blob: a6905c375451486a775e35c33a451d04f39555f1 [file] [log] [blame]
#
# Licensed to the Apache Software Foundation (ASF) under one or more
# contributor license agreements. See the NOTICE file distributed with
# this work for additional information regarding copyright ownership.
# The ASF licenses this file to You under the Apache License, Version 2.0
# (the "License"); you may not use this file except in compliance with
# the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
__all__ = ["SentenceTransformerEmbeddings", "InferenceAPIEmbeddings"]
import logging
import os
from collections.abc import Callable
from collections.abc import Mapping
from collections.abc import Sequence
from typing import Any
from typing import Optional
import requests
import apache_beam as beam
from apache_beam.ml.inference.base import ModelHandler
from apache_beam.ml.inference.base import RunInference
from apache_beam.ml.transforms.base import EmbeddingsManager
from apache_beam.ml.transforms.base import _ImageEmbeddingHandler
from apache_beam.ml.transforms.base import _TextEmbeddingHandler
try:
from sentence_transformers import SentenceTransformer
except ImportError:
SentenceTransformer = None
_LOGGER = logging.getLogger(__name__)
# TODO: https://github.com/apache/beam/issues/29621
# Use HuggingFaceModelHandlerTensor once the import issue is fixed.
# Right now, the hugging face model handler import torch and tensorflow
# at the same time, which adds too much weigth to the container unnecessarily.
class _SentenceTransformerModelHandler(ModelHandler):
"""
Note: Intended for internal use and guarantees no backwards compatibility.
"""
def __init__(
self,
model_name: str,
model_class: Callable,
load_model_args: Optional[dict] = None,
min_batch_size: Optional[int] = None,
max_batch_size: Optional[int] = None,
max_seq_length: Optional[int] = None,
large_model: bool = False,
**kwargs):
self._max_seq_length = max_seq_length
self.model_name = model_name
self._model_class = model_class
self._load_model_args = load_model_args
self._min_batch_size = min_batch_size
self._max_batch_size = max_batch_size
self._large_model = large_model
self._kwargs = kwargs
if not SentenceTransformer:
raise ImportError(
"sentence-transformers is required to use "
"SentenceTransformerEmbeddings."
"Please install it with using `pip install sentence-transformers`.")
def run_inference(
self,
batch: Sequence[str],
model: SentenceTransformer,
inference_args: Optional[dict[str, Any]] = None,
):
inference_args = inference_args or {}
return model.encode(batch, **inference_args)
def load_model(self):
model = self._model_class(self.model_name, **self._load_model_args)
if self._max_seq_length:
model.max_seq_length = self._max_seq_length
return model
def share_model_across_processes(self) -> bool:
return self._large_model
def batch_elements_kwargs(self) -> Mapping[str, Any]:
batch_sizes = {}
if self._min_batch_size:
batch_sizes["min_batch_size"] = self._min_batch_size
if self._max_batch_size:
batch_sizes["max_batch_size"] = self._max_batch_size
return batch_sizes
def __repr__(self) -> str:
# ModelHandler is internal to the user and is not exposed.
# Hence we need to override the __repr__ method to expose
# the name of the class.
return 'SentenceTransformerEmbeddings'
class SentenceTransformerEmbeddings(EmbeddingsManager):
def __init__(
self,
model_name: str,
columns: list[str],
max_seq_length: Optional[int] = None,
image_model: bool = False,
**kwargs):
"""
Embedding config for sentence-transformers. This config can be used with
MLTransform to embed text data. Models are loaded using the RunInference
PTransform with the help of ModelHandler.
Args:
model_name: Name of the model to use. The model should be hosted on
HuggingFace Hub or compatible with sentence_transformers. For image
embedding models, see
https://www.sbert.net/docs/sentence_transformer/pretrained_models.html#image-text-models # pylint: disable=line-too-long
for a list of available sentence_transformers models.
columns: List of columns to be embedded.
max_seq_length: Max sequence length to use for the model if applicable.
image_model: Whether the model is generating image embeddings.
min_batch_size: The minimum batch size to be used for inference.
max_batch_size: The maximum batch size to be used for inference.
large_model: Whether to share the model across processes.
"""
super().__init__(columns=columns, **kwargs)
self.model_name = model_name
self.max_seq_length = max_seq_length
self.image_model = image_model
def get_model_handler(self):
return _SentenceTransformerModelHandler(
model_class=SentenceTransformer,
max_seq_length=self.max_seq_length,
model_name=self.model_name,
load_model_args=self.load_model_args,
min_batch_size=self.min_batch_size,
max_batch_size=self.max_batch_size,
large_model=self.large_model)
def get_ptransform_for_processing(self, **kwargs) -> beam.PTransform:
# wrap the model handler in an appropriate embedding handler to provide
# some type checking.
if self.image_model:
return (
RunInference(
model_handler=_ImageEmbeddingHandler(self),
inference_args=self.inference_args,
))
return (
RunInference(
model_handler=_TextEmbeddingHandler(self),
inference_args=self.inference_args,
))
class _InferenceAPIHandler(ModelHandler):
def __init__(self, config: 'InferenceAPIEmbeddings'):
super().__init__()
self._config = config
def load_model(self):
session = requests.Session()
# if the token is not provided during construction time, it might have
# been provided with custom container, which we can get it during runtume.
if not self._config.hf_token:
hf_token = os.environ.get("HF_TOKEN")
if not hf_token:
raise ValueError(
'HF_TOKEN environment variable not set. '
'Please set the environment variable or pass the token as an '
'argument.')
session.headers.update({"Authorization": f"Bearer {hf_token}"})
return session
session.headers.update(self._config.authorization_token)
return session
def run_inference(
self, batch, session: requests.Session, inference_args=None):
response = session.post(
self._config.api_url,
headers=self._config.authorization_token,
json={
"inputs": batch, "options": inference_args
})
return response.json()
class InferenceAPIEmbeddings(EmbeddingsManager):
"""
Feature extraction using HuggingFace's Inference API.
Intended to be used for feature-extraction. For other tasks, please
refer to https://huggingface.co/inference-api.
Args:
hf_token: HuggingFace token.
columns: List of columns to be embedded.
model_name: Model name used for feature extraction.
api_url: API url for feature extraction. If specified, model_name will be
ignored. If none, the default url for feature extraction
will be used.
"""
def __init__(
self,
hf_token: Optional[str],
columns: list[str],
model_name: Optional[str] = None, # example: "sentence-transformers/all-MiniLM-l6-v2" # pylint: disable=line-too-long
api_url: Optional[str] = None,
**kwargs,
):
super().__init__(columns=columns, **kwargs)
self._authorization_token = {"Authorization": f"Bearer {hf_token}"}
self._model_name = model_name
self.hf_token = hf_token
if not api_url:
if not self._model_name:
raise ValueError("Either api_url or model_name must be provided.")
self._api_url = (
f"https://router.huggingface.co/hf-inference/models/{self._model_name}/pipeline/feature-extraction" # pylint: disable=line-too-long
)
else:
self._api_url = api_url
_LOGGER.info("HuggingFace API URL: %s")
def get_token(self):
return os.environ.get('HF_TOKEN')
@property
def api_url(self):
return self._api_url
@property
def authorization_token(self):
return self._authorization_token
def get_model_handler(self):
return _InferenceAPIHandler(self)
def get_ptransform_for_processing(self, **kwargs) -> beam.PTransform:
options = {
# sometimes the model is not ready and returns an error response
# instead of waiting. So we wait for the model to be ready.
'wait_for_model': True,
'use_cache': True,
}
self.inference_args.update(options)
return (
RunInference(
model_handler=_TextEmbeddingHandler(self),
inference_args=self.inference_args))