blob: e0d507ce4b9b570fd95d280c959a1172193bcac6 [file] [log] [blame]
# -*- coding: utf-8 -*-
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
import json
import os
import pyarrow as pa
import pyarrow.jvm as pa_jvm
import pytest
import six
import sys
import xml.etree.ElementTree as ET
jpype = pytest.importorskip("jpype")
@pytest.fixture(scope="session")
def root_allocator():
# This test requires Arrow Java to be built in the same source tree
pom_path = os.path.join(
os.path.dirname(__file__), '..', '..', '..',
'java', 'pom.xml')
tree = ET.parse(pom_path)
version = tree.getroot().find(
'POM:version',
namespaces={
'POM': 'http://maven.apache.org/POM/4.0.0'
}).text
jar_path = os.path.join(
os.path.dirname(__file__), '..', '..', '..',
'java', 'tools', 'target',
'arrow-tools-{}-jar-with-dependencies.jar'.format(version))
jar_path = os.getenv("ARROW_TOOLS_JAR", jar_path)
jpype.startJVM(jpype.getDefaultJVMPath(), "-Djava.class.path=" + jar_path)
return jpype.JPackage("org").apache.arrow.memory.RootAllocator(sys.maxsize)
def test_jvm_buffer(root_allocator):
# Create a buffer
jvm_buffer = root_allocator.buffer(8)
for i in range(8):
jvm_buffer.setByte(i, 8 - i)
# Convert to Python
buf = pa_jvm.jvm_buffer(jvm_buffer)
# Check its content
assert buf.to_pybytes() == b'\x08\x07\x06\x05\x04\x03\x02\x01'
def _jvm_field(jvm_spec):
om = jpype.JClass('com.fasterxml.jackson.databind.ObjectMapper')()
pojo_Field = jpype.JClass('org.apache.arrow.vector.types.pojo.Field')
return om.readValue(jvm_spec, pojo_Field)
def _jvm_schema(jvm_spec, metadata=None):
field = _jvm_field(jvm_spec)
schema_cls = jpype.JClass('org.apache.arrow.vector.types.pojo.Schema')
fields = jpype.JClass('java.util.ArrayList')()
fields.add(field)
if metadata:
dct = jpype.JClass('java.util.HashMap')()
for k, v in six.iteritems(metadata):
dct.put(k, v)
return schema_cls(fields, dct)
else:
return schema_cls(fields)
# In the following, we use the JSON serialization of the Field objects in Java.
# This ensures that we neither rely on the exact mechanics on how to construct
# them using Java code as well as enables us to define them as parameters
# without to invoke the JVM.
#
# The specifications were created using:
#
# om = jpype.JClass('com.fasterxml.jackson.databind.ObjectMapper')()
# field = … # Code to instantiate the field
# jvm_spec = om.writeValueAsString(field)
@pytest.mark.parametrize('pa_type,jvm_spec', [
(pa.null(), '{"name":"null"}'),
(pa.bool_(), '{"name":"bool"}'),
(pa.int8(), '{"name":"int","bitWidth":8,"isSigned":true}'),
(pa.int16(), '{"name":"int","bitWidth":16,"isSigned":true}'),
(pa.int32(), '{"name":"int","bitWidth":32,"isSigned":true}'),
(pa.int64(), '{"name":"int","bitWidth":64,"isSigned":true}'),
(pa.uint8(), '{"name":"int","bitWidth":8,"isSigned":false}'),
(pa.uint16(), '{"name":"int","bitWidth":16,"isSigned":false}'),
(pa.uint32(), '{"name":"int","bitWidth":32,"isSigned":false}'),
(pa.uint64(), '{"name":"int","bitWidth":64,"isSigned":false}'),
(pa.float16(), '{"name":"floatingpoint","precision":"HALF"}'),
(pa.float32(), '{"name":"floatingpoint","precision":"SINGLE"}'),
(pa.float64(), '{"name":"floatingpoint","precision":"DOUBLE"}'),
(pa.time32('s'), '{"name":"time","unit":"SECOND","bitWidth":32}'),
(pa.time32('ms'), '{"name":"time","unit":"MILLISECOND","bitWidth":32}'),
(pa.time64('us'), '{"name":"time","unit":"MICROSECOND","bitWidth":64}'),
(pa.time64('ns'), '{"name":"time","unit":"NANOSECOND","bitWidth":64}'),
(pa.timestamp('s'), '{"name":"timestamp","unit":"SECOND",'
'"timezone":null}'),
(pa.timestamp('ms'), '{"name":"timestamp","unit":"MILLISECOND",'
'"timezone":null}'),
(pa.timestamp('us'), '{"name":"timestamp","unit":"MICROSECOND",'
'"timezone":null}'),
(pa.timestamp('ns'), '{"name":"timestamp","unit":"NANOSECOND",'
'"timezone":null}'),
(pa.timestamp('ns', tz='UTC'), '{"name":"timestamp","unit":"NANOSECOND"'
',"timezone":"UTC"}'),
(pa.timestamp('ns', tz='Europe/Paris'), '{"name":"timestamp",'
'"unit":"NANOSECOND","timezone":"Europe/Paris"}'),
(pa.date32(), '{"name":"date","unit":"DAY"}'),
(pa.date64(), '{"name":"date","unit":"MILLISECOND"}'),
(pa.decimal128(19, 4), '{"name":"decimal","precision":19,"scale":4}'),
(pa.string(), '{"name":"utf8"}'),
(pa.binary(), '{"name":"binary"}'),
(pa.binary(10), '{"name":"fixedsizebinary","byteWidth":10}'),
# TODO(ARROW-2609): complex types that have children
# pa.list_(pa.int32()),
# pa.struct([pa.field('a', pa.int32()),
# pa.field('b', pa.int8()),
# pa.field('c', pa.string())]),
# pa.union([pa.field('a', pa.binary(10)),
# pa.field('b', pa.string())], mode=pa.lib.UnionMode_DENSE),
# pa.union([pa.field('a', pa.binary(10)),
# pa.field('b', pa.string())], mode=pa.lib.UnionMode_SPARSE),
# TODO: DictionaryType requires a vector in the type
# pa.dictionary(pa.int32(), pa.array(['a', 'b', 'c'])),
])
@pytest.mark.parametrize('nullable', [True, False])
def test_jvm_types(root_allocator, pa_type, jvm_spec, nullable):
spec = {
'name': 'field_name',
'nullable': nullable,
'type': json.loads(jvm_spec),
# TODO: This needs to be set for complex types
'children': []
}
jvm_field = _jvm_field(json.dumps(spec))
result = pa_jvm.field(jvm_field)
expected_field = pa.field('field_name', pa_type, nullable=nullable)
assert result == expected_field
jvm_schema = _jvm_schema(json.dumps(spec))
result = pa_jvm.schema(jvm_schema)
assert result == pa.schema([expected_field])
# Schema with custom metadata
jvm_schema = _jvm_schema(json.dumps(spec), {'meta': 'data'})
result = pa_jvm.schema(jvm_schema)
assert result == pa.schema([expected_field], {'meta': 'data'})
# These test parameters mostly use an integer range as an input as this is
# often the only type that is understood by both Python and Java
# implementations of Arrow.
@pytest.mark.parametrize('pa_type,py_data,jvm_type', [
(pa.bool_(), [True, False, True, True], 'BitVector'),
(pa.uint8(), list(range(128)), 'UInt1Vector'),
(pa.uint16(), list(range(128)), 'UInt2Vector'),
(pa.int32(), list(range(128)), 'IntVector'),
(pa.int64(), list(range(128)), 'BigIntVector'),
(pa.float32(), list(range(128)), 'Float4Vector'),
(pa.float64(), list(range(128)), 'Float8Vector'),
(pa.timestamp('s'), list(range(128)), 'TimeStampSecVector'),
(pa.timestamp('ms'), list(range(128)), 'TimeStampMilliVector'),
(pa.timestamp('us'), list(range(128)), 'TimeStampMicroVector'),
(pa.timestamp('ns'), list(range(128)), 'TimeStampNanoVector'),
# TODO(ARROW-2605): These types miss a conversion from pure Python objects
# * pa.time32('s')
# * pa.time32('ms')
# * pa.time64('us')
# * pa.time64('ns')
(pa.date32(), list(range(128)), 'DateDayVector'),
(pa.date64(), list(range(128)), 'DateMilliVector'),
# TODO(ARROW-2606): pa.decimal128(19, 4)
])
def test_jvm_array(root_allocator, pa_type, py_data, jvm_type):
# Create vector
cls = "org.apache.arrow.vector.{}".format(jvm_type)
jvm_vector = jpype.JClass(cls)("vector", root_allocator)
jvm_vector.allocateNew(len(py_data))
for i, val in enumerate(py_data):
jvm_vector.setSafe(i, val)
jvm_vector.setValueCount(len(py_data))
py_array = pa.array(py_data, type=pa_type)
jvm_array = pa_jvm.array(jvm_vector)
assert py_array.equals(jvm_array)
# These test parameters mostly use an integer range as an input as this is
# often the only type that is understood by both Python and Java
# implementations of Arrow.
@pytest.mark.parametrize('pa_type,py_data,jvm_type,jvm_spec', [
# TODO: null
(pa.bool_(), [True, False, True, True], 'BitVector', '{"name":"bool"}'),
(
pa.uint8(),
list(range(128)),
'UInt1Vector',
'{"name":"int","bitWidth":8,"isSigned":false}'
),
(
pa.uint16(),
list(range(128)),
'UInt2Vector',
'{"name":"int","bitWidth":16,"isSigned":false}'
),
(
pa.uint32(),
list(range(128)),
'UInt4Vector',
'{"name":"int","bitWidth":32,"isSigned":false}'
),
(
pa.uint64(),
list(range(128)),
'UInt8Vector',
'{"name":"int","bitWidth":64,"isSigned":false}'
),
(
pa.int8(),
list(range(128)),
'TinyIntVector',
'{"name":"int","bitWidth":8,"isSigned":true}'
),
(
pa.int16(),
list(range(128)),
'SmallIntVector',
'{"name":"int","bitWidth":16,"isSigned":true}'
),
(
pa.int32(),
list(range(128)),
'IntVector',
'{"name":"int","bitWidth":32,"isSigned":true}'
),
(
pa.int64(),
list(range(128)),
'BigIntVector',
'{"name":"int","bitWidth":64,"isSigned":true}'
),
# TODO: float16
(
pa.float32(),
list(range(128)),
'Float4Vector',
'{"name":"floatingpoint","precision":"SINGLE"}'
),
(
pa.float64(),
list(range(128)),
'Float8Vector',
'{"name":"floatingpoint","precision":"DOUBLE"}'
),
(
pa.timestamp('s'),
list(range(128)),
'TimeStampSecVector',
'{"name":"timestamp","unit":"SECOND","timezone":null}'
),
(
pa.timestamp('ms'),
list(range(128)),
'TimeStampMilliVector',
'{"name":"timestamp","unit":"MILLISECOND","timezone":null}'
),
(
pa.timestamp('us'),
list(range(128)),
'TimeStampMicroVector',
'{"name":"timestamp","unit":"MICROSECOND","timezone":null}'
),
(
pa.timestamp('ns'),
list(range(128)),
'TimeStampNanoVector',
'{"name":"timestamp","unit":"NANOSECOND","timezone":null}'
),
# TODO(ARROW-2605): These types miss a conversion from pure Python objects
# * pa.time32('s')
# * pa.time32('ms')
# * pa.time64('us')
# * pa.time64('ns')
(
pa.date32(),
list(range(128)),
'DateDayVector',
'{"name":"date","unit":"DAY"}'
),
(
pa.date64(),
list(range(128)),
'DateMilliVector',
'{"name":"date","unit":"MILLISECOND"}'
),
# TODO(ARROW-2606): pa.decimal128(19, 4)
])
def test_jvm_record_batch(root_allocator, pa_type, py_data, jvm_type,
jvm_spec):
# Create vector
cls = "org.apache.arrow.vector.{}".format(jvm_type)
jvm_vector = jpype.JClass(cls)("vector", root_allocator)
jvm_vector.allocateNew(len(py_data))
for i, val in enumerate(py_data):
jvm_vector.setSafe(i, val)
jvm_vector.setValueCount(len(py_data))
# Create field
spec = {
'name': 'field_name',
'nullable': False,
'type': json.loads(jvm_spec),
# TODO: This needs to be set for complex types
'children': []
}
jvm_field = _jvm_field(json.dumps(spec))
# Create VectorSchemaRoot
jvm_fields = jpype.JClass('java.util.ArrayList')()
jvm_fields.add(jvm_field)
jvm_vectors = jpype.JClass('java.util.ArrayList')()
jvm_vectors.add(jvm_vector)
jvm_vsr = jpype.JClass('org.apache.arrow.vector.VectorSchemaRoot')
jvm_vsr = jvm_vsr(jvm_fields, jvm_vectors, len(py_data))
py_record_batch = pa.RecordBatch.from_arrays(
[pa.array(py_data, type=pa_type)],
['col']
)
jvm_record_batch = pa_jvm.record_batch(jvm_vsr)
assert py_record_batch.equals(jvm_record_batch)
def _string_to_varchar_holder(ra, string):
nvch_cls = "org.apache.arrow.vector.holders.NullableVarCharHolder"
holder = jpype.JClass(nvch_cls)()
if string is None:
holder.isSet = 0
else:
holder.isSet = 1
value = jpype.JClass("java.lang.String")("string")
std_charsets = jpype.JClass("java.nio.charset.StandardCharsets")
bytes_ = value.getBytes(std_charsets.UTF_8)
holder.buffer = ra.buffer(len(bytes_))
holder.buffer.setBytes(0, bytes_, 0, len(bytes_))
holder.start = 0
holder.end = len(bytes_)
return holder
# TODO(ARROW-2607)
@pytest.mark.xfail(reason="from_buffers is only supported for "
"primitive arrays yet")
def test_jvm_string_array(root_allocator):
data = [u"string", None, u"töst"]
cls = "org.apache.arrow.vector.VarCharVector"
jvm_vector = jpype.JClass(cls)("vector", root_allocator)
jvm_vector.allocateNew()
for i, string in enumerate(data):
holder = _string_to_varchar_holder(root_allocator, "string")
jvm_vector.setSafe(i, holder)
jvm_vector.setValueCount(i + 1)
py_array = pa.array(data, type=pa.string())
jvm_array = pa_jvm.array(jvm_vector)
assert py_array.equals(jvm_array)