parameter_optimization.jl 10.7 KB
Newer Older
Peter Jentsch's avatar
Peter Jentsch committed
1
2
using KissABC

3
4
const vaccination_data = [0.0,0.043,0.385,0.424,0.115,0.03,0.005] #by month starting in august
const ymo_vac = [0.255,0.278,0.602]
5
const ymo_attack_rate = [10.376,5.636,7.2]./100
6

Peter Jentsch's avatar
Peter Jentsch committed
7

8
function solve_w_parameters(default_p, p_names, new_p_list)
9

10
11
12
13
14
    new_params = merge(default_p, NamedTuple{p_names}(ntuple(i -> new_p_list[i],length(p_names))))
    out = DebugRecorder(0,default_p.sim_length)
    model = abm(new_params,out)
    return out, model
end
Peter Jentsch's avatar
Peter Jentsch committed
15

16
function fit_pre_inf_behavioural_parameters(p_tuple)
Peter Jentsch's avatar
Peter Jentsch committed
17
    samples = 1
18
19
20
21
22
23
24
25
26
27
28
29
    p_names = (:π_base_y,:π_base_m,:π_base_o)
    priors = Factored(
        Uniform(-10.0,2.0),
        Uniform(-10.0,2.0),
        Uniform(-10.0,2.0)
    )
    
    #simulation begins in july
    #60 days for opinion dynamics to stabilize, then immunization begins in september,
    #infection is not considered
    
    sim_length = 180
Peter Jentsch's avatar
Peter Jentsch committed
30
31
32
33
    p_tuple_adjust = merge(p_tuple,
        (
            sim_length = sim_length,
            I_0_fraction = 0.000,
34
            immunization_begin_day =60, 
35
            infection_introduction_day = 180,
Peter Jentsch's avatar
Peter Jentsch committed
36
37
38
            immunizing = true,
        )
    )
39
40
    function cost(p)
        output,model = solve_w_parameters(p_tuple_adjust, p_names,p)
41
        target_ymo_vac = ymo_vac .* sum(vaccination_data[1:4]) .* length.(model.index_vectors)
42
        ymo_vaccination_ts = output.daily_immunized_by_age
43
44
45
46
        total_preinfection_vaccination = sum.(eachrow(ymo_vaccination_ts))
        # display(target_ymo_vac)
        # display(total_preinfection_vaccination)
        return sum((total_preinfection_vaccination .- target_ymo_vac).^2)    
Peter Jentsch's avatar
Peter Jentsch committed
47
    end
48

49
    out = smc(priors,cost; verbose = true, nparticles = 150, parallel = true)#
50
    return NamedTuple{p_names}(ntuple(i -> out.P[i].particles,length(p_names)))
Peter Jentsch's avatar
Peter Jentsch committed
51
52
end

53
function fit_post_inf_behavioural_parameters(p_tuple)
54
55
    p_names = (:ω,:β_y,:β_m,:β_o)
    priors = Factored(Uniform(0.0,0.1),Uniform(0.0,0.1),Uniform(0.0,0.1),Uniform(0.0,1.0))
56
57
58
59
60
61
62
    #simulation begins in july
    #60 days for opinion dynamics to stabilize, then immunization begins in september,
    #infection introduced at beginning of december
    sim_length = 300
    p_tuple_adjust = merge(p_tuple,
        (
            sim_length = sim_length,
Peter Jentsch's avatar
Peter Jentsch committed
63
            I_0_fraction = 0.005,
64
65
66
67
68
69
70
            immunization_begin_day =60, 
            infection_introduction_day = 180,
            immunizing = true,
        )
    )
    function cost(p)
        output,model = solve_w_parameters(p_tuple_adjust, p_names,p)
71
        target_ymo_vac = ymo_vac .* sum(vaccination_data[1:end]) .* length.(model.index_vectors)
72
        ymo_vaccination_ts = output.daily_immunized_by_age
73
74
        total_postinf_vaccination = sum.(eachrow(ymo_vaccination_ts[:,180:end]))

75
        final_size = sum.(eachrow(output.daily_cases_by_age))
76
77
78
79
80
        target_final_size = ymo_attack_rate .* length.(model.index_vectors)
        # display( length.(model.index_vectors))
        # display((final_size,target_final_size))
        # display((total_postinf_vaccination,target_ymo_vac))
        # display((1e-1*sum((total_postinf_vaccination .- target_ymo_vac).^2) , sum((final_size .- target_final_size).^2)))
81
        return 1e-2*sum((total_postinf_vaccination .- target_ymo_vac).^2)   + sum((final_size .- target_final_size).^2)
82
    end
83

84
    # display(cost([0.000,0.001,0.001,1.0]))
85
    out =smc(priors,cost; verbose = true, nparticles = 400, parallel = true)# ABCDE(priors,cost,1e6; verbose=true, nparticles=300,generations=1000,  parallel = true) #this one has better NaN handling
86
    return NamedTuple{p_names}(ntuple(i -> out.P[i].particles,length(p_names)))
87
end
Peter Jentsch's avatar
Peter Jentsch committed
88

