| |
| <!DOCTYPE html> |
| <!--[if lt IE 7]> <html class="no-js lt-ie9 lt-ie8 lt-ie7"> <![endif]--> |
| <!--[if IE 7]> <html class="no-js lt-ie9 lt-ie8"> <![endif]--> |
| <!--[if IE 8]> <html class="no-js lt-ie9"> <![endif]--> |
| <!--[if gt IE 8]><!--> <html class="no-js"> <!--<![endif]--> |
| <head> |
| <meta charset="utf-8"> |
| <meta http-equiv="X-UA-Compatible" content="IE=edge,chrome=1"> |
| <title>Evaluation Metrics - RDD-based API - Spark 2.4.5 Documentation</title> |
| |
| |
| |
| |
| <link rel="stylesheet" href="css/bootstrap.min.css"> |
| <style> |
| body { |
| padding-top: 60px; |
| padding-bottom: 40px; |
| } |
| </style> |
| <meta name="viewport" content="width=device-width"> |
| <link rel="stylesheet" href="css/bootstrap-responsive.min.css"> |
| <link rel="stylesheet" href="css/main.css"> |
| |
| <script src="js/vendor/modernizr-2.6.1-respond-1.1.0.min.js"></script> |
| |
| <link rel="stylesheet" href="css/pygments-default.css"> |
| |
| |
| <!-- Google analytics script --> |
| <script type="text/javascript"> |
| var _gaq = _gaq || []; |
| _gaq.push(['_setAccount', 'UA-32518208-2']); |
| _gaq.push(['_trackPageview']); |
| |
| (function() { |
| var ga = document.createElement('script'); ga.type = 'text/javascript'; ga.async = true; |
| ga.src = ('https:' == document.location.protocol ? 'https://ssl' : 'http://www') + '.google-analytics.com/ga.js'; |
| var s = document.getElementsByTagName('script')[0]; s.parentNode.insertBefore(ga, s); |
| })(); |
| </script> |
| |
| |
| </head> |
| <body> |
| <!--[if lt IE 7]> |
| <p class="chromeframe">You are using an outdated browser. <a href="https://browsehappy.com/">Upgrade your browser today</a> or <a href="http://www.google.com/chromeframe/?redirect=true">install Google Chrome Frame</a> to better experience this site.</p> |
| <![endif]--> |
| |
| <!-- This code is taken from http://twitter.github.com/bootstrap/examples/hero.html --> |
| |
| <div class="navbar navbar-fixed-top" id="topbar"> |
| <div class="navbar-inner"> |
| <div class="container"> |
| <div class="brand"><a href="index.html"> |
| <img src="img/spark-logo-hd.png" style="height:50px;"/></a><span class="version">2.4.5</span> |
| </div> |
| <ul class="nav"> |
| <!--TODO(andyk): Add class="active" attribute to li some how.--> |
| <li><a href="index.html">Overview</a></li> |
| |
| <li class="dropdown"> |
| <a href="#" class="dropdown-toggle" data-toggle="dropdown">Programming Guides<b class="caret"></b></a> |
| <ul class="dropdown-menu"> |
| <li><a href="quick-start.html">Quick Start</a></li> |
| <li><a href="rdd-programming-guide.html">RDDs, Accumulators, Broadcasts Vars</a></li> |
| <li><a href="sql-programming-guide.html">SQL, DataFrames, and Datasets</a></li> |
| <li><a href="structured-streaming-programming-guide.html">Structured Streaming</a></li> |
| <li><a href="streaming-programming-guide.html">Spark Streaming (DStreams)</a></li> |
| <li><a href="ml-guide.html">MLlib (Machine Learning)</a></li> |
| <li><a href="graphx-programming-guide.html">GraphX (Graph Processing)</a></li> |
| <li><a href="sparkr.html">SparkR (R on Spark)</a></li> |
| </ul> |
| </li> |
| |
| <li class="dropdown"> |
| <a href="#" class="dropdown-toggle" data-toggle="dropdown">API Docs<b class="caret"></b></a> |
| <ul class="dropdown-menu"> |
| <li><a href="api/scala/index.html#org.apache.spark.package">Scala</a></li> |
| <li><a href="api/java/index.html">Java</a></li> |
| <li><a href="api/python/index.html">Python</a></li> |
| <li><a href="api/R/index.html">R</a></li> |
| <li><a href="api/sql/index.html">SQL, Built-in Functions</a></li> |
| </ul> |
| </li> |
| |
| <li class="dropdown"> |
| <a href="#" class="dropdown-toggle" data-toggle="dropdown">Deploying<b class="caret"></b></a> |
| <ul class="dropdown-menu"> |
| <li><a href="cluster-overview.html">Overview</a></li> |
| <li><a href="submitting-applications.html">Submitting Applications</a></li> |
| <li class="divider"></li> |
| <li><a href="spark-standalone.html">Spark Standalone</a></li> |
| <li><a href="running-on-mesos.html">Mesos</a></li> |
| <li><a href="running-on-yarn.html">YARN</a></li> |
| <li><a href="running-on-kubernetes.html">Kubernetes</a></li> |
| </ul> |
| </li> |
| |
| <li class="dropdown"> |
| <a href="api.html" class="dropdown-toggle" data-toggle="dropdown">More<b class="caret"></b></a> |
| <ul class="dropdown-menu"> |
| <li><a href="configuration.html">Configuration</a></li> |
| <li><a href="monitoring.html">Monitoring</a></li> |
| <li><a href="tuning.html">Tuning Guide</a></li> |
| <li><a href="job-scheduling.html">Job Scheduling</a></li> |
| <li><a href="security.html">Security</a></li> |
| <li><a href="hardware-provisioning.html">Hardware Provisioning</a></li> |
| <li class="divider"></li> |
| <li><a href="building-spark.html">Building Spark</a></li> |
| <li><a href="https://spark.apache.org/contributing.html">Contributing to Spark</a></li> |
| <li><a href="https://spark.apache.org/third-party-projects.html">Third Party Projects</a></li> |
| </ul> |
| </li> |
| </ul> |
| <!--<p class="navbar-text pull-right"><span class="version-text">v2.4.5</span></p>--> |
| </div> |
| </div> |
| </div> |
| |
| <div class="container-wrapper"> |
| |
| |
| |
| <div class="left-menu-wrapper"> |
| <div class="left-menu"> |
| <h3><a href="ml-guide.html">MLlib: Main Guide</a></h3> |
| |
| <ul> |
| |
| <li> |
| <a href="ml-statistics.html"> |
| |
| Basic statistics |
| |
| </a> |
| </li> |
| |
| |
| |
| <li> |
| <a href="ml-datasource"> |
| |
| Data sources |
| |
| </a> |
| </li> |
| |
| |
| |
| <li> |
| <a href="ml-pipeline.html"> |
| |
| Pipelines |
| |
| </a> |
| </li> |
| |
| |
| |
| <li> |
| <a href="ml-features.html"> |
| |
| Extracting, transforming and selecting features |
| |
| </a> |
| </li> |
| |
| |
| |
| <li> |
| <a href="ml-classification-regression.html"> |
| |
| Classification and Regression |
| |
| </a> |
| </li> |
| |
| |
| |
| <li> |
| <a href="ml-clustering.html"> |
| |
| Clustering |
| |
| </a> |
| </li> |
| |
| |
| |
| <li> |
| <a href="ml-collaborative-filtering.html"> |
| |
| Collaborative filtering |
| |
| </a> |
| </li> |
| |
| |
| |
| <li> |
| <a href="ml-frequent-pattern-mining.html"> |
| |
| Frequent Pattern Mining |
| |
| </a> |
| </li> |
| |
| |
| |
| <li> |
| <a href="ml-tuning.html"> |
| |
| Model selection and tuning |
| |
| </a> |
| </li> |
| |
| |
| |
| <li> |
| <a href="ml-advanced.html"> |
| |
| Advanced topics |
| |
| </a> |
| </li> |
| |
| |
| |
| </ul> |
| |
| <h3><a href="mllib-guide.html">MLlib: RDD-based API Guide</a></h3> |
| |
| <ul> |
| |
| <li> |
| <a href="mllib-data-types.html"> |
| |
| Data types |
| |
| </a> |
| </li> |
| |
| |
| |
| <li> |
| <a href="mllib-statistics.html"> |
| |
| Basic statistics |
| |
| </a> |
| </li> |
| |
| |
| |
| <li> |
| <a href="mllib-classification-regression.html"> |
| |
| Classification and regression |
| |
| </a> |
| </li> |
| |
| |
| |
| <li> |
| <a href="mllib-collaborative-filtering.html"> |
| |
| Collaborative filtering |
| |
| </a> |
| </li> |
| |
| |
| |
| <li> |
| <a href="mllib-clustering.html"> |
| |
| Clustering |
| |
| </a> |
| </li> |
| |
| |
| |
| <li> |
| <a href="mllib-dimensionality-reduction.html"> |
| |
| Dimensionality reduction |
| |
| </a> |
| </li> |
| |
| |
| |
| <li> |
| <a href="mllib-feature-extraction.html"> |
| |
| Feature extraction and transformation |
| |
| </a> |
| </li> |
| |
| |
| |
| <li> |
| <a href="mllib-frequent-pattern-mining.html"> |
| |
| Frequent pattern mining |
| |
| </a> |
| </li> |
| |
| |
| |
| <li> |
| <a href="mllib-evaluation-metrics.html"> |
| |
| <b>Evaluation metrics</b> |
| |
| </a> |
| </li> |
| |
| |
| |
| <li> |
| <a href="mllib-pmml-model-export.html"> |
| |
| PMML model export |
| |
| </a> |
| </li> |
| |
| |
| |
| <li> |
| <a href="mllib-optimization.html"> |
| |
| Optimization (developer) |
| |
| </a> |
| </li> |
| |
| |
| |
| </ul> |
| |
| </div> |
| </div> |
| |
| <input id="nav-trigger" class="nav-trigger" checked type="checkbox"> |
| <label for="nav-trigger"></label> |
| <div class="content-with-sidebar" id="content"> |
| |
| <h1 class="title">Evaluation Metrics - RDD-based API</h1> |
| |
| |
| <ul id="markdown-toc"> |
| <li><a href="#classification-model-evaluation" id="markdown-toc-classification-model-evaluation">Classification model evaluation</a> <ul> |
| <li><a href="#binary-classification" id="markdown-toc-binary-classification">Binary classification</a> <ul> |
| <li><a href="#threshold-tuning" id="markdown-toc-threshold-tuning">Threshold tuning</a></li> |
| </ul> |
| </li> |
| <li><a href="#multiclass-classification" id="markdown-toc-multiclass-classification">Multiclass classification</a> <ul> |
| <li><a href="#label-based-metrics" id="markdown-toc-label-based-metrics">Label based metrics</a></li> |
| </ul> |
| </li> |
| <li><a href="#multilabel-classification" id="markdown-toc-multilabel-classification">Multilabel classification</a></li> |
| <li><a href="#ranking-systems" id="markdown-toc-ranking-systems">Ranking systems</a></li> |
| </ul> |
| </li> |
| <li><a href="#regression-model-evaluation" id="markdown-toc-regression-model-evaluation">Regression model evaluation</a></li> |
| </ul> |
| |
| <p><code>spark.mllib</code> comes with a number of machine learning algorithms that can be used to learn from and make predictions |
| on data. When these algorithms are applied to build machine learning models, there is a need to evaluate the performance |
| of the model on some criteria, which depends on the application and its requirements. <code>spark.mllib</code> also provides a |
| suite of metrics for the purpose of evaluating the performance of machine learning models.</p> |
| |
| <p>Specific machine learning algorithms fall under broader types of machine learning applications like classification, |
| regression, clustering, etc. Each of these types have well-established metrics for performance evaluation and those |
| metrics that are currently available in <code>spark.mllib</code> are detailed in this section.</p> |
| |
| <h2 id="classification-model-evaluation">Classification model evaluation</h2> |
| |
| <p>While there are many different types of classification algorithms, the evaluation of classification models all share |
| similar principles. In a <a href="https://en.wikipedia.org/wiki/Statistical_classification">supervised classification problem</a>, |
| there exists a true output and a model-generated predicted output for each data point. For this reason, the results for |
| each data point can be assigned to one of four categories:</p> |
| |
| <ul> |
| <li>True Positive (TP) - label is positive and prediction is also positive</li> |
| <li>True Negative (TN) - label is negative and prediction is also negative</li> |
| <li>False Positive (FP) - label is negative but prediction is positive</li> |
| <li>False Negative (FN) - label is positive but prediction is negative</li> |
| </ul> |
| |
| <p>These four numbers are the building blocks for most classifier evaluation metrics. A fundamental point when considering |
| classifier evaluation is that pure accuracy (i.e. was the prediction correct or incorrect) is not generally a good metric. The |
| reason for this is because a dataset may be highly unbalanced. For example, if a model is designed to predict fraud from |
| a dataset where 95% of the data points are <em>not fraud</em> and 5% of the data points are <em>fraud</em>, then a naive classifier |
| that predicts <em>not fraud</em>, regardless of input, will be 95% accurate. For this reason, metrics like |
| <a href="https://en.wikipedia.org/wiki/Precision_and_recall">precision and recall</a> are typically used because they take into |
| account the <em>type</em> of error. In most applications there is some desired balance between precision and recall, which can |
| be captured by combining the two into a single metric, called the <a href="https://en.wikipedia.org/wiki/F1_score">F-measure</a>.</p> |
| |
| <h3 id="binary-classification">Binary classification</h3> |
| |
| <p><a href="https://en.wikipedia.org/wiki/Binary_classification">Binary classifiers</a> are used to separate the elements of a given |
| dataset into one of two possible groups (e.g. fraud or not fraud) and is a special case of multiclass classification. |
| Most binary classification metrics can be generalized to multiclass classification metrics.</p> |
| |
| <h4 id="threshold-tuning">Threshold tuning</h4> |
| |
| <p>It is import to understand that many classification models actually output a “score” (often times a probability) for |
| each class, where a higher score indicates higher likelihood. In the binary case, the model may output a probability for |
| each class: $P(Y=1|X)$ and $P(Y=0|X)$. Instead of simply taking the higher probability, there may be some cases where |
| the model might need to be tuned so that it only predicts a class when the probability is very high (e.g. only block a |
| credit card transaction if the model predicts fraud with >90% probability). Therefore, there is a prediction <em>threshold</em> |
| which determines what the predicted class will be based on the probabilities that the model outputs.</p> |
| |
| <p>Tuning the prediction threshold will change the precision and recall of the model and is an important part of model |
| optimization. In order to visualize how precision, recall, and other metrics change as a function of the threshold it is |
| common practice to plot competing metrics against one another, parameterized by threshold. A P-R curve plots (precision, |
| recall) points for different threshold values, while a |
| <a href="https://en.wikipedia.org/wiki/Receiver_operating_characteristic">receiver operating characteristic</a>, or ROC, curve |
| plots (recall, false positive rate) points.</p> |
| |
| <p><strong>Available metrics</strong></p> |
| |
| <table class="table"> |
| <thead> |
| <tr><th>Metric</th><th>Definition</th></tr> |
| </thead> |
| <tbody> |
| <tr> |
| <td>Precision (Positive Predictive Value)</td> |
| <td>$PPV=\frac{TP}{TP + FP}$</td> |
| </tr> |
| <tr> |
| <td>Recall (True Positive Rate)</td> |
| <td>$TPR=\frac{TP}{P}=\frac{TP}{TP + FN}$</td> |
| </tr> |
| <tr> |
| <td>F-measure</td> |
| <td>$F(\beta) = \left(1 + \beta^2\right) \cdot \left(\frac{PPV \cdot TPR} |
| {\beta^2 \cdot PPV + TPR}\right)$</td> |
| </tr> |
| <tr> |
| <td>Receiver Operating Characteristic (ROC)</td> |
| <td>$FPR(T)=\int^\infty_{T} P_0(T)\,dT \\ TPR(T)=\int^\infty_{T} P_1(T)\,dT$</td> |
| </tr> |
| <tr> |
| <td>Area Under ROC Curve</td> |
| <td>$AUROC=\int^1_{0} \frac{TP}{P} d\left(\frac{FP}{N}\right)$</td> |
| </tr> |
| <tr> |
| <td>Area Under Precision-Recall Curve</td> |
| <td>$AUPRC=\int^1_{0} \frac{TP}{TP+FP} d\left(\frac{TP}{P}\right)$</td> |
| </tr> |
| </tbody> |
| </table> |
| |
| <p><strong>Examples</strong></p> |
| |
| <div class="codetabs"> |
| The following code snippets illustrate how to load a sample dataset, train a binary classification algorithm on the |
| data, and evaluate the performance of the algorithm by several binary evaluation metrics. |
| |
| <div data-lang="scala"> |
| <p>Refer to the <a href="api/scala/index.html#org.apache.spark.mllib.classification.LogisticRegressionWithLBFGS"><code>LogisticRegressionWithLBFGS</code> Scala docs</a> and <a href="api/scala/index.html#org.apache.spark.mllib.evaluation.BinaryClassificationMetrics"><code>BinaryClassificationMetrics</code> Scala docs</a> for details on the API.</p> |
| |
| <div class="highlight"><pre><span></span><span class="k">import</span> <span class="nn">org.apache.spark.mllib.classification.LogisticRegressionWithLBFGS</span> |
| <span class="k">import</span> <span class="nn">org.apache.spark.mllib.evaluation.BinaryClassificationMetrics</span> |
| <span class="k">import</span> <span class="nn">org.apache.spark.mllib.regression.LabeledPoint</span> |
| <span class="k">import</span> <span class="nn">org.apache.spark.mllib.util.MLUtils</span> |
| |
| <span class="c1">// Load training data in LIBSVM format</span> |
| <span class="k">val</span> <span class="n">data</span> <span class="k">=</span> <span class="nc">MLUtils</span><span class="o">.</span><span class="n">loadLibSVMFile</span><span class="o">(</span><span class="n">sc</span><span class="o">,</span> <span class="s">"data/mllib/sample_binary_classification_data.txt"</span><span class="o">)</span> |
| |
| <span class="c1">// Split data into training (60%) and test (40%)</span> |
| <span class="k">val</span> <span class="nc">Array</span><span class="o">(</span><span class="n">training</span><span class="o">,</span> <span class="n">test</span><span class="o">)</span> <span class="k">=</span> <span class="n">data</span><span class="o">.</span><span class="n">randomSplit</span><span class="o">(</span><span class="nc">Array</span><span class="o">(</span><span class="mf">0.6</span><span class="o">,</span> <span class="mf">0.4</span><span class="o">),</span> <span class="n">seed</span> <span class="k">=</span> <span class="mi">11L</span><span class="o">)</span> |
| <span class="n">training</span><span class="o">.</span><span class="n">cache</span><span class="o">()</span> |
| |
| <span class="c1">// Run training algorithm to build the model</span> |
| <span class="k">val</span> <span class="n">model</span> <span class="k">=</span> <span class="k">new</span> <span class="nc">LogisticRegressionWithLBFGS</span><span class="o">()</span> |
| <span class="o">.</span><span class="n">setNumClasses</span><span class="o">(</span><span class="mi">2</span><span class="o">)</span> |
| <span class="o">.</span><span class="n">run</span><span class="o">(</span><span class="n">training</span><span class="o">)</span> |
| |
| <span class="c1">// Clear the prediction threshold so the model will return probabilities</span> |
| <span class="n">model</span><span class="o">.</span><span class="n">clearThreshold</span> |
| |
| <span class="c1">// Compute raw scores on the test set</span> |
| <span class="k">val</span> <span class="n">predictionAndLabels</span> <span class="k">=</span> <span class="n">test</span><span class="o">.</span><span class="n">map</span> <span class="o">{</span> <span class="k">case</span> <span class="nc">LabeledPoint</span><span class="o">(</span><span class="n">label</span><span class="o">,</span> <span class="n">features</span><span class="o">)</span> <span class="k">=></span> |
| <span class="k">val</span> <span class="n">prediction</span> <span class="k">=</span> <span class="n">model</span><span class="o">.</span><span class="n">predict</span><span class="o">(</span><span class="n">features</span><span class="o">)</span> |
| <span class="o">(</span><span class="n">prediction</span><span class="o">,</span> <span class="n">label</span><span class="o">)</span> |
| <span class="o">}</span> |
| |
| <span class="c1">// Instantiate metrics object</span> |
| <span class="k">val</span> <span class="n">metrics</span> <span class="k">=</span> <span class="k">new</span> <span class="nc">BinaryClassificationMetrics</span><span class="o">(</span><span class="n">predictionAndLabels</span><span class="o">)</span> |
| |
| <span class="c1">// Precision by threshold</span> |
| <span class="k">val</span> <span class="n">precision</span> <span class="k">=</span> <span class="n">metrics</span><span class="o">.</span><span class="n">precisionByThreshold</span> |
| <span class="n">precision</span><span class="o">.</span><span class="n">foreach</span> <span class="o">{</span> <span class="k">case</span> <span class="o">(</span><span class="n">t</span><span class="o">,</span> <span class="n">p</span><span class="o">)</span> <span class="k">=></span> |
| <span class="n">println</span><span class="o">(</span><span class="s">s"Threshold: </span><span class="si">$t</span><span class="s">, Precision: </span><span class="si">$p</span><span class="s">"</span><span class="o">)</span> |
| <span class="o">}</span> |
| |
| <span class="c1">// Recall by threshold</span> |
| <span class="k">val</span> <span class="n">recall</span> <span class="k">=</span> <span class="n">metrics</span><span class="o">.</span><span class="n">recallByThreshold</span> |
| <span class="n">recall</span><span class="o">.</span><span class="n">foreach</span> <span class="o">{</span> <span class="k">case</span> <span class="o">(</span><span class="n">t</span><span class="o">,</span> <span class="n">r</span><span class="o">)</span> <span class="k">=></span> |
| <span class="n">println</span><span class="o">(</span><span class="s">s"Threshold: </span><span class="si">$t</span><span class="s">, Recall: </span><span class="si">$r</span><span class="s">"</span><span class="o">)</span> |
| <span class="o">}</span> |
| |
| <span class="c1">// Precision-Recall Curve</span> |
| <span class="k">val</span> <span class="nc">PRC</span> <span class="k">=</span> <span class="n">metrics</span><span class="o">.</span><span class="n">pr</span> |
| |
| <span class="c1">// F-measure</span> |
| <span class="k">val</span> <span class="n">f1Score</span> <span class="k">=</span> <span class="n">metrics</span><span class="o">.</span><span class="n">fMeasureByThreshold</span> |
| <span class="n">f1Score</span><span class="o">.</span><span class="n">foreach</span> <span class="o">{</span> <span class="k">case</span> <span class="o">(</span><span class="n">t</span><span class="o">,</span> <span class="n">f</span><span class="o">)</span> <span class="k">=></span> |
| <span class="n">println</span><span class="o">(</span><span class="s">s"Threshold: </span><span class="si">$t</span><span class="s">, F-score: </span><span class="si">$f</span><span class="s">, Beta = 1"</span><span class="o">)</span> |
| <span class="o">}</span> |
| |
| <span class="k">val</span> <span class="n">beta</span> <span class="k">=</span> <span class="mf">0.5</span> |
| <span class="k">val</span> <span class="n">fScore</span> <span class="k">=</span> <span class="n">metrics</span><span class="o">.</span><span class="n">fMeasureByThreshold</span><span class="o">(</span><span class="n">beta</span><span class="o">)</span> |
| <span class="n">f1Score</span><span class="o">.</span><span class="n">foreach</span> <span class="o">{</span> <span class="k">case</span> <span class="o">(</span><span class="n">t</span><span class="o">,</span> <span class="n">f</span><span class="o">)</span> <span class="k">=></span> |
| <span class="n">println</span><span class="o">(</span><span class="s">s"Threshold: </span><span class="si">$t</span><span class="s">, F-score: </span><span class="si">$f</span><span class="s">, Beta = 0.5"</span><span class="o">)</span> |
| <span class="o">}</span> |
| |
| <span class="c1">// AUPRC</span> |
| <span class="k">val</span> <span class="n">auPRC</span> <span class="k">=</span> <span class="n">metrics</span><span class="o">.</span><span class="n">areaUnderPR</span> |
| <span class="n">println</span><span class="o">(</span><span class="s">s"Area under precision-recall curve = </span><span class="si">$auPRC</span><span class="s">"</span><span class="o">)</span> |
| |
| <span class="c1">// Compute thresholds used in ROC and PR curves</span> |
| <span class="k">val</span> <span class="n">thresholds</span> <span class="k">=</span> <span class="n">precision</span><span class="o">.</span><span class="n">map</span><span class="o">(</span><span class="k">_</span><span class="o">.</span><span class="n">_1</span><span class="o">)</span> |
| |
| <span class="c1">// ROC Curve</span> |
| <span class="k">val</span> <span class="n">roc</span> <span class="k">=</span> <span class="n">metrics</span><span class="o">.</span><span class="n">roc</span> |
| |
| <span class="c1">// AUROC</span> |
| <span class="k">val</span> <span class="n">auROC</span> <span class="k">=</span> <span class="n">metrics</span><span class="o">.</span><span class="n">areaUnderROC</span> |
| <span class="n">println</span><span class="o">(</span><span class="s">s"Area under ROC = </span><span class="si">$auROC</span><span class="s">"</span><span class="o">)</span> |
| </pre></div> |
| <div><small>Find full example code at "examples/src/main/scala/org/apache/spark/examples/mllib/BinaryClassificationMetricsExample.scala" in the Spark repo.</small></div> |
| |
| </div> |
| |
| <div data-lang="java"> |
| <p>Refer to the <a href="api/java/org/apache/spark/mllib/classification/LogisticRegressionModel.html"><code>LogisticRegressionModel</code> Java docs</a> and <a href="api/java/org/apache/spark/mllib/classification/LogisticRegressionWithLBFGS.html"><code>LogisticRegressionWithLBFGS</code> Java docs</a> for details on the API.</p> |
| |
| <div class="highlight"><pre><span></span><span class="kn">import</span> <span class="nn">scala.Tuple2</span><span class="o">;</span> |
| |
| <span class="kn">import</span> <span class="nn">org.apache.spark.api.java.*</span><span class="o">;</span> |
| <span class="kn">import</span> <span class="nn">org.apache.spark.mllib.classification.LogisticRegressionModel</span><span class="o">;</span> |
| <span class="kn">import</span> <span class="nn">org.apache.spark.mllib.classification.LogisticRegressionWithLBFGS</span><span class="o">;</span> |
| <span class="kn">import</span> <span class="nn">org.apache.spark.mllib.evaluation.BinaryClassificationMetrics</span><span class="o">;</span> |
| <span class="kn">import</span> <span class="nn">org.apache.spark.mllib.regression.LabeledPoint</span><span class="o">;</span> |
| <span class="kn">import</span> <span class="nn">org.apache.spark.mllib.util.MLUtils</span><span class="o">;</span> |
| |
| <span class="n">String</span> <span class="n">path</span> <span class="o">=</span> <span class="s">"data/mllib/sample_binary_classification_data.txt"</span><span class="o">;</span> |
| <span class="n">JavaRDD</span><span class="o"><</span><span class="n">LabeledPoint</span><span class="o">></span> <span class="n">data</span> <span class="o">=</span> <span class="n">MLUtils</span><span class="o">.</span><span class="na">loadLibSVMFile</span><span class="o">(</span><span class="n">sc</span><span class="o">,</span> <span class="n">path</span><span class="o">).</span><span class="na">toJavaRDD</span><span class="o">();</span> |
| |
| <span class="c1">// Split initial RDD into two... [60% training data, 40% testing data].</span> |
| <span class="n">JavaRDD</span><span class="o"><</span><span class="n">LabeledPoint</span><span class="o">>[]</span> <span class="n">splits</span> <span class="o">=</span> |
| <span class="n">data</span><span class="o">.</span><span class="na">randomSplit</span><span class="o">(</span><span class="k">new</span> <span class="kt">double</span><span class="o">[]{</span><span class="mf">0.6</span><span class="o">,</span> <span class="mf">0.4</span><span class="o">},</span> <span class="mi">11L</span><span class="o">);</span> |
| <span class="n">JavaRDD</span><span class="o"><</span><span class="n">LabeledPoint</span><span class="o">></span> <span class="n">training</span> <span class="o">=</span> <span class="n">splits</span><span class="o">[</span><span class="mi">0</span><span class="o">].</span><span class="na">cache</span><span class="o">();</span> |
| <span class="n">JavaRDD</span><span class="o"><</span><span class="n">LabeledPoint</span><span class="o">></span> <span class="n">test</span> <span class="o">=</span> <span class="n">splits</span><span class="o">[</span><span class="mi">1</span><span class="o">];</span> |
| |
| <span class="c1">// Run training algorithm to build the model.</span> |
| <span class="n">LogisticRegressionModel</span> <span class="n">model</span> <span class="o">=</span> <span class="k">new</span> <span class="n">LogisticRegressionWithLBFGS</span><span class="o">()</span> |
| <span class="o">.</span><span class="na">setNumClasses</span><span class="o">(</span><span class="mi">2</span><span class="o">)</span> |
| <span class="o">.</span><span class="na">run</span><span class="o">(</span><span class="n">training</span><span class="o">.</span><span class="na">rdd</span><span class="o">());</span> |
| |
| <span class="c1">// Clear the prediction threshold so the model will return probabilities</span> |
| <span class="n">model</span><span class="o">.</span><span class="na">clearThreshold</span><span class="o">();</span> |
| |
| <span class="c1">// Compute raw scores on the test set.</span> |
| <span class="n">JavaPairRDD</span><span class="o"><</span><span class="n">Object</span><span class="o">,</span> <span class="n">Object</span><span class="o">></span> <span class="n">predictionAndLabels</span> <span class="o">=</span> <span class="n">test</span><span class="o">.</span><span class="na">mapToPair</span><span class="o">(</span><span class="n">p</span> <span class="o">-></span> |
| <span class="k">new</span> <span class="n">Tuple2</span><span class="o"><>(</span><span class="n">model</span><span class="o">.</span><span class="na">predict</span><span class="o">(</span><span class="n">p</span><span class="o">.</span><span class="na">features</span><span class="o">()),</span> <span class="n">p</span><span class="o">.</span><span class="na">label</span><span class="o">()));</span> |
| |
| <span class="c1">// Get evaluation metrics.</span> |
| <span class="n">BinaryClassificationMetrics</span> <span class="n">metrics</span> <span class="o">=</span> |
| <span class="k">new</span> <span class="n">BinaryClassificationMetrics</span><span class="o">(</span><span class="n">predictionAndLabels</span><span class="o">.</span><span class="na">rdd</span><span class="o">());</span> |
| |
| <span class="c1">// Precision by threshold</span> |
| <span class="n">JavaRDD</span><span class="o"><</span><span class="n">Tuple2</span><span class="o"><</span><span class="n">Object</span><span class="o">,</span> <span class="n">Object</span><span class="o">>></span> <span class="n">precision</span> <span class="o">=</span> <span class="n">metrics</span><span class="o">.</span><span class="na">precisionByThreshold</span><span class="o">().</span><span class="na">toJavaRDD</span><span class="o">();</span> |
| <span class="n">System</span><span class="o">.</span><span class="na">out</span><span class="o">.</span><span class="na">println</span><span class="o">(</span><span class="s">"Precision by threshold: "</span> <span class="o">+</span> <span class="n">precision</span><span class="o">.</span><span class="na">collect</span><span class="o">());</span> |
| |
| <span class="c1">// Recall by threshold</span> |
| <span class="n">JavaRDD</span><span class="o"><?></span> <span class="n">recall</span> <span class="o">=</span> <span class="n">metrics</span><span class="o">.</span><span class="na">recallByThreshold</span><span class="o">().</span><span class="na">toJavaRDD</span><span class="o">();</span> |
| <span class="n">System</span><span class="o">.</span><span class="na">out</span><span class="o">.</span><span class="na">println</span><span class="o">(</span><span class="s">"Recall by threshold: "</span> <span class="o">+</span> <span class="n">recall</span><span class="o">.</span><span class="na">collect</span><span class="o">());</span> |
| |
| <span class="c1">// F Score by threshold</span> |
| <span class="n">JavaRDD</span><span class="o"><?></span> <span class="n">f1Score</span> <span class="o">=</span> <span class="n">metrics</span><span class="o">.</span><span class="na">fMeasureByThreshold</span><span class="o">().</span><span class="na">toJavaRDD</span><span class="o">();</span> |
| <span class="n">System</span><span class="o">.</span><span class="na">out</span><span class="o">.</span><span class="na">println</span><span class="o">(</span><span class="s">"F1 Score by threshold: "</span> <span class="o">+</span> <span class="n">f1Score</span><span class="o">.</span><span class="na">collect</span><span class="o">());</span> |
| |
| <span class="n">JavaRDD</span><span class="o"><?></span> <span class="n">f2Score</span> <span class="o">=</span> <span class="n">metrics</span><span class="o">.</span><span class="na">fMeasureByThreshold</span><span class="o">(</span><span class="mf">2.0</span><span class="o">).</span><span class="na">toJavaRDD</span><span class="o">();</span> |
| <span class="n">System</span><span class="o">.</span><span class="na">out</span><span class="o">.</span><span class="na">println</span><span class="o">(</span><span class="s">"F2 Score by threshold: "</span> <span class="o">+</span> <span class="n">f2Score</span><span class="o">.</span><span class="na">collect</span><span class="o">());</span> |
| |
| <span class="c1">// Precision-recall curve</span> |
| <span class="n">JavaRDD</span><span class="o"><?></span> <span class="n">prc</span> <span class="o">=</span> <span class="n">metrics</span><span class="o">.</span><span class="na">pr</span><span class="o">().</span><span class="na">toJavaRDD</span><span class="o">();</span> |
| <span class="n">System</span><span class="o">.</span><span class="na">out</span><span class="o">.</span><span class="na">println</span><span class="o">(</span><span class="s">"Precision-recall curve: "</span> <span class="o">+</span> <span class="n">prc</span><span class="o">.</span><span class="na">collect</span><span class="o">());</span> |
| |
| <span class="c1">// Thresholds</span> |
| <span class="n">JavaRDD</span><span class="o"><</span><span class="n">Double</span><span class="o">></span> <span class="n">thresholds</span> <span class="o">=</span> <span class="n">precision</span><span class="o">.</span><span class="na">map</span><span class="o">(</span><span class="n">t</span> <span class="o">-></span> <span class="n">Double</span><span class="o">.</span><span class="na">parseDouble</span><span class="o">(</span><span class="n">t</span><span class="o">.</span><span class="na">_1</span><span class="o">().</span><span class="na">toString</span><span class="o">()));</span> |
| |
| <span class="c1">// ROC Curve</span> |
| <span class="n">JavaRDD</span><span class="o"><?></span> <span class="n">roc</span> <span class="o">=</span> <span class="n">metrics</span><span class="o">.</span><span class="na">roc</span><span class="o">().</span><span class="na">toJavaRDD</span><span class="o">();</span> |
| <span class="n">System</span><span class="o">.</span><span class="na">out</span><span class="o">.</span><span class="na">println</span><span class="o">(</span><span class="s">"ROC curve: "</span> <span class="o">+</span> <span class="n">roc</span><span class="o">.</span><span class="na">collect</span><span class="o">());</span> |
| |
| <span class="c1">// AUPRC</span> |
| <span class="n">System</span><span class="o">.</span><span class="na">out</span><span class="o">.</span><span class="na">println</span><span class="o">(</span><span class="s">"Area under precision-recall curve = "</span> <span class="o">+</span> <span class="n">metrics</span><span class="o">.</span><span class="na">areaUnderPR</span><span class="o">());</span> |
| |
| <span class="c1">// AUROC</span> |
| <span class="n">System</span><span class="o">.</span><span class="na">out</span><span class="o">.</span><span class="na">println</span><span class="o">(</span><span class="s">"Area under ROC = "</span> <span class="o">+</span> <span class="n">metrics</span><span class="o">.</span><span class="na">areaUnderROC</span><span class="o">());</span> |
| |
| <span class="c1">// Save and load model</span> |
| <span class="n">model</span><span class="o">.</span><span class="na">save</span><span class="o">(</span><span class="n">sc</span><span class="o">,</span> <span class="s">"target/tmp/LogisticRegressionModel"</span><span class="o">);</span> |
| <span class="n">LogisticRegressionModel</span><span class="o">.</span><span class="na">load</span><span class="o">(</span><span class="n">sc</span><span class="o">,</span> <span class="s">"target/tmp/LogisticRegressionModel"</span><span class="o">);</span> |
| </pre></div> |
| <div><small>Find full example code at "examples/src/main/java/org/apache/spark/examples/mllib/JavaBinaryClassificationMetricsExample.java" in the Spark repo.</small></div> |
| |
| </div> |
| |
| <div data-lang="python"> |
| <p>Refer to the <a href="api/python/pyspark.mllib.html#pyspark.mllib.evaluation.BinaryClassificationMetrics"><code>BinaryClassificationMetrics</code> Python docs</a> and <a href="api/python/pyspark.mllib.html#pyspark.mllib.classification.LogisticRegressionWithLBFGS"><code>LogisticRegressionWithLBFGS</code> Python docs</a> for more details on the API.</p> |
| |
| <div class="highlight"><pre><span></span><span class="kn">from</span> <span class="nn">pyspark.mllib.classification</span> <span class="kn">import</span> <span class="n">LogisticRegressionWithLBFGS</span> |
| <span class="kn">from</span> <span class="nn">pyspark.mllib.evaluation</span> <span class="kn">import</span> <span class="n">BinaryClassificationMetrics</span> |
| <span class="kn">from</span> <span class="nn">pyspark.mllib.util</span> <span class="kn">import</span> <span class="n">MLUtils</span> |
| |
| <span class="c1"># Several of the methods available in scala are currently missing from pyspark</span> |
| <span class="c1"># Load training data in LIBSVM format</span> |
| <span class="n">data</span> <span class="o">=</span> <span class="n">MLUtils</span><span class="o">.</span><span class="n">loadLibSVMFile</span><span class="p">(</span><span class="n">sc</span><span class="p">,</span> <span class="s2">"data/mllib/sample_binary_classification_data.txt"</span><span class="p">)</span> |
| |
| <span class="c1"># Split data into training (60%) and test (40%)</span> |
| <span class="n">training</span><span class="p">,</span> <span class="n">test</span> <span class="o">=</span> <span class="n">data</span><span class="o">.</span><span class="n">randomSplit</span><span class="p">([</span><span class="mf">0.6</span><span class="p">,</span> <span class="mf">0.4</span><span class="p">],</span> <span class="n">seed</span><span class="o">=</span><span class="mi">11</span><span class="p">)</span> |
| <span class="n">training</span><span class="o">.</span><span class="n">cache</span><span class="p">()</span> |
| |
| <span class="c1"># Run training algorithm to build the model</span> |
| <span class="n">model</span> <span class="o">=</span> <span class="n">LogisticRegressionWithLBFGS</span><span class="o">.</span><span class="n">train</span><span class="p">(</span><span class="n">training</span><span class="p">)</span> |
| |
| <span class="c1"># Compute raw scores on the test set</span> |
| <span class="n">predictionAndLabels</span> <span class="o">=</span> <span class="n">test</span><span class="o">.</span><span class="n">map</span><span class="p">(</span><span class="k">lambda</span> <span class="n">lp</span><span class="p">:</span> <span class="p">(</span><span class="nb">float</span><span class="p">(</span><span class="n">model</span><span class="o">.</span><span class="n">predict</span><span class="p">(</span><span class="n">lp</span><span class="o">.</span><span class="n">features</span><span class="p">)),</span> <span class="n">lp</span><span class="o">.</span><span class="n">label</span><span class="p">))</span> |
| |
| <span class="c1"># Instantiate metrics object</span> |
| <span class="n">metrics</span> <span class="o">=</span> <span class="n">BinaryClassificationMetrics</span><span class="p">(</span><span class="n">predictionAndLabels</span><span class="p">)</span> |
| |
| <span class="c1"># Area under precision-recall curve</span> |
| <span class="k">print</span><span class="p">(</span><span class="s2">"Area under PR = </span><span class="si">%s</span><span class="s2">"</span> <span class="o">%</span> <span class="n">metrics</span><span class="o">.</span><span class="n">areaUnderPR</span><span class="p">)</span> |
| |
| <span class="c1"># Area under ROC curve</span> |
| <span class="k">print</span><span class="p">(</span><span class="s2">"Area under ROC = </span><span class="si">%s</span><span class="s2">"</span> <span class="o">%</span> <span class="n">metrics</span><span class="o">.</span><span class="n">areaUnderROC</span><span class="p">)</span> |
| </pre></div> |
| <div><small>Find full example code at "examples/src/main/python/mllib/binary_classification_metrics_example.py" in the Spark repo.</small></div> |
| </div> |
| </div> |
| |
| <h3 id="multiclass-classification">Multiclass classification</h3> |
| |
| <p>A <a href="https://en.wikipedia.org/wiki/Multiclass_classification">multiclass classification</a> describes a classification |
| problem where there are $M \gt 2$ possible labels for each data point (the case where $M=2$ is the binary |
| classification problem). For example, classifying handwriting samples to the digits 0 to 9, having 10 possible classes.</p> |
| |
| <p>For multiclass metrics, the notion of positives and negatives is slightly different. Predictions and labels can still |
| be positive or negative, but they must be considered under the context of a particular class. Each label and prediction |
| take on the value of one of the multiple classes and so they are said to be positive for their particular class and negative |
| for all other classes. So, a true positive occurs whenever the prediction and the label match, while a true negative |
| occurs when neither the prediction nor the label take on the value of a given class. By this convention, there can be |
| multiple true negatives for a given data sample. The extension of false negatives and false positives from the former |
| definitions of positive and negative labels is straightforward.</p> |
| |
| <h4 id="label-based-metrics">Label based metrics</h4> |
| |
| <p>Opposed to binary classification where there are only two possible labels, multiclass classification problems have many |
| possible labels and so the concept of label-based metrics is introduced. Accuracy measures precision across all |
| labels - the number of times any class was predicted correctly (true positives) normalized by the number of data |
| points. Precision by label considers only one class, and measures the number of time a specific label was predicted |
| correctly normalized by the number of times that label appears in the output.</p> |
| |
| <p><strong>Available metrics</strong></p> |
| |
| <p>Define the class, or label, set as</p> |
| |
| <script type="math/tex; mode=display">L = \{\ell_0, \ell_1, \ldots, \ell_{M-1} \}</script> |
| |
| <p>The true output vector $\mathbf{y}$ consists of $N$ elements</p> |
| |
| <script type="math/tex; mode=display">\mathbf{y}_0, \mathbf{y}_1, \ldots, \mathbf{y}_{N-1} \in L</script> |
| |
| <p>A multiclass prediction algorithm generates a prediction vector $\hat{\mathbf{y}}$ of $N$ elements</p> |
| |
| <script type="math/tex; mode=display">\hat{\mathbf{y}}_0, \hat{\mathbf{y}}_1, \ldots, \hat{\mathbf{y}}_{N-1} \in L</script> |
| |
| <p>For this section, a modified delta function $\hat{\delta}(x)$ will prove useful</p> |
| |
| <script type="math/tex; mode=display">% <![CDATA[ |
| \hat{\delta}(x) = \begin{cases}1 & \text{if $x = 0$}, \\ 0 & \text{otherwise}.\end{cases} %]]></script> |
| |
| <table class="table"> |
| <thead> |
| <tr><th>Metric</th><th>Definition</th></tr> |
| </thead> |
| <tbody> |
| <tr> |
| <td>Confusion Matrix</td> |
| <td> |
| $C_{ij} = \sum_{k=0}^{N-1} \hat{\delta}(\mathbf{y}_k-\ell_i) \cdot \hat{\delta}(\hat{\mathbf{y}}_k - \ell_j)\\ \\ |
| \left( \begin{array}{ccc} |
| \sum_{k=0}^{N-1} \hat{\delta}(\mathbf{y}_k-\ell_1) \cdot \hat{\delta}(\hat{\mathbf{y}}_k - \ell_1) & \ldots & |
| \sum_{k=0}^{N-1} \hat{\delta}(\mathbf{y}_k-\ell_1) \cdot \hat{\delta}(\hat{\mathbf{y}}_k - \ell_N) \\ |
| \vdots & \ddots & \vdots \\ |
| \sum_{k=0}^{N-1} \hat{\delta}(\mathbf{y}_k-\ell_N) \cdot \hat{\delta}(\hat{\mathbf{y}}_k - \ell_1) & \ldots & |
| \sum_{k=0}^{N-1} \hat{\delta}(\mathbf{y}_k-\ell_N) \cdot \hat{\delta}(\hat{\mathbf{y}}_k - \ell_N) |
| \end{array} \right)$ |
| </td> |
| </tr> |
| <tr> |
| <td>Accuracy</td> |
| <td>$ACC = \frac{TP}{TP + FP} = \frac{1}{N}\sum_{i=0}^{N-1} \hat{\delta}\left(\hat{\mathbf{y}}_i - |
| \mathbf{y}_i\right)$</td> |
| </tr> |
| <tr> |
| <td>Precision by label</td> |
| <td>$PPV(\ell) = \frac{TP}{TP + FP} = |
| \frac{\sum_{i=0}^{N-1} \hat{\delta}(\hat{\mathbf{y}}_i - \ell) \cdot \hat{\delta}(\mathbf{y}_i - \ell)} |
| {\sum_{i=0}^{N-1} \hat{\delta}(\hat{\mathbf{y}}_i - \ell)}$</td> |
| </tr> |
| <tr> |
| <td>Recall by label</td> |
| <td>$TPR(\ell)=\frac{TP}{P} = |
| \frac{\sum_{i=0}^{N-1} \hat{\delta}(\hat{\mathbf{y}}_i - \ell) \cdot \hat{\delta}(\mathbf{y}_i - \ell)} |
| {\sum_{i=0}^{N-1} \hat{\delta}(\mathbf{y}_i - \ell)}$</td> |
| </tr> |
| <tr> |
| <td>F-measure by label</td> |
| <td>$F(\beta, \ell) = \left(1 + \beta^2\right) \cdot \left(\frac{PPV(\ell) \cdot TPR(\ell)} |
| {\beta^2 \cdot PPV(\ell) + TPR(\ell)}\right)$</td> |
| </tr> |
| <tr> |
| <td>Weighted precision</td> |
| <td>$PPV_{w}= \frac{1}{N} \sum\nolimits_{\ell \in L} PPV(\ell) |
| \cdot \sum_{i=0}^{N-1} \hat{\delta}(\mathbf{y}_i-\ell)$</td> |
| </tr> |
| <tr> |
| <td>Weighted recall</td> |
| <td>$TPR_{w}= \frac{1}{N} \sum\nolimits_{\ell \in L} TPR(\ell) |
| \cdot \sum_{i=0}^{N-1} \hat{\delta}(\mathbf{y}_i-\ell)$</td> |
| </tr> |
| <tr> |
| <td>Weighted F-measure</td> |
| <td>$F_{w}(\beta)= \frac{1}{N} \sum\nolimits_{\ell \in L} F(\beta, \ell) |
| \cdot \sum_{i=0}^{N-1} \hat{\delta}(\mathbf{y}_i-\ell)$</td> |
| </tr> |
| </tbody> |
| </table> |
| |
| <p><strong>Examples</strong></p> |
| |
| <div class="codetabs"> |
| The following code snippets illustrate how to load a sample dataset, train a multiclass classification algorithm on |
| the data, and evaluate the performance of the algorithm by several multiclass classification evaluation metrics. |
| |
| <div data-lang="scala"> |
| <p>Refer to the <a href="api/scala/index.html#org.apache.spark.mllib.evaluation.MulticlassMetrics"><code>MulticlassMetrics</code> Scala docs</a> for details on the API.</p> |
| |
| <div class="highlight"><pre><span></span><span class="k">import</span> <span class="nn">org.apache.spark.mllib.classification.LogisticRegressionWithLBFGS</span> |
| <span class="k">import</span> <span class="nn">org.apache.spark.mllib.evaluation.MulticlassMetrics</span> |
| <span class="k">import</span> <span class="nn">org.apache.spark.mllib.regression.LabeledPoint</span> |
| <span class="k">import</span> <span class="nn">org.apache.spark.mllib.util.MLUtils</span> |
| |
| <span class="c1">// Load training data in LIBSVM format</span> |
| <span class="k">val</span> <span class="n">data</span> <span class="k">=</span> <span class="nc">MLUtils</span><span class="o">.</span><span class="n">loadLibSVMFile</span><span class="o">(</span><span class="n">sc</span><span class="o">,</span> <span class="s">"data/mllib/sample_multiclass_classification_data.txt"</span><span class="o">)</span> |
| |
| <span class="c1">// Split data into training (60%) and test (40%)</span> |
| <span class="k">val</span> <span class="nc">Array</span><span class="o">(</span><span class="n">training</span><span class="o">,</span> <span class="n">test</span><span class="o">)</span> <span class="k">=</span> <span class="n">data</span><span class="o">.</span><span class="n">randomSplit</span><span class="o">(</span><span class="nc">Array</span><span class="o">(</span><span class="mf">0.6</span><span class="o">,</span> <span class="mf">0.4</span><span class="o">),</span> <span class="n">seed</span> <span class="k">=</span> <span class="mi">11L</span><span class="o">)</span> |
| <span class="n">training</span><span class="o">.</span><span class="n">cache</span><span class="o">()</span> |
| |
| <span class="c1">// Run training algorithm to build the model</span> |
| <span class="k">val</span> <span class="n">model</span> <span class="k">=</span> <span class="k">new</span> <span class="nc">LogisticRegressionWithLBFGS</span><span class="o">()</span> |
| <span class="o">.</span><span class="n">setNumClasses</span><span class="o">(</span><span class="mi">3</span><span class="o">)</span> |
| <span class="o">.</span><span class="n">run</span><span class="o">(</span><span class="n">training</span><span class="o">)</span> |
| |
| <span class="c1">// Compute raw scores on the test set</span> |
| <span class="k">val</span> <span class="n">predictionAndLabels</span> <span class="k">=</span> <span class="n">test</span><span class="o">.</span><span class="n">map</span> <span class="o">{</span> <span class="k">case</span> <span class="nc">LabeledPoint</span><span class="o">(</span><span class="n">label</span><span class="o">,</span> <span class="n">features</span><span class="o">)</span> <span class="k">=></span> |
| <span class="k">val</span> <span class="n">prediction</span> <span class="k">=</span> <span class="n">model</span><span class="o">.</span><span class="n">predict</span><span class="o">(</span><span class="n">features</span><span class="o">)</span> |
| <span class="o">(</span><span class="n">prediction</span><span class="o">,</span> <span class="n">label</span><span class="o">)</span> |
| <span class="o">}</span> |
| |
| <span class="c1">// Instantiate metrics object</span> |
| <span class="k">val</span> <span class="n">metrics</span> <span class="k">=</span> <span class="k">new</span> <span class="nc">MulticlassMetrics</span><span class="o">(</span><span class="n">predictionAndLabels</span><span class="o">)</span> |
| |
| <span class="c1">// Confusion matrix</span> |
| <span class="n">println</span><span class="o">(</span><span class="s">"Confusion matrix:"</span><span class="o">)</span> |
| <span class="n">println</span><span class="o">(</span><span class="n">metrics</span><span class="o">.</span><span class="n">confusionMatrix</span><span class="o">)</span> |
| |
| <span class="c1">// Overall Statistics</span> |
| <span class="k">val</span> <span class="n">accuracy</span> <span class="k">=</span> <span class="n">metrics</span><span class="o">.</span><span class="n">accuracy</span> |
| <span class="n">println</span><span class="o">(</span><span class="s">"Summary Statistics"</span><span class="o">)</span> |
| <span class="n">println</span><span class="o">(</span><span class="s">s"Accuracy = </span><span class="si">$accuracy</span><span class="s">"</span><span class="o">)</span> |
| |
| <span class="c1">// Precision by label</span> |
| <span class="k">val</span> <span class="n">labels</span> <span class="k">=</span> <span class="n">metrics</span><span class="o">.</span><span class="n">labels</span> |
| <span class="n">labels</span><span class="o">.</span><span class="n">foreach</span> <span class="o">{</span> <span class="n">l</span> <span class="k">=></span> |
| <span class="n">println</span><span class="o">(</span><span class="s">s"Precision(</span><span class="si">$l</span><span class="s">) = "</span> <span class="o">+</span> <span class="n">metrics</span><span class="o">.</span><span class="n">precision</span><span class="o">(</span><span class="n">l</span><span class="o">))</span> |
| <span class="o">}</span> |
| |
| <span class="c1">// Recall by label</span> |
| <span class="n">labels</span><span class="o">.</span><span class="n">foreach</span> <span class="o">{</span> <span class="n">l</span> <span class="k">=></span> |
| <span class="n">println</span><span class="o">(</span><span class="s">s"Recall(</span><span class="si">$l</span><span class="s">) = "</span> <span class="o">+</span> <span class="n">metrics</span><span class="o">.</span><span class="n">recall</span><span class="o">(</span><span class="n">l</span><span class="o">))</span> |
| <span class="o">}</span> |
| |
| <span class="c1">// False positive rate by label</span> |
| <span class="n">labels</span><span class="o">.</span><span class="n">foreach</span> <span class="o">{</span> <span class="n">l</span> <span class="k">=></span> |
| <span class="n">println</span><span class="o">(</span><span class="s">s"FPR(</span><span class="si">$l</span><span class="s">) = "</span> <span class="o">+</span> <span class="n">metrics</span><span class="o">.</span><span class="n">falsePositiveRate</span><span class="o">(</span><span class="n">l</span><span class="o">))</span> |
| <span class="o">}</span> |
| |
| <span class="c1">// F-measure by label</span> |
| <span class="n">labels</span><span class="o">.</span><span class="n">foreach</span> <span class="o">{</span> <span class="n">l</span> <span class="k">=></span> |
| <span class="n">println</span><span class="o">(</span><span class="s">s"F1-Score(</span><span class="si">$l</span><span class="s">) = "</span> <span class="o">+</span> <span class="n">metrics</span><span class="o">.</span><span class="n">fMeasure</span><span class="o">(</span><span class="n">l</span><span class="o">))</span> |
| <span class="o">}</span> |
| |
| <span class="c1">// Weighted stats</span> |
| <span class="n">println</span><span class="o">(</span><span class="s">s"Weighted precision: </span><span class="si">${</span><span class="n">metrics</span><span class="o">.</span><span class="n">weightedPrecision</span><span class="si">}</span><span class="s">"</span><span class="o">)</span> |
| <span class="n">println</span><span class="o">(</span><span class="s">s"Weighted recall: </span><span class="si">${</span><span class="n">metrics</span><span class="o">.</span><span class="n">weightedRecall</span><span class="si">}</span><span class="s">"</span><span class="o">)</span> |
| <span class="n">println</span><span class="o">(</span><span class="s">s"Weighted F1 score: </span><span class="si">${</span><span class="n">metrics</span><span class="o">.</span><span class="n">weightedFMeasure</span><span class="si">}</span><span class="s">"</span><span class="o">)</span> |
| <span class="n">println</span><span class="o">(</span><span class="s">s"Weighted false positive rate: </span><span class="si">${</span><span class="n">metrics</span><span class="o">.</span><span class="n">weightedFalsePositiveRate</span><span class="si">}</span><span class="s">"</span><span class="o">)</span> |
| </pre></div> |
| <div><small>Find full example code at "examples/src/main/scala/org/apache/spark/examples/mllib/MulticlassMetricsExample.scala" in the Spark repo.</small></div> |
| |
| </div> |
| |
| <div data-lang="java"> |
| <p>Refer to the <a href="api/java/org/apache/spark/mllib/evaluation/MulticlassMetrics.html"><code>MulticlassMetrics</code> Java docs</a> for details on the API.</p> |
| |
| <div class="highlight"><pre><span></span><span class="kn">import</span> <span class="nn">scala.Tuple2</span><span class="o">;</span> |
| |
| <span class="kn">import</span> <span class="nn">org.apache.spark.api.java.*</span><span class="o">;</span> |
| <span class="kn">import</span> <span class="nn">org.apache.spark.mllib.classification.LogisticRegressionModel</span><span class="o">;</span> |
| <span class="kn">import</span> <span class="nn">org.apache.spark.mllib.classification.LogisticRegressionWithLBFGS</span><span class="o">;</span> |
| <span class="kn">import</span> <span class="nn">org.apache.spark.mllib.evaluation.MulticlassMetrics</span><span class="o">;</span> |
| <span class="kn">import</span> <span class="nn">org.apache.spark.mllib.regression.LabeledPoint</span><span class="o">;</span> |
| <span class="kn">import</span> <span class="nn">org.apache.spark.mllib.util.MLUtils</span><span class="o">;</span> |
| <span class="kn">import</span> <span class="nn">org.apache.spark.mllib.linalg.Matrix</span><span class="o">;</span> |
| |
| <span class="n">String</span> <span class="n">path</span> <span class="o">=</span> <span class="s">"data/mllib/sample_multiclass_classification_data.txt"</span><span class="o">;</span> |
| <span class="n">JavaRDD</span><span class="o"><</span><span class="n">LabeledPoint</span><span class="o">></span> <span class="n">data</span> <span class="o">=</span> <span class="n">MLUtils</span><span class="o">.</span><span class="na">loadLibSVMFile</span><span class="o">(</span><span class="n">sc</span><span class="o">,</span> <span class="n">path</span><span class="o">).</span><span class="na">toJavaRDD</span><span class="o">();</span> |
| |
| <span class="c1">// Split initial RDD into two... [60% training data, 40% testing data].</span> |
| <span class="n">JavaRDD</span><span class="o"><</span><span class="n">LabeledPoint</span><span class="o">>[]</span> <span class="n">splits</span> <span class="o">=</span> <span class="n">data</span><span class="o">.</span><span class="na">randomSplit</span><span class="o">(</span><span class="k">new</span> <span class="kt">double</span><span class="o">[]{</span><span class="mf">0.6</span><span class="o">,</span> <span class="mf">0.4</span><span class="o">},</span> <span class="mi">11L</span><span class="o">);</span> |
| <span class="n">JavaRDD</span><span class="o"><</span><span class="n">LabeledPoint</span><span class="o">></span> <span class="n">training</span> <span class="o">=</span> <span class="n">splits</span><span class="o">[</span><span class="mi">0</span><span class="o">].</span><span class="na">cache</span><span class="o">();</span> |
| <span class="n">JavaRDD</span><span class="o"><</span><span class="n">LabeledPoint</span><span class="o">></span> <span class="n">test</span> <span class="o">=</span> <span class="n">splits</span><span class="o">[</span><span class="mi">1</span><span class="o">];</span> |
| |
| <span class="c1">// Run training algorithm to build the model.</span> |
| <span class="n">LogisticRegressionModel</span> <span class="n">model</span> <span class="o">=</span> <span class="k">new</span> <span class="n">LogisticRegressionWithLBFGS</span><span class="o">()</span> |
| <span class="o">.</span><span class="na">setNumClasses</span><span class="o">(</span><span class="mi">3</span><span class="o">)</span> |
| <span class="o">.</span><span class="na">run</span><span class="o">(</span><span class="n">training</span><span class="o">.</span><span class="na">rdd</span><span class="o">());</span> |
| |
| <span class="c1">// Compute raw scores on the test set.</span> |
| <span class="n">JavaPairRDD</span><span class="o"><</span><span class="n">Object</span><span class="o">,</span> <span class="n">Object</span><span class="o">></span> <span class="n">predictionAndLabels</span> <span class="o">=</span> <span class="n">test</span><span class="o">.</span><span class="na">mapToPair</span><span class="o">(</span><span class="n">p</span> <span class="o">-></span> |
| <span class="k">new</span> <span class="n">Tuple2</span><span class="o"><>(</span><span class="n">model</span><span class="o">.</span><span class="na">predict</span><span class="o">(</span><span class="n">p</span><span class="o">.</span><span class="na">features</span><span class="o">()),</span> <span class="n">p</span><span class="o">.</span><span class="na">label</span><span class="o">()));</span> |
| |
| <span class="c1">// Get evaluation metrics.</span> |
| <span class="n">MulticlassMetrics</span> <span class="n">metrics</span> <span class="o">=</span> <span class="k">new</span> <span class="n">MulticlassMetrics</span><span class="o">(</span><span class="n">predictionAndLabels</span><span class="o">.</span><span class="na">rdd</span><span class="o">());</span> |
| |
| <span class="c1">// Confusion matrix</span> |
| <span class="n">Matrix</span> <span class="n">confusion</span> <span class="o">=</span> <span class="n">metrics</span><span class="o">.</span><span class="na">confusionMatrix</span><span class="o">();</span> |
| <span class="n">System</span><span class="o">.</span><span class="na">out</span><span class="o">.</span><span class="na">println</span><span class="o">(</span><span class="s">"Confusion matrix: \n"</span> <span class="o">+</span> <span class="n">confusion</span><span class="o">);</span> |
| |
| <span class="c1">// Overall statistics</span> |
| <span class="n">System</span><span class="o">.</span><span class="na">out</span><span class="o">.</span><span class="na">println</span><span class="o">(</span><span class="s">"Accuracy = "</span> <span class="o">+</span> <span class="n">metrics</span><span class="o">.</span><span class="na">accuracy</span><span class="o">());</span> |
| |
| <span class="c1">// Stats by labels</span> |
| <span class="k">for</span> <span class="o">(</span><span class="kt">int</span> <span class="n">i</span> <span class="o">=</span> <span class="mi">0</span><span class="o">;</span> <span class="n">i</span> <span class="o"><</span> <span class="n">metrics</span><span class="o">.</span><span class="na">labels</span><span class="o">().</span><span class="na">length</span><span class="o">;</span> <span class="n">i</span><span class="o">++)</span> <span class="o">{</span> |
| <span class="n">System</span><span class="o">.</span><span class="na">out</span><span class="o">.</span><span class="na">format</span><span class="o">(</span><span class="s">"Class %f precision = %f\n"</span><span class="o">,</span> <span class="n">metrics</span><span class="o">.</span><span class="na">labels</span><span class="o">()[</span><span class="n">i</span><span class="o">],</span><span class="n">metrics</span><span class="o">.</span><span class="na">precision</span><span class="o">(</span> |
| <span class="n">metrics</span><span class="o">.</span><span class="na">labels</span><span class="o">()[</span><span class="n">i</span><span class="o">]));</span> |
| <span class="n">System</span><span class="o">.</span><span class="na">out</span><span class="o">.</span><span class="na">format</span><span class="o">(</span><span class="s">"Class %f recall = %f\n"</span><span class="o">,</span> <span class="n">metrics</span><span class="o">.</span><span class="na">labels</span><span class="o">()[</span><span class="n">i</span><span class="o">],</span> <span class="n">metrics</span><span class="o">.</span><span class="na">recall</span><span class="o">(</span> |
| <span class="n">metrics</span><span class="o">.</span><span class="na">labels</span><span class="o">()[</span><span class="n">i</span><span class="o">]));</span> |
| <span class="n">System</span><span class="o">.</span><span class="na">out</span><span class="o">.</span><span class="na">format</span><span class="o">(</span><span class="s">"Class %f F1 score = %f\n"</span><span class="o">,</span> <span class="n">metrics</span><span class="o">.</span><span class="na">labels</span><span class="o">()[</span><span class="n">i</span><span class="o">],</span> <span class="n">metrics</span><span class="o">.</span><span class="na">fMeasure</span><span class="o">(</span> |
| <span class="n">metrics</span><span class="o">.</span><span class="na">labels</span><span class="o">()[</span><span class="n">i</span><span class="o">]));</span> |
| <span class="o">}</span> |
| |
| <span class="c1">//Weighted stats</span> |
| <span class="n">System</span><span class="o">.</span><span class="na">out</span><span class="o">.</span><span class="na">format</span><span class="o">(</span><span class="s">"Weighted precision = %f\n"</span><span class="o">,</span> <span class="n">metrics</span><span class="o">.</span><span class="na">weightedPrecision</span><span class="o">());</span> |
| <span class="n">System</span><span class="o">.</span><span class="na">out</span><span class="o">.</span><span class="na">format</span><span class="o">(</span><span class="s">"Weighted recall = %f\n"</span><span class="o">,</span> <span class="n">metrics</span><span class="o">.</span><span class="na">weightedRecall</span><span class="o">());</span> |
| <span class="n">System</span><span class="o">.</span><span class="na">out</span><span class="o">.</span><span class="na">format</span><span class="o">(</span><span class="s">"Weighted F1 score = %f\n"</span><span class="o">,</span> <span class="n">metrics</span><span class="o">.</span><span class="na">weightedFMeasure</span><span class="o">());</span> |
| <span class="n">System</span><span class="o">.</span><span class="na">out</span><span class="o">.</span><span class="na">format</span><span class="o">(</span><span class="s">"Weighted false positive rate = %f\n"</span><span class="o">,</span> <span class="n">metrics</span><span class="o">.</span><span class="na">weightedFalsePositiveRate</span><span class="o">());</span> |
| |
| <span class="c1">// Save and load model</span> |
| <span class="n">model</span><span class="o">.</span><span class="na">save</span><span class="o">(</span><span class="n">sc</span><span class="o">,</span> <span class="s">"target/tmp/LogisticRegressionModel"</span><span class="o">);</span> |
| <span class="n">LogisticRegressionModel</span> <span class="n">sameModel</span> <span class="o">=</span> <span class="n">LogisticRegressionModel</span><span class="o">.</span><span class="na">load</span><span class="o">(</span><span class="n">sc</span><span class="o">,</span> |
| <span class="s">"target/tmp/LogisticRegressionModel"</span><span class="o">);</span> |
| </pre></div> |
| <div><small>Find full example code at "examples/src/main/java/org/apache/spark/examples/mllib/JavaMulticlassClassificationMetricsExample.java" in the Spark repo.</small></div> |
| |
| </div> |
| |
| <div data-lang="python"> |
| <p>Refer to the <a href="api/python/pyspark.mllib.html#pyspark.mllib.evaluation.MulticlassMetrics"><code>MulticlassMetrics</code> Python docs</a> for more details on the API.</p> |
| |
| <div class="highlight"><pre><span></span><span class="kn">from</span> <span class="nn">pyspark.mllib.classification</span> <span class="kn">import</span> <span class="n">LogisticRegressionWithLBFGS</span> |
| <span class="kn">from</span> <span class="nn">pyspark.mllib.util</span> <span class="kn">import</span> <span class="n">MLUtils</span> |
| <span class="kn">from</span> <span class="nn">pyspark.mllib.evaluation</span> <span class="kn">import</span> <span class="n">MulticlassMetrics</span> |
| |
| <span class="c1"># Load training data in LIBSVM format</span> |
| <span class="n">data</span> <span class="o">=</span> <span class="n">MLUtils</span><span class="o">.</span><span class="n">loadLibSVMFile</span><span class="p">(</span><span class="n">sc</span><span class="p">,</span> <span class="s2">"data/mllib/sample_multiclass_classification_data.txt"</span><span class="p">)</span> |
| |
| <span class="c1"># Split data into training (60%) and test (40%)</span> |
| <span class="n">training</span><span class="p">,</span> <span class="n">test</span> <span class="o">=</span> <span class="n">data</span><span class="o">.</span><span class="n">randomSplit</span><span class="p">([</span><span class="mf">0.6</span><span class="p">,</span> <span class="mf">0.4</span><span class="p">],</span> <span class="n">seed</span><span class="o">=</span><span class="mi">11</span><span class="p">)</span> |
| <span class="n">training</span><span class="o">.</span><span class="n">cache</span><span class="p">()</span> |
| |
| <span class="c1"># Run training algorithm to build the model</span> |
| <span class="n">model</span> <span class="o">=</span> <span class="n">LogisticRegressionWithLBFGS</span><span class="o">.</span><span class="n">train</span><span class="p">(</span><span class="n">training</span><span class="p">,</span> <span class="n">numClasses</span><span class="o">=</span><span class="mi">3</span><span class="p">)</span> |
| |
| <span class="c1"># Compute raw scores on the test set</span> |
| <span class="n">predictionAndLabels</span> <span class="o">=</span> <span class="n">test</span><span class="o">.</span><span class="n">map</span><span class="p">(</span><span class="k">lambda</span> <span class="n">lp</span><span class="p">:</span> <span class="p">(</span><span class="nb">float</span><span class="p">(</span><span class="n">model</span><span class="o">.</span><span class="n">predict</span><span class="p">(</span><span class="n">lp</span><span class="o">.</span><span class="n">features</span><span class="p">)),</span> <span class="n">lp</span><span class="o">.</span><span class="n">label</span><span class="p">))</span> |
| |
| <span class="c1"># Instantiate metrics object</span> |
| <span class="n">metrics</span> <span class="o">=</span> <span class="n">MulticlassMetrics</span><span class="p">(</span><span class="n">predictionAndLabels</span><span class="p">)</span> |
| |
| <span class="c1"># Overall statistics</span> |
| <span class="n">precision</span> <span class="o">=</span> <span class="n">metrics</span><span class="o">.</span><span class="n">precision</span><span class="p">()</span> |
| <span class="n">recall</span> <span class="o">=</span> <span class="n">metrics</span><span class="o">.</span><span class="n">recall</span><span class="p">()</span> |
| <span class="n">f1Score</span> <span class="o">=</span> <span class="n">metrics</span><span class="o">.</span><span class="n">fMeasure</span><span class="p">()</span> |
| <span class="k">print</span><span class="p">(</span><span class="s2">"Summary Stats"</span><span class="p">)</span> |
| <span class="k">print</span><span class="p">(</span><span class="s2">"Precision = </span><span class="si">%s</span><span class="s2">"</span> <span class="o">%</span> <span class="n">precision</span><span class="p">)</span> |
| <span class="k">print</span><span class="p">(</span><span class="s2">"Recall = </span><span class="si">%s</span><span class="s2">"</span> <span class="o">%</span> <span class="n">recall</span><span class="p">)</span> |
| <span class="k">print</span><span class="p">(</span><span class="s2">"F1 Score = </span><span class="si">%s</span><span class="s2">"</span> <span class="o">%</span> <span class="n">f1Score</span><span class="p">)</span> |
| |
| <span class="c1"># Statistics by class</span> |
| <span class="n">labels</span> <span class="o">=</span> <span class="n">data</span><span class="o">.</span><span class="n">map</span><span class="p">(</span><span class="k">lambda</span> <span class="n">lp</span><span class="p">:</span> <span class="n">lp</span><span class="o">.</span><span class="n">label</span><span class="p">)</span><span class="o">.</span><span class="n">distinct</span><span class="p">()</span><span class="o">.</span><span class="n">collect</span><span class="p">()</span> |
| <span class="k">for</span> <span class="n">label</span> <span class="ow">in</span> <span class="nb">sorted</span><span class="p">(</span><span class="n">labels</span><span class="p">):</span> |
| <span class="k">print</span><span class="p">(</span><span class="s2">"Class </span><span class="si">%s</span><span class="s2"> precision = </span><span class="si">%s</span><span class="s2">"</span> <span class="o">%</span> <span class="p">(</span><span class="n">label</span><span class="p">,</span> <span class="n">metrics</span><span class="o">.</span><span class="n">precision</span><span class="p">(</span><span class="n">label</span><span class="p">)))</span> |
| <span class="k">print</span><span class="p">(</span><span class="s2">"Class </span><span class="si">%s</span><span class="s2"> recall = </span><span class="si">%s</span><span class="s2">"</span> <span class="o">%</span> <span class="p">(</span><span class="n">label</span><span class="p">,</span> <span class="n">metrics</span><span class="o">.</span><span class="n">recall</span><span class="p">(</span><span class="n">label</span><span class="p">)))</span> |
| <span class="k">print</span><span class="p">(</span><span class="s2">"Class </span><span class="si">%s</span><span class="s2"> F1 Measure = </span><span class="si">%s</span><span class="s2">"</span> <span class="o">%</span> <span class="p">(</span><span class="n">label</span><span class="p">,</span> <span class="n">metrics</span><span class="o">.</span><span class="n">fMeasure</span><span class="p">(</span><span class="n">label</span><span class="p">,</span> <span class="n">beta</span><span class="o">=</span><span class="mf">1.0</span><span class="p">)))</span> |
| |
| <span class="c1"># Weighted stats</span> |
| <span class="k">print</span><span class="p">(</span><span class="s2">"Weighted recall = </span><span class="si">%s</span><span class="s2">"</span> <span class="o">%</span> <span class="n">metrics</span><span class="o">.</span><span class="n">weightedRecall</span><span class="p">)</span> |
| <span class="k">print</span><span class="p">(</span><span class="s2">"Weighted precision = </span><span class="si">%s</span><span class="s2">"</span> <span class="o">%</span> <span class="n">metrics</span><span class="o">.</span><span class="n">weightedPrecision</span><span class="p">)</span> |
| <span class="k">print</span><span class="p">(</span><span class="s2">"Weighted F(1) Score = </span><span class="si">%s</span><span class="s2">"</span> <span class="o">%</span> <span class="n">metrics</span><span class="o">.</span><span class="n">weightedFMeasure</span><span class="p">())</span> |
| <span class="k">print</span><span class="p">(</span><span class="s2">"Weighted F(0.5) Score = </span><span class="si">%s</span><span class="s2">"</span> <span class="o">%</span> <span class="n">metrics</span><span class="o">.</span><span class="n">weightedFMeasure</span><span class="p">(</span><span class="n">beta</span><span class="o">=</span><span class="mf">0.5</span><span class="p">))</span> |
| <span class="k">print</span><span class="p">(</span><span class="s2">"Weighted false positive rate = </span><span class="si">%s</span><span class="s2">"</span> <span class="o">%</span> <span class="n">metrics</span><span class="o">.</span><span class="n">weightedFalsePositiveRate</span><span class="p">)</span> |
| </pre></div> |
| <div><small>Find full example code at "examples/src/main/python/mllib/multi_class_metrics_example.py" in the Spark repo.</small></div> |
| |
| </div> |
| </div> |
| |
| <h3 id="multilabel-classification">Multilabel classification</h3> |
| |
| <p>A <a href="https://en.wikipedia.org/wiki/Multi-label_classification">multilabel classification</a> problem involves mapping |
| each sample in a dataset to a set of class labels. In this type of classification problem, the labels are not |
| mutually exclusive. For example, when classifying a set of news articles into topics, a single article might be both |
| science and politics.</p> |
| |
| <p>Because the labels are not mutually exclusive, the predictions and true labels are now vectors of label <em>sets</em>, rather |
| than vectors of labels. Multilabel metrics, therefore, extend the fundamental ideas of precision, recall, etc. to |
| operations on sets. For example, a true positive for a given class now occurs when that class exists in the predicted |
| set and it exists in the true label set, for a specific data point.</p> |
| |
| <p><strong>Available metrics</strong></p> |
| |
| <p>Here we define a set $D$ of $N$ documents</p> |
| |
| <script type="math/tex; mode=display">D = \left\{d_0, d_1, ..., d_{N-1}\right\}</script> |
| |
| <p>Define $L_0, L_1, …, L_{N-1}$ to be a family of label sets and $P_0, P_1, …, P_{N-1}$ |
| to be a family of prediction sets where $L_i$ and $P_i$ are the label set and prediction set, respectively, that |
| correspond to document $d_i$.</p> |
| |
| <p>The set of all unique labels is given by</p> |
| |
| <script type="math/tex; mode=display">L = \bigcup_{k=0}^{N-1} L_k</script> |
| |
| <p>The following definition of indicator function $I_A(x)$ on a set $A$ will be necessary</p> |
| |
| <script type="math/tex; mode=display">% <![CDATA[ |
| I_A(x) = \begin{cases}1 & \text{if $x \in A$}, \\ 0 & \text{otherwise}.\end{cases} %]]></script> |
| |
| <table class="table"> |
| <thead> |
| <tr><th>Metric</th><th>Definition</th></tr> |
| </thead> |
| <tbody> |
| <tr> |
| <td>Precision</td><td>$\frac{1}{N} \sum_{i=0}^{N-1} \frac{\left|P_i \cap L_i\right|}{\left|P_i\right|}$</td> |
| </tr> |
| <tr> |
| <td>Recall</td><td>$\frac{1}{N} \sum_{i=0}^{N-1} \frac{\left|L_i \cap P_i\right|}{\left|L_i\right|}$</td> |
| </tr> |
| <tr> |
| <td>Accuracy</td> |
| <td> |
| $\frac{1}{N} \sum_{i=0}^{N - 1} \frac{\left|L_i \cap P_i \right|} |
| {\left|L_i\right| + \left|P_i\right| - \left|L_i \cap P_i \right|}$ |
| </td> |
| </tr> |
| <tr> |
| <td>Precision by label</td><td>$PPV(\ell)=\frac{TP}{TP + FP}= |
| \frac{\sum_{i=0}^{N-1} I_{P_i}(\ell) \cdot I_{L_i}(\ell)} |
| {\sum_{i=0}^{N-1} I_{P_i}(\ell)}$</td> |
| </tr> |
| <tr> |
| <td>Recall by label</td><td>$TPR(\ell)=\frac{TP}{P}= |
| \frac{\sum_{i=0}^{N-1} I_{P_i}(\ell) \cdot I_{L_i}(\ell)} |
| {\sum_{i=0}^{N-1} I_{L_i}(\ell)}$</td> |
| </tr> |
| <tr> |
| <td>F1-measure by label</td><td>$F1(\ell) = 2 |
| \cdot \left(\frac{PPV(\ell) \cdot TPR(\ell)} |
| {PPV(\ell) + TPR(\ell)}\right)$</td> |
| </tr> |
| <tr> |
| <td>Hamming Loss</td> |
| <td> |
| $\frac{1}{N \cdot \left|L\right|} \sum_{i=0}^{N - 1} \left|L_i\right| + \left|P_i\right| - 2\left|L_i |
| \cap P_i\right|$ |
| </td> |
| </tr> |
| <tr> |
| <td>Subset Accuracy</td> |
| <td>$\frac{1}{N} \sum_{i=0}^{N-1} I_{\{L_i\}}(P_i)$</td> |
| </tr> |
| <tr> |
| <td>F1 Measure</td> |
| <td>$\frac{1}{N} \sum_{i=0}^{N-1} 2 \frac{\left|P_i \cap L_i\right|}{\left|P_i\right| \cdot \left|L_i\right|}$</td> |
| </tr> |
| <tr> |
| <td>Micro precision</td> |
| <td>$\frac{TP}{TP + FP}=\frac{\sum_{i=0}^{N-1} \left|P_i \cap L_i\right|} |
| {\sum_{i=0}^{N-1} \left|P_i \cap L_i\right| + \sum_{i=0}^{N-1} \left|P_i - L_i\right|}$</td> |
| </tr> |
| <tr> |
| <td>Micro recall</td> |
| <td>$\frac{TP}{TP + FN}=\frac{\sum_{i=0}^{N-1} \left|P_i \cap L_i\right|} |
| {\sum_{i=0}^{N-1} \left|P_i \cap L_i\right| + \sum_{i=0}^{N-1} \left|L_i - P_i\right|}$</td> |
| </tr> |
| <tr> |
| <td>Micro F1 Measure</td> |
| <td> |
| $2 \cdot \frac{TP}{2 \cdot TP + FP + FN}=2 \cdot \frac{\sum_{i=0}^{N-1} \left|P_i \cap L_i\right|}{2 \cdot |
| \sum_{i=0}^{N-1} \left|P_i \cap L_i\right| + \sum_{i=0}^{N-1} \left|L_i - P_i\right| + \sum_{i=0}^{N-1} |
| \left|P_i - L_i\right|}$ |
| </td> |
| </tr> |
| </tbody> |
| </table> |
| |
| <p><strong>Examples</strong></p> |
| |
| <p>The following code snippets illustrate how to evaluate the performance of a multilabel classifier. The examples |
| use the fake prediction and label data for multilabel classification that is shown below.</p> |
| |
| <p>Document predictions:</p> |
| |
| <ul> |
| <li>doc 0 - predict 0, 1 - class 0, 2</li> |
| <li>doc 1 - predict 0, 2 - class 0, 1</li> |
| <li>doc 2 - predict none - class 0</li> |
| <li>doc 3 - predict 2 - class 2</li> |
| <li>doc 4 - predict 2, 0 - class 2, 0</li> |
| <li>doc 5 - predict 0, 1, 2 - class 0, 1</li> |
| <li>doc 6 - predict 1 - class 1, 2</li> |
| </ul> |
| |
| <p>Predicted classes:</p> |
| |
| <ul> |
| <li>class 0 - doc 0, 1, 4, 5 (total 4)</li> |
| <li>class 1 - doc 0, 5, 6 (total 3)</li> |
| <li>class 2 - doc 1, 3, 4, 5 (total 4)</li> |
| </ul> |
| |
| <p>True classes:</p> |
| |
| <ul> |
| <li>class 0 - doc 0, 1, 2, 4, 5 (total 5)</li> |
| <li>class 1 - doc 1, 5, 6 (total 3)</li> |
| <li>class 2 - doc 0, 3, 4, 6 (total 4)</li> |
| </ul> |
| |
| <div class="codetabs"> |
| |
| <div data-lang="scala"> |
| <p>Refer to the <a href="api/scala/index.html#org.apache.spark.mllib.evaluation.MultilabelMetrics"><code>MultilabelMetrics</code> Scala docs</a> for details on the API.</p> |
| |
| <div class="highlight"><pre><span></span><span class="k">import</span> <span class="nn">org.apache.spark.mllib.evaluation.MultilabelMetrics</span> |
| <span class="k">import</span> <span class="nn">org.apache.spark.rdd.RDD</span> |
| |
| <span class="k">val</span> <span class="n">scoreAndLabels</span><span class="k">:</span> <span class="kt">RDD</span><span class="o">[(</span><span class="kt">Array</span><span class="o">[</span><span class="kt">Double</span><span class="o">]</span>, <span class="kt">Array</span><span class="o">[</span><span class="kt">Double</span><span class="o">])]</span> <span class="k">=</span> <span class="n">sc</span><span class="o">.</span><span class="n">parallelize</span><span class="o">(</span> |
| <span class="nc">Seq</span><span class="o">((</span><span class="nc">Array</span><span class="o">(</span><span class="mf">0.0</span><span class="o">,</span> <span class="mf">1.0</span><span class="o">),</span> <span class="nc">Array</span><span class="o">(</span><span class="mf">0.0</span><span class="o">,</span> <span class="mf">2.0</span><span class="o">)),</span> |
| <span class="o">(</span><span class="nc">Array</span><span class="o">(</span><span class="mf">0.0</span><span class="o">,</span> <span class="mf">2.0</span><span class="o">),</span> <span class="nc">Array</span><span class="o">(</span><span class="mf">0.0</span><span class="o">,</span> <span class="mf">1.0</span><span class="o">)),</span> |
| <span class="o">(</span><span class="nc">Array</span><span class="o">.</span><span class="n">empty</span><span class="o">[</span><span class="kt">Double</span><span class="o">],</span> <span class="nc">Array</span><span class="o">(</span><span class="mf">0.0</span><span class="o">)),</span> |
| <span class="o">(</span><span class="nc">Array</span><span class="o">(</span><span class="mf">2.0</span><span class="o">),</span> <span class="nc">Array</span><span class="o">(</span><span class="mf">2.0</span><span class="o">)),</span> |
| <span class="o">(</span><span class="nc">Array</span><span class="o">(</span><span class="mf">2.0</span><span class="o">,</span> <span class="mf">0.0</span><span class="o">),</span> <span class="nc">Array</span><span class="o">(</span><span class="mf">2.0</span><span class="o">,</span> <span class="mf">0.0</span><span class="o">)),</span> |
| <span class="o">(</span><span class="nc">Array</span><span class="o">(</span><span class="mf">0.0</span><span class="o">,</span> <span class="mf">1.0</span><span class="o">,</span> <span class="mf">2.0</span><span class="o">),</span> <span class="nc">Array</span><span class="o">(</span><span class="mf">0.0</span><span class="o">,</span> <span class="mf">1.0</span><span class="o">)),</span> |
| <span class="o">(</span><span class="nc">Array</span><span class="o">(</span><span class="mf">1.0</span><span class="o">),</span> <span class="nc">Array</span><span class="o">(</span><span class="mf">1.0</span><span class="o">,</span> <span class="mf">2.0</span><span class="o">))),</span> <span class="mi">2</span><span class="o">)</span> |
| |
| <span class="c1">// Instantiate metrics object</span> |
| <span class="k">val</span> <span class="n">metrics</span> <span class="k">=</span> <span class="k">new</span> <span class="nc">MultilabelMetrics</span><span class="o">(</span><span class="n">scoreAndLabels</span><span class="o">)</span> |
| |
| <span class="c1">// Summary stats</span> |
| <span class="n">println</span><span class="o">(</span><span class="s">s"Recall = </span><span class="si">${</span><span class="n">metrics</span><span class="o">.</span><span class="n">recall</span><span class="si">}</span><span class="s">"</span><span class="o">)</span> |
| <span class="n">println</span><span class="o">(</span><span class="s">s"Precision = </span><span class="si">${</span><span class="n">metrics</span><span class="o">.</span><span class="n">precision</span><span class="si">}</span><span class="s">"</span><span class="o">)</span> |
| <span class="n">println</span><span class="o">(</span><span class="s">s"F1 measure = </span><span class="si">${</span><span class="n">metrics</span><span class="o">.</span><span class="n">f1Measure</span><span class="si">}</span><span class="s">"</span><span class="o">)</span> |
| <span class="n">println</span><span class="o">(</span><span class="s">s"Accuracy = </span><span class="si">${</span><span class="n">metrics</span><span class="o">.</span><span class="n">accuracy</span><span class="si">}</span><span class="s">"</span><span class="o">)</span> |
| |
| <span class="c1">// Individual label stats</span> |
| <span class="n">metrics</span><span class="o">.</span><span class="n">labels</span><span class="o">.</span><span class="n">foreach</span><span class="o">(</span><span class="n">label</span> <span class="k">=></span> |
| <span class="n">println</span><span class="o">(</span><span class="s">s"Class </span><span class="si">$label</span><span class="s"> precision = </span><span class="si">${</span><span class="n">metrics</span><span class="o">.</span><span class="n">precision</span><span class="o">(</span><span class="n">label</span><span class="o">)</span><span class="si">}</span><span class="s">"</span><span class="o">))</span> |
| <span class="n">metrics</span><span class="o">.</span><span class="n">labels</span><span class="o">.</span><span class="n">foreach</span><span class="o">(</span><span class="n">label</span> <span class="k">=></span> <span class="n">println</span><span class="o">(</span><span class="s">s"Class </span><span class="si">$label</span><span class="s"> recall = </span><span class="si">${</span><span class="n">metrics</span><span class="o">.</span><span class="n">recall</span><span class="o">(</span><span class="n">label</span><span class="o">)</span><span class="si">}</span><span class="s">"</span><span class="o">))</span> |
| <span class="n">metrics</span><span class="o">.</span><span class="n">labels</span><span class="o">.</span><span class="n">foreach</span><span class="o">(</span><span class="n">label</span> <span class="k">=></span> <span class="n">println</span><span class="o">(</span><span class="s">s"Class </span><span class="si">$label</span><span class="s"> F1-score = </span><span class="si">${</span><span class="n">metrics</span><span class="o">.</span><span class="n">f1Measure</span><span class="o">(</span><span class="n">label</span><span class="o">)</span><span class="si">}</span><span class="s">"</span><span class="o">))</span> |
| |
| <span class="c1">// Micro stats</span> |
| <span class="n">println</span><span class="o">(</span><span class="s">s"Micro recall = </span><span class="si">${</span><span class="n">metrics</span><span class="o">.</span><span class="n">microRecall</span><span class="si">}</span><span class="s">"</span><span class="o">)</span> |
| <span class="n">println</span><span class="o">(</span><span class="s">s"Micro precision = </span><span class="si">${</span><span class="n">metrics</span><span class="o">.</span><span class="n">microPrecision</span><span class="si">}</span><span class="s">"</span><span class="o">)</span> |
| <span class="n">println</span><span class="o">(</span><span class="s">s"Micro F1 measure = </span><span class="si">${</span><span class="n">metrics</span><span class="o">.</span><span class="n">microF1Measure</span><span class="si">}</span><span class="s">"</span><span class="o">)</span> |
| |
| <span class="c1">// Hamming loss</span> |
| <span class="n">println</span><span class="o">(</span><span class="s">s"Hamming loss = </span><span class="si">${</span><span class="n">metrics</span><span class="o">.</span><span class="n">hammingLoss</span><span class="si">}</span><span class="s">"</span><span class="o">)</span> |
| |
| <span class="c1">// Subset accuracy</span> |
| <span class="n">println</span><span class="o">(</span><span class="s">s"Subset accuracy = </span><span class="si">${</span><span class="n">metrics</span><span class="o">.</span><span class="n">subsetAccuracy</span><span class="si">}</span><span class="s">"</span><span class="o">)</span> |
| </pre></div> |
| <div><small>Find full example code at "examples/src/main/scala/org/apache/spark/examples/mllib/MultiLabelMetricsExample.scala" in the Spark repo.</small></div> |
| |
| </div> |
| |
| <div data-lang="java"> |
| <p>Refer to the <a href="api/java/org/apache/spark/mllib/evaluation/MultilabelMetrics.html"><code>MultilabelMetrics</code> Java docs</a> for details on the API.</p> |
| |
| <div class="highlight"><pre><span></span><span class="kn">import</span> <span class="nn">java.util.Arrays</span><span class="o">;</span> |
| <span class="kn">import</span> <span class="nn">java.util.List</span><span class="o">;</span> |
| |
| <span class="kn">import</span> <span class="nn">scala.Tuple2</span><span class="o">;</span> |
| |
| <span class="kn">import</span> <span class="nn">org.apache.spark.api.java.*</span><span class="o">;</span> |
| <span class="kn">import</span> <span class="nn">org.apache.spark.mllib.evaluation.MultilabelMetrics</span><span class="o">;</span> |
| <span class="kn">import</span> <span class="nn">org.apache.spark.SparkConf</span><span class="o">;</span> |
| |
| <span class="n">List</span><span class="o"><</span><span class="n">Tuple2</span><span class="o"><</span><span class="kt">double</span><span class="o">[],</span> <span class="kt">double</span><span class="o">[]>></span> <span class="n">data</span> <span class="o">=</span> <span class="n">Arrays</span><span class="o">.</span><span class="na">asList</span><span class="o">(</span> |
| <span class="k">new</span> <span class="n">Tuple2</span><span class="o"><>(</span><span class="k">new</span> <span class="kt">double</span><span class="o">[]{</span><span class="mf">0.0</span><span class="o">,</span> <span class="mf">1.0</span><span class="o">},</span> <span class="k">new</span> <span class="kt">double</span><span class="o">[]{</span><span class="mf">0.0</span><span class="o">,</span> <span class="mf">2.0</span><span class="o">}),</span> |
| <span class="k">new</span> <span class="n">Tuple2</span><span class="o"><>(</span><span class="k">new</span> <span class="kt">double</span><span class="o">[]{</span><span class="mf">0.0</span><span class="o">,</span> <span class="mf">2.0</span><span class="o">},</span> <span class="k">new</span> <span class="kt">double</span><span class="o">[]{</span><span class="mf">0.0</span><span class="o">,</span> <span class="mf">1.0</span><span class="o">}),</span> |
| <span class="k">new</span> <span class="n">Tuple2</span><span class="o"><>(</span><span class="k">new</span> <span class="kt">double</span><span class="o">[]{},</span> <span class="k">new</span> <span class="kt">double</span><span class="o">[]{</span><span class="mf">0.0</span><span class="o">}),</span> |
| <span class="k">new</span> <span class="n">Tuple2</span><span class="o"><>(</span><span class="k">new</span> <span class="kt">double</span><span class="o">[]{</span><span class="mf">2.0</span><span class="o">},</span> <span class="k">new</span> <span class="kt">double</span><span class="o">[]{</span><span class="mf">2.0</span><span class="o">}),</span> |
| <span class="k">new</span> <span class="n">Tuple2</span><span class="o"><>(</span><span class="k">new</span> <span class="kt">double</span><span class="o">[]{</span><span class="mf">2.0</span><span class="o">,</span> <span class="mf">0.0</span><span class="o">},</span> <span class="k">new</span> <span class="kt">double</span><span class="o">[]{</span><span class="mf">2.0</span><span class="o">,</span> <span class="mf">0.0</span><span class="o">}),</span> |
| <span class="k">new</span> <span class="n">Tuple2</span><span class="o"><>(</span><span class="k">new</span> <span class="kt">double</span><span class="o">[]{</span><span class="mf">0.0</span><span class="o">,</span> <span class="mf">1.0</span><span class="o">,</span> <span class="mf">2.0</span><span class="o">},</span> <span class="k">new</span> <span class="kt">double</span><span class="o">[]{</span><span class="mf">0.0</span><span class="o">,</span> <span class="mf">1.0</span><span class="o">}),</span> |
| <span class="k">new</span> <span class="n">Tuple2</span><span class="o"><>(</span><span class="k">new</span> <span class="kt">double</span><span class="o">[]{</span><span class="mf">1.0</span><span class="o">},</span> <span class="k">new</span> <span class="kt">double</span><span class="o">[]{</span><span class="mf">1.0</span><span class="o">,</span> <span class="mf">2.0</span><span class="o">})</span> |
| <span class="o">);</span> |
| <span class="n">JavaRDD</span><span class="o"><</span><span class="n">Tuple2</span><span class="o"><</span><span class="kt">double</span><span class="o">[],</span> <span class="kt">double</span><span class="o">[]>></span> <span class="n">scoreAndLabels</span> <span class="o">=</span> <span class="n">sc</span><span class="o">.</span><span class="na">parallelize</span><span class="o">(</span><span class="n">data</span><span class="o">);</span> |
| |
| <span class="c1">// Instantiate metrics object</span> |
| <span class="n">MultilabelMetrics</span> <span class="n">metrics</span> <span class="o">=</span> <span class="k">new</span> <span class="n">MultilabelMetrics</span><span class="o">(</span><span class="n">scoreAndLabels</span><span class="o">.</span><span class="na">rdd</span><span class="o">());</span> |
| |
| <span class="c1">// Summary stats</span> |
| <span class="n">System</span><span class="o">.</span><span class="na">out</span><span class="o">.</span><span class="na">format</span><span class="o">(</span><span class="s">"Recall = %f\n"</span><span class="o">,</span> <span class="n">metrics</span><span class="o">.</span><span class="na">recall</span><span class="o">());</span> |
| <span class="n">System</span><span class="o">.</span><span class="na">out</span><span class="o">.</span><span class="na">format</span><span class="o">(</span><span class="s">"Precision = %f\n"</span><span class="o">,</span> <span class="n">metrics</span><span class="o">.</span><span class="na">precision</span><span class="o">());</span> |
| <span class="n">System</span><span class="o">.</span><span class="na">out</span><span class="o">.</span><span class="na">format</span><span class="o">(</span><span class="s">"F1 measure = %f\n"</span><span class="o">,</span> <span class="n">metrics</span><span class="o">.</span><span class="na">f1Measure</span><span class="o">());</span> |
| <span class="n">System</span><span class="o">.</span><span class="na">out</span><span class="o">.</span><span class="na">format</span><span class="o">(</span><span class="s">"Accuracy = %f\n"</span><span class="o">,</span> <span class="n">metrics</span><span class="o">.</span><span class="na">accuracy</span><span class="o">());</span> |
| |
| <span class="c1">// Stats by labels</span> |
| <span class="k">for</span> <span class="o">(</span><span class="kt">int</span> <span class="n">i</span> <span class="o">=</span> <span class="mi">0</span><span class="o">;</span> <span class="n">i</span> <span class="o"><</span> <span class="n">metrics</span><span class="o">.</span><span class="na">labels</span><span class="o">().</span><span class="na">length</span> <span class="o">-</span> <span class="mi">1</span><span class="o">;</span> <span class="n">i</span><span class="o">++)</span> <span class="o">{</span> |
| <span class="n">System</span><span class="o">.</span><span class="na">out</span><span class="o">.</span><span class="na">format</span><span class="o">(</span><span class="s">"Class %1.1f precision = %f\n"</span><span class="o">,</span> <span class="n">metrics</span><span class="o">.</span><span class="na">labels</span><span class="o">()[</span><span class="n">i</span><span class="o">],</span> <span class="n">metrics</span><span class="o">.</span><span class="na">precision</span><span class="o">(</span> |
| <span class="n">metrics</span><span class="o">.</span><span class="na">labels</span><span class="o">()[</span><span class="n">i</span><span class="o">]));</span> |
| <span class="n">System</span><span class="o">.</span><span class="na">out</span><span class="o">.</span><span class="na">format</span><span class="o">(</span><span class="s">"Class %1.1f recall = %f\n"</span><span class="o">,</span> <span class="n">metrics</span><span class="o">.</span><span class="na">labels</span><span class="o">()[</span><span class="n">i</span><span class="o">],</span> <span class="n">metrics</span><span class="o">.</span><span class="na">recall</span><span class="o">(</span> |
| <span class="n">metrics</span><span class="o">.</span><span class="na">labels</span><span class="o">()[</span><span class="n">i</span><span class="o">]));</span> |
| <span class="n">System</span><span class="o">.</span><span class="na">out</span><span class="o">.</span><span class="na">format</span><span class="o">(</span><span class="s">"Class %1.1f F1 score = %f\n"</span><span class="o">,</span> <span class="n">metrics</span><span class="o">.</span><span class="na">labels</span><span class="o">()[</span><span class="n">i</span><span class="o">],</span> <span class="n">metrics</span><span class="o">.</span><span class="na">f1Measure</span><span class="o">(</span> |
| <span class="n">metrics</span><span class="o">.</span><span class="na">labels</span><span class="o">()[</span><span class="n">i</span><span class="o">]));</span> |
| <span class="o">}</span> |
| |
| <span class="c1">// Micro stats</span> |
| <span class="n">System</span><span class="o">.</span><span class="na">out</span><span class="o">.</span><span class="na">format</span><span class="o">(</span><span class="s">"Micro recall = %f\n"</span><span class="o">,</span> <span class="n">metrics</span><span class="o">.</span><span class="na">microRecall</span><span class="o">());</span> |
| <span class="n">System</span><span class="o">.</span><span class="na">out</span><span class="o">.</span><span class="na">format</span><span class="o">(</span><span class="s">"Micro precision = %f\n"</span><span class="o">,</span> <span class="n">metrics</span><span class="o">.</span><span class="na">microPrecision</span><span class="o">());</span> |
| <span class="n">System</span><span class="o">.</span><span class="na">out</span><span class="o">.</span><span class="na">format</span><span class="o">(</span><span class="s">"Micro F1 measure = %f\n"</span><span class="o">,</span> <span class="n">metrics</span><span class="o">.</span><span class="na">microF1Measure</span><span class="o">());</span> |
| |
| <span class="c1">// Hamming loss</span> |
| <span class="n">System</span><span class="o">.</span><span class="na">out</span><span class="o">.</span><span class="na">format</span><span class="o">(</span><span class="s">"Hamming loss = %f\n"</span><span class="o">,</span> <span class="n">metrics</span><span class="o">.</span><span class="na">hammingLoss</span><span class="o">());</span> |
| |
| <span class="c1">// Subset accuracy</span> |
| <span class="n">System</span><span class="o">.</span><span class="na">out</span><span class="o">.</span><span class="na">format</span><span class="o">(</span><span class="s">"Subset accuracy = %f\n"</span><span class="o">,</span> <span class="n">metrics</span><span class="o">.</span><span class="na">subsetAccuracy</span><span class="o">());</span> |
| </pre></div> |
| <div><small>Find full example code at "examples/src/main/java/org/apache/spark/examples/mllib/JavaMultiLabelClassificationMetricsExample.java" in the Spark repo.</small></div> |
| |
| </div> |
| |
| <div data-lang="python"> |
| <p>Refer to the <a href="api/python/pyspark.mllib.html#pyspark.mllib.evaluation.MultilabelMetrics"><code>MultilabelMetrics</code> Python docs</a> for more details on the API.</p> |
| |
| <div class="highlight"><pre><span></span><span class="kn">from</span> <span class="nn">pyspark.mllib.evaluation</span> <span class="kn">import</span> <span class="n">MultilabelMetrics</span> |
| |
| <span class="n">scoreAndLabels</span> <span class="o">=</span> <span class="n">sc</span><span class="o">.</span><span class="n">parallelize</span><span class="p">([</span> |
| <span class="p">([</span><span class="mf">0.0</span><span class="p">,</span> <span class="mf">1.0</span><span class="p">],</span> <span class="p">[</span><span class="mf">0.0</span><span class="p">,</span> <span class="mf">2.0</span><span class="p">]),</span> |
| <span class="p">([</span><span class="mf">0.0</span><span class="p">,</span> <span class="mf">2.0</span><span class="p">],</span> <span class="p">[</span><span class="mf">0.0</span><span class="p">,</span> <span class="mf">1.0</span><span class="p">]),</span> |
| <span class="p">([],</span> <span class="p">[</span><span class="mf">0.0</span><span class="p">]),</span> |
| <span class="p">([</span><span class="mf">2.0</span><span class="p">],</span> <span class="p">[</span><span class="mf">2.0</span><span class="p">]),</span> |
| <span class="p">([</span><span class="mf">2.0</span><span class="p">,</span> <span class="mf">0.0</span><span class="p">],</span> <span class="p">[</span><span class="mf">2.0</span><span class="p">,</span> <span class="mf">0.0</span><span class="p">]),</span> |
| <span class="p">([</span><span class="mf">0.0</span><span class="p">,</span> <span class="mf">1.0</span><span class="p">,</span> <span class="mf">2.0</span><span class="p">],</span> <span class="p">[</span><span class="mf">0.0</span><span class="p">,</span> <span class="mf">1.0</span><span class="p">]),</span> |
| <span class="p">([</span><span class="mf">1.0</span><span class="p">],</span> <span class="p">[</span><span class="mf">1.0</span><span class="p">,</span> <span class="mf">2.0</span><span class="p">])])</span> |
| |
| <span class="c1"># Instantiate metrics object</span> |
| <span class="n">metrics</span> <span class="o">=</span> <span class="n">MultilabelMetrics</span><span class="p">(</span><span class="n">scoreAndLabels</span><span class="p">)</span> |
| |
| <span class="c1"># Summary stats</span> |
| <span class="k">print</span><span class="p">(</span><span class="s2">"Recall = </span><span class="si">%s</span><span class="s2">"</span> <span class="o">%</span> <span class="n">metrics</span><span class="o">.</span><span class="n">recall</span><span class="p">())</span> |
| <span class="k">print</span><span class="p">(</span><span class="s2">"Precision = </span><span class="si">%s</span><span class="s2">"</span> <span class="o">%</span> <span class="n">metrics</span><span class="o">.</span><span class="n">precision</span><span class="p">())</span> |
| <span class="k">print</span><span class="p">(</span><span class="s2">"F1 measure = </span><span class="si">%s</span><span class="s2">"</span> <span class="o">%</span> <span class="n">metrics</span><span class="o">.</span><span class="n">f1Measure</span><span class="p">())</span> |
| <span class="k">print</span><span class="p">(</span><span class="s2">"Accuracy = </span><span class="si">%s</span><span class="s2">"</span> <span class="o">%</span> <span class="n">metrics</span><span class="o">.</span><span class="n">accuracy</span><span class="p">)</span> |
| |
| <span class="c1"># Individual label stats</span> |
| <span class="n">labels</span> <span class="o">=</span> <span class="n">scoreAndLabels</span><span class="o">.</span><span class="n">flatMap</span><span class="p">(</span><span class="k">lambda</span> <span class="n">x</span><span class="p">:</span> <span class="n">x</span><span class="p">[</span><span class="mi">1</span><span class="p">])</span><span class="o">.</span><span class="n">distinct</span><span class="p">()</span><span class="o">.</span><span class="n">collect</span><span class="p">()</span> |
| <span class="k">for</span> <span class="n">label</span> <span class="ow">in</span> <span class="n">labels</span><span class="p">:</span> |
| <span class="k">print</span><span class="p">(</span><span class="s2">"Class </span><span class="si">%s</span><span class="s2"> precision = </span><span class="si">%s</span><span class="s2">"</span> <span class="o">%</span> <span class="p">(</span><span class="n">label</span><span class="p">,</span> <span class="n">metrics</span><span class="o">.</span><span class="n">precision</span><span class="p">(</span><span class="n">label</span><span class="p">)))</span> |
| <span class="k">print</span><span class="p">(</span><span class="s2">"Class </span><span class="si">%s</span><span class="s2"> recall = </span><span class="si">%s</span><span class="s2">"</span> <span class="o">%</span> <span class="p">(</span><span class="n">label</span><span class="p">,</span> <span class="n">metrics</span><span class="o">.</span><span class="n">recall</span><span class="p">(</span><span class="n">label</span><span class="p">)))</span> |
| <span class="k">print</span><span class="p">(</span><span class="s2">"Class </span><span class="si">%s</span><span class="s2"> F1 Measure = </span><span class="si">%s</span><span class="s2">"</span> <span class="o">%</span> <span class="p">(</span><span class="n">label</span><span class="p">,</span> <span class="n">metrics</span><span class="o">.</span><span class="n">f1Measure</span><span class="p">(</span><span class="n">label</span><span class="p">)))</span> |
| |
| <span class="c1"># Micro stats</span> |
| <span class="k">print</span><span class="p">(</span><span class="s2">"Micro precision = </span><span class="si">%s</span><span class="s2">"</span> <span class="o">%</span> <span class="n">metrics</span><span class="o">.</span><span class="n">microPrecision</span><span class="p">)</span> |
| <span class="k">print</span><span class="p">(</span><span class="s2">"Micro recall = </span><span class="si">%s</span><span class="s2">"</span> <span class="o">%</span> <span class="n">metrics</span><span class="o">.</span><span class="n">microRecall</span><span class="p">)</span> |
| <span class="k">print</span><span class="p">(</span><span class="s2">"Micro F1 measure = </span><span class="si">%s</span><span class="s2">"</span> <span class="o">%</span> <span class="n">metrics</span><span class="o">.</span><span class="n">microF1Measure</span><span class="p">)</span> |
| |
| <span class="c1"># Hamming loss</span> |
| <span class="k">print</span><span class="p">(</span><span class="s2">"Hamming loss = </span><span class="si">%s</span><span class="s2">"</span> <span class="o">%</span> <span class="n">metrics</span><span class="o">.</span><span class="n">hammingLoss</span><span class="p">)</span> |
| |
| <span class="c1"># Subset accuracy</span> |
| <span class="k">print</span><span class="p">(</span><span class="s2">"Subset accuracy = </span><span class="si">%s</span><span class="s2">"</span> <span class="o">%</span> <span class="n">metrics</span><span class="o">.</span><span class="n">subsetAccuracy</span><span class="p">)</span> |
| </pre></div> |
| <div><small>Find full example code at "examples/src/main/python/mllib/multi_label_metrics_example.py" in the Spark repo.</small></div> |
| |
| </div> |
| </div> |
| |
| <h3 id="ranking-systems">Ranking systems</h3> |
| |
| <p>The role of a ranking algorithm (often thought of as a <a href="https://en.wikipedia.org/wiki/Recommender_system">recommender system</a>) |
| is to return to the user a set of relevant items or documents based on some training data. The definition of relevance |
| may vary and is usually application specific. Ranking system metrics aim to quantify the effectiveness of these |
| rankings or recommendations in various contexts. Some metrics compare a set of recommended documents to a ground truth |
| set of relevant documents, while other metrics may incorporate numerical ratings explicitly.</p> |
| |
| <p><strong>Available metrics</strong></p> |
| |
| <p>A ranking system usually deals with a set of $M$ users</p> |
| |
| <script type="math/tex; mode=display">U = \left\{u_0, u_1, ..., u_{M-1}\right\}</script> |
| |
| <p>Each user ($u_i$) having a set of $N_i$ ground truth relevant documents</p> |
| |
| <script type="math/tex; mode=display">D_i = \left\{d_0, d_1, ..., d_{N_i-1}\right\}</script> |
| |
| <p>And a list of $Q_i$ recommended documents, in order of decreasing relevance</p> |
| |
| <script type="math/tex; mode=display">R_i = \left[r_0, r_1, ..., r_{Q_i-1}\right]</script> |
| |
| <p>The goal of the ranking system is to produce the most relevant set of documents for each user. The relevance of the |
| sets and the effectiveness of the algorithms can be measured using the metrics listed below.</p> |
| |
| <p>It is necessary to define a function which, provided a recommended document and a set of ground truth relevant |
| documents, returns a relevance score for the recommended document.</p> |
| |
| <script type="math/tex; mode=display">% <![CDATA[ |
| rel_D(r) = \begin{cases}1 & \text{if $r \in D$}, \\ 0 & \text{otherwise}.\end{cases} %]]></script> |
| |
| <table class="table"> |
| <thead> |
| <tr><th>Metric</th><th>Definition</th><th>Notes</th></tr> |
| </thead> |
| <tbody> |
| <tr> |
| <td> |
| Precision at k |
| </td> |
| <td> |
| $p(k)=\frac{1}{M} \sum_{i=0}^{M-1} {\frac{1}{k} \sum_{j=0}^{\text{min}(Q_i, k) - 1} rel_{D_i}(R_i(j))}$ |
| </td> |
| <td> |
| <a href="https://en.wikipedia.org/wiki/Evaluation_measures_(information_retrieval)#Precision_at_K">Precision at k</a> is a measure of |
| how many of the first k recommended documents are in the set of true relevant documents averaged across all |
| users. In this metric, the order of the recommendations is not taken into account. |
| </td> |
| </tr> |
| <tr> |
| <td>Mean Average Precision</td> |
| <td> |
| $MAP=\frac{1}{M} \sum_{i=0}^{M-1} {\frac{1}{N_i} \sum_{j=0}^{Q_i-1} \frac{rel_{D_i}(R_i(j))}{j + 1}}$ |
| </td> |
| <td> |
| <a href="https://en.wikipedia.org/wiki/Evaluation_measures_(information_retrieval)#Mean_average_precision">MAP</a> is a measure of how |
| many of the recommended documents are in the set of true relevant documents, where the |
| order of the recommendations is taken into account (i.e. penalty for highly relevant documents is higher). |
| </td> |
| </tr> |
| <tr> |
| <td>Normalized Discounted Cumulative Gain</td> |
| <td> |
| $NDCG(k)=\frac{1}{M} \sum_{i=0}^{M-1} {\frac{1}{IDCG(D_i, k)}\sum_{j=0}^{n-1} |
| \frac{rel_{D_i}(R_i(j))}{\text{log}(j+2)}} \\ |
| \text{Where} \\ |
| \hspace{5 mm} n = \text{min}\left(\text{max}\left(Q_i, N_i\right),k\right) \\ |
| \hspace{5 mm} IDCG(D, k) = \sum_{j=0}^{\text{min}(\left|D\right|, k) - 1} \frac{1}{\text{log}(j+2)}$ |
| </td> |
| <td> |
| <a href="https://en.wikipedia.org/wiki/Discounted_cumulative_gain#Normalized_DCG">NDCG at k</a> is a |
| measure of how many of the first k recommended documents are in the set of true relevant documents averaged |
| across all users. In contrast to precision at k, this metric takes into account the order of the recommendations |
| (documents are assumed to be in order of decreasing relevance). |
| </td> |
| </tr> |
| </tbody> |
| </table> |
| |
| <p><strong>Examples</strong></p> |
| |
| <p>The following code snippets illustrate how to load a sample dataset, train an alternating least squares recommendation |
| model on the data, and evaluate the performance of the recommender by several ranking metrics. A brief summary of the |
| methodology is provided below.</p> |
| |
| <p>MovieLens ratings are on a scale of 1-5:</p> |
| |
| <ul> |
| <li>5: Must see</li> |
| <li>4: Will enjoy</li> |
| <li>3: It’s okay</li> |
| <li>2: Fairly bad</li> |
| <li>1: Awful</li> |
| </ul> |
| |
| <p>So we should not recommend a movie if the predicted rating is less than 3. |
| To map ratings to confidence scores, we use:</p> |
| |
| <ul> |
| <li>5 -> 2.5</li> |
| <li>4 -> 1.5</li> |
| <li>3 -> 0.5</li> |
| <li>2 -> -0.5</li> |
| <li>1 -> -1.5.</li> |
| </ul> |
| |
| <p>This mappings means unobserved entries are generally between It’s okay and Fairly bad. The semantics of 0 in this |
| expanded world of non-positive weights are “the same as never having interacted at all.”</p> |
| |
| <div class="codetabs"> |
| |
| <div data-lang="scala"> |
| <p>Refer to the <a href="api/scala/index.html#org.apache.spark.mllib.evaluation.RegressionMetrics"><code>RegressionMetrics</code> Scala docs</a> and <a href="api/scala/index.html#org.apache.spark.mllib.evaluation.RankingMetrics"><code>RankingMetrics</code> Scala docs</a> for details on the API.</p> |
| |
| <div class="highlight"><pre><span></span><span class="k">import</span> <span class="nn">org.apache.spark.mllib.evaluation.</span><span class="o">{</span><span class="nc">RankingMetrics</span><span class="o">,</span> <span class="nc">RegressionMetrics</span><span class="o">}</span> |
| <span class="k">import</span> <span class="nn">org.apache.spark.mllib.recommendation.</span><span class="o">{</span><span class="nc">ALS</span><span class="o">,</span> <span class="nc">Rating</span><span class="o">}</span> |
| |
| <span class="c1">// Read in the ratings data</span> |
| <span class="k">val</span> <span class="n">ratings</span> <span class="k">=</span> <span class="n">spark</span><span class="o">.</span><span class="n">read</span><span class="o">.</span><span class="n">textFile</span><span class="o">(</span><span class="s">"data/mllib/sample_movielens_data.txt"</span><span class="o">).</span><span class="n">rdd</span><span class="o">.</span><span class="n">map</span> <span class="o">{</span> <span class="n">line</span> <span class="k">=></span> |
| <span class="k">val</span> <span class="n">fields</span> <span class="k">=</span> <span class="n">line</span><span class="o">.</span><span class="n">split</span><span class="o">(</span><span class="s">"::"</span><span class="o">)</span> |
| <span class="nc">Rating</span><span class="o">(</span><span class="n">fields</span><span class="o">(</span><span class="mi">0</span><span class="o">).</span><span class="n">toInt</span><span class="o">,</span> <span class="n">fields</span><span class="o">(</span><span class="mi">1</span><span class="o">).</span><span class="n">toInt</span><span class="o">,</span> <span class="n">fields</span><span class="o">(</span><span class="mi">2</span><span class="o">).</span><span class="n">toDouble</span> <span class="o">-</span> <span class="mf">2.5</span><span class="o">)</span> |
| <span class="o">}.</span><span class="n">cache</span><span class="o">()</span> |
| |
| <span class="c1">// Map ratings to 1 or 0, 1 indicating a movie that should be recommended</span> |
| <span class="k">val</span> <span class="n">binarizedRatings</span> <span class="k">=</span> <span class="n">ratings</span><span class="o">.</span><span class="n">map</span><span class="o">(</span><span class="n">r</span> <span class="k">=></span> <span class="nc">Rating</span><span class="o">(</span><span class="n">r</span><span class="o">.</span><span class="n">user</span><span class="o">,</span> <span class="n">r</span><span class="o">.</span><span class="n">product</span><span class="o">,</span> |
| <span class="k">if</span> <span class="o">(</span><span class="n">r</span><span class="o">.</span><span class="n">rating</span> <span class="o">></span> <span class="mi">0</span><span class="o">)</span> <span class="mf">1.0</span> <span class="k">else</span> <span class="mf">0.0</span><span class="o">)).</span><span class="n">cache</span><span class="o">()</span> |
| |
| <span class="c1">// Summarize ratings</span> |
| <span class="k">val</span> <span class="n">numRatings</span> <span class="k">=</span> <span class="n">ratings</span><span class="o">.</span><span class="n">count</span><span class="o">()</span> |
| <span class="k">val</span> <span class="n">numUsers</span> <span class="k">=</span> <span class="n">ratings</span><span class="o">.</span><span class="n">map</span><span class="o">(</span><span class="k">_</span><span class="o">.</span><span class="n">user</span><span class="o">).</span><span class="n">distinct</span><span class="o">().</span><span class="n">count</span><span class="o">()</span> |
| <span class="k">val</span> <span class="n">numMovies</span> <span class="k">=</span> <span class="n">ratings</span><span class="o">.</span><span class="n">map</span><span class="o">(</span><span class="k">_</span><span class="o">.</span><span class="n">product</span><span class="o">).</span><span class="n">distinct</span><span class="o">().</span><span class="n">count</span><span class="o">()</span> |
| <span class="n">println</span><span class="o">(</span><span class="s">s"Got </span><span class="si">$numRatings</span><span class="s"> ratings from </span><span class="si">$numUsers</span><span class="s"> users on </span><span class="si">$numMovies</span><span class="s"> movies."</span><span class="o">)</span> |
| |
| <span class="c1">// Build the model</span> |
| <span class="k">val</span> <span class="n">numIterations</span> <span class="k">=</span> <span class="mi">10</span> |
| <span class="k">val</span> <span class="n">rank</span> <span class="k">=</span> <span class="mi">10</span> |
| <span class="k">val</span> <span class="n">lambda</span> <span class="k">=</span> <span class="mf">0.01</span> |
| <span class="k">val</span> <span class="n">model</span> <span class="k">=</span> <span class="nc">ALS</span><span class="o">.</span><span class="n">train</span><span class="o">(</span><span class="n">ratings</span><span class="o">,</span> <span class="n">rank</span><span class="o">,</span> <span class="n">numIterations</span><span class="o">,</span> <span class="n">lambda</span><span class="o">)</span> |
| |
| <span class="c1">// Define a function to scale ratings from 0 to 1</span> |
| <span class="k">def</span> <span class="n">scaledRating</span><span class="o">(</span><span class="n">r</span><span class="k">:</span> <span class="kt">Rating</span><span class="o">)</span><span class="k">:</span> <span class="kt">Rating</span> <span class="o">=</span> <span class="o">{</span> |
| <span class="k">val</span> <span class="n">scaledRating</span> <span class="k">=</span> <span class="n">math</span><span class="o">.</span><span class="n">max</span><span class="o">(</span><span class="n">math</span><span class="o">.</span><span class="n">min</span><span class="o">(</span><span class="n">r</span><span class="o">.</span><span class="n">rating</span><span class="o">,</span> <span class="mf">1.0</span><span class="o">),</span> <span class="mf">0.0</span><span class="o">)</span> |
| <span class="nc">Rating</span><span class="o">(</span><span class="n">r</span><span class="o">.</span><span class="n">user</span><span class="o">,</span> <span class="n">r</span><span class="o">.</span><span class="n">product</span><span class="o">,</span> <span class="n">scaledRating</span><span class="o">)</span> |
| <span class="o">}</span> |
| |
| <span class="c1">// Get sorted top ten predictions for each user and then scale from [0, 1]</span> |
| <span class="k">val</span> <span class="n">userRecommended</span> <span class="k">=</span> <span class="n">model</span><span class="o">.</span><span class="n">recommendProductsForUsers</span><span class="o">(</span><span class="mi">10</span><span class="o">).</span><span class="n">map</span> <span class="o">{</span> <span class="k">case</span> <span class="o">(</span><span class="n">user</span><span class="o">,</span> <span class="n">recs</span><span class="o">)</span> <span class="k">=></span> |
| <span class="o">(</span><span class="n">user</span><span class="o">,</span> <span class="n">recs</span><span class="o">.</span><span class="n">map</span><span class="o">(</span><span class="n">scaledRating</span><span class="o">))</span> |
| <span class="o">}</span> |
| |
| <span class="c1">// Assume that any movie a user rated 3 or higher (which maps to a 1) is a relevant document</span> |
| <span class="c1">// Compare with top ten most relevant documents</span> |
| <span class="k">val</span> <span class="n">userMovies</span> <span class="k">=</span> <span class="n">binarizedRatings</span><span class="o">.</span><span class="n">groupBy</span><span class="o">(</span><span class="k">_</span><span class="o">.</span><span class="n">user</span><span class="o">)</span> |
| <span class="k">val</span> <span class="n">relevantDocuments</span> <span class="k">=</span> <span class="n">userMovies</span><span class="o">.</span><span class="n">join</span><span class="o">(</span><span class="n">userRecommended</span><span class="o">).</span><span class="n">map</span> <span class="o">{</span> <span class="k">case</span> <span class="o">(</span><span class="n">user</span><span class="o">,</span> <span class="o">(</span><span class="n">actual</span><span class="o">,</span> |
| <span class="n">predictions</span><span class="o">))</span> <span class="k">=></span> |
| <span class="o">(</span><span class="n">predictions</span><span class="o">.</span><span class="n">map</span><span class="o">(</span><span class="k">_</span><span class="o">.</span><span class="n">product</span><span class="o">),</span> <span class="n">actual</span><span class="o">.</span><span class="n">filter</span><span class="o">(</span><span class="k">_</span><span class="o">.</span><span class="n">rating</span> <span class="o">></span> <span class="mf">0.0</span><span class="o">).</span><span class="n">map</span><span class="o">(</span><span class="k">_</span><span class="o">.</span><span class="n">product</span><span class="o">).</span><span class="n">toArray</span><span class="o">)</span> |
| <span class="o">}</span> |
| |
| <span class="c1">// Instantiate metrics object</span> |
| <span class="k">val</span> <span class="n">metrics</span> <span class="k">=</span> <span class="k">new</span> <span class="nc">RankingMetrics</span><span class="o">(</span><span class="n">relevantDocuments</span><span class="o">)</span> |
| |
| <span class="c1">// Precision at K</span> |
| <span class="nc">Array</span><span class="o">(</span><span class="mi">1</span><span class="o">,</span> <span class="mi">3</span><span class="o">,</span> <span class="mi">5</span><span class="o">).</span><span class="n">foreach</span> <span class="o">{</span> <span class="n">k</span> <span class="k">=></span> |
| <span class="n">println</span><span class="o">(</span><span class="s">s"Precision at </span><span class="si">$k</span><span class="s"> = </span><span class="si">${</span><span class="n">metrics</span><span class="o">.</span><span class="n">precisionAt</span><span class="o">(</span><span class="n">k</span><span class="o">)</span><span class="si">}</span><span class="s">"</span><span class="o">)</span> |
| <span class="o">}</span> |
| |
| <span class="c1">// Mean average precision</span> |
| <span class="n">println</span><span class="o">(</span><span class="s">s"Mean average precision = </span><span class="si">${</span><span class="n">metrics</span><span class="o">.</span><span class="n">meanAveragePrecision</span><span class="si">}</span><span class="s">"</span><span class="o">)</span> |
| |
| <span class="c1">// Normalized discounted cumulative gain</span> |
| <span class="nc">Array</span><span class="o">(</span><span class="mi">1</span><span class="o">,</span> <span class="mi">3</span><span class="o">,</span> <span class="mi">5</span><span class="o">).</span><span class="n">foreach</span> <span class="o">{</span> <span class="n">k</span> <span class="k">=></span> |
| <span class="n">println</span><span class="o">(</span><span class="s">s"NDCG at </span><span class="si">$k</span><span class="s"> = </span><span class="si">${</span><span class="n">metrics</span><span class="o">.</span><span class="n">ndcgAt</span><span class="o">(</span><span class="n">k</span><span class="o">)</span><span class="si">}</span><span class="s">"</span><span class="o">)</span> |
| <span class="o">}</span> |
| |
| <span class="c1">// Get predictions for each data point</span> |
| <span class="k">val</span> <span class="n">allPredictions</span> <span class="k">=</span> <span class="n">model</span><span class="o">.</span><span class="n">predict</span><span class="o">(</span><span class="n">ratings</span><span class="o">.</span><span class="n">map</span><span class="o">(</span><span class="n">r</span> <span class="k">=></span> <span class="o">(</span><span class="n">r</span><span class="o">.</span><span class="n">user</span><span class="o">,</span> <span class="n">r</span><span class="o">.</span><span class="n">product</span><span class="o">))).</span><span class="n">map</span><span class="o">(</span><span class="n">r</span> <span class="k">=></span> <span class="o">((</span><span class="n">r</span><span class="o">.</span><span class="n">user</span><span class="o">,</span> |
| <span class="n">r</span><span class="o">.</span><span class="n">product</span><span class="o">),</span> <span class="n">r</span><span class="o">.</span><span class="n">rating</span><span class="o">))</span> |
| <span class="k">val</span> <span class="n">allRatings</span> <span class="k">=</span> <span class="n">ratings</span><span class="o">.</span><span class="n">map</span><span class="o">(</span><span class="n">r</span> <span class="k">=></span> <span class="o">((</span><span class="n">r</span><span class="o">.</span><span class="n">user</span><span class="o">,</span> <span class="n">r</span><span class="o">.</span><span class="n">product</span><span class="o">),</span> <span class="n">r</span><span class="o">.</span><span class="n">rating</span><span class="o">))</span> |
| <span class="k">val</span> <span class="n">predictionsAndLabels</span> <span class="k">=</span> <span class="n">allPredictions</span><span class="o">.</span><span class="n">join</span><span class="o">(</span><span class="n">allRatings</span><span class="o">).</span><span class="n">map</span> <span class="o">{</span> <span class="k">case</span> <span class="o">((</span><span class="n">user</span><span class="o">,</span> <span class="n">product</span><span class="o">),</span> |
| <span class="o">(</span><span class="n">predicted</span><span class="o">,</span> <span class="n">actual</span><span class="o">))</span> <span class="k">=></span> |
| <span class="o">(</span><span class="n">predicted</span><span class="o">,</span> <span class="n">actual</span><span class="o">)</span> |
| <span class="o">}</span> |
| |
| <span class="c1">// Get the RMSE using regression metrics</span> |
| <span class="k">val</span> <span class="n">regressionMetrics</span> <span class="k">=</span> <span class="k">new</span> <span class="nc">RegressionMetrics</span><span class="o">(</span><span class="n">predictionsAndLabels</span><span class="o">)</span> |
| <span class="n">println</span><span class="o">(</span><span class="s">s"RMSE = </span><span class="si">${</span><span class="n">regressionMetrics</span><span class="o">.</span><span class="n">rootMeanSquaredError</span><span class="si">}</span><span class="s">"</span><span class="o">)</span> |
| |
| <span class="c1">// R-squared</span> |
| <span class="n">println</span><span class="o">(</span><span class="s">s"R-squared = </span><span class="si">${</span><span class="n">regressionMetrics</span><span class="o">.</span><span class="n">r2</span><span class="si">}</span><span class="s">"</span><span class="o">)</span> |
| </pre></div> |
| <div><small>Find full example code at "examples/src/main/scala/org/apache/spark/examples/mllib/RankingMetricsExample.scala" in the Spark repo.</small></div> |
| |
| </div> |
| |
| <div data-lang="java"> |
| <p>Refer to the <a href="api/java/org/apache/spark/mllib/evaluation/RegressionMetrics.html"><code>RegressionMetrics</code> Java docs</a> and <a href="api/java/org/apache/spark/mllib/evaluation/RankingMetrics.html"><code>RankingMetrics</code> Java docs</a> for details on the API.</p> |
| |
| <div class="highlight"><pre><span></span><span class="kn">import</span> <span class="nn">java.util.*</span><span class="o">;</span> |
| |
| <span class="kn">import</span> <span class="nn">scala.Tuple2</span><span class="o">;</span> |
| |
| <span class="kn">import</span> <span class="nn">org.apache.spark.api.java.*</span><span class="o">;</span> |
| <span class="kn">import</span> <span class="nn">org.apache.spark.mllib.evaluation.RegressionMetrics</span><span class="o">;</span> |
| <span class="kn">import</span> <span class="nn">org.apache.spark.mllib.evaluation.RankingMetrics</span><span class="o">;</span> |
| <span class="kn">import</span> <span class="nn">org.apache.spark.mllib.recommendation.ALS</span><span class="o">;</span> |
| <span class="kn">import</span> <span class="nn">org.apache.spark.mllib.recommendation.MatrixFactorizationModel</span><span class="o">;</span> |
| <span class="kn">import</span> <span class="nn">org.apache.spark.mllib.recommendation.Rating</span><span class="o">;</span> |
| |
| <span class="n">String</span> <span class="n">path</span> <span class="o">=</span> <span class="s">"data/mllib/sample_movielens_data.txt"</span><span class="o">;</span> |
| <span class="n">JavaRDD</span><span class="o"><</span><span class="n">String</span><span class="o">></span> <span class="n">data</span> <span class="o">=</span> <span class="n">sc</span><span class="o">.</span><span class="na">textFile</span><span class="o">(</span><span class="n">path</span><span class="o">);</span> |
| <span class="n">JavaRDD</span><span class="o"><</span><span class="n">Rating</span><span class="o">></span> <span class="n">ratings</span> <span class="o">=</span> <span class="n">data</span><span class="o">.</span><span class="na">map</span><span class="o">(</span><span class="n">line</span> <span class="o">-></span> <span class="o">{</span> |
| <span class="n">String</span><span class="o">[]</span> <span class="n">parts</span> <span class="o">=</span> <span class="n">line</span><span class="o">.</span><span class="na">split</span><span class="o">(</span><span class="s">"::"</span><span class="o">);</span> |
| <span class="k">return</span> <span class="k">new</span> <span class="n">Rating</span><span class="o">(</span><span class="n">Integer</span><span class="o">.</span><span class="na">parseInt</span><span class="o">(</span><span class="n">parts</span><span class="o">[</span><span class="mi">0</span><span class="o">]),</span> <span class="n">Integer</span><span class="o">.</span><span class="na">parseInt</span><span class="o">(</span><span class="n">parts</span><span class="o">[</span><span class="mi">1</span><span class="o">]),</span> <span class="n">Double</span> |
| <span class="o">.</span><span class="na">parseDouble</span><span class="o">(</span><span class="n">parts</span><span class="o">[</span><span class="mi">2</span><span class="o">])</span> <span class="o">-</span> <span class="mf">2.5</span><span class="o">);</span> |
| <span class="o">});</span> |
| <span class="n">ratings</span><span class="o">.</span><span class="na">cache</span><span class="o">();</span> |
| |
| <span class="c1">// Train an ALS model</span> |
| <span class="n">MatrixFactorizationModel</span> <span class="n">model</span> <span class="o">=</span> <span class="n">ALS</span><span class="o">.</span><span class="na">train</span><span class="o">(</span><span class="n">JavaRDD</span><span class="o">.</span><span class="na">toRDD</span><span class="o">(</span><span class="n">ratings</span><span class="o">),</span> <span class="mi">10</span><span class="o">,</span> <span class="mi">10</span><span class="o">,</span> <span class="mf">0.01</span><span class="o">);</span> |
| |
| <span class="c1">// Get top 10 recommendations for every user and scale ratings from 0 to 1</span> |
| <span class="n">JavaRDD</span><span class="o"><</span><span class="n">Tuple2</span><span class="o"><</span><span class="n">Object</span><span class="o">,</span> <span class="n">Rating</span><span class="o">[]>></span> <span class="n">userRecs</span> <span class="o">=</span> <span class="n">model</span><span class="o">.</span><span class="na">recommendProductsForUsers</span><span class="o">(</span><span class="mi">10</span><span class="o">).</span><span class="na">toJavaRDD</span><span class="o">();</span> |
| <span class="n">JavaRDD</span><span class="o"><</span><span class="n">Tuple2</span><span class="o"><</span><span class="n">Object</span><span class="o">,</span> <span class="n">Rating</span><span class="o">[]>></span> <span class="n">userRecsScaled</span> <span class="o">=</span> <span class="n">userRecs</span><span class="o">.</span><span class="na">map</span><span class="o">(</span><span class="n">t</span> <span class="o">-></span> <span class="o">{</span> |
| <span class="n">Rating</span><span class="o">[]</span> <span class="n">scaledRatings</span> <span class="o">=</span> <span class="k">new</span> <span class="n">Rating</span><span class="o">[</span><span class="n">t</span><span class="o">.</span><span class="na">_2</span><span class="o">().</span><span class="na">length</span><span class="o">];</span> |
| <span class="k">for</span> <span class="o">(</span><span class="kt">int</span> <span class="n">i</span> <span class="o">=</span> <span class="mi">0</span><span class="o">;</span> <span class="n">i</span> <span class="o"><</span> <span class="n">scaledRatings</span><span class="o">.</span><span class="na">length</span><span class="o">;</span> <span class="n">i</span><span class="o">++)</span> <span class="o">{</span> |
| <span class="kt">double</span> <span class="n">newRating</span> <span class="o">=</span> <span class="n">Math</span><span class="o">.</span><span class="na">max</span><span class="o">(</span><span class="n">Math</span><span class="o">.</span><span class="na">min</span><span class="o">(</span><span class="n">t</span><span class="o">.</span><span class="na">_2</span><span class="o">()[</span><span class="n">i</span><span class="o">].</span><span class="na">rating</span><span class="o">(),</span> <span class="mf">1.0</span><span class="o">),</span> <span class="mf">0.0</span><span class="o">);</span> |
| <span class="n">scaledRatings</span><span class="o">[</span><span class="n">i</span><span class="o">]</span> <span class="o">=</span> <span class="k">new</span> <span class="n">Rating</span><span class="o">(</span><span class="n">t</span><span class="o">.</span><span class="na">_2</span><span class="o">()[</span><span class="n">i</span><span class="o">].</span><span class="na">user</span><span class="o">(),</span> <span class="n">t</span><span class="o">.</span><span class="na">_2</span><span class="o">()[</span><span class="n">i</span><span class="o">].</span><span class="na">product</span><span class="o">(),</span> <span class="n">newRating</span><span class="o">);</span> |
| <span class="o">}</span> |
| <span class="k">return</span> <span class="k">new</span> <span class="n">Tuple2</span><span class="o"><>(</span><span class="n">t</span><span class="o">.</span><span class="na">_1</span><span class="o">(),</span> <span class="n">scaledRatings</span><span class="o">);</span> |
| <span class="o">});</span> |
| <span class="n">JavaPairRDD</span><span class="o"><</span><span class="n">Object</span><span class="o">,</span> <span class="n">Rating</span><span class="o">[]></span> <span class="n">userRecommended</span> <span class="o">=</span> <span class="n">JavaPairRDD</span><span class="o">.</span><span class="na">fromJavaRDD</span><span class="o">(</span><span class="n">userRecsScaled</span><span class="o">);</span> |
| |
| <span class="c1">// Map ratings to 1 or 0, 1 indicating a movie that should be recommended</span> |
| <span class="n">JavaRDD</span><span class="o"><</span><span class="n">Rating</span><span class="o">></span> <span class="n">binarizedRatings</span> <span class="o">=</span> <span class="n">ratings</span><span class="o">.</span><span class="na">map</span><span class="o">(</span><span class="n">r</span> <span class="o">-></span> <span class="o">{</span> |
| <span class="kt">double</span> <span class="n">binaryRating</span><span class="o">;</span> |
| <span class="k">if</span> <span class="o">(</span><span class="n">r</span><span class="o">.</span><span class="na">rating</span><span class="o">()</span> <span class="o">></span> <span class="mf">0.0</span><span class="o">)</span> <span class="o">{</span> |
| <span class="n">binaryRating</span> <span class="o">=</span> <span class="mf">1.0</span><span class="o">;</span> |
| <span class="o">}</span> <span class="k">else</span> <span class="o">{</span> |
| <span class="n">binaryRating</span> <span class="o">=</span> <span class="mf">0.0</span><span class="o">;</span> |
| <span class="o">}</span> |
| <span class="k">return</span> <span class="k">new</span> <span class="n">Rating</span><span class="o">(</span><span class="n">r</span><span class="o">.</span><span class="na">user</span><span class="o">(),</span> <span class="n">r</span><span class="o">.</span><span class="na">product</span><span class="o">(),</span> <span class="n">binaryRating</span><span class="o">);</span> |
| <span class="o">});</span> |
| |
| <span class="c1">// Group ratings by common user</span> |
| <span class="n">JavaPairRDD</span><span class="o"><</span><span class="n">Object</span><span class="o">,</span> <span class="n">Iterable</span><span class="o"><</span><span class="n">Rating</span><span class="o">>></span> <span class="n">userMovies</span> <span class="o">=</span> <span class="n">binarizedRatings</span><span class="o">.</span><span class="na">groupBy</span><span class="o">(</span><span class="n">Rating</span><span class="o">::</span><span class="n">user</span><span class="o">);</span> |
| |
| <span class="c1">// Get true relevant documents from all user ratings</span> |
| <span class="n">JavaPairRDD</span><span class="o"><</span><span class="n">Object</span><span class="o">,</span> <span class="n">List</span><span class="o"><</span><span class="n">Integer</span><span class="o">>></span> <span class="n">userMoviesList</span> <span class="o">=</span> <span class="n">userMovies</span><span class="o">.</span><span class="na">mapValues</span><span class="o">(</span><span class="n">docs</span> <span class="o">-></span> <span class="o">{</span> |
| <span class="n">List</span><span class="o"><</span><span class="n">Integer</span><span class="o">></span> <span class="n">products</span> <span class="o">=</span> <span class="k">new</span> <span class="n">ArrayList</span><span class="o"><>();</span> |
| <span class="k">for</span> <span class="o">(</span><span class="n">Rating</span> <span class="n">r</span> <span class="o">:</span> <span class="n">docs</span><span class="o">)</span> <span class="o">{</span> |
| <span class="k">if</span> <span class="o">(</span><span class="n">r</span><span class="o">.</span><span class="na">rating</span><span class="o">()</span> <span class="o">></span> <span class="mf">0.0</span><span class="o">)</span> <span class="o">{</span> |
| <span class="n">products</span><span class="o">.</span><span class="na">add</span><span class="o">(</span><span class="n">r</span><span class="o">.</span><span class="na">product</span><span class="o">());</span> |
| <span class="o">}</span> |
| <span class="o">}</span> |
| <span class="k">return</span> <span class="n">products</span><span class="o">;</span> |
| <span class="o">});</span> |
| |
| <span class="c1">// Extract the product id from each recommendation</span> |
| <span class="n">JavaPairRDD</span><span class="o"><</span><span class="n">Object</span><span class="o">,</span> <span class="n">List</span><span class="o"><</span><span class="n">Integer</span><span class="o">>></span> <span class="n">userRecommendedList</span> <span class="o">=</span> <span class="n">userRecommended</span><span class="o">.</span><span class="na">mapValues</span><span class="o">(</span><span class="n">docs</span> <span class="o">-></span> <span class="o">{</span> |
| <span class="n">List</span><span class="o"><</span><span class="n">Integer</span><span class="o">></span> <span class="n">products</span> <span class="o">=</span> <span class="k">new</span> <span class="n">ArrayList</span><span class="o"><>();</span> |
| <span class="k">for</span> <span class="o">(</span><span class="n">Rating</span> <span class="n">r</span> <span class="o">:</span> <span class="n">docs</span><span class="o">)</span> <span class="o">{</span> |
| <span class="n">products</span><span class="o">.</span><span class="na">add</span><span class="o">(</span><span class="n">r</span><span class="o">.</span><span class="na">product</span><span class="o">());</span> |
| <span class="o">}</span> |
| <span class="k">return</span> <span class="n">products</span><span class="o">;</span> |
| <span class="o">});</span> |
| <span class="n">JavaRDD</span><span class="o"><</span><span class="n">Tuple2</span><span class="o"><</span><span class="n">List</span><span class="o"><</span><span class="n">Integer</span><span class="o">>,</span> <span class="n">List</span><span class="o"><</span><span class="n">Integer</span><span class="o">>>></span> <span class="n">relevantDocs</span> <span class="o">=</span> <span class="n">userMoviesList</span><span class="o">.</span><span class="na">join</span><span class="o">(</span> |
| <span class="n">userRecommendedList</span><span class="o">).</span><span class="na">values</span><span class="o">();</span> |
| |
| <span class="c1">// Instantiate the metrics object</span> |
| <span class="n">RankingMetrics</span><span class="o"><</span><span class="n">Integer</span><span class="o">></span> <span class="n">metrics</span> <span class="o">=</span> <span class="n">RankingMetrics</span><span class="o">.</span><span class="na">of</span><span class="o">(</span><span class="n">relevantDocs</span><span class="o">);</span> |
| |
| <span class="c1">// Precision and NDCG at k</span> |
| <span class="n">Integer</span><span class="o">[]</span> <span class="n">kVector</span> <span class="o">=</span> <span class="o">{</span><span class="mi">1</span><span class="o">,</span> <span class="mi">3</span><span class="o">,</span> <span class="mi">5</span><span class="o">};</span> |
| <span class="k">for</span> <span class="o">(</span><span class="n">Integer</span> <span class="n">k</span> <span class="o">:</span> <span class="n">kVector</span><span class="o">)</span> <span class="o">{</span> |
| <span class="n">System</span><span class="o">.</span><span class="na">out</span><span class="o">.</span><span class="na">format</span><span class="o">(</span><span class="s">"Precision at %d = %f\n"</span><span class="o">,</span> <span class="n">k</span><span class="o">,</span> <span class="n">metrics</span><span class="o">.</span><span class="na">precisionAt</span><span class="o">(</span><span class="n">k</span><span class="o">));</span> |
| <span class="n">System</span><span class="o">.</span><span class="na">out</span><span class="o">.</span><span class="na">format</span><span class="o">(</span><span class="s">"NDCG at %d = %f\n"</span><span class="o">,</span> <span class="n">k</span><span class="o">,</span> <span class="n">metrics</span><span class="o">.</span><span class="na">ndcgAt</span><span class="o">(</span><span class="n">k</span><span class="o">));</span> |
| <span class="o">}</span> |
| |
| <span class="c1">// Mean average precision</span> |
| <span class="n">System</span><span class="o">.</span><span class="na">out</span><span class="o">.</span><span class="na">format</span><span class="o">(</span><span class="s">"Mean average precision = %f\n"</span><span class="o">,</span> <span class="n">metrics</span><span class="o">.</span><span class="na">meanAveragePrecision</span><span class="o">());</span> |
| |
| <span class="c1">// Evaluate the model using numerical ratings and regression metrics</span> |
| <span class="n">JavaRDD</span><span class="o"><</span><span class="n">Tuple2</span><span class="o"><</span><span class="n">Object</span><span class="o">,</span> <span class="n">Object</span><span class="o">>></span> <span class="n">userProducts</span> <span class="o">=</span> |
| <span class="n">ratings</span><span class="o">.</span><span class="na">map</span><span class="o">(</span><span class="n">r</span> <span class="o">-></span> <span class="k">new</span> <span class="n">Tuple2</span><span class="o"><>(</span><span class="n">r</span><span class="o">.</span><span class="na">user</span><span class="o">(),</span> <span class="n">r</span><span class="o">.</span><span class="na">product</span><span class="o">()));</span> |
| |
| <span class="n">JavaPairRDD</span><span class="o"><</span><span class="n">Tuple2</span><span class="o"><</span><span class="n">Integer</span><span class="o">,</span> <span class="n">Integer</span><span class="o">>,</span> <span class="n">Object</span><span class="o">></span> <span class="n">predictions</span> <span class="o">=</span> <span class="n">JavaPairRDD</span><span class="o">.</span><span class="na">fromJavaRDD</span><span class="o">(</span> |
| <span class="n">model</span><span class="o">.</span><span class="na">predict</span><span class="o">(</span><span class="n">JavaRDD</span><span class="o">.</span><span class="na">toRDD</span><span class="o">(</span><span class="n">userProducts</span><span class="o">)).</span><span class="na">toJavaRDD</span><span class="o">().</span><span class="na">map</span><span class="o">(</span><span class="n">r</span> <span class="o">-></span> |
| <span class="k">new</span> <span class="n">Tuple2</span><span class="o"><>(</span><span class="k">new</span> <span class="n">Tuple2</span><span class="o"><>(</span><span class="n">r</span><span class="o">.</span><span class="na">user</span><span class="o">(),</span> <span class="n">r</span><span class="o">.</span><span class="na">product</span><span class="o">()),</span> <span class="n">r</span><span class="o">.</span><span class="na">rating</span><span class="o">())));</span> |
| <span class="n">JavaRDD</span><span class="o"><</span><span class="n">Tuple2</span><span class="o"><</span><span class="n">Object</span><span class="o">,</span> <span class="n">Object</span><span class="o">>></span> <span class="n">ratesAndPreds</span> <span class="o">=</span> |
| <span class="n">JavaPairRDD</span><span class="o">.</span><span class="na">fromJavaRDD</span><span class="o">(</span><span class="n">ratings</span><span class="o">.</span><span class="na">map</span><span class="o">(</span><span class="n">r</span> <span class="o">-></span> |
| <span class="k">new</span> <span class="n">Tuple2</span><span class="o"><</span><span class="n">Tuple2</span><span class="o"><</span><span class="n">Integer</span><span class="o">,</span> <span class="n">Integer</span><span class="o">>,</span> <span class="n">Object</span><span class="o">>(</span> |
| <span class="k">new</span> <span class="n">Tuple2</span><span class="o"><>(</span><span class="n">r</span><span class="o">.</span><span class="na">user</span><span class="o">(),</span> <span class="n">r</span><span class="o">.</span><span class="na">product</span><span class="o">()),</span> |
| <span class="n">r</span><span class="o">.</span><span class="na">rating</span><span class="o">())</span> |
| <span class="o">)).</span><span class="na">join</span><span class="o">(</span><span class="n">predictions</span><span class="o">).</span><span class="na">values</span><span class="o">();</span> |
| |
| <span class="c1">// Create regression metrics object</span> |
| <span class="n">RegressionMetrics</span> <span class="n">regressionMetrics</span> <span class="o">=</span> <span class="k">new</span> <span class="n">RegressionMetrics</span><span class="o">(</span><span class="n">ratesAndPreds</span><span class="o">.</span><span class="na">rdd</span><span class="o">());</span> |
| |
| <span class="c1">// Root mean squared error</span> |
| <span class="n">System</span><span class="o">.</span><span class="na">out</span><span class="o">.</span><span class="na">format</span><span class="o">(</span><span class="s">"RMSE = %f\n"</span><span class="o">,</span> <span class="n">regressionMetrics</span><span class="o">.</span><span class="na">rootMeanSquaredError</span><span class="o">());</span> |
| |
| <span class="c1">// R-squared</span> |
| <span class="n">System</span><span class="o">.</span><span class="na">out</span><span class="o">.</span><span class="na">format</span><span class="o">(</span><span class="s">"R-squared = %f\n"</span><span class="o">,</span> <span class="n">regressionMetrics</span><span class="o">.</span><span class="na">r2</span><span class="o">());</span> |
| </pre></div> |
| <div><small>Find full example code at "examples/src/main/java/org/apache/spark/examples/mllib/JavaRankingMetricsExample.java" in the Spark repo.</small></div> |
| |
| </div> |
| |
| <div data-lang="python"> |
| <p>Refer to the <a href="api/python/pyspark.mllib.html#pyspark.mllib.evaluation.RegressionMetrics"><code>RegressionMetrics</code> Python docs</a> and <a href="api/python/pyspark.mllib.html#pyspark.mllib.evaluation.RankingMetrics"><code>RankingMetrics</code> Python docs</a> for more details on the API.</p> |
| |
| <div class="highlight"><pre><span></span><span class="kn">from</span> <span class="nn">pyspark.mllib.recommendation</span> <span class="kn">import</span> <span class="n">ALS</span><span class="p">,</span> <span class="n">Rating</span> |
| <span class="kn">from</span> <span class="nn">pyspark.mllib.evaluation</span> <span class="kn">import</span> <span class="n">RegressionMetrics</span><span class="p">,</span> <span class="n">RankingMetrics</span> |
| |
| <span class="c1"># Read in the ratings data</span> |
| <span class="n">lines</span> <span class="o">=</span> <span class="n">sc</span><span class="o">.</span><span class="n">textFile</span><span class="p">(</span><span class="s2">"data/mllib/sample_movielens_data.txt"</span><span class="p">)</span> |
| |
| <span class="k">def</span> <span class="nf">parseLine</span><span class="p">(</span><span class="n">line</span><span class="p">):</span> |
| <span class="n">fields</span> <span class="o">=</span> <span class="n">line</span><span class="o">.</span><span class="n">split</span><span class="p">(</span><span class="s2">"::"</span><span class="p">)</span> |
| <span class="k">return</span> <span class="n">Rating</span><span class="p">(</span><span class="nb">int</span><span class="p">(</span><span class="n">fields</span><span class="p">[</span><span class="mi">0</span><span class="p">]),</span> <span class="nb">int</span><span class="p">(</span><span class="n">fields</span><span class="p">[</span><span class="mi">1</span><span class="p">]),</span> <span class="nb">float</span><span class="p">(</span><span class="n">fields</span><span class="p">[</span><span class="mi">2</span><span class="p">])</span> <span class="o">-</span> <span class="mf">2.5</span><span class="p">)</span> |
| <span class="n">ratings</span> <span class="o">=</span> <span class="n">lines</span><span class="o">.</span><span class="n">map</span><span class="p">(</span><span class="k">lambda</span> <span class="n">r</span><span class="p">:</span> <span class="n">parseLine</span><span class="p">(</span><span class="n">r</span><span class="p">))</span> |
| |
| <span class="c1"># Train a model on to predict user-product ratings</span> |
| <span class="n">model</span> <span class="o">=</span> <span class="n">ALS</span><span class="o">.</span><span class="n">train</span><span class="p">(</span><span class="n">ratings</span><span class="p">,</span> <span class="mi">10</span><span class="p">,</span> <span class="mi">10</span><span class="p">,</span> <span class="mf">0.01</span><span class="p">)</span> |
| |
| <span class="c1"># Get predicted ratings on all existing user-product pairs</span> |
| <span class="n">testData</span> <span class="o">=</span> <span class="n">ratings</span><span class="o">.</span><span class="n">map</span><span class="p">(</span><span class="k">lambda</span> <span class="n">p</span><span class="p">:</span> <span class="p">(</span><span class="n">p</span><span class="o">.</span><span class="n">user</span><span class="p">,</span> <span class="n">p</span><span class="o">.</span><span class="n">product</span><span class="p">))</span> |
| <span class="n">predictions</span> <span class="o">=</span> <span class="n">model</span><span class="o">.</span><span class="n">predictAll</span><span class="p">(</span><span class="n">testData</span><span class="p">)</span><span class="o">.</span><span class="n">map</span><span class="p">(</span><span class="k">lambda</span> <span class="n">r</span><span class="p">:</span> <span class="p">((</span><span class="n">r</span><span class="o">.</span><span class="n">user</span><span class="p">,</span> <span class="n">r</span><span class="o">.</span><span class="n">product</span><span class="p">),</span> <span class="n">r</span><span class="o">.</span><span class="n">rating</span><span class="p">))</span> |
| |
| <span class="n">ratingsTuple</span> <span class="o">=</span> <span class="n">ratings</span><span class="o">.</span><span class="n">map</span><span class="p">(</span><span class="k">lambda</span> <span class="n">r</span><span class="p">:</span> <span class="p">((</span><span class="n">r</span><span class="o">.</span><span class="n">user</span><span class="p">,</span> <span class="n">r</span><span class="o">.</span><span class="n">product</span><span class="p">),</span> <span class="n">r</span><span class="o">.</span><span class="n">rating</span><span class="p">))</span> |
| <span class="n">scoreAndLabels</span> <span class="o">=</span> <span class="n">predictions</span><span class="o">.</span><span class="n">join</span><span class="p">(</span><span class="n">ratingsTuple</span><span class="p">)</span><span class="o">.</span><span class="n">map</span><span class="p">(</span><span class="k">lambda</span> <span class="n">tup</span><span class="p">:</span> <span class="n">tup</span><span class="p">[</span><span class="mi">1</span><span class="p">])</span> |
| |
| <span class="c1"># Instantiate regression metrics to compare predicted and actual ratings</span> |
| <span class="n">metrics</span> <span class="o">=</span> <span class="n">RegressionMetrics</span><span class="p">(</span><span class="n">scoreAndLabels</span><span class="p">)</span> |
| |
| <span class="c1"># Root mean squared error</span> |
| <span class="k">print</span><span class="p">(</span><span class="s2">"RMSE = </span><span class="si">%s</span><span class="s2">"</span> <span class="o">%</span> <span class="n">metrics</span><span class="o">.</span><span class="n">rootMeanSquaredError</span><span class="p">)</span> |
| |
| <span class="c1"># R-squared</span> |
| <span class="k">print</span><span class="p">(</span><span class="s2">"R-squared = </span><span class="si">%s</span><span class="s2">"</span> <span class="o">%</span> <span class="n">metrics</span><span class="o">.</span><span class="n">r2</span><span class="p">)</span> |
| </pre></div> |
| <div><small>Find full example code at "examples/src/main/python/mllib/ranking_metrics_example.py" in the Spark repo.</small></div> |
| |
| </div> |
| </div> |
| |
| <h2 id="regression-model-evaluation">Regression model evaluation</h2> |
| |
| <p><a href="https://en.wikipedia.org/wiki/Regression_analysis">Regression analysis</a> is used when predicting a continuous output |
| variable from a number of independent variables.</p> |
| |
| <p><strong>Available metrics</strong></p> |
| |
| <table class="table"> |
| <thead> |
| <tr><th>Metric</th><th>Definition</th></tr> |
| </thead> |
| <tbody> |
| <tr> |
| <td>Mean Squared Error (MSE)</td> |
| <td>$MSE = \frac{\sum_{i=0}^{N-1} (\mathbf{y}_i - \hat{\mathbf{y}}_i)^2}{N}$</td> |
| </tr> |
| <tr> |
| <td>Root Mean Squared Error (RMSE)</td> |
| <td>$RMSE = \sqrt{\frac{\sum_{i=0}^{N-1} (\mathbf{y}_i - \hat{\mathbf{y}}_i)^2}{N}}$</td> |
| </tr> |
| <tr> |
| <td>Mean Absolute Error (MAE)</td> |
| <td>$MAE=\frac{1}{N}\sum_{i=0}^{N-1} \left|\mathbf{y}_i - \hat{\mathbf{y}}_i\right|$</td> |
| </tr> |
| <tr> |
| <td>Coefficient of Determination $(R^2)$</td> |
| <td>$R^2=1 - \frac{MSE}{\text{VAR}(\mathbf{y}) \cdot (N-1)}=1-\frac{\sum_{i=0}^{N-1} |
| (\mathbf{y}_i - \hat{\mathbf{y}}_i)^2}{\sum_{i=0}^{N-1}(\mathbf{y}_i-\bar{\mathbf{y}})^2}$</td> |
| </tr> |
| <tr> |
| <td>Explained Variance</td> |
| <td>$1 - \frac{\text{VAR}(\mathbf{y} - \mathbf{\hat{y}})}{\text{VAR}(\mathbf{y})}$</td> |
| </tr> |
| </tbody> |
| </table> |
| |
| <p><strong>Examples</strong></p> |
| |
| <div class="codetabs"> |
| The following code snippets illustrate how to load a sample dataset, train a linear regression algorithm on the data, |
| and evaluate the performance of the algorithm by several regression metrics. |
| |
| <div data-lang="scala"> |
| <p>Refer to the <a href="api/scala/index.html#org.apache.spark.mllib.evaluation.RegressionMetrics"><code>RegressionMetrics</code> Scala docs</a> for details on the API.</p> |
| |
| <div class="highlight"><pre><span></span><span class="k">import</span> <span class="nn">org.apache.spark.mllib.evaluation.RegressionMetrics</span> |
| <span class="k">import</span> <span class="nn">org.apache.spark.mllib.linalg.Vector</span> |
| <span class="k">import</span> <span class="nn">org.apache.spark.mllib.regression.</span><span class="o">{</span><span class="nc">LabeledPoint</span><span class="o">,</span> <span class="nc">LinearRegressionWithSGD</span><span class="o">}</span> |
| |
| <span class="c1">// Load the data</span> |
| <span class="k">val</span> <span class="n">data</span> <span class="k">=</span> <span class="n">spark</span> |
| <span class="o">.</span><span class="n">read</span><span class="o">.</span><span class="n">format</span><span class="o">(</span><span class="s">"libsvm"</span><span class="o">).</span><span class="n">load</span><span class="o">(</span><span class="s">"data/mllib/sample_linear_regression_data.txt"</span><span class="o">)</span> |
| <span class="o">.</span><span class="n">rdd</span><span class="o">.</span><span class="n">map</span><span class="o">(</span><span class="n">row</span> <span class="k">=></span> <span class="nc">LabeledPoint</span><span class="o">(</span><span class="n">row</span><span class="o">.</span><span class="n">getDouble</span><span class="o">(</span><span class="mi">0</span><span class="o">),</span> <span class="n">row</span><span class="o">.</span><span class="n">get</span><span class="o">(</span><span class="mi">1</span><span class="o">).</span><span class="n">asInstanceOf</span><span class="o">[</span><span class="kt">Vector</span><span class="o">]))</span> |
| <span class="o">.</span><span class="n">cache</span><span class="o">()</span> |
| |
| <span class="c1">// Build the model</span> |
| <span class="k">val</span> <span class="n">numIterations</span> <span class="k">=</span> <span class="mi">100</span> |
| <span class="k">val</span> <span class="n">model</span> <span class="k">=</span> <span class="nc">LinearRegressionWithSGD</span><span class="o">.</span><span class="n">train</span><span class="o">(</span><span class="n">data</span><span class="o">,</span> <span class="n">numIterations</span><span class="o">)</span> |
| |
| <span class="c1">// Get predictions</span> |
| <span class="k">val</span> <span class="n">valuesAndPreds</span> <span class="k">=</span> <span class="n">data</span><span class="o">.</span><span class="n">map</span><span class="o">{</span> <span class="n">point</span> <span class="k">=></span> |
| <span class="k">val</span> <span class="n">prediction</span> <span class="k">=</span> <span class="n">model</span><span class="o">.</span><span class="n">predict</span><span class="o">(</span><span class="n">point</span><span class="o">.</span><span class="n">features</span><span class="o">)</span> |
| <span class="o">(</span><span class="n">prediction</span><span class="o">,</span> <span class="n">point</span><span class="o">.</span><span class="n">label</span><span class="o">)</span> |
| <span class="o">}</span> |
| |
| <span class="c1">// Instantiate metrics object</span> |
| <span class="k">val</span> <span class="n">metrics</span> <span class="k">=</span> <span class="k">new</span> <span class="nc">RegressionMetrics</span><span class="o">(</span><span class="n">valuesAndPreds</span><span class="o">)</span> |
| |
| <span class="c1">// Squared error</span> |
| <span class="n">println</span><span class="o">(</span><span class="s">s"MSE = </span><span class="si">${</span><span class="n">metrics</span><span class="o">.</span><span class="n">meanSquaredError</span><span class="si">}</span><span class="s">"</span><span class="o">)</span> |
| <span class="n">println</span><span class="o">(</span><span class="s">s"RMSE = </span><span class="si">${</span><span class="n">metrics</span><span class="o">.</span><span class="n">rootMeanSquaredError</span><span class="si">}</span><span class="s">"</span><span class="o">)</span> |
| |
| <span class="c1">// R-squared</span> |
| <span class="n">println</span><span class="o">(</span><span class="s">s"R-squared = </span><span class="si">${</span><span class="n">metrics</span><span class="o">.</span><span class="n">r2</span><span class="si">}</span><span class="s">"</span><span class="o">)</span> |
| |
| <span class="c1">// Mean absolute error</span> |
| <span class="n">println</span><span class="o">(</span><span class="s">s"MAE = </span><span class="si">${</span><span class="n">metrics</span><span class="o">.</span><span class="n">meanAbsoluteError</span><span class="si">}</span><span class="s">"</span><span class="o">)</span> |
| |
| <span class="c1">// Explained variance</span> |
| <span class="n">println</span><span class="o">(</span><span class="s">s"Explained variance = </span><span class="si">${</span><span class="n">metrics</span><span class="o">.</span><span class="n">explainedVariance</span><span class="si">}</span><span class="s">"</span><span class="o">)</span> |
| </pre></div> |
| <div><small>Find full example code at "examples/src/main/scala/org/apache/spark/examples/mllib/RegressionMetricsExample.scala" in the Spark repo.</small></div> |
| |
| </div> |
| |
| <div data-lang="java"> |
| <p>Refer to the <a href="api/java/org/apache/spark/mllib/evaluation/RegressionMetrics.html"><code>RegressionMetrics</code> Java docs</a> for details on the API.</p> |
| |
| <div class="highlight"><pre><span></span><span class="kn">import</span> <span class="nn">scala.Tuple2</span><span class="o">;</span> |
| |
| <span class="kn">import</span> <span class="nn">org.apache.spark.api.java.*</span><span class="o">;</span> |
| <span class="kn">import</span> <span class="nn">org.apache.spark.mllib.linalg.Vectors</span><span class="o">;</span> |
| <span class="kn">import</span> <span class="nn">org.apache.spark.mllib.regression.LabeledPoint</span><span class="o">;</span> |
| <span class="kn">import</span> <span class="nn">org.apache.spark.mllib.regression.LinearRegressionModel</span><span class="o">;</span> |
| <span class="kn">import</span> <span class="nn">org.apache.spark.mllib.regression.LinearRegressionWithSGD</span><span class="o">;</span> |
| <span class="kn">import</span> <span class="nn">org.apache.spark.mllib.evaluation.RegressionMetrics</span><span class="o">;</span> |
| <span class="kn">import</span> <span class="nn">org.apache.spark.SparkConf</span><span class="o">;</span> |
| |
| <span class="c1">// Load and parse the data</span> |
| <span class="n">String</span> <span class="n">path</span> <span class="o">=</span> <span class="s">"data/mllib/sample_linear_regression_data.txt"</span><span class="o">;</span> |
| <span class="n">JavaRDD</span><span class="o"><</span><span class="n">String</span><span class="o">></span> <span class="n">data</span> <span class="o">=</span> <span class="n">sc</span><span class="o">.</span><span class="na">textFile</span><span class="o">(</span><span class="n">path</span><span class="o">);</span> |
| <span class="n">JavaRDD</span><span class="o"><</span><span class="n">LabeledPoint</span><span class="o">></span> <span class="n">parsedData</span> <span class="o">=</span> <span class="n">data</span><span class="o">.</span><span class="na">map</span><span class="o">(</span><span class="n">line</span> <span class="o">-></span> <span class="o">{</span> |
| <span class="n">String</span><span class="o">[]</span> <span class="n">parts</span> <span class="o">=</span> <span class="n">line</span><span class="o">.</span><span class="na">split</span><span class="o">(</span><span class="s">" "</span><span class="o">);</span> |
| <span class="kt">double</span><span class="o">[]</span> <span class="n">v</span> <span class="o">=</span> <span class="k">new</span> <span class="kt">double</span><span class="o">[</span><span class="n">parts</span><span class="o">.</span><span class="na">length</span> <span class="o">-</span> <span class="mi">1</span><span class="o">];</span> |
| <span class="k">for</span> <span class="o">(</span><span class="kt">int</span> <span class="n">i</span> <span class="o">=</span> <span class="mi">1</span><span class="o">;</span> <span class="n">i</span> <span class="o"><</span> <span class="n">parts</span><span class="o">.</span><span class="na">length</span><span class="o">;</span> <span class="n">i</span><span class="o">++)</span> <span class="o">{</span> |
| <span class="n">v</span><span class="o">[</span><span class="n">i</span> <span class="o">-</span> <span class="mi">1</span><span class="o">]</span> <span class="o">=</span> <span class="n">Double</span><span class="o">.</span><span class="na">parseDouble</span><span class="o">(</span><span class="n">parts</span><span class="o">[</span><span class="n">i</span><span class="o">].</span><span class="na">split</span><span class="o">(</span><span class="s">":"</span><span class="o">)[</span><span class="mi">1</span><span class="o">]);</span> |
| <span class="o">}</span> |
| <span class="k">return</span> <span class="k">new</span> <span class="n">LabeledPoint</span><span class="o">(</span><span class="n">Double</span><span class="o">.</span><span class="na">parseDouble</span><span class="o">(</span><span class="n">parts</span><span class="o">[</span><span class="mi">0</span><span class="o">]),</span> <span class="n">Vectors</span><span class="o">.</span><span class="na">dense</span><span class="o">(</span><span class="n">v</span><span class="o">));</span> |
| <span class="o">});</span> |
| <span class="n">parsedData</span><span class="o">.</span><span class="na">cache</span><span class="o">();</span> |
| |
| <span class="c1">// Building the model</span> |
| <span class="kt">int</span> <span class="n">numIterations</span> <span class="o">=</span> <span class="mi">100</span><span class="o">;</span> |
| <span class="n">LinearRegressionModel</span> <span class="n">model</span> <span class="o">=</span> <span class="n">LinearRegressionWithSGD</span><span class="o">.</span><span class="na">train</span><span class="o">(</span><span class="n">JavaRDD</span><span class="o">.</span><span class="na">toRDD</span><span class="o">(</span><span class="n">parsedData</span><span class="o">),</span> |
| <span class="n">numIterations</span><span class="o">);</span> |
| |
| <span class="c1">// Evaluate model on training examples and compute training error</span> |
| <span class="n">JavaPairRDD</span><span class="o"><</span><span class="n">Object</span><span class="o">,</span> <span class="n">Object</span><span class="o">></span> <span class="n">valuesAndPreds</span> <span class="o">=</span> <span class="n">parsedData</span><span class="o">.</span><span class="na">mapToPair</span><span class="o">(</span><span class="n">point</span> <span class="o">-></span> |
| <span class="k">new</span> <span class="n">Tuple2</span><span class="o"><>(</span><span class="n">model</span><span class="o">.</span><span class="na">predict</span><span class="o">(</span><span class="n">point</span><span class="o">.</span><span class="na">features</span><span class="o">()),</span> <span class="n">point</span><span class="o">.</span><span class="na">label</span><span class="o">()));</span> |
| |
| <span class="c1">// Instantiate metrics object</span> |
| <span class="n">RegressionMetrics</span> <span class="n">metrics</span> <span class="o">=</span> <span class="k">new</span> <span class="n">RegressionMetrics</span><span class="o">(</span><span class="n">valuesAndPreds</span><span class="o">.</span><span class="na">rdd</span><span class="o">());</span> |
| |
| <span class="c1">// Squared error</span> |
| <span class="n">System</span><span class="o">.</span><span class="na">out</span><span class="o">.</span><span class="na">format</span><span class="o">(</span><span class="s">"MSE = %f\n"</span><span class="o">,</span> <span class="n">metrics</span><span class="o">.</span><span class="na">meanSquaredError</span><span class="o">());</span> |
| <span class="n">System</span><span class="o">.</span><span class="na">out</span><span class="o">.</span><span class="na">format</span><span class="o">(</span><span class="s">"RMSE = %f\n"</span><span class="o">,</span> <span class="n">metrics</span><span class="o">.</span><span class="na">rootMeanSquaredError</span><span class="o">());</span> |
| |
| <span class="c1">// R-squared</span> |
| <span class="n">System</span><span class="o">.</span><span class="na">out</span><span class="o">.</span><span class="na">format</span><span class="o">(</span><span class="s">"R Squared = %f\n"</span><span class="o">,</span> <span class="n">metrics</span><span class="o">.</span><span class="na">r2</span><span class="o">());</span> |
| |
| <span class="c1">// Mean absolute error</span> |
| <span class="n">System</span><span class="o">.</span><span class="na">out</span><span class="o">.</span><span class="na">format</span><span class="o">(</span><span class="s">"MAE = %f\n"</span><span class="o">,</span> <span class="n">metrics</span><span class="o">.</span><span class="na">meanAbsoluteError</span><span class="o">());</span> |
| |
| <span class="c1">// Explained variance</span> |
| <span class="n">System</span><span class="o">.</span><span class="na">out</span><span class="o">.</span><span class="na">format</span><span class="o">(</span><span class="s">"Explained Variance = %f\n"</span><span class="o">,</span> <span class="n">metrics</span><span class="o">.</span><span class="na">explainedVariance</span><span class="o">());</span> |
| |
| <span class="c1">// Save and load model</span> |
| <span class="n">model</span><span class="o">.</span><span class="na">save</span><span class="o">(</span><span class="n">sc</span><span class="o">.</span><span class="na">sc</span><span class="o">(),</span> <span class="s">"target/tmp/LogisticRegressionModel"</span><span class="o">);</span> |
| <span class="n">LinearRegressionModel</span> <span class="n">sameModel</span> <span class="o">=</span> <span class="n">LinearRegressionModel</span><span class="o">.</span><span class="na">load</span><span class="o">(</span><span class="n">sc</span><span class="o">.</span><span class="na">sc</span><span class="o">(),</span> |
| <span class="s">"target/tmp/LogisticRegressionModel"</span><span class="o">);</span> |
| </pre></div> |
| <div><small>Find full example code at "examples/src/main/java/org/apache/spark/examples/mllib/JavaRegressionMetricsExample.java" in the Spark repo.</small></div> |
| |
| </div> |
| |
| <div data-lang="python"> |
| <p>Refer to the <a href="api/python/pyspark.mllib.html#pyspark.mllib.evaluation.RegressionMetrics"><code>RegressionMetrics</code> Python docs</a> for more details on the API.</p> |
| |
| <div class="highlight"><pre><span></span><span class="kn">from</span> <span class="nn">pyspark.mllib.regression</span> <span class="kn">import</span> <span class="n">LabeledPoint</span><span class="p">,</span> <span class="n">LinearRegressionWithSGD</span> |
| <span class="kn">from</span> <span class="nn">pyspark.mllib.evaluation</span> <span class="kn">import</span> <span class="n">RegressionMetrics</span> |
| <span class="kn">from</span> <span class="nn">pyspark.mllib.linalg</span> <span class="kn">import</span> <span class="n">DenseVector</span> |
| |
| <span class="c1"># Load and parse the data</span> |
| <span class="k">def</span> <span class="nf">parsePoint</span><span class="p">(</span><span class="n">line</span><span class="p">):</span> |
| <span class="n">values</span> <span class="o">=</span> <span class="n">line</span><span class="o">.</span><span class="n">split</span><span class="p">()</span> |
| <span class="k">return</span> <span class="n">LabeledPoint</span><span class="p">(</span><span class="nb">float</span><span class="p">(</span><span class="n">values</span><span class="p">[</span><span class="mi">0</span><span class="p">]),</span> |
| <span class="n">DenseVector</span><span class="p">([</span><span class="nb">float</span><span class="p">(</span><span class="n">x</span><span class="o">.</span><span class="n">split</span><span class="p">(</span><span class="s1">':'</span><span class="p">)[</span><span class="mi">1</span><span class="p">])</span> <span class="k">for</span> <span class="n">x</span> <span class="ow">in</span> <span class="n">values</span><span class="p">[</span><span class="mi">1</span><span class="p">:]]))</span> |
| |
| <span class="n">data</span> <span class="o">=</span> <span class="n">sc</span><span class="o">.</span><span class="n">textFile</span><span class="p">(</span><span class="s2">"data/mllib/sample_linear_regression_data.txt"</span><span class="p">)</span> |
| <span class="n">parsedData</span> <span class="o">=</span> <span class="n">data</span><span class="o">.</span><span class="n">map</span><span class="p">(</span><span class="n">parsePoint</span><span class="p">)</span> |
| |
| <span class="c1"># Build the model</span> |
| <span class="n">model</span> <span class="o">=</span> <span class="n">LinearRegressionWithSGD</span><span class="o">.</span><span class="n">train</span><span class="p">(</span><span class="n">parsedData</span><span class="p">)</span> |
| |
| <span class="c1"># Get predictions</span> |
| <span class="n">valuesAndPreds</span> <span class="o">=</span> <span class="n">parsedData</span><span class="o">.</span><span class="n">map</span><span class="p">(</span><span class="k">lambda</span> <span class="n">p</span><span class="p">:</span> <span class="p">(</span><span class="nb">float</span><span class="p">(</span><span class="n">model</span><span class="o">.</span><span class="n">predict</span><span class="p">(</span><span class="n">p</span><span class="o">.</span><span class="n">features</span><span class="p">)),</span> <span class="n">p</span><span class="o">.</span><span class="n">label</span><span class="p">))</span> |
| |
| <span class="c1"># Instantiate metrics object</span> |
| <span class="n">metrics</span> <span class="o">=</span> <span class="n">RegressionMetrics</span><span class="p">(</span><span class="n">valuesAndPreds</span><span class="p">)</span> |
| |
| <span class="c1"># Squared Error</span> |
| <span class="k">print</span><span class="p">(</span><span class="s2">"MSE = </span><span class="si">%s</span><span class="s2">"</span> <span class="o">%</span> <span class="n">metrics</span><span class="o">.</span><span class="n">meanSquaredError</span><span class="p">)</span> |
| <span class="k">print</span><span class="p">(</span><span class="s2">"RMSE = </span><span class="si">%s</span><span class="s2">"</span> <span class="o">%</span> <span class="n">metrics</span><span class="o">.</span><span class="n">rootMeanSquaredError</span><span class="p">)</span> |
| |
| <span class="c1"># R-squared</span> |
| <span class="k">print</span><span class="p">(</span><span class="s2">"R-squared = </span><span class="si">%s</span><span class="s2">"</span> <span class="o">%</span> <span class="n">metrics</span><span class="o">.</span><span class="n">r2</span><span class="p">)</span> |
| |
| <span class="c1"># Mean absolute error</span> |
| <span class="k">print</span><span class="p">(</span><span class="s2">"MAE = </span><span class="si">%s</span><span class="s2">"</span> <span class="o">%</span> <span class="n">metrics</span><span class="o">.</span><span class="n">meanAbsoluteError</span><span class="p">)</span> |
| |
| <span class="c1"># Explained variance</span> |
| <span class="k">print</span><span class="p">(</span><span class="s2">"Explained variance = </span><span class="si">%s</span><span class="s2">"</span> <span class="o">%</span> <span class="n">metrics</span><span class="o">.</span><span class="n">explainedVariance</span><span class="p">)</span> |
| </pre></div> |
| <div><small>Find full example code at "examples/src/main/python/mllib/regression_metrics_example.py" in the Spark repo.</small></div> |
| |
| </div> |
| </div> |
| |
| |
| </div> |
| |
| <!-- /container --> |
| </div> |
| |
| <script src="js/vendor/jquery-1.12.4.min.js"></script> |
| <script src="js/vendor/bootstrap.min.js"></script> |
| <script src="js/vendor/anchor.min.js"></script> |
| <script src="js/main.js"></script> |
| |
| <!-- MathJax Section --> |
| <script type="text/x-mathjax-config"> |
| MathJax.Hub.Config({ |
| TeX: { equationNumbers: { autoNumber: "AMS" } } |
| }); |
| </script> |
| <script> |
| // Note that we load MathJax this way to work with local file (file://), HTTP and HTTPS. |
| // We could use "//cdn.mathjax...", but that won't support "file://". |
| (function(d, script) { |
| script = d.createElement('script'); |
| script.type = 'text/javascript'; |
| script.async = true; |
| script.onload = function(){ |
| MathJax.Hub.Config({ |
| tex2jax: { |
| inlineMath: [ ["$", "$"], ["\\\\(","\\\\)"] ], |
| displayMath: [ ["$$","$$"], ["\\[", "\\]"] ], |
| processEscapes: true, |
| skipTags: ['script', 'noscript', 'style', 'textarea', 'pre'] |
| } |
| }); |
| }; |
| script.src = ('https:' == document.location.protocol ? 'https://' : 'http://') + |
| 'cdnjs.cloudflare.com/ajax/libs/mathjax/2.7.1/MathJax.js' + |
| '?config=TeX-AMS-MML_HTMLorMML'; |
| d.getElementsByTagName('head')[0].appendChild(script); |
| }(document)); |
| </script> |
| </body> |
| </html> |