Commit a27f8d5b authored by Peter Jentsch's avatar Peter Jentsch
Browse files

switch to VectorizedRNG

parent fd1ec66c
......@@ -39,9 +39,9 @@ version = "0.1.0"
[[ArrayInterface]]
deps = ["IfElse", "LinearAlgebra", "Requires", "SparseArrays", "Static"]
git-tree-sha1 = "d84e8967b7f04f52c9bca21714bae54a553a53fc"
git-tree-sha1 = "b08be763d0b8ddee6b162016dad746a69980616d"
uuid = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9"
version = "3.1.10"
version = "3.1.11"
[[Artifacts]]
uuid = "56f22d72-fd6d-98f1-02f0-08ddc0907c33"
......@@ -92,9 +92,9 @@ version = "1.16.0+6"
[[ChainRulesCore]]
deps = ["Compat", "LinearAlgebra", "SparseArrays"]
git-tree-sha1 = "bd0cc939d94b8bd736dce5bbbe0d635db9f94af7"
git-tree-sha1 = "e6b23566e025d3b0d9ccc397f5c7a134af552e27"
uuid = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
version = "0.9.41"
version = "0.9.42"
[[CheapThreads]]
deps = ["ArrayInterface", "IfElse", "Requires", "Static", "StrideArraysCore", "ThreadingUtilities", "VectorizationBase"]
......@@ -128,9 +128,9 @@ version = "0.3.0"
[[Compat]]
deps = ["Base64", "Dates", "DelimitedFiles", "Distributed", "InteractiveUtils", "LibGit2", "Libdl", "LinearAlgebra", "Markdown", "Mmap", "Pkg", "Printf", "REPL", "Random", "SHA", "Serialization", "SharedArrays", "Sockets", "SparseArrays", "Statistics", "Test", "UUIDs", "Unicode"]
git-tree-sha1 = "ac4132ad78082518ec2037ae5770b6e796f7f956"
git-tree-sha1 = "0a817fbe51c976de090aa8c997b7b719b786118d"
uuid = "34da2185-b29b-5c13-b0c7-acf172513d20"
version = "3.27.0"
version = "3.28.0"
[[CompilerSupportLibraries_jll]]
deps = ["Artifacts", "Libdl"]
......@@ -155,9 +155,9 @@ version = "0.1.2"
[[ConstructionBase]]
deps = ["LinearAlgebra"]
git-tree-sha1 = "48920211c95a6da1914a06c44ec94be70e84ffff"
git-tree-sha1 = "1dc43957fb9a1574fa1b7a449e101bd1fd3a9fb7"
uuid = "187b0558-2788-49d3-abe0-74a17ed4e7c9"
version = "1.1.0"
version = "1.2.1"
[[Contour]]
deps = ["StaticArrays"]
......@@ -183,9 +183,9 @@ version = "1.6.0"
[[DataFrames]]
deps = ["Compat", "DataAPI", "Future", "InvertedIndices", "IteratorInterfaceExtensions", "LinearAlgebra", "Markdown", "Missings", "PooledArrays", "PrettyTables", "Printf", "REPL", "Reexport", "SortingAlgorithms", "Statistics", "TableTraits", "Tables", "Unicode"]
git-tree-sha1 = "6e5452d9cf401ed9048e1cde93815be53d951079"
git-tree-sha1 = "66ee4fe515a9294a8836ef18eea7239c6ac3db5e"
uuid = "a93c6f00-e57d-5684-b7b6-d8193f3e46c0"
version = "1.0.2"
version = "1.1.1"
[[DataStructures]]
deps = ["Compat", "InteractiveUtils", "OrderedCollections"]
......@@ -322,9 +322,9 @@ version = "2.8.0"
[[FiniteDifferences]]
deps = ["ChainRulesCore", "LinearAlgebra", "Printf", "Random", "Richardson", "StaticArrays"]
git-tree-sha1 = "ae1ce4975c393bad81a55cb073e2af04bd151f85"
git-tree-sha1 = "80e1a7416cbf08fe80c8885e1834c45cfc399c61"
uuid = "26cc04aa-876d-5657-8c51-4c34ba976000"
version = "0.12.3"
version = "0.12.5"
[[FixedPointNumbers]]
deps = ["Statistics"]
......@@ -415,9 +415,9 @@ version = "0.9.8"
[[Hwloc]]
deps = ["Hwloc_jll"]
git-tree-sha1 = "ffdcd4272a7cc36442007bca41aa07ca3cc5fda4"
git-tree-sha1 = "92d99146066c5c6888d5a3abc871e6a214388b91"
uuid = "0e44f5e4-bd66-52a0-8798-143a42290a1d"
version = "1.3.0"
version = "2.0.0"
[[Hwloc_jll]]
deps = ["Artifacts", "JLLWrappers", "Libdl", "Pkg"]
......@@ -678,9 +678,9 @@ version = "0.4.6"
[[LoopVectorization]]
deps = ["ArrayInterface", "CheapThreads", "DocStringExtensions", "IfElse", "LinearAlgebra", "OffsetArrays", "Requires", "SLEEFPirates", "Static", "StrideArraysCore", "ThreadingUtilities", "UnPack", "VectorizationBase"]
git-tree-sha1 = "427ec6a601c32d704bb664b32bf695519ef66043"
git-tree-sha1 = "2ad016117de05750443ed219dada9df3f9b9fa8f"
uuid = "bdcacae8-1622-11e9-2a5c-532679323890"
version = "0.12.15"
version = "0.12.18"
[[LsqFit]]
deps = ["Distributions", "ForwardDiff", "LinearAlgebra", "NLSolversBase", "OptimBase", "Random", "StatsBase"]
......@@ -751,9 +751,9 @@ uuid = "14a3606d-f60d-562e-9121-12d972cd8159"
[[MutableArithmetics]]
deps = ["LinearAlgebra", "SparseArrays", "Test"]
git-tree-sha1 = "3301e152b9a208745fad6cd4b068307a5d218a38"
git-tree-sha1 = "ad9b2bce6021631e0e20706d361972343a03e642"
uuid = "d8a4904e-b15c-11e9-3269-09a3773c0cb0"
version = "0.2.18"
version = "0.2.19"
[[NLSolversBase]]
deps = ["DiffResults", "Distributed", "FiniteDiff", "ForwardDiff"]
......@@ -782,9 +782,9 @@ uuid = "ca575930-c2e3-43a9-ace4-1e988b2c1908"
[[OffsetArrays]]
deps = ["Adapt"]
git-tree-sha1 = "87a728aebb76220bd72855e1c85284c5fdb9774c"
git-tree-sha1 = "47b443d2ccc8297a4c538f55f8fd828ad58599ab"
uuid = "6fe1bfb0-de20-5000-8ca7-80f57d26f881"
version = "1.7.0"
version = "1.8.0"
[[Ogg_jll]]
deps = ["Artifacts", "JLLWrappers", "Libdl", "Pkg"]
......@@ -829,9 +829,9 @@ uuid = "91d4177d-7536-5919-b921-800302f37372"
version = "1.3.1+3"
[[OrderedCollections]]
git-tree-sha1 = "4fa2ba51070ec13fcc7517db714445b4ab986bdf"
git-tree-sha1 = "85f8e6578bf1f9ee0d11e7bb1b1456435479d47c"
uuid = "bac558e1-5e72-5ebc-8fee-abe8a469f55d"
version = "1.4.0"
version = "1.4.1"
[[PCRE_jll]]
deps = ["Artifacts", "JLLWrappers", "Libdl", "Pkg"]
......@@ -892,9 +892,9 @@ version = "1.0.10"
[[Plots]]
deps = ["Base64", "Contour", "Dates", "FFMPEG", "FixedPointNumbers", "GR", "GeometryBasics", "JSON", "Latexify", "LinearAlgebra", "Measures", "NaNMath", "PlotThemes", "PlotUtils", "Printf", "REPL", "Random", "RecipesBase", "RecipesPipeline", "Reexport", "Requires", "Scratch", "Showoff", "SparseArrays", "Statistics", "StatsBase", "UUIDs"]
git-tree-sha1 = "eced322de627fa8469c55f24899fdd4cce7d978c"
git-tree-sha1 = "2628e5859819173cef995470af83db42bf411ef8"
uuid = "91a5bcdd-55d7-5caf-9e0b-520d859cae80"
version = "1.13.2"
version = "1.14.0"
[[Polynomials]]
deps = ["Intervals", "LinearAlgebra", "MutableArithmetics", "RecipesBase"]
......@@ -910,9 +910,9 @@ version = "1.2.1"
[[Preferences]]
deps = ["TOML"]
git-tree-sha1 = "ea79e4c9077208cd3bc5d29631a26bc0cff78902"
git-tree-sha1 = "00cfd92944ca9c760982747e9a1d0d5d86ab1e5a"
uuid = "21216c6a-2e73-6563-6e65-726566657250"
version = "1.2.1"
version = "1.2.2"
[[PrettyTables]]
deps = ["Crayons", "Formatting", "Markdown", "Reexport", "Tables"]
......@@ -932,9 +932,9 @@ version = "0.1.4"
[[ProgressMeter]]
deps = ["Distributed", "Printf"]
git-tree-sha1 = "d85d8f0339a9937afac93e152c76f4745b386202"
git-tree-sha1 = "1be8800271c86f572d334fef6e3b8364eaece7d9"
uuid = "92933f4c-e287-5a05-a399-4b506db050ca"
version = "1.6.0"
version = "1.6.2"
[[PyCall]]
deps = ["Conda", "Dates", "Libdl", "LinearAlgebra", "MacroTools", "Serialization", "VersionParsing"]
......@@ -1025,9 +1025,9 @@ uuid = "ea8e919c-243c-51af-8825-aaa63cd721ce"
[[SLEEFPirates]]
deps = ["IfElse", "Libdl", "VectorizationBase"]
git-tree-sha1 = "3e682ce17a16c9dfb9d2fde0ceb2347e23faafba"
git-tree-sha1 = "91a650350dcf6e0fc1a014b59669e704d8f579ae"
uuid = "476501e8-09a2-5ece-8869-fb82de89a1fa"
version = "0.6.15"
version = "0.6.17"
[[Scratch]]
deps = ["Dates"]
......@@ -1099,9 +1099,9 @@ version = "0.2.4"
[[StaticArrays]]
deps = ["LinearAlgebra", "Random", "Statistics"]
git-tree-sha1 = "2653e9c769343808781a8bd5010ee7a17c01152e"
git-tree-sha1 = "fb46e45ef2cade8be20bb445b3ffeca3c6d6f7d3"
uuid = "90137ffa-7385-5640-81b9-e52037218182"
version = "1.1.2"
version = "1.1.3"
[[Statistics]]
deps = ["LinearAlgebra", "SparseArrays"]
......@@ -1126,9 +1126,9 @@ version = "0.9.8"
[[StrideArraysCore]]
deps = ["ArrayInterface", "Requires", "ThreadingUtilities", "VectorizationBase"]
git-tree-sha1 = "62a9b1e31f0741a642455f42ddaa9582101b3e71"
git-tree-sha1 = "f93118d367c8dec873c26a32ad2dea84989edd7d"
uuid = "7792a7ef-975c-4747-a70f-980b88e8d1da"
version = "0.1.6"
version = "0.1.7"
[[StructArrays]]
deps = ["Adapt", "DataAPI", "Tables"]
......@@ -1195,10 +1195,10 @@ uuid = "ac1d9e8a-700a-412c-b207-f0111f4b6c0d"
version = "0.1.7"
[[TimeZones]]
deps = ["Dates", "EzXML", "Mocking", "Pkg", "Printf", "RecipesBase", "Serialization", "Unicode"]
git-tree-sha1 = "3f6f0be07f33e33bd986a58b4cf2d6c9fd2b7f18"
deps = ["Dates", "EzXML", "LazyArtifacts", "Mocking", "Pkg", "Printf", "RecipesBase", "Serialization", "Unicode"]
git-tree-sha1 = "960099aed321e05ac649c90d583d59c9309faee1"
uuid = "f269a46b-ccf7-5d73-abea-4c690281aa53"
version = "1.5.4"
version = "1.5.5"
[[Transducers]]
deps = ["Adapt", "ArgCheck", "BangBang", "Baselet", "CompositionsBase", "DefineSingletons", "Distributed", "InitialValues", "Logging", "Markdown", "MicroCollections", "Requires", "Setfield", "SplittablesBase", "Tables"]
......@@ -1225,9 +1225,9 @@ uuid = "4ec0a83e-493e-50e2-b9ac-8f72acf5a8f5"
[[VectorizationBase]]
deps = ["ArrayInterface", "Hwloc", "IfElse", "Libdl", "LinearAlgebra", "Static"]
git-tree-sha1 = "293aa2c5cbf201e6b98810cb36d9eeafdafdafd1"
git-tree-sha1 = "c258467e1a3473e328c8a9109efcce845723593e"
uuid = "3d5dd08c-fd9d-11e8-17fa-ed2836048c2f"
version = "0.19.34"
version = "0.19.37"
[[VectorizedRNG]]
deps = ["Distributed", "Random", "UnPack", "VectorizationBase"]
......
......@@ -42,5 +42,10 @@ Run the model with given parameter tuple and output recorder. See `get_parameter
function abm(parameters, recorder)
model_sol = ModelSolution(parameters.sim_length,parameters,5000)
output = solve!(model_sol,recorder )
total_weighted_degree = map(modelsol.index_vectors) do age_group_indices
return mean(map(i -> weighted_degree(i,modelsol.inf_network),age_group_indices))
end
@show total_weighted_degree
return model_sol
end
......@@ -27,8 +27,8 @@ Fills i_to_j_contacts and j_to_i_contacts with degrees sampled from ij_dist and
Given `μz_i = mean(ij_dist)` and `μ_j = mean(ji_dist)`, these must satisfy `μ_i* length(i_to_j_contacts) == μ_j* length(j_to_i_contacts)`
"""
function generate_contact_vectors!(ij_dist,ji_dist,i_to_j_contacts::Vector{T}, j_to_i_contacts::Vector{T}) where T
rand!(Random.default_rng(Threads.threadid()),ij_dist,i_to_j_contacts)
rand!(Random.default_rng(Threads.threadid()),ji_dist,j_to_i_contacts)
rand!(local_rng(),ij_dist,i_to_j_contacts)
rand!(local_rng(),ji_dist,j_to_i_contacts)
l_i = length(i_to_j_contacts)
l_j = length(j_to_i_contacts)
......@@ -41,10 +41,10 @@ function generate_contact_vectors!(ij_dist,ji_dist,i_to_j_contacts::Vector{T}, j
sample_list_j = similar(sample_list_i)
while csum != 0
sample!(Random.default_rng(Threads.threadid()),1:l_i,index_list_i)
sample!(Random.default_rng(Threads.threadid()),1:l_j,index_list_j)
rand!(Random.default_rng(Threads.threadid()),ij_dist,sample_list_i)
rand!(Random.default_rng(Threads.threadid()),ji_dist,sample_list_j)
sample!(local_rng(),1:l_i,index_list_i)
sample!(local_rng(),1:l_j,index_list_j)
rand!(local_rng(),ij_dist,sample_list_i)
rand!(local_rng(),ji_dist,sample_list_j)
@inbounds for i = 1:inner_iter
if csum != 0
csum = reindex!(i,csum,index_list_i,index_list_j,j_to_i_contacts,i_to_j_contacts,sample_list_i,sample_list_j)
......
......@@ -57,5 +57,5 @@ const unemployment_matrix = alpha_matrix(
Sample initial_workschool_mixing_matrix, which is the workschool distributions symmetrized for the full Canadian population, rather than subsets (as used in the ABM). This is used in IntervalsModel.
"""
@views function ws_sample(age)
return rand.(initial_workschool_mixing_matrix[age,:]) * (rand(Random.default_rng(Threads.threadid())) < (5/7))
return rand.(initial_workschool_mixing_matrix[age,:]) * (rand(local_rng()) < (5/7))
end
......@@ -37,7 +37,7 @@ function sample_mixing_graph!(mixing_graph)
# display(length.(mixing_edges.contact_array))
# display(length.(mixing_edges.sample_cache))
for i in 1:size(mixing_edges.contact_array)[1], j in 1:i #diagonal
rand!(Random.default_rng(Threads.threadid()), mixing_edges.sampler_matrix[j,i],mixing_edges.sample_cache[j,i])
rand!(local_rng(), mixing_edges.sampler_matrix[j,i],mixing_edges.sample_cache[j,i])
for k in 1:length(mixing_edges.contact_array[j,i])
edge_weight_k = mixing_edges.sample_cache[j,i][k]
set!(mixing_edges.weights_dict, mixing_edges.contact_array[j,i][k], edge_weight_k)
......@@ -98,8 +98,8 @@ function create_mixing_edges(demographic_index_vectors,mixing_matrix,weights_dis
stubs_i = Vector{Int}(undef,m)
stubs_j = Vector{Int}(undef,m)
if m>0
sample!(Random.default_rng(Threads.threadid()),demographic_index_vectors[i],Weights(num_degrees_ij./m),stubs_i)
sample!(Random.default_rng(Threads.threadid()),demographic_index_vectors[j],Weights(num_degrees_ji./m),stubs_j)
sample!(local_rng(),demographic_index_vectors[i],Weights(num_degrees_ij./m),stubs_i)
sample!(local_rng(),demographic_index_vectors[j],Weights(num_degrees_ji./m),stubs_j)
tot += m
end
......@@ -176,7 +176,7 @@ function time_dep_mixing_graphs(len,base_network,demographics,index_vectors,ws_m
if !(day_of_week == 3 || day_of_week == 4) #simulation begins on thursday I guess
push!(l, ws_static_edges)
end
if rand(Random.default_rng(Threads.threadid()))<5/7
if rand(local_rng())<5/7
push!(l, ws_weekly_edges)
push!(l, rest_weekly_edges)
end
......@@ -237,7 +237,7 @@ mutable struct WeightedGraph{G,M}
)
end
end
neighbors(g::WeightedGraph,i) = neighbors(g.g,i)
get_weight(g::WeightedGraph,e) = g.mixing_edges.weights_dict[e]
function Base.show(io::IO, g::WeightedGraph)
print(io, "WG $(ne(g.g))")
......
......@@ -28,7 +28,7 @@ function get_parameters()
end
function get_u_0(nodes,I_0_fraction,vaccinator_prob)
is_vaccinator = rand(Random.default_rng(Threads.threadid()),nodes) .< vaccinator_prob
is_vaccinator = rand(local_rng(),nodes) .< vaccinator_prob
status = fill(Susceptible,nodes)
return status,is_vaccinator
end
......@@ -42,7 +42,7 @@ function app_users(demographics,app_usage_prob)
is_app_user = Vector{Bool}(undef,length(demographics))
@inbounds for i in eachindex(demographics)
demo = demographics[i]
is_app_user[i] = rand(Random.default_rng(Threads.threadid())) < app_usage_prob*ymo_usage[Int(demo)]
is_app_user[i] = rand(local_rng()) < app_usage_prob*ymo_usage[Int(demo)]
end
return is_app_user
end
......
......@@ -20,14 +20,14 @@ Base.@propagate_inbounds @views function update_alert_durations!(t,modelsol) # B
end
total_weight_i = 0
for mixing_graph in inf_network.graph_list[t]
for j in neighbors(mixing_graph.g,node)
for j in neighbors(mixing_graph,node)
if app_user[j]
total_weight_i+= get_weight(mixing_graph,GraphEdge(node,j))
end
end
end
coin_flip = 1 - (1 - notification_parameter)^total_weight_i
r = rand(Random.default_rng(Threads.threadid()))
r = rand(local_rng())
if r < coin_flip
covid_alert_notifications[end,i] = 1 #add the notifications for today
else
......@@ -61,9 +61,9 @@ Base.@propagate_inbounds @views function update_infection_state!(t,modelsol)
immunization_countdown[i] = 14
else
for mixing_graph in inf_network.graph_list[t]
for j in neighbors(mixing_graph.g,i)
for j in neighbors(mixing_graph,i)
if u_inf[j] == Infected && u_next_inf[i] != Infected
if rand(Random.default_rng(Threads.threadid())) < contact_weight(base_transmission_probability,get_weight(mixing_graph,GraphEdge(i,j)))
if rand(local_rng()) < contact_weight(base_transmission_probability,get_weight(mixing_graph,GraphEdge(i,j)))
modelsol.daily_cases_by_age[Int(agent_demo)]+=1
agent_transition!(i, Susceptible,Infected)
end
......@@ -72,7 +72,7 @@ Base.@propagate_inbounds @views function update_infection_state!(t,modelsol)
end
end
elseif agent_status == Infected
if rand(Random.default_rng(Threads.threadid())) < recovery_rate
if rand(local_rng()) < recovery_rate
agent_transition!(i, Infected,Recovered)
end
end
......@@ -95,10 +95,10 @@ Base.@propagate_inbounds @views function update_vaccination_opinion_state!(t,mod
π_base = @SVector [π_base_y,π_base_m,π_base_o]
vac_payoff = 0
num_soc_nbrs = 0
random_soc_network = sample(Random.default_rng(Threads.threadid()), soc_network.graph_list[t])
random_soc_network = sample(local_rng(), soc_network.graph_list[t])
if !isempty(neighbors(random_soc_network.g,i))
random_neighbour = sample(Random.default_rng(Threads.threadid()), neighbors(random_soc_network.g,i))
if !isempty(neighbors(random_soc_network,i))
random_neighbour = sample(local_rng(), neighbors(random_soc_network.g,i))
if u_vac[random_neighbour] == u_vac[i]
vac_payoff += π_base[Int(demographics[i])] + total_infections*ω
if app_user[i] && time_of_last_alert[app_user_list[i]]>=0
......@@ -106,11 +106,11 @@ Base.@propagate_inbounds @views function update_vaccination_opinion_state!(t,mod
end
if u_vac[i]
if rand(Random.default_rng(Threads.threadid())) < 1 - Φ(vac_payoff,β)
if rand(local_rng()) < 1 - Φ(vac_payoff,β)
u_next_vac[i] = false
end
else
if rand(Random.default_rng(Threads.threadid())) < Φ(vac_payoff,β)
if rand(local_rng()) < Φ(vac_payoff,β)
u_next_vac[i] = true
end
end
......@@ -126,7 +126,7 @@ function weighted_degree(node,network::TimeDepMixingGraph)
weighted_degree = 0
for g_list in network.graph_list
for g in g_list
for j in neighbors(g.g,node)
for j in neighbors(g,node)
weighted_degree += get_weight(g,GraphEdge(node,j))
end
end
......@@ -140,17 +140,9 @@ function agents_step!(t,modelsol)
for network in modelsol.inf_network.graph_list[t] #this also resamples the soc network weights since they point to the same objects, but those are never used
sample_mixing_graph!(network) #get new contact weights
end
# @show modelsol.inf_network.graph_list[t]
# @show weighted_degree(1,modelsol.inf_network)
# @show weighted_degree(1,modelsol.soc_network)
if t == modelsol.params.infection_introduction_day
init_indices = rand(Random.default_rng(Threads.threadid()), 1:modelsol.nodes, round(Int,modelsol.nodes*modelsol.params.I_0_fraction))
init_indices = rand(local_rng(), 1:modelsol.nodes, round(Int,modelsol.nodes*modelsol.params.I_0_fraction))
modelsol.u_inf[init_indices] .= Infected
modelsol.status_totals[Int(Infected)] += length(init_indices)
end
......
......@@ -26,6 +26,7 @@ import Pandas: read_csv
using DataFrames
using StaticArrays
import LightGraphs.neighbors
using VectorizedRNG
export intervalsmodel, hh, ws, rest, abm
......
No preview for this file type
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment