Skip to content
Snippets Groups Projects
Commit 189c2fa4 authored by Peter Jentsch's avatar Peter Jentsch
Browse files

improve performance in model loading

parent ff1d2569
No related branches found
No related tags found
No related merge requests found
...@@ -2,7 +2,7 @@ ...@@ -2,7 +2,7 @@
const population = 14.57e6 #population of ontario const population = 14.57e6 #population of ontario
const workschool_mixing, rest_mixing = load_mixing_matrices() const workschool_mixing, rest_mixing = load_mixing_matrices()
const age_bins = [(0.0, 25.0),(25.0,65.0),(65.0,Inf)] const age_bins = [(0.0, 25.0),(25.0,65.0),(65.0,Inf)]
const household_data = read_household_data()
default(dpi = 300) default(dpi = 300)
default(framestyle = :box) default(framestyle = :box)
...@@ -13,7 +13,7 @@ Runs the model with 5k households for 500 timesteps. ...@@ -13,7 +13,7 @@ Runs the model with 5k households for 500 timesteps.
""" """
function bench() function bench()
Random.seed!(RNG,1) Random.seed!(RNG,1)
steps = 500 steps = 50
model_sol = ModelSolution(steps,get_parameters(),5000) model_sol = ModelSolution(steps,get_parameters(),5000)
recording = DebugRecorder(steps) recording = DebugRecorder(steps)
output = solve!(model_sol,recording ) output = solve!(model_sol,recording )
......
...@@ -26,23 +26,29 @@ end ...@@ -26,23 +26,29 @@ end
Generates a complete graph from a vector of household compositions. Each household composition is a 3 element vectors (one for each demographic group) of integers where each element describes the number of the corresponding demographic group present in that household. Generates a complete graph from a vector of household compositions. Each household composition is a 3 element vectors (one for each demographic group) of integers where each element describes the number of the corresponding demographic group present in that household.
This function wires together a graph such that each household is in a complete subgraph. It is much faster than the `mapreduce` solution used previously. This function wires together a graph such that each household is in a complete subgraph. It is much faster than the `mapreduce` solution used previously.
Also returns a vector of AgentDemographic representing each agent, defined by the household compositions.
""" """
function complete_graph_from_households_composition(households_composition) @views function complete_graph_from_households_composition(households_composition)
total_household_pop = sum(sum.(households_composition)) total_household_pop = sum(sum.(households_composition))
population_list = Vector{AgentDemographic}(undef,total_household_pop)
network = SimpleGraph(total_household_pop) network = SimpleGraph(total_household_pop)
vertex_pointer = 1 vertex_pointer = 1
for household in households_composition for household in households_composition
num_vertices = sum(household) num_vertices = sum(household)
for v in vertex_pointer:(vertex_pointer + num_vertices - 1) household_pointer = 0
for w in vertex_pointer:(vertex_pointer + num_vertices - 1) for (k,size) in enumerate(household)
if v != w @inbounds population_list[vertex_pointer+household_pointer:vertex_pointer+household_pointer+size-1] .= AgentDemographic(k)
add_edge!(network,v,w) household_pointer += size
end end
for v in vertex_pointer:(vertex_pointer + num_vertices - 1), w in vertex_pointer:(vertex_pointer + num_vertices - 1)
if v != w
add_edge!(network,v,w)
end end
end end
vertex_pointer+=num_vertices vertex_pointer+=num_vertices
end end
return network return network,population_list
end end
...@@ -59,19 +65,8 @@ index_vectors: Vector of 3 vectors, each of which contains the indexes of all th ...@@ -59,19 +65,8 @@ index_vectors: Vector of 3 vectors, each of which contains the indexes of all th
""" """
function generate_population(num_households) function generate_population(num_households)
households_composition = sample_household_data(num_households) households_composition = sample_household_data(num_households)
household_networks,population_list = complete_graph_from_households_composition(households_composition)
household_networks = complete_graph_from_households_composition(households_composition)
#test the complete_graph_from_households_composition function against the mapreduce version to ensure it is correct.
@c_assert all(adjacency_matrix(household_networks) .== adjacency_matrix( mapreduce(LightGraphs.complete_graph,blockdiag, sum.(households_composition); init = SimpleGraph())))
households = map( l -> [fill(AgentDemographic(i),n) for (i,n) in enumerate(l)],households_composition) |>
l -> reduce(vcat,l) |>
l -> reduce(vcat,l)
population_list = reduce(vcat,households)
index_vectors = [findall(x -> x == AgentDemographic(i), population_list) for i in 1:(AgentDemographic.size-1)] index_vectors = [findall(x -> x == AgentDemographic(i), population_list) for i in 1:(AgentDemographic.size-1)]
return (; return (;
population_list, population_list,
household_networks, household_networks,
...@@ -81,7 +76,7 @@ end ...@@ -81,7 +76,7 @@ end
""" """
Defines pretty printing methods so that `display(AgentStatus)` is more readable. Defines pretty printing methods so that `display(s::AgentStatus)` is more readable.
""" """
function Base.show(io::IO, status::AgentStatus) function Base.show(io::IO, status::AgentStatus)
if status == Susceptible if status == Susceptible
...@@ -96,7 +91,7 @@ function Base.show(io::IO, status::AgentStatus) ...@@ -96,7 +91,7 @@ function Base.show(io::IO, status::AgentStatus)
end end
""" """
Defines pretty printing methods so that `display(AgentDemographic)` is more readable. Defines pretty printing methods so that `display(s::AgentDemographic)` is more readable.
""" """
function Base.show(io::IO, status::AgentDemographic) function Base.show(io::IO, status::AgentDemographic)
if status == Young if status == Young
......
# using LoopVectorization
function contact_weight(p, contact_time) function contact_weight(p, contact_time)
return 1 - (1-p)^contact_time return 1 - (1-p)^contact_time
end end
Base.@propagate_inbounds @views function update_alert_durations!(t,modelsol) Base.@propagate_inbounds @views function update_alert_durations!(t,modelsol)
#remove Base.@propagate_inbounds if you get segfaults
@unpack notification_parameter = modelsol.params @unpack notification_parameter = modelsol.params
@unpack time_of_last_alert, app_user_index,inf_network,covid_alert_times,app_user = modelsol @unpack time_of_last_alert, app_user_index,inf_network,covid_alert_times,app_user = modelsol
for (i,node) in enumerate(modelsol.app_user_index), mixing_graph in modelsol.inf_network.graph_list[t] for (i,node) in enumerate(modelsol.app_user_index), mixing_graph in modelsol.inf_network.graph_list[t]
for j in 2:14 for j in 2:14
covid_alert_times[j-1,i] = covid_alert_times[j,i] #shift them all back covid_alert_times[j-1,i] = covid_alert_times[j,i] #shift them all back
...@@ -15,13 +24,17 @@ Base.@propagate_inbounds @views function update_alert_durations!(t,modelsol) ...@@ -15,13 +24,17 @@ Base.@propagate_inbounds @views function update_alert_durations!(t,modelsol)
covid_alert_times[end,i] += get_weight(mixing_graph,GraphEdge(node,j)) #add the contact times for today to the back covid_alert_times[end,i] += get_weight(mixing_graph,GraphEdge(node,j)) #add the contact times for today to the back
end end
end end
if rand(RNG) < 1 - (1- notification_parameter)^sum(covid_alert_times[:,i])
covid_alert_total_exposures = 1 - (1 - notification_parameter) ^ sum(covid_alert_times[:,i])
if rand(RNG) < covid_alert_total_exposures
time_of_last_alert[i] = t time_of_last_alert[i] = t
end end
end end
end end
Base.@propagate_inbounds @views function update_infection_state!(t,modelsol) Base.@propagate_inbounds @views function update_infection_state!(t,modelsol)
#remove Base.@propagate_inbounds if you get segfaults
@unpack base_transmission_probability,immunization_loss_prob,recovery_rate = modelsol.params @unpack base_transmission_probability,immunization_loss_prob,recovery_rate = modelsol.params
@unpack u_inf,u_vac,u_next_inf,u_next_vac,demographics,inf_network,status_totals = modelsol @unpack u_inf,u_vac,u_next_inf,u_next_vac,demographics,inf_network,status_totals = modelsol
...@@ -65,6 +78,7 @@ Base.@propagate_inbounds @views function update_infection_state!(t,modelsol) ...@@ -65,6 +78,7 @@ Base.@propagate_inbounds @views function update_infection_state!(t,modelsol)
# display(u_next_inf) # display(u_next_inf)
end end
Base.@propagate_inbounds @views function update_vaccination_opinion_state!(t,modelsol,total_infections) Base.@propagate_inbounds @views function update_vaccination_opinion_state!(t,modelsol,total_infections)
#remove Base.@propagate_inbounds if you get segfaults
@unpack π_base, η,γ, κ, ω, ρ, ω_en,ρ_en,γ,β = modelsol.params @unpack π_base, η,γ, κ, ω, ρ, ω_en,ρ_en,γ,β = modelsol.params
@unpack demographics,time_of_last_alert, nodes, soc_network,u_vac,u_next_vac,app_user,app_user_list = modelsol @unpack demographics,time_of_last_alert, nodes, soc_network,u_vac,u_next_vac,app_user,app_user_list = modelsol
......
...@@ -35,20 +35,20 @@ function find_household_composition(df_row) ...@@ -35,20 +35,20 @@ function find_household_composition(df_row)
age_distribution[age_resp_to_bin[df_row[:AGERESP]]] += 1 age_distribution[age_resp_to_bin[df_row[:AGERESP]]] += 1
return SVector{3}(age_distribution) return SVector{3}(age_distribution)
end end
function sample_household_data(n) function read_household_data()
f = readdlm(joinpath(PACKAGE_FOLDER,"data/csv/home_compositions.csv"), ',') f = readdlm(joinpath(PACKAGE_FOLDER,"data/csv/home_compositions.csv"), ',')
df = DataFrame([f[1,i] => f[2:end, i] for i = 1:length(f[1,:])]) df = DataFrame([f[1,i] => f[2:end, i] for i = 1:length(f[1,:])])
weight_vector::Vector{Float64} = df[!,:WGHT_PER]/sum(df[!,:WGHT_PER]) weight_vector::Vector{Float64} = df[!,:WGHT_PER]/sum(df[!,:WGHT_PER])
households = map(find_household_composition,eachrow(df)) households = map(find_household_composition,eachrow(df))
return sample(RNG,households, Weights(weight_vector),n) return (;households,weight_vector)
# https://www.publichealthontario.ca/-/media/documents/ncov/epi/covid-19-severe-outcomes-ontario-epi-summary.pdf?la=en end
function sample_household_data(n)
return sample(RNG,household_data.households, Weights(household_data.weight_vector),n)
end end
function get_household_data_proportions() function get_household_data_proportions()
f = readdlm(joinpath(PACKAGE_FOLDER,"data/csv/home_compositions.csv"), ',') households_by_demographic_sum = sum.([map(l-> l[i],household_data.households) for i in 1:3])
df = DataFrame([f[1,i] => f[2:end, i] for i = 1:length(f[1,:])])
weight_vector::Vector{Float64} = df[!,:WGHT_PER]/sum(df[!,:WGHT_PER])
households = map(find_household_composition,eachrow(df))
households_by_demographic_sum = sum.([map(l-> l[i], households) for i in 1:3])
return households_by_demographic_sum./sum(households_by_demographic_sum) return households_by_demographic_sum./sum(households_by_demographic_sum)
# https://www.publichealthontario.ca/-/media/documents/ncov/epi/covid-19-severe-outcomes-ontario-epi-summary.pdf?la=en # https://www.publichealthontario.ca/-/media/documents/ncov/epi/covid-19-severe-outcomes-ontario-epi-summary.pdf?la=en
end end
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment