blob: fd86b3829a175d6ad7d43d1799d19983f0e2d089 [file] [log] [blame]
<!DOCTYPE html>
<html>
<head>
<meta charset="utf-8" />
<title>pyspark.ml.functions &#8212; PySpark 3.5.5 documentation</title>
<link href="../../../_static/styles/theme.css?digest=1999514e3f237ded88cf" rel="stylesheet">
<link href="../../../_static/styles/pydata-sphinx-theme.css?digest=1999514e3f237ded88cf" rel="stylesheet">
<link rel="stylesheet"
href="../../../_static/vendor/fontawesome/5.13.0/css/all.min.css">
<link rel="preload" as="font" type="font/woff2" crossorigin
href="../../../_static/vendor/fontawesome/5.13.0/webfonts/fa-solid-900.woff2">
<link rel="preload" as="font" type="font/woff2" crossorigin
href="../../../_static/vendor/fontawesome/5.13.0/webfonts/fa-brands-400.woff2">
<link rel="stylesheet" href="../../../_static/styles/pydata-sphinx-theme.css" type="text/css" />
<link rel="stylesheet" href="../../../_static/pygments.css" type="text/css" />
<link rel="stylesheet" type="text/css" href="../../../_static/copybutton.css" />
<link rel="stylesheet" type="text/css" href="../../../_static/css/pyspark.css" />
<link rel="preload" as="script" href="../../../_static/scripts/pydata-sphinx-theme.js?digest=1999514e3f237ded88cf">
<script id="documentation_options" data-url_root="../../../" src="../../../_static/documentation_options.js"></script>
<script src="../../../_static/jquery.js"></script>
<script src="../../../_static/underscore.js"></script>
<script src="../../../_static/doctools.js"></script>
<script src="../../../_static/language_data.js"></script>
<script src="../../../_static/clipboard.min.js"></script>
<script src="../../../_static/copybutton.js"></script>
<script crossorigin="anonymous" integrity="sha256-Ae2Vz/4ePdIu6ZyI/5ZGsYnb+m0JlOmKPjt6XZ9JJkA=" src="https://cdnjs.cloudflare.com/ajax/libs/require.js/2.3.4/require.min.js"></script>
<script async="async" src="https://cdnjs.cloudflare.com/ajax/libs/mathjax/2.7.5/latest.js?config=TeX-AMS-MML_HTMLorMML"></script>
<script type="text/x-mathjax-config">MathJax.Hub.Config({"tex2jax": {"inlineMath": [["$", "$"], ["\\(", "\\)"]], "processEscapes": true, "ignoreClass": "document", "processClass": "math|output_area"}})</script>
<link rel="canonical" href="https://spark.apache.org/docs/latest/api/python/_modules/pyspark/ml/functions.html" />
<link rel="search" title="Search" href="../../../search.html" />
<meta name="viewport" content="width=device-width, initial-scale=1" />
<meta name="docsearch:language" content="None">
<!-- Google Analytics -->
</head>
<body data-spy="scroll" data-target="#bd-toc-nav" data-offset="80">
<div class="container-fluid" id="banner"></div>
<nav class="navbar navbar-light navbar-expand-lg bg-light fixed-top bd-navbar" id="navbar-main"><div class="container-xl">
<div id="navbar-start">
<a class="navbar-brand" href="../../../index.html">
<img src="../../../_static/spark-logo-reverse.png" class="logo" alt="logo">
</a>
</div>
<button class="navbar-toggler" type="button" data-toggle="collapse" data-target="#navbar-collapsible" aria-controls="navbar-collapsible" aria-expanded="false" aria-label="Toggle navigation">
<span class="navbar-toggler-icon"></span>
</button>
<div id="navbar-collapsible" class="col-lg-9 collapse navbar-collapse">
<div id="navbar-center" class="mr-auto">
<div class="navbar-center-item">
<ul id="navbar-main-elements" class="navbar-nav">
<li class="toctree-l1 nav-item">
<a class="reference internal nav-link" href="../../../index.html">
Overview
</a>
</li>
<li class="toctree-l1 nav-item">
<a class="reference internal nav-link" href="../../../getting_started/index.html">
Getting Started
</a>
</li>
<li class="toctree-l1 nav-item">
<a class="reference internal nav-link" href="../../../user_guide/index.html">
User Guides
</a>
</li>
<li class="toctree-l1 nav-item">
<a class="reference internal nav-link" href="../../../reference/index.html">
API Reference
</a>
</li>
<li class="toctree-l1 nav-item">
<a class="reference internal nav-link" href="../../../development/index.html">
Development
</a>
</li>
<li class="toctree-l1 nav-item">
<a class="reference internal nav-link" href="../../../migration_guide/index.html">
Migration Guides
</a>
</li>
</ul>
</div>
</div>
<div id="navbar-end">
<div class="navbar-end-item">
<!--
Licensed to the Apache Software Foundation (ASF) under one or more
contributor license agreements. See the NOTICE file distributed with
this work for additional information regarding copyright ownership.
The ASF licenses this file to You under the Apache License, Version 2.0
(the "License"); you may not use this file except in compliance with
the License. You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
-->
<div id="version-button" class="dropdown">
<button type="button" class="btn btn-secondary btn-sm navbar-btn dropdown-toggle" id="version_switcher_button" data-toggle="dropdown">
3.5.5
<span class="caret"></span>
</button>
<div id="version_switcher" class="dropdown-menu list-group-flush py-0" aria-labelledby="version_switcher_button">
<!-- dropdown will be populated by javascript on page load -->
</div>
</div>
<script type="text/javascript">
// Function to construct the target URL from the JSON components
function buildURL(entry) {
var template = "https://spark.apache.org/docs/{version}/api/python/index.html"; // supplied by jinja
template = template.replace("{version}", entry.version);
return template;
}
// Function to check if corresponding page path exists in other version of docs
// and, if so, go there instead of the homepage of the other docs version
function checkPageExistsAndRedirect(event) {
const currentFilePath = "_modules/pyspark/ml/functions.html",
otherDocsHomepage = event.target.getAttribute("href");
let tryUrl = `${otherDocsHomepage}${currentFilePath}`;
$.ajax({
type: 'HEAD',
url: tryUrl,
// if the page exists, go there
success: function() {
location.href = tryUrl;
}
}).fail(function() {
location.href = otherDocsHomepage;
});
return false;
}
// Function to populate the version switcher
(function () {
// get JSON config
$.getJSON("https://spark.apache.org/static/versions.json", function(data, textStatus, jqXHR) {
// create the nodes first (before AJAX calls) to ensure the order is
// correct (for now, links will go to doc version homepage)
$.each(data, function(index, entry) {
// if no custom name specified (e.g., "latest"), use version string
if (!("name" in entry)) {
entry.name = entry.version;
}
// construct the appropriate URL, and add it to the dropdown
entry.url = buildURL(entry);
const node = document.createElement("a");
node.setAttribute("class", "list-group-item list-group-item-action py-1");
node.setAttribute("href", `${entry.url}`);
node.textContent = `${entry.name}`;
node.onclick = checkPageExistsAndRedirect;
$("#version_switcher").append(node);
});
});
})();
</script>
</div>
</div>
</div>
</div>
</nav>
<div class="container-xl">
<div class="row">
<!-- Only show if we have sidebars configured, else just a small margin -->
<div class="col-12 col-md-3 bd-sidebar">
<div class="sidebar-start-items"><form class="bd-search d-flex align-items-center" action="../../../search.html" method="get">
<i class="icon fas fa-search"></i>
<input type="search" class="form-control" name="q" id="search-input" placeholder="Search the docs ..." aria-label="Search the docs ..." autocomplete="off" >
</form><nav class="bd-links" id="bd-docs-nav" aria-label="Main navigation">
<div class="bd-toc-item active">
</div>
</nav>
</div>
<div class="sidebar-end-items">
</div>
</div>
<div class="d-none d-xl-block col-xl-2 bd-toc">
</div>
<main class="col-12 col-md-9 col-xl-7 py-md-5 pl-md-5 pr-md-4 bd-content" role="main">
<div>
<h1>Source code for pyspark.ml.functions</h1><div class="highlight"><pre>
<span></span><span class="c1">#</span>
<span class="c1"># Licensed to the Apache Software Foundation (ASF) under one or more</span>
<span class="c1"># contributor license agreements. See the NOTICE file distributed with</span>
<span class="c1"># this work for additional information regarding copyright ownership.</span>
<span class="c1"># The ASF licenses this file to You under the Apache License, Version 2.0</span>
<span class="c1"># (the &quot;License&quot;); you may not use this file except in compliance with</span>
<span class="c1"># the License. You may obtain a copy of the License at</span>
<span class="c1">#</span>
<span class="c1"># http://www.apache.org/licenses/LICENSE-2.0</span>
<span class="c1">#</span>
<span class="c1"># Unless required by applicable law or agreed to in writing, software</span>
<span class="c1"># distributed under the License is distributed on an &quot;AS IS&quot; BASIS,</span>
<span class="c1"># WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.</span>
<span class="c1"># See the License for the specific language governing permissions and</span>
<span class="c1"># limitations under the License.</span>
<span class="c1">#</span>
<span class="kn">from</span><span class="w"> </span><span class="nn">__future__</span><span class="w"> </span><span class="kn">import</span> <span class="n">annotations</span>
<span class="kn">import</span><span class="w"> </span><span class="nn">inspect</span>
<span class="kn">import</span><span class="w"> </span><span class="nn">numpy</span><span class="w"> </span><span class="k">as</span><span class="w"> </span><span class="nn">np</span>
<span class="kn">import</span><span class="w"> </span><span class="nn">pandas</span><span class="w"> </span><span class="k">as</span><span class="w"> </span><span class="nn">pd</span>
<span class="kn">import</span><span class="w"> </span><span class="nn">uuid</span>
<span class="kn">from</span><span class="w"> </span><span class="nn">pyspark</span><span class="w"> </span><span class="kn">import</span> <span class="n">SparkContext</span>
<span class="kn">from</span><span class="w"> </span><span class="nn">pyspark.sql.functions</span><span class="w"> </span><span class="kn">import</span> <span class="n">pandas_udf</span>
<span class="kn">from</span><span class="w"> </span><span class="nn">pyspark.sql.column</span><span class="w"> </span><span class="kn">import</span> <span class="n">Column</span><span class="p">,</span> <span class="n">_to_java_column</span>
<span class="kn">from</span><span class="w"> </span><span class="nn">pyspark.sql.types</span><span class="w"> </span><span class="kn">import</span> <span class="p">(</span>
<span class="n">ArrayType</span><span class="p">,</span>
<span class="n">ByteType</span><span class="p">,</span>
<span class="n">DataType</span><span class="p">,</span>
<span class="n">DoubleType</span><span class="p">,</span>
<span class="n">FloatType</span><span class="p">,</span>
<span class="n">IntegerType</span><span class="p">,</span>
<span class="n">LongType</span><span class="p">,</span>
<span class="n">ShortType</span><span class="p">,</span>
<span class="n">StringType</span><span class="p">,</span>
<span class="n">StructType</span><span class="p">,</span>
<span class="p">)</span>
<span class="kn">from</span><span class="w"> </span><span class="nn">pyspark.ml.util</span><span class="w"> </span><span class="kn">import</span> <span class="n">try_remote_functions</span>
<span class="kn">from</span><span class="w"> </span><span class="nn">typing</span><span class="w"> </span><span class="kn">import</span> <span class="n">Any</span><span class="p">,</span> <span class="n">Callable</span><span class="p">,</span> <span class="n">Iterator</span><span class="p">,</span> <span class="n">List</span><span class="p">,</span> <span class="n">Mapping</span><span class="p">,</span> <span class="n">TYPE_CHECKING</span><span class="p">,</span> <span class="n">Tuple</span><span class="p">,</span> <span class="n">Union</span><span class="p">,</span> <span class="n">Optional</span>
<span class="k">if</span> <span class="n">TYPE_CHECKING</span><span class="p">:</span>
<span class="kn">from</span><span class="w"> </span><span class="nn">pyspark.sql._typing</span><span class="w"> </span><span class="kn">import</span> <span class="n">UserDefinedFunctionLike</span>
<span class="n">supported_scalar_types</span> <span class="o">=</span> <span class="p">(</span>
<span class="n">ByteType</span><span class="p">,</span>
<span class="n">ShortType</span><span class="p">,</span>
<span class="n">IntegerType</span><span class="p">,</span>
<span class="n">LongType</span><span class="p">,</span>
<span class="n">FloatType</span><span class="p">,</span>
<span class="n">DoubleType</span><span class="p">,</span>
<span class="n">StringType</span><span class="p">,</span>
<span class="p">)</span>
<span class="c1"># Callable type for end user predict functions that take a variable number of ndarrays as</span>
<span class="c1"># input and returns one of the following as output:</span>
<span class="c1"># - single ndarray (single output)</span>
<span class="c1"># - dictionary of named ndarrays (multiple outputs represented in columnar form)</span>
<span class="c1"># - list of dictionaries of named ndarrays (multiple outputs represented in row form)</span>
<span class="n">PredictBatchFunction</span> <span class="o">=</span> <span class="n">Callable</span><span class="p">[</span>
<span class="p">[</span><span class="n">np</span><span class="o">.</span><span class="n">ndarray</span><span class="p">],</span> <span class="n">Union</span><span class="p">[</span><span class="n">np</span><span class="o">.</span><span class="n">ndarray</span><span class="p">,</span> <span class="n">Mapping</span><span class="p">[</span><span class="nb">str</span><span class="p">,</span> <span class="n">np</span><span class="o">.</span><span class="n">ndarray</span><span class="p">],</span> <span class="n">List</span><span class="p">[</span><span class="n">Mapping</span><span class="p">[</span><span class="nb">str</span><span class="p">,</span> <span class="n">np</span><span class="o">.</span><span class="n">dtype</span><span class="p">]]]</span>
<span class="p">]</span>
<div class="viewcode-block" id="vector_to_array"><a class="viewcode-back" href="../../../reference/api/pyspark.ml.functions.vector_to_array.html#pyspark.ml.functions.vector_to_array">[docs]</a><span class="nd">@try_remote_functions</span>
<span class="k">def</span><span class="w"> </span><span class="nf">vector_to_array</span><span class="p">(</span><span class="n">col</span><span class="p">:</span> <span class="n">Column</span><span class="p">,</span> <span class="n">dtype</span><span class="p">:</span> <span class="nb">str</span> <span class="o">=</span> <span class="s2">&quot;float64&quot;</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="n">Column</span><span class="p">:</span>
<span class="w"> </span><span class="sd">&quot;&quot;&quot;</span>
<span class="sd"> Converts a column of MLlib sparse/dense vectors into a column of dense arrays.</span>
<span class="sd"> .. versionadded:: 3.0.0</span>
<span class="sd"> .. versionchanged:: 3.5.0</span>
<span class="sd"> Supports Spark Connect.</span>
<span class="sd"> Parameters</span>
<span class="sd"> ----------</span>
<span class="sd"> col : :py:class:`pyspark.sql.Column` or str</span>
<span class="sd"> Input column</span>
<span class="sd"> dtype : str, optional</span>
<span class="sd"> The data type of the output array. Valid values: &quot;float64&quot; or &quot;float32&quot;.</span>
<span class="sd"> Returns</span>
<span class="sd"> -------</span>
<span class="sd"> :py:class:`pyspark.sql.Column`</span>
<span class="sd"> The converted column of dense arrays.</span>
<span class="sd"> Examples</span>
<span class="sd"> --------</span>
<span class="sd"> &gt;&gt;&gt; from pyspark.ml.linalg import Vectors</span>
<span class="sd"> &gt;&gt;&gt; from pyspark.ml.functions import vector_to_array</span>
<span class="sd"> &gt;&gt;&gt; from pyspark.mllib.linalg import Vectors as OldVectors</span>
<span class="sd"> &gt;&gt;&gt; df = spark.createDataFrame([</span>
<span class="sd"> ... (Vectors.dense(1.0, 2.0, 3.0), OldVectors.dense(10.0, 20.0, 30.0)),</span>
<span class="sd"> ... (Vectors.sparse(3, [(0, 2.0), (2, 3.0)]),</span>
<span class="sd"> ... OldVectors.sparse(3, [(0, 20.0), (2, 30.0)]))],</span>
<span class="sd"> ... [&quot;vec&quot;, &quot;oldVec&quot;])</span>
<span class="sd"> &gt;&gt;&gt; df1 = df.select(vector_to_array(&quot;vec&quot;).alias(&quot;vec&quot;),</span>
<span class="sd"> ... vector_to_array(&quot;oldVec&quot;).alias(&quot;oldVec&quot;))</span>
<span class="sd"> &gt;&gt;&gt; df1.collect()</span>
<span class="sd"> [Row(vec=[1.0, 2.0, 3.0], oldVec=[10.0, 20.0, 30.0]),</span>
<span class="sd"> Row(vec=[2.0, 0.0, 3.0], oldVec=[20.0, 0.0, 30.0])]</span>
<span class="sd"> &gt;&gt;&gt; df2 = df.select(vector_to_array(&quot;vec&quot;, &quot;float32&quot;).alias(&quot;vec&quot;),</span>
<span class="sd"> ... vector_to_array(&quot;oldVec&quot;, &quot;float32&quot;).alias(&quot;oldVec&quot;))</span>
<span class="sd"> &gt;&gt;&gt; df2.collect()</span>
<span class="sd"> [Row(vec=[1.0, 2.0, 3.0], oldVec=[10.0, 20.0, 30.0]),</span>
<span class="sd"> Row(vec=[2.0, 0.0, 3.0], oldVec=[20.0, 0.0, 30.0])]</span>
<span class="sd"> &gt;&gt;&gt; df1.schema.fields</span>
<span class="sd"> [StructField(&#39;vec&#39;, ArrayType(DoubleType(), False), False),</span>
<span class="sd"> StructField(&#39;oldVec&#39;, ArrayType(DoubleType(), False), False)]</span>
<span class="sd"> &gt;&gt;&gt; df2.schema.fields</span>
<span class="sd"> [StructField(&#39;vec&#39;, ArrayType(FloatType(), False), False),</span>
<span class="sd"> StructField(&#39;oldVec&#39;, ArrayType(FloatType(), False), False)]</span>
<span class="sd"> &quot;&quot;&quot;</span>
<span class="n">sc</span> <span class="o">=</span> <span class="n">SparkContext</span><span class="o">.</span><span class="n">_active_spark_context</span>
<span class="k">assert</span> <span class="n">sc</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span> <span class="ow">and</span> <span class="n">sc</span><span class="o">.</span><span class="n">_jvm</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span>
<span class="k">return</span> <span class="n">Column</span><span class="p">(</span>
<span class="n">sc</span><span class="o">.</span><span class="n">_jvm</span><span class="o">.</span><span class="n">org</span><span class="o">.</span><span class="n">apache</span><span class="o">.</span><span class="n">spark</span><span class="o">.</span><span class="n">ml</span><span class="o">.</span><span class="n">functions</span><span class="o">.</span><span class="n">vector_to_array</span><span class="p">(</span><span class="n">_to_java_column</span><span class="p">(</span><span class="n">col</span><span class="p">),</span> <span class="n">dtype</span><span class="p">)</span>
<span class="p">)</span></div>
<div class="viewcode-block" id="array_to_vector"><a class="viewcode-back" href="../../../reference/api/pyspark.ml.functions.array_to_vector.html#pyspark.ml.functions.array_to_vector">[docs]</a><span class="nd">@try_remote_functions</span>
<span class="k">def</span><span class="w"> </span><span class="nf">array_to_vector</span><span class="p">(</span><span class="n">col</span><span class="p">:</span> <span class="n">Column</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="n">Column</span><span class="p">:</span>
<span class="w"> </span><span class="sd">&quot;&quot;&quot;</span>
<span class="sd"> Converts a column of array of numeric type into a column of pyspark.ml.linalg.DenseVector</span>
<span class="sd"> instances</span>
<span class="sd"> .. versionadded:: 3.1.0</span>
<span class="sd"> .. versionchanged:: 3.5.0</span>
<span class="sd"> Supports Spark Connect.</span>
<span class="sd"> Parameters</span>
<span class="sd"> ----------</span>
<span class="sd"> col : :py:class:`pyspark.sql.Column` or str</span>
<span class="sd"> Input column</span>
<span class="sd"> Returns</span>
<span class="sd"> -------</span>
<span class="sd"> :py:class:`pyspark.sql.Column`</span>
<span class="sd"> The converted column of dense vectors.</span>
<span class="sd"> Examples</span>
<span class="sd"> --------</span>
<span class="sd"> &gt;&gt;&gt; from pyspark.ml.functions import array_to_vector</span>
<span class="sd"> &gt;&gt;&gt; df1 = spark.createDataFrame([([1.5, 2.5],),], schema=&#39;v1 array&lt;double&gt;&#39;)</span>
<span class="sd"> &gt;&gt;&gt; df1.select(array_to_vector(&#39;v1&#39;).alias(&#39;vec1&#39;)).collect()</span>
<span class="sd"> [Row(vec1=DenseVector([1.5, 2.5]))]</span>
<span class="sd"> &gt;&gt;&gt; df2 = spark.createDataFrame([([1.5, 3.5],),], schema=&#39;v1 array&lt;float&gt;&#39;)</span>
<span class="sd"> &gt;&gt;&gt; df2.select(array_to_vector(&#39;v1&#39;).alias(&#39;vec1&#39;)).collect()</span>
<span class="sd"> [Row(vec1=DenseVector([1.5, 3.5]))]</span>
<span class="sd"> &gt;&gt;&gt; df3 = spark.createDataFrame([([1, 3],),], schema=&#39;v1 array&lt;int&gt;&#39;)</span>
<span class="sd"> &gt;&gt;&gt; df3.select(array_to_vector(&#39;v1&#39;).alias(&#39;vec1&#39;)).collect()</span>
<span class="sd"> [Row(vec1=DenseVector([1.0, 3.0]))]</span>
<span class="sd"> &quot;&quot;&quot;</span>
<span class="n">sc</span> <span class="o">=</span> <span class="n">SparkContext</span><span class="o">.</span><span class="n">_active_spark_context</span>
<span class="k">assert</span> <span class="n">sc</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span> <span class="ow">and</span> <span class="n">sc</span><span class="o">.</span><span class="n">_jvm</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span>
<span class="k">return</span> <span class="n">Column</span><span class="p">(</span><span class="n">sc</span><span class="o">.</span><span class="n">_jvm</span><span class="o">.</span><span class="n">org</span><span class="o">.</span><span class="n">apache</span><span class="o">.</span><span class="n">spark</span><span class="o">.</span><span class="n">ml</span><span class="o">.</span><span class="n">functions</span><span class="o">.</span><span class="n">array_to_vector</span><span class="p">(</span><span class="n">_to_java_column</span><span class="p">(</span><span class="n">col</span><span class="p">)))</span></div>
<span class="k">def</span><span class="w"> </span><span class="nf">_batched</span><span class="p">(</span>
<span class="n">data</span><span class="p">:</span> <span class="n">Union</span><span class="p">[</span><span class="n">pd</span><span class="o">.</span><span class="n">Series</span><span class="p">,</span> <span class="n">pd</span><span class="o">.</span><span class="n">DataFrame</span><span class="p">,</span> <span class="n">Tuple</span><span class="p">[</span><span class="n">pd</span><span class="o">.</span><span class="n">Series</span><span class="p">]],</span> <span class="n">batch_size</span><span class="p">:</span> <span class="nb">int</span>
<span class="p">)</span> <span class="o">-&gt;</span> <span class="n">Iterator</span><span class="p">[</span><span class="n">pd</span><span class="o">.</span><span class="n">DataFrame</span><span class="p">]:</span>
<span class="w"> </span><span class="sd">&quot;&quot;&quot;Generator that splits a pandas dataframe/series into batches.&quot;&quot;&quot;</span>
<span class="k">if</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">data</span><span class="p">,</span> <span class="n">pd</span><span class="o">.</span><span class="n">DataFrame</span><span class="p">):</span>
<span class="n">df</span> <span class="o">=</span> <span class="n">data</span>
<span class="k">elif</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">data</span><span class="p">,</span> <span class="n">pd</span><span class="o">.</span><span class="n">Series</span><span class="p">):</span>
<span class="n">df</span> <span class="o">=</span> <span class="n">pd</span><span class="o">.</span><span class="n">concat</span><span class="p">((</span><span class="n">data</span><span class="p">,),</span> <span class="n">axis</span><span class="o">=</span><span class="mi">1</span><span class="p">)</span>
<span class="k">else</span><span class="p">:</span> <span class="c1"># isinstance(data, Tuple[pd.Series]):</span>
<span class="n">df</span> <span class="o">=</span> <span class="n">pd</span><span class="o">.</span><span class="n">concat</span><span class="p">(</span><span class="n">data</span><span class="p">,</span> <span class="n">axis</span><span class="o">=</span><span class="mi">1</span><span class="p">)</span>
<span class="n">index</span> <span class="o">=</span> <span class="mi">0</span>
<span class="n">data_size</span> <span class="o">=</span> <span class="nb">len</span><span class="p">(</span><span class="n">df</span><span class="p">)</span>
<span class="k">while</span> <span class="n">index</span> <span class="o">&lt;</span> <span class="n">data_size</span><span class="p">:</span>
<span class="k">yield</span> <span class="n">df</span><span class="o">.</span><span class="n">iloc</span><span class="p">[</span><span class="n">index</span> <span class="p">:</span> <span class="n">index</span> <span class="o">+</span> <span class="n">batch_size</span><span class="p">]</span>
<span class="n">index</span> <span class="o">+=</span> <span class="n">batch_size</span>
<span class="k">def</span><span class="w"> </span><span class="nf">_is_tensor_col</span><span class="p">(</span><span class="n">data</span><span class="p">:</span> <span class="n">Union</span><span class="p">[</span><span class="n">pd</span><span class="o">.</span><span class="n">Series</span><span class="p">,</span> <span class="n">pd</span><span class="o">.</span><span class="n">DataFrame</span><span class="p">])</span> <span class="o">-&gt;</span> <span class="nb">bool</span><span class="p">:</span>
<span class="k">if</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">data</span><span class="p">,</span> <span class="n">pd</span><span class="o">.</span><span class="n">Series</span><span class="p">):</span>
<span class="k">return</span> <span class="n">data</span><span class="o">.</span><span class="n">dtype</span> <span class="o">==</span> <span class="n">np</span><span class="o">.</span><span class="n">object_</span> <span class="ow">and</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">data</span><span class="o">.</span><span class="n">iloc</span><span class="p">[</span><span class="mi">0</span><span class="p">],</span> <span class="p">(</span><span class="n">np</span><span class="o">.</span><span class="n">ndarray</span><span class="p">,</span> <span class="nb">list</span><span class="p">))</span>
<span class="k">elif</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">data</span><span class="p">,</span> <span class="n">pd</span><span class="o">.</span><span class="n">DataFrame</span><span class="p">):</span>
<span class="k">return</span> <span class="nb">any</span><span class="p">(</span><span class="n">data</span><span class="o">.</span><span class="n">dtypes</span> <span class="o">==</span> <span class="n">np</span><span class="o">.</span><span class="n">object_</span><span class="p">)</span> <span class="ow">and</span> <span class="nb">any</span><span class="p">(</span>
<span class="p">[</span><span class="nb">isinstance</span><span class="p">(</span><span class="n">d</span><span class="p">,</span> <span class="p">(</span><span class="n">np</span><span class="o">.</span><span class="n">ndarray</span><span class="p">,</span> <span class="nb">list</span><span class="p">))</span> <span class="k">for</span> <span class="n">d</span> <span class="ow">in</span> <span class="n">data</span><span class="o">.</span><span class="n">iloc</span><span class="p">[</span><span class="mi">0</span><span class="p">]]</span>
<span class="p">)</span>
<span class="k">else</span><span class="p">:</span>
<span class="k">raise</span> <span class="ne">ValueError</span><span class="p">(</span>
<span class="s2">&quot;Unexpected data type: </span><span class="si">{}</span><span class="s2">, expected pd.Series or pd.DataFrame.&quot;</span><span class="o">.</span><span class="n">format</span><span class="p">(</span><span class="nb">type</span><span class="p">(</span><span class="n">data</span><span class="p">))</span>
<span class="p">)</span>
<span class="k">def</span><span class="w"> </span><span class="nf">_has_tensor_cols</span><span class="p">(</span><span class="n">data</span><span class="p">:</span> <span class="n">Union</span><span class="p">[</span><span class="n">pd</span><span class="o">.</span><span class="n">Series</span><span class="p">,</span> <span class="n">pd</span><span class="o">.</span><span class="n">DataFrame</span><span class="p">,</span> <span class="n">Tuple</span><span class="p">[</span><span class="n">pd</span><span class="o">.</span><span class="n">Series</span><span class="p">]])</span> <span class="o">-&gt;</span> <span class="nb">bool</span><span class="p">:</span>
<span class="w"> </span><span class="sd">&quot;&quot;&quot;Check if input Series/DataFrame/Tuple contains any tensor-valued columns.&quot;&quot;&quot;</span>
<span class="k">if</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">data</span><span class="p">,</span> <span class="p">(</span><span class="n">pd</span><span class="o">.</span><span class="n">Series</span><span class="p">,</span> <span class="n">pd</span><span class="o">.</span><span class="n">DataFrame</span><span class="p">)):</span>
<span class="k">return</span> <span class="n">_is_tensor_col</span><span class="p">(</span><span class="n">data</span><span class="p">)</span>
<span class="k">else</span><span class="p">:</span> <span class="c1"># isinstance(data, Tuple):</span>
<span class="k">return</span> <span class="nb">any</span><span class="p">(</span><span class="n">_is_tensor_col</span><span class="p">(</span><span class="n">elem</span><span class="p">)</span> <span class="k">for</span> <span class="n">elem</span> <span class="ow">in</span> <span class="n">data</span><span class="p">)</span>
<span class="k">def</span><span class="w"> </span><span class="nf">_validate_and_transform_multiple_inputs</span><span class="p">(</span>
<span class="n">batch</span><span class="p">:</span> <span class="n">pd</span><span class="o">.</span><span class="n">DataFrame</span><span class="p">,</span> <span class="n">input_shapes</span><span class="p">:</span> <span class="n">List</span><span class="p">[</span><span class="n">Optional</span><span class="p">[</span><span class="n">List</span><span class="p">[</span><span class="nb">int</span><span class="p">]]],</span> <span class="n">num_input_cols</span><span class="p">:</span> <span class="nb">int</span>
<span class="p">)</span> <span class="o">-&gt;</span> <span class="n">List</span><span class="p">[</span><span class="n">np</span><span class="o">.</span><span class="n">ndarray</span><span class="p">]:</span>
<span class="n">multi_inputs</span> <span class="o">=</span> <span class="p">[</span><span class="n">batch</span><span class="p">[</span><span class="n">col</span><span class="p">]</span><span class="o">.</span><span class="n">to_numpy</span><span class="p">()</span> <span class="k">for</span> <span class="n">col</span> <span class="ow">in</span> <span class="n">batch</span><span class="o">.</span><span class="n">columns</span><span class="p">]</span>
<span class="k">if</span> <span class="n">input_shapes</span><span class="p">:</span>
<span class="k">if</span> <span class="nb">len</span><span class="p">(</span><span class="n">input_shapes</span><span class="p">)</span> <span class="o">==</span> <span class="n">num_input_cols</span><span class="p">:</span>
<span class="n">multi_inputs</span> <span class="o">=</span> <span class="p">[</span>
<span class="n">np</span><span class="o">.</span><span class="n">vstack</span><span class="p">(</span><span class="n">v</span><span class="p">)</span><span class="o">.</span><span class="n">reshape</span><span class="p">([</span><span class="o">-</span><span class="mi">1</span><span class="p">]</span> <span class="o">+</span> <span class="n">input_shapes</span><span class="p">[</span><span class="n">i</span><span class="p">])</span> <span class="c1"># type: ignore</span>
<span class="k">if</span> <span class="n">input_shapes</span><span class="p">[</span><span class="n">i</span><span class="p">]</span>
<span class="k">else</span> <span class="n">v</span>
<span class="k">for</span> <span class="n">i</span><span class="p">,</span> <span class="n">v</span> <span class="ow">in</span> <span class="nb">enumerate</span><span class="p">(</span><span class="n">multi_inputs</span><span class="p">)</span>
<span class="p">]</span>
<span class="k">if</span> <span class="ow">not</span> <span class="nb">all</span><span class="p">([</span><span class="nb">len</span><span class="p">(</span><span class="n">x</span><span class="p">)</span> <span class="o">==</span> <span class="nb">len</span><span class="p">(</span><span class="n">batch</span><span class="p">)</span> <span class="k">for</span> <span class="n">x</span> <span class="ow">in</span> <span class="n">multi_inputs</span><span class="p">]):</span>
<span class="k">raise</span> <span class="ne">ValueError</span><span class="p">(</span><span class="s2">&quot;Input data does not match expected shape.&quot;</span><span class="p">)</span>
<span class="k">else</span><span class="p">:</span>
<span class="k">raise</span> <span class="ne">ValueError</span><span class="p">(</span><span class="s2">&quot;input_tensor_shapes must match columns&quot;</span><span class="p">)</span>
<span class="k">return</span> <span class="n">multi_inputs</span>
<span class="k">def</span><span class="w"> </span><span class="nf">_validate_and_transform_single_input</span><span class="p">(</span>
<span class="n">batch</span><span class="p">:</span> <span class="n">pd</span><span class="o">.</span><span class="n">DataFrame</span><span class="p">,</span>
<span class="n">input_shapes</span><span class="p">:</span> <span class="n">List</span><span class="p">[</span><span class="n">List</span><span class="p">[</span><span class="nb">int</span><span class="p">]</span> <span class="o">|</span> <span class="kc">None</span><span class="p">],</span>
<span class="n">has_tensors</span><span class="p">:</span> <span class="nb">bool</span><span class="p">,</span>
<span class="n">has_tuple</span><span class="p">:</span> <span class="nb">bool</span><span class="p">,</span>
<span class="p">)</span> <span class="o">-&gt;</span> <span class="n">np</span><span class="o">.</span><span class="n">ndarray</span><span class="p">:</span>
<span class="c1"># multiple input columns for single expected input</span>
<span class="k">if</span> <span class="n">has_tensors</span><span class="p">:</span>
<span class="c1"># tensor columns</span>
<span class="k">if</span> <span class="nb">len</span><span class="p">(</span><span class="n">batch</span><span class="o">.</span><span class="n">columns</span><span class="p">)</span> <span class="o">==</span> <span class="mi">1</span><span class="p">:</span>
<span class="c1"># one tensor column and one expected input, vstack rows</span>
<span class="n">single_input</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">vstack</span><span class="p">(</span><span class="n">batch</span><span class="o">.</span><span class="n">iloc</span><span class="p">[:,</span> <span class="mi">0</span><span class="p">])</span>
<span class="k">else</span><span class="p">:</span>
<span class="k">raise</span> <span class="ne">ValueError</span><span class="p">(</span>
<span class="s2">&quot;Multiple input columns found, but model expected a single &quot;</span>
<span class="s2">&quot;input, use `array` to combine columns into tensors.&quot;</span>
<span class="p">)</span>
<span class="k">else</span><span class="p">:</span>
<span class="c1"># scalar columns</span>
<span class="k">if</span> <span class="nb">len</span><span class="p">(</span><span class="n">batch</span><span class="o">.</span><span class="n">columns</span><span class="p">)</span> <span class="o">==</span> <span class="mi">1</span><span class="p">:</span>
<span class="c1"># single scalar column, remove extra dim</span>
<span class="n">np_batch</span> <span class="o">=</span> <span class="n">batch</span><span class="o">.</span><span class="n">to_numpy</span><span class="p">()</span>
<span class="n">single_input</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">squeeze</span><span class="p">(</span><span class="n">np_batch</span><span class="p">,</span> <span class="o">-</span><span class="mi">1</span><span class="p">)</span> <span class="k">if</span> <span class="nb">len</span><span class="p">(</span><span class="n">np_batch</span><span class="o">.</span><span class="n">shape</span><span class="p">)</span> <span class="o">&gt;</span> <span class="mi">1</span> <span class="k">else</span> <span class="n">np_batch</span>
<span class="k">if</span> <span class="n">input_shapes</span> <span class="ow">and</span> <span class="n">input_shapes</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span> <span class="ow">not</span> <span class="ow">in</span> <span class="p">[</span><span class="kc">None</span><span class="p">,</span> <span class="p">[],</span> <span class="p">[</span><span class="mi">1</span><span class="p">]]:</span>
<span class="k">raise</span> <span class="ne">ValueError</span><span class="p">(</span><span class="s2">&quot;Invalid input_tensor_shape for scalar column.&quot;</span><span class="p">)</span>
<span class="k">elif</span> <span class="ow">not</span> <span class="n">has_tuple</span><span class="p">:</span>
<span class="c1"># columns grouped via `array`, convert to single tensor</span>
<span class="n">single_input</span> <span class="o">=</span> <span class="n">batch</span><span class="o">.</span><span class="n">to_numpy</span><span class="p">()</span>
<span class="k">if</span> <span class="n">input_shapes</span> <span class="ow">and</span> <span class="n">input_shapes</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span> <span class="o">!=</span> <span class="p">[</span><span class="nb">len</span><span class="p">(</span><span class="n">batch</span><span class="o">.</span><span class="n">columns</span><span class="p">)]:</span>
<span class="k">raise</span> <span class="ne">ValueError</span><span class="p">(</span><span class="s2">&quot;Input data does not match expected shape.&quot;</span><span class="p">)</span>
<span class="k">else</span><span class="p">:</span>
<span class="k">raise</span> <span class="ne">ValueError</span><span class="p">(</span>
<span class="s2">&quot;Multiple input columns found, but model expected a single &quot;</span>
<span class="s2">&quot;input, use `array` to combine columns into tensors.&quot;</span>
<span class="p">)</span>
<span class="c1"># if input_tensor_shapes provided, try to reshape input</span>
<span class="k">if</span> <span class="n">input_shapes</span><span class="p">:</span>
<span class="k">if</span> <span class="nb">len</span><span class="p">(</span><span class="n">input_shapes</span><span class="p">)</span> <span class="o">==</span> <span class="mi">1</span><span class="p">:</span>
<span class="n">single_input</span> <span class="o">=</span> <span class="n">single_input</span><span class="o">.</span><span class="n">reshape</span><span class="p">([</span><span class="o">-</span><span class="mi">1</span><span class="p">]</span> <span class="o">+</span> <span class="n">input_shapes</span><span class="p">[</span><span class="mi">0</span><span class="p">])</span> <span class="c1"># type: ignore</span>
<span class="k">if</span> <span class="nb">len</span><span class="p">(</span><span class="n">single_input</span><span class="p">)</span> <span class="o">!=</span> <span class="nb">len</span><span class="p">(</span><span class="n">batch</span><span class="p">):</span>
<span class="k">raise</span> <span class="ne">ValueError</span><span class="p">(</span><span class="s2">&quot;Input data does not match expected shape.&quot;</span><span class="p">)</span>
<span class="k">else</span><span class="p">:</span>
<span class="k">raise</span> <span class="ne">ValueError</span><span class="p">(</span><span class="s2">&quot;Multiple input_tensor_shapes found, but model expected one input&quot;</span><span class="p">)</span>
<span class="k">return</span> <span class="n">single_input</span>
<span class="k">def</span><span class="w"> </span><span class="nf">_validate_and_transform_prediction_result</span><span class="p">(</span>
<span class="n">preds</span><span class="p">:</span> <span class="n">np</span><span class="o">.</span><span class="n">ndarray</span> <span class="o">|</span> <span class="n">Mapping</span><span class="p">[</span><span class="nb">str</span><span class="p">,</span> <span class="n">np</span><span class="o">.</span><span class="n">ndarray</span><span class="p">]</span> <span class="o">|</span> <span class="n">List</span><span class="p">[</span><span class="n">Mapping</span><span class="p">[</span><span class="nb">str</span><span class="p">,</span> <span class="n">Any</span><span class="p">]],</span>
<span class="n">num_input_rows</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span>
<span class="n">return_type</span><span class="p">:</span> <span class="n">DataType</span><span class="p">,</span>
<span class="p">)</span> <span class="o">-&gt;</span> <span class="n">pd</span><span class="o">.</span><span class="n">DataFrame</span> <span class="o">|</span> <span class="n">pd</span><span class="o">.</span><span class="n">Series</span><span class="p">:</span>
<span class="w"> </span><span class="sd">&quot;&quot;&quot;Validate numpy-based model predictions against the expected pandas_udf return_type and</span>
<span class="sd"> transforms the predictions into an equivalent pandas DataFrame or Series.&quot;&quot;&quot;</span>
<span class="k">if</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">return_type</span><span class="p">,</span> <span class="n">StructType</span><span class="p">):</span>
<span class="n">struct_rtype</span><span class="p">:</span> <span class="n">StructType</span> <span class="o">=</span> <span class="n">return_type</span>
<span class="n">fieldNames</span> <span class="o">=</span> <span class="n">struct_rtype</span><span class="o">.</span><span class="n">names</span>
<span class="k">if</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">preds</span><span class="p">,</span> <span class="nb">dict</span><span class="p">):</span>
<span class="c1"># dictionary of columns</span>
<span class="n">predNames</span> <span class="o">=</span> <span class="nb">list</span><span class="p">(</span><span class="n">preds</span><span class="o">.</span><span class="n">keys</span><span class="p">())</span>
<span class="k">for</span> <span class="n">field</span> <span class="ow">in</span> <span class="n">struct_rtype</span><span class="o">.</span><span class="n">fields</span><span class="p">:</span>
<span class="k">if</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">field</span><span class="o">.</span><span class="n">dataType</span><span class="p">,</span> <span class="n">ArrayType</span><span class="p">):</span>
<span class="k">if</span> <span class="nb">len</span><span class="p">(</span><span class="n">preds</span><span class="p">[</span><span class="n">field</span><span class="o">.</span><span class="n">name</span><span class="p">]</span><span class="o">.</span><span class="n">shape</span><span class="p">)</span> <span class="o">==</span> <span class="mi">2</span><span class="p">:</span>
<span class="n">preds</span><span class="p">[</span><span class="n">field</span><span class="o">.</span><span class="n">name</span><span class="p">]</span> <span class="o">=</span> <span class="nb">list</span><span class="p">(</span><span class="n">preds</span><span class="p">[</span><span class="n">field</span><span class="o">.</span><span class="n">name</span><span class="p">])</span>
<span class="k">else</span><span class="p">:</span>
<span class="k">raise</span> <span class="ne">ValueError</span><span class="p">(</span>
<span class="s2">&quot;Prediction results for ArrayType must be two-dimensional.&quot;</span>
<span class="p">)</span>
<span class="k">elif</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">field</span><span class="o">.</span><span class="n">dataType</span><span class="p">,</span> <span class="n">supported_scalar_types</span><span class="p">):</span>
<span class="k">if</span> <span class="nb">len</span><span class="p">(</span><span class="n">preds</span><span class="p">[</span><span class="n">field</span><span class="o">.</span><span class="n">name</span><span class="p">]</span><span class="o">.</span><span class="n">shape</span><span class="p">)</span> <span class="o">!=</span> <span class="mi">1</span><span class="p">:</span>
<span class="k">raise</span> <span class="ne">ValueError</span><span class="p">(</span>
<span class="s2">&quot;Prediction results for scalar types must be one-dimensional.&quot;</span>
<span class="p">)</span>
<span class="k">else</span><span class="p">:</span>
<span class="k">raise</span> <span class="ne">ValueError</span><span class="p">(</span><span class="s2">&quot;Unsupported field type in return struct type.&quot;</span><span class="p">)</span>
<span class="k">if</span> <span class="nb">len</span><span class="p">(</span><span class="n">preds</span><span class="p">[</span><span class="n">field</span><span class="o">.</span><span class="n">name</span><span class="p">])</span> <span class="o">!=</span> <span class="n">num_input_rows</span><span class="p">:</span>
<span class="k">raise</span> <span class="ne">ValueError</span><span class="p">(</span><span class="s2">&quot;Prediction results must have same length as input data&quot;</span><span class="p">)</span>
<span class="k">elif</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">preds</span><span class="p">,</span> <span class="nb">list</span><span class="p">)</span> <span class="ow">and</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">preds</span><span class="p">[</span><span class="mi">0</span><span class="p">],</span> <span class="nb">dict</span><span class="p">):</span>
<span class="c1"># rows of dictionaries</span>
<span class="n">predNames</span> <span class="o">=</span> <span class="nb">list</span><span class="p">(</span><span class="n">preds</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span><span class="o">.</span><span class="n">keys</span><span class="p">())</span>
<span class="k">if</span> <span class="nb">len</span><span class="p">(</span><span class="n">preds</span><span class="p">)</span> <span class="o">!=</span> <span class="n">num_input_rows</span><span class="p">:</span>
<span class="k">raise</span> <span class="ne">ValueError</span><span class="p">(</span><span class="s2">&quot;Prediction results must have same length as input data.&quot;</span><span class="p">)</span>
<span class="k">for</span> <span class="n">field</span> <span class="ow">in</span> <span class="n">struct_rtype</span><span class="o">.</span><span class="n">fields</span><span class="p">:</span>
<span class="k">if</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">field</span><span class="o">.</span><span class="n">dataType</span><span class="p">,</span> <span class="n">ArrayType</span><span class="p">):</span>
<span class="k">if</span> <span class="nb">len</span><span class="p">(</span><span class="n">preds</span><span class="p">[</span><span class="mi">0</span><span class="p">][</span><span class="n">field</span><span class="o">.</span><span class="n">name</span><span class="p">]</span><span class="o">.</span><span class="n">shape</span><span class="p">)</span> <span class="o">!=</span> <span class="mi">1</span><span class="p">:</span>
<span class="k">raise</span> <span class="ne">ValueError</span><span class="p">(</span>
<span class="s2">&quot;Prediction results for ArrayType must be one-dimensional.&quot;</span>
<span class="p">)</span>
<span class="k">elif</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">field</span><span class="o">.</span><span class="n">dataType</span><span class="p">,</span> <span class="n">supported_scalar_types</span><span class="p">):</span>
<span class="k">if</span> <span class="ow">not</span> <span class="n">np</span><span class="o">.</span><span class="n">isscalar</span><span class="p">(</span><span class="n">preds</span><span class="p">[</span><span class="mi">0</span><span class="p">][</span><span class="n">field</span><span class="o">.</span><span class="n">name</span><span class="p">]):</span>
<span class="k">raise</span> <span class="ne">ValueError</span><span class="p">(</span><span class="s2">&quot;Invalid scalar prediction result.&quot;</span><span class="p">)</span>
<span class="k">else</span><span class="p">:</span>
<span class="k">raise</span> <span class="ne">ValueError</span><span class="p">(</span><span class="s2">&quot;Unsupported field type in return struct type.&quot;</span><span class="p">)</span>
<span class="k">else</span><span class="p">:</span>
<span class="k">raise</span> <span class="ne">ValueError</span><span class="p">(</span>
<span class="s2">&quot;Prediction results for StructType must be a dictionary or &quot;</span>
<span class="s2">&quot;a list of dictionary, got: </span><span class="si">{}</span><span class="s2">&quot;</span><span class="o">.</span><span class="n">format</span><span class="p">(</span><span class="nb">type</span><span class="p">(</span><span class="n">preds</span><span class="p">))</span>
<span class="p">)</span>
<span class="c1"># check column names</span>
<span class="k">if</span> <span class="nb">set</span><span class="p">(</span><span class="n">predNames</span><span class="p">)</span> <span class="o">!=</span> <span class="nb">set</span><span class="p">(</span><span class="n">fieldNames</span><span class="p">):</span>
<span class="k">raise</span> <span class="ne">ValueError</span><span class="p">(</span>
<span class="s2">&quot;Prediction result columns did not match expected return_type &quot;</span>
<span class="s2">&quot;columns: expected </span><span class="si">{}</span><span class="s2">, got: </span><span class="si">{}</span><span class="s2">&quot;</span><span class="o">.</span><span class="n">format</span><span class="p">(</span><span class="n">fieldNames</span><span class="p">,</span> <span class="n">predNames</span><span class="p">)</span>
<span class="p">)</span>
<span class="k">return</span> <span class="n">pd</span><span class="o">.</span><span class="n">DataFrame</span><span class="p">(</span><span class="n">preds</span><span class="p">)</span>
<span class="k">elif</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">return_type</span><span class="p">,</span> <span class="n">ArrayType</span><span class="p">):</span>
<span class="k">if</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">preds</span><span class="p">,</span> <span class="n">np</span><span class="o">.</span><span class="n">ndarray</span><span class="p">):</span>
<span class="k">if</span> <span class="nb">len</span><span class="p">(</span><span class="n">preds</span><span class="p">)</span> <span class="o">!=</span> <span class="n">num_input_rows</span><span class="p">:</span>
<span class="k">raise</span> <span class="ne">ValueError</span><span class="p">(</span><span class="s2">&quot;Prediction results must have same length as input data.&quot;</span><span class="p">)</span>
<span class="k">if</span> <span class="nb">len</span><span class="p">(</span><span class="n">preds</span><span class="o">.</span><span class="n">shape</span><span class="p">)</span> <span class="o">!=</span> <span class="mi">2</span><span class="p">:</span>
<span class="k">raise</span> <span class="ne">ValueError</span><span class="p">(</span><span class="s2">&quot;Prediction results for ArrayType must be two-dimensional.&quot;</span><span class="p">)</span>
<span class="k">else</span><span class="p">:</span>
<span class="k">raise</span> <span class="ne">ValueError</span><span class="p">(</span><span class="s2">&quot;Prediction results for ArrayType must be an ndarray.&quot;</span><span class="p">)</span>
<span class="k">return</span> <span class="n">pd</span><span class="o">.</span><span class="n">Series</span><span class="p">(</span><span class="nb">list</span><span class="p">(</span><span class="n">preds</span><span class="p">))</span>
<span class="k">elif</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">return_type</span><span class="p">,</span> <span class="n">supported_scalar_types</span><span class="p">):</span>
<span class="n">preds_array</span><span class="p">:</span> <span class="n">np</span><span class="o">.</span><span class="n">ndarray</span> <span class="o">=</span> <span class="n">preds</span> <span class="c1"># type: ignore</span>
<span class="k">if</span> <span class="nb">len</span><span class="p">(</span><span class="n">preds_array</span><span class="p">)</span> <span class="o">!=</span> <span class="n">num_input_rows</span><span class="p">:</span>
<span class="k">raise</span> <span class="ne">ValueError</span><span class="p">(</span><span class="s2">&quot;Prediction results must have same length as input data.&quot;</span><span class="p">)</span>
<span class="k">if</span> <span class="ow">not</span> <span class="p">(</span>
<span class="p">(</span><span class="nb">len</span><span class="p">(</span><span class="n">preds_array</span><span class="o">.</span><span class="n">shape</span><span class="p">)</span> <span class="o">==</span> <span class="mi">2</span> <span class="ow">and</span> <span class="n">preds_array</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">1</span><span class="p">]</span> <span class="o">==</span> <span class="mi">1</span><span class="p">)</span>
<span class="ow">or</span> <span class="nb">len</span><span class="p">(</span><span class="n">preds_array</span><span class="o">.</span><span class="n">shape</span><span class="p">)</span> <span class="o">==</span> <span class="mi">1</span>
<span class="p">):</span>
<span class="k">raise</span> <span class="ne">ValueError</span><span class="p">(</span><span class="s2">&quot;Invalid shape for scalar prediction result.&quot;</span><span class="p">)</span>
<span class="n">output</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">squeeze</span><span class="p">(</span><span class="n">preds_array</span><span class="p">,</span> <span class="o">-</span><span class="mi">1</span><span class="p">)</span> <span class="k">if</span> <span class="nb">len</span><span class="p">(</span><span class="n">preds_array</span><span class="o">.</span><span class="n">shape</span><span class="p">)</span> <span class="o">&gt;</span> <span class="mi">1</span> <span class="k">else</span> <span class="n">preds_array</span>
<span class="k">return</span> <span class="n">pd</span><span class="o">.</span><span class="n">Series</span><span class="p">(</span><span class="n">output</span><span class="p">)</span><span class="o">.</span><span class="n">astype</span><span class="p">(</span><span class="n">output</span><span class="o">.</span><span class="n">dtype</span><span class="p">)</span>
<span class="k">else</span><span class="p">:</span>
<span class="k">raise</span> <span class="ne">ValueError</span><span class="p">(</span><span class="s2">&quot;Unsupported return type&quot;</span><span class="p">)</span>
<div class="viewcode-block" id="predict_batch_udf"><a class="viewcode-back" href="../../../reference/api/pyspark.ml.functions.predict_batch_udf.html#pyspark.ml.functions.predict_batch_udf">[docs]</a><span class="k">def</span><span class="w"> </span><span class="nf">predict_batch_udf</span><span class="p">(</span>
<span class="n">make_predict_fn</span><span class="p">:</span> <span class="n">Callable</span><span class="p">[</span>
<span class="p">[],</span>
<span class="n">PredictBatchFunction</span><span class="p">,</span>
<span class="p">],</span>
<span class="o">*</span><span class="p">,</span>
<span class="n">return_type</span><span class="p">:</span> <span class="n">DataType</span><span class="p">,</span>
<span class="n">batch_size</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span>
<span class="n">input_tensor_shapes</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="n">Union</span><span class="p">[</span><span class="n">List</span><span class="p">[</span><span class="n">Optional</span><span class="p">[</span><span class="n">List</span><span class="p">[</span><span class="nb">int</span><span class="p">]]],</span> <span class="n">Mapping</span><span class="p">[</span><span class="nb">int</span><span class="p">,</span> <span class="n">List</span><span class="p">[</span><span class="nb">int</span><span class="p">]]]]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
<span class="p">)</span> <span class="o">-&gt;</span> <span class="n">UserDefinedFunctionLike</span><span class="p">:</span>
<span class="w"> </span><span class="sd">&quot;&quot;&quot;Given a function which loads a model and returns a `predict` function for inference over a</span>
<span class="sd"> batch of numpy inputs, returns a Pandas UDF wrapper for inference over a Spark DataFrame.</span>
<span class="sd"> The returned Pandas UDF does the following on each DataFrame partition:</span>
<span class="sd"> * calls the `make_predict_fn` to load the model and cache its `predict` function.</span>
<span class="sd"> * batches the input records as numpy arrays and invokes `predict` on each batch.</span>
<span class="sd"> Note: this assumes that the `make_predict_fn` encapsulates all of the necessary dependencies for</span>
<span class="sd"> running the model, or the Spark executor environment already satisfies all runtime requirements.</span>
<span class="sd"> For the conversion of the Spark DataFrame to numpy arrays, there is a one-to-one mapping between</span>
<span class="sd"> the input arguments of the `predict` function (returned by the `make_predict_fn`) and the input</span>
<span class="sd"> columns sent to the Pandas UDF (returned by the `predict_batch_udf`) at runtime. Each input</span>
<span class="sd"> column will be converted as follows:</span>
<span class="sd"> * scalar column -&gt; 1-dim np.ndarray</span>
<span class="sd"> * tensor column + tensor shape -&gt; N-dim np.ndarray</span>
<span class="sd"> Note that any tensor columns in the Spark DataFrame must be represented as a flattened</span>
<span class="sd"> one-dimensional array, and multiple scalar columns can be combined into a single tensor column</span>
<span class="sd"> using the standard :py:func:`pyspark.sql.functions.array()` function.</span>
<span class="sd"> .. versionadded:: 3.4.0</span>
<span class="sd"> Parameters</span>
<span class="sd"> ----------</span>
<span class="sd"> make_predict_fn : callable</span>
<span class="sd"> Function which is responsible for loading a model and returning a</span>
<span class="sd"> :py:class:`PredictBatchFunction` which takes one or more numpy arrays as input and returns</span>
<span class="sd"> one of the following:</span>
<span class="sd"> * a numpy array (for a single output)</span>
<span class="sd"> * a dictionary of named numpy arrays (for multiple outputs)</span>
<span class="sd"> * a row-oriented list of dictionaries (for multiple outputs).</span>
<span class="sd"> For a dictionary of named numpy arrays, the arrays can only be one or two dimensional, since</span>
<span class="sd"> higher dimensional arrays are not supported. For a row-oriented list of dictionaries, each</span>
<span class="sd"> element in the dictionary must be either a scalar or one-dimensional array.</span>
<span class="sd"> return_type : :py:class:`pyspark.sql.types.DataType` or str.</span>
<span class="sd"> Spark SQL datatype for the expected output:</span>
<span class="sd"> * Scalar (e.g. IntegerType, FloatType) --&gt; 1-dim numpy array.</span>
<span class="sd"> * ArrayType --&gt; 2-dim numpy array.</span>
<span class="sd"> * StructType --&gt; dict with keys matching struct fields.</span>
<span class="sd"> * StructType --&gt; list of dict with keys matching struct fields, for models like the</span>
<span class="sd"> `Huggingface pipeline for sentiment analysis</span>
<span class="sd"> &lt;https://huggingface.co/docs/transformers/quicktour#pipeline-usage&gt;`_.</span>
<span class="sd"> batch_size : int</span>
<span class="sd"> Batch size to use for inference. This is typically a limitation of the model</span>
<span class="sd"> and/or available hardware resources and is usually smaller than the Spark partition size.</span>
<span class="sd"> input_tensor_shapes : list, dict, optional.</span>
<span class="sd"> A list of ints or a dictionary of ints (key) and list of ints (value).</span>
<span class="sd"> Input tensor shapes for models with tensor inputs. This can be a list of shapes,</span>
<span class="sd"> where each shape is a list of integers or None (for scalar inputs). Alternatively, this</span>
<span class="sd"> can be represented by a &quot;sparse&quot; dictionary, where the keys are the integer indices of the</span>
<span class="sd"> inputs, and the values are the shapes. Each tensor input value in the Spark DataFrame must</span>
<span class="sd"> be represented as a single column containing a flattened 1-D array. The provided</span>
<span class="sd"> `input_tensor_shapes` will be used to reshape the flattened array into the expected tensor</span>
<span class="sd"> shape. For the list form, the order of the tensor shapes must match the order of the</span>
<span class="sd"> selected DataFrame columns. The batch dimension (typically -1 or None in the first</span>
<span class="sd"> dimension) should not be included, since it will be determined by the batch_size argument.</span>
<span class="sd"> Tabular datasets with scalar-valued columns should not provide this argument.</span>
<span class="sd"> Returns</span>
<span class="sd"> -------</span>
<span class="sd"> :py:class:`UserDefinedFunctionLike`</span>
<span class="sd"> A Pandas UDF for model inference on a Spark DataFrame.</span>
<span class="sd"> Examples</span>
<span class="sd"> --------</span>
<span class="sd"> For a pre-trained TensorFlow MNIST model with two-dimensional input images represented as a</span>
<span class="sd"> flattened tensor value stored in a single Spark DataFrame column of type `array&lt;float&gt;`.</span>
<span class="sd"> .. code-block:: python</span>
<span class="sd"> from pyspark.ml.functions import predict_batch_udf</span>
<span class="sd"> def make_mnist_fn():</span>
<span class="sd"> # load/init happens once per python worker</span>
<span class="sd"> import tensorflow as tf</span>
<span class="sd"> model = tf.keras.models.load_model(&#39;/path/to/mnist_model&#39;)</span>
<span class="sd"> # predict on batches of tasks/partitions, using cached model</span>
<span class="sd"> def predict(inputs: np.ndarray) -&gt; np.ndarray:</span>
<span class="sd"> # inputs.shape = [batch_size, 784], see input_tensor_shapes</span>
<span class="sd"> # outputs.shape = [batch_size, 10], see return_type</span>
<span class="sd"> return model.predict(inputs)</span>
<span class="sd"> return predict</span>
<span class="sd"> mnist_udf = predict_batch_udf(make_mnist_fn,</span>
<span class="sd"> return_type=ArrayType(FloatType()),</span>
<span class="sd"> batch_size=100,</span>
<span class="sd"> input_tensor_shapes=[[784]])</span>
<span class="sd"> df = spark.read.parquet(&quot;/path/to/mnist_data&quot;)</span>
<span class="sd"> df.show(5)</span>
<span class="sd"> # +--------------------+</span>
<span class="sd"> # | data|</span>
<span class="sd"> # +--------------------+</span>
<span class="sd"> # |[0.0, 0.0, 0.0, 0...|</span>
<span class="sd"> # |[0.0, 0.0, 0.0, 0...|</span>
<span class="sd"> # |[0.0, 0.0, 0.0, 0...|</span>
<span class="sd"> # |[0.0, 0.0, 0.0, 0...|</span>
<span class="sd"> # |[0.0, 0.0, 0.0, 0...|</span>
<span class="sd"> # +--------------------+</span>
<span class="sd"> df.withColumn(&quot;preds&quot;, mnist_udf(&quot;data&quot;)).show(5)</span>
<span class="sd"> # +--------------------+--------------------+</span>
<span class="sd"> # | data| preds|</span>
<span class="sd"> # +--------------------+--------------------+</span>
<span class="sd"> # |[0.0, 0.0, 0.0, 0...|[-13.511008, 8.84...|</span>
<span class="sd"> # |[0.0, 0.0, 0.0, 0...|[-5.3957458, -2.2...|</span>
<span class="sd"> # |[0.0, 0.0, 0.0, 0...|[-7.2014456, -8.8...|</span>
<span class="sd"> # |[0.0, 0.0, 0.0, 0...|[-19.466187, -13....|</span>
<span class="sd"> # |[0.0, 0.0, 0.0, 0...|[-5.7757926, -7.8...|</span>
<span class="sd"> # +--------------------+--------------------+</span>
<span class="sd"> To demonstrate usage with different combinations of input and output types, the following</span>
<span class="sd"> examples just use simple mathematical transforms as the models.</span>
<span class="sd"> * Single scalar column</span>
<span class="sd"> Input DataFrame has a single scalar column, which will be passed to the `predict`</span>
<span class="sd"> function as a 1-D numpy array.</span>
<span class="sd"> &gt;&gt;&gt; import numpy as np</span>
<span class="sd"> &gt;&gt;&gt; import pandas as pd</span>
<span class="sd"> &gt;&gt;&gt; from pyspark.ml.functions import predict_batch_udf</span>
<span class="sd"> &gt;&gt;&gt; from pyspark.sql.types import FloatType</span>
<span class="sd"> &gt;&gt;&gt;</span>
<span class="sd"> &gt;&gt;&gt; df = spark.createDataFrame(pd.DataFrame(np.arange(100)))</span>
<span class="sd"> &gt;&gt;&gt; df.show(5)</span>
<span class="sd"> +---+</span>
<span class="sd"> | 0|</span>
<span class="sd"> +---+</span>
<span class="sd"> | 0|</span>
<span class="sd"> | 1|</span>
<span class="sd"> | 2|</span>
<span class="sd"> | 3|</span>
<span class="sd"> | 4|</span>
<span class="sd"> +---+</span>
<span class="sd"> only showing top 5 rows</span>
<span class="sd"> &gt;&gt;&gt; def make_times_two_fn():</span>
<span class="sd"> ... def predict(inputs: np.ndarray) -&gt; np.ndarray:</span>
<span class="sd"> ... # inputs.shape = [batch_size]</span>
<span class="sd"> ... # outputs.shape = [batch_size]</span>
<span class="sd"> ... return inputs * 2</span>
<span class="sd"> ... return predict</span>
<span class="sd"> ...</span>
<span class="sd"> &gt;&gt;&gt; times_two_udf = predict_batch_udf(make_times_two_fn,</span>
<span class="sd"> ... return_type=FloatType(),</span>
<span class="sd"> ... batch_size=10)</span>
<span class="sd"> &gt;&gt;&gt; df = spark.createDataFrame(pd.DataFrame(np.arange(100)))</span>
<span class="sd"> &gt;&gt;&gt; df.withColumn(&quot;x2&quot;, times_two_udf(&quot;0&quot;)).show(5)</span>
<span class="sd"> +---+---+</span>
<span class="sd"> | 0| x2|</span>
<span class="sd"> +---+---+</span>
<span class="sd"> | 0|0.0|</span>
<span class="sd"> | 1|2.0|</span>
<span class="sd"> | 2|4.0|</span>
<span class="sd"> | 3|6.0|</span>
<span class="sd"> | 4|8.0|</span>
<span class="sd"> +---+---+</span>
<span class="sd"> only showing top 5 rows</span>
<span class="sd"> * Multiple scalar columns</span>
<span class="sd"> Input DataFrame has multiple columns of scalar values. If the user-provided `predict`</span>
<span class="sd"> function expects a single input, then the user must combine the multiple columns into a</span>
<span class="sd"> single tensor using `pyspark.sql.functions.array`.</span>
<span class="sd"> &gt;&gt;&gt; import numpy as np</span>
<span class="sd"> &gt;&gt;&gt; import pandas as pd</span>
<span class="sd"> &gt;&gt;&gt; from pyspark.ml.functions import predict_batch_udf</span>
<span class="sd"> &gt;&gt;&gt; from pyspark.sql.functions import array</span>
<span class="sd"> &gt;&gt;&gt;</span>
<span class="sd"> &gt;&gt;&gt; data = np.arange(0, 1000, dtype=np.float64).reshape(-1, 4)</span>
<span class="sd"> &gt;&gt;&gt; pdf = pd.DataFrame(data, columns=[&#39;a&#39;,&#39;b&#39;,&#39;c&#39;,&#39;d&#39;])</span>
<span class="sd"> &gt;&gt;&gt; df = spark.createDataFrame(pdf)</span>
<span class="sd"> &gt;&gt;&gt; df.show(5)</span>
<span class="sd"> +----+----+----+----+</span>
<span class="sd"> | a| b| c| d|</span>
<span class="sd"> +----+----+----+----+</span>
<span class="sd"> | 0.0| 1.0| 2.0| 3.0|</span>
<span class="sd"> | 4.0| 5.0| 6.0| 7.0|</span>
<span class="sd"> | 8.0| 9.0|10.0|11.0|</span>
<span class="sd"> |12.0|13.0|14.0|15.0|</span>
<span class="sd"> |16.0|17.0|18.0|19.0|</span>
<span class="sd"> +----+----+----+----+</span>
<span class="sd"> only showing top 5 rows</span>
<span class="sd"> &gt;&gt;&gt; def make_sum_fn():</span>
<span class="sd"> ... def predict(inputs: np.ndarray) -&gt; np.ndarray:</span>
<span class="sd"> ... # inputs.shape = [batch_size, 4]</span>
<span class="sd"> ... # outputs.shape = [batch_size]</span>
<span class="sd"> ... return np.sum(inputs, axis=1)</span>
<span class="sd"> ... return predict</span>
<span class="sd"> ...</span>
<span class="sd"> &gt;&gt;&gt; sum_udf = predict_batch_udf(make_sum_fn,</span>
<span class="sd"> ... return_type=FloatType(),</span>
<span class="sd"> ... batch_size=10,</span>
<span class="sd"> ... input_tensor_shapes=[[4]])</span>
<span class="sd"> &gt;&gt;&gt; df.withColumn(&quot;sum&quot;, sum_udf(array(&quot;a&quot;, &quot;b&quot;, &quot;c&quot;, &quot;d&quot;))).show(5)</span>
<span class="sd"> +----+----+----+----+----+</span>
<span class="sd"> | a| b| c| d| sum|</span>
<span class="sd"> +----+----+----+----+----+</span>
<span class="sd"> | 0.0| 1.0| 2.0| 3.0| 6.0|</span>
<span class="sd"> | 4.0| 5.0| 6.0| 7.0|22.0|</span>
<span class="sd"> | 8.0| 9.0|10.0|11.0|38.0|</span>
<span class="sd"> |12.0|13.0|14.0|15.0|54.0|</span>
<span class="sd"> |16.0|17.0|18.0|19.0|70.0|</span>
<span class="sd"> +----+----+----+----+----+</span>
<span class="sd"> only showing top 5 rows</span>
<span class="sd"> If the `predict` function expects multiple inputs, then the number of selected input columns</span>
<span class="sd"> must match the number of expected inputs.</span>
<span class="sd"> &gt;&gt;&gt; def make_sum_fn():</span>
<span class="sd"> ... def predict(x1: np.ndarray,</span>
<span class="sd"> ... x2: np.ndarray,</span>
<span class="sd"> ... x3: np.ndarray,</span>
<span class="sd"> ... x4: np.ndarray) -&gt; np.ndarray:</span>
<span class="sd"> ... # xN.shape = [batch_size]</span>
<span class="sd"> ... # outputs.shape = [batch_size]</span>
<span class="sd"> ... return x1 + x2 + x3 + x4</span>
<span class="sd"> ... return predict</span>
<span class="sd"> ...</span>
<span class="sd"> &gt;&gt;&gt; sum_udf = predict_batch_udf(make_sum_fn,</span>
<span class="sd"> ... return_type=FloatType(),</span>
<span class="sd"> ... batch_size=10)</span>
<span class="sd"> &gt;&gt;&gt; df.withColumn(&quot;sum&quot;, sum_udf(&quot;a&quot;, &quot;b&quot;, &quot;c&quot;, &quot;d&quot;)).show(5)</span>
<span class="sd"> +----+----+----+----+----+</span>
<span class="sd"> | a| b| c| d| sum|</span>
<span class="sd"> +----+----+----+----+----+</span>
<span class="sd"> | 0.0| 1.0| 2.0| 3.0| 6.0|</span>
<span class="sd"> | 4.0| 5.0| 6.0| 7.0|22.0|</span>
<span class="sd"> | 8.0| 9.0|10.0|11.0|38.0|</span>
<span class="sd"> |12.0|13.0|14.0|15.0|54.0|</span>
<span class="sd"> |16.0|17.0|18.0|19.0|70.0|</span>
<span class="sd"> +----+----+----+----+----+</span>
<span class="sd"> only showing top 5 rows</span>
<span class="sd"> * Multiple tensor columns</span>
<span class="sd"> Input DataFrame has multiple columns, where each column is a tensor. The number of columns</span>
<span class="sd"> should match the number of expected inputs for the user-provided `predict` function.</span>
<span class="sd"> &gt;&gt;&gt; import numpy as np</span>
<span class="sd"> &gt;&gt;&gt; import pandas as pd</span>
<span class="sd"> &gt;&gt;&gt; from pyspark.ml.functions import predict_batch_udf</span>
<span class="sd"> &gt;&gt;&gt; from pyspark.sql.types import ArrayType, FloatType, StructType, StructField</span>
<span class="sd"> &gt;&gt;&gt; from typing import Mapping</span>
<span class="sd"> &gt;&gt;&gt;</span>
<span class="sd"> &gt;&gt;&gt; data = np.arange(0, 1000, dtype=np.float64).reshape(-1, 4)</span>
<span class="sd"> &gt;&gt;&gt; pdf = pd.DataFrame(data, columns=[&#39;a&#39;,&#39;b&#39;,&#39;c&#39;,&#39;d&#39;])</span>
<span class="sd"> &gt;&gt;&gt; pdf_tensor = pd.DataFrame()</span>
<span class="sd"> &gt;&gt;&gt; pdf_tensor[&#39;t1&#39;] = pdf.values.tolist()</span>
<span class="sd"> &gt;&gt;&gt; pdf_tensor[&#39;t2&#39;] = pdf.drop(columns=&#39;d&#39;).values.tolist()</span>
<span class="sd"> &gt;&gt;&gt; df = spark.createDataFrame(pdf_tensor)</span>
<span class="sd"> &gt;&gt;&gt; df.show(5)</span>
<span class="sd"> +--------------------+------------------+</span>
<span class="sd"> | t1| t2|</span>
<span class="sd"> +--------------------+------------------+</span>
<span class="sd"> |[0.0, 1.0, 2.0, 3.0]| [0.0, 1.0, 2.0]|</span>
<span class="sd"> |[4.0, 5.0, 6.0, 7.0]| [4.0, 5.0, 6.0]|</span>
<span class="sd"> |[8.0, 9.0, 10.0, ...| [8.0, 9.0, 10.0]|</span>
<span class="sd"> |[12.0, 13.0, 14.0...|[12.0, 13.0, 14.0]|</span>
<span class="sd"> |[16.0, 17.0, 18.0...|[16.0, 17.0, 18.0]|</span>
<span class="sd"> +--------------------+------------------+</span>
<span class="sd"> only showing top 5 rows</span>
<span class="sd"> &gt;&gt;&gt; def make_multi_sum_fn():</span>
<span class="sd"> ... def predict(x1: np.ndarray, x2: np.ndarray) -&gt; np.ndarray:</span>
<span class="sd"> ... # x1.shape = [batch_size, 4]</span>
<span class="sd"> ... # x2.shape = [batch_size, 3]</span>
<span class="sd"> ... # outputs.shape = [batch_size]</span>
<span class="sd"> ... return np.sum(x1, axis=1) + np.sum(x2, axis=1)</span>
<span class="sd"> ... return predict</span>
<span class="sd"> ...</span>
<span class="sd"> &gt;&gt;&gt; multi_sum_udf = predict_batch_udf(</span>
<span class="sd"> ... make_multi_sum_fn,</span>
<span class="sd"> ... return_type=FloatType(),</span>
<span class="sd"> ... batch_size=5,</span>
<span class="sd"> ... input_tensor_shapes=[[4], [3]],</span>
<span class="sd"> ... )</span>
<span class="sd"> &gt;&gt;&gt; df.withColumn(&quot;sum&quot;, multi_sum_udf(&quot;t1&quot;, &quot;t2&quot;)).show(5)</span>
<span class="sd"> +--------------------+------------------+-----+</span>
<span class="sd"> | t1| t2| sum|</span>
<span class="sd"> +--------------------+------------------+-----+</span>
<span class="sd"> |[0.0, 1.0, 2.0, 3.0]| [0.0, 1.0, 2.0]| 9.0|</span>
<span class="sd"> |[4.0, 5.0, 6.0, 7.0]| [4.0, 5.0, 6.0]| 37.0|</span>
<span class="sd"> |[8.0, 9.0, 10.0, ...| [8.0, 9.0, 10.0]| 65.0|</span>
<span class="sd"> |[12.0, 13.0, 14.0...|[12.0, 13.0, 14.0]| 93.0|</span>
<span class="sd"> |[16.0, 17.0, 18.0...|[16.0, 17.0, 18.0]|121.0|</span>
<span class="sd"> +--------------------+------------------+-----+</span>
<span class="sd"> only showing top 5 rows</span>
<span class="sd"> * Multiple outputs</span>
<span class="sd"> Some models can provide multiple outputs. These can be returned as a dictionary of named</span>
<span class="sd"> values, which can be represented in either columnar or row-based formats.</span>
<span class="sd"> &gt;&gt;&gt; def make_multi_sum_fn():</span>
<span class="sd"> ... def predict_columnar(x1: np.ndarray, x2: np.ndarray) -&gt; Mapping[str, np.ndarray]:</span>
<span class="sd"> ... # x1.shape = [batch_size, 4]</span>
<span class="sd"> ... # x2.shape = [batch_size, 3]</span>
<span class="sd"> ... return {</span>
<span class="sd"> ... &quot;sum1&quot;: np.sum(x1, axis=1),</span>
<span class="sd"> ... &quot;sum2&quot;: np.sum(x2, axis=1)</span>
<span class="sd"> ... }</span>
<span class="sd"> ... return predict_columnar</span>
<span class="sd"> ...</span>
<span class="sd"> &gt;&gt;&gt; multi_sum_udf = predict_batch_udf(</span>
<span class="sd"> ... make_multi_sum_fn,</span>
<span class="sd"> ... return_type=StructType([</span>
<span class="sd"> ... StructField(&quot;sum1&quot;, FloatType(), True),</span>
<span class="sd"> ... StructField(&quot;sum2&quot;, FloatType(), True)</span>
<span class="sd"> ... ]),</span>
<span class="sd"> ... batch_size=5,</span>
<span class="sd"> ... input_tensor_shapes=[[4], [3]],</span>
<span class="sd"> ... )</span>
<span class="sd"> &gt;&gt;&gt; df.withColumn(&quot;preds&quot;, multi_sum_udf(&quot;t1&quot;, &quot;t2&quot;)).select(&quot;t1&quot;, &quot;t2&quot;, &quot;preds.*&quot;).show(5)</span>
<span class="sd"> +--------------------+------------------+----+----+</span>
<span class="sd"> | t1| t2|sum1|sum2|</span>
<span class="sd"> +--------------------+------------------+----+----+</span>
<span class="sd"> |[0.0, 1.0, 2.0, 3.0]| [0.0, 1.0, 2.0]| 6.0| 3.0|</span>
<span class="sd"> |[4.0, 5.0, 6.0, 7.0]| [4.0, 5.0, 6.0]|22.0|15.0|</span>
<span class="sd"> |[8.0, 9.0, 10.0, ...| [8.0, 9.0, 10.0]|38.0|27.0|</span>
<span class="sd"> |[12.0, 13.0, 14.0...|[12.0, 13.0, 14.0]|54.0|39.0|</span>
<span class="sd"> |[16.0, 17.0, 18.0...|[16.0, 17.0, 18.0]|70.0|51.0|</span>
<span class="sd"> +--------------------+------------------+----+----+</span>
<span class="sd"> only showing top 5 rows</span>
<span class="sd"> &gt;&gt;&gt; def make_multi_sum_fn():</span>
<span class="sd"> ... def predict_row(x1: np.ndarray, x2: np.ndarray) -&gt; list[Mapping[str, float]]:</span>
<span class="sd"> ... # x1.shape = [batch_size, 4]</span>
<span class="sd"> ... # x2.shape = [batch_size, 3]</span>
<span class="sd"> ... return [{&#39;sum1&#39;: np.sum(x1[i]), &#39;sum2&#39;: np.sum(x2[i])} for i in range(len(x1))]</span>
<span class="sd"> ... return predict_row</span>
<span class="sd"> ...</span>
<span class="sd"> &gt;&gt;&gt; multi_sum_udf = predict_batch_udf(</span>
<span class="sd"> ... make_multi_sum_fn,</span>
<span class="sd"> ... return_type=StructType([</span>
<span class="sd"> ... StructField(&quot;sum1&quot;, FloatType(), True),</span>
<span class="sd"> ... StructField(&quot;sum2&quot;, FloatType(), True)</span>
<span class="sd"> ... ]),</span>
<span class="sd"> ... batch_size=5,</span>
<span class="sd"> ... input_tensor_shapes=[[4], [3]],</span>
<span class="sd"> ... )</span>
<span class="sd"> &gt;&gt;&gt; df.withColumn(&quot;sum&quot;, multi_sum_udf(&quot;t1&quot;, &quot;t2&quot;)).select(&quot;t1&quot;, &quot;t2&quot;, &quot;sum.*&quot;).show(5)</span>
<span class="sd"> +--------------------+------------------+----+----+</span>
<span class="sd"> | t1| t2|sum1|sum2|</span>
<span class="sd"> +--------------------+------------------+----+----+</span>
<span class="sd"> |[0.0, 1.0, 2.0, 3.0]| [0.0, 1.0, 2.0]| 6.0| 3.0|</span>
<span class="sd"> |[4.0, 5.0, 6.0, 7.0]| [4.0, 5.0, 6.0]|22.0|15.0|</span>
<span class="sd"> |[8.0, 9.0, 10.0, ...| [8.0, 9.0, 10.0]|38.0|27.0|</span>
<span class="sd"> |[12.0, 13.0, 14.0...|[12.0, 13.0, 14.0]|54.0|39.0|</span>
<span class="sd"> |[16.0, 17.0, 18.0...|[16.0, 17.0, 18.0]|70.0|51.0|</span>
<span class="sd"> +--------------------+------------------+----+----+</span>
<span class="sd"> only showing top 5 rows</span>
<span class="sd"> Note that the multiple outputs can be arrays as well.</span>
<span class="sd"> &gt;&gt;&gt; def make_multi_times_two_fn():</span>
<span class="sd"> ... def predict(x1: np.ndarray, x2: np.ndarray) -&gt; Mapping[str, np.ndarray]:</span>
<span class="sd"> ... # x1.shape = [batch_size, 4]</span>
<span class="sd"> ... # x2.shape = [batch_size, 3]</span>
<span class="sd"> ... return {&quot;t1x2&quot;: x1 * 2, &quot;t2x2&quot;: x2 * 2}</span>
<span class="sd"> ... return predict</span>
<span class="sd"> ...</span>
<span class="sd"> &gt;&gt;&gt; multi_times_two_udf = predict_batch_udf(</span>
<span class="sd"> ... make_multi_times_two_fn,</span>
<span class="sd"> ... return_type=StructType([</span>
<span class="sd"> ... StructField(&quot;t1x2&quot;, ArrayType(FloatType()), True),</span>
<span class="sd"> ... StructField(&quot;t2x2&quot;, ArrayType(FloatType()), True)</span>
<span class="sd"> ... ]),</span>
<span class="sd"> ... batch_size=5,</span>
<span class="sd"> ... input_tensor_shapes=[[4], [3]],</span>
<span class="sd"> ... )</span>
<span class="sd"> &gt;&gt;&gt; df.withColumn(&quot;x2&quot;, multi_times_two_udf(&quot;t1&quot;, &quot;t2&quot;)).select(&quot;t1&quot;, &quot;t2&quot;, &quot;x2.*&quot;).show(5)</span>
<span class="sd"> +--------------------+------------------+--------------------+------------------+</span>
<span class="sd"> | t1| t2| t1x2| t2x2|</span>
<span class="sd"> +--------------------+------------------+--------------------+------------------+</span>
<span class="sd"> |[0.0, 1.0, 2.0, 3.0]| [0.0, 1.0, 2.0]|[0.0, 2.0, 4.0, 6.0]| [0.0, 2.0, 4.0]|</span>
<span class="sd"> |[4.0, 5.0, 6.0, 7.0]| [4.0, 5.0, 6.0]|[8.0, 10.0, 12.0,...| [8.0, 10.0, 12.0]|</span>
<span class="sd"> |[8.0, 9.0, 10.0, ...| [8.0, 9.0, 10.0]|[16.0, 18.0, 20.0...|[16.0, 18.0, 20.0]|</span>
<span class="sd"> |[12.0, 13.0, 14.0...|[12.0, 13.0, 14.0]|[24.0, 26.0, 28.0...|[24.0, 26.0, 28.0]|</span>
<span class="sd"> |[16.0, 17.0, 18.0...|[16.0, 17.0, 18.0]|[32.0, 34.0, 36.0...|[32.0, 34.0, 36.0]|</span>
<span class="sd"> +--------------------+------------------+--------------------+------------------+</span>
<span class="sd"> only showing top 5 rows</span>
<span class="sd"> &quot;&quot;&quot;</span>
<span class="c1"># generate a new uuid each time this is invoked on the driver to invalidate executor-side cache.</span>
<span class="n">model_uuid</span> <span class="o">=</span> <span class="n">uuid</span><span class="o">.</span><span class="n">uuid4</span><span class="p">()</span>
<span class="k">def</span><span class="w"> </span><span class="nf">predict</span><span class="p">(</span><span class="n">data</span><span class="p">:</span> <span class="n">Iterator</span><span class="p">[</span><span class="n">Union</span><span class="p">[</span><span class="n">pd</span><span class="o">.</span><span class="n">Series</span><span class="p">,</span> <span class="n">pd</span><span class="o">.</span><span class="n">DataFrame</span><span class="p">]])</span> <span class="o">-&gt;</span> <span class="n">Iterator</span><span class="p">[</span><span class="n">pd</span><span class="o">.</span><span class="n">DataFrame</span><span class="p">]:</span>
<span class="c1"># TODO: adjust return type hint when Iterator[Union[pd.Series, pd.DataFrame]] is supported</span>
<span class="kn">from</span><span class="w"> </span><span class="nn">pyspark.ml.model_cache</span><span class="w"> </span><span class="kn">import</span> <span class="n">ModelCache</span>
<span class="c1"># get predict function (from cache or from running user-provided make_predict_fn)</span>
<span class="n">predict_fn</span> <span class="o">=</span> <span class="n">ModelCache</span><span class="o">.</span><span class="n">get</span><span class="p">(</span><span class="n">model_uuid</span><span class="p">)</span>
<span class="k">if</span> <span class="ow">not</span> <span class="n">predict_fn</span><span class="p">:</span>
<span class="n">predict_fn</span> <span class="o">=</span> <span class="n">make_predict_fn</span><span class="p">()</span>
<span class="n">ModelCache</span><span class="o">.</span><span class="n">add</span><span class="p">(</span><span class="n">model_uuid</span><span class="p">,</span> <span class="n">predict_fn</span><span class="p">)</span>
<span class="c1"># get number of expected parameters for predict function</span>
<span class="n">signature</span> <span class="o">=</span> <span class="n">inspect</span><span class="o">.</span><span class="n">signature</span><span class="p">(</span><span class="n">predict_fn</span><span class="p">)</span>
<span class="n">num_expected_cols</span> <span class="o">=</span> <span class="nb">len</span><span class="p">(</span><span class="n">signature</span><span class="o">.</span><span class="n">parameters</span><span class="p">)</span>
<span class="c1"># convert sparse input_tensor_shapes to dense if needed</span>
<span class="n">input_shapes</span><span class="p">:</span> <span class="n">List</span><span class="p">[</span><span class="n">List</span><span class="p">[</span><span class="nb">int</span><span class="p">]</span> <span class="o">|</span> <span class="kc">None</span><span class="p">]</span>
<span class="k">if</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">input_tensor_shapes</span><span class="p">,</span> <span class="n">Mapping</span><span class="p">):</span>
<span class="n">input_shapes</span> <span class="o">=</span> <span class="p">[</span><span class="kc">None</span><span class="p">]</span> <span class="o">*</span> <span class="n">num_expected_cols</span>
<span class="k">for</span> <span class="n">index</span><span class="p">,</span> <span class="n">shape</span> <span class="ow">in</span> <span class="n">input_tensor_shapes</span><span class="o">.</span><span class="n">items</span><span class="p">():</span>
<span class="n">input_shapes</span><span class="p">[</span><span class="n">index</span><span class="p">]</span> <span class="o">=</span> <span class="n">shape</span>
<span class="k">else</span><span class="p">:</span>
<span class="n">input_shapes</span> <span class="o">=</span> <span class="n">input_tensor_shapes</span> <span class="c1"># type: ignore</span>
<span class="c1"># iterate over pandas batch, invoking predict_fn with ndarrays</span>
<span class="k">for</span> <span class="n">pandas_batch</span> <span class="ow">in</span> <span class="n">data</span><span class="p">:</span>
<span class="n">has_tuple</span> <span class="o">=</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">pandas_batch</span><span class="p">,</span> <span class="n">Tuple</span><span class="p">)</span> <span class="c1"># type: ignore</span>
<span class="n">has_tensors</span> <span class="o">=</span> <span class="n">_has_tensor_cols</span><span class="p">(</span><span class="n">pandas_batch</span><span class="p">)</span>
<span class="c1"># require input_tensor_shapes for any tensor columns</span>
<span class="k">if</span> <span class="n">has_tensors</span> <span class="ow">and</span> <span class="ow">not</span> <span class="n">input_shapes</span><span class="p">:</span>
<span class="k">raise</span> <span class="ne">ValueError</span><span class="p">(</span><span class="s2">&quot;Tensor columns require input_tensor_shapes&quot;</span><span class="p">)</span>
<span class="k">for</span> <span class="n">batch</span> <span class="ow">in</span> <span class="n">_batched</span><span class="p">(</span><span class="n">pandas_batch</span><span class="p">,</span> <span class="n">batch_size</span><span class="p">):</span>
<span class="n">num_input_rows</span> <span class="o">=</span> <span class="nb">len</span><span class="p">(</span><span class="n">batch</span><span class="p">)</span>
<span class="n">num_input_cols</span> <span class="o">=</span> <span class="nb">len</span><span class="p">(</span><span class="n">batch</span><span class="o">.</span><span class="n">columns</span><span class="p">)</span>
<span class="k">if</span> <span class="n">num_input_cols</span> <span class="o">==</span> <span class="n">num_expected_cols</span> <span class="ow">and</span> <span class="n">num_expected_cols</span> <span class="o">&gt;</span> <span class="mi">1</span><span class="p">:</span>
<span class="c1"># input column per expected input for multiple inputs</span>
<span class="n">multi_inputs</span> <span class="o">=</span> <span class="n">_validate_and_transform_multiple_inputs</span><span class="p">(</span>
<span class="n">batch</span><span class="p">,</span> <span class="n">input_shapes</span><span class="p">,</span> <span class="n">num_input_cols</span>
<span class="p">)</span>
<span class="c1"># run model prediction function on multiple (numpy) inputs</span>
<span class="n">preds</span> <span class="o">=</span> <span class="n">predict_fn</span><span class="p">(</span><span class="o">*</span><span class="n">multi_inputs</span><span class="p">)</span>
<span class="k">elif</span> <span class="n">num_expected_cols</span> <span class="o">==</span> <span class="mi">1</span><span class="p">:</span>
<span class="c1"># one or more input columns for single expected input</span>
<span class="n">single_input</span> <span class="o">=</span> <span class="n">_validate_and_transform_single_input</span><span class="p">(</span>
<span class="n">batch</span><span class="p">,</span> <span class="n">input_shapes</span><span class="p">,</span> <span class="n">has_tensors</span><span class="p">,</span> <span class="n">has_tuple</span>
<span class="p">)</span>
<span class="c1"># run model prediction function on single (numpy) inputs</span>
<span class="n">preds</span> <span class="o">=</span> <span class="n">predict_fn</span><span class="p">(</span><span class="n">single_input</span><span class="p">)</span>
<span class="k">else</span><span class="p">:</span>
<span class="n">msg</span> <span class="o">=</span> <span class="s2">&quot;Model expected </span><span class="si">{}</span><span class="s2"> inputs, but received </span><span class="si">{}</span><span class="s2"> columns&quot;</span>
<span class="k">raise</span> <span class="ne">ValueError</span><span class="p">(</span><span class="n">msg</span><span class="o">.</span><span class="n">format</span><span class="p">(</span><span class="n">num_expected_cols</span><span class="p">,</span> <span class="n">num_input_cols</span><span class="p">))</span>
<span class="c1"># return transformed predictions to Spark</span>
<span class="k">yield</span> <span class="n">_validate_and_transform_prediction_result</span><span class="p">(</span>
<span class="n">preds</span><span class="p">,</span> <span class="n">num_input_rows</span><span class="p">,</span> <span class="n">return_type</span>
<span class="p">)</span> <span class="c1"># type: ignore</span>
<span class="k">return</span> <span class="n">pandas_udf</span><span class="p">(</span><span class="n">predict</span><span class="p">,</span> <span class="n">return_type</span><span class="p">)</span> <span class="c1"># type: ignore[call-overload]</span></div>
<span class="k">def</span><span class="w"> </span><span class="nf">_test</span><span class="p">()</span> <span class="o">-&gt;</span> <span class="kc">None</span><span class="p">:</span>
<span class="kn">import</span><span class="w"> </span><span class="nn">doctest</span>
<span class="kn">from</span><span class="w"> </span><span class="nn">pyspark.sql</span><span class="w"> </span><span class="kn">import</span> <span class="n">SparkSession</span>
<span class="kn">import</span><span class="w"> </span><span class="nn">pyspark.ml.functions</span>
<span class="kn">import</span><span class="w"> </span><span class="nn">sys</span>
<span class="n">globs</span> <span class="o">=</span> <span class="n">pyspark</span><span class="o">.</span><span class="n">ml</span><span class="o">.</span><span class="n">functions</span><span class="o">.</span><span class="vm">__dict__</span><span class="o">.</span><span class="n">copy</span><span class="p">()</span>
<span class="n">spark</span> <span class="o">=</span> <span class="n">SparkSession</span><span class="o">.</span><span class="n">builder</span><span class="o">.</span><span class="n">master</span><span class="p">(</span><span class="s2">&quot;local[2]&quot;</span><span class="p">)</span><span class="o">.</span><span class="n">appName</span><span class="p">(</span><span class="s2">&quot;ml.functions tests&quot;</span><span class="p">)</span><span class="o">.</span><span class="n">getOrCreate</span><span class="p">()</span>
<span class="n">sc</span> <span class="o">=</span> <span class="n">spark</span><span class="o">.</span><span class="n">sparkContext</span>
<span class="n">globs</span><span class="p">[</span><span class="s2">&quot;sc&quot;</span><span class="p">]</span> <span class="o">=</span> <span class="n">sc</span>
<span class="n">globs</span><span class="p">[</span><span class="s2">&quot;spark&quot;</span><span class="p">]</span> <span class="o">=</span> <span class="n">spark</span>
<span class="p">(</span><span class="n">failure_count</span><span class="p">,</span> <span class="n">test_count</span><span class="p">)</span> <span class="o">=</span> <span class="n">doctest</span><span class="o">.</span><span class="n">testmod</span><span class="p">(</span>
<span class="n">pyspark</span><span class="o">.</span><span class="n">ml</span><span class="o">.</span><span class="n">functions</span><span class="p">,</span>
<span class="n">globs</span><span class="o">=</span><span class="n">globs</span><span class="p">,</span>
<span class="n">optionflags</span><span class="o">=</span><span class="n">doctest</span><span class="o">.</span><span class="n">ELLIPSIS</span> <span class="o">|</span> <span class="n">doctest</span><span class="o">.</span><span class="n">NORMALIZE_WHITESPACE</span><span class="p">,</span>
<span class="p">)</span>
<span class="n">spark</span><span class="o">.</span><span class="n">stop</span><span class="p">()</span>
<span class="k">if</span> <span class="n">failure_count</span><span class="p">:</span>
<span class="n">sys</span><span class="o">.</span><span class="n">exit</span><span class="p">(</span><span class="o">-</span><span class="mi">1</span><span class="p">)</span>
<span class="k">if</span> <span class="vm">__name__</span> <span class="o">==</span> <span class="s2">&quot;__main__&quot;</span><span class="p">:</span>
<span class="n">_test</span><span class="p">()</span>
</pre></div>
</div>
<!-- Previous / next buttons -->
<div class='prev-next-area'>
</div>
</main>
</div>
</div>
<script src="../../../_static/scripts/pydata-sphinx-theme.js?digest=1999514e3f237ded88cf"></script>
<footer class="footer mt-5 mt-md-0">
<div class="container">
<div class="footer-item">
<p class="copyright">
&copy; Copyright .<br>
</p>
</div>
<div class="footer-item">
<p class="sphinx-version">
Created using <a href="http://sphinx-doc.org/">Sphinx</a> 3.0.4.<br>
</p>
</div>
</div>
</footer>
</body>
</html>