Commit 0c9e58ee authored by Peter Jentsch's avatar Peter Jentsch
Browse files

behavoiral fitting WIP

parent d596e111
......@@ -8,9 +8,9 @@ version = "1.0.1"
[[AbstractMCMC]]
deps = ["BangBang", "ConsoleProgressMonitor", "Distributed", "Logging", "LoggingExtras", "ProgressLogging", "Random", "StatsBase", "TerminalLoggers", "Transducers"]
git-tree-sha1 = "06e46b94299b6d0d9fb4f7833c7b1df3df6f9678"
git-tree-sha1 = "29683bc1b52e1879ac0951253d8b0e2f60bf4cb4"
uuid = "80f14c24-f653-4e6a-9b94-39d6b0f70001"
version = "3.0.0"
version = "3.1.0"
[[AbstractTrees]]
git-tree-sha1 = "03e0550477d86222521d254b741d470ba17ea0b5"
......@@ -39,9 +39,9 @@ version = "0.1.0"
[[ArrayInterface]]
deps = ["IfElse", "LinearAlgebra", "Requires", "SparseArrays", "Static"]
git-tree-sha1 = "f5cb35e4e2cbb488846d95b471f479be1ff6b173"
git-tree-sha1 = "d84e8967b7f04f52c9bca21714bae54a553a53fc"
uuid = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9"
version = "3.1.9"
version = "3.1.10"
[[Artifacts]]
uuid = "56f22d72-fd6d-98f1-02f0-08ddc0907c33"
......@@ -236,10 +236,10 @@ uuid = "b552c78f-8df3-52c6-915a-8e097449b14b"
version = "1.0.2"
[[Distances]]
deps = ["LinearAlgebra", "Statistics"]
git-tree-sha1 = "366715149014943abd71aa647a07a43314158b2d"
deps = ["LinearAlgebra", "Statistics", "StatsAPI"]
git-tree-sha1 = "abe4ad222b26af3337262b8afb28fab8d215e9f8"
uuid = "b4f34e82-e78d-54a5-968a-f98e89d6e8f7"
version = "0.10.2"
version = "0.10.3"
[[Distributed]]
deps = ["Random", "Serialization", "Sockets"]
......@@ -298,9 +298,9 @@ version = "4.3.1+4"
[[FFTW]]
deps = ["AbstractFFTs", "FFTW_jll", "LinearAlgebra", "MKL_jll", "Preferences", "Reexport"]
git-tree-sha1 = "1dc6ca6ad69eb9beadd3ce82b90910f4fa63d7c3"
git-tree-sha1 = "746f68839306977040653ebbd249e39c15420b8a"
uuid = "7a1cc6ca-52ef-59f5-83cd-3a7055c09341"
version = "1.4.0"
version = "1.4.1"
[[FFTW_jll]]
deps = ["Artifacts", "JLLWrappers", "Libdl", "Pkg"]
......@@ -515,9 +515,9 @@ version = "2.0.1+3"
[[KernelDensity]]
deps = ["Distributions", "DocStringExtensions", "FFTW", "Interpolations", "StatsBase"]
git-tree-sha1 = "09aeec87bdc9c1fa70d0b508dfa94a21acd280d9"
git-tree-sha1 = "591e8dc09ad18386189610acafb970032c519707"
uuid = "5ab0869b-81aa-558d-bb23-cbf5423bbe9b"
version = "0.6.2"
version = "0.6.3"
[[KissABC]]
deps = ["AbstractMCMC", "Distributions", "MonteCarloMeasurements", "Random"]
......@@ -678,9 +678,9 @@ version = "0.4.6"
[[LoopVectorization]]
deps = ["ArrayInterface", "CheapThreads", "DocStringExtensions", "IfElse", "LinearAlgebra", "OffsetArrays", "Requires", "SLEEFPirates", "Static", "StrideArraysCore", "ThreadingUtilities", "UnPack", "VectorizationBase"]
git-tree-sha1 = "50652d45832a4f907dbec3c57ac93af6f7d1f28c"
git-tree-sha1 = "427ec6a601c32d704bb664b32bf695519ef66043"
uuid = "bdcacae8-1622-11e9-2a5c-532679323890"
version = "0.12.14"
version = "0.12.15"
[[LsqFit]]
deps = ["Distributions", "ForwardDiff", "LinearAlgebra", "NLSolversBase", "OptimBase", "Random", "StatsBase"]
......@@ -1120,9 +1120,9 @@ version = "0.9.8"
[[StrideArraysCore]]
deps = ["ArrayInterface", "Requires", "ThreadingUtilities", "VectorizationBase"]
git-tree-sha1 = "da1091034d295c8dbaf1d6ea16529221bc24afe1"
git-tree-sha1 = "62a9b1e31f0741a642455f42ddaa9582101b3e71"
uuid = "7792a7ef-975c-4747-a70f-980b88e8d1da"
version = "0.1.5"
version = "0.1.6"
[[StructArrays]]
deps = ["Adapt", "DataAPI", "Tables"]
......
......@@ -5,49 +5,60 @@ using StatsBase
const parameters = (
sim_length = 600,
num_households = 5000,
I_0_fraction = 0.005,
I_0_fraction = 0.002,
base_transmission_probability = 0.001,
recovery_rate = 1/7,
immunization_loss_prob = 0.0055, #mean time of 6 months
π_base = -4.0,
π_base = -0.3,
η = 0.0,
κ = 1.0,
ω = 0.001,
κ = 0.8,
ω = 0.00005,
ρ = [0.0,0.0,0.0],
ω_en = 0.00,
ρ_en = [0.0,0.0,0.0],
γ = 0.0,
β = 5.0,
β = 10.0,
notification_parameter = 0.001,
vaccinator_prob = 0.2,
app_user_fraction = 0.4,
notification_threshold = 2,
app_user_fraction = 0.2,
notification_threshold = 17,
immunizing = true,
immunization_delay = 14,
immunization_begin_day = 50,
immunization_begin_day = 30,
infection_introduction_day = 100
)
# seasonal_transmission_dist = CovidAlertVaccinationModel.fit_epi_parameters(parameters,0.073) ##seasonal
# outbreak_transmission_dist = CovidAlertVaccinationModel.fit_epi_parameters(parameters,0.241) ##outbreak
# function plot_max_posterior(fname,particles)
# samples = 5
# base_transmission = mode(particles.P.particles)
# p_tuple_without_vac = merge(parameters,
# (
# sim_length = 150,
# immunization_begin_day = 0,
# infection_introduction_day = 1,
# immunizing = false,
# )
# )
# new_params = merge(p_tuple_without_vac, (base_transmission_probability = base_transmission,))
# out = mean_solve(samples, new_params ,DebugRecorder)
# p = plot_model(nothing,[nothing],[out],new_params.infection_introduction_day,new_params.immunization_begin_day)
# savefig(p,"$fname.pdf")
# hist = fit(Histogram,particles.P.particles; nbins = 25)
# p = plot(hist;legend = false)
# savefig(p,"$(fname)_posterior.pdf")
# end
# plot_max_posterior("seasonal", seasonal_transmission_dist)
# plot_max_posterior("outbreak", outbreak_transmission_dist)
function plot_max_posterior(fname,particles)
samples = 5
base_transmission = mode(particles.P.particles)
p_tuple_without_vac = merge(parameters,
(
sim_length = 150,
immunization_begin_day = 0,
infection_introduction_day = 1,
immunizing = false,
)
)
new_params = merge(p_tuple_without_vac, (base_transmission_probability = base_transmission,))
out = mean_solve(samples, new_params ,DebugRecorder)
p = plot_model(nothing,[nothing],[out],new_params.infection_introduction_day,new_params.immunization_begin_day)
savefig(p,"$fname.pdf")
end
plot_max_posterior("seasonal", seasonal_transmission_dist)
plot_max_posterior("outbreak", outbreak_transmission_dist)
\ No newline at end of file
a = CovidAlertVaccinationModel.fit_behavioural_parameters(parameters,nothing)
function solve_and_plot_parameters()
out = mean_solve(samples, parameters ,DebugRecorder)
p = plot_model(nothing,[nothing],[out],parameters.infection_introduction_day,parameters.immunization_begin_day)
savefig(p,"timeseries.pdf")
return out
end
\ No newline at end of file
using CovidAlertVaccinationModel
using OnlineStats
using Plots
const samples = 2
const samples = 1
const parameters = (
sim_length = 600,
num_households = 5000,
......@@ -9,28 +9,32 @@ const parameters = (
base_transmission_probability = 0.001,
recovery_rate = 1/7,
immunization_loss_prob = 0.0055, #mean time of 6 months
π_base = -4.0,
π_base = -0.3,
η = 0.0,
κ = 1.0,
ω = 0.001,
κ = 0.8,
ω = 0.00005,
ρ = [0.0,0.0,0.0],
ω_en = 0.00,
ρ_en = [0.0,0.0,0.0],
γ = 0.0,
β = 5.0,
β = 10.0,
notification_parameter = 0.001,
vaccinator_prob = 0.2,
app_user_fraction = 0.4,
notification_threshold = 2,
app_user_fraction = 0.2,
notification_threshold = 17,
immunizing = true,
immunization_delay = 14,
immunization_begin_delay = 50,
infection_introduction_delay = 100
immunization_begin_day = 30,
infection_introduction_day = 100
)
function solve_and_plot_parameters()
out = mean_solve(samples, parameters ,DebugRecorder)
p = plot_model(nothing,[nothing],[out],parameters.infection_introduction_delay,parameters.immunization_begin_delay)
p = plot_model(nothing,[nothing],[out],parameters.infection_introduction_day,parameters.immunization_begin_day)
savefig(p,"timeseries.pdf")
return out
end
solve_and_plot_parameters()
\ No newline at end of file
out = solve_and_plot_parameters()
println(mean.(out.new_ymo_immunization.Y))
println(mean.(out.new_ymo_immunization.M))
println(mean.(out.new_ymo_immunization.O))
\ No newline at end of file
......@@ -28,4 +28,5 @@ Run the model with given parameter tuple and output recorder. See `get_parameter
function abm(parameters, recorder)
model_sol = ModelSolution(parameters.sim_length,parameters,5000)
output = solve!(model_sol,recorder )
return model_sol
end
......@@ -11,7 +11,9 @@ function get_parameters()
η = 0.0,
κ = 0.0,
ω = 0.0005,
ρ = [0.0,0.0,0.0],
ρ_y = 0.0,
ρ_m = 0.0,
ρ_o = 0.0,
ω_en = 0.00,
ρ_en = [0.0,0.0,0.0],
γ = 0.0,
......@@ -67,6 +69,7 @@ mutable struct ModelSolution{T,InfNet,SocNet,WSMixingDist,RestMixingDist}
status_totals::Vector{Int}
daily_vaccinators::Int
daily_infected::Int
daily_immunizations_by_age::Vector{Int}
ws_matrix_tuple::WSMixingDist
rest_matrix_tuple::RestMixingDist
immunization_countdown::Vector{Int}
......@@ -113,6 +116,7 @@ mutable struct ModelSolution{T,InfNet,SocNet,WSMixingDist,RestMixingDist}
status_totals,
0,
0,
[0,0,0],
ws_matrix_tuple,
rest_matrix_tuple,
immunization_countdown
......
......@@ -16,13 +16,18 @@ abstract type AbstractRecorder{ElType} end
"""
DebugRecorder should store everything we might want to know about the model output.
"""
struct DebugRecorder{ElType,ArrType1<:AbstractArray{ElType},ArrType2<:AbstractArray{ElType}} <: AbstractRecorder{ElType}
recorded_status_totals::ArrType1
daily_cases::ArrType2
total_vaccinators::ArrType2
mean_time_since_last_notification::ArrType2
struct DebugRecorder{ElType,ArrT1<:AbstractArray{ElType},ArrT2<:AbstractArray{ElType},ArrT3<:AbstractArray{ElType}} <: AbstractRecorder{ElType}
recorded_status_totals::ArrT1
daily_cases::ArrT2
total_vaccinators::ArrT2
mean_time_since_last_notification::ArrT2
new_ymo_immunization::ArrT3
end
# struct AgeSpecificVaccination{ElType,ArrType1<:AbstractArray{ElType}} <: AbstractRecorder{ElType}
# new_ymo_vaccinations::ArrType1
# daily_cases::ArrType2
# total_vaccinators::ArrType2
# end
"""
Initialize an empty DebugRecorder. We use a labelledarray for the state vector, so the individal timeseries can be accessed by name.
"""
......@@ -31,14 +36,17 @@ function DebugRecorder(sim_length)
total_vaccinators = Vector{Int}(undef,sim_length)
mean_time_since_last_notification = Vector{Int}(undef,sim_length)
daily_cases= Vector{Int}(undef,sim_length)
new_ymo_immunization = @LArray Array{Int}(undef,3,sim_length) (Y = (1,:),M = (2,:),O = (3,:))
return DebugRecorder(
state_totals,
daily_cases,
total_vaccinators,
mean_time_since_last_notification,
new_ymo_immunization
)
end
end
"""
Initialize a DebugRecorder filled with (copies) of val. I should find a nicer way to combine both constructors.
......@@ -50,11 +58,14 @@ function DebugRecorder(val::T, sim_length) where T
mean_time_since_last_notification = [copy(val) for j in 1:sim_length]
daily_cases = [copy(val) for j in 1:sim_length]
totals_immunization = [copy(val) for i in 1:3, j in 1:sim_length]
new_ymo_immunization = @LArray totals_immunization (Y = (1,:),M = (2,:),O = (3,:))
return DebugRecorder(
state_totals,
daily_cases,
total_vaccinators,
mean_time_since_last_notification
mean_time_since_last_notification,
new_ymo_immunization
)
end
......@@ -69,6 +80,8 @@ function record!(t,modelsol, recorder::DebugRecorder)
else
recorder.mean_time_since_last_notification[t] = 0
end
# display(modelsol.daily_immunizations_by_age)
recorder.new_ymo_immunization[:,t].=modelsol.daily_immunizations_by_age
end
function record!(t,modelsol, recorder::Nothing)
......@@ -161,7 +174,7 @@ function plot_model(varname,univariate_series, output_list::Vector{T},infection_
colorbar = false,
title = "Daily (incident) cases",
legend= false
)
)
end
for p in plts
vline!(p,[infection_begin]; label = "infection begin", line =:dot)
......
......@@ -50,14 +50,79 @@ end
function fit_behavioural_parameters(p_tuple, target_growth_rate)
samples = 1
priors = Factored(Uniform(0.0001,0.005))
p_names = (:π_base,:κ, :ρ_y,:ρ_m,:ρ_o)
priors = Factored(Uniform(-2.0,0.0),Uniform(0.0,1.0),Uniform(0.0,1.0),Uniform(0.0,1.0),Uniform(0.0,1.0))
#simulation begins in august
#30 days for opinion dynamics to stabilize, then immunization begins in september,
#infection is introduced at the end of november
sim_length = 210
p_tuple_adjust = merge(p_tuple,
(
sim_length = sim_length,
I_0_fraction = 0.000,
immunization_begin_day =30,
infection_introduction_day = 90,
immunizing = true,
)
)
target_cumulative_vac_proportion = 0.33
vaccination_data = @SVector [0.0,0.043,0.385,0.424,0.115,0.03,0.005] #by month starting in august
ymo_vac = @SVector [0.255,0.278,0.602]
function cost(p)
new_params = merge(p_tuple, (base_transmission_probability = p[1],))
out = mean_solve(samples, new_params ,DebugRecorder)
incident_cases = mean.(out.daily_cases)
_,exp_growth_rate = exp_growth_rate_estimation(incident_cases)
return (target_growth_rate - exp_growth_rate)^2
new_params = merge(p_tuple_adjust, NamedTuple{p_names}(ntuple(i -> p[i],length(p_names))))
out = DebugRecorder(0,sim_length)
model = abm(new_params,out)
vaccination_ts = mean.(out.recorded_status_totals.V)
# vaccination_ts = mean.(out.total_vaccinators)
# display(out.new_ymo_immunization)
ymo_vaccination_ts = out.new_ymo_immunization
monthly_ymo_vaccination = [sum.(eachrow(ymo_vaccination_ts[:,i:min(i+30,sim_length)])) for i in 1:30:sim_length]
target_cumulative_vaccinations = target_cumulative_vac_proportion*model.nodes
# display(target_cumulative_vaccinations)
total_err = 0
for (j,(monthly_ymo_vac,target_montly_vac)) in enumerate(zip(monthly_ymo_vaccination,vaccination_data))
target_monthly_ymo_vac = ymo_vac .* target_montly_vac .* target_cumulative_vaccinations
# display((j,p[1],p[2],monthly_ymo_vac))
total_err += sum((monthly_ymo_vac .- target_monthly_ymo_vac).^2)
end
# display(daily_vaccination_y_ts)
# p1 = plot(vaccination_ts; label = false)
# p2 = plot(collect(eachrow(daily_vaccination_y_ts)); label = ["Y" "M" "O"])
# p = plot(p1,p2)
# display(p)
# return (mean(vaccination_ts[end - 10:end]) - 3000.0)^2
return total_err
end
out = ABCDE(priors,cost,0.01; verbose=true, nparticles=100,generations=20, parallel = true) #this one has better NaN handling
#smc(priors,cost; verbose = true, nparticles = 100, parallel = true)#
out = ABCDE(priors,cost,1e6; verbose=true, nparticles=100,generations=50, parallel = true) #this one has better NaN handling
return out
end
function plot_behavioural_fit(particles,p_tuple)
p_names = (:π_base, :κ, :ρ_y,:ρ_m,:ρ_o)
sim_length = 210
samples = 1
p_tuple_adjust = merge(p_tuple,
(
sim_length = sim_length,
I_0_fraction = 0.005,
immunization_begin_day =30,
infection_introduction_day = 90,
immunizing = true,
)
)
p = map(e -> mode(e.particles),particles.P)
display(p)
new_params = merge(p_tuple_adjust, NamedTuple{p_names}(ntuple(i -> p[i],length(p_names))))
out = mean_solve(samples, new_params ,DebugRecorder)
display(new_params)
p = plot_model(nothing,[nothing],[out],new_params.infection_introduction_day,new_params.immunization_begin_day)
savefig(p,"behaviour_fit.pdf")
return out
end
\ No newline at end of file
......@@ -42,8 +42,9 @@ end
Base.@propagate_inbounds @views function update_infection_state!(t,modelsol)
@unpack base_transmission_probability,immunization_loss_prob,recovery_rate,immunizing,immunization_begin_day = modelsol.params
@unpack u_inf,u_vac,u_next_inf,u_next_vac,demographics,inf_network,status_totals, immunization_countdown = modelsol
daily_cases = 0
modelsol.daily_infected = 0
modelsol.daily_immunizations_by_age.= 0
function agent_transition!(node, from::AgentStatus,to::AgentStatus)
immunization_countdown[node] = -1
status_totals[Int(from)] -= 1
......@@ -63,7 +64,7 @@ Base.@propagate_inbounds @views function update_infection_state!(t,modelsol)
for j in neighbors(mixing_graph.g,i)
if u_inf[j] == Infected && u_next_inf[i] != Infected
if rand(RNG) < contact_weight(base_transmission_probability,get_weight(mixing_graph,GraphEdge(i,j)))
daily_cases+=1
modelsol.daily_infected+=1
agent_transition!(i, Susceptible,Infected)
end
end
......@@ -81,17 +82,15 @@ Base.@propagate_inbounds @views function update_infection_state!(t,modelsol)
end
if immunization_countdown[i] == 0
modelsol.daily_immunizations_by_age[Int(agent_demo)] += 1
agent_transition!(i, Susceptible,Immunized)
elseif immunization_countdown[i]>0
# display(immunization_countdown)
immunization_countdown[i] -= 1
end
end
modelsol.daily_infected = daily_cases
end
Base.@propagate_inbounds @views function update_vaccination_opinion_state!(t,modelsol,total_infections)
@unpack π_base, η,γ, κ, ω, ρ, ω_en,ρ_en,γ,β = modelsol.params
@unpack π_base, η,γ, κ, ω, ρ_y,ρ_m,ρ_o, ω_en,ρ_en,γ,β = modelsol.params
@unpack demographics,time_of_last_alert, nodes, soc_network,u_vac,u_next_vac,app_user,app_user_list = modelsol
app_user_pointer = 0
for i in 1:nodes
......@@ -110,8 +109,9 @@ Base.@propagate_inbounds @views function update_vaccination_opinion_state!(t,mod
end
end
end
ρ = @SVector [ρ_y,ρ_m,ρ_o]
vac_payoff += π_base + dot(ρ,soc_nbrs_vac) + total_infections*ω +
ifelse(num_soc_nbrs> 0, κ * ((sum(soc_nbrs_vac) - soc_nbrs_nonvac/num_soc_nbrs)),0)
ifelse(num_soc_nbrs> 0, κ * ((sum(soc_nbrs_vac) - soc_nbrs_nonvac)/num_soc_nbrs),0)
if app_user[i] && time_of_last_alert[app_user_list[i]]>=0
vac_payoff += γ^(-1*(t - time_of_last_alert[app_user_list[i]]))* (η + dot(ρ_en,soc_nbrs_vac) + total_infections*ω_en)
......@@ -132,7 +132,7 @@ Base.@propagate_inbounds @views function update_vaccination_opinion_state!(t,mod
end
end
modelsol.daily_vaccinators = count(==(true),u_vac)
modelsol.daily_vaccinators = count(==(true),u_vac) #could maybe make this more efficient
end
......
No preview for this file type
No preview for this file type
File deleted
No preview for this file type
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment