Newer
Older

Peter Jentsch
committed
#needlessly overwrought output interface

Peter Jentsch
committed
using LabelledArrays
using Plots
import OnlineStats.fit!

Peter Jentsch
committed

Peter Jentsch
committed
"""
AbstractRecorder is the type that all "Recorder" types must extend for `fit!` (and possibly other stuff) to dispatch correctly.

Peter Jentsch
committed

Peter Jentsch
committed
The point of using different recorder types is that we only want to save the output that we need, and that output will vary depending on the specific experiment.
AbstractRecorders can be parameterized to have any element type, so when they are parameterized with T<:OnlineStat, we can use `fit!` to collect statistics about a bunch of Recorders into a single Recorder.
"""
abstract type AbstractRecorder{ElType} end
"""
DebugRecorder should store everything we might want to know about the model output.
"""
struct DebugRecorder{ElType,ArrT1<:AbstractArray{ElType},ArrT2<:AbstractVector{ElType},ArrT3<:AbstractArray{ElType}} <: AbstractRecorder{ElType}
daily_cases_by_age::ArrT3
unvac_final_size_by_age::ArrT2
total_postinf_vaccination::ArrT2
total_preinf_vaccination::ArrT2
final_size_by_age::ArrT2
daily_immunized_by_age::ArrT3

Peter Jentsch
committed
end

Peter Jentsch
committed
"""

Peter Jentsch
committed
"""
struct HeatmapRecorder{ElType,ArrT1<:AbstractArray{ElType},ArrT2<:AbstractArray{Int}} <: AbstractRecorder{ElType}
daily_cases_by_age::ArrT2
final_size_by_age::ArrT1
avg_weighted_degree_of_vaccinators::ElType
avg_weighted_degree::ElType

Peter Jentsch
committed

Peter Jentsch
committed
"""

Peter Jentsch
committed
"""

Peter Jentsch
committed
totals = [copy(val) for i in 1:4, j in 1:sim_length]
state_totals = @LArray totals (S = (1,:),I = (2,:),R = (3,:), V = (4,:))
total_vaccinators = [copy(val) for j in 1:sim_length]
mean_time_since_last_notification = [copy(val) for j in 1:sim_length]
totals_ymo = [copy(val) for i in 1:3, j in 1:sim_length]
daily_cases_by_age = @LArray deepcopy(totals_ymo) (Y = (1,:),M = (2,:),O = (3,:))
daily_immunized_by_age = @LArray deepcopy(totals_ymo) (Y = (1,:),M = (2,:),O = (3,:))
daily_unvac_by_age = @LArray deepcopy(totals_ymo) (Y = (1,:),M = (2,:),O = (3,:))
unvac_final_size_by_age = [copy(val) for i in 1:3]
total_postinf_vaccination = [copy(val) for i in 1:3]
total_preinf_vaccination = [copy(val) for i in 1:3]
final_size_by_age = [copy(val) for i in 1:3]

Peter Jentsch
committed
return DebugRecorder(
state_totals,
daily_cases_by_age,

Peter Jentsch
committed
total_vaccinators,
unvac_final_size_by_age,
total_postinf_vaccination,
total_preinf_vaccination,
final_size_by_age,

Peter Jentsch
committed
)
end
function HeatmapRecorder(sim_length) where T<:OnlineStat
daily_cases_by_age = @LArray zeros(Int,3,sim_length) (Y = (1,:),M = (2,:),O = (3,:))
final_size_by_age = [Variance() for i in 1:3]
average_deg_of_vac = Variance()
average_deg = Variance()
return HeatmapRecorder(
daily_cases_by_age,
final_size_by_age,
average_deg_of_vac,
average_deg,
)
end

Peter Jentsch
committed
function record!(t,modelsol, recorder::DebugRecorder)
recorder.total_vaccinators[t] = modelsol.daily_vaccinators
recorder.daily_cases_by_age[:,t] .= modelsol.daily_cases_by_age

Peter Jentsch
committed
recorder.recorded_status_totals[:,t] .= modelsol.status_totals

Peter Jentsch
committed
alerts = filter(>(0),modelsol.time_of_last_alert)
if !isempty(alerts)
mean_alert_time = mean(t .- alerts)
recorder.mean_time_since_last_notification[t] = round(Int,mean_alert_time)
else
recorder.mean_time_since_last_notification[t] = 0
end
recorder.daily_immunized_by_age[:,t] .= modelsol.daily_immunized_by_age
recorder.daily_unvac_cases_by_age[:,t] .= modelsol.daily_unvac_cases_by_age
if modelsol.sim_length == t
for age in 1:3
recorder.unvac_final_size_by_age[age] = sum(recorder.daily_unvac_cases_by_age[age,:])
recorder.final_size_by_age[age] = sum(recorder.daily_cases_by_age[age,:])
recorder.total_preinf_vaccination[age] = sum(recorder.daily_immunized_by_age[age,1:modelsol.params.infection_introduction_day])
recorder.total_postinf_vaccination[age] = sum(recorder.daily_immunized_by_age[age,modelsol.params.infection_introduction_day:end])
end
end
end
function record!(t,modelsol, recorder::HeatmapRecorder)
recorder.daily_cases_by_age[:,t] .= modelsol.daily_cases_by_age
if modelsol.sim_length == t
for age in 1:3
fit!(recorder.final_size_by_age[age], sum(recorder.daily_cases_by_age[age,:]))
end
merge!(recorder.avg_weighted_degree_of_vaccinators,modelsol.avg_weighted_degree_of_vaccinators)
merge!(recorder.avg_weighted_degree,modelsol.avg_weighted_degree)
end

Peter Jentsch
committed
end
function record!(t,modelsol, recorder::Nothing)
#do nothing
end

Peter Jentsch
committed
function mean_solve(samples,parameter_tuple,recorder;progmeter = nothing)

Peter Jentsch
committed
stat_recorder = recorder(Variance(), parameter_tuple.sim_length)
avg_populations = [0.0,0.0,0.0]
for _ in 1:samples
sol = abm(parameter_tuple,output_recorder)

Peter Jentsch
committed
fit!(stat_recorder,output_recorder)
end
avg_populations ./= samples
return stat_recorder,avg_populations

Peter Jentsch
committed
end
function mean_solve(samples,parameter_tuple,recorder::Type{HeatmapRecorder};progmeter = nothing)
output_recorder = recorder(parameter_tuple.sim_length)
avg_populations .+= length.(sol.index_vectors)
fit!(stat_recorder,output_recorder)
end
avg_populations ./= samples
return stat_recorder,avg_populations
end
function OnlineStats.fit!(accum::R,pt::R2) where {R<:AbstractRecorder,R2<:AbstractRecorder}

Peter Jentsch
committed
for field in fieldnames(R)

Peter Jentsch
committed
sample_field = getfield(pt,field)

Peter Jentsch
committed
end
end

Peter Jentsch
committed
function combine!(accum::AbstractArray,x::AbstractArray)
for (stat_entry,sample_entry) in zip(accum,x)
combine!(stat_entry,sample_entry)
end

Peter Jentsch
committed
end
combine!(accum::OnlineStat{<:T}, x::T) where T = fit!(accum,x)
combine!(accum::OnlineStat{T}, x::OnlineStat{T}) where T = merge!(accum,x)
combine!(accum::Real, x::Real) = nothing