89
function fit_all_parameters(p_tuple)
Peter Jentsch's avatar
Peter Jentsch committed
90
    p_names = (:ω,:β_y,:β_m,:β_o,:π_base_y,:π_base_m,:π_base_o,:α_y,:α_m,:α_o)
91
    priors = Factored(
Peter Jentsch's avatar
Peter Jentsch committed
92
93
        Uniform(0.0,0.01),
        Uniform(0.0,0.2),
94
95
        Uniform(0.0,0.1),
        Uniform(0.0,1.0),
Peter Jentsch's avatar
Peter Jentsch committed
96
97
        Uniform(-5.0,0.0),
        Uniform(-5.0,0.0),
98
        Uniform(-5.0,5.0),
Peter Jentsch's avatar
Peter Jentsch committed
99
100
101
        Uniform(0.0,1.0),
        Uniform(0.0,1.0),
        Uniform(0.0,1.0),
102
103
104
105
106
107
108
109
    )
    #simulation begins in july
    #60 days for opinion dynamics to stabilize, then immunization begins in september,
    #infection introduced at beginning of december
    sim_length = 300
    p_tuple_adjust = merge(p_tuple,
        (
            sim_length = sim_length,
Peter Jentsch's avatar
Peter Jentsch committed
110
            I_0_fraction = 0.005,
111
112
113
114
115
116
117
118
119
120
121
            immunization_begin_day =60, 
            infection_introduction_day = 180,
            immunizing = true,
        )
    )
    function cost(p)
        output,model = solve_w_parameters(p_tuple_adjust, p_names,p)
        target_ymo_vac = ymo_vac .* sum(vaccination_data[1:end]) .* length.(model.index_vectors)
        ymo_vaccination_ts = output.daily_immunized_by_age
        total_postinf_vaccination = sum.(eachrow(ymo_vaccination_ts[:,180:end]))

Peter Jentsch's avatar
Peter Jentsch committed
122
        final_size = sum.(eachrow(output.daily_unvac_cases_by_age))
123
124
125

        target_final_size = ymo_attack_rate .* length.(model.index_vectors)
        target_ymo_vac = ymo_vac .* sum(vaccination_data[1:4]) .* length.(model.index_vectors)
Peter Jentsch's avatar
Peter Jentsch committed
126

127
128
        ymo_vaccination_ts = output.daily_immunized_by_age
        total_preinfection_vaccination = sum.(eachrow(ymo_vaccination_ts))
Peter Jentsch's avatar
Peter Jentsch committed
129
        return sum((total_preinfection_vaccination .- target_ymo_vac).^2)  + sum((total_postinf_vaccination .- target_ymo_vac).^2)   + sum((final_size .- target_final_size).^2)
130
131
132
    end

    # display(cost([0.000,0.001,0.001,1.0]))
Peter Jentsch's avatar
Peter Jentsch committed
133
    out =smc(priors,cost; verbose = true, nparticles = 1000, parallel = true)# ABCDE(priors,cost,1e6; verbose=true, nparticles=300,generations=1000,  parallel = true) #this one has better NaN handling
134
135
136
    return NamedTuple{p_names}(ntuple(i -> out.P[i].particles,length(p_names)))
end

Peter Jentsch's avatar
Peter Jentsch committed
137
function plot_behavioural_fit(particles,p_tuple)
138
    p_names = (:π_base_y,:π_base_m,:π_base_o)
Peter Jentsch's avatar
Peter Jentsch committed
139
140
141
    sim_length = 210
    samples = 1
    p_tuple_adjust = merge(p_tuple,
142
        (
Peter Jentsch's avatar
Peter Jentsch committed
143
            sim_length = sim_length,
144
145
            I_0_fraction = 0.000,
            immunization_begin_day =60, 
146
            infection_introduction_day = 180,
Peter Jentsch's avatar
Peter Jentsch committed
147
148
149
150
151
152
153
            immunizing = true,
        )
    )
    p = map(e -> mode(e.particles),particles.P)
    display(p)
    new_params = merge(p_tuple_adjust, NamedTuple{p_names}(ntuple(i -> p[i],length(p_names))))
    out = mean_solve(samples, new_params ,DebugRecorder)
154
155
156
    target_cumulative_vac_proportion = 0.33
    vaccination_data = @SVector [0.0,0.043,0.385,0.424,0.115,0.03,0.005] #by month starting in august
    ymo_vac = @SVector [0.255,0.278,0.602]
157
    ymo_vaccination_ts = mean.(out.daily_immunized_by_age)
158
159
    total_preinfection_vaccination = sum.(eachrow(ymo_vaccination_ts))
    display(total_preinfection_vaccination)
160
161
    p = [plot(),plot(),plot()]

Peter Jentsch's avatar
Peter Jentsch committed
162
163
    p = plot_model(nothing,[nothing],[out],new_params.infection_introduction_day,new_params.immunization_begin_day)
    savefig(p,"behaviour_fit.pdf")
Peter Jentsch's avatar
Peter Jentsch committed
164
    return out
