blob: 3bb36930a0457dcf5e5908e88f0f209a300af364 [file] [log] [blame]
# Licensed to the Apache Software Foundation (ASF) under one or more
# contributor license agreements. See the NOTICE file distributed with
# this work for additional information regarding copyright ownership.
# The ASF licenses this file to You under the Apache License, Version 2.0
# (the "License"); you may not use this file except in compliance with
# the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""A Streaming pipeline that uses RunInference to perform Sentiment analysis
with DistilBERT model.
It reads input lines from a Pub/Sub topic (populated from a GCS file),
performs sentiment classification using a DistilBERT model via RunInference,
and writes results to BigQuery.
Resources like Pub/Sub topic/subscription cleanup is handled programmatically.
"""
import argparse
import logging
import threading
import time
from collections.abc import Iterable
from google.cloud import pubsub_v1
import apache_beam as beam
import torch
import torch.nn.functional as F
from apache_beam.ml.inference.base import KeyedModelHandler
from apache_beam.ml.inference.base import PredictionResult
from apache_beam.ml.inference.base import RunInference
from apache_beam.ml.inference.pytorch_inference import PytorchModelHandlerKeyedTensor
from apache_beam.options.pipeline_options import PipelineOptions
from apache_beam.options.pipeline_options import SetupOptions
from apache_beam.options.pipeline_options import StandardOptions
from apache_beam.runners.runner import PipelineResult
from transformers import DistilBertConfig
from transformers import DistilBertForSequenceClassification
from transformers import DistilBertTokenizerFast
class SentimentPostProcessor(beam.DoFn):
"""Processes PredictionResult to extract sentiment label and confidence."""
def __init__(self, tokenizer: DistilBertTokenizerFast):
self.tokenizer = tokenizer
def process(self, element: tuple[str, PredictionResult]) -> Iterable[dict]:
text, prediction_result = element
logits = prediction_result.inference['logits']
probs = F.softmax(logits, dim=-1)
predicted_class = torch.argmax(probs).item()
confidence = probs[predicted_class].item()
sentiment = 'POSITIVE' if predicted_class == 1 else 'NEGATIVE'
yield {
'text': text, 'sentiment': sentiment, 'confidence': float(confidence)
}
def tokenize_text(text: str,
tokenizer: DistilBertTokenizerFast) -> tuple[str, dict]:
"""Tokenizes input text using the specified tokenizer."""
tokenized = tokenizer(
text,
padding='max_length',
truncation=True,
max_length=128,
return_tensors="pt")
return text, {k: torch.squeeze(v) for k, v in tokenized.items()}
class RateLimitDoFn(beam.DoFn):
def __init__(self, rate_per_sec: float):
self.delay = 1.0 / rate_per_sec
def process(self, element):
time.sleep(self.delay)
yield element
def parse_known_args(argv):
"""Parses command-line arguments for pipeline execution."""
parser = argparse.ArgumentParser()
parser.add_argument(
'--output_table', required=True, help='Path to output BigQuery table')
parser.add_argument(
'--model_path',
default='distilbert-base-uncased-finetuned-sst-2-english',
help='Hugging Face model name or local path')
parser.add_argument(
'--model_state_dict_path',
dest='model_state_dict_path',
required=True,
help="Path to the model's state_dict.")
parser.add_argument(
'--input', required=True, help='Path to input file on GCS')
parser.add_argument(
'--pubsub_topic',
default='projects/apache-beam-testing/topics/test_sentiment_topic',
help='Pub/Sub topic to publish text lines to')
parser.add_argument(
'--pubsub_subscription',
default='projects/apache-beam-testing/subscriptions/'
'test_sentiment_subscription',
help='Pub/Sub subscription to read from')
parser.add_argument(
'--project', default='apache-beam-testing', help='GCP project ID')
parser.add_argument(
'--mode', default='streaming', choices=['streaming', 'batch'])
parser.add_argument(
'--rate_limit',
type=float,
default=None,
help='Elements per second to send to Pub/Sub')
return parser.parse_known_args(argv)
def ensure_pubsub_resources(
project: str, topic_path: str, subscription_path: str):
"""Ensures that the required Pub/Sub topic and subscription exist."""
publisher = pubsub_v1.PublisherClient()
subscriber = pubsub_v1.SubscriberClient()
topic_name = topic_path.split("/")[-1]
subscription_name = subscription_path.split("/")[-1]
full_topic_path = publisher.topic_path(project, topic_name)
full_subscription_path = subscriber.subscription_path(
project, subscription_name)
try:
publisher.get_topic(request={"topic": full_topic_path})
except Exception:
publisher.create_topic(name=full_topic_path)
try:
subscriber.get_subscription(
request={"subscription": full_subscription_path})
except Exception:
subscriber.create_subscription(
name=full_subscription_path, topic=full_topic_path)
def cleanup_pubsub_resources(
project: str, topic_path: str, subscription_path: str):
"""Deletes Pub/Sub topic and subscription created for the pipeline."""
publisher = pubsub_v1.PublisherClient()
subscriber = pubsub_v1.SubscriberClient()
topic_name = topic_path.split("/")[-1]
subscription_name = subscription_path.split("/")[-1]
full_topic_path = publisher.topic_path(project, topic_name)
full_subscription_path = subscriber.subscription_path(
project, subscription_name)
try:
subscriber.delete_subscription(
request={"subscription": full_subscription_path})
print(f"Deleted subscription: {subscription_name}")
except Exception as e:
print(f"Failed to delete subscription: {e}")
try:
publisher.delete_topic(request={"topic": full_topic_path})
print(f"Deleted topic: {topic_name}")
except Exception as e:
print(f"Failed to delete topic: {e}")
def override_or_add(args, flag, value):
if flag in args:
idx = args.index(flag)
args[idx + 1] = str(value)
else:
args.extend([flag, str(value)])
def run_load_pipeline(known_args, pipeline_args):
"""Load data pipeline: read lines from GCS file and send to Pub/Sub."""
override_or_add(pipeline_args, '--device', 'CPU')
override_or_add(pipeline_args, '--num_workers', '5')
override_or_add(pipeline_args, '--max_num_workers', '10')
override_or_add(
pipeline_args, '--job_name', f"sentiment-load-pubsub-{int(time.time())}")
override_or_add(pipeline_args, '--project', known_args.project)
pipeline_options = PipelineOptions(pipeline_args)
pipeline = beam.Pipeline(options=pipeline_options)
lines = (
pipeline
| 'ReadGCSFile' >> beam.io.ReadFromText(known_args.input)
| 'FilterEmpty' >> beam.Filter(lambda line: line.strip()))
if known_args.rate_limit:
lines = (
lines
| 'RateLimit' >> beam.ParDo(
RateLimitDoFn(rate_per_sec=known_args.rate_limit)))
_ = (
lines
| 'ToBytes' >> beam.Map(lambda line: line.encode('utf-8'))
|
'PublishToPubSub' >> beam.io.WriteToPubSub(topic=known_args.pubsub_topic))
return pipeline.run()
def run(
argv=None, save_main_session=True, test_pipeline=None) -> PipelineResult:
known_args, pipeline_args = parse_known_args(argv)
ensure_pubsub_resources(
project=known_args.project,
topic_path=known_args.pubsub_topic,
subscription_path=known_args.pubsub_subscription)
if known_args.mode == 'streaming':
threading.Thread(
target=lambda:
(time.sleep(100), run_load_pipeline(known_args, pipeline_args)),
daemon=True).start()
pipeline_options = PipelineOptions(pipeline_args)
pipeline_options.view_as(SetupOptions).save_main_session = save_main_session
method = beam.io.WriteToBigQuery.Method.FILE_LOADS
pipeline_options.view_as(StandardOptions).streaming = False
if known_args.mode == 'streaming':
method = beam.io.WriteToBigQuery.Method.STREAMING_INSERTS
pipeline_options.view_as(StandardOptions).streaming = True
model_handler = PytorchModelHandlerKeyedTensor(
model_class=DistilBertForSequenceClassification,
model_params={'config': DistilBertConfig(num_labels=2)},
state_dict_path=known_args.model_state_dict_path,
device='GPU')
tokenizer = DistilBertTokenizerFast.from_pretrained(known_args.model_path)
pipeline = test_pipeline or beam.Pipeline(options=pipeline_options)
# Main pipeline: read, process, write result to BigQuery output table
if known_args.mode == 'batch':
input = (
pipeline
| 'ReadGCSFile' >> beam.io.ReadFromText(known_args.input)
| 'FilterEmpty' >> beam.Filter(lambda line: line.strip()))
else:
input = (
pipeline
| 'ReadFromPubSub' >>
beam.io.ReadFromPubSub(subscription=known_args.pubsub_subscription)
| 'DecodeText' >> beam.Map(lambda x: x.decode('utf-8'))
| 'WindowedOutput' >> beam.WindowInto(
beam.window.FixedWindows(60),
trigger=beam.trigger.AfterProcessingTime(30),
accumulation_mode=beam.trigger.AccumulationMode.DISCARDING,
allowed_lateness=0))
_ = (
input
| 'Tokenize' >> beam.Map(lambda text: tokenize_text(text, tokenizer))
| 'RunInference' >> RunInference(KeyedModelHandler(model_handler))
| 'PostProcess' >> beam.ParDo(SentimentPostProcessor(tokenizer))
| 'WriteToBigQuery' >> beam.io.WriteToBigQuery(
known_args.output_table,
schema='text:STRING, sentiment:STRING, confidence:FLOAT',
write_disposition=beam.io.BigQueryDisposition.WRITE_APPEND,
create_disposition=beam.io.BigQueryDisposition.CREATE_IF_NEEDED,
method=method))
result = pipeline.run()
result.wait_until_finish(duration=1800000) # 30 min
result.cancel()
cleanup_pubsub_resources(
project=known_args.project,
topic_path=known_args.pubsub_topic,
subscription_path=known_args.pubsub_subscription)
return result
if __name__ == '__main__':
logging.getLogger().setLevel(logging.INFO)
run()