blob: f40b45f5a51378d2b43506db864c05dc343c11ae [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
/*!
* \file gradient_compression.h
* \brief Gradient compression for kvstore
* \author Rahul Huilgol
*/
#ifndef MXNET_KVSTORE_GRADIENT_COMPRESSION_H_
#define MXNET_KVSTORE_GRADIENT_COMPRESSION_H_
#include <dmlc/parameter.h>
#include <string>
#include <utility>
#include <vector>
#include "mxnet/ndarray.h"
namespace mxnet {
namespace kvstore {
enum class CompressionType {
kNone, kTwoBit
};
struct GradientCompressionParam : public dmlc::Parameter<GradientCompressionParam> {
std::string type;
float threshold;
DMLC_DECLARE_PARAMETER(GradientCompressionParam) {
DMLC_DECLARE_FIELD(type)
.describe("Type of gradient compression to use, like `2bit` for example");
DMLC_DECLARE_FIELD(threshold).set_default(0.5)
.describe("Threshold to use for 2bit gradient compression");
}
};
class GradientCompression {
public:
GradientCompression();
virtual ~GradientCompression() {}
/*!
* \brief sets parameters for gradient compression
* \param kwargs a vector of pair of strings. A pair represents key and value
* of the parameter. Will be parsed by GradientCompressionParam
*/
void SetParams(const std::vector<std::pair<std::string, std::string> >& kwargs);
/*!
* \brief returns type of compression if any
*/
CompressionType get_type();
/*!
* \brief returns as string the enum value of compression type
*/
std::string get_type_str();
/*!
* \brief sets two bit gradient compression
* \param threshold float value used for thresholding gradients
*/
void SetTwoBitCompression(const float threshold);
/*!
* \brief encodes parameters of gc into a string
*/
std::string EncodeParams();
/*!
* \brief decodes parameters of gc from a string and assigns them to member variables
*/
void DecodeParams(const std::string &s);
/*!
* \brief returns compression factor, which is the factor by which size of gradient
* reduces when using a particular type of compression
*/
int GetCompressionFactor();
/*!
* \brief returns the size of compressed gradients given an original sized gradient array
*/
int64_t GetCompressedSize(const int64_t original_size);
/*!
* \brief Issues quantize operation to be scheduled by the engine
* Compresses `from` into `to` and accumulates the quantization error
* into 'residual', using the quantization of type `type_`
* \param from the ndarray containing original data to be quantized
* \param to the target ndarray which contains quantized data
* \param residual the ndarray which accumulates quantization error
* \param priority Priority of the action.
*/
void Quantize(const mxnet::NDArray &from, mxnet::NDArray *to,
mxnet::NDArray *residual, const int priority);
/*!
* \brief Issues dequantize operation to be scheduled by the engine
* Decompresses `from` into `to` using current parameters of `type` and `threshold`
* \param from the ndarray containing quantized data
* \param to the target ndarray which contains final dequantized data
* \param priority Priority of the action.
*/
void Dequantize(const mxnet::NDArray &from, mxnet::NDArray *to, const int priority);
private:
/*!
* \brief denotes the type of gradient compression which has been set
*/
CompressionType type_;
/*!
* \brief denotes threshold used for quantization and dequantization
* Must be a positive value. All positive gradients will be thresholded to `threshold_` and
* all negative gradients will be thresholded to -1*`threshold_`
*/
float threshold_ = 0;
};
} // namespace kvstore
} // namespace mxnet
#endif // MXNET_KVSTORE_GRADIENT_COMPRESSION_H_