| #------------------------------------------------------------- |
| # |
| # Licensed to the Apache Software Foundation (ASF) under one |
| # or more contributor license agreements. See the NOTICE file |
| # distributed with this work for additional information |
| # regarding copyright ownership. The ASF licenses this file |
| # to you under the Apache License, Version 2.0 (the |
| # "License"); you may not use this file except in compliance |
| # with the License. You may obtain a copy of the License at |
| # |
| # http://www.apache.org/licenses/LICENSE-2.0 |
| # |
| # Unless required by applicable law or agreed to in writing, |
| # software distributed under the License is distributed on an |
| # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
| # KIND, either express or implied. See the License for the |
| # specific language governing permissions and limitations |
| # under the License. |
| # |
| #------------------------------------------------------------- |
| |
| # THIS SCRIPT IMPLEMENTS A TRAINING ALGORITHM FOR RESTRICTED BOLTZMANN MACHINES |
| # |
| # The training algorithm uses one-step Contrastive Divergence as suggested in [Bengio, Y., Learning Deep |
| # Architectures for AI, Foundations and Trends in Machine Learning, Volume 2 Issue 1, January 2009, p. 1–127]. |
| # This implementation is slightly modified and uses mini-batches to speed up the learning process. |
| # |
| # INPUT PARAMETERS: |
| # -------------------------------------------------------------------------------------------- |
| # NAME TYPE DEFAULT MEANING |
| # -------------------------------------------------------------------------------------------- |
| # X String ---- Location (on HDFS) to read the matrix X of feature vectors |
| # W String ---- Location to store the estimated RBM weights |
| # A String ---- Location to store the visible units bias vector |
| # B String ---- Location to store the hidden units bias vector |
| # batchsize Int 100 Number of observations in a single batch |
| # vis Int ---- Number of units in the visible layer. If not specified this will be set |
| # to the number of features in X |
| # hid Int 2 Number of units in the hidden layer |
| # epochs Int 10 Number of training epochs |
| # alpha Double 0.1 Learning rate |
| # fmt String text Matrix output format for W,A, and B, usually "text" or "csv" |
| # -------------------------------------------------------------------------------------------- |
| # OUTPUT: Matrices containing weights and biases for the trained network |
| # |
| # OUTPUT SIZE: OUTPUT CONTENTS: |
| # W vis x hid RBM weights |
| # A 1 x vis Visible units biases |
| # B 1 x hid Hidden units biases |
| # |
| # To run an observation through the trained network use the rbm_run.dml script. |
| # Alternatively, you can simply compute P(h=1|v) using 1.0 / (1 + exp(-(X %*% W + B)) and |
| # get the values for the hidden layer by sampling from P(h=1|v). |
| # |
| |
| # HOW TO INVOKE THIS SCRIPT - EXAMPLE: |
| # hadoop jar SystemDS.jar -f rbm.dml -nvargs X=$INPUT_DIR/X W=$OUTPUT_DIR/w A=$OUTPUT_DIR/a B=$OUTPUT_DIR/b batchsize=100 vis=10 hid=2 epochs=10 alpha=0.1 |
| |
| |
| # -------------------------------------------- START OF RBM_MINIBATCH.DML -------------------------------------------- |
| |
| # Default values |
| learning_rate = ifdef($alpha, 0.1) |
| hidden_units_count = ifdef($hid, 2) |
| epoch_count = ifdef($epochs, 10) |
| batch_size = ifdef($batchsize, 100) |
| output_format = ifdef($fmt, "text") |
| |
| print("BEGIN RBM MINIBATCH TRAINING SCRIPT") |
| |
| print("Reading X...") |
| X = read($X) |
| |
| # If number of visible units is not set, |
| # set it to the number of features in X |
| visible_units_count = ifdef ($vis, ncol(X)) |
| |
| # Get the total number of observations |
| N = nrow(X) |
| |
| # Rescale to interval [0,1] |
| minX = min(X) |
| maxX = max(X) |
| if (minX < 0.0 | maxX > 1.0) { |
| print("X not in [0,1]. Rescaling...") |
| X = (X-minX)/(maxX-minX) |
| } |
| |
| # Initialize a weight matrix (visible_units_count x hidden_units_count) |
| w = 0.1 * rand(rows = visible_units_count, cols = hidden_units_count, pdf = "uniform") |
| |
| # Setup biases and initialize with 0 |
| a = matrix(0, rows = 1, cols = visible_units_count) # Visible units bias |
| b = matrix(0, rows = 1, cols = hidden_units_count) # Hidden units bias |
| |
| # Train the RBM |
| print("Training starts...") |
| |
| for (epoch in 1:epoch_count) { |
| |
| # Cumulative error for this epoch |
| epoch_err = 0 |
| |
| # Loop over each batch |
| for (batch_start in seq(1, N, batch_size)) { |
| |
| # Compute the end of the batch for indexing |
| batch_end = batch_start + batch_size - 1 |
| |
| # Is this the last batch? |
| if (batch_end > N) { |
| batch_end = N |
| } |
| |
| # Present the current minibatch to the visible layer |
| v_1 = X[batch_start:batch_end,] |
| |
| # Get the number of observations in v_1 |
| batch_N = nrow(v_1) |
| |
| # POSITIVE PHASE |
| |
| # Compute the probabilities P(h=1|v) |
| p_h1_given_v = 1.0 / (1 + exp(-(v_1 %*% w + b))) |
| |
| # Sample from P(h=1|v) |
| h1 = p_h1_given_v > rand(rows = batch_N, cols = hidden_units_count) |
| |
| # NEGATIVE PHASE |
| |
| # Compute the probabilities P(v2=1|h1) |
| p_v2_given_h1 = 1.0 / (1 + exp(-(h1 %*% t(w) + a))) |
| |
| # Compute the probabilities P(h2=1|v2) |
| p_h2_given_v2 = 1.0 / (1 + exp(-(p_v2_given_h1 %*% w + b))) |
| |
| # UPDATE WEIGHTS AND BIASES |
| w = w + learning_rate * ((t(v_1) %*% p_h1_given_v - t(p_v2_given_h1) %*% p_h2_given_v2) / batch_N) |
| a = a + (learning_rate / batch_N) * (colSums(v_1) - colSums(p_v2_given_h1)) |
| b = b + (learning_rate / batch_N) * (colSums(p_h1_given_v) - colSums(p_h2_given_v2)) |
| |
| # COMPUTE ERROR |
| minibatch_err = sum((v_1 - p_v2_given_h1) ^ 2) / batch_N |
| epoch_err = epoch_err + minibatch_err |
| } |
| |
| print("Epoch " + epoch + " completed. Cumulative error is " + epoch_err) |
| } |
| |
| # Save the RBM model |
| print("Saving RBM weights...") |
| write(w, $W, format=output_format) |
| |
| print("Saving RBM biases...") |
| write(a, $A, format=output_format) |
| write(b, $B, format=output_format) |
| |
| print("END OF RBM MINIBATCH TRAINING SCRIPT") |
| |
| # -------------------------------------------- END OF RBM_MINIBATCH.DML -------------------------------------------- |
| |