@@ -962,7 +962,7 @@ def validate_distribution_for_instance_type(instance_type, distribution):
962962 """
963963 err_msg = ""
964964 if isinstance (instance_type , str ):
965- match = re .match (r"^ml[\._]([a-z\d]+)\.?\w*$" , instance_type )
965+ match = re .match (r"^ml[\._]([a-z\d\- ]+)\.?\w*$" , instance_type )
966966 if match and match [1 ].startswith ("trn" ):
967967 keys = list (distribution .keys ())
968968 if len (keys ) == 0 :
@@ -1083,7 +1083,7 @@ def _is_gpu_instance(instance_type):
10831083 bool: Whether or not the instance_type supports GPU
10841084 """
10851085 if isinstance (instance_type , str ):
1086- match = re .match (r"^ml[\._]([a-z\d]+)\.?\w*$" , instance_type )
1086+ match = re .match (r"^ml[\._]([a-z\d\- ]+)\.?\w*$" , instance_type )
10871087 if match :
10881088 if match [1 ].startswith ("p" ) or match [1 ].startswith ("g" ):
10891089 return True
@@ -1102,7 +1102,7 @@ def _is_trainium_instance(instance_type):
11021102 bool: Whether or not the instance_type is a Trainium instance
11031103 """
11041104 if isinstance (instance_type , str ):
1105- match = re .match (r"^ml[\._]([a-z\d]+)\.?\w*$" , instance_type )
1105+ match = re .match (r"^ml[\._]([a-z\d\- ]+)\.?\w*$" , instance_type )
11061106 if match and match [1 ].startswith ("trn" ):
11071107 return True
11081108 return False
@@ -1149,7 +1149,7 @@ def _instance_type_supports_profiler(instance_type):
11491149 bool: Whether or not the region supports Amazon SageMaker Debugger profiling feature.
11501150 """
11511151 if isinstance (instance_type , str ):
1152- match = re .match (r"^ml[\._]([a-z\d]+)\.?\w*$" , instance_type )
1152+ match = re .match (r"^ml[\._]([a-z\d\- ]+)\.?\w*$" , instance_type )
11531153 if match and match [1 ].startswith ("trn" ):
11541154 return True
11551155 return False
0 commit comments