[TritonMLIR] Disallow 0D tensor (#788)
This commit is contained in:
@@ -1104,8 +1104,11 @@ void init_triton_ir(py::module &&m) {
|
||||
operand.getType().dyn_cast<mlir::RankedTensorType>();
|
||||
std::vector<int64_t> shape = inputTensorType.getShape();
|
||||
shape.erase(shape.begin() + axis);
|
||||
auto resType = mlir::RankedTensorType::get(
|
||||
shape, inputTensorType.getElementType());
|
||||
mlir::Type resType = inputTensorType.getElementType();
|
||||
if (!shape.empty()) {
|
||||
resType = mlir::RankedTensorType::get(
|
||||
shape, inputTensorType.getElementType());
|
||||
}
|
||||
return self.create<mlir::triton::ReduceOp>(loc, resType, redOp,
|
||||
operand, axis);
|
||||
})
|
||||
|
@@ -1,57 +0,0 @@
|
||||
import triton
|
||||
import triton.language as tl
|
||||
|
||||
|
||||
# TODO: function with no arguments don't work
|
||||
@triton.jit
|
||||
def cast_check(X):
|
||||
zero_0d = tl.zeros([], dtype=tl.float32)
|
||||
zero_1d = tl.zeros([2], dtype=tl.float32)
|
||||
zero_2d_21 = tl.zeros([2, 1], dtype=tl.float32)
|
||||
zero_2d_22 = tl.zeros([2, 2], dtype=tl.float32)
|
||||
|
||||
# scalar + scalar -> scalar
|
||||
a0 = 0.0 + 0.0
|
||||
# scalar + 0D -> 0D
|
||||
a1 = 0.0 + zero_0d
|
||||
a2 = zero_0d + 0.0
|
||||
# scalar + 1D -> 1D
|
||||
a3 = 0.0 + zero_1d
|
||||
a4 = zero_1d + 0.0
|
||||
# scalar + 2D -> 2D
|
||||
a5 = 0.0 + zero_2d_22
|
||||
a6 = zero_2d_22 + 0.0
|
||||
|
||||
# 0D + 0D -> 0D
|
||||
b1 = zero_0d + zero_0d
|
||||
# 0D + 1D -> 1D
|
||||
b2 = zero_0d + zero_1d
|
||||
b3 = zero_1d + zero_0d
|
||||
# 0D + 2D -> 2D
|
||||
b4 = zero_0d + zero_2d_22
|
||||
b5 = zero_2d_22 + zero_0d
|
||||
|
||||
# 1D + 1D -> 1D
|
||||
c1 = zero_1d + zero_1d
|
||||
# 1D + 2D -> 2D
|
||||
c2 = zero_1d + zero_2d_21
|
||||
c3 = zero_1d + zero_2d_22
|
||||
c4 = zero_2d_21 + zero_1d
|
||||
c5 = zero_2d_22 + zero_1d
|
||||
|
||||
# 2D + 2D -> 2D
|
||||
d1 = zero_2d_21 + zero_2d_21
|
||||
d2 = zero_2d_22 + zero_2d_22
|
||||
d3 = zero_2d_21 + zero_2d_22
|
||||
d4 = zero_2d_22 + zero_2d_21
|
||||
|
||||
return a0, a1, a2, a3, a4, a5, a6, b1, b2, b3, b4, b5, c1, c2, c3, c4, c5, d1, d2, d3, d4
|
||||
|
||||
|
||||
def test_cast_check():
|
||||
kernel = triton.compiler._compile(cast_check,
|
||||
signature="*fp32",
|
||||
device=0,
|
||||
output="ttgir")
|
||||
assert (kernel)
|
||||
# TODO: Check types of the results
|
80
python/tests/test_type.py
Normal file
80
python/tests/test_type.py
Normal file
@@ -0,0 +1,80 @@
|
||||
import triton
|
||||
import triton.language as tl
|
||||
|
||||
|
||||
# TODO: function with no arguments don't work
|
||||
@triton.jit
|
||||
def binop_type_check(X):
|
||||
# 0d-tensor is not allowed.
|
||||
# zero_0d = tl.zeros([], dtype=tl.float32)
|
||||
zero_1d = tl.zeros([2], dtype=tl.float32)
|
||||
zero_2d_21 = tl.zeros([2, 1], dtype=tl.float32)
|
||||
zero_2d_22 = tl.zeros([2, 2], dtype=tl.float32)
|
||||
|
||||
# scalar + scalar -> scalar
|
||||
a0 = 0.0 + 0.0
|
||||
# # scalar + 0D -> 0D
|
||||
# a1 = 0.0 + zero_0d
|
||||
# a2 = zero_0d + 0.0
|
||||
# scalar + 1D -> 1D
|
||||
a3 = 0.0 + zero_1d
|
||||
a4 = zero_1d + 0.0
|
||||
# scalar + 2D -> 2D
|
||||
a5 = 0.0 + zero_2d_22
|
||||
a6 = zero_2d_22 + 0.0
|
||||
|
||||
# # 0D + 0D -> 0D
|
||||
# b1 = zero_0d + zero_0d
|
||||
# # 0D + 1D -> 1D
|
||||
# b2 = zero_0d + zero_1d
|
||||
# b3 = zero_1d + zero_0d
|
||||
# # 0D + 2D -> 2D
|
||||
# b4 = zero_0d + zero_2d_22
|
||||
# b5 = zero_2d_22 + zero_0d
|
||||
|
||||
# 1D + 1D -> 1D
|
||||
c1 = zero_1d + zero_1d
|
||||
# 1D + 2D -> 2D
|
||||
c2 = zero_1d + zero_2d_21
|
||||
c3 = zero_1d + zero_2d_22
|
||||
c4 = zero_2d_21 + zero_1d
|
||||
c5 = zero_2d_22 + zero_1d
|
||||
|
||||
# 2D + 2D -> 2D
|
||||
d1 = zero_2d_21 + zero_2d_21
|
||||
d2 = zero_2d_22 + zero_2d_22
|
||||
d3 = zero_2d_21 + zero_2d_22
|
||||
d4 = zero_2d_22 + zero_2d_21
|
||||
|
||||
# return a0, a1, a2, a3, a4, a5, a6, b1, b2, b3, b4, b5, c1, c2, c3, c4, c5, d1, d2, d3, d4
|
||||
return a0, a3, a4, a5, a6, c1, c2, c3, c4, c5, d1, d2, d3, d4
|
||||
|
||||
|
||||
def test_binop_type_check():
|
||||
kernel = triton.compiler._compile(binop_type_check,
|
||||
signature="*fp32",
|
||||
device=0,
|
||||
output="ttgir")
|
||||
assert (kernel)
|
||||
# TODO: Check types of the results
|
||||
|
||||
|
||||
@triton.jit
|
||||
def reduce_type_check(ptr):
|
||||
v_32 = tl.load(ptr + tl.arange(0, 32))
|
||||
v_scalar = tl.min(v_32, axis=0)
|
||||
tl.store(ptr, v_scalar)
|
||||
v_64x128 = tl.load(ptr + tl.arange(0, 64)[:, None] + tl.arange(0, 128)[None, :])
|
||||
v_64 = tl.max(v_64x128, axis=1)
|
||||
tl.store(ptr + tl.arange(0, 64), v_64)
|
||||
v_128 = tl.max(v_64x128, axis=0)
|
||||
tl.store(ptr + tl.arange(0, 128), v_128)
|
||||
|
||||
|
||||
def test_reduce_type_check():
|
||||
kernel = triton.compiler._compile(reduce_type_check,
|
||||
signature="*fp32",
|
||||
device=0,
|
||||
output="ttgir")
|
||||
assert (kernel)
|
||||
# TODO: Check types of the results
|
@@ -241,7 +241,9 @@ class block_type(dtype):
|
||||
# while tensor's shape is a list of constexpr.
|
||||
|
||||
# shape can be empty ([]) when an input is a 0D tensor.
|
||||
if shape and isinstance(shape[0], constexpr):
|
||||
if not shape:
|
||||
raise TypeError('0d block_type is forbidden')
|
||||
if isinstance(shape[0], constexpr):
|
||||
shape = [s.value for s in shape]
|
||||
|
||||
self.shape = shape
|
||||
|
@@ -991,7 +991,11 @@ def reduce_impl(input: tl.tensor, axis: int, builder: ir.builder, name: str,
|
||||
for i, s in enumerate(shape):
|
||||
if i != axis:
|
||||
ret_shape.append(s)
|
||||
res_ty = tl.block_type(scalar_ty, ret_shape)
|
||||
if ret_shape:
|
||||
res_ty = tl.block_type(scalar_ty, ret_shape)
|
||||
else:
|
||||
# 0d-tensor -> scalar
|
||||
res_ty = scalar_ty
|
||||
|
||||
if scalar_ty.is_floating():
|
||||
return tl.tensor(builder.create_reduce(input.handle, FLOAT_OP, axis), res_ty)
|
||||
|
Reference in New Issue
Block a user