diff --git a/mlflow_export_import/bulk/import_models.py b/mlflow_export_import/bulk/import_models.py index d29d6865..cc5eec16 100644 --- a/mlflow_export_import/bulk/import_models.py +++ b/mlflow_export_import/bulk/import_models.py @@ -40,7 +40,7 @@ def _import_experiments(client, input_dir, use_src_user_id): for exp in exps: exp_input_dir = os.path.join(input_dir, "experiments", exp["id"]) try: - _run_info_map = importer.import_experiment( exp["name"], exp_input_dir) + _run_info_map = importer.import_experiment(f'/Shared/{exp["name"]}', exp_input_dir) # MATCHING DBX PATH run_info_map[exp["id"]] = _run_info_map except Exception as e: exceptions.append(str(e)) diff --git a/mlflow_export_import/model/import_model.py b/mlflow_export_import/model/import_model.py index 2aa4ae17..54743a4e 100644 --- a/mlflow_export_import/model/import_model.py +++ b/mlflow_export_import/model/import_model.py @@ -4,6 +4,7 @@ import os import click +from urllib.parse import urlparse import mlflow from mlflow.exceptions import RestException @@ -54,7 +55,6 @@ def _import_version(self, model_name, src_vr, dst_run_id, dst_source, sleep_time tags = src_vr["tags"] if self.import_source_tags: _set_source_tags_for_field(src_vr, tags) - dst_vr = self.mlflow_client.create_model_version( model_name, dst_source, dst_run_id, \ @@ -190,7 +190,7 @@ def import_model(self, model_name, input_dir, delete_model=False, verbose=False, for vr in model_dct["versions"]: src_run_id = vr["run_id"] dst_run_id = self.run_info_map[src_run_id].run_id - mlflow.set_experiment(vr["_experiment_name"]) + mlflow.set_experiment(f'/Shared/{vr["_experiment_name"]}') # MATCHING DBX PATH self.import_version(model_name, vr, dst_run_id, sleep_time) if verbose: model_utils.dump_model_versions(self.mlflow_client, model_name) @@ -211,6 +211,14 @@ def _extract_model_path(source, run_id): :param run_id: Run ID in the 'source field :return: relative path to the model artifact """ + if source[:5] == "s3://": # check if source is s3 bucket + # bucket name may contain 'artifacts', this bypasses the bucket name + pattern = "artifacts" + parsed_s3 = urlparse(source) + s3_path = parsed_s3.path + idx = s3_path.find(pattern) + model_path = s3_path[1+idx+len(pattern):] + return model_path idx = source.find(run_id) if idx == -1: raise MlflowExportImportException(f"Cannot find run ID '{run_id}' in registered model version source field '{source}'", http_status_code=404)