diff --git a/src/SOM.jl b/src/SOM.jl index 2b738b6..20029a3 100644 --- a/src/SOM.jl +++ b/src/SOM.jl @@ -15,6 +15,7 @@ using Distances using ProgressMeter using StatsBase using Distributions +using NearestNeighbors #using TensorToolbox using LinearAlgebra # if VERSION < v"0.7.0-DEV.5183" diff --git a/src/helpers.jl b/src/helpers.jl index 8e365eb..c954950 100644 --- a/src/helpers.jl +++ b/src/helpers.jl @@ -90,6 +90,14 @@ function findWinner(cod, sampl) return winner end +function findWinner(kd_tree::KDTree, sampl) + idxs, _ = knn(kd_tree, sampl, 1) + idxs[1] +end + +function buildKDTree(codes) + permutedims(codes,(2,1)) |> KDTree +end """ normTrainData(x, normParams) diff --git a/src/soms.jl b/src/soms.jl index 24acd65..fd4bb0b 100644 --- a/src/soms.jl +++ b/src/soms.jl @@ -76,16 +76,21 @@ end Return the index of the winner neuron for each training pattern in x (row-wise). """ -function visual(codes, x) + +function visualGeneric(codes, x) vis = zeros(Int, nrow(x)) - for i in 1:nrow(x) + @time for i in 1:nrow(x) vis[i] = findWinner(codes, [x[i, col] for col in 1:size(x, 2)]) end return(vis) end +function visual(codes, x) + kd_tree = buildKDTree(codes) + visualGeneric(kd_tree, x) +end """ makePopulation(nCodes, vis) diff --git a/test/testFuns.jl b/test/testFuns.jl index be94ed1..5741784 100644 --- a/test/testFuns.jl +++ b/test/testFuns.jl @@ -16,7 +16,6 @@ end - function testVisual(train, topol) xdim = 8 @@ -38,6 +37,49 @@ function testVisual(train, topol) end +function testVisualGeneric() + + nonce = (function() + count = 0.0 + function() + count += 1 + end + end)() + + for i in 1:10 + + # tests on random data + dimensions = rand(1:10) + ncodes = rand(1:10) + nrows = rand(1:10) + codes = rand(ncodes,dimensions) + samples = rand(nrows,dimensions) + + if visual(codes,samples) != visualGeneric(codes,samples) + return false + end + + # tests on sequenced data + dimensions = i + ncodes = i + nrows = i + codes = zeros(ncodes,dimensions) + codes = map(_->(nonce()),codes) + samples = zeros(nrows,dimensions) + samples = map(_->(nonce()),samples) + + if visual(codes,samples) != visualGeneric(codes,samples) + return false + end + + end + + return true + +end + + + function testFreqs(train, wClasses, classes) xdim = 8