@@ -17,20 +17,19 @@ class ModelExporter():
1717
1818 def __init__ (self , mlflow_client , notebook_formats = None , stages = None , versions = None , export_run = True ):
1919 """
20- :param mlflow_client: MLflow client or if None create default client.
20+ :param mlflow_client: MlflowClient
2121 :param notebook_formats: List of notebook formats to export. Values are SOURCE, HTML, JUPYTER or DBC.
2222 :param stages: Stages to export. Default is all stages. Values are Production, Staging, Archived and None.
2323 :param export_run: Export the run that generated a registered model's version.
2424 """
2525 self .mlflow_client = mlflow_client
2626 self .http_client = MlflowHttpClient ()
2727 self .run_exporter = RunExporter (self .mlflow_client , notebook_formats = notebook_formats )
28- self .stages = self ._normalize_stages (stages )
2928 self .export_run = export_run
29+ self .stages = self ._normalize_stages (stages )
3030 self .versions = versions if versions else []
3131 if len (self .stages ) > 0 and len (self .versions ) > 0 :
3232 raise MlflowExportImportException (f"Both stages { self .stages } and versions { self .versions } cannot be set" )
33- self .export_run = export_run
3433
3534
3635 def export_model (self , model_name , output_dir ):
@@ -62,7 +61,7 @@ def _export_model(self, model_name, output_dir):
6261 opath = os .path .join (output_dir ,run_id )
6362 opath = opath .replace ("dbfs:" , "/dbfs" )
6463 dct = { "version" : vr .version , "stage" : vr .current_stage , "run_id" : run_id , "description" : vr .description , "tags" : vr .tags }
65- print (f"Exporting version: { dct } " )
64+ print (f"Exporting verions { vr . version } to ' { opath } ' " )
6665 manifest .append (dct )
6766 try :
6867 if self .export_run :
@@ -85,13 +84,13 @@ def _export_model(self, model_name, output_dir):
8584 model = self .http_client .get (f"registered-models/get" , {"name" : model_name })
8685 model ["registered_model" ]["latest_versions" ] = output_versions
8786
88- custom_info = {
87+ info_attr = {
8988 "num_target_stages" : len (self .stages ),
9089 "num_target_versions" : len (self .versions ),
9190 "num_src_versions" : len (versions ),
9291 "num_dst_versions" : len (output_versions )
9392 }
94- io_utils .write_export_file (output_dir , "model.json" , model , custom_info )
93+ io_utils .write_export_file (output_dir , "model.json" , __file__ , model , info_attr )
9594
9695 print (f"Exported { exported_versions } /{ len (output_versions )} versions for model '{ model_name } '" )
9796 return manifest
0 commit comments