Skip to content
Snippets Groups Projects
Commit 1e18580d authored by Peter Jentsch's avatar Peter Jentsch
Browse files

bugfix and posterior update in IntervalsModel

parent 09cf7f42
No related branches found
No related tags found
No related merge requests found
Showing
with 63 additions and 120 deletions
No preview for this file type
No preview for this file type
File deleted
File deleted
File deleted
File deleted
No preview for this file type
No preview for this file type
No preview for this file type
No preview for this file type
No preview for this file type
No preview for this file type
No preview for this file type
No preview for this file type
No preview for this file type
No preview for this file type
module IntervalsModel module IntervalsModel
export main
using Intervals using Intervals
using CSV using CSV
...@@ -11,25 +13,26 @@ using ZeroWeightedDistributions ...@@ -11,25 +13,26 @@ using ZeroWeightedDistributions
const PACKAGE_FOLDER = dirname(dirname(pathof(IntervalsModel))) const PACKAGE_FOLDER = dirname(dirname(pathof(IntervalsModel)))
using DataStructures:OrderedDict using DataStructures:OrderedDict
using Serialization using Serialization
using KissABC
using BenchmarkTools
using Plots using Plots
const rng = Xoroshiro128Plus(1) const rng = Xoroshiro128Plus(1)
const YOUNG, MIDDLE,OLD = 1,2,3 const YOUNG, MIDDLE,OLD = 1,2,3
const durmax = 144 const durmax = 144
const color_palette = palette(:seaborn_bright) #color theme for the plots const color_palette = palette(:seaborn_bright) #color theme for the plots
const sparam = [60, 12]
# Swap age brackets for numbers
const swap_dict = OrderedDict("Y" => YOUNG, "M" => MIDDLE, "O" => OLD)
include("interval_overlap_sampling.jl") include("interval_overlap_sampling.jl")
include("utils.jl") include("utils.jl")
include("hh_durations_model.jl") include("hh_durations_model.jl")
include("ws_durations_model.jl") include("ws_rest_durations_model.jl")
include("plotting_functions.jl") include("plotting_functions.jl")
using KissABC
using BenchmarkTools
using JLSO
using Plots
const μ_bounds = (1,144) const μ_bounds = (1,144)
const σ_bounds = (1,144) const σ_bounds = (1,144)
...@@ -46,119 +49,32 @@ const α_priors_mean_rest = ( ...@@ -46,119 +49,32 @@ const α_priors_mean_rest = (
O = 0.092857 O = 0.092857
) )
function alpha_matrix(alphas)
M = zeros(length(alphas),length(alphas))
for i in 1:length(alphas), j in 1:length(alphas)
M[i,j] = alphas[i] + alphas[j] - alphas[j]*alphas[i]
end
return [M[1,1], M[1,2],M[1,3],M[2,2],M[2,3],M[3,3]] #lol
end
pgfplotsx() pgfplotsx()
function do_hh(particles) function main()
dists = [ do_hh(400)
Normal, do_ws(400)
Poisson, do_rest(400)
Weibull
]
bounds_list = map(l -> vcat(l...),[
([μ_bounds for i = 1:6], [σ_bounds for i = 1:6]),
([μ_bounds for i = 1:6],),
([(1.0,144) for i = 1:6], [(1.0,144.0) for i = 1:6]),
])
@show bounds_list
fname = "hh"
data_pairs = map(zip(dists,bounds_list)) do (dist,bounds)
priors = Factored([Uniform(l,u) for (l,u) in bounds[1:end]]...) #assume uniform priors
# @btime err_ws($init,$dist) #compute benchmark of the error function, not rly necessary
out = smc(priors,p -> err_hh(p, dist), verbose=true, nparticles=particles, alpha=0.98,parallel = true)
return dist => out
end |> OrderedDict
display(data_pairs)
serialize(joinpath(PACKAGE_FOLDER,"simulation_data","$fname.dat"),data_pairs)
end end
function do_ws(particles) function bayesian_estimation(fname, err_func, priors_list, dists, particles; alpha = 0.995)
dists = [ data_pairs = map(zip(dists,priors_list)) do (dist,priors)
ZWDist{Normal}, out = smc(priors,p -> err_ws(p, dist), verbose=true, nparticles=particles, alpha=alpha, parallel = true)
ZWDist{Poisson},
# ZWDist{Weibull}
]
# Set parameter bounds for fitting
bounds_list = map(l -> vcat(l...),[
[[α_bounds for i = 1:6],[μ_bounds for i = 1:6], [σ_bounds for i = 1:6]],
[[α_bounds for i = 1:6],[μ_bounds for i = 1:6]],
[[α_bounds for i = 1:6],[(0.1,144) for i = 1:6], [(0.1,144.0) for i = 1:6]],
])
fname = "ws"
data_pairs = map(zip(dists,bounds_list)) do (dist,bounds)
init = [b[1] for b in bounds]
priors = Factored([TriangularDist(μ*(0.8),min(1,μ*(1.2),μ)) for μ in alpha_matrix(α_priors_mean_ws)]...,[Uniform(l,u) for (l,u) in bounds[7:end]]...) #assume uniform priors
# @btime err_ws($init,$dist) #compute benchmark of the error function, not rly necessary
out = smc(priors,p -> err_ws(p, dist), verbose=true, nparticles=particles, alpha=0.995, parallel = true)
return dist => out return dist => out
end |> OrderedDict end |> OrderedDict
display(data_pairs) display(data_pairs)
serialize(joinpath(PACKAGE_FOLDER,"simulation_data","$fname.dat"),data_pairs) serialize(joinpath(PACKAGE_FOLDER,"simulation_data","$fname.dat"),data_pairs)
end end
function do_rest(particles)
dists = [
ZWDist{Normal},
ZWDist{Poisson},
ZWDist{Weibull}
]
# Set parameter bounds for fitting
bounds_list = map(l -> vcat(l...),[
[[α_bounds for i = 1:6],[μ_bounds for i = 1:6], [σ_bounds for i = 1:6]],
[[α_bounds for i = 1:6],[μ_bounds for i = 1:6]],
[[α_bounds for i = 1:6],[(0.1,144) for i = 1:6], [(0.1,144.0) for i = 1:6]],
])
# @show bounds_list
fname = "rest"
data_pairs = map(zip(dists,bounds_list)) do (dist,bounds)
init = [b[1] for b in bounds]
priors = Factored([TriangularDist(μ*(0.8),min(1,μ*(1.2),μ)) for μ in alpha_matrix(α_priors_mean_rest)]...,[Uniform(l,u) for (l,u) in bounds[7:end]]...) #assume uniform priors
# @btime err_ws($init,$dist) #compute benchmark of the error function, not rly necessary
out = smc(priors,p -> err_rest(p, dist), verbose=true, nparticles=particles, alpha=0.995, parallel = true) #apply sequential monte carlo with 200 particles
return dist => out
end |> OrderedDict
display(data_pairs)
serialize(joinpath(PACKAGE_FOLDER,"simulation_data","$fname.dat"),data_pairs)
end
function plot_estimates()
data = deserialize(joinpath(PACKAGE_FOLDER,"simulation_data","ws.dat"))
plot_dists("ws",collect(keys(data)),collect(values(data)))
display(collect(keys(data)))
for (k,v) in zip(keys(data),values(data))
plot_posteriors("$(k)_ws",v)
end
data = deserialize(joinpath(PACKAGE_FOLDER,"simulation_data","hh.dat"))
plot_dists("hh",collect(keys(data)),collect(values(data)))
for (k,v) in zip(keys(data),values(data))
plot_posteriors("$(k)_hh",v)
end
data = deserialize(joinpath(PACKAGE_FOLDER,"simulation_data","rest.dat")) function alpha_matrix(alphas)
plot_dists("rest",collect(keys(data)),collect(values(data))) M = zeros(length(alphas),length(alphas))
for (k,v) in zip(keys(data),values(data)) for i in 1:length(alphas), j in 1:length(alphas)
plot_posteriors("$(k)_rest",v) M[i,j] = alphas[i] + alphas[j] - alphas[j]*alphas[i]
end end
return [M[1,1], M[1,2],M[1,3],M[2,2],M[2,3],M[3,3]] #lol
end end
end # module end # module
\ No newline at end of file
...@@ -5,21 +5,31 @@ const HHYMO = DataFrame(CSV.File("$PACKAGE_FOLDER/network-data/Timeuse/HH/HHYMO. ...@@ -5,21 +5,31 @@ const HHYMO = DataFrame(CSV.File("$PACKAGE_FOLDER/network-data/Timeuse/HH/HHYMO.
# In particular, we avoid having to modify any strings in the error function. # In particular, we avoid having to modify any strings in the error function.
const cnst_hh = ( const cnst_hh = (
# Set the underlying parameters for the intervals model
Sparam = [60,12],
# Set parameters for intervals sample and subpopulation size # Set parameters for intervals sample and subpopulation size
numsamples = 10, numsamples = 10_000,
subsize = size(HHYMO)[1], subsize = size(HHYMO)[1],
# Swap age brackets for numbers
swap = OrderedDict("Y" => YOUNG, "M" => MIDDLE, "O" => OLD),
# Total weight in survey # Total weight in survey
Wghttotal = sum(HHYMO[:,"WGHT_PER"]), Wghttotal = sum(HHYMO[:,"WGHT_PER"]),
) )
function do_hh(particles)
dists = [
Normal,
Poisson,
]
priors_list = map(l -> Factored(vcat(l...)...),[
[[Uniform(μ_bounds...) for i=1:6], [Uniform(σ_bounds...) for i = 1:6]],
[[Uniform(μ_bounds...) for i=1:6]],
])
fname = "hh"
bayesian_estimation(fname,err_hh,priors_list,dists,particles)
end
function make_dat_array() function make_dat_array()
durs = hcat( durs = hcat(
Int.(HHYMO[!,"YDUR"*string(cnst_hh.Sparam[2])]), Int.(HHYMO[!,"YDUR"*string(sparam[2])]),
Int.(HHYMO[!,"MDUR"*string(cnst_hh.Sparam[2])]), Int.(HHYMO[!,"MDUR"*string(sparam[2])]),
Int.(HHYMO[!,"ODUR"*string(cnst_hh.Sparam[2])]), Int.(HHYMO[!,"ODUR"*string(sparam[2])]),
) )
nums = hcat( nums = hcat(
Int.(HHYMO[!,"YNUM"]), Int.(HHYMO[!,"YNUM"]),
...@@ -28,7 +38,7 @@ function make_dat_array() ...@@ -28,7 +38,7 @@ function make_dat_array()
) )
WGHT = Weights(HHYMO[!,"WGHT_PER"]./cnst_hh.Wghttotal) WGHT = Weights(HHYMO[!,"WGHT_PER"]./cnst_hh.Wghttotal)
AGERESP = map(r -> cnst_hh.swap[r],HHYMO[!,"AGERESP"]) AGERESP = map(r -> swap_dict[r],HHYMO[!,"AGERESP"])
return (; return (;
nums, nums,
durs, durs,
...@@ -53,7 +63,7 @@ function err_hh(p,dist) ...@@ -53,7 +63,7 @@ function err_hh(p,dist)
age_sample = AGERESP[i] age_sample = AGERESP[i]
@inbounds for age_j in YOUNG:OLD #for a given age_sample loop over possible contact ages @inbounds for age_j in YOUNG:OLD #for a given age_sample loop over possible contact ages
durs = trunc.(Int,rand(rng,age_dists[age_sample,age_j],num_contacts_subarray[i,age_j])) .% durmax durs = trunc.(Int,rand(rng,age_dists[age_sample,age_j],num_contacts_subarray[i,age_j])) .% durmax
expdur = tot_dur_sample(cnst_hh.numsamples,cnst_hh.Sparam,durs) expdur = tot_dur_sample(cnst_hh.numsamples,durs)
errsum += (expdur/cnst_hh.numsamples - duration_subarray[i,age_j])^2 #compute total errsum += (expdur/cnst_hh.numsamples - duration_subarray[i,age_j])^2 #compute total
end end
end end
......
# Set the underlying parameters for the intervals model
const startdist = Normal(sparam...)
function coverage!(cov,S_j,E_j) function coverage!(cov,S_j,E_j)
if E_j < S_j if E_j < S_j
...@@ -8,7 +11,7 @@ function coverage!(cov,S_j,E_j) ...@@ -8,7 +11,7 @@ function coverage!(cov,S_j,E_j)
end end
end end
#compute the total duration of a sample of intervals #compute the total duration of a sample of intervals
function tot_dur_sample(n, dist,durlist) function tot_dur_sample(n,durlist)
if isempty(durlist) if isempty(durlist)
return 0 return 0
end end
...@@ -18,7 +21,7 @@ function tot_dur_sample(n, dist,durlist) ...@@ -18,7 +21,7 @@ function tot_dur_sample(n, dist,durlist)
int_list = Vector{Interval{Int,Closed,Closed}}() int_list = Vector{Interval{Int,Closed,Closed}}()
sizehint!(int_list,numcontact*2) sizehint!(int_list,numcontact*2)
start_matrix = trunc.(Int,(rand(rng,dist,(numcontact,n)))) start_matrix = trunc.(Int,rand(rng,startdist,(numcontact,n)))
@inbounds for i in 1:n @inbounds for i in 1:n
empty!(int_list) empty!(int_list)
@inbounds for j in 1:numcontact @inbounds for j in 1:numcontact
...@@ -31,7 +34,9 @@ function tot_dur_sample(n, dist,durlist) ...@@ -31,7 +34,9 @@ function tot_dur_sample(n, dist,durlist)
end end
return total_dur return total_dur
end end
function tot_dur_sample!(sample_list, dist,durlist)
function tot_dur_sample!(sample_list,durlist)
if isempty(durlist) if isempty(durlist)
sample_list .= 0.0 sample_list .= 0.0
return return
...@@ -40,7 +45,8 @@ function tot_dur_sample!(sample_list, dist,durlist) ...@@ -40,7 +45,8 @@ function tot_dur_sample!(sample_list, dist,durlist)
n = length(sample_list) n = length(sample_list)
int_list = Vector{Interval{Int,Closed,Closed}}() int_list = Vector{Interval{Int,Closed,Closed}}()
sizehint!(int_list,numcontact*2) sizehint!(int_list,numcontact*2)
start_matrix = trunc.(Int,(rand(rng,dist,(numcontact,n)))) # @show rand(rng,startdist,(numcontact,n))
start_matrix = trunc.(Int,rand(rng,startdist,(numcontact,n)))
for i in 1:n for i in 1:n
empty!(int_list) empty!(int_list)
for j in 1:numcontact for j in 1:numcontact
......
...@@ -4,6 +4,17 @@ using Plots ...@@ -4,6 +4,17 @@ using Plots
default(dpi = 300) default(dpi = 300)
default(framestyle = :box) default(framestyle = :box)
pgfplotsx() pgfplotsx()
function plot_all()
fnames = ["hh","ws","rest"]
map(plot_estimate,fnames)
end
function plot_estimate(fname)
data = deserialize(joinpath(PACKAGE_FOLDER,"simulation_data","$fname.dat"))
plot_dists("$fname",collect(keys(data)),collect(values(data)))
for (k,v) in zip(keys(data),values(data))
plot_posteriors("$(k)_$fname",v)
end
end
function plot_dists(fname,dist_constructors,data) function plot_dists(fname,dist_constructors,data)
p_estimate_as_arrays = map(d -> get_params(d.P),data) p_estimate_as_arrays = map(d -> get_params(d.P),data)
p_matrix = map(x -> plot(),p_estimate_as_arrays[1]) p_matrix = map(x -> plot(),p_estimate_as_arrays[1])
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment