blob: 898d9ee55589cb05934bbc150280984ec3295f3a [file] [log] [blame]
from typing import Union
import torch
from datasets import Dataset, load_dataset # noqa: F401
from datasets.formatting.formatting import LazyBatch
from sentence_transformers import SentenceTransformer
from transformers import (
AutoModelForTokenClassification,
AutoTokenizer,
PreTrainedModel,
PreTrainedTokenizer,
pipeline,
)
from transformers.pipelines import base
from hamilton.function_modifiers import load_from, save_to, source, value
# Could explicitly load the dataset this way
# def medium_articles() -> Dataset:
# """Loads medium dataset into a hugging face dataset"""
# ds = load_dataset(
# "fabiochiu/medium-articles",
# data_files="medium_articles.csv",
# split="train"
# )
# return ds
@load_from.hf_dataset(
path=value("fabiochiu/medium-articles"),
data_files=value("medium_articles.csv"),
split=value("train"),
)
def medium_articles(dataset: Dataset) -> Dataset:
"""Loads medium dataset into a hugging face dataset"""
return dataset
def sampled_articles(
medium_articles: Dataset,
sample_size: int = 104,
random_state: int = 32,
max_text_length: int = 1000,
) -> Dataset:
"""Samples the articles and does some light transformations.
Transformations:
- selects the first 1000 characters of text. This is for performance here. But in real life you'd \
do something for your use case.
- Joins article title and the text to create one text string.
"""
# Filter out entries with NaN values in 'text' or 'title' fields
dataset = medium_articles.filter(
lambda example: example["text"] is not None and example["title"] is not None
)
# Shuffle and take the first 10000 samples
dataset = dataset.shuffle(seed=random_state).select(range(sample_size))
# Truncate the 'text' to the first 1000 characters
dataset = dataset.map(lambda example: {"text": example["text"][:max_text_length]})
# Concatenate the 'title' and truncated 'text'
dataset = dataset.map(lambda example: {"title_text": example["title"] + ". " + example["text"]})
return dataset
def device() -> str:
"""Whether this is a CUDA or CPU enabled device."""
return "cuda" if torch.cuda.is_available() else "cpu"
def NER_model_id() -> str:
"""Model ID to use
To extract named entities, we will use a NER model finetuned on a BERT-base model.
The model can be loaded from the HuggingFace model hub.
Use `overrides={"NER_model_id": VALUE}` to switch this without changing code.
"""
return "dslim/bert-base-NER"
def tokenizer(NER_model_id: str) -> PreTrainedTokenizer:
"""Loads the tokenizer for the NER model ID from huggingface"""
return AutoTokenizer.from_pretrained(NER_model_id)
def model(NER_model_id: str) -> PreTrainedModel:
"""Loads the NER model from huggingface"""
return AutoModelForTokenClassification.from_pretrained(NER_model_id)
def ner_pipeline(
model: PreTrainedModel, tokenizer: PreTrainedTokenizer, device: str
) -> base.Pipeline:
"""Loads the tokenizer and model into a NER pipeline. That is it combines them."""
device_no = torch.cuda.current_device() if device == "cuda" else None
return pipeline(
"ner", model=model, tokenizer=tokenizer, aggregation_strategy="max", device=device_no
)
def retriever(
device: str, retriever_model_id: str = "flax-sentence-embeddings/all_datasets_v3_mpnet-base"
) -> SentenceTransformer:
"""Our retriever model to create embeddings.
A retriever model is used to embed passages (article title + first 1000 characters)
and queries. It creates embeddings such that queries and passages with similar
meanings are close in the vector space. We will use a sentence-transformer model
as our retriever. The model can be loaded as follows:
"""
return SentenceTransformer(retriever_model_id, device=device)
def _extract_named_entities_text(
title_text_batch: Union[LazyBatch, list[str]], _ner_pipeline
) -> list[list[str]]:
"""Helper function to extract named entities given a batch of text."""
# extract named entities using the NER pipeline
extracted_batch = _ner_pipeline(title_text_batch)
# this should be extracted_batch = dataset.map(ner_pipeline)
entities = []
# loop through the results and only select the entity names
for text in extracted_batch:
ne = [entity["word"] for entity in text]
entities.append(ne)
_named_entities = [list(set(entity)) for entity in entities]
return _named_entities
def _batch_map(dataset: LazyBatch, _retriever, _ner_pipeline) -> dict:
"""Helper function to created the embedding vectors and extract named entities"""
title_text_list = dataset["title_text"]
emb = _retriever.encode(title_text_list)
_named_entities = _extract_named_entities_text(title_text_list, _ner_pipeline)
return {
"vector": emb,
"named_entities": _named_entities,
}
def columns_of_interest() -> list[str]:
"""The columns we expect to pull from the dataset to be saved to lancedb"""
return ["vector", "named_entities", "title", "url", "authors", "timestamp", "tags"]
@save_to.lancedb(
db_client=source("db_client"),
table_name=source("table_name"),
columns_to_write=source("columns_of_interest"),
output_name_="load_into_lancedb",
)
def final_dataset(
sampled_articles: Dataset,
retriever: SentenceTransformer,
ner_pipeline: base.Pipeline,
) -> Dataset:
"""The final dataset to be pushed to lancedb.
This adds two columns:
- vector -- the vector embedding
- named_entities -- the names of entities extracted from the text
"""
# goes over the data in batches so that the GPU can be properly utilized.
final_ds = sampled_articles.map(
_batch_map,
batched=True,
fn_kwargs={"_retriever": retriever, "_ner_pipeline": ner_pipeline},
desc="extracting entities",
)
return final_ds
def named_entities(query: str, ner_pipeline: base.Pipeline) -> list[str]:
"""The entities to extract from the query via the pipeline."""
return _extract_named_entities_text([query], ner_pipeline)[0]