Skip to content

Commit e671aa1

Browse files
authored
Fix space utils for Discrete with non-zero start (#2645)
* Fix flatten utils to handle Discrete.start * Fix vector space utils to handle Discrete.start * More granular dispatch in vector utils * Fix Box including the high end of the interval
1 parent 108f32c commit e671aa1

File tree

6 files changed

+48
-28
lines changed

6 files changed

+48
-28
lines changed

gym/spaces/utils.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -78,7 +78,7 @@ def _flatten_box_multibinary(space, x) -> np.ndarray:
7878
@flatten.register(Discrete)
7979
def _flatten_discrete(space, x) -> np.ndarray:
8080
onehot = np.zeros(space.n, dtype=space.dtype)
81-
onehot[x] = 1
81+
onehot[x - space.start] = 1
8282
return onehot
8383

8484

@@ -124,7 +124,7 @@ def _unflatten_box_multibinary(space: Box | MultiBinary, x: np.ndarray) -> np.nd
124124

125125
@unflatten.register(Discrete)
126126
def _unflatten_discrete(space: Discrete, x: np.ndarray) -> int:
127-
return int(np.nonzero(x)[0][0])
127+
return int(space.start + np.nonzero(x)[0][0])
128128

129129

130130
@unflatten.register(MultiDiscrete)

gym/vector/utils/spaces.py

Lines changed: 32 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -43,37 +43,44 @@ def batch_space(space, n=1):
4343

4444

4545
@batch_space.register(Box)
46-
@batch_space.register(Discrete)
47-
@batch_space.register(MultiDiscrete)
48-
@batch_space.register(MultiBinary)
49-
def batch_space_base(space, n=1):
50-
if isinstance(space, Box):
51-
repeats = tuple([n] + [1] * space.low.ndim)
52-
low, high = np.tile(space.low, repeats), np.tile(space.high, repeats)
53-
return Box(low=low, high=high, dtype=space.dtype)
46+
def _batch_space_box(space, n=1):
47+
repeats = tuple([n] + [1] * space.low.ndim)
48+
low, high = np.tile(space.low, repeats), np.tile(space.high, repeats)
49+
return Box(low=low, high=high, dtype=space.dtype)
5450

55-
elif isinstance(space, Discrete):
51+
52+
@batch_space.register(Discrete)
53+
def _batch_space_discrete(space, n=1):
54+
if space.start == 0:
5655
return MultiDiscrete(np.full((n,), space.n, dtype=space.dtype))
56+
else:
57+
return Box(
58+
low=space.start,
59+
high=space.start + space.n - 1,
60+
shape=(n,),
61+
dtype=space.dtype,
62+
)
5763

58-
elif isinstance(space, MultiDiscrete):
59-
repeats = tuple([n] + [1] * space.nvec.ndim)
60-
high = np.tile(space.nvec, repeats) - 1
61-
return Box(low=np.zeros_like(high), high=high, dtype=space.dtype)
6264

63-
elif isinstance(space, MultiBinary):
64-
return Box(low=0, high=1, shape=(n,) + space.shape, dtype=space.dtype)
65+
@batch_space.register(MultiDiscrete)
66+
def _batch_space_multidiscrete(space, n=1):
67+
repeats = tuple([n] + [1] * space.nvec.ndim)
68+
high = np.tile(space.nvec, repeats) - 1
69+
return Box(low=np.zeros_like(high), high=high, dtype=space.dtype)
6570

66-
else:
67-
raise ValueError(f"Space type `{type(space)}` is not supported.")
71+
72+
@batch_space.register(MultiBinary)
73+
def _batch_space_multibinary(space, n=1):
74+
return Box(low=0, high=1, shape=(n,) + space.shape, dtype=space.dtype)
6875

6976

7077
@batch_space.register(Tuple)
71-
def batch_space_tuple(space, n=1):
78+
def _batch_space_tuple(space, n=1):
7279
return Tuple(tuple(batch_space(subspace, n=n) for subspace in space.spaces))
7380

7481

7582
@batch_space.register(Dict)
76-
def batch_space_dict(space, n=1):
83+
def _batch_space_dict(space, n=1):
7784
return Dict(
7885
OrderedDict(
7986
[
@@ -85,7 +92,7 @@ def batch_space_dict(space, n=1):
8592

8693

8794
@batch_space.register(Space)
88-
def batch_space_custom(space, n=1):
95+
def _batch_space_custom(space, n=1):
8996
return Tuple(tuple(space for _ in range(n)))
9097

9198

@@ -130,22 +137,22 @@ def iterate(space, items):
130137

131138

132139
@iterate.register(Discrete)
133-
def iterate_discrete(space, items):
140+
def _iterate_discrete(space, items):
134141
raise TypeError("Unable to iterate over a space of type `Discrete`.")
135142

136143

137144
@iterate.register(Box)
138145
@iterate.register(MultiDiscrete)
139146
@iterate.register(MultiBinary)
140-
def iterate_base(space, items):
147+
def _iterate_base(space, items):
141148
try:
142149
return iter(items)
143150
except TypeError:
144151
raise TypeError(f"Unable to iterate over the following elements: {items}")
145152

146153

147154
@iterate.register(Tuple)
148-
def iterate_tuple(space, items):
155+
def _iterate_tuple(space, items):
149156
# If this is a tuple of custom subspaces only, then simply iterate over items
150157
if all(
151158
isinstance(subspace, Space)
@@ -160,7 +167,7 @@ def iterate_tuple(space, items):
160167

161168

162169
@iterate.register(Dict)
163-
def iterate_dict(space, items):
170+
def _iterate_dict(space, items):
164171
keys, values = zip(
165172
*[
166173
(key, iterate(subspace, items[key]))
@@ -172,7 +179,7 @@ def iterate_dict(space, items):
172179

173180

174181
@iterate.register(Space)
175-
def iterate_custom(space, items):
182+
def _iterate_custom(space, items):
176183
raise CustomSpaceError(
177184
f"Unable to iterate over {items}, since {space} "
178185
"is a custom `gym.Space` instance (i.e. not one of "

tests/spaces/test_utils.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,9 +28,11 @@
2828
),
2929
}
3030
),
31+
Discrete(3, start=2),
32+
Discrete(8, start=-5),
3133
]
3234

33-
flatdims = [3, 4, 4, 15, 7, 9, 14, 10, 7]
35+
flatdims = [3, 4, 4, 15, 7, 9, 14, 10, 7, 3, 8]
3436

3537

3638
@pytest.mark.parametrize(["space", "flatdim"], zip(spaces, flatdims))
@@ -123,6 +125,8 @@ def compare_nested(left, right):
123125
np.int64,
124126
np.int8,
125127
np.float64,
128+
np.int64,
129+
np.int64,
126130
]
127131

128132

@@ -187,6 +191,8 @@ def compare_sample_types(original_space, original_sample, unflattened_sample):
187191
OrderedDict(
188192
[("position", 3), ("velocity", np.array([0.5, 3.5], dtype=np.float32))]
189193
),
194+
3,
195+
-2,
190196
]
191197

192198

@@ -200,6 +206,8 @@ def compare_sample_types(original_space, original_sample, unflattened_sample):
200206
np.array([1, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0], dtype=np.int64),
201207
np.array([0, 1, 1, 0, 0, 0, 1, 1, 1, 1], dtype=np.int8),
202208
np.array([0, 0, 0, 1, 0, 0.5, 3.5], dtype=np.float64),
209+
np.array([0, 1, 0], dtype=np.int64),
210+
np.array([0, 0, 0, 1, 0, 0, 0, 0], dtype=np.int64),
203211
]
204212

205213

@@ -243,6 +251,8 @@ def test_unflatten(space, flattened_sample, expected_sample):
243251
high=np.array([1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 5.0], dtype=np.float64),
244252
dtype=np.float64,
245253
),
254+
Box(low=0, high=1, shape=(3,), dtype=np.int64),
255+
Box(low=0, high=1, shape=(8,), dtype=np.int64),
246256
]
247257

248258

tests/vector/test_shared_memory.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
Array("B", 1),
2727
Array("B", 32 * 32 * 3),
2828
Array("i", 1),
29+
Array("i", 1),
2930
(Array("i", 1), Array("i", 1)),
3031
(Array("i", 1), Array("f", 2)),
3132
Array("B", 3),

tests/vector/test_spaces.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@
3333
Box(low=0, high=255, shape=(4,), dtype=np.uint8),
3434
Box(low=0, high=255, shape=(4, 32, 32, 3), dtype=np.uint8),
3535
MultiDiscrete([2, 2, 2, 2]),
36+
Box(low=-2, high=2, shape=(4,), dtype=np.int64),
3637
Tuple((MultiDiscrete([3, 3, 3, 3]), MultiDiscrete([5, 5, 5, 5]))),
3738
Tuple(
3839
(

tests/vector/utils.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
Box(low=0, high=255, shape=(), dtype=np.uint8),
1919
Box(low=0, high=255, shape=(32, 32, 3), dtype=np.uint8),
2020
Discrete(2),
21+
Discrete(5, start=-2),
2122
Tuple((Discrete(3), Discrete(5))),
2223
Tuple(
2324
(

0 commit comments

Comments
 (0)