[FRONTEND] Allow tl.where to select pointers (#595)
This commit is contained in:
@@ -878,6 +878,7 @@ void init_triton_ir(py::module &&m) {
|
||||
.def("create_int_cast", &ir::builder::create_int_cast, ret::reference)
|
||||
.def("create_downcast", &ir::builder::create_downcast, ret::reference)
|
||||
.def("create_int_to_ptr", &ir::builder::create_int_to_ptr, ret::reference)
|
||||
.def("create_ptr_to_int", &ir::builder::create_ptr_to_int, ret::reference)
|
||||
// phi
|
||||
.def("create_phi", &ir::builder::create_phi, ret::reference)
|
||||
// Binary instructions
|
||||
|
@@ -17,6 +17,7 @@ int_dtypes = ['int8', 'int16', 'int32', 'int64']
|
||||
uint_dtypes = ['uint8', 'uint16', 'uint32', 'uint64']
|
||||
float_dtypes = ['float16', 'float32', 'float64']
|
||||
dtypes = int_dtypes + uint_dtypes + float_dtypes
|
||||
dtypes_with_bfloat16 = dtypes + ['bfloat16']
|
||||
|
||||
|
||||
def _bitwidth(dtype: str) -> int:
|
||||
@@ -46,6 +47,8 @@ def numpy_random(shape, dtype_str, rs: Optional[RandomState] = None, low=None, h
|
||||
elif dtype_str == 'bfloat16':
|
||||
return (rs.normal(0, 1, shape).astype('float32').view('uint32')
|
||||
& np.uint32(0xffff0000)).view('float32')
|
||||
elif dtype_str in ['bool', 'int1', 'bool_']:
|
||||
return rs.normal(0, 1, shape) > 0.0
|
||||
else:
|
||||
raise RuntimeError(f'Unknown dtype {dtype_str}')
|
||||
|
||||
@@ -245,8 +248,8 @@ def _mod_operation_ill_conditioned(dtype_x, dtype_y) -> bool:
|
||||
@pytest.mark.parametrize("dtype_x, dtype_y, op", [
|
||||
(dtype_x, dtype_y, op)
|
||||
for op in ['+', '-', '*', '/', '%']
|
||||
for dtype_x in dtypes + ['bfloat16']
|
||||
for dtype_y in dtypes + ['bfloat16']
|
||||
for dtype_x in dtypes_with_bfloat16
|
||||
for dtype_y in dtypes_with_bfloat16
|
||||
])
|
||||
def test_bin_op(dtype_x, dtype_y, op, device='cuda'):
|
||||
expr = f' x {op} y'
|
||||
@@ -296,8 +299,8 @@ def test_floordiv(dtype_x, dtype_y, device='cuda'):
|
||||
@pytest.mark.parametrize("dtype_x, dtype_y, op", [
|
||||
(dtype_x, dtype_y, op)
|
||||
for op in ['&', '|', '^']
|
||||
for dtype_x in dtypes + ['bfloat16']
|
||||
for dtype_y in dtypes + ['bfloat16']
|
||||
for dtype_x in dtypes + dtypes_with_bfloat16
|
||||
for dtype_y in dtypes + dtypes_with_bfloat16
|
||||
])
|
||||
def test_bitwise_op(dtype_x, dtype_y, op, device='cuda'):
|
||||
expr = f'x {op} y'
|
||||
@@ -363,11 +366,55 @@ def test_compare_op(dtype_x, dtype_y, op, mode_x, mode_y, device='cuda'):
|
||||
_test_binary(dtype_x, dtype_y, expr, numpy_expr, mode_x=mode_x, mode_y=mode_y, device=device)
|
||||
|
||||
|
||||
# ---------------
|
||||
# test where
|
||||
# ---------------
|
||||
@pytest.mark.parametrize("dtype", dtypes_with_bfloat16 + ["*int32"])
|
||||
def test_where(dtype):
|
||||
select_ptrs = False
|
||||
if dtype == "*int32":
|
||||
dtype = "int64"
|
||||
select_ptrs = True
|
||||
check_type_supported(dtype)
|
||||
|
||||
@triton.jit
|
||||
def where_kernel(cond_ptr, a_ptr, b_ptr, output_ptr, n_elements,
|
||||
BLOCK_SIZE: tl.constexpr,
|
||||
TEST_POINTERS: tl.constexpr):
|
||||
offsets = tl.program_id(axis=0) * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
|
||||
mask = offsets < n_elements
|
||||
decide = tl.load(cond_ptr + offsets, mask=mask)
|
||||
if TEST_POINTERS:
|
||||
a = tl.load(a_ptr + offsets, mask=mask).to(tl.pi32_t)
|
||||
b = tl.load(b_ptr + offsets, mask=mask).to(tl.pi32_t)
|
||||
else:
|
||||
a = tl.load(a_ptr + offsets, mask=mask)
|
||||
b = tl.load(b_ptr + offsets, mask=mask)
|
||||
output = tl.where(decide, a, b)
|
||||
tl.store(output_ptr + offsets, output, mask=mask)
|
||||
|
||||
SIZE = 1_000
|
||||
rs = RandomState(17)
|
||||
cond = numpy_random(SIZE, 'bool', rs)
|
||||
x = numpy_random(SIZE, dtype_str=dtype, rs=rs)
|
||||
y = numpy_random(SIZE, dtype_str=dtype, rs=rs)
|
||||
z = np.where(cond, x, y)
|
||||
|
||||
cond_tri = to_triton(cond, device='cuda')
|
||||
x_tri = to_triton(x, device='cuda', dst_type=dtype)
|
||||
y_tri = to_triton(y, device='cuda', dst_type=dtype)
|
||||
z_tri = to_triton(np.empty(SIZE, dtype=z.dtype), device='cuda', dst_type=dtype)
|
||||
|
||||
grid = lambda meta: (triton.cdiv(SIZE, meta['BLOCK_SIZE']),)
|
||||
where_kernel[grid](cond_tri, x_tri, y_tri, z_tri, SIZE, BLOCK_SIZE=1024, TEST_POINTERS=select_ptrs)
|
||||
assert (z == to_numpy(z_tri)).all()
|
||||
|
||||
|
||||
# ---------------
|
||||
# test unary ops
|
||||
# ---------------
|
||||
@pytest.mark.parametrize("dtype_x, expr", [
|
||||
(dtype_x, ' -x') for dtype_x in dtypes + ['bfloat16']
|
||||
(dtype_x, ' -x') for dtype_x in dtypes_with_bfloat16
|
||||
] + [
|
||||
(dtype_x, ' ~x') for dtype_x in int_dtypes
|
||||
])
|
||||
@@ -762,7 +809,7 @@ def test_f16_to_f8_rounding():
|
||||
@pytest.mark.parametrize("op, dtype_str, shape",
|
||||
[(op, dtype, shape)
|
||||
for op in ['min', 'max', 'argmin', 'argmax', 'sum']
|
||||
for dtype in dtypes + ['bfloat16']
|
||||
for dtype in dtypes_with_bfloat16
|
||||
for shape in [32, 64, 128, 512]])
|
||||
def test_reduce1d(op, dtype_str, shape, device='cuda'):
|
||||
check_type_supported(dtype_str) # bfloat16 on cc < 80 will not be tested
|
||||
@@ -810,7 +857,7 @@ def test_reduce1d(op, dtype_str, shape, device='cuda'):
|
||||
|
||||
|
||||
reduce_configs1 = [
|
||||
(op, dtype, (1, 1024), axis) for dtype in dtypes + ['bfloat16']
|
||||
(op, dtype, (1, 1024), axis) for dtype in dtypes_with_bfloat16
|
||||
for op in ['min', 'max', 'argmin', 'argmax', 'sum']
|
||||
for axis in [1]
|
||||
]
|
||||
|
@@ -637,7 +637,7 @@ def cast(input: tl.tensor,
|
||||
if src_sca_ty.is_ptr() and dst_sca_ty.is_int():
|
||||
bitwidth = dst_sca_ty.int_bitwidth
|
||||
if bitwidth == 64:
|
||||
return tl.tensor(builder.create_cast(ir.PtrToInt, input.handle, dst_ty.to_ir(builder)),
|
||||
return tl.tensor(builder.create_ptr_to_int(input.handle, dst_ty.to_ir(builder)),
|
||||
dst_ty)
|
||||
if bitwidth == 1:
|
||||
return not_equal(cast(input, tl.int64, builder),
|
||||
@@ -969,16 +969,8 @@ def where(condition: tl.tensor,
|
||||
x = broadcast_impl_shape(x, condition.type.get_block_shapes(), builder)
|
||||
y = broadcast_impl_shape(y, condition.type.get_block_shapes(), builder)
|
||||
|
||||
# TODO: we need to check x's and y's shape?
|
||||
x_ty = x.type.scalar
|
||||
y_ty = y.type.scalar
|
||||
ty = computation_type_impl(x_ty, y_ty, div_or_mod=False)
|
||||
x = cast(x, ty, builder)
|
||||
y = cast(y, ty, builder)
|
||||
if x.type.is_block():
|
||||
ret_ty = tl.block_type(ty, x.type.shape)
|
||||
else:
|
||||
ret_ty = ty
|
||||
x, y = binary_op_type_checking_impl(x, y, builder, True, True)
|
||||
ret_ty = x.type
|
||||
return tl.tensor(builder.create_select(condition.handle, x.handle, y.handle), ret_ty)
|
||||
|
||||
|
||||
|
Reference in New Issue
Block a user