Skip to content

Commit

Permalink
exposed new xmm training parameter "multiClass_regression_estimator"
Browse files Browse the repository at this point in the history
  • Loading branch information
josephlarralde committed Nov 6, 2017
1 parent 0a12707 commit 54b86f8
Show file tree
Hide file tree
Showing 8 changed files with 93 additions and 31 deletions.
14 changes: 9 additions & 5 deletions index.js
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
/*
* This library is developed by the ISMM (http://ismm.ircam.fr/) team at IRCAM,
* within the context of the RAPID-MIX (http://rapidmix.goldsmithsdigital.com/)
* project, funded by the European Union’s Horizon 2020 research and innovation programme.
* Original XMM code authored by Jules Françoise, ported to Node.js by Joseph Larralde.
* project, funded by the European Union’s Horizon 2020 research and innovation programme.
* Original XMM code authored by Jules Françoise, ported to Node.js by Joseph Larralde.
* See https://github.com/Ircam-RnD/xmm for detailed XMM credits.
*/

Expand All @@ -17,9 +17,9 @@ var XmmNative = null;
// Load the precompiled binary for windows.

// if (process.platform == "win32" && process.arch == "x64") {
// XmmNative = require('./bin/winx64/xmm');
// XmmNative = require('./bin/winx64/xmm');
// } else if(process.platform == "win32" && process.arch == "ia32") {
// XmmNative = require('./bin/winx86/xmm');
// XmmNative = require('./bin/winx86/xmm');
// } else {

