[FRONTEND] Fix the implicit broadcasting rule (#663)
This PR solves the cast issue that appears in some tutorial code.
This commit is contained in:
56
python/tests/test_cast.py
Normal file
56
python/tests/test_cast.py
Normal file
@@ -0,0 +1,56 @@
|
|||||||
|
import triton
|
||||||
|
import triton.language as tl
|
||||||
|
|
||||||
|
|
||||||
|
@triton.jit
|
||||||
|
def cast_check():
|
||||||
|
zero_0d = tl.zeros([], dtype=tl.float32)
|
||||||
|
zero_1d = tl.zeros([2], dtype=tl.float32)
|
||||||
|
zero_2d_21 = tl.zeros([2, 1], dtype=tl.float32)
|
||||||
|
zero_2d_22 = tl.zeros([2, 2], dtype=tl.float32)
|
||||||
|
|
||||||
|
# scalar + scalar -> scalar
|
||||||
|
a0 = 0.0 + 0.0
|
||||||
|
# scalar + 0D -> 0D
|
||||||
|
a1 = 0.0 + zero_0d
|
||||||
|
a2 = zero_0d + 0.0
|
||||||
|
# scalar + 1D -> 1D
|
||||||
|
a3 = 0.0 + zero_1d
|
||||||
|
a4 = zero_1d + 0.0
|
||||||
|
# scalar + 2D -> 2D
|
||||||
|
a5 = 0.0 + zero_2d_22
|
||||||
|
a6 = zero_2d_22 + 0.0
|
||||||
|
|
||||||
|
# 0D + 0D -> 0D
|
||||||
|
b1 = zero_0d + zero_0d
|
||||||
|
# 0D + 1D -> 1D
|
||||||
|
b2 = zero_0d + zero_1d
|
||||||
|
b3 = zero_1d + zero_0d
|
||||||
|
# 0D + 2D -> 2D
|
||||||
|
b4 = zero_0d + zero_2d_22
|
||||||
|
b5 = zero_2d_22 + zero_0d
|
||||||
|
|
||||||
|
# 1D + 1D -> 1D
|
||||||
|
c1 = zero_1d + zero_1d
|
||||||
|
# 1D + 2D -> 2D
|
||||||
|
c2 = zero_1d + zero_2d_21
|
||||||
|
c3 = zero_1d + zero_2d_22
|
||||||
|
c4 = zero_2d_21 + zero_1d
|
||||||
|
c5 = zero_2d_22 + zero_1d
|
||||||
|
|
||||||
|
# 2D + 2D -> 2D
|
||||||
|
d1 = zero_2d_21 + zero_2d_21
|
||||||
|
d2 = zero_2d_22 + zero_2d_22
|
||||||
|
d3 = zero_2d_21 + zero_2d_22
|
||||||
|
d4 = zero_2d_22 + zero_2d_21
|
||||||
|
|
||||||
|
return a0, a1, a2, a3, a4, a5, a6, b1, b2, b3, b4, b5, c1, c2, c3, c4, c5, d1, d2, d3, d4
|
||||||
|
|
||||||
|
|
||||||
|
def test_cast_check():
|
||||||
|
kernel = triton.compile(cast_check,
|
||||||
|
signature="",
|
||||||
|
device=0,
|
||||||
|
output="ttir")
|
||||||
|
assert (kernel)
|
||||||
|
# TODO: Check types of the results
|
@@ -752,8 +752,11 @@ def make_triton_ir(fn, signature, constants=dict(), attributes=dict()):
|
|||||||
# create kernel prototype
|
# create kernel prototype
|
||||||
constants = {fn.arg_names.index(name): value for name, value in constants.items()}
|
constants = {fn.arg_names.index(name): value for name, value in constants.items()}
|
||||||
attributes = {fn.arg_names.index(name): value for name, value in attributes.items()}
|
attributes = {fn.arg_names.index(name): value for name, value in attributes.items()}
|
||||||
arg_types = signature.replace(' ', '').split(',')
|
if signature.replace(' ', '') != '':
|
||||||
arg_types = [str_to_ty(x) for x in arg_types]
|
arg_types = signature.replace(' ', '').split(',')
|
||||||
|
arg_types = [str_to_ty(x) for x in arg_types]
|
||||||
|
else:
|
||||||
|
arg_types = []
|
||||||
prototype = triton.language.function_type([], arg_types)
|
prototype = triton.language.function_type([], arg_types)
|
||||||
# visit kernel AST
|
# visit kernel AST
|
||||||
gscope = fn.__globals__.copy()
|
gscope = fn.__globals__.copy()
|
||||||
|
@@ -239,8 +239,9 @@ class block_type(dtype):
|
|||||||
|
|
||||||
# Note that block_type's shape is a list of int
|
# Note that block_type's shape is a list of int
|
||||||
# while tensor's shape is a list of constexpr.
|
# while tensor's shape is a list of constexpr.
|
||||||
assert shape
|
|
||||||
if isinstance(shape[0], constexpr):
|
# shape can be empty ([]) when an input is a 0D tensor.
|
||||||
|
if shape and isinstance(shape[0], constexpr):
|
||||||
shape = [s.value for s in shape]
|
shape = [s.value for s in shape]
|
||||||
|
|
||||||
self.shape = shape
|
self.shape = shape
|
||||||
|
@@ -508,8 +508,21 @@ def broadcast_impl_value(lhs: tl.tensor,
|
|||||||
elif lhs_ty.is_block() and rhs_ty.is_block():
|
elif lhs_ty.is_block() and rhs_ty.is_block():
|
||||||
lhs_shape = lhs_ty.get_block_shapes()
|
lhs_shape = lhs_ty.get_block_shapes()
|
||||||
rhs_shape = rhs_ty.get_block_shapes()
|
rhs_shape = rhs_ty.get_block_shapes()
|
||||||
if len(lhs_shape) != len(rhs_shape):
|
|
||||||
raise ValueError("Cannot make_shape_compatible: blocks must have the same rank")
|
if len(lhs_shape) < len(rhs_shape):
|
||||||
|
# Add new axes to lhs
|
||||||
|
for dim in range(len(lhs_shape), len(rhs_shape)):
|
||||||
|
lhs = tl.tensor(builder.create_expand_dims(lhs.handle, dim), tl.block_type(lhs_ty.scalar, lhs_shape + [1]))
|
||||||
|
lhs_ty = lhs.type
|
||||||
|
lhs_shape = lhs_ty.get_block_shapes()
|
||||||
|
elif len(rhs_shape) < len(lhs_shape):
|
||||||
|
# Add new axes to rhs
|
||||||
|
for dim in range(len(rhs_shape), len(lhs_shape)):
|
||||||
|
rhs = tl.tensor(builder.create_expand_dims(rhs.handle, dim), tl.block_type(rhs_ty.scalar, rhs_shape + [1]))
|
||||||
|
rhs_ty = rhs.type
|
||||||
|
rhs_shape = rhs_ty.get_block_shapes()
|
||||||
|
assert len(rhs_shape) == len(lhs_shape)
|
||||||
|
|
||||||
ret_shape = []
|
ret_shape = []
|
||||||
for i in range(len(lhs_shape)):
|
for i in range(len(lhs_shape)):
|
||||||
left = lhs_shape[i]
|
left = lhs_shape[i]
|
||||||
@@ -962,10 +975,7 @@ def reduce_impl(input: tl.tensor, axis: int, builder: ir.builder, name: str,
|
|||||||
for i, s in enumerate(shape):
|
for i, s in enumerate(shape):
|
||||||
if i != axis:
|
if i != axis:
|
||||||
ret_shape.append(s)
|
ret_shape.append(s)
|
||||||
if len(ret_shape) == 0:
|
res_ty = tl.block_type(scalar_ty, ret_shape)
|
||||||
res_ty = scalar_ty
|
|
||||||
else:
|
|
||||||
res_ty = tl.block_type(scalar_ty, ret_shape)
|
|
||||||
|
|
||||||
if scalar_ty.is_floating():
|
if scalar_ty.is_floating():
|
||||||
return tl.tensor(builder.create_reduce(input.handle, FLOAT_OP, axis), res_ty)
|
return tl.tensor(builder.create_reduce(input.handle, FLOAT_OP, axis), res_ty)
|
||||||
|
Reference in New Issue
Block a user