Skip to content

[SPARK-34591][ML] Add decision tree pruning as a parameter#55758

Closed
WeichenXu123 wants to merge 1 commit intobranch-4.xfrom
SPARK-34591-4.x
Closed

[SPARK-34591][ML] Add decision tree pruning as a parameter#55758
WeichenXu123 wants to merge 1 commit intobranch-4.xfrom
SPARK-34591-4.x

Conversation

@WeichenXu123
Copy link
Copy Markdown
Contributor

What changes were proposed in this pull request?

This PR adds a parameter to enable/disable a featuer where LearningNodes are merged after a RF model is trained.

This PR takes over #32813

Why are the changes needed?

2 Reasons:

  1. In addition to basic classification, another use case for decision trees are the probabilities associated with predictions.
    Once pruned, these predictions are lost and it makes the trees/predictions challenging to work with if not unusable.

  2. It is not in line with the default behavior in sklearn. In sklearn, the trees are left unpruned by default.

Please see Jira ticket for more explanation.

Does this PR introduce any user-facing change?

New params:
adds a parameter pruneTree that is exposed to the Tree based classifiers. Will add tests here to ensure parameter is exposed correctly.

How was this patch tested?

Unit tests.

Lead-authored-by: WeichenXu weichen.xu@databricks.com
Co-authored-by: bribiescas-carlos bribiescas.carlos@bcg.com
Co-authored-by: Carlos Bribiescas CBribiescas@gmail.com
Signed-off-by: Weichen Xu weichen.xu@databricks.com

(cherry picked from commit 1f46506)

This PR adds a parameter to enable/disable a featuer where LearningNodes are merged after a RF model is trained.

This PR takes over #32813

2 Reasons:

1. In addition to basic classification, another use case for decision trees are the probabilities associated with predictions.
 Once pruned, these predictions are lost and it makes the trees/predictions challenging to work with if not unusable.

2. It is not in line with the default behavior in sklearn.  In sklearn, the trees are left unpruned by default.

Please see Jira ticket for more explanation.

**New params:**
adds a parameter `pruneTree` that is exposed to the Tree based classifiers.  Will add tests here to ensure parameter is exposed correctly.

Unit tests.

Closes #55728 from WeichenXu123/SPARK-34591.

Lead-authored-by: WeichenXu <weichen.xu@databricks.com>
Co-authored-by: bribiescas-carlos <bribiescas.carlos@bcg.com>
Co-authored-by: Carlos Bribiescas <CBribiescas@gmail.com>
Signed-off-by: Weichen Xu <weichen.xu@databricks.com>
(cherry picked from commit 1f46506)
Signed-off-by: Weichen Xu <weichen.xu@databricks.com>
Copilot AI review requested due to automatic review settings May 8, 2026 09:52
Copy link
Copy Markdown

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

This PR introduces a new pruneTree parameter for Spark ML tree-based classifiers to control whether post-training pruning (merging sibling leaves with identical predicted class) is applied, addressing use cases that rely on unpruned tree structure and more fine-grained probability estimates.

Changes:

  • Adds pruneTree as a new param in ML tree classifier params (Scala + PySpark) with default true.
  • Wires pruneTree through DecisionTreeClassifier and RandomForestClassifier into the old Strategy, and uses it in the training implementation when converting LearningNode to final Node.
  • Updates RandomForest implementation tests to validate pruned vs. unpruned behavior using strategy.pruneTree.

Reviewed changes

Copilot reviewed 8 out of 8 changed files in this pull request and generated 8 comments.

Show a summary per file
File Description
python/pyspark/ml/tree.py Adds PySpark pruneTree param + getter on shared tree-classifier params.
python/pyspark/ml/classification.py Exposes pruneTree in PySpark DecisionTreeClassifier / RandomForestClassifier constructors, defaults, and setters.
mllib/src/test/scala/org/apache/spark/ml/tree/impl/RandomForestSuite.scala Adjusts pruning-related tests to toggle pruning via strategy.pruneTree.
mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala Adds pruneTree to old API Strategy to carry pruning control into training.
mllib/src/main/scala/org/apache/spark/ml/tree/treeParams.scala Defines pruneTree as an ML param for tree classifiers and sets its default.
mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala Uses strategy.pruneTree when materializing final models (and during early-stop size estimation); removes the testing-only prune argument.
mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala Adds setPruneTree and propagates the param into strategy.pruneTree.
mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala Adds setPruneTree and propagates the param into strategy.pruneTree.

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Comment on lines +217 to +226
* faster predictions, but class probabilities will be lost.
* If false, no pruning is applied after training, and class probabilities are preserved.
* (default = true)
* @group param
*/
final val pruneTree: BooleanParam = new BooleanParam(this, "pruneTree", "" +
"If true, the trained tree will undergo a pruning process after training, in which nodes" +
" with the same class predictions are merged. The resulting tree will be smaller and have" +
" faster predictions, but class probabilities will be lost." +
" If false, no pruning is applied after training, and class probabilities are preserved."
Comment on lines +222 to +226
final val pruneTree: BooleanParam = new BooleanParam(this, "pruneTree", "" +
"If true, the trained tree will undergo a pruning process after training, in which nodes" +
" with the same class predictions are merged. The resulting tree will be smaller and have" +
" faster predictions, but class probabilities will be lost." +
" If false, no pruning is applied after training, and class probabilities are preserved."
@@ -113,12 +116,13 @@ class Strategy @Since("1.3.0") (
categoricalFeaturesInfo: Map[Int, Int],
minInstancesPerNode: Int,
minInfoGain: Double,
Comment thread python/pyspark/ml/tree.py
Comment on lines +421 to +422
" faster predictions, but class probabilities will be lost." +
" If false, no pruning is applied after training, and class probabilities are preserved.",
Comment thread python/pyspark/ml/tree.py
Comment on lines 418 to +439
@@ -424,6 +431,12 @@ def getImpurity(self) -> str:
Gets the value of impurity or its default value.
"""
return self.getOrDefault(self.impurity)
@since("4.3.0")
def getPruneTree(self) -> bool:
"""
Gets the value of pruneTree or its default value.
"""
return self.getOrDefault(self.pruneTree)
Sets the value of :py:attr:`pruneTree`.
"""
return self._set(pruneTree=value)

Comment on lines +2178 to +2183
def setPruneTree(self, value: bool) -> "RandomForestClassifier":
"""
Sets the value of :py:attr:`pruneTree`.
"""
return self._set(pruneTree=value)

assert(prunedTree.numNodes === 5)
assert(unprunedTree.numNodes === 7)

assert(defaultBehaviorTree.numNodes == prunedTree.numNodes)
@zhengruifeng zhengruifeng deleted the SPARK-34591-4.x branch May 8, 2026 11:15
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants