Commit e27ac728 authored by Peter Jentsch's avatar Peter Jentsch
Browse files

performance improvements: removed sum calc in bipartite sampling,...

performance improvements: removed sum calc in bipartite sampling, array-specialized ZWDist sampling, ensured union splitting for mixing distribution matrices
parent 2ce4277d
......@@ -11,9 +11,6 @@ git-tree-sha1 = "dedbbb2ddb876f899585c4ec4433265e3017215a"
uuid = "dce04be8-c92d-5529-be00-80e4d2c0e197"
version = "2.1.0"
[[ArgTools]]
uuid = "0dad84c5-d112-42e6-8d28-ef12dabb789f"
[[ArnoldiMethod]]
deps = ["LinearAlgebra", "Random", "StaticArrays"]
git-tree-sha1 = "f87e559f87a45bece9c9ed97458d3afe98b1ebb9"
......@@ -21,7 +18,10 @@ uuid = "ec485272-7323-5ecc-a04f-4719b315124d"
version = "0.1.0"
[[Artifacts]]
deps = ["Pkg"]
git-tree-sha1 = "c30985d8821e0cd73870b17b0ed0ce6dc44cb744"
uuid = "56f22d72-fd6d-98f1-02f0-08ddc0907c33"
version = "1.3.0"
[[BangBang]]
deps = ["Compat", "ConstructionBase", "Future", "InitialValues", "LinearAlgebra", "Requires", "Setfield", "Tables", "ZygoteRules"]
......@@ -87,8 +87,10 @@ uuid = "34da2185-b29b-5c13-b0c7-acf172513d20"
version = "3.25.0"
[[CompilerSupportLibraries_jll]]
deps = ["Artifacts", "Libdl"]
deps = ["Artifacts", "JLLWrappers", "Libdl", "Pkg"]
git-tree-sha1 = "8e695f735fca77e9708e795eda62afdb869cbb70"
uuid = "e66e0078-7015-5450-92f7-15fbd957f2ae"
version = "0.3.4+0"
[[CompositionsBase]]
git-tree-sha1 = "f3955eb38944e5dd0fabf8ca1e267d94941d34a5"
......@@ -156,10 +158,6 @@ git-tree-sha1 = "f0e06a5b5ccda38e2fb8f59d91316e657b67047d"
uuid = "31c24e10-a181-5473-b8eb-7969acd0382f"
version = "0.24.12"
[[Downloads]]
deps = ["ArgTools", "LibCURL", "NetworkOptions"]
uuid = "f43a241f-c20a-4ad4-852c-f6b1247861c6"
[[EarCut_jll]]
deps = ["Artifacts", "JLLWrappers", "Libdl", "Pkg"]
git-tree-sha1 = "92d8f9f208637e8d2d28c664051a00569c01493d"
......@@ -353,22 +351,10 @@ git-tree-sha1 = "3a0084cec7bf157edcb45a67fac0647f88fe5eaf"
uuid = "23fbe1c1-3f47-55db-b15f-69d7ec21a316"
version = "0.14.7"
[[LibCURL]]
deps = ["LibCURL_jll", "MozillaCACerts_jll"]
uuid = "b27032c2-a3e7-50c8-80cd-2d36dbcbfd21"
[[LibCURL_jll]]
deps = ["Artifacts", "LibSSH2_jll", "Libdl", "MbedTLS_jll", "Zlib_jll", "nghttp2_jll"]
uuid = "deac9b47-8bc7-5906-a0fe-35ac56dc84c0"
[[LibGit2]]
deps = ["Base64", "NetworkOptions", "Printf", "SHA"]
deps = ["Printf"]
uuid = "76f85450-5226-5b5a-8eaa-529ad045b433"
[[LibSSH2_jll]]
deps = ["Artifacts", "Libdl", "MbedTLS_jll"]
uuid = "29816b5a-b9ab-546f-933c-edad1886dfa8"
[[LibVPX_jll]]
deps = ["Artifacts", "JLLWrappers", "Libdl", "Pkg"]
git-tree-sha1 = "85fcc80c3052be96619affa2fe2e6d2da3908e11"
......@@ -456,8 +442,10 @@ uuid = "739be429-bea8-5141-9913-cc70e7f3736d"
version = "1.0.3"
[[MbedTLS_jll]]
deps = ["Artifacts", "Libdl"]
deps = ["Artifacts", "JLLWrappers", "Libdl", "Pkg"]
git-tree-sha1 = "0eef589dd1c26a3ac9d753fe1a8bcad63f956fa6"
uuid = "c8ffd9c3-330d-5841-b78e-0817d7145fa1"
version = "2.16.8+1"
[[Measures]]
git-tree-sha1 = "e498ddeee6f9fdb4551ce855a46f54dbd900245f"
......@@ -479,9 +467,6 @@ version = "0.4.5"
[[Mmap]]
uuid = "a63ad114-7e13-5084-954f-fe012c677804"
[[MozillaCACerts_jll]]
uuid = "14a3606d-f60d-562e-9121-12d972cd8159"
[[NaNMath]]
git-tree-sha1 = "bfe47e760d60b82b66b61d2d44128b62e3a369fb"
uuid = "77ba4419-2d1f-58cd-9bb1-8ffee604a2e3"
......@@ -498,9 +483,6 @@ git-tree-sha1 = "01b5715cdd1b7c5d493c26cc05e4af663ba9a052"
uuid = "46757867-2c16-5918-afeb-47bfcb05e46a"
version = "0.3.0"
[[NetworkOptions]]
uuid = "ca575930-c2e3-43a9-ace4-1e988b2c1908"
[[Ogg_jll]]
deps = ["Artifacts", "JLLWrappers", "Libdl", "Pkg"]
git-tree-sha1 = "a42c0f138b9ebe8b58eba2271c5053773bde52d0"
......@@ -560,7 +542,7 @@ uuid = "30392449-352a-5448-841d-b1acce4e97dc"
version = "0.40.0+0"
[[Pkg]]
deps = ["Artifacts", "Dates", "Downloads", "LibGit2", "Libdl", "Logging", "Markdown", "Printf", "REPL", "Random", "SHA", "Serialization", "TOML", "Tar", "UUIDs"]
deps = ["Dates", "LibGit2", "Libdl", "Logging", "Markdown", "Printf", "REPL", "Random", "SHA", "UUIDs"]
uuid = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f"
[[PlotThemes]]
......@@ -610,7 +592,7 @@ uuid = "1fd47b50-473d-5c70-9696-f719f8f3bcdc"
version = "2.4.1"
[[REPL]]
deps = ["InteractiveUtils", "Markdown", "Sockets", "Unicode"]
deps = ["InteractiveUtils", "Markdown", "Sockets"]
uuid = "3fa0cd96-eef1-5676-8a61-b3b8758bbffb"
[[Random]]
......@@ -759,10 +741,6 @@ version = "1.2.3"
deps = ["Libdl", "LinearAlgebra", "Serialization", "SparseArrays"]
uuid = "4607b0f0-06f3-5cda-b6b1-a6196a1729e9"
[[TOML]]
deps = ["Dates"]
uuid = "fa267f1f-6049-4f14-aa54-33bafae1ed76"
[[TableTraits]]
deps = ["IteratorInterfaceExtensions"]
git-tree-sha1 = "b1ad568ba658d8cbb3b892ed5380a6f3e781a81e"
......@@ -775,12 +753,8 @@ git-tree-sha1 = "8dc2bb7d3548e315d890706547b24502ed79504f"
uuid = "bd369af6-aec1-5ad0-b16a-f7cc5008161c"
version = "1.3.1"
[[Tar]]
deps = ["ArgTools", "SHA"]
uuid = "a4e569a6-e804-4fa4-b0f3-eef7a1d5b13e"
[[Test]]
deps = ["InteractiveUtils", "Logging", "Random", "Serialization"]
deps = ["Distributed", "InteractiveUtils", "Logging", "Random"]
uuid = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
[[ThreadsX]]
......@@ -958,8 +932,10 @@ uuid = "c5fb5394-a638-5e4d-96e5-b29de1b5cf10"
version = "1.4.0+3"
[[Zlib_jll]]
deps = ["Libdl"]
deps = ["Artifacts", "JLLWrappers", "Libdl", "Pkg"]
git-tree-sha1 = "320228915c8debb12cb434c59057290f0834dbf6"
uuid = "83775a58-1f1d-513f-b197-d71354ab007a"
version = "1.2.11+18"
[[Zstd_jll]]
deps = ["Artifacts", "JLLWrappers", "Libdl", "Pkg"]
......@@ -997,10 +973,6 @@ git-tree-sha1 = "fa14ac25af7a4b8a7f61b287a124df7aab601bcd"
uuid = "f27f6e37-5d2b-51aa-960f-b287f2bc3b7a"
version = "1.3.6+6"
[[nghttp2_jll]]
deps = ["Artifacts", "Libdl"]
uuid = "8e850ede-7688-5339-a07c-302acd2aaf8d"
[[x264_jll]]
deps = ["Artifacts", "JLLWrappers", "Libdl", "Pkg"]
git-tree-sha1 = "d713c1ce4deac133e3334ee12f4adff07f81778f"
......
graphplot.gif

