Skip to content

Commit

Permalink
temp fix for symbolic shape view add [pr] (#8337)
Browse files Browse the repository at this point in the history
something is still wrong with symbolic shape shrink, but it should not recurse forever
  • Loading branch information
chenyuxyz authored Dec 19, 2024
1 parent 791a80a commit 2bf47b7
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 2 deletions.
8 changes: 7 additions & 1 deletion test/test_symbolic_shapetracker.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,13 @@ def test_real_strides_2(self):
def test_merge_view_recursion_err(self):
vm2 = View(shape=(Variable('j', 1, 10),), strides=(0,), offset=0, mask=None, contiguous=False)
vm1 = View(shape=(1,), strides=(0,), offset=0, mask=None, contiguous=True)
vm2.__add__(vm1)
self.assertEqual(vm2+vm1, vm1)

def test_merge_view_recursion_err2(self):
vm2 = View(shape=(Variable('a', 1, 10).bind(4),), strides=(0,), offset=0, mask=None, contiguous=False)
vm1 = View(shape=(Variable('a', 1, 10).bind(4),), strides=(1,), offset=0, mask=((0, Variable('a', 1, 10).bind(4)),), contiguous=False)
# TODO: this should not be None?
self.assertEqual(vm2+vm1, None)

def test_cat_dim0_strides(self):
i = Variable("i", 1, 5).bind(3)
Expand Down
3 changes: 2 additions & 1 deletion tinygrad/shape/view.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,7 +161,8 @@ def __add__(self, vm1:View) -> Optional[View]:
if vm1.contiguous and vm1.shape == vm2.shape: return vm2
if vm1.contiguous and vm1.size() == vm2.size() and (ret := vm2.reshape(vm1.shape)) is not None: return ret
if vm1.mask:
if (merged := vm2 + vm1.shrink(vm1.mask)) is None: return None
# TODO: why is shrink no changing the view for symbolic shape
if (new_vm1 := vm1.shrink(vm1.mask)) == vm1 or (merged := vm2 + new_vm1) is None: return None
return merged.pad(tuple((b,s-e) for (b,e),s in zip(vm1.mask, vm1.shape)))
if not all_int(vm1.shape): return None

Expand Down

0 comments on commit 2bf47b7

Please sign in to comment.