diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala index d5564f6a3fbda..887d8277d3117 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala @@ -74,10 +74,6 @@ class DecisionTreeClassifier @Since("1.4.0") ( @Since("1.4.0") def setMinInfoGain(value: Double): this.type = set(minInfoGain, value) - /** @group setParam */ - @Since("4.3.0") - def setPruneTree(value: Boolean): this.type = set(pruneTree, value) - /** @group expertSetParam */ @Since("1.4.0") def setMaxMemoryInMB(value: Int): this.type = set(maxMemoryInMB, value) @@ -138,11 +134,9 @@ class DecisionTreeClassifier @Since("1.4.0") ( val strategy = getOldStrategy(categoricalFeatures, numClasses) require(!strategy.bootstrap, "DecisionTreeClassifier does not need bootstrap sampling") - strategy.pruneTree = $(pruneTree) - instr.logNumClasses(numClasses) instr.logParams(this, labelCol, featuresCol, predictionCol, rawPredictionCol, - probabilityCol, leafCol, maxDepth, maxBins, minInstancesPerNode, minInfoGain, pruneTree, + probabilityCol, leafCol, maxDepth, maxBins, minInstancesPerNode, minInfoGain, maxMemoryInMB, cacheNodeIds, checkpointInterval, impurity, seed, thresholds) val trees = RandomForest.run(instances, strategy, numTrees = 1, featureSubsetStrategy = "all", diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala index 2c22ca5b42302..fb61358536d0c 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala @@ -76,10 +76,6 @@ class RandomForestClassifier @Since("1.4.0") ( @Since("1.4.0") def setMinInfoGain(value: Double): this.type = set(minInfoGain, value) - /** @group setParam */ - @Since("4.3.0") - def setPruneTree(value: Boolean): this.type = set(pruneTree, value) - /** @group expertSetParam */ @Since("1.4.0") def setMaxMemoryInMB(value: Int): this.type = set(maxMemoryInMB, value) @@ -163,11 +159,10 @@ class RandomForestClassifier @Since("1.4.0") ( val strategy = super.getOldStrategy(categoricalFeatures, numClasses, OldAlgo.Classification, getOldImpurity) strategy.bootstrap = $(bootstrap) - strategy.pruneTree = $(pruneTree) instr.logParams(this, labelCol, featuresCol, weightCol, predictionCol, probabilityCol, rawPredictionCol, leafCol, impurity, numTrees, featureSubsetStrategy, maxDepth, maxBins, - maxMemoryInMB, minInfoGain, pruneTree, minInstancesPerNode, minWeightFractionPerNode, seed, + maxMemoryInMB, minInfoGain, minInstancesPerNode, minWeightFractionPerNode, seed, subsamplingRate, thresholds, cacheNodeIds, checkpointInterval, bootstrap) val trees = RandomForest diff --git a/mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala index c3a16ab3dddd3..cabbc497571b6 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala @@ -41,6 +41,7 @@ import org.apache.spark.util.SizeEstimator import org.apache.spark.util.collection.OpenHashMap import org.apache.spark.util.random.{SamplingUtils, XORShiftRandom} + /** * ALGORITHM * @@ -96,9 +97,8 @@ private[spark] object RandomForest extends Logging with Serializable { numTrees: Int, featureSubsetStrategy: String, seed: Long): Array[DecisionTreeModel] = { - val instances = input.map { - case LabeledPoint(label, features) => - Instance(label, 1.0, features.asML) + val instances = input.map { case LabeledPoint(label, features) => + Instance(label, 1.0, features.asML) } run(instances, strategy, numTrees, featureSubsetStrategy, seed, None) } @@ -124,6 +124,7 @@ private[spark] object RandomForest extends Logging with Serializable { featureSubsetStrategy: String, seed: Long, instr: Option[Instrumentation], + prune: Boolean = true, // exposed for testing only, real trees are always pruned parentUID: Option[String] = None, earlyStopModelSizeThresholdInBytes: Long = 0): Array[DecisionTreeModel] = { lastEarlyStoppedModelSize = 0 @@ -150,8 +151,7 @@ private[spark] object RandomForest extends Logging with Serializable { // depth of the decision tree val maxDepth = strategy.maxDepth - require( - maxDepth <= 30, + require(maxDepth <= 30, s"DecisionTree currently only supports maxDepth <= 30, but was given maxDepth = $maxDepth.") // Max memory usage for aggregates @@ -203,10 +203,9 @@ private[spark] object RandomForest extends Logging with Serializable { // Collect some nodes to split, and choose features for each node (if subsampling). // Each group of nodes may come from one or multiple trees, and at multiple levels. val (nodesForGroup, treeToNodeToIndexInfo) = - RandomForest.selectNodesToSplit(nodeStack, maxMemoryUsage, metadata, rng) + RandomForest.selectNodesToSplit(nodeStack, maxMemoryUsage, metadata, rng) // Sanity check (should never occur): - assert( - nodesForGroup.nonEmpty, + assert(nodesForGroup.nonEmpty, s"RandomForest selected empty nodesForGroup. Error for unknown reason.") // Only send trees to worker if they contain nodes being split this iteration. @@ -215,16 +214,8 @@ private[spark] object RandomForest extends Logging with Serializable { // Choose node splits, and enqueue new nodes as needed. timer.start("findBestSplits") - val bestSplit = RandomForest.findBestSplits( - baggedInput, - metadata, - topNodesForGroup, - nodesForGroup, - treeToNodeToIndexInfo, - bcSplits, - nodeStack, - timer, - nodeIds, + val bestSplit = RandomForest.findBestSplits(baggedInput, metadata, topNodesForGroup, + nodesForGroup, treeToNodeToIndexInfo, bcSplits, nodeStack, timer, nodeIds, outputBestSplits = strategy.useNodeIdCache) if (strategy.useNodeIdCache) { nodeIds = updateNodeIds(baggedInput, nodeIds, bcSplits, bestSplit) @@ -234,7 +225,7 @@ private[spark] object RandomForest extends Logging with Serializable { timer.stop("findBestSplits") if (earlyStopModelSizeThresholdInBytes > 0) { - val nodes = topNodes.map(_.toNode(strategy.pruneTree)) + val nodes = topNodes.map(_.toNode(prune)) val estimatedSize = SizeEstimator.estimate(nodes) if (estimatedSize > earlyStopModelSizeThresholdInBytes){ earlyStop = true @@ -267,28 +258,23 @@ private[spark] object RandomForest extends Logging with Serializable { case Some(uid) => if (strategy.algo == OldAlgo.Classification) { topNodes.map { rootNode => - new DecisionTreeClassificationModel( - uid, - rootNode.toNode(strategy.pruneTree), - numFeatures, - strategy.getNumClasses) + new DecisionTreeClassificationModel(uid, rootNode.toNode(prune), numFeatures, + strategy.getNumClasses()) } } else { topNodes.map { rootNode => - new DecisionTreeRegressionModel(uid, rootNode.toNode(strategy.pruneTree), numFeatures) + new DecisionTreeRegressionModel(uid, rootNode.toNode(prune), numFeatures) } } case None => if (strategy.algo == OldAlgo.Classification) { topNodes.map { rootNode => - new DecisionTreeClassificationModel( - rootNode.toNode(strategy.pruneTree), - numFeatures, - strategy.getNumClasses) + new DecisionTreeClassificationModel(rootNode.toNode(prune), numFeatures, + strategy.getNumClasses()) } } else { topNodes.map(rootNode => - new DecisionTreeRegressionModel(rootNode.toNode(strategy.pruneTree), numFeatures)) + new DecisionTreeRegressionModel(rootNode.toNode(prune), numFeatures)) } } } @@ -307,6 +293,7 @@ private[spark] object RandomForest extends Logging with Serializable { featureSubsetStrategy: String, seed: Long, instr: Option[Instrumentation], + prune: Boolean = true, // exposed for testing only, real trees are always pruned parentUID: Option[String] = None): Array[DecisionTreeModel] = { val earlyStopModelSizeThresholdInBytes = TreeConfig.trainingEarlyStopModelSizeThresholdInBytes val timer = new TimeTracker() @@ -324,12 +311,9 @@ private[spark] object RandomForest extends Logging with Serializable { val splits = findSplits(retaggedInput, metadata, seed) timer.stop("findSplits") logDebug("numBins: feature: number of bins") - logDebug( - Range(0, metadata.numFeatures) - .map { featureIndex => - s"\t$featureIndex\t${metadata.numBins(featureIndex)}" - } - .mkString("\n")) + logDebug(Range(0, metadata.numFeatures).map { featureIndex => + s"\t$featureIndex\t${metadata.numBins(featureIndex)}" + }.mkString("\n")) // Bin feature values (TreePoint representation). // Cache input RDD for speedup during multiple passes. @@ -337,26 +321,14 @@ private[spark] object RandomForest extends Logging with Serializable { val bcSplits = input.sparkContext.broadcast(splits) val baggedInput = BaggedPoint - .convertToBaggedRDD( - treeInput, - strategy.subsamplingRate, - numTrees, - strategy.bootstrap, - (tp: TreePoint) => tp.weight, - seed = seed) + .convertToBaggedRDD(treeInput, strategy.subsamplingRate, numTrees, strategy.bootstrap, + (tp: TreePoint) => tp.weight, seed = seed) .persist(StorageLevel.MEMORY_AND_DISK) .setName("bagged tree points") - val trees = runBagged( - baggedInput = baggedInput, - metadata = metadata, - bcSplits = bcSplits, - strategy = strategy, - numTrees = numTrees, - featureSubsetStrategy = featureSubsetStrategy, - seed = seed, - instr = instr, - parentUID = parentUID, + val trees = runBagged(baggedInput = baggedInput, metadata = metadata, bcSplits = bcSplits, + strategy = strategy, numTrees = numTrees, featureSubsetStrategy = featureSubsetStrategy, + seed = seed, instr = instr, prune = prune, parentUID = parentUID, earlyStopModelSizeThresholdInBytes = earlyStopModelSizeThresholdInBytes) baggedInput.unpersist() @@ -374,27 +346,26 @@ private[spark] object RandomForest extends Logging with Serializable { bcSplits: Broadcast[Array[Array[Split]]], bestSplits: Array[Map[Int, Split]]): RDD[Array[Int]] = { require(nodeIds != null && bestSplits != null) - input.zip(nodeIds).map { - case (point, ids) => - var treeId = 0 - while (treeId < bestSplits.length) { - val bestSplitsInTree = bestSplits(treeId) - if (bestSplitsInTree != null) { - val nodeId = ids(treeId) - bestSplitsInTree.get(nodeId).foreach { bestSplit => - val featureId = bestSplit.featureIndex - val bin = point.datum.binnedFeatures(featureId) - val newNodeId = if (bestSplit.shouldGoLeft(bin, bcSplits.value(featureId))) { - LearningNode.leftChildIndex(nodeId) - } else { - LearningNode.rightChildIndex(nodeId) - } - ids(treeId) = newNodeId + input.zip(nodeIds).map { case (point, ids) => + var treeId = 0 + while (treeId < bestSplits.length) { + val bestSplitsInTree = bestSplits(treeId) + if (bestSplitsInTree != null) { + val nodeId = ids(treeId) + bestSplitsInTree.get(nodeId).foreach { bestSplit => + val featureId = bestSplit.featureIndex + val bin = point.datum.binnedFeatures(featureId) + val newNodeId = if (bestSplit.shouldGoLeft(bin, bcSplits.value(featureId))) { + LearningNode.leftChildIndex(nodeId) + } else { + LearningNode.rightChildIndex(nodeId) } + ids(treeId) = newNodeId } - treeId += 1 } - ids + treeId += 1 + } + ids } } @@ -446,11 +417,7 @@ private[spark] object RandomForest extends Logging with Serializable { var splitIndex = 0 while (splitIndex < numSplits) { if (featureSplits(splitIndex).shouldGoLeft(featureValue, featureSplits)) { - agg.featureUpdate( - leftNodeFeatureOffset, - splitIndex, - treePoint.label, - numSamples, + agg.featureUpdate(leftNodeFeatureOffset, splitIndex, treePoint.label, numSamples, sampleWeight) } splitIndex += 1 @@ -565,9 +532,8 @@ private[spark] object RandomForest extends Logging with Serializable { logDebug(s"numFeatures = ${metadata.numFeatures}") logDebug(s"numClasses = ${metadata.numClasses}") logDebug(s"isMulticlass = ${metadata.isMulticlass}") - logDebug( - s"isMulticlassWithCategoricalFeatures = " + - s"${metadata.isMulticlassWithCategoricalFeatures}") + logDebug(s"isMulticlassWithCategoricalFeatures = " + + s"${metadata.isMulticlassWithCategoricalFeatures}") logDebug(s"using nodeIdCache = $useNodeIdCache") /* @@ -594,21 +560,11 @@ private[spark] object RandomForest extends Logging with Serializable { val numSamples = baggedPoint.subsampleCounts(treeIndex) val sampleWeight = baggedPoint.sampleWeight if (metadata.unorderedFeatures.isEmpty) { - orderedBinSeqOp( - agg(aggNodeIndex), - baggedPoint.datum, - numSamples, - sampleWeight, + orderedBinSeqOp(agg(aggNodeIndex), baggedPoint.datum, numSamples, sampleWeight, featuresForNode) } else { - mixedBinSeqOp( - agg(aggNodeIndex), - baggedPoint.datum, - splits, - metadata.unorderedFeatures, - numSamples, - sampleWeight, - featuresForNode) + mixedBinSeqOp(agg(aggNodeIndex), baggedPoint.datum, splits, + metadata.unorderedFeatures, numSamples, sampleWeight, featuresForNode) } agg(aggNodeIndex).updateParent(baggedPoint.datum.label, numSamples, sampleWeight) } @@ -629,16 +585,11 @@ private[spark] object RandomForest extends Logging with Serializable { agg: Array[DTStatsAggregator], baggedPoint: BaggedPoint[TreePoint], splits: Array[Array[Split]]): Array[DTStatsAggregator] = { - treeToNodeToIndexInfo.foreach { - case (treeIndex, nodeIndexToInfo) => - val nodeIndex = - topNodesForGroup(treeIndex).predictImpl(baggedPoint.datum.binnedFeatures, splits) - nodeBinSeqOp( - treeIndex, - nodeIndexToInfo.getOrElse(nodeIndex, null), - agg, - baggedPoint, - splits) + treeToNodeToIndexInfo.foreach { case (treeIndex, nodeIndexToInfo) => + val nodeIndex = + topNodesForGroup(treeIndex).predictImpl(baggedPoint.datum.binnedFeatures, splits) + nodeBinSeqOp(treeIndex, nodeIndexToInfo.getOrElse(nodeIndex, null), + agg, baggedPoint, splits) } agg } @@ -650,17 +601,12 @@ private[spark] object RandomForest extends Logging with Serializable { agg: Array[DTStatsAggregator], dataPoint: (BaggedPoint[TreePoint], Array[Int]), splits: Array[Array[Split]]): Array[DTStatsAggregator] = { - treeToNodeToIndexInfo.foreach { - case (treeIndex, nodeIndexToInfo) => - val baggedPoint = dataPoint._1 - val nodeIdCache = dataPoint._2 - val nodeIndex = nodeIdCache(treeIndex) - nodeBinSeqOp( - treeIndex, - nodeIndexToInfo.getOrElse(nodeIndex, null), - agg, - baggedPoint, - splits) + treeToNodeToIndexInfo.foreach { case (treeIndex, nodeIndexToInfo) => + val baggedPoint = dataPoint._1 + val nodeIdCache = dataPoint._2 + val nodeIndex = nodeIdCache(treeIndex) + nodeBinSeqOp(treeIndex, nodeIndexToInfo.getOrElse(nodeIndex, null), + agg, baggedPoint, splits) } agg } @@ -669,8 +615,8 @@ private[spark] object RandomForest extends Logging with Serializable { * Get node index in group --> features indices map, * which is a short cut to find feature indices for a node given node index in group. */ - def getNodeToFeatures(treeToNodeToIndexInfo: Map[Int, Map[Int, NodeIndexInfo]]) - : Option[Map[Int, Array[Int]]] = { + def getNodeToFeatures( + treeToNodeToIndexInfo: Map[Int, Map[Int, NodeIndexInfo]]): Option[Map[Int, Array[Int]]] = { if (!metadata.subsamplingFeatures) { None } else { @@ -678,8 +624,7 @@ private[spark] object RandomForest extends Logging with Serializable { treeToNodeToIndexInfo.values.foreach { nodeIdToNodeInfo => nodeIdToNodeInfo.values.foreach { nodeIndexInfo => assert(nodeIndexInfo.featureSubset.isDefined) - mutableNodeToFeatures(nodeIndexInfo.nodeIndexInGroup) = - nodeIndexInfo.featureSubset.get + mutableNodeToFeatures(nodeIndexInfo.nodeIndexInGroup) = nodeIndexInfo.featureSubset.get } } Some(mutableNodeToFeatures.toMap) @@ -688,11 +633,10 @@ private[spark] object RandomForest extends Logging with Serializable { // array of nodes to train indexed by node index in group val nodes = new Array[LearningNode](numNodes) - nodesForGroup.foreach { - case (treeIndex, nodesForTree) => - nodesForTree.foreach { node => - nodes(treeToNodeToIndexInfo(treeIndex)(node.id).nodeIndexInGroup) = node - } + nodesForGroup.foreach { case (treeIndex, nodesForTree) => + nodesForTree.foreach { node => + nodes(treeToNodeToIndexInfo(treeIndex)(node.id).nodeIndexInGroup) = node + } } // Calculate best splits for all nodes in the group @@ -746,20 +690,17 @@ private[spark] object RandomForest extends Logging with Serializable { } } - val nodeToBestSplits = partitionAggregates - .reduceByKey((a, b) => a.merge(b)) - .map { - case (nodeIndex, aggStats) => - val featuresForNode = nodeToFeaturesBc.value.flatMap { nodeToFeatures => - Some(nodeToFeatures(nodeIndex)) - } + val nodeToBestSplits = partitionAggregates.reduceByKey((a, b) => a.merge(b)).map { + case (nodeIndex, aggStats) => + val featuresForNode = nodeToFeaturesBc.value.flatMap { nodeToFeatures => + Some(nodeToFeatures(nodeIndex)) + } - // find best split for each node - val (split: Split, stats: ImpurityStats) = - binsToBestSplit(aggStats, bcSplits.value, featuresForNode, nodes(nodeIndex)) - (nodeIndex, (split, stats)) - } - .collectAsMap() + // find best split for each node + val (split: Split, stats: ImpurityStats) = + binsToBestSplit(aggStats, bcSplits.value, featuresForNode, nodes(nodeIndex)) + (nodeIndex, (split, stats)) + }.collectAsMap() nodeToFeaturesBc.destroy() timer.stop("chooseSplits") @@ -771,64 +712,55 @@ private[spark] object RandomForest extends Logging with Serializable { } // Iterate over all nodes in this group. - nodesForGroup.foreach { - case (treeIndex, nodesForTree) => - nodesForTree.foreach { node => - val nodeIndex = node.id - val nodeInfo = treeToNodeToIndexInfo(treeIndex)(nodeIndex) - val aggNodeIndex = nodeInfo.nodeIndexInGroup - val (split: Split, stats: ImpurityStats) = - nodeToBestSplits(aggNodeIndex) - logDebug(s"best split = $split") - - // Extract info for this node. Create children if not leaf. - val isLeaf = - (stats.gain <= 0) || (LearningNode.indexToLevel(nodeIndex) == metadata.maxDepth) - node.isLeaf = isLeaf - node.stats = stats - logDebug(s"Node = $node") - - if (!isLeaf) { - node.split = Some(split) - val childIsLeaf = (LearningNode.indexToLevel(nodeIndex) + 1) == metadata.maxDepth - val leftChildIsLeaf = childIsLeaf || (math.abs(stats.leftImpurity) < Utils.EPSILON) - val rightChildIsLeaf = childIsLeaf || (math.abs(stats.rightImpurity) < Utils.EPSILON) - node.leftChild = Some( - LearningNode( - LearningNode.leftChildIndex(nodeIndex), - leftChildIsLeaf, - ImpurityStats.getEmptyImpurityStats(stats.leftImpurityCalculator))) - node.rightChild = Some( - LearningNode( - LearningNode.rightChildIndex(nodeIndex), - rightChildIsLeaf, - ImpurityStats.getEmptyImpurityStats(stats.rightImpurityCalculator))) - - if (outputBestSplits) { - val bestSplitsInTree = bestSplits(treeIndex) - if (bestSplitsInTree == null) { - bestSplits(treeIndex) = mutable.Map[Int, Split](nodeIndex -> split) - } else { - bestSplitsInTree.update(nodeIndex, split) - } - } - - // enqueue left child and right child if they are not leaves - if (!leftChildIsLeaf) { - nodeStack.prepend((treeIndex, node.leftChild.get)) - } - if (!rightChildIsLeaf) { - nodeStack.prepend((treeIndex, node.rightChild.get)) + nodesForGroup.foreach { case (treeIndex, nodesForTree) => + nodesForTree.foreach { node => + val nodeIndex = node.id + val nodeInfo = treeToNodeToIndexInfo(treeIndex)(nodeIndex) + val aggNodeIndex = nodeInfo.nodeIndexInGroup + val (split: Split, stats: ImpurityStats) = + nodeToBestSplits(aggNodeIndex) + logDebug(s"best split = $split") + + // Extract info for this node. Create children if not leaf. + val isLeaf = + (stats.gain <= 0) || (LearningNode.indexToLevel(nodeIndex) == metadata.maxDepth) + node.isLeaf = isLeaf + node.stats = stats + logDebug(s"Node = $node") + + if (!isLeaf) { + node.split = Some(split) + val childIsLeaf = (LearningNode.indexToLevel(nodeIndex) + 1) == metadata.maxDepth + val leftChildIsLeaf = childIsLeaf || (math.abs(stats.leftImpurity) < Utils.EPSILON) + val rightChildIsLeaf = childIsLeaf || (math.abs(stats.rightImpurity) < Utils.EPSILON) + node.leftChild = Some(LearningNode(LearningNode.leftChildIndex(nodeIndex), + leftChildIsLeaf, ImpurityStats.getEmptyImpurityStats(stats.leftImpurityCalculator))) + node.rightChild = Some(LearningNode(LearningNode.rightChildIndex(nodeIndex), + rightChildIsLeaf, ImpurityStats.getEmptyImpurityStats(stats.rightImpurityCalculator))) + + if (outputBestSplits) { + val bestSplitsInTree = bestSplits(treeIndex) + if (bestSplitsInTree == null) { + bestSplits(treeIndex) = mutable.Map[Int, Split](nodeIndex -> split) + } else { + bestSplitsInTree.update(nodeIndex, split) } + } - logDebug( - s"leftChildIndex = ${node.leftChild.get.id}" + - s", impurity = ${stats.leftImpurity}") - logDebug( - s"rightChildIndex = ${node.rightChild.get.id}" + - s", impurity = ${stats.rightImpurity}") + // enqueue left child and right child if they are not leaves + if (!leftChildIsLeaf) { + nodeStack.prepend((treeIndex, node.leftChild.get)) + } + if (!rightChildIsLeaf) { + nodeStack.prepend((treeIndex, node.rightChild.get)) } + + logDebug(s"leftChildIndex = ${node.leftChild.get.id}" + + s", impurity = ${stats.leftImpurity}") + logDebug(s"rightChildIndex = ${node.rightChild.get.id}" + + s", impurity = ${stats.rightImpurity}") } + } } if (outputBestSplits) { @@ -898,12 +830,8 @@ private[spark] object RandomForest extends Logging with Serializable { return ImpurityStats.getInvalidImpurityStats(parentImpurityCalculator) } - new ImpurityStats( - gain, - impurity, - parentImpurityCalculator, - leftImpurityCalculator, - rightImpurityCalculator) + new ImpurityStats(gain, impurity, parentImpurityCalculator, + leftImpurityCalculator, rightImpurityCalculator) } /** @@ -927,156 +855,130 @@ private[spark] object RandomForest extends Logging with Serializable { } val validFeatureSplits = - Iterator - .range(0, binAggregates.metadata.numFeaturesPerNode) - .map { featureIndexIdx => - featuresForNode - .map(features => (featureIndexIdx, features(featureIndexIdx))) - .getOrElse((featureIndexIdx, featureIndexIdx)) - } - .withFilter { - case (_, featureIndex) => - binAggregates.metadata.numSplits(featureIndex) != 0 - } + Iterator.range(0, binAggregates.metadata.numFeaturesPerNode).map { featureIndexIdx => + featuresForNode.map(features => (featureIndexIdx, features(featureIndexIdx))) + .getOrElse((featureIndexIdx, featureIndexIdx)) + }.withFilter { case (_, featureIndex) => + binAggregates.metadata.numSplits(featureIndex) != 0 + } // For each (feature, split), calculate the gain, and select the best (feature, split). val splitsAndImpurityInfo = - validFeatureSplits.map { - case (featureIndexIdx, featureIndex) => - val numSplits = binAggregates.metadata.numSplits(featureIndex) - if (binAggregates.metadata.isContinuous(featureIndex)) { - // Cumulative sum (scanLeft) of bin statistics. - // Afterwards, binAggregates for a bin is the sum of aggregates for - // that bin + all preceding bins. - val nodeFeatureOffset = binAggregates.getFeatureOffset(featureIndexIdx) - var splitIndex = 0 - while (splitIndex < numSplits) { - binAggregates.mergeForFeature(nodeFeatureOffset, splitIndex + 1, splitIndex) - splitIndex += 1 - } - // Find best split. - val (bestFeatureSplitIndex, bestFeatureGainStats) = - Range(0, numSplits) - .map { splitIdx => - val leftChildStats = - binAggregates.getImpurityCalculator(nodeFeatureOffset, splitIdx) - val rightChildStats = - binAggregates.getImpurityCalculator(nodeFeatureOffset, numSplits) - rightChildStats.subtract(leftChildStats) - gainAndImpurityStats = calculateImpurityStats( - gainAndImpurityStats, - leftChildStats, - rightChildStats, - binAggregates.metadata) - (splitIdx, gainAndImpurityStats) - } - .maxBy(_._2.gain) - (splits(featureIndex)(bestFeatureSplitIndex), bestFeatureGainStats) - } else if (binAggregates.metadata.isUnordered(featureIndex)) { - // Unordered categorical feature - val leftChildOffset = binAggregates.getFeatureOffset(featureIndexIdx) - val (bestFeatureSplitIndex, bestFeatureGainStats) = - Range(0, numSplits) - .map { splitIndex => - val leftChildStats = - binAggregates.getImpurityCalculator(leftChildOffset, splitIndex) - val rightChildStats = binAggregates - .getParentImpurityCalculator() - .subtract(leftChildStats) - gainAndImpurityStats = calculateImpurityStats( - gainAndImpurityStats, - leftChildStats, - rightChildStats, - binAggregates.metadata) - (splitIndex, gainAndImpurityStats) - } - .maxBy(_._2.gain) - (splits(featureIndex)(bestFeatureSplitIndex), bestFeatureGainStats) - } else { - // Ordered categorical feature - val nodeFeatureOffset = binAggregates.getFeatureOffset(featureIndexIdx) - val numCategories = binAggregates.metadata.numBins(featureIndex) - - /* Each bin is one category (feature value). - * The bins are ordered based on centroidForCategories, and this ordering determines - * which splits are considered. (With K categories, we - * consider K - 1 possible splits.) - * + validFeatureSplits.map { case (featureIndexIdx, featureIndex) => + val numSplits = binAggregates.metadata.numSplits(featureIndex) + if (binAggregates.metadata.isContinuous(featureIndex)) { + // Cumulative sum (scanLeft) of bin statistics. + // Afterwards, binAggregates for a bin is the sum of aggregates for + // that bin + all preceding bins. + val nodeFeatureOffset = binAggregates.getFeatureOffset(featureIndexIdx) + var splitIndex = 0 + while (splitIndex < numSplits) { + binAggregates.mergeForFeature(nodeFeatureOffset, splitIndex + 1, splitIndex) + splitIndex += 1 + } + // Find best split. + val (bestFeatureSplitIndex, bestFeatureGainStats) = + Range(0, numSplits).map { splitIdx => + val leftChildStats = + binAggregates.getImpurityCalculator(nodeFeatureOffset, splitIdx) + val rightChildStats = + binAggregates.getImpurityCalculator(nodeFeatureOffset, numSplits) + rightChildStats.subtract(leftChildStats) + gainAndImpurityStats = calculateImpurityStats(gainAndImpurityStats, + leftChildStats, rightChildStats, binAggregates.metadata) + (splitIdx, gainAndImpurityStats) + }.maxBy(_._2.gain) + (splits(featureIndex)(bestFeatureSplitIndex), bestFeatureGainStats) + } else if (binAggregates.metadata.isUnordered(featureIndex)) { + // Unordered categorical feature + val leftChildOffset = binAggregates.getFeatureOffset(featureIndexIdx) + val (bestFeatureSplitIndex, bestFeatureGainStats) = + Range(0, numSplits).map { splitIndex => + val leftChildStats = binAggregates.getImpurityCalculator(leftChildOffset, splitIndex) + val rightChildStats = binAggregates.getParentImpurityCalculator() + .subtract(leftChildStats) + gainAndImpurityStats = calculateImpurityStats(gainAndImpurityStats, + leftChildStats, rightChildStats, binAggregates.metadata) + (splitIndex, gainAndImpurityStats) + }.maxBy(_._2.gain) + (splits(featureIndex)(bestFeatureSplitIndex), bestFeatureGainStats) + } else { + // Ordered categorical feature + val nodeFeatureOffset = binAggregates.getFeatureOffset(featureIndexIdx) + val numCategories = binAggregates.metadata.numBins(featureIndex) + + /* Each bin is one category (feature value). + * The bins are ordered based on centroidForCategories, and this ordering determines which + * splits are considered. (With K categories, we consider K - 1 possible splits.) + * * centroidForCategories is a list: (category, centroid) - */ - val centroidForCategories = Range(0, numCategories).map { featureValue => - val categoryStats = - binAggregates.getImpurityCalculator(nodeFeatureOffset, featureValue) - val centroid = if (categoryStats.count != 0) { - if (binAggregates.metadata.isMulticlass) { - // multiclass classification - // For categorical variables in multiclass classification, - // the bins are ordered by the impurity of their corresponding labels. - categoryStats.calculate() - } else if (binAggregates.metadata.isClassification) { - // binary classification - // For categorical variables in binary classification, - // the bins are ordered by the count of class 1. - categoryStats.stats(1) - } else { - // regression - // For categorical variables in regression and binary classification, - // the bins are ordered by the prediction. - categoryStats.predict - } + */ + val centroidForCategories = Range(0, numCategories).map { featureValue => + val categoryStats = + binAggregates.getImpurityCalculator(nodeFeatureOffset, featureValue) + val centroid = if (categoryStats.count != 0) { + if (binAggregates.metadata.isMulticlass) { + // multiclass classification + // For categorical variables in multiclass classification, + // the bins are ordered by the impurity of their corresponding labels. + categoryStats.calculate() + } else if (binAggregates.metadata.isClassification) { + // binary classification + // For categorical variables in binary classification, + // the bins are ordered by the count of class 1. + categoryStats.stats(1) } else { - Double.MaxValue + // regression + // For categorical variables in regression and binary classification, + // the bins are ordered by the prediction. + categoryStats.predict } - (featureValue, centroid) + } else { + Double.MaxValue } + (featureValue, centroid) + } - logDebug( - s"Centroids for categorical variable: " + - s"${centroidForCategories.mkString(",")}") - - // bins sorted by centroids - val categoriesSortedByCentroid = centroidForCategories.toList.sortBy(_._2) - - logDebug( - s"Sorted centroids for categorical variable = " + - s"${categoriesSortedByCentroid.mkString(",")}") - - // Cumulative sum (scanLeft) of bin statistics. - // Afterwards, binAggregates for a bin is the sum of aggregates for - // that bin + all preceding bins. - var splitIndex = 0 - while (splitIndex < numSplits) { - val currentCategory = categoriesSortedByCentroid(splitIndex)._1 - val nextCategory = categoriesSortedByCentroid(splitIndex + 1)._1 - binAggregates.mergeForFeature(nodeFeatureOffset, nextCategory, currentCategory) - splitIndex += 1 - } - // lastCategory = index of bin with total aggregates for this (node, feature) - val lastCategory = categoriesSortedByCentroid.last._1 - // Find best split. - val (bestFeatureSplitIndex, bestFeatureGainStats) = - Range(0, numSplits) - .map { splitIndex => - val featureValue = categoriesSortedByCentroid(splitIndex)._1 - val leftChildStats = - binAggregates.getImpurityCalculator(nodeFeatureOffset, featureValue) - val rightChildStats = - binAggregates.getImpurityCalculator(nodeFeatureOffset, lastCategory) - rightChildStats.subtract(leftChildStats) - gainAndImpurityStats = calculateImpurityStats( - gainAndImpurityStats, - leftChildStats, - rightChildStats, - binAggregates.metadata) - (splitIndex, gainAndImpurityStats) - } - .maxBy(_._2.gain) - val categoriesForSplit = - categoriesSortedByCentroid.map(_._1.toDouble).slice(0, bestFeatureSplitIndex + 1) - val bestFeatureSplit = - new CategoricalSplit(featureIndex, categoriesForSplit.toArray, numCategories) - (bestFeatureSplit, bestFeatureGainStats) + logDebug(s"Centroids for categorical variable: " + + s"${centroidForCategories.mkString(",")}") + + // bins sorted by centroids + val categoriesSortedByCentroid = centroidForCategories.toList.sortBy(_._2) + + logDebug(s"Sorted centroids for categorical variable = " + + s"${categoriesSortedByCentroid.mkString(",")}") + + // Cumulative sum (scanLeft) of bin statistics. + // Afterwards, binAggregates for a bin is the sum of aggregates for + // that bin + all preceding bins. + var splitIndex = 0 + while (splitIndex < numSplits) { + val currentCategory = categoriesSortedByCentroid(splitIndex)._1 + val nextCategory = categoriesSortedByCentroid(splitIndex + 1)._1 + binAggregates.mergeForFeature(nodeFeatureOffset, nextCategory, currentCategory) + splitIndex += 1 } + // lastCategory = index of bin with total aggregates for this (node, feature) + val lastCategory = categoriesSortedByCentroid.last._1 + // Find best split. + val (bestFeatureSplitIndex, bestFeatureGainStats) = + Range(0, numSplits).map { splitIndex => + val featureValue = categoriesSortedByCentroid(splitIndex)._1 + val leftChildStats = + binAggregates.getImpurityCalculator(nodeFeatureOffset, featureValue) + val rightChildStats = + binAggregates.getImpurityCalculator(nodeFeatureOffset, lastCategory) + rightChildStats.subtract(leftChildStats) + gainAndImpurityStats = calculateImpurityStats(gainAndImpurityStats, + leftChildStats, rightChildStats, binAggregates.metadata) + (splitIndex, gainAndImpurityStats) + }.maxBy(_._2.gain) + val categoriesForSplit = + categoriesSortedByCentroid.map(_._1.toDouble).slice(0, bestFeatureSplitIndex + 1) + val bestFeatureSplit = + new CategoricalSplit(featureIndex, categoriesForSplit.toArray, numCategories) + (bestFeatureSplit, bestFeatureGainStats) + } } val (bestSplit, bestSplitStats) = @@ -1087,13 +989,11 @@ private[spark] object RandomForest extends Logging with Serializable { val dummyFeatureIndex = featuresForNode.map(_.head).getOrElse(0) val parentImpurityCalculator = binAggregates.getParentImpurityCalculator() if (binAggregates.metadata.isContinuous(dummyFeatureIndex)) { - ( - new ContinuousSplit(dummyFeatureIndex, 0), + (new ContinuousSplit(dummyFeatureIndex, 0), ImpurityStats.getInvalidImpurityStats(parentImpurityCalculator)) } else { val numCategories = binAggregates.metadata.featureArity(dummyFeatureIndex) - ( - new CategoricalSplit(dummyFeatureIndex, Array(), numCategories), + (new CategoricalSplit(dummyFeatureIndex, Array(), numCategories), ImpurityStats.getInvalidImpurityStats(parentImpurityCalculator)) } } else { @@ -1166,34 +1066,27 @@ private[spark] object RandomForest extends Logging with Serializable { // being spun up that will definitely do no work. val numPartitions = math.min(continuousFeatures.length, input.partitions.length) - input - .flatMap { point => - continuousFeatures.iterator - .map(idx => (idx, (point.features(idx), point.weight))) - .filter(_._2._1 != 0.0) - } - .aggregateByKey((new OpenHashMap[Double, Double], 0L), numPartitions)( - seqOp = { - case ((map, c), (v, w)) => - map.changeValue(v, w, _ + w) - (map, c + 1L) - }, - combOp = { - case ((map1, c1), (map2, c2)) => - map2.foreach { - case (v, w) => - map1.changeValue(v, w, _ + w) - } - (map1, c1 + c2) - }) - .map { - case (idx, (map, c)) => - val thresholds = findSplitsForContinuousFeature(map.toMap, c, metadata, idx) - val splits: Array[Split] = thresholds.map(thresh => new ContinuousSplit(idx, thresh)) - logDebug(s"featureIndex = $idx, numSplits = ${splits.length}") - (idx, splits) + input.flatMap { point => + continuousFeatures.iterator + .map(idx => (idx, (point.features(idx), point.weight))) + .filter(_._2._1 != 0.0) + }.aggregateByKey((new OpenHashMap[Double, Double], 0L), numPartitions)( + seqOp = { case ((map, c), (v, w)) => + map.changeValue(v, w, _ + w) + (map, c + 1L) + }, + combOp = { case ((map1, c1), (map2, c2)) => + map2.foreach { case (v, w) => + map1.changeValue(v, w, _ + w) + } + (map1, c1 + c2) } - .collectAsMap() + ).map { case (idx, (map, c)) => + val thresholds = findSplitsForContinuousFeature(map.toMap, c, metadata, idx) + val splits: Array[Split] = thresholds.map(thresh => new ContinuousSplit(idx, thresh)) + logDebug(s"featureIndex = $idx, numSplits = ${splits.length}") + (idx, splits) + }.collectAsMap() } else Map.empty[Int, Array[Split]] val numFeatures = metadata.numFeatures @@ -1264,10 +1157,9 @@ private[spark] object RandomForest extends Logging with Serializable { featureIndex: Int): Array[Double] = { val valueWeights = new OpenHashMap[Double, Double] var count = 0L - featureSamples.foreach { - case (weight, value) => - valueWeights.changeValue(value, weight, _ + weight) - count += 1L + featureSamples.foreach { case (weight, value) => + valueWeights.changeValue(value, weight, _ + weight) + count += 1L } findSplitsForContinuousFeature(valueWeights.toMap, count, metadata, featureIndex) } @@ -1290,8 +1182,7 @@ private[spark] object RandomForest extends Logging with Serializable { count: Long, metadata: DecisionTreeMetadata, featureIndex: Int): Array[Double] = { - require( - metadata.isContinuous(featureIndex), + require(metadata.isContinuous(featureIndex), "findSplitsForContinuousFeature can only be used to find splits for a continuous feature.") val splits = if (partValueWeights.isEmpty) { @@ -1365,8 +1256,7 @@ private[spark] object RandomForest extends Logging with Serializable { private[tree] class NodeIndexInfo( val nodeIndexInGroup: Int, - val featureSubset: Option[Array[Int]]) - extends Serializable + val featureSubset: Option[Array[Int]]) extends Serializable /** * Pull nodes off of the queue, and collect a group of nodes to be split on this iteration. @@ -1404,13 +1294,8 @@ private[spark] object RandomForest extends Logging with Serializable { val (treeIndex, node) = nodeStack.head // Choose subset of features for node (if subsampling). val featureSubset: Option[Array[Int]] = if (metadata.subsamplingFeatures) { - Some( - SamplingUtils - .reservoirSampleAndCount( - Range(0, metadata.numFeatures).iterator, - metadata.numFeaturesPerNode, - rng.nextLong()) - ._1) + Some(SamplingUtils.reservoirSampleAndCount(Range(0, + metadata.numFeatures).iterator, metadata.numFeaturesPerNode, rng.nextLong())._1) } else { None } @@ -1418,13 +1303,11 @@ private[spark] object RandomForest extends Logging with Serializable { val nodeMemUsage = RandomForest.aggregateSizeForNode(metadata, featureSubset) * 8L if (memUsage + nodeMemUsage <= maxMemoryUsage || memUsage == 0) { nodeStack.remove(0) - mutableNodesForGroup.getOrElseUpdate( - treeIndex, - new mutable.ArrayBuffer[LearningNode]()) += + mutableNodesForGroup.getOrElseUpdate(treeIndex, new mutable.ArrayBuffer[LearningNode]()) += node mutableTreeToNodeToIndexInfo - .getOrElseUpdate(treeIndex, new mutable.HashMap[Int, NodeIndexInfo]())(node.id) = - new NodeIndexInfo(numNodesInGroup, featureSubset) + .getOrElseUpdate(treeIndex, new mutable.HashMap[Int, NodeIndexInfo]())(node.id) + = new NodeIndexInfo(numNodesInGroup, featureSubset) numNodesInGroup += 1 memUsage += nodeMemUsage } else { @@ -1472,7 +1355,8 @@ private[spark] object RandomForest extends Logging with Serializable { * @param metadata decision tree metadata * @return subsample fraction */ - private def samplesFractionForFindSplits(metadata: DecisionTreeMetadata): Double = { + private def samplesFractionForFindSplits( + metadata: DecisionTreeMetadata): Double = { // Calculate the number of samples for approximate quantile calculation. val requiredSamples = math.max(metadata.maxBins * metadata.maxBins, 10000) if (requiredSamples < metadata.numExamples) { @@ -1481,5 +1365,4 @@ private[spark] object RandomForest extends Logging with Serializable { 1.0 } } - } diff --git a/mllib/src/main/scala/org/apache/spark/ml/tree/treeParams.scala b/mllib/src/main/scala/org/apache/spark/ml/tree/treeParams.scala index e5f542366be75..768e14f4b74e4 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/tree/treeParams.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/tree/treeParams.scala @@ -211,27 +211,10 @@ private[ml] trait TreeClassifierParams extends Params { (value: String) => TreeClassifierParams.supportedImpurities.contains(value.toLowerCase(Locale.ROOT))) - /** - * If true, the trained tree will undergo a pruning process after training, in which nodes - * with the same class predictions are merged. The resulting tree will be smaller and have - * faster predictions, but class probabilities will be lost. - * If false, no pruning is applied after training, and class probabilities are preserved. - * (default = true) - * @group param - */ - final val pruneTree: BooleanParam = new BooleanParam(this, "pruneTree", "" + - "If true, the trained tree will undergo a pruning process after training, in which nodes" + - " with the same class predictions are merged. The resulting tree will be smaller and have" + - " faster predictions, but class probabilities will be lost." + - " If false, no pruning is applied after training, and class probabilities are preserved." - ) - - setDefault(impurity -> "gini", pruneTree -> true) + setDefault(impurity -> "gini") /** @group getParam */ final def getImpurity: String = $(impurity).toLowerCase(Locale.ROOT) - /** @group getParam */ - final def getPruneTree: Boolean = $(pruneTree) /** Convert new impurity to old impurity. */ private[ml] def getOldImpurity: OldImpurity = { diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala index 85f4bcc642677..200d10130eed7 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala @@ -55,8 +55,6 @@ import org.apache.spark.mllib.tree.impurity.{Entropy, Gini, Impurity, Variance} * @param minInfoGain Minimum information gain a split must get. Default value is 0.0. * If a split has less information gain than minInfoGain, * this split will not be considered as a valid split. - * @param pruneTree If this is true, the final training tree will undergo a pruning in which - * nodes with the same classifications are merged. * @param maxMemoryInMB Maximum memory in MB allocated to histogram aggregation. Default value is * 256 MB. If too small, then 1 node will be split per iteration, and * its aggregates may exceed this size. @@ -79,7 +77,6 @@ class Strategy @Since("1.3.0") ( @Since("1.0.0") @BeanProperty var categoricalFeaturesInfo: Map[Int, Int] = Map[Int, Int](), @Since("1.2.0") @BeanProperty var minInstancesPerNode: Int = 1, @Since("1.2.0") @BeanProperty var minInfoGain: Double = 0.0, - @Since("4.3.0") @BeanProperty var pruneTree: Boolean = true, @Since("1.0.0") @BeanProperty var maxMemoryInMB: Int = 256, @Since("1.2.0") @BeanProperty var subsamplingRate: Double = 1, @Since("1.2.0") @BeanProperty var useNodeIdCache: Boolean = false, @@ -116,13 +113,12 @@ class Strategy @Since("1.3.0") ( categoricalFeaturesInfo: Map[Int, Int], minInstancesPerNode: Int, minInfoGain: Double, - pruneTree: Boolean, maxMemoryInMB: Int, subsamplingRate: Double, useNodeIdCache: Boolean, checkpointInterval: Int) = { this(algo, impurity, maxDepth, numClasses, maxBins, quantileCalculationStrategy, - categoricalFeaturesInfo, minInstancesPerNode, minInfoGain, pruneTree, maxMemoryInMB, + categoricalFeaturesInfo, minInstancesPerNode, minInfoGain, maxMemoryInMB, subsamplingRate, useNodeIdCache, checkpointInterval, 0.0) } // scalastyle:on argcount @@ -204,7 +200,7 @@ class Strategy @Since("1.3.0") ( def copy: Strategy = { new Strategy(algo, impurity, maxDepth, numClasses, maxBins, quantileCalculationStrategy, categoricalFeaturesInfo, minInstancesPerNode, - minInfoGain, pruneTree, maxMemoryInMB, subsamplingRate, useNodeIdCache, + minInfoGain, maxMemoryInMB, subsamplingRate, useNodeIdCache, checkpointInterval, minWeightFractionPerNode) } } diff --git a/mllib/src/test/scala/org/apache/spark/ml/tree/impl/RandomForestSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/tree/impl/RandomForestSuite.scala index 0c60441813159..62f25474e9476 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/tree/impl/RandomForestSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/tree/impl/RandomForestSuite.scala @@ -72,9 +72,8 @@ class RandomForestSuite extends SparkFunSuite with MLlibTestSparkContext { assert(splits(0).length === 0) } - test( - "Binary classification with 3-ary (ordered) categorical features," + - " with no samples for one category: split calculation") { + test("Binary classification with 3-ary (ordered) categorical features," + + " with no samples for one category: split calculation") { val arr = OldDTSuite.generateCategoricalDataPoints().map(_.asML.toInstance) assert(arr.length === 1000) val rdd = sc.parallelize(arr.toImmutableArraySeq) @@ -109,29 +108,16 @@ class RandomForestSuite extends SparkFunSuite with MLlibTestSparkContext { // SPARK-16957: Use midpoints for split values. { - val fakeMetadata = new DecisionTreeMetadata( - 1, - 8, - 8.0, - 0, - 0, - Map(), - Set(), - Array(3), - Gini, - QuantileStrategy.Sort, - 0, - 0, - 0.0, - 0.0, - 0, - 0) + val fakeMetadata = new DecisionTreeMetadata(1, 8, 8.0, 0, 0, + Map(), Set(), + Array(3), Gini, QuantileStrategy.Sort, + 0, 0, 0.0, 0.0, 0, 0 + ) // possibleSplits <= numSplits { val featureSamples = Array(0, 1, 0, 0, 1, 0, 1, 1) - .map(x => (1.0, x.toDouble)) - .filter(_._2 != 0.0) + .map(x => (1.0, x.toDouble)).filter(_._2 != 0.0) val splits = RandomForest.findSplitsForContinuousFeature(featureSamples, fakeMetadata, 0) val expectedSplits = Array((0.0 + 1.0) / 2) assert(splits === expectedSplits) @@ -140,8 +126,7 @@ class RandomForestSuite extends SparkFunSuite with MLlibTestSparkContext { // possibleSplits > numSplits { val featureSamples = Array(0, 0, 1, 1, 2, 2, 3, 3) - .map(x => (1.0, x.toDouble)) - .filter(_._2 != 0.0) + .map(x => (1.0, x.toDouble)).filter(_._2 != 0.0) val splits = RandomForest.findSplitsForContinuousFeature(featureSamples, fakeMetadata, 0) val expectedSplits = Array((0.0 + 1.0) / 2, (2.0 + 3.0) / 2) assert(splits === expectedSplits) @@ -151,23 +136,11 @@ class RandomForestSuite extends SparkFunSuite with MLlibTestSparkContext { // find splits should not return identical splits // when there are not enough split candidates, reduce the number of splits in metadata { - val fakeMetadata = new DecisionTreeMetadata( - 1, - 12, - 12.0, - 0, - 0, - Map(), - Set(), - Array(5), - Gini, - QuantileStrategy.Sort, - 0, - 0, - 0.0, - 0.0, - 0, - 0) + val fakeMetadata = new DecisionTreeMetadata(1, 12, 12.0, 0, 0, + Map(), Set(), + Array(5), Gini, QuantileStrategy.Sort, + 0, 0, 0.0, 0.0, 0, 0 + ) val featureSamples = Array(1, 1, 2, 2, 2, 2, 2, 2, 2, 2, 3, 3).map(x => (1.0, x.toDouble)) val splits = RandomForest.findSplitsForContinuousFeature(featureSamples, fakeMetadata, 0) val expectedSplits = Array((1.0 + 2.0) / 2, (2.0 + 3.0) / 2) @@ -178,23 +151,11 @@ class RandomForestSuite extends SparkFunSuite with MLlibTestSparkContext { // find splits when most samples close to the minimum { - val fakeMetadata = new DecisionTreeMetadata( - 1, - 18, - 18.0, - 0, - 0, - Map(), - Set(), - Array(3), - Gini, - QuantileStrategy.Sort, - 0, - 0, - 0.0, - 0.0, - 0, - 0) + val fakeMetadata = new DecisionTreeMetadata(1, 18, 18.0, 0, 0, + Map(), Set(), + Array(3), Gini, QuantileStrategy.Sort, + 0, 0, 0.0, 0.0, 0, 0 + ) val featureSamples = Array(2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 3, 4, 5).map(x => (1.0, x.toDouble)) val splits = RandomForest.findSplitsForContinuousFeature(featureSamples, fakeMetadata, 0) @@ -204,23 +165,11 @@ class RandomForestSuite extends SparkFunSuite with MLlibTestSparkContext { // find splits when most samples close to the maximum { - val fakeMetadata = new DecisionTreeMetadata( - 1, - 17, - 17.0, - 0, - 0, - Map(), - Set(), - Array(2), - Gini, - QuantileStrategy.Sort, - 0, - 0, - 0.0, - 0.0, - 0, - 0) + val fakeMetadata = new DecisionTreeMetadata(1, 17, 17.0, 0, 0, + Map(), Set(), + Array(2), Gini, QuantileStrategy.Sort, + 0, 0, 0.0, 0.0, 0, 0 + ) val featureSamples = Array(0, 1, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2).map(x => (1.0, x.toDouble)) val splits = RandomForest.findSplitsForContinuousFeature(featureSamples, fakeMetadata, 0) @@ -250,23 +199,11 @@ class RandomForestSuite extends SparkFunSuite with MLlibTestSparkContext { // find splits when most weight is close to the minimum { - val fakeMetadata = new DecisionTreeMetadata( - 1, - 0, - 0.0, - 0, - 0, - Map(), - Set(), - Array(3), - Gini, - QuantileStrategy.Sort, - 0, - 0, - 0.0, - 0.0, - 0, - 0) + val fakeMetadata = new DecisionTreeMetadata(1, 0, 0.0, 0, 0, + Map(), Set(), + Array(3), Gini, QuantileStrategy.Sort, + 0, 0, 0.0, 0.0, 0, 0 + ) val featureSamples = Array((10, 1), (1, 2), (1, 3), (1, 4), (1, 5), (1, 6)).map { case (w, x) => (w.toDouble, x.toDouble) } @@ -280,10 +217,10 @@ class RandomForestSuite extends SparkFunSuite with MLlibTestSparkContext { val data = Array.fill(5)(lp) val rdd = sc.parallelize(data.toImmutableArraySeq) - val strategy = new OldStrategy(OldAlgo.Regression, Gini, maxDepth = 2, maxBins = 5) - withClue( - "DecisionTree requires number of features > 0," + - " but was given an empty features vector") { + val strategy = new OldStrategy(OldAlgo.Regression, Gini, maxDepth = 2, + maxBins = 5) + withClue("DecisionTree requires number of features > 0," + + " but was given an empty features vector") { intercept[IllegalArgumentException] { RandomForest.run(rdd, strategy, 1, "all", 42L, instr = None) } @@ -295,19 +232,23 @@ class RandomForestSuite extends SparkFunSuite with MLlibTestSparkContext { val data = Array.fill(5)(instance) val rdd = sc.parallelize(data.toImmutableArraySeq) val strategy = new OldStrategy( - OldAlgo.Classification, - Gini, - maxDepth = 2, - numClasses = 2, - maxBins = 5, - categoricalFeaturesInfo = Map(0 -> 1, 1 -> 5)) + OldAlgo.Classification, + Gini, + maxDepth = 2, + numClasses = 2, + maxBins = 5, + categoricalFeaturesInfo = Map(0 -> 1, 1 -> 5)) val Array(tree) = RandomForest.run(rdd, strategy, 1, "all", 42L, instr = None) assert(tree.rootNode.impurity === -1.0) assert(tree.depth === 0) assert(tree.rootNode.prediction === instance.label) // Test with no categorical features - val strategy2 = new OldStrategy(OldAlgo.Regression, Variance, maxDepth = 2, maxBins = 5) + val strategy2 = new OldStrategy( + OldAlgo.Regression, + Variance, + maxDepth = 2, + maxBins = 5) val Array(tree2) = RandomForest.run(rdd, strategy2, 1, "all", 42L, instr = None) assert(tree2.rootNode.impurity === -1.0) assert(tree2.depth === 0) @@ -338,15 +279,12 @@ class RandomForestSuite extends SparkFunSuite with MLlibTestSparkContext { assert(metadata.numBins(1) === 3) // Expecting 2^2 - 1 = 3 splits per feature - def checkCategoricalSplit( - s: Split, - featureIndex: Int, - leftCategories: Array[Double]): Unit = { + def checkCategoricalSplit(s: Split, featureIndex: Int, leftCategories: Array[Double]): Unit = { assert(s.featureIndex === featureIndex) assert(s.isInstanceOf[CategoricalSplit]) val s0 = s.asInstanceOf[CategoricalSplit] assert(s0.leftCategories === leftCategories) - assert(s0.numCategories === 3) // for this unit test + assert(s0.numCategories === 3) // for this unit test } // Feature 0 checkCategoricalSplit(splits(0)(0), 0, Array(0.0)) @@ -359,8 +297,7 @@ class RandomForestSuite extends SparkFunSuite with MLlibTestSparkContext { } test("Multiclass classification with ordered categorical features: split calculations") { - val arr = OldDTSuite - .generateCategoricalDataPointsForMulticlassForOrderedFeatures() + val arr = OldDTSuite.generateCategoricalDataPointsForMulticlassForOrderedFeatures() .map(_.asML.toInstance) assert(arr.length === 3000) val rdd = sc.parallelize(arr.toImmutableArraySeq) @@ -395,12 +332,8 @@ class RandomForestSuite extends SparkFunSuite with MLlibTestSparkContext { LabeledPoint(1.0, Vectors.dense(0.0, 2.0, 1.0))) val input = sc.parallelize(arr.map(_.toInstance).toImmutableArraySeq) - val strategy = new OldStrategy( - algo = OldAlgo.Classification, - impurity = Gini, - maxDepth = 1, - numClasses = 2, - categoricalFeaturesInfo = Map(0 -> 3)) + val strategy = new OldStrategy(algo = OldAlgo.Classification, impurity = Gini, maxDepth = 1, + numClasses = 2, categoricalFeaturesInfo = Map(0 -> 3)) val metadata = DecisionTreeMetadata.buildMetadata(input, strategy) val splits = RandomForest.findSplits(input, metadata, seed = 42) val bcSplits = input.sparkContext.broadcast(splits) @@ -413,17 +346,12 @@ class RandomForestSuite extends SparkFunSuite with MLlibTestSparkContext { assert(topNode.stats === null) val nodesForGroup = Map(0 -> Array(topNode)) - val treeToNodeToIndexInfo = - Map(0 -> Map(topNode.id -> new RandomForest.NodeIndexInfo(0, None))) + val treeToNodeToIndexInfo = Map(0 -> Map( + topNode.id -> new RandomForest.NodeIndexInfo(0, None) + )) val nodeStack = new mutable.ListBuffer[(Int, LearningNode)] - RandomForest.findBestSplits( - baggedInput, - metadata, - Map(0 -> topNode), - nodesForGroup, - treeToNodeToIndexInfo, - bcSplits, - nodeStack) + RandomForest.findBestSplits(baggedInput, metadata, Map(0 -> topNode), + nodesForGroup, treeToNodeToIndexInfo, bcSplits, nodeStack) bcSplits.destroy() // don't enqueue leaf nodes into node queue @@ -448,12 +376,8 @@ class RandomForestSuite extends SparkFunSuite with MLlibTestSparkContext { LabeledPoint(1.0, Vectors.dense(0.0, 2.0, 1.0))) val input = sc.parallelize(arr.map(_.toInstance).toImmutableArraySeq) - val strategy = new OldStrategy( - algo = OldAlgo.Classification, - impurity = Gini, - maxDepth = 5, - numClasses = 2, - categoricalFeaturesInfo = Map(0 -> 3)) + val strategy = new OldStrategy(algo = OldAlgo.Classification, impurity = Gini, maxDepth = 5, + numClasses = 2, categoricalFeaturesInfo = Map(0 -> 3)) val metadata = DecisionTreeMetadata.buildMetadata(input, strategy) val splits = RandomForest.findSplits(input, metadata, seed = 42) val bcSplits = input.sparkContext.broadcast(splits) @@ -466,17 +390,12 @@ class RandomForestSuite extends SparkFunSuite with MLlibTestSparkContext { assert(topNode.stats === null) val nodesForGroup = Map(0 -> Array(topNode)) - val treeToNodeToIndexInfo = - Map(0 -> Map(topNode.id -> new RandomForest.NodeIndexInfo(0, None))) + val treeToNodeToIndexInfo = Map(0 -> Map( + topNode.id -> new RandomForest.NodeIndexInfo(0, None) + )) val nodeStack = new mutable.ListBuffer[(Int, LearningNode)] - RandomForest.findBestSplits( - baggedInput, - metadata, - Map(0 -> topNode), - nodesForGroup, - treeToNodeToIndexInfo, - bcSplits, - nodeStack) + RandomForest.findBestSplits(baggedInput, metadata, Map(0 -> topNode), + nodesForGroup, treeToNodeToIndexInfo, bcSplits, nodeStack) bcSplits.destroy() // don't enqueue a node into node queue if its impurity is 0.0 @@ -512,32 +431,18 @@ class RandomForestSuite extends SparkFunSuite with MLlibTestSparkContext { val input = sc.parallelize(arr.map(_.toInstance).toImmutableArraySeq) // Must set maxBins s.t. the feature will be treated as an ordered categorical feature. - val strategy = new OldStrategy( - algo = OldAlgo.Classification, - impurity = Gini, - maxDepth = 1, - numClasses = 2, - categoricalFeaturesInfo = Map(0 -> 3), - maxBins = 3) - - strategy.pruneTree = false - val model = RandomForest - .run( - input, - strategy, - numTrees = 1, - featureSubsetStrategy = "all", - seed = 42, - instr = None) - .head + val strategy = new OldStrategy(algo = OldAlgo.Classification, impurity = Gini, maxDepth = 1, + numClasses = 2, categoricalFeaturesInfo = Map(0 -> 3), maxBins = 3) + + val model = RandomForest.run(input, strategy, numTrees = 1, featureSubsetStrategy = "all", + seed = 42, instr = None, prune = false).head model.rootNode match { - case n: InternalNode => - n.split match { - case s: CategoricalSplit => - assert(s.leftCategories === Array(1.0)) - case _ => fail("model.rootNode.split was not a CategoricalSplit") - } + case n: InternalNode => n.split match { + case s: CategoricalSplit => + assert(s.leftCategories === Array(1.0)) + case _ => fail("model.rootNode.split was not a CategoricalSplit") + } case _ => fail("model.rootNode was not an InternalNode") } } @@ -553,21 +458,18 @@ class RandomForestSuite extends SparkFunSuite with MLlibTestSparkContext { val strategy2 = new OldStrategy(OldAlgo.Classification, Entropy, 3, 2, 100, maxMemoryInMB = 0) - val tree1 = RandomForest - .run(rdd, strategy1, numTrees = 1, featureSubsetStrategy = "all", seed = 42, instr = None) - .head - val tree2 = RandomForest - .run(rdd, strategy2, numTrees = 1, featureSubsetStrategy = "all", seed = 42, instr = None) - .head - - def getChildren(rootNode: Node): Array[InternalNode] = - rootNode match { - case n: InternalNode => - assert(n.leftChild.isInstanceOf[InternalNode]) - assert(n.rightChild.isInstanceOf[InternalNode]) - Array(n.leftChild.asInstanceOf[InternalNode], n.rightChild.asInstanceOf[InternalNode]) - case _ => fail("rootNode was not an InternalNode") - } + val tree1 = RandomForest.run(rdd, strategy1, numTrees = 1, featureSubsetStrategy = "all", + seed = 42, instr = None).head + val tree2 = RandomForest.run(rdd, strategy2, numTrees = 1, featureSubsetStrategy = "all", + seed = 42, instr = None).head + + def getChildren(rootNode: Node): Array[InternalNode] = rootNode match { + case n: InternalNode => + assert(n.leftChild.isInstanceOf[InternalNode]) + assert(n.rightChild.isInstanceOf[InternalNode]) + Array(n.leftChild.asInstanceOf[InternalNode], n.rightChild.asInstanceOf[InternalNode]) + case _ => fail("rootNode was not an InternalNode") + } // Single group second level tree construction. val children1 = getChildren(tree1.rootNode) @@ -613,9 +515,8 @@ class RandomForestSuite extends SparkFunSuite with MLlibTestSparkContext { nodeStack.prepend((treeIndex, topNodes(treeIndex))) } val rng = new scala.util.Random(seed = seed) - val ( - nodesForGroup: Map[Int, Array[LearningNode]], - treeToNodeToIndexInfo: Map[Int, Map[Int, RandomForest.NodeIndexInfo]]) = + val (nodesForGroup: Map[Int, Array[LearningNode]], + treeToNodeToIndexInfo: Map[Int, Map[Int, RandomForest.NodeIndexInfo]]) = RandomForest.selectNodesToSplit(nodeStack, maxMemoryUsage, metadata, rng) assert(nodesForGroup.size === numTrees, failString) @@ -623,15 +524,12 @@ class RandomForestSuite extends SparkFunSuite with MLlibTestSparkContext { if (numFeaturesPerNode == numFeatures) { // featureSubset values should all be None - assert( - treeToNodeToIndexInfo.values.forall(_.values.forall(_.featureSubset.isEmpty)), + assert(treeToNodeToIndexInfo.values.forall(_.values.forall(_.featureSubset.isEmpty)), failString) } else { // Check number of features. - assert( - treeToNodeToIndexInfo.values.forall( - _.values.forall(_.featureSubset.get.length === numFeaturesPerNode)), - failString) + assert(treeToNodeToIndexInfo.values.forall(_.values.forall( + _.featureSubset.get.length === numFeaturesPerNode)), failString) } } } @@ -639,9 +537,7 @@ class RandomForestSuite extends SparkFunSuite with MLlibTestSparkContext { checkFeatureSubsetStrategy(numTrees = 1, "auto", numFeatures) checkFeatureSubsetStrategy(numTrees = 1, "all", numFeatures) checkFeatureSubsetStrategy(numTrees = 1, "sqrt", math.sqrt(numFeatures).ceil.toInt) - checkFeatureSubsetStrategy( - numTrees = 1, - "log2", + checkFeatureSubsetStrategy(numTrees = 1, "log2", (math.log(numFeatures) / math.log(2)).ceil.toInt) checkFeatureSubsetStrategy(numTrees = 1, "onethird", (numFeatures / 3.0).ceil.toInt) @@ -659,7 +555,7 @@ class RandomForestSuite extends SparkFunSuite with MLlibTestSparkContext { val invalidStrategies = Array("-.1", "-.10", "-0.10", ".0", "0.0", "1.1", "0") for (invalidStrategy <- invalidStrategies) { - intercept[IllegalArgumentException] { + intercept[IllegalArgumentException]{ val metadata = DecisionTreeMetadata.buildMetadata(rdd, strategy, numTrees = 1, invalidStrategy) } @@ -668,9 +564,7 @@ class RandomForestSuite extends SparkFunSuite with MLlibTestSparkContext { checkFeatureSubsetStrategy(numTrees = 2, "all", numFeatures) checkFeatureSubsetStrategy(numTrees = 2, "auto", math.sqrt(numFeatures).ceil.toInt) checkFeatureSubsetStrategy(numTrees = 2, "sqrt", math.sqrt(numFeatures).ceil.toInt) - checkFeatureSubsetStrategy( - numTrees = 2, - "log2", + checkFeatureSubsetStrategy(numTrees = 2, "log2", (math.log(numFeatures) / math.log(2)).ceil.toInt) checkFeatureSubsetStrategy(numTrees = 2, "onethird", (numFeatures / 3.0).ceil.toInt) @@ -684,7 +578,7 @@ class RandomForestSuite extends SparkFunSuite with MLlibTestSparkContext { checkFeatureSubsetStrategy(numTrees = 2, strategy, expected) } for (invalidStrategy <- invalidStrategies) { - intercept[IllegalArgumentException] { + intercept[IllegalArgumentException]{ val metadata = DecisionTreeMetadata.buildMetadata(rdd, strategy, numTrees = 2, invalidStrategy) } @@ -693,23 +587,15 @@ class RandomForestSuite extends SparkFunSuite with MLlibTestSparkContext { test("Binary classification with continuous features: subsampling features") { val categoricalFeaturesInfo = Map.empty[Int, Int] - val strategy = new OldStrategy( - algo = OldAlgo.Classification, - impurity = Gini, - maxDepth = 2, - numClasses = 2, - categoricalFeaturesInfo = categoricalFeaturesInfo) + val strategy = new OldStrategy(algo = OldAlgo.Classification, impurity = Gini, maxDepth = 2, + numClasses = 2, categoricalFeaturesInfo = categoricalFeaturesInfo) binaryClassificationTestWithContinuousFeaturesAndSubsampledFeatures(strategy) } test("Binary classification with continuous features and node Id cache: subsampling features") { val categoricalFeaturesInfo = Map.empty[Int, Int] - val strategy = new OldStrategy( - algo = OldAlgo.Classification, - impurity = Gini, - maxDepth = 2, - numClasses = 2, - categoricalFeaturesInfo = categoricalFeaturesInfo, + val strategy = new OldStrategy(algo = OldAlgo.Classification, impurity = Gini, maxDepth = 2, + numClasses = 2, categoricalFeaturesInfo = categoricalFeaturesInfo, useNodeIdCache = true) binaryClassificationTestWithContinuousFeaturesAndSubsampledFeatures(strategy) } @@ -762,8 +648,7 @@ class RandomForestSuite extends SparkFunSuite with MLlibTestSparkContext { } val importances: Vector = TreeEnsembleModel.featureImportances(trees, 2) val tree2norm = feature0importance + feature1importance - val expected = Vectors.dense( - (1.0 + feature0importance / tree2norm) / 2.0, + val expected = Vectors.dense((1.0 + feature0importance / tree2norm) / 2.0, (feature1importance / tree2norm) / 2.0) assert(importances ~== expected relTol 0.01) } @@ -797,45 +682,18 @@ class RandomForestSuite extends SparkFunSuite with MLlibTestSparkContext { val rdd = sc.parallelize(arr.toImmutableArraySeq) val numClasses = 2 - val strategy = new OldStrategy( - algo = OldAlgo.Classification, - impurity = Gini, - maxDepth = 4, - numClasses = numClasses, - maxBins = 32) - - strategy.pruneTree = true - val prunedTree = RandomForest - .run( - rdd, - strategy, - numTrees = 1, - featureSubsetStrategy = "auto", - seed = 42, - instr = None) - .head - - strategy.pruneTree = false - val unprunedTree = RandomForest - .run( - rdd, - strategy, - numTrees = 1, - featureSubsetStrategy = "auto", - seed = 42, - instr = None) - .head - - strategy.pruneTree = true - val defaultBehaviorTree = RandomForest - .run(rdd, strategy, numTrees = 1, featureSubsetStrategy = "auto", seed = 42, instr = None) - .head + val strategy = new OldStrategy(algo = OldAlgo.Classification, impurity = Gini, maxDepth = 4, + numClasses = numClasses, maxBins = 32) + + val prunedTree = RandomForest.run(rdd, strategy, numTrees = 1, featureSubsetStrategy = "auto", + seed = 42, instr = None).head + + val unprunedTree = RandomForest.run(rdd, strategy, numTrees = 1, featureSubsetStrategy = "auto", + seed = 42, instr = None, prune = false).head assert(prunedTree.numNodes === 5) assert(unprunedTree.numNodes === 7) - assert(defaultBehaviorTree.numNodes == prunedTree.numNodes) - assert(RandomForestSuite.getSumLeafCounters(List(prunedTree.rootNode)) === arr.length) } @@ -854,45 +712,17 @@ class RandomForestSuite extends SparkFunSuite with MLlibTestSparkContext { ) val rdd = sc.parallelize(arr.toImmutableArraySeq) - val strategy = new OldStrategy( - algo = OldAlgo.Regression, - impurity = Variance, - maxDepth = 4, - numClasses = 0, - maxBins = 32) - - strategy.pruneTree = true - val prunedTree = RandomForest - .run( - rdd, - strategy, - numTrees = 1, - featureSubsetStrategy = "auto", - seed = 42, - instr = None) - .head - - strategy.pruneTree = false - val unprunedTree = RandomForest - .run( - rdd, - strategy, - numTrees = 1, - featureSubsetStrategy = "auto", - seed = 42, - instr = None) - .head - - strategy.pruneTree = true - val defaultBehaviorTree = RandomForest - .run(rdd, strategy, numTrees = 1, featureSubsetStrategy = "auto", seed = 42, instr = None) - .head + val strategy = new OldStrategy(algo = OldAlgo.Regression, impurity = Variance, maxDepth = 4, + numClasses = 0, maxBins = 32) - assert(prunedTree.numNodes === 3) - assert(unprunedTree.numNodes === 5) + val prunedTree = RandomForest.run(rdd, strategy, numTrees = 1, featureSubsetStrategy = "auto", + seed = 42, instr = None).head - assert(defaultBehaviorTree.numNodes == prunedTree.numNodes) + val unprunedTree = RandomForest.run(rdd, strategy, numTrees = 1, featureSubsetStrategy = "auto", + seed = 42, instr = None, prune = false).head + assert(prunedTree.numNodes === 3) + assert(unprunedTree.numNodes === 5) assert(RandomForestSuite.getSumLeafCounters(List(prunedTree.rootNode)) === arr.length) } @@ -909,15 +739,13 @@ class RandomForestSuite extends SparkFunSuite with MLlibTestSparkContext { val unitWeightTrees = RandomForest.run(rddWithUnitWeights, strategy, 3, "all", 42L, None) val smallWeightTrees = RandomForest.run(rddWithSmallWeights, strategy, 3, "all", 42L, None) - unitWeightTrees.zip(smallWeightTrees).foreach { - case (unitTree, smallWeightTree) => - TreeTests.checkEqual(unitTree, smallWeightTree) + unitWeightTrees.zip(smallWeightTrees).foreach { case (unitTree, smallWeightTree) => + TreeTests.checkEqual(unitTree, smallWeightTree) } val bigWeightTrees = RandomForest.run(rddWithBigWeights, strategy, 3, "all", 42L, None) - unitWeightTrees.zip(bigWeightTrees).foreach { - case (unitTree, bigWeightTree) => - TreeTests.checkEqual(unitTree, bigWeightTree) + unitWeightTrees.zip(bigWeightTrees).foreach { case (unitTree, bigWeightTree) => + TreeTests.checkEqual(unitTree, bigWeightTree) } } @@ -950,7 +778,6 @@ class RandomForestSuite extends SparkFunSuite with MLlibTestSparkContext { } private object RandomForestSuite { - def mapToVec(map: Map[Int, Double]): Vector = { val size = (map.keys.toSeq :+ 0).max + 1 val (indices, values) = map.toSeq.sortBy(_._1).unzip @@ -961,12 +788,12 @@ private object RandomForestSuite { private def getSumLeafCounters(nodes: List[Node], acc: Long = 0): Long = { if (nodes.isEmpty) { acc - } else { + } + else { nodes.head match { case i: InternalNode => getSumLeafCounters(i.leftChild :: i.rightChild :: nodes.tail, acc) case l: LeafNode => getSumLeafCounters(nodes.tail, acc + l.impurityStats.rawCount) } } } - } diff --git a/python/pyspark/ml/classification.py b/python/pyspark/ml/classification.py index 8f0646e2b24d0..f69ecf115f5ab 100644 --- a/python/pyspark/ml/classification.py +++ b/python/pyspark/ml/classification.py @@ -1678,7 +1678,6 @@ def __init__(self, *args: Any): maxBins=32, minInstancesPerNode=1, minInfoGain=0.0, - pruneTree=True, maxMemoryInMB=256, cacheNodeIds=False, checkpointInterval=10, @@ -1790,7 +1789,6 @@ def __init__( maxBins: int = 32, minInstancesPerNode: int = 1, minInfoGain: float = 0.0, - pruneTree: bool = True, maxMemoryInMB: int = 256, cacheNodeIds: bool = False, checkpointInterval: int = 10, @@ -1803,7 +1801,7 @@ def __init__( """ __init__(self, \\*, featuresCol="features", labelCol="label", predictionCol="prediction", \ probabilityCol="probability", rawPredictionCol="rawPrediction", \ - maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0, pruneTree=True, \ + maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0, \ maxMemoryInMB=256, cacheNodeIds=False, checkpointInterval=10, impurity="gini", \ seed=None, weightCol=None, leafCol="", minWeightFractionPerNode=0.0) """ @@ -1828,7 +1826,6 @@ def setParams( maxBins: int = 32, minInstancesPerNode: int = 1, minInfoGain: float = 0.0, - pruneTree: bool = True, maxMemoryInMB: int = 256, cacheNodeIds: bool = False, checkpointInterval: int = 10, @@ -1841,7 +1838,7 @@ def setParams( """ setParams(self, \\*, featuresCol="features", labelCol="label", predictionCol="prediction", \ probabilityCol="probability", rawPredictionCol="rawPrediction", \ - maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0, pruneTree=True, \ + maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0, \ maxMemoryInMB=256, cacheNodeIds=False, checkpointInterval=10, impurity="gini", \ seed=None, weightCol=None, leafCol="", minWeightFractionPerNode=0.0) Sets params for the DecisionTreeClassifier. @@ -1864,12 +1861,6 @@ def setMaxBins(self, value: int) -> "DecisionTreeClassifier": """ return self._set(maxBins=value) - def setPruneTree(self, value: bool) -> "DecisionTreeClassifier": - """ - Sets the value of :py:attr:`pruneTree`. - """ - return self._set(pruneTree=value) - def setMinInstancesPerNode(self, value: int) -> "DecisionTreeClassifier": """ Sets the value of :py:attr:`minInstancesPerNode`. @@ -1981,7 +1972,6 @@ def __init__(self, *args: Any): maxBins=32, minInstancesPerNode=1, minInfoGain=0.0, - pruneTree=True, maxMemoryInMB=256, cacheNodeIds=False, checkpointInterval=10, @@ -2091,7 +2081,6 @@ def __init__( maxBins: int = 32, minInstancesPerNode: int = 1, minInfoGain: float = 0.0, - pruneTree: bool = True, maxMemoryInMB: int = 256, cacheNodeIds: bool = False, checkpointInterval: int = 10, @@ -2108,7 +2097,7 @@ def __init__( """ __init__(self, \\*, featuresCol="features", labelCol="label", predictionCol="prediction", \ probabilityCol="probability", rawPredictionCol="rawPrediction", \ - maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0, pruneTree=True, \ + maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0, \ maxMemoryInMB=256, cacheNodeIds=False, checkpointInterval=10, impurity="gini", \ numTrees=20, featureSubsetStrategy="auto", seed=None, subsamplingRate=1.0, \ leafCol="", minWeightFractionPerNode=0.0, weightCol=None, bootstrap=True) @@ -2134,7 +2123,6 @@ def setParams( maxBins: int = 32, minInstancesPerNode: int = 1, minInfoGain: float = 0.0, - pruneTree: bool = True, maxMemoryInMB: int = 256, cacheNodeIds: bool = False, checkpointInterval: int = 10, @@ -2151,7 +2139,7 @@ def setParams( """ setParams(self, featuresCol="features", labelCol="label", predictionCol="prediction", \ probabilityCol="probability", rawPredictionCol="rawPrediction", \ - maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0, pruneTree=True, \ + maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0, \ maxMemoryInMB=256, cacheNodeIds=False, checkpointInterval=10, seed=None, \ impurity="gini", numTrees=20, featureSubsetStrategy="auto", subsamplingRate=1.0, \ leafCol="", minWeightFractionPerNode=0.0, weightCol=None, bootstrap=True) @@ -2175,12 +2163,6 @@ def setMaxBins(self, value: int) -> "RandomForestClassifier": """ return self._set(maxBins=value) - def setPruneTree(self, value: bool) -> "RandomForestClassifier": - """ - Sets the value of :py:attr:`pruneTree`. - """ - return self._set(pruneTree=value) - def setMinInstancesPerNode(self, value: int) -> "RandomForestClassifier": """ Sets the value of :py:attr:`minInstancesPerNode`. diff --git a/python/pyspark/ml/tree.py b/python/pyspark/ml/tree.py index 41b8bdc600c56..63f58272aeefb 100644 --- a/python/pyspark/ml/tree.py +++ b/python/pyspark/ml/tree.py @@ -415,13 +415,6 @@ class _TreeClassifierParams(Params): typeConverter=TypeConverters.toString, ) - pruneTree = Param(Params._dummy(), "pruneTree", "" + - "If true, the trained tree will undergo a pruning process after training, in which nodes" + - " with the same class predictions are merged. The resulting tree will be smaller and have" + - " faster predictions, but class probabilities will be lost." + - " If false, no pruning is applied after training, and class probabilities are preserved.", - typeConverter=TypeConverters.toBoolean) - def __init__(self) -> None: super().__init__() @@ -431,12 +424,6 @@ def getImpurity(self) -> str: Gets the value of impurity or its default value. """ return self.getOrDefault(self.impurity) - @since("4.3.0") - def getPruneTree(self) -> bool: - """ - Gets the value of pruneTree or its default value. - """ - return self.getOrDefault(self.pruneTree) class _TreeRegressorParams(_HasVarianceImpurity):