Skip to content
Snippets Groups Projects
Commit 28de5607 authored by Peter Jentsch's avatar Peter Jentsch
Browse files

posteriors added, ZWDist pdf bugfix

parent cd6462b3
No related branches found
No related tags found
No related merge requests found
......@@ -12,9 +12,11 @@ const PACKAGE_FOLDER = dirname(dirname(pathof(IntervalsModel)))
using DataStructures:OrderedDict
using Serialization
using Plots
const rng = Xoroshiro128Plus(1)
const YOUNG, MIDDLE,OLD = 1,2,3
const durmax = 144
const color_palette = palette(:seaborn_bright) #color theme for the plots
include("interval_overlap_sampling.jl")
include("utils.jl")
......@@ -29,10 +31,10 @@ using JLSO
using Plots
const μ_bounds = (6,12*6)
const σ_bounds = (1,48)
const μ_bounds = (1,144)
const σ_bounds = (1,144)
const α_bounds = (0.0,0.8)
pgfplotsx()
function do_hh(particles)
dists = [
Normal,
......@@ -42,7 +44,7 @@ function do_hh(particles)
bounds_list = map(l -> vcat(l...),[
([μ_bounds for i = 1:6], [σ_bounds for i = 1:6]),
([μ_bounds for i = 1:6],),
([μ_bounds for i = 1:6], [σ_bounds for i = 1:6]),
([(0.1,144) for i = 1:6], [(0.1,144.0) for i = 1:6]),
])
bayesian_estimate("hh",err_hh,dists,bounds_list,particles)
end
......@@ -51,14 +53,14 @@ function do_ws(particles)
dists = [
ZWDist{Normal},
ZWDist{Poisson},
ZWDist{Weibull}
# ZWDist{Weibull}
]
# Set parameter bounds for fitting
bounds_list = map(l -> vcat(l...),[
[[α_bounds for i = 1:6],[μ_bounds for i = 1:6], [σ_bounds for i = 1:6]],
[[α_bounds for i = 1:6],[μ_bounds for i = 1:6]],
[[α_bounds for i = 1:6],[(0.1,12*6) for i = 1:6], [(0.1,48.0) for i = 1:6]],
[[α_bounds for i = 1:6],[(0.1,144) for i = 1:6], [(0.1,144.0) for i = 1:6]],
])
@show bounds_list
bayesian_estimate("ws",err_ws,dists,bounds_list,particles)
......@@ -75,7 +77,7 @@ function do_rest(particles)
bounds_list = map(l -> vcat(l...),[
[[α_bounds for i = 1:6],[μ_bounds for i = 1:6], [σ_bounds for i = 1:6]],
[[α_bounds for i = 1:6],[μ_bounds for i = 1:6]],
[[α_bounds for i = 1:6],[μ_bounds for i = 1:6], [σ_bounds for i = 1:6]],
[[α_bounds for i = 1:6],[(0.1,144) for i = 1:6], [(0.1,144.0) for i = 1:6]],
])
@show bounds_list
bayesian_estimate("rest",err_rest,dists,bounds_list,particles)
......@@ -89,7 +91,7 @@ function bayesian_estimate(fname,err_func,dists,bounds_list,particles)
@btime err_ws($init,$dist) #compute benchmark of the error function, not rly necessary
out = smc(priors,p -> err_func(p, dist), verbose=true, nparticles=particles, alpha=0.90, parallel = true) #apply sequential monte carlo with 200 particles
out = smc(priors,p -> err_func(p, dist), verbose=true, nparticles=particles, alpha=0.99, M = 2, parallel = true) #apply sequential monte carlo with 200 particles
return dist => out
end |> OrderedDict
......@@ -100,14 +102,21 @@ end
function plot_estimates()
data = deserialize(joinpath(PACKAGE_FOLDER,"simulation_data","ws.dat"))
plot_dists("ws",collect(keys(data)),collect(values(data)))
display(collect(values(data)))
display(collect(keys(data)))
for (k,v) in zip(keys(data),values(data))
plot_posteriors("$(k)_ws",v)
end
data = deserialize(joinpath(PACKAGE_FOLDER,"simulation_data","hh.dat"))
plot_dists("hh",collect(keys(data)),collect(values(data)))
for (k,v) in zip(keys(data),values(data))
plot_posteriors("$(k)_hh",v)
end
data = deserialize(joinpath(PACKAGE_FOLDER,"simulation_data","rest.dat"))
plot_dists("rest",collect(keys(data)),collect(values(data)))
for (k,v) in zip(keys(data),values(data))
plot_posteriors("$(k)_rest",v)
end
end
......
......@@ -8,16 +8,20 @@ function plot_dists(fname,dist_constructors,data)
p_estimate_as_arrays = map(d -> get_params(d.P),data)
p_matrix = map(x -> plot(),p_estimate_as_arrays[1])
ymo = ["Y","M","O"]
x_range = 0.0:144.0
for i in YOUNG:OLD, j in YOUNG:OLD
for (dist_constructor,p_estimate) in zip(dist_constructors,p_estimate_as_arrays)
for (k,(dist_constructor,p_estimate)) in enumerate(zip(dist_constructors,p_estimate_as_arrays))
dists = map(p -> p.particles, p_estimate[i,j]) |>
t -> zip(t...) |>
l -> map(t -> dist_constructor(t...),l)
dist_pts = [pdf.(dists,i) for i in 0.0:144.0]
dist_pts = [pdf.(dists,i) for i in x_range]
mean_dat = median.(dist_pts)
err_down = quantile.(dist_pts,0.05)
err_up = quantile.(dist_pts,0.95)
plot!(p_matrix[i,j] ,0:144,mean_dat; ribbon = ( mean_dat .- err_down,err_up .- mean_dat),legend = false,label = string(dist_constructor))
display(typeof(dist_pts))
hasnans = any(any.(map(l -> isnan.(l),dist_pts)))
err_down = hasnans ? 0 : quantile.(dist_pts,0.05)
err_up = hasnans ? 0 : quantile.(dist_pts,0.95)
plot!(p_matrix[i,j] ,x_range,mean_dat; ribbon = ( mean_dat .- err_down,err_up .- mean_dat),legend = false,label = string(dist_constructor),seriescolor = color_palette[k])
end
annotate!(p_matrix[i,j],compute_x_pos(p_matrix[i,j]),compute_y_pos(p_matrix[i,j]), Plots.text("$(ymo[i])→$(ymo[j])", :left, 10))
end
......@@ -27,25 +31,28 @@ function plot_dists(fname,dist_constructors,data)
end
compute_x_pos(p) = xlims(p)[1] + 0.02*((xlims(p)[2] - xlims(p)[1]))
compute_y_pos(p) = ylims(p)[2] - 0.11*((ylims(p)[2] - ylims(p)[1]))
function plot_posteriors(fname,parameter_names,data)
function plot_posteriors(fname,data)
p_list = map(x -> plot(),1:length(data.P))
for i in 1:length(data.P)
a = stephist(
data.P[i].particles;
normalize = true,
title = "$(parameter_names[div(i,6)])_(i % 6)"
)
push!(p_list,a)
# display(data.P[i].particles)
hist = fit(Histogram,data.P[i].particles; nbins = 50)
kde_est = kde(data.P[i].particles)
kernel_data = [pdf(kde_est,x) for x in minimum(data.P[i].particles):maximum(data.P[i].particles)]
plot!(p_list[i],hist;legend = false)
# display(kernel_data)
# vline!(p_list[i],[argmax(kernel_data)]; seriescolor = color_palette[2])
end
p = plot(p_list...;layout=(length(p_list) ÷ 6,6), size = (1000,(length(p_list) ÷ 6)*300), seriescolor = color_palette[1])
savefig(p,joinpath(PACKAGE_FOLDER,"plots","$fname.pdf"))
end
function add_subplot_letters!(plot_list; pos = :top)
for (i,sp) in enumerate(plot_list)
letter = string(Char(i+96))
if pos == :top
annotate!(sp,xlims(sp)[1] + 0.02*((xlims(sp)[2] - xlims(sp)[1])),ylims(sp)[2] - 0.11*((ylims(sp)[2] - ylims(sp)[1])), Plots.text("$letter)", :left, 18))
elseif pos == :bottom
annotate!(sp,xlims(sp)[1] + 0.02*((xlims(sp)[2] - xlims(sp)[1])),ylims(sp)[1] + 0.11*((ylims(sp)[2] - ylims(sp)[1])), Plots.text("$letter)", :left, 18))
end
end
p = plot(p_list...)
savefig(p,"$fname.png")
end
# function add_subplot_letters!(plot_list; pos = :top)
# for (i,sp) in enumerate(plot_list)
# letter = string(Char(i+96))
# if pos == :top
# annotate!(sp,xlims(sp)[1] + 0.02*((xlims(sp)[2] - xlims(sp)[1])),ylims(sp)[2] - 0.11*((ylims(sp)[2] - ylims(sp)[1])), Plots.text("$letter)", :left, 18))
# elseif pos == :bottom
# annotate!(sp,xlims(sp)[1] + 0.02*((xlims(sp)[2] - xlims(sp)[1])),ylims(sp)[1] + 0.11*((ylims(sp)[2] - ylims(sp)[1])), Plots.text("$letter)", :left, 18))
# end
# end
# end
\ No newline at end of file
......@@ -16,7 +16,7 @@ const rest_data = (
M = CSV.File("$PACKAGE_FOLDER/network-data/Timeuse/Rest/RDataM.csv") |> Tables.matrix |> x -> dropdims(x;dims = 2),
O = CSV.File("$PACKAGE_FOLDER/network-data/Timeuse/Rest/RDataO.csv") |> Tables.matrix |> x -> dropdims(x;dims = 2),
)
const comparison_samples = 2000
const comparison_samples = 1000
ws_distributions = CovidAlertVaccinationModel.initial_workschool_mixing_matrix
......@@ -40,12 +40,12 @@ function err_ws(p,dist)
durs = trunc.(Int,rand(rng,age_dists[age_sample,age_j],neighourhoods[age_sample,age_j])) .% durmax
# display(durs)
tot_dur_sample!(sample_list,cnst_hh.Sparam,durs)
# kde_est = kde(sample_list)
kde_est = kde(sample_list)
# err = (1 - pvalue(KSampleADTest(ws_samples[age_sample],sample_list)))^2 #need to maximize probability of null hypothesis, not rly valid but everyone does it so idk
# errsum += mapreduce(+,0:0.05:durmax) do i
# return (pdf(kde_est,i) - pdf(kerneldensity_data[age_sample],i))^2
# end
errsum += (mean(sample_list) - mean(ws_samples[age_sample]))^2
errsum += mapreduce(+,0:2:durmax) do i
return (pdf(kde_est,i) - pdf(kerneldensity_data[age_sample],i))^2
end
# errsum += err#(mean(sample_list) - mean(ws_samples[age_sample]))^2
end
end
end
......
# This file is machine-generated - editing it directly is not advised
[[ArgTools]]
uuid = "0dad84c5-d112-42e6-8d28-ef12dabb789f"
[[ArrayInterface]]
deps = ["IfElse", "LinearAlgebra", "Requires", "SparseArrays"]
git-tree-sha1 = "ee07ae00e3cc277dcfa5507ce25be522313ecc3e"
uuid = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9"
version = "3.1.1"
[[Artifacts]]
uuid = "56f22d72-fd6d-98f1-02f0-08ddc0907c33"
[[Base64]]
uuid = "2a0f44e3-6c83-55bd-87e4-b1978d98bd5f"
[[Dates]]
deps = ["Printf"]
uuid = "ade2ca70-3891-5945-98fb-dc099432e06a"
[[Distributed]]
deps = ["Random", "Serialization", "Sockets"]
uuid = "8ba89e20-285c-5b6f-9357-94700520ee1b"
[[Downloads]]
deps = ["ArgTools", "LibCURL", "NetworkOptions"]
uuid = "f43a241f-c20a-4ad4-852c-f6b1247861c6"
[[Hwloc]]
deps = ["Hwloc_jll"]
git-tree-sha1 = "2e3d1d4ab0e7296354539b2be081f71f4b694c0b"
uuid = "0e44f5e4-bd66-52a0-8798-143a42290a1d"
version = "1.2.0"
[[Hwloc_jll]]
deps = ["Artifacts", "JLLWrappers", "Libdl", "Pkg"]
git-tree-sha1 = "1179250d910c99810d8a7ff55c50c4ed68c77a58"
uuid = "e33a78d0-f292-5ffc-b300-72abe9b543c8"
version = "2.4.0+0"
[[IfElse]]
git-tree-sha1 = "28e837ff3e7a6c3cdb252ce49fb412c8eb3caeef"
uuid = "615f187c-cbe4-4ef1-ba3b-2fcf58d6d173"
version = "0.1.0"
[[InteractiveUtils]]
deps = ["Markdown"]
uuid = "b77e0a4c-d291-57a0-90e8-8db25a27a240"
[[JLLWrappers]]
git-tree-sha1 = "a431f5f2ca3f4feef3bd7a5e94b8b8d4f2f647a0"
uuid = "692b3bcd-3c85-4b1f-b108-f13ce0eb3210"
version = "1.2.0"
[[LibCURL]]
deps = ["LibCURL_jll", "MozillaCACerts_jll"]
uuid = "b27032c2-a3e7-50c8-80cd-2d36dbcbfd21"
[[LibCURL_jll]]
deps = ["Artifacts", "LibSSH2_jll", "Libdl", "MbedTLS_jll", "Zlib_jll", "nghttp2_jll"]
uuid = "deac9b47-8bc7-5906-a0fe-35ac56dc84c0"
[[LibGit2]]
deps = ["Base64", "NetworkOptions", "Printf", "SHA"]
uuid = "76f85450-5226-5b5a-8eaa-529ad045b433"
[[LibSSH2_jll]]
deps = ["Artifacts", "Libdl", "MbedTLS_jll"]
uuid = "29816b5a-b9ab-546f-933c-edad1886dfa8"
[[Libdl]]
uuid = "8f399da3-3557-5675-b5ff-fb832c97cbdb"
[[LinearAlgebra]]
deps = ["Libdl"]
uuid = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
[[Logging]]
uuid = "56ddb016-857b-54e1-b83d-db4d58db5568"
[[Markdown]]
deps = ["Base64"]
uuid = "d6f4376e-aef5-505a-96c1-9c027394607a"
[[MbedTLS_jll]]
deps = ["Artifacts", "Libdl"]
uuid = "c8ffd9c3-330d-5841-b78e-0817d7145fa1"
[[MozillaCACerts_jll]]
uuid = "14a3606d-f60d-562e-9121-12d972cd8159"
[[NetworkOptions]]
uuid = "ca575930-c2e3-43a9-ace4-1e988b2c1908"
[[Pkg]]
deps = ["Artifacts", "Dates", "Downloads", "LibGit2", "Libdl", "Logging", "Markdown", "Printf", "REPL", "Random", "SHA", "Serialization", "TOML", "Tar", "UUIDs"]
uuid = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f"
[[Printf]]
deps = ["Unicode"]
uuid = "de0858da-6303-5e67-8744-51eddeeeb8d7"
[[REPL]]
deps = ["InteractiveUtils", "Markdown", "Sockets", "Unicode"]
uuid = "3fa0cd96-eef1-5676-8a61-b3b8758bbffb"
[[Random]]
deps = ["Serialization"]
uuid = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
[[Requires]]
deps = ["UUIDs"]
git-tree-sha1 = "cfbac6c1ed70c002ec6361e7fd334f02820d6419"
uuid = "ae029012-a4dd-5104-9daa-d747884805df"
version = "1.1.2"
[[SHA]]
uuid = "ea8e919c-243c-51af-8825-aaa63cd721ce"
[[Serialization]]
uuid = "9e88b42a-f829-5b0c-bbe9-9e923198166b"
[[Sockets]]
uuid = "6462fe0b-24de-5631-8697-dd941f90decc"
[[SparseArrays]]
deps = ["LinearAlgebra", "Random"]
uuid = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
[[TOML]]
deps = ["Dates"]
uuid = "fa267f1f-6049-4f14-aa54-33bafae1ed76"
[[Tar]]
deps = ["ArgTools", "SHA"]
uuid = "a4e569a6-e804-4fa4-b0f3-eef7a1d5b13e"
[[UUIDs]]
deps = ["Random", "SHA"]
uuid = "cf7118a7-6976-5b1a-9a39-7adc72f591a4"
[[UnPack]]
git-tree-sha1 = "387c1f73762231e86e0c9c5443ce3b4a0a9a0c2b"
uuid = "3a884ed6-31ef-47d7-9d2a-63182c4928ed"
version = "1.0.2"
[[Unicode]]
uuid = "4ec0a83e-493e-50e2-b9ac-8f72acf5a8f5"
[[VectorizationBase]]
deps = ["ArrayInterface", "Hwloc", "IfElse", "Libdl", "LinearAlgebra"]
git-tree-sha1 = "338930400f561a120b9b317a456c7c1cd62eac13"
uuid = "3d5dd08c-fd9d-11e8-17fa-ed2836048c2f"
version = "0.18.10"
[[VectorizedRNG]]
deps = ["Distributed", "Random", "UnPack", "VectorizationBase"]
git-tree-sha1 = "59c95a188efd11c6ed762154397ed5ea94a95e30"
uuid = "33b4df10-0173-11e9-2a0c-851a7edac40e"
version = "0.2.7"
[[Zlib_jll]]
deps = ["Libdl"]
uuid = "83775a58-1f1d-513f-b197-d71354ab007a"
[[nghttp2_jll]]
deps = ["Artifacts", "Libdl"]
uuid = "8e850ede-7688-5339-a07c-302acd2aaf8d"
[deps]
VectorizedRNG = "33b4df10-0173-11e9-2a0c-851a7edac40e"
......@@ -25,7 +25,7 @@ function Distributions.pdf(d::ZWDist, x)
if x == 0
return d.α + (1-d.α)*pdf(d.base_dist,0)
else
return pdf(d.base_dist,0)
return pdf(d.base_dist,x)
end
end
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment