| <!DOCTYPE html> |
| <!--[if lt IE 7]> <html class="no-js lt-ie9 lt-ie8 lt-ie7"> <![endif]--> |
| <!--[if IE 7]> <html class="no-js lt-ie9 lt-ie8"> <![endif]--> |
| <!--[if IE 8]> <html class="no-js lt-ie9"> <![endif]--> |
| <!--[if gt IE 8]><!--> <html class="no-js"> <!--<![endif]--> |
| <head> |
| <meta charset="utf-8"> |
| <meta http-equiv="X-UA-Compatible" content="IE=edge,chrome=1"> |
| <title>Spark ML Programming Guide - Spark 1.2.2 Documentation</title> |
| <meta name="description" content=""> |
| |
| |
| |
| <link rel="stylesheet" href="css/bootstrap.min.css"> |
| <style> |
| body { |
| padding-top: 60px; |
| padding-bottom: 40px; |
| } |
| </style> |
| <meta name="viewport" content="width=device-width"> |
| <link rel="stylesheet" href="css/bootstrap-responsive.min.css"> |
| <link rel="stylesheet" href="css/main.css"> |
| |
| <script src="js/vendor/modernizr-2.6.1-respond-1.1.0.min.js"></script> |
| |
| <link rel="stylesheet" href="css/pygments-default.css"> |
| |
| |
| <!-- Google analytics script --> |
| <script type="text/javascript"> |
| var _gaq = _gaq || []; |
| _gaq.push(['_setAccount', 'UA-32518208-2']); |
| _gaq.push(['_trackPageview']); |
| |
| (function() { |
| var ga = document.createElement('script'); ga.type = 'text/javascript'; ga.async = true; |
| ga.src = ('https:' == document.location.protocol ? 'https://ssl' : 'http://www') + '.google-analytics.com/ga.js'; |
| var s = document.getElementsByTagName('script')[0]; s.parentNode.insertBefore(ga, s); |
| })(); |
| </script> |
| |
| |
| </head> |
| <body> |
| <!--[if lt IE 7]> |
| <p class="chromeframe">You are using an outdated browser. <a href="http://browsehappy.com/">Upgrade your browser today</a> or <a href="http://www.google.com/chromeframe/?redirect=true">install Google Chrome Frame</a> to better experience this site.</p> |
| <![endif]--> |
| |
| <!-- This code is taken from http://twitter.github.com/bootstrap/examples/hero.html --> |
| |
| <div class="navbar navbar-fixed-top" id="topbar"> |
| <div class="navbar-inner"> |
| <div class="container"> |
| <div class="brand"><a href="index.html"> |
| <img src="img/spark-logo-hd.png" style="height:50px;"/></a><span class="version">1.2.2</span> |
| </div> |
| <ul class="nav"> |
| <!--TODO(andyk): Add class="active" attribute to li some how.--> |
| <li><a href="index.html">Overview</a></li> |
| |
| <li class="dropdown"> |
| <a href="#" class="dropdown-toggle" data-toggle="dropdown">Programming Guides<b class="caret"></b></a> |
| <ul class="dropdown-menu"> |
| <li><a href="quick-start.html">Quick Start</a></li> |
| <li><a href="programming-guide.html">Spark Programming Guide</a></li> |
| <li class="divider"></li> |
| <li><a href="streaming-programming-guide.html">Spark Streaming</a></li> |
| <li><a href="sql-programming-guide.html">Spark SQL</a></li> |
| <li><a href="mllib-guide.html">MLlib (Machine Learning)</a></li> |
| <li><a href="graphx-programming-guide.html">GraphX (Graph Processing)</a></li> |
| <li><a href="bagel-programming-guide.html">Bagel (Pregel on Spark)</a></li> |
| </ul> |
| </li> |
| |
| <li class="dropdown"> |
| <a href="#" class="dropdown-toggle" data-toggle="dropdown">API Docs<b class="caret"></b></a> |
| <ul class="dropdown-menu"> |
| <li><a href="api/scala/index.html#org.apache.spark.package">Scala</a></li> |
| <li><a href="api/java/index.html">Java</a></li> |
| <li><a href="api/python/index.html">Python</a></li> |
| </ul> |
| </li> |
| |
| <li class="dropdown"> |
| <a href="#" class="dropdown-toggle" data-toggle="dropdown">Deploying<b class="caret"></b></a> |
| <ul class="dropdown-menu"> |
| <li><a href="cluster-overview.html">Overview</a></li> |
| <li><a href="submitting-applications.html">Submitting Applications</a></li> |
| <li class="divider"></li> |
| <li><a href="spark-standalone.html">Spark Standalone</a></li> |
| <li><a href="running-on-mesos.html">Mesos</a></li> |
| <li><a href="running-on-yarn.html">YARN</a></li> |
| <li class="divider"></li> |
| <li><a href="ec2-scripts.html">Amazon EC2</a></li> |
| </ul> |
| </li> |
| |
| <li class="dropdown"> |
| <a href="api.html" class="dropdown-toggle" data-toggle="dropdown">More<b class="caret"></b></a> |
| <ul class="dropdown-menu"> |
| <li><a href="configuration.html">Configuration</a></li> |
| <li><a href="monitoring.html">Monitoring</a></li> |
| <li><a href="tuning.html">Tuning Guide</a></li> |
| <li><a href="job-scheduling.html">Job Scheduling</a></li> |
| <li><a href="security.html">Security</a></li> |
| <li><a href="hardware-provisioning.html">Hardware Provisioning</a></li> |
| <li><a href="hadoop-third-party-distributions.html">3<sup>rd</sup>-Party Hadoop Distros</a></li> |
| <li class="divider"></li> |
| <li><a href="building-spark.html">Building Spark</a></li> |
| <li><a href="https://cwiki.apache.org/confluence/display/SPARK/Contributing+to+Spark">Contributing to Spark</a></li> |
| <li><a href="https://cwiki.apache.org/confluence/display/SPARK/Supplemental+Spark+Projects">Supplemental Projects</a></li> |
| </ul> |
| </li> |
| </ul> |
| <!--<p class="navbar-text pull-right"><span class="version-text">v1.2.2</span></p>--> |
| </div> |
| </div> |
| </div> |
| |
| <div class="container" id="content"> |
| |
| <h1 class="title">Spark ML Programming Guide</h1> |
| |
| |
| <p><code>spark.ml</code> is a new package introduced in Spark 1.2, which aims to provide a uniform set of |
| high-level APIs that help users create and tune practical machine learning pipelines. |
| It is currently an alpha component, and we would like to hear back from the community about |
| how it fits real-world use cases and how it could be improved.</p> |
| |
| <p>Note that we will keep supporting and adding features to <code>spark.mllib</code> along with the |
| development of <code>spark.ml</code>. |
| Users should be comfortable using <code>spark.mllib</code> features and expect more features coming. |
| Developers should contribute new algorithms to <code>spark.mllib</code> and can optionally contribute |
| to <code>spark.ml</code>.</p> |
| |
| <p><strong>Table of Contents</strong></p> |
| |
| <ul id="markdown-toc"> |
| <li><a href="#main-concepts">Main Concepts</a> <ul> |
| <li><a href="#ml-dataset">ML Dataset</a></li> |
| <li><a href="#ml-algorithms">ML Algorithms</a> <ul> |
| <li><a href="#transformers">Transformers</a></li> |
| <li><a href="#estimators">Estimators</a></li> |
| <li><a href="#properties-of-ml-algorithms">Properties of ML Algorithms</a></li> |
| </ul> |
| </li> |
| <li><a href="#pipeline">Pipeline</a> <ul> |
| <li><a href="#how-it-works">How It Works</a></li> |
| <li><a href="#details">Details</a></li> |
| </ul> |
| </li> |
| <li><a href="#parameters">Parameters</a></li> |
| </ul> |
| </li> |
| <li><a href="#code-examples">Code Examples</a> <ul> |
| <li><a href="#example-estimator-transformer-and-param">Example: Estimator, Transformer, and Param</a></li> |
| <li><a href="#example-pipeline">Example: Pipeline</a></li> |
| <li><a href="#example-model-selection-via-cross-validation">Example: Model Selection via Cross-Validation</a></li> |
| </ul> |
| </li> |
| <li><a href="#dependencies">Dependencies</a></li> |
| </ul> |
| |
| <h1 id="main-concepts">Main Concepts</h1> |
| |
| <p>Spark ML standardizes APIs for machine learning algorithms to make it easier to combine multiple algorithms into a single pipeline, or workflow. This section covers the key concepts introduced by the Spark ML API.</p> |
| |
| <ul> |
| <li> |
| <p><strong><a href="ml-guide.html#ml-dataset">ML Dataset</a></strong>: Spark ML uses the <a href="api/scala/index.html#org.apache.spark.sql.SchemaRDD"><code>SchemaRDD</code></a> from Spark SQL as a dataset which can hold a variety of data types. |
| E.g., a dataset could have different columns storing text, feature vectors, true labels, and predictions.</p> |
| </li> |
| <li> |
| <p><strong><a href="ml-guide.html#transformers"><code>Transformer</code></a></strong>: A <code>Transformer</code> is an algorithm which can transform one <code>SchemaRDD</code> into another <code>SchemaRDD</code>. |
| E.g., an ML model is a <code>Transformer</code> which transforms an RDD with features into an RDD with predictions.</p> |
| </li> |
| <li> |
| <p><strong><a href="ml-guide.html#estimators"><code>Estimator</code></a></strong>: An <code>Estimator</code> is an algorithm which can be fit on a <code>SchemaRDD</code> to produce a <code>Transformer</code>. |
| E.g., a learning algorithm is an <code>Estimator</code> which trains on a dataset and produces a model.</p> |
| </li> |
| <li> |
| <p><strong><a href="ml-guide.html#pipeline"><code>Pipeline</code></a></strong>: A <code>Pipeline</code> chains multiple <code>Transformer</code>s and <code>Estimator</code>s together to specify an ML workflow.</p> |
| </li> |
| <li> |
| <p><strong><a href="ml-guide.html#parameters"><code>Param</code></a></strong>: All <code>Transformer</code>s and <code>Estimator</code>s now share a common API for specifying parameters.</p> |
| </li> |
| </ul> |
| |
| <h2 id="ml-dataset">ML Dataset</h2> |
| |
| <p>Machine learning can be applied to a wide variety of data types, such as vectors, text, images, and structured data. |
| Spark ML adopts the <a href="api/scala/index.html#org.apache.spark.sql.SchemaRDD"><code>SchemaRDD</code></a> from Spark SQL in order to support a variety of data types under a unified Dataset concept.</p> |
| |
| <p><code>SchemaRDD</code> supports many basic and structured types; see the <a href="sql-programming-guide.html#spark-sql-datatype-reference">Spark SQL datatype reference</a> for a list of supported types. |
| In addition to the types listed in the Spark SQL guide, <code>SchemaRDD</code> can use ML <a href="api/scala/index.html#org.apache.spark.mllib.linalg.Vector"><code>Vector</code></a> types.</p> |
| |
| <p>A <code>SchemaRDD</code> can be created either implicitly or explicitly from a regular <code>RDD</code>. See the code examples below and the <a href="sql-programming-guide.html">Spark SQL programming guide</a> for examples.</p> |
| |
| <p>Columns in a <code>SchemaRDD</code> are named. The code examples below use names such as “text,” “features,” and “label.”</p> |
| |
| <h2 id="ml-algorithms">ML Algorithms</h2> |
| |
| <h3 id="transformers">Transformers</h3> |
| |
| <p>A <a href="api/scala/index.html#org.apache.spark.ml.Transformer"><code>Transformer</code></a> is an abstraction which includes feature transformers and learned models. Technically, a <code>Transformer</code> implements a method <code>transform()</code> which converts one <code>SchemaRDD</code> into another, generally by appending one or more columns. |
| For example:</p> |
| |
| <ul> |
| <li>A feature transformer might take a dataset, read a column (e.g., text), convert it into a new column (e.g., feature vectors), append the new column to the dataset, and output the updated dataset.</li> |
| <li>A learning model might take a dataset, read the column containing feature vectors, predict the label for each feature vector, append the labels as a new column, and output the updated dataset.</li> |
| </ul> |
| |
| <h3 id="estimators">Estimators</h3> |
| |
| <p>An <a href="api/scala/index.html#org.apache.spark.ml.Estimator"><code>Estimator</code></a> abstracts the concept of a learning algorithm or any algorithm which fits or trains on data. Technically, an <code>Estimator</code> implements a method <code>fit()</code> which accepts a <code>SchemaRDD</code> and produces a <code>Transformer</code>. |
| For example, a learning algorithm such as <code>LogisticRegression</code> is an <code>Estimator</code>, and calling <code>fit()</code> trains a <code>LogisticRegressionModel</code>, which is a <code>Transformer</code>.</p> |
| |
| <h3 id="properties-of-ml-algorithms">Properties of ML Algorithms</h3> |
| |
| <p><code>Transformer</code>s and <code>Estimator</code>s are both stateless. In the future, stateful algorithms may be supported via alternative concepts.</p> |
| |
| <p>Each instance of a <code>Transformer</code> or <code>Estimator</code> has a unique ID, which is useful in specifying parameters (discussed below).</p> |
| |
| <h2 id="pipeline">Pipeline</h2> |
| |
| <p>In machine learning, it is common to run a sequence of algorithms to process and learn from data. |
| E.g., a simple text document processing workflow might include several stages:</p> |
| |
| <ul> |
| <li>Split each document’s text into words.</li> |
| <li>Convert each document’s words into a numerical feature vector.</li> |
| <li>Learn a prediction model using the feature vectors and labels.</li> |
| </ul> |
| |
| <p>Spark ML represents such a workflow as a <a href="api/scala/index.html#org.apache.spark.ml.Pipeline"><code>Pipeline</code></a>, |
| which consists of a sequence of <a href="api/scala/index.html#org.apache.spark.ml.PipelineStage"><code>PipelineStage</code>s</a> (<code>Transformer</code>s and <code>Estimator</code>s) to be run in a specific order. We will use this simple workflow as a running example in this section.</p> |
| |
| <h3 id="how-it-works">How It Works</h3> |
| |
| <p>A <code>Pipeline</code> is specified as a sequence of stages, and each stage is either a <code>Transformer</code> or an <code>Estimator</code>. |
| These stages are run in order, and the input dataset is modified as it passes through each stage. |
| For <code>Transformer</code> stages, the <code>transform()</code> method is called on the dataset. |
| For <code>Estimator</code> stages, the <code>fit()</code> method is called to produce a <code>Transformer</code> (which becomes part of the <code>PipelineModel</code>, or fitted <code>Pipeline</code>), and that <code>Transformer</code>’s <code>transform()</code> method is called on the dataset.</p> |
| |
| <p>We illustrate this for the simple text document workflow. The figure below is for the <em>training time</em> usage of a <code>Pipeline</code>.</p> |
| |
| <p style="text-align: center;"> |
| <img src="img/ml-Pipeline.png" title="Spark ML Pipeline Example" alt="Spark ML Pipeline Example" width="80%" /> |
| </p> |
| |
| <p>Above, the top row represents a <code>Pipeline</code> with three stages. |
| The first two (<code>Tokenizer</code> and <code>HashingTF</code>) are <code>Transformer</code>s (blue), and the third (<code>LogisticRegression</code>) is an <code>Estimator</code> (red). |
| The bottom row represents data flowing through the pipeline, where cylinders indicate <code>SchemaRDD</code>s. |
| The <code>Pipeline.fit()</code> method is called on the original dataset which has raw text documents and labels. |
| The <code>Tokenizer.transform()</code> method splits the raw text documents into words, adding a new column with words into the dataset. |
| The <code>HashingTF.transform()</code> method converts the words column into feature vectors, adding a new column with those vectors to the dataset. |
| Now, since <code>LogisticRegression</code> is an <code>Estimator</code>, the <code>Pipeline</code> first calls <code>LogisticRegression.fit()</code> to produce a <code>LogisticRegressionModel</code>. |
| If the <code>Pipeline</code> had more stages, it would call the <code>LogisticRegressionModel</code>’s <code>transform()</code> method on the dataset before passing the dataset to the next stage.</p> |
| |
| <p>A <code>Pipeline</code> is an <code>Estimator</code>. |
| Thus, after a <code>Pipeline</code>’s <code>fit()</code> method runs, it produces a <code>PipelineModel</code> which is a <code>Transformer</code>. This <code>PipelineModel</code> is used at <em>test time</em>; the figure below illustrates this usage.</p> |
| |
| <p style="text-align: center;"> |
| <img src="img/ml-PipelineModel.png" title="Spark ML PipelineModel Example" alt="Spark ML PipelineModel Example" width="80%" /> |
| </p> |
| |
| <p>In the figure above, the <code>PipelineModel</code> has the same number of stages as the original <code>Pipeline</code>, but all <code>Estimator</code>s in the original <code>Pipeline</code> have become <code>Transformer</code>s. |
| When the <code>PipelineModel</code>’s <code>transform()</code> method is called on a test dataset, the data are passed through the <code>Pipeline</code> in order. |
| Each stage’s <code>transform()</code> method updates the dataset and passes it to the next stage.</p> |
| |
| <p><code>Pipeline</code>s and <code>PipelineModel</code>s help to ensure that training and test data go through identical feature processing steps.</p> |
| |
| <h3 id="details">Details</h3> |
| |
| <p><em>DAG <code>Pipeline</code>s</em>: A <code>Pipeline</code>’s stages are specified as an ordered array. The examples given here are all for linear <code>Pipeline</code>s, i.e., <code>Pipeline</code>s in which each stage uses data produced by the previous stage. It is possible to create non-linear <code>Pipeline</code>s as long as the data flow graph forms a Directed Acyclic Graph (DAG). This graph is currently specified implicitly based on the input and output column names of each stage (generally specified as parameters). If the <code>Pipeline</code> forms a DAG, then the stages must be specified in topological order.</p> |
| |
| <p><em>Runtime checking</em>: Since <code>Pipeline</code>s can operate on datasets with varied types, they cannot use compile-time type checking. <code>Pipeline</code>s and <code>PipelineModel</code>s instead do runtime checking before actually running the <code>Pipeline</code>. This type checking is done using the dataset <em>schema</em>, a description of the data types of columns in the <code>SchemaRDD</code>.</p> |
| |
| <h2 id="parameters">Parameters</h2> |
| |
| <p>Spark ML <code>Estimator</code>s and <code>Transformer</code>s use a uniform API for specifying parameters.</p> |
| |
| <p>A <a href="api/scala/index.html#org.apache.spark.ml.param.Param"><code>Param</code></a> is a named parameter with self-contained documentation. |
| A <a href="api/scala/index.html#org.apache.spark.ml.param.ParamMap"><code>ParamMap</code></a> is a set of (parameter, value) pairs.</p> |
| |
| <p>There are two main ways to pass parameters to an algorithm:</p> |
| |
| <ol> |
| <li>Set parameters for an instance. E.g., if <code>lr</code> is an instance of <code>LogisticRegression</code>, one could call <code>lr.setMaxIter(10)</code> to make <code>lr.fit()</code> use at most 10 iterations. This API resembles the API used in MLlib.</li> |
| <li>Pass a <code>ParamMap</code> to <code>fit()</code> or <code>transform()</code>. Any parameters in the <code>ParamMap</code> will override parameters previously specified via setter methods.</li> |
| </ol> |
| |
| <p>Parameters belong to specific instances of <code>Estimator</code>s and <code>Transformer</code>s. |
| For example, if we have two <code>LogisticRegression</code> instances <code>lr1</code> and <code>lr2</code>, then we can build a <code>ParamMap</code> with both <code>maxIter</code> parameters specified: <code>ParamMap(lr1.maxIter -> 10, lr2.maxIter -> 20)</code>. |
| This is useful if there are two algorithms with the <code>maxIter</code> parameter in a <code>Pipeline</code>.</p> |
| |
| <h1 id="code-examples">Code Examples</h1> |
| |
| <p>This section gives code examples illustrating the functionality discussed above. |
| There is not yet documentation for specific algorithms in Spark ML. For more info, please refer to the <a href="api/scala/index.html#org.apache.spark.ml.package">API Documentation</a>. Spark ML algorithms are currently wrappers for MLlib algorithms, and the <a href="mllib-guide.html">MLlib programming guide</a> has details on specific algorithms.</p> |
| |
| <h2 id="example-estimator-transformer-and-param">Example: Estimator, Transformer, and Param</h2> |
| |
| <p>This example covers the concepts of <code>Estimator</code>, <code>Transformer</code>, and <code>Param</code>.</p> |
| |
| <div class="codetabs"> |
| |
| <div data-lang="scala"> |
| |
| <div class="highlight"><pre><code class="language-scala" data-lang="scala"><span class="k">import</span> <span class="nn">org.apache.spark.</span><span class="o">{</span><span class="nc">SparkConf</span><span class="o">,</span> <span class="nc">SparkContext</span><span class="o">}</span> |
| <span class="k">import</span> <span class="nn">org.apache.spark.ml.classification.LogisticRegression</span> |
| <span class="k">import</span> <span class="nn">org.apache.spark.ml.param.ParamMap</span> |
| <span class="k">import</span> <span class="nn">org.apache.spark.mllib.linalg.</span><span class="o">{</span><span class="nc">Vector</span><span class="o">,</span> <span class="nc">Vectors</span><span class="o">}</span> |
| <span class="k">import</span> <span class="nn">org.apache.spark.mllib.regression.LabeledPoint</span> |
| <span class="k">import</span> <span class="nn">org.apache.spark.sql.</span><span class="o">{</span><span class="nc">Row</span><span class="o">,</span> <span class="nc">SQLContext</span><span class="o">}</span> |
| |
| <span class="k">val</span> <span class="n">conf</span> <span class="k">=</span> <span class="k">new</span> <span class="nc">SparkConf</span><span class="o">().</span><span class="n">setAppName</span><span class="o">(</span><span class="s">"SimpleParamsExample"</span><span class="o">)</span> |
| <span class="k">val</span> <span class="n">sc</span> <span class="k">=</span> <span class="k">new</span> <span class="nc">SparkContext</span><span class="o">(</span><span class="n">conf</span><span class="o">)</span> |
| <span class="k">val</span> <span class="n">sqlContext</span> <span class="k">=</span> <span class="k">new</span> <span class="nc">SQLContext</span><span class="o">(</span><span class="n">sc</span><span class="o">)</span> |
| <span class="k">import</span> <span class="nn">sqlContext._</span> |
| |
| <span class="c1">// Prepare training data.</span> |
| <span class="c1">// We use LabeledPoint, which is a case class. Spark SQL can convert RDDs of case classes</span> |
| <span class="c1">// into SchemaRDDs, where it uses the case class metadata to infer the schema.</span> |
| <span class="k">val</span> <span class="n">training</span> <span class="k">=</span> <span class="n">sparkContext</span><span class="o">.</span><span class="n">parallelize</span><span class="o">(</span><span class="nc">Seq</span><span class="o">(</span> |
| <span class="nc">LabeledPoint</span><span class="o">(</span><span class="mf">1.0</span><span class="o">,</span> <span class="nc">Vectors</span><span class="o">.</span><span class="n">dense</span><span class="o">(</span><span class="mf">0.0</span><span class="o">,</span> <span class="mf">1.1</span><span class="o">,</span> <span class="mf">0.1</span><span class="o">)),</span> |
| <span class="nc">LabeledPoint</span><span class="o">(</span><span class="mf">0.0</span><span class="o">,</span> <span class="nc">Vectors</span><span class="o">.</span><span class="n">dense</span><span class="o">(</span><span class="mf">2.0</span><span class="o">,</span> <span class="mf">1.0</span><span class="o">,</span> <span class="o">-</span><span class="mf">1.0</span><span class="o">)),</span> |
| <span class="nc">LabeledPoint</span><span class="o">(</span><span class="mf">0.0</span><span class="o">,</span> <span class="nc">Vectors</span><span class="o">.</span><span class="n">dense</span><span class="o">(</span><span class="mf">2.0</span><span class="o">,</span> <span class="mf">1.3</span><span class="o">,</span> <span class="mf">1.0</span><span class="o">)),</span> |
| <span class="nc">LabeledPoint</span><span class="o">(</span><span class="mf">1.0</span><span class="o">,</span> <span class="nc">Vectors</span><span class="o">.</span><span class="n">dense</span><span class="o">(</span><span class="mf">0.0</span><span class="o">,</span> <span class="mf">1.2</span><span class="o">,</span> <span class="o">-</span><span class="mf">0.5</span><span class="o">))))</span> |
| |
| <span class="c1">// Create a LogisticRegression instance. This instance is an Estimator.</span> |
| <span class="k">val</span> <span class="n">lr</span> <span class="k">=</span> <span class="k">new</span> <span class="nc">LogisticRegression</span><span class="o">()</span> |
| <span class="c1">// Print out the parameters, documentation, and any default values.</span> |
| <span class="n">println</span><span class="o">(</span><span class="s">"LogisticRegression parameters:\n"</span> <span class="o">+</span> <span class="n">lr</span><span class="o">.</span><span class="n">explainParams</span><span class="o">()</span> <span class="o">+</span> <span class="s">"\n"</span><span class="o">)</span> |
| |
| <span class="c1">// We may set parameters using setter methods.</span> |
| <span class="n">lr</span><span class="o">.</span><span class="n">setMaxIter</span><span class="o">(</span><span class="mi">10</span><span class="o">)</span> |
| <span class="o">.</span><span class="n">setRegParam</span><span class="o">(</span><span class="mf">0.01</span><span class="o">)</span> |
| |
| <span class="c1">// Learn a LogisticRegression model. This uses the parameters stored in lr.</span> |
| <span class="k">val</span> <span class="n">model1</span> <span class="k">=</span> <span class="n">lr</span><span class="o">.</span><span class="n">fit</span><span class="o">(</span><span class="n">training</span><span class="o">)</span> |
| <span class="c1">// Since model1 is a Model (i.e., a Transformer produced by an Estimator),</span> |
| <span class="c1">// we can view the parameters it used during fit().</span> |
| <span class="c1">// This prints the parameter (name: value) pairs, where names are unique IDs for this</span> |
| <span class="c1">// LogisticRegression instance.</span> |
| <span class="n">println</span><span class="o">(</span><span class="s">"Model 1 was fit using parameters: "</span> <span class="o">+</span> <span class="n">model1</span><span class="o">.</span><span class="n">fittingParamMap</span><span class="o">)</span> |
| |
| <span class="c1">// We may alternatively specify parameters using a ParamMap,</span> |
| <span class="c1">// which supports several methods for specifying parameters.</span> |
| <span class="k">val</span> <span class="n">paramMap</span> <span class="k">=</span> <span class="nc">ParamMap</span><span class="o">(</span><span class="n">lr</span><span class="o">.</span><span class="n">maxIter</span> <span class="o">-></span> <span class="mi">20</span><span class="o">)</span> |
| <span class="n">paramMap</span><span class="o">.</span><span class="n">put</span><span class="o">(</span><span class="n">lr</span><span class="o">.</span><span class="n">maxIter</span><span class="o">,</span> <span class="mi">30</span><span class="o">)</span> <span class="c1">// Specify 1 Param. This overwrites the original maxIter.</span> |
| <span class="n">paramMap</span><span class="o">.</span><span class="n">put</span><span class="o">(</span><span class="n">lr</span><span class="o">.</span><span class="n">regParam</span> <span class="o">-></span> <span class="mf">0.1</span><span class="o">,</span> <span class="n">lr</span><span class="o">.</span><span class="n">threshold</span> <span class="o">-></span> <span class="mf">0.5</span><span class="o">)</span> <span class="c1">// Specify multiple Params.</span> |
| |
| <span class="c1">// One can also combine ParamMaps.</span> |
| <span class="k">val</span> <span class="n">paramMap2</span> <span class="k">=</span> <span class="nc">ParamMap</span><span class="o">(</span><span class="n">lr</span><span class="o">.</span><span class="n">scoreCol</span> <span class="o">-></span> <span class="s">"probability"</span><span class="o">)</span> <span class="c1">// Changes output column name.</span> |
| <span class="k">val</span> <span class="n">paramMapCombined</span> <span class="k">=</span> <span class="n">paramMap</span> <span class="o">++</span> <span class="n">paramMap2</span> |
| |
| <span class="c1">// Now learn a new model using the paramMapCombined parameters.</span> |
| <span class="c1">// paramMapCombined overrides all parameters set earlier via lr.set* methods.</span> |
| <span class="k">val</span> <span class="n">model2</span> <span class="k">=</span> <span class="n">lr</span><span class="o">.</span><span class="n">fit</span><span class="o">(</span><span class="n">training</span><span class="o">,</span> <span class="n">paramMapCombined</span><span class="o">)</span> |
| <span class="n">println</span><span class="o">(</span><span class="s">"Model 2 was fit using parameters: "</span> <span class="o">+</span> <span class="n">model2</span><span class="o">.</span><span class="n">fittingParamMap</span><span class="o">)</span> |
| |
| <span class="c1">// Prepare test documents.</span> |
| <span class="k">val</span> <span class="n">test</span> <span class="k">=</span> <span class="n">sparkContext</span><span class="o">.</span><span class="n">parallelize</span><span class="o">(</span><span class="nc">Seq</span><span class="o">(</span> |
| <span class="nc">LabeledPoint</span><span class="o">(</span><span class="mf">1.0</span><span class="o">,</span> <span class="nc">Vectors</span><span class="o">.</span><span class="n">dense</span><span class="o">(-</span><span class="mf">1.0</span><span class="o">,</span> <span class="mf">1.5</span><span class="o">,</span> <span class="mf">1.3</span><span class="o">)),</span> |
| <span class="nc">LabeledPoint</span><span class="o">(</span><span class="mf">0.0</span><span class="o">,</span> <span class="nc">Vectors</span><span class="o">.</span><span class="n">dense</span><span class="o">(</span><span class="mf">3.0</span><span class="o">,</span> <span class="mf">2.0</span><span class="o">,</span> <span class="o">-</span><span class="mf">0.1</span><span class="o">)),</span> |
| <span class="nc">LabeledPoint</span><span class="o">(</span><span class="mf">1.0</span><span class="o">,</span> <span class="nc">Vectors</span><span class="o">.</span><span class="n">dense</span><span class="o">(</span><span class="mf">0.0</span><span class="o">,</span> <span class="mf">2.2</span><span class="o">,</span> <span class="o">-</span><span class="mf">1.5</span><span class="o">))))</span> |
| |
| <span class="c1">// Make predictions on test documents using the Transformer.transform() method.</span> |
| <span class="c1">// LogisticRegression.transform will only use the 'features' column.</span> |
| <span class="c1">// Note that model2.transform() outputs a 'probability' column instead of the usual 'score'</span> |
| <span class="c1">// column since we renamed the lr.scoreCol parameter previously.</span> |
| <span class="n">model2</span><span class="o">.</span><span class="n">transform</span><span class="o">(</span><span class="n">test</span><span class="o">)</span> |
| <span class="o">.</span><span class="n">select</span><span class="o">(</span><span class="-Symbol">'features</span><span class="o">,</span> <span class="-Symbol">'label</span><span class="o">,</span> <span class="-Symbol">'probability</span><span class="o">,</span> <span class="-Symbol">'prediction</span><span class="o">)</span> |
| <span class="o">.</span><span class="n">collect</span><span class="o">()</span> |
| <span class="o">.</span><span class="n">foreach</span> <span class="o">{</span> <span class="k">case</span> <span class="nc">Row</span><span class="o">(</span><span class="n">features</span><span class="k">:</span> <span class="kt">Vector</span><span class="o">,</span> <span class="n">label</span><span class="k">:</span> <span class="kt">Double</span><span class="o">,</span> <span class="n">prob</span><span class="k">:</span> <span class="kt">Double</span><span class="o">,</span> <span class="n">prediction</span><span class="k">:</span> <span class="kt">Double</span><span class="o">)</span> <span class="k">=></span> |
| <span class="n">println</span><span class="o">(</span><span class="s">"("</span> <span class="o">+</span> <span class="n">features</span> <span class="o">+</span> <span class="s">", "</span> <span class="o">+</span> <span class="n">label</span> <span class="o">+</span> <span class="s">") -> prob="</span> <span class="o">+</span> <span class="n">prob</span> <span class="o">+</span> <span class="s">", prediction="</span> <span class="o">+</span> <span class="n">prediction</span><span class="o">)</span> |
| <span class="o">}</span></code></pre></div> |
| |
| </div> |
| |
| <div data-lang="java"> |
| |
| <div class="highlight"><pre><code class="language-java" data-lang="java"><span class="kn">import</span> <span class="nn">java.util.List</span><span class="o">;</span> |
| <span class="kn">import</span> <span class="nn">com.google.common.collect.Lists</span><span class="o">;</span> |
| <span class="kn">import</span> <span class="nn">org.apache.spark.SparkConf</span><span class="o">;</span> |
| <span class="kn">import</span> <span class="nn">org.apache.spark.api.java.JavaSparkContext</span><span class="o">;</span> |
| <span class="kn">import</span> <span class="nn">org.apache.spark.ml.classification.LogisticRegressionModel</span><span class="o">;</span> |
| <span class="kn">import</span> <span class="nn">org.apache.spark.ml.param.ParamMap</span><span class="o">;</span> |
| <span class="kn">import</span> <span class="nn">org.apache.spark.ml.classification.LogisticRegression</span><span class="o">;</span> |
| <span class="kn">import</span> <span class="nn">org.apache.spark.mllib.linalg.Vectors</span><span class="o">;</span> |
| <span class="kn">import</span> <span class="nn">org.apache.spark.mllib.regression.LabeledPoint</span><span class="o">;</span> |
| <span class="kn">import</span> <span class="nn">org.apache.spark.sql.api.java.JavaSQLContext</span><span class="o">;</span> |
| <span class="kn">import</span> <span class="nn">org.apache.spark.sql.api.java.JavaSchemaRDD</span><span class="o">;</span> |
| <span class="kn">import</span> <span class="nn">org.apache.spark.sql.api.java.Row</span><span class="o">;</span> |
| |
| <span class="n">SparkConf</span> <span class="n">conf</span> <span class="o">=</span> <span class="k">new</span> <span class="nf">SparkConf</span><span class="o">().</span><span class="na">setAppName</span><span class="o">(</span><span class="s">"JavaSimpleParamsExample"</span><span class="o">);</span> |
| <span class="n">JavaSparkContext</span> <span class="n">jsc</span> <span class="o">=</span> <span class="k">new</span> <span class="nf">JavaSparkContext</span><span class="o">(</span><span class="n">conf</span><span class="o">);</span> |
| <span class="n">JavaSQLContext</span> <span class="n">jsql</span> <span class="o">=</span> <span class="k">new</span> <span class="nf">JavaSQLContext</span><span class="o">(</span><span class="n">jsc</span><span class="o">);</span> |
| |
| <span class="c1">// Prepare training data.</span> |
| <span class="c1">// We use LabeledPoint, which is a case class. Spark SQL can convert RDDs of case classes</span> |
| <span class="c1">// into SchemaRDDs, where it uses the case class metadata to infer the schema.</span> |
| <span class="n">List</span><span class="o"><</span><span class="n">LabeledPoint</span><span class="o">></span> <span class="n">localTraining</span> <span class="o">=</span> <span class="n">Lists</span><span class="o">.</span><span class="na">newArrayList</span><span class="o">(</span> |
| <span class="k">new</span> <span class="nf">LabeledPoint</span><span class="o">(</span><span class="mf">1.0</span><span class="o">,</span> <span class="n">Vectors</span><span class="o">.</span><span class="na">dense</span><span class="o">(</span><span class="mf">0.0</span><span class="o">,</span> <span class="mf">1.1</span><span class="o">,</span> <span class="mf">0.1</span><span class="o">)),</span> |
| <span class="k">new</span> <span class="nf">LabeledPoint</span><span class="o">(</span><span class="mf">0.0</span><span class="o">,</span> <span class="n">Vectors</span><span class="o">.</span><span class="na">dense</span><span class="o">(</span><span class="mf">2.0</span><span class="o">,</span> <span class="mf">1.0</span><span class="o">,</span> <span class="o">-</span><span class="mf">1.0</span><span class="o">)),</span> |
| <span class="k">new</span> <span class="nf">LabeledPoint</span><span class="o">(</span><span class="mf">0.0</span><span class="o">,</span> <span class="n">Vectors</span><span class="o">.</span><span class="na">dense</span><span class="o">(</span><span class="mf">2.0</span><span class="o">,</span> <span class="mf">1.3</span><span class="o">,</span> <span class="mf">1.0</span><span class="o">)),</span> |
| <span class="k">new</span> <span class="nf">LabeledPoint</span><span class="o">(</span><span class="mf">1.0</span><span class="o">,</span> <span class="n">Vectors</span><span class="o">.</span><span class="na">dense</span><span class="o">(</span><span class="mf">0.0</span><span class="o">,</span> <span class="mf">1.2</span><span class="o">,</span> <span class="o">-</span><span class="mf">0.5</span><span class="o">)));</span> |
| <span class="n">JavaSchemaRDD</span> <span class="n">training</span> <span class="o">=</span> <span class="n">jsql</span><span class="o">.</span><span class="na">applySchema</span><span class="o">(</span><span class="n">jsc</span><span class="o">.</span><span class="na">parallelize</span><span class="o">(</span><span class="n">localTraining</span><span class="o">),</span> <span class="n">LabeledPoint</span><span class="o">.</span><span class="na">class</span><span class="o">);</span> |
| |
| <span class="c1">// Create a LogisticRegression instance. This instance is an Estimator.</span> |
| <span class="n">LogisticRegression</span> <span class="n">lr</span> <span class="o">=</span> <span class="k">new</span> <span class="nf">LogisticRegression</span><span class="o">();</span> |
| <span class="c1">// Print out the parameters, documentation, and any default values.</span> |
| <span class="n">System</span><span class="o">.</span><span class="na">out</span><span class="o">.</span><span class="na">println</span><span class="o">(</span><span class="s">"LogisticRegression parameters:\n"</span> <span class="o">+</span> <span class="n">lr</span><span class="o">.</span><span class="na">explainParams</span><span class="o">()</span> <span class="o">+</span> <span class="s">"\n"</span><span class="o">);</span> |
| |
| <span class="c1">// We may set parameters using setter methods.</span> |
| <span class="n">lr</span><span class="o">.</span><span class="na">setMaxIter</span><span class="o">(</span><span class="mi">10</span><span class="o">)</span> |
| <span class="o">.</span><span class="na">setRegParam</span><span class="o">(</span><span class="mf">0.01</span><span class="o">);</span> |
| |
| <span class="c1">// Learn a LogisticRegression model. This uses the parameters stored in lr.</span> |
| <span class="n">LogisticRegressionModel</span> <span class="n">model1</span> <span class="o">=</span> <span class="n">lr</span><span class="o">.</span><span class="na">fit</span><span class="o">(</span><span class="n">training</span><span class="o">);</span> |
| <span class="c1">// Since model1 is a Model (i.e., a Transformer produced by an Estimator),</span> |
| <span class="c1">// we can view the parameters it used during fit().</span> |
| <span class="c1">// This prints the parameter (name: value) pairs, where names are unique IDs for this</span> |
| <span class="c1">// LogisticRegression instance.</span> |
| <span class="n">System</span><span class="o">.</span><span class="na">out</span><span class="o">.</span><span class="na">println</span><span class="o">(</span><span class="s">"Model 1 was fit using parameters: "</span> <span class="o">+</span> <span class="n">model1</span><span class="o">.</span><span class="na">fittingParamMap</span><span class="o">());</span> |
| |
| <span class="c1">// We may alternatively specify parameters using a ParamMap.</span> |
| <span class="n">ParamMap</span> <span class="n">paramMap</span> <span class="o">=</span> <span class="k">new</span> <span class="nf">ParamMap</span><span class="o">();</span> |
| <span class="n">paramMap</span><span class="o">.</span><span class="na">put</span><span class="o">(</span><span class="n">lr</span><span class="o">.</span><span class="na">maxIter</span><span class="o">(),</span> <span class="mi">20</span><span class="o">);</span> <span class="c1">// Specify 1 Param.</span> |
| <span class="n">paramMap</span><span class="o">.</span><span class="na">put</span><span class="o">(</span><span class="n">lr</span><span class="o">.</span><span class="na">maxIter</span><span class="o">(),</span> <span class="mi">30</span><span class="o">);</span> <span class="c1">// This overwrites the original maxIter.</span> |
| <span class="n">paramMap</span><span class="o">.</span><span class="na">put</span><span class="o">(</span><span class="n">lr</span><span class="o">.</span><span class="na">regParam</span><span class="o">(),</span> <span class="mf">0.1</span><span class="o">);</span> |
| |
| <span class="c1">// One can also combine ParamMaps.</span> |
| <span class="n">ParamMap</span> <span class="n">paramMap2</span> <span class="o">=</span> <span class="k">new</span> <span class="nf">ParamMap</span><span class="o">();</span> |
| <span class="n">paramMap2</span><span class="o">.</span><span class="na">put</span><span class="o">(</span><span class="n">lr</span><span class="o">.</span><span class="na">scoreCol</span><span class="o">(),</span> <span class="s">"probability"</span><span class="o">);</span> <span class="c1">// Changes output column name.</span> |
| <span class="n">ParamMap</span> <span class="n">paramMapCombined</span> <span class="o">=</span> <span class="n">paramMap</span><span class="o">.</span><span class="na">$plus$plus</span><span class="o">(</span><span class="n">paramMap2</span><span class="o">);</span> |
| |
| <span class="c1">// Now learn a new model using the paramMapCombined parameters.</span> |
| <span class="c1">// paramMapCombined overrides all parameters set earlier via lr.set* methods.</span> |
| <span class="n">LogisticRegressionModel</span> <span class="n">model2</span> <span class="o">=</span> <span class="n">lr</span><span class="o">.</span><span class="na">fit</span><span class="o">(</span><span class="n">training</span><span class="o">,</span> <span class="n">paramMapCombined</span><span class="o">);</span> |
| <span class="n">System</span><span class="o">.</span><span class="na">out</span><span class="o">.</span><span class="na">println</span><span class="o">(</span><span class="s">"Model 2 was fit using parameters: "</span> <span class="o">+</span> <span class="n">model2</span><span class="o">.</span><span class="na">fittingParamMap</span><span class="o">());</span> |
| |
| <span class="c1">// Prepare test documents.</span> |
| <span class="n">List</span><span class="o"><</span><span class="n">LabeledPoint</span><span class="o">></span> <span class="n">localTest</span> <span class="o">=</span> <span class="n">Lists</span><span class="o">.</span><span class="na">newArrayList</span><span class="o">(</span> |
| <span class="k">new</span> <span class="nf">LabeledPoint</span><span class="o">(</span><span class="mf">1.0</span><span class="o">,</span> <span class="n">Vectors</span><span class="o">.</span><span class="na">dense</span><span class="o">(-</span><span class="mf">1.0</span><span class="o">,</span> <span class="mf">1.5</span><span class="o">,</span> <span class="mf">1.3</span><span class="o">)),</span> |
| <span class="k">new</span> <span class="nf">LabeledPoint</span><span class="o">(</span><span class="mf">0.0</span><span class="o">,</span> <span class="n">Vectors</span><span class="o">.</span><span class="na">dense</span><span class="o">(</span><span class="mf">3.0</span><span class="o">,</span> <span class="mf">2.0</span><span class="o">,</span> <span class="o">-</span><span class="mf">0.1</span><span class="o">)),</span> |
| <span class="k">new</span> <span class="nf">LabeledPoint</span><span class="o">(</span><span class="mf">1.0</span><span class="o">,</span> <span class="n">Vectors</span><span class="o">.</span><span class="na">dense</span><span class="o">(</span><span class="mf">0.0</span><span class="o">,</span> <span class="mf">2.2</span><span class="o">,</span> <span class="o">-</span><span class="mf">1.5</span><span class="o">)));</span> |
| <span class="n">JavaSchemaRDD</span> <span class="n">test</span> <span class="o">=</span> <span class="n">jsql</span><span class="o">.</span><span class="na">applySchema</span><span class="o">(</span><span class="n">jsc</span><span class="o">.</span><span class="na">parallelize</span><span class="o">(</span><span class="n">localTest</span><span class="o">),</span> <span class="n">LabeledPoint</span><span class="o">.</span><span class="na">class</span><span class="o">);</span> |
| |
| <span class="c1">// Make predictions on test documents using the Transformer.transform() method.</span> |
| <span class="c1">// LogisticRegression.transform will only use the 'features' column.</span> |
| <span class="c1">// Note that model2.transform() outputs a 'probability' column instead of the usual 'score'</span> |
| <span class="c1">// column since we renamed the lr.scoreCol parameter previously.</span> |
| <span class="n">model2</span><span class="o">.</span><span class="na">transform</span><span class="o">(</span><span class="n">test</span><span class="o">).</span><span class="na">registerAsTable</span><span class="o">(</span><span class="s">"results"</span><span class="o">);</span> |
| <span class="n">JavaSchemaRDD</span> <span class="n">results</span> <span class="o">=</span> |
| <span class="n">jsql</span><span class="o">.</span><span class="na">sql</span><span class="o">(</span><span class="s">"SELECT features, label, probability, prediction FROM results"</span><span class="o">);</span> |
| <span class="k">for</span> <span class="o">(</span><span class="n">Row</span> <span class="nl">r:</span> <span class="n">results</span><span class="o">.</span><span class="na">collect</span><span class="o">())</span> <span class="o">{</span> |
| <span class="n">System</span><span class="o">.</span><span class="na">out</span><span class="o">.</span><span class="na">println</span><span class="o">(</span><span class="s">"("</span> <span class="o">+</span> <span class="n">r</span><span class="o">.</span><span class="na">get</span><span class="o">(</span><span class="mi">0</span><span class="o">)</span> <span class="o">+</span> <span class="s">", "</span> <span class="o">+</span> <span class="n">r</span><span class="o">.</span><span class="na">get</span><span class="o">(</span><span class="mi">1</span><span class="o">)</span> <span class="o">+</span> <span class="s">") -> prob="</span> <span class="o">+</span> <span class="n">r</span><span class="o">.</span><span class="na">get</span><span class="o">(</span><span class="mi">2</span><span class="o">)</span> |
| <span class="o">+</span> <span class="s">", prediction="</span> <span class="o">+</span> <span class="n">r</span><span class="o">.</span><span class="na">get</span><span class="o">(</span><span class="mi">3</span><span class="o">));</span> |
| <span class="o">}</span></code></pre></div> |
| |
| </div> |
| |
| </div> |
| |
| <h2 id="example-pipeline">Example: Pipeline</h2> |
| |
| <p>This example follows the simple text document <code>Pipeline</code> illustrated in the figures above.</p> |
| |
| <div class="codetabs"> |
| |
| <div data-lang="scala"> |
| |
| <div class="highlight"><pre><code class="language-scala" data-lang="scala"><span class="k">import</span> <span class="nn">org.apache.spark.</span><span class="o">{</span><span class="nc">SparkConf</span><span class="o">,</span> <span class="nc">SparkContext</span><span class="o">}</span> |
| <span class="k">import</span> <span class="nn">org.apache.spark.ml.Pipeline</span> |
| <span class="k">import</span> <span class="nn">org.apache.spark.ml.classification.LogisticRegression</span> |
| <span class="k">import</span> <span class="nn">org.apache.spark.ml.feature.</span><span class="o">{</span><span class="nc">HashingTF</span><span class="o">,</span> <span class="nc">Tokenizer</span><span class="o">}</span> |
| <span class="k">import</span> <span class="nn">org.apache.spark.sql.</span><span class="o">{</span><span class="nc">Row</span><span class="o">,</span> <span class="nc">SQLContext</span><span class="o">}</span> |
| |
| <span class="c1">// Labeled and unlabeled instance types.</span> |
| <span class="c1">// Spark SQL can infer schema from case classes.</span> |
| <span class="k">case</span> <span class="k">class</span> <span class="nc">LabeledDocument</span><span class="o">(</span><span class="n">id</span><span class="k">:</span> <span class="kt">Long</span><span class="o">,</span> <span class="n">text</span><span class="k">:</span> <span class="kt">String</span><span class="o">,</span> <span class="n">label</span><span class="k">:</span> <span class="kt">Double</span><span class="o">)</span> |
| <span class="k">case</span> <span class="k">class</span> <span class="nc">Document</span><span class="o">(</span><span class="n">id</span><span class="k">:</span> <span class="kt">Long</span><span class="o">,</span> <span class="n">text</span><span class="k">:</span> <span class="kt">String</span><span class="o">)</span> |
| |
| <span class="c1">// Set up contexts. Import implicit conversions to SchemaRDD from sqlContext.</span> |
| <span class="k">val</span> <span class="n">conf</span> <span class="k">=</span> <span class="k">new</span> <span class="nc">SparkConf</span><span class="o">().</span><span class="n">setAppName</span><span class="o">(</span><span class="s">"SimpleTextClassificationPipeline"</span><span class="o">)</span> |
| <span class="k">val</span> <span class="n">sc</span> <span class="k">=</span> <span class="k">new</span> <span class="nc">SparkContext</span><span class="o">(</span><span class="n">conf</span><span class="o">)</span> |
| <span class="k">val</span> <span class="n">sqlContext</span> <span class="k">=</span> <span class="k">new</span> <span class="nc">SQLContext</span><span class="o">(</span><span class="n">sc</span><span class="o">)</span> |
| <span class="k">import</span> <span class="nn">sqlContext._</span> |
| |
| <span class="c1">// Prepare training documents, which are labeled.</span> |
| <span class="k">val</span> <span class="n">training</span> <span class="k">=</span> <span class="n">sparkContext</span><span class="o">.</span><span class="n">parallelize</span><span class="o">(</span><span class="nc">Seq</span><span class="o">(</span> |
| <span class="nc">LabeledDocument</span><span class="o">(</span><span class="mi">0L</span><span class="o">,</span> <span class="s">"a b c d e spark"</span><span class="o">,</span> <span class="mf">1.0</span><span class="o">),</span> |
| <span class="nc">LabeledDocument</span><span class="o">(</span><span class="mi">1L</span><span class="o">,</span> <span class="s">"b d"</span><span class="o">,</span> <span class="mf">0.0</span><span class="o">),</span> |
| <span class="nc">LabeledDocument</span><span class="o">(</span><span class="mi">2L</span><span class="o">,</span> <span class="s">"spark f g h"</span><span class="o">,</span> <span class="mf">1.0</span><span class="o">),</span> |
| <span class="nc">LabeledDocument</span><span class="o">(</span><span class="mi">3L</span><span class="o">,</span> <span class="s">"hadoop mapreduce"</span><span class="o">,</span> <span class="mf">0.0</span><span class="o">)))</span> |
| |
| <span class="c1">// Configure an ML pipeline, which consists of three stages: tokenizer, hashingTF, and lr.</span> |
| <span class="k">val</span> <span class="n">tokenizer</span> <span class="k">=</span> <span class="k">new</span> <span class="nc">Tokenizer</span><span class="o">()</span> |
| <span class="o">.</span><span class="n">setInputCol</span><span class="o">(</span><span class="s">"text"</span><span class="o">)</span> |
| <span class="o">.</span><span class="n">setOutputCol</span><span class="o">(</span><span class="s">"words"</span><span class="o">)</span> |
| <span class="k">val</span> <span class="n">hashingTF</span> <span class="k">=</span> <span class="k">new</span> <span class="nc">HashingTF</span><span class="o">()</span> |
| <span class="o">.</span><span class="n">setNumFeatures</span><span class="o">(</span><span class="mi">1000</span><span class="o">)</span> |
| <span class="o">.</span><span class="n">setInputCol</span><span class="o">(</span><span class="n">tokenizer</span><span class="o">.</span><span class="n">getOutputCol</span><span class="o">)</span> |
| <span class="o">.</span><span class="n">setOutputCol</span><span class="o">(</span><span class="s">"features"</span><span class="o">)</span> |
| <span class="k">val</span> <span class="n">lr</span> <span class="k">=</span> <span class="k">new</span> <span class="nc">LogisticRegression</span><span class="o">()</span> |
| <span class="o">.</span><span class="n">setMaxIter</span><span class="o">(</span><span class="mi">10</span><span class="o">)</span> |
| <span class="o">.</span><span class="n">setRegParam</span><span class="o">(</span><span class="mf">0.01</span><span class="o">)</span> |
| <span class="k">val</span> <span class="n">pipeline</span> <span class="k">=</span> <span class="k">new</span> <span class="nc">Pipeline</span><span class="o">()</span> |
| <span class="o">.</span><span class="n">setStages</span><span class="o">(</span><span class="nc">Array</span><span class="o">(</span><span class="n">tokenizer</span><span class="o">,</span> <span class="n">hashingTF</span><span class="o">,</span> <span class="n">lr</span><span class="o">))</span> |
| |
| <span class="c1">// Fit the pipeline to training documents.</span> |
| <span class="k">val</span> <span class="n">model</span> <span class="k">=</span> <span class="n">pipeline</span><span class="o">.</span><span class="n">fit</span><span class="o">(</span><span class="n">training</span><span class="o">)</span> |
| |
| <span class="c1">// Prepare test documents, which are unlabeled.</span> |
| <span class="k">val</span> <span class="n">test</span> <span class="k">=</span> <span class="n">sparkContext</span><span class="o">.</span><span class="n">parallelize</span><span class="o">(</span><span class="nc">Seq</span><span class="o">(</span> |
| <span class="nc">Document</span><span class="o">(</span><span class="mi">4L</span><span class="o">,</span> <span class="s">"spark i j k"</span><span class="o">),</span> |
| <span class="nc">Document</span><span class="o">(</span><span class="mi">5L</span><span class="o">,</span> <span class="s">"l m n"</span><span class="o">),</span> |
| <span class="nc">Document</span><span class="o">(</span><span class="mi">6L</span><span class="o">,</span> <span class="s">"mapreduce spark"</span><span class="o">),</span> |
| <span class="nc">Document</span><span class="o">(</span><span class="mi">7L</span><span class="o">,</span> <span class="s">"apache hadoop"</span><span class="o">)))</span> |
| |
| <span class="c1">// Make predictions on test documents.</span> |
| <span class="n">model</span><span class="o">.</span><span class="n">transform</span><span class="o">(</span><span class="n">test</span><span class="o">)</span> |
| <span class="o">.</span><span class="n">select</span><span class="o">(</span><span class="-Symbol">'id</span><span class="o">,</span> <span class="-Symbol">'text</span><span class="o">,</span> <span class="-Symbol">'score</span><span class="o">,</span> <span class="-Symbol">'prediction</span><span class="o">)</span> |
| <span class="o">.</span><span class="n">collect</span><span class="o">()</span> |
| <span class="o">.</span><span class="n">foreach</span> <span class="o">{</span> <span class="k">case</span> <span class="nc">Row</span><span class="o">(</span><span class="n">id</span><span class="k">:</span> <span class="kt">Long</span><span class="o">,</span> <span class="n">text</span><span class="k">:</span> <span class="kt">String</span><span class="o">,</span> <span class="n">score</span><span class="k">:</span> <span class="kt">Double</span><span class="o">,</span> <span class="n">prediction</span><span class="k">:</span> <span class="kt">Double</span><span class="o">)</span> <span class="k">=></span> |
| <span class="n">println</span><span class="o">(</span><span class="s">"("</span> <span class="o">+</span> <span class="n">id</span> <span class="o">+</span> <span class="s">", "</span> <span class="o">+</span> <span class="n">text</span> <span class="o">+</span> <span class="s">") --> score="</span> <span class="o">+</span> <span class="n">score</span> <span class="o">+</span> <span class="s">", prediction="</span> <span class="o">+</span> <span class="n">prediction</span><span class="o">)</span> |
| <span class="o">}</span></code></pre></div> |
| |
| </div> |
| |
| <div data-lang="java"> |
| |
| <div class="highlight"><pre><code class="language-java" data-lang="java"><span class="kn">import</span> <span class="nn">java.io.Serializable</span><span class="o">;</span> |
| <span class="kn">import</span> <span class="nn">java.util.List</span><span class="o">;</span> |
| <span class="kn">import</span> <span class="nn">com.google.common.collect.Lists</span><span class="o">;</span> |
| <span class="kn">import</span> <span class="nn">org.apache.spark.api.java.JavaSparkContext</span><span class="o">;</span> |
| <span class="kn">import</span> <span class="nn">org.apache.spark.ml.Pipeline</span><span class="o">;</span> |
| <span class="kn">import</span> <span class="nn">org.apache.spark.ml.PipelineModel</span><span class="o">;</span> |
| <span class="kn">import</span> <span class="nn">org.apache.spark.ml.PipelineStage</span><span class="o">;</span> |
| <span class="kn">import</span> <span class="nn">org.apache.spark.ml.classification.LogisticRegression</span><span class="o">;</span> |
| <span class="kn">import</span> <span class="nn">org.apache.spark.ml.feature.HashingTF</span><span class="o">;</span> |
| <span class="kn">import</span> <span class="nn">org.apache.spark.ml.feature.Tokenizer</span><span class="o">;</span> |
| <span class="kn">import</span> <span class="nn">org.apache.spark.sql.api.java.JavaSQLContext</span><span class="o">;</span> |
| <span class="kn">import</span> <span class="nn">org.apache.spark.sql.api.java.JavaSchemaRDD</span><span class="o">;</span> |
| <span class="kn">import</span> <span class="nn">org.apache.spark.sql.api.java.Row</span><span class="o">;</span> |
| <span class="kn">import</span> <span class="nn">org.apache.spark.SparkConf</span><span class="o">;</span> |
| |
| <span class="c1">// Labeled and unlabeled instance types.</span> |
| <span class="c1">// Spark SQL can infer schema from Java Beans.</span> |
| <span class="kd">public</span> <span class="kd">class</span> <span class="nc">Document</span> <span class="kd">implements</span> <span class="n">Serializable</span> <span class="o">{</span> |
| <span class="kd">private</span> <span class="n">Long</span> <span class="n">id</span><span class="o">;</span> |
| <span class="kd">private</span> <span class="n">String</span> <span class="n">text</span><span class="o">;</span> |
| |
| <span class="kd">public</span> <span class="nf">Document</span><span class="o">(</span><span class="n">Long</span> <span class="n">id</span><span class="o">,</span> <span class="n">String</span> <span class="n">text</span><span class="o">)</span> <span class="o">{</span> |
| <span class="k">this</span><span class="o">.</span><span class="na">id</span> <span class="o">=</span> <span class="n">id</span><span class="o">;</span> |
| <span class="k">this</span><span class="o">.</span><span class="na">text</span> <span class="o">=</span> <span class="n">text</span><span class="o">;</span> |
| <span class="o">}</span> |
| |
| <span class="kd">public</span> <span class="n">Long</span> <span class="nf">getId</span><span class="o">()</span> <span class="o">{</span> <span class="k">return</span> <span class="k">this</span><span class="o">.</span><span class="na">id</span><span class="o">;</span> <span class="o">}</span> |
| <span class="kd">public</span> <span class="kt">void</span> <span class="nf">setId</span><span class="o">(</span><span class="n">Long</span> <span class="n">id</span><span class="o">)</span> <span class="o">{</span> <span class="k">this</span><span class="o">.</span><span class="na">id</span> <span class="o">=</span> <span class="n">id</span><span class="o">;</span> <span class="o">}</span> |
| |
| <span class="kd">public</span> <span class="n">String</span> <span class="nf">getText</span><span class="o">()</span> <span class="o">{</span> <span class="k">return</span> <span class="k">this</span><span class="o">.</span><span class="na">text</span><span class="o">;</span> <span class="o">}</span> |
| <span class="kd">public</span> <span class="kt">void</span> <span class="nf">setText</span><span class="o">(</span><span class="n">String</span> <span class="n">text</span><span class="o">)</span> <span class="o">{</span> <span class="k">this</span><span class="o">.</span><span class="na">text</span> <span class="o">=</span> <span class="n">text</span><span class="o">;</span> <span class="o">}</span> |
| <span class="o">}</span> |
| |
| <span class="kd">public</span> <span class="kd">class</span> <span class="nc">LabeledDocument</span> <span class="kd">extends</span> <span class="n">Document</span> <span class="kd">implements</span> <span class="n">Serializable</span> <span class="o">{</span> |
| <span class="kd">private</span> <span class="n">Double</span> <span class="n">label</span><span class="o">;</span> |
| |
| <span class="kd">public</span> <span class="nf">LabeledDocument</span><span class="o">(</span><span class="n">Long</span> <span class="n">id</span><span class="o">,</span> <span class="n">String</span> <span class="n">text</span><span class="o">,</span> <span class="n">Double</span> <span class="n">label</span><span class="o">)</span> <span class="o">{</span> |
| <span class="kd">super</span><span class="o">(</span><span class="n">id</span><span class="o">,</span> <span class="n">text</span><span class="o">);</span> |
| <span class="k">this</span><span class="o">.</span><span class="na">label</span> <span class="o">=</span> <span class="n">label</span><span class="o">;</span> |
| <span class="o">}</span> |
| |
| <span class="kd">public</span> <span class="n">Double</span> <span class="nf">getLabel</span><span class="o">()</span> <span class="o">{</span> <span class="k">return</span> <span class="k">this</span><span class="o">.</span><span class="na">label</span><span class="o">;</span> <span class="o">}</span> |
| <span class="kd">public</span> <span class="kt">void</span> <span class="nf">setLabel</span><span class="o">(</span><span class="n">Double</span> <span class="n">label</span><span class="o">)</span> <span class="o">{</span> <span class="k">this</span><span class="o">.</span><span class="na">label</span> <span class="o">=</span> <span class="n">label</span><span class="o">;</span> <span class="o">}</span> |
| <span class="o">}</span> |
| |
| <span class="c1">// Set up contexts.</span> |
| <span class="n">SparkConf</span> <span class="n">conf</span> <span class="o">=</span> <span class="k">new</span> <span class="nf">SparkConf</span><span class="o">().</span><span class="na">setAppName</span><span class="o">(</span><span class="s">"JavaSimpleTextClassificationPipeline"</span><span class="o">);</span> |
| <span class="n">JavaSparkContext</span> <span class="n">jsc</span> <span class="o">=</span> <span class="k">new</span> <span class="nf">JavaSparkContext</span><span class="o">(</span><span class="n">conf</span><span class="o">);</span> |
| <span class="n">JavaSQLContext</span> <span class="n">jsql</span> <span class="o">=</span> <span class="k">new</span> <span class="nf">JavaSQLContext</span><span class="o">(</span><span class="n">jsc</span><span class="o">);</span> |
| |
| <span class="c1">// Prepare training documents, which are labeled.</span> |
| <span class="n">List</span><span class="o"><</span><span class="n">LabeledDocument</span><span class="o">></span> <span class="n">localTraining</span> <span class="o">=</span> <span class="n">Lists</span><span class="o">.</span><span class="na">newArrayList</span><span class="o">(</span> |
| <span class="k">new</span> <span class="nf">LabeledDocument</span><span class="o">(</span><span class="mi">0L</span><span class="o">,</span> <span class="s">"a b c d e spark"</span><span class="o">,</span> <span class="mf">1.0</span><span class="o">),</span> |
| <span class="k">new</span> <span class="nf">LabeledDocument</span><span class="o">(</span><span class="mi">1L</span><span class="o">,</span> <span class="s">"b d"</span><span class="o">,</span> <span class="mf">0.0</span><span class="o">),</span> |
| <span class="k">new</span> <span class="nf">LabeledDocument</span><span class="o">(</span><span class="mi">2L</span><span class="o">,</span> <span class="s">"spark f g h"</span><span class="o">,</span> <span class="mf">1.0</span><span class="o">),</span> |
| <span class="k">new</span> <span class="nf">LabeledDocument</span><span class="o">(</span><span class="mi">3L</span><span class="o">,</span> <span class="s">"hadoop mapreduce"</span><span class="o">,</span> <span class="mf">0.0</span><span class="o">));</span> |
| <span class="n">JavaSchemaRDD</span> <span class="n">training</span> <span class="o">=</span> |
| <span class="n">jsql</span><span class="o">.</span><span class="na">applySchema</span><span class="o">(</span><span class="n">jsc</span><span class="o">.</span><span class="na">parallelize</span><span class="o">(</span><span class="n">localTraining</span><span class="o">),</span> <span class="n">LabeledDocument</span><span class="o">.</span><span class="na">class</span><span class="o">);</span> |
| |
| <span class="c1">// Configure an ML pipeline, which consists of three stages: tokenizer, hashingTF, and lr.</span> |
| <span class="n">Tokenizer</span> <span class="n">tokenizer</span> <span class="o">=</span> <span class="k">new</span> <span class="nf">Tokenizer</span><span class="o">()</span> |
| <span class="o">.</span><span class="na">setInputCol</span><span class="o">(</span><span class="s">"text"</span><span class="o">)</span> |
| <span class="o">.</span><span class="na">setOutputCol</span><span class="o">(</span><span class="s">"words"</span><span class="o">);</span> |
| <span class="n">HashingTF</span> <span class="n">hashingTF</span> <span class="o">=</span> <span class="k">new</span> <span class="nf">HashingTF</span><span class="o">()</span> |
| <span class="o">.</span><span class="na">setNumFeatures</span><span class="o">(</span><span class="mi">1000</span><span class="o">)</span> |
| <span class="o">.</span><span class="na">setInputCol</span><span class="o">(</span><span class="n">tokenizer</span><span class="o">.</span><span class="na">getOutputCol</span><span class="o">())</span> |
| <span class="o">.</span><span class="na">setOutputCol</span><span class="o">(</span><span class="s">"features"</span><span class="o">);</span> |
| <span class="n">LogisticRegression</span> <span class="n">lr</span> <span class="o">=</span> <span class="k">new</span> <span class="nf">LogisticRegression</span><span class="o">()</span> |
| <span class="o">.</span><span class="na">setMaxIter</span><span class="o">(</span><span class="mi">10</span><span class="o">)</span> |
| <span class="o">.</span><span class="na">setRegParam</span><span class="o">(</span><span class="mf">0.01</span><span class="o">);</span> |
| <span class="n">Pipeline</span> <span class="n">pipeline</span> <span class="o">=</span> <span class="k">new</span> <span class="nf">Pipeline</span><span class="o">()</span> |
| <span class="o">.</span><span class="na">setStages</span><span class="o">(</span><span class="k">new</span> <span class="n">PipelineStage</span><span class="o">[]</span> <span class="o">{</span><span class="n">tokenizer</span><span class="o">,</span> <span class="n">hashingTF</span><span class="o">,</span> <span class="n">lr</span><span class="o">});</span> |
| |
| <span class="c1">// Fit the pipeline to training documents.</span> |
| <span class="n">PipelineModel</span> <span class="n">model</span> <span class="o">=</span> <span class="n">pipeline</span><span class="o">.</span><span class="na">fit</span><span class="o">(</span><span class="n">training</span><span class="o">);</span> |
| |
| <span class="c1">// Prepare test documents, which are unlabeled.</span> |
| <span class="n">List</span><span class="o"><</span><span class="n">Document</span><span class="o">></span> <span class="n">localTest</span> <span class="o">=</span> <span class="n">Lists</span><span class="o">.</span><span class="na">newArrayList</span><span class="o">(</span> |
| <span class="k">new</span> <span class="nf">Document</span><span class="o">(</span><span class="mi">4L</span><span class="o">,</span> <span class="s">"spark i j k"</span><span class="o">),</span> |
| <span class="k">new</span> <span class="nf">Document</span><span class="o">(</span><span class="mi">5L</span><span class="o">,</span> <span class="s">"l m n"</span><span class="o">),</span> |
| <span class="k">new</span> <span class="nf">Document</span><span class="o">(</span><span class="mi">6L</span><span class="o">,</span> <span class="s">"mapreduce spark"</span><span class="o">),</span> |
| <span class="k">new</span> <span class="nf">Document</span><span class="o">(</span><span class="mi">7L</span><span class="o">,</span> <span class="s">"apache hadoop"</span><span class="o">));</span> |
| <span class="n">JavaSchemaRDD</span> <span class="n">test</span> <span class="o">=</span> |
| <span class="n">jsql</span><span class="o">.</span><span class="na">applySchema</span><span class="o">(</span><span class="n">jsc</span><span class="o">.</span><span class="na">parallelize</span><span class="o">(</span><span class="n">localTest</span><span class="o">),</span> <span class="n">Document</span><span class="o">.</span><span class="na">class</span><span class="o">);</span> |
| |
| <span class="c1">// Make predictions on test documents.</span> |
| <span class="n">model</span><span class="o">.</span><span class="na">transform</span><span class="o">(</span><span class="n">test</span><span class="o">).</span><span class="na">registerAsTable</span><span class="o">(</span><span class="s">"prediction"</span><span class="o">);</span> |
| <span class="n">JavaSchemaRDD</span> <span class="n">predictions</span> <span class="o">=</span> <span class="n">jsql</span><span class="o">.</span><span class="na">sql</span><span class="o">(</span><span class="s">"SELECT id, text, score, prediction FROM prediction"</span><span class="o">);</span> |
| <span class="k">for</span> <span class="o">(</span><span class="n">Row</span> <span class="nl">r:</span> <span class="n">predictions</span><span class="o">.</span><span class="na">collect</span><span class="o">())</span> <span class="o">{</span> |
| <span class="n">System</span><span class="o">.</span><span class="na">out</span><span class="o">.</span><span class="na">println</span><span class="o">(</span><span class="s">"("</span> <span class="o">+</span> <span class="n">r</span><span class="o">.</span><span class="na">get</span><span class="o">(</span><span class="mi">0</span><span class="o">)</span> <span class="o">+</span> <span class="s">", "</span> <span class="o">+</span> <span class="n">r</span><span class="o">.</span><span class="na">get</span><span class="o">(</span><span class="mi">1</span><span class="o">)</span> <span class="o">+</span> <span class="s">") --> score="</span> <span class="o">+</span> <span class="n">r</span><span class="o">.</span><span class="na">get</span><span class="o">(</span><span class="mi">2</span><span class="o">)</span> |
| <span class="o">+</span> <span class="s">", prediction="</span> <span class="o">+</span> <span class="n">r</span><span class="o">.</span><span class="na">get</span><span class="o">(</span><span class="mi">3</span><span class="o">));</span> |
| <span class="o">}</span></code></pre></div> |
| |
| </div> |
| |
| </div> |
| |
| <h2 id="example-model-selection-via-cross-validation">Example: Model Selection via Cross-Validation</h2> |
| |
| <p>An important task in ML is <em>model selection</em>, or using data to find the best model or parameters for a given task. This is also called <em>tuning</em>. |
| <code>Pipeline</code>s facilitate model selection by making it easy to tune an entire <code>Pipeline</code> at once, rather than tuning each element in the <code>Pipeline</code> separately.</p> |
| |
| <p>Currently, <code>spark.ml</code> supports model selection using the <a href="api/scala/index.html#org.apache.spark.ml.tuning.CrossValidator"><code>CrossValidator</code></a> class, which takes an <code>Estimator</code>, a set of <code>ParamMap</code>s, and an <a href="api/scala/index.html#org.apache.spark.ml.Evaluator"><code>Evaluator</code></a>. |
| <code>CrossValidator</code> begins by splitting the dataset into a set of <em>folds</em> which are used as separate training and test datasets; e.g., with <code>$k=3$</code> folds, <code>CrossValidator</code> will generate 3 (training, test) dataset pairs, each of which uses 2/3 of the data for training and 1/3 for testing. |
| <code>CrossValidator</code> iterates through the set of <code>ParamMap</code>s. For each <code>ParamMap</code>, it trains the given <code>Estimator</code> and evaluates it using the given <code>Evaluator</code>. |
| The <code>ParamMap</code> which produces the best evaluation metric (averaged over the <code>$k$</code> folds) is selected as the best model. |
| <code>CrossValidator</code> finally fits the <code>Estimator</code> using the best <code>ParamMap</code> and the entire dataset.</p> |
| |
| <p>The following example demonstrates using <code>CrossValidator</code> to select from a grid of parameters. |
| To help construct the parameter grid, we use the <a href="api/scala/index.html#org.apache.spark.ml.tuning.ParamGridBuilder"><code>ParamGridBuilder</code></a> utility.</p> |
| |
| <p>Note that cross-validation over a grid of parameters is expensive. |
| E.g., in the example below, the parameter grid has 3 values for <code>hashingTF.numFeatures</code> and 2 values for <code>lr.regParam</code>, and <code>CrossValidator</code> uses 2 folds. This multiplies out to <code>$(3 \times 2) \times 2 = 12$</code> different models being trained. |
| In realistic settings, it can be common to try many more parameters and use more folds (<code>$k=3$</code> and <code>$k=10$</code> are common). |
| In other words, using <code>CrossValidator</code> can be very expensive. |
| However, it is also a well-established method for choosing parameters which is more statistically sound than heuristic hand-tuning.</p> |
| |
| <div class="codetabs"> |
| |
| <div data-lang="scala"> |
| |
| <div class="highlight"><pre><code class="language-scala" data-lang="scala"><span class="k">import</span> <span class="nn">org.apache.spark.</span><span class="o">{</span><span class="nc">SparkConf</span><span class="o">,</span> <span class="nc">SparkContext</span><span class="o">}</span> |
| <span class="k">import</span> <span class="nn">org.apache.spark.SparkContext._</span> |
| <span class="k">import</span> <span class="nn">org.apache.spark.ml.Pipeline</span> |
| <span class="k">import</span> <span class="nn">org.apache.spark.ml.classification.LogisticRegression</span> |
| <span class="k">import</span> <span class="nn">org.apache.spark.ml.evaluation.BinaryClassificationEvaluator</span> |
| <span class="k">import</span> <span class="nn">org.apache.spark.ml.feature.</span><span class="o">{</span><span class="nc">HashingTF</span><span class="o">,</span> <span class="nc">Tokenizer</span><span class="o">}</span> |
| <span class="k">import</span> <span class="nn">org.apache.spark.ml.tuning.</span><span class="o">{</span><span class="nc">ParamGridBuilder</span><span class="o">,</span> <span class="nc">CrossValidator</span><span class="o">}</span> |
| <span class="k">import</span> <span class="nn">org.apache.spark.sql.</span><span class="o">{</span><span class="nc">Row</span><span class="o">,</span> <span class="nc">SQLContext</span><span class="o">}</span> |
| |
| <span class="k">val</span> <span class="n">conf</span> <span class="k">=</span> <span class="k">new</span> <span class="nc">SparkConf</span><span class="o">().</span><span class="n">setAppName</span><span class="o">(</span><span class="s">"CrossValidatorExample"</span><span class="o">)</span> |
| <span class="k">val</span> <span class="n">sc</span> <span class="k">=</span> <span class="k">new</span> <span class="nc">SparkContext</span><span class="o">(</span><span class="n">conf</span><span class="o">)</span> |
| <span class="k">val</span> <span class="n">sqlContext</span> <span class="k">=</span> <span class="k">new</span> <span class="nc">SQLContext</span><span class="o">(</span><span class="n">sc</span><span class="o">)</span> |
| <span class="k">import</span> <span class="nn">sqlContext._</span> |
| |
| <span class="c1">// Prepare training documents, which are labeled.</span> |
| <span class="k">val</span> <span class="n">training</span> <span class="k">=</span> <span class="n">sparkContext</span><span class="o">.</span><span class="n">parallelize</span><span class="o">(</span><span class="nc">Seq</span><span class="o">(</span> |
| <span class="nc">LabeledDocument</span><span class="o">(</span><span class="mi">0L</span><span class="o">,</span> <span class="s">"a b c d e spark"</span><span class="o">,</span> <span class="mf">1.0</span><span class="o">),</span> |
| <span class="nc">LabeledDocument</span><span class="o">(</span><span class="mi">1L</span><span class="o">,</span> <span class="s">"b d"</span><span class="o">,</span> <span class="mf">0.0</span><span class="o">),</span> |
| <span class="nc">LabeledDocument</span><span class="o">(</span><span class="mi">2L</span><span class="o">,</span> <span class="s">"spark f g h"</span><span class="o">,</span> <span class="mf">1.0</span><span class="o">),</span> |
| <span class="nc">LabeledDocument</span><span class="o">(</span><span class="mi">3L</span><span class="o">,</span> <span class="s">"hadoop mapreduce"</span><span class="o">,</span> <span class="mf">0.0</span><span class="o">),</span> |
| <span class="nc">LabeledDocument</span><span class="o">(</span><span class="mi">4L</span><span class="o">,</span> <span class="s">"b spark who"</span><span class="o">,</span> <span class="mf">1.0</span><span class="o">),</span> |
| <span class="nc">LabeledDocument</span><span class="o">(</span><span class="mi">5L</span><span class="o">,</span> <span class="s">"g d a y"</span><span class="o">,</span> <span class="mf">0.0</span><span class="o">),</span> |
| <span class="nc">LabeledDocument</span><span class="o">(</span><span class="mi">6L</span><span class="o">,</span> <span class="s">"spark fly"</span><span class="o">,</span> <span class="mf">1.0</span><span class="o">),</span> |
| <span class="nc">LabeledDocument</span><span class="o">(</span><span class="mi">7L</span><span class="o">,</span> <span class="s">"was mapreduce"</span><span class="o">,</span> <span class="mf">0.0</span><span class="o">),</span> |
| <span class="nc">LabeledDocument</span><span class="o">(</span><span class="mi">8L</span><span class="o">,</span> <span class="s">"e spark program"</span><span class="o">,</span> <span class="mf">1.0</span><span class="o">),</span> |
| <span class="nc">LabeledDocument</span><span class="o">(</span><span class="mi">9L</span><span class="o">,</span> <span class="s">"a e c l"</span><span class="o">,</span> <span class="mf">0.0</span><span class="o">),</span> |
| <span class="nc">LabeledDocument</span><span class="o">(</span><span class="mi">10L</span><span class="o">,</span> <span class="s">"spark compile"</span><span class="o">,</span> <span class="mf">1.0</span><span class="o">),</span> |
| <span class="nc">LabeledDocument</span><span class="o">(</span><span class="mi">11L</span><span class="o">,</span> <span class="s">"hadoop software"</span><span class="o">,</span> <span class="mf">0.0</span><span class="o">)))</span> |
| |
| <span class="c1">// Configure an ML pipeline, which consists of three stages: tokenizer, hashingTF, and lr.</span> |
| <span class="k">val</span> <span class="n">tokenizer</span> <span class="k">=</span> <span class="k">new</span> <span class="nc">Tokenizer</span><span class="o">()</span> |
| <span class="o">.</span><span class="n">setInputCol</span><span class="o">(</span><span class="s">"text"</span><span class="o">)</span> |
| <span class="o">.</span><span class="n">setOutputCol</span><span class="o">(</span><span class="s">"words"</span><span class="o">)</span> |
| <span class="k">val</span> <span class="n">hashingTF</span> <span class="k">=</span> <span class="k">new</span> <span class="nc">HashingTF</span><span class="o">()</span> |
| <span class="o">.</span><span class="n">setInputCol</span><span class="o">(</span><span class="n">tokenizer</span><span class="o">.</span><span class="n">getOutputCol</span><span class="o">)</span> |
| <span class="o">.</span><span class="n">setOutputCol</span><span class="o">(</span><span class="s">"features"</span><span class="o">)</span> |
| <span class="k">val</span> <span class="n">lr</span> <span class="k">=</span> <span class="k">new</span> <span class="nc">LogisticRegression</span><span class="o">()</span> |
| <span class="o">.</span><span class="n">setMaxIter</span><span class="o">(</span><span class="mi">10</span><span class="o">)</span> |
| <span class="k">val</span> <span class="n">pipeline</span> <span class="k">=</span> <span class="k">new</span> <span class="nc">Pipeline</span><span class="o">()</span> |
| <span class="o">.</span><span class="n">setStages</span><span class="o">(</span><span class="nc">Array</span><span class="o">(</span><span class="n">tokenizer</span><span class="o">,</span> <span class="n">hashingTF</span><span class="o">,</span> <span class="n">lr</span><span class="o">))</span> |
| |
| <span class="c1">// We now treat the Pipeline as an Estimator, wrapping it in a CrossValidator instance.</span> |
| <span class="c1">// This will allow us to jointly choose parameters for all Pipeline stages.</span> |
| <span class="c1">// A CrossValidator requires an Estimator, a set of Estimator ParamMaps, and an Evaluator.</span> |
| <span class="k">val</span> <span class="n">crossval</span> <span class="k">=</span> <span class="k">new</span> <span class="nc">CrossValidator</span><span class="o">()</span> |
| <span class="o">.</span><span class="n">setEstimator</span><span class="o">(</span><span class="n">pipeline</span><span class="o">)</span> |
| <span class="o">.</span><span class="n">setEvaluator</span><span class="o">(</span><span class="k">new</span> <span class="nc">BinaryClassificationEvaluator</span><span class="o">)</span> |
| <span class="c1">// We use a ParamGridBuilder to construct a grid of parameters to search over.</span> |
| <span class="c1">// With 3 values for hashingTF.numFeatures and 2 values for lr.regParam,</span> |
| <span class="c1">// this grid will have 3 x 2 = 6 parameter settings for CrossValidator to choose from.</span> |
| <span class="k">val</span> <span class="n">paramGrid</span> <span class="k">=</span> <span class="k">new</span> <span class="nc">ParamGridBuilder</span><span class="o">()</span> |
| <span class="o">.</span><span class="n">addGrid</span><span class="o">(</span><span class="n">hashingTF</span><span class="o">.</span><span class="n">numFeatures</span><span class="o">,</span> <span class="nc">Array</span><span class="o">(</span><span class="mi">10</span><span class="o">,</span> <span class="mi">100</span><span class="o">,</span> <span class="mi">1000</span><span class="o">))</span> |
| <span class="o">.</span><span class="n">addGrid</span><span class="o">(</span><span class="n">lr</span><span class="o">.</span><span class="n">regParam</span><span class="o">,</span> <span class="nc">Array</span><span class="o">(</span><span class="mf">0.1</span><span class="o">,</span> <span class="mf">0.01</span><span class="o">))</span> |
| <span class="o">.</span><span class="n">build</span><span class="o">()</span> |
| <span class="n">crossval</span><span class="o">.</span><span class="n">setEstimatorParamMaps</span><span class="o">(</span><span class="n">paramGrid</span><span class="o">)</span> |
| <span class="n">crossval</span><span class="o">.</span><span class="n">setNumFolds</span><span class="o">(</span><span class="mi">2</span><span class="o">)</span> <span class="c1">// Use 3+ in practice</span> |
| |
| <span class="c1">// Run cross-validation, and choose the best set of parameters.</span> |
| <span class="k">val</span> <span class="n">cvModel</span> <span class="k">=</span> <span class="n">crossval</span><span class="o">.</span><span class="n">fit</span><span class="o">(</span><span class="n">training</span><span class="o">)</span> |
| <span class="c1">// Get the best LogisticRegression model (with the best set of parameters from paramGrid).</span> |
| <span class="k">val</span> <span class="n">lrModel</span> <span class="k">=</span> <span class="n">cvModel</span><span class="o">.</span><span class="n">bestModel</span> |
| |
| <span class="c1">// Prepare test documents, which are unlabeled.</span> |
| <span class="k">val</span> <span class="n">test</span> <span class="k">=</span> <span class="n">sparkContext</span><span class="o">.</span><span class="n">parallelize</span><span class="o">(</span><span class="nc">Seq</span><span class="o">(</span> |
| <span class="nc">Document</span><span class="o">(</span><span class="mi">4L</span><span class="o">,</span> <span class="s">"spark i j k"</span><span class="o">),</span> |
| <span class="nc">Document</span><span class="o">(</span><span class="mi">5L</span><span class="o">,</span> <span class="s">"l m n"</span><span class="o">),</span> |
| <span class="nc">Document</span><span class="o">(</span><span class="mi">6L</span><span class="o">,</span> <span class="s">"mapreduce spark"</span><span class="o">),</span> |
| <span class="nc">Document</span><span class="o">(</span><span class="mi">7L</span><span class="o">,</span> <span class="s">"apache hadoop"</span><span class="o">)))</span> |
| |
| <span class="c1">// Make predictions on test documents. cvModel uses the best model found (lrModel).</span> |
| <span class="n">cvModel</span><span class="o">.</span><span class="n">transform</span><span class="o">(</span><span class="n">test</span><span class="o">)</span> |
| <span class="o">.</span><span class="n">select</span><span class="o">(</span><span class="-Symbol">'id</span><span class="o">,</span> <span class="-Symbol">'text</span><span class="o">,</span> <span class="-Symbol">'score</span><span class="o">,</span> <span class="-Symbol">'prediction</span><span class="o">)</span> |
| <span class="o">.</span><span class="n">collect</span><span class="o">()</span> |
| <span class="o">.</span><span class="n">foreach</span> <span class="o">{</span> <span class="k">case</span> <span class="nc">Row</span><span class="o">(</span><span class="n">id</span><span class="k">:</span> <span class="kt">Long</span><span class="o">,</span> <span class="n">text</span><span class="k">:</span> <span class="kt">String</span><span class="o">,</span> <span class="n">score</span><span class="k">:</span> <span class="kt">Double</span><span class="o">,</span> <span class="n">prediction</span><span class="k">:</span> <span class="kt">Double</span><span class="o">)</span> <span class="k">=></span> |
| <span class="n">println</span><span class="o">(</span><span class="s">"("</span> <span class="o">+</span> <span class="n">id</span> <span class="o">+</span> <span class="s">", "</span> <span class="o">+</span> <span class="n">text</span> <span class="o">+</span> <span class="s">") --> score="</span> <span class="o">+</span> <span class="n">score</span> <span class="o">+</span> <span class="s">", prediction="</span> <span class="o">+</span> <span class="n">prediction</span><span class="o">)</span> |
| <span class="o">}</span></code></pre></div> |
| |
| </div> |
| |
| <div data-lang="java"> |
| |
| <div class="highlight"><pre><code class="language-java" data-lang="java"><span class="kn">import</span> <span class="nn">java.util.List</span><span class="o">;</span> |
| <span class="kn">import</span> <span class="nn">com.google.common.collect.Lists</span><span class="o">;</span> |
| <span class="kn">import</span> <span class="nn">org.apache.spark.SparkConf</span><span class="o">;</span> |
| <span class="kn">import</span> <span class="nn">org.apache.spark.api.java.JavaSparkContext</span><span class="o">;</span> |
| <span class="kn">import</span> <span class="nn">org.apache.spark.ml.Model</span><span class="o">;</span> |
| <span class="kn">import</span> <span class="nn">org.apache.spark.ml.Pipeline</span><span class="o">;</span> |
| <span class="kn">import</span> <span class="nn">org.apache.spark.ml.PipelineStage</span><span class="o">;</span> |
| <span class="kn">import</span> <span class="nn">org.apache.spark.ml.classification.LogisticRegression</span><span class="o">;</span> |
| <span class="kn">import</span> <span class="nn">org.apache.spark.ml.evaluation.BinaryClassificationEvaluator</span><span class="o">;</span> |
| <span class="kn">import</span> <span class="nn">org.apache.spark.ml.feature.HashingTF</span><span class="o">;</span> |
| <span class="kn">import</span> <span class="nn">org.apache.spark.ml.feature.Tokenizer</span><span class="o">;</span> |
| <span class="kn">import</span> <span class="nn">org.apache.spark.ml.param.ParamMap</span><span class="o">;</span> |
| <span class="kn">import</span> <span class="nn">org.apache.spark.ml.tuning.CrossValidator</span><span class="o">;</span> |
| <span class="kn">import</span> <span class="nn">org.apache.spark.ml.tuning.CrossValidatorModel</span><span class="o">;</span> |
| <span class="kn">import</span> <span class="nn">org.apache.spark.ml.tuning.ParamGridBuilder</span><span class="o">;</span> |
| <span class="kn">import</span> <span class="nn">org.apache.spark.sql.api.java.JavaSQLContext</span><span class="o">;</span> |
| <span class="kn">import</span> <span class="nn">org.apache.spark.sql.api.java.JavaSchemaRDD</span><span class="o">;</span> |
| <span class="kn">import</span> <span class="nn">org.apache.spark.sql.api.java.Row</span><span class="o">;</span> |
| |
| <span class="n">SparkConf</span> <span class="n">conf</span> <span class="o">=</span> <span class="k">new</span> <span class="nf">SparkConf</span><span class="o">().</span><span class="na">setAppName</span><span class="o">(</span><span class="s">"JavaCrossValidatorExample"</span><span class="o">);</span> |
| <span class="n">JavaSparkContext</span> <span class="n">jsc</span> <span class="o">=</span> <span class="k">new</span> <span class="nf">JavaSparkContext</span><span class="o">(</span><span class="n">conf</span><span class="o">);</span> |
| <span class="n">JavaSQLContext</span> <span class="n">jsql</span> <span class="o">=</span> <span class="k">new</span> <span class="nf">JavaSQLContext</span><span class="o">(</span><span class="n">jsc</span><span class="o">);</span> |
| |
| <span class="c1">// Prepare training documents, which are labeled.</span> |
| <span class="n">List</span><span class="o"><</span><span class="n">LabeledDocument</span><span class="o">></span> <span class="n">localTraining</span> <span class="o">=</span> <span class="n">Lists</span><span class="o">.</span><span class="na">newArrayList</span><span class="o">(</span> |
| <span class="k">new</span> <span class="nf">LabeledDocument</span><span class="o">(</span><span class="mi">0L</span><span class="o">,</span> <span class="s">"a b c d e spark"</span><span class="o">,</span> <span class="mf">1.0</span><span class="o">),</span> |
| <span class="k">new</span> <span class="nf">LabeledDocument</span><span class="o">(</span><span class="mi">1L</span><span class="o">,</span> <span class="s">"b d"</span><span class="o">,</span> <span class="mf">0.0</span><span class="o">),</span> |
| <span class="k">new</span> <span class="nf">LabeledDocument</span><span class="o">(</span><span class="mi">2L</span><span class="o">,</span> <span class="s">"spark f g h"</span><span class="o">,</span> <span class="mf">1.0</span><span class="o">),</span> |
| <span class="k">new</span> <span class="nf">LabeledDocument</span><span class="o">(</span><span class="mi">3L</span><span class="o">,</span> <span class="s">"hadoop mapreduce"</span><span class="o">,</span> <span class="mf">0.0</span><span class="o">),</span> |
| <span class="k">new</span> <span class="nf">LabeledDocument</span><span class="o">(</span><span class="mi">4L</span><span class="o">,</span> <span class="s">"b spark who"</span><span class="o">,</span> <span class="mf">1.0</span><span class="o">),</span> |
| <span class="k">new</span> <span class="nf">LabeledDocument</span><span class="o">(</span><span class="mi">5L</span><span class="o">,</span> <span class="s">"g d a y"</span><span class="o">,</span> <span class="mf">0.0</span><span class="o">),</span> |
| <span class="k">new</span> <span class="nf">LabeledDocument</span><span class="o">(</span><span class="mi">6L</span><span class="o">,</span> <span class="s">"spark fly"</span><span class="o">,</span> <span class="mf">1.0</span><span class="o">),</span> |
| <span class="k">new</span> <span class="nf">LabeledDocument</span><span class="o">(</span><span class="mi">7L</span><span class="o">,</span> <span class="s">"was mapreduce"</span><span class="o">,</span> <span class="mf">0.0</span><span class="o">),</span> |
| <span class="k">new</span> <span class="nf">LabeledDocument</span><span class="o">(</span><span class="mi">8L</span><span class="o">,</span> <span class="s">"e spark program"</span><span class="o">,</span> <span class="mf">1.0</span><span class="o">),</span> |
| <span class="k">new</span> <span class="nf">LabeledDocument</span><span class="o">(</span><span class="mi">9L</span><span class="o">,</span> <span class="s">"a e c l"</span><span class="o">,</span> <span class="mf">0.0</span><span class="o">),</span> |
| <span class="k">new</span> <span class="nf">LabeledDocument</span><span class="o">(</span><span class="mi">10L</span><span class="o">,</span> <span class="s">"spark compile"</span><span class="o">,</span> <span class="mf">1.0</span><span class="o">),</span> |
| <span class="k">new</span> <span class="nf">LabeledDocument</span><span class="o">(</span><span class="mi">11L</span><span class="o">,</span> <span class="s">"hadoop software"</span><span class="o">,</span> <span class="mf">0.0</span><span class="o">));</span> |
| <span class="n">JavaSchemaRDD</span> <span class="n">training</span> <span class="o">=</span> |
| <span class="n">jsql</span><span class="o">.</span><span class="na">applySchema</span><span class="o">(</span><span class="n">jsc</span><span class="o">.</span><span class="na">parallelize</span><span class="o">(</span><span class="n">localTraining</span><span class="o">),</span> <span class="n">LabeledDocument</span><span class="o">.</span><span class="na">class</span><span class="o">);</span> |
| |
| <span class="c1">// Configure an ML pipeline, which consists of three stages: tokenizer, hashingTF, and lr.</span> |
| <span class="n">Tokenizer</span> <span class="n">tokenizer</span> <span class="o">=</span> <span class="k">new</span> <span class="nf">Tokenizer</span><span class="o">()</span> |
| <span class="o">.</span><span class="na">setInputCol</span><span class="o">(</span><span class="s">"text"</span><span class="o">)</span> |
| <span class="o">.</span><span class="na">setOutputCol</span><span class="o">(</span><span class="s">"words"</span><span class="o">);</span> |
| <span class="n">HashingTF</span> <span class="n">hashingTF</span> <span class="o">=</span> <span class="k">new</span> <span class="nf">HashingTF</span><span class="o">()</span> |
| <span class="o">.</span><span class="na">setNumFeatures</span><span class="o">(</span><span class="mi">1000</span><span class="o">)</span> |
| <span class="o">.</span><span class="na">setInputCol</span><span class="o">(</span><span class="n">tokenizer</span><span class="o">.</span><span class="na">getOutputCol</span><span class="o">())</span> |
| <span class="o">.</span><span class="na">setOutputCol</span><span class="o">(</span><span class="s">"features"</span><span class="o">);</span> |
| <span class="n">LogisticRegression</span> <span class="n">lr</span> <span class="o">=</span> <span class="k">new</span> <span class="nf">LogisticRegression</span><span class="o">()</span> |
| <span class="o">.</span><span class="na">setMaxIter</span><span class="o">(</span><span class="mi">10</span><span class="o">)</span> |
| <span class="o">.</span><span class="na">setRegParam</span><span class="o">(</span><span class="mf">0.01</span><span class="o">);</span> |
| <span class="n">Pipeline</span> <span class="n">pipeline</span> <span class="o">=</span> <span class="k">new</span> <span class="nf">Pipeline</span><span class="o">()</span> |
| <span class="o">.</span><span class="na">setStages</span><span class="o">(</span><span class="k">new</span> <span class="n">PipelineStage</span><span class="o">[]</span> <span class="o">{</span><span class="n">tokenizer</span><span class="o">,</span> <span class="n">hashingTF</span><span class="o">,</span> <span class="n">lr</span><span class="o">});</span> |
| |
| <span class="c1">// We now treat the Pipeline as an Estimator, wrapping it in a CrossValidator instance.</span> |
| <span class="c1">// This will allow us to jointly choose parameters for all Pipeline stages.</span> |
| <span class="c1">// A CrossValidator requires an Estimator, a set of Estimator ParamMaps, and an Evaluator.</span> |
| <span class="n">CrossValidator</span> <span class="n">crossval</span> <span class="o">=</span> <span class="k">new</span> <span class="nf">CrossValidator</span><span class="o">()</span> |
| <span class="o">.</span><span class="na">setEstimator</span><span class="o">(</span><span class="n">pipeline</span><span class="o">)</span> |
| <span class="o">.</span><span class="na">setEvaluator</span><span class="o">(</span><span class="k">new</span> <span class="nf">BinaryClassificationEvaluator</span><span class="o">());</span> |
| <span class="c1">// We use a ParamGridBuilder to construct a grid of parameters to search over.</span> |
| <span class="c1">// With 3 values for hashingTF.numFeatures and 2 values for lr.regParam,</span> |
| <span class="c1">// this grid will have 3 x 2 = 6 parameter settings for CrossValidator to choose from.</span> |
| <span class="n">ParamMap</span><span class="o">[]</span> <span class="n">paramGrid</span> <span class="o">=</span> <span class="k">new</span> <span class="nf">ParamGridBuilder</span><span class="o">()</span> |
| <span class="o">.</span><span class="na">addGrid</span><span class="o">(</span><span class="n">hashingTF</span><span class="o">.</span><span class="na">numFeatures</span><span class="o">(),</span> <span class="k">new</span> <span class="kt">int</span><span class="o">[]{</span><span class="mi">10</span><span class="o">,</span> <span class="mi">100</span><span class="o">,</span> <span class="mi">1000</span><span class="o">})</span> |
| <span class="o">.</span><span class="na">addGrid</span><span class="o">(</span><span class="n">lr</span><span class="o">.</span><span class="na">regParam</span><span class="o">(),</span> <span class="k">new</span> <span class="kt">double</span><span class="o">[]{</span><span class="mf">0.1</span><span class="o">,</span> <span class="mf">0.01</span><span class="o">})</span> |
| <span class="o">.</span><span class="na">build</span><span class="o">();</span> |
| <span class="n">crossval</span><span class="o">.</span><span class="na">setEstimatorParamMaps</span><span class="o">(</span><span class="n">paramGrid</span><span class="o">);</span> |
| <span class="n">crossval</span><span class="o">.</span><span class="na">setNumFolds</span><span class="o">(</span><span class="mi">2</span><span class="o">);</span> <span class="c1">// Use 3+ in practice</span> |
| |
| <span class="c1">// Run cross-validation, and choose the best set of parameters.</span> |
| <span class="n">CrossValidatorModel</span> <span class="n">cvModel</span> <span class="o">=</span> <span class="n">crossval</span><span class="o">.</span><span class="na">fit</span><span class="o">(</span><span class="n">training</span><span class="o">);</span> |
| <span class="c1">// Get the best LogisticRegression model (with the best set of parameters from paramGrid).</span> |
| <span class="n">Model</span> <span class="n">lrModel</span> <span class="o">=</span> <span class="n">cvModel</span><span class="o">.</span><span class="na">bestModel</span><span class="o">();</span> |
| |
| <span class="c1">// Prepare test documents, which are unlabeled.</span> |
| <span class="n">List</span><span class="o"><</span><span class="n">Document</span><span class="o">></span> <span class="n">localTest</span> <span class="o">=</span> <span class="n">Lists</span><span class="o">.</span><span class="na">newArrayList</span><span class="o">(</span> |
| <span class="k">new</span> <span class="nf">Document</span><span class="o">(</span><span class="mi">4L</span><span class="o">,</span> <span class="s">"spark i j k"</span><span class="o">),</span> |
| <span class="k">new</span> <span class="nf">Document</span><span class="o">(</span><span class="mi">5L</span><span class="o">,</span> <span class="s">"l m n"</span><span class="o">),</span> |
| <span class="k">new</span> <span class="nf">Document</span><span class="o">(</span><span class="mi">6L</span><span class="o">,</span> <span class="s">"mapreduce spark"</span><span class="o">),</span> |
| <span class="k">new</span> <span class="nf">Document</span><span class="o">(</span><span class="mi">7L</span><span class="o">,</span> <span class="s">"apache hadoop"</span><span class="o">));</span> |
| <span class="n">JavaSchemaRDD</span> <span class="n">test</span> <span class="o">=</span> <span class="n">jsql</span><span class="o">.</span><span class="na">applySchema</span><span class="o">(</span><span class="n">jsc</span><span class="o">.</span><span class="na">parallelize</span><span class="o">(</span><span class="n">localTest</span><span class="o">),</span> <span class="n">Document</span><span class="o">.</span><span class="na">class</span><span class="o">);</span> |
| |
| <span class="c1">// Make predictions on test documents. cvModel uses the best model found (lrModel).</span> |
| <span class="n">cvModel</span><span class="o">.</span><span class="na">transform</span><span class="o">(</span><span class="n">test</span><span class="o">).</span><span class="na">registerAsTable</span><span class="o">(</span><span class="s">"prediction"</span><span class="o">);</span> |
| <span class="n">JavaSchemaRDD</span> <span class="n">predictions</span> <span class="o">=</span> <span class="n">jsql</span><span class="o">.</span><span class="na">sql</span><span class="o">(</span><span class="s">"SELECT id, text, score, prediction FROM prediction"</span><span class="o">);</span> |
| <span class="k">for</span> <span class="o">(</span><span class="n">Row</span> <span class="nl">r:</span> <span class="n">predictions</span><span class="o">.</span><span class="na">collect</span><span class="o">())</span> <span class="o">{</span> |
| <span class="n">System</span><span class="o">.</span><span class="na">out</span><span class="o">.</span><span class="na">println</span><span class="o">(</span><span class="s">"("</span> <span class="o">+</span> <span class="n">r</span><span class="o">.</span><span class="na">get</span><span class="o">(</span><span class="mi">0</span><span class="o">)</span> <span class="o">+</span> <span class="s">", "</span> <span class="o">+</span> <span class="n">r</span><span class="o">.</span><span class="na">get</span><span class="o">(</span><span class="mi">1</span><span class="o">)</span> <span class="o">+</span> <span class="s">") --> score="</span> <span class="o">+</span> <span class="n">r</span><span class="o">.</span><span class="na">get</span><span class="o">(</span><span class="mi">2</span><span class="o">)</span> |
| <span class="o">+</span> <span class="s">", prediction="</span> <span class="o">+</span> <span class="n">r</span><span class="o">.</span><span class="na">get</span><span class="o">(</span><span class="mi">3</span><span class="o">));</span> |
| <span class="o">}</span></code></pre></div> |
| |
| </div> |
| |
| </div> |
| |
| <h1 id="dependencies">Dependencies</h1> |
| |
| <p>Spark ML currently depends on MLlib and has the same dependencies. |
| Please see the <a href="mllib-guide.html#Dependencies">MLlib Dependencies guide</a> for more info.</p> |
| |
| <p>Spark ML also depends upon Spark SQL, but the relevant parts of Spark SQL do not bring additional dependencies.</p> |
| |
| |
| </div> <!-- /container --> |
| |
| <script src="js/vendor/jquery-1.8.0.min.js"></script> |
| <script src="js/vendor/bootstrap.min.js"></script> |
| <script src="js/main.js"></script> |
| |
| <!-- MathJax Section --> |
| <script type="text/x-mathjax-config"> |
| MathJax.Hub.Config({ |
| TeX: { equationNumbers: { autoNumber: "AMS" } } |
| }); |
| </script> |
| <script> |
| // Note that we load MathJax this way to work with local file (file://), HTTP and HTTPS. |
| // We could use "//cdn.mathjax...", but that won't support "file://". |
| (function(d, script) { |
| script = d.createElement('script'); |
| script.type = 'text/javascript'; |
| script.async = true; |
| script.onload = function(){ |
| MathJax.Hub.Config({ |
| tex2jax: { |
| inlineMath: [ ["$", "$"], ["\\\\(","\\\\)"] ], |
| displayMath: [ ["$$","$$"], ["\\[", "\\]"] ], |
| processEscapes: true, |
| skipTags: ['script', 'noscript', 'style', 'textarea', 'pre'] |
| } |
| }); |
| }; |
| script.src = ('https:' == document.location.protocol ? 'https://' : 'http://') + |
| 'cdn.mathjax.org/mathjax/latest/MathJax.js?config=TeX-AMS-MML_HTMLorMML'; |
| d.getElementsByTagName('head')[0].appendChild(script); |
| }(document)); |
| </script> |
| </body> |
| </html> |