diff --git a/DetNet.jl b/DetNet.jl index f300bdc..2cdec50 100644 --- a/DetNet.jl +++ b/DetNet.jl @@ -1,7 +1,7 @@ # module StoBloDetNet -using LinearAlgebra, Combinatorics, NamedArrays, Distributions -import Base: show, rand, getindex +using LinearAlgebra, Combinatorics, NamedArrays, Distributions, StatsFuns +import Base: show, rand, getindex, length import LinearAlgebra.eigen import Distributions: logpdf, loglikelihood, params diff --git a/StoBloDetNet.jl b/StoBloDetNet.jl index 1918873..a3c2bfa 100644 --- a/StoBloDetNet.jl +++ b/StoBloDetNet.jl @@ -9,12 +9,14 @@ struct StoBloDetNet z::Vector{Int} block::Vector{DetNet} K::Int + L::Int function StoBloDetNet(edge, z, K, ρ, q) et = [edgetype(e, z, K) for e in edge]; + L = binomial(K+1,2); block = [DetNet(edge[et .== i], ρ[i], q[i]) - for i in 1:binomial(K+1,2)]; - new(edge,z,block,K) + for i in 1:L]; + new(edge,z,block,K,L) end end @@ -30,7 +32,9 @@ StoBloDetNet(n::Int,z,K,ρ::Float64,q::Float64) = # get the edge-type of e from e's node-types edgetype(e, z, K) = div(z[e[1]],K) + z[e[2]]; edgetype(e, S::StoBloDetNet) = edgetype(e, S.z, S.K); - +edgetype(e::Vector{T}, S::StoBloDetNet) where T = [edgetype(i,S) for i in e]; +edgetype(e::Vector{T},z,K) where T = [edgetype(i,z,K) for i in e]; +edgetype(S::StoBloDetNet) = edgetype(S.edge,S); function rand(S::StoBloDetNet, n::Int) samp = [Vector{Tuple{Int,Int}}(undef,0) for i in 1:n]; tmp = [rand(B,n) for B in S.block]; @@ -52,7 +56,10 @@ end function logpdf(S::StoBloDetNet, x::T, normconst = [logdet(B.L+I) for B in S.block]) where {T <: AbstractVector} + #get edge type for each element in x et = [edgetype(e,S) for e in x]; - xiter = (x[et .== k] for k in 1:binomial(S.K+1,2)); + #iterator for pulling out only edges of a given edge-type + xiter = (x[et .== k] for k in 1:S.L); + #loop over edge-types, evaluate return sum(logpdf(B, xk, nc) for (B, xk, nc) in zip(S.block, xiter, normconst)); end diff --git a/sampler.jl b/sampler.jl new file mode 100644 index 0000000..79fe966 --- /dev/null +++ b/sampler.jl @@ -0,0 +1,147 @@ +# data := a vector of T adjacency matrices: A = [a1, ..., aT] +# z := integer vector where z[i] is the node type of node i, z[i] ∈ 1,2,...,K +# ρ := float64 vector where edge of type (a,b), a<=b, has strength of repulsion ρ[(b choose 2)+a] +# q* := float64 vector where edge of type (a,b), a<=b, has quality* q*[(b choose 2)+a] + + +### data simulater +## simulater +# n := number of nodes +# T := number of adjacency matrices +# ρ := true vector where edge of type (a,b), a<=b, has strength of repulsion ρ[(b choose 2)+a] +# q := true vector where edge of type (a,b), a<=b, has quality* q*[(b choose 2)+a] +function simulate_samp(n::Int,T::Int,z::Vector{Int},ρ::Vector{Float64},q::Vector{Float64},K::Int) + S = StoBloDetNet(n,z,K,ρ,q); + samp = rand(S,T); + A = [a_edge(s,n) for s in samp]; +end + +# Transform a set of edge into an adjacency matrix +function a_edge(edge::Array{Tuple{Int64,Int64},1},n::Int) + a = rand(0:0,n,n); + for e in edge + a[CartesianIndex(e)] = 1; + end + a = Symmetric(a); + return a +end + +function rand_q(K::Int) + res = rand(binomial(K+1,2)); + res[res.==1] .= 0; + return res +end + +function rand_ρ(K::Int) + res = .5*rand(binomial(K+1,2)); + return res +end + +function rand_z(n::Int,K::Int) + res = rand(1:K,n); + return res +end + +function update_ρq(z::Vector{Int},A::Vector{Symmetric{Int,Matrix{Int}}},ρ::Vector{Float64},q::Vector{Float64},K::Int) + α = .5; + β = .0001; + for i in 1:length(ρ) + y = rand()*exp(postA(z,A,ρ,q,K))*(q[i]^(α-1)*exp(-β*q[i])); + # for ρ + ρ_i_old = ρ[i]; + u_ρ = rand(); + L_ρ = .0; + R_ρ = .5; + # for q + q_i_old = q[i]; + u_q = rand(); + L_q = .0; + R_q = 1.0; + # sample from H, shrinking when points are rejected + while true + u_ρ = rand(); + u_q = rand(); + ρ[i] = L_ρ + u_ρ*(R_ρ-L_ρ); + q[i] = L_q + u_q*(R_q-L_q); + P = exp(postA(z,A,ρ,q,K))*(q[i]^(α-1)*exp(-β*q[i])); + if (y
x==e,EDGE) for e in edge]; + ind = hcat(ind...); + return ind +end + +function edge_a(a::Symmetric{Int,Matrix{Int}}) + # generate the edges from an adjacency matrix + edge = findall(x->x==1,a); + edge = [i2s(e) for e in edge]; + edge = unique(edge); + return edge +end + +function i2s(e::CartesianIndex{2}) + # transform CartesianIndex to tuple{int,int} + n1 = e[1]; + n2 = e[2]; + if n1 < n2 + t = tuple(n1,n2); + else + t = tuple(n2,n1); + end + return t +end + +# function subMatrix(M::Array{Float64,2},row::Int,col::Int) +# sub = M[setdiff(1:end,row),setdiff(1:end,col)]; +# return sub +# end +# +# function determinant(M::Array{Float64,2},n::Int) +# det = 0; +# if n==1 +# return M[1,1]; +# elseif n==2 +# return M[1,1]*M[2,2]-M[2,1]*M[1,2] +# else +# for i in 1:n +# det = det+((-1)^(i+1) * M[1,i] * determinant(subMatrix(M,1,i),n-1)); +# end +# end +# return det +# end + + +#### diff --git a/test.jl b/test.jl new file mode 100644 index 0000000..475cf80 --- /dev/null +++ b/test.jl @@ -0,0 +1,97 @@ +include(string(pwd(),"/DetNet.jl")); +include(string(pwd(),"/StoBloDetNet.jl")); +using Plots +function sametype(z::Vector{T},K) where T + return mapreduce(k -> (z.==k)*(z.==k)', +, 1:K) +end +function tril_vec(M) + return M[tril!(trues(size(M)),-1)] +end + +ρ = 0.499999; +q = 0.9; +K = 3; +n = 18; +z = repeat(collect(1:K),inner=div(n,K)); +# z = vcat(fill(1,div(n*2,3)),fill(2,div(n,3))) +S = StoBloDetNet(n,z,K,ρ,q); +y = rand(S) + +##### inference with known K +# zsamp = rand(Categorical(fill(1/K,K)), n); +zsamp = sample(z,n,replace=false); +niter = 10001; +thin = 2; +saveiter = 1:thin:niter; +nsamp = length(saveiter); +niter = maximum(saveiter); + +sout = Matrix{Int}(undef,n,nsamp); +sout[:,1] = zsamp; +for t in 2:niter + update_z!(zsamp,y,ρ,q,K); + if t ∈ saveiter + print(t,"\r") + j = findfirst(saveiter.==t); + sout[:,j] = zsamp; + end +end + +###### inference on K +Kseq = 2:5; +niter = 1001; +thin = 1; +saveiter = 1:thin:niter; +nsamp = length(saveiter); +niter = maximum(saveiter); + +ll = Matrix{Float64}(undef,length(Kseq),nsamp); +for (i,k) in enumerate(Kseq) + print(string(k),"\n") + zsamp = rand(Categorical(fill(1/k,k)), n); + ll[i,1] = logpdf(StoBloDetNet(n,zsamp,k,ρ,q),y); + for t in 2:niter + update_z!(zsamp,y,ρ,q,k); + if t ∈ saveiter + print(t,"\r") + j = findfirst(saveiter.==t); + ll[i,j] = logpdf(StoBloDetNet(n,zsamp,k,ρ,q),y); + end + end + print("\n") +end + +plot(hcat(fill(logpdf(DetNet(n,ρ,q),y),nsamp),ll',fill(logpdf(S,y),nsamp)), + label=hcat("1",string.(Kseq)...,"truth"),legend=:bottomright, + xlab="iteration",ylab="loglik") + +#### evaluate inference +typeA = [cor(tril_vec(sametype(z,K)),tril_vec(sametype(sout[:,i],K))) + for i in 1:nsamp]; +ll = [logpdf(StoBloDetNet(n,sout[:,i],K,ρ,q),y) for i in 1:nsamp]; +p1 = plot( + plot(saveiter,hcat(ll,fill(logpdf(S,y),nsamp)),yaxis=("loglik"),legend=false, + title=string("rho=",ρ,", q=",q,", n=",n,", K=",K)), + plot(saveiter,typeA,yaxis=("type agreement"),xaxis=("iteration"),legend=false), + layout=(2,1)); +plot(p1,heatmap(mean(sametype(sout[:,i],K) for i in 500:nsamp),aspect_ratio=:equal, + xticks=div(n,K):div(n,K):n,yticks=div(n,K):div(n,K):n)) + + +#### discriminability +n = 30; +z = repeat(collect(1:K),inner=div(n,K)); +K = 3; +ρseq = [0.4, 0.45, 0.49, 0.499, 0.5-1e-5]; +qseq = [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 0.99]; +lldiff = Matrix{Float64}(undef,length(ρseq),length(qseq)); +for (i,ρ) in enumerate(ρseq) + for (j,q) in enumerate(qseq) + S = StoBloDetNet(n,z,K,ρ,q); + Y = rand(S,25); + lldiff[i,j] = mean((logpdf(S,y)- + logpdf(StoBloDetNet(n,sample(z,n,replace=false),K,ρ,q),y) + for y in Y)) + end +end +heatmap(string.(qseq),string.(ρseq),lldiff, aspect_ratio=1,xaxis=("q"),yaxis=("rho"))