diff --git a/index.js b/index.js index 9a4ed19..bb0eaf4 100644 --- a/index.js +++ b/index.js @@ -1,8 +1,8 @@ /* * This library is developed by the ISMM (http://ismm.ircam.fr/) team at IRCAM, * within the context of the RAPID-MIX (http://rapidmix.goldsmithsdigital.com/) - * project, funded by the European Union’s Horizon 2020 research and innovation programme. - * Original XMM code authored by Jules Françoise, ported to Node.js by Joseph Larralde. + * project, funded by the European Union’s Horizon 2020 research and innovation programme. + * Original XMM code authored by Jules Françoise, ported to Node.js by Joseph Larralde. * See https://github.com/Ircam-RnD/xmm for detailed XMM credits. */ @@ -17,9 +17,9 @@ var XmmNative = null; // Load the precompiled binary for windows. // if (process.platform == "win32" && process.arch == "x64") { -// XmmNative = require('./bin/winx64/xmm'); +// XmmNative = require('./bin/winx64/xmm'); // } else if(process.platform == "win32" && process.arch == "ia32") { -// XmmNative = require('./bin/winx86/xmm'); +// XmmNative = require('./bin/winx86/xmm'); // } else { var binary = require('node-pre-gyp'); @@ -67,6 +67,8 @@ function translateFromXmmConfigProp(prop) { return 'transitionMode'; } else if (prop === 'regression_estimator') { return 'regressionEstimator'; + } else if (prop === 'multiClass_regression_estimator') { + return 'multiClassRegressionEstimator'; } else { return prop; } @@ -83,6 +85,8 @@ function translateToXmmConfigProp(prop) { return 'transition_mode'; } else if (prop === 'regressionEstimator') { return 'regression_estimator'; + } else if (prop === 'multiClassRegressionEstimator') { + return 'multiClass_regression_estimator'; } else { return prop; } @@ -106,7 +110,7 @@ Xmm.prototype.getConfig = function(prop) { Xmm.prototype.setConfig = function(config) { var inConfig = {}; for (var prop in config) { - var translatedProp = translateToXmmConfigProp(prop); + var translatedProp = translateToXmmConfigProp(prop); inConfig[translatedProp] = config[prop]; } this._xmm.setConfig(inConfig); diff --git a/package.json b/package.json index e592894..c89b7a2 100644 --- a/package.json +++ b/package.json @@ -1,6 +1,6 @@ { "name": "xmm-node", - "version": "0.3.2", + "version": "0.3.3", "author": "Joseph Larralde", "license": "GPL-3.0", "description": "node wrapper for the XMM library", @@ -42,7 +42,7 @@ "jsdoc-template": "ircam-jstools/jsdoc-template", "jsdoc-to-markdown": "^1.3.7", "tape": "^4.6.0", - "xmm-client": "ircam-rnd/xmm-client" + "xmm-client": "github:ircam-rnd/xmm-client#v0.3.3" }, "babel": { "presets": [ diff --git a/src/GmmTool.h b/src/GmmTool.h index 8cdd551..989eaa8 100644 --- a/src/GmmTool.h +++ b/src/GmmTool.h @@ -9,8 +9,8 @@ #include #include "XmmTool.h" -class GmmTool : public XmmTool { - +class GmmTool : public XmmTool { + public: GmmTool() {}; ~GmmTool() {}; diff --git a/src/HhmmTool.h b/src/HhmmTool.h index 662022f..252097b 100644 --- a/src/HhmmTool.h +++ b/src/HhmmTool.h @@ -9,12 +9,12 @@ #include #include "XmmTool.h" -class HhmmTool : public XmmTool { - +class HhmmTool : public XmmTool { + public: HhmmTool() {}; ~HhmmTool() {}; - + void setNbStates(std::size_t nbStates) { if(nbStates > 0) { model.configuration.states.set(nbStates, 1); diff --git a/src/XmmTool.h b/src/XmmTool.h index 39e022b..31c5f81 100644 --- a/src/XmmTool.h +++ b/src/XmmTool.h @@ -26,6 +26,8 @@ class XmmToolBase { virtual void setAbsoluteRegularization(double absReg) = 0; virtual xmm::GaussianDistribution::CovarianceMode getCovarianceMode() = 0; virtual void setCovarianceMode(xmm::GaussianDistribution::CovarianceMode cm) = 0; + virtual xmm::MultiClassRegressionEstimator getMultiClassRegressionEstimator() = 0; + virtual void setMultiClassRegressionEstimator(xmm::MultiClassRegressionEstimator mre) = 0; }; // the specializable template : @@ -36,7 +38,7 @@ class XmmTool : public XmmToolBase { private: std::vector *> workers; // std::vector callbacks; - + public: Model model; @@ -49,7 +51,7 @@ class XmmTool : public XmmToolBase { } ~XmmTool() {} - + void setBimodal(bool multimodality) { // Model tmp = Model(model); xmm::Configuration config = model.configuration; @@ -101,7 +103,7 @@ class XmmTool : public XmmToolBase { v8::Local filter(std::vector observation) { v8::Local outputResults = Nan::New(); - + bool bimodal = model.shared_parameters->bimodal.get(); unsigned int nmodels = model.size(); unsigned int dimension = model.shared_parameters->dimension.get(); @@ -158,7 +160,7 @@ class XmmTool : public XmmToolBase { if (bimodal) { v8::Local output_values = Nan::New(dimension_output); - for (unsigned int i = 0; i < dimension_output; ++i) { + for (unsigned int i = 0; i < dimension_output; ++i) { Nan::Set(output_values, i, Nan::New(res.output_values[i])); } outputResults->Set( @@ -168,7 +170,7 @@ class XmmTool : public XmmToolBase { unsigned int dim_out_cov = res.output_covariance.size(); v8::Local output_covariance = Nan::New(dim_out_cov); - for (unsigned int i = 0; i < dim_out_cov; ++i) { + for (unsigned int i = 0; i < dim_out_cov; ++i) { Nan::Set(output_covariance, i, Nan::New(res.output_covariance[i])); } outputResults->Set( @@ -218,6 +220,15 @@ class XmmTool : public XmmToolBase { model.configuration.changed = true; } + xmm::MultiClassRegressionEstimator getMultiClassRegressionEstimator() { + return model.configuration.multiClass_regression_estimator; + } + + void setMultiClassRegressionEstimator(xmm::MultiClassRegressionEstimator mre) { + model.configuration.multiClass_regression_estimator = mre; + model.configuration.changed = true; + } + }; #endif /* _XMM_TOOL_H_ */ diff --git a/src/XmmWrap.cpp b/src/XmmWrap.cpp index f8a55ae..8f786d5 100644 --- a/src/XmmWrap.cpp +++ b/src/XmmWrap.cpp @@ -456,6 +456,7 @@ void XmmWrap::getConfig(const Nan::FunctionCallbackInfo & args) { double relative_regularization = obj->model_->getRelativeRegularization(); double absolute_regularization = obj->model_->getAbsoluteRegularization(); xmm::GaussianDistribution::CovarianceMode covariance_mode = obj->model_->getCovarianceMode(); + xmm::MultiClassRegressionEstimator multiclass_regression_estimator = obj->model_->getMultiClassRegressionEstimator(); // HierarchicalHMM-specific : bool hierarchical = true; @@ -490,6 +491,12 @@ void XmmWrap::getConfig(const Nan::FunctionCallbackInfo & args) { ? "likeliest" : "")); + std::string mre = (multiclass_regression_estimator == xmm::MultiClassRegressionEstimator::Likeliest) + ? "likeliest" + : ((multiclass_regression_estimator == xmm::MultiClassRegressionEstimator::Mixture) + ? "mixture" + : ""); + std::string modelType; switch (obj->modelType_) { case XmmGmmE: @@ -522,6 +529,8 @@ void XmmWrap::getConfig(const Nan::FunctionCallbackInfo & args) { args.GetReturnValue().Set(Nan::New(absolute_regularization)); } else if (item == "covariance_mode") { args.GetReturnValue().Set(Nan::New(cm).ToLocalChecked()); + } else if (item == "multiclass_regression_etimator") { + args.GetReturnValue().Set(Nan::New(mre).ToLocalChecked()); } if (obj->modelType_ == XmmHhmmE) { @@ -547,6 +556,8 @@ void XmmWrap::getConfig(const Nan::FunctionCallbackInfo & args) { Nan::New(absolute_regularization)); outputConfig->Set(Nan::New("covariance_mode").ToLocalChecked(), Nan::New(cm).ToLocalChecked()); + outputConfig->Set(Nan::New("multiClass_regression_estimator").ToLocalChecked(), + Nan::New(mre).ToLocalChecked()); if (obj->modelType_ == XmmHhmmE) { outputConfig->Set(Nan::New("hierarchical").ToLocalChecked(), @@ -577,6 +588,7 @@ void XmmWrap::setConfig(const Nan::FunctionCallbackInfo & args) { double rr = obj->model_->getRelativeRegularization(); double ar = obj->model_->getAbsoluteRegularization(); xmm::GaussianDistribution::CovarianceMode cm = obj->model_->getCovarianceMode(); + xmm::MultiClassRegressionEstimator mre = obj->model_->getMultiClassRegressionEstimator(); // HierarchicalHMM-specific : bool h = true; @@ -659,12 +671,25 @@ void XmmWrap::setConfig(const Nan::FunctionCallbackInfo & args) { } } + v8::Local multiclass_regression_estimator + = inputConfig->Get(Nan::New("multiClass_regression_estimator").ToLocalChecked()); + if (multiclass_regression_estimator->IsString()) { + v8::String::Utf8Value val(multiclass_regression_estimator->ToString()); + std::string smre = std::string(*val); + if (smre == "likeliest") { + mre = xmm::MultiClassRegressionEstimator::Likeliest; + } else if (smre == "mixture") { + mre = xmm::MultiClassRegressionEstimator::Mixture; + } + } + // =============== SET NEW VALUES ============== // obj->model_->setGaussians(g); obj->model_->setRelativeRegularization(rr); obj->model_->setAbsoluteRegularization(ar); obj->model_->setCovarianceMode(cm); + obj->model_->setMultiClassRegressionEstimator(mre); if (obj->modelType_ == XmmHhmmE) { setHierarchical(obj, h); diff --git a/test/1_basic_tests.js b/test/1_basic_tests.js index 43abfe2..bd6612f 100644 --- a/test/1_basic_tests.js +++ b/test/1_basic_tests.js @@ -15,6 +15,7 @@ test('model configuration consistency', (t) => { covarianceMode: 'full', relativeRegularization: 0.1, absoluteRegularization: 0.1, + multiClassRegressionEstimator: 'mixture', // states: 1, }; diff --git a/test/2_training_tests.js b/test/2_training_tests.js index 62976f1..5dc241e 100644 --- a/test/2_training_tests.js +++ b/test/2_training_tests.js @@ -1,5 +1,5 @@ import xmm from '../index'; -import { SetMaker } from 'xmm-client'; +import { GmmDecoder, SetMaker } from 'xmm-client'; import test from 'tape'; test('training', (t) => { @@ -8,14 +8,17 @@ test('training', (t) => { const setMaker = new SetMaker(); - const hhmm = new xmm('hhmm', { - hierarchical: false, - relativeRegularization: 0.1 + const gmm = new xmm('gmm', { + // hierarchical: false, + relativeRegularization: 0.1, + multiClassRegressionEstimator: 'mixture', }); + const gmmClient = new GmmDecoder(); + var p = { bimodal: true, - dimension: 3, + dimension: 6, dimension_input: 3, column_names: [ "" ], // data: [ @@ -24,39 +27,57 @@ test('training', (t) => { // 3.7, 3.2 // ], data_input: [1, 2, 3], - data_output: [1, 2, 3], + data_output: [0, 0, 0], length: 1, label: 'aLabel' - } + }; const trainMsgOne = 'train should return a null model when training is cancelled'; - hhmm.train((err, res) => { + gmm.train((err, res) => { t.equal(res, null, trainMsgOne); }); - hhmm.cancelTraining(); + gmm.cancelTraining(); const trainMsgTwo = 'train should return an empty model when trained with empty set'; - hhmm.train((err, res) => { + gmm.train((err, res) => { t.deepEqual(res.models, [], trainMsgTwo); }); const trainMsgThree = 'train should return a trained model'; - for (let i = 0; i < 500; i++) { - setMaker.addPhrase(p); + for (let i = 0; i < 5; i++) { + setMaker.addPhrase(JSON.parse(JSON.stringify(p))); // hhmm.addPhrase(JSON.parse(JSON.stringify(p))); } - // hhmm.setTrainingSet(setMaker.getTrainingSet()); + + p.data_input = [3, 2, 1]; + p.data_output = [10, 10, 10]; + p.label = 'anotherLabel'; + + for (let i = 0; i < 5; i++) { + setMaker.addPhrase(JSON.parse(JSON.stringify(p))); + // hhmm.addPhrase(JSON.parse(JSON.stringify(p))); + } + + const set = setMaker.getTrainingSet(); + gmm.setTrainingSet(set); + console.log(set); // hhmm.clearTrainingSet(); // hhmm.addTrainingSet(setMaker.getTrainingSet()); - hhmm.train((err, res) => { + gmm.train((err, res) => { + console.log(JSON.stringify(res, null, 2)); + gmmClient.setModel(res); + t.notEqual(res, null, trainMsgThree); t.equal(res.models.length > 0, true, trainMsgThree); + console.log(gmmClient.filter([3, 2, 3])); + console.log(gmmClient.filter([3, 2, 1])); + // const setModelConfigMsg = 'config should not change when queried after setModel'; // let config = hhmm.getConfig(); // hhmm.setModel(res);