blob: a1182003196314c883e71d244b41b65e12ceedea [file] [log] [blame]
<!DOCTYPE html><html lang="en"><head><meta charset="utf-8"><meta name="viewport" content="width=device-width, initial-scale=1.0"><meta name="generator" content="rustdoc"><meta name="description" content="Source of the Rust file `/root/.cargo/registry/src/github.com-1ecc6299db9ec823/matrixmultiply-0.1.15/src/gemm.rs`."><meta name="keywords" content="rust, rustlang, rust-lang"><title>gemm.rs - source</title><link rel="preload" as="font" type="font/woff2" crossorigin href="../../SourceSerif4-Regular.ttf.woff2"><link rel="preload" as="font" type="font/woff2" crossorigin href="../../FiraSans-Regular.woff2"><link rel="preload" as="font" type="font/woff2" crossorigin href="../../FiraSans-Medium.woff2"><link rel="preload" as="font" type="font/woff2" crossorigin href="../../SourceCodePro-Regular.ttf.woff2"><link rel="preload" as="font" type="font/woff2" crossorigin href="../../SourceSerif4-Bold.ttf.woff2"><link rel="preload" as="font" type="font/woff2" crossorigin href="../../SourceCodePro-Semibold.ttf.woff2"><link rel="stylesheet" href="../../normalize.css"><link rel="stylesheet" href="../../rustdoc.css" id="mainThemeStyle"><link rel="stylesheet" href="../../ayu.css" disabled><link rel="stylesheet" href="../../dark.css" disabled><link rel="stylesheet" href="../../light.css" id="themeStyle"><script id="default-settings" ></script><script src="../../storage.js"></script><script defer src="../../source-script.js"></script><script defer src="../../source-files.js"></script><script defer src="../../main.js"></script><noscript><link rel="stylesheet" href="../../noscript.css"></noscript><link rel="alternate icon" type="image/png" href="../../favicon-16x16.png"><link rel="alternate icon" type="image/png" href="../../favicon-32x32.png"><link rel="icon" type="image/svg+xml" href="../../favicon.svg"></head><body class="rustdoc source"><!--[if lte IE 11]><div class="warning">This old browser is unsupported and will most likely display funky things.</div><![endif]--><nav class="sidebar"><a class="sidebar-logo" href="../../matrixmultiply/index.html"><div class="logo-container"><img class="rust-logo" src="../../rust-logo.svg" alt="logo"></div></a></nav><main><div class="width-limiter"><nav class="sub"><a class="sub-logo-container" href="../../matrixmultiply/index.html"><img class="rust-logo" src="../../rust-logo.svg" alt="logo"></a><form class="search-form"><div class="search-container"><span></span><input class="search-input" name="search" autocomplete="off" spellcheck="false" placeholder="Click or press ‘S’ to search, ‘?’ for more options…" type="search"><div id="help-button" title="help" tabindex="-1"><a href="../../help.html">?</a></div><div id="settings-menu" tabindex="-1"><a href="../../settings.html" title="settings"><img width="22" height="22" alt="Change settings" src="../../wheel.svg"></a></div></div></form></nav><section id="main-content" class="content"><div class="example-wrap"><pre class="src-line-numbers"><span id="1">1</span>
<span id="2">2</span>
<span id="3">3</span>
<span id="4">4</span>
<span id="5">5</span>
<span id="6">6</span>
<span id="7">7</span>
<span id="8">8</span>
<span id="9">9</span>
<span id="10">10</span>
<span id="11">11</span>
<span id="12">12</span>
<span id="13">13</span>
<span id="14">14</span>
<span id="15">15</span>
<span id="16">16</span>
<span id="17">17</span>
<span id="18">18</span>
<span id="19">19</span>
<span id="20">20</span>
<span id="21">21</span>
<span id="22">22</span>
<span id="23">23</span>
<span id="24">24</span>
<span id="25">25</span>
<span id="26">26</span>
<span id="27">27</span>
<span id="28">28</span>
<span id="29">29</span>
<span id="30">30</span>
<span id="31">31</span>
<span id="32">32</span>
<span id="33">33</span>
<span id="34">34</span>
<span id="35">35</span>
<span id="36">36</span>
<span id="37">37</span>
<span id="38">38</span>
<span id="39">39</span>
<span id="40">40</span>
<span id="41">41</span>
<span id="42">42</span>
<span id="43">43</span>
<span id="44">44</span>
<span id="45">45</span>
<span id="46">46</span>
<span id="47">47</span>
<span id="48">48</span>
<span id="49">49</span>
<span id="50">50</span>
<span id="51">51</span>
<span id="52">52</span>
<span id="53">53</span>
<span id="54">54</span>
<span id="55">55</span>
<span id="56">56</span>
<span id="57">57</span>
<span id="58">58</span>
<span id="59">59</span>
<span id="60">60</span>
<span id="61">61</span>
<span id="62">62</span>
<span id="63">63</span>
<span id="64">64</span>
<span id="65">65</span>
<span id="66">66</span>
<span id="67">67</span>
<span id="68">68</span>
<span id="69">69</span>
<span id="70">70</span>
<span id="71">71</span>
<span id="72">72</span>
<span id="73">73</span>
<span id="74">74</span>
<span id="75">75</span>
<span id="76">76</span>
<span id="77">77</span>
<span id="78">78</span>
<span id="79">79</span>
<span id="80">80</span>
<span id="81">81</span>
<span id="82">82</span>
<span id="83">83</span>
<span id="84">84</span>
<span id="85">85</span>
<span id="86">86</span>
<span id="87">87</span>
<span id="88">88</span>
<span id="89">89</span>
<span id="90">90</span>
<span id="91">91</span>
<span id="92">92</span>
<span id="93">93</span>
<span id="94">94</span>
<span id="95">95</span>
<span id="96">96</span>
<span id="97">97</span>
<span id="98">98</span>
<span id="99">99</span>
<span id="100">100</span>
<span id="101">101</span>
<span id="102">102</span>
<span id="103">103</span>
<span id="104">104</span>
<span id="105">105</span>
<span id="106">106</span>
<span id="107">107</span>
<span id="108">108</span>
<span id="109">109</span>
<span id="110">110</span>
<span id="111">111</span>
<span id="112">112</span>
<span id="113">113</span>
<span id="114">114</span>
<span id="115">115</span>
<span id="116">116</span>
<span id="117">117</span>
<span id="118">118</span>
<span id="119">119</span>
<span id="120">120</span>
<span id="121">121</span>
<span id="122">122</span>
<span id="123">123</span>
<span id="124">124</span>
<span id="125">125</span>
<span id="126">126</span>
<span id="127">127</span>
<span id="128">128</span>
<span id="129">129</span>
<span id="130">130</span>
<span id="131">131</span>
<span id="132">132</span>
<span id="133">133</span>
<span id="134">134</span>
<span id="135">135</span>
<span id="136">136</span>
<span id="137">137</span>
<span id="138">138</span>
<span id="139">139</span>
<span id="140">140</span>
<span id="141">141</span>
<span id="142">142</span>
<span id="143">143</span>
<span id="144">144</span>
<span id="145">145</span>
<span id="146">146</span>
<span id="147">147</span>
<span id="148">148</span>
<span id="149">149</span>
<span id="150">150</span>
<span id="151">151</span>
<span id="152">152</span>
<span id="153">153</span>
<span id="154">154</span>
<span id="155">155</span>
<span id="156">156</span>
<span id="157">157</span>
<span id="158">158</span>
<span id="159">159</span>
<span id="160">160</span>
<span id="161">161</span>
<span id="162">162</span>
<span id="163">163</span>
<span id="164">164</span>
<span id="165">165</span>
<span id="166">166</span>
<span id="167">167</span>
<span id="168">168</span>
<span id="169">169</span>
<span id="170">170</span>
<span id="171">171</span>
<span id="172">172</span>
<span id="173">173</span>
<span id="174">174</span>
<span id="175">175</span>
<span id="176">176</span>
<span id="177">177</span>
<span id="178">178</span>
<span id="179">179</span>
<span id="180">180</span>
<span id="181">181</span>
<span id="182">182</span>
<span id="183">183</span>
<span id="184">184</span>
<span id="185">185</span>
<span id="186">186</span>
<span id="187">187</span>
<span id="188">188</span>
<span id="189">189</span>
<span id="190">190</span>
<span id="191">191</span>
<span id="192">192</span>
<span id="193">193</span>
<span id="194">194</span>
<span id="195">195</span>
<span id="196">196</span>
<span id="197">197</span>
<span id="198">198</span>
<span id="199">199</span>
<span id="200">200</span>
<span id="201">201</span>
<span id="202">202</span>
<span id="203">203</span>
<span id="204">204</span>
<span id="205">205</span>
<span id="206">206</span>
<span id="207">207</span>
<span id="208">208</span>
<span id="209">209</span>
<span id="210">210</span>
<span id="211">211</span>
<span id="212">212</span>
<span id="213">213</span>
<span id="214">214</span>
<span id="215">215</span>
<span id="216">216</span>
<span id="217">217</span>
<span id="218">218</span>
<span id="219">219</span>
<span id="220">220</span>
<span id="221">221</span>
<span id="222">222</span>
<span id="223">223</span>
<span id="224">224</span>
<span id="225">225</span>
<span id="226">226</span>
<span id="227">227</span>
<span id="228">228</span>
<span id="229">229</span>
<span id="230">230</span>
<span id="231">231</span>
<span id="232">232</span>
<span id="233">233</span>
<span id="234">234</span>
<span id="235">235</span>
<span id="236">236</span>
<span id="237">237</span>
<span id="238">238</span>
<span id="239">239</span>
<span id="240">240</span>
<span id="241">241</span>
<span id="242">242</span>
<span id="243">243</span>
<span id="244">244</span>
<span id="245">245</span>
<span id="246">246</span>
<span id="247">247</span>
<span id="248">248</span>
<span id="249">249</span>
<span id="250">250</span>
<span id="251">251</span>
<span id="252">252</span>
<span id="253">253</span>
<span id="254">254</span>
<span id="255">255</span>
<span id="256">256</span>
<span id="257">257</span>
<span id="258">258</span>
<span id="259">259</span>
<span id="260">260</span>
<span id="261">261</span>
<span id="262">262</span>
<span id="263">263</span>
<span id="264">264</span>
<span id="265">265</span>
<span id="266">266</span>
<span id="267">267</span>
<span id="268">268</span>
<span id="269">269</span>
<span id="270">270</span>
<span id="271">271</span>
<span id="272">272</span>
<span id="273">273</span>
<span id="274">274</span>
<span id="275">275</span>
<span id="276">276</span>
<span id="277">277</span>
<span id="278">278</span>
<span id="279">279</span>
<span id="280">280</span>
<span id="281">281</span>
<span id="282">282</span>
<span id="283">283</span>
<span id="284">284</span>
<span id="285">285</span>
<span id="286">286</span>
<span id="287">287</span>
<span id="288">288</span>
<span id="289">289</span>
<span id="290">290</span>
<span id="291">291</span>
<span id="292">292</span>
<span id="293">293</span>
<span id="294">294</span>
<span id="295">295</span>
<span id="296">296</span>
<span id="297">297</span>
<span id="298">298</span>
<span id="299">299</span>
<span id="300">300</span>
<span id="301">301</span>
<span id="302">302</span>
<span id="303">303</span>
<span id="304">304</span>
<span id="305">305</span>
<span id="306">306</span>
<span id="307">307</span>
<span id="308">308</span>
<span id="309">309</span>
<span id="310">310</span>
<span id="311">311</span>
<span id="312">312</span>
<span id="313">313</span>
<span id="314">314</span>
<span id="315">315</span>
<span id="316">316</span>
<span id="317">317</span>
<span id="318">318</span>
<span id="319">319</span>
<span id="320">320</span>
<span id="321">321</span>
<span id="322">322</span>
<span id="323">323</span>
<span id="324">324</span>
<span id="325">325</span>
<span id="326">326</span>
<span id="327">327</span>
<span id="328">328</span>
<span id="329">329</span>
<span id="330">330</span>
<span id="331">331</span>
<span id="332">332</span>
<span id="333">333</span>
<span id="334">334</span>
<span id="335">335</span>
<span id="336">336</span>
<span id="337">337</span>
<span id="338">338</span>
<span id="339">339</span>
<span id="340">340</span>
<span id="341">341</span>
<span id="342">342</span>
<span id="343">343</span>
<span id="344">344</span>
<span id="345">345</span>
<span id="346">346</span>
<span id="347">347</span>
<span id="348">348</span>
<span id="349">349</span>
<span id="350">350</span>
<span id="351">351</span>
<span id="352">352</span>
<span id="353">353</span>
<span id="354">354</span>
<span id="355">355</span>
<span id="356">356</span>
<span id="357">357</span>
<span id="358">358</span>
<span id="359">359</span>
<span id="360">360</span>
<span id="361">361</span>
<span id="362">362</span>
<span id="363">363</span>
<span id="364">364</span>
<span id="365">365</span>
</pre><pre class="rust"><code><span class="comment">// Copyright 2016 bluss
//
// Licensed under the Apache License, Version 2.0 &lt;LICENSE-APACHE or
// http://www.apache.org/licenses/LICENSE-2.0&gt; or the MIT license
// &lt;LICENSE-MIT or http://opensource.org/licenses/MIT&gt;, at your
// option. This file may not be copied, modified, or distributed
// except according to those terms.
</span><span class="kw">use </span>std::cmp::min;
<span class="kw">use </span>std::mem::size_of;
<span class="kw">use </span>util::range_chunk;
<span class="kw">use </span>util::round_up_to;
<span class="kw">use </span>kernel::GemmKernel;
<span class="kw">use </span>kernel::Element;
<span class="kw">use </span>sgemm_kernel;
<span class="kw">use </span>dgemm_kernel;
<span class="kw">use </span>rawpointer::PointerExt;
<span class="doccomment">/// General matrix multiplication (f32)
///
/// C ← α A B + β C
///
/// + m, k, n: dimensions
/// + a, b, c: pointer to the first element in the matrix
/// + A: m by k matrix
/// + B: k by n matrix
/// + C: m by n matrix
/// + rs&lt;em&gt;x&lt;/em&gt;: row stride of *x*
/// + cs&lt;em&gt;x&lt;/em&gt;: col stride of *x*
///
/// Strides for A and B may be arbitrary. Strides for C must not result in
/// elements that alias each other, for example they can not be zero.
///
/// If β is zero, then C does not need to be initialized.
</span><span class="kw">pub unsafe fn </span>sgemm(
m: usize, k: usize, n: usize,
alpha: f32,
a: <span class="kw-2">*const </span>f32, rsa: isize, csa: isize,
b: <span class="kw-2">*const </span>f32, rsb: isize, csb: isize,
beta: f32,
c: <span class="kw-2">*mut </span>f32, rsc: isize, csc: isize)
{
gemm_loop::&lt;sgemm_kernel::Gemm&gt;(
m, k, n,
alpha,
a, rsa, csa,
b, rsb, csb,
beta,
c, rsc, csc)
}
<span class="doccomment">/// General matrix multiplication (f64)
///
/// C ← α A B + β C
///
/// + m, k, n: dimensions
/// + a, b, c: pointer to the first element in the matrix
/// + A: m by k matrix
/// + B: k by n matrix
/// + C: m by n matrix
/// + rs&lt;em&gt;x&lt;/em&gt;: row stride of *x*
/// + cs&lt;em&gt;x&lt;/em&gt;: col stride of *x*
///
/// Strides for A and B may be arbitrary. Strides for C must not result in
/// elements that alias each other, for example they can not be zero.
///
/// If β is zero, then C does not need to be initialized.
</span><span class="kw">pub unsafe fn </span>dgemm(
m: usize, k: usize, n: usize,
alpha: f64,
a: <span class="kw-2">*const </span>f64, rsa: isize, csa: isize,
b: <span class="kw-2">*const </span>f64, rsb: isize, csb: isize,
beta: f64,
c: <span class="kw-2">*mut </span>f64, rsc: isize, csc: isize)
{
gemm_loop::&lt;dgemm_kernel::Gemm&gt;(
m, k, n,
alpha,
a, rsa, csa,
b, rsb, csb,
beta,
c, rsc, csc)
}
<span class="doccomment">/// Ensure that GemmKernel parameters are supported
/// (alignment, microkernel size).
///
/// This function is optimized out for a supported configuration.
</span><span class="attribute">#[inline(always)]
</span><span class="kw">fn </span>ensure_kernel_params&lt;K&gt;()
<span class="kw">where </span>K: GemmKernel
{
<span class="kw">let </span>mr = K::mr();
<span class="kw">let </span>nr = K::nr();
<span class="macro">assert!</span>(mr &gt; <span class="number">0 </span>&amp;&amp; mr &lt;= <span class="number">8</span>);
<span class="macro">assert!</span>(nr &gt; <span class="number">0 </span>&amp;&amp; nr &lt;= <span class="number">8</span>);
<span class="macro">assert!</span>(mr * nr * size_of::&lt;K::Elem&gt;() &lt;= <span class="number">8 </span>* <span class="number">4 </span>* <span class="number">8</span>);
<span class="macro">assert!</span>(K::align_to() &lt;= <span class="number">32</span>);
<span class="comment">// one row/col of the kernel is limiting the max align we can provide
</span><span class="kw">let </span>max_align = size_of::&lt;K::Elem&gt;() * min(mr, nr);
<span class="macro">assert!</span>(K::align_to() &lt;= max_align);
}
<span class="doccomment">/// Implement matrix multiply using packed buffers and a microkernel
/// strategy, the type parameter `K` is the gemm microkernel.
</span><span class="kw">unsafe fn </span>gemm_loop&lt;K&gt;(
m: usize, k: usize, n: usize,
alpha: K::Elem,
a: <span class="kw-2">*const </span>K::Elem, rsa: isize, csa: isize,
b: <span class="kw-2">*const </span>K::Elem, rsb: isize, csb: isize,
beta: K::Elem,
c: <span class="kw-2">*mut </span>K::Elem, rsc: isize, csc: isize)
<span class="kw">where </span>K: GemmKernel
{
<span class="macro">debug_assert!</span>(m &lt;= <span class="number">1 </span>|| n == <span class="number">0 </span>|| rsc != <span class="number">0</span>);
<span class="macro">debug_assert!</span>(m == <span class="number">0 </span>|| n &lt;= <span class="number">1 </span>|| csc != <span class="number">0</span>);
<span class="comment">// if A or B have no elements, compute C ← βC and return
</span><span class="kw">if </span>m == <span class="number">0 </span>|| k == <span class="number">0 </span>|| n == <span class="number">0 </span>{
<span class="kw">for </span>i <span class="kw">in </span><span class="number">0</span>..m {
<span class="kw">for </span>j <span class="kw">in </span><span class="number">0</span>..n {
<span class="kw">let </span>cptr = c.offset(rsc * i <span class="kw">as </span>isize + csc * j <span class="kw">as </span>isize);
<span class="kw">if </span>beta.is_zero() {
<span class="kw-2">*</span>cptr = K::Elem::zero(); <span class="comment">// initialize C
</span>} <span class="kw">else </span>{
(<span class="kw-2">*</span>cptr).scale_by(beta);
}
}
}
<span class="kw">return</span>;
}
<span class="kw">let </span>knc = K::nc();
<span class="kw">let </span>kkc = K::kc();
<span class="kw">let </span>kmc = K::mc();
ensure_kernel_params::&lt;K&gt;();
<span class="kw">let </span>(<span class="kw-2">mut </span>packv, bp_offset) = packing_vec::&lt;K&gt;(m, k, n);
<span class="kw">let </span>app = make_aligned_vec_ptr(K::align_to(), <span class="kw-2">&amp;mut </span>packv);
<span class="kw">let </span>bpp = app.offset(bp_offset);
<span class="comment">// LOOP 5: split n into nc parts
</span><span class="kw">for </span>(l5, nc) <span class="kw">in </span>range_chunk(n, knc) {
<span class="macro">dprint!</span>(<span class="string">&quot;LOOP 5, {}, nc={}&quot;</span>, l5, nc);
<span class="kw">let </span>b = b.stride_offset(csb, knc * l5);
<span class="kw">let </span>c = c.stride_offset(csc, knc * l5);
<span class="comment">// LOOP 4: split k in kc parts
</span><span class="kw">for </span>(l4, kc) <span class="kw">in </span>range_chunk(k, kkc) {
<span class="macro">dprint!</span>(<span class="string">&quot;LOOP 4, {}, kc={}&quot;</span>, l4, kc);
<span class="kw">let </span>b = b.stride_offset(rsb, kkc * l4);
<span class="kw">let </span>a = a.stride_offset(csa, kkc * l4);
<span class="macro">debug!</span>(<span class="kw">for </span>elt <span class="kw">in </span><span class="kw-2">&amp;mut </span>packv { <span class="kw-2">*</span>elt = &lt;<span class="kw">_</span>&gt;::one(); });
<span class="comment">// Pack B -&gt; B~
</span>pack(kc, nc, K::nr(), bpp, b, csb, rsb);
<span class="comment">// LOOP 3: split m into mc parts
</span><span class="kw">for </span>(l3, mc) <span class="kw">in </span>range_chunk(m, kmc) {
<span class="macro">dprint!</span>(<span class="string">&quot;LOOP 3, {}, mc={}&quot;</span>, l3, mc);
<span class="kw">let </span>a = a.stride_offset(rsa, kmc * l3);
<span class="kw">let </span>c = c.stride_offset(rsc, kmc * l3);
<span class="comment">// Pack A -&gt; A~
</span>pack(kc, mc, K::mr(), app, a, rsa, csa);
<span class="comment">// First time writing to C, use user&#39;s `beta`, else accumulate
</span><span class="kw">let </span>betap = <span class="kw">if </span>l4 == <span class="number">0 </span>{ beta } <span class="kw">else </span>{ &lt;<span class="kw">_</span>&gt;::one() };
<span class="comment">// LOOP 2 and 1
</span>gemm_packed::&lt;K&gt;(nc, kc, mc,
alpha,
app, bpp,
betap,
c, rsc, csc);
}
}
}
}
<span class="doccomment">/// Loops 1 and 2 around the µ-kernel
///
/// + app: packed A (A~)
/// + bpp: packed B (B~)
/// + nc: columns of packed B
/// + kc: columns of packed A / rows of packed B
/// + mc: rows of packed A
</span><span class="kw">unsafe fn </span>gemm_packed&lt;K&gt;(nc: usize, kc: usize, mc: usize,
alpha: K::Elem,
app: <span class="kw-2">*const </span>K::Elem, bpp: <span class="kw-2">*const </span>K::Elem,
beta: K::Elem,
c: <span class="kw-2">*mut </span>K::Elem, rsc: isize, csc: isize)
<span class="kw">where </span>K: GemmKernel,
{
<span class="kw">let </span>mr = K::mr();
<span class="kw">let </span>nr = K::nr();
<span class="comment">// make a mask buffer that fits 8 x 8 f32 and 8 x 4 f64 kernels and alignment
</span><span class="macro">assert!</span>(mr * nr * size_of::&lt;K::Elem&gt;() &lt;= <span class="number">256 </span>&amp;&amp; K::align_to() &lt;= <span class="number">32</span>);
<span class="kw">let </span><span class="kw-2">mut </span>mask_buf = [<span class="number">0u8</span>; <span class="number">256 </span>+ <span class="number">31</span>];
<span class="kw">let </span>mask_ptr = align_ptr(<span class="number">32</span>, mask_buf.as_mut_ptr()) <span class="kw">as </span><span class="kw-2">*mut </span>K::Elem;
<span class="comment">// LOOP 2: through micropanels in packed `b`
</span><span class="kw">for </span>(l2, nr_) <span class="kw">in </span>range_chunk(nc, nr) {
<span class="kw">let </span>bpp = bpp.stride_offset(<span class="number">1</span>, kc * nr * l2);
<span class="kw">let </span>c = c.stride_offset(csc, nr * l2);
<span class="comment">// LOOP 1: through micropanels in packed `a` while `b` is constant
</span><span class="kw">for </span>(l1, mr_) <span class="kw">in </span>range_chunk(mc, mr) {
<span class="kw">let </span>app = app.stride_offset(<span class="number">1</span>, kc * mr * l1);
<span class="kw">let </span>c = c.stride_offset(rsc, mr * l1);
<span class="comment">// GEMM KERNEL
// NOTE: For the rust kernels, it performs better to simply
// always use the masked kernel function!
</span><span class="kw">if </span>K::always_masked() || nr_ &lt; nr || mr_ &lt; mr {
masked_kernel::&lt;<span class="kw">_</span>, K&gt;(kc, alpha, <span class="kw-2">&amp;*</span>app, <span class="kw-2">&amp;*</span>bpp,
beta, <span class="kw-2">&amp;mut *</span>c, rsc, csc,
mr_, nr_, mask_ptr);
<span class="kw">continue</span>;
} <span class="kw">else </span>{
K::kernel(kc, alpha, app, bpp, beta, c, rsc, csc);
}
}
}
}
<span class="doccomment">/// Allocate a vector of uninitialized data to be used for both packing buffers.
///
/// + A~ needs be KC x MC
/// + B~ needs be KC x NC
/// but we can make them smaller if the matrix is smaller than this (just ensure
/// we have rounded up to a multiple of the kernel size).
///
/// Return packing vector and offset to start of b
</span><span class="kw">unsafe fn </span>packing_vec&lt;K&gt;(m: usize, k: usize, n: usize) -&gt; (Vec&lt;K::Elem&gt;, isize)
<span class="kw">where </span>K: GemmKernel,
{
<span class="kw">let </span>m = min(m, K::mc());
<span class="kw">let </span>k = min(k, K::kc());
<span class="kw">let </span>n = min(n, K::nc());
<span class="comment">// round up k, n to multiples of mr, nr
// round up to multiple of kc
</span><span class="kw">let </span>apack_size = k * round_up_to(m, K::mr());
<span class="kw">let </span>bpack_size = k * round_up_to(n, K::nr());
<span class="kw">let </span>nelem = apack_size + bpack_size;
<span class="kw">let </span><span class="kw-2">mut </span>v = Vec::with_capacity(nelem);
v.set_len(nelem);
<span class="macro">dprint!</span>(<span class="string">&quot;packed nelem={}, apack={}, bpack={},
m={} k={} n={}&quot;</span>,
nelem, apack_size, bpack_size,
m,k,n);
<span class="comment">// max alignment requirement is a multiple of min(MR, NR) * sizeof&lt;Elem&gt;
// because apack_size is a multiple of MR, start of b aligns fine
</span>(v, apack_size <span class="kw">as </span>isize)
}
<span class="doccomment">/// Align a pointer into the vec. Will reallocate to fit &amp; shift the pointer
/// forwards if needed. This invalidates any previous pointers into the v.
</span><span class="kw">unsafe fn </span>make_aligned_vec_ptr&lt;U&gt;(align_to: usize, v: <span class="kw-2">&amp;mut </span>Vec&lt;U&gt;) -&gt; <span class="kw-2">*mut </span>U {
<span class="kw">let </span><span class="kw-2">mut </span>ptr = v.as_mut_ptr();
<span class="kw">if </span>align_to != <span class="number">0 </span>{
<span class="kw">if </span>v.as_ptr() <span class="kw">as </span>usize % align_to != <span class="number">0 </span>{
<span class="kw">let </span>cap = v.capacity();
v.reserve_exact(cap + align_to / size_of::&lt;U&gt;() - <span class="number">1</span>);
ptr = align_ptr(align_to, v.as_mut_ptr());
}
}
ptr
}
<span class="doccomment">/// offset the ptr forwards to align to a specific byte count
</span><span class="kw">unsafe fn </span>align_ptr&lt;U&gt;(align_to: usize, <span class="kw-2">mut </span>ptr: <span class="kw-2">*mut </span>U) -&gt; <span class="kw-2">*mut </span>U {
<span class="kw">if </span>align_to != <span class="number">0 </span>{
<span class="kw">let </span>cur_align = ptr <span class="kw">as </span>usize % align_to;
<span class="kw">if </span>cur_align != <span class="number">0 </span>{
ptr = ptr.offset(((align_to - cur_align) / size_of::&lt;U&gt;()) <span class="kw">as </span>isize);
}
}
ptr
}
<span class="doccomment">/// Pack matrix into `pack`
///
/// + kc: length of the micropanel
/// + mc: number of rows/columns in the matrix to be packed
/// + mr: kernel rows/columns that we round up to
/// + pack: packing buffer
/// + a: matrix,
/// + rsa: row stride
/// + csa: column stride
</span><span class="kw">unsafe fn </span>pack&lt;T&gt;(kc: usize, mc: usize, mr: usize, pack: <span class="kw-2">*mut </span>T,
a: <span class="kw-2">*const </span>T, rsa: isize, csa: isize)
<span class="kw">where </span>T: Element
{
<span class="kw">let </span><span class="kw-2">mut </span>pack = pack;
<span class="kw">for </span>ir <span class="kw">in </span><span class="number">0</span>..mc/mr {
<span class="kw">let </span>row_offset = ir * mr;
<span class="kw">for </span>j <span class="kw">in </span><span class="number">0</span>..kc {
<span class="kw">for </span>i <span class="kw">in </span><span class="number">0</span>..mr {
<span class="kw-2">*</span>pack = <span class="kw-2">*</span>a.stride_offset(rsa, i + row_offset)
.stride_offset(csa, j);
pack.inc();
}
}
}
<span class="kw">let </span>zero = &lt;<span class="kw">_</span>&gt;::zero();
<span class="comment">// Pad with zeros to multiple of kernel size (uneven mc)
</span><span class="kw">let </span>rest = mc % mr;
<span class="kw">if </span>rest &gt; <span class="number">0 </span>{
<span class="kw">let </span>row_offset = (mc/mr) * mr;
<span class="kw">for </span>j <span class="kw">in </span><span class="number">0</span>..kc {
<span class="kw">for </span>i <span class="kw">in </span><span class="number">0</span>..mr {
<span class="kw">if </span>i &lt; rest {
<span class="kw-2">*</span>pack = <span class="kw-2">*</span>a.stride_offset(rsa, i + row_offset)
.stride_offset(csa, j);
} <span class="kw">else </span>{
<span class="kw-2">*</span>pack = zero;
}
pack.inc();
}
}
}
}
<span class="doccomment">/// Call the GEMM kernel with a &quot;masked&quot; output C.
///
/// Simply redirect the MR by NR kernel output to the passed
/// in `mask_buf`, and copy the non masked region to the real
/// C.
///
/// + rows: rows of kernel unmasked
/// + cols: cols of kernel unmasked
</span><span class="attribute">#[inline(never)]
</span><span class="kw">unsafe fn </span>masked_kernel&lt;T, K&gt;(k: usize, alpha: T,
a: <span class="kw-2">*const </span>T,
b: <span class="kw-2">*const </span>T,
beta: T,
c: <span class="kw-2">*mut </span>T, rsc: isize, csc: isize,
rows: usize, cols: usize,
mask_buf: <span class="kw-2">*mut </span>T)
<span class="kw">where </span>K: GemmKernel&lt;Elem=T&gt;, T: Element,
{
<span class="kw">let </span>mr = K::mr();
<span class="kw">let </span>nr = K::nr();
<span class="comment">// use column major order for `mask_buf`
</span>K::kernel(k, T::one(), a, b, T::zero(), mask_buf, <span class="number">1</span>, mr <span class="kw">as </span>isize);
<span class="kw">let </span><span class="kw-2">mut </span>ab = mask_buf;
<span class="kw">for </span>j <span class="kw">in </span><span class="number">0</span>..nr {
<span class="kw">for </span>i <span class="kw">in </span><span class="number">0</span>..mr {
<span class="kw">if </span>i &lt; rows &amp;&amp; j &lt; cols {
<span class="kw">let </span>cptr = c.offset(rsc * i <span class="kw">as </span>isize + csc * j <span class="kw">as </span>isize);
<span class="kw">if </span>beta.is_zero() {
<span class="kw-2">*</span>cptr = T::zero(); <span class="comment">// initialize C
</span>} <span class="kw">else </span>{
(<span class="kw-2">*</span>cptr).scale_by(beta);
}
(<span class="kw-2">*</span>cptr).scaled_add(alpha, <span class="kw-2">*</span>ab);
}
ab.inc();
}
}
}
</code></pre></div>
</section></div></main><div id="rustdoc-vars" data-root-path="../../" data-current-crate="matrixmultiply" data-themes="ayu,dark,light" data-resource-suffix="" data-rustdoc-version="1.66.0-nightly (5c8bff74b 2022-10-21)" ></div></body></html>