Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions DetNet.jl
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# module StoBloDetNet

using LinearAlgebra, Combinatorics, NamedArrays, Distributions
import Base: show, rand, getindex
using LinearAlgebra, Combinatorics, NamedArrays, Distributions, StatsFuns
import Base: show, rand, getindex, length
import LinearAlgebra.eigen
import Distributions: logpdf, loglikelihood, params

Expand Down
15 changes: 11 additions & 4 deletions StoBloDetNet.jl
Original file line number Diff line number Diff line change
Expand Up @@ -9,12 +9,14 @@ struct StoBloDetNet
z::Vector{Int}
block::Vector{DetNet}
K::Int
L::Int

function StoBloDetNet(edge, z, K, ρ, q)
et = [edgetype(e, z, K) for e in edge];
L = binomial(K+1,2);
block = [DetNet(edge[et .== i], ρ[i], q[i])
for i in 1:binomial(K+1,2)];
new(edge,z,block,K)
for i in 1:L];
new(edge,z,block,K,L)
end
end

Expand All @@ -30,7 +32,9 @@ StoBloDetNet(n::Int,z,K,ρ::Float64,q::Float64) =
# get the edge-type of e from e's node-types
edgetype(e, z, K) = div(z[e[1]],K) + z[e[2]];
edgetype(e, S::StoBloDetNet) = edgetype(e, S.z, S.K);

edgetype(e::Vector{T}, S::StoBloDetNet) where T = [edgetype(i,S) for i in e];
edgetype(e::Vector{T},z,K) where T = [edgetype(i,z,K) for i in e];
edgetype(S::StoBloDetNet) = edgetype(S.edge,S);
function rand(S::StoBloDetNet, n::Int)
samp = [Vector{Tuple{Int,Int}}(undef,0) for i in 1:n];
tmp = [rand(B,n) for B in S.block];
Expand All @@ -52,7 +56,10 @@ end
function logpdf(S::StoBloDetNet, x::T,
normconst = [logdet(B.L+I) for B in S.block]) where {T <: AbstractVector}

#get edge type for each element in x
et = [edgetype(e,S) for e in x];
xiter = (x[et .== k] for k in 1:binomial(S.K+1,2));
#iterator for pulling out only edges of a given edge-type
xiter = (x[et .== k] for k in 1:S.L);
#loop over edge-types, evaluate
return sum(logpdf(B, xk, nc) for (B, xk, nc) in zip(S.block, xiter, normconst));
end
147 changes: 147 additions & 0 deletions sampler.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,147 @@
# data := a vector of T adjacency matrices: A = [a1, ..., aT]
# z := integer vector where z[i] is the node type of node i, z[i] ∈ 1,2,...,K
# ρ := float64 vector where edge of type (a,b), a<=b, has strength of repulsion ρ[(b choose 2)+a]
# q* := float64 vector where edge of type (a,b), a<=b, has quality* q*[(b choose 2)+a]


### data simulater
## simulater
# n := number of nodes
# T := number of adjacency matrices
# ρ := true vector where edge of type (a,b), a<=b, has strength of repulsion ρ[(b choose 2)+a]
# q := true vector where edge of type (a,b), a<=b, has quality* q*[(b choose 2)+a]
function simulate_samp(n::Int,T::Int,z::Vector{Int},ρ::Vector{Float64},q::Vector{Float64},K::Int)
S = StoBloDetNet(n,z,K,ρ,q);
samp = rand(S,T);
A = [a_edge(s,n) for s in samp];
end

# Transform a set of edge into an adjacency matrix
function a_edge(edge::Array{Tuple{Int64,Int64},1},n::Int)
a = rand(0:0,n,n);
for e in edge
a[CartesianIndex(e)] = 1;
end
a = Symmetric(a);
return a
end

function rand_q(K::Int)
res = rand(binomial(K+1,2));
res[res.==1] .= 0;
return res
end

function rand_ρ(K::Int)
res = .5*rand(binomial(K+1,2));
return res
end

function rand_z(n::Int,K::Int)
res = rand(1:K,n);
return res
end

function update_ρq(z::Vector{Int},A::Vector{Symmetric{Int,Matrix{Int}}},ρ::Vector{Float64},q::Vector{Float64},K::Int)
α = .5;
β = .0001;
for i in 1:length(ρ)
y = rand()*exp(postA(z,A,ρ,q,K))*(q[i]^(α-1)*exp(-β*q[i]));
# for ρ
ρ_i_old = ρ[i];
u_ρ = rand();
L_ρ = .0;
R_ρ = .5;
# for q
q_i_old = q[i];
u_q = rand();
L_q = .0;
R_q = 1.0;
# sample from H, shrinking when points are rejected
while true
u_ρ = rand();
u_q = rand();
ρ[i] = L_ρ + u_ρ*(R_ρ-L_ρ);
q[i] = L_q + u_q*(R_q-L_q);
P = exp(postA(z,A,ρ,q,K))*(q[i]^(α-1)*exp(-β*q[i]));
if (y<P)
break
end
if ρ[i]<ρ_i_old
L_ρ = ρ[i];
else
R_ρ = ρ[i];
end
if q[i]<q_i_old
L_q = q[i];
else
R_q = q[i];
end
end
end
return ρ,q
end

## Gibbs sampler to update node types z
function update_z!(z::Vector{Int},y::Vector{Tuple{Int,Int}},ρ,q,K)

logp_zi = Array{Float64,1}(undef,K);
for i in 1:length(z)
for k = 1:K
z[i] = k;
Szik = StoBloDetNet(length(z),z,K,ρ,q); #recreate complete L
logp_zi[k] = logpdf(Szik,y); #flat prior in z so only likelihood matters for posterior
end
logp_zi = logp_zi .- logsumexp(logp_zi); #normalize conditional posterior
z[i] = rand(Categorical(exp.(logp_zi))); #sample new k
end
end

function La_ind(EDGE::Array{Tuple{Int64,Int64},1},edge::Array{Tuple{Int64,Int64},1})
m = length(edge);
ind = [findall(x->x==e,EDGE) for e in edge];
ind = hcat(ind...);
return ind
end

function edge_a(a::Symmetric{Int,Matrix{Int}})
# generate the edges from an adjacency matrix
edge = findall(x->x==1,a);
edge = [i2s(e) for e in edge];
edge = unique(edge);
return edge
end

function i2s(e::CartesianIndex{2})
# transform CartesianIndex to tuple{int,int}
n1 = e[1];
n2 = e[2];
if n1 < n2
t = tuple(n1,n2);
else
t = tuple(n2,n1);
end
return t
end

# function subMatrix(M::Array{Float64,2},row::Int,col::Int)
# sub = M[setdiff(1:end,row),setdiff(1:end,col)];
# return sub
# end
#
# function determinant(M::Array{Float64,2},n::Int)
# det = 0;
# if n==1
# return M[1,1];
# elseif n==2
# return M[1,1]*M[2,2]-M[2,1]*M[1,2]
# else
# for i in 1:n
# det = det+((-1)^(i+1) * M[1,i] * determinant(subMatrix(M,1,i),n-1));
# end
# end
# return det
# end


####
97 changes: 97 additions & 0 deletions test.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,97 @@
include(string(pwd(),"/DetNet.jl"));
include(string(pwd(),"/StoBloDetNet.jl"));
using Plots
function sametype(z::Vector{T},K) where T
return mapreduce(k -> (z.==k)*(z.==k)', +, 1:K)
end
function tril_vec(M)
return M[tril!(trues(size(M)),-1)]
end

ρ = 0.499999;
q = 0.9;
K = 3;
n = 18;
z = repeat(collect(1:K),inner=div(n,K));
# z = vcat(fill(1,div(n*2,3)),fill(2,div(n,3)))
S = StoBloDetNet(n,z,K,ρ,q);
y = rand(S)

##### inference with known K
# zsamp = rand(Categorical(fill(1/K,K)), n);
zsamp = sample(z,n,replace=false);
niter = 10001;
thin = 2;
saveiter = 1:thin:niter;
nsamp = length(saveiter);
niter = maximum(saveiter);

sout = Matrix{Int}(undef,n,nsamp);
sout[:,1] = zsamp;
for t in 2:niter
update_z!(zsamp,y,ρ,q,K);
if t ∈ saveiter
print(t,"\r")
j = findfirst(saveiter.==t);
sout[:,j] = zsamp;
end
end

###### inference on K
Kseq = 2:5;
niter = 1001;
thin = 1;
saveiter = 1:thin:niter;
nsamp = length(saveiter);
niter = maximum(saveiter);

ll = Matrix{Float64}(undef,length(Kseq),nsamp);
for (i,k) in enumerate(Kseq)
print(string(k),"\n")
zsamp = rand(Categorical(fill(1/k,k)), n);
ll[i,1] = logpdf(StoBloDetNet(n,zsamp,k,ρ,q),y);
for t in 2:niter
update_z!(zsamp,y,ρ,q,k);
if t ∈ saveiter
print(t,"\r")
j = findfirst(saveiter.==t);
ll[i,j] = logpdf(StoBloDetNet(n,zsamp,k,ρ,q),y);
end
end
print("\n")
end

plot(hcat(fill(logpdf(DetNet(n,ρ,q),y),nsamp),ll',fill(logpdf(S,y),nsamp)),
label=hcat("1",string.(Kseq)...,"truth"),legend=:bottomright,
xlab="iteration",ylab="loglik")

#### evaluate inference
typeA = [cor(tril_vec(sametype(z,K)),tril_vec(sametype(sout[:,i],K)))
for i in 1:nsamp];
ll = [logpdf(StoBloDetNet(n,sout[:,i],K,ρ,q),y) for i in 1:nsamp];
p1 = plot(
plot(saveiter,hcat(ll,fill(logpdf(S,y),nsamp)),yaxis=("loglik"),legend=false,
title=string("rho=",ρ,", q=",q,", n=",n,", K=",K)),
plot(saveiter,typeA,yaxis=("type agreement"),xaxis=("iteration"),legend=false),
layout=(2,1));
plot(p1,heatmap(mean(sametype(sout[:,i],K) for i in 500:nsamp),aspect_ratio=:equal,
xticks=div(n,K):div(n,K):n,yticks=div(n,K):div(n,K):n))


#### discriminability
n = 30;
z = repeat(collect(1:K),inner=div(n,K));
K = 3;
ρseq = [0.4, 0.45, 0.49, 0.499, 0.5-1e-5];
qseq = [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 0.99];
lldiff = Matrix{Float64}(undef,length(ρseq),length(qseq));
for (i,ρ) in enumerate(ρseq)
for (j,q) in enumerate(qseq)
S = StoBloDetNet(n,z,K,ρ,q);
Y = rand(S,25);
lldiff[i,j] = mean((logpdf(S,y)-
logpdf(StoBloDetNet(n,sample(z,n,replace=false),K,ρ,q),y)
for y in Y))
end
end
heatmap(string.(qseq),string.(ρseq),lldiff, aspect_ratio=1,xaxis=("q"),yaxis=("rho"))