| |
| <!DOCTYPE html> |
| <!--[if lt IE 7]> <html class="no-js lt-ie9 lt-ie8 lt-ie7"> <![endif]--> |
| <!--[if IE 7]> <html class="no-js lt-ie9 lt-ie8"> <![endif]--> |
| <!--[if IE 8]> <html class="no-js lt-ie9"> <![endif]--> |
| <!--[if gt IE 8]><!--> <html class="no-js"> <!--<![endif]--> |
| <head> |
| <meta charset="utf-8"> |
| <meta http-equiv="X-UA-Compatible" content="IE=edge,chrome=1"> |
| <title>Classification and regression - Spark 3.3.0 Documentation</title> |
| |
| |
| |
| |
| <link rel="stylesheet" href="css/bootstrap.min.css"> |
| <style> |
| body { |
| padding-top: 60px; |
| padding-bottom: 40px; |
| } |
| </style> |
| <meta name="viewport" content="width=device-width"> |
| <link rel="stylesheet" href="css/main.css"> |
| |
| <script src="js/vendor/modernizr-2.6.1-respond-1.1.0.min.js"></script> |
| |
| <link rel="stylesheet" href="css/pygments-default.css"> |
| <link rel="stylesheet" href="https://cdn.jsdelivr.net/npm/docsearch.js@2/dist/cdn/docsearch.min.css" /> |
| <link rel="stylesheet" href="css/docsearch.css"> |
| |
| <!-- Matomo --> |
| <script type="text/javascript"> |
| var _paq = window._paq = window._paq || []; |
| /* tracker methods like "setCustomDimension" should be called before "trackPageView" */ |
| _paq.push(["disableCookies"]); |
| _paq.push(['trackPageView']); |
| _paq.push(['enableLinkTracking']); |
| (function() { |
| var u="https://analytics.apache.org/"; |
| _paq.push(['setTrackerUrl', u+'matomo.php']); |
| _paq.push(['setSiteId', '40']); |
| var d=document, g=d.createElement('script'), s=d.getElementsByTagName('script')[0]; |
| g.async=true; g.src=u+'matomo.js'; s.parentNode.insertBefore(g,s); |
| })(); |
| </script> |
| <!-- End Matomo Code --> |
| </head> |
| <body> |
| <!--[if lt IE 7]> |
| <p class="chromeframe">You are using an outdated browser. <a href="https://browsehappy.com/">Upgrade your browser today</a> or <a href="http://www.google.com/chromeframe/?redirect=true">install Google Chrome Frame</a> to better experience this site.</p> |
| <![endif]--> |
| |
| <!-- This code is taken from http://twitter.github.com/bootstrap/examples/hero.html --> |
| |
| <nav class="navbar fixed-top navbar-expand-md navbar-light bg-light" id="topbar"> |
| <div class="container"> |
| <div class="navbar-header"> |
| <div class="navbar-brand"><a href="index.html"> |
| <img src="img/spark-logo-hd.png" style="height:50px;"/></a><span class="version">3.3.0</span> |
| </div> |
| </div> |
| <button class="navbar-toggler" type="button" data-toggle="collapse" |
| data-target="#navbarCollapse" aria-controls="navbarCollapse" |
| aria-expanded="false" aria-label="Toggle navigation"> |
| <span class="navbar-toggler-icon"></span> |
| </button> |
| <div class="collapse navbar-collapse" id="navbarCollapse"> |
| <ul class="navbar-nav"> |
| <!--TODO(andyk): Add class="active" attribute to li some how.--> |
| <li class="nav-item"><a href="index.html" class="nav-link">Overview</a></li> |
| |
| <li class="nav-item dropdown"> |
| <a href="#" class="nav-link dropdown-toggle" id="navbarQuickStart" role="button" data-toggle="dropdown" aria-haspopup="true" aria-expanded="false">Programming Guides</a> |
| <div class="dropdown-menu" aria-labelledby="navbarQuickStart"> |
| <a class="dropdown-item" href="quick-start.html">Quick Start</a> |
| <a class="dropdown-item" href="rdd-programming-guide.html">RDDs, Accumulators, Broadcasts Vars</a> |
| <a class="dropdown-item" href="sql-programming-guide.html">SQL, DataFrames, and Datasets</a> |
| <a class="dropdown-item" href="structured-streaming-programming-guide.html">Structured Streaming</a> |
| <a class="dropdown-item" href="streaming-programming-guide.html">Spark Streaming (DStreams)</a> |
| <a class="dropdown-item" href="ml-guide.html">MLlib (Machine Learning)</a> |
| <a class="dropdown-item" href="graphx-programming-guide.html">GraphX (Graph Processing)</a> |
| <a class="dropdown-item" href="sparkr.html">SparkR (R on Spark)</a> |
| <a class="dropdown-item" href="api/python/getting_started/index.html">PySpark (Python on Spark)</a> |
| </div> |
| </li> |
| |
| <li class="nav-item dropdown"> |
| <a href="#" class="nav-link dropdown-toggle" id="navbarAPIDocs" role="button" data-toggle="dropdown" aria-haspopup="true" aria-expanded="false">API Docs</a> |
| <div class="dropdown-menu" aria-labelledby="navbarAPIDocs"> |
| <a class="dropdown-item" href="api/scala/org/apache/spark/index.html">Scala</a> |
| <a class="dropdown-item" href="api/java/index.html">Java</a> |
| <a class="dropdown-item" href="api/python/index.html">Python</a> |
| <a class="dropdown-item" href="api/R/index.html">R</a> |
| <a class="dropdown-item" href="api/sql/index.html">SQL, Built-in Functions</a> |
| </div> |
| </li> |
| |
| <li class="nav-item dropdown"> |
| <a href="#" class="nav-link dropdown-toggle" id="navbarDeploying" role="button" data-toggle="dropdown" aria-haspopup="true" aria-expanded="false">Deploying</a> |
| <div class="dropdown-menu" aria-labelledby="navbarDeploying"> |
| <a class="dropdown-item" href="cluster-overview.html">Overview</a> |
| <a class="dropdown-item" href="submitting-applications.html">Submitting Applications</a> |
| <div class="dropdown-divider"></div> |
| <a class="dropdown-item" href="spark-standalone.html">Spark Standalone</a> |
| <a class="dropdown-item" href="running-on-mesos.html">Mesos</a> |
| <a class="dropdown-item" href="running-on-yarn.html">YARN</a> |
| <a class="dropdown-item" href="running-on-kubernetes.html">Kubernetes</a> |
| </div> |
| </li> |
| |
| <li class="nav-item dropdown"> |
| <a href="#" class="nav-link dropdown-toggle" id="navbarMore" role="button" data-toggle="dropdown" aria-haspopup="true" aria-expanded="false">More</a> |
| <div class="dropdown-menu" aria-labelledby="navbarMore"> |
| <a class="dropdown-item" href="configuration.html">Configuration</a> |
| <a class="dropdown-item" href="monitoring.html">Monitoring</a> |
| <a class="dropdown-item" href="tuning.html">Tuning Guide</a> |
| <a class="dropdown-item" href="job-scheduling.html">Job Scheduling</a> |
| <a class="dropdown-item" href="security.html">Security</a> |
| <a class="dropdown-item" href="hardware-provisioning.html">Hardware Provisioning</a> |
| <a class="dropdown-item" href="migration-guide.html">Migration Guide</a> |
| <div class="dropdown-divider"></div> |
| <a class="dropdown-item" href="building-spark.html">Building Spark</a> |
| <a class="dropdown-item" href="https://spark.apache.org/contributing.html">Contributing to Spark</a> |
| <a class="dropdown-item" href="https://spark.apache.org/third-party-projects.html">Third Party Projects</a> |
| </div> |
| </li> |
| |
| <li class="nav-item"> |
| <input type="text" id="docsearch-input" placeholder="Search the docs…"> |
| </li> |
| </ul> |
| <!--<span class="navbar-text navbar-right"><span class="version-text">v3.3.0</span></span>--> |
| </div> |
| </div> |
| </nav> |
| |
| <div class="container-wrapper"> |
| |
| |
| |
| <div class="left-menu-wrapper"> |
| <div class="left-menu"> |
| <h3><a href="ml-guide.html">MLlib: Main Guide</a></h3> |
| |
| <ul> |
| |
| <li> |
| <a href="ml-statistics.html"> |
| |
| Basic statistics |
| |
| </a> |
| </li> |
| |
| |
| |
| <li> |
| <a href="ml-datasource.html"> |
| |
| Data sources |
| |
| </a> |
| </li> |
| |
| |
| |
| <li> |
| <a href="ml-pipeline.html"> |
| |
| Pipelines |
| |
| </a> |
| </li> |
| |
| |
| |
| <li> |
| <a href="ml-features.html"> |
| |
| Extracting, transforming and selecting features |
| |
| </a> |
| </li> |
| |
| |
| |
| <li> |
| <a href="ml-classification-regression.html"> |
| |
| Classification and Regression |
| |
| </a> |
| </li> |
| |
| |
| |
| <li> |
| <a href="ml-clustering.html"> |
| |
| Clustering |
| |
| </a> |
| </li> |
| |
| |
| |
| <li> |
| <a href="ml-collaborative-filtering.html"> |
| |
| Collaborative filtering |
| |
| </a> |
| </li> |
| |
| |
| |
| <li> |
| <a href="ml-frequent-pattern-mining.html"> |
| |
| Frequent Pattern Mining |
| |
| </a> |
| </li> |
| |
| |
| |
| <li> |
| <a href="ml-tuning.html"> |
| |
| Model selection and tuning |
| |
| </a> |
| </li> |
| |
| |
| |
| <li> |
| <a href="ml-advanced.html"> |
| |
| Advanced topics |
| |
| </a> |
| </li> |
| |
| |
| |
| </ul> |
| |
| <h3><a href="mllib-guide.html">MLlib: RDD-based API Guide</a></h3> |
| |
| <ul> |
| |
| <li> |
| <a href="mllib-data-types.html"> |
| |
| Data types |
| |
| </a> |
| </li> |
| |
| |
| |
| <li> |
| <a href="mllib-statistics.html"> |
| |
| Basic statistics |
| |
| </a> |
| </li> |
| |
| |
| |
| <li> |
| <a href="mllib-classification-regression.html"> |
| |
| Classification and regression |
| |
| </a> |
| </li> |
| |
| |
| |
| <li> |
| <a href="mllib-collaborative-filtering.html"> |
| |
| Collaborative filtering |
| |
| </a> |
| </li> |
| |
| |
| |
| <li> |
| <a href="mllib-clustering.html"> |
| |
| Clustering |
| |
| </a> |
| </li> |
| |
| |
| |
| <li> |
| <a href="mllib-dimensionality-reduction.html"> |
| |
| Dimensionality reduction |
| |
| </a> |
| </li> |
| |
| |
| |
| <li> |
| <a href="mllib-feature-extraction.html"> |
| |
| Feature extraction and transformation |
| |
| </a> |
| </li> |
| |
| |
| |
| <li> |
| <a href="mllib-frequent-pattern-mining.html"> |
| |
| Frequent pattern mining |
| |
| </a> |
| </li> |
| |
| |
| |
| <li> |
| <a href="mllib-evaluation-metrics.html"> |
| |
| Evaluation metrics |
| |
| </a> |
| </li> |
| |
| |
| |
| <li> |
| <a href="mllib-pmml-model-export.html"> |
| |
| PMML model export |
| |
| </a> |
| </li> |
| |
| |
| |
| <li> |
| <a href="mllib-optimization.html"> |
| |
| Optimization (developer) |
| |
| </a> |
| </li> |
| |
| |
| |
| </ul> |
| |
| </div> |
| </div> |
| |
| <input id="nav-trigger" class="nav-trigger" checked type="checkbox"> |
| <label for="nav-trigger"></label> |
| <div class="content-with-sidebar mr-3" id="content"> |
| |
| <h1 class="title">Classification and regression</h1> |
| |
| |
| <p><code class="language-plaintext highlighter-rouge">\[ |
| \newcommand{\R}{\mathbb{R}} |
| \newcommand{\E}{\mathbb{E}} |
| \newcommand{\x}{\mathbf{x}} |
| \newcommand{\y}{\mathbf{y}} |
| \newcommand{\wv}{\mathbf{w}} |
| \newcommand{\av}{\mathbf{\alpha}} |
| \newcommand{\bv}{\mathbf{b}} |
| \newcommand{\N}{\mathbb{N}} |
| \newcommand{\id}{\mathbf{I}} |
| \newcommand{\ind}{\mathbf{1}} |
| \newcommand{\0}{\mathbf{0}} |
| \newcommand{\unit}{\mathbf{e}} |
| \newcommand{\one}{\mathbf{1}} |
| \newcommand{\zero}{\mathbf{0}} |
| \]</code></p> |
| |
| <p>This page covers algorithms for Classification and Regression. It also includes sections |
| discussing specific classes of algorithms, such as linear methods, trees, and ensembles.</p> |
| |
| <p><strong>Table of Contents</strong></p> |
| |
| <ul id="markdown-toc"> |
| <li><a href="#classification" id="markdown-toc-classification">Classification</a> <ul> |
| <li><a href="#logistic-regression" id="markdown-toc-logistic-regression">Logistic regression</a> <ul> |
| <li><a href="#binomial-logistic-regression" id="markdown-toc-binomial-logistic-regression">Binomial logistic regression</a></li> |
| <li><a href="#multinomial-logistic-regression" id="markdown-toc-multinomial-logistic-regression">Multinomial logistic regression</a></li> |
| </ul> |
| </li> |
| <li><a href="#decision-tree-classifier" id="markdown-toc-decision-tree-classifier">Decision tree classifier</a></li> |
| <li><a href="#random-forest-classifier" id="markdown-toc-random-forest-classifier">Random forest classifier</a></li> |
| <li><a href="#gradient-boosted-tree-classifier" id="markdown-toc-gradient-boosted-tree-classifier">Gradient-boosted tree classifier</a></li> |
| <li><a href="#multilayer-perceptron-classifier" id="markdown-toc-multilayer-perceptron-classifier">Multilayer perceptron classifier</a></li> |
| <li><a href="#linear-support-vector-machine" id="markdown-toc-linear-support-vector-machine">Linear Support Vector Machine</a></li> |
| <li><a href="#one-vs-rest-classifier-aka-one-vs-all" id="markdown-toc-one-vs-rest-classifier-aka-one-vs-all">One-vs-Rest classifier (a.k.a. One-vs-All)</a></li> |
| <li><a href="#naive-bayes" id="markdown-toc-naive-bayes">Naive Bayes</a></li> |
| <li><a href="#factorization-machines-classifier" id="markdown-toc-factorization-machines-classifier">Factorization machines classifier</a></li> |
| </ul> |
| </li> |
| <li><a href="#regression" id="markdown-toc-regression">Regression</a> <ul> |
| <li><a href="#linear-regression" id="markdown-toc-linear-regression">Linear regression</a></li> |
| <li><a href="#generalized-linear-regression" id="markdown-toc-generalized-linear-regression">Generalized linear regression</a> <ul> |
| <li><a href="#available-families" id="markdown-toc-available-families">Available families</a></li> |
| </ul> |
| </li> |
| <li><a href="#decision-tree-regression" id="markdown-toc-decision-tree-regression">Decision tree regression</a></li> |
| <li><a href="#random-forest-regression" id="markdown-toc-random-forest-regression">Random forest regression</a></li> |
| <li><a href="#gradient-boosted-tree-regression" id="markdown-toc-gradient-boosted-tree-regression">Gradient-boosted tree regression</a></li> |
| <li><a href="#survival-regression" id="markdown-toc-survival-regression">Survival regression</a></li> |
| <li><a href="#isotonic-regression" id="markdown-toc-isotonic-regression">Isotonic regression</a></li> |
| <li><a href="#factorization-machines-regressor" id="markdown-toc-factorization-machines-regressor">Factorization machines regressor</a></li> |
| </ul> |
| </li> |
| <li><a href="#linear-methods" id="markdown-toc-linear-methods">Linear methods</a></li> |
| <li><a href="#factorization-machines" id="markdown-toc-factorization-machines">Factorization Machines</a></li> |
| <li><a href="#decision-trees" id="markdown-toc-decision-trees">Decision trees</a> <ul> |
| <li><a href="#inputs-and-outputs" id="markdown-toc-inputs-and-outputs">Inputs and Outputs</a> <ul> |
| <li><a href="#input-columns" id="markdown-toc-input-columns">Input Columns</a></li> |
| <li><a href="#output-columns" id="markdown-toc-output-columns">Output Columns</a></li> |
| </ul> |
| </li> |
| </ul> |
| </li> |
| <li><a href="#tree-ensembles" id="markdown-toc-tree-ensembles">Tree Ensembles</a> <ul> |
| <li><a href="#random-forests" id="markdown-toc-random-forests">Random Forests</a> <ul> |
| <li><a href="#inputs-and-outputs-1" id="markdown-toc-inputs-and-outputs-1">Inputs and Outputs</a> <ul> |
| <li><a href="#input-columns-1" id="markdown-toc-input-columns-1">Input Columns</a></li> |
| <li><a href="#output-columns-predictions" id="markdown-toc-output-columns-predictions">Output Columns (Predictions)</a></li> |
| </ul> |
| </li> |
| </ul> |
| </li> |
| <li><a href="#gradient-boosted-trees-gbts" id="markdown-toc-gradient-boosted-trees-gbts">Gradient-Boosted Trees (GBTs)</a> <ul> |
| <li><a href="#inputs-and-outputs-2" id="markdown-toc-inputs-and-outputs-2">Inputs and Outputs</a> <ul> |
| <li><a href="#input-columns-2" id="markdown-toc-input-columns-2">Input Columns</a></li> |
| <li><a href="#output-columns-predictions-1" id="markdown-toc-output-columns-predictions-1">Output Columns (Predictions)</a></li> |
| </ul> |
| </li> |
| </ul> |
| </li> |
| </ul> |
| </li> |
| </ul> |
| |
| <h1 id="classification">Classification</h1> |
| |
| <h2 id="logistic-regression">Logistic regression</h2> |
| |
| <p>Logistic regression is a popular method to predict a categorical response. It is a special case of <a href="https://en.wikipedia.org/wiki/Generalized_linear_model">Generalized Linear models</a> that predicts the probability of the outcomes. |
| In <code class="language-plaintext highlighter-rouge">spark.ml</code> logistic regression can be used to predict a binary outcome by using binomial logistic regression, or it can be used to predict a multiclass outcome by using multinomial logistic regression. Use the <code class="language-plaintext highlighter-rouge">family</code> |
| parameter to select between these two algorithms, or leave it unset and Spark will infer the correct variant.</p> |
| |
| <blockquote> |
| <p>Multinomial logistic regression can be used for binary classification by setting the <code class="language-plaintext highlighter-rouge">family</code> param to “multinomial”. It will produce two sets of coefficients and two intercepts.</p> |
| </blockquote> |
| |
| <blockquote> |
| <p>When fitting LogisticRegressionModel without intercept on dataset with constant nonzero column, Spark MLlib outputs zero coefficients for constant nonzero columns. This behavior is the same as R glmnet but different from LIBSVM.</p> |
| </blockquote> |
| |
| <h3 id="binomial-logistic-regression">Binomial logistic regression</h3> |
| |
| <p>For more background and more details about the implementation of binomial logistic regression, refer to the documentation of <a href="mllib-linear-methods.html#logistic-regression">logistic regression in <code class="language-plaintext highlighter-rouge">spark.mllib</code></a>.</p> |
| |
| <p><strong>Examples</strong></p> |
| |
| <p>The following example shows how to train binomial and multinomial logistic regression |
| models for binary classification with elastic net regularization. <code class="language-plaintext highlighter-rouge">elasticNetParam</code> corresponds to |
| $\alpha$ and <code class="language-plaintext highlighter-rouge">regParam</code> corresponds to $\lambda$.</p> |
| |
| <div class="codetabs"> |
| |
| <div data-lang="scala"> |
| |
| <p>More details on parameters can be found in the <a href="api/scala/org/apache/spark/ml/classification/LogisticRegression.html">Scala API documentation</a>.</p> |
| |
| <div class="highlight"><pre class="codehilite"><code><span class="k">import</span> <span class="nn">org.apache.spark.ml.classification.LogisticRegression</span> |
| |
| <span class="c1">// Load training data</span> |
| <span class="k">val</span> <span class="nv">training</span> <span class="k">=</span> <span class="nv">spark</span><span class="o">.</span><span class="py">read</span><span class="o">.</span><span class="py">format</span><span class="o">(</span><span class="s">"libsvm"</span><span class="o">).</span><span class="py">load</span><span class="o">(</span><span class="s">"data/mllib/sample_libsvm_data.txt"</span><span class="o">)</span> |
| |
| <span class="k">val</span> <span class="nv">lr</span> <span class="k">=</span> <span class="k">new</span> <span class="nc">LogisticRegression</span><span class="o">()</span> |
| <span class="o">.</span><span class="py">setMaxIter</span><span class="o">(</span><span class="mi">10</span><span class="o">)</span> |
| <span class="o">.</span><span class="py">setRegParam</span><span class="o">(</span><span class="mf">0.3</span><span class="o">)</span> |
| <span class="o">.</span><span class="py">setElasticNetParam</span><span class="o">(</span><span class="mf">0.8</span><span class="o">)</span> |
| |
| <span class="c1">// Fit the model</span> |
| <span class="k">val</span> <span class="nv">lrModel</span> <span class="k">=</span> <span class="nv">lr</span><span class="o">.</span><span class="py">fit</span><span class="o">(</span><span class="n">training</span><span class="o">)</span> |
| |
| <span class="c1">// Print the coefficients and intercept for logistic regression</span> |
| <span class="nf">println</span><span class="o">(</span><span class="n">s</span><span class="s">"Coefficients: ${lrModel.coefficients} Intercept: ${lrModel.intercept}"</span><span class="o">)</span> |
| |
| <span class="c1">// We can also use the multinomial family for binary classification</span> |
| <span class="k">val</span> <span class="nv">mlr</span> <span class="k">=</span> <span class="k">new</span> <span class="nc">LogisticRegression</span><span class="o">()</span> |
| <span class="o">.</span><span class="py">setMaxIter</span><span class="o">(</span><span class="mi">10</span><span class="o">)</span> |
| <span class="o">.</span><span class="py">setRegParam</span><span class="o">(</span><span class="mf">0.3</span><span class="o">)</span> |
| <span class="o">.</span><span class="py">setElasticNetParam</span><span class="o">(</span><span class="mf">0.8</span><span class="o">)</span> |
| <span class="o">.</span><span class="py">setFamily</span><span class="o">(</span><span class="s">"multinomial"</span><span class="o">)</span> |
| |
| <span class="k">val</span> <span class="nv">mlrModel</span> <span class="k">=</span> <span class="nv">mlr</span><span class="o">.</span><span class="py">fit</span><span class="o">(</span><span class="n">training</span><span class="o">)</span> |
| |
| <span class="c1">// Print the coefficients and intercepts for logistic regression with multinomial family</span> |
| <span class="nf">println</span><span class="o">(</span><span class="n">s</span><span class="s">"Multinomial coefficients: ${mlrModel.coefficientMatrix}"</span><span class="o">)</span> |
| <span class="nf">println</span><span class="o">(</span><span class="n">s</span><span class="s">"Multinomial intercepts: ${mlrModel.interceptVector}"</span><span class="o">)</span></code></pre></div> |
| <div><small>Find full example code at "examples/src/main/scala/org/apache/spark/examples/ml/LogisticRegressionWithElasticNetExample.scala" in the Spark repo.</small></div> |
| </div> |
| |
| <div data-lang="java"> |
| |
| <p>More details on parameters can be found in the <a href="api/java/org/apache/spark/ml/classification/LogisticRegression.html">Java API documentation</a>.</p> |
| |
| <div class="highlight"><pre class="codehilite"><code><span class="kn">import</span> <span class="nn">org.apache.spark.ml.classification.LogisticRegression</span><span class="o">;</span> |
| <span class="kn">import</span> <span class="nn">org.apache.spark.ml.classification.LogisticRegressionModel</span><span class="o">;</span> |
| <span class="kn">import</span> <span class="nn">org.apache.spark.sql.Dataset</span><span class="o">;</span> |
| <span class="kn">import</span> <span class="nn">org.apache.spark.sql.Row</span><span class="o">;</span> |
| <span class="kn">import</span> <span class="nn">org.apache.spark.sql.SparkSession</span><span class="o">;</span> |
| |
| <span class="c1">// Load training data</span> |
| <span class="nc">Dataset</span><span class="o"><</span><span class="nc">Row</span><span class="o">></span> <span class="n">training</span> <span class="o">=</span> <span class="n">spark</span><span class="o">.</span><span class="na">read</span><span class="o">().</span><span class="na">format</span><span class="o">(</span><span class="s">"libsvm"</span><span class="o">)</span> |
| <span class="o">.</span><span class="na">load</span><span class="o">(</span><span class="s">"data/mllib/sample_libsvm_data.txt"</span><span class="o">);</span> |
| |
| <span class="nc">LogisticRegression</span> <span class="n">lr</span> <span class="o">=</span> <span class="k">new</span> <span class="nc">LogisticRegression</span><span class="o">()</span> |
| <span class="o">.</span><span class="na">setMaxIter</span><span class="o">(</span><span class="mi">10</span><span class="o">)</span> |
| <span class="o">.</span><span class="na">setRegParam</span><span class="o">(</span><span class="mf">0.3</span><span class="o">)</span> |
| <span class="o">.</span><span class="na">setElasticNetParam</span><span class="o">(</span><span class="mf">0.8</span><span class="o">);</span> |
| |
| <span class="c1">// Fit the model</span> |
| <span class="nc">LogisticRegressionModel</span> <span class="n">lrModel</span> <span class="o">=</span> <span class="n">lr</span><span class="o">.</span><span class="na">fit</span><span class="o">(</span><span class="n">training</span><span class="o">);</span> |
| |
| <span class="c1">// Print the coefficients and intercept for logistic regression</span> |
| <span class="nc">System</span><span class="o">.</span><span class="na">out</span><span class="o">.</span><span class="na">println</span><span class="o">(</span><span class="s">"Coefficients: "</span> |
| <span class="o">+</span> <span class="n">lrModel</span><span class="o">.</span><span class="na">coefficients</span><span class="o">()</span> <span class="o">+</span> <span class="s">" Intercept: "</span> <span class="o">+</span> <span class="n">lrModel</span><span class="o">.</span><span class="na">intercept</span><span class="o">());</span> |
| |
| <span class="c1">// We can also use the multinomial family for binary classification</span> |
| <span class="nc">LogisticRegression</span> <span class="n">mlr</span> <span class="o">=</span> <span class="k">new</span> <span class="nc">LogisticRegression</span><span class="o">()</span> |
| <span class="o">.</span><span class="na">setMaxIter</span><span class="o">(</span><span class="mi">10</span><span class="o">)</span> |
| <span class="o">.</span><span class="na">setRegParam</span><span class="o">(</span><span class="mf">0.3</span><span class="o">)</span> |
| <span class="o">.</span><span class="na">setElasticNetParam</span><span class="o">(</span><span class="mf">0.8</span><span class="o">)</span> |
| <span class="o">.</span><span class="na">setFamily</span><span class="o">(</span><span class="s">"multinomial"</span><span class="o">);</span> |
| |
| <span class="c1">// Fit the model</span> |
| <span class="nc">LogisticRegressionModel</span> <span class="n">mlrModel</span> <span class="o">=</span> <span class="n">mlr</span><span class="o">.</span><span class="na">fit</span><span class="o">(</span><span class="n">training</span><span class="o">);</span> |
| |
| <span class="c1">// Print the coefficients and intercepts for logistic regression with multinomial family</span> |
| <span class="nc">System</span><span class="o">.</span><span class="na">out</span><span class="o">.</span><span class="na">println</span><span class="o">(</span><span class="s">"Multinomial coefficients: "</span> <span class="o">+</span> <span class="n">lrModel</span><span class="o">.</span><span class="na">coefficientMatrix</span><span class="o">()</span> |
| <span class="o">+</span> <span class="s">"\nMultinomial intercepts: "</span> <span class="o">+</span> <span class="n">mlrModel</span><span class="o">.</span><span class="na">interceptVector</span><span class="o">());</span></code></pre></div> |
| <div><small>Find full example code at "examples/src/main/java/org/apache/spark/examples/ml/JavaLogisticRegressionWithElasticNetExample.java" in the Spark repo.</small></div> |
| </div> |
| |
| <div data-lang="python"> |
| |
| <p>More details on parameters can be found in the <a href="api/python/reference/api/pyspark.ml.classification.LogisticRegression.html">Python API documentation</a>.</p> |
| |
| <div class="highlight"><pre class="codehilite"><code><span class="kn">from</span> <span class="nn">pyspark.ml.classification</span> <span class="kn">import</span> <span class="n">LogisticRegression</span> |
| |
| <span class="c1"># Load training data |
| </span><span class="n">training</span> <span class="o">=</span> <span class="n">spark</span><span class="p">.</span><span class="n">read</span><span class="p">.</span><span class="nb">format</span><span class="p">(</span><span class="s">"libsvm"</span><span class="p">).</span><span class="n">load</span><span class="p">(</span><span class="s">"data/mllib/sample_libsvm_data.txt"</span><span class="p">)</span> |
| |
| <span class="n">lr</span> <span class="o">=</span> <span class="n">LogisticRegression</span><span class="p">(</span><span class="n">maxIter</span><span class="o">=</span><span class="mi">10</span><span class="p">,</span> <span class="n">regParam</span><span class="o">=</span><span class="mf">0.3</span><span class="p">,</span> <span class="n">elasticNetParam</span><span class="o">=</span><span class="mf">0.8</span><span class="p">)</span> |
| |
| <span class="c1"># Fit the model |
| </span><span class="n">lrModel</span> <span class="o">=</span> <span class="n">lr</span><span class="p">.</span><span class="n">fit</span><span class="p">(</span><span class="n">training</span><span class="p">)</span> |
| |
| <span class="c1"># Print the coefficients and intercept for logistic regression |
| </span><span class="k">print</span><span class="p">(</span><span class="s">"Coefficients: "</span> <span class="o">+</span> <span class="nb">str</span><span class="p">(</span><span class="n">lrModel</span><span class="p">.</span><span class="n">coefficients</span><span class="p">))</span> |
| <span class="k">print</span><span class="p">(</span><span class="s">"Intercept: "</span> <span class="o">+</span> <span class="nb">str</span><span class="p">(</span><span class="n">lrModel</span><span class="p">.</span><span class="n">intercept</span><span class="p">))</span> |
| |
| <span class="c1"># We can also use the multinomial family for binary classification |
| </span><span class="n">mlr</span> <span class="o">=</span> <span class="n">LogisticRegression</span><span class="p">(</span><span class="n">maxIter</span><span class="o">=</span><span class="mi">10</span><span class="p">,</span> <span class="n">regParam</span><span class="o">=</span><span class="mf">0.3</span><span class="p">,</span> <span class="n">elasticNetParam</span><span class="o">=</span><span class="mf">0.8</span><span class="p">,</span> <span class="n">family</span><span class="o">=</span><span class="s">"multinomial"</span><span class="p">)</span> |
| |
| <span class="c1"># Fit the model |
| </span><span class="n">mlrModel</span> <span class="o">=</span> <span class="n">mlr</span><span class="p">.</span><span class="n">fit</span><span class="p">(</span><span class="n">training</span><span class="p">)</span> |
| |
| <span class="c1"># Print the coefficients and intercepts for logistic regression with multinomial family |
| </span><span class="k">print</span><span class="p">(</span><span class="s">"Multinomial coefficients: "</span> <span class="o">+</span> <span class="nb">str</span><span class="p">(</span><span class="n">mlrModel</span><span class="p">.</span><span class="n">coefficientMatrix</span><span class="p">))</span> |
| <span class="k">print</span><span class="p">(</span><span class="s">"Multinomial intercepts: "</span> <span class="o">+</span> <span class="nb">str</span><span class="p">(</span><span class="n">mlrModel</span><span class="p">.</span><span class="n">interceptVector</span><span class="p">))</span></code></pre></div> |
| <div><small>Find full example code at "examples/src/main/python/ml/logistic_regression_with_elastic_net.py" in the Spark repo.</small></div> |
| </div> |
| |
| <div data-lang="r"> |
| |
| <p>More details on parameters can be found in the <a href="api/R/spark.logit.html">R API documentation</a>.</p> |
| |
| <div class="highlight"><pre class="codehilite"><code><span class="c1"># Load training data</span><span class="w"> |
| </span><span class="n">df</span><span class="w"> </span><span class="o"><-</span><span class="w"> </span><span class="n">read.df</span><span class="p">(</span><span class="s2">"data/mllib/sample_libsvm_data.txt"</span><span class="p">,</span><span class="w"> </span><span class="n">source</span><span class="w"> </span><span class="o">=</span><span class="w"> </span><span class="s2">"libsvm"</span><span class="p">)</span><span class="w"> |
| </span><span class="n">training</span><span class="w"> </span><span class="o"><-</span><span class="w"> </span><span class="n">df</span><span class="w"> |
| </span><span class="n">test</span><span class="w"> </span><span class="o"><-</span><span class="w"> </span><span class="n">df</span><span class="w"> |
| |
| </span><span class="c1"># Fit an binomial logistic regression model with spark.logit</span><span class="w"> |
| </span><span class="n">model</span><span class="w"> </span><span class="o"><-</span><span class="w"> </span><span class="n">spark.logit</span><span class="p">(</span><span class="n">training</span><span class="p">,</span><span class="w"> </span><span class="n">label</span><span class="w"> </span><span class="o">~</span><span class="w"> </span><span class="n">features</span><span class="p">,</span><span class="w"> </span><span class="n">maxIter</span><span class="w"> </span><span class="o">=</span><span class="w"> </span><span class="m">10</span><span class="p">,</span><span class="w"> </span><span class="n">regParam</span><span class="w"> </span><span class="o">=</span><span class="w"> </span><span class="m">0.3</span><span class="p">,</span><span class="w"> </span><span class="n">elasticNetParam</span><span class="w"> </span><span class="o">=</span><span class="w"> </span><span class="m">0.8</span><span class="p">)</span><span class="w"> |
| |
| </span><span class="c1"># Model summary</span><span class="w"> |
| </span><span class="n">summary</span><span class="p">(</span><span class="n">model</span><span class="p">)</span><span class="w"> |
| |
| </span><span class="c1"># Prediction</span><span class="w"> |
| </span><span class="n">predictions</span><span class="w"> </span><span class="o"><-</span><span class="w"> </span><span class="n">predict</span><span class="p">(</span><span class="n">model</span><span class="p">,</span><span class="w"> </span><span class="n">test</span><span class="p">)</span><span class="w"> |
| </span><span class="n">head</span><span class="p">(</span><span class="n">predictions</span><span class="p">)</span></code></pre></div> |
| <div><small>Find full example code at "examples/src/main/r/ml/logit.R" in the Spark repo.</small></div> |
| </div> |
| |
| </div> |
| |
| <p>The <code class="language-plaintext highlighter-rouge">spark.ml</code> implementation of logistic regression also supports |
| extracting a summary of the model over the training set. Note that the |
| predictions and metrics which are stored as <code class="language-plaintext highlighter-rouge">DataFrame</code> in |
| <code class="language-plaintext highlighter-rouge">LogisticRegressionSummary</code> are annotated <code class="language-plaintext highlighter-rouge">@transient</code> and hence |
| only available on the driver.</p> |
| |
| <div class="codetabs"> |
| |
| <div data-lang="scala"> |
| |
| <p><a href="api/scala/org/apache/spark/ml/classification/LogisticRegressionTrainingSummary.html"><code class="language-plaintext highlighter-rouge">LogisticRegressionTrainingSummary</code></a> |
| provides a summary for a |
| <a href="api/scala/org/apache/spark/ml/classification/LogisticRegressionModel.html"><code class="language-plaintext highlighter-rouge">LogisticRegressionModel</code></a>. |
| In the case of binary classification, certain additional metrics are |
| available, e.g. ROC curve. The binary summary can be accessed via the |
| <code class="language-plaintext highlighter-rouge">binarySummary</code> method. See <a href="api/scala/org/apache/spark/ml/classification/BinaryLogisticRegressionTrainingSummary.html"><code class="language-plaintext highlighter-rouge">BinaryLogisticRegressionTrainingSummary</code></a>.</p> |
| |
| <p>Continuing the earlier example:</p> |
| |
| <div class="highlight"><pre class="codehilite"><code><span class="k">import</span> <span class="nn">org.apache.spark.ml.classification.LogisticRegression</span> |
| |
| <span class="c1">// Extract the summary from the returned LogisticRegressionModel instance trained in the earlier</span> |
| <span class="c1">// example</span> |
| <span class="k">val</span> <span class="nv">trainingSummary</span> <span class="k">=</span> <span class="nv">lrModel</span><span class="o">.</span><span class="py">binarySummary</span> |
| |
| <span class="c1">// Obtain the objective per iteration.</span> |
| <span class="k">val</span> <span class="nv">objectiveHistory</span> <span class="k">=</span> <span class="nv">trainingSummary</span><span class="o">.</span><span class="py">objectiveHistory</span> |
| <span class="nf">println</span><span class="o">(</span><span class="s">"objectiveHistory:"</span><span class="o">)</span> |
| <span class="nv">objectiveHistory</span><span class="o">.</span><span class="py">foreach</span><span class="o">(</span><span class="n">loss</span> <span class="k">=></span> <span class="nf">println</span><span class="o">(</span><span class="n">loss</span><span class="o">))</span> |
| |
| <span class="c1">// Obtain the receiver-operating characteristic as a dataframe and areaUnderROC.</span> |
| <span class="k">val</span> <span class="nv">roc</span> <span class="k">=</span> <span class="nv">trainingSummary</span><span class="o">.</span><span class="py">roc</span> |
| <span class="nv">roc</span><span class="o">.</span><span class="py">show</span><span class="o">()</span> |
| <span class="nf">println</span><span class="o">(</span><span class="n">s</span><span class="s">"areaUnderROC: ${trainingSummary.areaUnderROC}"</span><span class="o">)</span> |
| |
| <span class="c1">// Set the model threshold to maximize F-Measure</span> |
| <span class="k">val</span> <span class="nv">fMeasure</span> <span class="k">=</span> <span class="nv">trainingSummary</span><span class="o">.</span><span class="py">fMeasureByThreshold</span> |
| <span class="k">val</span> <span class="nv">maxFMeasure</span> <span class="k">=</span> <span class="nv">fMeasure</span><span class="o">.</span><span class="py">select</span><span class="o">(</span><span class="nf">max</span><span class="o">(</span><span class="s">"F-Measure"</span><span class="o">)).</span><span class="py">head</span><span class="o">().</span><span class="py">getDouble</span><span class="o">(</span><span class="mi">0</span><span class="o">)</span> |
| <span class="k">val</span> <span class="nv">bestThreshold</span> <span class="k">=</span> <span class="nv">fMeasure</span><span class="o">.</span><span class="py">where</span><span class="o">(</span><span class="n">$</span><span class="s">"F-Measure"</span> <span class="o">===</span> <span class="n">maxFMeasure</span><span class="o">)</span> |
| <span class="o">.</span><span class="py">select</span><span class="o">(</span><span class="s">"threshold"</span><span class="o">).</span><span class="py">head</span><span class="o">().</span><span class="py">getDouble</span><span class="o">(</span><span class="mi">0</span><span class="o">)</span> |
| <span class="nv">lrModel</span><span class="o">.</span><span class="py">setThreshold</span><span class="o">(</span><span class="n">bestThreshold</span><span class="o">)</span></code></pre></div> |
| <div><small>Find full example code at "examples/src/main/scala/org/apache/spark/examples/ml/LogisticRegressionSummaryExample.scala" in the Spark repo.</small></div> |
| </div> |
| |
| <div data-lang="java"> |
| <p><a href="api/java/org/apache/spark/ml/classification/LogisticRegressionTrainingSummary.html"><code class="language-plaintext highlighter-rouge">LogisticRegressionTrainingSummary</code></a> |
| provides a summary for a |
| <a href="api/java/org/apache/spark/ml/classification/LogisticRegressionModel.html"><code class="language-plaintext highlighter-rouge">LogisticRegressionModel</code></a>. |
| In the case of binary classification, certain additional metrics are |
| available, e.g. ROC curve. The binary summary can be accessed via the |
| <code class="language-plaintext highlighter-rouge">binarySummary</code> method. See <a href="api/java/org/apache/spark/ml/classification/BinaryLogisticRegressionTrainingSummary.html"><code class="language-plaintext highlighter-rouge">BinaryLogisticRegressionTrainingSummary</code></a>.</p> |
| |
| <p>Continuing the earlier example:</p> |
| |
| <div class="highlight"><pre class="codehilite"><code><span class="kn">import</span> <span class="nn">org.apache.spark.ml.classification.BinaryLogisticRegressionTrainingSummary</span><span class="o">;</span> |
| <span class="kn">import</span> <span class="nn">org.apache.spark.ml.classification.LogisticRegression</span><span class="o">;</span> |
| <span class="kn">import</span> <span class="nn">org.apache.spark.ml.classification.LogisticRegressionModel</span><span class="o">;</span> |
| <span class="kn">import</span> <span class="nn">org.apache.spark.sql.Dataset</span><span class="o">;</span> |
| <span class="kn">import</span> <span class="nn">org.apache.spark.sql.Row</span><span class="o">;</span> |
| <span class="kn">import</span> <span class="nn">org.apache.spark.sql.SparkSession</span><span class="o">;</span> |
| <span class="kn">import</span> <span class="nn">org.apache.spark.sql.functions</span><span class="o">;</span> |
| |
| <span class="c1">// Extract the summary from the returned LogisticRegressionModel instance trained in the earlier</span> |
| <span class="c1">// example</span> |
| <span class="nc">BinaryLogisticRegressionTrainingSummary</span> <span class="n">trainingSummary</span> <span class="o">=</span> <span class="n">lrModel</span><span class="o">.</span><span class="na">binarySummary</span><span class="o">();</span> |
| |
| <span class="c1">// Obtain the loss per iteration.</span> |
| <span class="kt">double</span><span class="o">[]</span> <span class="n">objectiveHistory</span> <span class="o">=</span> <span class="n">trainingSummary</span><span class="o">.</span><span class="na">objectiveHistory</span><span class="o">();</span> |
| <span class="k">for</span> <span class="o">(</span><span class="kt">double</span> <span class="n">lossPerIteration</span> <span class="o">:</span> <span class="n">objectiveHistory</span><span class="o">)</span> <span class="o">{</span> |
| <span class="nc">System</span><span class="o">.</span><span class="na">out</span><span class="o">.</span><span class="na">println</span><span class="o">(</span><span class="n">lossPerIteration</span><span class="o">);</span> |
| <span class="o">}</span> |
| |
| <span class="c1">// Obtain the receiver-operating characteristic as a dataframe and areaUnderROC.</span> |
| <span class="nc">Dataset</span><span class="o"><</span><span class="nc">Row</span><span class="o">></span> <span class="n">roc</span> <span class="o">=</span> <span class="n">trainingSummary</span><span class="o">.</span><span class="na">roc</span><span class="o">();</span> |
| <span class="n">roc</span><span class="o">.</span><span class="na">show</span><span class="o">();</span> |
| <span class="n">roc</span><span class="o">.</span><span class="na">select</span><span class="o">(</span><span class="s">"FPR"</span><span class="o">).</span><span class="na">show</span><span class="o">();</span> |
| <span class="nc">System</span><span class="o">.</span><span class="na">out</span><span class="o">.</span><span class="na">println</span><span class="o">(</span><span class="n">trainingSummary</span><span class="o">.</span><span class="na">areaUnderROC</span><span class="o">());</span> |
| |
| <span class="c1">// Get the threshold corresponding to the maximum F-Measure and rerun LogisticRegression with</span> |
| <span class="c1">// this selected threshold.</span> |
| <span class="nc">Dataset</span><span class="o"><</span><span class="nc">Row</span><span class="o">></span> <span class="n">fMeasure</span> <span class="o">=</span> <span class="n">trainingSummary</span><span class="o">.</span><span class="na">fMeasureByThreshold</span><span class="o">();</span> |
| <span class="kt">double</span> <span class="n">maxFMeasure</span> <span class="o">=</span> <span class="n">fMeasure</span><span class="o">.</span><span class="na">select</span><span class="o">(</span><span class="n">functions</span><span class="o">.</span><span class="na">max</span><span class="o">(</span><span class="s">"F-Measure"</span><span class="o">)).</span><span class="na">head</span><span class="o">().</span><span class="na">getDouble</span><span class="o">(</span><span class="mi">0</span><span class="o">);</span> |
| <span class="kt">double</span> <span class="n">bestThreshold</span> <span class="o">=</span> <span class="n">fMeasure</span><span class="o">.</span><span class="na">where</span><span class="o">(</span><span class="n">fMeasure</span><span class="o">.</span><span class="na">col</span><span class="o">(</span><span class="s">"F-Measure"</span><span class="o">).</span><span class="na">equalTo</span><span class="o">(</span><span class="n">maxFMeasure</span><span class="o">))</span> |
| <span class="o">.</span><span class="na">select</span><span class="o">(</span><span class="s">"threshold"</span><span class="o">).</span><span class="na">head</span><span class="o">().</span><span class="na">getDouble</span><span class="o">(</span><span class="mi">0</span><span class="o">);</span> |
| <span class="n">lrModel</span><span class="o">.</span><span class="na">setThreshold</span><span class="o">(</span><span class="n">bestThreshold</span><span class="o">);</span></code></pre></div> |
| <div><small>Find full example code at "examples/src/main/java/org/apache/spark/examples/ml/JavaLogisticRegressionSummaryExample.java" in the Spark repo.</small></div> |
| </div> |
| |
| <div data-lang="python"> |
| <p><a href="api/python/reference/api/pyspark.ml.classification.LogisticRegressionSummary.html"><code class="language-plaintext highlighter-rouge">LogisticRegressionTrainingSummary</code></a> |
| provides a summary for a |
| <a href="api/python/reference/api/pyspark.ml.classification.LogisticRegressionModel.html"><code class="language-plaintext highlighter-rouge">LogisticRegressionModel</code></a>. |
| In the case of binary classification, certain additional metrics are |
| available, e.g. ROC curve. See <a href="api/python/reference/api/pyspark.ml.classification.BinaryLogisticRegressionTrainingSummary.html"><code class="language-plaintext highlighter-rouge">BinaryLogisticRegressionTrainingSummary</code></a>.</p> |
| |
| <p>Continuing the earlier example:</p> |
| |
| <div class="highlight"><pre class="codehilite"><code><span class="kn">from</span> <span class="nn">pyspark.ml.classification</span> <span class="kn">import</span> <span class="n">LogisticRegression</span> |
| |
| <span class="c1"># Extract the summary from the returned LogisticRegressionModel instance trained |
| # in the earlier example |
| </span><span class="n">trainingSummary</span> <span class="o">=</span> <span class="n">lrModel</span><span class="p">.</span><span class="n">summary</span> |
| |
| <span class="c1"># Obtain the objective per iteration |
| </span><span class="n">objectiveHistory</span> <span class="o">=</span> <span class="n">trainingSummary</span><span class="p">.</span><span class="n">objectiveHistory</span> |
| <span class="k">print</span><span class="p">(</span><span class="s">"objectiveHistory:"</span><span class="p">)</span> |
| <span class="k">for</span> <span class="n">objective</span> <span class="ow">in</span> <span class="n">objectiveHistory</span><span class="p">:</span> |
| <span class="k">print</span><span class="p">(</span><span class="n">objective</span><span class="p">)</span> |
| |
| <span class="c1"># Obtain the receiver-operating characteristic as a dataframe and areaUnderROC. |
| </span><span class="n">trainingSummary</span><span class="p">.</span><span class="n">roc</span><span class="p">.</span><span class="n">show</span><span class="p">()</span> |
| <span class="k">print</span><span class="p">(</span><span class="s">"areaUnderROC: "</span> <span class="o">+</span> <span class="nb">str</span><span class="p">(</span><span class="n">trainingSummary</span><span class="p">.</span><span class="n">areaUnderROC</span><span class="p">))</span> |
| |
| <span class="c1"># Set the model threshold to maximize F-Measure |
| </span><span class="n">fMeasure</span> <span class="o">=</span> <span class="n">trainingSummary</span><span class="p">.</span><span class="n">fMeasureByThreshold</span> |
| <span class="n">maxFMeasure</span> <span class="o">=</span> <span class="n">fMeasure</span><span class="p">.</span><span class="n">groupBy</span><span class="p">().</span><span class="nb">max</span><span class="p">(</span><span class="s">'F-Measure'</span><span class="p">).</span><span class="n">select</span><span class="p">(</span><span class="s">'max(F-Measure)'</span><span class="p">).</span><span class="n">head</span><span class="p">()</span> |
| <span class="n">bestThreshold</span> <span class="o">=</span> <span class="n">fMeasure</span><span class="p">.</span><span class="n">where</span><span class="p">(</span><span class="n">fMeasure</span><span class="p">[</span><span class="s">'F-Measure'</span><span class="p">]</span> <span class="o">==</span> <span class="n">maxFMeasure</span><span class="p">[</span><span class="s">'max(F-Measure)'</span><span class="p">])</span> \ |
| <span class="p">.</span><span class="n">select</span><span class="p">(</span><span class="s">'threshold'</span><span class="p">).</span><span class="n">head</span><span class="p">()[</span><span class="s">'threshold'</span><span class="p">]</span> |
| <span class="n">lr</span><span class="p">.</span><span class="n">setThreshold</span><span class="p">(</span><span class="n">bestThreshold</span><span class="p">)</span></code></pre></div> |
| <div><small>Find full example code at "examples/src/main/python/ml/logistic_regression_summary_example.py" in the Spark repo.</small></div> |
| </div> |
| |
| </div> |
| |
| <h3 id="multinomial-logistic-regression">Multinomial logistic regression</h3> |
| |
| <p>Multiclass classification is supported via multinomial logistic (softmax) regression. In multinomial logistic regression, |
| the algorithm produces $K$ sets of coefficients, or a matrix of dimension $K \times J$ where $K$ is the number of outcome |
| classes and $J$ is the number of features. If the algorithm is fit with an intercept term then a length $K$ vector of |
| intercepts is available.</p> |
| |
| <blockquote> |
| <p>Multinomial coefficients are available as <code class="language-plaintext highlighter-rouge">coefficientMatrix</code> and intercepts are available as <code class="language-plaintext highlighter-rouge">interceptVector</code>.</p> |
| </blockquote> |
| |
| <blockquote> |
| <p><code class="language-plaintext highlighter-rouge">coefficients</code> and <code class="language-plaintext highlighter-rouge">intercept</code> methods on a logistic regression model trained with multinomial family are not supported. Use <code class="language-plaintext highlighter-rouge">coefficientMatrix</code> and <code class="language-plaintext highlighter-rouge">interceptVector</code> instead.</p> |
| </blockquote> |
| |
| <p>The conditional probabilities of the outcome classes $k \in {1, 2, …, K}$ are modeled using the softmax function.</p> |
| |
| <p><code class="language-plaintext highlighter-rouge">\[ |
| P(Y=k|\mathbf{X}, \boldsymbol{\beta}_k, \beta_{0k}) = \frac{e^{\boldsymbol{\beta}_k \cdot \mathbf{X} + \beta_{0k}}}{\sum_{k'=0}^{K-1} e^{\boldsymbol{\beta}_{k'} \cdot \mathbf{X} + \beta_{0k'}}} |
| \]</code></p> |
| |
| <p>We minimize the weighted negative log-likelihood, using a multinomial response model, with elastic-net penalty to control for overfitting.</p> |
| |
| <p><code class="language-plaintext highlighter-rouge">\[ |
| \min_{\beta, \beta_0} -\left[\sum_{i=1}^L w_i \cdot \log P(Y = y_i|\mathbf{x}_i)\right] + \lambda \left[\frac{1}{2}\left(1 - \alpha\right)||\boldsymbol{\beta}||_2^2 + \alpha ||\boldsymbol{\beta}||_1\right] |
| \]</code></p> |
| |
| <p>For a detailed derivation please see <a href="https://en.wikipedia.org/wiki/Multinomial_logistic_regression#As_a_log-linear_model">here</a>.</p> |
| |
| <p><strong>Examples</strong></p> |
| |
| <p>The following example shows how to train a multiclass logistic regression |
| model with elastic net regularization, as well as extract the multiclass |
| training summary for evaluating the model.</p> |
| |
| <div class="codetabs"> |
| |
| <div data-lang="scala"> |
| <div class="highlight"><pre class="codehilite"><code><span class="k">import</span> <span class="nn">org.apache.spark.ml.classification.LogisticRegression</span> |
| |
| <span class="c1">// Load training data</span> |
| <span class="k">val</span> <span class="nv">training</span> <span class="k">=</span> <span class="n">spark</span> |
| <span class="o">.</span><span class="py">read</span> |
| <span class="o">.</span><span class="py">format</span><span class="o">(</span><span class="s">"libsvm"</span><span class="o">)</span> |
| <span class="o">.</span><span class="py">load</span><span class="o">(</span><span class="s">"data/mllib/sample_multiclass_classification_data.txt"</span><span class="o">)</span> |
| |
| <span class="k">val</span> <span class="nv">lr</span> <span class="k">=</span> <span class="k">new</span> <span class="nc">LogisticRegression</span><span class="o">()</span> |
| <span class="o">.</span><span class="py">setMaxIter</span><span class="o">(</span><span class="mi">10</span><span class="o">)</span> |
| <span class="o">.</span><span class="py">setRegParam</span><span class="o">(</span><span class="mf">0.3</span><span class="o">)</span> |
| <span class="o">.</span><span class="py">setElasticNetParam</span><span class="o">(</span><span class="mf">0.8</span><span class="o">)</span> |
| |
| <span class="c1">// Fit the model</span> |
| <span class="k">val</span> <span class="nv">lrModel</span> <span class="k">=</span> <span class="nv">lr</span><span class="o">.</span><span class="py">fit</span><span class="o">(</span><span class="n">training</span><span class="o">)</span> |
| |
| <span class="c1">// Print the coefficients and intercept for multinomial logistic regression</span> |
| <span class="nf">println</span><span class="o">(</span><span class="n">s</span><span class="s">"Coefficients: \n${lrModel.coefficientMatrix}"</span><span class="o">)</span> |
| <span class="nf">println</span><span class="o">(</span><span class="n">s</span><span class="s">"Intercepts: \n${lrModel.interceptVector}"</span><span class="o">)</span> |
| |
| <span class="k">val</span> <span class="nv">trainingSummary</span> <span class="k">=</span> <span class="nv">lrModel</span><span class="o">.</span><span class="py">summary</span> |
| |
| <span class="c1">// Obtain the objective per iteration</span> |
| <span class="k">val</span> <span class="nv">objectiveHistory</span> <span class="k">=</span> <span class="nv">trainingSummary</span><span class="o">.</span><span class="py">objectiveHistory</span> |
| <span class="nf">println</span><span class="o">(</span><span class="s">"objectiveHistory:"</span><span class="o">)</span> |
| <span class="nv">objectiveHistory</span><span class="o">.</span><span class="py">foreach</span><span class="o">(</span><span class="n">println</span><span class="o">)</span> |
| |
| <span class="c1">// for multiclass, we can inspect metrics on a per-label basis</span> |
| <span class="nf">println</span><span class="o">(</span><span class="s">"False positive rate by label:"</span><span class="o">)</span> |
| <span class="nv">trainingSummary</span><span class="o">.</span><span class="py">falsePositiveRateByLabel</span><span class="o">.</span><span class="py">zipWithIndex</span><span class="o">.</span><span class="py">foreach</span> <span class="o">{</span> <span class="nf">case</span> <span class="o">(</span><span class="n">rate</span><span class="o">,</span> <span class="n">label</span><span class="o">)</span> <span class="k">=></span> |
| <span class="nf">println</span><span class="o">(</span><span class="n">s</span><span class="s">"label $label: $rate"</span><span class="o">)</span> |
| <span class="o">}</span> |
| |
| <span class="nf">println</span><span class="o">(</span><span class="s">"True positive rate by label:"</span><span class="o">)</span> |
| <span class="nv">trainingSummary</span><span class="o">.</span><span class="py">truePositiveRateByLabel</span><span class="o">.</span><span class="py">zipWithIndex</span><span class="o">.</span><span class="py">foreach</span> <span class="o">{</span> <span class="nf">case</span> <span class="o">(</span><span class="n">rate</span><span class="o">,</span> <span class="n">label</span><span class="o">)</span> <span class="k">=></span> |
| <span class="nf">println</span><span class="o">(</span><span class="n">s</span><span class="s">"label $label: $rate"</span><span class="o">)</span> |
| <span class="o">}</span> |
| |
| <span class="nf">println</span><span class="o">(</span><span class="s">"Precision by label:"</span><span class="o">)</span> |
| <span class="nv">trainingSummary</span><span class="o">.</span><span class="py">precisionByLabel</span><span class="o">.</span><span class="py">zipWithIndex</span><span class="o">.</span><span class="py">foreach</span> <span class="o">{</span> <span class="nf">case</span> <span class="o">(</span><span class="n">prec</span><span class="o">,</span> <span class="n">label</span><span class="o">)</span> <span class="k">=></span> |
| <span class="nf">println</span><span class="o">(</span><span class="n">s</span><span class="s">"label $label: $prec"</span><span class="o">)</span> |
| <span class="o">}</span> |
| |
| <span class="nf">println</span><span class="o">(</span><span class="s">"Recall by label:"</span><span class="o">)</span> |
| <span class="nv">trainingSummary</span><span class="o">.</span><span class="py">recallByLabel</span><span class="o">.</span><span class="py">zipWithIndex</span><span class="o">.</span><span class="py">foreach</span> <span class="o">{</span> <span class="nf">case</span> <span class="o">(</span><span class="n">rec</span><span class="o">,</span> <span class="n">label</span><span class="o">)</span> <span class="k">=></span> |
| <span class="nf">println</span><span class="o">(</span><span class="n">s</span><span class="s">"label $label: $rec"</span><span class="o">)</span> |
| <span class="o">}</span> |
| |
| |
| <span class="nf">println</span><span class="o">(</span><span class="s">"F-measure by label:"</span><span class="o">)</span> |
| <span class="nv">trainingSummary</span><span class="o">.</span><span class="py">fMeasureByLabel</span><span class="o">.</span><span class="py">zipWithIndex</span><span class="o">.</span><span class="py">foreach</span> <span class="o">{</span> <span class="nf">case</span> <span class="o">(</span><span class="n">f</span><span class="o">,</span> <span class="n">label</span><span class="o">)</span> <span class="k">=></span> |
| <span class="nf">println</span><span class="o">(</span><span class="n">s</span><span class="s">"label $label: $f"</span><span class="o">)</span> |
| <span class="o">}</span> |
| |
| <span class="k">val</span> <span class="nv">accuracy</span> <span class="k">=</span> <span class="nv">trainingSummary</span><span class="o">.</span><span class="py">accuracy</span> |
| <span class="k">val</span> <span class="nv">falsePositiveRate</span> <span class="k">=</span> <span class="nv">trainingSummary</span><span class="o">.</span><span class="py">weightedFalsePositiveRate</span> |
| <span class="k">val</span> <span class="nv">truePositiveRate</span> <span class="k">=</span> <span class="nv">trainingSummary</span><span class="o">.</span><span class="py">weightedTruePositiveRate</span> |
| <span class="k">val</span> <span class="nv">fMeasure</span> <span class="k">=</span> <span class="nv">trainingSummary</span><span class="o">.</span><span class="py">weightedFMeasure</span> |
| <span class="k">val</span> <span class="nv">precision</span> <span class="k">=</span> <span class="nv">trainingSummary</span><span class="o">.</span><span class="py">weightedPrecision</span> |
| <span class="k">val</span> <span class="nv">recall</span> <span class="k">=</span> <span class="nv">trainingSummary</span><span class="o">.</span><span class="py">weightedRecall</span> |
| <span class="nf">println</span><span class="o">(</span><span class="n">s</span><span class="s">"Accuracy: $accuracy\nFPR: $falsePositiveRate\nTPR: $truePositiveRate\n"</span> <span class="o">+</span> |
| <span class="n">s</span><span class="s">"F-measure: $fMeasure\nPrecision: $precision\nRecall: $recall"</span><span class="o">)</span></code></pre></div> |
| <div><small>Find full example code at "examples/src/main/scala/org/apache/spark/examples/ml/MulticlassLogisticRegressionWithElasticNetExample.scala" in the Spark repo.</small></div> |
| </div> |
| |
| <div data-lang="java"> |
| <div class="highlight"><pre class="codehilite"><code><span class="kn">import</span> <span class="nn">org.apache.spark.ml.classification.LogisticRegression</span><span class="o">;</span> |
| <span class="kn">import</span> <span class="nn">org.apache.spark.ml.classification.LogisticRegressionModel</span><span class="o">;</span> |
| <span class="kn">import</span> <span class="nn">org.apache.spark.ml.classification.LogisticRegressionTrainingSummary</span><span class="o">;</span> |
| <span class="kn">import</span> <span class="nn">org.apache.spark.sql.Dataset</span><span class="o">;</span> |
| <span class="kn">import</span> <span class="nn">org.apache.spark.sql.Row</span><span class="o">;</span> |
| <span class="kn">import</span> <span class="nn">org.apache.spark.sql.SparkSession</span><span class="o">;</span> |
| |
| <span class="c1">// Load training data</span> |
| <span class="nc">Dataset</span><span class="o"><</span><span class="nc">Row</span><span class="o">></span> <span class="n">training</span> <span class="o">=</span> <span class="n">spark</span><span class="o">.</span><span class="na">read</span><span class="o">().</span><span class="na">format</span><span class="o">(</span><span class="s">"libsvm"</span><span class="o">)</span> |
| <span class="o">.</span><span class="na">load</span><span class="o">(</span><span class="s">"data/mllib/sample_multiclass_classification_data.txt"</span><span class="o">);</span> |
| |
| <span class="nc">LogisticRegression</span> <span class="n">lr</span> <span class="o">=</span> <span class="k">new</span> <span class="nc">LogisticRegression</span><span class="o">()</span> |
| <span class="o">.</span><span class="na">setMaxIter</span><span class="o">(</span><span class="mi">10</span><span class="o">)</span> |
| <span class="o">.</span><span class="na">setRegParam</span><span class="o">(</span><span class="mf">0.3</span><span class="o">)</span> |
| <span class="o">.</span><span class="na">setElasticNetParam</span><span class="o">(</span><span class="mf">0.8</span><span class="o">);</span> |
| |
| <span class="c1">// Fit the model</span> |
| <span class="nc">LogisticRegressionModel</span> <span class="n">lrModel</span> <span class="o">=</span> <span class="n">lr</span><span class="o">.</span><span class="na">fit</span><span class="o">(</span><span class="n">training</span><span class="o">);</span> |
| |
| <span class="c1">// Print the coefficients and intercept for multinomial logistic regression</span> |
| <span class="nc">System</span><span class="o">.</span><span class="na">out</span><span class="o">.</span><span class="na">println</span><span class="o">(</span><span class="s">"Coefficients: \n"</span> |
| <span class="o">+</span> <span class="n">lrModel</span><span class="o">.</span><span class="na">coefficientMatrix</span><span class="o">()</span> <span class="o">+</span> <span class="s">" \nIntercept: "</span> <span class="o">+</span> <span class="n">lrModel</span><span class="o">.</span><span class="na">interceptVector</span><span class="o">());</span> |
| <span class="nc">LogisticRegressionTrainingSummary</span> <span class="n">trainingSummary</span> <span class="o">=</span> <span class="n">lrModel</span><span class="o">.</span><span class="na">summary</span><span class="o">();</span> |
| |
| <span class="c1">// Obtain the loss per iteration.</span> |
| <span class="kt">double</span><span class="o">[]</span> <span class="n">objectiveHistory</span> <span class="o">=</span> <span class="n">trainingSummary</span><span class="o">.</span><span class="na">objectiveHistory</span><span class="o">();</span> |
| <span class="k">for</span> <span class="o">(</span><span class="kt">double</span> <span class="n">lossPerIteration</span> <span class="o">:</span> <span class="n">objectiveHistory</span><span class="o">)</span> <span class="o">{</span> |
| <span class="nc">System</span><span class="o">.</span><span class="na">out</span><span class="o">.</span><span class="na">println</span><span class="o">(</span><span class="n">lossPerIteration</span><span class="o">);</span> |
| <span class="o">}</span> |
| |
| <span class="c1">// for multiclass, we can inspect metrics on a per-label basis</span> |
| <span class="nc">System</span><span class="o">.</span><span class="na">out</span><span class="o">.</span><span class="na">println</span><span class="o">(</span><span class="s">"False positive rate by label:"</span><span class="o">);</span> |
| <span class="kt">int</span> <span class="n">i</span> <span class="o">=</span> <span class="mi">0</span><span class="o">;</span> |
| <span class="kt">double</span><span class="o">[]</span> <span class="n">fprLabel</span> <span class="o">=</span> <span class="n">trainingSummary</span><span class="o">.</span><span class="na">falsePositiveRateByLabel</span><span class="o">();</span> |
| <span class="k">for</span> <span class="o">(</span><span class="kt">double</span> <span class="n">fpr</span> <span class="o">:</span> <span class="n">fprLabel</span><span class="o">)</span> <span class="o">{</span> |
| <span class="nc">System</span><span class="o">.</span><span class="na">out</span><span class="o">.</span><span class="na">println</span><span class="o">(</span><span class="s">"label "</span> <span class="o">+</span> <span class="n">i</span> <span class="o">+</span> <span class="s">": "</span> <span class="o">+</span> <span class="n">fpr</span><span class="o">);</span> |
| <span class="n">i</span><span class="o">++;</span> |
| <span class="o">}</span> |
| |
| <span class="nc">System</span><span class="o">.</span><span class="na">out</span><span class="o">.</span><span class="na">println</span><span class="o">(</span><span class="s">"True positive rate by label:"</span><span class="o">);</span> |
| <span class="n">i</span> <span class="o">=</span> <span class="mi">0</span><span class="o">;</span> |
| <span class="kt">double</span><span class="o">[]</span> <span class="n">tprLabel</span> <span class="o">=</span> <span class="n">trainingSummary</span><span class="o">.</span><span class="na">truePositiveRateByLabel</span><span class="o">();</span> |
| <span class="k">for</span> <span class="o">(</span><span class="kt">double</span> <span class="n">tpr</span> <span class="o">:</span> <span class="n">tprLabel</span><span class="o">)</span> <span class="o">{</span> |
| <span class="nc">System</span><span class="o">.</span><span class="na">out</span><span class="o">.</span><span class="na">println</span><span class="o">(</span><span class="s">"label "</span> <span class="o">+</span> <span class="n">i</span> <span class="o">+</span> <span class="s">": "</span> <span class="o">+</span> <span class="n">tpr</span><span class="o">);</span> |
| <span class="n">i</span><span class="o">++;</span> |
| <span class="o">}</span> |
| |
| <span class="nc">System</span><span class="o">.</span><span class="na">out</span><span class="o">.</span><span class="na">println</span><span class="o">(</span><span class="s">"Precision by label:"</span><span class="o">);</span> |
| <span class="n">i</span> <span class="o">=</span> <span class="mi">0</span><span class="o">;</span> |
| <span class="kt">double</span><span class="o">[]</span> <span class="n">precLabel</span> <span class="o">=</span> <span class="n">trainingSummary</span><span class="o">.</span><span class="na">precisionByLabel</span><span class="o">();</span> |
| <span class="k">for</span> <span class="o">(</span><span class="kt">double</span> <span class="n">prec</span> <span class="o">:</span> <span class="n">precLabel</span><span class="o">)</span> <span class="o">{</span> |
| <span class="nc">System</span><span class="o">.</span><span class="na">out</span><span class="o">.</span><span class="na">println</span><span class="o">(</span><span class="s">"label "</span> <span class="o">+</span> <span class="n">i</span> <span class="o">+</span> <span class="s">": "</span> <span class="o">+</span> <span class="n">prec</span><span class="o">);</span> |
| <span class="n">i</span><span class="o">++;</span> |
| <span class="o">}</span> |
| |
| <span class="nc">System</span><span class="o">.</span><span class="na">out</span><span class="o">.</span><span class="na">println</span><span class="o">(</span><span class="s">"Recall by label:"</span><span class="o">);</span> |
| <span class="n">i</span> <span class="o">=</span> <span class="mi">0</span><span class="o">;</span> |
| <span class="kt">double</span><span class="o">[]</span> <span class="n">recLabel</span> <span class="o">=</span> <span class="n">trainingSummary</span><span class="o">.</span><span class="na">recallByLabel</span><span class="o">();</span> |
| <span class="k">for</span> <span class="o">(</span><span class="kt">double</span> <span class="n">rec</span> <span class="o">:</span> <span class="n">recLabel</span><span class="o">)</span> <span class="o">{</span> |
| <span class="nc">System</span><span class="o">.</span><span class="na">out</span><span class="o">.</span><span class="na">println</span><span class="o">(</span><span class="s">"label "</span> <span class="o">+</span> <span class="n">i</span> <span class="o">+</span> <span class="s">": "</span> <span class="o">+</span> <span class="n">rec</span><span class="o">);</span> |
| <span class="n">i</span><span class="o">++;</span> |
| <span class="o">}</span> |
| |
| <span class="nc">System</span><span class="o">.</span><span class="na">out</span><span class="o">.</span><span class="na">println</span><span class="o">(</span><span class="s">"F-measure by label:"</span><span class="o">);</span> |
| <span class="n">i</span> <span class="o">=</span> <span class="mi">0</span><span class="o">;</span> |
| <span class="kt">double</span><span class="o">[]</span> <span class="n">fLabel</span> <span class="o">=</span> <span class="n">trainingSummary</span><span class="o">.</span><span class="na">fMeasureByLabel</span><span class="o">();</span> |
| <span class="k">for</span> <span class="o">(</span><span class="kt">double</span> <span class="n">f</span> <span class="o">:</span> <span class="n">fLabel</span><span class="o">)</span> <span class="o">{</span> |
| <span class="nc">System</span><span class="o">.</span><span class="na">out</span><span class="o">.</span><span class="na">println</span><span class="o">(</span><span class="s">"label "</span> <span class="o">+</span> <span class="n">i</span> <span class="o">+</span> <span class="s">": "</span> <span class="o">+</span> <span class="n">f</span><span class="o">);</span> |
| <span class="n">i</span><span class="o">++;</span> |
| <span class="o">}</span> |
| |
| <span class="kt">double</span> <span class="n">accuracy</span> <span class="o">=</span> <span class="n">trainingSummary</span><span class="o">.</span><span class="na">accuracy</span><span class="o">();</span> |
| <span class="kt">double</span> <span class="n">falsePositiveRate</span> <span class="o">=</span> <span class="n">trainingSummary</span><span class="o">.</span><span class="na">weightedFalsePositiveRate</span><span class="o">();</span> |
| <span class="kt">double</span> <span class="n">truePositiveRate</span> <span class="o">=</span> <span class="n">trainingSummary</span><span class="o">.</span><span class="na">weightedTruePositiveRate</span><span class="o">();</span> |
| <span class="kt">double</span> <span class="n">fMeasure</span> <span class="o">=</span> <span class="n">trainingSummary</span><span class="o">.</span><span class="na">weightedFMeasure</span><span class="o">();</span> |
| <span class="kt">double</span> <span class="n">precision</span> <span class="o">=</span> <span class="n">trainingSummary</span><span class="o">.</span><span class="na">weightedPrecision</span><span class="o">();</span> |
| <span class="kt">double</span> <span class="n">recall</span> <span class="o">=</span> <span class="n">trainingSummary</span><span class="o">.</span><span class="na">weightedRecall</span><span class="o">();</span> |
| <span class="nc">System</span><span class="o">.</span><span class="na">out</span><span class="o">.</span><span class="na">println</span><span class="o">(</span><span class="s">"Accuracy: "</span> <span class="o">+</span> <span class="n">accuracy</span><span class="o">);</span> |
| <span class="nc">System</span><span class="o">.</span><span class="na">out</span><span class="o">.</span><span class="na">println</span><span class="o">(</span><span class="s">"FPR: "</span> <span class="o">+</span> <span class="n">falsePositiveRate</span><span class="o">);</span> |
| <span class="nc">System</span><span class="o">.</span><span class="na">out</span><span class="o">.</span><span class="na">println</span><span class="o">(</span><span class="s">"TPR: "</span> <span class="o">+</span> <span class="n">truePositiveRate</span><span class="o">);</span> |
| <span class="nc">System</span><span class="o">.</span><span class="na">out</span><span class="o">.</span><span class="na">println</span><span class="o">(</span><span class="s">"F-measure: "</span> <span class="o">+</span> <span class="n">fMeasure</span><span class="o">);</span> |
| <span class="nc">System</span><span class="o">.</span><span class="na">out</span><span class="o">.</span><span class="na">println</span><span class="o">(</span><span class="s">"Precision: "</span> <span class="o">+</span> <span class="n">precision</span><span class="o">);</span> |
| <span class="nc">System</span><span class="o">.</span><span class="na">out</span><span class="o">.</span><span class="na">println</span><span class="o">(</span><span class="s">"Recall: "</span> <span class="o">+</span> <span class="n">recall</span><span class="o">);</span></code></pre></div> |
| <div><small>Find full example code at "examples/src/main/java/org/apache/spark/examples/ml/JavaMulticlassLogisticRegressionWithElasticNetExample.java" in the Spark repo.</small></div> |
| </div> |
| |
| <div data-lang="python"> |
| <div class="highlight"><pre class="codehilite"><code><span class="kn">from</span> <span class="nn">pyspark.ml.classification</span> <span class="kn">import</span> <span class="n">LogisticRegression</span> |
| |
| <span class="c1"># Load training data |
| </span><span class="n">training</span> <span class="o">=</span> <span class="n">spark</span> \ |
| <span class="p">.</span><span class="n">read</span> \ |
| <span class="p">.</span><span class="nb">format</span><span class="p">(</span><span class="s">"libsvm"</span><span class="p">)</span> \ |
| <span class="p">.</span><span class="n">load</span><span class="p">(</span><span class="s">"data/mllib/sample_multiclass_classification_data.txt"</span><span class="p">)</span> |
| |
| <span class="n">lr</span> <span class="o">=</span> <span class="n">LogisticRegression</span><span class="p">(</span><span class="n">maxIter</span><span class="o">=</span><span class="mi">10</span><span class="p">,</span> <span class="n">regParam</span><span class="o">=</span><span class="mf">0.3</span><span class="p">,</span> <span class="n">elasticNetParam</span><span class="o">=</span><span class="mf">0.8</span><span class="p">)</span> |
| |
| <span class="c1"># Fit the model |
| </span><span class="n">lrModel</span> <span class="o">=</span> <span class="n">lr</span><span class="p">.</span><span class="n">fit</span><span class="p">(</span><span class="n">training</span><span class="p">)</span> |
| |
| <span class="c1"># Print the coefficients and intercept for multinomial logistic regression |
| </span><span class="k">print</span><span class="p">(</span><span class="s">"Coefficients: </span><span class="se">\n</span><span class="s">"</span> <span class="o">+</span> <span class="nb">str</span><span class="p">(</span><span class="n">lrModel</span><span class="p">.</span><span class="n">coefficientMatrix</span><span class="p">))</span> |
| <span class="k">print</span><span class="p">(</span><span class="s">"Intercept: "</span> <span class="o">+</span> <span class="nb">str</span><span class="p">(</span><span class="n">lrModel</span><span class="p">.</span><span class="n">interceptVector</span><span class="p">))</span> |
| |
| <span class="n">trainingSummary</span> <span class="o">=</span> <span class="n">lrModel</span><span class="p">.</span><span class="n">summary</span> |
| |
| <span class="c1"># Obtain the objective per iteration |
| </span><span class="n">objectiveHistory</span> <span class="o">=</span> <span class="n">trainingSummary</span><span class="p">.</span><span class="n">objectiveHistory</span> |
| <span class="k">print</span><span class="p">(</span><span class="s">"objectiveHistory:"</span><span class="p">)</span> |
| <span class="k">for</span> <span class="n">objective</span> <span class="ow">in</span> <span class="n">objectiveHistory</span><span class="p">:</span> |
| <span class="k">print</span><span class="p">(</span><span class="n">objective</span><span class="p">)</span> |
| |
| <span class="c1"># for multiclass, we can inspect metrics on a per-label basis |
| </span><span class="k">print</span><span class="p">(</span><span class="s">"False positive rate by label:"</span><span class="p">)</span> |
| <span class="k">for</span> <span class="n">i</span><span class="p">,</span> <span class="n">rate</span> <span class="ow">in</span> <span class="nb">enumerate</span><span class="p">(</span><span class="n">trainingSummary</span><span class="p">.</span><span class="n">falsePositiveRateByLabel</span><span class="p">):</span> |
| <span class="k">print</span><span class="p">(</span><span class="s">"label %d: %s"</span> <span class="o">%</span> <span class="p">(</span><span class="n">i</span><span class="p">,</span> <span class="n">rate</span><span class="p">))</span> |
| |
| <span class="k">print</span><span class="p">(</span><span class="s">"True positive rate by label:"</span><span class="p">)</span> |
| <span class="k">for</span> <span class="n">i</span><span class="p">,</span> <span class="n">rate</span> <span class="ow">in</span> <span class="nb">enumerate</span><span class="p">(</span><span class="n">trainingSummary</span><span class="p">.</span><span class="n">truePositiveRateByLabel</span><span class="p">):</span> |
| <span class="k">print</span><span class="p">(</span><span class="s">"label %d: %s"</span> <span class="o">%</span> <span class="p">(</span><span class="n">i</span><span class="p">,</span> <span class="n">rate</span><span class="p">))</span> |
| |
| <span class="k">print</span><span class="p">(</span><span class="s">"Precision by label:"</span><span class="p">)</span> |
| <span class="k">for</span> <span class="n">i</span><span class="p">,</span> <span class="n">prec</span> <span class="ow">in</span> <span class="nb">enumerate</span><span class="p">(</span><span class="n">trainingSummary</span><span class="p">.</span><span class="n">precisionByLabel</span><span class="p">):</span> |
| <span class="k">print</span><span class="p">(</span><span class="s">"label %d: %s"</span> <span class="o">%</span> <span class="p">(</span><span class="n">i</span><span class="p">,</span> <span class="n">prec</span><span class="p">))</span> |
| |
| <span class="k">print</span><span class="p">(</span><span class="s">"Recall by label:"</span><span class="p">)</span> |
| <span class="k">for</span> <span class="n">i</span><span class="p">,</span> <span class="n">rec</span> <span class="ow">in</span> <span class="nb">enumerate</span><span class="p">(</span><span class="n">trainingSummary</span><span class="p">.</span><span class="n">recallByLabel</span><span class="p">):</span> |
| <span class="k">print</span><span class="p">(</span><span class="s">"label %d: %s"</span> <span class="o">%</span> <span class="p">(</span><span class="n">i</span><span class="p">,</span> <span class="n">rec</span><span class="p">))</span> |
| |
| <span class="k">print</span><span class="p">(</span><span class="s">"F-measure by label:"</span><span class="p">)</span> |
| <span class="k">for</span> <span class="n">i</span><span class="p">,</span> <span class="n">f</span> <span class="ow">in</span> <span class="nb">enumerate</span><span class="p">(</span><span class="n">trainingSummary</span><span class="p">.</span><span class="n">fMeasureByLabel</span><span class="p">()):</span> |
| <span class="k">print</span><span class="p">(</span><span class="s">"label %d: %s"</span> <span class="o">%</span> <span class="p">(</span><span class="n">i</span><span class="p">,</span> <span class="n">f</span><span class="p">))</span> |
| |
| <span class="n">accuracy</span> <span class="o">=</span> <span class="n">trainingSummary</span><span class="p">.</span><span class="n">accuracy</span> |
| <span class="n">falsePositiveRate</span> <span class="o">=</span> <span class="n">trainingSummary</span><span class="p">.</span><span class="n">weightedFalsePositiveRate</span> |
| <span class="n">truePositiveRate</span> <span class="o">=</span> <span class="n">trainingSummary</span><span class="p">.</span><span class="n">weightedTruePositiveRate</span> |
| <span class="n">fMeasure</span> <span class="o">=</span> <span class="n">trainingSummary</span><span class="p">.</span><span class="n">weightedFMeasure</span><span class="p">()</span> |
| <span class="n">precision</span> <span class="o">=</span> <span class="n">trainingSummary</span><span class="p">.</span><span class="n">weightedPrecision</span> |
| <span class="n">recall</span> <span class="o">=</span> <span class="n">trainingSummary</span><span class="p">.</span><span class="n">weightedRecall</span> |
| <span class="k">print</span><span class="p">(</span><span class="s">"Accuracy: %s</span><span class="se">\n</span><span class="s">FPR: %s</span><span class="se">\n</span><span class="s">TPR: %s</span><span class="se">\n</span><span class="s">F-measure: %s</span><span class="se">\n</span><span class="s">Precision: %s</span><span class="se">\n</span><span class="s">Recall: %s"</span> |
| <span class="o">%</span> <span class="p">(</span><span class="n">accuracy</span><span class="p">,</span> <span class="n">falsePositiveRate</span><span class="p">,</span> <span class="n">truePositiveRate</span><span class="p">,</span> <span class="n">fMeasure</span><span class="p">,</span> <span class="n">precision</span><span class="p">,</span> <span class="n">recall</span><span class="p">))</span></code></pre></div> |
| <div><small>Find full example code at "examples/src/main/python/ml/multiclass_logistic_regression_with_elastic_net.py" in the Spark repo.</small></div> |
| </div> |
| |
| <div data-lang="r"> |
| |
| <p>More details on parameters can be found in the <a href="api/R/spark.logit.html">R API documentation</a>.</p> |
| |
| <div class="highlight"><pre class="codehilite"><code><span class="c1"># Load training data</span><span class="w"> |
| </span><span class="n">df</span><span class="w"> </span><span class="o"><-</span><span class="w"> </span><span class="n">read.df</span><span class="p">(</span><span class="s2">"data/mllib/sample_multiclass_classification_data.txt"</span><span class="p">,</span><span class="w"> </span><span class="n">source</span><span class="w"> </span><span class="o">=</span><span class="w"> </span><span class="s2">"libsvm"</span><span class="p">)</span><span class="w"> |
| </span><span class="n">training</span><span class="w"> </span><span class="o"><-</span><span class="w"> </span><span class="n">df</span><span class="w"> |
| </span><span class="n">test</span><span class="w"> </span><span class="o"><-</span><span class="w"> </span><span class="n">df</span><span class="w"> |
| |
| </span><span class="c1"># Fit a multinomial logistic regression model with spark.logit</span><span class="w"> |
| </span><span class="n">model</span><span class="w"> </span><span class="o"><-</span><span class="w"> </span><span class="n">spark.logit</span><span class="p">(</span><span class="n">training</span><span class="p">,</span><span class="w"> </span><span class="n">label</span><span class="w"> </span><span class="o">~</span><span class="w"> </span><span class="n">features</span><span class="p">,</span><span class="w"> </span><span class="n">maxIter</span><span class="w"> </span><span class="o">=</span><span class="w"> </span><span class="m">10</span><span class="p">,</span><span class="w"> </span><span class="n">regParam</span><span class="w"> </span><span class="o">=</span><span class="w"> </span><span class="m">0.3</span><span class="p">,</span><span class="w"> </span><span class="n">elasticNetParam</span><span class="w"> </span><span class="o">=</span><span class="w"> </span><span class="m">0.8</span><span class="p">)</span><span class="w"> |
| |
| </span><span class="c1"># Model summary</span><span class="w"> |
| </span><span class="n">summary</span><span class="p">(</span><span class="n">model</span><span class="p">)</span><span class="w"> |
| |
| </span><span class="c1"># Prediction</span><span class="w"> |
| </span><span class="n">predictions</span><span class="w"> </span><span class="o"><-</span><span class="w"> </span><span class="n">predict</span><span class="p">(</span><span class="n">model</span><span class="p">,</span><span class="w"> </span><span class="n">test</span><span class="p">)</span><span class="w"> |
| </span><span class="n">head</span><span class="p">(</span><span class="n">predictions</span><span class="p">)</span></code></pre></div> |
| <div><small>Find full example code at "examples/src/main/r/ml/logit.R" in the Spark repo.</small></div> |
| </div> |
| |
| </div> |
| |
| <h2 id="decision-tree-classifier">Decision tree classifier</h2> |
| |
| <p>Decision trees are a popular family of classification and regression methods. |
| More information about the <code class="language-plaintext highlighter-rouge">spark.ml</code> implementation can be found further in the <a href="#decision-trees">section on decision trees</a>.</p> |
| |
| <p><strong>Examples</strong></p> |
| |
| <p>The following examples load a dataset in LibSVM format, split it into training and test sets, train on the first dataset, and then evaluate on the held-out test set. |
| We use two feature transformers to prepare the data; these help index categories for the label and categorical features, adding metadata to the <code class="language-plaintext highlighter-rouge">DataFrame</code> which the Decision Tree algorithm can recognize.</p> |
| |
| <div class="codetabs"> |
| <div data-lang="scala"> |
| |
| <p>More details on parameters can be found in the <a href="api/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.html">Scala API documentation</a>.</p> |
| |
| <div class="highlight"><pre class="codehilite"><code><span class="k">import</span> <span class="nn">org.apache.spark.ml.Pipeline</span> |
| <span class="k">import</span> <span class="nn">org.apache.spark.ml.classification.DecisionTreeClassificationModel</span> |
| <span class="k">import</span> <span class="nn">org.apache.spark.ml.classification.DecisionTreeClassifier</span> |
| <span class="k">import</span> <span class="nn">org.apache.spark.ml.evaluation.MulticlassClassificationEvaluator</span> |
| <span class="k">import</span> <span class="nn">org.apache.spark.ml.feature.</span><span class="o">{</span><span class="nc">IndexToString</span><span class="o">,</span> <span class="nc">StringIndexer</span><span class="o">,</span> <span class="nc">VectorIndexer</span><span class="o">}</span> |
| |
| <span class="c1">// Load the data stored in LIBSVM format as a DataFrame.</span> |
| <span class="k">val</span> <span class="nv">data</span> <span class="k">=</span> <span class="nv">spark</span><span class="o">.</span><span class="py">read</span><span class="o">.</span><span class="py">format</span><span class="o">(</span><span class="s">"libsvm"</span><span class="o">).</span><span class="py">load</span><span class="o">(</span><span class="s">"data/mllib/sample_libsvm_data.txt"</span><span class="o">)</span> |
| |
| <span class="c1">// Index labels, adding metadata to the label column.</span> |
| <span class="c1">// Fit on whole dataset to include all labels in index.</span> |
| <span class="k">val</span> <span class="nv">labelIndexer</span> <span class="k">=</span> <span class="k">new</span> <span class="nc">StringIndexer</span><span class="o">()</span> |
| <span class="o">.</span><span class="py">setInputCol</span><span class="o">(</span><span class="s">"label"</span><span class="o">)</span> |
| <span class="o">.</span><span class="py">setOutputCol</span><span class="o">(</span><span class="s">"indexedLabel"</span><span class="o">)</span> |
| <span class="o">.</span><span class="py">fit</span><span class="o">(</span><span class="n">data</span><span class="o">)</span> |
| <span class="c1">// Automatically identify categorical features, and index them.</span> |
| <span class="k">val</span> <span class="nv">featureIndexer</span> <span class="k">=</span> <span class="k">new</span> <span class="nc">VectorIndexer</span><span class="o">()</span> |
| <span class="o">.</span><span class="py">setInputCol</span><span class="o">(</span><span class="s">"features"</span><span class="o">)</span> |
| <span class="o">.</span><span class="py">setOutputCol</span><span class="o">(</span><span class="s">"indexedFeatures"</span><span class="o">)</span> |
| <span class="o">.</span><span class="py">setMaxCategories</span><span class="o">(</span><span class="mi">4</span><span class="o">)</span> <span class="c1">// features with > 4 distinct values are treated as continuous.</span> |
| <span class="o">.</span><span class="py">fit</span><span class="o">(</span><span class="n">data</span><span class="o">)</span> |
| |
| <span class="c1">// Split the data into training and test sets (30% held out for testing).</span> |
| <span class="k">val</span> <span class="nv">Array</span><span class="o">(</span><span class="n">trainingData</span><span class="o">,</span> <span class="n">testData</span><span class="o">)</span> <span class="k">=</span> <span class="nv">data</span><span class="o">.</span><span class="py">randomSplit</span><span class="o">(</span><span class="nc">Array</span><span class="o">(</span><span class="mf">0.7</span><span class="o">,</span> <span class="mf">0.3</span><span class="o">))</span> |
| |
| <span class="c1">// Train a DecisionTree model.</span> |
| <span class="k">val</span> <span class="nv">dt</span> <span class="k">=</span> <span class="k">new</span> <span class="nc">DecisionTreeClassifier</span><span class="o">()</span> |
| <span class="o">.</span><span class="py">setLabelCol</span><span class="o">(</span><span class="s">"indexedLabel"</span><span class="o">)</span> |
| <span class="o">.</span><span class="py">setFeaturesCol</span><span class="o">(</span><span class="s">"indexedFeatures"</span><span class="o">)</span> |
| |
| <span class="c1">// Convert indexed labels back to original labels.</span> |
| <span class="k">val</span> <span class="nv">labelConverter</span> <span class="k">=</span> <span class="k">new</span> <span class="nc">IndexToString</span><span class="o">()</span> |
| <span class="o">.</span><span class="py">setInputCol</span><span class="o">(</span><span class="s">"prediction"</span><span class="o">)</span> |
| <span class="o">.</span><span class="py">setOutputCol</span><span class="o">(</span><span class="s">"predictedLabel"</span><span class="o">)</span> |
| <span class="o">.</span><span class="py">setLabels</span><span class="o">(</span><span class="nv">labelIndexer</span><span class="o">.</span><span class="py">labelsArray</span><span class="o">(</span><span class="mi">0</span><span class="o">))</span> |
| |
| <span class="c1">// Chain indexers and tree in a Pipeline.</span> |
| <span class="k">val</span> <span class="nv">pipeline</span> <span class="k">=</span> <span class="k">new</span> <span class="nc">Pipeline</span><span class="o">()</span> |
| <span class="o">.</span><span class="py">setStages</span><span class="o">(</span><span class="nc">Array</span><span class="o">(</span><span class="n">labelIndexer</span><span class="o">,</span> <span class="n">featureIndexer</span><span class="o">,</span> <span class="n">dt</span><span class="o">,</span> <span class="n">labelConverter</span><span class="o">))</span> |
| |
| <span class="c1">// Train model. This also runs the indexers.</span> |
| <span class="k">val</span> <span class="nv">model</span> <span class="k">=</span> <span class="nv">pipeline</span><span class="o">.</span><span class="py">fit</span><span class="o">(</span><span class="n">trainingData</span><span class="o">)</span> |
| |
| <span class="c1">// Make predictions.</span> |
| <span class="k">val</span> <span class="nv">predictions</span> <span class="k">=</span> <span class="nv">model</span><span class="o">.</span><span class="py">transform</span><span class="o">(</span><span class="n">testData</span><span class="o">)</span> |
| |
| <span class="c1">// Select example rows to display.</span> |
| <span class="nv">predictions</span><span class="o">.</span><span class="py">select</span><span class="o">(</span><span class="s">"predictedLabel"</span><span class="o">,</span> <span class="s">"label"</span><span class="o">,</span> <span class="s">"features"</span><span class="o">).</span><span class="py">show</span><span class="o">(</span><span class="mi">5</span><span class="o">)</span> |
| |
| <span class="c1">// Select (prediction, true label) and compute test error.</span> |
| <span class="k">val</span> <span class="nv">evaluator</span> <span class="k">=</span> <span class="k">new</span> <span class="nc">MulticlassClassificationEvaluator</span><span class="o">()</span> |
| <span class="o">.</span><span class="py">setLabelCol</span><span class="o">(</span><span class="s">"indexedLabel"</span><span class="o">)</span> |
| <span class="o">.</span><span class="py">setPredictionCol</span><span class="o">(</span><span class="s">"prediction"</span><span class="o">)</span> |
| <span class="o">.</span><span class="py">setMetricName</span><span class="o">(</span><span class="s">"accuracy"</span><span class="o">)</span> |
| <span class="k">val</span> <span class="nv">accuracy</span> <span class="k">=</span> <span class="nv">evaluator</span><span class="o">.</span><span class="py">evaluate</span><span class="o">(</span><span class="n">predictions</span><span class="o">)</span> |
| <span class="nf">println</span><span class="o">(</span><span class="n">s</span><span class="s">"Test Error = ${(1.0 - accuracy)}"</span><span class="o">)</span> |
| |
| <span class="k">val</span> <span class="nv">treeModel</span> <span class="k">=</span> <span class="nv">model</span><span class="o">.</span><span class="py">stages</span><span class="o">(</span><span class="mi">2</span><span class="o">).</span><span class="py">asInstanceOf</span><span class="o">[</span><span class="kt">DecisionTreeClassificationModel</span><span class="o">]</span> |
| <span class="nf">println</span><span class="o">(</span><span class="n">s</span><span class="s">"Learned classification tree model:\n ${treeModel.toDebugString}"</span><span class="o">)</span></code></pre></div> |
| <div><small>Find full example code at "examples/src/main/scala/org/apache/spark/examples/ml/DecisionTreeClassificationExample.scala" in the Spark repo.</small></div> |
| |
| </div> |
| |
| <div data-lang="java"> |
| |
| <p>More details on parameters can be found in the <a href="api/java/org/apache/spark/ml/classification/DecisionTreeClassifier.html">Java API documentation</a>.</p> |
| |
| <div class="highlight"><pre class="codehilite"><code><span class="kn">import</span> <span class="nn">org.apache.spark.ml.Pipeline</span><span class="o">;</span> |
| <span class="kn">import</span> <span class="nn">org.apache.spark.ml.PipelineModel</span><span class="o">;</span> |
| <span class="kn">import</span> <span class="nn">org.apache.spark.ml.PipelineStage</span><span class="o">;</span> |
| <span class="kn">import</span> <span class="nn">org.apache.spark.ml.classification.DecisionTreeClassifier</span><span class="o">;</span> |
| <span class="kn">import</span> <span class="nn">org.apache.spark.ml.classification.DecisionTreeClassificationModel</span><span class="o">;</span> |
| <span class="kn">import</span> <span class="nn">org.apache.spark.ml.evaluation.MulticlassClassificationEvaluator</span><span class="o">;</span> |
| <span class="kn">import</span> <span class="nn">org.apache.spark.ml.feature.*</span><span class="o">;</span> |
| <span class="kn">import</span> <span class="nn">org.apache.spark.sql.Dataset</span><span class="o">;</span> |
| <span class="kn">import</span> <span class="nn">org.apache.spark.sql.Row</span><span class="o">;</span> |
| <span class="kn">import</span> <span class="nn">org.apache.spark.sql.SparkSession</span><span class="o">;</span> |
| |
| <span class="c1">// Load the data stored in LIBSVM format as a DataFrame.</span> |
| <span class="nc">Dataset</span><span class="o"><</span><span class="nc">Row</span><span class="o">></span> <span class="n">data</span> <span class="o">=</span> <span class="n">spark</span> |
| <span class="o">.</span><span class="na">read</span><span class="o">()</span> |
| <span class="o">.</span><span class="na">format</span><span class="o">(</span><span class="s">"libsvm"</span><span class="o">)</span> |
| <span class="o">.</span><span class="na">load</span><span class="o">(</span><span class="s">"data/mllib/sample_libsvm_data.txt"</span><span class="o">);</span> |
| |
| <span class="c1">// Index labels, adding metadata to the label column.</span> |
| <span class="c1">// Fit on whole dataset to include all labels in index.</span> |
| <span class="nc">StringIndexerModel</span> <span class="n">labelIndexer</span> <span class="o">=</span> <span class="k">new</span> <span class="nc">StringIndexer</span><span class="o">()</span> |
| <span class="o">.</span><span class="na">setInputCol</span><span class="o">(</span><span class="s">"label"</span><span class="o">)</span> |
| <span class="o">.</span><span class="na">setOutputCol</span><span class="o">(</span><span class="s">"indexedLabel"</span><span class="o">)</span> |
| <span class="o">.</span><span class="na">fit</span><span class="o">(</span><span class="n">data</span><span class="o">);</span> |
| |
| <span class="c1">// Automatically identify categorical features, and index them.</span> |
| <span class="nc">VectorIndexerModel</span> <span class="n">featureIndexer</span> <span class="o">=</span> <span class="k">new</span> <span class="nc">VectorIndexer</span><span class="o">()</span> |
| <span class="o">.</span><span class="na">setInputCol</span><span class="o">(</span><span class="s">"features"</span><span class="o">)</span> |
| <span class="o">.</span><span class="na">setOutputCol</span><span class="o">(</span><span class="s">"indexedFeatures"</span><span class="o">)</span> |
| <span class="o">.</span><span class="na">setMaxCategories</span><span class="o">(</span><span class="mi">4</span><span class="o">)</span> <span class="c1">// features with > 4 distinct values are treated as continuous.</span> |
| <span class="o">.</span><span class="na">fit</span><span class="o">(</span><span class="n">data</span><span class="o">);</span> |
| |
| <span class="c1">// Split the data into training and test sets (30% held out for testing).</span> |
| <span class="nc">Dataset</span><span class="o"><</span><span class="nc">Row</span><span class="o">>[]</span> <span class="n">splits</span> <span class="o">=</span> <span class="n">data</span><span class="o">.</span><span class="na">randomSplit</span><span class="o">(</span><span class="k">new</span> <span class="kt">double</span><span class="o">[]{</span><span class="mf">0.7</span><span class="o">,</span> <span class="mf">0.3</span><span class="o">});</span> |
| <span class="nc">Dataset</span><span class="o"><</span><span class="nc">Row</span><span class="o">></span> <span class="n">trainingData</span> <span class="o">=</span> <span class="n">splits</span><span class="o">[</span><span class="mi">0</span><span class="o">];</span> |
| <span class="nc">Dataset</span><span class="o"><</span><span class="nc">Row</span><span class="o">></span> <span class="n">testData</span> <span class="o">=</span> <span class="n">splits</span><span class="o">[</span><span class="mi">1</span><span class="o">];</span> |
| |
| <span class="c1">// Train a DecisionTree model.</span> |
| <span class="nc">DecisionTreeClassifier</span> <span class="n">dt</span> <span class="o">=</span> <span class="k">new</span> <span class="nc">DecisionTreeClassifier</span><span class="o">()</span> |
| <span class="o">.</span><span class="na">setLabelCol</span><span class="o">(</span><span class="s">"indexedLabel"</span><span class="o">)</span> |
| <span class="o">.</span><span class="na">setFeaturesCol</span><span class="o">(</span><span class="s">"indexedFeatures"</span><span class="o">);</span> |
| |
| <span class="c1">// Convert indexed labels back to original labels.</span> |
| <span class="nc">IndexToString</span> <span class="n">labelConverter</span> <span class="o">=</span> <span class="k">new</span> <span class="nc">IndexToString</span><span class="o">()</span> |
| <span class="o">.</span><span class="na">setInputCol</span><span class="o">(</span><span class="s">"prediction"</span><span class="o">)</span> |
| <span class="o">.</span><span class="na">setOutputCol</span><span class="o">(</span><span class="s">"predictedLabel"</span><span class="o">)</span> |
| <span class="o">.</span><span class="na">setLabels</span><span class="o">(</span><span class="n">labelIndexer</span><span class="o">.</span><span class="na">labelsArray</span><span class="o">()[</span><span class="mi">0</span><span class="o">]);</span> |
| |
| <span class="c1">// Chain indexers and tree in a Pipeline.</span> |
| <span class="nc">Pipeline</span> <span class="n">pipeline</span> <span class="o">=</span> <span class="k">new</span> <span class="nc">Pipeline</span><span class="o">()</span> |
| <span class="o">.</span><span class="na">setStages</span><span class="o">(</span><span class="k">new</span> <span class="nc">PipelineStage</span><span class="o">[]{</span><span class="n">labelIndexer</span><span class="o">,</span> <span class="n">featureIndexer</span><span class="o">,</span> <span class="n">dt</span><span class="o">,</span> <span class="n">labelConverter</span><span class="o">});</span> |
| |
| <span class="c1">// Train model. This also runs the indexers.</span> |
| <span class="nc">PipelineModel</span> <span class="n">model</span> <span class="o">=</span> <span class="n">pipeline</span><span class="o">.</span><span class="na">fit</span><span class="o">(</span><span class="n">trainingData</span><span class="o">);</span> |
| |
| <span class="c1">// Make predictions.</span> |
| <span class="nc">Dataset</span><span class="o"><</span><span class="nc">Row</span><span class="o">></span> <span class="n">predictions</span> <span class="o">=</span> <span class="n">model</span><span class="o">.</span><span class="na">transform</span><span class="o">(</span><span class="n">testData</span><span class="o">);</span> |
| |
| <span class="c1">// Select example rows to display.</span> |
| <span class="n">predictions</span><span class="o">.</span><span class="na">select</span><span class="o">(</span><span class="s">"predictedLabel"</span><span class="o">,</span> <span class="s">"label"</span><span class="o">,</span> <span class="s">"features"</span><span class="o">).</span><span class="na">show</span><span class="o">(</span><span class="mi">5</span><span class="o">);</span> |
| |
| <span class="c1">// Select (prediction, true label) and compute test error.</span> |
| <span class="nc">MulticlassClassificationEvaluator</span> <span class="n">evaluator</span> <span class="o">=</span> <span class="k">new</span> <span class="nc">MulticlassClassificationEvaluator</span><span class="o">()</span> |
| <span class="o">.</span><span class="na">setLabelCol</span><span class="o">(</span><span class="s">"indexedLabel"</span><span class="o">)</span> |
| <span class="o">.</span><span class="na">setPredictionCol</span><span class="o">(</span><span class="s">"prediction"</span><span class="o">)</span> |
| <span class="o">.</span><span class="na">setMetricName</span><span class="o">(</span><span class="s">"accuracy"</span><span class="o">);</span> |
| <span class="kt">double</span> <span class="n">accuracy</span> <span class="o">=</span> <span class="n">evaluator</span><span class="o">.</span><span class="na">evaluate</span><span class="o">(</span><span class="n">predictions</span><span class="o">);</span> |
| <span class="nc">System</span><span class="o">.</span><span class="na">out</span><span class="o">.</span><span class="na">println</span><span class="o">(</span><span class="s">"Test Error = "</span> <span class="o">+</span> <span class="o">(</span><span class="mf">1.0</span> <span class="o">-</span> <span class="n">accuracy</span><span class="o">));</span> |
| |
| <span class="nc">DecisionTreeClassificationModel</span> <span class="n">treeModel</span> <span class="o">=</span> |
| <span class="o">(</span><span class="nc">DecisionTreeClassificationModel</span><span class="o">)</span> <span class="o">(</span><span class="n">model</span><span class="o">.</span><span class="na">stages</span><span class="o">()[</span><span class="mi">2</span><span class="o">]);</span> |
| <span class="nc">System</span><span class="o">.</span><span class="na">out</span><span class="o">.</span><span class="na">println</span><span class="o">(</span><span class="s">"Learned classification tree model:\n"</span> <span class="o">+</span> <span class="n">treeModel</span><span class="o">.</span><span class="na">toDebugString</span><span class="o">());</span></code></pre></div> |
| <div><small>Find full example code at "examples/src/main/java/org/apache/spark/examples/ml/JavaDecisionTreeClassificationExample.java" in the Spark repo.</small></div> |
| |
| </div> |
| |
| <div data-lang="python"> |
| |
| <p>More details on parameters can be found in the <a href="api/python/reference/api/pyspark.ml.classification.DecisionTreeClassifier.html">Python API documentation</a>.</p> |
| |
| <div class="highlight"><pre class="codehilite"><code><span class="kn">from</span> <span class="nn">pyspark.ml</span> <span class="kn">import</span> <span class="n">Pipeline</span> |
| <span class="kn">from</span> <span class="nn">pyspark.ml.classification</span> <span class="kn">import</span> <span class="n">DecisionTreeClassifier</span> |
| <span class="kn">from</span> <span class="nn">pyspark.ml.feature</span> <span class="kn">import</span> <span class="n">StringIndexer</span><span class="p">,</span> <span class="n">VectorIndexer</span> |
| <span class="kn">from</span> <span class="nn">pyspark.ml.evaluation</span> <span class="kn">import</span> <span class="n">MulticlassClassificationEvaluator</span> |
| |
| <span class="c1"># Load the data stored in LIBSVM format as a DataFrame. |
| </span><span class="n">data</span> <span class="o">=</span> <span class="n">spark</span><span class="p">.</span><span class="n">read</span><span class="p">.</span><span class="nb">format</span><span class="p">(</span><span class="s">"libsvm"</span><span class="p">).</span><span class="n">load</span><span class="p">(</span><span class="s">"data/mllib/sample_libsvm_data.txt"</span><span class="p">)</span> |
| |
| <span class="c1"># Index labels, adding metadata to the label column. |
| # Fit on whole dataset to include all labels in index. |
| </span><span class="n">labelIndexer</span> <span class="o">=</span> <span class="n">StringIndexer</span><span class="p">(</span><span class="n">inputCol</span><span class="o">=</span><span class="s">"label"</span><span class="p">,</span> <span class="n">outputCol</span><span class="o">=</span><span class="s">"indexedLabel"</span><span class="p">).</span><span class="n">fit</span><span class="p">(</span><span class="n">data</span><span class="p">)</span> |
| <span class="c1"># Automatically identify categorical features, and index them. |
| # We specify maxCategories so features with > 4 distinct values are treated as continuous. |
| </span><span class="n">featureIndexer</span> <span class="o">=</span>\ |
| <span class="n">VectorIndexer</span><span class="p">(</span><span class="n">inputCol</span><span class="o">=</span><span class="s">"features"</span><span class="p">,</span> <span class="n">outputCol</span><span class="o">=</span><span class="s">"indexedFeatures"</span><span class="p">,</span> <span class="n">maxCategories</span><span class="o">=</span><span class="mi">4</span><span class="p">).</span><span class="n">fit</span><span class="p">(</span><span class="n">data</span><span class="p">)</span> |
| |
| <span class="c1"># Split the data into training and test sets (30% held out for testing) |
| </span><span class="p">(</span><span class="n">trainingData</span><span class="p">,</span> <span class="n">testData</span><span class="p">)</span> <span class="o">=</span> <span class="n">data</span><span class="p">.</span><span class="n">randomSplit</span><span class="p">([</span><span class="mf">0.7</span><span class="p">,</span> <span class="mf">0.3</span><span class="p">])</span> |
| |
| <span class="c1"># Train a DecisionTree model. |
| </span><span class="n">dt</span> <span class="o">=</span> <span class="n">DecisionTreeClassifier</span><span class="p">(</span><span class="n">labelCol</span><span class="o">=</span><span class="s">"indexedLabel"</span><span class="p">,</span> <span class="n">featuresCol</span><span class="o">=</span><span class="s">"indexedFeatures"</span><span class="p">)</span> |
| |
| <span class="c1"># Chain indexers and tree in a Pipeline |
| </span><span class="n">pipeline</span> <span class="o">=</span> <span class="n">Pipeline</span><span class="p">(</span><span class="n">stages</span><span class="o">=</span><span class="p">[</span><span class="n">labelIndexer</span><span class="p">,</span> <span class="n">featureIndexer</span><span class="p">,</span> <span class="n">dt</span><span class="p">])</span> |
| |
| <span class="c1"># Train model. This also runs the indexers. |
| </span><span class="n">model</span> <span class="o">=</span> <span class="n">pipeline</span><span class="p">.</span><span class="n">fit</span><span class="p">(</span><span class="n">trainingData</span><span class="p">)</span> |
| |
| <span class="c1"># Make predictions. |
| </span><span class="n">predictions</span> <span class="o">=</span> <span class="n">model</span><span class="p">.</span><span class="n">transform</span><span class="p">(</span><span class="n">testData</span><span class="p">)</span> |
| |
| <span class="c1"># Select example rows to display. |
| </span><span class="n">predictions</span><span class="p">.</span><span class="n">select</span><span class="p">(</span><span class="s">"prediction"</span><span class="p">,</span> <span class="s">"indexedLabel"</span><span class="p">,</span> <span class="s">"features"</span><span class="p">).</span><span class="n">show</span><span class="p">(</span><span class="mi">5</span><span class="p">)</span> |
| |
| <span class="c1"># Select (prediction, true label) and compute test error |
| </span><span class="n">evaluator</span> <span class="o">=</span> <span class="n">MulticlassClassificationEvaluator</span><span class="p">(</span> |
| <span class="n">labelCol</span><span class="o">=</span><span class="s">"indexedLabel"</span><span class="p">,</span> <span class="n">predictionCol</span><span class="o">=</span><span class="s">"prediction"</span><span class="p">,</span> <span class="n">metricName</span><span class="o">=</span><span class="s">"accuracy"</span><span class="p">)</span> |
| <span class="n">accuracy</span> <span class="o">=</span> <span class="n">evaluator</span><span class="p">.</span><span class="n">evaluate</span><span class="p">(</span><span class="n">predictions</span><span class="p">)</span> |
| <span class="k">print</span><span class="p">(</span><span class="s">"Test Error = %g "</span> <span class="o">%</span> <span class="p">(</span><span class="mf">1.0</span> <span class="o">-</span> <span class="n">accuracy</span><span class="p">))</span> |
| |
| <span class="n">treeModel</span> <span class="o">=</span> <span class="n">model</span><span class="p">.</span><span class="n">stages</span><span class="p">[</span><span class="mi">2</span><span class="p">]</span> |
| <span class="c1"># summary only |
| </span><span class="k">print</span><span class="p">(</span><span class="n">treeModel</span><span class="p">)</span></code></pre></div> |
| <div><small>Find full example code at "examples/src/main/python/ml/decision_tree_classification_example.py" in the Spark repo.</small></div> |
| |
| </div> |
| |
| <div data-lang="r"> |
| |
| <p>Refer to the <a href="api/R/spark.decisionTree.html">R API docs</a> for more details.</p> |
| |
| <div class="highlight"><pre class="codehilite"><code><span class="c1"># Load training data</span><span class="w"> |
| </span><span class="n">df</span><span class="w"> </span><span class="o"><-</span><span class="w"> </span><span class="n">read.df</span><span class="p">(</span><span class="s2">"data/mllib/sample_libsvm_data.txt"</span><span class="p">,</span><span class="w"> </span><span class="n">source</span><span class="w"> </span><span class="o">=</span><span class="w"> </span><span class="s2">"libsvm"</span><span class="p">)</span><span class="w"> |
| </span><span class="n">training</span><span class="w"> </span><span class="o"><-</span><span class="w"> </span><span class="n">df</span><span class="w"> |
| </span><span class="n">test</span><span class="w"> </span><span class="o"><-</span><span class="w"> </span><span class="n">df</span><span class="w"> |
| |
| </span><span class="c1"># Fit a DecisionTree classification model with spark.decisionTree</span><span class="w"> |
| </span><span class="n">model</span><span class="w"> </span><span class="o"><-</span><span class="w"> </span><span class="n">spark.decisionTree</span><span class="p">(</span><span class="n">training</span><span class="p">,</span><span class="w"> </span><span class="n">label</span><span class="w"> </span><span class="o">~</span><span class="w"> </span><span class="n">features</span><span class="p">,</span><span class="w"> </span><span class="s2">"classification"</span><span class="p">)</span><span class="w"> |
| |
| </span><span class="c1"># Model summary</span><span class="w"> |
| </span><span class="n">summary</span><span class="p">(</span><span class="n">model</span><span class="p">)</span><span class="w"> |
| |
| </span><span class="c1"># Prediction</span><span class="w"> |
| </span><span class="n">predictions</span><span class="w"> </span><span class="o"><-</span><span class="w"> </span><span class="n">predict</span><span class="p">(</span><span class="n">model</span><span class="p">,</span><span class="w"> </span><span class="n">test</span><span class="p">)</span><span class="w"> |
| </span><span class="n">head</span><span class="p">(</span><span class="n">predictions</span><span class="p">)</span></code></pre></div> |
| <div><small>Find full example code at "examples/src/main/r/ml/decisionTree.R" in the Spark repo.</small></div> |
| |
| </div> |
| |
| </div> |
| |
| <h2 id="random-forest-classifier">Random forest classifier</h2> |
| |
| <p>Random forests are a popular family of classification and regression methods. |
| More information about the <code class="language-plaintext highlighter-rouge">spark.ml</code> implementation can be found further in the <a href="#random-forests">section on random forests</a>.</p> |
| |
| <p><strong>Examples</strong></p> |
| |
| <p>The following examples load a dataset in LibSVM format, split it into training and test sets, train on the first dataset, and then evaluate on the held-out test set. |
| We use two feature transformers to prepare the data; these help index categories for the label and categorical features, adding metadata to the <code class="language-plaintext highlighter-rouge">DataFrame</code> which the tree-based algorithms can recognize.</p> |
| |
| <div class="codetabs"> |
| <div data-lang="scala"> |
| |
| <p>Refer to the <a href="api/scala/org/apache/spark/ml/classification/RandomForestClassifier.html">Scala API docs</a> for more details.</p> |
| |
| <div class="highlight"><pre class="codehilite"><code><span class="k">import</span> <span class="nn">org.apache.spark.ml.Pipeline</span> |
| <span class="k">import</span> <span class="nn">org.apache.spark.ml.classification.</span><span class="o">{</span><span class="nc">RandomForestClassificationModel</span><span class="o">,</span> <span class="nc">RandomForestClassifier</span><span class="o">}</span> |
| <span class="k">import</span> <span class="nn">org.apache.spark.ml.evaluation.MulticlassClassificationEvaluator</span> |
| <span class="k">import</span> <span class="nn">org.apache.spark.ml.feature.</span><span class="o">{</span><span class="nc">IndexToString</span><span class="o">,</span> <span class="nc">StringIndexer</span><span class="o">,</span> <span class="nc">VectorIndexer</span><span class="o">}</span> |
| |
| <span class="c1">// Load and parse the data file, converting it to a DataFrame.</span> |
| <span class="k">val</span> <span class="nv">data</span> <span class="k">=</span> <span class="nv">spark</span><span class="o">.</span><span class="py">read</span><span class="o">.</span><span class="py">format</span><span class="o">(</span><span class="s">"libsvm"</span><span class="o">).</span><span class="py">load</span><span class="o">(</span><span class="s">"data/mllib/sample_libsvm_data.txt"</span><span class="o">)</span> |
| |
| <span class="c1">// Index labels, adding metadata to the label column.</span> |
| <span class="c1">// Fit on whole dataset to include all labels in index.</span> |
| <span class="k">val</span> <span class="nv">labelIndexer</span> <span class="k">=</span> <span class="k">new</span> <span class="nc">StringIndexer</span><span class="o">()</span> |
| <span class="o">.</span><span class="py">setInputCol</span><span class="o">(</span><span class="s">"label"</span><span class="o">)</span> |
| <span class="o">.</span><span class="py">setOutputCol</span><span class="o">(</span><span class="s">"indexedLabel"</span><span class="o">)</span> |
| <span class="o">.</span><span class="py">fit</span><span class="o">(</span><span class="n">data</span><span class="o">)</span> |
| <span class="c1">// Automatically identify categorical features, and index them.</span> |
| <span class="c1">// Set maxCategories so features with > 4 distinct values are treated as continuous.</span> |
| <span class="k">val</span> <span class="nv">featureIndexer</span> <span class="k">=</span> <span class="k">new</span> <span class="nc">VectorIndexer</span><span class="o">()</span> |
| <span class="o">.</span><span class="py">setInputCol</span><span class="o">(</span><span class="s">"features"</span><span class="o">)</span> |
| <span class="o">.</span><span class="py">setOutputCol</span><span class="o">(</span><span class="s">"indexedFeatures"</span><span class="o">)</span> |
| <span class="o">.</span><span class="py">setMaxCategories</span><span class="o">(</span><span class="mi">4</span><span class="o">)</span> |
| <span class="o">.</span><span class="py">fit</span><span class="o">(</span><span class="n">data</span><span class="o">)</span> |
| |
| <span class="c1">// Split the data into training and test sets (30% held out for testing).</span> |
| <span class="k">val</span> <span class="nv">Array</span><span class="o">(</span><span class="n">trainingData</span><span class="o">,</span> <span class="n">testData</span><span class="o">)</span> <span class="k">=</span> <span class="nv">data</span><span class="o">.</span><span class="py">randomSplit</span><span class="o">(</span><span class="nc">Array</span><span class="o">(</span><span class="mf">0.7</span><span class="o">,</span> <span class="mf">0.3</span><span class="o">))</span> |
| |
| <span class="c1">// Train a RandomForest model.</span> |
| <span class="k">val</span> <span class="nv">rf</span> <span class="k">=</span> <span class="k">new</span> <span class="nc">RandomForestClassifier</span><span class="o">()</span> |
| <span class="o">.</span><span class="py">setLabelCol</span><span class="o">(</span><span class="s">"indexedLabel"</span><span class="o">)</span> |
| <span class="o">.</span><span class="py">setFeaturesCol</span><span class="o">(</span><span class="s">"indexedFeatures"</span><span class="o">)</span> |
| <span class="o">.</span><span class="py">setNumTrees</span><span class="o">(</span><span class="mi">10</span><span class="o">)</span> |
| |
| <span class="c1">// Convert indexed labels back to original labels.</span> |
| <span class="k">val</span> <span class="nv">labelConverter</span> <span class="k">=</span> <span class="k">new</span> <span class="nc">IndexToString</span><span class="o">()</span> |
| <span class="o">.</span><span class="py">setInputCol</span><span class="o">(</span><span class="s">"prediction"</span><span class="o">)</span> |
| <span class="o">.</span><span class="py">setOutputCol</span><span class="o">(</span><span class="s">"predictedLabel"</span><span class="o">)</span> |
| <span class="o">.</span><span class="py">setLabels</span><span class="o">(</span><span class="nv">labelIndexer</span><span class="o">.</span><span class="py">labelsArray</span><span class="o">(</span><span class="mi">0</span><span class="o">))</span> |
| |
| <span class="c1">// Chain indexers and forest in a Pipeline.</span> |
| <span class="k">val</span> <span class="nv">pipeline</span> <span class="k">=</span> <span class="k">new</span> <span class="nc">Pipeline</span><span class="o">()</span> |
| <span class="o">.</span><span class="py">setStages</span><span class="o">(</span><span class="nc">Array</span><span class="o">(</span><span class="n">labelIndexer</span><span class="o">,</span> <span class="n">featureIndexer</span><span class="o">,</span> <span class="n">rf</span><span class="o">,</span> <span class="n">labelConverter</span><span class="o">))</span> |
| |
| <span class="c1">// Train model. This also runs the indexers.</span> |
| <span class="k">val</span> <span class="nv">model</span> <span class="k">=</span> <span class="nv">pipeline</span><span class="o">.</span><span class="py">fit</span><span class="o">(</span><span class="n">trainingData</span><span class="o">)</span> |
| |
| <span class="c1">// Make predictions.</span> |
| <span class="k">val</span> <span class="nv">predictions</span> <span class="k">=</span> <span class="nv">model</span><span class="o">.</span><span class="py">transform</span><span class="o">(</span><span class="n">testData</span><span class="o">)</span> |
| |
| <span class="c1">// Select example rows to display.</span> |
| <span class="nv">predictions</span><span class="o">.</span><span class="py">select</span><span class="o">(</span><span class="s">"predictedLabel"</span><span class="o">,</span> <span class="s">"label"</span><span class="o">,</span> <span class="s">"features"</span><span class="o">).</span><span class="py">show</span><span class="o">(</span><span class="mi">5</span><span class="o">)</span> |
| |
| <span class="c1">// Select (prediction, true label) and compute test error.</span> |
| <span class="k">val</span> <span class="nv">evaluator</span> <span class="k">=</span> <span class="k">new</span> <span class="nc">MulticlassClassificationEvaluator</span><span class="o">()</span> |
| <span class="o">.</span><span class="py">setLabelCol</span><span class="o">(</span><span class="s">"indexedLabel"</span><span class="o">)</span> |
| <span class="o">.</span><span class="py">setPredictionCol</span><span class="o">(</span><span class="s">"prediction"</span><span class="o">)</span> |
| <span class="o">.</span><span class="py">setMetricName</span><span class="o">(</span><span class="s">"accuracy"</span><span class="o">)</span> |
| <span class="k">val</span> <span class="nv">accuracy</span> <span class="k">=</span> <span class="nv">evaluator</span><span class="o">.</span><span class="py">evaluate</span><span class="o">(</span><span class="n">predictions</span><span class="o">)</span> |
| <span class="nf">println</span><span class="o">(</span><span class="n">s</span><span class="s">"Test Error = ${(1.0 - accuracy)}"</span><span class="o">)</span> |
| |
| <span class="k">val</span> <span class="nv">rfModel</span> <span class="k">=</span> <span class="nv">model</span><span class="o">.</span><span class="py">stages</span><span class="o">(</span><span class="mi">2</span><span class="o">).</span><span class="py">asInstanceOf</span><span class="o">[</span><span class="kt">RandomForestClassificationModel</span><span class="o">]</span> |
| <span class="nf">println</span><span class="o">(</span><span class="n">s</span><span class="s">"Learned classification forest model:\n ${rfModel.toDebugString}"</span><span class="o">)</span></code></pre></div> |
| <div><small>Find full example code at "examples/src/main/scala/org/apache/spark/examples/ml/RandomForestClassifierExample.scala" in the Spark repo.</small></div> |
| </div> |
| |
| <div data-lang="java"> |
| |
| <p>Refer to the <a href="api/java/org/apache/spark/ml/classification/RandomForestClassifier.html">Java API docs</a> for more details.</p> |
| |
| <div class="highlight"><pre class="codehilite"><code><span class="kn">import</span> <span class="nn">org.apache.spark.ml.Pipeline</span><span class="o">;</span> |
| <span class="kn">import</span> <span class="nn">org.apache.spark.ml.PipelineModel</span><span class="o">;</span> |
| <span class="kn">import</span> <span class="nn">org.apache.spark.ml.PipelineStage</span><span class="o">;</span> |
| <span class="kn">import</span> <span class="nn">org.apache.spark.ml.classification.RandomForestClassificationModel</span><span class="o">;</span> |
| <span class="kn">import</span> <span class="nn">org.apache.spark.ml.classification.RandomForestClassifier</span><span class="o">;</span> |
| <span class="kn">import</span> <span class="nn">org.apache.spark.ml.evaluation.MulticlassClassificationEvaluator</span><span class="o">;</span> |
| <span class="kn">import</span> <span class="nn">org.apache.spark.ml.feature.*</span><span class="o">;</span> |
| <span class="kn">import</span> <span class="nn">org.apache.spark.sql.Dataset</span><span class="o">;</span> |
| <span class="kn">import</span> <span class="nn">org.apache.spark.sql.Row</span><span class="o">;</span> |
| <span class="kn">import</span> <span class="nn">org.apache.spark.sql.SparkSession</span><span class="o">;</span> |
| |
| <span class="c1">// Load and parse the data file, converting it to a DataFrame.</span> |
| <span class="nc">Dataset</span><span class="o"><</span><span class="nc">Row</span><span class="o">></span> <span class="n">data</span> <span class="o">=</span> <span class="n">spark</span><span class="o">.</span><span class="na">read</span><span class="o">().</span><span class="na">format</span><span class="o">(</span><span class="s">"libsvm"</span><span class="o">).</span><span class="na">load</span><span class="o">(</span><span class="s">"data/mllib/sample_libsvm_data.txt"</span><span class="o">);</span> |
| |
| <span class="c1">// Index labels, adding metadata to the label column.</span> |
| <span class="c1">// Fit on whole dataset to include all labels in index.</span> |
| <span class="nc">StringIndexerModel</span> <span class="n">labelIndexer</span> <span class="o">=</span> <span class="k">new</span> <span class="nc">StringIndexer</span><span class="o">()</span> |
| <span class="o">.</span><span class="na">setInputCol</span><span class="o">(</span><span class="s">"label"</span><span class="o">)</span> |
| <span class="o">.</span><span class="na">setOutputCol</span><span class="o">(</span><span class="s">"indexedLabel"</span><span class="o">)</span> |
| <span class="o">.</span><span class="na">fit</span><span class="o">(</span><span class="n">data</span><span class="o">);</span> |
| <span class="c1">// Automatically identify categorical features, and index them.</span> |
| <span class="c1">// Set maxCategories so features with > 4 distinct values are treated as continuous.</span> |
| <span class="nc">VectorIndexerModel</span> <span class="n">featureIndexer</span> <span class="o">=</span> <span class="k">new</span> <span class="nc">VectorIndexer</span><span class="o">()</span> |
| <span class="o">.</span><span class="na">setInputCol</span><span class="o">(</span><span class="s">"features"</span><span class="o">)</span> |
| <span class="o">.</span><span class="na">setOutputCol</span><span class="o">(</span><span class="s">"indexedFeatures"</span><span class="o">)</span> |
| <span class="o">.</span><span class="na">setMaxCategories</span><span class="o">(</span><span class="mi">4</span><span class="o">)</span> |
| <span class="o">.</span><span class="na">fit</span><span class="o">(</span><span class="n">data</span><span class="o">);</span> |
| |
| <span class="c1">// Split the data into training and test sets (30% held out for testing)</span> |
| <span class="nc">Dataset</span><span class="o"><</span><span class="nc">Row</span><span class="o">>[]</span> <span class="n">splits</span> <span class="o">=</span> <span class="n">data</span><span class="o">.</span><span class="na">randomSplit</span><span class="o">(</span><span class="k">new</span> <span class="kt">double</span><span class="o">[]</span> <span class="o">{</span><span class="mf">0.7</span><span class="o">,</span> <span class="mf">0.3</span><span class="o">});</span> |
| <span class="nc">Dataset</span><span class="o"><</span><span class="nc">Row</span><span class="o">></span> <span class="n">trainingData</span> <span class="o">=</span> <span class="n">splits</span><span class="o">[</span><span class="mi">0</span><span class="o">];</span> |
| <span class="nc">Dataset</span><span class="o"><</span><span class="nc">Row</span><span class="o">></span> <span class="n">testData</span> <span class="o">=</span> <span class="n">splits</span><span class="o">[</span><span class="mi">1</span><span class="o">];</span> |
| |
| <span class="c1">// Train a RandomForest model.</span> |
| <span class="nc">RandomForestClassifier</span> <span class="n">rf</span> <span class="o">=</span> <span class="k">new</span> <span class="nc">RandomForestClassifier</span><span class="o">()</span> |
| <span class="o">.</span><span class="na">setLabelCol</span><span class="o">(</span><span class="s">"indexedLabel"</span><span class="o">)</span> |
| <span class="o">.</span><span class="na">setFeaturesCol</span><span class="o">(</span><span class="s">"indexedFeatures"</span><span class="o">);</span> |
| |
| <span class="c1">// Convert indexed labels back to original labels.</span> |
| <span class="nc">IndexToString</span> <span class="n">labelConverter</span> <span class="o">=</span> <span class="k">new</span> <span class="nc">IndexToString</span><span class="o">()</span> |
| <span class="o">.</span><span class="na">setInputCol</span><span class="o">(</span><span class="s">"prediction"</span><span class="o">)</span> |
| <span class="o">.</span><span class="na">setOutputCol</span><span class="o">(</span><span class="s">"predictedLabel"</span><span class="o">)</span> |
| <span class="o">.</span><span class="na">setLabels</span><span class="o">(</span><span class="n">labelIndexer</span><span class="o">.</span><span class="na">labelsArray</span><span class="o">()[</span><span class="mi">0</span><span class="o">]);</span> |
| |
| <span class="c1">// Chain indexers and forest in a Pipeline</span> |
| <span class="nc">Pipeline</span> <span class="n">pipeline</span> <span class="o">=</span> <span class="k">new</span> <span class="nc">Pipeline</span><span class="o">()</span> |
| <span class="o">.</span><span class="na">setStages</span><span class="o">(</span><span class="k">new</span> <span class="nc">PipelineStage</span><span class="o">[]</span> <span class="o">{</span><span class="n">labelIndexer</span><span class="o">,</span> <span class="n">featureIndexer</span><span class="o">,</span> <span class="n">rf</span><span class="o">,</span> <span class="n">labelConverter</span><span class="o">});</span> |
| |
| <span class="c1">// Train model. This also runs the indexers.</span> |
| <span class="nc">PipelineModel</span> <span class="n">model</span> <span class="o">=</span> <span class="n">pipeline</span><span class="o">.</span><span class="na">fit</span><span class="o">(</span><span class="n">trainingData</span><span class="o">);</span> |
| |
| <span class="c1">// Make predictions.</span> |
| <span class="nc">Dataset</span><span class="o"><</span><span class="nc">Row</span><span class="o">></span> <span class="n">predictions</span> <span class="o">=</span> <span class="n">model</span><span class="o">.</span><span class="na">transform</span><span class="o">(</span><span class="n">testData</span><span class="o">);</span> |
| |
| <span class="c1">// Select example rows to display.</span> |
| <span class="n">predictions</span><span class="o">.</span><span class="na">select</span><span class="o">(</span><span class="s">"predictedLabel"</span><span class="o">,</span> <span class="s">"label"</span><span class="o">,</span> <span class="s">"features"</span><span class="o">).</span><span class="na">show</span><span class="o">(</span><span class="mi">5</span><span class="o">);</span> |
| |
| <span class="c1">// Select (prediction, true label) and compute test error</span> |
| <span class="nc">MulticlassClassificationEvaluator</span> <span class="n">evaluator</span> <span class="o">=</span> <span class="k">new</span> <span class="nc">MulticlassClassificationEvaluator</span><span class="o">()</span> |
| <span class="o">.</span><span class="na">setLabelCol</span><span class="o">(</span><span class="s">"indexedLabel"</span><span class="o">)</span> |
| <span class="o">.</span><span class="na">setPredictionCol</span><span class="o">(</span><span class="s">"prediction"</span><span class="o">)</span> |
| <span class="o">.</span><span class="na">setMetricName</span><span class="o">(</span><span class="s">"accuracy"</span><span class="o">);</span> |
| <span class="kt">double</span> <span class="n">accuracy</span> <span class="o">=</span> <span class="n">evaluator</span><span class="o">.</span><span class="na">evaluate</span><span class="o">(</span><span class="n">predictions</span><span class="o">);</span> |
| <span class="nc">System</span><span class="o">.</span><span class="na">out</span><span class="o">.</span><span class="na">println</span><span class="o">(</span><span class="s">"Test Error = "</span> <span class="o">+</span> <span class="o">(</span><span class="mf">1.0</span> <span class="o">-</span> <span class="n">accuracy</span><span class="o">));</span> |
| |
| <span class="nc">RandomForestClassificationModel</span> <span class="n">rfModel</span> <span class="o">=</span> <span class="o">(</span><span class="nc">RandomForestClassificationModel</span><span class="o">)(</span><span class="n">model</span><span class="o">.</span><span class="na">stages</span><span class="o">()[</span><span class="mi">2</span><span class="o">]);</span> |
| <span class="nc">System</span><span class="o">.</span><span class="na">out</span><span class="o">.</span><span class="na">println</span><span class="o">(</span><span class="s">"Learned classification forest model:\n"</span> <span class="o">+</span> <span class="n">rfModel</span><span class="o">.</span><span class="na">toDebugString</span><span class="o">());</span></code></pre></div> |
| <div><small>Find full example code at "examples/src/main/java/org/apache/spark/examples/ml/JavaRandomForestClassifierExample.java" in the Spark repo.</small></div> |
| </div> |
| |
| <div data-lang="python"> |
| |
| <p>Refer to the <a href="api/python/reference/api/pyspark.ml.classification.RandomForestClassifier.html">Python API docs</a> for more details.</p> |
| |
| <div class="highlight"><pre class="codehilite"><code><span class="kn">from</span> <span class="nn">pyspark.ml</span> <span class="kn">import</span> <span class="n">Pipeline</span> |
| <span class="kn">from</span> <span class="nn">pyspark.ml.classification</span> <span class="kn">import</span> <span class="n">RandomForestClassifier</span> |
| <span class="kn">from</span> <span class="nn">pyspark.ml.feature</span> <span class="kn">import</span> <span class="n">IndexToString</span><span class="p">,</span> <span class="n">StringIndexer</span><span class="p">,</span> <span class="n">VectorIndexer</span> |
| <span class="kn">from</span> <span class="nn">pyspark.ml.evaluation</span> <span class="kn">import</span> <span class="n">MulticlassClassificationEvaluator</span> |
| |
| <span class="c1"># Load and parse the data file, converting it to a DataFrame. |
| </span><span class="n">data</span> <span class="o">=</span> <span class="n">spark</span><span class="p">.</span><span class="n">read</span><span class="p">.</span><span class="nb">format</span><span class="p">(</span><span class="s">"libsvm"</span><span class="p">).</span><span class="n">load</span><span class="p">(</span><span class="s">"data/mllib/sample_libsvm_data.txt"</span><span class="p">)</span> |
| |
| <span class="c1"># Index labels, adding metadata to the label column. |
| # Fit on whole dataset to include all labels in index. |
| </span><span class="n">labelIndexer</span> <span class="o">=</span> <span class="n">StringIndexer</span><span class="p">(</span><span class="n">inputCol</span><span class="o">=</span><span class="s">"label"</span><span class="p">,</span> <span class="n">outputCol</span><span class="o">=</span><span class="s">"indexedLabel"</span><span class="p">).</span><span class="n">fit</span><span class="p">(</span><span class="n">data</span><span class="p">)</span> |
| |
| <span class="c1"># Automatically identify categorical features, and index them. |
| # Set maxCategories so features with > 4 distinct values are treated as continuous. |
| </span><span class="n">featureIndexer</span> <span class="o">=</span>\ |
| <span class="n">VectorIndexer</span><span class="p">(</span><span class="n">inputCol</span><span class="o">=</span><span class="s">"features"</span><span class="p">,</span> <span class="n">outputCol</span><span class="o">=</span><span class="s">"indexedFeatures"</span><span class="p">,</span> <span class="n">maxCategories</span><span class="o">=</span><span class="mi">4</span><span class="p">).</span><span class="n">fit</span><span class="p">(</span><span class="n">data</span><span class="p">)</span> |
| |
| <span class="c1"># Split the data into training and test sets (30% held out for testing) |
| </span><span class="p">(</span><span class="n">trainingData</span><span class="p">,</span> <span class="n">testData</span><span class="p">)</span> <span class="o">=</span> <span class="n">data</span><span class="p">.</span><span class="n">randomSplit</span><span class="p">([</span><span class="mf">0.7</span><span class="p">,</span> <span class="mf">0.3</span><span class="p">])</span> |
| |
| <span class="c1"># Train a RandomForest model. |
| </span><span class="n">rf</span> <span class="o">=</span> <span class="n">RandomForestClassifier</span><span class="p">(</span><span class="n">labelCol</span><span class="o">=</span><span class="s">"indexedLabel"</span><span class="p">,</span> <span class="n">featuresCol</span><span class="o">=</span><span class="s">"indexedFeatures"</span><span class="p">,</span> <span class="n">numTrees</span><span class="o">=</span><span class="mi">10</span><span class="p">)</span> |
| |
| <span class="c1"># Convert indexed labels back to original labels. |
| </span><span class="n">labelConverter</span> <span class="o">=</span> <span class="n">IndexToString</span><span class="p">(</span><span class="n">inputCol</span><span class="o">=</span><span class="s">"prediction"</span><span class="p">,</span> <span class="n">outputCol</span><span class="o">=</span><span class="s">"predictedLabel"</span><span class="p">,</span> |
| <span class="n">labels</span><span class="o">=</span><span class="n">labelIndexer</span><span class="p">.</span><span class="n">labels</span><span class="p">)</span> |
| |
| <span class="c1"># Chain indexers and forest in a Pipeline |
| </span><span class="n">pipeline</span> <span class="o">=</span> <span class="n">Pipeline</span><span class="p">(</span><span class="n">stages</span><span class="o">=</span><span class="p">[</span><span class="n">labelIndexer</span><span class="p">,</span> <span class="n">featureIndexer</span><span class="p">,</span> <span class="n">rf</span><span class="p">,</span> <span class="n">labelConverter</span><span class="p">])</span> |
| |
| <span class="c1"># Train model. This also runs the indexers. |
| </span><span class="n">model</span> <span class="o">=</span> <span class="n">pipeline</span><span class="p">.</span><span class="n">fit</span><span class="p">(</span><span class="n">trainingData</span><span class="p">)</span> |
| |
| <span class="c1"># Make predictions. |
| </span><span class="n">predictions</span> <span class="o">=</span> <span class="n">model</span><span class="p">.</span><span class="n">transform</span><span class="p">(</span><span class="n">testData</span><span class="p">)</span> |
| |
| <span class="c1"># Select example rows to display. |
| </span><span class="n">predictions</span><span class="p">.</span><span class="n">select</span><span class="p">(</span><span class="s">"predictedLabel"</span><span class="p">,</span> <span class="s">"label"</span><span class="p">,</span> <span class="s">"features"</span><span class="p">).</span><span class="n">show</span><span class="p">(</span><span class="mi">5</span><span class="p">)</span> |
| |
| <span class="c1"># Select (prediction, true label) and compute test error |
| </span><span class="n">evaluator</span> <span class="o">=</span> <span class="n">MulticlassClassificationEvaluator</span><span class="p">(</span> |
| <span class="n">labelCol</span><span class="o">=</span><span class="s">"indexedLabel"</span><span class="p">,</span> <span class="n">predictionCol</span><span class="o">=</span><span class="s">"prediction"</span><span class="p">,</span> <span class="n">metricName</span><span class="o">=</span><span class="s">"accuracy"</span><span class="p">)</span> |
| <span class="n">accuracy</span> <span class="o">=</span> <span class="n">evaluator</span><span class="p">.</span><span class="n">evaluate</span><span class="p">(</span><span class="n">predictions</span><span class="p">)</span> |
| <span class="k">print</span><span class="p">(</span><span class="s">"Test Error = %g"</span> <span class="o">%</span> <span class="p">(</span><span class="mf">1.0</span> <span class="o">-</span> <span class="n">accuracy</span><span class="p">))</span> |
| |
| <span class="n">rfModel</span> <span class="o">=</span> <span class="n">model</span><span class="p">.</span><span class="n">stages</span><span class="p">[</span><span class="mi">2</span><span class="p">]</span> |
| <span class="k">print</span><span class="p">(</span><span class="n">rfModel</span><span class="p">)</span> <span class="c1"># summary only</span></code></pre></div> |
| <div><small>Find full example code at "examples/src/main/python/ml/random_forest_classifier_example.py" in the Spark repo.</small></div> |
| </div> |
| |
| <div data-lang="r"> |
| |
| <p>Refer to the <a href="api/R/spark.randomForest.html">R API docs</a> for more details.</p> |
| |
| <div class="highlight"><pre class="codehilite"><code><span class="c1"># Load training data</span><span class="w"> |
| </span><span class="n">df</span><span class="w"> </span><span class="o"><-</span><span class="w"> </span><span class="n">read.df</span><span class="p">(</span><span class="s2">"data/mllib/sample_libsvm_data.txt"</span><span class="p">,</span><span class="w"> </span><span class="n">source</span><span class="w"> </span><span class="o">=</span><span class="w"> </span><span class="s2">"libsvm"</span><span class="p">)</span><span class="w"> |
| </span><span class="n">training</span><span class="w"> </span><span class="o"><-</span><span class="w"> </span><span class="n">df</span><span class="w"> |
| </span><span class="n">test</span><span class="w"> </span><span class="o"><-</span><span class="w"> </span><span class="n">df</span><span class="w"> |
| |
| </span><span class="c1"># Fit a random forest classification model with spark.randomForest</span><span class="w"> |
| </span><span class="n">model</span><span class="w"> </span><span class="o"><-</span><span class="w"> </span><span class="n">spark.randomForest</span><span class="p">(</span><span class="n">training</span><span class="p">,</span><span class="w"> </span><span class="n">label</span><span class="w"> </span><span class="o">~</span><span class="w"> </span><span class="n">features</span><span class="p">,</span><span class="w"> </span><span class="s2">"classification"</span><span class="p">,</span><span class="w"> </span><span class="n">numTrees</span><span class="w"> </span><span class="o">=</span><span class="w"> </span><span class="m">10</span><span class="p">)</span><span class="w"> |
| |
| </span><span class="c1"># Model summary</span><span class="w"> |
| </span><span class="n">summary</span><span class="p">(</span><span class="n">model</span><span class="p">)</span><span class="w"> |
| |
| </span><span class="c1"># Prediction</span><span class="w"> |
| </span><span class="n">predictions</span><span class="w"> </span><span class="o"><-</span><span class="w"> </span><span class="n">predict</span><span class="p">(</span><span class="n">model</span><span class="p">,</span><span class="w"> </span><span class="n">test</span><span class="p">)</span><span class="w"> |
| </span><span class="n">head</span><span class="p">(</span><span class="n">predictions</span><span class="p">)</span></code></pre></div> |
| <div><small>Find full example code at "examples/src/main/r/ml/randomForest.R" in the Spark repo.</small></div> |
| </div> |
| |
| </div> |
| |
| <h2 id="gradient-boosted-tree-classifier">Gradient-boosted tree classifier</h2> |
| |
| <p>Gradient-boosted trees (GBTs) are a popular classification and regression method using ensembles of decision trees. |
| More information about the <code class="language-plaintext highlighter-rouge">spark.ml</code> implementation can be found further in the <a href="#gradient-boosted-trees-gbts">section on GBTs</a>.</p> |
| |
| <p><strong>Examples</strong></p> |
| |
| <p>The following examples load a dataset in LibSVM format, split it into training and test sets, train on the first dataset, and then evaluate on the held-out test set. |
| We use two feature transformers to prepare the data; these help index categories for the label and categorical features, adding metadata to the <code class="language-plaintext highlighter-rouge">DataFrame</code> which the tree-based algorithms can recognize.</p> |
| |
| <div class="codetabs"> |
| <div data-lang="scala"> |
| |
| <p>Refer to the <a href="api/scala/org/apache/spark/ml/classification/GBTClassifier.html">Scala API docs</a> for more details.</p> |
| |
| <div class="highlight"><pre class="codehilite"><code><span class="k">import</span> <span class="nn">org.apache.spark.ml.Pipeline</span> |
| <span class="k">import</span> <span class="nn">org.apache.spark.ml.classification.</span><span class="o">{</span><span class="nc">GBTClassificationModel</span><span class="o">,</span> <span class="nc">GBTClassifier</span><span class="o">}</span> |
| <span class="k">import</span> <span class="nn">org.apache.spark.ml.evaluation.MulticlassClassificationEvaluator</span> |
| <span class="k">import</span> <span class="nn">org.apache.spark.ml.feature.</span><span class="o">{</span><span class="nc">IndexToString</span><span class="o">,</span> <span class="nc">StringIndexer</span><span class="o">,</span> <span class="nc">VectorIndexer</span><span class="o">}</span> |
| |
| <span class="c1">// Load and parse the data file, converting it to a DataFrame.</span> |
| <span class="k">val</span> <span class="nv">data</span> <span class="k">=</span> <span class="nv">spark</span><span class="o">.</span><span class="py">read</span><span class="o">.</span><span class="py">format</span><span class="o">(</span><span class="s">"libsvm"</span><span class="o">).</span><span class="py">load</span><span class="o">(</span><span class="s">"data/mllib/sample_libsvm_data.txt"</span><span class="o">)</span> |
| |
| <span class="c1">// Index labels, adding metadata to the label column.</span> |
| <span class="c1">// Fit on whole dataset to include all labels in index.</span> |
| <span class="k">val</span> <span class="nv">labelIndexer</span> <span class="k">=</span> <span class="k">new</span> <span class="nc">StringIndexer</span><span class="o">()</span> |
| <span class="o">.</span><span class="py">setInputCol</span><span class="o">(</span><span class="s">"label"</span><span class="o">)</span> |
| <span class="o">.</span><span class="py">setOutputCol</span><span class="o">(</span><span class="s">"indexedLabel"</span><span class="o">)</span> |
| <span class="o">.</span><span class="py">fit</span><span class="o">(</span><span class="n">data</span><span class="o">)</span> |
| <span class="c1">// Automatically identify categorical features, and index them.</span> |
| <span class="c1">// Set maxCategories so features with > 4 distinct values are treated as continuous.</span> |
| <span class="k">val</span> <span class="nv">featureIndexer</span> <span class="k">=</span> <span class="k">new</span> <span class="nc">VectorIndexer</span><span class="o">()</span> |
| <span class="o">.</span><span class="py">setInputCol</span><span class="o">(</span><span class="s">"features"</span><span class="o">)</span> |
| <span class="o">.</span><span class="py">setOutputCol</span><span class="o">(</span><span class="s">"indexedFeatures"</span><span class="o">)</span> |
| <span class="o">.</span><span class="py">setMaxCategories</span><span class="o">(</span><span class="mi">4</span><span class="o">)</span> |
| <span class="o">.</span><span class="py">fit</span><span class="o">(</span><span class="n">data</span><span class="o">)</span> |
| |
| <span class="c1">// Split the data into training and test sets (30% held out for testing).</span> |
| <span class="k">val</span> <span class="nv">Array</span><span class="o">(</span><span class="n">trainingData</span><span class="o">,</span> <span class="n">testData</span><span class="o">)</span> <span class="k">=</span> <span class="nv">data</span><span class="o">.</span><span class="py">randomSplit</span><span class="o">(</span><span class="nc">Array</span><span class="o">(</span><span class="mf">0.7</span><span class="o">,</span> <span class="mf">0.3</span><span class="o">))</span> |
| |
| <span class="c1">// Train a GBT model.</span> |
| <span class="k">val</span> <span class="nv">gbt</span> <span class="k">=</span> <span class="k">new</span> <span class="nc">GBTClassifier</span><span class="o">()</span> |
| <span class="o">.</span><span class="py">setLabelCol</span><span class="o">(</span><span class="s">"indexedLabel"</span><span class="o">)</span> |
| <span class="o">.</span><span class="py">setFeaturesCol</span><span class="o">(</span><span class="s">"indexedFeatures"</span><span class="o">)</span> |
| <span class="o">.</span><span class="py">setMaxIter</span><span class="o">(</span><span class="mi">10</span><span class="o">)</span> |
| <span class="o">.</span><span class="py">setFeatureSubsetStrategy</span><span class="o">(</span><span class="s">"auto"</span><span class="o">)</span> |
| |
| <span class="c1">// Convert indexed labels back to original labels.</span> |
| <span class="k">val</span> <span class="nv">labelConverter</span> <span class="k">=</span> <span class="k">new</span> <span class="nc">IndexToString</span><span class="o">()</span> |
| <span class="o">.</span><span class="py">setInputCol</span><span class="o">(</span><span class="s">"prediction"</span><span class="o">)</span> |
| <span class="o">.</span><span class="py">setOutputCol</span><span class="o">(</span><span class="s">"predictedLabel"</span><span class="o">)</span> |
| <span class="o">.</span><span class="py">setLabels</span><span class="o">(</span><span class="nv">labelIndexer</span><span class="o">.</span><span class="py">labelsArray</span><span class="o">(</span><span class="mi">0</span><span class="o">))</span> |
| |
| <span class="c1">// Chain indexers and GBT in a Pipeline.</span> |
| <span class="k">val</span> <span class="nv">pipeline</span> <span class="k">=</span> <span class="k">new</span> <span class="nc">Pipeline</span><span class="o">()</span> |
| <span class="o">.</span><span class="py">setStages</span><span class="o">(</span><span class="nc">Array</span><span class="o">(</span><span class="n">labelIndexer</span><span class="o">,</span> <span class="n">featureIndexer</span><span class="o">,</span> <span class="n">gbt</span><span class="o">,</span> <span class="n">labelConverter</span><span class="o">))</span> |
| |
| <span class="c1">// Train model. This also runs the indexers.</span> |
| <span class="k">val</span> <span class="nv">model</span> <span class="k">=</span> <span class="nv">pipeline</span><span class="o">.</span><span class="py">fit</span><span class="o">(</span><span class="n">trainingData</span><span class="o">)</span> |
| |
| <span class="c1">// Make predictions.</span> |
| <span class="k">val</span> <span class="nv">predictions</span> <span class="k">=</span> <span class="nv">model</span><span class="o">.</span><span class="py">transform</span><span class="o">(</span><span class="n">testData</span><span class="o">)</span> |
| |
| <span class="c1">// Select example rows to display.</span> |
| <span class="nv">predictions</span><span class="o">.</span><span class="py">select</span><span class="o">(</span><span class="s">"predictedLabel"</span><span class="o">,</span> <span class="s">"label"</span><span class="o">,</span> <span class="s">"features"</span><span class="o">).</span><span class="py">show</span><span class="o">(</span><span class="mi">5</span><span class="o">)</span> |
| |
| <span class="c1">// Select (prediction, true label) and compute test error.</span> |
| <span class="k">val</span> <span class="nv">evaluator</span> <span class="k">=</span> <span class="k">new</span> <span class="nc">MulticlassClassificationEvaluator</span><span class="o">()</span> |
| <span class="o">.</span><span class="py">setLabelCol</span><span class="o">(</span><span class="s">"indexedLabel"</span><span class="o">)</span> |
| <span class="o">.</span><span class="py">setPredictionCol</span><span class="o">(</span><span class="s">"prediction"</span><span class="o">)</span> |
| <span class="o">.</span><span class="py">setMetricName</span><span class="o">(</span><span class="s">"accuracy"</span><span class="o">)</span> |
| <span class="k">val</span> <span class="nv">accuracy</span> <span class="k">=</span> <span class="nv">evaluator</span><span class="o">.</span><span class="py">evaluate</span><span class="o">(</span><span class="n">predictions</span><span class="o">)</span> |
| <span class="nf">println</span><span class="o">(</span><span class="n">s</span><span class="s">"Test Error = ${1.0 - accuracy}"</span><span class="o">)</span> |
| |
| <span class="k">val</span> <span class="nv">gbtModel</span> <span class="k">=</span> <span class="nv">model</span><span class="o">.</span><span class="py">stages</span><span class="o">(</span><span class="mi">2</span><span class="o">).</span><span class="py">asInstanceOf</span><span class="o">[</span><span class="kt">GBTClassificationModel</span><span class="o">]</span> |
| <span class="nf">println</span><span class="o">(</span><span class="n">s</span><span class="s">"Learned classification GBT model:\n ${gbtModel.toDebugString}"</span><span class="o">)</span></code></pre></div> |
| <div><small>Find full example code at "examples/src/main/scala/org/apache/spark/examples/ml/GradientBoostedTreeClassifierExample.scala" in the Spark repo.</small></div> |
| </div> |
| |
| <div data-lang="java"> |
| |
| <p>Refer to the <a href="api/java/org/apache/spark/ml/classification/GBTClassifier.html">Java API docs</a> for more details.</p> |
| |
| <div class="highlight"><pre class="codehilite"><code><span class="kn">import</span> <span class="nn">org.apache.spark.ml.Pipeline</span><span class="o">;</span> |
| <span class="kn">import</span> <span class="nn">org.apache.spark.ml.PipelineModel</span><span class="o">;</span> |
| <span class="kn">import</span> <span class="nn">org.apache.spark.ml.PipelineStage</span><span class="o">;</span> |
| <span class="kn">import</span> <span class="nn">org.apache.spark.ml.classification.GBTClassificationModel</span><span class="o">;</span> |
| <span class="kn">import</span> <span class="nn">org.apache.spark.ml.classification.GBTClassifier</span><span class="o">;</span> |
| <span class="kn">import</span> <span class="nn">org.apache.spark.ml.evaluation.MulticlassClassificationEvaluator</span><span class="o">;</span> |
| <span class="kn">import</span> <span class="nn">org.apache.spark.ml.feature.*</span><span class="o">;</span> |
| <span class="kn">import</span> <span class="nn">org.apache.spark.sql.Dataset</span><span class="o">;</span> |
| <span class="kn">import</span> <span class="nn">org.apache.spark.sql.Row</span><span class="o">;</span> |
| <span class="kn">import</span> <span class="nn">org.apache.spark.sql.SparkSession</span><span class="o">;</span> |
| |
| <span class="c1">// Load and parse the data file, converting it to a DataFrame.</span> |
| <span class="nc">Dataset</span><span class="o"><</span><span class="nc">Row</span><span class="o">></span> <span class="n">data</span> <span class="o">=</span> <span class="n">spark</span> |
| <span class="o">.</span><span class="na">read</span><span class="o">()</span> |
| <span class="o">.</span><span class="na">format</span><span class="o">(</span><span class="s">"libsvm"</span><span class="o">)</span> |
| <span class="o">.</span><span class="na">load</span><span class="o">(</span><span class="s">"data/mllib/sample_libsvm_data.txt"</span><span class="o">);</span> |
| |
| <span class="c1">// Index labels, adding metadata to the label column.</span> |
| <span class="c1">// Fit on whole dataset to include all labels in index.</span> |
| <span class="nc">StringIndexerModel</span> <span class="n">labelIndexer</span> <span class="o">=</span> <span class="k">new</span> <span class="nc">StringIndexer</span><span class="o">()</span> |
| <span class="o">.</span><span class="na">setInputCol</span><span class="o">(</span><span class="s">"label"</span><span class="o">)</span> |
| <span class="o">.</span><span class="na">setOutputCol</span><span class="o">(</span><span class="s">"indexedLabel"</span><span class="o">)</span> |
| <span class="o">.</span><span class="na">fit</span><span class="o">(</span><span class="n">data</span><span class="o">);</span> |
| <span class="c1">// Automatically identify categorical features, and index them.</span> |
| <span class="c1">// Set maxCategories so features with > 4 distinct values are treated as continuous.</span> |
| <span class="nc">VectorIndexerModel</span> <span class="n">featureIndexer</span> <span class="o">=</span> <span class="k">new</span> <span class="nc">VectorIndexer</span><span class="o">()</span> |
| <span class="o">.</span><span class="na">setInputCol</span><span class="o">(</span><span class="s">"features"</span><span class="o">)</span> |
| <span class="o">.</span><span class="na">setOutputCol</span><span class="o">(</span><span class="s">"indexedFeatures"</span><span class="o">)</span> |
| <span class="o">.</span><span class="na">setMaxCategories</span><span class="o">(</span><span class="mi">4</span><span class="o">)</span> |
| <span class="o">.</span><span class="na">fit</span><span class="o">(</span><span class="n">data</span><span class="o">);</span> |
| |
| <span class="c1">// Split the data into training and test sets (30% held out for testing)</span> |
| <span class="nc">Dataset</span><span class="o"><</span><span class="nc">Row</span><span class="o">>[]</span> <span class="n">splits</span> <span class="o">=</span> <span class="n">data</span><span class="o">.</span><span class="na">randomSplit</span><span class="o">(</span><span class="k">new</span> <span class="kt">double</span><span class="o">[]</span> <span class="o">{</span><span class="mf">0.7</span><span class="o">,</span> <span class="mf">0.3</span><span class="o">});</span> |
| <span class="nc">Dataset</span><span class="o"><</span><span class="nc">Row</span><span class="o">></span> <span class="n">trainingData</span> <span class="o">=</span> <span class="n">splits</span><span class="o">[</span><span class="mi">0</span><span class="o">];</span> |
| <span class="nc">Dataset</span><span class="o"><</span><span class="nc">Row</span><span class="o">></span> <span class="n">testData</span> <span class="o">=</span> <span class="n">splits</span><span class="o">[</span><span class="mi">1</span><span class="o">];</span> |
| |
| <span class="c1">// Train a GBT model.</span> |
| <span class="nc">GBTClassifier</span> <span class="n">gbt</span> <span class="o">=</span> <span class="k">new</span> <span class="nc">GBTClassifier</span><span class="o">()</span> |
| <span class="o">.</span><span class="na">setLabelCol</span><span class="o">(</span><span class="s">"indexedLabel"</span><span class="o">)</span> |
| <span class="o">.</span><span class="na">setFeaturesCol</span><span class="o">(</span><span class="s">"indexedFeatures"</span><span class="o">)</span> |
| <span class="o">.</span><span class="na">setMaxIter</span><span class="o">(</span><span class="mi">10</span><span class="o">);</span> |
| |
| <span class="c1">// Convert indexed labels back to original labels.</span> |
| <span class="nc">IndexToString</span> <span class="n">labelConverter</span> <span class="o">=</span> <span class="k">new</span> <span class="nc">IndexToString</span><span class="o">()</span> |
| <span class="o">.</span><span class="na">setInputCol</span><span class="o">(</span><span class="s">"prediction"</span><span class="o">)</span> |
| <span class="o">.</span><span class="na">setOutputCol</span><span class="o">(</span><span class="s">"predictedLabel"</span><span class="o">)</span> |
| <span class="o">.</span><span class="na">setLabels</span><span class="o">(</span><span class="n">labelIndexer</span><span class="o">.</span><span class="na">labelsArray</span><span class="o">()[</span><span class="mi">0</span><span class="o">]);</span> |
| |
| <span class="c1">// Chain indexers and GBT in a Pipeline.</span> |
| <span class="nc">Pipeline</span> <span class="n">pipeline</span> <span class="o">=</span> <span class="k">new</span> <span class="nc">Pipeline</span><span class="o">()</span> |
| <span class="o">.</span><span class="na">setStages</span><span class="o">(</span><span class="k">new</span> <span class="nc">PipelineStage</span><span class="o">[]</span> <span class="o">{</span><span class="n">labelIndexer</span><span class="o">,</span> <span class="n">featureIndexer</span><span class="o">,</span> <span class="n">gbt</span><span class="o">,</span> <span class="n">labelConverter</span><span class="o">});</span> |
| |
| <span class="c1">// Train model. This also runs the indexers.</span> |
| <span class="nc">PipelineModel</span> <span class="n">model</span> <span class="o">=</span> <span class="n">pipeline</span><span class="o">.</span><span class="na">fit</span><span class="o">(</span><span class="n">trainingData</span><span class="o">);</span> |
| |
| <span class="c1">// Make predictions.</span> |
| <span class="nc">Dataset</span><span class="o"><</span><span class="nc">Row</span><span class="o">></span> <span class="n">predictions</span> <span class="o">=</span> <span class="n">model</span><span class="o">.</span><span class="na">transform</span><span class="o">(</span><span class="n">testData</span><span class="o">);</span> |
| |
| <span class="c1">// Select example rows to display.</span> |
| <span class="n">predictions</span><span class="o">.</span><span class="na">select</span><span class="o">(</span><span class="s">"predictedLabel"</span><span class="o">,</span> <span class="s">"label"</span><span class="o">,</span> <span class="s">"features"</span><span class="o">).</span><span class="na">show</span><span class="o">(</span><span class="mi">5</span><span class="o">);</span> |
| |
| <span class="c1">// Select (prediction, true label) and compute test error.</span> |
| <span class="nc">MulticlassClassificationEvaluator</span> <span class="n">evaluator</span> <span class="o">=</span> <span class="k">new</span> <span class="nc">MulticlassClassificationEvaluator</span><span class="o">()</span> |
| <span class="o">.</span><span class="na">setLabelCol</span><span class="o">(</span><span class="s">"indexedLabel"</span><span class="o">)</span> |
| <span class="o">.</span><span class="na">setPredictionCol</span><span class="o">(</span><span class="s">"prediction"</span><span class="o">)</span> |
| <span class="o">.</span><span class="na">setMetricName</span><span class="o">(</span><span class="s">"accuracy"</span><span class="o">);</span> |
| <span class="kt">double</span> <span class="n">accuracy</span> <span class="o">=</span> <span class="n">evaluator</span><span class="o">.</span><span class="na">evaluate</span><span class="o">(</span><span class="n">predictions</span><span class="o">);</span> |
| <span class="nc">System</span><span class="o">.</span><span class="na">out</span><span class="o">.</span><span class="na">println</span><span class="o">(</span><span class="s">"Test Error = "</span> <span class="o">+</span> <span class="o">(</span><span class="mf">1.0</span> <span class="o">-</span> <span class="n">accuracy</span><span class="o">));</span> |
| |
| <span class="nc">GBTClassificationModel</span> <span class="n">gbtModel</span> <span class="o">=</span> <span class="o">(</span><span class="nc">GBTClassificationModel</span><span class="o">)(</span><span class="n">model</span><span class="o">.</span><span class="na">stages</span><span class="o">()[</span><span class="mi">2</span><span class="o">]);</span> |
| <span class="nc">System</span><span class="o">.</span><span class="na">out</span><span class="o">.</span><span class="na">println</span><span class="o">(</span><span class="s">"Learned classification GBT model:\n"</span> <span class="o">+</span> <span class="n">gbtModel</span><span class="o">.</span><span class="na">toDebugString</span><span class="o">());</span></code></pre></div> |
| <div><small>Find full example code at "examples/src/main/java/org/apache/spark/examples/ml/JavaGradientBoostedTreeClassifierExample.java" in the Spark repo.</small></div> |
| </div> |
| |
| <div data-lang="python"> |
| |
| <p>Refer to the <a href="api/python/reference/api/pyspark.ml.classification.GBTClassifier.html">Python API docs</a> for more details.</p> |
| |
| <div class="highlight"><pre class="codehilite"><code><span class="kn">from</span> <span class="nn">pyspark.ml</span> <span class="kn">import</span> <span class="n">Pipeline</span> |
| <span class="kn">from</span> <span class="nn">pyspark.ml.classification</span> <span class="kn">import</span> <span class="n">GBTClassifier</span> |
| <span class="kn">from</span> <span class="nn">pyspark.ml.feature</span> <span class="kn">import</span> <span class="n">StringIndexer</span><span class="p">,</span> <span class="n">VectorIndexer</span> |
| <span class="kn">from</span> <span class="nn">pyspark.ml.evaluation</span> <span class="kn">import</span> <span class="n">MulticlassClassificationEvaluator</span> |
| |
| <span class="c1"># Load and parse the data file, converting it to a DataFrame. |
| </span><span class="n">data</span> <span class="o">=</span> <span class="n">spark</span><span class="p">.</span><span class="n">read</span><span class="p">.</span><span class="nb">format</span><span class="p">(</span><span class="s">"libsvm"</span><span class="p">).</span><span class="n">load</span><span class="p">(</span><span class="s">"data/mllib/sample_libsvm_data.txt"</span><span class="p">)</span> |
| |
| <span class="c1"># Index labels, adding metadata to the label column. |
| # Fit on whole dataset to include all labels in index. |
| </span><span class="n">labelIndexer</span> <span class="o">=</span> <span class="n">StringIndexer</span><span class="p">(</span><span class="n">inputCol</span><span class="o">=</span><span class="s">"label"</span><span class="p">,</span> <span class="n">outputCol</span><span class="o">=</span><span class="s">"indexedLabel"</span><span class="p">).</span><span class="n">fit</span><span class="p">(</span><span class="n">data</span><span class="p">)</span> |
| <span class="c1"># Automatically identify categorical features, and index them. |
| # Set maxCategories so features with > 4 distinct values are treated as continuous. |
| </span><span class="n">featureIndexer</span> <span class="o">=</span>\ |
| <span class="n">VectorIndexer</span><span class="p">(</span><span class="n">inputCol</span><span class="o">=</span><span class="s">"features"</span><span class="p">,</span> <span class="n">outputCol</span><span class="o">=</span><span class="s">"indexedFeatures"</span><span class="p">,</span> <span class="n">maxCategories</span><span class="o">=</span><span class="mi">4</span><span class="p">).</span><span class="n">fit</span><span class="p">(</span><span class="n">data</span><span class="p">)</span> |
| |
| <span class="c1"># Split the data into training and test sets (30% held out for testing) |
| </span><span class="p">(</span><span class="n">trainingData</span><span class="p">,</span> <span class="n">testData</span><span class="p">)</span> <span class="o">=</span> <span class="n">data</span><span class="p">.</span><span class="n">randomSplit</span><span class="p">([</span><span class="mf">0.7</span><span class="p">,</span> <span class="mf">0.3</span><span class="p">])</span> |
| |
| <span class="c1"># Train a GBT model. |
| </span><span class="n">gbt</span> <span class="o">=</span> <span class="n">GBTClassifier</span><span class="p">(</span><span class="n">labelCol</span><span class="o">=</span><span class="s">"indexedLabel"</span><span class="p">,</span> <span class="n">featuresCol</span><span class="o">=</span><span class="s">"indexedFeatures"</span><span class="p">,</span> <span class="n">maxIter</span><span class="o">=</span><span class="mi">10</span><span class="p">)</span> |
| |
| <span class="c1"># Chain indexers and GBT in a Pipeline |
| </span><span class="n">pipeline</span> <span class="o">=</span> <span class="n">Pipeline</span><span class="p">(</span><span class="n">stages</span><span class="o">=</span><span class="p">[</span><span class="n">labelIndexer</span><span class="p">,</span> <span class="n">featureIndexer</span><span class="p">,</span> <span class="n">gbt</span><span class="p">])</span> |
| |
| <span class="c1"># Train model. This also runs the indexers. |
| </span><span class="n">model</span> <span class="o">=</span> <span class="n">pipeline</span><span class="p">.</span><span class="n">fit</span><span class="p">(</span><span class="n">trainingData</span><span class="p">)</span> |
| |
| <span class="c1"># Make predictions. |
| </span><span class="n">predictions</span> <span class="o">=</span> <span class="n">model</span><span class="p">.</span><span class="n">transform</span><span class="p">(</span><span class="n">testData</span><span class="p">)</span> |
| |
| <span class="c1"># Select example rows to display. |
| </span><span class="n">predictions</span><span class="p">.</span><span class="n">select</span><span class="p">(</span><span class="s">"prediction"</span><span class="p">,</span> <span class="s">"indexedLabel"</span><span class="p">,</span> <span class="s">"features"</span><span class="p">).</span><span class="n">show</span><span class="p">(</span><span class="mi">5</span><span class="p">)</span> |
| |
| <span class="c1"># Select (prediction, true label) and compute test error |
| </span><span class="n">evaluator</span> <span class="o">=</span> <span class="n">MulticlassClassificationEvaluator</span><span class="p">(</span> |
| <span class="n">labelCol</span><span class="o">=</span><span class="s">"indexedLabel"</span><span class="p">,</span> <span class="n">predictionCol</span><span class="o">=</span><span class="s">"prediction"</span><span class="p">,</span> <span class="n">metricName</span><span class="o">=</span><span class="s">"accuracy"</span><span class="p">)</span> |
| <span class="n">accuracy</span> <span class="o">=</span> <span class="n">evaluator</span><span class="p">.</span><span class="n">evaluate</span><span class="p">(</span><span class="n">predictions</span><span class="p">)</span> |
| <span class="k">print</span><span class="p">(</span><span class="s">"Test Error = %g"</span> <span class="o">%</span> <span class="p">(</span><span class="mf">1.0</span> <span class="o">-</span> <span class="n">accuracy</span><span class="p">))</span> |
| |
| <span class="n">gbtModel</span> <span class="o">=</span> <span class="n">model</span><span class="p">.</span><span class="n">stages</span><span class="p">[</span><span class="mi">2</span><span class="p">]</span> |
| <span class="k">print</span><span class="p">(</span><span class="n">gbtModel</span><span class="p">)</span> <span class="c1"># summary only</span></code></pre></div> |
| <div><small>Find full example code at "examples/src/main/python/ml/gradient_boosted_tree_classifier_example.py" in the Spark repo.</small></div> |
| </div> |
| |
| <div data-lang="r"> |
| |
| <p>Refer to the <a href="api/R/spark.gbt.html">R API docs</a> for more details.</p> |
| |
| <div class="highlight"><pre class="codehilite"><code><span class="c1"># Load training data</span><span class="w"> |
| </span><span class="n">df</span><span class="w"> </span><span class="o"><-</span><span class="w"> </span><span class="n">read.df</span><span class="p">(</span><span class="s2">"data/mllib/sample_libsvm_data.txt"</span><span class="p">,</span><span class="w"> </span><span class="n">source</span><span class="w"> </span><span class="o">=</span><span class="w"> </span><span class="s2">"libsvm"</span><span class="p">)</span><span class="w"> |
| </span><span class="n">training</span><span class="w"> </span><span class="o"><-</span><span class="w"> </span><span class="n">df</span><span class="w"> |
| </span><span class="n">test</span><span class="w"> </span><span class="o"><-</span><span class="w"> </span><span class="n">df</span><span class="w"> |
| |
| </span><span class="c1"># Fit a GBT classification model with spark.gbt</span><span class="w"> |
| </span><span class="n">model</span><span class="w"> </span><span class="o"><-</span><span class="w"> </span><span class="n">spark.gbt</span><span class="p">(</span><span class="n">training</span><span class="p">,</span><span class="w"> </span><span class="n">label</span><span class="w"> </span><span class="o">~</span><span class="w"> </span><span class="n">features</span><span class="p">,</span><span class="w"> </span><span class="s2">"classification"</span><span class="p">,</span><span class="w"> </span><span class="n">maxIter</span><span class="w"> </span><span class="o">=</span><span class="w"> </span><span class="m">10</span><span class="p">)</span><span class="w"> |
| |
| </span><span class="c1"># Model summary</span><span class="w"> |
| </span><span class="n">summary</span><span class="p">(</span><span class="n">model</span><span class="p">)</span><span class="w"> |
| |
| </span><span class="c1"># Prediction</span><span class="w"> |
| </span><span class="n">predictions</span><span class="w"> </span><span class="o"><-</span><span class="w"> </span><span class="n">predict</span><span class="p">(</span><span class="n">model</span><span class="p">,</span><span class="w"> </span><span class="n">test</span><span class="p">)</span><span class="w"> |
| </span><span class="n">head</span><span class="p">(</span><span class="n">predictions</span><span class="p">)</span></code></pre></div> |
| <div><small>Find full example code at "examples/src/main/r/ml/gbt.R" in the Spark repo.</small></div> |
| </div> |
| |
| </div> |
| |
| <h2 id="multilayer-perceptron-classifier">Multilayer perceptron classifier</h2> |
| |
| <p>Multilayer perceptron classifier (MLPC) is a classifier based on the <a href="https://en.wikipedia.org/wiki/Feedforward_neural_network">feedforward artificial neural network</a>. |
| MLPC consists of multiple layers of nodes. |
| Each layer is fully connected to the next layer in the network. Nodes in the input layer represent the input data. All other nodes map inputs to outputs |
| by a linear combination of the inputs with the node’s weights <code class="language-plaintext highlighter-rouge">$\wv$</code> and bias <code class="language-plaintext highlighter-rouge">$\bv$</code> and applying an activation function. |
| This can be written in matrix form for MLPC with <code class="language-plaintext highlighter-rouge">$K+1$</code> layers as follows: |
| <code class="language-plaintext highlighter-rouge">\[ |
| \mathrm{y}(\x) = \mathrm{f_K}(...\mathrm{f_2}(\wv_2^T\mathrm{f_1}(\wv_1^T \x+b_1)+b_2)...+b_K) |
| \]</code> |
| Nodes in intermediate layers use sigmoid (logistic) function: |
| <code class="language-plaintext highlighter-rouge">\[ |
| \mathrm{f}(z_i) = \frac{1}{1 + e^{-z_i}} |
| \]</code> |
| Nodes in the output layer use softmax function: |
| <code class="language-plaintext highlighter-rouge">\[ |
| \mathrm{f}(z_i) = \frac{e^{z_i}}{\sum_{k=1}^N e^{z_k}} |
| \]</code> |
| The number of nodes <code class="language-plaintext highlighter-rouge">$N$</code> in the output layer corresponds to the number of classes.</p> |
| |
| <p>MLPC employs backpropagation for learning the model. We use the logistic loss function for optimization and L-BFGS as an optimization routine.</p> |
| |
| <p><strong>Examples</strong></p> |
| |
| <div class="codetabs"> |
| |
| <div data-lang="scala"> |
| |
| <p>Refer to the <a href="api/scala/org/apache/spark/ml/classification/MultilayerPerceptronClassifier.html">Scala API docs</a> for more details.</p> |
| |
| <div class="highlight"><pre class="codehilite"><code><span class="k">import</span> <span class="nn">org.apache.spark.ml.classification.MultilayerPerceptronClassifier</span> |
| <span class="k">import</span> <span class="nn">org.apache.spark.ml.evaluation.MulticlassClassificationEvaluator</span> |
| |
| <span class="c1">// Load the data stored in LIBSVM format as a DataFrame.</span> |
| <span class="k">val</span> <span class="nv">data</span> <span class="k">=</span> <span class="nv">spark</span><span class="o">.</span><span class="py">read</span><span class="o">.</span><span class="py">format</span><span class="o">(</span><span class="s">"libsvm"</span><span class="o">)</span> |
| <span class="o">.</span><span class="py">load</span><span class="o">(</span><span class="s">"data/mllib/sample_multiclass_classification_data.txt"</span><span class="o">)</span> |
| |
| <span class="c1">// Split the data into train and test</span> |
| <span class="k">val</span> <span class="nv">splits</span> <span class="k">=</span> <span class="nv">data</span><span class="o">.</span><span class="py">randomSplit</span><span class="o">(</span><span class="nc">Array</span><span class="o">(</span><span class="mf">0.6</span><span class="o">,</span> <span class="mf">0.4</span><span class="o">),</span> <span class="n">seed</span> <span class="k">=</span> <span class="mi">1234L</span><span class="o">)</span> |
| <span class="k">val</span> <span class="nv">train</span> <span class="k">=</span> <span class="nf">splits</span><span class="o">(</span><span class="mi">0</span><span class="o">)</span> |
| <span class="k">val</span> <span class="nv">test</span> <span class="k">=</span> <span class="nf">splits</span><span class="o">(</span><span class="mi">1</span><span class="o">)</span> |
| |
| <span class="c1">// specify layers for the neural network:</span> |
| <span class="c1">// input layer of size 4 (features), two intermediate of size 5 and 4</span> |
| <span class="c1">// and output of size 3 (classes)</span> |
| <span class="k">val</span> <span class="nv">layers</span> <span class="k">=</span> <span class="nc">Array</span><span class="o">[</span><span class="kt">Int</span><span class="o">](</span><span class="mi">4</span><span class="o">,</span> <span class="mi">5</span><span class="o">,</span> <span class="mi">4</span><span class="o">,</span> <span class="mi">3</span><span class="o">)</span> |
| |
| <span class="c1">// create the trainer and set its parameters</span> |
| <span class="k">val</span> <span class="nv">trainer</span> <span class="k">=</span> <span class="k">new</span> <span class="nc">MultilayerPerceptronClassifier</span><span class="o">()</span> |
| <span class="o">.</span><span class="py">setLayers</span><span class="o">(</span><span class="n">layers</span><span class="o">)</span> |
| <span class="o">.</span><span class="py">setBlockSize</span><span class="o">(</span><span class="mi">128</span><span class="o">)</span> |
| <span class="o">.</span><span class="py">setSeed</span><span class="o">(</span><span class="mi">1234L</span><span class="o">)</span> |
| <span class="o">.</span><span class="py">setMaxIter</span><span class="o">(</span><span class="mi">100</span><span class="o">)</span> |
| |
| <span class="c1">// train the model</span> |
| <span class="k">val</span> <span class="nv">model</span> <span class="k">=</span> <span class="nv">trainer</span><span class="o">.</span><span class="py">fit</span><span class="o">(</span><span class="n">train</span><span class="o">)</span> |
| |
| <span class="c1">// compute accuracy on the test set</span> |
| <span class="k">val</span> <span class="nv">result</span> <span class="k">=</span> <span class="nv">model</span><span class="o">.</span><span class="py">transform</span><span class="o">(</span><span class="n">test</span><span class="o">)</span> |
| <span class="k">val</span> <span class="nv">predictionAndLabels</span> <span class="k">=</span> <span class="nv">result</span><span class="o">.</span><span class="py">select</span><span class="o">(</span><span class="s">"prediction"</span><span class="o">,</span> <span class="s">"label"</span><span class="o">)</span> |
| <span class="k">val</span> <span class="nv">evaluator</span> <span class="k">=</span> <span class="k">new</span> <span class="nc">MulticlassClassificationEvaluator</span><span class="o">()</span> |
| <span class="o">.</span><span class="py">setMetricName</span><span class="o">(</span><span class="s">"accuracy"</span><span class="o">)</span> |
| |
| <span class="nf">println</span><span class="o">(</span><span class="n">s</span><span class="s">"Test set accuracy = ${evaluator.evaluate(predictionAndLabels)}"</span><span class="o">)</span></code></pre></div> |
| <div><small>Find full example code at "examples/src/main/scala/org/apache/spark/examples/ml/MultilayerPerceptronClassifierExample.scala" in the Spark repo.</small></div> |
| </div> |
| |
| <div data-lang="java"> |
| |
| <p>Refer to the <a href="api/java/org/apache/spark/ml/classification/MultilayerPerceptronClassifier.html">Java API docs</a> for more details.</p> |
| |
| <div class="highlight"><pre class="codehilite"><code><span class="kn">import</span> <span class="nn">org.apache.spark.sql.Dataset</span><span class="o">;</span> |
| <span class="kn">import</span> <span class="nn">org.apache.spark.sql.Row</span><span class="o">;</span> |
| <span class="kn">import</span> <span class="nn">org.apache.spark.sql.SparkSession</span><span class="o">;</span> |
| <span class="kn">import</span> <span class="nn">org.apache.spark.ml.classification.MultilayerPerceptronClassificationModel</span><span class="o">;</span> |
| <span class="kn">import</span> <span class="nn">org.apache.spark.ml.classification.MultilayerPerceptronClassifier</span><span class="o">;</span> |
| <span class="kn">import</span> <span class="nn">org.apache.spark.ml.evaluation.MulticlassClassificationEvaluator</span><span class="o">;</span> |
| |
| <span class="c1">// Load training data</span> |
| <span class="nc">String</span> <span class="n">path</span> <span class="o">=</span> <span class="s">"data/mllib/sample_multiclass_classification_data.txt"</span><span class="o">;</span> |
| <span class="nc">Dataset</span><span class="o"><</span><span class="nc">Row</span><span class="o">></span> <span class="n">dataFrame</span> <span class="o">=</span> <span class="n">spark</span><span class="o">.</span><span class="na">read</span><span class="o">().</span><span class="na">format</span><span class="o">(</span><span class="s">"libsvm"</span><span class="o">).</span><span class="na">load</span><span class="o">(</span><span class="n">path</span><span class="o">);</span> |
| |
| <span class="c1">// Split the data into train and test</span> |
| <span class="nc">Dataset</span><span class="o"><</span><span class="nc">Row</span><span class="o">>[]</span> <span class="n">splits</span> <span class="o">=</span> <span class="n">dataFrame</span><span class="o">.</span><span class="na">randomSplit</span><span class="o">(</span><span class="k">new</span> <span class="kt">double</span><span class="o">[]{</span><span class="mf">0.6</span><span class="o">,</span> <span class="mf">0.4</span><span class="o">},</span> <span class="mi">1234L</span><span class="o">);</span> |
| <span class="nc">Dataset</span><span class="o"><</span><span class="nc">Row</span><span class="o">></span> <span class="n">train</span> <span class="o">=</span> <span class="n">splits</span><span class="o">[</span><span class="mi">0</span><span class="o">];</span> |
| <span class="nc">Dataset</span><span class="o"><</span><span class="nc">Row</span><span class="o">></span> <span class="n">test</span> <span class="o">=</span> <span class="n">splits</span><span class="o">[</span><span class="mi">1</span><span class="o">];</span> |
| |
| <span class="c1">// specify layers for the neural network:</span> |
| <span class="c1">// input layer of size 4 (features), two intermediate of size 5 and 4</span> |
| <span class="c1">// and output of size 3 (classes)</span> |
| <span class="kt">int</span><span class="o">[]</span> <span class="n">layers</span> <span class="o">=</span> <span class="k">new</span> <span class="kt">int</span><span class="o">[]</span> <span class="o">{</span><span class="mi">4</span><span class="o">,</span> <span class="mi">5</span><span class="o">,</span> <span class="mi">4</span><span class="o">,</span> <span class="mi">3</span><span class="o">};</span> |
| |
| <span class="c1">// create the trainer and set its parameters</span> |
| <span class="nc">MultilayerPerceptronClassifier</span> <span class="n">trainer</span> <span class="o">=</span> <span class="k">new</span> <span class="nc">MultilayerPerceptronClassifier</span><span class="o">()</span> |
| <span class="o">.</span><span class="na">setLayers</span><span class="o">(</span><span class="n">layers</span><span class="o">)</span> |
| <span class="o">.</span><span class="na">setBlockSize</span><span class="o">(</span><span class="mi">128</span><span class="o">)</span> |
| <span class="o">.</span><span class="na">setSeed</span><span class="o">(</span><span class="mi">1234L</span><span class="o">)</span> |
| <span class="o">.</span><span class="na">setMaxIter</span><span class="o">(</span><span class="mi">100</span><span class="o">);</span> |
| |
| <span class="c1">// train the model</span> |
| <span class="nc">MultilayerPerceptronClassificationModel</span> <span class="n">model</span> <span class="o">=</span> <span class="n">trainer</span><span class="o">.</span><span class="na">fit</span><span class="o">(</span><span class="n">train</span><span class="o">);</span> |
| |
| <span class="c1">// compute accuracy on the test set</span> |
| <span class="nc">Dataset</span><span class="o"><</span><span class="nc">Row</span><span class="o">></span> <span class="n">result</span> <span class="o">=</span> <span class="n">model</span><span class="o">.</span><span class="na">transform</span><span class="o">(</span><span class="n">test</span><span class="o">);</span> |
| <span class="nc">Dataset</span><span class="o"><</span><span class="nc">Row</span><span class="o">></span> <span class="n">predictionAndLabels</span> <span class="o">=</span> <span class="n">result</span><span class="o">.</span><span class="na">select</span><span class="o">(</span><span class="s">"prediction"</span><span class="o">,</span> <span class="s">"label"</span><span class="o">);</span> |
| <span class="nc">MulticlassClassificationEvaluator</span> <span class="n">evaluator</span> <span class="o">=</span> <span class="k">new</span> <span class="nc">MulticlassClassificationEvaluator</span><span class="o">()</span> |
| <span class="o">.</span><span class="na">setMetricName</span><span class="o">(</span><span class="s">"accuracy"</span><span class="o">);</span> |
| |
| <span class="nc">System</span><span class="o">.</span><span class="na">out</span><span class="o">.</span><span class="na">println</span><span class="o">(</span><span class="s">"Test set accuracy = "</span> <span class="o">+</span> <span class="n">evaluator</span><span class="o">.</span><span class="na">evaluate</span><span class="o">(</span><span class="n">predictionAndLabels</span><span class="o">));</span></code></pre></div> |
| <div><small>Find full example code at "examples/src/main/java/org/apache/spark/examples/ml/JavaMultilayerPerceptronClassifierExample.java" in the Spark repo.</small></div> |
| </div> |
| |
| <div data-lang="python"> |
| |
| <p>Refer to the <a href="api/python/reference/api/pyspark.ml.classification.MultilayerPerceptronClassifier.html">Python API docs</a> for more details.</p> |
| |
| <div class="highlight"><pre class="codehilite"><code><span class="kn">from</span> <span class="nn">pyspark.ml.classification</span> <span class="kn">import</span> <span class="n">MultilayerPerceptronClassifier</span> |
| <span class="kn">from</span> <span class="nn">pyspark.ml.evaluation</span> <span class="kn">import</span> <span class="n">MulticlassClassificationEvaluator</span> |
| |
| <span class="c1"># Load training data |
| </span><span class="n">data</span> <span class="o">=</span> <span class="n">spark</span><span class="p">.</span><span class="n">read</span><span class="p">.</span><span class="nb">format</span><span class="p">(</span><span class="s">"libsvm"</span><span class="p">)</span>\ |
| <span class="p">.</span><span class="n">load</span><span class="p">(</span><span class="s">"data/mllib/sample_multiclass_classification_data.txt"</span><span class="p">)</span> |
| |
| <span class="c1"># Split the data into train and test |
| </span><span class="n">splits</span> <span class="o">=</span> <span class="n">data</span><span class="p">.</span><span class="n">randomSplit</span><span class="p">([</span><span class="mf">0.6</span><span class="p">,</span> <span class="mf">0.4</span><span class="p">],</span> <span class="mi">1234</span><span class="p">)</span> |
| <span class="n">train</span> <span class="o">=</span> <span class="n">splits</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span> |
| <span class="n">test</span> <span class="o">=</span> <span class="n">splits</span><span class="p">[</span><span class="mi">1</span><span class="p">]</span> |
| |
| <span class="c1"># specify layers for the neural network: |
| # input layer of size 4 (features), two intermediate of size 5 and 4 |
| # and output of size 3 (classes) |
| </span><span class="n">layers</span> <span class="o">=</span> <span class="p">[</span><span class="mi">4</span><span class="p">,</span> <span class="mi">5</span><span class="p">,</span> <span class="mi">4</span><span class="p">,</span> <span class="mi">3</span><span class="p">]</span> |
| |
| <span class="c1"># create the trainer and set its parameters |
| </span><span class="n">trainer</span> <span class="o">=</span> <span class="n">MultilayerPerceptronClassifier</span><span class="p">(</span><span class="n">maxIter</span><span class="o">=</span><span class="mi">100</span><span class="p">,</span> <span class="n">layers</span><span class="o">=</span><span class="n">layers</span><span class="p">,</span> <span class="n">blockSize</span><span class="o">=</span><span class="mi">128</span><span class="p">,</span> <span class="n">seed</span><span class="o">=</span><span class="mi">1234</span><span class="p">)</span> |
| |
| <span class="c1"># train the model |
| </span><span class="n">model</span> <span class="o">=</span> <span class="n">trainer</span><span class="p">.</span><span class="n">fit</span><span class="p">(</span><span class="n">train</span><span class="p">)</span> |
| |
| <span class="c1"># compute accuracy on the test set |
| </span><span class="n">result</span> <span class="o">=</span> <span class="n">model</span><span class="p">.</span><span class="n">transform</span><span class="p">(</span><span class="n">test</span><span class="p">)</span> |
| <span class="n">predictionAndLabels</span> <span class="o">=</span> <span class="n">result</span><span class="p">.</span><span class="n">select</span><span class="p">(</span><span class="s">"prediction"</span><span class="p">,</span> <span class="s">"label"</span><span class="p">)</span> |
| <span class="n">evaluator</span> <span class="o">=</span> <span class="n">MulticlassClassificationEvaluator</span><span class="p">(</span><span class="n">metricName</span><span class="o">=</span><span class="s">"accuracy"</span><span class="p">)</span> |
| <span class="k">print</span><span class="p">(</span><span class="s">"Test set accuracy = "</span> <span class="o">+</span> <span class="nb">str</span><span class="p">(</span><span class="n">evaluator</span><span class="p">.</span><span class="n">evaluate</span><span class="p">(</span><span class="n">predictionAndLabels</span><span class="p">)))</span></code></pre></div> |
| <div><small>Find full example code at "examples/src/main/python/ml/multilayer_perceptron_classification.py" in the Spark repo.</small></div> |
| </div> |
| |
| <div data-lang="r"> |
| |
| <p>Refer to the <a href="api/R/spark.mlp.html">R API docs</a> for more details.</p> |
| |
| <div class="highlight"><pre class="codehilite"><code><span class="c1"># Load training data</span><span class="w"> |
| </span><span class="n">df</span><span class="w"> </span><span class="o"><-</span><span class="w"> </span><span class="n">read.df</span><span class="p">(</span><span class="s2">"data/mllib/sample_multiclass_classification_data.txt"</span><span class="p">,</span><span class="w"> </span><span class="n">source</span><span class="w"> </span><span class="o">=</span><span class="w"> </span><span class="s2">"libsvm"</span><span class="p">)</span><span class="w"> |
| </span><span class="n">training</span><span class="w"> </span><span class="o"><-</span><span class="w"> </span><span class="n">df</span><span class="w"> |
| </span><span class="n">test</span><span class="w"> </span><span class="o"><-</span><span class="w"> </span><span class="n">df</span><span class="w"> |
| |
| </span><span class="c1"># specify layers for the neural network:</span><span class="w"> |
| </span><span class="c1"># input layer of size 4 (features), two intermediate of size 5 and 4</span><span class="w"> |
| </span><span class="c1"># and output of size 3 (classes)</span><span class="w"> |
| </span><span class="n">layers</span><span class="w"> </span><span class="o">=</span><span class="w"> </span><span class="nf">c</span><span class="p">(</span><span class="m">4</span><span class="p">,</span><span class="w"> </span><span class="m">5</span><span class="p">,</span><span class="w"> </span><span class="m">4</span><span class="p">,</span><span class="w"> </span><span class="m">3</span><span class="p">)</span><span class="w"> |
| |
| </span><span class="c1"># Fit a multi-layer perceptron neural network model with spark.mlp</span><span class="w"> |
| </span><span class="n">model</span><span class="w"> </span><span class="o"><-</span><span class="w"> </span><span class="n">spark.mlp</span><span class="p">(</span><span class="n">training</span><span class="p">,</span><span class="w"> </span><span class="n">label</span><span class="w"> </span><span class="o">~</span><span class="w"> </span><span class="n">features</span><span class="p">,</span><span class="w"> </span><span class="n">maxIter</span><span class="w"> </span><span class="o">=</span><span class="w"> </span><span class="m">100</span><span class="p">,</span><span class="w"> |
| </span><span class="n">layers</span><span class="w"> </span><span class="o">=</span><span class="w"> </span><span class="n">layers</span><span class="p">,</span><span class="w"> </span><span class="n">blockSize</span><span class="w"> </span><span class="o">=</span><span class="w"> </span><span class="m">128</span><span class="p">,</span><span class="w"> </span><span class="n">seed</span><span class="w"> </span><span class="o">=</span><span class="w"> </span><span class="m">1234</span><span class="p">)</span><span class="w"> |
| |
| </span><span class="c1"># Model summary</span><span class="w"> |
| </span><span class="n">summary</span><span class="p">(</span><span class="n">model</span><span class="p">)</span><span class="w"> |
| |
| </span><span class="c1"># Prediction</span><span class="w"> |
| </span><span class="n">predictions</span><span class="w"> </span><span class="o"><-</span><span class="w"> </span><span class="n">predict</span><span class="p">(</span><span class="n">model</span><span class="p">,</span><span class="w"> </span><span class="n">test</span><span class="p">)</span><span class="w"> |
| </span><span class="n">head</span><span class="p">(</span><span class="n">predictions</span><span class="p">)</span></code></pre></div> |
| <div><small>Find full example code at "examples/src/main/r/ml/mlp.R" in the Spark repo.</small></div> |
| </div> |
| |
| </div> |
| |
| <h2 id="linear-support-vector-machine">Linear Support Vector Machine</h2> |
| |
| <p>A <a href="https://en.wikipedia.org/wiki/Support_vector_machine">support vector machine</a> constructs a hyperplane |
| or set of hyperplanes in a high- or infinite-dimensional space, which can be used for classification, |
| regression, or other tasks. Intuitively, a good separation is achieved by the hyperplane that has |
| the largest distance to the nearest training-data points of any class (so-called functional margin), |
| since in general the larger the margin the lower the generalization error of the classifier. LinearSVC |
| in Spark ML supports binary classification with linear SVM. Internally, it optimizes the |
| <a href="https://en.wikipedia.org/wiki/Hinge_loss">Hinge Loss</a> using OWLQN optimizer.</p> |
| |
| <p><strong>Examples</strong></p> |
| |
| <div class="codetabs"> |
| |
| <div data-lang="scala"> |
| |
| <p>Refer to the <a href="api/scala/org/apache/spark/ml/classification/LinearSVC.html">Scala API docs</a> for more details.</p> |
| |
| <div class="highlight"><pre class="codehilite"><code><span class="k">import</span> <span class="nn">org.apache.spark.ml.classification.LinearSVC</span> |
| |
| <span class="c1">// Load training data</span> |
| <span class="k">val</span> <span class="nv">training</span> <span class="k">=</span> <span class="nv">spark</span><span class="o">.</span><span class="py">read</span><span class="o">.</span><span class="py">format</span><span class="o">(</span><span class="s">"libsvm"</span><span class="o">).</span><span class="py">load</span><span class="o">(</span><span class="s">"data/mllib/sample_libsvm_data.txt"</span><span class="o">)</span> |
| |
| <span class="k">val</span> <span class="nv">lsvc</span> <span class="k">=</span> <span class="k">new</span> <span class="nc">LinearSVC</span><span class="o">()</span> |
| <span class="o">.</span><span class="py">setMaxIter</span><span class="o">(</span><span class="mi">10</span><span class="o">)</span> |
| <span class="o">.</span><span class="py">setRegParam</span><span class="o">(</span><span class="mf">0.1</span><span class="o">)</span> |
| |
| <span class="c1">// Fit the model</span> |
| <span class="k">val</span> <span class="nv">lsvcModel</span> <span class="k">=</span> <span class="nv">lsvc</span><span class="o">.</span><span class="py">fit</span><span class="o">(</span><span class="n">training</span><span class="o">)</span> |
| |
| <span class="c1">// Print the coefficients and intercept for linear svc</span> |
| <span class="nf">println</span><span class="o">(</span><span class="n">s</span><span class="s">"Coefficients: ${lsvcModel.coefficients} Intercept: ${lsvcModel.intercept}"</span><span class="o">)</span></code></pre></div> |
| <div><small>Find full example code at "examples/src/main/scala/org/apache/spark/examples/ml/LinearSVCExample.scala" in the Spark repo.</small></div> |
| </div> |
| |
| <div data-lang="java"> |
| |
| <p>Refer to the <a href="api/java/org/apache/spark/ml/classification/LinearSVC.html">Java API docs</a> for more details.</p> |
| |
| <div class="highlight"><pre class="codehilite"><code><span class="kn">import</span> <span class="nn">org.apache.spark.ml.classification.LinearSVC</span><span class="o">;</span> |
| <span class="kn">import</span> <span class="nn">org.apache.spark.ml.classification.LinearSVCModel</span><span class="o">;</span> |
| <span class="kn">import</span> <span class="nn">org.apache.spark.sql.Dataset</span><span class="o">;</span> |
| <span class="kn">import</span> <span class="nn">org.apache.spark.sql.Row</span><span class="o">;</span> |
| <span class="kn">import</span> <span class="nn">org.apache.spark.sql.SparkSession</span><span class="o">;</span> |
| |
| <span class="c1">// Load training data</span> |
| <span class="nc">Dataset</span><span class="o"><</span><span class="nc">Row</span><span class="o">></span> <span class="n">training</span> <span class="o">=</span> <span class="n">spark</span><span class="o">.</span><span class="na">read</span><span class="o">().</span><span class="na">format</span><span class="o">(</span><span class="s">"libsvm"</span><span class="o">)</span> |
| <span class="o">.</span><span class="na">load</span><span class="o">(</span><span class="s">"data/mllib/sample_libsvm_data.txt"</span><span class="o">);</span> |
| |
| <span class="nc">LinearSVC</span> <span class="n">lsvc</span> <span class="o">=</span> <span class="k">new</span> <span class="nc">LinearSVC</span><span class="o">()</span> |
| <span class="o">.</span><span class="na">setMaxIter</span><span class="o">(</span><span class="mi">10</span><span class="o">)</span> |
| <span class="o">.</span><span class="na">setRegParam</span><span class="o">(</span><span class="mf">0.1</span><span class="o">);</span> |
| |
| <span class="c1">// Fit the model</span> |
| <span class="nc">LinearSVCModel</span> <span class="n">lsvcModel</span> <span class="o">=</span> <span class="n">lsvc</span><span class="o">.</span><span class="na">fit</span><span class="o">(</span><span class="n">training</span><span class="o">);</span> |
| |
| <span class="c1">// Print the coefficients and intercept for LinearSVC</span> |
| <span class="nc">System</span><span class="o">.</span><span class="na">out</span><span class="o">.</span><span class="na">println</span><span class="o">(</span><span class="s">"Coefficients: "</span> |
| <span class="o">+</span> <span class="n">lsvcModel</span><span class="o">.</span><span class="na">coefficients</span><span class="o">()</span> <span class="o">+</span> <span class="s">" Intercept: "</span> <span class="o">+</span> <span class="n">lsvcModel</span><span class="o">.</span><span class="na">intercept</span><span class="o">());</span></code></pre></div> |
| <div><small>Find full example code at "examples/src/main/java/org/apache/spark/examples/ml/JavaLinearSVCExample.java" in the Spark repo.</small></div> |
| </div> |
| |
| <div data-lang="python"> |
| |
| <p>Refer to the <a href="api/python/reference/api/pyspark.ml.classification.LinearSVC.html">Python API docs</a> for more details.</p> |
| |
| <div class="highlight"><pre class="codehilite"><code><span class="kn">from</span> <span class="nn">pyspark.ml.classification</span> <span class="kn">import</span> <span class="n">LinearSVC</span> |
| |
| <span class="c1"># Load training data |
| </span><span class="n">training</span> <span class="o">=</span> <span class="n">spark</span><span class="p">.</span><span class="n">read</span><span class="p">.</span><span class="nb">format</span><span class="p">(</span><span class="s">"libsvm"</span><span class="p">).</span><span class="n">load</span><span class="p">(</span><span class="s">"data/mllib/sample_libsvm_data.txt"</span><span class="p">)</span> |
| |
| <span class="n">lsvc</span> <span class="o">=</span> <span class="n">LinearSVC</span><span class="p">(</span><span class="n">maxIter</span><span class="o">=</span><span class="mi">10</span><span class="p">,</span> <span class="n">regParam</span><span class="o">=</span><span class="mf">0.1</span><span class="p">)</span> |
| |
| <span class="c1"># Fit the model |
| </span><span class="n">lsvcModel</span> <span class="o">=</span> <span class="n">lsvc</span><span class="p">.</span><span class="n">fit</span><span class="p">(</span><span class="n">training</span><span class="p">)</span> |
| |
| <span class="c1"># Print the coefficients and intercept for linear SVC |
| </span><span class="k">print</span><span class="p">(</span><span class="s">"Coefficients: "</span> <span class="o">+</span> <span class="nb">str</span><span class="p">(</span><span class="n">lsvcModel</span><span class="p">.</span><span class="n">coefficients</span><span class="p">))</span> |
| <span class="k">print</span><span class="p">(</span><span class="s">"Intercept: "</span> <span class="o">+</span> <span class="nb">str</span><span class="p">(</span><span class="n">lsvcModel</span><span class="p">.</span><span class="n">intercept</span><span class="p">))</span></code></pre></div> |
| <div><small>Find full example code at "examples/src/main/python/ml/linearsvc.py" in the Spark repo.</small></div> |
| </div> |
| |
| <div data-lang="r"> |
| |
| <p>Refer to the <a href="api/R/spark.svmLinear.html">R API docs</a> for more details.</p> |
| |
| <div class="highlight"><pre class="codehilite"><code><span class="c1"># load training data</span><span class="w"> |
| </span><span class="n">t</span><span class="w"> </span><span class="o"><-</span><span class="w"> </span><span class="n">as.data.frame</span><span class="p">(</span><span class="n">Titanic</span><span class="p">)</span><span class="w"> |
| </span><span class="n">training</span><span class="w"> </span><span class="o"><-</span><span class="w"> </span><span class="n">createDataFrame</span><span class="p">(</span><span class="n">t</span><span class="p">)</span><span class="w"> |
| |
| </span><span class="c1"># fit Linear SVM model</span><span class="w"> |
| </span><span class="n">model</span><span class="w"> </span><span class="o"><-</span><span class="w"> </span><span class="n">spark.svmLinear</span><span class="p">(</span><span class="n">training</span><span class="p">,</span><span class="w"> </span><span class="n">Survived</span><span class="w"> </span><span class="o">~</span><span class="w"> </span><span class="n">.</span><span class="p">,</span><span class="w"> </span><span class="n">regParam</span><span class="w"> </span><span class="o">=</span><span class="w"> </span><span class="m">0.01</span><span class="p">,</span><span class="w"> </span><span class="n">maxIter</span><span class="w"> </span><span class="o">=</span><span class="w"> </span><span class="m">10</span><span class="p">)</span><span class="w"> |
| |
| </span><span class="c1"># Model summary</span><span class="w"> |
| </span><span class="n">summary</span><span class="p">(</span><span class="n">model</span><span class="p">)</span><span class="w"> |
| |
| </span><span class="c1"># Prediction</span><span class="w"> |
| </span><span class="n">prediction</span><span class="w"> </span><span class="o"><-</span><span class="w"> </span><span class="n">predict</span><span class="p">(</span><span class="n">model</span><span class="p">,</span><span class="w"> </span><span class="n">training</span><span class="p">)</span><span class="w"> |
| </span><span class="n">showDF</span><span class="p">(</span><span class="n">prediction</span><span class="p">)</span></code></pre></div> |
| <div><small>Find full example code at "examples/src/main/r/ml/svmLinear.R" in the Spark repo.</small></div> |
| </div> |
| |
| </div> |
| |
| <h2 id="one-vs-rest-classifier-aka-one-vs-all">One-vs-Rest classifier (a.k.a. One-vs-All)</h2> |
| |
| <p><a href="http://en.wikipedia.org/wiki/Multiclass_classification#One-vs.-rest">OneVsRest</a> is an example of a machine learning reduction for performing multiclass classification given a base classifier that can perform binary classification efficiently. It is also known as “One-vs-All.”</p> |
| |
| <p><code class="language-plaintext highlighter-rouge">OneVsRest</code> is implemented as an <code class="language-plaintext highlighter-rouge">Estimator</code>. For the base classifier, it takes instances of <code class="language-plaintext highlighter-rouge">Classifier</code> and creates a binary classification problem for each of the k classes. The classifier for class i is trained to predict whether the label is i or not, distinguishing class i from all other classes.</p> |
| |
| <p>Predictions are done by evaluating each binary classifier and the index of the most confident classifier is output as label.</p> |
| |
| <p><strong>Examples</strong></p> |
| |
| <p>The example below demonstrates how to load the |
| <a href="http://www.csie.ntu.edu.tw/~cjlin/libsvmtools/datasets/multiclass/iris.scale">Iris dataset</a>, parse it as a DataFrame and perform multiclass classification using <code class="language-plaintext highlighter-rouge">OneVsRest</code>. The test error is calculated to measure the algorithm accuracy.</p> |
| |
| <div class="codetabs"> |
| <div data-lang="scala"> |
| |
| <p>Refer to the <a href="api/scala/org/apache/spark/ml/classification/OneVsRest.html">Scala API docs</a> for more details.</p> |
| |
| <div class="highlight"><pre class="codehilite"><code><span class="k">import</span> <span class="nn">org.apache.spark.ml.classification.</span><span class="o">{</span><span class="nc">LogisticRegression</span><span class="o">,</span> <span class="nc">OneVsRest</span><span class="o">}</span> |
| <span class="k">import</span> <span class="nn">org.apache.spark.ml.evaluation.MulticlassClassificationEvaluator</span> |
| |
| <span class="c1">// load data file.</span> |
| <span class="k">val</span> <span class="nv">inputData</span> <span class="k">=</span> <span class="nv">spark</span><span class="o">.</span><span class="py">read</span><span class="o">.</span><span class="py">format</span><span class="o">(</span><span class="s">"libsvm"</span><span class="o">)</span> |
| <span class="o">.</span><span class="py">load</span><span class="o">(</span><span class="s">"data/mllib/sample_multiclass_classification_data.txt"</span><span class="o">)</span> |
| |
| <span class="c1">// generate the train/test split.</span> |
| <span class="k">val</span> <span class="nv">Array</span><span class="o">(</span><span class="n">train</span><span class="o">,</span> <span class="n">test</span><span class="o">)</span> <span class="k">=</span> <span class="nv">inputData</span><span class="o">.</span><span class="py">randomSplit</span><span class="o">(</span><span class="nc">Array</span><span class="o">(</span><span class="mf">0.8</span><span class="o">,</span> <span class="mf">0.2</span><span class="o">))</span> |
| |
| <span class="c1">// instantiate the base classifier</span> |
| <span class="k">val</span> <span class="nv">classifier</span> <span class="k">=</span> <span class="k">new</span> <span class="nc">LogisticRegression</span><span class="o">()</span> |
| <span class="o">.</span><span class="py">setMaxIter</span><span class="o">(</span><span class="mi">10</span><span class="o">)</span> |
| <span class="o">.</span><span class="py">setTol</span><span class="o">(</span><span class="mi">1</span><span class="n">E</span><span class="o">-</span><span class="mi">6</span><span class="o">)</span> |
| <span class="o">.</span><span class="py">setFitIntercept</span><span class="o">(</span><span class="kc">true</span><span class="o">)</span> |
| |
| <span class="c1">// instantiate the One Vs Rest Classifier.</span> |
| <span class="k">val</span> <span class="nv">ovr</span> <span class="k">=</span> <span class="k">new</span> <span class="nc">OneVsRest</span><span class="o">().</span><span class="py">setClassifier</span><span class="o">(</span><span class="n">classifier</span><span class="o">)</span> |
| |
| <span class="c1">// train the multiclass model.</span> |
| <span class="k">val</span> <span class="nv">ovrModel</span> <span class="k">=</span> <span class="nv">ovr</span><span class="o">.</span><span class="py">fit</span><span class="o">(</span><span class="n">train</span><span class="o">)</span> |
| |
| <span class="c1">// score the model on test data.</span> |
| <span class="k">val</span> <span class="nv">predictions</span> <span class="k">=</span> <span class="nv">ovrModel</span><span class="o">.</span><span class="py">transform</span><span class="o">(</span><span class="n">test</span><span class="o">)</span> |
| |
| <span class="c1">// obtain evaluator.</span> |
| <span class="k">val</span> <span class="nv">evaluator</span> <span class="k">=</span> <span class="k">new</span> <span class="nc">MulticlassClassificationEvaluator</span><span class="o">()</span> |
| <span class="o">.</span><span class="py">setMetricName</span><span class="o">(</span><span class="s">"accuracy"</span><span class="o">)</span> |
| |
| <span class="c1">// compute the classification error on test data.</span> |
| <span class="k">val</span> <span class="nv">accuracy</span> <span class="k">=</span> <span class="nv">evaluator</span><span class="o">.</span><span class="py">evaluate</span><span class="o">(</span><span class="n">predictions</span><span class="o">)</span> |
| <span class="nf">println</span><span class="o">(</span><span class="n">s</span><span class="s">"Test Error = ${1 - accuracy}"</span><span class="o">)</span></code></pre></div> |
| <div><small>Find full example code at "examples/src/main/scala/org/apache/spark/examples/ml/OneVsRestExample.scala" in the Spark repo.</small></div> |
| </div> |
| |
| <div data-lang="java"> |
| |
| <p>Refer to the <a href="api/java/org/apache/spark/ml/classification/OneVsRest.html">Java API docs</a> for more details.</p> |
| |
| <div class="highlight"><pre class="codehilite"><code><span class="kn">import</span> <span class="nn">org.apache.spark.ml.classification.LogisticRegression</span><span class="o">;</span> |
| <span class="kn">import</span> <span class="nn">org.apache.spark.ml.classification.OneVsRest</span><span class="o">;</span> |
| <span class="kn">import</span> <span class="nn">org.apache.spark.ml.classification.OneVsRestModel</span><span class="o">;</span> |
| <span class="kn">import</span> <span class="nn">org.apache.spark.ml.evaluation.MulticlassClassificationEvaluator</span><span class="o">;</span> |
| <span class="kn">import</span> <span class="nn">org.apache.spark.sql.Dataset</span><span class="o">;</span> |
| <span class="kn">import</span> <span class="nn">org.apache.spark.sql.Row</span><span class="o">;</span> |
| |
| <span class="c1">// load data file.</span> |
| <span class="nc">Dataset</span><span class="o"><</span><span class="nc">Row</span><span class="o">></span> <span class="n">inputData</span> <span class="o">=</span> <span class="n">spark</span><span class="o">.</span><span class="na">read</span><span class="o">().</span><span class="na">format</span><span class="o">(</span><span class="s">"libsvm"</span><span class="o">)</span> |
| <span class="o">.</span><span class="na">load</span><span class="o">(</span><span class="s">"data/mllib/sample_multiclass_classification_data.txt"</span><span class="o">);</span> |
| |
| <span class="c1">// generate the train/test split.</span> |
| <span class="nc">Dataset</span><span class="o"><</span><span class="nc">Row</span><span class="o">>[]</span> <span class="n">tmp</span> <span class="o">=</span> <span class="n">inputData</span><span class="o">.</span><span class="na">randomSplit</span><span class="o">(</span><span class="k">new</span> <span class="kt">double</span><span class="o">[]{</span><span class="mf">0.8</span><span class="o">,</span> <span class="mf">0.2</span><span class="o">});</span> |
| <span class="nc">Dataset</span><span class="o"><</span><span class="nc">Row</span><span class="o">></span> <span class="n">train</span> <span class="o">=</span> <span class="n">tmp</span><span class="o">[</span><span class="mi">0</span><span class="o">];</span> |
| <span class="nc">Dataset</span><span class="o"><</span><span class="nc">Row</span><span class="o">></span> <span class="n">test</span> <span class="o">=</span> <span class="n">tmp</span><span class="o">[</span><span class="mi">1</span><span class="o">];</span> |
| |
| <span class="c1">// configure the base classifier.</span> |
| <span class="nc">LogisticRegression</span> <span class="n">classifier</span> <span class="o">=</span> <span class="k">new</span> <span class="nc">LogisticRegression</span><span class="o">()</span> |
| <span class="o">.</span><span class="na">setMaxIter</span><span class="o">(</span><span class="mi">10</span><span class="o">)</span> |
| <span class="o">.</span><span class="na">setTol</span><span class="o">(</span><span class="mi">1</span><span class="no">E</span><span class="o">-</span><span class="mi">6</span><span class="o">)</span> |
| <span class="o">.</span><span class="na">setFitIntercept</span><span class="o">(</span><span class="kc">true</span><span class="o">);</span> |
| |
| <span class="c1">// instantiate the One Vs Rest Classifier.</span> |
| <span class="nc">OneVsRest</span> <span class="n">ovr</span> <span class="o">=</span> <span class="k">new</span> <span class="nc">OneVsRest</span><span class="o">().</span><span class="na">setClassifier</span><span class="o">(</span><span class="n">classifier</span><span class="o">);</span> |
| |
| <span class="c1">// train the multiclass model.</span> |
| <span class="nc">OneVsRestModel</span> <span class="n">ovrModel</span> <span class="o">=</span> <span class="n">ovr</span><span class="o">.</span><span class="na">fit</span><span class="o">(</span><span class="n">train</span><span class="o">);</span> |
| |
| <span class="c1">// score the model on test data.</span> |
| <span class="nc">Dataset</span><span class="o"><</span><span class="nc">Row</span><span class="o">></span> <span class="n">predictions</span> <span class="o">=</span> <span class="n">ovrModel</span><span class="o">.</span><span class="na">transform</span><span class="o">(</span><span class="n">test</span><span class="o">)</span> |
| <span class="o">.</span><span class="na">select</span><span class="o">(</span><span class="s">"prediction"</span><span class="o">,</span> <span class="s">"label"</span><span class="o">);</span> |
| |
| <span class="c1">// obtain evaluator.</span> |
| <span class="nc">MulticlassClassificationEvaluator</span> <span class="n">evaluator</span> <span class="o">=</span> <span class="k">new</span> <span class="nc">MulticlassClassificationEvaluator</span><span class="o">()</span> |
| <span class="o">.</span><span class="na">setMetricName</span><span class="o">(</span><span class="s">"accuracy"</span><span class="o">);</span> |
| |
| <span class="c1">// compute the classification error on test data.</span> |
| <span class="kt">double</span> <span class="n">accuracy</span> <span class="o">=</span> <span class="n">evaluator</span><span class="o">.</span><span class="na">evaluate</span><span class="o">(</span><span class="n">predictions</span><span class="o">);</span> |
| <span class="nc">System</span><span class="o">.</span><span class="na">out</span><span class="o">.</span><span class="na">println</span><span class="o">(</span><span class="s">"Test Error = "</span> <span class="o">+</span> <span class="o">(</span><span class="mi">1</span> <span class="o">-</span> <span class="n">accuracy</span><span class="o">));</span></code></pre></div> |
| <div><small>Find full example code at "examples/src/main/java/org/apache/spark/examples/ml/JavaOneVsRestExample.java" in the Spark repo.</small></div> |
| </div> |
| |
| <div data-lang="python"> |
| |
| <p>Refer to the <a href="api/python/reference/api/pyspark.ml.classification.OneVsRest.html">Python API docs</a> for more details.</p> |
| |
| <div class="highlight"><pre class="codehilite"><code><span class="kn">from</span> <span class="nn">pyspark.ml.classification</span> <span class="kn">import</span> <span class="n">LogisticRegression</span><span class="p">,</span> <span class="n">OneVsRest</span> |
| <span class="kn">from</span> <span class="nn">pyspark.ml.evaluation</span> <span class="kn">import</span> <span class="n">MulticlassClassificationEvaluator</span> |
| |
| <span class="c1"># load data file. |
| </span><span class="n">inputData</span> <span class="o">=</span> <span class="n">spark</span><span class="p">.</span><span class="n">read</span><span class="p">.</span><span class="nb">format</span><span class="p">(</span><span class="s">"libsvm"</span><span class="p">)</span> \ |
| <span class="p">.</span><span class="n">load</span><span class="p">(</span><span class="s">"data/mllib/sample_multiclass_classification_data.txt"</span><span class="p">)</span> |
| |
| <span class="c1"># generate the train/test split. |
| </span><span class="p">(</span><span class="n">train</span><span class="p">,</span> <span class="n">test</span><span class="p">)</span> <span class="o">=</span> <span class="n">inputData</span><span class="p">.</span><span class="n">randomSplit</span><span class="p">([</span><span class="mf">0.8</span><span class="p">,</span> <span class="mf">0.2</span><span class="p">])</span> |
| |
| <span class="c1"># instantiate the base classifier. |
| </span><span class="n">lr</span> <span class="o">=</span> <span class="n">LogisticRegression</span><span class="p">(</span><span class="n">maxIter</span><span class="o">=</span><span class="mi">10</span><span class="p">,</span> <span class="n">tol</span><span class="o">=</span><span class="mf">1E-6</span><span class="p">,</span> <span class="n">fitIntercept</span><span class="o">=</span><span class="bp">True</span><span class="p">)</span> |
| |
| <span class="c1"># instantiate the One Vs Rest Classifier. |
| </span><span class="n">ovr</span> <span class="o">=</span> <span class="n">OneVsRest</span><span class="p">(</span><span class="n">classifier</span><span class="o">=</span><span class="n">lr</span><span class="p">)</span> |
| |
| <span class="c1"># train the multiclass model. |
| </span><span class="n">ovrModel</span> <span class="o">=</span> <span class="n">ovr</span><span class="p">.</span><span class="n">fit</span><span class="p">(</span><span class="n">train</span><span class="p">)</span> |
| |
| <span class="c1"># score the model on test data. |
| </span><span class="n">predictions</span> <span class="o">=</span> <span class="n">ovrModel</span><span class="p">.</span><span class="n">transform</span><span class="p">(</span><span class="n">test</span><span class="p">)</span> |
| |
| <span class="c1"># obtain evaluator. |
| </span><span class="n">evaluator</span> <span class="o">=</span> <span class="n">MulticlassClassificationEvaluator</span><span class="p">(</span><span class="n">metricName</span><span class="o">=</span><span class="s">"accuracy"</span><span class="p">)</span> |
| |
| <span class="c1"># compute the classification error on test data. |
| </span><span class="n">accuracy</span> <span class="o">=</span> <span class="n">evaluator</span><span class="p">.</span><span class="n">evaluate</span><span class="p">(</span><span class="n">predictions</span><span class="p">)</span> |
| <span class="k">print</span><span class="p">(</span><span class="s">"Test Error = %g"</span> <span class="o">%</span> <span class="p">(</span><span class="mf">1.0</span> <span class="o">-</span> <span class="n">accuracy</span><span class="p">))</span></code></pre></div> |
| <div><small>Find full example code at "examples/src/main/python/ml/one_vs_rest_example.py" in the Spark repo.</small></div> |
| </div> |
| </div> |
| |
| <h2 id="naive-bayes">Naive Bayes</h2> |
| |
| <p><a href="http://en.wikipedia.org/wiki/Naive_Bayes_classifier">Naive Bayes classifiers</a> are a family of simple |
| probabilistic, multiclass classifiers based on applying Bayes’ theorem with strong (naive) independence |
| assumptions between every pair of features.</p> |
| |
| <p>Naive Bayes can be trained very efficiently. With a single pass over the training data, |
| it computes the conditional probability distribution of each feature given each label. |
| For prediction, it applies Bayes’ theorem to compute the conditional probability distribution |
| of each label given an observation.</p> |
| |
| <p>MLlib supports <a href="http://en.wikipedia.org/wiki/Naive_Bayes_classifier#Multinomial_naive_Bayes">Multinomial naive Bayes</a>, |
| <a href="https://people.csail.mit.edu/jrennie/papers/icml03-nb.pdf">Complement naive Bayes</a>, |
| <a href="http://nlp.stanford.edu/IR-book/html/htmledition/the-bernoulli-model-1.html">Bernoulli naive Bayes</a> |
| and <a href="https://en.wikipedia.org/wiki/Naive_Bayes_classifier#Gaussian_naive_Bayes">Gaussian naive Bayes</a>.</p> |
| |
| <p><em>Input data</em>: |
| These Multinomial, Complement and Bernoulli models are typically used for <a href="http://nlp.stanford.edu/IR-book/html/htmledition/naive-bayes-text-classification-1.html">document classification</a>. |
| Within that context, each observation is a document and each feature represents a term. |
| A feature’s value is the frequency of the term (in Multinomial or Complement Naive Bayes) or |
| a zero or one indicating whether the term was found in the document (in Bernoulli Naive Bayes). |
| Feature values for Multinomial and Bernoulli models must be <em>non-negative</em>. The model type is selected with an optional parameter |
| “multinomial”, “complement”, “bernoulli” or “gaussian”, with “multinomial” as the default. |
| For document classification, the input feature vectors should usually be sparse vectors. |
| Since the training data is only used once, it is not necessary to cache it.</p> |
| |
| <p><a href="http://en.wikipedia.org/wiki/Lidstone_smoothing">Additive smoothing</a> can be used by |
| setting the parameter $\lambda$ (default to $1.0$).</p> |
| |
| <p><strong>Examples</strong></p> |
| |
| <div class="codetabs"> |
| <div data-lang="scala"> |
| |
| <p>Refer to the <a href="api/scala/org/apache/spark/ml/classification/NaiveBayes.html">Scala API docs</a> for more details.</p> |
| |
| <div class="highlight"><pre class="codehilite"><code><span class="k">import</span> <span class="nn">org.apache.spark.ml.classification.NaiveBayes</span> |
| <span class="k">import</span> <span class="nn">org.apache.spark.ml.evaluation.MulticlassClassificationEvaluator</span> |
| |
| <span class="c1">// Load the data stored in LIBSVM format as a DataFrame.</span> |
| <span class="k">val</span> <span class="nv">data</span> <span class="k">=</span> <span class="nv">spark</span><span class="o">.</span><span class="py">read</span><span class="o">.</span><span class="py">format</span><span class="o">(</span><span class="s">"libsvm"</span><span class="o">).</span><span class="py">load</span><span class="o">(</span><span class="s">"data/mllib/sample_libsvm_data.txt"</span><span class="o">)</span> |
| |
| <span class="c1">// Split the data into training and test sets (30% held out for testing)</span> |
| <span class="k">val</span> <span class="nv">Array</span><span class="o">(</span><span class="n">trainingData</span><span class="o">,</span> <span class="n">testData</span><span class="o">)</span> <span class="k">=</span> <span class="nv">data</span><span class="o">.</span><span class="py">randomSplit</span><span class="o">(</span><span class="nc">Array</span><span class="o">(</span><span class="mf">0.7</span><span class="o">,</span> <span class="mf">0.3</span><span class="o">),</span> <span class="n">seed</span> <span class="k">=</span> <span class="mi">1234L</span><span class="o">)</span> |
| |
| <span class="c1">// Train a NaiveBayes model.</span> |
| <span class="k">val</span> <span class="nv">model</span> <span class="k">=</span> <span class="k">new</span> <span class="nc">NaiveBayes</span><span class="o">()</span> |
| <span class="o">.</span><span class="py">fit</span><span class="o">(</span><span class="n">trainingData</span><span class="o">)</span> |
| |
| <span class="c1">// Select example rows to display.</span> |
| <span class="k">val</span> <span class="nv">predictions</span> <span class="k">=</span> <span class="nv">model</span><span class="o">.</span><span class="py">transform</span><span class="o">(</span><span class="n">testData</span><span class="o">)</span> |
| <span class="nv">predictions</span><span class="o">.</span><span class="py">show</span><span class="o">()</span> |
| |
| <span class="c1">// Select (prediction, true label) and compute test error</span> |
| <span class="k">val</span> <span class="nv">evaluator</span> <span class="k">=</span> <span class="k">new</span> <span class="nc">MulticlassClassificationEvaluator</span><span class="o">()</span> |
| <span class="o">.</span><span class="py">setLabelCol</span><span class="o">(</span><span class="s">"label"</span><span class="o">)</span> |
| <span class="o">.</span><span class="py">setPredictionCol</span><span class="o">(</span><span class="s">"prediction"</span><span class="o">)</span> |
| <span class="o">.</span><span class="py">setMetricName</span><span class="o">(</span><span class="s">"accuracy"</span><span class="o">)</span> |
| <span class="k">val</span> <span class="nv">accuracy</span> <span class="k">=</span> <span class="nv">evaluator</span><span class="o">.</span><span class="py">evaluate</span><span class="o">(</span><span class="n">predictions</span><span class="o">)</span> |
| <span class="nf">println</span><span class="o">(</span><span class="n">s</span><span class="s">"Test set accuracy = $accuracy"</span><span class="o">)</span></code></pre></div> |
| <div><small>Find full example code at "examples/src/main/scala/org/apache/spark/examples/ml/NaiveBayesExample.scala" in the Spark repo.</small></div> |
| </div> |
| |
| <div data-lang="java"> |
| |
| <p>Refer to the <a href="api/java/org/apache/spark/ml/classification/NaiveBayes.html">Java API docs</a> for more details.</p> |
| |
| <div class="highlight"><pre class="codehilite"><code><span class="kn">import</span> <span class="nn">org.apache.spark.ml.classification.NaiveBayes</span><span class="o">;</span> |
| <span class="kn">import</span> <span class="nn">org.apache.spark.ml.classification.NaiveBayesModel</span><span class="o">;</span> |
| <span class="kn">import</span> <span class="nn">org.apache.spark.ml.evaluation.MulticlassClassificationEvaluator</span><span class="o">;</span> |
| <span class="kn">import</span> <span class="nn">org.apache.spark.sql.Dataset</span><span class="o">;</span> |
| <span class="kn">import</span> <span class="nn">org.apache.spark.sql.Row</span><span class="o">;</span> |
| <span class="kn">import</span> <span class="nn">org.apache.spark.sql.SparkSession</span><span class="o">;</span> |
| |
| <span class="c1">// Load training data</span> |
| <span class="nc">Dataset</span><span class="o"><</span><span class="nc">Row</span><span class="o">></span> <span class="n">dataFrame</span> <span class="o">=</span> |
| <span class="n">spark</span><span class="o">.</span><span class="na">read</span><span class="o">().</span><span class="na">format</span><span class="o">(</span><span class="s">"libsvm"</span><span class="o">).</span><span class="na">load</span><span class="o">(</span><span class="s">"data/mllib/sample_libsvm_data.txt"</span><span class="o">);</span> |
| <span class="c1">// Split the data into train and test</span> |
| <span class="nc">Dataset</span><span class="o"><</span><span class="nc">Row</span><span class="o">>[]</span> <span class="n">splits</span> <span class="o">=</span> <span class="n">dataFrame</span><span class="o">.</span><span class="na">randomSplit</span><span class="o">(</span><span class="k">new</span> <span class="kt">double</span><span class="o">[]{</span><span class="mf">0.6</span><span class="o">,</span> <span class="mf">0.4</span><span class="o">},</span> <span class="mi">1234L</span><span class="o">);</span> |
| <span class="nc">Dataset</span><span class="o"><</span><span class="nc">Row</span><span class="o">></span> <span class="n">train</span> <span class="o">=</span> <span class="n">splits</span><span class="o">[</span><span class="mi">0</span><span class="o">];</span> |
| <span class="nc">Dataset</span><span class="o"><</span><span class="nc">Row</span><span class="o">></span> <span class="n">test</span> <span class="o">=</span> <span class="n">splits</span><span class="o">[</span><span class="mi">1</span><span class="o">];</span> |
| |
| <span class="c1">// create the trainer and set its parameters</span> |
| <span class="nc">NaiveBayes</span> <span class="n">nb</span> <span class="o">=</span> <span class="k">new</span> <span class="nc">NaiveBayes</span><span class="o">();</span> |
| |
| <span class="c1">// train the model</span> |
| <span class="nc">NaiveBayesModel</span> <span class="n">model</span> <span class="o">=</span> <span class="n">nb</span><span class="o">.</span><span class="na">fit</span><span class="o">(</span><span class="n">train</span><span class="o">);</span> |
| |
| <span class="c1">// Select example rows to display.</span> |
| <span class="nc">Dataset</span><span class="o"><</span><span class="nc">Row</span><span class="o">></span> <span class="n">predictions</span> <span class="o">=</span> <span class="n">model</span><span class="o">.</span><span class="na">transform</span><span class="o">(</span><span class="n">test</span><span class="o">);</span> |
| <span class="n">predictions</span><span class="o">.</span><span class="na">show</span><span class="o">();</span> |
| |
| <span class="c1">// compute accuracy on the test set</span> |
| <span class="nc">MulticlassClassificationEvaluator</span> <span class="n">evaluator</span> <span class="o">=</span> <span class="k">new</span> <span class="nc">MulticlassClassificationEvaluator</span><span class="o">()</span> |
| <span class="o">.</span><span class="na">setLabelCol</span><span class="o">(</span><span class="s">"label"</span><span class="o">)</span> |
| <span class="o">.</span><span class="na">setPredictionCol</span><span class="o">(</span><span class="s">"prediction"</span><span class="o">)</span> |
| <span class="o">.</span><span class="na">setMetricName</span><span class="o">(</span><span class="s">"accuracy"</span><span class="o">);</span> |
| <span class="kt">double</span> <span class="n">accuracy</span> <span class="o">=</span> <span class="n">evaluator</span><span class="o">.</span><span class="na">evaluate</span><span class="o">(</span><span class="n">predictions</span><span class="o">);</span> |
| <span class="nc">System</span><span class="o">.</span><span class="na">out</span><span class="o">.</span><span class="na">println</span><span class="o">(</span><span class="s">"Test set accuracy = "</span> <span class="o">+</span> <span class="n">accuracy</span><span class="o">);</span></code></pre></div> |
| <div><small>Find full example code at "examples/src/main/java/org/apache/spark/examples/ml/JavaNaiveBayesExample.java" in the Spark repo.</small></div> |
| </div> |
| |
| <div data-lang="python"> |
| |
| <p>Refer to the <a href="api/python/reference/api/pyspark.ml.classification.NaiveBayes.html">Python API docs</a> for more details.</p> |
| |
| <div class="highlight"><pre class="codehilite"><code><span class="kn">from</span> <span class="nn">pyspark.ml.classification</span> <span class="kn">import</span> <span class="n">NaiveBayes</span> |
| <span class="kn">from</span> <span class="nn">pyspark.ml.evaluation</span> <span class="kn">import</span> <span class="n">MulticlassClassificationEvaluator</span> |
| |
| <span class="c1"># Load training data |
| </span><span class="n">data</span> <span class="o">=</span> <span class="n">spark</span><span class="p">.</span><span class="n">read</span><span class="p">.</span><span class="nb">format</span><span class="p">(</span><span class="s">"libsvm"</span><span class="p">)</span> \ |
| <span class="p">.</span><span class="n">load</span><span class="p">(</span><span class="s">"data/mllib/sample_libsvm_data.txt"</span><span class="p">)</span> |
| |
| <span class="c1"># Split the data into train and test |
| </span><span class="n">splits</span> <span class="o">=</span> <span class="n">data</span><span class="p">.</span><span class="n">randomSplit</span><span class="p">([</span><span class="mf">0.6</span><span class="p">,</span> <span class="mf">0.4</span><span class="p">],</span> <span class="mi">1234</span><span class="p">)</span> |
| <span class="n">train</span> <span class="o">=</span> <span class="n">splits</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span> |
| <span class="n">test</span> <span class="o">=</span> <span class="n">splits</span><span class="p">[</span><span class="mi">1</span><span class="p">]</span> |
| |
| <span class="c1"># create the trainer and set its parameters |
| </span><span class="n">nb</span> <span class="o">=</span> <span class="n">NaiveBayes</span><span class="p">(</span><span class="n">smoothing</span><span class="o">=</span><span class="mf">1.0</span><span class="p">,</span> <span class="n">modelType</span><span class="o">=</span><span class="s">"multinomial"</span><span class="p">)</span> |
| |
| <span class="c1"># train the model |
| </span><span class="n">model</span> <span class="o">=</span> <span class="n">nb</span><span class="p">.</span><span class="n">fit</span><span class="p">(</span><span class="n">train</span><span class="p">)</span> |
| |
| <span class="c1"># select example rows to display. |
| </span><span class="n">predictions</span> <span class="o">=</span> <span class="n">model</span><span class="p">.</span><span class="n">transform</span><span class="p">(</span><span class="n">test</span><span class="p">)</span> |
| <span class="n">predictions</span><span class="p">.</span><span class="n">show</span><span class="p">()</span> |
| |
| <span class="c1"># compute accuracy on the test set |
| </span><span class="n">evaluator</span> <span class="o">=</span> <span class="n">MulticlassClassificationEvaluator</span><span class="p">(</span><span class="n">labelCol</span><span class="o">=</span><span class="s">"label"</span><span class="p">,</span> <span class="n">predictionCol</span><span class="o">=</span><span class="s">"prediction"</span><span class="p">,</span> |
| <span class="n">metricName</span><span class="o">=</span><span class="s">"accuracy"</span><span class="p">)</span> |
| <span class="n">accuracy</span> <span class="o">=</span> <span class="n">evaluator</span><span class="p">.</span><span class="n">evaluate</span><span class="p">(</span><span class="n">predictions</span><span class="p">)</span> |
| <span class="k">print</span><span class="p">(</span><span class="s">"Test set accuracy = "</span> <span class="o">+</span> <span class="nb">str</span><span class="p">(</span><span class="n">accuracy</span><span class="p">))</span></code></pre></div> |
| <div><small>Find full example code at "examples/src/main/python/ml/naive_bayes_example.py" in the Spark repo.</small></div> |
| </div> |
| |
| <div data-lang="r"> |
| |
| <p>Refer to the <a href="api/R/spark.naiveBayes.html">R API docs</a> for more details.</p> |
| |
| <div class="highlight"><pre class="codehilite"><code><span class="c1"># Fit a Bernoulli naive Bayes model with spark.naiveBayes</span><span class="w"> |
| </span><span class="n">titanic</span><span class="w"> </span><span class="o"><-</span><span class="w"> </span><span class="n">as.data.frame</span><span class="p">(</span><span class="n">Titanic</span><span class="p">)</span><span class="w"> |
| </span><span class="n">titanicDF</span><span class="w"> </span><span class="o"><-</span><span class="w"> </span><span class="n">createDataFrame</span><span class="p">(</span><span class="n">titanic</span><span class="p">[</span><span class="n">titanic</span><span class="o">$</span><span class="n">Freq</span><span class="w"> </span><span class="o">></span><span class="w"> </span><span class="m">0</span><span class="p">,</span><span class="w"> </span><span class="m">-5</span><span class="p">])</span><span class="w"> |
| </span><span class="n">nbDF</span><span class="w"> </span><span class="o"><-</span><span class="w"> </span><span class="n">titanicDF</span><span class="w"> |
| </span><span class="n">nbTestDF</span><span class="w"> </span><span class="o"><-</span><span class="w"> </span><span class="n">titanicDF</span><span class="w"> |
| </span><span class="n">nbModel</span><span class="w"> </span><span class="o"><-</span><span class="w"> </span><span class="n">spark.naiveBayes</span><span class="p">(</span><span class="n">nbDF</span><span class="p">,</span><span class="w"> </span><span class="n">Survived</span><span class="w"> </span><span class="o">~</span><span class="w"> </span><span class="n">Class</span><span class="w"> </span><span class="o">+</span><span class="w"> </span><span class="n">Sex</span><span class="w"> </span><span class="o">+</span><span class="w"> </span><span class="n">Age</span><span class="p">)</span><span class="w"> |
| |
| </span><span class="c1"># Model summary</span><span class="w"> |
| </span><span class="n">summary</span><span class="p">(</span><span class="n">nbModel</span><span class="p">)</span><span class="w"> |
| |
| </span><span class="c1"># Prediction</span><span class="w"> |
| </span><span class="n">nbPredictions</span><span class="w"> </span><span class="o"><-</span><span class="w"> </span><span class="n">predict</span><span class="p">(</span><span class="n">nbModel</span><span class="p">,</span><span class="w"> </span><span class="n">nbTestDF</span><span class="p">)</span><span class="w"> |
| </span><span class="n">head</span><span class="p">(</span><span class="n">nbPredictions</span><span class="p">)</span></code></pre></div> |
| <div><small>Find full example code at "examples/src/main/r/ml/naiveBayes.R" in the Spark repo.</small></div> |
| </div> |
| |
| </div> |
| |
| <h2 id="factorization-machines-classifier">Factorization machines classifier</h2> |
| |
| <p>For more background and more details about the implementation of factorization machines, |
| refer to the <a href="ml-classification-regression.html#factorization-machines">Factorization Machines section</a>.</p> |
| |
| <p><strong>Examples</strong></p> |
| |
| <p>The following examples load a dataset in LibSVM format, split it into training and test sets, |
| train on the first dataset, and then evaluate on the held-out test set. |
| We scale features to be between 0 and 1 to prevent the exploding gradient problem.</p> |
| |
| <div class="codetabs"> |
| <div data-lang="scala"> |
| |
| <p>Refer to the <a href="api/scala/org/apache/spark/ml/classification/FMClassifier.html">Scala API docs</a> for more details.</p> |
| |
| <div class="highlight"><pre class="codehilite"><code><span class="k">import</span> <span class="nn">org.apache.spark.ml.Pipeline</span> |
| <span class="k">import</span> <span class="nn">org.apache.spark.ml.classification.</span><span class="o">{</span><span class="nc">FMClassificationModel</span><span class="o">,</span> <span class="nc">FMClassifier</span><span class="o">}</span> |
| <span class="k">import</span> <span class="nn">org.apache.spark.ml.evaluation.MulticlassClassificationEvaluator</span> |
| <span class="k">import</span> <span class="nn">org.apache.spark.ml.feature.</span><span class="o">{</span><span class="nc">IndexToString</span><span class="o">,</span> <span class="nc">MinMaxScaler</span><span class="o">,</span> <span class="nc">StringIndexer</span><span class="o">}</span> |
| |
| <span class="c1">// Load and parse the data file, converting it to a DataFrame.</span> |
| <span class="k">val</span> <span class="nv">data</span> <span class="k">=</span> <span class="nv">spark</span><span class="o">.</span><span class="py">read</span><span class="o">.</span><span class="py">format</span><span class="o">(</span><span class="s">"libsvm"</span><span class="o">).</span><span class="py">load</span><span class="o">(</span><span class="s">"data/mllib/sample_libsvm_data.txt"</span><span class="o">)</span> |
| |
| <span class="c1">// Index labels, adding metadata to the label column.</span> |
| <span class="c1">// Fit on whole dataset to include all labels in index.</span> |
| <span class="k">val</span> <span class="nv">labelIndexer</span> <span class="k">=</span> <span class="k">new</span> <span class="nc">StringIndexer</span><span class="o">()</span> |
| <span class="o">.</span><span class="py">setInputCol</span><span class="o">(</span><span class="s">"label"</span><span class="o">)</span> |
| <span class="o">.</span><span class="py">setOutputCol</span><span class="o">(</span><span class="s">"indexedLabel"</span><span class="o">)</span> |
| <span class="o">.</span><span class="py">fit</span><span class="o">(</span><span class="n">data</span><span class="o">)</span> |
| <span class="c1">// Scale features.</span> |
| <span class="k">val</span> <span class="nv">featureScaler</span> <span class="k">=</span> <span class="k">new</span> <span class="nc">MinMaxScaler</span><span class="o">()</span> |
| <span class="o">.</span><span class="py">setInputCol</span><span class="o">(</span><span class="s">"features"</span><span class="o">)</span> |
| <span class="o">.</span><span class="py">setOutputCol</span><span class="o">(</span><span class="s">"scaledFeatures"</span><span class="o">)</span> |
| <span class="o">.</span><span class="py">fit</span><span class="o">(</span><span class="n">data</span><span class="o">)</span> |
| |
| <span class="c1">// Split the data into training and test sets (30% held out for testing).</span> |
| <span class="k">val</span> <span class="nv">Array</span><span class="o">(</span><span class="n">trainingData</span><span class="o">,</span> <span class="n">testData</span><span class="o">)</span> <span class="k">=</span> <span class="nv">data</span><span class="o">.</span><span class="py">randomSplit</span><span class="o">(</span><span class="nc">Array</span><span class="o">(</span><span class="mf">0.7</span><span class="o">,</span> <span class="mf">0.3</span><span class="o">))</span> |
| |
| <span class="c1">// Train a FM model.</span> |
| <span class="k">val</span> <span class="nv">fm</span> <span class="k">=</span> <span class="k">new</span> <span class="nc">FMClassifier</span><span class="o">()</span> |
| <span class="o">.</span><span class="py">setLabelCol</span><span class="o">(</span><span class="s">"indexedLabel"</span><span class="o">)</span> |
| <span class="o">.</span><span class="py">setFeaturesCol</span><span class="o">(</span><span class="s">"scaledFeatures"</span><span class="o">)</span> |
| <span class="o">.</span><span class="py">setStepSize</span><span class="o">(</span><span class="mf">0.001</span><span class="o">)</span> |
| |
| <span class="c1">// Convert indexed labels back to original labels.</span> |
| <span class="k">val</span> <span class="nv">labelConverter</span> <span class="k">=</span> <span class="k">new</span> <span class="nc">IndexToString</span><span class="o">()</span> |
| <span class="o">.</span><span class="py">setInputCol</span><span class="o">(</span><span class="s">"prediction"</span><span class="o">)</span> |
| <span class="o">.</span><span class="py">setOutputCol</span><span class="o">(</span><span class="s">"predictedLabel"</span><span class="o">)</span> |
| <span class="o">.</span><span class="py">setLabels</span><span class="o">(</span><span class="nv">labelIndexer</span><span class="o">.</span><span class="py">labelsArray</span><span class="o">(</span><span class="mi">0</span><span class="o">))</span> |
| |
| <span class="c1">// Create a Pipeline.</span> |
| <span class="k">val</span> <span class="nv">pipeline</span> <span class="k">=</span> <span class="k">new</span> <span class="nc">Pipeline</span><span class="o">()</span> |
| <span class="o">.</span><span class="py">setStages</span><span class="o">(</span><span class="nc">Array</span><span class="o">(</span><span class="n">labelIndexer</span><span class="o">,</span> <span class="n">featureScaler</span><span class="o">,</span> <span class="n">fm</span><span class="o">,</span> <span class="n">labelConverter</span><span class="o">))</span> |
| |
| <span class="c1">// Train model.</span> |
| <span class="k">val</span> <span class="nv">model</span> <span class="k">=</span> <span class="nv">pipeline</span><span class="o">.</span><span class="py">fit</span><span class="o">(</span><span class="n">trainingData</span><span class="o">)</span> |
| |
| <span class="c1">// Make predictions.</span> |
| <span class="k">val</span> <span class="nv">predictions</span> <span class="k">=</span> <span class="nv">model</span><span class="o">.</span><span class="py">transform</span><span class="o">(</span><span class="n">testData</span><span class="o">)</span> |
| |
| <span class="c1">// Select example rows to display.</span> |
| <span class="nv">predictions</span><span class="o">.</span><span class="py">select</span><span class="o">(</span><span class="s">"predictedLabel"</span><span class="o">,</span> <span class="s">"label"</span><span class="o">,</span> <span class="s">"features"</span><span class="o">).</span><span class="py">show</span><span class="o">(</span><span class="mi">5</span><span class="o">)</span> |
| |
| <span class="c1">// Select (prediction, true label) and compute test accuracy.</span> |
| <span class="k">val</span> <span class="nv">evaluator</span> <span class="k">=</span> <span class="k">new</span> <span class="nc">MulticlassClassificationEvaluator</span><span class="o">()</span> |
| <span class="o">.</span><span class="py">setLabelCol</span><span class="o">(</span><span class="s">"indexedLabel"</span><span class="o">)</span> |
| <span class="o">.</span><span class="py">setPredictionCol</span><span class="o">(</span><span class="s">"prediction"</span><span class="o">)</span> |
| <span class="o">.</span><span class="py">setMetricName</span><span class="o">(</span><span class="s">"accuracy"</span><span class="o">)</span> |
| <span class="k">val</span> <span class="nv">accuracy</span> <span class="k">=</span> <span class="nv">evaluator</span><span class="o">.</span><span class="py">evaluate</span><span class="o">(</span><span class="n">predictions</span><span class="o">)</span> |
| <span class="nf">println</span><span class="o">(</span><span class="n">s</span><span class="s">"Test set accuracy = $accuracy"</span><span class="o">)</span> |
| |
| <span class="k">val</span> <span class="nv">fmModel</span> <span class="k">=</span> <span class="nv">model</span><span class="o">.</span><span class="py">stages</span><span class="o">(</span><span class="mi">2</span><span class="o">).</span><span class="py">asInstanceOf</span><span class="o">[</span><span class="kt">FMClassificationModel</span><span class="o">]</span> |
| <span class="nf">println</span><span class="o">(</span><span class="n">s</span><span class="s">"Factors: ${fmModel.factors} Linear: ${fmModel.linear} "</span> <span class="o">+</span> |
| <span class="n">s</span><span class="s">"Intercept: ${fmModel.intercept}"</span><span class="o">)</span></code></pre></div> |
| <div><small>Find full example code at "examples/src/main/scala/org/apache/spark/examples/ml/FMClassifierExample.scala" in the Spark repo.</small></div> |
| </div> |
| |
| <div data-lang="java"> |
| |
| <p>Refer to the <a href="api/java/org/apache/spark/ml/classification/FMClassifier.html">Java API docs</a> for more details.</p> |
| |
| <div class="highlight"><pre class="codehilite"><code><span class="kn">import</span> <span class="nn">org.apache.spark.ml.Pipeline</span><span class="o">;</span> |
| <span class="kn">import</span> <span class="nn">org.apache.spark.ml.PipelineModel</span><span class="o">;</span> |
| <span class="kn">import</span> <span class="nn">org.apache.spark.ml.PipelineStage</span><span class="o">;</span> |
| <span class="kn">import</span> <span class="nn">org.apache.spark.ml.classification.FMClassificationModel</span><span class="o">;</span> |
| <span class="kn">import</span> <span class="nn">org.apache.spark.ml.classification.FMClassifier</span><span class="o">;</span> |
| <span class="kn">import</span> <span class="nn">org.apache.spark.ml.evaluation.MulticlassClassificationEvaluator</span><span class="o">;</span> |
| <span class="kn">import</span> <span class="nn">org.apache.spark.ml.feature.*</span><span class="o">;</span> |
| <span class="kn">import</span> <span class="nn">org.apache.spark.sql.Dataset</span><span class="o">;</span> |
| <span class="kn">import</span> <span class="nn">org.apache.spark.sql.Row</span><span class="o">;</span> |
| <span class="kn">import</span> <span class="nn">org.apache.spark.sql.SparkSession</span><span class="o">;</span> |
| |
| <span class="c1">// Load and parse the data file, converting it to a DataFrame.</span> |
| <span class="nc">Dataset</span><span class="o"><</span><span class="nc">Row</span><span class="o">></span> <span class="n">data</span> <span class="o">=</span> <span class="n">spark</span> |
| <span class="o">.</span><span class="na">read</span><span class="o">()</span> |
| <span class="o">.</span><span class="na">format</span><span class="o">(</span><span class="s">"libsvm"</span><span class="o">)</span> |
| <span class="o">.</span><span class="na">load</span><span class="o">(</span><span class="s">"data/mllib/sample_libsvm_data.txt"</span><span class="o">);</span> |
| |
| <span class="c1">// Index labels, adding metadata to the label column.</span> |
| <span class="c1">// Fit on whole dataset to include all labels in index.</span> |
| <span class="nc">StringIndexerModel</span> <span class="n">labelIndexer</span> <span class="o">=</span> <span class="k">new</span> <span class="nc">StringIndexer</span><span class="o">()</span> |
| <span class="o">.</span><span class="na">setInputCol</span><span class="o">(</span><span class="s">"label"</span><span class="o">)</span> |
| <span class="o">.</span><span class="na">setOutputCol</span><span class="o">(</span><span class="s">"indexedLabel"</span><span class="o">)</span> |
| <span class="o">.</span><span class="na">fit</span><span class="o">(</span><span class="n">data</span><span class="o">);</span> |
| <span class="c1">// Scale features.</span> |
| <span class="nc">MinMaxScalerModel</span> <span class="n">featureScaler</span> <span class="o">=</span> <span class="k">new</span> <span class="nc">MinMaxScaler</span><span class="o">()</span> |
| <span class="o">.</span><span class="na">setInputCol</span><span class="o">(</span><span class="s">"features"</span><span class="o">)</span> |
| <span class="o">.</span><span class="na">setOutputCol</span><span class="o">(</span><span class="s">"scaledFeatures"</span><span class="o">)</span> |
| <span class="o">.</span><span class="na">fit</span><span class="o">(</span><span class="n">data</span><span class="o">);</span> |
| |
| <span class="c1">// Split the data into training and test sets (30% held out for testing)</span> |
| <span class="nc">Dataset</span><span class="o"><</span><span class="nc">Row</span><span class="o">>[]</span> <span class="n">splits</span> <span class="o">=</span> <span class="n">data</span><span class="o">.</span><span class="na">randomSplit</span><span class="o">(</span><span class="k">new</span> <span class="kt">double</span><span class="o">[]</span> <span class="o">{</span><span class="mf">0.7</span><span class="o">,</span> <span class="mf">0.3</span><span class="o">});</span> |
| <span class="nc">Dataset</span><span class="o"><</span><span class="nc">Row</span><span class="o">></span> <span class="n">trainingData</span> <span class="o">=</span> <span class="n">splits</span><span class="o">[</span><span class="mi">0</span><span class="o">];</span> |
| <span class="nc">Dataset</span><span class="o"><</span><span class="nc">Row</span><span class="o">></span> <span class="n">testData</span> <span class="o">=</span> <span class="n">splits</span><span class="o">[</span><span class="mi">1</span><span class="o">];</span> |
| |
| <span class="c1">// Train a FM model.</span> |
| <span class="nc">FMClassifier</span> <span class="n">fm</span> <span class="o">=</span> <span class="k">new</span> <span class="nc">FMClassifier</span><span class="o">()</span> |
| <span class="o">.</span><span class="na">setLabelCol</span><span class="o">(</span><span class="s">"indexedLabel"</span><span class="o">)</span> |
| <span class="o">.</span><span class="na">setFeaturesCol</span><span class="o">(</span><span class="s">"scaledFeatures"</span><span class="o">)</span> |
| <span class="o">.</span><span class="na">setStepSize</span><span class="o">(</span><span class="mf">0.001</span><span class="o">);</span> |
| |
| <span class="c1">// Convert indexed labels back to original labels.</span> |
| <span class="nc">IndexToString</span> <span class="n">labelConverter</span> <span class="o">=</span> <span class="k">new</span> <span class="nc">IndexToString</span><span class="o">()</span> |
| <span class="o">.</span><span class="na">setInputCol</span><span class="o">(</span><span class="s">"prediction"</span><span class="o">)</span> |
| <span class="o">.</span><span class="na">setOutputCol</span><span class="o">(</span><span class="s">"predictedLabel"</span><span class="o">)</span> |
| <span class="o">.</span><span class="na">setLabels</span><span class="o">(</span><span class="n">labelIndexer</span><span class="o">.</span><span class="na">labelsArray</span><span class="o">()[</span><span class="mi">0</span><span class="o">]);</span> |
| |
| <span class="c1">// Create a Pipeline.</span> |
| <span class="nc">Pipeline</span> <span class="n">pipeline</span> <span class="o">=</span> <span class="k">new</span> <span class="nc">Pipeline</span><span class="o">()</span> |
| <span class="o">.</span><span class="na">setStages</span><span class="o">(</span><span class="k">new</span> <span class="nc">PipelineStage</span><span class="o">[]</span> <span class="o">{</span><span class="n">labelIndexer</span><span class="o">,</span> <span class="n">featureScaler</span><span class="o">,</span> <span class="n">fm</span><span class="o">,</span> <span class="n">labelConverter</span><span class="o">});</span> |
| |
| <span class="c1">// Train model.</span> |
| <span class="nc">PipelineModel</span> <span class="n">model</span> <span class="o">=</span> <span class="n">pipeline</span><span class="o">.</span><span class="na">fit</span><span class="o">(</span><span class="n">trainingData</span><span class="o">);</span> |
| |
| <span class="c1">// Make predictions.</span> |
| <span class="nc">Dataset</span><span class="o"><</span><span class="nc">Row</span><span class="o">></span> <span class="n">predictions</span> <span class="o">=</span> <span class="n">model</span><span class="o">.</span><span class="na">transform</span><span class="o">(</span><span class="n">testData</span><span class="o">);</span> |
| |
| <span class="c1">// Select example rows to display.</span> |
| <span class="n">predictions</span><span class="o">.</span><span class="na">select</span><span class="o">(</span><span class="s">"predictedLabel"</span><span class="o">,</span> <span class="s">"label"</span><span class="o">,</span> <span class="s">"features"</span><span class="o">).</span><span class="na">show</span><span class="o">(</span><span class="mi">5</span><span class="o">);</span> |
| |
| <span class="c1">// Select (prediction, true label) and compute test accuracy.</span> |
| <span class="nc">MulticlassClassificationEvaluator</span> <span class="n">evaluator</span> <span class="o">=</span> <span class="k">new</span> <span class="nc">MulticlassClassificationEvaluator</span><span class="o">()</span> |
| <span class="o">.</span><span class="na">setLabelCol</span><span class="o">(</span><span class="s">"indexedLabel"</span><span class="o">)</span> |
| <span class="o">.</span><span class="na">setPredictionCol</span><span class="o">(</span><span class="s">"prediction"</span><span class="o">)</span> |
| <span class="o">.</span><span class="na">setMetricName</span><span class="o">(</span><span class="s">"accuracy"</span><span class="o">);</span> |
| <span class="kt">double</span> <span class="n">accuracy</span> <span class="o">=</span> <span class="n">evaluator</span><span class="o">.</span><span class="na">evaluate</span><span class="o">(</span><span class="n">predictions</span><span class="o">);</span> |
| <span class="nc">System</span><span class="o">.</span><span class="na">out</span><span class="o">.</span><span class="na">println</span><span class="o">(</span><span class="s">"Test Accuracy = "</span> <span class="o">+</span> <span class="n">accuracy</span><span class="o">);</span> |
| |
| <span class="nc">FMClassificationModel</span> <span class="n">fmModel</span> <span class="o">=</span> <span class="o">(</span><span class="nc">FMClassificationModel</span><span class="o">)(</span><span class="n">model</span><span class="o">.</span><span class="na">stages</span><span class="o">()[</span><span class="mi">2</span><span class="o">]);</span> |
| <span class="nc">System</span><span class="o">.</span><span class="na">out</span><span class="o">.</span><span class="na">println</span><span class="o">(</span><span class="s">"Factors: "</span> <span class="o">+</span> <span class="n">fmModel</span><span class="o">.</span><span class="na">factors</span><span class="o">());</span> |
| <span class="nc">System</span><span class="o">.</span><span class="na">out</span><span class="o">.</span><span class="na">println</span><span class="o">(</span><span class="s">"Linear: "</span> <span class="o">+</span> <span class="n">fmModel</span><span class="o">.</span><span class="na">linear</span><span class="o">());</span> |
| <span class="nc">System</span><span class="o">.</span><span class="na">out</span><span class="o">.</span><span class="na">println</span><span class="o">(</span><span class="s">"Intercept: "</span> <span class="o">+</span> <span class="n">fmModel</span><span class="o">.</span><span class="na">intercept</span><span class="o">());</span></code></pre></div> |
| <div><small>Find full example code at "examples/src/main/java/org/apache/spark/examples/ml/JavaFMClassifierExample.java" in the Spark repo.</small></div> |
| </div> |
| |
| <div data-lang="python"> |
| |
| <p>Refer to the <a href="api/python/reference/api/pyspark.ml.classification.FMClassifier.html">Python API docs</a> for more details.</p> |
| |
| <div class="highlight"><pre class="codehilite"><code><span class="kn">from</span> <span class="nn">pyspark.ml</span> <span class="kn">import</span> <span class="n">Pipeline</span> |
| <span class="kn">from</span> <span class="nn">pyspark.ml.classification</span> <span class="kn">import</span> <span class="n">FMClassifier</span> |
| <span class="kn">from</span> <span class="nn">pyspark.ml.feature</span> <span class="kn">import</span> <span class="n">MinMaxScaler</span><span class="p">,</span> <span class="n">StringIndexer</span> |
| <span class="kn">from</span> <span class="nn">pyspark.ml.evaluation</span> <span class="kn">import</span> <span class="n">MulticlassClassificationEvaluator</span> |
| |
| <span class="c1"># Load and parse the data file, converting it to a DataFrame. |
| </span><span class="n">data</span> <span class="o">=</span> <span class="n">spark</span><span class="p">.</span><span class="n">read</span><span class="p">.</span><span class="nb">format</span><span class="p">(</span><span class="s">"libsvm"</span><span class="p">).</span><span class="n">load</span><span class="p">(</span><span class="s">"data/mllib/sample_libsvm_data.txt"</span><span class="p">)</span> |
| |
| <span class="c1"># Index labels, adding metadata to the label column. |
| # Fit on whole dataset to include all labels in index. |
| </span><span class="n">labelIndexer</span> <span class="o">=</span> <span class="n">StringIndexer</span><span class="p">(</span><span class="n">inputCol</span><span class="o">=</span><span class="s">"label"</span><span class="p">,</span> <span class="n">outputCol</span><span class="o">=</span><span class="s">"indexedLabel"</span><span class="p">).</span><span class="n">fit</span><span class="p">(</span><span class="n">data</span><span class="p">)</span> |
| <span class="c1"># Scale features. |
| </span><span class="n">featureScaler</span> <span class="o">=</span> <span class="n">MinMaxScaler</span><span class="p">(</span><span class="n">inputCol</span><span class="o">=</span><span class="s">"features"</span><span class="p">,</span> <span class="n">outputCol</span><span class="o">=</span><span class="s">"scaledFeatures"</span><span class="p">).</span><span class="n">fit</span><span class="p">(</span><span class="n">data</span><span class="p">)</span> |
| |
| <span class="c1"># Split the data into training and test sets (30% held out for testing) |
| </span><span class="p">(</span><span class="n">trainingData</span><span class="p">,</span> <span class="n">testData</span><span class="p">)</span> <span class="o">=</span> <span class="n">data</span><span class="p">.</span><span class="n">randomSplit</span><span class="p">([</span><span class="mf">0.7</span><span class="p">,</span> <span class="mf">0.3</span><span class="p">])</span> |
| |
| <span class="c1"># Train a FM model. |
| </span><span class="n">fm</span> <span class="o">=</span> <span class="n">FMClassifier</span><span class="p">(</span><span class="n">labelCol</span><span class="o">=</span><span class="s">"indexedLabel"</span><span class="p">,</span> <span class="n">featuresCol</span><span class="o">=</span><span class="s">"scaledFeatures"</span><span class="p">,</span> <span class="n">stepSize</span><span class="o">=</span><span class="mf">0.001</span><span class="p">)</span> |
| |
| <span class="c1"># Create a Pipeline. |
| </span><span class="n">pipeline</span> <span class="o">=</span> <span class="n">Pipeline</span><span class="p">(</span><span class="n">stages</span><span class="o">=</span><span class="p">[</span><span class="n">labelIndexer</span><span class="p">,</span> <span class="n">featureScaler</span><span class="p">,</span> <span class="n">fm</span><span class="p">])</span> |
| |
| <span class="c1"># Train model. |
| </span><span class="n">model</span> <span class="o">=</span> <span class="n">pipeline</span><span class="p">.</span><span class="n">fit</span><span class="p">(</span><span class="n">trainingData</span><span class="p">)</span> |
| |
| <span class="c1"># Make predictions. |
| </span><span class="n">predictions</span> <span class="o">=</span> <span class="n">model</span><span class="p">.</span><span class="n">transform</span><span class="p">(</span><span class="n">testData</span><span class="p">)</span> |
| |
| <span class="c1"># Select example rows to display. |
| </span><span class="n">predictions</span><span class="p">.</span><span class="n">select</span><span class="p">(</span><span class="s">"prediction"</span><span class="p">,</span> <span class="s">"indexedLabel"</span><span class="p">,</span> <span class="s">"features"</span><span class="p">).</span><span class="n">show</span><span class="p">(</span><span class="mi">5</span><span class="p">)</span> |
| |
| <span class="c1"># Select (prediction, true label) and compute test accuracy |
| </span><span class="n">evaluator</span> <span class="o">=</span> <span class="n">MulticlassClassificationEvaluator</span><span class="p">(</span> |
| <span class="n">labelCol</span><span class="o">=</span><span class="s">"indexedLabel"</span><span class="p">,</span> <span class="n">predictionCol</span><span class="o">=</span><span class="s">"prediction"</span><span class="p">,</span> <span class="n">metricName</span><span class="o">=</span><span class="s">"accuracy"</span><span class="p">)</span> |
| <span class="n">accuracy</span> <span class="o">=</span> <span class="n">evaluator</span><span class="p">.</span><span class="n">evaluate</span><span class="p">(</span><span class="n">predictions</span><span class="p">)</span> |
| <span class="k">print</span><span class="p">(</span><span class="s">"Test set accuracy = %g"</span> <span class="o">%</span> <span class="n">accuracy</span><span class="p">)</span> |
| |
| <span class="n">fmModel</span> <span class="o">=</span> <span class="n">model</span><span class="p">.</span><span class="n">stages</span><span class="p">[</span><span class="mi">2</span><span class="p">]</span> |
| <span class="k">print</span><span class="p">(</span><span class="s">"Factors: "</span> <span class="o">+</span> <span class="nb">str</span><span class="p">(</span><span class="n">fmModel</span><span class="p">.</span><span class="n">factors</span><span class="p">))</span> <span class="c1"># type: ignore |
| </span><span class="k">print</span><span class="p">(</span><span class="s">"Linear: "</span> <span class="o">+</span> <span class="nb">str</span><span class="p">(</span><span class="n">fmModel</span><span class="p">.</span><span class="n">linear</span><span class="p">))</span> <span class="c1"># type: ignore |
| </span><span class="k">print</span><span class="p">(</span><span class="s">"Intercept: "</span> <span class="o">+</span> <span class="nb">str</span><span class="p">(</span><span class="n">fmModel</span><span class="p">.</span><span class="n">intercept</span><span class="p">))</span> <span class="c1"># type: ignore</span></code></pre></div> |
| <div><small>Find full example code at "examples/src/main/python/ml/fm_classifier_example.py" in the Spark repo.</small></div> |
| </div> |
| |
| <div data-lang="r"> |
| |
| <p>Refer to the <a href="api/R/spark.fmClassifier.html">R API docs</a> for more details.</p> |
| |
| <p>Note: At the moment SparkR doesn’t support feature scaling.</p> |
| |
| <div class="highlight"><pre class="codehilite"><code><span class="c1"># Load training data</span><span class="w"> |
| </span><span class="n">df</span><span class="w"> </span><span class="o"><-</span><span class="w"> </span><span class="n">read.df</span><span class="p">(</span><span class="s2">"data/mllib/sample_libsvm_data.txt"</span><span class="p">,</span><span class="w"> </span><span class="n">source</span><span class="w"> </span><span class="o">=</span><span class="w"> </span><span class="s2">"libsvm"</span><span class="p">)</span><span class="w"> |
| </span><span class="n">training</span><span class="w"> </span><span class="o"><-</span><span class="w"> </span><span class="n">df</span><span class="w"> |
| </span><span class="n">test</span><span class="w"> </span><span class="o"><-</span><span class="w"> </span><span class="n">df</span><span class="w"> |
| |
| </span><span class="c1"># Fit a FM classification model</span><span class="w"> |
| </span><span class="n">model</span><span class="w"> </span><span class="o"><-</span><span class="w"> </span><span class="n">spark.fmClassifier</span><span class="p">(</span><span class="n">training</span><span class="p">,</span><span class="w"> </span><span class="n">label</span><span class="w"> </span><span class="o">~</span><span class="w"> </span><span class="n">features</span><span class="p">)</span><span class="w"> |
| |
| </span><span class="c1"># Model summary</span><span class="w"> |
| </span><span class="n">summary</span><span class="p">(</span><span class="n">model</span><span class="p">)</span><span class="w"> |
| |
| </span><span class="c1"># Prediction</span><span class="w"> |
| </span><span class="n">predictions</span><span class="w"> </span><span class="o"><-</span><span class="w"> </span><span class="n">predict</span><span class="p">(</span><span class="n">model</span><span class="p">,</span><span class="w"> </span><span class="n">test</span><span class="p">)</span><span class="w"> |
| </span><span class="n">head</span><span class="p">(</span><span class="n">predictions</span><span class="p">)</span></code></pre></div> |
| <div><small>Find full example code at "examples/src/main/r/ml/fmClassifier.R" in the Spark repo.</small></div> |
| </div> |
| |
| </div> |
| |
| <h1 id="regression">Regression</h1> |
| |
| <h2 id="linear-regression">Linear regression</h2> |
| |
| <p>The interface for working with linear regression models and model |
| summaries is similar to the logistic regression case.</p> |
| |
| <blockquote> |
| <p>When fitting LinearRegressionModel without intercept on dataset with constant nonzero column by “l-bfgs” solver, Spark MLlib outputs zero coefficients for constant nonzero columns. This behavior is the same as R glmnet but different from LIBSVM.</p> |
| </blockquote> |
| |
| <p><strong>Examples</strong></p> |
| |
| <p>The following |
| example demonstrates training an elastic net regularized linear |
| regression model and extracting model summary statistics.</p> |
| |
| <div class="codetabs"> |
| |
| <div data-lang="scala"> |
| |
| <p>More details on parameters can be found in the <a href="api/scala/org/apache/spark/ml/regression/LinearRegression.html">Scala API documentation</a>.</p> |
| |
| <div class="highlight"><pre class="codehilite"><code><span class="k">import</span> <span class="nn">org.apache.spark.ml.regression.LinearRegression</span> |
| |
| <span class="c1">// Load training data</span> |
| <span class="k">val</span> <span class="nv">training</span> <span class="k">=</span> <span class="nv">spark</span><span class="o">.</span><span class="py">read</span><span class="o">.</span><span class="py">format</span><span class="o">(</span><span class="s">"libsvm"</span><span class="o">)</span> |
| <span class="o">.</span><span class="py">load</span><span class="o">(</span><span class="s">"data/mllib/sample_linear_regression_data.txt"</span><span class="o">)</span> |
| |
| <span class="k">val</span> <span class="nv">lr</span> <span class="k">=</span> <span class="k">new</span> <span class="nc">LinearRegression</span><span class="o">()</span> |
| <span class="o">.</span><span class="py">setMaxIter</span><span class="o">(</span><span class="mi">10</span><span class="o">)</span> |
| <span class="o">.</span><span class="py">setRegParam</span><span class="o">(</span><span class="mf">0.3</span><span class="o">)</span> |
| <span class="o">.</span><span class="py">setElasticNetParam</span><span class="o">(</span><span class="mf">0.8</span><span class="o">)</span> |
| |
| <span class="c1">// Fit the model</span> |
| <span class="k">val</span> <span class="nv">lrModel</span> <span class="k">=</span> <span class="nv">lr</span><span class="o">.</span><span class="py">fit</span><span class="o">(</span><span class="n">training</span><span class="o">)</span> |
| |
| <span class="c1">// Print the coefficients and intercept for linear regression</span> |
| <span class="nf">println</span><span class="o">(</span><span class="n">s</span><span class="s">"Coefficients: ${lrModel.coefficients} Intercept: ${lrModel.intercept}"</span><span class="o">)</span> |
| |
| <span class="c1">// Summarize the model over the training set and print out some metrics</span> |
| <span class="k">val</span> <span class="nv">trainingSummary</span> <span class="k">=</span> <span class="nv">lrModel</span><span class="o">.</span><span class="py">summary</span> |
| <span class="nf">println</span><span class="o">(</span><span class="n">s</span><span class="s">"numIterations: ${trainingSummary.totalIterations}"</span><span class="o">)</span> |
| <span class="nf">println</span><span class="o">(</span><span class="n">s</span><span class="s">"objectiveHistory: [${trainingSummary.objectiveHistory.mkString("</span><span class="o">,</span><span class="s">")}]"</span><span class="o">)</span> |
| <span class="nv">trainingSummary</span><span class="o">.</span><span class="py">residuals</span><span class="o">.</span><span class="py">show</span><span class="o">()</span> |
| <span class="nf">println</span><span class="o">(</span><span class="n">s</span><span class="s">"RMSE: ${trainingSummary.rootMeanSquaredError}"</span><span class="o">)</span> |
| <span class="nf">println</span><span class="o">(</span><span class="n">s</span><span class="s">"r2: ${trainingSummary.r2}"</span><span class="o">)</span></code></pre></div> |
| <div><small>Find full example code at "examples/src/main/scala/org/apache/spark/examples/ml/LinearRegressionWithElasticNetExample.scala" in the Spark repo.</small></div> |
| </div> |
| |
| <div data-lang="java"> |
| |
| <p>More details on parameters can be found in the <a href="api/java/org/apache/spark/ml/regression/LinearRegression.html">Java API documentation</a>.</p> |
| |
| <div class="highlight"><pre class="codehilite"><code><span class="kn">import</span> <span class="nn">org.apache.spark.ml.regression.LinearRegression</span><span class="o">;</span> |
| <span class="kn">import</span> <span class="nn">org.apache.spark.ml.regression.LinearRegressionModel</span><span class="o">;</span> |
| <span class="kn">import</span> <span class="nn">org.apache.spark.ml.regression.LinearRegressionTrainingSummary</span><span class="o">;</span> |
| <span class="kn">import</span> <span class="nn">org.apache.spark.ml.linalg.Vectors</span><span class="o">;</span> |
| <span class="kn">import</span> <span class="nn">org.apache.spark.sql.Dataset</span><span class="o">;</span> |
| <span class="kn">import</span> <span class="nn">org.apache.spark.sql.Row</span><span class="o">;</span> |
| <span class="kn">import</span> <span class="nn">org.apache.spark.sql.SparkSession</span><span class="o">;</span> |
| |
| <span class="c1">// Load training data.</span> |
| <span class="nc">Dataset</span><span class="o"><</span><span class="nc">Row</span><span class="o">></span> <span class="n">training</span> <span class="o">=</span> <span class="n">spark</span><span class="o">.</span><span class="na">read</span><span class="o">().</span><span class="na">format</span><span class="o">(</span><span class="s">"libsvm"</span><span class="o">)</span> |
| <span class="o">.</span><span class="na">load</span><span class="o">(</span><span class="s">"data/mllib/sample_linear_regression_data.txt"</span><span class="o">);</span> |
| |
| <span class="nc">LinearRegression</span> <span class="n">lr</span> <span class="o">=</span> <span class="k">new</span> <span class="nc">LinearRegression</span><span class="o">()</span> |
| <span class="o">.</span><span class="na">setMaxIter</span><span class="o">(</span><span class="mi">10</span><span class="o">)</span> |
| <span class="o">.</span><span class="na">setRegParam</span><span class="o">(</span><span class="mf">0.3</span><span class="o">)</span> |
| <span class="o">.</span><span class="na">setElasticNetParam</span><span class="o">(</span><span class="mf">0.8</span><span class="o">);</span> |
| |
| <span class="c1">// Fit the model.</span> |
| <span class="nc">LinearRegressionModel</span> <span class="n">lrModel</span> <span class="o">=</span> <span class="n">lr</span><span class="o">.</span><span class="na">fit</span><span class="o">(</span><span class="n">training</span><span class="o">);</span> |
| |
| <span class="c1">// Print the coefficients and intercept for linear regression.</span> |
| <span class="nc">System</span><span class="o">.</span><span class="na">out</span><span class="o">.</span><span class="na">println</span><span class="o">(</span><span class="s">"Coefficients: "</span> |
| <span class="o">+</span> <span class="n">lrModel</span><span class="o">.</span><span class="na">coefficients</span><span class="o">()</span> <span class="o">+</span> <span class="s">" Intercept: "</span> <span class="o">+</span> <span class="n">lrModel</span><span class="o">.</span><span class="na">intercept</span><span class="o">());</span> |
| |
| <span class="c1">// Summarize the model over the training set and print out some metrics.</span> |
| <span class="nc">LinearRegressionTrainingSummary</span> <span class="n">trainingSummary</span> <span class="o">=</span> <span class="n">lrModel</span><span class="o">.</span><span class="na">summary</span><span class="o">();</span> |
| <span class="nc">System</span><span class="o">.</span><span class="na">out</span><span class="o">.</span><span class="na">println</span><span class="o">(</span><span class="s">"numIterations: "</span> <span class="o">+</span> <span class="n">trainingSummary</span><span class="o">.</span><span class="na">totalIterations</span><span class="o">());</span> |
| <span class="nc">System</span><span class="o">.</span><span class="na">out</span><span class="o">.</span><span class="na">println</span><span class="o">(</span><span class="s">"objectiveHistory: "</span> <span class="o">+</span> <span class="nc">Vectors</span><span class="o">.</span><span class="na">dense</span><span class="o">(</span><span class="n">trainingSummary</span><span class="o">.</span><span class="na">objectiveHistory</span><span class="o">()));</span> |
| <span class="n">trainingSummary</span><span class="o">.</span><span class="na">residuals</span><span class="o">().</span><span class="na">show</span><span class="o">();</span> |
| <span class="nc">System</span><span class="o">.</span><span class="na">out</span><span class="o">.</span><span class="na">println</span><span class="o">(</span><span class="s">"RMSE: "</span> <span class="o">+</span> <span class="n">trainingSummary</span><span class="o">.</span><span class="na">rootMeanSquaredError</span><span class="o">());</span> |
| <span class="nc">System</span><span class="o">.</span><span class="na">out</span><span class="o">.</span><span class="na">println</span><span class="o">(</span><span class="s">"r2: "</span> <span class="o">+</span> <span class="n">trainingSummary</span><span class="o">.</span><span class="na">r2</span><span class="o">());</span></code></pre></div> |
| <div><small>Find full example code at "examples/src/main/java/org/apache/spark/examples/ml/JavaLinearRegressionWithElasticNetExample.java" in the Spark repo.</small></div> |
| </div> |
| |
| <div data-lang="python"> |
| <!--- TODO: Add python model summaries once implemented --> |
| |
| <p>More details on parameters can be found in the <a href="api/python/reference/api/pyspark.ml.regression.LinearRegression.html#pyspark.ml.regression.LinearRegression">Python API documentation</a>.</p> |
| |
| <div class="highlight"><pre class="codehilite"><code><span class="kn">from</span> <span class="nn">pyspark.ml.regression</span> <span class="kn">import</span> <span class="n">LinearRegression</span> |
| |
| <span class="c1"># Load training data |
| </span><span class="n">training</span> <span class="o">=</span> <span class="n">spark</span><span class="p">.</span><span class="n">read</span><span class="p">.</span><span class="nb">format</span><span class="p">(</span><span class="s">"libsvm"</span><span class="p">)</span>\ |
| <span class="p">.</span><span class="n">load</span><span class="p">(</span><span class="s">"data/mllib/sample_linear_regression_data.txt"</span><span class="p">)</span> |
| |
| <span class="n">lr</span> <span class="o">=</span> <span class="n">LinearRegression</span><span class="p">(</span><span class="n">maxIter</span><span class="o">=</span><span class="mi">10</span><span class="p">,</span> <span class="n">regParam</span><span class="o">=</span><span class="mf">0.3</span><span class="p">,</span> <span class="n">elasticNetParam</span><span class="o">=</span><span class="mf">0.8</span><span class="p">)</span> |
| |
| <span class="c1"># Fit the model |
| </span><span class="n">lrModel</span> <span class="o">=</span> <span class="n">lr</span><span class="p">.</span><span class="n">fit</span><span class="p">(</span><span class="n">training</span><span class="p">)</span> |
| |
| <span class="c1"># Print the coefficients and intercept for linear regression |
| </span><span class="k">print</span><span class="p">(</span><span class="s">"Coefficients: %s"</span> <span class="o">%</span> <span class="nb">str</span><span class="p">(</span><span class="n">lrModel</span><span class="p">.</span><span class="n">coefficients</span><span class="p">))</span> |
| <span class="k">print</span><span class="p">(</span><span class="s">"Intercept: %s"</span> <span class="o">%</span> <span class="nb">str</span><span class="p">(</span><span class="n">lrModel</span><span class="p">.</span><span class="n">intercept</span><span class="p">))</span> |
| |
| <span class="c1"># Summarize the model over the training set and print out some metrics |
| </span><span class="n">trainingSummary</span> <span class="o">=</span> <span class="n">lrModel</span><span class="p">.</span><span class="n">summary</span> |
| <span class="k">print</span><span class="p">(</span><span class="s">"numIterations: %d"</span> <span class="o">%</span> <span class="n">trainingSummary</span><span class="p">.</span><span class="n">totalIterations</span><span class="p">)</span> |
| <span class="k">print</span><span class="p">(</span><span class="s">"objectiveHistory: %s"</span> <span class="o">%</span> <span class="nb">str</span><span class="p">(</span><span class="n">trainingSummary</span><span class="p">.</span><span class="n">objectiveHistory</span><span class="p">))</span> |
| <span class="n">trainingSummary</span><span class="p">.</span><span class="n">residuals</span><span class="p">.</span><span class="n">show</span><span class="p">()</span> |
| <span class="k">print</span><span class="p">(</span><span class="s">"RMSE: %f"</span> <span class="o">%</span> <span class="n">trainingSummary</span><span class="p">.</span><span class="n">rootMeanSquaredError</span><span class="p">)</span> |
| <span class="k">print</span><span class="p">(</span><span class="s">"r2: %f"</span> <span class="o">%</span> <span class="n">trainingSummary</span><span class="p">.</span><span class="n">r2</span><span class="p">)</span></code></pre></div> |
| <div><small>Find full example code at "examples/src/main/python/ml/linear_regression_with_elastic_net.py" in the Spark repo.</small></div> |
| </div> |
| |
| <div data-lang="r"> |
| |
| <p>More details on parameters can be found in the <a href="api/R/spark.lm.html">R API documentation</a>.</p> |
| |
| <div class="highlight"><pre class="codehilite"><code><span class="c1"># Load training data</span><span class="w"> |
| </span><span class="n">df</span><span class="w"> </span><span class="o"><-</span><span class="w"> </span><span class="n">read.df</span><span class="p">(</span><span class="s2">"data/mllib/sample_linear_regression_data.txt"</span><span class="p">,</span><span class="w"> </span><span class="n">source</span><span class="w"> </span><span class="o">=</span><span class="w"> </span><span class="s2">"libsvm"</span><span class="p">)</span><span class="w"> |
| </span><span class="n">training</span><span class="w"> </span><span class="o"><-</span><span class="w"> </span><span class="n">df</span><span class="w"> |
| </span><span class="n">test</span><span class="w"> </span><span class="o"><-</span><span class="w"> </span><span class="n">df</span><span class="w"> |
| |
| </span><span class="c1"># Fit a linear regression model</span><span class="w"> |
| </span><span class="n">model</span><span class="w"> </span><span class="o"><-</span><span class="w"> </span><span class="n">spark.lm</span><span class="p">(</span><span class="n">training</span><span class="p">,</span><span class="w"> </span><span class="n">label</span><span class="w"> </span><span class="o">~</span><span class="w"> </span><span class="n">features</span><span class="p">,</span><span class="w"> </span><span class="n">regParam</span><span class="w"> </span><span class="o">=</span><span class="w"> </span><span class="m">0.3</span><span class="p">,</span><span class="w"> </span><span class="n">elasticNetParam</span><span class="w"> </span><span class="o">=</span><span class="w"> </span><span class="m">0.8</span><span class="p">)</span><span class="w"> |
| |
| </span><span class="c1"># Prediction</span><span class="w"> |
| </span><span class="n">predictions</span><span class="w"> </span><span class="o"><-</span><span class="w"> </span><span class="n">predict</span><span class="p">(</span><span class="n">model</span><span class="p">,</span><span class="w"> </span><span class="n">test</span><span class="p">)</span><span class="w"> |
| </span><span class="n">head</span><span class="p">(</span><span class="n">predictions</span><span class="p">)</span><span class="w"> |
| |
| </span><span class="c1"># Summarize</span><span class="w"> |
| </span><span class="n">summary</span><span class="p">(</span><span class="n">model</span><span class="p">)</span></code></pre></div> |
| <div><small>Find full example code at "examples/src/main/r/ml/lm_with_elastic_net.R" in the Spark repo.</small></div> |
| </div> |
| |
| </div> |
| |
| <h2 id="generalized-linear-regression">Generalized linear regression</h2> |
| |
| <p>Contrasted with linear regression where the output is assumed to follow a Gaussian |
| distribution, <a href="https://en.wikipedia.org/wiki/Generalized_linear_model">generalized linear models</a> (GLMs) are specifications of linear models where the response variable $Y_i$ follows some |
| distribution from the <a href="https://en.wikipedia.org/wiki/Exponential_family">exponential family of distributions</a>. |
| Spark’s <code class="language-plaintext highlighter-rouge">GeneralizedLinearRegression</code> interface |
| allows for flexible specification of GLMs which can be used for various types of |
| prediction problems including linear regression, Poisson regression, logistic regression, and others. |
| Currently in <code class="language-plaintext highlighter-rouge">spark.ml</code>, only a subset of the exponential family distributions are supported and they are listed |
| <a href="#available-families">below</a>.</p> |
| |
| <p><strong>NOTE</strong>: Spark currently only supports up to 4096 features through its <code class="language-plaintext highlighter-rouge">GeneralizedLinearRegression</code> |
| interface, and will throw an exception if this constraint is exceeded. See the <a href="ml-advanced">advanced section</a> for more details. |
| Still, for linear and logistic regression, models with an increased number of features can be trained |
| using the <code class="language-plaintext highlighter-rouge">LinearRegression</code> and <code class="language-plaintext highlighter-rouge">LogisticRegression</code> estimators.</p> |
| |
| <p>GLMs require exponential family distributions that can be written in their “canonical” or “natural” form, aka |
| <a href="https://en.wikipedia.org/wiki/Natural_exponential_family">natural exponential family distributions</a>. The form of a natural exponential family distribution is given as:</p> |
| |
| \[f_Y(y|\theta, \tau) = h(y, \tau)\exp{\left( \frac{\theta \cdot y - A(\theta)}{d(\tau)} \right)}\] |
| |
| <p>where $\theta$ is the parameter of interest and $\tau$ is a dispersion parameter. In a GLM the response variable $Y_i$ is assumed to be drawn from a natural exponential family distribution:</p> |
| |
| \[Y_i \sim f\left(\cdot|\theta_i, \tau \right)\] |
| |
| <p>where the parameter of interest $\theta_i$ is related to the expected value of the response variable $\mu_i$ by</p> |
| |
| \[\mu_i = A'(\theta_i)\] |
| |
| <p>Here, $A’(\theta_i)$ is defined by the form of the distribution selected. GLMs also allow specification |
| of a link function, which defines the relationship between the expected value of the response variable $\mu_i$ |
| and the so called <em>linear predictor</em> $\eta_i$:</p> |
| |
| \[g(\mu_i) = \eta_i = \vec{x_i}^T \cdot \vec{\beta}\] |
| |
| <p>Often, the link function is chosen such that $A’ = g^{-1}$, which yields a simplified relationship |
| between the parameter of interest $\theta$ and the linear predictor $\eta$. In this case, the link |
| function $g(\mu)$ is said to be the “canonical” link function.</p> |
| |
| \[\theta_i = A'^{-1}(\mu_i) = g(g^{-1}(\eta_i)) = \eta_i\] |
| |
| <p>A GLM finds the regression coefficients $\vec{\beta}$ which maximize the likelihood function.</p> |
| |
| \[\max_{\vec{\beta}} \mathcal{L}(\vec{\theta}|\vec{y},X) = |
| \prod_{i=1}^{N} h(y_i, \tau) \exp{\left(\frac{y_i\theta_i - A(\theta_i)}{d(\tau)}\right)}\] |
| |
| <p>where the parameter of interest $\theta_i$ is related to the regression coefficients $\vec{\beta}$ |
| by</p> |
| |
| \[\theta_i = A'^{-1}(g^{-1}(\vec{x_i} \cdot \vec{\beta}))\] |
| |
| <p>Spark’s generalized linear regression interface also provides summary statistics for diagnosing the |
| fit of GLM models, including residuals, p-values, deviances, the Akaike information criterion, and |
| others.</p> |
| |
| <p><a href="http://data.princeton.edu/wws509/notes/">See here</a> for a more comprehensive review of GLMs and their applications.</p> |
| |
| <h3 id="available-families">Available families</h3> |
| |
| <table class="table"> |
| <thead> |
| <tr> |
| <th>Family</th> |
| <th>Response Type</th> |
| <th>Supported Links</th></tr> |
| </thead> |
| <tbody> |
| <tr> |
| <td>Gaussian</td> |
| <td>Continuous</td> |
| <td>Identity*, Log, Inverse</td> |
| </tr> |
| <tr> |
| <td>Binomial</td> |
| <td>Binary</td> |
| <td>Logit*, Probit, CLogLog</td> |
| </tr> |
| <tr> |
| <td>Poisson</td> |
| <td>Count</td> |
| <td>Log*, Identity, Sqrt</td> |
| </tr> |
| <tr> |
| <td>Gamma</td> |
| <td>Continuous</td> |
| <td>Inverse*, Identity, Log</td> |
| </tr> |
| <tr> |
| <td>Tweedie</td> |
| <td>Zero-inflated continuous</td> |
| <td>Power link function</td> |
| </tr> |
| <tfoot><tr><td colspan="4">* Canonical Link</td></tr></tfoot> |
| </tbody> |
| </table> |
| |
| <p><strong>Examples</strong></p> |
| |
| <p>The following example demonstrates training a GLM with a Gaussian response and identity link |
| function and extracting model summary statistics.</p> |
| |
| <div class="codetabs"> |
| |
| <div data-lang="scala"> |
| |
| <p>Refer to the <a href="api/scala/org/apache/spark/ml/regression/GeneralizedLinearRegression.html">Scala API docs</a> for more details.</p> |
| |
| <div class="highlight"><pre class="codehilite"><code><span class="k">import</span> <span class="nn">org.apache.spark.ml.regression.GeneralizedLinearRegression</span> |
| |
| <span class="c1">// Load training data</span> |
| <span class="k">val</span> <span class="nv">dataset</span> <span class="k">=</span> <span class="nv">spark</span><span class="o">.</span><span class="py">read</span><span class="o">.</span><span class="py">format</span><span class="o">(</span><span class="s">"libsvm"</span><span class="o">)</span> |
| <span class="o">.</span><span class="py">load</span><span class="o">(</span><span class="s">"data/mllib/sample_linear_regression_data.txt"</span><span class="o">)</span> |
| |
| <span class="k">val</span> <span class="nv">glr</span> <span class="k">=</span> <span class="k">new</span> <span class="nc">GeneralizedLinearRegression</span><span class="o">()</span> |
| <span class="o">.</span><span class="py">setFamily</span><span class="o">(</span><span class="s">"gaussian"</span><span class="o">)</span> |
| <span class="o">.</span><span class="py">setLink</span><span class="o">(</span><span class="s">"identity"</span><span class="o">)</span> |
| <span class="o">.</span><span class="py">setMaxIter</span><span class="o">(</span><span class="mi">10</span><span class="o">)</span> |
| <span class="o">.</span><span class="py">setRegParam</span><span class="o">(</span><span class="mf">0.3</span><span class="o">)</span> |
| |
| <span class="c1">// Fit the model</span> |
| <span class="k">val</span> <span class="nv">model</span> <span class="k">=</span> <span class="nv">glr</span><span class="o">.</span><span class="py">fit</span><span class="o">(</span><span class="n">dataset</span><span class="o">)</span> |
| |
| <span class="c1">// Print the coefficients and intercept for generalized linear regression model</span> |
| <span class="nf">println</span><span class="o">(</span><span class="n">s</span><span class="s">"Coefficients: ${model.coefficients}"</span><span class="o">)</span> |
| <span class="nf">println</span><span class="o">(</span><span class="n">s</span><span class="s">"Intercept: ${model.intercept}"</span><span class="o">)</span> |
| |
| <span class="c1">// Summarize the model over the training set and print out some metrics</span> |
| <span class="k">val</span> <span class="nv">summary</span> <span class="k">=</span> <span class="nv">model</span><span class="o">.</span><span class="py">summary</span> |
| <span class="nf">println</span><span class="o">(</span><span class="n">s</span><span class="s">"Coefficient Standard Errors: ${summary.coefficientStandardErrors.mkString("</span><span class="o">,</span><span class="s">")}"</span><span class="o">)</span> |
| <span class="nf">println</span><span class="o">(</span><span class="n">s</span><span class="s">"T Values: ${summary.tValues.mkString("</span><span class="o">,</span><span class="s">")}"</span><span class="o">)</span> |
| <span class="nf">println</span><span class="o">(</span><span class="n">s</span><span class="s">"P Values: ${summary.pValues.mkString("</span><span class="o">,</span><span class="s">")}"</span><span class="o">)</span> |
| <span class="nf">println</span><span class="o">(</span><span class="n">s</span><span class="s">"Dispersion: ${summary.dispersion}"</span><span class="o">)</span> |
| <span class="nf">println</span><span class="o">(</span><span class="n">s</span><span class="s">"Null Deviance: ${summary.nullDeviance}"</span><span class="o">)</span> |
| <span class="nf">println</span><span class="o">(</span><span class="n">s</span><span class="s">"Residual Degree Of Freedom Null: ${summary.residualDegreeOfFreedomNull}"</span><span class="o">)</span> |
| <span class="nf">println</span><span class="o">(</span><span class="n">s</span><span class="s">"Deviance: ${summary.deviance}"</span><span class="o">)</span> |
| <span class="nf">println</span><span class="o">(</span><span class="n">s</span><span class="s">"Residual Degree Of Freedom: ${summary.residualDegreeOfFreedom}"</span><span class="o">)</span> |
| <span class="nf">println</span><span class="o">(</span><span class="n">s</span><span class="s">"AIC: ${summary.aic}"</span><span class="o">)</span> |
| <span class="nf">println</span><span class="o">(</span><span class="s">"Deviance Residuals: "</span><span class="o">)</span> |
| <span class="nv">summary</span><span class="o">.</span><span class="py">residuals</span><span class="o">().</span><span class="py">show</span><span class="o">()</span></code></pre></div> |
| <div><small>Find full example code at "examples/src/main/scala/org/apache/spark/examples/ml/GeneralizedLinearRegressionExample.scala" in the Spark repo.</small></div> |
| </div> |
| |
| <div data-lang="java"> |
| |
| <p>Refer to the <a href="api/java/org/apache/spark/ml/regression/GeneralizedLinearRegression.html">Java API docs</a> for more details.</p> |
| |
| <div class="highlight"><pre class="codehilite"><code><span class="kn">import</span> <span class="nn">java.util.Arrays</span><span class="o">;</span> |
| |
| <span class="kn">import</span> <span class="nn">org.apache.spark.ml.regression.GeneralizedLinearRegression</span><span class="o">;</span> |
| <span class="kn">import</span> <span class="nn">org.apache.spark.ml.regression.GeneralizedLinearRegressionModel</span><span class="o">;</span> |
| <span class="kn">import</span> <span class="nn">org.apache.spark.ml.regression.GeneralizedLinearRegressionTrainingSummary</span><span class="o">;</span> |
| <span class="kn">import</span> <span class="nn">org.apache.spark.sql.Dataset</span><span class="o">;</span> |
| <span class="kn">import</span> <span class="nn">org.apache.spark.sql.Row</span><span class="o">;</span> |
| |
| <span class="c1">// Load training data</span> |
| <span class="nc">Dataset</span><span class="o"><</span><span class="nc">Row</span><span class="o">></span> <span class="n">dataset</span> <span class="o">=</span> <span class="n">spark</span><span class="o">.</span><span class="na">read</span><span class="o">().</span><span class="na">format</span><span class="o">(</span><span class="s">"libsvm"</span><span class="o">)</span> |
| <span class="o">.</span><span class="na">load</span><span class="o">(</span><span class="s">"data/mllib/sample_linear_regression_data.txt"</span><span class="o">);</span> |
| |
| <span class="nc">GeneralizedLinearRegression</span> <span class="n">glr</span> <span class="o">=</span> <span class="k">new</span> <span class="nc">GeneralizedLinearRegression</span><span class="o">()</span> |
| <span class="o">.</span><span class="na">setFamily</span><span class="o">(</span><span class="s">"gaussian"</span><span class="o">)</span> |
| <span class="o">.</span><span class="na">setLink</span><span class="o">(</span><span class="s">"identity"</span><span class="o">)</span> |
| <span class="o">.</span><span class="na">setMaxIter</span><span class="o">(</span><span class="mi">10</span><span class="o">)</span> |
| <span class="o">.</span><span class="na">setRegParam</span><span class="o">(</span><span class="mf">0.3</span><span class="o">);</span> |
| |
| <span class="c1">// Fit the model</span> |
| <span class="nc">GeneralizedLinearRegressionModel</span> <span class="n">model</span> <span class="o">=</span> <span class="n">glr</span><span class="o">.</span><span class="na">fit</span><span class="o">(</span><span class="n">dataset</span><span class="o">);</span> |
| |
| <span class="c1">// Print the coefficients and intercept for generalized linear regression model</span> |
| <span class="nc">System</span><span class="o">.</span><span class="na">out</span><span class="o">.</span><span class="na">println</span><span class="o">(</span><span class="s">"Coefficients: "</span> <span class="o">+</span> <span class="n">model</span><span class="o">.</span><span class="na">coefficients</span><span class="o">());</span> |
| <span class="nc">System</span><span class="o">.</span><span class="na">out</span><span class="o">.</span><span class="na">println</span><span class="o">(</span><span class="s">"Intercept: "</span> <span class="o">+</span> <span class="n">model</span><span class="o">.</span><span class="na">intercept</span><span class="o">());</span> |
| |
| <span class="c1">// Summarize the model over the training set and print out some metrics</span> |
| <span class="nc">GeneralizedLinearRegressionTrainingSummary</span> <span class="n">summary</span> <span class="o">=</span> <span class="n">model</span><span class="o">.</span><span class="na">summary</span><span class="o">();</span> |
| <span class="nc">System</span><span class="o">.</span><span class="na">out</span><span class="o">.</span><span class="na">println</span><span class="o">(</span><span class="s">"Coefficient Standard Errors: "</span> |
| <span class="o">+</span> <span class="nc">Arrays</span><span class="o">.</span><span class="na">toString</span><span class="o">(</span><span class="n">summary</span><span class="o">.</span><span class="na">coefficientStandardErrors</span><span class="o">()));</span> |
| <span class="nc">System</span><span class="o">.</span><span class="na">out</span><span class="o">.</span><span class="na">println</span><span class="o">(</span><span class="s">"T Values: "</span> <span class="o">+</span> <span class="nc">Arrays</span><span class="o">.</span><span class="na">toString</span><span class="o">(</span><span class="n">summary</span><span class="o">.</span><span class="na">tValues</span><span class="o">()));</span> |
| <span class="nc">System</span><span class="o">.</span><span class="na">out</span><span class="o">.</span><span class="na">println</span><span class="o">(</span><span class="s">"P Values: "</span> <span class="o">+</span> <span class="nc">Arrays</span><span class="o">.</span><span class="na">toString</span><span class="o">(</span><span class="n">summary</span><span class="o">.</span><span class="na">pValues</span><span class="o">()));</span> |
| <span class="nc">System</span><span class="o">.</span><span class="na">out</span><span class="o">.</span><span class="na">println</span><span class="o">(</span><span class="s">"Dispersion: "</span> <span class="o">+</span> <span class="n">summary</span><span class="o">.</span><span class="na">dispersion</span><span class="o">());</span> |
| <span class="nc">System</span><span class="o">.</span><span class="na">out</span><span class="o">.</span><span class="na">println</span><span class="o">(</span><span class="s">"Null Deviance: "</span> <span class="o">+</span> <span class="n">summary</span><span class="o">.</span><span class="na">nullDeviance</span><span class="o">());</span> |
| <span class="nc">System</span><span class="o">.</span><span class="na">out</span><span class="o">.</span><span class="na">println</span><span class="o">(</span><span class="s">"Residual Degree Of Freedom Null: "</span> <span class="o">+</span> <span class="n">summary</span><span class="o">.</span><span class="na">residualDegreeOfFreedomNull</span><span class="o">());</span> |
| <span class="nc">System</span><span class="o">.</span><span class="na">out</span><span class="o">.</span><span class="na">println</span><span class="o">(</span><span class="s">"Deviance: "</span> <span class="o">+</span> <span class="n">summary</span><span class="o">.</span><span class="na">deviance</span><span class="o">());</span> |
| <span class="nc">System</span><span class="o">.</span><span class="na">out</span><span class="o">.</span><span class="na">println</span><span class="o">(</span><span class="s">"Residual Degree Of Freedom: "</span> <span class="o">+</span> <span class="n">summary</span><span class="o">.</span><span class="na">residualDegreeOfFreedom</span><span class="o">());</span> |
| <span class="nc">System</span><span class="o">.</span><span class="na">out</span><span class="o">.</span><span class="na">println</span><span class="o">(</span><span class="s">"AIC: "</span> <span class="o">+</span> <span class="n">summary</span><span class="o">.</span><span class="na">aic</span><span class="o">());</span> |
| <span class="nc">System</span><span class="o">.</span><span class="na">out</span><span class="o">.</span><span class="na">println</span><span class="o">(</span><span class="s">"Deviance Residuals: "</span><span class="o">);</span> |
| <span class="n">summary</span><span class="o">.</span><span class="na">residuals</span><span class="o">().</span><span class="na">show</span><span class="o">();</span></code></pre></div> |
| <div><small>Find full example code at "examples/src/main/java/org/apache/spark/examples/ml/JavaGeneralizedLinearRegressionExample.java" in the Spark repo.</small></div> |
| </div> |
| |
| <div data-lang="python"> |
| |
| <p>Refer to the <a href="api/python/reference/api/pyspark.ml.regression.GeneralizedLinearRegression.html#pyspark.ml.regression.GeneralizedLinearRegression">Python API docs</a> for more details.</p> |
| |
| <div class="highlight"><pre class="codehilite"><code><span class="kn">from</span> <span class="nn">pyspark.ml.regression</span> <span class="kn">import</span> <span class="n">GeneralizedLinearRegression</span> |
| |
| <span class="c1"># Load training data |
| </span><span class="n">dataset</span> <span class="o">=</span> <span class="n">spark</span><span class="p">.</span><span class="n">read</span><span class="p">.</span><span class="nb">format</span><span class="p">(</span><span class="s">"libsvm"</span><span class="p">)</span>\ |
| <span class="p">.</span><span class="n">load</span><span class="p">(</span><span class="s">"data/mllib/sample_linear_regression_data.txt"</span><span class="p">)</span> |
| |
| <span class="n">glr</span> <span class="o">=</span> <span class="n">GeneralizedLinearRegression</span><span class="p">(</span><span class="n">family</span><span class="o">=</span><span class="s">"gaussian"</span><span class="p">,</span> <span class="n">link</span><span class="o">=</span><span class="s">"identity"</span><span class="p">,</span> <span class="n">maxIter</span><span class="o">=</span><span class="mi">10</span><span class="p">,</span> <span class="n">regParam</span><span class="o">=</span><span class="mf">0.3</span><span class="p">)</span> |
| |
| <span class="c1"># Fit the model |
| </span><span class="n">model</span> <span class="o">=</span> <span class="n">glr</span><span class="p">.</span><span class="n">fit</span><span class="p">(</span><span class="n">dataset</span><span class="p">)</span> |
| |
| <span class="c1"># Print the coefficients and intercept for generalized linear regression model |
| </span><span class="k">print</span><span class="p">(</span><span class="s">"Coefficients: "</span> <span class="o">+</span> <span class="nb">str</span><span class="p">(</span><span class="n">model</span><span class="p">.</span><span class="n">coefficients</span><span class="p">))</span> |
| <span class="k">print</span><span class="p">(</span><span class="s">"Intercept: "</span> <span class="o">+</span> <span class="nb">str</span><span class="p">(</span><span class="n">model</span><span class="p">.</span><span class="n">intercept</span><span class="p">))</span> |
| |
| <span class="c1"># Summarize the model over the training set and print out some metrics |
| </span><span class="n">summary</span> <span class="o">=</span> <span class="n">model</span><span class="p">.</span><span class="n">summary</span> |
| <span class="k">print</span><span class="p">(</span><span class="s">"Coefficient Standard Errors: "</span> <span class="o">+</span> <span class="nb">str</span><span class="p">(</span><span class="n">summary</span><span class="p">.</span><span class="n">coefficientStandardErrors</span><span class="p">))</span> |
| <span class="k">print</span><span class="p">(</span><span class="s">"T Values: "</span> <span class="o">+</span> <span class="nb">str</span><span class="p">(</span><span class="n">summary</span><span class="p">.</span><span class="n">tValues</span><span class="p">))</span> |
| <span class="k">print</span><span class="p">(</span><span class="s">"P Values: "</span> <span class="o">+</span> <span class="nb">str</span><span class="p">(</span><span class="n">summary</span><span class="p">.</span><span class="n">pValues</span><span class="p">))</span> |
| <span class="k">print</span><span class="p">(</span><span class="s">"Dispersion: "</span> <span class="o">+</span> <span class="nb">str</span><span class="p">(</span><span class="n">summary</span><span class="p">.</span><span class="n">dispersion</span><span class="p">))</span> |
| <span class="k">print</span><span class="p">(</span><span class="s">"Null Deviance: "</span> <span class="o">+</span> <span class="nb">str</span><span class="p">(</span><span class="n">summary</span><span class="p">.</span><span class="n">nullDeviance</span><span class="p">))</span> |
| <span class="k">print</span><span class="p">(</span><span class="s">"Residual Degree Of Freedom Null: "</span> <span class="o">+</span> <span class="nb">str</span><span class="p">(</span><span class="n">summary</span><span class="p">.</span><span class="n">residualDegreeOfFreedomNull</span><span class="p">))</span> |
| <span class="k">print</span><span class="p">(</span><span class="s">"Deviance: "</span> <span class="o">+</span> <span class="nb">str</span><span class="p">(</span><span class="n">summary</span><span class="p">.</span><span class="n">deviance</span><span class="p">))</span> |
| <span class="k">print</span><span class="p">(</span><span class="s">"Residual Degree Of Freedom: "</span> <span class="o">+</span> <span class="nb">str</span><span class="p">(</span><span class="n">summary</span><span class="p">.</span><span class="n">residualDegreeOfFreedom</span><span class="p">))</span> |
| <span class="k">print</span><span class="p">(</span><span class="s">"AIC: "</span> <span class="o">+</span> <span class="nb">str</span><span class="p">(</span><span class="n">summary</span><span class="p">.</span><span class="n">aic</span><span class="p">))</span> |
| <span class="k">print</span><span class="p">(</span><span class="s">"Deviance Residuals: "</span><span class="p">)</span> |
| <span class="n">summary</span><span class="p">.</span><span class="n">residuals</span><span class="p">().</span><span class="n">show</span><span class="p">()</span></code></pre></div> |
| <div><small>Find full example code at "examples/src/main/python/ml/generalized_linear_regression_example.py" in the Spark repo.</small></div> |
| </div> |
| |
| <div data-lang="r"> |
| |
| <p>Refer to the <a href="api/R/spark.glm.html">R API docs</a> for more details.</p> |
| |
| <div class="highlight"><pre class="codehilite"><code><span class="n">training</span><span class="w"> </span><span class="o"><-</span><span class="w"> </span><span class="n">read.df</span><span class="p">(</span><span class="s2">"data/mllib/sample_multiclass_classification_data.txt"</span><span class="p">,</span><span class="w"> </span><span class="n">source</span><span class="w"> </span><span class="o">=</span><span class="w"> </span><span class="s2">"libsvm"</span><span class="p">)</span><span class="w"> |
| </span><span class="c1"># Fit a generalized linear model of family "gaussian" with spark.glm</span><span class="w"> |
| </span><span class="n">df_list</span><span class="w"> </span><span class="o"><-</span><span class="w"> </span><span class="n">randomSplit</span><span class="p">(</span><span class="n">training</span><span class="p">,</span><span class="w"> </span><span class="nf">c</span><span class="p">(</span><span class="m">7</span><span class="p">,</span><span class="w"> </span><span class="m">3</span><span class="p">),</span><span class="w"> </span><span class="m">2</span><span class="p">)</span><span class="w"> |
| </span><span class="n">gaussianDF</span><span class="w"> </span><span class="o"><-</span><span class="w"> </span><span class="n">df_list</span><span class="p">[[</span><span class="m">1</span><span class="p">]]</span><span class="w"> |
| </span><span class="n">gaussianTestDF</span><span class="w"> </span><span class="o"><-</span><span class="w"> </span><span class="n">df_list</span><span class="p">[[</span><span class="m">2</span><span class="p">]]</span><span class="w"> |
| </span><span class="n">gaussianGLM</span><span class="w"> </span><span class="o"><-</span><span class="w"> </span><span class="n">spark.glm</span><span class="p">(</span><span class="n">gaussianDF</span><span class="p">,</span><span class="w"> </span><span class="n">label</span><span class="w"> </span><span class="o">~</span><span class="w"> </span><span class="n">features</span><span class="p">,</span><span class="w"> </span><span class="n">family</span><span class="w"> </span><span class="o">=</span><span class="w"> </span><span class="s2">"gaussian"</span><span class="p">)</span><span class="w"> |
| |
| </span><span class="c1"># Model summary</span><span class="w"> |
| </span><span class="n">summary</span><span class="p">(</span><span class="n">gaussianGLM</span><span class="p">)</span><span class="w"> |
| |
| </span><span class="c1"># Prediction</span><span class="w"> |
| </span><span class="n">gaussianPredictions</span><span class="w"> </span><span class="o"><-</span><span class="w"> </span><span class="n">predict</span><span class="p">(</span><span class="n">gaussianGLM</span><span class="p">,</span><span class="w"> </span><span class="n">gaussianTestDF</span><span class="p">)</span><span class="w"> |
| </span><span class="n">head</span><span class="p">(</span><span class="n">gaussianPredictions</span><span class="p">)</span><span class="w"> |
| |
| </span><span class="c1"># Fit a generalized linear model with glm (R-compliant)</span><span class="w"> |
| </span><span class="n">gaussianGLM2</span><span class="w"> </span><span class="o"><-</span><span class="w"> </span><span class="n">glm</span><span class="p">(</span><span class="n">label</span><span class="w"> </span><span class="o">~</span><span class="w"> </span><span class="n">features</span><span class="p">,</span><span class="w"> </span><span class="n">gaussianDF</span><span class="p">,</span><span class="w"> </span><span class="n">family</span><span class="w"> </span><span class="o">=</span><span class="w"> </span><span class="s2">"gaussian"</span><span class="p">)</span><span class="w"> |
| </span><span class="n">summary</span><span class="p">(</span><span class="n">gaussianGLM2</span><span class="p">)</span><span class="w"> |
| |
| </span><span class="c1"># Fit a generalized linear model of family "binomial" with spark.glm</span><span class="w"> |
| </span><span class="n">training2</span><span class="w"> </span><span class="o"><-</span><span class="w"> </span><span class="n">read.df</span><span class="p">(</span><span class="s2">"data/mllib/sample_multiclass_classification_data.txt"</span><span class="p">,</span><span class="w"> </span><span class="n">source</span><span class="w"> </span><span class="o">=</span><span class="w"> </span><span class="s2">"libsvm"</span><span class="p">)</span><span class="w"> |
| </span><span class="n">training2</span><span class="w"> </span><span class="o"><-</span><span class="w"> </span><span class="n">transform</span><span class="p">(</span><span class="n">training2</span><span class="p">,</span><span class="w"> </span><span class="n">label</span><span class="w"> </span><span class="o">=</span><span class="w"> </span><span class="n">cast</span><span class="p">(</span><span class="n">training2</span><span class="o">$</span><span class="n">label</span><span class="w"> </span><span class="o">></span><span class="w"> </span><span class="m">1</span><span class="p">,</span><span class="w"> </span><span class="s2">"integer"</span><span class="p">))</span><span class="w"> |
| </span><span class="n">df_list2</span><span class="w"> </span><span class="o"><-</span><span class="w"> </span><span class="n">randomSplit</span><span class="p">(</span><span class="n">training2</span><span class="p">,</span><span class="w"> </span><span class="nf">c</span><span class="p">(</span><span class="m">7</span><span class="p">,</span><span class="w"> </span><span class="m">3</span><span class="p">),</span><span class="w"> </span><span class="m">2</span><span class="p">)</span><span class="w"> |
| </span><span class="n">binomialDF</span><span class="w"> </span><span class="o"><-</span><span class="w"> </span><span class="n">df_list2</span><span class="p">[[</span><span class="m">1</span><span class="p">]]</span><span class="w"> |
| </span><span class="n">binomialTestDF</span><span class="w"> </span><span class="o"><-</span><span class="w"> </span><span class="n">df_list2</span><span class="p">[[</span><span class="m">2</span><span class="p">]]</span><span class="w"> |
| </span><span class="n">binomialGLM</span><span class="w"> </span><span class="o"><-</span><span class="w"> </span><span class="n">spark.glm</span><span class="p">(</span><span class="n">binomialDF</span><span class="p">,</span><span class="w"> </span><span class="n">label</span><span class="w"> </span><span class="o">~</span><span class="w"> </span><span class="n">features</span><span class="p">,</span><span class="w"> </span><span class="n">family</span><span class="w"> </span><span class="o">=</span><span class="w"> </span><span class="s2">"binomial"</span><span class="p">)</span><span class="w"> |
| |
| </span><span class="c1"># Model summary</span><span class="w"> |
| </span><span class="n">summary</span><span class="p">(</span><span class="n">binomialGLM</span><span class="p">)</span><span class="w"> |
| |
| </span><span class="c1"># Prediction</span><span class="w"> |
| </span><span class="n">binomialPredictions</span><span class="w"> </span><span class="o"><-</span><span class="w"> </span><span class="n">predict</span><span class="p">(</span><span class="n">binomialGLM</span><span class="p">,</span><span class="w"> </span><span class="n">binomialTestDF</span><span class="p">)</span><span class="w"> |
| </span><span class="n">head</span><span class="p">(</span><span class="n">binomialPredictions</span><span class="p">)</span><span class="w"> |
| |
| </span><span class="c1"># Fit a generalized linear model of family "tweedie" with spark.glm</span><span class="w"> |
| </span><span class="n">training3</span><span class="w"> </span><span class="o"><-</span><span class="w"> </span><span class="n">read.df</span><span class="p">(</span><span class="s2">"data/mllib/sample_multiclass_classification_data.txt"</span><span class="p">,</span><span class="w"> </span><span class="n">source</span><span class="w"> </span><span class="o">=</span><span class="w"> </span><span class="s2">"libsvm"</span><span class="p">)</span><span class="w"> |
| </span><span class="n">tweedieDF</span><span class="w"> </span><span class="o"><-</span><span class="w"> </span><span class="n">transform</span><span class="p">(</span><span class="n">training3</span><span class="p">,</span><span class="w"> </span><span class="n">label</span><span class="w"> </span><span class="o">=</span><span class="w"> </span><span class="n">training3</span><span class="o">$</span><span class="n">label</span><span class="w"> </span><span class="o">*</span><span class="w"> </span><span class="nf">exp</span><span class="p">(</span><span class="n">randn</span><span class="p">(</span><span class="m">10</span><span class="p">)))</span><span class="w"> |
| </span><span class="n">tweedieGLM</span><span class="w"> </span><span class="o"><-</span><span class="w"> </span><span class="n">spark.glm</span><span class="p">(</span><span class="n">tweedieDF</span><span class="p">,</span><span class="w"> </span><span class="n">label</span><span class="w"> </span><span class="o">~</span><span class="w"> </span><span class="n">features</span><span class="p">,</span><span class="w"> </span><span class="n">family</span><span class="w"> </span><span class="o">=</span><span class="w"> </span><span class="s2">"tweedie"</span><span class="p">,</span><span class="w"> |
| </span><span class="n">var.power</span><span class="w"> </span><span class="o">=</span><span class="w"> </span><span class="m">1.2</span><span class="p">,</span><span class="w"> </span><span class="n">link.power</span><span class="w"> </span><span class="o">=</span><span class="w"> </span><span class="m">0</span><span class="p">)</span><span class="w"> |
| |
| </span><span class="c1"># Model summary</span><span class="w"> |
| </span><span class="n">summary</span><span class="p">(</span><span class="n">tweedieGLM</span><span class="p">)</span></code></pre></div> |
| <div><small>Find full example code at "examples/src/main/r/ml/glm.R" in the Spark repo.</small></div> |
| </div> |
| |
| </div> |
| |
| <h2 id="decision-tree-regression">Decision tree regression</h2> |
| |
| <p>Decision trees are a popular family of classification and regression methods. |
| More information about the <code class="language-plaintext highlighter-rouge">spark.ml</code> implementation can be found further in the <a href="#decision-trees">section on decision trees</a>.</p> |
| |
| <p><strong>Examples</strong></p> |
| |
| <p>The following examples load a dataset in LibSVM format, split it into training and test sets, train on the first dataset, and then evaluate on the held-out test set. |
| We use a feature transformer to index categorical features, adding metadata to the <code class="language-plaintext highlighter-rouge">DataFrame</code> which the Decision Tree algorithm can recognize.</p> |
| |
| <div class="codetabs"> |
| <div data-lang="scala"> |
| |
| <p>More details on parameters can be found in the <a href="api/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.html">Scala API documentation</a>.</p> |
| |
| <div class="highlight"><pre class="codehilite"><code><span class="k">import</span> <span class="nn">org.apache.spark.ml.Pipeline</span> |
| <span class="k">import</span> <span class="nn">org.apache.spark.ml.evaluation.RegressionEvaluator</span> |
| <span class="k">import</span> <span class="nn">org.apache.spark.ml.feature.VectorIndexer</span> |
| <span class="k">import</span> <span class="nn">org.apache.spark.ml.regression.DecisionTreeRegressionModel</span> |
| <span class="k">import</span> <span class="nn">org.apache.spark.ml.regression.DecisionTreeRegressor</span> |
| |
| <span class="c1">// Load the data stored in LIBSVM format as a DataFrame.</span> |
| <span class="k">val</span> <span class="nv">data</span> <span class="k">=</span> <span class="nv">spark</span><span class="o">.</span><span class="py">read</span><span class="o">.</span><span class="py">format</span><span class="o">(</span><span class="s">"libsvm"</span><span class="o">).</span><span class="py">load</span><span class="o">(</span><span class="s">"data/mllib/sample_libsvm_data.txt"</span><span class="o">)</span> |
| |
| <span class="c1">// Automatically identify categorical features, and index them.</span> |
| <span class="c1">// Here, we treat features with > 4 distinct values as continuous.</span> |
| <span class="k">val</span> <span class="nv">featureIndexer</span> <span class="k">=</span> <span class="k">new</span> <span class="nc">VectorIndexer</span><span class="o">()</span> |
| <span class="o">.</span><span class="py">setInputCol</span><span class="o">(</span><span class="s">"features"</span><span class="o">)</span> |
| <span class="o">.</span><span class="py">setOutputCol</span><span class="o">(</span><span class="s">"indexedFeatures"</span><span class="o">)</span> |
| <span class="o">.</span><span class="py">setMaxCategories</span><span class="o">(</span><span class="mi">4</span><span class="o">)</span> |
| <span class="o">.</span><span class="py">fit</span><span class="o">(</span><span class="n">data</span><span class="o">)</span> |
| |
| <span class="c1">// Split the data into training and test sets (30% held out for testing).</span> |
| <span class="k">val</span> <span class="nv">Array</span><span class="o">(</span><span class="n">trainingData</span><span class="o">,</span> <span class="n">testData</span><span class="o">)</span> <span class="k">=</span> <span class="nv">data</span><span class="o">.</span><span class="py">randomSplit</span><span class="o">(</span><span class="nc">Array</span><span class="o">(</span><span class="mf">0.7</span><span class="o">,</span> <span class="mf">0.3</span><span class="o">))</span> |
| |
| <span class="c1">// Train a DecisionTree model.</span> |
| <span class="k">val</span> <span class="nv">dt</span> <span class="k">=</span> <span class="k">new</span> <span class="nc">DecisionTreeRegressor</span><span class="o">()</span> |
| <span class="o">.</span><span class="py">setLabelCol</span><span class="o">(</span><span class="s">"label"</span><span class="o">)</span> |
| <span class="o">.</span><span class="py">setFeaturesCol</span><span class="o">(</span><span class="s">"indexedFeatures"</span><span class="o">)</span> |
| |
| <span class="c1">// Chain indexer and tree in a Pipeline.</span> |
| <span class="k">val</span> <span class="nv">pipeline</span> <span class="k">=</span> <span class="k">new</span> <span class="nc">Pipeline</span><span class="o">()</span> |
| <span class="o">.</span><span class="py">setStages</span><span class="o">(</span><span class="nc">Array</span><span class="o">(</span><span class="n">featureIndexer</span><span class="o">,</span> <span class="n">dt</span><span class="o">))</span> |
| |
| <span class="c1">// Train model. This also runs the indexer.</span> |
| <span class="k">val</span> <span class="nv">model</span> <span class="k">=</span> <span class="nv">pipeline</span><span class="o">.</span><span class="py">fit</span><span class="o">(</span><span class="n">trainingData</span><span class="o">)</span> |
| |
| <span class="c1">// Make predictions.</span> |
| <span class="k">val</span> <span class="nv">predictions</span> <span class="k">=</span> <span class="nv">model</span><span class="o">.</span><span class="py">transform</span><span class="o">(</span><span class="n">testData</span><span class="o">)</span> |
| |
| <span class="c1">// Select example rows to display.</span> |
| <span class="nv">predictions</span><span class="o">.</span><span class="py">select</span><span class="o">(</span><span class="s">"prediction"</span><span class="o">,</span> <span class="s">"label"</span><span class="o">,</span> <span class="s">"features"</span><span class="o">).</span><span class="py">show</span><span class="o">(</span><span class="mi">5</span><span class="o">)</span> |
| |
| <span class="c1">// Select (prediction, true label) and compute test error.</span> |
| <span class="k">val</span> <span class="nv">evaluator</span> <span class="k">=</span> <span class="k">new</span> <span class="nc">RegressionEvaluator</span><span class="o">()</span> |
| <span class="o">.</span><span class="py">setLabelCol</span><span class="o">(</span><span class="s">"label"</span><span class="o">)</span> |
| <span class="o">.</span><span class="py">setPredictionCol</span><span class="o">(</span><span class="s">"prediction"</span><span class="o">)</span> |
| <span class="o">.</span><span class="py">setMetricName</span><span class="o">(</span><span class="s">"rmse"</span><span class="o">)</span> |
| <span class="k">val</span> <span class="nv">rmse</span> <span class="k">=</span> <span class="nv">evaluator</span><span class="o">.</span><span class="py">evaluate</span><span class="o">(</span><span class="n">predictions</span><span class="o">)</span> |
| <span class="nf">println</span><span class="o">(</span><span class="n">s</span><span class="s">"Root Mean Squared Error (RMSE) on test data = $rmse"</span><span class="o">)</span> |
| |
| <span class="k">val</span> <span class="nv">treeModel</span> <span class="k">=</span> <span class="nv">model</span><span class="o">.</span><span class="py">stages</span><span class="o">(</span><span class="mi">1</span><span class="o">).</span><span class="py">asInstanceOf</span><span class="o">[</span><span class="kt">DecisionTreeRegressionModel</span><span class="o">]</span> |
| <span class="nf">println</span><span class="o">(</span><span class="n">s</span><span class="s">"Learned regression tree model:\n ${treeModel.toDebugString}"</span><span class="o">)</span></code></pre></div> |
| <div><small>Find full example code at "examples/src/main/scala/org/apache/spark/examples/ml/DecisionTreeRegressionExample.scala" in the Spark repo.</small></div> |
| </div> |
| |
| <div data-lang="java"> |
| |
| <p>More details on parameters can be found in the <a href="api/java/org/apache/spark/ml/regression/DecisionTreeRegressor.html">Java API documentation</a>.</p> |
| |
| <div class="highlight"><pre class="codehilite"><code><span class="kn">import</span> <span class="nn">org.apache.spark.ml.Pipeline</span><span class="o">;</span> |
| <span class="kn">import</span> <span class="nn">org.apache.spark.ml.PipelineModel</span><span class="o">;</span> |
| <span class="kn">import</span> <span class="nn">org.apache.spark.ml.PipelineStage</span><span class="o">;</span> |
| <span class="kn">import</span> <span class="nn">org.apache.spark.ml.evaluation.RegressionEvaluator</span><span class="o">;</span> |
| <span class="kn">import</span> <span class="nn">org.apache.spark.ml.feature.VectorIndexer</span><span class="o">;</span> |
| <span class="kn">import</span> <span class="nn">org.apache.spark.ml.feature.VectorIndexerModel</span><span class="o">;</span> |
| <span class="kn">import</span> <span class="nn">org.apache.spark.ml.regression.DecisionTreeRegressionModel</span><span class="o">;</span> |
| <span class="kn">import</span> <span class="nn">org.apache.spark.ml.regression.DecisionTreeRegressor</span><span class="o">;</span> |
| <span class="kn">import</span> <span class="nn">org.apache.spark.sql.Dataset</span><span class="o">;</span> |
| <span class="kn">import</span> <span class="nn">org.apache.spark.sql.Row</span><span class="o">;</span> |
| <span class="kn">import</span> <span class="nn">org.apache.spark.sql.SparkSession</span><span class="o">;</span> |
| |
| <span class="c1">// Load the data stored in LIBSVM format as a DataFrame.</span> |
| <span class="nc">Dataset</span><span class="o"><</span><span class="nc">Row</span><span class="o">></span> <span class="n">data</span> <span class="o">=</span> <span class="n">spark</span><span class="o">.</span><span class="na">read</span><span class="o">().</span><span class="na">format</span><span class="o">(</span><span class="s">"libsvm"</span><span class="o">)</span> |
| <span class="o">.</span><span class="na">load</span><span class="o">(</span><span class="s">"data/mllib/sample_libsvm_data.txt"</span><span class="o">);</span> |
| |
| <span class="c1">// Automatically identify categorical features, and index them.</span> |
| <span class="c1">// Set maxCategories so features with > 4 distinct values are treated as continuous.</span> |
| <span class="nc">VectorIndexerModel</span> <span class="n">featureIndexer</span> <span class="o">=</span> <span class="k">new</span> <span class="nc">VectorIndexer</span><span class="o">()</span> |
| <span class="o">.</span><span class="na">setInputCol</span><span class="o">(</span><span class="s">"features"</span><span class="o">)</span> |
| <span class="o">.</span><span class="na">setOutputCol</span><span class="o">(</span><span class="s">"indexedFeatures"</span><span class="o">)</span> |
| <span class="o">.</span><span class="na">setMaxCategories</span><span class="o">(</span><span class="mi">4</span><span class="o">)</span> |
| <span class="o">.</span><span class="na">fit</span><span class="o">(</span><span class="n">data</span><span class="o">);</span> |
| |
| <span class="c1">// Split the data into training and test sets (30% held out for testing).</span> |
| <span class="nc">Dataset</span><span class="o"><</span><span class="nc">Row</span><span class="o">>[]</span> <span class="n">splits</span> <span class="o">=</span> <span class="n">data</span><span class="o">.</span><span class="na">randomSplit</span><span class="o">(</span><span class="k">new</span> <span class="kt">double</span><span class="o">[]{</span><span class="mf">0.7</span><span class="o">,</span> <span class="mf">0.3</span><span class="o">});</span> |
| <span class="nc">Dataset</span><span class="o"><</span><span class="nc">Row</span><span class="o">></span> <span class="n">trainingData</span> <span class="o">=</span> <span class="n">splits</span><span class="o">[</span><span class="mi">0</span><span class="o">];</span> |
| <span class="nc">Dataset</span><span class="o"><</span><span class="nc">Row</span><span class="o">></span> <span class="n">testData</span> <span class="o">=</span> <span class="n">splits</span><span class="o">[</span><span class="mi">1</span><span class="o">];</span> |
| |
| <span class="c1">// Train a DecisionTree model.</span> |
| <span class="nc">DecisionTreeRegressor</span> <span class="n">dt</span> <span class="o">=</span> <span class="k">new</span> <span class="nc">DecisionTreeRegressor</span><span class="o">()</span> |
| <span class="o">.</span><span class="na">setFeaturesCol</span><span class="o">(</span><span class="s">"indexedFeatures"</span><span class="o">);</span> |
| |
| <span class="c1">// Chain indexer and tree in a Pipeline.</span> |
| <span class="nc">Pipeline</span> <span class="n">pipeline</span> <span class="o">=</span> <span class="k">new</span> <span class="nc">Pipeline</span><span class="o">()</span> |
| <span class="o">.</span><span class="na">setStages</span><span class="o">(</span><span class="k">new</span> <span class="nc">PipelineStage</span><span class="o">[]{</span><span class="n">featureIndexer</span><span class="o">,</span> <span class="n">dt</span><span class="o">});</span> |
| |
| <span class="c1">// Train model. This also runs the indexer.</span> |
| <span class="nc">PipelineModel</span> <span class="n">model</span> <span class="o">=</span> <span class="n">pipeline</span><span class="o">.</span><span class="na">fit</span><span class="o">(</span><span class="n">trainingData</span><span class="o">);</span> |
| |
| <span class="c1">// Make predictions.</span> |
| <span class="nc">Dataset</span><span class="o"><</span><span class="nc">Row</span><span class="o">></span> <span class="n">predictions</span> <span class="o">=</span> <span class="n">model</span><span class="o">.</span><span class="na">transform</span><span class="o">(</span><span class="n">testData</span><span class="o">);</span> |
| |
| <span class="c1">// Select example rows to display.</span> |
| <span class="n">predictions</span><span class="o">.</span><span class="na">select</span><span class="o">(</span><span class="s">"label"</span><span class="o">,</span> <span class="s">"features"</span><span class="o">).</span><span class="na">show</span><span class="o">(</span><span class="mi">5</span><span class="o">);</span> |
| |
| <span class="c1">// Select (prediction, true label) and compute test error.</span> |
| <span class="nc">RegressionEvaluator</span> <span class="n">evaluator</span> <span class="o">=</span> <span class="k">new</span> <span class="nc">RegressionEvaluator</span><span class="o">()</span> |
| <span class="o">.</span><span class="na">setLabelCol</span><span class="o">(</span><span class="s">"label"</span><span class="o">)</span> |
| <span class="o">.</span><span class="na">setPredictionCol</span><span class="o">(</span><span class="s">"prediction"</span><span class="o">)</span> |
| <span class="o">.</span><span class="na">setMetricName</span><span class="o">(</span><span class="s">"rmse"</span><span class="o">);</span> |
| <span class="kt">double</span> <span class="n">rmse</span> <span class="o">=</span> <span class="n">evaluator</span><span class="o">.</span><span class="na">evaluate</span><span class="o">(</span><span class="n">predictions</span><span class="o">);</span> |
| <span class="nc">System</span><span class="o">.</span><span class="na">out</span><span class="o">.</span><span class="na">println</span><span class="o">(</span><span class="s">"Root Mean Squared Error (RMSE) on test data = "</span> <span class="o">+</span> <span class="n">rmse</span><span class="o">);</span> |
| |
| <span class="nc">DecisionTreeRegressionModel</span> <span class="n">treeModel</span> <span class="o">=</span> |
| <span class="o">(</span><span class="nc">DecisionTreeRegressionModel</span><span class="o">)</span> <span class="o">(</span><span class="n">model</span><span class="o">.</span><span class="na">stages</span><span class="o">()[</span><span class="mi">1</span><span class="o">]);</span> |
| <span class="nc">System</span><span class="o">.</span><span class="na">out</span><span class="o">.</span><span class="na">println</span><span class="o">(</span><span class="s">"Learned regression tree model:\n"</span> <span class="o">+</span> <span class="n">treeModel</span><span class="o">.</span><span class="na">toDebugString</span><span class="o">());</span></code></pre></div> |
| <div><small>Find full example code at "examples/src/main/java/org/apache/spark/examples/ml/JavaDecisionTreeRegressionExample.java" in the Spark repo.</small></div> |
| </div> |
| |
| <div data-lang="python"> |
| |
| <p>More details on parameters can be found in the <a href="api/python/reference/api/pyspark.ml.regression.DecisionTreeRegressor.html#pyspark.ml.regression.DecisionTreeRegressor">Python API documentation</a>.</p> |
| |
| <div class="highlight"><pre class="codehilite"><code><span class="kn">from</span> <span class="nn">pyspark.ml</span> <span class="kn">import</span> <span class="n">Pipeline</span> |
| <span class="kn">from</span> <span class="nn">pyspark.ml.regression</span> <span class="kn">import</span> <span class="n">DecisionTreeRegressor</span> |
| <span class="kn">from</span> <span class="nn">pyspark.ml.feature</span> <span class="kn">import</span> <span class="n">VectorIndexer</span> |
| <span class="kn">from</span> <span class="nn">pyspark.ml.evaluation</span> <span class="kn">import</span> <span class="n">RegressionEvaluator</span> |
| |
| <span class="c1"># Load the data stored in LIBSVM format as a DataFrame. |
| </span><span class="n">data</span> <span class="o">=</span> <span class="n">spark</span><span class="p">.</span><span class="n">read</span><span class="p">.</span><span class="nb">format</span><span class="p">(</span><span class="s">"libsvm"</span><span class="p">).</span><span class="n">load</span><span class="p">(</span><span class="s">"data/mllib/sample_libsvm_data.txt"</span><span class="p">)</span> |
| |
| <span class="c1"># Automatically identify categorical features, and index them. |
| # We specify maxCategories so features with > 4 distinct values are treated as continuous. |
| </span><span class="n">featureIndexer</span> <span class="o">=</span>\ |
| <span class="n">VectorIndexer</span><span class="p">(</span><span class="n">inputCol</span><span class="o">=</span><span class="s">"features"</span><span class="p">,</span> <span class="n">outputCol</span><span class="o">=</span><span class="s">"indexedFeatures"</span><span class="p">,</span> <span class="n">maxCategories</span><span class="o">=</span><span class="mi">4</span><span class="p">).</span><span class="n">fit</span><span class="p">(</span><span class="n">data</span><span class="p">)</span> |
| |
| <span class="c1"># Split the data into training and test sets (30% held out for testing) |
| </span><span class="p">(</span><span class="n">trainingData</span><span class="p">,</span> <span class="n">testData</span><span class="p">)</span> <span class="o">=</span> <span class="n">data</span><span class="p">.</span><span class="n">randomSplit</span><span class="p">([</span><span class="mf">0.7</span><span class="p">,</span> <span class="mf">0.3</span><span class="p">])</span> |
| |
| <span class="c1"># Train a DecisionTree model. |
| </span><span class="n">dt</span> <span class="o">=</span> <span class="n">DecisionTreeRegressor</span><span class="p">(</span><span class="n">featuresCol</span><span class="o">=</span><span class="s">"indexedFeatures"</span><span class="p">)</span> |
| |
| <span class="c1"># Chain indexer and tree in a Pipeline |
| </span><span class="n">pipeline</span> <span class="o">=</span> <span class="n">Pipeline</span><span class="p">(</span><span class="n">stages</span><span class="o">=</span><span class="p">[</span><span class="n">featureIndexer</span><span class="p">,</span> <span class="n">dt</span><span class="p">])</span> |
| |
| <span class="c1"># Train model. This also runs the indexer. |
| </span><span class="n">model</span> <span class="o">=</span> <span class="n">pipeline</span><span class="p">.</span><span class="n">fit</span><span class="p">(</span><span class="n">trainingData</span><span class="p">)</span> |
| |
| <span class="c1"># Make predictions. |
| </span><span class="n">predictions</span> <span class="o">=</span> <span class="n">model</span><span class="p">.</span><span class="n">transform</span><span class="p">(</span><span class="n">testData</span><span class="p">)</span> |
| |
| <span class="c1"># Select example rows to display. |
| </span><span class="n">predictions</span><span class="p">.</span><span class="n">select</span><span class="p">(</span><span class="s">"prediction"</span><span class="p">,</span> <span class="s">"label"</span><span class="p">,</span> <span class="s">"features"</span><span class="p">).</span><span class="n">show</span><span class="p">(</span><span class="mi">5</span><span class="p">)</span> |
| |
| <span class="c1"># Select (prediction, true label) and compute test error |
| </span><span class="n">evaluator</span> <span class="o">=</span> <span class="n">RegressionEvaluator</span><span class="p">(</span> |
| <span class="n">labelCol</span><span class="o">=</span><span class="s">"label"</span><span class="p">,</span> <span class="n">predictionCol</span><span class="o">=</span><span class="s">"prediction"</span><span class="p">,</span> <span class="n">metricName</span><span class="o">=</span><span class="s">"rmse"</span><span class="p">)</span> |
| <span class="n">rmse</span> <span class="o">=</span> <span class="n">evaluator</span><span class="p">.</span><span class="n">evaluate</span><span class="p">(</span><span class="n">predictions</span><span class="p">)</span> |
| <span class="k">print</span><span class="p">(</span><span class="s">"Root Mean Squared Error (RMSE) on test data = %g"</span> <span class="o">%</span> <span class="n">rmse</span><span class="p">)</span> |
| |
| <span class="n">treeModel</span> <span class="o">=</span> <span class="n">model</span><span class="p">.</span><span class="n">stages</span><span class="p">[</span><span class="mi">1</span><span class="p">]</span> |
| <span class="c1"># summary only |
| </span><span class="k">print</span><span class="p">(</span><span class="n">treeModel</span><span class="p">)</span></code></pre></div> |
| <div><small>Find full example code at "examples/src/main/python/ml/decision_tree_regression_example.py" in the Spark repo.</small></div> |
| </div> |
| |
| <div data-lang="r"> |
| |
| <p>Refer to the <a href="api/R/spark.decisionTree.html">R API docs</a> for more details.</p> |
| |
| <div class="highlight"><pre class="codehilite"><code><span class="c1"># Load training data</span><span class="w"> |
| </span><span class="n">df</span><span class="w"> </span><span class="o"><-</span><span class="w"> </span><span class="n">read.df</span><span class="p">(</span><span class="s2">"data/mllib/sample_linear_regression_data.txt"</span><span class="p">,</span><span class="w"> </span><span class="n">source</span><span class="w"> </span><span class="o">=</span><span class="w"> </span><span class="s2">"libsvm"</span><span class="p">)</span><span class="w"> |
| </span><span class="n">training</span><span class="w"> </span><span class="o"><-</span><span class="w"> </span><span class="n">df</span><span class="w"> |
| </span><span class="n">test</span><span class="w"> </span><span class="o"><-</span><span class="w"> </span><span class="n">df</span><span class="w"> |
| |
| </span><span class="c1"># Fit a DecisionTree regression model with spark.decisionTree</span><span class="w"> |
| </span><span class="n">model</span><span class="w"> </span><span class="o"><-</span><span class="w"> </span><span class="n">spark.decisionTree</span><span class="p">(</span><span class="n">training</span><span class="p">,</span><span class="w"> </span><span class="n">label</span><span class="w"> </span><span class="o">~</span><span class="w"> </span><span class="n">features</span><span class="p">,</span><span class="w"> </span><span class="s2">"regression"</span><span class="p">)</span><span class="w"> |
| |
| </span><span class="c1"># Model summary</span><span class="w"> |
| </span><span class="n">summary</span><span class="p">(</span><span class="n">model</span><span class="p">)</span><span class="w"> |
| |
| </span><span class="c1"># Prediction</span><span class="w"> |
| </span><span class="n">predictions</span><span class="w"> </span><span class="o"><-</span><span class="w"> </span><span class="n">predict</span><span class="p">(</span><span class="n">model</span><span class="p">,</span><span class="w"> </span><span class="n">test</span><span class="p">)</span><span class="w"> |
| </span><span class="n">head</span><span class="p">(</span><span class="n">predictions</span><span class="p">)</span></code></pre></div> |
| <div><small>Find full example code at "examples/src/main/r/ml/decisionTree.R" in the Spark repo.</small></div> |
| </div> |
| |
| </div> |
| |
| <h2 id="random-forest-regression">Random forest regression</h2> |
| |
| <p>Random forests are a popular family of classification and regression methods. |
| More information about the <code class="language-plaintext highlighter-rouge">spark.ml</code> implementation can be found further in the <a href="#random-forests">section on random forests</a>.</p> |
| |
| <p><strong>Examples</strong></p> |
| |
| <p>The following examples load a dataset in LibSVM format, split it into training and test sets, train on the first dataset, and then evaluate on the held-out test set. |
| We use a feature transformer to index categorical features, adding metadata to the <code class="language-plaintext highlighter-rouge">DataFrame</code> which the tree-based algorithms can recognize.</p> |
| |
| <div class="codetabs"> |
| <div data-lang="scala"> |
| |
| <p>Refer to the <a href="api/scala/org/apache/spark/ml/regression/RandomForestRegressor.html">Scala API docs</a> for more details.</p> |
| |
| <div class="highlight"><pre class="codehilite"><code><span class="k">import</span> <span class="nn">org.apache.spark.ml.Pipeline</span> |
| <span class="k">import</span> <span class="nn">org.apache.spark.ml.evaluation.RegressionEvaluator</span> |
| <span class="k">import</span> <span class="nn">org.apache.spark.ml.feature.VectorIndexer</span> |
| <span class="k">import</span> <span class="nn">org.apache.spark.ml.regression.</span><span class="o">{</span><span class="nc">RandomForestRegressionModel</span><span class="o">,</span> <span class="nc">RandomForestRegressor</span><span class="o">}</span> |
| |
| <span class="c1">// Load and parse the data file, converting it to a DataFrame.</span> |
| <span class="k">val</span> <span class="nv">data</span> <span class="k">=</span> <span class="nv">spark</span><span class="o">.</span><span class="py">read</span><span class="o">.</span><span class="py">format</span><span class="o">(</span><span class="s">"libsvm"</span><span class="o">).</span><span class="py">load</span><span class="o">(</span><span class="s">"data/mllib/sample_libsvm_data.txt"</span><span class="o">)</span> |
| |
| <span class="c1">// Automatically identify categorical features, and index them.</span> |
| <span class="c1">// Set maxCategories so features with > 4 distinct values are treated as continuous.</span> |
| <span class="k">val</span> <span class="nv">featureIndexer</span> <span class="k">=</span> <span class="k">new</span> <span class="nc">VectorIndexer</span><span class="o">()</span> |
| <span class="o">.</span><span class="py">setInputCol</span><span class="o">(</span><span class="s">"features"</span><span class="o">)</span> |
| <span class="o">.</span><span class="py">setOutputCol</span><span class="o">(</span><span class="s">"indexedFeatures"</span><span class="o">)</span> |
| <span class="o">.</span><span class="py">setMaxCategories</span><span class="o">(</span><span class="mi">4</span><span class="o">)</span> |
| <span class="o">.</span><span class="py">fit</span><span class="o">(</span><span class="n">data</span><span class="o">)</span> |
| |
| <span class="c1">// Split the data into training and test sets (30% held out for testing).</span> |
| <span class="k">val</span> <span class="nv">Array</span><span class="o">(</span><span class="n">trainingData</span><span class="o">,</span> <span class="n">testData</span><span class="o">)</span> <span class="k">=</span> <span class="nv">data</span><span class="o">.</span><span class="py">randomSplit</span><span class="o">(</span><span class="nc">Array</span><span class="o">(</span><span class="mf">0.7</span><span class="o">,</span> <span class="mf">0.3</span><span class="o">))</span> |
| |
| <span class="c1">// Train a RandomForest model.</span> |
| <span class="k">val</span> <span class="nv">rf</span> <span class="k">=</span> <span class="k">new</span> <span class="nc">RandomForestRegressor</span><span class="o">()</span> |
| <span class="o">.</span><span class="py">setLabelCol</span><span class="o">(</span><span class="s">"label"</span><span class="o">)</span> |
| <span class="o">.</span><span class="py">setFeaturesCol</span><span class="o">(</span><span class="s">"indexedFeatures"</span><span class="o">)</span> |
| |
| <span class="c1">// Chain indexer and forest in a Pipeline.</span> |
| <span class="k">val</span> <span class="nv">pipeline</span> <span class="k">=</span> <span class="k">new</span> <span class="nc">Pipeline</span><span class="o">()</span> |
| <span class="o">.</span><span class="py">setStages</span><span class="o">(</span><span class="nc">Array</span><span class="o">(</span><span class="n">featureIndexer</span><span class="o">,</span> <span class="n">rf</span><span class="o">))</span> |
| |
| <span class="c1">// Train model. This also runs the indexer.</span> |
| <span class="k">val</span> <span class="nv">model</span> <span class="k">=</span> <span class="nv">pipeline</span><span class="o">.</span><span class="py">fit</span><span class="o">(</span><span class="n">trainingData</span><span class="o">)</span> |
| |
| <span class="c1">// Make predictions.</span> |
| <span class="k">val</span> <span class="nv">predictions</span> <span class="k">=</span> <span class="nv">model</span><span class="o">.</span><span class="py">transform</span><span class="o">(</span><span class="n">testData</span><span class="o">)</span> |
| |
| <span class="c1">// Select example rows to display.</span> |
| <span class="nv">predictions</span><span class="o">.</span><span class="py">select</span><span class="o">(</span><span class="s">"prediction"</span><span class="o">,</span> <span class="s">"label"</span><span class="o">,</span> <span class="s">"features"</span><span class="o">).</span><span class="py">show</span><span class="o">(</span><span class="mi">5</span><span class="o">)</span> |
| |
| <span class="c1">// Select (prediction, true label) and compute test error.</span> |
| <span class="k">val</span> <span class="nv">evaluator</span> <span class="k">=</span> <span class="k">new</span> <span class="nc">RegressionEvaluator</span><span class="o">()</span> |
| <span class="o">.</span><span class="py">setLabelCol</span><span class="o">(</span><span class="s">"label"</span><span class="o">)</span> |
| <span class="o">.</span><span class="py">setPredictionCol</span><span class="o">(</span><span class="s">"prediction"</span><span class="o">)</span> |
| <span class="o">.</span><span class="py">setMetricName</span><span class="o">(</span><span class="s">"rmse"</span><span class="o">)</span> |
| <span class="k">val</span> <span class="nv">rmse</span> <span class="k">=</span> <span class="nv">evaluator</span><span class="o">.</span><span class="py">evaluate</span><span class="o">(</span><span class="n">predictions</span><span class="o">)</span> |
| <span class="nf">println</span><span class="o">(</span><span class="n">s</span><span class="s">"Root Mean Squared Error (RMSE) on test data = $rmse"</span><span class="o">)</span> |
| |
| <span class="k">val</span> <span class="nv">rfModel</span> <span class="k">=</span> <span class="nv">model</span><span class="o">.</span><span class="py">stages</span><span class="o">(</span><span class="mi">1</span><span class="o">).</span><span class="py">asInstanceOf</span><span class="o">[</span><span class="kt">RandomForestRegressionModel</span><span class="o">]</span> |
| <span class="nf">println</span><span class="o">(</span><span class="n">s</span><span class="s">"Learned regression forest model:\n ${rfModel.toDebugString}"</span><span class="o">)</span></code></pre></div> |
| <div><small>Find full example code at "examples/src/main/scala/org/apache/spark/examples/ml/RandomForestRegressorExample.scala" in the Spark repo.</small></div> |
| </div> |
| |
| <div data-lang="java"> |
| |
| <p>Refer to the <a href="api/java/org/apache/spark/ml/regression/RandomForestRegressor.html">Java API docs</a> for more details.</p> |
| |
| <div class="highlight"><pre class="codehilite"><code><span class="kn">import</span> <span class="nn">org.apache.spark.ml.Pipeline</span><span class="o">;</span> |
| <span class="kn">import</span> <span class="nn">org.apache.spark.ml.PipelineModel</span><span class="o">;</span> |
| <span class="kn">import</span> <span class="nn">org.apache.spark.ml.PipelineStage</span><span class="o">;</span> |
| <span class="kn">import</span> <span class="nn">org.apache.spark.ml.evaluation.RegressionEvaluator</span><span class="o">;</span> |
| <span class="kn">import</span> <span class="nn">org.apache.spark.ml.feature.VectorIndexer</span><span class="o">;</span> |
| <span class="kn">import</span> <span class="nn">org.apache.spark.ml.feature.VectorIndexerModel</span><span class="o">;</span> |
| <span class="kn">import</span> <span class="nn">org.apache.spark.ml.regression.RandomForestRegressionModel</span><span class="o">;</span> |
| <span class="kn">import</span> <span class="nn">org.apache.spark.ml.regression.RandomForestRegressor</span><span class="o">;</span> |
| <span class="kn">import</span> <span class="nn">org.apache.spark.sql.Dataset</span><span class="o">;</span> |
| <span class="kn">import</span> <span class="nn">org.apache.spark.sql.Row</span><span class="o">;</span> |
| <span class="kn">import</span> <span class="nn">org.apache.spark.sql.SparkSession</span><span class="o">;</span> |
| |
| <span class="c1">// Load and parse the data file, converting it to a DataFrame.</span> |
| <span class="nc">Dataset</span><span class="o"><</span><span class="nc">Row</span><span class="o">></span> <span class="n">data</span> <span class="o">=</span> <span class="n">spark</span><span class="o">.</span><span class="na">read</span><span class="o">().</span><span class="na">format</span><span class="o">(</span><span class="s">"libsvm"</span><span class="o">).</span><span class="na">load</span><span class="o">(</span><span class="s">"data/mllib/sample_libsvm_data.txt"</span><span class="o">);</span> |
| |
| <span class="c1">// Automatically identify categorical features, and index them.</span> |
| <span class="c1">// Set maxCategories so features with > 4 distinct values are treated as continuous.</span> |
| <span class="nc">VectorIndexerModel</span> <span class="n">featureIndexer</span> <span class="o">=</span> <span class="k">new</span> <span class="nc">VectorIndexer</span><span class="o">()</span> |
| <span class="o">.</span><span class="na">setInputCol</span><span class="o">(</span><span class="s">"features"</span><span class="o">)</span> |
| <span class="o">.</span><span class="na">setOutputCol</span><span class="o">(</span><span class="s">"indexedFeatures"</span><span class="o">)</span> |
| <span class="o">.</span><span class="na">setMaxCategories</span><span class="o">(</span><span class="mi">4</span><span class="o">)</span> |
| <span class="o">.</span><span class="na">fit</span><span class="o">(</span><span class="n">data</span><span class="o">);</span> |
| |
| <span class="c1">// Split the data into training and test sets (30% held out for testing)</span> |
| <span class="nc">Dataset</span><span class="o"><</span><span class="nc">Row</span><span class="o">>[]</span> <span class="n">splits</span> <span class="o">=</span> <span class="n">data</span><span class="o">.</span><span class="na">randomSplit</span><span class="o">(</span><span class="k">new</span> <span class="kt">double</span><span class="o">[]</span> <span class="o">{</span><span class="mf">0.7</span><span class="o">,</span> <span class="mf">0.3</span><span class="o">});</span> |
| <span class="nc">Dataset</span><span class="o"><</span><span class="nc">Row</span><span class="o">></span> <span class="n">trainingData</span> <span class="o">=</span> <span class="n">splits</span><span class="o">[</span><span class="mi">0</span><span class="o">];</span> |
| <span class="nc">Dataset</span><span class="o"><</span><span class="nc">Row</span><span class="o">></span> <span class="n">testData</span> <span class="o">=</span> <span class="n">splits</span><span class="o">[</span><span class="mi">1</span><span class="o">];</span> |
| |
| <span class="c1">// Train a RandomForest model.</span> |
| <span class="nc">RandomForestRegressor</span> <span class="n">rf</span> <span class="o">=</span> <span class="k">new</span> <span class="nc">RandomForestRegressor</span><span class="o">()</span> |
| <span class="o">.</span><span class="na">setLabelCol</span><span class="o">(</span><span class="s">"label"</span><span class="o">)</span> |
| <span class="o">.</span><span class="na">setFeaturesCol</span><span class="o">(</span><span class="s">"indexedFeatures"</span><span class="o">);</span> |
| |
| <span class="c1">// Chain indexer and forest in a Pipeline</span> |
| <span class="nc">Pipeline</span> <span class="n">pipeline</span> <span class="o">=</span> <span class="k">new</span> <span class="nc">Pipeline</span><span class="o">()</span> |
| <span class="o">.</span><span class="na">setStages</span><span class="o">(</span><span class="k">new</span> <span class="nc">PipelineStage</span><span class="o">[]</span> <span class="o">{</span><span class="n">featureIndexer</span><span class="o">,</span> <span class="n">rf</span><span class="o">});</span> |
| |
| <span class="c1">// Train model. This also runs the indexer.</span> |
| <span class="nc">PipelineModel</span> <span class="n">model</span> <span class="o">=</span> <span class="n">pipeline</span><span class="o">.</span><span class="na">fit</span><span class="o">(</span><span class="n">trainingData</span><span class="o">);</span> |
| |
| <span class="c1">// Make predictions.</span> |
| <span class="nc">Dataset</span><span class="o"><</span><span class="nc">Row</span><span class="o">></span> <span class="n">predictions</span> <span class="o">=</span> <span class="n">model</span><span class="o">.</span><span class="na">transform</span><span class="o">(</span><span class="n">testData</span><span class="o">);</span> |
| |
| <span class="c1">// Select example rows to display.</span> |
| <span class="n">predictions</span><span class="o">.</span><span class="na">select</span><span class="o">(</span><span class="s">"prediction"</span><span class="o">,</span> <span class="s">"label"</span><span class="o">,</span> <span class="s">"features"</span><span class="o">).</span><span class="na">show</span><span class="o">(</span><span class="mi">5</span><span class="o">);</span> |
| |
| <span class="c1">// Select (prediction, true label) and compute test error</span> |
| <span class="nc">RegressionEvaluator</span> <span class="n">evaluator</span> <span class="o">=</span> <span class="k">new</span> <span class="nc">RegressionEvaluator</span><span class="o">()</span> |
| <span class="o">.</span><span class="na">setLabelCol</span><span class="o">(</span><span class="s">"label"</span><span class="o">)</span> |
| <span class="o">.</span><span class="na">setPredictionCol</span><span class="o">(</span><span class="s">"prediction"</span><span class="o">)</span> |
| <span class="o">.</span><span class="na">setMetricName</span><span class="o">(</span><span class="s">"rmse"</span><span class="o">);</span> |
| <span class="kt">double</span> <span class="n">rmse</span> <span class="o">=</span> <span class="n">evaluator</span><span class="o">.</span><span class="na">evaluate</span><span class="o">(</span><span class="n">predictions</span><span class="o">);</span> |
| <span class="nc">System</span><span class="o">.</span><span class="na">out</span><span class="o">.</span><span class="na">println</span><span class="o">(</span><span class="s">"Root Mean Squared Error (RMSE) on test data = "</span> <span class="o">+</span> <span class="n">rmse</span><span class="o">);</span> |
| |
| <span class="nc">RandomForestRegressionModel</span> <span class="n">rfModel</span> <span class="o">=</span> <span class="o">(</span><span class="nc">RandomForestRegressionModel</span><span class="o">)(</span><span class="n">model</span><span class="o">.</span><span class="na">stages</span><span class="o">()[</span><span class="mi">1</span><span class="o">]);</span> |
| <span class="nc">System</span><span class="o">.</span><span class="na">out</span><span class="o">.</span><span class="na">println</span><span class="o">(</span><span class="s">"Learned regression forest model:\n"</span> <span class="o">+</span> <span class="n">rfModel</span><span class="o">.</span><span class="na">toDebugString</span><span class="o">());</span></code></pre></div> |
| <div><small>Find full example code at "examples/src/main/java/org/apache/spark/examples/ml/JavaRandomForestRegressorExample.java" in the Spark repo.</small></div> |
| </div> |
| |
| <div data-lang="python"> |
| |
| <p>Refer to the <a href="api/python/reference/api/pyspark.ml.regression.RandomForestRegressor.html#pyspark.ml.regression.RandomForestRegressor">Python API docs</a> for more details.</p> |
| |
| <div class="highlight"><pre class="codehilite"><code><span class="kn">from</span> <span class="nn">pyspark.ml</span> <span class="kn">import</span> <span class="n">Pipeline</span> |
| <span class="kn">from</span> <span class="nn">pyspark.ml.regression</span> <span class="kn">import</span> <span class="n">RandomForestRegressor</span> |
| <span class="kn">from</span> <span class="nn">pyspark.ml.feature</span> <span class="kn">import</span> <span class="n">VectorIndexer</span> |
| <span class="kn">from</span> <span class="nn">pyspark.ml.evaluation</span> <span class="kn">import</span> <span class="n">RegressionEvaluator</span> |
| |
| <span class="c1"># Load and parse the data file, converting it to a DataFrame. |
| </span><span class="n">data</span> <span class="o">=</span> <span class="n">spark</span><span class="p">.</span><span class="n">read</span><span class="p">.</span><span class="nb">format</span><span class="p">(</span><span class="s">"libsvm"</span><span class="p">).</span><span class="n">load</span><span class="p">(</span><span class="s">"data/mllib/sample_libsvm_data.txt"</span><span class="p">)</span> |
| |
| <span class="c1"># Automatically identify categorical features, and index them. |
| # Set maxCategories so features with > 4 distinct values are treated as continuous. |
| </span><span class="n">featureIndexer</span> <span class="o">=</span>\ |
| <span class="n">VectorIndexer</span><span class="p">(</span><span class="n">inputCol</span><span class="o">=</span><span class="s">"features"</span><span class="p">,</span> <span class="n">outputCol</span><span class="o">=</span><span class="s">"indexedFeatures"</span><span class="p">,</span> <span class="n">maxCategories</span><span class="o">=</span><span class="mi">4</span><span class="p">).</span><span class="n">fit</span><span class="p">(</span><span class="n">data</span><span class="p">)</span> |
| |
| <span class="c1"># Split the data into training and test sets (30% held out for testing) |
| </span><span class="p">(</span><span class="n">trainingData</span><span class="p">,</span> <span class="n">testData</span><span class="p">)</span> <span class="o">=</span> <span class="n">data</span><span class="p">.</span><span class="n">randomSplit</span><span class="p">([</span><span class="mf">0.7</span><span class="p">,</span> <span class="mf">0.3</span><span class="p">])</span> |
| |
| <span class="c1"># Train a RandomForest model. |
| </span><span class="n">rf</span> <span class="o">=</span> <span class="n">RandomForestRegressor</span><span class="p">(</span><span class="n">featuresCol</span><span class="o">=</span><span class="s">"indexedFeatures"</span><span class="p">)</span> |
| |
| <span class="c1"># Chain indexer and forest in a Pipeline |
| </span><span class="n">pipeline</span> <span class="o">=</span> <span class="n">Pipeline</span><span class="p">(</span><span class="n">stages</span><span class="o">=</span><span class="p">[</span><span class="n">featureIndexer</span><span class="p">,</span> <span class="n">rf</span><span class="p">])</span> |
| |
| <span class="c1"># Train model. This also runs the indexer. |
| </span><span class="n">model</span> <span class="o">=</span> <span class="n">pipeline</span><span class="p">.</span><span class="n">fit</span><span class="p">(</span><span class="n">trainingData</span><span class="p">)</span> |
| |
| <span class="c1"># Make predictions. |
| </span><span class="n">predictions</span> <span class="o">=</span> <span class="n">model</span><span class="p">.</span><span class="n">transform</span><span class="p">(</span><span class="n">testData</span><span class="p">)</span> |
| |
| <span class="c1"># Select example rows to display. |
| </span><span class="n">predictions</span><span class="p">.</span><span class="n">select</span><span class="p">(</span><span class="s">"prediction"</span><span class="p">,</span> <span class="s">"label"</span><span class="p">,</span> <span class="s">"features"</span><span class="p">).</span><span class="n">show</span><span class="p">(</span><span class="mi">5</span><span class="p">)</span> |
| |
| <span class="c1"># Select (prediction, true label) and compute test error |
| </span><span class="n">evaluator</span> <span class="o">=</span> <span class="n">RegressionEvaluator</span><span class="p">(</span> |
| <span class="n">labelCol</span><span class="o">=</span><span class="s">"label"</span><span class="p">,</span> <span class="n">predictionCol</span><span class="o">=</span><span class="s">"prediction"</span><span class="p">,</span> <span class="n">metricName</span><span class="o">=</span><span class="s">"rmse"</span><span class="p">)</span> |
| <span class="n">rmse</span> <span class="o">=</span> <span class="n">evaluator</span><span class="p">.</span><span class="n">evaluate</span><span class="p">(</span><span class="n">predictions</span><span class="p">)</span> |
| <span class="k">print</span><span class="p">(</span><span class="s">"Root Mean Squared Error (RMSE) on test data = %g"</span> <span class="o">%</span> <span class="n">rmse</span><span class="p">)</span> |
| |
| <span class="n">rfModel</span> <span class="o">=</span> <span class="n">model</span><span class="p">.</span><span class="n">stages</span><span class="p">[</span><span class="mi">1</span><span class="p">]</span> |
| <span class="k">print</span><span class="p">(</span><span class="n">rfModel</span><span class="p">)</span> <span class="c1"># summary only</span></code></pre></div> |
| <div><small>Find full example code at "examples/src/main/python/ml/random_forest_regressor_example.py" in the Spark repo.</small></div> |
| </div> |
| |
| <div data-lang="r"> |
| |
| <p>Refer to the <a href="api/R/spark.randomForest.html">R API docs</a> for more details.</p> |
| |
| <div class="highlight"><pre class="codehilite"><code><span class="c1"># Load training data</span><span class="w"> |
| </span><span class="n">df</span><span class="w"> </span><span class="o"><-</span><span class="w"> </span><span class="n">read.df</span><span class="p">(</span><span class="s2">"data/mllib/sample_linear_regression_data.txt"</span><span class="p">,</span><span class="w"> </span><span class="n">source</span><span class="w"> </span><span class="o">=</span><span class="w"> </span><span class="s2">"libsvm"</span><span class="p">)</span><span class="w"> |
| </span><span class="n">training</span><span class="w"> </span><span class="o"><-</span><span class="w"> </span><span class="n">df</span><span class="w"> |
| </span><span class="n">test</span><span class="w"> </span><span class="o"><-</span><span class="w"> </span><span class="n">df</span><span class="w"> |
| |
| </span><span class="c1"># Fit a random forest regression model with spark.randomForest</span><span class="w"> |
| </span><span class="n">model</span><span class="w"> </span><span class="o"><-</span><span class="w"> </span><span class="n">spark.randomForest</span><span class="p">(</span><span class="n">training</span><span class="p">,</span><span class="w"> </span><span class="n">label</span><span class="w"> </span><span class="o">~</span><span class="w"> </span><span class="n">features</span><span class="p">,</span><span class="w"> </span><span class="s2">"regression"</span><span class="p">,</span><span class="w"> </span><span class="n">numTrees</span><span class="w"> </span><span class="o">=</span><span class="w"> </span><span class="m">10</span><span class="p">)</span><span class="w"> |
| |
| </span><span class="c1"># Model summary</span><span class="w"> |
| </span><span class="n">summary</span><span class="p">(</span><span class="n">model</span><span class="p">)</span><span class="w"> |
| |
| </span><span class="c1"># Prediction</span><span class="w"> |
| </span><span class="n">predictions</span><span class="w"> </span><span class="o"><-</span><span class="w"> </span><span class="n">predict</span><span class="p">(</span><span class="n">model</span><span class="p">,</span><span class="w"> </span><span class="n">test</span><span class="p">)</span><span class="w"> |
| </span><span class="n">head</span><span class="p">(</span><span class="n">predictions</span><span class="p">)</span></code></pre></div> |
| <div><small>Find full example code at "examples/src/main/r/ml/randomForest.R" in the Spark repo.</small></div> |
| </div> |
| |
| </div> |
| |
| <h2 id="gradient-boosted-tree-regression">Gradient-boosted tree regression</h2> |
| |
| <p>Gradient-boosted trees (GBTs) are a popular regression method using ensembles of decision trees. |
| More information about the <code class="language-plaintext highlighter-rouge">spark.ml</code> implementation can be found further in the <a href="#gradient-boosted-trees-gbts">section on GBTs</a>.</p> |
| |
| <p><strong>Examples</strong></p> |
| |
| <p>Note: For this example dataset, <code class="language-plaintext highlighter-rouge">GBTRegressor</code> actually only needs 1 iteration, but that will not |
| be true in general.</p> |
| |
| <div class="codetabs"> |
| <div data-lang="scala"> |
| |
| <p>Refer to the <a href="api/scala/org/apache/spark/ml/regression/GBTRegressor.html">Scala API docs</a> for more details.</p> |
| |
| <div class="highlight"><pre class="codehilite"><code><span class="k">import</span> <span class="nn">org.apache.spark.ml.Pipeline</span> |
| <span class="k">import</span> <span class="nn">org.apache.spark.ml.evaluation.RegressionEvaluator</span> |
| <span class="k">import</span> <span class="nn">org.apache.spark.ml.feature.VectorIndexer</span> |
| <span class="k">import</span> <span class="nn">org.apache.spark.ml.regression.</span><span class="o">{</span><span class="nc">GBTRegressionModel</span><span class="o">,</span> <span class="nc">GBTRegressor</span><span class="o">}</span> |
| |
| <span class="c1">// Load and parse the data file, converting it to a DataFrame.</span> |
| <span class="k">val</span> <span class="nv">data</span> <span class="k">=</span> <span class="nv">spark</span><span class="o">.</span><span class="py">read</span><span class="o">.</span><span class="py">format</span><span class="o">(</span><span class="s">"libsvm"</span><span class="o">).</span><span class="py">load</span><span class="o">(</span><span class="s">"data/mllib/sample_libsvm_data.txt"</span><span class="o">)</span> |
| |
| <span class="c1">// Automatically identify categorical features, and index them.</span> |
| <span class="c1">// Set maxCategories so features with > 4 distinct values are treated as continuous.</span> |
| <span class="k">val</span> <span class="nv">featureIndexer</span> <span class="k">=</span> <span class="k">new</span> <span class="nc">VectorIndexer</span><span class="o">()</span> |
| <span class="o">.</span><span class="py">setInputCol</span><span class="o">(</span><span class="s">"features"</span><span class="o">)</span> |
| <span class="o">.</span><span class="py">setOutputCol</span><span class="o">(</span><span class="s">"indexedFeatures"</span><span class="o">)</span> |
| <span class="o">.</span><span class="py">setMaxCategories</span><span class="o">(</span><span class="mi">4</span><span class="o">)</span> |
| <span class="o">.</span><span class="py">fit</span><span class="o">(</span><span class="n">data</span><span class="o">)</span> |
| |
| <span class="c1">// Split the data into training and test sets (30% held out for testing).</span> |
| <span class="k">val</span> <span class="nv">Array</span><span class="o">(</span><span class="n">trainingData</span><span class="o">,</span> <span class="n">testData</span><span class="o">)</span> <span class="k">=</span> <span class="nv">data</span><span class="o">.</span><span class="py">randomSplit</span><span class="o">(</span><span class="nc">Array</span><span class="o">(</span><span class="mf">0.7</span><span class="o">,</span> <span class="mf">0.3</span><span class="o">))</span> |
| |
| <span class="c1">// Train a GBT model.</span> |
| <span class="k">val</span> <span class="nv">gbt</span> <span class="k">=</span> <span class="k">new</span> <span class="nc">GBTRegressor</span><span class="o">()</span> |
| <span class="o">.</span><span class="py">setLabelCol</span><span class="o">(</span><span class="s">"label"</span><span class="o">)</span> |
| <span class="o">.</span><span class="py">setFeaturesCol</span><span class="o">(</span><span class="s">"indexedFeatures"</span><span class="o">)</span> |
| <span class="o">.</span><span class="py">setMaxIter</span><span class="o">(</span><span class="mi">10</span><span class="o">)</span> |
| |
| <span class="c1">// Chain indexer and GBT in a Pipeline.</span> |
| <span class="k">val</span> <span class="nv">pipeline</span> <span class="k">=</span> <span class="k">new</span> <span class="nc">Pipeline</span><span class="o">()</span> |
| <span class="o">.</span><span class="py">setStages</span><span class="o">(</span><span class="nc">Array</span><span class="o">(</span><span class="n">featureIndexer</span><span class="o">,</span> <span class="n">gbt</span><span class="o">))</span> |
| |
| <span class="c1">// Train model. This also runs the indexer.</span> |
| <span class="k">val</span> <span class="nv">model</span> <span class="k">=</span> <span class="nv">pipeline</span><span class="o">.</span><span class="py">fit</span><span class="o">(</span><span class="n">trainingData</span><span class="o">)</span> |
| |
| <span class="c1">// Make predictions.</span> |
| <span class="k">val</span> <span class="nv">predictions</span> <span class="k">=</span> <span class="nv">model</span><span class="o">.</span><span class="py">transform</span><span class="o">(</span><span class="n">testData</span><span class="o">)</span> |
| |
| <span class="c1">// Select example rows to display.</span> |
| <span class="nv">predictions</span><span class="o">.</span><span class="py">select</span><span class="o">(</span><span class="s">"prediction"</span><span class="o">,</span> <span class="s">"label"</span><span class="o">,</span> <span class="s">"features"</span><span class="o">).</span><span class="py">show</span><span class="o">(</span><span class="mi">5</span><span class="o">)</span> |
| |
| <span class="c1">// Select (prediction, true label) and compute test error.</span> |
| <span class="k">val</span> <span class="nv">evaluator</span> <span class="k">=</span> <span class="k">new</span> <span class="nc">RegressionEvaluator</span><span class="o">()</span> |
| <span class="o">.</span><span class="py">setLabelCol</span><span class="o">(</span><span class="s">"label"</span><span class="o">)</span> |
| <span class="o">.</span><span class="py">setPredictionCol</span><span class="o">(</span><span class="s">"prediction"</span><span class="o">)</span> |
| <span class="o">.</span><span class="py">setMetricName</span><span class="o">(</span><span class="s">"rmse"</span><span class="o">)</span> |
| <span class="k">val</span> <span class="nv">rmse</span> <span class="k">=</span> <span class="nv">evaluator</span><span class="o">.</span><span class="py">evaluate</span><span class="o">(</span><span class="n">predictions</span><span class="o">)</span> |
| <span class="nf">println</span><span class="o">(</span><span class="n">s</span><span class="s">"Root Mean Squared Error (RMSE) on test data = $rmse"</span><span class="o">)</span> |
| |
| <span class="k">val</span> <span class="nv">gbtModel</span> <span class="k">=</span> <span class="nv">model</span><span class="o">.</span><span class="py">stages</span><span class="o">(</span><span class="mi">1</span><span class="o">).</span><span class="py">asInstanceOf</span><span class="o">[</span><span class="kt">GBTRegressionModel</span><span class="o">]</span> |
| <span class="nf">println</span><span class="o">(</span><span class="n">s</span><span class="s">"Learned regression GBT model:\n ${gbtModel.toDebugString}"</span><span class="o">)</span></code></pre></div> |
| <div><small>Find full example code at "examples/src/main/scala/org/apache/spark/examples/ml/GradientBoostedTreeRegressorExample.scala" in the Spark repo.</small></div> |
| </div> |
| |
| <div data-lang="java"> |
| |
| <p>Refer to the <a href="api/java/org/apache/spark/ml/regression/GBTRegressor.html">Java API docs</a> for more details.</p> |
| |
| <div class="highlight"><pre class="codehilite"><code><span class="kn">import</span> <span class="nn">org.apache.spark.ml.Pipeline</span><span class="o">;</span> |
| <span class="kn">import</span> <span class="nn">org.apache.spark.ml.PipelineModel</span><span class="o">;</span> |
| <span class="kn">import</span> <span class="nn">org.apache.spark.ml.PipelineStage</span><span class="o">;</span> |
| <span class="kn">import</span> <span class="nn">org.apache.spark.ml.evaluation.RegressionEvaluator</span><span class="o">;</span> |
| <span class="kn">import</span> <span class="nn">org.apache.spark.ml.feature.VectorIndexer</span><span class="o">;</span> |
| <span class="kn">import</span> <span class="nn">org.apache.spark.ml.feature.VectorIndexerModel</span><span class="o">;</span> |
| <span class="kn">import</span> <span class="nn">org.apache.spark.ml.regression.GBTRegressionModel</span><span class="o">;</span> |
| <span class="kn">import</span> <span class="nn">org.apache.spark.ml.regression.GBTRegressor</span><span class="o">;</span> |
| <span class="kn">import</span> <span class="nn">org.apache.spark.sql.Dataset</span><span class="o">;</span> |
| <span class="kn">import</span> <span class="nn">org.apache.spark.sql.Row</span><span class="o">;</span> |
| <span class="kn">import</span> <span class="nn">org.apache.spark.sql.SparkSession</span><span class="o">;</span> |
| |
| <span class="c1">// Load and parse the data file, converting it to a DataFrame.</span> |
| <span class="nc">Dataset</span><span class="o"><</span><span class="nc">Row</span><span class="o">></span> <span class="n">data</span> <span class="o">=</span> <span class="n">spark</span><span class="o">.</span><span class="na">read</span><span class="o">().</span><span class="na">format</span><span class="o">(</span><span class="s">"libsvm"</span><span class="o">).</span><span class="na">load</span><span class="o">(</span><span class="s">"data/mllib/sample_libsvm_data.txt"</span><span class="o">);</span> |
| |
| <span class="c1">// Automatically identify categorical features, and index them.</span> |
| <span class="c1">// Set maxCategories so features with > 4 distinct values are treated as continuous.</span> |
| <span class="nc">VectorIndexerModel</span> <span class="n">featureIndexer</span> <span class="o">=</span> <span class="k">new</span> <span class="nc">VectorIndexer</span><span class="o">()</span> |
| <span class="o">.</span><span class="na">setInputCol</span><span class="o">(</span><span class="s">"features"</span><span class="o">)</span> |
| <span class="o">.</span><span class="na">setOutputCol</span><span class="o">(</span><span class="s">"indexedFeatures"</span><span class="o">)</span> |
| <span class="o">.</span><span class="na">setMaxCategories</span><span class="o">(</span><span class="mi">4</span><span class="o">)</span> |
| <span class="o">.</span><span class="na">fit</span><span class="o">(</span><span class="n">data</span><span class="o">);</span> |
| |
| <span class="c1">// Split the data into training and test sets (30% held out for testing).</span> |
| <span class="nc">Dataset</span><span class="o"><</span><span class="nc">Row</span><span class="o">>[]</span> <span class="n">splits</span> <span class="o">=</span> <span class="n">data</span><span class="o">.</span><span class="na">randomSplit</span><span class="o">(</span><span class="k">new</span> <span class="kt">double</span><span class="o">[]</span> <span class="o">{</span><span class="mf">0.7</span><span class="o">,</span> <span class="mf">0.3</span><span class="o">});</span> |
| <span class="nc">Dataset</span><span class="o"><</span><span class="nc">Row</span><span class="o">></span> <span class="n">trainingData</span> <span class="o">=</span> <span class="n">splits</span><span class="o">[</span><span class="mi">0</span><span class="o">];</span> |
| <span class="nc">Dataset</span><span class="o"><</span><span class="nc">Row</span><span class="o">></span> <span class="n">testData</span> <span class="o">=</span> <span class="n">splits</span><span class="o">[</span><span class="mi">1</span><span class="o">];</span> |
| |
| <span class="c1">// Train a GBT model.</span> |
| <span class="nc">GBTRegressor</span> <span class="n">gbt</span> <span class="o">=</span> <span class="k">new</span> <span class="nc">GBTRegressor</span><span class="o">()</span> |
| <span class="o">.</span><span class="na">setLabelCol</span><span class="o">(</span><span class="s">"label"</span><span class="o">)</span> |
| <span class="o">.</span><span class="na">setFeaturesCol</span><span class="o">(</span><span class="s">"indexedFeatures"</span><span class="o">)</span> |
| <span class="o">.</span><span class="na">setMaxIter</span><span class="o">(</span><span class="mi">10</span><span class="o">);</span> |
| |
| <span class="c1">// Chain indexer and GBT in a Pipeline.</span> |
| <span class="nc">Pipeline</span> <span class="n">pipeline</span> <span class="o">=</span> <span class="k">new</span> <span class="nc">Pipeline</span><span class="o">().</span><span class="na">setStages</span><span class="o">(</span><span class="k">new</span> <span class="nc">PipelineStage</span><span class="o">[]</span> <span class="o">{</span><span class="n">featureIndexer</span><span class="o">,</span> <span class="n">gbt</span><span class="o">});</span> |
| |
| <span class="c1">// Train model. This also runs the indexer.</span> |
| <span class="nc">PipelineModel</span> <span class="n">model</span> <span class="o">=</span> <span class="n">pipeline</span><span class="o">.</span><span class="na">fit</span><span class="o">(</span><span class="n">trainingData</span><span class="o">);</span> |
| |
| <span class="c1">// Make predictions.</span> |
| <span class="nc">Dataset</span><span class="o"><</span><span class="nc">Row</span><span class="o">></span> <span class="n">predictions</span> <span class="o">=</span> <span class="n">model</span><span class="o">.</span><span class="na">transform</span><span class="o">(</span><span class="n">testData</span><span class="o">);</span> |
| |
| <span class="c1">// Select example rows to display.</span> |
| <span class="n">predictions</span><span class="o">.</span><span class="na">select</span><span class="o">(</span><span class="s">"prediction"</span><span class="o">,</span> <span class="s">"label"</span><span class="o">,</span> <span class="s">"features"</span><span class="o">).</span><span class="na">show</span><span class="o">(</span><span class="mi">5</span><span class="o">);</span> |
| |
| <span class="c1">// Select (prediction, true label) and compute test error.</span> |
| <span class="nc">RegressionEvaluator</span> <span class="n">evaluator</span> <span class="o">=</span> <span class="k">new</span> <span class="nc">RegressionEvaluator</span><span class="o">()</span> |
| <span class="o">.</span><span class="na">setLabelCol</span><span class="o">(</span><span class="s">"label"</span><span class="o">)</span> |
| <span class="o">.</span><span class="na">setPredictionCol</span><span class="o">(</span><span class="s">"prediction"</span><span class="o">)</span> |
| <span class="o">.</span><span class="na">setMetricName</span><span class="o">(</span><span class="s">"rmse"</span><span class="o">);</span> |
| <span class="kt">double</span> <span class="n">rmse</span> <span class="o">=</span> <span class="n">evaluator</span><span class="o">.</span><span class="na">evaluate</span><span class="o">(</span><span class="n">predictions</span><span class="o">);</span> |
| <span class="nc">System</span><span class="o">.</span><span class="na">out</span><span class="o">.</span><span class="na">println</span><span class="o">(</span><span class="s">"Root Mean Squared Error (RMSE) on test data = "</span> <span class="o">+</span> <span class="n">rmse</span><span class="o">);</span> |
| |
| <span class="nc">GBTRegressionModel</span> <span class="n">gbtModel</span> <span class="o">=</span> <span class="o">(</span><span class="nc">GBTRegressionModel</span><span class="o">)(</span><span class="n">model</span><span class="o">.</span><span class="na">stages</span><span class="o">()[</span><span class="mi">1</span><span class="o">]);</span> |
| <span class="nc">System</span><span class="o">.</span><span class="na">out</span><span class="o">.</span><span class="na">println</span><span class="o">(</span><span class="s">"Learned regression GBT model:\n"</span> <span class="o">+</span> <span class="n">gbtModel</span><span class="o">.</span><span class="na">toDebugString</span><span class="o">());</span></code></pre></div> |
| <div><small>Find full example code at "examples/src/main/java/org/apache/spark/examples/ml/JavaGradientBoostedTreeRegressorExample.java" in the Spark repo.</small></div> |
| </div> |
| |
| <div data-lang="python"> |
| |
| <p>Refer to the <a href="api/python/reference/api/pyspark.ml.regression.GBTRegressor.html#pyspark.ml.regression.GBTRegressor">Python API docs</a> for more details.</p> |
| |
| <div class="highlight"><pre class="codehilite"><code><span class="kn">from</span> <span class="nn">pyspark.ml</span> <span class="kn">import</span> <span class="n">Pipeline</span> |
| <span class="kn">from</span> <span class="nn">pyspark.ml.regression</span> <span class="kn">import</span> <span class="n">GBTRegressor</span> |
| <span class="kn">from</span> <span class="nn">pyspark.ml.feature</span> <span class="kn">import</span> <span class="n">VectorIndexer</span> |
| <span class="kn">from</span> <span class="nn">pyspark.ml.evaluation</span> <span class="kn">import</span> <span class="n">RegressionEvaluator</span> |
| |
| <span class="c1"># Load and parse the data file, converting it to a DataFrame. |
| </span><span class="n">data</span> <span class="o">=</span> <span class="n">spark</span><span class="p">.</span><span class="n">read</span><span class="p">.</span><span class="nb">format</span><span class="p">(</span><span class="s">"libsvm"</span><span class="p">).</span><span class="n">load</span><span class="p">(</span><span class="s">"data/mllib/sample_libsvm_data.txt"</span><span class="p">)</span> |
| |
| <span class="c1"># Automatically identify categorical features, and index them. |
| # Set maxCategories so features with > 4 distinct values are treated as continuous. |
| </span><span class="n">featureIndexer</span> <span class="o">=</span>\ |
| <span class="n">VectorIndexer</span><span class="p">(</span><span class="n">inputCol</span><span class="o">=</span><span class="s">"features"</span><span class="p">,</span> <span class="n">outputCol</span><span class="o">=</span><span class="s">"indexedFeatures"</span><span class="p">,</span> <span class="n">maxCategories</span><span class="o">=</span><span class="mi">4</span><span class="p">).</span><span class="n">fit</span><span class="p">(</span><span class="n">data</span><span class="p">)</span> |
| |
| <span class="c1"># Split the data into training and test sets (30% held out for testing) |
| </span><span class="p">(</span><span class="n">trainingData</span><span class="p">,</span> <span class="n">testData</span><span class="p">)</span> <span class="o">=</span> <span class="n">data</span><span class="p">.</span><span class="n">randomSplit</span><span class="p">([</span><span class="mf">0.7</span><span class="p">,</span> <span class="mf">0.3</span><span class="p">])</span> |
| |
| <span class="c1"># Train a GBT model. |
| </span><span class="n">gbt</span> <span class="o">=</span> <span class="n">GBTRegressor</span><span class="p">(</span><span class="n">featuresCol</span><span class="o">=</span><span class="s">"indexedFeatures"</span><span class="p">,</span> <span class="n">maxIter</span><span class="o">=</span><span class="mi">10</span><span class="p">)</span> |
| |
| <span class="c1"># Chain indexer and GBT in a Pipeline |
| </span><span class="n">pipeline</span> <span class="o">=</span> <span class="n">Pipeline</span><span class="p">(</span><span class="n">stages</span><span class="o">=</span><span class="p">[</span><span class="n">featureIndexer</span><span class="p">,</span> <span class="n">gbt</span><span class="p">])</span> |
| |
| <span class="c1"># Train model. This also runs the indexer. |
| </span><span class="n">model</span> <span class="o">=</span> <span class="n">pipeline</span><span class="p">.</span><span class="n">fit</span><span class="p">(</span><span class="n">trainingData</span><span class="p">)</span> |
| |
| <span class="c1"># Make predictions. |
| </span><span class="n">predictions</span> <span class="o">=</span> <span class="n">model</span><span class="p">.</span><span class="n">transform</span><span class="p">(</span><span class="n">testData</span><span class="p">)</span> |
| |
| <span class="c1"># Select example rows to display. |
| </span><span class="n">predictions</span><span class="p">.</span><span class="n">select</span><span class="p">(</span><span class="s">"prediction"</span><span class="p">,</span> <span class="s">"label"</span><span class="p">,</span> <span class="s">"features"</span><span class="p">).</span><span class="n">show</span><span class="p">(</span><span class="mi">5</span><span class="p">)</span> |
| |
| <span class="c1"># Select (prediction, true label) and compute test error |
| </span><span class="n">evaluator</span> <span class="o">=</span> <span class="n">RegressionEvaluator</span><span class="p">(</span> |
| <span class="n">labelCol</span><span class="o">=</span><span class="s">"label"</span><span class="p">,</span> <span class="n">predictionCol</span><span class="o">=</span><span class="s">"prediction"</span><span class="p">,</span> <span class="n">metricName</span><span class="o">=</span><span class="s">"rmse"</span><span class="p">)</span> |
| <span class="n">rmse</span> <span class="o">=</span> <span class="n">evaluator</span><span class="p">.</span><span class="n">evaluate</span><span class="p">(</span><span class="n">predictions</span><span class="p">)</span> |
| <span class="k">print</span><span class="p">(</span><span class="s">"Root Mean Squared Error (RMSE) on test data = %g"</span> <span class="o">%</span> <span class="n">rmse</span><span class="p">)</span> |
| |
| <span class="n">gbtModel</span> <span class="o">=</span> <span class="n">model</span><span class="p">.</span><span class="n">stages</span><span class="p">[</span><span class="mi">1</span><span class="p">]</span> |
| <span class="k">print</span><span class="p">(</span><span class="n">gbtModel</span><span class="p">)</span> <span class="c1"># summary only</span></code></pre></div> |
| <div><small>Find full example code at "examples/src/main/python/ml/gradient_boosted_tree_regressor_example.py" in the Spark repo.</small></div> |
| </div> |
| |
| <div data-lang="r"> |
| |
| <p>Refer to the <a href="api/R/spark.gbt.html">R API docs</a> for more details.</p> |
| |
| <div class="highlight"><pre class="codehilite"><code><span class="c1"># Load training data</span><span class="w"> |
| </span><span class="n">df</span><span class="w"> </span><span class="o"><-</span><span class="w"> </span><span class="n">read.df</span><span class="p">(</span><span class="s2">"data/mllib/sample_linear_regression_data.txt"</span><span class="p">,</span><span class="w"> </span><span class="n">source</span><span class="w"> </span><span class="o">=</span><span class="w"> </span><span class="s2">"libsvm"</span><span class="p">)</span><span class="w"> |
| </span><span class="n">training</span><span class="w"> </span><span class="o"><-</span><span class="w"> </span><span class="n">df</span><span class="w"> |
| </span><span class="n">test</span><span class="w"> </span><span class="o"><-</span><span class="w"> </span><span class="n">df</span><span class="w"> |
| |
| </span><span class="c1"># Fit a GBT regression model with spark.gbt</span><span class="w"> |
| </span><span class="n">model</span><span class="w"> </span><span class="o"><-</span><span class="w"> </span><span class="n">spark.gbt</span><span class="p">(</span><span class="n">training</span><span class="p">,</span><span class="w"> </span><span class="n">label</span><span class="w"> </span><span class="o">~</span><span class="w"> </span><span class="n">features</span><span class="p">,</span><span class="w"> </span><span class="s2">"regression"</span><span class="p">,</span><span class="w"> </span><span class="n">maxIter</span><span class="w"> </span><span class="o">=</span><span class="w"> </span><span class="m">10</span><span class="p">)</span><span class="w"> |
| |
| </span><span class="c1"># Model summary</span><span class="w"> |
| </span><span class="n">summary</span><span class="p">(</span><span class="n">model</span><span class="p">)</span><span class="w"> |
| |
| </span><span class="c1"># Prediction</span><span class="w"> |
| </span><span class="n">predictions</span><span class="w"> </span><span class="o"><-</span><span class="w"> </span><span class="n">predict</span><span class="p">(</span><span class="n">model</span><span class="p">,</span><span class="w"> </span><span class="n">test</span><span class="p">)</span><span class="w"> |
| </span><span class="n">head</span><span class="p">(</span><span class="n">predictions</span><span class="p">)</span></code></pre></div> |
| <div><small>Find full example code at "examples/src/main/r/ml/gbt.R" in the Spark repo.</small></div> |
| </div> |
| |
| </div> |
| |
| <h2 id="survival-regression">Survival regression</h2> |
| |
| <p>In <code class="language-plaintext highlighter-rouge">spark.ml</code>, we implement the <a href="https://en.wikipedia.org/wiki/Accelerated_failure_time_model">Accelerated failure time (AFT)</a> |
| model which is a parametric survival regression model for censored data. |
| It describes a model for the log of survival time, so it’s often called a |
| log-linear model for survival analysis. Different from a |
| <a href="https://en.wikipedia.org/wiki/Proportional_hazards_model">Proportional hazards</a> model |
| designed for the same purpose, the AFT model is easier to parallelize |
| because each instance contributes to the objective function independently.</p> |
| |
| <p>Given the values of the covariates $x^{‘}$, for random lifetime $t_{i}$ of |
| subjects i = 1, …, n, with possible right-censoring, |
| the likelihood function under the AFT model is given as: |
| <code class="language-plaintext highlighter-rouge">\[ |
| L(\beta,\sigma)=\prod_{i=1}^n[\frac{1}{\sigma}f_{0}(\frac{\log{t_{i}}-x^{'}\beta}{\sigma})]^{\delta_{i}}S_{0}(\frac{\log{t_{i}}-x^{'}\beta}{\sigma})^{1-\delta_{i}} |
| \]</code> |
| Where $\delta_{i}$ is the indicator of the event has occurred i.e. uncensored or not. |
| Using $\epsilon_{i}=\frac{\log{t_{i}}-x^{‘}\beta}{\sigma}$, the log-likelihood function |
| assumes the form: |
| <code class="language-plaintext highlighter-rouge">\[ |
| \iota(\beta,\sigma)=\sum_{i=1}^{n}[-\delta_{i}\log\sigma+\delta_{i}\log{f_{0}}(\epsilon_{i})+(1-\delta_{i})\log{S_{0}(\epsilon_{i})}] |
| \]</code> |
| Where $S_{0}(\epsilon_{i})$ is the baseline survivor function, |
| and $f_{0}(\epsilon_{i})$ is the corresponding density function.</p> |
| |
| <p>The most commonly used AFT model is based on the Weibull distribution of the survival time. |
| The Weibull distribution for lifetime corresponds to the extreme value distribution for the |
| log of the lifetime, and the $S_{0}(\epsilon)$ function is: |
| <code class="language-plaintext highlighter-rouge">\[ |
| S_{0}(\epsilon_{i})=\exp(-e^{\epsilon_{i}}) |
| \]</code> |
| the $f_{0}(\epsilon_{i})$ function is: |
| <code class="language-plaintext highlighter-rouge">\[ |
| f_{0}(\epsilon_{i})=e^{\epsilon_{i}}\exp(-e^{\epsilon_{i}}) |
| \]</code> |
| The log-likelihood function for AFT model with a Weibull distribution of lifetime is: |
| <code class="language-plaintext highlighter-rouge">\[ |
| \iota(\beta,\sigma)= -\sum_{i=1}^n[\delta_{i}\log\sigma-\delta_{i}\epsilon_{i}+e^{\epsilon_{i}}] |
| \]</code> |
| Due to minimizing the negative log-likelihood equivalent to maximum a posteriori probability, |
| the loss function we use to optimize is $-\iota(\beta,\sigma)$. |
| The gradient functions for $\beta$ and $\log\sigma$ respectively are: |
| <code class="language-plaintext highlighter-rouge">\[ |
| \frac{\partial (-\iota)}{\partial \beta}=\sum_{1=1}^{n}[\delta_{i}-e^{\epsilon_{i}}]\frac{x_{i}}{\sigma} |
| \]</code> |
| <code class="language-plaintext highlighter-rouge">\[ |
| \frac{\partial (-\iota)}{\partial (\log\sigma)}=\sum_{i=1}^{n}[\delta_{i}+(\delta_{i}-e^{\epsilon_{i}})\epsilon_{i}] |
| \]</code></p> |
| |
| <p>The AFT model can be formulated as a convex optimization problem, |
| i.e. the task of finding a minimizer of a convex function $-\iota(\beta,\sigma)$ |
| that depends on the coefficients vector $\beta$ and the log of scale parameter $\log\sigma$. |
| The optimization algorithm underlying the implementation is L-BFGS. |
| The implementation matches the result from R’s survival function |
| <a href="https://stat.ethz.ch/R-manual/R-devel/library/survival/html/survreg.html">survreg</a></p> |
| |
| <blockquote> |
| <p>When fitting AFTSurvivalRegressionModel without intercept on dataset with constant nonzero column, Spark MLlib outputs zero coefficients for constant nonzero columns. This behavior is different from R survival::survreg.</p> |
| </blockquote> |
| |
| <p><strong>Examples</strong></p> |
| |
| <div class="codetabs"> |
| |
| <div data-lang="scala"> |
| |
| <p>Refer to the <a href="api/scala/org/apache/spark/ml/regression/AFTSurvivalRegression.html">Scala API docs</a> for more details.</p> |
| |
| <div class="highlight"><pre class="codehilite"><code><span class="k">import</span> <span class="nn">org.apache.spark.ml.linalg.Vectors</span> |
| <span class="k">import</span> <span class="nn">org.apache.spark.ml.regression.AFTSurvivalRegression</span> |
| |
| <span class="k">val</span> <span class="nv">training</span> <span class="k">=</span> <span class="nv">spark</span><span class="o">.</span><span class="py">createDataFrame</span><span class="o">(</span><span class="nc">Seq</span><span class="o">(</span> |
| <span class="o">(</span><span class="mf">1.218</span><span class="o">,</span> <span class="mf">1.0</span><span class="o">,</span> <span class="nv">Vectors</span><span class="o">.</span><span class="py">dense</span><span class="o">(</span><span class="mf">1.560</span><span class="o">,</span> <span class="o">-</span><span class="mf">0.605</span><span class="o">)),</span> |
| <span class="o">(</span><span class="mf">2.949</span><span class="o">,</span> <span class="mf">0.0</span><span class="o">,</span> <span class="nv">Vectors</span><span class="o">.</span><span class="py">dense</span><span class="o">(</span><span class="mf">0.346</span><span class="o">,</span> <span class="mf">2.158</span><span class="o">)),</span> |
| <span class="o">(</span><span class="mf">3.627</span><span class="o">,</span> <span class="mf">0.0</span><span class="o">,</span> <span class="nv">Vectors</span><span class="o">.</span><span class="py">dense</span><span class="o">(</span><span class="mf">1.380</span><span class="o">,</span> <span class="mf">0.231</span><span class="o">)),</span> |
| <span class="o">(</span><span class="mf">0.273</span><span class="o">,</span> <span class="mf">1.0</span><span class="o">,</span> <span class="nv">Vectors</span><span class="o">.</span><span class="py">dense</span><span class="o">(</span><span class="mf">0.520</span><span class="o">,</span> <span class="mf">1.151</span><span class="o">)),</span> |
| <span class="o">(</span><span class="mf">4.199</span><span class="o">,</span> <span class="mf">0.0</span><span class="o">,</span> <span class="nv">Vectors</span><span class="o">.</span><span class="py">dense</span><span class="o">(</span><span class="mf">0.795</span><span class="o">,</span> <span class="o">-</span><span class="mf">0.226</span><span class="o">))</span> |
| <span class="o">)).</span><span class="py">toDF</span><span class="o">(</span><span class="s">"label"</span><span class="o">,</span> <span class="s">"censor"</span><span class="o">,</span> <span class="s">"features"</span><span class="o">)</span> |
| <span class="k">val</span> <span class="nv">quantileProbabilities</span> <span class="k">=</span> <span class="nc">Array</span><span class="o">(</span><span class="mf">0.3</span><span class="o">,</span> <span class="mf">0.6</span><span class="o">)</span> |
| <span class="k">val</span> <span class="nv">aft</span> <span class="k">=</span> <span class="k">new</span> <span class="nc">AFTSurvivalRegression</span><span class="o">()</span> |
| <span class="o">.</span><span class="py">setQuantileProbabilities</span><span class="o">(</span><span class="n">quantileProbabilities</span><span class="o">)</span> |
| <span class="o">.</span><span class="py">setQuantilesCol</span><span class="o">(</span><span class="s">"quantiles"</span><span class="o">)</span> |
| |
| <span class="k">val</span> <span class="nv">model</span> <span class="k">=</span> <span class="nv">aft</span><span class="o">.</span><span class="py">fit</span><span class="o">(</span><span class="n">training</span><span class="o">)</span> |
| |
| <span class="c1">// Print the coefficients, intercept and scale parameter for AFT survival regression</span> |
| <span class="nf">println</span><span class="o">(</span><span class="n">s</span><span class="s">"Coefficients: ${model.coefficients}"</span><span class="o">)</span> |
| <span class="nf">println</span><span class="o">(</span><span class="n">s</span><span class="s">"Intercept: ${model.intercept}"</span><span class="o">)</span> |
| <span class="nf">println</span><span class="o">(</span><span class="n">s</span><span class="s">"Scale: ${model.scale}"</span><span class="o">)</span> |
| <span class="nv">model</span><span class="o">.</span><span class="py">transform</span><span class="o">(</span><span class="n">training</span><span class="o">).</span><span class="py">show</span><span class="o">(</span><span class="kc">false</span><span class="o">)</span></code></pre></div> |
| <div><small>Find full example code at "examples/src/main/scala/org/apache/spark/examples/ml/AFTSurvivalRegressionExample.scala" in the Spark repo.</small></div> |
| </div> |
| |
| <div data-lang="java"> |
| |
| <p>Refer to the <a href="api/java/org/apache/spark/ml/regression/AFTSurvivalRegression.html">Java API docs</a> for more details.</p> |
| |
| <div class="highlight"><pre class="codehilite"><code><span class="kn">import</span> <span class="nn">java.util.Arrays</span><span class="o">;</span> |
| <span class="kn">import</span> <span class="nn">java.util.List</span><span class="o">;</span> |
| |
| <span class="kn">import</span> <span class="nn">org.apache.spark.ml.regression.AFTSurvivalRegression</span><span class="o">;</span> |
| <span class="kn">import</span> <span class="nn">org.apache.spark.ml.regression.AFTSurvivalRegressionModel</span><span class="o">;</span> |
| <span class="kn">import</span> <span class="nn">org.apache.spark.ml.linalg.VectorUDT</span><span class="o">;</span> |
| <span class="kn">import</span> <span class="nn">org.apache.spark.ml.linalg.Vectors</span><span class="o">;</span> |
| <span class="kn">import</span> <span class="nn">org.apache.spark.sql.Dataset</span><span class="o">;</span> |
| <span class="kn">import</span> <span class="nn">org.apache.spark.sql.Row</span><span class="o">;</span> |
| <span class="kn">import</span> <span class="nn">org.apache.spark.sql.RowFactory</span><span class="o">;</span> |
| <span class="kn">import</span> <span class="nn">org.apache.spark.sql.SparkSession</span><span class="o">;</span> |
| <span class="kn">import</span> <span class="nn">org.apache.spark.sql.types.DataTypes</span><span class="o">;</span> |
| <span class="kn">import</span> <span class="nn">org.apache.spark.sql.types.Metadata</span><span class="o">;</span> |
| <span class="kn">import</span> <span class="nn">org.apache.spark.sql.types.StructField</span><span class="o">;</span> |
| <span class="kn">import</span> <span class="nn">org.apache.spark.sql.types.StructType</span><span class="o">;</span> |
| |
| <span class="nc">List</span><span class="o"><</span><span class="nc">Row</span><span class="o">></span> <span class="n">data</span> <span class="o">=</span> <span class="nc">Arrays</span><span class="o">.</span><span class="na">asList</span><span class="o">(</span> |
| <span class="nc">RowFactory</span><span class="o">.</span><span class="na">create</span><span class="o">(</span><span class="mf">1.218</span><span class="o">,</span> <span class="mf">1.0</span><span class="o">,</span> <span class="nc">Vectors</span><span class="o">.</span><span class="na">dense</span><span class="o">(</span><span class="mf">1.560</span><span class="o">,</span> <span class="o">-</span><span class="mf">0.605</span><span class="o">)),</span> |
| <span class="nc">RowFactory</span><span class="o">.</span><span class="na">create</span><span class="o">(</span><span class="mf">2.949</span><span class="o">,</span> <span class="mf">0.0</span><span class="o">,</span> <span class="nc">Vectors</span><span class="o">.</span><span class="na">dense</span><span class="o">(</span><span class="mf">0.346</span><span class="o">,</span> <span class="mf">2.158</span><span class="o">)),</span> |
| <span class="nc">RowFactory</span><span class="o">.</span><span class="na">create</span><span class="o">(</span><span class="mf">3.627</span><span class="o">,</span> <span class="mf">0.0</span><span class="o">,</span> <span class="nc">Vectors</span><span class="o">.</span><span class="na">dense</span><span class="o">(</span><span class="mf">1.380</span><span class="o">,</span> <span class="mf">0.231</span><span class="o">)),</span> |
| <span class="nc">RowFactory</span><span class="o">.</span><span class="na">create</span><span class="o">(</span><span class="mf">0.273</span><span class="o">,</span> <span class="mf">1.0</span><span class="o">,</span> <span class="nc">Vectors</span><span class="o">.</span><span class="na">dense</span><span class="o">(</span><span class="mf">0.520</span><span class="o">,</span> <span class="mf">1.151</span><span class="o">)),</span> |
| <span class="nc">RowFactory</span><span class="o">.</span><span class="na">create</span><span class="o">(</span><span class="mf">4.199</span><span class="o">,</span> <span class="mf">0.0</span><span class="o">,</span> <span class="nc">Vectors</span><span class="o">.</span><span class="na">dense</span><span class="o">(</span><span class="mf">0.795</span><span class="o">,</span> <span class="o">-</span><span class="mf">0.226</span><span class="o">))</span> |
| <span class="o">);</span> |
| <span class="nc">StructType</span> <span class="n">schema</span> <span class="o">=</span> <span class="k">new</span> <span class="nc">StructType</span><span class="o">(</span><span class="k">new</span> <span class="nc">StructField</span><span class="o">[]{</span> |
| <span class="k">new</span> <span class="nf">StructField</span><span class="o">(</span><span class="s">"label"</span><span class="o">,</span> <span class="nc">DataTypes</span><span class="o">.</span><span class="na">DoubleType</span><span class="o">,</span> <span class="kc">false</span><span class="o">,</span> <span class="nc">Metadata</span><span class="o">.</span><span class="na">empty</span><span class="o">()),</span> |
| <span class="k">new</span> <span class="nf">StructField</span><span class="o">(</span><span class="s">"censor"</span><span class="o">,</span> <span class="nc">DataTypes</span><span class="o">.</span><span class="na">DoubleType</span><span class="o">,</span> <span class="kc">false</span><span class="o">,</span> <span class="nc">Metadata</span><span class="o">.</span><span class="na">empty</span><span class="o">()),</span> |
| <span class="k">new</span> <span class="nf">StructField</span><span class="o">(</span><span class="s">"features"</span><span class="o">,</span> <span class="k">new</span> <span class="nc">VectorUDT</span><span class="o">(),</span> <span class="kc">false</span><span class="o">,</span> <span class="nc">Metadata</span><span class="o">.</span><span class="na">empty</span><span class="o">())</span> |
| <span class="o">});</span> |
| <span class="nc">Dataset</span><span class="o"><</span><span class="nc">Row</span><span class="o">></span> <span class="n">training</span> <span class="o">=</span> <span class="n">spark</span><span class="o">.</span><span class="na">createDataFrame</span><span class="o">(</span><span class="n">data</span><span class="o">,</span> <span class="n">schema</span><span class="o">);</span> |
| <span class="kt">double</span><span class="o">[]</span> <span class="n">quantileProbabilities</span> <span class="o">=</span> <span class="k">new</span> <span class="kt">double</span><span class="o">[]{</span><span class="mf">0.3</span><span class="o">,</span> <span class="mf">0.6</span><span class="o">};</span> |
| <span class="nc">AFTSurvivalRegression</span> <span class="n">aft</span> <span class="o">=</span> <span class="k">new</span> <span class="nc">AFTSurvivalRegression</span><span class="o">()</span> |
| <span class="o">.</span><span class="na">setQuantileProbabilities</span><span class="o">(</span><span class="n">quantileProbabilities</span><span class="o">)</span> |
| <span class="o">.</span><span class="na">setQuantilesCol</span><span class="o">(</span><span class="s">"quantiles"</span><span class="o">);</span> |
| |
| <span class="nc">AFTSurvivalRegressionModel</span> <span class="n">model</span> <span class="o">=</span> <span class="n">aft</span><span class="o">.</span><span class="na">fit</span><span class="o">(</span><span class="n">training</span><span class="o">);</span> |
| |
| <span class="c1">// Print the coefficients, intercept and scale parameter for AFT survival regression</span> |
| <span class="nc">System</span><span class="o">.</span><span class="na">out</span><span class="o">.</span><span class="na">println</span><span class="o">(</span><span class="s">"Coefficients: "</span> <span class="o">+</span> <span class="n">model</span><span class="o">.</span><span class="na">coefficients</span><span class="o">());</span> |
| <span class="nc">System</span><span class="o">.</span><span class="na">out</span><span class="o">.</span><span class="na">println</span><span class="o">(</span><span class="s">"Intercept: "</span> <span class="o">+</span> <span class="n">model</span><span class="o">.</span><span class="na">intercept</span><span class="o">());</span> |
| <span class="nc">System</span><span class="o">.</span><span class="na">out</span><span class="o">.</span><span class="na">println</span><span class="o">(</span><span class="s">"Scale: "</span> <span class="o">+</span> <span class="n">model</span><span class="o">.</span><span class="na">scale</span><span class="o">());</span> |
| <span class="n">model</span><span class="o">.</span><span class="na">transform</span><span class="o">(</span><span class="n">training</span><span class="o">).</span><span class="na">show</span><span class="o">(</span><span class="kc">false</span><span class="o">);</span></code></pre></div> |
| <div><small>Find full example code at "examples/src/main/java/org/apache/spark/examples/ml/JavaAFTSurvivalRegressionExample.java" in the Spark repo.</small></div> |
| </div> |
| |
| <div data-lang="python"> |
| |
| <p>Refer to the <a href="api/python/reference/api/pyspark.ml.regression.AFTSurvivalRegression.html#pyspark.ml.regression.AFTSurvivalRegression">Python API docs</a> for more details.</p> |
| |
| <div class="highlight"><pre class="codehilite"><code><span class="kn">from</span> <span class="nn">pyspark.ml.regression</span> <span class="kn">import</span> <span class="n">AFTSurvivalRegression</span> |
| <span class="kn">from</span> <span class="nn">pyspark.ml.linalg</span> <span class="kn">import</span> <span class="n">Vectors</span> |
| |
| <span class="n">training</span> <span class="o">=</span> <span class="n">spark</span><span class="p">.</span><span class="n">createDataFrame</span><span class="p">([</span> |
| <span class="p">(</span><span class="mf">1.218</span><span class="p">,</span> <span class="mf">1.0</span><span class="p">,</span> <span class="n">Vectors</span><span class="p">.</span><span class="n">dense</span><span class="p">(</span><span class="mf">1.560</span><span class="p">,</span> <span class="o">-</span><span class="mf">0.605</span><span class="p">)),</span> |
| <span class="p">(</span><span class="mf">2.949</span><span class="p">,</span> <span class="mf">0.0</span><span class="p">,</span> <span class="n">Vectors</span><span class="p">.</span><span class="n">dense</span><span class="p">(</span><span class="mf">0.346</span><span class="p">,</span> <span class="mf">2.158</span><span class="p">)),</span> |
| <span class="p">(</span><span class="mf">3.627</span><span class="p">,</span> <span class="mf">0.0</span><span class="p">,</span> <span class="n">Vectors</span><span class="p">.</span><span class="n">dense</span><span class="p">(</span><span class="mf">1.380</span><span class="p">,</span> <span class="mf">0.231</span><span class="p">)),</span> |
| <span class="p">(</span><span class="mf">0.273</span><span class="p">,</span> <span class="mf">1.0</span><span class="p">,</span> <span class="n">Vectors</span><span class="p">.</span><span class="n">dense</span><span class="p">(</span><span class="mf">0.520</span><span class="p">,</span> <span class="mf">1.151</span><span class="p">)),</span> |
| <span class="p">(</span><span class="mf">4.199</span><span class="p">,</span> <span class="mf">0.0</span><span class="p">,</span> <span class="n">Vectors</span><span class="p">.</span><span class="n">dense</span><span class="p">(</span><span class="mf">0.795</span><span class="p">,</span> <span class="o">-</span><span class="mf">0.226</span><span class="p">))],</span> <span class="p">[</span><span class="s">"label"</span><span class="p">,</span> <span class="s">"censor"</span><span class="p">,</span> <span class="s">"features"</span><span class="p">])</span> |
| <span class="n">quantileProbabilities</span> <span class="o">=</span> <span class="p">[</span><span class="mf">0.3</span><span class="p">,</span> <span class="mf">0.6</span><span class="p">]</span> |
| <span class="n">aft</span> <span class="o">=</span> <span class="n">AFTSurvivalRegression</span><span class="p">(</span><span class="n">quantileProbabilities</span><span class="o">=</span><span class="n">quantileProbabilities</span><span class="p">,</span> |
| <span class="n">quantilesCol</span><span class="o">=</span><span class="s">"quantiles"</span><span class="p">)</span> |
| |
| <span class="n">model</span> <span class="o">=</span> <span class="n">aft</span><span class="p">.</span><span class="n">fit</span><span class="p">(</span><span class="n">training</span><span class="p">)</span> |
| |
| <span class="c1"># Print the coefficients, intercept and scale parameter for AFT survival regression |
| </span><span class="k">print</span><span class="p">(</span><span class="s">"Coefficients: "</span> <span class="o">+</span> <span class="nb">str</span><span class="p">(</span><span class="n">model</span><span class="p">.</span><span class="n">coefficients</span><span class="p">))</span> |
| <span class="k">print</span><span class="p">(</span><span class="s">"Intercept: "</span> <span class="o">+</span> <span class="nb">str</span><span class="p">(</span><span class="n">model</span><span class="p">.</span><span class="n">intercept</span><span class="p">))</span> |
| <span class="k">print</span><span class="p">(</span><span class="s">"Scale: "</span> <span class="o">+</span> <span class="nb">str</span><span class="p">(</span><span class="n">model</span><span class="p">.</span><span class="n">scale</span><span class="p">))</span> |
| <span class="n">model</span><span class="p">.</span><span class="n">transform</span><span class="p">(</span><span class="n">training</span><span class="p">).</span><span class="n">show</span><span class="p">(</span><span class="n">truncate</span><span class="o">=</span><span class="bp">False</span><span class="p">)</span></code></pre></div> |
| <div><small>Find full example code at "examples/src/main/python/ml/aft_survival_regression.py" in the Spark repo.</small></div> |
| </div> |
| |
| <div data-lang="r"> |
| |
| <p>Refer to the <a href="api/R/spark.survreg.html">R API docs</a> for more details.</p> |
| |
| <div class="highlight"><pre class="codehilite"><code><span class="c1"># Use the ovarian dataset available in R survival package</span><span class="w"> |
| </span><span class="n">library</span><span class="p">(</span><span class="n">survival</span><span class="p">)</span><span class="w"> |
| |
| </span><span class="c1"># Fit an accelerated failure time (AFT) survival regression model with spark.survreg</span><span class="w"> |
| </span><span class="n">ovarianDF</span><span class="w"> </span><span class="o"><-</span><span class="w"> </span><span class="n">suppressWarnings</span><span class="p">(</span><span class="n">createDataFrame</span><span class="p">(</span><span class="n">ovarian</span><span class="p">))</span><span class="w"> |
| </span><span class="n">aftDF</span><span class="w"> </span><span class="o"><-</span><span class="w"> </span><span class="n">ovarianDF</span><span class="w"> |
| </span><span class="n">aftTestDF</span><span class="w"> </span><span class="o"><-</span><span class="w"> </span><span class="n">ovarianDF</span><span class="w"> |
| </span><span class="n">aftModel</span><span class="w"> </span><span class="o"><-</span><span class="w"> </span><span class="n">spark.survreg</span><span class="p">(</span><span class="n">aftDF</span><span class="p">,</span><span class="w"> </span><span class="n">Surv</span><span class="p">(</span><span class="n">futime</span><span class="p">,</span><span class="w"> </span><span class="n">fustat</span><span class="p">)</span><span class="w"> </span><span class="o">~</span><span class="w"> </span><span class="n">ecog_ps</span><span class="w"> </span><span class="o">+</span><span class="w"> </span><span class="n">rx</span><span class="p">)</span><span class="w"> |
| |
| </span><span class="c1"># Model summary</span><span class="w"> |
| </span><span class="n">summary</span><span class="p">(</span><span class="n">aftModel</span><span class="p">)</span><span class="w"> |
| |
| </span><span class="c1"># Prediction</span><span class="w"> |
| </span><span class="n">aftPredictions</span><span class="w"> </span><span class="o"><-</span><span class="w"> </span><span class="n">predict</span><span class="p">(</span><span class="n">aftModel</span><span class="p">,</span><span class="w"> </span><span class="n">aftTestDF</span><span class="p">)</span><span class="w"> |
| </span><span class="n">head</span><span class="p">(</span><span class="n">aftPredictions</span><span class="p">)</span></code></pre></div> |
| <div><small>Find full example code at "examples/src/main/r/ml/survreg.R" in the Spark repo.</small></div> |
| </div> |
| |
| </div> |
| |
| <h2 id="isotonic-regression">Isotonic regression</h2> |
| <p><a href="http://en.wikipedia.org/wiki/Isotonic_regression">Isotonic regression</a> |
| belongs to the family of regression algorithms. Formally isotonic regression is a problem where |
| given a finite set of real numbers <code class="language-plaintext highlighter-rouge">$Y = {y_1, y_2, ..., y_n}$</code> representing observed responses |
| and <code class="language-plaintext highlighter-rouge">$X = {x_1, x_2, ..., x_n}$</code> the unknown response values to be fitted |
| finding a function that minimizes</p> |
| |
| <p><code class="language-plaintext highlighter-rouge">\begin{equation} |
| f(x) = \sum_{i=1}^n w_i (y_i - x_i)^2 |
| \end{equation}</code></p> |
| |
| <p>with respect to complete order subject to |
| <code class="language-plaintext highlighter-rouge">$x_1\le x_2\le ...\le x_n$</code> where <code class="language-plaintext highlighter-rouge">$w_i$</code> are positive weights. |
| The resulting function is called isotonic regression and it is unique. |
| It can be viewed as least squares problem under order restriction. |
| Essentially isotonic regression is a |
| <a href="http://en.wikipedia.org/wiki/Monotonic_function">monotonic function</a> |
| best fitting the original data points.</p> |
| |
| <p>We implement a |
| <a href="https://doi.org/10.1198/TECH.2010.10111">pool adjacent violators algorithm</a> |
| which uses an approach to |
| <a href="https://doi.org/10.1007/978-3-642-99789-1_10">parallelizing isotonic regression</a>. |
| The training input is a DataFrame which contains three columns |
| label, features and weight. Additionally, IsotonicRegression algorithm has one |
| optional parameter called $isotonic$ defaulting to true. |
| This argument specifies if the isotonic regression is |
| isotonic (monotonically increasing) or antitonic (monotonically decreasing).</p> |
| |
| <p>Training returns an IsotonicRegressionModel that can be used to predict |
| labels for both known and unknown features. The result of isotonic regression |
| is treated as piecewise linear function. The rules for prediction therefore are:</p> |
| |
| <ul> |
| <li>If the prediction input exactly matches a training feature |
| then associated prediction is returned. In case there are multiple predictions with the same |
| feature then one of them is returned. Which one is undefined |
| (same as java.util.Arrays.binarySearch).</li> |
| <li>If the prediction input is lower or higher than all training features |
| then prediction with lowest or highest feature is returned respectively. |
| In case there are multiple predictions with the same feature |
| then the lowest or highest is returned respectively.</li> |
| <li>If the prediction input falls between two training features then prediction is treated |
| as piecewise linear function and interpolated value is calculated from the |
| predictions of the two closest features. In case there are multiple values |
| with the same feature then the same rules as in previous point are used.</li> |
| </ul> |
| |
| <p><strong>Examples</strong></p> |
| |
| <div class="codetabs"> |
| <div data-lang="scala"> |
| |
| <p>Refer to the <a href="api/scala/org/apache/spark/ml/regression/IsotonicRegression.html"><code class="language-plaintext highlighter-rouge">IsotonicRegression</code> Scala docs</a> for details on the API.</p> |
| |
| <div class="highlight"><pre class="codehilite"><code><span class="k">import</span> <span class="nn">org.apache.spark.ml.regression.IsotonicRegression</span> |
| |
| <span class="c1">// Loads data.</span> |
| <span class="k">val</span> <span class="nv">dataset</span> <span class="k">=</span> <span class="nv">spark</span><span class="o">.</span><span class="py">read</span><span class="o">.</span><span class="py">format</span><span class="o">(</span><span class="s">"libsvm"</span><span class="o">)</span> |
| <span class="o">.</span><span class="py">load</span><span class="o">(</span><span class="s">"data/mllib/sample_isotonic_regression_libsvm_data.txt"</span><span class="o">)</span> |
| |
| <span class="c1">// Trains an isotonic regression model.</span> |
| <span class="k">val</span> <span class="nv">ir</span> <span class="k">=</span> <span class="k">new</span> <span class="nc">IsotonicRegression</span><span class="o">()</span> |
| <span class="k">val</span> <span class="nv">model</span> <span class="k">=</span> <span class="nv">ir</span><span class="o">.</span><span class="py">fit</span><span class="o">(</span><span class="n">dataset</span><span class="o">)</span> |
| |
| <span class="nf">println</span><span class="o">(</span><span class="n">s</span><span class="s">"Boundaries in increasing order: ${model.boundaries}\n"</span><span class="o">)</span> |
| <span class="nf">println</span><span class="o">(</span><span class="n">s</span><span class="s">"Predictions associated with the boundaries: ${model.predictions}\n"</span><span class="o">)</span> |
| |
| <span class="c1">// Makes predictions.</span> |
| <span class="nv">model</span><span class="o">.</span><span class="py">transform</span><span class="o">(</span><span class="n">dataset</span><span class="o">).</span><span class="py">show</span><span class="o">()</span></code></pre></div> |
| <div><small>Find full example code at "examples/src/main/scala/org/apache/spark/examples/ml/IsotonicRegressionExample.scala" in the Spark repo.</small></div> |
| </div> |
| <div data-lang="java"> |
| |
| <p>Refer to the <a href="api/java/org/apache/spark/ml/regression/IsotonicRegression.html"><code class="language-plaintext highlighter-rouge">IsotonicRegression</code> Java docs</a> for details on the API.</p> |
| |
| <div class="highlight"><pre class="codehilite"><code><span class="kn">import</span> <span class="nn">org.apache.spark.ml.regression.IsotonicRegression</span><span class="o">;</span> |
| <span class="kn">import</span> <span class="nn">org.apache.spark.ml.regression.IsotonicRegressionModel</span><span class="o">;</span> |
| <span class="kn">import</span> <span class="nn">org.apache.spark.sql.Dataset</span><span class="o">;</span> |
| <span class="kn">import</span> <span class="nn">org.apache.spark.sql.Row</span><span class="o">;</span> |
| |
| <span class="c1">// Loads data.</span> |
| <span class="nc">Dataset</span><span class="o"><</span><span class="nc">Row</span><span class="o">></span> <span class="n">dataset</span> <span class="o">=</span> <span class="n">spark</span><span class="o">.</span><span class="na">read</span><span class="o">().</span><span class="na">format</span><span class="o">(</span><span class="s">"libsvm"</span><span class="o">)</span> |
| <span class="o">.</span><span class="na">load</span><span class="o">(</span><span class="s">"data/mllib/sample_isotonic_regression_libsvm_data.txt"</span><span class="o">);</span> |
| |
| <span class="c1">// Trains an isotonic regression model.</span> |
| <span class="nc">IsotonicRegression</span> <span class="n">ir</span> <span class="o">=</span> <span class="k">new</span> <span class="nc">IsotonicRegression</span><span class="o">();</span> |
| <span class="nc">IsotonicRegressionModel</span> <span class="n">model</span> <span class="o">=</span> <span class="n">ir</span><span class="o">.</span><span class="na">fit</span><span class="o">(</span><span class="n">dataset</span><span class="o">);</span> |
| |
| <span class="nc">System</span><span class="o">.</span><span class="na">out</span><span class="o">.</span><span class="na">println</span><span class="o">(</span><span class="s">"Boundaries in increasing order: "</span> <span class="o">+</span> <span class="n">model</span><span class="o">.</span><span class="na">boundaries</span><span class="o">()</span> <span class="o">+</span> <span class="s">"\n"</span><span class="o">);</span> |
| <span class="nc">System</span><span class="o">.</span><span class="na">out</span><span class="o">.</span><span class="na">println</span><span class="o">(</span><span class="s">"Predictions associated with the boundaries: "</span> <span class="o">+</span> <span class="n">model</span><span class="o">.</span><span class="na">predictions</span><span class="o">()</span> <span class="o">+</span> <span class="s">"\n"</span><span class="o">);</span> |
| |
| <span class="c1">// Makes predictions.</span> |
| <span class="n">model</span><span class="o">.</span><span class="na">transform</span><span class="o">(</span><span class="n">dataset</span><span class="o">).</span><span class="na">show</span><span class="o">();</span></code></pre></div> |
| <div><small>Find full example code at "examples/src/main/java/org/apache/spark/examples/ml/JavaIsotonicRegressionExample.java" in the Spark repo.</small></div> |
| </div> |
| <div data-lang="python"> |
| |
| <p>Refer to the <a href="api/python/reference/api/pyspark.ml.regression.IsotonicRegression.html#pyspark.ml.regression.IsotonicRegression"><code class="language-plaintext highlighter-rouge">IsotonicRegression</code> Python docs</a> for more details on the API.</p> |
| |
| <div class="highlight"><pre class="codehilite"><code><span class="kn">from</span> <span class="nn">pyspark.ml.regression</span> <span class="kn">import</span> <span class="n">IsotonicRegression</span> |
| |
| <span class="c1"># Loads data. |
| </span><span class="n">dataset</span> <span class="o">=</span> <span class="n">spark</span><span class="p">.</span><span class="n">read</span><span class="p">.</span><span class="nb">format</span><span class="p">(</span><span class="s">"libsvm"</span><span class="p">)</span>\ |
| <span class="p">.</span><span class="n">load</span><span class="p">(</span><span class="s">"data/mllib/sample_isotonic_regression_libsvm_data.txt"</span><span class="p">)</span> |
| |
| <span class="c1"># Trains an isotonic regression model. |
| </span><span class="n">model</span> <span class="o">=</span> <span class="n">IsotonicRegression</span><span class="p">().</span><span class="n">fit</span><span class="p">(</span><span class="n">dataset</span><span class="p">)</span> |
| <span class="k">print</span><span class="p">(</span><span class="s">"Boundaries in increasing order: %s</span><span class="se">\n</span><span class="s">"</span> <span class="o">%</span> <span class="nb">str</span><span class="p">(</span><span class="n">model</span><span class="p">.</span><span class="n">boundaries</span><span class="p">))</span> |
| <span class="k">print</span><span class="p">(</span><span class="s">"Predictions associated with the boundaries: %s</span><span class="se">\n</span><span class="s">"</span> <span class="o">%</span> <span class="nb">str</span><span class="p">(</span><span class="n">model</span><span class="p">.</span><span class="n">predictions</span><span class="p">))</span> |
| |
| <span class="c1"># Makes predictions. |
| </span><span class="n">model</span><span class="p">.</span><span class="n">transform</span><span class="p">(</span><span class="n">dataset</span><span class="p">).</span><span class="n">show</span><span class="p">()</span></code></pre></div> |
| <div><small>Find full example code at "examples/src/main/python/ml/isotonic_regression_example.py" in the Spark repo.</small></div> |
| </div> |
| |
| <div data-lang="r"> |
| |
| <p>Refer to the <a href="api/R/spark.isoreg.html"><code class="language-plaintext highlighter-rouge">IsotonicRegression</code> R API docs</a> for more details on the API.</p> |
| |
| <div class="highlight"><pre class="codehilite"><code><span class="c1"># Load training data</span><span class="w"> |
| </span><span class="n">df</span><span class="w"> </span><span class="o"><-</span><span class="w"> </span><span class="n">read.df</span><span class="p">(</span><span class="s2">"data/mllib/sample_isotonic_regression_libsvm_data.txt"</span><span class="p">,</span><span class="w"> </span><span class="n">source</span><span class="w"> </span><span class="o">=</span><span class="w"> </span><span class="s2">"libsvm"</span><span class="p">)</span><span class="w"> |
| </span><span class="n">training</span><span class="w"> </span><span class="o"><-</span><span class="w"> </span><span class="n">df</span><span class="w"> |
| </span><span class="n">test</span><span class="w"> </span><span class="o"><-</span><span class="w"> </span><span class="n">df</span><span class="w"> |
| |
| </span><span class="c1"># Fit an isotonic regression model with spark.isoreg</span><span class="w"> |
| </span><span class="n">model</span><span class="w"> </span><span class="o"><-</span><span class="w"> </span><span class="n">spark.isoreg</span><span class="p">(</span><span class="n">training</span><span class="p">,</span><span class="w"> </span><span class="n">label</span><span class="w"> </span><span class="o">~</span><span class="w"> </span><span class="n">features</span><span class="p">,</span><span class="w"> </span><span class="n">isotonic</span><span class="w"> </span><span class="o">=</span><span class="w"> </span><span class="kc">FALSE</span><span class="p">)</span><span class="w"> |
| |
| </span><span class="c1"># Model summary</span><span class="w"> |
| </span><span class="n">summary</span><span class="p">(</span><span class="n">model</span><span class="p">)</span><span class="w"> |
| |
| </span><span class="c1"># Prediction</span><span class="w"> |
| </span><span class="n">predictions</span><span class="w"> </span><span class="o"><-</span><span class="w"> </span><span class="n">predict</span><span class="p">(</span><span class="n">model</span><span class="p">,</span><span class="w"> </span><span class="n">test</span><span class="p">)</span><span class="w"> |
| </span><span class="n">head</span><span class="p">(</span><span class="n">predictions</span><span class="p">)</span></code></pre></div> |
| <div><small>Find full example code at "examples/src/main/r/ml/isoreg.R" in the Spark repo.</small></div> |
| </div> |
| |
| </div> |
| |
| <h2 id="factorization-machines-regressor">Factorization machines regressor</h2> |
| |
| <p>For more background and more details about the implementation of factorization machines, |
| refer to the <a href="ml-classification-regression.html#factorization-machines">Factorization Machines section</a>.</p> |
| |
| <p><strong>Examples</strong></p> |
| |
| <p>The following examples load a dataset in LibSVM format, split it into training and test sets, |
| train on the first dataset, and then evaluate on the held-out test set. |
| We scale features to be between 0 and 1 to prevent the exploding gradient problem.</p> |
| |
| <div class="codetabs"> |
| <div data-lang="scala"> |
| |
| <p>Refer to the <a href="api/scala/org/apache/spark/ml/regression/FMRegressor.html">Scala API docs</a> for more details.</p> |
| |
| <div class="highlight"><pre class="codehilite"><code><span class="k">import</span> <span class="nn">org.apache.spark.ml.Pipeline</span> |
| <span class="k">import</span> <span class="nn">org.apache.spark.ml.evaluation.RegressionEvaluator</span> |
| <span class="k">import</span> <span class="nn">org.apache.spark.ml.feature.MinMaxScaler</span> |
| <span class="k">import</span> <span class="nn">org.apache.spark.ml.regression.</span><span class="o">{</span><span class="nc">FMRegressionModel</span><span class="o">,</span> <span class="nc">FMRegressor</span><span class="o">}</span> |
| |
| <span class="c1">// Load and parse the data file, converting it to a DataFrame.</span> |
| <span class="k">val</span> <span class="nv">data</span> <span class="k">=</span> <span class="nv">spark</span><span class="o">.</span><span class="py">read</span><span class="o">.</span><span class="py">format</span><span class="o">(</span><span class="s">"libsvm"</span><span class="o">).</span><span class="py">load</span><span class="o">(</span><span class="s">"data/mllib/sample_libsvm_data.txt"</span><span class="o">)</span> |
| |
| <span class="c1">// Scale features.</span> |
| <span class="k">val</span> <span class="nv">featureScaler</span> <span class="k">=</span> <span class="k">new</span> <span class="nc">MinMaxScaler</span><span class="o">()</span> |
| <span class="o">.</span><span class="py">setInputCol</span><span class="o">(</span><span class="s">"features"</span><span class="o">)</span> |
| <span class="o">.</span><span class="py">setOutputCol</span><span class="o">(</span><span class="s">"scaledFeatures"</span><span class="o">)</span> |
| <span class="o">.</span><span class="py">fit</span><span class="o">(</span><span class="n">data</span><span class="o">)</span> |
| |
| <span class="c1">// Split the data into training and test sets (30% held out for testing).</span> |
| <span class="k">val</span> <span class="nv">Array</span><span class="o">(</span><span class="n">trainingData</span><span class="o">,</span> <span class="n">testData</span><span class="o">)</span> <span class="k">=</span> <span class="nv">data</span><span class="o">.</span><span class="py">randomSplit</span><span class="o">(</span><span class="nc">Array</span><span class="o">(</span><span class="mf">0.7</span><span class="o">,</span> <span class="mf">0.3</span><span class="o">))</span> |
| |
| <span class="c1">// Train a FM model.</span> |
| <span class="k">val</span> <span class="nv">fm</span> <span class="k">=</span> <span class="k">new</span> <span class="nc">FMRegressor</span><span class="o">()</span> |
| <span class="o">.</span><span class="py">setLabelCol</span><span class="o">(</span><span class="s">"label"</span><span class="o">)</span> |
| <span class="o">.</span><span class="py">setFeaturesCol</span><span class="o">(</span><span class="s">"scaledFeatures"</span><span class="o">)</span> |
| <span class="o">.</span><span class="py">setStepSize</span><span class="o">(</span><span class="mf">0.001</span><span class="o">)</span> |
| |
| <span class="c1">// Create a Pipeline.</span> |
| <span class="k">val</span> <span class="nv">pipeline</span> <span class="k">=</span> <span class="k">new</span> <span class="nc">Pipeline</span><span class="o">()</span> |
| <span class="o">.</span><span class="py">setStages</span><span class="o">(</span><span class="nc">Array</span><span class="o">(</span><span class="n">featureScaler</span><span class="o">,</span> <span class="n">fm</span><span class="o">))</span> |
| |
| <span class="c1">// Train model.</span> |
| <span class="k">val</span> <span class="nv">model</span> <span class="k">=</span> <span class="nv">pipeline</span><span class="o">.</span><span class="py">fit</span><span class="o">(</span><span class="n">trainingData</span><span class="o">)</span> |
| |
| <span class="c1">// Make predictions.</span> |
| <span class="k">val</span> <span class="nv">predictions</span> <span class="k">=</span> <span class="nv">model</span><span class="o">.</span><span class="py">transform</span><span class="o">(</span><span class="n">testData</span><span class="o">)</span> |
| |
| <span class="c1">// Select example rows to display.</span> |
| <span class="nv">predictions</span><span class="o">.</span><span class="py">select</span><span class="o">(</span><span class="s">"prediction"</span><span class="o">,</span> <span class="s">"label"</span><span class="o">,</span> <span class="s">"features"</span><span class="o">).</span><span class="py">show</span><span class="o">(</span><span class="mi">5</span><span class="o">)</span> |
| |
| <span class="c1">// Select (prediction, true label) and compute test error.</span> |
| <span class="k">val</span> <span class="nv">evaluator</span> <span class="k">=</span> <span class="k">new</span> <span class="nc">RegressionEvaluator</span><span class="o">()</span> |
| <span class="o">.</span><span class="py">setLabelCol</span><span class="o">(</span><span class="s">"label"</span><span class="o">)</span> |
| <span class="o">.</span><span class="py">setPredictionCol</span><span class="o">(</span><span class="s">"prediction"</span><span class="o">)</span> |
| <span class="o">.</span><span class="py">setMetricName</span><span class="o">(</span><span class="s">"rmse"</span><span class="o">)</span> |
| <span class="k">val</span> <span class="nv">rmse</span> <span class="k">=</span> <span class="nv">evaluator</span><span class="o">.</span><span class="py">evaluate</span><span class="o">(</span><span class="n">predictions</span><span class="o">)</span> |
| <span class="nf">println</span><span class="o">(</span><span class="n">s</span><span class="s">"Root Mean Squared Error (RMSE) on test data = $rmse"</span><span class="o">)</span> |
| |
| <span class="k">val</span> <span class="nv">fmModel</span> <span class="k">=</span> <span class="nv">model</span><span class="o">.</span><span class="py">stages</span><span class="o">(</span><span class="mi">1</span><span class="o">).</span><span class="py">asInstanceOf</span><span class="o">[</span><span class="kt">FMRegressionModel</span><span class="o">]</span> |
| <span class="nf">println</span><span class="o">(</span><span class="n">s</span><span class="s">"Factors: ${fmModel.factors} Linear: ${fmModel.linear} "</span> <span class="o">+</span> |
| <span class="n">s</span><span class="s">"Intercept: ${fmModel.intercept}"</span><span class="o">)</span></code></pre></div> |
| <div><small>Find full example code at "examples/src/main/scala/org/apache/spark/examples/ml/FMRegressorExample.scala" in the Spark repo.</small></div> |
| </div> |
| |
| <div data-lang="java"> |
| |
| <p>Refer to the <a href="api/java/org/apache/spark/ml/regression/FMRegressor.html">Java API docs</a> for more details.</p> |
| |
| <div class="highlight"><pre class="codehilite"><code><span class="kn">import</span> <span class="nn">org.apache.spark.ml.Pipeline</span><span class="o">;</span> |
| <span class="kn">import</span> <span class="nn">org.apache.spark.ml.PipelineModel</span><span class="o">;</span> |
| <span class="kn">import</span> <span class="nn">org.apache.spark.ml.PipelineStage</span><span class="o">;</span> |
| <span class="kn">import</span> <span class="nn">org.apache.spark.ml.evaluation.RegressionEvaluator</span><span class="o">;</span> |
| <span class="kn">import</span> <span class="nn">org.apache.spark.ml.feature.MinMaxScaler</span><span class="o">;</span> |
| <span class="kn">import</span> <span class="nn">org.apache.spark.ml.feature.MinMaxScalerModel</span><span class="o">;</span> |
| <span class="kn">import</span> <span class="nn">org.apache.spark.ml.regression.FMRegressionModel</span><span class="o">;</span> |
| <span class="kn">import</span> <span class="nn">org.apache.spark.ml.regression.FMRegressor</span><span class="o">;</span> |
| <span class="kn">import</span> <span class="nn">org.apache.spark.sql.Dataset</span><span class="o">;</span> |
| <span class="kn">import</span> <span class="nn">org.apache.spark.sql.Row</span><span class="o">;</span> |
| <span class="kn">import</span> <span class="nn">org.apache.spark.sql.SparkSession</span><span class="o">;</span> |
| |
| <span class="c1">// Load and parse the data file, converting it to a DataFrame.</span> |
| <span class="nc">Dataset</span><span class="o"><</span><span class="nc">Row</span><span class="o">></span> <span class="n">data</span> <span class="o">=</span> <span class="n">spark</span><span class="o">.</span><span class="na">read</span><span class="o">().</span><span class="na">format</span><span class="o">(</span><span class="s">"libsvm"</span><span class="o">).</span><span class="na">load</span><span class="o">(</span><span class="s">"data/mllib/sample_libsvm_data.txt"</span><span class="o">);</span> |
| |
| <span class="c1">// Scale features.</span> |
| <span class="nc">MinMaxScalerModel</span> <span class="n">featureScaler</span> <span class="o">=</span> <span class="k">new</span> <span class="nc">MinMaxScaler</span><span class="o">()</span> |
| <span class="o">.</span><span class="na">setInputCol</span><span class="o">(</span><span class="s">"features"</span><span class="o">)</span> |
| <span class="o">.</span><span class="na">setOutputCol</span><span class="o">(</span><span class="s">"scaledFeatures"</span><span class="o">)</span> |
| <span class="o">.</span><span class="na">fit</span><span class="o">(</span><span class="n">data</span><span class="o">);</span> |
| |
| <span class="c1">// Split the data into training and test sets (30% held out for testing).</span> |
| <span class="nc">Dataset</span><span class="o"><</span><span class="nc">Row</span><span class="o">>[]</span> <span class="n">splits</span> <span class="o">=</span> <span class="n">data</span><span class="o">.</span><span class="na">randomSplit</span><span class="o">(</span><span class="k">new</span> <span class="kt">double</span><span class="o">[]</span> <span class="o">{</span><span class="mf">0.7</span><span class="o">,</span> <span class="mf">0.3</span><span class="o">});</span> |
| <span class="nc">Dataset</span><span class="o"><</span><span class="nc">Row</span><span class="o">></span> <span class="n">trainingData</span> <span class="o">=</span> <span class="n">splits</span><span class="o">[</span><span class="mi">0</span><span class="o">];</span> |
| <span class="nc">Dataset</span><span class="o"><</span><span class="nc">Row</span><span class="o">></span> <span class="n">testData</span> <span class="o">=</span> <span class="n">splits</span><span class="o">[</span><span class="mi">1</span><span class="o">];</span> |
| |
| <span class="c1">// Train a FM model.</span> |
| <span class="nc">FMRegressor</span> <span class="n">fm</span> <span class="o">=</span> <span class="k">new</span> <span class="nc">FMRegressor</span><span class="o">()</span> |
| <span class="o">.</span><span class="na">setLabelCol</span><span class="o">(</span><span class="s">"label"</span><span class="o">)</span> |
| <span class="o">.</span><span class="na">setFeaturesCol</span><span class="o">(</span><span class="s">"scaledFeatures"</span><span class="o">)</span> |
| <span class="o">.</span><span class="na">setStepSize</span><span class="o">(</span><span class="mf">0.001</span><span class="o">);</span> |
| |
| <span class="c1">// Create a Pipeline.</span> |
| <span class="nc">Pipeline</span> <span class="n">pipeline</span> <span class="o">=</span> <span class="k">new</span> <span class="nc">Pipeline</span><span class="o">().</span><span class="na">setStages</span><span class="o">(</span><span class="k">new</span> <span class="nc">PipelineStage</span><span class="o">[]</span> <span class="o">{</span><span class="n">featureScaler</span><span class="o">,</span> <span class="n">fm</span><span class="o">});</span> |
| |
| <span class="c1">// Train model.</span> |
| <span class="nc">PipelineModel</span> <span class="n">model</span> <span class="o">=</span> <span class="n">pipeline</span><span class="o">.</span><span class="na">fit</span><span class="o">(</span><span class="n">trainingData</span><span class="o">);</span> |
| |
| <span class="c1">// Make predictions.</span> |
| <span class="nc">Dataset</span><span class="o"><</span><span class="nc">Row</span><span class="o">></span> <span class="n">predictions</span> <span class="o">=</span> <span class="n">model</span><span class="o">.</span><span class="na">transform</span><span class="o">(</span><span class="n">testData</span><span class="o">);</span> |
| |
| <span class="c1">// Select example rows to display.</span> |
| <span class="n">predictions</span><span class="o">.</span><span class="na">select</span><span class="o">(</span><span class="s">"prediction"</span><span class="o">,</span> <span class="s">"label"</span><span class="o">,</span> <span class="s">"features"</span><span class="o">).</span><span class="na">show</span><span class="o">(</span><span class="mi">5</span><span class="o">);</span> |
| |
| <span class="c1">// Select (prediction, true label) and compute test error.</span> |
| <span class="nc">RegressionEvaluator</span> <span class="n">evaluator</span> <span class="o">=</span> <span class="k">new</span> <span class="nc">RegressionEvaluator</span><span class="o">()</span> |
| <span class="o">.</span><span class="na">setLabelCol</span><span class="o">(</span><span class="s">"label"</span><span class="o">)</span> |
| <span class="o">.</span><span class="na">setPredictionCol</span><span class="o">(</span><span class="s">"prediction"</span><span class="o">)</span> |
| <span class="o">.</span><span class="na">setMetricName</span><span class="o">(</span><span class="s">"rmse"</span><span class="o">);</span> |
| <span class="kt">double</span> <span class="n">rmse</span> <span class="o">=</span> <span class="n">evaluator</span><span class="o">.</span><span class="na">evaluate</span><span class="o">(</span><span class="n">predictions</span><span class="o">);</span> |
| <span class="nc">System</span><span class="o">.</span><span class="na">out</span><span class="o">.</span><span class="na">println</span><span class="o">(</span><span class="s">"Root Mean Squared Error (RMSE) on test data = "</span> <span class="o">+</span> <span class="n">rmse</span><span class="o">);</span> |
| |
| <span class="nc">FMRegressionModel</span> <span class="n">fmModel</span> <span class="o">=</span> <span class="o">(</span><span class="nc">FMRegressionModel</span><span class="o">)(</span><span class="n">model</span><span class="o">.</span><span class="na">stages</span><span class="o">()[</span><span class="mi">1</span><span class="o">]);</span> |
| <span class="nc">System</span><span class="o">.</span><span class="na">out</span><span class="o">.</span><span class="na">println</span><span class="o">(</span><span class="s">"Factors: "</span> <span class="o">+</span> <span class="n">fmModel</span><span class="o">.</span><span class="na">factors</span><span class="o">());</span> |
| <span class="nc">System</span><span class="o">.</span><span class="na">out</span><span class="o">.</span><span class="na">println</span><span class="o">(</span><span class="s">"Linear: "</span> <span class="o">+</span> <span class="n">fmModel</span><span class="o">.</span><span class="na">linear</span><span class="o">());</span> |
| <span class="nc">System</span><span class="o">.</span><span class="na">out</span><span class="o">.</span><span class="na">println</span><span class="o">(</span><span class="s">"Intercept: "</span> <span class="o">+</span> <span class="n">fmModel</span><span class="o">.</span><span class="na">intercept</span><span class="o">());</span></code></pre></div> |
| <div><small>Find full example code at "examples/src/main/java/org/apache/spark/examples/ml/JavaFMRegressorExample.java" in the Spark repo.</small></div> |
| </div> |
| |
| <div data-lang="python"> |
| |
| <p>Refer to the <a href="api/python/reference/api/pyspark.ml.regression.FMRegressor.html#pyspark.ml.regression.FMRegressor">Python API docs</a> for more details.</p> |
| |
| <div class="highlight"><pre class="codehilite"><code><span class="kn">from</span> <span class="nn">pyspark.ml</span> <span class="kn">import</span> <span class="n">Pipeline</span> |
| <span class="kn">from</span> <span class="nn">pyspark.ml.regression</span> <span class="kn">import</span> <span class="n">FMRegressor</span> |
| <span class="kn">from</span> <span class="nn">pyspark.ml.feature</span> <span class="kn">import</span> <span class="n">MinMaxScaler</span> |
| <span class="kn">from</span> <span class="nn">pyspark.ml.evaluation</span> <span class="kn">import</span> <span class="n">RegressionEvaluator</span> |
| |
| <span class="c1"># Load and parse the data file, converting it to a DataFrame. |
| </span><span class="n">data</span> <span class="o">=</span> <span class="n">spark</span><span class="p">.</span><span class="n">read</span><span class="p">.</span><span class="nb">format</span><span class="p">(</span><span class="s">"libsvm"</span><span class="p">).</span><span class="n">load</span><span class="p">(</span><span class="s">"data/mllib/sample_libsvm_data.txt"</span><span class="p">)</span> |
| |
| <span class="c1"># Scale features. |
| </span><span class="n">featureScaler</span> <span class="o">=</span> <span class="n">MinMaxScaler</span><span class="p">(</span><span class="n">inputCol</span><span class="o">=</span><span class="s">"features"</span><span class="p">,</span> <span class="n">outputCol</span><span class="o">=</span><span class="s">"scaledFeatures"</span><span class="p">).</span><span class="n">fit</span><span class="p">(</span><span class="n">data</span><span class="p">)</span> |
| |
| <span class="c1"># Split the data into training and test sets (30% held out for testing) |
| </span><span class="p">(</span><span class="n">trainingData</span><span class="p">,</span> <span class="n">testData</span><span class="p">)</span> <span class="o">=</span> <span class="n">data</span><span class="p">.</span><span class="n">randomSplit</span><span class="p">([</span><span class="mf">0.7</span><span class="p">,</span> <span class="mf">0.3</span><span class="p">])</span> |
| |
| <span class="c1"># Train a FM model. |
| </span><span class="n">fm</span> <span class="o">=</span> <span class="n">FMRegressor</span><span class="p">(</span><span class="n">featuresCol</span><span class="o">=</span><span class="s">"scaledFeatures"</span><span class="p">,</span> <span class="n">stepSize</span><span class="o">=</span><span class="mf">0.001</span><span class="p">)</span> |
| |
| <span class="c1"># Create a Pipeline. |
| </span><span class="n">pipeline</span> <span class="o">=</span> <span class="n">Pipeline</span><span class="p">(</span><span class="n">stages</span><span class="o">=</span><span class="p">[</span><span class="n">featureScaler</span><span class="p">,</span> <span class="n">fm</span><span class="p">])</span> |
| |
| <span class="c1"># Train model. |
| </span><span class="n">model</span> <span class="o">=</span> <span class="n">pipeline</span><span class="p">.</span><span class="n">fit</span><span class="p">(</span><span class="n">trainingData</span><span class="p">)</span> |
| |
| <span class="c1"># Make predictions. |
| </span><span class="n">predictions</span> <span class="o">=</span> <span class="n">model</span><span class="p">.</span><span class="n">transform</span><span class="p">(</span><span class="n">testData</span><span class="p">)</span> |
| |
| <span class="c1"># Select example rows to display. |
| </span><span class="n">predictions</span><span class="p">.</span><span class="n">select</span><span class="p">(</span><span class="s">"prediction"</span><span class="p">,</span> <span class="s">"label"</span><span class="p">,</span> <span class="s">"features"</span><span class="p">).</span><span class="n">show</span><span class="p">(</span><span class="mi">5</span><span class="p">)</span> |
| |
| <span class="c1"># Select (prediction, true label) and compute test error |
| </span><span class="n">evaluator</span> <span class="o">=</span> <span class="n">RegressionEvaluator</span><span class="p">(</span> |
| <span class="n">labelCol</span><span class="o">=</span><span class="s">"label"</span><span class="p">,</span> <span class="n">predictionCol</span><span class="o">=</span><span class="s">"prediction"</span><span class="p">,</span> <span class="n">metricName</span><span class="o">=</span><span class="s">"rmse"</span><span class="p">)</span> |
| <span class="n">rmse</span> <span class="o">=</span> <span class="n">evaluator</span><span class="p">.</span><span class="n">evaluate</span><span class="p">(</span><span class="n">predictions</span><span class="p">)</span> |
| <span class="k">print</span><span class="p">(</span><span class="s">"Root Mean Squared Error (RMSE) on test data = %g"</span> <span class="o">%</span> <span class="n">rmse</span><span class="p">)</span> |
| |
| <span class="n">fmModel</span> <span class="o">=</span> <span class="n">model</span><span class="p">.</span><span class="n">stages</span><span class="p">[</span><span class="mi">1</span><span class="p">]</span> |
| <span class="k">print</span><span class="p">(</span><span class="s">"Factors: "</span> <span class="o">+</span> <span class="nb">str</span><span class="p">(</span><span class="n">fmModel</span><span class="p">.</span><span class="n">factors</span><span class="p">))</span> <span class="c1"># type: ignore |
| </span><span class="k">print</span><span class="p">(</span><span class="s">"Linear: "</span> <span class="o">+</span> <span class="nb">str</span><span class="p">(</span><span class="n">fmModel</span><span class="p">.</span><span class="n">linear</span><span class="p">))</span> <span class="c1"># type: ignore |
| </span><span class="k">print</span><span class="p">(</span><span class="s">"Intercept: "</span> <span class="o">+</span> <span class="nb">str</span><span class="p">(</span><span class="n">fmModel</span><span class="p">.</span><span class="n">intercept</span><span class="p">))</span> <span class="c1"># type: ignore</span></code></pre></div> |
| <div><small>Find full example code at "examples/src/main/python/ml/fm_regressor_example.py" in the Spark repo.</small></div> |
| </div> |
| |
| <div data-lang="r"> |
| |
| <p>Refer to the <a href="api/R/spark.fmRegressor.html">R API documentation</a> for more details.</p> |
| |
| <p>Note: At the moment SparkR doesn’t support feature scaling.</p> |
| |
| <div class="highlight"><pre class="codehilite"><code><span class="c1"># Load training data</span><span class="w"> |
| </span><span class="n">df</span><span class="w"> </span><span class="o"><-</span><span class="w"> </span><span class="n">read.df</span><span class="p">(</span><span class="s2">"data/mllib/sample_linear_regression_data.txt"</span><span class="p">,</span><span class="w"> </span><span class="n">source</span><span class="w"> </span><span class="o">=</span><span class="w"> </span><span class="s2">"libsvm"</span><span class="p">)</span><span class="w"> |
| </span><span class="n">training_test</span><span class="w"> </span><span class="o"><-</span><span class="w"> </span><span class="n">randomSplit</span><span class="p">(</span><span class="n">df</span><span class="p">,</span><span class="w"> </span><span class="nf">c</span><span class="p">(</span><span class="m">0.7</span><span class="p">,</span><span class="w"> </span><span class="m">0.3</span><span class="p">))</span><span class="w"> |
| </span><span class="n">training</span><span class="w"> </span><span class="o"><-</span><span class="w"> </span><span class="n">training_test</span><span class="p">[[</span><span class="m">1</span><span class="p">]]</span><span class="w"> |
| </span><span class="n">test</span><span class="w"> </span><span class="o"><-</span><span class="w"> </span><span class="n">training_test</span><span class="p">[[</span><span class="m">2</span><span class="p">]]</span><span class="w"> |
| |
| </span><span class="c1"># Fit a FM regression model</span><span class="w"> |
| </span><span class="n">model</span><span class="w"> </span><span class="o"><-</span><span class="w"> </span><span class="n">spark.fmRegressor</span><span class="p">(</span><span class="n">training</span><span class="p">,</span><span class="w"> </span><span class="n">label</span><span class="w"> </span><span class="o">~</span><span class="w"> </span><span class="n">features</span><span class="p">)</span><span class="w"> |
| |
| </span><span class="c1"># Model summary</span><span class="w"> |
| </span><span class="n">summary</span><span class="p">(</span><span class="n">model</span><span class="p">)</span><span class="w"> |
| |
| </span><span class="c1"># Prediction</span><span class="w"> |
| </span><span class="n">predictions</span><span class="w"> </span><span class="o"><-</span><span class="w"> </span><span class="n">predict</span><span class="p">(</span><span class="n">model</span><span class="p">,</span><span class="w"> </span><span class="n">test</span><span class="p">)</span><span class="w"> |
| </span><span class="n">head</span><span class="p">(</span><span class="n">predictions</span><span class="p">)</span></code></pre></div> |
| <div><small>Find full example code at "examples/src/main/r/ml/fmRegressor.R" in the Spark repo.</small></div> |
| </div> |
| |
| </div> |
| |
| <h1 id="linear-methods">Linear methods</h1> |
| |
| <p>We implement popular linear methods such as logistic |
| regression and linear least squares with $L_1$ or $L_2$ regularization. |
| Refer to <a href="mllib-linear-methods.html">the linear methods guide for the RDD-based API</a> for |
| details about implementation and tuning; this information is still relevant.</p> |
| |
| <p>We also include a DataFrame API for <a href="http://en.wikipedia.org/wiki/Elastic_net_regularization">Elastic |
| net</a>, a hybrid |
| of $L_1$ and $L_2$ regularization proposed in <a href="http://users.stat.umn.edu/~zouxx019/Papers/elasticnet.pdf">Zou et al, Regularization |
| and variable selection via the elastic |
| net</a>. |
| Mathematically, it is defined as a convex combination of the $L_1$ and |
| the $L_2$ regularization terms: |
| <code class="language-plaintext highlighter-rouge">\[ |
| \alpha \left( \lambda \|\wv\|_1 \right) + (1-\alpha) \left( \frac{\lambda}{2}\|\wv\|_2^2 \right) , \alpha \in [0, 1], \lambda \geq 0 |
| \]</code> |
| By setting $\alpha$ properly, elastic net contains both $L_1$ and $L_2$ |
| regularization as special cases. For example, if a <a href="https://en.wikipedia.org/wiki/Linear_regression">linear |
| regression</a> model is |
| trained with the elastic net parameter $\alpha$ set to $1$, it is |
| equivalent to a |
| <a href="http://en.wikipedia.org/wiki/Least_squares#Lasso_method">Lasso</a> model. |
| On the other hand, if $\alpha$ is set to $0$, the trained model reduces |
| to a <a href="http://en.wikipedia.org/wiki/Tikhonov_regularization">ridge |
| regression</a> model. |
| We implement Pipelines API for both linear regression and logistic |
| regression with elastic net regularization.</p> |
| |
| <h1 id="factorization-machines">Factorization Machines</h1> |
| |
| <p><a href="https://www.csie.ntu.edu.tw/~b97053/paper/Rendle2010FM.pdf">Factorization Machines</a> are able to estimate interactions |
| between features even in problems with huge sparsity (like advertising and recommendation system). |
| The <code class="language-plaintext highlighter-rouge">spark.ml</code> implementation supports factorization machines for binary classification and for regression.</p> |
| |
| <p>Factorization machines formula is:</p> |
| |
| \[\hat{y} = w_0 + \sum\limits^n_{i-1} w_i x_i + |
| \sum\limits^n_{i=1} \sum\limits^n_{j=i+1} \langle v_i, v_j \rangle x_i x_j\] |
| |
| <p>The first two terms denote intercept and linear term (same as in linear regression), |
| and the last term denotes pairwise interactions term. \(v_i\) describes the i-th variable |
| with k factors.</p> |
| |
| <p>FM can be used for regression and optimization criterion is mean square error. FM also can be used for |
| binary classification through sigmoid function. The optimization criterion is logistic loss.</p> |
| |
| <p>The pairwise interactions can be reformulated:</p> |
| |
| \[\sum\limits^n_{i=1} \sum\limits^n_{j=i+1} \langle v_i, v_j \rangle x_i x_j |
| = \frac{1}{2}\sum\limits^k_{f=1} |
| \left(\left( \sum\limits^n_{i=1}v_{i,f}x_i \right)^2 - |
| \sum\limits^n_{i=1}v_{i,f}^2x_i^2 \right)\] |
| |
| <p>This equation has only linear complexity in both k and n - i.e. its computation is in \(O(kn)\).</p> |
| |
| <p>In general, in order to prevent the exploding gradient problem, it is best to scale continuous features to be between 0 and 1, |
| or bin the continuous features and one-hot encode them.</p> |
| |
| <h1 id="decision-trees">Decision trees</h1> |
| |
| <p><a href="http://en.wikipedia.org/wiki/Decision_tree_learning">Decision trees</a> |
| and their ensembles are popular methods for the machine learning tasks of |
| classification and regression. Decision trees are widely used since they are easy to interpret, |
| handle categorical features, extend to the multiclass classification setting, do not require |
| feature scaling, and are able to capture non-linearities and feature interactions. Tree ensemble |
| algorithms such as random forests and boosting are among the top performers for classification and |
| regression tasks.</p> |
| |
| <p>The <code class="language-plaintext highlighter-rouge">spark.ml</code> implementation supports decision trees for binary and multiclass classification and for regression, |
| using both continuous and categorical features. The implementation partitions data by rows, |
| allowing distributed training with millions or even billions of instances.</p> |
| |
| <p>Users can find more information about the decision tree algorithm in the <a href="mllib-decision-tree.html">MLlib Decision Tree guide</a>. |
| The main differences between this API and the <a href="mllib-decision-tree.html">original MLlib Decision Tree API</a> are:</p> |
| |
| <ul> |
| <li>support for ML Pipelines</li> |
| <li>separation of Decision Trees for classification vs. regression</li> |
| <li>use of DataFrame metadata to distinguish continuous and categorical features</li> |
| </ul> |
| |
| <p>The Pipelines API for Decision Trees offers a bit more functionality than the original API.<br /> |
| In particular, for classification, users can get the predicted probability of each class (a.k.a. class conditional probabilities); |
| for regression, users can get the biased sample variance of prediction.</p> |
| |
| <p>Ensembles of trees (Random Forests and Gradient-Boosted Trees) are described below in the <a href="#tree-ensembles">Tree ensembles section</a>.</p> |
| |
| <h2 id="inputs-and-outputs">Inputs and Outputs</h2> |
| |
| <p>We list the input and output (prediction) column types here. |
| All output columns are optional; to exclude an output column, set its corresponding Param to an empty string.</p> |
| |
| <h3 id="input-columns">Input Columns</h3> |
| |
| <table class="table"> |
| <thead> |
| <tr> |
| <th align="left">Param name</th> |
| <th align="left">Type(s)</th> |
| <th align="left">Default</th> |
| <th align="left">Description</th> |
| </tr> |
| </thead> |
| <tbody> |
| <tr> |
| <td>labelCol</td> |
| <td>Double</td> |
| <td>"label"</td> |
| <td>Label to predict</td> |
| </tr> |
| <tr> |
| <td>featuresCol</td> |
| <td>Vector</td> |
| <td>"features"</td> |
| <td>Feature vector</td> |
| </tr> |
| </tbody> |
| </table> |
| |
| <h3 id="output-columns">Output Columns</h3> |
| |
| <table class="table"> |
| <thead> |
| <tr> |
| <th align="left">Param name</th> |
| <th align="left">Type(s)</th> |
| <th align="left">Default</th> |
| <th align="left">Description</th> |
| <th align="left">Notes</th> |
| </tr> |
| </thead> |
| <tbody> |
| <tr> |
| <td>predictionCol</td> |
| <td>Double</td> |
| <td>"prediction"</td> |
| <td>Predicted label</td> |
| <td></td> |
| </tr> |
| <tr> |
| <td>rawPredictionCol</td> |
| <td>Vector</td> |
| <td>"rawPrediction"</td> |
| <td>Vector of length # classes, with the counts of training instance labels at the tree node which makes the prediction</td> |
| <td>Classification only</td> |
| </tr> |
| <tr> |
| <td>probabilityCol</td> |
| <td>Vector</td> |
| <td>"probability"</td> |
| <td>Vector of length # classes equal to rawPrediction normalized to a multinomial distribution</td> |
| <td>Classification only</td> |
| </tr> |
| <tr> |
| <td>varianceCol</td> |
| <td>Double</td> |
| <td></td> |
| <td>The biased sample variance of prediction</td> |
| <td>Regression only</td> |
| </tr> |
| </tbody> |
| </table> |
| |
| <h1 id="tree-ensembles">Tree Ensembles</h1> |
| |
| <p>The DataFrame API supports two major tree ensemble algorithms: <a href="http://en.wikipedia.org/wiki/Random_forest">Random Forests</a> and <a href="http://en.wikipedia.org/wiki/Gradient_boosting">Gradient-Boosted Trees (GBTs)</a>. |
| Both use <a href="ml-classification-regression.html#decision-trees"><code class="language-plaintext highlighter-rouge">spark.ml</code> decision trees</a> as their base models.</p> |
| |
| <p>Users can find more information about ensemble algorithms in the <a href="mllib-ensembles.html">MLlib Ensemble guide</a>.<br /> |
| In this section, we demonstrate the DataFrame API for ensembles.</p> |
| |
| <p>The main differences between this API and the <a href="mllib-ensembles.html">original MLlib ensembles API</a> are:</p> |
| |
| <ul> |
| <li>support for DataFrames and ML Pipelines</li> |
| <li>separation of classification vs. regression</li> |
| <li>use of DataFrame metadata to distinguish continuous and categorical features</li> |
| <li>more functionality for random forests: estimates of feature importance, as well as the predicted probability of each class (a.k.a. class conditional probabilities) for classification.</li> |
| </ul> |
| |
| <h2 id="random-forests">Random Forests</h2> |
| |
| <p><a href="http://en.wikipedia.org/wiki/Random_forest">Random forests</a> |
| are ensembles of <a href="ml-classification-regression.html#decision-trees">decision trees</a>. |
| Random forests combine many decision trees in order to reduce the risk of overfitting. |
| The <code class="language-plaintext highlighter-rouge">spark.ml</code> implementation supports random forests for binary and multiclass classification and for regression, |
| using both continuous and categorical features.</p> |
| |
| <p>For more information on the algorithm itself, please see the <a href="mllib-ensembles.html#random-forests"><code class="language-plaintext highlighter-rouge">spark.mllib</code> documentation on random forests</a>.</p> |
| |
| <h3 id="inputs-and-outputs-1">Inputs and Outputs</h3> |
| |
| <p>We list the input and output (prediction) column types here. |
| All output columns are optional; to exclude an output column, set its corresponding Param to an empty string.</p> |
| |
| <h4 id="input-columns-1">Input Columns</h4> |
| |
| <table class="table"> |
| <thead> |
| <tr> |
| <th align="left">Param name</th> |
| <th align="left">Type(s)</th> |
| <th align="left">Default</th> |
| <th align="left">Description</th> |
| </tr> |
| </thead> |
| <tbody> |
| <tr> |
| <td>labelCol</td> |
| <td>Double</td> |
| <td>"label"</td> |
| <td>Label to predict</td> |
| </tr> |
| <tr> |
| <td>featuresCol</td> |
| <td>Vector</td> |
| <td>"features"</td> |
| <td>Feature vector</td> |
| </tr> |
| </tbody> |
| </table> |
| |
| <h4 id="output-columns-predictions">Output Columns (Predictions)</h4> |
| |
| <table class="table"> |
| <thead> |
| <tr> |
| <th align="left">Param name</th> |
| <th align="left">Type(s)</th> |
| <th align="left">Default</th> |
| <th align="left">Description</th> |
| <th align="left">Notes</th> |
| </tr> |
| </thead> |
| <tbody> |
| <tr> |
| <td>predictionCol</td> |
| <td>Double</td> |
| <td>"prediction"</td> |
| <td>Predicted label</td> |
| <td></td> |
| </tr> |
| <tr> |
| <td>rawPredictionCol</td> |
| <td>Vector</td> |
| <td>"rawPrediction"</td> |
| <td>Vector of length # classes, with the counts of training instance labels at the tree node which makes the prediction</td> |
| <td>Classification only</td> |
| </tr> |
| <tr> |
| <td>probabilityCol</td> |
| <td>Vector</td> |
| <td>"probability"</td> |
| <td>Vector of length # classes equal to rawPrediction normalized to a multinomial distribution</td> |
| <td>Classification only</td> |
| </tr> |
| </tbody> |
| </table> |
| |
| <h2 id="gradient-boosted-trees-gbts">Gradient-Boosted Trees (GBTs)</h2> |
| |
| <p><a href="http://en.wikipedia.org/wiki/Gradient_boosting">Gradient-Boosted Trees (GBTs)</a> |
| are ensembles of <a href="ml-classification-regression.html#decision-trees">decision trees</a>. |
| GBTs iteratively train decision trees in order to minimize a loss function. |
| The <code class="language-plaintext highlighter-rouge">spark.ml</code> implementation supports GBTs for binary classification and for regression, |
| using both continuous and categorical features.</p> |
| |
| <p>For more information on the algorithm itself, please see the <a href="mllib-ensembles.html#gradient-boosted-trees-gbts"><code class="language-plaintext highlighter-rouge">spark.mllib</code> documentation on GBTs</a>.</p> |
| |
| <h3 id="inputs-and-outputs-2">Inputs and Outputs</h3> |
| |
| <p>We list the input and output (prediction) column types here. |
| All output columns are optional; to exclude an output column, set its corresponding Param to an empty string.</p> |
| |
| <h4 id="input-columns-2">Input Columns</h4> |
| |
| <table class="table"> |
| <thead> |
| <tr> |
| <th align="left">Param name</th> |
| <th align="left">Type(s)</th> |
| <th align="left">Default</th> |
| <th align="left">Description</th> |
| </tr> |
| </thead> |
| <tbody> |
| <tr> |
| <td>labelCol</td> |
| <td>Double</td> |
| <td>"label"</td> |
| <td>Label to predict</td> |
| </tr> |
| <tr> |
| <td>featuresCol</td> |
| <td>Vector</td> |
| <td>"features"</td> |
| <td>Feature vector</td> |
| </tr> |
| </tbody> |
| </table> |
| |
| <p>Note that <code class="language-plaintext highlighter-rouge">GBTClassifier</code> currently only supports binary labels.</p> |
| |
| <h4 id="output-columns-predictions-1">Output Columns (Predictions)</h4> |
| |
| <table class="table"> |
| <thead> |
| <tr> |
| <th align="left">Param name</th> |
| <th align="left">Type(s)</th> |
| <th align="left">Default</th> |
| <th align="left">Description</th> |
| <th align="left">Notes</th> |
| </tr> |
| </thead> |
| <tbody> |
| <tr> |
| <td>predictionCol</td> |
| <td>Double</td> |
| <td>"prediction"</td> |
| <td>Predicted label</td> |
| <td></td> |
| </tr> |
| </tbody> |
| </table> |
| |
| <p>In the future, <code class="language-plaintext highlighter-rouge">GBTClassifier</code> will also output columns for <code class="language-plaintext highlighter-rouge">rawPrediction</code> and <code class="language-plaintext highlighter-rouge">probability</code>, just as <code class="language-plaintext highlighter-rouge">RandomForestClassifier</code> does.</p> |
| |
| |
| </div> |
| |
| <!-- /container --> |
| </div> |
| |
| <script src="js/vendor/jquery-3.5.1.min.js"></script> |
| <script src="js/vendor/bootstrap.bundle.min.js"></script> |
| <script src="js/vendor/anchor.min.js"></script> |
| <script src="js/main.js"></script> |
| <script type="text/javascript" src="https://cdn.jsdelivr.net/npm/docsearch.js@2/dist/cdn/docsearch.min.js"></script> |
| <script type="text/javascript"> |
| // DocSearch is entirely free and automated. DocSearch is built in two parts: |
| // 1. a crawler which we run on our own infrastructure every 24 hours. It follows every link |
| // in your website and extract content from every page it traverses. It then pushes this |
| // content to an Algolia index. |
| // 2. a JavaScript snippet to be inserted in your website that will bind this Algolia index |
| // to your search input and display its results in a dropdown UI. If you want to find more |
| // details on how works DocSearch, check the docs of DocSearch. |
| docsearch({ |
| apiKey: 'd62f962a82bc9abb53471cb7b89da35e', |
| appId: 'RAI69RXRSK', |
| indexName: 'apache_spark', |
| inputSelector: '#docsearch-input', |
| enhancedSearchInput: true, |
| algoliaOptions: { |
| 'facetFilters': ["version:3.3.0"] |
| }, |
| debug: false // Set debug to true if you want to inspect the dropdown |
| }); |
| |
| </script> |
| |
| <!-- MathJax Section --> |
| <script type="text/x-mathjax-config"> |
| MathJax.Hub.Config({ |
| TeX: { equationNumbers: { autoNumber: "AMS" } } |
| }); |
| </script> |
| <script> |
| // Note that we load MathJax this way to work with local file (file://), HTTP and HTTPS. |
| // We could use "//cdn.mathjax...", but that won't support "file://". |
| (function(d, script) { |
| script = d.createElement('script'); |
| script.type = 'text/javascript'; |
| script.async = true; |
| script.onload = function(){ |
| MathJax.Hub.Config({ |
| tex2jax: { |
| inlineMath: [ ["$", "$"], ["\\\\(","\\\\)"] ], |
| displayMath: [ ["$$","$$"], ["\\[", "\\]"] ], |
| processEscapes: true, |
| skipTags: ['script', 'noscript', 'style', 'textarea', 'pre'] |
| } |
| }); |
| }; |
| script.src = ('https:' == document.location.protocol ? 'https://' : 'http://') + |
| 'cdnjs.cloudflare.com/ajax/libs/mathjax/2.7.1/MathJax.js' + |
| '?config=TeX-AMS-MML_HTMLorMML'; |
| d.getElementsByTagName('head')[0].appendChild(script); |
| }(document)); |
| </script> |
| </body> |
| </html> |