/*
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements.  See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership.  The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License.  You may obtain a copy of the License at
 *
 *   http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
 * KIND, either express or implied.  See the License for the
 * specific language governing permissions and limitations
 * under the License.
 */

/*!
 * Reflection utilities.
 * \file node/reflection.cc
 */
#include <tvm/ir/attrs.h>
#include <tvm/node/container.h>
#include <tvm/node/node.h>
#include <tvm/node/reflection.h>
#include <tvm/runtime/registry.h>

namespace tvm {

using runtime::PackedFunc;
using runtime::TVMArgs;
using runtime::TVMRetValue;

// Attr getter.
class AttrGetter : public AttrVisitor {
 public:
  const String& skey;
  TVMRetValue* ret;

  AttrGetter(const String& skey, TVMRetValue* ret) : skey(skey), ret(ret) {}

  bool found_ref_object{false};

  void Visit(const char* key, double* value) final {
    if (skey == key) *ret = value[0];
  }
  void Visit(const char* key, int64_t* value) final {
    if (skey == key) *ret = value[0];
  }
  void Visit(const char* key, uint64_t* value) final {
    CHECK_LE(value[0], static_cast<uint64_t>(std::numeric_limits<int64_t>::max()))
        << "cannot return too big constant";
    if (skey == key) *ret = static_cast<int64_t>(value[0]);
  }
  void Visit(const char* key, int* value) final {
    if (skey == key) *ret = static_cast<int64_t>(value[0]);
  }
  void Visit(const char* key, bool* value) final {
    if (skey == key) *ret = static_cast<int64_t>(value[0]);
  }
  void Visit(const char* key, void** value) final {
    if (skey == key) *ret = static_cast<void*>(value[0]);
  }
  void Visit(const char* key, DataType* value) final {
    if (skey == key) *ret = value[0];
  }
  void Visit(const char* key, std::string* value) final {
    if (skey == key) *ret = value[0];
  }

  void Visit(const char* key, runtime::NDArray* value) final {
    if (skey == key) {
      *ret = value[0];
      found_ref_object = true;
    }
  }
  void Visit(const char* key, runtime::ObjectRef* value) final {
    if (skey == key) {
      *ret = value[0];
      found_ref_object = true;
    }
  }
};

runtime::TVMRetValue ReflectionVTable::GetAttr(Object* self, const String& field_name) const {
  runtime::TVMRetValue ret;
  AttrGetter getter(field_name, &ret);

  bool success;
  if (getter.skey == "type_key") {
    ret = self->GetTypeKey();
    success = true;
  } else if (!self->IsInstance<DictAttrsNode>()) {
    VisitAttrs(self, &getter);
    success = getter.found_ref_object || ret.type_code() != kTVMNullptr;
  } else {
    // specially handle dict attr
    DictAttrsNode* dnode = static_cast<DictAttrsNode*>(self);
    auto it = dnode->dict.find(getter.skey);
    if (it != dnode->dict.end()) {
      success = true;
      ret = (*it).second;
    } else {
      success = false;
    }
  }
  if (!success) {
    LOG(FATAL) << "AttributeError: " << self->GetTypeKey() << " object has no attributed "
               << getter.skey;
  }
  return ret;
}

// List names;
class AttrDir : public AttrVisitor {
 public:
  std::vector<std::string>* names;

  void Visit(const char* key, double* value) final { names->push_back(key); }
  void Visit(const char* key, int64_t* value) final { names->push_back(key); }
  void Visit(const char* key, uint64_t* value) final { names->push_back(key); }
  void Visit(const char* key, bool* value) final { names->push_back(key); }
  void Visit(const char* key, int* value) final { names->push_back(key); }
  void Visit(const char* key, void** value) final { names->push_back(key); }
  void Visit(const char* key, DataType* value) final { names->push_back(key); }
  void Visit(const char* key, std::string* value) final { names->push_back(key); }
  void Visit(const char* key, runtime::NDArray* value) final { names->push_back(key); }
  void Visit(const char* key, runtime::ObjectRef* value) final { names->push_back(key); }
};

std::vector<std::string> ReflectionVTable::ListAttrNames(Object* self) const {
  std::vector<std::string> names;
  AttrDir dir;
  dir.names = &names;

  if (!self->IsInstance<DictAttrsNode>()) {
    VisitAttrs(self, &dir);
  } else {
    // specially handle dict attr
    DictAttrsNode* dnode = static_cast<DictAttrsNode*>(self);
    for (const auto& kv : dnode->dict) {
      names.push_back(kv.first);
    }
  }
  return names;
}

ReflectionVTable* ReflectionVTable::Global() {
  static ReflectionVTable inst;
  return &inst;
}

ObjectPtr<Object> ReflectionVTable::CreateInitObject(const std::string& type_key,
                                                     const std::string& repr_bytes) const {
  uint32_t tindex = Object::TypeKey2Index(type_key);
  if (tindex >= fcreate_.size() || fcreate_[tindex] == nullptr) {
    LOG(FATAL) << "TypeError: " << type_key << " is not registered via TVM_REGISTER_NODE_TYPE";
  }
  return fcreate_[tindex](repr_bytes);
}

class NodeAttrSetter : public AttrVisitor {
 public:
  std::string type_key;
  std::unordered_map<std::string, runtime::TVMArgValue> attrs;

  void Visit(const char* key, double* value) final { *value = GetAttr(key).operator double(); }
  void Visit(const char* key, int64_t* value) final { *value = GetAttr(key).operator int64_t(); }
  void Visit(const char* key, uint64_t* value) final { *value = GetAttr(key).operator uint64_t(); }
  void Visit(const char* key, int* value) final { *value = GetAttr(key).operator int(); }
  void Visit(const char* key, bool* value) final { *value = GetAttr(key).operator bool(); }
  void Visit(const char* key, std::string* value) final {
    *value = GetAttr(key).operator std::string();
  }
  void Visit(const char* key, void** value) final { *value = GetAttr(key).operator void*(); }
  void Visit(const char* key, DataType* value) final { *value = GetAttr(key).operator DataType(); }
  void Visit(const char* key, runtime::NDArray* value) final {
    *value = GetAttr(key).operator runtime::NDArray();
  }
  void Visit(const char* key, ObjectRef* value) final {
    *value = GetAttr(key).operator ObjectRef();
  }

 private:
  runtime::TVMArgValue GetAttr(const char* key) {
    auto it = attrs.find(key);
    if (it == attrs.end()) {
      LOG(FATAL) << type_key << ": require field " << key;
    }
    runtime::TVMArgValue v = it->second;
    attrs.erase(it);
    return v;
  }
};

void InitNodeByPackedArgs(ReflectionVTable* reflection, Object* n, const TVMArgs& args) {
  NodeAttrSetter setter;
  setter.type_key = n->GetTypeKey();
  CHECK_EQ(args.size() % 2, 0);
  for (int i = 0; i < args.size(); i += 2) {
    setter.attrs.emplace(args[i].operator std::string(), args[i + 1]);
  }
  reflection->VisitAttrs(n, &setter);

  if (setter.attrs.size() != 0) {
    std::ostringstream os;
    os << setter.type_key << " does not contain field ";
    for (const auto& kv : setter.attrs) {
      os << " " << kv.first;
    }
    LOG(FATAL) << os.str();
  }
}

ObjectRef ReflectionVTable::CreateObject(const std::string& type_key, const TVMArgs& kwargs) {
  ObjectPtr<Object> n = this->CreateInitObject(type_key);
  if (n->IsInstance<BaseAttrsNode>()) {
    static_cast<BaseAttrsNode*>(n.get())->InitByPackedArgs(kwargs);
  } else {
    InitNodeByPackedArgs(this, n.get(), kwargs);
  }
  return ObjectRef(n);
}

ObjectRef ReflectionVTable::CreateObject(const std::string& type_key,
                                         const Map<String, ObjectRef>& kwargs) {
  // Redirect to the TVMArgs version
  // It is not the most efficient way, but CreateObject is not meant to be used
  // in a fast code-path and is mainly reserved as a flexible API for frontends.
  std::vector<TVMValue> values(kwargs.size() * 2);
  std::vector<int32_t> tcodes(kwargs.size() * 2);
  runtime::TVMArgsSetter setter(values.data(), tcodes.data());
  int index = 0;

  for (const auto& kv : *static_cast<const MapNode*>(kwargs.get())) {
    setter(index, Downcast<String>(kv.first).c_str());
    setter(index + 1, kv.second);
    index += 2;
  }

  return CreateObject(type_key, runtime::TVMArgs(values.data(), tcodes.data(), kwargs.size() * 2));
}

// Expose to FFI APIs.
void NodeGetAttr(TVMArgs args, TVMRetValue* ret) {
  CHECK_EQ(args[0].type_code(), kTVMObjectHandle);
  Object* self = static_cast<Object*>(args[0].value().v_handle);
  *ret = ReflectionVTable::Global()->GetAttr(self, args[1]);
}

void NodeListAttrNames(TVMArgs args, TVMRetValue* ret) {
  CHECK_EQ(args[0].type_code(), kTVMObjectHandle);
  Object* self = static_cast<Object*>(args[0].value().v_handle);

  auto names =
      std::make_shared<std::vector<std::string> >(ReflectionVTable::Global()->ListAttrNames(self));

  *ret = PackedFunc([names](TVMArgs args, TVMRetValue* rv) {
    int64_t i = args[0];
    if (i == -1) {
      *rv = static_cast<int64_t>(names->size());
    } else {
      *rv = (*names)[i];
    }
  });
}

// API function to make node.
// args format:
//   key1, value1, ..., key_n, value_n
void MakeNode(const TVMArgs& args, TVMRetValue* rv) {
  std::string type_key = args[0];
  std::string empty_str;
  TVMArgs kwargs(args.values + 1, args.type_codes + 1, args.size() - 1);
  *rv = ReflectionVTable::Global()->CreateObject(type_key, kwargs);
}

TVM_REGISTER_GLOBAL("node.NodeGetAttr").set_body(NodeGetAttr);

TVM_REGISTER_GLOBAL("node.NodeListAttrNames").set_body(NodeListAttrNames);

TVM_REGISTER_GLOBAL("node.MakeNode").set_body(MakeNode);
}  // namespace tvm
