blob: 2ca2374abbf334296a05f9516d77a334d2b6481b [file] [log] [blame]
{
"nbformat": 4,
"nbformat_minor": 0,
"metadata": {
"colab": {
"provenance": []
},
"kernelspec": {
"name": "python3",
"display_name": "Python 3"
},
"language_info": {
"name": "python"
}
},
"cells": [
{
"cell_type": "code",
"source": [
"# @title ###### Licensed to the Apache Software Foundation (ASF), Version 2.0 (the \"License\")\n",
"\n",
"# Licensed to the Apache Software Foundation (ASF) under one\n",
"# or more contributor license agreements. See the NOTICE file\n",
"# distributed with this work for additional information\n",
"# regarding copyright ownership. The ASF licenses this file\n",
"# to you under the Apache License, Version 2.0 (the\n",
"# \"License\"); you may not use this file except in compliance\n",
"# with the License. You may obtain a copy of the License at\n",
"#\n",
"# http://www.apache.org/licenses/LICENSE-2.0\n",
"#\n",
"# Unless required by applicable law or agreed to in writing,\n",
"# software distributed under the License is distributed on an\n",
"# \"AS IS\" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY\n",
"# KIND, either express or implied. See the License for the\n",
"# specific language governing permissions and limitations\n",
"# under the License"
],
"metadata": {
"cellView": "form",
"id": "IVkpU8HZ1eyz"
},
"execution_count": 1,
"outputs": []
},
{
"cell_type": "markdown",
"source": [
"# Use RunInference for Generative AI\n",
"\n",
"<table align=\"left\">\n",
" <td>\n",
" <a target=\"_blank\" href=\"https://colab.research.google.com/github/apache/beam/blob/master/examples/notebooks/beam-ml/run_inference_generative_ai.ipynb\"><img src=\"https://raw.githubusercontent.com/google/or-tools/main/tools/colab_32px.png\" />Run in Google Colab</a>\n",
" </td>\n",
" <td>\n",
" <a target=\"_blank\" href=\"https://github.com/apache/beam/blob/master/examples/notebooks/beam-ml/run_inference_generative_ai.ipynb\"><img src=\"https://raw.githubusercontent.com/google/or-tools/main/tools/github_32px.png\" />View source on GitHub</a>\n",
" </td>\n",
"</table>\n"
],
"metadata": {
"id": "kH8SORNim8on"
}
},
{
"cell_type": "markdown",
"source": [
"This notebook shows how to use the Apache Beam [RunInference](https://beam.apache.org/releases/pydoc/current/apache_beam.ml.inference.base.html#apache_beam.ml.inference.base.RunInference) transform for generative AI tasks. It uses a large language model (LLM) from the [Hugging Face Model Hub](https://huggingface.co/models).\n",
"\n",
"This notebook demonstrates the following steps:\n",
"- Load and save a model from the Hugging Face Model Hub.\n",
"- Use the PyTorch model handler for RunInference.\n",
"\n",
"For more information about using RunInference, see [Get started with AI/ML pipelines](https://beam.apache.org/documentation/ml/overview/) in the Apache Beam documentation."
],
"metadata": {
"id": "7N2XzwoA0k4L"
}
},
{
"cell_type": "markdown",
"source": [
"## Install the Apache Beam SDK and dependencies\n",
"\n",
"Use the following code to install the Apache Beam Python SDK, PyTorch, and Transformers."
],
"metadata": {
"id": "nhf_lOeEsO1C"
}
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "wS9a3Y0oZ_l5"
},
"outputs": [],
"source": [
"!pip install apache_beam[interactive,gcp]==2.48.0\n",
"!pip install torch\n",
"!pip install transformers"
]
},
{
"cell_type": "markdown",
"source": [
"Use the following code to import dependencies\n",
"\n",
"**Important**: If an error occurs, restart your runtime."
],
"metadata": {
"id": "I7vMsFGW16bZ"
}
},
{
"cell_type": "code",
"source": [
"import os\n",
"import apache_beam as beam\n",
"from apache_beam.options.pipeline_options import PipelineOptions\n",
"from apache_beam.ml.inference.base import PredictionResult\n",
"from apache_beam.ml.inference.base import RunInference\n",
"from apache_beam.ml.inference.pytorch_inference import make_tensor_model_fn\n",
"from apache_beam.ml.inference.pytorch_inference import PytorchModelHandlerTensor\n",
"import torch\n",
"from transformers import AutoConfig\n",
"from transformers import AutoModelForSeq2SeqLM\n",
"from transformers import AutoTokenizer\n",
"from transformers.tokenization_utils import PreTrainedTokenizer\n",
"\n",
"\n",
"MAX_RESPONSE_TOKENS = 256\n",
"\n",
"model_name = \"google/flan-t5-small\"\n",
"state_dict_path = \"saved_model\""
],
"metadata": {
"id": "uhbOYUzvbOSc"
},
"execution_count": 3,
"outputs": []
},
{
"cell_type": "markdown",
"source": [
"## Download and save the model\n",
"This notebook uses the [auto classes](https://huggingface.co/docs/transformers/model_doc/auto) from Hugging Face to instantly load the model in memory. Later, the model is saved to the path defined previously."
],
"metadata": {
"id": "yRls3LmxswrC"
}
},
{
"cell_type": "code",
"source": [
"model = AutoModelForSeq2SeqLM.from_pretrained(\n",
" model_name, torch_dtype=torch.bfloat16\n",
" )\n",
"\n",
"directory = os.path.dirname(state_dict_path)\n",
"torch.save(model.state_dict(), state_dict_path)"
],
"metadata": {
"id": "PKhkiQFJe44n"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "markdown",
"source": [
"## Define utility functions\n",
"The input and output for the [`google/flan-t5-small`](https://huggingface.co/google/flan-t5-small) model are token tensors. These utility functions are used for the conversion of text to token tensors and then back to text.\n"
],
"metadata": {
"id": "7TSqb3l1s7F7"
}
},
{
"cell_type": "code",
"source": [
"def to_tensors(input_text: str, tokenizer) -> torch.Tensor:\n",
" \"\"\"Encodes input text into token tensors.\n",
" Args:\n",
" input_text: Input text for the LLM model.\n",
" tokenizer: Tokenizer for the LLM model.\n",
" Returns: Tokenized input tokens.\n",
" \"\"\"\n",
" return tokenizer(input_text, return_tensors=\"pt\").input_ids[0]\n",
"\n",
"\n",
"def from_tensors(result: PredictionResult, tokenizer) -> str:\n",
" \"\"\"Decodes output token tensors into text.\n",
" Args:\n",
" result: Prediction results from the RunInference transform.\n",
" tokenizer: Tokenizer for the LLM model.\n",
" Returns: The model's response as text.\n",
" \"\"\"\n",
" output_tokens = result.inference\n",
" return tokenizer.decode(output_tokens, skip_special_tokens=True)"
],
"metadata": {
"id": "OeTMbaLidnBe"
},
"execution_count": 5,
"outputs": []
},
{
"cell_type": "code",
"source": [
"# Load the tokenizer.\n",
"tokenizer = AutoTokenizer.from_pretrained(model_name)\n",
"\n",
"# Create an instance of the PyTorch model handler.\n",
"model_handler = PytorchModelHandlerTensor(\n",
" state_dict_path=state_dict_path,\n",
" model_class=AutoModelForSeq2SeqLM.from_config,\n",
" model_params={\"config\": AutoConfig.from_pretrained(model_name)},\n",
" inference_fn=make_tensor_model_fn(\"generate\"),\n",
" )"
],
"metadata": {
"id": "77mOPlQR5Mev"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "markdown",
"source": [
"## Run the Pipeline\n"
],
"metadata": {
"id": "y7IU3cYKtA3e"
}
},
{
"cell_type": "code",
"source": [
"example = [\"translate English to Spanish: We are in New York City.\"]\n",
"\n",
"pipeline = beam.Pipeline(options=PipelineOptions(save_main_session=True,pickle_library=\"cloudpickle\"))\n",
"\n",
"with pipeline as p:\n",
" _ = (\n",
" p\n",
" | \"Create Examples\" >> beam.Create(example)\n",
" | \"To tensors\" >> beam.Map(to_tensors, tokenizer)\n",
" | \"RunInference\"\n",
" >> RunInference(\n",
" model_handler,\n",
" inference_args={\"max_new_tokens\": MAX_RESPONSE_TOKENS},\n",
" )\n",
" | \"From tensors\" >> beam.Map(from_tensors, tokenizer)\n",
" | \"Print\" >> beam.Map(print)\n",
" )"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 106
},
"id": "JUBrJEDYcjRG",
"outputId": "e7e7f0a2-dc0d-427d-bdf8-047739bb3160"
},
"execution_count": 7,
"outputs": [
{
"output_type": "display_data",
"data": {
"application/javascript": [
"\n",
" if (typeof window.interactive_beam_jquery == 'undefined') {\n",
" var jqueryScript = document.createElement('script');\n",
" jqueryScript.src = 'https://code.jquery.com/jquery-3.4.1.slim.min.js';\n",
" jqueryScript.type = 'text/javascript';\n",
" jqueryScript.onload = function() {\n",
" var datatableScript = document.createElement('script');\n",
" datatableScript.src = 'https://cdn.datatables.net/1.10.20/js/jquery.dataTables.min.js';\n",
" datatableScript.type = 'text/javascript';\n",
" datatableScript.onload = function() {\n",
" window.interactive_beam_jquery = jQuery.noConflict(true);\n",
" window.interactive_beam_jquery(document).ready(function($){\n",
" \n",
" });\n",
" }\n",
" document.head.appendChild(datatableScript);\n",
" };\n",
" document.head.appendChild(jqueryScript);\n",
" } else {\n",
" window.interactive_beam_jquery(document).ready(function($){\n",
" \n",
" });\n",
" }"
]
},
"metadata": {}
},
{
"output_type": "stream",
"name": "stdout",
"text": [
"Estamos en Nueva York City.\n"
]
}
]
}
]
}