diff --git a/python/test/unit/language/test_core.py b/python/test/unit/language/test_core.py index 203995bc8..ac381d50a 100644 --- a/python/test/unit/language/test_core.py +++ b/python/test/unit/language/test_core.py @@ -491,10 +491,9 @@ def make_ptr_str(name, shape): # TODO: handle `%4 = triton_gpu.convert_layout %3 : (tensor<32xi32, #blocked0>) -> tensor<32xi32, #triton_gpu.slice<{dim = 0, parent = #blocked1}>>`` @pytest.mark.parametrize("expr, dtype_str", [ (f'x[{s}]', d) - for s in ['None, :', ':, None'] - # FIXME: 3d indexing doesn't work - #'None, :, :', - # ':, :, None'] + for s in ['None, :', ':, None', + 'None, :, :', + ':, :, None'] for d in ['int32', 'uint32', 'uint16'] ]) def test_index1d(expr, dtype_str, device='cuda'): diff --git a/python/triton/language/core.py b/python/triton/language/core.py index 2abc82b0c..15dd8462a 100644 --- a/python/triton/language/core.py +++ b/python/triton/language/core.py @@ -596,11 +596,9 @@ class tensor: if isinstance(slices, slice): slices = [slices] ret = self - n_inserted = 0 for dim, sl in enumerate(slices): if isinstance(sl, constexpr) and sl.value is None: - ret = semantic.expand_dims(ret, dim + n_inserted, _builder) - n_inserted += 1 + ret = semantic.expand_dims(ret, dim, _builder) elif sl == slice(None, None, None): pass else: