Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion examples/CO24/functions_classification.jl
Original file line number Diff line number Diff line change
Expand Up @@ -215,7 +215,7 @@ function misclassificationrate(
threshold = mean(σ)
classif = σ .>= threshold
else
distances = [abs(σ[i] - σ[j]) for i in eachindex(σ), j in eachindex(σ)]
distances = MeanFieldGraph.pairwise_abs_distances(σ)
ct = cutree(hclust(distances; linkage=clustering); k=2)
id_excitatory = ct[argmax(σ)]
classif = ct .== id_excitatory
Expand Down
2 changes: 1 addition & 1 deletion src/MeanFieldGraph.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ using Clustering: Clustering, ClusteringResult, assignments, counts, kmeans, hcl
using Distributions: Distributions, Bernoulli, DiscreteUniform, fit, mean
using LinearAlgebra: LinearAlgebra, I, transpose
using Plots: Plots, heatmap, palette, plot
using KrylovKit: KrylovKit, eigsolve
using KrylovKit: KrylovKit, svdsolve

export MarkovChainModel, DiscreteTimeData, MarkovChainConnectivity
export mvw, mvw_inf
Expand Down
27 changes: 19 additions & 8 deletions src/classification.jl
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,10 @@ function classification(
if method == :aggregated
σ̂ = covariance_vector(data)
elseif method == :spectral
# compute the leading singular vector of the covariance matrix
# compute the leading right singular vector of the covariance matrix
Σ̂ = covariance_matrix(data)
_, vecs = eigsolve(transpose(Σ̂) * Σ̂) # faster than full SVD
v̌ = vecs[1]
_, _, rvecs, = svdsolve(Σ̂, 1) # avoids forming Σ̂'Σ̂ explicitly
v̌ = rvecs[1]

# sign disambiguation
σ̂_ag = sum(Σ̂; dims=1)[1, :]
Expand All @@ -40,7 +40,7 @@ function classification(
threshold = mean(σ̂)
output = σ̂ .>= threshold
elseif clustering in (:single, :average, :complete, :ward)
distances = [abs(σ̂[i] - σ̂[j]) for i in eachindex(σ̂), j in eachindex(σ̂)]
distances = pairwise_abs_distances(σ̂)
ct = cutree(hclust(distances; linkage=clustering); k=2)
id_excitatory = ct[argmax(σ̂)]
output = ct .== id_excitatory
Expand Down Expand Up @@ -69,17 +69,28 @@ function covariance_matrix(data::DiscreteTimeData)::Matrix{Float64}
N, T = size(data)
Z = sum(X; dims=2)

s = zeros((N, N))
for t in 1:(T - 1)
s += @views(X[:, t + 1] * transpose(X[:, t]))
end
s = @views(X[:, 2:end] * transpose(X[:, 1:(end - 1)]))
output = s / (T - 1) - Z * transpose(Z) / T^2

return output
end

# Auxiliary functions

function pairwise_abs_distances(v::Vector{<:Real})::Matrix{Float64}
n = length(v)
output = Matrix{Float64}(undef, n, n)
@inbounds for i in 1:n
output[i, i] = 0.0
@inbounds for j in (i + 1):n
d = abs(v[i] - v[j])
output[i, j] = d
output[j, i] = d
end
end
return output
end

function cluster2bool(R::ClusteringResult)::Vector{Bool}
output = Vector{Bool}(undef, sum(counts(R)))
check = R.centers[1] < R.centers[2]
Expand Down
Loading