import Pkg
Pkg.activate(".")
Pkg.instantiate(verbose = true)

using CSV, DelimitedFiles, Distances, Random, Pathogen, Plots, Plots.PlotMeasures, DataFrames;

# To ensure reproducibility of results, use Julia v1.4.2
using InteractiveUtils;
versioninfo()

# Set seed for RNG
Random.seed!(4321);

# POPULATION INFOFORMATION
# Use CSV.jl for DataFrames I/O
risks = CSV.read("data/measles_hagelloch_1861_risk_factors.csv", DataFrame)

# Will precalculate distances
distance = [euclidean([risks[i, :x]; risks[i, :y]], [risks[j, :x]; risks[j, :y]]) for i = 1:size(risks, 1), j = 1:size(risks, 1)]
temp1 = [prod(risks[[i, j], :class]) for i = 1:size(risks, 1), j = 1:size(risks, 1)]
sameclass = temp1 .∈ Ref([1, 4])
samehousehold = distance .== 0.0
distance[samehousehold] .= Inf
dist = [(distance[i, j], sameclass[i, j], samehousehold[i, j]) for i = 1:size(risks, 1), j = 1:size(risks, 1)]

pop = Population(risks, dist)

# Define risk functions/TN-ILM structure
function _constant(params::Vector{Float64}, pop::Population, i::Int64)
  return params[1]
end

function _one(params::Vector{Float64}, pop::Population, i::Int64)
  return 1.0
end

function _powerlaw_plus(params::Vector{Float64}, pop::Population, i::Int64, k::Int64)
  return params[1] * pop.distances[k, i][1]^(-params[2]) +
         params[3] * pop.distances[k, i][2] +
         params[4] * pop.distances[k, i][3]
end

rf = RiskFunctions{SEIR}(_constant, # sparks function
                         _one, # susceptibility function
                         _powerlaw_plus, # infectivity function
                         _one, # transmissability function
                         _constant, # latency function
                         _constant) # removal function

obsdata = CSV.read("data/measles_hagelloch_1861_observations.csv", DataFrame)

# Set the removal observation as minimum of (day that rash appears + 4.0) and death, in fatal cases.
removed = [obsdata[i, :death] === NaN ? obsdata[i, :rash] + 4.0 : min(obsdata[i, :rash] + 4.0, obsdata[i, :death]) for i = 1:188]

# Set prodrome within first 7 days of epidemic as initial conditions
infected = obsdata[:, :prodrome]
# starting_states = [i <= 10.0 ? State_I : State_S for i in infected]
# infected[infected .<= 10.0] .= -Inf
obs = EventObservations{SEIR}(infected, removed)

# Specify some priors for the risk parameters of our various risk functions
rpriors = RiskPriors{SEIR}([Uniform(0.0, 0.1)],
                           UnivariateDistribution[],
                           [Uniform(0.0, 7.0); Uniform(0.0, 7.0); Uniform(0.0, 1.0); Uniform(0.0, 1.0)],
                           UnivariateDistribution[],
                           [Uniform(0.0, 1.0)],
                           [Uniform(0.0, 1.0)])

# Using CDC measles information set some extents for event data augmentation
# Exposure up to 2 weeks before infectiousness, with a minimum of 5 days between exposure and infectiousness
# Infectious up to 3 days before prodrome
# Removal time within 2-4 days after rash
ee = EventExtents{SEIR}((5.0, 14.0), 3.0, 2.0)

# Initialize MCMC
mcmc = MCMC(obs, ee, pop, rf, rpriors)
start!(mcmc, attempts = 100000) # 1 chain, with 100k initialization attempts

# Run MCMC
iterate!(mcmc, 200000, 1.0, condition_on_network = true, event_batches = 10)

gr(dpi = 200); # GR backend for Plots.jl with DPI=200

# MCMC and posterior plots
p1 = plot(1:20:200001, mcmc.markov_chains[1].risk_parameters, yscale = :log10, title = "TN-ILM parameters")
png(p1, joinpath(@__DIR__, "1861_trace.png"))

p2 = plot(mcmc.markov_chains[1].events[100000], State_S,
          linealpha = 0.01, title = "S", xguidefontsize = 8, yguidefontsize = 8,
          xtickfontsize = 7, ytickfontsize = 7, titlefontsize = 11)
for i = 100050:50:200000
  plot!(p2, mcmc.markov_chains[1].events[i], State_S, linealpha = 0.02)
end

p3 = plot(mcmc.markov_chains[1].events[100000], State_E,
          linealpha = 0.01, title = "E", xguidefontsize = 8, yguidefontsize = 8,
          xtickfontsize = 7, ytickfontsize = 7, titlefontsize = 11)
