Skip to content
Snippets Groups Projects
Commit be06ddd0 authored by Peter Jentsch's avatar Peter Jentsch
Browse files

added tests to ZeroWeightedDistributions, removed vector rand function cause...

added tests to ZeroWeightedDistributions, removed vector rand function cause it was wrong, but we don't use it in the simulations anyway
parent f1c256e9
No related branches found
No related tags found
No related merge requests found
Showing
with 173 additions and 133 deletions
...@@ -8,10 +8,10 @@ default(framestyle = :box) ...@@ -8,10 +8,10 @@ default(framestyle = :box)
function bench() function bench()
steps = 100 steps = 100
model_sol = ModelSolution(steps,get_parameters(),5000); model_sol = ModelSolution(steps,get_parameters(),5000);
solve!(model_sol,identity) recording = recorder(steps)
output = solve!(model_sol,DebugRecorder)
end end
function abm() function abm()
......
...@@ -17,7 +17,7 @@ end ...@@ -17,7 +17,7 @@ end
const contact_time_distribution_matrix = [Geometric() for i in 1:(AgentDemographic.size-1), j in 1:(AgentDemographic.size-1)] const contact_time_distributions = load_contact_time_distributions()
function alpha_matrix(alphas) function alpha_matrix(alphas)
......
...@@ -31,59 +31,49 @@ function random_bipartite_graph_fast_CL!(g::SimpleGraph,anodes,bnodes,aseq,bseq, ...@@ -31,59 +31,49 @@ function random_bipartite_graph_fast_CL!(g::SimpleGraph,anodes,bnodes,aseq,bseq,
return g return g
end end
using DataStructures
struct MixingGraph{G, V, M} #defining my own weighted graph type cause we need to be able to resample the edge weights in a particular way
struct WeightedGraph{G, V,M}
g::G g::G
weights::V weights::V
covid_alert_time::M weights_distribution_matrix::M
function MixingGraph(demographics,demographic_index_vectors,mixing_matrix) function WeightedGraph(demographics,demographic_index_vectors,mixing_matrix,weights_distribution_matrix)
contacts = MixingContacts(demographic_index_vectors,mixing_matrix) contacts = MixingContacts(demographic_index_vectors,mixing_matrix)
g = Graph(length(demographics)) g = Graph(length(demographics))
weights = RobinDict{Tuple{Int,Int},UInt8}() weights = RobinDict{Tuple{Int,Int},UInt8}()
for i in 1:length(demographic_index_vectors), j in 1:i #diagonal for i in 1:length(demographic_index_vectors), j in 1:i #diagonal
random_bipartite_graph_fast_CL!(g,demographic_index_vectors[i],demographic_index_vectors[j],contacts.contact_array[i,j],contacts.contact_array[j,i],weights) random_bipartite_graph_fast_CL!(g,demographic_index_vectors[i],demographic_index_vectors[j],contacts.contact_array[i,j],contacts.contact_array[j,i],weights)
end end
covid_alert_time = zeros(nv(g)) covid_alert_time = zeros(nv(g))
return new{typeof(g),typeof(weights),typeof(covid_alert_time)}( return new{typeof(g),typeof(weights),typeof(weights_distribution_matrix)}(
g, g,
weights, weights,
covid_alert_time weights_distribution_matrix
) )
end end
function MixingGraph(g::SimpleGraph) function WeightedGraph(g::SimpleGraph,weights_distribution_matrix)
weights = RobinDict{Tuple{Int,Int},UInt8}() weights = RobinDict{Tuple{Int,Int},UInt8}()
covid_alert_time = zeros(nv(g)) return new{typeof(g),typeof(weights),typeof(weights_distribution_matrix)}(
return new{typeof(g),typeof(weights),typeof(covid_alert_time)}(
g, g,
weights, weights,
covid_alert_time weights_distribution_matrix
) )
end end
end end
function MixingGraph(g::SimpleGraph,demographics,contact_time_distribution_matrix) function sample_mixing_graph!(mixing_graph,population_demographics)
mg = MixingGraph(g)
return mg
end
function MixingGraph(demographics,demographic_index_vectors,mixing_matrix,contact_time_distribution_matrix)
mg = MixingGraph(demographics,demographic_index_vectors,mixing_matrix)
return mg
end
function sample_mixing_graph!(mixing_graph,population_demographics, contact_time_distributions)
for (k,e) in enumerate(edges(mixing_graph.g)) for (k,e) in enumerate(edges(mixing_graph.g))
i = src(e) i = src(e)
j = dst(e) j = dst(e)
demo_i = Int(population_demographics[i]) demo_i = Int(population_demographics[i])
demo_j = Int(population_demographics[j]) demo_j = Int(population_demographics[j])
contact_time = rand(RNG, contact_time_distributions[demo_i,demo_j]) contact_time = rand(RNG, mixing_graph.weights_distribution_matrix[demo_i,demo_j])
mixing_graph.covid_alert_time[i] += contact_time
mixing_graph.covid_alert_time[j] += contact_time
mixing_graph.weights[(i,j)] = contact_time mixing_graph.weights[(i,j)] = contact_time
mixing_graph.weights[(j,i)] = contact_time mixing_graph.weights[(j,i)] = contact_time
end end
...@@ -116,7 +106,7 @@ function generate_contact_vectors!(ij_dist,ji_dist,i_to_j_contacts::Vector{T}, j ...@@ -116,7 +106,7 @@ function generate_contact_vectors!(ij_dist,ji_dist,i_to_j_contacts::Vector{T}, j
sample!(RNG,1:l_j,index_list_j) sample!(RNG,1:l_j,index_list_j)
rand!(RNG,ij_dist,sample_list_i) rand!(RNG,ij_dist,sample_list_i)
rand!(RNG,ji_dist,sample_list_j) rand!(RNG,ji_dist,sample_list_j)
@inbounds @simd for i = 1:inner_iter @inbounds for i = 1:inner_iter
if csum != 0 if csum != 0
csum = reindex!(i,csum,index_list_i,index_list_j,j_to_i_contacts,i_to_j_contacts,sample_list_i,sample_list_j) csum = reindex!(i,csum,index_list_i,index_list_j,j_to_i_contacts,i_to_j_contacts,sample_list_i,sample_list_j)
end end
......
function get_parameters() function get_parameters()
params = ( params = (
I_0_fraction = 0.0, I_0_fraction = 0.05,
base_transmission_probability = 0.5, base_transmission_probability = 0.5,
recovery_rate = 0.1, recovery_rate = 0.1,
immunization_loss_prob = 0.01, immunization_loss_prob = 0.5,
π_base = -0.1, π_base = -0.05,
η = 0.0, η = 0.0,
κ = 0.0, κ = 0.0,
ω = 0.0, ω = 0.0,
...@@ -13,7 +13,7 @@ function get_parameters() ...@@ -13,7 +13,7 @@ function get_parameters()
ω_en = 0.0, ω_en = 0.0,
ρ_en = [0.0,0.0,0.0], ρ_en = [0.0,0.0,0.0],
γ = 0.0, γ = 0.0,
β = 1000.0, β = 10.0,
notification_parameter = 0.0, notification_parameter = 0.0,
vaccinator_prob = 0.5, vaccinator_prob = 0.5,
app_user_fraction = 0.5, app_user_fraction = 0.5,
...@@ -48,7 +48,7 @@ struct ModelSolution{T,G} ...@@ -48,7 +48,7 @@ struct ModelSolution{T,G}
app_user_list::Vector{Int} app_user_list::Vector{Int}
app_user_index::Vector{Int} app_user_index::Vector{Int}
status_totals::Vector{Int} status_totals::Vector{Int}
status_totals_next::Vector{Int}
function ModelSolution(sim_length,params::T,num_households) where T function ModelSolution(sim_length,params::T,num_households) where T
demographics,base_network,index_vectors = generate_population(num_households) demographics,base_network,index_vectors = generate_population(num_households)
pop_sizes = length.(index_vectors) pop_sizes = length.(index_vectors)
...@@ -66,14 +66,14 @@ struct ModelSolution{T,G} ...@@ -66,14 +66,14 @@ struct ModelSolution{T,G}
u_0_inf,u_0_vac = get_u_0(nodes,params.I_0_fraction,params.vaccinator_prob) u_0_inf,u_0_vac = get_u_0(nodes,params.I_0_fraction,params.vaccinator_prob)
home_static_edges = MixingGraph(base_network,demographics,contact_time_distribution_matrix) #network with households and LTC homes home_static_edges = WeightedGraph(base_network,contact_time_distributions.hh) #network with households and LTC homes
ws_static_edges = MixingGraph(demographics,index_vectors,ws_matrix_list.daily,contact_time_distribution_matrix) ws_static_edges = WeightedGraph(demographics,index_vectors,ws_matrix_list.daily,contact_time_distributions.ws)
ws_weekly_edges = MixingGraph(demographics,index_vectors,ws_matrix_list.twice_a_week,contact_time_distribution_matrix) ws_weekly_edges = WeightedGraph(demographics,index_vectors,ws_matrix_list.twice_a_week,contact_time_distributions.ws)
ws_daily_edges_vector = [MixingGraph(demographics,index_vectors,ws_matrix_list.otherwise,contact_time_distribution_matrix) for i in 1:sim_length] ws_daily_edges_vector = [WeightedGraph(demographics,index_vectors,ws_matrix_list.otherwise,contact_time_distributions.ws) for i in 1:sim_length]
rest_static_edges = MixingGraph(demographics,index_vectors,rest_matrix_list.daily,contact_time_distribution_matrix) rest_static_edges = WeightedGraph(demographics,index_vectors,rest_matrix_list.daily,contact_time_distributions.rest)
rest_weekly_edges = MixingGraph(demographics,index_vectors,rest_matrix_list.twice_a_week,contact_time_distribution_matrix) rest_weekly_edges = WeightedGraph(demographics,index_vectors,rest_matrix_list.twice_a_week,contact_time_distributions.rest)
rest_daily_edges_vector = [MixingGraph(demographics,index_vectors,rest_matrix_list.otherwise,contact_time_distribution_matrix) for i in 1:sim_length] rest_daily_edges_vector = [WeightedGraph(demographics,index_vectors,rest_matrix_list.otherwise,contact_time_distributions.rest) for i in 1:sim_length]
inf_network_lists = [ inf_network_lists = [
[home_static_edges,rest_static_edges] for i in 1:sim_length [home_static_edges,rest_static_edges] for i in 1:sim_length
...@@ -96,6 +96,14 @@ struct ModelSolution{T,G} ...@@ -96,6 +96,14 @@ struct ModelSolution{T,G}
covid_alert_times = zeros(Int,length(app_user_index),14) #two weeks worth of values covid_alert_times = zeros(Int,length(app_user_index),14) #two weeks worth of values
time_of_last_alert = fill(-1,length(app_user_index)) #two weeks worth of values time_of_last_alert = fill(-1,length(app_user_index)) #two weeks worth of values
status_totals = zeros(Int, AgentStatus.size)
status_totals[1] = count(==(Susceptible), u_0_inf)
status_totals[2] = count(==(Infected), u_0_inf)
status_totals[3] = count(==(Recovered), u_0_inf)
status_totals[4] = count(==(Immunized), u_0_inf)
return new{T,typeof(home_static_edges)}( return new{T,typeof(home_static_edges)}(
sim_length, sim_length,
nodes, nodes,
...@@ -112,7 +120,9 @@ struct ModelSolution{T,G} ...@@ -112,7 +120,9 @@ struct ModelSolution{T,G}
demographics, demographics,
is_app_user, is_app_user,
app_user_list, app_user_list,
app_user_index app_user_index,
status_totals,
copy(status_totals)
) )
end end
end end
abstract type AbstractOutputData end #needlessly overwrought output interface
struct DebugOutputData <: AbstractOutputData
TotalInfected:: abstract type AbstractRecorder end
end
\ No newline at end of file struct DebugRecorder <: AbstractRecorder
recorded_status_totals::Array{Int,2}
Total_S::Vector{Int}
Total_I::Vector{Int}
Total_R::Vector{Int}
Total_V::Vector{Int}
Total_Vaccinator::Vector{Int}
function DebugRecorder(sim_length)
return new(
zeros(Int,AgentStatus.size,sim_length),
zeros(Int,sim_length),
zeros(Int,sim_length),
zeros(Int,sim_length),
zeros(Int,sim_length),
zeros(Int,sim_length),
)
end
end
function record!(t,modelsol, recorder::DebugRecorder)
recorder.Total_S[t] = count(==(Susceptible),modelsol.u_inf)
recorder.Total_I[t] = count(==(Infected),modelsol.u_inf)
recorder.Total_R[t] = count(==(Recovered),modelsol.u_inf)
recorder.Total_V[t] = count(==(Immunized),modelsol.u_inf)
recorder.Total_Vaccinator[t] = count(==(true),modelsol.u_vac)
recorder.recorded_status_totals[:,t] .= modelsol.status_totals
end
function record!(t,modelsol, recorder::Nothing)
#do nothing
end
...@@ -2,14 +2,7 @@ ...@@ -2,14 +2,7 @@
function contact_weight(p, contact_time) function contact_weight(p, contact_time)
return 1 - (1-p)^contact_time return 1 - (1-p)^contact_time
end end
function update_alert_durations!(t,modelsol)
function EN_payoff()
return 0.0
end
function update_alert_durations!(t,modelsol)
@unpack notification_parameter = modelsol.params @unpack notification_parameter = modelsol.params
@unpack time_of_last_alert, app_user_index,inf_network_lists,covid_alert_times,app_user = modelsol @unpack time_of_last_alert, app_user_index,inf_network_lists,covid_alert_times,app_user = modelsol
...@@ -30,31 +23,39 @@ function update_alert_durations!(t,modelsol) ...@@ -30,31 +23,39 @@ function update_alert_durations!(t,modelsol)
end end
function update_infection_state!(t,modelsol) function update_infection_state!(t,modelsol)
@unpack base_transmission_probability,immunization_loss_prob,recovery_rate = modelsol.params @unpack base_transmission_probability,immunization_loss_prob,recovery_rate = modelsol.params
@unpack u_inf,u_vac,u_next_inf,u_next_vac,demographics,inf_network_lists = modelsol @unpack u_inf,u_vac,u_next_inf,u_next_vac,demographics,inf_network_lists,status_totals,status_totals_next = modelsol
function agent_transition!(node, from::AgentStatus,to::AgentStatus)
status_totals_next[Int(from)] -= 1
status_totals_next[Int(to)] += 1
u_next_inf[node] = to
end
u_next_inf .= u_inf
status_totals_next .= status_totals
for i in 1:modelsol.nodes for i in 1:modelsol.nodes
agent_status = u_inf[i] agent_status = u_inf[i]
is_vaccinator = u_vac[i] is_vaccinator = u_vac[i]
agent_demo = demographics[i] agent_demo = demographics[i]
if agent_status == Susceptible if agent_status == Susceptible
if is_vaccinator if is_vaccinator
u_next_inf[i] = Immunized agent_transition!(i, Susceptible,Immunized)
else else
for mixing_graph in inf_network_lists[t] for mixing_graph in inf_network_lists[t]
for j in neighbors(mixing_graph.g,i) for j in neighbors(mixing_graph.g,i)
if u_inf[j] == Infected && rand(RNG) < contact_weight(base_transmission_probability,mixing_graph.weights[(i,j)]) if u_inf[j] == Infected && rand(RNG) < contact_weight(base_transmission_probability,mixing_graph.weights[(i,j)])
u_next_inf[i] = Infected agent_transition!(i, Susceptible,Infected)
end end
end end
end end
end end
elseif agent_status == Infected elseif agent_status == Infected
if rand(RNG) < recovery_rate if rand(RNG) < recovery_rate
u_next_inf[i] = Recovered agent_transition!(i, Infected,Recovered)
end end
elseif agent_status == Immunized elseif agent_status == Immunized
if rand(RNG) < immunization_loss_prob if rand(RNG) < immunization_loss_prob
u_next_inf[i] = Susceptible agent_transition!(i, Immunized,Susceptible)
end end
end end
end end
...@@ -64,6 +65,8 @@ function update_vaccination_opinion_state!(t,modelsol,total_infections) ...@@ -64,6 +65,8 @@ function update_vaccination_opinion_state!(t,modelsol,total_infections)
@unpack π_base, η,γ, κ, ω, ρ, ω_en,ρ_en,γ,β = modelsol.params @unpack π_base, η,γ, κ, ω, ρ, ω_en,ρ_en,γ,β = modelsol.params
@unpack demographics,time_of_last_alert, nodes, soc_networks,u_vac,u_next_vac,app_user,app_user_list = modelsol @unpack demographics,time_of_last_alert, nodes, soc_networks,u_vac,u_next_vac,app_user,app_user_list = modelsol
app_user_pointer = 0 app_user_pointer = 0
for i in 1:nodes for i in 1:nodes
vac_payoff = 0 vac_payoff = 0
soc_nbrs_vac = [0,0,0] soc_nbrs_vac = [0,0,0]
...@@ -80,20 +83,24 @@ function update_vaccination_opinion_state!(t,modelsol,total_infections) ...@@ -80,20 +83,24 @@ function update_vaccination_opinion_state!(t,modelsol,total_infections)
end end
end end
end end
vac_payoff += π_base + dot(ρ,soc_nbrs_vac) + total_infections*ω + ifelse(num_soc_nbrs> 0, κ * ((sum(soc_nbrs_vac) - soc_nbrs_nonvac/num_soc_nbrs)),0) vac_payoff += π_base + dot(ρ,soc_nbrs_vac) + total_infections*ω +
ifelse(num_soc_nbrs> 0, κ * ((sum(soc_nbrs_vac) - soc_nbrs_nonvac/num_soc_nbrs)),0)
if app_user[i] && time_of_last_alert[app_user_list[i]]>=0 if app_user[i] && time_of_last_alert[app_user_list[i]]>=0
vac_payoff += γ^(-1*(t - time_of_last_alert[app_user_list[i]]))* (η + dot(ρ_en,soc_nbrs_vac) + total_infections*ω_en) vac_payoff += γ^(-1*(t - time_of_last_alert[app_user_list[i]]))* (η + dot(ρ_en,soc_nbrs_vac) + total_infections*ω_en)
end end
if u_vac[i] if u_vac[i]
if rand(RNG) < 1 - Φ(vac_payoff,β) if rand(RNG) < 1 - Φ(vac_payoff,β)
# display("$i switch")
u_next_vac[i] = !u_vac[i] u_next_vac[i] = !u_vac[i]
else
u_next_vac[i] = u_vac[i]
end end
else else
if rand(RNG) < Φ(vac_payoff,β) if rand(RNG) < Φ(vac_payoff,β)
# display("$i switchback")
u_next_vac[i] = !u_vac[i] u_next_vac[i] = !u_vac[i]
else
u_next_vac[i] = u_vac[i]
end end
end end
end end
...@@ -102,24 +109,29 @@ end ...@@ -102,24 +109,29 @@ end
function agents_step!(t,modelsol) function agents_step!(t,modelsol)
for network in modelsol.inf_network_lists[t] for network in modelsol.inf_network_lists[t]
sample_mixing_graph!(network,modelsol.demographics, contact_time_distribution_matrix) #get new contact weights sample_mixing_graph!(network,modelsol.demographics) #get new contact weights
end end
update_alert_durations!(t,modelsol) update_alert_durations!(t,modelsol)
update_vaccination_opinion_state!(t,modelsol,count(==(Infected),modelsol.u_inf)) update_vaccination_opinion_state!(t,modelsol,modelsol.status_totals[Int(Infected)])
update_infection_state!(t,modelsol) update_infection_state!(t,modelsol)
modelsol.u_vac .= modelsol.u_next_vac modelsol.u_vac .= modelsol.u_next_vac
modelsol.u_inf .= modelsol.u_next_inf modelsol.u_inf .= modelsol.u_next_inf
modelsol.status_totals .= modelsol.status_totals_next
end end
function solve!(modelsol,recording_function) function solve!(modelsol,recording)
for t in 1:modelsol.sim_length for t in 1:modelsol.sim_length
#advance agent states based on the new network #advance agent states based on the new network
agents_step!(t,modelsol) agents_step!(t,modelsol)
display((count(==(true),modelsol.u_vac),count(==(Immunized),modelsol.u_inf)))
record!(t,modelsol,recording)
end end
# return solution, network_lists return recording
end end
......
...@@ -8,6 +8,7 @@ using Distributions ...@@ -8,6 +8,7 @@ using Distributions
using StatsBase using StatsBase
using Dates using Dates
using LinearAlgebra using LinearAlgebra
using CovidAlertVaccinationModel
using ThreadsX using ThreadsX
using DelimitedFiles using DelimitedFiles
using KernelDensity using KernelDensity
...@@ -15,6 +16,7 @@ using NamedTupleTools ...@@ -15,6 +16,7 @@ using NamedTupleTools
using NetworkLayout:Stress using NetworkLayout:Stress
using NetworkLayout:SFDP using NetworkLayout:SFDP
using ZeroWeightedDistributions using ZeroWeightedDistributions
using DataStructures
using Serialization using Serialization
using BenchmarkTools using BenchmarkTools
using Intervals using Intervals
...@@ -31,13 +33,15 @@ const RNG = Xoroshiro128Star(1) ...@@ -31,13 +33,15 @@ const RNG = Xoroshiro128Star(1)
const color_palette = palette(:seaborn_pastel) #color theme for the plots const color_palette = palette(:seaborn_pastel) #color theme for the plots
include("utils.jl") include("utils.jl")
include("data.jl") include("data.jl")
include("ABM/abm.jl")
include("ABM/agents.jl") include("ABM/agents.jl")
include("ABM/mixing_distributions.jl") include("ABM/mixing_distributions.jl")
include("ABM/mixing_graphs.jl") include("ABM/mixing_graphs.jl")
include("ABM/plotting.jl") include("ABM/plotting.jl")
include("ABM/model_setup.jl") include("ABM/model_setup.jl")
include("ABM/output.jl")
include("ABM/solve.jl") include("ABM/solve.jl")
include("ABM/abm.jl")
include("IntervalsModel/intervals_model.jl") include("IntervalsModel/intervals_model.jl")
include("IntervalsModel/interval_overlap_sampling.jl") include("IntervalsModel/interval_overlap_sampling.jl")
include("IntervalsModel/hh_durations_model.jl") include("IntervalsModel/hh_durations_model.jl")
......
...@@ -101,8 +101,17 @@ function load_mixing_matrices() ...@@ -101,8 +101,17 @@ function load_mixing_matrices()
end end
function load_contact_time_distributions() function load_contact_time_distributions()
dat = deserialize(joinpath(PACKAGE_FOLDER,"intervals_model_output/simulation_output/hh.dat")) distkey = "Distributions.Poisson"
return dat fnames = (
hh = "hh",
ws = "ws",
rest = "rest"
)
contact_distributions_tuple = map(fnames) do fname
dat = deserialize(joinpath(PACKAGE_FOLDER,"intervals_model_output","simulation_output","$fname.dat"))
return map(p -> Poisson(mode(p.particles)), as_symmetric_matrix(dat[distkey].P))
end
return contact_distributions_tuple
end end
""" """
Load rest data from `data/canada-network-data/Timeuse/Rest/RData`. Load rest data from `data/canada-network-data/Timeuse/Rest/RData`.
......
model_sizes = [100,1000,5000] const model_sizes = [100,1000,5000]
vaccination_strategies = [vaccinate_uniformly!] const dem_cat = AgentDemographic.size -1
vaccination_rates = [0.000,0.005,0.01,0.05]
infection_rates = [0.01,0.05,0.1]
agent_models = ThreadsX.map(model_size -> AgentModel(model_size...), model_sizes)
dem_cat = AgentDemographic.size -1
#network generation #network generation
...@@ -26,9 +22,8 @@ dem_cat = AgentDemographic.size -1 ...@@ -26,9 +22,8 @@ dem_cat = AgentDemographic.size -1
end end
function vac_rate_test(model,vac_strategy, vac_rate; rng = Xoroshiro128Plus()) function vaccinator_opinion_test(model,vac_strategy, π_base; rng = Xoroshiro128Plus())
u_0 = get_u_0(length(model.demographics)) params = merge(get_parameters(),(I_0_fraction = 0.0,π_base))
params = merge(get_parameters(),(vaccines_per_day = vac_rate,))
steps = 300 steps = 300
sol1,_ = solve!(u_0,params,steps,model,vac_strategy); sol1,_ = solve!(u_0,params,steps,model,vac_strategy);
total_infections = count(x->x == AgentStatus(3),sol1[end]) total_infections = count(x->x == AgentStatus(3),sol1[end])
...@@ -38,7 +33,6 @@ end ...@@ -38,7 +33,6 @@ end
function infection_rate_test(model, inf_parameter; rng = Xoroshiro128Plus()) function infection_rate_test(model, inf_parameter; rng = Xoroshiro128Plus())
params = merge(get_parameters(),(p = inf_parameter,))
steps = 300 steps = 300
# display(params) # display(params)
sol1,_ = solve!(params,steps,model,vaccinate_uniformly!); sol1,_ = solve!(params,steps,model,vaccinate_uniformly!);
...@@ -65,3 +59,12 @@ end ...@@ -65,3 +59,12 @@ end
@test test_comparison(x->infection_rate_test(m,x),infection_rates,<) @test test_comparison(x->infection_rate_test(m,x),infection_rates,<)
end end
@testset "infection efficacy $sz" for (m,sz) in zip(deepcopy(agent_models),model_sizes)
@test test_comparison(x->infection_rate_test(m,x),infection_rates,<)
end
...@@ -4,4 +4,4 @@ using Test ...@@ -4,4 +4,4 @@ using Test
using ThreadsX using ThreadsX
import StatsBase.mean import StatsBase.mean
include("ABM/abm.jl") include("ABM/abm_test.jl")
\ No newline at end of file \ No newline at end of file
...@@ -5,8 +5,12 @@ version = "0.1.0" ...@@ -5,8 +5,12 @@ version = "0.1.0"
[deps] [deps]
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f" Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
KernelDensity = "5ab0869b-81aa-558d-bb23-cbf5423bbe9b"
Plots = "91a5bcdd-55d7-5caf-9e0b-520d859cae80"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
RandomNumbers = "e6cf234a-135c-5ec9-84dd-332b85af5143"
StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91" StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91"
ThreadsX = "ac1d9e8a-700a-412c-b207-f0111f4b6c0d"
[compat] [compat]
julia = "1" julia = "1"
......
...@@ -21,10 +21,6 @@ struct ZWDist{BaseDistType <: Sampleable,T} <: Sampleable{Univariate,T} ...@@ -21,10 +21,6 @@ struct ZWDist{BaseDistType <: Sampleable,T} <: Sampleable{Univariate,T}
return new{DType,S}(α,DType(p...)) return new{DType,S}(α,DType(p...))
end end
end end
# function from_mean(::Type{ZWDist{DistType,T}},α,μ) where {DistType <: Distribution{Univariate,W} where W, T}
# return ZWDist(α,from_mean(DistType,μ/(1-α)))
# end
StatsBase.mean(d::ZWDist{Dist,T}) where {Dist,T} = (1 - d.α)*StatsBase.mean(d.base_dist) StatsBase.mean(d::ZWDist{Dist,T}) where {Dist,T} = (1 - d.α)*StatsBase.mean(d.base_dist)
...@@ -32,7 +28,7 @@ function Distributions.pdf(d::ZWDist, x) ...@@ -32,7 +28,7 @@ function Distributions.pdf(d::ZWDist, x)
if x == 0 if x == 0
return d.α + (1-d.α)*pdf(d.base_dist,0) return d.α + (1-d.α)*pdf(d.base_dist,0)
else else
return pdf(d.base_dist,x) return (1-d.α)*pdf(d.base_dist,x)
end end
end end
...@@ -40,22 +36,6 @@ function Distributions.mean(d::ZWDist, x) ...@@ -40,22 +36,6 @@ function Distributions.mean(d::ZWDist, x)
return (1-d.α)*StatsBase.mean(d.base_dist,0) return (1-d.α)*StatsBase.mean(d.base_dist,0)
end end
function Base.rand(rng::AbstractRNG, s::ZWDist{DType,S}, n::Int) where {DType, S}
l = Vector{eltype(DType)}(undef,n)
Random.rand!(rng,l)
l[l .< s.α] .= zero(eltype(DType))
Random.rand!(rng,s.base_dist,@view l[l .>= s.α])
return l
end
function Random.rand!(rng::AbstractRNG, s::ZWDist{DType,S}, l::T) where {T<:AbstractVector, DType,S}
Random.rand!(rng,l)
l[l .< s.α] .= zero(eltype(DType))
Random.rand!(rng,s.base_dist,@view l[l .>= s.α])
return l
end
function Base.rand(rng::AbstractRNG, s::ZWDist{DType,S}) where {DType,S} function Base.rand(rng::AbstractRNG, s::ZWDist{DType,S}) where {DType,S}
return ifelse(Base.rand(rng) < s.α, zero(eltype(DType)), Base.rand(rng,s.base_dist)) return ifelse(Base.rand(rng) < s.α, zero(eltype(DType)), Base.rand(rng,s.base_dist))
end end
......
...@@ -4,41 +4,37 @@ using Distributions ...@@ -4,41 +4,37 @@ using Distributions
using Test using Test
using ThreadsX using ThreadsX
using RandomNumbers.Xorshifts using RandomNumbers.Xorshifts
using Plots # using Plots
# const dist_list = [Poisson,Geometric] const dist_list = [Poisson]
# const params_list = collect(0.01:0.5:0.9) const params_list = collect(0.1:0.2:0.9)
# const α_list = collect(0.0:0.1:1.0) const α_list = collect(0.2:0.2:1.0)
# const RNG = Xoroshiro128Star(1) const RNG = Xoroshiro128Star(1)
# const tol = 1e-1 const tol = 0.01
function test_dist(d) const sample_length = 1_000_000
# samples_vec = rand(RNG,d, 100) epdf(samples,k) = count(==(k), samples)/sample_length
# display(samples_vec)
# samples_map = map(x->rand(RNG,d), 1:10_000)
# epdf = kde(samples_vec)
# vec_err = all([abs(pdf(epdf,k) - pdf(d,k)) for k in 0.0:0.1:10] .< tol)
# epdf = kde(samples_map)
# map_err= all([abs(pdf(epdf,k) - pdf(d,k)) for k in 0.0:0.1:10] .< tol)
# # displlay((vec_err,map_err))
# p = scatter([pdf(epdf,k) for k in 0.0:0.1:10])
# scatter!(p,[pdf(d,k) for k in 0.0:0.1:10])
# display(p)
# return vec_err && map_err
end
function test_dist_vec(d)
function test() samples_vec = rand(RNG,d, sample_length)
map(Iterators.product(dist_list,params_list,α_list)) do (d,p,α) vec_err = all(abs(epdf(samples_vec,k) - pdf(d,k)) < tol for k in 0:1:10)
# println((d,p,α)) return vec_err
test_dist(ZWDist(d,α,p)) end
end function test_dist_map(d)
samples_map = map(x->rand(RNG,d), 1:sample_length)
map_err= all(abs(epdf(samples_map,k) - pdf(d,k)) < tol for k in 0:1:10)
return map_err
end end
test()
# @testset "ZeroWeightedDistributions.jl" begin @testset "ZeroWeightedDistributions.jl" begin
# @test @testset for(d,p,α) in Iterators.product(dist_list,params_list,α_list)
# end println((d,p,α))
@test test_dist_vec(ZWDist(d,α,p))
end
@testset for(d,p,α) in Iterators.product(dist_list,params_list,α_list)
println((d,p,α))
@test test_dist_map(ZWDist(d,α,p))
end
end
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment