| |
| <!DOCTYPE html> |
| |
| <html> |
| <head> |
| <meta charset="utf-8" /> |
| <title>pyspark.ml.classification — PySpark 3.4.0 documentation</title> |
| |
| <link rel="stylesheet" href="../../../_static/css/index.73d71520a4ca3b99cfee5594769eaaae.css"> |
| |
| |
| <link rel="stylesheet" |
| href="../../../_static/vendor/fontawesome/5.13.0/css/all.min.css"> |
| <link rel="preload" as="font" type="font/woff2" crossorigin |
| href="../../../_static/vendor/fontawesome/5.13.0/webfonts/fa-solid-900.woff2"> |
| <link rel="preload" as="font" type="font/woff2" crossorigin |
| href="../../../_static/vendor/fontawesome/5.13.0/webfonts/fa-brands-400.woff2"> |
| |
| |
| |
| <link rel="stylesheet" |
| href="../../../_static/vendor/open-sans_all/1.44.1/index.css"> |
| <link rel="stylesheet" |
| href="../../../_static/vendor/lato_latin-ext/1.44.1/index.css"> |
| |
| |
| <link rel="stylesheet" href="../../../_static/basic.css" type="text/css" /> |
| <link rel="stylesheet" href="../../../_static/pygments.css" type="text/css" /> |
| <link rel="stylesheet" type="text/css" href="../../../_static/css/pyspark.css" /> |
| |
| <link rel="preload" as="script" href="../../../_static/js/index.3da636dd464baa7582d2.js"> |
| |
| <script id="documentation_options" data-url_root="../../../" src="../../../_static/documentation_options.js"></script> |
| <script src="../../../_static/jquery.js"></script> |
| <script src="../../../_static/underscore.js"></script> |
| <script src="../../../_static/doctools.js"></script> |
| <script src="../../../_static/language_data.js"></script> |
| <script src="../../../_static/copybutton.js"></script> |
| <script crossorigin="anonymous" integrity="sha256-Ae2Vz/4ePdIu6ZyI/5ZGsYnb+m0JlOmKPjt6XZ9JJkA=" src="https://cdnjs.cloudflare.com/ajax/libs/require.js/2.3.4/require.min.js"></script> |
| <script async="async" src="https://cdnjs.cloudflare.com/ajax/libs/mathjax/2.7.5/latest.js?config=TeX-AMS-MML_HTMLorMML"></script> |
| <script type="text/x-mathjax-config">MathJax.Hub.Config({"tex2jax": {"inlineMath": [["$", "$"], ["\\(", "\\)"]], "processEscapes": true, "ignoreClass": "document", "processClass": "math|output_area"}})</script> |
| <link rel="search" title="Search" href="../../../search.html" /> |
| <meta name="viewport" content="width=device-width, initial-scale=1" /> |
| <meta name="docsearch:language" content="en" /> |
| </head> |
| <body data-spy="scroll" data-target="#bd-toc-nav" data-offset="80"> |
| |
| <nav class="navbar navbar-light navbar-expand-lg bg-light fixed-top bd-navbar" id="navbar-main"> |
| <div class="container-xl"> |
| |
| <a class="navbar-brand" href="../../../index.html"> |
| |
| <img src="../../../_static/spark-logo-reverse.png" class="logo" alt="logo" /> |
| |
| </a> |
| <button class="navbar-toggler" type="button" data-toggle="collapse" data-target="#navbar-menu" aria-controls="navbar-menu" aria-expanded="false" aria-label="Toggle navigation"> |
| <span class="navbar-toggler-icon"></span> |
| </button> |
| |
| <div id="navbar-menu" class="col-lg-9 collapse navbar-collapse"> |
| <ul id="navbar-main-elements" class="navbar-nav mr-auto"> |
| |
| |
| <li class="nav-item "> |
| <a class="nav-link" href="../../../index.html">Overview</a> |
| </li> |
| |
| <li class="nav-item "> |
| <a class="nav-link" href="../../../getting_started/index.html">Getting Started</a> |
| </li> |
| |
| <li class="nav-item "> |
| <a class="nav-link" href="../../../user_guide/index.html">User Guides</a> |
| </li> |
| |
| <li class="nav-item "> |
| <a class="nav-link" href="../../../reference/index.html">API Reference</a> |
| </li> |
| |
| <li class="nav-item "> |
| <a class="nav-link" href="../../../development/index.html">Development</a> |
| </li> |
| |
| <li class="nav-item "> |
| <a class="nav-link" href="../../../migration_guide/index.html">Migration Guides</a> |
| </li> |
| |
| |
| </ul> |
| |
| |
| |
| |
| <ul class="navbar-nav"> |
| |
| |
| </ul> |
| </div> |
| </div> |
| </nav> |
| |
| |
| <div class="container-xl"> |
| <div class="row"> |
| |
| <div class="col-12 col-md-3 bd-sidebar"><form class="bd-search d-flex align-items-center" action="../../../search.html" method="get"> |
| <i class="icon fas fa-search"></i> |
| <input type="search" class="form-control" name="q" id="search-input" placeholder="Search the docs ..." aria-label="Search the docs ..." autocomplete="off" > |
| </form> |
| <nav class="bd-links" id="bd-docs-nav" aria-label="Main navigation"> |
| |
| <div class="bd-toc-item active"> |
| |
| |
| <ul class="nav bd-sidenav"> |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| </ul> |
| |
| </nav> |
| </div> |
| |
| |
| |
| <div class="d-none d-xl-block col-xl-2 bd-toc"> |
| |
| |
| <nav id="bd-toc-nav"> |
| <ul class="nav section-nav flex-column"> |
| |
| </ul> |
| </nav> |
| |
| |
| |
| </div> |
| |
| |
| |
| <main class="col-12 col-md-9 col-xl-7 py-md-5 pl-md-5 pr-md-4 bd-content" role="main"> |
| |
| <div> |
| |
| <h1>Source code for pyspark.ml.classification</h1><div class="highlight"><pre> |
| <span></span><span class="c1">#</span> |
| <span class="c1"># Licensed to the Apache Software Foundation (ASF) under one or more</span> |
| <span class="c1"># contributor license agreements. See the NOTICE file distributed with</span> |
| <span class="c1"># this work for additional information regarding copyright ownership.</span> |
| <span class="c1"># The ASF licenses this file to You under the Apache License, Version 2.0</span> |
| <span class="c1"># (the "License"); you may not use this file except in compliance with</span> |
| <span class="c1"># the License. You may obtain a copy of the License at</span> |
| <span class="c1">#</span> |
| <span class="c1"># http://www.apache.org/licenses/LICENSE-2.0</span> |
| <span class="c1">#</span> |
| <span class="c1"># Unless required by applicable law or agreed to in writing, software</span> |
| <span class="c1"># distributed under the License is distributed on an "AS IS" BASIS,</span> |
| <span class="c1"># WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.</span> |
| <span class="c1"># See the License for the specific language governing permissions and</span> |
| <span class="c1"># limitations under the License.</span> |
| <span class="c1">#</span> |
| |
| <span class="kn">import</span> <span class="nn">os</span> |
| <span class="kn">import</span> <span class="nn">operator</span> |
| <span class="kn">import</span> <span class="nn">sys</span> |
| <span class="kn">import</span> <span class="nn">uuid</span> |
| <span class="kn">import</span> <span class="nn">warnings</span> |
| <span class="kn">from</span> <span class="nn">abc</span> <span class="kn">import</span> <span class="n">ABCMeta</span><span class="p">,</span> <span class="n">abstractmethod</span> |
| <span class="kn">from</span> <span class="nn">multiprocessing.pool</span> <span class="kn">import</span> <span class="n">ThreadPool</span> |
| |
| <span class="kn">from</span> <span class="nn">typing</span> <span class="kn">import</span> <span class="p">(</span> |
| <span class="n">Any</span><span class="p">,</span> |
| <span class="n">Dict</span><span class="p">,</span> |
| <span class="n">Generic</span><span class="p">,</span> |
| <span class="n">Iterable</span><span class="p">,</span> |
| <span class="n">List</span><span class="p">,</span> |
| <span class="n">Optional</span><span class="p">,</span> |
| <span class="n">Type</span><span class="p">,</span> |
| <span class="n">TypeVar</span><span class="p">,</span> |
| <span class="n">Union</span><span class="p">,</span> |
| <span class="n">cast</span><span class="p">,</span> |
| <span class="n">overload</span><span class="p">,</span> |
| <span class="n">TYPE_CHECKING</span><span class="p">,</span> |
| <span class="p">)</span> |
| |
| <span class="kn">from</span> <span class="nn">pyspark</span> <span class="kn">import</span> <span class="n">keyword_only</span><span class="p">,</span> <span class="n">since</span><span class="p">,</span> <span class="n">SparkContext</span><span class="p">,</span> <span class="n">inheritable_thread_target</span> |
| <span class="kn">from</span> <span class="nn">pyspark.ml</span> <span class="kn">import</span> <span class="n">Estimator</span><span class="p">,</span> <span class="n">Predictor</span><span class="p">,</span> <span class="n">PredictionModel</span><span class="p">,</span> <span class="n">Model</span> |
| <span class="kn">from</span> <span class="nn">pyspark.ml.param.shared</span> <span class="kn">import</span> <span class="p">(</span> |
| <span class="n">HasRawPredictionCol</span><span class="p">,</span> |
| <span class="n">HasProbabilityCol</span><span class="p">,</span> |
| <span class="n">HasThresholds</span><span class="p">,</span> |
| <span class="n">HasRegParam</span><span class="p">,</span> |
| <span class="n">HasMaxIter</span><span class="p">,</span> |
| <span class="n">HasFitIntercept</span><span class="p">,</span> |
| <span class="n">HasTol</span><span class="p">,</span> |
| <span class="n">HasStandardization</span><span class="p">,</span> |
| <span class="n">HasWeightCol</span><span class="p">,</span> |
| <span class="n">HasAggregationDepth</span><span class="p">,</span> |
| <span class="n">HasThreshold</span><span class="p">,</span> |
| <span class="n">HasBlockSize</span><span class="p">,</span> |
| <span class="n">HasMaxBlockSizeInMB</span><span class="p">,</span> |
| <span class="n">Param</span><span class="p">,</span> |
| <span class="n">Params</span><span class="p">,</span> |
| <span class="n">TypeConverters</span><span class="p">,</span> |
| <span class="n">HasElasticNetParam</span><span class="p">,</span> |
| <span class="n">HasSeed</span><span class="p">,</span> |
| <span class="n">HasStepSize</span><span class="p">,</span> |
| <span class="n">HasSolver</span><span class="p">,</span> |
| <span class="n">HasParallelism</span><span class="p">,</span> |
| <span class="p">)</span> |
| <span class="kn">from</span> <span class="nn">pyspark.ml.tree</span> <span class="kn">import</span> <span class="p">(</span> |
| <span class="n">_DecisionTreeModel</span><span class="p">,</span> |
| <span class="n">_DecisionTreeParams</span><span class="p">,</span> |
| <span class="n">_TreeEnsembleModel</span><span class="p">,</span> |
| <span class="n">_RandomForestParams</span><span class="p">,</span> |
| <span class="n">_GBTParams</span><span class="p">,</span> |
| <span class="n">_HasVarianceImpurity</span><span class="p">,</span> |
| <span class="n">_TreeClassifierParams</span><span class="p">,</span> |
| <span class="p">)</span> |
| <span class="kn">from</span> <span class="nn">pyspark.ml.regression</span> <span class="kn">import</span> <span class="n">_FactorizationMachinesParams</span><span class="p">,</span> <span class="n">DecisionTreeRegressionModel</span> |
| <span class="kn">from</span> <span class="nn">pyspark.ml.base</span> <span class="kn">import</span> <span class="n">_PredictorParams</span> |
| <span class="kn">from</span> <span class="nn">pyspark.ml.util</span> <span class="kn">import</span> <span class="p">(</span> |
| <span class="n">DefaultParamsReader</span><span class="p">,</span> |
| <span class="n">DefaultParamsWriter</span><span class="p">,</span> |
| <span class="n">JavaMLReadable</span><span class="p">,</span> |
| <span class="n">JavaMLReader</span><span class="p">,</span> |
| <span class="n">JavaMLWritable</span><span class="p">,</span> |
| <span class="n">JavaMLWriter</span><span class="p">,</span> |
| <span class="n">MLReader</span><span class="p">,</span> |
| <span class="n">MLReadable</span><span class="p">,</span> |
| <span class="n">MLWriter</span><span class="p">,</span> |
| <span class="n">MLWritable</span><span class="p">,</span> |
| <span class="n">HasTrainingSummary</span><span class="p">,</span> |
| <span class="p">)</span> |
| <span class="kn">from</span> <span class="nn">pyspark.ml.wrapper</span> <span class="kn">import</span> <span class="n">JavaParams</span><span class="p">,</span> <span class="n">JavaPredictor</span><span class="p">,</span> <span class="n">JavaPredictionModel</span><span class="p">,</span> <span class="n">JavaWrapper</span> |
| <span class="kn">from</span> <span class="nn">pyspark.ml.common</span> <span class="kn">import</span> <span class="n">inherit_doc</span> |
| <span class="kn">from</span> <span class="nn">pyspark.ml.linalg</span> <span class="kn">import</span> <span class="n">Matrix</span><span class="p">,</span> <span class="n">Vector</span><span class="p">,</span> <span class="n">Vectors</span><span class="p">,</span> <span class="n">VectorUDT</span> |
| <span class="kn">from</span> <span class="nn">pyspark.sql</span> <span class="kn">import</span> <span class="n">DataFrame</span><span class="p">,</span> <span class="n">Row</span> |
| <span class="kn">from</span> <span class="nn">pyspark.sql.functions</span> <span class="kn">import</span> <span class="n">udf</span><span class="p">,</span> <span class="n">when</span> |
| <span class="kn">from</span> <span class="nn">pyspark.sql.types</span> <span class="kn">import</span> <span class="n">ArrayType</span><span class="p">,</span> <span class="n">DoubleType</span> |
| <span class="kn">from</span> <span class="nn">pyspark.storagelevel</span> <span class="kn">import</span> <span class="n">StorageLevel</span> |
| |
| |
| <span class="k">if</span> <span class="n">TYPE_CHECKING</span><span class="p">:</span> |
| <span class="kn">from</span> <span class="nn">pyspark.ml._typing</span> <span class="kn">import</span> <span class="n">P</span><span class="p">,</span> <span class="n">ParamMap</span> |
| <span class="kn">from</span> <span class="nn">py4j.java_gateway</span> <span class="kn">import</span> <span class="n">JavaObject</span> |
| |
| |
| <span class="n">T</span> <span class="o">=</span> <span class="n">TypeVar</span><span class="p">(</span><span class="s2">"T"</span><span class="p">)</span> |
| <span class="n">JPM</span> <span class="o">=</span> <span class="n">TypeVar</span><span class="p">(</span><span class="s2">"JPM"</span><span class="p">,</span> <span class="n">bound</span><span class="o">=</span><span class="n">JavaPredictionModel</span><span class="p">)</span> |
| <span class="n">CM</span> <span class="o">=</span> <span class="n">TypeVar</span><span class="p">(</span><span class="s2">"CM"</span><span class="p">,</span> <span class="n">bound</span><span class="o">=</span><span class="s2">"ClassificationModel"</span><span class="p">)</span> |
| |
| <span class="n">__all__</span> <span class="o">=</span> <span class="p">[</span> |
| <span class="s2">"LinearSVC"</span><span class="p">,</span> |
| <span class="s2">"LinearSVCModel"</span><span class="p">,</span> |
| <span class="s2">"LinearSVCSummary"</span><span class="p">,</span> |
| <span class="s2">"LinearSVCTrainingSummary"</span><span class="p">,</span> |
| <span class="s2">"LogisticRegression"</span><span class="p">,</span> |
| <span class="s2">"LogisticRegressionModel"</span><span class="p">,</span> |
| <span class="s2">"LogisticRegressionSummary"</span><span class="p">,</span> |
| <span class="s2">"LogisticRegressionTrainingSummary"</span><span class="p">,</span> |
| <span class="s2">"BinaryLogisticRegressionSummary"</span><span class="p">,</span> |
| <span class="s2">"BinaryLogisticRegressionTrainingSummary"</span><span class="p">,</span> |
| <span class="s2">"DecisionTreeClassifier"</span><span class="p">,</span> |
| <span class="s2">"DecisionTreeClassificationModel"</span><span class="p">,</span> |
| <span class="s2">"GBTClassifier"</span><span class="p">,</span> |
| <span class="s2">"GBTClassificationModel"</span><span class="p">,</span> |
| <span class="s2">"RandomForestClassifier"</span><span class="p">,</span> |
| <span class="s2">"RandomForestClassificationModel"</span><span class="p">,</span> |
| <span class="s2">"RandomForestClassificationSummary"</span><span class="p">,</span> |
| <span class="s2">"RandomForestClassificationTrainingSummary"</span><span class="p">,</span> |
| <span class="s2">"BinaryRandomForestClassificationSummary"</span><span class="p">,</span> |
| <span class="s2">"BinaryRandomForestClassificationTrainingSummary"</span><span class="p">,</span> |
| <span class="s2">"NaiveBayes"</span><span class="p">,</span> |
| <span class="s2">"NaiveBayesModel"</span><span class="p">,</span> |
| <span class="s2">"MultilayerPerceptronClassifier"</span><span class="p">,</span> |
| <span class="s2">"MultilayerPerceptronClassificationModel"</span><span class="p">,</span> |
| <span class="s2">"MultilayerPerceptronClassificationSummary"</span><span class="p">,</span> |
| <span class="s2">"MultilayerPerceptronClassificationTrainingSummary"</span><span class="p">,</span> |
| <span class="s2">"OneVsRest"</span><span class="p">,</span> |
| <span class="s2">"OneVsRestModel"</span><span class="p">,</span> |
| <span class="s2">"FMClassifier"</span><span class="p">,</span> |
| <span class="s2">"FMClassificationModel"</span><span class="p">,</span> |
| <span class="s2">"FMClassificationSummary"</span><span class="p">,</span> |
| <span class="s2">"FMClassificationTrainingSummary"</span><span class="p">,</span> |
| <span class="p">]</span> |
| |
| |
| <span class="k">class</span> <span class="nc">_ClassifierParams</span><span class="p">(</span><span class="n">HasRawPredictionCol</span><span class="p">,</span> <span class="n">_PredictorParams</span><span class="p">):</span> |
| <span class="w"> </span><span class="sd">"""</span> |
| <span class="sd"> Classifier Params for classification tasks.</span> |
| |
| <span class="sd"> .. versionadded:: 3.0.0</span> |
| <span class="sd"> """</span> |
| |
| <span class="k">pass</span> |
| |
| |
| <span class="nd">@inherit_doc</span> |
| <span class="k">class</span> <span class="nc">Classifier</span><span class="p">(</span><span class="n">Predictor</span><span class="p">[</span><span class="n">CM</span><span class="p">],</span> <span class="n">_ClassifierParams</span><span class="p">,</span> <span class="n">Generic</span><span class="p">[</span><span class="n">CM</span><span class="p">],</span> <span class="n">metaclass</span><span class="o">=</span><span class="n">ABCMeta</span><span class="p">):</span> |
| <span class="w"> </span><span class="sd">"""</span> |
| <span class="sd"> Classifier for classification tasks.</span> |
| <span class="sd"> Classes are indexed {0, 1, ..., numClasses - 1}.</span> |
| <span class="sd"> """</span> |
| |
| <span class="nd">@since</span><span class="p">(</span><span class="s2">"3.0.0"</span><span class="p">)</span> |
| <span class="k">def</span> <span class="nf">setRawPredictionCol</span><span class="p">(</span><span class="bp">self</span><span class="p">:</span> <span class="s2">"P"</span><span class="p">,</span> <span class="n">value</span><span class="p">:</span> <span class="nb">str</span><span class="p">)</span> <span class="o">-></span> <span class="s2">"P"</span><span class="p">:</span> |
| <span class="w"> </span><span class="sd">"""</span> |
| <span class="sd"> Sets the value of :py:attr:`rawPredictionCol`.</span> |
| <span class="sd"> """</span> |
| <span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">_set</span><span class="p">(</span><span class="n">rawPredictionCol</span><span class="o">=</span><span class="n">value</span><span class="p">)</span> |
| |
| |
| <span class="nd">@inherit_doc</span> |
| <span class="k">class</span> <span class="nc">ClassificationModel</span><span class="p">(</span><span class="n">PredictionModel</span><span class="p">,</span> <span class="n">_ClassifierParams</span><span class="p">,</span> <span class="n">metaclass</span><span class="o">=</span><span class="n">ABCMeta</span><span class="p">):</span> |
| <span class="w"> </span><span class="sd">"""</span> |
| <span class="sd"> Model produced by a ``Classifier``.</span> |
| <span class="sd"> Classes are indexed {0, 1, ..., numClasses - 1}.</span> |
| <span class="sd"> """</span> |
| |
| <span class="nd">@since</span><span class="p">(</span><span class="s2">"3.0.0"</span><span class="p">)</span> |
| <span class="k">def</span> <span class="nf">setRawPredictionCol</span><span class="p">(</span><span class="bp">self</span><span class="p">:</span> <span class="s2">"P"</span><span class="p">,</span> <span class="n">value</span><span class="p">:</span> <span class="nb">str</span><span class="p">)</span> <span class="o">-></span> <span class="s2">"P"</span><span class="p">:</span> |
| <span class="w"> </span><span class="sd">"""</span> |
| <span class="sd"> Sets the value of :py:attr:`rawPredictionCol`.</span> |
| <span class="sd"> """</span> |
| <span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">_set</span><span class="p">(</span><span class="n">rawPredictionCol</span><span class="o">=</span><span class="n">value</span><span class="p">)</span> |
| |
| <span class="nd">@property</span> <span class="c1"># type: ignore[misc]</span> |
| <span class="nd">@abstractmethod</span> |
| <span class="nd">@since</span><span class="p">(</span><span class="s2">"2.1.0"</span><span class="p">)</span> |
| <span class="k">def</span> <span class="nf">numClasses</span><span class="p">(</span><span class="bp">self</span><span class="p">)</span> <span class="o">-></span> <span class="nb">int</span><span class="p">:</span> |
| <span class="w"> </span><span class="sd">"""</span> |
| <span class="sd"> Number of classes (values which the label can take).</span> |
| <span class="sd"> """</span> |
| <span class="k">raise</span> <span class="ne">NotImplementedError</span><span class="p">()</span> |
| |
| <span class="nd">@abstractmethod</span> |
| <span class="nd">@since</span><span class="p">(</span><span class="s2">"3.0.0"</span><span class="p">)</span> |
| <span class="k">def</span> <span class="nf">predictRaw</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">value</span><span class="p">:</span> <span class="n">Vector</span><span class="p">)</span> <span class="o">-></span> <span class="n">Vector</span><span class="p">:</span> |
| <span class="w"> </span><span class="sd">"""</span> |
| <span class="sd"> Raw prediction for each possible label.</span> |
| <span class="sd"> """</span> |
| <span class="k">raise</span> <span class="ne">NotImplementedError</span><span class="p">()</span> |
| |
| |
| <span class="k">class</span> <span class="nc">_ProbabilisticClassifierParams</span><span class="p">(</span><span class="n">HasProbabilityCol</span><span class="p">,</span> <span class="n">HasThresholds</span><span class="p">,</span> <span class="n">_ClassifierParams</span><span class="p">):</span> |
| <span class="w"> </span><span class="sd">"""</span> |
| <span class="sd"> Params for :py:class:`ProbabilisticClassifier` and</span> |
| <span class="sd"> :py:class:`ProbabilisticClassificationModel`.</span> |
| |
| <span class="sd"> .. versionadded:: 3.0.0</span> |
| <span class="sd"> """</span> |
| |
| <span class="k">pass</span> |
| |
| |
| <span class="nd">@inherit_doc</span> |
| <span class="k">class</span> <span class="nc">ProbabilisticClassifier</span><span class="p">(</span><span class="n">Classifier</span><span class="p">,</span> <span class="n">_ProbabilisticClassifierParams</span><span class="p">,</span> <span class="n">metaclass</span><span class="o">=</span><span class="n">ABCMeta</span><span class="p">):</span> |
| <span class="w"> </span><span class="sd">"""</span> |
| <span class="sd"> Probabilistic Classifier for classification tasks.</span> |
| <span class="sd"> """</span> |
| |
| <span class="nd">@since</span><span class="p">(</span><span class="s2">"3.0.0"</span><span class="p">)</span> |
| <span class="k">def</span> <span class="nf">setProbabilityCol</span><span class="p">(</span><span class="bp">self</span><span class="p">:</span> <span class="s2">"P"</span><span class="p">,</span> <span class="n">value</span><span class="p">:</span> <span class="nb">str</span><span class="p">)</span> <span class="o">-></span> <span class="s2">"P"</span><span class="p">:</span> |
| <span class="w"> </span><span class="sd">"""</span> |
| <span class="sd"> Sets the value of :py:attr:`probabilityCol`.</span> |
| <span class="sd"> """</span> |
| <span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">_set</span><span class="p">(</span><span class="n">probabilityCol</span><span class="o">=</span><span class="n">value</span><span class="p">)</span> |
| |
| <span class="nd">@since</span><span class="p">(</span><span class="s2">"3.0.0"</span><span class="p">)</span> |
| <span class="k">def</span> <span class="nf">setThresholds</span><span class="p">(</span><span class="bp">self</span><span class="p">:</span> <span class="s2">"P"</span><span class="p">,</span> <span class="n">value</span><span class="p">:</span> <span class="n">List</span><span class="p">[</span><span class="nb">float</span><span class="p">])</span> <span class="o">-></span> <span class="s2">"P"</span><span class="p">:</span> |
| <span class="w"> </span><span class="sd">"""</span> |
| <span class="sd"> Sets the value of :py:attr:`thresholds`.</span> |
| <span class="sd"> """</span> |
| <span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">_set</span><span class="p">(</span><span class="n">thresholds</span><span class="o">=</span><span class="n">value</span><span class="p">)</span> |
| |
| |
| <span class="nd">@inherit_doc</span> |
| <span class="k">class</span> <span class="nc">ProbabilisticClassificationModel</span><span class="p">(</span> |
| <span class="n">ClassificationModel</span><span class="p">,</span> <span class="n">_ProbabilisticClassifierParams</span><span class="p">,</span> <span class="n">metaclass</span><span class="o">=</span><span class="n">ABCMeta</span> |
| <span class="p">):</span> |
| <span class="w"> </span><span class="sd">"""</span> |
| <span class="sd"> Model produced by a ``ProbabilisticClassifier``.</span> |
| <span class="sd"> """</span> |
| |
| <span class="nd">@since</span><span class="p">(</span><span class="s2">"3.0.0"</span><span class="p">)</span> |
| <span class="k">def</span> <span class="nf">setProbabilityCol</span><span class="p">(</span><span class="bp">self</span><span class="p">:</span> <span class="n">CM</span><span class="p">,</span> <span class="n">value</span><span class="p">:</span> <span class="nb">str</span><span class="p">)</span> <span class="o">-></span> <span class="n">CM</span><span class="p">:</span> |
| <span class="w"> </span><span class="sd">"""</span> |
| <span class="sd"> Sets the value of :py:attr:`probabilityCol`.</span> |
| <span class="sd"> """</span> |
| <span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">_set</span><span class="p">(</span><span class="n">probabilityCol</span><span class="o">=</span><span class="n">value</span><span class="p">)</span> |
| |
| <span class="nd">@since</span><span class="p">(</span><span class="s2">"3.0.0"</span><span class="p">)</span> |
| <span class="k">def</span> <span class="nf">setThresholds</span><span class="p">(</span><span class="bp">self</span><span class="p">:</span> <span class="n">CM</span><span class="p">,</span> <span class="n">value</span><span class="p">:</span> <span class="n">List</span><span class="p">[</span><span class="nb">float</span><span class="p">])</span> <span class="o">-></span> <span class="n">CM</span><span class="p">:</span> |
| <span class="w"> </span><span class="sd">"""</span> |
| <span class="sd"> Sets the value of :py:attr:`thresholds`.</span> |
| <span class="sd"> """</span> |
| <span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">_set</span><span class="p">(</span><span class="n">thresholds</span><span class="o">=</span><span class="n">value</span><span class="p">)</span> |
| |
| <span class="nd">@abstractmethod</span> |
| <span class="nd">@since</span><span class="p">(</span><span class="s2">"3.0.0"</span><span class="p">)</span> |
| <span class="k">def</span> <span class="nf">predictProbability</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">value</span><span class="p">:</span> <span class="n">Vector</span><span class="p">)</span> <span class="o">-></span> <span class="n">Vector</span><span class="p">:</span> |
| <span class="w"> </span><span class="sd">"""</span> |
| <span class="sd"> Predict the probability of each class given the features.</span> |
| <span class="sd"> """</span> |
| <span class="k">raise</span> <span class="ne">NotImplementedError</span><span class="p">()</span> |
| |
| |
| <span class="nd">@inherit_doc</span> |
| <span class="k">class</span> <span class="nc">_JavaClassifier</span><span class="p">(</span><span class="n">Classifier</span><span class="p">,</span> <span class="n">JavaPredictor</span><span class="p">[</span><span class="n">JPM</span><span class="p">],</span> <span class="n">Generic</span><span class="p">[</span><span class="n">JPM</span><span class="p">],</span> <span class="n">metaclass</span><span class="o">=</span><span class="n">ABCMeta</span><span class="p">):</span> |
| <span class="w"> </span><span class="sd">"""</span> |
| <span class="sd"> Java Classifier for classification tasks.</span> |
| <span class="sd"> Classes are indexed {0, 1, ..., numClasses - 1}.</span> |
| <span class="sd"> """</span> |
| |
| <span class="nd">@since</span><span class="p">(</span><span class="s2">"3.0.0"</span><span class="p">)</span> |
| <span class="k">def</span> <span class="nf">setRawPredictionCol</span><span class="p">(</span><span class="bp">self</span><span class="p">:</span> <span class="s2">"P"</span><span class="p">,</span> <span class="n">value</span><span class="p">:</span> <span class="nb">str</span><span class="p">)</span> <span class="o">-></span> <span class="s2">"P"</span><span class="p">:</span> |
| <span class="w"> </span><span class="sd">"""</span> |
| <span class="sd"> Sets the value of :py:attr:`rawPredictionCol`.</span> |
| <span class="sd"> """</span> |
| <span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">_set</span><span class="p">(</span><span class="n">rawPredictionCol</span><span class="o">=</span><span class="n">value</span><span class="p">)</span> |
| |
| |
| <span class="nd">@inherit_doc</span> |
| <span class="k">class</span> <span class="nc">_JavaClassificationModel</span><span class="p">(</span><span class="n">ClassificationModel</span><span class="p">,</span> <span class="n">JavaPredictionModel</span><span class="p">[</span><span class="n">T</span><span class="p">]):</span> |
| <span class="w"> </span><span class="sd">"""</span> |
| <span class="sd"> Java Model produced by a ``Classifier``.</span> |
| <span class="sd"> Classes are indexed {0, 1, ..., numClasses - 1}.</span> |
| <span class="sd"> To be mixed in with :class:`pyspark.ml.JavaModel`</span> |
| <span class="sd"> """</span> |
| |
| <span class="nd">@property</span> <span class="c1"># type: ignore[misc]</span> |
| <span class="nd">@since</span><span class="p">(</span><span class="s2">"2.1.0"</span><span class="p">)</span> |
| <span class="k">def</span> <span class="nf">numClasses</span><span class="p">(</span><span class="bp">self</span><span class="p">)</span> <span class="o">-></span> <span class="nb">int</span><span class="p">:</span> |
| <span class="w"> </span><span class="sd">"""</span> |
| <span class="sd"> Number of classes (values which the label can take).</span> |
| <span class="sd"> """</span> |
| <span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">_call_java</span><span class="p">(</span><span class="s2">"numClasses"</span><span class="p">)</span> |
| |
| <span class="nd">@since</span><span class="p">(</span><span class="s2">"3.0.0"</span><span class="p">)</span> |
| <span class="k">def</span> <span class="nf">predictRaw</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">value</span><span class="p">:</span> <span class="n">Vector</span><span class="p">)</span> <span class="o">-></span> <span class="n">Vector</span><span class="p">:</span> |
| <span class="w"> </span><span class="sd">"""</span> |
| <span class="sd"> Raw prediction for each possible label.</span> |
| <span class="sd"> """</span> |
| <span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">_call_java</span><span class="p">(</span><span class="s2">"predictRaw"</span><span class="p">,</span> <span class="n">value</span><span class="p">)</span> |
| |
| |
| <span class="nd">@inherit_doc</span> |
| <span class="k">class</span> <span class="nc">_JavaProbabilisticClassifier</span><span class="p">(</span> |
| <span class="n">ProbabilisticClassifier</span><span class="p">,</span> <span class="n">_JavaClassifier</span><span class="p">[</span><span class="n">JPM</span><span class="p">],</span> <span class="n">Generic</span><span class="p">[</span><span class="n">JPM</span><span class="p">],</span> <span class="n">metaclass</span><span class="o">=</span><span class="n">ABCMeta</span> |
| <span class="p">):</span> |
| <span class="w"> </span><span class="sd">"""</span> |
| <span class="sd"> Java Probabilistic Classifier for classification tasks.</span> |
| <span class="sd"> """</span> |
| |
| <span class="k">pass</span> |
| |
| |
| <span class="nd">@inherit_doc</span> |
| <span class="k">class</span> <span class="nc">_JavaProbabilisticClassificationModel</span><span class="p">(</span> |
| <span class="n">ProbabilisticClassificationModel</span><span class="p">,</span> <span class="n">_JavaClassificationModel</span><span class="p">[</span><span class="n">T</span><span class="p">]</span> |
| <span class="p">):</span> |
| <span class="w"> </span><span class="sd">"""</span> |
| <span class="sd"> Java Model produced by a ``ProbabilisticClassifier``.</span> |
| <span class="sd"> """</span> |
| |
| <span class="nd">@since</span><span class="p">(</span><span class="s2">"3.0.0"</span><span class="p">)</span> |
| <span class="k">def</span> <span class="nf">predictProbability</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">value</span><span class="p">:</span> <span class="n">Vector</span><span class="p">)</span> <span class="o">-></span> <span class="n">Vector</span><span class="p">:</span> |
| <span class="w"> </span><span class="sd">"""</span> |
| <span class="sd"> Predict the probability of each class given the features.</span> |
| <span class="sd"> """</span> |
| <span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">_call_java</span><span class="p">(</span><span class="s2">"predictProbability"</span><span class="p">,</span> <span class="n">value</span><span class="p">)</span> |
| |
| |
| <span class="nd">@inherit_doc</span> |
| <span class="k">class</span> <span class="nc">_ClassificationSummary</span><span class="p">(</span><span class="n">JavaWrapper</span><span class="p">):</span> |
| <span class="w"> </span><span class="sd">"""</span> |
| <span class="sd"> Abstraction for multiclass classification results for a given model.</span> |
| |
| <span class="sd"> .. versionadded:: 3.1.0</span> |
| <span class="sd"> """</span> |
| |
| <span class="nd">@property</span> <span class="c1"># type: ignore[misc]</span> |
| <span class="nd">@since</span><span class="p">(</span><span class="s2">"3.1.0"</span><span class="p">)</span> |
| <span class="k">def</span> <span class="nf">predictions</span><span class="p">(</span><span class="bp">self</span><span class="p">)</span> <span class="o">-></span> <span class="n">DataFrame</span><span class="p">:</span> |
| <span class="w"> </span><span class="sd">"""</span> |
| <span class="sd"> Dataframe outputted by the model's `transform` method.</span> |
| <span class="sd"> """</span> |
| <span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">_call_java</span><span class="p">(</span><span class="s2">"predictions"</span><span class="p">)</span> |
| |
| <span class="nd">@property</span> <span class="c1"># type: ignore[misc]</span> |
| <span class="nd">@since</span><span class="p">(</span><span class="s2">"3.1.0"</span><span class="p">)</span> |
| <span class="k">def</span> <span class="nf">predictionCol</span><span class="p">(</span><span class="bp">self</span><span class="p">)</span> <span class="o">-></span> <span class="nb">str</span><span class="p">:</span> |
| <span class="w"> </span><span class="sd">"""</span> |
| <span class="sd"> Field in "predictions" which gives the prediction of each class.</span> |
| <span class="sd"> """</span> |
| <span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">_call_java</span><span class="p">(</span><span class="s2">"predictionCol"</span><span class="p">)</span> |
| |
| <span class="nd">@property</span> <span class="c1"># type: ignore[misc]</span> |
| <span class="nd">@since</span><span class="p">(</span><span class="s2">"3.1.0"</span><span class="p">)</span> |
| <span class="k">def</span> <span class="nf">labelCol</span><span class="p">(</span><span class="bp">self</span><span class="p">)</span> <span class="o">-></span> <span class="nb">str</span><span class="p">:</span> |
| <span class="w"> </span><span class="sd">"""</span> |
| <span class="sd"> Field in "predictions" which gives the true label of each</span> |
| <span class="sd"> instance.</span> |
| <span class="sd"> """</span> |
| <span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">_call_java</span><span class="p">(</span><span class="s2">"labelCol"</span><span class="p">)</span> |
| |
| <span class="nd">@property</span> <span class="c1"># type: ignore[misc]</span> |
| <span class="nd">@since</span><span class="p">(</span><span class="s2">"3.1.0"</span><span class="p">)</span> |
| <span class="k">def</span> <span class="nf">weightCol</span><span class="p">(</span><span class="bp">self</span><span class="p">)</span> <span class="o">-></span> <span class="nb">str</span><span class="p">:</span> |
| <span class="w"> </span><span class="sd">"""</span> |
| <span class="sd"> Field in "predictions" which gives the weight of each instance</span> |
| <span class="sd"> as a vector.</span> |
| <span class="sd"> """</span> |
| <span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">_call_java</span><span class="p">(</span><span class="s2">"weightCol"</span><span class="p">)</span> |
| |
| <span class="nd">@property</span> |
| <span class="k">def</span> <span class="nf">labels</span><span class="p">(</span><span class="bp">self</span><span class="p">)</span> <span class="o">-></span> <span class="n">List</span><span class="p">[</span><span class="nb">str</span><span class="p">]:</span> |
| <span class="w"> </span><span class="sd">"""</span> |
| <span class="sd"> Returns the sequence of labels in ascending order. This order matches the order used</span> |
| <span class="sd"> in metrics which are specified as arrays over labels, e.g., truePositiveRateByLabel.</span> |
| |
| <span class="sd"> .. versionadded:: 3.1.0</span> |
| |
| <span class="sd"> Notes</span> |
| <span class="sd"> -----</span> |
| <span class="sd"> In most cases, it will be values {0.0, 1.0, ..., numClasses-1}, However, if the</span> |
| <span class="sd"> training set is missing a label, then all of the arrays over labels</span> |
| <span class="sd"> (e.g., from truePositiveRateByLabel) will be of length numClasses-1 instead of the</span> |
| <span class="sd"> expected numClasses.</span> |
| <span class="sd"> """</span> |
| <span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">_call_java</span><span class="p">(</span><span class="s2">"labels"</span><span class="p">)</span> |
| |
| <span class="nd">@property</span> <span class="c1"># type: ignore[misc]</span> |
| <span class="nd">@since</span><span class="p">(</span><span class="s2">"3.1.0"</span><span class="p">)</span> |
| <span class="k">def</span> <span class="nf">truePositiveRateByLabel</span><span class="p">(</span><span class="bp">self</span><span class="p">)</span> <span class="o">-></span> <span class="n">List</span><span class="p">[</span><span class="nb">float</span><span class="p">]:</span> |
| <span class="w"> </span><span class="sd">"""</span> |
| <span class="sd"> Returns true positive rate for each label (category).</span> |
| <span class="sd"> """</span> |
| <span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">_call_java</span><span class="p">(</span><span class="s2">"truePositiveRateByLabel"</span><span class="p">)</span> |
| |
| <span class="nd">@property</span> <span class="c1"># type: ignore[misc]</span> |
| <span class="nd">@since</span><span class="p">(</span><span class="s2">"3.1.0"</span><span class="p">)</span> |
| <span class="k">def</span> <span class="nf">falsePositiveRateByLabel</span><span class="p">(</span><span class="bp">self</span><span class="p">)</span> <span class="o">-></span> <span class="n">List</span><span class="p">[</span><span class="nb">float</span><span class="p">]:</span> |
| <span class="w"> </span><span class="sd">"""</span> |
| <span class="sd"> Returns false positive rate for each label (category).</span> |
| <span class="sd"> """</span> |
| <span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">_call_java</span><span class="p">(</span><span class="s2">"falsePositiveRateByLabel"</span><span class="p">)</span> |
| |
| <span class="nd">@property</span> <span class="c1"># type: ignore[misc]</span> |
| <span class="nd">@since</span><span class="p">(</span><span class="s2">"3.1.0"</span><span class="p">)</span> |
| <span class="k">def</span> <span class="nf">precisionByLabel</span><span class="p">(</span><span class="bp">self</span><span class="p">)</span> <span class="o">-></span> <span class="n">List</span><span class="p">[</span><span class="nb">float</span><span class="p">]:</span> |
| <span class="w"> </span><span class="sd">"""</span> |
| <span class="sd"> Returns precision for each label (category).</span> |
| <span class="sd"> """</span> |
| <span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">_call_java</span><span class="p">(</span><span class="s2">"precisionByLabel"</span><span class="p">)</span> |
| |
| <span class="nd">@property</span> <span class="c1"># type: ignore[misc]</span> |
| <span class="nd">@since</span><span class="p">(</span><span class="s2">"3.1.0"</span><span class="p">)</span> |
| <span class="k">def</span> <span class="nf">recallByLabel</span><span class="p">(</span><span class="bp">self</span><span class="p">)</span> <span class="o">-></span> <span class="n">List</span><span class="p">[</span><span class="nb">float</span><span class="p">]:</span> |
| <span class="w"> </span><span class="sd">"""</span> |
| <span class="sd"> Returns recall for each label (category).</span> |
| <span class="sd"> """</span> |
| <span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">_call_java</span><span class="p">(</span><span class="s2">"recallByLabel"</span><span class="p">)</span> |
| |
| <span class="nd">@since</span><span class="p">(</span><span class="s2">"3.1.0"</span><span class="p">)</span> |
| <span class="k">def</span> <span class="nf">fMeasureByLabel</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">beta</span><span class="p">:</span> <span class="nb">float</span> <span class="o">=</span> <span class="mf">1.0</span><span class="p">)</span> <span class="o">-></span> <span class="n">List</span><span class="p">[</span><span class="nb">float</span><span class="p">]:</span> |
| <span class="w"> </span><span class="sd">"""</span> |
| <span class="sd"> Returns f-measure for each label (category).</span> |
| <span class="sd"> """</span> |
| <span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">_call_java</span><span class="p">(</span><span class="s2">"fMeasureByLabel"</span><span class="p">,</span> <span class="n">beta</span><span class="p">)</span> |
| |
| <span class="nd">@property</span> <span class="c1"># type: ignore[misc]</span> |
| <span class="nd">@since</span><span class="p">(</span><span class="s2">"3.1.0"</span><span class="p">)</span> |
| <span class="k">def</span> <span class="nf">accuracy</span><span class="p">(</span><span class="bp">self</span><span class="p">)</span> <span class="o">-></span> <span class="nb">float</span><span class="p">:</span> |
| <span class="w"> </span><span class="sd">"""</span> |
| <span class="sd"> Returns accuracy.</span> |
| <span class="sd"> (equals to the total number of correctly classified instances</span> |
| <span class="sd"> out of the total number of instances.)</span> |
| <span class="sd"> """</span> |
| <span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">_call_java</span><span class="p">(</span><span class="s2">"accuracy"</span><span class="p">)</span> |
| |
| <span class="nd">@property</span> <span class="c1"># type: ignore[misc]</span> |
| <span class="nd">@since</span><span class="p">(</span><span class="s2">"3.1.0"</span><span class="p">)</span> |
| <span class="k">def</span> <span class="nf">weightedTruePositiveRate</span><span class="p">(</span><span class="bp">self</span><span class="p">)</span> <span class="o">-></span> <span class="nb">float</span><span class="p">:</span> |
| <span class="w"> </span><span class="sd">"""</span> |
| <span class="sd"> Returns weighted true positive rate.</span> |
| <span class="sd"> (equals to precision, recall and f-measure)</span> |
| <span class="sd"> """</span> |
| <span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">_call_java</span><span class="p">(</span><span class="s2">"weightedTruePositiveRate"</span><span class="p">)</span> |
| |
| <span class="nd">@property</span> <span class="c1"># type: ignore[misc]</span> |
| <span class="nd">@since</span><span class="p">(</span><span class="s2">"3.1.0"</span><span class="p">)</span> |
| <span class="k">def</span> <span class="nf">weightedFalsePositiveRate</span><span class="p">(</span><span class="bp">self</span><span class="p">)</span> <span class="o">-></span> <span class="nb">float</span><span class="p">:</span> |
| <span class="w"> </span><span class="sd">"""</span> |
| <span class="sd"> Returns weighted false positive rate.</span> |
| <span class="sd"> """</span> |
| <span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">_call_java</span><span class="p">(</span><span class="s2">"weightedFalsePositiveRate"</span><span class="p">)</span> |
| |
| <span class="nd">@property</span> <span class="c1"># type: ignore[misc]</span> |
| <span class="nd">@since</span><span class="p">(</span><span class="s2">"3.1.0"</span><span class="p">)</span> |
| <span class="k">def</span> <span class="nf">weightedRecall</span><span class="p">(</span><span class="bp">self</span><span class="p">)</span> <span class="o">-></span> <span class="nb">float</span><span class="p">:</span> |
| <span class="w"> </span><span class="sd">"""</span> |
| <span class="sd"> Returns weighted averaged recall.</span> |
| <span class="sd"> (equals to precision, recall and f-measure)</span> |
| <span class="sd"> """</span> |
| <span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">_call_java</span><span class="p">(</span><span class="s2">"weightedRecall"</span><span class="p">)</span> |
| |
| <span class="nd">@property</span> <span class="c1"># type: ignore[misc]</span> |
| <span class="nd">@since</span><span class="p">(</span><span class="s2">"3.1.0"</span><span class="p">)</span> |
| <span class="k">def</span> <span class="nf">weightedPrecision</span><span class="p">(</span><span class="bp">self</span><span class="p">)</span> <span class="o">-></span> <span class="nb">float</span><span class="p">:</span> |
| <span class="w"> </span><span class="sd">"""</span> |
| <span class="sd"> Returns weighted averaged precision.</span> |
| <span class="sd"> """</span> |
| <span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">_call_java</span><span class="p">(</span><span class="s2">"weightedPrecision"</span><span class="p">)</span> |
| |
| <span class="nd">@since</span><span class="p">(</span><span class="s2">"3.1.0"</span><span class="p">)</span> |
| <span class="k">def</span> <span class="nf">weightedFMeasure</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">beta</span><span class="p">:</span> <span class="nb">float</span> <span class="o">=</span> <span class="mf">1.0</span><span class="p">)</span> <span class="o">-></span> <span class="nb">float</span><span class="p">:</span> |
| <span class="w"> </span><span class="sd">"""</span> |
| <span class="sd"> Returns weighted averaged f-measure.</span> |
| <span class="sd"> """</span> |
| <span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">_call_java</span><span class="p">(</span><span class="s2">"weightedFMeasure"</span><span class="p">,</span> <span class="n">beta</span><span class="p">)</span> |
| |
| |
| <span class="nd">@inherit_doc</span> |
| <span class="k">class</span> <span class="nc">_TrainingSummary</span><span class="p">(</span><span class="n">JavaWrapper</span><span class="p">):</span> |
| <span class="w"> </span><span class="sd">"""</span> |
| <span class="sd"> Abstraction for Training results.</span> |
| |
| <span class="sd"> .. versionadded:: 3.1.0</span> |
| <span class="sd"> """</span> |
| |
| <span class="nd">@property</span> <span class="c1"># type: ignore[misc]</span> |
| <span class="nd">@since</span><span class="p">(</span><span class="s2">"3.1.0"</span><span class="p">)</span> |
| <span class="k">def</span> <span class="nf">objectiveHistory</span><span class="p">(</span><span class="bp">self</span><span class="p">)</span> <span class="o">-></span> <span class="n">List</span><span class="p">[</span><span class="nb">float</span><span class="p">]:</span> |
| <span class="w"> </span><span class="sd">"""</span> |
| <span class="sd"> Objective function (scaled loss + regularization) at each</span> |
| <span class="sd"> iteration. It contains one more element, the initial state,</span> |
| <span class="sd"> than number of iterations.</span> |
| <span class="sd"> """</span> |
| <span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">_call_java</span><span class="p">(</span><span class="s2">"objectiveHistory"</span><span class="p">)</span> |
| |
| <span class="nd">@property</span> <span class="c1"># type: ignore[misc]</span> |
| <span class="nd">@since</span><span class="p">(</span><span class="s2">"3.1.0"</span><span class="p">)</span> |
| <span class="k">def</span> <span class="nf">totalIterations</span><span class="p">(</span><span class="bp">self</span><span class="p">)</span> <span class="o">-></span> <span class="nb">int</span><span class="p">:</span> |
| <span class="w"> </span><span class="sd">"""</span> |
| <span class="sd"> Number of training iterations until termination.</span> |
| <span class="sd"> """</span> |
| <span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">_call_java</span><span class="p">(</span><span class="s2">"totalIterations"</span><span class="p">)</span> |
| |
| |
| <span class="nd">@inherit_doc</span> |
| <span class="k">class</span> <span class="nc">_BinaryClassificationSummary</span><span class="p">(</span><span class="n">_ClassificationSummary</span><span class="p">):</span> |
| <span class="w"> </span><span class="sd">"""</span> |
| <span class="sd"> Binary classification results for a given model.</span> |
| |
| <span class="sd"> .. versionadded:: 3.1.0</span> |
| <span class="sd"> """</span> |
| |
| <span class="nd">@property</span> <span class="c1"># type: ignore[misc]</span> |
| <span class="nd">@since</span><span class="p">(</span><span class="s2">"3.1.0"</span><span class="p">)</span> |
| <span class="k">def</span> <span class="nf">scoreCol</span><span class="p">(</span><span class="bp">self</span><span class="p">)</span> <span class="o">-></span> <span class="nb">str</span><span class="p">:</span> |
| <span class="w"> </span><span class="sd">"""</span> |
| <span class="sd"> Field in "predictions" which gives the probability or raw prediction</span> |
| <span class="sd"> of each class as a vector.</span> |
| <span class="sd"> """</span> |
| <span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">_call_java</span><span class="p">(</span><span class="s2">"scoreCol"</span><span class="p">)</span> |
| |
| <span class="nd">@property</span> |
| <span class="k">def</span> <span class="nf">roc</span><span class="p">(</span><span class="bp">self</span><span class="p">)</span> <span class="o">-></span> <span class="n">DataFrame</span><span class="p">:</span> |
| <span class="w"> </span><span class="sd">"""</span> |
| <span class="sd"> Returns the receiver operating characteristic (ROC) curve,</span> |
| <span class="sd"> which is a Dataframe having two fields (FPR, TPR) with</span> |
| <span class="sd"> (0.0, 0.0) prepended and (1.0, 1.0) appended to it.</span> |
| |
| <span class="sd"> .. versionadded:: 3.1.0</span> |
| |
| <span class="sd"> Notes</span> |
| <span class="sd"> -----</span> |
| <span class="sd"> `Wikipedia reference <http://en.wikipedia.org/wiki/Receiver_operating_characteristic>`_</span> |
| <span class="sd"> """</span> |
| <span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">_call_java</span><span class="p">(</span><span class="s2">"roc"</span><span class="p">)</span> |
| |
| <span class="nd">@property</span> <span class="c1"># type: ignore[misc]</span> |
| <span class="nd">@since</span><span class="p">(</span><span class="s2">"3.1.0"</span><span class="p">)</span> |
| <span class="k">def</span> <span class="nf">areaUnderROC</span><span class="p">(</span><span class="bp">self</span><span class="p">)</span> <span class="o">-></span> <span class="nb">float</span><span class="p">:</span> |
| <span class="w"> </span><span class="sd">"""</span> |
| <span class="sd"> Computes the area under the receiver operating characteristic</span> |
| <span class="sd"> (ROC) curve.</span> |
| <span class="sd"> """</span> |
| <span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">_call_java</span><span class="p">(</span><span class="s2">"areaUnderROC"</span><span class="p">)</span> |
| |
| <span class="nd">@property</span> <span class="c1"># type: ignore[misc]</span> |
| <span class="nd">@since</span><span class="p">(</span><span class="s2">"3.1.0"</span><span class="p">)</span> |
| <span class="k">def</span> <span class="nf">pr</span><span class="p">(</span><span class="bp">self</span><span class="p">)</span> <span class="o">-></span> <span class="n">DataFrame</span><span class="p">:</span> |
| <span class="w"> </span><span class="sd">"""</span> |
| <span class="sd"> Returns the precision-recall curve, which is a Dataframe</span> |
| <span class="sd"> containing two fields recall, precision with (0.0, 1.0) prepended</span> |
| <span class="sd"> to it.</span> |
| <span class="sd"> """</span> |
| <span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">_call_java</span><span class="p">(</span><span class="s2">"pr"</span><span class="p">)</span> |
| |
| <span class="nd">@property</span> <span class="c1"># type: ignore[misc]</span> |
| <span class="nd">@since</span><span class="p">(</span><span class="s2">"3.1.0"</span><span class="p">)</span> |
| <span class="k">def</span> <span class="nf">fMeasureByThreshold</span><span class="p">(</span><span class="bp">self</span><span class="p">)</span> <span class="o">-></span> <span class="n">DataFrame</span><span class="p">:</span> |
| <span class="w"> </span><span class="sd">"""</span> |
| <span class="sd"> Returns a dataframe with two fields (threshold, F-Measure) curve</span> |
| <span class="sd"> with beta = 1.0.</span> |
| <span class="sd"> """</span> |
| <span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">_call_java</span><span class="p">(</span><span class="s2">"fMeasureByThreshold"</span><span class="p">)</span> |
| |
| <span class="nd">@property</span> <span class="c1"># type: ignore[misc]</span> |
| <span class="nd">@since</span><span class="p">(</span><span class="s2">"3.1.0"</span><span class="p">)</span> |
| <span class="k">def</span> <span class="nf">precisionByThreshold</span><span class="p">(</span><span class="bp">self</span><span class="p">)</span> <span class="o">-></span> <span class="n">DataFrame</span><span class="p">:</span> |
| <span class="w"> </span><span class="sd">"""</span> |
| <span class="sd"> Returns a dataframe with two fields (threshold, precision) curve.</span> |
| <span class="sd"> Every possible probability obtained in transforming the dataset</span> |
| <span class="sd"> are used as thresholds used in calculating the precision.</span> |
| <span class="sd"> """</span> |
| <span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">_call_java</span><span class="p">(</span><span class="s2">"precisionByThreshold"</span><span class="p">)</span> |
| |
| <span class="nd">@property</span> <span class="c1"># type: ignore[misc]</span> |
| <span class="nd">@since</span><span class="p">(</span><span class="s2">"3.1.0"</span><span class="p">)</span> |
| <span class="k">def</span> <span class="nf">recallByThreshold</span><span class="p">(</span><span class="bp">self</span><span class="p">)</span> <span class="o">-></span> <span class="n">DataFrame</span><span class="p">:</span> |
| <span class="w"> </span><span class="sd">"""</span> |
| <span class="sd"> Returns a dataframe with two fields (threshold, recall) curve.</span> |
| <span class="sd"> Every possible probability obtained in transforming the dataset</span> |
| <span class="sd"> are used as thresholds used in calculating the recall.</span> |
| <span class="sd"> """</span> |
| <span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">_call_java</span><span class="p">(</span><span class="s2">"recallByThreshold"</span><span class="p">)</span> |
| |
| |
| <span class="k">class</span> <span class="nc">_LinearSVCParams</span><span class="p">(</span> |
| <span class="n">_ClassifierParams</span><span class="p">,</span> |
| <span class="n">HasRegParam</span><span class="p">,</span> |
| <span class="n">HasMaxIter</span><span class="p">,</span> |
| <span class="n">HasFitIntercept</span><span class="p">,</span> |
| <span class="n">HasTol</span><span class="p">,</span> |
| <span class="n">HasStandardization</span><span class="p">,</span> |
| <span class="n">HasWeightCol</span><span class="p">,</span> |
| <span class="n">HasAggregationDepth</span><span class="p">,</span> |
| <span class="n">HasThreshold</span><span class="p">,</span> |
| <span class="n">HasMaxBlockSizeInMB</span><span class="p">,</span> |
| <span class="p">):</span> |
| <span class="w"> </span><span class="sd">"""</span> |
| <span class="sd"> Params for :py:class:`LinearSVC` and :py:class:`LinearSVCModel`.</span> |
| |
| <span class="sd"> .. versionadded:: 3.0.0</span> |
| <span class="sd"> """</span> |
| |
| <span class="n">threshold</span><span class="p">:</span> <span class="n">Param</span><span class="p">[</span><span class="nb">float</span><span class="p">]</span> <span class="o">=</span> <span class="n">Param</span><span class="p">(</span> |
| <span class="n">Params</span><span class="o">.</span><span class="n">_dummy</span><span class="p">(),</span> |
| <span class="s2">"threshold"</span><span class="p">,</span> |
| <span class="s2">"The threshold in binary classification applied to the linear model"</span> |
| <span class="s2">" prediction. This threshold can be any real number, where Inf will make"</span> |
| <span class="s2">" all predictions 0.0 and -Inf will make all predictions 1.0."</span><span class="p">,</span> |
| <span class="n">typeConverter</span><span class="o">=</span><span class="n">TypeConverters</span><span class="o">.</span><span class="n">toFloat</span><span class="p">,</span> |
| <span class="p">)</span> |
| |
| <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="o">*</span><span class="n">args</span><span class="p">:</span> <span class="n">Any</span><span class="p">)</span> <span class="o">-></span> <span class="kc">None</span><span class="p">:</span> |
| <span class="nb">super</span><span class="p">(</span><span class="n">_LinearSVCParams</span><span class="p">,</span> <span class="bp">self</span><span class="p">)</span><span class="o">.</span><span class="fm">__init__</span><span class="p">(</span><span class="o">*</span><span class="n">args</span><span class="p">)</span> |
| <span class="bp">self</span><span class="o">.</span><span class="n">_setDefault</span><span class="p">(</span> |
| <span class="n">maxIter</span><span class="o">=</span><span class="mi">100</span><span class="p">,</span> |
| <span class="n">regParam</span><span class="o">=</span><span class="mf">0.0</span><span class="p">,</span> |
| <span class="n">tol</span><span class="o">=</span><span class="mf">1e-6</span><span class="p">,</span> |
| <span class="n">fitIntercept</span><span class="o">=</span><span class="kc">True</span><span class="p">,</span> |
| <span class="n">standardization</span><span class="o">=</span><span class="kc">True</span><span class="p">,</span> |
| <span class="n">threshold</span><span class="o">=</span><span class="mf">0.0</span><span class="p">,</span> |
| <span class="n">aggregationDepth</span><span class="o">=</span><span class="mi">2</span><span class="p">,</span> |
| <span class="n">maxBlockSizeInMB</span><span class="o">=</span><span class="mf">0.0</span><span class="p">,</span> |
| <span class="p">)</span> |
| |
| |
| <div class="viewcode-block" id="LinearSVC"><a class="viewcode-back" href="../../../reference/api/pyspark.ml.classification.LinearSVC.html#pyspark.ml.classification.LinearSVC">[docs]</a><span class="nd">@inherit_doc</span> |
| <span class="k">class</span> <span class="nc">LinearSVC</span><span class="p">(</span> |
| <span class="n">_JavaClassifier</span><span class="p">[</span><span class="s2">"LinearSVCModel"</span><span class="p">],</span> |
| <span class="n">_LinearSVCParams</span><span class="p">,</span> |
| <span class="n">JavaMLWritable</span><span class="p">,</span> |
| <span class="n">JavaMLReadable</span><span class="p">[</span><span class="s2">"LinearSVC"</span><span class="p">],</span> |
| <span class="p">):</span> |
| <span class="w"> </span><span class="sd">"""</span> |
| <span class="sd"> This binary classifier optimizes the Hinge Loss using the OWLQN optimizer.</span> |
| <span class="sd"> Only supports L2 regularization currently.</span> |
| |
| <span class="sd"> .. versionadded:: 2.2.0</span> |
| |
| <span class="sd"> Notes</span> |
| <span class="sd"> -----</span> |
| <span class="sd"> `Linear SVM Classifier <https://en.wikipedia.org/wiki/Support_vector_machine#Linear_SVM>`_</span> |
| |
| <span class="sd"> Examples</span> |
| <span class="sd"> --------</span> |
| <span class="sd"> >>> from pyspark.sql import Row</span> |
| <span class="sd"> >>> from pyspark.ml.linalg import Vectors</span> |
| <span class="sd"> >>> df = sc.parallelize([</span> |
| <span class="sd"> ... Row(label=1.0, features=Vectors.dense(1.0, 1.0, 1.0)),</span> |
| <span class="sd"> ... Row(label=0.0, features=Vectors.dense(1.0, 2.0, 3.0))]).toDF()</span> |
| <span class="sd"> >>> svm = LinearSVC()</span> |
| <span class="sd"> >>> svm.getMaxIter()</span> |
| <span class="sd"> 100</span> |
| <span class="sd"> >>> svm.setMaxIter(5)</span> |
| <span class="sd"> LinearSVC...</span> |
| <span class="sd"> >>> svm.getMaxIter()</span> |
| <span class="sd"> 5</span> |
| <span class="sd"> >>> svm.getRegParam()</span> |
| <span class="sd"> 0.0</span> |
| <span class="sd"> >>> svm.setRegParam(0.01)</span> |
| <span class="sd"> LinearSVC...</span> |
| <span class="sd"> >>> svm.getRegParam()</span> |
| <span class="sd"> 0.01</span> |
| <span class="sd"> >>> model = svm.fit(df)</span> |
| <span class="sd"> >>> model.setPredictionCol("newPrediction")</span> |
| <span class="sd"> LinearSVCModel...</span> |
| <span class="sd"> >>> model.getPredictionCol()</span> |
| <span class="sd"> 'newPrediction'</span> |
| <span class="sd"> >>> model.setThreshold(0.5)</span> |
| <span class="sd"> LinearSVCModel...</span> |
| <span class="sd"> >>> model.getThreshold()</span> |
| <span class="sd"> 0.5</span> |
| <span class="sd"> >>> model.getMaxBlockSizeInMB()</span> |
| <span class="sd"> 0.0</span> |
| <span class="sd"> >>> model.coefficients</span> |
| <span class="sd"> DenseVector([0.0, -1.0319, -0.5159])</span> |
| <span class="sd"> >>> model.intercept</span> |
| <span class="sd"> 2.579645978780695</span> |
| <span class="sd"> >>> model.numClasses</span> |
| <span class="sd"> 2</span> |
| <span class="sd"> >>> model.numFeatures</span> |
| <span class="sd"> 3</span> |
| <span class="sd"> >>> test0 = sc.parallelize([Row(features=Vectors.dense(-1.0, -1.0, -1.0))]).toDF()</span> |
| <span class="sd"> >>> model.predict(test0.head().features)</span> |
| <span class="sd"> 1.0</span> |
| <span class="sd"> >>> model.predictRaw(test0.head().features)</span> |
| <span class="sd"> DenseVector([-4.1274, 4.1274])</span> |
| <span class="sd"> >>> result = model.transform(test0).head()</span> |
| <span class="sd"> >>> result.newPrediction</span> |
| <span class="sd"> 1.0</span> |
| <span class="sd"> >>> result.rawPrediction</span> |
| <span class="sd"> DenseVector([-4.1274, 4.1274])</span> |
| <span class="sd"> >>> svm_path = temp_path + "/svm"</span> |
| <span class="sd"> >>> svm.save(svm_path)</span> |
| <span class="sd"> >>> svm2 = LinearSVC.load(svm_path)</span> |
| <span class="sd"> >>> svm2.getMaxIter()</span> |
| <span class="sd"> 5</span> |
| <span class="sd"> >>> model_path = temp_path + "/svm_model"</span> |
| <span class="sd"> >>> model.save(model_path)</span> |
| <span class="sd"> >>> model2 = LinearSVCModel.load(model_path)</span> |
| <span class="sd"> >>> model.coefficients[0] == model2.coefficients[0]</span> |
| <span class="sd"> True</span> |
| <span class="sd"> >>> model.intercept == model2.intercept</span> |
| <span class="sd"> True</span> |
| <span class="sd"> >>> model.transform(test0).take(1) == model2.transform(test0).take(1)</span> |
| <span class="sd"> True</span> |
| <span class="sd"> """</span> |
| |
| <span class="n">_input_kwargs</span><span class="p">:</span> <span class="n">Dict</span><span class="p">[</span><span class="nb">str</span><span class="p">,</span> <span class="n">Any</span><span class="p">]</span> |
| |
| <span class="nd">@keyword_only</span> |
| <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span> |
| <span class="bp">self</span><span class="p">,</span> |
| <span class="o">*</span><span class="p">,</span> |
| <span class="n">featuresCol</span><span class="p">:</span> <span class="nb">str</span> <span class="o">=</span> <span class="s2">"features"</span><span class="p">,</span> |
| <span class="n">labelCol</span><span class="p">:</span> <span class="nb">str</span> <span class="o">=</span> <span class="s2">"label"</span><span class="p">,</span> |
| <span class="n">predictionCol</span><span class="p">:</span> <span class="nb">str</span> <span class="o">=</span> <span class="s2">"prediction"</span><span class="p">,</span> |
| <span class="n">maxIter</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="mi">100</span><span class="p">,</span> |
| <span class="n">regParam</span><span class="p">:</span> <span class="nb">float</span> <span class="o">=</span> <span class="mf">0.0</span><span class="p">,</span> |
| <span class="n">tol</span><span class="p">:</span> <span class="nb">float</span> <span class="o">=</span> <span class="mf">1e-6</span><span class="p">,</span> |
| <span class="n">rawPredictionCol</span><span class="p">:</span> <span class="nb">str</span> <span class="o">=</span> <span class="s2">"rawPrediction"</span><span class="p">,</span> |
| <span class="n">fitIntercept</span><span class="p">:</span> <span class="nb">bool</span> <span class="o">=</span> <span class="kc">True</span><span class="p">,</span> |
| <span class="n">standardization</span><span class="p">:</span> <span class="nb">bool</span> <span class="o">=</span> <span class="kc">True</span><span class="p">,</span> |
| <span class="n">threshold</span><span class="p">:</span> <span class="nb">float</span> <span class="o">=</span> <span class="mf">0.0</span><span class="p">,</span> |
| <span class="n">weightCol</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="nb">str</span><span class="p">]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span> |
| <span class="n">aggregationDepth</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="mi">2</span><span class="p">,</span> |
| <span class="n">maxBlockSizeInMB</span><span class="p">:</span> <span class="nb">float</span> <span class="o">=</span> <span class="mf">0.0</span><span class="p">,</span> |
| <span class="p">):</span> |
| <span class="w"> </span><span class="sd">"""</span> |
| <span class="sd"> __init__(self, \\*, featuresCol="features", labelCol="label", predictionCol="prediction", \</span> |
| <span class="sd"> maxIter=100, regParam=0.0, tol=1e-6, rawPredictionCol="rawPrediction", \</span> |
| <span class="sd"> fitIntercept=True, standardization=True, threshold=0.0, weightCol=None, \</span> |
| <span class="sd"> aggregationDepth=2, maxBlockSizeInMB=0.0):</span> |
| <span class="sd"> """</span> |
| <span class="nb">super</span><span class="p">(</span><span class="n">LinearSVC</span><span class="p">,</span> <span class="bp">self</span><span class="p">)</span><span class="o">.</span><span class="fm">__init__</span><span class="p">()</span> |
| <span class="bp">self</span><span class="o">.</span><span class="n">_java_obj</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">_new_java_obj</span><span class="p">(</span> |
| <span class="s2">"org.apache.spark.ml.classification.LinearSVC"</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">uid</span> |
| <span class="p">)</span> |
| <span class="n">kwargs</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">_input_kwargs</span> |
| <span class="bp">self</span><span class="o">.</span><span class="n">setParams</span><span class="p">(</span><span class="o">**</span><span class="n">kwargs</span><span class="p">)</span> |
| |
| <div class="viewcode-block" id="LinearSVC.setParams"><a class="viewcode-back" href="../../../reference/api/pyspark.ml.classification.LinearSVC.html#pyspark.ml.classification.LinearSVC.setParams">[docs]</a> <span class="nd">@keyword_only</span> |
| <span class="nd">@since</span><span class="p">(</span><span class="s2">"2.2.0"</span><span class="p">)</span> |
| <span class="k">def</span> <span class="nf">setParams</span><span class="p">(</span> |
| <span class="bp">self</span><span class="p">,</span> |
| <span class="o">*</span><span class="p">,</span> |
| <span class="n">featuresCol</span><span class="p">:</span> <span class="nb">str</span> <span class="o">=</span> <span class="s2">"features"</span><span class="p">,</span> |
| <span class="n">labelCol</span><span class="p">:</span> <span class="nb">str</span> <span class="o">=</span> <span class="s2">"label"</span><span class="p">,</span> |
| <span class="n">predictionCol</span><span class="p">:</span> <span class="nb">str</span> <span class="o">=</span> <span class="s2">"prediction"</span><span class="p">,</span> |
| <span class="n">maxIter</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="mi">100</span><span class="p">,</span> |
| <span class="n">regParam</span><span class="p">:</span> <span class="nb">float</span> <span class="o">=</span> <span class="mf">0.0</span><span class="p">,</span> |
| <span class="n">tol</span><span class="p">:</span> <span class="nb">float</span> <span class="o">=</span> <span class="mf">1e-6</span><span class="p">,</span> |
| <span class="n">rawPredictionCol</span><span class="p">:</span> <span class="nb">str</span> <span class="o">=</span> <span class="s2">"rawPrediction"</span><span class="p">,</span> |
| <span class="n">fitIntercept</span><span class="p">:</span> <span class="nb">bool</span> <span class="o">=</span> <span class="kc">True</span><span class="p">,</span> |
| <span class="n">standardization</span><span class="p">:</span> <span class="nb">bool</span> <span class="o">=</span> <span class="kc">True</span><span class="p">,</span> |
| <span class="n">threshold</span><span class="p">:</span> <span class="nb">float</span> <span class="o">=</span> <span class="mf">0.0</span><span class="p">,</span> |
| <span class="n">weightCol</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="nb">str</span><span class="p">]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span> |
| <span class="n">aggregationDepth</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="mi">2</span><span class="p">,</span> |
| <span class="n">maxBlockSizeInMB</span><span class="p">:</span> <span class="nb">float</span> <span class="o">=</span> <span class="mf">0.0</span><span class="p">,</span> |
| <span class="p">)</span> <span class="o">-></span> <span class="s2">"LinearSVC"</span><span class="p">:</span> |
| <span class="w"> </span><span class="sd">"""</span> |
| <span class="sd"> setParams(self, \\*, featuresCol="features", labelCol="label", predictionCol="prediction", \</span> |
| <span class="sd"> maxIter=100, regParam=0.0, tol=1e-6, rawPredictionCol="rawPrediction", \</span> |
| <span class="sd"> fitIntercept=True, standardization=True, threshold=0.0, weightCol=None, \</span> |
| <span class="sd"> aggregationDepth=2, maxBlockSizeInMB=0.0):</span> |
| <span class="sd"> Sets params for Linear SVM Classifier.</span> |
| <span class="sd"> """</span> |
| <span class="n">kwargs</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">_input_kwargs</span> |
| <span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">_set</span><span class="p">(</span><span class="o">**</span><span class="n">kwargs</span><span class="p">)</span></div> |
| |
| <span class="k">def</span> <span class="nf">_create_model</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">java_model</span><span class="p">:</span> <span class="s2">"JavaObject"</span><span class="p">)</span> <span class="o">-></span> <span class="s2">"LinearSVCModel"</span><span class="p">:</span> |
| <span class="k">return</span> <span class="n">LinearSVCModel</span><span class="p">(</span><span class="n">java_model</span><span class="p">)</span> |
| |
| <div class="viewcode-block" id="LinearSVC.setMaxIter"><a class="viewcode-back" href="../../../reference/api/pyspark.ml.classification.LinearSVC.html#pyspark.ml.classification.LinearSVC.setMaxIter">[docs]</a> <span class="nd">@since</span><span class="p">(</span><span class="s2">"2.2.0"</span><span class="p">)</span> |
| <span class="k">def</span> <span class="nf">setMaxIter</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">value</span><span class="p">:</span> <span class="nb">int</span><span class="p">)</span> <span class="o">-></span> <span class="s2">"LinearSVC"</span><span class="p">:</span> |
| <span class="w"> </span><span class="sd">"""</span> |
| <span class="sd"> Sets the value of :py:attr:`maxIter`.</span> |
| <span class="sd"> """</span> |
| <span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">_set</span><span class="p">(</span><span class="n">maxIter</span><span class="o">=</span><span class="n">value</span><span class="p">)</span></div> |
| |
| <div class="viewcode-block" id="LinearSVC.setRegParam"><a class="viewcode-back" href="../../../reference/api/pyspark.ml.classification.LinearSVC.html#pyspark.ml.classification.LinearSVC.setRegParam">[docs]</a> <span class="nd">@since</span><span class="p">(</span><span class="s2">"2.2.0"</span><span class="p">)</span> |
| <span class="k">def</span> <span class="nf">setRegParam</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">value</span><span class="p">:</span> <span class="nb">float</span><span class="p">)</span> <span class="o">-></span> <span class="s2">"LinearSVC"</span><span class="p">:</span> |
| <span class="w"> </span><span class="sd">"""</span> |
| <span class="sd"> Sets the value of :py:attr:`regParam`.</span> |
| <span class="sd"> """</span> |
| <span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">_set</span><span class="p">(</span><span class="n">regParam</span><span class="o">=</span><span class="n">value</span><span class="p">)</span></div> |
| |
| <div class="viewcode-block" id="LinearSVC.setTol"><a class="viewcode-back" href="../../../reference/api/pyspark.ml.classification.LinearSVC.html#pyspark.ml.classification.LinearSVC.setTol">[docs]</a> <span class="nd">@since</span><span class="p">(</span><span class="s2">"2.2.0"</span><span class="p">)</span> |
| <span class="k">def</span> <span class="nf">setTol</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">value</span><span class="p">:</span> <span class="nb">float</span><span class="p">)</span> <span class="o">-></span> <span class="s2">"LinearSVC"</span><span class="p">:</span> |
| <span class="w"> </span><span class="sd">"""</span> |
| <span class="sd"> Sets the value of :py:attr:`tol`.</span> |
| <span class="sd"> """</span> |
| <span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">_set</span><span class="p">(</span><span class="n">tol</span><span class="o">=</span><span class="n">value</span><span class="p">)</span></div> |
| |
| <div class="viewcode-block" id="LinearSVC.setFitIntercept"><a class="viewcode-back" href="../../../reference/api/pyspark.ml.classification.LinearSVC.html#pyspark.ml.classification.LinearSVC.setFitIntercept">[docs]</a> <span class="nd">@since</span><span class="p">(</span><span class="s2">"2.2.0"</span><span class="p">)</span> |
| <span class="k">def</span> <span class="nf">setFitIntercept</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">value</span><span class="p">:</span> <span class="nb">bool</span><span class="p">)</span> <span class="o">-></span> <span class="s2">"LinearSVC"</span><span class="p">:</span> |
| <span class="w"> </span><span class="sd">"""</span> |
| <span class="sd"> Sets the value of :py:attr:`fitIntercept`.</span> |
| <span class="sd"> """</span> |
| <span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">_set</span><span class="p">(</span><span class="n">fitIntercept</span><span class="o">=</span><span class="n">value</span><span class="p">)</span></div> |
| |
| <div class="viewcode-block" id="LinearSVC.setStandardization"><a class="viewcode-back" href="../../../reference/api/pyspark.ml.classification.LinearSVC.html#pyspark.ml.classification.LinearSVC.setStandardization">[docs]</a> <span class="nd">@since</span><span class="p">(</span><span class="s2">"2.2.0"</span><span class="p">)</span> |
| <span class="k">def</span> <span class="nf">setStandardization</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">value</span><span class="p">:</span> <span class="nb">bool</span><span class="p">)</span> <span class="o">-></span> <span class="s2">"LinearSVC"</span><span class="p">:</span> |
| <span class="w"> </span><span class="sd">"""</span> |
| <span class="sd"> Sets the value of :py:attr:`standardization`.</span> |
| <span class="sd"> """</span> |
| <span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">_set</span><span class="p">(</span><span class="n">standardization</span><span class="o">=</span><span class="n">value</span><span class="p">)</span></div> |
| |
| <div class="viewcode-block" id="LinearSVC.setThreshold"><a class="viewcode-back" href="../../../reference/api/pyspark.ml.classification.LinearSVC.html#pyspark.ml.classification.LinearSVC.setThreshold">[docs]</a> <span class="nd">@since</span><span class="p">(</span><span class="s2">"2.2.0"</span><span class="p">)</span> |
| <span class="k">def</span> <span class="nf">setThreshold</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">value</span><span class="p">:</span> <span class="nb">float</span><span class="p">)</span> <span class="o">-></span> <span class="s2">"LinearSVC"</span><span class="p">:</span> |
| <span class="w"> </span><span class="sd">"""</span> |
| <span class="sd"> Sets the value of :py:attr:`threshold`.</span> |
| <span class="sd"> """</span> |
| <span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">_set</span><span class="p">(</span><span class="n">threshold</span><span class="o">=</span><span class="n">value</span><span class="p">)</span></div> |
| |
| <div class="viewcode-block" id="LinearSVC.setWeightCol"><a class="viewcode-back" href="../../../reference/api/pyspark.ml.classification.LinearSVC.html#pyspark.ml.classification.LinearSVC.setWeightCol">[docs]</a> <span class="nd">@since</span><span class="p">(</span><span class="s2">"2.2.0"</span><span class="p">)</span> |
| <span class="k">def</span> <span class="nf">setWeightCol</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">value</span><span class="p">:</span> <span class="nb">str</span><span class="p">)</span> <span class="o">-></span> <span class="s2">"LinearSVC"</span><span class="p">:</span> |
| <span class="w"> </span><span class="sd">"""</span> |
| <span class="sd"> Sets the value of :py:attr:`weightCol`.</span> |
| <span class="sd"> """</span> |
| <span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">_set</span><span class="p">(</span><span class="n">weightCol</span><span class="o">=</span><span class="n">value</span><span class="p">)</span></div> |
| |
| <div class="viewcode-block" id="LinearSVC.setAggregationDepth"><a class="viewcode-back" href="../../../reference/api/pyspark.ml.classification.LinearSVC.html#pyspark.ml.classification.LinearSVC.setAggregationDepth">[docs]</a> <span class="nd">@since</span><span class="p">(</span><span class="s2">"2.2.0"</span><span class="p">)</span> |
| <span class="k">def</span> <span class="nf">setAggregationDepth</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">value</span><span class="p">:</span> <span class="nb">int</span><span class="p">)</span> <span class="o">-></span> <span class="s2">"LinearSVC"</span><span class="p">:</span> |
| <span class="w"> </span><span class="sd">"""</span> |
| <span class="sd"> Sets the value of :py:attr:`aggregationDepth`.</span> |
| <span class="sd"> """</span> |
| <span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">_set</span><span class="p">(</span><span class="n">aggregationDepth</span><span class="o">=</span><span class="n">value</span><span class="p">)</span></div> |
| |
| <div class="viewcode-block" id="LinearSVC.setMaxBlockSizeInMB"><a class="viewcode-back" href="../../../reference/api/pyspark.ml.classification.LinearSVC.html#pyspark.ml.classification.LinearSVC.setMaxBlockSizeInMB">[docs]</a> <span class="nd">@since</span><span class="p">(</span><span class="s2">"3.1.0"</span><span class="p">)</span> |
| <span class="k">def</span> <span class="nf">setMaxBlockSizeInMB</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">value</span><span class="p">:</span> <span class="nb">float</span><span class="p">)</span> <span class="o">-></span> <span class="s2">"LinearSVC"</span><span class="p">:</span> |
| <span class="w"> </span><span class="sd">"""</span> |
| <span class="sd"> Sets the value of :py:attr:`maxBlockSizeInMB`.</span> |
| <span class="sd"> """</span> |
| <span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">_set</span><span class="p">(</span><span class="n">maxBlockSizeInMB</span><span class="o">=</span><span class="n">value</span><span class="p">)</span></div></div> |
| |
| |
| <div class="viewcode-block" id="LinearSVCModel"><a class="viewcode-back" href="../../../reference/api/pyspark.ml.classification.LinearSVCModel.html#pyspark.ml.classification.LinearSVCModel">[docs]</a><span class="k">class</span> <span class="nc">LinearSVCModel</span><span class="p">(</span> |
| <span class="n">_JavaClassificationModel</span><span class="p">[</span><span class="n">Vector</span><span class="p">],</span> |
| <span class="n">_LinearSVCParams</span><span class="p">,</span> |
| <span class="n">JavaMLWritable</span><span class="p">,</span> |
| <span class="n">JavaMLReadable</span><span class="p">[</span><span class="s2">"LinearSVCModel"</span><span class="p">],</span> |
| <span class="n">HasTrainingSummary</span><span class="p">[</span><span class="s2">"LinearSVCTrainingSummary"</span><span class="p">],</span> |
| <span class="p">):</span> |
| <span class="w"> </span><span class="sd">"""</span> |
| <span class="sd"> Model fitted by LinearSVC.</span> |
| |
| <span class="sd"> .. versionadded:: 2.2.0</span> |
| <span class="sd"> """</span> |
| |
| <div class="viewcode-block" id="LinearSVCModel.setThreshold"><a class="viewcode-back" href="../../../reference/api/pyspark.ml.classification.LinearSVCModel.html#pyspark.ml.classification.LinearSVCModel.setThreshold">[docs]</a> <span class="nd">@since</span><span class="p">(</span><span class="s2">"3.0.0"</span><span class="p">)</span> |
| <span class="k">def</span> <span class="nf">setThreshold</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">value</span><span class="p">:</span> <span class="nb">float</span><span class="p">)</span> <span class="o">-></span> <span class="s2">"LinearSVCModel"</span><span class="p">:</span> |
| <span class="w"> </span><span class="sd">"""</span> |
| <span class="sd"> Sets the value of :py:attr:`threshold`.</span> |
| <span class="sd"> """</span> |
| <span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">_set</span><span class="p">(</span><span class="n">threshold</span><span class="o">=</span><span class="n">value</span><span class="p">)</span></div> |
| |
| <span class="nd">@property</span> <span class="c1"># type: ignore[misc]</span> |
| <span class="nd">@since</span><span class="p">(</span><span class="s2">"2.2.0"</span><span class="p">)</span> |
| <span class="k">def</span> <span class="nf">coefficients</span><span class="p">(</span><span class="bp">self</span><span class="p">)</span> <span class="o">-></span> <span class="n">Vector</span><span class="p">:</span> |
| <span class="w"> </span><span class="sd">"""</span> |
| <span class="sd"> Model coefficients of Linear SVM Classifier.</span> |
| <span class="sd"> """</span> |
| <span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">_call_java</span><span class="p">(</span><span class="s2">"coefficients"</span><span class="p">)</span> |
| |
| <span class="nd">@property</span> <span class="c1"># type: ignore[misc]</span> |
| <span class="nd">@since</span><span class="p">(</span><span class="s2">"2.2.0"</span><span class="p">)</span> |
| <span class="k">def</span> <span class="nf">intercept</span><span class="p">(</span><span class="bp">self</span><span class="p">)</span> <span class="o">-></span> <span class="nb">float</span><span class="p">:</span> |
| <span class="w"> </span><span class="sd">"""</span> |
| <span class="sd"> Model intercept of Linear SVM Classifier.</span> |
| <span class="sd"> """</span> |
| <span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">_call_java</span><span class="p">(</span><span class="s2">"intercept"</span><span class="p">)</span> |
| |
| <div class="viewcode-block" id="LinearSVCModel.summary"><a class="viewcode-back" href="../../../reference/api/pyspark.ml.classification.LinearSVCModel.html#pyspark.ml.classification.LinearSVCModel.summary">[docs]</a> <span class="nd">@since</span><span class="p">(</span><span class="s2">"3.1.0"</span><span class="p">)</span> |
| <span class="k">def</span> <span class="nf">summary</span><span class="p">(</span><span class="bp">self</span><span class="p">)</span> <span class="o">-></span> <span class="s2">"LinearSVCTrainingSummary"</span><span class="p">:</span> |
| <span class="w"> </span><span class="sd">"""</span> |
| <span class="sd"> Gets summary (accuracy/precision/recall, objective history, total iterations) of model</span> |
| <span class="sd"> trained on the training set. An exception is thrown if `trainingSummary is None`.</span> |
| <span class="sd"> """</span> |
| <span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">hasSummary</span><span class="p">:</span> |
| <span class="k">return</span> <span class="n">LinearSVCTrainingSummary</span><span class="p">(</span><span class="nb">super</span><span class="p">(</span><span class="n">LinearSVCModel</span><span class="p">,</span> <span class="bp">self</span><span class="p">)</span><span class="o">.</span><span class="n">summary</span><span class="p">)</span> |
| <span class="k">else</span><span class="p">:</span> |
| <span class="k">raise</span> <span class="ne">RuntimeError</span><span class="p">(</span> |
| <span class="s2">"No training summary available for this </span><span class="si">%s</span><span class="s2">"</span> <span class="o">%</span> <span class="bp">self</span><span class="o">.</span><span class="vm">__class__</span><span class="o">.</span><span class="vm">__name__</span> |
| <span class="p">)</span></div> |
| |
| <div class="viewcode-block" id="LinearSVCModel.evaluate"><a class="viewcode-back" href="../../../reference/api/pyspark.ml.classification.LinearSVCModel.html#pyspark.ml.classification.LinearSVCModel.evaluate">[docs]</a> <span class="k">def</span> <span class="nf">evaluate</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">dataset</span><span class="p">:</span> <span class="n">DataFrame</span><span class="p">)</span> <span class="o">-></span> <span class="s2">"LinearSVCSummary"</span><span class="p">:</span> |
| <span class="w"> </span><span class="sd">"""</span> |
| <span class="sd"> Evaluates the model on a test dataset.</span> |
| |
| <span class="sd"> .. versionadded:: 3.1.0</span> |
| |
| <span class="sd"> Parameters</span> |
| <span class="sd"> ----------</span> |
| <span class="sd"> dataset : :py:class:`pyspark.sql.DataFrame`</span> |
| <span class="sd"> Test dataset to evaluate model on.</span> |
| <span class="sd"> """</span> |
| <span class="k">if</span> <span class="ow">not</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">dataset</span><span class="p">,</span> <span class="n">DataFrame</span><span class="p">):</span> |
| <span class="k">raise</span> <span class="ne">TypeError</span><span class="p">(</span><span class="s2">"dataset must be a DataFrame but got </span><span class="si">%s</span><span class="s2">."</span> <span class="o">%</span> <span class="nb">type</span><span class="p">(</span><span class="n">dataset</span><span class="p">))</span> |
| <span class="n">java_lsvc_summary</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">_call_java</span><span class="p">(</span><span class="s2">"evaluate"</span><span class="p">,</span> <span class="n">dataset</span><span class="p">)</span> |
| <span class="k">return</span> <span class="n">LinearSVCSummary</span><span class="p">(</span><span class="n">java_lsvc_summary</span><span class="p">)</span></div></div> |
| |
| |
| <div class="viewcode-block" id="LinearSVCSummary"><a class="viewcode-back" href="../../../reference/api/pyspark.ml.classification.LinearSVCSummary.html#pyspark.ml.classification.LinearSVCSummary">[docs]</a><span class="k">class</span> <span class="nc">LinearSVCSummary</span><span class="p">(</span><span class="n">_BinaryClassificationSummary</span><span class="p">):</span> |
| <span class="w"> </span><span class="sd">"""</span> |
| <span class="sd"> Abstraction for LinearSVC Results for a given model.</span> |
| |
| <span class="sd"> .. versionadded:: 3.1.0</span> |
| <span class="sd"> """</span> |
| |
| <span class="k">pass</span></div> |
| |
| |
| <div class="viewcode-block" id="LinearSVCTrainingSummary"><a class="viewcode-back" href="../../../reference/api/pyspark.ml.classification.LinearSVCTrainingSummary.html#pyspark.ml.classification.LinearSVCTrainingSummary">[docs]</a><span class="nd">@inherit_doc</span> |
| <span class="k">class</span> <span class="nc">LinearSVCTrainingSummary</span><span class="p">(</span><span class="n">LinearSVCSummary</span><span class="p">,</span> <span class="n">_TrainingSummary</span><span class="p">):</span> |
| <span class="w"> </span><span class="sd">"""</span> |
| <span class="sd"> Abstraction for LinearSVC Training results.</span> |
| |
| <span class="sd"> .. versionadded:: 3.1.0</span> |
| <span class="sd"> """</span> |
| |
| <span class="k">pass</span></div> |
| |
| |
| <span class="k">class</span> <span class="nc">_LogisticRegressionParams</span><span class="p">(</span> |
| <span class="n">_ProbabilisticClassifierParams</span><span class="p">,</span> |
| <span class="n">HasRegParam</span><span class="p">,</span> |
| <span class="n">HasElasticNetParam</span><span class="p">,</span> |
| <span class="n">HasMaxIter</span><span class="p">,</span> |
| <span class="n">HasFitIntercept</span><span class="p">,</span> |
| <span class="n">HasTol</span><span class="p">,</span> |
| <span class="n">HasStandardization</span><span class="p">,</span> |
| <span class="n">HasWeightCol</span><span class="p">,</span> |
| <span class="n">HasAggregationDepth</span><span class="p">,</span> |
| <span class="n">HasThreshold</span><span class="p">,</span> |
| <span class="n">HasMaxBlockSizeInMB</span><span class="p">,</span> |
| <span class="p">):</span> |
| <span class="w"> </span><span class="sd">"""</span> |
| <span class="sd"> Params for :py:class:`LogisticRegression` and :py:class:`LogisticRegressionModel`.</span> |
| |
| <span class="sd"> .. versionadded:: 3.0.0</span> |
| <span class="sd"> """</span> |
| |
| <span class="n">threshold</span><span class="p">:</span> <span class="n">Param</span><span class="p">[</span><span class="nb">float</span><span class="p">]</span> <span class="o">=</span> <span class="n">Param</span><span class="p">(</span> |
| <span class="n">Params</span><span class="o">.</span><span class="n">_dummy</span><span class="p">(),</span> |
| <span class="s2">"threshold"</span><span class="p">,</span> |
| <span class="s2">"Threshold in binary classification prediction, in range [0, 1]."</span> |
| <span class="o">+</span> <span class="s2">" If threshold and thresholds are both set, they must match."</span> |
| <span class="o">+</span> <span class="s2">"e.g. if threshold is p, then thresholds must be equal to [1-p, p]."</span><span class="p">,</span> |
| <span class="n">typeConverter</span><span class="o">=</span><span class="n">TypeConverters</span><span class="o">.</span><span class="n">toFloat</span><span class="p">,</span> |
| <span class="p">)</span> |
| |
| <span class="n">family</span><span class="p">:</span> <span class="n">Param</span><span class="p">[</span><span class="nb">str</span><span class="p">]</span> <span class="o">=</span> <span class="n">Param</span><span class="p">(</span> |
| <span class="n">Params</span><span class="o">.</span><span class="n">_dummy</span><span class="p">(),</span> |
| <span class="s2">"family"</span><span class="p">,</span> |
| <span class="s2">"The name of family which is a description of the label distribution to "</span> |
| <span class="o">+</span> <span class="s2">"be used in the model. Supported options: auto, binomial, multinomial"</span><span class="p">,</span> |
| <span class="n">typeConverter</span><span class="o">=</span><span class="n">TypeConverters</span><span class="o">.</span><span class="n">toString</span><span class="p">,</span> |
| <span class="p">)</span> |
| |
| <span class="n">lowerBoundsOnCoefficients</span><span class="p">:</span> <span class="n">Param</span><span class="p">[</span><span class="n">Matrix</span><span class="p">]</span> <span class="o">=</span> <span class="n">Param</span><span class="p">(</span> |
| <span class="n">Params</span><span class="o">.</span><span class="n">_dummy</span><span class="p">(),</span> |
| <span class="s2">"lowerBoundsOnCoefficients"</span><span class="p">,</span> |
| <span class="s2">"The lower bounds on coefficients if fitting under bound "</span> |
| <span class="s2">"constrained optimization. The bound matrix must be "</span> |
| <span class="s2">"compatible with the shape "</span> |
| <span class="s2">"(1, number of features) for binomial regression, or "</span> |
| <span class="s2">"(number of classes, number of features) "</span> |
| <span class="s2">"for multinomial regression."</span><span class="p">,</span> |
| <span class="n">typeConverter</span><span class="o">=</span><span class="n">TypeConverters</span><span class="o">.</span><span class="n">toMatrix</span><span class="p">,</span> |
| <span class="p">)</span> |
| |
| <span class="n">upperBoundsOnCoefficients</span><span class="p">:</span> <span class="n">Param</span><span class="p">[</span><span class="n">Matrix</span><span class="p">]</span> <span class="o">=</span> <span class="n">Param</span><span class="p">(</span> |
| <span class="n">Params</span><span class="o">.</span><span class="n">_dummy</span><span class="p">(),</span> |
| <span class="s2">"upperBoundsOnCoefficients"</span><span class="p">,</span> |
| <span class="s2">"The upper bounds on coefficients if fitting under bound "</span> |
| <span class="s2">"constrained optimization. The bound matrix must be "</span> |
| <span class="s2">"compatible with the shape "</span> |
| <span class="s2">"(1, number of features) for binomial regression, or "</span> |
| <span class="s2">"(number of classes, number of features) "</span> |
| <span class="s2">"for multinomial regression."</span><span class="p">,</span> |
| <span class="n">typeConverter</span><span class="o">=</span><span class="n">TypeConverters</span><span class="o">.</span><span class="n">toMatrix</span><span class="p">,</span> |
| <span class="p">)</span> |
| |
| <span class="n">lowerBoundsOnIntercepts</span><span class="p">:</span> <span class="n">Param</span><span class="p">[</span><span class="n">Vector</span><span class="p">]</span> <span class="o">=</span> <span class="n">Param</span><span class="p">(</span> |
| <span class="n">Params</span><span class="o">.</span><span class="n">_dummy</span><span class="p">(),</span> |
| <span class="s2">"lowerBoundsOnIntercepts"</span><span class="p">,</span> |
| <span class="s2">"The lower bounds on intercepts if fitting under bound "</span> |
| <span class="s2">"constrained optimization. The bounds vector size must be"</span> |
| <span class="s2">"equal with 1 for binomial regression, or the number of"</span> |
| <span class="s2">"lasses for multinomial regression."</span><span class="p">,</span> |
| <span class="n">typeConverter</span><span class="o">=</span><span class="n">TypeConverters</span><span class="o">.</span><span class="n">toVector</span><span class="p">,</span> |
| <span class="p">)</span> |
| |
| <span class="n">upperBoundsOnIntercepts</span><span class="p">:</span> <span class="n">Param</span><span class="p">[</span><span class="n">Vector</span><span class="p">]</span> <span class="o">=</span> <span class="n">Param</span><span class="p">(</span> |
| <span class="n">Params</span><span class="o">.</span><span class="n">_dummy</span><span class="p">(),</span> |
| <span class="s2">"upperBoundsOnIntercepts"</span><span class="p">,</span> |
| <span class="s2">"The upper bounds on intercepts if fitting under bound "</span> |
| <span class="s2">"constrained optimization. The bound vector size must be "</span> |
| <span class="s2">"equal with 1 for binomial regression, or the number of "</span> |
| <span class="s2">"classes for multinomial regression."</span><span class="p">,</span> |
| <span class="n">typeConverter</span><span class="o">=</span><span class="n">TypeConverters</span><span class="o">.</span><span class="n">toVector</span><span class="p">,</span> |
| <span class="p">)</span> |
| |
| <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="o">*</span><span class="n">args</span><span class="p">:</span> <span class="n">Any</span><span class="p">):</span> |
| <span class="nb">super</span><span class="p">(</span><span class="n">_LogisticRegressionParams</span><span class="p">,</span> <span class="bp">self</span><span class="p">)</span><span class="o">.</span><span class="fm">__init__</span><span class="p">(</span><span class="o">*</span><span class="n">args</span><span class="p">)</span> |
| <span class="bp">self</span><span class="o">.</span><span class="n">_setDefault</span><span class="p">(</span> |
| <span class="n">maxIter</span><span class="o">=</span><span class="mi">100</span><span class="p">,</span> <span class="n">regParam</span><span class="o">=</span><span class="mf">0.0</span><span class="p">,</span> <span class="n">tol</span><span class="o">=</span><span class="mf">1e-6</span><span class="p">,</span> <span class="n">threshold</span><span class="o">=</span><span class="mf">0.5</span><span class="p">,</span> <span class="n">family</span><span class="o">=</span><span class="s2">"auto"</span><span class="p">,</span> <span class="n">maxBlockSizeInMB</span><span class="o">=</span><span class="mf">0.0</span> |
| <span class="p">)</span> |
| |
| <span class="nd">@since</span><span class="p">(</span><span class="s2">"1.4.0"</span><span class="p">)</span> |
| <span class="k">def</span> <span class="nf">setThreshold</span><span class="p">(</span><span class="bp">self</span><span class="p">:</span> <span class="s2">"P"</span><span class="p">,</span> <span class="n">value</span><span class="p">:</span> <span class="nb">float</span><span class="p">)</span> <span class="o">-></span> <span class="s2">"P"</span><span class="p">:</span> |
| <span class="w"> </span><span class="sd">"""</span> |
| <span class="sd"> Sets the value of :py:attr:`threshold`.</span> |
| <span class="sd"> Clears value of :py:attr:`thresholds` if it has been set.</span> |
| <span class="sd"> """</span> |
| <span class="bp">self</span><span class="o">.</span><span class="n">_set</span><span class="p">(</span><span class="n">threshold</span><span class="o">=</span><span class="n">value</span><span class="p">)</span> |
| <span class="bp">self</span><span class="o">.</span><span class="n">clear</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">thresholds</span><span class="p">)</span> <span class="c1"># type: ignore[attr-defined]</span> |
| <span class="k">return</span> <span class="bp">self</span> |
| |
| <span class="nd">@since</span><span class="p">(</span><span class="s2">"1.4.0"</span><span class="p">)</span> |
| <span class="k">def</span> <span class="nf">getThreshold</span><span class="p">(</span><span class="bp">self</span><span class="p">)</span> <span class="o">-></span> <span class="nb">float</span><span class="p">:</span> |
| <span class="w"> </span><span class="sd">"""</span> |
| <span class="sd"> Get threshold for binary classification.</span> |
| |
| <span class="sd"> If :py:attr:`thresholds` is set with length 2 (i.e., binary classification),</span> |
| <span class="sd"> this returns the equivalent threshold:</span> |
| <span class="sd"> :math:`\\frac{1}{1 + \\frac{thresholds(0)}{thresholds(1)}}`.</span> |
| <span class="sd"> Otherwise, returns :py:attr:`threshold` if set or its default value if unset.</span> |
| <span class="sd"> """</span> |
| <span class="bp">self</span><span class="o">.</span><span class="n">_checkThresholdConsistency</span><span class="p">()</span> |
| <span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">isSet</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">thresholds</span><span class="p">):</span> |
| <span class="n">ts</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">getOrDefault</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">thresholds</span><span class="p">)</span> |
| <span class="k">if</span> <span class="nb">len</span><span class="p">(</span><span class="n">ts</span><span class="p">)</span> <span class="o">!=</span> <span class="mi">2</span><span class="p">:</span> |
| <span class="k">raise</span> <span class="ne">ValueError</span><span class="p">(</span> |
| <span class="s2">"Logistic Regression getThreshold only applies to"</span> |
| <span class="o">+</span> <span class="s2">" binary classification, but thresholds has length != 2."</span> |
| <span class="o">+</span> <span class="s2">" thresholds: </span><span class="si">{ts}</span><span class="s2">"</span><span class="o">.</span><span class="n">format</span><span class="p">(</span><span class="n">ts</span><span class="o">=</span><span class="n">ts</span><span class="p">)</span> |
| <span class="p">)</span> |
| <span class="k">return</span> <span class="mf">1.0</span> <span class="o">/</span> <span class="p">(</span><span class="mf">1.0</span> <span class="o">+</span> <span class="n">ts</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span> <span class="o">/</span> <span class="n">ts</span><span class="p">[</span><span class="mi">1</span><span class="p">])</span> |
| <span class="k">else</span><span class="p">:</span> |
| <span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">getOrDefault</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">threshold</span><span class="p">)</span> |
| |
| <span class="nd">@since</span><span class="p">(</span><span class="s2">"1.5.0"</span><span class="p">)</span> |
| <span class="k">def</span> <span class="nf">setThresholds</span><span class="p">(</span><span class="bp">self</span><span class="p">:</span> <span class="s2">"P"</span><span class="p">,</span> <span class="n">value</span><span class="p">:</span> <span class="n">List</span><span class="p">[</span><span class="nb">float</span><span class="p">])</span> <span class="o">-></span> <span class="s2">"P"</span><span class="p">:</span> |
| <span class="w"> </span><span class="sd">"""</span> |
| <span class="sd"> Sets the value of :py:attr:`thresholds`.</span> |
| <span class="sd"> Clears value of :py:attr:`threshold` if it has been set.</span> |
| <span class="sd"> """</span> |
| <span class="bp">self</span><span class="o">.</span><span class="n">_set</span><span class="p">(</span><span class="n">thresholds</span><span class="o">=</span><span class="n">value</span><span class="p">)</span> |
| <span class="bp">self</span><span class="o">.</span><span class="n">clear</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">threshold</span><span class="p">)</span> <span class="c1"># type: ignore[attr-defined]</span> |
| <span class="k">return</span> <span class="bp">self</span> |
| |
| <span class="nd">@since</span><span class="p">(</span><span class="s2">"1.5.0"</span><span class="p">)</span> |
| <span class="k">def</span> <span class="nf">getThresholds</span><span class="p">(</span><span class="bp">self</span><span class="p">)</span> <span class="o">-></span> <span class="n">List</span><span class="p">[</span><span class="nb">float</span><span class="p">]:</span> |
| <span class="w"> </span><span class="sd">"""</span> |
| <span class="sd"> If :py:attr:`thresholds` is set, return its value.</span> |
| <span class="sd"> Otherwise, if :py:attr:`threshold` is set, return the equivalent thresholds for binary</span> |
| <span class="sd"> classification: (1-threshold, threshold).</span> |
| <span class="sd"> If neither are set, throw an error.</span> |
| <span class="sd"> """</span> |
| <span class="bp">self</span><span class="o">.</span><span class="n">_checkThresholdConsistency</span><span class="p">()</span> |
| <span class="k">if</span> <span class="ow">not</span> <span class="bp">self</span><span class="o">.</span><span class="n">isSet</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">thresholds</span><span class="p">)</span> <span class="ow">and</span> <span class="bp">self</span><span class="o">.</span><span class="n">isSet</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">threshold</span><span class="p">):</span> |
| <span class="n">t</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">getOrDefault</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">threshold</span><span class="p">)</span> |
| <span class="k">return</span> <span class="p">[</span><span class="mf">1.0</span> <span class="o">-</span> <span class="n">t</span><span class="p">,</span> <span class="n">t</span><span class="p">]</span> |
| <span class="k">else</span><span class="p">:</span> |
| <span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">getOrDefault</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">thresholds</span><span class="p">)</span> |
| |
| <span class="k">def</span> <span class="nf">_checkThresholdConsistency</span><span class="p">(</span><span class="bp">self</span><span class="p">)</span> <span class="o">-></span> <span class="kc">None</span><span class="p">:</span> |
| <span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">isSet</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">threshold</span><span class="p">)</span> <span class="ow">and</span> <span class="bp">self</span><span class="o">.</span><span class="n">isSet</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">thresholds</span><span class="p">):</span> |
| <span class="n">ts</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">getOrDefault</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">thresholds</span><span class="p">)</span> |
| <span class="k">if</span> <span class="nb">len</span><span class="p">(</span><span class="n">ts</span><span class="p">)</span> <span class="o">!=</span> <span class="mi">2</span><span class="p">:</span> |
| <span class="k">raise</span> <span class="ne">ValueError</span><span class="p">(</span> |
| <span class="s2">"Logistic Regression getThreshold only applies to"</span> |
| <span class="o">+</span> <span class="s2">" binary classification, but thresholds has length != 2."</span> |
| <span class="o">+</span> <span class="s2">" thresholds: </span><span class="si">{0}</span><span class="s2">"</span><span class="o">.</span><span class="n">format</span><span class="p">(</span><span class="nb">str</span><span class="p">(</span><span class="n">ts</span><span class="p">))</span> |
| <span class="p">)</span> |
| <span class="n">t</span> <span class="o">=</span> <span class="mf">1.0</span> <span class="o">/</span> <span class="p">(</span><span class="mf">1.0</span> <span class="o">+</span> <span class="n">ts</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span> <span class="o">/</span> <span class="n">ts</span><span class="p">[</span><span class="mi">1</span><span class="p">])</span> |
| <span class="n">t2</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">getOrDefault</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">threshold</span><span class="p">)</span> |
| <span class="k">if</span> <span class="nb">abs</span><span class="p">(</span><span class="n">t2</span> <span class="o">-</span> <span class="n">t</span><span class="p">)</span> <span class="o">>=</span> <span class="mf">1e-5</span><span class="p">:</span> |
| <span class="k">raise</span> <span class="ne">ValueError</span><span class="p">(</span> |
| <span class="s2">"Logistic Regression getThreshold found inconsistent values for"</span> |
| <span class="o">+</span> <span class="s2">" threshold (</span><span class="si">%g</span><span class="s2">) and thresholds (equivalent to </span><span class="si">%g</span><span class="s2">)"</span> <span class="o">%</span> <span class="p">(</span><span class="n">t2</span><span class="p">,</span> <span class="n">t</span><span class="p">)</span> |
| <span class="p">)</span> |
| |
| <span class="nd">@since</span><span class="p">(</span><span class="s2">"2.1.0"</span><span class="p">)</span> |
| <span class="k">def</span> <span class="nf">getFamily</span><span class="p">(</span><span class="bp">self</span><span class="p">)</span> <span class="o">-></span> <span class="nb">str</span><span class="p">:</span> |
| <span class="w"> </span><span class="sd">"""</span> |
| <span class="sd"> Gets the value of :py:attr:`family` or its default value.</span> |
| <span class="sd"> """</span> |
| <span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">getOrDefault</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">family</span><span class="p">)</span> |
| |
| <span class="nd">@since</span><span class="p">(</span><span class="s2">"2.3.0"</span><span class="p">)</span> |
| <span class="k">def</span> <span class="nf">getLowerBoundsOnCoefficients</span><span class="p">(</span><span class="bp">self</span><span class="p">)</span> <span class="o">-></span> <span class="n">Matrix</span><span class="p">:</span> |
| <span class="w"> </span><span class="sd">"""</span> |
| <span class="sd"> Gets the value of :py:attr:`lowerBoundsOnCoefficients`</span> |
| <span class="sd"> """</span> |
| <span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">getOrDefault</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">lowerBoundsOnCoefficients</span><span class="p">)</span> |
| |
| <span class="nd">@since</span><span class="p">(</span><span class="s2">"2.3.0"</span><span class="p">)</span> |
| <span class="k">def</span> <span class="nf">getUpperBoundsOnCoefficients</span><span class="p">(</span><span class="bp">self</span><span class="p">)</span> <span class="o">-></span> <span class="n">Matrix</span><span class="p">:</span> |
| <span class="w"> </span><span class="sd">"""</span> |
| <span class="sd"> Gets the value of :py:attr:`upperBoundsOnCoefficients`</span> |
| <span class="sd"> """</span> |
| <span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">getOrDefault</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">upperBoundsOnCoefficients</span><span class="p">)</span> |
| |
| <span class="nd">@since</span><span class="p">(</span><span class="s2">"2.3.0"</span><span class="p">)</span> |
| <span class="k">def</span> <span class="nf">getLowerBoundsOnIntercepts</span><span class="p">(</span><span class="bp">self</span><span class="p">)</span> <span class="o">-></span> <span class="n">Vector</span><span class="p">:</span> |
| <span class="w"> </span><span class="sd">"""</span> |
| <span class="sd"> Gets the value of :py:attr:`lowerBoundsOnIntercepts`</span> |
| <span class="sd"> """</span> |
| <span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">getOrDefault</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">lowerBoundsOnIntercepts</span><span class="p">)</span> |
| |
| <span class="nd">@since</span><span class="p">(</span><span class="s2">"2.3.0"</span><span class="p">)</span> |
| <span class="k">def</span> <span class="nf">getUpperBoundsOnIntercepts</span><span class="p">(</span><span class="bp">self</span><span class="p">)</span> <span class="o">-></span> <span class="n">Vector</span><span class="p">:</span> |
| <span class="w"> </span><span class="sd">"""</span> |
| <span class="sd"> Gets the value of :py:attr:`upperBoundsOnIntercepts`</span> |
| <span class="sd"> """</span> |
| <span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">getOrDefault</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">upperBoundsOnIntercepts</span><span class="p">)</span> |
| |
| |
| <div class="viewcode-block" id="LogisticRegression"><a class="viewcode-back" href="../../../reference/api/pyspark.ml.classification.LogisticRegression.html#pyspark.ml.classification.LogisticRegression">[docs]</a><span class="nd">@inherit_doc</span> |
| <span class="k">class</span> <span class="nc">LogisticRegression</span><span class="p">(</span> |
| <span class="n">_JavaProbabilisticClassifier</span><span class="p">[</span><span class="s2">"LogisticRegressionModel"</span><span class="p">],</span> |
| <span class="n">_LogisticRegressionParams</span><span class="p">,</span> |
| <span class="n">JavaMLWritable</span><span class="p">,</span> |
| <span class="n">JavaMLReadable</span><span class="p">[</span><span class="s2">"LogisticRegression"</span><span class="p">],</span> |
| <span class="p">):</span> |
| <span class="w"> </span><span class="sd">"""</span> |
| <span class="sd"> Logistic regression.</span> |
| <span class="sd"> This class supports multinomial logistic (softmax) and binomial logistic regression.</span> |
| |
| <span class="sd"> .. versionadded:: 1.3.0</span> |
| |
| <span class="sd"> Examples</span> |
| <span class="sd"> --------</span> |
| <span class="sd"> >>> from pyspark.sql import Row</span> |
| <span class="sd"> >>> from pyspark.ml.linalg import Vectors</span> |
| <span class="sd"> >>> bdf = sc.parallelize([</span> |
| <span class="sd"> ... Row(label=1.0, weight=1.0, features=Vectors.dense(0.0, 5.0)),</span> |
| <span class="sd"> ... Row(label=0.0, weight=2.0, features=Vectors.dense(1.0, 2.0)),</span> |
| <span class="sd"> ... Row(label=1.0, weight=3.0, features=Vectors.dense(2.0, 1.0)),</span> |
| <span class="sd"> ... Row(label=0.0, weight=4.0, features=Vectors.dense(3.0, 3.0))]).toDF()</span> |
| <span class="sd"> >>> blor = LogisticRegression(weightCol="weight")</span> |
| <span class="sd"> >>> blor.getRegParam()</span> |
| <span class="sd"> 0.0</span> |
| <span class="sd"> >>> blor.setRegParam(0.01)</span> |
| <span class="sd"> LogisticRegression...</span> |
| <span class="sd"> >>> blor.getRegParam()</span> |
| <span class="sd"> 0.01</span> |
| <span class="sd"> >>> blor.setMaxIter(10)</span> |
| <span class="sd"> LogisticRegression...</span> |
| <span class="sd"> >>> blor.getMaxIter()</span> |
| <span class="sd"> 10</span> |
| <span class="sd"> >>> blor.clear(blor.maxIter)</span> |
| <span class="sd"> >>> blorModel = blor.fit(bdf)</span> |
| <span class="sd"> >>> blorModel.setFeaturesCol("features")</span> |
| <span class="sd"> LogisticRegressionModel...</span> |
| <span class="sd"> >>> blorModel.setProbabilityCol("newProbability")</span> |
| <span class="sd"> LogisticRegressionModel...</span> |
| <span class="sd"> >>> blorModel.getProbabilityCol()</span> |
| <span class="sd"> 'newProbability'</span> |
| <span class="sd"> >>> blorModel.getMaxBlockSizeInMB()</span> |
| <span class="sd"> 0.0</span> |
| <span class="sd"> >>> blorModel.setThreshold(0.1)</span> |
| <span class="sd"> LogisticRegressionModel...</span> |
| <span class="sd"> >>> blorModel.getThreshold()</span> |
| <span class="sd"> 0.1</span> |
| <span class="sd"> >>> blorModel.coefficients</span> |
| <span class="sd"> DenseVector([-1.080..., -0.646...])</span> |
| <span class="sd"> >>> blorModel.intercept</span> |
| <span class="sd"> 3.112...</span> |
| <span class="sd"> >>> blorModel.evaluate(bdf).accuracy == blorModel.summary.accuracy</span> |
| <span class="sd"> True</span> |
| <span class="sd"> >>> data_path = "data/mllib/sample_multiclass_classification_data.txt"</span> |
| <span class="sd"> >>> mdf = spark.read.format("libsvm").load(data_path)</span> |
| <span class="sd"> >>> mlor = LogisticRegression(regParam=0.1, elasticNetParam=1.0, family="multinomial")</span> |
| <span class="sd"> >>> mlorModel = mlor.fit(mdf)</span> |
| <span class="sd"> >>> mlorModel.coefficientMatrix</span> |
| <span class="sd"> SparseMatrix(3, 4, [0, 1, 2, 3], [3, 2, 1], [1.87..., -2.75..., -0.50...], 1)</span> |
| <span class="sd"> >>> mlorModel.interceptVector</span> |
| <span class="sd"> DenseVector([0.04..., -0.42..., 0.37...])</span> |
| <span class="sd"> >>> test0 = sc.parallelize([Row(features=Vectors.dense(-1.0, 1.0))]).toDF()</span> |
| <span class="sd"> >>> blorModel.predict(test0.head().features)</span> |
| <span class="sd"> 1.0</span> |
| <span class="sd"> >>> blorModel.predictRaw(test0.head().features)</span> |
| <span class="sd"> DenseVector([-3.54..., 3.54...])</span> |
| <span class="sd"> >>> blorModel.predictProbability(test0.head().features)</span> |
| <span class="sd"> DenseVector([0.028, 0.972])</span> |
| <span class="sd"> >>> result = blorModel.transform(test0).head()</span> |
| <span class="sd"> >>> result.prediction</span> |
| <span class="sd"> 1.0</span> |
| <span class="sd"> >>> result.newProbability</span> |
| <span class="sd"> DenseVector([0.02..., 0.97...])</span> |
| <span class="sd"> >>> result.rawPrediction</span> |
| <span class="sd"> DenseVector([-3.54..., 3.54...])</span> |
| <span class="sd"> >>> test1 = sc.parallelize([Row(features=Vectors.sparse(2, [0], [1.0]))]).toDF()</span> |
| <span class="sd"> >>> blorModel.transform(test1).head().prediction</span> |
| <span class="sd"> 1.0</span> |
| <span class="sd"> >>> blor.setParams("vector")</span> |
| <span class="sd"> Traceback (most recent call last):</span> |
| <span class="sd"> ...</span> |
| <span class="sd"> TypeError: Method setParams forces keyword arguments.</span> |
| <span class="sd"> >>> lr_path = temp_path + "/lr"</span> |
| <span class="sd"> >>> blor.save(lr_path)</span> |
| <span class="sd"> >>> lr2 = LogisticRegression.load(lr_path)</span> |
| <span class="sd"> >>> lr2.getRegParam()</span> |
| <span class="sd"> 0.01</span> |
| <span class="sd"> >>> model_path = temp_path + "/lr_model"</span> |
| <span class="sd"> >>> blorModel.save(model_path)</span> |
| <span class="sd"> >>> model2 = LogisticRegressionModel.load(model_path)</span> |
| <span class="sd"> >>> blorModel.coefficients[0] == model2.coefficients[0]</span> |
| <span class="sd"> True</span> |
| <span class="sd"> >>> blorModel.intercept == model2.intercept</span> |
| <span class="sd"> True</span> |
| <span class="sd"> >>> model2</span> |
| <span class="sd"> LogisticRegressionModel: uid=..., numClasses=2, numFeatures=2</span> |
| <span class="sd"> >>> blorModel.transform(test0).take(1) == model2.transform(test0).take(1)</span> |
| <span class="sd"> True</span> |
| <span class="sd"> """</span> |
| |
| <span class="n">_input_kwargs</span><span class="p">:</span> <span class="n">Dict</span><span class="p">[</span><span class="nb">str</span><span class="p">,</span> <span class="n">Any</span><span class="p">]</span> |
| |
| <span class="nd">@overload</span> |
| <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span> |
| <span class="bp">self</span><span class="p">,</span> |
| <span class="o">*</span><span class="p">,</span> |
| <span class="n">featuresCol</span><span class="p">:</span> <span class="nb">str</span> <span class="o">=</span> <span class="o">...</span><span class="p">,</span> |
| <span class="n">labelCol</span><span class="p">:</span> <span class="nb">str</span> <span class="o">=</span> <span class="o">...</span><span class="p">,</span> |
| <span class="n">predictionCol</span><span class="p">:</span> <span class="nb">str</span> <span class="o">=</span> <span class="o">...</span><span class="p">,</span> |
| <span class="n">maxIter</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="o">...</span><span class="p">,</span> |
| <span class="n">regParam</span><span class="p">:</span> <span class="nb">float</span> <span class="o">=</span> <span class="o">...</span><span class="p">,</span> |
| <span class="n">elasticNetParam</span><span class="p">:</span> <span class="nb">float</span> <span class="o">=</span> <span class="o">...</span><span class="p">,</span> |
| <span class="n">tol</span><span class="p">:</span> <span class="nb">float</span> <span class="o">=</span> <span class="o">...</span><span class="p">,</span> |
| <span class="n">fitIntercept</span><span class="p">:</span> <span class="nb">bool</span> <span class="o">=</span> <span class="o">...</span><span class="p">,</span> |
| <span class="n">threshold</span><span class="p">:</span> <span class="nb">float</span> <span class="o">=</span> <span class="o">...</span><span class="p">,</span> |
| <span class="n">probabilityCol</span><span class="p">:</span> <span class="nb">str</span> <span class="o">=</span> <span class="o">...</span><span class="p">,</span> |
| <span class="n">rawPredictionCol</span><span class="p">:</span> <span class="nb">str</span> <span class="o">=</span> <span class="o">...</span><span class="p">,</span> |
| <span class="n">standardization</span><span class="p">:</span> <span class="nb">bool</span> <span class="o">=</span> <span class="o">...</span><span class="p">,</span> |
| <span class="n">weightCol</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="nb">str</span><span class="p">]</span> <span class="o">=</span> <span class="o">...</span><span class="p">,</span> |
| <span class="n">aggregationDepth</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="o">...</span><span class="p">,</span> |
| <span class="n">family</span><span class="p">:</span> <span class="nb">str</span> <span class="o">=</span> <span class="o">...</span><span class="p">,</span> |
| <span class="n">lowerBoundsOnCoefficients</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="n">Matrix</span><span class="p">]</span> <span class="o">=</span> <span class="o">...</span><span class="p">,</span> |
| <span class="n">upperBoundsOnCoefficients</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="n">Matrix</span><span class="p">]</span> <span class="o">=</span> <span class="o">...</span><span class="p">,</span> |
| <span class="n">lowerBoundsOnIntercepts</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="n">Vector</span><span class="p">]</span> <span class="o">=</span> <span class="o">...</span><span class="p">,</span> |
| <span class="n">upperBoundsOnIntercepts</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="n">Vector</span><span class="p">]</span> <span class="o">=</span> <span class="o">...</span><span class="p">,</span> |
| <span class="n">maxBlockSizeInMB</span><span class="p">:</span> <span class="nb">float</span> <span class="o">=</span> <span class="o">...</span><span class="p">,</span> |
| <span class="p">):</span> |
| <span class="o">...</span> |
| |
| <span class="nd">@overload</span> |
| <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span> |
| <span class="bp">self</span><span class="p">,</span> |
| <span class="o">*</span><span class="p">,</span> |
| <span class="n">featuresCol</span><span class="p">:</span> <span class="nb">str</span> <span class="o">=</span> <span class="o">...</span><span class="p">,</span> |
| <span class="n">labelCol</span><span class="p">:</span> <span class="nb">str</span> <span class="o">=</span> <span class="o">...</span><span class="p">,</span> |
| <span class="n">predictionCol</span><span class="p">:</span> <span class="nb">str</span> <span class="o">=</span> <span class="o">...</span><span class="p">,</span> |
| <span class="n">maxIter</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="o">...</span><span class="p">,</span> |
| <span class="n">regParam</span><span class="p">:</span> <span class="nb">float</span> <span class="o">=</span> <span class="o">...</span><span class="p">,</span> |
| <span class="n">elasticNetParam</span><span class="p">:</span> <span class="nb">float</span> <span class="o">=</span> <span class="o">...</span><span class="p">,</span> |
| <span class="n">tol</span><span class="p">:</span> <span class="nb">float</span> <span class="o">=</span> <span class="o">...</span><span class="p">,</span> |
| <span class="n">fitIntercept</span><span class="p">:</span> <span class="nb">bool</span> <span class="o">=</span> <span class="o">...</span><span class="p">,</span> |
| <span class="n">thresholds</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="n">List</span><span class="p">[</span><span class="nb">float</span><span class="p">]]</span> <span class="o">=</span> <span class="o">...</span><span class="p">,</span> |
| <span class="n">probabilityCol</span><span class="p">:</span> <span class="nb">str</span> <span class="o">=</span> <span class="o">...</span><span class="p">,</span> |
| <span class="n">rawPredictionCol</span><span class="p">:</span> <span class="nb">str</span> <span class="o">=</span> <span class="o">...</span><span class="p">,</span> |
| <span class="n">standardization</span><span class="p">:</span> <span class="nb">bool</span> <span class="o">=</span> <span class="o">...</span><span class="p">,</span> |
| <span class="n">weightCol</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="nb">str</span><span class="p">]</span> <span class="o">=</span> <span class="o">...</span><span class="p">,</span> |
| <span class="n">aggregationDepth</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="o">...</span><span class="p">,</span> |
| <span class="n">family</span><span class="p">:</span> <span class="nb">str</span> <span class="o">=</span> <span class="o">...</span><span class="p">,</span> |
| <span class="n">lowerBoundsOnCoefficients</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="n">Matrix</span><span class="p">]</span> <span class="o">=</span> <span class="o">...</span><span class="p">,</span> |
| <span class="n">upperBoundsOnCoefficients</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="n">Matrix</span><span class="p">]</span> <span class="o">=</span> <span class="o">...</span><span class="p">,</span> |
| <span class="n">lowerBoundsOnIntercepts</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="n">Vector</span><span class="p">]</span> <span class="o">=</span> <span class="o">...</span><span class="p">,</span> |
| <span class="n">upperBoundsOnIntercepts</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="n">Vector</span><span class="p">]</span> <span class="o">=</span> <span class="o">...</span><span class="p">,</span> |
| <span class="n">maxBlockSizeInMB</span><span class="p">:</span> <span class="nb">float</span> <span class="o">=</span> <span class="o">...</span><span class="p">,</span> |
| <span class="p">):</span> |
| <span class="o">...</span> |
| |
| <span class="nd">@keyword_only</span> |
| <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span> |
| <span class="bp">self</span><span class="p">,</span> |
| <span class="o">*</span><span class="p">,</span> |
| <span class="n">featuresCol</span><span class="p">:</span> <span class="nb">str</span> <span class="o">=</span> <span class="s2">"features"</span><span class="p">,</span> |
| <span class="n">labelCol</span><span class="p">:</span> <span class="nb">str</span> <span class="o">=</span> <span class="s2">"label"</span><span class="p">,</span> |
| <span class="n">predictionCol</span><span class="p">:</span> <span class="nb">str</span> <span class="o">=</span> <span class="s2">"prediction"</span><span class="p">,</span> |
| <span class="n">maxIter</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="mi">100</span><span class="p">,</span> |
| <span class="n">regParam</span><span class="p">:</span> <span class="nb">float</span> <span class="o">=</span> <span class="mf">0.0</span><span class="p">,</span> |
| <span class="n">elasticNetParam</span><span class="p">:</span> <span class="nb">float</span> <span class="o">=</span> <span class="mf">0.0</span><span class="p">,</span> |
| <span class="n">tol</span><span class="p">:</span> <span class="nb">float</span> <span class="o">=</span> <span class="mf">1e-6</span><span class="p">,</span> |
| <span class="n">fitIntercept</span><span class="p">:</span> <span class="nb">bool</span> <span class="o">=</span> <span class="kc">True</span><span class="p">,</span> |
| <span class="n">threshold</span><span class="p">:</span> <span class="nb">float</span> <span class="o">=</span> <span class="mf">0.5</span><span class="p">,</span> |
| <span class="n">thresholds</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="n">List</span><span class="p">[</span><span class="nb">float</span><span class="p">]]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span> |
| <span class="n">probabilityCol</span><span class="p">:</span> <span class="nb">str</span> <span class="o">=</span> <span class="s2">"probability"</span><span class="p">,</span> |
| <span class="n">rawPredictionCol</span><span class="p">:</span> <span class="nb">str</span> <span class="o">=</span> <span class="s2">"rawPrediction"</span><span class="p">,</span> |
| <span class="n">standardization</span><span class="p">:</span> <span class="nb">bool</span> <span class="o">=</span> <span class="kc">True</span><span class="p">,</span> |
| <span class="n">weightCol</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="nb">str</span><span class="p">]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span> |
| <span class="n">aggregationDepth</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="mi">2</span><span class="p">,</span> |
| <span class="n">family</span><span class="p">:</span> <span class="nb">str</span> <span class="o">=</span> <span class="s2">"auto"</span><span class="p">,</span> |
| <span class="n">lowerBoundsOnCoefficients</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="n">Matrix</span><span class="p">]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span> |
| <span class="n">upperBoundsOnCoefficients</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="n">Matrix</span><span class="p">]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span> |
| <span class="n">lowerBoundsOnIntercepts</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="n">Vector</span><span class="p">]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span> |
| <span class="n">upperBoundsOnIntercepts</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="n">Vector</span><span class="p">]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span> |
| <span class="n">maxBlockSizeInMB</span><span class="p">:</span> <span class="nb">float</span> <span class="o">=</span> <span class="mf">0.0</span><span class="p">,</span> |
| <span class="p">):</span> |
| <span class="w"> </span><span class="sd">"""</span> |
| <span class="sd"> __init__(self, \\*, featuresCol="features", labelCol="label", predictionCol="prediction", \</span> |
| <span class="sd"> maxIter=100, regParam=0.0, elasticNetParam=0.0, tol=1e-6, fitIntercept=True, \</span> |
| <span class="sd"> threshold=0.5, thresholds=None, probabilityCol="probability", \</span> |
| <span class="sd"> rawPredictionCol="rawPrediction", standardization=True, weightCol=None, \</span> |
| <span class="sd"> aggregationDepth=2, family="auto", \</span> |
| <span class="sd"> lowerBoundsOnCoefficients=None, upperBoundsOnCoefficients=None, \</span> |
| <span class="sd"> lowerBoundsOnIntercepts=None, upperBoundsOnIntercepts=None, \</span> |
| <span class="sd"> maxBlockSizeInMB=0.0):</span> |
| <span class="sd"> If the threshold and thresholds Params are both set, they must be equivalent.</span> |
| <span class="sd"> """</span> |
| <span class="nb">super</span><span class="p">(</span><span class="n">LogisticRegression</span><span class="p">,</span> <span class="bp">self</span><span class="p">)</span><span class="o">.</span><span class="fm">__init__</span><span class="p">()</span> |
| <span class="bp">self</span><span class="o">.</span><span class="n">_java_obj</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">_new_java_obj</span><span class="p">(</span> |
| <span class="s2">"org.apache.spark.ml.classification.LogisticRegression"</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">uid</span> |
| <span class="p">)</span> |
| <span class="n">kwargs</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">_input_kwargs</span> |
| <span class="bp">self</span><span class="o">.</span><span class="n">setParams</span><span class="p">(</span><span class="o">**</span><span class="n">kwargs</span><span class="p">)</span> |
| <span class="bp">self</span><span class="o">.</span><span class="n">_checkThresholdConsistency</span><span class="p">()</span> |
| |
| <span class="nd">@overload</span> |
| <span class="k">def</span> <span class="nf">setParams</span><span class="p">(</span> |
| <span class="bp">self</span><span class="p">,</span> |
| <span class="o">*</span><span class="p">,</span> |
| <span class="n">featuresCol</span><span class="p">:</span> <span class="nb">str</span> <span class="o">=</span> <span class="o">...</span><span class="p">,</span> |
| <span class="n">labelCol</span><span class="p">:</span> <span class="nb">str</span> <span class="o">=</span> <span class="o">...</span><span class="p">,</span> |
| <span class="n">predictionCol</span><span class="p">:</span> <span class="nb">str</span> <span class="o">=</span> <span class="o">...</span><span class="p">,</span> |
| <span class="n">maxIter</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="o">...</span><span class="p">,</span> |
| <span class="n">regParam</span><span class="p">:</span> <span class="nb">float</span> <span class="o">=</span> <span class="o">...</span><span class="p">,</span> |
| <span class="n">elasticNetParam</span><span class="p">:</span> <span class="nb">float</span> <span class="o">=</span> <span class="o">...</span><span class="p">,</span> |
| <span class="n">tol</span><span class="p">:</span> <span class="nb">float</span> <span class="o">=</span> <span class="o">...</span><span class="p">,</span> |
| <span class="n">fitIntercept</span><span class="p">:</span> <span class="nb">bool</span> <span class="o">=</span> <span class="o">...</span><span class="p">,</span> |
| <span class="n">threshold</span><span class="p">:</span> <span class="nb">float</span> <span class="o">=</span> <span class="o">...</span><span class="p">,</span> |
| <span class="n">probabilityCol</span><span class="p">:</span> <span class="nb">str</span> <span class="o">=</span> <span class="o">...</span><span class="p">,</span> |
| <span class="n">rawPredictionCol</span><span class="p">:</span> <span class="nb">str</span> <span class="o">=</span> <span class="o">...</span><span class="p">,</span> |
| <span class="n">standardization</span><span class="p">:</span> <span class="nb">bool</span> <span class="o">=</span> <span class="o">...</span><span class="p">,</span> |
| <span class="n">weightCol</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="nb">str</span><span class="p">]</span> <span class="o">=</span> <span class="o">...</span><span class="p">,</span> |
| <span class="n">aggregationDepth</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="o">...</span><span class="p">,</span> |
| <span class="n">family</span><span class="p">:</span> <span class="nb">str</span> <span class="o">=</span> <span class="o">...</span><span class="p">,</span> |
| <span class="n">lowerBoundsOnCoefficients</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="n">Matrix</span><span class="p">]</span> <span class="o">=</span> <span class="o">...</span><span class="p">,</span> |
| <span class="n">upperBoundsOnCoefficients</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="n">Matrix</span><span class="p">]</span> <span class="o">=</span> <span class="o">...</span><span class="p">,</span> |
| <span class="n">lowerBoundsOnIntercepts</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="n">Vector</span><span class="p">]</span> <span class="o">=</span> <span class="o">...</span><span class="p">,</span> |
| <span class="n">upperBoundsOnIntercepts</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="n">Vector</span><span class="p">]</span> <span class="o">=</span> <span class="o">...</span><span class="p">,</span> |
| <span class="n">maxBlockSizeInMB</span><span class="p">:</span> <span class="nb">float</span> <span class="o">=</span> <span class="o">...</span><span class="p">,</span> |
| <span class="p">)</span> <span class="o">-></span> <span class="s2">"LogisticRegression"</span><span class="p">:</span> |
| <span class="o">...</span> |
| |
| <span class="nd">@overload</span> |
| <span class="k">def</span> <span class="nf">setParams</span><span class="p">(</span> |
| <span class="bp">self</span><span class="p">,</span> |
| <span class="o">*</span><span class="p">,</span> |
| <span class="n">featuresCol</span><span class="p">:</span> <span class="nb">str</span> <span class="o">=</span> <span class="o">...</span><span class="p">,</span> |
| <span class="n">labelCol</span><span class="p">:</span> <span class="nb">str</span> <span class="o">=</span> <span class="o">...</span><span class="p">,</span> |
| <span class="n">predictionCol</span><span class="p">:</span> <span class="nb">str</span> <span class="o">=</span> <span class="o">...</span><span class="p">,</span> |
| <span class="n">maxIter</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="o">...</span><span class="p">,</span> |
| <span class="n">regParam</span><span class="p">:</span> <span class="nb">float</span> <span class="o">=</span> <span class="o">...</span><span class="p">,</span> |
| <span class="n">elasticNetParam</span><span class="p">:</span> <span class="nb">float</span> <span class="o">=</span> <span class="o">...</span><span class="p">,</span> |
| <span class="n">tol</span><span class="p">:</span> <span class="nb">float</span> <span class="o">=</span> <span class="o">...</span><span class="p">,</span> |
| <span class="n">fitIntercept</span><span class="p">:</span> <span class="nb">bool</span> <span class="o">=</span> <span class="o">...</span><span class="p">,</span> |
| <span class="n">thresholds</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="n">List</span><span class="p">[</span><span class="nb">float</span><span class="p">]]</span> <span class="o">=</span> <span class="o">...</span><span class="p">,</span> |
| <span class="n">probabilityCol</span><span class="p">:</span> <span class="nb">str</span> <span class="o">=</span> <span class="o">...</span><span class="p">,</span> |
| <span class="n">rawPredictionCol</span><span class="p">:</span> <span class="nb">str</span> <span class="o">=</span> <span class="o">...</span><span class="p">,</span> |
| <span class="n">standardization</span><span class="p">:</span> <span class="nb">bool</span> <span class="o">=</span> <span class="o">...</span><span class="p">,</span> |
| <span class="n">weightCol</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="nb">str</span><span class="p">]</span> <span class="o">=</span> <span class="o">...</span><span class="p">,</span> |
| <span class="n">aggregationDepth</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="o">...</span><span class="p">,</span> |
| <span class="n">family</span><span class="p">:</span> <span class="nb">str</span> <span class="o">=</span> <span class="o">...</span><span class="p">,</span> |
| <span class="n">lowerBoundsOnCoefficients</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="n">Matrix</span><span class="p">]</span> <span class="o">=</span> <span class="o">...</span><span class="p">,</span> |
| <span class="n">upperBoundsOnCoefficients</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="n">Matrix</span><span class="p">]</span> <span class="o">=</span> <span class="o">...</span><span class="p">,</span> |
| <span class="n">lowerBoundsOnIntercepts</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="n">Vector</span><span class="p">]</span> <span class="o">=</span> <span class="o">...</span><span class="p">,</span> |
| <span class="n">upperBoundsOnIntercepts</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="n">Vector</span><span class="p">]</span> <span class="o">=</span> <span class="o">...</span><span class="p">,</span> |
| <span class="n">maxBlockSizeInMB</span><span class="p">:</span> <span class="nb">float</span> <span class="o">=</span> <span class="o">...</span><span class="p">,</span> |
| <span class="p">)</span> <span class="o">-></span> <span class="s2">"LogisticRegression"</span><span class="p">:</span> |
| <span class="o">...</span> |
| |
| <div class="viewcode-block" id="LogisticRegression.setParams"><a class="viewcode-back" href="../../../reference/api/pyspark.ml.classification.LogisticRegression.html#pyspark.ml.classification.LogisticRegression.setParams">[docs]</a> <span class="nd">@keyword_only</span> |
| <span class="nd">@since</span><span class="p">(</span><span class="s2">"1.3.0"</span><span class="p">)</span> |
| <span class="k">def</span> <span class="nf">setParams</span><span class="p">(</span> |
| <span class="bp">self</span><span class="p">,</span> |
| <span class="o">*</span><span class="p">,</span> |
| <span class="n">featuresCol</span><span class="p">:</span> <span class="nb">str</span> <span class="o">=</span> <span class="s2">"features"</span><span class="p">,</span> |
| <span class="n">labelCol</span><span class="p">:</span> <span class="nb">str</span> <span class="o">=</span> <span class="s2">"label"</span><span class="p">,</span> |
| <span class="n">predictionCol</span><span class="p">:</span> <span class="nb">str</span> <span class="o">=</span> <span class="s2">"prediction"</span><span class="p">,</span> |
| <span class="n">maxIter</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="mi">100</span><span class="p">,</span> |
| <span class="n">regParam</span><span class="p">:</span> <span class="nb">float</span> <span class="o">=</span> <span class="mf">0.0</span><span class="p">,</span> |
| <span class="n">elasticNetParam</span><span class="p">:</span> <span class="nb">float</span> <span class="o">=</span> <span class="mf">0.0</span><span class="p">,</span> |
| <span class="n">tol</span><span class="p">:</span> <span class="nb">float</span> <span class="o">=</span> <span class="mf">1e-6</span><span class="p">,</span> |
| <span class="n">fitIntercept</span><span class="p">:</span> <span class="nb">bool</span> <span class="o">=</span> <span class="kc">True</span><span class="p">,</span> |
| <span class="n">threshold</span><span class="p">:</span> <span class="nb">float</span> <span class="o">=</span> <span class="mf">0.5</span><span class="p">,</span> |
| <span class="n">thresholds</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="n">List</span><span class="p">[</span><span class="nb">float</span><span class="p">]]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span> |
| <span class="n">probabilityCol</span><span class="p">:</span> <span class="nb">str</span> <span class="o">=</span> <span class="s2">"probability"</span><span class="p">,</span> |
| <span class="n">rawPredictionCol</span><span class="p">:</span> <span class="nb">str</span> <span class="o">=</span> <span class="s2">"rawPrediction"</span><span class="p">,</span> |
| <span class="n">standardization</span><span class="p">:</span> <span class="nb">bool</span> <span class="o">=</span> <span class="kc">True</span><span class="p">,</span> |
| <span class="n">weightCol</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="nb">str</span><span class="p">]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span> |
| <span class="n">aggregationDepth</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="mi">2</span><span class="p">,</span> |
| <span class="n">family</span><span class="p">:</span> <span class="nb">str</span> <span class="o">=</span> <span class="s2">"auto"</span><span class="p">,</span> |
| <span class="n">lowerBoundsOnCoefficients</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="n">Matrix</span><span class="p">]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span> |
| <span class="n">upperBoundsOnCoefficients</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="n">Matrix</span><span class="p">]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span> |
| <span class="n">lowerBoundsOnIntercepts</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="n">Vector</span><span class="p">]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span> |
| <span class="n">upperBoundsOnIntercepts</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="n">Vector</span><span class="p">]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span> |
| <span class="n">maxBlockSizeInMB</span><span class="p">:</span> <span class="nb">float</span> <span class="o">=</span> <span class="mf">0.0</span><span class="p">,</span> |
| <span class="p">)</span> <span class="o">-></span> <span class="s2">"LogisticRegression"</span><span class="p">:</span> |
| <span class="w"> </span><span class="sd">"""</span> |
| <span class="sd"> setParams(self, \\*, featuresCol="features", labelCol="label", predictionCol="prediction", \</span> |
| <span class="sd"> maxIter=100, regParam=0.0, elasticNetParam=0.0, tol=1e-6, fitIntercept=True, \</span> |
| <span class="sd"> threshold=0.5, thresholds=None, probabilityCol="probability", \</span> |
| <span class="sd"> rawPredictionCol="rawPrediction", standardization=True, weightCol=None, \</span> |
| <span class="sd"> aggregationDepth=2, family="auto", \</span> |
| <span class="sd"> lowerBoundsOnCoefficients=None, upperBoundsOnCoefficients=None, \</span> |
| <span class="sd"> lowerBoundsOnIntercepts=None, upperBoundsOnIntercepts=None, \</span> |
| <span class="sd"> maxBlockSizeInMB=0.0):</span> |
| <span class="sd"> Sets params for logistic regression.</span> |
| <span class="sd"> If the threshold and thresholds Params are both set, they must be equivalent.</span> |
| <span class="sd"> """</span> |
| <span class="n">kwargs</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">_input_kwargs</span> |
| <span class="bp">self</span><span class="o">.</span><span class="n">_set</span><span class="p">(</span><span class="o">**</span><span class="n">kwargs</span><span class="p">)</span> |
| <span class="bp">self</span><span class="o">.</span><span class="n">_checkThresholdConsistency</span><span class="p">()</span> |
| <span class="k">return</span> <span class="bp">self</span></div> |
| |
| <span class="k">def</span> <span class="nf">_create_model</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">java_model</span><span class="p">:</span> <span class="s2">"JavaObject"</span><span class="p">)</span> <span class="o">-></span> <span class="s2">"LogisticRegressionModel"</span><span class="p">:</span> |
| <span class="k">return</span> <span class="n">LogisticRegressionModel</span><span class="p">(</span><span class="n">java_model</span><span class="p">)</span> |
| |
| <div class="viewcode-block" id="LogisticRegression.setFamily"><a class="viewcode-back" href="../../../reference/api/pyspark.ml.classification.LogisticRegression.html#pyspark.ml.classification.LogisticRegression.setFamily">[docs]</a> <span class="nd">@since</span><span class="p">(</span><span class="s2">"2.1.0"</span><span class="p">)</span> |
| <span class="k">def</span> <span class="nf">setFamily</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">value</span><span class="p">:</span> <span class="nb">str</span><span class="p">)</span> <span class="o">-></span> <span class="s2">"LogisticRegression"</span><span class="p">:</span> |
| <span class="w"> </span><span class="sd">"""</span> |
| <span class="sd"> Sets the value of :py:attr:`family`.</span> |
| <span class="sd"> """</span> |
| <span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">_set</span><span class="p">(</span><span class="n">family</span><span class="o">=</span><span class="n">value</span><span class="p">)</span></div> |
| |
| <div class="viewcode-block" id="LogisticRegression.setLowerBoundsOnCoefficients"><a class="viewcode-back" href="../../../reference/api/pyspark.ml.classification.LogisticRegression.html#pyspark.ml.classification.LogisticRegression.setLowerBoundsOnCoefficients">[docs]</a> <span class="nd">@since</span><span class="p">(</span><span class="s2">"2.3.0"</span><span class="p">)</span> |
| <span class="k">def</span> <span class="nf">setLowerBoundsOnCoefficients</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">value</span><span class="p">:</span> <span class="n">Matrix</span><span class="p">)</span> <span class="o">-></span> <span class="s2">"LogisticRegression"</span><span class="p">:</span> |
| <span class="w"> </span><span class="sd">"""</span> |
| <span class="sd"> Sets the value of :py:attr:`lowerBoundsOnCoefficients`</span> |
| <span class="sd"> """</span> |
| <span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">_set</span><span class="p">(</span><span class="n">lowerBoundsOnCoefficients</span><span class="o">=</span><span class="n">value</span><span class="p">)</span></div> |
| |
| <div class="viewcode-block" id="LogisticRegression.setUpperBoundsOnCoefficients"><a class="viewcode-back" href="../../../reference/api/pyspark.ml.classification.LogisticRegression.html#pyspark.ml.classification.LogisticRegression.setUpperBoundsOnCoefficients">[docs]</a> <span class="nd">@since</span><span class="p">(</span><span class="s2">"2.3.0"</span><span class="p">)</span> |
| <span class="k">def</span> <span class="nf">setUpperBoundsOnCoefficients</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">value</span><span class="p">:</span> <span class="n">Matrix</span><span class="p">)</span> <span class="o">-></span> <span class="s2">"LogisticRegression"</span><span class="p">:</span> |
| <span class="w"> </span><span class="sd">"""</span> |
| <span class="sd"> Sets the value of :py:attr:`upperBoundsOnCoefficients`</span> |
| <span class="sd"> """</span> |
| <span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">_set</span><span class="p">(</span><span class="n">upperBoundsOnCoefficients</span><span class="o">=</span><span class="n">value</span><span class="p">)</span></div> |
| |
| <div class="viewcode-block" id="LogisticRegression.setLowerBoundsOnIntercepts"><a class="viewcode-back" href="../../../reference/api/pyspark.ml.classification.LogisticRegression.html#pyspark.ml.classification.LogisticRegression.setLowerBoundsOnIntercepts">[docs]</a> <span class="nd">@since</span><span class="p">(</span><span class="s2">"2.3.0"</span><span class="p">)</span> |
| <span class="k">def</span> <span class="nf">setLowerBoundsOnIntercepts</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">value</span><span class="p">:</span> <span class="n">Vector</span><span class="p">)</span> <span class="o">-></span> <span class="s2">"LogisticRegression"</span><span class="p">:</span> |
| <span class="w"> </span><span class="sd">"""</span> |
| <span class="sd"> Sets the value of :py:attr:`lowerBoundsOnIntercepts`</span> |
| <span class="sd"> """</span> |
| <span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">_set</span><span class="p">(</span><span class="n">lowerBoundsOnIntercepts</span><span class="o">=</span><span class="n">value</span><span class="p">)</span></div> |
| |
| <div class="viewcode-block" id="LogisticRegression.setUpperBoundsOnIntercepts"><a class="viewcode-back" href="../../../reference/api/pyspark.ml.classification.LogisticRegression.html#pyspark.ml.classification.LogisticRegression.setUpperBoundsOnIntercepts">[docs]</a> <span class="nd">@since</span><span class="p">(</span><span class="s2">"2.3.0"</span><span class="p">)</span> |
| <span class="k">def</span> <span class="nf">setUpperBoundsOnIntercepts</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">value</span><span class="p">:</span> <span class="n">Vector</span><span class="p">)</span> <span class="o">-></span> <span class="s2">"LogisticRegression"</span><span class="p">:</span> |
| <span class="w"> </span><span class="sd">"""</span> |
| <span class="sd"> Sets the value of :py:attr:`upperBoundsOnIntercepts`</span> |
| <span class="sd"> """</span> |
| <span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">_set</span><span class="p">(</span><span class="n">upperBoundsOnIntercepts</span><span class="o">=</span><span class="n">value</span><span class="p">)</span></div> |
| |
| <div class="viewcode-block" id="LogisticRegression.setMaxIter"><a class="viewcode-back" href="../../../reference/api/pyspark.ml.classification.LogisticRegression.html#pyspark.ml.classification.LogisticRegression.setMaxIter">[docs]</a> <span class="k">def</span> <span class="nf">setMaxIter</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">value</span><span class="p">:</span> <span class="nb">int</span><span class="p">)</span> <span class="o">-></span> <span class="s2">"LogisticRegression"</span><span class="p">:</span> |
| <span class="w"> </span><span class="sd">"""</span> |
| <span class="sd"> Sets the value of :py:attr:`maxIter`.</span> |
| <span class="sd"> """</span> |
| <span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">_set</span><span class="p">(</span><span class="n">maxIter</span><span class="o">=</span><span class="n">value</span><span class="p">)</span></div> |
| |
| <div class="viewcode-block" id="LogisticRegression.setRegParam"><a class="viewcode-back" href="../../../reference/api/pyspark.ml.classification.LogisticRegression.html#pyspark.ml.classification.LogisticRegression.setRegParam">[docs]</a> <span class="k">def</span> <span class="nf">setRegParam</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">value</span><span class="p">:</span> <span class="nb">float</span><span class="p">)</span> <span class="o">-></span> <span class="s2">"LogisticRegression"</span><span class="p">:</span> |
| <span class="w"> </span><span class="sd">"""</span> |
| <span class="sd"> Sets the value of :py:attr:`regParam`.</span> |
| <span class="sd"> """</span> |
| <span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">_set</span><span class="p">(</span><span class="n">regParam</span><span class="o">=</span><span class="n">value</span><span class="p">)</span></div> |
| |
| <div class="viewcode-block" id="LogisticRegression.setTol"><a class="viewcode-back" href="../../../reference/api/pyspark.ml.classification.LogisticRegression.html#pyspark.ml.classification.LogisticRegression.setTol">[docs]</a> <span class="k">def</span> <span class="nf">setTol</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">value</span><span class="p">:</span> <span class="nb">float</span><span class="p">)</span> <span class="o">-></span> <span class="s2">"LogisticRegression"</span><span class="p">:</span> |
| <span class="w"> </span><span class="sd">"""</span> |
| <span class="sd"> Sets the value of :py:attr:`tol`.</span> |
| <span class="sd"> """</span> |
| <span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">_set</span><span class="p">(</span><span class="n">tol</span><span class="o">=</span><span class="n">value</span><span class="p">)</span></div> |
| |
| <div class="viewcode-block" id="LogisticRegression.setElasticNetParam"><a class="viewcode-back" href="../../../reference/api/pyspark.ml.classification.LogisticRegression.html#pyspark.ml.classification.LogisticRegression.setElasticNetParam">[docs]</a> <span class="k">def</span> <span class="nf">setElasticNetParam</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">value</span><span class="p">:</span> <span class="nb">float</span><span class="p">)</span> <span class="o">-></span> <span class="s2">"LogisticRegression"</span><span class="p">:</span> |
| <span class="w"> </span><span class="sd">"""</span> |
| <span class="sd"> Sets the value of :py:attr:`elasticNetParam`.</span> |
| <span class="sd"> """</span> |
| <span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">_set</span><span class="p">(</span><span class="n">elasticNetParam</span><span class="o">=</span><span class="n">value</span><span class="p">)</span></div> |
| |
| <div class="viewcode-block" id="LogisticRegression.setFitIntercept"><a class="viewcode-back" href="../../../reference/api/pyspark.ml.classification.LogisticRegression.html#pyspark.ml.classification.LogisticRegression.setFitIntercept">[docs]</a> <span class="k">def</span> <span class="nf">setFitIntercept</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">value</span><span class="p">:</span> <span class="nb">bool</span><span class="p">)</span> <span class="o">-></span> <span class="s2">"LogisticRegression"</span><span class="p">:</span> |
| <span class="w"> </span><span class="sd">"""</span> |
| <span class="sd"> Sets the value of :py:attr:`fitIntercept`.</span> |
| <span class="sd"> """</span> |
| <span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">_set</span><span class="p">(</span><span class="n">fitIntercept</span><span class="o">=</span><span class="n">value</span><span class="p">)</span></div> |
| |
| <div class="viewcode-block" id="LogisticRegression.setStandardization"><a class="viewcode-back" href="../../../reference/api/pyspark.ml.classification.LogisticRegression.html#pyspark.ml.classification.LogisticRegression.setStandardization">[docs]</a> <span class="k">def</span> <span class="nf">setStandardization</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">value</span><span class="p">:</span> <span class="nb">bool</span><span class="p">)</span> <span class="o">-></span> <span class="s2">"LogisticRegression"</span><span class="p">:</span> |
| <span class="w"> </span><span class="sd">"""</span> |
| <span class="sd"> Sets the value of :py:attr:`standardization`.</span> |
| <span class="sd"> """</span> |
| <span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">_set</span><span class="p">(</span><span class="n">standardization</span><span class="o">=</span><span class="n">value</span><span class="p">)</span></div> |
| |
| <div class="viewcode-block" id="LogisticRegression.setWeightCol"><a class="viewcode-back" href="../../../reference/api/pyspark.ml.classification.LogisticRegression.html#pyspark.ml.classification.LogisticRegression.setWeightCol">[docs]</a> <span class="k">def</span> <span class="nf">setWeightCol</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">value</span><span class="p">:</span> <span class="nb">str</span><span class="p">)</span> <span class="o">-></span> <span class="s2">"LogisticRegression"</span><span class="p">:</span> |
| <span class="w"> </span><span class="sd">"""</span> |
| <span class="sd"> Sets the value of :py:attr:`weightCol`.</span> |
| <span class="sd"> """</span> |
| <span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">_set</span><span class="p">(</span><span class="n">weightCol</span><span class="o">=</span><span class="n">value</span><span class="p">)</span></div> |
| |
| <div class="viewcode-block" id="LogisticRegression.setAggregationDepth"><a class="viewcode-back" href="../../../reference/api/pyspark.ml.classification.LogisticRegression.html#pyspark.ml.classification.LogisticRegression.setAggregationDepth">[docs]</a> <span class="k">def</span> <span class="nf">setAggregationDepth</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">value</span><span class="p">:</span> <span class="nb">int</span><span class="p">)</span> <span class="o">-></span> <span class="s2">"LogisticRegression"</span><span class="p">:</span> |
| <span class="w"> </span><span class="sd">"""</span> |
| <span class="sd"> Sets the value of :py:attr:`aggregationDepth`.</span> |
| <span class="sd"> """</span> |
| <span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">_set</span><span class="p">(</span><span class="n">aggregationDepth</span><span class="o">=</span><span class="n">value</span><span class="p">)</span></div> |
| |
| <div class="viewcode-block" id="LogisticRegression.setMaxBlockSizeInMB"><a class="viewcode-back" href="../../../reference/api/pyspark.ml.classification.LogisticRegression.html#pyspark.ml.classification.LogisticRegression.setMaxBlockSizeInMB">[docs]</a> <span class="nd">@since</span><span class="p">(</span><span class="s2">"3.1.0"</span><span class="p">)</span> |
| <span class="k">def</span> <span class="nf">setMaxBlockSizeInMB</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">value</span><span class="p">:</span> <span class="nb">float</span><span class="p">)</span> <span class="o">-></span> <span class="s2">"LogisticRegression"</span><span class="p">:</span> |
| <span class="w"> </span><span class="sd">"""</span> |
| <span class="sd"> Sets the value of :py:attr:`maxBlockSizeInMB`.</span> |
| <span class="sd"> """</span> |
| <span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">_set</span><span class="p">(</span><span class="n">maxBlockSizeInMB</span><span class="o">=</span><span class="n">value</span><span class="p">)</span></div></div> |
| |
| |
| <div class="viewcode-block" id="LogisticRegressionModel"><a class="viewcode-back" href="../../../reference/api/pyspark.ml.classification.LogisticRegressionModel.html#pyspark.ml.classification.LogisticRegressionModel">[docs]</a><span class="k">class</span> <span class="nc">LogisticRegressionModel</span><span class="p">(</span> |
| <span class="n">_JavaProbabilisticClassificationModel</span><span class="p">[</span><span class="n">Vector</span><span class="p">],</span> |
| <span class="n">_LogisticRegressionParams</span><span class="p">,</span> |
| <span class="n">JavaMLWritable</span><span class="p">,</span> |
| <span class="n">JavaMLReadable</span><span class="p">[</span><span class="s2">"LogisticRegressionModel"</span><span class="p">],</span> |
| <span class="n">HasTrainingSummary</span><span class="p">[</span><span class="s2">"LogisticRegressionTrainingSummary"</span><span class="p">],</span> |
| <span class="p">):</span> |
| <span class="w"> </span><span class="sd">"""</span> |
| <span class="sd"> Model fitted by LogisticRegression.</span> |
| |
| <span class="sd"> .. versionadded:: 1.3.0</span> |
| <span class="sd"> """</span> |
| |
| <span class="nd">@property</span> <span class="c1"># type: ignore[misc]</span> |
| <span class="nd">@since</span><span class="p">(</span><span class="s2">"2.0.0"</span><span class="p">)</span> |
| <span class="k">def</span> <span class="nf">coefficients</span><span class="p">(</span><span class="bp">self</span><span class="p">)</span> <span class="o">-></span> <span class="n">Vector</span><span class="p">:</span> |
| <span class="w"> </span><span class="sd">"""</span> |
| <span class="sd"> Model coefficients of binomial logistic regression.</span> |
| <span class="sd"> An exception is thrown in the case of multinomial logistic regression.</span> |
| <span class="sd"> """</span> |
| <span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">_call_java</span><span class="p">(</span><span class="s2">"coefficients"</span><span class="p">)</span> |
| |
| <span class="nd">@property</span> <span class="c1"># type: ignore[misc]</span> |
| <span class="nd">@since</span><span class="p">(</span><span class="s2">"1.4.0"</span><span class="p">)</span> |
| <span class="k">def</span> <span class="nf">intercept</span><span class="p">(</span><span class="bp">self</span><span class="p">)</span> <span class="o">-></span> <span class="nb">float</span><span class="p">:</span> |
| <span class="w"> </span><span class="sd">"""</span> |
| <span class="sd"> Model intercept of binomial logistic regression.</span> |
| <span class="sd"> An exception is thrown in the case of multinomial logistic regression.</span> |
| <span class="sd"> """</span> |
| <span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">_call_java</span><span class="p">(</span><span class="s2">"intercept"</span><span class="p">)</span> |
| |
| <span class="nd">@property</span> <span class="c1"># type: ignore[misc]</span> |
| <span class="nd">@since</span><span class="p">(</span><span class="s2">"2.1.0"</span><span class="p">)</span> |
| <span class="k">def</span> <span class="nf">coefficientMatrix</span><span class="p">(</span><span class="bp">self</span><span class="p">)</span> <span class="o">-></span> <span class="n">Matrix</span><span class="p">:</span> |
| <span class="w"> </span><span class="sd">"""</span> |
| <span class="sd"> Model coefficients.</span> |
| <span class="sd"> """</span> |
| <span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">_call_java</span><span class="p">(</span><span class="s2">"coefficientMatrix"</span><span class="p">)</span> |
| |
| <span class="nd">@property</span> <span class="c1"># type: ignore[misc]</span> |
| <span class="nd">@since</span><span class="p">(</span><span class="s2">"2.1.0"</span><span class="p">)</span> |
| <span class="k">def</span> <span class="nf">interceptVector</span><span class="p">(</span><span class="bp">self</span><span class="p">)</span> <span class="o">-></span> <span class="n">Vector</span><span class="p">:</span> |
| <span class="w"> </span><span class="sd">"""</span> |
| <span class="sd"> Model intercept.</span> |
| <span class="sd"> """</span> |
| <span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">_call_java</span><span class="p">(</span><span class="s2">"interceptVector"</span><span class="p">)</span> |
| |
| <span class="nd">@property</span> <span class="c1"># type: ignore[misc]</span> |
| <span class="nd">@since</span><span class="p">(</span><span class="s2">"2.0.0"</span><span class="p">)</span> |
| <span class="k">def</span> <span class="nf">summary</span><span class="p">(</span><span class="bp">self</span><span class="p">)</span> <span class="o">-></span> <span class="s2">"LogisticRegressionTrainingSummary"</span><span class="p">:</span> |
| <span class="w"> </span><span class="sd">"""</span> |
| <span class="sd"> Gets summary (accuracy/precision/recall, objective history, total iterations) of model</span> |
| <span class="sd"> trained on the training set. An exception is thrown if `trainingSummary is None`.</span> |
| <span class="sd"> """</span> |
| <span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">hasSummary</span><span class="p">:</span> |
| <span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">numClasses</span> <span class="o"><=</span> <span class="mi">2</span><span class="p">:</span> |
| <span class="k">return</span> <span class="n">BinaryLogisticRegressionTrainingSummary</span><span class="p">(</span> |
| <span class="nb">super</span><span class="p">(</span><span class="n">LogisticRegressionModel</span><span class="p">,</span> <span class="bp">self</span><span class="p">)</span><span class="o">.</span><span class="n">summary</span> |
| <span class="p">)</span> |
| <span class="k">else</span><span class="p">:</span> |
| <span class="k">return</span> <span class="n">LogisticRegressionTrainingSummary</span><span class="p">(</span> |
| <span class="nb">super</span><span class="p">(</span><span class="n">LogisticRegressionModel</span><span class="p">,</span> <span class="bp">self</span><span class="p">)</span><span class="o">.</span><span class="n">summary</span> |
| <span class="p">)</span> |
| <span class="k">else</span><span class="p">:</span> |
| <span class="k">raise</span> <span class="ne">RuntimeError</span><span class="p">(</span> |
| <span class="s2">"No training summary available for this </span><span class="si">%s</span><span class="s2">"</span> <span class="o">%</span> <span class="bp">self</span><span class="o">.</span><span class="vm">__class__</span><span class="o">.</span><span class="vm">__name__</span> |
| <span class="p">)</span> |
| |
| <div class="viewcode-block" id="LogisticRegressionModel.evaluate"><a class="viewcode-back" href="../../../reference/api/pyspark.ml.classification.LogisticRegressionModel.html#pyspark.ml.classification.LogisticRegressionModel.evaluate">[docs]</a> <span class="k">def</span> <span class="nf">evaluate</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">dataset</span><span class="p">:</span> <span class="n">DataFrame</span><span class="p">)</span> <span class="o">-></span> <span class="s2">"LogisticRegressionSummary"</span><span class="p">:</span> |
| <span class="w"> </span><span class="sd">"""</span> |
| <span class="sd"> Evaluates the model on a test dataset.</span> |
| |
| <span class="sd"> .. versionadded:: 2.0.0</span> |
| |
| <span class="sd"> Parameters</span> |
| <span class="sd"> ----------</span> |
| <span class="sd"> dataset : :py:class:`pyspark.sql.DataFrame`</span> |
| <span class="sd"> Test dataset to evaluate model on.</span> |
| <span class="sd"> """</span> |
| <span class="k">if</span> <span class="ow">not</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">dataset</span><span class="p">,</span> <span class="n">DataFrame</span><span class="p">):</span> |
| <span class="k">raise</span> <span class="ne">TypeError</span><span class="p">(</span><span class="s2">"dataset must be a DataFrame but got </span><span class="si">%s</span><span class="s2">."</span> <span class="o">%</span> <span class="nb">type</span><span class="p">(</span><span class="n">dataset</span><span class="p">))</span> |
| <span class="n">java_blr_summary</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">_call_java</span><span class="p">(</span><span class="s2">"evaluate"</span><span class="p">,</span> <span class="n">dataset</span><span class="p">)</span> |
| <span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">numClasses</span> <span class="o"><=</span> <span class="mi">2</span><span class="p">:</span> |
| <span class="k">return</span> <span class="n">BinaryLogisticRegressionSummary</span><span class="p">(</span><span class="n">java_blr_summary</span><span class="p">)</span> |
| <span class="k">else</span><span class="p">:</span> |
| <span class="k">return</span> <span class="n">LogisticRegressionSummary</span><span class="p">(</span><span class="n">java_blr_summary</span><span class="p">)</span></div></div> |
| |
| |
| <div class="viewcode-block" id="LogisticRegressionSummary"><a class="viewcode-back" href="../../../reference/api/pyspark.ml.classification.LogisticRegressionSummary.html#pyspark.ml.classification.LogisticRegressionSummary">[docs]</a><span class="k">class</span> <span class="nc">LogisticRegressionSummary</span><span class="p">(</span><span class="n">_ClassificationSummary</span><span class="p">):</span> |
| <span class="w"> </span><span class="sd">"""</span> |
| <span class="sd"> Abstraction for Logistic Regression Results for a given model.</span> |
| |
| <span class="sd"> .. versionadded:: 2.0.0</span> |
| <span class="sd"> """</span> |
| |
| <span class="nd">@property</span> <span class="c1"># type: ignore[misc]</span> |
| <span class="nd">@since</span><span class="p">(</span><span class="s2">"2.0.0"</span><span class="p">)</span> |
| <span class="k">def</span> <span class="nf">probabilityCol</span><span class="p">(</span><span class="bp">self</span><span class="p">)</span> <span class="o">-></span> <span class="nb">str</span><span class="p">:</span> |
| <span class="w"> </span><span class="sd">"""</span> |
| <span class="sd"> Field in "predictions" which gives the probability</span> |
| <span class="sd"> of each class as a vector.</span> |
| <span class="sd"> """</span> |
| <span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">_call_java</span><span class="p">(</span><span class="s2">"probabilityCol"</span><span class="p">)</span> |
| |
| <span class="nd">@property</span> <span class="c1"># type: ignore[misc]</span> |
| <span class="nd">@since</span><span class="p">(</span><span class="s2">"2.0.0"</span><span class="p">)</span> |
| <span class="k">def</span> <span class="nf">featuresCol</span><span class="p">(</span><span class="bp">self</span><span class="p">)</span> <span class="o">-></span> <span class="nb">str</span><span class="p">:</span> |
| <span class="w"> </span><span class="sd">"""</span> |
| <span class="sd"> Field in "predictions" which gives the features of each instance</span> |
| <span class="sd"> as a vector.</span> |
| <span class="sd"> """</span> |
| <span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">_call_java</span><span class="p">(</span><span class="s2">"featuresCol"</span><span class="p">)</span></div> |
| |
| |
| <div class="viewcode-block" id="LogisticRegressionTrainingSummary"><a class="viewcode-back" href="../../../reference/api/pyspark.ml.classification.LogisticRegressionTrainingSummary.html#pyspark.ml.classification.LogisticRegressionTrainingSummary">[docs]</a><span class="nd">@inherit_doc</span> |
| <span class="k">class</span> <span class="nc">LogisticRegressionTrainingSummary</span><span class="p">(</span><span class="n">LogisticRegressionSummary</span><span class="p">,</span> <span class="n">_TrainingSummary</span><span class="p">):</span> |
| <span class="w"> </span><span class="sd">"""</span> |
| <span class="sd"> Abstraction for multinomial Logistic Regression Training results.</span> |
| |
| <span class="sd"> .. versionadded:: 2.0.0</span> |
| <span class="sd"> """</span> |
| |
| <span class="k">pass</span></div> |
| |
| |
| <div class="viewcode-block" id="BinaryLogisticRegressionSummary"><a class="viewcode-back" href="../../../reference/api/pyspark.ml.classification.BinaryLogisticRegressionSummary.html#pyspark.ml.classification.BinaryLogisticRegressionSummary">[docs]</a><span class="nd">@inherit_doc</span> |
| <span class="k">class</span> <span class="nc">BinaryLogisticRegressionSummary</span><span class="p">(</span><span class="n">_BinaryClassificationSummary</span><span class="p">,</span> <span class="n">LogisticRegressionSummary</span><span class="p">):</span> |
| <span class="w"> </span><span class="sd">"""</span> |
| <span class="sd"> Binary Logistic regression results for a given model.</span> |
| |
| <span class="sd"> .. versionadded:: 2.0.0</span> |
| <span class="sd"> """</span> |
| |
| <span class="k">pass</span></div> |
| |
| |
| <div class="viewcode-block" id="BinaryLogisticRegressionTrainingSummary"><a class="viewcode-back" href="../../../reference/api/pyspark.ml.classification.BinaryLogisticRegressionTrainingSummary.html#pyspark.ml.classification.BinaryLogisticRegressionTrainingSummary">[docs]</a><span class="nd">@inherit_doc</span> |
| <span class="k">class</span> <span class="nc">BinaryLogisticRegressionTrainingSummary</span><span class="p">(</span> |
| <span class="n">BinaryLogisticRegressionSummary</span><span class="p">,</span> <span class="n">LogisticRegressionTrainingSummary</span> |
| <span class="p">):</span> |
| <span class="w"> </span><span class="sd">"""</span> |
| <span class="sd"> Binary Logistic regression training results for a given model.</span> |
| |
| <span class="sd"> .. versionadded:: 2.0.0</span> |
| <span class="sd"> """</span> |
| |
| <span class="k">pass</span></div> |
| |
| |
| <span class="nd">@inherit_doc</span> |
| <span class="k">class</span> <span class="nc">_DecisionTreeClassifierParams</span><span class="p">(</span><span class="n">_DecisionTreeParams</span><span class="p">,</span> <span class="n">_TreeClassifierParams</span><span class="p">):</span> |
| <span class="w"> </span><span class="sd">"""</span> |
| <span class="sd"> Params for :py:class:`DecisionTreeClassifier` and :py:class:`DecisionTreeClassificationModel`.</span> |
| <span class="sd"> """</span> |
| |
| <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="o">*</span><span class="n">args</span><span class="p">:</span> <span class="n">Any</span><span class="p">):</span> |
| <span class="nb">super</span><span class="p">(</span><span class="n">_DecisionTreeClassifierParams</span><span class="p">,</span> <span class="bp">self</span><span class="p">)</span><span class="o">.</span><span class="fm">__init__</span><span class="p">(</span><span class="o">*</span><span class="n">args</span><span class="p">)</span> |
| <span class="bp">self</span><span class="o">.</span><span class="n">_setDefault</span><span class="p">(</span> |
| <span class="n">maxDepth</span><span class="o">=</span><span class="mi">5</span><span class="p">,</span> |
| <span class="n">maxBins</span><span class="o">=</span><span class="mi">32</span><span class="p">,</span> |
| <span class="n">minInstancesPerNode</span><span class="o">=</span><span class="mi">1</span><span class="p">,</span> |
| <span class="n">minInfoGain</span><span class="o">=</span><span class="mf">0.0</span><span class="p">,</span> |
| <span class="n">maxMemoryInMB</span><span class="o">=</span><span class="mi">256</span><span class="p">,</span> |
| <span class="n">cacheNodeIds</span><span class="o">=</span><span class="kc">False</span><span class="p">,</span> |
| <span class="n">checkpointInterval</span><span class="o">=</span><span class="mi">10</span><span class="p">,</span> |
| <span class="n">impurity</span><span class="o">=</span><span class="s2">"gini"</span><span class="p">,</span> |
| <span class="n">leafCol</span><span class="o">=</span><span class="s2">""</span><span class="p">,</span> |
| <span class="n">minWeightFractionPerNode</span><span class="o">=</span><span class="mf">0.0</span><span class="p">,</span> |
| <span class="p">)</span> |
| |
| |
| <div class="viewcode-block" id="DecisionTreeClassifier"><a class="viewcode-back" href="../../../reference/api/pyspark.ml.classification.DecisionTreeClassifier.html#pyspark.ml.classification.DecisionTreeClassifier">[docs]</a><span class="nd">@inherit_doc</span> |
| <span class="k">class</span> <span class="nc">DecisionTreeClassifier</span><span class="p">(</span> |
| <span class="n">_JavaProbabilisticClassifier</span><span class="p">[</span><span class="s2">"DecisionTreeClassificationModel"</span><span class="p">],</span> |
| <span class="n">_DecisionTreeClassifierParams</span><span class="p">,</span> |
| <span class="n">JavaMLWritable</span><span class="p">,</span> |
| <span class="n">JavaMLReadable</span><span class="p">[</span><span class="s2">"DecisionTreeClassifier"</span><span class="p">],</span> |
| <span class="p">):</span> |
| <span class="w"> </span><span class="sd">"""</span> |
| <span class="sd"> `Decision tree <http://en.wikipedia.org/wiki/Decision_tree_learning>`_</span> |
| <span class="sd"> learning algorithm for classification.</span> |
| <span class="sd"> It supports both binary and multiclass labels, as well as both continuous and categorical</span> |
| <span class="sd"> features.</span> |
| |
| <span class="sd"> .. versionadded:: 1.4.0</span> |
| |
| <span class="sd"> Examples</span> |
| <span class="sd"> --------</span> |
| <span class="sd"> >>> from pyspark.ml.linalg import Vectors</span> |
| <span class="sd"> >>> from pyspark.ml.feature import StringIndexer</span> |
| <span class="sd"> >>> df = spark.createDataFrame([</span> |
| <span class="sd"> ... (1.0, Vectors.dense(1.0)),</span> |
| <span class="sd"> ... (0.0, Vectors.sparse(1, [], []))], ["label", "features"])</span> |
| <span class="sd"> >>> stringIndexer = StringIndexer(inputCol="label", outputCol="indexed")</span> |
| <span class="sd"> >>> si_model = stringIndexer.fit(df)</span> |
| <span class="sd"> >>> td = si_model.transform(df)</span> |
| <span class="sd"> >>> dt = DecisionTreeClassifier(maxDepth=2, labelCol="indexed", leafCol="leafId")</span> |
| <span class="sd"> >>> model = dt.fit(td)</span> |
| <span class="sd"> >>> model.getLabelCol()</span> |
| <span class="sd"> 'indexed'</span> |
| <span class="sd"> >>> model.setFeaturesCol("features")</span> |
| <span class="sd"> DecisionTreeClassificationModel...</span> |
| <span class="sd"> >>> model.numNodes</span> |
| <span class="sd"> 3</span> |
| <span class="sd"> >>> model.depth</span> |
| <span class="sd"> 1</span> |
| <span class="sd"> >>> model.featureImportances</span> |
| <span class="sd"> SparseVector(1, {0: 1.0})</span> |
| <span class="sd"> >>> model.numFeatures</span> |
| <span class="sd"> 1</span> |
| <span class="sd"> >>> model.numClasses</span> |
| <span class="sd"> 2</span> |
| <span class="sd"> >>> print(model.toDebugString)</span> |
| <span class="sd"> DecisionTreeClassificationModel...depth=1, numNodes=3...</span> |
| <span class="sd"> >>> test0 = spark.createDataFrame([(Vectors.dense(-1.0),)], ["features"])</span> |
| <span class="sd"> >>> model.predict(test0.head().features)</span> |
| <span class="sd"> 0.0</span> |
| <span class="sd"> >>> model.predictRaw(test0.head().features)</span> |
| <span class="sd"> DenseVector([1.0, 0.0])</span> |
| <span class="sd"> >>> model.predictProbability(test0.head().features)</span> |
| <span class="sd"> DenseVector([1.0, 0.0])</span> |
| <span class="sd"> >>> result = model.transform(test0).head()</span> |
| <span class="sd"> >>> result.prediction</span> |
| <span class="sd"> 0.0</span> |
| <span class="sd"> >>> result.probability</span> |
| <span class="sd"> DenseVector([1.0, 0.0])</span> |
| <span class="sd"> >>> result.rawPrediction</span> |
| <span class="sd"> DenseVector([1.0, 0.0])</span> |
| <span class="sd"> >>> result.leafId</span> |
| <span class="sd"> 0.0</span> |
| <span class="sd"> >>> test1 = spark.createDataFrame([(Vectors.sparse(1, [0], [1.0]),)], ["features"])</span> |
| <span class="sd"> >>> model.transform(test1).head().prediction</span> |
| <span class="sd"> 1.0</span> |
| <span class="sd"> >>> dtc_path = temp_path + "/dtc"</span> |
| <span class="sd"> >>> dt.save(dtc_path)</span> |
| <span class="sd"> >>> dt2 = DecisionTreeClassifier.load(dtc_path)</span> |
| <span class="sd"> >>> dt2.getMaxDepth()</span> |
| <span class="sd"> 2</span> |
| <span class="sd"> >>> model_path = temp_path + "/dtc_model"</span> |
| <span class="sd"> >>> model.save(model_path)</span> |
| <span class="sd"> >>> model2 = DecisionTreeClassificationModel.load(model_path)</span> |
| <span class="sd"> >>> model.featureImportances == model2.featureImportances</span> |
| <span class="sd"> True</span> |
| <span class="sd"> >>> model.transform(test0).take(1) == model2.transform(test0).take(1)</span> |
| <span class="sd"> True</span> |
| <span class="sd"> >>> df3 = spark.createDataFrame([</span> |
| <span class="sd"> ... (1.0, 0.2, Vectors.dense(1.0)),</span> |
| <span class="sd"> ... (1.0, 0.8, Vectors.dense(1.0)),</span> |
| <span class="sd"> ... (0.0, 1.0, Vectors.sparse(1, [], []))], ["label", "weight", "features"])</span> |
| <span class="sd"> >>> si3 = StringIndexer(inputCol="label", outputCol="indexed")</span> |
| <span class="sd"> >>> si_model3 = si3.fit(df3)</span> |
| <span class="sd"> >>> td3 = si_model3.transform(df3)</span> |
| <span class="sd"> >>> dt3 = DecisionTreeClassifier(maxDepth=2, weightCol="weight", labelCol="indexed")</span> |
| <span class="sd"> >>> model3 = dt3.fit(td3)</span> |
| <span class="sd"> >>> print(model3.toDebugString)</span> |
| <span class="sd"> DecisionTreeClassificationModel...depth=1, numNodes=3...</span> |
| <span class="sd"> """</span> |
| |
| <span class="n">_input_kwargs</span><span class="p">:</span> <span class="n">Dict</span><span class="p">[</span><span class="nb">str</span><span class="p">,</span> <span class="n">Any</span><span class="p">]</span> |
| |
| <span class="nd">@keyword_only</span> |
| <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span> |
| <span class="bp">self</span><span class="p">,</span> |
| <span class="o">*</span><span class="p">,</span> |
| <span class="n">featuresCol</span><span class="p">:</span> <span class="nb">str</span> <span class="o">=</span> <span class="s2">"features"</span><span class="p">,</span> |
| <span class="n">labelCol</span><span class="p">:</span> <span class="nb">str</span> <span class="o">=</span> <span class="s2">"label"</span><span class="p">,</span> |
| <span class="n">predictionCol</span><span class="p">:</span> <span class="nb">str</span> <span class="o">=</span> <span class="s2">"prediction"</span><span class="p">,</span> |
| <span class="n">probabilityCol</span><span class="p">:</span> <span class="nb">str</span> <span class="o">=</span> <span class="s2">"probability"</span><span class="p">,</span> |
| <span class="n">rawPredictionCol</span><span class="p">:</span> <span class="nb">str</span> <span class="o">=</span> <span class="s2">"rawPrediction"</span><span class="p">,</span> |
| <span class="n">maxDepth</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="mi">5</span><span class="p">,</span> |
| <span class="n">maxBins</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="mi">32</span><span class="p">,</span> |
| <span class="n">minInstancesPerNode</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="mi">1</span><span class="p">,</span> |
| <span class="n">minInfoGain</span><span class="p">:</span> <span class="nb">float</span> <span class="o">=</span> <span class="mf">0.0</span><span class="p">,</span> |
| <span class="n">maxMemoryInMB</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="mi">256</span><span class="p">,</span> |
| <span class="n">cacheNodeIds</span><span class="p">:</span> <span class="nb">bool</span> <span class="o">=</span> <span class="kc">False</span><span class="p">,</span> |
| <span class="n">checkpointInterval</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="mi">10</span><span class="p">,</span> |
| <span class="n">impurity</span><span class="p">:</span> <span class="nb">str</span> <span class="o">=</span> <span class="s2">"gini"</span><span class="p">,</span> |
| <span class="n">seed</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="nb">int</span><span class="p">]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span> |
| <span class="n">weightCol</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="nb">str</span><span class="p">]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span> |
| <span class="n">leafCol</span><span class="p">:</span> <span class="nb">str</span> <span class="o">=</span> <span class="s2">""</span><span class="p">,</span> |
| <span class="n">minWeightFractionPerNode</span><span class="p">:</span> <span class="nb">float</span> <span class="o">=</span> <span class="mf">0.0</span><span class="p">,</span> |
| <span class="p">):</span> |
| <span class="w"> </span><span class="sd">"""</span> |
| <span class="sd"> __init__(self, \\*, featuresCol="features", labelCol="label", predictionCol="prediction", \</span> |
| <span class="sd"> probabilityCol="probability", rawPredictionCol="rawPrediction", \</span> |
| <span class="sd"> maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0, \</span> |
| <span class="sd"> maxMemoryInMB=256, cacheNodeIds=False, checkpointInterval=10, impurity="gini", \</span> |
| <span class="sd"> seed=None, weightCol=None, leafCol="", minWeightFractionPerNode=0.0)</span> |
| <span class="sd"> """</span> |
| <span class="nb">super</span><span class="p">(</span><span class="n">DecisionTreeClassifier</span><span class="p">,</span> <span class="bp">self</span><span class="p">)</span><span class="o">.</span><span class="fm">__init__</span><span class="p">()</span> |
| <span class="bp">self</span><span class="o">.</span><span class="n">_java_obj</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">_new_java_obj</span><span class="p">(</span> |
| <span class="s2">"org.apache.spark.ml.classification.DecisionTreeClassifier"</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">uid</span> |
| <span class="p">)</span> |
| <span class="n">kwargs</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">_input_kwargs</span> |
| <span class="bp">self</span><span class="o">.</span><span class="n">setParams</span><span class="p">(</span><span class="o">**</span><span class="n">kwargs</span><span class="p">)</span> |
| |
| <div class="viewcode-block" id="DecisionTreeClassifier.setParams"><a class="viewcode-back" href="../../../reference/api/pyspark.ml.classification.DecisionTreeClassifier.html#pyspark.ml.classification.DecisionTreeClassifier.setParams">[docs]</a> <span class="nd">@keyword_only</span> |
| <span class="nd">@since</span><span class="p">(</span><span class="s2">"1.4.0"</span><span class="p">)</span> |
| <span class="k">def</span> <span class="nf">setParams</span><span class="p">(</span> |
| <span class="bp">self</span><span class="p">,</span> |
| <span class="o">*</span><span class="p">,</span> |
| <span class="n">featuresCol</span><span class="p">:</span> <span class="nb">str</span> <span class="o">=</span> <span class="s2">"features"</span><span class="p">,</span> |
| <span class="n">labelCol</span><span class="p">:</span> <span class="nb">str</span> <span class="o">=</span> <span class="s2">"label"</span><span class="p">,</span> |
| <span class="n">predictionCol</span><span class="p">:</span> <span class="nb">str</span> <span class="o">=</span> <span class="s2">"prediction"</span><span class="p">,</span> |
| <span class="n">probabilityCol</span><span class="p">:</span> <span class="nb">str</span> <span class="o">=</span> <span class="s2">"probability"</span><span class="p">,</span> |
| <span class="n">rawPredictionCol</span><span class="p">:</span> <span class="nb">str</span> <span class="o">=</span> <span class="s2">"rawPrediction"</span><span class="p">,</span> |
| <span class="n">maxDepth</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="mi">5</span><span class="p">,</span> |
| <span class="n">maxBins</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="mi">32</span><span class="p">,</span> |
| <span class="n">minInstancesPerNode</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="mi">1</span><span class="p">,</span> |
| <span class="n">minInfoGain</span><span class="p">:</span> <span class="nb">float</span> <span class="o">=</span> <span class="mf">0.0</span><span class="p">,</span> |
| <span class="n">maxMemoryInMB</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="mi">256</span><span class="p">,</span> |
| <span class="n">cacheNodeIds</span><span class="p">:</span> <span class="nb">bool</span> <span class="o">=</span> <span class="kc">False</span><span class="p">,</span> |
| <span class="n">checkpointInterval</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="mi">10</span><span class="p">,</span> |
| <span class="n">impurity</span><span class="p">:</span> <span class="nb">str</span> <span class="o">=</span> <span class="s2">"gini"</span><span class="p">,</span> |
| <span class="n">seed</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="nb">int</span><span class="p">]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span> |
| <span class="n">weightCol</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="nb">str</span><span class="p">]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span> |
| <span class="n">leafCol</span><span class="p">:</span> <span class="nb">str</span> <span class="o">=</span> <span class="s2">""</span><span class="p">,</span> |
| <span class="n">minWeightFractionPerNode</span><span class="p">:</span> <span class="nb">float</span> <span class="o">=</span> <span class="mf">0.0</span><span class="p">,</span> |
| <span class="p">)</span> <span class="o">-></span> <span class="s2">"DecisionTreeClassifier"</span><span class="p">:</span> |
| <span class="w"> </span><span class="sd">"""</span> |
| <span class="sd"> setParams(self, \\*, featuresCol="features", labelCol="label", predictionCol="prediction", \</span> |
| <span class="sd"> probabilityCol="probability", rawPredictionCol="rawPrediction", \</span> |
| <span class="sd"> maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0, \</span> |
| <span class="sd"> maxMemoryInMB=256, cacheNodeIds=False, checkpointInterval=10, impurity="gini", \</span> |
| <span class="sd"> seed=None, weightCol=None, leafCol="", minWeightFractionPerNode=0.0)</span> |
| <span class="sd"> Sets params for the DecisionTreeClassifier.</span> |
| <span class="sd"> """</span> |
| <span class="n">kwargs</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">_input_kwargs</span> |
| <span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">_set</span><span class="p">(</span><span class="o">**</span><span class="n">kwargs</span><span class="p">)</span></div> |
| |
| <span class="k">def</span> <span class="nf">_create_model</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">java_model</span><span class="p">:</span> <span class="s2">"JavaObject"</span><span class="p">)</span> <span class="o">-></span> <span class="s2">"DecisionTreeClassificationModel"</span><span class="p">:</span> |
| <span class="k">return</span> <span class="n">DecisionTreeClassificationModel</span><span class="p">(</span><span class="n">java_model</span><span class="p">)</span> |
| |
| <div class="viewcode-block" id="DecisionTreeClassifier.setMaxDepth"><a class="viewcode-back" href="../../../reference/api/pyspark.ml.classification.DecisionTreeClassifier.html#pyspark.ml.classification.DecisionTreeClassifier.setMaxDepth">[docs]</a> <span class="k">def</span> <span class="nf">setMaxDepth</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">value</span><span class="p">:</span> <span class="nb">int</span><span class="p">)</span> <span class="o">-></span> <span class="s2">"DecisionTreeClassifier"</span><span class="p">:</span> |
| <span class="w"> </span><span class="sd">"""</span> |
| <span class="sd"> Sets the value of :py:attr:`maxDepth`.</span> |
| <span class="sd"> """</span> |
| <span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">_set</span><span class="p">(</span><span class="n">maxDepth</span><span class="o">=</span><span class="n">value</span><span class="p">)</span></div> |
| |
| <div class="viewcode-block" id="DecisionTreeClassifier.setMaxBins"><a class="viewcode-back" href="../../../reference/api/pyspark.ml.classification.DecisionTreeClassifier.html#pyspark.ml.classification.DecisionTreeClassifier.setMaxBins">[docs]</a> <span class="k">def</span> <span class="nf">setMaxBins</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">value</span><span class="p">:</span> <span class="nb">int</span><span class="p">)</span> <span class="o">-></span> <span class="s2">"DecisionTreeClassifier"</span><span class="p">:</span> |
| <span class="w"> </span><span class="sd">"""</span> |
| <span class="sd"> Sets the value of :py:attr:`maxBins`.</span> |
| <span class="sd"> """</span> |
| <span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">_set</span><span class="p">(</span><span class="n">maxBins</span><span class="o">=</span><span class="n">value</span><span class="p">)</span></div> |
| |
| <div class="viewcode-block" id="DecisionTreeClassifier.setMinInstancesPerNode"><a class="viewcode-back" href="../../../reference/api/pyspark.ml.classification.DecisionTreeClassifier.html#pyspark.ml.classification.DecisionTreeClassifier.setMinInstancesPerNode">[docs]</a> <span class="k">def</span> <span class="nf">setMinInstancesPerNode</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">value</span><span class="p">:</span> <span class="nb">int</span><span class="p">)</span> <span class="o">-></span> <span class="s2">"DecisionTreeClassifier"</span><span class="p">:</span> |
| <span class="w"> </span><span class="sd">"""</span> |
| <span class="sd"> Sets the value of :py:attr:`minInstancesPerNode`.</span> |
| <span class="sd"> """</span> |
| <span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">_set</span><span class="p">(</span><span class="n">minInstancesPerNode</span><span class="o">=</span><span class="n">value</span><span class="p">)</span></div> |
| |
| <div class="viewcode-block" id="DecisionTreeClassifier.setMinWeightFractionPerNode"><a class="viewcode-back" href="../../../reference/api/pyspark.ml.classification.DecisionTreeClassifier.html#pyspark.ml.classification.DecisionTreeClassifier.setMinWeightFractionPerNode">[docs]</a> <span class="nd">@since</span><span class="p">(</span><span class="s2">"3.0.0"</span><span class="p">)</span> |
| <span class="k">def</span> <span class="nf">setMinWeightFractionPerNode</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">value</span><span class="p">:</span> <span class="nb">float</span><span class="p">)</span> <span class="o">-></span> <span class="s2">"DecisionTreeClassifier"</span><span class="p">:</span> |
| <span class="w"> </span><span class="sd">"""</span> |
| <span class="sd"> Sets the value of :py:attr:`minWeightFractionPerNode`.</span> |
| <span class="sd"> """</span> |
| <span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">_set</span><span class="p">(</span><span class="n">minWeightFractionPerNode</span><span class="o">=</span><span class="n">value</span><span class="p">)</span></div> |
| |
| <div class="viewcode-block" id="DecisionTreeClassifier.setMinInfoGain"><a class="viewcode-back" href="../../../reference/api/pyspark.ml.classification.DecisionTreeClassifier.html#pyspark.ml.classification.DecisionTreeClassifier.setMinInfoGain">[docs]</a> <span class="k">def</span> <span class="nf">setMinInfoGain</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">value</span><span class="p">:</span> <span class="nb">float</span><span class="p">)</span> <span class="o">-></span> <span class="s2">"DecisionTreeClassifier"</span><span class="p">:</span> |
| <span class="w"> </span><span class="sd">"""</span> |
| <span class="sd"> Sets the value of :py:attr:`minInfoGain`.</span> |
| <span class="sd"> """</span> |
| <span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">_set</span><span class="p">(</span><span class="n">minInfoGain</span><span class="o">=</span><span class="n">value</span><span class="p">)</span></div> |
| |
| <div class="viewcode-block" id="DecisionTreeClassifier.setMaxMemoryInMB"><a class="viewcode-back" href="../../../reference/api/pyspark.ml.classification.DecisionTreeClassifier.html#pyspark.ml.classification.DecisionTreeClassifier.setMaxMemoryInMB">[docs]</a> <span class="k">def</span> <span class="nf">setMaxMemoryInMB</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">value</span><span class="p">:</span> <span class="nb">int</span><span class="p">)</span> <span class="o">-></span> <span class="s2">"DecisionTreeClassifier"</span><span class="p">:</span> |
| <span class="w"> </span><span class="sd">"""</span> |
| <span class="sd"> Sets the value of :py:attr:`maxMemoryInMB`.</span> |
| <span class="sd"> """</span> |
| <span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">_set</span><span class="p">(</span><span class="n">maxMemoryInMB</span><span class="o">=</span><span class="n">value</span><span class="p">)</span></div> |
| |
| <div class="viewcode-block" id="DecisionTreeClassifier.setCacheNodeIds"><a class="viewcode-back" href="../../../reference/api/pyspark.ml.classification.DecisionTreeClassifier.html#pyspark.ml.classification.DecisionTreeClassifier.setCacheNodeIds">[docs]</a> <span class="k">def</span> <span class="nf">setCacheNodeIds</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">value</span><span class="p">:</span> <span class="nb">bool</span><span class="p">)</span> <span class="o">-></span> <span class="s2">"DecisionTreeClassifier"</span><span class="p">:</span> |
| <span class="w"> </span><span class="sd">"""</span> |
| <span class="sd"> Sets the value of :py:attr:`cacheNodeIds`.</span> |
| <span class="sd"> """</span> |
| <span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">_set</span><span class="p">(</span><span class="n">cacheNodeIds</span><span class="o">=</span><span class="n">value</span><span class="p">)</span></div> |
| |
| <div class="viewcode-block" id="DecisionTreeClassifier.setImpurity"><a class="viewcode-back" href="../../../reference/api/pyspark.ml.classification.DecisionTreeClassifier.html#pyspark.ml.classification.DecisionTreeClassifier.setImpurity">[docs]</a> <span class="nd">@since</span><span class="p">(</span><span class="s2">"1.4.0"</span><span class="p">)</span> |
| <span class="k">def</span> <span class="nf">setImpurity</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">value</span><span class="p">:</span> <span class="nb">str</span><span class="p">)</span> <span class="o">-></span> <span class="s2">"DecisionTreeClassifier"</span><span class="p">:</span> |
| <span class="w"> </span><span class="sd">"""</span> |
| <span class="sd"> Sets the value of :py:attr:`impurity`.</span> |
| <span class="sd"> """</span> |
| <span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">_set</span><span class="p">(</span><span class="n">impurity</span><span class="o">=</span><span class="n">value</span><span class="p">)</span></div> |
| |
| <div class="viewcode-block" id="DecisionTreeClassifier.setCheckpointInterval"><a class="viewcode-back" href="../../../reference/api/pyspark.ml.classification.DecisionTreeClassifier.html#pyspark.ml.classification.DecisionTreeClassifier.setCheckpointInterval">[docs]</a> <span class="nd">@since</span><span class="p">(</span><span class="s2">"1.4.0"</span><span class="p">)</span> |
| <span class="k">def</span> <span class="nf">setCheckpointInterval</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">value</span><span class="p">:</span> <span class="nb">int</span><span class="p">)</span> <span class="o">-></span> <span class="s2">"DecisionTreeClassifier"</span><span class="p">:</span> |
| <span class="w"> </span><span class="sd">"""</span> |
| <span class="sd"> Sets the value of :py:attr:`checkpointInterval`.</span> |
| <span class="sd"> """</span> |
| <span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">_set</span><span class="p">(</span><span class="n">checkpointInterval</span><span class="o">=</span><span class="n">value</span><span class="p">)</span></div> |
| |
| <div class="viewcode-block" id="DecisionTreeClassifier.setSeed"><a class="viewcode-back" href="../../../reference/api/pyspark.ml.classification.DecisionTreeClassifier.html#pyspark.ml.classification.DecisionTreeClassifier.setSeed">[docs]</a> <span class="k">def</span> <span class="nf">setSeed</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">value</span><span class="p">:</span> <span class="nb">int</span><span class="p">)</span> <span class="o">-></span> <span class="s2">"DecisionTreeClassifier"</span><span class="p">:</span> |
| <span class="w"> </span><span class="sd">"""</span> |
| <span class="sd"> Sets the value of :py:attr:`seed`.</span> |
| <span class="sd"> """</span> |
| <span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">_set</span><span class="p">(</span><span class="n">seed</span><span class="o">=</span><span class="n">value</span><span class="p">)</span></div> |
| |
| <div class="viewcode-block" id="DecisionTreeClassifier.setWeightCol"><a class="viewcode-back" href="../../../reference/api/pyspark.ml.classification.DecisionTreeClassifier.html#pyspark.ml.classification.DecisionTreeClassifier.setWeightCol">[docs]</a> <span class="nd">@since</span><span class="p">(</span><span class="s2">"3.0.0"</span><span class="p">)</span> |
| <span class="k">def</span> <span class="nf">setWeightCol</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">value</span><span class="p">:</span> <span class="nb">str</span><span class="p">)</span> <span class="o">-></span> <span class="s2">"DecisionTreeClassifier"</span><span class="p">:</span> |
| <span class="w"> </span><span class="sd">"""</span> |
| <span class="sd"> Sets the value of :py:attr:`weightCol`.</span> |
| <span class="sd"> """</span> |
| <span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">_set</span><span class="p">(</span><span class="n">weightCol</span><span class="o">=</span><span class="n">value</span><span class="p">)</span></div></div> |
| |
| |
| <div class="viewcode-block" id="DecisionTreeClassificationModel"><a class="viewcode-back" href="../../../reference/api/pyspark.ml.classification.DecisionTreeClassificationModel.html#pyspark.ml.classification.DecisionTreeClassificationModel">[docs]</a><span class="nd">@inherit_doc</span> |
| <span class="k">class</span> <span class="nc">DecisionTreeClassificationModel</span><span class="p">(</span> |
| <span class="n">_DecisionTreeModel</span><span class="p">,</span> |
| <span class="n">_JavaProbabilisticClassificationModel</span><span class="p">[</span><span class="n">Vector</span><span class="p">],</span> |
| <span class="n">_DecisionTreeClassifierParams</span><span class="p">,</span> |
| <span class="n">JavaMLWritable</span><span class="p">,</span> |
| <span class="n">JavaMLReadable</span><span class="p">[</span><span class="s2">"DecisionTreeClassificationModel"</span><span class="p">],</span> |
| <span class="p">):</span> |
| <span class="w"> </span><span class="sd">"""</span> |
| <span class="sd"> Model fitted by DecisionTreeClassifier.</span> |
| |
| <span class="sd"> .. versionadded:: 1.4.0</span> |
| <span class="sd"> """</span> |
| |
| <span class="nd">@property</span> |
| <span class="k">def</span> <span class="nf">featureImportances</span><span class="p">(</span><span class="bp">self</span><span class="p">)</span> <span class="o">-></span> <span class="n">Vector</span><span class="p">:</span> |
| <span class="w"> </span><span class="sd">"""</span> |
| <span class="sd"> Estimate of the importance of each feature.</span> |
| |
| <span class="sd"> This generalizes the idea of "Gini" importance to other losses,</span> |
| <span class="sd"> following the explanation of Gini importance from "Random Forests" documentation</span> |
| <span class="sd"> by Leo Breiman and Adele Cutler, and following the implementation from scikit-learn.</span> |
| |
| <span class="sd"> This feature importance is calculated as follows:</span> |
| <span class="sd"> - importance(feature j) = sum (over nodes which split on feature j) of the gain,</span> |
| <span class="sd"> where gain is scaled by the number of instances passing through node</span> |
| <span class="sd"> - Normalize importances for tree to sum to 1.</span> |
| |
| <span class="sd"> .. versionadded:: 2.0.0</span> |
| |
| <span class="sd"> Notes</span> |
| <span class="sd"> -----</span> |
| <span class="sd"> Feature importance for single decision trees can have high variance due to</span> |
| <span class="sd"> correlated predictor variables. Consider using a :py:class:`RandomForestClassifier`</span> |
| <span class="sd"> to determine feature importance instead.</span> |
| <span class="sd"> """</span> |
| <span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">_call_java</span><span class="p">(</span><span class="s2">"featureImportances"</span><span class="p">)</span></div> |
| |
| |
| <span class="nd">@inherit_doc</span> |
| <span class="k">class</span> <span class="nc">_RandomForestClassifierParams</span><span class="p">(</span><span class="n">_RandomForestParams</span><span class="p">,</span> <span class="n">_TreeClassifierParams</span><span class="p">):</span> |
| <span class="w"> </span><span class="sd">"""</span> |
| <span class="sd"> Params for :py:class:`RandomForestClassifier` and :py:class:`RandomForestClassificationModel`.</span> |
| <span class="sd"> """</span> |
| |
| <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="o">*</span><span class="n">args</span><span class="p">:</span> <span class="n">Any</span><span class="p">):</span> |
| <span class="nb">super</span><span class="p">(</span><span class="n">_RandomForestClassifierParams</span><span class="p">,</span> <span class="bp">self</span><span class="p">)</span><span class="o">.</span><span class="fm">__init__</span><span class="p">(</span><span class="o">*</span><span class="n">args</span><span class="p">)</span> |
| <span class="bp">self</span><span class="o">.</span><span class="n">_setDefault</span><span class="p">(</span> |
| <span class="n">maxDepth</span><span class="o">=</span><span class="mi">5</span><span class="p">,</span> |
| <span class="n">maxBins</span><span class="o">=</span><span class="mi">32</span><span class="p">,</span> |
| <span class="n">minInstancesPerNode</span><span class="o">=</span><span class="mi">1</span><span class="p">,</span> |
| <span class="n">minInfoGain</span><span class="o">=</span><span class="mf">0.0</span><span class="p">,</span> |
| <span class="n">maxMemoryInMB</span><span class="o">=</span><span class="mi">256</span><span class="p">,</span> |
| <span class="n">cacheNodeIds</span><span class="o">=</span><span class="kc">False</span><span class="p">,</span> |
| <span class="n">checkpointInterval</span><span class="o">=</span><span class="mi">10</span><span class="p">,</span> |
| <span class="n">impurity</span><span class="o">=</span><span class="s2">"gini"</span><span class="p">,</span> |
| <span class="n">numTrees</span><span class="o">=</span><span class="mi">20</span><span class="p">,</span> |
| <span class="n">featureSubsetStrategy</span><span class="o">=</span><span class="s2">"auto"</span><span class="p">,</span> |
| <span class="n">subsamplingRate</span><span class="o">=</span><span class="mf">1.0</span><span class="p">,</span> |
| <span class="n">leafCol</span><span class="o">=</span><span class="s2">""</span><span class="p">,</span> |
| <span class="n">minWeightFractionPerNode</span><span class="o">=</span><span class="mf">0.0</span><span class="p">,</span> |
| <span class="n">bootstrap</span><span class="o">=</span><span class="kc">True</span><span class="p">,</span> |
| <span class="p">)</span> |
| |
| |
| <div class="viewcode-block" id="RandomForestClassifier"><a class="viewcode-back" href="../../../reference/api/pyspark.ml.classification.RandomForestClassifier.html#pyspark.ml.classification.RandomForestClassifier">[docs]</a><span class="nd">@inherit_doc</span> |
| <span class="k">class</span> <span class="nc">RandomForestClassifier</span><span class="p">(</span> |
| <span class="n">_JavaProbabilisticClassifier</span><span class="p">[</span><span class="s2">"RandomForestClassificationModel"</span><span class="p">],</span> |
| <span class="n">_RandomForestClassifierParams</span><span class="p">,</span> |
| <span class="n">JavaMLWritable</span><span class="p">,</span> |
| <span class="n">JavaMLReadable</span><span class="p">[</span><span class="s2">"RandomForestClassifier"</span><span class="p">],</span> |
| <span class="p">):</span> |
| <span class="w"> </span><span class="sd">"""</span> |
| <span class="sd"> `Random Forest <http://en.wikipedia.org/wiki/Random_forest>`_</span> |
| <span class="sd"> learning algorithm for classification.</span> |
| <span class="sd"> It supports both binary and multiclass labels, as well as both continuous and categorical</span> |
| <span class="sd"> features.</span> |
| |
| <span class="sd"> .. versionadded:: 1.4.0</span> |
| |
| <span class="sd"> Examples</span> |
| <span class="sd"> --------</span> |
| <span class="sd"> >>> import numpy</span> |
| <span class="sd"> >>> from numpy import allclose</span> |
| <span class="sd"> >>> from pyspark.ml.linalg import Vectors</span> |
| <span class="sd"> >>> from pyspark.ml.feature import StringIndexer</span> |
| <span class="sd"> >>> df = spark.createDataFrame([</span> |
| <span class="sd"> ... (1.0, Vectors.dense(1.0)),</span> |
| <span class="sd"> ... (0.0, Vectors.sparse(1, [], []))], ["label", "features"])</span> |
| <span class="sd"> >>> stringIndexer = StringIndexer(inputCol="label", outputCol="indexed")</span> |
| <span class="sd"> >>> si_model = stringIndexer.fit(df)</span> |
| <span class="sd"> >>> td = si_model.transform(df)</span> |
| <span class="sd"> >>> rf = RandomForestClassifier(numTrees=3, maxDepth=2, labelCol="indexed", seed=42,</span> |
| <span class="sd"> ... leafCol="leafId")</span> |
| <span class="sd"> >>> rf.getMinWeightFractionPerNode()</span> |
| <span class="sd"> 0.0</span> |
| <span class="sd"> >>> model = rf.fit(td)</span> |
| <span class="sd"> >>> model.getLabelCol()</span> |
| <span class="sd"> 'indexed'</span> |
| <span class="sd"> >>> model.setFeaturesCol("features")</span> |
| <span class="sd"> RandomForestClassificationModel...</span> |
| <span class="sd"> >>> model.setRawPredictionCol("newRawPrediction")</span> |
| <span class="sd"> RandomForestClassificationModel...</span> |
| <span class="sd"> >>> model.getBootstrap()</span> |
| <span class="sd"> True</span> |
| <span class="sd"> >>> model.getRawPredictionCol()</span> |
| <span class="sd"> 'newRawPrediction'</span> |
| <span class="sd"> >>> model.featureImportances</span> |
| <span class="sd"> SparseVector(1, {0: 1.0})</span> |
| <span class="sd"> >>> allclose(model.treeWeights, [1.0, 1.0, 1.0])</span> |
| <span class="sd"> True</span> |
| <span class="sd"> >>> test0 = spark.createDataFrame([(Vectors.dense(-1.0),)], ["features"])</span> |
| <span class="sd"> >>> model.predict(test0.head().features)</span> |
| <span class="sd"> 0.0</span> |
| <span class="sd"> >>> model.predictRaw(test0.head().features)</span> |
| <span class="sd"> DenseVector([2.0, 0.0])</span> |
| <span class="sd"> >>> model.predictProbability(test0.head().features)</span> |
| <span class="sd"> DenseVector([1.0, 0.0])</span> |
| <span class="sd"> >>> result = model.transform(test0).head()</span> |
| <span class="sd"> >>> result.prediction</span> |
| <span class="sd"> 0.0</span> |
| <span class="sd"> >>> numpy.argmax(result.probability)</span> |
| <span class="sd"> 0</span> |
| <span class="sd"> >>> numpy.argmax(result.newRawPrediction)</span> |
| <span class="sd"> 0</span> |
| <span class="sd"> >>> result.leafId</span> |
| <span class="sd"> DenseVector([0.0, 0.0, 0.0])</span> |
| <span class="sd"> >>> test1 = spark.createDataFrame([(Vectors.sparse(1, [0], [1.0]),)], ["features"])</span> |
| <span class="sd"> >>> model.transform(test1).head().prediction</span> |
| <span class="sd"> 1.0</span> |
| <span class="sd"> >>> model.trees</span> |
| <span class="sd"> [DecisionTreeClassificationModel...depth=..., DecisionTreeClassificationModel...]</span> |
| <span class="sd"> >>> rfc_path = temp_path + "/rfc"</span> |
| <span class="sd"> >>> rf.save(rfc_path)</span> |
| <span class="sd"> >>> rf2 = RandomForestClassifier.load(rfc_path)</span> |
| <span class="sd"> >>> rf2.getNumTrees()</span> |
| <span class="sd"> 3</span> |
| <span class="sd"> >>> model_path = temp_path + "/rfc_model"</span> |
| <span class="sd"> >>> model.save(model_path)</span> |
| <span class="sd"> >>> model2 = RandomForestClassificationModel.load(model_path)</span> |
| <span class="sd"> >>> model.featureImportances == model2.featureImportances</span> |
| <span class="sd"> True</span> |
| <span class="sd"> >>> model.transform(test0).take(1) == model2.transform(test0).take(1)</span> |
| <span class="sd"> True</span> |
| <span class="sd"> """</span> |
| |
| <span class="n">_input_kwargs</span><span class="p">:</span> <span class="n">Dict</span><span class="p">[</span><span class="nb">str</span><span class="p">,</span> <span class="n">Any</span><span class="p">]</span> |
| |
| <span class="nd">@keyword_only</span> |
| <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span> |
| <span class="bp">self</span><span class="p">,</span> |
| <span class="o">*</span><span class="p">,</span> |
| <span class="n">featuresCol</span><span class="p">:</span> <span class="nb">str</span> <span class="o">=</span> <span class="s2">"features"</span><span class="p">,</span> |
| <span class="n">labelCol</span><span class="p">:</span> <span class="nb">str</span> <span class="o">=</span> <span class="s2">"label"</span><span class="p">,</span> |
| <span class="n">predictionCol</span><span class="p">:</span> <span class="nb">str</span> <span class="o">=</span> <span class="s2">"prediction"</span><span class="p">,</span> |
| <span class="n">probabilityCol</span><span class="p">:</span> <span class="nb">str</span> <span class="o">=</span> <span class="s2">"probability"</span><span class="p">,</span> |
| <span class="n">rawPredictionCol</span><span class="p">:</span> <span class="nb">str</span> <span class="o">=</span> <span class="s2">"rawPrediction"</span><span class="p">,</span> |
| <span class="n">maxDepth</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="mi">5</span><span class="p">,</span> |
| <span class="n">maxBins</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="mi">32</span><span class="p">,</span> |
| <span class="n">minInstancesPerNode</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="mi">1</span><span class="p">,</span> |
| <span class="n">minInfoGain</span><span class="p">:</span> <span class="nb">float</span> <span class="o">=</span> <span class="mf">0.0</span><span class="p">,</span> |
| <span class="n">maxMemoryInMB</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="mi">256</span><span class="p">,</span> |
| <span class="n">cacheNodeIds</span><span class="p">:</span> <span class="nb">bool</span> <span class="o">=</span> <span class="kc">False</span><span class="p">,</span> |
| <span class="n">checkpointInterval</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="mi">10</span><span class="p">,</span> |
| <span class="n">impurity</span><span class="p">:</span> <span class="nb">str</span> <span class="o">=</span> <span class="s2">"gini"</span><span class="p">,</span> |
| <span class="n">numTrees</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="mi">20</span><span class="p">,</span> |
| <span class="n">featureSubsetStrategy</span><span class="p">:</span> <span class="nb">str</span> <span class="o">=</span> <span class="s2">"auto"</span><span class="p">,</span> |
| <span class="n">seed</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="nb">int</span><span class="p">]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span> |
| <span class="n">subsamplingRate</span><span class="p">:</span> <span class="nb">float</span> <span class="o">=</span> <span class="mf">1.0</span><span class="p">,</span> |
| <span class="n">leafCol</span><span class="p">:</span> <span class="nb">str</span> <span class="o">=</span> <span class="s2">""</span><span class="p">,</span> |
| <span class="n">minWeightFractionPerNode</span><span class="p">:</span> <span class="nb">float</span> <span class="o">=</span> <span class="mf">0.0</span><span class="p">,</span> |
| <span class="n">weightCol</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="nb">str</span><span class="p">]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span> |
| <span class="n">bootstrap</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="nb">bool</span><span class="p">]</span> <span class="o">=</span> <span class="kc">True</span><span class="p">,</span> |
| <span class="p">):</span> |
| <span class="w"> </span><span class="sd">"""</span> |
| <span class="sd"> __init__(self, \\*, featuresCol="features", labelCol="label", predictionCol="prediction", \</span> |
| <span class="sd"> probabilityCol="probability", rawPredictionCol="rawPrediction", \</span> |
| <span class="sd"> maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0, \</span> |
| <span class="sd"> maxMemoryInMB=256, cacheNodeIds=False, checkpointInterval=10, impurity="gini", \</span> |
| <span class="sd"> numTrees=20, featureSubsetStrategy="auto", seed=None, subsamplingRate=1.0, \</span> |
| <span class="sd"> leafCol="", minWeightFractionPerNode=0.0, weightCol=None, bootstrap=True)</span> |
| <span class="sd"> """</span> |
| <span class="nb">super</span><span class="p">(</span><span class="n">RandomForestClassifier</span><span class="p">,</span> <span class="bp">self</span><span class="p">)</span><span class="o">.</span><span class="fm">__init__</span><span class="p">()</span> |
| <span class="bp">self</span><span class="o">.</span><span class="n">_java_obj</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">_new_java_obj</span><span class="p">(</span> |
| <span class="s2">"org.apache.spark.ml.classification.RandomForestClassifier"</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">uid</span> |
| <span class="p">)</span> |
| <span class="n">kwargs</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">_input_kwargs</span> |
| <span class="bp">self</span><span class="o">.</span><span class="n">setParams</span><span class="p">(</span><span class="o">**</span><span class="n">kwargs</span><span class="p">)</span> |
| |
| <div class="viewcode-block" id="RandomForestClassifier.setParams"><a class="viewcode-back" href="../../../reference/api/pyspark.ml.classification.RandomForestClassifier.html#pyspark.ml.classification.RandomForestClassifier.setParams">[docs]</a> <span class="nd">@keyword_only</span> |
| <span class="nd">@since</span><span class="p">(</span><span class="s2">"1.4.0"</span><span class="p">)</span> |
| <span class="k">def</span> <span class="nf">setParams</span><span class="p">(</span> |
| <span class="bp">self</span><span class="p">,</span> |
| <span class="o">*</span><span class="p">,</span> |
| <span class="n">featuresCol</span><span class="p">:</span> <span class="nb">str</span> <span class="o">=</span> <span class="s2">"features"</span><span class="p">,</span> |
| <span class="n">labelCol</span><span class="p">:</span> <span class="nb">str</span> <span class="o">=</span> <span class="s2">"label"</span><span class="p">,</span> |
| <span class="n">predictionCol</span><span class="p">:</span> <span class="nb">str</span> <span class="o">=</span> <span class="s2">"prediction"</span><span class="p">,</span> |
| <span class="n">probabilityCol</span><span class="p">:</span> <span class="nb">str</span> <span class="o">=</span> <span class="s2">"probability"</span><span class="p">,</span> |
| <span class="n">rawPredictionCol</span><span class="p">:</span> <span class="nb">str</span> <span class="o">=</span> <span class="s2">"rawPrediction"</span><span class="p">,</span> |
| <span class="n">maxDepth</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="mi">5</span><span class="p">,</span> |
| <span class="n">maxBins</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="mi">32</span><span class="p">,</span> |
| <span class="n">minInstancesPerNode</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="mi">1</span><span class="p">,</span> |
| <span class="n">minInfoGain</span><span class="p">:</span> <span class="nb">float</span> <span class="o">=</span> <span class="mf">0.0</span><span class="p">,</span> |
| <span class="n">maxMemoryInMB</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="mi">256</span><span class="p">,</span> |
| <span class="n">cacheNodeIds</span><span class="p">:</span> <span class="nb">bool</span> <span class="o">=</span> <span class="kc">False</span><span class="p">,</span> |
| <span class="n">checkpointInterval</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="mi">10</span><span class="p">,</span> |
| <span class="n">impurity</span><span class="p">:</span> <span class="nb">str</span> <span class="o">=</span> <span class="s2">"gini"</span><span class="p">,</span> |
| <span class="n">numTrees</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="mi">20</span><span class="p">,</span> |
| <span class="n">featureSubsetStrategy</span><span class="p">:</span> <span class="nb">str</span> <span class="o">=</span> <span class="s2">"auto"</span><span class="p">,</span> |
| <span class="n">seed</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="nb">int</span><span class="p">]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span> |
| <span class="n">subsamplingRate</span><span class="p">:</span> <span class="nb">float</span> <span class="o">=</span> <span class="mf">1.0</span><span class="p">,</span> |
| <span class="n">leafCol</span><span class="p">:</span> <span class="nb">str</span> <span class="o">=</span> <span class="s2">""</span><span class="p">,</span> |
| <span class="n">minWeightFractionPerNode</span><span class="p">:</span> <span class="nb">float</span> <span class="o">=</span> <span class="mf">0.0</span><span class="p">,</span> |
| <span class="n">weightCol</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="nb">str</span><span class="p">]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span> |
| <span class="n">bootstrap</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="nb">bool</span><span class="p">]</span> <span class="o">=</span> <span class="kc">True</span><span class="p">,</span> |
| <span class="p">)</span> <span class="o">-></span> <span class="s2">"RandomForestClassifier"</span><span class="p">:</span> |
| <span class="w"> </span><span class="sd">"""</span> |
| <span class="sd"> setParams(self, featuresCol="features", labelCol="label", predictionCol="prediction", \</span> |
| <span class="sd"> probabilityCol="probability", rawPredictionCol="rawPrediction", \</span> |
| <span class="sd"> maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0, \</span> |
| <span class="sd"> maxMemoryInMB=256, cacheNodeIds=False, checkpointInterval=10, seed=None, \</span> |
| <span class="sd"> impurity="gini", numTrees=20, featureSubsetStrategy="auto", subsamplingRate=1.0, \</span> |
| <span class="sd"> leafCol="", minWeightFractionPerNode=0.0, weightCol=None, bootstrap=True)</span> |
| <span class="sd"> Sets params for linear classification.</span> |
| <span class="sd"> """</span> |
| <span class="n">kwargs</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">_input_kwargs</span> |
| <span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">_set</span><span class="p">(</span><span class="o">**</span><span class="n">kwargs</span><span class="p">)</span></div> |
| |
| <span class="k">def</span> <span class="nf">_create_model</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">java_model</span><span class="p">:</span> <span class="s2">"JavaObject"</span><span class="p">)</span> <span class="o">-></span> <span class="s2">"RandomForestClassificationModel"</span><span class="p">:</span> |
| <span class="k">return</span> <span class="n">RandomForestClassificationModel</span><span class="p">(</span><span class="n">java_model</span><span class="p">)</span> |
| |
| <div class="viewcode-block" id="RandomForestClassifier.setMaxDepth"><a class="viewcode-back" href="../../../reference/api/pyspark.ml.classification.RandomForestClassifier.html#pyspark.ml.classification.RandomForestClassifier.setMaxDepth">[docs]</a> <span class="k">def</span> <span class="nf">setMaxDepth</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">value</span><span class="p">:</span> <span class="nb">int</span><span class="p">)</span> <span class="o">-></span> <span class="s2">"RandomForestClassifier"</span><span class="p">:</span> |
| <span class="w"> </span><span class="sd">"""</span> |
| <span class="sd"> Sets the value of :py:attr:`maxDepth`.</span> |
| <span class="sd"> """</span> |
| <span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">_set</span><span class="p">(</span><span class="n">maxDepth</span><span class="o">=</span><span class="n">value</span><span class="p">)</span></div> |
| |
| <div class="viewcode-block" id="RandomForestClassifier.setMaxBins"><a class="viewcode-back" href="../../../reference/api/pyspark.ml.classification.RandomForestClassifier.html#pyspark.ml.classification.RandomForestClassifier.setMaxBins">[docs]</a> <span class="k">def</span> <span class="nf">setMaxBins</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">value</span><span class="p">:</span> <span class="nb">int</span><span class="p">)</span> <span class="o">-></span> <span class="s2">"RandomForestClassifier"</span><span class="p">:</span> |
| <span class="w"> </span><span class="sd">"""</span> |
| <span class="sd"> Sets the value of :py:attr:`maxBins`.</span> |
| <span class="sd"> """</span> |
| <span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">_set</span><span class="p">(</span><span class="n">maxBins</span><span class="o">=</span><span class="n">value</span><span class="p">)</span></div> |
| |
| <div class="viewcode-block" id="RandomForestClassifier.setMinInstancesPerNode"><a class="viewcode-back" href="../../../reference/api/pyspark.ml.classification.RandomForestClassifier.html#pyspark.ml.classification.RandomForestClassifier.setMinInstancesPerNode">[docs]</a> <span class="k">def</span> <span class="nf">setMinInstancesPerNode</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">value</span><span class="p">:</span> <span class="nb">int</span><span class="p">)</span> <span class="o">-></span> <span class="s2">"RandomForestClassifier"</span><span class="p">:</span> |
| <span class="w"> </span><span class="sd">"""</span> |
| <span class="sd"> Sets the value of :py:attr:`minInstancesPerNode`.</span> |
| <span class="sd"> """</span> |
| <span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">_set</span><span class="p">(</span><span class="n">minInstancesPerNode</span><span class="o">=</span><span class="n">value</span><span class="p">)</span></div> |
| |
| <div class="viewcode-block" id="RandomForestClassifier.setMinInfoGain"><a class="viewcode-back" href="../../../reference/api/pyspark.ml.classification.RandomForestClassifier.html#pyspark.ml.classification.RandomForestClassifier.setMinInfoGain">[docs]</a> <span class="k">def</span> <span class="nf">setMinInfoGain</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">value</span><span class="p">:</span> <span class="nb">float</span><span class="p">)</span> <span class="o">-></span> <span class="s2">"RandomForestClassifier"</span><span class="p">:</span> |
| <span class="w"> </span><span class="sd">"""</span> |
| <span class="sd"> Sets the value of :py:attr:`minInfoGain`.</span> |
| <span class="sd"> """</span> |
| <span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">_set</span><span class="p">(</span><span class="n">minInfoGain</span><span class="o">=</span><span class="n">value</span><span class="p">)</span></div> |
| |
| <div class="viewcode-block" id="RandomForestClassifier.setMaxMemoryInMB"><a class="viewcode-back" href="../../../reference/api/pyspark.ml.classification.RandomForestClassifier.html#pyspark.ml.classification.RandomForestClassifier.setMaxMemoryInMB">[docs]</a> <span class="k">def</span> <span class="nf">setMaxMemoryInMB</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">value</span><span class="p">:</span> <span class="nb">int</span><span class="p">)</span> <span class="o">-></span> <span class="s2">"RandomForestClassifier"</span><span class="p">:</span> |
| <span class="w"> </span><span class="sd">"""</span> |
| <span class="sd"> Sets the value of :py:attr:`maxMemoryInMB`.</span> |
| <span class="sd"> """</span> |
| <span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">_set</span><span class="p">(</span><span class="n">maxMemoryInMB</span><span class="o">=</span><span class="n">value</span><span class="p">)</span></div> |
| |
| <div class="viewcode-block" id="RandomForestClassifier.setCacheNodeIds"><a class="viewcode-back" href="../../../reference/api/pyspark.ml.classification.RandomForestClassifier.html#pyspark.ml.classification.RandomForestClassifier.setCacheNodeIds">[docs]</a> <span class="k">def</span> <span class="nf">setCacheNodeIds</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">value</span><span class="p">:</span> <span class="nb">bool</span><span class="p">)</span> <span class="o">-></span> <span class="s2">"RandomForestClassifier"</span><span class="p">:</span> |
| <span class="w"> </span><span class="sd">"""</span> |
| <span class="sd"> Sets the value of :py:attr:`cacheNodeIds`.</span> |
| <span class="sd"> """</span> |
| <span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">_set</span><span class="p">(</span><span class="n">cacheNodeIds</span><span class="o">=</span><span class="n">value</span><span class="p">)</span></div> |
| |
| <div class="viewcode-block" id="RandomForestClassifier.setImpurity"><a class="viewcode-back" href="../../../reference/api/pyspark.ml.classification.RandomForestClassifier.html#pyspark.ml.classification.RandomForestClassifier.setImpurity">[docs]</a> <span class="nd">@since</span><span class="p">(</span><span class="s2">"1.4.0"</span><span class="p">)</span> |
| <span class="k">def</span> <span class="nf">setImpurity</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">value</span><span class="p">:</span> <span class="nb">str</span><span class="p">)</span> <span class="o">-></span> <span class="s2">"RandomForestClassifier"</span><span class="p">:</span> |
| <span class="w"> </span><span class="sd">"""</span> |
| <span class="sd"> Sets the value of :py:attr:`impurity`.</span> |
| <span class="sd"> """</span> |
| <span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">_set</span><span class="p">(</span><span class="n">impurity</span><span class="o">=</span><span class="n">value</span><span class="p">)</span></div> |
| |
| <div class="viewcode-block" id="RandomForestClassifier.setNumTrees"><a class="viewcode-back" href="../../../reference/api/pyspark.ml.classification.RandomForestClassifier.html#pyspark.ml.classification.RandomForestClassifier.setNumTrees">[docs]</a> <span class="nd">@since</span><span class="p">(</span><span class="s2">"1.4.0"</span><span class="p">)</span> |
| <span class="k">def</span> <span class="nf">setNumTrees</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">value</span><span class="p">:</span> <span class="nb">int</span><span class="p">)</span> <span class="o">-></span> <span class="s2">"RandomForestClassifier"</span><span class="p">:</span> |
| <span class="w"> </span><span class="sd">"""</span> |
| <span class="sd"> Sets the value of :py:attr:`numTrees`.</span> |
| <span class="sd"> """</span> |
| <span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">_set</span><span class="p">(</span><span class="n">numTrees</span><span class="o">=</span><span class="n">value</span><span class="p">)</span></div> |
| |
| <div class="viewcode-block" id="RandomForestClassifier.setBootstrap"><a class="viewcode-back" href="../../../reference/api/pyspark.ml.classification.RandomForestClassifier.html#pyspark.ml.classification.RandomForestClassifier.setBootstrap">[docs]</a> <span class="nd">@since</span><span class="p">(</span><span class="s2">"3.0.0"</span><span class="p">)</span> |
| <span class="k">def</span> <span class="nf">setBootstrap</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">value</span><span class="p">:</span> <span class="nb">bool</span><span class="p">)</span> <span class="o">-></span> <span class="s2">"RandomForestClassifier"</span><span class="p">:</span> |
| <span class="w"> </span><span class="sd">"""</span> |
| <span class="sd"> Sets the value of :py:attr:`bootstrap`.</span> |
| <span class="sd"> """</span> |
| <span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">_set</span><span class="p">(</span><span class="n">bootstrap</span><span class="o">=</span><span class="n">value</span><span class="p">)</span></div> |
| |
| <div class="viewcode-block" id="RandomForestClassifier.setSubsamplingRate"><a class="viewcode-back" href="../../../reference/api/pyspark.ml.classification.RandomForestClassifier.html#pyspark.ml.classification.RandomForestClassifier.setSubsamplingRate">[docs]</a> <span class="nd">@since</span><span class="p">(</span><span class="s2">"1.4.0"</span><span class="p">)</span> |
| <span class="k">def</span> <span class="nf">setSubsamplingRate</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">value</span><span class="p">:</span> <span class="nb">float</span><span class="p">)</span> <span class="o">-></span> <span class="s2">"RandomForestClassifier"</span><span class="p">:</span> |
| <span class="w"> </span><span class="sd">"""</span> |
| <span class="sd"> Sets the value of :py:attr:`subsamplingRate`.</span> |
| <span class="sd"> """</span> |
| <span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">_set</span><span class="p">(</span><span class="n">subsamplingRate</span><span class="o">=</span><span class="n">value</span><span class="p">)</span></div> |
| |
| <div class="viewcode-block" id="RandomForestClassifier.setFeatureSubsetStrategy"><a class="viewcode-back" href="../../../reference/api/pyspark.ml.classification.RandomForestClassifier.html#pyspark.ml.classification.RandomForestClassifier.setFeatureSubsetStrategy">[docs]</a> <span class="nd">@since</span><span class="p">(</span><span class="s2">"2.4.0"</span><span class="p">)</span> |
| <span class="k">def</span> <span class="nf">setFeatureSubsetStrategy</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">value</span><span class="p">:</span> <span class="nb">str</span><span class="p">)</span> <span class="o">-></span> <span class="s2">"RandomForestClassifier"</span><span class="p">:</span> |
| <span class="w"> </span><span class="sd">"""</span> |
| <span class="sd"> Sets the value of :py:attr:`featureSubsetStrategy`.</span> |
| <span class="sd"> """</span> |
| <span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">_set</span><span class="p">(</span><span class="n">featureSubsetStrategy</span><span class="o">=</span><span class="n">value</span><span class="p">)</span></div> |
| |
| <div class="viewcode-block" id="RandomForestClassifier.setSeed"><a class="viewcode-back" href="../../../reference/api/pyspark.ml.classification.RandomForestClassifier.html#pyspark.ml.classification.RandomForestClassifier.setSeed">[docs]</a> <span class="k">def</span> <span class="nf">setSeed</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">value</span><span class="p">:</span> <span class="nb">int</span><span class="p">)</span> <span class="o">-></span> <span class="s2">"RandomForestClassifier"</span><span class="p">:</span> |
| <span class="w"> </span><span class="sd">"""</span> |
| <span class="sd"> Sets the value of :py:attr:`seed`.</span> |
| <span class="sd"> """</span> |
| <span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">_set</span><span class="p">(</span><span class="n">seed</span><span class="o">=</span><span class="n">value</span><span class="p">)</span></div> |
| |
| <div class="viewcode-block" id="RandomForestClassifier.setCheckpointInterval"><a class="viewcode-back" href="../../../reference/api/pyspark.ml.classification.RandomForestClassifier.html#pyspark.ml.classification.RandomForestClassifier.setCheckpointInterval">[docs]</a> <span class="k">def</span> <span class="nf">setCheckpointInterval</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">value</span><span class="p">:</span> <span class="nb">int</span><span class="p">)</span> <span class="o">-></span> <span class="s2">"RandomForestClassifier"</span><span class="p">:</span> |
| <span class="w"> </span><span class="sd">"""</span> |
| <span class="sd"> Sets the value of :py:attr:`checkpointInterval`.</span> |
| <span class="sd"> """</span> |
| <span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">_set</span><span class="p">(</span><span class="n">checkpointInterval</span><span class="o">=</span><span class="n">value</span><span class="p">)</span></div> |
| |
| <div class="viewcode-block" id="RandomForestClassifier.setWeightCol"><a class="viewcode-back" href="../../../reference/api/pyspark.ml.classification.RandomForestClassifier.html#pyspark.ml.classification.RandomForestClassifier.setWeightCol">[docs]</a> <span class="nd">@since</span><span class="p">(</span><span class="s2">"3.0.0"</span><span class="p">)</span> |
| <span class="k">def</span> <span class="nf">setWeightCol</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">value</span><span class="p">:</span> <span class="nb">str</span><span class="p">)</span> <span class="o">-></span> <span class="s2">"RandomForestClassifier"</span><span class="p">:</span> |
| <span class="w"> </span><span class="sd">"""</span> |
| <span class="sd"> Sets the value of :py:attr:`weightCol`.</span> |
| <span class="sd"> """</span> |
| <span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">_set</span><span class="p">(</span><span class="n">weightCol</span><span class="o">=</span><span class="n">value</span><span class="p">)</span></div> |
| |
| <div class="viewcode-block" id="RandomForestClassifier.setMinWeightFractionPerNode"><a class="viewcode-back" href="../../../reference/api/pyspark.ml.classification.RandomForestClassifier.html#pyspark.ml.classification.RandomForestClassifier.setMinWeightFractionPerNode">[docs]</a> <span class="nd">@since</span><span class="p">(</span><span class="s2">"3.0.0"</span><span class="p">)</span> |
| <span class="k">def</span> <span class="nf">setMinWeightFractionPerNode</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">value</span><span class="p">:</span> <span class="nb">float</span><span class="p">)</span> <span class="o">-></span> <span class="s2">"RandomForestClassifier"</span><span class="p">:</span> |
| <span class="w"> </span><span class="sd">"""</span> |
| <span class="sd"> Sets the value of :py:attr:`minWeightFractionPerNode`.</span> |
| <span class="sd"> """</span> |
| <span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">_set</span><span class="p">(</span><span class="n">minWeightFractionPerNode</span><span class="o">=</span><span class="n">value</span><span class="p">)</span></div></div> |
| |
| |
| <div class="viewcode-block" id="RandomForestClassificationModel"><a class="viewcode-back" href="../../../reference/api/pyspark.ml.classification.RandomForestClassificationModel.html#pyspark.ml.classification.RandomForestClassificationModel">[docs]</a><span class="k">class</span> <span class="nc">RandomForestClassificationModel</span><span class="p">(</span> |
| <span class="n">_TreeEnsembleModel</span><span class="p">,</span> |
| <span class="n">_JavaProbabilisticClassificationModel</span><span class="p">[</span><span class="n">Vector</span><span class="p">],</span> |
| <span class="n">_RandomForestClassifierParams</span><span class="p">,</span> |
| <span class="n">JavaMLWritable</span><span class="p">,</span> |
| <span class="n">JavaMLReadable</span><span class="p">[</span><span class="s2">"RandomForestClassificationModel"</span><span class="p">],</span> |
| <span class="n">HasTrainingSummary</span><span class="p">[</span><span class="s2">"RandomForestClassificationTrainingSummary"</span><span class="p">],</span> |
| <span class="p">):</span> |
| <span class="w"> </span><span class="sd">"""</span> |
| <span class="sd"> Model fitted by RandomForestClassifier.</span> |
| |
| <span class="sd"> .. versionadded:: 1.4.0</span> |
| <span class="sd"> """</span> |
| |
| <span class="nd">@property</span> |
| <span class="k">def</span> <span class="nf">featureImportances</span><span class="p">(</span><span class="bp">self</span><span class="p">)</span> <span class="o">-></span> <span class="n">Vector</span><span class="p">:</span> |
| <span class="w"> </span><span class="sd">"""</span> |
| <span class="sd"> Estimate of the importance of each feature.</span> |
| |
| <span class="sd"> Each feature's importance is the average of its importance across all trees in the ensemble</span> |
| <span class="sd"> The importance vector is normalized to sum to 1. This method is suggested by Hastie et al.</span> |
| <span class="sd"> (Hastie, Tibshirani, Friedman. "The Elements of Statistical Learning, 2nd Edition." 2001.)</span> |
| <span class="sd"> and follows the implementation from scikit-learn.</span> |
| |
| <span class="sd"> .. versionadded:: 2.0.0</span> |
| |
| <span class="sd"> See Also</span> |
| <span class="sd"> --------</span> |
| <span class="sd"> DecisionTreeClassificationModel.featureImportances</span> |
| <span class="sd"> """</span> |
| <span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">_call_java</span><span class="p">(</span><span class="s2">"featureImportances"</span><span class="p">)</span> |
| |
| <span class="nd">@property</span> <span class="c1"># type: ignore[misc]</span> |
| <span class="nd">@since</span><span class="p">(</span><span class="s2">"2.0.0"</span><span class="p">)</span> |
| <span class="k">def</span> <span class="nf">trees</span><span class="p">(</span><span class="bp">self</span><span class="p">)</span> <span class="o">-></span> <span class="n">List</span><span class="p">[</span><span class="n">DecisionTreeClassificationModel</span><span class="p">]:</span> |
| <span class="w"> </span><span class="sd">"""Trees in this ensemble. Warning: These have null parent Estimators."""</span> |
| <span class="k">return</span> <span class="p">[</span><span class="n">DecisionTreeClassificationModel</span><span class="p">(</span><span class="n">m</span><span class="p">)</span> <span class="k">for</span> <span class="n">m</span> <span class="ow">in</span> <span class="nb">list</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">_call_java</span><span class="p">(</span><span class="s2">"trees"</span><span class="p">))]</span> |
| |
| <span class="nd">@property</span> <span class="c1"># type: ignore[misc]</span> |
| <span class="nd">@since</span><span class="p">(</span><span class="s2">"3.1.0"</span><span class="p">)</span> |
| <span class="k">def</span> <span class="nf">summary</span><span class="p">(</span><span class="bp">self</span><span class="p">)</span> <span class="o">-></span> <span class="s2">"RandomForestClassificationTrainingSummary"</span><span class="p">:</span> |
| <span class="w"> </span><span class="sd">"""</span> |
| <span class="sd"> Gets summary (accuracy/precision/recall, objective history, total iterations) of model</span> |
| <span class="sd"> trained on the training set. An exception is thrown if `trainingSummary is None`.</span> |
| <span class="sd"> """</span> |
| <span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">hasSummary</span><span class="p">:</span> |
| <span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">numClasses</span> <span class="o"><=</span> <span class="mi">2</span><span class="p">:</span> |
| <span class="k">return</span> <span class="n">BinaryRandomForestClassificationTrainingSummary</span><span class="p">(</span> |
| <span class="nb">super</span><span class="p">(</span><span class="n">RandomForestClassificationModel</span><span class="p">,</span> <span class="bp">self</span><span class="p">)</span><span class="o">.</span><span class="n">summary</span> |
| <span class="p">)</span> |
| <span class="k">else</span><span class="p">:</span> |
| <span class="k">return</span> <span class="n">RandomForestClassificationTrainingSummary</span><span class="p">(</span> |
| <span class="nb">super</span><span class="p">(</span><span class="n">RandomForestClassificationModel</span><span class="p">,</span> <span class="bp">self</span><span class="p">)</span><span class="o">.</span><span class="n">summary</span> |
| <span class="p">)</span> |
| <span class="k">else</span><span class="p">:</span> |
| <span class="k">raise</span> <span class="ne">RuntimeError</span><span class="p">(</span> |
| <span class="s2">"No training summary available for this </span><span class="si">%s</span><span class="s2">"</span> <span class="o">%</span> <span class="bp">self</span><span class="o">.</span><span class="vm">__class__</span><span class="o">.</span><span class="vm">__name__</span> |
| <span class="p">)</span> |
| |
| <div class="viewcode-block" id="RandomForestClassificationModel.evaluate"><a class="viewcode-back" href="../../../reference/api/pyspark.ml.classification.RandomForestClassificationModel.html#pyspark.ml.classification.RandomForestClassificationModel.evaluate">[docs]</a> <span class="k">def</span> <span class="nf">evaluate</span><span class="p">(</span> |
| <span class="bp">self</span><span class="p">,</span> <span class="n">dataset</span><span class="p">:</span> <span class="n">DataFrame</span> |
| <span class="p">)</span> <span class="o">-></span> <span class="n">Union</span><span class="p">[</span><span class="s2">"BinaryRandomForestClassificationSummary"</span><span class="p">,</span> <span class="s2">"RandomForestClassificationSummary"</span><span class="p">]:</span> |
| <span class="w"> </span><span class="sd">"""</span> |
| <span class="sd"> Evaluates the model on a test dataset.</span> |
| |
| <span class="sd"> .. versionadded:: 3.1.0</span> |
| |
| <span class="sd"> Parameters</span> |
| <span class="sd"> ----------</span> |
| <span class="sd"> dataset : :py:class:`pyspark.sql.DataFrame`</span> |
| <span class="sd"> Test dataset to evaluate model on.</span> |
| <span class="sd"> """</span> |
| <span class="k">if</span> <span class="ow">not</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">dataset</span><span class="p">,</span> <span class="n">DataFrame</span><span class="p">):</span> |
| <span class="k">raise</span> <span class="ne">TypeError</span><span class="p">(</span><span class="s2">"dataset must be a DataFrame but got </span><span class="si">%s</span><span class="s2">."</span> <span class="o">%</span> <span class="nb">type</span><span class="p">(</span><span class="n">dataset</span><span class="p">))</span> |
| <span class="n">java_rf_summary</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">_call_java</span><span class="p">(</span><span class="s2">"evaluate"</span><span class="p">,</span> <span class="n">dataset</span><span class="p">)</span> |
| <span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">numClasses</span> <span class="o"><=</span> <span class="mi">2</span><span class="p">:</span> |
| <span class="k">return</span> <span class="n">BinaryRandomForestClassificationSummary</span><span class="p">(</span><span class="n">java_rf_summary</span><span class="p">)</span> |
| <span class="k">else</span><span class="p">:</span> |
| <span class="k">return</span> <span class="n">RandomForestClassificationSummary</span><span class="p">(</span><span class="n">java_rf_summary</span><span class="p">)</span></div></div> |
| |
| |
| <div class="viewcode-block" id="RandomForestClassificationSummary"><a class="viewcode-back" href="../../../reference/api/pyspark.ml.classification.RandomForestClassificationSummary.html#pyspark.ml.classification.RandomForestClassificationSummary">[docs]</a><span class="k">class</span> <span class="nc">RandomForestClassificationSummary</span><span class="p">(</span><span class="n">_ClassificationSummary</span><span class="p">):</span> |
| <span class="w"> </span><span class="sd">"""</span> |
| <span class="sd"> Abstraction for RandomForestClassification Results for a given model.</span> |
| |
| <span class="sd"> .. versionadded:: 3.1.0</span> |
| <span class="sd"> """</span> |
| |
| <span class="k">pass</span></div> |
| |
| |
| <div class="viewcode-block" id="RandomForestClassificationTrainingSummary"><a class="viewcode-back" href="../../../reference/api/pyspark.ml.classification.RandomForestClassificationTrainingSummary.html#pyspark.ml.classification.RandomForestClassificationTrainingSummary">[docs]</a><span class="nd">@inherit_doc</span> |
| <span class="k">class</span> <span class="nc">RandomForestClassificationTrainingSummary</span><span class="p">(</span> |
| <span class="n">RandomForestClassificationSummary</span><span class="p">,</span> <span class="n">_TrainingSummary</span> |
| <span class="p">):</span> |
| <span class="w"> </span><span class="sd">"""</span> |
| <span class="sd"> Abstraction for RandomForestClassificationTraining Training results.</span> |
| |
| <span class="sd"> .. versionadded:: 3.1.0</span> |
| <span class="sd"> """</span> |
| |
| <span class="k">pass</span></div> |
| |
| |
| <div class="viewcode-block" id="BinaryRandomForestClassificationSummary"><a class="viewcode-back" href="../../../reference/api/pyspark.ml.classification.BinaryRandomForestClassificationSummary.html#pyspark.ml.classification.BinaryRandomForestClassificationSummary">[docs]</a><span class="nd">@inherit_doc</span> |
| <span class="k">class</span> <span class="nc">BinaryRandomForestClassificationSummary</span><span class="p">(</span><span class="n">_BinaryClassificationSummary</span><span class="p">):</span> |
| <span class="w"> </span><span class="sd">"""</span> |
| <span class="sd"> BinaryRandomForestClassification results for a given model.</span> |
| |
| <span class="sd"> .. versionadded:: 3.1.0</span> |
| <span class="sd"> """</span> |
| |
| <span class="k">pass</span></div> |
| |
| |
| <div class="viewcode-block" id="BinaryRandomForestClassificationTrainingSummary"><a class="viewcode-back" href="../../../reference/api/pyspark.ml.classification.BinaryRandomForestClassificationTrainingSummary.html#pyspark.ml.classification.BinaryRandomForestClassificationTrainingSummary">[docs]</a><span class="nd">@inherit_doc</span> |
| <span class="k">class</span> <span class="nc">BinaryRandomForestClassificationTrainingSummary</span><span class="p">(</span> |
| <span class="n">BinaryRandomForestClassificationSummary</span><span class="p">,</span> <span class="n">RandomForestClassificationTrainingSummary</span> |
| <span class="p">):</span> |
| <span class="w"> </span><span class="sd">"""</span> |
| <span class="sd"> BinaryRandomForestClassification training results for a given model.</span> |
| |
| <span class="sd"> .. versionadded:: 3.1.0</span> |
| <span class="sd"> """</span> |
| |
| <span class="k">pass</span></div> |
| |
| |
| <span class="k">class</span> <span class="nc">_GBTClassifierParams</span><span class="p">(</span><span class="n">_GBTParams</span><span class="p">,</span> <span class="n">_HasVarianceImpurity</span><span class="p">):</span> |
| <span class="w"> </span><span class="sd">"""</span> |
| <span class="sd"> Params for :py:class:`GBTClassifier` and :py:class:`GBTClassifierModel`.</span> |
| |
| <span class="sd"> .. versionadded:: 3.0.0</span> |
| <span class="sd"> """</span> |
| |
| <span class="n">supportedLossTypes</span><span class="p">:</span> <span class="n">List</span><span class="p">[</span><span class="nb">str</span><span class="p">]</span> <span class="o">=</span> <span class="p">[</span><span class="s2">"logistic"</span><span class="p">]</span> |
| |
| <span class="n">lossType</span><span class="p">:</span> <span class="n">Param</span><span class="p">[</span><span class="nb">str</span><span class="p">]</span> <span class="o">=</span> <span class="n">Param</span><span class="p">(</span> |
| <span class="n">Params</span><span class="o">.</span><span class="n">_dummy</span><span class="p">(),</span> |
| <span class="s2">"lossType"</span><span class="p">,</span> |
| <span class="s2">"Loss function which GBT tries to minimize (case-insensitive). "</span> |
| <span class="o">+</span> <span class="s2">"Supported options: "</span> |
| <span class="o">+</span> <span class="s2">", "</span><span class="o">.</span><span class="n">join</span><span class="p">(</span><span class="n">supportedLossTypes</span><span class="p">),</span> |
| <span class="n">typeConverter</span><span class="o">=</span><span class="n">TypeConverters</span><span class="o">.</span><span class="n">toString</span><span class="p">,</span> |
| <span class="p">)</span> |
| |
| <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="o">*</span><span class="n">args</span><span class="p">:</span> <span class="n">Any</span><span class="p">):</span> |
| <span class="nb">super</span><span class="p">(</span><span class="n">_GBTClassifierParams</span><span class="p">,</span> <span class="bp">self</span><span class="p">)</span><span class="o">.</span><span class="fm">__init__</span><span class="p">(</span><span class="o">*</span><span class="n">args</span><span class="p">)</span> |
| <span class="bp">self</span><span class="o">.</span><span class="n">_setDefault</span><span class="p">(</span> |
| <span class="n">maxDepth</span><span class="o">=</span><span class="mi">5</span><span class="p">,</span> |
| <span class="n">maxBins</span><span class="o">=</span><span class="mi">32</span><span class="p">,</span> |
| <span class="n">minInstancesPerNode</span><span class="o">=</span><span class="mi">1</span><span class="p">,</span> |
| <span class="n">minInfoGain</span><span class="o">=</span><span class="mf">0.0</span><span class="p">,</span> |
| <span class="n">maxMemoryInMB</span><span class="o">=</span><span class="mi">256</span><span class="p">,</span> |
| <span class="n">cacheNodeIds</span><span class="o">=</span><span class="kc">False</span><span class="p">,</span> |
| <span class="n">checkpointInterval</span><span class="o">=</span><span class="mi">10</span><span class="p">,</span> |
| <span class="n">lossType</span><span class="o">=</span><span class="s2">"logistic"</span><span class="p">,</span> |
| <span class="n">maxIter</span><span class="o">=</span><span class="mi">20</span><span class="p">,</span> |
| <span class="n">stepSize</span><span class="o">=</span><span class="mf">0.1</span><span class="p">,</span> |
| <span class="n">subsamplingRate</span><span class="o">=</span><span class="mf">1.0</span><span class="p">,</span> |
| <span class="n">impurity</span><span class="o">=</span><span class="s2">"variance"</span><span class="p">,</span> |
| <span class="n">featureSubsetStrategy</span><span class="o">=</span><span class="s2">"all"</span><span class="p">,</span> |
| <span class="n">validationTol</span><span class="o">=</span><span class="mf">0.01</span><span class="p">,</span> |
| <span class="n">leafCol</span><span class="o">=</span><span class="s2">""</span><span class="p">,</span> |
| <span class="n">minWeightFractionPerNode</span><span class="o">=</span><span class="mf">0.0</span><span class="p">,</span> |
| <span class="p">)</span> |
| |
| <span class="nd">@since</span><span class="p">(</span><span class="s2">"1.4.0"</span><span class="p">)</span> |
| <span class="k">def</span> <span class="nf">getLossType</span><span class="p">(</span><span class="bp">self</span><span class="p">)</span> <span class="o">-></span> <span class="nb">str</span><span class="p">:</span> |
| <span class="w"> </span><span class="sd">"""</span> |
| <span class="sd"> Gets the value of lossType or its default value.</span> |
| <span class="sd"> """</span> |
| <span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">getOrDefault</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">lossType</span><span class="p">)</span> |
| |
| |
| <div class="viewcode-block" id="GBTClassifier"><a class="viewcode-back" href="../../../reference/api/pyspark.ml.classification.GBTClassifier.html#pyspark.ml.classification.GBTClassifier">[docs]</a><span class="nd">@inherit_doc</span> |
| <span class="k">class</span> <span class="nc">GBTClassifier</span><span class="p">(</span> |
| <span class="n">_JavaProbabilisticClassifier</span><span class="p">[</span><span class="s2">"GBTClassificationModel"</span><span class="p">],</span> |
| <span class="n">_GBTClassifierParams</span><span class="p">,</span> |
| <span class="n">JavaMLWritable</span><span class="p">,</span> |
| <span class="n">JavaMLReadable</span><span class="p">[</span><span class="s2">"GBTClassifier"</span><span class="p">],</span> |
| <span class="p">):</span> |
| <span class="w"> </span><span class="sd">"""</span> |
| <span class="sd"> `Gradient-Boosted Trees (GBTs) <http://en.wikipedia.org/wiki/Gradient_boosting>`_</span> |
| <span class="sd"> learning algorithm for classification.</span> |
| <span class="sd"> It supports binary labels, as well as both continuous and categorical features.</span> |
| |
| <span class="sd"> .. versionadded:: 1.4.0</span> |
| |
| <span class="sd"> Notes</span> |
| <span class="sd"> -----</span> |
| <span class="sd"> Multiclass labels are not currently supported.</span> |
| |
| <span class="sd"> The implementation is based upon: J.H. Friedman. "Stochastic Gradient Boosting." 1999.</span> |
| |
| <span class="sd"> Gradient Boosting vs. TreeBoost:</span> |
| |
| <span class="sd"> - This implementation is for Stochastic Gradient Boosting, not for TreeBoost.</span> |
| <span class="sd"> - Both algorithms learn tree ensembles by minimizing loss functions.</span> |
| <span class="sd"> - TreeBoost (Friedman, 1999) additionally modifies the outputs at tree leaf nodes</span> |
| <span class="sd"> based on the loss function, whereas the original gradient boosting method does not.</span> |
| <span class="sd"> - We expect to implement TreeBoost in the future:</span> |
| <span class="sd"> `SPARK-4240 <https://issues.apache.org/jira/browse/SPARK-4240>`_</span> |
| |
| <span class="sd"> Examples</span> |
| <span class="sd"> --------</span> |
| <span class="sd"> >>> from numpy import allclose</span> |
| <span class="sd"> >>> from pyspark.ml.linalg import Vectors</span> |
| <span class="sd"> >>> from pyspark.ml.feature import StringIndexer</span> |
| <span class="sd"> >>> df = spark.createDataFrame([</span> |
| <span class="sd"> ... (1.0, Vectors.dense(1.0)),</span> |
| <span class="sd"> ... (0.0, Vectors.sparse(1, [], []))], ["label", "features"])</span> |
| <span class="sd"> >>> stringIndexer = StringIndexer(inputCol="label", outputCol="indexed")</span> |
| <span class="sd"> >>> si_model = stringIndexer.fit(df)</span> |
| <span class="sd"> >>> td = si_model.transform(df)</span> |
| <span class="sd"> >>> gbt = GBTClassifier(maxIter=5, maxDepth=2, labelCol="indexed", seed=42,</span> |
| <span class="sd"> ... leafCol="leafId")</span> |
| <span class="sd"> >>> gbt.setMaxIter(5)</span> |
| <span class="sd"> GBTClassifier...</span> |
| <span class="sd"> >>> gbt.setMinWeightFractionPerNode(0.049)</span> |
| <span class="sd"> GBTClassifier...</span> |
| <span class="sd"> >>> gbt.getMaxIter()</span> |
| <span class="sd"> 5</span> |
| <span class="sd"> >>> gbt.getFeatureSubsetStrategy()</span> |
| <span class="sd"> 'all'</span> |
| <span class="sd"> >>> model = gbt.fit(td)</span> |
| <span class="sd"> >>> model.getLabelCol()</span> |
| <span class="sd"> 'indexed'</span> |
| <span class="sd"> >>> model.setFeaturesCol("features")</span> |
| <span class="sd"> GBTClassificationModel...</span> |
| <span class="sd"> >>> model.setThresholds([0.3, 0.7])</span> |
| <span class="sd"> GBTClassificationModel...</span> |
| <span class="sd"> >>> model.getThresholds()</span> |
| <span class="sd"> [0.3, 0.7]</span> |
| <span class="sd"> >>> model.featureImportances</span> |
| <span class="sd"> SparseVector(1, {0: 1.0})</span> |
| <span class="sd"> >>> allclose(model.treeWeights, [1.0, 0.1, 0.1, 0.1, 0.1])</span> |
| <span class="sd"> True</span> |
| <span class="sd"> >>> test0 = spark.createDataFrame([(Vectors.dense(-1.0),)], ["features"])</span> |
| <span class="sd"> >>> model.predict(test0.head().features)</span> |
| <span class="sd"> 0.0</span> |
| <span class="sd"> >>> model.predictRaw(test0.head().features)</span> |
| <span class="sd"> DenseVector([1.1697, -1.1697])</span> |
| <span class="sd"> >>> model.predictProbability(test0.head().features)</span> |
| <span class="sd"> DenseVector([0.9121, 0.0879])</span> |
| <span class="sd"> >>> result = model.transform(test0).head()</span> |
| <span class="sd"> >>> result.prediction</span> |
| <span class="sd"> 0.0</span> |
| <span class="sd"> >>> result.leafId</span> |
| <span class="sd"> DenseVector([0.0, 0.0, 0.0, 0.0, 0.0])</span> |
| <span class="sd"> >>> test1 = spark.createDataFrame([(Vectors.sparse(1, [0], [1.0]),)], ["features"])</span> |
| <span class="sd"> >>> model.transform(test1).head().prediction</span> |
| <span class="sd"> 1.0</span> |
| <span class="sd"> >>> model.totalNumNodes</span> |
| <span class="sd"> 15</span> |
| <span class="sd"> >>> print(model.toDebugString)</span> |
| <span class="sd"> GBTClassificationModel...numTrees=5...</span> |
| <span class="sd"> >>> gbtc_path = temp_path + "gbtc"</span> |
| <span class="sd"> >>> gbt.save(gbtc_path)</span> |
| <span class="sd"> >>> gbt2 = GBTClassifier.load(gbtc_path)</span> |
| <span class="sd"> >>> gbt2.getMaxDepth()</span> |
| <span class="sd"> 2</span> |
| <span class="sd"> >>> model_path = temp_path + "gbtc_model"</span> |
| <span class="sd"> >>> model.save(model_path)</span> |
| <span class="sd"> >>> model2 = GBTClassificationModel.load(model_path)</span> |
| <span class="sd"> >>> model.featureImportances == model2.featureImportances</span> |
| <span class="sd"> True</span> |
| <span class="sd"> >>> model.treeWeights == model2.treeWeights</span> |
| <span class="sd"> True</span> |
| <span class="sd"> >>> model.transform(test0).take(1) == model2.transform(test0).take(1)</span> |
| <span class="sd"> True</span> |
| <span class="sd"> >>> model.trees</span> |
| <span class="sd"> [DecisionTreeRegressionModel...depth=..., DecisionTreeRegressionModel...]</span> |
| <span class="sd"> >>> validation = spark.createDataFrame([(0.0, Vectors.dense(-1.0),)],</span> |
| <span class="sd"> ... ["indexed", "features"])</span> |
| <span class="sd"> >>> model.evaluateEachIteration(validation)</span> |
| <span class="sd"> [0.25..., 0.23..., 0.21..., 0.19..., 0.18...]</span> |
| <span class="sd"> >>> model.numClasses</span> |
| <span class="sd"> 2</span> |
| <span class="sd"> >>> gbt = gbt.setValidationIndicatorCol("validationIndicator")</span> |
| <span class="sd"> >>> gbt.getValidationIndicatorCol()</span> |
| <span class="sd"> 'validationIndicator'</span> |
| <span class="sd"> >>> gbt.getValidationTol()</span> |
| <span class="sd"> 0.01</span> |
| <span class="sd"> """</span> |
| |
| <span class="n">_input_kwargs</span><span class="p">:</span> <span class="n">Dict</span><span class="p">[</span><span class="nb">str</span><span class="p">,</span> <span class="n">Any</span><span class="p">]</span> |
| |
| <span class="nd">@keyword_only</span> |
| <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span> |
| <span class="bp">self</span><span class="p">,</span> |
| <span class="o">*</span><span class="p">,</span> |
| <span class="n">featuresCol</span><span class="p">:</span> <span class="nb">str</span> <span class="o">=</span> <span class="s2">"features"</span><span class="p">,</span> |
| <span class="n">labelCol</span><span class="p">:</span> <span class="nb">str</span> <span class="o">=</span> <span class="s2">"label"</span><span class="p">,</span> |
| <span class="n">predictionCol</span><span class="p">:</span> <span class="nb">str</span> <span class="o">=</span> <span class="s2">"prediction"</span><span class="p">,</span> |
| <span class="n">maxDepth</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="mi">5</span><span class="p">,</span> |
| <span class="n">maxBins</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="mi">32</span><span class="p">,</span> |
| <span class="n">minInstancesPerNode</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="mi">1</span><span class="p">,</span> |
| <span class="n">minInfoGain</span><span class="p">:</span> <span class="nb">float</span> <span class="o">=</span> <span class="mf">0.0</span><span class="p">,</span> |
| <span class="n">maxMemoryInMB</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="mi">256</span><span class="p">,</span> |
| <span class="n">cacheNodeIds</span><span class="p">:</span> <span class="nb">bool</span> <span class="o">=</span> <span class="kc">False</span><span class="p">,</span> |
| <span class="n">checkpointInterval</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="mi">10</span><span class="p">,</span> |
| <span class="n">lossType</span><span class="p">:</span> <span class="nb">str</span> <span class="o">=</span> <span class="s2">"logistic"</span><span class="p">,</span> |
| <span class="n">maxIter</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="mi">20</span><span class="p">,</span> |
| <span class="n">stepSize</span><span class="p">:</span> <span class="nb">float</span> <span class="o">=</span> <span class="mf">0.1</span><span class="p">,</span> |
| <span class="n">seed</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="nb">int</span><span class="p">]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span> |
| <span class="n">subsamplingRate</span><span class="p">:</span> <span class="nb">float</span> <span class="o">=</span> <span class="mf">1.0</span><span class="p">,</span> |
| <span class="n">impurity</span><span class="p">:</span> <span class="nb">str</span> <span class="o">=</span> <span class="s2">"variance"</span><span class="p">,</span> |
| <span class="n">featureSubsetStrategy</span><span class="p">:</span> <span class="nb">str</span> <span class="o">=</span> <span class="s2">"all"</span><span class="p">,</span> |
| <span class="n">validationTol</span><span class="p">:</span> <span class="nb">float</span> <span class="o">=</span> <span class="mf">0.01</span><span class="p">,</span> |
| <span class="n">validationIndicatorCol</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="nb">str</span><span class="p">]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span> |
| <span class="n">leafCol</span><span class="p">:</span> <span class="nb">str</span> <span class="o">=</span> <span class="s2">""</span><span class="p">,</span> |
| <span class="n">minWeightFractionPerNode</span><span class="p">:</span> <span class="nb">float</span> <span class="o">=</span> <span class="mf">0.0</span><span class="p">,</span> |
| <span class="n">weightCol</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="nb">str</span><span class="p">]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span> |
| <span class="p">):</span> |
| <span class="w"> </span><span class="sd">"""</span> |
| <span class="sd"> __init__(self, \\*, featuresCol="features", labelCol="label", predictionCol="prediction", \</span> |
| <span class="sd"> maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0, \</span> |
| <span class="sd"> maxMemoryInMB=256, cacheNodeIds=False, checkpointInterval=10, \</span> |
| <span class="sd"> lossType="logistic", maxIter=20, stepSize=0.1, seed=None, subsamplingRate=1.0, \</span> |
| <span class="sd"> impurity="variance", featureSubsetStrategy="all", validationTol=0.01, \</span> |
| <span class="sd"> validationIndicatorCol=None, leafCol="", minWeightFractionPerNode=0.0, \</span> |
| <span class="sd"> weightCol=None)</span> |
| <span class="sd"> """</span> |
| <span class="nb">super</span><span class="p">(</span><span class="n">GBTClassifier</span><span class="p">,</span> <span class="bp">self</span><span class="p">)</span><span class="o">.</span><span class="fm">__init__</span><span class="p">()</span> |
| <span class="bp">self</span><span class="o">.</span><span class="n">_java_obj</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">_new_java_obj</span><span class="p">(</span> |
| <span class="s2">"org.apache.spark.ml.classification.GBTClassifier"</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">uid</span> |
| <span class="p">)</span> |
| <span class="n">kwargs</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">_input_kwargs</span> |
| <span class="bp">self</span><span class="o">.</span><span class="n">setParams</span><span class="p">(</span><span class="o">**</span><span class="n">kwargs</span><span class="p">)</span> |
| |
| <div class="viewcode-block" id="GBTClassifier.setParams"><a class="viewcode-back" href="../../../reference/api/pyspark.ml.classification.GBTClassifier.html#pyspark.ml.classification.GBTClassifier.setParams">[docs]</a> <span class="nd">@keyword_only</span> |
| <span class="nd">@since</span><span class="p">(</span><span class="s2">"1.4.0"</span><span class="p">)</span> |
| <span class="k">def</span> <span class="nf">setParams</span><span class="p">(</span> |
| <span class="bp">self</span><span class="p">,</span> |
| <span class="o">*</span><span class="p">,</span> |
| <span class="n">featuresCol</span><span class="p">:</span> <span class="nb">str</span> <span class="o">=</span> <span class="s2">"features"</span><span class="p">,</span> |
| <span class="n">labelCol</span><span class="p">:</span> <span class="nb">str</span> <span class="o">=</span> <span class="s2">"label"</span><span class="p">,</span> |
| <span class="n">predictionCol</span><span class="p">:</span> <span class="nb">str</span> <span class="o">=</span> <span class="s2">"prediction"</span><span class="p">,</span> |
| <span class="n">maxDepth</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="mi">5</span><span class="p">,</span> |
| <span class="n">maxBins</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="mi">32</span><span class="p">,</span> |
| <span class="n">minInstancesPerNode</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="mi">1</span><span class="p">,</span> |
| <span class="n">minInfoGain</span><span class="p">:</span> <span class="nb">float</span> <span class="o">=</span> <span class="mf">0.0</span><span class="p">,</span> |
| <span class="n">maxMemoryInMB</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="mi">256</span><span class="p">,</span> |
| <span class="n">cacheNodeIds</span><span class="p">:</span> <span class="nb">bool</span> <span class="o">=</span> <span class="kc">False</span><span class="p">,</span> |
| <span class="n">checkpointInterval</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="mi">10</span><span class="p">,</span> |
| <span class="n">lossType</span><span class="p">:</span> <span class="nb">str</span> <span class="o">=</span> <span class="s2">"logistic"</span><span class="p">,</span> |
| <span class="n">maxIter</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="mi">20</span><span class="p">,</span> |
| <span class="n">stepSize</span><span class="p">:</span> <span class="nb">float</span> <span class="o">=</span> <span class="mf">0.1</span><span class="p">,</span> |
| <span class="n">seed</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="nb">int</span><span class="p">]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span> |
| <span class="n">subsamplingRate</span><span class="p">:</span> <span class="nb">float</span> <span class="o">=</span> <span class="mf">1.0</span><span class="p">,</span> |
| <span class="n">impurity</span><span class="p">:</span> <span class="nb">str</span> <span class="o">=</span> <span class="s2">"variance"</span><span class="p">,</span> |
| <span class="n">featureSubsetStrategy</span><span class="p">:</span> <span class="nb">str</span> <span class="o">=</span> <span class="s2">"all"</span><span class="p">,</span> |
| <span class="n">validationTol</span><span class="p">:</span> <span class="nb">float</span> <span class="o">=</span> <span class="mf">0.01</span><span class="p">,</span> |
| <span class="n">validationIndicatorCol</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="nb">str</span><span class="p">]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span> |
| <span class="n">leafCol</span><span class="p">:</span> <span class="nb">str</span> <span class="o">=</span> <span class="s2">""</span><span class="p">,</span> |
| <span class="n">minWeightFractionPerNode</span><span class="p">:</span> <span class="nb">float</span> <span class="o">=</span> <span class="mf">0.0</span><span class="p">,</span> |
| <span class="n">weightCol</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="nb">str</span><span class="p">]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span> |
| <span class="p">)</span> <span class="o">-></span> <span class="s2">"GBTClassifier"</span><span class="p">:</span> |
| <span class="w"> </span><span class="sd">"""</span> |
| <span class="sd"> setParams(self, \\*, featuresCol="features", labelCol="label", predictionCol="prediction", \</span> |
| <span class="sd"> maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0, \</span> |
| <span class="sd"> maxMemoryInMB=256, cacheNodeIds=False, checkpointInterval=10, \</span> |
| <span class="sd"> lossType="logistic", maxIter=20, stepSize=0.1, seed=None, subsamplingRate=1.0, \</span> |
| <span class="sd"> impurity="variance", featureSubsetStrategy="all", validationTol=0.01, \</span> |
| <span class="sd"> validationIndicatorCol=None, leafCol="", minWeightFractionPerNode=0.0, \</span> |
| <span class="sd"> weightCol=None)</span> |
| <span class="sd"> Sets params for Gradient Boosted Tree Classification.</span> |
| <span class="sd"> """</span> |
| <span class="n">kwargs</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">_input_kwargs</span> |
| <span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">_set</span><span class="p">(</span><span class="o">**</span><span class="n">kwargs</span><span class="p">)</span></div> |
| |
| <span class="k">def</span> <span class="nf">_create_model</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">java_model</span><span class="p">:</span> <span class="s2">"JavaObject"</span><span class="p">)</span> <span class="o">-></span> <span class="s2">"GBTClassificationModel"</span><span class="p">:</span> |
| <span class="k">return</span> <span class="n">GBTClassificationModel</span><span class="p">(</span><span class="n">java_model</span><span class="p">)</span> |
| |
| <div class="viewcode-block" id="GBTClassifier.setMaxDepth"><a class="viewcode-back" href="../../../reference/api/pyspark.ml.classification.GBTClassifier.html#pyspark.ml.classification.GBTClassifier.setMaxDepth">[docs]</a> <span class="k">def</span> <span class="nf">setMaxDepth</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">value</span><span class="p">:</span> <span class="nb">int</span><span class="p">)</span> <span class="o">-></span> <span class="s2">"GBTClassifier"</span><span class="p">:</span> |
| <span class="w"> </span><span class="sd">"""</span> |
| <span class="sd"> Sets the value of :py:attr:`maxDepth`.</span> |
| <span class="sd"> """</span> |
| <span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">_set</span><span class="p">(</span><span class="n">maxDepth</span><span class="o">=</span><span class="n">value</span><span class="p">)</span></div> |
| |
| <div class="viewcode-block" id="GBTClassifier.setMaxBins"><a class="viewcode-back" href="../../../reference/api/pyspark.ml.classification.GBTClassifier.html#pyspark.ml.classification.GBTClassifier.setMaxBins">[docs]</a> <span class="k">def</span> <span class="nf">setMaxBins</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">value</span><span class="p">:</span> <span class="nb">int</span><span class="p">)</span> <span class="o">-></span> <span class="s2">"GBTClassifier"</span><span class="p">:</span> |
| <span class="w"> </span><span class="sd">"""</span> |
| <span class="sd"> Sets the value of :py:attr:`maxBins`.</span> |
| <span class="sd"> """</span> |
| <span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">_set</span><span class="p">(</span><span class="n">maxBins</span><span class="o">=</span><span class="n">value</span><span class="p">)</span></div> |
| |
| <div class="viewcode-block" id="GBTClassifier.setMinInstancesPerNode"><a class="viewcode-back" href="../../../reference/api/pyspark.ml.classification.GBTClassifier.html#pyspark.ml.classification.GBTClassifier.setMinInstancesPerNode">[docs]</a> <span class="k">def</span> <span class="nf">setMinInstancesPerNode</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">value</span><span class="p">:</span> <span class="nb">int</span><span class="p">)</span> <span class="o">-></span> <span class="s2">"GBTClassifier"</span><span class="p">:</span> |
| <span class="w"> </span><span class="sd">"""</span> |
| <span class="sd"> Sets the value of :py:attr:`minInstancesPerNode`.</span> |
| <span class="sd"> """</span> |
| <span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">_set</span><span class="p">(</span><span class="n">minInstancesPerNode</span><span class="o">=</span><span class="n">value</span><span class="p">)</span></div> |
| |
| <div class="viewcode-block" id="GBTClassifier.setMinInfoGain"><a class="viewcode-back" href="../../../reference/api/pyspark.ml.classification.GBTClassifier.html#pyspark.ml.classification.GBTClassifier.setMinInfoGain">[docs]</a> <span class="k">def</span> <span class="nf">setMinInfoGain</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">value</span><span class="p">:</span> <span class="nb">float</span><span class="p">)</span> <span class="o">-></span> <span class="s2">"GBTClassifier"</span><span class="p">:</span> |
| <span class="w"> </span><span class="sd">"""</span> |
| <span class="sd"> Sets the value of :py:attr:`minInfoGain`.</span> |
| <span class="sd"> """</span> |
| <span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">_set</span><span class="p">(</span><span class="n">minInfoGain</span><span class="o">=</span><span class="n">value</span><span class="p">)</span></div> |
| |
| <div class="viewcode-block" id="GBTClassifier.setMaxMemoryInMB"><a class="viewcode-back" href="../../../reference/api/pyspark.ml.classification.GBTClassifier.html#pyspark.ml.classification.GBTClassifier.setMaxMemoryInMB">[docs]</a> <span class="k">def</span> <span class="nf">setMaxMemoryInMB</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">value</span><span class="p">:</span> <span class="nb">int</span><span class="p">)</span> <span class="o">-></span> <span class="s2">"GBTClassifier"</span><span class="p">:</span> |
| <span class="w"> </span><span class="sd">"""</span> |
| <span class="sd"> Sets the value of :py:attr:`maxMemoryInMB`.</span> |
| <span class="sd"> """</span> |
| <span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">_set</span><span class="p">(</span><span class="n">maxMemoryInMB</span><span class="o">=</span><span class="n">value</span><span class="p">)</span></div> |
| |
| <div class="viewcode-block" id="GBTClassifier.setCacheNodeIds"><a class="viewcode-back" href="../../../reference/api/pyspark.ml.classification.GBTClassifier.html#pyspark.ml.classification.GBTClassifier.setCacheNodeIds">[docs]</a> <span class="k">def</span> <span class="nf">setCacheNodeIds</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">value</span><span class="p">:</span> <span class="nb">bool</span><span class="p">)</span> <span class="o">-></span> <span class="s2">"GBTClassifier"</span><span class="p">:</span> |
| <span class="w"> </span><span class="sd">"""</span> |
| <span class="sd"> Sets the value of :py:attr:`cacheNodeIds`.</span> |
| <span class="sd"> """</span> |
| <span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">_set</span><span class="p">(</span><span class="n">cacheNodeIds</span><span class="o">=</span><span class="n">value</span><span class="p">)</span></div> |
| |
| <div class="viewcode-block" id="GBTClassifier.setImpurity"><a class="viewcode-back" href="../../../reference/api/pyspark.ml.classification.GBTClassifier.html#pyspark.ml.classification.GBTClassifier.setImpurity">[docs]</a> <span class="nd">@since</span><span class="p">(</span><span class="s2">"1.4.0"</span><span class="p">)</span> |
| <span class="k">def</span> <span class="nf">setImpurity</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">value</span><span class="p">:</span> <span class="nb">str</span><span class="p">)</span> <span class="o">-></span> <span class="s2">"GBTClassifier"</span><span class="p">:</span> |
| <span class="w"> </span><span class="sd">"""</span> |
| <span class="sd"> Sets the value of :py:attr:`impurity`.</span> |
| <span class="sd"> """</span> |
| <span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">_set</span><span class="p">(</span><span class="n">impurity</span><span class="o">=</span><span class="n">value</span><span class="p">)</span></div> |
| |
| <div class="viewcode-block" id="GBTClassifier.setLossType"><a class="viewcode-back" href="../../../reference/api/pyspark.ml.classification.GBTClassifier.html#pyspark.ml.classification.GBTClassifier.setLossType">[docs]</a> <span class="nd">@since</span><span class="p">(</span><span class="s2">"1.4.0"</span><span class="p">)</span> |
| <span class="k">def</span> <span class="nf">setLossType</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">value</span><span class="p">:</span> <span class="nb">str</span><span class="p">)</span> <span class="o">-></span> <span class="s2">"GBTClassifier"</span><span class="p">:</span> |
| <span class="w"> </span><span class="sd">"""</span> |
| <span class="sd"> Sets the value of :py:attr:`lossType`.</span> |
| <span class="sd"> """</span> |
| <span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">_set</span><span class="p">(</span><span class="n">lossType</span><span class="o">=</span><span class="n">value</span><span class="p">)</span></div> |
| |
| <div class="viewcode-block" id="GBTClassifier.setSubsamplingRate"><a class="viewcode-back" href="../../../reference/api/pyspark.ml.classification.GBTClassifier.html#pyspark.ml.classification.GBTClassifier.setSubsamplingRate">[docs]</a> <span class="nd">@since</span><span class="p">(</span><span class="s2">"1.4.0"</span><span class="p">)</span> |
| <span class="k">def</span> <span class="nf">setSubsamplingRate</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">value</span><span class="p">:</span> <span class="nb">float</span><span class="p">)</span> <span class="o">-></span> <span class="s2">"GBTClassifier"</span><span class="p">:</span> |
| <span class="w"> </span><span class="sd">"""</span> |
| <span class="sd"> Sets the value of :py:attr:`subsamplingRate`.</span> |
| <span class="sd"> """</span> |
| <span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">_set</span><span class="p">(</span><span class="n">subsamplingRate</span><span class="o">=</span><span class="n">value</span><span class="p">)</span></div> |
| |
| <div class="viewcode-block" id="GBTClassifier.setFeatureSubsetStrategy"><a class="viewcode-back" href="../../../reference/api/pyspark.ml.classification.GBTClassifier.html#pyspark.ml.classification.GBTClassifier.setFeatureSubsetStrategy">[docs]</a> <span class="nd">@since</span><span class="p">(</span><span class="s2">"2.4.0"</span><span class="p">)</span> |
| <span class="k">def</span> <span class="nf">setFeatureSubsetStrategy</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">value</span><span class="p">:</span> <span class="nb">str</span><span class="p">)</span> <span class="o">-></span> <span class="s2">"GBTClassifier"</span><span class="p">:</span> |
| <span class="w"> </span><span class="sd">"""</span> |
| <span class="sd"> Sets the value of :py:attr:`featureSubsetStrategy`.</span> |
| <span class="sd"> """</span> |
| <span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">_set</span><span class="p">(</span><span class="n">featureSubsetStrategy</span><span class="o">=</span><span class="n">value</span><span class="p">)</span></div> |
| |
| <div class="viewcode-block" id="GBTClassifier.setValidationIndicatorCol"><a class="viewcode-back" href="../../../reference/api/pyspark.ml.classification.GBTClassifier.html#pyspark.ml.classification.GBTClassifier.setValidationIndicatorCol">[docs]</a> <span class="nd">@since</span><span class="p">(</span><span class="s2">"3.0.0"</span><span class="p">)</span> |
| <span class="k">def</span> <span class="nf">setValidationIndicatorCol</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">value</span><span class="p">:</span> <span class="nb">str</span><span class="p">)</span> <span class="o">-></span> <span class="s2">"GBTClassifier"</span><span class="p">:</span> |
| <span class="w"> </span><span class="sd">"""</span> |
| <span class="sd"> Sets the value of :py:attr:`validationIndicatorCol`.</span> |
| <span class="sd"> """</span> |
| <span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">_set</span><span class="p">(</span><span class="n">validationIndicatorCol</span><span class="o">=</span><span class="n">value</span><span class="p">)</span></div> |
| |
| <div class="viewcode-block" id="GBTClassifier.setMaxIter"><a class="viewcode-back" href="../../../reference/api/pyspark.ml.classification.GBTClassifier.html#pyspark.ml.classification.GBTClassifier.setMaxIter">[docs]</a> <span class="nd">@since</span><span class="p">(</span><span class="s2">"1.4.0"</span><span class="p">)</span> |
| <span class="k">def</span> <span class="nf">setMaxIter</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">value</span><span class="p">:</span> <span class="nb">int</span><span class="p">)</span> <span class="o">-></span> <span class="s2">"GBTClassifier"</span><span class="p">:</span> |
| <span class="w"> </span><span class="sd">"""</span> |
| <span class="sd"> Sets the value of :py:attr:`maxIter`.</span> |
| <span class="sd"> """</span> |
| <span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">_set</span><span class="p">(</span><span class="n">maxIter</span><span class="o">=</span><span class="n">value</span><span class="p">)</span></div> |
| |
| <div class="viewcode-block" id="GBTClassifier.setCheckpointInterval"><a class="viewcode-back" href="../../../reference/api/pyspark.ml.classification.GBTClassifier.html#pyspark.ml.classification.GBTClassifier.setCheckpointInterval">[docs]</a> <span class="nd">@since</span><span class="p">(</span><span class="s2">"1.4.0"</span><span class="p">)</span> |
| <span class="k">def</span> <span class="nf">setCheckpointInterval</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">value</span><span class="p">:</span> <span class="nb">int</span><span class="p">)</span> <span class="o">-></span> <span class="s2">"GBTClassifier"</span><span class="p">:</span> |
| <span class="w"> </span><span class="sd">"""</span> |
| <span class="sd"> Sets the value of :py:attr:`checkpointInterval`.</span> |
| <span class="sd"> """</span> |
| <span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">_set</span><span class="p">(</span><span class="n">checkpointInterval</span><span class="o">=</span><span class="n">value</span><span class="p">)</span></div> |
| |
| <div class="viewcode-block" id="GBTClassifier.setSeed"><a class="viewcode-back" href="../../../reference/api/pyspark.ml.classification.GBTClassifier.html#pyspark.ml.classification.GBTClassifier.setSeed">[docs]</a> <span class="nd">@since</span><span class="p">(</span><span class="s2">"1.4.0"</span><span class="p">)</span> |
| <span class="k">def</span> <span class="nf">setSeed</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">value</span><span class="p">:</span> <span class="nb">int</span><span class="p">)</span> <span class="o">-></span> <span class="s2">"GBTClassifier"</span><span class="p">:</span> |
| <span class="w"> </span><span class="sd">"""</span> |
| <span class="sd"> Sets the value of :py:attr:`seed`.</span> |
| <span class="sd"> """</span> |
| <span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">_set</span><span class="p">(</span><span class="n">seed</span><span class="o">=</span><span class="n">value</span><span class="p">)</span></div> |
| |
| <div class="viewcode-block" id="GBTClassifier.setStepSize"><a class="viewcode-back" href="../../../reference/api/pyspark.ml.classification.GBTClassifier.html#pyspark.ml.classification.GBTClassifier.setStepSize">[docs]</a> <span class="nd">@since</span><span class="p">(</span><span class="s2">"1.4.0"</span><span class="p">)</span> |
| <span class="k">def</span> <span class="nf">setStepSize</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">value</span><span class="p">:</span> <span class="nb">int</span><span class="p">)</span> <span class="o">-></span> <span class="s2">"GBTClassifier"</span><span class="p">:</span> |
| <span class="w"> </span><span class="sd">"""</span> |
| <span class="sd"> Sets the value of :py:attr:`stepSize`.</span> |
| <span class="sd"> """</span> |
| <span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">_set</span><span class="p">(</span><span class="n">stepSize</span><span class="o">=</span><span class="n">value</span><span class="p">)</span></div> |
| |
| <div class="viewcode-block" id="GBTClassifier.setWeightCol"><a class="viewcode-back" href="../../../reference/api/pyspark.ml.classification.GBTClassifier.html#pyspark.ml.classification.GBTClassifier.setWeightCol">[docs]</a> <span class="nd">@since</span><span class="p">(</span><span class="s2">"3.0.0"</span><span class="p">)</span> |
| <span class="k">def</span> <span class="nf">setWeightCol</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">value</span><span class="p">:</span> <span class="nb">str</span><span class="p">)</span> <span class="o">-></span> <span class="s2">"GBTClassifier"</span><span class="p">:</span> |
| <span class="w"> </span><span class="sd">"""</span> |
| <span class="sd"> Sets the value of :py:attr:`weightCol`.</span> |
| <span class="sd"> """</span> |
| <span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">_set</span><span class="p">(</span><span class="n">weightCol</span><span class="o">=</span><span class="n">value</span><span class="p">)</span></div> |
| |
| <div class="viewcode-block" id="GBTClassifier.setMinWeightFractionPerNode"><a class="viewcode-back" href="../../../reference/api/pyspark.ml.classification.GBTClassifier.html#pyspark.ml.classification.GBTClassifier.setMinWeightFractionPerNode">[docs]</a> <span class="nd">@since</span><span class="p">(</span><span class="s2">"3.0.0"</span><span class="p">)</span> |
| <span class="k">def</span> <span class="nf">setMinWeightFractionPerNode</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">value</span><span class="p">:</span> <span class="nb">float</span><span class="p">)</span> <span class="o">-></span> <span class="s2">"GBTClassifier"</span><span class="p">:</span> |
| <span class="w"> </span><span class="sd">"""</span> |
| <span class="sd"> Sets the value of :py:attr:`minWeightFractionPerNode`.</span> |
| <span class="sd"> """</span> |
| <span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">_set</span><span class="p">(</span><span class="n">minWeightFractionPerNode</span><span class="o">=</span><span class="n">value</span><span class="p">)</span></div></div> |
| |
| |
| <div class="viewcode-block" id="GBTClassificationModel"><a class="viewcode-back" href="../../../reference/api/pyspark.ml.classification.GBTClassificationModel.html#pyspark.ml.classification.GBTClassificationModel">[docs]</a><span class="k">class</span> <span class="nc">GBTClassificationModel</span><span class="p">(</span> |
| <span class="n">_TreeEnsembleModel</span><span class="p">,</span> |
| <span class="n">_JavaProbabilisticClassificationModel</span><span class="p">[</span><span class="n">Vector</span><span class="p">],</span> |
| <span class="n">_GBTClassifierParams</span><span class="p">,</span> |
| <span class="n">JavaMLWritable</span><span class="p">,</span> |
| <span class="n">JavaMLReadable</span><span class="p">[</span><span class="s2">"GBTClassificationModel"</span><span class="p">],</span> |
| <span class="p">):</span> |
| <span class="w"> </span><span class="sd">"""</span> |
| <span class="sd"> Model fitted by GBTClassifier.</span> |
| |
| <span class="sd"> .. versionadded:: 1.4.0</span> |
| <span class="sd"> """</span> |
| |
| <span class="nd">@property</span> |
| <span class="k">def</span> <span class="nf">featureImportances</span><span class="p">(</span><span class="bp">self</span><span class="p">)</span> <span class="o">-></span> <span class="n">Vector</span><span class="p">:</span> |
| <span class="w"> </span><span class="sd">"""</span> |
| <span class="sd"> Estimate of the importance of each feature.</span> |
| |
| <span class="sd"> Each feature's importance is the average of its importance across all trees in the ensemble</span> |
| <span class="sd"> The importance vector is normalized to sum to 1. This method is suggested by Hastie et al.</span> |
| <span class="sd"> (Hastie, Tibshirani, Friedman. "The Elements of Statistical Learning, 2nd Edition." 2001.)</span> |
| <span class="sd"> and follows the implementation from scikit-learn.</span> |
| |
| <span class="sd"> .. versionadded:: 2.0.0</span> |
| |
| <span class="sd"> See Also</span> |
| <span class="sd"> --------</span> |
| <span class="sd"> DecisionTreeClassificationModel.featureImportances</span> |
| <span class="sd"> """</span> |
| <span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">_call_java</span><span class="p">(</span><span class="s2">"featureImportances"</span><span class="p">)</span> |
| |
| <span class="nd">@property</span> <span class="c1"># type: ignore[misc]</span> |
| <span class="nd">@since</span><span class="p">(</span><span class="s2">"2.0.0"</span><span class="p">)</span> |
| <span class="k">def</span> <span class="nf">trees</span><span class="p">(</span><span class="bp">self</span><span class="p">)</span> <span class="o">-></span> <span class="n">List</span><span class="p">[</span><span class="n">DecisionTreeRegressionModel</span><span class="p">]:</span> |
| <span class="w"> </span><span class="sd">"""Trees in this ensemble. Warning: These have null parent Estimators."""</span> |
| <span class="k">return</span> <span class="p">[</span><span class="n">DecisionTreeRegressionModel</span><span class="p">(</span><span class="n">m</span><span class="p">)</span> <span class="k">for</span> <span class="n">m</span> <span class="ow">in</span> <span class="nb">list</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">_call_java</span><span class="p">(</span><span class="s2">"trees"</span><span class="p">))]</span> |
| |
| <div class="viewcode-block" id="GBTClassificationModel.evaluateEachIteration"><a class="viewcode-back" href="../../../reference/api/pyspark.ml.classification.GBTClassificationModel.html#pyspark.ml.classification.GBTClassificationModel.evaluateEachIteration">[docs]</a> <span class="k">def</span> <span class="nf">evaluateEachIteration</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">dataset</span><span class="p">:</span> <span class="n">DataFrame</span><span class="p">)</span> <span class="o">-></span> <span class="n">List</span><span class="p">[</span><span class="nb">float</span><span class="p">]:</span> |
| <span class="w"> </span><span class="sd">"""</span> |
| <span class="sd"> Method to compute error or loss for every iteration of gradient boosting.</span> |
| |
| <span class="sd"> .. versionadded:: 2.4.0</span> |
| |
| <span class="sd"> Parameters</span> |
| <span class="sd"> ----------</span> |
| <span class="sd"> dataset : :py:class:`pyspark.sql.DataFrame`</span> |
| <span class="sd"> Test dataset to evaluate model on.</span> |
| <span class="sd"> """</span> |
| <span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">_call_java</span><span class="p">(</span><span class="s2">"evaluateEachIteration"</span><span class="p">,</span> <span class="n">dataset</span><span class="p">)</span></div></div> |
| |
| |
| <span class="k">class</span> <span class="nc">_NaiveBayesParams</span><span class="p">(</span><span class="n">_PredictorParams</span><span class="p">,</span> <span class="n">HasWeightCol</span><span class="p">):</span> |
| <span class="w"> </span><span class="sd">"""</span> |
| <span class="sd"> Params for :py:class:`NaiveBayes` and :py:class:`NaiveBayesModel`.</span> |
| |
| <span class="sd"> .. versionadded:: 3.0.0</span> |
| <span class="sd"> """</span> |
| |
| <span class="n">smoothing</span><span class="p">:</span> <span class="n">Param</span><span class="p">[</span><span class="nb">float</span><span class="p">]</span> <span class="o">=</span> <span class="n">Param</span><span class="p">(</span> |
| <span class="n">Params</span><span class="o">.</span><span class="n">_dummy</span><span class="p">(),</span> |
| <span class="s2">"smoothing"</span><span class="p">,</span> |
| <span class="s2">"The smoothing parameter, should be >= 0, "</span> <span class="o">+</span> <span class="s2">"default is 1.0"</span><span class="p">,</span> |
| <span class="n">typeConverter</span><span class="o">=</span><span class="n">TypeConverters</span><span class="o">.</span><span class="n">toFloat</span><span class="p">,</span> |
| <span class="p">)</span> |
| <span class="n">modelType</span><span class="p">:</span> <span class="n">Param</span><span class="p">[</span><span class="nb">str</span><span class="p">]</span> <span class="o">=</span> <span class="n">Param</span><span class="p">(</span> |
| <span class="n">Params</span><span class="o">.</span><span class="n">_dummy</span><span class="p">(),</span> |
| <span class="s2">"modelType"</span><span class="p">,</span> |
| <span class="s2">"The model type which is a string "</span> |
| <span class="o">+</span> <span class="s2">"(case-sensitive). Supported options: multinomial (default), bernoulli "</span> |
| <span class="o">+</span> <span class="s2">"and gaussian."</span><span class="p">,</span> |
| <span class="n">typeConverter</span><span class="o">=</span><span class="n">TypeConverters</span><span class="o">.</span><span class="n">toString</span><span class="p">,</span> |
| <span class="p">)</span> |
| |
| <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="o">*</span><span class="n">args</span><span class="p">:</span> <span class="n">Any</span><span class="p">):</span> |
| <span class="nb">super</span><span class="p">(</span><span class="n">_NaiveBayesParams</span><span class="p">,</span> <span class="bp">self</span><span class="p">)</span><span class="o">.</span><span class="fm">__init__</span><span class="p">(</span><span class="o">*</span><span class="n">args</span><span class="p">)</span> |
| <span class="bp">self</span><span class="o">.</span><span class="n">_setDefault</span><span class="p">(</span><span class="n">smoothing</span><span class="o">=</span><span class="mf">1.0</span><span class="p">,</span> <span class="n">modelType</span><span class="o">=</span><span class="s2">"multinomial"</span><span class="p">)</span> |
| |
| <span class="nd">@since</span><span class="p">(</span><span class="s2">"1.5.0"</span><span class="p">)</span> |
| <span class="k">def</span> <span class="nf">getSmoothing</span><span class="p">(</span><span class="bp">self</span><span class="p">)</span> <span class="o">-></span> <span class="nb">float</span><span class="p">:</span> |
| <span class="w"> </span><span class="sd">"""</span> |
| <span class="sd"> Gets the value of smoothing or its default value.</span> |
| <span class="sd"> """</span> |
| <span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">getOrDefault</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">smoothing</span><span class="p">)</span> |
| |
| <span class="nd">@since</span><span class="p">(</span><span class="s2">"1.5.0"</span><span class="p">)</span> |
| <span class="k">def</span> <span class="nf">getModelType</span><span class="p">(</span><span class="bp">self</span><span class="p">)</span> <span class="o">-></span> <span class="nb">str</span><span class="p">:</span> |
| <span class="w"> </span><span class="sd">"""</span> |
| <span class="sd"> Gets the value of modelType or its default value.</span> |
| <span class="sd"> """</span> |
| <span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">getOrDefault</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">modelType</span><span class="p">)</span> |
| |
| |
| <div class="viewcode-block" id="NaiveBayes"><a class="viewcode-back" href="../../../reference/api/pyspark.ml.classification.NaiveBayes.html#pyspark.ml.classification.NaiveBayes">[docs]</a><span class="nd">@inherit_doc</span> |
| <span class="k">class</span> <span class="nc">NaiveBayes</span><span class="p">(</span> |
| <span class="n">_JavaProbabilisticClassifier</span><span class="p">[</span><span class="s2">"NaiveBayesModel"</span><span class="p">],</span> |
| <span class="n">_NaiveBayesParams</span><span class="p">,</span> |
| <span class="n">HasThresholds</span><span class="p">,</span> |
| <span class="n">HasWeightCol</span><span class="p">,</span> |
| <span class="n">JavaMLWritable</span><span class="p">,</span> |
| <span class="n">JavaMLReadable</span><span class="p">[</span><span class="s2">"NaiveBayes"</span><span class="p">],</span> |
| <span class="p">):</span> |
| <span class="w"> </span><span class="sd">"""</span> |
| <span class="sd"> Naive Bayes Classifiers.</span> |
| <span class="sd"> It supports both Multinomial and Bernoulli NB. `Multinomial NB \</span> |
| <span class="sd"> <http://nlp.stanford.edu/IR-book/html/htmledition/naive-bayes-text-classification-1.html>`_</span> |
| <span class="sd"> can handle finitely supported discrete data. For example, by converting documents into</span> |
| <span class="sd"> TF-IDF vectors, it can be used for document classification. By making every vector a</span> |
| <span class="sd"> binary (0/1) data, it can also be used as `Bernoulli NB \</span> |
| <span class="sd"> <http://nlp.stanford.edu/IR-book/html/htmledition/the-bernoulli-model-1.html>`_.</span> |
| |
| <span class="sd"> The input feature values for Multinomial NB and Bernoulli NB must be nonnegative.</span> |
| <span class="sd"> Since 3.0.0, it supports Complement NB which is an adaptation of the Multinomial NB.</span> |
| <span class="sd"> Specifically, Complement NB uses statistics from the complement of each class to compute</span> |
| <span class="sd"> the model's coefficients. The inventors of Complement NB show empirically that the parameter</span> |
| <span class="sd"> estimates for CNB are more stable than those for Multinomial NB. Like Multinomial NB, the</span> |
| <span class="sd"> input feature values for Complement NB must be nonnegative.</span> |
| <span class="sd"> Since 3.0.0, it also supports `Gaussian NB \</span> |
| <span class="sd"> <https://en.wikipedia.org/wiki/Naive_Bayes_classifier#Gaussian_naive_Bayes>`_.</span> |
| <span class="sd"> which can handle continuous data.</span> |
| |
| <span class="sd"> .. versionadded:: 1.5.0</span> |
| |
| <span class="sd"> Examples</span> |
| <span class="sd"> --------</span> |
| <span class="sd"> >>> from pyspark.sql import Row</span> |
| <span class="sd"> >>> from pyspark.ml.linalg import Vectors</span> |
| <span class="sd"> >>> df = spark.createDataFrame([</span> |
| <span class="sd"> ... Row(label=0.0, weight=0.1, features=Vectors.dense([0.0, 0.0])),</span> |
| <span class="sd"> ... Row(label=0.0, weight=0.5, features=Vectors.dense([0.0, 1.0])),</span> |
| <span class="sd"> ... Row(label=1.0, weight=1.0, features=Vectors.dense([1.0, 0.0]))])</span> |
| <span class="sd"> >>> nb = NaiveBayes(smoothing=1.0, modelType="multinomial", weightCol="weight")</span> |
| <span class="sd"> >>> model = nb.fit(df)</span> |
| <span class="sd"> >>> model.setFeaturesCol("features")</span> |
| <span class="sd"> NaiveBayesModel...</span> |
| <span class="sd"> >>> model.getSmoothing()</span> |
| <span class="sd"> 1.0</span> |
| <span class="sd"> >>> model.pi</span> |
| <span class="sd"> DenseVector([-0.81..., -0.58...])</span> |
| <span class="sd"> >>> model.theta</span> |
| <span class="sd"> DenseMatrix(2, 2, [-0.91..., -0.51..., -0.40..., -1.09...], 1)</span> |
| <span class="sd"> >>> model.sigma</span> |
| <span class="sd"> DenseMatrix(0, 0, [...], ...)</span> |
| <span class="sd"> >>> test0 = sc.parallelize([Row(features=Vectors.dense([1.0, 0.0]))]).toDF()</span> |
| <span class="sd"> >>> model.predict(test0.head().features)</span> |
| <span class="sd"> 1.0</span> |
| <span class="sd"> >>> model.predictRaw(test0.head().features)</span> |
| <span class="sd"> DenseVector([-1.72..., -0.99...])</span> |
| <span class="sd"> >>> model.predictProbability(test0.head().features)</span> |
| <span class="sd"> DenseVector([0.32..., 0.67...])</span> |
| <span class="sd"> >>> result = model.transform(test0).head()</span> |
| <span class="sd"> >>> result.prediction</span> |
| <span class="sd"> 1.0</span> |
| <span class="sd"> >>> result.probability</span> |
| <span class="sd"> DenseVector([0.32..., 0.67...])</span> |
| <span class="sd"> >>> result.rawPrediction</span> |
| <span class="sd"> DenseVector([-1.72..., -0.99...])</span> |
| <span class="sd"> >>> test1 = sc.parallelize([Row(features=Vectors.sparse(2, [0], [1.0]))]).toDF()</span> |
| <span class="sd"> >>> model.transform(test1).head().prediction</span> |
| <span class="sd"> 1.0</span> |
| <span class="sd"> >>> nb_path = temp_path + "/nb"</span> |
| <span class="sd"> >>> nb.save(nb_path)</span> |
| <span class="sd"> >>> nb2 = NaiveBayes.load(nb_path)</span> |
| <span class="sd"> >>> nb2.getSmoothing()</span> |
| <span class="sd"> 1.0</span> |
| <span class="sd"> >>> model_path = temp_path + "/nb_model"</span> |
| <span class="sd"> >>> model.save(model_path)</span> |
| <span class="sd"> >>> model2 = NaiveBayesModel.load(model_path)</span> |
| <span class="sd"> >>> model.pi == model2.pi</span> |
| <span class="sd"> True</span> |
| <span class="sd"> >>> model.theta == model2.theta</span> |
| <span class="sd"> True</span> |
| <span class="sd"> >>> model.transform(test0).take(1) == model2.transform(test0).take(1)</span> |
| <span class="sd"> True</span> |
| <span class="sd"> >>> nb = nb.setThresholds([0.01, 10.00])</span> |
| <span class="sd"> >>> model3 = nb.fit(df)</span> |
| <span class="sd"> >>> result = model3.transform(test0).head()</span> |
| <span class="sd"> >>> result.prediction</span> |
| <span class="sd"> 0.0</span> |
| <span class="sd"> >>> nb3 = NaiveBayes().setModelType("gaussian")</span> |
| <span class="sd"> >>> model4 = nb3.fit(df)</span> |
| <span class="sd"> >>> model4.getModelType()</span> |
| <span class="sd"> 'gaussian'</span> |
| <span class="sd"> >>> model4.sigma</span> |
| <span class="sd"> DenseMatrix(2, 2, [0.0, 0.25, 0.0, 0.0], 1)</span> |
| <span class="sd"> >>> nb5 = NaiveBayes(smoothing=1.0, modelType="complement", weightCol="weight")</span> |
| <span class="sd"> >>> model5 = nb5.fit(df)</span> |
| <span class="sd"> >>> model5.getModelType()</span> |
| <span class="sd"> 'complement'</span> |
| <span class="sd"> >>> model5.theta</span> |
| <span class="sd"> DenseMatrix(2, 2, [...], 1)</span> |
| <span class="sd"> >>> model5.sigma</span> |
| <span class="sd"> DenseMatrix(0, 0, [...], ...)</span> |
| <span class="sd"> """</span> |
| |
| <span class="n">_input_kwargs</span><span class="p">:</span> <span class="n">Dict</span><span class="p">[</span><span class="nb">str</span><span class="p">,</span> <span class="n">Any</span><span class="p">]</span> |
| |
| <span class="nd">@keyword_only</span> |
| <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span> |
| <span class="bp">self</span><span class="p">,</span> |
| <span class="o">*</span><span class="p">,</span> |
| <span class="n">featuresCol</span><span class="p">:</span> <span class="nb">str</span> <span class="o">=</span> <span class="s2">"features"</span><span class="p">,</span> |
| <span class="n">labelCol</span><span class="p">:</span> <span class="nb">str</span> <span class="o">=</span> <span class="s2">"label"</span><span class="p">,</span> |
| <span class="n">predictionCol</span><span class="p">:</span> <span class="nb">str</span> <span class="o">=</span> <span class="s2">"prediction"</span><span class="p">,</span> |
| <span class="n">probabilityCol</span><span class="p">:</span> <span class="nb">str</span> <span class="o">=</span> <span class="s2">"probability"</span><span class="p">,</span> |
| <span class="n">rawPredictionCol</span><span class="p">:</span> <span class="nb">str</span> <span class="o">=</span> <span class="s2">"rawPrediction"</span><span class="p">,</span> |
| <span class="n">smoothing</span><span class="p">:</span> <span class="nb">float</span> <span class="o">=</span> <span class="mf">1.0</span><span class="p">,</span> |
| <span class="n">modelType</span><span class="p">:</span> <span class="nb">str</span> <span class="o">=</span> <span class="s2">"multinomial"</span><span class="p">,</span> |
| <span class="n">thresholds</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="n">List</span><span class="p">[</span><span class="nb">float</span><span class="p">]]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span> |
| <span class="n">weightCol</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="nb">str</span><span class="p">]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span> |
| <span class="p">):</span> |
| <span class="w"> </span><span class="sd">"""</span> |
| <span class="sd"> __init__(self, \\*, featuresCol="features", labelCol="label", predictionCol="prediction", \</span> |
| <span class="sd"> probabilityCol="probability", rawPredictionCol="rawPrediction", smoothing=1.0, \</span> |
| <span class="sd"> modelType="multinomial", thresholds=None, weightCol=None)</span> |
| <span class="sd"> """</span> |
| <span class="nb">super</span><span class="p">(</span><span class="n">NaiveBayes</span><span class="p">,</span> <span class="bp">self</span><span class="p">)</span><span class="o">.</span><span class="fm">__init__</span><span class="p">()</span> |
| <span class="bp">self</span><span class="o">.</span><span class="n">_java_obj</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">_new_java_obj</span><span class="p">(</span> |
| <span class="s2">"org.apache.spark.ml.classification.NaiveBayes"</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">uid</span> |
| <span class="p">)</span> |
| <span class="n">kwargs</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">_input_kwargs</span> |
| <span class="bp">self</span><span class="o">.</span><span class="n">setParams</span><span class="p">(</span><span class="o">**</span><span class="n">kwargs</span><span class="p">)</span> |
| |
| <div class="viewcode-block" id="NaiveBayes.setParams"><a class="viewcode-back" href="../../../reference/api/pyspark.ml.classification.NaiveBayes.html#pyspark.ml.classification.NaiveBayes.setParams">[docs]</a> <span class="nd">@keyword_only</span> |
| <span class="nd">@since</span><span class="p">(</span><span class="s2">"1.5.0"</span><span class="p">)</span> |
| <span class="k">def</span> <span class="nf">setParams</span><span class="p">(</span> |
| <span class="bp">self</span><span class="p">,</span> |
| <span class="o">*</span><span class="p">,</span> |
| <span class="n">featuresCol</span><span class="p">:</span> <span class="nb">str</span> <span class="o">=</span> <span class="s2">"features"</span><span class="p">,</span> |
| <span class="n">labelCol</span><span class="p">:</span> <span class="nb">str</span> <span class="o">=</span> <span class="s2">"label"</span><span class="p">,</span> |
| <span class="n">predictionCol</span><span class="p">:</span> <span class="nb">str</span> <span class="o">=</span> <span class="s2">"prediction"</span><span class="p">,</span> |
| <span class="n">probabilityCol</span><span class="p">:</span> <span class="nb">str</span> <span class="o">=</span> <span class="s2">"probability"</span><span class="p">,</span> |
| <span class="n">rawPredictionCol</span><span class="p">:</span> <span class="nb">str</span> <span class="o">=</span> <span class="s2">"rawPrediction"</span><span class="p">,</span> |
| <span class="n">smoothing</span><span class="p">:</span> <span class="nb">float</span> <span class="o">=</span> <span class="mf">1.0</span><span class="p">,</span> |
| <span class="n">modelType</span><span class="p">:</span> <span class="nb">str</span> <span class="o">=</span> <span class="s2">"multinomial"</span><span class="p">,</span> |
| <span class="n">thresholds</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="n">List</span><span class="p">[</span><span class="nb">float</span><span class="p">]]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span> |
| <span class="n">weightCol</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="nb">str</span><span class="p">]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span> |
| <span class="p">)</span> <span class="o">-></span> <span class="s2">"NaiveBayes"</span><span class="p">:</span> |
| <span class="w"> </span><span class="sd">"""</span> |
| <span class="sd"> setParams(self, \\*, featuresCol="features", labelCol="label", predictionCol="prediction", \</span> |
| <span class="sd"> probabilityCol="probability", rawPredictionCol="rawPrediction", smoothing=1.0, \</span> |
| <span class="sd"> modelType="multinomial", thresholds=None, weightCol=None)</span> |
| <span class="sd"> Sets params for Naive Bayes.</span> |
| <span class="sd"> """</span> |
| <span class="n">kwargs</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">_input_kwargs</span> |
| <span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">_set</span><span class="p">(</span><span class="o">**</span><span class="n">kwargs</span><span class="p">)</span></div> |
| |
| <span class="k">def</span> <span class="nf">_create_model</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">java_model</span><span class="p">:</span> <span class="s2">"JavaObject"</span><span class="p">)</span> <span class="o">-></span> <span class="s2">"NaiveBayesModel"</span><span class="p">:</span> |
| <span class="k">return</span> <span class="n">NaiveBayesModel</span><span class="p">(</span><span class="n">java_model</span><span class="p">)</span> |
| |
| <div class="viewcode-block" id="NaiveBayes.setSmoothing"><a class="viewcode-back" href="../../../reference/api/pyspark.ml.classification.NaiveBayes.html#pyspark.ml.classification.NaiveBayes.setSmoothing">[docs]</a> <span class="nd">@since</span><span class="p">(</span><span class="s2">"1.5.0"</span><span class="p">)</span> |
| <span class="k">def</span> <span class="nf">setSmoothing</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">value</span><span class="p">:</span> <span class="nb">float</span><span class="p">)</span> <span class="o">-></span> <span class="s2">"NaiveBayes"</span><span class="p">:</span> |
| <span class="w"> </span><span class="sd">"""</span> |
| <span class="sd"> Sets the value of :py:attr:`smoothing`.</span> |
| <span class="sd"> """</span> |
| <span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">_set</span><span class="p">(</span><span class="n">smoothing</span><span class="o">=</span><span class="n">value</span><span class="p">)</span></div> |
| |
| <div class="viewcode-block" id="NaiveBayes.setModelType"><a class="viewcode-back" href="../../../reference/api/pyspark.ml.classification.NaiveBayes.html#pyspark.ml.classification.NaiveBayes.setModelType">[docs]</a> <span class="nd">@since</span><span class="p">(</span><span class="s2">"1.5.0"</span><span class="p">)</span> |
| <span class="k">def</span> <span class="nf">setModelType</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">value</span><span class="p">:</span> <span class="nb">str</span><span class="p">)</span> <span class="o">-></span> <span class="s2">"NaiveBayes"</span><span class="p">:</span> |
| <span class="w"> </span><span class="sd">"""</span> |
| <span class="sd"> Sets the value of :py:attr:`modelType`.</span> |
| <span class="sd"> """</span> |
| <span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">_set</span><span class="p">(</span><span class="n">modelType</span><span class="o">=</span><span class="n">value</span><span class="p">)</span></div> |
| |
| <div class="viewcode-block" id="NaiveBayes.setWeightCol"><a class="viewcode-back" href="../../../reference/api/pyspark.ml.classification.NaiveBayes.html#pyspark.ml.classification.NaiveBayes.setWeightCol">[docs]</a> <span class="k">def</span> <span class="nf">setWeightCol</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">value</span><span class="p">:</span> <span class="nb">str</span><span class="p">)</span> <span class="o">-></span> <span class="s2">"NaiveBayes"</span><span class="p">:</span> |
| <span class="w"> </span><span class="sd">"""</span> |
| <span class="sd"> Sets the value of :py:attr:`weightCol`.</span> |
| <span class="sd"> """</span> |
| <span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">_set</span><span class="p">(</span><span class="n">weightCol</span><span class="o">=</span><span class="n">value</span><span class="p">)</span></div></div> |
| |
| |
| <div class="viewcode-block" id="NaiveBayesModel"><a class="viewcode-back" href="../../../reference/api/pyspark.ml.classification.NaiveBayesModel.html#pyspark.ml.classification.NaiveBayesModel">[docs]</a><span class="k">class</span> <span class="nc">NaiveBayesModel</span><span class="p">(</span> |
| <span class="n">_JavaProbabilisticClassificationModel</span><span class="p">[</span><span class="n">Vector</span><span class="p">],</span> |
| <span class="n">_NaiveBayesParams</span><span class="p">,</span> |
| <span class="n">JavaMLWritable</span><span class="p">,</span> |
| <span class="n">JavaMLReadable</span><span class="p">[</span><span class="s2">"NaiveBayesModel"</span><span class="p">],</span> |
| <span class="p">):</span> |
| <span class="w"> </span><span class="sd">"""</span> |
| <span class="sd"> Model fitted by NaiveBayes.</span> |
| |
| <span class="sd"> .. versionadded:: 1.5.0</span> |
| <span class="sd"> """</span> |
| |
| <span class="nd">@property</span> <span class="c1"># type: ignore[misc]</span> |
| <span class="nd">@since</span><span class="p">(</span><span class="s2">"2.0.0"</span><span class="p">)</span> |
| <span class="k">def</span> <span class="nf">pi</span><span class="p">(</span><span class="bp">self</span><span class="p">)</span> <span class="o">-></span> <span class="n">Vector</span><span class="p">:</span> |
| <span class="w"> </span><span class="sd">"""</span> |
| <span class="sd"> log of class priors.</span> |
| <span class="sd"> """</span> |
| <span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">_call_java</span><span class="p">(</span><span class="s2">"pi"</span><span class="p">)</span> |
| |
| <span class="nd">@property</span> <span class="c1"># type: ignore[misc]</span> |
| <span class="nd">@since</span><span class="p">(</span><span class="s2">"2.0.0"</span><span class="p">)</span> |
| <span class="k">def</span> <span class="nf">theta</span><span class="p">(</span><span class="bp">self</span><span class="p">)</span> <span class="o">-></span> <span class="n">Matrix</span><span class="p">:</span> |
| <span class="w"> </span><span class="sd">"""</span> |
| <span class="sd"> log of class conditional probabilities.</span> |
| <span class="sd"> """</span> |
| <span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">_call_java</span><span class="p">(</span><span class="s2">"theta"</span><span class="p">)</span> |
| |
| <span class="nd">@property</span> <span class="c1"># type: ignore[misc]</span> |
| <span class="nd">@since</span><span class="p">(</span><span class="s2">"3.0.0"</span><span class="p">)</span> |
| <span class="k">def</span> <span class="nf">sigma</span><span class="p">(</span><span class="bp">self</span><span class="p">)</span> <span class="o">-></span> <span class="n">Matrix</span><span class="p">:</span> |
| <span class="w"> </span><span class="sd">"""</span> |
| <span class="sd"> variance of each feature.</span> |
| <span class="sd"> """</span> |
| <span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">_call_java</span><span class="p">(</span><span class="s2">"sigma"</span><span class="p">)</span></div> |
| |
| |
| <span class="k">class</span> <span class="nc">_MultilayerPerceptronParams</span><span class="p">(</span> |
| <span class="n">_ProbabilisticClassifierParams</span><span class="p">,</span> |
| <span class="n">HasSeed</span><span class="p">,</span> |
| <span class="n">HasMaxIter</span><span class="p">,</span> |
| <span class="n">HasTol</span><span class="p">,</span> |
| <span class="n">HasStepSize</span><span class="p">,</span> |
| <span class="n">HasSolver</span><span class="p">,</span> |
| <span class="n">HasBlockSize</span><span class="p">,</span> |
| <span class="p">):</span> |
| <span class="w"> </span><span class="sd">"""</span> |
| <span class="sd"> Params for :py:class:`MultilayerPerceptronClassifier`.</span> |
| |
| <span class="sd"> .. versionadded:: 3.0.0</span> |
| <span class="sd"> """</span> |
| |
| <span class="n">layers</span><span class="p">:</span> <span class="n">Param</span><span class="p">[</span><span class="n">List</span><span class="p">[</span><span class="nb">int</span><span class="p">]]</span> <span class="o">=</span> <span class="n">Param</span><span class="p">(</span> |
| <span class="n">Params</span><span class="o">.</span><span class="n">_dummy</span><span class="p">(),</span> |
| <span class="s2">"layers"</span><span class="p">,</span> |
| <span class="s2">"Sizes of layers from input layer to output layer "</span> |
| <span class="o">+</span> <span class="s2">"E.g., Array(780, 100, 10) means 780 inputs, one hidden layer with 100 "</span> |
| <span class="o">+</span> <span class="s2">"neurons and output layer of 10 neurons."</span><span class="p">,</span> |
| <span class="n">typeConverter</span><span class="o">=</span><span class="n">TypeConverters</span><span class="o">.</span><span class="n">toListInt</span><span class="p">,</span> |
| <span class="p">)</span> |
| <span class="n">solver</span><span class="p">:</span> <span class="n">Param</span><span class="p">[</span><span class="nb">str</span><span class="p">]</span> <span class="o">=</span> <span class="n">Param</span><span class="p">(</span> |
| <span class="n">Params</span><span class="o">.</span><span class="n">_dummy</span><span class="p">(),</span> |
| <span class="s2">"solver"</span><span class="p">,</span> |
| <span class="s2">"The solver algorithm for optimization. Supported "</span> <span class="o">+</span> <span class="s2">"options: l-bfgs, gd."</span><span class="p">,</span> |
| <span class="n">typeConverter</span><span class="o">=</span><span class="n">TypeConverters</span><span class="o">.</span><span class="n">toString</span><span class="p">,</span> |
| <span class="p">)</span> |
| <span class="n">initialWeights</span><span class="p">:</span> <span class="n">Param</span><span class="p">[</span><span class="n">Vector</span><span class="p">]</span> <span class="o">=</span> <span class="n">Param</span><span class="p">(</span> |
| <span class="n">Params</span><span class="o">.</span><span class="n">_dummy</span><span class="p">(),</span> |
| <span class="s2">"initialWeights"</span><span class="p">,</span> |
| <span class="s2">"The initial weights of the model."</span><span class="p">,</span> |
| <span class="n">typeConverter</span><span class="o">=</span><span class="n">TypeConverters</span><span class="o">.</span><span class="n">toVector</span><span class="p">,</span> |
| <span class="p">)</span> |
| |
| <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="o">*</span><span class="n">args</span><span class="p">:</span> <span class="n">Any</span><span class="p">):</span> |
| <span class="nb">super</span><span class="p">(</span><span class="n">_MultilayerPerceptronParams</span><span class="p">,</span> <span class="bp">self</span><span class="p">)</span><span class="o">.</span><span class="fm">__init__</span><span class="p">(</span><span class="o">*</span><span class="n">args</span><span class="p">)</span> |
| <span class="bp">self</span><span class="o">.</span><span class="n">_setDefault</span><span class="p">(</span><span class="n">maxIter</span><span class="o">=</span><span class="mi">100</span><span class="p">,</span> <span class="n">tol</span><span class="o">=</span><span class="mf">1e-6</span><span class="p">,</span> <span class="n">blockSize</span><span class="o">=</span><span class="mi">128</span><span class="p">,</span> <span class="n">stepSize</span><span class="o">=</span><span class="mf">0.03</span><span class="p">,</span> <span class="n">solver</span><span class="o">=</span><span class="s2">"l-bfgs"</span><span class="p">)</span> |
| |
| <span class="nd">@since</span><span class="p">(</span><span class="s2">"1.6.0"</span><span class="p">)</span> |
| <span class="k">def</span> <span class="nf">getLayers</span><span class="p">(</span><span class="bp">self</span><span class="p">)</span> <span class="o">-></span> <span class="n">List</span><span class="p">[</span><span class="nb">int</span><span class="p">]:</span> |
| <span class="w"> </span><span class="sd">"""</span> |
| <span class="sd"> Gets the value of layers or its default value.</span> |
| <span class="sd"> """</span> |
| <span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">getOrDefault</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">layers</span><span class="p">)</span> |
| |
| <span class="nd">@since</span><span class="p">(</span><span class="s2">"2.0.0"</span><span class="p">)</span> |
| <span class="k">def</span> <span class="nf">getInitialWeights</span><span class="p">(</span><span class="bp">self</span><span class="p">)</span> <span class="o">-></span> <span class="n">Vector</span><span class="p">:</span> |
| <span class="w"> </span><span class="sd">"""</span> |
| <span class="sd"> Gets the value of initialWeights or its default value.</span> |
| <span class="sd"> """</span> |
| <span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">getOrDefault</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">initialWeights</span><span class="p">)</span> |
| |
| |
| <div class="viewcode-block" id="MultilayerPerceptronClassifier"><a class="viewcode-back" href="../../../reference/api/pyspark.ml.classification.MultilayerPerceptronClassifier.html#pyspark.ml.classification.MultilayerPerceptronClassifier">[docs]</a><span class="nd">@inherit_doc</span> |
| <span class="k">class</span> <span class="nc">MultilayerPerceptronClassifier</span><span class="p">(</span> |
| <span class="n">_JavaProbabilisticClassifier</span><span class="p">[</span><span class="s2">"MultilayerPerceptronClassificationModel"</span><span class="p">],</span> |
| <span class="n">_MultilayerPerceptronParams</span><span class="p">,</span> |
| <span class="n">JavaMLWritable</span><span class="p">,</span> |
| <span class="n">JavaMLReadable</span><span class="p">[</span><span class="s2">"MultilayerPerceptronClassifier"</span><span class="p">],</span> |
| <span class="p">):</span> |
| <span class="w"> </span><span class="sd">"""</span> |
| <span class="sd"> Classifier trainer based on the Multilayer Perceptron.</span> |
| <span class="sd"> Each layer has sigmoid activation function, output layer has softmax.</span> |
| <span class="sd"> Number of inputs has to be equal to the size of feature vectors.</span> |
| <span class="sd"> Number of outputs has to be equal to the total number of labels.</span> |
| |
| <span class="sd"> .. versionadded:: 1.6.0</span> |
| |
| <span class="sd"> Examples</span> |
| <span class="sd"> --------</span> |
| <span class="sd"> >>> from pyspark.ml.linalg import Vectors</span> |
| <span class="sd"> >>> df = spark.createDataFrame([</span> |
| <span class="sd"> ... (0.0, Vectors.dense([0.0, 0.0])),</span> |
| <span class="sd"> ... (1.0, Vectors.dense([0.0, 1.0])),</span> |
| <span class="sd"> ... (1.0, Vectors.dense([1.0, 0.0])),</span> |
| <span class="sd"> ... (0.0, Vectors.dense([1.0, 1.0]))], ["label", "features"])</span> |
| <span class="sd"> >>> mlp = MultilayerPerceptronClassifier(layers=[2, 2, 2], seed=123)</span> |
| <span class="sd"> >>> mlp.setMaxIter(100)</span> |
| <span class="sd"> MultilayerPerceptronClassifier...</span> |
| <span class="sd"> >>> mlp.getMaxIter()</span> |
| <span class="sd"> 100</span> |
| <span class="sd"> >>> mlp.getBlockSize()</span> |
| <span class="sd"> 128</span> |
| <span class="sd"> >>> mlp.setBlockSize(1)</span> |
| <span class="sd"> MultilayerPerceptronClassifier...</span> |
| <span class="sd"> >>> mlp.getBlockSize()</span> |
| <span class="sd"> 1</span> |
| <span class="sd"> >>> model = mlp.fit(df)</span> |
| <span class="sd"> >>> model.setFeaturesCol("features")</span> |
| <span class="sd"> MultilayerPerceptronClassificationModel...</span> |
| <span class="sd"> >>> model.getMaxIter()</span> |
| <span class="sd"> 100</span> |
| <span class="sd"> >>> model.getLayers()</span> |
| <span class="sd"> [2, 2, 2]</span> |
| <span class="sd"> >>> model.weights.size</span> |
| <span class="sd"> 12</span> |
| <span class="sd"> >>> testDF = spark.createDataFrame([</span> |
| <span class="sd"> ... (Vectors.dense([1.0, 0.0]),),</span> |
| <span class="sd"> ... (Vectors.dense([0.0, 0.0]),)], ["features"])</span> |
| <span class="sd"> >>> model.predict(testDF.head().features)</span> |
| <span class="sd"> 1.0</span> |
| <span class="sd"> >>> model.predictRaw(testDF.head().features)</span> |
| <span class="sd"> DenseVector([-16.208, 16.344])</span> |
| <span class="sd"> >>> model.predictProbability(testDF.head().features)</span> |
| <span class="sd"> DenseVector([0.0, 1.0])</span> |
| <span class="sd"> >>> model.transform(testDF).select("features", "prediction").show()</span> |
| <span class="sd"> +---------+----------+</span> |
| <span class="sd"> | features|prediction|</span> |
| <span class="sd"> +---------+----------+</span> |
| <span class="sd"> |[1.0,0.0]| 1.0|</span> |
| <span class="sd"> |[0.0,0.0]| 0.0|</span> |
| <span class="sd"> +---------+----------+</span> |
| <span class="sd"> ...</span> |
| <span class="sd"> >>> mlp_path = temp_path + "/mlp"</span> |
| <span class="sd"> >>> mlp.save(mlp_path)</span> |
| <span class="sd"> >>> mlp2 = MultilayerPerceptronClassifier.load(mlp_path)</span> |
| <span class="sd"> >>> mlp2.getBlockSize()</span> |
| <span class="sd"> 1</span> |
| <span class="sd"> >>> model_path = temp_path + "/mlp_model"</span> |
| <span class="sd"> >>> model.save(model_path)</span> |
| <span class="sd"> >>> model2 = MultilayerPerceptronClassificationModel.load(model_path)</span> |
| <span class="sd"> >>> model.getLayers() == model2.getLayers()</span> |
| <span class="sd"> True</span> |
| <span class="sd"> >>> model.weights == model2.weights</span> |
| <span class="sd"> True</span> |
| <span class="sd"> >>> model.transform(testDF).take(1) == model2.transform(testDF).take(1)</span> |
| <span class="sd"> True</span> |
| <span class="sd"> >>> mlp2 = mlp2.setInitialWeights(list(range(0, 12)))</span> |
| <span class="sd"> >>> model3 = mlp2.fit(df)</span> |
| <span class="sd"> >>> model3.weights != model2.weights</span> |
| <span class="sd"> True</span> |
| <span class="sd"> >>> model3.getLayers() == model.getLayers()</span> |
| <span class="sd"> True</span> |
| <span class="sd"> """</span> |
| |
| <span class="n">_input_kwargs</span><span class="p">:</span> <span class="n">Dict</span><span class="p">[</span><span class="nb">str</span><span class="p">,</span> <span class="n">Any</span><span class="p">]</span> |
| |
| <span class="nd">@keyword_only</span> |
| <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span> |
| <span class="bp">self</span><span class="p">,</span> |
| <span class="o">*</span><span class="p">,</span> |
| <span class="n">featuresCol</span><span class="p">:</span> <span class="nb">str</span> <span class="o">=</span> <span class="s2">"features"</span><span class="p">,</span> |
| <span class="n">labelCol</span><span class="p">:</span> <span class="nb">str</span> <span class="o">=</span> <span class="s2">"label"</span><span class="p">,</span> |
| <span class="n">predictionCol</span><span class="p">:</span> <span class="nb">str</span> <span class="o">=</span> <span class="s2">"prediction"</span><span class="p">,</span> |
| <span class="n">maxIter</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="mi">100</span><span class="p">,</span> |
| <span class="n">tol</span><span class="p">:</span> <span class="nb">float</span> <span class="o">=</span> <span class="mf">1e-6</span><span class="p">,</span> |
| <span class="n">seed</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="nb">int</span><span class="p">]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span> |
| <span class="n">layers</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="n">List</span><span class="p">[</span><span class="nb">int</span><span class="p">]]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span> |
| <span class="n">blockSize</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="mi">128</span><span class="p">,</span> |
| <span class="n">stepSize</span><span class="p">:</span> <span class="nb">float</span> <span class="o">=</span> <span class="mf">0.03</span><span class="p">,</span> |
| <span class="n">solver</span><span class="p">:</span> <span class="nb">str</span> <span class="o">=</span> <span class="s2">"l-bfgs"</span><span class="p">,</span> |
| <span class="n">initialWeights</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="n">Vector</span><span class="p">]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span> |
| <span class="n">probabilityCol</span><span class="p">:</span> <span class="nb">str</span> <span class="o">=</span> <span class="s2">"probability"</span><span class="p">,</span> |
| <span class="n">rawPredictionCol</span><span class="p">:</span> <span class="nb">str</span> <span class="o">=</span> <span class="s2">"rawPrediction"</span><span class="p">,</span> |
| <span class="p">):</span> |
| <span class="w"> </span><span class="sd">"""</span> |
| <span class="sd"> __init__(self, \\*, featuresCol="features", labelCol="label", predictionCol="prediction", \</span> |
| <span class="sd"> maxIter=100, tol=1e-6, seed=None, layers=None, blockSize=128, stepSize=0.03, \</span> |
| <span class="sd"> solver="l-bfgs", initialWeights=None, probabilityCol="probability", \</span> |
| <span class="sd"> rawPredictionCol="rawPrediction")</span> |
| <span class="sd"> """</span> |
| <span class="nb">super</span><span class="p">(</span><span class="n">MultilayerPerceptronClassifier</span><span class="p">,</span> <span class="bp">self</span><span class="p">)</span><span class="o">.</span><span class="fm">__init__</span><span class="p">()</span> |
| <span class="bp">self</span><span class="o">.</span><span class="n">_java_obj</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">_new_java_obj</span><span class="p">(</span> |
| <span class="s2">"org.apache.spark.ml.classification.MultilayerPerceptronClassifier"</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">uid</span> |
| <span class="p">)</span> |
| <span class="n">kwargs</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">_input_kwargs</span> |
| <span class="bp">self</span><span class="o">.</span><span class="n">setParams</span><span class="p">(</span><span class="o">**</span><span class="n">kwargs</span><span class="p">)</span> |
| |
| <div class="viewcode-block" id="MultilayerPerceptronClassifier.setParams"><a class="viewcode-back" href="../../../reference/api/pyspark.ml.classification.MultilayerPerceptronClassifier.html#pyspark.ml.classification.MultilayerPerceptronClassifier.setParams">[docs]</a> <span class="nd">@keyword_only</span> |
| <span class="nd">@since</span><span class="p">(</span><span class="s2">"1.6.0"</span><span class="p">)</span> |
| <span class="k">def</span> <span class="nf">setParams</span><span class="p">(</span> |
| <span class="bp">self</span><span class="p">,</span> |
| <span class="o">*</span><span class="p">,</span> |
| <span class="n">featuresCol</span><span class="p">:</span> <span class="nb">str</span> <span class="o">=</span> <span class="s2">"features"</span><span class="p">,</span> |
| <span class="n">labelCol</span><span class="p">:</span> <span class="nb">str</span> <span class="o">=</span> <span class="s2">"label"</span><span class="p">,</span> |
| <span class="n">predictionCol</span><span class="p">:</span> <span class="nb">str</span> <span class="o">=</span> <span class="s2">"prediction"</span><span class="p">,</span> |
| <span class="n">maxIter</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="mi">100</span><span class="p">,</span> |
| <span class="n">tol</span><span class="p">:</span> <span class="nb">float</span> <span class="o">=</span> <span class="mf">1e-6</span><span class="p">,</span> |
| <span class="n">seed</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="nb">int</span><span class="p">]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span> |
| <span class="n">layers</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="n">List</span><span class="p">[</span><span class="nb">int</span><span class="p">]]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span> |
| <span class="n">blockSize</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="mi">128</span><span class="p">,</span> |
| <span class="n">stepSize</span><span class="p">:</span> <span class="nb">float</span> <span class="o">=</span> <span class="mf">0.03</span><span class="p">,</span> |
| <span class="n">solver</span><span class="p">:</span> <span class="nb">str</span> <span class="o">=</span> <span class="s2">"l-bfgs"</span><span class="p">,</span> |
| <span class="n">initialWeights</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="n">Vector</span><span class="p">]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span> |
| <span class="n">probabilityCol</span><span class="p">:</span> <span class="nb">str</span> <span class="o">=</span> <span class="s2">"probability"</span><span class="p">,</span> |
| <span class="n">rawPredictionCol</span><span class="p">:</span> <span class="nb">str</span> <span class="o">=</span> <span class="s2">"rawPrediction"</span><span class="p">,</span> |
| <span class="p">)</span> <span class="o">-></span> <span class="s2">"MultilayerPerceptronClassifier"</span><span class="p">:</span> |
| <span class="w"> </span><span class="sd">"""</span> |
| <span class="sd"> setParams(self, \\*, featuresCol="features", labelCol="label", predictionCol="prediction", \</span> |
| <span class="sd"> maxIter=100, tol=1e-6, seed=None, layers=None, blockSize=128, stepSize=0.03, \</span> |
| <span class="sd"> solver="l-bfgs", initialWeights=None, probabilityCol="probability", \</span> |
| <span class="sd"> rawPredictionCol="rawPrediction"):</span> |
| <span class="sd"> Sets params for MultilayerPerceptronClassifier.</span> |
| <span class="sd"> """</span> |
| <span class="n">kwargs</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">_input_kwargs</span> |
| <span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">_set</span><span class="p">(</span><span class="o">**</span><span class="n">kwargs</span><span class="p">)</span></div> |
| |
| <span class="k">def</span> <span class="nf">_create_model</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">java_model</span><span class="p">:</span> <span class="s2">"JavaObject"</span><span class="p">)</span> <span class="o">-></span> <span class="s2">"MultilayerPerceptronClassificationModel"</span><span class="p">:</span> |
| <span class="k">return</span> <span class="n">MultilayerPerceptronClassificationModel</span><span class="p">(</span><span class="n">java_model</span><span class="p">)</span> |
| |
| <div class="viewcode-block" id="MultilayerPerceptronClassifier.setLayers"><a class="viewcode-back" href="../../../reference/api/pyspark.ml.classification.MultilayerPerceptronClassifier.html#pyspark.ml.classification.MultilayerPerceptronClassifier.setLayers">[docs]</a> <span class="nd">@since</span><span class="p">(</span><span class="s2">"1.6.0"</span><span class="p">)</span> |
| <span class="k">def</span> <span class="nf">setLayers</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">value</span><span class="p">:</span> <span class="n">List</span><span class="p">[</span><span class="nb">int</span><span class="p">])</span> <span class="o">-></span> <span class="s2">"MultilayerPerceptronClassifier"</span><span class="p">:</span> |
| <span class="w"> </span><span class="sd">"""</span> |
| <span class="sd"> Sets the value of :py:attr:`layers`.</span> |
| <span class="sd"> """</span> |
| <span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">_set</span><span class="p">(</span><span class="n">layers</span><span class="o">=</span><span class="n">value</span><span class="p">)</span></div> |
| |
| <div class="viewcode-block" id="MultilayerPerceptronClassifier.setBlockSize"><a class="viewcode-back" href="../../../reference/api/pyspark.ml.classification.MultilayerPerceptronClassifier.html#pyspark.ml.classification.MultilayerPerceptronClassifier.setBlockSize">[docs]</a> <span class="nd">@since</span><span class="p">(</span><span class="s2">"1.6.0"</span><span class="p">)</span> |
| <span class="k">def</span> <span class="nf">setBlockSize</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">value</span><span class="p">:</span> <span class="nb">int</span><span class="p">)</span> <span class="o">-></span> <span class="s2">"MultilayerPerceptronClassifier"</span><span class="p">:</span> |
| <span class="w"> </span><span class="sd">"""</span> |
| <span class="sd"> Sets the value of :py:attr:`blockSize`.</span> |
| <span class="sd"> """</span> |
| <span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">_set</span><span class="p">(</span><span class="n">blockSize</span><span class="o">=</span><span class="n">value</span><span class="p">)</span></div> |
| |
| <div class="viewcode-block" id="MultilayerPerceptronClassifier.setInitialWeights"><a class="viewcode-back" href="../../../reference/api/pyspark.ml.classification.MultilayerPerceptronClassifier.html#pyspark.ml.classification.MultilayerPerceptronClassifier.setInitialWeights">[docs]</a> <span class="nd">@since</span><span class="p">(</span><span class="s2">"2.0.0"</span><span class="p">)</span> |
| <span class="k">def</span> <span class="nf">setInitialWeights</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">value</span><span class="p">:</span> <span class="n">Vector</span><span class="p">)</span> <span class="o">-></span> <span class="s2">"MultilayerPerceptronClassifier"</span><span class="p">:</span> |
| <span class="w"> </span><span class="sd">"""</span> |
| <span class="sd"> Sets the value of :py:attr:`initialWeights`.</span> |
| <span class="sd"> """</span> |
| <span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">_set</span><span class="p">(</span><span class="n">initialWeights</span><span class="o">=</span><span class="n">value</span><span class="p">)</span></div> |
| |
| <div class="viewcode-block" id="MultilayerPerceptronClassifier.setMaxIter"><a class="viewcode-back" href="../../../reference/api/pyspark.ml.classification.MultilayerPerceptronClassifier.html#pyspark.ml.classification.MultilayerPerceptronClassifier.setMaxIter">[docs]</a> <span class="k">def</span> <span class="nf">setMaxIter</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">value</span><span class="p">:</span> <span class="nb">int</span><span class="p">)</span> <span class="o">-></span> <span class="s2">"MultilayerPerceptronClassifier"</span><span class="p">:</span> |
| <span class="w"> </span><span class="sd">"""</span> |
| <span class="sd"> Sets the value of :py:attr:`maxIter`.</span> |
| <span class="sd"> """</span> |
| <span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">_set</span><span class="p">(</span><span class="n">maxIter</span><span class="o">=</span><span class="n">value</span><span class="p">)</span></div> |
| |
| <div class="viewcode-block" id="MultilayerPerceptronClassifier.setSeed"><a class="viewcode-back" href="../../../reference/api/pyspark.ml.classification.MultilayerPerceptronClassifier.html#pyspark.ml.classification.MultilayerPerceptronClassifier.setSeed">[docs]</a> <span class="k">def</span> <span class="nf">setSeed</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">value</span><span class="p">:</span> <span class="nb">int</span><span class="p">)</span> <span class="o">-></span> <span class="s2">"MultilayerPerceptronClassifier"</span><span class="p">:</span> |
| <span class="w"> </span><span class="sd">"""</span> |
| <span class="sd"> Sets the value of :py:attr:`seed`.</span> |
| <span class="sd"> """</span> |
| <span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">_set</span><span class="p">(</span><span class="n">seed</span><span class="o">=</span><span class="n">value</span><span class="p">)</span></div> |
| |
| <div class="viewcode-block" id="MultilayerPerceptronClassifier.setTol"><a class="viewcode-back" href="../../../reference/api/pyspark.ml.classification.MultilayerPerceptronClassifier.html#pyspark.ml.classification.MultilayerPerceptronClassifier.setTol">[docs]</a> <span class="k">def</span> <span class="nf">setTol</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">value</span><span class="p">:</span> <span class="nb">float</span><span class="p">)</span> <span class="o">-></span> <span class="s2">"MultilayerPerceptronClassifier"</span><span class="p">:</span> |
| <span class="w"> </span><span class="sd">"""</span> |
| <span class="sd"> Sets the value of :py:attr:`tol`.</span> |
| <span class="sd"> """</span> |
| <span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">_set</span><span class="p">(</span><span class="n">tol</span><span class="o">=</span><span class="n">value</span><span class="p">)</span></div> |
| |
| <div class="viewcode-block" id="MultilayerPerceptronClassifier.setStepSize"><a class="viewcode-back" href="../../../reference/api/pyspark.ml.classification.MultilayerPerceptronClassifier.html#pyspark.ml.classification.MultilayerPerceptronClassifier.setStepSize">[docs]</a> <span class="nd">@since</span><span class="p">(</span><span class="s2">"2.0.0"</span><span class="p">)</span> |
| <span class="k">def</span> <span class="nf">setStepSize</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">value</span><span class="p">:</span> <span class="nb">float</span><span class="p">)</span> <span class="o">-></span> <span class="s2">"MultilayerPerceptronClassifier"</span><span class="p">:</span> |
| <span class="w"> </span><span class="sd">"""</span> |
| <span class="sd"> Sets the value of :py:attr:`stepSize`.</span> |
| <span class="sd"> """</span> |
| <span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">_set</span><span class="p">(</span><span class="n">stepSize</span><span class="o">=</span><span class="n">value</span><span class="p">)</span></div> |
| |
| <div class="viewcode-block" id="MultilayerPerceptronClassifier.setSolver"><a class="viewcode-back" href="../../../reference/api/pyspark.ml.classification.MultilayerPerceptronClassifier.html#pyspark.ml.classification.MultilayerPerceptronClassifier.setSolver">[docs]</a> <span class="k">def</span> <span class="nf">setSolver</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">value</span><span class="p">:</span> <span class="nb">str</span><span class="p">)</span> <span class="o">-></span> <span class="s2">"MultilayerPerceptronClassifier"</span><span class="p">:</span> |
| <span class="w"> </span><span class="sd">"""</span> |
| <span class="sd"> Sets the value of :py:attr:`solver`.</span> |
| <span class="sd"> """</span> |
| <span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">_set</span><span class="p">(</span><span class="n">solver</span><span class="o">=</span><span class="n">value</span><span class="p">)</span></div></div> |
| |
| |
| <div class="viewcode-block" id="MultilayerPerceptronClassificationModel"><a class="viewcode-back" href="../../../reference/api/pyspark.ml.classification.MultilayerPerceptronClassificationModel.html#pyspark.ml.classification.MultilayerPerceptronClassificationModel">[docs]</a><span class="k">class</span> <span class="nc">MultilayerPerceptronClassificationModel</span><span class="p">(</span> |
| <span class="n">_JavaProbabilisticClassificationModel</span><span class="p">[</span><span class="n">Vector</span><span class="p">],</span> |
| <span class="n">_MultilayerPerceptronParams</span><span class="p">,</span> |
| <span class="n">JavaMLWritable</span><span class="p">,</span> |
| <span class="n">JavaMLReadable</span><span class="p">[</span><span class="s2">"MultilayerPerceptronClassificationModel"</span><span class="p">],</span> |
| <span class="n">HasTrainingSummary</span><span class="p">[</span><span class="s2">"MultilayerPerceptronClassificationTrainingSummary"</span><span class="p">],</span> |
| <span class="p">):</span> |
| <span class="w"> </span><span class="sd">"""</span> |
| <span class="sd"> Model fitted by MultilayerPerceptronClassifier.</span> |
| |
| <span class="sd"> .. versionadded:: 1.6.0</span> |
| <span class="sd"> """</span> |
| |
| <span class="nd">@property</span> <span class="c1"># type: ignore[misc]</span> |
| <span class="nd">@since</span><span class="p">(</span><span class="s2">"2.0.0"</span><span class="p">)</span> |
| <span class="k">def</span> <span class="nf">weights</span><span class="p">(</span><span class="bp">self</span><span class="p">)</span> <span class="o">-></span> <span class="n">Vector</span><span class="p">:</span> |
| <span class="w"> </span><span class="sd">"""</span> |
| <span class="sd"> the weights of layers.</span> |
| <span class="sd"> """</span> |
| <span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">_call_java</span><span class="p">(</span><span class="s2">"weights"</span><span class="p">)</span> |
| |
| <div class="viewcode-block" id="MultilayerPerceptronClassificationModel.summary"><a class="viewcode-back" href="../../../reference/api/pyspark.ml.classification.MultilayerPerceptronClassificationModel.html#pyspark.ml.classification.MultilayerPerceptronClassificationModel.summary">[docs]</a> <span class="nd">@since</span><span class="p">(</span><span class="s2">"3.1.0"</span><span class="p">)</span> |
| <span class="k">def</span> <span class="nf">summary</span><span class="p">(</span><span class="bp">self</span><span class="p">)</span> <span class="o">-></span> <span class="s2">"MultilayerPerceptronClassificationTrainingSummary"</span><span class="p">:</span> |
| <span class="w"> </span><span class="sd">"""</span> |
| <span class="sd"> Gets summary (accuracy/precision/recall, objective history, total iterations) of model</span> |
| <span class="sd"> trained on the training set. An exception is thrown if `trainingSummary is None`.</span> |
| <span class="sd"> """</span> |
| <span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">hasSummary</span><span class="p">:</span> |
| <span class="k">return</span> <span class="n">MultilayerPerceptronClassificationTrainingSummary</span><span class="p">(</span> |
| <span class="nb">super</span><span class="p">(</span><span class="n">MultilayerPerceptronClassificationModel</span><span class="p">,</span> <span class="bp">self</span><span class="p">)</span><span class="o">.</span><span class="n">summary</span> |
| <span class="p">)</span> |
| <span class="k">else</span><span class="p">:</span> |
| <span class="k">raise</span> <span class="ne">RuntimeError</span><span class="p">(</span> |
| <span class="s2">"No training summary available for this </span><span class="si">%s</span><span class="s2">"</span> <span class="o">%</span> <span class="bp">self</span><span class="o">.</span><span class="vm">__class__</span><span class="o">.</span><span class="vm">__name__</span> |
| <span class="p">)</span></div> |
| |
| <div class="viewcode-block" id="MultilayerPerceptronClassificationModel.evaluate"><a class="viewcode-back" href="../../../reference/api/pyspark.ml.classification.MultilayerPerceptronClassificationModel.html#pyspark.ml.classification.MultilayerPerceptronClassificationModel.evaluate">[docs]</a> <span class="k">def</span> <span class="nf">evaluate</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">dataset</span><span class="p">:</span> <span class="n">DataFrame</span><span class="p">)</span> <span class="o">-></span> <span class="s2">"MultilayerPerceptronClassificationSummary"</span><span class="p">:</span> |
| <span class="w"> </span><span class="sd">"""</span> |
| <span class="sd"> Evaluates the model on a test dataset.</span> |
| |
| <span class="sd"> .. versionadded:: 3.1.0</span> |
| |
| <span class="sd"> Parameters</span> |
| <span class="sd"> ----------</span> |
| <span class="sd"> dataset : :py:class:`pyspark.sql.DataFrame`</span> |
| <span class="sd"> Test dataset to evaluate model on.</span> |
| <span class="sd"> """</span> |
| <span class="k">if</span> <span class="ow">not</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">dataset</span><span class="p">,</span> <span class="n">DataFrame</span><span class="p">):</span> |
| <span class="k">raise</span> <span class="ne">TypeError</span><span class="p">(</span><span class="s2">"dataset must be a DataFrame but got </span><span class="si">%s</span><span class="s2">."</span> <span class="o">%</span> <span class="nb">type</span><span class="p">(</span><span class="n">dataset</span><span class="p">))</span> |
| <span class="n">java_mlp_summary</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">_call_java</span><span class="p">(</span><span class="s2">"evaluate"</span><span class="p">,</span> <span class="n">dataset</span><span class="p">)</span> |
| <span class="k">return</span> <span class="n">MultilayerPerceptronClassificationSummary</span><span class="p">(</span><span class="n">java_mlp_summary</span><span class="p">)</span></div></div> |
| |
| |
| <div class="viewcode-block" id="MultilayerPerceptronClassificationSummary"><a class="viewcode-back" href="../../../reference/api/pyspark.ml.classification.MultilayerPerceptronClassificationSummary.html#pyspark.ml.classification.MultilayerPerceptronClassificationSummary">[docs]</a><span class="k">class</span> <span class="nc">MultilayerPerceptronClassificationSummary</span><span class="p">(</span><span class="n">_ClassificationSummary</span><span class="p">):</span> |
| <span class="w"> </span><span class="sd">"""</span> |
| <span class="sd"> Abstraction for MultilayerPerceptronClassifier Results for a given model.</span> |
| |
| <span class="sd"> .. versionadded:: 3.1.0</span> |
| <span class="sd"> """</span> |
| |
| <span class="k">pass</span></div> |
| |
| |
| <div class="viewcode-block" id="MultilayerPerceptronClassificationTrainingSummary"><a class="viewcode-back" href="../../../reference/api/pyspark.ml.classification.MultilayerPerceptronClassificationTrainingSummary.html#pyspark.ml.classification.MultilayerPerceptronClassificationTrainingSummary">[docs]</a><span class="nd">@inherit_doc</span> |
| <span class="k">class</span> <span class="nc">MultilayerPerceptronClassificationTrainingSummary</span><span class="p">(</span> |
| <span class="n">MultilayerPerceptronClassificationSummary</span><span class="p">,</span> <span class="n">_TrainingSummary</span> |
| <span class="p">):</span> |
| <span class="w"> </span><span class="sd">"""</span> |
| <span class="sd"> Abstraction for MultilayerPerceptronClassifier Training results.</span> |
| |
| <span class="sd"> .. versionadded:: 3.1.0</span> |
| <span class="sd"> """</span> |
| |
| <span class="k">pass</span></div> |
| |
| |
| <span class="k">class</span> <span class="nc">_OneVsRestParams</span><span class="p">(</span><span class="n">_ClassifierParams</span><span class="p">,</span> <span class="n">HasWeightCol</span><span class="p">):</span> |
| <span class="w"> </span><span class="sd">"""</span> |
| <span class="sd"> Params for :py:class:`OneVsRest` and :py:class:`OneVsRestModelModel`.</span> |
| <span class="sd"> """</span> |
| |
| <span class="n">classifier</span><span class="p">:</span> <span class="n">Param</span><span class="p">[</span><span class="n">Classifier</span><span class="p">]</span> <span class="o">=</span> <span class="n">Param</span><span class="p">(</span><span class="n">Params</span><span class="o">.</span><span class="n">_dummy</span><span class="p">(),</span> <span class="s2">"classifier"</span><span class="p">,</span> <span class="s2">"base binary classifier"</span><span class="p">)</span> |
| |
| <span class="nd">@since</span><span class="p">(</span><span class="s2">"2.0.0"</span><span class="p">)</span> |
| <span class="k">def</span> <span class="nf">getClassifier</span><span class="p">(</span><span class="bp">self</span><span class="p">)</span> <span class="o">-></span> <span class="n">Classifier</span><span class="p">:</span> |
| <span class="w"> </span><span class="sd">"""</span> |
| <span class="sd"> Gets the value of classifier or its default value.</span> |
| <span class="sd"> """</span> |
| <span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">getOrDefault</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">classifier</span><span class="p">)</span> |
| |
| |
| <div class="viewcode-block" id="OneVsRest"><a class="viewcode-back" href="../../../reference/api/pyspark.ml.classification.OneVsRest.html#pyspark.ml.classification.OneVsRest">[docs]</a><span class="nd">@inherit_doc</span> |
| <span class="k">class</span> <span class="nc">OneVsRest</span><span class="p">(</span> |
| <span class="n">Estimator</span><span class="p">[</span><span class="s2">"OneVsRestModel"</span><span class="p">],</span> |
| <span class="n">_OneVsRestParams</span><span class="p">,</span> |
| <span class="n">HasParallelism</span><span class="p">,</span> |
| <span class="n">MLReadable</span><span class="p">[</span><span class="s2">"OneVsRest"</span><span class="p">],</span> |
| <span class="n">MLWritable</span><span class="p">,</span> |
| <span class="n">Generic</span><span class="p">[</span><span class="n">CM</span><span class="p">],</span> |
| <span class="p">):</span> |
| <span class="w"> </span><span class="sd">"""</span> |
| <span class="sd"> Reduction of Multiclass Classification to Binary Classification.</span> |
| <span class="sd"> Performs reduction using one against all strategy.</span> |
| <span class="sd"> For a multiclass classification with k classes, train k models (one per class).</span> |
| <span class="sd"> Each example is scored against all k models and the model with highest score</span> |
| <span class="sd"> is picked to label the example.</span> |
| |
| <span class="sd"> .. versionadded:: 2.0.0</span> |
| |
| <span class="sd"> Examples</span> |
| <span class="sd"> --------</span> |
| <span class="sd"> >>> from pyspark.sql import Row</span> |
| <span class="sd"> >>> from pyspark.ml.linalg import Vectors</span> |
| <span class="sd"> >>> data_path = "data/mllib/sample_multiclass_classification_data.txt"</span> |
| <span class="sd"> >>> df = spark.read.format("libsvm").load(data_path)</span> |
| <span class="sd"> >>> lr = LogisticRegression(regParam=0.01)</span> |
| <span class="sd"> >>> ovr = OneVsRest(classifier=lr)</span> |
| <span class="sd"> >>> ovr.getRawPredictionCol()</span> |
| <span class="sd"> 'rawPrediction'</span> |
| <span class="sd"> >>> ovr.setPredictionCol("newPrediction")</span> |
| <span class="sd"> OneVsRest...</span> |
| <span class="sd"> >>> model = ovr.fit(df)</span> |
| <span class="sd"> >>> model.models[0].coefficients</span> |
| <span class="sd"> DenseVector([0.5..., -1.0..., 3.4..., 4.2...])</span> |
| <span class="sd"> >>> model.models[1].coefficients</span> |
| <span class="sd"> DenseVector([-2.1..., 3.1..., -2.6..., -2.3...])</span> |
| <span class="sd"> >>> model.models[2].coefficients</span> |
| <span class="sd"> DenseVector([0.3..., -3.4..., 1.0..., -1.1...])</span> |
| <span class="sd"> >>> [x.intercept for x in model.models]</span> |
| <span class="sd"> [-2.7..., -2.5..., -1.3...]</span> |
| <span class="sd"> >>> test0 = sc.parallelize([Row(features=Vectors.dense(-1.0, 0.0, 1.0, 1.0))]).toDF()</span> |
| <span class="sd"> >>> model.transform(test0).head().newPrediction</span> |
| <span class="sd"> 0.0</span> |
| <span class="sd"> >>> test1 = sc.parallelize([Row(features=Vectors.sparse(4, [0], [1.0]))]).toDF()</span> |
| <span class="sd"> >>> model.transform(test1).head().newPrediction</span> |
| <span class="sd"> 2.0</span> |
| <span class="sd"> >>> test2 = sc.parallelize([Row(features=Vectors.dense(0.5, 0.4, 0.3, 0.2))]).toDF()</span> |
| <span class="sd"> >>> model.transform(test2).head().newPrediction</span> |
| <span class="sd"> 0.0</span> |
| <span class="sd"> >>> model_path = temp_path + "/ovr_model"</span> |
| <span class="sd"> >>> model.save(model_path)</span> |
| <span class="sd"> >>> model2 = OneVsRestModel.load(model_path)</span> |
| <span class="sd"> >>> model2.transform(test0).head().newPrediction</span> |
| <span class="sd"> 0.0</span> |
| <span class="sd"> >>> model.transform(test0).take(1) == model2.transform(test0).take(1)</span> |
| <span class="sd"> True</span> |
| <span class="sd"> >>> model.transform(test2).columns</span> |
| <span class="sd"> ['features', 'rawPrediction', 'newPrediction']</span> |
| <span class="sd"> """</span> |
| |
| <span class="n">_input_kwargs</span><span class="p">:</span> <span class="n">Dict</span><span class="p">[</span><span class="nb">str</span><span class="p">,</span> <span class="n">Any</span><span class="p">]</span> |
| |
| <span class="nd">@keyword_only</span> |
| <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span> |
| <span class="bp">self</span><span class="p">,</span> |
| <span class="o">*</span><span class="p">,</span> |
| <span class="n">featuresCol</span><span class="p">:</span> <span class="nb">str</span> <span class="o">=</span> <span class="s2">"features"</span><span class="p">,</span> |
| <span class="n">labelCol</span><span class="p">:</span> <span class="nb">str</span> <span class="o">=</span> <span class="s2">"label"</span><span class="p">,</span> |
| <span class="n">predictionCol</span><span class="p">:</span> <span class="nb">str</span> <span class="o">=</span> <span class="s2">"prediction"</span><span class="p">,</span> |
| <span class="n">rawPredictionCol</span><span class="p">:</span> <span class="nb">str</span> <span class="o">=</span> <span class="s2">"rawPrediction"</span><span class="p">,</span> |
| <span class="n">classifier</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="n">Classifier</span><span class="p">[</span><span class="n">CM</span><span class="p">]]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span> |
| <span class="n">weightCol</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="nb">str</span><span class="p">]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span> |
| <span class="n">parallelism</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="mi">1</span><span class="p">,</span> |
| <span class="p">):</span> |
| <span class="w"> </span><span class="sd">"""</span> |
| <span class="sd"> __init__(self, \\*, featuresCol="features", labelCol="label", predictionCol="prediction", \</span> |
| <span class="sd"> rawPredictionCol="rawPrediction", classifier=None, weightCol=None, parallelism=1):</span> |
| <span class="sd"> """</span> |
| <span class="nb">super</span><span class="p">(</span><span class="n">OneVsRest</span><span class="p">,</span> <span class="bp">self</span><span class="p">)</span><span class="o">.</span><span class="fm">__init__</span><span class="p">()</span> |
| <span class="bp">self</span><span class="o">.</span><span class="n">_setDefault</span><span class="p">(</span><span class="n">parallelism</span><span class="o">=</span><span class="mi">1</span><span class="p">)</span> |
| <span class="n">kwargs</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">_input_kwargs</span> |
| <span class="bp">self</span><span class="o">.</span><span class="n">_set</span><span class="p">(</span><span class="o">**</span><span class="n">kwargs</span><span class="p">)</span> |
| |
| <div class="viewcode-block" id="OneVsRest.setParams"><a class="viewcode-back" href="../../../reference/api/pyspark.ml.classification.OneVsRest.html#pyspark.ml.classification.OneVsRest.setParams">[docs]</a> <span class="nd">@keyword_only</span> |
| <span class="nd">@since</span><span class="p">(</span><span class="s2">"2.0.0"</span><span class="p">)</span> |
| <span class="k">def</span> <span class="nf">setParams</span><span class="p">(</span> |
| <span class="bp">self</span><span class="p">,</span> |
| <span class="o">*</span><span class="p">,</span> |
| <span class="n">featuresCol</span><span class="p">:</span> <span class="nb">str</span> <span class="o">=</span> <span class="s2">"features"</span><span class="p">,</span> |
| <span class="n">labelCol</span><span class="p">:</span> <span class="nb">str</span> <span class="o">=</span> <span class="s2">"label"</span><span class="p">,</span> |
| <span class="n">predictionCol</span><span class="p">:</span> <span class="nb">str</span> <span class="o">=</span> <span class="s2">"prediction"</span><span class="p">,</span> |
| <span class="n">rawPredictionCol</span><span class="p">:</span> <span class="nb">str</span> <span class="o">=</span> <span class="s2">"rawPrediction"</span><span class="p">,</span> |
| <span class="n">classifier</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="n">Classifier</span><span class="p">[</span><span class="n">CM</span><span class="p">]]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span> |
| <span class="n">weightCol</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="nb">str</span><span class="p">]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span> |
| <span class="n">parallelism</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="mi">1</span><span class="p">,</span> |
| <span class="p">)</span> <span class="o">-></span> <span class="s2">"OneVsRest"</span><span class="p">:</span> |
| <span class="w"> </span><span class="sd">"""</span> |
| <span class="sd"> setParams(self, \\*, featuresCol="features", labelCol="label", predictionCol="prediction", \</span> |
| <span class="sd"> rawPredictionCol="rawPrediction", classifier=None, weightCol=None, parallelism=1):</span> |
| <span class="sd"> Sets params for OneVsRest.</span> |
| <span class="sd"> """</span> |
| <span class="n">kwargs</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">_input_kwargs</span> |
| <span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">_set</span><span class="p">(</span><span class="o">**</span><span class="n">kwargs</span><span class="p">)</span></div> |
| |
| <div class="viewcode-block" id="OneVsRest.setClassifier"><a class="viewcode-back" href="../../../reference/api/pyspark.ml.classification.OneVsRest.html#pyspark.ml.classification.OneVsRest.setClassifier">[docs]</a> <span class="nd">@since</span><span class="p">(</span><span class="s2">"2.0.0"</span><span class="p">)</span> |
| <span class="k">def</span> <span class="nf">setClassifier</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">value</span><span class="p">:</span> <span class="n">Classifier</span><span class="p">[</span><span class="n">CM</span><span class="p">])</span> <span class="o">-></span> <span class="s2">"OneVsRest"</span><span class="p">:</span> |
| <span class="w"> </span><span class="sd">"""</span> |
| <span class="sd"> Sets the value of :py:attr:`classifier`.</span> |
| <span class="sd"> """</span> |
| <span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">_set</span><span class="p">(</span><span class="n">classifier</span><span class="o">=</span><span class="n">value</span><span class="p">)</span></div> |
| |
| <div class="viewcode-block" id="OneVsRest.setLabelCol"><a class="viewcode-back" href="../../../reference/api/pyspark.ml.classification.OneVsRest.html#pyspark.ml.classification.OneVsRest.setLabelCol">[docs]</a> <span class="k">def</span> <span class="nf">setLabelCol</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">value</span><span class="p">:</span> <span class="nb">str</span><span class="p">)</span> <span class="o">-></span> <span class="s2">"OneVsRest"</span><span class="p">:</span> |
| <span class="w"> </span><span class="sd">"""</span> |
| <span class="sd"> Sets the value of :py:attr:`labelCol`.</span> |
| <span class="sd"> """</span> |
| <span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">_set</span><span class="p">(</span><span class="n">labelCol</span><span class="o">=</span><span class="n">value</span><span class="p">)</span></div> |
| |
| <div class="viewcode-block" id="OneVsRest.setFeaturesCol"><a class="viewcode-back" href="../../../reference/api/pyspark.ml.classification.OneVsRest.html#pyspark.ml.classification.OneVsRest.setFeaturesCol">[docs]</a> <span class="k">def</span> <span class="nf">setFeaturesCol</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">value</span><span class="p">:</span> <span class="nb">str</span><span class="p">)</span> <span class="o">-></span> <span class="s2">"OneVsRest"</span><span class="p">:</span> |
| <span class="w"> </span><span class="sd">"""</span> |
| <span class="sd"> Sets the value of :py:attr:`featuresCol`.</span> |
| <span class="sd"> """</span> |
| <span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">_set</span><span class="p">(</span><span class="n">featuresCol</span><span class="o">=</span><span class="n">value</span><span class="p">)</span></div> |
| |
| <div class="viewcode-block" id="OneVsRest.setPredictionCol"><a class="viewcode-back" href="../../../reference/api/pyspark.ml.classification.OneVsRest.html#pyspark.ml.classification.OneVsRest.setPredictionCol">[docs]</a> <span class="k">def</span> <span class="nf">setPredictionCol</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">value</span><span class="p">:</span> <span class="nb">str</span><span class="p">)</span> <span class="o">-></span> <span class="s2">"OneVsRest"</span><span class="p">:</span> |
| <span class="w"> </span><span class="sd">"""</span> |
| <span class="sd"> Sets the value of :py:attr:`predictionCol`.</span> |
| <span class="sd"> """</span> |
| <span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">_set</span><span class="p">(</span><span class="n">predictionCol</span><span class="o">=</span><span class="n">value</span><span class="p">)</span></div> |
| |
| <div class="viewcode-block" id="OneVsRest.setRawPredictionCol"><a class="viewcode-back" href="../../../reference/api/pyspark.ml.classification.OneVsRest.html#pyspark.ml.classification.OneVsRest.setRawPredictionCol">[docs]</a> <span class="k">def</span> <span class="nf">setRawPredictionCol</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">value</span><span class="p">:</span> <span class="nb">str</span><span class="p">)</span> <span class="o">-></span> <span class="s2">"OneVsRest"</span><span class="p">:</span> |
| <span class="w"> </span><span class="sd">"""</span> |
| <span class="sd"> Sets the value of :py:attr:`rawPredictionCol`.</span> |
| <span class="sd"> """</span> |
| <span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">_set</span><span class="p">(</span><span class="n">rawPredictionCol</span><span class="o">=</span><span class="n">value</span><span class="p">)</span></div> |
| |
| <div class="viewcode-block" id="OneVsRest.setWeightCol"><a class="viewcode-back" href="../../../reference/api/pyspark.ml.classification.OneVsRest.html#pyspark.ml.classification.OneVsRest.setWeightCol">[docs]</a> <span class="k">def</span> <span class="nf">setWeightCol</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">value</span><span class="p">:</span> <span class="nb">str</span><span class="p">)</span> <span class="o">-></span> <span class="s2">"OneVsRest"</span><span class="p">:</span> |
| <span class="w"> </span><span class="sd">"""</span> |
| <span class="sd"> Sets the value of :py:attr:`weightCol`.</span> |
| <span class="sd"> """</span> |
| <span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">_set</span><span class="p">(</span><span class="n">weightCol</span><span class="o">=</span><span class="n">value</span><span class="p">)</span></div> |
| |
| <div class="viewcode-block" id="OneVsRest.setParallelism"><a class="viewcode-back" href="../../../reference/api/pyspark.ml.classification.OneVsRest.html#pyspark.ml.classification.OneVsRest.setParallelism">[docs]</a> <span class="k">def</span> <span class="nf">setParallelism</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">value</span><span class="p">:</span> <span class="nb">int</span><span class="p">)</span> <span class="o">-></span> <span class="s2">"OneVsRest"</span><span class="p">:</span> |
| <span class="w"> </span><span class="sd">"""</span> |
| <span class="sd"> Sets the value of :py:attr:`parallelism`.</span> |
| <span class="sd"> """</span> |
| <span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">_set</span><span class="p">(</span><span class="n">parallelism</span><span class="o">=</span><span class="n">value</span><span class="p">)</span></div> |
| |
| <span class="k">def</span> <span class="nf">_fit</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">dataset</span><span class="p">:</span> <span class="n">DataFrame</span><span class="p">)</span> <span class="o">-></span> <span class="s2">"OneVsRestModel"</span><span class="p">:</span> |
| <span class="n">labelCol</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">getLabelCol</span><span class="p">()</span> |
| <span class="n">featuresCol</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">getFeaturesCol</span><span class="p">()</span> |
| <span class="n">predictionCol</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">getPredictionCol</span><span class="p">()</span> |
| <span class="n">classifier</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">getClassifier</span><span class="p">()</span> |
| |
| <span class="n">numClasses</span> <span class="o">=</span> <span class="p">(</span> |
| <span class="nb">int</span><span class="p">(</span><span class="n">cast</span><span class="p">(</span><span class="n">Row</span><span class="p">,</span> <span class="n">dataset</span><span class="o">.</span><span class="n">agg</span><span class="p">({</span><span class="n">labelCol</span><span class="p">:</span> <span class="s2">"max"</span><span class="p">})</span><span class="o">.</span><span class="n">head</span><span class="p">())[</span><span class="s2">"max("</span> <span class="o">+</span> <span class="n">labelCol</span> <span class="o">+</span> <span class="s2">")"</span><span class="p">])</span> <span class="o">+</span> <span class="mi">1</span> |
| <span class="p">)</span> |
| |
| <span class="n">weightCol</span> <span class="o">=</span> <span class="kc">None</span> |
| <span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">isDefined</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">weightCol</span><span class="p">)</span> <span class="ow">and</span> <span class="bp">self</span><span class="o">.</span><span class="n">getWeightCol</span><span class="p">():</span> |
| <span class="k">if</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">classifier</span><span class="p">,</span> <span class="n">HasWeightCol</span><span class="p">):</span> |
| <span class="n">weightCol</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">getWeightCol</span><span class="p">()</span> |
| <span class="k">else</span><span class="p">:</span> |
| <span class="n">warnings</span><span class="o">.</span><span class="n">warn</span><span class="p">(</span> |
| <span class="s2">"weightCol is ignored, "</span> <span class="s2">"as it is not supported by </span><span class="si">{}</span><span class="s2"> now."</span><span class="o">.</span><span class="n">format</span><span class="p">(</span><span class="n">classifier</span><span class="p">)</span> |
| <span class="p">)</span> |
| |
| <span class="k">if</span> <span class="n">weightCol</span><span class="p">:</span> |
| <span class="n">multiclassLabeled</span> <span class="o">=</span> <span class="n">dataset</span><span class="o">.</span><span class="n">select</span><span class="p">(</span><span class="n">labelCol</span><span class="p">,</span> <span class="n">featuresCol</span><span class="p">,</span> <span class="n">weightCol</span><span class="p">)</span> |
| <span class="k">else</span><span class="p">:</span> |
| <span class="n">multiclassLabeled</span> <span class="o">=</span> <span class="n">dataset</span><span class="o">.</span><span class="n">select</span><span class="p">(</span><span class="n">labelCol</span><span class="p">,</span> <span class="n">featuresCol</span><span class="p">)</span> |
| |
| <span class="c1"># persist if underlying dataset is not persistent.</span> |
| <span class="n">handlePersistence</span> <span class="o">=</span> <span class="n">dataset</span><span class="o">.</span><span class="n">storageLevel</span> <span class="o">==</span> <span class="n">StorageLevel</span><span class="p">(</span><span class="kc">False</span><span class="p">,</span> <span class="kc">False</span><span class="p">,</span> <span class="kc">False</span><span class="p">,</span> <span class="kc">False</span><span class="p">)</span> |
| <span class="k">if</span> <span class="n">handlePersistence</span><span class="p">:</span> |
| <span class="n">multiclassLabeled</span><span class="o">.</span><span class="n">persist</span><span class="p">(</span><span class="n">StorageLevel</span><span class="o">.</span><span class="n">MEMORY_AND_DISK</span><span class="p">)</span> |
| |
| <span class="k">def</span> <span class="nf">trainSingleClass</span><span class="p">(</span><span class="n">index</span><span class="p">:</span> <span class="nb">int</span><span class="p">)</span> <span class="o">-></span> <span class="n">CM</span><span class="p">:</span> |
| <span class="n">binaryLabelCol</span> <span class="o">=</span> <span class="s2">"mc2b$"</span> <span class="o">+</span> <span class="nb">str</span><span class="p">(</span><span class="n">index</span><span class="p">)</span> |
| <span class="n">trainingDataset</span> <span class="o">=</span> <span class="n">multiclassLabeled</span><span class="o">.</span><span class="n">withColumn</span><span class="p">(</span> |
| <span class="n">binaryLabelCol</span><span class="p">,</span> |
| <span class="n">when</span><span class="p">(</span><span class="n">multiclassLabeled</span><span class="p">[</span><span class="n">labelCol</span><span class="p">]</span> <span class="o">==</span> <span class="nb">float</span><span class="p">(</span><span class="n">index</span><span class="p">),</span> <span class="mf">1.0</span><span class="p">)</span><span class="o">.</span><span class="n">otherwise</span><span class="p">(</span><span class="mf">0.0</span><span class="p">),</span> |
| <span class="p">)</span> |
| <span class="n">paramMap</span> <span class="o">=</span> <span class="nb">dict</span><span class="p">(</span> |
| <span class="p">[</span> |
| <span class="p">(</span><span class="n">classifier</span><span class="o">.</span><span class="n">labelCol</span><span class="p">,</span> <span class="n">binaryLabelCol</span><span class="p">),</span> |
| <span class="p">(</span><span class="n">classifier</span><span class="o">.</span><span class="n">featuresCol</span><span class="p">,</span> <span class="n">featuresCol</span><span class="p">),</span> |
| <span class="p">(</span><span class="n">classifier</span><span class="o">.</span><span class="n">predictionCol</span><span class="p">,</span> <span class="n">predictionCol</span><span class="p">),</span> |
| <span class="p">]</span> |
| <span class="p">)</span> |
| <span class="k">if</span> <span class="n">weightCol</span><span class="p">:</span> |
| <span class="n">paramMap</span><span class="p">[</span><span class="n">cast</span><span class="p">(</span><span class="n">HasWeightCol</span><span class="p">,</span> <span class="n">classifier</span><span class="p">)</span><span class="o">.</span><span class="n">weightCol</span><span class="p">]</span> <span class="o">=</span> <span class="n">weightCol</span> |
| <span class="k">return</span> <span class="n">classifier</span><span class="o">.</span><span class="n">fit</span><span class="p">(</span><span class="n">trainingDataset</span><span class="p">,</span> <span class="n">paramMap</span><span class="p">)</span> |
| |
| <span class="n">pool</span> <span class="o">=</span> <span class="n">ThreadPool</span><span class="p">(</span><span class="n">processes</span><span class="o">=</span><span class="nb">min</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">getParallelism</span><span class="p">(),</span> <span class="n">numClasses</span><span class="p">))</span> |
| |
| <span class="n">models</span> <span class="o">=</span> <span class="n">pool</span><span class="o">.</span><span class="n">map</span><span class="p">(</span><span class="n">inheritable_thread_target</span><span class="p">(</span><span class="n">trainSingleClass</span><span class="p">),</span> <span class="nb">range</span><span class="p">(</span><span class="n">numClasses</span><span class="p">))</span> |
| |
| <span class="k">if</span> <span class="n">handlePersistence</span><span class="p">:</span> |
| <span class="n">multiclassLabeled</span><span class="o">.</span><span class="n">unpersist</span><span class="p">()</span> |
| |
| <span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">_copyValues</span><span class="p">(</span><span class="n">OneVsRestModel</span><span class="p">(</span><span class="n">models</span><span class="o">=</span><span class="n">models</span><span class="p">))</span> |
| |
| <div class="viewcode-block" id="OneVsRest.copy"><a class="viewcode-back" href="../../../reference/api/pyspark.ml.classification.OneVsRest.html#pyspark.ml.classification.OneVsRest.copy">[docs]</a> <span class="k">def</span> <span class="nf">copy</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">extra</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="s2">"ParamMap"</span><span class="p">]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">)</span> <span class="o">-></span> <span class="s2">"OneVsRest"</span><span class="p">:</span> |
| <span class="w"> </span><span class="sd">"""</span> |
| <span class="sd"> Creates a copy of this instance with a randomly generated uid</span> |
| <span class="sd"> and some extra params. This creates a deep copy of the embedded paramMap,</span> |
| <span class="sd"> and copies the embedded and extra parameters over.</span> |
| |
| <span class="sd"> .. versionadded:: 2.0.0</span> |
| |
| <span class="sd"> Examples</span> |
| <span class="sd"> --------</span> |
| <span class="sd"> extra : dict, optional</span> |
| <span class="sd"> Extra parameters to copy to the new instance</span> |
| |
| <span class="sd"> Returns</span> |
| <span class="sd"> -------</span> |
| <span class="sd"> :py:class:`OneVsRest`</span> |
| <span class="sd"> Copy of this instance</span> |
| <span class="sd"> """</span> |
| <span class="k">if</span> <span class="n">extra</span> <span class="ow">is</span> <span class="kc">None</span><span class="p">:</span> |
| <span class="n">extra</span> <span class="o">=</span> <span class="nb">dict</span><span class="p">()</span> |
| <span class="n">newOvr</span> <span class="o">=</span> <span class="n">Params</span><span class="o">.</span><span class="n">copy</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">extra</span><span class="p">)</span> |
| <span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">isSet</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">classifier</span><span class="p">):</span> |
| <span class="n">newOvr</span><span class="o">.</span><span class="n">setClassifier</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">getClassifier</span><span class="p">()</span><span class="o">.</span><span class="n">copy</span><span class="p">(</span><span class="n">extra</span><span class="p">))</span> |
| <span class="k">return</span> <span class="n">newOvr</span></div> |
| |
| <span class="nd">@classmethod</span> |
| <span class="k">def</span> <span class="nf">_from_java</span><span class="p">(</span><span class="bp">cls</span><span class="p">,</span> <span class="n">java_stage</span><span class="p">:</span> <span class="s2">"JavaObject"</span><span class="p">)</span> <span class="o">-></span> <span class="s2">"OneVsRest"</span><span class="p">:</span> |
| <span class="w"> </span><span class="sd">"""</span> |
| <span class="sd"> Given a Java OneVsRest, create and return a Python wrapper of it.</span> |
| <span class="sd"> Used for ML persistence.</span> |
| <span class="sd"> """</span> |
| <span class="n">featuresCol</span> <span class="o">=</span> <span class="n">java_stage</span><span class="o">.</span><span class="n">getFeaturesCol</span><span class="p">()</span> |
| <span class="n">labelCol</span> <span class="o">=</span> <span class="n">java_stage</span><span class="o">.</span><span class="n">getLabelCol</span><span class="p">()</span> |
| <span class="n">predictionCol</span> <span class="o">=</span> <span class="n">java_stage</span><span class="o">.</span><span class="n">getPredictionCol</span><span class="p">()</span> |
| <span class="n">rawPredictionCol</span> <span class="o">=</span> <span class="n">java_stage</span><span class="o">.</span><span class="n">getRawPredictionCol</span><span class="p">()</span> |
| <span class="n">classifier</span><span class="p">:</span> <span class="n">Classifier</span> <span class="o">=</span> <span class="n">JavaParams</span><span class="o">.</span><span class="n">_from_java</span><span class="p">(</span><span class="n">java_stage</span><span class="o">.</span><span class="n">getClassifier</span><span class="p">())</span> |
| <span class="n">parallelism</span> <span class="o">=</span> <span class="n">java_stage</span><span class="o">.</span><span class="n">getParallelism</span><span class="p">()</span> |
| <span class="n">py_stage</span> <span class="o">=</span> <span class="bp">cls</span><span class="p">(</span> |
| <span class="n">featuresCol</span><span class="o">=</span><span class="n">featuresCol</span><span class="p">,</span> |
| <span class="n">labelCol</span><span class="o">=</span><span class="n">labelCol</span><span class="p">,</span> |
| <span class="n">predictionCol</span><span class="o">=</span><span class="n">predictionCol</span><span class="p">,</span> |
| <span class="n">rawPredictionCol</span><span class="o">=</span><span class="n">rawPredictionCol</span><span class="p">,</span> |
| <span class="n">classifier</span><span class="o">=</span><span class="n">classifier</span><span class="p">,</span> |
| <span class="n">parallelism</span><span class="o">=</span><span class="n">parallelism</span><span class="p">,</span> |
| <span class="p">)</span> |
| <span class="k">if</span> <span class="n">java_stage</span><span class="o">.</span><span class="n">isDefined</span><span class="p">(</span><span class="n">java_stage</span><span class="o">.</span><span class="n">getParam</span><span class="p">(</span><span class="s2">"weightCol"</span><span class="p">)):</span> |
| <span class="n">py_stage</span><span class="o">.</span><span class="n">setWeightCol</span><span class="p">(</span><span class="n">java_stage</span><span class="o">.</span><span class="n">getWeightCol</span><span class="p">())</span> |
| <span class="n">py_stage</span><span class="o">.</span><span class="n">_resetUid</span><span class="p">(</span><span class="n">java_stage</span><span class="o">.</span><span class="n">uid</span><span class="p">())</span> |
| <span class="k">return</span> <span class="n">py_stage</span> |
| |
| <span class="k">def</span> <span class="nf">_to_java</span><span class="p">(</span><span class="bp">self</span><span class="p">)</span> <span class="o">-></span> <span class="s2">"JavaObject"</span><span class="p">:</span> |
| <span class="w"> </span><span class="sd">"""</span> |
| <span class="sd"> Transfer this instance to a Java OneVsRest. Used for ML persistence.</span> |
| |
| <span class="sd"> Returns</span> |
| <span class="sd"> -------</span> |
| <span class="sd"> py4j.java_gateway.JavaObject</span> |
| <span class="sd"> Java object equivalent to this instance.</span> |
| <span class="sd"> """</span> |
| <span class="n">_java_obj</span> <span class="o">=</span> <span class="n">JavaParams</span><span class="o">.</span><span class="n">_new_java_obj</span><span class="p">(</span> |
| <span class="s2">"org.apache.spark.ml.classification.OneVsRest"</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">uid</span> |
| <span class="p">)</span> |
| <span class="n">_java_obj</span><span class="o">.</span><span class="n">setClassifier</span><span class="p">(</span><span class="n">cast</span><span class="p">(</span><span class="n">_JavaClassifier</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">getClassifier</span><span class="p">())</span><span class="o">.</span><span class="n">_to_java</span><span class="p">())</span> |
| <span class="n">_java_obj</span><span class="o">.</span><span class="n">setParallelism</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">getParallelism</span><span class="p">())</span> |
| <span class="n">_java_obj</span><span class="o">.</span><span class="n">setFeaturesCol</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">getFeaturesCol</span><span class="p">())</span> |
| <span class="n">_java_obj</span><span class="o">.</span><span class="n">setLabelCol</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">getLabelCol</span><span class="p">())</span> |
| <span class="n">_java_obj</span><span class="o">.</span><span class="n">setPredictionCol</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">getPredictionCol</span><span class="p">())</span> |
| <span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">isDefined</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">weightCol</span><span class="p">)</span> <span class="ow">and</span> <span class="bp">self</span><span class="o">.</span><span class="n">getWeightCol</span><span class="p">():</span> |
| <span class="n">_java_obj</span><span class="o">.</span><span class="n">setWeightCol</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">getWeightCol</span><span class="p">())</span> |
| <span class="n">_java_obj</span><span class="o">.</span><span class="n">setRawPredictionCol</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">getRawPredictionCol</span><span class="p">())</span> |
| <span class="k">return</span> <span class="n">_java_obj</span> |
| |
| <div class="viewcode-block" id="OneVsRest.read"><a class="viewcode-back" href="../../../reference/api/pyspark.ml.classification.OneVsRest.html#pyspark.ml.classification.OneVsRest.read">[docs]</a> <span class="nd">@classmethod</span> |
| <span class="k">def</span> <span class="nf">read</span><span class="p">(</span><span class="bp">cls</span><span class="p">)</span> <span class="o">-></span> <span class="s2">"OneVsRestReader"</span><span class="p">:</span> |
| <span class="k">return</span> <span class="n">OneVsRestReader</span><span class="p">(</span><span class="bp">cls</span><span class="p">)</span></div> |
| |
| <div class="viewcode-block" id="OneVsRest.write"><a class="viewcode-back" href="../../../reference/api/pyspark.ml.classification.OneVsRest.html#pyspark.ml.classification.OneVsRest.write">[docs]</a> <span class="k">def</span> <span class="nf">write</span><span class="p">(</span><span class="bp">self</span><span class="p">)</span> <span class="o">-></span> <span class="n">MLWriter</span><span class="p">:</span> |
| <span class="k">if</span> <span class="nb">isinstance</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">getClassifier</span><span class="p">(),</span> <span class="n">JavaMLWritable</span><span class="p">):</span> |
| <span class="k">return</span> <span class="n">JavaMLWriter</span><span class="p">(</span><span class="bp">self</span><span class="p">)</span> <span class="c1"># type: ignore[arg-type]</span> |
| <span class="k">else</span><span class="p">:</span> |
| <span class="k">return</span> <span class="n">OneVsRestWriter</span><span class="p">(</span><span class="bp">self</span><span class="p">)</span></div></div> |
| |
| |
| <span class="k">class</span> <span class="nc">_OneVsRestSharedReadWrite</span><span class="p">:</span> |
| <span class="nd">@staticmethod</span> |
| <span class="k">def</span> <span class="nf">saveImpl</span><span class="p">(</span> |
| <span class="n">instance</span><span class="p">:</span> <span class="n">Union</span><span class="p">[</span><span class="n">OneVsRest</span><span class="p">,</span> <span class="s2">"OneVsRestModel"</span><span class="p">],</span> |
| <span class="n">sc</span><span class="p">:</span> <span class="n">SparkContext</span><span class="p">,</span> |
| <span class="n">path</span><span class="p">:</span> <span class="nb">str</span><span class="p">,</span> |
| <span class="n">extraMetadata</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="n">Dict</span><span class="p">[</span><span class="nb">str</span><span class="p">,</span> <span class="n">Any</span><span class="p">]]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span> |
| <span class="p">)</span> <span class="o">-></span> <span class="kc">None</span><span class="p">:</span> |
| <span class="n">skipParams</span> <span class="o">=</span> <span class="p">[</span><span class="s2">"classifier"</span><span class="p">]</span> |
| <span class="n">jsonParams</span> <span class="o">=</span> <span class="n">DefaultParamsWriter</span><span class="o">.</span><span class="n">extractJsonParams</span><span class="p">(</span><span class="n">instance</span><span class="p">,</span> <span class="n">skipParams</span><span class="p">)</span> |
| <span class="n">DefaultParamsWriter</span><span class="o">.</span><span class="n">saveMetadata</span><span class="p">(</span> |
| <span class="n">instance</span><span class="p">,</span> <span class="n">path</span><span class="p">,</span> <span class="n">sc</span><span class="p">,</span> <span class="n">paramMap</span><span class="o">=</span><span class="n">jsonParams</span><span class="p">,</span> <span class="n">extraMetadata</span><span class="o">=</span><span class="n">extraMetadata</span> |
| <span class="p">)</span> |
| <span class="n">classifierPath</span> <span class="o">=</span> <span class="n">os</span><span class="o">.</span><span class="n">path</span><span class="o">.</span><span class="n">join</span><span class="p">(</span><span class="n">path</span><span class="p">,</span> <span class="s2">"classifier"</span><span class="p">)</span> |
| <span class="n">cast</span><span class="p">(</span><span class="n">MLWritable</span><span class="p">,</span> <span class="n">instance</span><span class="o">.</span><span class="n">getClassifier</span><span class="p">())</span><span class="o">.</span><span class="n">save</span><span class="p">(</span><span class="n">classifierPath</span><span class="p">)</span> |
| |
| <span class="nd">@staticmethod</span> |
| <span class="k">def</span> <span class="nf">loadClassifier</span><span class="p">(</span><span class="n">path</span><span class="p">:</span> <span class="nb">str</span><span class="p">,</span> <span class="n">sc</span><span class="p">:</span> <span class="n">SparkContext</span><span class="p">)</span> <span class="o">-></span> <span class="n">Union</span><span class="p">[</span><span class="n">OneVsRest</span><span class="p">,</span> <span class="s2">"OneVsRestModel"</span><span class="p">]:</span> |
| <span class="n">classifierPath</span> <span class="o">=</span> <span class="n">os</span><span class="o">.</span><span class="n">path</span><span class="o">.</span><span class="n">join</span><span class="p">(</span><span class="n">path</span><span class="p">,</span> <span class="s2">"classifier"</span><span class="p">)</span> |
| <span class="k">return</span> <span class="n">DefaultParamsReader</span><span class="o">.</span><span class="n">loadParamsInstance</span><span class="p">(</span><span class="n">classifierPath</span><span class="p">,</span> <span class="n">sc</span><span class="p">)</span> |
| |
| <span class="nd">@staticmethod</span> |
| <span class="k">def</span> <span class="nf">validateParams</span><span class="p">(</span><span class="n">instance</span><span class="p">:</span> <span class="n">Union</span><span class="p">[</span><span class="n">OneVsRest</span><span class="p">,</span> <span class="s2">"OneVsRestModel"</span><span class="p">])</span> <span class="o">-></span> <span class="kc">None</span><span class="p">:</span> |
| <span class="n">elems_to_check</span><span class="p">:</span> <span class="n">List</span><span class="p">[</span><span class="n">Params</span><span class="p">]</span> <span class="o">=</span> <span class="p">[</span><span class="n">instance</span><span class="o">.</span><span class="n">getClassifier</span><span class="p">()]</span> |
| <span class="k">if</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">instance</span><span class="p">,</span> <span class="n">OneVsRestModel</span><span class="p">):</span> |
| <span class="n">elems_to_check</span><span class="o">.</span><span class="n">extend</span><span class="p">(</span><span class="n">instance</span><span class="o">.</span><span class="n">models</span><span class="p">)</span> |
| |
| <span class="k">for</span> <span class="n">elem</span> <span class="ow">in</span> <span class="n">elems_to_check</span><span class="p">:</span> |
| <span class="k">if</span> <span class="ow">not</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">elem</span><span class="p">,</span> <span class="n">MLWritable</span><span class="p">):</span> |
| <span class="k">raise</span> <span class="ne">ValueError</span><span class="p">(</span> |
| <span class="sa">f</span><span class="s2">"OneVsRest write will fail because it contains </span><span class="si">{</span><span class="n">elem</span><span class="o">.</span><span class="n">uid</span><span class="si">}</span><span class="s2"> "</span> |
| <span class="sa">f</span><span class="s2">"which is not writable."</span> |
| <span class="p">)</span> |
| |
| |
| <span class="nd">@inherit_doc</span> |
| <span class="k">class</span> <span class="nc">OneVsRestReader</span><span class="p">(</span><span class="n">MLReader</span><span class="p">[</span><span class="n">OneVsRest</span><span class="p">]):</span> |
| <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="bp">cls</span><span class="p">:</span> <span class="n">Type</span><span class="p">[</span><span class="n">OneVsRest</span><span class="p">])</span> <span class="o">-></span> <span class="kc">None</span><span class="p">:</span> |
| <span class="nb">super</span><span class="p">(</span><span class="n">OneVsRestReader</span><span class="p">,</span> <span class="bp">self</span><span class="p">)</span><span class="o">.</span><span class="fm">__init__</span><span class="p">()</span> |
| <span class="bp">self</span><span class="o">.</span><span class="n">cls</span> <span class="o">=</span> <span class="bp">cls</span> |
| |
| <span class="k">def</span> <span class="nf">load</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">path</span><span class="p">:</span> <span class="nb">str</span><span class="p">)</span> <span class="o">-></span> <span class="n">OneVsRest</span><span class="p">:</span> |
| <span class="n">metadata</span> <span class="o">=</span> <span class="n">DefaultParamsReader</span><span class="o">.</span><span class="n">loadMetadata</span><span class="p">(</span><span class="n">path</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">sc</span><span class="p">)</span> |
| <span class="k">if</span> <span class="ow">not</span> <span class="n">DefaultParamsReader</span><span class="o">.</span><span class="n">isPythonParamsInstance</span><span class="p">(</span><span class="n">metadata</span><span class="p">):</span> |
| <span class="k">return</span> <span class="n">JavaMLReader</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">cls</span><span class="p">)</span><span class="o">.</span><span class="n">load</span><span class="p">(</span><span class="n">path</span><span class="p">)</span> <span class="c1"># type: ignore[arg-type]</span> |
| <span class="k">else</span><span class="p">:</span> |
| <span class="n">classifier</span> <span class="o">=</span> <span class="n">cast</span><span class="p">(</span><span class="n">Classifier</span><span class="p">,</span> <span class="n">_OneVsRestSharedReadWrite</span><span class="o">.</span><span class="n">loadClassifier</span><span class="p">(</span><span class="n">path</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">sc</span><span class="p">))</span> |
| <span class="n">ova</span><span class="p">:</span> <span class="n">OneVsRest</span> <span class="o">=</span> <span class="n">OneVsRest</span><span class="p">(</span><span class="n">classifier</span><span class="o">=</span><span class="n">classifier</span><span class="p">)</span><span class="o">.</span><span class="n">_resetUid</span><span class="p">(</span><span class="n">metadata</span><span class="p">[</span><span class="s2">"uid"</span><span class="p">])</span> |
| <span class="n">DefaultParamsReader</span><span class="o">.</span><span class="n">getAndSetParams</span><span class="p">(</span><span class="n">ova</span><span class="p">,</span> <span class="n">metadata</span><span class="p">,</span> <span class="n">skipParams</span><span class="o">=</span><span class="p">[</span><span class="s2">"classifier"</span><span class="p">])</span> |
| <span class="k">return</span> <span class="n">ova</span> |
| |
| |
| <span class="nd">@inherit_doc</span> |
| <span class="k">class</span> <span class="nc">OneVsRestWriter</span><span class="p">(</span><span class="n">MLWriter</span><span class="p">):</span> |
| <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">instance</span><span class="p">:</span> <span class="n">OneVsRest</span><span class="p">):</span> |
| <span class="nb">super</span><span class="p">(</span><span class="n">OneVsRestWriter</span><span class="p">,</span> <span class="bp">self</span><span class="p">)</span><span class="o">.</span><span class="fm">__init__</span><span class="p">()</span> |
| <span class="bp">self</span><span class="o">.</span><span class="n">instance</span> <span class="o">=</span> <span class="n">instance</span> |
| |
| <span class="k">def</span> <span class="nf">saveImpl</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">path</span><span class="p">:</span> <span class="nb">str</span><span class="p">)</span> <span class="o">-></span> <span class="kc">None</span><span class="p">:</span> |
| <span class="n">_OneVsRestSharedReadWrite</span><span class="o">.</span><span class="n">validateParams</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">instance</span><span class="p">)</span> |
| <span class="n">_OneVsRestSharedReadWrite</span><span class="o">.</span><span class="n">saveImpl</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">instance</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">sc</span><span class="p">,</span> <span class="n">path</span><span class="p">)</span> |
| |
| |
| <div class="viewcode-block" id="OneVsRestModel"><a class="viewcode-back" href="../../../reference/api/pyspark.ml.classification.OneVsRestModel.html#pyspark.ml.classification.OneVsRestModel">[docs]</a><span class="k">class</span> <span class="nc">OneVsRestModel</span><span class="p">(</span> |
| <span class="n">Model</span><span class="p">,</span> |
| <span class="n">_OneVsRestParams</span><span class="p">,</span> |
| <span class="n">MLReadable</span><span class="p">[</span><span class="s2">"OneVsRestModel"</span><span class="p">],</span> |
| <span class="n">MLWritable</span><span class="p">,</span> |
| <span class="p">):</span> |
| <span class="w"> </span><span class="sd">"""</span> |
| <span class="sd"> Model fitted by OneVsRest.</span> |
| <span class="sd"> This stores the models resulting from training k binary classifiers: one for each class.</span> |
| <span class="sd"> Each example is scored against all k models, and the model with the highest score</span> |
| <span class="sd"> is picked to label the example.</span> |
| |
| <span class="sd"> .. versionadded:: 2.0.0</span> |
| <span class="sd"> """</span> |
| |
| <div class="viewcode-block" id="OneVsRestModel.setFeaturesCol"><a class="viewcode-back" href="../../../reference/api/pyspark.ml.classification.OneVsRestModel.html#pyspark.ml.classification.OneVsRestModel.setFeaturesCol">[docs]</a> <span class="k">def</span> <span class="nf">setFeaturesCol</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">value</span><span class="p">:</span> <span class="nb">str</span><span class="p">)</span> <span class="o">-></span> <span class="s2">"OneVsRestModel"</span><span class="p">:</span> |
| <span class="w"> </span><span class="sd">"""</span> |
| <span class="sd"> Sets the value of :py:attr:`featuresCol`.</span> |
| <span class="sd"> """</span> |
| <span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">_set</span><span class="p">(</span><span class="n">featuresCol</span><span class="o">=</span><span class="n">value</span><span class="p">)</span></div> |
| |
| <div class="viewcode-block" id="OneVsRestModel.setPredictionCol"><a class="viewcode-back" href="../../../reference/api/pyspark.ml.classification.OneVsRestModel.html#pyspark.ml.classification.OneVsRestModel.setPredictionCol">[docs]</a> <span class="k">def</span> <span class="nf">setPredictionCol</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">value</span><span class="p">:</span> <span class="nb">str</span><span class="p">)</span> <span class="o">-></span> <span class="s2">"OneVsRestModel"</span><span class="p">:</span> |
| <span class="w"> </span><span class="sd">"""</span> |
| <span class="sd"> Sets the value of :py:attr:`predictionCol`.</span> |
| <span class="sd"> """</span> |
| <span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">_set</span><span class="p">(</span><span class="n">predictionCol</span><span class="o">=</span><span class="n">value</span><span class="p">)</span></div> |
| |
| <div class="viewcode-block" id="OneVsRestModel.setRawPredictionCol"><a class="viewcode-back" href="../../../reference/api/pyspark.ml.classification.OneVsRestModel.html#pyspark.ml.classification.OneVsRestModel.setRawPredictionCol">[docs]</a> <span class="k">def</span> <span class="nf">setRawPredictionCol</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">value</span><span class="p">:</span> <span class="nb">str</span><span class="p">)</span> <span class="o">-></span> <span class="s2">"OneVsRestModel"</span><span class="p">:</span> |
| <span class="w"> </span><span class="sd">"""</span> |
| <span class="sd"> Sets the value of :py:attr:`rawPredictionCol`.</span> |
| <span class="sd"> """</span> |
| <span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">_set</span><span class="p">(</span><span class="n">rawPredictionCol</span><span class="o">=</span><span class="n">value</span><span class="p">)</span></div> |
| |
| <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">models</span><span class="p">:</span> <span class="n">List</span><span class="p">[</span><span class="n">ClassificationModel</span><span class="p">]):</span> |
| <span class="nb">super</span><span class="p">(</span><span class="n">OneVsRestModel</span><span class="p">,</span> <span class="bp">self</span><span class="p">)</span><span class="o">.</span><span class="fm">__init__</span><span class="p">()</span> |
| <span class="bp">self</span><span class="o">.</span><span class="n">models</span> <span class="o">=</span> <span class="n">models</span> |
| <span class="k">if</span> <span class="ow">not</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">models</span><span class="p">[</span><span class="mi">0</span><span class="p">],</span> <span class="n">JavaMLWritable</span><span class="p">):</span> |
| <span class="k">return</span> |
| <span class="c1"># set java instance</span> |
| <span class="n">java_models</span> <span class="o">=</span> <span class="p">[</span><span class="n">cast</span><span class="p">(</span><span class="n">_JavaClassificationModel</span><span class="p">,</span> <span class="n">model</span><span class="p">)</span><span class="o">.</span><span class="n">_to_java</span><span class="p">()</span> <span class="k">for</span> <span class="n">model</span> <span class="ow">in</span> <span class="bp">self</span><span class="o">.</span><span class="n">models</span><span class="p">]</span> |
| <span class="n">sc</span> <span class="o">=</span> <span class="n">SparkContext</span><span class="o">.</span><span class="n">_active_spark_context</span> |
| <span class="k">assert</span> <span class="n">sc</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span> <span class="ow">and</span> <span class="n">sc</span><span class="o">.</span><span class="n">_gateway</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span> |
| |
| <span class="n">java_models_array</span> <span class="o">=</span> <span class="n">JavaWrapper</span><span class="o">.</span><span class="n">_new_java_array</span><span class="p">(</span> |
| <span class="n">java_models</span><span class="p">,</span> <span class="n">sc</span><span class="o">.</span><span class="n">_gateway</span><span class="o">.</span><span class="n">jvm</span><span class="o">.</span><span class="n">org</span><span class="o">.</span><span class="n">apache</span><span class="o">.</span><span class="n">spark</span><span class="o">.</span><span class="n">ml</span><span class="o">.</span><span class="n">classification</span><span class="o">.</span><span class="n">ClassificationModel</span> |
| <span class="p">)</span> |
| <span class="c1"># TODO: need to set metadata</span> |
| <span class="n">metadata</span> <span class="o">=</span> <span class="n">JavaParams</span><span class="o">.</span><span class="n">_new_java_obj</span><span class="p">(</span><span class="s2">"org.apache.spark.sql.types.Metadata"</span><span class="p">)</span> |
| <span class="bp">self</span><span class="o">.</span><span class="n">_java_obj</span> <span class="o">=</span> <span class="n">JavaParams</span><span class="o">.</span><span class="n">_new_java_obj</span><span class="p">(</span> |
| <span class="s2">"org.apache.spark.ml.classification.OneVsRestModel"</span><span class="p">,</span> |
| <span class="bp">self</span><span class="o">.</span><span class="n">uid</span><span class="p">,</span> |
| <span class="n">metadata</span><span class="o">.</span><span class="n">empty</span><span class="p">(),</span> |
| <span class="n">java_models_array</span><span class="p">,</span> |
| <span class="p">)</span> |
| |
| <span class="k">def</span> <span class="nf">_transform</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">dataset</span><span class="p">:</span> <span class="n">DataFrame</span><span class="p">)</span> <span class="o">-></span> <span class="n">DataFrame</span><span class="p">:</span> |
| <span class="c1"># determine the input columns: these need to be passed through</span> |
| <span class="n">origCols</span> <span class="o">=</span> <span class="n">dataset</span><span class="o">.</span><span class="n">columns</span> |
| |
| <span class="c1"># add an accumulator column to store predictions of all the models</span> |
| <span class="n">accColName</span> <span class="o">=</span> <span class="s2">"mbc$acc"</span> <span class="o">+</span> <span class="nb">str</span><span class="p">(</span><span class="n">uuid</span><span class="o">.</span><span class="n">uuid4</span><span class="p">())</span> |
| <span class="n">initUDF</span> <span class="o">=</span> <span class="n">udf</span><span class="p">(</span><span class="k">lambda</span> <span class="n">_</span><span class="p">:</span> <span class="p">[],</span> <span class="n">ArrayType</span><span class="p">(</span><span class="n">DoubleType</span><span class="p">()))</span> |
| <span class="n">newDataset</span> <span class="o">=</span> <span class="n">dataset</span><span class="o">.</span><span class="n">withColumn</span><span class="p">(</span><span class="n">accColName</span><span class="p">,</span> <span class="n">initUDF</span><span class="p">(</span><span class="n">dataset</span><span class="p">[</span><span class="n">origCols</span><span class="p">[</span><span class="mi">0</span><span class="p">]]))</span> |
| |
| <span class="c1"># persist if underlying dataset is not persistent.</span> |
| <span class="n">handlePersistence</span> <span class="o">=</span> <span class="n">dataset</span><span class="o">.</span><span class="n">storageLevel</span> <span class="o">==</span> <span class="n">StorageLevel</span><span class="p">(</span><span class="kc">False</span><span class="p">,</span> <span class="kc">False</span><span class="p">,</span> <span class="kc">False</span><span class="p">,</span> <span class="kc">False</span><span class="p">)</span> |
| <span class="k">if</span> <span class="n">handlePersistence</span><span class="p">:</span> |
| <span class="n">newDataset</span><span class="o">.</span><span class="n">persist</span><span class="p">(</span><span class="n">StorageLevel</span><span class="o">.</span><span class="n">MEMORY_AND_DISK</span><span class="p">)</span> |
| |
| <span class="c1"># update the accumulator column with the result of prediction of models</span> |
| <span class="n">aggregatedDataset</span> <span class="o">=</span> <span class="n">newDataset</span> |
| <span class="k">for</span> <span class="n">index</span><span class="p">,</span> <span class="n">model</span> <span class="ow">in</span> <span class="nb">enumerate</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">models</span><span class="p">):</span> |
| <span class="n">rawPredictionCol</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">getRawPredictionCol</span><span class="p">()</span> |
| |
| <span class="n">columns</span> <span class="o">=</span> <span class="n">origCols</span> <span class="o">+</span> <span class="p">[</span><span class="n">rawPredictionCol</span><span class="p">,</span> <span class="n">accColName</span><span class="p">]</span> |
| |
| <span class="c1"># add temporary column to store intermediate scores and update</span> |
| <span class="n">tmpColName</span> <span class="o">=</span> <span class="s2">"mbc$tmp"</span> <span class="o">+</span> <span class="nb">str</span><span class="p">(</span><span class="n">uuid</span><span class="o">.</span><span class="n">uuid4</span><span class="p">())</span> |
| <span class="n">updateUDF</span> <span class="o">=</span> <span class="n">udf</span><span class="p">(</span> |
| <span class="k">lambda</span> <span class="n">predictions</span><span class="p">,</span> <span class="n">prediction</span><span class="p">:</span> <span class="n">predictions</span> <span class="o">+</span> <span class="p">[</span><span class="n">prediction</span><span class="o">.</span><span class="n">tolist</span><span class="p">()[</span><span class="mi">1</span><span class="p">]],</span> |
| <span class="n">ArrayType</span><span class="p">(</span><span class="n">DoubleType</span><span class="p">()),</span> |
| <span class="p">)</span> |
| <span class="n">transformedDataset</span> <span class="o">=</span> <span class="n">model</span><span class="o">.</span><span class="n">transform</span><span class="p">(</span><span class="n">aggregatedDataset</span><span class="p">)</span><span class="o">.</span><span class="n">select</span><span class="p">(</span><span class="o">*</span><span class="n">columns</span><span class="p">)</span> |
| <span class="n">updatedDataset</span> <span class="o">=</span> <span class="n">transformedDataset</span><span class="o">.</span><span class="n">withColumn</span><span class="p">(</span> |
| <span class="n">tmpColName</span><span class="p">,</span> |
| <span class="n">updateUDF</span><span class="p">(</span><span class="n">transformedDataset</span><span class="p">[</span><span class="n">accColName</span><span class="p">],</span> <span class="n">transformedDataset</span><span class="p">[</span><span class="n">rawPredictionCol</span><span class="p">]),</span> |
| <span class="p">)</span> |
| <span class="n">newColumns</span> <span class="o">=</span> <span class="n">origCols</span> <span class="o">+</span> <span class="p">[</span><span class="n">tmpColName</span><span class="p">]</span> |
| |
| <span class="c1"># switch out the intermediate column with the accumulator column</span> |
| <span class="n">aggregatedDataset</span> <span class="o">=</span> <span class="n">updatedDataset</span><span class="o">.</span><span class="n">select</span><span class="p">(</span><span class="o">*</span><span class="n">newColumns</span><span class="p">)</span><span class="o">.</span><span class="n">withColumnRenamed</span><span class="p">(</span> |
| <span class="n">tmpColName</span><span class="p">,</span> <span class="n">accColName</span> |
| <span class="p">)</span> |
| |
| <span class="k">if</span> <span class="n">handlePersistence</span><span class="p">:</span> |
| <span class="n">newDataset</span><span class="o">.</span><span class="n">unpersist</span><span class="p">()</span> |
| |
| <span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">getRawPredictionCol</span><span class="p">():</span> |
| |
| <span class="k">def</span> <span class="nf">func</span><span class="p">(</span><span class="n">predictions</span><span class="p">:</span> <span class="n">Iterable</span><span class="p">[</span><span class="nb">float</span><span class="p">])</span> <span class="o">-></span> <span class="n">Vector</span><span class="p">:</span> |
| <span class="n">predArray</span><span class="p">:</span> <span class="n">List</span><span class="p">[</span><span class="nb">float</span><span class="p">]</span> <span class="o">=</span> <span class="p">[]</span> |
| <span class="k">for</span> <span class="n">x</span> <span class="ow">in</span> <span class="n">predictions</span><span class="p">:</span> |
| <span class="n">predArray</span><span class="o">.</span><span class="n">append</span><span class="p">(</span><span class="n">x</span><span class="p">)</span> |
| <span class="k">return</span> <span class="n">Vectors</span><span class="o">.</span><span class="n">dense</span><span class="p">(</span><span class="n">predArray</span><span class="p">)</span> |
| |
| <span class="n">rawPredictionUDF</span> <span class="o">=</span> <span class="n">udf</span><span class="p">(</span><span class="n">func</span><span class="p">,</span> <span class="n">VectorUDT</span><span class="p">())</span> |
| <span class="n">aggregatedDataset</span> <span class="o">=</span> <span class="n">aggregatedDataset</span><span class="o">.</span><span class="n">withColumn</span><span class="p">(</span> |
| <span class="bp">self</span><span class="o">.</span><span class="n">getRawPredictionCol</span><span class="p">(),</span> <span class="n">rawPredictionUDF</span><span class="p">(</span><span class="n">aggregatedDataset</span><span class="p">[</span><span class="n">accColName</span><span class="p">])</span> |
| <span class="p">)</span> |
| |
| <span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">getPredictionCol</span><span class="p">():</span> |
| <span class="c1"># output the index of the classifier with highest confidence as prediction</span> |
| <span class="n">labelUDF</span> <span class="o">=</span> <span class="n">udf</span><span class="p">(</span> |
| <span class="k">lambda</span> <span class="n">predictions</span><span class="p">:</span> <span class="nb">float</span><span class="p">(</span> |
| <span class="nb">max</span><span class="p">(</span><span class="nb">enumerate</span><span class="p">(</span><span class="n">predictions</span><span class="p">),</span> <span class="n">key</span><span class="o">=</span><span class="n">operator</span><span class="o">.</span><span class="n">itemgetter</span><span class="p">(</span><span class="mi">1</span><span class="p">))[</span><span class="mi">0</span><span class="p">]</span> |
| <span class="p">),</span> |
| <span class="n">DoubleType</span><span class="p">(),</span> |
| <span class="p">)</span> |
| <span class="n">aggregatedDataset</span> <span class="o">=</span> <span class="n">aggregatedDataset</span><span class="o">.</span><span class="n">withColumn</span><span class="p">(</span> |
| <span class="bp">self</span><span class="o">.</span><span class="n">getPredictionCol</span><span class="p">(),</span> <span class="n">labelUDF</span><span class="p">(</span><span class="n">aggregatedDataset</span><span class="p">[</span><span class="n">accColName</span><span class="p">])</span> |
| <span class="p">)</span> |
| <span class="k">return</span> <span class="n">aggregatedDataset</span><span class="o">.</span><span class="n">drop</span><span class="p">(</span><span class="n">accColName</span><span class="p">)</span> |
| |
| <div class="viewcode-block" id="OneVsRestModel.copy"><a class="viewcode-back" href="../../../reference/api/pyspark.ml.classification.OneVsRestModel.html#pyspark.ml.classification.OneVsRestModel.copy">[docs]</a> <span class="k">def</span> <span class="nf">copy</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">extra</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="s2">"ParamMap"</span><span class="p">]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">)</span> <span class="o">-></span> <span class="s2">"OneVsRestModel"</span><span class="p">:</span> |
| <span class="w"> </span><span class="sd">"""</span> |
| <span class="sd"> Creates a copy of this instance with a randomly generated uid</span> |
| <span class="sd"> and some extra params. This creates a deep copy of the embedded paramMap,</span> |
| <span class="sd"> and copies the embedded and extra parameters over.</span> |
| |
| <span class="sd"> .. versionadded:: 2.0.0</span> |
| |
| <span class="sd"> Parameters</span> |
| <span class="sd"> ----------</span> |
| <span class="sd"> extra : dict, optional</span> |
| <span class="sd"> Extra parameters to copy to the new instance</span> |
| |
| <span class="sd"> Returns</span> |
| <span class="sd"> -------</span> |
| <span class="sd"> :py:class:`OneVsRestModel`</span> |
| <span class="sd"> Copy of this instance</span> |
| <span class="sd"> """</span> |
| <span class="k">if</span> <span class="n">extra</span> <span class="ow">is</span> <span class="kc">None</span><span class="p">:</span> |
| <span class="n">extra</span> <span class="o">=</span> <span class="nb">dict</span><span class="p">()</span> |
| <span class="n">newModel</span> <span class="o">=</span> <span class="n">Params</span><span class="o">.</span><span class="n">copy</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">extra</span><span class="p">)</span> |
| <span class="n">newModel</span><span class="o">.</span><span class="n">models</span> <span class="o">=</span> <span class="p">[</span><span class="n">model</span><span class="o">.</span><span class="n">copy</span><span class="p">(</span><span class="n">extra</span><span class="p">)</span> <span class="k">for</span> <span class="n">model</span> <span class="ow">in</span> <span class="bp">self</span><span class="o">.</span><span class="n">models</span><span class="p">]</span> |
| <span class="k">return</span> <span class="n">newModel</span></div> |
| |
| <span class="nd">@classmethod</span> |
| <span class="k">def</span> <span class="nf">_from_java</span><span class="p">(</span><span class="bp">cls</span><span class="p">,</span> <span class="n">java_stage</span><span class="p">:</span> <span class="s2">"JavaObject"</span><span class="p">)</span> <span class="o">-></span> <span class="s2">"OneVsRestModel"</span><span class="p">:</span> |
| <span class="w"> </span><span class="sd">"""</span> |
| <span class="sd"> Given a Java OneVsRestModel, create and return a Python wrapper of it.</span> |
| <span class="sd"> Used for ML persistence.</span> |
| <span class="sd"> """</span> |
| <span class="n">featuresCol</span> <span class="o">=</span> <span class="n">java_stage</span><span class="o">.</span><span class="n">getFeaturesCol</span><span class="p">()</span> |
| <span class="n">labelCol</span> <span class="o">=</span> <span class="n">java_stage</span><span class="o">.</span><span class="n">getLabelCol</span><span class="p">()</span> |
| <span class="n">predictionCol</span> <span class="o">=</span> <span class="n">java_stage</span><span class="o">.</span><span class="n">getPredictionCol</span><span class="p">()</span> |
| <span class="n">classifier</span><span class="p">:</span> <span class="n">Classifier</span> <span class="o">=</span> <span class="n">JavaParams</span><span class="o">.</span><span class="n">_from_java</span><span class="p">(</span><span class="n">java_stage</span><span class="o">.</span><span class="n">getClassifier</span><span class="p">())</span> |
| <span class="n">models</span><span class="p">:</span> <span class="n">List</span><span class="p">[</span><span class="n">ClassificationModel</span><span class="p">]</span> <span class="o">=</span> <span class="p">[</span> |
| <span class="n">JavaParams</span><span class="o">.</span><span class="n">_from_java</span><span class="p">(</span><span class="n">model</span><span class="p">)</span> <span class="k">for</span> <span class="n">model</span> <span class="ow">in</span> <span class="n">java_stage</span><span class="o">.</span><span class="n">models</span><span class="p">()</span> |
| <span class="p">]</span> |
| <span class="n">py_stage</span> <span class="o">=</span> <span class="bp">cls</span><span class="p">(</span><span class="n">models</span><span class="o">=</span><span class="n">models</span><span class="p">)</span><span class="o">.</span><span class="n">setPredictionCol</span><span class="p">(</span><span class="n">predictionCol</span><span class="p">)</span><span class="o">.</span><span class="n">setFeaturesCol</span><span class="p">(</span><span class="n">featuresCol</span><span class="p">)</span> |
| <span class="n">py_stage</span><span class="o">.</span><span class="n">_set</span><span class="p">(</span><span class="n">labelCol</span><span class="o">=</span><span class="n">labelCol</span><span class="p">)</span> |
| <span class="k">if</span> <span class="n">java_stage</span><span class="o">.</span><span class="n">isDefined</span><span class="p">(</span><span class="n">java_stage</span><span class="o">.</span><span class="n">getParam</span><span class="p">(</span><span class="s2">"weightCol"</span><span class="p">)):</span> |
| <span class="n">py_stage</span><span class="o">.</span><span class="n">_set</span><span class="p">(</span><span class="n">weightCol</span><span class="o">=</span><span class="n">java_stage</span><span class="o">.</span><span class="n">getWeightCol</span><span class="p">())</span> |
| <span class="n">py_stage</span><span class="o">.</span><span class="n">_set</span><span class="p">(</span><span class="n">classifier</span><span class="o">=</span><span class="n">classifier</span><span class="p">)</span> |
| <span class="n">py_stage</span><span class="o">.</span><span class="n">_resetUid</span><span class="p">(</span><span class="n">java_stage</span><span class="o">.</span><span class="n">uid</span><span class="p">())</span> |
| <span class="k">return</span> <span class="n">py_stage</span> |
| |
| <span class="k">def</span> <span class="nf">_to_java</span><span class="p">(</span><span class="bp">self</span><span class="p">)</span> <span class="o">-></span> <span class="s2">"JavaObject"</span><span class="p">:</span> |
| <span class="w"> </span><span class="sd">"""</span> |
| <span class="sd"> Transfer this instance to a Java OneVsRestModel. Used for ML persistence.</span> |
| |
| <span class="sd"> Returns</span> |
| <span class="sd"> -------</span> |
| <span class="sd"> py4j.java_gateway.JavaObject</span> |
| <span class="sd"> Java object equivalent to this instance.</span> |
| <span class="sd"> """</span> |
| <span class="n">sc</span> <span class="o">=</span> <span class="n">SparkContext</span><span class="o">.</span><span class="n">_active_spark_context</span> |
| <span class="k">assert</span> <span class="n">sc</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span> <span class="ow">and</span> <span class="n">sc</span><span class="o">.</span><span class="n">_gateway</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span> |
| |
| <span class="n">java_models</span> <span class="o">=</span> <span class="p">[</span><span class="n">cast</span><span class="p">(</span><span class="n">_JavaClassificationModel</span><span class="p">,</span> <span class="n">model</span><span class="p">)</span><span class="o">.</span><span class="n">_to_java</span><span class="p">()</span> <span class="k">for</span> <span class="n">model</span> <span class="ow">in</span> <span class="bp">self</span><span class="o">.</span><span class="n">models</span><span class="p">]</span> |
| <span class="n">java_models_array</span> <span class="o">=</span> <span class="n">JavaWrapper</span><span class="o">.</span><span class="n">_new_java_array</span><span class="p">(</span> |
| <span class="n">java_models</span><span class="p">,</span> <span class="n">sc</span><span class="o">.</span><span class="n">_gateway</span><span class="o">.</span><span class="n">jvm</span><span class="o">.</span><span class="n">org</span><span class="o">.</span><span class="n">apache</span><span class="o">.</span><span class="n">spark</span><span class="o">.</span><span class="n">ml</span><span class="o">.</span><span class="n">classification</span><span class="o">.</span><span class="n">ClassificationModel</span> |
| <span class="p">)</span> |
| <span class="n">metadata</span> <span class="o">=</span> <span class="n">JavaParams</span><span class="o">.</span><span class="n">_new_java_obj</span><span class="p">(</span><span class="s2">"org.apache.spark.sql.types.Metadata"</span><span class="p">)</span> |
| <span class="n">_java_obj</span> <span class="o">=</span> <span class="n">JavaParams</span><span class="o">.</span><span class="n">_new_java_obj</span><span class="p">(</span> |
| <span class="s2">"org.apache.spark.ml.classification.OneVsRestModel"</span><span class="p">,</span> |
| <span class="bp">self</span><span class="o">.</span><span class="n">uid</span><span class="p">,</span> |
| <span class="n">metadata</span><span class="o">.</span><span class="n">empty</span><span class="p">(),</span> |
| <span class="n">java_models_array</span><span class="p">,</span> |
| <span class="p">)</span> |
| <span class="n">_java_obj</span><span class="o">.</span><span class="n">set</span><span class="p">(</span><span class="s2">"classifier"</span><span class="p">,</span> <span class="n">cast</span><span class="p">(</span><span class="n">_JavaClassifier</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">getClassifier</span><span class="p">())</span><span class="o">.</span><span class="n">_to_java</span><span class="p">())</span> |
| <span class="n">_java_obj</span><span class="o">.</span><span class="n">set</span><span class="p">(</span><span class="s2">"featuresCol"</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">getFeaturesCol</span><span class="p">())</span> |
| <span class="n">_java_obj</span><span class="o">.</span><span class="n">set</span><span class="p">(</span><span class="s2">"labelCol"</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">getLabelCol</span><span class="p">())</span> |
| <span class="n">_java_obj</span><span class="o">.</span><span class="n">set</span><span class="p">(</span><span class="s2">"predictionCol"</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">getPredictionCol</span><span class="p">())</span> |
| <span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">isDefined</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">weightCol</span><span class="p">)</span> <span class="ow">and</span> <span class="bp">self</span><span class="o">.</span><span class="n">getWeightCol</span><span class="p">():</span> |
| <span class="n">_java_obj</span><span class="o">.</span><span class="n">set</span><span class="p">(</span><span class="s2">"weightCol"</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">getWeightCol</span><span class="p">())</span> |
| <span class="k">return</span> <span class="n">_java_obj</span> |
| |
| <div class="viewcode-block" id="OneVsRestModel.read"><a class="viewcode-back" href="../../../reference/api/pyspark.ml.classification.OneVsRestModel.html#pyspark.ml.classification.OneVsRestModel.read">[docs]</a> <span class="nd">@classmethod</span> |
| <span class="k">def</span> <span class="nf">read</span><span class="p">(</span><span class="bp">cls</span><span class="p">)</span> <span class="o">-></span> <span class="s2">"OneVsRestModelReader"</span><span class="p">:</span> |
| <span class="k">return</span> <span class="n">OneVsRestModelReader</span><span class="p">(</span><span class="bp">cls</span><span class="p">)</span></div> |
| |
| <div class="viewcode-block" id="OneVsRestModel.write"><a class="viewcode-back" href="../../../reference/api/pyspark.ml.classification.OneVsRestModel.html#pyspark.ml.classification.OneVsRestModel.write">[docs]</a> <span class="k">def</span> <span class="nf">write</span><span class="p">(</span><span class="bp">self</span><span class="p">)</span> <span class="o">-></span> <span class="n">MLWriter</span><span class="p">:</span> |
| <span class="k">if</span> <span class="nb">all</span><span class="p">(</span> |
| <span class="nb">map</span><span class="p">(</span> |
| <span class="k">lambda</span> <span class="n">elem</span><span class="p">:</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">elem</span><span class="p">,</span> <span class="n">JavaMLWritable</span><span class="p">),</span> |
| <span class="p">[</span><span class="bp">self</span><span class="o">.</span><span class="n">getClassifier</span><span class="p">()]</span> <span class="o">+</span> <span class="bp">self</span><span class="o">.</span><span class="n">models</span><span class="p">,</span> <span class="c1"># type: ignore[operator]</span> |
| <span class="p">)</span> |
| <span class="p">):</span> |
| <span class="k">return</span> <span class="n">JavaMLWriter</span><span class="p">(</span><span class="bp">self</span><span class="p">)</span> <span class="c1"># type: ignore[arg-type]</span> |
| <span class="k">else</span><span class="p">:</span> |
| <span class="k">return</span> <span class="n">OneVsRestModelWriter</span><span class="p">(</span><span class="bp">self</span><span class="p">)</span></div></div> |
| |
| |
| <span class="nd">@inherit_doc</span> |
| <span class="k">class</span> <span class="nc">OneVsRestModelReader</span><span class="p">(</span><span class="n">MLReader</span><span class="p">[</span><span class="n">OneVsRestModel</span><span class="p">]):</span> |
| <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="bp">cls</span><span class="p">:</span> <span class="n">Type</span><span class="p">[</span><span class="n">OneVsRestModel</span><span class="p">]):</span> |
| <span class="nb">super</span><span class="p">(</span><span class="n">OneVsRestModelReader</span><span class="p">,</span> <span class="bp">self</span><span class="p">)</span><span class="o">.</span><span class="fm">__init__</span><span class="p">()</span> |
| <span class="bp">self</span><span class="o">.</span><span class="n">cls</span> <span class="o">=</span> <span class="bp">cls</span> |
| |
| <span class="k">def</span> <span class="nf">load</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">path</span><span class="p">:</span> <span class="nb">str</span><span class="p">)</span> <span class="o">-></span> <span class="n">OneVsRestModel</span><span class="p">:</span> |
| <span class="n">metadata</span> <span class="o">=</span> <span class="n">DefaultParamsReader</span><span class="o">.</span><span class="n">loadMetadata</span><span class="p">(</span><span class="n">path</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">sc</span><span class="p">)</span> |
| <span class="k">if</span> <span class="ow">not</span> <span class="n">DefaultParamsReader</span><span class="o">.</span><span class="n">isPythonParamsInstance</span><span class="p">(</span><span class="n">metadata</span><span class="p">):</span> |
| <span class="k">return</span> <span class="n">JavaMLReader</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">cls</span><span class="p">)</span><span class="o">.</span><span class="n">load</span><span class="p">(</span><span class="n">path</span><span class="p">)</span> <span class="c1"># type: ignore[arg-type]</span> |
| <span class="k">else</span><span class="p">:</span> |
| <span class="n">classifier</span> <span class="o">=</span> <span class="n">_OneVsRestSharedReadWrite</span><span class="o">.</span><span class="n">loadClassifier</span><span class="p">(</span><span class="n">path</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">sc</span><span class="p">)</span> |
| <span class="n">numClasses</span> <span class="o">=</span> <span class="n">metadata</span><span class="p">[</span><span class="s2">"numClasses"</span><span class="p">]</span> |
| <span class="n">subModels</span> <span class="o">=</span> <span class="p">[</span><span class="kc">None</span><span class="p">]</span> <span class="o">*</span> <span class="n">numClasses</span> |
| <span class="k">for</span> <span class="n">idx</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">numClasses</span><span class="p">):</span> |
| <span class="n">subModelPath</span> <span class="o">=</span> <span class="n">os</span><span class="o">.</span><span class="n">path</span><span class="o">.</span><span class="n">join</span><span class="p">(</span><span class="n">path</span><span class="p">,</span> <span class="sa">f</span><span class="s2">"model_</span><span class="si">{</span><span class="n">idx</span><span class="si">}</span><span class="s2">"</span><span class="p">)</span> |
| <span class="n">subModels</span><span class="p">[</span><span class="n">idx</span><span class="p">]</span> <span class="o">=</span> <span class="n">DefaultParamsReader</span><span class="o">.</span><span class="n">loadParamsInstance</span><span class="p">(</span><span class="n">subModelPath</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">sc</span><span class="p">)</span> |
| <span class="n">ovaModel</span> <span class="o">=</span> <span class="n">OneVsRestModel</span><span class="p">(</span><span class="n">cast</span><span class="p">(</span><span class="n">List</span><span class="p">[</span><span class="n">ClassificationModel</span><span class="p">],</span> <span class="n">subModels</span><span class="p">))</span><span class="o">.</span><span class="n">_resetUid</span><span class="p">(</span> |
| <span class="n">metadata</span><span class="p">[</span><span class="s2">"uid"</span><span class="p">]</span> |
| <span class="p">)</span> |
| <span class="n">ovaModel</span><span class="o">.</span><span class="n">set</span><span class="p">(</span><span class="n">ovaModel</span><span class="o">.</span><span class="n">classifier</span><span class="p">,</span> <span class="n">classifier</span><span class="p">)</span> |
| <span class="n">DefaultParamsReader</span><span class="o">.</span><span class="n">getAndSetParams</span><span class="p">(</span><span class="n">ovaModel</span><span class="p">,</span> <span class="n">metadata</span><span class="p">,</span> <span class="n">skipParams</span><span class="o">=</span><span class="p">[</span><span class="s2">"classifier"</span><span class="p">])</span> |
| <span class="k">return</span> <span class="n">ovaModel</span> |
| |
| |
| <span class="nd">@inherit_doc</span> |
| <span class="k">class</span> <span class="nc">OneVsRestModelWriter</span><span class="p">(</span><span class="n">MLWriter</span><span class="p">):</span> |
| <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">instance</span><span class="p">:</span> <span class="n">OneVsRestModel</span><span class="p">):</span> |
| <span class="nb">super</span><span class="p">(</span><span class="n">OneVsRestModelWriter</span><span class="p">,</span> <span class="bp">self</span><span class="p">)</span><span class="o">.</span><span class="fm">__init__</span><span class="p">()</span> |
| <span class="bp">self</span><span class="o">.</span><span class="n">instance</span> <span class="o">=</span> <span class="n">instance</span> |
| |
| <span class="k">def</span> <span class="nf">saveImpl</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">path</span><span class="p">:</span> <span class="nb">str</span><span class="p">)</span> <span class="o">-></span> <span class="kc">None</span><span class="p">:</span> |
| <span class="n">_OneVsRestSharedReadWrite</span><span class="o">.</span><span class="n">validateParams</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">instance</span><span class="p">)</span> |
| <span class="n">instance</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">instance</span> |
| <span class="n">numClasses</span> <span class="o">=</span> <span class="nb">len</span><span class="p">(</span><span class="n">instance</span><span class="o">.</span><span class="n">models</span><span class="p">)</span> |
| <span class="n">extraMetadata</span> <span class="o">=</span> <span class="p">{</span><span class="s2">"numClasses"</span><span class="p">:</span> <span class="n">numClasses</span><span class="p">}</span> |
| <span class="n">_OneVsRestSharedReadWrite</span><span class="o">.</span><span class="n">saveImpl</span><span class="p">(</span><span class="n">instance</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">sc</span><span class="p">,</span> <span class="n">path</span><span class="p">,</span> <span class="n">extraMetadata</span><span class="o">=</span><span class="n">extraMetadata</span><span class="p">)</span> |
| <span class="k">for</span> <span class="n">idx</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">numClasses</span><span class="p">):</span> |
| <span class="n">subModelPath</span> <span class="o">=</span> <span class="n">os</span><span class="o">.</span><span class="n">path</span><span class="o">.</span><span class="n">join</span><span class="p">(</span><span class="n">path</span><span class="p">,</span> <span class="sa">f</span><span class="s2">"model_</span><span class="si">{</span><span class="n">idx</span><span class="si">}</span><span class="s2">"</span><span class="p">)</span> |
| <span class="n">cast</span><span class="p">(</span><span class="n">MLWritable</span><span class="p">,</span> <span class="n">instance</span><span class="o">.</span><span class="n">models</span><span class="p">[</span><span class="n">idx</span><span class="p">])</span><span class="o">.</span><span class="n">save</span><span class="p">(</span><span class="n">subModelPath</span><span class="p">)</span> |
| |
| |
| <div class="viewcode-block" id="FMClassifier"><a class="viewcode-back" href="../../../reference/api/pyspark.ml.classification.FMClassifier.html#pyspark.ml.classification.FMClassifier">[docs]</a><span class="nd">@inherit_doc</span> |
| <span class="k">class</span> <span class="nc">FMClassifier</span><span class="p">(</span> |
| <span class="n">_JavaProbabilisticClassifier</span><span class="p">[</span><span class="s2">"FMClassificationModel"</span><span class="p">],</span> |
| <span class="n">_FactorizationMachinesParams</span><span class="p">,</span> |
| <span class="n">JavaMLWritable</span><span class="p">,</span> |
| <span class="n">JavaMLReadable</span><span class="p">[</span><span class="s2">"FMClassifier"</span><span class="p">],</span> |
| <span class="p">):</span> |
| <span class="w"> </span><span class="sd">"""</span> |
| <span class="sd"> Factorization Machines learning algorithm for classification.</span> |
| |
| <span class="sd"> Solver supports:</span> |
| |
| <span class="sd"> * gd (normal mini-batch gradient descent)</span> |
| <span class="sd"> * adamW (default)</span> |
| |
| <span class="sd"> .. versionadded:: 3.0.0</span> |
| |
| <span class="sd"> Examples</span> |
| <span class="sd"> --------</span> |
| <span class="sd"> >>> from pyspark.ml.linalg import Vectors</span> |
| <span class="sd"> >>> from pyspark.ml.classification import FMClassifier</span> |
| <span class="sd"> >>> df = spark.createDataFrame([</span> |
| <span class="sd"> ... (1.0, Vectors.dense(1.0)),</span> |
| <span class="sd"> ... (0.0, Vectors.sparse(1, [], []))], ["label", "features"])</span> |
| <span class="sd"> >>> fm = FMClassifier(factorSize=2)</span> |
| <span class="sd"> >>> fm.setSeed(11)</span> |
| <span class="sd"> FMClassifier...</span> |
| <span class="sd"> >>> model = fm.fit(df)</span> |
| <span class="sd"> >>> model.getMaxIter()</span> |
| <span class="sd"> 100</span> |
| <span class="sd"> >>> test0 = spark.createDataFrame([</span> |
| <span class="sd"> ... (Vectors.dense(-1.0),),</span> |
| <span class="sd"> ... (Vectors.dense(0.5),),</span> |
| <span class="sd"> ... (Vectors.dense(1.0),),</span> |
| <span class="sd"> ... (Vectors.dense(2.0),)], ["features"])</span> |
| <span class="sd"> >>> model.predictRaw(test0.head().features)</span> |
| <span class="sd"> DenseVector([22.13..., -22.13...])</span> |
| <span class="sd"> >>> model.predictProbability(test0.head().features)</span> |
| <span class="sd"> DenseVector([1.0, 0.0])</span> |
| <span class="sd"> >>> model.transform(test0).select("features", "probability").show(10, False)</span> |
| <span class="sd"> +--------+------------------------------------------+</span> |
| <span class="sd"> |features|probability |</span> |
| <span class="sd"> +--------+------------------------------------------+</span> |
| <span class="sd"> |[-1.0] |[0.9999999997574736,2.425264676902229E-10]|</span> |
| <span class="sd"> |[0.5] |[0.47627851732981163,0.5237214826701884] |</span> |
| <span class="sd"> |[1.0] |[5.491554426243495E-4,0.9994508445573757] |</span> |
| <span class="sd"> |[2.0] |[2.005766663870645E-10,0.9999999997994233]|</span> |
| <span class="sd"> +--------+------------------------------------------+</span> |
| <span class="sd"> ...</span> |
| <span class="sd"> >>> model.intercept</span> |
| <span class="sd"> -7.316665276826291</span> |
| <span class="sd"> >>> model.linear</span> |
| <span class="sd"> DenseVector([14.8232])</span> |
| <span class="sd"> >>> model.factors</span> |
| <span class="sd"> DenseMatrix(1, 2, [0.0163, -0.0051], 1)</span> |
| <span class="sd"> >>> model_path = temp_path + "/fm_model"</span> |
| <span class="sd"> >>> model.save(model_path)</span> |
| <span class="sd"> >>> model2 = FMClassificationModel.load(model_path)</span> |
| <span class="sd"> >>> model2.intercept</span> |
| <span class="sd"> -7.316665276826291</span> |
| <span class="sd"> >>> model2.linear</span> |
| <span class="sd"> DenseVector([14.8232])</span> |
| <span class="sd"> >>> model2.factors</span> |
| <span class="sd"> DenseMatrix(1, 2, [0.0163, -0.0051], 1)</span> |
| <span class="sd"> >>> model.transform(test0).take(1) == model2.transform(test0).take(1)</span> |
| <span class="sd"> True</span> |
| <span class="sd"> """</span> |
| |
| <span class="n">_input_kwargs</span><span class="p">:</span> <span class="n">Dict</span><span class="p">[</span><span class="nb">str</span><span class="p">,</span> <span class="n">Any</span><span class="p">]</span> |
| |
| <span class="nd">@keyword_only</span> |
| <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span> |
| <span class="bp">self</span><span class="p">,</span> |
| <span class="o">*</span><span class="p">,</span> |
| <span class="n">featuresCol</span><span class="p">:</span> <span class="nb">str</span> <span class="o">=</span> <span class="s2">"features"</span><span class="p">,</span> |
| <span class="n">labelCol</span><span class="p">:</span> <span class="nb">str</span> <span class="o">=</span> <span class="s2">"label"</span><span class="p">,</span> |
| <span class="n">predictionCol</span><span class="p">:</span> <span class="nb">str</span> <span class="o">=</span> <span class="s2">"prediction"</span><span class="p">,</span> |
| <span class="n">probabilityCol</span><span class="p">:</span> <span class="nb">str</span> <span class="o">=</span> <span class="s2">"probability"</span><span class="p">,</span> |
| <span class="n">rawPredictionCol</span><span class="p">:</span> <span class="nb">str</span> <span class="o">=</span> <span class="s2">"rawPrediction"</span><span class="p">,</span> |
| <span class="n">factorSize</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="mi">8</span><span class="p">,</span> |
| <span class="n">fitIntercept</span><span class="p">:</span> <span class="nb">bool</span> <span class="o">=</span> <span class="kc">True</span><span class="p">,</span> |
| <span class="n">fitLinear</span><span class="p">:</span> <span class="nb">bool</span> <span class="o">=</span> <span class="kc">True</span><span class="p">,</span> |
| <span class="n">regParam</span><span class="p">:</span> <span class="nb">float</span> <span class="o">=</span> <span class="mf">0.0</span><span class="p">,</span> |
| <span class="n">miniBatchFraction</span><span class="p">:</span> <span class="nb">float</span> <span class="o">=</span> <span class="mf">1.0</span><span class="p">,</span> |
| <span class="n">initStd</span><span class="p">:</span> <span class="nb">float</span> <span class="o">=</span> <span class="mf">0.01</span><span class="p">,</span> |
| <span class="n">maxIter</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="mi">100</span><span class="p">,</span> |
| <span class="n">stepSize</span><span class="p">:</span> <span class="nb">float</span> <span class="o">=</span> <span class="mf">1.0</span><span class="p">,</span> |
| <span class="n">tol</span><span class="p">:</span> <span class="nb">float</span> <span class="o">=</span> <span class="mf">1e-6</span><span class="p">,</span> |
| <span class="n">solver</span><span class="p">:</span> <span class="nb">str</span> <span class="o">=</span> <span class="s2">"adamW"</span><span class="p">,</span> |
| <span class="n">thresholds</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="n">List</span><span class="p">[</span><span class="nb">float</span><span class="p">]]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span> |
| <span class="n">seed</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="nb">int</span><span class="p">]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span> |
| <span class="p">):</span> |
| <span class="w"> </span><span class="sd">"""</span> |
| <span class="sd"> __init__(self, \\*, featuresCol="features", labelCol="label", predictionCol="prediction", \</span> |
| <span class="sd"> probabilityCol="probability", rawPredictionCol="rawPrediction", \</span> |
| <span class="sd"> factorSize=8, fitIntercept=True, fitLinear=True, regParam=0.0, \</span> |
| <span class="sd"> miniBatchFraction=1.0, initStd=0.01, maxIter=100, stepSize=1.0, \</span> |
| <span class="sd"> tol=1e-6, solver="adamW", thresholds=None, seed=None)</span> |
| <span class="sd"> """</span> |
| <span class="nb">super</span><span class="p">(</span><span class="n">FMClassifier</span><span class="p">,</span> <span class="bp">self</span><span class="p">)</span><span class="o">.</span><span class="fm">__init__</span><span class="p">()</span> |
| <span class="bp">self</span><span class="o">.</span><span class="n">_java_obj</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">_new_java_obj</span><span class="p">(</span> |
| <span class="s2">"org.apache.spark.ml.classification.FMClassifier"</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">uid</span> |
| <span class="p">)</span> |
| <span class="n">kwargs</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">_input_kwargs</span> |
| <span class="bp">self</span><span class="o">.</span><span class="n">setParams</span><span class="p">(</span><span class="o">**</span><span class="n">kwargs</span><span class="p">)</span> |
| |
| <div class="viewcode-block" id="FMClassifier.setParams"><a class="viewcode-back" href="../../../reference/api/pyspark.ml.classification.FMClassifier.html#pyspark.ml.classification.FMClassifier.setParams">[docs]</a> <span class="nd">@keyword_only</span> |
| <span class="nd">@since</span><span class="p">(</span><span class="s2">"3.0.0"</span><span class="p">)</span> |
| <span class="k">def</span> <span class="nf">setParams</span><span class="p">(</span> |
| <span class="bp">self</span><span class="p">,</span> |
| <span class="o">*</span><span class="p">,</span> |
| <span class="n">featuresCol</span><span class="p">:</span> <span class="nb">str</span> <span class="o">=</span> <span class="s2">"features"</span><span class="p">,</span> |
| <span class="n">labelCol</span><span class="p">:</span> <span class="nb">str</span> <span class="o">=</span> <span class="s2">"label"</span><span class="p">,</span> |
| <span class="n">predictionCol</span><span class="p">:</span> <span class="nb">str</span> <span class="o">=</span> <span class="s2">"prediction"</span><span class="p">,</span> |
| <span class="n">probabilityCol</span><span class="p">:</span> <span class="nb">str</span> <span class="o">=</span> <span class="s2">"probability"</span><span class="p">,</span> |
| <span class="n">rawPredictionCol</span><span class="p">:</span> <span class="nb">str</span> <span class="o">=</span> <span class="s2">"rawPrediction"</span><span class="p">,</span> |
| <span class="n">factorSize</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="mi">8</span><span class="p">,</span> |
| <span class="n">fitIntercept</span><span class="p">:</span> <span class="nb">bool</span> <span class="o">=</span> <span class="kc">True</span><span class="p">,</span> |
| <span class="n">fitLinear</span><span class="p">:</span> <span class="nb">bool</span> <span class="o">=</span> <span class="kc">True</span><span class="p">,</span> |
| <span class="n">regParam</span><span class="p">:</span> <span class="nb">float</span> <span class="o">=</span> <span class="mf">0.0</span><span class="p">,</span> |
| <span class="n">miniBatchFraction</span><span class="p">:</span> <span class="nb">float</span> <span class="o">=</span> <span class="mf">1.0</span><span class="p">,</span> |
| <span class="n">initStd</span><span class="p">:</span> <span class="nb">float</span> <span class="o">=</span> <span class="mf">0.01</span><span class="p">,</span> |
| <span class="n">maxIter</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="mi">100</span><span class="p">,</span> |
| <span class="n">stepSize</span><span class="p">:</span> <span class="nb">float</span> <span class="o">=</span> <span class="mf">1.0</span><span class="p">,</span> |
| <span class="n">tol</span><span class="p">:</span> <span class="nb">float</span> <span class="o">=</span> <span class="mf">1e-6</span><span class="p">,</span> |
| <span class="n">solver</span><span class="p">:</span> <span class="nb">str</span> <span class="o">=</span> <span class="s2">"adamW"</span><span class="p">,</span> |
| <span class="n">thresholds</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="n">List</span><span class="p">[</span><span class="nb">float</span><span class="p">]]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span> |
| <span class="n">seed</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="nb">int</span><span class="p">]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span> |
| <span class="p">)</span> <span class="o">-></span> <span class="s2">"FMClassifier"</span><span class="p">:</span> |
| <span class="w"> </span><span class="sd">"""</span> |
| <span class="sd"> setParams(self, \\*, featuresCol="features", labelCol="label", predictionCol="prediction", \</span> |
| <span class="sd"> probabilityCol="probability", rawPredictionCol="rawPrediction", \</span> |
| <span class="sd"> factorSize=8, fitIntercept=True, fitLinear=True, regParam=0.0, \</span> |
| <span class="sd"> miniBatchFraction=1.0, initStd=0.01, maxIter=100, stepSize=1.0, \</span> |
| <span class="sd"> tol=1e-6, solver="adamW", thresholds=None, seed=None)</span> |
| <span class="sd"> Sets Params for FMClassifier.</span> |
| <span class="sd"> """</span> |
| <span class="n">kwargs</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">_input_kwargs</span> |
| <span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">_set</span><span class="p">(</span><span class="o">**</span><span class="n">kwargs</span><span class="p">)</span></div> |
| |
| <span class="k">def</span> <span class="nf">_create_model</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">java_model</span><span class="p">:</span> <span class="s2">"JavaObject"</span><span class="p">)</span> <span class="o">-></span> <span class="s2">"FMClassificationModel"</span><span class="p">:</span> |
| <span class="k">return</span> <span class="n">FMClassificationModel</span><span class="p">(</span><span class="n">java_model</span><span class="p">)</span> |
| |
| <div class="viewcode-block" id="FMClassifier.setFactorSize"><a class="viewcode-back" href="../../../reference/api/pyspark.ml.classification.FMClassifier.html#pyspark.ml.classification.FMClassifier.setFactorSize">[docs]</a> <span class="nd">@since</span><span class="p">(</span><span class="s2">"3.0.0"</span><span class="p">)</span> |
| <span class="k">def</span> <span class="nf">setFactorSize</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">value</span><span class="p">:</span> <span class="nb">int</span><span class="p">)</span> <span class="o">-></span> <span class="s2">"FMClassifier"</span><span class="p">:</span> |
| <span class="w"> </span><span class="sd">"""</span> |
| <span class="sd"> Sets the value of :py:attr:`factorSize`.</span> |
| <span class="sd"> """</span> |
| <span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">_set</span><span class="p">(</span><span class="n">factorSize</span><span class="o">=</span><span class="n">value</span><span class="p">)</span></div> |
| |
| <div class="viewcode-block" id="FMClassifier.setFitLinear"><a class="viewcode-back" href="../../../reference/api/pyspark.ml.classification.FMClassifier.html#pyspark.ml.classification.FMClassifier.setFitLinear">[docs]</a> <span class="nd">@since</span><span class="p">(</span><span class="s2">"3.0.0"</span><span class="p">)</span> |
| <span class="k">def</span> <span class="nf">setFitLinear</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">value</span><span class="p">:</span> <span class="nb">bool</span><span class="p">)</span> <span class="o">-></span> <span class="s2">"FMClassifier"</span><span class="p">:</span> |
| <span class="w"> </span><span class="sd">"""</span> |
| <span class="sd"> Sets the value of :py:attr:`fitLinear`.</span> |
| <span class="sd"> """</span> |
| <span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">_set</span><span class="p">(</span><span class="n">fitLinear</span><span class="o">=</span><span class="n">value</span><span class="p">)</span></div> |
| |
| <div class="viewcode-block" id="FMClassifier.setMiniBatchFraction"><a class="viewcode-back" href="../../../reference/api/pyspark.ml.classification.FMClassifier.html#pyspark.ml.classification.FMClassifier.setMiniBatchFraction">[docs]</a> <span class="nd">@since</span><span class="p">(</span><span class="s2">"3.0.0"</span><span class="p">)</span> |
| <span class="k">def</span> <span class="nf">setMiniBatchFraction</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">value</span><span class="p">:</span> <span class="nb">float</span><span class="p">)</span> <span class="o">-></span> <span class="s2">"FMClassifier"</span><span class="p">:</span> |
| <span class="w"> </span><span class="sd">"""</span> |
| <span class="sd"> Sets the value of :py:attr:`miniBatchFraction`.</span> |
| <span class="sd"> """</span> |
| <span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">_set</span><span class="p">(</span><span class="n">miniBatchFraction</span><span class="o">=</span><span class="n">value</span><span class="p">)</span></div> |
| |
| <div class="viewcode-block" id="FMClassifier.setInitStd"><a class="viewcode-back" href="../../../reference/api/pyspark.ml.classification.FMClassifier.html#pyspark.ml.classification.FMClassifier.setInitStd">[docs]</a> <span class="nd">@since</span><span class="p">(</span><span class="s2">"3.0.0"</span><span class="p">)</span> |
| <span class="k">def</span> <span class="nf">setInitStd</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">value</span><span class="p">:</span> <span class="nb">float</span><span class="p">)</span> <span class="o">-></span> <span class="s2">"FMClassifier"</span><span class="p">:</span> |
| <span class="w"> </span><span class="sd">"""</span> |
| <span class="sd"> Sets the value of :py:attr:`initStd`.</span> |
| <span class="sd"> """</span> |
| <span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">_set</span><span class="p">(</span><span class="n">initStd</span><span class="o">=</span><span class="n">value</span><span class="p">)</span></div> |
| |
| <div class="viewcode-block" id="FMClassifier.setMaxIter"><a class="viewcode-back" href="../../../reference/api/pyspark.ml.classification.FMClassifier.html#pyspark.ml.classification.FMClassifier.setMaxIter">[docs]</a> <span class="nd">@since</span><span class="p">(</span><span class="s2">"3.0.0"</span><span class="p">)</span> |
| <span class="k">def</span> <span class="nf">setMaxIter</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">value</span><span class="p">:</span> <span class="nb">int</span><span class="p">)</span> <span class="o">-></span> <span class="s2">"FMClassifier"</span><span class="p">:</span> |
| <span class="w"> </span><span class="sd">"""</span> |
| <span class="sd"> Sets the value of :py:attr:`maxIter`.</span> |
| <span class="sd"> """</span> |
| <span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">_set</span><span class="p">(</span><span class="n">maxIter</span><span class="o">=</span><span class="n">value</span><span class="p">)</span></div> |
| |
| <div class="viewcode-block" id="FMClassifier.setStepSize"><a class="viewcode-back" href="../../../reference/api/pyspark.ml.classification.FMClassifier.html#pyspark.ml.classification.FMClassifier.setStepSize">[docs]</a> <span class="nd">@since</span><span class="p">(</span><span class="s2">"3.0.0"</span><span class="p">)</span> |
| <span class="k">def</span> <span class="nf">setStepSize</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">value</span><span class="p">:</span> <span class="nb">float</span><span class="p">)</span> <span class="o">-></span> <span class="s2">"FMClassifier"</span><span class="p">:</span> |
| <span class="w"> </span><span class="sd">"""</span> |
| <span class="sd"> Sets the value of :py:attr:`stepSize`.</span> |
| <span class="sd"> """</span> |
| <span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">_set</span><span class="p">(</span><span class="n">stepSize</span><span class="o">=</span><span class="n">value</span><span class="p">)</span></div> |
| |
| <div class="viewcode-block" id="FMClassifier.setTol"><a class="viewcode-back" href="../../../reference/api/pyspark.ml.classification.FMClassifier.html#pyspark.ml.classification.FMClassifier.setTol">[docs]</a> <span class="nd">@since</span><span class="p">(</span><span class="s2">"3.0.0"</span><span class="p">)</span> |
| <span class="k">def</span> <span class="nf">setTol</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">value</span><span class="p">:</span> <span class="nb">float</span><span class="p">)</span> <span class="o">-></span> <span class="s2">"FMClassifier"</span><span class="p">:</span> |
| <span class="w"> </span><span class="sd">"""</span> |
| <span class="sd"> Sets the value of :py:attr:`tol`.</span> |
| <span class="sd"> """</span> |
| <span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">_set</span><span class="p">(</span><span class="n">tol</span><span class="o">=</span><span class="n">value</span><span class="p">)</span></div> |
| |
| <div class="viewcode-block" id="FMClassifier.setSolver"><a class="viewcode-back" href="../../../reference/api/pyspark.ml.classification.FMClassifier.html#pyspark.ml.classification.FMClassifier.setSolver">[docs]</a> <span class="nd">@since</span><span class="p">(</span><span class="s2">"3.0.0"</span><span class="p">)</span> |
| <span class="k">def</span> <span class="nf">setSolver</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">value</span><span class="p">:</span> <span class="nb">str</span><span class="p">)</span> <span class="o">-></span> <span class="s2">"FMClassifier"</span><span class="p">:</span> |
| <span class="w"> </span><span class="sd">"""</span> |
| <span class="sd"> Sets the value of :py:attr:`solver`.</span> |
| <span class="sd"> """</span> |
| <span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">_set</span><span class="p">(</span><span class="n">solver</span><span class="o">=</span><span class="n">value</span><span class="p">)</span></div> |
| |
| <div class="viewcode-block" id="FMClassifier.setSeed"><a class="viewcode-back" href="../../../reference/api/pyspark.ml.classification.FMClassifier.html#pyspark.ml.classification.FMClassifier.setSeed">[docs]</a> <span class="nd">@since</span><span class="p">(</span><span class="s2">"3.0.0"</span><span class="p">)</span> |
| <span class="k">def</span> <span class="nf">setSeed</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">value</span><span class="p">:</span> <span class="nb">int</span><span class="p">)</span> <span class="o">-></span> <span class="s2">"FMClassifier"</span><span class="p">:</span> |
| <span class="w"> </span><span class="sd">"""</span> |
| <span class="sd"> Sets the value of :py:attr:`seed`.</span> |
| <span class="sd"> """</span> |
| <span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">_set</span><span class="p">(</span><span class="n">seed</span><span class="o">=</span><span class="n">value</span><span class="p">)</span></div> |
| |
| <div class="viewcode-block" id="FMClassifier.setFitIntercept"><a class="viewcode-back" href="../../../reference/api/pyspark.ml.classification.FMClassifier.html#pyspark.ml.classification.FMClassifier.setFitIntercept">[docs]</a> <span class="nd">@since</span><span class="p">(</span><span class="s2">"3.0.0"</span><span class="p">)</span> |
| <span class="k">def</span> <span class="nf">setFitIntercept</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">value</span><span class="p">:</span> <span class="nb">bool</span><span class="p">)</span> <span class="o">-></span> <span class="s2">"FMClassifier"</span><span class="p">:</span> |
| <span class="w"> </span><span class="sd">"""</span> |
| <span class="sd"> Sets the value of :py:attr:`fitIntercept`.</span> |
| <span class="sd"> """</span> |
| <span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">_set</span><span class="p">(</span><span class="n">fitIntercept</span><span class="o">=</span><span class="n">value</span><span class="p">)</span></div> |
| |
| <div class="viewcode-block" id="FMClassifier.setRegParam"><a class="viewcode-back" href="../../../reference/api/pyspark.ml.classification.FMClassifier.html#pyspark.ml.classification.FMClassifier.setRegParam">[docs]</a> <span class="nd">@since</span><span class="p">(</span><span class="s2">"3.0.0"</span><span class="p">)</span> |
| <span class="k">def</span> <span class="nf">setRegParam</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">value</span><span class="p">:</span> <span class="nb">float</span><span class="p">)</span> <span class="o">-></span> <span class="s2">"FMClassifier"</span><span class="p">:</span> |
| <span class="w"> </span><span class="sd">"""</span> |
| <span class="sd"> Sets the value of :py:attr:`regParam`.</span> |
| <span class="sd"> """</span> |
| <span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">_set</span><span class="p">(</span><span class="n">regParam</span><span class="o">=</span><span class="n">value</span><span class="p">)</span></div></div> |
| |
| |
| <div class="viewcode-block" id="FMClassificationModel"><a class="viewcode-back" href="../../../reference/api/pyspark.ml.classification.FMClassificationModel.html#pyspark.ml.classification.FMClassificationModel">[docs]</a><span class="k">class</span> <span class="nc">FMClassificationModel</span><span class="p">(</span> |
| <span class="n">_JavaProbabilisticClassificationModel</span><span class="p">[</span><span class="n">Vector</span><span class="p">],</span> |
| <span class="n">_FactorizationMachinesParams</span><span class="p">,</span> |
| <span class="n">JavaMLWritable</span><span class="p">,</span> |
| <span class="n">JavaMLReadable</span><span class="p">[</span><span class="s2">"FMClassificationModel"</span><span class="p">],</span> |
| <span class="n">HasTrainingSummary</span><span class="p">,</span> |
| <span class="p">):</span> |
| <span class="w"> </span><span class="sd">"""</span> |
| <span class="sd"> Model fitted by :class:`FMClassifier`.</span> |
| |
| <span class="sd"> .. versionadded:: 3.0.0</span> |
| <span class="sd"> """</span> |
| |
| <span class="nd">@property</span> <span class="c1"># type: ignore[misc]</span> |
| <span class="nd">@since</span><span class="p">(</span><span class="s2">"3.0.0"</span><span class="p">)</span> |
| <span class="k">def</span> <span class="nf">intercept</span><span class="p">(</span><span class="bp">self</span><span class="p">)</span> <span class="o">-></span> <span class="nb">float</span><span class="p">:</span> |
| <span class="w"> </span><span class="sd">"""</span> |
| <span class="sd"> Model intercept.</span> |
| <span class="sd"> """</span> |
| <span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">_call_java</span><span class="p">(</span><span class="s2">"intercept"</span><span class="p">)</span> |
| |
| <span class="nd">@property</span> <span class="c1"># type: ignore[misc]</span> |
| <span class="nd">@since</span><span class="p">(</span><span class="s2">"3.0.0"</span><span class="p">)</span> |
| <span class="k">def</span> <span class="nf">linear</span><span class="p">(</span><span class="bp">self</span><span class="p">)</span> <span class="o">-></span> <span class="n">Vector</span><span class="p">:</span> |
| <span class="w"> </span><span class="sd">"""</span> |
| <span class="sd"> Model linear term.</span> |
| <span class="sd"> """</span> |
| <span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">_call_java</span><span class="p">(</span><span class="s2">"linear"</span><span class="p">)</span> |
| |
| <span class="nd">@property</span> <span class="c1"># type: ignore[misc]</span> |
| <span class="nd">@since</span><span class="p">(</span><span class="s2">"3.0.0"</span><span class="p">)</span> |
| <span class="k">def</span> <span class="nf">factors</span><span class="p">(</span><span class="bp">self</span><span class="p">)</span> <span class="o">-></span> <span class="n">Matrix</span><span class="p">:</span> |
| <span class="w"> </span><span class="sd">"""</span> |
| <span class="sd"> Model factor term.</span> |
| <span class="sd"> """</span> |
| <span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">_call_java</span><span class="p">(</span><span class="s2">"factors"</span><span class="p">)</span> |
| |
| <div class="viewcode-block" id="FMClassificationModel.summary"><a class="viewcode-back" href="../../../reference/api/pyspark.ml.classification.FMClassificationModel.html#pyspark.ml.classification.FMClassificationModel.summary">[docs]</a> <span class="nd">@since</span><span class="p">(</span><span class="s2">"3.1.0"</span><span class="p">)</span> |
| <span class="k">def</span> <span class="nf">summary</span><span class="p">(</span><span class="bp">self</span><span class="p">)</span> <span class="o">-></span> <span class="s2">"FMClassificationTrainingSummary"</span><span class="p">:</span> |
| <span class="w"> </span><span class="sd">"""</span> |
| <span class="sd"> Gets summary (accuracy/precision/recall, objective history, total iterations) of model</span> |
| <span class="sd"> trained on the training set. An exception is thrown if `trainingSummary is None`.</span> |
| <span class="sd"> """</span> |
| <span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">hasSummary</span><span class="p">:</span> |
| <span class="k">return</span> <span class="n">FMClassificationTrainingSummary</span><span class="p">(</span><span class="nb">super</span><span class="p">(</span><span class="n">FMClassificationModel</span><span class="p">,</span> <span class="bp">self</span><span class="p">)</span><span class="o">.</span><span class="n">summary</span><span class="p">)</span> |
| <span class="k">else</span><span class="p">:</span> |
| <span class="k">raise</span> <span class="ne">RuntimeError</span><span class="p">(</span> |
| <span class="s2">"No training summary available for this </span><span class="si">%s</span><span class="s2">"</span> <span class="o">%</span> <span class="bp">self</span><span class="o">.</span><span class="vm">__class__</span><span class="o">.</span><span class="vm">__name__</span> |
| <span class="p">)</span></div> |
| |
| <div class="viewcode-block" id="FMClassificationModel.evaluate"><a class="viewcode-back" href="../../../reference/api/pyspark.ml.classification.FMClassificationModel.html#pyspark.ml.classification.FMClassificationModel.evaluate">[docs]</a> <span class="k">def</span> <span class="nf">evaluate</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">dataset</span><span class="p">:</span> <span class="n">DataFrame</span><span class="p">)</span> <span class="o">-></span> <span class="s2">"FMClassificationSummary"</span><span class="p">:</span> |
| <span class="w"> </span><span class="sd">"""</span> |
| <span class="sd"> Evaluates the model on a test dataset.</span> |
| |
| <span class="sd"> .. versionadded:: 3.1.0</span> |
| |
| <span class="sd"> Parameters</span> |
| <span class="sd"> ----------</span> |
| <span class="sd"> dataset : :py:class:`pyspark.sql.DataFrame`</span> |
| <span class="sd"> Test dataset to evaluate model on.</span> |
| <span class="sd"> """</span> |
| <span class="k">if</span> <span class="ow">not</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">dataset</span><span class="p">,</span> <span class="n">DataFrame</span><span class="p">):</span> |
| <span class="k">raise</span> <span class="ne">TypeError</span><span class="p">(</span><span class="s2">"dataset must be a DataFrame but got </span><span class="si">%s</span><span class="s2">."</span> <span class="o">%</span> <span class="nb">type</span><span class="p">(</span><span class="n">dataset</span><span class="p">))</span> |
| <span class="n">java_fm_summary</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">_call_java</span><span class="p">(</span><span class="s2">"evaluate"</span><span class="p">,</span> <span class="n">dataset</span><span class="p">)</span> |
| <span class="k">return</span> <span class="n">FMClassificationSummary</span><span class="p">(</span><span class="n">java_fm_summary</span><span class="p">)</span></div></div> |
| |
| |
| <div class="viewcode-block" id="FMClassificationSummary"><a class="viewcode-back" href="../../../reference/api/pyspark.ml.classification.FMClassificationSummary.html#pyspark.ml.classification.FMClassificationSummary">[docs]</a><span class="k">class</span> <span class="nc">FMClassificationSummary</span><span class="p">(</span><span class="n">_BinaryClassificationSummary</span><span class="p">):</span> |
| <span class="w"> </span><span class="sd">"""</span> |
| <span class="sd"> Abstraction for FMClassifier Results for a given model.</span> |
| |
| <span class="sd"> .. versionadded:: 3.1.0</span> |
| <span class="sd"> """</span> |
| |
| <span class="k">pass</span></div> |
| |
| |
| <div class="viewcode-block" id="FMClassificationTrainingSummary"><a class="viewcode-back" href="../../../reference/api/pyspark.ml.classification.FMClassificationTrainingSummary.html#pyspark.ml.classification.FMClassificationTrainingSummary">[docs]</a><span class="nd">@inherit_doc</span> |
| <span class="k">class</span> <span class="nc">FMClassificationTrainingSummary</span><span class="p">(</span><span class="n">FMClassificationSummary</span><span class="p">,</span> <span class="n">_TrainingSummary</span><span class="p">):</span> |
| <span class="w"> </span><span class="sd">"""</span> |
| <span class="sd"> Abstraction for FMClassifier Training results.</span> |
| |
| <span class="sd"> .. versionadded:: 3.1.0</span> |
| <span class="sd"> """</span> |
| |
| <span class="k">pass</span></div> |
| |
| |
| <span class="k">if</span> <span class="vm">__name__</span> <span class="o">==</span> <span class="s2">"__main__"</span><span class="p">:</span> |
| <span class="kn">import</span> <span class="nn">doctest</span> |
| <span class="kn">import</span> <span class="nn">pyspark.ml.classification</span> |
| <span class="kn">from</span> <span class="nn">pyspark.sql</span> <span class="kn">import</span> <span class="n">SparkSession</span> |
| |
| <span class="n">globs</span> <span class="o">=</span> <span class="n">pyspark</span><span class="o">.</span><span class="n">ml</span><span class="o">.</span><span class="n">classification</span><span class="o">.</span><span class="vm">__dict__</span><span class="o">.</span><span class="n">copy</span><span class="p">()</span> |
| <span class="c1"># The small batch size here ensures that we see multiple batches,</span> |
| <span class="c1"># even in these small test examples:</span> |
| <span class="n">spark</span> <span class="o">=</span> <span class="n">SparkSession</span><span class="o">.</span><span class="n">builder</span><span class="o">.</span><span class="n">master</span><span class="p">(</span><span class="s2">"local[2]"</span><span class="p">)</span><span class="o">.</span><span class="n">appName</span><span class="p">(</span><span class="s2">"ml.classification tests"</span><span class="p">)</span><span class="o">.</span><span class="n">getOrCreate</span><span class="p">()</span> |
| <span class="n">sc</span> <span class="o">=</span> <span class="n">spark</span><span class="o">.</span><span class="n">sparkContext</span> |
| <span class="n">globs</span><span class="p">[</span><span class="s2">"sc"</span><span class="p">]</span> <span class="o">=</span> <span class="n">sc</span> |
| <span class="n">globs</span><span class="p">[</span><span class="s2">"spark"</span><span class="p">]</span> <span class="o">=</span> <span class="n">spark</span> |
| <span class="kn">import</span> <span class="nn">tempfile</span> |
| |
| <span class="n">temp_path</span> <span class="o">=</span> <span class="n">tempfile</span><span class="o">.</span><span class="n">mkdtemp</span><span class="p">()</span> |
| <span class="n">globs</span><span class="p">[</span><span class="s2">"temp_path"</span><span class="p">]</span> <span class="o">=</span> <span class="n">temp_path</span> |
| <span class="k">try</span><span class="p">:</span> |
| <span class="p">(</span><span class="n">failure_count</span><span class="p">,</span> <span class="n">test_count</span><span class="p">)</span> <span class="o">=</span> <span class="n">doctest</span><span class="o">.</span><span class="n">testmod</span><span class="p">(</span><span class="n">globs</span><span class="o">=</span><span class="n">globs</span><span class="p">,</span> <span class="n">optionflags</span><span class="o">=</span><span class="n">doctest</span><span class="o">.</span><span class="n">ELLIPSIS</span><span class="p">)</span> |
| <span class="n">spark</span><span class="o">.</span><span class="n">stop</span><span class="p">()</span> |
| <span class="k">finally</span><span class="p">:</span> |
| <span class="kn">from</span> <span class="nn">shutil</span> <span class="kn">import</span> <span class="n">rmtree</span> |
| |
| <span class="k">try</span><span class="p">:</span> |
| <span class="n">rmtree</span><span class="p">(</span><span class="n">temp_path</span><span class="p">)</span> |
| <span class="k">except</span> <span class="ne">OSError</span><span class="p">:</span> |
| <span class="k">pass</span> |
| <span class="k">if</span> <span class="n">failure_count</span><span class="p">:</span> |
| <span class="n">sys</span><span class="o">.</span><span class="n">exit</span><span class="p">(</span><span class="o">-</span><span class="mi">1</span><span class="p">)</span> |
| </pre></div> |
| |
| </div> |
| |
| |
| <div class='prev-next-bottom'> |
| |
| |
| </div> |
| |
| </main> |
| |
| |
| </div> |
| </div> |
| |
| |
| <script src="../../../_static/js/index.3da636dd464baa7582d2.js"></script> |
| |
| |
| <footer class="footer mt-5 mt-md-0"> |
| <div class="container"> |
| <p> |
| © Copyright .<br/> |
| Created using <a href="http://sphinx-doc.org/">Sphinx</a> 3.0.4.<br/> |
| </p> |
| </div> |
| </footer> |
| </body> |
| </html> |