| # Licensed to the Apache Software Foundation (ASF) under one |
| # or more contributor license agreements. See the NOTICE file |
| # distributed with this work for additional information |
| # regarding copyright ownership. The ASF licenses this file |
| # to you under the Apache License, Version 2.0 (the |
| # "License"); you may not use this file except in compliance |
| # with the License. You may obtain a copy of the License at |
| # |
| # http://www.apache.org/licenses/LICENSE-2.0 |
| # |
| # Unless required by applicable law or agreed to in writing, |
| # software distributed under the License is distributed on an |
| # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
| # KIND, either express or implied. See the License for the |
| # specific language governing permissions and limitations |
| # under the License. |
| # pylint: disable=invalid-name,unused-variable,too-many-locals,len-as-condition |
| """Schedule for reduce operators""" |
| from __future__ import absolute_import as _abs |
| import tvm |
| from tvm import te |
| from .. import tag |
| from .injective import schedule_injective_from_existing |
| |
| |
| def _schedule_reduce(op, sch, is_idx_reduce=False): |
| if is_idx_reduce: |
| data_out = op.input_tensors[0] |
| else: |
| data_in = op.input_tensors[0] |
| data_out = op.output(0) |
| |
| if not sch[data_out].op.reduce_axis: |
| return schedule_injective_from_existing(sch, op.output(0)) |
| |
| if len(sch[data_out].op.axis) > 0: |
| all_reduce = False |
| num_thread = 32 |
| target = tvm.target.Target.current() |
| if target and target.kind.name == "opencl": |
| # without it, CL_INVALID_WORK_GROUP_SIZE occurred when running test_topi_reduce.py |
| # don't know why |
| num_thread = 16 |
| block_x = te.thread_axis("blockIdx.x") |
| thread_x = te.thread_axis((0, num_thread), "threadIdx.x") |
| thread_y = te.thread_axis((0, num_thread), "threadIdx.y") |
| else: |
| all_reduce = True |
| num_thread = tvm.target.Target.current(allow_none=False).max_num_threads |
| thread_x = te.thread_axis((0, num_thread), "threadIdx.x") |
| |
| # Fuse and refactor the reduce axis |
| fused_reduce = sch[data_out].fuse( |
| *[sch[data_out].op.reduce_axis[i] for i in range(len(sch[data_out].op.reduce_axis))] |
| ) |
| ko, ki = sch[data_out].split(fused_reduce, factor=num_thread) |
| if is_idx_reduce: |
| data_out_rf, _ = sch.rfactor(data_out, ki) |
| else: |
| data_out_rf = sch.rfactor(data_out, ki) |
| tx = sch[data_out].op.reduce_axis[0] |
| sch[data_out].bind(tx, thread_x) |
| sch[data_out_rf].compute_at(sch[data_out], tx) |
| if is_idx_reduce: |
| real_output = op.output(0) |
| temp_idx_input = data_out.op.output(0) |
| temp_val_input = data_out.op.output(1) |
| else: |
| real_output = data_out |
| if not all_reduce: |
| # Fuse and split the axis |
| fused_outer = sch[real_output].fuse( |
| *[sch[real_output].op.axis[i] for i in range(len(sch[real_output].op.axis))] |
| ) |
| bx, outer_in = sch[real_output].split(fused_outer, factor=num_thread) |
| |
| # Bind the axes to threads and blocks |
| sch[real_output].bind(outer_in, thread_y) |
| sch[real_output].bind(bx, block_x) |
| if is_idx_reduce: |
| sch[temp_idx_input].compute_at(sch[real_output], outer_in) |
| sch[temp_val_input].compute_at(sch[real_output], outer_in) |
| else: |
| if is_idx_reduce: |
| spatial_axis = sch[real_output].fuse(*(sch[real_output].op.axis)) |
| sch[real_output].bind(spatial_axis, te.thread_axis("blockIdx.x")) |
| sch[temp_idx_input].compute_at(sch[real_output], spatial_axis) |
| sch[temp_val_input].compute_at(sch[real_output], spatial_axis) |
| sch[real_output].set_store_predicate(thread_x.equal(0)) |
| return sch |
| |
| |
| def schedule_reduce(outs): |
| """Schedule for inject->reduce->bcast ops. |
| |
| Parameters |
| ---------- |
| outs: Array of Tensor |
| The computation graph description of reduce in the format |
| of an array of tensors. |
| |
| Returns |
| ------- |
| sch: Schedule |
| The computation schedule for the op. |
| """ |
| outs = [outs] if isinstance(outs, te.tensor.Tensor) else outs |
| sch = te.create_schedule([x.op for x in outs]) |
| scheduled_ops = [] |
| |
| def traverse_before_reduce(operator): |
| """Internal traverse function""" |
| if isinstance(operator, tvm.te.PlaceholderOp): |
| return |
| if tag.is_injective(operator.tag): |
| sch[operator].compute_inline() |
| for tensor in operator.input_tensors: |
| if tensor.op not in scheduled_ops: |
| traverse_before_reduce(tensor.op) |
| else: |
| raise RuntimeError("Unsupported operator: %s" % operator.tag) |
| |
| scheduled_ops.append(operator) |
| |
| def traverse_after_reduce(operator): |
| """Internal traverse function""" |
| if tag.is_broadcast(operator.tag): |
| if operator not in scheduled_ops: |
| schedule_injective_from_existing(sch, operator.output(0)) |
| for tensor in operator.input_tensors: |
| traverse_after_reduce(tensor.op) |
| elif operator.tag == "comm_reduce": |
| _schedule_reduce(operator, sch, is_idx_reduce=False) |
| for tensor in operator.input_tensors: |
| if tensor.op not in scheduled_ops: |
| traverse_before_reduce(tensor.op) |
| elif operator.tag == "comm_reduce_idx": |
| _schedule_reduce(operator, sch, is_idx_reduce=True) |
| input_tensors = operator.input_tensors[0].op.input_tensors |
| for tensor in input_tensors: |
| if tensor.op not in scheduled_ops: |
| traverse_before_reduce(tensor.op) |
| elif isinstance(operator, tvm.te.PlaceholderOp): |
| pass |
| else: |
| raise RuntimeError("Unsupported operator: %s" % operator.tag) |
| |
| scheduled_ops.append(operator) |
| |
| traverse_after_reduce(outs[0].op) |
| return sch |