from collections import OrderedDict, defaultdict
from copy import copy
from datetime import datetime
import json
import uuid

from flask import flash, request
from markdown import markdown
from pandas.io.json import dumps
from werkzeug.datastructures import ImmutableMultiDict
from werkzeug.urls import Href
import numpy as np
import pandas as pd

from panoramix import app, utils
from panoramix.forms import FormFactory

config = app.config


class BaseViz(object):
    viz_type = None
    verbose_name = "Base Viz"
    template = None
    is_timeseries = False
    form_fields = [
        'viz_type',
        'granularity',
        ('since', 'until'),
        'metrics', 'groupby',
    ]
    js_files = []
    css_files = []

    def __init__(self, datasource, form_data):
        self.orig_form_data = form_data
        self.datasource = datasource
        self.request = request
        self.viz_type = form_data.get("viz_type")

        ff = FormFactory(self)
        form_class = ff.get_form()
        defaults = form_class().data.copy()
        previous_viz_type = form_data.get('previous_viz_type')
        if isinstance(form_data, ImmutableMultiDict):
            form = form_class(form_data)
        else:
            form = form_class(**form_data)
        data = form.data.copy()

        if not form.validate():
            for k, v in form.errors.items():
                if not data.get('json') and not data.get('async'):
                    flash("{}: {}".format(k, " ".join(v)), 'danger')
        data = {
            k: form.data[k]
            for k in form_data.keys()
            if k in form.data}
        defaults.update(data)
        self.form_data = defaults

        self.form_data['previous_viz_type'] = self.viz_type
        self.token = self.form_data.get(
            'token', 'token_' + uuid.uuid4().hex[:8])

        self.metrics = self.form_data.get('metrics') or []
        self.groupby = self.form_data.get('groupby') or []
        self.reassignments()

    @classmethod
    def flat_form_fields(cls):
        l = []
        for obj in cls.form_fields:
            if isinstance(obj, (tuple, list)):
                l += [a for a in obj]
            else:
                l.append(obj)
        return l

    def reassignments(self):
        pass

    def get_url(self, **kwargs):
        d = self.orig_form_data.copy()
        if 'action' in d:
            del d['action']
        d.update(kwargs)
        # Remove unchecked checkboxes because HTML is weird like that
        for key in d.keys():
            if d[key] == False:
                del d[key]
        href = Href(
            '/panoramix/datasource/{self.datasource.type}/'
            '{self.datasource.id}/'.format(**locals()))
        return href(d)

    def get_df(self, query_obj=None):
        if not query_obj:
            query_obj = self.query_obj()

        self.error_msg = ""
        self.results = None

        self.results = self.datasource.query(**query_obj)
        df = self.results.df
        if df is None or df.empty:
            raise Exception("No data, review your incantations!")
        else:
            if 'timestamp' in df.columns:
                df.timestamp = pd.to_datetime(df.timestamp)
        return df

    @property
    def form(self):
        return self.form_class(**self.form_data)

    @property
    def form_class(self):
        return FormFactory(self).get_form()

    def query_filters(self):
        form_data = self.form_data
        # Building filters
        filters = []
        for i in range(1, 10):
            col = form_data.get("flt_col_" + str(i))
            op = form_data.get("flt_op_" + str(i))
            eq = form_data.get("flt_eq_" + str(i))
            if col and op and eq:
                filters.append((col, op, eq))
        return filters

    def query_obj(self):
        """
        Building a query object
        """
        form_data = self.form_data
        groupby = form_data.get("groupby") or []
        metrics = form_data.get("metrics") or ['count']
        granularity = form_data.get("granularity")
        limit = int(form_data.get("limit", 0))
        row_limit = int(
            form_data.get("row_limit", config.get("ROW_LIMIT")))
        since = form_data.get("since", "1 year ago")
        from_dttm = utils.parse_human_datetime(since)
        if from_dttm > datetime.now():
            from_dttm = datetime.now() - (from_dttm-datetime.now())
        until = form_data.get("until", "now")
        to_dttm = utils.parse_human_datetime(until)
        if from_dttm >= to_dttm:
            flash("The date range doesn't seem right.", "danger")
            from_dttm = to_dttm  # Making them identical to not raise

        # extras are used to query elements specific to a datasource type
        # for instance the extra where clause that applies only to Tables
        extras = {
            'where': form_data.get("where", '')
        }
        d = {
            'granularity': granularity,
            'from_dttm': from_dttm,
            'to_dttm': to_dttm,
            'is_timeseries': self.is_timeseries,
            'groupby': groupby,
            'metrics': metrics,
            'row_limit': row_limit,
            'filter': self.query_filters(),
            'timeseries_limit': limit,
            'extras': extras,
        }
        return d

    def get_json(self):
        payload = {
            'data': json.loads(self.get_json_data()),
            'form_data': self.form_data,
        }
        return json.dumps(payload)

    def get_json_data(self):
        return json.dumps([])

    def get_data_attribute(self):
        content = {
            'viz_name': self.viz_type,
            'json_endpoint': self.get_url(json="true"),
            'token': self.token,
        }
        return json.dumps(content)

