[TritonMLIR] Disallow 0D tensor (#788)

This commit is contained in:
Shintaro Iwasaki
2022-10-19 10:34:32 -07:00
committed by GitHub
parent 4464646efb
commit 0d22d2bc03
8 changed files with 203 additions and 87 deletions

View File

@@ -108,8 +108,7 @@ def TT_LoadOp : TT_Op<"load",
AttrSizedOperandSegments,
MemoryEffects<[MemRead]>,
TypesMatchWith<"infer ptr type from result type",
"result", "ptr",
"getPointerTypeFromTensor($_self)">,
"result", "ptr", "getPointerTypeSameShape($_self)">,
TypesMatchWith<"infer mask type from result type or none",
"result", "mask", "getI1SameShape($_self)",
"($_op.getOperands().size() <= 1) || std::equal_to<>()">,
@@ -118,7 +117,7 @@ def TT_LoadOp : TT_Op<"load",
"($_op.getOperands().size() <= 2) || std::equal_to<>()">]> {
let summary = "load";
let arguments = (ins TT_PtrTensor:$ptr, Optional<I1Tensor>:$mask, Optional<TT_Type>:$other,
let arguments = (ins TT_PtrLike:$ptr, Optional<TT_BoolLike>:$mask, Optional<TT_Type>:$other,
TT_CacheModifierAttr:$cache, TT_EvictionPolicyAttr:$evict,
BoolAttr:$isVolatile);
@@ -147,13 +146,13 @@ def TT_StoreOp : TT_Op<"store",
MemoryEffects<[MemWrite]>,
TypesMatchWith<"infer ptr type from value type",
"value", "ptr",
"getPointerTypeFromTensor($_self)">,
"getPointerTypeSameShape($_self)">,
TypesMatchWith<"infer mask type from value type",
"value", "mask", "getI1SameShape($_self)",
"($_op.getOperands().size() <= 2) || std::equal_to<>()">]> {
let summary = "store";
let arguments = (ins TT_PtrTensor:$ptr, TT_Type:$value, Optional<I1Tensor>:$mask);
let arguments = (ins TT_PtrLike:$ptr, TT_Type:$value, Optional<TT_BoolLike>:$mask);
let builders = [
OpBuilder<(ins "Value":$ptr, "Value":$value)>,
@@ -318,7 +317,7 @@ def TT_ReduceOp : TT_Op<"reduce", [NoSideEffect,
let arguments = (ins TT_RedOpAttr:$redOp, TT_Tensor:$operand, I32Attr:$axis);
let results = (outs TT_Tensor:$result);
let results = (outs TT_Type:$result);
let builders = [
OpBuilder<(ins "triton::RedOp":$redOp, "Value":$operand, "int":$axis)>,

View File

@@ -26,14 +26,15 @@ static Type getI32SameShape(Type type) {
return i32Type;
}
static Type getPointerTypeFromTensor(Type type) {
static Type getPointerTypeSameShape(Type type) {
if (auto tensorType = type.dyn_cast<RankedTensorType>()) {
Type elementType = tensorType.getElementType();
auto shape = tensorType.getShape();
PointerType ptrType = PointerType::get(elementType, 1);
return RankedTensorType::get(shape, ptrType, tensorType.getEncoding());
} else {
return PointerType::get(type, 1);
}
return Type();
}
// Parser & printer for assembly forms
@@ -49,7 +50,7 @@ ParseResult parseLoadOp(OpAsmParser &parser, OperationState &result) {
result.addTypes(resultTypes);
SmallVector<Type> operandTypes;
operandTypes.push_back(getPointerTypeFromTensor(resultTypes[0])); // ptr
operandTypes.push_back(getPointerTypeSameShape(resultTypes[0])); // ptr
int hasMask = 0, hasOther = 0;
if (allOperands.size() >= 2) {
operandTypes.push_back(getI1SameShape(resultTypes[0])); // mask
@@ -92,8 +93,8 @@ ParseResult parseStoreOp(OpAsmParser &parser, OperationState &result) {
return failure();
SmallVector<Type> operandTypes;
operandTypes.push_back(getPointerTypeFromTensor(valueType)); // ptr
operandTypes.push_back(valueType); // value
operandTypes.push_back(getPointerTypeSameShape(valueType)); // ptr
operandTypes.push_back(valueType); // value
if (allOperands.size() >= 3)
operandTypes.push_back(getI1SameShape(valueType)); // mask
@@ -194,26 +195,33 @@ mlir::LogicalResult mlir::triton::ReduceOp::inferReturnTypes(
// infer shape
Value arg = operands[0];
auto argTy = arg.getType().cast<RankedTensorType>();
auto argEltTy = argTy.getElementType();
auto retShape = argTy.getShape().vec();
int axis = attributes.get("axis").cast<IntegerAttr>().getInt();
retShape.erase(retShape.begin() + axis);
// infer encoding
Attribute argEncoding = argTy.getEncoding();
Attribute retEncoding;
if (argEncoding) {
Dialect &dialect = argEncoding.getDialect();
auto inferLayoutInterface = dyn_cast<DialectInferLayoutInterface>(&dialect);
if (inferLayoutInterface
->inferReduceOpEncoding(argEncoding, axis, retEncoding)
.failed()) {
llvm::report_fatal_error("failed to infer layout for ReduceOp");
return mlir::failure();
if (retShape.empty()) {
// 0d-tensor -> scalar
inferredReturnTypes.push_back(argEltTy);
} else {
// nd-tensor where n >= 1
// infer encoding
Attribute argEncoding = argTy.getEncoding();
Attribute retEncoding;
if (argEncoding) {
Dialect &dialect = argEncoding.getDialect();
auto inferLayoutInterface =
dyn_cast<DialectInferLayoutInterface>(&dialect);
if (inferLayoutInterface
->inferReduceOpEncoding(argEncoding, axis, retEncoding)
.failed()) {
llvm::report_fatal_error("failed to infer layout for ReduceOp");
return mlir::failure();
}
}
// create type
inferredReturnTypes.push_back(
RankedTensorType::get(retShape, argEltTy, retEncoding));
}
// create type
auto argEltTy = argTy.getElementType();
inferredReturnTypes.push_back(
RankedTensorType::get(retShape, argEltTy, retEncoding));
return mlir::success();
}

View File

@@ -1104,8 +1104,11 @@ void init_triton_ir(py::module &&m) {
operand.getType().dyn_cast<mlir::RankedTensorType>();
std::vector<int64_t> shape = inputTensorType.getShape();
shape.erase(shape.begin() + axis);
auto resType = mlir::RankedTensorType::get(
shape, inputTensorType.getElementType());
mlir::Type resType = inputTensorType.getElementType();
if (!shape.empty()) {
resType = mlir::RankedTensorType::get(
shape, inputTensorType.getElementType());
}
return self.create<mlir::triton::ReduceOp>(loc, resType, redOp,
operand, axis);
})

View File

@@ -1,57 +0,0 @@
import triton
import triton.language as tl
# TODO: function with no arguments don't work
@triton.jit
def cast_check(X):
zero_0d = tl.zeros([], dtype=tl.float32)
zero_1d = tl.zeros([2], dtype=tl.float32)
zero_2d_21 = tl.zeros([2, 1], dtype=tl.float32)
zero_2d_22 = tl.zeros([2, 2], dtype=tl.float32)
# scalar + scalar -> scalar
a0 = 0.0 + 0.0
# scalar + 0D -> 0D
a1 = 0.0 + zero_0d
a2 = zero_0d + 0.0
# scalar + 1D -> 1D
a3 = 0.0 + zero_1d
a4 = zero_1d + 0.0
# scalar + 2D -> 2D
a5 = 0.0 + zero_2d_22
a6 = zero_2d_22 + 0.0
# 0D + 0D -> 0D
b1 = zero_0d + zero_0d
# 0D + 1D -> 1D
b2 = zero_0d + zero_1d
b3 = zero_1d + zero_0d
# 0D + 2D -> 2D
b4 = zero_0d + zero_2d_22
b5 = zero_2d_22 + zero_0d
# 1D + 1D -> 1D
c1 = zero_1d + zero_1d
# 1D + 2D -> 2D
c2 = zero_1d + zero_2d_21
c3 = zero_1d + zero_2d_22
c4 = zero_2d_21 + zero_1d
c5 = zero_2d_22 + zero_1d
# 2D + 2D -> 2D
d1 = zero_2d_21 + zero_2d_21
d2 = zero_2d_22 + zero_2d_22
d3 = zero_2d_21 + zero_2d_22
d4 = zero_2d_22 + zero_2d_21
return a0, a1, a2, a3, a4, a5, a6, b1, b2, b3, b4, b5, c1, c2, c3, c4, c5, d1, d2, d3, d4
def test_cast_check():
kernel = triton.compiler._compile(cast_check,
signature="*fp32",
device=0,
output="ttgir")
assert (kernel)
# TODO: Check types of the results

80
python/tests/test_type.py Normal file
View File

@@ -0,0 +1,80 @@
import triton
import triton.language as tl
# TODO: function with no arguments don't work
@triton.jit
def binop_type_check(X):
# 0d-tensor is not allowed.
# zero_0d = tl.zeros([], dtype=tl.float32)
zero_1d = tl.zeros([2], dtype=tl.float32)
zero_2d_21 = tl.zeros([2, 1], dtype=tl.float32)
zero_2d_22 = tl.zeros([2, 2], dtype=tl.float32)
# scalar + scalar -> scalar
a0 = 0.0 + 0.0
# # scalar + 0D -> 0D
# a1 = 0.0 + zero_0d
# a2 = zero_0d + 0.0
# scalar + 1D -> 1D
a3 = 0.0 + zero_1d
a4 = zero_1d + 0.0
# scalar + 2D -> 2D
a5 = 0.0 + zero_2d_22
a6 = zero_2d_22 + 0.0
# # 0D + 0D -> 0D
# b1 = zero_0d + zero_0d
# # 0D + 1D -> 1D
# b2 = zero_0d + zero_1d
# b3 = zero_1d + zero_0d
# # 0D + 2D -> 2D
# b4 = zero_0d + zero_2d_22
# b5 = zero_2d_22 + zero_0d
# 1D + 1D -> 1D
c1 = zero_1d + zero_1d
# 1D + 2D -> 2D
c2 = zero_1d + zero_2d_21
c3 = zero_1d + zero_2d_22
c4 = zero_2d_21 + zero_1d
c5 = zero_2d_22 + zero_1d
# 2D + 2D -> 2D
d1 = zero_2d_21 + zero_2d_21
d2 = zero_2d_22 + zero_2d_22
d3 = zero_2d_21 + zero_2d_22
d4 = zero_2d_22 + zero_2d_21
# return a0, a1, a2, a3, a4, a5, a6, b1, b2, b3, b4, b5, c1, c2, c3, c4, c5, d1, d2, d3, d4
return a0, a3, a4, a5, a6, c1, c2, c3, c4, c5, d1, d2, d3, d4
def test_binop_type_check():
kernel = triton.compiler._compile(binop_type_check,
signature="*fp32",
device=0,
output="ttgir")
assert (kernel)
# TODO: Check types of the results
@triton.jit
def reduce_type_check(ptr):
v_32 = tl.load(ptr + tl.arange(0, 32))
v_scalar = tl.min(v_32, axis=0)
tl.store(ptr, v_scalar)
v_64x128 = tl.load(ptr + tl.arange(0, 64)[:, None] + tl.arange(0, 128)[None, :])
v_64 = tl.max(v_64x128, axis=1)
tl.store(ptr + tl.arange(0, 64), v_64)
v_128 = tl.max(v_64x128, axis=0)
tl.store(ptr + tl.arange(0, 128), v_128)
def test_reduce_type_check():
kernel = triton.compiler._compile(reduce_type_check,
signature="*fp32",
device=0,
output="ttgir")
assert (kernel)
# TODO: Check types of the results

View File

@@ -241,7 +241,9 @@ class block_type(dtype):
# while tensor's shape is a list of constexpr.
# shape can be empty ([]) when an input is a 0D tensor.
if shape and isinstance(shape[0], constexpr):
if not shape:
raise TypeError('0d block_type is forbidden')
if isinstance(shape[0], constexpr):
shape = [s.value for s in shape]
self.shape = shape

View File

@@ -991,7 +991,11 @@ def reduce_impl(input: tl.tensor, axis: int, builder: ir.builder, name: str,
for i, s in enumerate(shape):
if i != axis:
ret_shape.append(s)
res_ty = tl.block_type(scalar_ty, ret_shape)
if ret_shape:
res_ty = tl.block_type(scalar_ty, ret_shape)
else:
# 0d-tensor -> scalar
res_ty = scalar_ty
if scalar_ty.is_floating():
return tl.tensor(builder.create_reduce(input.handle, FLOAT_OP, axis), res_ty)

View File

@@ -53,3 +53,80 @@ func @addptr_ops(%scalar_ptr: !tt.ptr<f32>, %scalar_i32: i32) {
%2 = tt.addptr %tensor_ptr_1d, %tensor_i32_1d : tensor<16x!tt.ptr<f32>>
return
}
func @load_store_ops_scalar(%ptr: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %mask : i1) {
// Test if Load/Store ops can handle scalar values
%other = arith.constant 0.0e+0 : f32
// load scalar
// CHECK: %[[L0:.*]] = tt.load %{{.*}} {cache = 1 : i32, evict = 1 : i32, isVolatile = true} : f32
%a = tt.load %ptr {cache = 1 : i32, evict = 1 : i32, isVolatile = true} : f32
// CHECK: %[[L1:.*]] = tt.load %{{.*}}, %{{.*}} {cache = 1 : i32, evict = 1 : i32, isVolatile = true} : f32
%b = tt.load %ptr, %mask {cache = 1 : i32, evict = 1 : i32, isVolatile = true} : f32
// CHECK: %[[L2:.*]] = tt.load %{{.*}}, %{{.*}}, %{{.*}} {cache = 1 : i32, evict = 1 : i32, isVolatile = true} : f32
%c = tt.load %ptr, %mask, %other {cache = 1 : i32, evict = 1 : i32, isVolatile = true} : f32
// store scalar
// CHECK: tt.store %{{.*}}, %[[L0]] : f32
tt.store %ptr, %a : f32
// CHECK: tt.store %{{.*}}, %[[L1]], %{{.*}} : f32
tt.store %ptr, %b, %mask : f32
// CHECK: tt.store %{{.*}}, %[[L2]], %{{.*}} : f32
tt.store %ptr, %c, %mask : f32
return
}
func @reduce_ops_infer(%ptr: !tt.ptr<f32>, %v : tensor<1x2x4xf32>) {
// Test if reduce ops infer types correctly
// CHECK: %{{.*}} = tt.reduce %{{.*}} -> tensor<2x4xf32>
%a = tt.reduce %v {redOp = 1 : i32, axis = 0 : i32} : tensor<1x2x4xf32> -> tensor<2x4xf32>
// CHECK: %{{.*}} = tt.reduce %{{.*}} -> tensor<1x4xf32>
%b = tt.reduce %v {redOp = 1 : i32, axis = 1 : i32} : tensor<1x2x4xf32> -> tensor<1x4xf32>
// CHECK: %{{.*}} = tt.reduce %{{.*}} -> tensor<1x2xf32>
%c = tt.reduce %v {redOp = 1 : i32, axis = 2 : i32} : tensor<1x2x4xf32> -> tensor<1x2xf32>
// CHECK: %{{.*}} = tt.reduce %{{.*}} -> tensor<1xf32>
%e = tt.reduce %b {redOp = 1 : i32, axis = 1 : i32} : tensor<1x4xf32> -> tensor<1xf32>
// CHECK: %{{.*}} = tt.reduce %{{.*}} -> tensor<4xf32>
%f = tt.reduce %a {redOp = 1 : i32, axis = 0 : i32} : tensor<2x4xf32> -> tensor<4xf32>
// CHECK: %{{.*}} = tt.reduce %{{.*}} -> f32
%g = tt.reduce %f {redOp = 1 : i32, axis = 0 : i32} : tensor<4xf32> -> f32
// Avoid optimizations for c, e, and g
%ptr1x2 = tt.splat %ptr : (!tt.ptr<f32>) -> tensor<1x2x!tt.ptr<f32>>
%ptr1 = tt.splat %ptr : (!tt.ptr<f32>) -> tensor<1x!tt.ptr<f32>>
tt.store %ptr1x2, %c : tensor<1x2xf32>
tt.store %ptr1, %e : tensor<1xf32>
tt.store %ptr, %g : f32
return
}
func @dot_ops_infer(%ptr: !tt.ptr<f32>, %v : f32) {
// Test if reduce ops infer types correctly
%v128x32 = tt.splat %v : (f32) -> tensor<128x32xf32>
%v32x128 = tt.splat %v : (f32) -> tensor<32x128xf32>
%v128x1 = tt.splat %v : (f32) -> tensor<128x1xf32>
%v1x128 = tt.splat %v : (f32) -> tensor<1x128xf32>
%zero128x128 = arith.constant dense<0.00e+00> : tensor<128x128xf32>
%zero32x32 = arith.constant dense<0.00e+00> : tensor<32x32xf32>
%zero1x1 = arith.constant dense<0.00e+00> : tensor<1x1xf32>
// CHECK: %{{.*}} = tt.dot %{{.*}} -> tensor<128x128xf32>
%r1 = tt.dot %v128x32, %v32x128, %zero128x128 {allowTF32 = true} : tensor<128x32xf32> * tensor<32x128xf32> -> tensor<128x128xf32>
// CHECK: %{{.*}} = tt.dot %{{.*}} -> tensor<32x32xf32>
%r2 = tt.dot %v32x128, %v128x32, %zero32x32 {allowTF32 = true} : tensor<32x128xf32> * tensor<128x32xf32> -> tensor<32x32xf32>
// CHECK: %{{.*}} = tt.dot %{{.*}} -> tensor<128x128xf32>
%r3 = tt.dot %v128x1, %v1x128, %zero128x128 {allowTF32 = true} : tensor<128x1xf32> * tensor<1x128xf32> -> tensor<128x128xf32>
// CHECK: %{{.*}} = tt.dot %{{.*}} -> tensor<1x1xf32>
%r4 = tt.dot %v1x128, %v128x1, %zero1x1 {allowTF32 = true} : tensor<1x128xf32> * tensor<128x1xf32> -> tensor<1x1xf32>
%ptr128x128 = tt.splat %ptr : (!tt.ptr<f32>) -> tensor<128x128x!tt.ptr<f32>>
%ptr32x32 = tt.splat %ptr : (!tt.ptr<f32>) -> tensor<32x32x!tt.ptr<f32>>
%ptr1x1 = tt.splat %ptr : (!tt.ptr<f32>) -> tensor<1x1x!tt.ptr<f32>>
tt.store %ptr128x128, %r1 : tensor<128x128xf32>
tt.store %ptr32x32, %r2 : tensor<32x32xf32>
tt.store %ptr128x128, %r3 : tensor<128x128xf32>
tt.store %ptr1x1, %r4 : tensor<1x1xf32>
return
}