blob: 53ecae217e8e14d3790bd9dc2e640923bb240703 [file] [log] [blame]
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
import numpy as np
import pytest
def run_tensorflow_test_with_dtype(tf, plasma, plasma_store_name,
client, use_gpu, dtype):
FORCE_DEVICE = '/gpu' if use_gpu else '/cpu'
object_id = np.random.bytes(20)
data = np.random.randn(3, 244, 244).astype(dtype)
ones = np.ones((3, 244, 244)).astype(dtype)
sess = tf.Session(config=tf.ConfigProto(
allow_soft_placement=True, log_device_placement=True))
def ToPlasma():
data_tensor = tf.constant(data)
ones_tensor = tf.constant(ones)
return plasma.tf_plasma_op.tensor_to_plasma(
[data_tensor, ones_tensor],
object_id,
plasma_store_socket_name=plasma_store_name)
def FromPlasma():
return plasma.tf_plasma_op.plasma_to_tensor(
object_id,
dtype=tf.as_dtype(dtype),
plasma_store_socket_name=plasma_store_name)
with tf.device(FORCE_DEVICE):
to_plasma = ToPlasma()
from_plasma = FromPlasma()
z = from_plasma + 1
sess.run(to_plasma)
# NOTE(zongheng): currently it returns a flat 1D tensor.
# So reshape manually.
out = sess.run(from_plasma)
out = np.split(out, 2)
out0 = out[0].reshape(3, 244, 244)
out1 = out[1].reshape(3, 244, 244)
sess.run(z)
assert np.array_equal(data, out0), "Data not equal!"
assert np.array_equal(ones, out1), "Data not equal!"
# Try getting the data from Python
plasma_object_id = plasma.ObjectID(object_id)
obj = client.get(plasma_object_id)
# Deserialized Tensor should be 64-byte aligned.
assert obj.ctypes.data % 64 == 0
result = np.split(obj, 2)
result0 = result[0].reshape(3, 244, 244)
result1 = result[1].reshape(3, 244, 244)
assert np.array_equal(data, result0), "Data not equal!"
assert np.array_equal(ones, result1), "Data not equal!"
@pytest.mark.plasma
@pytest.mark.tensorflow
@pytest.mark.skip(reason='Until ARROW-4259 is resolved')
def test_plasma_tf_op(use_gpu=False):
import pyarrow.plasma as plasma
import tensorflow as tf
plasma.build_plasma_tensorflow_op()
if plasma.tf_plasma_op is None:
pytest.skip("TensorFlow Op not found")
with plasma.start_plasma_store(10**8) as (plasma_store_name, p):
client = plasma.connect(plasma_store_name)
for dtype in [np.float32, np.float64,
np.int8, np.int16, np.int32, np.int64]:
run_tensorflow_test_with_dtype(tf, plasma, plasma_store_name,
client, use_gpu, dtype)
# Make sure the objects have been released.
for _, info in client.list().items():
assert info['ref_count'] == 0