4.38 MB | W: | H:

graphplot.gif

5.08 MB | W: | H:

graphplot.gif
graphplot.gif
graphplot.gif
graphplot.gif
  • 2-up
  • Swipe
  • Onion skin
......@@ -17,7 +17,7 @@ using NamedTupleTools
using NetworkLayout:Stress
using NetworkLayout:SFDP
import Base.rand
import LightGraphs.add_edge!
include("agents.jl")
include("mixing_distributions.jl")
......@@ -34,27 +34,19 @@ const PACKAGE_FOLDER = dirname(dirname(pathof(CovidAlertVaccinationModel)))
const infection_data = parse_cases_data()
const demographic_distribution = get_canada_demographic_distribution()
#shift distribution,shift means, geometric changes to poisson if mean goes below 1
#ZW expectation is (1 - α)(expecation of base distribution)
#expectation/(1 - α)
# CUDA.allowscalar(false)
default(dpi = 300)
default(framestyle = :box)
using BenchmarkTools
function main()
rng = Xoroshiro128Plus()
agent_model = AgentModel(rng,500,2)
rng = Xoroshiro128Star(1)
agent_model = AgentModel(rng,5000,2)
# display(size.(agent_model.demographic_index_vectors))
u_0 = get_u_0(rng,length(agent_model.demographics))
steps = 300
# sol1,graphs = solve!(rng,u_0,get_parameters(),steps,agent_model,vaccinate_uniformly!);
steps = 10
@btime solve!($rng,$u_0,$get_parameters(),$steps,$agent_model,$vaccinate_uniformly!);
# plot_model(agent_model.base_network,graphs,sol1)
# @btime solve!($rng,$u_0,$get_parameters(),$steps,$agent_model,$vaccinate_uniformly!);
end
#returns the total infected given a certain num_households, vaccines per day, and strategy
end
\ No newline at end of file
......@@ -7,14 +7,31 @@ function generate_mixing_graph!(rng,g,index_vectors,mixing_matrix)
i_to_j_contacts = rand(rng,agent_contact_dist_i,length(index_vectors[i]))
j_to_i_contacts = rand(rng,agent_contact_dist_j,length(index_vectors[j]))
while sum(i_to_j_contacts) - sum(j_to_i_contacts) != 0
i_to_j_contacts[sample(rng,1:length(index_vectors[i]))] = rand(rng,agent_contact_dist_i)
j_to_i_contacts[sample(rng,1:length(index_vectors[j]))] = rand(rng,agent_contact_dist_j)
# equalize_degree_lists!(rng,i_to_j_contacts,j_to_i_contacts)
contacts_sums = sum(i_to_j_contacts) - sum(j_to_i_contacts)
while contacts_sums != 0
i_index = sample(rng,1:length(index_vectors[i]))
j_index = sample(rng,1:length(index_vectors[j]))
contacts_sums -= i_to_j_contacts[i_index]
contacts_sums += j_to_i_contacts[j_index]
i_to_j_contacts[i_index] = rand(rng,agent_contact_dist_i)
j_to_i_contacts[j_index] = rand(rng,agent_contact_dist_j)
contacts_sums += i_to_j_contacts[i_index]
contacts_sums -= j_to_i_contacts[j_index]
end
random_bipartite_graph_fast_CL!(rng,g,index_vectors[i],index_vectors[j],i_to_j_contacts,j_to_i_contacts)
# random_bipartite_graph_fast_CL!(rng,g,index_vectors[i],index_vectors[j],i_to_j_contacts,j_to_i_contacts)
end
return g
end
# function searchsortedfirst(l,things_to_find::T) where T<:AbstractVector
# end
#modify g so that nodes specified in anodes and bnodes are connected by a bipartite graph with expected degrees given by aseq and bseq
#implemented from Aksoy, S. G., Kolda, T. G., & Pinar, A. (2017). Measuring and modeling bipartite graphs with community structure
......@@ -24,8 +41,8 @@ function random_bipartite_graph_fast_CL!(rng,g,anodes,bnodes,aseq,bseq)
lenb = length(bseq)
m = sum(aseq)
@assert sum(aseq) == sum(bseq) "degree sequences must have equal sum"
astubs = sample(rng,anodes,StatsBase.weights(aseq./m),m;replace = true)
bstubs = sample(rng,bnodes,StatsBase.weights(bseq./m),m;replace = true)
astubs = sample(rng,anodes,StatsBase.weights(aseq./m), m; replace = true)
bstubs = sample(rng,bnodes,StatsBase.weights(bseq./m), m; replace = true)
for k in 1:m
add_edge!(g,astubs[k],bstubs[k])
end
......
......@@ -11,9 +11,10 @@ function Base.rand(rng::AbstractRNG, s::ZWDist)
end
end
# function from_mean(::Type{Geometric{T}},μ) where T
# return Geometric(1/μ)
# end
function Base.rand(rng::AbstractRNG, s::ZWDist, n::T) where T<:Int
return ifelse.(Base.rand(rng,n) .< s.α, 0, Base.rand(rng,s.base_dist,n))
end
function from_mean(::Type{Geometric{T}},μ) where T
if μ > 1.0
return Geometric(1/(μ+1))
......@@ -36,6 +37,7 @@ function ZeroGeometric(α,p)
end
StatsBase.mean(d::ZWDist{Dist,T}) where {Dist,T} = (1 - d.α)*StatsBase.mean(d.base_dist)
const initial_workschool_type = Union{ZWDist{Geometric{Float64},Float64},ZWDist{Poisson{Float64},Float64}}
const initial_workschool_mixing_matrix = map(t->from_mean(t...),[
(ZWDist{Geometric{Float64},Float64}, 0.433835,4.104848) (ZWDist{Geometric{Float64},Float64},0.406326,2.568782) (ZWDist{Poisson{Float64},Float64},0.888015,0.017729) (ZWDist{Geometric{Float64},Float64},0.406326,2.568782) (ZWDist{Poisson{Float64},Float64},0.888015,0.017729)
......@@ -46,6 +48,7 @@ const initial_workschool_mixing_matrix = map(t->from_mean(t...),[
])
const initial_rest_type = Union{Geometric{Float64},Poisson{Float64}}
const initial_rest_mixing_matrix = map(t->from_mean(t...),[
(Geometric{Float64},2.728177) (Geometric{Float64},1.382557) (Poisson{Float64},0.206362) (Geometric{Float64},1.382557) (Poisson{Float64},0.206362)
(Poisson{Float64},1.139072) (Geometric{Float64},3.245594) (Poisson{Float64},0.785297) (Geometric{Float64},3.245594) (Poisson{Float64},0.785297)
......@@ -54,6 +57,6 @@ const initial_rest_mixing_matrix = map(t->from_mean(t...),[
(Poisson{Float64},0.264822) (Poisson{Float64},0.734856) (Poisson{Float64},0.667099) (Poisson{Float64},0.734856) (Poisson{Float64},0.667099)
])
const contact_time_distribution_matrix = [Hypergeometric(5,2,4) for i in 1:AgentDemographic.size + 1, j in 1:AgentDemographic.size + 1]
const contact_time_distribution_matrix = [Geometric() for i in 1:AgentDemographic.size + 1, j in 1:AgentDemographic.size + 1]
......@@ -53,9 +53,9 @@ function solve!(rng,u_0,params,steps,agent_model,vaccination_algorithm!)
for t in 1:steps
graph_t = deepcopy(base_network) #copy static network to modify with dynamic workschool/rest contacts
generate_mixing_graph!(rng,graph_t,index_vectors,agent_model.workschool_contacts_mean_adjusted) #add workschool contacts
generate_mixing_graph!(rng,graph_t,index_vectors,agent_model.rest_contacts_mean_adjusted) #add rest contacts
push!(graphs,graph_t) #add the generated graph for this timestep onto the list of graphs
agents_step!(rng,t,solution[t+1],solution[t],population_list,graph_t,params,index_vectors,vaccination_algorithm!)
# generate_mixing_graph!(rng,graph_t,index_vectors,agent_model.rest_contacts_mean_adjusted) #add rest contacts
# push!(graphs,graph_t) #add the generated graph for this timestep onto the list of graphs
# agents_step!(rng,t,solution[t+1],solution[t],population_list,graph_t,params,index_vectors,vaccination_algorithm!)
#advance agent states based on the new network, vaccination process given by vaccination_algorithm!, which is just a function defined as above
end
......
import Base.Order: Ordering, Forward, ord, lt
#sorted merging is from https://github.com/vvjn/MergeSorted.jl, an old unmaintained package
function mergesorted!(
v::AbstractVector,
lo::Int, hi::Int,
vl::AbstractVector,
lol::Int, hil::Int,
vr::AbstractVector,
lor::Int, hir::Int,
order::Ordering
)
c = lol
p = lor
nl = hil
nr = hir
i = lo
@inbounds while c <= nl && p <= nr && i <= hi
if lt(order, vr[p], vl[c])
v[i] = vr[p]
p = p+1
i = i+1
else
v[i] = vl[c]
c = c+1
i = i+1
end
end
@inbounds while p <= nr && i <= hi
v[i] = vr[p]
i = i+1
p = p+1
end
@inbounds while c <= nl && i <= hi
v[i] = vl[c]
i = i+1
c = c+1
end
v
end
function mergesorted!(v::AbstractVector, vl::AbstractVector,
vr::AbstractVector, order::Ordering)
inds = eachindex(v)
indsl = eachindex(vl)
indsr = eachindex(vr)
mergesorted!(v,first(inds),last(inds),vl,first(indsl),last(indsl),
vr,first(indsr),last(indsr),order)
end
function mergesorted!(v::AbstractVector,
vl::AbstractVector,
vr::AbstractVector;
lt=isless,
by=identity,
rev::Bool=false,
order::Ordering=Forward
)
ordr = ord(lt,by,rev,order)
mergesorted!(v, vl, vr, ordr)
end
using LightGraphs:SimpleGraphEdge
function add_edges!(g::SimpleGraph{T}, edges::Vector{SimpleGraphEdge{T}}) where T
verts = vertices(g)
@inbounds list = g.fadjlist[s]
index = searchsortedfirst(list, d)
insert!(list, index, d)
@inbounds list = g.fadjlist[d]
index = searchsortedfirst(list, s)
insert!(list, index, s)
g.ne += length(edges)
return true # edge successfully added
end
\ No newline at end of file
......@@ -32,7 +32,6 @@ function vac_rate_test(model,vac_strategy, vac_rate; rng = Xoroshiro128Plus())
steps = 300
sol1,graphs = solve!(rng,u_0,params,steps,model,vac_strategy);
total_infections = count(x->x == AgentStatus(3),sol1[end])
display((vac_rate,total_infections))
return total_infections
end
......@@ -44,7 +43,6 @@ function infection_rate_test(model, inf_parameter; rng = Xoroshiro128Plus())
sol1,graphs = solve!(rng,u_0,params,steps,model,vaccinate_uniformly!);
total_infections = count(x->x == AgentStatus(3),sol1[end])
display((inf_parameter,total_infections))
return total_infections
end
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment