Skip to content
Snippets Groups Projects
output.jl 4.8 KiB
Newer Older
#needlessly overwrought output interface
using LabelledArrays
using Plots
import OnlineStats.fit!
"""
Peter Jentsch's avatar
Peter Jentsch committed
Recorder should store everything we might want to know about the model output. 
Peter Jentsch's avatar
Peter Jentsch committed
struct Recorder{ElType,ArrT1,ArrT2,ArrT3,StatAccumulator}
Peter Jentsch's avatar
Peter Jentsch committed
    recorded_status_totals::ArrT1
Peter Jentsch's avatar
Peter Jentsch committed
    total_vaccinators::ArrT2
Peter Jentsch's avatar
Peter Jentsch committed
    daily_total_notifications::ArrT2
    daily_total_notified_agents::ArrT2
Peter Jentsch's avatar
Peter Jentsch committed
    unvac_final_size_by_age::ArrT2
    total_postinf_vaccination::ArrT2
    total_preinf_vaccination::ArrT2
    final_size_by_age::ArrT2
Peter Jentsch's avatar
Peter Jentsch committed
    mean_time_since_last_notification::ArrT2
    daily_immunized_by_age::ArrT3
    daily_unvac_cases_by_age::ArrT3
    record_degrees_flag::Bool
Peter Jentsch's avatar
Peter Jentsch committed
    avg_weighted_degree_of_vaccinators::Vector{StatAccumulator}
    avg_weighted_degree::Vector{StatAccumulator}
Peter Jentsch's avatar
Peter Jentsch committed

    function Recorder(val::T,sim_length; record_degrees = false) where T
Peter Jentsch's avatar
Peter Jentsch committed
        totals = [copy(val) for i in 1:4, j in 1:sim_length]
        state_totals = @LArray totals (S = (1,:),I = (2,:),R = (3,:), V = (4,:))
        total_vaccinators = [copy(val) for j in 1:sim_length]
    
        daily_total_notifications = [copy(val) for j in 1:sim_length]
        daily_total_notified_agents = [copy(val) for j in 1:sim_length]
    
        mean_time_since_last_notification = [copy(val) for j in 1:sim_length]
        totals_ymo = [copy(val) for i in 1:3, j in 1:sim_length]
        daily_cases_by_age = @LArray deepcopy(totals_ymo) (Y = (1,:),M = (2,:),O = (3,:))
        daily_immunized_by_age = @LArray deepcopy(totals_ymo) (Y = (1,:),M = (2,:),O = (3,:))
        daily_unvac_by_age = @LArray deepcopy(totals_ymo) (Y = (1,:),M = (2,:),O = (3,:))
        unvac_final_size_by_age =  [copy(val) for i in 1:3]
        total_postinf_vaccination =  [copy(val) for i in 1:3]
        total_preinf_vaccination =  [copy(val) for i in 1:3]
        final_size_by_age =  [copy(val) for i in 1:3]
    
        avg_weighted_degree_of_vaccinators = [Variance() for _ in 1:3]
        avg_weighted_degree = [Variance() for _ in 1:3]
Peter Jentsch's avatar
Peter Jentsch committed
    
        return new{T,typeof(state_totals),typeof(total_vaccinators),typeof(daily_immunized_by_age),eltype(avg_weighted_degree)}(
            state_totals,
            daily_cases_by_age,
            total_vaccinators,
            daily_total_notifications,
            daily_total_notified_agents,
            unvac_final_size_by_age,
            total_postinf_vaccination,
            total_preinf_vaccination,
            final_size_by_age,
            mean_time_since_last_notification,
            daily_immunized_by_age,
            daily_unvac_by_age,
            record_degrees,
Peter Jentsch's avatar
Peter Jentsch committed
            avg_weighted_degree_of_vaccinators,
            avg_weighted_degree
        )
    end
Peter Jentsch's avatar
Peter Jentsch committed
end
Peter Jentsch's avatar
Peter Jentsch committed
Initialize a Recorder filled with (copies) of val.
Peter Jentsch's avatar
Peter Jentsch committed
function record!(t,modelsol)
    recorder = modelsol.output_data
    alerts = filter(>(0),modelsol.time_of_last_alert)
    if !isempty(alerts)
        mean_alert_time = mean(t .- alerts) 
        recorder.mean_time_since_last_notification[t] = round(Int,mean_alert_time) 
    else
        recorder.mean_time_since_last_notification[t] = 0
    end
Peter Jentsch's avatar
Peter Jentsch committed
    if modelsol.sim_length == t
        for age in 1:3
            recorder.unvac_final_size_by_age[age] = sum(recorder.daily_unvac_cases_by_age[age,:])
            recorder.final_size_by_age[age] = sum(recorder.daily_cases_by_age[age,:])
            recorder.total_preinf_vaccination[age] = sum(recorder.daily_immunized_by_age[age,1:modelsol.params.infection_introduction_day])
            recorder.total_postinf_vaccination[age] = sum(recorder.daily_immunized_by_age[age,modelsol.params.infection_introduction_day:end])
        end
    end
Peter Jentsch's avatar
Peter Jentsch committed
end
function mean_solve(samples,parameter_tuple;progmeter = nothing, record_degrees = false)
    accumulation_recorder = Recorder(Variance(), parameter_tuple.sim_length)
Peter Jentsch's avatar
Peter Jentsch committed
    avg_populations = [0.0,0.0,0.0]
Peter Jentsch's avatar
Peter Jentsch committed
    sol_list = map(1:samples) do _
        # Random.seed!(Random.default_rng(),1)
        sol = abm(parameter_tuple;record_degrees)
Peter Jentsch's avatar
Peter Jentsch committed
        isnothing(progmeter) || next!(progmeter)
        return sol.output_data,length.(sol.index_vectors)
Peter Jentsch's avatar
Peter Jentsch committed
    end
    for (output_recorder,pop) in sol_list
        avg_populations .+= pop
        accumulate_model(accumulation_recorder,output_recorder)
Peter Jentsch's avatar
Peter Jentsch committed
    avg_populations ./= samples
    return accumulation_recorder,avg_populations
function accumulate_model(accum::R,pt) where R
Peter Jentsch's avatar
Peter Jentsch committed
        stat_field = getfield(accum,field)
Peter Jentsch's avatar
Peter Jentsch committed
        combine!(stat_field,sample_field)
Peter Jentsch's avatar
Peter Jentsch committed
function combine!(accum::AbstractArray,x::AbstractArray)
    for (stat_entry,sample_entry) in zip(accum,x)
        combine!(stat_entry,sample_entry)
Peter Jentsch's avatar
Peter Jentsch committed
combine!(accum::OnlineStat{<:T}, x::T) where T = fit!(accum,x)
combine!(accum::OnlineStat{T}, x::OnlineStat{T}) where T = merge!(accum,x)
combine!(accum::Real, x::Real) = nothing