blob: 843965256df1ddabe98402b03a126f508713c271 [file] [log] [blame]
/*!
* Copyright (c) 2016 by Contributors
* \file initializer.h
* \brief random initializer
* \author Zhang Chen
*/
#ifndef CPP_PACKAGE_INCLUDE_MXNET_CPP_INITIALIZER_H_
#define CPP_PACKAGE_INCLUDE_MXNET_CPP_INITIALIZER_H_
#include <cmath>
#include <string>
#include <vector>
#include "mxnet-cpp/ndarray.h"
namespace mxnet {
namespace cpp {
class Initializer {
public:
static bool StringStartWith(const std::string& name,
const std::string& check_str) {
return (name.size() >= check_str.size() &&
name.substr(0, check_str.size()) == check_str);
}
static bool StringEndWith(const std::string& name,
const std::string& check_str) {
return (name.size() >= check_str.size() &&
name.substr(name.size() - check_str.size(), check_str.size()) ==
check_str);
}
virtual void operator()(const std::string& name, NDArray* arr) {
if (StringStartWith(name, "upsampling")) {
InitBilinear(arr);
} else if (StringEndWith(name, "bias")) {
InitBias(arr);
} else if (StringEndWith(name, "gamma")) {
InitGamma(arr);
} else if (StringEndWith(name, "beta")) {
InitBeta(arr);
} else if (StringEndWith(name, "weight")) {
InitWeight(arr);
} else if (StringEndWith(name, "moving_mean")) {
InitZero(arr);
} else if (StringEndWith(name, "moving_var")) {
InitOne(arr);
} else if (StringEndWith(name, "moving_inv_var")) {
InitZero(arr);
} else if (StringEndWith(name, "moving_avg")) {
InitZero(arr);
} else {
InitDefault(arr);
}
}
protected:
virtual void InitBilinear(NDArray* arr) {
Shape shape(arr->GetShape());
std::vector<float> weight(shape.Size(), 0);
int f = std::ceil(shape[3] / 2.0);
float c = (2 * f - 1 - f % 2) / (2. * f);
for (size_t i = 0; i < shape.Size(); ++i) {
int x = i % shape[3];
int y = (i / shape[3]) % shape[2];
weight[i] = (1 - std::abs(x / f - c)) * (1 - std::abs(y / f - c));
}
(*arr).SyncCopyFromCPU(weight);
}
virtual void InitZero(NDArray* arr) { (*arr) = 0.0f; }
virtual void InitOne(NDArray* arr) { (*arr) = 1.0f; }
virtual void InitBias(NDArray* arr) { (*arr) = 0.0f; }
virtual void InitGamma(NDArray* arr) { (*arr) = 1.0f; }
virtual void InitBeta(NDArray* arr) { (*arr) = 0.0f; }
virtual void InitWeight(NDArray* arr) {}
virtual void InitDefault(NDArray* arr) {}
};
class Constant : public Initializer {
public:
explicit Constant(float value)
: value(value) {}
void operator()(const std::string &name, NDArray *arr) override {
(*arr) = value;
}
protected:
float value;
};
class Zero : public Constant {
public:
Zero(): Constant(0.0f) {}
};
class One : public Constant {
public:
One(): Constant(1.0f) {}
};
class Uniform : public Initializer {
public:
explicit Uniform(float scale)
: Uniform(-scale, scale) {}
Uniform(float begin, float end)
: begin(begin), end(end) {}
void operator()(const std::string &name, NDArray *arr) override {
NDArray::SampleUniform(begin, end, arr);
}
protected:
float begin, end;
};
class Normal : public Initializer {
public:
Normal(float mu, float sigma)
: mu(mu), sigma(sigma) {}
void operator()(const std::string &name, NDArray *arr) override {
NDArray::SampleGaussian(mu, sigma, arr);
}
protected:
float mu, sigma;
};
class Bilinear : public Initializer {
public:
Bilinear() {}
void operator()(const std::string &name, NDArray *arr) override {
InitBilinear(arr);
}
};
class Xavier : public Initializer {
public:
enum RandType {
gaussian,
uniform
} rand_type;
enum FactorType {
avg,
in,
out
} factor_type;
float magnitude;
Xavier(RandType rand_type = gaussian, FactorType factor_type = avg,
float magnitude = 3)
: rand_type(rand_type), factor_type(factor_type), magnitude(magnitude) {}
void operator()(const std::string &name, NDArray* arr) override {
Shape shape(arr->GetShape());
float hw_scale = 1.0f;
if (shape.ndim() > 2) {
for (size_t i = 2; i < shape.ndim(); ++i) {
hw_scale *= shape[i];
}
}
float fan_in = shape[1] * hw_scale, fan_out = shape[0] * hw_scale;
float factor = 1.0f;
switch (factor_type) {
case avg:
factor = (fan_in + fan_out) / 2.0;
break;
case in:
factor = fan_in;
break;
case out:
factor = fan_out;
}
float scale = std::sqrt(magnitude / factor);
switch (rand_type) {
case uniform:
NDArray::SampleUniform(-scale, scale, arr);
break;
case gaussian:
NDArray::SampleGaussian(0, scale, arr);
break;
}
}
};
} // namespace cpp
} // namespace mxnet
#endif // CPP_PACKAGE_INCLUDE_MXNET_CPP_INITIALIZER_H_