blob: b7b96ae4efa9cf7350fdf20b21f771da99c74029 [file] [log] [blame]
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
import argparse
import pytest
from tvm.driver.tvmc.shape_parser import parse_shape_string
def test_shape_parser():
# Check that a valid input is parsed correctly
shape_string = "input:[10,10,10]"
shape_dict = parse_shape_string(shape_string)
assert shape_dict == {"input": [10, 10, 10]}
def test_alternate_syntax():
shape_string = "input:0:[10,10,10] input2:[20,20,20,20]"
shape_dict = parse_shape_string(shape_string)
assert shape_dict == {"input:0": [10, 10, 10], "input2": [20, 20, 20, 20]}
@pytest.mark.parametrize(
"shape_string",
[
"input:[10,10,10] input2:[20,20,20,20]",
"input: [10, 10, 10] input2: [20, 20, 20, 20]",
"input:[10,10,10],input2:[20,20,20,20]",
],
)
def test_alternate_syntaxes(shape_string):
shape_dict = parse_shape_string(shape_string)
assert shape_dict == {"input": [10, 10, 10], "input2": [20, 20, 20, 20]}
def test_negative_dimensions():
# Check that negative dimensions parse to Any correctly.
shape_string = "input:[-1,3,224,224]"
shape_dict = parse_shape_string(shape_string)
# Convert to strings to allow comparison with Any.
assert str(shape_dict) == "{'input': [T.Any(), 3, 224, 224]}"
def test_multiple_valid_gpu_inputs():
# Check that multiple valid gpu inputs are parsed correctly.
shape_string = "gpu_0/data_0:[1, -1,224,224] gpu_1/data_1:[7, 7]"
shape_dict = parse_shape_string(shape_string)
expected = "{'gpu_0/data_0': [1, T.Any(), 224, 224], 'gpu_1/data_1': [7, 7]}"
assert str(shape_dict) == expected
def test_invalid_pattern():
shape_string = "input:[a,10]"
with pytest.raises(argparse.ArgumentTypeError):
parse_shape_string(shape_string)
def test_invalid_separators():
shape_string = "input:5,10 input2:10,10"
with pytest.raises(argparse.ArgumentTypeError):
parse_shape_string(shape_string)
def test_invalid_colon():
shape_string = "gpu_0/data_0:5,10 :test:10,10"
with pytest.raises(argparse.ArgumentTypeError):
parse_shape_string(shape_string)
@pytest.mark.parametrize(
"shape_string",
[
"gpu_0/data_0:5,10 /:10,10",
"gpu_0/data_0:5,10 data/:10,10",
"gpu_0/data_0:5,10 /data:10,10",
"gpu_0/invalid/data_0:5,10 data_1:10,10",
],
)
def test_invalid_slashes(shape_string):
with pytest.raises(argparse.ArgumentTypeError):
parse_shape_string(shape_string)
def test_dot():
# Check dot in input name
shape_string = "input.1:[10,10,10]"
shape_dict = parse_shape_string(shape_string)
assert shape_dict == {"input.1": [10, 10, 10]}