Commit 9774ac9f authored by Peter Jentsch's avatar Peter Jentsch
Browse files

heatmaps

parent 61883b87
......@@ -52,6 +52,12 @@ git-tree-sha1 = "a4d07a1c313392a77042855df46c5f534076fab9"
uuid = "13072b0f-2c55-5437-9ae7-d433b7a33950"
version = "1.0.0"
[[AxisKeys]]
deps = ["AbstractFFTs", "CovarianceEstimation", "IntervalSets", "InvertedIndices", "LazyStack", "LinearAlgebra", "NamedDims", "OffsetArrays", "Statistics", "StatsBase", "Tables"]
git-tree-sha1 = "118c5c2c9f509f503efa05fa2385936bc2cad78d"
uuid = "94b1ba4f-4ee9-5380-92f1-94cde586c3c5"
version = "0.1.16"
[[BangBang]]
deps = ["Compat", "ConstructionBase", "Future", "InitialValues", "LinearAlgebra", "Requires", "Setfield", "Tables", "ZygoteRules"]
git-tree-sha1 = "d53b1eaefd48e233545d21f5b764c8ee54df4a09"
......@@ -153,6 +159,12 @@ git-tree-sha1 = "9f02045d934dc030edad45944ea80dbd1f0ebea7"
uuid = "d38c429a-6771-53c6-b99e-75d170b6e991"
version = "0.5.7"
[[CovarianceEstimation]]
deps = ["LinearAlgebra", "Statistics", "StatsBase"]
git-tree-sha1 = "bc3930158d2be029e90b7c40d1371c4f54fa04db"
uuid = "587fd27a-f159-11e8-2dae-1979310e6154"
version = "0.2.6"
[[Crayons]]
git-tree-sha1 = "3f71217b538d7aaee0b69ab47d9b7724ca8afa0d"
uuid = "a8cc5b0e-0ffa-5ad4-8c14-923d3ee1735f"
......@@ -243,6 +255,12 @@ git-tree-sha1 = "92d8f9f208637e8d2d28c664051a00569c01493d"
uuid = "5ae413db-bbd1-5e63-b57d-d24a61df00f5"
version = "2.1.5+1"
[[EllipsisNotation]]
deps = ["ArrayInterface"]
git-tree-sha1 = "8041575f021cba5a099a456b4163c9a08b566a02"
uuid = "da5c29d0-fa7d-589e-88eb-ea29b0a81949"
version = "1.1.0"
[[Expat_jll]]
deps = ["Artifacts", "JLLWrappers", "Libdl", "Pkg"]
git-tree-sha1 = "b3bfd02e98aedfa5cf885665493c5598c350cd2f"
......@@ -443,6 +461,12 @@ git-tree-sha1 = "1e0e51692a3a77f1eeb51bf741bdd0439ed210e7"
uuid = "a98d9a8b-a2ab-59e6-89dd-64a1c18fca59"
version = "0.13.2"
[[IntervalSets]]
deps = ["Dates", "EllipsisNotation", "Statistics"]
git-tree-sha1 = "3cc368af3f110a767ac786560045dceddfc16758"
uuid = "8197267c-284f-5f27-9208-e0e47529a953"
version = "0.5.3"
[[Intervals]]
deps = ["Dates", "Printf", "RecipesBase", "Serialization", "TimeZones"]
git-tree-sha1 = "323a38ed1952d30586d0fe03412cde9399d3618b"
......@@ -534,6 +558,12 @@ version = "0.15.1"
deps = ["Artifacts", "Pkg"]
uuid = "4af54fe1-eca0-43a8-85a7-787d91b784e3"
[[LazyStack]]
deps = ["LinearAlgebra", "NamedDims", "OffsetArrays", "Test", "ZygoteRules"]
git-tree-sha1 = "a8bf67afad3f1ee59d367267adb7c44ccac7fdee"
uuid = "1fad7336-0346-5a1a-a56f-a06ba010965b"
version = "0.0.7"
[[LeftChildRightSiblingTrees]]
deps = ["AbstractTrees"]
git-tree-sha1 = "71be1eb5ad19cb4f61fa8c73395c0338fd092ae0"
......@@ -722,6 +752,12 @@ git-tree-sha1 = "bfe47e760d60b82b66b61d2d44128b62e3a369fb"
uuid = "77ba4419-2d1f-58cd-9bb1-8ffee604a2e3"
version = "0.3.5"
[[NamedDims]]
deps = ["LinearAlgebra", "Pkg", "Requires", "Statistics"]
git-tree-sha1 = "ec71dd922fd8008f29741b1358da9254833ef6ca"
uuid = "356022a1-0364-5f58-8944-0da4b18d706f"
version = "0.2.29"
[[NamedTupleTools]]
git-tree-sha1 = "63831dcea5e11db1c0925efe5ef5fc01d528c522"
uuid = "d9ec5142-1e00-5aa0-9d6a-321866360f50"
......
......@@ -4,6 +4,7 @@ authors = ["pjentsch <pjentsch@uwaterloo.ca> and contributors"]
version = "0.1.0"
[deps]
AxisKeys = "94b1ba4f-4ee9-5380-92f1-94cde586c3c5"
BenchmarkTools = "6e4b80f9-dd63-53aa-95a3-0cdb28fa8baf"
CSV = "336ed68f-0bac-5ca0-87d4-7b16caf5d00b"
CurveFit = "5a033b19-8c74-5913-a970-47c3779ef25c"
......@@ -35,6 +36,7 @@ Pandas = "eadc2687-ae89-51f9-a5d9-86b5a6373a9c"
Pipe = "b98c9c47-44ae-5843-9183-064241ee97a0"
Plots = "91a5bcdd-55d7-5caf-9e0b-520d859cae80"
Printf = "de0858da-6303-5e67-8744-51eddeeeb8d7"
ProgressMeter = "92933f4c-e287-5a05-a399-4b506db050ca"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
RandomNumbers = "e6cf234a-135c-5ec9-84dd-332b85af5143"
Serialization = "9e88b42a-f829-5b0c-bbe9-9e923198166b"
......
using CovidAlertVaccinationModel
using CovidAlertVaccinationModel:get_parameters,get_app_parameters
using OnlineStats
using ThreadsX
using Plots
const samples = 10
##Univariate tests
const len = 10 #number of points to evaluate
gr()
const univarate_test_list = (
# (:I_0_fraction, range(0.0, 0.05; length = len)),
# (:base_transmission_probability, range(0.0002, 0.002; length = len)),
# (:recovery_rate, range(0.1, 0.5; length = len)),
# (:immunization_loss_prob, range(0.00, 0.05; length = len)),
# (:π_base, range(-4.5, -3.5; length = len)),
(:η, range(0.0, 0.01; length = len)),
# (:κ, range(0.5, 1.5; length = len)),
# (:ω, range(0.0, 0.01; length = len)),
(:ω_en, range(0.0, 0.0005; length = len)),
# (:γ, range(0.0, 0.5; length = len)),
# (:ξ, range(1, 15; length = len)),
# (:notification_parameter, range(0.00, 0.05; length = len)),
# (:app_user_fraction, range(0.05, 0.25; length = len)),
(:notification_threshold, (1:len)),
# (:immunization_delay, [7,10,14,20]),
)
const univariate_path = "CovidAlertVaccinationModel/plots/univariate/"
function univarate_test(variable, variable_range)
default_parameters = get_app_parameters()
parameter_range_list = [merge(default_parameters,NamedTuple{(variable,)}((value,))) for value in variable_range]
solve_fn(p) = mean_solve(samples, p,DebugRecorder)[1]
univariate_outlist = ThreadsX.map(solve_fn, parameter_range_list)
p = plot_model(variable,parameter_range_list,univariate_outlist,default_parameters.infection_introduction_day,default_parameters.immunization_begin_day)
return p
end
if !ispath(univariate_path)
mkdir(univariate_path)
end
function univariate_simulations()
plt_list = ThreadsX.map(univarate_test_list) do ur
out = univarate_test(ur...)
display("done $(ur[1])")
return out
end
for ((varname,_),p) in zip(univarate_test_list,plt_list)
savefig(p,"$univariate_path/$varname.pdf")
end
end
univariate_simulations()
using CovidAlertVaccinationModel:plot_phase_planes
multivariate_simulations()
......@@ -6,7 +6,7 @@ const household_data = read_household_data()
default(dpi = 300)
default(framestyle = :box)
import LightGraphs.neighbors
export DebugRecorder,mean_solve,plot_model, get_parameters
export DebugRecorder,mean_solve,plot_model, get_parameters, HeatmapRecorder
"""
bench()
......@@ -14,7 +14,7 @@ Runs the model with default parameters.
"""
function bench()
p = get_parameters()
recording = DebugRecorder(p.sim_length)
recording = DebugRecorder(0.0,p.sim_length)
abm(p,recording)
return recording
end
......
......@@ -73,12 +73,13 @@ mutable struct ModelSolution{T,InfNet,SocNet,WSMixingDist,RestMixingDist}
app_user::Vector{Bool}
app_user_list::Vector{Int}
app_user_index::Vector{Int}
status_totals::Vector{Int}
daily_vaccinators::Int
daily_cases_by_age::Vector{Int}
daily_unvac_cases_by_age::Vector{Int}
daily_immunized_by_age::Vector{Int}
avg_weighted_degree_of_vaccinators::Variance{Float64,Float64,EqualWeight}
avg_weighted_degree::Variance{Float64,Float64,EqualWeight}
ws_matrix_tuple::WSMixingDist
rest_matrix_tuple::RestMixingDist
immunization_countdown::Vector{Int}
......@@ -132,6 +133,8 @@ mutable struct ModelSolution{T,InfNet,SocNet,WSMixingDist,RestMixingDist}
[0,0,0],
[0,0,0],
[0,0,0],
Variance(),
Variance(),
ws_matrix_tuple,
rest_matrix_tuple,
immunization_countdown
......
#needlessly overwrought output interface
using LabelledArrays
using OnlineStats
using Plots
import OnlineStats.fit!
......@@ -24,30 +23,19 @@ struct DebugRecorder{ElType,ArrT1<:AbstractArray{ElType},ArrT2<:AbstractArray{El
daily_immunized_by_age::ArrT3
daily_unvac_cases_by_age::ArrT3
end
"""
Initialize an empty DebugRecorder. We use a labelledarray for the state vector, so the individal timeseries can be accessed by name.
HeatmapRecorder, for heatmaps!
"""
function DebugRecorder(sim_length)
state_totals = @LArray Array{Int}(undef,4,sim_length) (S = (1,:),I = (2,:),R = (3,:), V = (4,:))
total_vaccinators = Vector{Int}(undef,sim_length)
mean_time_since_last_notification = Vector{Int}(undef,sim_length)
daily_cases_by_age = @LArray Array{Int}(undef,3,sim_length) (Y = (1,:),M = (2,:),O = (3,:))
daily_immunized_by_age = @LArray Array{Int}(undef,3,sim_length) (Y = (1,:),M = (2,:),O = (3,:))
daily_unvac_by_age = @LArray Array{Int}(undef,3,sim_length) (Y = (1,:),M = (2,:),O = (3,:))
return DebugRecorder(
state_totals,
daily_cases_by_age,
total_vaccinators,
mean_time_since_last_notification,
daily_immunized_by_age,
daily_unvac_by_age
)
struct HeatmapRecorder{ElType,ArrT1<:AbstractArray{ElType},ArrT2<:AbstractArray{Int}} <: AbstractRecorder{ElType}
daily_cases_by_age::ArrT2
final_size_by_age::ArrT1
avg_weighted_degree_of_vaccinators::ElType
avg_weighted_degree::ElType
end
"""
Initialize a DebugRecorder filled with (copies) of val. I should find a nicer way to combine both constructors.
Initialize a DebugRecorder filled with (copies) of val.
"""
function DebugRecorder(val::T, sim_length) where T
totals = [copy(val) for i in 1:4, j in 1:sim_length]
......@@ -68,6 +56,21 @@ function DebugRecorder(val::T, sim_length) where T
)
end
function HeatmapRecorder(val::T,sim_length) where T<:OnlineStat
daily_cases_by_age = @LArray zeros(Int,3,sim_length) (Y = (1,:),M = (2,:),O = (3,:))
final_size_by_age= [copy(val) for i in 1:3]
average_deg_of_vac= copy(val)
average_deg= copy(val)
return HeatmapRecorder(
daily_cases_by_age,
final_size_by_age,
average_deg_of_vac,
average_deg,
)
end
function record!(t,modelsol, recorder::DebugRecorder)
recorder.total_vaccinators[t] = modelsol.daily_vaccinators
recorder.daily_cases_by_age[:,t] .= modelsol.daily_cases_by_age
......@@ -82,7 +85,16 @@ function record!(t,modelsol, recorder::DebugRecorder)
# display(modelsol.daily_immunizations_by_age)
recorder.daily_immunized_by_age[:,t] .= modelsol.daily_immunized_by_age
recorder.daily_unvac_cases_by_age[:,t] .= modelsol.daily_unvac_cases_by_age
end
function record!(t,modelsol, recorder::HeatmapRecorder)
recorder.daily_cases_by_age[:,t] .= modelsol.daily_cases_by_age
if modelsol.sim_length == t
for age in 1:3
fit!(recorder.final_size_by_age[age], sum(recorder.daily_cases_by_age[age,:]))
end
merge!(recorder.avg_weighted_degree_of_vaccinators,modelsol.avg_weighted_degree_of_vaccinators)
merge!(recorder.avg_weighted_degree,modelsol.avg_weighted_degree)
end
end
function record!(t,modelsol, recorder::Nothing)
......@@ -92,7 +104,7 @@ end
function mean_solve(samples,parameter_tuple,recorder)
stat_recorder = recorder(Variance(), parameter_tuple.sim_length)
output_recorder = recorder(parameter_tuple.sim_length)
output_recorder = recorder(0.0,parameter_tuple.sim_length)
avg_populations = [0.0,0.0,0.0]
for _ in 1:samples
sol = abm(parameter_tuple,output_recorder)
......@@ -103,59 +115,31 @@ function mean_solve(samples,parameter_tuple,recorder)
return stat_recorder,avg_populations
end
function OnlineStats.fit!(stat_type::R,pt) where {T,OS<:OnlineStat{T}, R<:AbstractRecorder{OS}}
function mean_solve(samples,parameter_tuple,recorder::Type{HeatmapRecorder})
stat_recorder = recorder(Variance(), parameter_tuple.sim_length)
output_recorder = recorder(Variance(),parameter_tuple.sim_length)
avg_populations = [0.0,0.0,0.0]
for _ in 1:samples
sol = abm(parameter_tuple,output_recorder)
avg_populations .+= length.(sol.index_vectors)
fit!(stat_recorder,output_recorder)
end
avg_populations ./= samples
return stat_recorder,avg_populations
end
function OnlineStats.fit!(accum::R,pt::R2) where {R<:AbstractRecorder,R2<:AbstractRecorder}
for field in fieldnames(R)
stat_field = getfield(stat_type,field)
stat_field = getfield(accum,field)
sample_field = getfield(pt,field)
for (k,(stat_entry,sample_entry)) in enumerate(zip(stat_field,sample_field))
fit!(stat_entry,sample_entry)
end
combine!(stat_field,sample_field)
end
end
using Printf
const ts_colors = cgrad(:PuBu_9)
function plot_model(varname,univariate_series, output_list::Vector{T},infection_begin,vac_begin) where T<:DebugRecorder
sim_length = length(output_list[1].recorded_status_totals.S)
ts_list(data) = [
(data.recorded_status_totals.S, "Susceptible over time"),
(data.recorded_status_totals.R, "Recovered over time"),
(data.total_vaccinators, "No. vaccinators over time"),
(data.mean_time_since_last_notification, "Mean time since last notification"),
(data.daily_cases_by_age.Y,"Daily (incident) Y cases"),
(data.daily_cases_by_age.M,"Daily (incident) M cases"),
(data.daily_cases_by_age.O,"Daily (incident) O cases"),
(data.daily_immunized_by_age.Y, "new Y vaccinations each day"),
(data.daily_immunized_by_age.M, "new M vaccinations each day"),
(data.daily_immunized_by_age.O, "new O vaccinations each day"),
]
l = length(ts_list(output_list[1]))
plts = [plot() for i=1:l]
for (i,(p,data)) in enumerate(zip(univariate_series, output_list))
# display(p[varname])
if !isnothing(varname)
p_val = @sprintf "%.4f" (p[varname])
labelname = "$varname = $(p_val)"
else
labelname = nothing
end
for (plt,(ts,title)) in zip(plts,ts_list(data))
plot!(plt, mean.(ts); ribbon = std.(ts),
label = labelname, line_z = i, color=:blues,
ylabel = "no. of people",
colorbar = false,
title = title,
)
end
end
plot!(plts[begin]; legends = :topright)
for p in plts
vline!(p,[infection_begin]; label = "infection begin", line =:dot)
vline!(p,[vac_begin]; label = "vaccination begin",line = :dot)
function combine!(accum::AbstractArray,x::AbstractArray)
for (stat_entry,sample_entry) in zip(accum,x)
combine!(stat_entry,sample_entry)
end
return plot(plts...;layout = (l,1),size=(800,400*l),leftmargin = 12Plots.mm, legend = :outerright)
end
combine!(accum::OnlineStat{<:T}, x::T) where T = fit!(accum,x)
combine!(accum::OnlineStat{T}, x::OnlineStat{T}) where T = merge!(accum,x)
combine!(accum::Real, x::Real) = nothing
\ No newline at end of file
const univariate_path = "CovidAlertVaccinationModel/plots/univariate/"
const bivariate_path = "CovidAlertVaccinationModel/plots/univariate/"
function univarate_test(variable, variable_range)
default_parameters = get_app_parameters()
parameter_range_list = [merge(default_parameters,NamedTuple{(variable,)}((value,))) for value in variable_range]
solve_fn(p) = mean_solve(samples, p,DebugRecorder)[1]
univariate_outlist = ThreadsX.map(solve_fn, parameter_range_list)
p = plot_model(variable,parameter_range_list,univariate_outlist,default_parameters.infection_introduction_day,default_parameters.immunization_begin_day)
return p
end
if !ispath(univariate_path)
mkdir(univariate_path)
end
function univariate_simulations()
len = 10
univarate_test_list = (
# (:I_0_fraction, range(0.0, 0.05; length = len)),
# (:base_transmission_probability, range(0.0002, 0.002; length = len)),
# (:recovery_rate, range(0.1, 0.5; length = len)),
# (:immunization_loss_prob, range(0.00, 0.05; length = len)),
# (:π_base, range(-4.5, -3.5; length = len)),
(:η, range(0.0, 0.01; length = len)),
# (:κ, range(0.5, 1.5; length = len)),
# (:ω, range(0.0, 0.01; length = len)),
(:ω_en, range(0.0, 0.0005; length = len)),
# (:γ, range(0.0, 0.5; length = len)),
# (:ξ, range(1, 15; length = len)),
# (:notification_parameter, range(0.00, 0.05; length = len)),
# (:app_user_fraction, range(0.05, 0.25; length = len)),
(:notification_threshold, (1:len)),
# (:immunization_delay, [7,10,14,20]),
)
plt_list = ThreadsX.map(univarate_test_list) do ur
out = univarate_test(ur...)
display("done $(ur[1])")
return out
end
for ((varname,_),p) in zip(univarate_test_list,plt_list)
savefig(p,"$univariate_path/$varname.pdf")
end
end
using AxisKeys
function multivariate_simulations()
len = 10
samples = 10
app_simulations = (
(:η, range(0.0, 0.01; length = len)),
(:ω_en, range(0.0, 0.0005; length = len)),
# (:notification_threshold, (1:len)),
)
run_multivariate_sims(app_simulations,1)
# for ((varname,_),p) in zip(univarate_test_list,plt_list)
# savefig(p,"$univariate_path/$varname.pdf")
# end
end
using ProgressMeter
function run_multivariate_sims(sims,samples)
varnames, sim_ranges = zip(sims...)
default_parameters = get_app_parameters()
simvars = Iterators.product(sim_ranges...)
progmeter = Progress(length(simvars))
output = ThreadsX.map(simvars) do vars
vars_with_names = NamedTuple{varnames}(vars)
parameters = merge(default_parameters,vars_with_names)
out,_ = mean_solve(samples, parameters,DebugRecorder)
next!(progmeter)
return out
end
display(length(simvars))
fname = join(string.(varnames),"_")
keyed_output = KeyedArray(output;NamedTuple{varnames}(sim_ranges)...)
path = joinpath(PACKAGE_FOLDER,"abm_output","$fname.dat")
serialize(path,keyed_output)
return fname
end
function plot_parameter_plane(input_fname)
map(1: length(axiskeys(output)[end])) do i
end
end
\ No newline at end of file
using Printf
const ts_colors = cgrad(:PuBu_9)
function plot_model(varname,univariate_series, output_list::Vector{T},infection_begin,vac_begin) where T<:DebugRecorder
sim_length = length(output_list[1].recorded_status_totals.S)
ts_list(data) = [
(data.recorded_status_totals.S, "Susceptible over time"),
(data.recorded_status_totals.R, "Recovered over time"),
(data.total_vaccinators, "No. vaccinators over time"),
(data.mean_time_since_last_notification, "Mean time since last notification"),
(data.daily_cases_by_age.Y,"Daily (incident) Y cases"),
(data.daily_cases_by_age.M,"Daily (incident) M cases"),
(data.daily_cases_by_age.O,"Daily (incident) O cases"),
(data.daily_immunized_by_age.Y, "new Y vaccinations each day"),
(data.daily_immunized_by_age.M, "new M vaccinations each day"),
(data.daily_immunized_by_age.O, "new O vaccinations each day"),
]
l = length(ts_list(output_list[1]))
plts = [plot() for i=1:l]
for (i,(p,data)) in enumerate(zip(univariate_series, output_list))
# display(p[varname])
if !isnothing(varname)
p_val = @sprintf "%.4f" (p[varname])
labelname = "$varname = $(p_val)"
else
labelname = nothing
end
for (plt,(ts,title)) in zip(plts,ts_list(data))
plot!(plt, mean.(ts); ribbon = std.(ts),
label = labelname, line_z = i, color=:blues,
ylabel = "no. of people",
colorbar = false,
title = title,
)
end
end
plot!(plts[begin]; legends = :topright)
for p in plts
vline!(p,[infection_begin]; label = "infection begin", line =:dot)
vline!(p,[vac_begin]; label = "vaccination begin",line = :dot)
end
return plot(plts...;layout = (l,1),size=(800,400*l),leftmargin = 12Plots.mm, legend = :outerright)
end
......@@ -14,7 +14,7 @@ end
Base.@propagate_inbounds @views function update_alert_durations!(t,modelsol) # Base.@propagate_inbounds
@unpack notification_parameter,notification_threshold = modelsol.params
@unpack time_of_last_alert, app_user_index,inf_network,covid_alert_notifications,app_user = modelsol
for (i,node) in enumerate(modelsol.app_user_index)
for (i,node) in enumerate(app_user_index)
for j in 2:14
covid_alert_notifications[j-1,i] = covid_alert_notifications[j,i] #shift them all back
end
......@@ -39,9 +39,9 @@ Base.@propagate_inbounds @views function update_alert_durations!(t,modelsol) # B
end
end
Base.@propagate_inbounds @views function update_infection_state!(t,modelsol)
@views function update_infection_state!(t,modelsol; record_degrees = false)
@unpack β_y,β_m,β_o,α_y,α_m,α_o,recovery_rate,immunizing,immunization_begin_day = modelsol.params
@unpack u_inf,u_vac,u_next_inf,u_next_vac,demographics,inf_network,status_totals, immunization_countdown = modelsol
@unpack u_inf,u_vac,u_next_inf,demographics,inf_network,status_totals, immunization_countdown = modelsol
modelsol.daily_cases_by_age .= 0
modelsol.daily_immunized_by_age .= 0
......@@ -86,12 +86,17 @@ Base.@propagate_inbounds @views function update_infection_state!(t,modelsol)
agent_transition!(i, Infected,Recovered)
end
end
weighted_degree_of_i::Int = record_degrees ? weighted_degree(t,i,inf_network) : 0
if immunization_countdown[i] == 0
modelsol.daily_immunized_by_age[Int(agent_demo)] += 1
fit!(modelsol.avg_weighted_degree_of_vaccinators,weighted_degree_of_i)
agent_transition!(i, Susceptible,Immunized)
elseif immunization_countdown[i]>0
fit!(modelsol.avg_weighted_degree,weighted_degree_of_i)
immunization_countdown[i] -= 1
else
fit!(modelsol.avg_weighted_degree,weighted_degree_of_i)
end
end
end
......@@ -130,13 +135,11 @@ Base.@propagate_inbounds @views function update_vaccination_opinion_state!(t,mod
modelsol.daily_vaccinators = count(==(true),u_vac)
end
function weighted_degree(node,network::TimeDepMixingGraph)
function weighted_degree(t,node,network::TimeDepMixingGraph)
weighted_degree = 0
for g_list in network.graph_list
for g in g_list
for j in neighbors(g,node)
weighted_degree += get_weight(g,GraphEdge(node,j))
end
for g in network.graph_list[t]
for j in neighbors(g,node)
weighted_degree += get_weight(g,GraphEdge(node,j))
end
end
return weighted_degree
......@@ -158,7 +161,7 @@ function sample_initial_nodes(nodes,graphs,I_0_fraction)
end
function solve!(modelsol,recordings...)
function solve!(modelsol,recording::T) where T
init_indices = sample_initial_nodes(modelsol.nodes, modelsol.inf_network.graph_list[begin], modelsol.params.I_0_fraction)
for t in 1:modelsol.sim_length
#this also resamples the soc network weights since they point to the same objects, but those are never used
......@@ -176,15 +179,13 @@ function solve!(modelsol,recordings...)
end
update_vaccination_opinion_state!(t,modelsol,modelsol.status_totals[Int(Infected)])
update_infection_state!(t,modelsol)
update_infection_state!(t,modelsol; record_degrees = (T <: HeatmapRecorder))
#advance agent states based on the new network
modelsol.u_vac .= modelsol.u_next_vac
modelsol.u_inf .= modelsol.u_next_inf