Skip to content

Commit a2041e3

Browse files
committed
Add XGBoost test files (#1452)
1 parent dedd92d commit a2041e3

File tree

5 files changed

+88
-5
lines changed

5 files changed

+88
-5
lines changed

source/gguf.js

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -191,7 +191,8 @@ gguf.Reader = class {
191191
static open(context) {
192192
const stream = context.stream;
193193
if (stream && stream.length > 4) {
194-
const signature = String.fromCharCode.apply(null, stream.peek(4));
194+
const buffer = stream.peek(4);
195+
const signature = String.fromCharCode.apply(null, buffer);
195196
if (signature === 'GGUF') {
196197
return new gguf.Reader(context);
197198
}

source/mlir.js

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@ mlir.ModelFactory = class {
4040
const reader = await context.read('binary');
4141
const parser = new mlir.BytecodeReader(reader);
4242
parser.read();
43-
throw new mlir.Error('Invalid file content. File contains MLIR bytecode data.');
43+
throw new mlir.Error('File contains unsupported MLIR bytecode data.');
4444
}
4545
default: {
4646
throw new mlir.Error(`Unsupported MLIR format '${context.type}'.`);

source/view.js

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6151,6 +6151,7 @@ view.ModelFactoryService = class {
61516151
this.register('./weka', ['.model']);
61526152
this.register('./qnn', ['.json', '.bin', '.serialized', '.dlc']);
61536153
this.register('./kann', ['.kann', '.bin', '.kgraph'], [], [/^....KaNN/]);
6154+
this.register('./xgboost', ['.xgb', '.xgboost', '.json', '.model', '.bin', '.txt'], [], [/^{L\x00\x00/, /^binf/, /^bs64/, /^\s*booster\[0\]:/]);
61546155
this.register('', ['.cambricon', '.vnnmodel', '.nnc']);
61556156
/* eslint-enable no-control-regex */
61566157
}
@@ -6695,7 +6696,6 @@ view.ModelFactoryService = class {
66956696
{ name: 'Cambricon model', value: /^\x7fMEF/ },
66966697
{ name: 'Cambricon model', value: /^cambricon_offline/ },
66976698
{ name: 'VNN model', value: /^\x2F\x4E\x00\x00.\x00\x00\x00/, identifier: /.vnnmodel$/ },
6698-
{ name: 'XGBoost model', value: /^(binf|bs64)/ }, // https://github.com/dmlc/xgboost/blob/master/src/learner.cc
66996699
{ name: 'SQLite data', value: /^SQLite format/ },
67006700
{ name: 'Optimium model', value: /^EZMODEL/ }, // https://github.com/EZ-Optimium/Optimium,
67016701
{ name: 'undocumented NNC data', value: /^(\xC0|\xBC)\x0F\x00\x00ENNC/ },

source/xgboost.js

Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,54 @@
1+
2+
// Experimental
3+
4+
const xgboost = {};
5+
6+
xgboost.ModelFactory = class {
7+
8+
async match(context) {
9+
const obj = await context.peek('json');
10+
if (obj && obj.learner && obj.version) {
11+
return context.set('xgboost.json', obj);
12+
}
13+
const stream = context.stream;
14+
if (stream && stream.length > 4) {
15+
const buffer = stream.peek(4);
16+
if (buffer[0] === 0x7B && buffer[1] === 0x4C && buffer[2] === 0x00 && buffer[3] === 0x00) {
17+
return context.set('xgboost.ubj', stream);
18+
}
19+
const signature = String.fromCharCode.apply(null, buffer);
20+
if (signature.startsWith('binf')) {
21+
return context.set('xgboost.binf', stream);
22+
}
23+
if (signature.startsWith('bs64')) {
24+
return context.set('xgboost.bs64', stream);
25+
}
26+
const reader = await context.read('text', 0x100);
27+
const line = reader.read('\n');
28+
if (line !== undefined && line.trim() === 'booster[0]:') {
29+
return context.set('xgboost.text', stream);
30+
}
31+
}
32+
return null;
33+
}
34+
35+
async open(context) {
36+
if (context.type === 'xgboost.json') {
37+
throw new xgboost.Error('File contains unsupported XGBoost JSON data.');
38+
}
39+
if (context.type === 'xgboost.text') {
40+
throw new xgboost.Error('File contains unsupported XGBoost text data.');
41+
}
42+
throw new xgboost.Error('File contains unsupported XGBoost data.');
43+
}
44+
};
45+
46+
xgboost.Error = class extends Error {
47+
48+
constructor(message) {
49+
super(message);
50+
this.name = 'Error loading XGBoost model.';
51+
}
52+
};
53+
54+
export const ModelFactory = xgboost.ModelFactory;

test/models.json

Lines changed: 30 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3312,7 +3312,7 @@
33123312
"target": "model.mlirbc",
33133313
"source": "https://github.com/user-attachments/files/17179955/model.mlirbc.zip[model.mlirbc]",
33143314
"format": "MLIR",
3315-
"error": "Invalid file content. File contains MLIR bytecode data.",
3315+
"error": "File contains unsupported MLIR bytecode data.",
33163316
"link": "https://github.com/lutzroeder/netron/issues/1044"
33173317
},
33183318
{
@@ -3363,7 +3363,7 @@
33633363
"target": "versioned-op-2.0.mlirbc",
33643364
"source": "https://github.com/user-attachments/files/17174958/versioned-op-2.0.mlirbc.zip[versioned-op-2.0.mlirbc]",
33653365
"format": "MLIR",
3366-
"error": "Invalid file content. File contains MLIR bytecode data.",
3366+
"error": "File contains unsupported MLIR bytecode data.",
33673367
"link": "https://github.com/lutzroeder/netron/issues/1044"
33683368
},
33693369
{
@@ -8451,6 +8451,34 @@
84518451
"source": "https://raw.githubusercontent.com/PTaati/wekaTree2python/master/j48model.model",
84528452
"error": "Unsupported type 'weka.classifiers.trees.J48'."
84538453
},
8454+
{
8455+
"type": "xgboost",
8456+
"target": "xgb_classifier.json",
8457+
"source": "https://github.com/user-attachments/files/20028544/xgb_classifier.zip[xgb_classifier.json]",
8458+
"error": "File contains unsupported XGBoost JSON data.",
8459+
"link": "https://github.com/lutzroeder/netron/issues/1452"
8460+
},
8461+
{
8462+
"type": "xgboost",
8463+
"target": "xgb_classifier.model",
8464+
"source": "https://github.com/user-attachments/files/20028544/xgb_classifier.zip[xgb_classifier.model]",
8465+
"error": "File contains unsupported XGBoost data.",
8466+
"link": "https://github.com/lutzroeder/netron/issues/1452"
8467+
},
8468+
{
8469+
"type": "xgboost",
8470+
"target": "xgb_classifier.pkl",
8471+
"source": "https://github.com/user-attachments/files/20028544/xgb_classifier.zip[xgb_classifier.pkl]",
8472+
"format": "scikit-learn",
8473+
"link": "https://github.com/lutzroeder/netron/issues/1452"
8474+
},
8475+
{
8476+
"type": "xgboost",
8477+
"target": "xgb_classifier.txt",
8478+
"source": "https://github.com/user-attachments/files/20028544/xgb_classifier.zip[xgb_classifier.txt]",
8479+
"error": "File contains unsupported XGBoost text data.",
8480+
"link": "https://github.com/lutzroeder/netron/issues/1452"
8481+
},
84548482
{
84558483
"type": "xmodel",
84568484
"target": "face-quality_pt.xmodel",

0 commit comments

Comments
 (0)