blob: 5afe3162820572e1ea1b18f6668f1889c3813b4a [file] [log] [blame]
from __future__ import annotations
import json
import click
import data_module
import embedding_module
from hamilton import base, driver
@click.command()
@click.option(
"--vector_db",
type=click.Choice(["lancedb", "weaviate", "pinecone", "marqo"], case_sensitive=False),
default="lancedb",
help="Vector database service",
)
@click.option(
"--vector_db_config",
default='{"uri": "data/lancedb"}',
help="Pass a JSON string for vector database config.\
Weaviate needs a dictionary {'url': ''}\
Pinecone needs dictionary {'environment': '', 'api_key': ''}",
)
@click.option(
"--embedding_service",
type=click.Choice(["openai", "cohere", "sentence_transformer", "marqo"], case_sensitive=False),
default="sentence_transformer",
help="Text embedding service.",
)
@click.option(
"--embedding_service_api_key",
default=None,
help="API Key for embedding service. Needed if using OpenAI or Cohere.",
)
@click.option("--model_name", default=None, help="Text embedding model name.")
@click.option("--display_dag", is_flag=True, help="Generate a .png of the Hamilton DAG")
@click.option(
"--other_input_kwargs", default=None, help="Pass in other input kwargs as a JSON string."
)
def main(
vector_db: str,
vector_db_config: str,
embedding_service: str,
embedding_service_api_key: str | None,
model_name: str,
display_dag: bool,
other_input_kwargs: str,
):
if vector_db == "weaviate":
import weaviate_module
vector_db_module = weaviate_module
elif vector_db == "pinecone":
import pinecone_module
vector_db_module = pinecone_module
elif vector_db == "lancedb":
import lancedb_module
vector_db_module = lancedb_module
elif vector_db == "marqo":
import marqo_module
vector_db_module = marqo_module
else:
raise ValueError(f"Unknown vector_db {vector_db}")
if model_name is None:
if embedding_service == "openai":
model_name = "text-embedding-ada-002"
elif embedding_service == "cohere":
model_name = "embed-english-light-v2.0"
elif embedding_service == "sentence_transformer":
model_name = "multi-qa-MiniLM-L6-cos-v1"
config = dict(
vector_db_config=json.loads(vector_db_config),
embedding_service=embedding_service, # this triggers config.when() in embedding_module
api_key=embedding_service_api_key,
model_name=model_name,
)
if other_input_kwargs is None:
other_input_kwargs = "{}"
input_dict = json.loads(other_input_kwargs)
input_dict.update(dict(class_name="SQuADEntry"))
dr = driver.Driver(
config,
data_module,
vector_db_module, # this points to one of the VectorDB modules
embedding_module,
adapter=base.DefaultAdapter(),
)
# The `final_vars` requested are functions with side-effects
dr.execute(
final_vars=["initialize_vector_db_indices", "push_to_vector_db"],
inputs=input_dict,
)
if display_dag:
filename = f"{vector_db}_{embedding_service}_dag"
dr.display_all_functions(filename, {"format": "png"})
if __name__ == "__main__":
main()