diff --git a/python/pyspark/ml/classification.py b/python/pyspark/ml/classification.py index 3f9e3fa37f727..a5fdaed0db2c4 100644 --- a/python/pyspark/ml/classification.py +++ b/python/pyspark/ml/classification.py @@ -2306,7 +2306,12 @@ def featureImportances(self) -> Vector: def trees(self) -> List[DecisionTreeClassificationModel]: """Trees in this ensemble. Warning: These have null parent Estimators.""" if is_remote(): - return [DecisionTreeClassificationModel(m) for m in self._call_java("trees").split(",")] + from pyspark.ml.util import RemoteModelRef + + return [ + DecisionTreeClassificationModel(RemoteModelRef(m)) + for m in self._call_java("trees").split(",") + ] return [DecisionTreeClassificationModel(m) for m in list(self._call_java("trees"))] @property @@ -2805,7 +2810,12 @@ def featureImportances(self) -> Vector: def trees(self) -> List[DecisionTreeRegressionModel]: """Trees in this ensemble. Warning: These have null parent Estimators.""" if is_remote(): - return [DecisionTreeRegressionModel(m) for m in self._call_java("trees").split(",")] + from pyspark.ml.util import RemoteModelRef + + return [ + DecisionTreeRegressionModel(RemoteModelRef(m)) + for m in self._call_java("trees").split(",") + ] return [DecisionTreeRegressionModel(m) for m in list(self._call_java("trees"))] def evaluateEachIteration(self, dataset: DataFrame) -> List[float]: diff --git a/python/pyspark/ml/connect/readwrite.py b/python/pyspark/ml/connect/readwrite.py index 0dc38e7275c1f..ff53eb77d0326 100644 --- a/python/pyspark/ml/connect/readwrite.py +++ b/python/pyspark/ml/connect/readwrite.py @@ -77,11 +77,13 @@ def saveInstance( # Spark Connect ML is built on scala Spark.ML, that means we're only # supporting JavaModel or JavaEstimator or JavaEvaluator if isinstance(instance, JavaModel): + from pyspark.ml.util import RemoteModelRef + model = cast("JavaModel", instance) params = serialize_ml_params(model, session.client) - assert isinstance(model._java_obj, str) + assert isinstance(model._java_obj, RemoteModelRef) writer = pb2.MlCommand.Write( - obj_ref=pb2.ObjectRef(id=model._java_obj), + obj_ref=pb2.ObjectRef(id=model._java_obj.ref_id), params=params, path=path, should_overwrite=shouldOverwrite, @@ -270,9 +272,12 @@ def _get_class() -> Type[RL]: py_type = _get_class() # It must be JavaWrapper, since we're passing the string to the _java_obj if issubclass(py_type, JavaWrapper): + from pyspark.ml.util import RemoteModelRef + if ml_type == pb2.MlOperator.OPERATOR_TYPE_MODEL: session.client.add_ml_cache(result.obj_ref.id) - instance = py_type(result.obj_ref.id) + remote_model_ref = RemoteModelRef(result.obj_ref.id) + instance = py_type(remote_model_ref) else: instance = py_type() instance._resetUid(result.uid) diff --git a/python/pyspark/ml/feature.py b/python/pyspark/ml/feature.py index d669fab27d505..4d1551652028a 100755 --- a/python/pyspark/ml/feature.py +++ b/python/pyspark/ml/feature.py @@ -64,6 +64,7 @@ _jvm, ) from pyspark.ml.common import inherit_doc +from pyspark.ml.util import RemoteModelRef from pyspark.sql.types import ArrayType, StringType from pyspark.sql.utils import is_remote @@ -1224,10 +1225,12 @@ def from_vocabulary( if is_remote(): model = CountVectorizerModel() - model._java_obj = invoke_helper_attr( - "countVectorizerModelFromVocabulary", - model.uid, - list(vocabulary), + model._java_obj = RemoteModelRef( + invoke_helper_attr( + "countVectorizerModelFromVocabulary", + model.uid, + list(vocabulary), + ) ) else: @@ -4843,10 +4846,12 @@ def from_labels( """ if is_remote(): model = StringIndexerModel() - model._java_obj = invoke_helper_attr( - "stringIndexerModelFromLabels", - model.uid, - (list(labels), ArrayType(StringType())), + model._java_obj = RemoteModelRef( + invoke_helper_attr( + "stringIndexerModelFromLabels", + model.uid, + (list(labels), ArrayType(StringType())), + ) ) else: @@ -4882,13 +4887,15 @@ def from_arrays_of_labels( """ if is_remote(): model = StringIndexerModel() - model._java_obj = invoke_helper_attr( - "stringIndexerModelFromLabelsArray", - model.uid, - ( - [list(labels) for labels in arrayOfLabels], - ArrayType(ArrayType(StringType())), - ), + model._java_obj = RemoteModelRef( + invoke_helper_attr( + "stringIndexerModelFromLabelsArray", + model.uid, + ( + [list(labels) for labels in arrayOfLabels], + ArrayType(ArrayType(StringType())), + ), + ) ) else: diff --git a/python/pyspark/ml/regression.py b/python/pyspark/ml/regression.py index a7e793142233d..66d6dbd6a2678 100644 --- a/python/pyspark/ml/regression.py +++ b/python/pyspark/ml/regression.py @@ -1614,7 +1614,12 @@ class RandomForestRegressionModel( def trees(self) -> List[DecisionTreeRegressionModel]: """Trees in this ensemble. Warning: These have null parent Estimators.""" if is_remote(): - return [DecisionTreeRegressionModel(m) for m in self._call_java("trees").split(",")] + from pyspark.ml.util import RemoteModelRef + + return [ + DecisionTreeRegressionModel(RemoteModelRef(m)) + for m in self._call_java("trees").split(",") + ] return [DecisionTreeRegressionModel(m) for m in list(self._call_java("trees"))] @property @@ -2005,7 +2010,12 @@ def featureImportances(self) -> Vector: def trees(self) -> List[DecisionTreeRegressionModel]: """Trees in this ensemble. Warning: These have null parent Estimators.""" if is_remote(): - return [DecisionTreeRegressionModel(m) for m in self._call_java("trees").split(",")] + from pyspark.ml.util import RemoteModelRef + + return [ + DecisionTreeRegressionModel(RemoteModelRef(m)) + for m in self._call_java("trees").split(",") + ] return [DecisionTreeRegressionModel(m) for m in list(self._call_java("trees"))] def evaluateEachIteration(self, dataset: DataFrame, loss: str) -> List[float]: diff --git a/python/pyspark/ml/tests/test_tuning.py b/python/pyspark/ml/tests/test_tuning.py index 947c599b3cf25..ff9a26f711975 100644 --- a/python/pyspark/ml/tests/test_tuning.py +++ b/python/pyspark/ml/tests/test_tuning.py @@ -97,7 +97,6 @@ def test_train_validation_split(self): self.assertEqual(str(tvs_model.getEstimator()), str(model2.getEstimator())) self.assertEqual(str(tvs_model.getEvaluator()), str(model2.getEvaluator())) - @unittest.skip("Disabled due to a Python side reference count issue in _parallelFitTasks.") def test_cross_validator(self): dataset = self.spark.createDataFrame( [ diff --git a/python/pyspark/ml/util.py b/python/pyspark/ml/util.py index 6abadec74e63a..a5e0c847c1732 100644 --- a/python/pyspark/ml/util.py +++ b/python/pyspark/ml/util.py @@ -17,6 +17,7 @@ import json import os +import threading import time import uuid import functools @@ -75,7 +76,7 @@ def try_remote_intermediate_result(f: FuncT) -> FuncT: @functools.wraps(f) def wrapped(self: "JavaWrapper") -> Any: if is_remote() and "PYSPARK_NO_NAMESPACE_SHARE" not in os.environ: - return f"{self._java_obj}.{f.__name__}" + return f"{str(self._java_obj)}.{f.__name__}" else: return f(self) @@ -108,13 +109,18 @@ def invoke_remote_attribute_relation( from pyspark.ml.connect.proto import AttributeRelation from pyspark.sql.connect.session import SparkSession from pyspark.sql.connect.dataframe import DataFrame as ConnectDataFrame + from pyspark.ml.wrapper import JavaModel session = SparkSession.getActiveSession() assert session is not None - assert isinstance(instance._java_obj, str) - - methods, obj_ref = _extract_id_methods(instance._java_obj) + if isinstance(instance, JavaModel): + assert isinstance(instance._java_obj, RemoteModelRef) + object_id = instance._java_obj.ref_id + else: + # model summary + object_id = instance._java_obj # type: ignore + methods, obj_ref = _extract_id_methods(object_id) methods.append(pb2.Fetch.Method(method=method, args=serialize(session.client, *args))) plan = AttributeRelation(obj_ref, methods) @@ -139,6 +145,33 @@ def wrapped(self: "JavaWrapper", *args: Any, **kwargs: Any) -> Any: return cast(FuncT, wrapped) +class RemoteModelRef: + def __init__(self, ref_id: str) -> None: + self._ref_id = ref_id + self._ref_count = 1 + self._lock = threading.Lock() + + @property + def ref_id(self) -> str: + return self._ref_id + + def add_ref(self) -> None: + with self._lock: + assert self._ref_count > 0 + self._ref_count += 1 + + def release_ref(self) -> None: + with self._lock: + assert self._ref_count > 0 + self._ref_count -= 1 + if self._ref_count == 0: + # Delete the model if possible + del_remote_cache(self.ref_id) + + def __str__(self) -> str: + return self.ref_id + + def try_remote_fit(f: FuncT) -> FuncT: """Mark the function that fits a model.""" @@ -165,7 +198,8 @@ def wrapped(self: "JavaEstimator", dataset: "ConnectDataFrame") -> Any: (_, properties, _) = client.execute_command(command) model_info = deserialize(properties) client.add_ml_cache(model_info.obj_ref.id) - model = self._create_model(model_info.obj_ref.id) + remote_model_ref = RemoteModelRef(model_info.obj_ref.id) + model = self._create_model(remote_model_ref) if model.__class__.__name__ not in ["Bucketizer"]: model._resetUid(self.uid) return self._copyValues(model) @@ -192,11 +226,11 @@ def wrapped(self: "JavaWrapper", dataset: "ConnectDataFrame") -> Any: if isinstance(self, Model): from pyspark.ml.connect.proto import TransformerRelation - assert isinstance(self._java_obj, str) + assert isinstance(self._java_obj, RemoteModelRef) params = serialize_ml_params(self, session.client) plan = TransformerRelation( child=dataset._plan, - name=self._java_obj, + name=self._java_obj.ref_id, ml_params=params, is_model=True, ) @@ -246,11 +280,20 @@ def wrapped(self: "JavaWrapper", name: str, *args: Any) -> Any: from pyspark.sql.connect.session import SparkSession from pyspark.ml.connect.util import _extract_id_methods from pyspark.ml.connect.serialize import serialize, deserialize + from pyspark.ml.wrapper import JavaModel session = SparkSession.getActiveSession() assert session is not None - assert isinstance(self._java_obj, str) - methods, obj_ref = _extract_id_methods(self._java_obj) + if self._java_obj == ML_CONNECT_HELPER_ID: + obj_id = ML_CONNECT_HELPER_ID + else: + if isinstance(self, JavaModel): + assert isinstance(self._java_obj, RemoteModelRef) + obj_id = self._java_obj.ref_id + else: + # model summary + obj_id = self._java_obj # type: ignore + methods, obj_ref = _extract_id_methods(obj_id) methods.append(pb2.Fetch.Method(method=name, args=serialize(session.client, *args))) command = pb2.Command() command.ml_command.fetch.CopyFrom( @@ -301,10 +344,8 @@ def wrapped(self: "JavaWrapper") -> Any: except Exception: return - if in_remote: - # Delete the model if possible - model_id = self._java_obj - del_remote_cache(cast(str, model_id)) + if in_remote and isinstance(self._java_obj, RemoteModelRef): + self._java_obj.release_ref() return else: return f(self) diff --git a/python/pyspark/ml/wrapper.py b/python/pyspark/ml/wrapper.py index f88045e718a55..b8d86e9eab3b1 100644 --- a/python/pyspark/ml/wrapper.py +++ b/python/pyspark/ml/wrapper.py @@ -356,9 +356,15 @@ def copy(self: "JP", extra: Optional["ParamMap"] = None) -> "JP": if extra is None: extra = dict() that = super(JavaParams, self).copy(extra) - if self._java_obj is not None and not isinstance(self._java_obj, str): - that._java_obj = self._java_obj.copy(self._empty_java_param_map()) - that._transfer_params_to_java() + if self._java_obj is not None: + from pyspark.ml.util import RemoteModelRef + + if isinstance(self._java_obj, RemoteModelRef): + that._java_obj = self._java_obj + self._java_obj.add_ref() + elif not isinstance(self._java_obj, str): + that._java_obj = self._java_obj.copy(self._empty_java_param_map()) + that._transfer_params_to_java() return that @try_remote_intercept @@ -452,6 +458,10 @@ def __init__(self, java_model: Optional["JavaObject"] = None): other ML classes). """ super(JavaModel, self).__init__(java_model) + if is_remote() and java_model is not None: + from pyspark.ml.util import RemoteModelRef + + assert isinstance(java_model, RemoteModelRef) if java_model is not None and not is_remote(): # SPARK-10931: This is a temporary fix to allow models to own params # from estimators. Eventually, these params should be in models through diff --git a/python/pyspark/sql/connect/client/core.py b/python/pyspark/sql/connect/client/core.py index db7f5a135fb08..ca9bdd9b6f0c4 100644 --- a/python/pyspark/sql/connect/client/core.py +++ b/python/pyspark/sql/connect/client/core.py @@ -1981,9 +1981,10 @@ def add_ml_cache(self, cache_id: str) -> None: self.thread_local.ml_caches.add(cache_id) def remove_ml_cache(self, cache_id: str) -> None: + deleted = self._delete_ml_cache([cache_id]) + # TODO: Fix the code: change thread-local `ml_caches` to global `ml_caches`. if hasattr(self.thread_local, "ml_caches"): if cache_id in self.thread_local.ml_caches: - deleted = self._delete_ml_cache([cache_id]) for obj_id in deleted: self.thread_local.ml_caches.remove(obj_id)