using Dictionaries
"""
A type that stores a pair of nodes representing an edge in the graph. 

We need to define a custom type for these so we can define a hash function on graph edges, in order to more efficiently use them in hashmaps (dictionaries)
"""
struct GraphEdge
    a::Int
    b::Int
end

"""
Define a hash on GraphEdge such that `hash(a,b) = hash(b,a)` (hash is commutative).

This is helpful because then we only need to store (a,b) in the graph edges weights dictionary, rather than both (a,b) and (b,a).
"""
function Base.hash(e::GraphEdge)
    return hash(minmax(e.a,e.b))
end

"""
Define symmetric edge equality, matches the hash function.
"""
function Base.isequal(e1::GraphEdge,e2::GraphEdge)
    return isequal(minmax(e1.a,e1.b),minmax(e2.a,e2.b))
end

"""
    sample_mixing_graph!(mixing_graph::Graph)


Resample all the weights in `mixing_graph`
"""
function sample_mixing_edges!(weights_dict,sampler_matrix,demographics)
    indices = keys(weights_dict)
    for ind in indices
        e = ind
        i = e.a
        j = e.b
        weight = rand(Random.default_rng(Threads.threadid()),sampler_matrix[Int(demographics[j]),Int(demographics[i])])
        setindex!(weights_dict,weight,ind)
    end
end

"""
Stores the full time dependent mixing graph for the model. I think this might be a weird abstraction for this idea but it works fine.

#Fields 

    remade_graphs::NTuple{N,G}

These are references to the graphs that get resampled everyday.

    resampled_graphs::NTuple{N,G}   

These are references to the graphs that get resampled everyday.


    graph_list::Vector{Vector{G}}

List of lists of graphs, one list for each day. 

"""
struct TimeDepMixingGraph{G,T1,T2}
    remade_graphs::T1
    resampled_graphs::T2
    graph_list::Vector{Vector{G}}
    function TimeDepMixingGraph(len,remade_graphs::T1,resampled_graphs::T2,base_graph_list::Vector{G}) where {G,T1,T2}
        return new{G,T1,T2}(
            remade_graphs,
            resampled_graphs,
            [copy(base_graph_list) for i in 1:len]
        )
    end
end

"""
Creates the `TimeDepMixingGraph` for our specific model. 

Assumes the simulation begins on Thursday arbitrarily.
"""
function time_dep_mixing_graphs(len,base_network,demographics,index_vectors,ws_matrix_tuple,rest_matrix_tuple)

    home_static_edges = WeightedGraph(base_network,demographics,contact_time_distributions.hh) #network with households and LTC homes

    ws_static_edges = WeightedGraph(demographics,index_vectors,ws_matrix_tuple.daily,contact_time_distributions.ws)
    ws_weekly_edges = WeightedGraph(demographics,index_vectors,ws_matrix_tuple.twice_a_week,contact_time_distributions.ws)
    ws_justonce_edges = WeightedGraph(demographics,index_vectors,ws_matrix_tuple.otherwise,contact_time_distributions.ws)

    rest_static_edges = WeightedGraph(demographics,index_vectors,rest_matrix_tuple.daily,contact_time_distributions.rest)
    rest_weekly_edges = WeightedGraph(demographics,index_vectors,rest_matrix_tuple.twice_a_week,contact_time_distributions.rest)
    rest_justonce_edges = WeightedGraph(demographics,index_vectors,rest_matrix_tuple.otherwise,contact_time_distributions.rest)
   
    inf_network_list = [home_static_edges,rest_static_edges,ws_justonce_edges,rest_justonce_edges] 
    soc_network_list = [home_static_edges,rest_static_edges,ws_static_edges]

    remade_graphs = (ws_justonce_edges,rest_justonce_edges)
    resampled_graphs = (home_static_edges,rest_static_edges,ws_static_edges,rest_weekly_edges,ws_weekly_edges)

    infected_mixing_graph = TimeDepMixingGraph(len,remade_graphs,resampled_graphs,inf_network_list)

    for (t,l) in enumerate(infected_mixing_graph.graph_list)
        day_of_week = mod(t,7)
        if !(day_of_week == 3 || day_of_week == 4) #simulation begins on thursday I guess
            push!(l, ws_static_edges)
        end
        if rand(Random.default_rng(Threads.threadid()))<5/7
            push!(l, ws_weekly_edges)
            push!(l, rest_weekly_edges)
        end
    end

    soc_mixing_graph = TimeDepMixingGraph(len,remade_graphs,resampled_graphs,soc_network_list)
    return infected_mixing_graph,soc_mixing_graph
end

"""
Completely remake all the graphs in `time_dep_mixing_graph.resampled_graphs`.
"""
function remake!(t,time_dep_mixing_graph,index_vectors,demographics)
    for wg in time_dep_mixing_graph.remade_graphs
        remake!(wg,index_vectors)
    end
    # display_degree(time_dep_mixing_graph.resampled_graphs[1])
    for wg in time_dep_mixing_graph.resampled_graphs
        if wg in time_dep_mixing_graph.graph_list[t]
            sample_mixing_edges!(wg.weights_dict,wg.sampler_matrix,demographics)
        end
    end
    # display_degree(time_dep_mixing_graph.resampled_graphs[1])
