[FRONTEND] Allow tl.where to select pointers (#595)
This commit is contained in:
@@ -878,6 +878,7 @@ void init_triton_ir(py::module &&m) {
|
|||||||
.def("create_int_cast", &ir::builder::create_int_cast, ret::reference)
|
.def("create_int_cast", &ir::builder::create_int_cast, ret::reference)
|
||||||
.def("create_downcast", &ir::builder::create_downcast, ret::reference)
|
.def("create_downcast", &ir::builder::create_downcast, ret::reference)
|
||||||
.def("create_int_to_ptr", &ir::builder::create_int_to_ptr, ret::reference)
|
.def("create_int_to_ptr", &ir::builder::create_int_to_ptr, ret::reference)
|
||||||
|
.def("create_ptr_to_int", &ir::builder::create_ptr_to_int, ret::reference)
|
||||||
// phi
|
// phi
|
||||||
.def("create_phi", &ir::builder::create_phi, ret::reference)
|
.def("create_phi", &ir::builder::create_phi, ret::reference)
|
||||||
// Binary instructions
|
// Binary instructions
|
||||||
|
@@ -17,6 +17,7 @@ int_dtypes = ['int8', 'int16', 'int32', 'int64']
|
|||||||
uint_dtypes = ['uint8', 'uint16', 'uint32', 'uint64']
|
uint_dtypes = ['uint8', 'uint16', 'uint32', 'uint64']
|
||||||
float_dtypes = ['float16', 'float32', 'float64']
|
float_dtypes = ['float16', 'float32', 'float64']
|
||||||
dtypes = int_dtypes + uint_dtypes + float_dtypes
|
dtypes = int_dtypes + uint_dtypes + float_dtypes
|
||||||
|
dtypes_with_bfloat16 = dtypes + ['bfloat16']
|
||||||
|
|
||||||
|
|
||||||
def _bitwidth(dtype: str) -> int:
|
def _bitwidth(dtype: str) -> int:
|
||||||
@@ -46,6 +47,8 @@ def numpy_random(shape, dtype_str, rs: Optional[RandomState] = None, low=None, h
|
|||||||
elif dtype_str == 'bfloat16':
|
elif dtype_str == 'bfloat16':
|
||||||
return (rs.normal(0, 1, shape).astype('float32').view('uint32')
|
return (rs.normal(0, 1, shape).astype('float32').view('uint32')
|
||||||
& np.uint32(0xffff0000)).view('float32')
|
& np.uint32(0xffff0000)).view('float32')
|
||||||
|
elif dtype_str in ['bool', 'int1', 'bool_']:
|
||||||
|
return rs.normal(0, 1, shape) > 0.0
|
||||||
else:
|
else:
|
||||||
raise RuntimeError(f'Unknown dtype {dtype_str}')
|
raise RuntimeError(f'Unknown dtype {dtype_str}')
|
||||||
|
|
||||||
@@ -245,8 +248,8 @@ def _mod_operation_ill_conditioned(dtype_x, dtype_y) -> bool:
|
|||||||
@pytest.mark.parametrize("dtype_x, dtype_y, op", [
|
@pytest.mark.parametrize("dtype_x, dtype_y, op", [
|
||||||
(dtype_x, dtype_y, op)
|
(dtype_x, dtype_y, op)
|
||||||
for op in ['+', '-', '*', '/', '%']
|
for op in ['+', '-', '*', '/', '%']
|
||||||
for dtype_x in dtypes + ['bfloat16']
|
for dtype_x in dtypes_with_bfloat16
|
||||||
for dtype_y in dtypes + ['bfloat16']
|
for dtype_y in dtypes_with_bfloat16
|
||||||
])
|
])
|
||||||
def test_bin_op(dtype_x, dtype_y, op, device='cuda'):
|
def test_bin_op(dtype_x, dtype_y, op, device='cuda'):
|
||||||
expr = f' x {op} y'
|
expr = f' x {op} y'
|
||||||
@@ -296,8 +299,8 @@ def test_floordiv(dtype_x, dtype_y, device='cuda'):
|
|||||||
@pytest.mark.parametrize("dtype_x, dtype_y, op", [
|
@pytest.mark.parametrize("dtype_x, dtype_y, op", [
|
||||||
(dtype_x, dtype_y, op)
|
(dtype_x, dtype_y, op)
|
||||||
for op in ['&', '|', '^']
|
for op in ['&', '|', '^']
|
||||||
for dtype_x in dtypes + ['bfloat16']
|
for dtype_x in dtypes + dtypes_with_bfloat16
|
||||||
for dtype_y in dtypes + ['bfloat16']
|
for dtype_y in dtypes + dtypes_with_bfloat16
|
||||||
])
|
])
|
||||||
def test_bitwise_op(dtype_x, dtype_y, op, device='cuda'):
|
def test_bitwise_op(dtype_x, dtype_y, op, device='cuda'):
|
||||||
expr = f'x {op} y'
|
expr = f'x {op} y'
|
||||||
@@ -363,11 +366,55 @@ def test_compare_op(dtype_x, dtype_y, op, mode_x, mode_y, device='cuda'):
|
|||||||
_test_binary(dtype_x, dtype_y, expr, numpy_expr, mode_x=mode_x, mode_y=mode_y, device=device)
|
_test_binary(dtype_x, dtype_y, expr, numpy_expr, mode_x=mode_x, mode_y=mode_y, device=device)
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------
|
||||||
|
# test where
|
||||||
|
# ---------------
|
||||||
|
@pytest.mark.parametrize("dtype", dtypes_with_bfloat16 + ["*int32"])
|
||||||
|
def test_where(dtype):
|
||||||
|
select_ptrs = False
|
||||||
|
if dtype == "*int32":
|
||||||
|
dtype = "int64"
|
||||||
|
select_ptrs = True
|
||||||
|
check_type_supported(dtype)
|
||||||
|
|
||||||
|
@triton.jit
|
||||||
|
def where_kernel(cond_ptr, a_ptr, b_ptr, output_ptr, n_elements,
|
||||||
|
BLOCK_SIZE: tl.constexpr,
|
||||||
|
TEST_POINTERS: tl.constexpr):
|
||||||
|
offsets = tl.program_id(axis=0) * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
|
||||||
|
mask = offsets < n_elements
|
||||||
|
decide = tl.load(cond_ptr + offsets, mask=mask)
|
||||||
|
if TEST_POINTERS:
|
||||||
|
a = tl.load(a_ptr + offsets, mask=mask).to(tl.pi32_t)
|
||||||
|
b = tl.load(b_ptr + offsets, mask=mask).to(tl.pi32_t)
|
||||||
|
else:
|
||||||
|
a = tl.load(a_ptr + offsets, mask=mask)
|
||||||
|
b = tl.load(b_ptr + offsets, mask=mask)
|
||||||
|
output = tl.where(decide, a, b)
|
||||||
|
tl.store(output_ptr + offsets, output, mask=mask)
|
||||||
|
|
||||||
|
SIZE = 1_000
|
||||||
|
rs = RandomState(17)
|
||||||
|
cond = numpy_random(SIZE, 'bool', rs)
|
||||||
|
x = numpy_random(SIZE, dtype_str=dtype, rs=rs)
|
||||||
|
y = numpy_random(SIZE, dtype_str=dtype, rs=rs)
|
||||||
|
z = np.where(cond, x, y)
|
||||||
|
|
||||||
|
cond_tri = to_triton(cond, device='cuda')
|
||||||
|
x_tri = to_triton(x, device='cuda', dst_type=dtype)
|
||||||
|
y_tri = to_triton(y, device='cuda', dst_type=dtype)
|
||||||
|
z_tri = to_triton(np.empty(SIZE, dtype=z.dtype), device='cuda', dst_type=dtype)
|
||||||
|
|
||||||
|
grid = lambda meta: (triton.cdiv(SIZE, meta['BLOCK_SIZE']),)
|
||||||
|
where_kernel[grid](cond_tri, x_tri, y_tri, z_tri, SIZE, BLOCK_SIZE=1024, TEST_POINTERS=select_ptrs)
|
||||||
|
assert (z == to_numpy(z_tri)).all()
|
||||||
|
|
||||||
|
|
||||||
# ---------------
|
# ---------------
|
||||||
# test unary ops
|
# test unary ops
|
||||||
# ---------------
|
# ---------------
|
||||||
@pytest.mark.parametrize("dtype_x, expr", [
|
@pytest.mark.parametrize("dtype_x, expr", [
|
||||||
(dtype_x, ' -x') for dtype_x in dtypes + ['bfloat16']
|
(dtype_x, ' -x') for dtype_x in dtypes_with_bfloat16
|
||||||
] + [
|
] + [
|
||||||
(dtype_x, ' ~x') for dtype_x in int_dtypes
|
(dtype_x, ' ~x') for dtype_x in int_dtypes
|
||||||
])
|
])
|
||||||
@@ -762,7 +809,7 @@ def test_f16_to_f8_rounding():
|
|||||||
@pytest.mark.parametrize("op, dtype_str, shape",
|
@pytest.mark.parametrize("op, dtype_str, shape",
|
||||||
[(op, dtype, shape)
|
[(op, dtype, shape)
|
||||||
for op in ['min', 'max', 'argmin', 'argmax', 'sum']
|
for op in ['min', 'max', 'argmin', 'argmax', 'sum']
|
||||||
for dtype in dtypes + ['bfloat16']
|
for dtype in dtypes_with_bfloat16
|
||||||
for shape in [32, 64, 128, 512]])
|
for shape in [32, 64, 128, 512]])
|
||||||
def test_reduce1d(op, dtype_str, shape, device='cuda'):
|
def test_reduce1d(op, dtype_str, shape, device='cuda'):
|
||||||
check_type_supported(dtype_str) # bfloat16 on cc < 80 will not be tested
|
check_type_supported(dtype_str) # bfloat16 on cc < 80 will not be tested
|
||||||
@@ -810,7 +857,7 @@ def test_reduce1d(op, dtype_str, shape, device='cuda'):
|
|||||||
|
|
||||||
|
|
||||||
reduce_configs1 = [
|
reduce_configs1 = [
|
||||||
(op, dtype, (1, 1024), axis) for dtype in dtypes + ['bfloat16']
|
(op, dtype, (1, 1024), axis) for dtype in dtypes_with_bfloat16
|
||||||
for op in ['min', 'max', 'argmin', 'argmax', 'sum']
|
for op in ['min', 'max', 'argmin', 'argmax', 'sum']
|
||||||
for axis in [1]
|
for axis in [1]
|
||||||
]
|
]
|
||||||
|
@@ -637,7 +637,7 @@ def cast(input: tl.tensor,
|
|||||||
if src_sca_ty.is_ptr() and dst_sca_ty.is_int():
|
if src_sca_ty.is_ptr() and dst_sca_ty.is_int():
|
||||||
bitwidth = dst_sca_ty.int_bitwidth
|
bitwidth = dst_sca_ty.int_bitwidth
|
||||||
if bitwidth == 64:
|
if bitwidth == 64:
|
||||||
return tl.tensor(builder.create_cast(ir.PtrToInt, input.handle, dst_ty.to_ir(builder)),
|
return tl.tensor(builder.create_ptr_to_int(input.handle, dst_ty.to_ir(builder)),
|
||||||
dst_ty)
|
dst_ty)
|
||||||
if bitwidth == 1:
|
if bitwidth == 1:
|
||||||
return not_equal(cast(input, tl.int64, builder),
|
return not_equal(cast(input, tl.int64, builder),
|
||||||
@@ -969,16 +969,8 @@ def where(condition: tl.tensor,
|
|||||||
x = broadcast_impl_shape(x, condition.type.get_block_shapes(), builder)
|
x = broadcast_impl_shape(x, condition.type.get_block_shapes(), builder)
|
||||||
y = broadcast_impl_shape(y, condition.type.get_block_shapes(), builder)
|
y = broadcast_impl_shape(y, condition.type.get_block_shapes(), builder)
|
||||||
|
|
||||||
# TODO: we need to check x's and y's shape?
|
x, y = binary_op_type_checking_impl(x, y, builder, True, True)
|
||||||
x_ty = x.type.scalar
|
ret_ty = x.type
|
||||||
y_ty = y.type.scalar
|
|
||||||
ty = computation_type_impl(x_ty, y_ty, div_or_mod=False)
|
|
||||||
x = cast(x, ty, builder)
|
|
||||||
y = cast(y, ty, builder)
|
|
||||||
if x.type.is_block():
|
|
||||||
ret_ty = tl.block_type(ty, x.type.shape)
|
|
||||||
else:
|
|
||||||
ret_ty = ty
|
|
||||||
return tl.tensor(builder.create_select(condition.handle, x.handle, y.handle), ret_ty)
|
return tl.tensor(builder.create_select(condition.handle, x.handle, y.handle), ret_ty)
|
||||||
|
|
||||||
|
|
||||||
|
Reference in New Issue
Block a user