[FRONTEND] Allow tl.where to select pointers (#595)

This commit is contained in:
Da Yan
2022-07-22 00:54:27 +08:00
committed by GitHub
parent af85f5fa46
commit f28caddbf8
3 changed files with 58 additions and 18 deletions

View File

@@ -878,6 +878,7 @@ void init_triton_ir(py::module &&m) {
.def("create_int_cast", &ir::builder::create_int_cast, ret::reference) .def("create_int_cast", &ir::builder::create_int_cast, ret::reference)
.def("create_downcast", &ir::builder::create_downcast, ret::reference) .def("create_downcast", &ir::builder::create_downcast, ret::reference)
.def("create_int_to_ptr", &ir::builder::create_int_to_ptr, ret::reference) .def("create_int_to_ptr", &ir::builder::create_int_to_ptr, ret::reference)
.def("create_ptr_to_int", &ir::builder::create_ptr_to_int, ret::reference)
// phi // phi
.def("create_phi", &ir::builder::create_phi, ret::reference) .def("create_phi", &ir::builder::create_phi, ret::reference)
// Binary instructions // Binary instructions

View File

@@ -17,6 +17,7 @@ int_dtypes = ['int8', 'int16', 'int32', 'int64']
uint_dtypes = ['uint8', 'uint16', 'uint32', 'uint64'] uint_dtypes = ['uint8', 'uint16', 'uint32', 'uint64']
float_dtypes = ['float16', 'float32', 'float64'] float_dtypes = ['float16', 'float32', 'float64']
dtypes = int_dtypes + uint_dtypes + float_dtypes dtypes = int_dtypes + uint_dtypes + float_dtypes
dtypes_with_bfloat16 = dtypes + ['bfloat16']
def _bitwidth(dtype: str) -> int: def _bitwidth(dtype: str) -> int:
@@ -46,6 +47,8 @@ def numpy_random(shape, dtype_str, rs: Optional[RandomState] = None, low=None, h
elif dtype_str == 'bfloat16': elif dtype_str == 'bfloat16':
return (rs.normal(0, 1, shape).astype('float32').view('uint32') return (rs.normal(0, 1, shape).astype('float32').view('uint32')
& np.uint32(0xffff0000)).view('float32') & np.uint32(0xffff0000)).view('float32')
elif dtype_str in ['bool', 'int1', 'bool_']:
return rs.normal(0, 1, shape) > 0.0
else: else:
raise RuntimeError(f'Unknown dtype {dtype_str}') raise RuntimeError(f'Unknown dtype {dtype_str}')
@@ -245,8 +248,8 @@ def _mod_operation_ill_conditioned(dtype_x, dtype_y) -> bool:
@pytest.mark.parametrize("dtype_x, dtype_y, op", [ @pytest.mark.parametrize("dtype_x, dtype_y, op", [
(dtype_x, dtype_y, op) (dtype_x, dtype_y, op)
for op in ['+', '-', '*', '/', '%'] for op in ['+', '-', '*', '/', '%']
for dtype_x in dtypes + ['bfloat16'] for dtype_x in dtypes_with_bfloat16
for dtype_y in dtypes + ['bfloat16'] for dtype_y in dtypes_with_bfloat16
]) ])
def test_bin_op(dtype_x, dtype_y, op, device='cuda'): def test_bin_op(dtype_x, dtype_y, op, device='cuda'):
expr = f' x {op} y' expr = f' x {op} y'
@@ -296,8 +299,8 @@ def test_floordiv(dtype_x, dtype_y, device='cuda'):
@pytest.mark.parametrize("dtype_x, dtype_y, op", [ @pytest.mark.parametrize("dtype_x, dtype_y, op", [
(dtype_x, dtype_y, op) (dtype_x, dtype_y, op)
for op in ['&', '|', '^'] for op in ['&', '|', '^']
for dtype_x in dtypes + ['bfloat16'] for dtype_x in dtypes + dtypes_with_bfloat16
for dtype_y in dtypes + ['bfloat16'] for dtype_y in dtypes + dtypes_with_bfloat16
]) ])
def test_bitwise_op(dtype_x, dtype_y, op, device='cuda'): def test_bitwise_op(dtype_x, dtype_y, op, device='cuda'):
expr = f'x {op} y' expr = f'x {op} y'
@@ -363,11 +366,55 @@ def test_compare_op(dtype_x, dtype_y, op, mode_x, mode_y, device='cuda'):
_test_binary(dtype_x, dtype_y, expr, numpy_expr, mode_x=mode_x, mode_y=mode_y, device=device) _test_binary(dtype_x, dtype_y, expr, numpy_expr, mode_x=mode_x, mode_y=mode_y, device=device)
# ---------------
# test where
# ---------------
@pytest.mark.parametrize("dtype", dtypes_with_bfloat16 + ["*int32"])
def test_where(dtype):
select_ptrs = False
if dtype == "*int32":
dtype = "int64"
select_ptrs = True
check_type_supported(dtype)
@triton.jit
def where_kernel(cond_ptr, a_ptr, b_ptr, output_ptr, n_elements,
BLOCK_SIZE: tl.constexpr,
TEST_POINTERS: tl.constexpr):
offsets = tl.program_id(axis=0) * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
mask = offsets < n_elements
decide = tl.load(cond_ptr + offsets, mask=mask)
if TEST_POINTERS:
a = tl.load(a_ptr + offsets, mask=mask).to(tl.pi32_t)
b = tl.load(b_ptr + offsets, mask=mask).to(tl.pi32_t)
else:
a = tl.load(a_ptr + offsets, mask=mask)
b = tl.load(b_ptr + offsets, mask=mask)
output = tl.where(decide, a, b)
tl.store(output_ptr + offsets, output, mask=mask)
SIZE = 1_000
rs = RandomState(17)
cond = numpy_random(SIZE, 'bool', rs)
x = numpy_random(SIZE, dtype_str=dtype, rs=rs)
y = numpy_random(SIZE, dtype_str=dtype, rs=rs)
z = np.where(cond, x, y)
cond_tri = to_triton(cond, device='cuda')
x_tri = to_triton(x, device='cuda', dst_type=dtype)
y_tri = to_triton(y, device='cuda', dst_type=dtype)
z_tri = to_triton(np.empty(SIZE, dtype=z.dtype), device='cuda', dst_type=dtype)
grid = lambda meta: (triton.cdiv(SIZE, meta['BLOCK_SIZE']),)
where_kernel[grid](cond_tri, x_tri, y_tri, z_tri, SIZE, BLOCK_SIZE=1024, TEST_POINTERS=select_ptrs)
assert (z == to_numpy(z_tri)).all()
# --------------- # ---------------
# test unary ops # test unary ops
# --------------- # ---------------
@pytest.mark.parametrize("dtype_x, expr", [ @pytest.mark.parametrize("dtype_x, expr", [
(dtype_x, ' -x') for dtype_x in dtypes + ['bfloat16'] (dtype_x, ' -x') for dtype_x in dtypes_with_bfloat16
] + [ ] + [
(dtype_x, ' ~x') for dtype_x in int_dtypes (dtype_x, ' ~x') for dtype_x in int_dtypes
]) ])
@@ -762,7 +809,7 @@ def test_f16_to_f8_rounding():
@pytest.mark.parametrize("op, dtype_str, shape", @pytest.mark.parametrize("op, dtype_str, shape",
[(op, dtype, shape) [(op, dtype, shape)
for op in ['min', 'max', 'argmin', 'argmax', 'sum'] for op in ['min', 'max', 'argmin', 'argmax', 'sum']
for dtype in dtypes + ['bfloat16'] for dtype in dtypes_with_bfloat16
for shape in [32, 64, 128, 512]]) for shape in [32, 64, 128, 512]])
def test_reduce1d(op, dtype_str, shape, device='cuda'): def test_reduce1d(op, dtype_str, shape, device='cuda'):
check_type_supported(dtype_str) # bfloat16 on cc < 80 will not be tested check_type_supported(dtype_str) # bfloat16 on cc < 80 will not be tested
@@ -810,7 +857,7 @@ def test_reduce1d(op, dtype_str, shape, device='cuda'):
reduce_configs1 = [ reduce_configs1 = [
(op, dtype, (1, 1024), axis) for dtype in dtypes + ['bfloat16'] (op, dtype, (1, 1024), axis) for dtype in dtypes_with_bfloat16
for op in ['min', 'max', 'argmin', 'argmax', 'sum'] for op in ['min', 'max', 'argmin', 'argmax', 'sum']
for axis in [1] for axis in [1]
] ]

View File

@@ -637,7 +637,7 @@ def cast(input: tl.tensor,
if src_sca_ty.is_ptr() and dst_sca_ty.is_int(): if src_sca_ty.is_ptr() and dst_sca_ty.is_int():
bitwidth = dst_sca_ty.int_bitwidth bitwidth = dst_sca_ty.int_bitwidth
if bitwidth == 64: if bitwidth == 64:
return tl.tensor(builder.create_cast(ir.PtrToInt, input.handle, dst_ty.to_ir(builder)), return tl.tensor(builder.create_ptr_to_int(input.handle, dst_ty.to_ir(builder)),
dst_ty) dst_ty)
if bitwidth == 1: if bitwidth == 1:
return not_equal(cast(input, tl.int64, builder), return not_equal(cast(input, tl.int64, builder),
@@ -969,16 +969,8 @@ def where(condition: tl.tensor,
x = broadcast_impl_shape(x, condition.type.get_block_shapes(), builder) x = broadcast_impl_shape(x, condition.type.get_block_shapes(), builder)
y = broadcast_impl_shape(y, condition.type.get_block_shapes(), builder) y = broadcast_impl_shape(y, condition.type.get_block_shapes(), builder)
# TODO: we need to check x's and y's shape? x, y = binary_op_type_checking_impl(x, y, builder, True, True)
x_ty = x.type.scalar ret_ty = x.type
y_ty = y.type.scalar
ty = computation_type_impl(x_ty, y_ty, div_or_mod=False)
x = cast(x, ty, builder)
y = cast(y, ty, builder)
if x.type.is_block():
ret_ty = tl.block_type(ty, x.type.shape)
else:
ret_ty = ty
return tl.tensor(builder.create_select(condition.handle, x.handle, y.handle), ret_ty) return tl.tensor(builder.create_select(condition.handle, x.handle, y.handle), ret_ty)