diff --git a/CovidAlertVaccinationModel/src/ABM/abm.jl b/CovidAlertVaccinationModel/src/ABM/abm.jl index 250c1d0c53d6f3ce5735952407ea5c194ddda8c7..084fd4d9b50bec979248cd63be3c53e967553a1a 100644 --- a/CovidAlertVaccinationModel/src/ABM/abm.jl +++ b/CovidAlertVaccinationModel/src/ABM/abm.jl @@ -8,10 +8,10 @@ default(framestyle = :box) function bench() - steps = 100 model_sol = ModelSolution(steps,get_parameters(),5000); - solve!(model_sol,identity) + recording = recorder(steps) + output = solve!(model_sol,DebugRecorder) end function abm() diff --git a/CovidAlertVaccinationModel/src/ABM/mixing_distributions.jl b/CovidAlertVaccinationModel/src/ABM/mixing_distributions.jl index 10f045b8f82526f321d8bec447ee80b7021acfcd..8b4ad25750ae22a5cbba2342a939d09c259007e2 100644 --- a/CovidAlertVaccinationModel/src/ABM/mixing_distributions.jl +++ b/CovidAlertVaccinationModel/src/ABM/mixing_distributions.jl @@ -17,7 +17,7 @@ end -const contact_time_distribution_matrix = [Geometric() for i in 1:(AgentDemographic.size-1), j in 1:(AgentDemographic.size-1)] +const contact_time_distributions = load_contact_time_distributions() function alpha_matrix(alphas) diff --git a/CovidAlertVaccinationModel/src/ABM/mixing_graphs.jl b/CovidAlertVaccinationModel/src/ABM/mixing_graphs.jl index 1c26f96739b4f560e63541ba45d6e20fcfb203da..c038e83329d3754a8e7f239ef4ecd647637ae913 100644 --- a/CovidAlertVaccinationModel/src/ABM/mixing_graphs.jl +++ b/CovidAlertVaccinationModel/src/ABM/mixing_graphs.jl @@ -31,59 +31,49 @@ function random_bipartite_graph_fast_CL!(g::SimpleGraph,anodes,bnodes,aseq,bseq, return g end -using DataStructures -struct MixingGraph{G, V, M} + +#defining my own weighted graph type cause we need to be able to resample the edge weights in a particular way +struct WeightedGraph{G, V,M} g::G weights::V - covid_alert_time::M - function MixingGraph(demographics,demographic_index_vectors,mixing_matrix) + weights_distribution_matrix::M + function WeightedGraph(demographics,demographic_index_vectors,mixing_matrix,weights_distribution_matrix) contacts = MixingContacts(demographic_index_vectors,mixing_matrix) g = Graph(length(demographics)) weights = RobinDict{Tuple{Int,Int},UInt8}() + for i in 1:length(demographic_index_vectors), j in 1:i #diagonal random_bipartite_graph_fast_CL!(g,demographic_index_vectors[i],demographic_index_vectors[j],contacts.contact_array[i,j],contacts.contact_array[j,i],weights) end covid_alert_time = zeros(nv(g)) - return new{typeof(g),typeof(weights),typeof(covid_alert_time)}( + return new{typeof(g),typeof(weights),typeof(weights_distribution_matrix)}( g, weights, - covid_alert_time + weights_distribution_matrix ) end - function MixingGraph(g::SimpleGraph) + function WeightedGraph(g::SimpleGraph,weights_distribution_matrix) weights = RobinDict{Tuple{Int,Int},UInt8}() - covid_alert_time = zeros(nv(g)) - return new{typeof(g),typeof(weights),typeof(covid_alert_time)}( + return new{typeof(g),typeof(weights),typeof(weights_distribution_matrix)}( g, weights, - covid_alert_time + weights_distribution_matrix ) end end -function MixingGraph(g::SimpleGraph,demographics,contact_time_distribution_matrix) - mg = MixingGraph(g) - return mg -end -function MixingGraph(demographics,demographic_index_vectors,mixing_matrix,contact_time_distribution_matrix) - mg = MixingGraph(demographics,demographic_index_vectors,mixing_matrix) - return mg -end - -function sample_mixing_graph!(mixing_graph,population_demographics, contact_time_distributions) +function sample_mixing_graph!(mixing_graph,population_demographics) for (k,e) in enumerate(edges(mixing_graph.g)) i = src(e) j = dst(e) demo_i = Int(population_demographics[i]) demo_j = Int(population_demographics[j]) - contact_time = rand(RNG, contact_time_distributions[demo_i,demo_j]) - mixing_graph.covid_alert_time[i] += contact_time - mixing_graph.covid_alert_time[j] += contact_time + contact_time = rand(RNG, mixing_graph.weights_distribution_matrix[demo_i,demo_j]) mixing_graph.weights[(i,j)] = contact_time mixing_graph.weights[(j,i)] = contact_time end @@ -116,7 +106,7 @@ function generate_contact_vectors!(ij_dist,ji_dist,i_to_j_contacts::Vector{T}, j sample!(RNG,1:l_j,index_list_j) rand!(RNG,ij_dist,sample_list_i) rand!(RNG,ji_dist,sample_list_j) - @inbounds @simd for i = 1:inner_iter + @inbounds for i = 1:inner_iter if csum != 0 csum = reindex!(i,csum,index_list_i,index_list_j,j_to_i_contacts,i_to_j_contacts,sample_list_i,sample_list_j) end diff --git a/CovidAlertVaccinationModel/src/ABM/model_setup.jl b/CovidAlertVaccinationModel/src/ABM/model_setup.jl index efb69fb0ecfba1f3020ca6705e61c4348b0afaba..6d623e82ac61196f3e620148b7b7513459732cc7 100644 --- a/CovidAlertVaccinationModel/src/ABM/model_setup.jl +++ b/CovidAlertVaccinationModel/src/ABM/model_setup.jl @@ -1,11 +1,11 @@ function get_parameters() params = ( - I_0_fraction = 0.0, + I_0_fraction = 0.05, base_transmission_probability = 0.5, recovery_rate = 0.1, - immunization_loss_prob = 0.01, - Ï€_base = -0.1, + immunization_loss_prob = 0.5, + Ï€_base = -0.05, η = 0.0, κ = 0.0, ω = 0.0, @@ -13,7 +13,7 @@ function get_parameters() ω_en = 0.0, Ï_en = [0.0,0.0,0.0], γ = 0.0, - β = 1000.0, + β = 10.0, notification_parameter = 0.0, vaccinator_prob = 0.5, app_user_fraction = 0.5, @@ -48,7 +48,7 @@ struct ModelSolution{T,G} app_user_list::Vector{Int} app_user_index::Vector{Int} status_totals::Vector{Int} - + status_totals_next::Vector{Int} function ModelSolution(sim_length,params::T,num_households) where T demographics,base_network,index_vectors = generate_population(num_households) pop_sizes = length.(index_vectors) @@ -66,14 +66,14 @@ struct ModelSolution{T,G} u_0_inf,u_0_vac = get_u_0(nodes,params.I_0_fraction,params.vaccinator_prob) - home_static_edges = MixingGraph(base_network,demographics,contact_time_distribution_matrix) #network with households and LTC homes - ws_static_edges = MixingGraph(demographics,index_vectors,ws_matrix_list.daily,contact_time_distribution_matrix) - ws_weekly_edges = MixingGraph(demographics,index_vectors,ws_matrix_list.twice_a_week,contact_time_distribution_matrix) - ws_daily_edges_vector = [MixingGraph(demographics,index_vectors,ws_matrix_list.otherwise,contact_time_distribution_matrix) for i in 1:sim_length] + home_static_edges = WeightedGraph(base_network,contact_time_distributions.hh) #network with households and LTC homes + ws_static_edges = WeightedGraph(demographics,index_vectors,ws_matrix_list.daily,contact_time_distributions.ws) + ws_weekly_edges = WeightedGraph(demographics,index_vectors,ws_matrix_list.twice_a_week,contact_time_distributions.ws) + ws_daily_edges_vector = [WeightedGraph(demographics,index_vectors,ws_matrix_list.otherwise,contact_time_distributions.ws) for i in 1:sim_length] - rest_static_edges = MixingGraph(demographics,index_vectors,rest_matrix_list.daily,contact_time_distribution_matrix) - rest_weekly_edges = MixingGraph(demographics,index_vectors,rest_matrix_list.twice_a_week,contact_time_distribution_matrix) - rest_daily_edges_vector = [MixingGraph(demographics,index_vectors,rest_matrix_list.otherwise,contact_time_distribution_matrix) for i in 1:sim_length] + rest_static_edges = WeightedGraph(demographics,index_vectors,rest_matrix_list.daily,contact_time_distributions.rest) + rest_weekly_edges = WeightedGraph(demographics,index_vectors,rest_matrix_list.twice_a_week,contact_time_distributions.rest) + rest_daily_edges_vector = [WeightedGraph(demographics,index_vectors,rest_matrix_list.otherwise,contact_time_distributions.rest) for i in 1:sim_length] inf_network_lists = [ [home_static_edges,rest_static_edges] for i in 1:sim_length @@ -96,6 +96,14 @@ struct ModelSolution{T,G} covid_alert_times = zeros(Int,length(app_user_index),14) #two weeks worth of values time_of_last_alert = fill(-1,length(app_user_index)) #two weeks worth of values + status_totals = zeros(Int, AgentStatus.size) + + status_totals[1] = count(==(Susceptible), u_0_inf) + status_totals[2] = count(==(Infected), u_0_inf) + status_totals[3] = count(==(Recovered), u_0_inf) + status_totals[4] = count(==(Immunized), u_0_inf) + + return new{T,typeof(home_static_edges)}( sim_length, nodes, @@ -112,7 +120,9 @@ struct ModelSolution{T,G} demographics, is_app_user, app_user_list, - app_user_index + app_user_index, + status_totals, + copy(status_totals) ) end end diff --git a/CovidAlertVaccinationModel/src/ABM/output.jl b/CovidAlertVaccinationModel/src/ABM/output.jl index dbd0bbf51183a4468efcfcc682ba62d9dd1322f8..89229449ab599b34c31f99beb3af47c8417c03a7 100644 --- a/CovidAlertVaccinationModel/src/ABM/output.jl +++ b/CovidAlertVaccinationModel/src/ABM/output.jl @@ -1,5 +1,37 @@ -abstract type AbstractOutputData end +#needlessly overwrought output interface -struct DebugOutputData <: AbstractOutputData - TotalInfected:: -end \ No newline at end of file + +abstract type AbstractRecorder end + +struct DebugRecorder <: AbstractRecorder + recorded_status_totals::Array{Int,2} + Total_S::Vector{Int} + Total_I::Vector{Int} + Total_R::Vector{Int} + Total_V::Vector{Int} + Total_Vaccinator::Vector{Int} + function DebugRecorder(sim_length) + return new( + zeros(Int,AgentStatus.size,sim_length), + zeros(Int,sim_length), + zeros(Int,sim_length), + zeros(Int,sim_length), + zeros(Int,sim_length), + zeros(Int,sim_length), + ) + end +end + +function record!(t,modelsol, recorder::DebugRecorder) + recorder.Total_S[t] = count(==(Susceptible),modelsol.u_inf) + recorder.Total_I[t] = count(==(Infected),modelsol.u_inf) + recorder.Total_R[t] = count(==(Recovered),modelsol.u_inf) + recorder.Total_V[t] = count(==(Immunized),modelsol.u_inf) + recorder.Total_Vaccinator[t] = count(==(true),modelsol.u_vac) + + recorder.recorded_status_totals[:,t] .= modelsol.status_totals +end + +function record!(t,modelsol, recorder::Nothing) + #do nothing +end diff --git a/CovidAlertVaccinationModel/src/ABM/solve.jl b/CovidAlertVaccinationModel/src/ABM/solve.jl index 0ccdc1c7cca4e8696ab0cf86dd9dfa2bb8b34957..817aa6a9c4ed9f1f7edb84d143e2409436975345 100644 --- a/CovidAlertVaccinationModel/src/ABM/solve.jl +++ b/CovidAlertVaccinationModel/src/ABM/solve.jl @@ -2,14 +2,7 @@ function contact_weight(p, contact_time) return 1 - (1-p)^contact_time end - - -function EN_payoff() - return 0.0 -end - -function update_alert_durations!(t,modelsol) - +function update_alert_durations!(t,modelsol) @unpack notification_parameter = modelsol.params @unpack time_of_last_alert, app_user_index,inf_network_lists,covid_alert_times,app_user = modelsol @@ -30,31 +23,39 @@ function update_alert_durations!(t,modelsol) end function update_infection_state!(t,modelsol) @unpack base_transmission_probability,immunization_loss_prob,recovery_rate = modelsol.params - @unpack u_inf,u_vac,u_next_inf,u_next_vac,demographics,inf_network_lists = modelsol - + @unpack u_inf,u_vac,u_next_inf,u_next_vac,demographics,inf_network_lists,status_totals,status_totals_next = modelsol + + function agent_transition!(node, from::AgentStatus,to::AgentStatus) + status_totals_next[Int(from)] -= 1 + status_totals_next[Int(to)] += 1 + u_next_inf[node] = to + end + + u_next_inf .= u_inf + status_totals_next .= status_totals for i in 1:modelsol.nodes agent_status = u_inf[i] is_vaccinator = u_vac[i] agent_demo = demographics[i] if agent_status == Susceptible if is_vaccinator - u_next_inf[i] = Immunized + agent_transition!(i, Susceptible,Immunized) else for mixing_graph in inf_network_lists[t] for j in neighbors(mixing_graph.g,i) if u_inf[j] == Infected && rand(RNG) < contact_weight(base_transmission_probability,mixing_graph.weights[(i,j)]) - u_next_inf[i] = Infected + agent_transition!(i, Susceptible,Infected) end end end end elseif agent_status == Infected if rand(RNG) < recovery_rate - u_next_inf[i] = Recovered + agent_transition!(i, Infected,Recovered) end elseif agent_status == Immunized if rand(RNG) < immunization_loss_prob - u_next_inf[i] = Susceptible + agent_transition!(i, Immunized,Susceptible) end end end @@ -64,6 +65,8 @@ function update_vaccination_opinion_state!(t,modelsol,total_infections) @unpack Ï€_base, η,γ, κ, ω, Ï, ω_en,Ï_en,γ,β = modelsol.params @unpack demographics,time_of_last_alert, nodes, soc_networks,u_vac,u_next_vac,app_user,app_user_list = modelsol app_user_pointer = 0 + + for i in 1:nodes vac_payoff = 0 soc_nbrs_vac = [0,0,0] @@ -80,20 +83,24 @@ function update_vaccination_opinion_state!(t,modelsol,total_infections) end end end - vac_payoff += Ï€_base + dot(Ï,soc_nbrs_vac) + total_infections*ω + ifelse(num_soc_nbrs> 0, κ * ((sum(soc_nbrs_vac) - soc_nbrs_nonvac/num_soc_nbrs)),0) + vac_payoff += Ï€_base + dot(Ï,soc_nbrs_vac) + total_infections*ω + + ifelse(num_soc_nbrs> 0, κ * ((sum(soc_nbrs_vac) - soc_nbrs_nonvac/num_soc_nbrs)),0) if app_user[i] && time_of_last_alert[app_user_list[i]]>=0 vac_payoff += γ^(-1*(t - time_of_last_alert[app_user_list[i]]))* (η + dot(Ï_en,soc_nbrs_vac) + total_infections*ω_en) end + if u_vac[i] if rand(RNG) < 1 - Φ(vac_payoff,β) - # display("$i switch") u_next_vac[i] = !u_vac[i] + else + u_next_vac[i] = u_vac[i] end else if rand(RNG) < Φ(vac_payoff,β) - # display("$i switchback") u_next_vac[i] = !u_vac[i] + else + u_next_vac[i] = u_vac[i] end end end @@ -102,24 +109,29 @@ end function agents_step!(t,modelsol) for network in modelsol.inf_network_lists[t] - sample_mixing_graph!(network,modelsol.demographics, contact_time_distribution_matrix) #get new contact weights + sample_mixing_graph!(network,modelsol.demographics) #get new contact weights end + update_alert_durations!(t,modelsol) - update_vaccination_opinion_state!(t,modelsol,count(==(Infected),modelsol.u_inf)) + update_vaccination_opinion_state!(t,modelsol,modelsol.status_totals[Int(Infected)]) update_infection_state!(t,modelsol) + modelsol.u_vac .= modelsol.u_next_vac modelsol.u_inf .= modelsol.u_next_inf + modelsol.status_totals .= modelsol.status_totals_next end -function solve!(modelsol,recording_function) +function solve!(modelsol,recording) for t in 1:modelsol.sim_length #advance agent states based on the new network agents_step!(t,modelsol) - display((count(==(true),modelsol.u_vac),count(==(Immunized),modelsol.u_inf))) + + record!(t,modelsol,recording) + end - # return solution, network_lists + return recording end diff --git a/CovidAlertVaccinationModel/src/CovidAlertVaccinationModel.jl b/CovidAlertVaccinationModel/src/CovidAlertVaccinationModel.jl index 20f401c7cbf3f65fca0dd78731ec9c63ccbecda1..6ef2ce693d8f2443d7add25c39352c0b7a0e54e1 100644 --- a/CovidAlertVaccinationModel/src/CovidAlertVaccinationModel.jl +++ b/CovidAlertVaccinationModel/src/CovidAlertVaccinationModel.jl @@ -8,6 +8,7 @@ using Distributions using StatsBase using Dates using LinearAlgebra +using CovidAlertVaccinationModel using ThreadsX using DelimitedFiles using KernelDensity @@ -15,6 +16,7 @@ using NamedTupleTools using NetworkLayout:Stress using NetworkLayout:SFDP using ZeroWeightedDistributions +using DataStructures using Serialization using BenchmarkTools using Intervals @@ -31,13 +33,15 @@ const RNG = Xoroshiro128Star(1) const color_palette = palette(:seaborn_pastel) #color theme for the plots include("utils.jl") include("data.jl") -include("ABM/abm.jl") include("ABM/agents.jl") include("ABM/mixing_distributions.jl") include("ABM/mixing_graphs.jl") include("ABM/plotting.jl") include("ABM/model_setup.jl") +include("ABM/output.jl") include("ABM/solve.jl") +include("ABM/abm.jl") + include("IntervalsModel/intervals_model.jl") include("IntervalsModel/interval_overlap_sampling.jl") include("IntervalsModel/hh_durations_model.jl") diff --git a/CovidAlertVaccinationModel/src/data.jl b/CovidAlertVaccinationModel/src/data.jl index 9540d82a5548e00569bdefc2368ed6c0574f0233..907ba53f3b415a8c81fe8bc84376242e65bb5baa 100644 --- a/CovidAlertVaccinationModel/src/data.jl +++ b/CovidAlertVaccinationModel/src/data.jl @@ -101,8 +101,17 @@ function load_mixing_matrices() end function load_contact_time_distributions() - dat = deserialize(joinpath(PACKAGE_FOLDER,"intervals_model_output/simulation_output/hh.dat")) - return dat + distkey = "Distributions.Poisson" + fnames = ( + hh = "hh", + ws = "ws", + rest = "rest" + ) + contact_distributions_tuple = map(fnames) do fname + dat = deserialize(joinpath(PACKAGE_FOLDER,"intervals_model_output","simulation_output","$fname.dat")) + return map(p -> Poisson(mode(p.particles)), as_symmetric_matrix(dat[distkey].P)) + end + return contact_distributions_tuple end """ Load rest data from `data/canada-network-data/Timeuse/Rest/RData`. diff --git a/CovidAlertVaccinationModel/test/ABM/abm.jl b/CovidAlertVaccinationModel/test/ABM/abm_test.jl similarity index 77% rename from CovidAlertVaccinationModel/test/ABM/abm.jl rename to CovidAlertVaccinationModel/test/ABM/abm_test.jl index 8d8970e6161aefd613dec8747548f777fe3f208b..e35004d11d1889999778fe83dc46dcc50891a2f5 100644 --- a/CovidAlertVaccinationModel/test/ABM/abm.jl +++ b/CovidAlertVaccinationModel/test/ABM/abm_test.jl @@ -1,9 +1,5 @@ -model_sizes = [100,1000,5000] -vaccination_strategies = [vaccinate_uniformly!] -vaccination_rates = [0.000,0.005,0.01,0.05] -infection_rates = [0.01,0.05,0.1] -agent_models = ThreadsX.map(model_size -> AgentModel(model_size...), model_sizes) -dem_cat = AgentDemographic.size -1 +const model_sizes = [100,1000,5000] +const dem_cat = AgentDemographic.size -1 #network generation @@ -26,9 +22,8 @@ dem_cat = AgentDemographic.size -1 end -function vac_rate_test(model,vac_strategy, vac_rate; rng = Xoroshiro128Plus()) - u_0 = get_u_0(length(model.demographics)) - params = merge(get_parameters(),(vaccines_per_day = vac_rate,)) +function vaccinator_opinion_test(model,vac_strategy, Ï€_base; rng = Xoroshiro128Plus()) + params = merge(get_parameters(),(I_0_fraction = 0.0,Ï€_base)) steps = 300 sol1,_ = solve!(u_0,params,steps,model,vac_strategy); total_infections = count(x->x == AgentStatus(3),sol1[end]) @@ -38,7 +33,6 @@ end function infection_rate_test(model, inf_parameter; rng = Xoroshiro128Plus()) - params = merge(get_parameters(),(p = inf_parameter,)) steps = 300 # display(params) sol1,_ = solve!(params,steps,model,vaccinate_uniformly!); @@ -65,3 +59,12 @@ end @test test_comparison(x->infection_rate_test(m,x),infection_rates,<) end + +@testset "infection efficacy $sz" for (m,sz) in zip(deepcopy(agent_models),model_sizes) + @test test_comparison(x->infection_rate_test(m,x),infection_rates,<) +end + + + + + diff --git a/CovidAlertVaccinationModel/test/runtests.jl b/CovidAlertVaccinationModel/test/runtests.jl index e662c5b70b9adfc5e969a4abdfd04772e744acf7..4bfd2b0987477a5577acb08c2006dbb94b053f12 100644 --- a/CovidAlertVaccinationModel/test/runtests.jl +++ b/CovidAlertVaccinationModel/test/runtests.jl @@ -4,4 +4,4 @@ using Test using ThreadsX import StatsBase.mean -include("ABM/abm.jl") \ No newline at end of file +include("ABM/abm_test.jl") \ No newline at end of file diff --git a/ZeroWeightedDistributions/Project.toml b/ZeroWeightedDistributions/Project.toml index 8597f5b54ffbc568ef9ab6b4e1e79172c0969af7..f246a6e0b1d65accb6db9a9545ea2cc80f7a39dc 100644 --- a/ZeroWeightedDistributions/Project.toml +++ b/ZeroWeightedDistributions/Project.toml @@ -5,8 +5,12 @@ version = "0.1.0" [deps] Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f" +KernelDensity = "5ab0869b-81aa-558d-bb23-cbf5423bbe9b" +Plots = "91a5bcdd-55d7-5caf-9e0b-520d859cae80" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" +RandomNumbers = "e6cf234a-135c-5ec9-84dd-332b85af5143" StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91" +ThreadsX = "ac1d9e8a-700a-412c-b207-f0111f4b6c0d" [compat] julia = "1" diff --git a/ZeroWeightedDistributions/src/ZeroWeightedDistributions.jl b/ZeroWeightedDistributions/src/ZeroWeightedDistributions.jl index f346fe55621be7c856a0998d1fa2debdcd717702..b9dc8fe8a16c5c022818d850d22de8226ba4a3d8 100644 --- a/ZeroWeightedDistributions/src/ZeroWeightedDistributions.jl +++ b/ZeroWeightedDistributions/src/ZeroWeightedDistributions.jl @@ -21,10 +21,6 @@ struct ZWDist{BaseDistType <: Sampleable,T} <: Sampleable{Univariate,T} return new{DType,S}(α,DType(p...)) end end - -# function from_mean(::Type{ZWDist{DistType,T}},α,μ) where {DistType <: Distribution{Univariate,W} where W, T} -# return ZWDist(α,from_mean(DistType,μ/(1-α))) -# end StatsBase.mean(d::ZWDist{Dist,T}) where {Dist,T} = (1 - d.α)*StatsBase.mean(d.base_dist) @@ -32,7 +28,7 @@ function Distributions.pdf(d::ZWDist, x) if x == 0 return d.α + (1-d.α)*pdf(d.base_dist,0) else - return pdf(d.base_dist,x) + return (1-d.α)*pdf(d.base_dist,x) end end @@ -40,22 +36,6 @@ function Distributions.mean(d::ZWDist, x) return (1-d.α)*StatsBase.mean(d.base_dist,0) end - -function Base.rand(rng::AbstractRNG, s::ZWDist{DType,S}, n::Int) where {DType, S} - l = Vector{eltype(DType)}(undef,n) - Random.rand!(rng,l) - l[l .< s.α] .= zero(eltype(DType)) - Random.rand!(rng,s.base_dist,@view l[l .>= s.α]) - return l -end - -function Random.rand!(rng::AbstractRNG, s::ZWDist{DType,S}, l::T) where {T<:AbstractVector, DType,S} - Random.rand!(rng,l) - l[l .< s.α] .= zero(eltype(DType)) - Random.rand!(rng,s.base_dist,@view l[l .>= s.α]) - return l -end - function Base.rand(rng::AbstractRNG, s::ZWDist{DType,S}) where {DType,S} return ifelse(Base.rand(rng) < s.α, zero(eltype(DType)), Base.rand(rng,s.base_dist)) end diff --git a/ZeroWeightedDistributions/test/runtests.jl b/ZeroWeightedDistributions/test/runtests.jl index 3a2ef2aee5c205dcb6144a505d4af35d5b1a0231..15633ee1e1f5777e20e109b7e4528bce9b424692 100644 --- a/ZeroWeightedDistributions/test/runtests.jl +++ b/ZeroWeightedDistributions/test/runtests.jl @@ -4,41 +4,37 @@ using Distributions using Test using ThreadsX using RandomNumbers.Xorshifts -using Plots -# const dist_list = [Poisson,Geometric] -# const params_list = collect(0.01:0.5:0.9) -# const α_list = collect(0.0:0.1:1.0) -# const RNG = Xoroshiro128Star(1) -# const tol = 1e-1 -function test_dist(d) - # samples_vec = rand(RNG,d, 100) - # display(samples_vec) - # samples_map = map(x->rand(RNG,d), 1:10_000) - # epdf = kde(samples_vec) - # vec_err = all([abs(pdf(epdf,k) - pdf(d,k)) for k in 0.0:0.1:10] .< tol) - # epdf = kde(samples_map) - # map_err= all([abs(pdf(epdf,k) - pdf(d,k)) for k in 0.0:0.1:10] .< tol) - # # displlay((vec_err,map_err)) - - # p = scatter([pdf(epdf,k) for k in 0.0:0.1:10]) - # scatter!(p,[pdf(d,k) for k in 0.0:0.1:10]) - # display(p) - # return vec_err && map_err -end - +# using Plots +const dist_list = [Poisson] +const params_list = collect(0.1:0.2:0.9) +const α_list = collect(0.2:0.2:1.0) +const RNG = Xoroshiro128Star(1) +const tol = 0.01 +const sample_length = 1_000_000 +epdf(samples,k) = count(==(k), samples)/sample_length - -function test() - map(Iterators.product(dist_list,params_list,α_list)) do (d,p,α) - # println((d,p,α)) - test_dist(ZWDist(d,α,p)) - end +function test_dist_vec(d) + samples_vec = rand(RNG,d, sample_length) + vec_err = all(abs(epdf(samples_vec,k) - pdf(d,k)) < tol for k in 0:1:10) + return vec_err +end +function test_dist_map(d) + samples_map = map(x->rand(RNG,d), 1:sample_length) + map_err= all(abs(epdf(samples_map,k) - pdf(d,k)) < tol for k in 0:1:10) + return map_err end -test() -# @testset "ZeroWeightedDistributions.jl" begin +@testset "ZeroWeightedDistributions.jl" begin -# @test -# end + @testset for(d,p,α) in Iterators.product(dist_list,params_list,α_list) + println((d,p,α)) + @test test_dist_vec(ZWDist(d,α,p)) + end + + @testset for(d,p,α) in Iterators.product(dist_list,params_list,α_list) + println((d,p,α)) + @test test_dist_map(ZWDist(d,α,p)) + end +end