class TableViz(BaseViz):
    viz_type = "table"
    verbose_name = "Table View"
    template = 'panoramix/viz_table.html'
    form_fields = BaseViz.form_fields + ['row_limit']
    css_files = ['lib/dataTables/dataTables.bootstrap.css']
    is_timeseries = False
    js_files = [
        'lib/dataTables/jquery.dataTables.min.js',
        'lib/dataTables/dataTables.bootstrap.js']

    def query_obj(self):
        d = super(TableViz, self).query_obj()
        d['is_timeseries'] = False
        d['timeseries_limit'] = None
        return d

    def get_df(self):
        df = super(TableViz, self).get_df()
        if (
                self.form_data.get("granularity") == "all" and
                'timestamp' in df):
            del df['timestamp']
        for m in self.metrics:
            df[m + '__perc'] = np.rint((df[m] / np.max(df[m])) * 100)
        return df


class MarkupViz(BaseViz):
    viz_type = "markup"
    verbose_name = "Markup Widget"
    template = 'panoramix/viz_markup.html'
    form_fields = ['viz_type', 'markup_type', 'code']
    is_timeseries = False

    def rendered(self):
        markup_type = self.form_data.get("markup_type")
        code = self.form_data.get("code", '')
        if markup_type == "markdown":
            return markdown(code)
        elif markup_type == "html":
            return code


class WordCloudViz(BaseViz):
    """
    Integration with the nice library at:
    https://github.com/jasondavies/d3-cloud
    """
    viz_type = "word_cloud"
    verbose_name = "Word Cloud"
    template = 'panoramix/viz_word_cloud.html'
    is_timeseries = False
    form_fields = [
        'viz_type',
        ('since', 'until'),
        'groupby', 'metric', 'limit',
        ('size_from', 'size_to'),
        'rotation',
    ]
    js_files = [
        'lib/d3.min.js',
        'lib/d3.layout.cloud.js',
        'widgets/viz_wordcloud.js',
    ]

    def query_obj(self):
        d = super(WordCloudViz, self).query_obj()
        metric = self.form_data.get('metric')
        if not metric:
            raise Exception("Pick a metric!")
        d['metrics'] = [self.form_data.get('metric')]
        d['groupby'] = [d['groupby'][0]]
        return d

    def get_json_data(self):
        df = self.get_df()
        df.columns = ['text', 'size']
        return df.to_json(orient="records")


class NVD3Viz(BaseViz):
    viz_type = None
    verbose_name = "Base NVD3 Viz"
    template = 'panoramix/viz_nvd3.html'
    is_timeseries = False
    js_files = [
        'lib/d3.min.js',
        'lib/nvd3/nv.d3.min.js',
        'widgets/viz_nvd3.js',
    ]
    css_files = [
        'lib/nvd3/nv.d3.css',
        'widgets/viz_nvd3.css',
    ]


class BubbleViz(NVD3Viz):
    viz_type = "bubble"
    verbose_name = "Bubble Chart"
    is_timeseries = False
    form_fields = [
        'viz_type',
        ('since', 'until'),
        ('series', 'entity'),
        ('x', 'y'),
        ('size', 'limit'),
        ('x_log_scale', 'y_log_scale'),
        ('show_legend', None),
    ]

    def query_obj(self):
        form_data = self.form_data
        d = super(BubbleViz, self).query_obj()
        d['groupby'] = list({
            form_data.get('series'),
            form_data.get('entity')
        })
        self.x_metric = form_data.get('x')
        self.y_metric = form_data.get('y')
        self.z_metric = form_data.get('size')
        self.entity = form_data.get('entity')
        self.series = form_data.get('series')

        d['metrics'] = [
            self.z_metric,
            self.x_metric,
            self.y_metric,
        ]
        if not all(d['metrics'] + [self.entity, self.series]):
            raise Exception("Pick a metric for x, y and size")
        return d

    def get_df(self):
        df = super(BubbleViz, self).get_df()
        df = df.fillna(0)
        df['x'] = df[[self.x_metric]]
        df['y'] = df[[self.y_metric]]
        df['size'] = df[[self.z_metric]]
        df['shape'] = 'circle'
        df['group'] = df[[self.series]]
        return df

    def get_json_data(self):
        df = self.get_df()
        series = defaultdict(list)
        for row in df.to_dict(orient='records'):
            series[row['group']].append(row)
        chart_data = []
        for k, v in series.items():
            chart_data.append({
                'key': k,
                "color": utils.color(k),
                'values': v })
        return dumps({
            'chart_data': chart_data,
            'query': self.results.query,
            'duration': self.results.duration,
        })

class BigNumberViz(BaseViz):
    viz_type = "big_number"
    verbose_name = "Big Number"
    template = 'panoramix/viz_bignumber.html'
    is_timeseries = True
    js_files = [
        'lib/d3.min.js',
        'widgets/viz_bignumber.js',
    ]
    css_files = [
        'widgets/viz_bignumber.css',
    ]
    form_fields = [
        'viz_type',
        'granularity',
        ('since', 'until'),
        'metric',
        'compare_lag',
        'compare_suffix',
        #('rolling_type', 'rolling_periods'),
    ]

    def reassignments(self):
        metric = self.form_data.get('metric')
        if not metric:
            self.form_data['metric'] = self.orig_form_data.get('metrics')


    def query_obj(self):
        d = super(BigNumberViz, self).query_obj()
        metric = self.form_data.get('metric')
        if not metric:
            raise Exception("Pick a metric!")
        d['metrics'] = [self.form_data.get('metric')]
        self.form_data['metric'] = metric
        return d

    def get_json_data(self):
        form_data = self.form_data
        df = self.get_df()
        df = df.sort(columns=df.columns[0])
        df['timestamp'] = df[[0]].astype(np.int64) // 10**9
        compare_lag = form_data.get("compare_lag", "")
        compare_lag = int(compare_lag) if compare_lag.isdigit() else 0
        d = {
            'data': df.values.tolist(),
            'compare_lag': compare_lag,
            'compare_suffix': form_data.get('compare_suffix', ''),
        }
        return json.dumps(d)


class NVD3TimeSeriesViz(NVD3Viz):
    viz_type = "line"
    verbose_name = "Time Series - Line Chart"
    sort_series = False
    is_timeseries = True
    form_fields = [
        'viz_type',
        'granularity', ('since', 'until'),
        'metrics',
        'groupby', 'limit',
        ('rolling_type', 'rolling_periods'),
        ('time_compare', 'num_period_compare'),
        ('line_interpolation', None),
        ('show_brush', 'show_legend'),
        ('rich_tooltip', 'y_axis_zero'),
        ('y_log_scale', 'contribution'),
    ]

    def get_df(self, query_obj=None):
        form_data = self.form_data
        df = super(NVD3TimeSeriesViz, self).get_df(query_obj)

        df = df.fillna(0)
        if form_data.get("granularity") == "all":
            raise Exception("Pick a time granularity for your time series")

        df = df.pivot_table(
            index="timestamp",
            columns=form_data.get('groupby'),
            values=form_data.get('metrics'))

        if self.sort_series:
            dfs = df.sum()
            dfs.sort(ascending=False)
            df = df[dfs.index]

        if form_data.get("contribution"):
            dft = df.T
            df = (dft / dft.sum()).T

        num_period_compare = form_data.get("num_period_compare")
        if num_period_compare:
            num_period_compare = int(num_period_compare)
            df = df / df.shift(num_period_compare)
            df = df[num_period_compare:]

        rolling_periods = form_data.get("rolling_periods")
        rolling_type = form_data.get("rolling_type")
        if rolling_periods and rolling_type:
            if rolling_type == 'mean':
                df = pd.rolling_mean(df, int(rolling_periods))
            elif rolling_type == 'std':
                df = pd.rolling_std(df, int(rolling_periods))
            elif rolling_type == 'sum':
                df = pd.rolling_sum(df, int(rolling_periods))
        return df

    def to_series(self, df, classed='', title_suffix=''):
        series = df.to_dict('series')
        chart_data = []
        for name in df.T.index.tolist():
            ys = series[name]
            if df[name].dtype.kind not in "biufc":
                continue
            df['timestamp'] = pd.to_datetime(df.index, utc=False)
            if isinstance(name, basestring):
                series_title = name
            else:
                name = ["{}".format(s) for s in name]
                if len(self.form_data.get('metrics')) > 1:
                    series_title = ", ".join(name)
                else:
                    series_title = ", ".join(name[1:])
            color = utils.color(series_title)
            if title_suffix:
                series_title += title_suffix

            d = {
                "key": series_title,
                "color": color,
                "classed": classed,
                "values": [
                    {'x': ds, 'y': ys[i]}
                    for i, ds in enumerate(df.timestamp)]
            }
            chart_data.append(d)
        return chart_data

    def get_json_data(self):
        df = self.get_df()
        chart_data = self.to_series(df)

        time_compare = self.form_data.get('time_compare')
        if time_compare:
            query_object = self.query_obj()
            delta = utils.parse_human_timedelta(time_compare)
            query_object['inner_from_dttm'] = query_object['from_dttm']
            query_object['inner_to_dttm'] = query_object['to_dttm']
            query_object['from_dttm'] -= delta
            query_object['to_dttm'] -= delta

            df2 = self.get_df(query_object)
            df2.index += delta
            chart_data += self.to_series(
                df2, classed='dashed', title_suffix="---")
            chart_data = sorted(chart_data, key=lambda x: x['key'])

        data = {
            'chart_data': chart_data,
            'query': self.results.query,
            'duration': self.results.duration,
        }
        return dumps(data)


class NVD3TimeSeriesBarViz(NVD3TimeSeriesViz):
    viz_type = "bar"
    verbose_name = "Time Series - Bar Chart"
    form_fields = [
        'viz_type',
        'granularity', ('since', 'until'),
        'metrics',
        'groupby', 'limit',
        ('rolling_type', 'rolling_periods'),
        'show_legend',
    ]


class NVD3CompareTimeSeriesViz(NVD3TimeSeriesViz):
    viz_type = 'compare'
    verbose_name = "Time Series - Percent Change"
    form_fields = [
        'viz_type',
        'granularity', ('since', 'until'),
        'metrics',
        'groupby', 'limit',
        ('rolling_type', 'rolling_periods'),
        'show_legend',
    ]


class NVD3TimeSeriesStackedViz(NVD3TimeSeriesViz):
    viz_type = "area"
    verbose_name = "Time Series - Stacked"
    sort_series = True
    form_fields = [
        'viz_type',
        'granularity', ('since', 'until'),
        'metrics',
        'groupby', 'limit',
        ('rolling_type', 'rolling_periods'),
        ('rich_tooltip', 'show_legend'),
    ]


class DistributionPieViz(NVD3Viz):
    viz_type = "pie"
    verbose_name = "Distribution - NVD3 - Pie Chart"
    is_timeseries = False
    form_fields = [
        'viz_type',
        ('since', 'until'),
        'metrics', 'groupby',
        'limit',
        ('donut', 'show_legend'),
    ]

    def query_obj(self):
        d = super(DistributionPieViz, self).query_obj()
        d['is_timeseries'] = False
        return d

    def get_df(self):
        df = super(DistributionPieViz, self).get_df()
        df = df.pivot_table(
            index=self.groupby,
            values=[self.metrics[0]])
        df = df.sort(self.metrics[0], ascending=False)
        return df

    def get_json_data(self):
        df = self.get_df()
        df = df.reset_index()
        df.columns = ['x', 'y']
        df['color'] = map(utils.color, df.x)
        return dumps({
            'chart_data': df.to_dict(orient="records"),
            'query': self.results.query,
            'duration': self.results.duration,
        })


class DistributionBarViz(DistributionPieViz):
    viz_type = "dist_bar"
    verbose_name = "Distribution - Bar Chart"
    is_timeseries = False
    form_fields = [
        'viz_type', 'metrics', 'groupby',
        ('since', 'until'),
        'limit',
        ('show_legend', None),
    ]

    def get_df(self):
        df = super(DistributionPieViz, self).get_df()
        df = df.pivot_table(
            index=self.groupby,
            values=self.metrics)
        df = df.sort(self.metrics[0], ascending=False)
        return df

    def get_json_data(self):
        df = self.get_df()
        series = df.to_dict('series')
        chart_data = []
        for name, ys in series.items():
            if df[name].dtype.kind not in "biufc":
                continue
            df['timestamp'] = pd.to_datetime(df.index, utc=False)
            if isinstance(name, basestring):
                series_title = name
            elif len(self.metrics) > 1:
                series_title = ", ".join(name)
            else:
                series_title = ", ".join(name[1:])
            d = {
                "key": series_title,
                "color": utils.color(series_title),
                "values": [
                    {'x': ds, 'y': ys[i]}
                    for i, ds in enumerate(df.timestamp)]
            }
            chart_data.append(d)
        return dumps({
            'chart_data': chart_data,
            'query': self.results.query,
            'duration': self.results.duration,
        })


viz_types_list = [
    TableViz,
    NVD3TimeSeriesViz,
    NVD3CompareTimeSeriesViz,
    NVD3TimeSeriesStackedViz,
    NVD3TimeSeriesBarViz,
    DistributionBarViz,
    DistributionPieViz,
    BubbleViz,
    MarkupViz,
    WordCloudViz,
    BigNumberViz,
]
# This dict is used to
viz_types = OrderedDict([(v.viz_type, v) for v in viz_types_list])
