Skip to content

Commit 5131bda

Browse files
committed
Fixing issues and splitting out into separate classes
1 parent 78e8e8c commit 5131bda

File tree

4 files changed

+48
-29
lines changed

4 files changed

+48
-29
lines changed

lib/hex-svm.rb

+19-10
Original file line numberDiff line numberDiff line change
@@ -71,16 +71,25 @@ def _convert_to_svm_node_array(x)
7171
module SVM
7272
extend self
7373

74-
def convert_to_svm_node_array(x, max = 30)
75-
# X is an array of indicies
76-
# Make index array
77-
# iter_range = x.sort
78-
79-
data = svm_node_array(max + 1)
80-
svm_node_array_set(data, max, -1, 0)
81-
82-
x.sort.each do |k|
83-
svm_node_array_set(data, k, k, 1)
74+
def convert_to_svm_node_array(indicies, max)
75+
# Make index array
76+
# x = indexes_to_array(indicies, max)
77+
# iter_range = x.each_index.to_a
78+
79+
data = svm_node_array(indicies.length + 1)
80+
svm_node_array_set(data, indicies.length, -1, 0)
81+
82+
# max.times do |i|
83+
# # Set to zero if not in indicies
84+
# if indicies.include?(i)
85+
# svm_node_array_set(data, i, i, 1)
86+
# else
87+
# svm_node_array_set(data, i, i, 0)
88+
# end
89+
# end
90+
91+
indicies.sort.each_with_index do |idx, i|
92+
svm_node_array_set(data, i, idx, 1)
8493
end
8594

8695
data

lib/libsvm/model.rb

+5-6
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ def initialize(arg1,arg2=nil)
1616
end
1717
msg = svm_check_parameter(prob.prob,param.param)
1818
raise ::ArgumentError, msg if msg
19-
@model = svm_train(prob.prob,param.param)
19+
@model = svm_train(prob.prob, param.param)
2020
end
2121

2222
#setup some classwide variables
@@ -29,11 +29,10 @@ def initialize(arg1,arg2=nil)
2929
delete_int(intarr)
3030
#check if valid probability model
3131
@probability = svm_check_probability_model(@model)
32-
3332
end
3433

35-
def predict(x)
36-
data = SVM.convert_to_svm_node_array(x)
34+
def predict(x, max = x.max)
35+
data = SVM.convert_to_svm_node_array(x, max)
3736
ret = svm_predict(@model,data)
3837
svm_node_array_destroy(data)
3938
return ret
@@ -54,7 +53,7 @@ def get_labels
5453
def predict_values_raw(x)
5554
#convert x into svm_node, allocate a double array for return
5655
n = (@nr_class*(@nr_class-1)/2).floor
57-
data = _convert_to_svm_node_array(x)
56+
data = SVM.convert_to_svm_node_array(x)
5857
dblarr = new_double(n)
5958
svm_predict_values(@model, data, dblarr)
6059
ret = _double_array_to_list(dblarr, n)
@@ -101,7 +100,7 @@ def predict_probability(x)
101100
end
102101

103102
#convert x into svm_node, alloc a double array to receive probabilities
104-
data = _convert_to_svm_node_array(x)
103+
data = SVM.convert_to_svm_node_array(x)
105104
dblarr = new_double(@nr_class)
106105
pred = svm_predict_probability(@model, data, dblarr)
107106
pv = _double_array_to_list(dblarr, @nr_class)

lib/libsvm/problem.rb

+5-5
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,11 @@
11
module SVM
22
class Problem
33
attr_accessor :prob, :maxlen, :size
4+
attr_reader :max
45

56
def initialize(y, *x)
67
#assert y.size == x.size
8+
@max = x.flatten.max #aggh hate doing this...
79
@prob = prob = Svm_problem.new
810
@size = size = y.size
911

@@ -14,12 +16,10 @@ def initialize(y, *x)
1416

1517
@x_matrix = x_matrix = svm_node_matrix(size)
1618
@data = []
17-
@maxlen = 0
19+
@maxlen = max
1820
x.each_with_index do |row, i|
19-
#data = _convert_to_svm_node_array(row)
20-
data = SVM.convert_to_svm_node_array(row)
21-
@data << data
22-
svm_node_matrix_set(x_matrix, i, data)
21+
@data << SVM.convert_to_svm_node_array(row, max)
22+
svm_node_matrix_set(x_matrix, i, @data.last)
2323
@maxlen = [@maxlen, row.size].max
2424
end
2525

spec/problem_spec.rb

+19-8
Original file line numberDiff line numberDiff line change
@@ -1,19 +1,30 @@
1-
require File.join(File.dirname(__FILE__), "../lib/svm")
1+
require File.join(File.dirname(__FILE__), "../lib/hex-svm")
22

33
describe SVM::Problem do
44
before(:each) do
55
#@model = [1,-1], [1,0,1], [0,0,1]
66
@y = [1, -1]
77
@x = [[0,2], [2]]
8+
9+
prob = SVM::Problem.new(@y, *@x)
10+
param = SVM::Parameter.new(:kernel_type => LINEAR, :C => 10)
11+
@model = SVM::Model.new(prob, param)
812
end
913

1014

11-
it 'should return the right results' do
12-
prob = SVM::Problem.new(@y, *@x)
13-
param = SVM::Parameter.new(:kernel_type => RBF, :C => 10)
14-
model = SVM::Model.new(prob, param)
15-
16-
model.predict([2]).should == -1.0
17-
model.predict([0]).should == 1.0
15+
it '0,1,2 should retrieve 1' do
16+
@model.predict([0,1,2], 3).should == 1.0
17+
end
18+
19+
it '' do
20+
@model.predict([2], 3).should == -1.0
21+
end
22+
23+
it '' do
24+
@model.predict([], 3).should == -1.0
25+
end
26+
27+
it '' do
28+
@model.predict([0], 3).should == 1.0
1829
end
1930
end

0 commit comments

Comments
 (0)