blob: 93666758b58c8131f4aee07219526d25cae03cd0 [file] [log] [blame]
#-------------------------------------------------------------
#
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
#
#-------------------------------------------------------------
num_records = $1;
num_features = $2;
p = $3; #sparsity
num_categories = $4; #num classes
is_intercept = $5==1;
stdevLT = 1.0;
beta_range = 3.0 * stdevLT / sqrt (num_features * p);
if (is_intercept) {
intercept = Rand (rows = 1, cols = num_categories - 1, min = -1.0, max = 1.0);
}
X = Rand( rows = num_records,
cols = num_features,
min = 1,
max = 5,
pdf = "uniform",
sparsity = p );
B = Rand (rows = num_features,
cols = num_categories - 1,
min = -1.0,
max = 1.0,
pdf = "uniform",
sparsity = 1.0) * beta_range;
LT = X %*% B;
if (is_intercept) {
LT = LT + matrix (1, rows = num_records, cols = 1) %*% intercept;
}
Prob = exp (LT);
Prob = Prob / (1.0 + rowSums(Prob));
Prob = t(cumsum (t(Prob)));
r = Rand (rows = num_records, cols = 1, min = 0, max = 1, pdf = "uniform");
Y = 1 + rowSums (Prob < r);
# ensure all classes are represented
Y[(num_records-num_categories+1):num_records,1] = seq(1,num_categories);
write(X, $6, format=$8)
write(Y, $7, format=$8);