Skip to content
Snippets Groups Projects
output.jl 6.6 KiB
Newer Older
#needlessly overwrought output interface
using LabelledArrays
using Plots
import OnlineStats.fit!
"""
AbstractRecorder is the type that all "Recorder" types must extend for `fit!` (and possibly other stuff) to dispatch correctly.
The point of using different recorder types is that we only want to save the output that we need, and that output will vary depending on the specific experiment.

AbstractRecorders can be parameterized to have any element type, so when they are parameterized with T<:OnlineStat, we can use `fit!` to collect statistics about a bunch of Recorders into a single Recorder.
"""
abstract type AbstractRecorder{ElType} end

"""
DebugRecorder should store everything we might want to know about the model output. 
"""
Peter Jentsch's avatar
Peter Jentsch committed
struct DebugRecorder{ElType,ArrT1<:AbstractArray{ElType},ArrT2<:AbstractVector{ElType},ArrT3<:AbstractArray{ElType}} <: AbstractRecorder{ElType}
Peter Jentsch's avatar
Peter Jentsch committed
    recorded_status_totals::ArrT1
Peter Jentsch's avatar
Peter Jentsch committed
    total_vaccinators::ArrT2
Peter Jentsch's avatar
Peter Jentsch committed
    unvac_final_size_by_age::ArrT2
    total_postinf_vaccination::ArrT2
    total_preinf_vaccination::ArrT2
    final_size_by_age::ArrT2
Peter Jentsch's avatar
Peter Jentsch committed
    mean_time_since_last_notification::ArrT2
    daily_immunized_by_age::ArrT3
    daily_unvac_cases_by_age::ArrT3
Peter Jentsch's avatar
Peter Jentsch committed

Peter Jentsch's avatar
Peter Jentsch committed
HeatmapRecorder, for heatmaps! 
Peter Jentsch's avatar
Peter Jentsch committed
struct HeatmapRecorder{ElType,ArrT1<:AbstractArray{ElType},ArrT2<:AbstractArray{Int}} <: AbstractRecorder{ElType}
    daily_cases_by_age::ArrT2
    final_size_by_age::ArrT1
    avg_weighted_degree_of_vaccinators::ElType
    avg_weighted_degree::ElType
Peter Jentsch's avatar
Peter Jentsch committed
end
Peter Jentsch's avatar
Peter Jentsch committed
Initialize a DebugRecorder filled with (copies) of val.
Peter Jentsch's avatar
Peter Jentsch committed
function DebugRecorder(val::T,sim_length) where T
    
    totals = [copy(val) for i in 1:4, j in 1:sim_length]
    state_totals = @LArray totals (S = (1,:),I = (2,:),R = (3,:), V = (4,:))
    total_vaccinators = [copy(val) for j in 1:sim_length]
    mean_time_since_last_notification = [copy(val) for j in 1:sim_length]
    totals_ymo = [copy(val) for i in 1:3, j in 1:sim_length]
    daily_cases_by_age = @LArray deepcopy(totals_ymo) (Y = (1,:),M = (2,:),O = (3,:))
    daily_immunized_by_age = @LArray deepcopy(totals_ymo) (Y = (1,:),M = (2,:),O = (3,:))
    daily_unvac_by_age = @LArray deepcopy(totals_ymo) (Y = (1,:),M = (2,:),O = (3,:))
Peter Jentsch's avatar
Peter Jentsch committed
    unvac_final_size_by_age =  [copy(val) for i in 1:3]
    total_postinf_vaccination =  [copy(val) for i in 1:3]
    total_preinf_vaccination =  [copy(val) for i in 1:3]
    final_size_by_age =  [copy(val) for i in 1:3]
Peter Jentsch's avatar
Peter Jentsch committed
        unvac_final_size_by_age,
        total_postinf_vaccination,
        total_preinf_vaccination,
        final_size_by_age,
Peter Jentsch's avatar
Peter Jentsch committed
        mean_time_since_last_notification,
        daily_immunized_by_age,
        daily_unvac_by_age
Peter Jentsch's avatar
Peter Jentsch committed
function HeatmapRecorder(sim_length) where T<:OnlineStat
Peter Jentsch's avatar
Peter Jentsch committed
    daily_cases_by_age = @LArray zeros(Int,3,sim_length) (Y = (1,:),M = (2,:),O = (3,:))
Peter Jentsch's avatar
Peter Jentsch committed
    final_size_by_age =  [Variance() for i in 1:3]
    average_deg_of_vac = Variance()
    average_deg = Variance()
Peter Jentsch's avatar
Peter Jentsch committed
    return HeatmapRecorder(
        daily_cases_by_age,
        final_size_by_age,
        average_deg_of_vac,
        average_deg,
    )
end



function record!(t,modelsol, recorder::DebugRecorder)
    recorder.total_vaccinators[t] =  modelsol.daily_vaccinators
    recorder.daily_cases_by_age[:,t] .= modelsol.daily_cases_by_age
    recorder.recorded_status_totals[:,t] .= modelsol.status_totals
    alerts = filter(>(0),modelsol.time_of_last_alert)
    if !isempty(alerts)
        mean_alert_time = mean(t .- alerts) 
        recorder.mean_time_since_last_notification[t] = round(Int,mean_alert_time) 
    else
        recorder.mean_time_since_last_notification[t] = 0
    end
Peter Jentsch's avatar
Peter Jentsch committed
    # display(modelsol.daily_immunizations_by_age)
    recorder.daily_immunized_by_age[:,t] .= modelsol.daily_immunized_by_age
    recorder.daily_unvac_cases_by_age[:,t] .= modelsol.daily_unvac_cases_by_age
Peter Jentsch's avatar
Peter Jentsch committed
    if modelsol.sim_length == t
        for age in 1:3
            recorder.unvac_final_size_by_age[age] = sum(recorder.daily_unvac_cases_by_age[age,:])
            recorder.final_size_by_age[age] = sum(recorder.daily_cases_by_age[age,:])
            recorder.total_preinf_vaccination[age] = sum(recorder.daily_immunized_by_age[age,1:modelsol.params.infection_introduction_day])
            recorder.total_postinf_vaccination[age] = sum(recorder.daily_immunized_by_age[age,modelsol.params.infection_introduction_day:end])
        end
    end
Peter Jentsch's avatar
Peter Jentsch committed
end
function record!(t,modelsol, recorder::HeatmapRecorder)
    recorder.daily_cases_by_age[:,t] .= modelsol.daily_cases_by_age
    if modelsol.sim_length == t
        for age in 1:3
            fit!(recorder.final_size_by_age[age], sum(recorder.daily_cases_by_age[age,:]))
        end
        merge!(recorder.avg_weighted_degree_of_vaccinators,modelsol.avg_weighted_degree_of_vaccinators)
        merge!(recorder.avg_weighted_degree,modelsol.avg_weighted_degree)
    end
end

function record!(t,modelsol, recorder::Nothing)
    #do nothing
end
Peter Jentsch's avatar
Peter Jentsch committed
function mean_solve(samples,parameter_tuple,recorder;progmeter = nothing)
    stat_recorder = recorder(Variance(), parameter_tuple.sim_length)
Peter Jentsch's avatar
Peter Jentsch committed
    output_recorder = recorder(0.0,parameter_tuple.sim_length)
Peter Jentsch's avatar
Peter Jentsch committed
    avg_populations = [0.0,0.0,0.0]
    for _ in 1:samples
        sol = abm(parameter_tuple,output_recorder)
Peter Jentsch's avatar
Peter Jentsch committed
        isnothing(progmeter) || next!(progmeter)
Peter Jentsch's avatar
Peter Jentsch committed
        avg_populations .+= length.(sol.index_vectors)
Peter Jentsch's avatar
Peter Jentsch committed
    avg_populations ./= samples
    return stat_recorder,avg_populations
Peter Jentsch's avatar
Peter Jentsch committed
function mean_solve(samples,parameter_tuple,recorder::Type{HeatmapRecorder};progmeter = nothing)
Peter Jentsch's avatar
Peter Jentsch committed
    stat_recorder = recorder(parameter_tuple.sim_length)
Peter Jentsch's avatar
Peter Jentsch committed
    avg_populations = [0.0,0.0,0.0]
    for _ in 1:samples
Peter Jentsch's avatar
Peter Jentsch committed
        output_recorder = recorder(parameter_tuple.sim_length)
Peter Jentsch's avatar
Peter Jentsch committed
        sol = abm(parameter_tuple,output_recorder)
Peter Jentsch's avatar
Peter Jentsch committed
        isnothing(progmeter) || next!(progmeter)
Peter Jentsch's avatar
Peter Jentsch committed
        avg_populations .+= length.(sol.index_vectors)
        fit!(stat_recorder,output_recorder)
    end
    avg_populations ./= samples
    return stat_recorder,avg_populations
end
function OnlineStats.fit!(accum::R,pt::R2) where {R<:AbstractRecorder,R2<:AbstractRecorder}
Peter Jentsch's avatar
Peter Jentsch committed
        stat_field = getfield(accum,field)
Peter Jentsch's avatar
Peter Jentsch committed
        combine!(stat_field,sample_field)
Peter Jentsch's avatar
Peter Jentsch committed
function combine!(accum::AbstractArray,x::AbstractArray)
    for (stat_entry,sample_entry) in zip(accum,x)
        combine!(stat_entry,sample_entry)
Peter Jentsch's avatar
Peter Jentsch committed
combine!(accum::OnlineStat{<:T}, x::T) where T = fit!(accum,x)
combine!(accum::OnlineStat{T}, x::OnlineStat{T}) where T = merge!(accum,x)
combine!(accum::Real, x::Real) = nothing