Commit 7a0347d8 authored by Peter Jentsch's avatar Peter Jentsch
Browse files

refactor graph rewiring so we can turn off degree sampling, fixed error but still not working yet

parent b4fa17e4
...@@ -36,9 +36,9 @@ See `load_contact_time_distributions()`. ...@@ -36,9 +36,9 @@ See `load_contact_time_distributions()`.
""" """
const contact_time_distributions = load_contact_time_distributions() const contact_time_distributions = load_contact_time_distributions()
function shift_contact_distributions(sampler_matrix::AbstractMatrix{T},proportion) where T<:Distributions.PoissonADSampler function shift_contact_distributions(sampler_matrix::AbstractMatrix{T},proportion) where T
return map(sampler_matrix) do sampler return map(sampler_matrix) do sampler
μ_old = sampler.μ μ_old = mean(sampler)
return Poisson(μ_old*proportion) return Poisson(μ_old*proportion)
end end
end end
......
...@@ -82,9 +82,9 @@ Assumes the simulation begins on Thursday arbitrarily. ...@@ -82,9 +82,9 @@ Assumes the simulation begins on Thursday arbitrarily.
function time_dep_mixing_graphs(len,base_network,demographics,index_vectors,ws_matrix_tuple,rest_matrix_tuple) function time_dep_mixing_graphs(len,base_network,demographics,index_vectors,ws_matrix_tuple,rest_matrix_tuple)
#weekly multiply durations by 1/5 #weekly multiply durations by 1/5
daily_shift(dur_dists) = shift_contact_distributions(dur_dists,1/5) daily_shift(dur_dists) = dur_dists#shift_contact_distributions(dur_dists,1/5)
#daily multiply by 1/2 #daily multiply by 1/2
weekly_shift(dur_dists) = shift_contact_distributions(dur_dists,1/2) weekly_shift(dur_dists) = dur_dists#shift_contact_distributions(dur_dists,1/2)
home_static_edges = WeightedGraph(base_network,demographics,contact_time_distributions.hh) #network with households and LTC homes home_static_edges = WeightedGraph(base_network,demographics,contact_time_distributions.hh) #network with households and LTC homes
...@@ -166,15 +166,18 @@ struct WeightedGraph{G,M1,M2} ...@@ -166,15 +166,18 @@ struct WeightedGraph{G,M1,M2}
weights_dict::Dictionary{GraphEdge,UInt8} weights_dict::Dictionary{GraphEdge,UInt8}
mixing_matrix::M1 mixing_matrix::M1
sampler_matrix::M2 sampler_matrix::M2
degrees_matrix::Union{Nothing,Matrix{Vector{Int}}}
function WeightedGraph(demographics::AbstractVector,index_vectors,mixing_matrix::M1, sampler_matrix::M2) where {M1,M2} function WeightedGraph(demographics::AbstractVector,index_vectors,mixing_matrix::M1, sampler_matrix::M2) where {M1,M2}
g = Graph(length(demographics)) g = Graph(length(demographics))
weights_dict = Dictionary{GraphEdge,UInt8}() weights_dict = Dictionary{GraphEdge,UInt8}()
assemble_graph!(g,weights_dict,index_vectors,mixing_matrix,sampler_matrix) degrees_matrix = [similar(index_vectors[i]) for i = 1:3, j = 1:3]
assemble_graph!(g,weights_dict,index_vectors,mixing_matrix,sampler_matrix,degrees_matrix)
return new{typeof(g),M1,M2}( return new{typeof(g),M1,M2}(
g, g,
weights_dict, weights_dict,
mixing_matrix, mixing_matrix,
sampler_matrix, sampler_matrix,
degrees_matrix
) )
end end
function WeightedGraph(g::G,demographics,sampler_matrix::M2) where {G<:SimpleGraph,M2} function WeightedGraph(g::G,demographics,sampler_matrix::M2) where {G<:SimpleGraph,M2}
...@@ -190,6 +193,7 @@ struct WeightedGraph{G,M1,M2} ...@@ -190,6 +193,7 @@ struct WeightedGraph{G,M1,M2}
weights_dict, weights_dict,
nothing, nothing,
sampler_matrix, sampler_matrix,
nothing
) )
end end
end end
...@@ -198,16 +202,22 @@ function remake!(wg::WeightedGraph,index_vectors) ...@@ -198,16 +202,22 @@ function remake!(wg::WeightedGraph,index_vectors)
empty!.(wg.g.fadjlist) #empty all the vector edgelists empty!.(wg.g.fadjlist) #empty all the vector edgelists
wg.g.ne = 0 wg.g.ne = 0
empty!(wg.weights_dict) empty!(wg.weights_dict)
assemble_graph!(wg.g,wg.weights_dict,index_vectors,wg.mixing_matrix,wg.sampler_matrix) assemble_graph!(wg.g,wg.weights_dict,index_vectors,wg.mixing_matrix,wg.sampler_matrix,wg.degrees_matrix; resample_degrees = true)
end end
function assemble_graph!(g,weights_dict,index_vectors,mixing_matrix,sampler_matrix) function assemble_graph!(g,weights_dict,index_vectors,mixing_matrix,sampler_matrix,degree_matrix; resample_degrees=true)
for i in 1:3, j in 1:i #diagonal for i in 1:3, j in 1:i #diagonal
if i != j if i != j
edges = fast_chung_lu_bipartite(index_vectors[i],index_vectors[j],mixing_matrix[i,j],mixing_matrix[j,i]) if resample_degrees
generate_contact_vectors!(mixing_matrix[i,j],mixing_matrix[j,i],degree_matrix[i,j],degree_matrix[j,i])
end
edges = fast_chung_lu_bipartite(degree_matrix[i,j],degree_matrix[j,i],index_vectors[i],index_vectors[j])
else #from one group to itself we need another algorithm else #from one group to itself we need another algorithm
edges = fast_chung_lu(index_vectors[i],mixing_matrix[i,i]) if resample_degrees
generate_contact_vectors!(mixing_matrix[i,j],degree_matrix[i,j])
end
edges = fast_chung_lu(degree_matrix[i,j],index_vectors[i])
end end
for e in edges for e in edges
edge_weight_k = rand(Random.default_rng(Threads.threadid()),sampler_matrix[j,i]) edge_weight_k = rand(Random.default_rng(Threads.threadid()),sampler_matrix[j,i])
...@@ -218,34 +228,32 @@ function assemble_graph!(g,weights_dict,index_vectors,mixing_matrix,sampler_matr ...@@ -218,34 +228,32 @@ function assemble_graph!(g,weights_dict,index_vectors,mixing_matrix,sampler_matr
add_edge!(g,e.a,e.b) add_edge!(g,e.a,e.b)
end end
end end
function fast_chung_lu_bipartite(pop_i,pop_j,mixing_dist_ij,mixing_dist_ji)
num_degrees_ij = similar(pop_i) function fast_chung_lu_bipartite(degrees_ij,degrees_ji,index_vectors_i,index_vectors_j)
num_degrees_ji = similar(pop_j) m = sum(degrees_ij)
generate_contact_vectors!(mixing_dist_ij,mixing_dist_ji,num_degrees_ij,num_degrees_ji) @assert m == sum(degrees_ji)
num_edges = sum(num_degrees_ij) stubs_i = Vector{Int}(undef,m)
stubs_i = Vector{Int}(undef,num_edges)
stubs_j = similar(stubs_i) stubs_j = similar(stubs_i)
if num_edges>0 if m>0
sample!(Random.default_rng(Threads.threadid()),pop_i,Weights(num_degrees_ij./num_edges),stubs_i) sample!(Random.default_rng(Threads.threadid()),index_vectors_i,Weights(degrees_ij./m),stubs_i)
sample!(Random.default_rng(Threads.threadid()),pop_j,Weights(num_degrees_ji./num_edges),stubs_j) sample!(Random.default_rng(Threads.threadid()),index_vectors_j,Weights(degrees_ji./m),stubs_j)
end end
return GraphEdge.(stubs_i,stubs_j) return GraphEdge.(stubs_i,stubs_j)
end end
function fast_chung_lu(pop_i,mixing_dist) function fast_chung_lu(degrees_ii,index_vectors_i)
num_degrees_ii = similar(pop_i) m = sum(degrees_ii)
generate_contact_vectors!(mixing_dist,num_degrees_ii)
m = sum(num_degrees_ii)
num_edges= div(m,2) num_edges= div(m,2)
stubs_i = Vector{Int}(undef,num_edges) stubs_i = Vector{Int}(undef,num_edges)
stubs_j = similar(stubs_i) stubs_j = similar(stubs_i)
if m>0 if m>0
sample!(Random.default_rng(Threads.threadid()),pop_i,Weights(num_degrees_ii./m),stubs_i) sample!(Random.default_rng(Threads.threadid()),index_vectors_i,Weights(degrees_ii./m),stubs_i)
sample!(Random.default_rng(Threads.threadid()),pop_i,Weights(num_degrees_ii./m),stubs_j) sample!(Random.default_rng(Threads.threadid()),index_vectors_i,Weights(degrees_ii./m),stubs_j)
end end
return GraphEdge.(stubs_i,stubs_j) return GraphEdge.(stubs_i,stubs_j)
end end
neighbors(g::WeightedGraph,i) = LightGraphs.neighbors(g.g,i) neighbors(g::WeightedGraph,i) = LightGraphs.neighbors(g.g,i)
get_weight(g::WeightedGraph,e) = g.weights_dict[e] get_weight(g::WeightedGraph,e) = g.weights_dict[e]
function Base.show(io::IO, g::WeightedGraph) function Base.show(io::IO, g::WeightedGraph)
......
...@@ -21,7 +21,7 @@ ...@@ -21,7 +21,7 @@
κ = 0.0, κ = 0.0,
ω = 0.0061, ω = 0.0061,
ω_en = 0.05, ω_en = 0.05,
Γ = 0.906, Γ = 1/7,#0.906,
ξ = 5.0, ξ = 5.0,
notification_parameter = 0.0005, notification_parameter = 0.0005,
vaccinator_prob = 0.6, vaccinator_prob = 0.6,
......
module CovidAlertVaccinationModel module CovidAlertVaccinationModel
using Distributions: PoissonADSampler
using Intervals: Ending
using Base: Float64, NamedTuple
using LightGraphs using LightGraphs
using RandomNumbers.Xorshifts using RandomNumbers.Xorshifts
using Random using Random
......
...@@ -142,6 +142,8 @@ end ...@@ -142,6 +142,8 @@ end
fit!(mixing_dist[Int(demo_v), j],d) fit!(mixing_dist[Int(demo_v), j],d)
end end
end end
# display(mean.(mixing_dist))
# display(expected_dist_mean)
for i in eachindex(mixing_dist) for i in eachindex(mixing_dist)
@test mean(mixing_dist[i]) expected_dist_mean[i] atol = 0.05 @test mean(mixing_dist[i]) expected_dist_mean[i] atol = 0.05
@test mean(mixing_dist[i]) mean(dist[i]) atol = 0.2 @test mean(mixing_dist[i]) mean(dist[i]) atol = 0.2
......
No preview for this file type
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment