#needlessly overwrought output interface
using LabelledArrays
using Plots
import OnlineStats.fit!
"""
Recorder should store everything we might want to know about the model output. 
"""
struct Recorder{ElType,ArrT1,ArrT2,ArrT3,StatAccumulator}

    recorded_status_totals::ArrT1
    daily_cases_by_age::ArrT3
    total_vaccinators::ArrT2
    daily_total_notifications::ArrT2
    daily_total_notified_agents::ArrT2
    unvac_final_size_by_age::ArrT2
    total_postinf_vaccination::ArrT2
    total_preinf_vaccination::ArrT2
    final_size_by_age::ArrT2
    mean_time_since_last_notification::ArrT2
    daily_immunized_by_age::ArrT3
    daily_unvac_cases_by_age::ArrT3
    record_degrees_flag::Bool
    avg_weighted_degree_of_vaccinators::Vector{StatAccumulator}
    avg_weighted_degree::Vector{StatAccumulator}

    function Recorder(val::T,sim_length; record_degrees = false) where T
    

        totals = [copy(val) for i in 1:4, j in 1:sim_length]
        state_totals = @LArray totals (S = (1,:),I = (2,:),R = (3,:), V = (4,:))
        total_vaccinators = [copy(val) for j in 1:sim_length]
    
        daily_total_notifications = [copy(val) for j in 1:sim_length]
        daily_total_notified_agents = [copy(val) for j in 1:sim_length]
    
        mean_time_since_last_notification = [copy(val) for j in 1:sim_length]
        totals_ymo = [copy(val) for i in 1:3, j in 1:sim_length]
        daily_cases_by_age = @LArray deepcopy(totals_ymo) (Y = (1,:),M = (2,:),O = (3,:))
        daily_immunized_by_age = @LArray deepcopy(totals_ymo) (Y = (1,:),M = (2,:),O = (3,:))
        daily_unvac_by_age = @LArray deepcopy(totals_ymo) (Y = (1,:),M = (2,:),O = (3,:))
        unvac_final_size_by_age =  [copy(val) for i in 1:3]
        total_postinf_vaccination =  [copy(val) for i in 1:3]
        total_preinf_vaccination =  [copy(val) for i in 1:3]
        final_size_by_age =  [copy(val) for i in 1:3]
    
        avg_weighted_degree_of_vaccinators = [Variance() for _ in 1:3]
        avg_weighted_degree = [Variance() for _ in 1:3]
    
        return new{T,typeof(state_totals),typeof(total_vaccinators),typeof(daily_immunized_by_age),eltype(avg_weighted_degree)}(
            state_totals,
            daily_cases_by_age,
            total_vaccinators,
            daily_total_notifications,
            daily_total_notified_agents,
            unvac_final_size_by_age,
            total_postinf_vaccination,
            total_preinf_vaccination,
            final_size_by_age,
            mean_time_since_last_notification,
            daily_immunized_by_age,
            daily_unvac_by_age,
            record_degrees,
            avg_weighted_degree_of_vaccinators,
            avg_weighted_degree
        )
    end
end
"""
Initialize a Recorder filled with (copies) of val.
"""



function record!(t,modelsol)
    recorder = modelsol.output_data
    alerts = filter(>(0),modelsol.time_of_last_alert)
    if !isempty(alerts)
        mean_alert_time = mean(t .- alerts) 
        recorder.mean_time_since_last_notification[t] = round(Int,mean_alert_time) 
    else
        recorder.mean_time_since_last_notification[t] = 0
    end
    if modelsol.sim_length == t
        for age in 1:3
            recorder.unvac_final_size_by_age[age] = sum(recorder.daily_unvac_cases_by_age[age,:])
            recorder.final_size_by_age[age] = sum(recorder.daily_cases_by_age[age,:])
            recorder.total_preinf_vaccination[age] = sum(recorder.daily_immunized_by_age[age,1:modelsol.params.infection_introduction_day])
            recorder.total_postinf_vaccination[age] = sum(recorder.daily_immunized_by_age[age,modelsol.params.infection_introduction_day:end])
        end
    end
end


function mean_solve(samples,parameter_tuple;progmeter = nothing, record_degrees = false)
    accumulation_recorder = Recorder(Variance(), parameter_tuple.sim_length)
    avg_populations = [0.0,0.0,0.0]
    sol_list = map(1:samples) do _
        # Random.seed!(Random.default_rng(),1)
        sol = abm(parameter_tuple;record_degrees)
        isnothing(progmeter) || next!(progmeter)
        return sol.output_data,length.(sol.index_vectors)
    end
    for (output_recorder,pop) in sol_list
        avg_populations .+= pop
        accumulate_model(accumulation_recorder,output_recorder)
    end
    avg_populations ./= samples
    return accumulation_recorder,avg_populations
end

function accumulate_model(accum::R,pt) where R
    for field in fieldnames(R)
        stat_field = getfield(accum,field)
        sample_field = getfield(pt,field)
        combine!(stat_field,sample_field)
    end
end

function combine!(accum::AbstractArray,x::AbstractArray)
    for (stat_entry,sample_entry) in zip(accum,x)
        combine!(stat_entry,sample_entry)
    end
end
combine!(accum::OnlineStat{<:T}, x::T) where T = fit!(accum,x)
combine!(accum::OnlineStat{T}, x::OnlineStat{T}) where T = merge!(accum,x)
combine!(accum::Real, x::Real) = nothing