From b4d7111fe39c70c9230d6597fd22a71b747b3fd9 Mon Sep 17 00:00:00 2001 From: Simon Berger Date: Mon, 3 Jun 2024 12:10:23 +0200 Subject: [PATCH 01/10] Add improved version of build_config_constructor_serializers function --- .../setups/returnn_pytorch/serialization.py | 176 +++++++++++++++++- 1 file changed, 175 insertions(+), 1 deletion(-) diff --git a/common/setups/returnn_pytorch/serialization.py b/common/setups/returnn_pytorch/serialization.py index 9f333278b..8f3942690 100644 --- a/common/setups/returnn_pytorch/serialization.py +++ b/common/setups/returnn_pytorch/serialization.py @@ -7,13 +7,14 @@ import textwrap from collections import OrderedDict from dataclasses import fields +from enum import Enum from inspect import isfunction from typing import Any, Dict, List, Optional, Set, Tuple, Union, TYPE_CHECKING import torch from i6_core.util import instanciate_delayed from sisyphus import gs, tk -from sisyphus.delayed_ops import DelayedBase +from sisyphus.delayed_ops import DelayedBase, DelayedFormat from sisyphus.hash import sis_hash_helper if TYPE_CHECKING: @@ -251,3 +252,176 @@ def build_config_constructor_serializers( imports = list(OrderedDict.fromkeys(imports)) # remove duplications return Call(callable_name=type(cfg).__name__, kwargs=call_kwargs, return_assign_variables=variable_name), imports + +def deduplicate_list_by_hash(orig_list: list) -> list: + seen_hashes = set() + unique_objects = [] + for obj in orig_list: + obj_hash = hash(obj) + if obj_hash not in seen_hashes: + seen_hashes.add(obj_hash) + unique_objects.append(obj) + return unique_objects + + +def build_config_constructor_serializers_v2( + cfg: ModelConfiguration, variable_name: Optional[str] = None, unhashed_package_root: Optional[str] = None +) -> Tuple[Call, List[Import]]: + """ + Creates a Call object that will re-construct the given ModelConfiguration when serialized and + optionally assigns the resulting config object to a variable. Automatically generates a list of all + necessary imports in order to perform the constructor call. + + Compared to the previous version, this function can also serialize enum members and values of type + list, tuple or dict. It also fixes import deduplication. + + :param cfg: ModelConfiguration object that will be re-constructed by the Call serializer + :param variable_name: Name of the variable which the constructed ModelConfiguration + will be assigned to. If None, the result will not be assigned + to a variable. + :param unhashed_package_root: Will be passed to all generated Import objects. + :return: Call object and list of necessary imports. + """ + from i6_models.config import ModelConfiguration, ModuleFactoryV1 + + # Helper function which can call itself recursively for nested types + def serialize_value(value: Any) -> Tuple[Union[str, DelayedBase], List[Import]]: + # Switch over serialization logic for different subtypes + + if isinstance(value, ModelConfiguration): + # Example: + # ConformerBlockConfig(mhsa_config=ConformerMHSAConfig(...)) + # -> Sub-Constructor-Call and imports for ConformerMHSAConfig + return build_config_constructor_serializers_v2(value, unhashed_package_root=unhashed_package_root) + elif isinstance(value, ModuleFactoryV1): + # Example: + # ConformerEncoderConfig( + # frontend=ModuleFactoryV1(module_class=VGGFrontend, cfg=VGGFrontendConfig(...))) + # -> Import classes ModuleFactoryV1, VGGFrontend and VGGFrontendConfig + # -> Sub-Constructor-Call for VGGFrontendConfig + subcall, subimports = build_config_constructor_serializers_v2( + value.cfg, unhashed_package_root=unhashed_package_root + ) + subimports.append( + Import( + code_object_path=f"{value.module_class.__module__}.{value.module_class.__name__}", + unhashed_package_root=unhashed_package_root, + ) + ) + subimports.append( + Import( + code_object_path=f"{ModuleFactoryV1.__module__}.{ModuleFactoryV1.__name__}", + unhashed_package_root=unhashed_package_root, + ) + ) + return ( + Call( + callable_name=ModuleFactoryV1.__name__, + kwargs=[("module_class", value.module_class.__name__), ("cfg", subcall)], + ), + subimports, + ) + elif isinstance(value, torch.nn.Module): + # Example: + # ConformerConvolutionConfig(norm=BatchNorm1d(...)) + # -> Import class BatchNorm1d + # -> Sub-serialization of BatchNorm1d object. + # The __str__ function of torch.nn.Module already does this in the way we want. + return str(value), [ + Import( + code_object_path=f"{value.__module__}.{type(value).__name__}", + unhashed_package_root=unhashed_package_root, + ) + ] + elif isfunction(value): + # Example: + # ConformerConvolutionConfig(activation=torch.nn.functional.silu) + # -> Import function silu + # Builtins (e.g. 'sum') do not need to be imported + if value.__module__ != "builtins": + subimports = [ + Import( + code_object_path=f"{value.__module__}.{value.__name__}", + unhashed_package_root=unhashed_package_root, + ) + ] + else: + subimports = [] + return value.__name__, subimports + elif isinstance(value, Enum): + # Example: + # FrontendLayerType.Conv2d + # -> Import enum class FrontendLayerType + subimports = [ + Import( + code_object_path=f"{value.__class__.__module__}.{value.__class__.__name__}", + unhashed_package_root=unhashed_package_root, + ) + ] + return f"{value.__class__.__name__}.{value.name}", subimports + elif isinstance(value, list): + # -> Serialize list values individually, collect subimports + list_items = [] + list_imports = [] + for item in value: + item_serialized, item_imports = serialize_value(item) + list_items.append(item_serialized) + list_imports += item_imports + return DelayedFormat(f"[{', '.join(['{}'] * len(list_items))}]", *list_items), list_imports + elif isinstance(value, tuple): + # -> Serialize tuple values individually, collect subimports + tuple_items = [] + tuple_imports = [] + for item in value: + item_serialized, item_imports = serialize_value(item) + tuple_items.append(item_serialized) + tuple_imports += item_imports + return DelayedFormat(f"({', '.join(['{}'] * len(tuple_items))})", *tuple_items), tuple_imports + elif isinstance(value, dict): + # -> Serialize dict values individually, collect subimports + dict_items = [] # Will alternatingly contain key and value of all dict items + dict_imports = [] + for key, val in value.items(): + val_serialized, item_imports = serialize_value(val) + dict_items += [key, val_serialized] + dict_imports += item_imports + return DelayedFormat(f"{{{', '.join(['{}: {}'] * len(dict_items))}}}", *dict_items), dict_imports + elif isinstance(value, DelayedBase): + # sisyphus variables are just given as-is and will be instanciated only when calling "get". + return value, [] + elif isinstance(value, str): + return f'"{value}"', [] + else: + # No special case (usually python primitives) + # -> Just get string representation + return str(value), [] + + # Import the class of `cfg` + imports = [ + Import( + code_object_path=f"{type(cfg).__module__}.{type(cfg).__name__}", unhashed_package_root=unhashed_package_root + ) + ] + + call_kwargs = [] + + # Iterate over all dataclass fields and apply helper function to all values + for key in fields(type(cfg)): + # Value corresponding to dataclass field name + value = getattr(cfg, key.name) + + serialized_value, value_imports = serialize_value(value) + call_kwargs.append((key.name, serialized_value)) + imports += value_imports + + # Deduplicate imports + seen_hashes = set() + unique_imports = [] + for imp in imports: + imp_hash = hash(imp) + if imp_hash not in seen_hashes: + seen_hashes.add(imp_hash) + unique_imports.append(imp) + + return Call(callable_name=type(cfg).__name__, kwargs=call_kwargs, return_assign_variables=variable_name), unique_imports + From b8750244a1d899479cbd7e78491b987af5d94400 Mon Sep 17 00:00:00 2001 From: Simon Berger Date: Mon, 3 Jun 2024 12:36:16 +0200 Subject: [PATCH 02/10] Remove unused deduplication function --- common/setups/returnn_pytorch/serialization.py | 10 ---------- 1 file changed, 10 deletions(-) diff --git a/common/setups/returnn_pytorch/serialization.py b/common/setups/returnn_pytorch/serialization.py index 8f3942690..7029ca43d 100644 --- a/common/setups/returnn_pytorch/serialization.py +++ b/common/setups/returnn_pytorch/serialization.py @@ -253,16 +253,6 @@ def build_config_constructor_serializers( return Call(callable_name=type(cfg).__name__, kwargs=call_kwargs, return_assign_variables=variable_name), imports -def deduplicate_list_by_hash(orig_list: list) -> list: - seen_hashes = set() - unique_objects = [] - for obj in orig_list: - obj_hash = hash(obj) - if obj_hash not in seen_hashes: - seen_hashes.add(obj_hash) - unique_objects.append(obj) - return unique_objects - def build_config_constructor_serializers_v2( cfg: ModelConfiguration, variable_name: Optional[str] = None, unhashed_package_root: Optional[str] = None From acabde7b7088ca4e5717c9173df29dc58fae7806 Mon Sep 17 00:00:00 2001 From: Simon Berger Date: Mon, 3 Jun 2024 12:46:12 +0200 Subject: [PATCH 03/10] Formatting and import sorting --- .../setups/returnn_pytorch/serialization.py | 22 +++++++++++++++---- 1 file changed, 18 insertions(+), 4 deletions(-) diff --git a/common/setups/returnn_pytorch/serialization.py b/common/setups/returnn_pytorch/serialization.py index 7029ca43d..eb66e9927 100644 --- a/common/setups/returnn_pytorch/serialization.py +++ b/common/setups/returnn_pytorch/serialization.py @@ -255,7 +255,9 @@ def build_config_constructor_serializers( def build_config_constructor_serializers_v2( - cfg: ModelConfiguration, variable_name: Optional[str] = None, unhashed_package_root: Optional[str] = None + cfg: ModelConfiguration, + variable_name: Optional[str] = None, + unhashed_package_root: Optional[str] = None, ) -> Tuple[Call, List[Import]]: """ Creates a Call object that will re-construct the given ModelConfiguration when serialized and @@ -307,7 +309,10 @@ def serialize_value(value: Any) -> Tuple[Union[str, DelayedBase], List[Import]]: return ( Call( callable_name=ModuleFactoryV1.__name__, - kwargs=[("module_class", value.module_class.__name__), ("cfg", subcall)], + kwargs=[ + ("module_class", value.module_class.__name__), + ("cfg", subcall), + ], ), subimports, ) @@ -389,7 +394,8 @@ def serialize_value(value: Any) -> Tuple[Union[str, DelayedBase], List[Import]]: # Import the class of `cfg` imports = [ Import( - code_object_path=f"{type(cfg).__module__}.{type(cfg).__name__}", unhashed_package_root=unhashed_package_root + code_object_path=f"{type(cfg).__module__}.{type(cfg).__name__}", + unhashed_package_root=unhashed_package_root, ) ] @@ -413,5 +419,13 @@ def serialize_value(value: Any) -> Tuple[Union[str, DelayedBase], List[Import]]: seen_hashes.add(imp_hash) unique_imports.append(imp) - return Call(callable_name=type(cfg).__name__, kwargs=call_kwargs, return_assign_variables=variable_name), unique_imports + unique_imports.sort(key=lambda imp: str(imp)) + return ( + Call( + callable_name=type(cfg).__name__, + kwargs=call_kwargs, + return_assign_variables=variable_name, + ), + unique_imports, + ) From b190553798f355865ff5f029558f4021a2922480 Mon Sep 17 00:00:00 2001 From: Simon Berger Date: Mon, 25 Nov 2024 12:16:45 +0100 Subject: [PATCH 04/10] Fix unhashed_package_root --- .../setups/returnn_pytorch/serialization.py | 27 ++++++++++++++----- 1 file changed, 21 insertions(+), 6 deletions(-) diff --git a/common/setups/returnn_pytorch/serialization.py b/common/setups/returnn_pytorch/serialization.py index eb66e9927..dcba63295 100644 --- a/common/setups/returnn_pytorch/serialization.py +++ b/common/setups/returnn_pytorch/serialization.py @@ -297,13 +297,19 @@ def serialize_value(value: Any) -> Tuple[Union[str, DelayedBase], List[Import]]: subimports.append( Import( code_object_path=f"{value.module_class.__module__}.{value.module_class.__name__}", - unhashed_package_root=unhashed_package_root, + unhashed_package_root=unhashed_package_root + if unhashed_package_root is not None + and value.module_class.__module__.startswith(unhashed_package_root) + else None, ) ) subimports.append( Import( code_object_path=f"{ModuleFactoryV1.__module__}.{ModuleFactoryV1.__name__}", - unhashed_package_root=unhashed_package_root, + unhashed_package_root=unhashed_package_root + if unhashed_package_root is not None + and ModuleFactoryV1.__module__.startswith(unhashed_package_root) + else None, ) ) return ( @@ -325,7 +331,9 @@ def serialize_value(value: Any) -> Tuple[Union[str, DelayedBase], List[Import]]: return str(value), [ Import( code_object_path=f"{value.__module__}.{type(value).__name__}", - unhashed_package_root=unhashed_package_root, + unhashed_package_root=unhashed_package_root + if unhashed_package_root is not None and value.__module__.startswith(unhashed_package_root) + else None, ) ] elif isfunction(value): @@ -337,7 +345,9 @@ def serialize_value(value: Any) -> Tuple[Union[str, DelayedBase], List[Import]]: subimports = [ Import( code_object_path=f"{value.__module__}.{value.__name__}", - unhashed_package_root=unhashed_package_root, + unhashed_package_root=unhashed_package_root + if unhashed_package_root is not None and value.__module__.startswith(unhashed_package_root) + else None, ) ] else: @@ -350,7 +360,10 @@ def serialize_value(value: Any) -> Tuple[Union[str, DelayedBase], List[Import]]: subimports = [ Import( code_object_path=f"{value.__class__.__module__}.{value.__class__.__name__}", - unhashed_package_root=unhashed_package_root, + unhashed_package_root=unhashed_package_root + if unhashed_package_root is not None + and value.__class__.__module__.startswith(unhashed_package_root) + else None, ) ] return f"{value.__class__.__name__}.{value.name}", subimports @@ -395,7 +408,9 @@ def serialize_value(value: Any) -> Tuple[Union[str, DelayedBase], List[Import]]: imports = [ Import( code_object_path=f"{type(cfg).__module__}.{type(cfg).__name__}", - unhashed_package_root=unhashed_package_root, + unhashed_package_root=unhashed_package_root + if unhashed_package_root is not None and type(cfg).__module__.startswith(unhashed_package_root) + else None, ) ] From 4be0c356a693e5e0b7e2048ed347203a19932d59 Mon Sep 17 00:00:00 2001 From: Simon Berger Date: Mon, 25 Nov 2024 12:17:08 +0100 Subject: [PATCH 05/10] Allow tk.Path values --- common/setups/returnn_pytorch/serialization.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/common/setups/returnn_pytorch/serialization.py b/common/setups/returnn_pytorch/serialization.py index dcba63295..a324c8486 100644 --- a/common/setups/returnn_pytorch/serialization.py +++ b/common/setups/returnn_pytorch/serialization.py @@ -394,6 +394,8 @@ def serialize_value(value: Any) -> Tuple[Union[str, DelayedBase], List[Import]]: dict_items += [key, val_serialized] dict_imports += item_imports return DelayedFormat(f"{{{', '.join(['{}: {}'] * len(dict_items))}}}", *dict_items), dict_imports + elif isinstance(value, tk.Path): + return DelayedFormat('tk.Path("{}")', value), [Import("sisyphus.tk")] elif isinstance(value, DelayedBase): # sisyphus variables are just given as-is and will be instanciated only when calling "get". return value, [] From 9f3d350ca661a6c8461c688b35235d26928751e8 Mon Sep 17 00:00:00 2001 From: Simon Berger Date: Mon, 25 Nov 2024 12:17:25 +0100 Subject: [PATCH 06/10] Generalize typing of kwargs --- common/setups/serialization.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/common/setups/serialization.py b/common/setups/serialization.py index 80afeb2db..313d8dcae 100644 --- a/common/setups/serialization.py +++ b/common/setups/serialization.py @@ -367,8 +367,8 @@ class Call(SerializerObject): def __init__( self, callable_name: str, - kwargs: Optional[List[Tuple[str, Union[str, DelayedBase]]]] = None, - unhashed_kwargs: Optional[List[Tuple[str, Union[str, DelayedBase]]]] = None, + kwargs: Optional[List[Tuple[str, Any]]] = None, + unhashed_kwargs: Optional[List[Tuple[str, Any]]] = None, return_assign_variables: Optional[Union[str, List[str]]] = None, ) -> None: """ From 0d30366ad9d4e0136d40ce7fe11798eebd22c544 Mon Sep 17 00:00:00 2001 From: Simon Berger Date: Thu, 28 Nov 2024 18:29:27 +0100 Subject: [PATCH 07/10] Enable general dataclass serialization --- .../setups/returnn_pytorch/serialization.py | 20 +++++++++---------- 1 file changed, 10 insertions(+), 10 deletions(-) diff --git a/common/setups/returnn_pytorch/serialization.py b/common/setups/returnn_pytorch/serialization.py index a324c8486..4c4e85cde 100644 --- a/common/setups/returnn_pytorch/serialization.py +++ b/common/setups/returnn_pytorch/serialization.py @@ -6,7 +6,7 @@ import string import textwrap from collections import OrderedDict -from dataclasses import fields +from dataclasses import fields, is_dataclass from enum import Enum from inspect import isfunction from typing import Any, Dict, List, Optional, Set, Tuple, Union, TYPE_CHECKING @@ -255,7 +255,7 @@ def build_config_constructor_serializers( def build_config_constructor_serializers_v2( - cfg: ModelConfiguration, + cfg: Any, variable_name: Optional[str] = None, unhashed_package_root: Optional[str] = None, ) -> Tuple[Call, List[Import]]: @@ -267,25 +267,20 @@ def build_config_constructor_serializers_v2( Compared to the previous version, this function can also serialize enum members and values of type list, tuple or dict. It also fixes import deduplication. - :param cfg: ModelConfiguration object that will be re-constructed by the Call serializer + :param cfg: ModelConfiguration or dataclass object that will be re-constructed by the Call serializer :param variable_name: Name of the variable which the constructed ModelConfiguration will be assigned to. If None, the result will not be assigned to a variable. :param unhashed_package_root: Will be passed to all generated Import objects. :return: Call object and list of necessary imports. """ - from i6_models.config import ModelConfiguration, ModuleFactoryV1 + from i6_models.config import ModuleFactoryV1 # Helper function which can call itself recursively for nested types def serialize_value(value: Any) -> Tuple[Union[str, DelayedBase], List[Import]]: # Switch over serialization logic for different subtypes - if isinstance(value, ModelConfiguration): - # Example: - # ConformerBlockConfig(mhsa_config=ConformerMHSAConfig(...)) - # -> Sub-Constructor-Call and imports for ConformerMHSAConfig - return build_config_constructor_serializers_v2(value, unhashed_package_root=unhashed_package_root) - elif isinstance(value, ModuleFactoryV1): + if isinstance(value, ModuleFactoryV1): # Example: # ConformerEncoderConfig( # frontend=ModuleFactoryV1(module_class=VGGFrontend, cfg=VGGFrontendConfig(...))) @@ -322,6 +317,11 @@ def serialize_value(value: Any) -> Tuple[Union[str, DelayedBase], List[Import]]: ), subimports, ) + elif is_dataclass(value): + # Example: + # ConformerBlockConfig(mhsa_config=ConformerMHSAConfig(...)) + # -> Sub-Constructor-Call and imports for ConformerMHSAConfig + return build_config_constructor_serializers_v2(value, unhashed_package_root=unhashed_package_root) elif isinstance(value, torch.nn.Module): # Example: # ConformerConvolutionConfig(norm=BatchNorm1d(...)) From a937d3045d338d3056c9e2ff9f207a954b23ae29 Mon Sep 17 00:00:00 2001 From: Simon Berger Date: Thu, 28 Nov 2024 18:29:35 +0100 Subject: [PATCH 08/10] Fix dict serialization --- common/setups/returnn_pytorch/serialization.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/common/setups/returnn_pytorch/serialization.py b/common/setups/returnn_pytorch/serialization.py index 4c4e85cde..38b2e6049 100644 --- a/common/setups/returnn_pytorch/serialization.py +++ b/common/setups/returnn_pytorch/serialization.py @@ -393,7 +393,9 @@ def serialize_value(value: Any) -> Tuple[Union[str, DelayedBase], List[Import]]: val_serialized, item_imports = serialize_value(val) dict_items += [key, val_serialized] dict_imports += item_imports - return DelayedFormat(f"{{{', '.join(['{}: {}'] * len(dict_items))}}}", *dict_items), dict_imports + return DelayedFormat( + "{{" + ", ".join(["{}: {}"] * (len(dict_items) // 2)) + "}}", *dict_items + ), dict_imports elif isinstance(value, tk.Path): return DelayedFormat('tk.Path("{}")', value), [Import("sisyphus.tk")] elif isinstance(value, DelayedBase): From 020f3f5a9feabeb6d39be3691798d0860157ac5e Mon Sep 17 00:00:00 2001 From: Simon Berger Date: Thu, 28 Nov 2024 18:32:34 +0100 Subject: [PATCH 09/10] Formatting --- .../setups/returnn_pytorch/serialization.py | 61 +++++++++++-------- 1 file changed, 37 insertions(+), 24 deletions(-) diff --git a/common/setups/returnn_pytorch/serialization.py b/common/setups/returnn_pytorch/serialization.py index 38b2e6049..482f2a737 100644 --- a/common/setups/returnn_pytorch/serialization.py +++ b/common/setups/returnn_pytorch/serialization.py @@ -292,19 +292,23 @@ def serialize_value(value: Any) -> Tuple[Union[str, DelayedBase], List[Import]]: subimports.append( Import( code_object_path=f"{value.module_class.__module__}.{value.module_class.__name__}", - unhashed_package_root=unhashed_package_root - if unhashed_package_root is not None - and value.module_class.__module__.startswith(unhashed_package_root) - else None, + unhashed_package_root=( + unhashed_package_root + if unhashed_package_root is not None + and value.module_class.__module__.startswith(unhashed_package_root) + else None + ), ) ) subimports.append( Import( code_object_path=f"{ModuleFactoryV1.__module__}.{ModuleFactoryV1.__name__}", - unhashed_package_root=unhashed_package_root - if unhashed_package_root is not None - and ModuleFactoryV1.__module__.startswith(unhashed_package_root) - else None, + unhashed_package_root=( + unhashed_package_root + if unhashed_package_root is not None + and ModuleFactoryV1.__module__.startswith(unhashed_package_root) + else None + ), ) ) return ( @@ -331,9 +335,11 @@ def serialize_value(value: Any) -> Tuple[Union[str, DelayedBase], List[Import]]: return str(value), [ Import( code_object_path=f"{value.__module__}.{type(value).__name__}", - unhashed_package_root=unhashed_package_root - if unhashed_package_root is not None and value.__module__.startswith(unhashed_package_root) - else None, + unhashed_package_root=( + unhashed_package_root + if unhashed_package_root is not None and value.__module__.startswith(unhashed_package_root) + else None + ), ) ] elif isfunction(value): @@ -345,9 +351,11 @@ def serialize_value(value: Any) -> Tuple[Union[str, DelayedBase], List[Import]]: subimports = [ Import( code_object_path=f"{value.__module__}.{value.__name__}", - unhashed_package_root=unhashed_package_root - if unhashed_package_root is not None and value.__module__.startswith(unhashed_package_root) - else None, + unhashed_package_root=( + unhashed_package_root + if unhashed_package_root is not None and value.__module__.startswith(unhashed_package_root) + else None + ), ) ] else: @@ -360,10 +368,12 @@ def serialize_value(value: Any) -> Tuple[Union[str, DelayedBase], List[Import]]: subimports = [ Import( code_object_path=f"{value.__class__.__module__}.{value.__class__.__name__}", - unhashed_package_root=unhashed_package_root - if unhashed_package_root is not None - and value.__class__.__module__.startswith(unhashed_package_root) - else None, + unhashed_package_root=( + unhashed_package_root + if unhashed_package_root is not None + and value.__class__.__module__.startswith(unhashed_package_root) + else None + ), ) ] return f"{value.__class__.__name__}.{value.name}", subimports @@ -393,9 +403,10 @@ def serialize_value(value: Any) -> Tuple[Union[str, DelayedBase], List[Import]]: val_serialized, item_imports = serialize_value(val) dict_items += [key, val_serialized] dict_imports += item_imports - return DelayedFormat( - "{{" + ", ".join(["{}: {}"] * (len(dict_items) // 2)) + "}}", *dict_items - ), dict_imports + return ( + DelayedFormat("{{" + ", ".join(["{}: {}"] * (len(dict_items) // 2)) + "}}", *dict_items), + dict_imports, + ) elif isinstance(value, tk.Path): return DelayedFormat('tk.Path("{}")', value), [Import("sisyphus.tk")] elif isinstance(value, DelayedBase): @@ -412,9 +423,11 @@ def serialize_value(value: Any) -> Tuple[Union[str, DelayedBase], List[Import]]: imports = [ Import( code_object_path=f"{type(cfg).__module__}.{type(cfg).__name__}", - unhashed_package_root=unhashed_package_root - if unhashed_package_root is not None and type(cfg).__module__.startswith(unhashed_package_root) - else None, + unhashed_package_root=( + unhashed_package_root + if unhashed_package_root is not None and type(cfg).__module__.startswith(unhashed_package_root) + else None + ), ) ] From 7b4f50c82835dc174ff48084ff002156bd5ff891 Mon Sep 17 00:00:00 2001 From: Simon Berger Date: Mon, 20 Jan 2025 12:25:37 +0100 Subject: [PATCH 10/10] Fix dict key serialization --- common/setups/returnn_pytorch/serialization.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/common/setups/returnn_pytorch/serialization.py b/common/setups/returnn_pytorch/serialization.py index 482f2a737..2f5b85459 100644 --- a/common/setups/returnn_pytorch/serialization.py +++ b/common/setups/returnn_pytorch/serialization.py @@ -401,7 +401,7 @@ def serialize_value(value: Any) -> Tuple[Union[str, DelayedBase], List[Import]]: dict_imports = [] for key, val in value.items(): val_serialized, item_imports = serialize_value(val) - dict_items += [key, val_serialized] + dict_items += [repr(key), val_serialized] dict_imports += item_imports return ( DelayedFormat("{{" + ", ".join(["{}: {}"] * (len(dict_items) // 2)) + "}}", *dict_items),