Skip to content

Commit a1367ea

Browse files
authored
Set feature_names and feature_types in jvm-packages (dmlc#9364)
* 1. Add parameters to set feature names and feature types 2. Save feature names and feature types to native json model * Change serialization and deserialization format to ubj.
1 parent 3632242 commit a1367ea

File tree

12 files changed

+295
-8
lines changed

12 files changed

+295
-8
lines changed

jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoost.scala

+33-2
Original file line numberDiff line numberDiff line change
@@ -74,7 +74,9 @@ private[scala] case class XGBoostExecutionParams(
7474
earlyStoppingParams: XGBoostExecutionEarlyStoppingParams,
7575
cacheTrainingSet: Boolean,
7676
treeMethod: Option[String],
77-
isLocal: Boolean) {
77+
isLocal: Boolean,
78+
featureNames: Option[Array[String]],
79+
featureTypes: Option[Array[String]]) {
7880

7981
private var rawParamMap: Map[String, Any] = _
8082

@@ -213,14 +215,24 @@ private[this] class XGBoostExecutionParamsFactory(rawParams: Map[String, Any], s
213215
val cacheTrainingSet = overridedParams.getOrElse("cache_training_set", false)
214216
.asInstanceOf[Boolean]
215217

218+
val featureNames = if (overridedParams.contains("feature_names")) {
219+
Some(overridedParams("feature_names").asInstanceOf[Array[String]])
220+
} else None
221+
val featureTypes = if (overridedParams.contains("feature_types")){
222+
Some(overridedParams("feature_types").asInstanceOf[Array[String]])
223+
} else None
224+
216225
val xgbExecParam = XGBoostExecutionParams(nWorkers, round, useExternalMemory, obj, eval,
217226
missing, allowNonZeroForMissing, trackerConf,
218227
checkpointParam,
219228
inputParams,
220229
xgbExecEarlyStoppingParams,
221230
cacheTrainingSet,
222231
treeMethod,
223-
isLocal)
232+
isLocal,
233+
featureNames,
234+
featureTypes
235+
)
224236
xgbExecParam.setRawParamMap(overridedParams)
225237
xgbExecParam
226238
}
@@ -531,6 +543,16 @@ private object Watches {
531543
if (trainMargin.isDefined) trainMatrix.setBaseMargin(trainMargin.get)
532544
if (testMargin.isDefined) testMatrix.setBaseMargin(testMargin.get)
533545

546+
if (xgbExecutionParams.featureNames.isDefined) {
547+
trainMatrix.setFeatureNames(xgbExecutionParams.featureNames.get)
548+
testMatrix.setFeatureNames(xgbExecutionParams.featureNames.get)
549+
}
550+
551+
if (xgbExecutionParams.featureTypes.isDefined) {
552+
trainMatrix.setFeatureTypes(xgbExecutionParams.featureTypes.get)
553+
testMatrix.setFeatureTypes(xgbExecutionParams.featureTypes.get)
554+
}
555+
534556
new Watches(Array(trainMatrix, testMatrix), Array("train", "test"), cacheDirName)
535557
}
536558

@@ -643,6 +665,15 @@ private object Watches {
643665
if (trainMargin.isDefined) trainMatrix.setBaseMargin(trainMargin.get)
644666
if (testMargin.isDefined) testMatrix.setBaseMargin(testMargin.get)
645667

668+
if (xgbExecutionParams.featureNames.isDefined) {
669+
trainMatrix.setFeatureNames(xgbExecutionParams.featureNames.get)
670+
testMatrix.setFeatureNames(xgbExecutionParams.featureNames.get)
671+
}
672+
if (xgbExecutionParams.featureTypes.isDefined) {
673+
trainMatrix.setFeatureTypes(xgbExecutionParams.featureTypes.get)
674+
testMatrix.setFeatureTypes(xgbExecutionParams.featureTypes.get)
675+
}
676+
646677
new Watches(Array(trainMatrix, testMatrix), Array("train", "test"), cacheDirName)
647678
}
648679
}

jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostClassifier.scala