165
end
166
167
168
169
170
171
172
# outbreak_transmission_dist = CovidAlertVaccinationModel.fit_epi_parameters(default_parameters,0.241) ##outbreak
# serialize(joinpath(PACKAGE_FOLDER,"abm_parameter_fits","outbreak_inf_parameters.dat"),outbreak_transmission_dist)
# plot_max_posterior("outbreak", outbreak_transmission_dist,default_parameters)

function fit_parameters(default_parameters)
    pre_inf_behaviour_parameters_path =joinpath(PACKAGE_FOLDER,"abm_parameter_fits","pre_inf_behaviour_parameters.dat")
    post_inf_behaviour_parameters_path = joinpath(PACKAGE_FOLDER,"abm_parameter_fits","post_inf_behaviour_parameters.dat")
173
    fit_all_parameters_path = joinpath(PACKAGE_FOLDER,"abm_parameter_fits","fit_all_parameters.dat")
174
175


176
177
    # pre_inf_behaviour_parameters = fit_pre_inf_behavioural_parameters(default_parameters) 
    # serialize(pre_inf_behaviour_parameters_path, pre_inf_behaviour_parameters)
178

179

180
181
182
183
    # pre_inf_behaviour_parameters = deserialize(pre_inf_behaviour_parameters_path)
    # display(map(mode,pre_inf_behaviour_parameters))
    # post_inf_behaviour_parameters = fit_post_inf_behavioural_parameters(merge(default_parameters,map(mode,pre_inf_behaviour_parameters))) 
    # serialize(post_inf_behaviour_parameters_path, post_inf_behaviour_parameters)
184
    
185

186
187
188
189
190
191
192

    # fitted_parameter_tuple = (;
    #     deserialize(pre_inf_behaviour_parameters_path)...,
    #     deserialize(post_inf_behaviour_parameters_path)...
    # )
    # display(map(mode,fitted_parameters_with_post_inf_behaviour))

Peter Jentsch's avatar
Peter Jentsch committed
193
194
    # output = fit_all_parameters(default_parameters)
    # serialize(fit_all_parameters_path,output)
195
196
197
198

    
    fitted_parameter_tuple = deserialize(fit_all_parameters_path)
    fitted_sol = plot_fitting_posteriors("post_inf_fitting",fitted_parameter_tuple,default_parameters)
199
    return fitted_sol
200
201
202
203
204
205
206
207
208
209
210
211
212
end 

function plot_max_posterior(fname,particles,parameters)
    samples = 5
    base_transmission = mode(particles.base_transmission_probability)
    p_tuple_without_vac = merge(parameters,
        (
            sim_length = 150,
            immunization_begin_day = 0,
            infection_introduction_day = 1,
            immunizing = false,
        )
    )
213
    
214
215
216
217
218
219
220
    new_params = merge(p_tuple_without_vac, (base_transmission_probability = base_transmission,))
    out = mean_solve(samples, new_params ,DebugRecorder)
    p = plot_model(nothing,[nothing],[out],new_params.infection_introduction_day,new_params.immunization_begin_day)
    savefig(p,"$fname.pdf")
    hist = StatsBase.fit(Histogram,particles.base_transmission_probability; nbins = 25)
    p = plot(hist;legend = false)
    savefig(p,"$(fname)_posterior.pdf")
221

222
end
223
using PairPlots
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
function plot_fitting_posteriors(fname,particles_tuple,parameters)
    p_tuple_adjust = merge(parameters,
        (
            sim_length = 500,
            I_0_fraction = 0.002,
            immunization_begin_day =60, 
            infection_introduction_day = 180,
            immunizing = true,
        )
    )
    out = mean_solve(5, merge(p_tuple_adjust,map(mode,particles_tuple)) ,DebugRecorder)
    p = plot_model(nothing,[nothing],[out],parameters.infection_introduction_day,parameters.immunization_begin_day)
    savefig(p, "$fname.pdf")
    
    plts = [plot() for i in 1:length(particles_tuple)]
    for (plt,(k,v)) in zip(plts,pairs(particles_tuple))
240
        hist = StatsBase.fit(Histogram,v; nbins = 30)
241
242
        plot!(plt,hist;legend = false,xlabel = k)            
    end
243
    p = plot(plts...; size = (1400,800),bottommargin = 5Plots.mm)
244
245
246
    savefig(p,"$(fname)_posteriors.pdf")
    return out
end
247

248
# function visualize_π_base(particles_tuple)
249

250
251
252
#     param_keys = [:π_base_y,:π_base_m,:π_base_o]
#     # π_bases_array = map(f -> getproperty(particles_tuple,f),param_keys) |> l -> mapreduce(t-> [t...],hcat,zip(l...))
#     # display(π_bases_array)
253

254
255
256
257
258
#     # p = cornerplot(π_bases_array'; labels = string.(param_keys))
#     params = NamedTuple{(param_keys...,)}(particles_tuple)
#     # display(params)
#     p = corner(params)
#     display(p)
259
260


261
262
#     # display(scatter(map(f -> getproperty(particles_tuple,f),param_keys)...))
# end