Commit 2bda668f authored by Peter Jentsch's avatar Peter Jentsch
Browse files

fix heatmap output

parent 9774ac9f
......@@ -13,9 +13,9 @@ function solve_and_plot_parameters()
ymo_vaccination_ts = mean.(out.daily_immunized_by_age)
total_postinf_vaccination = sum.(eachrow(ymo_vaccination_ts[:,180:end]))
final_size = sum.(eachrow(mean.(out.daily_unvac_cases_by_age)))
total_preinf_vaccination = sum.(eachrow(ymo_vaccination_ts[:,1:180]))
total_postinf_vaccination = mean.(out.total_postinf_vaccination)#sum.(eachrow(ymo_vaccination_ts[:,180:end]))
final_size = mean.(out.final_size_by_age)#sum.(eachrow(mean.(out.daily_unvac_cases_by_age)))
total_preinf_vaccination = mean.(out.total_preinf_vaccination)#sum.(eachrow(ymo_vaccination_ts[:,1:180]))
target_final_size = ymo_attack_rate .*avg_populations
target_preinf_vac = ymo_vac .* sum(vaccination_data[1:4]) .* avg_populations
target_postinf_vac = ymo_vac .* sum(vaccination_data[5:end]) .*avg_populations
......
......@@ -15,10 +15,14 @@ abstract type AbstractRecorder{ElType} end
"""
DebugRecorder should store everything we might want to know about the model output.
"""
struct DebugRecorder{ElType,ArrT1<:AbstractArray{ElType},ArrT2<:AbstractArray{ElType},ArrT3<:AbstractArray{ElType}} <: AbstractRecorder{ElType}
struct DebugRecorder{ElType,ArrT1<:AbstractArray{ElType},ArrT2<:AbstractVector{ElType},ArrT3<:AbstractArray{ElType}} <: AbstractRecorder{ElType}
recorded_status_totals::ArrT1
daily_cases_by_age::ArrT3
total_vaccinators::ArrT2
unvac_final_size_by_age::ArrT2
total_postinf_vaccination::ArrT2
total_preinf_vaccination::ArrT2
final_size_by_age::ArrT2
mean_time_since_last_notification::ArrT2
daily_immunized_by_age::ArrT3
daily_unvac_cases_by_age::ArrT3
......@@ -37,7 +41,8 @@ end
"""
Initialize a DebugRecorder filled with (copies) of val.
"""
function DebugRecorder(val::T, sim_length) where T
function DebugRecorder(val::T,sim_length) where T
totals = [copy(val) for i in 1:4, j in 1:sim_length]
state_totals = @LArray totals (S = (1,:),I = (2,:),R = (3,:), V = (4,:))
total_vaccinators = [copy(val) for j in 1:sim_length]
......@@ -46,21 +51,29 @@ function DebugRecorder(val::T, sim_length) where T
daily_cases_by_age = @LArray deepcopy(totals_ymo) (Y = (1,:),M = (2,:),O = (3,:))
daily_immunized_by_age = @LArray deepcopy(totals_ymo) (Y = (1,:),M = (2,:),O = (3,:))
daily_unvac_by_age = @LArray deepcopy(totals_ymo) (Y = (1,:),M = (2,:),O = (3,:))
unvac_final_size_by_age = [copy(val) for i in 1:3]
total_postinf_vaccination = [copy(val) for i in 1:3]
total_preinf_vaccination = [copy(val) for i in 1:3]
final_size_by_age = [copy(val) for i in 1:3]
return DebugRecorder(
state_totals,
daily_cases_by_age,
total_vaccinators,
unvac_final_size_by_age,
total_postinf_vaccination,
total_preinf_vaccination,
final_size_by_age,
mean_time_since_last_notification,
daily_immunized_by_age,
daily_unvac_by_age
)
end
function HeatmapRecorder(val::T,sim_length) where T<:OnlineStat
function HeatmapRecorder(sim_length) where T<:OnlineStat
daily_cases_by_age = @LArray zeros(Int,3,sim_length) (Y = (1,:),M = (2,:),O = (3,:))
final_size_by_age= [copy(val) for i in 1:3]
average_deg_of_vac= copy(val)
average_deg= copy(val)
final_size_by_age = [Variance() for i in 1:3]
average_deg_of_vac = Variance()
average_deg = Variance()
return HeatmapRecorder(
daily_cases_by_age,
final_size_by_age,
......@@ -85,6 +98,14 @@ function record!(t,modelsol, recorder::DebugRecorder)
# display(modelsol.daily_immunizations_by_age)
recorder.daily_immunized_by_age[:,t] .= modelsol.daily_immunized_by_age
recorder.daily_unvac_cases_by_age[:,t] .= modelsol.daily_unvac_cases_by_age
if modelsol.sim_length == t
for age in 1:3
recorder.unvac_final_size_by_age[age] = sum(recorder.daily_unvac_cases_by_age[age,:])
recorder.final_size_by_age[age] = sum(recorder.daily_cases_by_age[age,:])
recorder.total_preinf_vaccination[age] = sum(recorder.daily_immunized_by_age[age,1:modelsol.params.infection_introduction_day])
recorder.total_postinf_vaccination[age] = sum(recorder.daily_immunized_by_age[age,modelsol.params.infection_introduction_day:end])
end
end
end
function record!(t,modelsol, recorder::HeatmapRecorder)
recorder.daily_cases_by_age[:,t] .= modelsol.daily_cases_by_age
......@@ -116,11 +137,13 @@ function mean_solve(samples,parameter_tuple,recorder)
end
function mean_solve(samples,parameter_tuple,recorder::Type{HeatmapRecorder})
stat_recorder = recorder(Variance(), parameter_tuple.sim_length)
output_recorder = recorder(Variance(),parameter_tuple.sim_length)
stat_recorder = recorder(parameter_tuple.sim_length)
avg_populations = [0.0,0.0,0.0]
for _ in 1:samples
output_recorder = recorder(parameter_tuple.sim_length)
sol = abm(parameter_tuple,output_recorder)
avg_populations .+= length.(sol.index_vectors)
fit!(stat_recorder,output_recorder)
end
......
......@@ -5,16 +5,17 @@ const samples = 10
using Random
@testset "default parameters output" begin
Random.seed!(1)
p = CovidAlertVaccinationModel.get_parameters()
display(p)
out,avg_populations = mean_solve(samples, p,DebugRecorder)
plot_model(nothing,[nothing],[out],p.infection_introduction_day,p.immunization_begin_day)
# plot_model(nothing,[nothing],[out],p.infection_introduction_day,p.immunization_begin_day)
ymo_vaccination_ts = mean.(out.daily_immunized_by_age)
# ymo_vaccination_ts = mean.(out.daily_immunized_by_age)
total_postinf_vaccination = sum.(eachrow(ymo_vaccination_ts[:,180:end]))
final_size = sum.(eachrow(mean.(out.daily_unvac_cases_by_age)))
total_preinf_vaccination = sum.(eachrow(ymo_vaccination_ts[:,1:180]))
total_postinf_vaccination = mean.(out.total_postinf_vaccination)#sum.(eachrow(ymo_vaccination_ts[:,180:end]))
final_size = mean.(out.unvac_final_size_by_age)#sum.(eachrow(mean.(out.daily_unvac_cases_by_age)))
total_preinf_vaccination = mean.(out.total_preinf_vaccination)#sum.(eachrow(ymo_vaccination_ts[:,1:180]))
target_final_size = ymo_attack_rate .*avg_populations
target_preinf_vac = ymo_vac .* sum(vaccination_data[1:4]) .* avg_populations
target_postinf_vac = ymo_vac .* sum(vaccination_data[5:end]) .*avg_populations
......@@ -26,21 +27,23 @@ using Random
@test all(abs.(final_size .- target_final_size) .< [75,100,25])
@test all(abs.(total_preinf_vaccination .- target_preinf_vac) .< [50,100,25])
@test all(abs.(total_postinf_vaccination .- target_postinf_vac) .< [120,120,40])
@test all(abs.(total_postinf_vaccination .- target_postinf_vac) .< [135,120,40])
Random.seed!(1)
p = CovidAlertVaccinationModel.get_parameters()
out_heatmap,avg_populations_heatmap = mean_solve(samples, p,HeatmapRecorder)
final_size_heatmap = mean.(out_heatmap.final_size_by_age)
display(final_size_heatmap)
@test all(abs.(final_size_heatmap .- final_size).< [75,100,25])
@test all(abs.(avg_populations_heatmap .- avg_populations).< [75,100,25])
# display(final_size_heatmap)
# display(mean.(out.final_size_by_age))
@test all(abs.(final_size_heatmap .- (mean.(out.final_size_by_age)).< [75,100,25]))
@test all(abs.(avg_populations_heatmap .- avg_populations).< 250)
end
@testset "perf" begin
b = @belapsed CovidAlertVaccinationModel.bench()
@test 4.0 < b < 4.5
@test 4.0 < b < 4.9
if Threads.nthreads() == 25
b = @belapsed CovidAlertVaccinationModel.threaded_bench()
@test 12.0 < b < 16.00
......
No preview for this file type
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment