From e9e1a4e6822cc7a4769556212590c0a8b72fc152 Mon Sep 17 00:00:00 2001 From: Shintaro Iwasaki Date: Fri, 16 Sep 2022 10:49:15 -0700 Subject: [PATCH] [FRONTEND] Fix the implicit broadcasting rule (#663) This PR solves the cast issue that appears in some tutorial code. --- python/tests/test_cast.py | 56 ++++++++++++++++++++++++++++++ python/triton/compiler.py | 7 ++-- python/triton/language/core.py | 5 +-- python/triton/language/semantic.py | 22 ++++++++---- 4 files changed, 80 insertions(+), 10 deletions(-) create mode 100644 python/tests/test_cast.py diff --git a/python/tests/test_cast.py b/python/tests/test_cast.py new file mode 100644 index 000000000..cc7793aa0 --- /dev/null +++ b/python/tests/test_cast.py @@ -0,0 +1,56 @@ +import triton +import triton.language as tl + + +@triton.jit +def cast_check(): + zero_0d = tl.zeros([], dtype=tl.float32) + zero_1d = tl.zeros([2], dtype=tl.float32) + zero_2d_21 = tl.zeros([2, 1], dtype=tl.float32) + zero_2d_22 = tl.zeros([2, 2], dtype=tl.float32) + + # scalar + scalar -> scalar + a0 = 0.0 + 0.0 + # scalar + 0D -> 0D + a1 = 0.0 + zero_0d + a2 = zero_0d + 0.0 + # scalar + 1D -> 1D + a3 = 0.0 + zero_1d + a4 = zero_1d + 0.0 + # scalar + 2D -> 2D + a5 = 0.0 + zero_2d_22 + a6 = zero_2d_22 + 0.0 + + # 0D + 0D -> 0D + b1 = zero_0d + zero_0d + # 0D + 1D -> 1D + b2 = zero_0d + zero_1d + b3 = zero_1d + zero_0d + # 0D + 2D -> 2D + b4 = zero_0d + zero_2d_22 + b5 = zero_2d_22 + zero_0d + + # 1D + 1D -> 1D + c1 = zero_1d + zero_1d + # 1D + 2D -> 2D + c2 = zero_1d + zero_2d_21 + c3 = zero_1d + zero_2d_22 + c4 = zero_2d_21 + zero_1d + c5 = zero_2d_22 + zero_1d + + # 2D + 2D -> 2D + d1 = zero_2d_21 + zero_2d_21 + d2 = zero_2d_22 + zero_2d_22 + d3 = zero_2d_21 + zero_2d_22 + d4 = zero_2d_22 + zero_2d_21 + + return a0, a1, a2, a3, a4, a5, a6, b1, b2, b3, b4, b5, c1, c2, c3, c4, c5, d1, d2, d3, d4 + + +def test_cast_check(): + kernel = triton.compile(cast_check, + signature="", + device=0, + output="ttir") + assert (kernel) + # TODO: Check types of the results diff --git a/python/triton/compiler.py b/python/triton/compiler.py index f9f9a53c5..2d871f523 100644 --- a/python/triton/compiler.py +++ b/python/triton/compiler.py @@ -752,8 +752,11 @@ def make_triton_ir(fn, signature, constants=dict(), attributes=dict()): # create kernel prototype constants = {fn.arg_names.index(name): value for name, value in constants.items()} attributes = {fn.arg_names.index(name): value for name, value in attributes.items()} - arg_types = signature.replace(' ', '').split(',') - arg_types = [str_to_ty(x) for x in arg_types] + if signature.replace(' ', '') != '': + arg_types = signature.replace(' ', '').split(',') + arg_types = [str_to_ty(x) for x in arg_types] + else: + arg_types = [] prototype = triton.language.function_type([], arg_types) # visit kernel AST gscope = fn.__globals__.copy() diff --git a/python/triton/language/core.py b/python/triton/language/core.py index 26277257b..39245676b 100644 --- a/python/triton/language/core.py +++ b/python/triton/language/core.py @@ -239,8 +239,9 @@ class block_type(dtype): # Note that block_type's shape is a list of int # while tensor's shape is a list of constexpr. - assert shape - if isinstance(shape[0], constexpr): + + # shape can be empty ([]) when an input is a 0D tensor. + if shape and isinstance(shape[0], constexpr): shape = [s.value for s in shape] self.shape = shape diff --git a/python/triton/language/semantic.py b/python/triton/language/semantic.py index dd3dc0aa5..4084a2cc5 100644 --- a/python/triton/language/semantic.py +++ b/python/triton/language/semantic.py @@ -508,8 +508,21 @@ def broadcast_impl_value(lhs: tl.tensor, elif lhs_ty.is_block() and rhs_ty.is_block(): lhs_shape = lhs_ty.get_block_shapes() rhs_shape = rhs_ty.get_block_shapes() - if len(lhs_shape) != len(rhs_shape): - raise ValueError("Cannot make_shape_compatible: blocks must have the same rank") + + if len(lhs_shape) < len(rhs_shape): + # Add new axes to lhs + for dim in range(len(lhs_shape), len(rhs_shape)): + lhs = tl.tensor(builder.create_expand_dims(lhs.handle, dim), tl.block_type(lhs_ty.scalar, lhs_shape + [1])) + lhs_ty = lhs.type + lhs_shape = lhs_ty.get_block_shapes() + elif len(rhs_shape) < len(lhs_shape): + # Add new axes to rhs + for dim in range(len(rhs_shape), len(lhs_shape)): + rhs = tl.tensor(builder.create_expand_dims(rhs.handle, dim), tl.block_type(rhs_ty.scalar, rhs_shape + [1])) + rhs_ty = rhs.type + rhs_shape = rhs_ty.get_block_shapes() + assert len(rhs_shape) == len(lhs_shape) + ret_shape = [] for i in range(len(lhs_shape)): left = lhs_shape[i] @@ -962,10 +975,7 @@ def reduce_impl(input: tl.tensor, axis: int, builder: ir.builder, name: str, for i, s in enumerate(shape): if i != axis: ret_shape.append(s) - if len(ret_shape) == 0: - res_ty = scalar_ty - else: - res_ty = tl.block_type(scalar_ty, ret_shape) + res_ty = tl.block_type(scalar_ty, ret_shape) if scalar_ty.is_floating(): return tl.tensor(builder.create_reduce(input.handle, FLOAT_OP, axis), res_ty)