diff --git a/examples/CO24/functions_classification.jl b/examples/CO24/functions_classification.jl index 4f3de19..2e4b7d3 100644 --- a/examples/CO24/functions_classification.jl +++ b/examples/CO24/functions_classification.jl @@ -215,7 +215,7 @@ function misclassificationrate( threshold = mean(σ) classif = σ .>= threshold else - distances = [abs(σ[i] - σ[j]) for i in eachindex(σ), j in eachindex(σ)] + distances = MeanFieldGraph.pairwise_abs_distances(σ) ct = cutree(hclust(distances; linkage=clustering); k=2) id_excitatory = ct[argmax(σ)] classif = ct .== id_excitatory diff --git a/src/MeanFieldGraph.jl b/src/MeanFieldGraph.jl index 66a8072..266b708 100644 --- a/src/MeanFieldGraph.jl +++ b/src/MeanFieldGraph.jl @@ -4,7 +4,7 @@ using Clustering: Clustering, ClusteringResult, assignments, counts, kmeans, hcl using Distributions: Distributions, Bernoulli, DiscreteUniform, fit, mean using LinearAlgebra: LinearAlgebra, I, transpose using Plots: Plots, heatmap, palette, plot -using KrylovKit: KrylovKit, eigsolve +using KrylovKit: KrylovKit, svdsolve export MarkovChainModel, DiscreteTimeData, MarkovChainConnectivity export mvw, mvw_inf diff --git a/src/classification.jl b/src/classification.jl index 81aa806..9bfb1c4 100644 --- a/src/classification.jl +++ b/src/classification.jl @@ -15,10 +15,10 @@ function classification( if method == :aggregated σ̂ = covariance_vector(data) elseif method == :spectral - # compute the leading singular vector of the covariance matrix + # compute the leading right singular vector of the covariance matrix Σ̂ = covariance_matrix(data) - _, vecs = eigsolve(transpose(Σ̂) * Σ̂) # faster than full SVD - v̌ = vecs[1] + _, _, rvecs, = svdsolve(Σ̂, 1) # avoids forming Σ̂'Σ̂ explicitly + v̌ = rvecs[1] # sign disambiguation σ̂_ag = sum(Σ̂; dims=1)[1, :] @@ -40,7 +40,7 @@ function classification( threshold = mean(σ̂) output = σ̂ .>= threshold elseif clustering in (:single, :average, :complete, :ward) - distances = [abs(σ̂[i] - σ̂[j]) for i in eachindex(σ̂), j in eachindex(σ̂)] + distances = pairwise_abs_distances(σ̂) ct = cutree(hclust(distances; linkage=clustering); k=2) id_excitatory = ct[argmax(σ̂)] output = ct .== id_excitatory @@ -69,10 +69,7 @@ function covariance_matrix(data::DiscreteTimeData)::Matrix{Float64} N, T = size(data) Z = sum(X; dims=2) - s = zeros((N, N)) - for t in 1:(T - 1) - s += @views(X[:, t + 1] * transpose(X[:, t])) - end + s = @views(X[:, 2:end] * transpose(X[:, 1:(end - 1)])) output = s / (T - 1) - Z * transpose(Z) / T^2 return output @@ -80,6 +77,20 @@ end # Auxiliary functions +function pairwise_abs_distances(v::Vector{<:Real})::Matrix{Float64} + n = length(v) + output = Matrix{Float64}(undef, n, n) + @inbounds for i in 1:n + output[i, i] = 0.0 + @inbounds for j in (i + 1):n + d = abs(v[i] - v[j]) + output[i, j] = d + output[j, i] = d + end + end + return output +end + function cluster2bool(R::ClusteringResult)::Vector{Bool} output = Vector{Bool}(undef, sum(counts(R))) check = R.centers[1] < R.centers[2]