| # |
| # Licensed to the Apache Software Foundation (ASF) under one or more |
| # contributor license agreements. See the NOTICE file distributed with |
| # this work for additional information regarding copyright ownership. |
| # The ASF licenses this file to You under the Apache License, Version 2.0 |
| # (the "License"); you may not use this file except in compliance with |
| # the License. You may obtain a copy of the License at |
| # |
| # http://www.apache.org/licenses/LICENSE-2.0 |
| # |
| # Unless required by applicable law or agreed to in writing, software |
| # distributed under the License is distributed on an "AS IS" BASIS, |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| # See the License for the specific language governing permissions and |
| # limitations under the License. |
| |
| from collections.abc import Callable |
| from dataclasses import dataclass |
| from typing import Any |
| from typing import Dict |
| from typing import Optional |
| |
| import apache_beam as beam |
| from apache_beam.io.gcp.bigquery_tools import beam_row_from_dict |
| from apache_beam.io.gcp.bigquery_tools import get_beam_typehints_from_tableschema |
| from apache_beam.ml.rag.ingestion.base import VectorDatabaseWriteConfig |
| from apache_beam.ml.rag.types import Chunk |
| from apache_beam.typehints.row_type import RowTypeConstraint |
| |
| ChunkToDictFn = Callable[[Chunk], Dict[str, any]] |
| |
| |
| @dataclass |
| class SchemaConfig: |
| """Configuration for custom BigQuery schema and row conversion. |
| |
| Allows overriding the default schema and row conversion logic for BigQuery |
| vector storage. This enables custom table schemas beyond the default |
| id/embedding/content/metadata structure. |
| |
| Attributes: |
| schema: BigQuery TableSchema dict defining the table structure. |
| Example: |
| >>> { |
| ... 'fields': [ |
| ... {'name': 'id', 'type': 'STRING'}, |
| ... {'name': 'embedding', 'type': 'FLOAT64', 'mode': 'REPEATED'}, |
| ... {'name': 'custom_field', 'type': 'STRING'} |
| ... ] |
| ... } |
| chunk_to_dict_fn: Function that converts a Chunk to a dict matching the |
| schema. Takes a Chunk and returns Dict[str, Any] with keys matching |
| schema fields. |
| Example: |
| >>> def chunk_to_dict(chunk: Chunk) -> Dict[str, Any]: |
| ... return { |
| ... 'id': chunk.id, |
| ... 'embedding': chunk.embedding.dense_embedding, |
| ... 'custom_field': chunk.metadata.get('custom_field') |
| ... } |
| """ |
| schema: Dict |
| chunk_to_dict_fn: ChunkToDictFn |
| |
| |
| class BigQueryVectorWriterConfig(VectorDatabaseWriteConfig): |
| def __init__( |
| self, |
| write_config: Dict[str, Any], |
| *, # Force keyword arguments |
| schema_config: Optional[SchemaConfig] = None): |
| """Configuration for writing vectors to BigQuery using managed transforms. |
| |
| Supports both default schema (id, embedding, content, metadata columns) and |
| custom schemas through SchemaConfig. |
| |
| Example with default schema: |
| >>> config = BigQueryVectorWriterConfig( |
| ... write_config={'table': 'project.dataset.embeddings'}) |
| |
| Example with custom schema: |
| >>> schema_config = SchemaConfig( |
| ... schema={ |
| ... 'fields': [ |
| ... {'name': 'id', 'type': 'STRING'}, |
| ... {'name': 'embedding', 'type': 'FLOAT64', 'mode': 'REPEATED'}, |
| ... {'name': 'source_url', 'type': 'STRING'} |
| ... ] |
| ... }, |
| ... chunk_to_dict_fn=lambda chunk: { |
| ... 'id': chunk.id, |
| ... 'embedding': chunk.embedding.dense_embedding, |
| ... 'source_url': chunk.metadata.get('url') |
| ... } |
| ... ) |
| >>> config = BigQueryVectorWriterConfig( |
| ... write_config={'table': 'project.dataset.embeddings'}, |
| ... schema_config=schema_config |
| ... ) |
| |
| Args: |
| write_config: BigQuery write configuration dict. Must include 'table'. |
| Other options like create_disposition, write_disposition can be |
| specified. |
| schema_config: Optional configuration for custom schema and row |
| conversion. |
| If not provided, uses default schema with id, embedding, content and |
| metadata columns. |
| |
| Raises: |
| ValueError: If write_config doesn't include table specification. |
| """ |
| if 'table' not in write_config: |
| raise ValueError("write_config must be provided with 'table' specified") |
| |
| self.write_config = write_config |
| self.schema_config = schema_config |
| |
| def create_write_transform(self) -> beam.PTransform: |
| """Creates transform to write to BigQuery.""" |
| return _WriteToBigQueryVectorDatabase(self) |
| |
| |
| def _default_chunk_to_dict_fn(chunk: Chunk): |
| if chunk.embedding is None or chunk.embedding.dense_embedding is None: |
| raise ValueError("chunk must contain dense embedding") |
| return { |
| 'id': chunk.id, |
| 'embedding': chunk.embedding.dense_embedding, |
| 'content': chunk.content.text, |
| 'metadata': [{ |
| "key": k, "value": str(v) |
| } for k, v in chunk.metadata.items()] |
| } |
| |
| |
| def _default_schema(): |
| return { |
| 'fields': [{ |
| 'name': 'id', 'type': 'STRING' |
| }, { |
| 'name': 'embedding', 'type': 'FLOAT64', 'mode': 'REPEATED' |
| }, { |
| 'name': 'content', 'type': 'STRING' |
| }, |
| { |
| 'name': 'metadata', |
| 'type': 'RECORD', |
| 'mode': 'REPEATED', |
| 'fields': [{ |
| 'name': 'key', 'type': 'STRING' |
| }, { |
| 'name': 'value', 'type': 'STRING' |
| }] |
| }] |
| } |
| |
| |
| class _WriteToBigQueryVectorDatabase(beam.PTransform): |
| """Implementation of BigQuery vector database write. """ |
| def __init__(self, config: BigQueryVectorWriterConfig): |
| self.config = config |
| |
| def expand(self, pcoll: beam.PCollection[Chunk]): |
| schema = ( |
| self.config.schema_config.schema |
| if self.config.schema_config else _default_schema()) |
| chunk_to_dict_fn = ( |
| self.config.schema_config.chunk_to_dict_fn |
| if self.config.schema_config else _default_chunk_to_dict_fn) |
| return ( |
| pcoll |
| | "Chunk to dict" >> beam.Map(chunk_to_dict_fn) |
| | "Chunk dict to schema'd row" >> beam.Map( |
| lambda chunk_dict: beam_row_from_dict( |
| row=chunk_dict, schema=schema)).with_output_types( |
| RowTypeConstraint.from_fields( |
| get_beam_typehints_from_tableschema(schema))) |
| | "Write to BigQuery" >> beam.managed.Write( |
| beam.managed.BIGQUERY, config=self.config.write_config)) |