Skip to content
Snippets Groups Projects
Commit 44e3c3e6 authored by Peter Jentsch's avatar Peter Jentsch
Browse files

replace mapreduce with a faster graph construction, add tests for output, fix bug in output

parent 90d4a13e
No related branches found
No related tags found
No related merge requests found
Showing with 180 additions and 125 deletions
......@@ -8,15 +8,16 @@ default(framestyle = :box)
function bench()
steps = 100
Random.seed!(RNG,1)
model_sol = ModelSolution(steps,get_parameters(),5000);
steps = 500
model_sol = ModelSolution(steps,get_parameters(),5000)
recording = DebugRecorder(steps)
output = solve!(model_sol,recording )
display(output.recorded_status_totals[2,:])
end
function abm()
b1 = @benchmark bench() seconds = 20
b1 = @benchmark bench() seconds = 20
display(b1)
println("done")
end
......@@ -12,22 +12,43 @@ end
end
function complete_graph_from_households_composition(households_composition)
total_household_pop = sum(sum.(households_composition))
network = SimpleGraph(total_household_pop)
vertex_pointer = 1
for household in households_composition
num_vertices = sum(household)
for v in vertex_pointer:(vertex_pointer + num_vertices - 1)
for w in vertex_pointer:(vertex_pointer + num_vertices - 1)
if v != w
add_edge!(network,v,w)
end
end
end
vertex_pointer+=num_vertices
end
return network
end
#generate population with households distributed according to household_size_distribution
function generate_population(num_households)
households_composition = sample_household_data(num_households)
households = map( l -> [fill(AgentDemographic(i),n) for (i,n) in enumerate(l)],households_composition) |>
l -> reduce(vcat,l) |>
l -> reduce(vcat,l)
total_household_pop = sum(sum.(households_composition))
household_networks = map(LightGraphs.complete_graph, sum.(households_composition))
household_networks = complete_graph_from_households_composition(households_composition)
# @assert all(adjacency_matrix(household_networks) .== adjacency_matrix( mapreduce(LightGraphs.complete_graph,blockdiag, sum.(households_composition); init = SimpleGraph())))
households = map( l -> [fill(AgentDemographic(i),n) for (i,n) in enumerate(l)],households_composition) |>
l -> reduce(vcat,l) |>
l -> reduce(vcat,l)
population_list = reduce(vcat,households)
network = reduce(blockdiag,household_networks; init = SimpleGraph())
index_vectors = [findall(x -> x == AgentDemographic(i), population_list) for i in 1:(AgentDemographic.size-1)]
return (;
population_list,
network,
household_networks,
index_vectors
)
end
......
#A type that defines a matrix of vectors, such that the sum of the ijth vector is equal to the sum of the jith vector
struct MixingContacts{V}
struct MixingEdges{V}
total_edges::Int
contact_array::V
function MixingContacts(demographic_index_vectors,mixing_matrix)
function MixingEdges(demographic_index_vectors,mixing_matrix)
contacts = map(CartesianIndices(mixing_matrix)) do ind
zeros(Int,length(demographic_index_vectors[ind[1]]))
end
......@@ -15,23 +15,6 @@ struct MixingContacts{V}
end
end
#modify g so that nodes specified in anodes and bnodes are connected by a bipartite graph with expected degrees given by aseq and bseq
#implemented from Aksoy, S. G., Kolda, T. G., & Pinar, A. (2017). Measuring and modeling bipartite graphs with community structure
#simple algorithm, lightgraphs does not allow parallel edges
function random_bipartite_graph_fast_CL!(g::SimpleGraph,anodes,bnodes,aseq,bseq, weights_dict)
lena = length(aseq)
lenb = length(bseq)
m = Int(sum(aseq))
if m>0
# @assert sum(aseq) == sum(bseq) "degree sequences must have equal sum"
adist = sampler(Categorical(aseq./m))
bdist = sampler(Categorical(bseq./m))
for k in 1:m
add_edge!(g,anodes[rand(RNG,adist)], bnodes[rand(RNG,bdist)])
end
end
end
struct TimeDepMixingGraph{N,G}
resampled_graphs::NTuple{N,G}
graph_list::Vector{Vector{G}}
......@@ -42,6 +25,35 @@ struct TimeDepMixingGraph{N,G}
)
end
end
#modify g so that nodes specified in anodes and bnodes are connected by a bipartite graph with expected degrees given by aseq and bseq
#implemented from Aksoy, S. G., Kolda, T. G., & Pinar, A. (2017). Measuring and modeling bipartite graphs with community structure
#simple algorithm, lightgraphs does not allow parallel edges.
#(also not actually that fast, this is a major bottleneck)
function random_bipartite_graphs_fast!(g, demographic_index_vectors,mixing_contacts)
astubs = Vector{Int}(undef,mixing_contacts.total_edges)
bstubs = Vector{Int}(undef,mixing_contacts.total_edges)
stubs_ptr = 1
for i in 1:length(demographic_index_vectors), j in 1:i #diagonal
anodes = demographic_index_vectors[i]
bnodes = demographic_index_vectors[j]
aseq = mixing_contacts.contact_array[i,j]
bseq = mixing_contacts.contact_array[j,i]
# @show sum(contacts[i,j])
m = Int(sum(aseq))
if m>0
astubs_ij = @view astubs[stubs_ptr:stubs_ptr + m - 1]
bstubs_ij = @view bstubs[stubs_ptr:stubs_ptr + m - 1]
@assert sum(aseq) == sum(bseq) "degree sequences must have equal sum"
rand!(RNG,DiscreteNonParametric(anodes,aseq./m),astubs_ij)
rand!(RNG,DiscreteNonParametric(bnodes,bseq./m),bstubs_ij)
for k in 1:m
add_edge!(g,astubs_ij[k], bstubs_ij[k])
end
stubs_ptr += m
end
end
return astubs,bstubs
end
function time_dep_mixing_graphs(len,base_network,demographics,index_vectors,ws_matrix_tuple,rest_matrix_tuple)
home_static_edges = WeightedGraph(base_network,contact_time_distributions.hh) #network with households and LTC homes
......@@ -78,51 +90,48 @@ function remake!(time_dep_mixing_graph,demographic_index_vectors,mixing_matrix)
empty!.(weighted_graph.g.fadjlist) #empty all the vector edgelists
empty!(weighted_graph.weights_dict)
weighted_graph.g.ne = 0
contacts = MixingContacts(demographic_index_vectors,mixing_matrix)
for i in 1:length(demographic_index_vectors), j in 1:i #diagonal
random_bipartite_graph_fast_CL!(
weighted_graph.g,
demographic_index_vectors[i],
demographic_index_vectors[j],
contacts.contact_array[i,j],
contacts.contact_array[j,i],
weighted_graph.weights_dict
)
end
contacts = MixingEdges(demographic_index_vectors,mixing_matrix)
astubs,bstubs = random_bipartite_graphs_fast!(weighted_graph.g,demographic_index_vectors,contacts)
weighted_graph.astubs = astubs
weighted_graph.bstubs = bstubs
end
end
#defining my own weighted graph type cause we need to be able to resample the edge weights in a particular way
struct WeightedGraph{G, V,M}
#defining my own weighted graph type cause we need to be able to resample the edge weights in a very particular way
mutable struct WeightedGraph{G, V,M}
g::G
weights_dict::V
weights_distribution_matrix::M
astubs::Vector{Int}
bstubs::Vector{Int}
function WeightedGraph(demographics,demographic_index_vectors,mixing_matrix,weights_distribution_matrix)
contacts = MixingContacts(demographic_index_vectors,mixing_matrix)
mixing_contacts = MixingEdges(demographic_index_vectors,mixing_matrix)
g = Graph(length(demographics))
weights_dict = RobinDict{Tuple{Int,Int},UInt8}()
sampler_matrix = map(m -> Distributions.PoissonCountSampler.(m),weights_distribution_matrix)
for i in 1:length(demographic_index_vectors), j in 1:i #diagonal
random_bipartite_graph_fast_CL!(g,demographic_index_vectors[i],demographic_index_vectors[j],contacts.contact_array[i,j],contacts.contact_array[j,i],weights_dict)
end
sampler_matrix = map(m -> Distributions.PoissonCountSampler.(m),weights_distribution_matrix)
astubs,bstubs = random_bipartite_graphs_fast!(g,demographic_index_vectors,mixing_contacts)
return new{typeof(g),typeof(weights_dict),typeof(sampler_matrix)}(
g,
weights_dict,
sampler_matrix
sampler_matrix,
astubs,
bstubs
)
end
function WeightedGraph(g::SimpleGraph,weights_distribution_matrix)
astubs = Vector{Int}(undef,g.ne)
bstubs = Vector{Int}(undef,g.ne)
for (k,e) in enumerate(edges(g))
astubs[k] = src(e)
bstubs[k] = dst(e)
end
weights_dict = RobinDict{Tuple{Int,Int},UInt8}()
sampler_matrix = map(m -> Distributions.PoissonCountSampler.(m),weights_distribution_matrix)
return new{typeof(g),typeof(weights_dict),typeof(sampler_matrix)}(
g,
weights_dict,
sampler_matrix
sampler_matrix,
astubs,
bstubs
)
end
end
......@@ -130,21 +139,16 @@ function Base.show(io::IO, g::WeightedGraph)
print(io, "WG")
end
function sample_mixing_graph!(mixing_graph,population_demographics)
for (k,e) in enumerate(edges(mixing_graph.g))
i = src(e)
j = dst(e)
for k in eachindex(mixing_graph.astubs)
i = mixing_graph.astubs[k]
j = mixing_graph.bstubs[k]
demo_i = Int(population_demographics[i])
demo_j = Int(population_demographics[j])
contact_time = rand(RNG, mixing_graph.weights_distribution_matrix[demo_i,demo_j])
mixing_graph.weights_dict[(i,j)] = contact_time
mixing_graph.weights_dict[(j,i)] = contact_time
end
end
@inline function reindex!(k,csum,index_list_i,index_list_j,j_to_i_contacts,i_to_j_contacts,sample_list_i,sample_list_j)
i_index = index_list_i[k]
j_index = index_list_j[k]
......@@ -179,6 +183,5 @@ function generate_contact_vectors!(ij_dist,ji_dist,i_to_j_contacts::Vector{T}, j
end
end
end
return nothing
end
\ No newline at end of file
function get_parameters()
params = (
I_0_fraction = 0.05,
I_0_fraction = 0.01,
base_transmission_probability = 0.5,
recovery_rate = 0.1,
immunization_loss_prob = 0.5,
......@@ -15,7 +15,7 @@ function get_parameters()
γ = 0.0,
β = 10.0,
notification_parameter = 0.0,
vaccinator_prob = 0.5,
vaccinator_prob = 0.2,
app_user_fraction = 0.5,
)
return params
......@@ -72,7 +72,7 @@ struct ModelSolution{T,InfNet,SocNet,WSMixingDist,RestMixingDist}
covid_alert_times = zeros(Int,length(app_user_index),14) #two weeks worth of values
time_of_last_alert = fill(-1,length(app_user_index)) #time of last alert is negative if no alert has been recieved
status_totals = [count(==(i), u_0_inf) for i in 1:AgentStatus.size]
status_totals = [count(==(AgentStatus(i)), u_0_inf) for i in 1:AgentStatus.size]
return new{T,typeof(infected_mixing_graph),typeof(soc_mixing_graph),typeof(ws_matrix_tuple),typeof(rest_matrix_tuple)}(
sim_length,
......
......@@ -30,6 +30,7 @@ function record!(t,modelsol, recorder::DebugRecorder)
recorder.Total_Vaccinator[t] = count(==(true),modelsol.u_vac)
recorder.recorded_status_totals[:,t] .= modelsol.status_totals
end
function record!(t,modelsol, recorder::Nothing)
......
......@@ -23,16 +23,15 @@ function update_alert_durations!(t,modelsol)
end
function update_infection_state!(t,modelsol)
@unpack base_transmission_probability,immunization_loss_prob,recovery_rate = modelsol.params
@unpack u_inf,u_vac,u_next_inf,u_next_vac,demographics,inf_network,status_totals,status_totals_next = modelsol
@unpack u_inf,u_vac,u_next_inf,u_next_vac,demographics,inf_network,status_totals = modelsol
function agent_transition!(node, from::AgentStatus,to::AgentStatus)
status_totals_next[Int(from)] -= 1
status_totals_next[Int(to)] += 1
status_totals[Int(from)] -= 1
status_totals[Int(to)] += 1
u_next_inf[node] = to
end
u_next_inf .= u_inf
status_totals_next .= status_totals
for i in 1:modelsol.nodes
agent_status = u_inf[i]
is_vaccinator = u_vac[i]
......@@ -43,8 +42,10 @@ function update_infection_state!(t,modelsol)
else
for mixing_graph in inf_network.graph_list[t]
for j in neighbors(mixing_graph.g,i)
if u_inf[j] == Infected && rand(RNG) < contact_weight(base_transmission_probability,mixing_graph.weights_dict[(i,j)])
agent_transition!(i, Susceptible,Infected)
if u_inf[j] == Infected && u_next_inf[i] != Infected
if rand(RNG) < contact_weight(base_transmission_probability,mixing_graph.weights_dict[(i,j)])
agent_transition!(i, Susceptible,Infected)
end
end
end
end
......@@ -121,7 +122,7 @@ function agents_step!(t,modelsol)
modelsol.u_vac .= modelsol.u_next_vac
modelsol.u_inf .= modelsol.u_next_inf
modelsol.status_totals .= modelsol.status_totals_next
# modelsol.status_totals .= modelsol.status_totals_next
end
......@@ -129,10 +130,8 @@ end
function solve!(modelsol,recording)
for t in 1:modelsol.sim_length
#advance agent states based on the new network
agents_step!(t,modelsol)
record!(t,modelsol,recording)
agents_step!(t,modelsol)
end
return recording
end
......
......@@ -26,6 +26,14 @@ import Pandas: read_csv
using DataFrames
export intervalsmodel, hh, ws, rest, abm
const DNDEBUG = false
macro c_assert(boolean)
if DNDEBUG
message = string("Assertion: ", boolean, " failed")
:($(esc(boolean)) || error($message))
end
end
const PACKAGE_FOLDER = dirname(dirname(pathof(CovidAlertVaccinationModel)))
const RNG = Xoroshiro128Star(1)
#consts that let give us nicer names for the indices
......
......@@ -18,6 +18,7 @@ function get_canada_case_fatality()::Tuple{Vector{Tuple{Float64,Float64}},Vector
# https://www.publichealthontario.ca/-/media/documents/ncov/epi/covid-19-severe-outcomes-ontario-epi-summary.pdf?la=en
end
function find_household_composition(df_row)
# display(typeof(df_row))
age_resp_to_bin = Dict(
"Y" => 1,
"M" => 2,
......@@ -26,13 +27,13 @@ function find_household_composition(df_row)
u25_bins = [:U15CHILD,:O15CHILD,:YSPOUSE]
m_bins = [:MPAR, :MCHILD,:MHHADULT]
o_bins = [:OPAR, :OSPOUSE,:OHHADULT]
age_distribution = [0,0,0]
age_distribution[1] += sum(df_row[field ] for field in u25_bins)
age_distribution[2] += sum(df_row[field] for field in m_bins)
age_distribution[3] += sum(df_row[field] for field in o_bins)
age_distribution =[
sum(Int(df_row[field]) for field in u25_bins),
sum(Int(df_row[field]) for field in m_bins),
sum(Int(df_row[field]) for field in o_bins),
]
age_distribution[age_resp_to_bin[df_row[:AGERESP]]] += 1
return age_distribution
return SVector{3}(age_distribution)
end
function sample_household_data(n)
f = readdlm(joinpath(PACKAGE_FOLDER,"data/csv/home_compositions.csv"), ',')
......
using CovidAlertVaccinationModel:ModelSolution,AgentDemographic,mean,AgentStatus,get_u_0,get_parameters,solve!, DebugRecorder
using Random
const model_sizes = [100,1000,5000]
const dem_cat = AgentDemographic.size -1
const steps = 300
const reps = 5
#network generation
#covidalert
@testset "mixing matrices, size: $sz" for (m,sz) in zip(agent_models,model_sizes)
ws_dist = m.ws_matrix_list
r_dist = m.rest_matrix_list
index_vec =m.demographic_index_vectors
@testset "workschool" for i in dem_cat, j in dem_cat
for t in 1:length(ws_dist)
@test mean(ws_dist[t][i,j])*length(index_vec[i]) == mean(ws_dist[t][j,i])*length(index_vec[j])
@testset "mixing matrices, size: $sz" for sz in model_sizes
for rep = 1:reps
m = ModelSolution(10,get_parameters(),sz)
ws_dist = m.ws_matrix_tuple
r_dist = m.rest_matrix_tuple
index_vec =m.index_vectors
@testset "workschool" for i in dem_cat, j in dem_cat
for t in 1:length(ws_dist)
@test mean(ws_dist[t][i,j])*length(index_vec[i]) == mean(ws_dist[t][j,i])*length(index_vec[j])
end
end
end
@testset "rest" for i in dem_cat, j in dem_cat
for t in 1:length(ws_dist)
@test mean(r_dist[t][i,j])*length(index_vec[i]) == mean(r_dist[t][j,i])*length(index_vec[j])
@testset "rest" for i in dem_cat, j in dem_cat
for t in 1:length(ws_dist)
@test mean(r_dist[t][i,j])*length(index_vec[i]) == mean(r_dist[t][j,i])*length(index_vec[j])
end
end
end
end
function vaccinator_opinion_test(model,vac_strategy, π_base; rng = Xoroshiro128Plus())
params = merge(get_parameters(),(I_0_fraction = 0.0,π_base))
steps = 300
sol1,_ = solve!(u_0,params,steps,model,vac_strategy);
total_infections = count(x->x == AgentStatus(3),sol1[end])
display(total_infections)
return total_infections
@testset "status counter, size: $sz" for sz in model_sizes
for rep = 1:reps
Random.seed!(CovidAlertVaccinationModel.RNG,1)
m = ModelSolution(steps,get_parameters(),sz)
recording = DebugRecorder(steps)
@test all(recording.Total_S .== recording.recorded_status_totals[1,:])
@test all(recording.Total_I .== recording.recorded_status_totals[2,:])
@test all(recording.Total_R .== recording.recorded_status_totals[3,:])
@test all(recording.Total_V .== recording.recorded_status_totals[4,:])
end
end
function infection_rate_test(model, inf_parameter; rng = Xoroshiro128Plus())
steps = 300
# display(params)
sol1,_ = solve!(params,steps,model,vaccinate_uniformly!);
total_infections = count(x->x == AgentStatus(3),sol1[end])
# function vaccinator_opinion_test(model,vac_strategy, π_base; rng = Xoroshiro128Plus())
# params = merge(get_parameters(),(I_0_fraction = 0.0,π_base))
# steps = 300
# sol1,_ = solve!(u_0,params,steps,model,vac_strategy);
# total_infections = count(x->x == AgentStatus(3),sol1[end])
# display(total_infections)
# return total_infections
# end
# display(total_infections)
return total_infections
end
function test_comparison(f,xpts,comparison)
xpts_sorted = sort(xpts)
ypts = ThreadsX.map(f,xpts)
return all(comparison(ypts[i],ypts[i+1]) for i in 1:length(ypts)-1)
end
# function infection_rate_test(model, inf_parameter; rng = Xoroshiro128Plus())
# steps = 300
# # display(params)
# sol1,_ = solve!(params,steps,model,vaccinate_uniformly!);
# total_infections = count(x->x == AgentStatus(3),sol1[end])
# # display(total_infections)
# return total_infections
# end
# function test_comparison(f,xpts,comparison)
# xpts_sorted = sort(xpts)
# ypts = ThreadsX.map(f,xpts)
# return all(comparison(ypts[i],ypts[i+1]) for i in 1:length(ypts)-1)
# end
# @testset "vaccination efficacy $sz" for (m,sz) in zip(deepcopy(agent_models),model_sizes)
# @show vac_rate_test(m,vaccination_strategies[1],vaccination_rates[1])
......@@ -55,14 +74,14 @@ end
# end
# end
@testset "infection efficacy $sz" for (m,sz) in zip(deepcopy(agent_models),model_sizes)
@test test_comparison(x->infection_rate_test(m,x),infection_rates,<)
end
# @testset "infection efficacy $sz" for (m,sz) in zip(deepcopy(agent_models),model_sizes)
# @test test_comparison(x->infection_rate_test(m,x),infection_rates,<)
# end
@testset "infection efficacy $sz" for (m,sz) in zip(deepcopy(agent_models),model_sizes)
@test test_comparison(x->infection_rate_test(m,x),infection_rates,<)
end
# @testset "infection efficacy $sz" for (m,sz) in zip(deepcopy(agent_models),model_sizes)
# @test test_comparison(x->infection_rate_test(m,x),infection_rates,<)
# end
......
#todo: write tests for this
\ No newline at end of file
using CovidAlertVaccinationModel:AgentModel,AgentDemographic,mean,vaccinate_uniformly!,AgentStatus,get_u_0,get_parameters,solve!
using RandomNumbers.Xorshifts
using Test
using ThreadsX
import StatsBase.mean
include("ABM/abm_test.jl")
\ No newline at end of file
include("ABM/abm_test.jl")
include("IntervalsModel/intervals_model_test.jl")
\ No newline at end of file
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment