blob: a2b51c93db8fd5ef8a99dfd2bfac26aa898d2f6c [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.ignite.ml.math.isolve.lsqr;
import java.util.Arrays;
import com.github.fommil.netlib.BLAS;
import org.apache.ignite.ml.dataset.Dataset;
import org.apache.ignite.ml.dataset.DatasetBuilder;
import org.apache.ignite.ml.dataset.PartitionDataBuilder;
import org.apache.ignite.ml.dataset.primitive.data.SimpleLabeledDatasetData;
import org.apache.ignite.ml.environment.LearningEnvironment;
import org.apache.ignite.ml.environment.LearningEnvironmentBuilder;
/**
* Distributed implementation of LSQR algorithm based on {@link AbstractLSQR} and {@link Dataset}.
*/
public class LSQROnHeap<K, V> extends AbstractLSQR implements AutoCloseable {
/** Dataset. */
private final Dataset<LSQRPartitionContext, SimpleLabeledDatasetData> dataset;
/**
* Constructs a new instance of OnHeap LSQR algorithm implementation.
*
* @param datasetBuilder Dataset builder.
* @param envBuilder Learning environment builder.
* @param partDataBuilder Partition data builder.
*/
public LSQROnHeap(DatasetBuilder<K, V> datasetBuilder,
LearningEnvironmentBuilder envBuilder,
PartitionDataBuilder<K, V, LSQRPartitionContext, SimpleLabeledDatasetData> partDataBuilder,
LearningEnvironment localLearningEnv) {
this.dataset = datasetBuilder.build(
envBuilder,
(env, upstream, upstreamSize) -> new LSQRPartitionContext(),
partDataBuilder,
localLearningEnv
);
}
/** {@inheritDoc} */
@Override protected double bnorm() {
return dataset.computeWithCtx((ctx, data) -> {
ctx.setU(Arrays.copyOf(data.getLabels(), data.getLabels().length));
return BLAS.getInstance().dnrm2(data.getLabels().length, data.getLabels(), 1);
}, (a, b) -> a == null ? b : b == null ? a : Math.sqrt(a * a + b * b));
}
/** {@inheritDoc} */
@Override protected double beta(double[] x, double alfa, double beta) {
return dataset.computeWithCtx((ctx, data) -> {
if (data.getFeatures() == null)
return null;
int cols = data.getFeatures().length / data.getRows();
BLAS.getInstance().dgemv("N", data.getRows(), cols, alfa, data.getFeatures(),
Math.max(1, data.getRows()), x, 1, beta, ctx.getU(), 1);
return BLAS.getInstance().dnrm2(ctx.getU().length, ctx.getU(), 1);
}, (a, b) -> a == null ? b : b == null ? a : Math.sqrt(a * a + b * b));
}
/** {@inheritDoc} */
@Override protected double[] iter(double bnorm, double[] target) {
double[] res = dataset.computeWithCtx((ctx, data) -> {
if (data.getFeatures() == null)
return null;
int cols = data.getFeatures().length / data.getRows();
BLAS.getInstance().dscal(ctx.getU().length, 1 / bnorm, ctx.getU(), 1);
double[] v = new double[cols];
BLAS.getInstance().dgemv("T", data.getRows(), cols, 1.0, data.getFeatures(),
Math.max(1, data.getRows()), ctx.getU(), 1, 0, v, 1);
return v;
}, (a, b) -> {
if (a == null)
return b;
else if (b == null)
return a;
else {
BLAS.getInstance().daxpy(a.length, 1.0, a, 1, b, 1);
return b;
}
});
BLAS.getInstance().daxpy(res.length, 1.0, res, 1, target, 1);
return target;
}
/**
* Returns number of columns in dataset.
*
* @return number of columns
*/
@Override protected Integer getColumns() {
return dataset.compute(
data -> data.getFeatures() == null ? null : data.getFeatures().length / data.getRows(),
(a, b) -> {
if (a == null)
return b == null ? 0 : b;
if (b == null)
return a;
return b;
}
);
}
/** {@inheritDoc} */
@Override public void close() throws Exception {
dataset.close();
}
}