blob: 78ed2dfdecc873f01e9d09b37e14974f437244c0 [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
/*!
* \file initializer.h
* \brief random initializer
* \author Zhang Chen
*/
#ifndef MXNET_CPP_INITIALIZER_H_
#define MXNET_CPP_INITIALIZER_H_
#include <cmath>
#include <string>
#include <vector>
#include <random>
#include "mxnet-cpp/ndarray.h"
namespace mxnet {
namespace cpp {
class Initializer {
public:
static bool StringStartWith(const std::string& name, const std::string& check_str) {
return (name.size() >= check_str.size() && name.substr(0, check_str.size()) == check_str);
}
static bool StringEndWith(const std::string& name, const std::string& check_str) {
return (name.size() >= check_str.size() &&
name.substr(name.size() - check_str.size(), check_str.size()) == check_str);
}
virtual void operator()(const std::string& name, NDArray* arr) {
if (StringStartWith(name, "upsampling")) {
InitBilinear(arr);
} else if (StringEndWith(name, "bias")) {
InitBias(arr);
} else if (StringEndWith(name, "gamma")) {
InitGamma(arr);
} else if (StringEndWith(name, "beta")) {
InitBeta(arr);
} else if (StringEndWith(name, "weight")) {
InitWeight(arr);
} else if (StringEndWith(name, "moving_mean")) {
InitZero(arr);
} else if (StringEndWith(name, "moving_var")) {
InitOne(arr);
} else if (StringEndWith(name, "moving_inv_var")) {
InitZero(arr);
} else if (StringEndWith(name, "moving_avg")) {
InitZero(arr);
} else if (StringEndWith(name, "min")) {
InitZero(arr);
} else if (StringEndWith(name, "max")) {
InitOne(arr);
} else if (StringEndWith(name, "weight_quantize")) {
InitQuantizedWeight(arr);
} else if (StringEndWith(name, "bias_quantize")) {
InitQuantizedBias(arr);
} else {
InitDefault(arr);
}
}
protected:
virtual void InitBilinear(NDArray* arr) {
Shape shape(arr->GetShape());
std::vector<float> weight(shape.Size(), 0);
int f = std::ceil(shape[3] / 2.0);
float c = (2 * f - 1 - f % 2) / (2. * f);
for (size_t i = 0; i < shape.Size(); ++i) {
int x = i % shape[3];
int y = (i / shape[3]) % shape[2];
weight[i] = (1 - std::abs(x / f - c)) * (1 - std::abs(y / f - c));
}
(*arr).SyncCopyFromCPU(weight);
}
virtual void InitZero(NDArray* arr) {
(*arr) = 0.0f;
}
virtual void InitOne(NDArray* arr) {
(*arr) = 1.0f;
}
virtual void InitBias(NDArray* arr) {
(*arr) = 0.0f;
}
virtual void InitGamma(NDArray* arr) {
(*arr) = 1.0f;
}
virtual void InitBeta(NDArray* arr) {
(*arr) = 0.0f;
}
virtual void InitWeight(NDArray* arr) {}
virtual void InitQuantizedWeight(NDArray* arr) {
std::default_random_engine generator;
std::uniform_int_distribution<int32_t> _val(-127, 127);
(*arr) = _val(generator);
}
virtual void InitQuantizedBias(NDArray* arr) {
(*arr) = 0;
}
virtual void InitDefault(NDArray* arr) {}
};
class Constant : public Initializer {
public:
explicit Constant(float value) : value(value) {}
void operator()(const std::string& name, NDArray* arr) override {
(*arr) = value;
}
protected:
float value;
};
class Zero : public Constant {
public:
Zero() : Constant(0.0f) {}
};
class One : public Constant {
public:
One() : Constant(1.0f) {}
};
class Uniform : public Initializer {
public:
explicit Uniform(float scale) : Uniform(-scale, scale) {}
Uniform(float begin, float end) : begin(begin), end(end) {}
void operator()(const std::string& name, NDArray* arr) override {
if (StringEndWith(name, "weight_quantize")) {
InitQuantizedWeight(arr);
return;
}
if (StringEndWith(name, "bias_quantize")) {
InitQuantizedBias(arr);
return;
}
NDArray::SampleUniform(begin, end, arr);
}
protected:
float begin, end;
};
class Normal : public Initializer {
public:
Normal(float mu, float sigma) : mu(mu), sigma(sigma) {}
void operator()(const std::string& name, NDArray* arr) override {
if (StringEndWith(name, "weight_quantize")) {
InitQuantizedWeight(arr);
return;
}
if (StringEndWith(name, "bias_quantize")) {
InitQuantizedBias(arr);
return;
}
NDArray::SampleGaussian(mu, sigma, arr);
}
protected:
float mu, sigma;
};
class Bilinear : public Initializer {
public:
Bilinear() {}
void operator()(const std::string& name, NDArray* arr) override {
if (StringEndWith(name, "weight_quantize")) {
InitQuantizedWeight(arr);
return;
}
if (StringEndWith(name, "bias_quantize")) {
InitQuantizedBias(arr);
return;
}
InitBilinear(arr);
}
};
class Xavier : public Initializer {
public:
enum RandType { gaussian, uniform } rand_type;
enum FactorType { avg, in, out } factor_type;
float magnitude;
Xavier(RandType rand_type = gaussian, // NOLINT
FactorType factor_type = avg, // NOLINT
float magnitude = 3) // NOLINT
: rand_type(rand_type), factor_type(factor_type), magnitude(magnitude) {}
void operator()(const std::string& name, NDArray* arr) override {
if (StringEndWith(name, "weight_quantize")) {
InitQuantizedWeight(arr);
return;
}
if (StringEndWith(name, "bias_quantize")) {
InitQuantizedBias(arr);
return;
}
Shape shape(arr->GetShape());
float hw_scale = 1.0f;
if (shape.ndim() > 2) {
for (size_t i = 2; i < shape.ndim(); ++i) {
hw_scale *= shape[i];
}
}
float fan_in = shape[1] * hw_scale, fan_out = shape[0] * hw_scale;
float factor = 1.0f;
switch (factor_type) {
case avg:
factor = (fan_in + fan_out) / 2.0;
break;
case in:
factor = fan_in;
break;
case out:
factor = fan_out;
}
float scale = std::sqrt(magnitude / factor);
switch (rand_type) {
case uniform:
NDArray::SampleUniform(-scale, scale, arr);
break;
case gaussian:
NDArray::SampleGaussian(0, scale, arr);
break;
}
}
};
class MSRAPrelu : public Xavier {
public:
explicit MSRAPrelu(FactorType factor_type = avg, float slope = 0.25f)
: Xavier(gaussian, factor_type, 2. / (1 + slope * slope)) {}
};
} // namespace cpp
} // namespace mxnet
#endif // MXNET_CPP_INITIALIZER_H_