import json
import sys
import matplotlib.pyplot as plt
import numpy as np
import argparse

from typing import List, Tuple, TypedDict

# all fields of a `Points` should have the same length
class Points(TypedDict):
    x: List[float]
    y: List[float]
    e: List[float]

class LineStyle(TypedDict):
    marker: str
    type: str
    width: int

class Style(TypedDict):
    color: str
    line: LineStyle
    alpha: float

class Graph(TypedDict):
    name: str
    points: Points
    style: Style

class Font(TypedDict):
    family: str
    serif: str
    weight: str
    size: int

HELP = """## Example
```nuon
[
    {
        name: "Alice", # optional, unset or set to null won't show the grap name
        points: [
            [ x, y, e ];
            [ 1, 1143, 120 ],
            [ 2, 1310, 248 ],
            [ 4, 1609, 258 ],
            [ 8, 1953, 343 ],
            [ 16, 2145, 270 ],
            [ 32, 3427, 301 ]
        ],
        style: {},  # optional, see section below
    },
    {
        name: "Bob", # optional, unset or set to null won't show the grap name
        points: [
            [ x, y, e ];
            [ 1, 2388, 374 ],
            [ 2, 2738, 355 ],
            [ 4, 3191, 470 ],
            [ 8, 3932, 671 ],
            [ 16, 4571, 334 ],
            [ 32, 4929, 1094 ]
        ]
        style: {},  # optional, see section below
    },
]
```

## Custom style
any record inside the data can have an optional "style" specification.

below is the full shape of that specification, where all of the keys are completely optional,
default values have been chosen:
```nuon
{
    color: null,  # see https://matplotlib.org/stable/users/explain/colors/colors.html
    line: {
        marker: {
            shape: "o",  # see https://matplotlib.org/stable/api/markers_api.html
            size: 5,
        },
        type: null,  # see https://matplotlib.org/stable/gallery/lines_bars_and_markers/linestyles.html
        width: null,  # just an integer
        alpha: 1.0,  # a real number between 0 and 1
    },
    error: {
        alpha: 0.3,  # a real number between 0 and 1
    },
}

## Parametrizing the font
> see https://matplotlib.org/stable/users/explain/text/usetex.html

below is an example of the Helvetica font using _serifs_, best rendered with
`--use-tex`
```nuon
{
    size: 15,
    family: serif,
    sans-serif: Helvetica,
}
```
```"""

# see [`HELP`]
def plot(
    graphs: List[Graph],
    title: str,
    x_label: str,
    y_label: str,
    save: str = None,
    save_fig_size: Tuple[float, float] = (16, 9),
    save_dpi: int = 500,
    plot_layout: str = "constrained",
    legend_loc: str = None,
    x_scale: str = "linear",
    y_scale: str = "linear",
    x_lim: Tuple[float, float] = None,
    y_lim: Tuple[float, float] = None,
    x_ticks: List[float] = None,
    x_tick_labels: List[str] = None,
    x_ticks_rotation: float = None,
    y_ticks: List[float] = None,
    y_tick_labels: List[str] = None,
    font_size: float = None,
    font: Font = {},
    use_tex: bool = False,
):
    plt.rc("font", **font)
    plt.rcParams.update({
        "text.usetex": use_tex,
    })

    fig, ax = plt.subplots(layout=plot_layout)

    for i, g in enumerate(graphs):
        xs = list(filter(lambda x: x is not None, [p.get("x") for p in g["points"]]))
        ys = list(filter(lambda x: x is not None, [p.get("y") for p in g["points"]]))
        es = list(filter(lambda x: x is not None, [p.get("e") for p in g["points"]]))

        if len(xs) != len(ys):
            print(f"invalid points: found {len(xs)} x values and {len(ys)} y values for graph {g.get('name', i)}")
            exit(1)

        if not len(es) in [0, len(xs)]:
            print(f"invalid errors: please provide error values for all points or for none, found {len(es)} errors for {len(xs)} points for graph {g.get('name', i)}")
            exit(1)

        style = {
            "marker": 'o',
            "markersize": 5,
            "linestyle": None,
            "color": None,
            "linewidth": None,
        }
        alpha = 1.0
        error_alpha = 0.3
        if "style" in g:
            custom_style = g["style"]
            style["color"] = custom_style.get("color", None)
            style["marker"] = custom_style.get("line", {}).get("marker", {}).get("shape", style["marker"])
            style["markersize"] = custom_style.get("line", {}).get("marker", {}).get("size", style["markersize"])
            style["linestyle"] = custom_style.get("line", {}).get("type", style["linestyle"])
            style["linewidth"] = custom_style.get("line", {}).get("width", style["linewidth"])
            style["alpha"] = custom_style.get("line", {}).get("alpha", alpha)
            error_alpha = custom_style.get("error", {}).get("alpha", error_alpha)

        if g.get("name", None) is None:
            ax.plot(xs, ys, **style)
        else:
            ax.plot(xs, ys, label=g["name"], **style)

        if len(es) != 0:
            ys, es = np.array(ys), np.array(es)
            down = ys - es
            up = ys + es
            if style["color"] is None:
                ax.fill_between(xs, down, up, alpha=error_alpha)
            else:
                ax.fill_between(xs, down, up, alpha=error_alpha, color=style["color"])

    ax.set_xlabel(x_label)
    ax.set_ylabel(y_label)

    ax.set_xscale(x_scale)
    ax.set_yscale(y_scale)

    if x_lim is not None:
        ax.set_xlim(x_lim)
    if y_lim is not None:
        ax.set_ylim(y_lim)

    ax.set_title(title)

    ax.legend(loc=legend_loc)

    ax.grid(True, which="major")
    if x_scale == "log":
        ax.xaxis.grid(True, which="minor")
    if y_scale == "log":
        ax.yaxis.grid(True, which="minor")
    if x_scale == "log" or y_scale == "log":
        ax.minorticks_on()

    if x_ticks_rotation is not None:
        ax.set_xticklabels(ax.get_xticklabels(), rotation=x_ticks_rotation)

    if x_ticks is not None:
        labels = x_tick_labels if x_tick_labels is not None else x_ticks
        if len(x_ticks) != len(labels):
            print(f"X ticks and their labels should have the same length, found {len(x_ticks)} and {len(labels)}")
            exit(1)
        ax.set_xticks(x_ticks, labels=labels)
    if y_ticks is not None:
        labels = y_tick_labels if y_tick_labels is not None else y_ticks
        if len(y_ticks) != len(labels):
            print(f"Y ticks and their labels should have the same length, found {len(y_ticks)} and {len(labels)}")
            exit(1)
        ax.set_yticks(y_ticks, labels=labels)

    if save is not None:
        fig.set_size_inches(save_fig_size, forward=False)
        fig.savefig(save, dpi=save_dpi)

        print(f"plot saved as `{save}`")
    else:
        plt.show()
