[SPARK-34591][ML] Add decision tree pruning as a parameter#55728
[SPARK-34591][ML] Add decision tree pruning as a parameter#55728WeichenXu123 wants to merge 21 commits intomasterfrom
Conversation
### What changes were proposed in this pull request?
This PR disables a feature created in SPARK-3159 where LearningNodes are
merged after a RF model is trained.
### Why are the changes needed?
2 Reasons:
1. In addition to basic classification, another use case for decision trees are the
probabilities associated with predictions. Once pruned, these predictions are lost
and it makes the trees/predictions challenging to work with if not unusable.
2. It is not in line with the default behavior in sklearn. In sklearn, the trees
are left unpruned by default.
### Does this PR introduce _any_ user-facing change?
No, it's dev-only.
### How was this patch tested?
Locally ran `./build/mvn -pl mllib package` and verified tests passed
Additionally, running through git workflow as described here:
https://spark.apache.org/developer-tools.html#github-workflow-tests
This PR disables a feature created in SPARK-3159 where LearningNodes are merged after a RF model is trained. 2 Reasons: 1. In addition to basic classification, another use case for decision trees are the probabilities associated with predictions. Once pruned, these predictions are lost and it makes the trees/predictions challenging to work with if not unusable. 2. It is not in line with the default behavior in sklearn. In sklearn, the trees are left unpruned by default. Please see Jira ticket for more explanation. No, it's dev-only. I modified the two tests introduced with this change to verify postive/negative use of feature. I also added assertions for default behavior Locally ran `./build/mvn -pl mllib package` and verified tests passed Additionally, running through git workflow as described here: https://spark.apache.org/developer-tools.html#github-workflow-tests
…are merged after a RF model is trained.
2 Reasons:
1. In addition to basic classification, another use case for decision trees are the probabilities associated with predictions.
Once pruned, these predictions are lost and it makes the trees/predictions challenging to work with if not unusable.
2. It is not in line with the default behavior in sklearn. In sklearn, the trees are left unpruned by default.
Please see Jira ticket for more explanation.
No, it's dev-only.
I modified the two tests introduced with this change to verify postive/negative use of feature. I also added assertions for default behavior
Locally ran `./build/mvn -pl mllib package` and verified tests passed
Locally ran `./dev/scalafmt` which resulted in some minor cosmetic changes
Additionally, running through git workflow as described here:
https://spark.apache.org/developer-tools.html#github-workflow-tests
Signed-off-by: Weichen Xu <weichen.xu@databricks.com>
Signed-off-by: Weichen Xu <weichen.xu@databricks.com>
Signed-off-by: Weichen Xu <weichen.xu@databricks.com>
Signed-off-by: Weichen Xu <weichen.xu@databricks.com>
There was a problem hiding this comment.
Pull request overview
Adds a configurable switch to control post-training decision tree “pruning” (merging redundant leaf nodes) and wires it through Spark ML (Scala + PySpark) APIs down to the underlying training implementation.
Changes:
- Introduce a new
pruneTreeparameter on Spark ML tree-based classifiers (Scala + PySpark) and propagate it into the oldStrategyused by the training code. - Modify
ml/tree/impl/RandomForestto usestrategy.pruneTreewhen convertingLearningNodeto finalNodetrees (affecting pruning behavior). - Update/extend RandomForest implementation tests and reformat a large portion of the suite.
Reviewed changes
Copilot reviewed 8 out of 8 changed files in this pull request and generated 15 comments.
Show a summary per file
| File | Description |
|---|---|
| python/pyspark/ml/tree.py | Adds pruneTree param + getter to the shared Python tree classifier params. |
| python/pyspark/ml/classification.py | Exposes pruneTree in Python DecisionTreeClassifier / RandomForestClassifier constructors, setters, and docstrings. |
| mllib/src/main/scala/org/apache/spark/ml/tree/treeParams.scala | Adds pruneTree to Scala ML TreeClassifierParams with defaults and docs. |
| mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala | Adds pruneTree to the old mllib Strategy so training code can read it. |
| mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala | Uses strategy.pruneTree to decide whether to prune when finalizing trees (and for early-stop size estimation). |
| mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala | Adds setPruneTree and sets strategy.pruneTree during training; logs the param. |
| mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala | Same as above for RF classifier. |
| mllib/src/test/scala/org/apache/spark/ml/tree/impl/RandomForestSuite.scala | Reformats tests and adds/updates pruning-related expectations (but currently contains compilation-breaking calls). |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
Signed-off-by: Weichen Xu <weichen.xu@databricks.com>
| featureSubsetStrategy: String, | ||
| seed: Long, | ||
| instr: Option[Instrumentation], | ||
| prune: Boolean = true, // exposed for testing only, real trees are always pruned |
There was a problem hiding this comment.
should the default value be true to align with previous impl?
There was a problem hiding this comment.
"default prune = false" is proposed in the jira: https://issues.apache.org/jira/browse/SPARK-34591
but to keep API compatibility, keeping it to true might be safer.
Signed-off-by: Weichen Xu <weichen.xu@databricks.com>
Signed-off-by: Weichen Xu <weichen.xu@databricks.com>
|
merged to master |
|
it seems this PR broke the CI https://github.com/apache/spark/actions/runs/25548774851/job/74991497856 let me revert it for now |
|
reverted in #55759 |
|
new PR: #55763 |
What changes were proposed in this pull request?
This PR adds a parameter to enable/disable a featuer where LearningNodes are merged after a RF model is trained.
This PR takes over #32813
Why are the changes needed?
2 Reasons:
In addition to basic classification, another use case for decision trees are the probabilities associated with predictions.
Once pruned, these predictions are lost and it makes the trees/predictions challenging to work with if not unusable.
It is not in line with the default behavior in sklearn. In sklearn, the trees are left unpruned by default.
Please see Jira ticket for more explanation.
Does this PR introduce any user-facing change?
New params:
adds a parameter
pruneTreethat is exposed to the Tree based classifiers. Will add tests here to ensure parameter is exposed correctly.How was this patch tested?
Unit tests.