parameter_optimization.jl 8.78 KB
Newer Older
Peter Jentsch's avatar
Peter Jentsch committed
1
2
using KissABC

3
const target_cumulative_vac_proportion = 0.33
4
5
6
7
const vaccination_data = [0.0,0.043,0.385,0.424,0.115,0.03,0.005] #by month starting in august
const ymo_vac = [0.255,0.278,0.602]
const ymo_attack_rate = []

Peter Jentsch's avatar
Peter Jentsch committed
8

9
10
11
12
13
14
function solve_w_parameters(default_p, p_names, new_p_list)
    new_params = merge(default_p, NamedTuple{p_names}(ntuple(i -> new_p_list[i],length(p_names))))
    out = DebugRecorder(0,default_p.sim_length)
    model = abm(new_params,out)
    return out, model
end
Peter Jentsch's avatar
Peter Jentsch committed
15

16
function fit_pre_inf_behavioural_parameters(p_tuple)
Peter Jentsch's avatar
Peter Jentsch committed
17
    samples = 1
18
19
20
21
22
23
24
25
26
27
28
29
    p_names = (:π_base_y,:π_base_m,:π_base_o)
    priors = Factored(
        Uniform(-10.0,2.0),
        Uniform(-10.0,2.0),
        Uniform(-10.0,2.0)
    )
    
    #simulation begins in july
    #60 days for opinion dynamics to stabilize, then immunization begins in september,
    #infection is not considered
    
    sim_length = 180
Peter Jentsch's avatar
Peter Jentsch committed
30
31
32
33
    p_tuple_adjust = merge(p_tuple,
        (
            sim_length = sim_length,
            I_0_fraction = 0.000,
34
            immunization_begin_day =60, 
35
            infection_introduction_day = 180,
Peter Jentsch's avatar
Peter Jentsch committed
36
37
38
            immunizing = true,
        )
    )
39

40
41
    function cost(p)
        output,model = solve_w_parameters(p_tuple_adjust, p_names,p)
Peter Jentsch's avatar
Peter Jentsch committed
42
        target_cumulative_vaccinations = target_cumulative_vac_proportion*model.nodes
43
        target_ymo_vac = ymo_vac .* sum(vaccination_data[1:4]) .* target_cumulative_vaccinations
44
        ymo_vaccination_ts = output.new_ymo_immunization
45
46
47
48
        total_preinfection_vaccination = sum.(eachrow(ymo_vaccination_ts))
        # display(target_ymo_vac)
        # display(total_preinfection_vaccination)
        return sum((total_preinfection_vaccination .- target_ymo_vac).^2)    
Peter Jentsch's avatar
Peter Jentsch committed
49
    end
50
51
    out = smc(priors,cost; verbose = true, nparticles = 800, parallel = true)#
    return NamedTuple{p_names}(ntuple(i -> out.P[i].particles,length(p_names)))
Peter Jentsch's avatar
Peter Jentsch committed
52
53
end

54
55
function fit_post_inf_behavioural_parameters(p_tuple)
    samples = 1
56
57
    p_names = (:ω,:base_transmission_probability)
    priors = Factored(Uniform(0.0,0.1),Uniform(0.0,0.1))
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
    #simulation begins in july
    #60 days for opinion dynamics to stabilize, then immunization begins in september,
    #infection introduced at beginning of december
    sim_length = 300
    p_tuple_adjust = merge(p_tuple,
        (
            sim_length = sim_length,
            I_0_fraction = 0.002,
            immunization_begin_day =60, 
            infection_introduction_day = 180,
            immunizing = true,
        )
    )
    target_cumulative_vac_proportion = 0.33
    vaccination_data = @SVector [0.0,0.043,0.385,0.424,0.115,0.03,0.005] #by month starting in august
    ymo_vac = @SVector [0.255,0.278,0.602]
    function cost(p)
        output,model = solve_w_parameters(p_tuple_adjust, p_names,p)
76
77


78
79
80
        target_cumulative_vaccinations = target_cumulative_vac_proportion*model.nodes
        target_ymo_vac = ymo_vac .* sum(vaccination_data[5:end]) .* target_cumulative_vaccinations
        ymo_vaccination_ts = output.new_ymo_immunization
81
82
83
84
        total_postinf_vaccination = sum.(eachrow(ymo_vaccination_ts[:,180:end]))

        final_size = output.recorded_status_totals.R[end]
        return sum((total_postinf_vaccination .- target_ymo_vac).^2)    
85
86
87
88
    end
    out =smc(priors,cost; verbose = true, nparticles = 200, parallel = true)# ABCDE(priors,cost,1e6; verbose=true, nparticles=300,generations=1000,  parallel = true) #this one has better NaN handling
    return NamedTuple{p_names}((out.P.particles,))
end
Peter Jentsch's avatar
Peter Jentsch committed
89
90

function plot_behavioural_fit(particles,p_tuple)
91
    p_names = (:π_base_y,:π_base_m,:π_base_o)
Peter Jentsch's avatar
Peter Jentsch committed
92
93
94
    sim_length = 210
    samples = 1
    p_tuple_adjust = merge(p_tuple,
95
        (
Peter Jentsch's avatar
Peter Jentsch committed
96
            sim_length = sim_length,
97
98
            I_0_fraction = 0.000,
            immunization_begin_day =60, 
Peter Jentsch's avatar
Peter Jentsch committed
99
100
101
102
103
104
105
            immunizing = true,
        )
    )
    p = map(e -> mode(e.particles),particles.P)
    display(p)
    new_params = merge(p_tuple_adjust, NamedTuple{p_names}(ntuple(i -> p[i],length(p_names))))
    out = mean_solve(samples, new_params ,DebugRecorder)
106
107
108
109
110
111
    target_cumulative_vac_proportion = 0.33
    vaccination_data = @SVector [0.0,0.043,0.385,0.424,0.115,0.03,0.005] #by month starting in august
    ymo_vac = @SVector [0.255,0.278,0.602]
    ymo_vaccination_ts = mean.(out.new_ymo_immunization)
    total_preinfection_vaccination = sum.(eachrow(ymo_vaccination_ts))
    display(total_preinfection_vaccination)
112
113
    p = [plot(),plot(),plot()]

Peter Jentsch's avatar
Peter Jentsch committed
114
115
    p = plot_model(nothing,[nothing],[out],new_params.infection_introduction_day,new_params.immunization_begin_day)
    savefig(p,"behaviour_fit.pdf")
Peter Jentsch's avatar
Peter Jentsch committed
116
    return out
117
end
118
119
120
121
122
123
124
125
126
# outbreak_transmission_dist = CovidAlertVaccinationModel.fit_epi_parameters(default_parameters,0.241) ##outbreak
# serialize(joinpath(PACKAGE_FOLDER,"abm_parameter_fits","outbreak_inf_parameters.dat"),outbreak_transmission_dist)
# plot_max_posterior("outbreak", outbreak_transmission_dist,default_parameters)

function fit_parameters(default_parameters)
    seasonal_parameters_path = joinpath(PACKAGE_FOLDER,"abm_parameter_fits","outbreak_inf_parameters.dat")
    pre_inf_behaviour_parameters_path =joinpath(PACKAGE_FOLDER,"abm_parameter_fits","pre_inf_behaviour_parameters.dat")
    post_inf_behaviour_parameters_path = joinpath(PACKAGE_FOLDER,"abm_parameter_fits","post_inf_behaviour_parameters.dat")

127
128
129
130
131
    # seasonal_transmission_dist = CovidAlertVaccinationModel.fit_epi_parameters(default_parameters,0.073) ##seasonal
    # plot_max_posterior("seasonal", seasonal_transmission_dist,default_parameters)
    # pre_inf_behaviour_parameters = CovidAlertVaccinationModel.fit_pre_inf_behavioural_parameters(default_parameters)
    # fitted_parameters = map(mode,((;seasonal_transmission_dist...,pre_inf_behaviour_parameters...)))
    # fitted_parameter_tuple = merge(default_parameters,fitted_parameters)
