Skip to content

Commit beb9c3e

Browse files
author
Tomasz Latkowski
committed
added pooled variance
1 parent 4af5e8b commit beb9c3e

File tree

3 files changed

+39
-16
lines changed

3 files changed

+39
-16
lines changed

methods/selection.py

Lines changed: 5 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@ def fisher(data, num_instances: list, top_k_features=2):
1414
assert len(num_instances) == 2, "Fisher selection method can be performed for two-class problems."
1515

1616
data = tf.convert_to_tensor(data)
17-
_, num_features = data.get_shape().as_list()
17+
num_features = data.get_shape().as_list()[-1]
1818
if top_k_features > num_features:
1919
top_k_features = num_features
2020
class1, class2 = tf.split(data, num_instances)
@@ -35,7 +35,7 @@ def feature_correlation_with_class(data, num_instances: list, top_k_features=10)
3535
:return: the list of most significant features.
3636
"""
3737
data = tf.convert_to_tensor(data)
38-
_, num_features = data.get_shape().as_list()
38+
num_features = data.get_shape().as_list()[-1]
3939
if top_k_features > num_features:
4040
top_k_features = num_features
4141
class1, class2 = tf.split(data, num_instances)
@@ -57,7 +57,7 @@ def t_test(data, num_instances: list, top_k_features=10):
5757
:return: the list of most significant features.
5858
"""
5959
data = tf.convert_to_tensor(data)
60-
_, num_features = data.get_shape().as_list()
60+
num_features = data.get_shape().as_list()[-1]
6161
if top_k_features > num_features:
6262
top_k_features = num_features
6363
class1, class2 = tf.split(data, num_instances)
@@ -74,16 +74,10 @@ def t_test(data, num_instances: list, top_k_features=10):
7474

7575
def random(data, num_instances: list, top_k_features=10):
7676
data = tf.convert_to_tensor(data)
77-
_, num_features = data.get_shape().as_list()
77+
num_features = data.get_shape().as_list()[-1]
7878
if top_k_features > num_features:
7979
top_k_features = num_features
8080
class1, class2 = tf.split(data, num_instances)
8181

8282
with tf.name_scope('random_selection'):
83-
mean1, std1 = tf.nn.moments(class1, axes=0)
84-
mean2, std2 = tf.nn.moments(class2, axes=0)
85-
t_test_coeffs = tf.abs(mean1 - mean2) / tf.sqrt(
86-
tf.square(std1) / num_instances[0] + tf.square(std2) / num_instances[1])
87-
selected_features = tf.nn.top_k(t_test_coeffs, k=top_k_features)
88-
89-
return selected_features
83+
pass

tests/test_statistics.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
import numpy as np
2+
import tensorflow as tf
3+
4+
from utils.statistics import pooled_variance
5+
6+
7+
class TestStatistics(tf.test.TestCase):
8+
9+
def testPooledVariance(self):
10+
with self.test_session() as test_session:
11+
data = np.array([[2., 3., 4., 5.],
12+
[2., 3., 4., 5.],
13+
[2., 3., 4., 5.],
14+
[2., 3., 4., 5.]])
15+
num_instances = [2, 2]
16+
actual_pooled_variance = test_session.run(pooled_variance(data, num_instances))
17+
correct_pooled_variance = [.0, .0, .0, .0]
18+
19+
self.assertAllEqual(actual_pooled_variance, correct_pooled_variance)
20+
21+
22+
if __name__ == '__main__':
23+
tf.test.main()

utils/statistics.py

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -37,8 +37,14 @@ def f_test(data, num_instances):
3737
def pooled_variance(data, num_instances):
3838
K = len(num_instances)
3939
n = sum(num_instances)
40-
41-
class1, class2 = tf.split(data, num_instances)
42-
mean1, var1 = tf.nn.moments(class1, axes=0)
43-
mean2, var2 = tf.nn.moments(class2, axes=0)
44-
return -1
40+
data = tf.convert_to_tensor(data, dtype=tf.float32)
41+
split_classes = tf.split(data, num_instances)
42+
vars = []
43+
for i in range(len(split_classes)):
44+
_, var = tf.nn.moments(split_classes[i], axes=0)
45+
vars.append(var)
46+
47+
n_k = tf.to_float(tf.reshape(num_instances, [K, -1]))
48+
stacked_var = tf.stack(vars)
49+
pooled_var = tf.reduce_sum(stacked_var * (n_k - 1), axis=0) / (n - K)
50+
return pooled_var

0 commit comments

Comments
 (0)