@@ -74,7 +74,9 @@ private[scala] case class XGBoostExecutionParams(
74
74
earlyStoppingParams : XGBoostExecutionEarlyStoppingParams ,
75
75
cacheTrainingSet : Boolean ,
76
76
treeMethod : Option [String ],
77
- isLocal : Boolean ) {
77
+ isLocal : Boolean ,
78
+ featureNames : Option [Array [String ]],
79
+ featureTypes : Option [Array [String ]]) {
78
80
79
81
private var rawParamMap : Map [String , Any ] = _
80
82
@@ -213,14 +215,24 @@ private[this] class XGBoostExecutionParamsFactory(rawParams: Map[String, Any], s
213
215
val cacheTrainingSet = overridedParams.getOrElse(" cache_training_set" , false )
214
216
.asInstanceOf [Boolean ]
215
217
218
+ val featureNames = if (overridedParams.contains(" feature_names" )) {
219
+ Some (overridedParams(" feature_names" ).asInstanceOf [Array [String ]])
220
+ } else None
221
+ val featureTypes = if (overridedParams.contains(" feature_types" )){
222
+ Some (overridedParams(" feature_types" ).asInstanceOf [Array [String ]])
223
+ } else None
224
+
216
225
val xgbExecParam = XGBoostExecutionParams (nWorkers, round, useExternalMemory, obj, eval,
217
226
missing, allowNonZeroForMissing, trackerConf,
218
227
checkpointParam,
219
228
inputParams,
220
229
xgbExecEarlyStoppingParams,
221
230
cacheTrainingSet,
222
231
treeMethod,
223
- isLocal)
232
+ isLocal,
233
+ featureNames,
234
+ featureTypes
235
+ )
224
236
xgbExecParam.setRawParamMap(overridedParams)
225
237
xgbExecParam
226
238
}
@@ -531,6 +543,16 @@ private object Watches {
531
543
if (trainMargin.isDefined) trainMatrix.setBaseMargin(trainMargin.get)
532
544
if (testMargin.isDefined) testMatrix.setBaseMargin(testMargin.get)
533
545
546
+ if (xgbExecutionParams.featureNames.isDefined) {
547
+ trainMatrix.setFeatureNames(xgbExecutionParams.featureNames.get)
548
+ testMatrix.setFeatureNames(xgbExecutionParams.featureNames.get)
549
+ }
550
+
551
+ if (xgbExecutionParams.featureTypes.isDefined) {
552
+ trainMatrix.setFeatureTypes(xgbExecutionParams.featureTypes.get)
553
+ testMatrix.setFeatureTypes(xgbExecutionParams.featureTypes.get)
554
+ }
555
+
534
556
new Watches (Array (trainMatrix, testMatrix), Array (" train" , " test" ), cacheDirName)
535
557
}
536
558
@@ -643,6 +665,15 @@ private object Watches {
643
665
if (trainMargin.isDefined) trainMatrix.setBaseMargin(trainMargin.get)
644
666
if (testMargin.isDefined) testMatrix.setBaseMargin(testMargin.get)
645
667
668
+ if (xgbExecutionParams.featureNames.isDefined) {
669
+ trainMatrix.setFeatureNames(xgbExecutionParams.featureNames.get)
670
+ testMatrix.setFeatureNames(xgbExecutionParams.featureNames.get)
671
+ }
672
+ if (xgbExecutionParams.featureTypes.isDefined) {
673
+ trainMatrix.setFeatureTypes(xgbExecutionParams.featureTypes.get)
674
+ testMatrix.setFeatureTypes(xgbExecutionParams.featureTypes.get)
675
+ }
676
+
646
677
new Watches (Array (trainMatrix, testMatrix), Array (" train" , " test" ), cacheDirName)
647
678
}
648
679
}
0 commit comments