132

133
134
    # serialize(seasonal_parameters_path,seasonal_transmission_dist)
    # serialize(pre_inf_behaviour_parameters_path, pre_inf_behaviour_parameters)
135

136
137
138
    # post_inf_behaviour_parameters = fit_post_inf_behavioural_parameters(fitted_parameter_tuple) 
    # serialize(post_inf_behaviour_parameters_path, post_inf_behaviour_parameters)
    
139
    
140
    # visualize_π_base(deserialize(pre_inf_behaviour_parameters_path))
141
142
143
144
145
146
147
148
    

    fitted_parameters_with_post_inf_behaviour = (;
        deserialize(seasonal_parameters_path)...,
        deserialize(pre_inf_behaviour_parameters_path)...,
        deserialize(post_inf_behaviour_parameters_path)...
    )
    display(map(mode,fitted_parameters_with_post_inf_behaviour))
149
150
151
152


    fitted_sol = plot_fitting_posteriors("post_inf_fitting",fitted_parameters_with_post_inf_behaviour,default_parameters)
    return fitted_sol
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
end 

function plot_max_posterior(fname,particles,parameters)
    samples = 5
    base_transmission = mode(particles.base_transmission_probability)
    p_tuple_without_vac = merge(parameters,
        (
            sim_length = 150,
            immunization_begin_day = 0,
            infection_introduction_day = 1,
            immunizing = false,
        )
    )
    new_params = merge(p_tuple_without_vac, (base_transmission_probability = base_transmission,))
    out = mean_solve(samples, new_params ,DebugRecorder)
    p = plot_model(nothing,[nothing],[out],new_params.infection_introduction_day,new_params.immunization_begin_day)
    savefig(p,"$fname.pdf")
    hist = StatsBase.fit(Histogram,particles.base_transmission_probability; nbins = 25)
    p = plot(hist;legend = false)
    savefig(p,"$(fname)_posterior.pdf")
end
174
using PairPlots
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
function plot_fitting_posteriors(fname,particles_tuple,parameters)
    p_tuple_adjust = merge(parameters,
        (
            sim_length = 500,
            I_0_fraction = 0.002,
            immunization_begin_day =60, 
            infection_introduction_day = 180,
            immunizing = true,
        )
    )
    out = mean_solve(5, merge(p_tuple_adjust,map(mode,particles_tuple)) ,DebugRecorder)
    p = plot_model(nothing,[nothing],[out],parameters.infection_introduction_day,parameters.immunization_begin_day)
    savefig(p, "$fname.pdf")
    
    plts = [plot() for i in 1:length(particles_tuple)]
    for (plt,(k,v)) in zip(plts,pairs(particles_tuple))
191
        hist = StatsBase.fit(Histogram,v; nbins = 30)
192
193
        plot!(plt,hist;legend = false,xlabel = k)            
    end
194
    p = plot(plts...; size = (1400,800),bottommargin = 5Plots.mm)
195
196
197
    savefig(p,"$(fname)_posteriors.pdf")
    return out
end
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215

#error function for inf
#
function visualize_π_base(particles_tuple)

    param_keys = [:π_base_y,:π_base_m,:π_base_o]
    # π_bases_array = map(f -> getproperty(particles_tuple,f),param_keys) |> l -> mapreduce(t-> [t...],hcat,zip(l...))
    # display(π_bases_array)

    # p = cornerplot(π_bases_array'; labels = string.(param_keys))
    params = NamedTuple{(param_keys...,)}(particles_tuple)
    # display(params)
    p = corner(params)
    display(p)


    # display(scatter(map(f -> getproperty(particles_tuple,f),param_keys)...))
end