Skip to content

Commit

Permalink
Fix/modelrepository get model repo (#195)
Browse files Browse the repository at this point in the history
* fix: change get_model_repo func to normal func instead of class method

* build: update version to 2.1.5

* fix: remove log
  • Loading branch information
phamhoangtuan authored Jun 29, 2023
1 parent 8f2cec0 commit 9f05a87
Show file tree
Hide file tree
Showing 3 changed files with 8 additions and 9 deletions.
5 changes: 3 additions & 2 deletions h1st/model/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ def __init__(self):
self.stats = {}
self.metrics = {}
self.base_model = None
self.model_repo = ModelRepository()

def persist(self, version=None) -> str:
"""
Expand All @@ -58,15 +59,15 @@ def persist(self, version=None) -> str:
:param version: model version, leave blank for autogeneration
:returns: model version
"""
repo = ModelRepository.get_model_repo(self)
repo = self.model_repo.get_model_repo(self)
return repo.persist(model=self, version=version)

def load(self, version: str = None) -> Any:
"""
Load parameters from the specified `version` from the ModelRepository.
Leave version blank to load latest version.
"""
repo = ModelRepository.get_model_repo(self)
repo = self.model_repo.get_model_repo(self)
repo.load(model=self, version=version)

return self
Expand Down
10 changes: 4 additions & 6 deletions h1st/model/repository/model_repository.py
Original file line number Diff line number Diff line change
Expand Up @@ -437,7 +437,6 @@ def persist(self, model, version=None):
logger.exception(f'Error persisting model {model} version {version}: {e}')
finally:
dir_util.remove_tree(tmpdir)
logger.info(f'Removed temp dir {tmpdir}')

return version

Expand Down Expand Up @@ -526,15 +525,14 @@ def _get_key(self, model, version):

return key

@classmethod
def get_model_repo(cls, ref=None):
def get_model_repo(self, ref=None):
"""
Retrieve the default model repository for the project
:param ref: target model
:returns: Model repository instance
"""
if not hasattr(cls, "MODEL_REPO"): # global ModelRepository.MODEL_REPO
if not hasattr(self, "MODEL_REPO"): # ModelRepository.MODEL_REPO
repo_path = None
if ref is not None:
# root module
Expand Down Expand Up @@ -568,9 +566,9 @@ def get_model_repo(cls, ref=None):
if not repo_path:
raise RuntimeError("Please set MODEL_REPO_PATH in config.py")

setattr(cls, "MODEL_REPO", ModelRepository(storage=repo_path))
self.MODEL_REPO = ModelRepository(storage=repo_path)

return getattr(cls, "MODEL_REPO")
return self.MODEL_REPO


def _tar_create(target, source):
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[tool.poetry]
name = "h1st"
version = "2.1.4"
version = "2.1.5"
description = "Human-First AI (H1st)"
authors = ["Aitomatic, Inc. <[email protected]>"]
license = "Apache-2.0"
Expand Down

0 comments on commit 9f05a87

Please sign in to comment.