[SPARK-34591][ML] Add decision tree pruning as a parameter#55758
Closed
WeichenXu123 wants to merge 1 commit intobranch-4.xfrom
Closed
[SPARK-34591][ML] Add decision tree pruning as a parameter#55758WeichenXu123 wants to merge 1 commit intobranch-4.xfrom
WeichenXu123 wants to merge 1 commit intobranch-4.xfrom
Conversation
This PR adds a parameter to enable/disable a featuer where LearningNodes are merged after a RF model is trained. This PR takes over #32813 2 Reasons: 1. In addition to basic classification, another use case for decision trees are the probabilities associated with predictions. Once pruned, these predictions are lost and it makes the trees/predictions challenging to work with if not unusable. 2. It is not in line with the default behavior in sklearn. In sklearn, the trees are left unpruned by default. Please see Jira ticket for more explanation. **New params:** adds a parameter `pruneTree` that is exposed to the Tree based classifiers. Will add tests here to ensure parameter is exposed correctly. Unit tests. Closes #55728 from WeichenXu123/SPARK-34591. Lead-authored-by: WeichenXu <weichen.xu@databricks.com> Co-authored-by: bribiescas-carlos <bribiescas.carlos@bcg.com> Co-authored-by: Carlos Bribiescas <CBribiescas@gmail.com> Signed-off-by: Weichen Xu <weichen.xu@databricks.com> (cherry picked from commit 1f46506) Signed-off-by: Weichen Xu <weichen.xu@databricks.com>
There was a problem hiding this comment.
Pull request overview
This PR introduces a new pruneTree parameter for Spark ML tree-based classifiers to control whether post-training pruning (merging sibling leaves with identical predicted class) is applied, addressing use cases that rely on unpruned tree structure and more fine-grained probability estimates.
Changes:
- Adds
pruneTreeas a new param in ML tree classifier params (Scala + PySpark) with defaulttrue. - Wires
pruneTreethroughDecisionTreeClassifierandRandomForestClassifierinto the oldStrategy, and uses it in the training implementation when convertingLearningNodeto finalNode. - Updates RandomForest implementation tests to validate pruned vs. unpruned behavior using
strategy.pruneTree.
Reviewed changes
Copilot reviewed 8 out of 8 changed files in this pull request and generated 8 comments.
Show a summary per file
| File | Description |
|---|---|
| python/pyspark/ml/tree.py | Adds PySpark pruneTree param + getter on shared tree-classifier params. |
| python/pyspark/ml/classification.py | Exposes pruneTree in PySpark DecisionTreeClassifier / RandomForestClassifier constructors, defaults, and setters. |
| mllib/src/test/scala/org/apache/spark/ml/tree/impl/RandomForestSuite.scala | Adjusts pruning-related tests to toggle pruning via strategy.pruneTree. |
| mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala | Adds pruneTree to old API Strategy to carry pruning control into training. |
| mllib/src/main/scala/org/apache/spark/ml/tree/treeParams.scala | Defines pruneTree as an ML param for tree classifiers and sets its default. |
| mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala | Uses strategy.pruneTree when materializing final models (and during early-stop size estimation); removes the testing-only prune argument. |
| mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala | Adds setPruneTree and propagates the param into strategy.pruneTree. |
| mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala | Adds setPruneTree and propagates the param into strategy.pruneTree. |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
Comment on lines
+217
to
+226
| * faster predictions, but class probabilities will be lost. | ||
| * If false, no pruning is applied after training, and class probabilities are preserved. | ||
| * (default = true) | ||
| * @group param | ||
| */ | ||
| final val pruneTree: BooleanParam = new BooleanParam(this, "pruneTree", "" + | ||
| "If true, the trained tree will undergo a pruning process after training, in which nodes" + | ||
| " with the same class predictions are merged. The resulting tree will be smaller and have" + | ||
| " faster predictions, but class probabilities will be lost." + | ||
| " If false, no pruning is applied after training, and class probabilities are preserved." |
Comment on lines
+222
to
+226
| final val pruneTree: BooleanParam = new BooleanParam(this, "pruneTree", "" + | ||
| "If true, the trained tree will undergo a pruning process after training, in which nodes" + | ||
| " with the same class predictions are merged. The resulting tree will be smaller and have" + | ||
| " faster predictions, but class probabilities will be lost." + | ||
| " If false, no pruning is applied after training, and class probabilities are preserved." |
| @@ -113,12 +116,13 @@ class Strategy @Since("1.3.0") ( | |||
| categoricalFeaturesInfo: Map[Int, Int], | |||
| minInstancesPerNode: Int, | |||
| minInfoGain: Double, | |||
Comment on lines
+421
to
+422
| " faster predictions, but class probabilities will be lost." + | ||
| " If false, no pruning is applied after training, and class probabilities are preserved.", |
Comment on lines
418
to
+439
| @@ -424,6 +431,12 @@ def getImpurity(self) -> str: | |||
| Gets the value of impurity or its default value. | |||
| """ | |||
| return self.getOrDefault(self.impurity) | |||
| @since("4.3.0") | |||
| def getPruneTree(self) -> bool: | |||
| """ | |||
| Gets the value of pruneTree or its default value. | |||
| """ | |||
| return self.getOrDefault(self.pruneTree) | |||
| Sets the value of :py:attr:`pruneTree`. | ||
| """ | ||
| return self._set(pruneTree=value) | ||
|
|
Comment on lines
+2178
to
+2183
| def setPruneTree(self, value: bool) -> "RandomForestClassifier": | ||
| """ | ||
| Sets the value of :py:attr:`pruneTree`. | ||
| """ | ||
| return self._set(pruneTree=value) | ||
|
|
| assert(prunedTree.numNodes === 5) | ||
| assert(unprunedTree.numNodes === 7) | ||
|
|
||
| assert(defaultBehaviorTree.numNodes == prunedTree.numNodes) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
What changes were proposed in this pull request?
This PR adds a parameter to enable/disable a featuer where LearningNodes are merged after a RF model is trained.
This PR takes over #32813
Why are the changes needed?
2 Reasons:
In addition to basic classification, another use case for decision trees are the probabilities associated with predictions.
Once pruned, these predictions are lost and it makes the trees/predictions challenging to work with if not unusable.
It is not in line with the default behavior in sklearn. In sklearn, the trees are left unpruned by default.
Please see Jira ticket for more explanation.
Does this PR introduce any user-facing change?
New params:
adds a parameter
pruneTreethat is exposed to the Tree based classifiers. Will add tests here to ensure parameter is exposed correctly.How was this patch tested?
Unit tests.
Lead-authored-by: WeichenXu weichen.xu@databricks.com
Co-authored-by: bribiescas-carlos bribiescas.carlos@bcg.com
Co-authored-by: Carlos Bribiescas CBribiescas@gmail.com
Signed-off-by: Weichen Xu weichen.xu@databricks.com
(cherry picked from commit 1f46506)