Skip to content
Snippets Groups Projects
Commit a27f8d5b authored by Peter Jentsch's avatar Peter Jentsch
Browse files

switch to VectorizedRNG

parent fd1ec66c
No related branches found
No related tags found
No related merge requests found
...@@ -39,9 +39,9 @@ version = "0.1.0" ...@@ -39,9 +39,9 @@ version = "0.1.0"
[[ArrayInterface]] [[ArrayInterface]]
deps = ["IfElse", "LinearAlgebra", "Requires", "SparseArrays", "Static"] deps = ["IfElse", "LinearAlgebra", "Requires", "SparseArrays", "Static"]
git-tree-sha1 = "d84e8967b7f04f52c9bca21714bae54a553a53fc" git-tree-sha1 = "b08be763d0b8ddee6b162016dad746a69980616d"
uuid = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9" uuid = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9"
version = "3.1.10" version = "3.1.11"
[[Artifacts]] [[Artifacts]]
uuid = "56f22d72-fd6d-98f1-02f0-08ddc0907c33" uuid = "56f22d72-fd6d-98f1-02f0-08ddc0907c33"
...@@ -92,9 +92,9 @@ version = "1.16.0+6" ...@@ -92,9 +92,9 @@ version = "1.16.0+6"
[[ChainRulesCore]] [[ChainRulesCore]]
deps = ["Compat", "LinearAlgebra", "SparseArrays"] deps = ["Compat", "LinearAlgebra", "SparseArrays"]
git-tree-sha1 = "bd0cc939d94b8bd736dce5bbbe0d635db9f94af7" git-tree-sha1 = "e6b23566e025d3b0d9ccc397f5c7a134af552e27"
uuid = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" uuid = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
version = "0.9.41" version = "0.9.42"
[[CheapThreads]] [[CheapThreads]]
deps = ["ArrayInterface", "IfElse", "Requires", "Static", "StrideArraysCore", "ThreadingUtilities", "VectorizationBase"] deps = ["ArrayInterface", "IfElse", "Requires", "Static", "StrideArraysCore", "ThreadingUtilities", "VectorizationBase"]
...@@ -128,9 +128,9 @@ version = "0.3.0" ...@@ -128,9 +128,9 @@ version = "0.3.0"
[[Compat]] [[Compat]]
deps = ["Base64", "Dates", "DelimitedFiles", "Distributed", "InteractiveUtils", "LibGit2", "Libdl", "LinearAlgebra", "Markdown", "Mmap", "Pkg", "Printf", "REPL", "Random", "SHA", "Serialization", "SharedArrays", "Sockets", "SparseArrays", "Statistics", "Test", "UUIDs", "Unicode"] deps = ["Base64", "Dates", "DelimitedFiles", "Distributed", "InteractiveUtils", "LibGit2", "Libdl", "LinearAlgebra", "Markdown", "Mmap", "Pkg", "Printf", "REPL", "Random", "SHA", "Serialization", "SharedArrays", "Sockets", "SparseArrays", "Statistics", "Test", "UUIDs", "Unicode"]
git-tree-sha1 = "ac4132ad78082518ec2037ae5770b6e796f7f956" git-tree-sha1 = "0a817fbe51c976de090aa8c997b7b719b786118d"
uuid = "34da2185-b29b-5c13-b0c7-acf172513d20" uuid = "34da2185-b29b-5c13-b0c7-acf172513d20"
version = "3.27.0" version = "3.28.0"
[[CompilerSupportLibraries_jll]] [[CompilerSupportLibraries_jll]]
deps = ["Artifacts", "Libdl"] deps = ["Artifacts", "Libdl"]
...@@ -155,9 +155,9 @@ version = "0.1.2" ...@@ -155,9 +155,9 @@ version = "0.1.2"
[[ConstructionBase]] [[ConstructionBase]]
deps = ["LinearAlgebra"] deps = ["LinearAlgebra"]
git-tree-sha1 = "48920211c95a6da1914a06c44ec94be70e84ffff" git-tree-sha1 = "1dc43957fb9a1574fa1b7a449e101bd1fd3a9fb7"
uuid = "187b0558-2788-49d3-abe0-74a17ed4e7c9" uuid = "187b0558-2788-49d3-abe0-74a17ed4e7c9"
version = "1.1.0" version = "1.2.1"
[[Contour]] [[Contour]]
deps = ["StaticArrays"] deps = ["StaticArrays"]
...@@ -183,9 +183,9 @@ version = "1.6.0" ...@@ -183,9 +183,9 @@ version = "1.6.0"
[[DataFrames]] [[DataFrames]]
deps = ["Compat", "DataAPI", "Future", "InvertedIndices", "IteratorInterfaceExtensions", "LinearAlgebra", "Markdown", "Missings", "PooledArrays", "PrettyTables", "Printf", "REPL", "Reexport", "SortingAlgorithms", "Statistics", "TableTraits", "Tables", "Unicode"] deps = ["Compat", "DataAPI", "Future", "InvertedIndices", "IteratorInterfaceExtensions", "LinearAlgebra", "Markdown", "Missings", "PooledArrays", "PrettyTables", "Printf", "REPL", "Reexport", "SortingAlgorithms", "Statistics", "TableTraits", "Tables", "Unicode"]
git-tree-sha1 = "6e5452d9cf401ed9048e1cde93815be53d951079" git-tree-sha1 = "66ee4fe515a9294a8836ef18eea7239c6ac3db5e"
uuid = "a93c6f00-e57d-5684-b7b6-d8193f3e46c0" uuid = "a93c6f00-e57d-5684-b7b6-d8193f3e46c0"
version = "1.0.2" version = "1.1.1"
[[DataStructures]] [[DataStructures]]
deps = ["Compat", "InteractiveUtils", "OrderedCollections"] deps = ["Compat", "InteractiveUtils", "OrderedCollections"]
...@@ -322,9 +322,9 @@ version = "2.8.0" ...@@ -322,9 +322,9 @@ version = "2.8.0"
[[FiniteDifferences]] [[FiniteDifferences]]
deps = ["ChainRulesCore", "LinearAlgebra", "Printf", "Random", "Richardson", "StaticArrays"] deps = ["ChainRulesCore", "LinearAlgebra", "Printf", "Random", "Richardson", "StaticArrays"]
git-tree-sha1 = "ae1ce4975c393bad81a55cb073e2af04bd151f85" git-tree-sha1 = "80e1a7416cbf08fe80c8885e1834c45cfc399c61"
uuid = "26cc04aa-876d-5657-8c51-4c34ba976000" uuid = "26cc04aa-876d-5657-8c51-4c34ba976000"
version = "0.12.3" version = "0.12.5"
[[FixedPointNumbers]] [[FixedPointNumbers]]
deps = ["Statistics"] deps = ["Statistics"]
...@@ -415,9 +415,9 @@ version = "0.9.8" ...@@ -415,9 +415,9 @@ version = "0.9.8"
[[Hwloc]] [[Hwloc]]
deps = ["Hwloc_jll"] deps = ["Hwloc_jll"]
git-tree-sha1 = "ffdcd4272a7cc36442007bca41aa07ca3cc5fda4" git-tree-sha1 = "92d99146066c5c6888d5a3abc871e6a214388b91"
uuid = "0e44f5e4-bd66-52a0-8798-143a42290a1d" uuid = "0e44f5e4-bd66-52a0-8798-143a42290a1d"
version = "1.3.0" version = "2.0.0"
[[Hwloc_jll]] [[Hwloc_jll]]
deps = ["Artifacts", "JLLWrappers", "Libdl", "Pkg"] deps = ["Artifacts", "JLLWrappers", "Libdl", "Pkg"]
...@@ -678,9 +678,9 @@ version = "0.4.6" ...@@ -678,9 +678,9 @@ version = "0.4.6"
[[LoopVectorization]] [[LoopVectorization]]
deps = ["ArrayInterface", "CheapThreads", "DocStringExtensions", "IfElse", "LinearAlgebra", "OffsetArrays", "Requires", "SLEEFPirates", "Static", "StrideArraysCore", "ThreadingUtilities", "UnPack", "VectorizationBase"] deps = ["ArrayInterface", "CheapThreads", "DocStringExtensions", "IfElse", "LinearAlgebra", "OffsetArrays", "Requires", "SLEEFPirates", "Static", "StrideArraysCore", "ThreadingUtilities", "UnPack", "VectorizationBase"]
git-tree-sha1 = "427ec6a601c32d704bb664b32bf695519ef66043" git-tree-sha1 = "2ad016117de05750443ed219dada9df3f9b9fa8f"
uuid = "bdcacae8-1622-11e9-2a5c-532679323890" uuid = "bdcacae8-1622-11e9-2a5c-532679323890"
version = "0.12.15" version = "0.12.18"
[[LsqFit]] [[LsqFit]]
deps = ["Distributions", "ForwardDiff", "LinearAlgebra", "NLSolversBase", "OptimBase", "Random", "StatsBase"] deps = ["Distributions", "ForwardDiff", "LinearAlgebra", "NLSolversBase", "OptimBase", "Random", "StatsBase"]
...@@ -751,9 +751,9 @@ uuid = "14a3606d-f60d-562e-9121-12d972cd8159" ...@@ -751,9 +751,9 @@ uuid = "14a3606d-f60d-562e-9121-12d972cd8159"
[[MutableArithmetics]] [[MutableArithmetics]]
deps = ["LinearAlgebra", "SparseArrays", "Test"] deps = ["LinearAlgebra", "SparseArrays", "Test"]
git-tree-sha1 = "3301e152b9a208745fad6cd4b068307a5d218a38" git-tree-sha1 = "ad9b2bce6021631e0e20706d361972343a03e642"
uuid = "d8a4904e-b15c-11e9-3269-09a3773c0cb0" uuid = "d8a4904e-b15c-11e9-3269-09a3773c0cb0"
version = "0.2.18" version = "0.2.19"
[[NLSolversBase]] [[NLSolversBase]]
deps = ["DiffResults", "Distributed", "FiniteDiff", "ForwardDiff"] deps = ["DiffResults", "Distributed", "FiniteDiff", "ForwardDiff"]
...@@ -782,9 +782,9 @@ uuid = "ca575930-c2e3-43a9-ace4-1e988b2c1908" ...@@ -782,9 +782,9 @@ uuid = "ca575930-c2e3-43a9-ace4-1e988b2c1908"
[[OffsetArrays]] [[OffsetArrays]]
deps = ["Adapt"] deps = ["Adapt"]
git-tree-sha1 = "87a728aebb76220bd72855e1c85284c5fdb9774c" git-tree-sha1 = "47b443d2ccc8297a4c538f55f8fd828ad58599ab"
uuid = "6fe1bfb0-de20-5000-8ca7-80f57d26f881" uuid = "6fe1bfb0-de20-5000-8ca7-80f57d26f881"
version = "1.7.0" version = "1.8.0"
[[Ogg_jll]] [[Ogg_jll]]
deps = ["Artifacts", "JLLWrappers", "Libdl", "Pkg"] deps = ["Artifacts", "JLLWrappers", "Libdl", "Pkg"]
...@@ -829,9 +829,9 @@ uuid = "91d4177d-7536-5919-b921-800302f37372" ...@@ -829,9 +829,9 @@ uuid = "91d4177d-7536-5919-b921-800302f37372"
version = "1.3.1+3" version = "1.3.1+3"
[[OrderedCollections]] [[OrderedCollections]]
git-tree-sha1 = "4fa2ba51070ec13fcc7517db714445b4ab986bdf" git-tree-sha1 = "85f8e6578bf1f9ee0d11e7bb1b1456435479d47c"
uuid = "bac558e1-5e72-5ebc-8fee-abe8a469f55d" uuid = "bac558e1-5e72-5ebc-8fee-abe8a469f55d"
version = "1.4.0" version = "1.4.1"
[[PCRE_jll]] [[PCRE_jll]]
deps = ["Artifacts", "JLLWrappers", "Libdl", "Pkg"] deps = ["Artifacts", "JLLWrappers", "Libdl", "Pkg"]
...@@ -892,9 +892,9 @@ version = "1.0.10" ...@@ -892,9 +892,9 @@ version = "1.0.10"
[[Plots]] [[Plots]]
deps = ["Base64", "Contour", "Dates", "FFMPEG", "FixedPointNumbers", "GR", "GeometryBasics", "JSON", "Latexify", "LinearAlgebra", "Measures", "NaNMath", "PlotThemes", "PlotUtils", "Printf", "REPL", "Random", "RecipesBase", "RecipesPipeline", "Reexport", "Requires", "Scratch", "Showoff", "SparseArrays", "Statistics", "StatsBase", "UUIDs"] deps = ["Base64", "Contour", "Dates", "FFMPEG", "FixedPointNumbers", "GR", "GeometryBasics", "JSON", "Latexify", "LinearAlgebra", "Measures", "NaNMath", "PlotThemes", "PlotUtils", "Printf", "REPL", "Random", "RecipesBase", "RecipesPipeline", "Reexport", "Requires", "Scratch", "Showoff", "SparseArrays", "Statistics", "StatsBase", "UUIDs"]
git-tree-sha1 = "eced322de627fa8469c55f24899fdd4cce7d978c" git-tree-sha1 = "2628e5859819173cef995470af83db42bf411ef8"
uuid = "91a5bcdd-55d7-5caf-9e0b-520d859cae80" uuid = "91a5bcdd-55d7-5caf-9e0b-520d859cae80"
version = "1.13.2" version = "1.14.0"
[[Polynomials]] [[Polynomials]]
deps = ["Intervals", "LinearAlgebra", "MutableArithmetics", "RecipesBase"] deps = ["Intervals", "LinearAlgebra", "MutableArithmetics", "RecipesBase"]
...@@ -910,9 +910,9 @@ version = "1.2.1" ...@@ -910,9 +910,9 @@ version = "1.2.1"
[[Preferences]] [[Preferences]]
deps = ["TOML"] deps = ["TOML"]
git-tree-sha1 = "ea79e4c9077208cd3bc5d29631a26bc0cff78902" git-tree-sha1 = "00cfd92944ca9c760982747e9a1d0d5d86ab1e5a"
uuid = "21216c6a-2e73-6563-6e65-726566657250" uuid = "21216c6a-2e73-6563-6e65-726566657250"
version = "1.2.1" version = "1.2.2"
[[PrettyTables]] [[PrettyTables]]
deps = ["Crayons", "Formatting", "Markdown", "Reexport", "Tables"] deps = ["Crayons", "Formatting", "Markdown", "Reexport", "Tables"]
...@@ -932,9 +932,9 @@ version = "0.1.4" ...@@ -932,9 +932,9 @@ version = "0.1.4"
[[ProgressMeter]] [[ProgressMeter]]
deps = ["Distributed", "Printf"] deps = ["Distributed", "Printf"]
git-tree-sha1 = "d85d8f0339a9937afac93e152c76f4745b386202" git-tree-sha1 = "1be8800271c86f572d334fef6e3b8364eaece7d9"
uuid = "92933f4c-e287-5a05-a399-4b506db050ca" uuid = "92933f4c-e287-5a05-a399-4b506db050ca"
version = "1.6.0" version = "1.6.2"
[[PyCall]] [[PyCall]]
deps = ["Conda", "Dates", "Libdl", "LinearAlgebra", "MacroTools", "Serialization", "VersionParsing"] deps = ["Conda", "Dates", "Libdl", "LinearAlgebra", "MacroTools", "Serialization", "VersionParsing"]
...@@ -1025,9 +1025,9 @@ uuid = "ea8e919c-243c-51af-8825-aaa63cd721ce" ...@@ -1025,9 +1025,9 @@ uuid = "ea8e919c-243c-51af-8825-aaa63cd721ce"
[[SLEEFPirates]] [[SLEEFPirates]]
deps = ["IfElse", "Libdl", "VectorizationBase"] deps = ["IfElse", "Libdl", "VectorizationBase"]
git-tree-sha1 = "3e682ce17a16c9dfb9d2fde0ceb2347e23faafba" git-tree-sha1 = "91a650350dcf6e0fc1a014b59669e704d8f579ae"
uuid = "476501e8-09a2-5ece-8869-fb82de89a1fa" uuid = "476501e8-09a2-5ece-8869-fb82de89a1fa"
version = "0.6.15" version = "0.6.17"
[[Scratch]] [[Scratch]]
deps = ["Dates"] deps = ["Dates"]
...@@ -1099,9 +1099,9 @@ version = "0.2.4" ...@@ -1099,9 +1099,9 @@ version = "0.2.4"
[[StaticArrays]] [[StaticArrays]]
deps = ["LinearAlgebra", "Random", "Statistics"] deps = ["LinearAlgebra", "Random", "Statistics"]
git-tree-sha1 = "2653e9c769343808781a8bd5010ee7a17c01152e" git-tree-sha1 = "fb46e45ef2cade8be20bb445b3ffeca3c6d6f7d3"
uuid = "90137ffa-7385-5640-81b9-e52037218182" uuid = "90137ffa-7385-5640-81b9-e52037218182"
version = "1.1.2" version = "1.1.3"
[[Statistics]] [[Statistics]]
deps = ["LinearAlgebra", "SparseArrays"] deps = ["LinearAlgebra", "SparseArrays"]
...@@ -1126,9 +1126,9 @@ version = "0.9.8" ...@@ -1126,9 +1126,9 @@ version = "0.9.8"
[[StrideArraysCore]] [[StrideArraysCore]]
deps = ["ArrayInterface", "Requires", "ThreadingUtilities", "VectorizationBase"] deps = ["ArrayInterface", "Requires", "ThreadingUtilities", "VectorizationBase"]
git-tree-sha1 = "62a9b1e31f0741a642455f42ddaa9582101b3e71" git-tree-sha1 = "f93118d367c8dec873c26a32ad2dea84989edd7d"
uuid = "7792a7ef-975c-4747-a70f-980b88e8d1da" uuid = "7792a7ef-975c-4747-a70f-980b88e8d1da"
version = "0.1.6" version = "0.1.7"
[[StructArrays]] [[StructArrays]]
deps = ["Adapt", "DataAPI", "Tables"] deps = ["Adapt", "DataAPI", "Tables"]
...@@ -1195,10 +1195,10 @@ uuid = "ac1d9e8a-700a-412c-b207-f0111f4b6c0d" ...@@ -1195,10 +1195,10 @@ uuid = "ac1d9e8a-700a-412c-b207-f0111f4b6c0d"
version = "0.1.7" version = "0.1.7"
[[TimeZones]] [[TimeZones]]
deps = ["Dates", "EzXML", "Mocking", "Pkg", "Printf", "RecipesBase", "Serialization", "Unicode"] deps = ["Dates", "EzXML", "LazyArtifacts", "Mocking", "Pkg", "Printf", "RecipesBase", "Serialization", "Unicode"]
git-tree-sha1 = "3f6f0be07f33e33bd986a58b4cf2d6c9fd2b7f18" git-tree-sha1 = "960099aed321e05ac649c90d583d59c9309faee1"
uuid = "f269a46b-ccf7-5d73-abea-4c690281aa53" uuid = "f269a46b-ccf7-5d73-abea-4c690281aa53"
version = "1.5.4" version = "1.5.5"
[[Transducers]] [[Transducers]]
deps = ["Adapt", "ArgCheck", "BangBang", "Baselet", "CompositionsBase", "DefineSingletons", "Distributed", "InitialValues", "Logging", "Markdown", "MicroCollections", "Requires", "Setfield", "SplittablesBase", "Tables"] deps = ["Adapt", "ArgCheck", "BangBang", "Baselet", "CompositionsBase", "DefineSingletons", "Distributed", "InitialValues", "Logging", "Markdown", "MicroCollections", "Requires", "Setfield", "SplittablesBase", "Tables"]
...@@ -1225,9 +1225,9 @@ uuid = "4ec0a83e-493e-50e2-b9ac-8f72acf5a8f5" ...@@ -1225,9 +1225,9 @@ uuid = "4ec0a83e-493e-50e2-b9ac-8f72acf5a8f5"
[[VectorizationBase]] [[VectorizationBase]]
deps = ["ArrayInterface", "Hwloc", "IfElse", "Libdl", "LinearAlgebra", "Static"] deps = ["ArrayInterface", "Hwloc", "IfElse", "Libdl", "LinearAlgebra", "Static"]
git-tree-sha1 = "293aa2c5cbf201e6b98810cb36d9eeafdafdafd1" git-tree-sha1 = "c258467e1a3473e328c8a9109efcce845723593e"
uuid = "3d5dd08c-fd9d-11e8-17fa-ed2836048c2f" uuid = "3d5dd08c-fd9d-11e8-17fa-ed2836048c2f"
version = "0.19.34" version = "0.19.37"
[[VectorizedRNG]] [[VectorizedRNG]]
deps = ["Distributed", "Random", "UnPack", "VectorizationBase"] deps = ["Distributed", "Random", "UnPack", "VectorizationBase"]
......
...@@ -42,5 +42,10 @@ Run the model with given parameter tuple and output recorder. See `get_parameter ...@@ -42,5 +42,10 @@ Run the model with given parameter tuple and output recorder. See `get_parameter
function abm(parameters, recorder) function abm(parameters, recorder)
model_sol = ModelSolution(parameters.sim_length,parameters,5000) model_sol = ModelSolution(parameters.sim_length,parameters,5000)
output = solve!(model_sol,recorder ) output = solve!(model_sol,recorder )
total_weighted_degree = map(modelsol.index_vectors) do age_group_indices
return mean(map(i -> weighted_degree(i,modelsol.inf_network),age_group_indices))
end
@show total_weighted_degree
return model_sol return model_sol
end end
...@@ -27,8 +27,8 @@ Fills i_to_j_contacts and j_to_i_contacts with degrees sampled from ij_dist and ...@@ -27,8 +27,8 @@ Fills i_to_j_contacts and j_to_i_contacts with degrees sampled from ij_dist and
Given `μz_i = mean(ij_dist)` and `μ_j = mean(ji_dist)`, these must satisfy `μ_i* length(i_to_j_contacts) == μ_j* length(j_to_i_contacts)` Given `μz_i = mean(ij_dist)` and `μ_j = mean(ji_dist)`, these must satisfy `μ_i* length(i_to_j_contacts) == μ_j* length(j_to_i_contacts)`
""" """
function generate_contact_vectors!(ij_dist,ji_dist,i_to_j_contacts::Vector{T}, j_to_i_contacts::Vector{T}) where T function generate_contact_vectors!(ij_dist,ji_dist,i_to_j_contacts::Vector{T}, j_to_i_contacts::Vector{T}) where T
rand!(Random.default_rng(Threads.threadid()),ij_dist,i_to_j_contacts) rand!(local_rng(),ij_dist,i_to_j_contacts)
rand!(Random.default_rng(Threads.threadid()),ji_dist,j_to_i_contacts) rand!(local_rng(),ji_dist,j_to_i_contacts)
l_i = length(i_to_j_contacts) l_i = length(i_to_j_contacts)
l_j = length(j_to_i_contacts) l_j = length(j_to_i_contacts)
...@@ -41,10 +41,10 @@ function generate_contact_vectors!(ij_dist,ji_dist,i_to_j_contacts::Vector{T}, j ...@@ -41,10 +41,10 @@ function generate_contact_vectors!(ij_dist,ji_dist,i_to_j_contacts::Vector{T}, j
sample_list_j = similar(sample_list_i) sample_list_j = similar(sample_list_i)
while csum != 0 while csum != 0
sample!(Random.default_rng(Threads.threadid()),1:l_i,index_list_i) sample!(local_rng(),1:l_i,index_list_i)
sample!(Random.default_rng(Threads.threadid()),1:l_j,index_list_j) sample!(local_rng(),1:l_j,index_list_j)
rand!(Random.default_rng(Threads.threadid()),ij_dist,sample_list_i) rand!(local_rng(),ij_dist,sample_list_i)
rand!(Random.default_rng(Threads.threadid()),ji_dist,sample_list_j) rand!(local_rng(),ji_dist,sample_list_j)
@inbounds for i = 1:inner_iter @inbounds for i = 1:inner_iter
if csum != 0 if csum != 0
csum = reindex!(i,csum,index_list_i,index_list_j,j_to_i_contacts,i_to_j_contacts,sample_list_i,sample_list_j) csum = reindex!(i,csum,index_list_i,index_list_j,j_to_i_contacts,i_to_j_contacts,sample_list_i,sample_list_j)
......
...@@ -57,5 +57,5 @@ const unemployment_matrix = alpha_matrix( ...@@ -57,5 +57,5 @@ const unemployment_matrix = alpha_matrix(
Sample initial_workschool_mixing_matrix, which is the workschool distributions symmetrized for the full Canadian population, rather than subsets (as used in the ABM). This is used in IntervalsModel. Sample initial_workschool_mixing_matrix, which is the workschool distributions symmetrized for the full Canadian population, rather than subsets (as used in the ABM). This is used in IntervalsModel.
""" """
@views function ws_sample(age) @views function ws_sample(age)
return rand.(initial_workschool_mixing_matrix[age,:]) * (rand(Random.default_rng(Threads.threadid())) < (5/7)) return rand.(initial_workschool_mixing_matrix[age,:]) * (rand(local_rng()) < (5/7))
end end
...@@ -37,7 +37,7 @@ function sample_mixing_graph!(mixing_graph) ...@@ -37,7 +37,7 @@ function sample_mixing_graph!(mixing_graph)
# display(length.(mixing_edges.contact_array)) # display(length.(mixing_edges.contact_array))
# display(length.(mixing_edges.sample_cache)) # display(length.(mixing_edges.sample_cache))
for i in 1:size(mixing_edges.contact_array)[1], j in 1:i #diagonal for i in 1:size(mixing_edges.contact_array)[1], j in 1:i #diagonal
rand!(Random.default_rng(Threads.threadid()), mixing_edges.sampler_matrix[j,i],mixing_edges.sample_cache[j,i]) rand!(local_rng(), mixing_edges.sampler_matrix[j,i],mixing_edges.sample_cache[j,i])
for k in 1:length(mixing_edges.contact_array[j,i]) for k in 1:length(mixing_edges.contact_array[j,i])
edge_weight_k = mixing_edges.sample_cache[j,i][k] edge_weight_k = mixing_edges.sample_cache[j,i][k]
set!(mixing_edges.weights_dict, mixing_edges.contact_array[j,i][k], edge_weight_k) set!(mixing_edges.weights_dict, mixing_edges.contact_array[j,i][k], edge_weight_k)
...@@ -98,8 +98,8 @@ function create_mixing_edges(demographic_index_vectors,mixing_matrix,weights_dis ...@@ -98,8 +98,8 @@ function create_mixing_edges(demographic_index_vectors,mixing_matrix,weights_dis
stubs_i = Vector{Int}(undef,m) stubs_i = Vector{Int}(undef,m)
stubs_j = Vector{Int}(undef,m) stubs_j = Vector{Int}(undef,m)
if m>0 if m>0
sample!(Random.default_rng(Threads.threadid()),demographic_index_vectors[i],Weights(num_degrees_ij./m),stubs_i) sample!(local_rng(),demographic_index_vectors[i],Weights(num_degrees_ij./m),stubs_i)
sample!(Random.default_rng(Threads.threadid()),demographic_index_vectors[j],Weights(num_degrees_ji./m),stubs_j) sample!(local_rng(),demographic_index_vectors[j],Weights(num_degrees_ji./m),stubs_j)
tot += m tot += m
end end
...@@ -176,7 +176,7 @@ function time_dep_mixing_graphs(len,base_network,demographics,index_vectors,ws_m ...@@ -176,7 +176,7 @@ function time_dep_mixing_graphs(len,base_network,demographics,index_vectors,ws_m
if !(day_of_week == 3 || day_of_week == 4) #simulation begins on thursday I guess if !(day_of_week == 3 || day_of_week == 4) #simulation begins on thursday I guess
push!(l, ws_static_edges) push!(l, ws_static_edges)
end end
if rand(Random.default_rng(Threads.threadid()))<5/7 if rand(local_rng())<5/7
push!(l, ws_weekly_edges) push!(l, ws_weekly_edges)
push!(l, rest_weekly_edges) push!(l, rest_weekly_edges)
end end
...@@ -237,7 +237,7 @@ mutable struct WeightedGraph{G,M} ...@@ -237,7 +237,7 @@ mutable struct WeightedGraph{G,M}
) )
end end
end end
neighbors(g::WeightedGraph,i) = neighbors(g.g,i)
get_weight(g::WeightedGraph,e) = g.mixing_edges.weights_dict[e] get_weight(g::WeightedGraph,e) = g.mixing_edges.weights_dict[e]
function Base.show(io::IO, g::WeightedGraph) function Base.show(io::IO, g::WeightedGraph)
print(io, "WG $(ne(g.g))") print(io, "WG $(ne(g.g))")
......
...@@ -28,7 +28,7 @@ function get_parameters() ...@@ -28,7 +28,7 @@ function get_parameters()
end end
function get_u_0(nodes,I_0_fraction,vaccinator_prob) function get_u_0(nodes,I_0_fraction,vaccinator_prob)
is_vaccinator = rand(Random.default_rng(Threads.threadid()),nodes) .< vaccinator_prob is_vaccinator = rand(local_rng(),nodes) .< vaccinator_prob
status = fill(Susceptible,nodes) status = fill(Susceptible,nodes)
return status,is_vaccinator return status,is_vaccinator
end end
...@@ -42,7 +42,7 @@ function app_users(demographics,app_usage_prob) ...@@ -42,7 +42,7 @@ function app_users(demographics,app_usage_prob)
is_app_user = Vector{Bool}(undef,length(demographics)) is_app_user = Vector{Bool}(undef,length(demographics))
@inbounds for i in eachindex(demographics) @inbounds for i in eachindex(demographics)
demo = demographics[i] demo = demographics[i]
is_app_user[i] = rand(Random.default_rng(Threads.threadid())) < app_usage_prob*ymo_usage[Int(demo)] is_app_user[i] = rand(local_rng()) < app_usage_prob*ymo_usage[Int(demo)]
end end
return is_app_user return is_app_user
end end
......
...@@ -20,14 +20,14 @@ Base.@propagate_inbounds @views function update_alert_durations!(t,modelsol) # B ...@@ -20,14 +20,14 @@ Base.@propagate_inbounds @views function update_alert_durations!(t,modelsol) # B
end end
total_weight_i = 0 total_weight_i = 0
for mixing_graph in inf_network.graph_list[t] for mixing_graph in inf_network.graph_list[t]
for j in neighbors(mixing_graph.g,node) for j in neighbors(mixing_graph,node)
if app_user[j] if app_user[j]
total_weight_i+= get_weight(mixing_graph,GraphEdge(node,j)) total_weight_i+= get_weight(mixing_graph,GraphEdge(node,j))
end end
end end
end end
coin_flip = 1 - (1 - notification_parameter)^total_weight_i coin_flip = 1 - (1 - notification_parameter)^total_weight_i
r = rand(Random.default_rng(Threads.threadid())) r = rand(local_rng())
if r < coin_flip if r < coin_flip
covid_alert_notifications[end,i] = 1 #add the notifications for today covid_alert_notifications[end,i] = 1 #add the notifications for today
else else
...@@ -61,9 +61,9 @@ Base.@propagate_inbounds @views function update_infection_state!(t,modelsol) ...@@ -61,9 +61,9 @@ Base.@propagate_inbounds @views function update_infection_state!(t,modelsol)
immunization_countdown[i] = 14 immunization_countdown[i] = 14
else else
for mixing_graph in inf_network.graph_list[t] for mixing_graph in inf_network.graph_list[t]
for j in neighbors(mixing_graph.g,i) for j in neighbors(mixing_graph,i)
if u_inf[j] == Infected && u_next_inf[i] != Infected if u_inf[j] == Infected && u_next_inf[i] != Infected
if rand(Random.default_rng(Threads.threadid())) < contact_weight(base_transmission_probability,get_weight(mixing_graph,GraphEdge(i,j))) if rand(local_rng()) < contact_weight(base_transmission_probability,get_weight(mixing_graph,GraphEdge(i,j)))
modelsol.daily_cases_by_age[Int(agent_demo)]+=1 modelsol.daily_cases_by_age[Int(agent_demo)]+=1
agent_transition!(i, Susceptible,Infected) agent_transition!(i, Susceptible,Infected)
end end
...@@ -72,7 +72,7 @@ Base.@propagate_inbounds @views function update_infection_state!(t,modelsol) ...@@ -72,7 +72,7 @@ Base.@propagate_inbounds @views function update_infection_state!(t,modelsol)
end end
end end
elseif agent_status == Infected elseif agent_status == Infected
if rand(Random.default_rng(Threads.threadid())) < recovery_rate if rand(local_rng()) < recovery_rate
agent_transition!(i, Infected,Recovered) agent_transition!(i, Infected,Recovered)
end end
end end
...@@ -95,10 +95,10 @@ Base.@propagate_inbounds @views function update_vaccination_opinion_state!(t,mod ...@@ -95,10 +95,10 @@ Base.@propagate_inbounds @views function update_vaccination_opinion_state!(t,mod
π_base = @SVector [π_base_y,π_base_m,π_base_o] π_base = @SVector [π_base_y,π_base_m,π_base_o]
vac_payoff = 0 vac_payoff = 0
num_soc_nbrs = 0 num_soc_nbrs = 0
random_soc_network = sample(Random.default_rng(Threads.threadid()), soc_network.graph_list[t]) random_soc_network = sample(local_rng(), soc_network.graph_list[t])
if !isempty(neighbors(random_soc_network.g,i)) if !isempty(neighbors(random_soc_network,i))
random_neighbour = sample(Random.default_rng(Threads.threadid()), neighbors(random_soc_network.g,i)) random_neighbour = sample(local_rng(), neighbors(random_soc_network.g,i))
if u_vac[random_neighbour] == u_vac[i] if u_vac[random_neighbour] == u_vac[i]
vac_payoff += π_base[Int(demographics[i])] + total_infections*ω vac_payoff += π_base[Int(demographics[i])] + total_infections*ω
if app_user[i] && time_of_last_alert[app_user_list[i]]>=0 if app_user[i] && time_of_last_alert[app_user_list[i]]>=0
...@@ -106,11 +106,11 @@ Base.@propagate_inbounds @views function update_vaccination_opinion_state!(t,mod ...@@ -106,11 +106,11 @@ Base.@propagate_inbounds @views function update_vaccination_opinion_state!(t,mod
end end
if u_vac[i] if u_vac[i]
if rand(Random.default_rng(Threads.threadid())) < 1 - Φ(vac_payoff,β) if rand(local_rng()) < 1 - Φ(vac_payoff,β)
u_next_vac[i] = false u_next_vac[i] = false
end end
else else
if rand(Random.default_rng(Threads.threadid())) < Φ(vac_payoff,β) if rand(local_rng()) < Φ(vac_payoff,β)
u_next_vac[i] = true u_next_vac[i] = true
end end
end end
...@@ -126,7 +126,7 @@ function weighted_degree(node,network::TimeDepMixingGraph) ...@@ -126,7 +126,7 @@ function weighted_degree(node,network::TimeDepMixingGraph)
weighted_degree = 0 weighted_degree = 0
for g_list in network.graph_list for g_list in network.graph_list
for g in g_list for g in g_list
for j in neighbors(g.g,node) for j in neighbors(g,node)
weighted_degree += get_weight(g,GraphEdge(node,j)) weighted_degree += get_weight(g,GraphEdge(node,j))
end end
end end
...@@ -140,17 +140,9 @@ function agents_step!(t,modelsol) ...@@ -140,17 +140,9 @@ function agents_step!(t,modelsol)
for network in modelsol.inf_network.graph_list[t] #this also resamples the soc network weights since they point to the same objects, but those are never used for network in modelsol.inf_network.graph_list[t] #this also resamples the soc network weights since they point to the same objects, but those are never used
sample_mixing_graph!(network) #get new contact weights sample_mixing_graph!(network) #get new contact weights
end end
# @show modelsol.inf_network.graph_list[t]
# @show weighted_degree(1,modelsol.inf_network)
# @show weighted_degree(1,modelsol.soc_network)
if t == modelsol.params.infection_introduction_day if t == modelsol.params.infection_introduction_day
init_indices = rand(Random.default_rng(Threads.threadid()), 1:modelsol.nodes, round(Int,modelsol.nodes*modelsol.params.I_0_fraction)) init_indices = rand(local_rng(), 1:modelsol.nodes, round(Int,modelsol.nodes*modelsol.params.I_0_fraction))
modelsol.u_inf[init_indices] .= Infected modelsol.u_inf[init_indices] .= Infected
modelsol.status_totals[Int(Infected)] += length(init_indices) modelsol.status_totals[Int(Infected)] += length(init_indices)
end end
......
...@@ -26,6 +26,7 @@ import Pandas: read_csv ...@@ -26,6 +26,7 @@ import Pandas: read_csv
using DataFrames using DataFrames
using StaticArrays using StaticArrays
import LightGraphs.neighbors import LightGraphs.neighbors
using VectorizedRNG
export intervalsmodel, hh, ws, rest, abm export intervalsmodel, hh, ws, rest, abm
......
No preview for this file type
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment