| |
| <!DOCTYPE html> |
| <!--[if lt IE 7]> <html class="no-js lt-ie9 lt-ie8 lt-ie7"> <![endif]--> |
| <!--[if IE 7]> <html class="no-js lt-ie9 lt-ie8"> <![endif]--> |
| <!--[if IE 8]> <html class="no-js lt-ie9"> <![endif]--> |
| <!--[if gt IE 8]><!--> <html class="no-js"> <!--<![endif]--> |
| <head> |
| <meta charset="utf-8"> |
| <meta http-equiv="X-UA-Compatible" content="IE=edge,chrome=1"> |
| <title>ML Pipelines - Spark 2.1.1 Documentation</title> |
| |
| |
| |
| |
| <link rel="stylesheet" href="css/bootstrap.min.css"> |
| <style> |
| body { |
| padding-top: 60px; |
| padding-bottom: 40px; |
| } |
| </style> |
| <meta name="viewport" content="width=device-width"> |
| <link rel="stylesheet" href="css/bootstrap-responsive.min.css"> |
| <link rel="stylesheet" href="css/main.css"> |
| |
| <script src="js/vendor/modernizr-2.6.1-respond-1.1.0.min.js"></script> |
| |
| <link rel="stylesheet" href="css/pygments-default.css"> |
| |
| |
| <!-- Google analytics script --> |
| <script type="text/javascript"> |
| var _gaq = _gaq || []; |
| _gaq.push(['_setAccount', 'UA-32518208-2']); |
| _gaq.push(['_trackPageview']); |
| |
| (function() { |
| var ga = document.createElement('script'); ga.type = 'text/javascript'; ga.async = true; |
| ga.src = ('https:' == document.location.protocol ? 'https://ssl' : 'http://www') + '.google-analytics.com/ga.js'; |
| var s = document.getElementsByTagName('script')[0]; s.parentNode.insertBefore(ga, s); |
| })(); |
| </script> |
| |
| |
| </head> |
| <body> |
| <!--[if lt IE 7]> |
| <p class="chromeframe">You are using an outdated browser. <a href="http://browsehappy.com/">Upgrade your browser today</a> or <a href="http://www.google.com/chromeframe/?redirect=true">install Google Chrome Frame</a> to better experience this site.</p> |
| <![endif]--> |
| |
| <!-- This code is taken from http://twitter.github.com/bootstrap/examples/hero.html --> |
| |
| <div class="navbar navbar-fixed-top" id="topbar"> |
| <div class="navbar-inner"> |
| <div class="container"> |
| <div class="brand"><a href="index.html"> |
| <img src="img/spark-logo-hd.png" style="height:50px;"/></a><span class="version">2.1.1</span> |
| </div> |
| <ul class="nav"> |
| <!--TODO(andyk): Add class="active" attribute to li some how.--> |
| <li><a href="index.html">Overview</a></li> |
| |
| <li class="dropdown"> |
| <a href="#" class="dropdown-toggle" data-toggle="dropdown">Programming Guides<b class="caret"></b></a> |
| <ul class="dropdown-menu"> |
| <li><a href="quick-start.html">Quick Start</a></li> |
| <li><a href="programming-guide.html">Spark Programming Guide</a></li> |
| <li class="divider"></li> |
| <li><a href="streaming-programming-guide.html">Spark Streaming</a></li> |
| <li><a href="sql-programming-guide.html">DataFrames, Datasets and SQL</a></li> |
| <li><a href="structured-streaming-programming-guide.html">Structured Streaming</a></li> |
| <li><a href="ml-guide.html">MLlib (Machine Learning)</a></li> |
| <li><a href="graphx-programming-guide.html">GraphX (Graph Processing)</a></li> |
| <li><a href="sparkr.html">SparkR (R on Spark)</a></li> |
| </ul> |
| </li> |
| |
| <li class="dropdown"> |
| <a href="#" class="dropdown-toggle" data-toggle="dropdown">API Docs<b class="caret"></b></a> |
| <ul class="dropdown-menu"> |
| <li><a href="api/scala/index.html#org.apache.spark.package">Scala</a></li> |
| <li><a href="api/java/index.html">Java</a></li> |
| <li><a href="api/python/index.html">Python</a></li> |
| <li><a href="api/R/index.html">R</a></li> |
| </ul> |
| </li> |
| |
| <li class="dropdown"> |
| <a href="#" class="dropdown-toggle" data-toggle="dropdown">Deploying<b class="caret"></b></a> |
| <ul class="dropdown-menu"> |
| <li><a href="cluster-overview.html">Overview</a></li> |
| <li><a href="submitting-applications.html">Submitting Applications</a></li> |
| <li class="divider"></li> |
| <li><a href="spark-standalone.html">Spark Standalone</a></li> |
| <li><a href="running-on-mesos.html">Mesos</a></li> |
| <li><a href="running-on-yarn.html">YARN</a></li> |
| </ul> |
| </li> |
| |
| <li class="dropdown"> |
| <a href="api.html" class="dropdown-toggle" data-toggle="dropdown">More<b class="caret"></b></a> |
| <ul class="dropdown-menu"> |
| <li><a href="configuration.html">Configuration</a></li> |
| <li><a href="monitoring.html">Monitoring</a></li> |
| <li><a href="tuning.html">Tuning Guide</a></li> |
| <li><a href="job-scheduling.html">Job Scheduling</a></li> |
| <li><a href="security.html">Security</a></li> |
| <li><a href="hardware-provisioning.html">Hardware Provisioning</a></li> |
| <li class="divider"></li> |
| <li><a href="building-spark.html">Building Spark</a></li> |
| <li><a href="http://spark.apache.org/contributing.html">Contributing to Spark</a></li> |
| <li><a href="http://spark.apache.org/third-party-projects.html">Third Party Projects</a></li> |
| </ul> |
| </li> |
| </ul> |
| <!--<p class="navbar-text pull-right"><span class="version-text">v2.1.1</span></p>--> |
| </div> |
| </div> |
| </div> |
| |
| <div class="container-wrapper"> |
| |
| |
| <div class="left-menu-wrapper"> |
| <div class="left-menu"> |
| <h3><a href="ml-guide.html">MLlib: Main Guide</a></h3> |
| |
| <ul> |
| |
| <li> |
| <a href="ml-pipeline.html"> |
| |
| <b>Pipelines</b> |
| |
| </a> |
| </li> |
| |
| |
| <li> |
| <a href="ml-features.html"> |
| |
| Extracting, transforming and selecting features |
| |
| </a> |
| </li> |
| |
| |
| <li> |
| <a href="ml-classification-regression.html"> |
| |
| Classification and Regression |
| |
| </a> |
| </li> |
| |
| |
| <li> |
| <a href="ml-clustering.html"> |
| |
| Clustering |
| |
| </a> |
| </li> |
| |
| |
| <li> |
| <a href="ml-collaborative-filtering.html"> |
| |
| Collaborative filtering |
| |
| </a> |
| </li> |
| |
| |
| <li> |
| <a href="ml-tuning.html"> |
| |
| Model selection and tuning |
| |
| </a> |
| </li> |
| |
| |
| <li> |
| <a href="ml-advanced.html"> |
| |
| Advanced topics |
| |
| </a> |
| </li> |
| |
| |
| </ul> |
| |
| <h3><a href="mllib-guide.html">MLlib: RDD-based API Guide</a></h3> |
| |
| <ul> |
| |
| <li> |
| <a href="mllib-data-types.html"> |
| |
| Data types |
| |
| </a> |
| </li> |
| |
| |
| <li> |
| <a href="mllib-statistics.html"> |
| |
| Basic statistics |
| |
| </a> |
| </li> |
| |
| |
| <li> |
| <a href="mllib-classification-regression.html"> |
| |
| Classification and regression |
| |
| </a> |
| </li> |
| |
| |
| <li> |
| <a href="mllib-collaborative-filtering.html"> |
| |
| Collaborative filtering |
| |
| </a> |
| </li> |
| |
| |
| <li> |
| <a href="mllib-clustering.html"> |
| |
| Clustering |
| |
| </a> |
| </li> |
| |
| |
| <li> |
| <a href="mllib-dimensionality-reduction.html"> |
| |
| Dimensionality reduction |
| |
| </a> |
| </li> |
| |
| |
| <li> |
| <a href="mllib-feature-extraction.html"> |
| |
| Feature extraction and transformation |
| |
| </a> |
| </li> |
| |
| |
| <li> |
| <a href="mllib-frequent-pattern-mining.html"> |
| |
| Frequent pattern mining |
| |
| </a> |
| </li> |
| |
| |
| <li> |
| <a href="mllib-evaluation-metrics.html"> |
| |
| Evaluation metrics |
| |
| </a> |
| </li> |
| |
| |
| <li> |
| <a href="mllib-pmml-model-export.html"> |
| |
| PMML model export |
| |
| </a> |
| </li> |
| |
| |
| <li> |
| <a href="mllib-optimization.html"> |
| |
| Optimization (developer) |
| |
| </a> |
| </li> |
| |
| |
| </ul> |
| |
| </div> |
| </div> |
| <input id="nav-trigger" class="nav-trigger" checked type="checkbox"> |
| <label for="nav-trigger"></label> |
| <div class="content-with-sidebar" id="content"> |
| |
| <h1 class="title">ML Pipelines</h1> |
| |
| |
| <p><code>\[ |
| \newcommand{\R}{\mathbb{R}} |
| \newcommand{\E}{\mathbb{E}} |
| \newcommand{\x}{\mathbf{x}} |
| \newcommand{\y}{\mathbf{y}} |
| \newcommand{\wv}{\mathbf{w}} |
| \newcommand{\av}{\mathbf{\alpha}} |
| \newcommand{\bv}{\mathbf{b}} |
| \newcommand{\N}{\mathbb{N}} |
| \newcommand{\id}{\mathbf{I}} |
| \newcommand{\ind}{\mathbf{1}} |
| \newcommand{\0}{\mathbf{0}} |
| \newcommand{\unit}{\mathbf{e}} |
| \newcommand{\one}{\mathbf{1}} |
| \newcommand{\zero}{\mathbf{0}} |
| \]</code></p> |
| |
| <p>In this section, we introduce the concept of <strong><em>ML Pipelines</em></strong>. |
| ML Pipelines provide a uniform set of high-level APIs built on top of |
| <a href="sql-programming-guide.html">DataFrames</a> that help users create and tune practical |
| machine learning pipelines.</p> |
| |
| <p><strong>Table of Contents</strong></p> |
| |
| <ul id="markdown-toc"> |
| <li><a href="#main-concepts-in-pipelines" id="markdown-toc-main-concepts-in-pipelines">Main concepts in Pipelines</a> <ul> |
| <li><a href="#dataframe" id="markdown-toc-dataframe">DataFrame</a></li> |
| <li><a href="#pipeline-components" id="markdown-toc-pipeline-components">Pipeline components</a> <ul> |
| <li><a href="#transformers" id="markdown-toc-transformers">Transformers</a></li> |
| <li><a href="#estimators" id="markdown-toc-estimators">Estimators</a></li> |
| <li><a href="#properties-of-pipeline-components" id="markdown-toc-properties-of-pipeline-components">Properties of pipeline components</a></li> |
| </ul> |
| </li> |
| <li><a href="#pipeline" id="markdown-toc-pipeline">Pipeline</a> <ul> |
| <li><a href="#how-it-works" id="markdown-toc-how-it-works">How it works</a></li> |
| <li><a href="#details" id="markdown-toc-details">Details</a></li> |
| </ul> |
| </li> |
| <li><a href="#parameters" id="markdown-toc-parameters">Parameters</a></li> |
| <li><a href="#saving-and-loading-pipelines" id="markdown-toc-saving-and-loading-pipelines">Saving and Loading Pipelines</a></li> |
| </ul> |
| </li> |
| <li><a href="#code-examples" id="markdown-toc-code-examples">Code examples</a> <ul> |
| <li><a href="#example-estimator-transformer-and-param" id="markdown-toc-example-estimator-transformer-and-param">Example: Estimator, Transformer, and Param</a></li> |
| <li><a href="#example-pipeline" id="markdown-toc-example-pipeline">Example: Pipeline</a></li> |
| <li><a href="#model-selection-hyperparameter-tuning" id="markdown-toc-model-selection-hyperparameter-tuning">Model selection (hyperparameter tuning)</a></li> |
| </ul> |
| </li> |
| </ul> |
| |
| <h1 id="main-concepts-in-pipelines">Main concepts in Pipelines</h1> |
| |
| <p>MLlib standardizes APIs for machine learning algorithms to make it easier to combine multiple |
| algorithms into a single pipeline, or workflow. |
| This section covers the key concepts introduced by the Pipelines API, where the pipeline concept is |
| mostly inspired by the <a href="http://scikit-learn.org/">scikit-learn</a> project.</p> |
| |
| <ul> |
| <li> |
| <p><strong><a href="ml-pipeline.html#dataframe"><code>DataFrame</code></a></strong>: This ML API uses <code>DataFrame</code> from Spark SQL as an ML |
| dataset, which can hold a variety of data types. |
| E.g., a <code>DataFrame</code> could have different columns storing text, feature vectors, true labels, and predictions.</p> |
| </li> |
| <li> |
| <p><strong><a href="ml-pipeline.html#transformers"><code>Transformer</code></a></strong>: A <code>Transformer</code> is an algorithm which can transform one <code>DataFrame</code> into another <code>DataFrame</code>. |
| E.g., an ML model is a <code>Transformer</code> which transforms a <code>DataFrame</code> with features into a <code>DataFrame</code> with predictions.</p> |
| </li> |
| <li> |
| <p><strong><a href="ml-pipeline.html#estimators"><code>Estimator</code></a></strong>: An <code>Estimator</code> is an algorithm which can be fit on a <code>DataFrame</code> to produce a <code>Transformer</code>. |
| E.g., a learning algorithm is an <code>Estimator</code> which trains on a <code>DataFrame</code> and produces a model.</p> |
| </li> |
| <li> |
| <p><strong><a href="ml-pipeline.html#pipeline"><code>Pipeline</code></a></strong>: A <code>Pipeline</code> chains multiple <code>Transformer</code>s and <code>Estimator</code>s together to specify an ML workflow.</p> |
| </li> |
| <li> |
| <p><strong><a href="ml-pipeline.html#parameters"><code>Parameter</code></a></strong>: All <code>Transformer</code>s and <code>Estimator</code>s now share a common API for specifying parameters.</p> |
| </li> |
| </ul> |
| |
| <h2 id="dataframe">DataFrame</h2> |
| |
| <p>Machine learning can be applied to a wide variety of data types, such as vectors, text, images, and structured data. |
| This API adopts the <code>DataFrame</code> from Spark SQL in order to support a variety of data types.</p> |
| |
| <p><code>DataFrame</code> supports many basic and structured types; see the <a href="sql-programming-guide.html#data-types">Spark SQL datatype reference</a> for a list of supported types. |
| In addition to the types listed in the Spark SQL guide, <code>DataFrame</code> can use ML <a href="mllib-data-types.html#local-vector"><code>Vector</code></a> types.</p> |
| |
| <p>A <code>DataFrame</code> can be created either implicitly or explicitly from a regular <code>RDD</code>. See the code examples below and the <a href="sql-programming-guide.html">Spark SQL programming guide</a> for examples.</p> |
| |
| <p>Columns in a <code>DataFrame</code> are named. The code examples below use names such as “text,” “features,” and “label.”</p> |
| |
| <h2 id="pipeline-components">Pipeline components</h2> |
| |
| <h3 id="transformers">Transformers</h3> |
| |
| <p>A <code>Transformer</code> is an abstraction that includes feature transformers and learned models. |
| Technically, a <code>Transformer</code> implements a method <code>transform()</code>, which converts one <code>DataFrame</code> into |
| another, generally by appending one or more columns. |
| For example:</p> |
| |
| <ul> |
| <li>A feature transformer might take a <code>DataFrame</code>, read a column (e.g., text), map it into a new |
| column (e.g., feature vectors), and output a new <code>DataFrame</code> with the mapped column appended.</li> |
| <li>A learning model might take a <code>DataFrame</code>, read the column containing feature vectors, predict the |
| label for each feature vector, and output a new <code>DataFrame</code> with predicted labels appended as a |
| column.</li> |
| </ul> |
| |
| <h3 id="estimators">Estimators</h3> |
| |
| <p>An <code>Estimator</code> abstracts the concept of a learning algorithm or any algorithm that fits or trains on |
| data. |
| Technically, an <code>Estimator</code> implements a method <code>fit()</code>, which accepts a <code>DataFrame</code> and produces a |
| <code>Model</code>, which is a <code>Transformer</code>. |
| For example, a learning algorithm such as <code>LogisticRegression</code> is an <code>Estimator</code>, and calling |
| <code>fit()</code> trains a <code>LogisticRegressionModel</code>, which is a <code>Model</code> and hence a <code>Transformer</code>.</p> |
| |
| <h3 id="properties-of-pipeline-components">Properties of pipeline components</h3> |
| |
| <p><code>Transformer.transform()</code>s and <code>Estimator.fit()</code>s are both stateless. In the future, stateful algorithms may be supported via alternative concepts.</p> |
| |
| <p>Each instance of a <code>Transformer</code> or <code>Estimator</code> has a unique ID, which is useful in specifying parameters (discussed below).</p> |
| |
| <h2 id="pipeline">Pipeline</h2> |
| |
| <p>In machine learning, it is common to run a sequence of algorithms to process and learn from data. |
| E.g., a simple text document processing workflow might include several stages:</p> |
| |
| <ul> |
| <li>Split each document’s text into words.</li> |
| <li>Convert each document’s words into a numerical feature vector.</li> |
| <li>Learn a prediction model using the feature vectors and labels.</li> |
| </ul> |
| |
| <p>MLlib represents such a workflow as a <code>Pipeline</code>, which consists of a sequence of |
| <code>PipelineStage</code>s (<code>Transformer</code>s and <code>Estimator</code>s) to be run in a specific order. |
| We will use this simple workflow as a running example in this section.</p> |
| |
| <h3 id="how-it-works">How it works</h3> |
| |
| <p>A <code>Pipeline</code> is specified as a sequence of stages, and each stage is either a <code>Transformer</code> or an <code>Estimator</code>. |
| These stages are run in order, and the input <code>DataFrame</code> is transformed as it passes through each stage. |
| For <code>Transformer</code> stages, the <code>transform()</code> method is called on the <code>DataFrame</code>. |
| For <code>Estimator</code> stages, the <code>fit()</code> method is called to produce a <code>Transformer</code> (which becomes part of the <code>PipelineModel</code>, or fitted <code>Pipeline</code>), and that <code>Transformer</code>’s <code>transform()</code> method is called on the <code>DataFrame</code>.</p> |
| |
| <p>We illustrate this for the simple text document workflow. The figure below is for the <em>training time</em> usage of a <code>Pipeline</code>.</p> |
| |
| <p style="text-align: center;"> |
| <img src="img/ml-Pipeline.png" title="ML Pipeline Example" alt="ML Pipeline Example" width="80%" /> |
| </p> |
| |
| <p>Above, the top row represents a <code>Pipeline</code> with three stages. |
| The first two (<code>Tokenizer</code> and <code>HashingTF</code>) are <code>Transformer</code>s (blue), and the third (<code>LogisticRegression</code>) is an <code>Estimator</code> (red). |
| The bottom row represents data flowing through the pipeline, where cylinders indicate <code>DataFrame</code>s. |
| The <code>Pipeline.fit()</code> method is called on the original <code>DataFrame</code>, which has raw text documents and labels. |
| The <code>Tokenizer.transform()</code> method splits the raw text documents into words, adding a new column with words to the <code>DataFrame</code>. |
| The <code>HashingTF.transform()</code> method converts the words column into feature vectors, adding a new column with those vectors to the <code>DataFrame</code>. |
| Now, since <code>LogisticRegression</code> is an <code>Estimator</code>, the <code>Pipeline</code> first calls <code>LogisticRegression.fit()</code> to produce a <code>LogisticRegressionModel</code>. |
| If the <code>Pipeline</code> had more <code>Estimator</code>s, it would call the <code>LogisticRegressionModel</code>’s <code>transform()</code> |
| method on the <code>DataFrame</code> before passing the <code>DataFrame</code> to the next stage.</p> |
| |
| <p>A <code>Pipeline</code> is an <code>Estimator</code>. |
| Thus, after a <code>Pipeline</code>’s <code>fit()</code> method runs, it produces a <code>PipelineModel</code>, which is a |
| <code>Transformer</code>. |
| This <code>PipelineModel</code> is used at <em>test time</em>; the figure below illustrates this usage.</p> |
| |
| <p style="text-align: center;"> |
| <img src="img/ml-PipelineModel.png" title="ML PipelineModel Example" alt="ML PipelineModel Example" width="80%" /> |
| </p> |
| |
| <p>In the figure above, the <code>PipelineModel</code> has the same number of stages as the original <code>Pipeline</code>, but all <code>Estimator</code>s in the original <code>Pipeline</code> have become <code>Transformer</code>s. |
| When the <code>PipelineModel</code>’s <code>transform()</code> method is called on a test dataset, the data are passed |
| through the fitted pipeline in order. |
| Each stage’s <code>transform()</code> method updates the dataset and passes it to the next stage.</p> |
| |
| <p><code>Pipeline</code>s and <code>PipelineModel</code>s help to ensure that training and test data go through identical feature processing steps.</p> |
| |
| <h3 id="details">Details</h3> |
| |
| <p><em>DAG <code>Pipeline</code>s</em>: A <code>Pipeline</code>’s stages are specified as an ordered array. The examples given here are all for linear <code>Pipeline</code>s, i.e., <code>Pipeline</code>s in which each stage uses data produced by the previous stage. It is possible to create non-linear <code>Pipeline</code>s as long as the data flow graph forms a Directed Acyclic Graph (DAG). This graph is currently specified implicitly based on the input and output column names of each stage (generally specified as parameters). If the <code>Pipeline</code> forms a DAG, then the stages must be specified in topological order.</p> |
| |
| <p><em>Runtime checking</em>: Since <code>Pipeline</code>s can operate on <code>DataFrame</code>s with varied types, they cannot use |
| compile-time type checking. |
| <code>Pipeline</code>s and <code>PipelineModel</code>s instead do runtime checking before actually running the <code>Pipeline</code>. |
| This type checking is done using the <code>DataFrame</code> <em>schema</em>, a description of the data types of columns in the <code>DataFrame</code>.</p> |
| |
| <p><em>Unique Pipeline stages</em>: A <code>Pipeline</code>’s stages should be unique instances. E.g., the same instance |
| <code>myHashingTF</code> should not be inserted into the <code>Pipeline</code> twice since <code>Pipeline</code> stages must have |
| unique IDs. However, different instances <code>myHashingTF1</code> and <code>myHashingTF2</code> (both of type <code>HashingTF</code>) |
| can be put into the same <code>Pipeline</code> since different instances will be created with different IDs.</p> |
| |
| <h2 id="parameters">Parameters</h2> |
| |
| <p>MLlib <code>Estimator</code>s and <code>Transformer</code>s use a uniform API for specifying parameters.</p> |
| |
| <p>A <code>Param</code> is a named parameter with self-contained documentation. |
| A <code>ParamMap</code> is a set of (parameter, value) pairs.</p> |
| |
| <p>There are two main ways to pass parameters to an algorithm:</p> |
| |
| <ol> |
| <li>Set parameters for an instance. E.g., if <code>lr</code> is an instance of <code>LogisticRegression</code>, one could |
| call <code>lr.setMaxIter(10)</code> to make <code>lr.fit()</code> use at most 10 iterations. |
| This API resembles the API used in <code>spark.mllib</code> package.</li> |
| <li>Pass a <code>ParamMap</code> to <code>fit()</code> or <code>transform()</code>. Any parameters in the <code>ParamMap</code> will override parameters previously specified via setter methods.</li> |
| </ol> |
| |
| <p>Parameters belong to specific instances of <code>Estimator</code>s and <code>Transformer</code>s. |
| For example, if we have two <code>LogisticRegression</code> instances <code>lr1</code> and <code>lr2</code>, then we can build a <code>ParamMap</code> with both <code>maxIter</code> parameters specified: <code>ParamMap(lr1.maxIter -> 10, lr2.maxIter -> 20)</code>. |
| This is useful if there are two algorithms with the <code>maxIter</code> parameter in a <code>Pipeline</code>.</p> |
| |
| <h2 id="saving-and-loading-pipelines">Saving and Loading Pipelines</h2> |
| |
| <p>Often times it is worth it to save a model or a pipeline to disk for later use. In Spark 1.6, a model import/export functionality was added to the Pipeline API. Most basic transformers are supported as well as some of the more basic ML models. Please refer to the algorithm’s API documentation to see if saving and loading is supported.</p> |
| |
| <h1 id="code-examples">Code examples</h1> |
| |
| <p>This section gives code examples illustrating the functionality discussed above. |
| For more info, please refer to the API documentation |
| (<a href="api/scala/index.html#org.apache.spark.ml.package">Scala</a>, |
| <a href="api/java/org/apache/spark/ml/package-summary.html">Java</a>, |
| and <a href="api/python/pyspark.ml.html">Python</a>).</p> |
| |
| <h2 id="example-estimator-transformer-and-param">Example: Estimator, Transformer, and Param</h2> |
| |
| <p>This example covers the concepts of <code>Estimator</code>, <code>Transformer</code>, and <code>Param</code>.</p> |
| |
| <div class="codetabs"> |
| |
| <div data-lang="scala"> |
| |
| <p>Refer to the <a href="api/scala/index.html#org.apache.spark.ml.Estimator"><code>Estimator</code> Scala docs</a>, |
| the <a href="api/scala/index.html#org.apache.spark.ml.Transformer"><code>Transformer</code> Scala docs</a> and |
| the <a href="api/scala/index.html#org.apache.spark.ml.param.Params"><code>Params</code> Scala docs</a> for details on the API.</p> |
| |
| <div class="highlight"><pre><span class="k">import</span> <span class="nn">org.apache.spark.ml.classification.LogisticRegression</span> |
| <span class="k">import</span> <span class="nn">org.apache.spark.ml.linalg.</span><span class="o">{</span><span class="nc">Vector</span><span class="o">,</span> <span class="nc">Vectors</span><span class="o">}</span> |
| <span class="k">import</span> <span class="nn">org.apache.spark.ml.param.ParamMap</span> |
| <span class="k">import</span> <span class="nn">org.apache.spark.sql.Row</span> |
| |
| <span class="c1">// Prepare training data from a list of (label, features) tuples.</span> |
| <span class="k">val</span> <span class="n">training</span> <span class="k">=</span> <span class="n">spark</span><span class="o">.</span><span class="n">createDataFrame</span><span class="o">(</span><span class="nc">Seq</span><span class="o">(</span> |
| <span class="o">(</span><span class="mf">1.0</span><span class="o">,</span> <span class="nc">Vectors</span><span class="o">.</span><span class="n">dense</span><span class="o">(</span><span class="mf">0.0</span><span class="o">,</span> <span class="mf">1.1</span><span class="o">,</span> <span class="mf">0.1</span><span class="o">)),</span> |
| <span class="o">(</span><span class="mf">0.0</span><span class="o">,</span> <span class="nc">Vectors</span><span class="o">.</span><span class="n">dense</span><span class="o">(</span><span class="mf">2.0</span><span class="o">,</span> <span class="mf">1.0</span><span class="o">,</span> <span class="o">-</span><span class="mf">1.0</span><span class="o">)),</span> |
| <span class="o">(</span><span class="mf">0.0</span><span class="o">,</span> <span class="nc">Vectors</span><span class="o">.</span><span class="n">dense</span><span class="o">(</span><span class="mf">2.0</span><span class="o">,</span> <span class="mf">1.3</span><span class="o">,</span> <span class="mf">1.0</span><span class="o">)),</span> |
| <span class="o">(</span><span class="mf">1.0</span><span class="o">,</span> <span class="nc">Vectors</span><span class="o">.</span><span class="n">dense</span><span class="o">(</span><span class="mf">0.0</span><span class="o">,</span> <span class="mf">1.2</span><span class="o">,</span> <span class="o">-</span><span class="mf">0.5</span><span class="o">))</span> |
| <span class="o">)).</span><span class="n">toDF</span><span class="o">(</span><span class="s">"label"</span><span class="o">,</span> <span class="s">"features"</span><span class="o">)</span> |
| |
| <span class="c1">// Create a LogisticRegression instance. This instance is an Estimator.</span> |
| <span class="k">val</span> <span class="n">lr</span> <span class="k">=</span> <span class="k">new</span> <span class="nc">LogisticRegression</span><span class="o">()</span> |
| <span class="c1">// Print out the parameters, documentation, and any default values.</span> |
| <span class="n">println</span><span class="o">(</span><span class="s">"LogisticRegression parameters:\n"</span> <span class="o">+</span> <span class="n">lr</span><span class="o">.</span><span class="n">explainParams</span><span class="o">()</span> <span class="o">+</span> <span class="s">"\n"</span><span class="o">)</span> |
| |
| <span class="c1">// We may set parameters using setter methods.</span> |
| <span class="n">lr</span><span class="o">.</span><span class="n">setMaxIter</span><span class="o">(</span><span class="mi">10</span><span class="o">)</span> |
| <span class="o">.</span><span class="n">setRegParam</span><span class="o">(</span><span class="mf">0.01</span><span class="o">)</span> |
| |
| <span class="c1">// Learn a LogisticRegression model. This uses the parameters stored in lr.</span> |
| <span class="k">val</span> <span class="n">model1</span> <span class="k">=</span> <span class="n">lr</span><span class="o">.</span><span class="n">fit</span><span class="o">(</span><span class="n">training</span><span class="o">)</span> |
| <span class="c1">// Since model1 is a Model (i.e., a Transformer produced by an Estimator),</span> |
| <span class="c1">// we can view the parameters it used during fit().</span> |
| <span class="c1">// This prints the parameter (name: value) pairs, where names are unique IDs for this</span> |
| <span class="c1">// LogisticRegression instance.</span> |
| <span class="n">println</span><span class="o">(</span><span class="s">"Model 1 was fit using parameters: "</span> <span class="o">+</span> <span class="n">model1</span><span class="o">.</span><span class="n">parent</span><span class="o">.</span><span class="n">extractParamMap</span><span class="o">)</span> |
| |
| <span class="c1">// We may alternatively specify parameters using a ParamMap,</span> |
| <span class="c1">// which supports several methods for specifying parameters.</span> |
| <span class="k">val</span> <span class="n">paramMap</span> <span class="k">=</span> <span class="nc">ParamMap</span><span class="o">(</span><span class="n">lr</span><span class="o">.</span><span class="n">maxIter</span> <span class="o">-></span> <span class="mi">20</span><span class="o">)</span> |
| <span class="o">.</span><span class="n">put</span><span class="o">(</span><span class="n">lr</span><span class="o">.</span><span class="n">maxIter</span><span class="o">,</span> <span class="mi">30</span><span class="o">)</span> <span class="c1">// Specify 1 Param. This overwrites the original maxIter.</span> |
| <span class="o">.</span><span class="n">put</span><span class="o">(</span><span class="n">lr</span><span class="o">.</span><span class="n">regParam</span> <span class="o">-></span> <span class="mf">0.1</span><span class="o">,</span> <span class="n">lr</span><span class="o">.</span><span class="n">threshold</span> <span class="o">-></span> <span class="mf">0.55</span><span class="o">)</span> <span class="c1">// Specify multiple Params.</span> |
| |
| <span class="c1">// One can also combine ParamMaps.</span> |
| <span class="k">val</span> <span class="n">paramMap2</span> <span class="k">=</span> <span class="nc">ParamMap</span><span class="o">(</span><span class="n">lr</span><span class="o">.</span><span class="n">probabilityCol</span> <span class="o">-></span> <span class="s">"myProbability"</span><span class="o">)</span> <span class="c1">// Change output column name.</span> |
| <span class="k">val</span> <span class="n">paramMapCombined</span> <span class="k">=</span> <span class="n">paramMap</span> <span class="o">++</span> <span class="n">paramMap2</span> |
| |
| <span class="c1">// Now learn a new model using the paramMapCombined parameters.</span> |
| <span class="c1">// paramMapCombined overrides all parameters set earlier via lr.set* methods.</span> |
| <span class="k">val</span> <span class="n">model2</span> <span class="k">=</span> <span class="n">lr</span><span class="o">.</span><span class="n">fit</span><span class="o">(</span><span class="n">training</span><span class="o">,</span> <span class="n">paramMapCombined</span><span class="o">)</span> |
| <span class="n">println</span><span class="o">(</span><span class="s">"Model 2 was fit using parameters: "</span> <span class="o">+</span> <span class="n">model2</span><span class="o">.</span><span class="n">parent</span><span class="o">.</span><span class="n">extractParamMap</span><span class="o">)</span> |
| |
| <span class="c1">// Prepare test data.</span> |
| <span class="k">val</span> <span class="n">test</span> <span class="k">=</span> <span class="n">spark</span><span class="o">.</span><span class="n">createDataFrame</span><span class="o">(</span><span class="nc">Seq</span><span class="o">(</span> |
| <span class="o">(</span><span class="mf">1.0</span><span class="o">,</span> <span class="nc">Vectors</span><span class="o">.</span><span class="n">dense</span><span class="o">(-</span><span class="mf">1.0</span><span class="o">,</span> <span class="mf">1.5</span><span class="o">,</span> <span class="mf">1.3</span><span class="o">)),</span> |
| <span class="o">(</span><span class="mf">0.0</span><span class="o">,</span> <span class="nc">Vectors</span><span class="o">.</span><span class="n">dense</span><span class="o">(</span><span class="mf">3.0</span><span class="o">,</span> <span class="mf">2.0</span><span class="o">,</span> <span class="o">-</span><span class="mf">0.1</span><span class="o">)),</span> |
| <span class="o">(</span><span class="mf">1.0</span><span class="o">,</span> <span class="nc">Vectors</span><span class="o">.</span><span class="n">dense</span><span class="o">(</span><span class="mf">0.0</span><span class="o">,</span> <span class="mf">2.2</span><span class="o">,</span> <span class="o">-</span><span class="mf">1.5</span><span class="o">))</span> |
| <span class="o">)).</span><span class="n">toDF</span><span class="o">(</span><span class="s">"label"</span><span class="o">,</span> <span class="s">"features"</span><span class="o">)</span> |
| |
| <span class="c1">// Make predictions on test data using the Transformer.transform() method.</span> |
| <span class="c1">// LogisticRegression.transform will only use the 'features' column.</span> |
| <span class="c1">// Note that model2.transform() outputs a 'myProbability' column instead of the usual</span> |
| <span class="c1">// 'probability' column since we renamed the lr.probabilityCol parameter previously.</span> |
| <span class="n">model2</span><span class="o">.</span><span class="n">transform</span><span class="o">(</span><span class="n">test</span><span class="o">)</span> |
| <span class="o">.</span><span class="n">select</span><span class="o">(</span><span class="s">"features"</span><span class="o">,</span> <span class="s">"label"</span><span class="o">,</span> <span class="s">"myProbability"</span><span class="o">,</span> <span class="s">"prediction"</span><span class="o">)</span> |
| <span class="o">.</span><span class="n">collect</span><span class="o">()</span> |
| <span class="o">.</span><span class="n">foreach</span> <span class="o">{</span> <span class="k">case</span> <span class="nc">Row</span><span class="o">(</span><span class="n">features</span><span class="k">:</span> <span class="kt">Vector</span><span class="o">,</span> <span class="n">label</span><span class="k">:</span> <span class="kt">Double</span><span class="o">,</span> <span class="n">prob</span><span class="k">:</span> <span class="kt">Vector</span><span class="o">,</span> <span class="n">prediction</span><span class="k">:</span> <span class="kt">Double</span><span class="o">)</span> <span class="k">=></span> |
| <span class="n">println</span><span class="o">(</span><span class="n">s</span><span class="s">"($features, $label) -> prob=$prob, prediction=$prediction"</span><span class="o">)</span> |
| <span class="o">}</span> |
| </pre></div> |
| <div><small>Find full example code at "examples/src/main/scala/org/apache/spark/examples/ml/EstimatorTransformerParamExample.scala" in the Spark repo.</small></div> |
| </div> |
| |
| <div data-lang="java"> |
| |
| <p>Refer to the <a href="api/java/org/apache/spark/ml/Estimator.html"><code>Estimator</code> Java docs</a>, |
| the <a href="api/java/org/apache/spark/ml/Transformer.html"><code>Transformer</code> Java docs</a> and |
| the <a href="api/java/org/apache/spark/ml/param/Params.html"><code>Params</code> Java docs</a> for details on the API.</p> |
| |
| <div class="highlight"><pre><span class="kn">import</span> <span class="nn">java.util.Arrays</span><span class="o">;</span> |
| <span class="kn">import</span> <span class="nn">java.util.List</span><span class="o">;</span> |
| |
| <span class="kn">import</span> <span class="nn">org.apache.spark.ml.classification.LogisticRegression</span><span class="o">;</span> |
| <span class="kn">import</span> <span class="nn">org.apache.spark.ml.classification.LogisticRegressionModel</span><span class="o">;</span> |
| <span class="kn">import</span> <span class="nn">org.apache.spark.ml.linalg.VectorUDT</span><span class="o">;</span> |
| <span class="kn">import</span> <span class="nn">org.apache.spark.ml.linalg.Vectors</span><span class="o">;</span> |
| <span class="kn">import</span> <span class="nn">org.apache.spark.ml.param.ParamMap</span><span class="o">;</span> |
| <span class="kn">import</span> <span class="nn">org.apache.spark.sql.Dataset</span><span class="o">;</span> |
| <span class="kn">import</span> <span class="nn">org.apache.spark.sql.Row</span><span class="o">;</span> |
| <span class="kn">import</span> <span class="nn">org.apache.spark.sql.RowFactory</span><span class="o">;</span> |
| <span class="kn">import</span> <span class="nn">org.apache.spark.sql.types.DataTypes</span><span class="o">;</span> |
| <span class="kn">import</span> <span class="nn">org.apache.spark.sql.types.Metadata</span><span class="o">;</span> |
| <span class="kn">import</span> <span class="nn">org.apache.spark.sql.types.StructField</span><span class="o">;</span> |
| <span class="kn">import</span> <span class="nn">org.apache.spark.sql.types.StructType</span><span class="o">;</span> |
| |
| <span class="c1">// Prepare training data.</span> |
| <span class="n">List</span><span class="o"><</span><span class="n">Row</span><span class="o">></span> <span class="n">dataTraining</span> <span class="o">=</span> <span class="n">Arrays</span><span class="o">.</span><span class="na">asList</span><span class="o">(</span> |
| <span class="n">RowFactory</span><span class="o">.</span><span class="na">create</span><span class="o">(</span><span class="mf">1.0</span><span class="o">,</span> <span class="n">Vectors</span><span class="o">.</span><span class="na">dense</span><span class="o">(</span><span class="mf">0.0</span><span class="o">,</span> <span class="mf">1.1</span><span class="o">,</span> <span class="mf">0.1</span><span class="o">)),</span> |
| <span class="n">RowFactory</span><span class="o">.</span><span class="na">create</span><span class="o">(</span><span class="mf">0.0</span><span class="o">,</span> <span class="n">Vectors</span><span class="o">.</span><span class="na">dense</span><span class="o">(</span><span class="mf">2.0</span><span class="o">,</span> <span class="mf">1.0</span><span class="o">,</span> <span class="o">-</span><span class="mf">1.0</span><span class="o">)),</span> |
| <span class="n">RowFactory</span><span class="o">.</span><span class="na">create</span><span class="o">(</span><span class="mf">0.0</span><span class="o">,</span> <span class="n">Vectors</span><span class="o">.</span><span class="na">dense</span><span class="o">(</span><span class="mf">2.0</span><span class="o">,</span> <span class="mf">1.3</span><span class="o">,</span> <span class="mf">1.0</span><span class="o">)),</span> |
| <span class="n">RowFactory</span><span class="o">.</span><span class="na">create</span><span class="o">(</span><span class="mf">1.0</span><span class="o">,</span> <span class="n">Vectors</span><span class="o">.</span><span class="na">dense</span><span class="o">(</span><span class="mf">0.0</span><span class="o">,</span> <span class="mf">1.2</span><span class="o">,</span> <span class="o">-</span><span class="mf">0.5</span><span class="o">))</span> |
| <span class="o">);</span> |
| <span class="n">StructType</span> <span class="n">schema</span> <span class="o">=</span> <span class="k">new</span> <span class="nf">StructType</span><span class="o">(</span><span class="k">new</span> <span class="n">StructField</span><span class="o">[]{</span> |
| <span class="k">new</span> <span class="nf">StructField</span><span class="o">(</span><span class="s">"label"</span><span class="o">,</span> <span class="n">DataTypes</span><span class="o">.</span><span class="na">DoubleType</span><span class="o">,</span> <span class="kc">false</span><span class="o">,</span> <span class="n">Metadata</span><span class="o">.</span><span class="na">empty</span><span class="o">()),</span> |
| <span class="k">new</span> <span class="nf">StructField</span><span class="o">(</span><span class="s">"features"</span><span class="o">,</span> <span class="k">new</span> <span class="nf">VectorUDT</span><span class="o">(),</span> <span class="kc">false</span><span class="o">,</span> <span class="n">Metadata</span><span class="o">.</span><span class="na">empty</span><span class="o">())</span> |
| <span class="o">});</span> |
| <span class="n">Dataset</span><span class="o"><</span><span class="n">Row</span><span class="o">></span> <span class="n">training</span> <span class="o">=</span> <span class="n">spark</span><span class="o">.</span><span class="na">createDataFrame</span><span class="o">(</span><span class="n">dataTraining</span><span class="o">,</span> <span class="n">schema</span><span class="o">);</span> |
| |
| <span class="c1">// Create a LogisticRegression instance. This instance is an Estimator.</span> |
| <span class="n">LogisticRegression</span> <span class="n">lr</span> <span class="o">=</span> <span class="k">new</span> <span class="nf">LogisticRegression</span><span class="o">();</span> |
| <span class="c1">// Print out the parameters, documentation, and any default values.</span> |
| <span class="n">System</span><span class="o">.</span><span class="na">out</span><span class="o">.</span><span class="na">println</span><span class="o">(</span><span class="s">"LogisticRegression parameters:\n"</span> <span class="o">+</span> <span class="n">lr</span><span class="o">.</span><span class="na">explainParams</span><span class="o">()</span> <span class="o">+</span> <span class="s">"\n"</span><span class="o">);</span> |
| |
| <span class="c1">// We may set parameters using setter methods.</span> |
| <span class="n">lr</span><span class="o">.</span><span class="na">setMaxIter</span><span class="o">(</span><span class="mi">10</span><span class="o">).</span><span class="na">setRegParam</span><span class="o">(</span><span class="mf">0.01</span><span class="o">);</span> |
| |
| <span class="c1">// Learn a LogisticRegression model. This uses the parameters stored in lr.</span> |
| <span class="n">LogisticRegressionModel</span> <span class="n">model1</span> <span class="o">=</span> <span class="n">lr</span><span class="o">.</span><span class="na">fit</span><span class="o">(</span><span class="n">training</span><span class="o">);</span> |
| <span class="c1">// Since model1 is a Model (i.e., a Transformer produced by an Estimator),</span> |
| <span class="c1">// we can view the parameters it used during fit().</span> |
| <span class="c1">// This prints the parameter (name: value) pairs, where names are unique IDs for this</span> |
| <span class="c1">// LogisticRegression instance.</span> |
| <span class="n">System</span><span class="o">.</span><span class="na">out</span><span class="o">.</span><span class="na">println</span><span class="o">(</span><span class="s">"Model 1 was fit using parameters: "</span> <span class="o">+</span> <span class="n">model1</span><span class="o">.</span><span class="na">parent</span><span class="o">().</span><span class="na">extractParamMap</span><span class="o">());</span> |
| |
| <span class="c1">// We may alternatively specify parameters using a ParamMap.</span> |
| <span class="n">ParamMap</span> <span class="n">paramMap</span> <span class="o">=</span> <span class="k">new</span> <span class="nf">ParamMap</span><span class="o">()</span> |
| <span class="o">.</span><span class="na">put</span><span class="o">(</span><span class="n">lr</span><span class="o">.</span><span class="na">maxIter</span><span class="o">().</span><span class="na">w</span><span class="o">(</span><span class="mi">20</span><span class="o">))</span> <span class="c1">// Specify 1 Param.</span> |
| <span class="o">.</span><span class="na">put</span><span class="o">(</span><span class="n">lr</span><span class="o">.</span><span class="na">maxIter</span><span class="o">(),</span> <span class="mi">30</span><span class="o">)</span> <span class="c1">// This overwrites the original maxIter.</span> |
| <span class="o">.</span><span class="na">put</span><span class="o">(</span><span class="n">lr</span><span class="o">.</span><span class="na">regParam</span><span class="o">().</span><span class="na">w</span><span class="o">(</span><span class="mf">0.1</span><span class="o">),</span> <span class="n">lr</span><span class="o">.</span><span class="na">threshold</span><span class="o">().</span><span class="na">w</span><span class="o">(</span><span class="mf">0.55</span><span class="o">));</span> <span class="c1">// Specify multiple Params.</span> |
| |
| <span class="c1">// One can also combine ParamMaps.</span> |
| <span class="n">ParamMap</span> <span class="n">paramMap2</span> <span class="o">=</span> <span class="k">new</span> <span class="nf">ParamMap</span><span class="o">()</span> |
| <span class="o">.</span><span class="na">put</span><span class="o">(</span><span class="n">lr</span><span class="o">.</span><span class="na">probabilityCol</span><span class="o">().</span><span class="na">w</span><span class="o">(</span><span class="s">"myProbability"</span><span class="o">));</span> <span class="c1">// Change output column name</span> |
| <span class="n">ParamMap</span> <span class="n">paramMapCombined</span> <span class="o">=</span> <span class="n">paramMap</span><span class="o">.</span><span class="na">$plus$plus</span><span class="o">(</span><span class="n">paramMap2</span><span class="o">);</span> |
| |
| <span class="c1">// Now learn a new model using the paramMapCombined parameters.</span> |
| <span class="c1">// paramMapCombined overrides all parameters set earlier via lr.set* methods.</span> |
| <span class="n">LogisticRegressionModel</span> <span class="n">model2</span> <span class="o">=</span> <span class="n">lr</span><span class="o">.</span><span class="na">fit</span><span class="o">(</span><span class="n">training</span><span class="o">,</span> <span class="n">paramMapCombined</span><span class="o">);</span> |
| <span class="n">System</span><span class="o">.</span><span class="na">out</span><span class="o">.</span><span class="na">println</span><span class="o">(</span><span class="s">"Model 2 was fit using parameters: "</span> <span class="o">+</span> <span class="n">model2</span><span class="o">.</span><span class="na">parent</span><span class="o">().</span><span class="na">extractParamMap</span><span class="o">());</span> |
| |
| <span class="c1">// Prepare test documents.</span> |
| <span class="n">List</span><span class="o"><</span><span class="n">Row</span><span class="o">></span> <span class="n">dataTest</span> <span class="o">=</span> <span class="n">Arrays</span><span class="o">.</span><span class="na">asList</span><span class="o">(</span> |
| <span class="n">RowFactory</span><span class="o">.</span><span class="na">create</span><span class="o">(</span><span class="mf">1.0</span><span class="o">,</span> <span class="n">Vectors</span><span class="o">.</span><span class="na">dense</span><span class="o">(-</span><span class="mf">1.0</span><span class="o">,</span> <span class="mf">1.5</span><span class="o">,</span> <span class="mf">1.3</span><span class="o">)),</span> |
| <span class="n">RowFactory</span><span class="o">.</span><span class="na">create</span><span class="o">(</span><span class="mf">0.0</span><span class="o">,</span> <span class="n">Vectors</span><span class="o">.</span><span class="na">dense</span><span class="o">(</span><span class="mf">3.0</span><span class="o">,</span> <span class="mf">2.0</span><span class="o">,</span> <span class="o">-</span><span class="mf">0.1</span><span class="o">)),</span> |
| <span class="n">RowFactory</span><span class="o">.</span><span class="na">create</span><span class="o">(</span><span class="mf">1.0</span><span class="o">,</span> <span class="n">Vectors</span><span class="o">.</span><span class="na">dense</span><span class="o">(</span><span class="mf">0.0</span><span class="o">,</span> <span class="mf">2.2</span><span class="o">,</span> <span class="o">-</span><span class="mf">1.5</span><span class="o">))</span> |
| <span class="o">);</span> |
| <span class="n">Dataset</span><span class="o"><</span><span class="n">Row</span><span class="o">></span> <span class="n">test</span> <span class="o">=</span> <span class="n">spark</span><span class="o">.</span><span class="na">createDataFrame</span><span class="o">(</span><span class="n">dataTest</span><span class="o">,</span> <span class="n">schema</span><span class="o">);</span> |
| |
| <span class="c1">// Make predictions on test documents using the Transformer.transform() method.</span> |
| <span class="c1">// LogisticRegression.transform will only use the 'features' column.</span> |
| <span class="c1">// Note that model2.transform() outputs a 'myProbability' column instead of the usual</span> |
| <span class="c1">// 'probability' column since we renamed the lr.probabilityCol parameter previously.</span> |
| <span class="n">Dataset</span><span class="o"><</span><span class="n">Row</span><span class="o">></span> <span class="n">results</span> <span class="o">=</span> <span class="n">model2</span><span class="o">.</span><span class="na">transform</span><span class="o">(</span><span class="n">test</span><span class="o">);</span> |
| <span class="n">Dataset</span><span class="o"><</span><span class="n">Row</span><span class="o">></span> <span class="n">rows</span> <span class="o">=</span> <span class="n">results</span><span class="o">.</span><span class="na">select</span><span class="o">(</span><span class="s">"features"</span><span class="o">,</span> <span class="s">"label"</span><span class="o">,</span> <span class="s">"myProbability"</span><span class="o">,</span> <span class="s">"prediction"</span><span class="o">);</span> |
| <span class="k">for</span> <span class="o">(</span><span class="n">Row</span> <span class="nl">r:</span> <span class="n">rows</span><span class="o">.</span><span class="na">collectAsList</span><span class="o">())</span> <span class="o">{</span> |
| <span class="n">System</span><span class="o">.</span><span class="na">out</span><span class="o">.</span><span class="na">println</span><span class="o">(</span><span class="s">"("</span> <span class="o">+</span> <span class="n">r</span><span class="o">.</span><span class="na">get</span><span class="o">(</span><span class="mi">0</span><span class="o">)</span> <span class="o">+</span> <span class="s">", "</span> <span class="o">+</span> <span class="n">r</span><span class="o">.</span><span class="na">get</span><span class="o">(</span><span class="mi">1</span><span class="o">)</span> <span class="o">+</span> <span class="s">") -> prob="</span> <span class="o">+</span> <span class="n">r</span><span class="o">.</span><span class="na">get</span><span class="o">(</span><span class="mi">2</span><span class="o">)</span> |
| <span class="o">+</span> <span class="s">", prediction="</span> <span class="o">+</span> <span class="n">r</span><span class="o">.</span><span class="na">get</span><span class="o">(</span><span class="mi">3</span><span class="o">));</span> |
| <span class="o">}</span> |
| </pre></div> |
| <div><small>Find full example code at "examples/src/main/java/org/apache/spark/examples/ml/JavaEstimatorTransformerParamExample.java" in the Spark repo.</small></div> |
| </div> |
| |
| <div data-lang="python"> |
| |
| <p>Refer to the <a href="api/python/pyspark.ml.html#pyspark.ml.Estimator"><code>Estimator</code> Python docs</a>, |
| the <a href="api/python/pyspark.ml.html#pyspark.ml.Transformer"><code>Transformer</code> Python docs</a> and |
| the <a href="api/python/pyspark.ml.html#pyspark.ml.param.Params"><code>Params</code> Python docs</a> for more details on the API.</p> |
| |
| <div class="highlight"><pre><span class="kn">from</span> <span class="nn">pyspark.ml.linalg</span> <span class="kn">import</span> <span class="n">Vectors</span> |
| <span class="kn">from</span> <span class="nn">pyspark.ml.classification</span> <span class="kn">import</span> <span class="n">LogisticRegression</span> |
| |
| <span class="c"># Prepare training data from a list of (label, features) tuples.</span> |
| <span class="n">training</span> <span class="o">=</span> <span class="n">spark</span><span class="o">.</span><span class="n">createDataFrame</span><span class="p">([</span> |
| <span class="p">(</span><span class="mf">1.0</span><span class="p">,</span> <span class="n">Vectors</span><span class="o">.</span><span class="n">dense</span><span class="p">([</span><span class="mf">0.0</span><span class="p">,</span> <span class="mf">1.1</span><span class="p">,</span> <span class="mf">0.1</span><span class="p">])),</span> |
| <span class="p">(</span><span class="mf">0.0</span><span class="p">,</span> <span class="n">Vectors</span><span class="o">.</span><span class="n">dense</span><span class="p">([</span><span class="mf">2.0</span><span class="p">,</span> <span class="mf">1.0</span><span class="p">,</span> <span class="o">-</span><span class="mf">1.0</span><span class="p">])),</span> |
| <span class="p">(</span><span class="mf">0.0</span><span class="p">,</span> <span class="n">Vectors</span><span class="o">.</span><span class="n">dense</span><span class="p">([</span><span class="mf">2.0</span><span class="p">,</span> <span class="mf">1.3</span><span class="p">,</span> <span class="mf">1.0</span><span class="p">])),</span> |
| <span class="p">(</span><span class="mf">1.0</span><span class="p">,</span> <span class="n">Vectors</span><span class="o">.</span><span class="n">dense</span><span class="p">([</span><span class="mf">0.0</span><span class="p">,</span> <span class="mf">1.2</span><span class="p">,</span> <span class="o">-</span><span class="mf">0.5</span><span class="p">]))],</span> <span class="p">[</span><span class="s">"label"</span><span class="p">,</span> <span class="s">"features"</span><span class="p">])</span> |
| |
| <span class="c"># Create a LogisticRegression instance. This instance is an Estimator.</span> |
| <span class="n">lr</span> <span class="o">=</span> <span class="n">LogisticRegression</span><span class="p">(</span><span class="n">maxIter</span><span class="o">=</span><span class="mi">10</span><span class="p">,</span> <span class="n">regParam</span><span class="o">=</span><span class="mf">0.01</span><span class="p">)</span> |
| <span class="c"># Print out the parameters, documentation, and any default values.</span> |
| <span class="k">print</span><span class="p">(</span><span class="s">"LogisticRegression parameters:</span><span class="se">\n</span><span class="s">"</span> <span class="o">+</span> <span class="n">lr</span><span class="o">.</span><span class="n">explainParams</span><span class="p">()</span> <span class="o">+</span> <span class="s">"</span><span class="se">\n</span><span class="s">"</span><span class="p">)</span> |
| |
| <span class="c"># Learn a LogisticRegression model. This uses the parameters stored in lr.</span> |
| <span class="n">model1</span> <span class="o">=</span> <span class="n">lr</span><span class="o">.</span><span class="n">fit</span><span class="p">(</span><span class="n">training</span><span class="p">)</span> |
| |
| <span class="c"># Since model1 is a Model (i.e., a transformer produced by an Estimator),</span> |
| <span class="c"># we can view the parameters it used during fit().</span> |
| <span class="c"># This prints the parameter (name: value) pairs, where names are unique IDs for this</span> |
| <span class="c"># LogisticRegression instance.</span> |
| <span class="k">print</span><span class="p">(</span><span class="s">"Model 1 was fit using parameters: "</span><span class="p">)</span> |
| <span class="k">print</span><span class="p">(</span><span class="n">model1</span><span class="o">.</span><span class="n">extractParamMap</span><span class="p">())</span> |
| |
| <span class="c"># We may alternatively specify parameters using a Python dictionary as a paramMap</span> |
| <span class="n">paramMap</span> <span class="o">=</span> <span class="p">{</span><span class="n">lr</span><span class="o">.</span><span class="n">maxIter</span><span class="p">:</span> <span class="mi">20</span><span class="p">}</span> |
| <span class="n">paramMap</span><span class="p">[</span><span class="n">lr</span><span class="o">.</span><span class="n">maxIter</span><span class="p">]</span> <span class="o">=</span> <span class="mi">30</span> <span class="c"># Specify 1 Param, overwriting the original maxIter.</span> |
| <span class="n">paramMap</span><span class="o">.</span><span class="n">update</span><span class="p">({</span><span class="n">lr</span><span class="o">.</span><span class="n">regParam</span><span class="p">:</span> <span class="mf">0.1</span><span class="p">,</span> <span class="n">lr</span><span class="o">.</span><span class="n">threshold</span><span class="p">:</span> <span class="mf">0.55</span><span class="p">})</span> <span class="c"># Specify multiple Params.</span> |
| |
| <span class="c"># You can combine paramMaps, which are python dictionaries.</span> |
| <span class="n">paramMap2</span> <span class="o">=</span> <span class="p">{</span><span class="n">lr</span><span class="o">.</span><span class="n">probabilityCol</span><span class="p">:</span> <span class="s">"myProbability"</span><span class="p">}</span> <span class="c"># Change output column name</span> |
| <span class="n">paramMapCombined</span> <span class="o">=</span> <span class="n">paramMap</span><span class="o">.</span><span class="n">copy</span><span class="p">()</span> |
| <span class="n">paramMapCombined</span><span class="o">.</span><span class="n">update</span><span class="p">(</span><span class="n">paramMap2</span><span class="p">)</span> |
| |
| <span class="c"># Now learn a new model using the paramMapCombined parameters.</span> |
| <span class="c"># paramMapCombined overrides all parameters set earlier via lr.set* methods.</span> |
| <span class="n">model2</span> <span class="o">=</span> <span class="n">lr</span><span class="o">.</span><span class="n">fit</span><span class="p">(</span><span class="n">training</span><span class="p">,</span> <span class="n">paramMapCombined</span><span class="p">)</span> |
| <span class="k">print</span><span class="p">(</span><span class="s">"Model 2 was fit using parameters: "</span><span class="p">)</span> |
| <span class="k">print</span><span class="p">(</span><span class="n">model2</span><span class="o">.</span><span class="n">extractParamMap</span><span class="p">())</span> |
| |
| <span class="c"># Prepare test data</span> |
| <span class="n">test</span> <span class="o">=</span> <span class="n">spark</span><span class="o">.</span><span class="n">createDataFrame</span><span class="p">([</span> |
| <span class="p">(</span><span class="mf">1.0</span><span class="p">,</span> <span class="n">Vectors</span><span class="o">.</span><span class="n">dense</span><span class="p">([</span><span class="o">-</span><span class="mf">1.0</span><span class="p">,</span> <span class="mf">1.5</span><span class="p">,</span> <span class="mf">1.3</span><span class="p">])),</span> |
| <span class="p">(</span><span class="mf">0.0</span><span class="p">,</span> <span class="n">Vectors</span><span class="o">.</span><span class="n">dense</span><span class="p">([</span><span class="mf">3.0</span><span class="p">,</span> <span class="mf">2.0</span><span class="p">,</span> <span class="o">-</span><span class="mf">0.1</span><span class="p">])),</span> |
| <span class="p">(</span><span class="mf">1.0</span><span class="p">,</span> <span class="n">Vectors</span><span class="o">.</span><span class="n">dense</span><span class="p">([</span><span class="mf">0.0</span><span class="p">,</span> <span class="mf">2.2</span><span class="p">,</span> <span class="o">-</span><span class="mf">1.5</span><span class="p">]))],</span> <span class="p">[</span><span class="s">"label"</span><span class="p">,</span> <span class="s">"features"</span><span class="p">])</span> |
| |
| <span class="c"># Make predictions on test data using the Transformer.transform() method.</span> |
| <span class="c"># LogisticRegression.transform will only use the 'features' column.</span> |
| <span class="c"># Note that model2.transform() outputs a "myProbability" column instead of the usual</span> |
| <span class="c"># 'probability' column since we renamed the lr.probabilityCol parameter previously.</span> |
| <span class="n">prediction</span> <span class="o">=</span> <span class="n">model2</span><span class="o">.</span><span class="n">transform</span><span class="p">(</span><span class="n">test</span><span class="p">)</span> |
| <span class="n">result</span> <span class="o">=</span> <span class="n">prediction</span><span class="o">.</span><span class="n">select</span><span class="p">(</span><span class="s">"features"</span><span class="p">,</span> <span class="s">"label"</span><span class="p">,</span> <span class="s">"myProbability"</span><span class="p">,</span> <span class="s">"prediction"</span><span class="p">)</span> \ |
| <span class="o">.</span><span class="n">collect</span><span class="p">()</span> |
| |
| <span class="k">for</span> <span class="n">row</span> <span class="ow">in</span> <span class="n">result</span><span class="p">:</span> |
| <span class="k">print</span><span class="p">(</span><span class="s">"features=</span><span class="si">%s</span><span class="s">, label=</span><span class="si">%s</span><span class="s"> -> prob=</span><span class="si">%s</span><span class="s">, prediction=</span><span class="si">%s</span><span class="s">"</span> |
| <span class="o">%</span> <span class="p">(</span><span class="n">row</span><span class="o">.</span><span class="n">features</span><span class="p">,</span> <span class="n">row</span><span class="o">.</span><span class="n">label</span><span class="p">,</span> <span class="n">row</span><span class="o">.</span><span class="n">myProbability</span><span class="p">,</span> <span class="n">row</span><span class="o">.</span><span class="n">prediction</span><span class="p">))</span> |
| </pre></div> |
| <div><small>Find full example code at "examples/src/main/python/ml/estimator_transformer_param_example.py" in the Spark repo.</small></div> |
| </div> |
| |
| </div> |
| |
| <h2 id="example-pipeline">Example: Pipeline</h2> |
| |
| <p>This example follows the simple text document <code>Pipeline</code> illustrated in the figures above.</p> |
| |
| <div class="codetabs"> |
| |
| <div data-lang="scala"> |
| |
| <p>Refer to the <a href="api/scala/index.html#org.apache.spark.ml.Pipeline"><code>Pipeline</code> Scala docs</a> for details on the API.</p> |
| |
| <div class="highlight"><pre><span class="k">import</span> <span class="nn">org.apache.spark.ml.</span><span class="o">{</span><span class="nc">Pipeline</span><span class="o">,</span> <span class="nc">PipelineModel</span><span class="o">}</span> |
| <span class="k">import</span> <span class="nn">org.apache.spark.ml.classification.LogisticRegression</span> |
| <span class="k">import</span> <span class="nn">org.apache.spark.ml.feature.</span><span class="o">{</span><span class="nc">HashingTF</span><span class="o">,</span> <span class="nc">Tokenizer</span><span class="o">}</span> |
| <span class="k">import</span> <span class="nn">org.apache.spark.ml.linalg.Vector</span> |
| <span class="k">import</span> <span class="nn">org.apache.spark.sql.Row</span> |
| |
| <span class="c1">// Prepare training documents from a list of (id, text, label) tuples.</span> |
| <span class="k">val</span> <span class="n">training</span> <span class="k">=</span> <span class="n">spark</span><span class="o">.</span><span class="n">createDataFrame</span><span class="o">(</span><span class="nc">Seq</span><span class="o">(</span> |
| <span class="o">(</span><span class="mi">0L</span><span class="o">,</span> <span class="s">"a b c d e spark"</span><span class="o">,</span> <span class="mf">1.0</span><span class="o">),</span> |
| <span class="o">(</span><span class="mi">1L</span><span class="o">,</span> <span class="s">"b d"</span><span class="o">,</span> <span class="mf">0.0</span><span class="o">),</span> |
| <span class="o">(</span><span class="mi">2L</span><span class="o">,</span> <span class="s">"spark f g h"</span><span class="o">,</span> <span class="mf">1.0</span><span class="o">),</span> |
| <span class="o">(</span><span class="mi">3L</span><span class="o">,</span> <span class="s">"hadoop mapreduce"</span><span class="o">,</span> <span class="mf">0.0</span><span class="o">)</span> |
| <span class="o">)).</span><span class="n">toDF</span><span class="o">(</span><span class="s">"id"</span><span class="o">,</span> <span class="s">"text"</span><span class="o">,</span> <span class="s">"label"</span><span class="o">)</span> |
| |
| <span class="c1">// Configure an ML pipeline, which consists of three stages: tokenizer, hashingTF, and lr.</span> |
| <span class="k">val</span> <span class="n">tokenizer</span> <span class="k">=</span> <span class="k">new</span> <span class="nc">Tokenizer</span><span class="o">()</span> |
| <span class="o">.</span><span class="n">setInputCol</span><span class="o">(</span><span class="s">"text"</span><span class="o">)</span> |
| <span class="o">.</span><span class="n">setOutputCol</span><span class="o">(</span><span class="s">"words"</span><span class="o">)</span> |
| <span class="k">val</span> <span class="n">hashingTF</span> <span class="k">=</span> <span class="k">new</span> <span class="nc">HashingTF</span><span class="o">()</span> |
| <span class="o">.</span><span class="n">setNumFeatures</span><span class="o">(</span><span class="mi">1000</span><span class="o">)</span> |
| <span class="o">.</span><span class="n">setInputCol</span><span class="o">(</span><span class="n">tokenizer</span><span class="o">.</span><span class="n">getOutputCol</span><span class="o">)</span> |
| <span class="o">.</span><span class="n">setOutputCol</span><span class="o">(</span><span class="s">"features"</span><span class="o">)</span> |
| <span class="k">val</span> <span class="n">lr</span> <span class="k">=</span> <span class="k">new</span> <span class="nc">LogisticRegression</span><span class="o">()</span> |
| <span class="o">.</span><span class="n">setMaxIter</span><span class="o">(</span><span class="mi">10</span><span class="o">)</span> |
| <span class="o">.</span><span class="n">setRegParam</span><span class="o">(</span><span class="mf">0.001</span><span class="o">)</span> |
| <span class="k">val</span> <span class="n">pipeline</span> <span class="k">=</span> <span class="k">new</span> <span class="nc">Pipeline</span><span class="o">()</span> |
| <span class="o">.</span><span class="n">setStages</span><span class="o">(</span><span class="nc">Array</span><span class="o">(</span><span class="n">tokenizer</span><span class="o">,</span> <span class="n">hashingTF</span><span class="o">,</span> <span class="n">lr</span><span class="o">))</span> |
| |
| <span class="c1">// Fit the pipeline to training documents.</span> |
| <span class="k">val</span> <span class="n">model</span> <span class="k">=</span> <span class="n">pipeline</span><span class="o">.</span><span class="n">fit</span><span class="o">(</span><span class="n">training</span><span class="o">)</span> |
| |
| <span class="c1">// Now we can optionally save the fitted pipeline to disk</span> |
| <span class="n">model</span><span class="o">.</span><span class="n">write</span><span class="o">.</span><span class="n">overwrite</span><span class="o">().</span><span class="n">save</span><span class="o">(</span><span class="s">"/tmp/spark-logistic-regression-model"</span><span class="o">)</span> |
| |
| <span class="c1">// We can also save this unfit pipeline to disk</span> |
| <span class="n">pipeline</span><span class="o">.</span><span class="n">write</span><span class="o">.</span><span class="n">overwrite</span><span class="o">().</span><span class="n">save</span><span class="o">(</span><span class="s">"/tmp/unfit-lr-model"</span><span class="o">)</span> |
| |
| <span class="c1">// And load it back in during production</span> |
| <span class="k">val</span> <span class="n">sameModel</span> <span class="k">=</span> <span class="nc">PipelineModel</span><span class="o">.</span><span class="n">load</span><span class="o">(</span><span class="s">"/tmp/spark-logistic-regression-model"</span><span class="o">)</span> |
| |
| <span class="c1">// Prepare test documents, which are unlabeled (id, text) tuples.</span> |
| <span class="k">val</span> <span class="n">test</span> <span class="k">=</span> <span class="n">spark</span><span class="o">.</span><span class="n">createDataFrame</span><span class="o">(</span><span class="nc">Seq</span><span class="o">(</span> |
| <span class="o">(</span><span class="mi">4L</span><span class="o">,</span> <span class="s">"spark i j k"</span><span class="o">),</span> |
| <span class="o">(</span><span class="mi">5L</span><span class="o">,</span> <span class="s">"l m n"</span><span class="o">),</span> |
| <span class="o">(</span><span class="mi">6L</span><span class="o">,</span> <span class="s">"spark hadoop spark"</span><span class="o">),</span> |
| <span class="o">(</span><span class="mi">7L</span><span class="o">,</span> <span class="s">"apache hadoop"</span><span class="o">)</span> |
| <span class="o">)).</span><span class="n">toDF</span><span class="o">(</span><span class="s">"id"</span><span class="o">,</span> <span class="s">"text"</span><span class="o">)</span> |
| |
| <span class="c1">// Make predictions on test documents.</span> |
| <span class="n">model</span><span class="o">.</span><span class="n">transform</span><span class="o">(</span><span class="n">test</span><span class="o">)</span> |
| <span class="o">.</span><span class="n">select</span><span class="o">(</span><span class="s">"id"</span><span class="o">,</span> <span class="s">"text"</span><span class="o">,</span> <span class="s">"probability"</span><span class="o">,</span> <span class="s">"prediction"</span><span class="o">)</span> |
| <span class="o">.</span><span class="n">collect</span><span class="o">()</span> |
| <span class="o">.</span><span class="n">foreach</span> <span class="o">{</span> <span class="k">case</span> <span class="nc">Row</span><span class="o">(</span><span class="n">id</span><span class="k">:</span> <span class="kt">Long</span><span class="o">,</span> <span class="n">text</span><span class="k">:</span> <span class="kt">String</span><span class="o">,</span> <span class="n">prob</span><span class="k">:</span> <span class="kt">Vector</span><span class="o">,</span> <span class="n">prediction</span><span class="k">:</span> <span class="kt">Double</span><span class="o">)</span> <span class="k">=></span> |
| <span class="n">println</span><span class="o">(</span><span class="n">s</span><span class="s">"($id, $text) --> prob=$prob, prediction=$prediction"</span><span class="o">)</span> |
| <span class="o">}</span> |
| </pre></div> |
| <div><small>Find full example code at "examples/src/main/scala/org/apache/spark/examples/ml/PipelineExample.scala" in the Spark repo.</small></div> |
| </div> |
| |
| <div data-lang="java"> |
| |
| <p>Refer to the <a href="api/java/org/apache/spark/ml/Pipeline.html"><code>Pipeline</code> Java docs</a> for details on the API.</p> |
| |
| <div class="highlight"><pre><span class="kn">import</span> <span class="nn">java.util.Arrays</span><span class="o">;</span> |
| |
| <span class="kn">import</span> <span class="nn">org.apache.spark.ml.Pipeline</span><span class="o">;</span> |
| <span class="kn">import</span> <span class="nn">org.apache.spark.ml.PipelineModel</span><span class="o">;</span> |
| <span class="kn">import</span> <span class="nn">org.apache.spark.ml.PipelineStage</span><span class="o">;</span> |
| <span class="kn">import</span> <span class="nn">org.apache.spark.ml.classification.LogisticRegression</span><span class="o">;</span> |
| <span class="kn">import</span> <span class="nn">org.apache.spark.ml.feature.HashingTF</span><span class="o">;</span> |
| <span class="kn">import</span> <span class="nn">org.apache.spark.ml.feature.Tokenizer</span><span class="o">;</span> |
| <span class="kn">import</span> <span class="nn">org.apache.spark.sql.Dataset</span><span class="o">;</span> |
| <span class="kn">import</span> <span class="nn">org.apache.spark.sql.Row</span><span class="o">;</span> |
| |
| <span class="c1">// Prepare training documents, which are labeled.</span> |
| <span class="n">Dataset</span><span class="o"><</span><span class="n">Row</span><span class="o">></span> <span class="n">training</span> <span class="o">=</span> <span class="n">spark</span><span class="o">.</span><span class="na">createDataFrame</span><span class="o">(</span><span class="n">Arrays</span><span class="o">.</span><span class="na">asList</span><span class="o">(</span> |
| <span class="k">new</span> <span class="nf">JavaLabeledDocument</span><span class="o">(</span><span class="mi">0L</span><span class="o">,</span> <span class="s">"a b c d e spark"</span><span class="o">,</span> <span class="mf">1.0</span><span class="o">),</span> |
| <span class="k">new</span> <span class="nf">JavaLabeledDocument</span><span class="o">(</span><span class="mi">1L</span><span class="o">,</span> <span class="s">"b d"</span><span class="o">,</span> <span class="mf">0.0</span><span class="o">),</span> |
| <span class="k">new</span> <span class="nf">JavaLabeledDocument</span><span class="o">(</span><span class="mi">2L</span><span class="o">,</span> <span class="s">"spark f g h"</span><span class="o">,</span> <span class="mf">1.0</span><span class="o">),</span> |
| <span class="k">new</span> <span class="nf">JavaLabeledDocument</span><span class="o">(</span><span class="mi">3L</span><span class="o">,</span> <span class="s">"hadoop mapreduce"</span><span class="o">,</span> <span class="mf">0.0</span><span class="o">)</span> |
| <span class="o">),</span> <span class="n">JavaLabeledDocument</span><span class="o">.</span><span class="na">class</span><span class="o">);</span> |
| |
| <span class="c1">// Configure an ML pipeline, which consists of three stages: tokenizer, hashingTF, and lr.</span> |
| <span class="n">Tokenizer</span> <span class="n">tokenizer</span> <span class="o">=</span> <span class="k">new</span> <span class="nf">Tokenizer</span><span class="o">()</span> |
| <span class="o">.</span><span class="na">setInputCol</span><span class="o">(</span><span class="s">"text"</span><span class="o">)</span> |
| <span class="o">.</span><span class="na">setOutputCol</span><span class="o">(</span><span class="s">"words"</span><span class="o">);</span> |
| <span class="n">HashingTF</span> <span class="n">hashingTF</span> <span class="o">=</span> <span class="k">new</span> <span class="nf">HashingTF</span><span class="o">()</span> |
| <span class="o">.</span><span class="na">setNumFeatures</span><span class="o">(</span><span class="mi">1000</span><span class="o">)</span> |
| <span class="o">.</span><span class="na">setInputCol</span><span class="o">(</span><span class="n">tokenizer</span><span class="o">.</span><span class="na">getOutputCol</span><span class="o">())</span> |
| <span class="o">.</span><span class="na">setOutputCol</span><span class="o">(</span><span class="s">"features"</span><span class="o">);</span> |
| <span class="n">LogisticRegression</span> <span class="n">lr</span> <span class="o">=</span> <span class="k">new</span> <span class="nf">LogisticRegression</span><span class="o">()</span> |
| <span class="o">.</span><span class="na">setMaxIter</span><span class="o">(</span><span class="mi">10</span><span class="o">)</span> |
| <span class="o">.</span><span class="na">setRegParam</span><span class="o">(</span><span class="mf">0.001</span><span class="o">);</span> |
| <span class="n">Pipeline</span> <span class="n">pipeline</span> <span class="o">=</span> <span class="k">new</span> <span class="nf">Pipeline</span><span class="o">()</span> |
| <span class="o">.</span><span class="na">setStages</span><span class="o">(</span><span class="k">new</span> <span class="n">PipelineStage</span><span class="o">[]</span> <span class="o">{</span><span class="n">tokenizer</span><span class="o">,</span> <span class="n">hashingTF</span><span class="o">,</span> <span class="n">lr</span><span class="o">});</span> |
| |
| <span class="c1">// Fit the pipeline to training documents.</span> |
| <span class="n">PipelineModel</span> <span class="n">model</span> <span class="o">=</span> <span class="n">pipeline</span><span class="o">.</span><span class="na">fit</span><span class="o">(</span><span class="n">training</span><span class="o">);</span> |
| |
| <span class="c1">// Prepare test documents, which are unlabeled.</span> |
| <span class="n">Dataset</span><span class="o"><</span><span class="n">Row</span><span class="o">></span> <span class="n">test</span> <span class="o">=</span> <span class="n">spark</span><span class="o">.</span><span class="na">createDataFrame</span><span class="o">(</span><span class="n">Arrays</span><span class="o">.</span><span class="na">asList</span><span class="o">(</span> |
| <span class="k">new</span> <span class="nf">JavaDocument</span><span class="o">(</span><span class="mi">4L</span><span class="o">,</span> <span class="s">"spark i j k"</span><span class="o">),</span> |
| <span class="k">new</span> <span class="nf">JavaDocument</span><span class="o">(</span><span class="mi">5L</span><span class="o">,</span> <span class="s">"l m n"</span><span class="o">),</span> |
| <span class="k">new</span> <span class="nf">JavaDocument</span><span class="o">(</span><span class="mi">6L</span><span class="o">,</span> <span class="s">"spark hadoop spark"</span><span class="o">),</span> |
| <span class="k">new</span> <span class="nf">JavaDocument</span><span class="o">(</span><span class="mi">7L</span><span class="o">,</span> <span class="s">"apache hadoop"</span><span class="o">)</span> |
| <span class="o">),</span> <span class="n">JavaDocument</span><span class="o">.</span><span class="na">class</span><span class="o">);</span> |
| |
| <span class="c1">// Make predictions on test documents.</span> |
| <span class="n">Dataset</span><span class="o"><</span><span class="n">Row</span><span class="o">></span> <span class="n">predictions</span> <span class="o">=</span> <span class="n">model</span><span class="o">.</span><span class="na">transform</span><span class="o">(</span><span class="n">test</span><span class="o">);</span> |
| <span class="k">for</span> <span class="o">(</span><span class="n">Row</span> <span class="n">r</span> <span class="o">:</span> <span class="n">predictions</span><span class="o">.</span><span class="na">select</span><span class="o">(</span><span class="s">"id"</span><span class="o">,</span> <span class="s">"text"</span><span class="o">,</span> <span class="s">"probability"</span><span class="o">,</span> <span class="s">"prediction"</span><span class="o">).</span><span class="na">collectAsList</span><span class="o">())</span> <span class="o">{</span> |
| <span class="n">System</span><span class="o">.</span><span class="na">out</span><span class="o">.</span><span class="na">println</span><span class="o">(</span><span class="s">"("</span> <span class="o">+</span> <span class="n">r</span><span class="o">.</span><span class="na">get</span><span class="o">(</span><span class="mi">0</span><span class="o">)</span> <span class="o">+</span> <span class="s">", "</span> <span class="o">+</span> <span class="n">r</span><span class="o">.</span><span class="na">get</span><span class="o">(</span><span class="mi">1</span><span class="o">)</span> <span class="o">+</span> <span class="s">") --> prob="</span> <span class="o">+</span> <span class="n">r</span><span class="o">.</span><span class="na">get</span><span class="o">(</span><span class="mi">2</span><span class="o">)</span> |
| <span class="o">+</span> <span class="s">", prediction="</span> <span class="o">+</span> <span class="n">r</span><span class="o">.</span><span class="na">get</span><span class="o">(</span><span class="mi">3</span><span class="o">));</span> |
| <span class="o">}</span> |
| </pre></div> |
| <div><small>Find full example code at "examples/src/main/java/org/apache/spark/examples/ml/JavaPipelineExample.java" in the Spark repo.</small></div> |
| </div> |
| |
| <div data-lang="python"> |
| |
| <p>Refer to the <a href="api/python/pyspark.ml.html#pyspark.ml.Pipeline"><code>Pipeline</code> Python docs</a> for more details on the API.</p> |
| |
| <div class="highlight"><pre><span class="kn">from</span> <span class="nn">pyspark.ml</span> <span class="kn">import</span> <span class="n">Pipeline</span> |
| <span class="kn">from</span> <span class="nn">pyspark.ml.classification</span> <span class="kn">import</span> <span class="n">LogisticRegression</span> |
| <span class="kn">from</span> <span class="nn">pyspark.ml.feature</span> <span class="kn">import</span> <span class="n">HashingTF</span><span class="p">,</span> <span class="n">Tokenizer</span> |
| |
| <span class="c"># Prepare training documents from a list of (id, text, label) tuples.</span> |
| <span class="n">training</span> <span class="o">=</span> <span class="n">spark</span><span class="o">.</span><span class="n">createDataFrame</span><span class="p">([</span> |
| <span class="p">(</span><span class="mi">0</span><span class="p">,</span> <span class="s">"a b c d e spark"</span><span class="p">,</span> <span class="mf">1.0</span><span class="p">),</span> |
| <span class="p">(</span><span class="mi">1</span><span class="p">,</span> <span class="s">"b d"</span><span class="p">,</span> <span class="mf">0.0</span><span class="p">),</span> |
| <span class="p">(</span><span class="mi">2</span><span class="p">,</span> <span class="s">"spark f g h"</span><span class="p">,</span> <span class="mf">1.0</span><span class="p">),</span> |
| <span class="p">(</span><span class="mi">3</span><span class="p">,</span> <span class="s">"hadoop mapreduce"</span><span class="p">,</span> <span class="mf">0.0</span><span class="p">)</span> |
| <span class="p">],</span> <span class="p">[</span><span class="s">"id"</span><span class="p">,</span> <span class="s">"text"</span><span class="p">,</span> <span class="s">"label"</span><span class="p">])</span> |
| |
| <span class="c"># Configure an ML pipeline, which consists of three stages: tokenizer, hashingTF, and lr.</span> |
| <span class="n">tokenizer</span> <span class="o">=</span> <span class="n">Tokenizer</span><span class="p">(</span><span class="n">inputCol</span><span class="o">=</span><span class="s">"text"</span><span class="p">,</span> <span class="n">outputCol</span><span class="o">=</span><span class="s">"words"</span><span class="p">)</span> |
| <span class="n">hashingTF</span> <span class="o">=</span> <span class="n">HashingTF</span><span class="p">(</span><span class="n">inputCol</span><span class="o">=</span><span class="n">tokenizer</span><span class="o">.</span><span class="n">getOutputCol</span><span class="p">(),</span> <span class="n">outputCol</span><span class="o">=</span><span class="s">"features"</span><span class="p">)</span> |
| <span class="n">lr</span> <span class="o">=</span> <span class="n">LogisticRegression</span><span class="p">(</span><span class="n">maxIter</span><span class="o">=</span><span class="mi">10</span><span class="p">,</span> <span class="n">regParam</span><span class="o">=</span><span class="mf">0.001</span><span class="p">)</span> |
| <span class="n">pipeline</span> <span class="o">=</span> <span class="n">Pipeline</span><span class="p">(</span><span class="n">stages</span><span class="o">=</span><span class="p">[</span><span class="n">tokenizer</span><span class="p">,</span> <span class="n">hashingTF</span><span class="p">,</span> <span class="n">lr</span><span class="p">])</span> |
| |
| <span class="c"># Fit the pipeline to training documents.</span> |
| <span class="n">model</span> <span class="o">=</span> <span class="n">pipeline</span><span class="o">.</span><span class="n">fit</span><span class="p">(</span><span class="n">training</span><span class="p">)</span> |
| |
| <span class="c"># Prepare test documents, which are unlabeled (id, text) tuples.</span> |
| <span class="n">test</span> <span class="o">=</span> <span class="n">spark</span><span class="o">.</span><span class="n">createDataFrame</span><span class="p">([</span> |
| <span class="p">(</span><span class="mi">4</span><span class="p">,</span> <span class="s">"spark i j k"</span><span class="p">),</span> |
| <span class="p">(</span><span class="mi">5</span><span class="p">,</span> <span class="s">"l m n"</span><span class="p">),</span> |
| <span class="p">(</span><span class="mi">6</span><span class="p">,</span> <span class="s">"spark hadoop spark"</span><span class="p">),</span> |
| <span class="p">(</span><span class="mi">7</span><span class="p">,</span> <span class="s">"apache hadoop"</span><span class="p">)</span> |
| <span class="p">],</span> <span class="p">[</span><span class="s">"id"</span><span class="p">,</span> <span class="s">"text"</span><span class="p">])</span> |
| |
| <span class="c"># Make predictions on test documents and print columns of interest.</span> |
| <span class="n">prediction</span> <span class="o">=</span> <span class="n">model</span><span class="o">.</span><span class="n">transform</span><span class="p">(</span><span class="n">test</span><span class="p">)</span> |
| <span class="n">selected</span> <span class="o">=</span> <span class="n">prediction</span><span class="o">.</span><span class="n">select</span><span class="p">(</span><span class="s">"id"</span><span class="p">,</span> <span class="s">"text"</span><span class="p">,</span> <span class="s">"probability"</span><span class="p">,</span> <span class="s">"prediction"</span><span class="p">)</span> |
| <span class="k">for</span> <span class="n">row</span> <span class="ow">in</span> <span class="n">selected</span><span class="o">.</span><span class="n">collect</span><span class="p">():</span> |
| <span class="n">rid</span><span class="p">,</span> <span class="n">text</span><span class="p">,</span> <span class="n">prob</span><span class="p">,</span> <span class="n">prediction</span> <span class="o">=</span> <span class="n">row</span> |
| <span class="k">print</span><span class="p">(</span><span class="s">"(</span><span class="si">%d</span><span class="s">, </span><span class="si">%s</span><span class="s">) --> prob=</span><span class="si">%s</span><span class="s">, prediction=</span><span class="si">%f</span><span class="s">"</span> <span class="o">%</span> <span class="p">(</span><span class="n">rid</span><span class="p">,</span> <span class="n">text</span><span class="p">,</span> <span class="nb">str</span><span class="p">(</span><span class="n">prob</span><span class="p">),</span> <span class="n">prediction</span><span class="p">))</span> |
| </pre></div> |
| <div><small>Find full example code at "examples/src/main/python/ml/pipeline_example.py" in the Spark repo.</small></div> |
| </div> |
| |
| </div> |
| |
| <h2 id="model-selection-hyperparameter-tuning">Model selection (hyperparameter tuning)</h2> |
| |
| <p>A big benefit of using ML Pipelines is hyperparameter optimization. See the <a href="ml-tuning.html">ML Tuning Guide</a> for more information on automatic model selection.</p> |
| |
| |
| </div> |
| |
| <!-- /container --> |
| </div> |
| |
| <script src="js/vendor/jquery-1.8.0.min.js"></script> |
| <script src="js/vendor/bootstrap.min.js"></script> |
| <script src="js/vendor/anchor.min.js"></script> |
| <script src="js/main.js"></script> |
| |
| <!-- MathJax Section --> |
| <script type="text/x-mathjax-config"> |
| MathJax.Hub.Config({ |
| TeX: { equationNumbers: { autoNumber: "AMS" } } |
| }); |
| </script> |
| <script> |
| // Note that we load MathJax this way to work with local file (file://), HTTP and HTTPS. |
| // We could use "//cdn.mathjax...", but that won't support "file://". |
| (function(d, script) { |
| script = d.createElement('script'); |
| script.type = 'text/javascript'; |
| script.async = true; |
| script.onload = function(){ |
| MathJax.Hub.Config({ |
| tex2jax: { |
| inlineMath: [ ["$", "$"], ["\\\\(","\\\\)"] ], |
| displayMath: [ ["$$","$$"], ["\\[", "\\]"] ], |
| processEscapes: true, |
| skipTags: ['script', 'noscript', 'style', 'textarea', 'pre'] |
| } |
| }); |
| }; |
| script.src = ('https:' == document.location.protocol ? 'https://' : 'http://') + |
| 'cdn.mathjax.org/mathjax/latest/MathJax.js?config=TeX-AMS-MML_HTMLorMML'; |
| d.getElementsByTagName('head')[0].appendChild(script); |
| }(document)); |
| </script> |
| </body> |
| </html> |