Skip to content
Snippets Groups Projects
Commit 2bda668f authored by Peter Jentsch's avatar Peter Jentsch
Browse files

fix heatmap output

parent 9774ac9f
No related branches found
No related tags found
No related merge requests found
......@@ -13,9 +13,9 @@ function solve_and_plot_parameters()
ymo_vaccination_ts = mean.(out.daily_immunized_by_age)
total_postinf_vaccination = sum.(eachrow(ymo_vaccination_ts[:,180:end]))
final_size = sum.(eachrow(mean.(out.daily_unvac_cases_by_age)))
total_preinf_vaccination = sum.(eachrow(ymo_vaccination_ts[:,1:180]))
total_postinf_vaccination = mean.(out.total_postinf_vaccination)#sum.(eachrow(ymo_vaccination_ts[:,180:end]))
final_size = mean.(out.final_size_by_age)#sum.(eachrow(mean.(out.daily_unvac_cases_by_age)))
total_preinf_vaccination = mean.(out.total_preinf_vaccination)#sum.(eachrow(ymo_vaccination_ts[:,1:180]))
target_final_size = ymo_attack_rate .*avg_populations
target_preinf_vac = ymo_vac .* sum(vaccination_data[1:4]) .* avg_populations
target_postinf_vac = ymo_vac .* sum(vaccination_data[5:end]) .*avg_populations
......
......@@ -15,10 +15,14 @@ abstract type AbstractRecorder{ElType} end
"""
DebugRecorder should store everything we might want to know about the model output.
"""
struct DebugRecorder{ElType,ArrT1<:AbstractArray{ElType},ArrT2<:AbstractArray{ElType},ArrT3<:AbstractArray{ElType}} <: AbstractRecorder{ElType}
struct DebugRecorder{ElType,ArrT1<:AbstractArray{ElType},ArrT2<:AbstractVector{ElType},ArrT3<:AbstractArray{ElType}} <: AbstractRecorder{ElType}
recorded_status_totals::ArrT1
daily_cases_by_age::ArrT3
total_vaccinators::ArrT2
unvac_final_size_by_age::ArrT2
total_postinf_vaccination::ArrT2
total_preinf_vaccination::ArrT2
final_size_by_age::ArrT2
mean_time_since_last_notification::ArrT2
daily_immunized_by_age::ArrT3
daily_unvac_cases_by_age::ArrT3
......@@ -37,7 +41,8 @@ end
"""
Initialize a DebugRecorder filled with (copies) of val.
"""
function DebugRecorder(val::T, sim_length) where T
function DebugRecorder(val::T,sim_length) where T
totals = [copy(val) for i in 1:4, j in 1:sim_length]
state_totals = @LArray totals (S = (1,:),I = (2,:),R = (3,:), V = (4,:))
total_vaccinators = [copy(val) for j in 1:sim_length]
......@@ -46,21 +51,29 @@ function DebugRecorder(val::T, sim_length) where T
daily_cases_by_age = @LArray deepcopy(totals_ymo) (Y = (1,:),M = (2,:),O = (3,:))
daily_immunized_by_age = @LArray deepcopy(totals_ymo) (Y = (1,:),M = (2,:),O = (3,:))
daily_unvac_by_age = @LArray deepcopy(totals_ymo) (Y = (1,:),M = (2,:),O = (3,:))
unvac_final_size_by_age = [copy(val) for i in 1:3]
total_postinf_vaccination = [copy(val) for i in 1:3]
total_preinf_vaccination = [copy(val) for i in 1:3]
final_size_by_age = [copy(val) for i in 1:3]
return DebugRecorder(
state_totals,
daily_cases_by_age,
total_vaccinators,
unvac_final_size_by_age,
total_postinf_vaccination,
total_preinf_vaccination,
final_size_by_age,
mean_time_since_last_notification,
daily_immunized_by_age,
daily_unvac_by_age
)
end
function HeatmapRecorder(val::T,sim_length) where T<:OnlineStat
function HeatmapRecorder(sim_length) where T<:OnlineStat
daily_cases_by_age = @LArray zeros(Int,3,sim_length) (Y = (1,:),M = (2,:),O = (3,:))
final_size_by_age= [copy(val) for i in 1:3]
average_deg_of_vac= copy(val)
average_deg= copy(val)
final_size_by_age = [Variance() for i in 1:3]
average_deg_of_vac = Variance()
average_deg = Variance()
return HeatmapRecorder(
daily_cases_by_age,
final_size_by_age,
......@@ -85,6 +98,14 @@ function record!(t,modelsol, recorder::DebugRecorder)
# display(modelsol.daily_immunizations_by_age)
recorder.daily_immunized_by_age[:,t] .= modelsol.daily_immunized_by_age
recorder.daily_unvac_cases_by_age[:,t] .= modelsol.daily_unvac_cases_by_age
if modelsol.sim_length == t
for age in 1:3
recorder.unvac_final_size_by_age[age] = sum(recorder.daily_unvac_cases_by_age[age,:])
recorder.final_size_by_age[age] = sum(recorder.daily_cases_by_age[age,:])
recorder.total_preinf_vaccination[age] = sum(recorder.daily_immunized_by_age[age,1:modelsol.params.infection_introduction_day])
recorder.total_postinf_vaccination[age] = sum(recorder.daily_immunized_by_age[age,modelsol.params.infection_introduction_day:end])
end
end
end
function record!(t,modelsol, recorder::HeatmapRecorder)
recorder.daily_cases_by_age[:,t] .= modelsol.daily_cases_by_age
......@@ -116,11 +137,13 @@ function mean_solve(samples,parameter_tuple,recorder)
end
function mean_solve(samples,parameter_tuple,recorder::Type{HeatmapRecorder})
stat_recorder = recorder(Variance(), parameter_tuple.sim_length)
output_recorder = recorder(Variance(),parameter_tuple.sim_length)
stat_recorder = recorder(parameter_tuple.sim_length)
avg_populations = [0.0,0.0,0.0]
for _ in 1:samples
output_recorder = recorder(parameter_tuple.sim_length)
sol = abm(parameter_tuple,output_recorder)
avg_populations .+= length.(sol.index_vectors)
fit!(stat_recorder,output_recorder)
end
......
......@@ -5,16 +5,17 @@ const samples = 10
using Random
@testset "default parameters output" begin
Random.seed!(1)
p = CovidAlertVaccinationModel.get_parameters()
display(p)
out,avg_populations = mean_solve(samples, p,DebugRecorder)
plot_model(nothing,[nothing],[out],p.infection_introduction_day,p.immunization_begin_day)
# plot_model(nothing,[nothing],[out],p.infection_introduction_day,p.immunization_begin_day)
ymo_vaccination_ts = mean.(out.daily_immunized_by_age)
# ymo_vaccination_ts = mean.(out.daily_immunized_by_age)
total_postinf_vaccination = sum.(eachrow(ymo_vaccination_ts[:,180:end]))
final_size = sum.(eachrow(mean.(out.daily_unvac_cases_by_age)))
total_preinf_vaccination = sum.(eachrow(ymo_vaccination_ts[:,1:180]))
total_postinf_vaccination = mean.(out.total_postinf_vaccination)#sum.(eachrow(ymo_vaccination_ts[:,180:end]))
final_size = mean.(out.unvac_final_size_by_age)#sum.(eachrow(mean.(out.daily_unvac_cases_by_age)))
total_preinf_vaccination = mean.(out.total_preinf_vaccination)#sum.(eachrow(ymo_vaccination_ts[:,1:180]))
target_final_size = ymo_attack_rate .*avg_populations
target_preinf_vac = ymo_vac .* sum(vaccination_data[1:4]) .* avg_populations
target_postinf_vac = ymo_vac .* sum(vaccination_data[5:end]) .*avg_populations
......@@ -26,21 +27,23 @@ using Random
@test all(abs.(final_size .- target_final_size) .< [75,100,25])
@test all(abs.(total_preinf_vaccination .- target_preinf_vac) .< [50,100,25])
@test all(abs.(total_postinf_vaccination .- target_postinf_vac) .< [120,120,40])
@test all(abs.(total_postinf_vaccination .- target_postinf_vac) .< [135,120,40])
Random.seed!(1)
p = CovidAlertVaccinationModel.get_parameters()
out_heatmap,avg_populations_heatmap = mean_solve(samples, p,HeatmapRecorder)
final_size_heatmap = mean.(out_heatmap.final_size_by_age)
display(final_size_heatmap)
@test all(abs.(final_size_heatmap .- final_size).< [75,100,25])
@test all(abs.(avg_populations_heatmap .- avg_populations).< [75,100,25])
# display(final_size_heatmap)
# display(mean.(out.final_size_by_age))
@test all(abs.(final_size_heatmap .- (mean.(out.final_size_by_age)).< [75,100,25]))
@test all(abs.(avg_populations_heatmap .- avg_populations).< 250)
end
@testset "perf" begin
b = @belapsed CovidAlertVaccinationModel.bench()
@test 4.0 < b < 4.5
@test 4.0 < b < 4.9
if Threads.nthreads() == 25
b = @belapsed CovidAlertVaccinationModel.threaded_bench()
@test 12.0 < b < 16.00
......
No preview for this file type
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment