blob: 96b657ef4cdf9cd8b378fae19338dd79978655aa [file] [log] [blame]
{
"nbformat": 4,
"nbformat_minor": 0,
"metadata": {
"colab": {
"provenance": []
},
"kernelspec": {
"name": "python3",
"display_name": "Python 3"
},
"language_info": {
"name": "python"
}
},
"cells": [
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"cellView": "form",
"id": "XobBB6Sv8mB3"
},
"outputs": [],
"source": [
"# @title ###### Licensed to the Apache Software Foundation (ASF), Version 2.0 (the \"License\")\n",
"\n",
"# Licensed to the Apache Software Foundation (ASF) under one\n",
"# or more contributor license agreements. See the NOTICE file\n",
"# distributed with this work for additional information\n",
"# regarding copyright ownership. The ASF licenses this file\n",
"# to you under the Apache License, Version 2.0 (the\n",
"# \"License\"); you may not use this file except in compliance\n",
"# with the License. You may obtain a copy of the License at\n",
"#\n",
"# http://www.apache.org/licenses/LICENSE-2.0\n",
"#\n",
"# Unless required by applicable law or agreed to in writing,\n",
"# software distributed under the License is distributed on an\n",
"# \"AS IS\" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY\n",
"# KIND, either express or implied. See the License for the\n",
"# specific language governing permissions and limitations\n",
"# under the License"
]
},
{
"cell_type": "markdown",
"source": [
"# Apache Beam RunInference for XGBoost\n",
"\n",
"<table align=\"left\">\n",
" <td>\n",
" <a target=\"_blank\" href=\"https://colab.research.google.com/github/apache/beam/blob/master/examples/notebooks/beam-ml/run_inference_xgboost.ipynb\"><img src=\"https://raw.githubusercontent.com/google/or-tools/main/tools/colab_32px.png\" />Run in Google Colab</a>\n",
" </td>\n",
" <td>\n",
" <a target=\"_blank\" href=\"https://github.com/apache/beam/blob/master/examples/notebooks/beam-ml/run_inference_xgboost.ipynb\"><img src=\"https://raw.githubusercontent.com/google/or-tools/main/tools/github_32px.png\" />View source on GitHub</a>\n",
" </td>\n",
"</table>\n"
],
"metadata": {
"id": "DUGbrRuv89CS"
}
},
{
"cell_type": "markdown",
"source": [
"This notebook shows how to use the Apache Beam [RunInference](https://beam.apache.org/releases/pydoc/current/apache_beam.ml.inference.base.html#apache_beam.ml.inference.base.RunInference) transform with [XGBoost](https://xgboost.readthedocs.io/en/stable/).\n",
"The Apache Beam RunInference transform has implementations of the `ModelHandler` class prebuilt for XGBoost. For more information about using RunInference, see [Get started with AI/ML pipelines](https://beam.apache.org/documentation/ml/overview/) in the Apache Beam documentation.\n",
"\n",
"You can choose the appropriate model handler based on your input data type:\n",
"\n",
"- NumPy model handler\n",
"- Pandas DataFrame model handler\n",
"- DataTable model handler\n",
"- SciPy model handler\n",
"\n",
"With RunInference, these model handlers manage batching, vectorization, and prediction optimization for your XGBoost pipeline or model.\n",
"\n",
"This notebook demonstrates the following common RunInference patterns:\n",
"\n",
"- Generate predictions.\n",
"- Postprocess results after running inference.\n",
"- Classify iris flowers.\n",
"- Make home price predictions with a regression model."
],
"metadata": {
"id": "6nh2h-sIOAOg"
}
},
{
"cell_type": "markdown",
"source": [
"## Before you begin\n",
"Install dependencies for Apache Beam."
],
"metadata": {
"id": "nRCJBcTUOq1k"
}
},
{
"cell_type": "code",
"source": [
"!pip install apache-beam[gcp]>=2.47.0"
],
"metadata": {
"id": "gbmH329jOuj1"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"source": [
"import xgboost\n",
"import apache_beam as beam\n",
"from sklearn.datasets import fetch_california_housing\n",
"from sklearn.datasets import load_iris\n",
"from sklearn.model_selection import train_test_split\n",
"\n",
"from apache_beam.ml.inference import RunInference\n",
"from apache_beam.ml.inference.xgboost_inference import XGBoostModelHandlerNumpy\n",
"from apache_beam.options.pipeline_options import PipelineOptions"
],
"metadata": {
"id": "_O0BN_XqOwp1"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"source": [
"SEED = 999\n",
"CLASSIFICATION_MODEL_STATE = '/tmp/classification_model.json'\n",
"REGRESSION_MODEL_STATE = '/tmp/regression_model.json'"
],
"metadata": {
"id": "ue_5a-oaO-Lz"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "markdown",
"source": [
"## Load data from scikit-learn and train XGBoost models\n",
"In this example, you create two models, one to classify iris flowers and one to predict housing prices in California.\n",
"\n",
"This section demonstrates the following steps:\n",
"1. Load the iris flower and Califorina housing datasets from scikit-learn, and then create classification and regression models.\n",
"2. Train the classification and regression models.\n",
"3. Save the models in a JSON file using `mode.save_model`. For more information, see [Introduction to Model IO](https://xgboost.readthedocs.io/en/stable/tutorials/saving_model.html) in the XGBoost documentation."
],
"metadata": {
"id": "74oE5pGgPE0M"
}
},
{
"cell_type": "code",
"source": [
"# Train the classification model.\n",
"iris_dataset = load_iris()\n",
"x_train_classification, x_test_classification, y_train_classification, y_test_classification = train_test_split(\n",
" iris_dataset['data'], iris_dataset['target'], test_size=.2, random_state=SEED)\n",
"booster = xgboost.XGBClassifier(\n",
" n_estimators=2, max_depth=2, learning_rate=1, objective='binary:logistic')\n",
"booster.fit(x_train_classification, y_train_classification)\n",
"booster.save_model(CLASSIFICATION_MODEL_STATE)\n",
"\n",
"# Train the regression model.\n",
"california_dataset = fetch_california_housing()\n",
"x_train_regression, x_test_regression, y_train_regression, y_test_regression = train_test_split(\n",
" california_dataset['data'], california_dataset['target'], test_size=.2, random_state=SEED)\n",
"model = xgboost.XGBRegressor(\n",
" n_estimators=1000,\n",
" max_depth=8,\n",
" eta=0.1,\n",
" subsample=0.75,\n",
" colsample_bytree=0.8)\n",
"model.fit(x_train_regression, y_train_regression)\n",
"model.save_model(REGRESSION_MODEL_STATE)\n",
"\n",
"\n",
"# Reshape the test data, because XGBoost expects a batch instead of a single element.\n",
"# For more information, see https://xgboost.readthedocs.io/en/stable/prediction.html\n",
"x_test_classification = x_test.reshape(5, 6, 4)\n",
"x_test_regression = x_test_regression.reshape(258, 16, 8)"
],
"metadata": {
"id": "KVSKt3pFPBnj"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "markdown",
"source": [
"## Implement postprocessing helper functions\n",
"The following code demonstrates how to implement postprocessing helper functions for your models."
],
"metadata": {
"id": "VGQj-B1Abioq"
}
},
{
"cell_type": "code",
"source": [
"def translate_labels(inference_results: PredictionResult):\n",
" \"\"\"\n",
" Maps output values (0, 1 or 2) of the XGBoost iris classification\n",
" model to the names of the different iris flowers.\n",
" Args:\n",
" inference_results: Array containing the outputs of the XGBoost Iris classification model\n",
" \"\"\"\n",
" return PredictionResult(\n",
" inference_results.example,\n",
" np.vectorize(['Setosa', 'Versicolour',\n",
" 'Virginica'].__getitem__)(inference_results.inference))\n",
"\n",
"\n",
"class FlattenBatchPredictionResults(beam.DoFn):\n",
" \"\"\"\n",
" This function takes a PredictionResult containing a batch (list) of\n",
" examples and predictions as input and prints all example/prediction pairs.\n",
" \"\"\"\n",
" def process(self, batch_prediction_result: PredictionResult):\n",
" for example, inference in zip(batch_prediction_result.example, batch_prediction_result.inference):\n",
" print(PredictionResult(\n",
" example, inference, batch_prediction_result.model_id))\n"
],
"metadata": {
"id": "e1xDKfwbbg0z"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "markdown",
"source": [
"## Create an XGBoost RunInference pipeline\n",
"This section demonstrates how to do the following tasks:\n",
"1. Define an XGBoost model handler that accepts a `numpy.ndarray` object as input.\n",
"2. Load the data from the datasets.\n",
"3. Use the XGBoost trained models and the XGBoost RunInference transform on unkeyed data.\n",
"4. Use postprocessing to tranlate the XGBoost numeric outputs into flower names, and then flatten the batched outputs."
],
"metadata": {
"id": "ItuxdQoXSNTQ"
}
},
{
"cell_type": "code",
"source": [
"xgboost_classification_model_handler = XGBoostModelHandlerNumpy(\n",
" model_class=xgboost.XGBClassifier, model_state=CLASSIFICATION_MODEL_STATE)\n",
"\n",
"pipeline_options = PipelineOptions().from_dictionary({})\n",
"\n",
"with beam.Pipeline(options=pipeline_options) as p:\n",
" (\n",
" p\n",
" | \"Load Data\" >> beam.Create(x_test_classification)\n",
" | \"RunInferenceXGBoost\" >>\n",
" RunInference(model_handler=xgboost_classification_model_handler)\n",
" | \"TranslateLabels\" >> beam.Map(translate_labels)\n",
" | \"FlattenBatchPredictionResults\" >> beam.ParDo(\n",
" FlattenBatchPredictionResults()))"
],
"metadata": {
"id": "SBdMq3-CSGqZ"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"source": [
"xgboost_regression_model_handler = XGBoostModelHandlerNumpy(\n",
" model_class=xgboost.XGBRegressor, model_state=REGRESSION_MODEL_STATE)\n",
"\n",
"pipeline_options = PipelineOptions().from_dictionary({})\n",
"\n",
"with beam.Pipeline(options=pipeline_options) as p:\n",
" (\n",
" p\n",
" | \"Load Data\" >> beam.Create(x_test_regression)\n",
" | \"RunInferenceXGBoost\" >>\n",
" RunInference(model_handler=xgboost_regression_model_handler)\n",
" | \"FlattenBatchPredictionResults\" >> beam.ParDo(\n",
" FlattenBatchPredictionResults()))"
],
"metadata": {
"id": "IYUXIJt7UIm6"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "markdown",
"source": [
"## Use XGBoost with RunInference on keyed inputs\n",
"To retain metadata about the example, associate examples with a key before doing inference.\n",
"For example, you might want to retain the original URL of a preprocessed image or a non-preprocessed input.\n",
"To use RunInference to retain metadata, use a `KeyedModelHandler`.\n",
"This section demonstrates how to do the following tasks with a `KeyedModelHandler`:\n",
"\n",
"\n",
"1. To handle keyed data, wrap the `XGBoostHandlerNumpy` with a `KeyedModelHandler`.\n",
"2. Load the data from the datasets.\n",
"3. Use the XGBoost trained models and the XGBoost RunInference transform on the keyed data.\n",
"4. Postprocess the results to flatten the batched outputs."
],
"metadata": {
"id": "ptTZUGmqW4s2"
}
},
{
"cell_type": "code",
"source": [
"x_test_classification = [(f'batch {i}', sample) for i, sample in enumerate(x_test_classification)]\n",
"x_test_regression = [(f'batch {i}', sample for i, sample in enumerate(x_test_regression)]"
],
"metadata": {
"id": "MBSbY569W3zm"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"source": [
"keyed_xgboost_regression_model_handler = KeyedModelHandler(xgboost_classification_model_handler)\n",
"\n",
"pipeline_options = PipelineOptions().from_dictionary({})\n",
"\n",
"with beam.Pipeline(options=pipeline_options) as p:\n",
" (\n",
" p\n",
" | \"Load Data\" >> beam.Create(x_test_classification)\n",
" | \"RunInferenceXGBoost\" >>\n",
" RunInference(model_handler=keyed_xgboost_regression_model_handler)\n",
" | \"TranslateLabels\" >> beam.Map(translate_labels)\n",
" | \"FlattenBatchPredictionResults\" >> beam.ParDo(\n",
" FlattenBatchPredictionResults()))"
],
"metadata": {
"id": "8L7sU7a5YXrI"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"source": [
"keyed_xgboost_regression_model_handler = KeyedModelHandler(xgboost_regression_model_handler)\n",
"\n",
"pipeline_options = PipelineOptions().from_dictionary({})\n",
"\n",
"with beam.Pipeline(options=pipeline_options) as p:\n",
" (\n",
" p\n",
" | \"Load Data\" >> beam.Create(x_test_regression)\n",
" | \"RunInferenceXGBoost\" >>\n",
" RunInference(model_handler=keyed_xgboost_regression_model_handler)\n",
" | \"FlattenBatchPredictionResults\" >> beam.ParDo(\n",
" FlattenBatchPredictionResults()))"
],
"metadata": {
"id": "935E8Q7cYkDK"
},
"execution_count": null,
"outputs": []
}
]
}