Commit ae9a2aa2 authored by Peter Jentsch's avatar Peter Jentsch
Browse files

output simplify wip

parent ca29f6ee
......@@ -60,9 +60,9 @@ version = "0.1.16"
[[BangBang]]
deps = ["Compat", "ConstructionBase", "Future", "InitialValues", "LinearAlgebra", "Requires", "Setfield", "Tables", "ZygoteRules"]
git-tree-sha1 = "d53b1eaefd48e233545d21f5b764c8ee54df4a09"
git-tree-sha1 = "e239020994123f08905052b9603b4ca14f8c5807"
uuid = "198e06fe-97b7-11e9-32a5-e1d131e6ad66"
version = "0.3.30"
version = "0.3.31"
[[Base64]]
uuid = "2a0f44e3-6c83-55bd-87e4-b1978d98bd5f"
......@@ -98,9 +98,9 @@ version = "1.16.0+6"
[[ChainRulesCore]]
deps = ["Compat", "LinearAlgebra", "SparseArrays"]
git-tree-sha1 = "d659e42240c2162300b321f05173cab5cc40a5ba"
git-tree-sha1 = "ea05cadc30c15f9185b61ea418b9d47d53b55bc2"
uuid = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
version = "0.10.4"
version = "0.10.8"
[[ColorSchemes]]
deps = ["ColorTypes", "Colors", "FixedPointNumbers", "Random", "StaticArrays"]
......@@ -122,9 +122,9 @@ version = "0.12.8"
[[Compat]]
deps = ["Base64", "Dates", "DelimitedFiles", "Distributed", "InteractiveUtils", "LibGit2", "Libdl", "LinearAlgebra", "Markdown", "Mmap", "Pkg", "Printf", "REPL", "Random", "SHA", "Serialization", "SharedArrays", "Sockets", "SparseArrays", "Statistics", "Test", "UUIDs", "Unicode"]
git-tree-sha1 = "e4e2b39db08f967cc1360951f01e8a75ec441cab"
git-tree-sha1 = "dc7dedc2c2aa9faf59a55c622760a25cbefbe941"
uuid = "34da2185-b29b-5c13-b0c7-acf172513d20"
version = "3.30.0"
version = "3.31.0"
[[CompilerSupportLibraries_jll]]
deps = ["Artifacts", "Libdl"]
......@@ -149,9 +149,9 @@ version = "0.1.2"
[[ConstructionBase]]
deps = ["LinearAlgebra"]
git-tree-sha1 = "1dc43957fb9a1574fa1b7a449e101bd1fd3a9fb7"
git-tree-sha1 = "f74e9d5388b8620b4cee35d4c5a618dd4dc547f4"
uuid = "187b0558-2788-49d3-abe0-74a17ed4e7c9"
version = "1.2.1"
version = "1.3.0"
[[Contour]]
deps = ["StaticArrays"]
......@@ -235,9 +235,9 @@ uuid = "8ba89e20-285c-5b6f-9357-94700520ee1b"
[[Distributions]]
deps = ["FillArrays", "LinearAlgebra", "PDMats", "Printf", "QuadGK", "Random", "SparseArrays", "SpecialFunctions", "Statistics", "StatsBase", "StatsFuns"]
git-tree-sha1 = "013020ec9a5cdf1dd454eba3704dbffa69d3047e"
git-tree-sha1 = "62e1ac52e9adf4234285cd88c94954924aa3f9ef"
uuid = "31c24e10-a181-5473-b8eb-7969acd0382f"
version = "0.25.3"
version = "0.25.5"
[[DocStringExtensions]]
deps = ["LibGit2"]
......@@ -292,9 +292,9 @@ version = "4.3.1+4"
[[FFTW]]
deps = ["AbstractFFTs", "FFTW_jll", "LinearAlgebra", "MKL_jll", "Preferences", "Reexport"]
git-tree-sha1 = "ae8de3350af3026008be4ba23e1e905ab2011d20"
git-tree-sha1 = "f985af3b9f4e278b1d24434cbb546d6092fca661"
uuid = "7a1cc6ca-52ef-59f5-83cd-3a7055c09341"
version = "1.4.2"
version = "1.4.3"
[[FFTW_jll]]
deps = ["Artifacts", "JLLWrappers", "Libdl", "Pkg"]
......@@ -409,9 +409,9 @@ version = "2.0.0"
[[Hwloc_jll]]
deps = ["Artifacts", "JLLWrappers", "Libdl", "Pkg"]
git-tree-sha1 = "aac91e34ef4c166e0857e3d6052a3467e5732ceb"
git-tree-sha1 = "3395d4d4aeb3c9d31f5929d32760d8baeee88aaf"
uuid = "e33a78d0-f292-5ffc-b300-72abe9b543c8"
version = "2.4.1+0"
version = "2.5.0+0"
[[IfElse]]
git-tree-sha1 = "28e837ff3e7a6c3cdb252ce49fb412c8eb3caeef"
......@@ -621,9 +621,9 @@ version = "1.42.0+0"
[[Libiconv_jll]]
deps = ["Artifacts", "JLLWrappers", "Libdl", "Pkg"]
git-tree-sha1 = "8d22e127ea9a0917bc98ebd3755c8bd31989381e"
git-tree-sha1 = "42b62845d70a619f063a7da093d995ec8e15e778"
uuid = "94ce4f54-9a6c-5748-9c1c-f9c7231a4531"
version = "1.16.1+0"
version = "1.16.1+1"
[[Libmount_jll]]
deps = ["Artifacts", "JLLWrappers", "Libdl", "Pkg"]
......@@ -670,15 +670,15 @@ uuid = "56ddb016-857b-54e1-b83d-db4d58db5568"
[[LoggingExtras]]
deps = ["Dates", "Logging"]
git-tree-sha1 = "59b45fd91b743dff047313bb7af0f84167aef80d"
git-tree-sha1 = "dfeda1c1130990428720de0024d4516b1902ce98"
uuid = "e6f89c97-d47a-5376-807f-9c37f3926c36"
version = "0.4.6"
version = "0.4.7"
[[LoopVectorization]]
deps = ["ArrayInterface", "DocStringExtensions", "IfElse", "LinearAlgebra", "OffsetArrays", "Polyester", "Requires", "SLEEFPirates", "Static", "StrideArraysCore", "ThreadingUtilities", "UnPack", "VectorizationBase"]
git-tree-sha1 = "b16dde45ba9e2506358d4d7fe13f746330e8e622"
git-tree-sha1 = "0d353c52a418e1d97b7a39d192331157f71b2389"
uuid = "bdcacae8-1622-11e9-2a5c-532679323890"
version = "0.12.37"
version = "0.12.42"
[[MKL_jll]]
deps = ["Artifacts", "IntelOpenMP_jll", "JLLWrappers", "LazyArtifacts", "Libdl", "Pkg"]
......@@ -743,9 +743,9 @@ uuid = "14a3606d-f60d-562e-9121-12d972cd8159"
[[MutableArithmetics]]
deps = ["LinearAlgebra", "SparseArrays", "Test"]
git-tree-sha1 = "ad9b2bce6021631e0e20706d361972343a03e642"
git-tree-sha1 = "3927848ccebcc165952dc0d9ac9aa274a87bfe01"
uuid = "d8a4904e-b15c-11e9-3269-09a3773c0cb0"
version = "0.2.19"
version = "0.2.20"
[[NaNMath]]
git-tree-sha1 = "bfe47e760d60b82b66b61d2d44128b62e3a369fb"
......@@ -774,9 +774,9 @@ uuid = "ca575930-c2e3-43a9-ace4-1e988b2c1908"
[[OffsetArrays]]
deps = ["Adapt"]
git-tree-sha1 = "1381a7142eefd4cd12f052a4d2d790fe21bd1d55"
git-tree-sha1 = "e436bb81d2ce4f01fb02374c4410e5a9229c85f9"
uuid = "6fe1bfb0-de20-5000-8ca7-80f57d26f881"
version = "1.9.2"
version = "1.10.0"
[[Ogg_jll]]
deps = ["Artifacts", "JLLWrappers", "Libdl", "Pkg"]
......@@ -1028,9 +1028,9 @@ version = "1.1.0"
[[SentinelArrays]]
deps = ["Dates", "Random"]
git-tree-sha1 = "bc967c221ccdb0b85511709bda96ee489396f544"
git-tree-sha1 = "ffae887d0f0222a19c406a11c3831776d1383e3d"
uuid = "91c51154-3ec4-41a3-a24f-3f23e20d615c"
version = "1.3.2"
version = "1.3.3"
[[Serialization]]
uuid = "9e88b42a-f829-5b0c-bbe9-9e923198166b"
......@@ -1090,9 +1090,9 @@ version = "0.2.5"
[[StaticArrays]]
deps = ["LinearAlgebra", "Random", "Statistics"]
git-tree-sha1 = "42378d3bab8b4f57aa1ca443821b752850592668"
git-tree-sha1 = "745914ebcd610da69f3cb6bf76cb7bb83dcb8c9a"
uuid = "90137ffa-7385-5640-81b9-e52037218182"
version = "1.2.2"
version = "1.2.4"
[[Statistics]]
deps = ["LinearAlgebra", "SparseArrays"]
......@@ -1216,9 +1216,9 @@ uuid = "4ec0a83e-493e-50e2-b9ac-8f72acf5a8f5"
[[VectorizationBase]]
deps = ["ArrayInterface", "Hwloc", "IfElse", "Libdl", "LinearAlgebra", "Static"]
git-tree-sha1 = "7c8974c7de377a2dc67e778017c78f96fc8f0fc6"
git-tree-sha1 = "d3d40e06daf09599a5f6524d1ae4224c573639d0"
uuid = "3d5dd08c-fd9d-11e8-17fa-ed2836048c2f"
version = "0.20.16"
version = "0.20.18"
[[VersionParsing]]
git-tree-sha1 = "80229be1f670524750d905f8fc8148e5a8c4537f"
......@@ -1227,9 +1227,9 @@ version = "1.2.0"
[[Wayland_jll]]
deps = ["Artifacts", "Expat_jll", "JLLWrappers", "Libdl", "Libffi_jll", "Pkg", "XML2_jll"]
git-tree-sha1 = "dc643a9b774da1c2781413fd7b6dcd2c56bb8056"
git-tree-sha1 = "3e61f0b86f90dacb0bc0e73a0c5a83f6a8636e23"
uuid = "a2964d1f-97da-50d4-b82a-358c7fce9d89"
version = "1.17.0+4"
version = "1.19.0+0"
[[Wayland_protocols_jll]]
deps = ["Artifacts", "JLLWrappers", "Libdl", "Pkg", "Wayland_jll"]
......
......@@ -14,9 +14,8 @@ Runs the model with default parameters.
"""
function bench()
p = get_parameters()
recording = DebugRecorder(0.0,p.sim_length)
abm(p,recording)
return recording
out = abm(p)
return out
end
......@@ -34,8 +33,8 @@ end
"""
Run the model with given parameter tuple and output recorder. See `get_parameters` for list of parameters. See `output.jl` for the list of recorders. Currently just 'DebugRecorder`.
"""
function abm(parameters, recorder)
function abm(parameters)
model_sol = ModelSolution(parameters.sim_length,parameters,parameters.num_households)
output = solve!(model_sol,recorder )
solve!(model_sol)
return model_sol
end
......@@ -117,7 +117,7 @@ end
"""
Completely remake all the graphs in `time_dep_mixing_graph.resampled_graphs`.
"""
function remake!(t,time_dep_mixing_graph,index_vectors,demographics)
function remake_all!(t,time_dep_mixing_graph,index_vectors,demographics)
for wg in time_dep_mixing_graph.remade_graphs
remake!(wg,index_vectors)
end
......
......@@ -56,7 +56,7 @@ function app_users(demographics,app_usage_prob)
end
mutable struct ModelSolution{T,InfNet,SocNet,WSMixingDist,RestMixingDist}
mutable struct ModelSolution{T,InfNet,SocNet,WSMixingDist,RestMixingDist,RecorderType}
sim_length::Int
nodes::Int
params::T
......@@ -74,17 +74,10 @@ mutable struct ModelSolution{T,InfNet,SocNet,WSMixingDist,RestMixingDist}
app_user_list::Vector{Int}
app_user_index::Vector{Int}
status_totals::Vector{Int}
daily_vaccinators::Int
daily_cases_by_age::Vector{Int}
daily_unvac_cases_by_age::Vector{Int}
daily_immunized_by_age::Vector{Int}
daily_total_notifications::Int
daily_total_notified_agents::Int
avg_weighted_degree_of_vaccinators::Variance{Float64,Float64,EqualWeight}
avg_weighted_degree::Variance{Float64,Float64,EqualWeight}
ws_matrix_tuple::WSMixingDist
rest_matrix_tuple::RestMixingDist
immunization_countdown::Vector{Int}
output_data::RecorderType
@views function ModelSolution(sim_length,params::T,num_households) where T
demographics,base_network,index_vectors = generate_population(num_households)
nodes = length(demographics)
......@@ -112,8 +105,10 @@ mutable struct ModelSolution{T,InfNet,SocNet,WSMixingDist,RestMixingDist}
status_totals = [count(==(AgentStatus(i)), u_0_inf) for i in 1:AgentStatus.size]
immunization_countdown = fill(-1, nodes) #immunization countdown is negative if not counting down
return new{T,typeof(infected_mixing_graph),typeof(soc_mixing_graph),typeof(ws_matrix_tuple),typeof(rest_matrix_tuple)}(
output_data = Recorder(0,sim_length)
return new{T,typeof(infected_mixing_graph),
typeof(soc_mixing_graph),typeof(ws_matrix_tuple),
typeof(rest_matrix_tuple), typeof(output_data)}(
sim_length,
nodes,
params,
......@@ -131,17 +126,10 @@ mutable struct ModelSolution{T,InfNet,SocNet,WSMixingDist,RestMixingDist}
app_user_list,
app_user_index,
status_totals,
0,
[0,0,0],
[0,0,0],
[0,0,0],
0,
0,
Variance(),
Variance(),
ws_matrix_tuple,
rest_matrix_tuple,
immunization_countdown
immunization_countdown,
output_data
)
end
end
......@@ -2,20 +2,10 @@
using LabelledArrays
using Plots
import OnlineStats.fit!
"""
AbstractRecorder is the type that all "Recorder" types must extend for `fit!` (and possibly other stuff) to dispatch correctly.
The point of using different recorder types is that we only want to save the output that we need, and that output will vary depending on the specific experiment.
AbstractRecorders can be parameterized to have any element type, so when they are parameterized with T<:OnlineStat, we can use `fit!` to collect statistics about a bunch of Recorders into a single Recorder.
"""
abstract type AbstractRecorder{ElType} end
"""
DebugRecorder should store everything we might want to know about the model output.
Recorder should store everything we might want to know about the model output.
"""
struct DebugRecorder{ElType,ArrT1<:AbstractArray{ElType},ArrT2<:AbstractVector{ElType},ArrT3<:AbstractArray{ElType}} <: AbstractRecorder{ElType}
struct Recorder{ElType,ArrT1,ArrT2,ArrT3,StatAccumulator}
recorded_status_totals::ArrT1
daily_cases_by_age::ArrT3
total_vaccinators::ArrT2
......@@ -28,77 +18,58 @@ struct DebugRecorder{ElType,ArrT1<:AbstractArray{ElType},ArrT2<:AbstractVector{E
mean_time_since_last_notification::ArrT2
daily_immunized_by_age::ArrT3
daily_unvac_cases_by_age::ArrT3
end
avg_weighted_degree_of_vaccinators::Vector{StatAccumulator}
avg_weighted_degree::Vector{StatAccumulator}
"""
HeatmapRecorder, for heatmaps!
"""
struct HeatmapRecorder{ElType,ArrT1<:AbstractArray{ElType},ArrT2<:AbstractArray{Int}} <: AbstractRecorder{ElType}
daily_cases_by_age::ArrT2
final_size_by_age::ArrT1
avg_weighted_degree_of_vaccinators::ElType
avg_weighted_degree::ElType
function Recorder(val::T,sim_length) where T
totals = [copy(val) for i in 1:4, j in 1:sim_length]
state_totals = @LArray totals (S = (1,:),I = (2,:),R = (3,:), V = (4,:))
total_vaccinators = [copy(val) for j in 1:sim_length]
daily_total_notifications = [copy(val) for j in 1:sim_length]
daily_total_notified_agents = [copy(val) for j in 1:sim_length]
mean_time_since_last_notification = [copy(val) for j in 1:sim_length]
totals_ymo = [copy(val) for i in 1:3, j in 1:sim_length]
daily_cases_by_age = @LArray deepcopy(totals_ymo) (Y = (1,:),M = (2,:),O = (3,:))
daily_immunized_by_age = @LArray deepcopy(totals_ymo) (Y = (1,:),M = (2,:),O = (3,:))
daily_unvac_by_age = @LArray deepcopy(totals_ymo) (Y = (1,:),M = (2,:),O = (3,:))
unvac_final_size_by_age = [copy(val) for i in 1:3]
total_postinf_vaccination = [copy(val) for i in 1:3]
total_preinf_vaccination = [copy(val) for i in 1:3]
final_size_by_age = [copy(val) for i in 1:3]
avg_weighted_degree_of_vaccinators = [Variance() for _ in 1:3]
avg_weighted_degree = [Variance() for _ in 1:3]
return new{T,typeof(state_totals),typeof(total_vaccinators),typeof(daily_immunized_by_age),eltype(avg_weighted_degree)}(
state_totals,
daily_cases_by_age,
total_vaccinators,
daily_total_notifications,
daily_total_notified_agents,
unvac_final_size_by_age,
total_postinf_vaccination,
total_preinf_vaccination,
final_size_by_age,
mean_time_since_last_notification,
daily_immunized_by_age,
daily_unvac_by_age,
avg_weighted_degree_of_vaccinators,
avg_weighted_degree
)
end
end
"""
Initialize a DebugRecorder filled with (copies) of val.
Initialize a Recorder filled with (copies) of val.
"""
function DebugRecorder(val::T,sim_length) where T
totals = [copy(val) for i in 1:4, j in 1:sim_length]
state_totals = @LArray totals (S = (1,:),I = (2,:),R = (3,:), V = (4,:))
total_vaccinators = [copy(val) for j in 1:sim_length]
daily_total_notifications = [copy(val) for j in 1:sim_length]
daily_total_notified_agents = [copy(val) for j in 1:sim_length]
mean_time_since_last_notification = [copy(val) for j in 1:sim_length]
totals_ymo = [copy(val) for i in 1:3, j in 1:sim_length]
daily_cases_by_age = @LArray deepcopy(totals_ymo) (Y = (1,:),M = (2,:),O = (3,:))
daily_immunized_by_age = @LArray deepcopy(totals_ymo) (Y = (1,:),M = (2,:),O = (3,:))
daily_unvac_by_age = @LArray deepcopy(totals_ymo) (Y = (1,:),M = (2,:),O = (3,:))
unvac_final_size_by_age = [copy(val) for i in 1:3]
total_postinf_vaccination = [copy(val) for i in 1:3]
total_preinf_vaccination = [copy(val) for i in 1:3]
final_size_by_age = [copy(val) for i in 1:3]
return DebugRecorder(
state_totals,
daily_cases_by_age,
total_vaccinators,
daily_total_notifications,
daily_total_notified_agents,
unvac_final_size_by_age,
total_postinf_vaccination,
total_preinf_vaccination,
final_size_by_age,
mean_time_since_last_notification,
daily_immunized_by_age,
daily_unvac_by_age
)
end
function HeatmapRecorder(sim_length) where T<:OnlineStat
daily_cases_by_age = @LArray zeros(Int,3,sim_length) (Y = (1,:),M = (2,:),O = (3,:))
final_size_by_age = [Variance() for i in 1:3]
average_deg_of_vac = Variance()
average_deg = Variance()
return HeatmapRecorder(
daily_cases_by_age,
final_size_by_age,
average_deg_of_vac,
average_deg,
)
end
function record!(t,modelsol, recorder::DebugRecorder)
recorder.total_vaccinators[t] = modelsol.daily_vaccinators
recorder.daily_cases_by_age[:,t] .= modelsol.daily_cases_by_age
function record!(t,modelsol)
recorder = modelsol.output_data
recorder.recorded_status_totals[:,t] .= modelsol.status_totals
recorder.daily_total_notifications[t] = modelsol.daily_total_notifications
recorder.daily_total_notified_agents[t] = modelsol.daily_total_notified_agents
alerts = filter(>(0),modelsol.time_of_last_alert)
if !isempty(alerts)
mean_alert_time = mean(t .- alerts)
......@@ -106,9 +77,6 @@ function record!(t,modelsol, recorder::DebugRecorder)
else
recorder.mean_time_since_last_notification[t] = 0
end
# display(modelsol.daily_immunizations_by_age)
recorder.daily_immunized_by_age[:,t] .= modelsol.daily_immunized_by_age
recorder.daily_unvac_cases_by_age[:,t] .= modelsol.daily_unvac_cases_by_age
if modelsol.sim_length == t
for age in 1:3
recorder.unvac_final_size_by_age[age] = sum(recorder.daily_unvac_cases_by_age[age,:])
......@@ -118,20 +86,6 @@ function record!(t,modelsol, recorder::DebugRecorder)
end
end
end
function record!(t,modelsol, recorder::HeatmapRecorder)
recorder.daily_cases_by_age[:,t] .= modelsol.daily_cases_by_age
if modelsol.sim_length == t
for age in 1:3
fit!(recorder.final_size_by_age[age], sum(recorder.daily_cases_by_age[age,:]))
end
merge!(recorder.avg_weighted_degree_of_vaccinators,modelsol.avg_weighted_degree_of_vaccinators)
merge!(recorder.avg_weighted_degree,modelsol.avg_weighted_degree)
end
end
function record!(t,modelsol, recorder::Nothing)
#do nothing
end
function mean_solve(samples,parameter_tuple,recorder;progmeter = nothing)
......@@ -151,23 +105,7 @@ function mean_solve(samples,parameter_tuple,recorder;progmeter = nothing)
return stat_recorder,avg_populations
end
function mean_solve(samples,parameter_tuple,recorder::Type{HeatmapRecorder};progmeter = nothing)
stat_recorder = recorder(parameter_tuple.sim_length)
avg_populations = [0.0,0.0,0.0]
sol_list = map(1:samples) do _
output_recorder = recorder(parameter_tuple.sim_length)
sol = abm(parameter_tuple,output_recorder)
isnothing(progmeter) || next!(progmeter)
return output_recorder,length.(sol.index_vectors)
end
for (output_recorder,pop) in sol_list
avg_populations .+= pop
fit!(stat_recorder,output_recorder)
end
avg_populations ./= samples
return stat_recorder,avg_populations
end
function OnlineStats.fit!(accum::R,pt::R2) where {R<:AbstractRecorder,R2<:AbstractRecorder}
function accumulate_model(accum,pt)
for field in fieldnames(R)
stat_field = getfield(accum,field)
sample_field = getfield(pt,field)
......
......@@ -5,7 +5,7 @@ using Printf
const ts_colors = cgrad(:PuBu_9)
function plot_model(varname,univariate_series, output_list::Vector{T},infection_begin,vac_begin) where T<:DebugRecorder
function plot_model(varname,univariate_series, output_list::Vector{T},infection_begin,vac_begin) where T
sim_length = length(output_list[1].recorded_status_totals.S)
ts_list(data) = [
......
......@@ -13,10 +13,8 @@ end
Base.@propagate_inbounds @views function update_alert_durations!(t,modelsol) # Base.@propagate_inbounds
@unpack notification_parameter,notification_threshold = modelsol.params
@unpack time_of_last_alert, app_user_index,inf_network,covid_alert_notifications,app_user = modelsol
@unpack time_of_last_alert, app_user_index,inf_network,covid_alert_notifications,app_user, output_data = modelsol
modelsol.daily_total_notifications = 0
for (i,node) in enumerate(app_user_index)
for j in 2:14
covid_alert_notifications[j-1,i] = covid_alert_notifications[j,i] #shift them all back
......@@ -37,7 +35,7 @@ Base.@propagate_inbounds @views function update_alert_durations!(t,modelsol) # B
covid_alert_notifications[end,i] = 0
end
if sum(covid_alert_notifications[:,i])>=notification_threshold
modelsol.daily_total_notifications += 1
output_data.daily_total_notifications[t] += 1
time_of_last_alert[i] = t
end
end
......@@ -45,11 +43,7 @@ end
Base.@propagate_inbounds @views function update_infection_state!(t,modelsol; record_degrees = false)
@unpack β_y,β_m,β_o,α_y,α_m,α_o,recovery_rate,immunizing,immunization_begin_day = modelsol.params
@unpack u_inf,u_vac,u_next_inf,demographics,inf_network,status_totals, immunization_countdown = modelsol
modelsol.daily_cases_by_age .= 0
modelsol.daily_immunized_by_age .= 0
modelsol.daily_unvac_cases_by_age .= 0
@unpack u_inf,u_vac,u_next_inf,demographics,inf_network,status_totals, immunization_countdown, output_data = modelsol
function agent_transition!(node, from::AgentStatus,to::AgentStatus)
immunization_countdown[node] = -1
......@@ -74,11 +68,11 @@ Base.@propagate_inbounds @views function update_infection_state!(t,modelsol; rec
if rand(Random.default_rng(Threads.threadid())) < contact_weight(β_vec[Int(agent_demo)],get_weight(mixing_graph,GraphEdge(i,j)))
if agent_status == Immunized && rand(Random.default_rng(Threads.threadid())) < 1- α_vec[Int(agent_demo)]
agent_transition!(i, Immunized,Infected)
modelsol.daily_cases_by_age[Int(agent_demo)]+=1
output_data.daily_cases_by_age[Int(agent_demo),t]+=1
elseif agent_status == Susceptible
modelsol.daily_cases_by_age[Int(agent_demo)]+=1
modelsol.daily_unvac_cases_by_age[Int(agent_demo)]+=1
output_data.daily_cases_by_age[Int(agent_demo),t]+=1
output_data.daily_unvac_cases_by_age[Int(agent_demo),t]+=1
agent_transition!(i, Susceptible,Infected)
end
end
......@@ -92,15 +86,15 @@ Base.@propagate_inbounds @views function update_infection_state!(t,modelsol; rec
end
weighted_degree_of_i::Int = record_degrees ? weighted_degree(t,i,inf_network) : 0
if immunization_countdown[i] == 0
modelsol.daily_immunized_by_age[Int(agent_demo)] += 1
fit!(modelsol.avg_weighted_degree_of_vaccinators,weighted_degree_of_i)
output_data.daily_immunized_by_age[Int(agent_demo),t] += 1
fit!(output_data.avg_weighted_degree_of_vaccinators[Int(agent_demo)],weighted_degree_of_i)
agent_transition!(i, Susceptible,Immunized)
elseif immunization_countdown[i]>0
fit!(modelsol.avg_weighted_degree,weighted_degree_of_i)
fit!(output_data.avg_weighted_degree[Int(agent_demo)],weighted_degree_of_i)
immunization_countdown[i] -= 1
else
fit!(modelsol.avg_weighted_degree,weighted_degree_of_i)
fit!(output_data.avg_weighted_degree[Int(agent_demo)],weighted_degree_of_i)
end
end
end
......@@ -108,10 +102,8 @@ end
Base.@propagate_inbounds @views function update_vaccination_opinion_state!(t,modelsol,total_infections)
@unpack infection_introduction_day, π_base_y,π_base_m,π_base_o, η,Γ,ζ, ω, ω_en,ξ = modelsol.params
@unpack demographics,time_of_last_alert, nodes, soc_network,u_vac,u_next_vac,app_user,app_user_list = modelsol
@unpack demographics,time_of_last_alert, nodes, soc_network,u_vac,u_next_vac,app_user,app_user_list,output_data = modelsol
modelsol.daily_total_notified_agents = 0
# display(time_of_last_alert[1:20])
for i in 1:nodes
π_base = t<infection_introduction_day ?
(π_base_y,π_base_m,π_base_o) :
......@@ -121,7 +113,7 @@ Base.@propagate_inbounds @views function update_vaccination_opinion_state!(t,mod
random_neighbour = sample(Random.default_rng(Threads.threadid()), neighbors(random_soc_network.g,i))
app_vac_payoff = 0.0
if app_user[i] && time_of_last_alert[app_user_list[i]]>=0
modelsol.daily_total_notified_agents += 1
output_data.daily_total_notified_agents[t] += 1
app_vac_payoff = Γ^((t - time_of_last_alert[app_user_list[i]])) * (η + total_infections*ω_en)
# display(t - time_of_last_alert[app_user_list[i]])
end
......@@ -141,7 +133,7 @@ Base.@propagate_inbounds @views function update_vaccination_opinion_state!(t,mod
end
end
end
modelsol.daily_vaccinators = count(==(true),u_vac)
output_data.total_vaccinators[t] = count(==(true),u_vac)
end
function weighted_degree(t,node,network::TimeDepMixingGraph)
......@@ -170,12 +162,12 @@ function sample_initial_nodes(nodes,graphs,I_0_fraction)
end
function solve!(modelsol,recording::T) where T
function solve!(modelsol)
init_indices = sample_initial_nodes(modelsol.nodes, modelsol.inf_network.graph_list[begin], modelsol.params.I_0_fraction)
for t in 1:modelsol.sim_length
#this also resamples the soc network weights since they point to the same objects, but those are never used
if t>1
remake!(t,modelsol.inf_network,modelsol.index_vectors,modelsol.demographics)
remake_all!(t,modelsol.inf_network,modelsol.index_vectors,modelsol.demographics)
end
if t>modelsol.params.infection_introduction_day
......@@ -190,12 +182,12 @@ function solve!(modelsol,recording::T) where T
update_alert_durations!(t,modelsol