#!/usr/bin/env python
# fmt: off

import matplotlib as mpl
import warnings

warnings.simplefilter("ignore", UserWarning)

mpl.rcParams['figure.titlesize'] = 'x-large'
mpl.rcParams['figure.constrained_layout.use'] = True

# Basis

from skfda.representation.basis import MonomialBasis, FourierBasis, BSplineBasis
import matplotlib.pyplot as plt

fig, axes = plt.subplots(1, 3, figsize=(10, 3))

MonomialBasis(n_basis=5).plot(axes[0])
axes[0].set_title("Monomial")
BSplineBasis(n_basis=5).plot(axes[1])
axes[1].set_title("B-spline")
FourierBasis(n_basis=5).plot(axes[2])
axes[2].set_title("Fourier")

fig.savefig("../Figures/basis.pdf")

# Regularization
import skfda
import matplotlib.pyplot as plt

reg_param_list = (0, 1, 10)
fig, axes = plt.subplots(1, len(reg_param_list), figsize=(10, 3))

for i, reg_param in enumerate(reg_param_list):

    X, y = skfda.datasets.fetch_phoneme(return_X_y=True)
    X = X.coordinates[0]

    basis = skfda.representation.basis.BSplineBasis(
        domain_range=X.domain_range,
        n_basis=40,
    )

    regularization = skfda.misc.regularization.L2Regularization(
        skfda.misc.operators.LinearDifferentialOperator(2),
        regularization_parameter=reg_param,
    )

    smoother = skfda.preprocessing.smoothing.BasisSmoother(
        basis=basis,
        regularization=regularization,
        return_basis=True,
    )

    X_basis = smoother.fit_transform(X)
    X_basis[:10].plot(axes[i])
    axes[i].set_ylim(0, 23)
    axes[i].set_title(fr"$\lambda$ = {reg_param:.1f}")

fig.suptitle(None)
fig.savefig("../Figures/regularization.pdf")

# Boxplot

import skfda
import matplotlib.pyplot as plt

X, _ = skfda.datasets.fetch_growth(return_X_y=True)

fig, axes = plt.subplots(1, 3, figsize=(10, 3))

X.plot(axes=axes[0])

boxplot = skfda.exploratory.visualization.Boxplot(
    X,
    depth_method=skfda.exploratory.depth.ModifiedBandDepth(),
    axes=axes[1],
)
boxplot.plot()

boxplot = skfda.exploratory.visualization.Boxplot(
    X,
    depth_method=skfda.exploratory.depth.ModifiedBandDepth(),
    prob=[0.75, 0.5, 0.25],
    axes=axes[2],
)
boxplot.plot()
fig.suptitle(None)
fig.savefig("../Figures/boxplot.pdf")

# Elastic registration (synthetic)

import skfda
from skfda.preprocessing.registration import FisherRaoElasticRegistration
import matplotlib.pyplot as plt

X = skfda.datasets.make_multimodal_samples(
    n_modes=2, start=-2, stop=2, random_state=1)

X_aligned = FisherRaoElasticRegistration().fit_transform(X)

fig, axes = plt.subplots(1, 2, figsize=(10, 3))
X.plot(axes=axes[0])
X_aligned.plot(axes=axes[1])
fig.savefig("../Figures/elastic_registration_synthetic.pdf")

# Elastic registration

import skfda
from skfda.preprocessing.registration import (
    FisherRaoElasticRegistration,
    LeastSquaresShiftRegistration,
)
import matplotlib.pyplot as plt

X, y = skfda.datasets.fetch_growth(return_X_y=True)

X_aligned_elastic = FisherRaoElasticRegistration().fit_transform(X)
X_aligned_shift = LeastSquaresShiftRegistration().fit_transform(X)

fig, axes = plt.subplots(1, 3, figsize=(10, 3))
X.plot(axes=axes[0])
axes[0].set_title(X.dataset_name)
X_aligned_shift.plot(axes=axes[1])
axes[1].set_title("Shift registration")
X_aligned_elastic.plot(axes=axes[2])
axes[2].set_title("Elastic registration")
fig.suptitle(None)
fig.savefig("../Figures/registration.pdf")

# FDataBasis

import skfda
from skfda.representation.basis import FourierBasis, BSplineBasis
import matplotlib.pyplot as plt

X, y = skfda.datasets.fetch_phoneme(return_X_y=True)

fig, axes = plt.subplots(1, 3, figsize=(10, 2.5))

X[:10].plot(axes[0])
axes[0].set_title("Original")
X.to_basis(BSplineBasis(n_basis=5))[:10].plot(axes[1])
axes[1].set_title("B-spline")
X.to_basis(FourierBasis(n_basis=5))[:10].plot(axes[2])
axes[2].set_title("Fourier")
fig.suptitle(None)

fig.savefig("../Figures/fdatabasis.pdf")

# Fetch data

import skfda
import matplotlib.pyplot as plt

fig, axes = plt.subplots(1, 3, figsize=(10, 3))

dataset = skfda.datasets.fetch_ucr("GunPoint")
dataset["data"].plot(group=dataset["target"], axes=axes[0])
axes[0].set_title(dataset["data"].dataset_name)

X, _ = skfda.datasets.fetch_weather(return_X_y=True)
X.coordinates[0].plot(axes=axes[1])
axes[1].set_title(X.dataset_name)

X, y = skfda.datasets.fetch_growth(return_X_y=True)
X.plot(group=y, axes=axes[2])
axes[2].set_title(X.dataset_name)

fig.suptitle(None)

fig.savefig("../Figures/fetch_datasets.pdf")

# Gaussian

import skfda
from skfda.misc.covariances import Brownian, Gaussian, Exponential
import matplotlib.pyplot as plt

cov_dict = {
    "Brownian": Brownian(variance=1),
    "Exponential": Exponential(variance=1, length_scale=1),
    "Gaussian (RBF)": Gaussian(length_scale=0.1),
}

fig, axes = plt.subplots(1, len(cov_dict), figsize=(10, 3))
for i, (name, cov) in enumerate(cov_dict.items()):

    fd = skfda.datasets.make_gaussian_process(
        n_samples=50,
        n_features=100,
        mean=0,
        cov=cov,
        random_state=0,
    )

    fd.plot(axes=axes[i])
    axes[i].set_title(name)

fig.savefig("../Figures/gaussian.pdf")

# Grid example

import skfda

grid_points = [0.0, 0.1, 0.3, 0.4, 0.7, 1.0]
data_matrix = [
    [109.5, 115.8, 121.9, 130.0, 138.2, 141.1],
    [104.6, 112.3, 118.9, 125.0, 130.1, 133.0],
    [100.4, 107.1, 112.3, 118.6, 124.0, 126.5],
]

fd = skfda.FDataGrid(
    data_matrix=data_matrix,
    grid_points=grid_points,
)

# MS-plot

import skfda
import matplotlib.pyplot as plt

X, y = skfda.datasets.fetch_weather(return_X_y=True)
X = X.coordinates[0]

fig, axes = plt.subplots(1, 2, figsize=(8, 3))
ms_plot = skfda.exploratory.visualization.MagnitudeShapePlot(X, axes=axes[1])
ms_plot.plot()
axes[1].set_title("MS-Plot")

fig = X.plot(
    group=ms_plot.outliers,
    group_colors=["blue", "red"],
    axes=axes[0],
)
axes[0].set_title("Trajectories")
fig.suptitle(None)
fig.savefig("../Figures/ms_plot.pdf")

# Notation

import skfda
import matplotlib.pyplot as plt
import numpy as np

# From https://stackoverflow.com/a/61454455


def draw_brace(ax, xspan, yy, text):
    """Draws an annotated brace on the axes."""
    xmin, xmax = xspan
    xspan = xmax - xmin
    ax_xmin, ax_xmax = ax.get_xlim()
    xax_span = ax_xmax - ax_xmin

    ymin, ymax = ax.get_ylim()
    yspan = ymax - ymin
    resolution = int(xspan / xax_span * 100) * 2 + 1  # guaranteed uneven
    beta = 300. / xax_span  # the higher this is, the smaller the radius

    x = np.linspace(xmin, xmax, resolution)
    x_half = x[:int(resolution / 2) + 1]
    y_half_brace = (1 / (1. + np.exp(-beta * (x_half - x_half[0])))
                    + 1 / (1. + np.exp(-beta * (x_half - x_half[-1]))))
    y = np.concatenate((y_half_brace, y_half_brace[-2::-1]))
    y = yy + - (.05 * y - .01) * yspan  # adjust vertical position

    ax.autoscale(False)
    ax.plot(x, y, color='black', clip_on=False, lw=1)

    ax.text((xmax + xmin) / 2., yy - .17 * yspan,
            text, ha='center', va='bottom', fontsize=20)


plt.rc('legend', fontsize=15)

X, y = skfda.datasets.fetch_growth(return_X_y=True)

X2 = skfda.FDataGrid(X.data_matrix[:, 6:12], [0, 0.1, 0.3, 0.4, 0.7, 1])

X2 = X2[:3]


fig = X2.scatter()
X2.plot(fig, group=["$x_1$", "$x_2$", "$x_3$"], legend=True)

xlim = fig.axes[0].get_xlim()
ylim = fig.axes[0].get_ylim()

#fig.suptitle("Sample", fontsize=20)

fig.axes[0].hlines(X2.data_matrix[0, 3, 0], -1,
                   X2.grid_points[0][3], linestyle="dashed")

for t, x in zip([0, 0, 0], X2.data_matrix[:, 0, 0]):
    fig.axes[0].hlines(x, -1, t, linestyle="dashed")

for t, x in zip(X2.grid_points[0], X2.data_matrix[0, ..., 0]):
    fig.axes[0].vlines(t, 0, x, linestyle="dashed")

fig.axes[0].set_yticks([X2.data_matrix[0, 3, 0]] +
                       [X2.data_matrix[i, 0, 0] for i in range(len(X2))])
fig.axes[0].set_yticklabels(
    ["$x_1(t_4)$"] + [f"$x_{i+1}(t_1)$" for i in range(len(X2))], fontsize=15)

fig.axes[0].set_xticks(X2.grid_points[0])
fig.axes[0].set_xticklabels(
    [f"$t_{i+1}$" for i, _ in enumerate(X2.grid_points[0])], fontsize=15)

fig.axes[0].set_xlim(xlim)
fig.axes[0].set_ylim(ylim)

draw_brace(fig.axes[0], X2.domain_range[0],
           X2.data_matrix[2, 0, 0] - 6, "$\\mathcal{T}$")

plt.subplots_adjust(bottom=0.2)

fig.savefig("../Figures/notation.pdf")

# Outliergram

import skfda
import matplotlib.pyplot as plt

X, y = skfda.datasets.fetch_weather(return_X_y=True)

X = X.coordinates[0]

fig, axes = plt.subplots(1, 2, figsize=(8, 3))
fig = X.plot(axes=axes[0])
axes[0].set_title("Trajectories")
fig = skfda.exploratory.visualization.Outliergram(X, axes=axes[1]).plot()
axes[1].set_title("Outliergram")
fig.suptitle(None)
fig.savefig("../Figures/outliergram.pdf")

# PCA

import skfda
import matplotlib.pyplot as plt

X, y = skfda.datasets.fetch_growth(return_X_y=True)

fpca = skfda.preprocessing.dim_reduction.FPCA(
    n_components=2,
)
fpca.fit(X)

fig, axes = plt.subplots(1, 3, figsize=(10, 3))
skfda.exploratory.visualization.FPCAPlot(
    X.mean(), fpca.components_, factor=30, axes=axes[:2],
).plot()
axes[0].set_title("First principal component")
axes[1].set_title("Second principal component")

scores = fpca.transform(X)
scores_class_0 = scores[y==0]
scores_class_1 = scores[y==1]

axes[2].scatter(scores_class_0[:, 0], scores_class_0[:, 1])
axes[2].scatter(scores_class_1[:, 0], scores_class_1[:, 1])
axes[2].set_xlabel("PC1 score")
axes[2].set_ylabel("PC2 score")

fig.suptitle(None)
fig.savefig("../Figures/fpca.pdf")

# Sklearn example

import skfda

from sklearn.model_selection import GridSearchCV, train_test_split
from sklearn.pipeline import Pipeline
from sklearn.svm import SVC

import skfda.preprocessing.smoothing as smoothing
import skfda.preprocessing.dim_reduction as dimred
from skfda.misc.hat_matrix import KNeighborsHatMatrix

X, y = skfda.datasets.fetch_phoneme(return_X_y=True)

X_train, X_test, y_train, y_test = train_test_split(
    X, y, random_state=0)

smoothing_step = smoothing.KernelSmoother(
    kernel_estimator=KNeighborsHatMatrix()
)
dimred_step = dimred.FPCA(n_components=3)
classification_step = SVC()

pipeline = Pipeline([
    ('smoothing', smoothing_step),
    ('dimred', dimred_step),
    ('classification', classification_step)])
grid = GridSearchCV(
    pipeline,
    param_grid={
        'smoothing__kernel_estimator__n_neighbors': [3, 5, 7],
        'dimred__n_components': [1, 2, 3],
        'classification__C': [0.001, 0.01, 0.1, 1, 10],
    })

grid.fit(X_train, y_train)
score = grid.score(X_test, y_test)
print(f"{score:.3}")

# Smoothing

import skfda
from skfda.preprocessing.smoothing import KernelSmoother, validation
from skfda.misc.hat_matrix import KNeighborsHatMatrix
import matplotlib.pyplot as plt

X, y = skfda.datasets.fetch_phoneme(return_X_y=True)

grid = validation.SmoothingParameterSearch(
    KernelSmoother(KNeighborsHatMatrix()),
    [2, 3, 4, 5],
    param_name="kernel_estimator__n_neighbors",
    scoring=validation.LinearSmootherGeneralizedCVScorer(validation.shibata),
)

grid.fit(X)
X_smooth = grid.transform(X)

fig, axes = plt.subplots(1, 2, figsize=(10, 3))
X[:5].plot(axes=axes[0])
X_smooth[:5].plot(axes=axes[1])
fig.suptitle(None)
fig.savefig("../Figures/smoothing.pdf")

# Stats

import skfda
import matplotlib.pyplot as plt
import numpy as np

X, _ = skfda.datasets.fetch_weather(return_X_y=True)
X = X.coordinates[0]

mean = X.mean()
var = X.var()
std = np.sqrt(var)
trim_mean = skfda.exploratory.stats.trim_mean(X, 0.1)
geo_median = skfda.exploratory.stats.geometric_median(X)
depth_median = skfda.exploratory.stats.depth_based_median(X)

fig, axes = plt.subplots(1, 1, figsize=(10, 6))
X.plot(fig=fig, color="grey", alpha=0.2)
axes.fill_between(
    X.grid_points[0],
    (mean-std).data_matrix[0, ..., 0],
    (mean+std).data_matrix[0, ..., 0],
    alpha=0.3,
)

mean.plot(fig=fig, label="mean")
trim_mean.plot(fig=fig, label="trimmed mean")
geo_median.plot(fig=fig, label="geometric median")
depth_median.plot(fig=fig, label="depth based median")
axes.legend(loc='lower right')
fig.suptitle(None)

fig.savefig("../Figures/stats.pdf")

# Interactive

import skfda
import matplotlib.pyplot as plt
from matplotlib.widgets import Slider

X, y = skfda.datasets.fetch_weather(return_X_y=True)
X = X.coordinates[0]

fig, axes = plt.subplots(2, 2, figsize=(8, 6))
graph_plot = skfda.exploratory.visualization.representation.GraphPlot(X)
ms_plot = skfda.exploratory.visualization.MagnitudeShapePlot(X)
outliergram_plot = skfda.exploratory.visualization.Outliergram(X)
mbd = skfda.exploratory.depth.ModifiedBandDepth()
interactive_plot = skfda.exploratory.visualization.MultipleDisplay(
    [graph_plot, ms_plot, outliergram_plot],
    criteria=mbd(X),
    sliders=Slider,
    label_sliders=["MBD"],
    fig=fig,
)

interactive_plot.plot()

axes[0, 0].set_title("Trajectories")
axes[0, 1].set_title("MS-Plot")
axes[1, 0].set_title("Outliergram")
fig.suptitle(None)

plt.show()