end


"""
Weighted graph type. Stores the graph in `g`, and the weights and edges in `mixing_edges`. 
Fields

    g::SimpleGraph

Stores the actual graph structure

    weights_dict::Dictionary{GraphEdge,UInt8}

Stores the weights used in the graph, so they can be easily resampled.    

    mixing_matrix::M1

Matrix of distributions determining node degrees 


    sampler_matrix::M

Matrix of distributions determining the edge weights

"""
struct WeightedGraph{G,M1,M2} 
    g::G
    weights_dict::Dictionary{GraphEdge,UInt8}
    mixing_matrix::M1
    sampler_matrix::M2
    function WeightedGraph(demographics::AbstractVector,index_vectors,mixing_matrix::M1, sampler_matrix::M2) where {M1,M2}
        g = Graph(length(demographics))
        weights_dict = Dictionary{GraphEdge,UInt8}()
        assemble_graph!(g,weights_dict,index_vectors,mixing_matrix,sampler_matrix)
        return new{typeof(g),M1,M2}(
            g,
            weights_dict,
            mixing_matrix,
            sampler_matrix,
        )
    end
    function WeightedGraph(g::G,demographics,sampler_matrix::M2) where {G<:SimpleGraph,M2}
        weights_dict = Dictionary{GraphEdge,UInt8}(;sizehint = ne(g))
        for e in edges(g)
            j = src(e)
            i = dst(e)
            weight = rand(Random.default_rng(Threads.threadid()),sampler_matrix[Int(demographics[j]),Int(demographics[i])])
            set!(weights_dict,GraphEdge(j,i),weight)
        end
        return new{typeof(g),Nothing,M2}(
            g,
            weights_dict,
            nothing,
            sampler_matrix,
        )
    end
end

function remake!(wg::WeightedGraph,index_vectors)
    empty!.(wg.g.fadjlist) #empty all the vector edgelists
    wg.g.ne = 0
    empty!(wg.weights_dict)
    assemble_graph!(wg.g,wg.weights_dict,index_vectors,wg.mixing_matrix,wg.sampler_matrix)
end


function assemble_graph!(g,weights_dict,index_vectors,mixing_matrix,sampler_matrix)
    for i in 1:3, j in 1:i #diagonal
        if i != j 
            edges = fast_chung_lu_bipartite(index_vectors[i],index_vectors[j],mixing_matrix[i,j],mixing_matrix[j,i])
        else #from one group to itself we need another algorithm
            edges = fast_chung_lu(index_vectors[i],mixing_matrix[i,i])
        end   
        for e in edges
            edge_weight_k = rand(Random.default_rng(Threads.threadid()),sampler_matrix[j,i])
            set!(weights_dict, e, edge_weight_k)
        end
    end
    for e in keys(weights_dict)
        add_edge!(g,e.a,e.b)
    end
end
function fast_chung_lu_bipartite(pop_i,pop_j,mixing_dist_ij,mixing_dist_ji)
    num_degrees_ij = similar(pop_i)
    num_degrees_ji =  similar(pop_j)
    generate_contact_vectors!(mixing_dist_ij,mixing_dist_ji,num_degrees_ij,num_degrees_ji)
    num_edges = sum(num_degrees_ij)
    stubs_i = Vector{Int}(undef,num_edges)
    stubs_j = similar(stubs_i)
    if num_edges>0
        sample!(Random.default_rng(Threads.threadid()),pop_i,Weights(num_degrees_ij./num_edges),stubs_i)
        sample!(Random.default_rng(Threads.threadid()),pop_j,Weights(num_degrees_ji./num_edges),stubs_j)
    end
    return GraphEdge.(stubs_i,stubs_j)
end

function fast_chung_lu(pop_i,mixing_dist)
    num_degrees_ii = similar(pop_i)
    generate_contact_vectors!(mixing_dist,num_degrees_ii)
    m = sum(num_degrees_ii)
    num_edges= div(m,2)
    stubs_i = Vector{Int}(undef,num_edges)
    stubs_j = similar(stubs_i)
    if m>0
        sample!(Random.default_rng(Threads.threadid()),pop_i,Weights(num_degrees_ii./m),stubs_i)
        sample!(Random.default_rng(Threads.threadid()),pop_i,Weights(num_degrees_ii./m),stubs_j)
    end
    return GraphEdge.(stubs_i,stubs_j)
end

neighbors(g::WeightedGraph,i) = LightGraphs.neighbors(g.g,i)
get_weight(g::WeightedGraph,e) = g.weights_dict[e]
function Base.show(io::IO, g::WeightedGraph) 
    print(io, "WG $(ne(g.g))")
end