| |
| <!DOCTYPE html> |
| <!--[if lt IE 7]> <html class="no-js lt-ie9 lt-ie8 lt-ie7"> <![endif]--> |
| <!--[if IE 7]> <html class="no-js lt-ie9 lt-ie8"> <![endif]--> |
| <!--[if IE 8]> <html class="no-js lt-ie9"> <![endif]--> |
| <!--[if gt IE 8]><!--> <html class="no-js"> <!--<![endif]--> |
| <head> |
| <meta charset="utf-8"> |
| <meta http-equiv="X-UA-Compatible" content="IE=edge,chrome=1"> |
| <title>Classification and regression - Spark 2.4.5 Documentation</title> |
| |
| |
| |
| |
| <link rel="stylesheet" href="css/bootstrap.min.css"> |
| <style> |
| body { |
| padding-top: 60px; |
| padding-bottom: 40px; |
| } |
| </style> |
| <meta name="viewport" content="width=device-width"> |
| <link rel="stylesheet" href="css/bootstrap-responsive.min.css"> |
| <link rel="stylesheet" href="css/main.css"> |
| |
| <script src="js/vendor/modernizr-2.6.1-respond-1.1.0.min.js"></script> |
| |
| <link rel="stylesheet" href="css/pygments-default.css"> |
| |
| |
| <!-- Google analytics script --> |
| <script type="text/javascript"> |
| var _gaq = _gaq || []; |
| _gaq.push(['_setAccount', 'UA-32518208-2']); |
| _gaq.push(['_trackPageview']); |
| |
| (function() { |
| var ga = document.createElement('script'); ga.type = 'text/javascript'; ga.async = true; |
| ga.src = ('https:' == document.location.protocol ? 'https://ssl' : 'http://www') + '.google-analytics.com/ga.js'; |
| var s = document.getElementsByTagName('script')[0]; s.parentNode.insertBefore(ga, s); |
| })(); |
| </script> |
| |
| |
| </head> |
| <body> |
| <!--[if lt IE 7]> |
| <p class="chromeframe">You are using an outdated browser. <a href="https://browsehappy.com/">Upgrade your browser today</a> or <a href="http://www.google.com/chromeframe/?redirect=true">install Google Chrome Frame</a> to better experience this site.</p> |
| <![endif]--> |
| |
| <!-- This code is taken from http://twitter.github.com/bootstrap/examples/hero.html --> |
| |
| <div class="navbar navbar-fixed-top" id="topbar"> |
| <div class="navbar-inner"> |
| <div class="container"> |
| <div class="brand"><a href="index.html"> |
| <img src="img/spark-logo-hd.png" style="height:50px;"/></a><span class="version">2.4.5</span> |
| </div> |
| <ul class="nav"> |
| <!--TODO(andyk): Add class="active" attribute to li some how.--> |
| <li><a href="index.html">Overview</a></li> |
| |
| <li class="dropdown"> |
| <a href="#" class="dropdown-toggle" data-toggle="dropdown">Programming Guides<b class="caret"></b></a> |
| <ul class="dropdown-menu"> |
| <li><a href="quick-start.html">Quick Start</a></li> |
| <li><a href="rdd-programming-guide.html">RDDs, Accumulators, Broadcasts Vars</a></li> |
| <li><a href="sql-programming-guide.html">SQL, DataFrames, and Datasets</a></li> |
| <li><a href="structured-streaming-programming-guide.html">Structured Streaming</a></li> |
| <li><a href="streaming-programming-guide.html">Spark Streaming (DStreams)</a></li> |
| <li><a href="ml-guide.html">MLlib (Machine Learning)</a></li> |
| <li><a href="graphx-programming-guide.html">GraphX (Graph Processing)</a></li> |
| <li><a href="sparkr.html">SparkR (R on Spark)</a></li> |
| </ul> |
| </li> |
| |
| <li class="dropdown"> |
| <a href="#" class="dropdown-toggle" data-toggle="dropdown">API Docs<b class="caret"></b></a> |
| <ul class="dropdown-menu"> |
| <li><a href="api/scala/index.html#org.apache.spark.package">Scala</a></li> |
| <li><a href="api/java/index.html">Java</a></li> |
| <li><a href="api/python/index.html">Python</a></li> |
| <li><a href="api/R/index.html">R</a></li> |
| <li><a href="api/sql/index.html">SQL, Built-in Functions</a></li> |
| </ul> |
| </li> |
| |
| <li class="dropdown"> |
| <a href="#" class="dropdown-toggle" data-toggle="dropdown">Deploying<b class="caret"></b></a> |
| <ul class="dropdown-menu"> |
| <li><a href="cluster-overview.html">Overview</a></li> |
| <li><a href="submitting-applications.html">Submitting Applications</a></li> |
| <li class="divider"></li> |
| <li><a href="spark-standalone.html">Spark Standalone</a></li> |
| <li><a href="running-on-mesos.html">Mesos</a></li> |
| <li><a href="running-on-yarn.html">YARN</a></li> |
| <li><a href="running-on-kubernetes.html">Kubernetes</a></li> |
| </ul> |
| </li> |
| |
| <li class="dropdown"> |
| <a href="api.html" class="dropdown-toggle" data-toggle="dropdown">More<b class="caret"></b></a> |
| <ul class="dropdown-menu"> |
| <li><a href="configuration.html">Configuration</a></li> |
| <li><a href="monitoring.html">Monitoring</a></li> |
| <li><a href="tuning.html">Tuning Guide</a></li> |
| <li><a href="job-scheduling.html">Job Scheduling</a></li> |
| <li><a href="security.html">Security</a></li> |
| <li><a href="hardware-provisioning.html">Hardware Provisioning</a></li> |
| <li class="divider"></li> |
| <li><a href="building-spark.html">Building Spark</a></li> |
| <li><a href="https://spark.apache.org/contributing.html">Contributing to Spark</a></li> |
| <li><a href="https://spark.apache.org/third-party-projects.html">Third Party Projects</a></li> |
| </ul> |
| </li> |
| </ul> |
| <!--<p class="navbar-text pull-right"><span class="version-text">v2.4.5</span></p>--> |
| </div> |
| </div> |
| </div> |
| |
| <div class="container-wrapper"> |
| |
| |
| |
| <div class="left-menu-wrapper"> |
| <div class="left-menu"> |
| <h3><a href="ml-guide.html">MLlib: Main Guide</a></h3> |
| |
| <ul> |
| |
| <li> |
| <a href="ml-statistics.html"> |
| |
| Basic statistics |
| |
| </a> |
| </li> |
| |
| |
| |
| <li> |
| <a href="ml-datasource"> |
| |
| Data sources |
| |
| </a> |
| </li> |
| |
| |
| |
| <li> |
| <a href="ml-pipeline.html"> |
| |
| Pipelines |
| |
| </a> |
| </li> |
| |
| |
| |
| <li> |
| <a href="ml-features.html"> |
| |
| Extracting, transforming and selecting features |
| |
| </a> |
| </li> |
| |
| |
| |
| <li> |
| <a href="ml-classification-regression.html"> |
| |
| <b>Classification and Regression</b> |
| |
| </a> |
| </li> |
| |
| |
| |
| <li> |
| <a href="ml-clustering.html"> |
| |
| Clustering |
| |
| </a> |
| </li> |
| |
| |
| |
| <li> |
| <a href="ml-collaborative-filtering.html"> |
| |
| Collaborative filtering |
| |
| </a> |
| </li> |
| |
| |
| |
| <li> |
| <a href="ml-frequent-pattern-mining.html"> |
| |
| Frequent Pattern Mining |
| |
| </a> |
| </li> |
| |
| |
| |
| <li> |
| <a href="ml-tuning.html"> |
| |
| Model selection and tuning |
| |
| </a> |
| </li> |
| |
| |
| |
| <li> |
| <a href="ml-advanced.html"> |
| |
| Advanced topics |
| |
| </a> |
| </li> |
| |
| |
| |
| </ul> |
| |
| <h3><a href="mllib-guide.html">MLlib: RDD-based API Guide</a></h3> |
| |
| <ul> |
| |
| <li> |
| <a href="mllib-data-types.html"> |
| |
| Data types |
| |
| </a> |
| </li> |
| |
| |
| |
| <li> |
| <a href="mllib-statistics.html"> |
| |
| Basic statistics |
| |
| </a> |
| </li> |
| |
| |
| |
| <li> |
| <a href="mllib-classification-regression.html"> |
| |
| Classification and regression |
| |
| </a> |
| </li> |
| |
| |
| |
| <li> |
| <a href="mllib-collaborative-filtering.html"> |
| |
| Collaborative filtering |
| |
| </a> |
| </li> |
| |
| |
| |
| <li> |
| <a href="mllib-clustering.html"> |
| |
| Clustering |
| |
| </a> |
| </li> |
| |
| |
| |
| <li> |
| <a href="mllib-dimensionality-reduction.html"> |
| |
| Dimensionality reduction |
| |
| </a> |
| </li> |
| |
| |
| |
| <li> |
| <a href="mllib-feature-extraction.html"> |
| |
| Feature extraction and transformation |
| |
| </a> |
| </li> |
| |
| |
| |
| <li> |
| <a href="mllib-frequent-pattern-mining.html"> |
| |
| Frequent pattern mining |
| |
| </a> |
| </li> |
| |
| |
| |
| <li> |
| <a href="mllib-evaluation-metrics.html"> |
| |
| Evaluation metrics |
| |
| </a> |
| </li> |
| |
| |
| |
| <li> |
| <a href="mllib-pmml-model-export.html"> |
| |
| PMML model export |
| |
| </a> |
| </li> |
| |
| |
| |
| <li> |
| <a href="mllib-optimization.html"> |
| |
| Optimization (developer) |
| |
| </a> |
| </li> |
| |
| |
| |
| </ul> |
| |
| </div> |
| </div> |
| |
| <input id="nav-trigger" class="nav-trigger" checked type="checkbox"> |
| <label for="nav-trigger"></label> |
| <div class="content-with-sidebar" id="content"> |
| |
| <h1 class="title">Classification and regression</h1> |
| |
| |
| <p><code>\[ |
| \newcommand{\R}{\mathbb{R}} |
| \newcommand{\E}{\mathbb{E}} |
| \newcommand{\x}{\mathbf{x}} |
| \newcommand{\y}{\mathbf{y}} |
| \newcommand{\wv}{\mathbf{w}} |
| \newcommand{\av}{\mathbf{\alpha}} |
| \newcommand{\bv}{\mathbf{b}} |
| \newcommand{\N}{\mathbb{N}} |
| \newcommand{\id}{\mathbf{I}} |
| \newcommand{\ind}{\mathbf{1}} |
| \newcommand{\0}{\mathbf{0}} |
| \newcommand{\unit}{\mathbf{e}} |
| \newcommand{\one}{\mathbf{1}} |
| \newcommand{\zero}{\mathbf{0}} |
| \]</code></p> |
| |
| <p>This page covers algorithms for Classification and Regression. It also includes sections |
| discussing specific classes of algorithms, such as linear methods, trees, and ensembles.</p> |
| |
| <p><strong>Table of Contents</strong></p> |
| |
| <ul id="markdown-toc"> |
| <li><a href="#classification" id="markdown-toc-classification">Classification</a> <ul> |
| <li><a href="#logistic-regression" id="markdown-toc-logistic-regression">Logistic regression</a> <ul> |
| <li><a href="#binomial-logistic-regression" id="markdown-toc-binomial-logistic-regression">Binomial logistic regression</a></li> |
| <li><a href="#multinomial-logistic-regression" id="markdown-toc-multinomial-logistic-regression">Multinomial logistic regression</a></li> |
| </ul> |
| </li> |
| <li><a href="#decision-tree-classifier" id="markdown-toc-decision-tree-classifier">Decision tree classifier</a></li> |
| <li><a href="#random-forest-classifier" id="markdown-toc-random-forest-classifier">Random forest classifier</a></li> |
| <li><a href="#gradient-boosted-tree-classifier" id="markdown-toc-gradient-boosted-tree-classifier">Gradient-boosted tree classifier</a></li> |
| <li><a href="#multilayer-perceptron-classifier" id="markdown-toc-multilayer-perceptron-classifier">Multilayer perceptron classifier</a></li> |
| <li><a href="#linear-support-vector-machine" id="markdown-toc-linear-support-vector-machine">Linear Support Vector Machine</a></li> |
| <li><a href="#one-vs-rest-classifier-aka-one-vs-all" id="markdown-toc-one-vs-rest-classifier-aka-one-vs-all">One-vs-Rest classifier (a.k.a. One-vs-All)</a></li> |
| <li><a href="#naive-bayes" id="markdown-toc-naive-bayes">Naive Bayes</a></li> |
| </ul> |
| </li> |
| <li><a href="#regression" id="markdown-toc-regression">Regression</a> <ul> |
| <li><a href="#linear-regression" id="markdown-toc-linear-regression">Linear regression</a></li> |
| <li><a href="#generalized-linear-regression" id="markdown-toc-generalized-linear-regression">Generalized linear regression</a> <ul> |
| <li><a href="#available-families" id="markdown-toc-available-families">Available families</a></li> |
| </ul> |
| </li> |
| <li><a href="#decision-tree-regression" id="markdown-toc-decision-tree-regression">Decision tree regression</a></li> |
| <li><a href="#random-forest-regression" id="markdown-toc-random-forest-regression">Random forest regression</a></li> |
| <li><a href="#gradient-boosted-tree-regression" id="markdown-toc-gradient-boosted-tree-regression">Gradient-boosted tree regression</a></li> |
| <li><a href="#survival-regression" id="markdown-toc-survival-regression">Survival regression</a></li> |
| <li><a href="#isotonic-regression" id="markdown-toc-isotonic-regression">Isotonic regression</a></li> |
| </ul> |
| </li> |
| <li><a href="#linear-methods" id="markdown-toc-linear-methods">Linear methods</a></li> |
| <li><a href="#decision-trees" id="markdown-toc-decision-trees">Decision trees</a> <ul> |
| <li><a href="#inputs-and-outputs" id="markdown-toc-inputs-and-outputs">Inputs and Outputs</a> <ul> |
| <li><a href="#input-columns" id="markdown-toc-input-columns">Input Columns</a></li> |
| <li><a href="#output-columns" id="markdown-toc-output-columns">Output Columns</a></li> |
| </ul> |
| </li> |
| </ul> |
| </li> |
| <li><a href="#tree-ensembles" id="markdown-toc-tree-ensembles">Tree Ensembles</a> <ul> |
| <li><a href="#random-forests" id="markdown-toc-random-forests">Random Forests</a> <ul> |
| <li><a href="#inputs-and-outputs-1" id="markdown-toc-inputs-and-outputs-1">Inputs and Outputs</a> <ul> |
| <li><a href="#input-columns-1" id="markdown-toc-input-columns-1">Input Columns</a></li> |
| <li><a href="#output-columns-predictions" id="markdown-toc-output-columns-predictions">Output Columns (Predictions)</a></li> |
| </ul> |
| </li> |
| </ul> |
| </li> |
| <li><a href="#gradient-boosted-trees-gbts" id="markdown-toc-gradient-boosted-trees-gbts">Gradient-Boosted Trees (GBTs)</a> <ul> |
| <li><a href="#inputs-and-outputs-2" id="markdown-toc-inputs-and-outputs-2">Inputs and Outputs</a> <ul> |
| <li><a href="#input-columns-2" id="markdown-toc-input-columns-2">Input Columns</a></li> |
| <li><a href="#output-columns-predictions-1" id="markdown-toc-output-columns-predictions-1">Output Columns (Predictions)</a></li> |
| </ul> |
| </li> |
| </ul> |
| </li> |
| </ul> |
| </li> |
| </ul> |
| |
| <h1 id="classification">Classification</h1> |
| |
| <h2 id="logistic-regression">Logistic regression</h2> |
| |
| <p>Logistic regression is a popular method to predict a categorical response. It is a special case of <a href="https://en.wikipedia.org/wiki/Generalized_linear_model">Generalized Linear models</a> that predicts the probability of the outcomes. |
| In <code>spark.ml</code> logistic regression can be used to predict a binary outcome by using binomial logistic regression, or it can be used to predict a multiclass outcome by using multinomial logistic regression. Use the <code>family</code> |
| parameter to select between these two algorithms, or leave it unset and Spark will infer the correct variant.</p> |
| |
| <blockquote> |
| <p>Multinomial logistic regression can be used for binary classification by setting the <code>family</code> param to “multinomial”. It will produce two sets of coefficients and two intercepts.</p> |
| </blockquote> |
| |
| <blockquote> |
| <p>When fitting LogisticRegressionModel without intercept on dataset with constant nonzero column, Spark MLlib outputs zero coefficients for constant nonzero columns. This behavior is the same as R glmnet but different from LIBSVM.</p> |
| </blockquote> |
| |
| <h3 id="binomial-logistic-regression">Binomial logistic regression</h3> |
| |
| <p>For more background and more details about the implementation of binomial logistic regression, refer to the documentation of <a href="mllib-linear-methods.html#logistic-regression">logistic regression in <code>spark.mllib</code></a>.</p> |
| |
| <p><strong>Examples</strong></p> |
| |
| <p>The following example shows how to train binomial and multinomial logistic regression |
| models for binary classification with elastic net regularization. <code>elasticNetParam</code> corresponds to |
| $\alpha$ and <code>regParam</code> corresponds to $\lambda$.</p> |
| |
| <div class="codetabs"> |
| |
| <div data-lang="scala"> |
| |
| <p>More details on parameters can be found in the <a href="api/scala/index.html#org.apache.spark.ml.classification.LogisticRegression">Scala API documentation</a>.</p> |
| |
| <div class="highlight"><pre><span></span><span class="k">import</span> <span class="nn">org.apache.spark.ml.classification.LogisticRegression</span> |
| |
| <span class="c1">// Load training data</span> |
| <span class="k">val</span> <span class="n">training</span> <span class="k">=</span> <span class="n">spark</span><span class="o">.</span><span class="n">read</span><span class="o">.</span><span class="n">format</span><span class="o">(</span><span class="s">"libsvm"</span><span class="o">).</span><span class="n">load</span><span class="o">(</span><span class="s">"data/mllib/sample_libsvm_data.txt"</span><span class="o">)</span> |
| |
| <span class="k">val</span> <span class="n">lr</span> <span class="k">=</span> <span class="k">new</span> <span class="nc">LogisticRegression</span><span class="o">()</span> |
| <span class="o">.</span><span class="n">setMaxIter</span><span class="o">(</span><span class="mi">10</span><span class="o">)</span> |
| <span class="o">.</span><span class="n">setRegParam</span><span class="o">(</span><span class="mf">0.3</span><span class="o">)</span> |
| <span class="o">.</span><span class="n">setElasticNetParam</span><span class="o">(</span><span class="mf">0.8</span><span class="o">)</span> |
| |
| <span class="c1">// Fit the model</span> |
| <span class="k">val</span> <span class="n">lrModel</span> <span class="k">=</span> <span class="n">lr</span><span class="o">.</span><span class="n">fit</span><span class="o">(</span><span class="n">training</span><span class="o">)</span> |
| |
| <span class="c1">// Print the coefficients and intercept for logistic regression</span> |
| <span class="n">println</span><span class="o">(</span><span class="s">s"Coefficients: </span><span class="si">${</span><span class="n">lrModel</span><span class="o">.</span><span class="n">coefficients</span><span class="si">}</span><span class="s"> Intercept: </span><span class="si">${</span><span class="n">lrModel</span><span class="o">.</span><span class="n">intercept</span><span class="si">}</span><span class="s">"</span><span class="o">)</span> |
| |
| <span class="c1">// We can also use the multinomial family for binary classification</span> |
| <span class="k">val</span> <span class="n">mlr</span> <span class="k">=</span> <span class="k">new</span> <span class="nc">LogisticRegression</span><span class="o">()</span> |
| <span class="o">.</span><span class="n">setMaxIter</span><span class="o">(</span><span class="mi">10</span><span class="o">)</span> |
| <span class="o">.</span><span class="n">setRegParam</span><span class="o">(</span><span class="mf">0.3</span><span class="o">)</span> |
| <span class="o">.</span><span class="n">setElasticNetParam</span><span class="o">(</span><span class="mf">0.8</span><span class="o">)</span> |
| <span class="o">.</span><span class="n">setFamily</span><span class="o">(</span><span class="s">"multinomial"</span><span class="o">)</span> |
| |
| <span class="k">val</span> <span class="n">mlrModel</span> <span class="k">=</span> <span class="n">mlr</span><span class="o">.</span><span class="n">fit</span><span class="o">(</span><span class="n">training</span><span class="o">)</span> |
| |
| <span class="c1">// Print the coefficients and intercepts for logistic regression with multinomial family</span> |
| <span class="n">println</span><span class="o">(</span><span class="s">s"Multinomial coefficients: </span><span class="si">${</span><span class="n">mlrModel</span><span class="o">.</span><span class="n">coefficientMatrix</span><span class="si">}</span><span class="s">"</span><span class="o">)</span> |
| <span class="n">println</span><span class="o">(</span><span class="s">s"Multinomial intercepts: </span><span class="si">${</span><span class="n">mlrModel</span><span class="o">.</span><span class="n">interceptVector</span><span class="si">}</span><span class="s">"</span><span class="o">)</span> |
| </pre></div> |
| <div><small>Find full example code at "examples/src/main/scala/org/apache/spark/examples/ml/LogisticRegressionWithElasticNetExample.scala" in the Spark repo.</small></div> |
| </div> |
| |
| <div data-lang="java"> |
| |
| <p>More details on parameters can be found in the <a href="api/java/org/apache/spark/ml/classification/LogisticRegression.html">Java API documentation</a>.</p> |
| |
| <div class="highlight"><pre><span></span><span class="kn">import</span> <span class="nn">org.apache.spark.ml.classification.LogisticRegression</span><span class="o">;</span> |
| <span class="kn">import</span> <span class="nn">org.apache.spark.ml.classification.LogisticRegressionModel</span><span class="o">;</span> |
| <span class="kn">import</span> <span class="nn">org.apache.spark.sql.Dataset</span><span class="o">;</span> |
| <span class="kn">import</span> <span class="nn">org.apache.spark.sql.Row</span><span class="o">;</span> |
| <span class="kn">import</span> <span class="nn">org.apache.spark.sql.SparkSession</span><span class="o">;</span> |
| |
| <span class="c1">// Load training data</span> |
| <span class="n">Dataset</span><span class="o"><</span><span class="n">Row</span><span class="o">></span> <span class="n">training</span> <span class="o">=</span> <span class="n">spark</span><span class="o">.</span><span class="na">read</span><span class="o">().</span><span class="na">format</span><span class="o">(</span><span class="s">"libsvm"</span><span class="o">)</span> |
| <span class="o">.</span><span class="na">load</span><span class="o">(</span><span class="s">"data/mllib/sample_libsvm_data.txt"</span><span class="o">);</span> |
| |
| <span class="n">LogisticRegression</span> <span class="n">lr</span> <span class="o">=</span> <span class="k">new</span> <span class="n">LogisticRegression</span><span class="o">()</span> |
| <span class="o">.</span><span class="na">setMaxIter</span><span class="o">(</span><span class="mi">10</span><span class="o">)</span> |
| <span class="o">.</span><span class="na">setRegParam</span><span class="o">(</span><span class="mf">0.3</span><span class="o">)</span> |
| <span class="o">.</span><span class="na">setElasticNetParam</span><span class="o">(</span><span class="mf">0.8</span><span class="o">);</span> |
| |
| <span class="c1">// Fit the model</span> |
| <span class="n">LogisticRegressionModel</span> <span class="n">lrModel</span> <span class="o">=</span> <span class="n">lr</span><span class="o">.</span><span class="na">fit</span><span class="o">(</span><span class="n">training</span><span class="o">);</span> |
| |
| <span class="c1">// Print the coefficients and intercept for logistic regression</span> |
| <span class="n">System</span><span class="o">.</span><span class="na">out</span><span class="o">.</span><span class="na">println</span><span class="o">(</span><span class="s">"Coefficients: "</span> |
| <span class="o">+</span> <span class="n">lrModel</span><span class="o">.</span><span class="na">coefficients</span><span class="o">()</span> <span class="o">+</span> <span class="s">" Intercept: "</span> <span class="o">+</span> <span class="n">lrModel</span><span class="o">.</span><span class="na">intercept</span><span class="o">());</span> |
| |
| <span class="c1">// We can also use the multinomial family for binary classification</span> |
| <span class="n">LogisticRegression</span> <span class="n">mlr</span> <span class="o">=</span> <span class="k">new</span> <span class="n">LogisticRegression</span><span class="o">()</span> |
| <span class="o">.</span><span class="na">setMaxIter</span><span class="o">(</span><span class="mi">10</span><span class="o">)</span> |
| <span class="o">.</span><span class="na">setRegParam</span><span class="o">(</span><span class="mf">0.3</span><span class="o">)</span> |
| <span class="o">.</span><span class="na">setElasticNetParam</span><span class="o">(</span><span class="mf">0.8</span><span class="o">)</span> |
| <span class="o">.</span><span class="na">setFamily</span><span class="o">(</span><span class="s">"multinomial"</span><span class="o">);</span> |
| |
| <span class="c1">// Fit the model</span> |
| <span class="n">LogisticRegressionModel</span> <span class="n">mlrModel</span> <span class="o">=</span> <span class="n">mlr</span><span class="o">.</span><span class="na">fit</span><span class="o">(</span><span class="n">training</span><span class="o">);</span> |
| |
| <span class="c1">// Print the coefficients and intercepts for logistic regression with multinomial family</span> |
| <span class="n">System</span><span class="o">.</span><span class="na">out</span><span class="o">.</span><span class="na">println</span><span class="o">(</span><span class="s">"Multinomial coefficients: "</span> <span class="o">+</span> <span class="n">lrModel</span><span class="o">.</span><span class="na">coefficientMatrix</span><span class="o">()</span> |
| <span class="o">+</span> <span class="s">"\nMultinomial intercepts: "</span> <span class="o">+</span> <span class="n">mlrModel</span><span class="o">.</span><span class="na">interceptVector</span><span class="o">());</span> |
| </pre></div> |
| <div><small>Find full example code at "examples/src/main/java/org/apache/spark/examples/ml/JavaLogisticRegressionWithElasticNetExample.java" in the Spark repo.</small></div> |
| </div> |
| |
| <div data-lang="python"> |
| |
| <p>More details on parameters can be found in the <a href="api/python/pyspark.ml.html#pyspark.ml.classification.LogisticRegression">Python API documentation</a>.</p> |
| |
| <div class="highlight"><pre><span></span><span class="kn">from</span> <span class="nn">pyspark.ml.classification</span> <span class="kn">import</span> <span class="n">LogisticRegression</span> |
| |
| <span class="c1"># Load training data</span> |
| <span class="n">training</span> <span class="o">=</span> <span class="n">spark</span><span class="o">.</span><span class="n">read</span><span class="o">.</span><span class="n">format</span><span class="p">(</span><span class="s2">"libsvm"</span><span class="p">)</span><span class="o">.</span><span class="n">load</span><span class="p">(</span><span class="s2">"data/mllib/sample_libsvm_data.txt"</span><span class="p">)</span> |
| |
| <span class="n">lr</span> <span class="o">=</span> <span class="n">LogisticRegression</span><span class="p">(</span><span class="n">maxIter</span><span class="o">=</span><span class="mi">10</span><span class="p">,</span> <span class="n">regParam</span><span class="o">=</span><span class="mf">0.3</span><span class="p">,</span> <span class="n">elasticNetParam</span><span class="o">=</span><span class="mf">0.8</span><span class="p">)</span> |
| |
| <span class="c1"># Fit the model</span> |
| <span class="n">lrModel</span> <span class="o">=</span> <span class="n">lr</span><span class="o">.</span><span class="n">fit</span><span class="p">(</span><span class="n">training</span><span class="p">)</span> |
| |
| <span class="c1"># Print the coefficients and intercept for logistic regression</span> |
| <span class="k">print</span><span class="p">(</span><span class="s2">"Coefficients: "</span> <span class="o">+</span> <span class="nb">str</span><span class="p">(</span><span class="n">lrModel</span><span class="o">.</span><span class="n">coefficients</span><span class="p">))</span> |
| <span class="k">print</span><span class="p">(</span><span class="s2">"Intercept: "</span> <span class="o">+</span> <span class="nb">str</span><span class="p">(</span><span class="n">lrModel</span><span class="o">.</span><span class="n">intercept</span><span class="p">))</span> |
| |
| <span class="c1"># We can also use the multinomial family for binary classification</span> |
| <span class="n">mlr</span> <span class="o">=</span> <span class="n">LogisticRegression</span><span class="p">(</span><span class="n">maxIter</span><span class="o">=</span><span class="mi">10</span><span class="p">,</span> <span class="n">regParam</span><span class="o">=</span><span class="mf">0.3</span><span class="p">,</span> <span class="n">elasticNetParam</span><span class="o">=</span><span class="mf">0.8</span><span class="p">,</span> <span class="n">family</span><span class="o">=</span><span class="s2">"multinomial"</span><span class="p">)</span> |
| |
| <span class="c1"># Fit the model</span> |
| <span class="n">mlrModel</span> <span class="o">=</span> <span class="n">mlr</span><span class="o">.</span><span class="n">fit</span><span class="p">(</span><span class="n">training</span><span class="p">)</span> |
| |
| <span class="c1"># Print the coefficients and intercepts for logistic regression with multinomial family</span> |
| <span class="k">print</span><span class="p">(</span><span class="s2">"Multinomial coefficients: "</span> <span class="o">+</span> <span class="nb">str</span><span class="p">(</span><span class="n">mlrModel</span><span class="o">.</span><span class="n">coefficientMatrix</span><span class="p">))</span> |
| <span class="k">print</span><span class="p">(</span><span class="s2">"Multinomial intercepts: "</span> <span class="o">+</span> <span class="nb">str</span><span class="p">(</span><span class="n">mlrModel</span><span class="o">.</span><span class="n">interceptVector</span><span class="p">))</span> |
| </pre></div> |
| <div><small>Find full example code at "examples/src/main/python/ml/logistic_regression_with_elastic_net.py" in the Spark repo.</small></div> |
| </div> |
| |
| <div data-lang="r"> |
| |
| <p>More details on parameters can be found in the <a href="api/R/spark.logit.html">R API documentation</a>.</p> |
| |
| <div class="highlight"><pre><span></span><span class="c1"># Load training data</span> |
| df <span class="o"><-</span> read.df<span class="p">(</span><span class="s">"data/mllib/sample_libsvm_data.txt"</span><span class="p">,</span> <span class="kn">source</span> <span class="o">=</span> <span class="s">"libsvm"</span><span class="p">)</span> |
| training <span class="o"><-</span> df |
| test <span class="o"><-</span> df |
| |
| <span class="c1"># Fit an binomial logistic regression model with spark.logit</span> |
| model <span class="o"><-</span> spark.logit<span class="p">(</span>training<span class="p">,</span> label <span class="o">~</span> features<span class="p">,</span> maxIter <span class="o">=</span> <span class="m">10</span><span class="p">,</span> regParam <span class="o">=</span> <span class="m">0.3</span><span class="p">,</span> elasticNetParam <span class="o">=</span> <span class="m">0.8</span><span class="p">)</span> |
| |
| <span class="c1"># Model summary</span> |
| <span class="kp">summary</span><span class="p">(</span>model<span class="p">)</span> |
| |
| <span class="c1"># Prediction</span> |
| predictions <span class="o"><-</span> predict<span class="p">(</span>model<span class="p">,</span> test<span class="p">)</span> |
| <span class="kp">head</span><span class="p">(</span>predictions<span class="p">)</span> |
| </pre></div> |
| <div><small>Find full example code at "examples/src/main/r/ml/logit.R" in the Spark repo.</small></div> |
| </div> |
| |
| </div> |
| |
| <p>The <code>spark.ml</code> implementation of logistic regression also supports |
| extracting a summary of the model over the training set. Note that the |
| predictions and metrics which are stored as <code>DataFrame</code> in |
| <code>LogisticRegressionSummary</code> are annotated <code>@transient</code> and hence |
| only available on the driver.</p> |
| |
| <div class="codetabs"> |
| |
| <div data-lang="scala"> |
| |
| <p><a href="api/scala/index.html#org.apache.spark.ml.classification.LogisticRegressionTrainingSummary"><code>LogisticRegressionTrainingSummary</code></a> |
| provides a summary for a |
| <a href="api/scala/index.html#org.apache.spark.ml.classification.LogisticRegressionModel"><code>LogisticRegressionModel</code></a>. |
| In the case of binary classification, certain additional metrics are |
| available, e.g. ROC curve. The binary summary can be accessed via the |
| <code>binarySummary</code> method. See <a href="api/scala/index.html#org.apache.spark.ml.classification.BinaryLogisticRegressionTrainingSummary"><code>BinaryLogisticRegressionTrainingSummary</code></a>.</p> |
| |
| <p>Continuing the earlier example:</p> |
| |
| <div class="highlight"><pre><span></span><span class="k">import</span> <span class="nn">org.apache.spark.ml.classification.LogisticRegression</span> |
| |
| <span class="c1">// Extract the summary from the returned LogisticRegressionModel instance trained in the earlier</span> |
| <span class="c1">// example</span> |
| <span class="k">val</span> <span class="n">trainingSummary</span> <span class="k">=</span> <span class="n">lrModel</span><span class="o">.</span><span class="n">binarySummary</span> |
| |
| <span class="c1">// Obtain the objective per iteration.</span> |
| <span class="k">val</span> <span class="n">objectiveHistory</span> <span class="k">=</span> <span class="n">trainingSummary</span><span class="o">.</span><span class="n">objectiveHistory</span> |
| <span class="n">println</span><span class="o">(</span><span class="s">"objectiveHistory:"</span><span class="o">)</span> |
| <span class="n">objectiveHistory</span><span class="o">.</span><span class="n">foreach</span><span class="o">(</span><span class="n">loss</span> <span class="k">=></span> <span class="n">println</span><span class="o">(</span><span class="n">loss</span><span class="o">))</span> |
| |
| <span class="c1">// Obtain the receiver-operating characteristic as a dataframe and areaUnderROC.</span> |
| <span class="k">val</span> <span class="n">roc</span> <span class="k">=</span> <span class="n">trainingSummary</span><span class="o">.</span><span class="n">roc</span> |
| <span class="n">roc</span><span class="o">.</span><span class="n">show</span><span class="o">()</span> |
| <span class="n">println</span><span class="o">(</span><span class="s">s"areaUnderROC: </span><span class="si">${</span><span class="n">trainingSummary</span><span class="o">.</span><span class="n">areaUnderROC</span><span class="si">}</span><span class="s">"</span><span class="o">)</span> |
| |
| <span class="c1">// Set the model threshold to maximize F-Measure</span> |
| <span class="k">val</span> <span class="n">fMeasure</span> <span class="k">=</span> <span class="n">trainingSummary</span><span class="o">.</span><span class="n">fMeasureByThreshold</span> |
| <span class="k">val</span> <span class="n">maxFMeasure</span> <span class="k">=</span> <span class="n">fMeasure</span><span class="o">.</span><span class="n">select</span><span class="o">(</span><span class="n">max</span><span class="o">(</span><span class="s">"F-Measure"</span><span class="o">)).</span><span class="n">head</span><span class="o">().</span><span class="n">getDouble</span><span class="o">(</span><span class="mi">0</span><span class="o">)</span> |
| <span class="k">val</span> <span class="n">bestThreshold</span> <span class="k">=</span> <span class="n">fMeasure</span><span class="o">.</span><span class="n">where</span><span class="o">(</span><span class="n">$</span><span class="s">"F-Measure"</span> <span class="o">===</span> <span class="n">maxFMeasure</span><span class="o">)</span> |
| <span class="o">.</span><span class="n">select</span><span class="o">(</span><span class="s">"threshold"</span><span class="o">).</span><span class="n">head</span><span class="o">().</span><span class="n">getDouble</span><span class="o">(</span><span class="mi">0</span><span class="o">)</span> |
| <span class="n">lrModel</span><span class="o">.</span><span class="n">setThreshold</span><span class="o">(</span><span class="n">bestThreshold</span><span class="o">)</span> |
| </pre></div> |
| <div><small>Find full example code at "examples/src/main/scala/org/apache/spark/examples/ml/LogisticRegressionSummaryExample.scala" in the Spark repo.</small></div> |
| </div> |
| |
| <div data-lang="java"> |
| <p><a href="api/java/org/apache/spark/ml/classification/LogisticRegressionTrainingSummary.html"><code>LogisticRegressionTrainingSummary</code></a> |
| provides a summary for a |
| <a href="api/java/org/apache/spark/ml/classification/LogisticRegressionModel.html"><code>LogisticRegressionModel</code></a>. |
| In the case of binary classification, certain additional metrics are |
| available, e.g. ROC curve. The binary summary can be accessed via the |
| <code>binarySummary</code> method. See <a href="api/java/org/apache/spark/ml/classification/BinaryLogisticRegressionTrainingSummary.html"><code>BinaryLogisticRegressionTrainingSummary</code></a>.</p> |
| |
| <p>Continuing the earlier example:</p> |
| |
| <div class="highlight"><pre><span></span><span class="kn">import</span> <span class="nn">org.apache.spark.ml.classification.BinaryLogisticRegressionTrainingSummary</span><span class="o">;</span> |
| <span class="kn">import</span> <span class="nn">org.apache.spark.ml.classification.LogisticRegression</span><span class="o">;</span> |
| <span class="kn">import</span> <span class="nn">org.apache.spark.ml.classification.LogisticRegressionModel</span><span class="o">;</span> |
| <span class="kn">import</span> <span class="nn">org.apache.spark.sql.Dataset</span><span class="o">;</span> |
| <span class="kn">import</span> <span class="nn">org.apache.spark.sql.Row</span><span class="o">;</span> |
| <span class="kn">import</span> <span class="nn">org.apache.spark.sql.SparkSession</span><span class="o">;</span> |
| <span class="kn">import</span> <span class="nn">org.apache.spark.sql.functions</span><span class="o">;</span> |
| |
| <span class="c1">// Extract the summary from the returned LogisticRegressionModel instance trained in the earlier</span> |
| <span class="c1">// example</span> |
| <span class="n">BinaryLogisticRegressionTrainingSummary</span> <span class="n">trainingSummary</span> <span class="o">=</span> <span class="n">lrModel</span><span class="o">.</span><span class="na">binarySummary</span><span class="o">();</span> |
| |
| <span class="c1">// Obtain the loss per iteration.</span> |
| <span class="kt">double</span><span class="o">[]</span> <span class="n">objectiveHistory</span> <span class="o">=</span> <span class="n">trainingSummary</span><span class="o">.</span><span class="na">objectiveHistory</span><span class="o">();</span> |
| <span class="k">for</span> <span class="o">(</span><span class="kt">double</span> <span class="n">lossPerIteration</span> <span class="o">:</span> <span class="n">objectiveHistory</span><span class="o">)</span> <span class="o">{</span> |
| <span class="n">System</span><span class="o">.</span><span class="na">out</span><span class="o">.</span><span class="na">println</span><span class="o">(</span><span class="n">lossPerIteration</span><span class="o">);</span> |
| <span class="o">}</span> |
| |
| <span class="c1">// Obtain the receiver-operating characteristic as a dataframe and areaUnderROC.</span> |
| <span class="n">Dataset</span><span class="o"><</span><span class="n">Row</span><span class="o">></span> <span class="n">roc</span> <span class="o">=</span> <span class="n">trainingSummary</span><span class="o">.</span><span class="na">roc</span><span class="o">();</span> |
| <span class="n">roc</span><span class="o">.</span><span class="na">show</span><span class="o">();</span> |
| <span class="n">roc</span><span class="o">.</span><span class="na">select</span><span class="o">(</span><span class="s">"FPR"</span><span class="o">).</span><span class="na">show</span><span class="o">();</span> |
| <span class="n">System</span><span class="o">.</span><span class="na">out</span><span class="o">.</span><span class="na">println</span><span class="o">(</span><span class="n">trainingSummary</span><span class="o">.</span><span class="na">areaUnderROC</span><span class="o">());</span> |
| |
| <span class="c1">// Get the threshold corresponding to the maximum F-Measure and rerun LogisticRegression with</span> |
| <span class="c1">// this selected threshold.</span> |
| <span class="n">Dataset</span><span class="o"><</span><span class="n">Row</span><span class="o">></span> <span class="n">fMeasure</span> <span class="o">=</span> <span class="n">trainingSummary</span><span class="o">.</span><span class="na">fMeasureByThreshold</span><span class="o">();</span> |
| <span class="kt">double</span> <span class="n">maxFMeasure</span> <span class="o">=</span> <span class="n">fMeasure</span><span class="o">.</span><span class="na">select</span><span class="o">(</span><span class="n">functions</span><span class="o">.</span><span class="na">max</span><span class="o">(</span><span class="s">"F-Measure"</span><span class="o">)).</span><span class="na">head</span><span class="o">().</span><span class="na">getDouble</span><span class="o">(</span><span class="mi">0</span><span class="o">);</span> |
| <span class="kt">double</span> <span class="n">bestThreshold</span> <span class="o">=</span> <span class="n">fMeasure</span><span class="o">.</span><span class="na">where</span><span class="o">(</span><span class="n">fMeasure</span><span class="o">.</span><span class="na">col</span><span class="o">(</span><span class="s">"F-Measure"</span><span class="o">).</span><span class="na">equalTo</span><span class="o">(</span><span class="n">maxFMeasure</span><span class="o">))</span> |
| <span class="o">.</span><span class="na">select</span><span class="o">(</span><span class="s">"threshold"</span><span class="o">).</span><span class="na">head</span><span class="o">().</span><span class="na">getDouble</span><span class="o">(</span><span class="mi">0</span><span class="o">);</span> |
| <span class="n">lrModel</span><span class="o">.</span><span class="na">setThreshold</span><span class="o">(</span><span class="n">bestThreshold</span><span class="o">);</span> |
| </pre></div> |
| <div><small>Find full example code at "examples/src/main/java/org/apache/spark/examples/ml/JavaLogisticRegressionSummaryExample.java" in the Spark repo.</small></div> |
| </div> |
| |
| <div data-lang="python"> |
| <p><a href="api/python/pyspark.ml.html#pyspark.ml.classification.LogisticRegressionSummary"><code>LogisticRegressionTrainingSummary</code></a> |
| provides a summary for a |
| <a href="api/python/pyspark.ml.html#pyspark.ml.classification.LogisticRegressionModel"><code>LogisticRegressionModel</code></a>. |
| In the case of binary classification, certain additional metrics are |
| available, e.g. ROC curve. See <a href="api/python/pyspark.ml.html#pyspark.ml.classification.BinaryLogisticRegressionTrainingSummary"><code>BinaryLogisticRegressionTrainingSummary</code></a>.</p> |
| |
| <p>Continuing the earlier example:</p> |
| |
| <div class="highlight"><pre><span></span><span class="kn">from</span> <span class="nn">pyspark.ml.classification</span> <span class="kn">import</span> <span class="n">LogisticRegression</span> |
| |
| <span class="c1"># Extract the summary from the returned LogisticRegressionModel instance trained</span> |
| <span class="c1"># in the earlier example</span> |
| <span class="n">trainingSummary</span> <span class="o">=</span> <span class="n">lrModel</span><span class="o">.</span><span class="n">summary</span> |
| |
| <span class="c1"># Obtain the objective per iteration</span> |
| <span class="n">objectiveHistory</span> <span class="o">=</span> <span class="n">trainingSummary</span><span class="o">.</span><span class="n">objectiveHistory</span> |
| <span class="k">print</span><span class="p">(</span><span class="s2">"objectiveHistory:"</span><span class="p">)</span> |
| <span class="k">for</span> <span class="n">objective</span> <span class="ow">in</span> <span class="n">objectiveHistory</span><span class="p">:</span> |
| <span class="k">print</span><span class="p">(</span><span class="n">objective</span><span class="p">)</span> |
| |
| <span class="c1"># Obtain the receiver-operating characteristic as a dataframe and areaUnderROC.</span> |
| <span class="n">trainingSummary</span><span class="o">.</span><span class="n">roc</span><span class="o">.</span><span class="n">show</span><span class="p">()</span> |
| <span class="k">print</span><span class="p">(</span><span class="s2">"areaUnderROC: "</span> <span class="o">+</span> <span class="nb">str</span><span class="p">(</span><span class="n">trainingSummary</span><span class="o">.</span><span class="n">areaUnderROC</span><span class="p">))</span> |
| |
| <span class="c1"># Set the model threshold to maximize F-Measure</span> |
| <span class="n">fMeasure</span> <span class="o">=</span> <span class="n">trainingSummary</span><span class="o">.</span><span class="n">fMeasureByThreshold</span> |
| <span class="n">maxFMeasure</span> <span class="o">=</span> <span class="n">fMeasure</span><span class="o">.</span><span class="n">groupBy</span><span class="p">()</span><span class="o">.</span><span class="n">max</span><span class="p">(</span><span class="s1">'F-Measure'</span><span class="p">)</span><span class="o">.</span><span class="n">select</span><span class="p">(</span><span class="s1">'max(F-Measure)'</span><span class="p">)</span><span class="o">.</span><span class="n">head</span><span class="p">()</span> |
| <span class="n">bestThreshold</span> <span class="o">=</span> <span class="n">fMeasure</span><span class="o">.</span><span class="n">where</span><span class="p">(</span><span class="n">fMeasure</span><span class="p">[</span><span class="s1">'F-Measure'</span><span class="p">]</span> <span class="o">==</span> <span class="n">maxFMeasure</span><span class="p">[</span><span class="s1">'max(F-Measure)'</span><span class="p">])</span> \ |
| <span class="o">.</span><span class="n">select</span><span class="p">(</span><span class="s1">'threshold'</span><span class="p">)</span><span class="o">.</span><span class="n">head</span><span class="p">()[</span><span class="s1">'threshold'</span><span class="p">]</span> |
| <span class="n">lr</span><span class="o">.</span><span class="n">setThreshold</span><span class="p">(</span><span class="n">bestThreshold</span><span class="p">)</span> |
| </pre></div> |
| <div><small>Find full example code at "examples/src/main/python/ml/logistic_regression_summary_example.py" in the Spark repo.</small></div> |
| </div> |
| |
| </div> |
| |
| <h3 id="multinomial-logistic-regression">Multinomial logistic regression</h3> |
| |
| <p>Multiclass classification is supported via multinomial logistic (softmax) regression. In multinomial logistic regression, |
| the algorithm produces $K$ sets of coefficients, or a matrix of dimension $K \times J$ where $K$ is the number of outcome |
| classes and $J$ is the number of features. If the algorithm is fit with an intercept term then a length $K$ vector of |
| intercepts is available.</p> |
| |
| <blockquote> |
| <p>Multinomial coefficients are available as <code>coefficientMatrix</code> and intercepts are available as <code>interceptVector</code>.</p> |
| </blockquote> |
| |
| <blockquote> |
| <p><code>coefficients</code> and <code>intercept</code> methods on a logistic regression model trained with multinomial family are not supported. Use <code>coefficientMatrix</code> and <code>interceptVector</code> instead.</p> |
| </blockquote> |
| |
| <p>The conditional probabilities of the outcome classes $k \in {1, 2, …, K}$ are modeled using the softmax function.</p> |
| |
| <p><code>\[ |
| P(Y=k|\mathbf{X}, \boldsymbol{\beta}_k, \beta_{0k}) = \frac{e^{\boldsymbol{\beta}_k \cdot \mathbf{X} + \beta_{0k}}}{\sum_{k'=0}^{K-1} e^{\boldsymbol{\beta}_{k'} \cdot \mathbf{X} + \beta_{0k'}}} |
| \]</code></p> |
| |
| <p>We minimize the weighted negative log-likelihood, using a multinomial response model, with elastic-net penalty to control for overfitting.</p> |
| |
| <p><code>\[ |
| \min_{\beta, \beta_0} -\left[\sum_{i=1}^L w_i \cdot \log P(Y = y_i|\mathbf{x}_i)\right] + \lambda \left[\frac{1}{2}\left(1 - \alpha\right)||\boldsymbol{\beta}||_2^2 + \alpha ||\boldsymbol{\beta}||_1\right] |
| \]</code></p> |
| |
| <p>For a detailed derivation please see <a href="https://en.wikipedia.org/wiki/Multinomial_logistic_regression#As_a_log-linear_model">here</a>.</p> |
| |
| <p><strong>Examples</strong></p> |
| |
| <p>The following example shows how to train a multiclass logistic regression |
| model with elastic net regularization, as well as extract the multiclass |
| training summary for evaluating the model.</p> |
| |
| <div class="codetabs"> |
| |
| <div data-lang="scala"> |
| <div class="highlight"><pre><span></span><span class="k">import</span> <span class="nn">org.apache.spark.ml.classification.LogisticRegression</span> |
| |
| <span class="c1">// Load training data</span> |
| <span class="k">val</span> <span class="n">training</span> <span class="k">=</span> <span class="n">spark</span> |
| <span class="o">.</span><span class="n">read</span> |
| <span class="o">.</span><span class="n">format</span><span class="o">(</span><span class="s">"libsvm"</span><span class="o">)</span> |
| <span class="o">.</span><span class="n">load</span><span class="o">(</span><span class="s">"data/mllib/sample_multiclass_classification_data.txt"</span><span class="o">)</span> |
| |
| <span class="k">val</span> <span class="n">lr</span> <span class="k">=</span> <span class="k">new</span> <span class="nc">LogisticRegression</span><span class="o">()</span> |
| <span class="o">.</span><span class="n">setMaxIter</span><span class="o">(</span><span class="mi">10</span><span class="o">)</span> |
| <span class="o">.</span><span class="n">setRegParam</span><span class="o">(</span><span class="mf">0.3</span><span class="o">)</span> |
| <span class="o">.</span><span class="n">setElasticNetParam</span><span class="o">(</span><span class="mf">0.8</span><span class="o">)</span> |
| |
| <span class="c1">// Fit the model</span> |
| <span class="k">val</span> <span class="n">lrModel</span> <span class="k">=</span> <span class="n">lr</span><span class="o">.</span><span class="n">fit</span><span class="o">(</span><span class="n">training</span><span class="o">)</span> |
| |
| <span class="c1">// Print the coefficients and intercept for multinomial logistic regression</span> |
| <span class="n">println</span><span class="o">(</span><span class="s">s"Coefficients: \n</span><span class="si">${</span><span class="n">lrModel</span><span class="o">.</span><span class="n">coefficientMatrix</span><span class="si">}</span><span class="s">"</span><span class="o">)</span> |
| <span class="n">println</span><span class="o">(</span><span class="s">s"Intercepts: \n</span><span class="si">${</span><span class="n">lrModel</span><span class="o">.</span><span class="n">interceptVector</span><span class="si">}</span><span class="s">"</span><span class="o">)</span> |
| |
| <span class="k">val</span> <span class="n">trainingSummary</span> <span class="k">=</span> <span class="n">lrModel</span><span class="o">.</span><span class="n">summary</span> |
| |
| <span class="c1">// Obtain the objective per iteration</span> |
| <span class="k">val</span> <span class="n">objectiveHistory</span> <span class="k">=</span> <span class="n">trainingSummary</span><span class="o">.</span><span class="n">objectiveHistory</span> |
| <span class="n">println</span><span class="o">(</span><span class="s">"objectiveHistory:"</span><span class="o">)</span> |
| <span class="n">objectiveHistory</span><span class="o">.</span><span class="n">foreach</span><span class="o">(</span><span class="n">println</span><span class="o">)</span> |
| |
| <span class="c1">// for multiclass, we can inspect metrics on a per-label basis</span> |
| <span class="n">println</span><span class="o">(</span><span class="s">"False positive rate by label:"</span><span class="o">)</span> |
| <span class="n">trainingSummary</span><span class="o">.</span><span class="n">falsePositiveRateByLabel</span><span class="o">.</span><span class="n">zipWithIndex</span><span class="o">.</span><span class="n">foreach</span> <span class="o">{</span> <span class="k">case</span> <span class="o">(</span><span class="n">rate</span><span class="o">,</span> <span class="n">label</span><span class="o">)</span> <span class="k">=></span> |
| <span class="n">println</span><span class="o">(</span><span class="s">s"label </span><span class="si">$label</span><span class="s">: </span><span class="si">$rate</span><span class="s">"</span><span class="o">)</span> |
| <span class="o">}</span> |
| |
| <span class="n">println</span><span class="o">(</span><span class="s">"True positive rate by label:"</span><span class="o">)</span> |
| <span class="n">trainingSummary</span><span class="o">.</span><span class="n">truePositiveRateByLabel</span><span class="o">.</span><span class="n">zipWithIndex</span><span class="o">.</span><span class="n">foreach</span> <span class="o">{</span> <span class="k">case</span> <span class="o">(</span><span class="n">rate</span><span class="o">,</span> <span class="n">label</span><span class="o">)</span> <span class="k">=></span> |
| <span class="n">println</span><span class="o">(</span><span class="s">s"label </span><span class="si">$label</span><span class="s">: </span><span class="si">$rate</span><span class="s">"</span><span class="o">)</span> |
| <span class="o">}</span> |
| |
| <span class="n">println</span><span class="o">(</span><span class="s">"Precision by label:"</span><span class="o">)</span> |
| <span class="n">trainingSummary</span><span class="o">.</span><span class="n">precisionByLabel</span><span class="o">.</span><span class="n">zipWithIndex</span><span class="o">.</span><span class="n">foreach</span> <span class="o">{</span> <span class="k">case</span> <span class="o">(</span><span class="n">prec</span><span class="o">,</span> <span class="n">label</span><span class="o">)</span> <span class="k">=></span> |
| <span class="n">println</span><span class="o">(</span><span class="s">s"label </span><span class="si">$label</span><span class="s">: </span><span class="si">$prec</span><span class="s">"</span><span class="o">)</span> |
| <span class="o">}</span> |
| |
| <span class="n">println</span><span class="o">(</span><span class="s">"Recall by label:"</span><span class="o">)</span> |
| <span class="n">trainingSummary</span><span class="o">.</span><span class="n">recallByLabel</span><span class="o">.</span><span class="n">zipWithIndex</span><span class="o">.</span><span class="n">foreach</span> <span class="o">{</span> <span class="k">case</span> <span class="o">(</span><span class="n">rec</span><span class="o">,</span> <span class="n">label</span><span class="o">)</span> <span class="k">=></span> |
| <span class="n">println</span><span class="o">(</span><span class="s">s"label </span><span class="si">$label</span><span class="s">: </span><span class="si">$rec</span><span class="s">"</span><span class="o">)</span> |
| <span class="o">}</span> |
| |
| |
| <span class="n">println</span><span class="o">(</span><span class="s">"F-measure by label:"</span><span class="o">)</span> |
| <span class="n">trainingSummary</span><span class="o">.</span><span class="n">fMeasureByLabel</span><span class="o">.</span><span class="n">zipWithIndex</span><span class="o">.</span><span class="n">foreach</span> <span class="o">{</span> <span class="k">case</span> <span class="o">(</span><span class="n">f</span><span class="o">,</span> <span class="n">label</span><span class="o">)</span> <span class="k">=></span> |
| <span class="n">println</span><span class="o">(</span><span class="s">s"label </span><span class="si">$label</span><span class="s">: </span><span class="si">$f</span><span class="s">"</span><span class="o">)</span> |
| <span class="o">}</span> |
| |
| <span class="k">val</span> <span class="n">accuracy</span> <span class="k">=</span> <span class="n">trainingSummary</span><span class="o">.</span><span class="n">accuracy</span> |
| <span class="k">val</span> <span class="n">falsePositiveRate</span> <span class="k">=</span> <span class="n">trainingSummary</span><span class="o">.</span><span class="n">weightedFalsePositiveRate</span> |
| <span class="k">val</span> <span class="n">truePositiveRate</span> <span class="k">=</span> <span class="n">trainingSummary</span><span class="o">.</span><span class="n">weightedTruePositiveRate</span> |
| <span class="k">val</span> <span class="n">fMeasure</span> <span class="k">=</span> <span class="n">trainingSummary</span><span class="o">.</span><span class="n">weightedFMeasure</span> |
| <span class="k">val</span> <span class="n">precision</span> <span class="k">=</span> <span class="n">trainingSummary</span><span class="o">.</span><span class="n">weightedPrecision</span> |
| <span class="k">val</span> <span class="n">recall</span> <span class="k">=</span> <span class="n">trainingSummary</span><span class="o">.</span><span class="n">weightedRecall</span> |
| <span class="n">println</span><span class="o">(</span><span class="s">s"Accuracy: </span><span class="si">$accuracy</span><span class="s">\nFPR: </span><span class="si">$falsePositiveRate</span><span class="s">\nTPR: </span><span class="si">$truePositiveRate</span><span class="s">\n"</span> <span class="o">+</span> |
| <span class="s">s"F-measure: </span><span class="si">$fMeasure</span><span class="s">\nPrecision: </span><span class="si">$precision</span><span class="s">\nRecall: </span><span class="si">$recall</span><span class="s">"</span><span class="o">)</span> |
| </pre></div> |
| <div><small>Find full example code at "examples/src/main/scala/org/apache/spark/examples/ml/MulticlassLogisticRegressionWithElasticNetExample.scala" in the Spark repo.</small></div> |
| </div> |
| |
| <div data-lang="java"> |
| <div class="highlight"><pre><span></span><span class="kn">import</span> <span class="nn">org.apache.spark.ml.classification.LogisticRegression</span><span class="o">;</span> |
| <span class="kn">import</span> <span class="nn">org.apache.spark.ml.classification.LogisticRegressionModel</span><span class="o">;</span> |
| <span class="kn">import</span> <span class="nn">org.apache.spark.ml.classification.LogisticRegressionTrainingSummary</span><span class="o">;</span> |
| <span class="kn">import</span> <span class="nn">org.apache.spark.sql.Dataset</span><span class="o">;</span> |
| <span class="kn">import</span> <span class="nn">org.apache.spark.sql.Row</span><span class="o">;</span> |
| <span class="kn">import</span> <span class="nn">org.apache.spark.sql.SparkSession</span><span class="o">;</span> |
| |
| <span class="c1">// Load training data</span> |
| <span class="n">Dataset</span><span class="o"><</span><span class="n">Row</span><span class="o">></span> <span class="n">training</span> <span class="o">=</span> <span class="n">spark</span><span class="o">.</span><span class="na">read</span><span class="o">().</span><span class="na">format</span><span class="o">(</span><span class="s">"libsvm"</span><span class="o">)</span> |
| <span class="o">.</span><span class="na">load</span><span class="o">(</span><span class="s">"data/mllib/sample_multiclass_classification_data.txt"</span><span class="o">);</span> |
| |
| <span class="n">LogisticRegression</span> <span class="n">lr</span> <span class="o">=</span> <span class="k">new</span> <span class="n">LogisticRegression</span><span class="o">()</span> |
| <span class="o">.</span><span class="na">setMaxIter</span><span class="o">(</span><span class="mi">10</span><span class="o">)</span> |
| <span class="o">.</span><span class="na">setRegParam</span><span class="o">(</span><span class="mf">0.3</span><span class="o">)</span> |
| <span class="o">.</span><span class="na">setElasticNetParam</span><span class="o">(</span><span class="mf">0.8</span><span class="o">);</span> |
| |
| <span class="c1">// Fit the model</span> |
| <span class="n">LogisticRegressionModel</span> <span class="n">lrModel</span> <span class="o">=</span> <span class="n">lr</span><span class="o">.</span><span class="na">fit</span><span class="o">(</span><span class="n">training</span><span class="o">);</span> |
| |
| <span class="c1">// Print the coefficients and intercept for multinomial logistic regression</span> |
| <span class="n">System</span><span class="o">.</span><span class="na">out</span><span class="o">.</span><span class="na">println</span><span class="o">(</span><span class="s">"Coefficients: \n"</span> |
| <span class="o">+</span> <span class="n">lrModel</span><span class="o">.</span><span class="na">coefficientMatrix</span><span class="o">()</span> <span class="o">+</span> <span class="s">" \nIntercept: "</span> <span class="o">+</span> <span class="n">lrModel</span><span class="o">.</span><span class="na">interceptVector</span><span class="o">());</span> |
| <span class="n">LogisticRegressionTrainingSummary</span> <span class="n">trainingSummary</span> <span class="o">=</span> <span class="n">lrModel</span><span class="o">.</span><span class="na">summary</span><span class="o">();</span> |
| |
| <span class="c1">// Obtain the loss per iteration.</span> |
| <span class="kt">double</span><span class="o">[]</span> <span class="n">objectiveHistory</span> <span class="o">=</span> <span class="n">trainingSummary</span><span class="o">.</span><span class="na">objectiveHistory</span><span class="o">();</span> |
| <span class="k">for</span> <span class="o">(</span><span class="kt">double</span> <span class="n">lossPerIteration</span> <span class="o">:</span> <span class="n">objectiveHistory</span><span class="o">)</span> <span class="o">{</span> |
| <span class="n">System</span><span class="o">.</span><span class="na">out</span><span class="o">.</span><span class="na">println</span><span class="o">(</span><span class="n">lossPerIteration</span><span class="o">);</span> |
| <span class="o">}</span> |
| |
| <span class="c1">// for multiclass, we can inspect metrics on a per-label basis</span> |
| <span class="n">System</span><span class="o">.</span><span class="na">out</span><span class="o">.</span><span class="na">println</span><span class="o">(</span><span class="s">"False positive rate by label:"</span><span class="o">);</span> |
| <span class="kt">int</span> <span class="n">i</span> <span class="o">=</span> <span class="mi">0</span><span class="o">;</span> |
| <span class="kt">double</span><span class="o">[]</span> <span class="n">fprLabel</span> <span class="o">=</span> <span class="n">trainingSummary</span><span class="o">.</span><span class="na">falsePositiveRateByLabel</span><span class="o">();</span> |
| <span class="k">for</span> <span class="o">(</span><span class="kt">double</span> <span class="n">fpr</span> <span class="o">:</span> <span class="n">fprLabel</span><span class="o">)</span> <span class="o">{</span> |
| <span class="n">System</span><span class="o">.</span><span class="na">out</span><span class="o">.</span><span class="na">println</span><span class="o">(</span><span class="s">"label "</span> <span class="o">+</span> <span class="n">i</span> <span class="o">+</span> <span class="s">": "</span> <span class="o">+</span> <span class="n">fpr</span><span class="o">);</span> |
| <span class="n">i</span><span class="o">++;</span> |
| <span class="o">}</span> |
| |
| <span class="n">System</span><span class="o">.</span><span class="na">out</span><span class="o">.</span><span class="na">println</span><span class="o">(</span><span class="s">"True positive rate by label:"</span><span class="o">);</span> |
| <span class="n">i</span> <span class="o">=</span> <span class="mi">0</span><span class="o">;</span> |
| <span class="kt">double</span><span class="o">[]</span> <span class="n">tprLabel</span> <span class="o">=</span> <span class="n">trainingSummary</span><span class="o">.</span><span class="na">truePositiveRateByLabel</span><span class="o">();</span> |
| <span class="k">for</span> <span class="o">(</span><span class="kt">double</span> <span class="n">tpr</span> <span class="o">:</span> <span class="n">tprLabel</span><span class="o">)</span> <span class="o">{</span> |
| <span class="n">System</span><span class="o">.</span><span class="na">out</span><span class="o">.</span><span class="na">println</span><span class="o">(</span><span class="s">"label "</span> <span class="o">+</span> <span class="n">i</span> <span class="o">+</span> <span class="s">": "</span> <span class="o">+</span> <span class="n">tpr</span><span class="o">);</span> |
| <span class="n">i</span><span class="o">++;</span> |
| <span class="o">}</span> |
| |
| <span class="n">System</span><span class="o">.</span><span class="na">out</span><span class="o">.</span><span class="na">println</span><span class="o">(</span><span class="s">"Precision by label:"</span><span class="o">);</span> |
| <span class="n">i</span> <span class="o">=</span> <span class="mi">0</span><span class="o">;</span> |
| <span class="kt">double</span><span class="o">[]</span> <span class="n">precLabel</span> <span class="o">=</span> <span class="n">trainingSummary</span><span class="o">.</span><span class="na">precisionByLabel</span><span class="o">();</span> |
| <span class="k">for</span> <span class="o">(</span><span class="kt">double</span> <span class="n">prec</span> <span class="o">:</span> <span class="n">precLabel</span><span class="o">)</span> <span class="o">{</span> |
| <span class="n">System</span><span class="o">.</span><span class="na">out</span><span class="o">.</span><span class="na">println</span><span class="o">(</span><span class="s">"label "</span> <span class="o">+</span> <span class="n">i</span> <span class="o">+</span> <span class="s">": "</span> <span class="o">+</span> <span class="n">prec</span><span class="o">);</span> |
| <span class="n">i</span><span class="o">++;</span> |
| <span class="o">}</span> |
| |
| <span class="n">System</span><span class="o">.</span><span class="na">out</span><span class="o">.</span><span class="na">println</span><span class="o">(</span><span class="s">"Recall by label:"</span><span class="o">);</span> |
| <span class="n">i</span> <span class="o">=</span> <span class="mi">0</span><span class="o">;</span> |
| <span class="kt">double</span><span class="o">[]</span> <span class="n">recLabel</span> <span class="o">=</span> <span class="n">trainingSummary</span><span class="o">.</span><span class="na">recallByLabel</span><span class="o">();</span> |
| <span class="k">for</span> <span class="o">(</span><span class="kt">double</span> <span class="n">rec</span> <span class="o">:</span> <span class="n">recLabel</span><span class="o">)</span> <span class="o">{</span> |
| <span class="n">System</span><span class="o">.</span><span class="na">out</span><span class="o">.</span><span class="na">println</span><span class="o">(</span><span class="s">"label "</span> <span class="o">+</span> <span class="n">i</span> <span class="o">+</span> <span class="s">": "</span> <span class="o">+</span> <span class="n">rec</span><span class="o">);</span> |
| <span class="n">i</span><span class="o">++;</span> |
| <span class="o">}</span> |
| |
| <span class="n">System</span><span class="o">.</span><span class="na">out</span><span class="o">.</span><span class="na">println</span><span class="o">(</span><span class="s">"F-measure by label:"</span><span class="o">);</span> |
| <span class="n">i</span> <span class="o">=</span> <span class="mi">0</span><span class="o">;</span> |
| <span class="kt">double</span><span class="o">[]</span> <span class="n">fLabel</span> <span class="o">=</span> <span class="n">trainingSummary</span><span class="o">.</span><span class="na">fMeasureByLabel</span><span class="o">();</span> |
| <span class="k">for</span> <span class="o">(</span><span class="kt">double</span> <span class="n">f</span> <span class="o">:</span> <span class="n">fLabel</span><span class="o">)</span> <span class="o">{</span> |
| <span class="n">System</span><span class="o">.</span><span class="na">out</span><span class="o">.</span><span class="na">println</span><span class="o">(</span><span class="s">"label "</span> <span class="o">+</span> <span class="n">i</span> <span class="o">+</span> <span class="s">": "</span> <span class="o">+</span> <span class="n">f</span><span class="o">);</span> |
| <span class="n">i</span><span class="o">++;</span> |
| <span class="o">}</span> |
| |
| <span class="kt">double</span> <span class="n">accuracy</span> <span class="o">=</span> <span class="n">trainingSummary</span><span class="o">.</span><span class="na">accuracy</span><span class="o">();</span> |
| <span class="kt">double</span> <span class="n">falsePositiveRate</span> <span class="o">=</span> <span class="n">trainingSummary</span><span class="o">.</span><span class="na">weightedFalsePositiveRate</span><span class="o">();</span> |
| <span class="kt">double</span> <span class="n">truePositiveRate</span> <span class="o">=</span> <span class="n">trainingSummary</span><span class="o">.</span><span class="na">weightedTruePositiveRate</span><span class="o">();</span> |
| <span class="kt">double</span> <span class="n">fMeasure</span> <span class="o">=</span> <span class="n">trainingSummary</span><span class="o">.</span><span class="na">weightedFMeasure</span><span class="o">();</span> |
| <span class="kt">double</span> <span class="n">precision</span> <span class="o">=</span> <span class="n">trainingSummary</span><span class="o">.</span><span class="na">weightedPrecision</span><span class="o">();</span> |
| <span class="kt">double</span> <span class="n">recall</span> <span class="o">=</span> <span class="n">trainingSummary</span><span class="o">.</span><span class="na">weightedRecall</span><span class="o">();</span> |
| <span class="n">System</span><span class="o">.</span><span class="na">out</span><span class="o">.</span><span class="na">println</span><span class="o">(</span><span class="s">"Accuracy: "</span> <span class="o">+</span> <span class="n">accuracy</span><span class="o">);</span> |
| <span class="n">System</span><span class="o">.</span><span class="na">out</span><span class="o">.</span><span class="na">println</span><span class="o">(</span><span class="s">"FPR: "</span> <span class="o">+</span> <span class="n">falsePositiveRate</span><span class="o">);</span> |
| <span class="n">System</span><span class="o">.</span><span class="na">out</span><span class="o">.</span><span class="na">println</span><span class="o">(</span><span class="s">"TPR: "</span> <span class="o">+</span> <span class="n">truePositiveRate</span><span class="o">);</span> |
| <span class="n">System</span><span class="o">.</span><span class="na">out</span><span class="o">.</span><span class="na">println</span><span class="o">(</span><span class="s">"F-measure: "</span> <span class="o">+</span> <span class="n">fMeasure</span><span class="o">);</span> |
| <span class="n">System</span><span class="o">.</span><span class="na">out</span><span class="o">.</span><span class="na">println</span><span class="o">(</span><span class="s">"Precision: "</span> <span class="o">+</span> <span class="n">precision</span><span class="o">);</span> |
| <span class="n">System</span><span class="o">.</span><span class="na">out</span><span class="o">.</span><span class="na">println</span><span class="o">(</span><span class="s">"Recall: "</span> <span class="o">+</span> <span class="n">recall</span><span class="o">);</span> |
| </pre></div> |
| <div><small>Find full example code at "examples/src/main/java/org/apache/spark/examples/ml/JavaMulticlassLogisticRegressionWithElasticNetExample.java" in the Spark repo.</small></div> |
| </div> |
| |
| <div data-lang="python"> |
| <div class="highlight"><pre><span></span><span class="kn">from</span> <span class="nn">pyspark.ml.classification</span> <span class="kn">import</span> <span class="n">LogisticRegression</span> |
| |
| <span class="c1"># Load training data</span> |
| <span class="n">training</span> <span class="o">=</span> <span class="n">spark</span> \ |
| <span class="o">.</span><span class="n">read</span> \ |
| <span class="o">.</span><span class="n">format</span><span class="p">(</span><span class="s2">"libsvm"</span><span class="p">)</span> \ |
| <span class="o">.</span><span class="n">load</span><span class="p">(</span><span class="s2">"data/mllib/sample_multiclass_classification_data.txt"</span><span class="p">)</span> |
| |
| <span class="n">lr</span> <span class="o">=</span> <span class="n">LogisticRegression</span><span class="p">(</span><span class="n">maxIter</span><span class="o">=</span><span class="mi">10</span><span class="p">,</span> <span class="n">regParam</span><span class="o">=</span><span class="mf">0.3</span><span class="p">,</span> <span class="n">elasticNetParam</span><span class="o">=</span><span class="mf">0.8</span><span class="p">)</span> |
| |
| <span class="c1"># Fit the model</span> |
| <span class="n">lrModel</span> <span class="o">=</span> <span class="n">lr</span><span class="o">.</span><span class="n">fit</span><span class="p">(</span><span class="n">training</span><span class="p">)</span> |
| |
| <span class="c1"># Print the coefficients and intercept for multinomial logistic regression</span> |
| <span class="k">print</span><span class="p">(</span><span class="s2">"Coefficients: </span><span class="se">\n</span><span class="s2">"</span> <span class="o">+</span> <span class="nb">str</span><span class="p">(</span><span class="n">lrModel</span><span class="o">.</span><span class="n">coefficientMatrix</span><span class="p">))</span> |
| <span class="k">print</span><span class="p">(</span><span class="s2">"Intercept: "</span> <span class="o">+</span> <span class="nb">str</span><span class="p">(</span><span class="n">lrModel</span><span class="o">.</span><span class="n">interceptVector</span><span class="p">))</span> |
| |
| <span class="n">trainingSummary</span> <span class="o">=</span> <span class="n">lrModel</span><span class="o">.</span><span class="n">summary</span> |
| |
| <span class="c1"># Obtain the objective per iteration</span> |
| <span class="n">objectiveHistory</span> <span class="o">=</span> <span class="n">trainingSummary</span><span class="o">.</span><span class="n">objectiveHistory</span> |
| <span class="k">print</span><span class="p">(</span><span class="s2">"objectiveHistory:"</span><span class="p">)</span> |
| <span class="k">for</span> <span class="n">objective</span> <span class="ow">in</span> <span class="n">objectiveHistory</span><span class="p">:</span> |
| <span class="k">print</span><span class="p">(</span><span class="n">objective</span><span class="p">)</span> |
| |
| <span class="c1"># for multiclass, we can inspect metrics on a per-label basis</span> |
| <span class="k">print</span><span class="p">(</span><span class="s2">"False positive rate by label:"</span><span class="p">)</span> |
| <span class="k">for</span> <span class="n">i</span><span class="p">,</span> <span class="n">rate</span> <span class="ow">in</span> <span class="nb">enumerate</span><span class="p">(</span><span class="n">trainingSummary</span><span class="o">.</span><span class="n">falsePositiveRateByLabel</span><span class="p">):</span> |
| <span class="k">print</span><span class="p">(</span><span class="s2">"label </span><span class="si">%d</span><span class="s2">: </span><span class="si">%s</span><span class="s2">"</span> <span class="o">%</span> <span class="p">(</span><span class="n">i</span><span class="p">,</span> <span class="n">rate</span><span class="p">))</span> |
| |
| <span class="k">print</span><span class="p">(</span><span class="s2">"True positive rate by label:"</span><span class="p">)</span> |
| <span class="k">for</span> <span class="n">i</span><span class="p">,</span> <span class="n">rate</span> <span class="ow">in</span> <span class="nb">enumerate</span><span class="p">(</span><span class="n">trainingSummary</span><span class="o">.</span><span class="n">truePositiveRateByLabel</span><span class="p">):</span> |
| <span class="k">print</span><span class="p">(</span><span class="s2">"label </span><span class="si">%d</span><span class="s2">: </span><span class="si">%s</span><span class="s2">"</span> <span class="o">%</span> <span class="p">(</span><span class="n">i</span><span class="p">,</span> <span class="n">rate</span><span class="p">))</span> |
| |
| <span class="k">print</span><span class="p">(</span><span class="s2">"Precision by label:"</span><span class="p">)</span> |
| <span class="k">for</span> <span class="n">i</span><span class="p">,</span> <span class="n">prec</span> <span class="ow">in</span> <span class="nb">enumerate</span><span class="p">(</span><span class="n">trainingSummary</span><span class="o">.</span><span class="n">precisionByLabel</span><span class="p">):</span> |
| <span class="k">print</span><span class="p">(</span><span class="s2">"label </span><span class="si">%d</span><span class="s2">: </span><span class="si">%s</span><span class="s2">"</span> <span class="o">%</span> <span class="p">(</span><span class="n">i</span><span class="p">,</span> <span class="n">prec</span><span class="p">))</span> |
| |
| <span class="k">print</span><span class="p">(</span><span class="s2">"Recall by label:"</span><span class="p">)</span> |
| <span class="k">for</span> <span class="n">i</span><span class="p">,</span> <span class="n">rec</span> <span class="ow">in</span> <span class="nb">enumerate</span><span class="p">(</span><span class="n">trainingSummary</span><span class="o">.</span><span class="n">recallByLabel</span><span class="p">):</span> |
| <span class="k">print</span><span class="p">(</span><span class="s2">"label </span><span class="si">%d</span><span class="s2">: </span><span class="si">%s</span><span class="s2">"</span> <span class="o">%</span> <span class="p">(</span><span class="n">i</span><span class="p">,</span> <span class="n">rec</span><span class="p">))</span> |
| |
| <span class="k">print</span><span class="p">(</span><span class="s2">"F-measure by label:"</span><span class="p">)</span> |
| <span class="k">for</span> <span class="n">i</span><span class="p">,</span> <span class="n">f</span> <span class="ow">in</span> <span class="nb">enumerate</span><span class="p">(</span><span class="n">trainingSummary</span><span class="o">.</span><span class="n">fMeasureByLabel</span><span class="p">()):</span> |
| <span class="k">print</span><span class="p">(</span><span class="s2">"label </span><span class="si">%d</span><span class="s2">: </span><span class="si">%s</span><span class="s2">"</span> <span class="o">%</span> <span class="p">(</span><span class="n">i</span><span class="p">,</span> <span class="n">f</span><span class="p">))</span> |
| |
| <span class="n">accuracy</span> <span class="o">=</span> <span class="n">trainingSummary</span><span class="o">.</span><span class="n">accuracy</span> |
| <span class="n">falsePositiveRate</span> <span class="o">=</span> <span class="n">trainingSummary</span><span class="o">.</span><span class="n">weightedFalsePositiveRate</span> |
| <span class="n">truePositiveRate</span> <span class="o">=</span> <span class="n">trainingSummary</span><span class="o">.</span><span class="n">weightedTruePositiveRate</span> |
| <span class="n">fMeasure</span> <span class="o">=</span> <span class="n">trainingSummary</span><span class="o">.</span><span class="n">weightedFMeasure</span><span class="p">()</span> |
| <span class="n">precision</span> <span class="o">=</span> <span class="n">trainingSummary</span><span class="o">.</span><span class="n">weightedPrecision</span> |
| <span class="n">recall</span> <span class="o">=</span> <span class="n">trainingSummary</span><span class="o">.</span><span class="n">weightedRecall</span> |
| <span class="k">print</span><span class="p">(</span><span class="s2">"Accuracy: </span><span class="si">%s</span><span class="se">\n</span><span class="s2">FPR: </span><span class="si">%s</span><span class="se">\n</span><span class="s2">TPR: </span><span class="si">%s</span><span class="se">\n</span><span class="s2">F-measure: </span><span class="si">%s</span><span class="se">\n</span><span class="s2">Precision: </span><span class="si">%s</span><span class="se">\n</span><span class="s2">Recall: </span><span class="si">%s</span><span class="s2">"</span> |
| <span class="o">%</span> <span class="p">(</span><span class="n">accuracy</span><span class="p">,</span> <span class="n">falsePositiveRate</span><span class="p">,</span> <span class="n">truePositiveRate</span><span class="p">,</span> <span class="n">fMeasure</span><span class="p">,</span> <span class="n">precision</span><span class="p">,</span> <span class="n">recall</span><span class="p">))</span> |
| </pre></div> |
| <div><small>Find full example code at "examples/src/main/python/ml/multiclass_logistic_regression_with_elastic_net.py" in the Spark repo.</small></div> |
| </div> |
| |
| <div data-lang="r"> |
| |
| <p>More details on parameters can be found in the <a href="api/R/spark.logit.html">R API documentation</a>.</p> |
| |
| <div class="highlight"><pre><span></span><span class="c1"># Load training data</span> |
| df <span class="o"><-</span> read.df<span class="p">(</span><span class="s">"data/mllib/sample_multiclass_classification_data.txt"</span><span class="p">,</span> <span class="kn">source</span> <span class="o">=</span> <span class="s">"libsvm"</span><span class="p">)</span> |
| training <span class="o"><-</span> df |
| test <span class="o"><-</span> df |
| |
| <span class="c1"># Fit a multinomial logistic regression model with spark.logit</span> |
| model <span class="o"><-</span> spark.logit<span class="p">(</span>training<span class="p">,</span> label <span class="o">~</span> features<span class="p">,</span> maxIter <span class="o">=</span> <span class="m">10</span><span class="p">,</span> regParam <span class="o">=</span> <span class="m">0.3</span><span class="p">,</span> elasticNetParam <span class="o">=</span> <span class="m">0.8</span><span class="p">)</span> |
| |
| <span class="c1"># Model summary</span> |
| <span class="kp">summary</span><span class="p">(</span>model<span class="p">)</span> |
| |
| <span class="c1"># Prediction</span> |
| predictions <span class="o"><-</span> predict<span class="p">(</span>model<span class="p">,</span> test<span class="p">)</span> |
| <span class="kp">head</span><span class="p">(</span>predictions<span class="p">)</span> |
| </pre></div> |
| <div><small>Find full example code at "examples/src/main/r/ml/logit.R" in the Spark repo.</small></div> |
| </div> |
| |
| </div> |
| |
| <h2 id="decision-tree-classifier">Decision tree classifier</h2> |
| |
| <p>Decision trees are a popular family of classification and regression methods. |
| More information about the <code>spark.ml</code> implementation can be found further in the <a href="#decision-trees">section on decision trees</a>.</p> |
| |
| <p><strong>Examples</strong></p> |
| |
| <p>The following examples load a dataset in LibSVM format, split it into training and test sets, train on the first dataset, and then evaluate on the held-out test set. |
| We use two feature transformers to prepare the data; these help index categories for the label and categorical features, adding metadata to the <code>DataFrame</code> which the Decision Tree algorithm can recognize.</p> |
| |
| <div class="codetabs"> |
| <div data-lang="scala"> |
| |
| <p>More details on parameters can be found in the <a href="api/scala/index.html#org.apache.spark.ml.classification.DecisionTreeClassifier">Scala API documentation</a>.</p> |
| |
| <div class="highlight"><pre><span></span><span class="k">import</span> <span class="nn">org.apache.spark.ml.Pipeline</span> |
| <span class="k">import</span> <span class="nn">org.apache.spark.ml.classification.DecisionTreeClassificationModel</span> |
| <span class="k">import</span> <span class="nn">org.apache.spark.ml.classification.DecisionTreeClassifier</span> |
| <span class="k">import</span> <span class="nn">org.apache.spark.ml.evaluation.MulticlassClassificationEvaluator</span> |
| <span class="k">import</span> <span class="nn">org.apache.spark.ml.feature.</span><span class="o">{</span><span class="nc">IndexToString</span><span class="o">,</span> <span class="nc">StringIndexer</span><span class="o">,</span> <span class="nc">VectorIndexer</span><span class="o">}</span> |
| |
| <span class="c1">// Load the data stored in LIBSVM format as a DataFrame.</span> |
| <span class="k">val</span> <span class="n">data</span> <span class="k">=</span> <span class="n">spark</span><span class="o">.</span><span class="n">read</span><span class="o">.</span><span class="n">format</span><span class="o">(</span><span class="s">"libsvm"</span><span class="o">).</span><span class="n">load</span><span class="o">(</span><span class="s">"data/mllib/sample_libsvm_data.txt"</span><span class="o">)</span> |
| |
| <span class="c1">// Index labels, adding metadata to the label column.</span> |
| <span class="c1">// Fit on whole dataset to include all labels in index.</span> |
| <span class="k">val</span> <span class="n">labelIndexer</span> <span class="k">=</span> <span class="k">new</span> <span class="nc">StringIndexer</span><span class="o">()</span> |
| <span class="o">.</span><span class="n">setInputCol</span><span class="o">(</span><span class="s">"label"</span><span class="o">)</span> |
| <span class="o">.</span><span class="n">setOutputCol</span><span class="o">(</span><span class="s">"indexedLabel"</span><span class="o">)</span> |
| <span class="o">.</span><span class="n">fit</span><span class="o">(</span><span class="n">data</span><span class="o">)</span> |
| <span class="c1">// Automatically identify categorical features, and index them.</span> |
| <span class="k">val</span> <span class="n">featureIndexer</span> <span class="k">=</span> <span class="k">new</span> <span class="nc">VectorIndexer</span><span class="o">()</span> |
| <span class="o">.</span><span class="n">setInputCol</span><span class="o">(</span><span class="s">"features"</span><span class="o">)</span> |
| <span class="o">.</span><span class="n">setOutputCol</span><span class="o">(</span><span class="s">"indexedFeatures"</span><span class="o">)</span> |
| <span class="o">.</span><span class="n">setMaxCategories</span><span class="o">(</span><span class="mi">4</span><span class="o">)</span> <span class="c1">// features with > 4 distinct values are treated as continuous.</span> |
| <span class="o">.</span><span class="n">fit</span><span class="o">(</span><span class="n">data</span><span class="o">)</span> |
| |
| <span class="c1">// Split the data into training and test sets (30% held out for testing).</span> |
| <span class="k">val</span> <span class="nc">Array</span><span class="o">(</span><span class="n">trainingData</span><span class="o">,</span> <span class="n">testData</span><span class="o">)</span> <span class="k">=</span> <span class="n">data</span><span class="o">.</span><span class="n">randomSplit</span><span class="o">(</span><span class="nc">Array</span><span class="o">(</span><span class="mf">0.7</span><span class="o">,</span> <span class="mf">0.3</span><span class="o">))</span> |
| |
| <span class="c1">// Train a DecisionTree model.</span> |
| <span class="k">val</span> <span class="n">dt</span> <span class="k">=</span> <span class="k">new</span> <span class="nc">DecisionTreeClassifier</span><span class="o">()</span> |
| <span class="o">.</span><span class="n">setLabelCol</span><span class="o">(</span><span class="s">"indexedLabel"</span><span class="o">)</span> |
| <span class="o">.</span><span class="n">setFeaturesCol</span><span class="o">(</span><span class="s">"indexedFeatures"</span><span class="o">)</span> |
| |
| <span class="c1">// Convert indexed labels back to original labels.</span> |
| <span class="k">val</span> <span class="n">labelConverter</span> <span class="k">=</span> <span class="k">new</span> <span class="nc">IndexToString</span><span class="o">()</span> |
| <span class="o">.</span><span class="n">setInputCol</span><span class="o">(</span><span class="s">"prediction"</span><span class="o">)</span> |
| <span class="o">.</span><span class="n">setOutputCol</span><span class="o">(</span><span class="s">"predictedLabel"</span><span class="o">)</span> |
| <span class="o">.</span><span class="n">setLabels</span><span class="o">(</span><span class="n">labelIndexer</span><span class="o">.</span><span class="n">labels</span><span class="o">)</span> |
| |
| <span class="c1">// Chain indexers and tree in a Pipeline.</span> |
| <span class="k">val</span> <span class="n">pipeline</span> <span class="k">=</span> <span class="k">new</span> <span class="nc">Pipeline</span><span class="o">()</span> |
| <span class="o">.</span><span class="n">setStages</span><span class="o">(</span><span class="nc">Array</span><span class="o">(</span><span class="n">labelIndexer</span><span class="o">,</span> <span class="n">featureIndexer</span><span class="o">,</span> <span class="n">dt</span><span class="o">,</span> <span class="n">labelConverter</span><span class="o">))</span> |
| |
| <span class="c1">// Train model. This also runs the indexers.</span> |
| <span class="k">val</span> <span class="n">model</span> <span class="k">=</span> <span class="n">pipeline</span><span class="o">.</span><span class="n">fit</span><span class="o">(</span><span class="n">trainingData</span><span class="o">)</span> |
| |
| <span class="c1">// Make predictions.</span> |
| <span class="k">val</span> <span class="n">predictions</span> <span class="k">=</span> <span class="n">model</span><span class="o">.</span><span class="n">transform</span><span class="o">(</span><span class="n">testData</span><span class="o">)</span> |
| |
| <span class="c1">// Select example rows to display.</span> |
| <span class="n">predictions</span><span class="o">.</span><span class="n">select</span><span class="o">(</span><span class="s">"predictedLabel"</span><span class="o">,</span> <span class="s">"label"</span><span class="o">,</span> <span class="s">"features"</span><span class="o">).</span><span class="n">show</span><span class="o">(</span><span class="mi">5</span><span class="o">)</span> |
| |
| <span class="c1">// Select (prediction, true label) and compute test error.</span> |
| <span class="k">val</span> <span class="n">evaluator</span> <span class="k">=</span> <span class="k">new</span> <span class="nc">MulticlassClassificationEvaluator</span><span class="o">()</span> |
| <span class="o">.</span><span class="n">setLabelCol</span><span class="o">(</span><span class="s">"indexedLabel"</span><span class="o">)</span> |
| <span class="o">.</span><span class="n">setPredictionCol</span><span class="o">(</span><span class="s">"prediction"</span><span class="o">)</span> |
| <span class="o">.</span><span class="n">setMetricName</span><span class="o">(</span><span class="s">"accuracy"</span><span class="o">)</span> |
| <span class="k">val</span> <span class="n">accuracy</span> <span class="k">=</span> <span class="n">evaluator</span><span class="o">.</span><span class="n">evaluate</span><span class="o">(</span><span class="n">predictions</span><span class="o">)</span> |
| <span class="n">println</span><span class="o">(</span><span class="s">s"Test Error = </span><span class="si">${</span><span class="o">(</span><span class="mf">1.0</span> <span class="o">-</span> <span class="n">accuracy</span><span class="o">)</span><span class="si">}</span><span class="s">"</span><span class="o">)</span> |
| |
| <span class="k">val</span> <span class="n">treeModel</span> <span class="k">=</span> <span class="n">model</span><span class="o">.</span><span class="n">stages</span><span class="o">(</span><span class="mi">2</span><span class="o">).</span><span class="n">asInstanceOf</span><span class="o">[</span><span class="kt">DecisionTreeClassificationModel</span><span class="o">]</span> |
| <span class="n">println</span><span class="o">(</span><span class="s">s"Learned classification tree model:\n </span><span class="si">${</span><span class="n">treeModel</span><span class="o">.</span><span class="n">toDebugString</span><span class="si">}</span><span class="s">"</span><span class="o">)</span> |
| </pre></div> |
| <div><small>Find full example code at "examples/src/main/scala/org/apache/spark/examples/ml/DecisionTreeClassificationExample.scala" in the Spark repo.</small></div> |
| |
| </div> |
| |
| <div data-lang="java"> |
| |
| <p>More details on parameters can be found in the <a href="api/java/org/apache/spark/ml/classification/DecisionTreeClassifier.html">Java API documentation</a>.</p> |
| |
| <div class="highlight"><pre><span></span><span class="kn">import</span> <span class="nn">org.apache.spark.ml.Pipeline</span><span class="o">;</span> |
| <span class="kn">import</span> <span class="nn">org.apache.spark.ml.PipelineModel</span><span class="o">;</span> |
| <span class="kn">import</span> <span class="nn">org.apache.spark.ml.PipelineStage</span><span class="o">;</span> |
| <span class="kn">import</span> <span class="nn">org.apache.spark.ml.classification.DecisionTreeClassifier</span><span class="o">;</span> |
| <span class="kn">import</span> <span class="nn">org.apache.spark.ml.classification.DecisionTreeClassificationModel</span><span class="o">;</span> |
| <span class="kn">import</span> <span class="nn">org.apache.spark.ml.evaluation.MulticlassClassificationEvaluator</span><span class="o">;</span> |
| <span class="kn">import</span> <span class="nn">org.apache.spark.ml.feature.*</span><span class="o">;</span> |
| <span class="kn">import</span> <span class="nn">org.apache.spark.sql.Dataset</span><span class="o">;</span> |
| <span class="kn">import</span> <span class="nn">org.apache.spark.sql.Row</span><span class="o">;</span> |
| <span class="kn">import</span> <span class="nn">org.apache.spark.sql.SparkSession</span><span class="o">;</span> |
| |
| <span class="c1">// Load the data stored in LIBSVM format as a DataFrame.</span> |
| <span class="n">Dataset</span><span class="o"><</span><span class="n">Row</span><span class="o">></span> <span class="n">data</span> <span class="o">=</span> <span class="n">spark</span> |
| <span class="o">.</span><span class="na">read</span><span class="o">()</span> |
| <span class="o">.</span><span class="na">format</span><span class="o">(</span><span class="s">"libsvm"</span><span class="o">)</span> |
| <span class="o">.</span><span class="na">load</span><span class="o">(</span><span class="s">"data/mllib/sample_libsvm_data.txt"</span><span class="o">);</span> |
| |
| <span class="c1">// Index labels, adding metadata to the label column.</span> |
| <span class="c1">// Fit on whole dataset to include all labels in index.</span> |
| <span class="n">StringIndexerModel</span> <span class="n">labelIndexer</span> <span class="o">=</span> <span class="k">new</span> <span class="n">StringIndexer</span><span class="o">()</span> |
| <span class="o">.</span><span class="na">setInputCol</span><span class="o">(</span><span class="s">"label"</span><span class="o">)</span> |
| <span class="o">.</span><span class="na">setOutputCol</span><span class="o">(</span><span class="s">"indexedLabel"</span><span class="o">)</span> |
| <span class="o">.</span><span class="na">fit</span><span class="o">(</span><span class="n">data</span><span class="o">);</span> |
| |
| <span class="c1">// Automatically identify categorical features, and index them.</span> |
| <span class="n">VectorIndexerModel</span> <span class="n">featureIndexer</span> <span class="o">=</span> <span class="k">new</span> <span class="n">VectorIndexer</span><span class="o">()</span> |
| <span class="o">.</span><span class="na">setInputCol</span><span class="o">(</span><span class="s">"features"</span><span class="o">)</span> |
| <span class="o">.</span><span class="na">setOutputCol</span><span class="o">(</span><span class="s">"indexedFeatures"</span><span class="o">)</span> |
| <span class="o">.</span><span class="na">setMaxCategories</span><span class="o">(</span><span class="mi">4</span><span class="o">)</span> <span class="c1">// features with > 4 distinct values are treated as continuous.</span> |
| <span class="o">.</span><span class="na">fit</span><span class="o">(</span><span class="n">data</span><span class="o">);</span> |
| |
| <span class="c1">// Split the data into training and test sets (30% held out for testing).</span> |
| <span class="n">Dataset</span><span class="o"><</span><span class="n">Row</span><span class="o">>[]</span> <span class="n">splits</span> <span class="o">=</span> <span class="n">data</span><span class="o">.</span><span class="na">randomSplit</span><span class="o">(</span><span class="k">new</span> <span class="kt">double</span><span class="o">[]{</span><span class="mf">0.7</span><span class="o">,</span> <span class="mf">0.3</span><span class="o">});</span> |
| <span class="n">Dataset</span><span class="o"><</span><span class="n">Row</span><span class="o">></span> <span class="n">trainingData</span> <span class="o">=</span> <span class="n">splits</span><span class="o">[</span><span class="mi">0</span><span class="o">];</span> |
| <span class="n">Dataset</span><span class="o"><</span><span class="n">Row</span><span class="o">></span> <span class="n">testData</span> <span class="o">=</span> <span class="n">splits</span><span class="o">[</span><span class="mi">1</span><span class="o">];</span> |
| |
| <span class="c1">// Train a DecisionTree model.</span> |
| <span class="n">DecisionTreeClassifier</span> <span class="n">dt</span> <span class="o">=</span> <span class="k">new</span> <span class="n">DecisionTreeClassifier</span><span class="o">()</span> |
| <span class="o">.</span><span class="na">setLabelCol</span><span class="o">(</span><span class="s">"indexedLabel"</span><span class="o">)</span> |
| <span class="o">.</span><span class="na">setFeaturesCol</span><span class="o">(</span><span class="s">"indexedFeatures"</span><span class="o">);</span> |
| |
| <span class="c1">// Convert indexed labels back to original labels.</span> |
| <span class="n">IndexToString</span> <span class="n">labelConverter</span> <span class="o">=</span> <span class="k">new</span> <span class="n">IndexToString</span><span class="o">()</span> |
| <span class="o">.</span><span class="na">setInputCol</span><span class="o">(</span><span class="s">"prediction"</span><span class="o">)</span> |
| <span class="o">.</span><span class="na">setOutputCol</span><span class="o">(</span><span class="s">"predictedLabel"</span><span class="o">)</span> |
| <span class="o">.</span><span class="na">setLabels</span><span class="o">(</span><span class="n">labelIndexer</span><span class="o">.</span><span class="na">labels</span><span class="o">());</span> |
| |
| <span class="c1">// Chain indexers and tree in a Pipeline.</span> |
| <span class="n">Pipeline</span> <span class="n">pipeline</span> <span class="o">=</span> <span class="k">new</span> <span class="n">Pipeline</span><span class="o">()</span> |
| <span class="o">.</span><span class="na">setStages</span><span class="o">(</span><span class="k">new</span> <span class="n">PipelineStage</span><span class="o">[]{</span><span class="n">labelIndexer</span><span class="o">,</span> <span class="n">featureIndexer</span><span class="o">,</span> <span class="n">dt</span><span class="o">,</span> <span class="n">labelConverter</span><span class="o">});</span> |
| |
| <span class="c1">// Train model. This also runs the indexers.</span> |
| <span class="n">PipelineModel</span> <span class="n">model</span> <span class="o">=</span> <span class="n">pipeline</span><span class="o">.</span><span class="na">fit</span><span class="o">(</span><span class="n">trainingData</span><span class="o">);</span> |
| |
| <span class="c1">// Make predictions.</span> |
| <span class="n">Dataset</span><span class="o"><</span><span class="n">Row</span><span class="o">></span> <span class="n">predictions</span> <span class="o">=</span> <span class="n">model</span><span class="o">.</span><span class="na">transform</span><span class="o">(</span><span class="n">testData</span><span class="o">);</span> |
| |
| <span class="c1">// Select example rows to display.</span> |
| <span class="n">predictions</span><span class="o">.</span><span class="na">select</span><span class="o">(</span><span class="s">"predictedLabel"</span><span class="o">,</span> <span class="s">"label"</span><span class="o">,</span> <span class="s">"features"</span><span class="o">).</span><span class="na">show</span><span class="o">(</span><span class="mi">5</span><span class="o">);</span> |
| |
| <span class="c1">// Select (prediction, true label) and compute test error.</span> |
| <span class="n">MulticlassClassificationEvaluator</span> <span class="n">evaluator</span> <span class="o">=</span> <span class="k">new</span> <span class="n">MulticlassClassificationEvaluator</span><span class="o">()</span> |
| <span class="o">.</span><span class="na">setLabelCol</span><span class="o">(</span><span class="s">"indexedLabel"</span><span class="o">)</span> |
| <span class="o">.</span><span class="na">setPredictionCol</span><span class="o">(</span><span class="s">"prediction"</span><span class="o">)</span> |
| <span class="o">.</span><span class="na">setMetricName</span><span class="o">(</span><span class="s">"accuracy"</span><span class="o">);</span> |
| <span class="kt">double</span> <span class="n">accuracy</span> <span class="o">=</span> <span class="n">evaluator</span><span class="o">.</span><span class="na">evaluate</span><span class="o">(</span><span class="n">predictions</span><span class="o">);</span> |
| <span class="n">System</span><span class="o">.</span><span class="na">out</span><span class="o">.</span><span class="na">println</span><span class="o">(</span><span class="s">"Test Error = "</span> <span class="o">+</span> <span class="o">(</span><span class="mf">1.0</span> <span class="o">-</span> <span class="n">accuracy</span><span class="o">));</span> |
| |
| <span class="n">DecisionTreeClassificationModel</span> <span class="n">treeModel</span> <span class="o">=</span> |
| <span class="o">(</span><span class="n">DecisionTreeClassificationModel</span><span class="o">)</span> <span class="o">(</span><span class="n">model</span><span class="o">.</span><span class="na">stages</span><span class="o">()[</span><span class="mi">2</span><span class="o">]);</span> |
| <span class="n">System</span><span class="o">.</span><span class="na">out</span><span class="o">.</span><span class="na">println</span><span class="o">(</span><span class="s">"Learned classification tree model:\n"</span> <span class="o">+</span> <span class="n">treeModel</span><span class="o">.</span><span class="na">toDebugString</span><span class="o">());</span> |
| </pre></div> |
| <div><small>Find full example code at "examples/src/main/java/org/apache/spark/examples/ml/JavaDecisionTreeClassificationExample.java" in the Spark repo.</small></div> |
| |
| </div> |
| |
| <div data-lang="python"> |
| |
| <p>More details on parameters can be found in the <a href="api/python/pyspark.ml.html#pyspark.ml.classification.DecisionTreeClassifier">Python API documentation</a>.</p> |
| |
| <div class="highlight"><pre><span></span><span class="kn">from</span> <span class="nn">pyspark.ml</span> <span class="kn">import</span> <span class="n">Pipeline</span> |
| <span class="kn">from</span> <span class="nn">pyspark.ml.classification</span> <span class="kn">import</span> <span class="n">DecisionTreeClassifier</span> |
| <span class="kn">from</span> <span class="nn">pyspark.ml.feature</span> <span class="kn">import</span> <span class="n">StringIndexer</span><span class="p">,</span> <span class="n">VectorIndexer</span> |
| <span class="kn">from</span> <span class="nn">pyspark.ml.evaluation</span> <span class="kn">import</span> <span class="n">MulticlassClassificationEvaluator</span> |
| |
| <span class="c1"># Load the data stored in LIBSVM format as a DataFrame.</span> |
| <span class="n">data</span> <span class="o">=</span> <span class="n">spark</span><span class="o">.</span><span class="n">read</span><span class="o">.</span><span class="n">format</span><span class="p">(</span><span class="s2">"libsvm"</span><span class="p">)</span><span class="o">.</span><span class="n">load</span><span class="p">(</span><span class="s2">"data/mllib/sample_libsvm_data.txt"</span><span class="p">)</span> |
| |
| <span class="c1"># Index labels, adding metadata to the label column.</span> |
| <span class="c1"># Fit on whole dataset to include all labels in index.</span> |
| <span class="n">labelIndexer</span> <span class="o">=</span> <span class="n">StringIndexer</span><span class="p">(</span><span class="n">inputCol</span><span class="o">=</span><span class="s2">"label"</span><span class="p">,</span> <span class="n">outputCol</span><span class="o">=</span><span class="s2">"indexedLabel"</span><span class="p">)</span><span class="o">.</span><span class="n">fit</span><span class="p">(</span><span class="n">data</span><span class="p">)</span> |
| <span class="c1"># Automatically identify categorical features, and index them.</span> |
| <span class="c1"># We specify maxCategories so features with > 4 distinct values are treated as continuous.</span> |
| <span class="n">featureIndexer</span> <span class="o">=</span>\ |
| <span class="n">VectorIndexer</span><span class="p">(</span><span class="n">inputCol</span><span class="o">=</span><span class="s2">"features"</span><span class="p">,</span> <span class="n">outputCol</span><span class="o">=</span><span class="s2">"indexedFeatures"</span><span class="p">,</span> <span class="n">maxCategories</span><span class="o">=</span><span class="mi">4</span><span class="p">)</span><span class="o">.</span><span class="n">fit</span><span class="p">(</span><span class="n">data</span><span class="p">)</span> |
| |
| <span class="c1"># Split the data into training and test sets (30% held out for testing)</span> |
| <span class="p">(</span><span class="n">trainingData</span><span class="p">,</span> <span class="n">testData</span><span class="p">)</span> <span class="o">=</span> <span class="n">data</span><span class="o">.</span><span class="n">randomSplit</span><span class="p">([</span><span class="mf">0.7</span><span class="p">,</span> <span class="mf">0.3</span><span class="p">])</span> |
| |
| <span class="c1"># Train a DecisionTree model.</span> |
| <span class="n">dt</span> <span class="o">=</span> <span class="n">DecisionTreeClassifier</span><span class="p">(</span><span class="n">labelCol</span><span class="o">=</span><span class="s2">"indexedLabel"</span><span class="p">,</span> <span class="n">featuresCol</span><span class="o">=</span><span class="s2">"indexedFeatures"</span><span class="p">)</span> |
| |
| <span class="c1"># Chain indexers and tree in a Pipeline</span> |
| <span class="n">pipeline</span> <span class="o">=</span> <span class="n">Pipeline</span><span class="p">(</span><span class="n">stages</span><span class="o">=</span><span class="p">[</span><span class="n">labelIndexer</span><span class="p">,</span> <span class="n">featureIndexer</span><span class="p">,</span> <span class="n">dt</span><span class="p">])</span> |
| |
| <span class="c1"># Train model. This also runs the indexers.</span> |
| <span class="n">model</span> <span class="o">=</span> <span class="n">pipeline</span><span class="o">.</span><span class="n">fit</span><span class="p">(</span><span class="n">trainingData</span><span class="p">)</span> |
| |
| <span class="c1"># Make predictions.</span> |
| <span class="n">predictions</span> <span class="o">=</span> <span class="n">model</span><span class="o">.</span><span class="n">transform</span><span class="p">(</span><span class="n">testData</span><span class="p">)</span> |
| |
| <span class="c1"># Select example rows to display.</span> |
| <span class="n">predictions</span><span class="o">.</span><span class="n">select</span><span class="p">(</span><span class="s2">"prediction"</span><span class="p">,</span> <span class="s2">"indexedLabel"</span><span class="p">,</span> <span class="s2">"features"</span><span class="p">)</span><span class="o">.</span><span class="n">show</span><span class="p">(</span><span class="mi">5</span><span class="p">)</span> |
| |
| <span class="c1"># Select (prediction, true label) and compute test error</span> |
| <span class="n">evaluator</span> <span class="o">=</span> <span class="n">MulticlassClassificationEvaluator</span><span class="p">(</span> |
| <span class="n">labelCol</span><span class="o">=</span><span class="s2">"indexedLabel"</span><span class="p">,</span> <span class="n">predictionCol</span><span class="o">=</span><span class="s2">"prediction"</span><span class="p">,</span> <span class="n">metricName</span><span class="o">=</span><span class="s2">"accuracy"</span><span class="p">)</span> |
| <span class="n">accuracy</span> <span class="o">=</span> <span class="n">evaluator</span><span class="o">.</span><span class="n">evaluate</span><span class="p">(</span><span class="n">predictions</span><span class="p">)</span> |
| <span class="k">print</span><span class="p">(</span><span class="s2">"Test Error = </span><span class="si">%g</span><span class="s2"> "</span> <span class="o">%</span> <span class="p">(</span><span class="mf">1.0</span> <span class="o">-</span> <span class="n">accuracy</span><span class="p">))</span> |
| |
| <span class="n">treeModel</span> <span class="o">=</span> <span class="n">model</span><span class="o">.</span><span class="n">stages</span><span class="p">[</span><span class="mi">2</span><span class="p">]</span> |
| <span class="c1"># summary only</span> |
| <span class="k">print</span><span class="p">(</span><span class="n">treeModel</span><span class="p">)</span> |
| </pre></div> |
| <div><small>Find full example code at "examples/src/main/python/ml/decision_tree_classification_example.py" in the Spark repo.</small></div> |
| |
| </div> |
| |
| <div data-lang="r"> |
| |
| <p>Refer to the <a href="api/R/spark.decisionTree.html">R API docs</a> for more details.</p> |
| |
| <div class="highlight"><pre><span></span><span class="c1"># Load training data</span> |
| df <span class="o"><-</span> read.df<span class="p">(</span><span class="s">"data/mllib/sample_libsvm_data.txt"</span><span class="p">,</span> <span class="kn">source</span> <span class="o">=</span> <span class="s">"libsvm"</span><span class="p">)</span> |
| training <span class="o"><-</span> df |
| test <span class="o"><-</span> df |
| |
| <span class="c1"># Fit a DecisionTree classification model with spark.decisionTree</span> |
| model <span class="o"><-</span> spark.decisionTree<span class="p">(</span>training<span class="p">,</span> label <span class="o">~</span> features<span class="p">,</span> <span class="s">"classification"</span><span class="p">)</span> |
| |
| <span class="c1"># Model summary</span> |
| <span class="kp">summary</span><span class="p">(</span>model<span class="p">)</span> |
| |
| <span class="c1"># Prediction</span> |
| predictions <span class="o"><-</span> predict<span class="p">(</span>model<span class="p">,</span> test<span class="p">)</span> |
| <span class="kp">head</span><span class="p">(</span>predictions<span class="p">)</span> |
| </pre></div> |
| <div><small>Find full example code at "examples/src/main/r/ml/decisionTree.R" in the Spark repo.</small></div> |
| |
| </div> |
| |
| </div> |
| |
| <h2 id="random-forest-classifier">Random forest classifier</h2> |
| |
| <p>Random forests are a popular family of classification and regression methods. |
| More information about the <code>spark.ml</code> implementation can be found further in the <a href="#random-forests">section on random forests</a>.</p> |
| |
| <p><strong>Examples</strong></p> |
| |
| <p>The following examples load a dataset in LibSVM format, split it into training and test sets, train on the first dataset, and then evaluate on the held-out test set. |
| We use two feature transformers to prepare the data; these help index categories for the label and categorical features, adding metadata to the <code>DataFrame</code> which the tree-based algorithms can recognize.</p> |
| |
| <div class="codetabs"> |
| <div data-lang="scala"> |
| |
| <p>Refer to the <a href="api/scala/index.html#org.apache.spark.ml.classification.RandomForestClassifier">Scala API docs</a> for more details.</p> |
| |
| <div class="highlight"><pre><span></span><span class="k">import</span> <span class="nn">org.apache.spark.ml.Pipeline</span> |
| <span class="k">import</span> <span class="nn">org.apache.spark.ml.classification.</span><span class="o">{</span><span class="nc">RandomForestClassificationModel</span><span class="o">,</span> <span class="nc">RandomForestClassifier</span><span class="o">}</span> |
| <span class="k">import</span> <span class="nn">org.apache.spark.ml.evaluation.MulticlassClassificationEvaluator</span> |
| <span class="k">import</span> <span class="nn">org.apache.spark.ml.feature.</span><span class="o">{</span><span class="nc">IndexToString</span><span class="o">,</span> <span class="nc">StringIndexer</span><span class="o">,</span> <span class="nc">VectorIndexer</span><span class="o">}</span> |
| |
| <span class="c1">// Load and parse the data file, converting it to a DataFrame.</span> |
| <span class="k">val</span> <span class="n">data</span> <span class="k">=</span> <span class="n">spark</span><span class="o">.</span><span class="n">read</span><span class="o">.</span><span class="n">format</span><span class="o">(</span><span class="s">"libsvm"</span><span class="o">).</span><span class="n">load</span><span class="o">(</span><span class="s">"data/mllib/sample_libsvm_data.txt"</span><span class="o">)</span> |
| |
| <span class="c1">// Index labels, adding metadata to the label column.</span> |
| <span class="c1">// Fit on whole dataset to include all labels in index.</span> |
| <span class="k">val</span> <span class="n">labelIndexer</span> <span class="k">=</span> <span class="k">new</span> <span class="nc">StringIndexer</span><span class="o">()</span> |
| <span class="o">.</span><span class="n">setInputCol</span><span class="o">(</span><span class="s">"label"</span><span class="o">)</span> |
| <span class="o">.</span><span class="n">setOutputCol</span><span class="o">(</span><span class="s">"indexedLabel"</span><span class="o">)</span> |
| <span class="o">.</span><span class="n">fit</span><span class="o">(</span><span class="n">data</span><span class="o">)</span> |
| <span class="c1">// Automatically identify categorical features, and index them.</span> |
| <span class="c1">// Set maxCategories so features with > 4 distinct values are treated as continuous.</span> |
| <span class="k">val</span> <span class="n">featureIndexer</span> <span class="k">=</span> <span class="k">new</span> <span class="nc">VectorIndexer</span><span class="o">()</span> |
| <span class="o">.</span><span class="n">setInputCol</span><span class="o">(</span><span class="s">"features"</span><span class="o">)</span> |
| <span class="o">.</span><span class="n">setOutputCol</span><span class="o">(</span><span class="s">"indexedFeatures"</span><span class="o">)</span> |
| <span class="o">.</span><span class="n">setMaxCategories</span><span class="o">(</span><span class="mi">4</span><span class="o">)</span> |
| <span class="o">.</span><span class="n">fit</span><span class="o">(</span><span class="n">data</span><span class="o">)</span> |
| |
| <span class="c1">// Split the data into training and test sets (30% held out for testing).</span> |
| <span class="k">val</span> <span class="nc">Array</span><span class="o">(</span><span class="n">trainingData</span><span class="o">,</span> <span class="n">testData</span><span class="o">)</span> <span class="k">=</span> <span class="n">data</span><span class="o">.</span><span class="n">randomSplit</span><span class="o">(</span><span class="nc">Array</span><span class="o">(</span><span class="mf">0.7</span><span class="o">,</span> <span class="mf">0.3</span><span class="o">))</span> |
| |
| <span class="c1">// Train a RandomForest model.</span> |
| <span class="k">val</span> <span class="n">rf</span> <span class="k">=</span> <span class="k">new</span> <span class="nc">RandomForestClassifier</span><span class="o">()</span> |
| <span class="o">.</span><span class="n">setLabelCol</span><span class="o">(</span><span class="s">"indexedLabel"</span><span class="o">)</span> |
| <span class="o">.</span><span class="n">setFeaturesCol</span><span class="o">(</span><span class="s">"indexedFeatures"</span><span class="o">)</span> |
| <span class="o">.</span><span class="n">setNumTrees</span><span class="o">(</span><span class="mi">10</span><span class="o">)</span> |
| |
| <span class="c1">// Convert indexed labels back to original labels.</span> |
| <span class="k">val</span> <span class="n">labelConverter</span> <span class="k">=</span> <span class="k">new</span> <span class="nc">IndexToString</span><span class="o">()</span> |
| <span class="o">.</span><span class="n">setInputCol</span><span class="o">(</span><span class="s">"prediction"</span><span class="o">)</span> |
| <span class="o">.</span><span class="n">setOutputCol</span><span class="o">(</span><span class="s">"predictedLabel"</span><span class="o">)</span> |
| <span class="o">.</span><span class="n">setLabels</span><span class="o">(</span><span class="n">labelIndexer</span><span class="o">.</span><span class="n">labels</span><span class="o">)</span> |
| |
| <span class="c1">// Chain indexers and forest in a Pipeline.</span> |
| <span class="k">val</span> <span class="n">pipeline</span> <span class="k">=</span> <span class="k">new</span> <span class="nc">Pipeline</span><span class="o">()</span> |
| <span class="o">.</span><span class="n">setStages</span><span class="o">(</span><span class="nc">Array</span><span class="o">(</span><span class="n">labelIndexer</span><span class="o">,</span> <span class="n">featureIndexer</span><span class="o">,</span> <span class="n">rf</span><span class="o">,</span> <span class="n">labelConverter</span><span class="o">))</span> |
| |
| <span class="c1">// Train model. This also runs the indexers.</span> |
| <span class="k">val</span> <span class="n">model</span> <span class="k">=</span> <span class="n">pipeline</span><span class="o">.</span><span class="n">fit</span><span class="o">(</span><span class="n">trainingData</span><span class="o">)</span> |
| |
| <span class="c1">// Make predictions.</span> |
| <span class="k">val</span> <span class="n">predictions</span> <span class="k">=</span> <span class="n">model</span><span class="o">.</span><span class="n">transform</span><span class="o">(</span><span class="n">testData</span><span class="o">)</span> |
| |
| <span class="c1">// Select example rows to display.</span> |
| <span class="n">predictions</span><span class="o">.</span><span class="n">select</span><span class="o">(</span><span class="s">"predictedLabel"</span><span class="o">,</span> <span class="s">"label"</span><span class="o">,</span> <span class="s">"features"</span><span class="o">).</span><span class="n">show</span><span class="o">(</span><span class="mi">5</span><span class="o">)</span> |
| |
| <span class="c1">// Select (prediction, true label) and compute test error.</span> |
| <span class="k">val</span> <span class="n">evaluator</span> <span class="k">=</span> <span class="k">new</span> <span class="nc">MulticlassClassificationEvaluator</span><span class="o">()</span> |
| <span class="o">.</span><span class="n">setLabelCol</span><span class="o">(</span><span class="s">"indexedLabel"</span><span class="o">)</span> |
| <span class="o">.</span><span class="n">setPredictionCol</span><span class="o">(</span><span class="s">"prediction"</span><span class="o">)</span> |
| <span class="o">.</span><span class="n">setMetricName</span><span class="o">(</span><span class="s">"accuracy"</span><span class="o">)</span> |
| <span class="k">val</span> <span class="n">accuracy</span> <span class="k">=</span> <span class="n">evaluator</span><span class="o">.</span><span class="n">evaluate</span><span class="o">(</span><span class="n">predictions</span><span class="o">)</span> |
| <span class="n">println</span><span class="o">(</span><span class="s">s"Test Error = </span><span class="si">${</span><span class="o">(</span><span class="mf">1.0</span> <span class="o">-</span> <span class="n">accuracy</span><span class="o">)</span><span class="si">}</span><span class="s">"</span><span class="o">)</span> |
| |
| <span class="k">val</span> <span class="n">rfModel</span> <span class="k">=</span> <span class="n">model</span><span class="o">.</span><span class="n">stages</span><span class="o">(</span><span class="mi">2</span><span class="o">).</span><span class="n">asInstanceOf</span><span class="o">[</span><span class="kt">RandomForestClassificationModel</span><span class="o">]</span> |
| <span class="n">println</span><span class="o">(</span><span class="s">s"Learned classification forest model:\n </span><span class="si">${</span><span class="n">rfModel</span><span class="o">.</span><span class="n">toDebugString</span><span class="si">}</span><span class="s">"</span><span class="o">)</span> |
| </pre></div> |
| <div><small>Find full example code at "examples/src/main/scala/org/apache/spark/examples/ml/RandomForestClassifierExample.scala" in the Spark repo.</small></div> |
| </div> |
| |
| <div data-lang="java"> |
| |
| <p>Refer to the <a href="api/java/org/apache/spark/ml/classification/RandomForestClassifier.html">Java API docs</a> for more details.</p> |
| |
| <div class="highlight"><pre><span></span><span class="kn">import</span> <span class="nn">org.apache.spark.ml.Pipeline</span><span class="o">;</span> |
| <span class="kn">import</span> <span class="nn">org.apache.spark.ml.PipelineModel</span><span class="o">;</span> |
| <span class="kn">import</span> <span class="nn">org.apache.spark.ml.PipelineStage</span><span class="o">;</span> |
| <span class="kn">import</span> <span class="nn">org.apache.spark.ml.classification.RandomForestClassificationModel</span><span class="o">;</span> |
| <span class="kn">import</span> <span class="nn">org.apache.spark.ml.classification.RandomForestClassifier</span><span class="o">;</span> |
| <span class="kn">import</span> <span class="nn">org.apache.spark.ml.evaluation.MulticlassClassificationEvaluator</span><span class="o">;</span> |
| <span class="kn">import</span> <span class="nn">org.apache.spark.ml.feature.*</span><span class="o">;</span> |
| <span class="kn">import</span> <span class="nn">org.apache.spark.sql.Dataset</span><span class="o">;</span> |
| <span class="kn">import</span> <span class="nn">org.apache.spark.sql.Row</span><span class="o">;</span> |
| <span class="kn">import</span> <span class="nn">org.apache.spark.sql.SparkSession</span><span class="o">;</span> |
| |
| <span class="c1">// Load and parse the data file, converting it to a DataFrame.</span> |
| <span class="n">Dataset</span><span class="o"><</span><span class="n">Row</span><span class="o">></span> <span class="n">data</span> <span class="o">=</span> <span class="n">spark</span><span class="o">.</span><span class="na">read</span><span class="o">().</span><span class="na">format</span><span class="o">(</span><span class="s">"libsvm"</span><span class="o">).</span><span class="na">load</span><span class="o">(</span><span class="s">"data/mllib/sample_libsvm_data.txt"</span><span class="o">);</span> |
| |
| <span class="c1">// Index labels, adding metadata to the label column.</span> |
| <span class="c1">// Fit on whole dataset to include all labels in index.</span> |
| <span class="n">StringIndexerModel</span> <span class="n">labelIndexer</span> <span class="o">=</span> <span class="k">new</span> <span class="n">StringIndexer</span><span class="o">()</span> |
| <span class="o">.</span><span class="na">setInputCol</span><span class="o">(</span><span class="s">"label"</span><span class="o">)</span> |
| <span class="o">.</span><span class="na">setOutputCol</span><span class="o">(</span><span class="s">"indexedLabel"</span><span class="o">)</span> |
| <span class="o">.</span><span class="na">fit</span><span class="o">(</span><span class="n">data</span><span class="o">);</span> |
| <span class="c1">// Automatically identify categorical features, and index them.</span> |
| <span class="c1">// Set maxCategories so features with > 4 distinct values are treated as continuous.</span> |
| <span class="n">VectorIndexerModel</span> <span class="n">featureIndexer</span> <span class="o">=</span> <span class="k">new</span> <span class="n">VectorIndexer</span><span class="o">()</span> |
| <span class="o">.</span><span class="na">setInputCol</span><span class="o">(</span><span class="s">"features"</span><span class="o">)</span> |
| <span class="o">.</span><span class="na">setOutputCol</span><span class="o">(</span><span class="s">"indexedFeatures"</span><span class="o">)</span> |
| <span class="o">.</span><span class="na">setMaxCategories</span><span class="o">(</span><span class="mi">4</span><span class="o">)</span> |
| <span class="o">.</span><span class="na">fit</span><span class="o">(</span><span class="n">data</span><span class="o">);</span> |
| |
| <span class="c1">// Split the data into training and test sets (30% held out for testing)</span> |
| <span class="n">Dataset</span><span class="o"><</span><span class="n">Row</span><span class="o">>[]</span> <span class="n">splits</span> <span class="o">=</span> <span class="n">data</span><span class="o">.</span><span class="na">randomSplit</span><span class="o">(</span><span class="k">new</span> <span class="kt">double</span><span class="o">[]</span> <span class="o">{</span><span class="mf">0.7</span><span class="o">,</span> <span class="mf">0.3</span><span class="o">});</span> |
| <span class="n">Dataset</span><span class="o"><</span><span class="n">Row</span><span class="o">></span> <span class="n">trainingData</span> <span class="o">=</span> <span class="n">splits</span><span class="o">[</span><span class="mi">0</span><span class="o">];</span> |
| <span class="n">Dataset</span><span class="o"><</span><span class="n">Row</span><span class="o">></span> <span class="n">testData</span> <span class="o">=</span> <span class="n">splits</span><span class="o">[</span><span class="mi">1</span><span class="o">];</span> |
| |
| <span class="c1">// Train a RandomForest model.</span> |
| <span class="n">RandomForestClassifier</span> <span class="n">rf</span> <span class="o">=</span> <span class="k">new</span> <span class="n">RandomForestClassifier</span><span class="o">()</span> |
| <span class="o">.</span><span class="na">setLabelCol</span><span class="o">(</span><span class="s">"indexedLabel"</span><span class="o">)</span> |
| <span class="o">.</span><span class="na">setFeaturesCol</span><span class="o">(</span><span class="s">"indexedFeatures"</span><span class="o">);</span> |
| |
| <span class="c1">// Convert indexed labels back to original labels.</span> |
| <span class="n">IndexToString</span> <span class="n">labelConverter</span> <span class="o">=</span> <span class="k">new</span> <span class="n">IndexToString</span><span class="o">()</span> |
| <span class="o">.</span><span class="na">setInputCol</span><span class="o">(</span><span class="s">"prediction"</span><span class="o">)</span> |
| <span class="o">.</span><span class="na">setOutputCol</span><span class="o">(</span><span class="s">"predictedLabel"</span><span class="o">)</span> |
| <span class="o">.</span><span class="na">setLabels</span><span class="o">(</span><span class="n">labelIndexer</span><span class="o">.</span><span class="na">labels</span><span class="o">());</span> |
| |
| <span class="c1">// Chain indexers and forest in a Pipeline</span> |
| <span class="n">Pipeline</span> <span class="n">pipeline</span> <span class="o">=</span> <span class="k">new</span> <span class="n">Pipeline</span><span class="o">()</span> |
| <span class="o">.</span><span class="na">setStages</span><span class="o">(</span><span class="k">new</span> <span class="n">PipelineStage</span><span class="o">[]</span> <span class="o">{</span><span class="n">labelIndexer</span><span class="o">,</span> <span class="n">featureIndexer</span><span class="o">,</span> <span class="n">rf</span><span class="o">,</span> <span class="n">labelConverter</span><span class="o">});</span> |
| |
| <span class="c1">// Train model. This also runs the indexers.</span> |
| <span class="n">PipelineModel</span> <span class="n">model</span> <span class="o">=</span> <span class="n">pipeline</span><span class="o">.</span><span class="na">fit</span><span class="o">(</span><span class="n">trainingData</span><span class="o">);</span> |
| |
| <span class="c1">// Make predictions.</span> |
| <span class="n">Dataset</span><span class="o"><</span><span class="n">Row</span><span class="o">></span> <span class="n">predictions</span> <span class="o">=</span> <span class="n">model</span><span class="o">.</span><span class="na">transform</span><span class="o">(</span><span class="n">testData</span><span class="o">);</span> |
| |
| <span class="c1">// Select example rows to display.</span> |
| <span class="n">predictions</span><span class="o">.</span><span class="na">select</span><span class="o">(</span><span class="s">"predictedLabel"</span><span class="o">,</span> <span class="s">"label"</span><span class="o">,</span> <span class="s">"features"</span><span class="o">).</span><span class="na">show</span><span class="o">(</span><span class="mi">5</span><span class="o">);</span> |
| |
| <span class="c1">// Select (prediction, true label) and compute test error</span> |
| <span class="n">MulticlassClassificationEvaluator</span> <span class="n">evaluator</span> <span class="o">=</span> <span class="k">new</span> <span class="n">MulticlassClassificationEvaluator</span><span class="o">()</span> |
| <span class="o">.</span><span class="na">setLabelCol</span><span class="o">(</span><span class="s">"indexedLabel"</span><span class="o">)</span> |
| <span class="o">.</span><span class="na">setPredictionCol</span><span class="o">(</span><span class="s">"prediction"</span><span class="o">)</span> |
| <span class="o">.</span><span class="na">setMetricName</span><span class="o">(</span><span class="s">"accuracy"</span><span class="o">);</span> |
| <span class="kt">double</span> <span class="n">accuracy</span> <span class="o">=</span> <span class="n">evaluator</span><span class="o">.</span><span class="na">evaluate</span><span class="o">(</span><span class="n">predictions</span><span class="o">);</span> |
| <span class="n">System</span><span class="o">.</span><span class="na">out</span><span class="o">.</span><span class="na">println</span><span class="o">(</span><span class="s">"Test Error = "</span> <span class="o">+</span> <span class="o">(</span><span class="mf">1.0</span> <span class="o">-</span> <span class="n">accuracy</span><span class="o">));</span> |
| |
| <span class="n">RandomForestClassificationModel</span> <span class="n">rfModel</span> <span class="o">=</span> <span class="o">(</span><span class="n">RandomForestClassificationModel</span><span class="o">)(</span><span class="n">model</span><span class="o">.</span><span class="na">stages</span><span class="o">()[</span><span class="mi">2</span><span class="o">]);</span> |
| <span class="n">System</span><span class="o">.</span><span class="na">out</span><span class="o">.</span><span class="na">println</span><span class="o">(</span><span class="s">"Learned classification forest model:\n"</span> <span class="o">+</span> <span class="n">rfModel</span><span class="o">.</span><span class="na">toDebugString</span><span class="o">());</span> |
| </pre></div> |
| <div><small>Find full example code at "examples/src/main/java/org/apache/spark/examples/ml/JavaRandomForestClassifierExample.java" in the Spark repo.</small></div> |
| </div> |
| |
| <div data-lang="python"> |
| |
| <p>Refer to the <a href="api/python/pyspark.ml.html#pyspark.ml.classification.RandomForestClassifier">Python API docs</a> for more details.</p> |
| |
| <div class="highlight"><pre><span></span><span class="kn">from</span> <span class="nn">pyspark.ml</span> <span class="kn">import</span> <span class="n">Pipeline</span> |
| <span class="kn">from</span> <span class="nn">pyspark.ml.classification</span> <span class="kn">import</span> <span class="n">RandomForestClassifier</span> |
| <span class="kn">from</span> <span class="nn">pyspark.ml.feature</span> <span class="kn">import</span> <span class="n">IndexToString</span><span class="p">,</span> <span class="n">StringIndexer</span><span class="p">,</span> <span class="n">VectorIndexer</span> |
| <span class="kn">from</span> <span class="nn">pyspark.ml.evaluation</span> <span class="kn">import</span> <span class="n">MulticlassClassificationEvaluator</span> |
| |
| <span class="c1"># Load and parse the data file, converting it to a DataFrame.</span> |
| <span class="n">data</span> <span class="o">=</span> <span class="n">spark</span><span class="o">.</span><span class="n">read</span><span class="o">.</span><span class="n">format</span><span class="p">(</span><span class="s2">"libsvm"</span><span class="p">)</span><span class="o">.</span><span class="n">load</span><span class="p">(</span><span class="s2">"data/mllib/sample_libsvm_data.txt"</span><span class="p">)</span> |
| |
| <span class="c1"># Index labels, adding metadata to the label column.</span> |
| <span class="c1"># Fit on whole dataset to include all labels in index.</span> |
| <span class="n">labelIndexer</span> <span class="o">=</span> <span class="n">StringIndexer</span><span class="p">(</span><span class="n">inputCol</span><span class="o">=</span><span class="s2">"label"</span><span class="p">,</span> <span class="n">outputCol</span><span class="o">=</span><span class="s2">"indexedLabel"</span><span class="p">)</span><span class="o">.</span><span class="n">fit</span><span class="p">(</span><span class="n">data</span><span class="p">)</span> |
| |
| <span class="c1"># Automatically identify categorical features, and index them.</span> |
| <span class="c1"># Set maxCategories so features with > 4 distinct values are treated as continuous.</span> |
| <span class="n">featureIndexer</span> <span class="o">=</span>\ |
| <span class="n">VectorIndexer</span><span class="p">(</span><span class="n">inputCol</span><span class="o">=</span><span class="s2">"features"</span><span class="p">,</span> <span class="n">outputCol</span><span class="o">=</span><span class="s2">"indexedFeatures"</span><span class="p">,</span> <span class="n">maxCategories</span><span class="o">=</span><span class="mi">4</span><span class="p">)</span><span class="o">.</span><span class="n">fit</span><span class="p">(</span><span class="n">data</span><span class="p">)</span> |
| |
| <span class="c1"># Split the data into training and test sets (30% held out for testing)</span> |
| <span class="p">(</span><span class="n">trainingData</span><span class="p">,</span> <span class="n">testData</span><span class="p">)</span> <span class="o">=</span> <span class="n">data</span><span class="o">.</span><span class="n">randomSplit</span><span class="p">([</span><span class="mf">0.7</span><span class="p">,</span> <span class="mf">0.3</span><span class="p">])</span> |
| |
| <span class="c1"># Train a RandomForest model.</span> |
| <span class="n">rf</span> <span class="o">=</span> <span class="n">RandomForestClassifier</span><span class="p">(</span><span class="n">labelCol</span><span class="o">=</span><span class="s2">"indexedLabel"</span><span class="p">,</span> <span class="n">featuresCol</span><span class="o">=</span><span class="s2">"indexedFeatures"</span><span class="p">,</span> <span class="n">numTrees</span><span class="o">=</span><span class="mi">10</span><span class="p">)</span> |
| |
| <span class="c1"># Convert indexed labels back to original labels.</span> |
| <span class="n">labelConverter</span> <span class="o">=</span> <span class="n">IndexToString</span><span class="p">(</span><span class="n">inputCol</span><span class="o">=</span><span class="s2">"prediction"</span><span class="p">,</span> <span class="n">outputCol</span><span class="o">=</span><span class="s2">"predictedLabel"</span><span class="p">,</span> |
| <span class="n">labels</span><span class="o">=</span><span class="n">labelIndexer</span><span class="o">.</span><span class="n">labels</span><span class="p">)</span> |
| |
| <span class="c1"># Chain indexers and forest in a Pipeline</span> |
| <span class="n">pipeline</span> <span class="o">=</span> <span class="n">Pipeline</span><span class="p">(</span><span class="n">stages</span><span class="o">=</span><span class="p">[</span><span class="n">labelIndexer</span><span class="p">,</span> <span class="n">featureIndexer</span><span class="p">,</span> <span class="n">rf</span><span class="p">,</span> <span class="n">labelConverter</span><span class="p">])</span> |
| |
| <span class="c1"># Train model. This also runs the indexers.</span> |
| <span class="n">model</span> <span class="o">=</span> <span class="n">pipeline</span><span class="o">.</span><span class="n">fit</span><span class="p">(</span><span class="n">trainingData</span><span class="p">)</span> |
| |
| <span class="c1"># Make predictions.</span> |
| <span class="n">predictions</span> <span class="o">=</span> <span class="n">model</span><span class="o">.</span><span class="n">transform</span><span class="p">(</span><span class="n">testData</span><span class="p">)</span> |
| |
| <span class="c1"># Select example rows to display.</span> |
| <span class="n">predictions</span><span class="o">.</span><span class="n">select</span><span class="p">(</span><span class="s2">"predictedLabel"</span><span class="p">,</span> <span class="s2">"label"</span><span class="p">,</span> <span class="s2">"features"</span><span class="p">)</span><span class="o">.</span><span class="n">show</span><span class="p">(</span><span class="mi">5</span><span class="p">)</span> |
| |
| <span class="c1"># Select (prediction, true label) and compute test error</span> |
| <span class="n">evaluator</span> <span class="o">=</span> <span class="n">MulticlassClassificationEvaluator</span><span class="p">(</span> |
| <span class="n">labelCol</span><span class="o">=</span><span class="s2">"indexedLabel"</span><span class="p">,</span> <span class="n">predictionCol</span><span class="o">=</span><span class="s2">"prediction"</span><span class="p">,</span> <span class="n">metricName</span><span class="o">=</span><span class="s2">"accuracy"</span><span class="p">)</span> |
| <span class="n">accuracy</span> <span class="o">=</span> <span class="n">evaluator</span><span class="o">.</span><span class="n">evaluate</span><span class="p">(</span><span class="n">predictions</span><span class="p">)</span> |
| <span class="k">print</span><span class="p">(</span><span class="s2">"Test Error = </span><span class="si">%g</span><span class="s2">"</span> <span class="o">%</span> <span class="p">(</span><span class="mf">1.0</span> <span class="o">-</span> <span class="n">accuracy</span><span class="p">))</span> |
| |
| <span class="n">rfModel</span> <span class="o">=</span> <span class="n">model</span><span class="o">.</span><span class="n">stages</span><span class="p">[</span><span class="mi">2</span><span class="p">]</span> |
| <span class="k">print</span><span class="p">(</span><span class="n">rfModel</span><span class="p">)</span> <span class="c1"># summary only</span> |
| </pre></div> |
| <div><small>Find full example code at "examples/src/main/python/ml/random_forest_classifier_example.py" in the Spark repo.</small></div> |
| </div> |
| |
| <div data-lang="r"> |
| |
| <p>Refer to the <a href="api/R/spark.randomForest.html">R API docs</a> for more details.</p> |
| |
| <div class="highlight"><pre><span></span><span class="c1"># Load training data</span> |
| df <span class="o"><-</span> read.df<span class="p">(</span><span class="s">"data/mllib/sample_libsvm_data.txt"</span><span class="p">,</span> <span class="kn">source</span> <span class="o">=</span> <span class="s">"libsvm"</span><span class="p">)</span> |
| training <span class="o"><-</span> df |
| test <span class="o"><-</span> df |
| |
| <span class="c1"># Fit a random forest classification model with spark.randomForest</span> |
| model <span class="o"><-</span> spark.randomForest<span class="p">(</span>training<span class="p">,</span> label <span class="o">~</span> features<span class="p">,</span> <span class="s">"classification"</span><span class="p">,</span> numTrees <span class="o">=</span> <span class="m">10</span><span class="p">)</span> |
| |
| <span class="c1"># Model summary</span> |
| <span class="kp">summary</span><span class="p">(</span>model<span class="p">)</span> |
| |
| <span class="c1"># Prediction</span> |
| predictions <span class="o"><-</span> predict<span class="p">(</span>model<span class="p">,</span> test<span class="p">)</span> |
| <span class="kp">head</span><span class="p">(</span>predictions<span class="p">)</span> |
| </pre></div> |
| <div><small>Find full example code at "examples/src/main/r/ml/randomForest.R" in the Spark repo.</small></div> |
| </div> |
| |
| </div> |
| |
| <h2 id="gradient-boosted-tree-classifier">Gradient-boosted tree classifier</h2> |
| |
| <p>Gradient-boosted trees (GBTs) are a popular classification and regression method using ensembles of decision trees. |
| More information about the <code>spark.ml</code> implementation can be found further in the <a href="#gradient-boosted-trees-gbts">section on GBTs</a>.</p> |
| |
| <p><strong>Examples</strong></p> |
| |
| <p>The following examples load a dataset in LibSVM format, split it into training and test sets, train on the first dataset, and then evaluate on the held-out test set. |
| We use two feature transformers to prepare the data; these help index categories for the label and categorical features, adding metadata to the <code>DataFrame</code> which the tree-based algorithms can recognize.</p> |
| |
| <div class="codetabs"> |
| <div data-lang="scala"> |
| |
| <p>Refer to the <a href="api/scala/index.html#org.apache.spark.ml.classification.GBTClassifier">Scala API docs</a> for more details.</p> |
| |
| <div class="highlight"><pre><span></span><span class="k">import</span> <span class="nn">org.apache.spark.ml.Pipeline</span> |
| <span class="k">import</span> <span class="nn">org.apache.spark.ml.classification.</span><span class="o">{</span><span class="nc">GBTClassificationModel</span><span class="o">,</span> <span class="nc">GBTClassifier</span><span class="o">}</span> |
| <span class="k">import</span> <span class="nn">org.apache.spark.ml.evaluation.MulticlassClassificationEvaluator</span> |
| <span class="k">import</span> <span class="nn">org.apache.spark.ml.feature.</span><span class="o">{</span><span class="nc">IndexToString</span><span class="o">,</span> <span class="nc">StringIndexer</span><span class="o">,</span> <span class="nc">VectorIndexer</span><span class="o">}</span> |
| |
| <span class="c1">// Load and parse the data file, converting it to a DataFrame.</span> |
| <span class="k">val</span> <span class="n">data</span> <span class="k">=</span> <span class="n">spark</span><span class="o">.</span><span class="n">read</span><span class="o">.</span><span class="n">format</span><span class="o">(</span><span class="s">"libsvm"</span><span class="o">).</span><span class="n">load</span><span class="o">(</span><span class="s">"data/mllib/sample_libsvm_data.txt"</span><span class="o">)</span> |
| |
| <span class="c1">// Index labels, adding metadata to the label column.</span> |
| <span class="c1">// Fit on whole dataset to include all labels in index.</span> |
| <span class="k">val</span> <span class="n">labelIndexer</span> <span class="k">=</span> <span class="k">new</span> <span class="nc">StringIndexer</span><span class="o">()</span> |
| <span class="o">.</span><span class="n">setInputCol</span><span class="o">(</span><span class="s">"label"</span><span class="o">)</span> |
| <span class="o">.</span><span class="n">setOutputCol</span><span class="o">(</span><span class="s">"indexedLabel"</span><span class="o">)</span> |
| <span class="o">.</span><span class="n">fit</span><span class="o">(</span><span class="n">data</span><span class="o">)</span> |
| <span class="c1">// Automatically identify categorical features, and index them.</span> |
| <span class="c1">// Set maxCategories so features with > 4 distinct values are treated as continuous.</span> |
| <span class="k">val</span> <span class="n">featureIndexer</span> <span class="k">=</span> <span class="k">new</span> <span class="nc">VectorIndexer</span><span class="o">()</span> |
| <span class="o">.</span><span class="n">setInputCol</span><span class="o">(</span><span class="s">"features"</span><span class="o">)</span> |
| <span class="o">.</span><span class="n">setOutputCol</span><span class="o">(</span><span class="s">"indexedFeatures"</span><span class="o">)</span> |
| <span class="o">.</span><span class="n">setMaxCategories</span><span class="o">(</span><span class="mi">4</span><span class="o">)</span> |
| <span class="o">.</span><span class="n">fit</span><span class="o">(</span><span class="n">data</span><span class="o">)</span> |
| |
| <span class="c1">// Split the data into training and test sets (30% held out for testing).</span> |
| <span class="k">val</span> <span class="nc">Array</span><span class="o">(</span><span class="n">trainingData</span><span class="o">,</span> <span class="n">testData</span><span class="o">)</span> <span class="k">=</span> <span class="n">data</span><span class="o">.</span><span class="n">randomSplit</span><span class="o">(</span><span class="nc">Array</span><span class="o">(</span><span class="mf">0.7</span><span class="o">,</span> <span class="mf">0.3</span><span class="o">))</span> |
| |
| <span class="c1">// Train a GBT model.</span> |
| <span class="k">val</span> <span class="n">gbt</span> <span class="k">=</span> <span class="k">new</span> <span class="nc">GBTClassifier</span><span class="o">()</span> |
| <span class="o">.</span><span class="n">setLabelCol</span><span class="o">(</span><span class="s">"indexedLabel"</span><span class="o">)</span> |
| <span class="o">.</span><span class="n">setFeaturesCol</span><span class="o">(</span><span class="s">"indexedFeatures"</span><span class="o">)</span> |
| <span class="o">.</span><span class="n">setMaxIter</span><span class="o">(</span><span class="mi">10</span><span class="o">)</span> |
| <span class="o">.</span><span class="n">setFeatureSubsetStrategy</span><span class="o">(</span><span class="s">"auto"</span><span class="o">)</span> |
| |
| <span class="c1">// Convert indexed labels back to original labels.</span> |
| <span class="k">val</span> <span class="n">labelConverter</span> <span class="k">=</span> <span class="k">new</span> <span class="nc">IndexToString</span><span class="o">()</span> |
| <span class="o">.</span><span class="n">setInputCol</span><span class="o">(</span><span class="s">"prediction"</span><span class="o">)</span> |
| <span class="o">.</span><span class="n">setOutputCol</span><span class="o">(</span><span class="s">"predictedLabel"</span><span class="o">)</span> |
| <span class="o">.</span><span class="n">setLabels</span><span class="o">(</span><span class="n">labelIndexer</span><span class="o">.</span><span class="n">labels</span><span class="o">)</span> |
| |
| <span class="c1">// Chain indexers and GBT in a Pipeline.</span> |
| <span class="k">val</span> <span class="n">pipeline</span> <span class="k">=</span> <span class="k">new</span> <span class="nc">Pipeline</span><span class="o">()</span> |
| <span class="o">.</span><span class="n">setStages</span><span class="o">(</span><span class="nc">Array</span><span class="o">(</span><span class="n">labelIndexer</span><span class="o">,</span> <span class="n">featureIndexer</span><span class="o">,</span> <span class="n">gbt</span><span class="o">,</span> <span class="n">labelConverter</span><span class="o">))</span> |
| |
| <span class="c1">// Train model. This also runs the indexers.</span> |
| <span class="k">val</span> <span class="n">model</span> <span class="k">=</span> <span class="n">pipeline</span><span class="o">.</span><span class="n">fit</span><span class="o">(</span><span class="n">trainingData</span><span class="o">)</span> |
| |
| <span class="c1">// Make predictions.</span> |
| <span class="k">val</span> <span class="n">predictions</span> <span class="k">=</span> <span class="n">model</span><span class="o">.</span><span class="n">transform</span><span class="o">(</span><span class="n">testData</span><span class="o">)</span> |
| |
| <span class="c1">// Select example rows to display.</span> |
| <span class="n">predictions</span><span class="o">.</span><span class="n">select</span><span class="o">(</span><span class="s">"predictedLabel"</span><span class="o">,</span> <span class="s">"label"</span><span class="o">,</span> <span class="s">"features"</span><span class="o">).</span><span class="n">show</span><span class="o">(</span><span class="mi">5</span><span class="o">)</span> |
| |
| <span class="c1">// Select (prediction, true label) and compute test error.</span> |
| <span class="k">val</span> <span class="n">evaluator</span> <span class="k">=</span> <span class="k">new</span> <span class="nc">MulticlassClassificationEvaluator</span><span class="o">()</span> |
| <span class="o">.</span><span class="n">setLabelCol</span><span class="o">(</span><span class="s">"indexedLabel"</span><span class="o">)</span> |
| <span class="o">.</span><span class="n">setPredictionCol</span><span class="o">(</span><span class="s">"prediction"</span><span class="o">)</span> |
| <span class="o">.</span><span class="n">setMetricName</span><span class="o">(</span><span class="s">"accuracy"</span><span class="o">)</span> |
| <span class="k">val</span> <span class="n">accuracy</span> <span class="k">=</span> <span class="n">evaluator</span><span class="o">.</span><span class="n">evaluate</span><span class="o">(</span><span class="n">predictions</span><span class="o">)</span> |
| <span class="n">println</span><span class="o">(</span><span class="s">s"Test Error = </span><span class="si">${</span><span class="mf">1.0</span> <span class="o">-</span> <span class="n">accuracy</span><span class="si">}</span><span class="s">"</span><span class="o">)</span> |
| |
| <span class="k">val</span> <span class="n">gbtModel</span> <span class="k">=</span> <span class="n">model</span><span class="o">.</span><span class="n">stages</span><span class="o">(</span><span class="mi">2</span><span class="o">).</span><span class="n">asInstanceOf</span><span class="o">[</span><span class="kt">GBTClassificationModel</span><span class="o">]</span> |
| <span class="n">println</span><span class="o">(</span><span class="s">s"Learned classification GBT model:\n </span><span class="si">${</span><span class="n">gbtModel</span><span class="o">.</span><span class="n">toDebugString</span><span class="si">}</span><span class="s">"</span><span class="o">)</span> |
| </pre></div> |
| <div><small>Find full example code at "examples/src/main/scala/org/apache/spark/examples/ml/GradientBoostedTreeClassifierExample.scala" in the Spark repo.</small></div> |
| </div> |
| |
| <div data-lang="java"> |
| |
| <p>Refer to the <a href="api/java/org/apache/spark/ml/classification/GBTClassifier.html">Java API docs</a> for more details.</p> |
| |
| <div class="highlight"><pre><span></span><span class="kn">import</span> <span class="nn">org.apache.spark.ml.Pipeline</span><span class="o">;</span> |
| <span class="kn">import</span> <span class="nn">org.apache.spark.ml.PipelineModel</span><span class="o">;</span> |
| <span class="kn">import</span> <span class="nn">org.apache.spark.ml.PipelineStage</span><span class="o">;</span> |
| <span class="kn">import</span> <span class="nn">org.apache.spark.ml.classification.GBTClassificationModel</span><span class="o">;</span> |
| <span class="kn">import</span> <span class="nn">org.apache.spark.ml.classification.GBTClassifier</span><span class="o">;</span> |
| <span class="kn">import</span> <span class="nn">org.apache.spark.ml.evaluation.MulticlassClassificationEvaluator</span><span class="o">;</span> |
| <span class="kn">import</span> <span class="nn">org.apache.spark.ml.feature.*</span><span class="o">;</span> |
| <span class="kn">import</span> <span class="nn">org.apache.spark.sql.Dataset</span><span class="o">;</span> |
| <span class="kn">import</span> <span class="nn">org.apache.spark.sql.Row</span><span class="o">;</span> |
| <span class="kn">import</span> <span class="nn">org.apache.spark.sql.SparkSession</span><span class="o">;</span> |
| |
| <span class="c1">// Load and parse the data file, converting it to a DataFrame.</span> |
| <span class="n">Dataset</span><span class="o"><</span><span class="n">Row</span><span class="o">></span> <span class="n">data</span> <span class="o">=</span> <span class="n">spark</span> |
| <span class="o">.</span><span class="na">read</span><span class="o">()</span> |
| <span class="o">.</span><span class="na">format</span><span class="o">(</span><span class="s">"libsvm"</span><span class="o">)</span> |
| <span class="o">.</span><span class="na">load</span><span class="o">(</span><span class="s">"data/mllib/sample_libsvm_data.txt"</span><span class="o">);</span> |
| |
| <span class="c1">// Index labels, adding metadata to the label column.</span> |
| <span class="c1">// Fit on whole dataset to include all labels in index.</span> |
| <span class="n">StringIndexerModel</span> <span class="n">labelIndexer</span> <span class="o">=</span> <span class="k">new</span> <span class="n">StringIndexer</span><span class="o">()</span> |
| <span class="o">.</span><span class="na">setInputCol</span><span class="o">(</span><span class="s">"label"</span><span class="o">)</span> |
| <span class="o">.</span><span class="na">setOutputCol</span><span class="o">(</span><span class="s">"indexedLabel"</span><span class="o">)</span> |
| <span class="o">.</span><span class="na">fit</span><span class="o">(</span><span class="n">data</span><span class="o">);</span> |
| <span class="c1">// Automatically identify categorical features, and index them.</span> |
| <span class="c1">// Set maxCategories so features with > 4 distinct values are treated as continuous.</span> |
| <span class="n">VectorIndexerModel</span> <span class="n">featureIndexer</span> <span class="o">=</span> <span class="k">new</span> <span class="n">VectorIndexer</span><span class="o">()</span> |
| <span class="o">.</span><span class="na">setInputCol</span><span class="o">(</span><span class="s">"features"</span><span class="o">)</span> |
| <span class="o">.</span><span class="na">setOutputCol</span><span class="o">(</span><span class="s">"indexedFeatures"</span><span class="o">)</span> |
| <span class="o">.</span><span class="na">setMaxCategories</span><span class="o">(</span><span class="mi">4</span><span class="o">)</span> |
| <span class="o">.</span><span class="na">fit</span><span class="o">(</span><span class="n">data</span><span class="o">);</span> |
| |
| <span class="c1">// Split the data into training and test sets (30% held out for testing)</span> |
| <span class="n">Dataset</span><span class="o"><</span><span class="n">Row</span><span class="o">>[]</span> <span class="n">splits</span> <span class="o">=</span> <span class="n">data</span><span class="o">.</span><span class="na">randomSplit</span><span class="o">(</span><span class="k">new</span> <span class="kt">double</span><span class="o">[]</span> <span class="o">{</span><span class="mf">0.7</span><span class="o">,</span> <span class="mf">0.3</span><span class="o">});</span> |
| <span class="n">Dataset</span><span class="o"><</span><span class="n">Row</span><span class="o">></span> <span class="n">trainingData</span> <span class="o">=</span> <span class="n">splits</span><span class="o">[</span><span class="mi">0</span><span class="o">];</span> |
| <span class="n">Dataset</span><span class="o"><</span><span class="n">Row</span><span class="o">></span> <span class="n">testData</span> <span class="o">=</span> <span class="n">splits</span><span class="o">[</span><span class="mi">1</span><span class="o">];</span> |
| |
| <span class="c1">// Train a GBT model.</span> |
| <span class="n">GBTClassifier</span> <span class="n">gbt</span> <span class="o">=</span> <span class="k">new</span> <span class="n">GBTClassifier</span><span class="o">()</span> |
| <span class="o">.</span><span class="na">setLabelCol</span><span class="o">(</span><span class="s">"indexedLabel"</span><span class="o">)</span> |
| <span class="o">.</span><span class="na">setFeaturesCol</span><span class="o">(</span><span class="s">"indexedFeatures"</span><span class="o">)</span> |
| <span class="o">.</span><span class="na">setMaxIter</span><span class="o">(</span><span class="mi">10</span><span class="o">);</span> |
| |
| <span class="c1">// Convert indexed labels back to original labels.</span> |
| <span class="n">IndexToString</span> <span class="n">labelConverter</span> <span class="o">=</span> <span class="k">new</span> <span class="n">IndexToString</span><span class="o">()</span> |
| <span class="o">.</span><span class="na">setInputCol</span><span class="o">(</span><span class="s">"prediction"</span><span class="o">)</span> |
| <span class="o">.</span><span class="na">setOutputCol</span><span class="o">(</span><span class="s">"predictedLabel"</span><span class="o">)</span> |
| <span class="o">.</span><span class="na">setLabels</span><span class="o">(</span><span class="n">labelIndexer</span><span class="o">.</span><span class="na">labels</span><span class="o">());</span> |
| |
| <span class="c1">// Chain indexers and GBT in a Pipeline.</span> |
| <span class="n">Pipeline</span> <span class="n">pipeline</span> <span class="o">=</span> <span class="k">new</span> <span class="n">Pipeline</span><span class="o">()</span> |
| <span class="o">.</span><span class="na">setStages</span><span class="o">(</span><span class="k">new</span> <span class="n">PipelineStage</span><span class="o">[]</span> <span class="o">{</span><span class="n">labelIndexer</span><span class="o">,</span> <span class="n">featureIndexer</span><span class="o">,</span> <span class="n">gbt</span><span class="o">,</span> <span class="n">labelConverter</span><span class="o">});</span> |
| |
| <span class="c1">// Train model. This also runs the indexers.</span> |
| <span class="n">PipelineModel</span> <span class="n">model</span> <span class="o">=</span> <span class="n">pipeline</span><span class="o">.</span><span class="na">fit</span><span class="o">(</span><span class="n">trainingData</span><span class="o">);</span> |
| |
| <span class="c1">// Make predictions.</span> |
| <span class="n">Dataset</span><span class="o"><</span><span class="n">Row</span><span class="o">></span> <span class="n">predictions</span> <span class="o">=</span> <span class="n">model</span><span class="o">.</span><span class="na">transform</span><span class="o">(</span><span class="n">testData</span><span class="o">);</span> |
| |
| <span class="c1">// Select example rows to display.</span> |
| <span class="n">predictions</span><span class="o">.</span><span class="na">select</span><span class="o">(</span><span class="s">"predictedLabel"</span><span class="o">,</span> <span class="s">"label"</span><span class="o">,</span> <span class="s">"features"</span><span class="o">).</span><span class="na">show</span><span class="o">(</span><span class="mi">5</span><span class="o">);</span> |
| |
| <span class="c1">// Select (prediction, true label) and compute test error.</span> |
| <span class="n">MulticlassClassificationEvaluator</span> <span class="n">evaluator</span> <span class="o">=</span> <span class="k">new</span> <span class="n">MulticlassClassificationEvaluator</span><span class="o">()</span> |
| <span class="o">.</span><span class="na">setLabelCol</span><span class="o">(</span><span class="s">"indexedLabel"</span><span class="o">)</span> |
| <span class="o">.</span><span class="na">setPredictionCol</span><span class="o">(</span><span class="s">"prediction"</span><span class="o">)</span> |
| <span class="o">.</span><span class="na">setMetricName</span><span class="o">(</span><span class="s">"accuracy"</span><span class="o">);</span> |
| <span class="kt">double</span> <span class="n">accuracy</span> <span class="o">=</span> <span class="n">evaluator</span><span class="o">.</span><span class="na">evaluate</span><span class="o">(</span><span class="n">predictions</span><span class="o">);</span> |
| <span class="n">System</span><span class="o">.</span><span class="na">out</span><span class="o">.</span><span class="na">println</span><span class="o">(</span><span class="s">"Test Error = "</span> <span class="o">+</span> <span class="o">(</span><span class="mf">1.0</span> <span class="o">-</span> <span class="n">accuracy</span><span class="o">));</span> |
| |
| <span class="n">GBTClassificationModel</span> <span class="n">gbtModel</span> <span class="o">=</span> <span class="o">(</span><span class="n">GBTClassificationModel</span><span class="o">)(</span><span class="n">model</span><span class="o">.</span><span class="na">stages</span><span class="o">()[</span><span class="mi">2</span><span class="o">]);</span> |
| <span class="n">System</span><span class="o">.</span><span class="na">out</span><span class="o">.</span><span class="na">println</span><span class="o">(</span><span class="s">"Learned classification GBT model:\n"</span> <span class="o">+</span> <span class="n">gbtModel</span><span class="o">.</span><span class="na">toDebugString</span><span class="o">());</span> |
| </pre></div> |
| <div><small>Find full example code at "examples/src/main/java/org/apache/spark/examples/ml/JavaGradientBoostedTreeClassifierExample.java" in the Spark repo.</small></div> |
| </div> |
| |
| <div data-lang="python"> |
| |
| <p>Refer to the <a href="api/python/pyspark.ml.html#pyspark.ml.classification.GBTClassifier">Python API docs</a> for more details.</p> |
| |
| <div class="highlight"><pre><span></span><span class="kn">from</span> <span class="nn">pyspark.ml</span> <span class="kn">import</span> <span class="n">Pipeline</span> |
| <span class="kn">from</span> <span class="nn">pyspark.ml.classification</span> <span class="kn">import</span> <span class="n">GBTClassifier</span> |
| <span class="kn">from</span> <span class="nn">pyspark.ml.feature</span> <span class="kn">import</span> <span class="n">StringIndexer</span><span class="p">,</span> <span class="n">VectorIndexer</span> |
| <span class="kn">from</span> <span class="nn">pyspark.ml.evaluation</span> <span class="kn">import</span> <span class="n">MulticlassClassificationEvaluator</span> |
| |
| <span class="c1"># Load and parse the data file, converting it to a DataFrame.</span> |
| <span class="n">data</span> <span class="o">=</span> <span class="n">spark</span><span class="o">.</span><span class="n">read</span><span class="o">.</span><span class="n">format</span><span class="p">(</span><span class="s2">"libsvm"</span><span class="p">)</span><span class="o">.</span><span class="n">load</span><span class="p">(</span><span class="s2">"data/mllib/sample_libsvm_data.txt"</span><span class="p">)</span> |
| |
| <span class="c1"># Index labels, adding metadata to the label column.</span> |
| <span class="c1"># Fit on whole dataset to include all labels in index.</span> |
| <span class="n">labelIndexer</span> <span class="o">=</span> <span class="n">StringIndexer</span><span class="p">(</span><span class="n">inputCol</span><span class="o">=</span><span class="s2">"label"</span><span class="p">,</span> <span class="n">outputCol</span><span class="o">=</span><span class="s2">"indexedLabel"</span><span class="p">)</span><span class="o">.</span><span class="n">fit</span><span class="p">(</span><span class="n">data</span><span class="p">)</span> |
| <span class="c1"># Automatically identify categorical features, and index them.</span> |
| <span class="c1"># Set maxCategories so features with > 4 distinct values are treated as continuous.</span> |
| <span class="n">featureIndexer</span> <span class="o">=</span>\ |
| <span class="n">VectorIndexer</span><span class="p">(</span><span class="n">inputCol</span><span class="o">=</span><span class="s2">"features"</span><span class="p">,</span> <span class="n">outputCol</span><span class="o">=</span><span class="s2">"indexedFeatures"</span><span class="p">,</span> <span class="n">maxCategories</span><span class="o">=</span><span class="mi">4</span><span class="p">)</span><span class="o">.</span><span class="n">fit</span><span class="p">(</span><span class="n">data</span><span class="p">)</span> |
| |
| <span class="c1"># Split the data into training and test sets (30% held out for testing)</span> |
| <span class="p">(</span><span class="n">trainingData</span><span class="p">,</span> <span class="n">testData</span><span class="p">)</span> <span class="o">=</span> <span class="n">data</span><span class="o">.</span><span class="n">randomSplit</span><span class="p">([</span><span class="mf">0.7</span><span class="p">,</span> <span class="mf">0.3</span><span class="p">])</span> |
| |
| <span class="c1"># Train a GBT model.</span> |
| <span class="n">gbt</span> <span class="o">=</span> <span class="n">GBTClassifier</span><span class="p">(</span><span class="n">labelCol</span><span class="o">=</span><span class="s2">"indexedLabel"</span><span class="p">,</span> <span class="n">featuresCol</span><span class="o">=</span><span class="s2">"indexedFeatures"</span><span class="p">,</span> <span class="n">maxIter</span><span class="o">=</span><span class="mi">10</span><span class="p">)</span> |
| |
| <span class="c1"># Chain indexers and GBT in a Pipeline</span> |
| <span class="n">pipeline</span> <span class="o">=</span> <span class="n">Pipeline</span><span class="p">(</span><span class="n">stages</span><span class="o">=</span><span class="p">[</span><span class="n">labelIndexer</span><span class="p">,</span> <span class="n">featureIndexer</span><span class="p">,</span> <span class="n">gbt</span><span class="p">])</span> |
| |
| <span class="c1"># Train model. This also runs the indexers.</span> |
| <span class="n">model</span> <span class="o">=</span> <span class="n">pipeline</span><span class="o">.</span><span class="n">fit</span><span class="p">(</span><span class="n">trainingData</span><span class="p">)</span> |
| |
| <span class="c1"># Make predictions.</span> |
| <span class="n">predictions</span> <span class="o">=</span> <span class="n">model</span><span class="o">.</span><span class="n">transform</span><span class="p">(</span><span class="n">testData</span><span class="p">)</span> |
| |
| <span class="c1"># Select example rows to display.</span> |
| <span class="n">predictions</span><span class="o">.</span><span class="n">select</span><span class="p">(</span><span class="s2">"prediction"</span><span class="p">,</span> <span class="s2">"indexedLabel"</span><span class="p">,</span> <span class="s2">"features"</span><span class="p">)</span><span class="o">.</span><span class="n">show</span><span class="p">(</span><span class="mi">5</span><span class="p">)</span> |
| |
| <span class="c1"># Select (prediction, true label) and compute test error</span> |
| <span class="n">evaluator</span> <span class="o">=</span> <span class="n">MulticlassClassificationEvaluator</span><span class="p">(</span> |
| <span class="n">labelCol</span><span class="o">=</span><span class="s2">"indexedLabel"</span><span class="p">,</span> <span class="n">predictionCol</span><span class="o">=</span><span class="s2">"prediction"</span><span class="p">,</span> <span class="n">metricName</span><span class="o">=</span><span class="s2">"accuracy"</span><span class="p">)</span> |
| <span class="n">accuracy</span> <span class="o">=</span> <span class="n">evaluator</span><span class="o">.</span><span class="n">evaluate</span><span class="p">(</span><span class="n">predictions</span><span class="p">)</span> |
| <span class="k">print</span><span class="p">(</span><span class="s2">"Test Error = </span><span class="si">%g</span><span class="s2">"</span> <span class="o">%</span> <span class="p">(</span><span class="mf">1.0</span> <span class="o">-</span> <span class="n">accuracy</span><span class="p">))</span> |
| |
| <span class="n">gbtModel</span> <span class="o">=</span> <span class="n">model</span><span class="o">.</span><span class="n">stages</span><span class="p">[</span><span class="mi">2</span><span class="p">]</span> |
| <span class="k">print</span><span class="p">(</span><span class="n">gbtModel</span><span class="p">)</span> <span class="c1"># summary only</span> |
| </pre></div> |
| <div><small>Find full example code at "examples/src/main/python/ml/gradient_boosted_tree_classifier_example.py" in the Spark repo.</small></div> |
| </div> |
| |
| <div data-lang="r"> |
| |
| <p>Refer to the <a href="api/R/spark.gbt.html">R API docs</a> for more details.</p> |
| |
| <div class="highlight"><pre><span></span><span class="c1"># Load training data</span> |
| df <span class="o"><-</span> read.df<span class="p">(</span><span class="s">"data/mllib/sample_libsvm_data.txt"</span><span class="p">,</span> <span class="kn">source</span> <span class="o">=</span> <span class="s">"libsvm"</span><span class="p">)</span> |
| training <span class="o"><-</span> df |
| test <span class="o"><-</span> df |
| |
| <span class="c1"># Fit a GBT classification model with spark.gbt</span> |
| model <span class="o"><-</span> spark.gbt<span class="p">(</span>training<span class="p">,</span> label <span class="o">~</span> features<span class="p">,</span> <span class="s">"classification"</span><span class="p">,</span> maxIter <span class="o">=</span> <span class="m">10</span><span class="p">)</span> |
| |
| <span class="c1"># Model summary</span> |
| <span class="kp">summary</span><span class="p">(</span>model<span class="p">)</span> |
| |
| <span class="c1"># Prediction</span> |
| predictions <span class="o"><-</span> predict<span class="p">(</span>model<span class="p">,</span> test<span class="p">)</span> |
| <span class="kp">head</span><span class="p">(</span>predictions<span class="p">)</span> |
| </pre></div> |
| <div><small>Find full example code at "examples/src/main/r/ml/gbt.R" in the Spark repo.</small></div> |
| </div> |
| |
| </div> |
| |
| <h2 id="multilayer-perceptron-classifier">Multilayer perceptron classifier</h2> |
| |
| <p>Multilayer perceptron classifier (MLPC) is a classifier based on the <a href="https://en.wikipedia.org/wiki/Feedforward_neural_network">feedforward artificial neural network</a>. |
| MLPC consists of multiple layers of nodes. |
| Each layer is fully connected to the next layer in the network. Nodes in the input layer represent the input data. All other nodes map inputs to outputs |
| by a linear combination of the inputs with the node’s weights <code>$\wv$</code> and bias <code>$\bv$</code> and applying an activation function. |
| This can be written in matrix form for MLPC with <code>$K+1$</code> layers as follows: |
| <code>\[ |
| \mathrm{y}(\x) = \mathrm{f_K}(...\mathrm{f_2}(\wv_2^T\mathrm{f_1}(\wv_1^T \x+b_1)+b_2)...+b_K) |
| \]</code> |
| Nodes in intermediate layers use sigmoid (logistic) function: |
| <code>\[ |
| \mathrm{f}(z_i) = \frac{1}{1 + e^{-z_i}} |
| \]</code> |
| Nodes in the output layer use softmax function: |
| <code>\[ |
| \mathrm{f}(z_i) = \frac{e^{z_i}}{\sum_{k=1}^N e^{z_k}} |
| \]</code> |
| The number of nodes <code>$N$</code> in the output layer corresponds to the number of classes.</p> |
| |
| <p>MLPC employs backpropagation for learning the model. We use the logistic loss function for optimization and L-BFGS as an optimization routine.</p> |
| |
| <p><strong>Examples</strong></p> |
| |
| <div class="codetabs"> |
| |
| <div data-lang="scala"> |
| |
| <p>Refer to the <a href="api/scala/index.html#org.apache.spark.ml.classification.MultilayerPerceptronClassifier">Scala API docs</a> for more details.</p> |
| |
| <div class="highlight"><pre><span></span><span class="k">import</span> <span class="nn">org.apache.spark.ml.classification.MultilayerPerceptronClassifier</span> |
| <span class="k">import</span> <span class="nn">org.apache.spark.ml.evaluation.MulticlassClassificationEvaluator</span> |
| |
| <span class="c1">// Load the data stored in LIBSVM format as a DataFrame.</span> |
| <span class="k">val</span> <span class="n">data</span> <span class="k">=</span> <span class="n">spark</span><span class="o">.</span><span class="n">read</span><span class="o">.</span><span class="n">format</span><span class="o">(</span><span class="s">"libsvm"</span><span class="o">)</span> |
| <span class="o">.</span><span class="n">load</span><span class="o">(</span><span class="s">"data/mllib/sample_multiclass_classification_data.txt"</span><span class="o">)</span> |
| |
| <span class="c1">// Split the data into train and test</span> |
| <span class="k">val</span> <span class="n">splits</span> <span class="k">=</span> <span class="n">data</span><span class="o">.</span><span class="n">randomSplit</span><span class="o">(</span><span class="nc">Array</span><span class="o">(</span><span class="mf">0.6</span><span class="o">,</span> <span class="mf">0.4</span><span class="o">),</span> <span class="n">seed</span> <span class="k">=</span> <span class="mi">1234L</span><span class="o">)</span> |
| <span class="k">val</span> <span class="n">train</span> <span class="k">=</span> <span class="n">splits</span><span class="o">(</span><span class="mi">0</span><span class="o">)</span> |
| <span class="k">val</span> <span class="n">test</span> <span class="k">=</span> <span class="n">splits</span><span class="o">(</span><span class="mi">1</span><span class="o">)</span> |
| |
| <span class="c1">// specify layers for the neural network:</span> |
| <span class="c1">// input layer of size 4 (features), two intermediate of size 5 and 4</span> |
| <span class="c1">// and output of size 3 (classes)</span> |
| <span class="k">val</span> <span class="n">layers</span> <span class="k">=</span> <span class="nc">Array</span><span class="o">[</span><span class="kt">Int</span><span class="o">](</span><span class="mi">4</span><span class="o">,</span> <span class="mi">5</span><span class="o">,</span> <span class="mi">4</span><span class="o">,</span> <span class="mi">3</span><span class="o">)</span> |
| |
| <span class="c1">// create the trainer and set its parameters</span> |
| <span class="k">val</span> <span class="n">trainer</span> <span class="k">=</span> <span class="k">new</span> <span class="nc">MultilayerPerceptronClassifier</span><span class="o">()</span> |
| <span class="o">.</span><span class="n">setLayers</span><span class="o">(</span><span class="n">layers</span><span class="o">)</span> |
| <span class="o">.</span><span class="n">setBlockSize</span><span class="o">(</span><span class="mi">128</span><span class="o">)</span> |
| <span class="o">.</span><span class="n">setSeed</span><span class="o">(</span><span class="mi">1234L</span><span class="o">)</span> |
| <span class="o">.</span><span class="n">setMaxIter</span><span class="o">(</span><span class="mi">100</span><span class="o">)</span> |
| |
| <span class="c1">// train the model</span> |
| <span class="k">val</span> <span class="n">model</span> <span class="k">=</span> <span class="n">trainer</span><span class="o">.</span><span class="n">fit</span><span class="o">(</span><span class="n">train</span><span class="o">)</span> |
| |
| <span class="c1">// compute accuracy on the test set</span> |
| <span class="k">val</span> <span class="n">result</span> <span class="k">=</span> <span class="n">model</span><span class="o">.</span><span class="n">transform</span><span class="o">(</span><span class="n">test</span><span class="o">)</span> |
| <span class="k">val</span> <span class="n">predictionAndLabels</span> <span class="k">=</span> <span class="n">result</span><span class="o">.</span><span class="n">select</span><span class="o">(</span><span class="s">"prediction"</span><span class="o">,</span> <span class="s">"label"</span><span class="o">)</span> |
| <span class="k">val</span> <span class="n">evaluator</span> <span class="k">=</span> <span class="k">new</span> <span class="nc">MulticlassClassificationEvaluator</span><span class="o">()</span> |
| <span class="o">.</span><span class="n">setMetricName</span><span class="o">(</span><span class="s">"accuracy"</span><span class="o">)</span> |
| |
| <span class="n">println</span><span class="o">(</span><span class="s">s"Test set accuracy = </span><span class="si">${</span><span class="n">evaluator</span><span class="o">.</span><span class="n">evaluate</span><span class="o">(</span><span class="n">predictionAndLabels</span><span class="o">)</span><span class="si">}</span><span class="s">"</span><span class="o">)</span> |
| </pre></div> |
| <div><small>Find full example code at "examples/src/main/scala/org/apache/spark/examples/ml/MultilayerPerceptronClassifierExample.scala" in the Spark repo.</small></div> |
| </div> |
| |
| <div data-lang="java"> |
| |
| <p>Refer to the <a href="api/java/org/apache/spark/ml/classification/MultilayerPerceptronClassifier.html">Java API docs</a> for more details.</p> |
| |
| <div class="highlight"><pre><span></span><span class="kn">import</span> <span class="nn">org.apache.spark.sql.Dataset</span><span class="o">;</span> |
| <span class="kn">import</span> <span class="nn">org.apache.spark.sql.Row</span><span class="o">;</span> |
| <span class="kn">import</span> <span class="nn">org.apache.spark.sql.SparkSession</span><span class="o">;</span> |
| <span class="kn">import</span> <span class="nn">org.apache.spark.ml.classification.MultilayerPerceptronClassificationModel</span><span class="o">;</span> |
| <span class="kn">import</span> <span class="nn">org.apache.spark.ml.classification.MultilayerPerceptronClassifier</span><span class="o">;</span> |
| <span class="kn">import</span> <span class="nn">org.apache.spark.ml.evaluation.MulticlassClassificationEvaluator</span><span class="o">;</span> |
| |
| <span class="c1">// Load training data</span> |
| <span class="n">String</span> <span class="n">path</span> <span class="o">=</span> <span class="s">"data/mllib/sample_multiclass_classification_data.txt"</span><span class="o">;</span> |
| <span class="n">Dataset</span><span class="o"><</span><span class="n">Row</span><span class="o">></span> <span class="n">dataFrame</span> <span class="o">=</span> <span class="n">spark</span><span class="o">.</span><span class="na">read</span><span class="o">().</span><span class="na">format</span><span class="o">(</span><span class="s">"libsvm"</span><span class="o">).</span><span class="na">load</span><span class="o">(</span><span class="n">path</span><span class="o">);</span> |
| |
| <span class="c1">// Split the data into train and test</span> |
| <span class="n">Dataset</span><span class="o"><</span><span class="n">Row</span><span class="o">>[]</span> <span class="n">splits</span> <span class="o">=</span> <span class="n">dataFrame</span><span class="o">.</span><span class="na">randomSplit</span><span class="o">(</span><span class="k">new</span> <span class="kt">double</span><span class="o">[]{</span><span class="mf">0.6</span><span class="o">,</span> <span class="mf">0.4</span><span class="o">},</span> <span class="mi">1234L</span><span class="o">);</span> |
| <span class="n">Dataset</span><span class="o"><</span><span class="n">Row</span><span class="o">></span> <span class="n">train</span> <span class="o">=</span> <span class="n">splits</span><span class="o">[</span><span class="mi">0</span><span class="o">];</span> |
| <span class="n">Dataset</span><span class="o"><</span><span class="n">Row</span><span class="o">></span> <span class="n">test</span> <span class="o">=</span> <span class="n">splits</span><span class="o">[</span><span class="mi">1</span><span class="o">];</span> |
| |
| <span class="c1">// specify layers for the neural network:</span> |
| <span class="c1">// input layer of size 4 (features), two intermediate of size 5 and 4</span> |
| <span class="c1">// and output of size 3 (classes)</span> |
| <span class="kt">int</span><span class="o">[]</span> <span class="n">layers</span> <span class="o">=</span> <span class="k">new</span> <span class="kt">int</span><span class="o">[]</span> <span class="o">{</span><span class="mi">4</span><span class="o">,</span> <span class="mi">5</span><span class="o">,</span> <span class="mi">4</span><span class="o">,</span> <span class="mi">3</span><span class="o">};</span> |
| |
| <span class="c1">// create the trainer and set its parameters</span> |
| <span class="n">MultilayerPerceptronClassifier</span> <span class="n">trainer</span> <span class="o">=</span> <span class="k">new</span> <span class="n">MultilayerPerceptronClassifier</span><span class="o">()</span> |
| <span class="o">.</span><span class="na">setLayers</span><span class="o">(</span><span class="n">layers</span><span class="o">)</span> |
| <span class="o">.</span><span class="na">setBlockSize</span><span class="o">(</span><span class="mi">128</span><span class="o">)</span> |
| <span class="o">.</span><span class="na">setSeed</span><span class="o">(</span><span class="mi">1234L</span><span class="o">)</span> |
| <span class="o">.</span><span class="na">setMaxIter</span><span class="o">(</span><span class="mi">100</span><span class="o">);</span> |
| |
| <span class="c1">// train the model</span> |
| <span class="n">MultilayerPerceptronClassificationModel</span> <span class="n">model</span> <span class="o">=</span> <span class="n">trainer</span><span class="o">.</span><span class="na">fit</span><span class="o">(</span><span class="n">train</span><span class="o">);</span> |
| |
| <span class="c1">// compute accuracy on the test set</span> |
| <span class="n">Dataset</span><span class="o"><</span><span class="n">Row</span><span class="o">></span> <span class="n">result</span> <span class="o">=</span> <span class="n">model</span><span class="o">.</span><span class="na">transform</span><span class="o">(</span><span class="n">test</span><span class="o">);</span> |
| <span class="n">Dataset</span><span class="o"><</span><span class="n">Row</span><span class="o">></span> <span class="n">predictionAndLabels</span> <span class="o">=</span> <span class="n">result</span><span class="o">.</span><span class="na">select</span><span class="o">(</span><span class="s">"prediction"</span><span class="o">,</span> <span class="s">"label"</span><span class="o">);</span> |
| <span class="n">MulticlassClassificationEvaluator</span> <span class="n">evaluator</span> <span class="o">=</span> <span class="k">new</span> <span class="n">MulticlassClassificationEvaluator</span><span class="o">()</span> |
| <span class="o">.</span><span class="na">setMetricName</span><span class="o">(</span><span class="s">"accuracy"</span><span class="o">);</span> |
| |
| <span class="n">System</span><span class="o">.</span><span class="na">out</span><span class="o">.</span><span class="na">println</span><span class="o">(</span><span class="s">"Test set accuracy = "</span> <span class="o">+</span> <span class="n">evaluator</span><span class="o">.</span><span class="na">evaluate</span><span class="o">(</span><span class="n">predictionAndLabels</span><span class="o">));</span> |
| </pre></div> |
| <div><small>Find full example code at "examples/src/main/java/org/apache/spark/examples/ml/JavaMultilayerPerceptronClassifierExample.java" in the Spark repo.</small></div> |
| </div> |
| |
| <div data-lang="python"> |
| |
| <p>Refer to the <a href="api/python/pyspark.ml.html#pyspark.ml.classification.MultilayerPerceptronClassifier">Python API docs</a> for more details.</p> |
| |
| <div class="highlight"><pre><span></span><span class="kn">from</span> <span class="nn">pyspark.ml.classification</span> <span class="kn">import</span> <span class="n">MultilayerPerceptronClassifier</span> |
| <span class="kn">from</span> <span class="nn">pyspark.ml.evaluation</span> <span class="kn">import</span> <span class="n">MulticlassClassificationEvaluator</span> |
| |
| <span class="c1"># Load training data</span> |
| <span class="n">data</span> <span class="o">=</span> <span class="n">spark</span><span class="o">.</span><span class="n">read</span><span class="o">.</span><span class="n">format</span><span class="p">(</span><span class="s2">"libsvm"</span><span class="p">)</span>\ |
| <span class="o">.</span><span class="n">load</span><span class="p">(</span><span class="s2">"data/mllib/sample_multiclass_classification_data.txt"</span><span class="p">)</span> |
| |
| <span class="c1"># Split the data into train and test</span> |
| <span class="n">splits</span> <span class="o">=</span> <span class="n">data</span><span class="o">.</span><span class="n">randomSplit</span><span class="p">([</span><span class="mf">0.6</span><span class="p">,</span> <span class="mf">0.4</span><span class="p">],</span> <span class="mi">1234</span><span class="p">)</span> |
| <span class="n">train</span> <span class="o">=</span> <span class="n">splits</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span> |
| <span class="n">test</span> <span class="o">=</span> <span class="n">splits</span><span class="p">[</span><span class="mi">1</span><span class="p">]</span> |
| |
| <span class="c1"># specify layers for the neural network:</span> |
| <span class="c1"># input layer of size 4 (features), two intermediate of size 5 and 4</span> |
| <span class="c1"># and output of size 3 (classes)</span> |
| <span class="n">layers</span> <span class="o">=</span> <span class="p">[</span><span class="mi">4</span><span class="p">,</span> <span class="mi">5</span><span class="p">,</span> <span class="mi">4</span><span class="p">,</span> <span class="mi">3</span><span class="p">]</span> |
| |
| <span class="c1"># create the trainer and set its parameters</span> |
| <span class="n">trainer</span> <span class="o">=</span> <span class="n">MultilayerPerceptronClassifier</span><span class="p">(</span><span class="n">maxIter</span><span class="o">=</span><span class="mi">100</span><span class="p">,</span> <span class="n">layers</span><span class="o">=</span><span class="n">layers</span><span class="p">,</span> <span class="n">blockSize</span><span class="o">=</span><span class="mi">128</span><span class="p">,</span> <span class="n">seed</span><span class="o">=</span><span class="mi">1234</span><span class="p">)</span> |
| |
| <span class="c1"># train the model</span> |
| <span class="n">model</span> <span class="o">=</span> <span class="n">trainer</span><span class="o">.</span><span class="n">fit</span><span class="p">(</span><span class="n">train</span><span class="p">)</span> |
| |
| <span class="c1"># compute accuracy on the test set</span> |
| <span class="n">result</span> <span class="o">=</span> <span class="n">model</span><span class="o">.</span><span class="n">transform</span><span class="p">(</span><span class="n">test</span><span class="p">)</span> |
| <span class="n">predictionAndLabels</span> <span class="o">=</span> <span class="n">result</span><span class="o">.</span><span class="n">select</span><span class="p">(</span><span class="s2">"prediction"</span><span class="p">,</span> <span class="s2">"label"</span><span class="p">)</span> |
| <span class="n">evaluator</span> <span class="o">=</span> <span class="n">MulticlassClassificationEvaluator</span><span class="p">(</span><span class="n">metricName</span><span class="o">=</span><span class="s2">"accuracy"</span><span class="p">)</span> |
| <span class="k">print</span><span class="p">(</span><span class="s2">"Test set accuracy = "</span> <span class="o">+</span> <span class="nb">str</span><span class="p">(</span><span class="n">evaluator</span><span class="o">.</span><span class="n">evaluate</span><span class="p">(</span><span class="n">predictionAndLabels</span><span class="p">)))</span> |
| </pre></div> |
| <div><small>Find full example code at "examples/src/main/python/ml/multilayer_perceptron_classification.py" in the Spark repo.</small></div> |
| </div> |
| |
| <div data-lang="r"> |
| |
| <p>Refer to the <a href="api/R/spark.mlp.html">R API docs</a> for more details.</p> |
| |
| <div class="highlight"><pre><span></span><span class="c1"># Load training data</span> |
| df <span class="o"><-</span> read.df<span class="p">(</span><span class="s">"data/mllib/sample_multiclass_classification_data.txt"</span><span class="p">,</span> <span class="kn">source</span> <span class="o">=</span> <span class="s">"libsvm"</span><span class="p">)</span> |
| training <span class="o"><-</span> df |
| test <span class="o"><-</span> df |
| |
| <span class="c1"># specify layers for the neural network:</span> |
| <span class="c1"># input layer of size 4 (features), two intermediate of size 5 and 4</span> |
| <span class="c1"># and output of size 3 (classes)</span> |
| layers <span class="o">=</span> <span class="kt">c</span><span class="p">(</span><span class="m">4</span><span class="p">,</span> <span class="m">5</span><span class="p">,</span> <span class="m">4</span><span class="p">,</span> <span class="m">3</span><span class="p">)</span> |
| |
| <span class="c1"># Fit a multi-layer perceptron neural network model with spark.mlp</span> |
| model <span class="o"><-</span> spark.mlp<span class="p">(</span>training<span class="p">,</span> label <span class="o">~</span> features<span class="p">,</span> maxIter <span class="o">=</span> <span class="m">100</span><span class="p">,</span> |
| layers <span class="o">=</span> layers<span class="p">,</span> blockSize <span class="o">=</span> <span class="m">128</span><span class="p">,</span> seed <span class="o">=</span> <span class="m">1234</span><span class="p">)</span> |
| |
| <span class="c1"># Model summary</span> |
| <span class="kp">summary</span><span class="p">(</span>model<span class="p">)</span> |
| |
| <span class="c1"># Prediction</span> |
| predictions <span class="o"><-</span> predict<span class="p">(</span>model<span class="p">,</span> test<span class="p">)</span> |
| <span class="kp">head</span><span class="p">(</span>predictions<span class="p">)</span> |
| </pre></div> |
| <div><small>Find full example code at "examples/src/main/r/ml/mlp.R" in the Spark repo.</small></div> |
| </div> |
| |
| </div> |
| |
| <h2 id="linear-support-vector-machine">Linear Support Vector Machine</h2> |
| |
| <p>A <a href="https://en.wikipedia.org/wiki/Support_vector_machine">support vector machine</a> constructs a hyperplane |
| or set of hyperplanes in a high- or infinite-dimensional space, which can be used for classification, |
| regression, or other tasks. Intuitively, a good separation is achieved by the hyperplane that has |
| the largest distance to the nearest training-data points of any class (so-called functional margin), |
| since in general the larger the margin the lower the generalization error of the classifier. LinearSVC |
| in Spark ML supports binary classification with linear SVM. Internally, it optimizes the |
| <a href="https://en.wikipedia.org/wiki/Hinge_loss">Hinge Loss</a> using OWLQN optimizer.</p> |
| |
| <p><strong>Examples</strong></p> |
| |
| <div class="codetabs"> |
| |
| <div data-lang="scala"> |
| |
| <p>Refer to the <a href="api/scala/index.html#org.apache.spark.ml.classification.LinearSVC">Scala API docs</a> for more details.</p> |
| |
| <div class="highlight"><pre><span></span><span class="k">import</span> <span class="nn">org.apache.spark.ml.classification.LinearSVC</span> |
| |
| <span class="c1">// Load training data</span> |
| <span class="k">val</span> <span class="n">training</span> <span class="k">=</span> <span class="n">spark</span><span class="o">.</span><span class="n">read</span><span class="o">.</span><span class="n">format</span><span class="o">(</span><span class="s">"libsvm"</span><span class="o">).</span><span class="n">load</span><span class="o">(</span><span class="s">"data/mllib/sample_libsvm_data.txt"</span><span class="o">)</span> |
| |
| <span class="k">val</span> <span class="n">lsvc</span> <span class="k">=</span> <span class="k">new</span> <span class="nc">LinearSVC</span><span class="o">()</span> |
| <span class="o">.</span><span class="n">setMaxIter</span><span class="o">(</span><span class="mi">10</span><span class="o">)</span> |
| <span class="o">.</span><span class="n">setRegParam</span><span class="o">(</span><span class="mf">0.1</span><span class="o">)</span> |
| |
| <span class="c1">// Fit the model</span> |
| <span class="k">val</span> <span class="n">lsvcModel</span> <span class="k">=</span> <span class="n">lsvc</span><span class="o">.</span><span class="n">fit</span><span class="o">(</span><span class="n">training</span><span class="o">)</span> |
| |
| <span class="c1">// Print the coefficients and intercept for linear svc</span> |
| <span class="n">println</span><span class="o">(</span><span class="s">s"Coefficients: </span><span class="si">${</span><span class="n">lsvcModel</span><span class="o">.</span><span class="n">coefficients</span><span class="si">}</span><span class="s"> Intercept: </span><span class="si">${</span><span class="n">lsvcModel</span><span class="o">.</span><span class="n">intercept</span><span class="si">}</span><span class="s">"</span><span class="o">)</span> |
| </pre></div> |
| <div><small>Find full example code at "examples/src/main/scala/org/apache/spark/examples/ml/LinearSVCExample.scala" in the Spark repo.</small></div> |
| </div> |
| |
| <div data-lang="java"> |
| |
| <p>Refer to the <a href="api/java/org/apache/spark/ml/classification/LinearSVC.html">Java API docs</a> for more details.</p> |
| |
| <div class="highlight"><pre><span></span><span class="kn">import</span> <span class="nn">org.apache.spark.ml.classification.LinearSVC</span><span class="o">;</span> |
| <span class="kn">import</span> <span class="nn">org.apache.spark.ml.classification.LinearSVCModel</span><span class="o">;</span> |
| <span class="kn">import</span> <span class="nn">org.apache.spark.sql.Dataset</span><span class="o">;</span> |
| <span class="kn">import</span> <span class="nn">org.apache.spark.sql.Row</span><span class="o">;</span> |
| <span class="kn">import</span> <span class="nn">org.apache.spark.sql.SparkSession</span><span class="o">;</span> |
| |
| <span class="c1">// Load training data</span> |
| <span class="n">Dataset</span><span class="o"><</span><span class="n">Row</span><span class="o">></span> <span class="n">training</span> <span class="o">=</span> <span class="n">spark</span><span class="o">.</span><span class="na">read</span><span class="o">().</span><span class="na">format</span><span class="o">(</span><span class="s">"libsvm"</span><span class="o">)</span> |
| <span class="o">.</span><span class="na">load</span><span class="o">(</span><span class="s">"data/mllib/sample_libsvm_data.txt"</span><span class="o">);</span> |
| |
| <span class="n">LinearSVC</span> <span class="n">lsvc</span> <span class="o">=</span> <span class="k">new</span> <span class="n">LinearSVC</span><span class="o">()</span> |
| <span class="o">.</span><span class="na">setMaxIter</span><span class="o">(</span><span class="mi">10</span><span class="o">)</span> |
| <span class="o">.</span><span class="na">setRegParam</span><span class="o">(</span><span class="mf">0.1</span><span class="o">);</span> |
| |
| <span class="c1">// Fit the model</span> |
| <span class="n">LinearSVCModel</span> <span class="n">lsvcModel</span> <span class="o">=</span> <span class="n">lsvc</span><span class="o">.</span><span class="na">fit</span><span class="o">(</span><span class="n">training</span><span class="o">);</span> |
| |
| <span class="c1">// Print the coefficients and intercept for LinearSVC</span> |
| <span class="n">System</span><span class="o">.</span><span class="na">out</span><span class="o">.</span><span class="na">println</span><span class="o">(</span><span class="s">"Coefficients: "</span> |
| <span class="o">+</span> <span class="n">lsvcModel</span><span class="o">.</span><span class="na">coefficients</span><span class="o">()</span> <span class="o">+</span> <span class="s">" Intercept: "</span> <span class="o">+</span> <span class="n">lsvcModel</span><span class="o">.</span><span class="na">intercept</span><span class="o">());</span> |
| </pre></div> |
| <div><small>Find full example code at "examples/src/main/java/org/apache/spark/examples/ml/JavaLinearSVCExample.java" in the Spark repo.</small></div> |
| </div> |
| |
| <div data-lang="python"> |
| |
| <p>Refer to the <a href="api/python/pyspark.ml.html#pyspark.ml.classification.LinearSVC">Python API docs</a> for more details.</p> |
| |
| <div class="highlight"><pre><span></span><span class="kn">from</span> <span class="nn">pyspark.ml.classification</span> <span class="kn">import</span> <span class="n">LinearSVC</span> |
| |
| <span class="c1"># Load training data</span> |
| <span class="n">training</span> <span class="o">=</span> <span class="n">spark</span><span class="o">.</span><span class="n">read</span><span class="o">.</span><span class="n">format</span><span class="p">(</span><span class="s2">"libsvm"</span><span class="p">)</span><span class="o">.</span><span class="n">load</span><span class="p">(</span><span class="s2">"data/mllib/sample_libsvm_data.txt"</span><span class="p">)</span> |
| |
| <span class="n">lsvc</span> <span class="o">=</span> <span class="n">LinearSVC</span><span class="p">(</span><span class="n">maxIter</span><span class="o">=</span><span class="mi">10</span><span class="p">,</span> <span class="n">regParam</span><span class="o">=</span><span class="mf">0.1</span><span class="p">)</span> |
| |
| <span class="c1"># Fit the model</span> |
| <span class="n">lsvcModel</span> <span class="o">=</span> <span class="n">lsvc</span><span class="o">.</span><span class="n">fit</span><span class="p">(</span><span class="n">training</span><span class="p">)</span> |
| |
| <span class="c1"># Print the coefficients and intercept for linear SVC</span> |
| <span class="k">print</span><span class="p">(</span><span class="s2">"Coefficients: "</span> <span class="o">+</span> <span class="nb">str</span><span class="p">(</span><span class="n">lsvcModel</span><span class="o">.</span><span class="n">coefficients</span><span class="p">))</span> |
| <span class="k">print</span><span class="p">(</span><span class="s2">"Intercept: "</span> <span class="o">+</span> <span class="nb">str</span><span class="p">(</span><span class="n">lsvcModel</span><span class="o">.</span><span class="n">intercept</span><span class="p">))</span> |
| </pre></div> |
| <div><small>Find full example code at "examples/src/main/python/ml/linearsvc.py" in the Spark repo.</small></div> |
| </div> |
| |
| <div data-lang="r"> |
| |
| <p>Refer to the <a href="api/R/spark.svmLinear.html">R API docs</a> for more details.</p> |
| |
| <div class="highlight"><pre><span></span><span class="c1"># load training data</span> |
| t <span class="o"><-</span> <span class="kp">as.data.frame</span><span class="p">(</span>Titanic<span class="p">)</span> |
| training <span class="o"><-</span> createDataFrame<span class="p">(</span><span class="kp">t</span><span class="p">)</span> |
| |
| <span class="c1"># fit Linear SVM model</span> |
| model <span class="o"><-</span> spark.svmLinear<span class="p">(</span>training<span class="p">,</span> Survived <span class="o">~</span> <span class="m">.</span><span class="p">,</span> regParam <span class="o">=</span> <span class="m">0.01</span><span class="p">,</span> maxIter <span class="o">=</span> <span class="m">10</span><span class="p">)</span> |
| |
| <span class="c1"># Model summary</span> |
| <span class="kp">summary</span><span class="p">(</span>model<span class="p">)</span> |
| |
| <span class="c1"># Prediction</span> |
| prediction <span class="o"><-</span> predict<span class="p">(</span>model<span class="p">,</span> training<span class="p">)</span> |
| showDF<span class="p">(</span>prediction<span class="p">)</span> |
| </pre></div> |
| <div><small>Find full example code at "examples/src/main/r/ml/svmLinear.R" in the Spark repo.</small></div> |
| </div> |
| |
| </div> |
| |
| <h2 id="one-vs-rest-classifier-aka-one-vs-all">One-vs-Rest classifier (a.k.a. One-vs-All)</h2> |
| |
| <p><a href="http://en.wikipedia.org/wiki/Multiclass_classification#One-vs.-rest">OneVsRest</a> is an example of a machine learning reduction for performing multiclass classification given a base classifier that can perform binary classification efficiently. It is also known as “One-vs-All.”</p> |
| |
| <p><code>OneVsRest</code> is implemented as an <code>Estimator</code>. For the base classifier, it takes instances of <code>Classifier</code> and creates a binary classification problem for each of the k classes. The classifier for class i is trained to predict whether the label is i or not, distinguishing class i from all other classes.</p> |
| |
| <p>Predictions are done by evaluating each binary classifier and the index of the most confident classifier is output as label.</p> |
| |
| <p><strong>Examples</strong></p> |
| |
| <p>The example below demonstrates how to load the |
| <a href="http://www.csie.ntu.edu.tw/~cjlin/libsvmtools/datasets/multiclass/iris.scale">Iris dataset</a>, parse it as a DataFrame and perform multiclass classification using <code>OneVsRest</code>. The test error is calculated to measure the algorithm accuracy.</p> |
| |
| <div class="codetabs"> |
| <div data-lang="scala"> |
| |
| <p>Refer to the <a href="api/scala/index.html#org.apache.spark.ml.classification.OneVsRest">Scala API docs</a> for more details.</p> |
| |
| <div class="highlight"><pre><span></span><span class="k">import</span> <span class="nn">org.apache.spark.ml.classification.</span><span class="o">{</span><span class="nc">LogisticRegression</span><span class="o">,</span> <span class="nc">OneVsRest</span><span class="o">}</span> |
| <span class="k">import</span> <span class="nn">org.apache.spark.ml.evaluation.MulticlassClassificationEvaluator</span> |
| |
| <span class="c1">// load data file.</span> |
| <span class="k">val</span> <span class="n">inputData</span> <span class="k">=</span> <span class="n">spark</span><span class="o">.</span><span class="n">read</span><span class="o">.</span><span class="n">format</span><span class="o">(</span><span class="s">"libsvm"</span><span class="o">)</span> |
| <span class="o">.</span><span class="n">load</span><span class="o">(</span><span class="s">"data/mllib/sample_multiclass_classification_data.txt"</span><span class="o">)</span> |
| |
| <span class="c1">// generate the train/test split.</span> |
| <span class="k">val</span> <span class="nc">Array</span><span class="o">(</span><span class="n">train</span><span class="o">,</span> <span class="n">test</span><span class="o">)</span> <span class="k">=</span> <span class="n">inputData</span><span class="o">.</span><span class="n">randomSplit</span><span class="o">(</span><span class="nc">Array</span><span class="o">(</span><span class="mf">0.8</span><span class="o">,</span> <span class="mf">0.2</span><span class="o">))</span> |
| |
| <span class="c1">// instantiate the base classifier</span> |
| <span class="k">val</span> <span class="n">classifier</span> <span class="k">=</span> <span class="k">new</span> <span class="nc">LogisticRegression</span><span class="o">()</span> |
| <span class="o">.</span><span class="n">setMaxIter</span><span class="o">(</span><span class="mi">10</span><span class="o">)</span> |
| <span class="o">.</span><span class="n">setTol</span><span class="o">(</span><span class="mi">1</span><span class="n">E</span><span class="o">-</span><span class="mi">6</span><span class="o">)</span> |
| <span class="o">.</span><span class="n">setFitIntercept</span><span class="o">(</span><span class="kc">true</span><span class="o">)</span> |
| |
| <span class="c1">// instantiate the One Vs Rest Classifier.</span> |
| <span class="k">val</span> <span class="n">ovr</span> <span class="k">=</span> <span class="k">new</span> <span class="nc">OneVsRest</span><span class="o">().</span><span class="n">setClassifier</span><span class="o">(</span><span class="n">classifier</span><span class="o">)</span> |
| |
| <span class="c1">// train the multiclass model.</span> |
| <span class="k">val</span> <span class="n">ovrModel</span> <span class="k">=</span> <span class="n">ovr</span><span class="o">.</span><span class="n">fit</span><span class="o">(</span><span class="n">train</span><span class="o">)</span> |
| |
| <span class="c1">// score the model on test data.</span> |
| <span class="k">val</span> <span class="n">predictions</span> <span class="k">=</span> <span class="n">ovrModel</span><span class="o">.</span><span class="n">transform</span><span class="o">(</span><span class="n">test</span><span class="o">)</span> |
| |
| <span class="c1">// obtain evaluator.</span> |
| <span class="k">val</span> <span class="n">evaluator</span> <span class="k">=</span> <span class="k">new</span> <span class="nc">MulticlassClassificationEvaluator</span><span class="o">()</span> |
| <span class="o">.</span><span class="n">setMetricName</span><span class="o">(</span><span class="s">"accuracy"</span><span class="o">)</span> |
| |
| <span class="c1">// compute the classification error on test data.</span> |
| <span class="k">val</span> <span class="n">accuracy</span> <span class="k">=</span> <span class="n">evaluator</span><span class="o">.</span><span class="n">evaluate</span><span class="o">(</span><span class="n">predictions</span><span class="o">)</span> |
| <span class="n">println</span><span class="o">(</span><span class="s">s"Test Error = </span><span class="si">${</span><span class="mi">1</span> <span class="o">-</span> <span class="n">accuracy</span><span class="si">}</span><span class="s">"</span><span class="o">)</span> |
| </pre></div> |
| <div><small>Find full example code at "examples/src/main/scala/org/apache/spark/examples/ml/OneVsRestExample.scala" in the Spark repo.</small></div> |
| </div> |
| |
| <div data-lang="java"> |
| |
| <p>Refer to the <a href="api/java/org/apache/spark/ml/classification/OneVsRest.html">Java API docs</a> for more details.</p> |
| |
| <div class="highlight"><pre><span></span><span class="kn">import</span> <span class="nn">org.apache.spark.ml.classification.LogisticRegression</span><span class="o">;</span> |
| <span class="kn">import</span> <span class="nn">org.apache.spark.ml.classification.OneVsRest</span><span class="o">;</span> |
| <span class="kn">import</span> <span class="nn">org.apache.spark.ml.classification.OneVsRestModel</span><span class="o">;</span> |
| <span class="kn">import</span> <span class="nn">org.apache.spark.ml.evaluation.MulticlassClassificationEvaluator</span><span class="o">;</span> |
| <span class="kn">import</span> <span class="nn">org.apache.spark.sql.Dataset</span><span class="o">;</span> |
| <span class="kn">import</span> <span class="nn">org.apache.spark.sql.Row</span><span class="o">;</span> |
| |
| <span class="c1">// load data file.</span> |
| <span class="n">Dataset</span><span class="o"><</span><span class="n">Row</span><span class="o">></span> <span class="n">inputData</span> <span class="o">=</span> <span class="n">spark</span><span class="o">.</span><span class="na">read</span><span class="o">().</span><span class="na">format</span><span class="o">(</span><span class="s">"libsvm"</span><span class="o">)</span> |
| <span class="o">.</span><span class="na">load</span><span class="o">(</span><span class="s">"data/mllib/sample_multiclass_classification_data.txt"</span><span class="o">);</span> |
| |
| <span class="c1">// generate the train/test split.</span> |
| <span class="n">Dataset</span><span class="o"><</span><span class="n">Row</span><span class="o">>[]</span> <span class="n">tmp</span> <span class="o">=</span> <span class="n">inputData</span><span class="o">.</span><span class="na">randomSplit</span><span class="o">(</span><span class="k">new</span> <span class="kt">double</span><span class="o">[]{</span><span class="mf">0.8</span><span class="o">,</span> <span class="mf">0.2</span><span class="o">});</span> |
| <span class="n">Dataset</span><span class="o"><</span><span class="n">Row</span><span class="o">></span> <span class="n">train</span> <span class="o">=</span> <span class="n">tmp</span><span class="o">[</span><span class="mi">0</span><span class="o">];</span> |
| <span class="n">Dataset</span><span class="o"><</span><span class="n">Row</span><span class="o">></span> <span class="n">test</span> <span class="o">=</span> <span class="n">tmp</span><span class="o">[</span><span class="mi">1</span><span class="o">];</span> |
| |
| <span class="c1">// configure the base classifier.</span> |
| <span class="n">LogisticRegression</span> <span class="n">classifier</span> <span class="o">=</span> <span class="k">new</span> <span class="n">LogisticRegression</span><span class="o">()</span> |
| <span class="o">.</span><span class="na">setMaxIter</span><span class="o">(</span><span class="mi">10</span><span class="o">)</span> |
| <span class="o">.</span><span class="na">setTol</span><span class="o">(</span><span class="mf">1E-6</span><span class="o">)</span> |
| <span class="o">.</span><span class="na">setFitIntercept</span><span class="o">(</span><span class="kc">true</span><span class="o">);</span> |
| |
| <span class="c1">// instantiate the One Vs Rest Classifier.</span> |
| <span class="n">OneVsRest</span> <span class="n">ovr</span> <span class="o">=</span> <span class="k">new</span> <span class="n">OneVsRest</span><span class="o">().</span><span class="na">setClassifier</span><span class="o">(</span><span class="n">classifier</span><span class="o">);</span> |
| |
| <span class="c1">// train the multiclass model.</span> |
| <span class="n">OneVsRestModel</span> <span class="n">ovrModel</span> <span class="o">=</span> <span class="n">ovr</span><span class="o">.</span><span class="na">fit</span><span class="o">(</span><span class="n">train</span><span class="o">);</span> |
| |
| <span class="c1">// score the model on test data.</span> |
| <span class="n">Dataset</span><span class="o"><</span><span class="n">Row</span><span class="o">></span> <span class="n">predictions</span> <span class="o">=</span> <span class="n">ovrModel</span><span class="o">.</span><span class="na">transform</span><span class="o">(</span><span class="n">test</span><span class="o">)</span> |
| <span class="o">.</span><span class="na">select</span><span class="o">(</span><span class="s">"prediction"</span><span class="o">,</span> <span class="s">"label"</span><span class="o">);</span> |
| |
| <span class="c1">// obtain evaluator.</span> |
| <span class="n">MulticlassClassificationEvaluator</span> <span class="n">evaluator</span> <span class="o">=</span> <span class="k">new</span> <span class="n">MulticlassClassificationEvaluator</span><span class="o">()</span> |
| <span class="o">.</span><span class="na">setMetricName</span><span class="o">(</span><span class="s">"accuracy"</span><span class="o">);</span> |
| |
| <span class="c1">// compute the classification error on test data.</span> |
| <span class="kt">double</span> <span class="n">accuracy</span> <span class="o">=</span> <span class="n">evaluator</span><span class="o">.</span><span class="na">evaluate</span><span class="o">(</span><span class="n">predictions</span><span class="o">);</span> |
| <span class="n">System</span><span class="o">.</span><span class="na">out</span><span class="o">.</span><span class="na">println</span><span class="o">(</span><span class="s">"Test Error = "</span> <span class="o">+</span> <span class="o">(</span><span class="mi">1</span> <span class="o">-</span> <span class="n">accuracy</span><span class="o">));</span> |
| </pre></div> |
| <div><small>Find full example code at "examples/src/main/java/org/apache/spark/examples/ml/JavaOneVsRestExample.java" in the Spark repo.</small></div> |
| </div> |
| |
| <div data-lang="python"> |
| |
| <p>Refer to the <a href="api/python/pyspark.ml.html#pyspark.ml.classification.OneVsRest">Python API docs</a> for more details.</p> |
| |
| <div class="highlight"><pre><span></span><span class="kn">from</span> <span class="nn">pyspark.ml.classification</span> <span class="kn">import</span> <span class="n">LogisticRegression</span><span class="p">,</span> <span class="n">OneVsRest</span> |
| <span class="kn">from</span> <span class="nn">pyspark.ml.evaluation</span> <span class="kn">import</span> <span class="n">MulticlassClassificationEvaluator</span> |
| |
| <span class="c1"># load data file.</span> |
| <span class="n">inputData</span> <span class="o">=</span> <span class="n">spark</span><span class="o">.</span><span class="n">read</span><span class="o">.</span><span class="n">format</span><span class="p">(</span><span class="s2">"libsvm"</span><span class="p">)</span> \ |
| <span class="o">.</span><span class="n">load</span><span class="p">(</span><span class="s2">"data/mllib/sample_multiclass_classification_data.txt"</span><span class="p">)</span> |
| |
| <span class="c1"># generate the train/test split.</span> |
| <span class="p">(</span><span class="n">train</span><span class="p">,</span> <span class="n">test</span><span class="p">)</span> <span class="o">=</span> <span class="n">inputData</span><span class="o">.</span><span class="n">randomSplit</span><span class="p">([</span><span class="mf">0.8</span><span class="p">,</span> <span class="mf">0.2</span><span class="p">])</span> |
| |
| <span class="c1"># instantiate the base classifier.</span> |
| <span class="n">lr</span> <span class="o">=</span> <span class="n">LogisticRegression</span><span class="p">(</span><span class="n">maxIter</span><span class="o">=</span><span class="mi">10</span><span class="p">,</span> <span class="n">tol</span><span class="o">=</span><span class="mf">1E-6</span><span class="p">,</span> <span class="n">fitIntercept</span><span class="o">=</span><span class="bp">True</span><span class="p">)</span> |
| |
| <span class="c1"># instantiate the One Vs Rest Classifier.</span> |
| <span class="n">ovr</span> <span class="o">=</span> <span class="n">OneVsRest</span><span class="p">(</span><span class="n">classifier</span><span class="o">=</span><span class="n">lr</span><span class="p">)</span> |
| |
| <span class="c1"># train the multiclass model.</span> |
| <span class="n">ovrModel</span> <span class="o">=</span> <span class="n">ovr</span><span class="o">.</span><span class="n">fit</span><span class="p">(</span><span class="n">train</span><span class="p">)</span> |
| |
| <span class="c1"># score the model on test data.</span> |
| <span class="n">predictions</span> <span class="o">=</span> <span class="n">ovrModel</span><span class="o">.</span><span class="n">transform</span><span class="p">(</span><span class="n">test</span><span class="p">)</span> |
| |
| <span class="c1"># obtain evaluator.</span> |
| <span class="n">evaluator</span> <span class="o">=</span> <span class="n">MulticlassClassificationEvaluator</span><span class="p">(</span><span class="n">metricName</span><span class="o">=</span><span class="s2">"accuracy"</span><span class="p">)</span> |
| |
| <span class="c1"># compute the classification error on test data.</span> |
| <span class="n">accuracy</span> <span class="o">=</span> <span class="n">evaluator</span><span class="o">.</span><span class="n">evaluate</span><span class="p">(</span><span class="n">predictions</span><span class="p">)</span> |
| <span class="k">print</span><span class="p">(</span><span class="s2">"Test Error = </span><span class="si">%g</span><span class="s2">"</span> <span class="o">%</span> <span class="p">(</span><span class="mf">1.0</span> <span class="o">-</span> <span class="n">accuracy</span><span class="p">))</span> |
| </pre></div> |
| <div><small>Find full example code at "examples/src/main/python/ml/one_vs_rest_example.py" in the Spark repo.</small></div> |
| </div> |
| </div> |
| |
| <h2 id="naive-bayes">Naive Bayes</h2> |
| |
| <p><a href="http://en.wikipedia.org/wiki/Naive_Bayes_classifier">Naive Bayes classifiers</a> are a family of simple |
| probabilistic, multiclass classifiers based on applying Bayes’ theorem with strong (naive) independence |
| assumptions between every pair of features.</p> |
| |
| <p>Naive Bayes can be trained very efficiently. With a single pass over the training data, |
| it computes the conditional probability distribution of each feature given each label. |
| For prediction, it applies Bayes’ theorem to compute the conditional probability distribution |
| of each label given an observation.</p> |
| |
| <p>MLlib supports both <a href="http://en.wikipedia.org/wiki/Naive_Bayes_classifier#Multinomial_naive_Bayes">multinomial naive Bayes</a> |
| and <a href="http://nlp.stanford.edu/IR-book/html/htmledition/the-bernoulli-model-1.html">Bernoulli naive Bayes</a>.</p> |
| |
| <p><em>Input data</em>: |
| These models are typically used for <a href="http://nlp.stanford.edu/IR-book/html/htmledition/naive-bayes-text-classification-1.html">document classification</a>. |
| Within that context, each observation is a document and each feature represents a term. |
| A feature’s value is the frequency of the term (in multinomial Naive Bayes) or |
| a zero or one indicating whether the term was found in the document (in Bernoulli Naive Bayes). |
| Feature values must be <em>non-negative</em>. The model type is selected with an optional parameter |
| “multinomial” or “bernoulli” with “multinomial” as the default. |
| For document classification, the input feature vectors should usually be sparse vectors. |
| Since the training data is only used once, it is not necessary to cache it.</p> |
| |
| <p><a href="http://en.wikipedia.org/wiki/Lidstone_smoothing">Additive smoothing</a> can be used by |
| setting the parameter $\lambda$ (default to $1.0$).</p> |
| |
| <p><strong>Examples</strong></p> |
| |
| <div class="codetabs"> |
| <div data-lang="scala"> |
| |
| <p>Refer to the <a href="api/scala/index.html#org.apache.spark.ml.classification.NaiveBayes">Scala API docs</a> for more details.</p> |
| |
| <div class="highlight"><pre><span></span><span class="k">import</span> <span class="nn">org.apache.spark.ml.classification.NaiveBayes</span> |
| <span class="k">import</span> <span class="nn">org.apache.spark.ml.evaluation.MulticlassClassificationEvaluator</span> |
| |
| <span class="c1">// Load the data stored in LIBSVM format as a DataFrame.</span> |
| <span class="k">val</span> <span class="n">data</span> <span class="k">=</span> <span class="n">spark</span><span class="o">.</span><span class="n">read</span><span class="o">.</span><span class="n">format</span><span class="o">(</span><span class="s">"libsvm"</span><span class="o">).</span><span class="n">load</span><span class="o">(</span><span class="s">"data/mllib/sample_libsvm_data.txt"</span><span class="o">)</span> |
| |
| <span class="c1">// Split the data into training and test sets (30% held out for testing)</span> |
| <span class="k">val</span> <span class="nc">Array</span><span class="o">(</span><span class="n">trainingData</span><span class="o">,</span> <span class="n">testData</span><span class="o">)</span> <span class="k">=</span> <span class="n">data</span><span class="o">.</span><span class="n">randomSplit</span><span class="o">(</span><span class="nc">Array</span><span class="o">(</span><span class="mf">0.7</span><span class="o">,</span> <span class="mf">0.3</span><span class="o">),</span> <span class="n">seed</span> <span class="k">=</span> <span class="mi">1234L</span><span class="o">)</span> |
| |
| <span class="c1">// Train a NaiveBayes model.</span> |
| <span class="k">val</span> <span class="n">model</span> <span class="k">=</span> <span class="k">new</span> <span class="nc">NaiveBayes</span><span class="o">()</span> |
| <span class="o">.</span><span class="n">fit</span><span class="o">(</span><span class="n">trainingData</span><span class="o">)</span> |
| |
| <span class="c1">// Select example rows to display.</span> |
| <span class="k">val</span> <span class="n">predictions</span> <span class="k">=</span> <span class="n">model</span><span class="o">.</span><span class="n">transform</span><span class="o">(</span><span class="n">testData</span><span class="o">)</span> |
| <span class="n">predictions</span><span class="o">.</span><span class="n">show</span><span class="o">()</span> |
| |
| <span class="c1">// Select (prediction, true label) and compute test error</span> |
| <span class="k">val</span> <span class="n">evaluator</span> <span class="k">=</span> <span class="k">new</span> <span class="nc">MulticlassClassificationEvaluator</span><span class="o">()</span> |
| <span class="o">.</span><span class="n">setLabelCol</span><span class="o">(</span><span class="s">"label"</span><span class="o">)</span> |
| <span class="o">.</span><span class="n">setPredictionCol</span><span class="o">(</span><span class="s">"prediction"</span><span class="o">)</span> |
| <span class="o">.</span><span class="n">setMetricName</span><span class="o">(</span><span class="s">"accuracy"</span><span class="o">)</span> |
| <span class="k">val</span> <span class="n">accuracy</span> <span class="k">=</span> <span class="n">evaluator</span><span class="o">.</span><span class="n">evaluate</span><span class="o">(</span><span class="n">predictions</span><span class="o">)</span> |
| <span class="n">println</span><span class="o">(</span><span class="s">s"Test set accuracy = </span><span class="si">$accuracy</span><span class="s">"</span><span class="o">)</span> |
| </pre></div> |
| <div><small>Find full example code at "examples/src/main/scala/org/apache/spark/examples/ml/NaiveBayesExample.scala" in the Spark repo.</small></div> |
| </div> |
| |
| <div data-lang="java"> |
| |
| <p>Refer to the <a href="api/java/org/apache/spark/ml/classification/NaiveBayes.html">Java API docs</a> for more details.</p> |
| |
| <div class="highlight"><pre><span></span><span class="kn">import</span> <span class="nn">org.apache.spark.ml.classification.NaiveBayes</span><span class="o">;</span> |
| <span class="kn">import</span> <span class="nn">org.apache.spark.ml.classification.NaiveBayesModel</span><span class="o">;</span> |
| <span class="kn">import</span> <span class="nn">org.apache.spark.ml.evaluation.MulticlassClassificationEvaluator</span><span class="o">;</span> |
| <span class="kn">import</span> <span class="nn">org.apache.spark.sql.Dataset</span><span class="o">;</span> |
| <span class="kn">import</span> <span class="nn">org.apache.spark.sql.Row</span><span class="o">;</span> |
| <span class="kn">import</span> <span class="nn">org.apache.spark.sql.SparkSession</span><span class="o">;</span> |
| |
| <span class="c1">// Load training data</span> |
| <span class="n">Dataset</span><span class="o"><</span><span class="n">Row</span><span class="o">></span> <span class="n">dataFrame</span> <span class="o">=</span> |
| <span class="n">spark</span><span class="o">.</span><span class="na">read</span><span class="o">().</span><span class="na">format</span><span class="o">(</span><span class="s">"libsvm"</span><span class="o">).</span><span class="na">load</span><span class="o">(</span><span class="s">"data/mllib/sample_libsvm_data.txt"</span><span class="o">);</span> |
| <span class="c1">// Split the data into train and test</span> |
| <span class="n">Dataset</span><span class="o"><</span><span class="n">Row</span><span class="o">>[]</span> <span class="n">splits</span> <span class="o">=</span> <span class="n">dataFrame</span><span class="o">.</span><span class="na">randomSplit</span><span class="o">(</span><span class="k">new</span> <span class="kt">double</span><span class="o">[]{</span><span class="mf">0.6</span><span class="o">,</span> <span class="mf">0.4</span><span class="o">},</span> <span class="mi">1234L</span><span class="o">);</span> |
| <span class="n">Dataset</span><span class="o"><</span><span class="n">Row</span><span class="o">></span> <span class="n">train</span> <span class="o">=</span> <span class="n">splits</span><span class="o">[</span><span class="mi">0</span><span class="o">];</span> |
| <span class="n">Dataset</span><span class="o"><</span><span class="n">Row</span><span class="o">></span> <span class="n">test</span> <span class="o">=</span> <span class="n">splits</span><span class="o">[</span><span class="mi">1</span><span class="o">];</span> |
| |
| <span class="c1">// create the trainer and set its parameters</span> |
| <span class="n">NaiveBayes</span> <span class="n">nb</span> <span class="o">=</span> <span class="k">new</span> <span class="n">NaiveBayes</span><span class="o">();</span> |
| |
| <span class="c1">// train the model</span> |
| <span class="n">NaiveBayesModel</span> <span class="n">model</span> <span class="o">=</span> <span class="n">nb</span><span class="o">.</span><span class="na">fit</span><span class="o">(</span><span class="n">train</span><span class="o">);</span> |
| |
| <span class="c1">// Select example rows to display.</span> |
| <span class="n">Dataset</span><span class="o"><</span><span class="n">Row</span><span class="o">></span> <span class="n">predictions</span> <span class="o">=</span> <span class="n">model</span><span class="o">.</span><span class="na">transform</span><span class="o">(</span><span class="n">test</span><span class="o">);</span> |
| <span class="n">predictions</span><span class="o">.</span><span class="na">show</span><span class="o">();</span> |
| |
| <span class="c1">// compute accuracy on the test set</span> |
| <span class="n">MulticlassClassificationEvaluator</span> <span class="n">evaluator</span> <span class="o">=</span> <span class="k">new</span> <span class="n">MulticlassClassificationEvaluator</span><span class="o">()</span> |
| <span class="o">.</span><span class="na">setLabelCol</span><span class="o">(</span><span class="s">"label"</span><span class="o">)</span> |
| <span class="o">.</span><span class="na">setPredictionCol</span><span class="o">(</span><span class="s">"prediction"</span><span class="o">)</span> |
| <span class="o">.</span><span class="na">setMetricName</span><span class="o">(</span><span class="s">"accuracy"</span><span class="o">);</span> |
| <span class="kt">double</span> <span class="n">accuracy</span> <span class="o">=</span> <span class="n">evaluator</span><span class="o">.</span><span class="na">evaluate</span><span class="o">(</span><span class="n">predictions</span><span class="o">);</span> |
| <span class="n">System</span><span class="o">.</span><span class="na">out</span><span class="o">.</span><span class="na">println</span><span class="o">(</span><span class="s">"Test set accuracy = "</span> <span class="o">+</span> <span class="n">accuracy</span><span class="o">);</span> |
| </pre></div> |
| <div><small>Find full example code at "examples/src/main/java/org/apache/spark/examples/ml/JavaNaiveBayesExample.java" in the Spark repo.</small></div> |
| </div> |
| |
| <div data-lang="python"> |
| |
| <p>Refer to the <a href="api/python/pyspark.ml.html#pyspark.ml.classification.NaiveBayes">Python API docs</a> for more details.</p> |
| |
| <div class="highlight"><pre><span></span><span class="kn">from</span> <span class="nn">pyspark.ml.classification</span> <span class="kn">import</span> <span class="n">NaiveBayes</span> |
| <span class="kn">from</span> <span class="nn">pyspark.ml.evaluation</span> <span class="kn">import</span> <span class="n">MulticlassClassificationEvaluator</span> |
| |
| <span class="c1"># Load training data</span> |
| <span class="n">data</span> <span class="o">=</span> <span class="n">spark</span><span class="o">.</span><span class="n">read</span><span class="o">.</span><span class="n">format</span><span class="p">(</span><span class="s2">"libsvm"</span><span class="p">)</span> \ |
| <span class="o">.</span><span class="n">load</span><span class="p">(</span><span class="s2">"data/mllib/sample_libsvm_data.txt"</span><span class="p">)</span> |
| |
| <span class="c1"># Split the data into train and test</span> |
| <span class="n">splits</span> <span class="o">=</span> <span class="n">data</span><span class="o">.</span><span class="n">randomSplit</span><span class="p">([</span><span class="mf">0.6</span><span class="p">,</span> <span class="mf">0.4</span><span class="p">],</span> <span class="mi">1234</span><span class="p">)</span> |
| <span class="n">train</span> <span class="o">=</span> <span class="n">splits</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span> |
| <span class="n">test</span> <span class="o">=</span> <span class="n">splits</span><span class="p">[</span><span class="mi">1</span><span class="p">]</span> |
| |
| <span class="c1"># create the trainer and set its parameters</span> |
| <span class="n">nb</span> <span class="o">=</span> <span class="n">NaiveBayes</span><span class="p">(</span><span class="n">smoothing</span><span class="o">=</span><span class="mf">1.0</span><span class="p">,</span> <span class="n">modelType</span><span class="o">=</span><span class="s2">"multinomial"</span><span class="p">)</span> |
| |
| <span class="c1"># train the model</span> |
| <span class="n">model</span> <span class="o">=</span> <span class="n">nb</span><span class="o">.</span><span class="n">fit</span><span class="p">(</span><span class="n">train</span><span class="p">)</span> |
| |
| <span class="c1"># select example rows to display.</span> |
| <span class="n">predictions</span> <span class="o">=</span> <span class="n">model</span><span class="o">.</span><span class="n">transform</span><span class="p">(</span><span class="n">test</span><span class="p">)</span> |
| <span class="n">predictions</span><span class="o">.</span><span class="n">show</span><span class="p">()</span> |
| |
| <span class="c1"># compute accuracy on the test set</span> |
| <span class="n">evaluator</span> <span class="o">=</span> <span class="n">MulticlassClassificationEvaluator</span><span class="p">(</span><span class="n">labelCol</span><span class="o">=</span><span class="s2">"label"</span><span class="p">,</span> <span class="n">predictionCol</span><span class="o">=</span><span class="s2">"prediction"</span><span class="p">,</span> |
| <span class="n">metricName</span><span class="o">=</span><span class="s2">"accuracy"</span><span class="p">)</span> |
| <span class="n">accuracy</span> <span class="o">=</span> <span class="n">evaluator</span><span class="o">.</span><span class="n">evaluate</span><span class="p">(</span><span class="n">predictions</span><span class="p">)</span> |
| <span class="k">print</span><span class="p">(</span><span class="s2">"Test set accuracy = "</span> <span class="o">+</span> <span class="nb">str</span><span class="p">(</span><span class="n">accuracy</span><span class="p">))</span> |
| </pre></div> |
| <div><small>Find full example code at "examples/src/main/python/ml/naive_bayes_example.py" in the Spark repo.</small></div> |
| </div> |
| |
| <div data-lang="r"> |
| |
| <p>Refer to the <a href="api/R/spark.naiveBayes.html">R API docs</a> for more details.</p> |
| |
| <div class="highlight"><pre><span></span><span class="c1"># Fit a Bernoulli naive Bayes model with spark.naiveBayes</span> |
| titanic <span class="o"><-</span> <span class="kp">as.data.frame</span><span class="p">(</span>Titanic<span class="p">)</span> |
| titanicDF <span class="o"><-</span> createDataFrame<span class="p">(</span>titanic<span class="p">[</span>titanic<span class="o">$</span>Freq <span class="o">></span> <span class="m">0</span><span class="p">,</span> <span class="m">-5</span><span class="p">])</span> |
| nbDF <span class="o"><-</span> titanicDF |
| nbTestDF <span class="o"><-</span> titanicDF |
| nbModel <span class="o"><-</span> spark.naiveBayes<span class="p">(</span>nbDF<span class="p">,</span> Survived <span class="o">~</span> Class <span class="o">+</span> Sex <span class="o">+</span> Age<span class="p">)</span> |
| |
| <span class="c1"># Model summary</span> |
| <span class="kp">summary</span><span class="p">(</span>nbModel<span class="p">)</span> |
| |
| <span class="c1"># Prediction</span> |
| nbPredictions <span class="o"><-</span> predict<span class="p">(</span>nbModel<span class="p">,</span> nbTestDF<span class="p">)</span> |
| <span class="kp">head</span><span class="p">(</span>nbPredictions<span class="p">)</span> |
| </pre></div> |
| <div><small>Find full example code at "examples/src/main/r/ml/naiveBayes.R" in the Spark repo.</small></div> |
| </div> |
| |
| </div> |
| |
| <h1 id="regression">Regression</h1> |
| |
| <h2 id="linear-regression">Linear regression</h2> |
| |
| <p>The interface for working with linear regression models and model |
| summaries is similar to the logistic regression case.</p> |
| |
| <blockquote> |
| <p>When fitting LinearRegressionModel without intercept on dataset with constant nonzero column by “l-bfgs” solver, Spark MLlib outputs zero coefficients for constant nonzero columns. This behavior is the same as R glmnet but different from LIBSVM.</p> |
| </blockquote> |
| |
| <p><strong>Examples</strong></p> |
| |
| <p>The following |
| example demonstrates training an elastic net regularized linear |
| regression model and extracting model summary statistics.</p> |
| |
| <div class="codetabs"> |
| |
| <div data-lang="scala"> |
| |
| <p>More details on parameters can be found in the <a href="api/scala/index.html#org.apache.spark.ml.regression.LinearRegression">Scala API documentation</a>.</p> |
| |
| <div class="highlight"><pre><span></span><span class="k">import</span> <span class="nn">org.apache.spark.ml.regression.LinearRegression</span> |
| |
| <span class="c1">// Load training data</span> |
| <span class="k">val</span> <span class="n">training</span> <span class="k">=</span> <span class="n">spark</span><span class="o">.</span><span class="n">read</span><span class="o">.</span><span class="n">format</span><span class="o">(</span><span class="s">"libsvm"</span><span class="o">)</span> |
| <span class="o">.</span><span class="n">load</span><span class="o">(</span><span class="s">"data/mllib/sample_linear_regression_data.txt"</span><span class="o">)</span> |
| |
| <span class="k">val</span> <span class="n">lr</span> <span class="k">=</span> <span class="k">new</span> <span class="nc">LinearRegression</span><span class="o">()</span> |
| <span class="o">.</span><span class="n">setMaxIter</span><span class="o">(</span><span class="mi">10</span><span class="o">)</span> |
| <span class="o">.</span><span class="n">setRegParam</span><span class="o">(</span><span class="mf">0.3</span><span class="o">)</span> |
| <span class="o">.</span><span class="n">setElasticNetParam</span><span class="o">(</span><span class="mf">0.8</span><span class="o">)</span> |
| |
| <span class="c1">// Fit the model</span> |
| <span class="k">val</span> <span class="n">lrModel</span> <span class="k">=</span> <span class="n">lr</span><span class="o">.</span><span class="n">fit</span><span class="o">(</span><span class="n">training</span><span class="o">)</span> |
| |
| <span class="c1">// Print the coefficients and intercept for linear regression</span> |
| <span class="n">println</span><span class="o">(</span><span class="s">s"Coefficients: </span><span class="si">${</span><span class="n">lrModel</span><span class="o">.</span><span class="n">coefficients</span><span class="si">}</span><span class="s"> Intercept: </span><span class="si">${</span><span class="n">lrModel</span><span class="o">.</span><span class="n">intercept</span><span class="si">}</span><span class="s">"</span><span class="o">)</span> |
| |
| <span class="c1">// Summarize the model over the training set and print out some metrics</span> |
| <span class="k">val</span> <span class="n">trainingSummary</span> <span class="k">=</span> <span class="n">lrModel</span><span class="o">.</span><span class="n">summary</span> |
| <span class="n">println</span><span class="o">(</span><span class="s">s"numIterations: </span><span class="si">${</span><span class="n">trainingSummary</span><span class="o">.</span><span class="n">totalIterations</span><span class="si">}</span><span class="s">"</span><span class="o">)</span> |
| <span class="n">println</span><span class="o">(</span><span class="s">s"objectiveHistory: [</span><span class="si">${</span><span class="n">trainingSummary</span><span class="o">.</span><span class="n">objectiveHistory</span><span class="o">.</span><span class="n">mkString</span><span class="o">(</span><span class="s">","</span><span class="o">)</span><span class="si">}</span><span class="s">]"</span><span class="o">)</span> |
| <span class="n">trainingSummary</span><span class="o">.</span><span class="n">residuals</span><span class="o">.</span><span class="n">show</span><span class="o">()</span> |
| <span class="n">println</span><span class="o">(</span><span class="s">s"RMSE: </span><span class="si">${</span><span class="n">trainingSummary</span><span class="o">.</span><span class="n">rootMeanSquaredError</span><span class="si">}</span><span class="s">"</span><span class="o">)</span> |
| <span class="n">println</span><span class="o">(</span><span class="s">s"r2: </span><span class="si">${</span><span class="n">trainingSummary</span><span class="o">.</span><span class="n">r2</span><span class="si">}</span><span class="s">"</span><span class="o">)</span> |
| </pre></div> |
| <div><small>Find full example code at "examples/src/main/scala/org/apache/spark/examples/ml/LinearRegressionWithElasticNetExample.scala" in the Spark repo.</small></div> |
| </div> |
| |
| <div data-lang="java"> |
| |
| <p>More details on parameters can be found in the <a href="api/java/org/apache/spark/ml/regression/LinearRegression.html">Java API documentation</a>.</p> |
| |
| <div class="highlight"><pre><span></span><span class="kn">import</span> <span class="nn">org.apache.spark.ml.regression.LinearRegression</span><span class="o">;</span> |
| <span class="kn">import</span> <span class="nn">org.apache.spark.ml.regression.LinearRegressionModel</span><span class="o">;</span> |
| <span class="kn">import</span> <span class="nn">org.apache.spark.ml.regression.LinearRegressionTrainingSummary</span><span class="o">;</span> |
| <span class="kn">import</span> <span class="nn">org.apache.spark.ml.linalg.Vectors</span><span class="o">;</span> |
| <span class="kn">import</span> <span class="nn">org.apache.spark.sql.Dataset</span><span class="o">;</span> |
| <span class="kn">import</span> <span class="nn">org.apache.spark.sql.Row</span><span class="o">;</span> |
| <span class="kn">import</span> <span class="nn">org.apache.spark.sql.SparkSession</span><span class="o">;</span> |
| |
| <span class="c1">// Load training data.</span> |
| <span class="n">Dataset</span><span class="o"><</span><span class="n">Row</span><span class="o">></span> <span class="n">training</span> <span class="o">=</span> <span class="n">spark</span><span class="o">.</span><span class="na">read</span><span class="o">().</span><span class="na">format</span><span class="o">(</span><span class="s">"libsvm"</span><span class="o">)</span> |
| <span class="o">.</span><span class="na">load</span><span class="o">(</span><span class="s">"data/mllib/sample_linear_regression_data.txt"</span><span class="o">);</span> |
| |
| <span class="n">LinearRegression</span> <span class="n">lr</span> <span class="o">=</span> <span class="k">new</span> <span class="n">LinearRegression</span><span class="o">()</span> |
| <span class="o">.</span><span class="na">setMaxIter</span><span class="o">(</span><span class="mi">10</span><span class="o">)</span> |
| <span class="o">.</span><span class="na">setRegParam</span><span class="o">(</span><span class="mf">0.3</span><span class="o">)</span> |
| <span class="o">.</span><span class="na">setElasticNetParam</span><span class="o">(</span><span class="mf">0.8</span><span class="o">);</span> |
| |
| <span class="c1">// Fit the model.</span> |
| <span class="n">LinearRegressionModel</span> <span class="n">lrModel</span> <span class="o">=</span> <span class="n">lr</span><span class="o">.</span><span class="na">fit</span><span class="o">(</span><span class="n">training</span><span class="o">);</span> |
| |
| <span class="c1">// Print the coefficients and intercept for linear regression.</span> |
| <span class="n">System</span><span class="o">.</span><span class="na">out</span><span class="o">.</span><span class="na">println</span><span class="o">(</span><span class="s">"Coefficients: "</span> |
| <span class="o">+</span> <span class="n">lrModel</span><span class="o">.</span><span class="na">coefficients</span><span class="o">()</span> <span class="o">+</span> <span class="s">" Intercept: "</span> <span class="o">+</span> <span class="n">lrModel</span><span class="o">.</span><span class="na">intercept</span><span class="o">());</span> |
| |
| <span class="c1">// Summarize the model over the training set and print out some metrics.</span> |
| <span class="n">LinearRegressionTrainingSummary</span> <span class="n">trainingSummary</span> <span class="o">=</span> <span class="n">lrModel</span><span class="o">.</span><span class="na">summary</span><span class="o">();</span> |
| <span class="n">System</span><span class="o">.</span><span class="na">out</span><span class="o">.</span><span class="na">println</span><span class="o">(</span><span class="s">"numIterations: "</span> <span class="o">+</span> <span class="n">trainingSummary</span><span class="o">.</span><span class="na">totalIterations</span><span class="o">());</span> |
| <span class="n">System</span><span class="o">.</span><span class="na">out</span><span class="o">.</span><span class="na">println</span><span class="o">(</span><span class="s">"objectiveHistory: "</span> <span class="o">+</span> <span class="n">Vectors</span><span class="o">.</span><span class="na">dense</span><span class="o">(</span><span class="n">trainingSummary</span><span class="o">.</span><span class="na">objectiveHistory</span><span class="o">()));</span> |
| <span class="n">trainingSummary</span><span class="o">.</span><span class="na">residuals</span><span class="o">().</span><span class="na">show</span><span class="o">();</span> |
| <span class="n">System</span><span class="o">.</span><span class="na">out</span><span class="o">.</span><span class="na">println</span><span class="o">(</span><span class="s">"RMSE: "</span> <span class="o">+</span> <span class="n">trainingSummary</span><span class="o">.</span><span class="na">rootMeanSquaredError</span><span class="o">());</span> |
| <span class="n">System</span><span class="o">.</span><span class="na">out</span><span class="o">.</span><span class="na">println</span><span class="o">(</span><span class="s">"r2: "</span> <span class="o">+</span> <span class="n">trainingSummary</span><span class="o">.</span><span class="na">r2</span><span class="o">());</span> |
| </pre></div> |
| <div><small>Find full example code at "examples/src/main/java/org/apache/spark/examples/ml/JavaLinearRegressionWithElasticNetExample.java" in the Spark repo.</small></div> |
| </div> |
| |
| <div data-lang="python"> |
| <!--- TODO: Add python model summaries once implemented --> |
| |
| <p>More details on parameters can be found in the <a href="api/python/pyspark.ml.html#pyspark.ml.regression.LinearRegression">Python API documentation</a>.</p> |
| |
| <div class="highlight"><pre><span></span><span class="kn">from</span> <span class="nn">pyspark.ml.regression</span> <span class="kn">import</span> <span class="n">LinearRegression</span> |
| |
| <span class="c1"># Load training data</span> |
| <span class="n">training</span> <span class="o">=</span> <span class="n">spark</span><span class="o">.</span><span class="n">read</span><span class="o">.</span><span class="n">format</span><span class="p">(</span><span class="s2">"libsvm"</span><span class="p">)</span>\ |
| <span class="o">.</span><span class="n">load</span><span class="p">(</span><span class="s2">"data/mllib/sample_linear_regression_data.txt"</span><span class="p">)</span> |
| |
| <span class="n">lr</span> <span class="o">=</span> <span class="n">LinearRegression</span><span class="p">(</span><span class="n">maxIter</span><span class="o">=</span><span class="mi">10</span><span class="p">,</span> <span class="n">regParam</span><span class="o">=</span><span class="mf">0.3</span><span class="p">,</span> <span class="n">elasticNetParam</span><span class="o">=</span><span class="mf">0.8</span><span class="p">)</span> |
| |
| <span class="c1"># Fit the model</span> |
| <span class="n">lrModel</span> <span class="o">=</span> <span class="n">lr</span><span class="o">.</span><span class="n">fit</span><span class="p">(</span><span class="n">training</span><span class="p">)</span> |
| |
| <span class="c1"># Print the coefficients and intercept for linear regression</span> |
| <span class="k">print</span><span class="p">(</span><span class="s2">"Coefficients: </span><span class="si">%s</span><span class="s2">"</span> <span class="o">%</span> <span class="nb">str</span><span class="p">(</span><span class="n">lrModel</span><span class="o">.</span><span class="n">coefficients</span><span class="p">))</span> |
| <span class="k">print</span><span class="p">(</span><span class="s2">"Intercept: </span><span class="si">%s</span><span class="s2">"</span> <span class="o">%</span> <span class="nb">str</span><span class="p">(</span><span class="n">lrModel</span><span class="o">.</span><span class="n">intercept</span><span class="p">))</span> |
| |
| <span class="c1"># Summarize the model over the training set and print out some metrics</span> |
| <span class="n">trainingSummary</span> <span class="o">=</span> <span class="n">lrModel</span><span class="o">.</span><span class="n">summary</span> |
| <span class="k">print</span><span class="p">(</span><span class="s2">"numIterations: </span><span class="si">%d</span><span class="s2">"</span> <span class="o">%</span> <span class="n">trainingSummary</span><span class="o">.</span><span class="n">totalIterations</span><span class="p">)</span> |
| <span class="k">print</span><span class="p">(</span><span class="s2">"objectiveHistory: </span><span class="si">%s</span><span class="s2">"</span> <span class="o">%</span> <span class="nb">str</span><span class="p">(</span><span class="n">trainingSummary</span><span class="o">.</span><span class="n">objectiveHistory</span><span class="p">))</span> |
| <span class="n">trainingSummary</span><span class="o">.</span><span class="n">residuals</span><span class="o">.</span><span class="n">show</span><span class="p">()</span> |
| <span class="k">print</span><span class="p">(</span><span class="s2">"RMSE: </span><span class="si">%f</span><span class="s2">"</span> <span class="o">%</span> <span class="n">trainingSummary</span><span class="o">.</span><span class="n">rootMeanSquaredError</span><span class="p">)</span> |
| <span class="k">print</span><span class="p">(</span><span class="s2">"r2: </span><span class="si">%f</span><span class="s2">"</span> <span class="o">%</span> <span class="n">trainingSummary</span><span class="o">.</span><span class="n">r2</span><span class="p">)</span> |
| </pre></div> |
| <div><small>Find full example code at "examples/src/main/python/ml/linear_regression_with_elastic_net.py" in the Spark repo.</small></div> |
| </div> |
| |
| </div> |
| |
| <h2 id="generalized-linear-regression">Generalized linear regression</h2> |
| |
| <p>Contrasted with linear regression where the output is assumed to follow a Gaussian |
| distribution, <a href="https://en.wikipedia.org/wiki/Generalized_linear_model">generalized linear models</a> (GLMs) are specifications of linear models where the response variable $Y_i$ follows some |
| distribution from the <a href="https://en.wikipedia.org/wiki/Exponential_family">exponential family of distributions</a>. |
| Spark’s <code>GeneralizedLinearRegression</code> interface |
| allows for flexible specification of GLMs which can be used for various types of |
| prediction problems including linear regression, Poisson regression, logistic regression, and others. |
| Currently in <code>spark.ml</code>, only a subset of the exponential family distributions are supported and they are listed |
| <a href="#available-families">below</a>.</p> |
| |
| <p><strong>NOTE</strong>: Spark currently only supports up to 4096 features through its <code>GeneralizedLinearRegression</code> |
| interface, and will throw an exception if this constraint is exceeded. See the <a href="ml-advanced">advanced section</a> for more details. |
| Still, for linear and logistic regression, models with an increased number of features can be trained |
| using the <code>LinearRegression</code> and <code>LogisticRegression</code> estimators.</p> |
| |
| <p>GLMs require exponential family distributions that can be written in their “canonical” or “natural” form, aka |
| <a href="https://en.wikipedia.org/wiki/Natural_exponential_family">natural exponential family distributions</a>. The form of a natural exponential family distribution is given as:</p> |
| |
| <script type="math/tex; mode=display">f_Y(y|\theta, \tau) = h(y, \tau)\exp{\left( \frac{\theta \cdot y - A(\theta)}{d(\tau)} \right)}</script> |
| |
| <p>where $\theta$ is the parameter of interest and $\tau$ is a dispersion parameter. In a GLM the response variable $Y_i$ is assumed to be drawn from a natural exponential family distribution:</p> |
| |
| <script type="math/tex; mode=display">Y_i \sim f\left(\cdot|\theta_i, \tau \right)</script> |
| |
| <p>where the parameter of interest $\theta_i$ is related to the expected value of the response variable $\mu_i$ by</p> |
| |
| <script type="math/tex; mode=display">\mu_i = A'(\theta_i)</script> |
| |
| <p>Here, $A’(\theta_i)$ is defined by the form of the distribution selected. GLMs also allow specification |
| of a link function, which defines the relationship between the expected value of the response variable $\mu_i$ |
| and the so called <em>linear predictor</em> $\eta_i$:</p> |
| |
| <script type="math/tex; mode=display">g(\mu_i) = \eta_i = \vec{x_i}^T \cdot \vec{\beta}</script> |
| |
| <p>Often, the link function is chosen such that $A’ = g^{-1}$, which yields a simplified relationship |
| between the parameter of interest $\theta$ and the linear predictor $\eta$. In this case, the link |
| function $g(\mu)$ is said to be the “canonical” link function.</p> |
| |
| <script type="math/tex; mode=display">\theta_i = A'^{-1}(\mu_i) = g(g^{-1}(\eta_i)) = \eta_i</script> |
| |
| <p>A GLM finds the regression coefficients $\vec{\beta}$ which maximize the likelihood function.</p> |
| |
| <script type="math/tex; mode=display">\max_{\vec{\beta}} \mathcal{L}(\vec{\theta}|\vec{y},X) = |
| \prod_{i=1}^{N} h(y_i, \tau) \exp{\left(\frac{y_i\theta_i - A(\theta_i)}{d(\tau)}\right)}</script> |
| |
| <p>where the parameter of interest $\theta_i$ is related to the regression coefficients $\vec{\beta}$ |
| by</p> |
| |
| <script type="math/tex; mode=display">\theta_i = A'^{-1}(g^{-1}(\vec{x_i} \cdot \vec{\beta}))</script> |
| |
| <p>Spark’s generalized linear regression interface also provides summary statistics for diagnosing the |
| fit of GLM models, including residuals, p-values, deviances, the Akaike information criterion, and |
| others.</p> |
| |
| <p><a href="http://data.princeton.edu/wws509/notes/">See here</a> for a more comprehensive review of GLMs and their applications.</p> |
| |
| <h3 id="available-families">Available families</h3> |
| |
| <table class="table"> |
| <thead> |
| <tr> |
| <th>Family</th> |
| <th>Response Type</th> |
| <th>Supported Links</th></tr> |
| </thead> |
| <tbody> |
| <tr> |
| <td>Gaussian</td> |
| <td>Continuous</td> |
| <td>Identity*, Log, Inverse</td> |
| </tr> |
| <tr> |
| <td>Binomial</td> |
| <td>Binary</td> |
| <td>Logit*, Probit, CLogLog</td> |
| </tr> |
| <tr> |
| <td>Poisson</td> |
| <td>Count</td> |
| <td>Log*, Identity, Sqrt</td> |
| </tr> |
| <tr> |
| <td>Gamma</td> |
| <td>Continuous</td> |
| <td>Inverse*, Idenity, Log</td> |
| </tr> |
| <tr> |
| <td>Tweedie</td> |
| <td>Zero-inflated continuous</td> |
| <td>Power link function</td> |
| </tr> |
| <tfoot><tr><td colspan="4">* Canonical Link</td></tr></tfoot> |
| </tbody> |
| </table> |
| |
| <p><strong>Examples</strong></p> |
| |
| <p>The following example demonstrates training a GLM with a Gaussian response and identity link |
| function and extracting model summary statistics.</p> |
| |
| <div class="codetabs"> |
| |
| <div data-lang="scala"> |
| |
| <p>Refer to the <a href="api/scala/index.html#org.apache.spark.ml.regression.GeneralizedLinearRegression">Scala API docs</a> for more details.</p> |
| |
| <div class="highlight"><pre><span></span><span class="k">import</span> <span class="nn">org.apache.spark.ml.regression.GeneralizedLinearRegression</span> |
| |
| <span class="c1">// Load training data</span> |
| <span class="k">val</span> <span class="n">dataset</span> <span class="k">=</span> <span class="n">spark</span><span class="o">.</span><span class="n">read</span><span class="o">.</span><span class="n">format</span><span class="o">(</span><span class="s">"libsvm"</span><span class="o">)</span> |
| <span class="o">.</span><span class="n">load</span><span class="o">(</span><span class="s">"data/mllib/sample_linear_regression_data.txt"</span><span class="o">)</span> |
| |
| <span class="k">val</span> <span class="n">glr</span> <span class="k">=</span> <span class="k">new</span> <span class="nc">GeneralizedLinearRegression</span><span class="o">()</span> |
| <span class="o">.</span><span class="n">setFamily</span><span class="o">(</span><span class="s">"gaussian"</span><span class="o">)</span> |
| <span class="o">.</span><span class="n">setLink</span><span class="o">(</span><span class="s">"identity"</span><span class="o">)</span> |
| <span class="o">.</span><span class="n">setMaxIter</span><span class="o">(</span><span class="mi">10</span><span class="o">)</span> |
| <span class="o">.</span><span class="n">setRegParam</span><span class="o">(</span><span class="mf">0.3</span><span class="o">)</span> |
| |
| <span class="c1">// Fit the model</span> |
| <span class="k">val</span> <span class="n">model</span> <span class="k">=</span> <span class="n">glr</span><span class="o">.</span><span class="n">fit</span><span class="o">(</span><span class="n">dataset</span><span class="o">)</span> |
| |
| <span class="c1">// Print the coefficients and intercept for generalized linear regression model</span> |
| <span class="n">println</span><span class="o">(</span><span class="s">s"Coefficients: </span><span class="si">${</span><span class="n">model</span><span class="o">.</span><span class="n">coefficients</span><span class="si">}</span><span class="s">"</span><span class="o">)</span> |
| <span class="n">println</span><span class="o">(</span><span class="s">s"Intercept: </span><span class="si">${</span><span class="n">model</span><span class="o">.</span><span class="n">intercept</span><span class="si">}</span><span class="s">"</span><span class="o">)</span> |
| |
| <span class="c1">// Summarize the model over the training set and print out some metrics</span> |
| <span class="k">val</span> <span class="n">summary</span> <span class="k">=</span> <span class="n">model</span><span class="o">.</span><span class="n">summary</span> |
| <span class="n">println</span><span class="o">(</span><span class="s">s"Coefficient Standard Errors: </span><span class="si">${</span><span class="n">summary</span><span class="o">.</span><span class="n">coefficientStandardErrors</span><span class="o">.</span><span class="n">mkString</span><span class="o">(</span><span class="s">","</span><span class="o">)</span><span class="si">}</span><span class="s">"</span><span class="o">)</span> |
| <span class="n">println</span><span class="o">(</span><span class="s">s"T Values: </span><span class="si">${</span><span class="n">summary</span><span class="o">.</span><span class="n">tValues</span><span class="o">.</span><span class="n">mkString</span><span class="o">(</span><span class="s">","</span><span class="o">)</span><span class="si">}</span><span class="s">"</span><span class="o">)</span> |
| <span class="n">println</span><span class="o">(</span><span class="s">s"P Values: </span><span class="si">${</span><span class="n">summary</span><span class="o">.</span><span class="n">pValues</span><span class="o">.</span><span class="n">mkString</span><span class="o">(</span><span class="s">","</span><span class="o">)</span><span class="si">}</span><span class="s">"</span><span class="o">)</span> |
| <span class="n">println</span><span class="o">(</span><span class="s">s"Dispersion: </span><span class="si">${</span><span class="n">summary</span><span class="o">.</span><span class="n">dispersion</span><span class="si">}</span><span class="s">"</span><span class="o">)</span> |
| <span class="n">println</span><span class="o">(</span><span class="s">s"Null Deviance: </span><span class="si">${</span><span class="n">summary</span><span class="o">.</span><span class="n">nullDeviance</span><span class="si">}</span><span class="s">"</span><span class="o">)</span> |
| <span class="n">println</span><span class="o">(</span><span class="s">s"Residual Degree Of Freedom Null: </span><span class="si">${</span><span class="n">summary</span><span class="o">.</span><span class="n">residualDegreeOfFreedomNull</span><span class="si">}</span><span class="s">"</span><span class="o">)</span> |
| <span class="n">println</span><span class="o">(</span><span class="s">s"Deviance: </span><span class="si">${</span><span class="n">summary</span><span class="o">.</span><span class="n">deviance</span><span class="si">}</span><span class="s">"</span><span class="o">)</span> |
| <span class="n">println</span><span class="o">(</span><span class="s">s"Residual Degree Of Freedom: </span><span class="si">${</span><span class="n">summary</span><span class="o">.</span><span class="n">residualDegreeOfFreedom</span><span class="si">}</span><span class="s">"</span><span class="o">)</span> |
| <span class="n">println</span><span class="o">(</span><span class="s">s"AIC: </span><span class="si">${</span><span class="n">summary</span><span class="o">.</span><span class="n">aic</span><span class="si">}</span><span class="s">"</span><span class="o">)</span> |
| <span class="n">println</span><span class="o">(</span><span class="s">"Deviance Residuals: "</span><span class="o">)</span> |
| <span class="n">summary</span><span class="o">.</span><span class="n">residuals</span><span class="o">().</span><span class="n">show</span><span class="o">()</span> |
| </pre></div> |
| <div><small>Find full example code at "examples/src/main/scala/org/apache/spark/examples/ml/GeneralizedLinearRegressionExample.scala" in the Spark repo.</small></div> |
| </div> |
| |
| <div data-lang="java"> |
| |
| <p>Refer to the <a href="api/java/org/apache/spark/ml/regression/GeneralizedLinearRegression.html">Java API docs</a> for more details.</p> |
| |
| <div class="highlight"><pre><span></span><span class="kn">import</span> <span class="nn">java.util.Arrays</span><span class="o">;</span> |
| |
| <span class="kn">import</span> <span class="nn">org.apache.spark.ml.regression.GeneralizedLinearRegression</span><span class="o">;</span> |
| <span class="kn">import</span> <span class="nn">org.apache.spark.ml.regression.GeneralizedLinearRegressionModel</span><span class="o">;</span> |
| <span class="kn">import</span> <span class="nn">org.apache.spark.ml.regression.GeneralizedLinearRegressionTrainingSummary</span><span class="o">;</span> |
| <span class="kn">import</span> <span class="nn">org.apache.spark.sql.Dataset</span><span class="o">;</span> |
| <span class="kn">import</span> <span class="nn">org.apache.spark.sql.Row</span><span class="o">;</span> |
| |
| <span class="c1">// Load training data</span> |
| <span class="n">Dataset</span><span class="o"><</span><span class="n">Row</span><span class="o">></span> <span class="n">dataset</span> <span class="o">=</span> <span class="n">spark</span><span class="o">.</span><span class="na">read</span><span class="o">().</span><span class="na">format</span><span class="o">(</span><span class="s">"libsvm"</span><span class="o">)</span> |
| <span class="o">.</span><span class="na">load</span><span class="o">(</span><span class="s">"data/mllib/sample_linear_regression_data.txt"</span><span class="o">);</span> |
| |
| <span class="n">GeneralizedLinearRegression</span> <span class="n">glr</span> <span class="o">=</span> <span class="k">new</span> <span class="n">GeneralizedLinearRegression</span><span class="o">()</span> |
| <span class="o">.</span><span class="na">setFamily</span><span class="o">(</span><span class="s">"gaussian"</span><span class="o">)</span> |
| <span class="o">.</span><span class="na">setLink</span><span class="o">(</span><span class="s">"identity"</span><span class="o">)</span> |
| <span class="o">.</span><span class="na">setMaxIter</span><span class="o">(</span><span class="mi">10</span><span class="o">)</span> |
| <span class="o">.</span><span class="na">setRegParam</span><span class="o">(</span><span class="mf">0.3</span><span class="o">);</span> |
| |
| <span class="c1">// Fit the model</span> |
| <span class="n">GeneralizedLinearRegressionModel</span> <span class="n">model</span> <span class="o">=</span> <span class="n">glr</span><span class="o">.</span><span class="na">fit</span><span class="o">(</span><span class="n">dataset</span><span class="o">);</span> |
| |
| <span class="c1">// Print the coefficients and intercept for generalized linear regression model</span> |
| <span class="n">System</span><span class="o">.</span><span class="na">out</span><span class="o">.</span><span class="na">println</span><span class="o">(</span><span class="s">"Coefficients: "</span> <span class="o">+</span> <span class="n">model</span><span class="o">.</span><span class="na">coefficients</span><span class="o">());</span> |
| <span class="n">System</span><span class="o">.</span><span class="na">out</span><span class="o">.</span><span class="na">println</span><span class="o">(</span><span class="s">"Intercept: "</span> <span class="o">+</span> <span class="n">model</span><span class="o">.</span><span class="na">intercept</span><span class="o">());</span> |
| |
| <span class="c1">// Summarize the model over the training set and print out some metrics</span> |
| <span class="n">GeneralizedLinearRegressionTrainingSummary</span> <span class="n">summary</span> <span class="o">=</span> <span class="n">model</span><span class="o">.</span><span class="na">summary</span><span class="o">();</span> |
| <span class="n">System</span><span class="o">.</span><span class="na">out</span><span class="o">.</span><span class="na">println</span><span class="o">(</span><span class="s">"Coefficient Standard Errors: "</span> |
| <span class="o">+</span> <span class="n">Arrays</span><span class="o">.</span><span class="na">toString</span><span class="o">(</span><span class="n">summary</span><span class="o">.</span><span class="na">coefficientStandardErrors</span><span class="o">()));</span> |
| <span class="n">System</span><span class="o">.</span><span class="na">out</span><span class="o">.</span><span class="na">println</span><span class="o">(</span><span class="s">"T Values: "</span> <span class="o">+</span> <span class="n">Arrays</span><span class="o">.</span><span class="na">toString</span><span class="o">(</span><span class="n">summary</span><span class="o">.</span><span class="na">tValues</span><span class="o">()));</span> |
| <span class="n">System</span><span class="o">.</span><span class="na">out</span><span class="o">.</span><span class="na">println</span><span class="o">(</span><span class="s">"P Values: "</span> <span class="o">+</span> <span class="n">Arrays</span><span class="o">.</span><span class="na">toString</span><span class="o">(</span><span class="n">summary</span><span class="o">.</span><span class="na">pValues</span><span class="o">()));</span> |
| <span class="n">System</span><span class="o">.</span><span class="na">out</span><span class="o">.</span><span class="na">println</span><span class="o">(</span><span class="s">"Dispersion: "</span> <span class="o">+</span> <span class="n">summary</span><span class="o">.</span><span class="na">dispersion</span><span class="o">());</span> |
| <span class="n">System</span><span class="o">.</span><span class="na">out</span><span class="o">.</span><span class="na">println</span><span class="o">(</span><span class="s">"Null Deviance: "</span> <span class="o">+</span> <span class="n">summary</span><span class="o">.</span><span class="na">nullDeviance</span><span class="o">());</span> |
| <span class="n">System</span><span class="o">.</span><span class="na">out</span><span class="o">.</span><span class="na">println</span><span class="o">(</span><span class="s">"Residual Degree Of Freedom Null: "</span> <span class="o">+</span> <span class="n">summary</span><span class="o">.</span><span class="na">residualDegreeOfFreedomNull</span><span class="o">());</span> |
| <span class="n">System</span><span class="o">.</span><span class="na">out</span><span class="o">.</span><span class="na">println</span><span class="o">(</span><span class="s">"Deviance: "</span> <span class="o">+</span> <span class="n">summary</span><span class="o">.</span><span class="na">deviance</span><span class="o">());</span> |
| <span class="n">System</span><span class="o">.</span><span class="na">out</span><span class="o">.</span><span class="na">println</span><span class="o">(</span><span class="s">"Residual Degree Of Freedom: "</span> <span class="o">+</span> <span class="n">summary</span><span class="o">.</span><span class="na">residualDegreeOfFreedom</span><span class="o">());</span> |
| <span class="n">System</span><span class="o">.</span><span class="na">out</span><span class="o">.</span><span class="na">println</span><span class="o">(</span><span class="s">"AIC: "</span> <span class="o">+</span> <span class="n">summary</span><span class="o">.</span><span class="na">aic</span><span class="o">());</span> |
| <span class="n">System</span><span class="o">.</span><span class="na">out</span><span class="o">.</span><span class="na">println</span><span class="o">(</span><span class="s">"Deviance Residuals: "</span><span class="o">);</span> |
| <span class="n">summary</span><span class="o">.</span><span class="na">residuals</span><span class="o">().</span><span class="na">show</span><span class="o">();</span> |
| </pre></div> |
| <div><small>Find full example code at "examples/src/main/java/org/apache/spark/examples/ml/JavaGeneralizedLinearRegressionExample.java" in the Spark repo.</small></div> |
| </div> |
| |
| <div data-lang="python"> |
| |
| <p>Refer to the <a href="api/python/pyspark.ml.html#pyspark.ml.regression.GeneralizedLinearRegression">Python API docs</a> for more details.</p> |
| |
| <div class="highlight"><pre><span></span><span class="kn">from</span> <span class="nn">pyspark.ml.regression</span> <span class="kn">import</span> <span class="n">GeneralizedLinearRegression</span> |
| |
| <span class="c1"># Load training data</span> |
| <span class="n">dataset</span> <span class="o">=</span> <span class="n">spark</span><span class="o">.</span><span class="n">read</span><span class="o">.</span><span class="n">format</span><span class="p">(</span><span class="s2">"libsvm"</span><span class="p">)</span>\ |
| <span class="o">.</span><span class="n">load</span><span class="p">(</span><span class="s2">"data/mllib/sample_linear_regression_data.txt"</span><span class="p">)</span> |
| |
| <span class="n">glr</span> <span class="o">=</span> <span class="n">GeneralizedLinearRegression</span><span class="p">(</span><span class="n">family</span><span class="o">=</span><span class="s2">"gaussian"</span><span class="p">,</span> <span class="n">link</span><span class="o">=</span><span class="s2">"identity"</span><span class="p">,</span> <span class="n">maxIter</span><span class="o">=</span><span class="mi">10</span><span class="p">,</span> <span class="n">regParam</span><span class="o">=</span><span class="mf">0.3</span><span class="p">)</span> |
| |
| <span class="c1"># Fit the model</span> |
| <span class="n">model</span> <span class="o">=</span> <span class="n">glr</span><span class="o">.</span><span class="n">fit</span><span class="p">(</span><span class="n">dataset</span><span class="p">)</span> |
| |
| <span class="c1"># Print the coefficients and intercept for generalized linear regression model</span> |
| <span class="k">print</span><span class="p">(</span><span class="s2">"Coefficients: "</span> <span class="o">+</span> <span class="nb">str</span><span class="p">(</span><span class="n">model</span><span class="o">.</span><span class="n">coefficients</span><span class="p">))</span> |
| <span class="k">print</span><span class="p">(</span><span class="s2">"Intercept: "</span> <span class="o">+</span> <span class="nb">str</span><span class="p">(</span><span class="n">model</span><span class="o">.</span><span class="n">intercept</span><span class="p">))</span> |
| |
| <span class="c1"># Summarize the model over the training set and print out some metrics</span> |
| <span class="n">summary</span> <span class="o">=</span> <span class="n">model</span><span class="o">.</span><span class="n">summary</span> |
| <span class="k">print</span><span class="p">(</span><span class="s2">"Coefficient Standard Errors: "</span> <span class="o">+</span> <span class="nb">str</span><span class="p">(</span><span class="n">summary</span><span class="o">.</span><span class="n">coefficientStandardErrors</span><span class="p">))</span> |
| <span class="k">print</span><span class="p">(</span><span class="s2">"T Values: "</span> <span class="o">+</span> <span class="nb">str</span><span class="p">(</span><span class="n">summary</span><span class="o">.</span><span class="n">tValues</span><span class="p">))</span> |
| <span class="k">print</span><span class="p">(</span><span class="s2">"P Values: "</span> <span class="o">+</span> <span class="nb">str</span><span class="p">(</span><span class="n">summary</span><span class="o">.</span><span class="n">pValues</span><span class="p">))</span> |
| <span class="k">print</span><span class="p">(</span><span class="s2">"Dispersion: "</span> <span class="o">+</span> <span class="nb">str</span><span class="p">(</span><span class="n">summary</span><span class="o">.</span><span class="n">dispersion</span><span class="p">))</span> |
| <span class="k">print</span><span class="p">(</span><span class="s2">"Null Deviance: "</span> <span class="o">+</span> <span class="nb">str</span><span class="p">(</span><span class="n">summary</span><span class="o">.</span><span class="n">nullDeviance</span><span class="p">))</span> |
| <span class="k">print</span><span class="p">(</span><span class="s2">"Residual Degree Of Freedom Null: "</span> <span class="o">+</span> <span class="nb">str</span><span class="p">(</span><span class="n">summary</span><span class="o">.</span><span class="n">residualDegreeOfFreedomNull</span><span class="p">))</span> |
| <span class="k">print</span><span class="p">(</span><span class="s2">"Deviance: "</span> <span class="o">+</span> <span class="nb">str</span><span class="p">(</span><span class="n">summary</span><span class="o">.</span><span class="n">deviance</span><span class="p">))</span> |
| <span class="k">print</span><span class="p">(</span><span class="s2">"Residual Degree Of Freedom: "</span> <span class="o">+</span> <span class="nb">str</span><span class="p">(</span><span class="n">summary</span><span class="o">.</span><span class="n">residualDegreeOfFreedom</span><span class="p">))</span> |
| <span class="k">print</span><span class="p">(</span><span class="s2">"AIC: "</span> <span class="o">+</span> <span class="nb">str</span><span class="p">(</span><span class="n">summary</span><span class="o">.</span><span class="n">aic</span><span class="p">))</span> |
| <span class="k">print</span><span class="p">(</span><span class="s2">"Deviance Residuals: "</span><span class="p">)</span> |
| <span class="n">summary</span><span class="o">.</span><span class="n">residuals</span><span class="p">()</span><span class="o">.</span><span class="n">show</span><span class="p">()</span> |
| </pre></div> |
| <div><small>Find full example code at "examples/src/main/python/ml/generalized_linear_regression_example.py" in the Spark repo.</small></div> |
| </div> |
| |
| <div data-lang="r"> |
| |
| <p>Refer to the <a href="api/R/spark.glm.html">R API docs</a> for more details.</p> |
| |
| <div class="highlight"><pre><span></span>training <span class="o"><-</span> read.df<span class="p">(</span><span class="s">"data/mllib/sample_multiclass_classification_data.txt"</span><span class="p">,</span> <span class="kn">source</span> <span class="o">=</span> <span class="s">"libsvm"</span><span class="p">)</span> |
| <span class="c1"># Fit a generalized linear model of family "gaussian" with spark.glm</span> |
| df_list <span class="o"><-</span> randomSplit<span class="p">(</span>training<span class="p">,</span> <span class="kt">c</span><span class="p">(</span><span class="m">7</span><span class="p">,</span> <span class="m">3</span><span class="p">),</span> <span class="m">2</span><span class="p">)</span> |
| gaussianDF <span class="o"><-</span> df_list<span class="p">[[</span><span class="m">1</span><span class="p">]]</span> |
| gaussianTestDF <span class="o"><-</span> df_list<span class="p">[[</span><span class="m">2</span><span class="p">]]</span> |
| gaussianGLM <span class="o"><-</span> spark.glm<span class="p">(</span>gaussianDF<span class="p">,</span> label <span class="o">~</span> features<span class="p">,</span> family <span class="o">=</span> <span class="s">"gaussian"</span><span class="p">)</span> |
| |
| <span class="c1"># Model summary</span> |
| <span class="kp">summary</span><span class="p">(</span>gaussianGLM<span class="p">)</span> |
| |
| <span class="c1"># Prediction</span> |
| gaussianPredictions <span class="o"><-</span> predict<span class="p">(</span>gaussianGLM<span class="p">,</span> gaussianTestDF<span class="p">)</span> |
| <span class="kp">head</span><span class="p">(</span>gaussianPredictions<span class="p">)</span> |
| |
| <span class="c1"># Fit a generalized linear model with glm (R-compliant)</span> |
| gaussianGLM2 <span class="o"><-</span> glm<span class="p">(</span>label <span class="o">~</span> features<span class="p">,</span> gaussianDF<span class="p">,</span> family <span class="o">=</span> <span class="s">"gaussian"</span><span class="p">)</span> |
| <span class="kp">summary</span><span class="p">(</span>gaussianGLM2<span class="p">)</span> |
| |
| <span class="c1"># Fit a generalized linear model of family "binomial" with spark.glm</span> |
| training2 <span class="o"><-</span> read.df<span class="p">(</span><span class="s">"data/mllib/sample_multiclass_classification_data.txt"</span><span class="p">,</span> <span class="kn">source</span> <span class="o">=</span> <span class="s">"libsvm"</span><span class="p">)</span> |
| training2 <span class="o"><-</span> <span class="kp">transform</span><span class="p">(</span>training2<span class="p">,</span> label <span class="o">=</span> cast<span class="p">(</span>training2<span class="o">$</span>label <span class="o">></span> <span class="m">1</span><span class="p">,</span> <span class="s">"integer"</span><span class="p">))</span> |
| df_list2 <span class="o"><-</span> randomSplit<span class="p">(</span>training2<span class="p">,</span> <span class="kt">c</span><span class="p">(</span><span class="m">7</span><span class="p">,</span> <span class="m">3</span><span class="p">),</span> <span class="m">2</span><span class="p">)</span> |
| binomialDF <span class="o"><-</span> df_list2<span class="p">[[</span><span class="m">1</span><span class="p">]]</span> |
| binomialTestDF <span class="o"><-</span> df_list2<span class="p">[[</span><span class="m">2</span><span class="p">]]</span> |
| binomialGLM <span class="o"><-</span> spark.glm<span class="p">(</span>binomialDF<span class="p">,</span> label <span class="o">~</span> features<span class="p">,</span> family <span class="o">=</span> <span class="s">"binomial"</span><span class="p">)</span> |
| |
| <span class="c1"># Model summary</span> |
| <span class="kp">summary</span><span class="p">(</span>binomialGLM<span class="p">)</span> |
| |
| <span class="c1"># Prediction</span> |
| binomialPredictions <span class="o"><-</span> predict<span class="p">(</span>binomialGLM<span class="p">,</span> binomialTestDF<span class="p">)</span> |
| <span class="kp">head</span><span class="p">(</span>binomialPredictions<span class="p">)</span> |
| |
| <span class="c1"># Fit a generalized linear model of family "tweedie" with spark.glm</span> |
| training3 <span class="o"><-</span> read.df<span class="p">(</span><span class="s">"data/mllib/sample_multiclass_classification_data.txt"</span><span class="p">,</span> <span class="kn">source</span> <span class="o">=</span> <span class="s">"libsvm"</span><span class="p">)</span> |
| tweedieDF <span class="o"><-</span> <span class="kp">transform</span><span class="p">(</span>training3<span class="p">,</span> label <span class="o">=</span> training3<span class="o">$</span>label <span class="o">*</span> <span class="kp">exp</span><span class="p">(</span>randn<span class="p">(</span><span class="m">10</span><span class="p">)))</span> |
| tweedieGLM <span class="o"><-</span> spark.glm<span class="p">(</span>tweedieDF<span class="p">,</span> label <span class="o">~</span> features<span class="p">,</span> family <span class="o">=</span> <span class="s">"tweedie"</span><span class="p">,</span> |
| var.power <span class="o">=</span> <span class="m">1.2</span><span class="p">,</span> link.power <span class="o">=</span> <span class="m">0</span><span class="p">)</span> |
| |
| <span class="c1"># Model summary</span> |
| <span class="kp">summary</span><span class="p">(</span>tweedieGLM<span class="p">)</span> |
| </pre></div> |
| <div><small>Find full example code at "examples/src/main/r/ml/glm.R" in the Spark repo.</small></div> |
| </div> |
| |
| </div> |
| |
| <h2 id="decision-tree-regression">Decision tree regression</h2> |
| |
| <p>Decision trees are a popular family of classification and regression methods. |
| More information about the <code>spark.ml</code> implementation can be found further in the <a href="#decision-trees">section on decision trees</a>.</p> |
| |
| <p><strong>Examples</strong></p> |
| |
| <p>The following examples load a dataset in LibSVM format, split it into training and test sets, train on the first dataset, and then evaluate on the held-out test set. |
| We use a feature transformer to index categorical features, adding metadata to the <code>DataFrame</code> which the Decision Tree algorithm can recognize.</p> |
| |
| <div class="codetabs"> |
| <div data-lang="scala"> |
| |
| <p>More details on parameters can be found in the <a href="api/scala/index.html#org.apache.spark.ml.regression.DecisionTreeRegressor">Scala API documentation</a>.</p> |
| |
| <div class="highlight"><pre><span></span><span class="k">import</span> <span class="nn">org.apache.spark.ml.Pipeline</span> |
| <span class="k">import</span> <span class="nn">org.apache.spark.ml.evaluation.RegressionEvaluator</span> |
| <span class="k">import</span> <span class="nn">org.apache.spark.ml.feature.VectorIndexer</span> |
| <span class="k">import</span> <span class="nn">org.apache.spark.ml.regression.DecisionTreeRegressionModel</span> |
| <span class="k">import</span> <span class="nn">org.apache.spark.ml.regression.DecisionTreeRegressor</span> |
| |
| <span class="c1">// Load the data stored in LIBSVM format as a DataFrame.</span> |
| <span class="k">val</span> <span class="n">data</span> <span class="k">=</span> <span class="n">spark</span><span class="o">.</span><span class="n">read</span><span class="o">.</span><span class="n">format</span><span class="o">(</span><span class="s">"libsvm"</span><span class="o">).</span><span class="n">load</span><span class="o">(</span><span class="s">"data/mllib/sample_libsvm_data.txt"</span><span class="o">)</span> |
| |
| <span class="c1">// Automatically identify categorical features, and index them.</span> |
| <span class="c1">// Here, we treat features with > 4 distinct values as continuous.</span> |
| <span class="k">val</span> <span class="n">featureIndexer</span> <span class="k">=</span> <span class="k">new</span> <span class="nc">VectorIndexer</span><span class="o">()</span> |
| <span class="o">.</span><span class="n">setInputCol</span><span class="o">(</span><span class="s">"features"</span><span class="o">)</span> |
| <span class="o">.</span><span class="n">setOutputCol</span><span class="o">(</span><span class="s">"indexedFeatures"</span><span class="o">)</span> |
| <span class="o">.</span><span class="n">setMaxCategories</span><span class="o">(</span><span class="mi">4</span><span class="o">)</span> |
| <span class="o">.</span><span class="n">fit</span><span class="o">(</span><span class="n">data</span><span class="o">)</span> |
| |
| <span class="c1">// Split the data into training and test sets (30% held out for testing).</span> |
| <span class="k">val</span> <span class="nc">Array</span><span class="o">(</span><span class="n">trainingData</span><span class="o">,</span> <span class="n">testData</span><span class="o">)</span> <span class="k">=</span> <span class="n">data</span><span class="o">.</span><span class="n">randomSplit</span><span class="o">(</span><span class="nc">Array</span><span class="o">(</span><span class="mf">0.7</span><span class="o">,</span> <span class="mf">0.3</span><span class="o">))</span> |
| |
| <span class="c1">// Train a DecisionTree model.</span> |
| <span class="k">val</span> <span class="n">dt</span> <span class="k">=</span> <span class="k">new</span> <span class="nc">DecisionTreeRegressor</span><span class="o">()</span> |
| <span class="o">.</span><span class="n">setLabelCol</span><span class="o">(</span><span class="s">"label"</span><span class="o">)</span> |
| <span class="o">.</span><span class="n">setFeaturesCol</span><span class="o">(</span><span class="s">"indexedFeatures"</span><span class="o">)</span> |
| |
| <span class="c1">// Chain indexer and tree in a Pipeline.</span> |
| <span class="k">val</span> <span class="n">pipeline</span> <span class="k">=</span> <span class="k">new</span> <span class="nc">Pipeline</span><span class="o">()</span> |
| <span class="o">.</span><span class="n">setStages</span><span class="o">(</span><span class="nc">Array</span><span class="o">(</span><span class="n">featureIndexer</span><span class="o">,</span> <span class="n">dt</span><span class="o">))</span> |
| |
| <span class="c1">// Train model. This also runs the indexer.</span> |
| <span class="k">val</span> <span class="n">model</span> <span class="k">=</span> <span class="n">pipeline</span><span class="o">.</span><span class="n">fit</span><span class="o">(</span><span class="n">trainingData</span><span class="o">)</span> |
| |
| <span class="c1">// Make predictions.</span> |
| <span class="k">val</span> <span class="n">predictions</span> <span class="k">=</span> <span class="n">model</span><span class="o">.</span><span class="n">transform</span><span class="o">(</span><span class="n">testData</span><span class="o">)</span> |
| |
| <span class="c1">// Select example rows to display.</span> |
| <span class="n">predictions</span><span class="o">.</span><span class="n">select</span><span class="o">(</span><span class="s">"prediction"</span><span class="o">,</span> <span class="s">"label"</span><span class="o">,</span> <span class="s">"features"</span><span class="o">).</span><span class="n">show</span><span class="o">(</span><span class="mi">5</span><span class="o">)</span> |
| |
| <span class="c1">// Select (prediction, true label) and compute test error.</span> |
| <span class="k">val</span> <span class="n">evaluator</span> <span class="k">=</span> <span class="k">new</span> <span class="nc">RegressionEvaluator</span><span class="o">()</span> |
| <span class="o">.</span><span class="n">setLabelCol</span><span class="o">(</span><span class="s">"label"</span><span class="o">)</span> |
| <span class="o">.</span><span class="n">setPredictionCol</span><span class="o">(</span><span class="s">"prediction"</span><span class="o">)</span> |
| <span class="o">.</span><span class="n">setMetricName</span><span class="o">(</span><span class="s">"rmse"</span><span class="o">)</span> |
| <span class="k">val</span> <span class="n">rmse</span> <span class="k">=</span> <span class="n">evaluator</span><span class="o">.</span><span class="n">evaluate</span><span class="o">(</span><span class="n">predictions</span><span class="o">)</span> |
| <span class="n">println</span><span class="o">(</span><span class="s">s"Root Mean Squared Error (RMSE) on test data = </span><span class="si">$rmse</span><span class="s">"</span><span class="o">)</span> |
| |
| <span class="k">val</span> <span class="n">treeModel</span> <span class="k">=</span> <span class="n">model</span><span class="o">.</span><span class="n">stages</span><span class="o">(</span><span class="mi">1</span><span class="o">).</span><span class="n">asInstanceOf</span><span class="o">[</span><span class="kt">DecisionTreeRegressionModel</span><span class="o">]</span> |
| <span class="n">println</span><span class="o">(</span><span class="s">s"Learned regression tree model:\n </span><span class="si">${</span><span class="n">treeModel</span><span class="o">.</span><span class="n">toDebugString</span><span class="si">}</span><span class="s">"</span><span class="o">)</span> |
| </pre></div> |
| <div><small>Find full example code at "examples/src/main/scala/org/apache/spark/examples/ml/DecisionTreeRegressionExample.scala" in the Spark repo.</small></div> |
| </div> |
| |
| <div data-lang="java"> |
| |
| <p>More details on parameters can be found in the <a href="api/java/org/apache/spark/ml/regression/DecisionTreeRegressor.html">Java API documentation</a>.</p> |
| |
| <div class="highlight"><pre><span></span><span class="kn">import</span> <span class="nn">org.apache.spark.ml.Pipeline</span><span class="o">;</span> |
| <span class="kn">import</span> <span class="nn">org.apache.spark.ml.PipelineModel</span><span class="o">;</span> |
| <span class="kn">import</span> <span class="nn">org.apache.spark.ml.PipelineStage</span><span class="o">;</span> |
| <span class="kn">import</span> <span class="nn">org.apache.spark.ml.evaluation.RegressionEvaluator</span><span class="o">;</span> |
| <span class="kn">import</span> <span class="nn">org.apache.spark.ml.feature.VectorIndexer</span><span class="o">;</span> |
| <span class="kn">import</span> <span class="nn">org.apache.spark.ml.feature.VectorIndexerModel</span><span class="o">;</span> |
| <span class="kn">import</span> <span class="nn">org.apache.spark.ml.regression.DecisionTreeRegressionModel</span><span class="o">;</span> |
| <span class="kn">import</span> <span class="nn">org.apache.spark.ml.regression.DecisionTreeRegressor</span><span class="o">;</span> |
| <span class="kn">import</span> <span class="nn">org.apache.spark.sql.Dataset</span><span class="o">;</span> |
| <span class="kn">import</span> <span class="nn">org.apache.spark.sql.Row</span><span class="o">;</span> |
| <span class="kn">import</span> <span class="nn">org.apache.spark.sql.SparkSession</span><span class="o">;</span> |
| |
| <span class="c1">// Load the data stored in LIBSVM format as a DataFrame.</span> |
| <span class="n">Dataset</span><span class="o"><</span><span class="n">Row</span><span class="o">></span> <span class="n">data</span> <span class="o">=</span> <span class="n">spark</span><span class="o">.</span><span class="na">read</span><span class="o">().</span><span class="na">format</span><span class="o">(</span><span class="s">"libsvm"</span><span class="o">)</span> |
| <span class="o">.</span><span class="na">load</span><span class="o">(</span><span class="s">"data/mllib/sample_libsvm_data.txt"</span><span class="o">);</span> |
| |
| <span class="c1">// Automatically identify categorical features, and index them.</span> |
| <span class="c1">// Set maxCategories so features with > 4 distinct values are treated as continuous.</span> |
| <span class="n">VectorIndexerModel</span> <span class="n">featureIndexer</span> <span class="o">=</span> <span class="k">new</span> <span class="n">VectorIndexer</span><span class="o">()</span> |
| <span class="o">.</span><span class="na">setInputCol</span><span class="o">(</span><span class="s">"features"</span><span class="o">)</span> |
| <span class="o">.</span><span class="na">setOutputCol</span><span class="o">(</span><span class="s">"indexedFeatures"</span><span class="o">)</span> |
| <span class="o">.</span><span class="na">setMaxCategories</span><span class="o">(</span><span class="mi">4</span><span class="o">)</span> |
| <span class="o">.</span><span class="na">fit</span><span class="o">(</span><span class="n">data</span><span class="o">);</span> |
| |
| <span class="c1">// Split the data into training and test sets (30% held out for testing).</span> |
| <span class="n">Dataset</span><span class="o"><</span><span class="n">Row</span><span class="o">>[]</span> <span class="n">splits</span> <span class="o">=</span> <span class="n">data</span><span class="o">.</span><span class="na">randomSplit</span><span class="o">(</span><span class="k">new</span> <span class="kt">double</span><span class="o">[]{</span><span class="mf">0.7</span><span class="o">,</span> <span class="mf">0.3</span><span class="o">});</span> |
| <span class="n">Dataset</span><span class="o"><</span><span class="n">Row</span><span class="o">></span> <span class="n">trainingData</span> <span class="o">=</span> <span class="n">splits</span><span class="o">[</span><span class="mi">0</span><span class="o">];</span> |
| <span class="n">Dataset</span><span class="o"><</span><span class="n">Row</span><span class="o">></span> <span class="n">testData</span> <span class="o">=</span> <span class="n">splits</span><span class="o">[</span><span class="mi">1</span><span class="o">];</span> |
| |
| <span class="c1">// Train a DecisionTree model.</span> |
| <span class="n">DecisionTreeRegressor</span> <span class="n">dt</span> <span class="o">=</span> <span class="k">new</span> <span class="n">DecisionTreeRegressor</span><span class="o">()</span> |
| <span class="o">.</span><span class="na">setFeaturesCol</span><span class="o">(</span><span class="s">"indexedFeatures"</span><span class="o">);</span> |
| |
| <span class="c1">// Chain indexer and tree in a Pipeline.</span> |
| <span class="n">Pipeline</span> <span class="n">pipeline</span> <span class="o">=</span> <span class="k">new</span> <span class="n">Pipeline</span><span class="o">()</span> |
| <span class="o">.</span><span class="na">setStages</span><span class="o">(</span><span class="k">new</span> <span class="n">PipelineStage</span><span class="o">[]{</span><span class="n">featureIndexer</span><span class="o">,</span> <span class="n">dt</span><span class="o">});</span> |
| |
| <span class="c1">// Train model. This also runs the indexer.</span> |
| <span class="n">PipelineModel</span> <span class="n">model</span> <span class="o">=</span> <span class="n">pipeline</span><span class="o">.</span><span class="na">fit</span><span class="o">(</span><span class="n">trainingData</span><span class="o">);</span> |
| |
| <span class="c1">// Make predictions.</span> |
| <span class="n">Dataset</span><span class="o"><</span><span class="n">Row</span><span class="o">></span> <span class="n">predictions</span> <span class="o">=</span> <span class="n">model</span><span class="o">.</span><span class="na">transform</span><span class="o">(</span><span class="n">testData</span><span class="o">);</span> |
| |
| <span class="c1">// Select example rows to display.</span> |
| <span class="n">predictions</span><span class="o">.</span><span class="na">select</span><span class="o">(</span><span class="s">"label"</span><span class="o">,</span> <span class="s">"features"</span><span class="o">).</span><span class="na">show</span><span class="o">(</span><span class="mi">5</span><span class="o">);</span> |
| |
| <span class="c1">// Select (prediction, true label) and compute test error.</span> |
| <span class="n">RegressionEvaluator</span> <span class="n">evaluator</span> <span class="o">=</span> <span class="k">new</span> <span class="n">RegressionEvaluator</span><span class="o">()</span> |
| <span class="o">.</span><span class="na">setLabelCol</span><span class="o">(</span><span class="s">"label"</span><span class="o">)</span> |
| <span class="o">.</span><span class="na">setPredictionCol</span><span class="o">(</span><span class="s">"prediction"</span><span class="o">)</span> |
| <span class="o">.</span><span class="na">setMetricName</span><span class="o">(</span><span class="s">"rmse"</span><span class="o">);</span> |
| <span class="kt">double</span> <span class="n">rmse</span> <span class="o">=</span> <span class="n">evaluator</span><span class="o">.</span><span class="na">evaluate</span><span class="o">(</span><span class="n">predictions</span><span class="o">);</span> |
| <span class="n">System</span><span class="o">.</span><span class="na">out</span><span class="o">.</span><span class="na">println</span><span class="o">(</span><span class="s">"Root Mean Squared Error (RMSE) on test data = "</span> <span class="o">+</span> <span class="n">rmse</span><span class="o">);</span> |
| |
| <span class="n">DecisionTreeRegressionModel</span> <span class="n">treeModel</span> <span class="o">=</span> |
| <span class="o">(</span><span class="n">DecisionTreeRegressionModel</span><span class="o">)</span> <span class="o">(</span><span class="n">model</span><span class="o">.</span><span class="na">stages</span><span class="o">()[</span><span class="mi">1</span><span class="o">]);</span> |
| <span class="n">System</span><span class="o">.</span><span class="na">out</span><span class="o">.</span><span class="na">println</span><span class="o">(</span><span class="s">"Learned regression tree model:\n"</span> <span class="o">+</span> <span class="n">treeModel</span><span class="o">.</span><span class="na">toDebugString</span><span class="o">());</span> |
| </pre></div> |
| <div><small>Find full example code at "examples/src/main/java/org/apache/spark/examples/ml/JavaDecisionTreeRegressionExample.java" in the Spark repo.</small></div> |
| </div> |
| |
| <div data-lang="python"> |
| |
| <p>More details on parameters can be found in the <a href="api/python/pyspark.ml.html#pyspark.ml.regression.DecisionTreeRegressor">Python API documentation</a>.</p> |
| |
| <div class="highlight"><pre><span></span><span class="kn">from</span> <span class="nn">pyspark.ml</span> <span class="kn">import</span> <span class="n">Pipeline</span> |
| <span class="kn">from</span> <span class="nn">pyspark.ml.regression</span> <span class="kn">import</span> <span class="n">DecisionTreeRegressor</span> |
| <span class="kn">from</span> <span class="nn">pyspark.ml.feature</span> <span class="kn">import</span> <span class="n">VectorIndexer</span> |
| <span class="kn">from</span> <span class="nn">pyspark.ml.evaluation</span> <span class="kn">import</span> <span class="n">RegressionEvaluator</span> |
| |
| <span class="c1"># Load the data stored in LIBSVM format as a DataFrame.</span> |
| <span class="n">data</span> <span class="o">=</span> <span class="n">spark</span><span class="o">.</span><span class="n">read</span><span class="o">.</span><span class="n">format</span><span class="p">(</span><span class="s2">"libsvm"</span><span class="p">)</span><span class="o">.</span><span class="n">load</span><span class="p">(</span><span class="s2">"data/mllib/sample_libsvm_data.txt"</span><span class="p">)</span> |
| |
| <span class="c1"># Automatically identify categorical features, and index them.</span> |
| <span class="c1"># We specify maxCategories so features with > 4 distinct values are treated as continuous.</span> |
| <span class="n">featureIndexer</span> <span class="o">=</span>\ |
| <span class="n">VectorIndexer</span><span class="p">(</span><span class="n">inputCol</span><span class="o">=</span><span class="s2">"features"</span><span class="p">,</span> <span class="n">outputCol</span><span class="o">=</span><span class="s2">"indexedFeatures"</span><span class="p">,</span> <span class="n">maxCategories</span><span class="o">=</span><span class="mi">4</span><span class="p">)</span><span class="o">.</span><span class="n">fit</span><span class="p">(</span><span class="n">data</span><span class="p">)</span> |
| |
| <span class="c1"># Split the data into training and test sets (30% held out for testing)</span> |
| <span class="p">(</span><span class="n">trainingData</span><span class="p">,</span> <span class="n">testData</span><span class="p">)</span> <span class="o">=</span> <span class="n">data</span><span class="o">.</span><span class="n">randomSplit</span><span class="p">([</span><span class="mf">0.7</span><span class="p">,</span> <span class="mf">0.3</span><span class="p">])</span> |
| |
| <span class="c1"># Train a DecisionTree model.</span> |
| <span class="n">dt</span> <span class="o">=</span> <span class="n">DecisionTreeRegressor</span><span class="p">(</span><span class="n">featuresCol</span><span class="o">=</span><span class="s2">"indexedFeatures"</span><span class="p">)</span> |
| |
| <span class="c1"># Chain indexer and tree in a Pipeline</span> |
| <span class="n">pipeline</span> <span class="o">=</span> <span class="n">Pipeline</span><span class="p">(</span><span class="n">stages</span><span class="o">=</span><span class="p">[</span><span class="n">featureIndexer</span><span class="p">,</span> <span class="n">dt</span><span class="p">])</span> |
| |
| <span class="c1"># Train model. This also runs the indexer.</span> |
| <span class="n">model</span> <span class="o">=</span> <span class="n">pipeline</span><span class="o">.</span><span class="n">fit</span><span class="p">(</span><span class="n">trainingData</span><span class="p">)</span> |
| |
| <span class="c1"># Make predictions.</span> |
| <span class="n">predictions</span> <span class="o">=</span> <span class="n">model</span><span class="o">.</span><span class="n">transform</span><span class="p">(</span><span class="n">testData</span><span class="p">)</span> |
| |
| <span class="c1"># Select example rows to display.</span> |
| <span class="n">predictions</span><span class="o">.</span><span class="n">select</span><span class="p">(</span><span class="s2">"prediction"</span><span class="p">,</span> <span class="s2">"label"</span><span class="p">,</span> <span class="s2">"features"</span><span class="p">)</span><span class="o">.</span><span class="n">show</span><span class="p">(</span><span class="mi">5</span><span class="p">)</span> |
| |
| <span class="c1"># Select (prediction, true label) and compute test error</span> |
| <span class="n">evaluator</span> <span class="o">=</span> <span class="n">RegressionEvaluator</span><span class="p">(</span> |
| <span class="n">labelCol</span><span class="o">=</span><span class="s2">"label"</span><span class="p">,</span> <span class="n">predictionCol</span><span class="o">=</span><span class="s2">"prediction"</span><span class="p">,</span> <span class="n">metricName</span><span class="o">=</span><span class="s2">"rmse"</span><span class="p">)</span> |
| <span class="n">rmse</span> <span class="o">=</span> <span class="n">evaluator</span><span class="o">.</span><span class="n">evaluate</span><span class="p">(</span><span class="n">predictions</span><span class="p">)</span> |
| <span class="k">print</span><span class="p">(</span><span class="s2">"Root Mean Squared Error (RMSE) on test data = </span><span class="si">%g</span><span class="s2">"</span> <span class="o">%</span> <span class="n">rmse</span><span class="p">)</span> |
| |
| <span class="n">treeModel</span> <span class="o">=</span> <span class="n">model</span><span class="o">.</span><span class="n">stages</span><span class="p">[</span><span class="mi">1</span><span class="p">]</span> |
| <span class="c1"># summary only</span> |
| <span class="k">print</span><span class="p">(</span><span class="n">treeModel</span><span class="p">)</span> |
| </pre></div> |
| <div><small>Find full example code at "examples/src/main/python/ml/decision_tree_regression_example.py" in the Spark repo.</small></div> |
| </div> |
| |
| <div data-lang="r"> |
| |
| <p>Refer to the <a href="api/R/spark.decisionTree.html">R API docs</a> for more details.</p> |
| |
| <div class="highlight"><pre><span></span><span class="c1"># Load training data</span> |
| df <span class="o"><-</span> read.df<span class="p">(</span><span class="s">"data/mllib/sample_linear_regression_data.txt"</span><span class="p">,</span> <span class="kn">source</span> <span class="o">=</span> <span class="s">"libsvm"</span><span class="p">)</span> |
| training <span class="o"><-</span> df |
| test <span class="o"><-</span> df |
| |
| <span class="c1"># Fit a DecisionTree regression model with spark.decisionTree</span> |
| model <span class="o"><-</span> spark.decisionTree<span class="p">(</span>training<span class="p">,</span> label <span class="o">~</span> features<span class="p">,</span> <span class="s">"regression"</span><span class="p">)</span> |
| |
| <span class="c1"># Model summary</span> |
| <span class="kp">summary</span><span class="p">(</span>model<span class="p">)</span> |
| |
| <span class="c1"># Prediction</span> |
| predictions <span class="o"><-</span> predict<span class="p">(</span>model<span class="p">,</span> test<span class="p">)</span> |
| <span class="kp">head</span><span class="p">(</span>predictions<span class="p">)</span> |
| </pre></div> |
| <div><small>Find full example code at "examples/src/main/r/ml/decisionTree.R" in the Spark repo.</small></div> |
| </div> |
| |
| </div> |
| |
| <h2 id="random-forest-regression">Random forest regression</h2> |
| |
| <p>Random forests are a popular family of classification and regression methods. |
| More information about the <code>spark.ml</code> implementation can be found further in the <a href="#random-forests">section on random forests</a>.</p> |
| |
| <p><strong>Examples</strong></p> |
| |
| <p>The following examples load a dataset in LibSVM format, split it into training and test sets, train on the first dataset, and then evaluate on the held-out test set. |
| We use a feature transformer to index categorical features, adding metadata to the <code>DataFrame</code> which the tree-based algorithms can recognize.</p> |
| |
| <div class="codetabs"> |
| <div data-lang="scala"> |
| |
| <p>Refer to the <a href="api/scala/index.html#org.apache.spark.ml.regression.RandomForestRegressor">Scala API docs</a> for more details.</p> |
| |
| <div class="highlight"><pre><span></span><span class="k">import</span> <span class="nn">org.apache.spark.ml.Pipeline</span> |
| <span class="k">import</span> <span class="nn">org.apache.spark.ml.evaluation.RegressionEvaluator</span> |
| <span class="k">import</span> <span class="nn">org.apache.spark.ml.feature.VectorIndexer</span> |
| <span class="k">import</span> <span class="nn">org.apache.spark.ml.regression.</span><span class="o">{</span><span class="nc">RandomForestRegressionModel</span><span class="o">,</span> <span class="nc">RandomForestRegressor</span><span class="o">}</span> |
| |
| <span class="c1">// Load and parse the data file, converting it to a DataFrame.</span> |
| <span class="k">val</span> <span class="n">data</span> <span class="k">=</span> <span class="n">spark</span><span class="o">.</span><span class="n">read</span><span class="o">.</span><span class="n">format</span><span class="o">(</span><span class="s">"libsvm"</span><span class="o">).</span><span class="n">load</span><span class="o">(</span><span class="s">"data/mllib/sample_libsvm_data.txt"</span><span class="o">)</span> |
| |
| <span class="c1">// Automatically identify categorical features, and index them.</span> |
| <span class="c1">// Set maxCategories so features with > 4 distinct values are treated as continuous.</span> |
| <span class="k">val</span> <span class="n">featureIndexer</span> <span class="k">=</span> <span class="k">new</span> <span class="nc">VectorIndexer</span><span class="o">()</span> |
| <span class="o">.</span><span class="n">setInputCol</span><span class="o">(</span><span class="s">"features"</span><span class="o">)</span> |
| <span class="o">.</span><span class="n">setOutputCol</span><span class="o">(</span><span class="s">"indexedFeatures"</span><span class="o">)</span> |
| <span class="o">.</span><span class="n">setMaxCategories</span><span class="o">(</span><span class="mi">4</span><span class="o">)</span> |
| <span class="o">.</span><span class="n">fit</span><span class="o">(</span><span class="n">data</span><span class="o">)</span> |
| |
| <span class="c1">// Split the data into training and test sets (30% held out for testing).</span> |
| <span class="k">val</span> <span class="nc">Array</span><span class="o">(</span><span class="n">trainingData</span><span class="o">,</span> <span class="n">testData</span><span class="o">)</span> <span class="k">=</span> <span class="n">data</span><span class="o">.</span><span class="n">randomSplit</span><span class="o">(</span><span class="nc">Array</span><span class="o">(</span><span class="mf">0.7</span><span class="o">,</span> <span class="mf">0.3</span><span class="o">))</span> |
| |
| <span class="c1">// Train a RandomForest model.</span> |
| <span class="k">val</span> <span class="n">rf</span> <span class="k">=</span> <span class="k">new</span> <span class="nc">RandomForestRegressor</span><span class="o">()</span> |
| <span class="o">.</span><span class="n">setLabelCol</span><span class="o">(</span><span class="s">"label"</span><span class="o">)</span> |
| <span class="o">.</span><span class="n">setFeaturesCol</span><span class="o">(</span><span class="s">"indexedFeatures"</span><span class="o">)</span> |
| |
| <span class="c1">// Chain indexer and forest in a Pipeline.</span> |
| <span class="k">val</span> <span class="n">pipeline</span> <span class="k">=</span> <span class="k">new</span> <span class="nc">Pipeline</span><span class="o">()</span> |
| <span class="o">.</span><span class="n">setStages</span><span class="o">(</span><span class="nc">Array</span><span class="o">(</span><span class="n">featureIndexer</span><span class="o">,</span> <span class="n">rf</span><span class="o">))</span> |
| |
| <span class="c1">// Train model. This also runs the indexer.</span> |
| <span class="k">val</span> <span class="n">model</span> <span class="k">=</span> <span class="n">pipeline</span><span class="o">.</span><span class="n">fit</span><span class="o">(</span><span class="n">trainingData</span><span class="o">)</span> |
| |
| <span class="c1">// Make predictions.</span> |
| <span class="k">val</span> <span class="n">predictions</span> <span class="k">=</span> <span class="n">model</span><span class="o">.</span><span class="n">transform</span><span class="o">(</span><span class="n">testData</span><span class="o">)</span> |
| |
| <span class="c1">// Select example rows to display.</span> |
| <span class="n">predictions</span><span class="o">.</span><span class="n">select</span><span class="o">(</span><span class="s">"prediction"</span><span class="o">,</span> <span class="s">"label"</span><span class="o">,</span> <span class="s">"features"</span><span class="o">).</span><span class="n">show</span><span class="o">(</span><span class="mi">5</span><span class="o">)</span> |
| |
| <span class="c1">// Select (prediction, true label) and compute test error.</span> |
| <span class="k">val</span> <span class="n">evaluator</span> <span class="k">=</span> <span class="k">new</span> <span class="nc">RegressionEvaluator</span><span class="o">()</span> |
| <span class="o">.</span><span class="n">setLabelCol</span><span class="o">(</span><span class="s">"label"</span><span class="o">)</span> |
| <span class="o">.</span><span class="n">setPredictionCol</span><span class="o">(</span><span class="s">"prediction"</span><span class="o">)</span> |
| <span class="o">.</span><span class="n">setMetricName</span><span class="o">(</span><span class="s">"rmse"</span><span class="o">)</span> |
| <span class="k">val</span> <span class="n">rmse</span> <span class="k">=</span> <span class="n">evaluator</span><span class="o">.</span><span class="n">evaluate</span><span class="o">(</span><span class="n">predictions</span><span class="o">)</span> |
| <span class="n">println</span><span class="o">(</span><span class="s">s"Root Mean Squared Error (RMSE) on test data = </span><span class="si">$rmse</span><span class="s">"</span><span class="o">)</span> |
| |
| <span class="k">val</span> <span class="n">rfModel</span> <span class="k">=</span> <span class="n">model</span><span class="o">.</span><span class="n">stages</span><span class="o">(</span><span class="mi">1</span><span class="o">).</span><span class="n">asInstanceOf</span><span class="o">[</span><span class="kt">RandomForestRegressionModel</span><span class="o">]</span> |
| <span class="n">println</span><span class="o">(</span><span class="s">s"Learned regression forest model:\n </span><span class="si">${</span><span class="n">rfModel</span><span class="o">.</span><span class="n">toDebugString</span><span class="si">}</span><span class="s">"</span><span class="o">)</span> |
| </pre></div> |
| <div><small>Find full example code at "examples/src/main/scala/org/apache/spark/examples/ml/RandomForestRegressorExample.scala" in the Spark repo.</small></div> |
| </div> |
| |
| <div data-lang="java"> |
| |
| <p>Refer to the <a href="api/java/org/apache/spark/ml/regression/RandomForestRegressor.html">Java API docs</a> for more details.</p> |
| |
| <div class="highlight"><pre><span></span><span class="kn">import</span> <span class="nn">org.apache.spark.ml.Pipeline</span><span class="o">;</span> |
| <span class="kn">import</span> <span class="nn">org.apache.spark.ml.PipelineModel</span><span class="o">;</span> |
| <span class="kn">import</span> <span class="nn">org.apache.spark.ml.PipelineStage</span><span class="o">;</span> |
| <span class="kn">import</span> <span class="nn">org.apache.spark.ml.evaluation.RegressionEvaluator</span><span class="o">;</span> |
| <span class="kn">import</span> <span class="nn">org.apache.spark.ml.feature.VectorIndexer</span><span class="o">;</span> |
| <span class="kn">import</span> <span class="nn">org.apache.spark.ml.feature.VectorIndexerModel</span><span class="o">;</span> |
| <span class="kn">import</span> <span class="nn">org.apache.spark.ml.regression.RandomForestRegressionModel</span><span class="o">;</span> |
| <span class="kn">import</span> <span class="nn">org.apache.spark.ml.regression.RandomForestRegressor</span><span class="o">;</span> |
| <span class="kn">import</span> <span class="nn">org.apache.spark.sql.Dataset</span><span class="o">;</span> |
| <span class="kn">import</span> <span class="nn">org.apache.spark.sql.Row</span><span class="o">;</span> |
| <span class="kn">import</span> <span class="nn">org.apache.spark.sql.SparkSession</span><span class="o">;</span> |
| |
| <span class="c1">// Load and parse the data file, converting it to a DataFrame.</span> |
| <span class="n">Dataset</span><span class="o"><</span><span class="n">Row</span><span class="o">></span> <span class="n">data</span> <span class="o">=</span> <span class="n">spark</span><span class="o">.</span><span class="na">read</span><span class="o">().</span><span class="na">format</span><span class="o">(</span><span class="s">"libsvm"</span><span class="o">).</span><span class="na">load</span><span class="o">(</span><span class="s">"data/mllib/sample_libsvm_data.txt"</span><span class="o">);</span> |
| |
| <span class="c1">// Automatically identify categorical features, and index them.</span> |
| <span class="c1">// Set maxCategories so features with > 4 distinct values are treated as continuous.</span> |
| <span class="n">VectorIndexerModel</span> <span class="n">featureIndexer</span> <span class="o">=</span> <span class="k">new</span> <span class="n">VectorIndexer</span><span class="o">()</span> |
| <span class="o">.</span><span class="na">setInputCol</span><span class="o">(</span><span class="s">"features"</span><span class="o">)</span> |
| <span class="o">.</span><span class="na">setOutputCol</span><span class="o">(</span><span class="s">"indexedFeatures"</span><span class="o">)</span> |
| <span class="o">.</span><span class="na">setMaxCategories</span><span class="o">(</span><span class="mi">4</span><span class="o">)</span> |
| <span class="o">.</span><span class="na">fit</span><span class="o">(</span><span class="n">data</span><span class="o">);</span> |
| |
| <span class="c1">// Split the data into training and test sets (30% held out for testing)</span> |
| <span class="n">Dataset</span><span class="o"><</span><span class="n">Row</span><span class="o">>[]</span> <span class="n">splits</span> <span class="o">=</span> <span class="n">data</span><span class="o">.</span><span class="na">randomSplit</span><span class="o">(</span><span class="k">new</span> <span class="kt">double</span><span class="o">[]</span> <span class="o">{</span><span class="mf">0.7</span><span class="o">,</span> <span class="mf">0.3</span><span class="o">});</span> |
| <span class="n">Dataset</span><span class="o"><</span><span class="n">Row</span><span class="o">></span> <span class="n">trainingData</span> <span class="o">=</span> <span class="n">splits</span><span class="o">[</span><span class="mi">0</span><span class="o">];</span> |
| <span class="n">Dataset</span><span class="o"><</span><span class="n">Row</span><span class="o">></span> <span class="n">testData</span> <span class="o">=</span> <span class="n">splits</span><span class="o">[</span><span class="mi">1</span><span class="o">];</span> |
| |
| <span class="c1">// Train a RandomForest model.</span> |
| <span class="n">RandomForestRegressor</span> <span class="n">rf</span> <span class="o">=</span> <span class="k">new</span> <span class="n">RandomForestRegressor</span><span class="o">()</span> |
| <span class="o">.</span><span class="na">setLabelCol</span><span class="o">(</span><span class="s">"label"</span><span class="o">)</span> |
| <span class="o">.</span><span class="na">setFeaturesCol</span><span class="o">(</span><span class="s">"indexedFeatures"</span><span class="o">);</span> |
| |
| <span class="c1">// Chain indexer and forest in a Pipeline</span> |
| <span class="n">Pipeline</span> <span class="n">pipeline</span> <span class="o">=</span> <span class="k">new</span> <span class="n">Pipeline</span><span class="o">()</span> |
| <span class="o">.</span><span class="na">setStages</span><span class="o">(</span><span class="k">new</span> <span class="n">PipelineStage</span><span class="o">[]</span> <span class="o">{</span><span class="n">featureIndexer</span><span class="o">,</span> <span class="n">rf</span><span class="o">});</span> |
| |
| <span class="c1">// Train model. This also runs the indexer.</span> |
| <span class="n">PipelineModel</span> <span class="n">model</span> <span class="o">=</span> <span class="n">pipeline</span><span class="o">.</span><span class="na">fit</span><span class="o">(</span><span class="n">trainingData</span><span class="o">);</span> |
| |
| <span class="c1">// Make predictions.</span> |
| <span class="n">Dataset</span><span class="o"><</span><span class="n">Row</span><span class="o">></span> <span class="n">predictions</span> <span class="o">=</span> <span class="n">model</span><span class="o">.</span><span class="na">transform</span><span class="o">(</span><span class="n">testData</span><span class="o">);</span> |
| |
| <span class="c1">// Select example rows to display.</span> |
| <span class="n">predictions</span><span class="o">.</span><span class="na">select</span><span class="o">(</span><span class="s">"prediction"</span><span class="o">,</span> <span class="s">"label"</span><span class="o">,</span> <span class="s">"features"</span><span class="o">).</span><span class="na">show</span><span class="o">(</span><span class="mi">5</span><span class="o">);</span> |
| |
| <span class="c1">// Select (prediction, true label) and compute test error</span> |
| <span class="n">RegressionEvaluator</span> <span class="n">evaluator</span> <span class="o">=</span> <span class="k">new</span> <span class="n">RegressionEvaluator</span><span class="o">()</span> |
| <span class="o">.</span><span class="na">setLabelCol</span><span class="o">(</span><span class="s">"label"</span><span class="o">)</span> |
| <span class="o">.</span><span class="na">setPredictionCol</span><span class="o">(</span><span class="s">"prediction"</span><span class="o">)</span> |
| <span class="o">.</span><span class="na">setMetricName</span><span class="o">(</span><span class="s">"rmse"</span><span class="o">);</span> |
| <span class="kt">double</span> <span class="n">rmse</span> <span class="o">=</span> <span class="n">evaluator</span><span class="o">.</span><span class="na">evaluate</span><span class="o">(</span><span class="n">predictions</span><span class="o">);</span> |
| <span class="n">System</span><span class="o">.</span><span class="na">out</span><span class="o">.</span><span class="na">println</span><span class="o">(</span><span class="s">"Root Mean Squared Error (RMSE) on test data = "</span> <span class="o">+</span> <span class="n">rmse</span><span class="o">);</span> |
| |
| <span class="n">RandomForestRegressionModel</span> <span class="n">rfModel</span> <span class="o">=</span> <span class="o">(</span><span class="n">RandomForestRegressionModel</span><span class="o">)(</span><span class="n">model</span><span class="o">.</span><span class="na">stages</span><span class="o">()[</span><span class="mi">1</span><span class="o">]);</span> |
| <span class="n">System</span><span class="o">.</span><span class="na">out</span><span class="o">.</span><span class="na">println</span><span class="o">(</span><span class="s">"Learned regression forest model:\n"</span> <span class="o">+</span> <span class="n">rfModel</span><span class="o">.</span><span class="na">toDebugString</span><span class="o">());</span> |
| </pre></div> |
| <div><small>Find full example code at "examples/src/main/java/org/apache/spark/examples/ml/JavaRandomForestRegressorExample.java" in the Spark repo.</small></div> |
| </div> |
| |
| <div data-lang="python"> |
| |
| <p>Refer to the <a href="api/python/pyspark.ml.html#pyspark.ml.regression.RandomForestRegressor">Python API docs</a> for more details.</p> |
| |
| <div class="highlight"><pre><span></span><span class="kn">from</span> <span class="nn">pyspark.ml</span> <span class="kn">import</span> <span class="n">Pipeline</span> |
| <span class="kn">from</span> <span class="nn">pyspark.ml.regression</span> <span class="kn">import</span> <span class="n">RandomForestRegressor</span> |
| <span class="kn">from</span> <span class="nn">pyspark.ml.feature</span> <span class="kn">import</span> <span class="n">VectorIndexer</span> |
| <span class="kn">from</span> <span class="nn">pyspark.ml.evaluation</span> <span class="kn">import</span> <span class="n">RegressionEvaluator</span> |
| |
| <span class="c1"># Load and parse the data file, converting it to a DataFrame.</span> |
| <span class="n">data</span> <span class="o">=</span> <span class="n">spark</span><span class="o">.</span><span class="n">read</span><span class="o">.</span><span class="n">format</span><span class="p">(</span><span class="s2">"libsvm"</span><span class="p">)</span><span class="o">.</span><span class="n">load</span><span class="p">(</span><span class="s2">"data/mllib/sample_libsvm_data.txt"</span><span class="p">)</span> |
| |
| <span class="c1"># Automatically identify categorical features, and index them.</span> |
| <span class="c1"># Set maxCategories so features with > 4 distinct values are treated as continuous.</span> |
| <span class="n">featureIndexer</span> <span class="o">=</span>\ |
| <span class="n">VectorIndexer</span><span class="p">(</span><span class="n">inputCol</span><span class="o">=</span><span class="s2">"features"</span><span class="p">,</span> <span class="n">outputCol</span><span class="o">=</span><span class="s2">"indexedFeatures"</span><span class="p">,</span> <span class="n">maxCategories</span><span class="o">=</span><span class="mi">4</span><span class="p">)</span><span class="o">.</span><span class="n">fit</span><span class="p">(</span><span class="n">data</span><span class="p">)</span> |
| |
| <span class="c1"># Split the data into training and test sets (30% held out for testing)</span> |
| <span class="p">(</span><span class="n">trainingData</span><span class="p">,</span> <span class="n">testData</span><span class="p">)</span> <span class="o">=</span> <span class="n">data</span><span class="o">.</span><span class="n">randomSplit</span><span class="p">([</span><span class="mf">0.7</span><span class="p">,</span> <span class="mf">0.3</span><span class="p">])</span> |
| |
| <span class="c1"># Train a RandomForest model.</span> |
| <span class="n">rf</span> <span class="o">=</span> <span class="n">RandomForestRegressor</span><span class="p">(</span><span class="n">featuresCol</span><span class="o">=</span><span class="s2">"indexedFeatures"</span><span class="p">)</span> |
| |
| <span class="c1"># Chain indexer and forest in a Pipeline</span> |
| <span class="n">pipeline</span> <span class="o">=</span> <span class="n">Pipeline</span><span class="p">(</span><span class="n">stages</span><span class="o">=</span><span class="p">[</span><span class="n">featureIndexer</span><span class="p">,</span> <span class="n">rf</span><span class="p">])</span> |
| |
| <span class="c1"># Train model. This also runs the indexer.</span> |
| <span class="n">model</span> <span class="o">=</span> <span class="n">pipeline</span><span class="o">.</span><span class="n">fit</span><span class="p">(</span><span class="n">trainingData</span><span class="p">)</span> |
| |
| <span class="c1"># Make predictions.</span> |
| <span class="n">predictions</span> <span class="o">=</span> <span class="n">model</span><span class="o">.</span><span class="n">transform</span><span class="p">(</span><span class="n">testData</span><span class="p">)</span> |
| |
| <span class="c1"># Select example rows to display.</span> |
| <span class="n">predictions</span><span class="o">.</span><span class="n">select</span><span class="p">(</span><span class="s2">"prediction"</span><span class="p">,</span> <span class="s2">"label"</span><span class="p">,</span> <span class="s2">"features"</span><span class="p">)</span><span class="o">.</span><span class="n">show</span><span class="p">(</span><span class="mi">5</span><span class="p">)</span> |
| |
| <span class="c1"># Select (prediction, true label) and compute test error</span> |
| <span class="n">evaluator</span> <span class="o">=</span> <span class="n">RegressionEvaluator</span><span class="p">(</span> |
| <span class="n">labelCol</span><span class="o">=</span><span class="s2">"label"</span><span class="p">,</span> <span class="n">predictionCol</span><span class="o">=</span><span class="s2">"prediction"</span><span class="p">,</span> <span class="n">metricName</span><span class="o">=</span><span class="s2">"rmse"</span><span class="p">)</span> |
| <span class="n">rmse</span> <span class="o">=</span> <span class="n">evaluator</span><span class="o">.</span><span class="n">evaluate</span><span class="p">(</span><span class="n">predictions</span><span class="p">)</span> |
| <span class="k">print</span><span class="p">(</span><span class="s2">"Root Mean Squared Error (RMSE) on test data = </span><span class="si">%g</span><span class="s2">"</span> <span class="o">%</span> <span class="n">rmse</span><span class="p">)</span> |
| |
| <span class="n">rfModel</span> <span class="o">=</span> <span class="n">model</span><span class="o">.</span><span class="n">stages</span><span class="p">[</span><span class="mi">1</span><span class="p">]</span> |
| <span class="k">print</span><span class="p">(</span><span class="n">rfModel</span><span class="p">)</span> <span class="c1"># summary only</span> |
| </pre></div> |
| <div><small>Find full example code at "examples/src/main/python/ml/random_forest_regressor_example.py" in the Spark repo.</small></div> |
| </div> |
| |
| <div data-lang="r"> |
| |
| <p>Refer to the <a href="api/R/spark.randomForest.html">R API docs</a> for more details.</p> |
| |
| <div class="highlight"><pre><span></span><span class="c1"># Load training data</span> |
| df <span class="o"><-</span> read.df<span class="p">(</span><span class="s">"data/mllib/sample_linear_regression_data.txt"</span><span class="p">,</span> <span class="kn">source</span> <span class="o">=</span> <span class="s">"libsvm"</span><span class="p">)</span> |
| training <span class="o"><-</span> df |
| test <span class="o"><-</span> df |
| |
| <span class="c1"># Fit a random forest regression model with spark.randomForest</span> |
| model <span class="o"><-</span> spark.randomForest<span class="p">(</span>training<span class="p">,</span> label <span class="o">~</span> features<span class="p">,</span> <span class="s">"regression"</span><span class="p">,</span> numTrees <span class="o">=</span> <span class="m">10</span><span class="p">)</span> |
| |
| <span class="c1"># Model summary</span> |
| <span class="kp">summary</span><span class="p">(</span>model<span class="p">)</span> |
| |
| <span class="c1"># Prediction</span> |
| predictions <span class="o"><-</span> predict<span class="p">(</span>model<span class="p">,</span> test<span class="p">)</span> |
| <span class="kp">head</span><span class="p">(</span>predictions<span class="p">)</span> |
| </pre></div> |
| <div><small>Find full example code at "examples/src/main/r/ml/randomForest.R" in the Spark repo.</small></div> |
| </div> |
| |
| </div> |
| |
| <h2 id="gradient-boosted-tree-regression">Gradient-boosted tree regression</h2> |
| |
| <p>Gradient-boosted trees (GBTs) are a popular regression method using ensembles of decision trees. |
| More information about the <code>spark.ml</code> implementation can be found further in the <a href="#gradient-boosted-trees-gbts">section on GBTs</a>.</p> |
| |
| <p><strong>Examples</strong></p> |
| |
| <p>Note: For this example dataset, <code>GBTRegressor</code> actually only needs 1 iteration, but that will not |
| be true in general.</p> |
| |
| <div class="codetabs"> |
| <div data-lang="scala"> |
| |
| <p>Refer to the <a href="api/scala/index.html#org.apache.spark.ml.regression.GBTRegressor">Scala API docs</a> for more details.</p> |
| |
| <div class="highlight"><pre><span></span><span class="k">import</span> <span class="nn">org.apache.spark.ml.Pipeline</span> |
| <span class="k">import</span> <span class="nn">org.apache.spark.ml.evaluation.RegressionEvaluator</span> |
| <span class="k">import</span> <span class="nn">org.apache.spark.ml.feature.VectorIndexer</span> |
| <span class="k">import</span> <span class="nn">org.apache.spark.ml.regression.</span><span class="o">{</span><span class="nc">GBTRegressionModel</span><span class="o">,</span> <span class="nc">GBTRegressor</span><span class="o">}</span> |
| |
| <span class="c1">// Load and parse the data file, converting it to a DataFrame.</span> |
| <span class="k">val</span> <span class="n">data</span> <span class="k">=</span> <span class="n">spark</span><span class="o">.</span><span class="n">read</span><span class="o">.</span><span class="n">format</span><span class="o">(</span><span class="s">"libsvm"</span><span class="o">).</span><span class="n">load</span><span class="o">(</span><span class="s">"data/mllib/sample_libsvm_data.txt"</span><span class="o">)</span> |
| |
| <span class="c1">// Automatically identify categorical features, and index them.</span> |
| <span class="c1">// Set maxCategories so features with > 4 distinct values are treated as continuous.</span> |
| <span class="k">val</span> <span class="n">featureIndexer</span> <span class="k">=</span> <span class="k">new</span> <span class="nc">VectorIndexer</span><span class="o">()</span> |
| <span class="o">.</span><span class="n">setInputCol</span><span class="o">(</span><span class="s">"features"</span><span class="o">)</span> |
| <span class="o">.</span><span class="n">setOutputCol</span><span class="o">(</span><span class="s">"indexedFeatures"</span><span class="o">)</span> |
| <span class="o">.</span><span class="n">setMaxCategories</span><span class="o">(</span><span class="mi">4</span><span class="o">)</span> |
| <span class="o">.</span><span class="n">fit</span><span class="o">(</span><span class="n">data</span><span class="o">)</span> |
| |
| <span class="c1">// Split the data into training and test sets (30% held out for testing).</span> |
| <span class="k">val</span> <span class="nc">Array</span><span class="o">(</span><span class="n">trainingData</span><span class="o">,</span> <span class="n">testData</span><span class="o">)</span> <span class="k">=</span> <span class="n">data</span><span class="o">.</span><span class="n">randomSplit</span><span class="o">(</span><span class="nc">Array</span><span class="o">(</span><span class="mf">0.7</span><span class="o">,</span> <span class="mf">0.3</span><span class="o">))</span> |
| |
| <span class="c1">// Train a GBT model.</span> |
| <span class="k">val</span> <span class="n">gbt</span> <span class="k">=</span> <span class="k">new</span> <span class="nc">GBTRegressor</span><span class="o">()</span> |
| <span class="o">.</span><span class="n">setLabelCol</span><span class="o">(</span><span class="s">"label"</span><span class="o">)</span> |
| <span class="o">.</span><span class="n">setFeaturesCol</span><span class="o">(</span><span class="s">"indexedFeatures"</span><span class="o">)</span> |
| <span class="o">.</span><span class="n">setMaxIter</span><span class="o">(</span><span class="mi">10</span><span class="o">)</span> |
| |
| <span class="c1">// Chain indexer and GBT in a Pipeline.</span> |
| <span class="k">val</span> <span class="n">pipeline</span> <span class="k">=</span> <span class="k">new</span> <span class="nc">Pipeline</span><span class="o">()</span> |
| <span class="o">.</span><span class="n">setStages</span><span class="o">(</span><span class="nc">Array</span><span class="o">(</span><span class="n">featureIndexer</span><span class="o">,</span> <span class="n">gbt</span><span class="o">))</span> |
| |
| <span class="c1">// Train model. This also runs the indexer.</span> |
| <span class="k">val</span> <span class="n">model</span> <span class="k">=</span> <span class="n">pipeline</span><span class="o">.</span><span class="n">fit</span><span class="o">(</span><span class="n">trainingData</span><span class="o">)</span> |
| |
| <span class="c1">// Make predictions.</span> |
| <span class="k">val</span> <span class="n">predictions</span> <span class="k">=</span> <span class="n">model</span><span class="o">.</span><span class="n">transform</span><span class="o">(</span><span class="n">testData</span><span class="o">)</span> |
| |
| <span class="c1">// Select example rows to display.</span> |
| <span class="n">predictions</span><span class="o">.</span><span class="n">select</span><span class="o">(</span><span class="s">"prediction"</span><span class="o">,</span> <span class="s">"label"</span><span class="o">,</span> <span class="s">"features"</span><span class="o">).</span><span class="n">show</span><span class="o">(</span><span class="mi">5</span><span class="o">)</span> |
| |
| <span class="c1">// Select (prediction, true label) and compute test error.</span> |
| <span class="k">val</span> <span class="n">evaluator</span> <span class="k">=</span> <span class="k">new</span> <span class="nc">RegressionEvaluator</span><span class="o">()</span> |
| <span class="o">.</span><span class="n">setLabelCol</span><span class="o">(</span><span class="s">"label"</span><span class="o">)</span> |
| <span class="o">.</span><span class="n">setPredictionCol</span><span class="o">(</span><span class="s">"prediction"</span><span class="o">)</span> |
| <span class="o">.</span><span class="n">setMetricName</span><span class="o">(</span><span class="s">"rmse"</span><span class="o">)</span> |
| <span class="k">val</span> <span class="n">rmse</span> <span class="k">=</span> <span class="n">evaluator</span><span class="o">.</span><span class="n">evaluate</span><span class="o">(</span><span class="n">predictions</span><span class="o">)</span> |
| <span class="n">println</span><span class="o">(</span><span class="s">s"Root Mean Squared Error (RMSE) on test data = </span><span class="si">$rmse</span><span class="s">"</span><span class="o">)</span> |
| |
| <span class="k">val</span> <span class="n">gbtModel</span> <span class="k">=</span> <span class="n">model</span><span class="o">.</span><span class="n">stages</span><span class="o">(</span><span class="mi">1</span><span class="o">).</span><span class="n">asInstanceOf</span><span class="o">[</span><span class="kt">GBTRegressionModel</span><span class="o">]</span> |
| <span class="n">println</span><span class="o">(</span><span class="s">s"Learned regression GBT model:\n </span><span class="si">${</span><span class="n">gbtModel</span><span class="o">.</span><span class="n">toDebugString</span><span class="si">}</span><span class="s">"</span><span class="o">)</span> |
| </pre></div> |
| <div><small>Find full example code at "examples/src/main/scala/org/apache/spark/examples/ml/GradientBoostedTreeRegressorExample.scala" in the Spark repo.</small></div> |
| </div> |
| |
| <div data-lang="java"> |
| |
| <p>Refer to the <a href="api/java/org/apache/spark/ml/regression/GBTRegressor.html">Java API docs</a> for more details.</p> |
| |
| <div class="highlight"><pre><span></span><span class="kn">import</span> <span class="nn">org.apache.spark.ml.Pipeline</span><span class="o">;</span> |
| <span class="kn">import</span> <span class="nn">org.apache.spark.ml.PipelineModel</span><span class="o">;</span> |
| <span class="kn">import</span> <span class="nn">org.apache.spark.ml.PipelineStage</span><span class="o">;</span> |
| <span class="kn">import</span> <span class="nn">org.apache.spark.ml.evaluation.RegressionEvaluator</span><span class="o">;</span> |
| <span class="kn">import</span> <span class="nn">org.apache.spark.ml.feature.VectorIndexer</span><span class="o">;</span> |
| <span class="kn">import</span> <span class="nn">org.apache.spark.ml.feature.VectorIndexerModel</span><span class="o">;</span> |
| <span class="kn">import</span> <span class="nn">org.apache.spark.ml.regression.GBTRegressionModel</span><span class="o">;</span> |
| <span class="kn">import</span> <span class="nn">org.apache.spark.ml.regression.GBTRegressor</span><span class="o">;</span> |
| <span class="kn">import</span> <span class="nn">org.apache.spark.sql.Dataset</span><span class="o">;</span> |
| <span class="kn">import</span> <span class="nn">org.apache.spark.sql.Row</span><span class="o">;</span> |
| <span class="kn">import</span> <span class="nn">org.apache.spark.sql.SparkSession</span><span class="o">;</span> |
| |
| <span class="c1">// Load and parse the data file, converting it to a DataFrame.</span> |
| <span class="n">Dataset</span><span class="o"><</span><span class="n">Row</span><span class="o">></span> <span class="n">data</span> <span class="o">=</span> <span class="n">spark</span><span class="o">.</span><span class="na">read</span><span class="o">().</span><span class="na">format</span><span class="o">(</span><span class="s">"libsvm"</span><span class="o">).</span><span class="na">load</span><span class="o">(</span><span class="s">"data/mllib/sample_libsvm_data.txt"</span><span class="o">);</span> |
| |
| <span class="c1">// Automatically identify categorical features, and index them.</span> |
| <span class="c1">// Set maxCategories so features with > 4 distinct values are treated as continuous.</span> |
| <span class="n">VectorIndexerModel</span> <span class="n">featureIndexer</span> <span class="o">=</span> <span class="k">new</span> <span class="n">VectorIndexer</span><span class="o">()</span> |
| <span class="o">.</span><span class="na">setInputCol</span><span class="o">(</span><span class="s">"features"</span><span class="o">)</span> |
| <span class="o">.</span><span class="na">setOutputCol</span><span class="o">(</span><span class="s">"indexedFeatures"</span><span class="o">)</span> |
| <span class="o">.</span><span class="na">setMaxCategories</span><span class="o">(</span><span class="mi">4</span><span class="o">)</span> |
| <span class="o">.</span><span class="na">fit</span><span class="o">(</span><span class="n">data</span><span class="o">);</span> |
| |
| <span class="c1">// Split the data into training and test sets (30% held out for testing).</span> |
| <span class="n">Dataset</span><span class="o"><</span><span class="n">Row</span><span class="o">>[]</span> <span class="n">splits</span> <span class="o">=</span> <span class="n">data</span><span class="o">.</span><span class="na">randomSplit</span><span class="o">(</span><span class="k">new</span> <span class="kt">double</span><span class="o">[]</span> <span class="o">{</span><span class="mf">0.7</span><span class="o">,</span> <span class="mf">0.3</span><span class="o">});</span> |
| <span class="n">Dataset</span><span class="o"><</span><span class="n">Row</span><span class="o">></span> <span class="n">trainingData</span> <span class="o">=</span> <span class="n">splits</span><span class="o">[</span><span class="mi">0</span><span class="o">];</span> |
| <span class="n">Dataset</span><span class="o"><</span><span class="n">Row</span><span class="o">></span> <span class="n">testData</span> <span class="o">=</span> <span class="n">splits</span><span class="o">[</span><span class="mi">1</span><span class="o">];</span> |
| |
| <span class="c1">// Train a GBT model.</span> |
| <span class="n">GBTRegressor</span> <span class="n">gbt</span> <span class="o">=</span> <span class="k">new</span> <span class="n">GBTRegressor</span><span class="o">()</span> |
| <span class="o">.</span><span class="na">setLabelCol</span><span class="o">(</span><span class="s">"label"</span><span class="o">)</span> |
| <span class="o">.</span><span class="na">setFeaturesCol</span><span class="o">(</span><span class="s">"indexedFeatures"</span><span class="o">)</span> |
| <span class="o">.</span><span class="na">setMaxIter</span><span class="o">(</span><span class="mi">10</span><span class="o">);</span> |
| |
| <span class="c1">// Chain indexer and GBT in a Pipeline.</span> |
| <span class="n">Pipeline</span> <span class="n">pipeline</span> <span class="o">=</span> <span class="k">new</span> <span class="n">Pipeline</span><span class="o">().</span><span class="na">setStages</span><span class="o">(</span><span class="k">new</span> <span class="n">PipelineStage</span><span class="o">[]</span> <span class="o">{</span><span class="n">featureIndexer</span><span class="o">,</span> <span class="n">gbt</span><span class="o">});</span> |
| |
| <span class="c1">// Train model. This also runs the indexer.</span> |
| <span class="n">PipelineModel</span> <span class="n">model</span> <span class="o">=</span> <span class="n">pipeline</span><span class="o">.</span><span class="na">fit</span><span class="o">(</span><span class="n">trainingData</span><span class="o">);</span> |
| |
| <span class="c1">// Make predictions.</span> |
| <span class="n">Dataset</span><span class="o"><</span><span class="n">Row</span><span class="o">></span> <span class="n">predictions</span> <span class="o">=</span> <span class="n">model</span><span class="o">.</span><span class="na">transform</span><span class="o">(</span><span class="n">testData</span><span class="o">);</span> |
| |
| <span class="c1">// Select example rows to display.</span> |
| <span class="n">predictions</span><span class="o">.</span><span class="na">select</span><span class="o">(</span><span class="s">"prediction"</span><span class="o">,</span> <span class="s">"label"</span><span class="o">,</span> <span class="s">"features"</span><span class="o">).</span><span class="na">show</span><span class="o">(</span><span class="mi">5</span><span class="o">);</span> |
| |
| <span class="c1">// Select (prediction, true label) and compute test error.</span> |
| <span class="n">RegressionEvaluator</span> <span class="n">evaluator</span> <span class="o">=</span> <span class="k">new</span> <span class="n">RegressionEvaluator</span><span class="o">()</span> |
| <span class="o">.</span><span class="na">setLabelCol</span><span class="o">(</span><span class="s">"label"</span><span class="o">)</span> |
| <span class="o">.</span><span class="na">setPredictionCol</span><span class="o">(</span><span class="s">"prediction"</span><span class="o">)</span> |
| <span class="o">.</span><span class="na">setMetricName</span><span class="o">(</span><span class="s">"rmse"</span><span class="o">);</span> |
| <span class="kt">double</span> <span class="n">rmse</span> <span class="o">=</span> <span class="n">evaluator</span><span class="o">.</span><span class="na">evaluate</span><span class="o">(</span><span class="n">predictions</span><span class="o">);</span> |
| <span class="n">System</span><span class="o">.</span><span class="na">out</span><span class="o">.</span><span class="na">println</span><span class="o">(</span><span class="s">"Root Mean Squared Error (RMSE) on test data = "</span> <span class="o">+</span> <span class="n">rmse</span><span class="o">);</span> |
| |
| <span class="n">GBTRegressionModel</span> <span class="n">gbtModel</span> <span class="o">=</span> <span class="o">(</span><span class="n">GBTRegressionModel</span><span class="o">)(</span><span class="n">model</span><span class="o">.</span><span class="na">stages</span><span class="o">()[</span><span class="mi">1</span><span class="o">]);</span> |
| <span class="n">System</span><span class="o">.</span><span class="na">out</span><span class="o">.</span><span class="na">println</span><span class="o">(</span><span class="s">"Learned regression GBT model:\n"</span> <span class="o">+</span> <span class="n">gbtModel</span><span class="o">.</span><span class="na">toDebugString</span><span class="o">());</span> |
| </pre></div> |
| <div><small>Find full example code at "examples/src/main/java/org/apache/spark/examples/ml/JavaGradientBoostedTreeRegressorExample.java" in the Spark repo.</small></div> |
| </div> |
| |
| <div data-lang="python"> |
| |
| <p>Refer to the <a href="api/python/pyspark.ml.html#pyspark.ml.regression.GBTRegressor">Python API docs</a> for more details.</p> |
| |
| <div class="highlight"><pre><span></span><span class="kn">from</span> <span class="nn">pyspark.ml</span> <span class="kn">import</span> <span class="n">Pipeline</span> |
| <span class="kn">from</span> <span class="nn">pyspark.ml.regression</span> <span class="kn">import</span> <span class="n">GBTRegressor</span> |
| <span class="kn">from</span> <span class="nn">pyspark.ml.feature</span> <span class="kn">import</span> <span class="n">VectorIndexer</span> |
| <span class="kn">from</span> <span class="nn">pyspark.ml.evaluation</span> <span class="kn">import</span> <span class="n">RegressionEvaluator</span> |
| |
| <span class="c1"># Load and parse the data file, converting it to a DataFrame.</span> |
| <span class="n">data</span> <span class="o">=</span> <span class="n">spark</span><span class="o">.</span><span class="n">read</span><span class="o">.</span><span class="n">format</span><span class="p">(</span><span class="s2">"libsvm"</span><span class="p">)</span><span class="o">.</span><span class="n">load</span><span class="p">(</span><span class="s2">"data/mllib/sample_libsvm_data.txt"</span><span class="p">)</span> |
| |
| <span class="c1"># Automatically identify categorical features, and index them.</span> |
| <span class="c1"># Set maxCategories so features with > 4 distinct values are treated as continuous.</span> |
| <span class="n">featureIndexer</span> <span class="o">=</span>\ |
| <span class="n">VectorIndexer</span><span class="p">(</span><span class="n">inputCol</span><span class="o">=</span><span class="s2">"features"</span><span class="p">,</span> <span class="n">outputCol</span><span class="o">=</span><span class="s2">"indexedFeatures"</span><span class="p">,</span> <span class="n">maxCategories</span><span class="o">=</span><span class="mi">4</span><span class="p">)</span><span class="o">.</span><span class="n">fit</span><span class="p">(</span><span class="n">data</span><span class="p">)</span> |
| |
| <span class="c1"># Split the data into training and test sets (30% held out for testing)</span> |
| <span class="p">(</span><span class="n">trainingData</span><span class="p">,</span> <span class="n">testData</span><span class="p">)</span> <span class="o">=</span> <span class="n">data</span><span class="o">.</span><span class="n">randomSplit</span><span class="p">([</span><span class="mf">0.7</span><span class="p">,</span> <span class="mf">0.3</span><span class="p">])</span> |
| |
| <span class="c1"># Train a GBT model.</span> |
| <span class="n">gbt</span> <span class="o">=</span> <span class="n">GBTRegressor</span><span class="p">(</span><span class="n">featuresCol</span><span class="o">=</span><span class="s2">"indexedFeatures"</span><span class="p">,</span> <span class="n">maxIter</span><span class="o">=</span><span class="mi">10</span><span class="p">)</span> |
| |
| <span class="c1"># Chain indexer and GBT in a Pipeline</span> |
| <span class="n">pipeline</span> <span class="o">=</span> <span class="n">Pipeline</span><span class="p">(</span><span class="n">stages</span><span class="o">=</span><span class="p">[</span><span class="n">featureIndexer</span><span class="p">,</span> <span class="n">gbt</span><span class="p">])</span> |
| |
| <span class="c1"># Train model. This also runs the indexer.</span> |
| <span class="n">model</span> <span class="o">=</span> <span class="n">pipeline</span><span class="o">.</span><span class="n">fit</span><span class="p">(</span><span class="n">trainingData</span><span class="p">)</span> |
| |
| <span class="c1"># Make predictions.</span> |
| <span class="n">predictions</span> <span class="o">=</span> <span class="n">model</span><span class="o">.</span><span class="n">transform</span><span class="p">(</span><span class="n">testData</span><span class="p">)</span> |
| |
| <span class="c1"># Select example rows to display.</span> |
| <span class="n">predictions</span><span class="o">.</span><span class="n">select</span><span class="p">(</span><span class="s2">"prediction"</span><span class="p">,</span> <span class="s2">"label"</span><span class="p">,</span> <span class="s2">"features"</span><span class="p">)</span><span class="o">.</span><span class="n">show</span><span class="p">(</span><span class="mi">5</span><span class="p">)</span> |
| |
| <span class="c1"># Select (prediction, true label) and compute test error</span> |
| <span class="n">evaluator</span> <span class="o">=</span> <span class="n">RegressionEvaluator</span><span class="p">(</span> |
| <span class="n">labelCol</span><span class="o">=</span><span class="s2">"label"</span><span class="p">,</span> <span class="n">predictionCol</span><span class="o">=</span><span class="s2">"prediction"</span><span class="p">,</span> <span class="n">metricName</span><span class="o">=</span><span class="s2">"rmse"</span><span class="p">)</span> |
| <span class="n">rmse</span> <span class="o">=</span> <span class="n">evaluator</span><span class="o">.</span><span class="n">evaluate</span><span class="p">(</span><span class="n">predictions</span><span class="p">)</span> |
| <span class="k">print</span><span class="p">(</span><span class="s2">"Root Mean Squared Error (RMSE) on test data = </span><span class="si">%g</span><span class="s2">"</span> <span class="o">%</span> <span class="n">rmse</span><span class="p">)</span> |
| |
| <span class="n">gbtModel</span> <span class="o">=</span> <span class="n">model</span><span class="o">.</span><span class="n">stages</span><span class="p">[</span><span class="mi">1</span><span class="p">]</span> |
| <span class="k">print</span><span class="p">(</span><span class="n">gbtModel</span><span class="p">)</span> <span class="c1"># summary only</span> |
| </pre></div> |
| <div><small>Find full example code at "examples/src/main/python/ml/gradient_boosted_tree_regressor_example.py" in the Spark repo.</small></div> |
| </div> |
| |
| <div data-lang="r"> |
| |
| <p>Refer to the <a href="api/R/spark.gbt.html">R API docs</a> for more details.</p> |
| |
| <div class="highlight"><pre><span></span><span class="c1"># Load training data</span> |
| df <span class="o"><-</span> read.df<span class="p">(</span><span class="s">"data/mllib/sample_linear_regression_data.txt"</span><span class="p">,</span> <span class="kn">source</span> <span class="o">=</span> <span class="s">"libsvm"</span><span class="p">)</span> |
| training <span class="o"><-</span> df |
| test <span class="o"><-</span> df |
| |
| <span class="c1"># Fit a GBT regression model with spark.gbt</span> |
| model <span class="o"><-</span> spark.gbt<span class="p">(</span>training<span class="p">,</span> label <span class="o">~</span> features<span class="p">,</span> <span class="s">"regression"</span><span class="p">,</span> maxIter <span class="o">=</span> <span class="m">10</span><span class="p">)</span> |
| |
| <span class="c1"># Model summary</span> |
| <span class="kp">summary</span><span class="p">(</span>model<span class="p">)</span> |
| |
| <span class="c1"># Prediction</span> |
| predictions <span class="o"><-</span> predict<span class="p">(</span>model<span class="p">,</span> test<span class="p">)</span> |
| <span class="kp">head</span><span class="p">(</span>predictions<span class="p">)</span> |
| </pre></div> |
| <div><small>Find full example code at "examples/src/main/r/ml/gbt.R" in the Spark repo.</small></div> |
| </div> |
| |
| </div> |
| |
| <h2 id="survival-regression">Survival regression</h2> |
| |
| <p>In <code>spark.ml</code>, we implement the <a href="https://en.wikipedia.org/wiki/Accelerated_failure_time_model">Accelerated failure time (AFT)</a> |
| model which is a parametric survival regression model for censored data. |
| It describes a model for the log of survival time, so it’s often called a |
| log-linear model for survival analysis. Different from a |
| <a href="https://en.wikipedia.org/wiki/Proportional_hazards_model">Proportional hazards</a> model |
| designed for the same purpose, the AFT model is easier to parallelize |
| because each instance contributes to the objective function independently.</p> |
| |
| <p>Given the values of the covariates $x^{‘}$, for random lifetime $t_{i}$ of |
| subjects i = 1, …, n, with possible right-censoring, |
| the likelihood function under the AFT model is given as: |
| <code>\[ |
| L(\beta,\sigma)=\prod_{i=1}^n[\frac{1}{\sigma}f_{0}(\frac{\log{t_{i}}-x^{'}\beta}{\sigma})]^{\delta_{i}}S_{0}(\frac{\log{t_{i}}-x^{'}\beta}{\sigma})^{1-\delta_{i}} |
| \]</code> |
| Where $\delta_{i}$ is the indicator of the event has occurred i.e. uncensored or not. |
| Using $\epsilon_{i}=\frac{\log{t_{i}}-x^{‘}\beta}{\sigma}$, the log-likelihood function |
| assumes the form: |
| <code>\[ |
| \iota(\beta,\sigma)=\sum_{i=1}^{n}[-\delta_{i}\log\sigma+\delta_{i}\log{f_{0}}(\epsilon_{i})+(1-\delta_{i})\log{S_{0}(\epsilon_{i})}] |
| \]</code> |
| Where $S_{0}(\epsilon_{i})$ is the baseline survivor function, |
| and $f_{0}(\epsilon_{i})$ is the corresponding density function.</p> |
| |
| <p>The most commonly used AFT model is based on the Weibull distribution of the survival time. |
| The Weibull distribution for lifetime corresponds to the extreme value distribution for the |
| log of the lifetime, and the $S_{0}(\epsilon)$ function is: |
| <code>\[ |
| S_{0}(\epsilon_{i})=\exp(-e^{\epsilon_{i}}) |
| \]</code> |
| the $f_{0}(\epsilon_{i})$ function is: |
| <code>\[ |
| f_{0}(\epsilon_{i})=e^{\epsilon_{i}}\exp(-e^{\epsilon_{i}}) |
| \]</code> |
| The log-likelihood function for AFT model with a Weibull distribution of lifetime is: |
| <code>\[ |
| \iota(\beta,\sigma)= -\sum_{i=1}^n[\delta_{i}\log\sigma-\delta_{i}\epsilon_{i}+e^{\epsilon_{i}}] |
| \]</code> |
| Due to minimizing the negative log-likelihood equivalent to maximum a posteriori probability, |
| the loss function we use to optimize is $-\iota(\beta,\sigma)$. |
| The gradient functions for $\beta$ and $\log\sigma$ respectively are: |
| <code>\[ |
| \frac{\partial (-\iota)}{\partial \beta}=\sum_{1=1}^{n}[\delta_{i}-e^{\epsilon_{i}}]\frac{x_{i}}{\sigma} |
| \]</code> |
| <code>\[ |
| \frac{\partial (-\iota)}{\partial (\log\sigma)}=\sum_{i=1}^{n}[\delta_{i}+(\delta_{i}-e^{\epsilon_{i}})\epsilon_{i}] |
| \]</code></p> |
| |
| <p>The AFT model can be formulated as a convex optimization problem, |
| i.e. the task of finding a minimizer of a convex function $-\iota(\beta,\sigma)$ |
| that depends on the coefficients vector $\beta$ and the log of scale parameter $\log\sigma$. |
| The optimization algorithm underlying the implementation is L-BFGS. |
| The implementation matches the result from R’s survival function |
| <a href="https://stat.ethz.ch/R-manual/R-devel/library/survival/html/survreg.html">survreg</a></p> |
| |
| <blockquote> |
| <p>When fitting AFTSurvivalRegressionModel without intercept on dataset with constant nonzero column, Spark MLlib outputs zero coefficients for constant nonzero columns. This behavior is different from R survival::survreg.</p> |
| </blockquote> |
| |
| <p><strong>Examples</strong></p> |
| |
| <div class="codetabs"> |
| |
| <div data-lang="scala"> |
| |
| <p>Refer to the <a href="api/scala/index.html#org.apache.spark.ml.regression.AFTSurvivalRegression">Scala API docs</a> for more details.</p> |
| |
| <div class="highlight"><pre><span></span><span class="k">import</span> <span class="nn">org.apache.spark.ml.linalg.Vectors</span> |
| <span class="k">import</span> <span class="nn">org.apache.spark.ml.regression.AFTSurvivalRegression</span> |
| |
| <span class="k">val</span> <span class="n">training</span> <span class="k">=</span> <span class="n">spark</span><span class="o">.</span><span class="n">createDataFrame</span><span class="o">(</span><span class="nc">Seq</span><span class="o">(</span> |
| <span class="o">(</span><span class="mf">1.218</span><span class="o">,</span> <span class="mf">1.0</span><span class="o">,</span> <span class="nc">Vectors</span><span class="o">.</span><span class="n">dense</span><span class="o">(</span><span class="mf">1.560</span><span class="o">,</span> <span class="o">-</span><span class="mf">0.605</span><span class="o">)),</span> |
| <span class="o">(</span><span class="mf">2.949</span><span class="o">,</span> <span class="mf">0.0</span><span class="o">,</span> <span class="nc">Vectors</span><span class="o">.</span><span class="n">dense</span><span class="o">(</span><span class="mf">0.346</span><span class="o">,</span> <span class="mf">2.158</span><span class="o">)),</span> |
| <span class="o">(</span><span class="mf">3.627</span><span class="o">,</span> <span class="mf">0.0</span><span class="o">,</span> <span class="nc">Vectors</span><span class="o">.</span><span class="n">dense</span><span class="o">(</span><span class="mf">1.380</span><span class="o">,</span> <span class="mf">0.231</span><span class="o">)),</span> |
| <span class="o">(</span><span class="mf">0.273</span><span class="o">,</span> <span class="mf">1.0</span><span class="o">,</span> <span class="nc">Vectors</span><span class="o">.</span><span class="n">dense</span><span class="o">(</span><span class="mf">0.520</span><span class="o">,</span> <span class="mf">1.151</span><span class="o">)),</span> |
| <span class="o">(</span><span class="mf">4.199</span><span class="o">,</span> <span class="mf">0.0</span><span class="o">,</span> <span class="nc">Vectors</span><span class="o">.</span><span class="n">dense</span><span class="o">(</span><span class="mf">0.795</span><span class="o">,</span> <span class="o">-</span><span class="mf">0.226</span><span class="o">))</span> |
| <span class="o">)).</span><span class="n">toDF</span><span class="o">(</span><span class="s">"label"</span><span class="o">,</span> <span class="s">"censor"</span><span class="o">,</span> <span class="s">"features"</span><span class="o">)</span> |
| <span class="k">val</span> <span class="n">quantileProbabilities</span> <span class="k">=</span> <span class="nc">Array</span><span class="o">(</span><span class="mf">0.3</span><span class="o">,</span> <span class="mf">0.6</span><span class="o">)</span> |
| <span class="k">val</span> <span class="n">aft</span> <span class="k">=</span> <span class="k">new</span> <span class="nc">AFTSurvivalRegression</span><span class="o">()</span> |
| <span class="o">.</span><span class="n">setQuantileProbabilities</span><span class="o">(</span><span class="n">quantileProbabilities</span><span class="o">)</span> |
| <span class="o">.</span><span class="n">setQuantilesCol</span><span class="o">(</span><span class="s">"quantiles"</span><span class="o">)</span> |
| |
| <span class="k">val</span> <span class="n">model</span> <span class="k">=</span> <span class="n">aft</span><span class="o">.</span><span class="n">fit</span><span class="o">(</span><span class="n">training</span><span class="o">)</span> |
| |
| <span class="c1">// Print the coefficients, intercept and scale parameter for AFT survival regression</span> |
| <span class="n">println</span><span class="o">(</span><span class="s">s"Coefficients: </span><span class="si">${</span><span class="n">model</span><span class="o">.</span><span class="n">coefficients</span><span class="si">}</span><span class="s">"</span><span class="o">)</span> |
| <span class="n">println</span><span class="o">(</span><span class="s">s"Intercept: </span><span class="si">${</span><span class="n">model</span><span class="o">.</span><span class="n">intercept</span><span class="si">}</span><span class="s">"</span><span class="o">)</span> |
| <span class="n">println</span><span class="o">(</span><span class="s">s"Scale: </span><span class="si">${</span><span class="n">model</span><span class="o">.</span><span class="n">scale</span><span class="si">}</span><span class="s">"</span><span class="o">)</span> |
| <span class="n">model</span><span class="o">.</span><span class="n">transform</span><span class="o">(</span><span class="n">training</span><span class="o">).</span><span class="n">show</span><span class="o">(</span><span class="kc">false</span><span class="o">)</span> |
| </pre></div> |
| <div><small>Find full example code at "examples/src/main/scala/org/apache/spark/examples/ml/AFTSurvivalRegressionExample.scala" in the Spark repo.</small></div> |
| </div> |
| |
| <div data-lang="java"> |
| |
| <p>Refer to the <a href="api/java/org/apache/spark/ml/regression/AFTSurvivalRegression.html">Java API docs</a> for more details.</p> |
| |
| <div class="highlight"><pre><span></span><span class="kn">import</span> <span class="nn">java.util.Arrays</span><span class="o">;</span> |
| <span class="kn">import</span> <span class="nn">java.util.List</span><span class="o">;</span> |
| |
| <span class="kn">import</span> <span class="nn">org.apache.spark.ml.regression.AFTSurvivalRegression</span><span class="o">;</span> |
| <span class="kn">import</span> <span class="nn">org.apache.spark.ml.regression.AFTSurvivalRegressionModel</span><span class="o">;</span> |
| <span class="kn">import</span> <span class="nn">org.apache.spark.ml.linalg.VectorUDT</span><span class="o">;</span> |
| <span class="kn">import</span> <span class="nn">org.apache.spark.ml.linalg.Vectors</span><span class="o">;</span> |
| <span class="kn">import</span> <span class="nn">org.apache.spark.sql.Dataset</span><span class="o">;</span> |
| <span class="kn">import</span> <span class="nn">org.apache.spark.sql.Row</span><span class="o">;</span> |
| <span class="kn">import</span> <span class="nn">org.apache.spark.sql.RowFactory</span><span class="o">;</span> |
| <span class="kn">import</span> <span class="nn">org.apache.spark.sql.SparkSession</span><span class="o">;</span> |
| <span class="kn">import</span> <span class="nn">org.apache.spark.sql.types.DataTypes</span><span class="o">;</span> |
| <span class="kn">import</span> <span class="nn">org.apache.spark.sql.types.Metadata</span><span class="o">;</span> |
| <span class="kn">import</span> <span class="nn">org.apache.spark.sql.types.StructField</span><span class="o">;</span> |
| <span class="kn">import</span> <span class="nn">org.apache.spark.sql.types.StructType</span><span class="o">;</span> |
| |
| <span class="n">List</span><span class="o"><</span><span class="n">Row</span><span class="o">></span> <span class="n">data</span> <span class="o">=</span> <span class="n">Arrays</span><span class="o">.</span><span class="na">asList</span><span class="o">(</span> |
| <span class="n">RowFactory</span><span class="o">.</span><span class="na">create</span><span class="o">(</span><span class="mf">1.218</span><span class="o">,</span> <span class="mf">1.0</span><span class="o">,</span> <span class="n">Vectors</span><span class="o">.</span><span class="na">dense</span><span class="o">(</span><span class="mf">1.560</span><span class="o">,</span> <span class="o">-</span><span class="mf">0.605</span><span class="o">)),</span> |
| <span class="n">RowFactory</span><span class="o">.</span><span class="na">create</span><span class="o">(</span><span class="mf">2.949</span><span class="o">,</span> <span class="mf">0.0</span><span class="o">,</span> <span class="n">Vectors</span><span class="o">.</span><span class="na">dense</span><span class="o">(</span><span class="mf">0.346</span><span class="o">,</span> <span class="mf">2.158</span><span class="o">)),</span> |
| <span class="n">RowFactory</span><span class="o">.</span><span class="na">create</span><span class="o">(</span><span class="mf">3.627</span><span class="o">,</span> <span class="mf">0.0</span><span class="o">,</span> <span class="n">Vectors</span><span class="o">.</span><span class="na">dense</span><span class="o">(</span><span class="mf">1.380</span><span class="o">,</span> <span class="mf">0.231</span><span class="o">)),</span> |
| <span class="n">RowFactory</span><span class="o">.</span><span class="na">create</span><span class="o">(</span><span class="mf">0.273</span><span class="o">,</span> <span class="mf">1.0</span><span class="o">,</span> <span class="n">Vectors</span><span class="o">.</span><span class="na">dense</span><span class="o">(</span><span class="mf">0.520</span><span class="o">,</span> <span class="mf">1.151</span><span class="o">)),</span> |
| <span class="n">RowFactory</span><span class="o">.</span><span class="na">create</span><span class="o">(</span><span class="mf">4.199</span><span class="o">,</span> <span class="mf">0.0</span><span class="o">,</span> <span class="n">Vectors</span><span class="o">.</span><span class="na">dense</span><span class="o">(</span><span class="mf">0.795</span><span class="o">,</span> <span class="o">-</span><span class="mf">0.226</span><span class="o">))</span> |
| <span class="o">);</span> |
| <span class="n">StructType</span> <span class="n">schema</span> <span class="o">=</span> <span class="k">new</span> <span class="n">StructType</span><span class="o">(</span><span class="k">new</span> <span class="n">StructField</span><span class="o">[]{</span> |
| <span class="k">new</span> <span class="n">StructField</span><span class="o">(</span><span class="s">"label"</span><span class="o">,</span> <span class="n">DataTypes</span><span class="o">.</span><span class="na">DoubleType</span><span class="o">,</span> <span class="kc">false</span><span class="o">,</span> <span class="n">Metadata</span><span class="o">.</span><span class="na">empty</span><span class="o">()),</span> |
| <span class="k">new</span> <span class="n">StructField</span><span class="o">(</span><span class="s">"censor"</span><span class="o">,</span> <span class="n">DataTypes</span><span class="o">.</span><span class="na">DoubleType</span><span class="o">,</span> <span class="kc">false</span><span class="o">,</span> <span class="n">Metadata</span><span class="o">.</span><span class="na">empty</span><span class="o">()),</span> |
| <span class="k">new</span> <span class="n">StructField</span><span class="o">(</span><span class="s">"features"</span><span class="o">,</span> <span class="k">new</span> <span class="n">VectorUDT</span><span class="o">(),</span> <span class="kc">false</span><span class="o">,</span> <span class="n">Metadata</span><span class="o">.</span><span class="na">empty</span><span class="o">())</span> |
| <span class="o">});</span> |
| <span class="n">Dataset</span><span class="o"><</span><span class="n">Row</span><span class="o">></span> <span class="n">training</span> <span class="o">=</span> <span class="n">spark</span><span class="o">.</span><span class="na">createDataFrame</span><span class="o">(</span><span class="n">data</span><span class="o">,</span> <span class="n">schema</span><span class="o">);</span> |
| <span class="kt">double</span><span class="o">[]</span> <span class="n">quantileProbabilities</span> <span class="o">=</span> <span class="k">new</span> <span class="kt">double</span><span class="o">[]{</span><span class="mf">0.3</span><span class="o">,</span> <span class="mf">0.6</span><span class="o">};</span> |
| <span class="n">AFTSurvivalRegression</span> <span class="n">aft</span> <span class="o">=</span> <span class="k">new</span> <span class="n">AFTSurvivalRegression</span><span class="o">()</span> |
| <span class="o">.</span><span class="na">setQuantileProbabilities</span><span class="o">(</span><span class="n">quantileProbabilities</span><span class="o">)</span> |
| <span class="o">.</span><span class="na">setQuantilesCol</span><span class="o">(</span><span class="s">"quantiles"</span><span class="o">);</span> |
| |
| <span class="n">AFTSurvivalRegressionModel</span> <span class="n">model</span> <span class="o">=</span> <span class="n">aft</span><span class="o">.</span><span class="na">fit</span><span class="o">(</span><span class="n">training</span><span class="o">);</span> |
| |
| <span class="c1">// Print the coefficients, intercept and scale parameter for AFT survival regression</span> |
| <span class="n">System</span><span class="o">.</span><span class="na">out</span><span class="o">.</span><span class="na">println</span><span class="o">(</span><span class="s">"Coefficients: "</span> <span class="o">+</span> <span class="n">model</span><span class="o">.</span><span class="na">coefficients</span><span class="o">());</span> |
| <span class="n">System</span><span class="o">.</span><span class="na">out</span><span class="o">.</span><span class="na">println</span><span class="o">(</span><span class="s">"Intercept: "</span> <span class="o">+</span> <span class="n">model</span><span class="o">.</span><span class="na">intercept</span><span class="o">());</span> |
| <span class="n">System</span><span class="o">.</span><span class="na">out</span><span class="o">.</span><span class="na">println</span><span class="o">(</span><span class="s">"Scale: "</span> <span class="o">+</span> <span class="n">model</span><span class="o">.</span><span class="na">scale</span><span class="o">());</span> |
| <span class="n">model</span><span class="o">.</span><span class="na">transform</span><span class="o">(</span><span class="n">training</span><span class="o">).</span><span class="na">show</span><span class="o">(</span><span class="kc">false</span><span class="o">);</span> |
| </pre></div> |
| <div><small>Find full example code at "examples/src/main/java/org/apache/spark/examples/ml/JavaAFTSurvivalRegressionExample.java" in the Spark repo.</small></div> |
| </div> |
| |
| <div data-lang="python"> |
| |
| <p>Refer to the <a href="api/python/pyspark.ml.html#pyspark.ml.regression.AFTSurvivalRegression">Python API docs</a> for more details.</p> |
| |
| <div class="highlight"><pre><span></span><span class="kn">from</span> <span class="nn">pyspark.ml.regression</span> <span class="kn">import</span> <span class="n">AFTSurvivalRegression</span> |
| <span class="kn">from</span> <span class="nn">pyspark.ml.linalg</span> <span class="kn">import</span> <span class="n">Vectors</span> |
| |
| <span class="n">training</span> <span class="o">=</span> <span class="n">spark</span><span class="o">.</span><span class="n">createDataFrame</span><span class="p">([</span> |
| <span class="p">(</span><span class="mf">1.218</span><span class="p">,</span> <span class="mf">1.0</span><span class="p">,</span> <span class="n">Vectors</span><span class="o">.</span><span class="n">dense</span><span class="p">(</span><span class="mf">1.560</span><span class="p">,</span> <span class="o">-</span><span class="mf">0.605</span><span class="p">)),</span> |
| <span class="p">(</span><span class="mf">2.949</span><span class="p">,</span> <span class="mf">0.0</span><span class="p">,</span> <span class="n">Vectors</span><span class="o">.</span><span class="n">dense</span><span class="p">(</span><span class="mf">0.346</span><span class="p">,</span> <span class="mf">2.158</span><span class="p">)),</span> |
| <span class="p">(</span><span class="mf">3.627</span><span class="p">,</span> <span class="mf">0.0</span><span class="p">,</span> <span class="n">Vectors</span><span class="o">.</span><span class="n">dense</span><span class="p">(</span><span class="mf">1.380</span><span class="p">,</span> <span class="mf">0.231</span><span class="p">)),</span> |
| <span class="p">(</span><span class="mf">0.273</span><span class="p">,</span> <span class="mf">1.0</span><span class="p">,</span> <span class="n">Vectors</span><span class="o">.</span><span class="n">dense</span><span class="p">(</span><span class="mf">0.520</span><span class="p">,</span> <span class="mf">1.151</span><span class="p">)),</span> |
| <span class="p">(</span><span class="mf">4.199</span><span class="p">,</span> <span class="mf">0.0</span><span class="p">,</span> <span class="n">Vectors</span><span class="o">.</span><span class="n">dense</span><span class="p">(</span><span class="mf">0.795</span><span class="p">,</span> <span class="o">-</span><span class="mf">0.226</span><span class="p">))],</span> <span class="p">[</span><span class="s2">"label"</span><span class="p">,</span> <span class="s2">"censor"</span><span class="p">,</span> <span class="s2">"features"</span><span class="p">])</span> |
| <span class="n">quantileProbabilities</span> <span class="o">=</span> <span class="p">[</span><span class="mf">0.3</span><span class="p">,</span> <span class="mf">0.6</span><span class="p">]</span> |
| <span class="n">aft</span> <span class="o">=</span> <span class="n">AFTSurvivalRegression</span><span class="p">(</span><span class="n">quantileProbabilities</span><span class="o">=</span><span class="n">quantileProbabilities</span><span class="p">,</span> |
| <span class="n">quantilesCol</span><span class="o">=</span><span class="s2">"quantiles"</span><span class="p">)</span> |
| |
| <span class="n">model</span> <span class="o">=</span> <span class="n">aft</span><span class="o">.</span><span class="n">fit</span><span class="p">(</span><span class="n">training</span><span class="p">)</span> |
| |
| <span class="c1"># Print the coefficients, intercept and scale parameter for AFT survival regression</span> |
| <span class="k">print</span><span class="p">(</span><span class="s2">"Coefficients: "</span> <span class="o">+</span> <span class="nb">str</span><span class="p">(</span><span class="n">model</span><span class="o">.</span><span class="n">coefficients</span><span class="p">))</span> |
| <span class="k">print</span><span class="p">(</span><span class="s2">"Intercept: "</span> <span class="o">+</span> <span class="nb">str</span><span class="p">(</span><span class="n">model</span><span class="o">.</span><span class="n">intercept</span><span class="p">))</span> |
| <span class="k">print</span><span class="p">(</span><span class="s2">"Scale: "</span> <span class="o">+</span> <span class="nb">str</span><span class="p">(</span><span class="n">model</span><span class="o">.</span><span class="n">scale</span><span class="p">))</span> |
| <span class="n">model</span><span class="o">.</span><span class="n">transform</span><span class="p">(</span><span class="n">training</span><span class="p">)</span><span class="o">.</span><span class="n">show</span><span class="p">(</span><span class="n">truncate</span><span class="o">=</span><span class="bp">False</span><span class="p">)</span> |
| </pre></div> |
| <div><small>Find full example code at "examples/src/main/python/ml/aft_survival_regression.py" in the Spark repo.</small></div> |
| </div> |
| |
| <div data-lang="r"> |
| |
| <p>Refer to the <a href="api/R/spark.survreg.html">R API docs</a> for more details.</p> |
| |
| <div class="highlight"><pre><span></span><span class="c1"># Use the ovarian dataset available in R survival package</span> |
| <span class="kn">library</span><span class="p">(</span>survival<span class="p">)</span> |
| |
| <span class="c1"># Fit an accelerated failure time (AFT) survival regression model with spark.survreg</span> |
| ovarianDF <span class="o"><-</span> <span class="kp">suppressWarnings</span><span class="p">(</span>createDataFrame<span class="p">(</span>ovarian<span class="p">))</span> |
| aftDF <span class="o"><-</span> ovarianDF |
| aftTestDF <span class="o"><-</span> ovarianDF |
| aftModel <span class="o"><-</span> spark.survreg<span class="p">(</span>aftDF<span class="p">,</span> Surv<span class="p">(</span>futime<span class="p">,</span> fustat<span class="p">)</span> <span class="o">~</span> ecog_ps <span class="o">+</span> rx<span class="p">)</span> |
| |
| <span class="c1"># Model summary</span> |
| <span class="kp">summary</span><span class="p">(</span>aftModel<span class="p">)</span> |
| |
| <span class="c1"># Prediction</span> |
| aftPredictions <span class="o"><-</span> predict<span class="p">(</span>aftModel<span class="p">,</span> aftTestDF<span class="p">)</span> |
| <span class="kp">head</span><span class="p">(</span>aftPredictions<span class="p">)</span> |
| </pre></div> |
| <div><small>Find full example code at "examples/src/main/r/ml/survreg.R" in the Spark repo.</small></div> |
| </div> |
| |
| </div> |
| |
| <h2 id="isotonic-regression">Isotonic regression</h2> |
| <p><a href="http://en.wikipedia.org/wiki/Isotonic_regression">Isotonic regression</a> |
| belongs to the family of regression algorithms. Formally isotonic regression is a problem where |
| given a finite set of real numbers <code>$Y = {y_1, y_2, ..., y_n}$</code> representing observed responses |
| and <code>$X = {x_1, x_2, ..., x_n}$</code> the unknown response values to be fitted |
| finding a function that minimizes</p> |
| |
| <p><code>\begin{equation} |
| f(x) = \sum_{i=1}^n w_i (y_i - x_i)^2 |
| \end{equation}</code></p> |
| |
| <p>with respect to complete order subject to |
| <code>$x_1\le x_2\le ...\le x_n$</code> where <code>$w_i$</code> are positive weights. |
| The resulting function is called isotonic regression and it is unique. |
| It can be viewed as least squares problem under order restriction. |
| Essentially isotonic regression is a |
| <a href="http://en.wikipedia.org/wiki/Monotonic_function">monotonic function</a> |
| best fitting the original data points.</p> |
| |
| <p>We implement a |
| <a href="http://doi.org/10.1198/TECH.2010.10111">pool adjacent violators algorithm</a> |
| which uses an approach to |
| <a href="http://doi.org/10.1007/978-3-642-99789-1_10">parallelizing isotonic regression</a>. |
| The training input is a DataFrame which contains three columns |
| label, features and weight. Additionally, IsotonicRegression algorithm has one |
| optional parameter called $isotonic$ defaulting to true. |
| This argument specifies if the isotonic regression is |
| isotonic (monotonically increasing) or antitonic (monotonically decreasing).</p> |
| |
| <p>Training returns an IsotonicRegressionModel that can be used to predict |
| labels for both known and unknown features. The result of isotonic regression |
| is treated as piecewise linear function. The rules for prediction therefore are:</p> |
| |
| <ul> |
| <li>If the prediction input exactly matches a training feature |
| then associated prediction is returned. In case there are multiple predictions with the same |
| feature then one of them is returned. Which one is undefined |
| (same as java.util.Arrays.binarySearch).</li> |
| <li>If the prediction input is lower or higher than all training features |
| then prediction with lowest or highest feature is returned respectively. |
| In case there are multiple predictions with the same feature |
| then the lowest or highest is returned respectively.</li> |
| <li>If the prediction input falls between two training features then prediction is treated |
| as piecewise linear function and interpolated value is calculated from the |
| predictions of the two closest features. In case there are multiple values |
| with the same feature then the same rules as in previous point are used.</li> |
| </ul> |
| |
| <p><strong>Examples</strong></p> |
| |
| <div class="codetabs"> |
| <div data-lang="scala"> |
| |
| <p>Refer to the <a href="api/scala/index.html#org.apache.spark.ml.regression.IsotonicRegression"><code>IsotonicRegression</code> Scala docs</a> for details on the API.</p> |
| |
| <div class="highlight"><pre><span></span><span class="k">import</span> <span class="nn">org.apache.spark.ml.regression.IsotonicRegression</span> |
| |
| <span class="c1">// Loads data.</span> |
| <span class="k">val</span> <span class="n">dataset</span> <span class="k">=</span> <span class="n">spark</span><span class="o">.</span><span class="n">read</span><span class="o">.</span><span class="n">format</span><span class="o">(</span><span class="s">"libsvm"</span><span class="o">)</span> |
| <span class="o">.</span><span class="n">load</span><span class="o">(</span><span class="s">"data/mllib/sample_isotonic_regression_libsvm_data.txt"</span><span class="o">)</span> |
| |
| <span class="c1">// Trains an isotonic regression model.</span> |
| <span class="k">val</span> <span class="n">ir</span> <span class="k">=</span> <span class="k">new</span> <span class="nc">IsotonicRegression</span><span class="o">()</span> |
| <span class="k">val</span> <span class="n">model</span> <span class="k">=</span> <span class="n">ir</span><span class="o">.</span><span class="n">fit</span><span class="o">(</span><span class="n">dataset</span><span class="o">)</span> |
| |
| <span class="n">println</span><span class="o">(</span><span class="s">s"Boundaries in increasing order: </span><span class="si">${</span><span class="n">model</span><span class="o">.</span><span class="n">boundaries</span><span class="si">}</span><span class="s">\n"</span><span class="o">)</span> |
| <span class="n">println</span><span class="o">(</span><span class="s">s"Predictions associated with the boundaries: </span><span class="si">${</span><span class="n">model</span><span class="o">.</span><span class="n">predictions</span><span class="si">}</span><span class="s">\n"</span><span class="o">)</span> |
| |
| <span class="c1">// Makes predictions.</span> |
| <span class="n">model</span><span class="o">.</span><span class="n">transform</span><span class="o">(</span><span class="n">dataset</span><span class="o">).</span><span class="n">show</span><span class="o">()</span> |
| </pre></div> |
| <div><small>Find full example code at "examples/src/main/scala/org/apache/spark/examples/ml/IsotonicRegressionExample.scala" in the Spark repo.</small></div> |
| </div> |
| <div data-lang="java"> |
| |
| <p>Refer to the <a href="api/java/org/apache/spark/ml/regression/IsotonicRegression.html"><code>IsotonicRegression</code> Java docs</a> for details on the API.</p> |
| |
| <div class="highlight"><pre><span></span><span class="kn">import</span> <span class="nn">org.apache.spark.ml.regression.IsotonicRegression</span><span class="o">;</span> |
| <span class="kn">import</span> <span class="nn">org.apache.spark.ml.regression.IsotonicRegressionModel</span><span class="o">;</span> |
| <span class="kn">import</span> <span class="nn">org.apache.spark.sql.Dataset</span><span class="o">;</span> |
| <span class="kn">import</span> <span class="nn">org.apache.spark.sql.Row</span><span class="o">;</span> |
| |
| <span class="c1">// Loads data.</span> |
| <span class="n">Dataset</span><span class="o"><</span><span class="n">Row</span><span class="o">></span> <span class="n">dataset</span> <span class="o">=</span> <span class="n">spark</span><span class="o">.</span><span class="na">read</span><span class="o">().</span><span class="na">format</span><span class="o">(</span><span class="s">"libsvm"</span><span class="o">)</span> |
| <span class="o">.</span><span class="na">load</span><span class="o">(</span><span class="s">"data/mllib/sample_isotonic_regression_libsvm_data.txt"</span><span class="o">);</span> |
| |
| <span class="c1">// Trains an isotonic regression model.</span> |
| <span class="n">IsotonicRegression</span> <span class="n">ir</span> <span class="o">=</span> <span class="k">new</span> <span class="n">IsotonicRegression</span><span class="o">();</span> |
| <span class="n">IsotonicRegressionModel</span> <span class="n">model</span> <span class="o">=</span> <span class="n">ir</span><span class="o">.</span><span class="na">fit</span><span class="o">(</span><span class="n">dataset</span><span class="o">);</span> |
| |
| <span class="n">System</span><span class="o">.</span><span class="na">out</span><span class="o">.</span><span class="na">println</span><span class="o">(</span><span class="s">"Boundaries in increasing order: "</span> <span class="o">+</span> <span class="n">model</span><span class="o">.</span><span class="na">boundaries</span><span class="o">()</span> <span class="o">+</span> <span class="s">"\n"</span><span class="o">);</span> |
| <span class="n">System</span><span class="o">.</span><span class="na">out</span><span class="o">.</span><span class="na">println</span><span class="o">(</span><span class="s">"Predictions associated with the boundaries: "</span> <span class="o">+</span> <span class="n">model</span><span class="o">.</span><span class="na">predictions</span><span class="o">()</span> <span class="o">+</span> <span class="s">"\n"</span><span class="o">);</span> |
| |
| <span class="c1">// Makes predictions.</span> |
| <span class="n">model</span><span class="o">.</span><span class="na">transform</span><span class="o">(</span><span class="n">dataset</span><span class="o">).</span><span class="na">show</span><span class="o">();</span> |
| </pre></div> |
| <div><small>Find full example code at "examples/src/main/java/org/apache/spark/examples/ml/JavaIsotonicRegressionExample.java" in the Spark repo.</small></div> |
| </div> |
| <div data-lang="python"> |
| |
| <p>Refer to the <a href="api/python/pyspark.ml.html#pyspark.ml.regression.IsotonicRegression"><code>IsotonicRegression</code> Python docs</a> for more details on the API.</p> |
| |
| <div class="highlight"><pre><span></span><span class="kn">from</span> <span class="nn">pyspark.ml.regression</span> <span class="kn">import</span> <span class="n">IsotonicRegression</span> |
| |
| <span class="c1"># Loads data.</span> |
| <span class="n">dataset</span> <span class="o">=</span> <span class="n">spark</span><span class="o">.</span><span class="n">read</span><span class="o">.</span><span class="n">format</span><span class="p">(</span><span class="s2">"libsvm"</span><span class="p">)</span>\ |
| <span class="o">.</span><span class="n">load</span><span class="p">(</span><span class="s2">"data/mllib/sample_isotonic_regression_libsvm_data.txt"</span><span class="p">)</span> |
| |
| <span class="c1"># Trains an isotonic regression model.</span> |
| <span class="n">model</span> <span class="o">=</span> <span class="n">IsotonicRegression</span><span class="p">()</span><span class="o">.</span><span class="n">fit</span><span class="p">(</span><span class="n">dataset</span><span class="p">)</span> |
| <span class="k">print</span><span class="p">(</span><span class="s2">"Boundaries in increasing order: </span><span class="si">%s</span><span class="se">\n</span><span class="s2">"</span> <span class="o">%</span> <span class="nb">str</span><span class="p">(</span><span class="n">model</span><span class="o">.</span><span class="n">boundaries</span><span class="p">))</span> |
| <span class="k">print</span><span class="p">(</span><span class="s2">"Predictions associated with the boundaries: </span><span class="si">%s</span><span class="se">\n</span><span class="s2">"</span> <span class="o">%</span> <span class="nb">str</span><span class="p">(</span><span class="n">model</span><span class="o">.</span><span class="n">predictions</span><span class="p">))</span> |
| |
| <span class="c1"># Makes predictions.</span> |
| <span class="n">model</span><span class="o">.</span><span class="n">transform</span><span class="p">(</span><span class="n">dataset</span><span class="p">)</span><span class="o">.</span><span class="n">show</span><span class="p">()</span> |
| </pre></div> |
| <div><small>Find full example code at "examples/src/main/python/ml/isotonic_regression_example.py" in the Spark repo.</small></div> |
| </div> |
| |
| <div data-lang="r"> |
| |
| <p>Refer to the <a href="api/R/spark.isoreg.html"><code>IsotonicRegression</code> R API docs</a> for more details on the API.</p> |
| |
| <div class="highlight"><pre><span></span><span class="c1"># Load training data</span> |
| df <span class="o"><-</span> read.df<span class="p">(</span><span class="s">"data/mllib/sample_isotonic_regression_libsvm_data.txt"</span><span class="p">,</span> <span class="kn">source</span> <span class="o">=</span> <span class="s">"libsvm"</span><span class="p">)</span> |
| training <span class="o"><-</span> df |
| test <span class="o"><-</span> df |
| |
| <span class="c1"># Fit an isotonic regression model with spark.isoreg</span> |
| model <span class="o"><-</span> spark.isoreg<span class="p">(</span>training<span class="p">,</span> label <span class="o">~</span> features<span class="p">,</span> isotonic <span class="o">=</span> <span class="kc">FALSE</span><span class="p">)</span> |
| |
| <span class="c1"># Model summary</span> |
| <span class="kp">summary</span><span class="p">(</span>model<span class="p">)</span> |
| |
| <span class="c1"># Prediction</span> |
| predictions <span class="o"><-</span> predict<span class="p">(</span>model<span class="p">,</span> test<span class="p">)</span> |
| <span class="kp">head</span><span class="p">(</span>predictions<span class="p">)</span> |
| </pre></div> |
| <div><small>Find full example code at "examples/src/main/r/ml/isoreg.R" in the Spark repo.</small></div> |
| </div> |
| |
| </div> |
| |
| <h1 id="linear-methods">Linear methods</h1> |
| |
| <p>We implement popular linear methods such as logistic |
| regression and linear least squares with $L_1$ or $L_2$ regularization. |
| Refer to <a href="mllib-linear-methods.html">the linear methods guide for the RDD-based API</a> for |
| details about implementation and tuning; this information is still relevant.</p> |
| |
| <p>We also include a DataFrame API for <a href="http://en.wikipedia.org/wiki/Elastic_net_regularization">Elastic |
| net</a>, a hybrid |
| of $L_1$ and $L_2$ regularization proposed in <a href="http://users.stat.umn.edu/~zouxx019/Papers/elasticnet.pdf">Zou et al, Regularization |
| and variable selection via the elastic |
| net</a>. |
| Mathematically, it is defined as a convex combination of the $L_1$ and |
| the $L_2$ regularization terms: |
| <code>\[ |
| \alpha \left( \lambda \|\wv\|_1 \right) + (1-\alpha) \left( \frac{\lambda}{2}\|\wv\|_2^2 \right) , \alpha \in [0, 1], \lambda \geq 0 |
| \]</code> |
| By setting $\alpha$ properly, elastic net contains both $L_1$ and $L_2$ |
| regularization as special cases. For example, if a <a href="https://en.wikipedia.org/wiki/Linear_regression">linear |
| regression</a> model is |
| trained with the elastic net parameter $\alpha$ set to $1$, it is |
| equivalent to a |
| <a href="http://en.wikipedia.org/wiki/Least_squares#Lasso_method">Lasso</a> model. |
| On the other hand, if $\alpha$ is set to $0$, the trained model reduces |
| to a <a href="http://en.wikipedia.org/wiki/Tikhonov_regularization">ridge |
| regression</a> model. |
| We implement Pipelines API for both linear regression and logistic |
| regression with elastic net regularization.</p> |
| |
| <h1 id="decision-trees">Decision trees</h1> |
| |
| <p><a href="http://en.wikipedia.org/wiki/Decision_tree_learning">Decision trees</a> |
| and their ensembles are popular methods for the machine learning tasks of |
| classification and regression. Decision trees are widely used since they are easy to interpret, |
| handle categorical features, extend to the multiclass classification setting, do not require |
| feature scaling, and are able to capture non-linearities and feature interactions. Tree ensemble |
| algorithms such as random forests and boosting are among the top performers for classification and |
| regression tasks.</p> |
| |
| <p>The <code>spark.ml</code> implementation supports decision trees for binary and multiclass classification and for regression, |
| using both continuous and categorical features. The implementation partitions data by rows, |
| allowing distributed training with millions or even billions of instances.</p> |
| |
| <p>Users can find more information about the decision tree algorithm in the <a href="mllib-decision-tree.html">MLlib Decision Tree guide</a>. |
| The main differences between this API and the <a href="mllib-decision-tree.html">original MLlib Decision Tree API</a> are:</p> |
| |
| <ul> |
| <li>support for ML Pipelines</li> |
| <li>separation of Decision Trees for classification vs. regression</li> |
| <li>use of DataFrame metadata to distinguish continuous and categorical features</li> |
| </ul> |
| |
| <p>The Pipelines API for Decision Trees offers a bit more functionality than the original API.<br /> |
| In particular, for classification, users can get the predicted probability of each class (a.k.a. class conditional probabilities); |
| for regression, users can get the biased sample variance of prediction.</p> |
| |
| <p>Ensembles of trees (Random Forests and Gradient-Boosted Trees) are described below in the <a href="#tree-ensembles">Tree ensembles section</a>.</p> |
| |
| <h2 id="inputs-and-outputs">Inputs and Outputs</h2> |
| |
| <p>We list the input and output (prediction) column types here. |
| All output columns are optional; to exclude an output column, set its corresponding Param to an empty string.</p> |
| |
| <h3 id="input-columns">Input Columns</h3> |
| |
| <table class="table"> |
| <thead> |
| <tr> |
| <th align="left">Param name</th> |
| <th align="left">Type(s)</th> |
| <th align="left">Default</th> |
| <th align="left">Description</th> |
| </tr> |
| </thead> |
| <tbody> |
| <tr> |
| <td>labelCol</td> |
| <td>Double</td> |
| <td>"label"</td> |
| <td>Label to predict</td> |
| </tr> |
| <tr> |
| <td>featuresCol</td> |
| <td>Vector</td> |
| <td>"features"</td> |
| <td>Feature vector</td> |
| </tr> |
| </tbody> |
| </table> |
| |
| <h3 id="output-columns">Output Columns</h3> |
| |
| <table class="table"> |
| <thead> |
| <tr> |
| <th align="left">Param name</th> |
| <th align="left">Type(s)</th> |
| <th align="left">Default</th> |
| <th align="left">Description</th> |
| <th align="left">Notes</th> |
| </tr> |
| </thead> |
| <tbody> |
| <tr> |
| <td>predictionCol</td> |
| <td>Double</td> |
| <td>"prediction"</td> |
| <td>Predicted label</td> |
| <td></td> |
| </tr> |
| <tr> |
| <td>rawPredictionCol</td> |
| <td>Vector</td> |
| <td>"rawPrediction"</td> |
| <td>Vector of length # classes, with the counts of training instance labels at the tree node which makes the prediction</td> |
| <td>Classification only</td> |
| </tr> |
| <tr> |
| <td>probabilityCol</td> |
| <td>Vector</td> |
| <td>"probability"</td> |
| <td>Vector of length # classes equal to rawPrediction normalized to a multinomial distribution</td> |
| <td>Classification only</td> |
| </tr> |
| <tr> |
| <td>varianceCol</td> |
| <td>Double</td> |
| <td></td> |
| <td>The biased sample variance of prediction</td> |
| <td>Regression only</td> |
| </tr> |
| </tbody> |
| </table> |
| |
| <h1 id="tree-ensembles">Tree Ensembles</h1> |
| |
| <p>The DataFrame API supports two major tree ensemble algorithms: <a href="http://en.wikipedia.org/wiki/Random_forest">Random Forests</a> and <a href="http://en.wikipedia.org/wiki/Gradient_boosting">Gradient-Boosted Trees (GBTs)</a>. |
| Both use <a href="ml-classification-regression.html#decision-trees"><code>spark.ml</code> decision trees</a> as their base models.</p> |
| |
| <p>Users can find more information about ensemble algorithms in the <a href="mllib-ensembles.html">MLlib Ensemble guide</a>.<br /> |
| In this section, we demonstrate the DataFrame API for ensembles.</p> |
| |
| <p>The main differences between this API and the <a href="mllib-ensembles.html">original MLlib ensembles API</a> are:</p> |
| |
| <ul> |
| <li>support for DataFrames and ML Pipelines</li> |
| <li>separation of classification vs. regression</li> |
| <li>use of DataFrame metadata to distinguish continuous and categorical features</li> |
| <li>more functionality for random forests: estimates of feature importance, as well as the predicted probability of each class (a.k.a. class conditional probabilities) for classification.</li> |
| </ul> |
| |
| <h2 id="random-forests">Random Forests</h2> |
| |
| <p><a href="http://en.wikipedia.org/wiki/Random_forest">Random forests</a> |
| are ensembles of <a href="ml-classification-regression.html#decision-trees">decision trees</a>. |
| Random forests combine many decision trees in order to reduce the risk of overfitting. |
| The <code>spark.ml</code> implementation supports random forests for binary and multiclass classification and for regression, |
| using both continuous and categorical features.</p> |
| |
| <p>For more information on the algorithm itself, please see the <a href="mllib-ensembles.html#random-forests"><code>spark.mllib</code> documentation on random forests</a>.</p> |
| |
| <h3 id="inputs-and-outputs-1">Inputs and Outputs</h3> |
| |
| <p>We list the input and output (prediction) column types here. |
| All output columns are optional; to exclude an output column, set its corresponding Param to an empty string.</p> |
| |
| <h4 id="input-columns-1">Input Columns</h4> |
| |
| <table class="table"> |
| <thead> |
| <tr> |
| <th align="left">Param name</th> |
| <th align="left">Type(s)</th> |
| <th align="left">Default</th> |
| <th align="left">Description</th> |
| </tr> |
| </thead> |
| <tbody> |
| <tr> |
| <td>labelCol</td> |
| <td>Double</td> |
| <td>"label"</td> |
| <td>Label to predict</td> |
| </tr> |
| <tr> |
| <td>featuresCol</td> |
| <td>Vector</td> |
| <td>"features"</td> |
| <td>Feature vector</td> |
| </tr> |
| </tbody> |
| </table> |
| |
| <h4 id="output-columns-predictions">Output Columns (Predictions)</h4> |
| |
| <table class="table"> |
| <thead> |
| <tr> |
| <th align="left">Param name</th> |
| <th align="left">Type(s)</th> |
| <th align="left">Default</th> |
| <th align="left">Description</th> |
| <th align="left">Notes</th> |
| </tr> |
| </thead> |
| <tbody> |
| <tr> |
| <td>predictionCol</td> |
| <td>Double</td> |
| <td>"prediction"</td> |
| <td>Predicted label</td> |
| <td></td> |
| </tr> |
| <tr> |
| <td>rawPredictionCol</td> |
| <td>Vector</td> |
| <td>"rawPrediction"</td> |
| <td>Vector of length # classes, with the counts of training instance labels at the tree node which makes the prediction</td> |
| <td>Classification only</td> |
| </tr> |
| <tr> |
| <td>probabilityCol</td> |
| <td>Vector</td> |
| <td>"probability"</td> |
| <td>Vector of length # classes equal to rawPrediction normalized to a multinomial distribution</td> |
| <td>Classification only</td> |
| </tr> |
| </tbody> |
| </table> |
| |
| <h2 id="gradient-boosted-trees-gbts">Gradient-Boosted Trees (GBTs)</h2> |
| |
| <p><a href="http://en.wikipedia.org/wiki/Gradient_boosting">Gradient-Boosted Trees (GBTs)</a> |
| are ensembles of <a href="ml-classification-regression.html#decision-trees">decision trees</a>. |
| GBTs iteratively train decision trees in order to minimize a loss function. |
| The <code>spark.ml</code> implementation supports GBTs for binary classification and for regression, |
| using both continuous and categorical features.</p> |
| |
| <p>For more information on the algorithm itself, please see the <a href="mllib-ensembles.html#gradient-boosted-trees-gbts"><code>spark.mllib</code> documentation on GBTs</a>.</p> |
| |
| <h3 id="inputs-and-outputs-2">Inputs and Outputs</h3> |
| |
| <p>We list the input and output (prediction) column types here. |
| All output columns are optional; to exclude an output column, set its corresponding Param to an empty string.</p> |
| |
| <h4 id="input-columns-2">Input Columns</h4> |
| |
| <table class="table"> |
| <thead> |
| <tr> |
| <th align="left">Param name</th> |
| <th align="left">Type(s)</th> |
| <th align="left">Default</th> |
| <th align="left">Description</th> |
| </tr> |
| </thead> |
| <tbody> |
| <tr> |
| <td>labelCol</td> |
| <td>Double</td> |
| <td>"label"</td> |
| <td>Label to predict</td> |
| </tr> |
| <tr> |
| <td>featuresCol</td> |
| <td>Vector</td> |
| <td>"features"</td> |
| <td>Feature vector</td> |
| </tr> |
| </tbody> |
| </table> |
| |
| <p>Note that <code>GBTClassifier</code> currently only supports binary labels.</p> |
| |
| <h4 id="output-columns-predictions-1">Output Columns (Predictions)</h4> |
| |
| <table class="table"> |
| <thead> |
| <tr> |
| <th align="left">Param name</th> |
| <th align="left">Type(s)</th> |
| <th align="left">Default</th> |
| <th align="left">Description</th> |
| <th align="left">Notes</th> |
| </tr> |
| </thead> |
| <tbody> |
| <tr> |
| <td>predictionCol</td> |
| <td>Double</td> |
| <td>"prediction"</td> |
| <td>Predicted label</td> |
| <td></td> |
| </tr> |
| </tbody> |
| </table> |
| |
| <p>In the future, <code>GBTClassifier</code> will also output columns for <code>rawPrediction</code> and <code>probability</code>, just as <code>RandomForestClassifier</code> does.</p> |
| |
| |
| |
| </div> |
| |
| <!-- /container --> |
| </div> |
| |
| <script src="js/vendor/jquery-1.12.4.min.js"></script> |
| <script src="js/vendor/bootstrap.min.js"></script> |
| <script src="js/vendor/anchor.min.js"></script> |
| <script src="js/main.js"></script> |
| |
| <!-- MathJax Section --> |
| <script type="text/x-mathjax-config"> |
| MathJax.Hub.Config({ |
| TeX: { equationNumbers: { autoNumber: "AMS" } } |
| }); |
| </script> |
| <script> |
| // Note that we load MathJax this way to work with local file (file://), HTTP and HTTPS. |
| // We could use "//cdn.mathjax...", but that won't support "file://". |
| (function(d, script) { |
| script = d.createElement('script'); |
| script.type = 'text/javascript'; |
| script.async = true; |
| script.onload = function(){ |
| MathJax.Hub.Config({ |
| tex2jax: { |
| inlineMath: [ ["$", "$"], ["\\\\(","\\\\)"] ], |
| displayMath: [ ["$$","$$"], ["\\[", "\\]"] ], |
| processEscapes: true, |
| skipTags: ['script', 'noscript', 'style', 'textarea', 'pre'] |
| } |
| }); |
| }; |
| script.src = ('https:' == document.location.protocol ? 'https://' : 'http://') + |
| 'cdnjs.cloudflare.com/ajax/libs/mathjax/2.7.1/MathJax.js' + |
| '?config=TeX-AMS-MML_HTMLorMML'; |
| d.getElementsByTagName('head')[0].appendChild(script); |
| }(document)); |
| </script> |
| </body> |
| </html> |