Skip to content
Snippets Groups Projects
Commit 189c2fa4 authored by Peter Jentsch's avatar Peter Jentsch
Browse files

improve performance in model loading

parent ff1d2569
No related branches found
No related tags found
No related merge requests found
......@@ -2,7 +2,7 @@
const population = 14.57e6 #population of ontario
const workschool_mixing, rest_mixing = load_mixing_matrices()
const age_bins = [(0.0, 25.0),(25.0,65.0),(65.0,Inf)]
const household_data = read_household_data()
default(dpi = 300)
default(framestyle = :box)
......@@ -13,7 +13,7 @@ Runs the model with 5k households for 500 timesteps.
"""
function bench()
Random.seed!(RNG,1)
steps = 500
steps = 50
model_sol = ModelSolution(steps,get_parameters(),5000)
recording = DebugRecorder(steps)
output = solve!(model_sol,recording )
......
......@@ -26,23 +26,29 @@ end
Generates a complete graph from a vector of household compositions. Each household composition is a 3 element vectors (one for each demographic group) of integers where each element describes the number of the corresponding demographic group present in that household.
This function wires together a graph such that each household is in a complete subgraph. It is much faster than the `mapreduce` solution used previously.
Also returns a vector of AgentDemographic representing each agent, defined by the household compositions.
"""
function complete_graph_from_households_composition(households_composition)
@views function complete_graph_from_households_composition(households_composition)
total_household_pop = sum(sum.(households_composition))
population_list = Vector{AgentDemographic}(undef,total_household_pop)
network = SimpleGraph(total_household_pop)
vertex_pointer = 1
for household in households_composition
num_vertices = sum(household)
for v in vertex_pointer:(vertex_pointer + num_vertices - 1)
for w in vertex_pointer:(vertex_pointer + num_vertices - 1)
if v != w
add_edge!(network,v,w)
end
household_pointer = 0
for (k,size) in enumerate(household)
@inbounds population_list[vertex_pointer+household_pointer:vertex_pointer+household_pointer+size-1] .= AgentDemographic(k)
household_pointer += size
end
for v in vertex_pointer:(vertex_pointer + num_vertices - 1), w in vertex_pointer:(vertex_pointer + num_vertices - 1)
if v != w
add_edge!(network,v,w)
end
end
vertex_pointer+=num_vertices
end
return network
return network,population_list
end
......@@ -59,19 +65,8 @@ index_vectors: Vector of 3 vectors, each of which contains the indexes of all th
"""
function generate_population(num_households)
households_composition = sample_household_data(num_households)
household_networks = complete_graph_from_households_composition(households_composition)
#test the complete_graph_from_households_composition function against the mapreduce version to ensure it is correct.
@c_assert all(adjacency_matrix(household_networks) .== adjacency_matrix( mapreduce(LightGraphs.complete_graph,blockdiag, sum.(households_composition); init = SimpleGraph())))
households = map( l -> [fill(AgentDemographic(i),n) for (i,n) in enumerate(l)],households_composition) |>
l -> reduce(vcat,l) |>
l -> reduce(vcat,l)
population_list = reduce(vcat,households)
household_networks,population_list = complete_graph_from_households_composition(households_composition)
index_vectors = [findall(x -> x == AgentDemographic(i), population_list) for i in 1:(AgentDemographic.size-1)]
return (;
population_list,
household_networks,
......@@ -81,7 +76,7 @@ end
"""
Defines pretty printing methods so that `display(AgentStatus)` is more readable.
Defines pretty printing methods so that `display(s::AgentStatus)` is more readable.
"""
function Base.show(io::IO, status::AgentStatus)
if status == Susceptible
......@@ -96,7 +91,7 @@ function Base.show(io::IO, status::AgentStatus)
end
"""
Defines pretty printing methods so that `display(AgentDemographic)` is more readable.
Defines pretty printing methods so that `display(s::AgentDemographic)` is more readable.
"""
function Base.show(io::IO, status::AgentDemographic)
if status == Young
......
# using LoopVectorization
function contact_weight(p, contact_time)
return 1 - (1-p)^contact_time
end
Base.@propagate_inbounds @views function update_alert_durations!(t,modelsol)
#remove Base.@propagate_inbounds if you get segfaults
@unpack notification_parameter = modelsol.params
@unpack time_of_last_alert, app_user_index,inf_network,covid_alert_times,app_user = modelsol
for (i,node) in enumerate(modelsol.app_user_index), mixing_graph in modelsol.inf_network.graph_list[t]
for j in 2:14
covid_alert_times[j-1,i] = covid_alert_times[j,i] #shift them all back
......@@ -15,13 +24,17 @@ Base.@propagate_inbounds @views function update_alert_durations!(t,modelsol)
covid_alert_times[end,i] += get_weight(mixing_graph,GraphEdge(node,j)) #add the contact times for today to the back
end
end
if rand(RNG) < 1 - (1- notification_parameter)^sum(covid_alert_times[:,i])
covid_alert_total_exposures = 1 - (1 - notification_parameter) ^ sum(covid_alert_times[:,i])
if rand(RNG) < covid_alert_total_exposures
time_of_last_alert[i] = t
end
end
end
Base.@propagate_inbounds @views function update_infection_state!(t,modelsol)
#remove Base.@propagate_inbounds if you get segfaults
@unpack base_transmission_probability,immunization_loss_prob,recovery_rate = modelsol.params
@unpack u_inf,u_vac,u_next_inf,u_next_vac,demographics,inf_network,status_totals = modelsol
......@@ -65,6 +78,7 @@ Base.@propagate_inbounds @views function update_infection_state!(t,modelsol)
# display(u_next_inf)
end
Base.@propagate_inbounds @views function update_vaccination_opinion_state!(t,modelsol,total_infections)
#remove Base.@propagate_inbounds if you get segfaults
@unpack π_base, η,γ, κ, ω, ρ, ω_en,ρ_en,γ,β = modelsol.params
@unpack demographics,time_of_last_alert, nodes, soc_network,u_vac,u_next_vac,app_user,app_user_list = modelsol
......
......@@ -35,20 +35,20 @@ function find_household_composition(df_row)
age_distribution[age_resp_to_bin[df_row[:AGERESP]]] += 1
return SVector{3}(age_distribution)
end
function sample_household_data(n)
function read_household_data()
f = readdlm(joinpath(PACKAGE_FOLDER,"data/csv/home_compositions.csv"), ',')
df = DataFrame([f[1,i] => f[2:end, i] for i = 1:length(f[1,:])])
weight_vector::Vector{Float64} = df[!,:WGHT_PER]/sum(df[!,:WGHT_PER])
households = map(find_household_composition,eachrow(df))
return sample(RNG,households, Weights(weight_vector),n)
# https://www.publichealthontario.ca/-/media/documents/ncov/epi/covid-19-severe-outcomes-ontario-epi-summary.pdf?la=en
return (;households,weight_vector)
end
function sample_household_data(n)
return sample(RNG,household_data.households, Weights(household_data.weight_vector),n)
end
function get_household_data_proportions()
f = readdlm(joinpath(PACKAGE_FOLDER,"data/csv/home_compositions.csv"), ',')
df = DataFrame([f[1,i] => f[2:end, i] for i = 1:length(f[1,:])])
weight_vector::Vector{Float64} = df[!,:WGHT_PER]/sum(df[!,:WGHT_PER])
households = map(find_household_composition,eachrow(df))
households_by_demographic_sum = sum.([map(l-> l[i], households) for i in 1:3])
households_by_demographic_sum = sum.([map(l-> l[i],household_data.households) for i in 1:3])
return households_by_demographic_sum./sum(households_by_demographic_sum)
# https://www.publichealthontario.ca/-/media/documents/ncov/epi/covid-19-severe-outcomes-ontario-epi-summary.pdf?la=en
end
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment