Hexagon conv2d -> conv2d schedules
This is a baseline 1x1 conv2d -> 1x1 conv2d schedule for Hexagon.
pytest -sv “tests/python/contrib/test_hexagon/test_conv2d_conv2d.py::TestConv2dConv2dPackedFilter::test_conv2d[1-64-128-0-1-1-128-1-1-128-1-1-float32-llvm]”
Parameter | Value |
---|---|
Batch | 1 |
Input Size | 64x64 |
Input Channel | 128 |
Conv2d #1 Pad | 0 |
Conv2d #1 Stride | 1 |
Conv2d #1 Kernel Size | 1 |
Conv2d #1 Output Channel | 128 |
Conv2d #2 Stride | 1 |
Conv2d #2 Kernel Size | 1 |
Conv2d #2 Output Channel | 128 |
k_split | 1 |
h_split | 1 |
Constant | Value |
---|---|
Conv2d #2 Pad | 0 |
Conv2d #1 Dilation | 1 |
Conv2d #2 Dilation | 1 |
The input is provided and padded in logical layout and then packed into its physical layout prior to compute. Logical layout / shape information is provided as a reference for physical tensors.
Tensor | Type | Layout | Shape | Logical Layout | Logical Shape |
---|---|---|---|---|---|
Input | Logical | NHWC | [1, 64, 64, 128] | ||
Padded Input | Logical | NHWC | [1, 64, 64, 128] | ||
Packed Input | Physical | NHWC8h8w32c | [1, 8, 8, 4, 8, 8, 32] | NHWC | [1, 64, 64, 128] |
Filter 1 | Physical | OIHW8i32o4i | [4, 4, 1, 1, 8, 32, 4] | OIHW | [128, 128, 1, 1] |
Temp Output | Physical | NHWC8h8w32c | [1, 8, 8, 4, 8, 8, 32] | NHWC | [1, 64, 64, 128] |
Filter 2 | Physical | OIHW8i32o4i | [4, 4, 1, 1, 8, 32, 4] | OIHW | [128, 128, 1, 1] |
Output | Physical | NHWC8h8w32c | [1, 8, 8, 4, 8, 8, 32] | NHWC | [1, 64, 64, 128] |
This is the conv2d compute schedule:
for (ko.outer: int32, 0, 4) { for (ho.outer: int32, 0, 8) { // input cache read for (ko.outer_1: int32, 0, 4) { // filter #1 cache read // conv2d #1 for (wo: int32, 0, 8) { for (rc.outer: int32, 0, 4) { for (hi: int32, 0, 8) { for (wi: int32, 0, 8) { for (ki: int32, 0, 32) { for (rc.inner: int32, 0, 32) { } // end ko.outer_1 // filter #2 cache read // conv2d #2 for (wo.c: int32, 0, 8) { for (rc.outer_1: int32, 0, 4) { for (hi.c: int32, 0, 8) { for (wi.c: int32, 0, 8) { for (ki.c: int32, 0, 32) { for (rc.inner_1: int32, 0, 32) { // write back output cache } // end ho.outer } // end ko.outer
Note that conv2d #1 has an independent loop over the channel out ko.outer_1
dimension. This is because the output channels of conv2d #1 are the input channels to conv2d #2 and we compute over all input channels for each conv2d so we must compute over all output channels of conv2d #1 before we compute conv2d #2.
for (ko.outer_1: int32, 0, 2) {
Input Cache
We compute over the WC8h8w32c portion of the input so we need 8 * 4 * 8 * 8 * 32 = 64kb for the input cache.
allocate(packed_input.global: Pointer(global float32), float32, [65536]), storage_scope = global;
Filter Cache
We compute over the IHW8i32o4i portion of each filter so we need 4 * 1 * 1 * 8 * 32 * 4 = 4kb filter cache.
allocate(packed_filter.global: Pointer(global float32), float32, [4096]), storage_scope = global;
Note that there is just one cache which is reused for conv2d / filter #1 and conv2d / filter #2.
Output Cache
We compute over the WK8h832k portion of the output where k
denotes the output channel. The output cache is computed for each ko.outer
which means it should be W * 8h * 8w * 32k = 8 * 8 * 8 * 32 = 16kb. And, in fact, this is the case for a single conv2d case. But, as already noted, for this conv2d -> conv2d case “the output channels of conv2d #1 are the input channels to conv2d #2 and we compute over all input channels for each conv2d so we must compute over all output channels of conv2d #1 before we compute conv2d #2”. This means that the output cache must grow accordingly to K * W * 8h * 8w * 32k = 4 * 8 * 8 * 8 * 32 = 64kb. There is a temporary allocation to store the results of conv2d #1:
allocate(temp_output: Pointer(global float32), float32, [65536]), storage_scope = global;
Note that the input cache is reused to store the results of conv2d #2.
slice 0: global -> load -> cache0 -> conv2d_0 -> cache1 -> conv2d_1 -> cache0 -> store -> global slice 1: global -> load -> cache0 -> conv2d_0 -> cache1 -> conv2d_1 -> cache0 -> store -> global
In this case the store from slice 0: cache0 -> store -> global can potentially block the load in slice 1: global -> load -> cache0
StorageRewrite is responsible for planning these caches, we'll need to understand how to avoid this for the async case.
primfn(placeholder_3: handle, placeholder_4: handle, placeholder_5: handle, output_1: handle) -> () attr = {"from_legacy_te_schedule": True, "global_symbol": "main", "tir.noalias": True} buffers = {output: Buffer(output_2: Pointer(float32), float32, [1, 8, 8, 4, 8, 8, 32], []), // nhw8h8w32c placeholder_2: Buffer(placeholder_6: Pointer(float32), float32, [4, 4, 1, 1, 8, 32, 4], []), // oihw8i32o4i placeholder_1: Buffer(placeholder_7: Pointer(float32), float32, [4, 4, 1, 1, 8, 32, 4], []), // oihw8i32o4i placeholder: Buffer(placeholder_8: Pointer(float32), float32, [1, 64, 64, 128], [])} // nhwc buffer_map = {placeholder_3: placeholder, placeholder_4: placeholder_1, placeholder_5: placeholder_2, output_1: output} { allocate(packed_input.global: Pointer(global float32), float32, [65536]), storage_scope = global; allocate(temp_output: Pointer(global float32), float32, [65536]), storage_scope = global; allocate(packed_filter.global: Pointer(global float32), float32, [4096]), storage_scope = global; for (ko.outer: int32, 0, 4) { for (ho.outer: int32, 0, 8) { // input cache read for (wo: int32, 0, 8) { for (co: int32, 0, 4) { for (hi: int32, 0, 8) { for (wi: int32, 0, 8) { for (ci: int32, 0, 32) { packed_input.global[(((((wo*8192) + (co*2048)) + (hi*256)) + (wi*32)) + ci)] = (float32*)placeholder_8[((((((ho.outer*65536) + (hi*8192)) + (wo*1024)) + (wi*128)) + (co*32)) + ci)] } } } } } // NOTE: compute over all output channels of conv2d #1 before computing conv2d #2 for (ko.outer_1: int32, 0, 4) { // filter #1 cache read for (co: int32, 0, 4) { for (cio: int32, 0, 8) { for (ki: int32, 0, 32) { for (cii: int32, 0, 4) { packed_filter.global[((((co*1024) + (cio*128)) + (ki*4)) + cii)] = (float32*)placeholder_7[(((((ko.outer_1*4096) + (co*1024)) + (cio*128)) + (ki*4)) + cii)] } } } } // conv2d #1 for (wo: int32, 0, 8) { // init temp output to zero for (hi.init: int32, 0, 8) { for (wi.init: int32, 0, 8) { for (ki.init: int32, 0, 32) { temp_output[(((((wo*8192) + (ko.outer_1*2048)) + (hi.init*256)) + (wi.init*32)) + ki.init)] = 0f32 } } } // compute for (rc.outer: int32, 0, 4) { for (hi: int32, 0, 8) { for (wi: int32, 0, 8) { for (ki: int32, 0, 32) { for (rc.inner: int32, 0, 32) { temp_output[(((((wo*8192) + (ko.outer_1*2048)) + (hi*256)) + (wi*32)) + ki)] = ( (float32*)temp_output[(((((wo*8192) + (ko.outer_1*2048)) + (hi*256)) + (wi*32)) + ki)] + ( (float32*)packed_input.global[(((((wo*8192) + (rc.outer*2048)) + (hi*256)) + (wi*32)) + rc.inner)] * (float32*)packed_filter.global[((((rc.outer*1024) + (floordiv(rc.inner, 4)*128)) + (ki*4)) + floormod(rc.inner, 4))] ) ) } } } } } } } // filter #2 cache read // NOTE: reusing same filter cache for (co: int32, 0, 4) { for (cio: int32, 0, 8) { for (ki: int32, 0, 32) { for (cii: int32, 0, 4) { packed_filter.global[((((co*1024) + (cio*128)) + (ki*4)) + cii)] = (float32*)placeholder_6[(((((ko.outer*4096) + (co*1024)) + (cio*128)) + (ki*4)) + cii)] } } } } // conv2d #2 for (wo.c: int32, 0, 8) { // init output cache to zero // NOTE: reusing the input cache as the output cache for (hi.c.init: int32, 0, 8) { for (wi.c.init: int32, 0, 8) { for (ki.c.init: int32, 0, 32) { packed_input.global[((((wo.c*2048) + (hi.c.init*256)) + (wi.c.init*32)) + ki.c.init)] = 0f32 } } } // compute for (rc.outer_1: int32, 0, 4) { for (hi.c: int32, 0, 8) { for (wi.c: int32, 0, 8) { for (ki.c: int32, 0, 32) { for (rc.inner_1: int32, 0, 32) { packed_input.global[((((wo.c*2048) + (hi.c*256)) + (wi.c*32)) + ki.c)] = ( (float32*)packed_input.global[((((wo.c*2048) + (hi.c*256)) + (wi.c*32)) + ki.c)] + ( (float32*)temp_output[(((((wo.c*8192) + (rc.outer_1*2048)) + (hi.c*256)) + (wi.c*32)) + rc.inner_1)] * (float32*)packed_filter.global[((((rc.outer_1*1024) + (floordiv(rc.inner_1, 4)*128)) + (ki.c*4)) + floormod(rc.inner_1, 4))] ) ) } } } } } } // write back output cache for (wo_1: int32, 0, 8) { for (hi_1: int32, 0, 8) { for (wi_1: int32, 0, 8) { for (ki_1: int32, 0, 32) { output_2[((((((ho.outer*65536) + (wo_1*8192)) + (ko.outer*2048)) + (hi_1*256)) + (wi_1*32)) + ki_1)] = (float32*)packed_input.global[((((wo_1*2048) + (hi_1*256)) + (wi_1*32)) + ki_1)] } } } } } } }
Uses parameters k_split
and h_split
which creates a loop split on the outer channel out ko
and height ho
loops creating outer
and inner
loops for each split. The cache reads and writes are computed at ho.outer
which means that cache allocation grow in proportion to k_split
and h_split
factors.
pytest -sv “tests/python/contrib/test_hexagon/test_conv2d_conv2d.py::TestConv2dConv2dPackedFilter::test_conv2d[1-64-128-0-1-1-128-1-1-128-2-2-float32-llvm]”
Parameter | Value |
---|---|
Batch | 1 |
Input Size | 64x64 |
Input Channel | 128 |
Conv2d #1 Pad | 0 |
Conv2d #1 Stride | 1 |
Conv2d #1 Kernel Size | 1 |
Conv2d #1 Output Channel | 128 |
Conv2d #2 Stride | 1 |
Conv2d #2 Kernel Size | 1 |
Conv2d #2 Output Channel | 128 |
k_split | 2 ^ |
h_split | 2 ^ |
^ Changes from above
Constant | Value |
---|---|
Conv2d #2 Pad | 0 |
Conv2d #1 Dilation | 1 |
Conv2d #2 Dilation | 1 |
The input is provided and padded in logical layout and then packed into its physical layout prior to compute. Logical layout / shape information is provided as a reference for physical tensors.
Tensor | Type | Layout | Shape | Logical Layout | Logical Shape |
---|---|---|---|---|---|
Input | Logical | NHWC | [1, 64, 64, 128] | ||
Padded Input | Logical | NHWC | [1, 64, 64, 128] | ||
Packed Input | Physical | NHWC8h8w32c | [1, 8, 8, 4, 8, 8, 32] | NHWC | [1, 64, 64, 128] |
Filter 1 | Physical | OIHW8i32o4i | [4, 4, 1, 1, 8, 32, 4] | OIHW | [128, 128, 1, 1] |
Temp Output | Physical | NHWC8h8w32c | [1, 8, 8, 4, 8, 8, 32] | NHWC | [1, 64, 64, 128] |
Filter 2 | Physical | OIHW8i32o4i | [4, 4, 1, 1, 8, 32, 4] | OIHW | [128, 128, 1, 1] |
Output | Physical | NHWC8h8w32c | [1, 8, 8, 4, 8, 8, 32] | NHWC | [1, 64, 64, 128] |
This is the conv2d compute schedule:
for (ko.outer: int32, 0, 2) { for (ho.outer: int32, 0, 4) { // input cache read for (ho.inner: int32, 0, 2) { ... } for (ko.outer_1: int32, 0, 2) { // filter #1 cache read for (ko.inner: int32, 0, 2) { ... } // conv2d #1 for (ko.inner: int32, 0, 2) { for (ho.inner: int32, 0, 2) { for (wo: int32, 0, 8) { for (rc.outer: int32, 0, 4) { for (hi: int32, 0, 8) { for (wi: int32, 0, 8) { for (ki: int32, 0, 32) { for (rc.inner: int32, 0, 32) { } // end ko.outer_1 // filter #2 cache read for (ko.inner: int32, 0, 2) { ... } // conv2d #2 for (ko.c.inner: int32, 0, 2) { for (ho.c.inner: int32, 0, 2) { for (wo.c: int32, 0, 8) { for (rc.outer_1: int32, 0, 4) { for (hi.c: int32, 0, 8) { for (wi.c: int32, 0, 8) { for (ki.c: int32, 0, 32) { for (rc.inner_1: int32, 0, 32) { // write back output cache } // end ho.outer } // end ko.outer
The major change here versus above is the presence of inner
loops for both channel out ko
and height ho
dimensions created from the k_split
and h_split
schedule parameters respectively, for example:
for (ko.c.inner: int32, 0, 2) { for (ho.c.inner: int32, 0, 2) {
The effect of this change is increased cache usage given where the caches are computed in the schedule. Specifically, the input cache is now computed over ho.inner
and the filter caches are computed over ko.inner
which will grow the size of the cache. Details below.
(Same as above) Note that conv2d #1 has an independent loop over the channel out ko.outer_1
dimension. This is because the output channels of conv2d #1 are the input channels to conv2d #2 and we compute over all input channels for each conv2d so we must compute over all output channels of conv2d #1 before we compute conv2d #2.
for (ko.outer_1: int32, 0, 2) {
Input Cache
The input cache grows by a factor of h_split = 2
compared with above:
allocate(packed_input.global: Pointer(global float32), float32, [131072]), storage_scope = global;
Filter Cache
The filter cache grows by a factor of k_split = 2
compared with above:
allocate(packed_filter.global: Pointer(global float32), float32, [8192]), storage_scope = global;
(Same as above) Note that there is just one cache which is reused for conv2d / filter #1 and conv2d / filter #2.
Output Cache
The output cache grows by a factor of k_split = 2
compared with above:
allocate(temp_output: Pointer(global float32), float32, [131072]), storage_scope = global;
(Same as above) Note that the input cache is reused to store the results of conv2d #2.
primfn(placeholder_3: handle, placeholder_4: handle, placeholder_5: handle, output_1: handle) -> () attr = {"from_legacy_te_schedule": True, "global_symbol": "main", "tir.noalias": True} buffers = {output: Buffer(output_2: Pointer(float32), float32, [1, 8, 8, 4, 8, 8, 32], []), // nhw8h8w32c placeholder_2: Buffer(placeholder_6: Pointer(float32), float32, [4, 4, 1, 1, 8, 32, 4], []), // oihw8i32o4i placeholder_1: Buffer(placeholder_7: Pointer(float32), float32, [4, 4, 1, 1, 8, 32, 4], []), // oihw8i32o4i placeholder: Buffer(placeholder_8: Pointer(float32), float32, [1, 64, 64, 128], [])} // nhwc buffer_map = {placeholder_3: placeholder, placeholder_4: placeholder_1, placeholder_5: placeholder_2, output_1: output} { allocate(packed_input.global: Pointer(global float32), float32, [131072]), storage_scope = global; allocate(temp_output: Pointer(global float32), float32, [131072]), storage_scope = global; allocate(packed_filter.global: Pointer(global float32), float32, [8192]), storage_scope = global; for (ko.outer: int32, 0, 2) { for (ho.outer: int32, 0, 4) { // input cache read for (ho.inner: int32, 0, 2) { for (wo: int32, 0, 8) { for (co: int32, 0, 4) { for (hi: int32, 0, 8) { for (wi: int32, 0, 8) { for (ci: int32, 0, 32) { packed_input.global[((((((ho.inner*65536) + (wo*8192)) + (co*2048)) + (hi*256)) + (wi*32)) + ci)] = (float32*)placeholder_8[(((((((ho.outer*131072) + (ho.inner*65536)) + (hi*8192)) + (wo*1024)) + (wi*128)) + (co*32)) + ci)] } } } } } } // NOTE: compute over all output channels of conv2d #1 before computing conv2d #2 for (ko.outer_1: int32, 0, 2) { // filter #1 cache read for (ko.inner: int32, 0, 2) { for (co: int32, 0, 4) { for (cio: int32, 0, 8) { for (ki: int32, 0, 32) { for (cii: int32, 0, 4) { packed_filter.global[(((((ko.inner*4096) + (co*1024)) + (cio*128)) + (ki*4)) + cii)] = (float32*)placeholder_7[((((((ko.outer_1*8192) + (ko.inner*4096)) + (co*1024)) + (cio*128)) + (ki*4)) + cii)] } } } } } // conv2d #1 for (ko.inner: int32, 0, 2) { for (ho.inner: int32, 0, 2) { for (wo: int32, 0, 8) { // init temp output to zero for (hi.init: int32, 0, 8) { for (wi.init: int32, 0, 8) { for (ki.init: int32, 0, 32) { temp_output[(((((((ho.inner*65536) + (wo*8192)) + (ko.outer_1*4096)) + (ko.inner*2048)) + (hi.init*256)) + (wi.init*32)) + ki.init)] = 0f32 } } } // compute for (rc.outer: int32, 0, 4) { for (hi: int32, 0, 8) { for (wi: int32, 0, 8) { for (ki: int32, 0, 32) { for (rc.inner: int32, 0, 32) { temp_output[(((((((ho.inner*65536) + (wo*8192)) + (ko.outer_1*4096)) + (ko.inner*2048)) + (hi*256)) + (wi*32)) + ki)] = ( (float32*)temp_output[(((((((ho.inner*65536) + (wo*8192)) + (ko.outer_1*4096)) + (ko.inner*2048)) + (hi*256)) + (wi*32)) + ki)] + ( (float32*)packed_input.global[((((((ho.inner*65536) + (wo*8192)) + (rc.outer*2048)) + (hi*256)) + (wi*32)) + rc.inner)] * (float32*)packed_filter.global[(((((ko.inner*4096) + (rc.outer*1024)) + (floordiv(rc.inner, 4)*128)) + (ki*4)) + floormod(rc.inner, 4))] ) ) } } } } } } } } } // filter #2 cache read // NOTE: reusing same filter cache for (ko.inner: int32, 0, 2) { for (co: int32, 0, 4) { for (cio: int32, 0, 8) { for (ki: int32, 0, 32) { for (cii: int32, 0, 4) { packed_filter.global[(((((ko.inner*4096) + (co*1024)) + (cio*128)) + (ki*4)) + cii)] = (float32*)placeholder_6[((((((ko.outer*8192) + (ko.inner*4096)) + (co*1024)) + (cio*128)) + (ki*4)) + cii)] } } } } } // conv2d #2 for (ko.c.inner: int32, 0, 2) { for (ho.c.inner: int32, 0, 2) { for (wo.c: int32, 0, 8) { // init output cache to zero // NOTE: reusing the input cache as the output cache for (hi.c.init: int32, 0, 8) { for (wi.c.init: int32, 0, 8) { for (ki.c.init: int32, 0, 32) { packed_input.global[((((((ho.c.inner*32768) + (wo.c*4096)) + (ko.c.inner*2048)) + (hi.c.init*256)) + (wi.c.init*32)) + ki.c.init)] = 0f32 } } } // compute for (rc.outer_1: int32, 0, 4) { for (hi.c: int32, 0, 8) { for (wi.c: int32, 0, 8) { for (ki.c: int32, 0, 32) { for (rc.inner_1: int32, 0, 32) { packed_input.global[((((((ho.c.inner*32768) + (wo.c*4096)) + (ko.c.inner*2048)) + (hi.c*256)) + (wi.c*32)) + ki.c)] = ( (float32*)packed_input.global[((((((ho.c.inner*32768) + (wo.c*4096)) + (ko.c.inner*2048)) + (hi.c*256)) + (wi.c*32)) + ki.c)] + ( (float32*)temp_output[((((((ho.c.inner*65536) + (wo.c*8192)) + (rc.outer_1*2048)) + (hi.c*256)) + (wi.c*32)) + rc.inner_1)] * (float32*)packed_filter.global[(((((ko.c.inner*4096) + (rc.outer_1*1024)) + (floordiv(rc.inner_1, 4)*128)) + (ki.c*4)) + floormod(rc.inner_1, 4))] ) ) } } } } } } } } // write back output cache for (ko.inner_1: int32, 0, 2) { for (ho.inner_1: int32, 0, 2) { for (wo_1: int32, 0, 8) { for (hi_1: int32, 0, 8) { for (wi_1: int32, 0, 8) { for (ki_1: int32, 0, 32) { output_2[((((((((ho.outer*131072) + (ho.inner_1*65536)) + (wo_1*8192)) + (ko.outer*4096)) + (ko.inner_1*2048)) + (hi_1*256)) + (wi_1*32)) + ki_1)] = (float32*)packed_input.global[((((((ho.inner_1*32768) + (wo_1*4096)) + (ko.inner_1*2048)) + (hi_1*256)) + (wi_1*32)) + ki_1)] } } } } } } } } }
Change from a 1x1 filter to a 3x3 filter.
pytest -sv “tests/python/contrib/test_hexagon/test_conv2d_conv2d.py::TestConv2dConv2dPackedFilter::test_conv2d[1-64-128-0-1-3-128-1-3-128-2-2-float32-llvm]”
Parameter | Value |
---|---|
Batch | 1 |
Input Size | 64x64 |
Input Channel | 128 |
Conv2d #1 Pad | 0 |
Conv2d #1 Stride | 1 |
Conv2d #1 Kernel Size | 3 ^ |
Conv2d #1 Output Channel | 128 |
Conv2d #2 Stride | 1 |
Conv2d #2 Kernel Size | 3 ^ |
Conv2d #2 Output Channel | 128 |
k_split | 2 |
h_split | 2 |
^ Changes from above
Constant | Value |
---|---|
Conv2d #2 Pad | 0 |
Conv2d #1 Dilation | 1 |
Conv2d #2 Dilation | 1 |
The input is provided and padded in logical layout and then packed into its physical layout prior to compute. Logical layout / shape information is provided as a reference for physical tensors.
Tensor | Type | Layout | Shape | Logical Layout | Logical Shape |
---|---|---|---|---|---|
Input | Logical | NHWC | [1, 64, 64, 128] | ||
Padded Input | Logical | NHWC | [1, 64, 64, 128] | ||
Packed Input | Physical | NHWC8h8w32c | [1, 8, 8, 4, 8, 8, 32] | NHWC | [1, 64, 64, 128] |
Filter 1 | Physical | OIHW8i32o4i | [4, 4, 3, 3, 8, 32, 4] | OIHW | [128, 128, 3, 3] |
Temp Output | Physical | NHWC8h8w32c | [1, 8, 8, 4, 8, 8, 32] | NHWC | [1, 62, 62, 128] |
Filter 2 | Physical | OIHW8i32o4i | [4, 4, 3, 3, 8, 32, 4] | OIHW | [128, 128, 3, 3] |
Output | Physical | NHWC8h8w32c | [1, 8, 8, 4, 8, 8, 32] | NHWC | [1, 60, 60, 128] |
This is the conv2d compute schedule:
for (ko.outer: int32, 0, 2) { for (ho.outer: int32, 0, 4) { for (ko.outer_1: int32, 0, 2) { for (ho.outer_1: int32, 0, 2) { // input cache read for (ho.inner: int32, 0, 3) { if ((((ho.outer_1*2) + (ho.outer*2)) + ho.inner) < 8) { ... } } // filter #1 cache read for (ko.inner: int32, 0, 2) { ... } // conv2d #1 for (ko.inner: int32, 0, 2) { for (ho.inner: int32, 0, 2) { for (wo: int32, 0, 8) { if (((ho.outer_1*2) + ho.inner) < 3) { if ((((ho.outer_1*2) + (ho.outer*2)) + ho.inner) < 8) { for (rc.outer: int32, 0, 4) { for (hi: int32, 0, 8) { for (wi: int32, 0, 8) { for (rh: int32, 0, 3) { for (rw: int32, 0, 3) { for (ki: int32, 0, 32) { for (rc.inner: int32, 0, 32) { } // end ho.outer_1 } // end ko.outer_1 // filter #2 cache read for (ko.inner: int32, 0, 2) { ... } // conv2d #2 for (ko.c.inner: int32, 0, 2) { for (ho.c.inner: int32, 0, 2) { for (wo.c: int32, 0, 8) { for (rc.outer_1: int32, 0, 4) { for (hi.c: int32, 0, 8) { for (wi.c: int32, 0, 8) { for (rh_1: int32, 0, 3) { for (rw_1: int32, 0, 3) { for (ki.c: int32, 0, 32) { for (rc.inner_1: int32, 0, 32) { // write back output cache } // end ho.outer } // end ko.outer
There are two major changes here:
rh
and width rw
iterators, for example:for (rh_1: int32, 0, 3) { for (rw_1: int32, 0, 3) {
The effect of this change is to grow the filter cache by the size of the kernel. Details below.
h_split
(2) “full width” and “full channel depth” slices from each conv2d. Given the 3x3 kernel size there are several changes to the schedule regarding the handling of the height dimension.First, notice that in order to produce h_split
(2) “full width” and “full channel depth” slices for conv2d #1 we will need h_split + 1
(3) “full width” and “full channel depth” slices of the input. This is because a 3x3 kernel (as opposed to a 1x1 kernel) creates a many-to-one relationship between the spatial coordinates of the input relative to the output. To illustrate, the 3x3 kernel will “fall off the bottom” of the 2nd input slice requiring values from the vertically adjacent 3rd input slice in order to produce the 2nd full output slice. Hence, we have the following input cache read over h_split + 1
(3) input slices:
for (ho.inner: int32, 0, 3) { if ((((ho.outer_1*2) + (ho.outer*2)) + ho.inner) < 8) {
The if
statement above indicates NOT to prefetch the vertically adjacent slice at the “bottom” of the input since it does not exist.
Second, notice that conv2d #1 must produce sufficient output in the height dimension before conv2d #2 can proceed. This is similar to the requirement that conv2d #1 in regard to the channel out dimension, but also different because we do not require all output in the height dimenson only sufficient output in the height dimension. How much output in the height dimension is required? The intuitive guess might be h_split + 1
(3) slices but that is wrong and the reason is that the output spatial coordinates are “shrinking” relative to the input coordinate space due to lack of padding. Hence 2 output slices from conv2d #1 are sufficient as intput to calculate 2 output slices from conv2d #2 and we get the following independent loop over ho.outer_1
for conv2d #1:
for (ho.outer_1: int32, 0, 2) {
There are similar if
statements in the conv2d compute schedule to prevent computing off the “bottom” of the input and output.
(Same as above) Note that conv2d #1 has an independent loop over the channel out ko.outer_1
dimension. This is because the output channels of conv2d #1 are the input channels to conv2d #2 and we compute over all input channels for each conv2d so we must compute over all output channels of conv2d #1 before we compute conv2d #2.
for (ko.outer_1: int32, 0, 2) {
Input Cache
The input cache grows to hold the vertically adjacent slice:
allocate(packed_input.global: Pointer(global float32), float32, [196608]), storage_scope = global;
Filter Cache
The filter cache grows to hold the 3x3 filter filter:
allocate(packed_filter.global: Pointer(global float32), float32, [73728]), storage_scope = global;
(Same as above) Note that there is just one cache which is reused for conv2d / filter #1 and conv2d / filter #2.
Output Cache
The output cache scales with the input cache:
allocate(temp_output: Pointer(global float32), float32, [196608]), storage_scope = global;
(Same as above) Note that the input cache is reused to store the results of conv2d #2.
primfn(placeholder_3: handle, placeholder_4: handle, placeholder_5: handle, output_1: handle) -> () attr = {"from_legacy_te_schedule": True, "global_symbol": "main", "tir.noalias": True} buffers = {output: Buffer(output_2: Pointer(float32), float32, [1, 8, 8, 4, 8, 8, 32], []), // nhw8h8w32c placeholder_2: Buffer(placeholder_6: Pointer(float32), float32, [4, 4, 3, 3, 8, 32, 4], []), // oihw8i32o4i placeholder_1: Buffer(placeholder_7: Pointer(float32), float32, [4, 4, 3, 3, 8, 32, 4], []), // oihw8i32o4i placeholder: Buffer(placeholder_8: Pointer(float32), float32, [1, 64, 64, 128], [])} // nhwc buffer_map = {placeholder_3: placeholder, placeholder_4: placeholder_1, placeholder_5: placeholder_2, output_1: output} { allocate(packed_input.global: Pointer(global float32), float32, [196608]), storage_scope = global; allocate(temp_output: Pointer(global float32), float32, [196608]), storage_scope = global; allocate(packed_filter.global: Pointer(global float32), float32, [73728]), storage_scope = global; for (ko.outer: int32, 0, 2) { for (ho.outer: int32, 0, 4) { // NOTE: compute over all output channels of conv2d #1 before computing conv2d #2 for (ko.outer_1: int32, 0, 2) { // NOTE: compute enough height of conv2d #1 before computing conv2d #2 for (ho.outer_1: int32, 0, 2) { // input cache read for (ho.inner: int32, 0, 3) { if ((((ho.outer_1*2) + (ho.outer*2)) + ho.inner) < 8) { for (wo: int32, 0, 8) { for (co: int32, 0, 4) { for (hi: int32, 0, 8) { for (wi: int32, 0, 8) { for (ci: int32, 0, 32) { packed_input.global[((((((ho.inner*65536) + (wo*8192)) + (co*2048)) + (hi*256)) + (wi*32)) + ci)] = (float32*)placeholder_8[((((((((ho.outer_1*131072) + (ho.outer*131072)) + (ho.inner*65536)) + (hi*8192)) + (wo*1024)) + (wi*128)) + (co*32)) + ci)] } } } } } } } // filter #1 cache read for (ko.inner: int32, 0, 2) { for (co: int32, 0, 4) { for (rh: int32, 0, 3) { for (rw: int32, 0, 3) { for (cio: int32, 0, 8) { for (ki: int32, 0, 32) { for (cii: int32, 0, 4) { packed_filter.global[(((((((ko.inner*36864) + (co*9216)) + (rh*3072)) + (rw*1024)) + (cio*128)) + (ki*4)) + cii)] = (float32*)placeholder_7[((((((((ko.outer_1*73728) + (ko.inner*36864)) + (co*9216)) + (rh*3072)) + (rw*1024)) + (cio*128)) + (ki*4)) + cii)] } } } } } } } // conv2d #1 for (ko.inner: int32, 0, 2) { for (ho.inner: int32, 0, 2) { for (wo: int32, 0, 8) { // init temp output to zero if (((ho.outer_1*2) + ho.inner) < 3) { for (hi.init: int32, 0, 8) { for (wi.init: int32, 0, 8) { for (ki.init: int32, 0, 32) { temp_output[((((((((ho.outer_1*131072) + (ho.inner*65536)) + (wo*8192)) + (ko.outer_1*4096)) + (ko.inner*2048)) + (hi.init*256)) + (wi.init*32)) + ki.init)] = 0f32 } } } } // compute if (((ho.outer_1*2) + ho.inner) < 3) { if ((((ho.outer_1*2) + (ho.outer*2)) + ho.inner) < 8) { for (rc.outer: int32, 0, 4) { for (hi: int32, 0, 8) { for (wi: int32, 0, 8) { for (rh: int32, 0, 3) { for (rw: int32, 0, 3) { for (ki: int32, 0, 32) { for (rc.inner: int32, 0, 32) { temp_output[((((((((ho.outer_1*131072) + (ho.inner*65536)) + (wo*8192)) + (ko.outer_1*4096)) + (ko.inner*2048)) + (hi*256)) + (wi*32)) + ki)] = ( (float32*)temp_output[((((((((ho.outer_1*131072) + (ho.inner*65536)) + (wo*8192)) + (ko.outer_1*4096)) + (ko.inner*2048)) + (hi*256)) + (wi*32)) + ki)] + ( (float32*)packed_input.global[((((((((floordiv((hi + rh), 8)*65536) + (ho.inner*65536)) + (floordiv((wi + rw), 8)*8192)) + (wo*8192)) + (rc.outer*2048)) + (floormod((hi + rh), 8)*256)) + (floormod((wi + rw), 8)*32)) + rc.inner)] * (float32*)packed_filter.global[(((((((ko.inner*36864) + (rc.outer*9216)) + (rh*3072)) + (rw*1024)) + (floordiv(rc.inner, 4)*128)) + (ki*4)) + floormod(rc.inner, 4))] ) ) } } } } } } } } } } } } } } // filter #2 cache read // NOTE: reusing same filter cache for (ko.inner: int32, 0, 2) { for (co: int32, 0, 4) { for (rh: int32, 0, 3) { for (rw: int32, 0, 3) { for (cio: int32, 0, 8) { for (ki: int32, 0, 32) { for (cii: int32, 0, 4) { packed_filter.global[(((((((ko.inner*36864) + (co*9216)) + (rh*3072)) + (rw*1024)) + (cio*128)) + (ki*4)) + cii)] = (float32*)placeholder_6[((((((((ko.outer*73728) + (ko.inner*36864)) + (co*9216)) + (rh*3072)) + (rw*1024)) + (cio*128)) + (ki*4)) + cii)] } } } } } } } // conv2d #2 for (ko.c.inner: int32, 0, 2) { for (ho.c.inner: int32, 0, 2) { for (wo.c: int32, 0, 8) { // init output cache to zero // NOTE: reusing the input cache as the output cache for (hi.c.init: int32, 0, 8) { for (wi.c.init: int32, 0, 8) { for (ki.c.init: int32, 0, 32) { packed_input.global[((((((ho.c.inner*32768) + (wo.c*4096)) + (ko.c.inner*2048)) + (hi.c.init*256)) + (wi.c.init*32)) + ki.c.init)] = 0f32 } } } // compute for (rc.outer_1: int32, 0, 4) { for (hi.c: int32, 0, 8) { for (wi.c: int32, 0, 8) { for (rh_1: int32, 0, 3) { for (rw_1: int32, 0, 3) { for (ki.c: int32, 0, 32) { for (rc.inner_1: int32, 0, 32) { packed_input.global[((((((ho.c.inner*32768) + (wo.c*4096)) + (ko.c.inner*2048)) + (hi.c*256)) + (wi.c*32)) + ki.c)] = ( (float32*)packed_input.global[((((((ho.c.inner*32768) + (wo.c*4096)) + (ko.c.inner*2048)) + (hi.c*256)) + (wi.c*32)) + ki.c)] + ( (float32*)temp_output[((((((((floordiv((hi.c + rh_1), 8)*65536) + (ho.c.inner*65536)) + (floordiv((wi.c + rw_1), 8)*8192)) + (wo.c*8192)) + (rc.outer_1*2048)) + (floormod((hi.c + rh_1), 8)*256)) + (floormod((wi.c + rw_1), 8)*32)) + rc.inner_1)] * (float32*)packed_filter.global[(((((((ko.c.inner*36864) + (rc.outer_1*9216)) + (rh_1*3072)) + (rw_1*1024)) + (floordiv(rc.inner_1, 4)*128)) + (ki.c*4)) + floormod(rc.inner_1, 4))] ) ) } } } } } } } } } } // write back output cache for (ko.inner_1: int32, 0, 2) { for (ho.inner_1: int32, 0, 2) { for (wo_1: int32, 0, 8) { for (hi_1: int32, 0, 8) { for (wi_1: int32, 0, 8) { for (ki_1: int32, 0, 32) { output_2[((((((((ho.outer*131072) + (ho.inner_1*65536)) + (wo_1*8192)) + (ko.outer*4096)) + (ko.inner_1*2048)) + (hi_1*256)) + (wi_1*32)) + ki_1)] = (float32*)packed_input.global[((((((ho.inner_1*32768) + (wo_1*4096)) + (ko.inner_1*2048)) + (hi_1*256)) + (wi_1*32)) + ki_1)] } } } } } } } } }