Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 7 additions & 4 deletions mlflow_export_import/model/export_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,13 @@
Export a registered model and all the experiment runs associated with each version.
"""

import json
import os
import click
import mlflow
from mlflow.utils.proto_json_utils import message_to_json

from mlflow_export_import.common import MlflowExportImportException
from mlflow_export_import.common.http_client import MlflowHttpClient
from mlflow_export_import.common import filesystem as _filesystem
from mlflow_export_import.run.export_run import RunExporter
from mlflow_export_import import utils, click_doc
Expand All @@ -23,7 +24,6 @@ def __init__(self, mlflow_client, export_source_tags=False, notebook_formats=No
:param export_run: Export the run that generated a registered model's version.
"""
self.mlflow_client = mlflow_client
self.http_client = MlflowHttpClient()
self.run_exporter = RunExporter(self.mlflow_client, export_source_tags=export_source_tags, notebook_formats=notebook_formats)
self.stages = self._normalize_stages(stages)
self.export_run = export_run
Expand Down Expand Up @@ -85,7 +85,10 @@ def _export_model(self, model_name, output_dir):
traceback.print_exc()
output_versions.sort(key=lambda x: x["version"], reverse=False)

model = self.http_client.get(f"registered-models/get", {"name": model_name})
model_obj = self.mlflow_client.get_registered_model(model_name)
model_proto = model_obj.to_proto()
model = json.loads(message_to_json(model_proto))

export_info = { "export_info":
{ **utils.create_export_info(),
**{ "num_target_stages": len(self.stages),
Expand All @@ -95,7 +98,7 @@ def _export_model(self, model_name, output_dir):
}
}
}
model = { **export_info, **model }
model = {'registered_model': model, **export_info }
model["registered_model"]["latest_versions"] = output_versions

print(f"Exported {exported_versions}/{len(output_versions)} versions for model '{model_name}'")
Expand Down
7 changes: 4 additions & 3 deletions mlflow_export_import/model/import_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
"""

import os
from urllib.parse import urlparse
import click

import mlflow
Expand Down Expand Up @@ -45,8 +46,8 @@ def _import_version(self, model_name, src_vr, dst_run_id, dst_source, sleep_time
:param sleep_time: Seconds to wait for model version crreation.
"""
src_current_stage = src_vr["current_stage"]
dst_source = dst_source.replace("file://","") # OSS MLflow
if not dst_source.startswith("dbfs:") and not os.path.exists(dst_source):
parsed_dst_source = urlparse(dst_source)
if parsed_dst_source.scheme == "file" and not os.path.exists(dst_source):
raise MlflowExportImportException(f"'source' argument for MLflowClient.create_model_version does not exist: {dst_source}")
kwargs = {"await_creation_for": self.await_creation_for } if self.await_creation_for else {}
tags = src_vr["tags"]
Expand Down Expand Up @@ -189,7 +190,7 @@ def import_model(self, model_name, input_dir, delete_model=False, verbose=False,
for vr in model_dct["latest_versions"]:
src_run_id = vr["run_id"]
dst_run_id = self.run_info_map[src_run_id].run_id
mlflow.set_experiment(vr["_experiment_name"])
# mlflow.set_experiment(vr["_experiment_name"]) Is it thread-safe?
self.import_version(model_name, vr, dst_run_id, sleep_time)
if verbose:
model_utils.dump_model_versions(self.mlflow_client, model_name)
Expand Down