@@ -43,37 +43,44 @@ def batch_space(space, n=1):
43
43
44
44
45
45
@batch_space .register (Box )
46
- @batch_space .register (Discrete )
47
- @batch_space .register (MultiDiscrete )
48
- @batch_space .register (MultiBinary )
49
- def batch_space_base (space , n = 1 ):
50
- if isinstance (space , Box ):
51
- repeats = tuple ([n ] + [1 ] * space .low .ndim )
52
- low , high = np .tile (space .low , repeats ), np .tile (space .high , repeats )
53
- return Box (low = low , high = high , dtype = space .dtype )
46
+ def _batch_space_box (space , n = 1 ):
47
+ repeats = tuple ([n ] + [1 ] * space .low .ndim )
48
+ low , high = np .tile (space .low , repeats ), np .tile (space .high , repeats )
49
+ return Box (low = low , high = high , dtype = space .dtype )
54
50
55
- elif isinstance (space , Discrete ):
51
+
52
+ @batch_space .register (Discrete )
53
+ def _batch_space_discrete (space , n = 1 ):
54
+ if space .start == 0 :
56
55
return MultiDiscrete (np .full ((n ,), space .n , dtype = space .dtype ))
56
+ else :
57
+ return Box (
58
+ low = space .start ,
59
+ high = space .start + space .n - 1 ,
60
+ shape = (n ,),
61
+ dtype = space .dtype ,
62
+ )
57
63
58
- elif isinstance (space , MultiDiscrete ):
59
- repeats = tuple ([n ] + [1 ] * space .nvec .ndim )
60
- high = np .tile (space .nvec , repeats ) - 1
61
- return Box (low = np .zeros_like (high ), high = high , dtype = space .dtype )
62
64
63
- elif isinstance (space , MultiBinary ):
64
- return Box (low = 0 , high = 1 , shape = (n ,) + space .shape , dtype = space .dtype )
65
+ @batch_space .register (MultiDiscrete )
66
+ def _batch_space_multidiscrete (space , n = 1 ):
67
+ repeats = tuple ([n ] + [1 ] * space .nvec .ndim )
68
+ high = np .tile (space .nvec , repeats ) - 1
69
+ return Box (low = np .zeros_like (high ), high = high , dtype = space .dtype )
65
70
66
- else :
67
- raise ValueError (f"Space type `{ type (space )} ` is not supported." )
71
+
72
+ @batch_space .register (MultiBinary )
73
+ def _batch_space_multibinary (space , n = 1 ):
74
+ return Box (low = 0 , high = 1 , shape = (n ,) + space .shape , dtype = space .dtype )
68
75
69
76
70
77
@batch_space .register (Tuple )
71
- def batch_space_tuple (space , n = 1 ):
78
+ def _batch_space_tuple (space , n = 1 ):
72
79
return Tuple (tuple (batch_space (subspace , n = n ) for subspace in space .spaces ))
73
80
74
81
75
82
@batch_space .register (Dict )
76
- def batch_space_dict (space , n = 1 ):
83
+ def _batch_space_dict (space , n = 1 ):
77
84
return Dict (
78
85
OrderedDict (
79
86
[
@@ -85,7 +92,7 @@ def batch_space_dict(space, n=1):
85
92
86
93
87
94
@batch_space .register (Space )
88
- def batch_space_custom (space , n = 1 ):
95
+ def _batch_space_custom (space , n = 1 ):
89
96
return Tuple (tuple (space for _ in range (n )))
90
97
91
98
@@ -130,22 +137,22 @@ def iterate(space, items):
130
137
131
138
132
139
@iterate .register (Discrete )
133
- def iterate_discrete (space , items ):
140
+ def _iterate_discrete (space , items ):
134
141
raise TypeError ("Unable to iterate over a space of type `Discrete`." )
135
142
136
143
137
144
@iterate .register (Box )
138
145
@iterate .register (MultiDiscrete )
139
146
@iterate .register (MultiBinary )
140
- def iterate_base (space , items ):
147
+ def _iterate_base (space , items ):
141
148
try :
142
149
return iter (items )
143
150
except TypeError :
144
151
raise TypeError (f"Unable to iterate over the following elements: { items } " )
145
152
146
153
147
154
@iterate .register (Tuple )
148
- def iterate_tuple (space , items ):
155
+ def _iterate_tuple (space , items ):
149
156
# If this is a tuple of custom subspaces only, then simply iterate over items
150
157
if all (
151
158
isinstance (subspace , Space )
@@ -160,7 +167,7 @@ def iterate_tuple(space, items):
160
167
161
168
162
169
@iterate .register (Dict )
163
- def iterate_dict (space , items ):
170
+ def _iterate_dict (space , items ):
164
171
keys , values = zip (
165
172
* [
166
173
(key , iterate (subspace , items [key ]))
@@ -172,7 +179,7 @@ def iterate_dict(space, items):
172
179
173
180
174
181
@iterate .register (Space )
175
- def iterate_custom (space , items ):
182
+ def _iterate_custom (space , items ):
176
183
raise CustomSpaceError (
177
184
f"Unable to iterate over { items } , since { space } "
178
185
"is a custom `gym.Space` instance (i.e. not one of "
0 commit comments