Commit 41ffd902 authored by Peter Jentsch's avatar Peter Jentsch
Browse files

add univariate snensitivity plots, some output code, added the model size and...

add univariate snensitivity plots, some output code, added the model size and length to parameter tuple. New output is in /CovidAlertVaccinationModel/plots/univariate
parent 189c2fa4
......@@ -61,6 +61,11 @@ version = "0.3.30"
[[Base64]]
uuid = "2a0f44e3-6c83-55bd-87e4-b1978d98bd5f"
[[Baselet]]
git-tree-sha1 = "aebf55e6d7795e02ca500a689d326ac979aaf89e"
uuid = "9718e550-a3fa-408a-8086-8db961cd8217"
version = "0.1.1"
[[BenchmarkTools]]
deps = ["JSON", "Logging", "Printf", "Statistics", "UUIDs"]
git-tree-sha1 = "068fda9b756e41e6c75da7b771e6f89fa8a43d15"
......@@ -85,12 +90,6 @@ git-tree-sha1 = "e2f47f6d8337369411569fd45ae5753ca10394c6"
uuid = "83423d85-b0ee-5818-9007-b63ccbeb887a"
version = "1.16.0+6"
[[CategoricalArrays]]
deps = ["DataAPI", "Future", "JSON", "Missings", "Printf", "Statistics", "StructTypes", "Unicode"]
git-tree-sha1 = "f713d583d10fc036252fd826feebc6c173c522a8"
uuid = "324d7699-5711-5eae-9e2f-1d82baa6b597"
version = "0.9.5"
[[ChainRulesCore]]
deps = ["Compat", "LinearAlgebra", "SparseArrays"]
git-tree-sha1 = "42e3c181483fbd2c416087a0a93838803e358358"
......@@ -171,10 +170,10 @@ uuid = "9a962f9c-6df0-11e9-0e5d-c546b8b5ee8a"
version = "1.6.0"
[[DataFrames]]
deps = ["CategoricalArrays", "Compat", "DataAPI", "Future", "InvertedIndices", "IteratorInterfaceExtensions", "LinearAlgebra", "Markdown", "Missings", "PooledArrays", "PrettyTables", "Printf", "REPL", "Reexport", "SortingAlgorithms", "Statistics", "TableTraits", "Tables", "Unicode"]
git-tree-sha1 = "d50972453ef464ddcebdf489d11885468b7b83a3"
deps = ["Compat", "DataAPI", "Future", "InvertedIndices", "IteratorInterfaceExtensions", "LinearAlgebra", "Markdown", "Missings", "PooledArrays", "PrettyTables", "Printf", "REPL", "Reexport", "SortingAlgorithms", "Statistics", "TableTraits", "Tables", "Unicode"]
git-tree-sha1 = "56ff5833e5b755d2db654479993e949e73606b64"
uuid = "a93c6f00-e57d-5684-b7b6-d8193f3e46c0"
version = "0.22.7"
version = "1.0.0"
[[DataStructures]]
deps = ["Compat", "InteractiveUtils", "OrderedCollections"]
......@@ -218,9 +217,9 @@ uuid = "8ba89e20-285c-5b6f-9357-94700520ee1b"
[[Distributions]]
deps = ["FillArrays", "LinearAlgebra", "PDMats", "Printf", "QuadGK", "Random", "SparseArrays", "SpecialFunctions", "Statistics", "StatsBase", "StatsFuns"]
git-tree-sha1 = "09b3d464f5cbeaf1a2a422afe20d82eff421a7ca"
git-tree-sha1 = "a837fdf80f333415b69684ba8e8ae6ba76de6aaa"
uuid = "31c24e10-a181-5473-b8eb-7969acd0382f"
version = "0.24.17"
version = "0.24.18"
[[DocStringExtensions]]
deps = ["LibGit2", "Markdown", "Pkg", "Test"]
......@@ -608,9 +607,9 @@ uuid = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
[[LogExpFunctions]]
deps = ["DocStringExtensions"]
git-tree-sha1 = "9809b844f0ff853f0620e0cac7a712e1818671e5"
git-tree-sha1 = "49c5c32deda5999d15378b64ee10f2e87831ab25"
uuid = "2ab3a3ac-af41-5b50-aa03-7779005ae688"
version = "0.2.1"
version = "0.2.2"
[[Logging]]
uuid = "56ddb016-857b-54e1-b83d-db4d58db5568"
......@@ -813,9 +812,9 @@ version = "1.2.1"
[[PrettyTables]]
deps = ["Crayons", "Formatting", "Markdown", "Reexport", "Tables"]
git-tree-sha1 = "574a6b3ea95f04e8757c0280bb9c29f1a5e35138"
git-tree-sha1 = "a7162ad93a899333717481f448a235ffafeb5eba"
uuid = "08abe8d2-0d0c-5749-adfa-8a2ac140af0d"
version = "0.11.1"
version = "1.0.0"
[[Printf]]
deps = ["Unicode"]
......@@ -1021,12 +1020,6 @@ git-tree-sha1 = "44b3afd37b17422a62aea25f04c1f7e09ce6b07f"
uuid = "09ab397b-f2b6-538f-b94a-2f83cf4a842a"
version = "0.5.1"
[[StructTypes]]
deps = ["Dates", "UUIDs"]
git-tree-sha1 = "ad4558dee74c5d26ab0d0324766b1a3ee6ae777a"
uuid = "856f2bd8-1eba-4b0a-8007-ebc267875bd4"
version = "1.7.1"
[[SuiteSparse]]
deps = ["Libdl", "LinearAlgebra", "Serialization", "SparseArrays"]
uuid = "4607b0f0-06f3-5cda-b6b1-a6196a1729e9"
......@@ -1086,10 +1079,10 @@ uuid = "f269a46b-ccf7-5d73-abea-4c690281aa53"
version = "1.5.3"
[[Transducers]]
deps = ["ArgCheck", "BangBang", "CompositionsBase", "DefineSingletons", "Distributed", "InitialValues", "Logging", "Markdown", "MicroCollections", "Requires", "Setfield", "SplittablesBase", "Tables"]
git-tree-sha1 = "d0aa4681564aa1c68bb1e146dd181817f139697b"
deps = ["Adapt", "ArgCheck", "BangBang", "Baselet", "CompositionsBase", "DefineSingletons", "Distributed", "InitialValues", "Logging", "Markdown", "MicroCollections", "Requires", "Setfield", "SplittablesBase", "Tables"]
git-tree-sha1 = "c277f1190f76f108cfdb89b9d5da87d9602e5593"
uuid = "28d57a85-8fef-5791-bfe6-a80928e7c999"
version = "0.4.63"
version = "0.4.64"
[[URIs]]
git-tree-sha1 = "7855809b88d7b16e9b029afd17880930626f54a2"
......
using CovidAlertVaccinationModel
using OnlineStats
using ThreadsX
using Plots
const samples = 5
##Univariate tests
const len = 4 #number of points to evaluate
gr()
const default_parameters = (
sim_length = 400,
num_households = 5000,
I_0_fraction = 0.001,
base_transmission_probability = 0.005,
recovery_rate = 0.1,
immunization_loss_prob = 0.0055, #mean time of 6 months
π_base = -2.0,
η = 0.0,
κ = 0.0,
ω = 0.0,
ρ = [0.0,0.0,0.0],
ω_en = 0.02,
ρ_en = [0.0,0.0,0.0],
γ = 0.0,
β = 5.0,
notification_parameter = 0.1,
vaccinator_prob = 0.2,
app_user_fraction = 0.4,
)
const univarate_test_list = (
(:I_0_fraction, range(0.0, 0.05; length = len)),
(:base_transmission_probability, range(0.001, 0.02; length = len)),
(:recovery_rate, range(0.05, 0.5; length = len)),
(:immunization_loss_prob, range(0.00, 0.05; length = len)),
(:π_base, range(-0.1, 0.1; length = len)),
(:η, range(0.0, 1.0; length = len)),
(:κ, range(0.0, 0.1; length = len)),
(:ω, range(0.0, 0.025; length = len)),
(:ω_en, range(0.0, 0.5; length = len)),
(:γ, range(0.0, 0.5; length = len)),
(:notification_parameter, range(0.1, 1.0; length = len)),
(:app_user_fraction, range(0.05, 0.5; length = len)),
)
const univariate_path = "CovidAlertVaccinationModel/plots/univariate/"
function univarate_test(variable, variable_range)
parameter_range_list = [merge(default_parameters,NamedTuple{(variable,)}((value,))) for value in variable_range]
solve_fn(p) = mean_solve(samples, p,DebugRecorder)
univariate_outlist = ThreadsX.map(solve_fn, parameter_range_list)
p = plot_model(variable,parameter_range_list,univariate_outlist)
return p
end
if !ispath(univariate_path)
mkdir(univariate_path)
end
function univariate_simulations()
plt_list = ThreadsX.map(univarate_test_list) do ur
out = univarate_test(ur...)
display("done $(ur[1])")
return out
end
for ((varname,_),p) in zip(univarate_test_list,plt_list)
savefig(p,"$univariate_path/$varname.pdf")
end
end
univariate_simulations()
......@@ -5,25 +5,27 @@ const age_bins = [(0.0, 25.0),(25.0,65.0),(65.0,Inf)]
const household_data = read_household_data()
default(dpi = 300)
default(framestyle = :box)
import LightGraphs.neighbors
export DebugRecorder,mean_solve,plot_model, get_parameters
"""
bench()
Runs the model with 5k households for 500 timesteps.
Resets the RNG state, and runs the model with default parameters.
"""
function bench()
Random.seed!(RNG,1)
steps = 50
model_sol = ModelSolution(steps,get_parameters(),5000)
recording = DebugRecorder(steps)
output = solve!(model_sol,recording )
p = get_parameters()
model_sol = ModelSolution(p.sim_length,p,p.num_households)
recording = DebugRecorder(p.sim_length)
output = solve!(model_sol,recording)
end
"""
The only function we export. I just put whatever I am currently working on into this function to make the model easier to test from the REPL.
Run the model with given parameter tuple and output recorder. See `get_parameters` for list of parameters. See `output.jl` for the list of recorders. Currently just 'DebugRecorder`.
"""
function abm()
b1 = @benchmark bench() seconds = 20
display(b1)
println("done")
function abm(parameters, recorder)
model_sol = ModelSolution(parameters.sim_length,parameters,5000)
output = solve!(model_sol,recorder )
end
......@@ -187,8 +187,6 @@ end
"""
Completely remake all the graphs in `time_dep_mixing_graph.resampled_graphs`.
This is a huge bottleneck for the performance of the model but I doubt it can be improved much more.
"""
function remake!(time_dep_mixing_graph,demographic_index_vectors,mixing_matrix)
for weighted_graph in time_dep_mixing_graph.resampled_graphs
......
function get_parameters()
params = (
I_0_fraction = 0.01,
base_transmission_probability = 0.5,
sim_length = 500,
num_households = 5000,
I_0_fraction = 0.005,
base_transmission_probability = 0.01,
recovery_rate = 0.1,
immunization_loss_prob = 0.5,
π_base = -0.05,
immunization_loss_prob = 0.0055,
π_base = -1.0,
η = 0.0,
κ = 0.0,
ω = 0.0,
ρ = [0.0,0.0,0.0],
ω_en = 0.0,
ω_en = 0.01,
ρ_en = [0.0,0.0,0.0],
γ = 0.0,
β = 10.0,
notification_parameter = 0.0,
β = 5.0,
notification_parameter = 0.2,
vaccinator_prob = 0.2,
app_user_fraction = 0.5,
app_user_fraction = 0.4,
)
return params
end
......
#needlessly overwrought output interface
using LabelledArrays
using OnlineStats
using Plots
import OnlineStats.fit!
"""
AbstractRecorder is the type that all "Recorder" types must extend for `fit!` (and possibly other stuff) to dispatch correctly.
abstract type AbstractRecorder end
struct DebugRecorder <: AbstractRecorder
recorded_status_totals::Array{Int,2}
Total_S::Vector{Int}
Total_I::Vector{Int}
Total_R::Vector{Int}
Total_V::Vector{Int}
Total_Vaccinator::Vector{Int}
function DebugRecorder(sim_length)
return new(
Matrix{Int}(undef,AgentStatus.size,sim_length),
Vector{Int}(undef,sim_length),
Vector{Int}(undef,sim_length),
Vector{Int}(undef,sim_length),
Vector{Int}(undef,sim_length),
Vector{Int}(undef,sim_length),
)
end
The point of using different recorder types is that we only want to save the output that we need, and that output will vary depending on the specific experiment.
AbstractRecorders can be parameterized to have any element type, so when they are parameterized with T<:OnlineStat, we can use `fit!` to collect statistics about a bunch of Recorders into a single Recorder.
"""
abstract type AbstractRecorder{ElType} end
"""
DebugRecorder should store everything we might want to know about the model output.
"""
struct DebugRecorder{ElType,ArrType1<:AbstractArray{ElType},ArrType2<:AbstractArray{ElType}} <: AbstractRecorder{ElType}
recorded_status_totals::ArrType1
total_vaccinators::ArrType2
mean_time_since_last_notification::ArrType2
end
"""
Initialize an empty DebugRecorder. We use a labelledarray for the state vector, so the individal timeseries can be accessed by name.
"""
function DebugRecorder(sim_length)
state_totals = @LArray Array{Int}(undef,4,sim_length) (S = (1,:),I = (2,:),R = (3,:), V = (4,:))
total_vaccinators = Vector{Int}(undef,sim_length)
mean_time_since_last_notification = Vector{Int}(undef,sim_length)
return DebugRecorder(
state_totals,
total_vaccinators,
mean_time_since_last_notification,
)
end
function record!(t,modelsol, recorder::DebugRecorder)
recorder.Total_S[t] = count(==(Susceptible),modelsol.u_inf)
recorder.Total_I[t] = count(==(Infected),modelsol.u_inf)
recorder.Total_R[t] = count(==(Recovered),modelsol.u_inf)
recorder.Total_V[t] = count(==(Immunized),modelsol.u_inf)
recorder.Total_Vaccinator[t] = count(==(true),modelsol.u_vac)
"""
Initialize a DebugRecorder filled with (copies) of val. I should find a nicer way to combine both constructors.
"""
function DebugRecorder(val::T, sim_length) where T
totals = [copy(val) for i in 1:4, j in 1:sim_length]
state_totals = @LArray totals (S = (1,:),I = (2,:),R = (3,:), V = (4,:))
total_vaccinators = [copy(val) for j in 1:sim_length]
mean_time_since_last_notification = [copy(val) for j in 1:sim_length]
return DebugRecorder(
state_totals,
total_vaccinators,
mean_time_since_last_notification
)
end
function record!(t,modelsol, recorder::DebugRecorder)
recorder.total_vaccinators[t] = count(==(true),modelsol.u_vac)
recorder.recorded_status_totals[:,t] .= modelsol.status_totals
mean_alert_time = mean(t .- modelsol.time_of_last_alert)
recorder.mean_time_since_last_notification[t] = round(Int,mean_alert_time)
end
function record!(t,modelsol, recorder::Nothing)
#do nothing
end
function mean_solve(samples,parameter_tuple,recorder)
stat_recorder = recorder(Variance(), parameter_tuple.sim_length)
output_recorder = recorder(parameter_tuple.sim_length)
for i in 1:samples
abm(parameter_tuple,output_recorder)
fit!(stat_recorder,output_recorder)
end
return stat_recorder
end
function OnlineStats.fit!(stat_type::R,pt) where {T,OS<:OnlineStat{T}, R<:AbstractRecorder{OS}}
for field in fieldnames(R)
stat_field = getfield(stat_type,field)
sample_field = getfield(pt,field)
for (k,(stat_entry,sample_entry)) in enumerate(zip(stat_field,sample_field))
fit!(stat_entry,sample_entry)
end
end
end
using Printf
const ts_colors = cgrad(:PuBu_9)
function plot_model(varname,univariate_series, output_list::Vector{T}) where T<:DebugRecorder
plts = [plot(),plot(),plot(),plot(),plot(),plot()]
for (i,(p,data)) in enumerate(zip(univariate_series, output_list))
# display(p[varname])
p_val = @sprintf "%.3f" (p[varname])
# display(p_val)
plot!(plts[1], mean.(data.recorded_status_totals.S); ribbon = std.(data.recorded_status_totals.S),
label = "$varname = $(p_val)", line_z = i, color=:blues,
ylabel = "no. of people",
colorbar = false,
title = "Susceptible over time"
)
plot!(plts[2], mean.(data.recorded_status_totals.I); ribbon = std.(data.recorded_status_totals.I),
label = "$varname = $(p_val)", line_z = i, color=:blues,
ylabel = "no. of people",
colorbar = false,
title = "Infected over time"
)
plot!(plts[3], mean.(data.recorded_status_totals.R); ribbon = std.(data.recorded_status_totals.R),
label = "$varname = $(p_val)", line_z = i, color=:blues,
ylabel = "no. of people",
colorbar = false,
title = "Recovered over time"
)
plot!(plts[4], mean.(data.recorded_status_totals.V); ribbon = std.(data.recorded_status_totals.V),
label = "$varname = $(p_val)", line_z = i, color=:blues,
ylabel = "no. of people",
colorbar = false,
title = "Vaccinated over time"
)
plot!(plts[5], mean.(data.total_vaccinators); ribbon = std.(data.total_vaccinators),
label = "$varname = $(p_val)", line_z = i, color=:blues,
ylabel = "no. of people",
colorbar = false,
title = "No. vaccinators over time"
)
plot!(plts[6], mean.(data.mean_time_since_last_notification); ribbon = std.(data.mean_time_since_last_notification),
label = "$varname = $(p_val)", line_z = i, color=:blues,
ylabel = "Days",
colorbar = false,
title = "Mean time since last notification"
)
end
return plot(plts...;layout = (6,1),size=(800,2500),leftmargin = 5Plots.mm)
end
function plot_edges!(plt,g,network_pts, edge_alpha, edge_color)
xpts = [i[1] for i in network_pts]
ypts = [i[2] for i in network_pts]
for edge in edges(g)
e = [src(edge), dst(edge)]
plot!(plt,xpts[e],ypts[e], color = edge_color, alpha = edge_alpha)
end
end
function plot_nodes!(plt,network_pts,u_t)
for (i,var) in enumerate(instances(AgentStatus))
indices_of_var = findall(x -> x == var, u_t)
xpts = [i[1] for i in network_pts]
ypts = [i[2] for i in network_pts]
scatter!(plt,xpts[indices_of_var],ypts[indices_of_var]; color = color_palette[i], markersize = 3)
end
end
function make_layout(g)
a = LightGraphs.adjacency_matrix(g)
# network_pts = Stress.layout(a,Stress.Point2f0) # generate 2D layout
network_pts = SFDP.layout(a,SFDP.Point2f0,tol=0.1,C=1,K=1,iterations=10)
return network_pts
end
function plot_model_spatial_gif(static_graph,annealed_graph_list::AbstractVector,solution)
plt = plot()
network_pts = make_layout(static_graph)
anim = Animation()
for (t,(u_t,g_t)) in enumerate(zip(solution,annealed_graph_list))
frame_t = plot()
plot_edges!(frame_t,g_t,network_pts,0.2,:blue) ;
plot_edges!(frame_t,static_graph,network_pts,0.8,:red)
plot_nodes!(frame_t,network_pts,u_t)
plot!(frame_t;title = "t = $t", dpi = 300,xticks = nothing,yticks = nothing, legend = false)
frame(anim,frame_t)
end
gif(anim, joinpath(PACKAGE_FOLDER,"plots", "graphplot.gif"), fps = 1.0)
end
using LabelledArrays
function aggregate_timeseries(sol)
agent_statuses = map(Symbol,[AgentStatus(i) for i in 1:AgentStatus.size])
return @LArray [count(x -> AgentStatus(i) == x, u_t) for u_t in sol, i in 1:AgentStatus.size] namedtuple(agent_statuses,ntuple(i -> (:,i),length(agent_statuses)))
end
function plot_aggregate_timeseries(sol,network)
plot(sol,labels = reshape([string(k) for k in keys(keys(sol))],(1,:)))
end
\ No newline at end of file
......@@ -6,6 +6,10 @@ function contact_weight(p, contact_time)
return 1 - (1-p)^contact_time
end
function Φ(payoff,β)
return 1 / (exp(-1*β*payoff))
end
Base.@propagate_inbounds @views function update_alert_durations!(t,modelsol)
......@@ -14,7 +18,7 @@ Base.@propagate_inbounds @views function update_alert_durations!(t,modelsol)
@unpack time_of_last_alert, app_user_index,inf_network,covid_alert_times,app_user = modelsol
for (i,node) in enumerate(modelsol.app_user_index), mixing_graph in modelsol.inf_network.graph_list[t]
for (i,node) in enumerate(modelsol.app_user_index), mixing_graph in inf_network.graph_list[t]
for j in 2:14
covid_alert_times[j-1,i] = covid_alert_times[j,i] #shift them all back
end
......@@ -143,16 +147,15 @@ end
function solve!(modelsol,recording)
function solve!(modelsol,recordings...)
for t in 1:modelsol.sim_length
#advance agent states based on the new network
record!(t,modelsol,recording)
for recording in recordings
record!(t,modelsol,recording)
end
agents_step!(t,modelsol)
end
end
function Φ(payoff,β)
return 1 / (exp(-1*β*payoff))
end
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment