blob: 3079a55242b8f5635fb07a11a54845500f311f4e [file] [log] [blame]
# -------------------------------------------------------------
#
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
#
# -------------------------------------------------------------
import numpy as np
from py4j.java_gateway import JavaClass, JavaObject, JVMView
def numpy_to_matrix_block(sds: 'SystemDSContext', np_arr: np.array):
"""Converts a given numpy array, to internal matrix block representation.
:param sds: The current systemds context.
:param np_arr: the numpy array to convert to matrixblock.
"""
assert (np_arr.ndim <= 2), "np_arr invalid, because it has more than 2 dimensions"
rows = np_arr.shape[0]
cols = np_arr.shape[1] if np_arr.ndim == 2 else 1
# If not numpy array then convert to numpy array
if not isinstance(np_arr, np.ndarray):
np_arr = np.asarray(np_arr, dtype=np.float64)
jvm: JVMView = sds.java_gateway.jvm
# flatten and prepare byte buffer.
if np_arr.dtype is np.dtype(np.uint8):
arr = np_arr.ravel()
value_type = jvm.org.apache.sysds.common.Types.ValueType.UINT8
elif np_arr.dtype is np.dtype(np.int32):
arr = np_arr.ravel()
value_type = jvm.org.apache.sysds.common.Types.ValueType.INT32
elif np_arr.dtype is np.dtype(np.float32):
arr = np_arr.ravel()
value_type = jvm.org.apache.sysds.common.Types.ValueType.FP32
else:
arr = np_arr.ravel().astype(np.float64)
value_type = jvm.org.apache.sysds.common.Types.ValueType.FP64
buf = bytearray(arr.tobytes())
# Send data to java.
try:
j_class: JavaClass = jvm.org.apache.sysds.runtime.util.Py4jConverterUtils
return j_class.convertPy4JArrayToMB(buf, rows, cols, value_type)
except Exception as e:
sds.exception_and_close(e)
def matrix_block_to_numpy(jvm: JVMView, mb: JavaObject):
"""Converts a MatrixBlock object in the JVM to a numpy array.
:param jvm: The current JVM instance running systemds.
:param mb: A pointer to the JVM's MatrixBlock object.
"""
num_ros = mb.getNumRows()
num_cols = mb.getNumColumns()
buf = jvm.org.apache.sysds.runtime.util.Py4jConverterUtils.convertMBtoPy4JDenseArr(
mb)
return np.frombuffer(buf, count=num_ros * num_cols, dtype=np.float64).reshape((num_ros, num_cols))