diff --git a/CovidAlertVaccinationModel/src/ABM/abm.jl b/CovidAlertVaccinationModel/src/ABM/abm.jl
index 250c1d0c53d6f3ce5735952407ea5c194ddda8c7..084fd4d9b50bec979248cd63be3c53e967553a1a 100644
--- a/CovidAlertVaccinationModel/src/ABM/abm.jl
+++ b/CovidAlertVaccinationModel/src/ABM/abm.jl
@@ -8,10 +8,10 @@ default(framestyle = :box)
 
 
 function bench()
-
     steps = 100
     model_sol = ModelSolution(steps,get_parameters(),5000);
-    solve!(model_sol,identity)
+    recording = recorder(steps)
+    output = solve!(model_sol,DebugRecorder)
 end
 
 function abm()
diff --git a/CovidAlertVaccinationModel/src/ABM/mixing_distributions.jl b/CovidAlertVaccinationModel/src/ABM/mixing_distributions.jl
index 10f045b8f82526f321d8bec447ee80b7021acfcd..8b4ad25750ae22a5cbba2342a939d09c259007e2 100644
--- a/CovidAlertVaccinationModel/src/ABM/mixing_distributions.jl
+++ b/CovidAlertVaccinationModel/src/ABM/mixing_distributions.jl
@@ -17,7 +17,7 @@ end
 
 
 
-const contact_time_distribution_matrix = [Geometric() for i in 1:(AgentDemographic.size-1), j in 1:(AgentDemographic.size-1)]
+const contact_time_distributions = load_contact_time_distributions()
 
 
 function alpha_matrix(alphas)
diff --git a/CovidAlertVaccinationModel/src/ABM/mixing_graphs.jl b/CovidAlertVaccinationModel/src/ABM/mixing_graphs.jl
index 1c26f96739b4f560e63541ba45d6e20fcfb203da..c038e83329d3754a8e7f239ef4ecd647637ae913 100644
--- a/CovidAlertVaccinationModel/src/ABM/mixing_graphs.jl
+++ b/CovidAlertVaccinationModel/src/ABM/mixing_graphs.jl
@@ -31,59 +31,49 @@ function random_bipartite_graph_fast_CL!(g::SimpleGraph,anodes,bnodes,aseq,bseq,
     return g    
 end
 
-using DataStructures
-struct MixingGraph{G, V, M}
+
+#defining my own weighted graph type cause we need to be able to resample the edge weights in a particular way
+struct WeightedGraph{G, V,M} 
     g::G
     weights::V
-    covid_alert_time::M
-    function MixingGraph(demographics,demographic_index_vectors,mixing_matrix)
+    weights_distribution_matrix::M
+    function WeightedGraph(demographics,demographic_index_vectors,mixing_matrix,weights_distribution_matrix)
         
         contacts = MixingContacts(demographic_index_vectors,mixing_matrix)
         
         g = Graph(length(demographics))
         
         weights = RobinDict{Tuple{Int,Int},UInt8}()
+
         for i in 1:length(demographic_index_vectors), j in 1:i  #diagonal
             random_bipartite_graph_fast_CL!(g,demographic_index_vectors[i],demographic_index_vectors[j],contacts.contact_array[i,j],contacts.contact_array[j,i],weights)
         end
 
         covid_alert_time = zeros(nv(g))
-        return new{typeof(g),typeof(weights),typeof(covid_alert_time)}(
+        return new{typeof(g),typeof(weights),typeof(weights_distribution_matrix)}(
             g,
             weights,
-            covid_alert_time
+            weights_distribution_matrix
         )
     end
-    function MixingGraph(g::SimpleGraph)
+    function WeightedGraph(g::SimpleGraph,weights_distribution_matrix)
         weights = RobinDict{Tuple{Int,Int},UInt8}()
-        covid_alert_time = zeros(nv(g))
-        return new{typeof(g),typeof(weights),typeof(covid_alert_time)}(
+        return new{typeof(g),typeof(weights),typeof(weights_distribution_matrix)}(
             g,
             weights,
-            covid_alert_time
+            weights_distribution_matrix
         )
     end
 end
 
-function MixingGraph(g::SimpleGraph,demographics,contact_time_distribution_matrix)
-    mg = MixingGraph(g)
-    return mg
-end
-function MixingGraph(demographics,demographic_index_vectors,mixing_matrix,contact_time_distribution_matrix)
-    mg = MixingGraph(demographics,demographic_index_vectors,mixing_matrix)
-    return mg
-end
-
-function sample_mixing_graph!(mixing_graph,population_demographics, contact_time_distributions)
+function sample_mixing_graph!(mixing_graph,population_demographics)
     for (k,e) in enumerate(edges(mixing_graph.g))
         i = src(e)
         j = dst(e)
         demo_i = Int(population_demographics[i])
         demo_j = Int(population_demographics[j])
 
-        contact_time = rand(RNG, contact_time_distributions[demo_i,demo_j])
-        mixing_graph.covid_alert_time[i] += contact_time
-        mixing_graph.covid_alert_time[j] += contact_time
+        contact_time = rand(RNG, mixing_graph.weights_distribution_matrix[demo_i,demo_j])
         mixing_graph.weights[(i,j)] = contact_time
         mixing_graph.weights[(j,i)] = contact_time
     end
@@ -116,7 +106,7 @@ function generate_contact_vectors!(ij_dist,ji_dist,i_to_j_contacts::Vector{T}, j
         sample!(RNG,1:l_j,index_list_j)
         rand!(RNG,ij_dist,sample_list_i)
         rand!(RNG,ji_dist,sample_list_j)
-        @inbounds @simd for i = 1:inner_iter
+        @inbounds for i = 1:inner_iter
             if csum != 0
                 csum = reindex!(i,csum,index_list_i,index_list_j,j_to_i_contacts,i_to_j_contacts,sample_list_i,sample_list_j)
             end
diff --git a/CovidAlertVaccinationModel/src/ABM/model_setup.jl b/CovidAlertVaccinationModel/src/ABM/model_setup.jl
index efb69fb0ecfba1f3020ca6705e61c4348b0afaba..6d623e82ac61196f3e620148b7b7513459732cc7 100644
--- a/CovidAlertVaccinationModel/src/ABM/model_setup.jl
+++ b/CovidAlertVaccinationModel/src/ABM/model_setup.jl
@@ -1,11 +1,11 @@
 
 function get_parameters()
     params = (
-        I_0_fraction = 0.0,
+        I_0_fraction = 0.05,
         base_transmission_probability = 0.5,
         recovery_rate = 0.1,
-        immunization_loss_prob = 0.01,
-        π_base = -0.1,
+        immunization_loss_prob = 0.5,
+        π_base = -0.05,
         η = 0.0,
         κ = 0.0,
         ω = 0.0,
@@ -13,7 +13,7 @@ function get_parameters()
         ω_en = 0.0,
         ρ_en =  [0.0,0.0,0.0],
         γ = 0.0,
-        β = 1000.0,
+        β = 10.0,
         notification_parameter = 0.0,
         vaccinator_prob = 0.5,
         app_user_fraction = 0.5,
@@ -48,7 +48,7 @@ struct ModelSolution{T,G}
     app_user_list::Vector{Int}
     app_user_index::Vector{Int}
     status_totals::Vector{Int}
-
+    status_totals_next::Vector{Int}
     function ModelSolution(sim_length,params::T,num_households) where T
         demographics,base_network,index_vectors = generate_population(num_households)
         pop_sizes = length.(index_vectors)
@@ -66,14 +66,14 @@ struct ModelSolution{T,G}
         
         u_0_inf,u_0_vac = get_u_0(nodes,params.I_0_fraction,params.vaccinator_prob)
 
-        home_static_edges = MixingGraph(base_network,demographics,contact_time_distribution_matrix) #network with households and LTC homes
-        ws_static_edges = MixingGraph(demographics,index_vectors,ws_matrix_list.daily,contact_time_distribution_matrix)
-        ws_weekly_edges = MixingGraph(demographics,index_vectors,ws_matrix_list.twice_a_week,contact_time_distribution_matrix)
-        ws_daily_edges_vector = [MixingGraph(demographics,index_vectors,ws_matrix_list.otherwise,contact_time_distribution_matrix) for i in 1:sim_length]
+        home_static_edges = WeightedGraph(base_network,contact_time_distributions.hh) #network with households and LTC homes
+        ws_static_edges = WeightedGraph(demographics,index_vectors,ws_matrix_list.daily,contact_time_distributions.ws)
+        ws_weekly_edges = WeightedGraph(demographics,index_vectors,ws_matrix_list.twice_a_week,contact_time_distributions.ws)
+        ws_daily_edges_vector = [WeightedGraph(demographics,index_vectors,ws_matrix_list.otherwise,contact_time_distributions.ws) for i in 1:sim_length]
     
-        rest_static_edges = MixingGraph(demographics,index_vectors,rest_matrix_list.daily,contact_time_distribution_matrix)
-        rest_weekly_edges = MixingGraph(demographics,index_vectors,rest_matrix_list.twice_a_week,contact_time_distribution_matrix)
-        rest_daily_edges_vector = [MixingGraph(demographics,index_vectors,rest_matrix_list.otherwise,contact_time_distribution_matrix) for i in 1:sim_length]
+        rest_static_edges = WeightedGraph(demographics,index_vectors,rest_matrix_list.daily,contact_time_distributions.rest)
+        rest_weekly_edges = WeightedGraph(demographics,index_vectors,rest_matrix_list.twice_a_week,contact_time_distributions.rest)
+        rest_daily_edges_vector = [WeightedGraph(demographics,index_vectors,rest_matrix_list.otherwise,contact_time_distributions.rest) for i in 1:sim_length]
 
         inf_network_lists = [
             [home_static_edges,rest_static_edges] for i in 1:sim_length
@@ -96,6 +96,14 @@ struct ModelSolution{T,G}
         covid_alert_times = zeros(Int,length(app_user_index),14) #two weeks worth of values
         time_of_last_alert = fill(-1,length(app_user_index)) #two weeks worth of values
 
+        status_totals = zeros(Int, AgentStatus.size)
+
+        status_totals[1] = count(==(Susceptible), u_0_inf)
+        status_totals[2] = count(==(Infected), u_0_inf)
+        status_totals[3] = count(==(Recovered), u_0_inf)
+        status_totals[4] = count(==(Immunized), u_0_inf)
+    
+
         return new{T,typeof(home_static_edges)}(
             sim_length,
             nodes,
@@ -112,7 +120,9 @@ struct ModelSolution{T,G}
             demographics,
             is_app_user,
             app_user_list,
-            app_user_index
+            app_user_index,
+            status_totals,
+            copy(status_totals)
         )
     end
 end
diff --git a/CovidAlertVaccinationModel/src/ABM/output.jl b/CovidAlertVaccinationModel/src/ABM/output.jl
index dbd0bbf51183a4468efcfcc682ba62d9dd1322f8..89229449ab599b34c31f99beb3af47c8417c03a7 100644
--- a/CovidAlertVaccinationModel/src/ABM/output.jl
+++ b/CovidAlertVaccinationModel/src/ABM/output.jl
@@ -1,5 +1,37 @@
-abstract type AbstractOutputData end
+#needlessly overwrought output interface
 
-struct DebugOutputData <: AbstractOutputData
-    TotalInfected::
-end
\ No newline at end of file
+
+abstract type AbstractRecorder end
+
+struct DebugRecorder <: AbstractRecorder
+    recorded_status_totals::Array{Int,2}
+    Total_S::Vector{Int}
+    Total_I::Vector{Int}
+    Total_R::Vector{Int}
+    Total_V::Vector{Int}
+    Total_Vaccinator::Vector{Int}
+    function DebugRecorder(sim_length)
+        return new(
+            zeros(Int,AgentStatus.size,sim_length),
+            zeros(Int,sim_length),
+            zeros(Int,sim_length),
+            zeros(Int,sim_length),
+            zeros(Int,sim_length),
+            zeros(Int,sim_length),
+        )
+    end
+end
+
+function record!(t,modelsol, recorder::DebugRecorder)
+    recorder.Total_S[t] = count(==(Susceptible),modelsol.u_inf)
+    recorder.Total_I[t] = count(==(Infected),modelsol.u_inf)
+    recorder.Total_R[t] = count(==(Recovered),modelsol.u_inf)
+    recorder.Total_V[t] = count(==(Immunized),modelsol.u_inf)
+    recorder.Total_Vaccinator[t] = count(==(true),modelsol.u_vac)
+
+    recorder.recorded_status_totals[:,t] .= modelsol.status_totals
+end
+
+function record!(t,modelsol, recorder::Nothing)
+    #do nothing
+end
diff --git a/CovidAlertVaccinationModel/src/ABM/solve.jl b/CovidAlertVaccinationModel/src/ABM/solve.jl
index 0ccdc1c7cca4e8696ab0cf86dd9dfa2bb8b34957..817aa6a9c4ed9f1f7edb84d143e2409436975345 100644
--- a/CovidAlertVaccinationModel/src/ABM/solve.jl
+++ b/CovidAlertVaccinationModel/src/ABM/solve.jl
@@ -2,14 +2,7 @@
 function contact_weight(p, contact_time) 
     return 1 - (1-p)^contact_time
 end
-
-
-function EN_payoff()
-    return 0.0
-end
-
-function update_alert_durations!(t,modelsol)
-    
+function update_alert_durations!(t,modelsol)  
     @unpack notification_parameter = modelsol.params
     @unpack time_of_last_alert, app_user_index,inf_network_lists,covid_alert_times,app_user = modelsol
 
@@ -30,31 +23,39 @@ function update_alert_durations!(t,modelsol)
 end
 function update_infection_state!(t,modelsol)
     @unpack base_transmission_probability,immunization_loss_prob,recovery_rate = modelsol.params
-    @unpack u_inf,u_vac,u_next_inf,u_next_vac,demographics,inf_network_lists = modelsol
-
+    @unpack u_inf,u_vac,u_next_inf,u_next_vac,demographics,inf_network_lists,status_totals,status_totals_next = modelsol
+    
+    function agent_transition!(node, from::AgentStatus,to::AgentStatus) 
+        status_totals_next[Int(from)] -= 1
+        status_totals_next[Int(to)] += 1
+        u_next_inf[node] = to
+    end
+    
+    u_next_inf .= u_inf
+    status_totals_next .= status_totals
     for i in 1:modelsol.nodes
         agent_status = u_inf[i]
         is_vaccinator = u_vac[i]
         agent_demo = demographics[i]
         if agent_status == Susceptible
             if is_vaccinator
-                u_next_inf[i] = Immunized
+                agent_transition!(i, Susceptible,Immunized)
             else
                 for mixing_graph in inf_network_lists[t]
                     for j in neighbors(mixing_graph.g,i)    
                         if u_inf[j] == Infected && rand(RNG) < contact_weight(base_transmission_probability,mixing_graph.weights[(i,j)])
-                            u_next_inf[i] = Infected
+                            agent_transition!(i, Susceptible,Infected)
                         end
                     end
                 end
             end
         elseif agent_status == Infected
             if rand(RNG) < recovery_rate
-                u_next_inf[i] = Recovered
+               agent_transition!(i, Infected,Recovered)
             end
         elseif agent_status == Immunized
             if rand(RNG) < immunization_loss_prob
-                u_next_inf[i] = Susceptible
+               agent_transition!(i, Immunized,Susceptible)
             end
         end
     end
@@ -64,6 +65,8 @@ function update_vaccination_opinion_state!(t,modelsol,total_infections)
     @unpack π_base, η,γ, κ, ω, ρ, ω_en,ρ_en,γ,β = modelsol.params
     @unpack demographics,time_of_last_alert, nodes, soc_networks,u_vac,u_next_vac,app_user,app_user_list = modelsol
     app_user_pointer = 0
+
+
     for i in 1:nodes
         vac_payoff = 0
         soc_nbrs_vac = [0,0,0]
@@ -80,20 +83,24 @@ function update_vaccination_opinion_state!(t,modelsol,total_infections)
                 end
             end
         end
-        vac_payoff += π_base + dot(ρ,soc_nbrs_vac) + total_infections*ω + ifelse(num_soc_nbrs> 0, κ * ((sum(soc_nbrs_vac) - soc_nbrs_nonvac/num_soc_nbrs)),0)
+        vac_payoff += π_base + dot(ρ,soc_nbrs_vac) + total_infections*ω +
+            ifelse(num_soc_nbrs> 0, κ * ((sum(soc_nbrs_vac) - soc_nbrs_nonvac/num_soc_nbrs)),0)
         
         if app_user[i] && time_of_last_alert[app_user_list[i]]>=0
             vac_payoff += γ^(-1*(t - time_of_last_alert[app_user_list[i]]))* (η + dot(ρ_en,soc_nbrs_vac) + total_infections*ω_en) 
         end
+        
         if u_vac[i]
             if rand(RNG) < 1 - Φ(vac_payoff,β)
-                # display("$i switch")
                 u_next_vac[i] = !u_vac[i]
+            else
+                u_next_vac[i] = u_vac[i]
             end
         else
             if rand(RNG) < Φ(vac_payoff,β)
-                # display("$i switchback")
                 u_next_vac[i] = !u_vac[i]
+            else
+                u_next_vac[i] = u_vac[i]
             end
         end
     end
@@ -102,24 +109,29 @@ end
 
 function agents_step!(t,modelsol)
     for network in modelsol.inf_network_lists[t]
-        sample_mixing_graph!(network,modelsol.demographics, contact_time_distribution_matrix) #get new contact weights
+        sample_mixing_graph!(network,modelsol.demographics) #get new contact weights
     end
+
     update_alert_durations!(t,modelsol)
-    update_vaccination_opinion_state!(t,modelsol,count(==(Infected),modelsol.u_inf))
+    update_vaccination_opinion_state!(t,modelsol,modelsol.status_totals[Int(Infected)])
     update_infection_state!(t,modelsol)
+
     modelsol.u_vac .= modelsol.u_next_vac
     modelsol.u_inf .= modelsol.u_next_inf
+    modelsol.status_totals .= modelsol.status_totals_next
 end
 
 
 
-function solve!(modelsol,recording_function)
+function solve!(modelsol,recording)
     for t in 1:modelsol.sim_length
         #advance agent states based on the new network
         agents_step!(t,modelsol) 
-        display((count(==(true),modelsol.u_vac),count(==(Immunized),modelsol.u_inf)))
+
+        record!(t,modelsol,recording)
+
     end
-    # return solution, network_lists
+    return recording
 end
 
 
diff --git a/CovidAlertVaccinationModel/src/CovidAlertVaccinationModel.jl b/CovidAlertVaccinationModel/src/CovidAlertVaccinationModel.jl
index 20f401c7cbf3f65fca0dd78731ec9c63ccbecda1..6ef2ce693d8f2443d7add25c39352c0b7a0e54e1 100644
--- a/CovidAlertVaccinationModel/src/CovidAlertVaccinationModel.jl
+++ b/CovidAlertVaccinationModel/src/CovidAlertVaccinationModel.jl
@@ -8,6 +8,7 @@ using Distributions
 using StatsBase
 using Dates
 using LinearAlgebra
+using CovidAlertVaccinationModel
 using ThreadsX
 using DelimitedFiles
 using KernelDensity
@@ -15,6 +16,7 @@ using NamedTupleTools
 using NetworkLayout:Stress
 using NetworkLayout:SFDP
 using ZeroWeightedDistributions
+using DataStructures
 using Serialization
 using BenchmarkTools
 using Intervals
@@ -31,13 +33,15 @@ const RNG = Xoroshiro128Star(1)
 const color_palette = palette(:seaborn_pastel) #color theme for the plots
 include("utils.jl")
 include("data.jl")
-include("ABM/abm.jl")
 include("ABM/agents.jl")
 include("ABM/mixing_distributions.jl")
 include("ABM/mixing_graphs.jl")
 include("ABM/plotting.jl")
 include("ABM/model_setup.jl")
+include("ABM/output.jl")
 include("ABM/solve.jl")
+include("ABM/abm.jl")
+
 include("IntervalsModel/intervals_model.jl")
 include("IntervalsModel/interval_overlap_sampling.jl")
 include("IntervalsModel/hh_durations_model.jl")
diff --git a/CovidAlertVaccinationModel/src/data.jl b/CovidAlertVaccinationModel/src/data.jl
index 9540d82a5548e00569bdefc2368ed6c0574f0233..907ba53f3b415a8c81fe8bc84376242e65bb5baa 100644
--- a/CovidAlertVaccinationModel/src/data.jl
+++ b/CovidAlertVaccinationModel/src/data.jl
@@ -101,8 +101,17 @@ function load_mixing_matrices()
 end
 
 function load_contact_time_distributions()
-    dat = deserialize(joinpath(PACKAGE_FOLDER,"intervals_model_output/simulation_output/hh.dat"))
-    return dat
+    distkey = "Distributions.Poisson"
+    fnames = (
+        hh = "hh",
+        ws = "ws",
+        rest = "rest"
+    )
+    contact_distributions_tuple = map(fnames) do fname
+        dat = deserialize(joinpath(PACKAGE_FOLDER,"intervals_model_output","simulation_output","$fname.dat"))
+        return map(p -> Poisson(mode(p.particles)), as_symmetric_matrix(dat[distkey].P))
+    end
+    return contact_distributions_tuple
 end
 """
 Load rest data from `data/canada-network-data/Timeuse/Rest/RData`.
diff --git a/CovidAlertVaccinationModel/test/ABM/abm.jl b/CovidAlertVaccinationModel/test/ABM/abm_test.jl
similarity index 77%
rename from CovidAlertVaccinationModel/test/ABM/abm.jl
rename to CovidAlertVaccinationModel/test/ABM/abm_test.jl
index 8d8970e6161aefd613dec8747548f777fe3f208b..e35004d11d1889999778fe83dc46dcc50891a2f5 100644
--- a/CovidAlertVaccinationModel/test/ABM/abm.jl
+++ b/CovidAlertVaccinationModel/test/ABM/abm_test.jl
@@ -1,9 +1,5 @@
-model_sizes = [100,1000,5000]
-vaccination_strategies = [vaccinate_uniformly!]
-vaccination_rates = [0.000,0.005,0.01,0.05]
-infection_rates = [0.01,0.05,0.1]
-agent_models = ThreadsX.map(model_size -> AgentModel(model_size...), model_sizes)
-dem_cat = AgentDemographic.size -1
+const model_sizes = [100,1000,5000]
+const dem_cat = AgentDemographic.size -1
 
 
 #network generation
@@ -26,9 +22,8 @@ dem_cat = AgentDemographic.size -1
 end
 
 
-function vac_rate_test(model,vac_strategy, vac_rate; rng = Xoroshiro128Plus()) 
-    u_0 = get_u_0(length(model.demographics))
-    params = merge(get_parameters(),(vaccines_per_day = vac_rate,))
+function vaccinator_opinion_test(model,vac_strategy, π_base; rng = Xoroshiro128Plus()) 
+    params = merge(get_parameters(),(I_0_fraction = 0.0,Ï€_base))
     steps = 300
     sol1,_ = solve!(u_0,params,steps,model,vac_strategy);
     total_infections = count(x->x == AgentStatus(3),sol1[end])
@@ -38,7 +33,6 @@ end
 
 
 function infection_rate_test(model, inf_parameter; rng = Xoroshiro128Plus()) 
-    params = merge(get_parameters(),(p = inf_parameter,))
     steps = 300
     # display(params)
     sol1,_ = solve!(params,steps,model,vaccinate_uniformly!);
@@ -65,3 +59,12 @@ end
     @test test_comparison(x->infection_rate_test(m,x),infection_rates,<)
 end
 
+
+@testset "infection efficacy $sz" for (m,sz) in zip(deepcopy(agent_models),model_sizes)
+    @test test_comparison(x->infection_rate_test(m,x),infection_rates,<)
+end
+
+
+
+
+
diff --git a/CovidAlertVaccinationModel/test/runtests.jl b/CovidAlertVaccinationModel/test/runtests.jl
index e662c5b70b9adfc5e969a4abdfd04772e744acf7..4bfd2b0987477a5577acb08c2006dbb94b053f12 100644
--- a/CovidAlertVaccinationModel/test/runtests.jl
+++ b/CovidAlertVaccinationModel/test/runtests.jl
@@ -4,4 +4,4 @@ using Test
 using ThreadsX
 import StatsBase.mean
 
-include("ABM/abm.jl")
\ No newline at end of file
+include("ABM/abm_test.jl")
\ No newline at end of file
diff --git a/ZeroWeightedDistributions/Project.toml b/ZeroWeightedDistributions/Project.toml
index 8597f5b54ffbc568ef9ab6b4e1e79172c0969af7..f246a6e0b1d65accb6db9a9545ea2cc80f7a39dc 100644
--- a/ZeroWeightedDistributions/Project.toml
+++ b/ZeroWeightedDistributions/Project.toml
@@ -5,8 +5,12 @@ version = "0.1.0"
 
 [deps]
 Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
+KernelDensity = "5ab0869b-81aa-558d-bb23-cbf5423bbe9b"
+Plots = "91a5bcdd-55d7-5caf-9e0b-520d859cae80"
 Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
+RandomNumbers = "e6cf234a-135c-5ec9-84dd-332b85af5143"
 StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91"
+ThreadsX = "ac1d9e8a-700a-412c-b207-f0111f4b6c0d"
 
 [compat]
 julia = "1"
diff --git a/ZeroWeightedDistributions/src/ZeroWeightedDistributions.jl b/ZeroWeightedDistributions/src/ZeroWeightedDistributions.jl
index f346fe55621be7c856a0998d1fa2debdcd717702..b9dc8fe8a16c5c022818d850d22de8226ba4a3d8 100644
--- a/ZeroWeightedDistributions/src/ZeroWeightedDistributions.jl
+++ b/ZeroWeightedDistributions/src/ZeroWeightedDistributions.jl
@@ -21,10 +21,6 @@ struct ZWDist{BaseDistType <: Sampleable,T} <: Sampleable{Univariate,T}
         return new{DType,S}(α,DType(p...))
     end
 end
-
-# function from_mean(::Type{ZWDist{DistType,T}},α,μ) where {DistType <: Distribution{Univariate,W} where W, T}
-#     return ZWDist(α,from_mean(DistType,μ/(1-α)))
-# end
 StatsBase.mean(d::ZWDist{Dist,T}) where {Dist,T} = (1 - d.α)*StatsBase.mean(d.base_dist)
 
 
@@ -32,7 +28,7 @@ function Distributions.pdf(d::ZWDist, x)
     if x == 0 
         return d.α + (1-d.α)*pdf(d.base_dist,0)
     else
-        return pdf(d.base_dist,x)
+        return (1-d.α)*pdf(d.base_dist,x)
     end
 end
 
@@ -40,22 +36,6 @@ function Distributions.mean(d::ZWDist, x)
     return (1-d.α)*StatsBase.mean(d.base_dist,0)
 end
 
-
-function Base.rand(rng::AbstractRNG,  s::ZWDist{DType,S},  n::Int) where {DType, S}
-    l = Vector{eltype(DType)}(undef,n)
-    Random.rand!(rng,l)
-    l[l .< s.α] .= zero(eltype(DType))
-    Random.rand!(rng,s.base_dist,@view l[l .>= s.α])
-    return l
-end
-
-function Random.rand!(rng::AbstractRNG, s::ZWDist{DType,S}, l::T) where {T<:AbstractVector, DType,S}
-    Random.rand!(rng,l)
-    l[l .< s.α] .= zero(eltype(DType))
-    Random.rand!(rng,s.base_dist,@view l[l .>= s.α])
-    return l
-end
-
 function Base.rand(rng::AbstractRNG, s::ZWDist{DType,S}) where {DType,S} 
     return ifelse(Base.rand(rng) < s.α, zero(eltype(DType)), Base.rand(rng,s.base_dist))
 end
diff --git a/ZeroWeightedDistributions/test/runtests.jl b/ZeroWeightedDistributions/test/runtests.jl
index 3a2ef2aee5c205dcb6144a505d4af35d5b1a0231..15633ee1e1f5777e20e109b7e4528bce9b424692 100644
--- a/ZeroWeightedDistributions/test/runtests.jl
+++ b/ZeroWeightedDistributions/test/runtests.jl
@@ -4,41 +4,37 @@ using Distributions
 using Test
 using ThreadsX
 using RandomNumbers.Xorshifts
-using Plots
-# const dist_list = [Poisson,Geometric]
-# const params_list = collect(0.01:0.5:0.9)
-# const α_list = collect(0.0:0.1:1.0)
-# const RNG = Xoroshiro128Star(1)
-# const tol = 1e-1
-function test_dist(d)
-    # samples_vec = rand(RNG,d, 100)
-    # display(samples_vec)
-    # samples_map = map(x->rand(RNG,d), 1:10_000)
-    # epdf = kde(samples_vec)
-    # vec_err = all([abs(pdf(epdf,k) - pdf(d,k)) for k in 0.0:0.1:10] .< tol)
-    # epdf = kde(samples_map)
-    # map_err= all([abs(pdf(epdf,k) - pdf(d,k)) for k in 0.0:0.1:10] .< tol)
-    # # displlay((vec_err,map_err))
-    
-    # p = scatter([pdf(epdf,k) for k in 0.0:0.1:10])
-    # scatter!(p,[pdf(d,k) for k in 0.0:0.1:10])
-    # display(p)
-    # return vec_err && map_err
-end
-
+# using Plots
+const dist_list = [Poisson]
+const params_list = collect(0.1:0.2:0.9)
+const α_list = collect(0.2:0.2:1.0)
+const RNG = Xoroshiro128Star(1)
+const tol = 0.01
+const sample_length = 1_000_000
+epdf(samples,k) = count(==(k), samples)/sample_length
 
-
-function test()
-    map(Iterators.product(dist_list,params_list,α_list)) do (d,p,α)
-        # println((d,p,α))
-        test_dist(ZWDist(d,α,p))
-    end
+function test_dist_vec(d)
+    samples_vec = rand(RNG,d, sample_length)   
+    vec_err = all(abs(epdf(samples_vec,k) - pdf(d,k)) < tol for k in 0:1:10)
+    return vec_err
+end
+function test_dist_map(d)
+    samples_map = map(x->rand(RNG,d), 1:sample_length)
+    map_err= all(abs(epdf(samples_map,k) - pdf(d,k)) < tol for k in 0:1:10)
+    return map_err
 end
-test()
 
 
-# @testset "ZeroWeightedDistributions.jl" begin
+@testset "ZeroWeightedDistributions.jl" begin
     
     
-#     @test
-# end
+    @testset for(d,p,α) in Iterators.product(dist_list,params_list,α_list)
+        println((d,p,α))
+        @test test_dist_vec(ZWDist(d,α,p))
+    end
+
+    @testset for(d,p,α) in Iterators.product(dist_list,params_list,α_list)
+        println((d,p,α))
+        @test test_dist_map(ZWDist(d,α,p))
+    end
+end