Commit 7a0347d8 authored by Peter Jentsch's avatar Peter Jentsch
Browse files

refactor graph rewiring so we can turn off degree sampling, fixed error but still not working yet

parent b4fa17e4
......@@ -36,9 +36,9 @@ See `load_contact_time_distributions()`.
"""
const contact_time_distributions = load_contact_time_distributions()
function shift_contact_distributions(sampler_matrix::AbstractMatrix{T},proportion) where T<:Distributions.PoissonADSampler
function shift_contact_distributions(sampler_matrix::AbstractMatrix{T},proportion) where T
return map(sampler_matrix) do sampler
μ_old = sampler.μ
μ_old = mean(sampler)
return Poisson(μ_old*proportion)
end
end
......
......@@ -82,9 +82,9 @@ Assumes the simulation begins on Thursday arbitrarily.
function time_dep_mixing_graphs(len,base_network,demographics,index_vectors,ws_matrix_tuple,rest_matrix_tuple)
#weekly multiply durations by 1/5
daily_shift(dur_dists) = shift_contact_distributions(dur_dists,1/5)
daily_shift(dur_dists) = dur_dists#shift_contact_distributions(dur_dists,1/5)
#daily multiply by 1/2
weekly_shift(dur_dists) = shift_contact_distributions(dur_dists,1/2)
weekly_shift(dur_dists) = dur_dists#shift_contact_distributions(dur_dists,1/2)
home_static_edges = WeightedGraph(base_network,demographics,contact_time_distributions.hh) #network with households and LTC homes
......@@ -166,15 +166,18 @@ struct WeightedGraph{G,M1,M2}
weights_dict::Dictionary{GraphEdge,UInt8}
mixing_matrix::M1
sampler_matrix::M2
degrees_matrix::Union{Nothing,Matrix{Vector{Int}}}
function WeightedGraph(demographics::AbstractVector,index_vectors,mixing_matrix::M1, sampler_matrix::M2) where {M1,M2}
g = Graph(length(demographics))
weights_dict = Dictionary{GraphEdge,UInt8}()
assemble_graph!(g,weights_dict,index_vectors,mixing_matrix,sampler_matrix)
degrees_matrix = [similar(index_vectors[i]) for i = 1:3, j = 1:3]
assemble_graph!(g,weights_dict,index_vectors,mixing_matrix,sampler_matrix,degrees_matrix)
return new{typeof(g),M1,M2}(
g,
weights_dict,
mixing_matrix,
sampler_matrix,
degrees_matrix
)
end
function WeightedGraph(g::G,demographics,sampler_matrix::M2) where {G<:SimpleGraph,M2}
......@@ -190,6 +193,7 @@ struct WeightedGraph{G,M1,M2}
weights_dict,
nothing,
sampler_matrix,
nothing
)
end
end
......@@ -198,16 +202,22 @@ function remake!(wg::WeightedGraph,index_vectors)
empty!.(wg.g.fadjlist) #empty all the vector edgelists
wg.g.ne = 0
empty!(wg.weights_dict)
assemble_graph!(wg.g,wg.weights_dict,index_vectors,wg.mixing_matrix,wg.sampler_matrix)
assemble_graph!(wg.g,wg.weights_dict,index_vectors,wg.mixing_matrix,wg.sampler_matrix,wg.degrees_matrix; resample_degrees = true)
end
function assemble_graph!(g,weights_dict,index_vectors,mixing_matrix,sampler_matrix)
function assemble_graph!(g,weights_dict,index_vectors,mixing_matrix,sampler_matrix,degree_matrix; resample_degrees=true)
for i in 1:3, j in 1:i #diagonal
if i != j
edges = fast_chung_lu_bipartite(index_vectors[i],index_vectors[j],mixing_matrix[i,j],mixing_matrix[j,i])
if resample_degrees
generate_contact_vectors!(mixing_matrix[i,j],mixing_matrix[j,i],degree_matrix[i,j],degree_matrix[j,i])
end
edges = fast_chung_lu_bipartite(degree_matrix[i,j],degree_matrix[j,i],index_vectors[i],index_vectors[j])
else #from one group to itself we need another algorithm
edges = fast_chung_lu(index_vectors[i],mixing_matrix[i,i])
if resample_degrees
generate_contact_vectors!(mixing_matrix[i,j],degree_matrix[i,j])
end
edges = fast_chung_lu(degree_matrix[i,j],index_vectors[i])
end
for e in edges
edge_weight_k = rand(Random.default_rng(Threads.threadid()),sampler_matrix[j,i])
......@@ -218,34 +228,32 @@ function assemble_graph!(g,weights_dict,index_vectors,mixing_matrix,sampler_matr
add_edge!(g,e.a,e.b)
end
end
function fast_chung_lu_bipartite(pop_i,pop_j,mixing_dist_ij,mixing_dist_ji)
num_degrees_ij = similar(pop_i)
num_degrees_ji = similar(pop_j)
generate_contact_vectors!(mixing_dist_ij,mixing_dist_ji,num_degrees_ij,num_degrees_ji)
num_edges = sum(num_degrees_ij)
stubs_i = Vector{Int}(undef,num_edges)
function fast_chung_lu_bipartite(degrees_ij,degrees_ji,index_vectors_i,index_vectors_j)
m = sum(degrees_ij)
@assert m == sum(degrees_ji)
stubs_i = Vector{Int}(undef,m)
stubs_j = similar(stubs_i)
if num_edges>0
sample!(Random.default_rng(Threads.threadid()),pop_i,Weights(num_degrees_ij./num_edges),stubs_i)
sample!(Random.default_rng(Threads.threadid()),pop_j,Weights(num_degrees_ji./num_edges),stubs_j)
if m>0
sample!(Random.default_rng(Threads.threadid()),index_vectors_i,Weights(degrees_ij./m),stubs_i)
sample!(Random.default_rng(Threads.threadid()),index_vectors_j,Weights(degrees_ji./m),stubs_j)
end
return GraphEdge.(stubs_i,stubs_j)
end
function fast_chung_lu(pop_i,mixing_dist)
num_degrees_ii = similar(pop_i)
generate_contact_vectors!(mixing_dist,num_degrees_ii)
m = sum(num_degrees_ii)
function fast_chung_lu(degrees_ii,index_vectors_i)
m = sum(degrees_ii)
num_edges= div(m,2)
stubs_i = Vector{Int}(undef,num_edges)
stubs_j = similar(stubs_i)
if m>0
sample!(Random.default_rng(Threads.threadid()),pop_i,Weights(num_degrees_ii./m),stubs_i)
sample!(Random.default_rng(Threads.threadid()),pop_i,Weights(num_degrees_ii./m),stubs_j)
sample!(Random.default_rng(Threads.threadid()),index_vectors_i,Weights(degrees_ii./m),stubs_i)
sample!(Random.default_rng(Threads.threadid()),index_vectors_i,Weights(degrees_ii./m),stubs_j)
end
return GraphEdge.(stubs_i,stubs_j)
end
neighbors(g::WeightedGraph,i) = LightGraphs.neighbors(g.g,i)
get_weight(g::WeightedGraph,e) = g.weights_dict[e]
function Base.show(io::IO, g::WeightedGraph)
......
......@@ -21,7 +21,7 @@
κ = 0.0,
ω = 0.0061,
ω_en = 0.05,
Γ = 0.906,
Γ = 1/7,#0.906,
ξ = 5.0,
notification_parameter = 0.0005,
vaccinator_prob = 0.6,
......
module CovidAlertVaccinationModel
using Distributions: PoissonADSampler
using Intervals: Ending
using Base: Float64, NamedTuple
using LightGraphs
using RandomNumbers.Xorshifts
using Random
......
......@@ -142,6 +142,8 @@ end
fit!(mixing_dist[Int(demo_v), j],d)
end
end
# display(mean.(mixing_dist))
# display(expected_dist_mean)
for i in eachindex(mixing_dist)
@test mean(mixing_dist[i]) expected_dist_mean[i] atol = 0.05
@test mean(mixing_dist[i]) mean(dist[i]) atol = 0.2
......
No preview for this file type
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment