import numpy as np
import matplotlib.pyplot as plt
import matplotlib

# install jumpdiff using: pip install jumpdiff
import jumpdiff as jd

# %% ############ Firstly we showcase direct the code in the paper #############

# integration time and time sampling
time = 10000
delta_t = 0.001
# define the drift function a(x)
def a(x):
    return -0.5*x
# define the diffusion function b(x)
def b(x):
    return 0.75
# define jump amplitude and rate
xi = 1.5
lamb = 1.75
# generate the jump-diffusion process
X = jd.jd_process(time, delta_t, a, b, xi, lamb)

# extract the Kramers-Moyal and state-space without second-order corrections
edge, simple_mom = jd.moments(timeseries=X, correction=False)
# and with second-order corrections
edge, mom = jd.moments(timeseries=X, correction=True)

# To estimate the jump amplitude
xi_est = jd.jump_amplitude(moments=simple_mom)
# and to estimated the jump rate
lamb_est = jd.jump_rate(moments=simple_mom)

# Use numpy to generate logarithmic spaced arrays
# and matplotlib to plot in double logarithmic scale.
# Take a sequence of integers lag
lag = np.logspace(0, 4, 25, dtype=int)
# Recover the Q-ratio of the time series X
lag, Q = jd.q_ratio(lag, X)

# plot in a log-log scale
plt.loglog(lag, Q)

# %% #################### Fixing a seed for reproducibility ####################

np.random.seed(4229)

# %% ############################## FIGURE 1 ###################################

# Let's first integrate a jump-diffusion process.
# integration time and time sampling
t_final = 10000
delta_t = 0.001

# A drift function
def a(x):
    return -0.5*x

# and a (constant) diffusion term
def b(x):
    return 0.75

# Now define a jump amplitude and rate
xi = 1.5
lamb = 1.75

# and let jdprocess integrate the stochastic path
X = jd.jd_process(t_final, delta_t, a=a, b=b, xi=xi, lamb=lamb)

# %% Plot the trajectory of the jump-diffusion process
fig, ax = plt.subplots(1, 2, figsize=(12, 3));

ax[0].plot(X, color='black')
ax[0].set_xlabel(r'$N$', fontsize=22)
ax[0].set_ylabel(r'$X(t)$', fontsize=22, labelpad=-10)
ax[1].plot(np.linspace(np.argmax(X) - 20, np.argmax(X) + 200, 220),
           X[np.argmax(X) - 20:np.argmax(X) + 200], '.-', color='black')
ax[1].set_xlabel(r'$N$', fontsize=22)
ax[1].set_ylabel(r'$X(t)$', fontsize=22, labelpad=0)
fig.subplots_adjust(left=0.06, bottom=0.21, right=0.99, top=0.99, hspace=0.05,
    wspace=0.15)
fig.savefig('Figures/Figure1.png')

# %% ############################## FIGURE 2 ###################################

# Choose number of points of you target space
bins = np.array([5000])

# Choose your desired bandwidth
bw = 0.35

### In the paper you can read the following. In here we perform the same but for
### three increasing undersampling of the stochastic process X to showcase the
### limitations arising from finite-time effects.

################################################################################
# # extract the Kramers-Moyal and state-space without second-order corrections
# edge, simple_mom = jd.moments(timeseries=X, correction=False)
# # and with second-order corrections
# edge, mom = jd.moments(timeseries=X, correction=True)
################################################################################

# Obtain the KM coefficients without corrections...
ed_1, kmc_1 = jd.moments(X[::1], bins=bins, bw=bw, lag=[1], correction=False)
h_1 = np.argmin(ed_1[:,0] ** 2)
ed_2, kmc_2 = jd.moments(X[::20], bins=bins, bw=bw, lag=[1], correction=False)
h_2 = np.argmin(ed_2[:,0] ** 2)
ed_3, kmc_3 = jd.moments(X[::100], bins=bins, bw=bw, lag=[1], correction=False)
h_3 = np.argmin(ed_3[:,0] ** 2)

# ... and with corrections
_, F_1 = jd.moments(X[::1], bins=bins, bw=bw, lag=[1], correction=True);
_, F_2 = jd.moments(X[::20], bins=bins, bw=bw, lag=[1], correction=True);
_, F_3 = jd.moments(X[::100], bins=bins, bw=bw, lag=[1], correction=True);

# 1/m! normalisation
norm = np.array([1, 1, 1/2, 1, 1/24, 1, 1/720])

fig, ax = plt.subplots(4, 3, figsize=(16, 6));

# Number of points left and right to plots
ofs = 650

# plot KMs 1, 2, 4, and 6
sit = [1, 2, 4, 6]

colours = ['#ff7f00', '#a6cee3', '#6a3d9a', '#33a02c']

[ax[0,j].plot(0, 0, '--', color='black', linewidth=2,
    label='first-order') for j in range(3)]
[ax[0,j].plot(0, 0, '-', color='black', linewidth=2,
    label='second-order') for j in range(3)]

[ax[i,0].plot(ed_1[h_1 - ofs:h_1 + ofs,0],
    (1/delta_t) * norm[sit[i]] * kmc_1[sit[i], h_1 - ofs:h_1 + ofs,0], '--',
    color=colours[i], linewidth=2) for i in range(4)]
[ax[i,0].plot(ed_1[h_1 - ofs:h_1 + ofs,0],
    (1/delta_t) * norm[sit[i]] * F_1[sit[i], h_1 - ofs:h_1 + ofs,0], '-',
    color=colours[i], linewidth=2) for i in range(4)]

[ax[i,1].plot(ed_2[h_2 - ofs:h_2 + ofs,0],
    (1/delta_t) * norm[sit[i]] * kmc_2[sit[i], h_2 - ofs:h_2 + ofs,0]/20, '--',
    color=colours[i], linewidth=2) for i in range(4)]
[ax[i,1].plot(ed_2[h_2 - ofs:h_2 + ofs,0],
    (1/delta_t) * norm[sit[i]] * F_2[sit[i], h_2 - ofs:h_2 + ofs,0]/20, '-',
    color=colours[i], linewidth=2) for i in range(4)]

[ax[i,2].plot(ed_3[h_3 - ofs:h_3 + ofs,0],
    (1/delta_t) * norm[sit[i]] * kmc_3[sit[i], h_3 - ofs:h_3 + ofs,0]/100, '--',
    color=colours[i], linewidth=2) for i in range(4)]
[ax[i,2].plot(ed_3[h_3 - ofs:h_3 + ofs,0],
    (1/delta_t) * norm[sit[i]] * F_3[sit[i], h_3 - ofs:h_3 + ofs,0]/100, '-',
    color=colours[i], linewidth=2) for i in range(4)]

ax[0,0].plot(ed_1[h_1 - ofs:h_1 + ofs,0],
    norm[sit[0]] * a(ed_1[h_1 - ofs:h_1 + ofs,0]),
    ':', color='black', linewidth=2)
ax[0,1].plot(ed_2[h_2 - ofs:h_2 + ofs,0],
    norm[sit[0]] * a(ed_2[h_2 - ofs:h_2 + ofs,0]),
    ':', color='black', linewidth=2)
ax[0,2].plot(ed_3[h_3 - ofs:h_3 + ofs,0],
    norm[sit[0]] * a(ed_3[h_3 - ofs:h_3 + ofs,0]),
    ':', color='black', linewidth=2)

ax[1,0].plot(ed_1[h_1 - ofs:h_1 + ofs,0],
    norm[sit[1]] * (b(0) ** 2 * np.ones_like(ed_1[h_1 - ofs:h_1 + ofs,0]) \
        + lamb * np.sqrt(xi) ** 2), ':', color='black', linewidth=2)
ax[1,1].plot(ed_2[h_2 - ofs:h_2 + ofs,0],
    norm[sit[1]] * (b(0) ** 2 * np.ones_like(ed_2[h_2 - ofs:h_2 + ofs,0]) \
        + lamb * np.sqrt(xi) ** 2), ':', color='black', linewidth=2)
ax[1,2].plot(ed_3[h_3 - ofs:h_3 + ofs,0],
    norm[sit[1]] * (b(0) ** 2 * np.ones_like(ed_3[h_3 - ofs:h_3 + ofs,0]) \
        + lamb*np.sqrt(xi) ** 2), ':', color='black', linewidth=2)

ax[2,0].plot(ed_1[h_1 - ofs:h_1 + ofs,0],
    norm[sit[2]] * (3 * lamb * (xi ** 2) \
    * np.ones_like(ed_1[h_1 - ofs:h_1 + ofs,0])),
    ':', color='black', linewidth=2)
