blob: 115e9d140e4b04aa2262c1de828b60c58e905ba2 [file]
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
import ast
import numpy as np
import mxnet as mx
class BinaryRBM(mx.operator.CustomOp):
def __init__(self, k):
self.k = k # Persistent contrastive divergence k
def forward(self, is_train, req, in_data, out_data, aux):
visible_layer_data = in_data[0] # (num_batch, num_visible)
visible_layer_bias = in_data[1] # (num_visible,)
hidden_layer_bias = in_data[2] # (num_hidden,)
interaction_weight= in_data[3] # (num_visible, num_hidden)
if is_train:
_, hidden_layer_prob_1 = self.sample_hidden_layer(visible_layer_data, hidden_layer_bias, interaction_weight)
hidden_layer_sample = aux[1] # The initial state of the Gibbs sampling for persistent CD
else:
hidden_layer_sample, hidden_layer_prob_1 = self.sample_hidden_layer(visible_layer_data, hidden_layer_bias, interaction_weight)
# k-step Gibbs sampling
for _ in range(self.k):
visible_layer_sample, visible_layer_prob_1 = self.sample_visible_layer(hidden_layer_sample, visible_layer_bias, interaction_weight)
hidden_layer_sample, _ = self.sample_hidden_layer(visible_layer_sample, hidden_layer_bias, interaction_weight)
if is_train:
# Used in backward and next forward
aux[0][:] = visible_layer_sample
aux[1][:] = hidden_layer_sample
self.assign(out_data[0], req[0], visible_layer_prob_1)
self.assign(out_data[1], req[1], hidden_layer_prob_1)
def backward(self, req, out_grad, in_data, out_data, in_grad, aux):
visible_layer_data = in_data[0] # (num_batch, num_visible)
visible_layer_sample = aux[0] # (num_batch, num_visible)
hidden_layer_prob_1 = out_data[1] # (num_batch, num_hidden)
hidden_layer_sample = aux[1] # (num_batch, num_hidden)
grad_visible_layer_bias = (visible_layer_sample - visible_layer_data).mean(axis=0)
grad_hidden_layer_bias = (hidden_layer_sample - hidden_layer_prob_1).mean(axis=0)
grad_interaction_weight= (mx.nd.linalg.gemm2(visible_layer_sample.expand_dims(2), hidden_layer_sample.expand_dims(1)) -
mx.nd.linalg.gemm2(visible_layer_data.expand_dims(2), hidden_layer_prob_1.expand_dims(1))
).mean(axis=0)
# We don't need the gradient on the visible layer input
self.assign(in_grad[1], req[1], grad_visible_layer_bias)
self.assign(in_grad[2], req[2], grad_hidden_layer_bias)
self.assign(in_grad[3], req[3], grad_interaction_weight)
def sample_hidden_layer(self, visible_layer_batch, hidden_layer_bias, interaction_weight):
return self.sample_layer(visible_layer_batch, hidden_layer_bias, interaction_weight, False)
def sample_visible_layer(self, hidden_layer_batch, visible_layer_bias, interaction_weight):
return self.sample_layer(hidden_layer_batch, visible_layer_bias, interaction_weight, True)
def sample_layer(self, other_layer_sample, layer_bias, interaction_weight, interaction_transpose):
prob_1 = mx.nd.linalg.gemm(
other_layer_sample,
interaction_weight,
layer_bias.tile(reps=(other_layer_sample.shape[0], 1)),
transpose_b=interaction_transpose) # (num_batch, num_units_in_layer)
prob_1.sigmoid(out=prob_1)
return mx.nd.random.uniform(shape=prob_1.shape) < prob_1, prob_1
@mx.operator.register('BinaryRBM')
class BinaryRBMProp(mx.operator.CustomOpProp):
# Auxiliary states are requested only if `for_training` is true.
def __init__(self, num_hidden, k, for_training):
super(BinaryRBMProp, self).__init__(False)
self.num_hidden = int(num_hidden)
self.k = int(k)
self.for_training = ast.literal_eval(for_training)
def list_arguments(self):
# 0: (batch size, the number of visible units)
# 1: (the number of visible units,)
# 2: (the number of hidden units,)
# 3: (the number of visible units, the number of hidden units)
return ['data', 'visible_layer_bias', 'hidden_layer_bias', 'interaction_weight']
def list_outputs(self):
# 0: The probabilities that each visible unit is 1 after `k` steps of Gibbs sampling starting from the given `data`.
# (batch size, the number of visible units)
# 1: The probabilities that each hidden unit is 1 conditional on the given `data`.
# (batch size, the number of hidden units)
return ['visible_layer_prob_1', 'hidden_layer_prob_1']
def list_auxiliary_states(self):
# Used only if `self.for_trainig is true.
# 0: Store the visible layer samples obtained in the forward pass, used in the backward pass.
# (batch size, the number of visible units)
# 1: Store the hidden layer samples obtained in the forward pass, used in the backward and next forward pass.
# (batch size, the number of hidden units)
return ['aux_visible_layer_sample', 'aux_hidden_layer_sample'] if self.for_training else []
def infer_shape(self, in_shapes):
visible_layer_data_shape = in_shapes[0] # The input data
visible_layer_bias_shape = (visible_layer_data_shape[1],)
hidden_layer_bias_shape = (self.num_hidden,)
interaction_shape = (visible_layer_data_shape[1], self.num_hidden)
visible_layer_sample_shape = visible_layer_data_shape
visible_layer_prob_1_shape = visible_layer_sample_shape
hidden_layer_sample_shape = (visible_layer_data_shape[0], self.num_hidden)
hidden_layer_prob_1_shape = hidden_layer_sample_shape
return [visible_layer_data_shape, visible_layer_bias_shape, hidden_layer_bias_shape, interaction_shape], \
[visible_layer_prob_1_shape, hidden_layer_prob_1_shape], \
[visible_layer_sample_shape, hidden_layer_sample_shape] if self.for_training else []
def infer_type(self, in_type):
return [in_type[0], in_type[0], in_type[0], in_type[0]], \
[in_type[0], in_type[0]], \
[in_type[0], in_type[0]] if self.for_training else []
def create_operator(self, ctx, in_shapes, in_dtypes):
return BinaryRBM(self.k)
# For gluon API
class BinaryRBMBlock(mx.gluon.HybridBlock):
def __init__(self, num_hidden, k, for_training, **kwargs):
super(BinaryRBMBlock, self).__init__(**kwargs)
with self.name_scope():
self.num_hidden = num_hidden
self.k = k
self.for_training = for_training
self.visible_layer_bias = self.params.get('visible_layer_bias', shape=(0,), allow_deferred_init=True)
self.hidden_layer_bias = self.params.get('hidden_layer_bias', shape=(0,), allow_deferred_init=True)
self.interaction_weight= self.params.get('interaction_weight', shape=(0, 0), allow_deferred_init=True)
if for_training:
self.aux_visible_layer_sample = self.params.get('aux_visible_layer_sample', shape=(0, 0), allow_deferred_init=True)
self.aux_hidden_layer_sample = self.params.get('aux_hidden_layer_sample', shape=(0, 0), allow_deferred_init=True)
def hybrid_forward(self, F, data, visible_layer_bias, hidden_layer_bias, interaction_weight, aux_visible_layer_sample=None, aux_hidden_layer_sample=None):
# As long as `for_training` is kept constant, this conditional statement does not prevent hybridization.
if self.for_training:
return F.Custom(
data,
visible_layer_bias,
hidden_layer_bias,
interaction_weight,
aux_visible_layer_sample,
aux_hidden_layer_sample,
num_hidden=self.num_hidden,
k=self.k,
for_training=self.for_training,
op_type='BinaryRBM')
else:
return F.Custom(
data,
visible_layer_bias,
hidden_layer_bias,
interaction_weight,
num_hidden=self.num_hidden,
k=self.k,
for_training=self.for_training,
op_type='BinaryRBM')
def estimate_log_likelihood(visible_layer_bias, hidden_layer_bias, interaction_weight, ais_batch_size, ais_num_batch, ais_intermediate_steps, ais_burn_in_steps, data, ctx):
# The base-rate RBM with no hidden layer. The visible layer bias is set to the same with the given RBM.
# This is not the only possible choice but simple and works well.
base_rate_visible_layer_bias = visible_layer_bias
base_rate_visible_prob_1 = base_rate_visible_layer_bias.sigmoid()
log_base_rate_z = base_rate_visible_layer_bias.exp().log1p().sum()
def log_intermediate_unnormalized_prob(visible_layer_sample, beta):
p = mx.nd.dot(
visible_layer_sample,
(1 - beta) * base_rate_visible_layer_bias + beta * visible_layer_bias)
if beta != 0:
p += mx.nd.linalg.gemm(
visible_layer_sample,
interaction_weight,
hidden_layer_bias.tile(reps=(visible_layer_sample.shape[0], 1)),
transpose_b=False,
alpha=beta,
beta=beta).exp().log1p().sum(axis=1)
return p
def sample_base_rbm():
rands = mx.nd.random.uniform(shape=(ais_batch_size, base_rate_visible_prob_1.shape[0]), ctx=ctx)
return rands < base_rate_visible_prob_1.tile(reps=(ais_batch_size, 1))
def sample_intermediate_visible_layer(visible_layer_sample, beta):
for _ in range(ais_burn_in_steps):
hidden_prob_1 = mx.nd.linalg.gemm(
visible_layer_sample,
interaction_weight,
hidden_layer_bias.tile(reps=(visible_layer_sample.shape[0], 1)),
transpose_b=False,
alpha=beta,
beta=beta)
hidden_prob_1.sigmoid(out=hidden_prob_1)
hidden_layer_sample = mx.nd.random.uniform(shape=hidden_prob_1.shape, ctx=ctx) < hidden_prob_1
visible_prob_1 = mx.nd.linalg.gemm(
hidden_layer_sample,
interaction_weight,
visible_layer_bias.tile(reps=(hidden_layer_sample.shape[0], 1)),
transpose_b=True,
alpha=beta,
beta=beta) + (1 - beta) * base_rate_visible_layer_bias
visible_prob_1.sigmoid(out=visible_prob_1)
visible_layer_sample = mx.nd.random.uniform(shape=visible_prob_1.shape, ctx=ctx) < visible_prob_1
return visible_layer_sample
def array_from_batch(batch):
if isinstance(batch, mx.io.DataBatch):
return batch.data[0].as_in_context(ctx).flatten()
else: # batch is an instance of list in the case of gluon DataLoader
return batch[0].as_in_context(ctx).flatten()
importance_weight_sum = 0
num_ais_samples = ais_num_batch * ais_batch_size
for _ in range(ais_num_batch):
log_importance_weight = 0
visible_layer_sample = sample_base_rbm()
for n in range(1, ais_intermediate_steps + 1):
beta = 1. * n / ais_intermediate_steps
log_importance_weight += \
log_intermediate_unnormalized_prob(visible_layer_sample, beta) - \
log_intermediate_unnormalized_prob(visible_layer_sample, (n - 1.) / ais_intermediate_steps)
visible_layer_sample = sample_intermediate_visible_layer(visible_layer_sample, beta)
importance_weight_sum += log_importance_weight.exp().sum()
log_z = (importance_weight_sum / num_ais_samples).log() + log_base_rate_z
log_likelihood = 0
num_data = 0
for batch in data:
batch_array = array_from_batch(batch)
log_likelihood += log_intermediate_unnormalized_prob(batch_array, 1) - log_z
num_data += batch_array.shape[0]
log_likelihood = log_likelihood.sum() / num_data
return log_likelihood.asscalar(), log_z.asscalar()