blob: 287c17f228fc4be0e2c6f80f28307a34fc3c21f4 [file] [log] [blame]
#-------------------------------------------------------------
#
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
#
#-------------------------------------------------------------
# THIS SCRIPT USES A TRAINED RESTRICTED BOLTZMANN MACHINES TO PROCESS A SET OF OBSERVATIONS
#
# The RBM training is handled by the rbm_minibatch.dml script.
#
# INPUT PARAMETERS:
# --------------------------------------------------------------------------------------------
# NAME TYPE DEFAULT MEANING
# --------------------------------------------------------------------------------------------
# X String --- Location (on HDFS) to read the matrix of observations X
# W String --- Location (on HDFS) to read the RBM's weight vector
# A String --- Location (on HDFS) to read the RBM's visible units bias vector
# B String --- Location (on HDFS) to read the RBM's hidden units bias vector
# O String --- Location to store the processed observations
# fmt String csv Matrix output format for O, usually "text" or "csv"
# --------------------------------------------------------------------------------------------
# OUTPUT: A matrix containing a sample of P(h=1|v) for each observation in X
#
# OUTPUT SIZE: OUTPUT CONTENTS:
# O N x ncol(W) RBM outputs for each observation in X
#
# HOW TO INVOKE THIS SCRIPT - EXAMPLE:
# hadoop jar SystemML.jar -f rbm_run.dml -nvargs X=$INPUT_DIR/X W=$INPUT_DIR/w A=$INPUT_DIR/a B=$INPUT_DIR/b O=$OUTPUT_DIR/O fmt="csv"
# -------------------------------------------- START OF RBM_MINIBATCH.DML --------------------------------------------
# Default values
output_format = ifdef($fmt, "csv")
print("BEGIN RBM SCRIPT")
print("Reading model data...")
#Read data set
X = read($X)
# Get the total number of observations
N = nrow(X)
# Read model coefficients
w = read($W)
a = read($A)
b = read($B)
hidden_units_count = ncol(w)
# Rescale to interval [0,1]
minX = min(X)
maxX = max(X)
if (minX < 0.0 | maxX > 1.0) {
print("X not in [0,1]. Rescaling...")
X = (X-minX)/(maxX-minX)
}
print("Running RBM...")
# Compute the probabilities P(h=1|v)
p_h1_given_v = 1.0 / (1 + exp(-(X %*% w + b)))
# Sample from P(h=1|v)
h1 = p_h1_given_v > rand(rows = N, cols = hidden_units_count)
print("Saving hidden layer sample...")
write(h1, $O, format=output_format)
print("END RBM SCRIPT")
# -------------------------------------------- END OF RBM_RUN.DML --------------------------------------------