import Pkg
Pkg.activate(".")
Pkg.instantiate(verbose = true)

# Required packages
using Distances,
      Pathogen,
      Random,
      Plots,
      Plots.PlotMeasures; # Plots.PlotMeasures is used for advanced formatting control for figure output

# To ensure reproducibility of results, use Julia v1.4.2
using InteractiveUtils;
versioninfo()

# Set seed for RNG
Random.seed!(11235);

# Generate population
n = 100
risks = DataFrame(x = rand(Uniform(0, 15), n),
                  y = rand(Uniform(0, 30), n),
                  riskfactor1 = rand(Gamma(), n))

# Precalculate distances
dists = [euclidean([risks[i, :x];
                    risks[i, :y]],
                   [risks[j, :x];
                    risks[j, :y]]) for i = 1:n, j = 1:n]

pop = Population(risks, dists)

# Define risk functions/TN-ILM structure
function _constant(θ::Vector{Float64}, pop::Population, i::Int64)
  return θ[1]
end

function _one(θ::Vector{Float64}, pop::Population, i::Int64)
  return 1.0
end

function _linear(θ::Vector{Float64}, pop::Population, i::Int64)
  return θ[1] * pop.risks[i, :riskfactor1]
end

function _powerlaw(θ::Vector{Float64}, pop::Population, i::Int64, k::Int64)
  d = pop.distances[k, i]
  return d^(-θ[1])
end

rf = RiskFunctions{SIR}(_constant, # sparks function
                        _one, # susceptibility function
                        _powerlaw, # infectivity function
                        _one, # transmissability function
                        _linear) # removal function

# Parametrize risk functions for simulation
rparams = RiskParameters{SIR}([0.0001], # sparks function parameter(s)
                              Float64[], # susceptibility function parameter(s)
                              [4.0], # infectivity function parameter(s)
                              Float64[], # transmissibility function parameter(s)
                              [0.1]) # removal function parameter(s)

# Set starting states in population
# Set first individual as infectious, others as susceptible to start
starting_states = [State_I; fill(State_S, n-1)]

# Initialize Simulation
sim = Simulation(pop, starting_states, rf, rparams)

# Simulate!
simulate!(sim, tmax = 200.0)

gr(dpi = 200) # GR backend for Plots.jl with DPI=200

# Epidemic Curve
p1 = plot(sim.events, 0.0, 200.0, legendfont = font(6), xaxis = font(10), bottom_margin = 30px)

# Population/TransmissionNetwork plots
p2 = plot(sim.transmission_network, sim.population, sim.events, 0.0, title = "Time = 0", titlefontsize = 8)
p3 = plot(sim.transmission_network, sim.population, sim.events, 10.0, title = "Time = 10", titlefontsize = 8)
p4 = plot(sim.transmission_network, sim.population, sim.events, 20.0, title = "Time = 20", titlefontsize = 8)
p5 = plot(sim.transmission_network, sim.population, sim.events, 30.0, title = "Time = 30", titlefontsize = 8)
p6 = plot(sim.transmission_network, sim.population, sim.events, 50.0, title = "Time = 50", titlefontsize = 8)

l = @layout [a
             grid(1, 5)]
combinedplots1 = plot(p1, p2, p3, p4, p5, p6, layout = l)
png(combinedplots1, joinpath(@__DIR__, "sim_epiplot.png"))

# Generate observations with Uniform(0.5, 2.5) observation delay for infection and removal
obs = observe(sim, Uniform(0.5, 2.5), Uniform(0.5, 2.5), force = true)

# Optimistically assume we know the functional form of epidemic (i.e. use same risk functions used for simulation purposes)
# Specify some priors for the risk parameters of our various risk functions
rpriors = RiskPriors{SIR}([Exponential(0.0001)],
                          UnivariateDistribution[],
                          [Uniform(1.0, 7.0)],
                          UnivariateDistribution[],
                          [Uniform(0.0, 1.0)])

# Set some extents for event data augmentation
ee = EventExtents{SIR}(5.0, 5.0)

# Initialize MCMC
mcmc = MCMC(obs, ee, pop, starting_states, rf, rpriors)
start!(mcmc, attempts = 50000) # 1 chain, with 50k initialization attempts

# Run MCMC
iterate!(mcmc, 50000, 1.0, condition_on_network = true, event_batches = 5)

# MCMC and posterior plots
p1 = plot(1:20:50001,
  mcmc.markov_chains[1].risk_parameters, yscale = :log10, title = "TN-ILM parameters", xguidefontsize = 8, yguidefontsize = 8, xtickfontsize = 7, ytickfontsize = 7, titlefontsize = 11, bottom_margin = 30px)

p2 = plot(mcmc.markov_chains[1].events[10000], State_S,
          linealpha = 0.01, title = "S", xguidefontsize = 8, yguidefontsize = 8,
          xtickfontsize = 7, ytickfontsize = 7, titlefontsize = 11)
for i = 10020:20:50000
  plot!(p2, mcmc.markov_chains[1].events[i], State_S, linealpha = 0.01)
end
plot!(p2, sim.events, State_S, linecolor = :black, linewidth = 1.5)

p3 = plot(mcmc.markov_chains[1].events[10000], State_I,
          linealpha = 0.01, title = "I", xguidefontsize = 8, yguidefontsize = 8, xtickfontsize = 7, ytickfontsize = 7, titlefontsize = 11)
for i = 10020:20:50000
  plot!(p3, mcmc.markov_chains[1].events[i], State_I, linealpha = 0.01)
end
plot!(p3, sim.events, State_I, linecolor = :black, linewidth = 1.5)

p4 = plot(mcmc.markov_chains[1].events[10000], State_R,
          linealpha = 0.01, title = "R", xguidefontsize = 8, yguidefontsize = 8, xtickfontsize = 7, ytickfontsize = 7, titlefontsize = 11)
for i = 10020:20:50000
  plot!(p4, mcmc.markov_chains[1].events[i], State_R, linealpha = 0.01)
end
plot!(p4, sim.events, State_R, linecolor = :black, linewidth = 1.5)

l = @layout [a; [b c d]]
combinedplots2 = plot(p1, p2, p3, p4, layout = l)
png(combinedplots2, joinpath(@__DIR__, "sim_posterior.png"))

p1 = plot(sim.transmission_network, sim.population, title = "True Transmission\nNetwork", titlefontsize = 11, framestyle = :box)

tnp = TransmissionNetworkPosterior(mcmc, burnin = 10000, thin = 20)
p2 = plot(tnp, sim.population, title = "Transmission Network\nPosterior Distribution", titlefontsize = 11, framestyle = :box)

combinedplots3 = plot(p1, p2, layout = (1, 2))
png(combinedplots3, joinpath(@__DIR__, "sim_posterior_tn.png"))

# Marginal posterior distribution summary of TN-ILM parameters
println(summary(mcmc, burnin = 10000, thin = 20))
