| # -*- coding: utf-8 -*- |
| # Licensed to the Apache Software Foundation (ASF) under one |
| # or more contributor license agreements. See the NOTICE file |
| # distributed with this work for additional information |
| # regarding copyright ownership. The ASF licenses this file |
| # to you under the Apache License, Version 2.0 (the |
| # "License"); you may not use this file except in compliance |
| # with the License. You may obtain a copy of the License at |
| # |
| # http://www.apache.org/licenses/LICENSE-2.0 |
| # |
| # Unless required by applicable law or agreed to in writing, |
| # software distributed under the License is distributed on an |
| # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
| # KIND, either express or implied. See the License for the |
| # specific language governing permissions and limitations |
| # under the License. |
| |
| import json |
| import os |
| import pyarrow as pa |
| import pyarrow.jvm as pa_jvm |
| import pytest |
| import six |
| import sys |
| import xml.etree.ElementTree as ET |
| |
| |
| jpype = pytest.importorskip("jpype") |
| |
| |
| @pytest.fixture(scope="session") |
| def root_allocator(): |
| # This test requires Arrow Java to be built in the same source tree |
| pom_path = os.path.join( |
| os.path.dirname(__file__), '..', '..', '..', |
| 'java', 'pom.xml') |
| tree = ET.parse(pom_path) |
| version = tree.getroot().find( |
| 'POM:version', |
| namespaces={ |
| 'POM': 'http://maven.apache.org/POM/4.0.0' |
| }).text |
| jar_path = os.path.join( |
| os.path.dirname(__file__), '..', '..', '..', |
| 'java', 'tools', 'target', |
| 'arrow-tools-{}-jar-with-dependencies.jar'.format(version)) |
| jar_path = os.getenv("ARROW_TOOLS_JAR", jar_path) |
| jpype.startJVM(jpype.getDefaultJVMPath(), "-Djava.class.path=" + jar_path) |
| return jpype.JPackage("org").apache.arrow.memory.RootAllocator(sys.maxsize) |
| |
| |
| def test_jvm_buffer(root_allocator): |
| # Create a buffer |
| jvm_buffer = root_allocator.buffer(8) |
| for i in range(8): |
| jvm_buffer.setByte(i, 8 - i) |
| |
| # Convert to Python |
| buf = pa_jvm.jvm_buffer(jvm_buffer) |
| |
| # Check its content |
| assert buf.to_pybytes() == b'\x08\x07\x06\x05\x04\x03\x02\x01' |
| |
| |
| def _jvm_field(jvm_spec): |
| om = jpype.JClass('com.fasterxml.jackson.databind.ObjectMapper')() |
| pojo_Field = jpype.JClass('org.apache.arrow.vector.types.pojo.Field') |
| return om.readValue(jvm_spec, pojo_Field) |
| |
| |
| def _jvm_schema(jvm_spec, metadata=None): |
| field = _jvm_field(jvm_spec) |
| schema_cls = jpype.JClass('org.apache.arrow.vector.types.pojo.Schema') |
| fields = jpype.JClass('java.util.ArrayList')() |
| fields.add(field) |
| if metadata: |
| dct = jpype.JClass('java.util.HashMap')() |
| for k, v in six.iteritems(metadata): |
| dct.put(k, v) |
| return schema_cls(fields, dct) |
| else: |
| return schema_cls(fields) |
| |
| |
| # In the following, we use the JSON serialization of the Field objects in Java. |
| # This ensures that we neither rely on the exact mechanics on how to construct |
| # them using Java code as well as enables us to define them as parameters |
| # without to invoke the JVM. |
| # |
| # The specifications were created using: |
| # |
| # om = jpype.JClass('com.fasterxml.jackson.databind.ObjectMapper')() |
| # field = … # Code to instantiate the field |
| # jvm_spec = om.writeValueAsString(field) |
| @pytest.mark.parametrize('pa_type,jvm_spec', [ |
| (pa.null(), '{"name":"null"}'), |
| (pa.bool_(), '{"name":"bool"}'), |
| (pa.int8(), '{"name":"int","bitWidth":8,"isSigned":true}'), |
| (pa.int16(), '{"name":"int","bitWidth":16,"isSigned":true}'), |
| (pa.int32(), '{"name":"int","bitWidth":32,"isSigned":true}'), |
| (pa.int64(), '{"name":"int","bitWidth":64,"isSigned":true}'), |
| (pa.uint8(), '{"name":"int","bitWidth":8,"isSigned":false}'), |
| (pa.uint16(), '{"name":"int","bitWidth":16,"isSigned":false}'), |
| (pa.uint32(), '{"name":"int","bitWidth":32,"isSigned":false}'), |
| (pa.uint64(), '{"name":"int","bitWidth":64,"isSigned":false}'), |
| (pa.float16(), '{"name":"floatingpoint","precision":"HALF"}'), |
| (pa.float32(), '{"name":"floatingpoint","precision":"SINGLE"}'), |
| (pa.float64(), '{"name":"floatingpoint","precision":"DOUBLE"}'), |
| (pa.time32('s'), '{"name":"time","unit":"SECOND","bitWidth":32}'), |
| (pa.time32('ms'), '{"name":"time","unit":"MILLISECOND","bitWidth":32}'), |
| (pa.time64('us'), '{"name":"time","unit":"MICROSECOND","bitWidth":64}'), |
| (pa.time64('ns'), '{"name":"time","unit":"NANOSECOND","bitWidth":64}'), |
| (pa.timestamp('s'), '{"name":"timestamp","unit":"SECOND",' |
| '"timezone":null}'), |
| (pa.timestamp('ms'), '{"name":"timestamp","unit":"MILLISECOND",' |
| '"timezone":null}'), |
| (pa.timestamp('us'), '{"name":"timestamp","unit":"MICROSECOND",' |
| '"timezone":null}'), |
| (pa.timestamp('ns'), '{"name":"timestamp","unit":"NANOSECOND",' |
| '"timezone":null}'), |
| (pa.timestamp('ns', tz='UTC'), '{"name":"timestamp","unit":"NANOSECOND"' |
| ',"timezone":"UTC"}'), |
| (pa.timestamp('ns', tz='Europe/Paris'), '{"name":"timestamp",' |
| '"unit":"NANOSECOND","timezone":"Europe/Paris"}'), |
| (pa.date32(), '{"name":"date","unit":"DAY"}'), |
| (pa.date64(), '{"name":"date","unit":"MILLISECOND"}'), |
| (pa.decimal128(19, 4), '{"name":"decimal","precision":19,"scale":4}'), |
| (pa.string(), '{"name":"utf8"}'), |
| (pa.binary(), '{"name":"binary"}'), |
| (pa.binary(10), '{"name":"fixedsizebinary","byteWidth":10}'), |
| # TODO(ARROW-2609): complex types that have children |
| # pa.list_(pa.int32()), |
| # pa.struct([pa.field('a', pa.int32()), |
| # pa.field('b', pa.int8()), |
| # pa.field('c', pa.string())]), |
| # pa.union([pa.field('a', pa.binary(10)), |
| # pa.field('b', pa.string())], mode=pa.lib.UnionMode_DENSE), |
| # pa.union([pa.field('a', pa.binary(10)), |
| # pa.field('b', pa.string())], mode=pa.lib.UnionMode_SPARSE), |
| # TODO: DictionaryType requires a vector in the type |
| # pa.dictionary(pa.int32(), pa.array(['a', 'b', 'c'])), |
| ]) |
| @pytest.mark.parametrize('nullable', [True, False]) |
| def test_jvm_types(root_allocator, pa_type, jvm_spec, nullable): |
| spec = { |
| 'name': 'field_name', |
| 'nullable': nullable, |
| 'type': json.loads(jvm_spec), |
| # TODO: This needs to be set for complex types |
| 'children': [] |
| } |
| jvm_field = _jvm_field(json.dumps(spec)) |
| result = pa_jvm.field(jvm_field) |
| expected_field = pa.field('field_name', pa_type, nullable=nullable) |
| assert result == expected_field |
| |
| jvm_schema = _jvm_schema(json.dumps(spec)) |
| result = pa_jvm.schema(jvm_schema) |
| assert result == pa.schema([expected_field]) |
| |
| # Schema with custom metadata |
| jvm_schema = _jvm_schema(json.dumps(spec), {'meta': 'data'}) |
| result = pa_jvm.schema(jvm_schema) |
| assert result == pa.schema([expected_field], {'meta': 'data'}) |
| |
| |
| # These test parameters mostly use an integer range as an input as this is |
| # often the only type that is understood by both Python and Java |
| # implementations of Arrow. |
| @pytest.mark.parametrize('pa_type,py_data,jvm_type', [ |
| (pa.bool_(), [True, False, True, True], 'BitVector'), |
| (pa.uint8(), list(range(128)), 'UInt1Vector'), |
| (pa.uint16(), list(range(128)), 'UInt2Vector'), |
| (pa.int32(), list(range(128)), 'IntVector'), |
| (pa.int64(), list(range(128)), 'BigIntVector'), |
| (pa.float32(), list(range(128)), 'Float4Vector'), |
| (pa.float64(), list(range(128)), 'Float8Vector'), |
| (pa.timestamp('s'), list(range(128)), 'TimeStampSecVector'), |
| (pa.timestamp('ms'), list(range(128)), 'TimeStampMilliVector'), |
| (pa.timestamp('us'), list(range(128)), 'TimeStampMicroVector'), |
| (pa.timestamp('ns'), list(range(128)), 'TimeStampNanoVector'), |
| # TODO(ARROW-2605): These types miss a conversion from pure Python objects |
| # * pa.time32('s') |
| # * pa.time32('ms') |
| # * pa.time64('us') |
| # * pa.time64('ns') |
| (pa.date32(), list(range(128)), 'DateDayVector'), |
| (pa.date64(), list(range(128)), 'DateMilliVector'), |
| # TODO(ARROW-2606): pa.decimal128(19, 4) |
| ]) |
| def test_jvm_array(root_allocator, pa_type, py_data, jvm_type): |
| # Create vector |
| cls = "org.apache.arrow.vector.{}".format(jvm_type) |
| jvm_vector = jpype.JClass(cls)("vector", root_allocator) |
| jvm_vector.allocateNew(len(py_data)) |
| for i, val in enumerate(py_data): |
| jvm_vector.setSafe(i, val) |
| jvm_vector.setValueCount(len(py_data)) |
| |
| py_array = pa.array(py_data, type=pa_type) |
| jvm_array = pa_jvm.array(jvm_vector) |
| |
| assert py_array.equals(jvm_array) |
| |
| |
| # These test parameters mostly use an integer range as an input as this is |
| # often the only type that is understood by both Python and Java |
| # implementations of Arrow. |
| @pytest.mark.parametrize('pa_type,py_data,jvm_type,jvm_spec', [ |
| # TODO: null |
| (pa.bool_(), [True, False, True, True], 'BitVector', '{"name":"bool"}'), |
| ( |
| pa.uint8(), |
| list(range(128)), |
| 'UInt1Vector', |
| '{"name":"int","bitWidth":8,"isSigned":false}' |
| ), |
| ( |
| pa.uint16(), |
| list(range(128)), |
| 'UInt2Vector', |
| '{"name":"int","bitWidth":16,"isSigned":false}' |
| ), |
| ( |
| pa.uint32(), |
| list(range(128)), |
| 'UInt4Vector', |
| '{"name":"int","bitWidth":32,"isSigned":false}' |
| ), |
| ( |
| pa.uint64(), |
| list(range(128)), |
| 'UInt8Vector', |
| '{"name":"int","bitWidth":64,"isSigned":false}' |
| ), |
| ( |
| pa.int8(), |
| list(range(128)), |
| 'TinyIntVector', |
| '{"name":"int","bitWidth":8,"isSigned":true}' |
| ), |
| ( |
| pa.int16(), |
| list(range(128)), |
| 'SmallIntVector', |
| '{"name":"int","bitWidth":16,"isSigned":true}' |
| ), |
| ( |
| pa.int32(), |
| list(range(128)), |
| 'IntVector', |
| '{"name":"int","bitWidth":32,"isSigned":true}' |
| ), |
| ( |
| pa.int64(), |
| list(range(128)), |
| 'BigIntVector', |
| '{"name":"int","bitWidth":64,"isSigned":true}' |
| ), |
| # TODO: float16 |
| ( |
| pa.float32(), |
| list(range(128)), |
| 'Float4Vector', |
| '{"name":"floatingpoint","precision":"SINGLE"}' |
| ), |
| ( |
| pa.float64(), |
| list(range(128)), |
| 'Float8Vector', |
| '{"name":"floatingpoint","precision":"DOUBLE"}' |
| ), |
| ( |
| pa.timestamp('s'), |
| list(range(128)), |
| 'TimeStampSecVector', |
| '{"name":"timestamp","unit":"SECOND","timezone":null}' |
| ), |
| ( |
| pa.timestamp('ms'), |
| list(range(128)), |
| 'TimeStampMilliVector', |
| '{"name":"timestamp","unit":"MILLISECOND","timezone":null}' |
| ), |
| ( |
| pa.timestamp('us'), |
| list(range(128)), |
| 'TimeStampMicroVector', |
| '{"name":"timestamp","unit":"MICROSECOND","timezone":null}' |
| ), |
| ( |
| pa.timestamp('ns'), |
| list(range(128)), |
| 'TimeStampNanoVector', |
| '{"name":"timestamp","unit":"NANOSECOND","timezone":null}' |
| ), |
| # TODO(ARROW-2605): These types miss a conversion from pure Python objects |
| # * pa.time32('s') |
| # * pa.time32('ms') |
| # * pa.time64('us') |
| # * pa.time64('ns') |
| ( |
| pa.date32(), |
| list(range(128)), |
| 'DateDayVector', |
| '{"name":"date","unit":"DAY"}' |
| ), |
| ( |
| pa.date64(), |
| list(range(128)), |
| 'DateMilliVector', |
| '{"name":"date","unit":"MILLISECOND"}' |
| ), |
| # TODO(ARROW-2606): pa.decimal128(19, 4) |
| ]) |
| def test_jvm_record_batch(root_allocator, pa_type, py_data, jvm_type, |
| jvm_spec): |
| # Create vector |
| cls = "org.apache.arrow.vector.{}".format(jvm_type) |
| jvm_vector = jpype.JClass(cls)("vector", root_allocator) |
| jvm_vector.allocateNew(len(py_data)) |
| for i, val in enumerate(py_data): |
| jvm_vector.setSafe(i, val) |
| jvm_vector.setValueCount(len(py_data)) |
| |
| # Create field |
| spec = { |
| 'name': 'field_name', |
| 'nullable': False, |
| 'type': json.loads(jvm_spec), |
| # TODO: This needs to be set for complex types |
| 'children': [] |
| } |
| jvm_field = _jvm_field(json.dumps(spec)) |
| |
| # Create VectorSchemaRoot |
| jvm_fields = jpype.JClass('java.util.ArrayList')() |
| jvm_fields.add(jvm_field) |
| jvm_vectors = jpype.JClass('java.util.ArrayList')() |
| jvm_vectors.add(jvm_vector) |
| jvm_vsr = jpype.JClass('org.apache.arrow.vector.VectorSchemaRoot') |
| jvm_vsr = jvm_vsr(jvm_fields, jvm_vectors, len(py_data)) |
| |
| py_record_batch = pa.RecordBatch.from_arrays( |
| [pa.array(py_data, type=pa_type)], |
| ['col'] |
| ) |
| jvm_record_batch = pa_jvm.record_batch(jvm_vsr) |
| |
| assert py_record_batch.equals(jvm_record_batch) |
| |
| |
| def _string_to_varchar_holder(ra, string): |
| nvch_cls = "org.apache.arrow.vector.holders.NullableVarCharHolder" |
| holder = jpype.JClass(nvch_cls)() |
| if string is None: |
| holder.isSet = 0 |
| else: |
| holder.isSet = 1 |
| value = jpype.JClass("java.lang.String")("string") |
| std_charsets = jpype.JClass("java.nio.charset.StandardCharsets") |
| bytes_ = value.getBytes(std_charsets.UTF_8) |
| holder.buffer = ra.buffer(len(bytes_)) |
| holder.buffer.setBytes(0, bytes_, 0, len(bytes_)) |
| holder.start = 0 |
| holder.end = len(bytes_) |
| return holder |
| |
| |
| # TODO(ARROW-2607) |
| @pytest.mark.xfail(reason="from_buffers is only supported for " |
| "primitive arrays yet") |
| def test_jvm_string_array(root_allocator): |
| data = [u"string", None, u"töst"] |
| cls = "org.apache.arrow.vector.VarCharVector" |
| jvm_vector = jpype.JClass(cls)("vector", root_allocator) |
| jvm_vector.allocateNew() |
| |
| for i, string in enumerate(data): |
| holder = _string_to_varchar_holder(root_allocator, "string") |
| jvm_vector.setSafe(i, holder) |
| jvm_vector.setValueCount(i + 1) |
| |
| py_array = pa.array(data, type=pa.string()) |
| jvm_array = pa_jvm.array(jvm_vector) |
| |
| assert py_array.equals(jvm_array) |