diff --git a/python/src/triton.cc b/python/src/triton.cc index 260f83942..81e1b66fe 100644 --- a/python/src/triton.cc +++ b/python/src/triton.cc @@ -878,6 +878,7 @@ void init_triton_ir(py::module &&m) { .def("create_int_cast", &ir::builder::create_int_cast, ret::reference) .def("create_downcast", &ir::builder::create_downcast, ret::reference) .def("create_int_to_ptr", &ir::builder::create_int_to_ptr, ret::reference) + .def("create_ptr_to_int", &ir::builder::create_ptr_to_int, ret::reference) // phi .def("create_phi", &ir::builder::create_phi, ret::reference) // Binary instructions diff --git a/python/test/unit/language/test_core.py b/python/test/unit/language/test_core.py index 561ed6af5..93063d064 100644 --- a/python/test/unit/language/test_core.py +++ b/python/test/unit/language/test_core.py @@ -17,6 +17,7 @@ int_dtypes = ['int8', 'int16', 'int32', 'int64'] uint_dtypes = ['uint8', 'uint16', 'uint32', 'uint64'] float_dtypes = ['float16', 'float32', 'float64'] dtypes = int_dtypes + uint_dtypes + float_dtypes +dtypes_with_bfloat16 = dtypes + ['bfloat16'] def _bitwidth(dtype: str) -> int: @@ -46,6 +47,8 @@ def numpy_random(shape, dtype_str, rs: Optional[RandomState] = None, low=None, h elif dtype_str == 'bfloat16': return (rs.normal(0, 1, shape).astype('float32').view('uint32') & np.uint32(0xffff0000)).view('float32') + elif dtype_str in ['bool', 'int1', 'bool_']: + return rs.normal(0, 1, shape) > 0.0 else: raise RuntimeError(f'Unknown dtype {dtype_str}') @@ -245,8 +248,8 @@ def _mod_operation_ill_conditioned(dtype_x, dtype_y) -> bool: @pytest.mark.parametrize("dtype_x, dtype_y, op", [ (dtype_x, dtype_y, op) for op in ['+', '-', '*', '/', '%'] - for dtype_x in dtypes + ['bfloat16'] - for dtype_y in dtypes + ['bfloat16'] + for dtype_x in dtypes_with_bfloat16 + for dtype_y in dtypes_with_bfloat16 ]) def test_bin_op(dtype_x, dtype_y, op, device='cuda'): expr = f' x {op} y' @@ -296,8 +299,8 @@ def test_floordiv(dtype_x, dtype_y, device='cuda'): @pytest.mark.parametrize("dtype_x, dtype_y, op", [ (dtype_x, dtype_y, op) for op in ['&', '|', '^'] - for dtype_x in dtypes + ['bfloat16'] - for dtype_y in dtypes + ['bfloat16'] + for dtype_x in dtypes + dtypes_with_bfloat16 + for dtype_y in dtypes + dtypes_with_bfloat16 ]) def test_bitwise_op(dtype_x, dtype_y, op, device='cuda'): expr = f'x {op} y' @@ -363,11 +366,55 @@ def test_compare_op(dtype_x, dtype_y, op, mode_x, mode_y, device='cuda'): _test_binary(dtype_x, dtype_y, expr, numpy_expr, mode_x=mode_x, mode_y=mode_y, device=device) +# --------------- +# test where +# --------------- +@pytest.mark.parametrize("dtype", dtypes_with_bfloat16 + ["*int32"]) +def test_where(dtype): + select_ptrs = False + if dtype == "*int32": + dtype = "int64" + select_ptrs = True + check_type_supported(dtype) + + @triton.jit + def where_kernel(cond_ptr, a_ptr, b_ptr, output_ptr, n_elements, + BLOCK_SIZE: tl.constexpr, + TEST_POINTERS: tl.constexpr): + offsets = tl.program_id(axis=0) * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) + mask = offsets < n_elements + decide = tl.load(cond_ptr + offsets, mask=mask) + if TEST_POINTERS: + a = tl.load(a_ptr + offsets, mask=mask).to(tl.pi32_t) + b = tl.load(b_ptr + offsets, mask=mask).to(tl.pi32_t) + else: + a = tl.load(a_ptr + offsets, mask=mask) + b = tl.load(b_ptr + offsets, mask=mask) + output = tl.where(decide, a, b) + tl.store(output_ptr + offsets, output, mask=mask) + + SIZE = 1_000 + rs = RandomState(17) + cond = numpy_random(SIZE, 'bool', rs) + x = numpy_random(SIZE, dtype_str=dtype, rs=rs) + y = numpy_random(SIZE, dtype_str=dtype, rs=rs) + z = np.where(cond, x, y) + + cond_tri = to_triton(cond, device='cuda') + x_tri = to_triton(x, device='cuda', dst_type=dtype) + y_tri = to_triton(y, device='cuda', dst_type=dtype) + z_tri = to_triton(np.empty(SIZE, dtype=z.dtype), device='cuda', dst_type=dtype) + + grid = lambda meta: (triton.cdiv(SIZE, meta['BLOCK_SIZE']),) + where_kernel[grid](cond_tri, x_tri, y_tri, z_tri, SIZE, BLOCK_SIZE=1024, TEST_POINTERS=select_ptrs) + assert (z == to_numpy(z_tri)).all() + + # --------------- # test unary ops # --------------- @pytest.mark.parametrize("dtype_x, expr", [ - (dtype_x, ' -x') for dtype_x in dtypes + ['bfloat16'] + (dtype_x, ' -x') for dtype_x in dtypes_with_bfloat16 ] + [ (dtype_x, ' ~x') for dtype_x in int_dtypes ]) @@ -762,7 +809,7 @@ def test_f16_to_f8_rounding(): @pytest.mark.parametrize("op, dtype_str, shape", [(op, dtype, shape) for op in ['min', 'max', 'argmin', 'argmax', 'sum'] - for dtype in dtypes + ['bfloat16'] + for dtype in dtypes_with_bfloat16 for shape in [32, 64, 128, 512]]) def test_reduce1d(op, dtype_str, shape, device='cuda'): check_type_supported(dtype_str) # bfloat16 on cc < 80 will not be tested @@ -810,7 +857,7 @@ def test_reduce1d(op, dtype_str, shape, device='cuda'): reduce_configs1 = [ - (op, dtype, (1, 1024), axis) for dtype in dtypes + ['bfloat16'] + (op, dtype, (1, 1024), axis) for dtype in dtypes_with_bfloat16 for op in ['min', 'max', 'argmin', 'argmax', 'sum'] for axis in [1] ] diff --git a/python/triton/language/semantic.py b/python/triton/language/semantic.py index 8878a8195..ecd740114 100644 --- a/python/triton/language/semantic.py +++ b/python/triton/language/semantic.py @@ -637,7 +637,7 @@ def cast(input: tl.tensor, if src_sca_ty.is_ptr() and dst_sca_ty.is_int(): bitwidth = dst_sca_ty.int_bitwidth if bitwidth == 64: - return tl.tensor(builder.create_cast(ir.PtrToInt, input.handle, dst_ty.to_ir(builder)), + return tl.tensor(builder.create_ptr_to_int(input.handle, dst_ty.to_ir(builder)), dst_ty) if bitwidth == 1: return not_equal(cast(input, tl.int64, builder), @@ -969,16 +969,8 @@ def where(condition: tl.tensor, x = broadcast_impl_shape(x, condition.type.get_block_shapes(), builder) y = broadcast_impl_shape(y, condition.type.get_block_shapes(), builder) - # TODO: we need to check x's and y's shape? - x_ty = x.type.scalar - y_ty = y.type.scalar - ty = computation_type_impl(x_ty, y_ty, div_or_mod=False) - x = cast(x, ty, builder) - y = cast(y, ty, builder) - if x.type.is_block(): - ret_ty = tl.block_type(ty, x.type.shape) - else: - ret_ty = ty + x, y = binary_op_type_checking_impl(x, y, builder, True, True) + ret_ty = x.type return tl.tensor(builder.create_select(condition.handle, x.handle, y.handle), ret_ty)