blob: f6e33e5be78698b8e1a41cb616a22b092be64100 [file] [log] [blame]
#
# Licensed to the Apache Software Foundation (ASF) under one or more
# contributor license agreements. See the NOTICE file distributed with
# this work for additional information regarding copyright ownership.
# The ASF licenses this file to You under the Apache License, Version 2.0
# (the "License"); you may not use this file except in compliance with
# the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
from __future__ import annotations
import logging
import os
import tempfile
import apache_beam as beam
from apache_beam.io.filesystems import FileSystems
from apache_beam.ml.inference.base import RunInference
from apache_beam.ml.inference.vllm_inference import VLLMCompletionsModelHandler
from apache_beam.ml.inference.vllm_inference import _VLLMModelServer
from apache_beam.options.pipeline_options import PipelineOptions
from apache_beam.options.pipeline_options import SetupOptions
class GemmaVLLMOptions(PipelineOptions):
"""Custom pipeline options for the Gemma vLLM batch inference job."""
@classmethod
def _add_argparse_args(cls, parser):
parser.add_argument(
"--input",
dest="input_file",
required=True,
help="Input file gs://path containing prompts.",
)
parser.add_argument(
"--output_table",
required=True,
help="BigQuery table to write to in the form project:dataset.table.",
)
parser.add_argument(
"--model_gcs_path",
required=True,
help="GCS path to the directory containing model files.",
)
class FormatOutput(beam.DoFn):
def process(self, element):
prompt = element.example
comp = element.inference
if hasattr(comp, 'choices'):
completion = comp.choices[0].text
# fallback to a single .text field
elif hasattr(comp, 'text'):
completion = comp.text
# final fallback
else:
completion = str(comp)
yield {'prompt': prompt, 'completion': completion}
class GcsVLLMCompletionsModelHandler(VLLMCompletionsModelHandler):
def __init__(self, model_name, vllm_server_kwargs=None):
super().__init__(model_name, vllm_server_kwargs)
self._local_model_dir = None
def _download_gcs_directory(self, gcs_path: str, local_path: str):
logging.info("Downloading model from %s to %s…", gcs_path, local_path)
matches = FileSystems.match([os.path.join(gcs_path, "**")])[0].metadata_list
for md in matches:
rel = os.path.relpath(md.path, gcs_path)
dst = os.path.join(local_path, rel)
os.makedirs(os.path.dirname(dst), exist_ok=True)
with FileSystems.open(md.path) as src, open(dst, "wb") as dstf:
dstf.write(src.read())
logging.info("Download complete.")
def load_model(self) -> _VLLMModelServer:
uri = self._model_name
if uri.startswith("gs://"):
self._local_model_dir = tempfile.mkdtemp(prefix="vllm_model_")
self._download_gcs_directory(uri, self._local_model_dir)
logging.info("Loading vLLM from local dir %s", self._local_model_dir)
return _VLLMModelServer(self._local_model_dir, self._vllm_server_kwargs)
else:
logging.info("Loading vLLM from HF hub: %s", uri)
return super().load_model()
def run(argv=None, save_main_session=True, test_pipeline=None):
# Build pipeline options
opts = PipelineOptions(argv)
gem = opts.view_as(GemmaVLLMOptions)
opts.view_as(SetupOptions).save_main_session = save_main_session
logging.info("Pipeline starting with model path: %s", gem.model_gcs_path)
handler = GcsVLLMCompletionsModelHandler(
model_name=gem.model_gcs_path,
vllm_server_kwargs={"served-model-name": gem.model_gcs_path})
with (test_pipeline or beam.Pipeline(options=opts)) as p:
_ = (
p
| "Read" >> beam.io.ReadFromText(gem.input_file)
| "InferBatch" >> RunInference(handler, inference_batch_size=32)
| "FormatForBQ" >> beam.ParDo(FormatOutput())
| "WriteToBQ" >> beam.io.WriteToBigQuery(
gem.output_table,
schema="prompt:STRING,completion:STRING",
write_disposition=beam.io.BigQueryDisposition.WRITE_APPEND,
create_disposition=beam.io.BigQueryDisposition.CREATE_IF_NEEDED,
method=beam.io.WriteToBigQuery.Method.FILE_LOADS,
))
return p.result
if __name__ == "__main__":
logging.getLogger().setLevel(logging.INFO)
run()