| <!DOCTYPE html PUBLIC "-//W3C//DTD XHTML 1.0 Transitional//EN" "https://www.w3.org/TR/xhtml1/DTD/xhtml1-transitional.dtd"> |
| <html xmlns="http://www.w3.org/1999/xhtml"> |
| <head> |
| <meta http-equiv="Content-Type" content="text/xhtml;charset=UTF-8"/> |
| <meta http-equiv="X-UA-Compatible" content="IE=9"/> |
| <meta name="generator" content="Doxygen 1.8.17"/> |
| <meta name="viewport" content="width=device-width, initial-scale=1"/> |
| <title>mxnet: /work/mxnet/include/mxnet/engine.h Source File</title> |
| <link href="tabs.css" rel="stylesheet" type="text/css"/> |
| <script type="text/javascript" src="jquery.js"></script> |
| <script type="text/javascript" src="dynsections.js"></script> |
| <link href="search/search.css" rel="stylesheet" type="text/css"/> |
| <script type="text/javascript" src="search/searchdata.js"></script> |
| <script type="text/javascript" src="search/search.js"></script> |
| <link href="doxygen.css" rel="stylesheet" type="text/css" /> |
| </head> |
| <body> |
| <div id="top"><!-- do not remove this div, it is closed by doxygen! --> |
| <div id="titlearea"> |
| <table cellspacing="0" cellpadding="0"> |
| <tbody> |
| <tr style="height: 56px;"> |
| <td id="projectalign" style="padding-left: 0.5em;"> |
| <div id="projectname">mxnet |
| </div> |
| </td> |
| </tr> |
| </tbody> |
| </table> |
| </div> |
| <!-- end header part --> |
| <!-- Generated by Doxygen 1.8.17 --> |
| <script type="text/javascript"> |
| /* @license magnet:?xt=urn:btih:cf05388f2679ee054f2beb29a391d25f4e673ac3&dn=gpl-2.0.txt GPL-v2 */ |
| var searchBox = new SearchBox("searchBox", "search",false,'Search'); |
| /* @license-end */ |
| </script> |
| <script type="text/javascript" src="menudata.js"></script> |
| <script type="text/javascript" src="menu.js"></script> |
| <script type="text/javascript"> |
| /* @license magnet:?xt=urn:btih:cf05388f2679ee054f2beb29a391d25f4e673ac3&dn=gpl-2.0.txt GPL-v2 */ |
| $(function() { |
| initMenu('',true,false,'search.php','Search'); |
| $(document).ready(function() { init_search(); }); |
| }); |
| /* @license-end */</script> |
| <div id="main-nav"></div> |
| <!-- window showing the filter options --> |
| <div id="MSearchSelectWindow" |
| onmouseover="return searchBox.OnSearchSelectShow()" |
| onmouseout="return searchBox.OnSearchSelectHide()" |
| onkeydown="return searchBox.OnSearchSelectKey(event)"> |
| </div> |
| |
| <!-- iframe showing the search results (closed by default) --> |
| <div id="MSearchResultsWindow"> |
| <iframe src="javascript:void(0)" frameborder="0" |
| name="MSearchResults" id="MSearchResults"> |
| </iframe> |
| </div> |
| |
| <div id="nav-path" class="navpath"> |
| <ul> |
| <li class="navelem"><a class="el" href="dir_d44c64559bbebec7f509842c48db8b23.html">include</a></li><li class="navelem"><a class="el" href="dir_1143c7affb9ebd026cb6818dd282def7.html">mxnet</a></li> </ul> |
| </div> |
| </div><!-- top --> |
| <div class="header"> |
| <div class="headertitle"> |
| <div class="title">engine.h</div> </div> |
| </div><!--header--> |
| <div class="contents"> |
| <a href="engine_8h.html">Go to the documentation of this file.</a><div class="fragment"><div class="line"><a name="l00001"></a><span class="lineno"> 1</span> <span class="comment">/*</span></div> |
| <div class="line"><a name="l00002"></a><span class="lineno"> 2</span> <span class="comment"> * Licensed to the Apache Software Foundation (ASF) under one</span></div> |
| <div class="line"><a name="l00003"></a><span class="lineno"> 3</span> <span class="comment"> * or more contributor license agreements. See the NOTICE file</span></div> |
| <div class="line"><a name="l00004"></a><span class="lineno"> 4</span> <span class="comment"> * distributed with this work for additional information</span></div> |
| <div class="line"><a name="l00005"></a><span class="lineno"> 5</span> <span class="comment"> * regarding copyright ownership. The ASF licenses this file</span></div> |
| <div class="line"><a name="l00006"></a><span class="lineno"> 6</span> <span class="comment"> * to you under the Apache License, Version 2.0 (the</span></div> |
| <div class="line"><a name="l00007"></a><span class="lineno"> 7</span> <span class="comment"> * "License"); you may not use this file except in compliance</span></div> |
| <div class="line"><a name="l00008"></a><span class="lineno"> 8</span> <span class="comment"> * with the License. You may obtain a copy of the License at</span></div> |
| <div class="line"><a name="l00009"></a><span class="lineno"> 9</span> <span class="comment"> *</span></div> |
| <div class="line"><a name="l00010"></a><span class="lineno"> 10</span> <span class="comment"> * http://www.apache.org/licenses/LICENSE-2.0</span></div> |
| <div class="line"><a name="l00011"></a><span class="lineno"> 11</span> <span class="comment"> *</span></div> |
| <div class="line"><a name="l00012"></a><span class="lineno"> 12</span> <span class="comment"> * Unless required by applicable law or agreed to in writing,</span></div> |
| <div class="line"><a name="l00013"></a><span class="lineno"> 13</span> <span class="comment"> * software distributed under the License is distributed on an</span></div> |
| <div class="line"><a name="l00014"></a><span class="lineno"> 14</span> <span class="comment"> * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY</span></div> |
| <div class="line"><a name="l00015"></a><span class="lineno"> 15</span> <span class="comment"> * KIND, either express or implied. See the License for the</span></div> |
| <div class="line"><a name="l00016"></a><span class="lineno"> 16</span> <span class="comment"> * specific language governing permissions and limitations</span></div> |
| <div class="line"><a name="l00017"></a><span class="lineno"> 17</span> <span class="comment"> * under the License.</span></div> |
| <div class="line"><a name="l00018"></a><span class="lineno"> 18</span> <span class="comment"> */</span></div> |
| <div class="line"><a name="l00019"></a><span class="lineno"> 19</span>  </div> |
| <div class="line"><a name="l00024"></a><span class="lineno"> 24</span> <span class="preprocessor">#ifndef MXNET_ENGINE_H_</span></div> |
| <div class="line"><a name="l00025"></a><span class="lineno"> 25</span> <span class="preprocessor">#define MXNET_ENGINE_H_</span></div> |
| <div class="line"><a name="l00026"></a><span class="lineno"> 26</span>  </div> |
| <div class="line"><a name="l00027"></a><span class="lineno"> 27</span> <span class="preprocessor">#if DMLC_USE_CXX11</span></div> |
| <div class="line"><a name="l00028"></a><span class="lineno"> 28</span> <span class="preprocessor">#include <algorithm></span></div> |
| <div class="line"><a name="l00029"></a><span class="lineno"> 29</span> <span class="preprocessor">#include <memory></span></div> |
| <div class="line"><a name="l00030"></a><span class="lineno"> 30</span> <span class="preprocessor">#include <functional></span></div> |
| <div class="line"><a name="l00031"></a><span class="lineno"> 31</span> <span class="preprocessor">#endif</span></div> |
| <div class="line"><a name="l00032"></a><span class="lineno"> 32</span> <span class="preprocessor">#include <utility></span></div> |
| <div class="line"><a name="l00033"></a><span class="lineno"> 33</span> <span class="preprocessor">#include <vector></span></div> |
| <div class="line"><a name="l00034"></a><span class="lineno"> 34</span> <span class="preprocessor">#include "<a class="code" href="include_2mxnet_2base_8h.html">./base.h</a>"</span></div> |
| <div class="line"><a name="l00035"></a><span class="lineno"> 35</span>  </div> |
| <div class="line"><a name="l00036"></a><span class="lineno"> 36</span> <span class="keyword">namespace </span><a class="code" href="namespacemxnet.html">mxnet</a> {</div> |
| <div class="line"><a name="l00037"></a><span class="lineno"> 37</span>  </div> |
| <div class="line"><a name="l00038"></a><span class="lineno"> 38</span> <span class="comment">// forward declare engine</span></div> |
| <div class="line"><a name="l00039"></a><span class="lineno"> 39</span> <span class="keyword">class </span>Engine;</div> |
| <div class="line"><a name="l00040"></a><span class="lineno"> 40</span>  </div> |
| <div class="line"><a name="l00042"></a><span class="lineno"><a class="line" href="namespacemxnet_1_1engine.html"> 42</a></span> <span class="keyword">namespace </span>engine {</div> |
| <div class="line"><a name="l00043"></a><span class="lineno"> 43</span> <span class="preprocessor">#if MXNET_USE_CUDA</span></div> |
| <div class="line"><a name="l00044"></a><span class="lineno"> 44</span> <span class="comment">/* \brief The class wrapping CUDA event with timing disabled. */</span></div> |
| <div class="line"><a name="l00045"></a><span class="lineno"><a class="line" href="classmxnet_1_1engine_1_1CUDAEvent.html"> 45</a></span> <span class="keyword">class </span><a class="code" href="classmxnet_1_1engine_1_1CUDAEvent.html">CUDAEvent</a> final {</div> |
| <div class="line"><a name="l00046"></a><span class="lineno"> 46</span>  <span class="keyword">public</span>:</div> |
| <div class="line"><a name="l00047"></a><span class="lineno"> 47</span>  <span class="keyword">explicit</span> <a class="code" href="classmxnet_1_1engine_1_1CUDAEvent.html#ae6c6574f8191012ec061cd816a64ff20">CUDAEvent</a>(<a class="code" href="structmxnet_1_1Context.html">Context</a> <span class="keyword">const</span>& ctx);</div> |
| <div class="line"><a name="l00048"></a><span class="lineno"> 48</span>  </div> |
| <div class="line"><a name="l00049"></a><span class="lineno"><a class="line" href="classmxnet_1_1engine_1_1CUDAEvent.html#a0ff9222e821d258fadba48133a9d1eb6"> 49</a></span>  <a class="code" href="classmxnet_1_1engine_1_1CUDAEvent.html#a0ff9222e821d258fadba48133a9d1eb6">CUDAEvent</a>(<a class="code" href="classmxnet_1_1engine_1_1CUDAEvent.html">CUDAEvent</a>&& other) : event_(other.event_), dev_id_(other.dev_id_) {</div> |
| <div class="line"><a name="l00050"></a><span class="lineno"> 50</span>  other.event_ = <span class="keyword">nullptr</span>;</div> |
| <div class="line"><a name="l00051"></a><span class="lineno"> 51</span>  }</div> |
| <div class="line"><a name="l00052"></a><span class="lineno"> 52</span>  </div> |
| <div class="line"><a name="l00053"></a><span class="lineno"> 53</span>  <a class="code" href="classmxnet_1_1engine_1_1CUDAEvent.html#ae6c6574f8191012ec061cd816a64ff20">CUDAEvent</a>(<span class="keyword">const</span> <a class="code" href="classmxnet_1_1engine_1_1CUDAEvent.html">CUDAEvent</a>& other) = <span class="keyword">delete</span>;</div> |
| <div class="line"><a name="l00054"></a><span class="lineno"> 54</span>  <span class="keywordtype">void</span> <a class="code" href="classmxnet_1_1engine_1_1CUDAEvent.html#a3743c3d6d08920567950e8c917b37c64">operator=</a>(<span class="keyword">const</span> <a class="code" href="classmxnet_1_1engine_1_1CUDAEvent.html">CUDAEvent</a>& other) = <span class="keyword">delete</span>;</div> |
| <div class="line"><a name="l00055"></a><span class="lineno"> 55</span>  </div> |
| <div class="line"><a name="l00056"></a><span class="lineno"> 56</span>  <a class="code" href="classmxnet_1_1engine_1_1CUDAEvent.html#ab11e75549f6842a01ec4941dbb6a1c20">~CUDAEvent</a>();</div> |
| <div class="line"><a name="l00057"></a><span class="lineno"> 57</span>  </div> |
| <div class="line"><a name="l00058"></a><span class="lineno"><a class="line" href="classmxnet_1_1engine_1_1CUDAEvent.html#a53ea01451a92b23950a4ddce8c71cdb8"> 58</a></span>  <span class="keyword">inline</span> std::weak_ptr<cudaEvent_t> <a class="code" href="classmxnet_1_1engine_1_1CUDAEvent.html#a53ea01451a92b23950a4ddce8c71cdb8">GetEvent</a>() noexcept {</div> |
| <div class="line"><a name="l00059"></a><span class="lineno"> 59</span>  <span class="keywordflow">return</span> event_;</div> |
| <div class="line"><a name="l00060"></a><span class="lineno"> 60</span>  }</div> |
| <div class="line"><a name="l00061"></a><span class="lineno"> 61</span>  </div> |
| <div class="line"><a name="l00062"></a><span class="lineno"> 62</span>  <span class="keyword">private</span>:</div> |
| <div class="line"><a name="l00063"></a><span class="lineno"> 63</span>  std::shared_ptr<cudaEvent_t> event_;</div> |
| <div class="line"><a name="l00064"></a><span class="lineno"> 64</span>  <span class="keywordtype">int</span> dev_id_;</div> |
| <div class="line"><a name="l00065"></a><span class="lineno"> 65</span> };</div> |
| <div class="line"><a name="l00066"></a><span class="lineno"> 66</span>  </div> |
| <div class="line"><a name="l00067"></a><span class="lineno"><a class="line" href="classmxnet_1_1engine_1_1CUDAEventPool.html"> 67</a></span> <span class="keyword">class </span><a class="code" href="classmxnet_1_1engine_1_1CUDAEventPool.html">CUDAEventPool</a> final {</div> |
| <div class="line"><a name="l00068"></a><span class="lineno"> 68</span>  <span class="keyword">public</span>:</div> |
| <div class="line"><a name="l00069"></a><span class="lineno"><a class="line" href="classmxnet_1_1engine_1_1CUDAEventPool.html#aad9425ed11adf77479af5a666f14f38d"> 69</a></span>  <span class="keyword">explicit</span> <a class="code" href="classmxnet_1_1engine_1_1CUDAEventPool.html#aad9425ed11adf77479af5a666f14f38d">CUDAEventPool</a>(<a class="code" href="structmxnet_1_1Context.html">Context</a> <span class="keyword">const</span>& ctx) : counter_(0) {</div> |
| <div class="line"><a name="l00070"></a><span class="lineno"> 70</span>  <span class="keywordflow">for</span> (<span class="keywordtype">size_t</span> i = 0; i < kPoolSize; ++i) {</div> |
| <div class="line"><a name="l00071"></a><span class="lineno"> 71</span>  events_.emplace_back(ctx);</div> |
| <div class="line"><a name="l00072"></a><span class="lineno"> 72</span>  }</div> |
| <div class="line"><a name="l00073"></a><span class="lineno"> 73</span>  }</div> |
| <div class="line"><a name="l00074"></a><span class="lineno"> 74</span>  </div> |
| <div class="line"><a name="l00075"></a><span class="lineno"><a class="line" href="classmxnet_1_1engine_1_1CUDAEventPool.html#a4e756deaf6e48abd5568cdcdd8f58140"> 75</a></span>  <span class="keyword">inline</span> std::weak_ptr<cudaEvent_t> <a class="code" href="classmxnet_1_1engine_1_1CUDAEventPool.html#a4e756deaf6e48abd5568cdcdd8f58140">GetEvent</a>(<span class="keywordtype">size_t</span> i) noexcept {</div> |
| <div class="line"><a name="l00076"></a><span class="lineno"> 76</span>  <span class="keywordflow">return</span> events_.at(i).GetEvent();</div> |
| <div class="line"><a name="l00077"></a><span class="lineno"> 77</span>  }</div> |
| <div class="line"><a name="l00078"></a><span class="lineno"> 78</span>  </div> |
| <div class="line"><a name="l00079"></a><span class="lineno"><a class="line" href="classmxnet_1_1engine_1_1CUDAEventPool.html#a3e8378ca127166af434256007c649f64"> 79</a></span>  <span class="keyword">inline</span> std::pair<std::weak_ptr<cudaEvent_t>, uint64_t> <a class="code" href="classmxnet_1_1engine_1_1CUDAEventPool.html#a3e8378ca127166af434256007c649f64">GetNextEvent</a>() noexcept {</div> |
| <div class="line"><a name="l00080"></a><span class="lineno"> 80</span>  uint64_t c = counter_++;</div> |
| <div class="line"><a name="l00081"></a><span class="lineno"> 81</span>  <span class="keywordflow">return</span> {events_.at((c) % kPoolSize).GetEvent(), c};</div> |
| <div class="line"><a name="l00082"></a><span class="lineno"> 82</span>  }</div> |
| <div class="line"><a name="l00083"></a><span class="lineno"> 83</span>  </div> |
| <div class="line"><a name="l00084"></a><span class="lineno"><a class="line" href="classmxnet_1_1engine_1_1CUDAEventPool.html#a7449bbaca7f9d382999930d38d2710ef"> 84</a></span>  <span class="keyword">inline</span> uint64_t <a class="code" href="classmxnet_1_1engine_1_1CUDAEventPool.html#a7449bbaca7f9d382999930d38d2710ef">GetCounterValue</a>() noexcept {</div> |
| <div class="line"><a name="l00085"></a><span class="lineno"> 85</span>  <span class="keywordflow">return</span> counter_.load();</div> |
| <div class="line"><a name="l00086"></a><span class="lineno"> 86</span>  }</div> |
| <div class="line"><a name="l00087"></a><span class="lineno"> 87</span>  </div> |
| <div class="line"><a name="l00088"></a><span class="lineno"> 88</span>  <span class="keyword">private</span>:</div> |
| <div class="line"><a name="l00089"></a><span class="lineno"> 89</span>  <span class="keyword">static</span> constexpr <span class="keywordtype">size_t</span> kPoolSize = 64;</div> |
| <div class="line"><a name="l00090"></a><span class="lineno"> 90</span>  std::vector<CUDAEvent> events_;</div> |
| <div class="line"><a name="l00091"></a><span class="lineno"> 91</span>  std::atomic<uint64_t> counter_;</div> |
| <div class="line"><a name="l00092"></a><span class="lineno"> 92</span> };</div> |
| <div class="line"><a name="l00093"></a><span class="lineno"> 93</span>  </div> |
| <div class="line"><a name="l00095"></a><span class="lineno"><a class="line" href="structmxnet_1_1engine_1_1EventInfo.html"> 95</a></span> <span class="keyword">struct </span><a class="code" href="structmxnet_1_1engine_1_1EventInfo.html">EventInfo</a> {</div> |
| <div class="line"><a name="l00096"></a><span class="lineno"><a class="line" href="structmxnet_1_1engine_1_1EventInfo.html#a1906da3e8772f14bae7e2a545e393263"> 96</a></span>  std::weak_ptr<cudaEvent_t> <a class="code" href="structmxnet_1_1engine_1_1EventInfo.html#a1906da3e8772f14bae7e2a545e393263">event</a>;</div> |
| <div class="line"><a name="l00097"></a><span class="lineno"><a class="line" href="structmxnet_1_1engine_1_1EventInfo.html#a0c134825ed31fe9c210affafe29614ae"> 97</a></span>  cudaStream_t <a class="code" href="structmxnet_1_1engine_1_1EventInfo.html#a0c134825ed31fe9c210affafe29614ae">stream</a>;</div> |
| <div class="line"><a name="l00098"></a><span class="lineno"><a class="line" href="structmxnet_1_1engine_1_1EventInfo.html#a275fb1793ee6bdcf4d78c1536b1dd95e"> 98</a></span>  uint64_t <a class="code" href="structmxnet_1_1engine_1_1EventInfo.html#a275fb1793ee6bdcf4d78c1536b1dd95e">pool_index</a>;</div> |
| <div class="line"><a name="l00099"></a><span class="lineno"> 99</span> };</div> |
| <div class="line"><a name="l00101"></a><span class="lineno"><a class="line" href="structmxnet_1_1engine_1_1SyncObject.html"> 101</a></span> <span class="keyword">struct </span><a class="code" href="structmxnet_1_1engine_1_1SyncObject.html">SyncObject</a> {</div> |
| <div class="line"><a name="l00102"></a><span class="lineno"> 102</span>  <span class="comment">// vector can carry multiple reader events</span></div> |
| <div class="line"><a name="l00103"></a><span class="lineno"><a class="line" href="structmxnet_1_1engine_1_1SyncObject.html#a0f52005af80e0b61afa08e1295830a0f"> 103</a></span>  std::vector<EventInfo> <a class="code" href="structmxnet_1_1engine_1_1SyncObject.html#a0f52005af80e0b61afa08e1295830a0f">reader_events</a>;</div> |
| <div class="line"><a name="l00104"></a><span class="lineno"> 104</span>  <span class="comment">// vector should carry only 1 writer event</span></div> |
| <div class="line"><a name="l00105"></a><span class="lineno"><a class="line" href="structmxnet_1_1engine_1_1SyncObject.html#a4c2ae5228111a794e9651e5b88861588"> 105</a></span>  std::vector<EventInfo> <a class="code" href="structmxnet_1_1engine_1_1SyncObject.html#a4c2ae5228111a794e9651e5b88861588">writer_event</a>;</div> |
| <div class="line"><a name="l00106"></a><span class="lineno"><a class="line" href="structmxnet_1_1engine_1_1SyncObject.html#a218d57ab487ba20d18dc79c2c4e985ea"> 106</a></span>  std::mutex <a class="code" href="structmxnet_1_1engine_1_1SyncObject.html#a218d57ab487ba20d18dc79c2c4e985ea">mutex</a>;</div> |
| <div class="line"><a name="l00107"></a><span class="lineno"> 107</span> };</div> |
| <div class="line"><a name="l00108"></a><span class="lineno"> 108</span> <span class="preprocessor">#endif</span></div> |
| <div class="line"><a name="l00109"></a><span class="lineno"> 109</span>  </div> |
| <div class="line"><a name="l00111"></a><span class="lineno"><a class="line" href="structmxnet_1_1engine_1_1Var.html"> 111</a></span> <span class="keyword">struct </span><a class="code" href="structmxnet_1_1engine_1_1Var.html">Var</a> {</div> |
| <div class="line"><a name="l00112"></a><span class="lineno"><a class="line" href="structmxnet_1_1engine_1_1Var.html#a5ae9f575784f18598df8824ccd421c18"> 112</a></span>  <span class="keyword">virtual</span> <span class="keywordtype">size_t</span> <a class="code" href="structmxnet_1_1engine_1_1Var.html#a5ae9f575784f18598df8824ccd421c18">version</a>() {</div> |
| <div class="line"><a name="l00113"></a><span class="lineno"> 113</span>  <span class="keywordflow">return</span> <a class="code" href="structmxnet_1_1engine_1_1Var.html#a43edd2b4bd955d283b3dff5e3f66279d">version_</a>;</div> |
| <div class="line"><a name="l00114"></a><span class="lineno"> 114</span>  }</div> |
| <div class="line"><a name="l00115"></a><span class="lineno"> 115</span>  <span class="keyword">virtual</span> <a class="code" href="structmxnet_1_1engine_1_1Var.html#af9884f59707511d65ecdb04a6dab0423">~Var</a>() = <span class="keywordflow">default</span>;</div> |
| <div class="line"><a name="l00121"></a><span class="lineno"> 121</span>  <span class="keyword">template</span> <<span class="keyword">typename</span> T></div> |
| <div class="line"><a name="l00122"></a><span class="lineno"> 122</span>  <span class="keyword">inline</span> T* <a class="code" href="structmxnet_1_1engine_1_1Var.html#af71c04da3c374220356efa98cc215590">Cast</a>();</div> |
| <div class="line"><a name="l00127"></a><span class="lineno"><a class="line" href="structmxnet_1_1engine_1_1Var.html#a43edd2b4bd955d283b3dff5e3f66279d"> 127</a></span>  <span class="keywordtype">size_t</span> <a class="code" href="structmxnet_1_1engine_1_1Var.html#a43edd2b4bd955d283b3dff5e3f66279d">version_</a>{0};</div> |
| <div class="line"><a name="l00128"></a><span class="lineno"> 128</span> <span class="preprocessor">#if MXNET_USE_CUDA</span></div> |
| <div class="line"><a name="l00129"></a><span class="lineno"> 129</span>  </div> |
| <div class="line"><a name="l00132"></a><span class="lineno"><a class="line" href="structmxnet_1_1engine_1_1Var.html#a10a695c677bc0019d1c9eeed6aab7ee4"> 132</a></span>  <a class="code" href="structmxnet_1_1engine_1_1SyncObject.html">SyncObject</a> <a class="code" href="structmxnet_1_1engine_1_1Var.html#a10a695c677bc0019d1c9eeed6aab7ee4">sync_object</a>;</div> |
| <div class="line"><a name="l00133"></a><span class="lineno"> 133</span> <span class="preprocessor">#endif</span></div> |
| <div class="line"><a name="l00134"></a><span class="lineno"> 134</span> }; <span class="comment">// struct Var</span></div> |
| <div class="line"><a name="l00135"></a><span class="lineno"> 135</span>  </div> |
| <div class="line"><a name="l00137"></a><span class="lineno"><a class="line" href="namespacemxnet_1_1engine.html#a9d36c4f33eae8531586dc2edf83ae7cf"> 137</a></span> <span class="keyword">struct </span>Opr;</div> |
| <div class="line"><a name="l00139"></a><span class="lineno"> 139</span> <span class="keyword">typedef</span> <a class="code" href="structmxnet_1_1engine_1_1Var.html">Var</a>* <a class="code" href="namespacemxnet_1_1engine.html#a9d36c4f33eae8531586dc2edf83ae7cf">VarHandle</a>;</div> |
| <div class="line"><a name="l00141"></a><span class="lineno"><a class="line" href="namespacemxnet_1_1engine.html#a2d9b14b658e3f3c4e03ca49cd38ace94"> 141</a></span> <span class="keyword">typedef</span> Opr* <a class="code" href="namespacemxnet_1_1engine.html#a2d9b14b658e3f3c4e03ca49cd38ace94">OprHandle</a>;</div> |
| <div class="line"><a name="l00146"></a><span class="lineno"><a class="line" href="classmxnet_1_1engine_1_1CallbackOnStart.html"> 146</a></span> <span class="keyword">class </span><a class="code" href="classmxnet_1_1engine_1_1CallbackOnStart.html">CallbackOnStart</a> {</div> |
| <div class="line"><a name="l00147"></a><span class="lineno"> 147</span>  <span class="keyword">public</span>:</div> |
| <div class="line"><a name="l00148"></a><span class="lineno"> 148</span>  <span class="comment">// use implicit copy and assign</span></div> |
| <div class="line"><a name="l00150"></a><span class="lineno"><a class="line" href="classmxnet_1_1engine_1_1CallbackOnStart.html#aaef3cbe661d77ef471d084feb058a613"> 150</a></span> <span class="comment"></span> <span class="keyword">inline</span> <span class="keywordtype">void</span> <a class="code" href="classmxnet_1_1engine_1_1CallbackOnStart.html#aaef3cbe661d77ef471d084feb058a613">operator()</a>(<span class="keyword">const</span> dmlc::Error* error = <span class="keyword">nullptr</span>)<span class="keyword"> const </span>{</div> |
| <div class="line"><a name="l00151"></a><span class="lineno"> 151</span>  <span class="keywordflow">if</span> (callback_ != <span class="keyword">nullptr</span>)</div> |
| <div class="line"><a name="l00152"></a><span class="lineno"> 152</span>  (*callback_)(engine_, param_, error);</div> |
| <div class="line"><a name="l00153"></a><span class="lineno"> 153</span>  }</div> |
| <div class="line"><a name="l00154"></a><span class="lineno"> 154</span>  </div> |
| <div class="line"><a name="l00155"></a><span class="lineno"> 155</span>  <span class="keyword">private</span>:</div> |
| <div class="line"><a name="l00157"></a><span class="lineno"><a class="line" href="classmxnet_1_1engine_1_1CallbackOnStart.html#aa365d0be0a3a369cdc5da59787f93297"> 157</a></span>  <span class="keyword">friend</span> class ::mxnet::Engine;</div> |
| <div class="line"><a name="l00159"></a><span class="lineno"> 159</span>  void (*callback_)(<a class="code" href="classmxnet_1_1Engine.html">Engine</a>*, <span class="keywordtype">void</span>*, <span class="keyword">const</span> dmlc::Error*);</div> |
| <div class="line"><a name="l00161"></a><span class="lineno"> 161</span>  <a class="code" href="classmxnet_1_1Engine.html">Engine</a>* engine_;</div> |
| <div class="line"><a name="l00163"></a><span class="lineno"> 163</span>  <span class="keywordtype">void</span>* param_;</div> |
| <div class="line"><a name="l00164"></a><span class="lineno"> 164</span> };</div> |
| <div class="line"><a name="l00169"></a><span class="lineno"><a class="line" href="classmxnet_1_1engine_1_1CallbackOnComplete.html"> 169</a></span> <span class="keyword">class </span><a class="code" href="classmxnet_1_1engine_1_1CallbackOnComplete.html">CallbackOnComplete</a> {</div> |
| <div class="line"><a name="l00170"></a><span class="lineno"> 170</span>  <span class="keyword">public</span>:</div> |
| <div class="line"><a name="l00171"></a><span class="lineno"> 171</span>  <span class="comment">// use implicit copy and assign</span></div> |
| <div class="line"><a name="l00173"></a><span class="lineno"><a class="line" href="classmxnet_1_1engine_1_1CallbackOnComplete.html#aadfc4d64cee7555f9ef105172d855fc5"> 173</a></span> <span class="comment"></span> <span class="keyword">inline</span> <span class="keywordtype">void</span> <a class="code" href="classmxnet_1_1engine_1_1CallbackOnComplete.html#aadfc4d64cee7555f9ef105172d855fc5">operator()</a>(<span class="keyword">const</span> dmlc::Error* error = <span class="keyword">nullptr</span>)<span class="keyword"> const </span>{</div> |
| <div class="line"><a name="l00174"></a><span class="lineno"> 174</span>  (*callback_)(engine_, param_, error);</div> |
| <div class="line"><a name="l00175"></a><span class="lineno"> 175</span>  }</div> |
| <div class="line"><a name="l00176"></a><span class="lineno"> 176</span>  </div> |
| <div class="line"><a name="l00177"></a><span class="lineno"> 177</span>  <span class="keyword">private</span>:</div> |
| <div class="line"><a name="l00179"></a><span class="lineno"><a class="line" href="classmxnet_1_1engine_1_1CallbackOnComplete.html#aa365d0be0a3a369cdc5da59787f93297"> 179</a></span>  <span class="keyword">friend</span> class ::mxnet::Engine;</div> |
| <div class="line"><a name="l00181"></a><span class="lineno"> 181</span>  void (*callback_)(<a class="code" href="classmxnet_1_1Engine.html">Engine</a>*, <span class="keywordtype">void</span>*, <span class="keyword">const</span> dmlc::Error*);</div> |
| <div class="line"><a name="l00183"></a><span class="lineno"> 183</span>  <a class="code" href="classmxnet_1_1Engine.html">Engine</a>* engine_;</div> |
| <div class="line"><a name="l00185"></a><span class="lineno"> 185</span>  <span class="keywordtype">void</span>* param_;</div> |
| <div class="line"><a name="l00186"></a><span class="lineno"> 186</span> };</div> |
| <div class="line"><a name="l00187"></a><span class="lineno"> 187</span> } <span class="comment">// namespace engine</span></div> |
| <div class="line"><a name="l00188"></a><span class="lineno"> 188</span>  </div> |
| <div class="line"><a name="l00189"></a><span class="lineno"> 189</span> <span class="preprocessor">#if DMLC_USE_CXX11</span></div> |
| <div class="line"><a name="l00190"></a><span class="lineno"> 190</span>  </div> |
| <div class="line"><a name="l00191"></a><span class="lineno"><a class="line" href="namespacemxnet.html#a998b74220fab2b012cf8a179650e1b3b"> 191</a></span> <span class="keyword">enum class</span> <a class="code" href="namespacemxnet.html#a998b74220fab2b012cf8a179650e1b3b">FnProperty</a> {</div> |
| <div class="line"><a name="l00193"></a><span class="lineno"> 193</span>  <a class="code" href="namespacemxnet.html#a998b74220fab2b012cf8a179650e1b3ba07fa7a19aa722c635a15e94cb7f50416">kNormal</a>,</div> |
| <div class="line"><a name="l00195"></a><span class="lineno"> 195</span>  <a class="code" href="namespacemxnet.html#a998b74220fab2b012cf8a179650e1b3ba739f2f416f05f4728c217f09e93958d1">kCopyFromGPU</a>,</div> |
| <div class="line"><a name="l00197"></a><span class="lineno"> 197</span>  <a class="code" href="namespacemxnet.html#a998b74220fab2b012cf8a179650e1b3ba6cd75f41e0ec8d61b0a2f0e20ef6d1e8">kCopyToGPU</a>,</div> |
| <div class="line"><a name="l00199"></a><span class="lineno"> 199</span>  <a class="code" href="namespacemxnet.html#a998b74220fab2b012cf8a179650e1b3bac41ceb98eeb9b2e208e3e242a7357142">kCPUPrioritized</a>,</div> |
| <div class="line"><a name="l00201"></a><span class="lineno"> 201</span>  <a class="code" href="namespacemxnet.html#a998b74220fab2b012cf8a179650e1b3ba9f2b960005d2a3a5f35ac32809d84db7">kAsync</a>,</div> |
| <div class="line"><a name="l00203"></a><span class="lineno"> 203</span>  <a class="code" href="namespacemxnet.html#a998b74220fab2b012cf8a179650e1b3bac41fda8552e9d327ad3b06b1bafa663a">kDeleteVar</a>,</div> |
| <div class="line"><a name="l00205"></a><span class="lineno"> 205</span>  <a class="code" href="namespacemxnet.html#a998b74220fab2b012cf8a179650e1b3ba4879e1f172ddec6223d17ea90e691ddd">kGPUPrioritized</a>,</div> |
| <div class="line"><a name="l00207"></a><span class="lineno"> 207</span>  <a class="code" href="namespacemxnet.html#a998b74220fab2b012cf8a179650e1b3bac76ad5584b42e2ba06f1b5e9ebac8b2a">kNoSkip</a></div> |
| <div class="line"><a name="l00208"></a><span class="lineno"> 208</span> }; <span class="comment">// enum class FnProperty</span></div> |
| <div class="line"><a name="l00209"></a><span class="lineno"> 209</span>  </div> |
| <div class="line"><a name="l00213"></a><span class="lineno"><a class="line" href="classmxnet_1_1Engine.html"> 213</a></span> <span class="keyword">class </span><a class="code" href="include_2mxnet_2base_8h.html#a91a09948aaaffa1eb64508db79009a05">MXNET_API</a> <a class="code" href="classmxnet_1_1Engine.html">Engine</a> {</div> |
| <div class="line"><a name="l00214"></a><span class="lineno"> 214</span>  <span class="keyword">public</span>:</div> |
| <div class="line"><a name="l00216"></a><span class="lineno"><a class="line" href="classmxnet_1_1Engine.html#a86c85108347be2bfbb61696b9937316b"> 216</a></span>  <span class="keyword">typedef</span> <a class="code" href="classmxnet_1_1engine_1_1CallbackOnStart.html">engine::CallbackOnStart</a> <a class="code" href="classmxnet_1_1Engine.html#a86c85108347be2bfbb61696b9937316b">CallbackOnStart</a>;</div> |
| <div class="line"><a name="l00218"></a><span class="lineno"><a class="line" href="classmxnet_1_1Engine.html#a16b757432556f835d27f1b5e1dbe1b06"> 218</a></span>  <span class="keyword">typedef</span> <a class="code" href="classmxnet_1_1engine_1_1CallbackOnComplete.html">engine::CallbackOnComplete</a> <a class="code" href="classmxnet_1_1Engine.html#a16b757432556f835d27f1b5e1dbe1b06">CallbackOnComplete</a>;</div> |
| <div class="line"><a name="l00220"></a><span class="lineno"><a class="line" href="classmxnet_1_1Engine.html#a07f30ab85fca436e1bbcc72cd4d8bb35"> 220</a></span>  <span class="keyword">typedef</span> std::function<void(<a class="code" href="structmxnet_1_1RunContext.html">RunContext</a>)> <a class="code" href="classmxnet_1_1Engine.html#a07f30ab85fca436e1bbcc72cd4d8bb35">SyncFn</a>;</div> |
| <div class="line"><a name="l00222"></a><span class="lineno"><a class="line" href="classmxnet_1_1Engine.html#af3c96ff8fb3f7e86ff0d79ff1d930001"> 222</a></span>  <span class="keyword">typedef</span> std::function<void(<a class="code" href="structmxnet_1_1RunContext.html">RunContext</a>, <a class="code" href="classmxnet_1_1engine_1_1CallbackOnStart.html">CallbackOnStart</a>, <a class="code" href="classmxnet_1_1engine_1_1CallbackOnComplete.html">CallbackOnComplete</a>)> <a class="code" href="classmxnet_1_1Engine.html#af3c96ff8fb3f7e86ff0d79ff1d930001">AsyncFn</a>;</div> |
| <div class="line"><a name="l00224"></a><span class="lineno"><a class="line" href="classmxnet_1_1Engine.html#aac31510c793a12944c33f9cac6150491"> 224</a></span>  <span class="keyword">typedef</span> <a class="code" href="structmxnet_1_1engine_1_1Var.html">engine::VarHandle</a> <a class="code" href="classmxnet_1_1Engine.html#aac31510c793a12944c33f9cac6150491">VarHandle</a>;</div> |
| <div class="line"><a name="l00226"></a><span class="lineno"><a class="line" href="classmxnet_1_1Engine.html#a832436e413a075291aa1a631942c3f01"> 226</a></span>  <span class="keyword">typedef</span> <a class="code" href="namespacemxnet_1_1engine.html#a2d9b14b658e3f3c4e03ca49cd38ace94">engine::OprHandle</a> <a class="code" href="classmxnet_1_1Engine.html#a832436e413a075291aa1a631942c3f01">OprHandle</a>;</div> |
| <div class="line"><a name="l00234"></a><span class="lineno"> 234</span>  <span class="keyword">virtual</span> <span class="keywordtype">void</span> NotifyShutdown() = 0;</div> |
| <div class="line"><a name="l00238"></a><span class="lineno"><a class="line" href="classmxnet_1_1Engine.html#a30f8947a7a1ad86b4c2cdb5b15b6d1e1"> 238</a></span>  <span class="keyword">virtual</span> <span class="keywordtype">void</span> <a class="code" href="classmxnet_1_1Engine.html#a30f8947a7a1ad86b4c2cdb5b15b6d1e1">Stop</a>() {</div> |
| <div class="line"><a name="l00239"></a><span class="lineno"> 239</span>  LOG(FATAL) << <span class="stringliteral">"Engine cannot be stopped"</span>;</div> |
| <div class="line"><a name="l00240"></a><span class="lineno"> 240</span>  }</div> |
| <div class="line"><a name="l00244"></a><span class="lineno"><a class="line" href="classmxnet_1_1Engine.html#a136a1d60db77a0846faa8462d0fb6237"> 244</a></span>  <span class="keyword">virtual</span> <span class="keywordtype">void</span> <a class="code" href="classmxnet_1_1Engine.html#a136a1d60db77a0846faa8462d0fb6237">Start</a>() {</div> |
| <div class="line"><a name="l00245"></a><span class="lineno"> 245</span>  LOG(FATAL) << <span class="stringliteral">"Engine cannot be restarted"</span>;</div> |
| <div class="line"><a name="l00246"></a><span class="lineno"> 246</span>  }</div> |
| <div class="line"><a name="l00253"></a><span class="lineno"> 253</span>  <span class="keyword">virtual</span> <a class="code" href="namespacemxnet_1_1engine.html#a9d36c4f33eae8531586dc2edf83ae7cf">VarHandle</a> NewVariable() = 0;</div> |
| <div class="line"><a name="l00266"></a><span class="lineno"> 266</span>  <span class="keyword">virtual</span> <a class="code" href="namespacemxnet_1_1engine.html#a2d9b14b658e3f3c4e03ca49cd38ace94">OprHandle</a> NewOperator(AsyncFn fn,</div> |
| <div class="line"><a name="l00267"></a><span class="lineno"> 267</span>  std::vector<VarHandle> <span class="keyword">const</span>& const_vars,</div> |
| <div class="line"><a name="l00268"></a><span class="lineno"> 268</span>  std::vector<VarHandle> <span class="keyword">const</span>& mutable_vars,</div> |
| <div class="line"><a name="l00269"></a><span class="lineno"> 269</span>  <a class="code" href="namespacemxnet.html#a998b74220fab2b012cf8a179650e1b3b">FnProperty</a> prop = <a class="code" href="namespacemxnet.html#a998b74220fab2b012cf8a179650e1b3ba07fa7a19aa722c635a15e94cb7f50416">FnProperty::kNormal</a>,</div> |
| <div class="line"><a name="l00270"></a><span class="lineno"> 270</span>  <span class="keyword">const</span> <span class="keywordtype">char</span>* opr_name = <span class="keyword">nullptr</span>,</div> |
| <div class="line"><a name="l00271"></a><span class="lineno"> 271</span>  <span class="keywordtype">bool</span> wait = <span class="keyword">false</span>) = 0;</div> |
| <div class="line"><a name="l00279"></a><span class="lineno"> 279</span>  <span class="keyword">virtual</span> <span class="keywordtype">void</span> DeleteOperator(<a class="code" href="namespacemxnet_1_1engine.html#a2d9b14b658e3f3c4e03ca49cd38ace94">OprHandle</a> op) = 0;</div> |
| <div class="line"><a name="l00287"></a><span class="lineno"> 287</span>  <span class="keyword">virtual</span> <span class="keywordtype">void</span> Push(<a class="code" href="namespacemxnet_1_1engine.html#a2d9b14b658e3f3c4e03ca49cd38ace94">OprHandle</a> op, <a class="code" href="structmxnet_1_1Context.html">Context</a> exec_ctx, <span class="keywordtype">int</span> priority = 0, <span class="keywordtype">bool</span> profiling = <span class="keyword">false</span>) = 0;</div> |
| <div class="line"><a name="l00302"></a><span class="lineno"> 302</span>  <span class="keyword">virtual</span> <span class="keywordtype">void</span> PushAsync(AsyncFn exec_fun,</div> |
| <div class="line"><a name="l00303"></a><span class="lineno"> 303</span>  <a class="code" href="structmxnet_1_1Context.html">Context</a> exec_ctx,</div> |
| <div class="line"><a name="l00304"></a><span class="lineno"> 304</span>  std::vector<VarHandle> <span class="keyword">const</span>& const_vars,</div> |
| <div class="line"><a name="l00305"></a><span class="lineno"> 305</span>  std::vector<VarHandle> <span class="keyword">const</span>& mutable_vars,</div> |
| <div class="line"><a name="l00306"></a><span class="lineno"> 306</span>  <a class="code" href="namespacemxnet.html#a998b74220fab2b012cf8a179650e1b3b">FnProperty</a> prop = <a class="code" href="namespacemxnet.html#a998b74220fab2b012cf8a179650e1b3ba07fa7a19aa722c635a15e94cb7f50416">FnProperty::kNormal</a>,</div> |
| <div class="line"><a name="l00307"></a><span class="lineno"> 307</span>  <span class="keywordtype">int</span> priority = 0,</div> |
| <div class="line"><a name="l00308"></a><span class="lineno"> 308</span>  <span class="keyword">const</span> <span class="keywordtype">char</span>* opr_name = <span class="keyword">nullptr</span>,</div> |
| <div class="line"><a name="l00309"></a><span class="lineno"> 309</span>  <span class="keywordtype">bool</span> wait = <span class="keyword">false</span>) = 0;</div> |
| <div class="line"><a name="l00321"></a><span class="lineno"> 321</span>  <span class="keyword">virtual</span> <span class="keywordtype">void</span> DeleteVariable(SyncFn delete_fn, <a class="code" href="structmxnet_1_1Context.html">Context</a> exec_ctx, <a class="code" href="namespacemxnet_1_1engine.html#a9d36c4f33eae8531586dc2edf83ae7cf">VarHandle</a> var) = 0;</div> |
| <div class="line"><a name="l00327"></a><span class="lineno"> 327</span>  <span class="keyword">virtual</span> <span class="keywordtype">void</span> WaitForVar(<a class="code" href="namespacemxnet_1_1engine.html#a9d36c4f33eae8531586dc2edf83ae7cf">VarHandle</a> var) = 0;</div> |
| <div class="line"><a name="l00331"></a><span class="lineno"> 331</span>  <span class="keyword">virtual</span> <span class="keywordtype">void</span> WaitForAll() = 0;</div> |
| <div class="line"><a name="l00333"></a><span class="lineno"> 333</span>  <span class="keyword">virtual</span> <span class="keywordtype">void</span> Throw(<a class="code" href="namespacemxnet_1_1engine.html#a9d36c4f33eae8531586dc2edf83ae7cf">VarHandle</a> var) = 0;</div> |
| <div class="line"><a name="l00335"></a><span class="lineno"><a class="line" href="classmxnet_1_1Engine.html#aff025321827e15096c02342225f2395b"> 335</a></span>  <span class="keyword">virtual</span> <a class="code" href="classmxnet_1_1Engine.html#aff025321827e15096c02342225f2395b">~Engine</a>() noexcept(false) {}</div> |
| <div class="line"><a name="l00339"></a><span class="lineno"> 339</span>  <span class="keyword">static</span> <a class="code" href="classmxnet_1_1Engine.html">Engine</a>* Get();</div> |
| <div class="line"><a name="l00348"></a><span class="lineno"> 348</span>  <span class="keyword">static</span> <span class="keyword">const</span> std::shared_ptr<Engine>& _GetSharedRef();</div> |
| <div class="line"><a name="l00361"></a><span class="lineno"><a class="line" href="classmxnet_1_1Engine.html#ac4c1d74e906699cf074a34dfdc968d8c"> 361</a></span>  <span class="keyword">virtual</span> <span class="keywordtype">void</span> <a class="code" href="classmxnet_1_1Engine.html#ac4c1d74e906699cf074a34dfdc968d8c">PushSync</a>(<a class="code" href="classmxnet_1_1Engine.html#a07f30ab85fca436e1bbcc72cd4d8bb35">SyncFn</a> exec_fn,</div> |
| <div class="line"><a name="l00362"></a><span class="lineno"> 362</span>  <a class="code" href="structmxnet_1_1Context.html">Context</a> exec_ctx,</div> |
| <div class="line"><a name="l00363"></a><span class="lineno"> 363</span>  std::vector<VarHandle> <span class="keyword">const</span>& const_vars,</div> |
| <div class="line"><a name="l00364"></a><span class="lineno"> 364</span>  std::vector<VarHandle> <span class="keyword">const</span>& mutable_vars,</div> |
| <div class="line"><a name="l00365"></a><span class="lineno"> 365</span>  <a class="code" href="namespacemxnet.html#a998b74220fab2b012cf8a179650e1b3b">FnProperty</a> prop = <a class="code" href="namespacemxnet.html#a998b74220fab2b012cf8a179650e1b3ba07fa7a19aa722c635a15e94cb7f50416">FnProperty::kNormal</a>,</div> |
| <div class="line"><a name="l00366"></a><span class="lineno"> 366</span>  <span class="keywordtype">int</span> priority = 0,</div> |
| <div class="line"><a name="l00367"></a><span class="lineno"> 367</span>  <span class="keyword">const</span> <span class="keywordtype">char</span>* opr_name = <span class="keyword">nullptr</span>) {</div> |
| <div class="line"><a name="l00368"></a><span class="lineno"> 368</span>  this->PushAsync(</div> |
| <div class="line"><a name="l00369"></a><span class="lineno"> 369</span>  [exec_fn](<a class="code" href="structmxnet_1_1RunContext.html">RunContext</a> ctx, <a class="code" href="classmxnet_1_1engine_1_1CallbackOnStart.html">CallbackOnStart</a> on_start, <a class="code" href="classmxnet_1_1engine_1_1CallbackOnComplete.html">CallbackOnComplete</a> on_complete) {</div> |
| <div class="line"><a name="l00370"></a><span class="lineno"> 370</span>  on_start();</div> |
| <div class="line"><a name="l00371"></a><span class="lineno"> 371</span>  exec_fn(ctx);</div> |
| <div class="line"><a name="l00372"></a><span class="lineno"> 372</span>  on_complete();</div> |
| <div class="line"><a name="l00373"></a><span class="lineno"> 373</span>  },</div> |
| <div class="line"><a name="l00374"></a><span class="lineno"> 374</span>  exec_ctx,</div> |
| <div class="line"><a name="l00375"></a><span class="lineno"> 375</span>  const_vars,</div> |
| <div class="line"><a name="l00376"></a><span class="lineno"> 376</span>  mutable_vars,</div> |
| <div class="line"><a name="l00377"></a><span class="lineno"> 377</span>  prop,</div> |
| <div class="line"><a name="l00378"></a><span class="lineno"> 378</span>  priority,</div> |
| <div class="line"><a name="l00379"></a><span class="lineno"> 379</span>  opr_name);</div> |
| <div class="line"><a name="l00380"></a><span class="lineno"> 380</span>  }</div> |
| <div class="line"><a name="l00381"></a><span class="lineno"> 381</span>  </div> |
| <div class="line"><a name="l00387"></a><span class="lineno"><a class="line" href="classmxnet_1_1Engine.html#ad1064effa398035423931135e6607c5a"> 387</a></span>  <span class="keyword">inline</span> <a class="code" href="classmxnet_1_1engine_1_1CallbackOnStart.html">CallbackOnStart</a> <a class="code" href="classmxnet_1_1Engine.html#ad1064effa398035423931135e6607c5a">CreateOnStart</a>(<span class="keywordtype">void</span> (*callback)(<a class="code" href="classmxnet_1_1Engine.html">Engine</a>*, <span class="keywordtype">void</span>*, <span class="keyword">const</span> dmlc::Error*),</div> |
| <div class="line"><a name="l00388"></a><span class="lineno"> 388</span>  <span class="keywordtype">void</span>* param) {</div> |
| <div class="line"><a name="l00389"></a><span class="lineno"> 389</span>  <a class="code" href="classmxnet_1_1engine_1_1CallbackOnStart.html">CallbackOnStart</a> ret;</div> |
| <div class="line"><a name="l00390"></a><span class="lineno"> 390</span>  ret.callback_ = callback;</div> |
| <div class="line"><a name="l00391"></a><span class="lineno"> 391</span>  ret.engine_ = <span class="keyword">this</span>;</div> |
| <div class="line"><a name="l00392"></a><span class="lineno"> 392</span>  ret.param_ = param;</div> |
| <div class="line"><a name="l00393"></a><span class="lineno"> 393</span>  <span class="keywordflow">return</span> ret;</div> |
| <div class="line"><a name="l00394"></a><span class="lineno"> 394</span>  }</div> |
| <div class="line"><a name="l00395"></a><span class="lineno"> 395</span>  </div> |
| <div class="line"><a name="l00401"></a><span class="lineno"><a class="line" href="classmxnet_1_1Engine.html#a53659c833b8cac3cbcd15a152275a809"> 401</a></span>  <span class="keyword">inline</span> <a class="code" href="classmxnet_1_1engine_1_1CallbackOnComplete.html">CallbackOnComplete</a> <a class="code" href="classmxnet_1_1Engine.html#a53659c833b8cac3cbcd15a152275a809">CreateCallback</a>(<span class="keywordtype">void</span> (*callback)(<a class="code" href="classmxnet_1_1Engine.html">Engine</a>*, <span class="keywordtype">void</span>*, <span class="keyword">const</span> dmlc::Error*),</div> |
| <div class="line"><a name="l00402"></a><span class="lineno"> 402</span>  <span class="keywordtype">void</span>* param) {</div> |
| <div class="line"><a name="l00403"></a><span class="lineno"> 403</span>  <a class="code" href="classmxnet_1_1engine_1_1CallbackOnComplete.html">CallbackOnComplete</a> ret;</div> |
| <div class="line"><a name="l00404"></a><span class="lineno"> 404</span>  ret.callback_ = callback;</div> |
| <div class="line"><a name="l00405"></a><span class="lineno"> 405</span>  ret.engine_ = <span class="keyword">this</span>;</div> |
| <div class="line"><a name="l00406"></a><span class="lineno"> 406</span>  ret.param_ = param;</div> |
| <div class="line"><a name="l00407"></a><span class="lineno"> 407</span>  <span class="keywordflow">return</span> ret;</div> |
| <div class="line"><a name="l00408"></a><span class="lineno"> 408</span>  }</div> |
| <div class="line"><a name="l00409"></a><span class="lineno"> 409</span>  <span class="comment">// For each var vector, sort it and remove the duplicated vars.</span></div> |
| <div class="line"><a name="l00410"></a><span class="lineno"> 410</span>  <span class="comment">// Also remove vars from read_vars if it also appears in write_vars</span></div> |
| <div class="line"><a name="l00411"></a><span class="lineno"><a class="line" href="classmxnet_1_1Engine.html#a9d8ac1987a6844dba9b0463030fb3430"> 411</a></span>  <span class="keyword">inline</span> <span class="keywordtype">void</span> <a class="code" href="classmxnet_1_1Engine.html#a9d8ac1987a6844dba9b0463030fb3430">DeduplicateVarHandle</a>(std::vector<engine::VarHandle>* read_vars,</div> |
| <div class="line"><a name="l00412"></a><span class="lineno"> 412</span>  std::vector<engine::VarHandle>* write_vars) {</div> |
| <div class="line"><a name="l00413"></a><span class="lineno"> 413</span>  std::sort(write_vars->begin(), write_vars->end());</div> |
| <div class="line"><a name="l00414"></a><span class="lineno"> 414</span>  write_vars->resize(std::unique(write_vars->begin(), write_vars->end()) - write_vars->begin());</div> |
| <div class="line"><a name="l00415"></a><span class="lineno"> 415</span>  std::sort(read_vars->begin(), read_vars->end());</div> |
| <div class="line"><a name="l00416"></a><span class="lineno"> 416</span>  read_vars->resize(std::unique(read_vars->begin(), read_vars->end()) - read_vars->begin());</div> |
| <div class="line"><a name="l00417"></a><span class="lineno"> 417</span>  <span class="keyword">auto</span> wit = write_vars->begin();</div> |
| <div class="line"><a name="l00418"></a><span class="lineno"> 418</span>  <span class="keyword">auto</span> rtop = read_vars->begin();</div> |
| <div class="line"><a name="l00419"></a><span class="lineno"> 419</span>  <span class="keywordflow">for</span> (<span class="keyword">auto</span> rit = read_vars->begin(); rit != read_vars->end(); ++rit) {</div> |
| <div class="line"><a name="l00420"></a><span class="lineno"> 420</span>  <span class="keywordflow">while</span> (wit != write_vars->end() && *wit < *rit)</div> |
| <div class="line"><a name="l00421"></a><span class="lineno"> 421</span>  ++wit;</div> |
| <div class="line"><a name="l00422"></a><span class="lineno"> 422</span>  <span class="keywordflow">if</span> (wit == write_vars->end() || *wit != *rit) {</div> |
| <div class="line"><a name="l00423"></a><span class="lineno"> 423</span>  *rtop = *rit;</div> |
| <div class="line"><a name="l00424"></a><span class="lineno"> 424</span>  ++rtop;</div> |
| <div class="line"><a name="l00425"></a><span class="lineno"> 425</span>  }</div> |
| <div class="line"><a name="l00426"></a><span class="lineno"> 426</span>  }</div> |
| <div class="line"><a name="l00427"></a><span class="lineno"> 427</span>  read_vars->resize(rtop - read_vars->begin());</div> |
| <div class="line"><a name="l00428"></a><span class="lineno"> 428</span>  }</div> |
| <div class="line"><a name="l00430"></a><span class="lineno"><a class="line" href="classmxnet_1_1Engine.html#aa3cc0ead60033689efdc28dba2d3a372"> 430</a></span>  <span class="keyword">virtual</span> <span class="keywordtype">int</span> <a class="code" href="classmxnet_1_1Engine.html#aa3cc0ead60033689efdc28dba2d3a372">bulk_size</a>()<span class="keyword"> const </span>{</div> |
| <div class="line"><a name="l00431"></a><span class="lineno"> 431</span>  <span class="keywordflow">return</span> 0;</div> |
| <div class="line"><a name="l00432"></a><span class="lineno"> 432</span>  }</div> |
| <div class="line"><a name="l00434"></a><span class="lineno"><a class="line" href="classmxnet_1_1Engine.html#ac33a02d6827b43c27c04371d8ef04ef4"> 434</a></span>  <span class="keyword">virtual</span> <span class="keywordtype">int</span> <a class="code" href="classmxnet_1_1Engine.html#ac33a02d6827b43c27c04371d8ef04ef4">set_bulk_size</a>(<span class="keywordtype">int</span>) {</div> |
| <div class="line"><a name="l00435"></a><span class="lineno"> 435</span>  <span class="keywordflow">return</span> 0;</div> |
| <div class="line"><a name="l00436"></a><span class="lineno"> 436</span>  }</div> |
| <div class="line"><a name="l00437"></a><span class="lineno"> 437</span> }; <span class="comment">// class Engine</span></div> |
| <div class="line"><a name="l00438"></a><span class="lineno"> 438</span> <span class="preprocessor">#endif // DMLC_USE_CXX11</span></div> |
| <div class="line"><a name="l00439"></a><span class="lineno"> 439</span> } <span class="comment">// namespace mxnet</span></div> |
| <div class="line"><a name="l00440"></a><span class="lineno"> 440</span> <span class="preprocessor">#endif // MXNET_ENGINE_H_</span></div> |
| </div><!-- fragment --></div><!-- contents --> |
| <div class="ttc" id="anamespacemxnet_html"><div class="ttname"><a href="namespacemxnet.html">mxnet</a></div><div class="ttdoc">namespace of mxnet</div><div class="ttdef"><b>Definition:</b> api_registry.h:33</div></div> |
| <div class="ttc" id="aclassmxnet_1_1Engine_html_a86c85108347be2bfbb61696b9937316b"><div class="ttname"><a href="classmxnet_1_1Engine.html#a86c85108347be2bfbb61696b9937316b">mxnet::Engine::CallbackOnStart</a></div><div class="ttdeci">engine::CallbackOnStart CallbackOnStart</div><div class="ttdoc">on start</div><div class="ttdef"><b>Definition:</b> engine.h:216</div></div> |
| <div class="ttc" id="aclassmxnet_1_1engine_1_1CUDAEvent_html_a3743c3d6d08920567950e8c917b37c64"><div class="ttname"><a href="classmxnet_1_1engine_1_1CUDAEvent.html#a3743c3d6d08920567950e8c917b37c64">mxnet::engine::CUDAEvent::operator=</a></div><div class="ttdeci">void operator=(const CUDAEvent &other)=delete</div></div> |
| <div class="ttc" id="aclassmxnet_1_1engine_1_1CallbackOnComplete_html"><div class="ttname"><a href="classmxnet_1_1engine_1_1CallbackOnComplete.html">mxnet::engine::CallbackOnComplete</a></div><div class="ttdoc">OnComplete Callback to the engine, called by AsyncFn when action completes.</div><div class="ttdef"><b>Definition:</b> engine.h:169</div></div> |
| <div class="ttc" id="aclassmxnet_1_1engine_1_1CallbackOnStart_html"><div class="ttname"><a href="classmxnet_1_1engine_1_1CallbackOnStart.html">mxnet::engine::CallbackOnStart</a></div><div class="ttdoc">OnStart callback to the engine, called by AsyncFn before the action.</div><div class="ttdef"><b>Definition:</b> engine.h:146</div></div> |
| <div class="ttc" id="anamespacemxnet_html_a998b74220fab2b012cf8a179650e1b3ba07fa7a19aa722c635a15e94cb7f50416"><div class="ttname"><a href="namespacemxnet.html#a998b74220fab2b012cf8a179650e1b3ba07fa7a19aa722c635a15e94cb7f50416">mxnet::FnProperty::kNormal</a></div><div class="ttdeci">@ kNormal</div><div class="ttdoc">Normal operation.</div></div> |
| <div class="ttc" id="astructmxnet_1_1engine_1_1Var_html_a43edd2b4bd955d283b3dff5e3f66279d"><div class="ttname"><a href="structmxnet_1_1engine_1_1Var.html#a43edd2b4bd955d283b3dff5e3f66279d">mxnet::engine::Var::version_</a></div><div class="ttdeci">size_t version_</div><div class="ttdoc">version number of the var. Every time the object it is associated with is modified,...</div><div class="ttdef"><b>Definition:</b> engine.h:127</div></div> |
| <div class="ttc" id="aclassmxnet_1_1engine_1_1CUDAEvent_html_a53ea01451a92b23950a4ddce8c71cdb8"><div class="ttname"><a href="classmxnet_1_1engine_1_1CUDAEvent.html#a53ea01451a92b23950a4ddce8c71cdb8">mxnet::engine::CUDAEvent::GetEvent</a></div><div class="ttdeci">std::weak_ptr< cudaEvent_t > GetEvent() noexcept</div><div class="ttdef"><b>Definition:</b> engine.h:58</div></div> |
| <div class="ttc" id="aclassmxnet_1_1Engine_html_a16b757432556f835d27f1b5e1dbe1b06"><div class="ttname"><a href="classmxnet_1_1Engine.html#a16b757432556f835d27f1b5e1dbe1b06">mxnet::Engine::CallbackOnComplete</a></div><div class="ttdeci">engine::CallbackOnComplete CallbackOnComplete</div><div class="ttdoc">callback on complete</div><div class="ttdef"><b>Definition:</b> engine.h:218</div></div> |
| <div class="ttc" id="astructmxnet_1_1engine_1_1EventInfo_html_a0c134825ed31fe9c210affafe29614ae"><div class="ttname"><a href="structmxnet_1_1engine_1_1EventInfo.html#a0c134825ed31fe9c210affafe29614ae">mxnet::engine::EventInfo::stream</a></div><div class="ttdeci">cudaStream_t stream</div><div class="ttdef"><b>Definition:</b> engine.h:97</div></div> |
| <div class="ttc" id="astructmxnet_1_1engine_1_1EventInfo_html"><div class="ttname"><a href="structmxnet_1_1engine_1_1EventInfo.html">mxnet::engine::EventInfo</a></div><div class="ttdoc">full event info for the sync object.</div><div class="ttdef"><b>Definition:</b> engine.h:95</div></div> |
| <div class="ttc" id="astructmxnet_1_1engine_1_1EventInfo_html_a275fb1793ee6bdcf4d78c1536b1dd95e"><div class="ttname"><a href="structmxnet_1_1engine_1_1EventInfo.html#a275fb1793ee6bdcf4d78c1536b1dd95e">mxnet::engine::EventInfo::pool_index</a></div><div class="ttdeci">uint64_t pool_index</div><div class="ttdef"><b>Definition:</b> engine.h:98</div></div> |
| <div class="ttc" id="anamespacemxnet_html_a998b74220fab2b012cf8a179650e1b3bac41fda8552e9d327ad3b06b1bafa663a"><div class="ttname"><a href="namespacemxnet.html#a998b74220fab2b012cf8a179650e1b3bac41fda8552e9d327ad3b06b1bafa663a">mxnet::FnProperty::kDeleteVar</a></div><div class="ttdeci">@ kDeleteVar</div><div class="ttdoc">Delete variable call.</div></div> |
| <div class="ttc" id="aclassmxnet_1_1engine_1_1CUDAEventPool_html_aad9425ed11adf77479af5a666f14f38d"><div class="ttname"><a href="classmxnet_1_1engine_1_1CUDAEventPool.html#aad9425ed11adf77479af5a666f14f38d">mxnet::engine::CUDAEventPool::CUDAEventPool</a></div><div class="ttdeci">CUDAEventPool(Context const &ctx)</div><div class="ttdef"><b>Definition:</b> engine.h:69</div></div> |
| <div class="ttc" id="astructmxnet_1_1RunContext_html"><div class="ttname"><a href="structmxnet_1_1RunContext.html">mxnet::RunContext</a></div><div class="ttdoc">execution time context. The information needed in runtime for actual execution.</div><div class="ttdef"><b>Definition:</b> base.h:343</div></div> |
| <div class="ttc" id="anamespacemxnet_html_a998b74220fab2b012cf8a179650e1b3ba9f2b960005d2a3a5f35ac32809d84db7"><div class="ttname"><a href="namespacemxnet.html#a998b74220fab2b012cf8a179650e1b3ba9f2b960005d2a3a5f35ac32809d84db7">mxnet::FnProperty::kAsync</a></div><div class="ttdeci">@ kAsync</div><div class="ttdoc">Asynchronous function call.</div></div> |
| <div class="ttc" id="aclassmxnet_1_1engine_1_1CUDAEventPool_html_a3e8378ca127166af434256007c649f64"><div class="ttname"><a href="classmxnet_1_1engine_1_1CUDAEventPool.html#a3e8378ca127166af434256007c649f64">mxnet::engine::CUDAEventPool::GetNextEvent</a></div><div class="ttdeci">std::pair< std::weak_ptr< cudaEvent_t >, uint64_t > GetNextEvent() noexcept</div><div class="ttdef"><b>Definition:</b> engine.h:79</div></div> |
| <div class="ttc" id="aclassmxnet_1_1engine_1_1CUDAEvent_html_a0ff9222e821d258fadba48133a9d1eb6"><div class="ttname"><a href="classmxnet_1_1engine_1_1CUDAEvent.html#a0ff9222e821d258fadba48133a9d1eb6">mxnet::engine::CUDAEvent::CUDAEvent</a></div><div class="ttdeci">CUDAEvent(CUDAEvent &&other)</div><div class="ttdef"><b>Definition:</b> engine.h:49</div></div> |
| <div class="ttc" id="aclassmxnet_1_1Engine_html_a30f8947a7a1ad86b4c2cdb5b15b6d1e1"><div class="ttname"><a href="classmxnet_1_1Engine.html#a30f8947a7a1ad86b4c2cdb5b15b6d1e1">mxnet::Engine::Stop</a></div><div class="ttdeci">virtual void Stop()</div><div class="ttdoc">Stop all workers in the engine.</div><div class="ttdef"><b>Definition:</b> engine.h:238</div></div> |
| <div class="ttc" id="aclassmxnet_1_1Engine_html_aff025321827e15096c02342225f2395b"><div class="ttname"><a href="classmxnet_1_1Engine.html#aff025321827e15096c02342225f2395b">mxnet::Engine::~Engine</a></div><div class="ttdeci">virtual ~Engine() noexcept(false)</div><div class="ttdoc">virtual destructor</div><div class="ttdef"><b>Definition:</b> engine.h:335</div></div> |
| <div class="ttc" id="anamespacemxnet_html_a998b74220fab2b012cf8a179650e1b3bac76ad5584b42e2ba06f1b5e9ebac8b2a"><div class="ttname"><a href="namespacemxnet.html#a998b74220fab2b012cf8a179650e1b3bac76ad5584b42e2ba06f1b5e9ebac8b2a">mxnet::FnProperty::kNoSkip</a></div><div class="ttdeci">@ kNoSkip</div><div class="ttdoc">Operation not to be skipped even with associated exception.</div></div> |
| <div class="ttc" id="aclassmxnet_1_1Engine_html_a07f30ab85fca436e1bbcc72cd4d8bb35"><div class="ttname"><a href="classmxnet_1_1Engine.html#a07f30ab85fca436e1bbcc72cd4d8bb35">mxnet::Engine::SyncFn</a></div><div class="ttdeci">std::function< void(RunContext)> SyncFn</div><div class="ttdoc">Synchronous operation to pass to engine.</div><div class="ttdef"><b>Definition:</b> engine.h:220</div></div> |
| <div class="ttc" id="aclassmxnet_1_1Engine_html_af3c96ff8fb3f7e86ff0d79ff1d930001"><div class="ttname"><a href="classmxnet_1_1Engine.html#af3c96ff8fb3f7e86ff0d79ff1d930001">mxnet::Engine::AsyncFn</a></div><div class="ttdeci">std::function< void(RunContext, CallbackOnStart, CallbackOnComplete)> AsyncFn</div><div class="ttdoc">Asynchronous operation to pass to engine.</div><div class="ttdef"><b>Definition:</b> engine.h:222</div></div> |
| <div class="ttc" id="astructmxnet_1_1engine_1_1Var_html_af9884f59707511d65ecdb04a6dab0423"><div class="ttname"><a href="structmxnet_1_1engine_1_1Var.html#af9884f59707511d65ecdb04a6dab0423">mxnet::engine::Var::~Var</a></div><div class="ttdeci">virtual ~Var()=default</div></div> |
| <div class="ttc" id="aclassmxnet_1_1Engine_html_a136a1d60db77a0846faa8462d0fb6237"><div class="ttname"><a href="classmxnet_1_1Engine.html#a136a1d60db77a0846faa8462d0fb6237">mxnet::Engine::Start</a></div><div class="ttdeci">virtual void Start()</div><div class="ttdoc">Restart all workers in the engine.</div><div class="ttdef"><b>Definition:</b> engine.h:244</div></div> |
| <div class="ttc" id="anamespacemxnet_1_1engine_html_a2d9b14b658e3f3c4e03ca49cd38ace94"><div class="ttname"><a href="namespacemxnet_1_1engine.html#a2d9b14b658e3f3c4e03ca49cd38ace94">mxnet::engine::OprHandle</a></div><div class="ttdeci">Opr * OprHandle</div><div class="ttdoc">Operator pointer type, usually hold by user.</div><div class="ttdef"><b>Definition:</b> engine.h:141</div></div> |
| <div class="ttc" id="aclassmxnet_1_1engine_1_1CUDAEvent_html_ae6c6574f8191012ec061cd816a64ff20"><div class="ttname"><a href="classmxnet_1_1engine_1_1CUDAEvent.html#ae6c6574f8191012ec061cd816a64ff20">mxnet::engine::CUDAEvent::CUDAEvent</a></div><div class="ttdeci">CUDAEvent(Context const &ctx)</div></div> |
| <div class="ttc" id="anamespacemxnet_html_a998b74220fab2b012cf8a179650e1b3ba4879e1f172ddec6223d17ea90e691ddd"><div class="ttname"><a href="namespacemxnet.html#a998b74220fab2b012cf8a179650e1b3ba4879e1f172ddec6223d17ea90e691ddd">mxnet::FnProperty::kGPUPrioritized</a></div><div class="ttdeci">@ kGPUPrioritized</div><div class="ttdoc">Prioritized sync operation on GPU.</div></div> |
| <div class="ttc" id="aclassmxnet_1_1Engine_html_a832436e413a075291aa1a631942c3f01"><div class="ttname"><a href="classmxnet_1_1Engine.html#a832436e413a075291aa1a631942c3f01">mxnet::Engine::OprHandle</a></div><div class="ttdeci">engine::OprHandle OprHandle</div><div class="ttdoc">Operator pointer.</div><div class="ttdef"><b>Definition:</b> engine.h:226</div></div> |
| <div class="ttc" id="aclassmxnet_1_1Engine_html"><div class="ttname"><a href="classmxnet_1_1Engine.html">mxnet::Engine</a></div><div class="ttdoc">Dependency engine that schedules operations.</div><div class="ttdef"><b>Definition:</b> engine.h:213</div></div> |
| <div class="ttc" id="aclassmxnet_1_1engine_1_1CallbackOnComplete_html_aadfc4d64cee7555f9ef105172d855fc5"><div class="ttname"><a href="classmxnet_1_1engine_1_1CallbackOnComplete.html#aadfc4d64cee7555f9ef105172d855fc5">mxnet::engine::CallbackOnComplete::operator()</a></div><div class="ttdeci">void operator()(const dmlc::Error *error=nullptr) const</div><div class="ttdoc">involve the callback</div><div class="ttdef"><b>Definition:</b> engine.h:173</div></div> |
| <div class="ttc" id="anamespacemxnet_html_a998b74220fab2b012cf8a179650e1b3bac41ceb98eeb9b2e208e3e242a7357142"><div class="ttname"><a href="namespacemxnet.html#a998b74220fab2b012cf8a179650e1b3bac41ceb98eeb9b2e208e3e242a7357142">mxnet::FnProperty::kCPUPrioritized</a></div><div class="ttdeci">@ kCPUPrioritized</div><div class="ttdoc">Prioritized sync operation on CPU.</div></div> |
| <div class="ttc" id="astructmxnet_1_1Context_html"><div class="ttname"><a href="structmxnet_1_1Context.html">mxnet::Context</a></div><div class="ttdoc">Context information about the execution environment.</div><div class="ttdef"><b>Definition:</b> base.h:90</div></div> |
| <div class="ttc" id="aclassmxnet_1_1Engine_html_ac4c1d74e906699cf074a34dfdc968d8c"><div class="ttname"><a href="classmxnet_1_1Engine.html#ac4c1d74e906699cf074a34dfdc968d8c">mxnet::Engine::PushSync</a></div><div class="ttdeci">virtual void PushSync(SyncFn exec_fn, Context exec_ctx, std::vector< VarHandle > const &const_vars, std::vector< VarHandle > const &mutable_vars, FnProperty prop=FnProperty::kNormal, int priority=0, const char *opr_name=nullptr)</div><div class="ttdoc">Push an synchronous operation to the engine.</div><div class="ttdef"><b>Definition:</b> engine.h:361</div></div> |
| <div class="ttc" id="astructmxnet_1_1engine_1_1SyncObject_html_a218d57ab487ba20d18dc79c2c4e985ea"><div class="ttname"><a href="structmxnet_1_1engine_1_1SyncObject.html#a218d57ab487ba20d18dc79c2c4e985ea">mxnet::engine::SyncObject::mutex</a></div><div class="ttdeci">std::mutex mutex</div><div class="ttdef"><b>Definition:</b> engine.h:106</div></div> |
| <div class="ttc" id="ainclude_2mxnet_2base_8h_html_a91a09948aaaffa1eb64508db79009a05"><div class="ttname"><a href="include_2mxnet_2base_8h.html#a91a09948aaaffa1eb64508db79009a05">MXNET_API</a></div><div class="ttdeci">#define MXNET_API</div><div class="ttdoc">define dllexport for Visual Studio</div><div class="ttdef"><b>Definition:</b> base.h:49</div></div> |
| <div class="ttc" id="aclassmxnet_1_1Engine_html_aa3cc0ead60033689efdc28dba2d3a372"><div class="ttname"><a href="classmxnet_1_1Engine.html#aa3cc0ead60033689efdc28dba2d3a372">mxnet::Engine::bulk_size</a></div><div class="ttdeci">virtual int bulk_size() const</div><div class="ttdoc">query current limit for bulk size</div><div class="ttdef"><b>Definition:</b> engine.h:430</div></div> |
| <div class="ttc" id="aclassmxnet_1_1engine_1_1CUDAEvent_html_ab11e75549f6842a01ec4941dbb6a1c20"><div class="ttname"><a href="classmxnet_1_1engine_1_1CUDAEvent.html#ab11e75549f6842a01ec4941dbb6a1c20">mxnet::engine::CUDAEvent::~CUDAEvent</a></div><div class="ttdeci">~CUDAEvent()</div></div> |
| <div class="ttc" id="anamespacemxnet_1_1engine_html_a9d36c4f33eae8531586dc2edf83ae7cf"><div class="ttname"><a href="namespacemxnet_1_1engine.html#a9d36c4f33eae8531586dc2edf83ae7cf">mxnet::engine::VarHandle</a></div><div class="ttdeci">Var * VarHandle</div><div class="ttdoc">Variable pointer type, usually hold by user used to specify dependencies.</div><div class="ttdef"><b>Definition:</b> engine.h:137</div></div> |
| <div class="ttc" id="astructmxnet_1_1engine_1_1Var_html"><div class="ttname"><a href="structmxnet_1_1engine_1_1Var.html">mxnet::engine::Var</a></div><div class="ttdoc">base class of engine variables.</div><div class="ttdef"><b>Definition:</b> engine.h:111</div></div> |
| <div class="ttc" id="aclassmxnet_1_1engine_1_1CUDAEventPool_html_a4e756deaf6e48abd5568cdcdd8f58140"><div class="ttname"><a href="classmxnet_1_1engine_1_1CUDAEventPool.html#a4e756deaf6e48abd5568cdcdd8f58140">mxnet::engine::CUDAEventPool::GetEvent</a></div><div class="ttdeci">std::weak_ptr< cudaEvent_t > GetEvent(size_t i) noexcept</div><div class="ttdef"><b>Definition:</b> engine.h:75</div></div> |
| <div class="ttc" id="astructmxnet_1_1engine_1_1EventInfo_html_a1906da3e8772f14bae7e2a545e393263"><div class="ttname"><a href="structmxnet_1_1engine_1_1EventInfo.html#a1906da3e8772f14bae7e2a545e393263">mxnet::engine::EventInfo::event</a></div><div class="ttdeci">std::weak_ptr< cudaEvent_t > event</div><div class="ttdef"><b>Definition:</b> engine.h:96</div></div> |
| <div class="ttc" id="astructmxnet_1_1engine_1_1SyncObject_html_a4c2ae5228111a794e9651e5b88861588"><div class="ttname"><a href="structmxnet_1_1engine_1_1SyncObject.html#a4c2ae5228111a794e9651e5b88861588">mxnet::engine::SyncObject::writer_event</a></div><div class="ttdeci">std::vector< EventInfo > writer_event</div><div class="ttdef"><b>Definition:</b> engine.h:105</div></div> |
| <div class="ttc" id="aclassmxnet_1_1Engine_html_a9d8ac1987a6844dba9b0463030fb3430"><div class="ttname"><a href="classmxnet_1_1Engine.html#a9d8ac1987a6844dba9b0463030fb3430">mxnet::Engine::DeduplicateVarHandle</a></div><div class="ttdeci">void DeduplicateVarHandle(std::vector< engine::VarHandle > *read_vars, std::vector< engine::VarHandle > *write_vars)</div><div class="ttdef"><b>Definition:</b> engine.h:411</div></div> |
| <div class="ttc" id="aclassmxnet_1_1Engine_html_ad1064effa398035423931135e6607c5a"><div class="ttname"><a href="classmxnet_1_1Engine.html#ad1064effa398035423931135e6607c5a">mxnet::Engine::CreateOnStart</a></div><div class="ttdeci">CallbackOnStart CreateOnStart(void(*callback)(Engine *, void *, const dmlc::Error *), void *param)</div><div class="ttdoc">factory function to create OnStart callback.</div><div class="ttdef"><b>Definition:</b> engine.h:387</div></div> |
| <div class="ttc" id="aclassmxnet_1_1engine_1_1CUDAEventPool_html"><div class="ttname"><a href="classmxnet_1_1engine_1_1CUDAEventPool.html">mxnet::engine::CUDAEventPool</a></div><div class="ttdef"><b>Definition:</b> engine.h:67</div></div> |
| <div class="ttc" id="aclassmxnet_1_1Engine_html_aac31510c793a12944c33f9cac6150491"><div class="ttname"><a href="classmxnet_1_1Engine.html#aac31510c793a12944c33f9cac6150491">mxnet::Engine::VarHandle</a></div><div class="ttdeci">engine::VarHandle VarHandle</div><div class="ttdoc">Variable pointer.</div><div class="ttdef"><b>Definition:</b> engine.h:224</div></div> |
| <div class="ttc" id="astructmxnet_1_1engine_1_1SyncObject_html_a0f52005af80e0b61afa08e1295830a0f"><div class="ttname"><a href="structmxnet_1_1engine_1_1SyncObject.html#a0f52005af80e0b61afa08e1295830a0f">mxnet::engine::SyncObject::reader_events</a></div><div class="ttdeci">std::vector< EventInfo > reader_events</div><div class="ttdef"><b>Definition:</b> engine.h:103</div></div> |
| <div class="ttc" id="aclassmxnet_1_1Engine_html_ac33a02d6827b43c27c04371d8ef04ef4"><div class="ttname"><a href="classmxnet_1_1Engine.html#ac33a02d6827b43c27c04371d8ef04ef4">mxnet::Engine::set_bulk_size</a></div><div class="ttdeci">virtual int set_bulk_size(int)</div><div class="ttdoc">set maximum limit for bulk size</div><div class="ttdef"><b>Definition:</b> engine.h:434</div></div> |
| <div class="ttc" id="astructmxnet_1_1engine_1_1Var_html_af71c04da3c374220356efa98cc215590"><div class="ttname"><a href="structmxnet_1_1engine_1_1Var.html#af71c04da3c374220356efa98cc215590">mxnet::engine::Var::Cast</a></div><div class="ttdeci">T * Cast()</div><div class="ttdoc">cast variable to derived type T</div></div> |
| <div class="ttc" id="aclassmxnet_1_1Engine_html_a53659c833b8cac3cbcd15a152275a809"><div class="ttname"><a href="classmxnet_1_1Engine.html#a53659c833b8cac3cbcd15a152275a809">mxnet::Engine::CreateCallback</a></div><div class="ttdeci">CallbackOnComplete CreateCallback(void(*callback)(Engine *, void *, const dmlc::Error *), void *param)</div><div class="ttdoc">factory function to create OnComplete callback.</div><div class="ttdef"><b>Definition:</b> engine.h:401</div></div> |
| <div class="ttc" id="anamespacemxnet_html_a998b74220fab2b012cf8a179650e1b3ba6cd75f41e0ec8d61b0a2f0e20ef6d1e8"><div class="ttname"><a href="namespacemxnet.html#a998b74220fab2b012cf8a179650e1b3ba6cd75f41e0ec8d61b0a2f0e20ef6d1e8">mxnet::FnProperty::kCopyToGPU</a></div><div class="ttdeci">@ kCopyToGPU</div><div class="ttdoc">Copy operation from CPU to other devices.</div></div> |
| <div class="ttc" id="astructmxnet_1_1engine_1_1Var_html_a5ae9f575784f18598df8824ccd421c18"><div class="ttname"><a href="structmxnet_1_1engine_1_1Var.html#a5ae9f575784f18598df8824ccd421c18">mxnet::engine::Var::version</a></div><div class="ttdeci">virtual size_t version()</div><div class="ttdef"><b>Definition:</b> engine.h:112</div></div> |
| <div class="ttc" id="ainclude_2mxnet_2base_8h_html"><div class="ttname"><a href="include_2mxnet_2base_8h.html">base.h</a></div><div class="ttdoc">configuration of MXNet as well as basic data structure.</div></div> |
| <div class="ttc" id="astructmxnet_1_1engine_1_1SyncObject_html"><div class="ttname"><a href="structmxnet_1_1engine_1_1SyncObject.html">mxnet::engine::SyncObject</a></div><div class="ttdoc">struct containing cuda events and variables needed for the dependencies.</div><div class="ttdef"><b>Definition:</b> engine.h:101</div></div> |
| <div class="ttc" id="anamespacemxnet_html_a998b74220fab2b012cf8a179650e1b3ba739f2f416f05f4728c217f09e93958d1"><div class="ttname"><a href="namespacemxnet.html#a998b74220fab2b012cf8a179650e1b3ba739f2f416f05f4728c217f09e93958d1">mxnet::FnProperty::kCopyFromGPU</a></div><div class="ttdeci">@ kCopyFromGPU</div><div class="ttdoc">Copy operation from GPU to other devices.</div></div> |
| <div class="ttc" id="anamespacemxnet_html_a998b74220fab2b012cf8a179650e1b3b"><div class="ttname"><a href="namespacemxnet.html#a998b74220fab2b012cf8a179650e1b3b">mxnet::FnProperty</a></div><div class="ttdeci">FnProperty</div><div class="ttdoc">Function property, used to hint what action is pushed to engine.</div><div class="ttdef"><b>Definition:</b> engine.h:191</div></div> |
| <div class="ttc" id="astructmxnet_1_1engine_1_1Var_html_a10a695c677bc0019d1c9eeed6aab7ee4"><div class="ttname"><a href="structmxnet_1_1engine_1_1Var.html#a10a695c677bc0019d1c9eeed6aab7ee4">mxnet::engine::Var::sync_object</a></div><div class="ttdeci">SyncObject sync_object</div><div class="ttdoc">struct containing cuda events and variables needed for the dependencies.</div><div class="ttdef"><b>Definition:</b> engine.h:132</div></div> |
| <div class="ttc" id="aclassmxnet_1_1engine_1_1CUDAEvent_html"><div class="ttname"><a href="classmxnet_1_1engine_1_1CUDAEvent.html">mxnet::engine::CUDAEvent</a></div><div class="ttdef"><b>Definition:</b> engine.h:45</div></div> |
| <div class="ttc" id="aclassmxnet_1_1engine_1_1CUDAEventPool_html_a7449bbaca7f9d382999930d38d2710ef"><div class="ttname"><a href="classmxnet_1_1engine_1_1CUDAEventPool.html#a7449bbaca7f9d382999930d38d2710ef">mxnet::engine::CUDAEventPool::GetCounterValue</a></div><div class="ttdeci">uint64_t GetCounterValue() noexcept</div><div class="ttdef"><b>Definition:</b> engine.h:84</div></div> |
| <div class="ttc" id="aclassmxnet_1_1engine_1_1CallbackOnStart_html_aaef3cbe661d77ef471d084feb058a613"><div class="ttname"><a href="classmxnet_1_1engine_1_1CallbackOnStart.html#aaef3cbe661d77ef471d084feb058a613">mxnet::engine::CallbackOnStart::operator()</a></div><div class="ttdeci">void operator()(const dmlc::Error *error=nullptr) const</div><div class="ttdoc">involve the callback</div><div class="ttdef"><b>Definition:</b> engine.h:150</div></div> |
| <!-- start footer part --> |
| <hr class="footer"/><address class="footer"><small> |
| Generated on Sat Nov 5 2022 01:16:57 for mxnet by  <a href="http://www.doxygen.org/index.html"> |
| <img class="footer" src="doxygen.png" alt="doxygen"/> |
| </a> 1.8.17 |
| </small></address> |
| </body> |
| </html> |