| #------------------------------------------------------------- |
| # |
| # Licensed to the Apache Software Foundation (ASF) under one |
| # or more contributor license agreements. See the NOTICE file |
| # distributed with this work for additional information |
| # regarding copyright ownership. The ASF licenses this file |
| # to you under the Apache License, Version 2.0 (the |
| # "License"); you may not use this file except in compliance |
| # with the License. You may obtain a copy of the License at |
| # |
| # http://www.apache.org/licenses/LICENSE-2.0 |
| # |
| # Unless required by applicable law or agreed to in writing, |
| # software distributed under the License is distributed on an |
| # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
| # KIND, either express or implied. See the License for the |
| # specific language governing permissions and limitations |
| # under the License. |
| # |
| #------------------------------------------------------------- |
| |
| # THIS SCRIPT USES A TRAINED RESTRICTED BOLTZMANN MACHINES TO PROCESS A SET OF OBSERVATIONS |
| # |
| # The RBM training is handled by the rbm_minibatch.dml script. |
| # |
| # INPUT PARAMETERS: |
| # -------------------------------------------------------------------------------------------- |
| # NAME TYPE DEFAULT MEANING |
| # -------------------------------------------------------------------------------------------- |
| # X String --- Location (on HDFS) to read the matrix of observations X |
| # W String --- Location (on HDFS) to read the RBM's weight vector |
| # A String --- Location (on HDFS) to read the RBM's visible units bias vector |
| # B String --- Location (on HDFS) to read the RBM's hidden units bias vector |
| # O String --- Location to store the processed observations |
| # fmt String csv Matrix output format for O, usually "text" or "csv" |
| # -------------------------------------------------------------------------------------------- |
| # OUTPUT: A matrix containing a sample of P(h=1|v) for each observation in X |
| # |
| # OUTPUT SIZE: OUTPUT CONTENTS: |
| # O N x ncol(W) RBM outputs for each observation in X |
| # |
| |
| # HOW TO INVOKE THIS SCRIPT - EXAMPLE: |
| # hadoop jar SystemDS.jar -f rbm_run.dml -nvargs X=$INPUT_DIR/X W=$INPUT_DIR/w A=$INPUT_DIR/a B=$INPUT_DIR/b O=$OUTPUT_DIR/O fmt="csv" |
| |
| |
| # -------------------------------------------- START OF RBM_MINIBATCH.DML -------------------------------------------- |
| |
| # Default values |
| output_format = ifdef($fmt, "csv") |
| |
| print("BEGIN RBM SCRIPT") |
| |
| print("Reading model data...") |
| #Read data set |
| X = read($X) |
| |
| # Get the total number of observations |
| N = nrow(X) |
| |
| # Read model coefficients |
| w = read($W) |
| a = read($A) |
| b = read($B) |
| |
| hidden_units_count = ncol(w) |
| |
| # Rescale to interval [0,1] |
| minX = min(X) |
| maxX = max(X) |
| if (minX < 0.0 | maxX > 1.0) { |
| print("X not in [0,1]. Rescaling...") |
| X = (X-minX)/(maxX-minX) |
| } |
| |
| print("Running RBM...") |
| |
| # Compute the probabilities P(h=1|v) |
| p_h1_given_v = 1.0 / (1 + exp(-(X %*% w + b))) |
| |
| # Sample from P(h=1|v) |
| h1 = p_h1_given_v > rand(rows = N, cols = hidden_units_count) |
| |
| print("Saving hidden layer sample...") |
| write(h1, $O, format=output_format) |
| |
| print("END RBM SCRIPT") |
| |
| # -------------------------------------------- END OF RBM_RUN.DML -------------------------------------------- |