blob: 0128d6a6d6fc53d203a7ef2f457debf37ca9ef4e [file] [log] [blame]
#
# Licensed to the Apache Software Foundation (ASF) under one or more
# contributor license agreements. See the NOTICE file distributed with
# this work for additional information regarding copyright ownership.
# The ASF licenses this file to You under the Apache License, Version 2.0
# (the "License"); you may not use this file except in compliance with
# the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""RAG-specific embedding adapters.
This module provides adapters for extracting content from
EmbeddableItem instances and mapping embeddings back. Adapters
are used by EmbeddingsManager to support various input types.
"""
from collections.abc import Sequence
from typing import List
from apache_beam.ml.rag.types import EmbeddableItem
from apache_beam.ml.rag.types import Embedding
from apache_beam.ml.transforms.base import EmbeddingTypeAdapter
def create_text_adapter(
) -> EmbeddingTypeAdapter[EmbeddableItem, EmbeddableItem]:
"""Creates adapter for text content embedding.
Works with any EmbeddableItem that has text content
(content.text). Extracts text for embedding and maps
results back as Embedding objects.
Returns:
EmbeddingTypeAdapter configured for text embedding
"""
return EmbeddingTypeAdapter(
input_fn=_extract_text, output_fn=_add_embedding_fn)
# Backward compatibility alias.
create_rag_adapter = create_text_adapter
def _extract_text(items: Sequence[EmbeddableItem]) -> List[str]:
"""Extract text from items for embedding."""
texts = []
for item in items:
if not item.content.text:
raise ValueError(
f"Expected text content in {type(item).__name__} {item.id}, "
"got None")
texts.append(item.content.text)
return texts
def _add_embedding_fn(
items: Sequence[EmbeddableItem],
embeddings: Sequence[List[float]]) -> List[EmbeddableItem]:
"""Create Embeddings from items and embedding vectors."""
for item, embedding in zip(items, embeddings):
item.embedding = Embedding(dense_embedding=embedding)
return list(items)