Commit 2ce846f5 authored by Peter Jentsch's avatar Peter Jentsch
Browse files

durations done

parent 7501a165
......@@ -22,9 +22,9 @@ version = "0.1.0"
[[ArrayInterface]]
deps = ["IfElse", "LinearAlgebra", "Requires", "SparseArrays"]
git-tree-sha1 = "ef91c543a3a8094eba9b1f7171258b9ecae87dfa"
git-tree-sha1 = "b9c3166c0124f44135419a394f42912c14dcbd80"
uuid = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9"
version = "3.0.0"
version = "3.0.1"
[[Artifacts]]
uuid = "56f22d72-fd6d-98f1-02f0-08ddc0907c33"
......@@ -50,6 +50,12 @@ git-tree-sha1 = "c3598e525718abcc440f69cc6d5f60dda0a1b61e"
uuid = "6e34b625-4abd-537c-b88f-471c36dfa7a0"
version = "1.0.6+5"
[[CSV]]
deps = ["Dates", "Mmap", "Parsers", "PooledArrays", "SentinelArrays", "Tables", "Unicode"]
git-tree-sha1 = "1f79803452adf73e2d3fc84785adb7aaca14db36"
uuid = "336ed68f-0bac-5ca0-87d4-7b16caf5d00b"
version = "0.8.3"
[[Cairo_jll]]
deps = ["Artifacts", "Bzip2_jll", "Fontconfig_jll", "FreeType2_jll", "Glib_jll", "JLLWrappers", "LZO_jll", "Libdl", "Pixman_jll", "Pkg", "Xorg_libXext_jll", "Xorg_libXrender_jll", "Zlib_jll", "libpng_jll"]
git-tree-sha1 = "e2f47f6d8337369411569fd45ae5753ca10394c6"
......@@ -58,9 +64,9 @@ version = "1.16.0+6"
[[CategoricalArrays]]
deps = ["DataAPI", "Future", "JSON", "Missings", "Printf", "Statistics", "StructTypes", "Unicode"]
git-tree-sha1 = "5861101791fa76fafe8dddefd70ffbfe4e33ecae"
git-tree-sha1 = "99809999c8ee01fa89498480b147f7394ea5450f"
uuid = "324d7699-5711-5eae-9e2f-1d82baa6b597"
version = "0.9.0"
version = "0.9.2"
[[ChainRulesCore]]
deps = ["Compat", "LinearAlgebra", "SparseArrays"]
......@@ -118,15 +124,15 @@ uuid = "a8cc5b0e-0ffa-5ad4-8c14-923d3ee1735f"
version = "4.0.4"
[[DataAPI]]
git-tree-sha1 = "6d64b28d291cb94a0d84e6e41081fb081e7f717f"
git-tree-sha1 = "8ab70b4de35bb3b8cc19654f6b893cf5164f8ee8"
uuid = "9a962f9c-6df0-11e9-0e5d-c546b8b5ee8a"
version = "1.5.0"
version = "1.5.1"
[[DataFrames]]
deps = ["CategoricalArrays", "Compat", "DataAPI", "Future", "InvertedIndices", "IteratorInterfaceExtensions", "LinearAlgebra", "Markdown", "Missings", "PooledArrays", "PrettyTables", "Printf", "REPL", "Reexport", "SortingAlgorithms", "Statistics", "TableTraits", "Tables", "Unicode"]
git-tree-sha1 = "c5f9f2385a52e35c75efd90911091f8a749177a5"
git-tree-sha1 = "b0db5579803eabb33f1274ca7ca2f472fdfb7f2a"
uuid = "a93c6f00-e57d-5684-b7b6-d8193f3e46c0"
version = "0.22.4"
version = "0.22.5"
[[DataStructures]]
deps = ["Compat", "InteractiveUtils", "OrderedCollections"]
......@@ -178,6 +184,17 @@ git-tree-sha1 = "1402e52fcda25064f51c77a9655ce8680b76acf0"
uuid = "2e619515-83b5-522b-bb60-26c02a35a201"
version = "2.2.7+6"
[[ExprTools]]
git-tree-sha1 = "10407a39b87f29d47ebaca8edbc75d7c302ff93e"
uuid = "e2ba6199-217a-4e67-a87a-7c52f15ade04"
version = "0.1.3"
[[EzXML]]
deps = ["Printf", "XML2_jll"]
git-tree-sha1 = "0fa3b52a04a4e210aeb1626def9c90df3ae65268"
uuid = "8f5d6c58-4d21-5cfd-889c-e3ad7ee6a615"
version = "1.1.0"
[[FFMPEG]]
deps = ["FFMPEG_jll", "x264_jll"]
git-tree-sha1 = "9a73ffdc375be61b0e4516d83d880b265366fe1f"
......@@ -308,6 +325,12 @@ version = "0.2.10"
deps = ["Markdown"]
uuid = "b77e0a4c-d291-57a0-90e8-8db25a27a240"
[[Intervals]]
deps = ["Dates", "Printf", "RecipesBase", "Serialization", "TimeZones"]
git-tree-sha1 = "323a38ed1952d30586d0fe03412cde9399d3618b"
uuid = "d8418881-c3e1-53bb-8760-2df7ec849ed5"
version = "1.5.0"
[[InvertedIndices]]
deps = ["Test"]
git-tree-sha1 = "15732c475062348b0165684ffe28e85ea8396afc"
......@@ -496,6 +519,12 @@ version = "0.4.5"
[[Mmap]]
uuid = "a63ad114-7e13-5084-954f-fe012c677804"
[[Mocking]]
deps = ["ExprTools"]
git-tree-sha1 = "916b850daad0d46b8c71f65f719c49957e9513ed"
uuid = "78c3b35d-d492-501b-9361-3d52fe80e533"
version = "0.7.1"
[[MozillaCACerts_jll]]
uuid = "14a3606d-f60d-562e-9121-12d972cd8159"
......@@ -600,9 +629,9 @@ version = "1.10.2"
[[PooledArrays]]
deps = ["DataAPI"]
git-tree-sha1 = "b1333d4eced1826e15adbdf01a4ecaccca9d353c"
git-tree-sha1 = "0e8f5c428a41a81cd71f76d76f2fc3415fe5a676"
uuid = "2dfb63ee-cc39-5dd5-95bd-886bf059d720"
version = "0.5.3"
version = "1.1.0"
[[PrettyTables]]
deps = ["Crayons", "Formatting", "Markdown", "Reexport", "Tables"]
......@@ -688,6 +717,12 @@ git-tree-sha1 = "ad4b278adb62d185bbcb6864dc24959ab0627bf6"
uuid = "6c6a2e73-6563-6170-7368-637461726353"
version = "1.0.3"
[[SentinelArrays]]
deps = ["Dates", "Random"]
git-tree-sha1 = "6ccde405cf0759eba835eb613130723cb8f10ff9"
uuid = "91c51154-3ec4-41a3-a24f-3f23e20d615c"
version = "1.2.16"
[[Serialization]]
uuid = "9e88b42a-f829-5b0c-bbe9-9e923198166b"
......@@ -806,11 +841,17 @@ git-tree-sha1 = "269f5c1955c1194086cf6d2029aa4a0b4fb8018b"
uuid = "ac1d9e8a-700a-412c-b207-f0111f4b6c0d"
version = "0.1.7"
[[TimeZones]]
deps = ["Dates", "EzXML", "Mocking", "Pkg", "Printf", "RecipesBase", "Serialization", "Unicode"]
git-tree-sha1 = "4ba8a9579a243400db412b50300cd61d7447e583"
uuid = "f269a46b-ccf7-5d73-abea-4c690281aa53"
version = "1.5.3"
[[Transducers]]
deps = ["ArgCheck", "BangBang", "CompositionsBase", "DefineSingletons", "Distributed", "InitialValues", "Logging", "Markdown", "MicroCollections", "Requires", "Setfield", "SplittablesBase", "Tables"]
git-tree-sha1 = "02413f5795ad272f9f9912e4e4c83d9b1572750c"
git-tree-sha1 = "9550eba57ebc2f7677c4c946aaca56e149ca73ff"
uuid = "28d57a85-8fef-5791-bfe6-a80928e7c999"
version = "0.4.58"
version = "0.4.59"
[[UUIDs]]
deps = ["Random", "SHA"]
......
......@@ -5,11 +5,13 @@ version = "0.1.0"
[deps]
BenchmarkTools = "6e4b80f9-dd63-53aa-95a3-0cdb28fa8baf"
CSV = "336ed68f-0bac-5ca0-87d4-7b16caf5d00b"
DataFrames = "a93c6f00-e57d-5684-b7b6-d8193f3e46c0"
Dates = "ade2ca70-3891-5945-98fb-dc099432e06a"
DelimitedFiles = "8bb1440f-4735-579b-a4ab-409b98df4dab"
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
ImportAll = "c65182e5-40f4-518f-8165-175b85689199"
Intervals = "d8418881-c3e1-53bb-8760-2df7ec849ed5"
LabelledArrays = "2ee39098-c373-598a-b85f-a56591580800"
LightGraphs = "093fc24a-ae57-5d10-9952-331d41423f4d"
NamedTupleTools = "d9ec5142-1e00-5aa0-9d6a-321866360f50"
......
......@@ -111,6 +111,7 @@ def ErrfunPOIS(params):
Subdf = HHYMO[Bools]
# Compute error on subpopulation
Errcol = Subdf.apply(ExpErrorPOIS, axis=1, args=[LAM])
return Errcol.sum()/subsize
# Set parameter bounds for fitting
......@@ -118,23 +119,33 @@ MUbounds = (6,12*6)
SIGMAbounds = (1,48)
BoundsNORM = [MUbounds for i in range(6)] + [SIGMAbounds for i in range(6)]
BoundsPOIS = [MUbounds for i in range(6)]
init = [MUbounds[1] for i in range(6)] + [SIGMAbounds[1] for i in range(6)]
def test():
ErrfunNORM(init)
# Run both normal and poisson fitting
numruns = 5
NORMSave = {}
POISSave = {}
for i in range(numruns):
NORMFIT = optimize.differential_evolution(ErrfunNORM, bounds = BoundsNORM,
disp = True, maxiter = 175,
popsize = 10)
POISFIT = optimize.differential_evolution(ErrfunPOIS, bounds=BoundsPOIS,
disp = True, maxiter = 150,
popsize = 10)
NORMSave[i] = list(NORMFIT.get('x')) + [float(NORMFIT.get('fun'))]
POISSave[i] = list(POISFIT.get('x')) + [float(POISFIT.get('fun'))]
pd.DataFrame.from_dict(data=NORMSave, orient='index').to_csv('NormalFit-Feb2.csv', header=False)
pd.DataFrame.from_dict(data=POISSave, orient='index').to_csv('PoissonFit-Feb2.csv', header=False)
if __name__ == "__main__":
test()
print("doot")
import timeit
print(timeit.timeit("test()", setup="from __main__ import test",number=100)/100)
# for i in range(numruns):
# NORMFIT = optimize.differential_evolution(ErrfunNORM, bounds = BoundsNORM,
# disp = True, maxiter = 175,
# popsize = 10)
# POISFIT = optimize.differential_evolution(ErrfunPOIS, bounds=BoundsPOIS,
# disp = True, maxiter = 150,
# popsize = 10)
# NORMSave[i] = list(NORMFIT.get('x')) + [float(NORMFIT.get('fun'))]
# POISSave[i] = list(POISFIT.get('x')) + [float(POISFIT.get('fun'))]
# pd.DataFrame.from_dict(data=NORMSave, orient='index').to_csv('NormalFit-Feb2.csv', header=False)
# pd.DataFrame.from_dict(data=POISSave, orient='index').to_csv('PoissonFit-Feb2.csv', header=False)
#pd.DataFrame(durparamsfit+errorfit).to_csv("HHNormalFit.csv", index=False)
......
......@@ -41,11 +41,14 @@ def StartSample(N, Sparam):
# Samples interval configurations for a given sequence of durations
def IntSample(N, Sparam, durlist):
print(durlist)
numcontact = len(durlist)
out = {}
for i in range(N):
S = StartSample(numcontact, Sparam)
E = [(S[j] + durlist[j])%(Durmax-1) for j in range(numcontact)]
out[i] = list(zip(S,E))
return out
......
File added
norm.png

72.6 KB

File added
File added
pois.png

47.6 KB

File added
module intervals_fast
using Intervals
using CSV
using DataFrames
using RandomNumbers.Xorshifts
using StatsBase
using Distributions
import AxisKeys
const HHYMO = DataFrame(CSV.File("network-data/Timeuse/HH/HHYMO.csv"))
const rng = Xoroshiro128Plus()
const YOUNG, MIDDLE,OLD = 1,2,3
const cnst = (
# Set the underlying parameters for the intervals model
Sparam = [60,12],
# Set parameters for intervals sample and subpopulation size
numsamples = 100,
subsize = size(HHYMO)[1],
durmax = 145,
# Swap age brackets for numbers
swap = Dict("Y" => YOUNG, "M" => MIDDLE, "O" => OLD),
# Total weight in survey
Wghttotal = sum(HHYMO[:,"WGHT_PER"]),
MUbounds = (6,12*6),
SIGMAbounds = (1,48),
)
function make_dat_array()
durs = hcat(
Int.(HHYMO[!,"YDUR"*string(cnst.Sparam[2])]),
Int.(HHYMO[!,"MDUR"*string(cnst.Sparam[2])]),
Int.(HHYMO[!,"ODUR"*string(cnst.Sparam[2])]),
)
nums = hcat(
Int.(HHYMO[!,"YNUM"]),
Int.(HHYMO[!,"MNUM"]),
Int.(HHYMO[!,"ONUM"]),
)
WGHT = Weights(HHYMO[!,"WGHT_PER"]./cnst.Wghttotal)
AGERESP = map(r -> cnst.swap[r],HHYMO[!,"AGERESP"])
return (;
nums,
durs,
WGHT,
AGERESP
)
end
const dat = make_dat_array()
function coverage!(cov,S_j,E_j)
if E_j < S_j
push!(cov,Interval(0,E_j))
push!(cov,Interval(S_j,cnst.durmax-1))
else
push!(cov,Interval(S_j,E_j))
end
end
function tot_dur_sample(n, dist,durlist)
if isempty(durlist)
return 0
end
total_dur = 0
numcontact = length(durlist)
for i in 1:n
cov1 = Vector{Interval{Int64, Closed, Closed}}()
for j in 1:numcontact
S_j = Int(trunc(rand(rng,dist))) % 144
E_j = (S_j + durlist[j])%(cnst.durmax-1)
coverage!(cov1,S_j,E_j)
end
union!(cov1)
total_dur += mapreduce(Intervals.span,+,cov1)
end
return total_dur
end
function err_norm(params)
μ = as_symmetric_matrix(params[1:6])
σ = as_symmetric_matrix(params[7:12])
# row_ids = sample(rng,1:length(dat.WGHT), dat.WGHT,cnst.subsize)
age_dists = [Normal(μ[i,j],σ[i,j]) for i in YOUNG:OLD, j in YOUNG:OLD]
duration_subarray = dat.durs#@view dat.durs[row_ids,:]
num_contacts_subarray = dat.nums#@view dat.nums[row_ids,:]
# display(num_contacts_subarray)
AGERESP = dat.AGERESP #@view dat.AGERESP[row_ids]
errsum = 0
@inbounds for i = 1:cnst.subsize
age_sample = AGERESP[i]
@inbounds for age_j in YOUNG:OLD
running_sum = 0
durs = Int.(trunc.(rand(rng,age_dists[age_sample,age_j],num_contacts_subarray[i,age_j]))) .% 144
expdur = tot_dur_sample(cnst.numsamples,cnst.Sparam,durs)
errsum += (expdur/cnst.numsamples - duration_subarray[i,age_j])^2
end
end
return errsum/cnst.subsize
end
function as_symmetric_matrix(l)
return [
l[1] l[2] l[3]
l[2] l[4] l[5]
l[3] l[5] l[6]
]
end
function err_poisson(params)
μ = as_symmetric_matrix(params)
# row_ids = sample(rng,1:length(dat.WGHT), dat.WGHT,cnst.subsize)
age_dists = [Poisson(μ[i,j]) for i in YOUNG:OLD, j in YOUNG:OLD]
duration_subarray = dat.durs#@view dat.durs[row_ids,:]
num_contacts_subarray = dat.nums#@view dat.nums[row_ids,:]
# display(num_contacts_subarray)
AGERESP = dat.AGERESP #@view dat.AGERESP[row_ids]
errsum = 0
@inbounds for i = 1:cnst.subsize
age_sample = AGERESP[i]
@inbounds for age_j in YOUNG:OLD
running_sum = 0
durs = Int.(trunc.(rand(rng,age_dists[age_sample,age_j],num_contacts_subarray[i,age_j]))) .% 144
expdur = tot_dur_sample(cnst.numsamples,cnst.Sparam,durs)
errsum += (expdur/cnst.numsamples - duration_subarray[i,age_j])^2
end
end
return errsum/cnst.subsize
end
using KissABC
using BenchmarkTools
using Serialization
using Plots
function bayesian_estimate()
# Set parameter bounds for fitting
BoundsNORM = vcat([cnst.MUbounds for i = 1:6], [cnst.SIGMAbounds for i = 1:6])
norm_init = vcat([cnst.MUbounds[1] for i = 1:6], [cnst.SIGMAbounds[1] for i = 1:6])
BoundsPOIS = [cnst.MUbounds for i in 1:6]
pois_init = [cnst.MUbounds[1] for i = 1:6]
priors_norm = Factored([Uniform(l,u) for (l,u) in BoundsNORM]...)
@btime err_norm($norm_init)
out_norm = smc(priors_norm,err_norm, verbose=true, nparticles=200, alpha=0.95, parallel = true)
# out_norm = KissABC.ABCDE(priors_norm,err_norm,0; verbose=true, nparticles=200, generations = 0.1e3,earlystop = false,parallel = true)
serialize("norm.dat",out_norm)
priors_pois = Factored([Uniform(l,u) for (l,u) in BoundsPOIS]...)
# out_pois = KissABC.ABCDE(priors_pois,err_poisson, 0; verbose=true, nparticles=200, generations = 0.1e3,earlystop = false,parallel = true)
out_pois = smc(priors_pois,err_poisson, verbose=true, nparticles=200, alpha=0.95, parallel = true)
serialize("pois.dat",out_pois)
end
function plot_estimates()
estimate = deserialize("norm.dat")
p_list = []
for i in 1:length(estimate.P)
a = stephist(
estimate.P[i].particles;
normalize = true,
title = i <=6 ? "μ_$i" : "σ_$i"
)
push!(p_list,a)
end
p = plot(p_list...)
savefig(p,"norm.png")
μ_estimate_as_array = as_symmetric_matrix(estimate.P[1:6])
σ_estimate_as_array = as_symmetric_matrix(estimate.P[7:12])
p_matrix = map(x -> plot(),σ_estimate_as_array)
for i in YOUNG:OLD, j in YOUNG:OLD
dist = Normal.(μ_estimate_as_array[i,j].particles,σ_estimate_as_array[i,j].particles)
data = [pdf.(dist,i) for i in 0.0:144.0]
mean_dat = median.(data)
err_down = quantile.(data,0.05)
err_up = quantile.(data,0.95)
p_matrix[i,j] = plot(0:144,mean_dat; ribbon = ( mean_dat .- err_down,err_up .- mean_dat),legend = false)
end
plot!(p_matrix[end,1]; legend = true)
p = plot(p_matrix..., size = (600,400))
savefig(p,"norm_dists.pdf")
estimate = deserialize("pois.dat")
p_list = []
for i in 1:length(estimate.P)
a = stephist(
estimate.P[i].particles;
normalize = true,
title = i <=6 ? "μ_$i" : "σ_$i"
)
push!(p_list,a)
end
p = plot(p_list...)
savefig(p,"pois.png")
μ_estimate_as_array = as_symmetric_matrix(estimate.P[1:6])
p_matrix = map(x -> plot(),μ_estimate_as_array)
for i in YOUNG:OLD, j in YOUNG:OLD
dist = Poisson.(μ_estimate_as_array[i,j].particles)
data = [pdf.(dist,i) for i in 0.0:144.0]
mean_dat = median.(data)
err_down = quantile.(data,0.05)
err_up = quantile.(data,0.95)
p_matrix[i,j] = plot(0:144,mean_dat; ribbon = ( mean_dat .- err_down,err_up .- mean_dat),legend = false)
end
plot!(p_matrix[end,1]; legend = true)
p = plot(p_matrix..., size = (600,400))
savefig(p,"pois_dists.pdf")
end
end
\ No newline at end of file
include("intervals_model_but_fast.jl")
using .intervals_fast
using BenchmarkTools
using Plots
pgfplotsx()
default(dpi = 300)
default(framestyle = :box)
intervals_fast.bayesian_estimate()
intervals_fast.plot_estimates()
\ No newline at end of file
......@@ -5,7 +5,7 @@ using ThreadsX
import StatsBase.mean
model_sizes = [(100,1),(1000,3),(5000,3)]
vaccination_stratgies = [vaccinate_uniformly!]
vaccination_rates = [0.0001,0.001,0.005,0.01,0.05]
vaccination_rates = [0.001,0.005,0.01,0.05]
infection_rates = [0.01,0.05,0.1]
agent_models = ThreadsX.map(model_size -> AgentModel(model_size...), model_sizes)
dem_cat = AgentDemographic.size + 1
......@@ -35,6 +35,7 @@ function vac_rate_test(model,vac_strategy, vac_rate; rng = Xoroshiro128Plus())
steps = 300
sol1,graphs = solve!(u_0,params,steps,model,vac_strategy);
total_infections = count(x->x == AgentStatus(3),sol1[end])
display(total_infections)
return total_infections
end
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment