Commit df40693d authored by Peter Jentsch's avatar Peter Jentsch
Browse files

simplify mixing graphs code

parent 29a75642
using CovidAlertVaccinationModel
using CovidAlertVaccinationModel:get_parameters
using CovidAlertVaccinationModel:get_parameters,get_app_parameters
using OnlineStats
using ThreadsX
using Plots
......@@ -22,10 +22,10 @@ const univarate_test_list = (
# (:recovery_rate, range(0.1, 0.5; length = len)),
# (:immunization_loss_prob, range(0.00, 0.05; length = len)),
# (:π_base, range(-4.5, -3.5; length = len)),
# (:η, range(0.0, 0.01; length = len)),
(:η, range(0.0, 0.01; length = len)),
# (:κ, range(0.5, 1.5; length = len)),
# (:ω, range(0.0, 0.01; length = len)),
(:ω_en, range(0.0, 0.5; length = len)),
(:ω_en, range(0.0, 0.0005; length = len)),
# (:γ, range(0.0, 0.5; length = len)),
# (:ξ, range(1, 15; length = len)),
# (:notification_parameter, range(0.00, 0.05; length = len)),
......@@ -36,7 +36,7 @@ const univarate_test_list = (
const univariate_path = "CovidAlertVaccinationModel/plots/univariate/"
function univarate_test(variable, variable_range)
default_parameters = get_parameters()
default_parameters = get_app_parameters()
parameter_range_list = [merge(default_parameters,NamedTuple{(variable,)}((value,))) for value in variable_range]
solve_fn(p) = mean_solve(samples, p,DebugRecorder)[1]
......
......@@ -31,8 +31,7 @@ end
Resample all the weights in `mixing_graph`
"""
function sample_mixing_graph!(mixing_graph)
mixing_edges = mixing_graph.mixing_edges
function sample_mixing_edges!(mixing_edges)
for i in 1:3, j in 1:i #diagonal
rand!(Random.default_rng(Threads.threadid()), mixing_edges.sampler_matrix[j,i],mixing_edges.sample_cache[j,i])
for k in 1:length(mixing_edges.contact_array[j,i])
......@@ -67,16 +66,19 @@ Stores the weights used in the graph, so they can be easily resampled.
This is the matrix of distributions, from which the edge weights are sampled. Specifically, weights for edges in `contact_array[i,j]` come from the distribution in `sampler_matrix[i,j]`, and are placed into `sample_cache[i,j]`. We only use the upper triangle of this but Julia lacks a good Symmetric matrix type. See `sample_mixing_graph!`.
"""
struct MixingEdges{M}
struct MixingEdges{M,M2}
total_edges::Int
contact_array::Matrix{Vector{GraphEdge}}
sample_cache::Matrix{Vector{UInt8}}
weights_dict::Dictionary{GraphEdge,UInt8}
sampler_matrix::M
function MixingEdges(total_edges,contact_array,sampler_matrix::Matrix{M}) where M<:Sampleable{Univariate, Discrete}
mixing_matrix::Union{Nothing,M2}
function MixingEdges(total_edges,contact_array,sampler_matrix::M,mixing_matrix::M2) where {M,M2}
sample_cache = map(v-> Vector{UInt8}(undef,length(v)),contact_array)
weights_dict = Dictionary{GraphEdge,UInt8}(;sizehint = total_edges)
new{typeof(sampler_matrix)}(total_edges,contact_array,sample_cache,weights_dict,sampler_matrix)
me = new{M,M2}(total_edges,contact_array,sample_cache,weights_dict,sampler_matrix,mixing_matrix)
sample_mixing_edges!(me)
return me
end
end
......@@ -116,7 +118,7 @@ function create_mixing_edges(demographic_index_vectors,mixing_matrix,weights_dis
contact_array[j,i] = GraphEdge.(stubs_i,stubs_j)
end
end
return MixingEdges(tot,contact_array,weights_distribution_matrix)
return MixingEdges(tot,contact_array,weights_distribution_matrix,mixing_matrix)
end
......@@ -130,7 +132,7 @@ function create_mixing_edges(g::SimpleGraph,demographics,demographic_index_vecto
j = dst(e)
push!(contact_array[Int(demographics[i]),Int(demographics[j])], GraphEdge(i,j))
end
return MixingEdges(ne(g),contact_array,weights_distribution_matrix)
return MixingEdges(ne(g),contact_array,weights_distribution_matrix,nothing)
end
"""
......@@ -148,11 +150,13 @@ List of lists of graphs, one list for each day.
graph_list::Vector{Vector{G}}
"""
struct TimeDepMixingGraph{N,G,GNT}
resampled_graphs::NTuple{N,GNT}
struct TimeDepMixingGraph{G,T1,T2}
remade_graphs::T1
resampled_graphs::T2
graph_list::Vector{Vector{G}}
function TimeDepMixingGraph(len,resampled_graphs::NTuple{N,GNT},base_graph_list::Vector{G}) where {GNT,G,N}
return new{N,G,GNT}(
function TimeDepMixingGraph(len,remade_graphs::T1,resampled_graphs::T2,base_graph_list::Vector{G}) where {G,T1,T2}
return new{G,T1,T2}(
remade_graphs,
resampled_graphs,
[copy(base_graph_list) for i in 1:len]
)
......@@ -165,6 +169,7 @@ Creates the `TimeDepMixingGraph` for our specific model.
Assumes the simulation begins on Thursday arbitrarily.
"""
function time_dep_mixing_graphs(len,base_network,demographics,index_vectors,ws_matrix_tuple,rest_matrix_tuple)
home_static_edges = WeightedGraph(base_network,demographics,index_vectors,contact_time_distributions.hh) #network with households and LTC homes
ws_static_edges = WeightedGraph(demographics,index_vectors,ws_matrix_tuple.daily,contact_time_distributions.ws)
......@@ -175,12 +180,14 @@ function time_dep_mixing_graphs(len,base_network,demographics,index_vectors,ws_m
rest_weekly_edges = WeightedGraph(demographics,index_vectors,rest_matrix_tuple.twice_a_week,contact_time_distributions.rest)
rest_daily_edges = WeightedGraph(demographics,index_vectors,rest_matrix_tuple.otherwise,contact_time_distributions.rest)
inf_network_list = [home_static_edges,rest_static_edges]
inf_network_list = [home_static_edges,rest_static_edges,ws_daily_edges,rest_daily_edges]
soc_network_list = [home_static_edges,rest_static_edges,ws_static_edges]
infected_mixing_graph = TimeDepMixingGraph(len,((ws_daily_edges,ws_matrix_tuple.daily),(rest_daily_edges,rest_matrix_tuple.daily)),inf_network_list)
soc_mixing_graph = TimeDepMixingGraph(len,((ws_daily_edges,ws_matrix_tuple.daily),(rest_daily_edges,rest_matrix_tuple.daily)),soc_network_list)
# display(infected_mixing_graph.graph_list)
remade_graphs = (ws_daily_edges,rest_daily_edges)
resampled_graphs = (home_static_edges,rest_static_edges,ws_static_edges,rest_weekly_edges,ws_weekly_edges)
infected_mixing_graph = TimeDepMixingGraph(len,remade_graphs,resampled_graphs,inf_network_list)
for (t,l) in enumerate(infected_mixing_graph.graph_list)
day_of_week = mod(t,7)
if !(day_of_week == 3 || day_of_week == 4) #simulation begins on thursday I guess
......@@ -190,25 +197,48 @@ function time_dep_mixing_graphs(len,base_network,demographics,index_vectors,ws_m
push!(l, ws_weekly_edges)
push!(l, rest_weekly_edges)
end
push!(l,ws_daily_edges)
push!(l,rest_daily_edges)
end
soc_mixing_graph = TimeDepMixingGraph(len,remade_graphs,resampled_graphs,soc_network_list)
return infected_mixing_graph,soc_mixing_graph
end
"""
Completely remake all the graphs in `time_dep_mixing_graph.resampled_graphs`.
"""
function remake!(time_dep_mixing_graph,demographic_index_vectors)
for (weighted_graph,mixing_matrix) in time_dep_mixing_graph.resampled_graphs
function remake!(t,time_dep_mixing_graph,demographic_index_vectors)
for weighted_graph in time_dep_mixing_graph.remade_graphs
empty!.(weighted_graph.g.fadjlist) #empty all the vector edgelists
weighted_graph.g.ne = 0
weighted_graph.mixing_edges = create_mixing_edges(demographic_index_vectors,mixing_matrix,weighted_graph.mixing_edges.sampler_matrix)
weighted_graph.mixing_edges = create_mixing_edges(demographic_index_vectors,weighted_graph.mixing_edges.mixing_matrix,weighted_graph.mixing_edges.sampler_matrix)
graph_from_mixing_edges(weighted_graph.g,weighted_graph.mixing_edges)
end
for weighted_graph in time_dep_mixing_graph.resampled_graphs
if weighted_graph in time_dep_mixing_graph.graph_list[t]
sample_mixing_edges!(weighted_graph.mixing_edges)
end
end
end
# function display_weighted_degree(g)
# weighted_degree = 0.0
# for node in vertices(g.g)
# for j in neighbors(g,node)
# weighted_degree += get_weight(g,GraphEdge(node,j))
# end
# end
# display(weighted_degree)
# end
# function display_degree(g)
# weighted_degree = 0.0
# for node in vertices(g.g)
# for j in neighbors(g,node)
# weighted_degree += 1
# end
# end
# display(weighted_degree)
# end
"""
Add the edges defined by MixingEdges to the actual graph G. Another big bottleneck since adjancency lists don't add edges super efficiently, and there are a ton of them.
"""
......
......@@ -3,27 +3,26 @@ function get_parameters()#(0.0000,0.00048,0.0005,0.16,-1.30,-1.24,-0.8,0.35,0.35
sim_length = 500,
num_households = 5000,
I_0_fraction = 0.003,
β_y = 0.00078,
β_m = 0.00063,
β_o = 0.755,
β_y = 0.0011,
β_m = 0.00061,
β_o = 0.04,
α_y = 0.4,
α_m = 0.4,
α_o = 0.4,
recovery_rate = 1/5,
π_base_y = -1.39,
π_base_m = -1.44,
π_base_y = -1.37,
π_base_m = -1.46,
π_base_o = -0.95,
η = 0.0,
κ = 0.0,
ω = 0.0055
,
ω = 0.0055,
ω_en = 0.00,
γ = 0.0,
Γ = 1/7,
ξ = 5.0,
notification_parameter = 0.001,
vaccinator_prob = 0.6,
app_user_fraction = 0.0,
notification_threshold = 20,
notification_threshold = 2,
immunizing = true,
immunization_delay = 14,
immunization_begin_day = 60,
......@@ -32,7 +31,10 @@ function get_parameters()#(0.0000,0.00048,0.0005,0.16,-1.30,-1.24,-0.8,0.35,0.35
)
return params
end
function get_app_parameters()
return merge(get_parameters(),(app_user_fraction = 0.5,))
end
function get_u_0(nodes,vaccinator_prob)
is_vaccinator = rand(Random.default_rng(Threads.threadid()),nodes) .< vaccinator_prob
status = fill(Susceptible,nodes)
......@@ -107,11 +109,6 @@ mutable struct ModelSolution{T,InfNet,SocNet,WSMixingDist,RestMixingDist}
status_totals = [count(==(AgentStatus(i)), u_0_inf) for i in 1:AgentStatus.size]
immunization_countdown = fill(-1, nodes) #immunization countdown is negative if not counting down
for network in infected_mixing_graph.graph_list[begin] #this also resamples the soc network weights since they point to the same objects, but those are never used
sample_mixing_graph!(network) #get new contact weights
end
return new{T,typeof(infected_mixing_graph),typeof(soc_mixing_graph),typeof(ws_matrix_tuple),typeof(rest_matrix_tuple)}(
sim_length,
......
......@@ -98,25 +98,20 @@ end
Base.@propagate_inbounds @views function update_vaccination_opinion_state!(t,modelsol,total_infections)
@unpack infection_introduction_day, π_base_y,π_base_m,π_base_o, η,γ,ζ, κ, ω, ω_en,γ,ξ = modelsol.params
@unpack infection_introduction_day, π_base_y,π_base_m,π_base_o, η,Γ,ζ, ω, ω_en,ξ = modelsol.params
@unpack demographics,time_of_last_alert, nodes, soc_network,u_vac,u_next_vac,app_user,app_user_list = modelsol
app_user_pointer = 0
for i in 1:nodes
π_base = t<infection_introduction_day ?
(π_base_y,π_base_m,π_base_o) :
(π_base_y*ζ,π_base_m*ζ,π_base_o*ζ)
random_soc_network = sample(Random.default_rng(Threads.threadid()), soc_network.graph_list[t])
if !isempty(neighbors(random_soc_network,i))
for _ = 1:1
random_neighbour = sample(Random.default_rng(Threads.threadid()), neighbors(random_soc_network.g,i))
if u_vac[random_neighbour] == u_vac[i]
vac_payoff = π_base[Int(demographics[i])] + total_infections*ω
if app_user[i] && time_of_last_alert[app_user_list[i]]>=0
vac_payoff += γ^(-1*(t - time_of_last_alert[app_user_list[i]])) * (η + total_infections*ω_en)
vac_payoff += Γ^(-1*(t - time_of_last_alert[app_user_list[i]])) * (η + total_infections*ω_en)
end
if u_vac[i]
# display(1 - Φ(vac_payoff,ξ))
if rand(Random.default_rng(Threads.threadid())) < 1 - Φ(vac_payoff,ξ)
......@@ -128,10 +123,7 @@ Base.@propagate_inbounds @views function update_vaccination_opinion_state!(t,mod
u_next_vac[i] = true
end
end
break
end
end
end
end
......@@ -150,36 +142,33 @@ function weighted_degree(node,network::TimeDepMixingGraph)
return weighted_degree
end
function agents_step!(t,modelsol, init_indices)
if t>modelsol.params.infection_introduction_day
if !isempty(init_indices)
inf_index = pop!(init_indices)
modelsol.u_inf[inf_index] = Infected
modelsol.status_totals[Int(Infected)] += 1
end
update_alert_durations!(t,modelsol)
end
update_vaccination_opinion_state!(t,modelsol,modelsol.status_totals[Int(Infected)])
update_infection_state!(t,modelsol)
modelsol.u_vac .= modelsol.u_next_vac
modelsol.u_inf .= modelsol.u_next_inf
remake!(modelsol.inf_network,modelsol.index_vectors)
for network in modelsol.inf_network.graph_list[t] #this also resamples the soc network weights since they point to the same objects, but those are never used
sample_mixing_graph!(network) #get new contact weights
end
end
function solve!(modelsol,recordings...)
init_indices = rand(Random.default_rng(Threads.threadid()), 1:modelsol.nodes, round(Int,modelsol.nodes*modelsol.params.I_0_fraction))
for t in 1:modelsol.sim_length
agents_step!(t,modelsol,init_indices)
#this also resamples the soc network weights since they point to the same objects, but those are never used
if t>1
remake!(t,modelsol.inf_network,modelsol.index_vectors)
end
if t>modelsol.params.infection_introduction_day
if !isempty(init_indices)
inf_index = pop!(init_indices)
modelsol.u_inf[inf_index] = Infected
modelsol.status_totals[Int(Infected)] += 1
end
update_alert_durations!(t,modelsol)
end
update_vaccination_opinion_state!(t,modelsol,modelsol.status_totals[Int(Infected)])
update_infection_state!(t,modelsol)
#advance agent states based on the new network
modelsol.u_vac .= modelsol.u_next_vac
modelsol.u_inf .= modelsol.u_next_inf
for recording in recordings
record!(t,modelsol,recording)
end
......
......@@ -30,13 +30,13 @@ using VectorizedRNG
export intervalsmodel, hh, ws, rest, abm
const DNDEBUG = false
macro c_assert(boolean) #this is a version of @assert that turns itself off when DNDEBUG=false, should use more
if DNDEBUG
message = string("Assertion: ", boolean, " failed")
:($(esc(boolean)) || error($message))
end
end
# const DNDEBUG = false
# macro c_assert(boolean) #this is a version of @assert that turns itself off when DNDEBUG=false, should use more
# if DNDEBUG
# message = string("Assertion: ", boolean, " failed")
# :($(esc(boolean)) || error($message))
# end
# end
const durmax = 144
const PACKAGE_FOLDER = dirname(dirname(pathof(CovidAlertVaccinationModel)))
......
......@@ -13,7 +13,6 @@ const model_sizes = [1000,5000]
const dem_cat = AgentDemographic.size -1
const samples = 1
#WRITE MORE TESTS
@testset "mixing matrices, size: $sz" for sz in model_sizes
for rep = 1:samples
......@@ -136,4 +135,30 @@ end
@test mean(mixing_dist[i]) mean(dist[i]) atol = 0.2
end
end
end
\ No newline at end of file
end
# @testset "edge resampling" begin
# model = abm(get_parameters(), nothing)
# @unpack demographics, inf_network, soc_network = model
# for l in inf_network.resampled_graphs
# for graphs in inf_network.graph_list
# mixing_dist = [0.0 for _ in 1:3, _ in 1:3]
# prev_mixing_dist = [0.0 for _ in 1:3, _ in 1:3]
# if l in graphs
# for v in vertices(l.g)
# demo_v = demographics[v]
# degs = zeros(3)
# for w in LightGraphs.neighbors(l.g,v)
# demo_w = demographics[w]
# degs[Int(demo_w)] += get_weight(l,GraphEdge(w,v))
# end
# for (j,d) in enumerate(degs)
# mixing_dist[Int(demo_v), j] += d
# end
# end
# end
# prev_mixing_dist = copy(mixing_dist)
# display((prev_mixing_dist,mixing_dist))
# end
# end
# end
......@@ -6,6 +6,6 @@ using ThreadsX
using BenchmarkTools
import StatsBase.mean
# include("ABM/mixing_test.jl")
include("ABM/output_test.jl")
include("ABM/mixing_test.jl")
# include("ABM/output_test.jl")
include("IntervalsModel/intervals_model_test.jl")
\ No newline at end of file
No preview for this file type
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment