[FRONTEND] Fix 3d indexing (#1006)
This commit is contained in:
@@ -491,10 +491,9 @@ def make_ptr_str(name, shape):
|
|||||||
# TODO: handle `%4 = triton_gpu.convert_layout %3 : (tensor<32xi32, #blocked0>) -> tensor<32xi32, #triton_gpu.slice<{dim = 0, parent = #blocked1}>>``
|
# TODO: handle `%4 = triton_gpu.convert_layout %3 : (tensor<32xi32, #blocked0>) -> tensor<32xi32, #triton_gpu.slice<{dim = 0, parent = #blocked1}>>``
|
||||||
@pytest.mark.parametrize("expr, dtype_str", [
|
@pytest.mark.parametrize("expr, dtype_str", [
|
||||||
(f'x[{s}]', d)
|
(f'x[{s}]', d)
|
||||||
for s in ['None, :', ':, None']
|
for s in ['None, :', ':, None',
|
||||||
# FIXME: 3d indexing doesn't work
|
'None, :, :',
|
||||||
#'None, :, :',
|
':, :, None']
|
||||||
# ':, :, None']
|
|
||||||
for d in ['int32', 'uint32', 'uint16']
|
for d in ['int32', 'uint32', 'uint16']
|
||||||
])
|
])
|
||||||
def test_index1d(expr, dtype_str, device='cuda'):
|
def test_index1d(expr, dtype_str, device='cuda'):
|
||||||
|
@@ -596,11 +596,9 @@ class tensor:
|
|||||||
if isinstance(slices, slice):
|
if isinstance(slices, slice):
|
||||||
slices = [slices]
|
slices = [slices]
|
||||||
ret = self
|
ret = self
|
||||||
n_inserted = 0
|
|
||||||
for dim, sl in enumerate(slices):
|
for dim, sl in enumerate(slices):
|
||||||
if isinstance(sl, constexpr) and sl.value is None:
|
if isinstance(sl, constexpr) and sl.value is None:
|
||||||
ret = semantic.expand_dims(ret, dim + n_inserted, _builder)
|
ret = semantic.expand_dims(ret, dim, _builder)
|
||||||
n_inserted += 1
|
|
||||||
elif sl == slice(None, None, None):
|
elif sl == slice(None, None, None):
|
||||||
pass
|
pass
|
||||||
else:
|
else:
|
||||||
|
Reference in New Issue
Block a user