blob: e6fe4b5751a401395b663882235371b8543e532e [file] [log] [blame]
<!DOCTYPE html>
<!--[if lt IE 7]> <html class="no-js lt-ie9 lt-ie8 lt-ie7"> <![endif]-->
<!--[if IE 7]> <html class="no-js lt-ie9 lt-ie8"> <![endif]-->
<!--[if IE 8]> <html class="no-js lt-ie9"> <![endif]-->
<!--[if gt IE 8]><!--> <html class="no-js"> <!--<![endif]-->
<head>
<meta charset="utf-8">
<meta http-equiv="X-UA-Compatible" content="IE=edge,chrome=1">
<title>ML Pipelines - Spark 3.4.3 Documentation</title>
<link rel="stylesheet" href="css/bootstrap.min.css">
<style>
body {
padding-top: 60px;
padding-bottom: 40px;
}
</style>
<meta name="viewport" content="width=device-width">
<link rel="stylesheet" href="css/main.css">
<script src="js/vendor/modernizr-2.6.1-respond-1.1.0.min.js"></script>
<link rel="stylesheet" href="css/pygments-default.css">
<link rel="stylesheet" href="https://cdn.jsdelivr.net/npm/docsearch.js@2/dist/cdn/docsearch.min.css" />
<link rel="stylesheet" href="css/docsearch.css">
<!-- Matomo -->
<script>
var _paq = window._paq = window._paq || [];
/* tracker methods like "setCustomDimension" should be called before "trackPageView" */
_paq.push(["disableCookies"]);
_paq.push(['trackPageView']);
_paq.push(['enableLinkTracking']);
(function() {
var u="https://analytics.apache.org/";
_paq.push(['setTrackerUrl', u+'matomo.php']);
_paq.push(['setSiteId', '40']);
var d=document, g=d.createElement('script'), s=d.getElementsByTagName('script')[0];
g.async=true; g.src=u+'matomo.js'; s.parentNode.insertBefore(g,s);
})();
</script>
<!-- End Matomo Code -->
</head>
<body>
<!--[if lt IE 7]>
<p class="chromeframe">You are using an outdated browser. <a href="https://browsehappy.com/">Upgrade your browser today</a> or <a href="http://www.google.com/chromeframe/?redirect=true">install Google Chrome Frame</a> to better experience this site.</p>
<![endif]-->
<!-- This code is taken from http://twitter.github.com/bootstrap/examples/hero.html -->
<nav class="navbar fixed-top navbar-expand-md navbar-light bg-light" id="topbar">
<div class="container">
<div class="navbar-header">
<div class="navbar-brand"><a href="index.html">
<img src="img/spark-logo-hd.png" style="height:50px;"/></a><span class="version">3.4.3</span>
</div>
</div>
<button class="navbar-toggler" type="button" data-toggle="collapse"
data-target="#navbarCollapse" aria-controls="navbarCollapse"
aria-expanded="false" aria-label="Toggle navigation">
<span class="navbar-toggler-icon"></span>
</button>
<div class="collapse navbar-collapse" id="navbarCollapse">
<ul class="navbar-nav">
<!--TODO(andyk): Add class="active" attribute to li some how.-->
<li class="nav-item"><a href="index.html" class="nav-link">Overview</a></li>
<li class="nav-item dropdown">
<a href="#" class="nav-link dropdown-toggle" id="navbarQuickStart" role="button" data-toggle="dropdown" aria-haspopup="true" aria-expanded="false">Programming Guides</a>
<div class="dropdown-menu" aria-labelledby="navbarQuickStart">
<a class="dropdown-item" href="quick-start.html">Quick Start</a>
<a class="dropdown-item" href="rdd-programming-guide.html">RDDs, Accumulators, Broadcasts Vars</a>
<a class="dropdown-item" href="sql-programming-guide.html">SQL, DataFrames, and Datasets</a>
<a class="dropdown-item" href="structured-streaming-programming-guide.html">Structured Streaming</a>
<a class="dropdown-item" href="streaming-programming-guide.html">Spark Streaming (DStreams)</a>
<a class="dropdown-item" href="ml-guide.html">MLlib (Machine Learning)</a>
<a class="dropdown-item" href="graphx-programming-guide.html">GraphX (Graph Processing)</a>
<a class="dropdown-item" href="sparkr.html">SparkR (R on Spark)</a>
<a class="dropdown-item" href="api/python/getting_started/index.html">PySpark (Python on Spark)</a>
</div>
</li>
<li class="nav-item dropdown">
<a href="#" class="nav-link dropdown-toggle" id="navbarAPIDocs" role="button" data-toggle="dropdown" aria-haspopup="true" aria-expanded="false">API Docs</a>
<div class="dropdown-menu" aria-labelledby="navbarAPIDocs">
<a class="dropdown-item" href="api/scala/org/apache/spark/index.html">Scala</a>
<a class="dropdown-item" href="api/java/index.html">Java</a>
<a class="dropdown-item" href="api/python/index.html">Python</a>
<a class="dropdown-item" href="api/R/index.html">R</a>
<a class="dropdown-item" href="api/sql/index.html">SQL, Built-in Functions</a>
</div>
</li>
<li class="nav-item dropdown">
<a href="#" class="nav-link dropdown-toggle" id="navbarDeploying" role="button" data-toggle="dropdown" aria-haspopup="true" aria-expanded="false">Deploying</a>
<div class="dropdown-menu" aria-labelledby="navbarDeploying">
<a class="dropdown-item" href="cluster-overview.html">Overview</a>
<a class="dropdown-item" href="submitting-applications.html">Submitting Applications</a>
<div class="dropdown-divider"></div>
<a class="dropdown-item" href="spark-standalone.html">Spark Standalone</a>
<a class="dropdown-item" href="running-on-mesos.html">Mesos</a>
<a class="dropdown-item" href="running-on-yarn.html">YARN</a>
<a class="dropdown-item" href="running-on-kubernetes.html">Kubernetes</a>
</div>
</li>
<li class="nav-item dropdown">
<a href="#" class="nav-link dropdown-toggle" id="navbarMore" role="button" data-toggle="dropdown" aria-haspopup="true" aria-expanded="false">More</a>
<div class="dropdown-menu" aria-labelledby="navbarMore">
<a class="dropdown-item" href="configuration.html">Configuration</a>
<a class="dropdown-item" href="monitoring.html">Monitoring</a>
<a class="dropdown-item" href="tuning.html">Tuning Guide</a>
<a class="dropdown-item" href="job-scheduling.html">Job Scheduling</a>
<a class="dropdown-item" href="security.html">Security</a>
<a class="dropdown-item" href="hardware-provisioning.html">Hardware Provisioning</a>
<a class="dropdown-item" href="migration-guide.html">Migration Guide</a>
<div class="dropdown-divider"></div>
<a class="dropdown-item" href="building-spark.html">Building Spark</a>
<a class="dropdown-item" href="https://spark.apache.org/contributing.html">Contributing to Spark</a>
<a class="dropdown-item" href="https://spark.apache.org/third-party-projects.html">Third Party Projects</a>
</div>
</li>
<li class="nav-item">
<input type="text" id="docsearch-input" placeholder="Search the docs…">
</li>
</ul>
<!--<span class="navbar-text navbar-right"><span class="version-text">v3.4.3</span></span>-->
</div>
</div>
</nav>
<div class="container-wrapper">
<div class="left-menu-wrapper">
<div class="left-menu">
<h3><a href="ml-guide.html">MLlib: Main Guide</a></h3>
<ul>
<li>
<a href="ml-statistics.html">
Basic statistics
</a>
</li>
<li>
<a href="ml-datasource.html">
Data sources
</a>
</li>
<li>
<a href="ml-pipeline.html">
Pipelines
</a>
</li>
<li>
<a href="ml-features.html">
Extracting, transforming and selecting features
</a>
</li>
<li>
<a href="ml-classification-regression.html">
Classification and Regression
</a>
</li>
<li>
<a href="ml-clustering.html">
Clustering
</a>
</li>
<li>
<a href="ml-collaborative-filtering.html">
Collaborative filtering
</a>
</li>
<li>
<a href="ml-frequent-pattern-mining.html">
Frequent Pattern Mining
</a>
</li>
<li>
<a href="ml-tuning.html">
Model selection and tuning
</a>
</li>
<li>
<a href="ml-advanced.html">
Advanced topics
</a>
</li>
</ul>
<h3><a href="mllib-guide.html">MLlib: RDD-based API Guide</a></h3>
<ul>
<li>
<a href="mllib-data-types.html">
Data types
</a>
</li>
<li>
<a href="mllib-statistics.html">
Basic statistics
</a>
</li>
<li>
<a href="mllib-classification-regression.html">
Classification and regression
</a>
</li>
<li>
<a href="mllib-collaborative-filtering.html">
Collaborative filtering
</a>
</li>
<li>
<a href="mllib-clustering.html">
Clustering
</a>
</li>
<li>
<a href="mllib-dimensionality-reduction.html">
Dimensionality reduction
</a>
</li>
<li>
<a href="mllib-feature-extraction.html">
Feature extraction and transformation
</a>
</li>
<li>
<a href="mllib-frequent-pattern-mining.html">
Frequent pattern mining
</a>
</li>
<li>
<a href="mllib-evaluation-metrics.html">
Evaluation metrics
</a>
</li>
<li>
<a href="mllib-pmml-model-export.html">
PMML model export
</a>
</li>
<li>
<a href="mllib-optimization.html">
Optimization (developer)
</a>
</li>
</ul>
</div>
</div>
<input id="nav-trigger" class="nav-trigger" checked type="checkbox">
<label for="nav-trigger"></label>
<div class="content-with-sidebar mr-3" id="content">
<h1 class="title">ML Pipelines</h1>
<p><code class="language-plaintext highlighter-rouge">\[
\newcommand{\R}{\mathbb{R}}
\newcommand{\E}{\mathbb{E}}
\newcommand{\x}{\mathbf{x}}
\newcommand{\y}{\mathbf{y}}
\newcommand{\wv}{\mathbf{w}}
\newcommand{\av}{\mathbf{\alpha}}
\newcommand{\bv}{\mathbf{b}}
\newcommand{\N}{\mathbb{N}}
\newcommand{\id}{\mathbf{I}}
\newcommand{\ind}{\mathbf{1}}
\newcommand{\0}{\mathbf{0}}
\newcommand{\unit}{\mathbf{e}}
\newcommand{\one}{\mathbf{1}}
\newcommand{\zero}{\mathbf{0}}
\]</code></p>
<p>In this section, we introduce the concept of <strong><em>ML Pipelines</em></strong>.
ML Pipelines provide a uniform set of high-level APIs built on top of
<a href="sql-programming-guide.html">DataFrames</a> that help users create and tune practical
machine learning pipelines.</p>
<p><strong>Table of Contents</strong></p>
<ul id="markdown-toc">
<li><a href="#main-concepts-in-pipelines" id="markdown-toc-main-concepts-in-pipelines">Main concepts in Pipelines</a> <ul>
<li><a href="#dataframe" id="markdown-toc-dataframe">DataFrame</a></li>
<li><a href="#pipeline-components" id="markdown-toc-pipeline-components">Pipeline components</a> <ul>
<li><a href="#transformers" id="markdown-toc-transformers">Transformers</a></li>
<li><a href="#estimators" id="markdown-toc-estimators">Estimators</a></li>
<li><a href="#properties-of-pipeline-components" id="markdown-toc-properties-of-pipeline-components">Properties of pipeline components</a></li>
</ul>
</li>
<li><a href="#pipeline" id="markdown-toc-pipeline">Pipeline</a> <ul>
<li><a href="#how-it-works" id="markdown-toc-how-it-works">How it works</a></li>
<li><a href="#details" id="markdown-toc-details">Details</a></li>
</ul>
</li>
<li><a href="#parameters" id="markdown-toc-parameters">Parameters</a></li>
<li><a href="#ml-persistence-saving-and-loading-pipelines" id="markdown-toc-ml-persistence-saving-and-loading-pipelines">ML persistence: Saving and Loading Pipelines</a> <ul>
<li><a href="#backwards-compatibility-for-ml-persistence" id="markdown-toc-backwards-compatibility-for-ml-persistence">Backwards compatibility for ML persistence</a></li>
</ul>
</li>
</ul>
</li>
<li><a href="#code-examples" id="markdown-toc-code-examples">Code examples</a> <ul>
<li><a href="#example-estimator-transformer-and-param" id="markdown-toc-example-estimator-transformer-and-param">Example: Estimator, Transformer, and Param</a></li>
<li><a href="#example-pipeline" id="markdown-toc-example-pipeline">Example: Pipeline</a></li>
<li><a href="#model-selection-hyperparameter-tuning" id="markdown-toc-model-selection-hyperparameter-tuning">Model selection (hyperparameter tuning)</a></li>
</ul>
</li>
</ul>
<h1 id="main-concepts-in-pipelines">Main concepts in Pipelines</h1>
<p>MLlib standardizes APIs for machine learning algorithms to make it easier to combine multiple
algorithms into a single pipeline, or workflow.
This section covers the key concepts introduced by the Pipelines API, where the pipeline concept is
mostly inspired by the <a href="http://scikit-learn.org/">scikit-learn</a> project.</p>
<ul>
<li>
<p><strong><a href="ml-pipeline.html#dataframe"><code class="language-plaintext highlighter-rouge">DataFrame</code></a></strong>: This ML API uses <code class="language-plaintext highlighter-rouge">DataFrame</code> from Spark SQL as an ML
dataset, which can hold a variety of data types.
E.g., a <code class="language-plaintext highlighter-rouge">DataFrame</code> could have different columns storing text, feature vectors, true labels, and predictions.</p>
</li>
<li>
<p><strong><a href="ml-pipeline.html#transformers"><code class="language-plaintext highlighter-rouge">Transformer</code></a></strong>: A <code class="language-plaintext highlighter-rouge">Transformer</code> is an algorithm which can transform one <code class="language-plaintext highlighter-rouge">DataFrame</code> into another <code class="language-plaintext highlighter-rouge">DataFrame</code>.
E.g., an ML model is a <code class="language-plaintext highlighter-rouge">Transformer</code> which transforms a <code class="language-plaintext highlighter-rouge">DataFrame</code> with features into a <code class="language-plaintext highlighter-rouge">DataFrame</code> with predictions.</p>
</li>
<li>
<p><strong><a href="ml-pipeline.html#estimators"><code class="language-plaintext highlighter-rouge">Estimator</code></a></strong>: An <code class="language-plaintext highlighter-rouge">Estimator</code> is an algorithm which can be fit on a <code class="language-plaintext highlighter-rouge">DataFrame</code> to produce a <code class="language-plaintext highlighter-rouge">Transformer</code>.
E.g., a learning algorithm is an <code class="language-plaintext highlighter-rouge">Estimator</code> which trains on a <code class="language-plaintext highlighter-rouge">DataFrame</code> and produces a model.</p>
</li>
<li>
<p><strong><a href="ml-pipeline.html#pipeline"><code class="language-plaintext highlighter-rouge">Pipeline</code></a></strong>: A <code class="language-plaintext highlighter-rouge">Pipeline</code> chains multiple <code class="language-plaintext highlighter-rouge">Transformer</code>s and <code class="language-plaintext highlighter-rouge">Estimator</code>s together to specify an ML workflow.</p>
</li>
<li>
<p><strong><a href="ml-pipeline.html#parameters"><code class="language-plaintext highlighter-rouge">Parameter</code></a></strong>: All <code class="language-plaintext highlighter-rouge">Transformer</code>s and <code class="language-plaintext highlighter-rouge">Estimator</code>s now share a common API for specifying parameters.</p>
</li>
</ul>
<h2 id="dataframe">DataFrame</h2>
<p>Machine learning can be applied to a wide variety of data types, such as vectors, text, images, and structured data.
This API adopts the <code class="language-plaintext highlighter-rouge">DataFrame</code> from Spark SQL in order to support a variety of data types.</p>
<p><code class="language-plaintext highlighter-rouge">DataFrame</code> supports many basic and structured types; see the <a href="sql-ref-datatypes.html">Spark SQL datatype reference</a> for a list of supported types.
In addition to the types listed in the Spark SQL guide, <code class="language-plaintext highlighter-rouge">DataFrame</code> can use ML <a href="mllib-data-types.html#local-vector"><code class="language-plaintext highlighter-rouge">Vector</code></a> types.</p>
<p>A <code class="language-plaintext highlighter-rouge">DataFrame</code> can be created either implicitly or explicitly from a regular <code class="language-plaintext highlighter-rouge">RDD</code>. See the code examples below and the <a href="sql-programming-guide.html">Spark SQL programming guide</a> for examples.</p>
<p>Columns in a <code class="language-plaintext highlighter-rouge">DataFrame</code> are named. The code examples below use names such as &#8220;text&#8221;, &#8220;features&#8221;, and &#8220;label&#8221;.</p>
<h2 id="pipeline-components">Pipeline components</h2>
<h3 id="transformers">Transformers</h3>
<p>A <code class="language-plaintext highlighter-rouge">Transformer</code> is an abstraction that includes feature transformers and learned models.
Technically, a <code class="language-plaintext highlighter-rouge">Transformer</code> implements a method <code class="language-plaintext highlighter-rouge">transform()</code>, which converts one <code class="language-plaintext highlighter-rouge">DataFrame</code> into
another, generally by appending one or more columns.
For example:</p>
<ul>
<li>A feature transformer might take a <code class="language-plaintext highlighter-rouge">DataFrame</code>, read a column (e.g., text), map it into a new
column (e.g., feature vectors), and output a new <code class="language-plaintext highlighter-rouge">DataFrame</code> with the mapped column appended.</li>
<li>A learning model might take a <code class="language-plaintext highlighter-rouge">DataFrame</code>, read the column containing feature vectors, predict the
label for each feature vector, and output a new <code class="language-plaintext highlighter-rouge">DataFrame</code> with predicted labels appended as a
column.</li>
</ul>
<h3 id="estimators">Estimators</h3>
<p>An <code class="language-plaintext highlighter-rouge">Estimator</code> abstracts the concept of a learning algorithm or any algorithm that fits or trains on
data.
Technically, an <code class="language-plaintext highlighter-rouge">Estimator</code> implements a method <code class="language-plaintext highlighter-rouge">fit()</code>, which accepts a <code class="language-plaintext highlighter-rouge">DataFrame</code> and produces a
<code class="language-plaintext highlighter-rouge">Model</code>, which is a <code class="language-plaintext highlighter-rouge">Transformer</code>.
For example, a learning algorithm such as <code class="language-plaintext highlighter-rouge">LogisticRegression</code> is an <code class="language-plaintext highlighter-rouge">Estimator</code>, and calling
<code class="language-plaintext highlighter-rouge">fit()</code> trains a <code class="language-plaintext highlighter-rouge">LogisticRegressionModel</code>, which is a <code class="language-plaintext highlighter-rouge">Model</code> and hence a <code class="language-plaintext highlighter-rouge">Transformer</code>.</p>
<h3 id="properties-of-pipeline-components">Properties of pipeline components</h3>
<p><code class="language-plaintext highlighter-rouge">Transformer.transform()</code>s and <code class="language-plaintext highlighter-rouge">Estimator.fit()</code>s are both stateless. In the future, stateful algorithms may be supported via alternative concepts.</p>
<p>Each instance of a <code class="language-plaintext highlighter-rouge">Transformer</code> or <code class="language-plaintext highlighter-rouge">Estimator</code> has a unique ID, which is useful in specifying parameters (discussed below).</p>
<h2 id="pipeline">Pipeline</h2>
<p>In machine learning, it is common to run a sequence of algorithms to process and learn from data.
E.g., a simple text document processing workflow might include several stages:</p>
<ul>
<li>Split each document&#8217;s text into words.</li>
<li>Convert each document&#8217;s words into a numerical feature vector.</li>
<li>Learn a prediction model using the feature vectors and labels.</li>
</ul>
<p>MLlib represents such a workflow as a <code class="language-plaintext highlighter-rouge">Pipeline</code>, which consists of a sequence of
<code class="language-plaintext highlighter-rouge">PipelineStage</code>s (<code class="language-plaintext highlighter-rouge">Transformer</code>s and <code class="language-plaintext highlighter-rouge">Estimator</code>s) to be run in a specific order.
We will use this simple workflow as a running example in this section.</p>
<h3 id="how-it-works">How it works</h3>
<p>A <code class="language-plaintext highlighter-rouge">Pipeline</code> is specified as a sequence of stages, and each stage is either a <code class="language-plaintext highlighter-rouge">Transformer</code> or an <code class="language-plaintext highlighter-rouge">Estimator</code>.
These stages are run in order, and the input <code class="language-plaintext highlighter-rouge">DataFrame</code> is transformed as it passes through each stage.
For <code class="language-plaintext highlighter-rouge">Transformer</code> stages, the <code class="language-plaintext highlighter-rouge">transform()</code> method is called on the <code class="language-plaintext highlighter-rouge">DataFrame</code>.
For <code class="language-plaintext highlighter-rouge">Estimator</code> stages, the <code class="language-plaintext highlighter-rouge">fit()</code> method is called to produce a <code class="language-plaintext highlighter-rouge">Transformer</code> (which becomes part of the <code class="language-plaintext highlighter-rouge">PipelineModel</code>, or fitted <code class="language-plaintext highlighter-rouge">Pipeline</code>), and that <code class="language-plaintext highlighter-rouge">Transformer</code>&#8217;s <code class="language-plaintext highlighter-rouge">transform()</code> method is called on the <code class="language-plaintext highlighter-rouge">DataFrame</code>.</p>
<p>We illustrate this for the simple text document workflow. The figure below is for the <em>training time</em> usage of a <code class="language-plaintext highlighter-rouge">Pipeline</code>.</p>
<p style="text-align: center;">
<img src="img/ml-Pipeline.png" title="ML Pipeline Example" alt="ML Pipeline Example" width="80%" />
</p>
<p>Above, the top row represents a <code class="language-plaintext highlighter-rouge">Pipeline</code> with three stages.
The first two (<code class="language-plaintext highlighter-rouge">Tokenizer</code> and <code class="language-plaintext highlighter-rouge">HashingTF</code>) are <code class="language-plaintext highlighter-rouge">Transformer</code>s (blue), and the third (<code class="language-plaintext highlighter-rouge">LogisticRegression</code>) is an <code class="language-plaintext highlighter-rouge">Estimator</code> (red).
The bottom row represents data flowing through the pipeline, where cylinders indicate <code class="language-plaintext highlighter-rouge">DataFrame</code>s.
The <code class="language-plaintext highlighter-rouge">Pipeline.fit()</code> method is called on the original <code class="language-plaintext highlighter-rouge">DataFrame</code>, which has raw text documents and labels.
The <code class="language-plaintext highlighter-rouge">Tokenizer.transform()</code> method splits the raw text documents into words, adding a new column with words to the <code class="language-plaintext highlighter-rouge">DataFrame</code>.
The <code class="language-plaintext highlighter-rouge">HashingTF.transform()</code> method converts the words column into feature vectors, adding a new column with those vectors to the <code class="language-plaintext highlighter-rouge">DataFrame</code>.
Now, since <code class="language-plaintext highlighter-rouge">LogisticRegression</code> is an <code class="language-plaintext highlighter-rouge">Estimator</code>, the <code class="language-plaintext highlighter-rouge">Pipeline</code> first calls <code class="language-plaintext highlighter-rouge">LogisticRegression.fit()</code> to produce a <code class="language-plaintext highlighter-rouge">LogisticRegressionModel</code>.
If the <code class="language-plaintext highlighter-rouge">Pipeline</code> had more <code class="language-plaintext highlighter-rouge">Estimator</code>s, it would call the <code class="language-plaintext highlighter-rouge">LogisticRegressionModel</code>&#8217;s <code class="language-plaintext highlighter-rouge">transform()</code>
method on the <code class="language-plaintext highlighter-rouge">DataFrame</code> before passing the <code class="language-plaintext highlighter-rouge">DataFrame</code> to the next stage.</p>
<p>A <code class="language-plaintext highlighter-rouge">Pipeline</code> is an <code class="language-plaintext highlighter-rouge">Estimator</code>.
Thus, after a <code class="language-plaintext highlighter-rouge">Pipeline</code>&#8217;s <code class="language-plaintext highlighter-rouge">fit()</code> method runs, it produces a <code class="language-plaintext highlighter-rouge">PipelineModel</code>, which is a
<code class="language-plaintext highlighter-rouge">Transformer</code>.
This <code class="language-plaintext highlighter-rouge">PipelineModel</code> is used at <em>test time</em>; the figure below illustrates this usage.</p>
<p style="text-align: center;">
<img src="img/ml-PipelineModel.png" title="ML PipelineModel Example" alt="ML PipelineModel Example" width="80%" />
</p>
<p>In the figure above, the <code class="language-plaintext highlighter-rouge">PipelineModel</code> has the same number of stages as the original <code class="language-plaintext highlighter-rouge">Pipeline</code>, but all <code class="language-plaintext highlighter-rouge">Estimator</code>s in the original <code class="language-plaintext highlighter-rouge">Pipeline</code> have become <code class="language-plaintext highlighter-rouge">Transformer</code>s.
When the <code class="language-plaintext highlighter-rouge">PipelineModel</code>&#8217;s <code class="language-plaintext highlighter-rouge">transform()</code> method is called on a test dataset, the data are passed
through the fitted pipeline in order.
Each stage&#8217;s <code class="language-plaintext highlighter-rouge">transform()</code> method updates the dataset and passes it to the next stage.</p>
<p><code class="language-plaintext highlighter-rouge">Pipeline</code>s and <code class="language-plaintext highlighter-rouge">PipelineModel</code>s help to ensure that training and test data go through identical feature processing steps.</p>
<h3 id="details">Details</h3>
<p><em>DAG <code class="language-plaintext highlighter-rouge">Pipeline</code>s</em>: A <code class="language-plaintext highlighter-rouge">Pipeline</code>&#8217;s stages are specified as an ordered array. The examples given here are all for linear <code class="language-plaintext highlighter-rouge">Pipeline</code>s, i.e., <code class="language-plaintext highlighter-rouge">Pipeline</code>s in which each stage uses data produced by the previous stage. It is possible to create non-linear <code class="language-plaintext highlighter-rouge">Pipeline</code>s as long as the data flow graph forms a Directed Acyclic Graph (DAG). This graph is currently specified implicitly based on the input and output column names of each stage (generally specified as parameters). If the <code class="language-plaintext highlighter-rouge">Pipeline</code> forms a DAG, then the stages must be specified in topological order.</p>
<p><em>Runtime checking</em>: Since <code class="language-plaintext highlighter-rouge">Pipeline</code>s can operate on <code class="language-plaintext highlighter-rouge">DataFrame</code>s with varied types, they cannot use
compile-time type checking.
<code class="language-plaintext highlighter-rouge">Pipeline</code>s and <code class="language-plaintext highlighter-rouge">PipelineModel</code>s instead do runtime checking before actually running the <code class="language-plaintext highlighter-rouge">Pipeline</code>.
This type checking is done using the <code class="language-plaintext highlighter-rouge">DataFrame</code> <em>schema</em>, a description of the data types of columns in the <code class="language-plaintext highlighter-rouge">DataFrame</code>.</p>
<p><em>Unique Pipeline stages</em>: A <code class="language-plaintext highlighter-rouge">Pipeline</code>&#8217;s stages should be unique instances. E.g., the same instance
<code class="language-plaintext highlighter-rouge">myHashingTF</code> should not be inserted into the <code class="language-plaintext highlighter-rouge">Pipeline</code> twice since <code class="language-plaintext highlighter-rouge">Pipeline</code> stages must have
unique IDs. However, different instances <code class="language-plaintext highlighter-rouge">myHashingTF1</code> and <code class="language-plaintext highlighter-rouge">myHashingTF2</code> (both of type <code class="language-plaintext highlighter-rouge">HashingTF</code>)
can be put into the same <code class="language-plaintext highlighter-rouge">Pipeline</code> since different instances will be created with different IDs.</p>
<h2 id="parameters">Parameters</h2>
<p>MLlib <code class="language-plaintext highlighter-rouge">Estimator</code>s and <code class="language-plaintext highlighter-rouge">Transformer</code>s use a uniform API for specifying parameters.</p>
<p>A <code class="language-plaintext highlighter-rouge">Param</code> is a named parameter with self-contained documentation.
A <code class="language-plaintext highlighter-rouge">ParamMap</code> is a set of (parameter, value) pairs.</p>
<p>There are two main ways to pass parameters to an algorithm:</p>
<ol>
<li>Set parameters for an instance. E.g., if <code class="language-plaintext highlighter-rouge">lr</code> is an instance of <code class="language-plaintext highlighter-rouge">LogisticRegression</code>, one could
call <code class="language-plaintext highlighter-rouge">lr.setMaxIter(10)</code> to make <code class="language-plaintext highlighter-rouge">lr.fit()</code> use at most 10 iterations.
This API resembles the API used in <code class="language-plaintext highlighter-rouge">spark.mllib</code> package.</li>
<li>Pass a <code class="language-plaintext highlighter-rouge">ParamMap</code> to <code class="language-plaintext highlighter-rouge">fit()</code> or <code class="language-plaintext highlighter-rouge">transform()</code>. Any parameters in the <code class="language-plaintext highlighter-rouge">ParamMap</code> will override parameters previously specified via setter methods.</li>
</ol>
<p>Parameters belong to specific instances of <code class="language-plaintext highlighter-rouge">Estimator</code>s and <code class="language-plaintext highlighter-rouge">Transformer</code>s.
For example, if we have two <code class="language-plaintext highlighter-rouge">LogisticRegression</code> instances <code class="language-plaintext highlighter-rouge">lr1</code> and <code class="language-plaintext highlighter-rouge">lr2</code>, then we can build a <code class="language-plaintext highlighter-rouge">ParamMap</code> with both <code class="language-plaintext highlighter-rouge">maxIter</code> parameters specified: <code class="language-plaintext highlighter-rouge">ParamMap(lr1.maxIter -&gt; 10, lr2.maxIter -&gt; 20)</code>.
This is useful if there are two algorithms with the <code class="language-plaintext highlighter-rouge">maxIter</code> parameter in a <code class="language-plaintext highlighter-rouge">Pipeline</code>.</p>
<h2 id="ml-persistence-saving-and-loading-pipelines">ML persistence: Saving and Loading Pipelines</h2>
<p>Often times it is worth it to save a model or a pipeline to disk for later use. In Spark 1.6, a model import/export functionality was added to the Pipeline API.
As of Spark 2.3, the DataFrame-based API in <code class="language-plaintext highlighter-rouge">spark.ml</code> and <code class="language-plaintext highlighter-rouge">pyspark.ml</code> has complete coverage.</p>
<p>ML persistence works across Scala, Java and Python. However, R currently uses a modified format,
so models saved in R can only be loaded back in R; this should be fixed in the future and is
tracked in <a href="https://issues.apache.org/jira/browse/SPARK-15572">SPARK-15572</a>.</p>
<h3 id="backwards-compatibility-for-ml-persistence">Backwards compatibility for ML persistence</h3>
<p>In general, MLlib maintains backwards compatibility for ML persistence. I.e., if you save an ML
model or Pipeline in one version of Spark, then you should be able to load it back and use it in a
future version of Spark. However, there are rare exceptions, described below.</p>
<p>Model persistence: Is a model or Pipeline saved using Apache Spark ML persistence in Spark
version X loadable by Spark version Y?</p>
<ul>
<li>Major versions: No guarantees, but best-effort.</li>
<li>Minor and patch versions: Yes; these are backwards compatible.</li>
<li>Note about the format: There are no guarantees for a stable persistence format, but model loading itself is designed to be backwards compatible.</li>
</ul>
<p>Model behavior: Does a model or Pipeline in Spark version X behave identically in Spark version Y?</p>
<ul>
<li>Major versions: No guarantees, but best-effort.</li>
<li>Minor and patch versions: Identical behavior, except for bug fixes.</li>
</ul>
<p>For both model persistence and model behavior, any breaking changes across a minor version or patch
version are reported in the Spark version release notes. If a breakage is not reported in release
notes, then it should be treated as a bug to be fixed.</p>
<h1 id="code-examples">Code examples</h1>
<p>This section gives code examples illustrating the functionality discussed above.
For more info, please refer to the API documentation
(<a href="api/scala/org/apache/spark/ml/package.html">Scala</a>,
<a href="api/java/org/apache/spark/ml/package-summary.html">Java</a>,
and <a href="api/python/reference/pyspark.ml.html">Python</a>).</p>
<h2 id="example-estimator-transformer-and-param">Example: Estimator, Transformer, and Param</h2>
<p>This example covers the concepts of <code class="language-plaintext highlighter-rouge">Estimator</code>, <code class="language-plaintext highlighter-rouge">Transformer</code>, and <code class="language-plaintext highlighter-rouge">Param</code>.</p>
<div class="codetabs">
<div data-lang="scala">
<p>Refer to the <a href="api/scala/org/apache/spark/ml/Estimator.html"><code class="language-plaintext highlighter-rouge">Estimator</code> Scala docs</a>,
the <a href="api/scala/org/apache/spark/ml/Transformer.html"><code class="language-plaintext highlighter-rouge">Transformer</code> Scala docs</a> and
the <a href="api/scala/org/apache/spark/ml/param/Params.html"><code class="language-plaintext highlighter-rouge">Params</code> Scala docs</a> for details on the API.</p>
<div class="highlight"><pre class="codehilite"><code><span class="k">import</span> <span class="nn">org.apache.spark.ml.classification.LogisticRegression</span>
<span class="k">import</span> <span class="nn">org.apache.spark.ml.linalg.</span><span class="o">{</span><span class="nc">Vector</span><span class="o">,</span> <span class="nc">Vectors</span><span class="o">}</span>
<span class="k">import</span> <span class="nn">org.apache.spark.ml.param.ParamMap</span>
<span class="k">import</span> <span class="nn">org.apache.spark.sql.Row</span>
<span class="c1">// Prepare training data from a list of (label, features) tuples.</span>
<span class="k">val</span> <span class="nv">training</span> <span class="k">=</span> <span class="nv">spark</span><span class="o">.</span><span class="py">createDataFrame</span><span class="o">(</span><span class="nc">Seq</span><span class="o">(</span>
<span class="o">(</span><span class="mf">1.0</span><span class="o">,</span> <span class="nv">Vectors</span><span class="o">.</span><span class="py">dense</span><span class="o">(</span><span class="mf">0.0</span><span class="o">,</span> <span class="mf">1.1</span><span class="o">,</span> <span class="mf">0.1</span><span class="o">)),</span>
<span class="o">(</span><span class="mf">0.0</span><span class="o">,</span> <span class="nv">Vectors</span><span class="o">.</span><span class="py">dense</span><span class="o">(</span><span class="mf">2.0</span><span class="o">,</span> <span class="mf">1.0</span><span class="o">,</span> <span class="o">-</span><span class="mf">1.0</span><span class="o">)),</span>
<span class="o">(</span><span class="mf">0.0</span><span class="o">,</span> <span class="nv">Vectors</span><span class="o">.</span><span class="py">dense</span><span class="o">(</span><span class="mf">2.0</span><span class="o">,</span> <span class="mf">1.3</span><span class="o">,</span> <span class="mf">1.0</span><span class="o">)),</span>
<span class="o">(</span><span class="mf">1.0</span><span class="o">,</span> <span class="nv">Vectors</span><span class="o">.</span><span class="py">dense</span><span class="o">(</span><span class="mf">0.0</span><span class="o">,</span> <span class="mf">1.2</span><span class="o">,</span> <span class="o">-</span><span class="mf">0.5</span><span class="o">))</span>
<span class="o">)).</span><span class="py">toDF</span><span class="o">(</span><span class="s">"label"</span><span class="o">,</span> <span class="s">"features"</span><span class="o">)</span>
<span class="c1">// Create a LogisticRegression instance. This instance is an Estimator.</span>
<span class="k">val</span> <span class="nv">lr</span> <span class="k">=</span> <span class="k">new</span> <span class="nc">LogisticRegression</span><span class="o">()</span>
<span class="c1">// Print out the parameters, documentation, and any default values.</span>
<span class="nf">println</span><span class="o">(</span><span class="n">s</span><span class="s">"LogisticRegression parameters:\n ${lr.explainParams()}\n"</span><span class="o">)</span>
<span class="c1">// We may set parameters using setter methods.</span>
<span class="nv">lr</span><span class="o">.</span><span class="py">setMaxIter</span><span class="o">(</span><span class="mi">10</span><span class="o">)</span>
<span class="o">.</span><span class="py">setRegParam</span><span class="o">(</span><span class="mf">0.01</span><span class="o">)</span>
<span class="c1">// Learn a LogisticRegression model. This uses the parameters stored in lr.</span>
<span class="k">val</span> <span class="nv">model1</span> <span class="k">=</span> <span class="nv">lr</span><span class="o">.</span><span class="py">fit</span><span class="o">(</span><span class="n">training</span><span class="o">)</span>
<span class="c1">// Since model1 is a Model (i.e., a Transformer produced by an Estimator),</span>
<span class="c1">// we can view the parameters it used during fit().</span>
<span class="c1">// This prints the parameter (name: value) pairs, where names are unique IDs for this</span>
<span class="c1">// LogisticRegression instance.</span>
<span class="nf">println</span><span class="o">(</span><span class="n">s</span><span class="s">"Model 1 was fit using parameters: ${model1.parent.extractParamMap}"</span><span class="o">)</span>
<span class="c1">// We may alternatively specify parameters using a ParamMap,</span>
<span class="c1">// which supports several methods for specifying parameters.</span>
<span class="k">val</span> <span class="nv">paramMap</span> <span class="k">=</span> <span class="nc">ParamMap</span><span class="o">(</span><span class="nv">lr</span><span class="o">.</span><span class="py">maxIter</span> <span class="o">-&gt;</span> <span class="mi">20</span><span class="o">)</span>
<span class="o">.</span><span class="py">put</span><span class="o">(</span><span class="nv">lr</span><span class="o">.</span><span class="py">maxIter</span><span class="o">,</span> <span class="mi">30</span><span class="o">)</span> <span class="c1">// Specify 1 Param. This overwrites the original maxIter.</span>
<span class="o">.</span><span class="py">put</span><span class="o">(</span><span class="nv">lr</span><span class="o">.</span><span class="py">regParam</span> <span class="o">-&gt;</span> <span class="mf">0.1</span><span class="o">,</span> <span class="nv">lr</span><span class="o">.</span><span class="py">threshold</span> <span class="o">-&gt;</span> <span class="mf">0.55</span><span class="o">)</span> <span class="c1">// Specify multiple Params.</span>
<span class="c1">// One can also combine ParamMaps.</span>
<span class="k">val</span> <span class="nv">paramMap2</span> <span class="k">=</span> <span class="nc">ParamMap</span><span class="o">(</span><span class="nv">lr</span><span class="o">.</span><span class="py">probabilityCol</span> <span class="o">-&gt;</span> <span class="s">"myProbability"</span><span class="o">)</span> <span class="c1">// Change output column name.</span>
<span class="k">val</span> <span class="nv">paramMapCombined</span> <span class="k">=</span> <span class="n">paramMap</span> <span class="o">++</span> <span class="n">paramMap2</span>
<span class="c1">// Now learn a new model using the paramMapCombined parameters.</span>
<span class="c1">// paramMapCombined overrides all parameters set earlier via lr.set* methods.</span>
<span class="k">val</span> <span class="nv">model2</span> <span class="k">=</span> <span class="nv">lr</span><span class="o">.</span><span class="py">fit</span><span class="o">(</span><span class="n">training</span><span class="o">,</span> <span class="n">paramMapCombined</span><span class="o">)</span>
<span class="nf">println</span><span class="o">(</span><span class="n">s</span><span class="s">"Model 2 was fit using parameters: ${model2.parent.extractParamMap}"</span><span class="o">)</span>
<span class="c1">// Prepare test data.</span>
<span class="k">val</span> <span class="nv">test</span> <span class="k">=</span> <span class="nv">spark</span><span class="o">.</span><span class="py">createDataFrame</span><span class="o">(</span><span class="nc">Seq</span><span class="o">(</span>
<span class="o">(</span><span class="mf">1.0</span><span class="o">,</span> <span class="nv">Vectors</span><span class="o">.</span><span class="py">dense</span><span class="o">(-</span><span class="mf">1.0</span><span class="o">,</span> <span class="mf">1.5</span><span class="o">,</span> <span class="mf">1.3</span><span class="o">)),</span>
<span class="o">(</span><span class="mf">0.0</span><span class="o">,</span> <span class="nv">Vectors</span><span class="o">.</span><span class="py">dense</span><span class="o">(</span><span class="mf">3.0</span><span class="o">,</span> <span class="mf">2.0</span><span class="o">,</span> <span class="o">-</span><span class="mf">0.1</span><span class="o">)),</span>
<span class="o">(</span><span class="mf">1.0</span><span class="o">,</span> <span class="nv">Vectors</span><span class="o">.</span><span class="py">dense</span><span class="o">(</span><span class="mf">0.0</span><span class="o">,</span> <span class="mf">2.2</span><span class="o">,</span> <span class="o">-</span><span class="mf">1.5</span><span class="o">))</span>
<span class="o">)).</span><span class="py">toDF</span><span class="o">(</span><span class="s">"label"</span><span class="o">,</span> <span class="s">"features"</span><span class="o">)</span>
<span class="c1">// Make predictions on test data using the Transformer.transform() method.</span>
<span class="c1">// LogisticRegression.transform will only use the 'features' column.</span>
<span class="c1">// Note that model2.transform() outputs a 'myProbability' column instead of the usual</span>
<span class="c1">// 'probability' column since we renamed the lr.probabilityCol parameter previously.</span>
<span class="nv">model2</span><span class="o">.</span><span class="py">transform</span><span class="o">(</span><span class="n">test</span><span class="o">)</span>
<span class="o">.</span><span class="py">select</span><span class="o">(</span><span class="s">"features"</span><span class="o">,</span> <span class="s">"label"</span><span class="o">,</span> <span class="s">"myProbability"</span><span class="o">,</span> <span class="s">"prediction"</span><span class="o">)</span>
<span class="o">.</span><span class="py">collect</span><span class="o">()</span>
<span class="o">.</span><span class="py">foreach</span> <span class="o">{</span> <span class="k">case</span> <span class="nc">Row</span><span class="o">(</span><span class="n">features</span><span class="k">:</span> <span class="kt">Vector</span><span class="o">,</span> <span class="n">label</span><span class="k">:</span> <span class="kt">Double</span><span class="o">,</span> <span class="n">prob</span><span class="k">:</span> <span class="kt">Vector</span><span class="o">,</span> <span class="n">prediction</span><span class="k">:</span> <span class="kt">Double</span><span class="o">)</span> <span class="k">=&gt;</span>
<span class="nf">println</span><span class="o">(</span><span class="n">s</span><span class="s">"($features, $label) -&gt; prob=$prob, prediction=$prediction"</span><span class="o">)</span>
<span class="o">}</span></code></pre></div>
<div><small>Find full example code at "examples/src/main/scala/org/apache/spark/examples/ml/EstimatorTransformerParamExample.scala" in the Spark repo.</small></div>
</div>
<div data-lang="java">
<p>Refer to the <a href="api/java/org/apache/spark/ml/Estimator.html"><code class="language-plaintext highlighter-rouge">Estimator</code> Java docs</a>,
the <a href="api/java/org/apache/spark/ml/Transformer.html"><code class="language-plaintext highlighter-rouge">Transformer</code> Java docs</a> and
the <a href="api/java/org/apache/spark/ml/param/Params.html"><code class="language-plaintext highlighter-rouge">Params</code> Java docs</a> for details on the API.</p>
<div class="highlight"><pre class="codehilite"><code><span class="kn">import</span> <span class="nn">java.util.Arrays</span><span class="o">;</span>
<span class="kn">import</span> <span class="nn">java.util.List</span><span class="o">;</span>
<span class="kn">import</span> <span class="nn">org.apache.spark.ml.classification.LogisticRegression</span><span class="o">;</span>
<span class="kn">import</span> <span class="nn">org.apache.spark.ml.classification.LogisticRegressionModel</span><span class="o">;</span>
<span class="kn">import</span> <span class="nn">org.apache.spark.ml.linalg.VectorUDT</span><span class="o">;</span>
<span class="kn">import</span> <span class="nn">org.apache.spark.ml.linalg.Vectors</span><span class="o">;</span>
<span class="kn">import</span> <span class="nn">org.apache.spark.ml.param.ParamMap</span><span class="o">;</span>
<span class="kn">import</span> <span class="nn">org.apache.spark.sql.Dataset</span><span class="o">;</span>
<span class="kn">import</span> <span class="nn">org.apache.spark.sql.Row</span><span class="o">;</span>
<span class="kn">import</span> <span class="nn">org.apache.spark.sql.RowFactory</span><span class="o">;</span>
<span class="kn">import</span> <span class="nn">org.apache.spark.sql.types.DataTypes</span><span class="o">;</span>
<span class="kn">import</span> <span class="nn">org.apache.spark.sql.types.Metadata</span><span class="o">;</span>
<span class="kn">import</span> <span class="nn">org.apache.spark.sql.types.StructField</span><span class="o">;</span>
<span class="kn">import</span> <span class="nn">org.apache.spark.sql.types.StructType</span><span class="o">;</span>
<span class="c1">// Prepare training data.</span>
<span class="nc">List</span><span class="o">&lt;</span><span class="nc">Row</span><span class="o">&gt;</span> <span class="n">dataTraining</span> <span class="o">=</span> <span class="nc">Arrays</span><span class="o">.</span><span class="na">asList</span><span class="o">(</span>
<span class="nc">RowFactory</span><span class="o">.</span><span class="na">create</span><span class="o">(</span><span class="mf">1.0</span><span class="o">,</span> <span class="nc">Vectors</span><span class="o">.</span><span class="na">dense</span><span class="o">(</span><span class="mf">0.0</span><span class="o">,</span> <span class="mf">1.1</span><span class="o">,</span> <span class="mf">0.1</span><span class="o">)),</span>
<span class="nc">RowFactory</span><span class="o">.</span><span class="na">create</span><span class="o">(</span><span class="mf">0.0</span><span class="o">,</span> <span class="nc">Vectors</span><span class="o">.</span><span class="na">dense</span><span class="o">(</span><span class="mf">2.0</span><span class="o">,</span> <span class="mf">1.0</span><span class="o">,</span> <span class="o">-</span><span class="mf">1.0</span><span class="o">)),</span>
<span class="nc">RowFactory</span><span class="o">.</span><span class="na">create</span><span class="o">(</span><span class="mf">0.0</span><span class="o">,</span> <span class="nc">Vectors</span><span class="o">.</span><span class="na">dense</span><span class="o">(</span><span class="mf">2.0</span><span class="o">,</span> <span class="mf">1.3</span><span class="o">,</span> <span class="mf">1.0</span><span class="o">)),</span>
<span class="nc">RowFactory</span><span class="o">.</span><span class="na">create</span><span class="o">(</span><span class="mf">1.0</span><span class="o">,</span> <span class="nc">Vectors</span><span class="o">.</span><span class="na">dense</span><span class="o">(</span><span class="mf">0.0</span><span class="o">,</span> <span class="mf">1.2</span><span class="o">,</span> <span class="o">-</span><span class="mf">0.5</span><span class="o">))</span>
<span class="o">);</span>
<span class="nc">StructType</span> <span class="n">schema</span> <span class="o">=</span> <span class="k">new</span> <span class="nc">StructType</span><span class="o">(</span><span class="k">new</span> <span class="nc">StructField</span><span class="o">[]{</span>
<span class="k">new</span> <span class="nf">StructField</span><span class="o">(</span><span class="s">"label"</span><span class="o">,</span> <span class="nc">DataTypes</span><span class="o">.</span><span class="na">DoubleType</span><span class="o">,</span> <span class="kc">false</span><span class="o">,</span> <span class="nc">Metadata</span><span class="o">.</span><span class="na">empty</span><span class="o">()),</span>
<span class="k">new</span> <span class="nf">StructField</span><span class="o">(</span><span class="s">"features"</span><span class="o">,</span> <span class="k">new</span> <span class="nc">VectorUDT</span><span class="o">(),</span> <span class="kc">false</span><span class="o">,</span> <span class="nc">Metadata</span><span class="o">.</span><span class="na">empty</span><span class="o">())</span>
<span class="o">});</span>
<span class="nc">Dataset</span><span class="o">&lt;</span><span class="nc">Row</span><span class="o">&gt;</span> <span class="n">training</span> <span class="o">=</span> <span class="n">spark</span><span class="o">.</span><span class="na">createDataFrame</span><span class="o">(</span><span class="n">dataTraining</span><span class="o">,</span> <span class="n">schema</span><span class="o">);</span>
<span class="c1">// Create a LogisticRegression instance. This instance is an Estimator.</span>
<span class="nc">LogisticRegression</span> <span class="n">lr</span> <span class="o">=</span> <span class="k">new</span> <span class="nc">LogisticRegression</span><span class="o">();</span>
<span class="c1">// Print out the parameters, documentation, and any default values.</span>
<span class="nc">System</span><span class="o">.</span><span class="na">out</span><span class="o">.</span><span class="na">println</span><span class="o">(</span><span class="s">"LogisticRegression parameters:\n"</span> <span class="o">+</span> <span class="n">lr</span><span class="o">.</span><span class="na">explainParams</span><span class="o">()</span> <span class="o">+</span> <span class="s">"\n"</span><span class="o">);</span>
<span class="c1">// We may set parameters using setter methods.</span>
<span class="n">lr</span><span class="o">.</span><span class="na">setMaxIter</span><span class="o">(</span><span class="mi">10</span><span class="o">).</span><span class="na">setRegParam</span><span class="o">(</span><span class="mf">0.01</span><span class="o">);</span>
<span class="c1">// Learn a LogisticRegression model. This uses the parameters stored in lr.</span>
<span class="nc">LogisticRegressionModel</span> <span class="n">model1</span> <span class="o">=</span> <span class="n">lr</span><span class="o">.</span><span class="na">fit</span><span class="o">(</span><span class="n">training</span><span class="o">);</span>
<span class="c1">// Since model1 is a Model (i.e., a Transformer produced by an Estimator),</span>
<span class="c1">// we can view the parameters it used during fit().</span>
<span class="c1">// This prints the parameter (name: value) pairs, where names are unique IDs for this</span>
<span class="c1">// LogisticRegression instance.</span>
<span class="nc">System</span><span class="o">.</span><span class="na">out</span><span class="o">.</span><span class="na">println</span><span class="o">(</span><span class="s">"Model 1 was fit using parameters: "</span> <span class="o">+</span> <span class="n">model1</span><span class="o">.</span><span class="na">parent</span><span class="o">().</span><span class="na">extractParamMap</span><span class="o">());</span>
<span class="c1">// We may alternatively specify parameters using a ParamMap.</span>
<span class="nc">ParamMap</span> <span class="n">paramMap</span> <span class="o">=</span> <span class="k">new</span> <span class="nc">ParamMap</span><span class="o">()</span>
<span class="o">.</span><span class="na">put</span><span class="o">(</span><span class="n">lr</span><span class="o">.</span><span class="na">maxIter</span><span class="o">().</span><span class="na">w</span><span class="o">(</span><span class="mi">20</span><span class="o">))</span> <span class="c1">// Specify 1 Param.</span>
<span class="o">.</span><span class="na">put</span><span class="o">(</span><span class="n">lr</span><span class="o">.</span><span class="na">maxIter</span><span class="o">(),</span> <span class="mi">30</span><span class="o">)</span> <span class="c1">// This overwrites the original maxIter.</span>
<span class="o">.</span><span class="na">put</span><span class="o">(</span><span class="n">lr</span><span class="o">.</span><span class="na">regParam</span><span class="o">().</span><span class="na">w</span><span class="o">(</span><span class="mf">0.1</span><span class="o">),</span> <span class="n">lr</span><span class="o">.</span><span class="na">threshold</span><span class="o">().</span><span class="na">w</span><span class="o">(</span><span class="mf">0.55</span><span class="o">));</span> <span class="c1">// Specify multiple Params.</span>
<span class="c1">// One can also combine ParamMaps.</span>
<span class="nc">ParamMap</span> <span class="n">paramMap2</span> <span class="o">=</span> <span class="k">new</span> <span class="nc">ParamMap</span><span class="o">()</span>
<span class="o">.</span><span class="na">put</span><span class="o">(</span><span class="n">lr</span><span class="o">.</span><span class="na">probabilityCol</span><span class="o">().</span><span class="na">w</span><span class="o">(</span><span class="s">"myProbability"</span><span class="o">));</span> <span class="c1">// Change output column name</span>
<span class="nc">ParamMap</span> <span class="n">paramMapCombined</span> <span class="o">=</span> <span class="n">paramMap</span><span class="o">.</span><span class="n">$plus$plus</span><span class="o">(</span><span class="n">paramMap2</span><span class="o">);</span>
<span class="c1">// Now learn a new model using the paramMapCombined parameters.</span>
<span class="c1">// paramMapCombined overrides all parameters set earlier via lr.set* methods.</span>
<span class="nc">LogisticRegressionModel</span> <span class="n">model2</span> <span class="o">=</span> <span class="n">lr</span><span class="o">.</span><span class="na">fit</span><span class="o">(</span><span class="n">training</span><span class="o">,</span> <span class="n">paramMapCombined</span><span class="o">);</span>
<span class="nc">System</span><span class="o">.</span><span class="na">out</span><span class="o">.</span><span class="na">println</span><span class="o">(</span><span class="s">"Model 2 was fit using parameters: "</span> <span class="o">+</span> <span class="n">model2</span><span class="o">.</span><span class="na">parent</span><span class="o">().</span><span class="na">extractParamMap</span><span class="o">());</span>
<span class="c1">// Prepare test documents.</span>
<span class="nc">List</span><span class="o">&lt;</span><span class="nc">Row</span><span class="o">&gt;</span> <span class="n">dataTest</span> <span class="o">=</span> <span class="nc">Arrays</span><span class="o">.</span><span class="na">asList</span><span class="o">(</span>
<span class="nc">RowFactory</span><span class="o">.</span><span class="na">create</span><span class="o">(</span><span class="mf">1.0</span><span class="o">,</span> <span class="nc">Vectors</span><span class="o">.</span><span class="na">dense</span><span class="o">(-</span><span class="mf">1.0</span><span class="o">,</span> <span class="mf">1.5</span><span class="o">,</span> <span class="mf">1.3</span><span class="o">)),</span>
<span class="nc">RowFactory</span><span class="o">.</span><span class="na">create</span><span class="o">(</span><span class="mf">0.0</span><span class="o">,</span> <span class="nc">Vectors</span><span class="o">.</span><span class="na">dense</span><span class="o">(</span><span class="mf">3.0</span><span class="o">,</span> <span class="mf">2.0</span><span class="o">,</span> <span class="o">-</span><span class="mf">0.1</span><span class="o">)),</span>
<span class="nc">RowFactory</span><span class="o">.</span><span class="na">create</span><span class="o">(</span><span class="mf">1.0</span><span class="o">,</span> <span class="nc">Vectors</span><span class="o">.</span><span class="na">dense</span><span class="o">(</span><span class="mf">0.0</span><span class="o">,</span> <span class="mf">2.2</span><span class="o">,</span> <span class="o">-</span><span class="mf">1.5</span><span class="o">))</span>
<span class="o">);</span>
<span class="nc">Dataset</span><span class="o">&lt;</span><span class="nc">Row</span><span class="o">&gt;</span> <span class="n">test</span> <span class="o">=</span> <span class="n">spark</span><span class="o">.</span><span class="na">createDataFrame</span><span class="o">(</span><span class="n">dataTest</span><span class="o">,</span> <span class="n">schema</span><span class="o">);</span>
<span class="c1">// Make predictions on test documents using the Transformer.transform() method.</span>
<span class="c1">// LogisticRegression.transform will only use the 'features' column.</span>
<span class="c1">// Note that model2.transform() outputs a 'myProbability' column instead of the usual</span>
<span class="c1">// 'probability' column since we renamed the lr.probabilityCol parameter previously.</span>
<span class="nc">Dataset</span><span class="o">&lt;</span><span class="nc">Row</span><span class="o">&gt;</span> <span class="n">results</span> <span class="o">=</span> <span class="n">model2</span><span class="o">.</span><span class="na">transform</span><span class="o">(</span><span class="n">test</span><span class="o">);</span>
<span class="nc">Dataset</span><span class="o">&lt;</span><span class="nc">Row</span><span class="o">&gt;</span> <span class="n">rows</span> <span class="o">=</span> <span class="n">results</span><span class="o">.</span><span class="na">select</span><span class="o">(</span><span class="s">"features"</span><span class="o">,</span> <span class="s">"label"</span><span class="o">,</span> <span class="s">"myProbability"</span><span class="o">,</span> <span class="s">"prediction"</span><span class="o">);</span>
<span class="k">for</span> <span class="o">(</span><span class="nc">Row</span> <span class="nl">r:</span> <span class="n">rows</span><span class="o">.</span><span class="na">collectAsList</span><span class="o">())</span> <span class="o">{</span>
<span class="nc">System</span><span class="o">.</span><span class="na">out</span><span class="o">.</span><span class="na">println</span><span class="o">(</span><span class="s">"("</span> <span class="o">+</span> <span class="n">r</span><span class="o">.</span><span class="na">get</span><span class="o">(</span><span class="mi">0</span><span class="o">)</span> <span class="o">+</span> <span class="s">", "</span> <span class="o">+</span> <span class="n">r</span><span class="o">.</span><span class="na">get</span><span class="o">(</span><span class="mi">1</span><span class="o">)</span> <span class="o">+</span> <span class="s">") -&gt; prob="</span> <span class="o">+</span> <span class="n">r</span><span class="o">.</span><span class="na">get</span><span class="o">(</span><span class="mi">2</span><span class="o">)</span>
<span class="o">+</span> <span class="s">", prediction="</span> <span class="o">+</span> <span class="n">r</span><span class="o">.</span><span class="na">get</span><span class="o">(</span><span class="mi">3</span><span class="o">));</span>
<span class="o">}</span></code></pre></div>
<div><small>Find full example code at "examples/src/main/java/org/apache/spark/examples/ml/JavaEstimatorTransformerParamExample.java" in the Spark repo.</small></div>
</div>
<div data-lang="python">
<p>Refer to the <a href="api/python/reference/api/pyspark.ml.Estimator.html"><code class="language-plaintext highlighter-rouge">Estimator</code> Python docs</a>,
the <a href="api/python/reference/api/pyspark.ml.Transformer.html"><code class="language-plaintext highlighter-rouge">Transformer</code> Python docs</a> and
the <a href="api/python/reference/api/pyspark.ml.param.Params.html"><code class="language-plaintext highlighter-rouge">Params</code> Python docs</a> for more details on the API.</p>
<div class="highlight"><pre class="codehilite"><code><span class="kn">from</span> <span class="nn">pyspark.ml.linalg</span> <span class="kn">import</span> <span class="n">Vectors</span>
<span class="kn">from</span> <span class="nn">pyspark.ml.classification</span> <span class="kn">import</span> <span class="n">LogisticRegression</span>
<span class="c1"># Prepare training data from a list of (label, features) tuples.
</span><span class="n">training</span> <span class="o">=</span> <span class="n">spark</span><span class="p">.</span><span class="n">createDataFrame</span><span class="p">([</span>
<span class="p">(</span><span class="mf">1.0</span><span class="p">,</span> <span class="n">Vectors</span><span class="p">.</span><span class="n">dense</span><span class="p">([</span><span class="mf">0.0</span><span class="p">,</span> <span class="mf">1.1</span><span class="p">,</span> <span class="mf">0.1</span><span class="p">])),</span>
<span class="p">(</span><span class="mf">0.0</span><span class="p">,</span> <span class="n">Vectors</span><span class="p">.</span><span class="n">dense</span><span class="p">([</span><span class="mf">2.0</span><span class="p">,</span> <span class="mf">1.0</span><span class="p">,</span> <span class="o">-</span><span class="mf">1.0</span><span class="p">])),</span>
<span class="p">(</span><span class="mf">0.0</span><span class="p">,</span> <span class="n">Vectors</span><span class="p">.</span><span class="n">dense</span><span class="p">([</span><span class="mf">2.0</span><span class="p">,</span> <span class="mf">1.3</span><span class="p">,</span> <span class="mf">1.0</span><span class="p">])),</span>
<span class="p">(</span><span class="mf">1.0</span><span class="p">,</span> <span class="n">Vectors</span><span class="p">.</span><span class="n">dense</span><span class="p">([</span><span class="mf">0.0</span><span class="p">,</span> <span class="mf">1.2</span><span class="p">,</span> <span class="o">-</span><span class="mf">0.5</span><span class="p">]))],</span> <span class="p">[</span><span class="s">"label"</span><span class="p">,</span> <span class="s">"features"</span><span class="p">])</span>
<span class="c1"># Create a LogisticRegression instance. This instance is an Estimator.
</span><span class="n">lr</span> <span class="o">=</span> <span class="n">LogisticRegression</span><span class="p">(</span><span class="n">maxIter</span><span class="o">=</span><span class="mi">10</span><span class="p">,</span> <span class="n">regParam</span><span class="o">=</span><span class="mf">0.01</span><span class="p">)</span>
<span class="c1"># Print out the parameters, documentation, and any default values.
</span><span class="k">print</span><span class="p">(</span><span class="s">"LogisticRegression parameters:</span><span class="se">\n</span><span class="s">"</span> <span class="o">+</span> <span class="n">lr</span><span class="p">.</span><span class="n">explainParams</span><span class="p">()</span> <span class="o">+</span> <span class="s">"</span><span class="se">\n</span><span class="s">"</span><span class="p">)</span>
<span class="c1"># Learn a LogisticRegression model. This uses the parameters stored in lr.
</span><span class="n">model1</span> <span class="o">=</span> <span class="n">lr</span><span class="p">.</span><span class="n">fit</span><span class="p">(</span><span class="n">training</span><span class="p">)</span>
<span class="c1"># Since model1 is a Model (i.e., a transformer produced by an Estimator),
# we can view the parameters it used during fit().
# This prints the parameter (name: value) pairs, where names are unique IDs for this
# LogisticRegression instance.
</span><span class="k">print</span><span class="p">(</span><span class="s">"Model 1 was fit using parameters: "</span><span class="p">)</span>
<span class="k">print</span><span class="p">(</span><span class="n">model1</span><span class="p">.</span><span class="n">extractParamMap</span><span class="p">())</span>
<span class="c1"># We may alternatively specify parameters using a Python dictionary as a paramMap
</span><span class="n">paramMap</span> <span class="o">=</span> <span class="p">{</span><span class="n">lr</span><span class="p">.</span><span class="n">maxIter</span><span class="p">:</span> <span class="mi">20</span><span class="p">}</span>
<span class="n">paramMap</span><span class="p">[</span><span class="n">lr</span><span class="p">.</span><span class="n">maxIter</span><span class="p">]</span> <span class="o">=</span> <span class="mi">30</span> <span class="c1"># Specify 1 Param, overwriting the original maxIter.
# Specify multiple Params.
</span><span class="n">paramMap</span><span class="p">.</span><span class="n">update</span><span class="p">({</span><span class="n">lr</span><span class="p">.</span><span class="n">regParam</span><span class="p">:</span> <span class="mf">0.1</span><span class="p">,</span> <span class="n">lr</span><span class="p">.</span><span class="n">threshold</span><span class="p">:</span> <span class="mf">0.55</span><span class="p">})</span> <span class="c1"># type: ignore
</span>
<span class="c1"># You can combine paramMaps, which are python dictionaries.
# Change output column name
</span><span class="n">paramMap2</span> <span class="o">=</span> <span class="p">{</span><span class="n">lr</span><span class="p">.</span><span class="n">probabilityCol</span><span class="p">:</span> <span class="s">"myProbability"</span><span class="p">}</span>
<span class="n">paramMapCombined</span> <span class="o">=</span> <span class="n">paramMap</span><span class="p">.</span><span class="n">copy</span><span class="p">()</span>
<span class="n">paramMapCombined</span><span class="p">.</span><span class="n">update</span><span class="p">(</span><span class="n">paramMap2</span><span class="p">)</span> <span class="c1"># type: ignore
</span>
<span class="c1"># Now learn a new model using the paramMapCombined parameters.
# paramMapCombined overrides all parameters set earlier via lr.set* methods.
</span><span class="n">model2</span> <span class="o">=</span> <span class="n">lr</span><span class="p">.</span><span class="n">fit</span><span class="p">(</span><span class="n">training</span><span class="p">,</span> <span class="n">paramMapCombined</span><span class="p">)</span>
<span class="k">print</span><span class="p">(</span><span class="s">"Model 2 was fit using parameters: "</span><span class="p">)</span>
<span class="k">print</span><span class="p">(</span><span class="n">model2</span><span class="p">.</span><span class="n">extractParamMap</span><span class="p">())</span>
<span class="c1"># Prepare test data
</span><span class="n">test</span> <span class="o">=</span> <span class="n">spark</span><span class="p">.</span><span class="n">createDataFrame</span><span class="p">([</span>
<span class="p">(</span><span class="mf">1.0</span><span class="p">,</span> <span class="n">Vectors</span><span class="p">.</span><span class="n">dense</span><span class="p">([</span><span class="o">-</span><span class="mf">1.0</span><span class="p">,</span> <span class="mf">1.5</span><span class="p">,</span> <span class="mf">1.3</span><span class="p">])),</span>
<span class="p">(</span><span class="mf">0.0</span><span class="p">,</span> <span class="n">Vectors</span><span class="p">.</span><span class="n">dense</span><span class="p">([</span><span class="mf">3.0</span><span class="p">,</span> <span class="mf">2.0</span><span class="p">,</span> <span class="o">-</span><span class="mf">0.1</span><span class="p">])),</span>
<span class="p">(</span><span class="mf">1.0</span><span class="p">,</span> <span class="n">Vectors</span><span class="p">.</span><span class="n">dense</span><span class="p">([</span><span class="mf">0.0</span><span class="p">,</span> <span class="mf">2.2</span><span class="p">,</span> <span class="o">-</span><span class="mf">1.5</span><span class="p">]))],</span> <span class="p">[</span><span class="s">"label"</span><span class="p">,</span> <span class="s">"features"</span><span class="p">])</span>
<span class="c1"># Make predictions on test data using the Transformer.transform() method.
# LogisticRegression.transform will only use the 'features' column.
# Note that model2.transform() outputs a "myProbability" column instead of the usual
# 'probability' column since we renamed the lr.probabilityCol parameter previously.
</span><span class="n">prediction</span> <span class="o">=</span> <span class="n">model2</span><span class="p">.</span><span class="n">transform</span><span class="p">(</span><span class="n">test</span><span class="p">)</span>
<span class="n">result</span> <span class="o">=</span> <span class="n">prediction</span><span class="p">.</span><span class="n">select</span><span class="p">(</span><span class="s">"features"</span><span class="p">,</span> <span class="s">"label"</span><span class="p">,</span> <span class="s">"myProbability"</span><span class="p">,</span> <span class="s">"prediction"</span><span class="p">)</span> \
<span class="p">.</span><span class="n">collect</span><span class="p">()</span>
<span class="k">for</span> <span class="n">row</span> <span class="ow">in</span> <span class="n">result</span><span class="p">:</span>
<span class="k">print</span><span class="p">(</span><span class="s">"features=%s, label=%s -&gt; prob=%s, prediction=%s"</span>
<span class="o">%</span> <span class="p">(</span><span class="n">row</span><span class="p">.</span><span class="n">features</span><span class="p">,</span> <span class="n">row</span><span class="p">.</span><span class="n">label</span><span class="p">,</span> <span class="n">row</span><span class="p">.</span><span class="n">myProbability</span><span class="p">,</span> <span class="n">row</span><span class="p">.</span><span class="n">prediction</span><span class="p">))</span></code></pre></div>
<div><small>Find full example code at "examples/src/main/python/ml/estimator_transformer_param_example.py" in the Spark repo.</small></div>
</div>
</div>
<h2 id="example-pipeline">Example: Pipeline</h2>
<p>This example follows the simple text document <code class="language-plaintext highlighter-rouge">Pipeline</code> illustrated in the figures above.</p>
<div class="codetabs">
<div data-lang="scala">
<p>Refer to the <a href="api/scala/org/apache/spark/ml/Pipeline.html"><code class="language-plaintext highlighter-rouge">Pipeline</code> Scala docs</a> for details on the API.</p>
<div class="highlight"><pre class="codehilite"><code><span class="k">import</span> <span class="nn">org.apache.spark.ml.</span><span class="o">{</span><span class="nc">Pipeline</span><span class="o">,</span> <span class="nc">PipelineModel</span><span class="o">}</span>
<span class="k">import</span> <span class="nn">org.apache.spark.ml.classification.LogisticRegression</span>
<span class="k">import</span> <span class="nn">org.apache.spark.ml.feature.</span><span class="o">{</span><span class="nc">HashingTF</span><span class="o">,</span> <span class="nc">Tokenizer</span><span class="o">}</span>
<span class="k">import</span> <span class="nn">org.apache.spark.ml.linalg.Vector</span>
<span class="k">import</span> <span class="nn">org.apache.spark.sql.Row</span>
<span class="c1">// Prepare training documents from a list of (id, text, label) tuples.</span>
<span class="k">val</span> <span class="nv">training</span> <span class="k">=</span> <span class="nv">spark</span><span class="o">.</span><span class="py">createDataFrame</span><span class="o">(</span><span class="nc">Seq</span><span class="o">(</span>
<span class="o">(</span><span class="mi">0L</span><span class="o">,</span> <span class="s">"a b c d e spark"</span><span class="o">,</span> <span class="mf">1.0</span><span class="o">),</span>
<span class="o">(</span><span class="mi">1L</span><span class="o">,</span> <span class="s">"b d"</span><span class="o">,</span> <span class="mf">0.0</span><span class="o">),</span>
<span class="o">(</span><span class="mi">2L</span><span class="o">,</span> <span class="s">"spark f g h"</span><span class="o">,</span> <span class="mf">1.0</span><span class="o">),</span>
<span class="o">(</span><span class="mi">3L</span><span class="o">,</span> <span class="s">"hadoop mapreduce"</span><span class="o">,</span> <span class="mf">0.0</span><span class="o">)</span>
<span class="o">)).</span><span class="py">toDF</span><span class="o">(</span><span class="s">"id"</span><span class="o">,</span> <span class="s">"text"</span><span class="o">,</span> <span class="s">"label"</span><span class="o">)</span>
<span class="c1">// Configure an ML pipeline, which consists of three stages: tokenizer, hashingTF, and lr.</span>
<span class="k">val</span> <span class="nv">tokenizer</span> <span class="k">=</span> <span class="k">new</span> <span class="nc">Tokenizer</span><span class="o">()</span>
<span class="o">.</span><span class="py">setInputCol</span><span class="o">(</span><span class="s">"text"</span><span class="o">)</span>
<span class="o">.</span><span class="py">setOutputCol</span><span class="o">(</span><span class="s">"words"</span><span class="o">)</span>
<span class="k">val</span> <span class="nv">hashingTF</span> <span class="k">=</span> <span class="k">new</span> <span class="nc">HashingTF</span><span class="o">()</span>
<span class="o">.</span><span class="py">setNumFeatures</span><span class="o">(</span><span class="mi">1000</span><span class="o">)</span>
<span class="o">.</span><span class="py">setInputCol</span><span class="o">(</span><span class="nv">tokenizer</span><span class="o">.</span><span class="py">getOutputCol</span><span class="o">)</span>
<span class="o">.</span><span class="py">setOutputCol</span><span class="o">(</span><span class="s">"features"</span><span class="o">)</span>
<span class="k">val</span> <span class="nv">lr</span> <span class="k">=</span> <span class="k">new</span> <span class="nc">LogisticRegression</span><span class="o">()</span>
<span class="o">.</span><span class="py">setMaxIter</span><span class="o">(</span><span class="mi">10</span><span class="o">)</span>
<span class="o">.</span><span class="py">setRegParam</span><span class="o">(</span><span class="mf">0.001</span><span class="o">)</span>
<span class="k">val</span> <span class="nv">pipeline</span> <span class="k">=</span> <span class="k">new</span> <span class="nc">Pipeline</span><span class="o">()</span>
<span class="o">.</span><span class="py">setStages</span><span class="o">(</span><span class="nc">Array</span><span class="o">(</span><span class="n">tokenizer</span><span class="o">,</span> <span class="n">hashingTF</span><span class="o">,</span> <span class="n">lr</span><span class="o">))</span>
<span class="c1">// Fit the pipeline to training documents.</span>
<span class="k">val</span> <span class="nv">model</span> <span class="k">=</span> <span class="nv">pipeline</span><span class="o">.</span><span class="py">fit</span><span class="o">(</span><span class="n">training</span><span class="o">)</span>
<span class="c1">// Now we can optionally save the fitted pipeline to disk</span>
<span class="nv">model</span><span class="o">.</span><span class="py">write</span><span class="o">.</span><span class="py">overwrite</span><span class="o">().</span><span class="py">save</span><span class="o">(</span><span class="s">"/tmp/spark-logistic-regression-model"</span><span class="o">)</span>
<span class="c1">// We can also save this unfit pipeline to disk</span>
<span class="nv">pipeline</span><span class="o">.</span><span class="py">write</span><span class="o">.</span><span class="py">overwrite</span><span class="o">().</span><span class="py">save</span><span class="o">(</span><span class="s">"/tmp/unfit-lr-model"</span><span class="o">)</span>
<span class="c1">// And load it back in during production</span>
<span class="k">val</span> <span class="nv">sameModel</span> <span class="k">=</span> <span class="nv">PipelineModel</span><span class="o">.</span><span class="py">load</span><span class="o">(</span><span class="s">"/tmp/spark-logistic-regression-model"</span><span class="o">)</span>
<span class="c1">// Prepare test documents, which are unlabeled (id, text) tuples.</span>
<span class="k">val</span> <span class="nv">test</span> <span class="k">=</span> <span class="nv">spark</span><span class="o">.</span><span class="py">createDataFrame</span><span class="o">(</span><span class="nc">Seq</span><span class="o">(</span>
<span class="o">(</span><span class="mi">4L</span><span class="o">,</span> <span class="s">"spark i j k"</span><span class="o">),</span>
<span class="o">(</span><span class="mi">5L</span><span class="o">,</span> <span class="s">"l m n"</span><span class="o">),</span>
<span class="o">(</span><span class="mi">6L</span><span class="o">,</span> <span class="s">"spark hadoop spark"</span><span class="o">),</span>
<span class="o">(</span><span class="mi">7L</span><span class="o">,</span> <span class="s">"apache hadoop"</span><span class="o">)</span>
<span class="o">)).</span><span class="py">toDF</span><span class="o">(</span><span class="s">"id"</span><span class="o">,</span> <span class="s">"text"</span><span class="o">)</span>
<span class="c1">// Make predictions on test documents.</span>
<span class="nv">model</span><span class="o">.</span><span class="py">transform</span><span class="o">(</span><span class="n">test</span><span class="o">)</span>
<span class="o">.</span><span class="py">select</span><span class="o">(</span><span class="s">"id"</span><span class="o">,</span> <span class="s">"text"</span><span class="o">,</span> <span class="s">"probability"</span><span class="o">,</span> <span class="s">"prediction"</span><span class="o">)</span>
<span class="o">.</span><span class="py">collect</span><span class="o">()</span>
<span class="o">.</span><span class="py">foreach</span> <span class="o">{</span> <span class="k">case</span> <span class="nc">Row</span><span class="o">(</span><span class="n">id</span><span class="k">:</span> <span class="kt">Long</span><span class="o">,</span> <span class="n">text</span><span class="k">:</span> <span class="kt">String</span><span class="o">,</span> <span class="n">prob</span><span class="k">:</span> <span class="kt">Vector</span><span class="o">,</span> <span class="n">prediction</span><span class="k">:</span> <span class="kt">Double</span><span class="o">)</span> <span class="k">=&gt;</span>
<span class="nf">println</span><span class="o">(</span><span class="n">s</span><span class="s">"($id, $text) --&gt; prob=$prob, prediction=$prediction"</span><span class="o">)</span>
<span class="o">}</span></code></pre></div>
<div><small>Find full example code at "examples/src/main/scala/org/apache/spark/examples/ml/PipelineExample.scala" in the Spark repo.</small></div>
</div>
<div data-lang="java">
<p>Refer to the <a href="api/java/org/apache/spark/ml/Pipeline.html"><code class="language-plaintext highlighter-rouge">Pipeline</code> Java docs</a> for details on the API.</p>
<div class="highlight"><pre class="codehilite"><code><span class="kn">import</span> <span class="nn">java.util.Arrays</span><span class="o">;</span>
<span class="kn">import</span> <span class="nn">org.apache.spark.ml.Pipeline</span><span class="o">;</span>
<span class="kn">import</span> <span class="nn">org.apache.spark.ml.PipelineModel</span><span class="o">;</span>
<span class="kn">import</span> <span class="nn">org.apache.spark.ml.PipelineStage</span><span class="o">;</span>
<span class="kn">import</span> <span class="nn">org.apache.spark.ml.classification.LogisticRegression</span><span class="o">;</span>
<span class="kn">import</span> <span class="nn">org.apache.spark.ml.feature.HashingTF</span><span class="o">;</span>
<span class="kn">import</span> <span class="nn">org.apache.spark.ml.feature.Tokenizer</span><span class="o">;</span>
<span class="kn">import</span> <span class="nn">org.apache.spark.sql.Dataset</span><span class="o">;</span>
<span class="kn">import</span> <span class="nn">org.apache.spark.sql.Row</span><span class="o">;</span>
<span class="c1">// Prepare training documents, which are labeled.</span>
<span class="nc">Dataset</span><span class="o">&lt;</span><span class="nc">Row</span><span class="o">&gt;</span> <span class="n">training</span> <span class="o">=</span> <span class="n">spark</span><span class="o">.</span><span class="na">createDataFrame</span><span class="o">(</span><span class="nc">Arrays</span><span class="o">.</span><span class="na">asList</span><span class="o">(</span>
<span class="k">new</span> <span class="nf">JavaLabeledDocument</span><span class="o">(</span><span class="mi">0L</span><span class="o">,</span> <span class="s">"a b c d e spark"</span><span class="o">,</span> <span class="mf">1.0</span><span class="o">),</span>
<span class="k">new</span> <span class="nf">JavaLabeledDocument</span><span class="o">(</span><span class="mi">1L</span><span class="o">,</span> <span class="s">"b d"</span><span class="o">,</span> <span class="mf">0.0</span><span class="o">),</span>
<span class="k">new</span> <span class="nf">JavaLabeledDocument</span><span class="o">(</span><span class="mi">2L</span><span class="o">,</span> <span class="s">"spark f g h"</span><span class="o">,</span> <span class="mf">1.0</span><span class="o">),</span>
<span class="k">new</span> <span class="nf">JavaLabeledDocument</span><span class="o">(</span><span class="mi">3L</span><span class="o">,</span> <span class="s">"hadoop mapreduce"</span><span class="o">,</span> <span class="mf">0.0</span><span class="o">)</span>
<span class="o">),</span> <span class="nc">JavaLabeledDocument</span><span class="o">.</span><span class="na">class</span><span class="o">);</span>
<span class="c1">// Configure an ML pipeline, which consists of three stages: tokenizer, hashingTF, and lr.</span>
<span class="nc">Tokenizer</span> <span class="n">tokenizer</span> <span class="o">=</span> <span class="k">new</span> <span class="nc">Tokenizer</span><span class="o">()</span>
<span class="o">.</span><span class="na">setInputCol</span><span class="o">(</span><span class="s">"text"</span><span class="o">)</span>
<span class="o">.</span><span class="na">setOutputCol</span><span class="o">(</span><span class="s">"words"</span><span class="o">);</span>
<span class="nc">HashingTF</span> <span class="n">hashingTF</span> <span class="o">=</span> <span class="k">new</span> <span class="nc">HashingTF</span><span class="o">()</span>
<span class="o">.</span><span class="na">setNumFeatures</span><span class="o">(</span><span class="mi">1000</span><span class="o">)</span>
<span class="o">.</span><span class="na">setInputCol</span><span class="o">(</span><span class="n">tokenizer</span><span class="o">.</span><span class="na">getOutputCol</span><span class="o">())</span>
<span class="o">.</span><span class="na">setOutputCol</span><span class="o">(</span><span class="s">"features"</span><span class="o">);</span>
<span class="nc">LogisticRegression</span> <span class="n">lr</span> <span class="o">=</span> <span class="k">new</span> <span class="nc">LogisticRegression</span><span class="o">()</span>
<span class="o">.</span><span class="na">setMaxIter</span><span class="o">(</span><span class="mi">10</span><span class="o">)</span>
<span class="o">.</span><span class="na">setRegParam</span><span class="o">(</span><span class="mf">0.001</span><span class="o">);</span>
<span class="nc">Pipeline</span> <span class="n">pipeline</span> <span class="o">=</span> <span class="k">new</span> <span class="nc">Pipeline</span><span class="o">()</span>
<span class="o">.</span><span class="na">setStages</span><span class="o">(</span><span class="k">new</span> <span class="nc">PipelineStage</span><span class="o">[]</span> <span class="o">{</span><span class="n">tokenizer</span><span class="o">,</span> <span class="n">hashingTF</span><span class="o">,</span> <span class="n">lr</span><span class="o">});</span>
<span class="c1">// Fit the pipeline to training documents.</span>
<span class="nc">PipelineModel</span> <span class="n">model</span> <span class="o">=</span> <span class="n">pipeline</span><span class="o">.</span><span class="na">fit</span><span class="o">(</span><span class="n">training</span><span class="o">);</span>
<span class="c1">// Prepare test documents, which are unlabeled.</span>
<span class="nc">Dataset</span><span class="o">&lt;</span><span class="nc">Row</span><span class="o">&gt;</span> <span class="n">test</span> <span class="o">=</span> <span class="n">spark</span><span class="o">.</span><span class="na">createDataFrame</span><span class="o">(</span><span class="nc">Arrays</span><span class="o">.</span><span class="na">asList</span><span class="o">(</span>
<span class="k">new</span> <span class="nf">JavaDocument</span><span class="o">(</span><span class="mi">4L</span><span class="o">,</span> <span class="s">"spark i j k"</span><span class="o">),</span>
<span class="k">new</span> <span class="nf">JavaDocument</span><span class="o">(</span><span class="mi">5L</span><span class="o">,</span> <span class="s">"l m n"</span><span class="o">),</span>
<span class="k">new</span> <span class="nf">JavaDocument</span><span class="o">(</span><span class="mi">6L</span><span class="o">,</span> <span class="s">"spark hadoop spark"</span><span class="o">),</span>
<span class="k">new</span> <span class="nf">JavaDocument</span><span class="o">(</span><span class="mi">7L</span><span class="o">,</span> <span class="s">"apache hadoop"</span><span class="o">)</span>
<span class="o">),</span> <span class="nc">JavaDocument</span><span class="o">.</span><span class="na">class</span><span class="o">);</span>
<span class="c1">// Make predictions on test documents.</span>
<span class="nc">Dataset</span><span class="o">&lt;</span><span class="nc">Row</span><span class="o">&gt;</span> <span class="n">predictions</span> <span class="o">=</span> <span class="n">model</span><span class="o">.</span><span class="na">transform</span><span class="o">(</span><span class="n">test</span><span class="o">);</span>
<span class="k">for</span> <span class="o">(</span><span class="nc">Row</span> <span class="n">r</span> <span class="o">:</span> <span class="n">predictions</span><span class="o">.</span><span class="na">select</span><span class="o">(</span><span class="s">"id"</span><span class="o">,</span> <span class="s">"text"</span><span class="o">,</span> <span class="s">"probability"</span><span class="o">,</span> <span class="s">"prediction"</span><span class="o">).</span><span class="na">collectAsList</span><span class="o">())</span> <span class="o">{</span>
<span class="nc">System</span><span class="o">.</span><span class="na">out</span><span class="o">.</span><span class="na">println</span><span class="o">(</span><span class="s">"("</span> <span class="o">+</span> <span class="n">r</span><span class="o">.</span><span class="na">get</span><span class="o">(</span><span class="mi">0</span><span class="o">)</span> <span class="o">+</span> <span class="s">", "</span> <span class="o">+</span> <span class="n">r</span><span class="o">.</span><span class="na">get</span><span class="o">(</span><span class="mi">1</span><span class="o">)</span> <span class="o">+</span> <span class="s">") --&gt; prob="</span> <span class="o">+</span> <span class="n">r</span><span class="o">.</span><span class="na">get</span><span class="o">(</span><span class="mi">2</span><span class="o">)</span>
<span class="o">+</span> <span class="s">", prediction="</span> <span class="o">+</span> <span class="n">r</span><span class="o">.</span><span class="na">get</span><span class="o">(</span><span class="mi">3</span><span class="o">));</span>
<span class="o">}</span></code></pre></div>
<div><small>Find full example code at "examples/src/main/java/org/apache/spark/examples/ml/JavaPipelineExample.java" in the Spark repo.</small></div>
</div>
<div data-lang="python">
<p>Refer to the <a href="api/python/reference/api/pyspark.ml.Pipeline.html"><code class="language-plaintext highlighter-rouge">Pipeline</code> Python docs</a> for more details on the API.</p>
<div class="highlight"><pre class="codehilite"><code><span class="kn">from</span> <span class="nn">pyspark.ml</span> <span class="kn">import</span> <span class="n">Pipeline</span>
<span class="kn">from</span> <span class="nn">pyspark.ml.classification</span> <span class="kn">import</span> <span class="n">LogisticRegression</span>
<span class="kn">from</span> <span class="nn">pyspark.ml.feature</span> <span class="kn">import</span> <span class="n">HashingTF</span><span class="p">,</span> <span class="n">Tokenizer</span>
<span class="c1"># Prepare training documents from a list of (id, text, label) tuples.
</span><span class="n">training</span> <span class="o">=</span> <span class="n">spark</span><span class="p">.</span><span class="n">createDataFrame</span><span class="p">([</span>
<span class="p">(</span><span class="mi">0</span><span class="p">,</span> <span class="s">"a b c d e spark"</span><span class="p">,</span> <span class="mf">1.0</span><span class="p">),</span>
<span class="p">(</span><span class="mi">1</span><span class="p">,</span> <span class="s">"b d"</span><span class="p">,</span> <span class="mf">0.0</span><span class="p">),</span>
<span class="p">(</span><span class="mi">2</span><span class="p">,</span> <span class="s">"spark f g h"</span><span class="p">,</span> <span class="mf">1.0</span><span class="p">),</span>
<span class="p">(</span><span class="mi">3</span><span class="p">,</span> <span class="s">"hadoop mapreduce"</span><span class="p">,</span> <span class="mf">0.0</span><span class="p">)</span>
<span class="p">],</span> <span class="p">[</span><span class="s">"id"</span><span class="p">,</span> <span class="s">"text"</span><span class="p">,</span> <span class="s">"label"</span><span class="p">])</span>
<span class="c1"># Configure an ML pipeline, which consists of three stages: tokenizer, hashingTF, and lr.
</span><span class="n">tokenizer</span> <span class="o">=</span> <span class="n">Tokenizer</span><span class="p">(</span><span class="n">inputCol</span><span class="o">=</span><span class="s">"text"</span><span class="p">,</span> <span class="n">outputCol</span><span class="o">=</span><span class="s">"words"</span><span class="p">)</span>
<span class="n">hashingTF</span> <span class="o">=</span> <span class="n">HashingTF</span><span class="p">(</span><span class="n">inputCol</span><span class="o">=</span><span class="n">tokenizer</span><span class="p">.</span><span class="n">getOutputCol</span><span class="p">(),</span> <span class="n">outputCol</span><span class="o">=</span><span class="s">"features"</span><span class="p">)</span>
<span class="n">lr</span> <span class="o">=</span> <span class="n">LogisticRegression</span><span class="p">(</span><span class="n">maxIter</span><span class="o">=</span><span class="mi">10</span><span class="p">,</span> <span class="n">regParam</span><span class="o">=</span><span class="mf">0.001</span><span class="p">)</span>
<span class="n">pipeline</span> <span class="o">=</span> <span class="n">Pipeline</span><span class="p">(</span><span class="n">stages</span><span class="o">=</span><span class="p">[</span><span class="n">tokenizer</span><span class="p">,</span> <span class="n">hashingTF</span><span class="p">,</span> <span class="n">lr</span><span class="p">])</span>
<span class="c1"># Fit the pipeline to training documents.
</span><span class="n">model</span> <span class="o">=</span> <span class="n">pipeline</span><span class="p">.</span><span class="n">fit</span><span class="p">(</span><span class="n">training</span><span class="p">)</span>
<span class="c1"># Prepare test documents, which are unlabeled (id, text) tuples.
</span><span class="n">test</span> <span class="o">=</span> <span class="n">spark</span><span class="p">.</span><span class="n">createDataFrame</span><span class="p">([</span>
<span class="p">(</span><span class="mi">4</span><span class="p">,</span> <span class="s">"spark i j k"</span><span class="p">),</span>
<span class="p">(</span><span class="mi">5</span><span class="p">,</span> <span class="s">"l m n"</span><span class="p">),</span>
<span class="p">(</span><span class="mi">6</span><span class="p">,</span> <span class="s">"spark hadoop spark"</span><span class="p">),</span>
<span class="p">(</span><span class="mi">7</span><span class="p">,</span> <span class="s">"apache hadoop"</span><span class="p">)</span>
<span class="p">],</span> <span class="p">[</span><span class="s">"id"</span><span class="p">,</span> <span class="s">"text"</span><span class="p">])</span>
<span class="c1"># Make predictions on test documents and print columns of interest.
</span><span class="n">prediction</span> <span class="o">=</span> <span class="n">model</span><span class="p">.</span><span class="n">transform</span><span class="p">(</span><span class="n">test</span><span class="p">)</span>
<span class="n">selected</span> <span class="o">=</span> <span class="n">prediction</span><span class="p">.</span><span class="n">select</span><span class="p">(</span><span class="s">"id"</span><span class="p">,</span> <span class="s">"text"</span><span class="p">,</span> <span class="s">"probability"</span><span class="p">,</span> <span class="s">"prediction"</span><span class="p">)</span>
<span class="k">for</span> <span class="n">row</span> <span class="ow">in</span> <span class="n">selected</span><span class="p">.</span><span class="n">collect</span><span class="p">():</span>
<span class="n">rid</span><span class="p">,</span> <span class="n">text</span><span class="p">,</span> <span class="n">prob</span><span class="p">,</span> <span class="n">prediction</span> <span class="o">=</span> <span class="n">row</span>
<span class="k">print</span><span class="p">(</span>
<span class="s">"(%d, %s) --&gt; prob=%s, prediction=%f"</span> <span class="o">%</span> <span class="p">(</span>
<span class="n">rid</span><span class="p">,</span> <span class="n">text</span><span class="p">,</span> <span class="nb">str</span><span class="p">(</span><span class="n">prob</span><span class="p">),</span> <span class="n">prediction</span> <span class="c1"># type: ignore
</span> <span class="p">)</span>
<span class="p">)</span></code></pre></div>
<div><small>Find full example code at "examples/src/main/python/ml/pipeline_example.py" in the Spark repo.</small></div>
</div>
</div>
<h2 id="model-selection-hyperparameter-tuning">Model selection (hyperparameter tuning)</h2>
<p>A big benefit of using ML Pipelines is hyperparameter optimization. See the <a href="ml-tuning.html">ML Tuning Guide</a> for more information on automatic model selection.</p>
</div>
<!-- /container -->
</div>
<script src="js/vendor/jquery-3.5.1.min.js"></script>
<script src="js/vendor/bootstrap.bundle.min.js"></script>
<script src="js/vendor/anchor.min.js"></script>
<script src="js/main.js"></script>
<script type="text/javascript" src="https://cdn.jsdelivr.net/npm/docsearch.js@2/dist/cdn/docsearch.min.js"></script>
<script type="text/javascript">
// DocSearch is entirely free and automated. DocSearch is built in two parts:
// 1. a crawler which we run on our own infrastructure every 24 hours. It follows every link
// in your website and extract content from every page it traverses. It then pushes this
// content to an Algolia index.
// 2. a JavaScript snippet to be inserted in your website that will bind this Algolia index
// to your search input and display its results in a dropdown UI. If you want to find more
// details on how works DocSearch, check the docs of DocSearch.
docsearch({
apiKey: 'd62f962a82bc9abb53471cb7b89da35e',
appId: 'RAI69RXRSK',
indexName: 'apache_spark',
inputSelector: '#docsearch-input',
enhancedSearchInput: true,
algoliaOptions: {
'facetFilters': ["version:3.4.3"]
},
debug: false // Set debug to true if you want to inspect the dropdown
});
</script>
<!-- MathJax Section -->
<script type="text/x-mathjax-config">
MathJax.Hub.Config({
TeX: { equationNumbers: { autoNumber: "AMS" } }
});
</script>
<script>
// Note that we load MathJax this way to work with local file (file://), HTTP and HTTPS.
// We could use "//cdn.mathjax...", but that won't support "file://".
(function(d, script) {
script = d.createElement('script');
script.type = 'text/javascript';
script.async = true;
script.onload = function(){
MathJax.Hub.Config({
tex2jax: {
inlineMath: [ ["$", "$"], ["\\\\(","\\\\)"] ],
displayMath: [ ["$$","$$"], ["\\[", "\\]"] ],
processEscapes: true,
skipTags: ['script', 'noscript', 'style', 'textarea', 'pre']
}
});
};
script.src = ('https:' == document.location.protocol ? 'https://' : 'http://') +
'cdnjs.cloudflare.com/ajax/libs/mathjax/2.7.1/MathJax.js' +
'?config=TeX-AMS-MML_HTMLorMML';
d.getElementsByTagName('head')[0].appendChild(script);
}(document));
</script>
</body>
</html>