[FRONTEND] Fix 3d indexing (#1006)

This commit is contained in:
Keren Zhou
2022-12-21 12:52:32 -08:00
committed by GitHub
parent 20100a7254
commit b5aafb0dab
2 changed files with 4 additions and 7 deletions

View File

@@ -491,10 +491,9 @@ def make_ptr_str(name, shape):
# TODO: handle `%4 = triton_gpu.convert_layout %3 : (tensor<32xi32, #blocked0>) -> tensor<32xi32, #triton_gpu.slice<{dim = 0, parent = #blocked1}>>``
@pytest.mark.parametrize("expr, dtype_str", [
(f'x[{s}]', d)
for s in ['None, :', ':, None']
# FIXME: 3d indexing doesn't work
#'None, :, :',
# ':, :, None']
for s in ['None, :', ':, None',
'None, :, :',
':, :, None']
for d in ['int32', 'uint32', 'uint16']
])
def test_index1d(expr, dtype_str, device='cuda'):

View File

@@ -596,11 +596,9 @@ class tensor:
if isinstance(slices, slice):
slices = [slices]
ret = self
n_inserted = 0
for dim, sl in enumerate(slices):
if isinstance(sl, constexpr) and sl.value is None:
ret = semantic.expand_dims(ret, dim + n_inserted, _builder)
n_inserted += 1
ret = semantic.expand_dims(ret, dim, _builder)
elif sl == slice(None, None, None):
pass
else: