diff --git a/Manifest.toml b/Manifest.toml index 2d810023493375a88b69c7723e7c8f4dd590180f..cbf9bb04055aff9877322afa9e81de096751af6d 100644 --- a/Manifest.toml +++ b/Manifest.toml @@ -22,9 +22,9 @@ version = "0.1.0" [[ArrayInterface]] deps = ["IfElse", "LinearAlgebra", "Requires", "SparseArrays"] -git-tree-sha1 = "ef91c543a3a8094eba9b1f7171258b9ecae87dfa" +git-tree-sha1 = "b9c3166c0124f44135419a394f42912c14dcbd80" uuid = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9" -version = "3.0.0" +version = "3.0.1" [[Artifacts]] uuid = "56f22d72-fd6d-98f1-02f0-08ddc0907c33" @@ -50,6 +50,12 @@ git-tree-sha1 = "c3598e525718abcc440f69cc6d5f60dda0a1b61e" uuid = "6e34b625-4abd-537c-b88f-471c36dfa7a0" version = "1.0.6+5" +[[CSV]] +deps = ["Dates", "Mmap", "Parsers", "PooledArrays", "SentinelArrays", "Tables", "Unicode"] +git-tree-sha1 = "1f79803452adf73e2d3fc84785adb7aaca14db36" +uuid = "336ed68f-0bac-5ca0-87d4-7b16caf5d00b" +version = "0.8.3" + [[Cairo_jll]] deps = ["Artifacts", "Bzip2_jll", "Fontconfig_jll", "FreeType2_jll", "Glib_jll", "JLLWrappers", "LZO_jll", "Libdl", "Pixman_jll", "Pkg", "Xorg_libXext_jll", "Xorg_libXrender_jll", "Zlib_jll", "libpng_jll"] git-tree-sha1 = "e2f47f6d8337369411569fd45ae5753ca10394c6" @@ -58,9 +64,9 @@ version = "1.16.0+6" [[CategoricalArrays]] deps = ["DataAPI", "Future", "JSON", "Missings", "Printf", "Statistics", "StructTypes", "Unicode"] -git-tree-sha1 = "5861101791fa76fafe8dddefd70ffbfe4e33ecae" +git-tree-sha1 = "99809999c8ee01fa89498480b147f7394ea5450f" uuid = "324d7699-5711-5eae-9e2f-1d82baa6b597" -version = "0.9.0" +version = "0.9.2" [[ChainRulesCore]] deps = ["Compat", "LinearAlgebra", "SparseArrays"] @@ -118,15 +124,15 @@ uuid = "a8cc5b0e-0ffa-5ad4-8c14-923d3ee1735f" version = "4.0.4" [[DataAPI]] -git-tree-sha1 = "6d64b28d291cb94a0d84e6e41081fb081e7f717f" +git-tree-sha1 = "8ab70b4de35bb3b8cc19654f6b893cf5164f8ee8" uuid = "9a962f9c-6df0-11e9-0e5d-c546b8b5ee8a" -version = "1.5.0" +version = "1.5.1" [[DataFrames]] deps = ["CategoricalArrays", "Compat", "DataAPI", "Future", "InvertedIndices", "IteratorInterfaceExtensions", "LinearAlgebra", "Markdown", "Missings", "PooledArrays", "PrettyTables", "Printf", "REPL", "Reexport", "SortingAlgorithms", "Statistics", "TableTraits", "Tables", "Unicode"] -git-tree-sha1 = "c5f9f2385a52e35c75efd90911091f8a749177a5" +git-tree-sha1 = "b0db5579803eabb33f1274ca7ca2f472fdfb7f2a" uuid = "a93c6f00-e57d-5684-b7b6-d8193f3e46c0" -version = "0.22.4" +version = "0.22.5" [[DataStructures]] deps = ["Compat", "InteractiveUtils", "OrderedCollections"] @@ -178,6 +184,17 @@ git-tree-sha1 = "1402e52fcda25064f51c77a9655ce8680b76acf0" uuid = "2e619515-83b5-522b-bb60-26c02a35a201" version = "2.2.7+6" +[[ExprTools]] +git-tree-sha1 = "10407a39b87f29d47ebaca8edbc75d7c302ff93e" +uuid = "e2ba6199-217a-4e67-a87a-7c52f15ade04" +version = "0.1.3" + +[[EzXML]] +deps = ["Printf", "XML2_jll"] +git-tree-sha1 = "0fa3b52a04a4e210aeb1626def9c90df3ae65268" +uuid = "8f5d6c58-4d21-5cfd-889c-e3ad7ee6a615" +version = "1.1.0" + [[FFMPEG]] deps = ["FFMPEG_jll", "x264_jll"] git-tree-sha1 = "9a73ffdc375be61b0e4516d83d880b265366fe1f" @@ -308,6 +325,12 @@ version = "0.2.10" deps = ["Markdown"] uuid = "b77e0a4c-d291-57a0-90e8-8db25a27a240" +[[Intervals]] +deps = ["Dates", "Printf", "RecipesBase", "Serialization", "TimeZones"] +git-tree-sha1 = "323a38ed1952d30586d0fe03412cde9399d3618b" +uuid = "d8418881-c3e1-53bb-8760-2df7ec849ed5" +version = "1.5.0" + [[InvertedIndices]] deps = ["Test"] git-tree-sha1 = "15732c475062348b0165684ffe28e85ea8396afc" @@ -496,6 +519,12 @@ version = "0.4.5" [[Mmap]] uuid = "a63ad114-7e13-5084-954f-fe012c677804" +[[Mocking]] +deps = ["ExprTools"] +git-tree-sha1 = "916b850daad0d46b8c71f65f719c49957e9513ed" +uuid = "78c3b35d-d492-501b-9361-3d52fe80e533" +version = "0.7.1" + [[MozillaCACerts_jll]] uuid = "14a3606d-f60d-562e-9121-12d972cd8159" @@ -600,9 +629,9 @@ version = "1.10.2" [[PooledArrays]] deps = ["DataAPI"] -git-tree-sha1 = "b1333d4eced1826e15adbdf01a4ecaccca9d353c" +git-tree-sha1 = "0e8f5c428a41a81cd71f76d76f2fc3415fe5a676" uuid = "2dfb63ee-cc39-5dd5-95bd-886bf059d720" -version = "0.5.3" +version = "1.1.0" [[PrettyTables]] deps = ["Crayons", "Formatting", "Markdown", "Reexport", "Tables"] @@ -688,6 +717,12 @@ git-tree-sha1 = "ad4b278adb62d185bbcb6864dc24959ab0627bf6" uuid = "6c6a2e73-6563-6170-7368-637461726353" version = "1.0.3" +[[SentinelArrays]] +deps = ["Dates", "Random"] +git-tree-sha1 = "6ccde405cf0759eba835eb613130723cb8f10ff9" +uuid = "91c51154-3ec4-41a3-a24f-3f23e20d615c" +version = "1.2.16" + [[Serialization]] uuid = "9e88b42a-f829-5b0c-bbe9-9e923198166b" @@ -806,11 +841,17 @@ git-tree-sha1 = "269f5c1955c1194086cf6d2029aa4a0b4fb8018b" uuid = "ac1d9e8a-700a-412c-b207-f0111f4b6c0d" version = "0.1.7" +[[TimeZones]] +deps = ["Dates", "EzXML", "Mocking", "Pkg", "Printf", "RecipesBase", "Serialization", "Unicode"] +git-tree-sha1 = "4ba8a9579a243400db412b50300cd61d7447e583" +uuid = "f269a46b-ccf7-5d73-abea-4c690281aa53" +version = "1.5.3" + [[Transducers]] deps = ["ArgCheck", "BangBang", "CompositionsBase", "DefineSingletons", "Distributed", "InitialValues", "Logging", "Markdown", "MicroCollections", "Requires", "Setfield", "SplittablesBase", "Tables"] -git-tree-sha1 = "02413f5795ad272f9f9912e4e4c83d9b1572750c" +git-tree-sha1 = "9550eba57ebc2f7677c4c946aaca56e149ca73ff" uuid = "28d57a85-8fef-5791-bfe6-a80928e7c999" -version = "0.4.58" +version = "0.4.59" [[UUIDs]] deps = ["Random", "SHA"] diff --git a/Project.toml b/Project.toml index 078e8686c335ed0bca0065806f77c21affc3e97d..923bdb0023c373be5345e28486135f2a81fad8b7 100644 --- a/Project.toml +++ b/Project.toml @@ -5,11 +5,13 @@ version = "0.1.0" [deps] BenchmarkTools = "6e4b80f9-dd63-53aa-95a3-0cdb28fa8baf" +CSV = "336ed68f-0bac-5ca0-87d4-7b16caf5d00b" DataFrames = "a93c6f00-e57d-5684-b7b6-d8193f3e46c0" Dates = "ade2ca70-3891-5945-98fb-dc099432e06a" DelimitedFiles = "8bb1440f-4735-579b-a4ab-409b98df4dab" Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f" ImportAll = "c65182e5-40f4-518f-8165-175b85689199" +Intervals = "d8418881-c3e1-53bb-8760-2df7ec849ed5" LabelledArrays = "2ee39098-c373-598a-b85f-a56591580800" LightGraphs = "093fc24a-ae57-5d10-9952-331d41423f4d" NamedTupleTools = "d9ec5142-1e00-5aa0-9d6a-321866360f50" diff --git a/network-data/Timeuse/HH/Duration-fitting.py b/network-data/Timeuse/HH/Duration-fitting.py index 83b920de16d649a50d21861ec83160d70b83dcff..5414990f3ff076064e0e9dd9c05bb06c1462bdb1 100644 --- a/network-data/Timeuse/HH/Duration-fitting.py +++ b/network-data/Timeuse/HH/Duration-fitting.py @@ -111,6 +111,7 @@ def ErrfunPOIS(params): Subdf = HHYMO[Bools] # Compute error on subpopulation Errcol = Subdf.apply(ExpErrorPOIS, axis=1, args=[LAM]) + return Errcol.sum()/subsize # Set parameter bounds for fitting @@ -118,23 +119,33 @@ MUbounds = (6,12*6) SIGMAbounds = (1,48) BoundsNORM = [MUbounds for i in range(6)] + [SIGMAbounds for i in range(6)] BoundsPOIS = [MUbounds for i in range(6)] +init = [MUbounds[1] for i in range(6)] + [SIGMAbounds[1] for i in range(6)] + +def test(): + ErrfunNORM(init) # Run both normal and poisson fitting numruns = 5 NORMSave = {} POISSave = {} -for i in range(numruns): - NORMFIT = optimize.differential_evolution(ErrfunNORM, bounds = BoundsNORM, - disp = True, maxiter = 175, - popsize = 10) - POISFIT = optimize.differential_evolution(ErrfunPOIS, bounds=BoundsPOIS, - disp = True, maxiter = 150, - popsize = 10) - NORMSave[i] = list(NORMFIT.get('x')) + [float(NORMFIT.get('fun'))] - POISSave[i] = list(POISFIT.get('x')) + [float(POISFIT.get('fun'))] - -pd.DataFrame.from_dict(data=NORMSave, orient='index').to_csv('NormalFit-Feb2.csv', header=False) -pd.DataFrame.from_dict(data=POISSave, orient='index').to_csv('PoissonFit-Feb2.csv', header=False) + +if __name__ == "__main__": + test() + print("doot") + import timeit + print(timeit.timeit("test()", setup="from __main__ import test",number=100)/100) +# for i in range(numruns): +# NORMFIT = optimize.differential_evolution(ErrfunNORM, bounds = BoundsNORM, +# disp = True, maxiter = 175, +# popsize = 10) +# POISFIT = optimize.differential_evolution(ErrfunPOIS, bounds=BoundsPOIS, +# disp = True, maxiter = 150, +# popsize = 10) +# NORMSave[i] = list(NORMFIT.get('x')) + [float(NORMFIT.get('fun'))] +# POISSave[i] = list(POISFIT.get('x')) + [float(POISFIT.get('fun'))] + +# pd.DataFrame.from_dict(data=NORMSave, orient='index').to_csv('NormalFit-Feb2.csv', header=False) +# pd.DataFrame.from_dict(data=POISSave, orient='index').to_csv('PoissonFit-Feb2.csv', header=False) #pd.DataFrame(durparamsfit+errorfit).to_csv("HHNormalFit.csv", index=False) diff --git a/network-data/Timeuse/HH/Intervals_Model.py b/network-data/Timeuse/HH/Intervals_Model.py index 7df3280805b991e2ca44cd42c6205f57b821ccd7..e4f1f7427aad9b63e26d564552d406980b17d898 100644 --- a/network-data/Timeuse/HH/Intervals_Model.py +++ b/network-data/Timeuse/HH/Intervals_Model.py @@ -41,11 +41,14 @@ def StartSample(N, Sparam): # Samples interval configurations for a given sequence of durations def IntSample(N, Sparam, durlist): + + print(durlist) numcontact = len(durlist) out = {} for i in range(N): S = StartSample(numcontact, Sparam) E = [(S[j] + durlist[j])%(Durmax-1) for j in range(numcontact)] + out[i] = list(zip(S,E)) return out diff --git a/network-data/Timeuse/HH/__pycache__/Intervals_Model.cpython-39.pyc b/network-data/Timeuse/HH/__pycache__/Intervals_Model.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a4ac28a40a3e48286a0fa1a256af3d3443042c77 Binary files /dev/null and b/network-data/Timeuse/HH/__pycache__/Intervals_Model.cpython-39.pyc differ diff --git a/norm.dat b/norm.dat new file mode 100644 index 0000000000000000000000000000000000000000..1dc47ed8ace146c314760eff22cc4907eac137a1 Binary files /dev/null and b/norm.dat differ diff --git a/norm.png b/norm.png new file mode 100644 index 0000000000000000000000000000000000000000..fc4a66770ad609ecc108c30afe99271849ad6265 Binary files /dev/null and b/norm.png differ diff --git a/norm_dists.pdf b/norm_dists.pdf new file mode 100644 index 0000000000000000000000000000000000000000..7a0f7d30ecc8afc366b354e564250291be6a60d5 Binary files /dev/null and b/norm_dists.pdf differ diff --git a/pois.dat b/pois.dat new file mode 100644 index 0000000000000000000000000000000000000000..e35cbfe5f46b9f81d8d2c1e946e6764d2787095f Binary files /dev/null and b/pois.dat differ diff --git a/pois.png b/pois.png new file mode 100644 index 0000000000000000000000000000000000000000..c3c48ee3eafc5580839626066f70dce589b2641c Binary files /dev/null and b/pois.png differ diff --git a/pois_dists.pdf b/pois_dists.pdf new file mode 100644 index 0000000000000000000000000000000000000000..beeb87127c818fa4f1e9319e4c4e7075a6e418c4 Binary files /dev/null and b/pois_dists.pdf differ diff --git a/scratch_code/intervals_model_but_fast.jl b/scratch_code/intervals_model_but_fast.jl new file mode 100644 index 0000000000000000000000000000000000000000..a57b1b23641c8829f62696f3d948cbd874937a82 --- /dev/null +++ b/scratch_code/intervals_model_but_fast.jl @@ -0,0 +1,231 @@ +module intervals_fast + +using Intervals +using CSV +using DataFrames +using RandomNumbers.Xorshifts +using StatsBase +using Distributions +import AxisKeys +const HHYMO = DataFrame(CSV.File("network-data/Timeuse/HH/HHYMO.csv")) +const rng = Xoroshiro128Plus() +const YOUNG, MIDDLE,OLD = 1,2,3 +const cnst = ( + # Set the underlying parameters for the intervals model + Sparam = [60,12], + # Set parameters for intervals sample and subpopulation size + numsamples = 100, + subsize = size(HHYMO)[1], + durmax = 145, + # Swap age brackets for numbers + swap = Dict("Y" => YOUNG, "M" => MIDDLE, "O" => OLD), + # Total weight in survey + Wghttotal = sum(HHYMO[:,"WGHT_PER"]), + + MUbounds = (6,12*6), + SIGMAbounds = (1,48), +) +function make_dat_array() + durs = hcat( + Int.(HHYMO[!,"YDUR"*string(cnst.Sparam[2])]), + Int.(HHYMO[!,"MDUR"*string(cnst.Sparam[2])]), + Int.(HHYMO[!,"ODUR"*string(cnst.Sparam[2])]), + ) + nums = hcat( + Int.(HHYMO[!,"YNUM"]), + Int.(HHYMO[!,"MNUM"]), + Int.(HHYMO[!,"ONUM"]), + ) + + WGHT = Weights(HHYMO[!,"WGHT_PER"]./cnst.Wghttotal) + AGERESP = map(r -> cnst.swap[r],HHYMO[!,"AGERESP"]) + return (; + nums, + durs, + WGHT, + AGERESP + ) +end + +const dat = make_dat_array() + + +function coverage!(cov,S_j,E_j) + if E_j < S_j + push!(cov,Interval(0,E_j)) + push!(cov,Interval(S_j,cnst.durmax-1)) + else + push!(cov,Interval(S_j,E_j)) + end +end + +function tot_dur_sample(n, dist,durlist) + if isempty(durlist) + return 0 + end + total_dur = 0 + numcontact = length(durlist) + for i in 1:n + cov1 = Vector{Interval{Int64, Closed, Closed}}() + for j in 1:numcontact + S_j = Int(trunc(rand(rng,dist))) % 144 + E_j = (S_j + durlist[j])%(cnst.durmax-1) + coverage!(cov1,S_j,E_j) + end + union!(cov1) + total_dur += mapreduce(Intervals.span,+,cov1) + end + return total_dur +end +function err_norm(params) + μ = as_symmetric_matrix(params[1:6]) + σ = as_symmetric_matrix(params[7:12]) + # row_ids = sample(rng,1:length(dat.WGHT), dat.WGHT,cnst.subsize) + age_dists = [Normal(μ[i,j],σ[i,j]) for i in YOUNG:OLD, j in YOUNG:OLD] + duration_subarray = dat.durs#@view dat.durs[row_ids,:] + num_contacts_subarray = dat.nums#@view dat.nums[row_ids,:] + + # display(num_contacts_subarray) + AGERESP = dat.AGERESP #@view dat.AGERESP[row_ids] + errsum = 0 + @inbounds for i = 1:cnst.subsize + age_sample = AGERESP[i] + @inbounds for age_j in YOUNG:OLD + running_sum = 0 + durs = Int.(trunc.(rand(rng,age_dists[age_sample,age_j],num_contacts_subarray[i,age_j]))) .% 144 + expdur = tot_dur_sample(cnst.numsamples,cnst.Sparam,durs) + errsum += (expdur/cnst.numsamples - duration_subarray[i,age_j])^2 + end + end + return errsum/cnst.subsize +end +function as_symmetric_matrix(l) + return [ + l[1] l[2] l[3] + l[2] l[4] l[5] + l[3] l[5] l[6] + ] +end +function err_poisson(params) + μ = as_symmetric_matrix(params) + # row_ids = sample(rng,1:length(dat.WGHT), dat.WGHT,cnst.subsize) + age_dists = [Poisson(μ[i,j]) for i in YOUNG:OLD, j in YOUNG:OLD] + duration_subarray = dat.durs#@view dat.durs[row_ids,:] + num_contacts_subarray = dat.nums#@view dat.nums[row_ids,:] + + # display(num_contacts_subarray) + AGERESP = dat.AGERESP #@view dat.AGERESP[row_ids] + + errsum = 0 + @inbounds for i = 1:cnst.subsize + age_sample = AGERESP[i] + @inbounds for age_j in YOUNG:OLD + running_sum = 0 + durs = Int.(trunc.(rand(rng,age_dists[age_sample,age_j],num_contacts_subarray[i,age_j]))) .% 144 + expdur = tot_dur_sample(cnst.numsamples,cnst.Sparam,durs) + errsum += (expdur/cnst.numsamples - duration_subarray[i,age_j])^2 + end + end + return errsum/cnst.subsize +end + +using KissABC +using BenchmarkTools +using Serialization +using Plots +function bayesian_estimate() + +# Set parameter bounds for fitting + BoundsNORM = vcat([cnst.MUbounds for i = 1:6], [cnst.SIGMAbounds for i = 1:6]) + + norm_init = vcat([cnst.MUbounds[1] for i = 1:6], [cnst.SIGMAbounds[1] for i = 1:6]) + BoundsPOIS = [cnst.MUbounds for i in 1:6] + pois_init = [cnst.MUbounds[1] for i = 1:6] + + priors_norm = Factored([Uniform(l,u) for (l,u) in BoundsNORM]...) + @btime err_norm($norm_init) + out_norm = smc(priors_norm,err_norm, verbose=true, nparticles=200, alpha=0.95, parallel = true) + # out_norm = KissABC.ABCDE(priors_norm,err_norm,0; verbose=true, nparticles=200, generations = 0.1e3,earlystop = false,parallel = true) + + serialize("norm.dat",out_norm) + + + priors_pois = Factored([Uniform(l,u) for (l,u) in BoundsPOIS]...) + # out_pois = KissABC.ABCDE(priors_pois,err_poisson, 0; verbose=true, nparticles=200, generations = 0.1e3,earlystop = false,parallel = true) + out_pois = smc(priors_pois,err_poisson, verbose=true, nparticles=200, alpha=0.95, parallel = true) + + serialize("pois.dat",out_pois) +end + +function plot_estimates() + + estimate = deserialize("norm.dat") + p_list = [] + + for i in 1:length(estimate.P) + a = stephist( + estimate.P[i].particles; + normalize = true, + title = i <=6 ? "μ_$i" : "σ_$i" + ) + push!(p_list,a) + end + p = plot(p_list...) + savefig(p,"norm.png") + + μ_estimate_as_array = as_symmetric_matrix(estimate.P[1:6]) + σ_estimate_as_array = as_symmetric_matrix(estimate.P[7:12]) + p_matrix = map(x -> plot(),σ_estimate_as_array) + for i in YOUNG:OLD, j in YOUNG:OLD + + dist = Normal.(μ_estimate_as_array[i,j].particles,σ_estimate_as_array[i,j].particles) + + data = [pdf.(dist,i) for i in 0.0:144.0] + mean_dat = median.(data) + err_down = quantile.(data,0.05) + err_up = quantile.(data,0.95) + p_matrix[i,j] = plot(0:144,mean_dat; ribbon = ( mean_dat .- err_down,err_up .- mean_dat),legend = false) + + end + plot!(p_matrix[end,1]; legend = true) + p = plot(p_matrix..., size = (600,400)) + savefig(p,"norm_dists.pdf") + + + + + estimate = deserialize("pois.dat") + p_list = [] + for i in 1:length(estimate.P) + a = stephist( + estimate.P[i].particles; + normalize = true, + title = i <=6 ? "μ_$i" : "σ_$i" + ) + push!(p_list,a) + end + p = plot(p_list...) + savefig(p,"pois.png") + + μ_estimate_as_array = as_symmetric_matrix(estimate.P[1:6]) + p_matrix = map(x -> plot(),μ_estimate_as_array) + for i in YOUNG:OLD, j in YOUNG:OLD + + dist = Poisson.(μ_estimate_as_array[i,j].particles) + + data = [pdf.(dist,i) for i in 0.0:144.0] + mean_dat = median.(data) + err_down = quantile.(data,0.05) + err_up = quantile.(data,0.95) + p_matrix[i,j] = plot(0:144,mean_dat; ribbon = ( mean_dat .- err_down,err_up .- mean_dat),legend = false) + + end + plot!(p_matrix[end,1]; legend = true) + p = plot(p_matrix..., size = (600,400)) + savefig(p,"pois_dists.pdf") + + +end + + +end \ No newline at end of file diff --git a/scratch_code/test.jl b/scratch_code/test.jl new file mode 100644 index 0000000000000000000000000000000000000000..0bb07a405ab02aaaa3c0d1067b50edaa0f67e63e --- /dev/null +++ b/scratch_code/test.jl @@ -0,0 +1,14 @@ + +include("intervals_model_but_fast.jl") + +using .intervals_fast + +using BenchmarkTools +using Plots + +pgfplotsx() + +default(dpi = 300) +default(framestyle = :box) +intervals_fast.bayesian_estimate() +intervals_fast.plot_estimates() \ No newline at end of file diff --git a/test/runtests.jl b/test/runtests.jl index 8eda42451d45ceb7492805320f2147005b106037..07a338e53674a5bf74bdd85c2604f7142f3b612a 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -5,7 +5,7 @@ using ThreadsX import StatsBase.mean model_sizes = [(100,1),(1000,3),(5000,3)] vaccination_stratgies = [vaccinate_uniformly!] -vaccination_rates = [0.0001,0.001,0.005,0.01,0.05] +vaccination_rates = [0.001,0.005,0.01,0.05] infection_rates = [0.01,0.05,0.1] agent_models = ThreadsX.map(model_size -> AgentModel(model_size...), model_sizes) dem_cat = AgentDemographic.size + 1 @@ -35,6 +35,7 @@ function vac_rate_test(model,vac_strategy, vac_rate; rng = Xoroshiro128Plus()) steps = 300 sol1,graphs = solve!(u_0,params,steps,model,vac_strategy); total_infections = count(x->x == AgentStatus(3),sol1[end]) + display(total_infections) return total_infections end