Skip to content
Snippets Groups Projects
graphs.jl 6.59 KiB
#This defines a type for the time-dependent-graph
#The tradeoff is (partially) memory-usage vs speed, so it tries to preallocate as much as possible
#As result, it stores a lot of redundant information
#Julia's LightGraphs.jl package uses edgelists to store their graphs, but here I am experimenting with bitarrays, which seem to be 4-5x faster, even for huge problems.
#Also completely decouple the generation of the contact vectors (the number of contacts each node has in each WorkSchool/Rest layer) from the
#generation of the corresponding graphs 
#bitmatrices are much faster at the graph generation but we lose access to neighbors in O(1)
struct MixingGraphs{GraphVector,A}
    graph::GraphVector
    base_graph::GraphVector
    contact_vector_ws::A
    contact_vector_rest::A
    function MixingGraphs(static_contacts, ws_mixing_matrices, rest_mixing_matrices, index_vectors)
        (length(ws_mixing_matrices) ==  length(rest_mixing_matrices)) || throw(ArgumentError("mixing matrix lists must be of equal length")) 
        ts_length = length(ws_mixing_matrices)
        contact_vector_ws = map(mm -> MixingContacts(index_vectors,mm), ws_mixing_matrices)
        contact_vector_rest = map(mm -> MixingContacts(index_vectors,mm), rest_mixing_matrices)
        base_graph = convert(BitArray,adjacency_matrix(static_contacts))
        return new{typeof(base_graph),typeof(contact_vector_ws)}(deepcopy(base_graph),base_graph,contact_vector_ws,contact_vector_rest)
    end
end


#A type that defines a matrix of vectors, such that the sum of the ijth vector is equal to the sum of the jith vector
struct MixingContacts{V}
    contact_array::V
    function MixingContacts(index_vectors,mixing_matrix)
        contacts = map(CartesianIndices(mixing_matrix)) do ind
            zeros(length(index_vectors[ind[1]]))
        end
        for i in 1:length(mixing_matrix[:,1]), j in 1:i  #diagonal
            generate_contact_vectors!(mixing_matrix[i,j],mixing_matrix[j,i],contacts[i,j],contacts[j,i] )
        end
        new{typeof(contacts)}(contacts)
    end
end

#generate a new mixing graph for the current timestep t
function advance_mixing_graph!(t,mixing_graph,index_vectors)
    mixing_graph.graph .= 0
    mixing_graph.graph .= mixing_graph.base_graph
    generate_mixing_graph!(mixing_graph.graph, index_vectors,mixing_graph.contact_vector_rest[t])
    generate_mixing_graph!(mixing_graph.graph, index_vectors,mixing_graph.contact_vector_ws[t])
    return nothing
end


function update_contacts!(mixing_contacts,mixing_matrix)
    for i in 1:length(mixing_matrix[:,1]), j in 1:i  #diagonal  
        generate_contact_vectors!(mixing_matrix[i,j],mixing_matrix[j,i],mixing_contacts.contact_array[i,j],mixing_contacts.contact_array[j,i] )
    end
end
function generate_contact_vectors!(ij_dist,ji_dist,i_to_j_contacts, j_to_i_contacts)
    rand!(RNG,ij_dist,i_to_j_contacts)
    rand!(RNG,ji_dist,j_to_i_contacts)
    l_i = length(i_to_j_contacts)
    l_j = length(j_to_i_contacts)
    contacts_sums = sum(i_to_j_contacts) - sum(j_to_i_contacts)
    sample_list_length = max(l_i,l_j) #better heuristic for this based on stddev of dist?
    index_list_i = sample(RNG,1:l_i,sample_list_length)
    index_list_j = sample(RNG,1:l_j,sample_list_length)
    sample_list_i = rand(RNG,ij_dist,sample_list_length)
    sample_list_j = rand(RNG,ji_dist,sample_list_length)
    for k = 1:sample_list_length
        if (contacts_sums != 0)
            i_index = index_list_i[k]
            j_index = index_list_j[k]    
            contacts_sums +=  j_to_i_contacts[j_index] - i_to_j_contacts[i_index]

            i_to_j_contacts[i_index] = sample_list_i[k]
            j_to_i_contacts[j_index] = sample_list_j[k]    
            contacts_sums += i_to_j_contacts[i_index] -  j_to_i_contacts[j_index]
        else
            break
        end
    end
    while contacts_sums != 0
        i_index = sample(RNG,1:l_i)
        j_index = sample(RNG,1:l_j)
        contacts_sums +=  j_to_i_contacts[j_index] - i_to_j_contacts[i_index]
        i_to_j_contacts[i_index] = rand(RNG,ij_dist)
        j_to_i_contacts[j_index] = rand(RNG,ji_dist)
        contacts_sums += i_to_j_contacts[i_index] -  j_to_i_contacts[j_index]
    end
    return nothing
end
#add a bipartite graph derived from mixing matrices onto g
function generate_mixing_graph!(g,index_vectors,contacts)
    for i in 1:length(index_vectors), j in 1:length(index_vectors)  
        random_bipartite_graph_fast_CL!(g,index_vectors[i],index_vectors[j],contacts.contact_array[i,j],contacts.contact_array[j,i])
    end
    return g
end

using StatsBase
#modify g so that nodes specified in anodes and bnodes are connected by a bipartite graph with expected degrees given by aseq and bseq
#implemented from Aksoy, S. G., Kolda, T. G., & Pinar, A. (2017). Measuring and modeling bipartite graphs with community structure
#simple algorithm, might produce parallel edges for small graphs
function random_bipartite_graph_fast_CL!(g::SimpleGraph,anodes,bnodes,aseq,bseq) 
    lena = length(aseq)
    lenb = length(bseq)
    m = Int(sum(aseq))
    @assert sum(aseq) == sum(bseq) "degree sequences must have equal sum"
    astubs = sample(RNG,anodes,StatsBase.weights(aseq./m), m)
    bstubs = sample(RNG,bnodes,StatsBase.weights(bseq./m), m)
    for k in 1:m
        add_edge!(g,astubs[k],bstubs[k])
    end
    return g
end

#same algorithm but with a bitarray, this cannot produce parallel edges
function random_bipartite_graph_fast_CL!(g::T,anodes,bnodes,aseq,bseq) where T<:AbstractArray 
    lena = length(aseq)
    lenb = length(bseq)
    m = Int(sum(aseq))
    @assert sum(aseq) == sum(bseq) "degree sequences must have equal sum"
    astubs = sample(RNG,anodes,StatsBase.weights(aseq./m), m)
    bstubs = sample(RNG,bnodes,StatsBase.weights(bseq./m), m)
    for k in 1:m
        g[astubs[k],bstubs[k]] = 1
    end
    return g
end

#generate population with households distributed according to household_size_distribution
function generate_population(num_households)
    households_composition = sample_household_data(num_households)
    households = map( l -> [fill(AgentDemographic(i),n) for (i,n) in enumerate(l)],households_composition) |>
                    l -> reduce(vcat,l) |>
                    l -> reduce(vcat,l)

    total_household_pop = sum(sum.(households_composition))
    household_networks = map(LightGraphs.complete_graph, sum.(households_composition))

    population_list = reduce(vcat,households)
    network = reduce(blockdiag,household_networks; init = SimpleGraph())
    index_vectors = [findall(x -> x == AgentDemographic(i), population_list) for i in 1:(AgentDemographic.size-1)]
    return (;
        population_list,
        network,
        index_vectors
    )
end