for i = 100050:50:200000
  plot!(p3, mcmc.markov_chains[1].events[i], State_E, linealpha = 0.02)
end

p4 = plot(mcmc.markov_chains[1].events[100000], State_I,
          linealpha = 0.01, title = "I", xguidefontsize = 8, yguidefontsize = 8, xtickfontsize = 7, ytickfontsize = 7, titlefontsize = 11)
for i = 100050:50:200000
  plot!(p4, mcmc.markov_chains[1].events[i], State_I, linealpha = 0.02)
end
plot!(p4, obs, State_I, linecolor = :black, linewidth = 1.5) # Show infection observations (day of prodrome)

p5 = plot(mcmc.markov_chains[1].events[100000], State_R,
          linealpha = 0.01, title = "R", xguidefontsize = 8, yguidefontsize = 8, xtickfontsize = 7, ytickfontsize = 7, titlefontsize = 11)
for i = 100050:50:200000
  plot!(p5, mcmc.markov_chains[1].events[i], State_R, linealpha = 0.02)
end
plot!(p5, obs, State_R, linecolor = :black, linewidth = 1.5) # Show removal observations (day of appearance of rash + 4)

l = @layout [a b c d]
combinedplots1 = plot(p2, p3, p4, p5, layout = l, link = :y, size = (800,200))
png(combinedplots1, joinpath(@__DIR__, "1861_curves.png"))

tnp = TransmissionNetworkPosterior(mcmc, burnin = 100000, thin = 50)

tnoesterle = TransmissionNetwork(BitArray(readdlm("data/oesterle_tn_external.csv")[:]), BitArray(readdlm("data/oesterle_tn_internal.csv")))

# As there are several individuals at single locations, jitter locations to better illustrate the
# transmission network distribution
xyjitter = select(risks, :x, :y)
xyjitter[:, :x] = xyjitter[:, :x] + rand(Normal(0, 3), 188)
xyjitter[:, :y] = xyjitter[:, :y] + rand(Normal(0, 3), 188)
plotpop = Population(xyjitter)

p6 = plot(tnp, plotpop, title = "Transmission Network Posterior Distribution", titlefontsize = 11, framestyle = :box, markeralpha = 0.5, size = (400, 400))

p7 = plot(tnoesterle, plotpop, title = "Oesterle (1992) Transmission Network", titlefontsize = 11, framestyle = :box, markeralpha = 0.5, size = (400, 400))

l = @layout [a b]
combinedplots2 = plot(p6, p7, layout = l, size = (800, 400))
png(combinedplots2, joinpath(@__DIR__, "1861_tn.png"))

# Marginal posterior distribution summary of TN-ILM parameters
param_summary = summary(mcmc, burnin = 100000, thin = 50)
println(param_summary)

# Implied posterior mean latent period
println(param_summary[6,2]^-1)

# Implied posterior mean infectious period
println(param_summary[7,2]^-1)

# Infectious pressure from infectious family member relative to infectious classmate
println(param_summary[5,2]/param_summary[4,2])

# Infectious pressure from individual residing 15m away relative to infectious classmate
println((param_summary[2,2]*15^-param_summary[3,2])/param_summary[4,2])

# Infectious pressure from individual residing 30m away relative to infectious classmate
println((param_summary[2,2]*30^-param_summary[3,2])/param_summary[4,2])

# Extra calculations to support in-text result discussion
# What is the outdegree of individual 45?
println(Pathogen._outdegree(tnoesterle)[45])
println(Pathogen._outdegree(tnp)[45])

# How many transmissions are common to the posterior mode network and Oesterle?
println(sum(mode(tnp).internal .* tnoesterle.internal) + sum(mode(tnp).external .* tnoesterle.external))

# How many external transmissions are in the posterior transmission network
println(sum(tnp.external))

# Calculate posterior mean events
pm_events = mean(mcmc.markov_chains[1].events[100000:50:200000])

# Histogram of latent period
histogram(pm_events.infection .- pm_events.exposure, legend = :none, ylabel = "Frequency", xlabel = "Latent period")

# Mean latent period
println(mean(pm_events.infection .- pm_events.exposure))

# Histogram of incubation period
histogram(obs.infection .- pm_events.exposure, legend = :none, ylabel = "Frequency", xlabel = "Incubation period")

# Mean incubation period
println(mean(obs.infection .- pm_events.exposure))

# Display the earliest and latest infection observations
println(sort(obs.infection))

# Find the individual with the latest infection observation
println(argmax(obs.infection))

# Show the posterior probability of individual 141 having an external transmission source
println(tnp.external[141])
