[FRONTEND] Fix 3d indexing (#1006)
This commit is contained in:
@@ -491,10 +491,9 @@ def make_ptr_str(name, shape):
|
||||
# TODO: handle `%4 = triton_gpu.convert_layout %3 : (tensor<32xi32, #blocked0>) -> tensor<32xi32, #triton_gpu.slice<{dim = 0, parent = #blocked1}>>``
|
||||
@pytest.mark.parametrize("expr, dtype_str", [
|
||||
(f'x[{s}]', d)
|
||||
for s in ['None, :', ':, None']
|
||||
# FIXME: 3d indexing doesn't work
|
||||
#'None, :, :',
|
||||
# ':, :, None']
|
||||
for s in ['None, :', ':, None',
|
||||
'None, :, :',
|
||||
':, :, None']
|
||||
for d in ['int32', 'uint32', 'uint16']
|
||||
])
|
||||
def test_index1d(expr, dtype_str, device='cuda'):
|
||||
|
@@ -596,11 +596,9 @@ class tensor:
|
||||
if isinstance(slices, slice):
|
||||
slices = [slices]
|
||||
ret = self
|
||||
n_inserted = 0
|
||||
for dim, sl in enumerate(slices):
|
||||
if isinstance(sl, constexpr) and sl.value is None:
|
||||
ret = semantic.expand_dims(ret, dim + n_inserted, _builder)
|
||||
n_inserted += 1
|
||||
ret = semantic.expand_dims(ret, dim, _builder)
|
||||
elif sl == slice(None, None, None):
|
||||
pass
|
||||
else:
|
||||
|
Reference in New Issue
Block a user