blob: dde42399ff35c58da92c588e457ea4c93bbf9683 [file] [log] [blame]
{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"metadata": {
"cellView": "form",
"id": "C1rAsD2L-hSO"
},
"outputs": [],
"source": [
"# @title ###### Licensed to the Apache Software Foundation (ASF), Version 2.0 (the \"License\")\n",
"\n",
"# Licensed to the Apache Software Foundation (ASF) under one\n",
"# or more contributor license agreements. See the NOTICE file\n",
"# distributed with this work for additional information\n",
"# regarding copyright ownership. The ASF licenses this file\n",
"# to you under the Apache License, Version 2.0 (the\n",
"# \"License\"); you may not use this file except in compliance\n",
"# with the License. You may obtain a copy of the License at\n",
"#\n",
"# http://www.apache.org/licenses/LICENSE-2.0\n",
"#\n",
"# Unless required by applicable law or agreed to in writing,\n",
"# software distributed under the License is distributed on an\n",
"# \"AS IS\" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY\n",
"# KIND, either express or implied. See the License for the\n",
"# specific language governing permissions and limitations\n",
"# under the License"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "b6f8f3af-744e-4eaa-8a30-6d03e8e4d21e"
},
"source": [
"# Bring your own ML model to Beam RunInference\n",
"\n",
"<table align=\"left\">\n",
" <td>\n",
" <a target=\"_blank\" href=\"https://colab.research.google.com/github/apache/beam/blob/master/examples/notebooks/beam-ml/run_custom_inference.ipynb\"><img src=\"https://raw.githubusercontent.com/google/or-tools/main/tools/colab_32px.png\" />Run in Google Colab</a>\n",
" </td>\n",
" <td>\n",
" <a target=\"_blank\" href=\"https://github.com/apache/beam/blob/master/examples/notebooks/beam-ml/run_custom_inference.ipynb\"><img src=\"https://raw.githubusercontent.com/google/or-tools/main/tools/github_32px.png\" />View source on GitHub</a>\n",
" </td>\n",
"</table>\n"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "A8xNRyZMW1yK"
},
"source": [
"This notebook demonstrates how to run inference on your custom framework using the\n",
"[ModelHandler](https://beam.apache.org/releases/pydoc/current/apache_beam.ml.inference.base.html#apache_beam.ml.inference.base.ModelHandler) class.\n",
"\n",
"Named-entity recognition (NER) is one of the most common tasks for natural language processing (NLP). \n",
"NLP locates named entities in unstructured text and classifies the entities using pre-defined labels, such as person name, organization, date, and so on.\n",
"\n",
"This example illustrates how to use the popular `spaCy` package to load a machine learning (ML) model and perform inference in an Apache Beam pipeline using the RunInference `PTransform`.\n",
"For more information about the RunInference API, see [About Beam ML](https://beam.apache.org/documentation/ml/about-ml) in the Apache Beam documentation."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "299af9bb-b2fc-405c-96e7-ee0a6ae24bdd"
},
"source": [
"## Install package dependencies\n",
"\n",
"The RunInference library is available in Apache Beam versions 2.40 and later.\n",
"\n",
"For this example, you need to install `spaCy` and `pandas`. A small NER model, `en_core_web_sm`, is also installed, but you can use any valid `spaCy` model."
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "7f841596-f217-46d2-b64e-1952db4de4cb",
"outputId": "da04ccb9-0801-47f6-ec9e-e87f0ca4569f"
},
"outputs": [],
"source": [
"# Uncomment the following lines to install the required packages.\n",
"# %pip install spacy pandas\n",
"# %pip install \"apache-beam[gcp, dataframe, interactive]\"\n",
"# !python -m spacy download en_core_web_sm"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "7f841596-f217-46d2-b64e-1952db4de4cc"
},
"source": [
"## Learn about `spaCy`\n",
"\n",
"To learn more about `spaCy`, create a `spaCy` language object in memory using `spaCy`'s trained models.\n",
"You can install these models as Python packages.\n",
"For more information, see spaCy's [Models and Languages](https://spacy.io/usage/models) documentation."
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {
"id": "7f841596-f217-46d2-b64e-1952db4de4cd"
},
"outputs": [],
"source": [
"import spacy\n",
"\n",
"nlp = spacy.load(\"en_core_web_sm\")\n"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {
"id": "7f841596-f217-46d2-b64e-1952db4de4ce"
},
"outputs": [],
"source": [
"# Add text strings.\n",
"text_strings = [\n",
" \"The New York Times is an American daily newspaper based in New York City with a worldwide readership.\",\n",
" \"It was founded in 1851 by Henry Jarvis Raymond and George Jones, and was initially published by Raymond, Jones & Company.\"\n",
"]\n"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {
"id": "7f841596-f217-46d2-b64e-1952db4de4cf"
},
"outputs": [],
"source": [
"# Check which entities spaCy can recognize.\n",
"doc = nlp(text_strings[0])\n"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {
"id": "7f841596-f217-46d2-b64e-1952db4de4d0"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"The New York Times 0 18 ORG\n",
"American 25 33 NORP\n",
"daily 34 39 DATE\n",
"New York City 59 72 GPE\n"
]
}
],
"source": [
"for ent in doc.ents:\n",
" print(ent.text, ent.start_char, ent.end_char, ent.label_)\n"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {
"id": "7f841596-f217-46d2-b64e-1952db4de4d1"
},
"outputs": [
{
"data": {
"text/html": [
"<span class=\"tex2jax_ignore\"><div class=\"entities\" style=\"line-height: 2.5; direction: ltr\">\n",
"<mark class=\"entity\" style=\"background: #7aecec; padding: 0.45em 0.6em; margin: 0 0.25em; line-height: 1; border-radius: 0.35em;\">\n",
" The New York Times\n",
" <span style=\"font-size: 0.8em; font-weight: bold; line-height: 1; border-radius: 0.35em; vertical-align: middle; margin-left: 0.5rem\">ORG</span>\n",
"</mark>\n",
" is an \n",
"<mark class=\"entity\" style=\"background: #c887fb; padding: 0.45em 0.6em; margin: 0 0.25em; line-height: 1; border-radius: 0.35em;\">\n",
" American\n",
" <span style=\"font-size: 0.8em; font-weight: bold; line-height: 1; border-radius: 0.35em; vertical-align: middle; margin-left: 0.5rem\">NORP</span>\n",
"</mark>\n",
" \n",
"<mark class=\"entity\" style=\"background: #bfe1d9; padding: 0.45em 0.6em; margin: 0 0.25em; line-height: 1; border-radius: 0.35em;\">\n",
" daily\n",
" <span style=\"font-size: 0.8em; font-weight: bold; line-height: 1; border-radius: 0.35em; vertical-align: middle; margin-left: 0.5rem\">DATE</span>\n",
"</mark>\n",
" newspaper based in \n",
"<mark class=\"entity\" style=\"background: #feca74; padding: 0.45em 0.6em; margin: 0 0.25em; line-height: 1; border-radius: 0.35em;\">\n",
" New York City\n",
" <span style=\"font-size: 0.8em; font-weight: bold; line-height: 1; border-radius: 0.35em; vertical-align: middle; margin-left: 0.5rem\">GPE</span>\n",
"</mark>\n",
" with a worldwide readership.</div></span>"
],
"text/plain": [
"<IPython.core.display.HTML object>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"# Visualize the results.\n",
"from spacy import displacy\n",
"displacy.render(doc, style=\"ent\")\n"
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {
"id": "7f841596-f217-46d2-b64e-1952db4de4e0"
},
"outputs": [
{
"data": {
"text/html": [
"<span class=\"tex2jax_ignore\"><div class=\"entities\" style=\"line-height: 2.5; direction: ltr\">It was founded in \n",
"<mark class=\"entity\" style=\"background: #bfe1d9; padding: 0.45em 0.6em; margin: 0 0.25em; line-height: 1; border-radius: 0.35em;\">\n",
" 1851\n",
" <span style=\"font-size: 0.8em; font-weight: bold; line-height: 1; border-radius: 0.35em; vertical-align: middle; margin-left: 0.5rem\">DATE</span>\n",
"</mark>\n",
" by \n",
"<mark class=\"entity\" style=\"background: #aa9cfc; padding: 0.45em 0.6em; margin: 0 0.25em; line-height: 1; border-radius: 0.35em;\">\n",
" Henry Jarvis\n",
" <span style=\"font-size: 0.8em; font-weight: bold; line-height: 1; border-radius: 0.35em; vertical-align: middle; margin-left: 0.5rem\">PERSON</span>\n",
"</mark>\n",
" \n",
"<mark class=\"entity\" style=\"background: #aa9cfc; padding: 0.45em 0.6em; margin: 0 0.25em; line-height: 1; border-radius: 0.35em;\">\n",
" Raymond\n",
" <span style=\"font-size: 0.8em; font-weight: bold; line-height: 1; border-radius: 0.35em; vertical-align: middle; margin-left: 0.5rem\">PERSON</span>\n",
"</mark>\n",
" and \n",
"<mark class=\"entity\" style=\"background: #aa9cfc; padding: 0.45em 0.6em; margin: 0 0.25em; line-height: 1; border-radius: 0.35em;\">\n",
" George Jones\n",
" <span style=\"font-size: 0.8em; font-weight: bold; line-height: 1; border-radius: 0.35em; vertical-align: middle; margin-left: 0.5rem\">PERSON</span>\n",
"</mark>\n",
", and was initially published by \n",
"<mark class=\"entity\" style=\"background: #7aecec; padding: 0.45em 0.6em; margin: 0 0.25em; line-height: 1; border-radius: 0.35em;\">\n",
" Raymond, Jones &amp; Company\n",
" <span style=\"font-size: 0.8em; font-weight: bold; line-height: 1; border-radius: 0.35em; vertical-align: middle; margin-left: 0.5rem\">ORG</span>\n",
"</mark>\n",
".</div></span>"
],
"text/plain": [
"<IPython.core.display.HTML object>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"# Visualize another example.\n",
"displacy.render(nlp(text_strings[1]), style=\"ent\")"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "7f841596-f217-46d2-b64e-1952db4de4e1"
},
"source": [
"## Create a model handler\n",
"\n",
"This section demonstrates how to create your own `ModelHandler` so that you can use `spaCy` for inference."
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {
"id": "7f841596-f217-46d2-b64e-1952db4de4e2"
},
"outputs": [
{
"data": {
"application/javascript": "\n if (typeof window.interactive_beam_jquery == 'undefined') {\n var jqueryScript = document.createElement('script');\n jqueryScript.src = 'https://code.jquery.com/jquery-3.4.1.slim.min.js';\n jqueryScript.type = 'text/javascript';\n jqueryScript.onload = function() {\n var datatableScript = document.createElement('script');\n datatableScript.src = 'https://cdn.datatables.net/1.10.20/js/jquery.dataTables.min.js';\n datatableScript.type = 'text/javascript';\n datatableScript.onload = function() {\n window.interactive_beam_jquery = jQuery.noConflict(true);\n window.interactive_beam_jquery(document).ready(function($){\n \n });\n }\n document.head.appendChild(datatableScript);\n };\n document.head.appendChild(jqueryScript);\n } else {\n window.interactive_beam_jquery(document).ready(function($){\n \n });\n }"
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"The New York Times is an American daily newspaper based in New York City with a worldwide readership.\n",
"It was founded in 1851 by Henry Jarvis Raymond and George Jones, and was initially published by Raymond, Jones & Company.\n"
]
}
],
"source": [
"\n",
"import apache_beam as beam\n",
"from apache_beam.options.pipeline_options import PipelineOptions\n",
"\n",
"import warnings\n",
"warnings.filterwarnings(\"ignore\")\n",
"\n",
"\n",
"pipeline = beam.Pipeline()\n",
"\n",
"# Print the results for verification.\n",
"with pipeline as p:\n",
" (p \n",
" | \"CreateSentences\" >> beam.Create(text_strings)\n",
" | beam.Map(print)\n",
" )\n"
]
},
{
"cell_type": "code",
"execution_count": 10,
"metadata": {
"id": "7f841596-f217-46d2-b64e-1952db4de4e3"
},
"outputs": [],
"source": [
"# Define `SpacyModelHandler` to load the model and perform the inference.\n",
"\n",
"from apache_beam.ml.inference.base import RunInference\n",
"from apache_beam.ml.inference.base import ModelHandler\n",
"from apache_beam.ml.inference.base import PredictionResult\n",
"from spacy import Language\n",
"from typing import Any\n",
"from typing import Dict\n",
"from typing import Iterable\n",
"from typing import Optional\n",
"from typing import Sequence\n",
"\n",
"class SpacyModelHandler(ModelHandler[str,\n",
" PredictionResult,\n",
" Language]):\n",
" def __init__(\n",
" self,\n",
" model_name: str = \"en_core_web_sm\",\n",
" ):\n",
" \"\"\" Implementation of the ModelHandler interface for spaCy using text as input.\n",
"\n",
" Example Usage::\n",
"\n",
" pcoll | RunInference(SpacyModelHandler())\n",
"\n",
" Args:\n",
" model_name: The spaCy model name. Default is en_core_web_sm.\n",
" \"\"\"\n",
" self._model_name = model_name\n",
" self._env_vars = {}\n",
"\n",
" def load_model(self) -> Language:\n",
" \"\"\"Loads and initializes a model for processing.\"\"\"\n",
" return spacy.load(self._model_name)\n",
"\n",
" def run_inference(\n",
" self,\n",
" batch: Sequence[str],\n",
" model: Language,\n",
" inference_args: Optional[Dict[str, Any]] = None\n",
" ) -> Iterable[PredictionResult]:\n",
" \"\"\"Runs inferences on a batch of text strings.\n",
"\n",
" Args:\n",
" batch: A sequence of examples as text strings. \n",
" model: A spaCy language model\n",
" inference_args: Any additional arguments for an inference.\n",
"\n",
" Returns:\n",
" An Iterable of type PredictionResult.\n",
" \"\"\"\n",
" # Loop each text string, and use a tuple to store the inference results.\n",
" predictions = []\n",
" for one_text in batch:\n",
" doc = model(one_text)\n",
" predictions.append(\n",
" [(ent.text, ent.start_char, ent.end_char, ent.label_) for ent in doc.ents])\n",
" return [PredictionResult(x, y) for x, y in zip(batch, predictions)]\n"
]
},
{
"cell_type": "code",
"execution_count": 11,
"metadata": {
"id": "7f841596-f217-46d2-b64e-1952db4de4e4"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"The New York Times is an American daily newspaper based in New York City with a worldwide readership.\n",
"It was founded in 1851 by Henry Jarvis Raymond and George Jones, and was initially published by Raymond, Jones & Company.\n",
"PredictionResult(example='The New York Times is an American daily newspaper based in New York City with a worldwide readership.', inference=[('The New York Times', 0, 18, 'ORG'), ('American', 25, 33, 'NORP'), ('daily', 34, 39, 'DATE'), ('New York City', 59, 72, 'GPE')])\n",
"PredictionResult(example='It was founded in 1851 by Henry Jarvis Raymond and George Jones, and was initially published by Raymond, Jones & Company.', inference=[('1851', 18, 22, 'DATE'), ('Henry Jarvis', 26, 38, 'PERSON'), ('Raymond', 39, 46, 'PERSON'), ('George Jones', 51, 63, 'PERSON'), ('Raymond, Jones & Company', 96, 120, 'ORG')])\n"
]
}
],
"source": [
"# Verify that the inference results are correct.\n",
"with pipeline as p:\n",
" (p \n",
" | \"CreateSentences\" >> beam.Create(text_strings)\n",
" | \"RunInferenceSpacy\" >> RunInference(SpacyModelHandler(\"en_core_web_sm\"))\n",
" | beam.Map(print)\n",
" )\n"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "7f841596-f217-46d2-b64e-1952db4de4e5"
},
"source": [
"## Use `KeyedModelHandler` to handle keyed data\n",
"\n",
"If you have keyed data, use `KeyedModelHandler`."
]
},
{
"cell_type": "code",
"execution_count": 12,
"metadata": {
"id": "7f841596-f217-46d2-b64e-1952db4de4e6"
},
"outputs": [],
"source": [
"# You can use these text strings with keys to distinguish examples.\n",
"text_strings_with_keys = [\n",
" (\"example_0\", \"The New York Times is an American daily newspaper based in New York City with a worldwide readership.\"),\n",
" (\"example_1\", \"It was founded in 1851 by Henry Jarvis Raymond and George Jones, and was initially published by Raymond, Jones & Company.\")\n",
"]\n"
]
},
{
"cell_type": "code",
"execution_count": 13,
"metadata": {
"id": "7f841596-f217-46d2-b64e-1952db4de4e7"
},
"outputs": [],
"source": [
"from apache_beam.runners.interactive.interactive_runner import InteractiveRunner\n",
"from apache_beam.ml.inference.base import KeyedModelHandler\n",
"from apache_beam.dataframe.convert import to_dataframe\n",
"\n",
"pipeline = beam.Pipeline(InteractiveRunner())\n",
"\n",
"keyed_spacy_model_handler = KeyedModelHandler(SpacyModelHandler(\"en_core_web_sm\"))\n",
"\n",
"# Verify that the inference results are correct.\n",
"with pipeline as p:\n",
" results = (p \n",
" | \"CreateSentences\" >> beam.Create(text_strings_with_keys)\n",
" | \"RunInferenceSpacy\" >> RunInference(keyed_spacy_model_handler)\n",
" # Generate a schema suitable for conversion to a dataframe using Map to Row objects.\n",
" | 'ToRows' >> beam.Map(lambda row: beam.Row(key=row[0], text=row[1][0], predictions=row[1][1]))\n",
" )"
]
},
{
"cell_type": "code",
"execution_count": 14,
"metadata": {
"id": "7f841596-f217-46d2-b64e-1952db4de4e8"
},
"outputs": [
{
"data": {
"text/html": [
"\n",
" <link rel=\"stylesheet\" href=\"https://stackpath.bootstrapcdn.com/bootstrap/4.4.1/css/bootstrap.min.css\" integrity=\"sha384-Vkoo8x4CGsO3+Hhxv8T/Q5PaXtkKtu6ug5TOeNV6gBiFeWPGFN9MuhOf23Q9Ifjh\" crossorigin=\"anonymous\">\n",
" <div id=\"progress_indicator_25aaf10d571c025a28901bea46b94c93\">\n",
" <div class=\"spinner-border text-info\" role=\"status\"></div>\n",
" <span class=\"text-info\">Processing... collect</span>\n",
" </div>\n",
" "
],
"text/plain": [
"<IPython.core.display.HTML object>"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/javascript": "\n if (typeof window.interactive_beam_jquery == 'undefined') {\n var jqueryScript = document.createElement('script');\n jqueryScript.src = 'https://code.jquery.com/jquery-3.4.1.slim.min.js';\n jqueryScript.type = 'text/javascript';\n jqueryScript.onload = function() {\n var datatableScript = document.createElement('script');\n datatableScript.src = 'https://cdn.datatables.net/1.10.20/js/jquery.dataTables.min.js';\n datatableScript.type = 'text/javascript';\n datatableScript.onload = function() {\n window.interactive_beam_jquery = jQuery.noConflict(true);\n window.interactive_beam_jquery(document).ready(function($){\n \n $(\"#progress_indicator_25aaf10d571c025a28901bea46b94c93\").remove();\n });\n }\n document.head.appendChild(datatableScript);\n };\n document.head.appendChild(jqueryScript);\n } else {\n window.interactive_beam_jquery(document).ready(function($){\n \n $(\"#progress_indicator_25aaf10d571c025a28901bea46b94c93\").remove();\n });\n }"
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"# Convert the results to a pandas dataframe.\n",
"import apache_beam.runners.interactive.interactive_beam as ib\n",
"\n",
"beam_df = to_dataframe(results)\n",
"df = ib.collect(beam_df)\n"
]
},
{
"cell_type": "code",
"execution_count": 15,
"metadata": {
"id": "7f841596-f217-46d2-b64e-1952db4de4e9"
},
"outputs": [
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>key</th>\n",
" <th>text</th>\n",
" <th>predictions</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>0</th>\n",
" <td>example_0</td>\n",
" <td>The New York Times is an American daily newspa...</td>\n",
" <td>[(The New York Times, 0, 18, ORG), (American, ...</td>\n",
" </tr>\n",
" <tr>\n",
" <th>0</th>\n",
" <td>example_1</td>\n",
" <td>It was founded in 1851 by Henry Jarvis Raymond...</td>\n",
" <td>[(1851, 18, 22, DATE), (Henry Jarvis, 26, 38, ...</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"</div>"
],
"text/plain": [
" key text \\\n",
"0 example_0 The New York Times is an American daily newspa... \n",
"0 example_1 It was founded in 1851 by Henry Jarvis Raymond... \n",
"\n",
" predictions \n",
"0 [(The New York Times, 0, 18, ORG), (American, ... \n",
"0 [(1851, 18, 22, DATE), (Henry Jarvis, 26, 38, ... "
]
},
"execution_count": 15,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"df"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "7f841596-f217-46d2-b64e-1952db4de4f0"
},
"outputs": [],
"source": []
}
],
"metadata": {
"colab": {
"collapsed_sections": [],
"name": "Beam RunInference",
"provenance": [],
"toc_visible": true
},
"kernelspec": {
"display_name": "Python 3.9.13 ('venv': venv)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.13"
},
"vscode": {
"interpreter": {
"hash": "aab5fceeb08468f7e142944162550e82df74df803ff2eb1987d9526d4285522f"
}
}
},
"nbformat": 4,
"nbformat_minor": 2
}