Skip to content

Commit c23bbb0

Browse files
authored
Update HF mixin (#910)
* Update mixin * Add reqs for hub lib * Add example to save load share * Add filter warning (not relevant) * Fix typo
1 parent f40b6ed commit c23bbb0

File tree

4 files changed

+282
-42
lines changed

4 files changed

+282
-42
lines changed
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,250 @@
1+
{
2+
"cells": [
3+
{
4+
"cell_type": "code",
5+
"execution_count": 1,
6+
"metadata": {},
7+
"outputs": [],
8+
"source": [
9+
"import segmentation_models_pytorch as smp"
10+
]
11+
},
12+
{
13+
"cell_type": "markdown",
14+
"metadata": {},
15+
"source": [
16+
"## Save to local directory and load back"
17+
]
18+
},
19+
{
20+
"cell_type": "code",
21+
"execution_count": 2,
22+
"metadata": {},
23+
"outputs": [
24+
{
25+
"name": "stdout",
26+
"output_type": "stream",
27+
"text": [
28+
"Loading weights from local directory\n"
29+
]
30+
}
31+
],
32+
"source": [
33+
"model = smp.Unet()\n",
34+
"\n",
35+
"# save the model\n",
36+
"model.save_pretrained(\"saved-model-dir/unet/\")\n",
37+
"\n",
38+
"# load the model\n",
39+
"restored_model = smp.from_pretrained(\"saved-model-dir/unet/\")"
40+
]
41+
},
42+
{
43+
"cell_type": "markdown",
44+
"metadata": {},
45+
"source": [
46+
"## Save model with additional metadata"
47+
]
48+
},
49+
{
50+
"cell_type": "code",
51+
"execution_count": 6,
52+
"metadata": {},
53+
"outputs": [],
54+
"source": [
55+
"model = smp.Unet()\n",
56+
"\n",
57+
"# save the model\n",
58+
"model.save_pretrained(\n",
59+
" \"saved-model-dir/unet-with-metadata/\",\n",
60+
"\n",
61+
" # additional information to be saved with the model\n",
62+
" # only \"dataset\" and \"metrics\" are supported\n",
63+
" dataset=\"PASCAL VOC\", # only string name is supported\n",
64+
" metrics={ # should be a dictionary with metric name as key and metric value as value\n",
65+
" \"mIoU\": 0.95,\n",
66+
" \"accuracy\": 0.96\n",
67+
" }\n",
68+
")"
69+
]
70+
},
71+
{
72+
"cell_type": "code",
73+
"execution_count": 7,
74+
"metadata": {},
75+
"outputs": [
76+
{
77+
"name": "stdout",
78+
"output_type": "stream",
79+
"text": [
80+
"---\n",
81+
"library_name: segmentation-models-pytorch\n",
82+
"license: mit\n",
83+
"pipeline_tag: image-segmentation\n",
84+
"tags:\n",
85+
"- semantic-segmentation\n",
86+
"- pytorch\n",
87+
"- segmentation-models-pytorch\n",
88+
"languages:\n",
89+
"- python\n",
90+
"---\n",
91+
"# Unet Model Card\n",
92+
"\n",
93+
"Table of Contents:\n",
94+
"- [Load trained model](#load-trained-model)\n",
95+
"- [Model init parameters](#model-init-parameters)\n",
96+
"- [Model metrics](#model-metrics)\n",
97+
"- [Dataset](#dataset)\n",
98+
"\n",
99+
"## Load trained model\n",
100+
"```python\n",
101+
"import segmentation_models_pytorch as smp\n",
102+
"\n",
103+
"model = smp.from_pretrained(\"<save-directory-or-this-repo>\")\n",
104+
"```\n",
105+
"\n",
106+
"## Model init parameters\n",
107+
"```python\n",
108+
"model_init_params = {\n",
109+
" \"encoder_name\": \"resnet34\",\n",
110+
" \"encoder_depth\": 5,\n",
111+
" \"encoder_weights\": \"imagenet\",\n",
112+
" \"decoder_use_batchnorm\": True,\n",
113+
" \"decoder_channels\": (256, 128, 64, 32, 16),\n",
114+
" \"decoder_attention_type\": None,\n",
115+
" \"in_channels\": 3,\n",
116+
" \"classes\": 1,\n",
117+
" \"activation\": None,\n",
118+
" \"aux_params\": None\n",
119+
"}\n",
120+
"```\n",
121+
"\n",
122+
"## Model metrics\n",
123+
"```json\n",
124+
"{\n",
125+
" \"mIoU\": 0.95,\n",
126+
" \"accuracy\": 0.96\n",
127+
"}\n",
128+
"```\n",
129+
"\n",
130+
"## Dataset\n",
131+
"Dataset name: PASCAL VOC\n",
132+
"\n",
133+
"## More Information\n",
134+
"- Library: https://github.com/qubvel/segmentation_models.pytorch\n",
135+
"- Docs: https://smp.readthedocs.io/en/latest/\n",
136+
"\n",
137+
"This model has been pushed to the Hub using the [PytorchModelHubMixin](https://huggingface.co/docs/huggingface_hub/package_reference/mixins#huggingface_hub.PyTorchModelHubMixin)"
138+
]
139+
}
140+
],
141+
"source": [
142+
"!cat \"saved-model-dir/unet-with-metadata/README.md\""
143+
]
144+
},
145+
{
146+
"cell_type": "markdown",
147+
"metadata": {},
148+
"source": [
149+
"## Share model with HF Hub"
150+
]
151+
},
152+
{
153+
"cell_type": "code",
154+
"execution_count": 5,
155+
"metadata": {},
156+
"outputs": [
157+
{
158+
"data": {
159+
"application/vnd.jupyter.widget-view+json": {
160+
"model_id": "075ae026811542bdb4030e53b943efc7",
161+
"version_major": 2,
162+
"version_minor": 0
163+
},
164+
"text/plain": [
165+
"VBox(children=(HTML(value='<center> <img\\nsrc=https://huggingface.co/front/assets/huggingface_logo-noborder.sv…"
166+
]
167+
},
168+
"metadata": {},
169+
"output_type": "display_data"
170+
}
171+
],
172+
"source": [
173+
"from huggingface_hub import notebook_login\n",
174+
"\n",
175+
"# You only need to run this once on the machine,\n",
176+
"# the token will be stored for later use\n",
177+
"notebook_login()"
178+
]
179+
},
180+
{
181+
"cell_type": "code",
182+
"execution_count": 8,
183+
"metadata": {},
184+
"outputs": [
185+
{
186+
"data": {
187+
"application/vnd.jupyter.widget-view+json": {
188+
"model_id": "2921a81d7fd747939b4a425cc17d6104",
189+
"version_major": 2,
190+
"version_minor": 0
191+
},
192+
"text/plain": [
193+
"model.safetensors: 0%| | 0.00/97.8M [00:00<?, ?B/s]"
194+
]
195+
},
196+
"metadata": {},
197+
"output_type": "display_data"
198+
},
199+
{
200+
"data": {
201+
"text/plain": [
202+
"CommitInfo(commit_url='https://huggingface.co/qubvel-hf/unet-with-metadata/commit/9f821c7bc3a12db827c0da96a31f354ec6ba5253', commit_message='Push model using huggingface_hub.', commit_description='', oid='9f821c7bc3a12db827c0da96a31f354ec6ba5253', pr_url=None, pr_revision=None, pr_num=None)"
203+
]
204+
},
205+
"execution_count": 8,
206+
"metadata": {},
207+
"output_type": "execute_result"
208+
}
209+
],
210+
"source": [
211+
"model = smp.Unet()\n",
212+
"\n",
213+
"# save the model and share it on the HF Hub (https://huggingface.co/models)\n",
214+
"model.save_pretrained(\n",
215+
" \"qubvel-hf/unet-with-metadata/\",\n",
216+
" push_to_hub=True, # <---------- push the model to the hub\n",
217+
" private=False, # <---------- make the model private or or public\n",
218+
" dataset=\"PASCAL VOC\",\n",
219+
" metrics={\n",
220+
" \"mIoU\": 0.95,\n",
221+
" \"accuracy\": 0.96\n",
222+
" }\n",
223+
")\n",
224+
"\n",
225+
"# see result here https://huggingface.co/qubvel-hf/unet-with-metadata"
226+
]
227+
}
228+
],
229+
"metadata": {
230+
"kernelspec": {
231+
"display_name": ".venv",
232+
"language": "python",
233+
"name": "python3"
234+
},
235+
"language_info": {
236+
"codemirror_mode": {
237+
"name": "ipython",
238+
"version": 3
239+
},
240+
"file_extension": ".py",
241+
"mimetype": "text/x-python",
242+
"name": "python",
243+
"nbconvert_exporter": "python",
244+
"pygments_lexer": "ipython3",
245+
"version": "3.10.12"
246+
}
247+
},
248+
"nbformat": 4,
249+
"nbformat_minor": 2
250+
}

requirements.txt

+1
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@ torchvision>=0.5.0
22
pretrainedmodels==0.7.4
33
efficientnet-pytorch==0.7.1
44
timm==0.9.7
5+
huggingface_hub>=0.24.6
56

67
tqdm
78
pillow

segmentation_models_pytorch/__init__.py

+5
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
import warnings
2+
13
from . import datasets
24
from . import encoders
35
from . import decoders
@@ -20,6 +22,9 @@
2022
from typing import Optional as _Optional
2123
import torch as _torch
2224

25+
# Suppress the specific SyntaxWarning for `pretrainedmodels`
26+
warnings.filterwarnings("ignore", message="is with a literal", category=SyntaxWarning)
27+
2328

2429
def create_model(
2530
arch: str,

segmentation_models_pytorch/base/hub_mixin.py

+26-42
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@
2626
```python
2727
import segmentation_models_pytorch as smp
2828
29-
model = smp.{{ model_name }}.from_pretrained("{{ save_directory | default("<save-directory-or-repo>", true)}}")
29+
model = smp.from_pretrained("<save-directory-or-this-repo>")
3030
```
3131
3232
## Model init parameters
@@ -61,23 +61,22 @@ def _format_parameters(parameters: dict):
6161

6262
class SMPHubMixin(PyTorchModelHubMixin):
6363
def generate_model_card(self, *args, **kwargs) -> ModelCard:
64-
model_parameters_json = _format_parameters(self._hub_mixin_config)
65-
directory = self._save_directory if hasattr(self, "_save_directory") else None
66-
repo_id = self._repo_id if hasattr(self, "_repo_id") else None
67-
repo_or_directory = repo_id if repo_id is not None else directory
68-
69-
metrics = self._metrics if hasattr(self, "_metrics") else None
70-
dataset = self._dataset if hasattr(self, "_dataset") else None
64+
model_parameters_json = _format_parameters(self.config)
65+
metrics = kwargs.get("metrics", None)
66+
dataset = kwargs.get("dataset", None)
7167

7268
if metrics is not None:
7369
metrics = json.dumps(metrics, indent=4)
7470
metrics = f"```json\n{metrics}\n```"
7571

72+
tags = self._hub_mixin_info.model_card_data.get("tags", []) or []
73+
tags.extend(["segmentation-models-pytorch", "semantic-segmentation", "pytorch"])
74+
7675
model_card_data = ModelCardData(
7776
languages=["python"],
7877
library_name="segmentation-models-pytorch",
7978
license="mit",
80-
tags=["semantic-segmentation", "pytorch", "segmentation-models-pytorch"],
79+
tags=tags,
8180
pipeline_tag="image-segmentation",
8281
)
8382
model_card = ModelCard.from_template(
@@ -86,64 +85,49 @@ def generate_model_card(self, *args, **kwargs) -> ModelCard:
8685
repo_url="https://github.com/qubvel/segmentation_models.pytorch",
8786
docs_url="https://smp.readthedocs.io/en/latest/",
8887
model_parameters=model_parameters_json,
89-
save_directory=repo_or_directory,
9088
model_name=self.__class__.__name__,
9189
metrics=metrics,
9290
dataset=dataset,
9391
)
9492
return model_card
9593

96-
def _set_attrs_from_kwargs(self, attrs, kwargs):
97-
for attr in attrs:
98-
if attr in kwargs:
99-
setattr(self, f"_{attr}", kwargs.pop(attr))
100-
101-
def _del_attrs(self, attrs):
102-
for attr in attrs:
103-
if hasattr(self, f"_{attr}"):
104-
delattr(self, f"_{attr}")
105-
10694
@wraps(PyTorchModelHubMixin.save_pretrained)
10795
def save_pretrained(
10896
self, save_directory: Union[str, Path], *args, **kwargs
10997
) -> Optional[str]:
110-
# set additional attributes to be used in generate_model_card
111-
self._save_directory = save_directory
112-
self._set_attrs_from_kwargs(["metrics", "dataset"], kwargs)
98+
model_card_kwargs = kwargs.pop("model_card_kwargs", {})
99+
if "dataset" in kwargs:
100+
model_card_kwargs["dataset"] = kwargs.pop("dataset")
101+
if "metrics" in kwargs:
102+
model_card_kwargs["metrics"] = kwargs.pop("metrics")
103+
kwargs["model_card_kwargs"] = model_card_kwargs
113104

114-
# set additional attribute to be used in from_pretrained
115-
self._hub_mixin_config["_model_class"] = self.__class__.__name__
105+
# set additional attribute to be able to deserialize the model
106+
self.config["_model_class"] = self.__class__.__name__
116107

117108
try:
118109
# call the original save_pretrained
119110
result = super().save_pretrained(save_directory, *args, **kwargs)
120111
finally:
121-
# delete the additional attributes
122-
self._del_attrs(["save_directory", "metrics", "dataset"])
123-
self._hub_mixin_config.pop("_model_class", None)
112+
self.config.pop("_model_class", None)
124113

125114
return result
126115

127-
@wraps(PyTorchModelHubMixin.push_to_hub)
128-
def push_to_hub(self, repo_id: str, *args, **kwargs):
129-
self._repo_id = repo_id
130-
self._set_attrs_from_kwargs(["metrics", "dataset"], kwargs)
131-
result = super().push_to_hub(repo_id, *args, **kwargs)
132-
self._del_attrs(["repo_id", "metrics", "dataset"])
133-
return result
134-
135116
@property
136-
def config(self):
117+
def config(self) -> dict:
137118
return self._hub_mixin_config
138119

139120

140121
@wraps(PyTorchModelHubMixin.from_pretrained)
141122
def from_pretrained(pretrained_model_name_or_path: str, *args, **kwargs):
142-
config_path = hf_hub_download(
143-
pretrained_model_name_or_path,
144-
filename="config.json",
145-
revision=kwargs.get("revision", None),
146-
)
123+
config_path = Path(pretrained_model_name_or_path) / "config.json"
124+
if not config_path.exists():
125+
config_path = hf_hub_download(
126+
pretrained_model_name_or_path,
127+
filename="config.json",
128+
revision=kwargs.get("revision", None),
129+
)
130+
147131
with open(config_path, "r") as f:
148132
config = json.load(f)
149133
model_class_name = config.pop("_model_class")

0 commit comments

Comments
 (0)