blob: a5838d7dbbba88cb0da2c48c34bab041f3d15dcf [file] [log] [blame]
#-------------------------------------------------------------
#
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
#
#-------------------------------------------------------------
Xfile = $X; # input matrix X of size m x n
Wfile = $W; # original row factor of size m x r
Hfile = $H; # original col factor of size r x n
m = $rows; # no. of rows of X
n = $cols; # no. of cols of X
r = $rank; # rank of factorization
nnz = $nnz; # no. of nonzeros in X
sigma = ifdef ($sigma, 0.01); # variance of Gaussian noise
fmt = ifdef ($fmt, "binary"); # output format
# generate original factors by sampling from a normal(0,1.0) distribution
W = rand(rows = m, cols = r, pdf = "normal", seed = 123);
H = rand(rows = r, cols = n, pdf = "normal", seed = 456);
I = floor(rand(rows = nnz, cols = 1, min = 1, max = m + 0.999999999));
J = floor(rand(rows = nnz, cols = 1, min = 1, max = n + 0.999999999));
X = rand(rows = nnz, cols = 1, pdf = "normal") * sqrt(sigma);
N = table(I, J, X);
T = (N != 0);
X = T * (W %*% H) + T * N;
write(X, Xfile, format = fmt);
write(W, Wfile, format = fmt);
write(H, Hfile, format = fmt);