/************************************************************
*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements.  See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership.  The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License.  You may obtain a copy of the License at
*
*   http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied.  See the License for the
* specific language governing permissions and limitations
* under the License.
*
*************************************************************/

#include "./stub.h"

#include <glog/logging.h>
#include <unistd.h>
#include <map>
#include <thread>
#include <set>
#include "mshadow/tensor.h"
#include "proto/common.pb.h"
#include "utils/cluster.h"
#include "utils/common.h"
#include "utils/tinydir.h"

namespace singa {

using std::vector;
using std::string;

/***********************Stub****************************/
Stub::~Stub() {
  delete router_;
}
void Stub::Setup() {
  router_ = new Router();
  router_->Bind(kInprocRouterEndpoint);
  auto cluster = Cluster::Get();
  const string hostip = cluster->hostip();
  int port = router_->Bind("tcp://" + hostip + ":*");
  endpoint_ = hostip + ":" + std::to_string(port);
}
/**
 * Get a hash id for a Param object from a group.
 *
 * Simple multiple group_id with a large prime number 997 (assuming there are
 * no more than 997 worker groups) and plus owner param id.
 */
inline int Hash(int grp_id, int param_id) {
  return grp_id * 997 + param_id;
}
const std::unordered_map<int, ParamEntry*>  CreateParamShard(
    const vector<Worker*>& workers) {
  std::unordered_map<int, ParamEntry*> shard;
  // grp id -> net
  std::unordered_map<int, NeuralNet*> grp2net;
  // grp id -> worker id range
  std::unordered_map<int, std::pair<int, int>> grp2workers;
  for (auto worker : workers) {
    int grp = worker->grp_id(), id = worker->id();
    if (grp2net.find(grp) == grp2net.end()) {
      grp2net[grp] = worker->train_net();
      grp2workers[grp] = std::make_pair(id, id + 1);
    } else {
      CHECK_EQ(grp2net[grp], worker->train_net());
      int start = grp2workers[grp].first, end = grp2workers[grp].second;
      if (start > id) start = id;
      if (end < id + 1) end = id + 1;
      grp2workers[grp] = std::make_pair(start, end);
    }
  }

  for (const auto entry : grp2net) {
    int grp = entry.first;
    int wstart = grp2workers[grp].first, wend = grp2workers[grp].second;
    for (auto layer : entry.second->layers()) {
      int partition = layer->partition_id();
      bool local =  partition >= wstart && partition < wend;
      for (auto param : layer->GetParams()) {
        int hash = Hash(grp, param->owner());
        if (shard.find(hash) == shard.end())
          shard[hash] = new ParamEntry();
        shard[hash]->AddParam(local, param);
      }
    }
  }
  return shard;
}

void Stub::Run(const vector<int>& slice2server,
    const vector<Worker*>& workers, const vector<Server*>& servers) {
  slice2server_ = slice2server;
  int nworkers = workers.size(), nservers = servers.size();
  auto cluster = Cluster::Get();
  int procs_id = cluster->procs_id();
  LOG(INFO) << "Stub in process " << procs_id << " starts";
  auto shard = CreateParamShard(workers);
  std::map<int, Dealer*> inter_dealers;  // for sending msg to other procs
  std::queue<Msg*> msg_queue;
  while (true) {
    Msg* msg = nullptr;
    if (msg_queue.empty()) {
      msg = router_->Receive();
    } else {
      msg = msg_queue.front();
      msg_queue.pop();
    }
    int type = msg->type(), dst = msg->dst(), flag = AddrType(dst);
    if (flag == kStub && (AddrProc(dst) == procs_id || AddrGrp(dst) == -1)) {
      //  the following statements are ordered!
      if (type == kConnect) {
        DeleteMsg(&msg);
      } else if (type == kStop) {
        int src_flag = AddrType(msg->src());
        if (src_flag == kServer) nservers--;
        else if (src_flag == kWorkerParam) nworkers--;
        DeleteMsg(&msg);
        if (nworkers == 0 && nservers == 0) break;
      } else {
        int grp;
        int paramid = ParamID(msg->trgt_val());
        ParamEntry *entry = nullptr;
        switch (type) {
          case kUpdate:
            grp = AddrGrp(msg->src());
            entry = shard.at(Hash(grp, paramid));
            for (auto update_msg : HandleUpdateRequest(entry, &msg))
              msg_queue.push(update_msg);
            break;
          case kRUpdate:
            grp = AddrGrp(msg->dst());
            entry = shard.at(Hash(grp, paramid));
            HandleUpdateResponse(entry, &msg);
            break;
          case kGet:
            grp = AddrGrp(msg->src());
            entry = shard.at(Hash(grp, paramid));
            for (auto get_msg : HandleGetRequest(entry, &msg))
              msg_queue.push(get_msg);
            break;
          case kRGet:
            grp = AddrGrp(msg->dst());
            entry = shard.at(Hash(grp, paramid));
            HandleGetResponse(entry, &msg);
            break;
          case kPut:
            grp = AddrGrp(msg->src());
            entry = shard.at(Hash(grp, paramid));
            for (auto put_msg : HandlePutRequest(entry, &msg))
              msg_queue.push(put_msg);
            break;
          default:
            LOG(ERROR) << "Unknow message type:" << type;
            break;
        }
      }
    } else {
      int dst_procs = AddrProc(dst);
      if (flag != kStub)
        dst_procs = cluster->ProcsIDOf(AddrGrp(dst), AddrID(dst), flag);
      if (dst_procs != procs_id) {
        if (inter_dealers.find(dst_procs) == inter_dealers.end())
          inter_dealers[dst_procs] = CreateInterProcsDealer(dst_procs);
        inter_dealers[dst_procs]->Send(&msg);
      } else {
        router_->Send(&msg);
      }
    }
  }
  LOG(ERROR) << "Stub in process " << procs_id << " stops";
  for (auto& entry : inter_dealers)
    delete entry.second;
}

Dealer* Stub::CreateInterProcsDealer(int dst_procs) {
  // forward to other procs
  auto cluster = Cluster::Get();
  auto dealer = new Dealer();
  while (cluster->endpoint(dst_procs) == "") {
    // kCollectSleepTime));
    std::this_thread::sleep_for(std::chrono::milliseconds(3000));
    LOG(ERROR) << "waiting for procs " << dst_procs << " to register";
  }
  dealer->Connect("tcp://"+cluster->endpoint(dst_procs));
  return dealer;
}

void Stub::GenMsgs(int type, int version, ParamEntry* entry, Msg* msg,
                      vector<Msg*> *ret) {
  int procs_id = Cluster::Get()->procs_id();
  int src_grp = AddrGrp(msg->src());
  int dst_grp = src_grp / Cluster::Get()->nworker_groups_per_server_group();
  auto param = entry->shares.at(0);
  for (int idx = 0 ; idx < param->num_slices(); idx++) {
    int slice_id = param->slice_start() + idx;
    int server = slice2server_[slice_id];
    int dst_procs = Cluster::Get()->ProcsIDOf(dst_grp, server, kServer);
    Msg* new_msg = nullptr;
    if (type == kPut) {
      CHECK_GT(entry->num_total, 0);
      new_msg = param->GenPutMsg(dst_procs != procs_id, idx);
      new_msg->AddFormatFrame("i", entry->num_total);
    } else if (type == kGet) {
      new_msg = param->GenGetMsg(dst_procs != procs_id, idx);
    } else if (type == kUpdate) {
      new_msg = param->GenUpdateMsg(dst_procs != procs_id, idx);
      new_msg->AddFormatFrame("i", entry->num_local);
    } else {
      LOG(FATAL) << "Wrong type";
    }
    new_msg->set_trgt(ParamTrgt(param->owner(), slice_id), version);
    new_msg->set_src(Addr(src_grp, procs_id, kStub));
    new_msg->set_dst(Addr(dst_grp, server, kServer));
    ret->push_back(new_msg);
  }
}

const vector<Msg*> Stub::HandleGetRequest(ParamEntry* entry, Msg** msg) {
  vector<Msg*> ret;
  int version = (*msg)->trgt_version();
  if (version > entry->next_version) {
    entry->next_version = version;
    GenMsgs(kGet, version, entry, *msg, &ret);
  }
  DeleteMsg(msg);
  return ret;
}

const vector<Msg*> Stub::HandleUpdateRequest(ParamEntry *entry, Msg** msg) {
  vector<Msg*> ret;
  entry->num_update++;
  if (entry->num_update >= entry->num_local) {
    // average local gradient
    if (entry->num_local > 1) {
      auto it = entry->shares.begin();
      auto shape = mshadow::Shape1((*it)->size());
      mshadow::Tensor<mshadow::cpu, 1> sum((*it)->mutable_cpu_grad(), shape);
      for (++it; it != entry->shares.end(); it++) {
        mshadow::Tensor<mshadow::cpu, 1> grad((*it)->mutable_cpu_grad(), shape);
        sum += grad;
      }
    }
    int step = (*msg)->trgt_version();
    GenMsgs(kUpdate, step, entry, *msg, &ret);
    entry->num_update = 0;
  }
  DeleteMsg(msg);
  return ret;
}

const vector<Msg*> Stub::HandlePutRequest(ParamEntry* entry, Msg** msg) {
  vector<Msg*> ret;
  int version = (*msg)->trgt_version();
  GenMsgs(kPut, version, entry, *msg, &ret);
  DeleteMsg(msg);
  return ret;
}

void Stub::HandleGetResponse(ParamEntry* entry, Msg** msg) {
  int version = (*msg)->trgt_version();
  int sliceid = SliceID((*msg)->trgt_val());
  auto param = entry->shares.at(0);
  if (param->ParseGetResponseMsg(*msg, sliceid-param->slice_start()))
    param->set_version(version);
  DeleteMsg(msg);
}

void Stub::HandleUpdateResponse(ParamEntry* entry, Msg** msg) {
  int version = (*msg)->trgt_version();
  int sliceid = SliceID((*msg)->trgt_val());
  auto param = entry->shares.at(0);
  if (param->ParseUpdateResponseMsg(*msg, sliceid-param->slice_start()))
    param->set_version(version);
  DeleteMsg(msg);
}
}  // namespace singa
