| # Licensed to the Apache Software Foundation (ASF) under one |
| # or more contributor license agreements. See the NOTICE file |
| # distributed with this work for additional information |
| # regarding copyright ownership. The ASF licenses this file |
| # to you under the Apache License, Version 2.0 (the |
| # "License"); you may not use this file except in compliance |
| # with the License. You may obtain a copy of the License at |
| # |
| # http://www.apache.org/licenses/LICENSE-2.0 |
| # |
| # Unless required by applicable law or agreed to in writing, |
| # software distributed under the License is distributed on an |
| # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
| # KIND, either express or implied. See the License for the |
| # specific language governing permissions and limitations |
| # under the License. |
| |
| class TestTensor < Test::Unit::TestCase |
| include Helper::Omittable |
| |
| def setup |
| @raw_data = [ |
| 1, 2, |
| 3, 4, |
| |
| 5, 6, |
| 7, 8, |
| |
| 9, 10, |
| 11, 12, |
| ] |
| data = Arrow::Buffer.new(@raw_data.pack("c*")) |
| @shape = [3, 2, 2] |
| strides = [] |
| names = ["a", "b", "c"] |
| @tensor = Arrow::Tensor.new(Arrow::Int8DataType.new, |
| data, |
| @shape, |
| strides, |
| names) |
| end |
| |
| def test_equal |
| data = Arrow::Buffer.new(@raw_data.pack("c*")) |
| strides = [] |
| names = ["a", "b", "c"] |
| other_tensor = Arrow::Tensor.new(Arrow::Int8DataType.new, |
| data, |
| @shape, |
| strides, |
| names) |
| assert_equal(@tensor, |
| other_tensor) |
| end |
| |
| def test_value_data_type |
| assert_equal(Arrow::Int8DataType, @tensor.value_data_type.class) |
| end |
| |
| def test_value_type |
| assert_equal(Arrow::Type::INT8, @tensor.value_type) |
| end |
| |
| def test_buffer |
| assert_equal(@raw_data, @tensor.buffer.data.to_s.unpack("c*")) |
| end |
| |
| def test_shape |
| require_gi_bindings(3, 3, 1) |
| assert_equal(@shape, @tensor.shape) |
| end |
| |
| def test_strides |
| require_gi_bindings(3, 3, 1) |
| assert_equal([4, 2, 1], @tensor.strides) |
| end |
| |
| def test_n_dimensions |
| assert_equal(@shape.size, @tensor.n_dimensions) |
| end |
| |
| def test_dimension_name |
| dimension_names = @tensor.n_dimensions.times.collect do |i| |
| @tensor.get_dimension_name(i) |
| end |
| assert_equal(["a", "b", "c"], |
| dimension_names) |
| end |
| |
| def test_size |
| assert_equal(@raw_data.size, @tensor.size) |
| end |
| |
| def test_mutable? |
| assert do |
| not @tensor.mutable? |
| end |
| end |
| |
| def test_contiguous? |
| assert do |
| @tensor.contiguous? |
| end |
| end |
| |
| def test_row_major? |
| assert do |
| @tensor.row_major? |
| end |
| end |
| |
| def test_column_major? |
| assert do |
| not @tensor.column_major? |
| end |
| end |
| |
| def test_io |
| buffer = Arrow::ResizableBuffer.new(0) |
| output = Arrow::BufferOutputStream.new(buffer) |
| output.write_tensor(@tensor) |
| input = Arrow::BufferInputStream.new(buffer) |
| assert_equal(@tensor, input.read_tensor) |
| end |
| end |