blob: 2cfd55ac8ba50eeab8be149d56d27ee980cbdaf1 [file] [log] [blame]
/************************************************************
*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*
*************************************************************/
#include "singa/neuralnet/connection_layer.h"
#include "singa/comm/msg.h"
namespace singa {
using std::vector;
void BridgeLayer::MakePaired(Layer* pair, int grp_id, Dealer* dealer,
std::unordered_map<std::string, Layer*>* name2bridge) {
pair_ = pair;
group_id_ = grp_id;
dealer_ = dealer;
name2bridge_ = name2bridge;
}
void BridgeLayer::SendBlobs(bool handle_data) {
CHECK(dealer_) << "NULL dealer for bridges in worker (" << group_id_
<< ", " << partition_id() << ")";
Msg *msg = new Msg();
msg->set_src(Addr(group_id_, partition_id(), kWorkerLayer));
msg->set_dst(Addr(group_id_, pair_->partition_id(), kWorkerLayer));
msg->AddFrame(pair_->name().c_str(), pair_->name().length());
auto const& blob = handle_data ? data(nullptr) : grad(nullptr);
msg->AddFrame(blob.cpu_data(), blob.count() * sizeof(float));
dealer_->Send(&msg);
}
void BridgeLayer::ReceiveBlobs(bool handle_data) {
CHECK(dealer_) << "NULL dealer for bridges in worker (" << group_id_
<< ", " << partition_id() << ")";
while (!ready()) {
auto msg = dealer_->Receive();
CHECK_EQ(AddrGrp(msg->src()), group_id_);
string name(static_cast<char*>(msg->FrameData()), msg->FrameSize());
auto receive_layer = name2bridge_->at(name);
auto blob = handle_data ? receive_layer->mutable_data(nullptr) :
receive_layer -> mutable_grad(nullptr);
msg->NextFrame();
memcpy(blob->mutable_cpu_data(), msg->FrameData(), msg->FrameSize());
dynamic_cast<BridgeLayer*>(receive_layer)->set_ready(true);
delete msg;
}
}
void BridgeSrcLayer::Setup(const LayerProto& conf,
const vector<Layer*>& srclayers) {
CHECK_GE(srclayers.size(), 1);
Layer::Setup(conf, srclayers);
data_.Reshape(srclayers[0]->data(this).shape());
grad_.ReshapeLike(data_);
data_.ShareData(srclayers[0]->mutable_data(this), false);
grad_.ShareData(srclayers[0]->mutable_grad(this), false);
}
void BridgeSrcLayer::ComputeFeature(int flag, const vector<Layer*>& srcs) {
// send data
SendBlobs(true);
// reset flag for receiving gradient in compute gradient phase
set_ready(false);
}
void BridgeSrcLayer::ComputeGradient(int flag, const vector<Layer*>& srcs) {
// receive gradient
ReceiveBlobs(false);
}
void BridgeDstLayer::Setup(const LayerProto& conf,
const vector<Layer*>& srclayers) {
CHECK_EQ(srclayers.size(), 1);
Layer::Setup(conf, srclayers);
data_.Reshape(srclayers[0]->data(this).shape());
grad_.ReshapeLike(data_);
}
void BridgeDstLayer::ComputeFeature(int flag, const vector<Layer*>& srcs) {
// receive data
ReceiveBlobs(true);
}
void BridgeDstLayer::ComputeGradient(int flag, const vector<Layer*>& srcs) {
// send gradient
SendBlobs(false);
// reset flag for receiving data in compute feature phase
set_ready(false);
}
} // namespace singa