Skip to content

[SPARK-51880][ML][PYTHON][CONNECT] Fix ML cache object python client references #50707

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 15 commits into from
14 changes: 12 additions & 2 deletions python/pyspark/ml/classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -2306,7 +2306,12 @@ def featureImportances(self) -> Vector:
def trees(self) -> List[DecisionTreeClassificationModel]:
"""Trees in this ensemble. Warning: These have null parent Estimators."""
if is_remote():
return [DecisionTreeClassificationModel(m) for m in self._call_java("trees").split(",")]
from pyspark.ml.util import RemoteModelRef

return [
DecisionTreeClassificationModel(RemoteModelRef(m))
for m in self._call_java("trees").split(",")
]
return [DecisionTreeClassificationModel(m) for m in list(self._call_java("trees"))]

@property
Expand Down Expand Up @@ -2805,7 +2810,12 @@ def featureImportances(self) -> Vector:
def trees(self) -> List[DecisionTreeRegressionModel]:
"""Trees in this ensemble. Warning: These have null parent Estimators."""
if is_remote():
return [DecisionTreeRegressionModel(m) for m in self._call_java("trees").split(",")]
from pyspark.ml.util import RemoteModelRef

return [
DecisionTreeRegressionModel(RemoteModelRef(m))
for m in self._call_java("trees").split(",")
]
return [DecisionTreeRegressionModel(m) for m in list(self._call_java("trees"))]

def evaluateEachIteration(self, dataset: DataFrame) -> List[float]:
Expand Down
11 changes: 8 additions & 3 deletions python/pyspark/ml/connect/readwrite.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,11 +77,13 @@ def saveInstance(
# Spark Connect ML is built on scala Spark.ML, that means we're only
# supporting JavaModel or JavaEstimator or JavaEvaluator
if isinstance(instance, JavaModel):
from pyspark.ml.util import RemoteModelRef

model = cast("JavaModel", instance)
params = serialize_ml_params(model, session.client)
assert isinstance(model._java_obj, str)
assert isinstance(model._java_obj, RemoteModelRef)
writer = pb2.MlCommand.Write(
obj_ref=pb2.ObjectRef(id=model._java_obj),
obj_ref=pb2.ObjectRef(id=model._java_obj.ref_id),
params=params,
path=path,
should_overwrite=shouldOverwrite,
Expand Down Expand Up @@ -270,9 +272,12 @@ def _get_class() -> Type[RL]:
py_type = _get_class()
# It must be JavaWrapper, since we're passing the string to the _java_obj
if issubclass(py_type, JavaWrapper):
from pyspark.ml.util import RemoteModelRef

if ml_type == pb2.MlOperator.OPERATOR_TYPE_MODEL:
session.client.add_ml_cache(result.obj_ref.id)
instance = py_type(result.obj_ref.id)
remote_model_ref = RemoteModelRef(result.obj_ref.id)
instance = py_type(remote_model_ref)
else:
instance = py_type()
instance._resetUid(result.uid)
Expand Down
37 changes: 22 additions & 15 deletions python/pyspark/ml/feature.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@
_jvm,
)
from pyspark.ml.common import inherit_doc
from pyspark.ml.util import RemoteModelRef
from pyspark.sql.types import ArrayType, StringType
from pyspark.sql.utils import is_remote

Expand Down Expand Up @@ -1224,10 +1225,12 @@ def from_vocabulary(

if is_remote():
model = CountVectorizerModel()
model._java_obj = invoke_helper_attr(
"countVectorizerModelFromVocabulary",
model.uid,
list(vocabulary),
model._java_obj = RemoteModelRef(
invoke_helper_attr(
"countVectorizerModelFromVocabulary",
model.uid,
list(vocabulary),
)
)

else:
Expand Down Expand Up @@ -4843,10 +4846,12 @@ def from_labels(
"""
if is_remote():
model = StringIndexerModel()
model._java_obj = invoke_helper_attr(
"stringIndexerModelFromLabels",
model.uid,
(list(labels), ArrayType(StringType())),
model._java_obj = RemoteModelRef(
invoke_helper_attr(
"stringIndexerModelFromLabels",
model.uid,
(list(labels), ArrayType(StringType())),
)
)

else:
Expand Down Expand Up @@ -4882,13 +4887,15 @@ def from_arrays_of_labels(
"""
if is_remote():
model = StringIndexerModel()
model._java_obj = invoke_helper_attr(
"stringIndexerModelFromLabelsArray",
model.uid,
(
[list(labels) for labels in arrayOfLabels],
ArrayType(ArrayType(StringType())),
),
model._java_obj = RemoteModelRef(
invoke_helper_attr(
"stringIndexerModelFromLabelsArray",
model.uid,
(
[list(labels) for labels in arrayOfLabels],
ArrayType(ArrayType(StringType())),
),
)
)

else:
Expand Down
14 changes: 12 additions & 2 deletions python/pyspark/ml/regression.py
Original file line number Diff line number Diff line change
Expand Up @@ -1614,7 +1614,12 @@ class RandomForestRegressionModel(
def trees(self) -> List[DecisionTreeRegressionModel]:
"""Trees in this ensemble. Warning: These have null parent Estimators."""
if is_remote():
return [DecisionTreeRegressionModel(m) for m in self._call_java("trees").split(",")]
from pyspark.ml.util import RemoteModelRef

return [
DecisionTreeRegressionModel(RemoteModelRef(m))
for m in self._call_java("trees").split(",")
]
return [DecisionTreeRegressionModel(m) for m in list(self._call_java("trees"))]

@property
Expand Down Expand Up @@ -2005,7 +2010,12 @@ def featureImportances(self) -> Vector:
def trees(self) -> List[DecisionTreeRegressionModel]:
"""Trees in this ensemble. Warning: These have null parent Estimators."""
if is_remote():
return [DecisionTreeRegressionModel(m) for m in self._call_java("trees").split(",")]
from pyspark.ml.util import RemoteModelRef

return [
DecisionTreeRegressionModel(RemoteModelRef(m))
for m in self._call_java("trees").split(",")
]
return [DecisionTreeRegressionModel(m) for m in list(self._call_java("trees"))]

def evaluateEachIteration(self, dataset: DataFrame, loss: str) -> List[float]:
Expand Down
1 change: 0 additions & 1 deletion python/pyspark/ml/tests/test_tuning.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,6 @@ def test_train_validation_split(self):
self.assertEqual(str(tvs_model.getEstimator()), str(model2.getEstimator()))
self.assertEqual(str(tvs_model.getEvaluator()), str(model2.getEvaluator()))

@unittest.skip("Disabled due to a Python side reference count issue in _parallelFitTasks.")
def test_cross_validator(self):
dataset = self.spark.createDataFrame(
[
Expand Down
67 changes: 54 additions & 13 deletions python/pyspark/ml/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

import json
import os
import threading
import time
import uuid
import functools
Expand Down Expand Up @@ -75,7 +76,7 @@ def try_remote_intermediate_result(f: FuncT) -> FuncT:
@functools.wraps(f)
def wrapped(self: "JavaWrapper") -> Any:
if is_remote() and "PYSPARK_NO_NAMESPACE_SHARE" not in os.environ:
return f"{self._java_obj}.{f.__name__}"
return f"{str(self._java_obj)}.{f.__name__}"
else:
return f(self)

Expand Down Expand Up @@ -108,13 +109,18 @@ def invoke_remote_attribute_relation(
from pyspark.ml.connect.proto import AttributeRelation
from pyspark.sql.connect.session import SparkSession
from pyspark.sql.connect.dataframe import DataFrame as ConnectDataFrame
from pyspark.ml.wrapper import JavaModel

session = SparkSession.getActiveSession()
assert session is not None

assert isinstance(instance._java_obj, str)

methods, obj_ref = _extract_id_methods(instance._java_obj)
if isinstance(instance, JavaModel):
assert isinstance(instance._java_obj, RemoteModelRef)
object_id = instance._java_obj.ref_id
else:
# model summary
object_id = instance._java_obj # type: ignore
methods, obj_ref = _extract_id_methods(object_id)
methods.append(pb2.Fetch.Method(method=method, args=serialize(session.client, *args)))
plan = AttributeRelation(obj_ref, methods)

Expand All @@ -139,6 +145,33 @@ def wrapped(self: "JavaWrapper", *args: Any, **kwargs: Any) -> Any:
return cast(FuncT, wrapped)


class RemoteModelRef:
def __init__(self, ref_id: str) -> None:
self._ref_id = ref_id
self._ref_count = 1
self._lock = threading.Lock()

@property
def ref_id(self) -> str:
return self._ref_id

def add_ref(self) -> None:
with self._lock:
assert self._ref_count > 0
self._ref_count += 1

def release_ref(self) -> None:
with self._lock:
assert self._ref_count > 0
self._ref_count -= 1
if self._ref_count == 0:
# Delete the model if possible
del_remote_cache(self.ref_id)

def __str__(self) -> str:
return self.ref_id


def try_remote_fit(f: FuncT) -> FuncT:
"""Mark the function that fits a model."""

Expand All @@ -165,7 +198,8 @@ def wrapped(self: "JavaEstimator", dataset: "ConnectDataFrame") -> Any:
(_, properties, _) = client.execute_command(command)
model_info = deserialize(properties)
client.add_ml_cache(model_info.obj_ref.id)
model = self._create_model(model_info.obj_ref.id)
remote_model_ref = RemoteModelRef(model_info.obj_ref.id)
model = self._create_model(remote_model_ref)
if model.__class__.__name__ not in ["Bucketizer"]:
model._resetUid(self.uid)
return self._copyValues(model)
Expand All @@ -192,11 +226,11 @@ def wrapped(self: "JavaWrapper", dataset: "ConnectDataFrame") -> Any:
if isinstance(self, Model):
from pyspark.ml.connect.proto import TransformerRelation

assert isinstance(self._java_obj, str)
assert isinstance(self._java_obj, RemoteModelRef)
params = serialize_ml_params(self, session.client)
plan = TransformerRelation(
child=dataset._plan,
name=self._java_obj,
name=self._java_obj.ref_id,
ml_params=params,
is_model=True,
)
Expand Down Expand Up @@ -246,11 +280,20 @@ def wrapped(self: "JavaWrapper", name: str, *args: Any) -> Any:
from pyspark.sql.connect.session import SparkSession
from pyspark.ml.connect.util import _extract_id_methods
from pyspark.ml.connect.serialize import serialize, deserialize
from pyspark.ml.wrapper import JavaModel

session = SparkSession.getActiveSession()
assert session is not None
assert isinstance(self._java_obj, str)
methods, obj_ref = _extract_id_methods(self._java_obj)
if self._java_obj == ML_CONNECT_HELPER_ID:
obj_id = ML_CONNECT_HELPER_ID
else:
if isinstance(self, JavaModel):
assert isinstance(self._java_obj, RemoteModelRef)
obj_id = self._java_obj.ref_id
else:
# model summary
obj_id = self._java_obj # type: ignore
methods, obj_ref = _extract_id_methods(obj_id)
methods.append(pb2.Fetch.Method(method=name, args=serialize(session.client, *args)))
command = pb2.Command()
command.ml_command.fetch.CopyFrom(
Expand Down Expand Up @@ -301,10 +344,8 @@ def wrapped(self: "JavaWrapper") -> Any:
except Exception:
return

if in_remote:
# Delete the model if possible
model_id = self._java_obj
del_remote_cache(cast(str, model_id))
if in_remote and isinstance(self._java_obj, RemoteModelRef):
self._java_obj.release_ref()
return
else:
return f(self)
Expand Down
16 changes: 13 additions & 3 deletions python/pyspark/ml/wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -356,9 +356,15 @@ def copy(self: "JP", extra: Optional["ParamMap"] = None) -> "JP":
if extra is None:
extra = dict()
that = super(JavaParams, self).copy(extra)
if self._java_obj is not None and not isinstance(self._java_obj, str):
that._java_obj = self._java_obj.copy(self._empty_java_param_map())
that._transfer_params_to_java()
if self._java_obj is not None:
from pyspark.ml.util import RemoteModelRef

if isinstance(self._java_obj, RemoteModelRef):
that._java_obj = self._java_obj
self._java_obj.add_ref()
elif not isinstance(self._java_obj, str):
that._java_obj = self._java_obj.copy(self._empty_java_param_map())
that._transfer_params_to_java()
return that

@try_remote_intercept
Expand Down Expand Up @@ -452,6 +458,10 @@ def __init__(self, java_model: Optional["JavaObject"] = None):
other ML classes).
"""
super(JavaModel, self).__init__(java_model)
if is_remote() and java_model is not None:
from pyspark.ml.util import RemoteModelRef

assert isinstance(java_model, RemoteModelRef)
if java_model is not None and not is_remote():
# SPARK-10931: This is a temporary fix to allow models to own params
# from estimators. Eventually, these params should be in models through
Expand Down
3 changes: 2 additions & 1 deletion python/pyspark/sql/connect/client/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -1981,9 +1981,10 @@ def add_ml_cache(self, cache_id: str) -> None:
self.thread_local.ml_caches.add(cache_id)

def remove_ml_cache(self, cache_id: str) -> None:
deleted = self._delete_ml_cache([cache_id])
# TODO: Fix the code: change thread-local `ml_caches` to global `ml_caches`.
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

TODO

if hasattr(self.thread_local, "ml_caches"):
if cache_id in self.thread_local.ml_caches:
deleted = self._delete_ml_cache([cache_id])
for obj_id in deleted:
self.thread_local.ml_caches.remove(obj_id)

Expand Down