blob: cee6c0d2c1549aed768957c46d12046743031832 [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.apache.sysds.runtime.matrix.data;
import java.io.Serializable;
import org.apache.sysds.hops.Hop;
import org.apache.sysds.runtime.DMLRuntimeException;
import org.apache.sysds.runtime.util.DnnUtils;
/**
* This class is container that stores parameters required for executing following operations:
* conv2d, conv2d_backward_data, conv2d_backward_filter, maxpooling, maxpooling_backward
*/
public class DnnParameters implements Serializable
{
private static final long serialVersionUID = -212362627205772829L;
public int N, C, H, W, K, R, S, P, Q;
public int stride_h, stride_w, pad_h, pad_w;
public int numThreads;
// Optional variables used by ConvolutionCPInstruction
public boolean enableNative = false;
public MatrixBlock input1; public MatrixBlock input2; public MatrixBlock output;
public MatrixBlock bias;
public int [] start_indexes_h, end_indexes_h, start_indexes_w, end_indexes_w;
public double minValForMaxPoolOperations = -Double.MAX_VALUE;
public DnnParameters(long N, long C, long H, long W,
long K, long R, long S, long stride_h, long stride_w,
long pad_h, long pad_w, int numThreads) {
this.N = convertToInt(N);
this.C = convertToInt(C);
this.H = convertToInt(H);
this.W = convertToInt(W);
this.K = convertToInt(K);
this.R = convertToInt(R);
this.S = convertToInt(S);
this.stride_h = convertToInt(stride_h);
this.stride_w = convertToInt(stride_w);
this.pad_h = convertToInt(pad_h);
this.pad_w = convertToInt(pad_w);
if(H >= 0 && pad_h >= 0 && R >= 0 && stride_h >= 0)
P = (int) ((H + 2 * pad_h - R) / stride_h + 1);
else
P = -1;
if(W >= 0 && pad_w >= 0 && S >= 0 && stride_w >= 0)
Q = (int) ((W + 2 * pad_w - S) / stride_w + 1);
else
Q = -1;
this.numThreads = numThreads;
}
public DnnParameters(int N, int C, int H, int W,
int K, int R, int S, int stride_h, int stride_w, int pad_h, int pad_w, int numThreads) {
this.N = N;
this.C = C;
this.H = H;
this.W = W;
this.K = K;
this.R = R;
this.S = S;
this.stride_h = stride_h;
this.stride_w = stride_w;
this.pad_h = pad_h;
this.pad_w = pad_w;
if(H <= 0 || R <= 0 || stride_h < 0 || pad_h < 0)
P = -1;
else
P = (int) DnnUtils.getP(H, R, stride_h, pad_h);
if(W <= 0 || S <= 0 || stride_w < 0 || pad_w < 0)
Q = -1;
else
Q = (int) DnnUtils.getQ(W, S, stride_w, pad_w);
this.numThreads = numThreads;
}
private static int convertToInt(long val) {
if( val > Integer.MAX_VALUE )
throw new DMLRuntimeException("The value for DnnParameters is too large:" + val);
return (int) val;
}
public boolean compare(DnnParameters that) {
if(this.N == that.N && this.C == that.C && this.H == that.H && this.W == that.W
&& this.K == that.K && this.R == that.R && this.S == that.S && this.stride_h == that.stride_h
&& this.stride_w == that.stride_w && this.pad_h == that.pad_h
&& this.pad_w == that.pad_w && this.numThreads == that.numThreads) {
return true;
}
return false;
}
@Override
public String toString() {
return "(NCHW=[" + N + " " + C + " " + H + " " + W + "], KCRS=[" + K + " " + R + " " + S + "], stride=[" + stride_h + "," + stride_w +
"], pad=[" + pad_h + "," + pad_w + "])";
}
public void setIfUnknown(Hop N, Hop C, Hop H, Hop W,
Hop K, Hop R, Hop S, Hop stride_h, Hop stride_w, Hop pad_h, Hop pad_w, int numThreads) {
if(this.N < 0) this.N = convertToInt(Hop.computeSizeInformation(N));
if(this.C < 0) this.C = convertToInt(Hop.computeSizeInformation(C));
if(this.H < 0) this.H = convertToInt(Hop.computeSizeInformation(H));
if(this.W < 0) this.W = convertToInt(Hop.computeSizeInformation(W));
if(this.K < 0) this.K = convertToInt(Hop.computeSizeInformation(K));
if(this.R < 0) this.R = convertToInt(Hop.computeSizeInformation(R));
if(this.S < 0) this.S = convertToInt(Hop.computeSizeInformation(S));
if(this.stride_h < 0) this.stride_h = convertToInt(Hop.computeSizeInformation(stride_h));
if(this.stride_w < 0) this.stride_w = convertToInt(Hop.computeSizeInformation(stride_w));
if(this.pad_h < 0) this.pad_h = convertToInt(Hop.computeSizeInformation(pad_h));
if(this.pad_w < 0) this.pad_w = convertToInt(Hop.computeSizeInformation(pad_w));
if(this.P < 0 && this.H >= 0 && this.R >= 0 && this.stride_h >= 0 && this.pad_h >= 0) {
this.P = (int) DnnUtils.getP(this.H, this.R, this.stride_h, this.pad_h);
}
if(this.Q < 0 && this.W >= 0 && this.S >= 0 && this.stride_w >= 0 && this.pad_w >= 0) {
this.Q = (int) DnnUtils.getQ(this.W, this.S, this.stride_w, this.pad_w);
}
this.numThreads = numThreads;
}
public boolean isOutputThreadSafe() {
return output.isThreadSafe();
}
public boolean isStride1Pad0() {
return (stride_h==1 && stride_w==1
&& pad_h==0 && pad_w==0);
}
public boolean isAllOnes(Integer...params) {
boolean ret = true;
for(int param : params)
ret &= (param == 1);
return ret;
}
}