+6
Original file line numberDiff line numberDiff line change
@@ -139,6 +139,12 @@ class XGBoostClassifier (
139139
def setSinglePrecisionHistogram(value: Boolean): this.type =
140140
set(singlePrecisionHistogram, value)
141141

142+
def setFeatureNames(value: Array[String]): this.type =
143+
set(featureNames, value)
144+
145+
def setFeatureTypes(value: Array[String]): this.type =
146+
set(featureTypes, value)
147+
142148
// called at the start of fit/train when 'eval_metric' is not defined
143149
private def setupDefaultEvalMetric(): String = {
144150
require(isDefined(objective), "Users must set \'objective\' via xgboostParams.")

jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostRegressor.scala

+6
Original file line numberDiff line numberDiff line change
@@ -141,6 +141,12 @@ class XGBoostRegressor (
141141
def setSinglePrecisionHistogram(value: Boolean): this.type =
142142
set(singlePrecisionHistogram, value)
143143

144+
def setFeatureNames(value: Array[String]): this.type =
145+
set(featureNames, value)
146+
147+
def setFeatureTypes(value: Array[String]): this.type =
148+
set(featureTypes, value)
149+
144150
// called at the start of fit/train when 'eval_metric' is not defined
145151
private def setupDefaultEvalMetric(): String = {
146152
require(isDefined(objective), "Users must set \'objective\' via xgboostParams.")

jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/params/GeneralParams.scala

+15
Original file line numberDiff line numberDiff line change
@@ -177,6 +177,21 @@ private[spark] trait GeneralParams extends Params {
177177

178178
final def getSeed: Long = $(seed)
179179

180+
/** Feature's name, it will be set to DMatrix and Booster, and in the final native json model.
181+
* In native code, the parameter name is feature_name.
182+
* */
183+
final val featureNames = new StringArrayParam(this, "feature_names",
184+
"an array of feature names")
185+
186+
final def getFeatureNames: Array[String] = $(featureNames)
187+
188+
/** Feature types, q is numeric and c is categorical.
189+
* In native code, the parameter name is feature_type
190+
* */
191+
final val featureTypes = new StringArrayParam(this, "feature_types",
192+
"an array of feature types")
193+
194+
final def getFeatureTypes: Array[String] = $(featureTypes)
180195
}
181196

182197
trait HasLeafPredictionCol extends Params {

jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostClassifierSuite.scala

+24
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,8 @@ import org.apache.commons.io.IOUtils
2727

2828
import org.apache.spark.Partitioner
2929
import org.apache.spark.ml.feature.VectorAssembler
30+
import org.json4s.{DefaultFormats, Formats}
31+
import org.json4s.jackson.parseJson
3032

3133
class XGBoostClassifierSuite extends AnyFunSuite with PerTest with TmpFolderPerSuite {
3234

@@ -453,4 +455,26 @@ class XGBoostClassifierSuite extends AnyFunSuite with PerTest with TmpFolderPerS
453455
assert(!compareTwoFiles(new File(modelJsonPath, "data/XGBoostClassificationModel").getPath,
454456
nativeUbjModelPath))
455457
}
458+
459+
test("native json model file should store feature_name and feature_type") {
460+
val featureNames = (1 to 33).map(idx => s"feature_${idx}").toArray
461+
val featureTypes = (1 to 33).map(idx => "q").toArray
462+
val paramMap = Map("eta" -> "0.1", "max_depth" -> "6", "silent" -> "1",
463+
"objective" -> "multi:softprob", "num_class" -> "6", "num_round" -> 5,
464+
"num_workers" -> numWorkers, "tree_method" -> treeMethod
465+
)
466+
val trainingDF = buildDataFrame(MultiClassification.train)
467+
val xgb = new XGBoostClassifier(paramMap)
468+
.setFeatureNames(featureNames)
469+
.setFeatureTypes(featureTypes)
470+
val model = xgb.fit(trainingDF)
471+
val modelStr = new String(model._booster.toByteArray("json"))
472+
System.out.println(modelStr)
473+
val jsonModel = parseJson(modelStr)
474+
implicit val formats: Formats = DefaultFormats
475+
val featureNamesInModel = (jsonModel \ "learner" \ "feature_names").extract[List[String]]
476+
val featureTypesInModel = (jsonModel \ "learner" \ "feature_types").extract[List[String]]
477+
assert(featureNamesInModel.length == 33)
478+
assert(featureTypesInModel.length == 33)
479+
}
456480
}

jvm-packages/xgboost4j/src/main/java/ml/dmlc/xgboost4j/java/Booster.java

+47-2
Original file line numberDiff line numberDiff line change
@@ -162,6 +162,51 @@ public void setAttrs(Map<String, String> attrs) throws XGBoostError {
162162
}
163163
}
164164

165+
/**
166+
* Get feature names from the Booster.
167+
* @return
168+
* @throws XGBoostError
169+
*/
170+
public final String[] getFeatureNames() throws XGBoostError {
171+
int numFeature = (int) getNumFeature();
172+
String[] out = new String[numFeature];
173+
XGBoostJNI.checkCall(XGBoostJNI.XGBoosterGetStrFeatureInfo(handle, "feature_name", out));
174+
return out;
175+
}
176+
177+
/**
178+
* Set feature names to the Booster.
179+
*
180+
* @param featureNames
181+
* @throws XGBoostError
182+
*/
183+
public void setFeatureNames(String[] featureNames) throws XGBoostError {
184+
XGBoostJNI.checkCall(XGBoostJNI.XGBoosterSetStrFeatureInfo(
185+
handle, "feature_name", featureNames));
186+
}
187+
188+
/**
189+
* Get feature types from the Booster.
190+
* @return
191+
* @throws XGBoostError
192+
*/
193+
public final String[] getFeatureTypes() throws XGBoostError {
194+
int numFeature = (int) getNumFeature();
195+
String[] out = new String[numFeature];
196+
XGBoostJNI.checkCall(XGBoostJNI.XGBoosterGetStrFeatureInfo(handle, "feature_type", out));
197+
return out;
198+
}
199+
200+
/**
201+
* Set feature types to the Booster.
202+
* @param featureTypes
203+
* @throws XGBoostError
204+
*/
205+
public void setFeatureTypes(String[] featureTypes) throws XGBoostError {
206+
XGBoostJNI.checkCall(XGBoostJNI.XGBoosterSetStrFeatureInfo(
207+
handle, "feature_type", featureTypes));
208+
}
209+
165210
/**
166211
* Update the booster for one iteration.
167212
*
@@ -744,7 +789,7 @@ private static long[] dmatrixsToHandles(DMatrix[] dmatrixs) {
744789
private void writeObject(java.io.ObjectOutputStream out) throws IOException {
745790
try {
746791
out.writeInt(version);
747-
out.writeObject(this.toByteArray());
792+
out.writeObject(this.toByteArray("ubj"));
748793
} catch (XGBoostError ex) {
749794
ex.printStackTrace();
750795
logger.error(ex.getMessage());
@@ -780,7 +825,7 @@ public synchronized void dispose() {
780825
@Override
781826
public void write(Kryo kryo, Output output) {
782827
try {
783-
byte[] serObj = this.toByteArray();
828+
byte[] serObj = this.toByteArray("ubj");
784829
int serObjSize = serObj.length;
785830
output.writeInt(serObjSize);
786831
output.writeInt(version);

jvm-packages/xgboost4j/src/main/java/ml/dmlc/xgboost4j/java/XGBoost.java

+2
Original file line numberDiff line numberDiff line change
@@ -198,6 +198,8 @@ public static Booster trainAndSaveCheckpoint(
198198
if (booster == null) {
199199
// Start training on a new booster
200200
booster = new Booster(params, allMats);
201+
booster.setFeatureNames(dtrain.getFeatureNames());
202+
booster.setFeatureTypes(dtrain.getFeatureTypes());
201203
booster.loadRabitCheckpoint();
202204
} else {
203205
// Start training on an existing booster

jvm-packages/xgboost4j/src/main/java/ml/dmlc/xgboost4j/java/XGBoostJNI.java

+4
Original file line numberDiff line numberDiff line change
@@ -164,4 +164,8 @@ public final static native int XGQuantileDMatrixCreateFromCallback(
164164
public final static native int XGDMatrixCreateFromArrayInterfaceColumns(
165165
String featureJson, float missing, int nthread, long[] out);
166166

167+
public final static native int XGBoosterSetStrFeatureInfo(long handle, String field, String[] features);
168+
169+
public final static native int XGBoosterGetStrFeatureInfo(long handle, String field, String[] out);
170+
167171
}

jvm-packages/xgboost4j/src/main/scala/ml/dmlc/xgboost4j/scala/DMatrix.scala

+40
Original file line numberDiff line numberDiff line change
@@ -205,6 +205,26 @@ class DMatrix private[scala](private[scala] val jDMatrix: JDMatrix) {
205205
jDMatrix.setBaseMargin(column)
206206
}
207207

208+
/**
209+
* set feature names
210+
* @param values feature names
211+
* @throws ml.dmlc.xgboost4j.java.XGBoostError
212+
*/
213+
@throws(classOf[XGBoostError])
214+
def setFeatureNames(values: Array[String]): Unit = {
215+
jDMatrix.setFeatureNames(values)
216+
}
217+
218+
/**
219+
* set feature types
220+
* @param values feature types
221+
* @throws ml.dmlc.xgboost4j.java.XGBoostError
222+
*/
223+
@throws(classOf[XGBoostError])
224+
def setFeatureTypes(values: Array[String]): Unit = {
225+
jDMatrix.setFeatureTypes(values)
226+
}
227+
208228
/**
209229
* Get group sizes of DMatrix (used for ranking)
210230
*/
@@ -243,6 +263,26 @@ class DMatrix private[scala](private[scala] val jDMatrix: JDMatrix) {
243263
jDMatrix.getBaseMargin
244264
}
245265

266+
/**
267+
* get feature names
268+
* @throws ml.dmlc.xgboost4j.java.XGBoostError
269+
* @return
270+
*/
271+
@throws(classOf[XGBoostError])
272+
def getFeatureNames: Array[String] = {
273+
jDMatrix.getFeatureNames
274+
}
275+
276+
/**
277+
* get feature types
278+
* @throws ml.dmlc.xgboost4j.java.XGBoostError
279+
* @return
280+
*/
281+
@throws(classOf[XGBoostError])
282+
def getFeatureTypes: Array[String] = {
283+
jDMatrix.getFeatureTypes
284+
}
285+
246286
/**
247287
* Slice the DMatrix and return a new DMatrix that only contains `rowIndex`.
248288
*

jvm-packages/xgboost4j/src/native/xgboost4j.cpp

+65
Original file line numberDiff line numberDiff line change
@@ -1148,3 +1148,68 @@ JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGDMatrixGetStrFea
11481148
if (field) jenv->ReleaseStringUTFChars(jfield, field);
11491149
return ret;
11501150
}
1151+
1152+
/*
1153+
* Class: ml_dmlc_xgboost4j_java_XGBoostJNI
1154+
* Method: XGBoosterSetStrFeatureInfo
1155+
* Signature: (JLjava/lang/String;[Ljava/lang/String;])I
1156+
*/
1157+
JNIEXPORT jint JNICALL
1158+
Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGBoosterSetStrFeatureInfo(
1159+
JNIEnv *jenv, jclass jclz, jlong jhandle, jstring jfield,
1160+
jobjectArray jfeatures) {
1161+
BoosterHandle handle = (BoosterHandle)jhandle;
1162+
1163+
const char *field = jenv->GetStringUTFChars(jfield, 0);
1164+
1165+
bst_ulong feature_num = (bst_ulong)jenv->GetArrayLength(jfeatures);
1166+
1167+
std::vector<std::string> features;
1168+
std::vector<char const*> features_char;
1169+
1170+
for (bst_ulong i = 0; i < feature_num; ++i) {
1171+
jstring jfeature = (jstring)jenv->GetObjectArrayElement(jfeatures, i);
1172+
const char *s = jenv->GetStringUTFChars(jfeature, 0);
1173+
features.push_back(std::string(s, jenv->GetStringLength(jfeature)));
1174+
if (s != nullptr) jenv->ReleaseStringUTFChars(jfeature, s);
1175+
}
1176+
1177+
for (size_t i = 0; i < features.size(); ++i) {
1178+
features_char.push_back(features[i].c_str());
1179+
}
1180+
1181+
int ret = XGBoosterSetStrFeatureInfo(
1182+
handle, field, dmlc::BeginPtr(features_char), feature_num);
1183+
JVM_CHECK_CALL(ret);
1184+
return ret;
1185+
}
1186+
1187+
/*
1188+
* Class: ml_dmlc_xgboost4j_java_XGBoostJNI
1189+
* Method: XGBoosterSetGtrFeatureInfo
1190+
* Signature: (JLjava/lang/String;[Ljava/lang/String;])I
1191+
*/
1192+
JNIEXPORT jint JNICALL
1193+
Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGBoosterGetStrFeatureInfo(
1194+
JNIEnv *jenv, jclass jclz, jlong jhandle, jstring jfield,
1195+
jobjectArray jout) {
1196+
BoosterHandle handle = (BoosterHandle)jhandle;
1197+
1198+
const char *field = jenv->GetStringUTFChars(jfield, 0);
1199+
1200+
bst_ulong feature_num = (bst_ulong)jenv->GetArrayLength(jout);
1201+
1202+
const char **features;
1203+
std::vector<char *> features_char;
1204+
1205+
int ret = XGBoosterGetStrFeatureInfo(handle, field, &feature_num,
1206+
(const char ***)&features);
1207+
JVM_CHECK_CALL(ret);
1208+
1209+
for (bst_ulong i = 0; i < feature_num; i++) {
1210+
jstring jfeature = jenv->NewStringUTF(features[i]);
1211+
jenv->SetObjectArrayElement(jout, i, jfeature);
1212+
}
1213+
1214+
return ret;
1215+
}

jvm-packages/xgboost4j/src/native/xgboost4j.h

+18
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

0 commit comments

Comments
 (0)