Skip to content
Snippets Groups Projects
parameter_planes.jl 5.04 KiB

const samples = 20
const univariate_path = "CovidAlertVaccinationModel/plots/univariate/" 
const bivariate_path = "CovidAlertVaccinationModel/plots/univariate/" 
function univarate_test(variable, variable_range; progmeter = nothing)
    default_parameters = get_app_parameters() 
    parameter_range_list = [merge(default_parameters,NamedTuple{(variable,)}((value,))) for value in variable_range]
    solve_fn(p) = mean_solve(samples, p;progmeter)[1]

    univariate_outlist = ThreadsX.map(solve_fn, parameter_range_list)
    
    p = plot_model(variable,parameter_range_list,univariate_outlist,default_parameters)
    return p
end

if !ispath(univariate_path)
    mkdir(univariate_path)
end
function univariate_simulations()
    len = 6
    univarate_test_list = (
        # (:I_0_fraction, range(0.0, 0.05; length = len)), 
        # (:base_transmission_probability, range(0.0002, 0.002; length = len)),
        # (:recovery_rate, range(0.1, 0.5; length = len)),
        # (:immunization_loss_prob, range(0.00, 0.05; length = len)),
        # (:π_base, range(-4.5, -3.5;  length = len)),
        # (:η, range(0.0,2.0; length = len)),
        # (:κ, range(0.5, 1.5; length = len)),
        # (:ω, range(0.0, 0.01; length = len)),
        # (:ω_en, range(0.0, 5e-2; length = len)),
        # (:γ, range(0.0, 0.5; length = len)),
        # (:ξ, range(1, 15; length = len)),
        # (:notification_parameter, range(0.000, 0.001; length = len)),
        (:app_user_fraction, range(0.05, 0.8; length = len)),
        # (:notification_threshold, (1:len)),
        # (:immunization_delay, [7,10,14,20]),
    )


    numsim = sum(map(t -> length(t[2]), univarate_test_list))
    display(numsim)
    progmeter = Progress(numsim*samples)
    display(get_app_parameters())
    plt_list = ThreadsX.map(univarate_test_list) do ur
        out = univarate_test(ur...;progmeter)
        display(out)
        return out
    end

    for ((varname,_),pltlist) in zip(univarate_test_list,plt_list)
        mkpath("$univariate_path/$varname")
        # display(pltlist)
        for (i,p) in enumerate(pltlist)
            savefig(p,"$univariate_path/$varname/$i.pdf")
        end
    end
end

using AxisKeys 
function multivariate_simulations()
    len = 15
    app_simulations = (
        (:η, range(0.0, 2.0; length = len)),
        (:ω_en, range(0.0, 1e-1; length = len)),
        # (:notification_threshold, (1:len)),
    )
    run_multivariate_sims(app_simulations)


    # for ((varname,_),p) in zip(univarate_test_list,plt_list)
    #     savefig(p,"$univariate_path/$varname.pdf")
    # end
end
using ProgressMeter
function run_multivariate_sims(sims)
    varnames, sim_ranges = zip(sims...)

    simvars = Iterators.product(sim_ranges...)
    progmeter = Progress((length(simvars)+1)*samples)

    
    without_app, _ = mean_solve(samples,get_parameters(); progmeter, record_degrees = true)
    next!(progmeter)
    default_parameters = get_app_parameters() 
    app_output = ThreadsX.map(simvars) do vars
        vars_with_names = NamedTuple{varnames}(vars)
        parameters = merge(default_parameters,vars_with_names)
        out,_ = mean_solve(samples, parameters;progmeter, record_degrees = true) 
        return out
    end
    display(length(simvars))
    fname = join(string.(varnames),"_")
    keyed_output = KeyedArray(app_output;NamedTuple{varnames}(sim_ranges)...)
    path = joinpath(PACKAGE_FOLDER,"abm_output","$fname.dat")
    serialize(path,(without_app,keyed_output))
    return fname
end
using ColorSchemes
using LaTeXStrings
function plot_parameter_plane(input_fname)
    path = joinpath(PACKAGE_FOLDER,"abm_output","$input_fname.dat")
    output_no_app, output = deserialize(path)
    var_ranges = axiskeys(output)
    vars = (L"\eta",L"\omega_{en}")



    mean_final_size(p) = mean(reduce(merge!,p.final_size_by_age));
    std_final_size(p) = std(reduce(merge!,p.final_size_by_age))
    base_outcome = mean_final_size(output_no_app)
    final_size_map = map(x-> (mean_final_size(x) - base_outcome),output)
    mean_weighted_degree_change(p,age) = mean(p.avg_weighted_degree_of_vaccinators[age])-mean(p.avg_weighted_degree[age])
    weighted_degree_map = [map(p -> mean_weighted_degree_change(p,i),output) for i in 1:3]



    cs = cgrad(:blues)
    datamaps = (weighted_degree_map..., final_size_map,map(std_final_size,output_data))
    fnames = [
        "wdg_change_Y.pdf",
        "wdg_change_M.pdf",
        "wdg_change_O.pdf", 
        "final_size_change.pdf",
        "final_size_standard_dev.pdf"
    ]
    titles = [
        "Average w. deg. of vaccinators minus average w. deg., Y",
        "Average w. deg. of vaccinators minus average w. deg., M",
        "Average w. deg. of vaccinators minus average w. deg., O",
        "Effect of notifications on tot. infections",
        "Standard deviation from the mean of total infection size"
    ]
    for (fname,title,datamap) in zip(fnames,titles,datamaps)
        p = heatmap(var_ranges[1],var_ranges[2],datamap; title, xlabel = vars[1], ylabel = vars[2], seriescolor=cs, size = (600,400))
        savefig(p,joinpath(PACKAGE_FOLDER,"plots","app_heatmaps","$fname"))
    end
end