ax[2,1].plot(ed_2[h_2 - ofs:h_2 + ofs,0],
    norm[sit[2]] * (3 * lamb * (xi ** 2) \
    * np.ones_like(ed_2[h_2 - ofs:h_2 + ofs,0])),
    ':', color='black', linewidth=2)
ax[2,2].plot(ed_3[h_3 - ofs:h_3 + ofs,0],
    norm[sit[2]] * (3 * lamb * (xi ** 2) \
    * np.ones_like(ed_3[h_3 - ofs:h_3 + ofs,0])),
    ':', color='black', linewidth=2)

ax[3,0].plot(ed_1[h_1 - ofs:h_1 + ofs,0],
    norm[sit[3]] * (15 * lamb * (xi ** 3) \
    * np.ones_like(ed_1[h_1 - ofs:h_1 + ofs,0])),
    ':', color='black', linewidth=2)
ax[3,1].plot(ed_2[h_2 - ofs:h_2 + ofs,0],
    norm[sit[3]] * (15 * lamb * (xi ** 3) \
    * np.ones_like(ed_2[h_2 - ofs:h_2 + ofs,0])),
    ':', color='black', linewidth=2)
ax[3,2].plot(ed_3[h_3 - ofs:h_3 + ofs,0],
    norm[sit[3]] * (15 * lamb * (xi ** 3) \
    * np.ones_like(ed_3[h_3 - ofs:h_3 + ofs,0])),
    ':', color='black', linewidth=2)

[[ax[i,j].yaxis.set_visible(False) for i in range(4)] for j in range(1,3)]
[[ax[j,i].xaxis.set_visible(False) for i in range(3)] for j in range(3)]

[ax[0,i].set_ylim([-1.6, 1.6]) for i in range(3)]
[ax[0,i].set_yticks([-1, 0, 1]) for i in range(3)]
[ax[1,i].set_ylim([1.3, 2.3]) for i in range(3)]
[ax[1,i].set_yticks([1.5, 2.0]) for i in range(3)]
[ax[2,i].set_ylim([0.2, 0.9]) for i in range(3)]
[ax[2,i].set_yticks([0.4, 0.6, 0.8]) for i in range(3)]
[ax[3,i].set_ylim([-0.1, 0.5]) for i in range(3)]
[ax[3,i].set_yticks([0.0, 0.2, 0.4]) for i in range(3)]

ax[0,0].set_ylabel(r'$D_1$', labelpad=0, fontsize=28)
ax[1,0].set_ylabel(r'$D_2$', labelpad=0, fontsize=28)
ax[2,0].set_ylabel(r'$D_4$', labelpad=0, fontsize=28)
ax[3,0].set_ylabel(r'$D_6$', labelpad=0, fontsize=28)

ax[0,0].set_title(r'$s_f=1$', loc='left')
ax[0,1].set_title(r'$s_f=0.05$', loc='left')
ax[0,2].set_title(r'$s_f=0.01$', loc='left')

[ax[3,i].set_xlabel(r'$x$', fontsize=28) for i in range(3)]
[ax[0,j].legend(loc=1, fontsize=16, ncol=2, bbox_to_anchor=(1.0, 1.45),
    handlelength=1.1, columnspacing=0.5, handletextpad=0.4,
    borderpad=0.2) for j in range(3)];

fig.subplots_adjust(left=0.06, bottom=0.12, right=.99, top=0.93, hspace=0.12,
    wspace=0.03)
fig.savefig('Figures/Figure2.png')

# %% ############################## FIGURE 3 ###################################

# Figure 3 is considerably lengthy, thus we have removed the averaring over 10
# iterations, and left solely 2 iterations. We also reduced the number of points
# to 10.

# In the paper the basis for this code is:
################################################################################
# # To estimate the jump amplitude
# xi_est = jd.jump_amplitude(moments=simple_mom)
# # and to estimated the jump rate
# lamb_est = jd.jump_rate(moments=simple_mom)
################################################################################
# Herein we consider a very large set of stochastic processes X with growing
# length that serve to test the accuracy of the non-parametric estimates

# number of points in plot (in the paper num_points = 20)
num_points = 10

# segments of increasing time span to test convergence of estimators
timer = np.logspace(2, np.log10(5e4), num_points).astype(int)

amp = np.zeros([num_points, 10, 2])
amp_std = np.zeros([num_points, 10, 2])
rate = np.zeros([num_points, 10, 2])
rate_std = np.zeros([num_points, 10, 2])

delta_t = 0.001

# %%
for i in range(num_points):
    print(i)
    # In the paper the range below is range(10), here for speed range(2)
    for j in range(2):
        X = jd.jd_process(time=timer[i], delta_t=delta_t, a=a, b=b, xi=xi,
            lamb=lamb, init=0)
        m_F = jd.moments(X, bins=np.array([5000]), correction=False)[1]
        m_T = jd.moments(X, bins=np.array([5000]), correction=True)[1]
        amp[i,j,0], amp_std[i,j,0] = jd.jump_amplitude(m_F,
            full=True, verbose=False)
        rate[i,j,0], rate_std[i,j,0] = jd.jump_rate(m_F, xi_est=amp[i,j,0],
            full=True, verbose=False)
        amp[i,j,1], amp_std[i,j,1] = jd.jump_amplitude(m_T,
            full=True, verbose=False)
        rate[i,j,1], rate_std[i,j,1] = jd.jump_rate(m_T, xi_est=amp[i,j,1],
            full=True, verbose=False)

# Averaging over the several iterations
amp_ = np.mean(amp[:,:2,:], axis=1)
amp_std_ = np.std(amp[:,:2,:], axis=1)
rate_ = np.mean(rate[:,:2,:], axis=1)
rate_std_ = np.std(rate[:,:2,:], axis=1)

# %% Figure
fig, ax = plt.subplots(1, 2, figsize=(12, 4));

ax0 = ax[0].twiny()
ax1 = ax[1].twiny()

ax[0].semilogx(timer/delta_t, amp_[:,0], 'o-',
    color='#33a02c', label='first-order')
ax[0].fill_between(timer/delta_t, amp_[:,0], y2=amp_[:,0] - amp_std_[:,0],
    color='#33a02c', alpha=0.2)
ax[0].fill_between(timer/delta_t, amp_[:,0], y2=amp_[:,0] + amp_std_[:,0],
    color='#33a02c', alpha=0.2)
ax[0].semilogx(timer/delta_t, amp_[:,1],
    color='#ff7f00', label='second-order')
ax[0].fill_between(timer/delta_t, amp_[:,1], y2=amp_[:,1] - amp_std_[:,1],
    color='#ff7f00', alpha=0.2)
ax[0].fill_between(timer/delta_t, amp_[:,1], y2=amp_[:,1] + amp_std_[:,1],
    color='#ff7f00', alpha=0.2)
ax[0].set_xscale('log')

ax0.semilogx(timer*lamb, (xi)*np.ones_like(timer), '--', color='black')
ax0.set_xscale('log')

ax[1].semilogx(timer/delta_t, rate_[:,0] / delta_t, 'o-',
color='#33a02c', label='first-order')
ax[1].fill_between(timer / delta_t, rate_[:,0] / delta_t,
    y2=rate_[:,0] / delta_t - rate_std_[:,0] / delta_t,
    color='#33a02c', alpha=0.2)
ax[1].fill_between(timer / delta_t, rate_[:,0] / delta_t,
    y2=rate_[:,0] / delta_t + rate_std_[:,0] / delta_t,
    color='#33a02c', alpha=0.2)
ax[1].semilogx(timer / delta_t, rate_[:,1] / delta_t,
color='#ff7f00', label='second-order')
ax[1].fill_between(timer / delta_t, rate_[:,1] / delta_t,
    y2=rate_[:,1] / delta_t - rate_std_[:,1] / delta_t,
    color='#ff7f00', alpha=0.2)
ax[1].fill_between(timer / delta_t, rate_[:,1] / delta_t,
    y2=rate_[:,1] / delta_t + rate_std_[:,1] / delta_t,
    color='#ff7f00', alpha=0.2)
ax[1].set_xscale('log')

ax1.semilogx(timer * lamb, (lamb) * np.ones_like(timer), '--', color='black')
ax1.set_xscale('log')

# Axis
[ax[i].tick_params(axis='both', which='minor') for i in range(2)]
ax[0].set_ylabel(r'$\hat{\sigma}_\xi^2$', fontsize=22 , labelpad=0)
ax[1].set_ylabel(r'$\hat{\lambda}$', fontsize=22 , labelpad=0)
[ax[i].set_xlabel(r'$N$', fontsize=22) for i in range(2)];

ax[0].set_xticks([1e5, 3e5, 1e6, 3e6, 1e7, 3e7])
ax[1].set_xticks([1e5, 3e5, 1e6, 3e6, 1e7, 3e7])

ax[0].set_ylim([0, 3])
ax[1].set_ylim([0, 6])

[ax[i].legend(loc=1, fontsize=16, ncol=2, handlelength=1.5, columnspacing=1,
    handletextpad=0.4)  for i in range(2)]
fig.subplots_adjust(left=0.06, bottom=0.17, right=0.99, top=0.8, hspace=0.05,
    wspace=0.15)
fig.savefig('Figures/Figure3.png')

# %% ############################## FIGURE 4 ###################################

# Here we will generate a set of shorter trajectories, with varying jump height,
# in order to show the workings of Eq. 16

# In the paper the basis for this code is:
################################################################################
# # import numpy to generate logarithmic spaced arrays
# # and matplotlib to plot in double logarithmic scale
# import numpy as np
# import matplotlib.plotly as plt
# # Take a sequence of integers lag
# lag = np.logspace(0, 4, 25, dtype=int)
# # Recover the Q-ratio of the time series X
# lag, Q = jd.q_ratio(lag, X)
# # plot in a log-log scale
# plt.loglog(lag, Q)
################################################################################
# Here we test and showcase the Q-ratio for jump-diffusion processes with no
# jumps (purely diffusion) up to jumps with sigma = 1

# integration time and time sampling
t_final = 1000
delta_t = 0.001

# A drift function
def a(x):
    return -1*x

# and a (constant) diffusion term
def b(x):
    return 1

# Now define only the jump rate
lamb = 0.1

# This is a list of values for the jump height, from 0.05 to 1(=drift)
xi_list = np.linspace(0.05, 1, 20)

# lag for the Qratio
lag = np.unique(np.logspace(0, np.log10(int(t_final/delta_t) // 1000),
    100).astype(int)+1)

ratio = np.zeros([lag.size, 20, 3])

for j in range(20):
    print(j)
    for k in range(3):
        # jump height from list
        xi = xi_list[j].astype(float)
        # integrate process
        X = jd.jd_process(t_final, delta_t, a=a, b=b, xi=xi, lamb=lamb)
        _, ratio[:,j,k] = jd.q_ratio(lag=lag, timeseries=X)

# averaging the results
ratio_mean = np.mean(ratio, axis=2)

# %% Figure
fig, ax = plt.subplots(1, 1, figsize=(6, 3));

cmap = matplotlib.cm.get_cmap('BuGn')
cmap_r = matplotlib.cm.get_cmap('BuGn_r')

[ax.loglog(lag[:80], ratio_mean[:80,i],
    color=cmap((20 - i / 1.1) / 20.)) for i in range(20)];
ax.loglog(lag[:80], np.ones_like(lag[:80]) * 3, ':', color='black')
ax.loglog(lag[10:80], lag[10:80] * 1.8e-4, ':', color='black')

# Axis
ax.tick_params(axis='both', which='minor')
ax.set_ylabel(r'$Q$-ratio', fontsize=22, labelpad=0);
ax.set_xlabel(r'$\tau$', fontsize=22)
ax.set_xticks([1e1,1e2])
ax.set_ylim([1e-3,5e1]); ax.set_yticks([1e-2,1e-1,1e0,1e1])

# Text
fig.text(0.85, 0.3, r'diffusive', fontsize=20, rotation=16, ha='right')
fig.text(0.84, 0.81, r'jumpy', fontsize=20, ha='right')
fig.text(0.9, 0.06, r'$\sigma_\xi^2$', fontsize=22)

# Colorbar
sm = plt.cm.ScalarMappable(cmap=cmap_r)
cax = fig.add_axes([0.90, 0.21, 0.03, 0.76])

clb = fig.colorbar(sm,  cax=cax)
clb.set_ticks([0.0, 0.2, 0.4, 0.6, 0.8, 1.0])

fig.subplots_adjust(left=0.14, bottom=0.21, right=.88, top=0.97, hspace=0.4,
    wspace=0.3)
fig.savefig('Figures/Figure4.png')
