blob: 30aaec91e27fecb7f1dccb902844cb267bf24284 [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
/*!
* \file gradient_compression.cc
* \brief Gradient compression for kvstore
* \author Rahul Huilgol
*/
#include <vector>
#include "kvstore_local.h"
#include "gradient_compression.h"
#include "gradient_compression-inl.h"
namespace mxnet {
namespace kvstore {
DMLC_REGISTER_PARAMETER(GradientCompressionParam);
GradientCompression::GradientCompression() {
type_ = CompressionType::kNone;
}
void GradientCompression::SetParams(const std::vector<std::pair<std::string, std::string> >
& kwargs) {
GradientCompressionParam params;
params.InitAllowUnknown(kwargs);
CHECK_GT(params.threshold, 0) << "threshold must be greater than 0";
if (params.type == "2bit") {
SetTwoBitCompression(params.threshold);
} else {
LOG(FATAL) << "Unknown type for gradient compression " << params.type;
}
}
CompressionType GradientCompression::get_type() {
return type_;
}
std::string GradientCompression::get_type_str() {
return std::to_string(static_cast<int>(type_));
}
void GradientCompression::SetTwoBitCompression(const float threshold) {
type_ = CompressionType::kTwoBit;
threshold_ = threshold;
}
std::string GradientCompression::EncodeParams() {
using namespace std; // to reduce length of next line
string rval = get_type_str();
if (type_ == CompressionType::kTwoBit) {
rval += "," + to_string(threshold_);
}
return rval;
}
void GradientCompression::DecodeParams(const std::string &s) {
std::vector<std::string> elems;
mxnet::kvstore::split(s, ',', std::back_inserter(elems));
type_ = static_cast<CompressionType>(stoi(elems[0]));
if (elems.size() > 1) {
if (!elems[1].empty()) {
threshold_ = stof(elems[1]);
}
}
}
int GradientCompression::GetCompressionFactor() {
if (type_ == CompressionType::kTwoBit) {
return 16;
} else {
LOG(FATAL) << "Unsupported compression type: " << get_type_str();
return 0;
}
}
int64_t GradientCompression::GetCompressedSize(const int64_t original_size) {
const int bits = GetCompressionFactor();
return ((original_size % bits == 0) ?
original_size / bits :
original_size / bits + 1);
}
void GradientCompression::Quantize(const mxnet::NDArray &from, mxnet::NDArray *to,
mxnet::NDArray *residual, const int priority) {
CHECK(shape_is_known(from.shape())) << "source operand has undefined shape";
CHECK(shape_is_known(to->shape())) << "destination operand has undefined shape";
CHECK(shape_is_known(residual->shape())) << "residual operand has undefined shape";
const int a = from.ctx().dev_mask();
const int b = to->ctx().dev_mask();
const float threshold = threshold_;
if (type_ == CompressionType::kTwoBit) {
if (a == mshadow::cpu::kDevMask && b == mshadow::cpu::kDevMask) {
mxnet::Engine::Get()->PushSync([from, to, residual, threshold](mxnet::RunContext ctx) {
std::vector<mxnet::TBlob> inputs = {from.data(), residual->data(), to->data()};
Quantize2BitImpl(ctx.get_stream<mshadow::cpu>(), inputs, threshold);
}, from.ctx(), {from.var()}, {to->var(), residual->var()},
mxnet::FnProperty::kNormal, priority, "QuantizeCPU");
} else {
#if MXNET_USE_CUDA
if (a == mshadow::gpu::kDevMask && b == mshadow::gpu::kDevMask) {
mxnet::Engine::Get()->PushSync([from, to, residual, threshold](mxnet::RunContext ctx) {
std::vector<mxnet::TBlob> inputs = {from.data(), residual->data(), to->data()};
Quantize2BitImpl(ctx.get_stream<mshadow::gpu>(), inputs, threshold);
// Wait GPU kernel to complete
ctx.get_stream<mshadow::gpu>()->Wait();
}, from.ctx(), {from.var()}, {to->var(), residual->var()},
mxnet::FnProperty::kNormal, priority, "QuantizeGPU");
} else {
LOG(FATAL) << "unknown device mask";
}
#else
LOG(FATAL) << MXNET_GPU_NOT_ENABLED_ERROR;
#endif
}
} else {
LOG(FATAL) << "Unsupported quantization of type " << get_type_str();
}
}
void GradientCompression::Dequantize(const mxnet::NDArray &from, mxnet::NDArray *to,
const int priority) {
CHECK(shape_is_known(from.shape())) << "source operand has undefined shape";
CHECK(shape_is_known(to->shape())) << "destination operand has undefined shape";
const int a = from.ctx().dev_mask();
const int b = to->ctx().dev_mask();
const float threshold = threshold_;
if (type_ == CompressionType::kTwoBit) {
if (a == mshadow::cpu::kDevMask && b == mshadow::cpu::kDevMask) {
mxnet::Engine::Get()->PushSync([from, to, threshold](mxnet::RunContext ctx) {
std::vector<mxnet::TBlob> inputs = {from.data(), to->data()};
Dequantize2BitImpl(ctx.get_stream<mshadow::cpu>(), inputs, threshold);
}, from.ctx(), {from.var()}, {to->var()},
mxnet::FnProperty::kNormal, priority, "DequantizeCPU");
} else {
#if MXNET_USE_CUDA
if (a == mshadow::gpu::kDevMask && b == mshadow::gpu::kDevMask) {
mxnet::Engine::Get()->PushSync([from, to, threshold](mxnet::RunContext ctx) {
std::vector<mxnet::TBlob> inputs = {from.data(), to->data()};
Dequantize2BitImpl(ctx.get_stream<mshadow::gpu>(), inputs, threshold);
// Wait GPU kernel to complete
ctx.get_stream<mshadow::gpu>()->Wait();
}, from.ctx(), {from.var()}, {to->var()},
mxnet::FnProperty::kNormal, priority, "DequantizeGPU");
} else {
LOG(FATAL) << "unknown device mask";
}
#else
LOG(FATAL) << MXNET_GPU_NOT_ENABLED_ERROR;
#endif
}
} else {
LOG(FATAL) << "Unsupported dequantization of type " << get_type_str();
}
}
} // namespace kvstore
} // namespace mxnet