15
15
import java .io .IOException ;
16
16
import java .util .Arrays ;
17
17
import java .util .ArrayList ;
18
+ import java .util .HashMap ;
18
19
import java .util .List ;
19
20
import java .util .ListIterator ;
21
+ import java .util .Map ;
20
22
import java .util .Optional ;
21
23
22
24
public class XGBoostRawJsonParser implements LtrRankerParser {
@@ -37,9 +39,42 @@ public NaiveAdditiveDecisionTree parse(FeatureSet set, String model) {
37
39
}
38
40
39
41
NaiveAdditiveDecisionTree .Node [] trees = modelDefinition .getLearner ().getTrees (set );
42
+ List <String > modelFeatures = modelDefinition .learner .featureNames ;
43
+
44
+ // remap features according to the order in the feature set
45
+ Map <Integer , Integer > modelFeaturesReordering = new HashMap <>();
46
+ for (int i = 0 ; i < modelFeatures .size (); i ++) {
47
+ modelFeaturesReordering .put (i , set .featureOrdinal (modelFeatures .get (i )));
48
+ }
49
+
50
+ // Reorder features in each tree
51
+ NaiveAdditiveDecisionTree .Node [] adjustedTrees = new NaiveAdditiveDecisionTree .Node [trees .length ];
52
+ for (int i = 0 ; i < trees .length ; i ++) {
53
+ adjustedTrees [i ] = reorderTreeFeatures (trees [i ], modelFeaturesReordering );
54
+ }
55
+
40
56
float [] weights = new float [trees .length ];
41
57
Arrays .fill (weights , 1F );
42
- return new NaiveAdditiveDecisionTree (trees , weights , set .size (), modelDefinition .getLearner ().getObjective ().getNormalizer ());
58
+ return new NaiveAdditiveDecisionTree (
59
+ adjustedTrees , weights , set .size (), modelDefinition .getLearner ().getObjective ().getNormalizer ()
60
+ );
61
+ }
62
+
63
+ private NaiveAdditiveDecisionTree .Node reorderTreeFeatures (NaiveAdditiveDecisionTree .Node node ,
64
+ Map <Integer , Integer > modelFeaturesReordering ) {
65
+ if (node instanceof NaiveAdditiveDecisionTree .Split splitNode ) {
66
+ return new NaiveAdditiveDecisionTree .Split (
67
+ reorderTreeFeatures (splitNode .getLeft (), modelFeaturesReordering ),
68
+ reorderTreeFeatures (splitNode .getRight (), modelFeaturesReordering ),
69
+ modelFeaturesReordering .get (splitNode .getFeature ()),
70
+ splitNode .getThreshold (),
71
+ splitNode .getLeftNodeId (),
72
+ splitNode .getMissingNodeId ()
73
+ );
74
+ }
75
+
76
+ // if the node is Leaf we don't do anything
77
+ return node ;
43
78
}
44
79
45
80
private static class XGBoostDefinition {
@@ -95,6 +130,7 @@ public static XGBoostRawJsonParser.XGBoostDefinition parse(XContentParser parser
95
130
} else {
96
131
throw new ParsingException (parser .getTokenLocation (), "Expected [START_OBJECT] but got [" + startToken + "]" );
97
132
}
133
+
98
134
return definition ;
99
135
}
100
136
0 commit comments