Skip to content
Snippets Groups Projects
Commit 2f890228 authored by Peter Jentsch's avatar Peter Jentsch
Browse files

graph optimizations, contact intervals are now drawn from IntervalsModel output, still slow tho :(

parent be06ddd0
No related branches found
No related tags found
No related merge requests found
...@@ -10,8 +10,8 @@ default(framestyle = :box) ...@@ -10,8 +10,8 @@ default(framestyle = :box)
function bench() function bench()
steps = 100 steps = 100
model_sol = ModelSolution(steps,get_parameters(),5000); model_sol = ModelSolution(steps,get_parameters(),5000);
recording = recorder(steps) recording = DebugRecorder(steps)
output = solve!(model_sol,DebugRecorder) output = solve!(model_sol,recording )
end end
function abm() function abm()
......
...@@ -31,11 +31,71 @@ function random_bipartite_graph_fast_CL!(g::SimpleGraph,anodes,bnodes,aseq,bseq, ...@@ -31,11 +31,71 @@ function random_bipartite_graph_fast_CL!(g::SimpleGraph,anodes,bnodes,aseq,bseq,
return g return g
end end
struct TimeDepMixingGraph{N,G}
resampled_graphs::NTuple{N,G}
graph_list::Vector{Vector{G}}
function TimeDepMixingGraph(len,resampled_graphs::NTuple{N,G},base_graph_list::Vector{G}) where {G,N}
return new{N,G}(
resampled_graphs,
[copy(base_graph_list) for i in 1:len]
)
end
end
function time_dep_mixing_graphs(len,base_network,demographics,index_vectors,ws_matrix_tuple,rest_matrix_tuple)
home_static_edges = WeightedGraph(base_network,contact_time_distributions.hh) #network with households and LTC homes
ws_static_edges = WeightedGraph(demographics,index_vectors,ws_matrix_tuple.daily,contact_time_distributions.ws)
ws_weekly_edges = WeightedGraph(demographics,index_vectors,ws_matrix_tuple.twice_a_week,contact_time_distributions.ws)
ws_daily_edges = WeightedGraph(demographics,index_vectors,ws_matrix_tuple.otherwise,contact_time_distributions.ws)
rest_static_edges = WeightedGraph(demographics,index_vectors,rest_matrix_tuple.daily,contact_time_distributions.rest)
rest_weekly_edges = WeightedGraph(demographics,index_vectors,rest_matrix_tuple.twice_a_week,contact_time_distributions.rest)
rest_daily_edges = WeightedGraph(demographics,index_vectors,rest_matrix_tuple.otherwise,contact_time_distributions.rest)
inf_network_list = [home_static_edges,rest_static_edges]
soc_network_list = [home_static_edges,rest_static_edges,ws_static_edges]
infected_mixing_graph = TimeDepMixingGraph(len,(ws_daily_edges,rest_daily_edges),inf_network_list)
soc_mixing_graph = TimeDepMixingGraph(len,(ws_daily_edges,rest_daily_edges),soc_network_list)
# display(infected_mixing_graph.graph_list)
for (t,l) in enumerate(infected_mixing_graph.graph_list)
day_of_week = mod(t,7)
if !(day_of_week == 3 || day_of_week == 4) #simulation begins on thursday I guess
push!(l, ws_static_edges)
end
if rand(RNG)<5/7
push!(l, ws_weekly_edges)
push!(l, rest_weekly_edges)
end
push!(l,ws_daily_edges)
push!(l,rest_daily_edges)
end
return infected_mixing_graph,soc_mixing_graph
end
function remake!(time_dep_mixing_graph,demographic_index_vectors,mixing_matrix)
for weighted_graph in time_dep_mixing_graph.resampled_graphs
empty!.(weighted_graph.g.fadjlist) #empty all the vector edgelists
empty!(weighted_graph.weights_dict)
weighted_graph.g.ne = 0
contacts = MixingContacts(demographic_index_vectors,mixing_matrix)
for i in 1:length(demographic_index_vectors), j in 1:i #diagonal
random_bipartite_graph_fast_CL!(
weighted_graph.g,
demographic_index_vectors[i],
demographic_index_vectors[j],
contacts.contact_array[i,j],
contacts.contact_array[j,i],
weighted_graph.weights_dict
)
end
end
end
#defining my own weighted graph type cause we need to be able to resample the edge weights in a particular way #defining my own weighted graph type cause we need to be able to resample the edge weights in a particular way
struct WeightedGraph{G, V,M} struct WeightedGraph{G, V,M}
g::G g::G
weights::V weights_dict::V
weights_distribution_matrix::M weights_distribution_matrix::M
function WeightedGraph(demographics,demographic_index_vectors,mixing_matrix,weights_distribution_matrix) function WeightedGraph(demographics,demographic_index_vectors,mixing_matrix,weights_distribution_matrix)
...@@ -43,29 +103,31 @@ struct WeightedGraph{G, V,M} ...@@ -43,29 +103,31 @@ struct WeightedGraph{G, V,M}
g = Graph(length(demographics)) g = Graph(length(demographics))
weights = RobinDict{Tuple{Int,Int},UInt8}() weights_dict = RobinDict{Tuple{Int,Int},UInt8}()
for i in 1:length(demographic_index_vectors), j in 1:i #diagonal for i in 1:length(demographic_index_vectors), j in 1:i #diagonal
random_bipartite_graph_fast_CL!(g,demographic_index_vectors[i],demographic_index_vectors[j],contacts.contact_array[i,j],contacts.contact_array[j,i],weights) random_bipartite_graph_fast_CL!(g,demographic_index_vectors[i],demographic_index_vectors[j],contacts.contact_array[i,j],contacts.contact_array[j,i],weights_dict)
end end
covid_alert_time = zeros(nv(g)) covid_alert_time = zeros(nv(g))
return new{typeof(g),typeof(weights),typeof(weights_distribution_matrix)}( return new{typeof(g),typeof(weights_dict),typeof(weights_distribution_matrix)}(
g, g,
weights, weights_dict,
weights_distribution_matrix weights_distribution_matrix
) )
end end
function WeightedGraph(g::SimpleGraph,weights_distribution_matrix) function WeightedGraph(g::SimpleGraph,weights_distribution_matrix)
weights = RobinDict{Tuple{Int,Int},UInt8}() weights_dict = RobinDict{Tuple{Int,Int},UInt8}()
return new{typeof(g),typeof(weights),typeof(weights_distribution_matrix)}( return new{typeof(g),typeof(weights_dict),typeof(weights_distribution_matrix)}(
g, g,
weights, weights_dict,
weights_distribution_matrix weights_distribution_matrix
) )
end end
end end
function Base.show(io::IO, g::WeightedGraph)
print(io, "WG")
end
function sample_mixing_graph!(mixing_graph,population_demographics) function sample_mixing_graph!(mixing_graph,population_demographics)
for (k,e) in enumerate(edges(mixing_graph.g)) for (k,e) in enumerate(edges(mixing_graph.g))
i = src(e) i = src(e)
...@@ -74,8 +136,8 @@ function sample_mixing_graph!(mixing_graph,population_demographics) ...@@ -74,8 +136,8 @@ function sample_mixing_graph!(mixing_graph,population_demographics)
demo_j = Int(population_demographics[j]) demo_j = Int(population_demographics[j])
contact_time = rand(RNG, mixing_graph.weights_distribution_matrix[demo_i,demo_j]) contact_time = rand(RNG, mixing_graph.weights_distribution_matrix[demo_i,demo_j])
mixing_graph.weights[(i,j)] = contact_time mixing_graph.weights_dict[(i,j)] = contact_time
mixing_graph.weights[(j,i)] = contact_time mixing_graph.weights_dict[(j,i)] = contact_time
end end
end end
@inline function reindex!(k,csum,index_list_i,index_list_j,j_to_i_contacts,i_to_j_contacts,sample_list_i,sample_list_j) @inline function reindex!(k,csum,index_list_i,index_list_j,j_to_i_contacts,i_to_j_contacts,sample_list_i,sample_list_j)
......
...@@ -30,7 +30,7 @@ function get_u_0(nodes,I_0_fraction,vaccinator_prob) ...@@ -30,7 +30,7 @@ function get_u_0(nodes,I_0_fraction,vaccinator_prob)
return status,is_vaccinator return status,is_vaccinator
end end
struct ModelSolution{T,G} struct ModelSolution{T,InfNet,SocNet,WSMixingDist,RestMixingDist}
sim_length::Int sim_length::Int
nodes::Int nodes::Int
params::T params::T
...@@ -40,8 +40,8 @@ struct ModelSolution{T,G} ...@@ -40,8 +40,8 @@ struct ModelSolution{T,G}
u_vac::Vector{Bool} u_vac::Vector{Bool}
covid_alert_times::Array{Int,2} covid_alert_times::Array{Int,2}
time_of_last_alert::Vector{Int} time_of_last_alert::Vector{Int}
inf_network_lists::Vector{Vector{G}} inf_network::InfNet
soc_networks::Vector{G} soc_network::SocNet
index_vectors::Vector{Vector{Int}} index_vectors::Vector{Vector{Int}}
demographics::Vector{AgentDemographic} demographics::Vector{AgentDemographic}
app_user::Vector{Bool} app_user::Vector{Bool}
...@@ -49,62 +49,32 @@ struct ModelSolution{T,G} ...@@ -49,62 +49,32 @@ struct ModelSolution{T,G}
app_user_index::Vector{Int} app_user_index::Vector{Int}
status_totals::Vector{Int} status_totals::Vector{Int}
status_totals_next::Vector{Int} status_totals_next::Vector{Int}
ws_matrix_tuple::WSMixingDist
rest_matrix_tuple::RestMixingDist
function ModelSolution(sim_length,params::T,num_households) where T function ModelSolution(sim_length,params::T,num_households) where T
demographics,base_network,index_vectors = generate_population(num_households) demographics,base_network,index_vectors = generate_population(num_households)
pop_sizes = length.(index_vectors) pop_sizes = length.(index_vectors)
map_symmetrize(m_tuple) = map(md -> symmetrize_means(pop_sizes,md), m_tuple) map_symmetrize(m_tuple) = map(md -> symmetrize_means(pop_sizes,md), m_tuple)
ws_matrix_list = map_symmetrize(workschool_mixing) ws_matrix_tuple = map_symmetrize(workschool_mixing)
rest_matrix_list = map_symmetrize(rest_mixing) rest_matrix_tuple = map_symmetrize(rest_mixing)
app_user_list = zeros(length(demographics)) app_user_list = zeros(length(demographics))
is_app_user = rand(RNG,length(demographics)) .< params.app_user_fraction is_app_user = rand(RNG,length(demographics)) .< params.app_user_fraction
app_user_index = findall(==(true),is_app_user) app_user_index = findall(==(true),is_app_user)
app_user_list[is_app_user] .= collect(1:length(app_user_index)) app_user_list[is_app_user] .= collect(1:length(app_user_index))
# display(app)
nodes = length(demographics) nodes = length(demographics)
u_0_inf,u_0_vac = get_u_0(nodes,params.I_0_fraction,params.vaccinator_prob) u_0_inf,u_0_vac = get_u_0(nodes,params.I_0_fraction,params.vaccinator_prob)
home_static_edges = WeightedGraph(base_network,contact_time_distributions.hh) #network with households and LTC homes infected_mixing_graph,soc_mixing_graph = time_dep_mixing_graphs(sim_length,base_network,demographics,index_vectors,ws_matrix_tuple,rest_matrix_tuple)
ws_static_edges = WeightedGraph(demographics,index_vectors,ws_matrix_list.daily,contact_time_distributions.ws)
ws_weekly_edges = WeightedGraph(demographics,index_vectors,ws_matrix_list.twice_a_week,contact_time_distributions.ws)
ws_daily_edges_vector = [WeightedGraph(demographics,index_vectors,ws_matrix_list.otherwise,contact_time_distributions.ws) for i in 1:sim_length]
rest_static_edges = WeightedGraph(demographics,index_vectors,rest_matrix_list.daily,contact_time_distributions.rest)
rest_weekly_edges = WeightedGraph(demographics,index_vectors,rest_matrix_list.twice_a_week,contact_time_distributions.rest)
rest_daily_edges_vector = [WeightedGraph(demographics,index_vectors,rest_matrix_list.otherwise,contact_time_distributions.rest) for i in 1:sim_length]
inf_network_lists = [
[home_static_edges,rest_static_edges] for i in 1:sim_length
]
soc_network_list = [home_static_edges,rest_static_edges,ws_static_edges]
for (t,l) in enumerate(inf_network_lists)
day_of_week = mod(t,7)
if !(day_of_week == 3 || day_of_week == 4) #simulation begins on thursday I guess
push!(l, ws_static_edges)
end
if rand(RNG)<5/7
push!(l, ws_weekly_edges)
push!(l, rest_weekly_edges)
end
push!(l,ws_daily_edges_vector[t])
push!(l,rest_daily_edges_vector[t])
end
covid_alert_times = zeros(Int,length(app_user_index),14) #two weeks worth of values covid_alert_times = zeros(Int,length(app_user_index),14) #two weeks worth of values
time_of_last_alert = fill(-1,length(app_user_index)) #two weeks worth of values time_of_last_alert = fill(-1,length(app_user_index)) #time of last alert is negative if no alert has been recieved
status_totals = zeros(Int, AgentStatus.size)
status_totals[1] = count(==(Susceptible), u_0_inf) status_totals = [count(==(i), u_0_inf) for i in 1:AgentStatus.size]
status_totals[2] = count(==(Infected), u_0_inf)
status_totals[3] = count(==(Recovered), u_0_inf)
status_totals[4] = count(==(Immunized), u_0_inf)
return new{T,typeof(infected_mixing_graph),typeof(soc_mixing_graph),typeof(ws_matrix_tuple),typeof(rest_matrix_tuple)}(
return new{T,typeof(home_static_edges)}(
sim_length, sim_length,
nodes, nodes,
params, params,
...@@ -114,15 +84,17 @@ struct ModelSolution{T,G} ...@@ -114,15 +84,17 @@ struct ModelSolution{T,G}
copy(u_0_vac), copy(u_0_vac),
covid_alert_times, covid_alert_times,
time_of_last_alert, time_of_last_alert,
inf_network_lists, infected_mixing_graph,
soc_network_list, soc_mixing_graph,
index_vectors, index_vectors,
demographics, demographics,
is_app_user, is_app_user,
app_user_list, app_user_list,
app_user_index, app_user_index,
status_totals, status_totals,
copy(status_totals) copy(status_totals),
ws_matrix_tuple,
rest_matrix_tuple
) )
end end
end end
...@@ -4,16 +4,16 @@ function contact_weight(p, contact_time) ...@@ -4,16 +4,16 @@ function contact_weight(p, contact_time)
end end
function update_alert_durations!(t,modelsol) function update_alert_durations!(t,modelsol)
@unpack notification_parameter = modelsol.params @unpack notification_parameter = modelsol.params
@unpack time_of_last_alert, app_user_index,inf_network_lists,covid_alert_times,app_user = modelsol @unpack time_of_last_alert, app_user_index,inf_network,covid_alert_times,app_user = modelsol
for (i,node) in enumerate(modelsol.app_user_index), mixing_graph in modelsol.inf_network_lists[t] for (i,node) in enumerate(modelsol.app_user_index), mixing_graph in modelsol.inf_network.graph_list[t]
for j in 2:14 for j in 2:14
covid_alert_times[i,j-1] = covid_alert_times[i,j] #shift them all back covid_alert_times[i,j-1] = covid_alert_times[i,j] #shift them all back
end end
for j in neighbors(mixing_graph.g,node) for j in neighbors(mixing_graph.g,node)
if app_user[j] if app_user[j]
covid_alert_times[i,end] += mixing_graph.weights[(node,j)] #add the contact times for today to the back covid_alert_times[i,end] += mixing_graph.weights_dict[(node,j)] #add the contact times for today to the back
end end
end end
if rand(RNG) < 1 - (1- notification_parameter)^sum(covid_alert_times[i,:]) if rand(RNG) < 1 - (1- notification_parameter)^sum(covid_alert_times[i,:])
...@@ -23,7 +23,7 @@ function update_alert_durations!(t,modelsol) ...@@ -23,7 +23,7 @@ function update_alert_durations!(t,modelsol)
end end
function update_infection_state!(t,modelsol) function update_infection_state!(t,modelsol)
@unpack base_transmission_probability,immunization_loss_prob,recovery_rate = modelsol.params @unpack base_transmission_probability,immunization_loss_prob,recovery_rate = modelsol.params
@unpack u_inf,u_vac,u_next_inf,u_next_vac,demographics,inf_network_lists,status_totals,status_totals_next = modelsol @unpack u_inf,u_vac,u_next_inf,u_next_vac,demographics,inf_network,status_totals,status_totals_next = modelsol
function agent_transition!(node, from::AgentStatus,to::AgentStatus) function agent_transition!(node, from::AgentStatus,to::AgentStatus)
status_totals_next[Int(from)] -= 1 status_totals_next[Int(from)] -= 1
...@@ -41,9 +41,9 @@ function update_infection_state!(t,modelsol) ...@@ -41,9 +41,9 @@ function update_infection_state!(t,modelsol)
if is_vaccinator if is_vaccinator
agent_transition!(i, Susceptible,Immunized) agent_transition!(i, Susceptible,Immunized)
else else
for mixing_graph in inf_network_lists[t] for mixing_graph in inf_network.graph_list[t]
for j in neighbors(mixing_graph.g,i) for j in neighbors(mixing_graph.g,i)
if u_inf[j] == Infected && rand(RNG) < contact_weight(base_transmission_probability,mixing_graph.weights[(i,j)]) if u_inf[j] == Infected && rand(RNG) < contact_weight(base_transmission_probability,mixing_graph.weights_dict[(i,j)])
agent_transition!(i, Susceptible,Infected) agent_transition!(i, Susceptible,Infected)
end end
end end
...@@ -63,7 +63,7 @@ end ...@@ -63,7 +63,7 @@ end
function update_vaccination_opinion_state!(t,modelsol,total_infections) function update_vaccination_opinion_state!(t,modelsol,total_infections)
@unpack π_base, η,γ, κ, ω, ρ, ω_en,ρ_en,γ,β = modelsol.params @unpack π_base, η,γ, κ, ω, ρ, ω_en,ρ_en,γ,β = modelsol.params
@unpack demographics,time_of_last_alert, nodes, soc_networks,u_vac,u_next_vac,app_user,app_user_list = modelsol @unpack demographics,time_of_last_alert, nodes, soc_network,u_vac,u_next_vac,app_user,app_user_list = modelsol
app_user_pointer = 0 app_user_pointer = 0
...@@ -72,8 +72,8 @@ function update_vaccination_opinion_state!(t,modelsol,total_infections) ...@@ -72,8 +72,8 @@ function update_vaccination_opinion_state!(t,modelsol,total_infections)
soc_nbrs_vac = [0,0,0] soc_nbrs_vac = [0,0,0]
soc_nbrs_nonvac = 0 soc_nbrs_nonvac = 0
num_soc_nbrs = 0 num_soc_nbrs = 0
for soc_network in soc_networks for sc_g in soc_network.graph_list[t]
soc_nbrs = neighbors(soc_network.g,i) soc_nbrs = neighbors(sc_g.g,i)
num_soc_nbrs += length(soc_nbrs) num_soc_nbrs += length(soc_nbrs)
for nbr in soc_nbrs for nbr in soc_nbrs
if u_vac[nbr] if u_vac[nbr]
...@@ -108,7 +108,10 @@ end ...@@ -108,7 +108,10 @@ end
function agents_step!(t,modelsol) function agents_step!(t,modelsol)
for network in modelsol.inf_network_lists[t]
remake!(modelsol.inf_network,modelsol.index_vectors,modelsol.ws_matrix_tuple.daily)
remake!(modelsol.inf_network,modelsol.index_vectors,modelsol.rest_matrix_tuple.daily)
for network in modelsol.inf_network.graph_list[t]
sample_mixing_graph!(network,modelsol.demographics) #get new contact weights sample_mixing_graph!(network,modelsol.demographics) #get new contact weights
end end
......
...@@ -33,9 +33,9 @@ const RNG = Xoroshiro128Star(1) ...@@ -33,9 +33,9 @@ const RNG = Xoroshiro128Star(1)
const color_palette = palette(:seaborn_pastel) #color theme for the plots const color_palette = palette(:seaborn_pastel) #color theme for the plots
include("utils.jl") include("utils.jl")
include("data.jl") include("data.jl")
include("ABM/agents.jl")
include("ABM/mixing_distributions.jl") include("ABM/mixing_distributions.jl")
include("ABM/mixing_graphs.jl") include("ABM/mixing_graphs.jl")
include("ABM/agents.jl")
include("ABM/plotting.jl") include("ABM/plotting.jl")
include("ABM/model_setup.jl") include("ABM/model_setup.jl")
include("ABM/output.jl") include("ABM/output.jl")
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment