#!/usr/bin/env python3

import argparse as ap
import configparser as cp
import json
import os
import re
import textwrap

import requests
import yaml


# Environments read from config file
# {"cloudant": {"api_key": "abc123", ...}
# {"power": {"api_key", "cde456", ...}
#
ENV = {}

# Hard-coded VMs managed by someone else without access to cloud.ibm.com
EXTRA = {}

IBM_CLOUD_URL = "https://us-south.iaas.cloud.ibm.com/v1"
IAM_URL = "https://iam.cloud.ibm.com/identity/token"

IBM_CLOUD_GENERATION = "2"
IBM_CLOUD_VERSION = "2019-08-09"

# Authed API request sessions, one per environment
# {"cloudant": <Session>, ...}
#
SESS = {}


def load_environment():
    path = os.path.expanduser("~/.couchdb-infra-cm.cfg")
    if not os.path.exists(path):
        print(f"Missing config file: {path}")
        exit(1)
    parser = cp.ConfigParser()
    parser.read([path])

    for section in parser.sections():
        if section.startswith("extra"):
            split = section.split(".")
            if len(split) != 2:
                print(f"Invalid 'extra' section {section}")
                exit(1)

            (_, name) = split

            EXTRA[name] = {
                "name" : name,
                "instance_id" : parser.get(section, "id", fallback=name),
                "ip_addr": parser.get(section, "ip_addr"),
                "system": {
                    "arch": parser.get(section, "arch"),
                    "num_cpus": int(parser.get(section, "num_cpus")),
                    "ram": int(parser.get(section, "ram"))
                }
            }

        if not section.startswith("ibmcloud"):
            continue

        split = section.split(".")
        if len(split) == 2:
            (_, env) = split
        elif len(split) == 1:
            env = "<default>"
        else:
            print(f"Invalid 'ibmcloud' section {section}")
            exit(1)

        ENV[env] = {
            "api_key" : parser.get(section, "api_key"),
            "iam_url" : parser.get(section, "iam_url",
                                   fallback=IAM_URL),
            "api_url" : parser.get(section, "api_url",
                                   fallback=IBM_CLOUD_URL),
            "api_generation" : parser.get(section, "api_generation",
                                          fallback=IBM_CLOUD_GENERATION),
            "api_version" : parser.get(section, "api_version",
                                       fallback=IBM_CLOUD_VERSION),
            "crn" : parser.get(section, "crn", fallback=None),
            "instance_id" : parser.get(section, "instance_id",
                                       fallback=None)
        }


def load_iam_tokens():
    for env in ENV:
        sess = requests.session()
        headers = {
            "Accept": "application/json"
        }
        data = {
            "grant_type": "urn:ibm:params:oauth:grant-type:apikey",
            "apikey": ENV[env]["api_key"]
        }
        resp = sess.post(ENV[env]["iam_url"], headers=headers, data=data)
        resp.raise_for_status()
        body = resp.json()
        token = body["token_type"] + " " + body["access_token"]
        sess.headers["Authorization"] = token
        for hk, hv in env_headers(env).items():
            sess.headers[hk] = hv
        SESS[env] = sess


def init():
    load_environment()
    load_iam_tokens()


def list_instances():
    for env in ENV:
        if env != "power":
            yield from list_x86_instances(env)
        if env == "power":
            yield from list_power_instances(env)
    for inst  in EXTRA.values():
        yield inst


def list_x86_instances(env):
    url = ENV[env]["api_url"] + "/instances"
    sess = SESS[env]
    while url:
        resp = sess.get(url, params=params(env))
        body = resp.json()
        for instance in body["instances"]:
            interface_url = instance["primary_network_interface"]["href"]
            resp = sess.get(interface_url, params=params(env))
            instance["primary_network_interface"] = resp.json()
            yield instance
        url = body.get("next")


def list_power_instances(env):
    # It's really just one instance but to keep up with the shape of the rest
    # of the code emit it as a generated item
    if not ENV[env]["instance_id"]:
        Exception("Power instances require an 'instance_id' setting")

    instance_id = ENV[env]["instance_id"]
    url = ENV[env]["api_url"] + "/pvm-instances/" + instance_id
    sess = SESS[env]
    resp = sess.get(url, params=params(env))
    instance = resp.json()
    instance["name"] = instance["serverName"]
    instance["system"] = {"arch": "power"}
    yield instance



def params(env):
    params = {"limit": 100}
    if ENV[env].get("api_version"):
        params["version"] = ENV[env]["api_version"]
    if ENV[env].get("api_generation"):
        params["generation"] = ENV[env]["api_generation"]
    return params


def env_headers(env):
    headers = {}
    if ENV[env].get("crn"):
        headers["crn"] = ENV[env]["crn"]
    return headers


def load_bastion(bastions, instance):
    if instance["status"] != "running":
        return

    name = instance["name"]
    if name in bastions:
        print(f"Duplicate bastion found {name}")
        exit(2)

    ip_addr = None
    net_iface = instance["primary_network_interface"]
    floating_ips = net_iface.get("floating_ips", [])
    if not floating_ips:
        print(f"Bastion is missing a public IP: {name}")
        exit(2)
    ip_addr = floating_ips[0]["address"]

    bastions[name] = {
        "instance": {
            "id": instance["id"],
            "name": instance["name"],
            "created_at": instance["created_at"],
            "profile": instance["profile"]["name"],
            "vpc": instance["vpc"]["name"],
            "zone": instance["zone"]["name"],
            "subnet": net_iface["subnet"]["name"]
        },
        "ip_addrs": {
            "public": ip_addr,
            "private": get_private_ip(instance)
        },
        "system": {
            "arch": instance["vcpu"]["architecture"],
            "num_cpus": instance["vcpu"]["count"],
            "ram": instance["memory"]
        }
    }


def load_ci_agent(ci_agents, instance):
    if instance["status"] != "running":
        return

    name = instance["name"]
    if name in ci_agents:
        print(f"Duplicate ci_agent found {name}")
        exit(2)
    net_iface = instance["primary_network_interface"]

    ci_agents[name] = {
        "instance": {
            "id": instance["id"],
            "name": instance["name"],
            "created_at": instance["created_at"],
            "profile": instance["profile"]["name"],
            "vpc": instance["vpc"]["name"],
            "zone": instance["zone"]["name"],
            "subnet": net_iface["subnet"]["name"]
        },
        "ip_addrs": {
            "bastion_ip": None,
            "bastion_host": None,
            "public": None,
            "private": get_private_ip(instance)
        },
        "system": {
            "arch": instance["vcpu"]["architecture"],
            "num_cpus": instance["vcpu"]["count"],
            "ram": instance["memory"]
        }
    }


def load_power_ci_agent(ci_agents, instance):
    if instance["status"] != "ACTIVE":
        return

    name = instance["name"]
    if name in ci_agents:
        print(f"Duplicate ci_agent found {name}")
        exit(2)

    ci_agents[name] = {
        "instance": {
            "id": instance["pvmInstanceID"],
            "name": instance["serverName"],
            "created_at": instance["creationDate"],
            "profile": instance["procType"],
            "vpc" : instance["sysType"],
            "zone": instance["placementGroup"],
            "subnet" : None
        },
        "ip_addrs": {
            "bastion_ip": None,
            "bastion_host": None,
            "public": instance["addresses"][0]["externalIP"],
            "private": instance["addresses"][0]["ip"]
        },
        "system": {
            "arch": "power",
            "num_cpus": instance["virtualCores"]["assigned"],
            "ram": instance["memory"]
        }
    }


def load_s390x_ci_agent(ci_agents, instance):
    name = instance["name"]
    if name in ci_agents:
        print(f"Duplicate ci_agent found {name}")
        exit(2)

    ci_agents[name] = {
        "instance": {
            "id": instance["instance_id"],
            "name" : instance["name"],
            "subnet": None
        },
        "ip_addrs": {
            "bastion_ip": None,
            "bastion_host": None,
            "public": instance["ip_addr"]
        },
        "system": instance["system"]
    }


def get_private_ip(instance):
    ip = instance["primary_network_interface"]["primary_ipv4_address"]
    if ip:
        return ip

    for iface in instance["network_interfaces"]:
        if iface.get("primary_ipv4_address"):
            return iface["primary_ipv4_address"]

    raise Exception("Unable to locate a private IP address")


def assign_bastions(bastions, ci_agents):
    subnets = {}
    for (host, bastion) in bastions.items():
        subnet = bastion["instance"]["subnet"]
        ip_addr = bastion["ip_addrs"]["public"]
        assert subnet not in subnets
        subnets[subnet] = (ip_addr, host)
    for (host, ci_agent) in ci_agents.items():
        if ci_agent["system"]["arch"] in ["power", "s390x"]:
            # Power & s390x an external IP without bastions
            continue
        subnet = ci_agent["instance"]["subnet"]
        assert subnet in subnets
        (bastion_ip, bastion_host) = subnets[subnet]
        ci_agent["ip_addrs"]["bastion_ip"] = bastion_ip
        ci_agent["ip_addrs"]["bastion_host"] = bastion_host


def write_inventory(fname, bastions, ci_agents):
    inventory = {"all": {
        "children": {
            "bastions": {
                "hosts": bastions
            },
            "ci_agents": {
                "hosts": ci_agents
            }
        }
    }}

    with open(fname, "w") as handle:
        yaml.dump(inventory, stream=handle, default_flow_style=False)


def write_ssh_cfg(filename, bastions, ci_agents):
    bastion_tmpl = textwrap.dedent("""\
        Host {host}
          Hostname {ip_addr}
          User {user}
          ForwardAgent yes
          StrictHostKeyChecking no
          ControlMaster auto
          ControlPath /tmp/ansible-%r@%h:%p
          ControlPersist 30m

        """)
    ci_agent_tmpl = textwrap.dedent("""\
        Host {host}
          Hostname {ip_addr}
          User {user}
          StrictHostKeyChecking no
          ProxyCommand /usr/bin/ssh -W %h:%p -q root@{bastion_host}
        """)
    with open(filename, "w") as handle:
        for host, info in sorted(bastions.items()):
            args = {
                "user": "root",
                "host": host,
                "ip_addr": info["ip_addrs"]["public"]
            }
            entry = bastion_tmpl.format(**args)
            handle.write(entry)
        for host, info in sorted(ci_agents.items()):
            if info["system"]["arch"] in ["power", "s390x"]:
                # Power or s390x use an external IP directly
                args = {
                    "user": "ubuntu",
                    "host": host,
                    "ip_addr": info["ip_addrs"]["public"]
                }
                entry = bastion_tmpl.format(**args)
            else:
                bastion_ip =  info["ip_addrs"]["bastion_ip"]
                bastion_host = info["ip_addrs"]["bastion_host"]
                args = {
                    "user": "root",
                    "host": host,
                    "ip_addr": info["ip_addrs"]["private"],
                    "bastion_ip": bastion_ip,
                    "bastion_host": bastion_host
                }
                entry = ci_agent_tmpl.format(**args)
            handle.write(entry)


def parse_args():
    parser = ap.ArgumentParser(description="Inventory Generation")
    parser.add_argument(
            "--inventory",
            default="production",
            metavar="FILE",
            type=str,
            help="Inventory filename to write"
        )
    parser.add_argument(
            "--ssh-cfg",
            default="ssh.cfg",
            metavar="FILE",
            type=str,
            help="SSH config filename to write"
        )
    return parser.parse_args()


def main():
    args = parse_args()

    init()

    bastions = {}
    ci_agents = {}

    for instance in list_instances():
        if instance["name"].startswith("couchdb-bastion"):
            load_bastion(bastions, instance)
        elif instance["name"].startswith("couchdb-worker"):
            load_ci_agent(ci_agents, instance)
        elif instance["system"]["arch"] == "power":
            load_power_ci_agent(ci_agents, instance)
        elif instance["system"]["arch"] == "s390x":
            load_s390x_ci_agent(ci_agents, instance)

    assign_bastions(bastions, ci_agents)

    write_inventory(args.inventory, bastions, ci_agents)
    sshf = args.ssh_cfg
    write_ssh_cfg(sshf, bastions, ci_agents)
    sshf_full = os.path.abspath(os.path.expanduser(os.path.expandvars(sshf)))
    print("Add 'Include %s' to your ~/.ssh/config file" % sshf_full)


if __name__ == "__main__":
    main()
