Commit 90d4a13e authored by Peter Jentsch's avatar Peter Jentsch
Browse files

explicitly specific poissoncountsampler, 30% faster

parent 2f890228
......@@ -97,6 +97,12 @@ git-tree-sha1 = "44e9f638aa9ed1ad58885defc568c133010140aa"
uuid = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
version = "0.9.37"
[[CheapThreads]]
deps = ["ArrayInterface", "IfElse", "Requires", "Static", "StrideArraysCore", "ThreadingUtilities", "VectorizationBase"]
git-tree-sha1 = "5b11b7aba0d4b53f46559b25055d3f0d4376fb36"
uuid = "b630d9fa-e28e-4980-896d-83ce5e2106b2"
version = "0.2.2"
[[ColorSchemes]]
deps = ["ColorTypes", "Colors", "FixedPointNumbers", "Random", "StaticArrays"]
git-tree-sha1 = "d3cf83862f70d430d4b34e43ed65e74bd50ae0e0"
......@@ -597,6 +603,12 @@ git-tree-sha1 = "59b45fd91b743dff047313bb7af0f84167aef80d"
uuid = "e6f89c97-d47a-5376-807f-9c37f3926c36"
version = "0.4.6"
[[LoopVectorization]]
deps = ["ArrayInterface", "CheapThreads", "DocStringExtensions", "IfElse", "LinearAlgebra", "OffsetArrays", "Requires", "SLEEFPirates", "Static", "ThreadingUtilities", "UnPack", "VectorizationBase"]
git-tree-sha1 = "da82a865158e1d62e013fa72657fda7e92820a90"
uuid = "bdcacae8-1622-11e9-2a5c-532679323890"
version = "0.12.12"
[[MKL_jll]]
deps = ["Artifacts", "IntelOpenMP_jll", "JLLWrappers", "LazyArtifacts", "Libdl", "Pkg"]
git-tree-sha1 = "c253236b0ed414624b083e6b72bfe891fbd2c7af"
......@@ -973,6 +985,12 @@ git-tree-sha1 = "ced55fd4bae008a8ea12508314e725df61f0ba45"
uuid = "4c63d2b9-4356-54db-8cca-17b64c39e42c"
version = "0.9.7"
[[StrideArraysCore]]
deps = ["ArrayInterface", "Requires", "ThreadingUtilities", "VectorizationBase"]
git-tree-sha1 = "da1091034d295c8dbaf1d6ea16529221bc24afe1"
uuid = "7792a7ef-975c-4747-a70f-980b88e8d1da"
version = "0.1.5"
[[StructArrays]]
deps = ["Adapt", "DataAPI", "Tables"]
git-tree-sha1 = "44b3afd37b17422a62aea25f04c1f7e09ce6b07f"
......@@ -1025,6 +1043,12 @@ version = "0.1.3"
deps = ["InteractiveUtils", "Logging", "Random", "Serialization"]
uuid = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
[[ThreadingUtilities]]
deps = ["VectorizationBase"]
git-tree-sha1 = "063f52eee44ec303f1721cd59b4d7892cae9f1cc"
uuid = "8290d209-cae3-49c0-8002-c8c24d57dab5"
version = "0.4.1"
[[ThreadsX]]
deps = ["ArgCheck", "BangBang", "ConstructionBase", "InitialValues", "MicroCollections", "Referenceables", "Setfield", "SplittablesBase", "Transducers"]
git-tree-sha1 = "269f5c1955c1194086cf6d2029aa4a0b4fb8018b"
......@@ -1228,7 +1252,7 @@ uuid = "c5fb5394-a638-5e4d-96e5-b29de1b5cf10"
version = "1.4.0+3"
[[ZeroWeightedDistributions]]
deps = ["Distributions", "Random", "StatsBase"]
deps = ["Distributions", "KernelDensity", "Plots", "Random", "RandomNumbers", "StatsBase", "ThreadsX"]
path = "../ZeroWeightedDistributions"
uuid = "24733ad3-391a-4e41-8839-c7177de7dea4"
version = "0.1.0"
......
......@@ -19,6 +19,7 @@ LaTeXStrings = "b964fa9f-0449-5b57-a5c2-d3ea65f4040f"
LabelledArrays = "2ee39098-c373-598a-b85f-a56591580800"
LightGraphs = "093fc24a-ae57-5d10-9952-331d41423f4d"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
LoopVectorization = "bdcacae8-1622-11e9-2a5c-532679323890"
NamedTupleTools = "d9ec5142-1e00-5aa0-9d6a-321866360f50"
NetworkLayout = "46757867-2c16-5918-afeb-47bfcb05e46a"
Pandas = "eadc2687-ae89-51f9-a5d9-86b5a6373a9c"
......
......@@ -9,6 +9,7 @@ default(framestyle = :box)
function bench()
steps = 100
Random.seed!(RNG,1)
model_sol = ModelSolution(steps,get_parameters(),5000);
recording = DebugRecorder(steps)
output = solve!(model_sol,recording )
......
......@@ -22,13 +22,14 @@ function random_bipartite_graph_fast_CL!(g::SimpleGraph,anodes,bnodes,aseq,bseq,
lena = length(aseq)
lenb = length(bseq)
m = Int(sum(aseq))
@assert sum(aseq) == sum(bseq) "degree sequences must have equal sum"
astubs = sample(RNG,anodes,StatsBase.weights(aseq./m), m)
bstubs = sample(RNG,bnodes,StatsBase.weights(bseq./m), m)
for k in 1:m
add_edge!(g,astubs[k],bstubs[k])
if m>0
# @assert sum(aseq) == sum(bseq) "degree sequences must have equal sum"
adist = sampler(Categorical(aseq./m))
bdist = sampler(Categorical(bseq./m))
for k in 1:m
add_edge!(g,anodes[rand(RNG,adist)], bnodes[rand(RNG,bdist)])
end
end
return g
end
struct TimeDepMixingGraph{N,G}
......@@ -105,23 +106,23 @@ struct WeightedGraph{G, V,M}
weights_dict = RobinDict{Tuple{Int,Int},UInt8}()
sampler_matrix = map(m -> Distributions.PoissonCountSampler.(m),weights_distribution_matrix)
for i in 1:length(demographic_index_vectors), j in 1:i #diagonal
random_bipartite_graph_fast_CL!(g,demographic_index_vectors[i],demographic_index_vectors[j],contacts.contact_array[i,j],contacts.contact_array[j,i],weights_dict)
end
covid_alert_time = zeros(nv(g))
return new{typeof(g),typeof(weights_dict),typeof(weights_distribution_matrix)}(
return new{typeof(g),typeof(weights_dict),typeof(sampler_matrix)}(
g,
weights_dict,
weights_distribution_matrix
sampler_matrix
)
end
function WeightedGraph(g::SimpleGraph,weights_distribution_matrix)
weights_dict = RobinDict{Tuple{Int,Int},UInt8}()
return new{typeof(g),typeof(weights_dict),typeof(weights_distribution_matrix)}(
sampler_matrix = map(m -> Distributions.PoissonCountSampler.(m),weights_distribution_matrix)
return new{typeof(g),typeof(weights_dict),typeof(sampler_matrix)}(
g,
weights_dict,
weights_distribution_matrix
weights_dict,
sampler_matrix
)
end
end
......@@ -140,6 +141,10 @@ function sample_mixing_graph!(mixing_graph,population_demographics)
mixing_graph.weights_dict[(j,i)] = contact_time
end
end
@inline function reindex!(k,csum,index_list_i,index_list_j,j_to_i_contacts,i_to_j_contacts,sample_list_i,sample_list_j)
i_index = index_list_i[k]
j_index = index_list_j[k]
......
# using LoopVectorization
function contact_weight(p, contact_time)
return 1 - (1-p)^contact_time
end
......@@ -69,7 +69,7 @@ function update_vaccination_opinion_state!(t,modelsol,total_infections)
for i in 1:nodes
vac_payoff = 0
soc_nbrs_vac = [0,0,0]
soc_nbrs_vac = @MArray [0,0,0]
soc_nbrs_nonvac = 0
num_soc_nbrs = 0
for sc_g in soc_network.graph_list[t]
......@@ -110,7 +110,7 @@ end
function agents_step!(t,modelsol)
remake!(modelsol.inf_network,modelsol.index_vectors,modelsol.ws_matrix_tuple.daily)
remake!(modelsol.inf_network,modelsol.index_vectors,modelsol.rest_matrix_tuple.daily)
remake!(modelsol.soc_network,modelsol.index_vectors,modelsol.rest_matrix_tuple.daily)
for network in modelsol.inf_network.graph_list[t]
sample_mixing_graph!(network,modelsol.demographics) #get new contact weights
end
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment