/*
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements.  See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership.  The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License.  You may obtain a copy of the License at
 *
 *   http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
 * KIND, either express or implied.  See the License for the
 * specific language governing permissions and limitations
 * under the License.
 */
#include <tvm/ffi/function.h>
#include <tvm/ffi/reflection/registry.h>
#include <tvm/ir/expr.h>
#include <tvm/node/repr_printer.h>
#include <tvm/node/script_printer.h>

#include <algorithm>

namespace tvm {

using AccessPath = ffi::reflection::AccessPath;

TVM_FFI_STATIC_INIT_BLOCK() { PrinterConfigNode::RegisterReflection(); }

TVMScriptPrinter::FType& TVMScriptPrinter::vtable() {
  static FType inst;
  return inst;
}

std::string TVMScriptPrinter::Script(const ObjectRef& node,
                                     const ffi::Optional<PrinterConfig>& cfg) {
  if (!TVMScriptPrinter::vtable().can_dispatch(node)) {
    std::ostringstream os;
    ReprPrinter printer(os);
    printer.Print(node);
    return os.str();
  }
  return TVMScriptPrinter::vtable()(node, cfg.value_or(PrinterConfig()));
}

bool IsIdentifier(const std::string& name) {
  // Python identifiers follow the regex: "^[a-zA-Z_][a-zA-Z0-9_]*$"
  // `std::regex` would cause a symbol conflict with PyTorch, we avoids to use it in the codebase.
  //
  // We convert the regex into following conditions:
  // 1. The name is not empty.
  // 2. The first character is either an alphabet or an underscore.
  // 3. The rest of the characters are either an alphabet, a digit or an underscore.
  return name.size() > 0 &&                            //
         (std::isalpha(name[0]) || name[0] == '_') &&  //
         std::all_of(name.begin() + 1, name.end(),
                     [](char c) { return std::isalnum(c) || c == '_'; });
}

PrinterConfig::PrinterConfig(ffi::Map<ffi::String, Any> config_dict) {
  runtime::ObjectPtr<PrinterConfigNode> n = ffi::make_object<PrinterConfigNode>();
  if (auto v = config_dict.Get("name")) {
    n->binding_names.push_back(Downcast<ffi::String>(v.value()));
  }
  if (auto v = config_dict.Get("show_meta")) {
    n->show_meta = v.value().cast<bool>();
  }
  if (auto v = config_dict.Get("ir_prefix")) {
    n->ir_prefix = Downcast<ffi::String>(v.value());
  }
  if (auto v = config_dict.Get("tir_prefix")) {
    n->tir_prefix = Downcast<ffi::String>(v.value());
  }
  if (auto v = config_dict.Get("relax_prefix")) {
    n->relax_prefix = Downcast<ffi::String>(v.value());
  }
  if (auto v = config_dict.Get("module_alias")) {
    n->module_alias = Downcast<ffi::String>(v.value());
  }
  if (auto v = config_dict.Get("buffer_dtype")) {
    n->buffer_dtype = DataType(ffi::StringToDLDataType(Downcast<ffi::String>(v.value())));
  }
  if (auto v = config_dict.Get("int_dtype")) {
    n->int_dtype = DataType(ffi::StringToDLDataType(Downcast<ffi::String>(v.value())));
  }
  if (auto v = config_dict.Get("float_dtype")) {
    n->float_dtype = DataType(ffi::StringToDLDataType(Downcast<ffi::String>(v.value())));
  }
  if (auto v = config_dict.Get("verbose_expr")) {
    n->verbose_expr = v.value().cast<bool>();
  }
  if (auto v = config_dict.Get("indent_spaces")) {
    n->indent_spaces = v.value().cast<int>();
  }
  if (auto v = config_dict.Get("print_line_numbers")) {
    n->print_line_numbers = v.value().cast<bool>();
  }
  if (auto v = config_dict.Get("num_context_lines")) {
    n->num_context_lines = v.value().cast<int>();
  }
  if (auto v = config_dict.Get("path_to_underline")) {
    n->path_to_underline =
        Downcast<ffi::Optional<ffi::Array<AccessPath>>>(v).value_or(ffi::Array<AccessPath>());
  }
  if (auto v = config_dict.Get("path_to_annotate")) {
    n->path_to_annotate = Downcast<ffi::Optional<ffi::Map<AccessPath, ffi::String>>>(v).value_or(
        ffi::Map<AccessPath, ffi::String>());
  }
  if (auto v = config_dict.Get("obj_to_underline")) {
    n->obj_to_underline =
        Downcast<ffi::Optional<ffi::Array<ObjectRef>>>(v).value_or(ffi::Array<ObjectRef>());
  }
  if (auto v = config_dict.Get("obj_to_annotate")) {
    n->obj_to_annotate = Downcast<ffi::Optional<ffi::Map<ObjectRef, ffi::String>>>(v).value_or(
        ffi::Map<ObjectRef, ffi::String>());
  }
  if (auto v = config_dict.Get("syntax_sugar")) {
    n->syntax_sugar = v.value().cast<bool>();
  }
  if (auto v = config_dict.Get("show_object_address")) {
    n->show_object_address = v.value().cast<bool>();
  }
  if (auto v = config_dict.Get("show_all_struct_info")) {
    n->show_all_struct_info = v.value().cast<bool>();
  }

  // Checking prefixes if they are valid Python identifiers.
  CHECK(IsIdentifier(n->ir_prefix)) << "Invalid `ir_prefix`: " << n->ir_prefix;
  CHECK(IsIdentifier(n->tir_prefix)) << "Invalid `tir_prefix`: " << n->tir_prefix;
  CHECK(IsIdentifier(n->relax_prefix)) << "Invalid `relax_prefix`: " << n->relax_prefix;
  CHECK(n->module_alias.empty() || IsIdentifier(n->module_alias))
      << "Invalid `module_alias`: " << n->module_alias;

  this->data_ = std::move(n);
}

ffi::Array<ffi::String> PrinterConfigNode::GetBuiltinKeywords() {
  ffi::Array<ffi::String> result{this->ir_prefix, this->tir_prefix, this->relax_prefix};
  if (!this->module_alias.empty()) {
    result.push_back(this->module_alias);
  }
  return result;
}

TVM_FFI_STATIC_INIT_BLOCK() {
  namespace refl = tvm::ffi::reflection;
  refl::GlobalDef()
      .def("node.PrinterConfig",
           [](ffi::Map<ffi::String, Any> config_dict) { return PrinterConfig(config_dict); })
      .def("node.TVMScriptPrinterScript", TVMScriptPrinter::Script);
}

}  // namespace tvm
