blob: 7feeda4ea8e09587f37460445bba5ecbc7f14148 [file] [log] [blame]
#
# Licensed to the Apache Software Foundation (ASF) under one or more
# contributor license agreements. See the NOTICE file distributed with
# this work for additional information regarding copyright ownership.
# The ASF licenses this file to You under the Apache License, Version 2.0
# (the "License"); you may not use this file except in compliance with
# the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
"""This file contains the pipeline for loading a ML model, and exploring
the different RunInference metrics."""
import argparse
import logging
import sys
import apache_beam as beam
import config as cfg
from apache_beam.ml.inference import RunInference
from apache_beam.ml.inference.base import KeyedModelHandler
from apache_beam.ml.inference.pytorch_inference import PytorchModelHandlerKeyedTensor
from pipeline.options import get_pipeline_options
from pipeline.transformations import CustomPytorchModelHandlerKeyedTensor
from pipeline.transformations import HuggingFaceStripBatchingWrapper
from pipeline.transformations import PostProcessor
from pipeline.transformations import Tokenize
from transformers import DistilBertConfig
def parse_arguments(argv):
"""
Parses the arguments passed to the command line and
returns them as an object
Args:
argv: The arguments passed to the command line.
Returns:
The arguments that are being passed in.
"""
parser = argparse.ArgumentParser(description="benchmark-runinference")
parser.add_argument(
"-m",
"--mode",
help="Mode to run pipeline in.",
choices=["local", "cloud"],
default="local",
)
parser.add_argument(
"-p",
"--project",
help="GCP project to run pipeline on.",
default=cfg.PROJECT_ID,
)
parser.add_argument(
"-d",
"--device",
help="Device to run the dataflow job on",
choices=["CPU", "GPU"],
default="CPU",
)
args, _ = parser.parse_known_args(args=argv)
return args
def run():
"""
Runs the pipeline that loads a transformer based text classification model
and does inference on a list of sentences.
At the end of pipeline, different metrics like latency,
throughput and others are printed.
"""
args = parse_arguments(sys.argv)
inputs = [
"This is the worst food I have ever eaten",
"In my soul and in my heart, I’m convinced I’m wrong!",
"Be with me always—take any form—drive me mad!"\
"only do not leave me in this abyss, where I cannot find you!",
"Do I want to live? Would you like to live with your soul in the grave?",
"Honest people don’t hide their deeds.",
"Nelly, I am Heathcliff! He’s always,"\
"always in my mind: not as a pleasure,"\
"any more than I am always a pleasure to myself, but as my own being.",
] * 1000
pipeline_options = get_pipeline_options(
job_name=cfg.JOB_NAME,
num_workers=cfg.NUM_WORKERS,
project=args.project,
mode=args.mode,
device=args.device,
)
model_handler_class = (
PytorchModelHandlerKeyedTensor
if args.device == "GPU" else CustomPytorchModelHandlerKeyedTensor)
device = "cuda:0" if args.device == "GPU" else args.device
model_handler = model_handler_class(
state_dict_path=cfg.MODEL_STATE_DICT_PATH,
model_class=HuggingFaceStripBatchingWrapper,
model_params={
"config": DistilBertConfig.from_pretrained(cfg.MODEL_CONFIG_PATH)
},
device=device,
)
with beam.Pipeline(options=pipeline_options) as pipeline:
_ = (
pipeline
| "Create inputs" >> beam.Create(inputs)
| "Tokenize" >> beam.ParDo(Tokenize(cfg.TOKENIZER_NAME))
| "Inference" >>
RunInference(model_handler=KeyedModelHandler(model_handler))
| "Decode Predictions" >> beam.ParDo(PostProcessor()))
metrics = pipeline.result.metrics().query(beam.metrics.MetricsFilter())
logging.info(metrics)
if __name__ == "__main__":
run()