blob: b90de753f67927bf7063208d7670c67e5797025e [file] [log] [blame]
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# pylint: disable=invalid-name
"""Block config to hold an output block shape and a corresponding input block shape"""
from typing import List
import tvm._ffi
from tvm.runtime import Object
from . import _ffi_api
@tvm._ffi.register_object("contrib.ethosu.cascader.BlockConfig")
class BlockConfig(Object):
"""BlockConfig class"""
def __init__(
self,
input_shape: List[int],
output_shape: List[int],
compute_cycles: int,
output_cycles: int,
):
self.__init_handle_by_constructor__(
_ffi_api.BlockConfig, input_shape, output_shape, compute_cycles, output_cycles
)
@property
def input_shape(self) -> List[int]:
return list(self._input_shape)
@property
def output_shape(self) -> List[int]:
return list(self._output_shape)
@property
def compute_cycles(self) -> int:
return int(self._compute_cycles)
@property
def output_cycles(self) -> int:
return int(self._output_cycles)
def __ge__(self, other: "BlockConfig"):
if len(self.output_shape) != len(other.output_shape):
return False
return all(a >= b for a, b in zip(self.output_shape, other.output_shape))
def __lt__(self, other: "BlockConfig"):
if len(self.output_shape) != len(other.output_shape):
return False
return other >= self
def __repr__(self) -> str:
return f"BlockConfig(output_shape={self.output_shape})"