Skip to content

Commit 40ac075

Browse files
authored
BUGFIX: Automatically reorder features in the xgboost raw model parser (#508)
1 parent 96f3485 commit 40ac075

File tree

2 files changed

+61
-1
lines changed

2 files changed

+61
-1
lines changed

src/main/java/com/o19s/es/ltr/ranker/dectree/NaiveAdditiveDecisionTree.java

+24
Original file line numberDiff line numberDiff line change
@@ -136,6 +136,30 @@ else if (s.threshold > scores[s.feature]) {
136136
return n.eval(scores);
137137
}
138138

139+
public Node getLeft() {
140+
return this.left;
141+
}
142+
143+
public Node getRight() {
144+
return this.right;
145+
}
146+
147+
public int getFeature() {
148+
return this.feature;
149+
}
150+
151+
public float getThreshold() {
152+
return this.threshold;
153+
}
154+
155+
public int getLeftNodeId() {
156+
return this.leftNodeId;
157+
}
158+
159+
public int getMissingNodeId() {
160+
return this.missingNodeId;
161+
}
162+
139163
/**
140164
* Return the memory usage of this object in bytes. Negative values are illegal.
141165
*/

src/main/java/com/o19s/es/ltr/ranker/parser/XGBoostRawJsonParser.java

+37-1
Original file line numberDiff line numberDiff line change
@@ -15,8 +15,10 @@
1515
import java.io.IOException;
1616
import java.util.Arrays;
1717
import java.util.ArrayList;
18+
import java.util.HashMap;
1819
import java.util.List;
1920
import java.util.ListIterator;
21+
import java.util.Map;
2022
import java.util.Optional;
2123

2224
public class XGBoostRawJsonParser implements LtrRankerParser {
@@ -37,9 +39,42 @@ public NaiveAdditiveDecisionTree parse(FeatureSet set, String model) {
3739
}
3840

3941
NaiveAdditiveDecisionTree.Node[] trees = modelDefinition.getLearner().getTrees(set);
42+
List<String> modelFeatures = modelDefinition.learner.featureNames;
43+
44+
// remap features according to the order in the feature set
45+
Map<Integer, Integer> modelFeaturesReordering = new HashMap<>();
46+
for (int i = 0; i < modelFeatures.size(); i++) {
47+
modelFeaturesReordering.put(i, set.featureOrdinal(modelFeatures.get(i)));
48+
}
49+
50+
// Reorder features in each tree
51+
NaiveAdditiveDecisionTree.Node[] adjustedTrees = new NaiveAdditiveDecisionTree.Node[trees.length];
52+
for (int i = 0; i < trees.length; i++) {
53+
adjustedTrees[i] = reorderTreeFeatures(trees[i], modelFeaturesReordering);
54+
}
55+
4056
float[] weights = new float[trees.length];
4157
Arrays.fill(weights, 1F);
42-
return new NaiveAdditiveDecisionTree(trees, weights, set.size(), modelDefinition.getLearner().getObjective().getNormalizer());
58+
return new NaiveAdditiveDecisionTree(
59+
adjustedTrees, weights, set.size(), modelDefinition.getLearner().getObjective().getNormalizer()
60+
);
61+
}
62+
63+
private NaiveAdditiveDecisionTree.Node reorderTreeFeatures(NaiveAdditiveDecisionTree.Node node,
64+
Map<Integer, Integer> modelFeaturesReordering) {
65+
if (node instanceof NaiveAdditiveDecisionTree.Split splitNode) {
66+
return new NaiveAdditiveDecisionTree.Split(
67+
reorderTreeFeatures(splitNode.getLeft(), modelFeaturesReordering),
68+
reorderTreeFeatures(splitNode.getRight(), modelFeaturesReordering),
69+
modelFeaturesReordering.get(splitNode.getFeature()),
70+
splitNode.getThreshold(),
71+
splitNode.getLeftNodeId(),
72+
splitNode.getMissingNodeId()
73+
);
74+
}
75+
76+
// if the node is Leaf we don't do anything
77+
return node;
4378
}
4479

4580
private static class XGBoostDefinition {
@@ -95,6 +130,7 @@ public static XGBoostRawJsonParser.XGBoostDefinition parse(XContentParser parser
95130
} else {
96131
throw new ParsingException(parser.getTokenLocation(), "Expected [START_OBJECT] but got [" + startToken + "]");
97132
}
133+
98134
return definition;
99135
}
100136

0 commit comments

Comments
 (0)