var binary = require('node-pre-gyp');
Expand Down Expand Up @@ -67,6 +67,8 @@ function translateFromXmmConfigProp(prop) {
return 'transitionMode';
} else if (prop === 'regression_estimator') {
return 'regressionEstimator';
} else if (prop === 'multiClass_regression_estimator') {
return 'multiClassRegressionEstimator';
} else {
return prop;
}
Expand All @@ -83,6 +85,8 @@ function translateToXmmConfigProp(prop) {
return 'transition_mode';
} else if (prop === 'regressionEstimator') {
return 'regression_estimator';
} else if (prop === 'multiClassRegressionEstimator') {
return 'multiClass_regression_estimator';
} else {
return prop;
}
Expand All @@ -106,7 +110,7 @@ Xmm.prototype.getConfig = function(prop) {
Xmm.prototype.setConfig = function(config) {
var inConfig = {};
for (var prop in config) {
var translatedProp = translateToXmmConfigProp(prop);
var translatedProp = translateToXmmConfigProp(prop);
inConfig[translatedProp] = config[prop];
}
this._xmm.setConfig(inConfig);
Expand Down
4 changes: 2 additions & 2 deletions package.json
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
{
"name": "xmm-node",
"version": "0.3.2",
"version": "0.3.3",
"author": "Joseph Larralde",
"license": "GPL-3.0",
"description": "node wrapper for the XMM library",
Expand Down Expand Up @@ -42,7 +42,7 @@
"jsdoc-template": "ircam-jstools/jsdoc-template",
"jsdoc-to-markdown": "^1.3.7",
"tape": "^4.6.0",
"xmm-client": "ircam-rnd/xmm-client"
"xmm-client": "github:ircam-rnd/xmm-client#v0.3.3"
},
"babel": {
"presets": [
Expand Down
4 changes: 2 additions & 2 deletions src/GmmTool.h
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,8 @@
#include <iostream>
#include "XmmTool.h"

class GmmTool : public XmmTool<xmm::GMM> {
class GmmTool : public XmmTool<xmm::GMM, xmm::GMM> {

public:
GmmTool() {};
~GmmTool() {};
Expand Down
6 changes: 3 additions & 3 deletions src/HhmmTool.h
Original file line number Diff line number Diff line change
Expand Up @@ -9,12 +9,12 @@
#include <iostream>
#include "XmmTool.h"

class HhmmTool : public XmmTool<xmm::HierarchicalHMM> {
class HhmmTool : public XmmTool<xmm::HierarchicalHMM, xmm::HMM> {

public:
HhmmTool() {};
~HhmmTool() {};

void setNbStates(std::size_t nbStates) {
if(nbStates > 0) {
model.configuration.states.set(nbStates, 1);
Expand Down
21 changes: 16 additions & 5 deletions src/XmmTool.h
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,8 @@ class XmmToolBase {
virtual void setAbsoluteRegularization(double absReg) = 0;
virtual xmm::GaussianDistribution::CovarianceMode getCovarianceMode() = 0;
virtual void setCovarianceMode(xmm::GaussianDistribution::CovarianceMode cm) = 0;
virtual xmm::MultiClassRegressionEstimator getMultiClassRegressionEstimator() = 0;
virtual void setMultiClassRegressionEstimator(xmm::MultiClassRegressionEstimator mre) = 0;
};

// the specializable template :
Expand All @@ -36,7 +38,7 @@ class XmmTool : public XmmToolBase {
private:
std::vector<XmmWrapTrainWorker<Model> *> workers;
// std::vector<Nan::Callback *> callbacks;

public:
Model model;

Expand All @@ -49,7 +51,7 @@ class XmmTool : public XmmToolBase {
}

~XmmTool() {}

void setBimodal(bool multimodality) {
// Model tmp = Model(model);
xmm::Configuration<ModelType> config = model.configuration;
Expand Down Expand Up @@ -101,7 +103,7 @@ class XmmTool : public XmmToolBase {

v8::Local<v8::Object> filter(std::vector<float> observation) {
v8::Local<v8::Object> outputResults = Nan::New<v8::Object>();

bool bimodal = model.shared_parameters->bimodal.get();
unsigned int nmodels = model.size();
unsigned int dimension = model.shared_parameters->dimension.get();
Expand Down Expand Up @@ -158,7 +160,7 @@ class XmmTool : public XmmToolBase {

if (bimodal) {
v8::Local<v8::Array> output_values = Nan::New<v8::Array>(dimension_output);
for (unsigned int i = 0; i < dimension_output; ++i) {
for (unsigned int i = 0; i < dimension_output; ++i) {
Nan::Set(output_values, i, Nan::New(res.output_values[i]));
}
outputResults->Set(
Expand All @@ -168,7 +170,7 @@ class XmmTool : public XmmToolBase {

unsigned int dim_out_cov = res.output_covariance.size();
v8::Local<v8::Array> output_covariance = Nan::New<v8::Array>(dim_out_cov);
for (unsigned int i = 0; i < dim_out_cov; ++i) {
for (unsigned int i = 0; i < dim_out_cov; ++i) {
Nan::Set(output_covariance, i, Nan::New(res.output_covariance[i]));
}
outputResults->Set(
Expand Down Expand Up @@ -218,6 +220,15 @@ class XmmTool : public XmmToolBase {
model.configuration.changed = true;
}

xmm::MultiClassRegressionEstimator getMultiClassRegressionEstimator() {
return model.configuration.multiClass_regression_estimator;
}

void setMultiClassRegressionEstimator(xmm::MultiClassRegressionEstimator mre) {
model.configuration.multiClass_regression_estimator = mre;
model.configuration.changed = true;
}

};

#endif /* _XMM_TOOL_H_ */
25 changes: 25 additions & 0 deletions src/XmmWrap.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -456,6 +456,7 @@ void XmmWrap::getConfig(const Nan::FunctionCallbackInfo<v8::Value> & args) {
double relative_regularization = obj->model_->getRelativeRegularization();
double absolute_regularization = obj->model_->getAbsoluteRegularization();
xmm::GaussianDistribution::CovarianceMode covariance_mode = obj->model_->getCovarianceMode();
xmm::MultiClassRegressionEstimator multiclass_regression_estimator = obj->model_->getMultiClassRegressionEstimator();

// HierarchicalHMM-specific :
bool hierarchical = true;
Expand Down Expand Up @@ -490,6 +491,12 @@ void XmmWrap::getConfig(const Nan::FunctionCallbackInfo<v8::Value> & args) {
? "likeliest"
: ""));

std::string mre = (multiclass_regression_estimator == xmm::MultiClassRegressionEstimator::Likeliest)
? "likeliest"
: ((multiclass_regression_estimator == xmm::MultiClassRegressionEstimator::Mixture)
? "mixture"
: "");

std::string modelType;
switch (obj->modelType_) {
case XmmGmmE:
Expand Down Expand Up @@ -522,6 +529,8 @@ void XmmWrap::getConfig(const Nan::FunctionCallbackInfo<v8::Value> & args) {
args.GetReturnValue().Set(Nan::New<v8::Number>(absolute_regularization));
} else if (item == "covariance_mode") {
args.GetReturnValue().Set(Nan::New<v8::String>(cm).ToLocalChecked());
} else if (item == "multiclass_regression_etimator") {
args.GetReturnValue().Set(Nan::New<v8::String>(mre).ToLocalChecked());
}

if (obj->modelType_ == XmmHhmmE) {
Expand All @@ -547,6 +556,8 @@ void XmmWrap::getConfig(const Nan::FunctionCallbackInfo<v8::Value> & args) {
Nan::New<v8::Number>(absolute_regularization));
outputConfig->Set(Nan::New<v8::String>("covariance_mode").ToLocalChecked(),
Nan::New<v8::String>(cm).ToLocalChecked());
outputConfig->Set(Nan::New<v8::String>("multiClass_regression_estimator").ToLocalChecked(),
Nan::New<v8::String>(mre).ToLocalChecked());

if (obj->modelType_ == XmmHhmmE) {
outputConfig->Set(Nan::New<v8::String>("hierarchical").ToLocalChecked(),
Expand Down Expand Up @@ -577,6 +588,7 @@ void XmmWrap::setConfig(const Nan::FunctionCallbackInfo<v8::Value> & args) {
double rr = obj->model_->getRelativeRegularization();
double ar = obj->model_->getAbsoluteRegularization();
xmm::GaussianDistribution::CovarianceMode cm = obj->model_->getCovarianceMode();
xmm::MultiClassRegressionEstimator mre = obj->model_->getMultiClassRegressionEstimator();

// HierarchicalHMM-specific :
bool h = true;
Expand Down Expand Up @@ -659,12 +671,25 @@ void XmmWrap::setConfig(const Nan::FunctionCallbackInfo<v8::Value> & args) {
}
}

v8::Local<v8::Value> multiclass_regression_estimator
= inputConfig->Get(Nan::New<v8::String>("multiClass_regression_estimator").ToLocalChecked());
if (multiclass_regression_estimator->IsString()) {
v8::String::Utf8Value val(multiclass_regression_estimator->ToString());
std::string smre = std::string(*val);
if (smre == "likeliest") {
mre = xmm::MultiClassRegressionEstimator::Likeliest;
} else if (smre == "mixture") {
mre = xmm::MultiClassRegressionEstimator::Mixture;
}
}

// =============== SET NEW VALUES ============== //

obj->model_->setGaussians(g);
obj->model_->setRelativeRegularization(rr);
obj->model_->setAbsoluteRegularization(ar);
obj->model_->setCovarianceMode(cm);
obj->model_->setMultiClassRegressionEstimator(mre);

if (obj->modelType_ == XmmHhmmE) {
setHierarchical(obj, h);
Expand Down
1 change: 1 addition & 0 deletions test/1_basic_tests.js
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ test('model configuration consistency', (t) => {
covarianceMode: 'full',
relativeRegularization: 0.1,
absoluteRegularization: 0.1,
multiClassRegressionEstimator: 'mixture',
// states: 1,
};

Expand Down
49 changes: 35 additions & 14 deletions test/2_training_tests.js
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import xmm from '../index';
import { SetMaker } from 'xmm-client';
import { GmmDecoder, SetMaker } from 'xmm-client';
import test from 'tape';

test('training', (t) => {
Expand All @@ -8,14 +8,17 @@ test('training', (t) => {

const setMaker = new SetMaker();

const hhmm = new xmm('hhmm', {
hierarchical: false,
relativeRegularization: 0.1
const gmm = new xmm('gmm', {
// hierarchical: false,
relativeRegularization: 0.1,
multiClassRegressionEstimator: 'mixture',
});

const gmmClient = new GmmDecoder();

var p = {
bimodal: true,
dimension: 3,
dimension: 6,
dimension_input: 3,
column_names: [ "" ],
// data: [
Expand All @@ -24,39 +27,57 @@ test('training', (t) => {
// 3.7, 3.2
// ],
data_input: [1, 2, 3],
data_output: [1, 2, 3],
data_output: [0, 0, 0],
length: 1,
label: 'aLabel'
}
};

const trainMsgOne = 'train should return a null model when training is cancelled';

hhmm.train((err, res) => {
gmm.train((err, res) => {
t.equal(res, null, trainMsgOne);
});

hhmm.cancelTraining();
gmm.cancelTraining();

const trainMsgTwo = 'train should return an empty model when trained with empty set';

hhmm.train((err, res) => {
gmm.train((err, res) => {
t.deepEqual(res.models, [], trainMsgTwo);
});

const trainMsgThree = 'train should return a trained model';

for (let i = 0; i < 500; i++) {
setMaker.addPhrase(p);
for (let i = 0; i < 5; i++) {
setMaker.addPhrase(JSON.parse(JSON.stringify(p)));
// hhmm.addPhrase(JSON.parse(JSON.stringify(p)));
}
// hhmm.setTrainingSet(setMaker.getTrainingSet());

p.data_input = [3, 2, 1];
p.data_output = [10, 10, 10];
p.label = 'anotherLabel';

for (let i = 0; i < 5; i++) {
setMaker.addPhrase(JSON.parse(JSON.stringify(p)));
// hhmm.addPhrase(JSON.parse(JSON.stringify(p)));
}

const set = setMaker.getTrainingSet();
gmm.setTrainingSet(set);
console.log(set);
// hhmm.clearTrainingSet();
// hhmm.addTrainingSet(setMaker.getTrainingSet());

hhmm.train((err, res) => {
gmm.train((err, res) => {
console.log(JSON.stringify(res, null, 2));
gmmClient.setModel(res);

t.notEqual(res, null, trainMsgThree);
t.equal(res.models.length > 0, true, trainMsgThree);

console.log(gmmClient.filter([3, 2, 3]));
console.log(gmmClient.filter([3, 2, 1]));

// const setModelConfigMsg = 'config should not change when queried after setModel';
// let config = hhmm.getConfig();
// hhmm.setModel(res);
Expand Down

0 comments on commit 54b86f8

Please sign in to comment.