From 0d22d2bc03e87fcd29071d5344987132e1e9528e Mon Sep 17 00:00:00 2001 From: Shintaro Iwasaki Date: Wed, 19 Oct 2022 10:34:32 -0700 Subject: [PATCH] [TritonMLIR] Disallow 0D tensor (#788) --- include/triton/Dialect/Triton/IR/TritonOps.td | 11 ++- lib/Dialect/Triton/IR/Ops.cpp | 48 ++++++----- python/src/triton.cc | 7 +- python/tests/test_cast.py | 57 ------------- python/tests/test_type.py | 80 +++++++++++++++++++ python/triton/language/core.py | 4 +- python/triton/language/semantic.py | 6 +- test/Conversion/triton_ops.mlir | 77 ++++++++++++++++++ 8 files changed, 203 insertions(+), 87 deletions(-) delete mode 100644 python/tests/test_cast.py create mode 100644 python/tests/test_type.py diff --git a/include/triton/Dialect/Triton/IR/TritonOps.td b/include/triton/Dialect/Triton/IR/TritonOps.td index b3f5f82ec..ce32d3914 100644 --- a/include/triton/Dialect/Triton/IR/TritonOps.td +++ b/include/triton/Dialect/Triton/IR/TritonOps.td @@ -108,8 +108,7 @@ def TT_LoadOp : TT_Op<"load", AttrSizedOperandSegments, MemoryEffects<[MemRead]>, TypesMatchWith<"infer ptr type from result type", - "result", "ptr", - "getPointerTypeFromTensor($_self)">, + "result", "ptr", "getPointerTypeSameShape($_self)">, TypesMatchWith<"infer mask type from result type or none", "result", "mask", "getI1SameShape($_self)", "($_op.getOperands().size() <= 1) || std::equal_to<>()">, @@ -118,7 +117,7 @@ def TT_LoadOp : TT_Op<"load", "($_op.getOperands().size() <= 2) || std::equal_to<>()">]> { let summary = "load"; - let arguments = (ins TT_PtrTensor:$ptr, Optional:$mask, Optional:$other, + let arguments = (ins TT_PtrLike:$ptr, Optional:$mask, Optional:$other, TT_CacheModifierAttr:$cache, TT_EvictionPolicyAttr:$evict, BoolAttr:$isVolatile); @@ -147,13 +146,13 @@ def TT_StoreOp : TT_Op<"store", MemoryEffects<[MemWrite]>, TypesMatchWith<"infer ptr type from value type", "value", "ptr", - "getPointerTypeFromTensor($_self)">, + "getPointerTypeSameShape($_self)">, TypesMatchWith<"infer mask type from value type", "value", "mask", "getI1SameShape($_self)", "($_op.getOperands().size() <= 2) || std::equal_to<>()">]> { let summary = "store"; - let arguments = (ins TT_PtrTensor:$ptr, TT_Type:$value, Optional:$mask); + let arguments = (ins TT_PtrLike:$ptr, TT_Type:$value, Optional:$mask); let builders = [ OpBuilder<(ins "Value":$ptr, "Value":$value)>, @@ -318,7 +317,7 @@ def TT_ReduceOp : TT_Op<"reduce", [NoSideEffect, let arguments = (ins TT_RedOpAttr:$redOp, TT_Tensor:$operand, I32Attr:$axis); - let results = (outs TT_Tensor:$result); + let results = (outs TT_Type:$result); let builders = [ OpBuilder<(ins "triton::RedOp":$redOp, "Value":$operand, "int":$axis)>, diff --git a/lib/Dialect/Triton/IR/Ops.cpp b/lib/Dialect/Triton/IR/Ops.cpp index 630260988..32b56fb0d 100644 --- a/lib/Dialect/Triton/IR/Ops.cpp +++ b/lib/Dialect/Triton/IR/Ops.cpp @@ -26,14 +26,15 @@ static Type getI32SameShape(Type type) { return i32Type; } -static Type getPointerTypeFromTensor(Type type) { +static Type getPointerTypeSameShape(Type type) { if (auto tensorType = type.dyn_cast()) { Type elementType = tensorType.getElementType(); auto shape = tensorType.getShape(); PointerType ptrType = PointerType::get(elementType, 1); return RankedTensorType::get(shape, ptrType, tensorType.getEncoding()); + } else { + return PointerType::get(type, 1); } - return Type(); } // Parser & printer for assembly forms @@ -49,7 +50,7 @@ ParseResult parseLoadOp(OpAsmParser &parser, OperationState &result) { result.addTypes(resultTypes); SmallVector operandTypes; - operandTypes.push_back(getPointerTypeFromTensor(resultTypes[0])); // ptr + operandTypes.push_back(getPointerTypeSameShape(resultTypes[0])); // ptr int hasMask = 0, hasOther = 0; if (allOperands.size() >= 2) { operandTypes.push_back(getI1SameShape(resultTypes[0])); // mask @@ -92,8 +93,8 @@ ParseResult parseStoreOp(OpAsmParser &parser, OperationState &result) { return failure(); SmallVector operandTypes; - operandTypes.push_back(getPointerTypeFromTensor(valueType)); // ptr - operandTypes.push_back(valueType); // value + operandTypes.push_back(getPointerTypeSameShape(valueType)); // ptr + operandTypes.push_back(valueType); // value if (allOperands.size() >= 3) operandTypes.push_back(getI1SameShape(valueType)); // mask @@ -194,26 +195,33 @@ mlir::LogicalResult mlir::triton::ReduceOp::inferReturnTypes( // infer shape Value arg = operands[0]; auto argTy = arg.getType().cast(); + auto argEltTy = argTy.getElementType(); auto retShape = argTy.getShape().vec(); int axis = attributes.get("axis").cast().getInt(); retShape.erase(retShape.begin() + axis); - // infer encoding - Attribute argEncoding = argTy.getEncoding(); - Attribute retEncoding; - if (argEncoding) { - Dialect &dialect = argEncoding.getDialect(); - auto inferLayoutInterface = dyn_cast(&dialect); - if (inferLayoutInterface - ->inferReduceOpEncoding(argEncoding, axis, retEncoding) - .failed()) { - llvm::report_fatal_error("failed to infer layout for ReduceOp"); - return mlir::failure(); + if (retShape.empty()) { + // 0d-tensor -> scalar + inferredReturnTypes.push_back(argEltTy); + } else { + // nd-tensor where n >= 1 + // infer encoding + Attribute argEncoding = argTy.getEncoding(); + Attribute retEncoding; + if (argEncoding) { + Dialect &dialect = argEncoding.getDialect(); + auto inferLayoutInterface = + dyn_cast(&dialect); + if (inferLayoutInterface + ->inferReduceOpEncoding(argEncoding, axis, retEncoding) + .failed()) { + llvm::report_fatal_error("failed to infer layout for ReduceOp"); + return mlir::failure(); + } } + // create type + inferredReturnTypes.push_back( + RankedTensorType::get(retShape, argEltTy, retEncoding)); } - // create type - auto argEltTy = argTy.getElementType(); - inferredReturnTypes.push_back( - RankedTensorType::get(retShape, argEltTy, retEncoding)); return mlir::success(); } diff --git a/python/src/triton.cc b/python/src/triton.cc index e3947a3f8..b0c7d828c 100644 --- a/python/src/triton.cc +++ b/python/src/triton.cc @@ -1104,8 +1104,11 @@ void init_triton_ir(py::module &&m) { operand.getType().dyn_cast(); std::vector shape = inputTensorType.getShape(); shape.erase(shape.begin() + axis); - auto resType = mlir::RankedTensorType::get( - shape, inputTensorType.getElementType()); + mlir::Type resType = inputTensorType.getElementType(); + if (!shape.empty()) { + resType = mlir::RankedTensorType::get( + shape, inputTensorType.getElementType()); + } return self.create(loc, resType, redOp, operand, axis); }) diff --git a/python/tests/test_cast.py b/python/tests/test_cast.py deleted file mode 100644 index 9b5513aa3..000000000 --- a/python/tests/test_cast.py +++ /dev/null @@ -1,57 +0,0 @@ -import triton -import triton.language as tl - - -# TODO: function with no arguments don't work -@triton.jit -def cast_check(X): - zero_0d = tl.zeros([], dtype=tl.float32) - zero_1d = tl.zeros([2], dtype=tl.float32) - zero_2d_21 = tl.zeros([2, 1], dtype=tl.float32) - zero_2d_22 = tl.zeros([2, 2], dtype=tl.float32) - - # scalar + scalar -> scalar - a0 = 0.0 + 0.0 - # scalar + 0D -> 0D - a1 = 0.0 + zero_0d - a2 = zero_0d + 0.0 - # scalar + 1D -> 1D - a3 = 0.0 + zero_1d - a4 = zero_1d + 0.0 - # scalar + 2D -> 2D - a5 = 0.0 + zero_2d_22 - a6 = zero_2d_22 + 0.0 - - # 0D + 0D -> 0D - b1 = zero_0d + zero_0d - # 0D + 1D -> 1D - b2 = zero_0d + zero_1d - b3 = zero_1d + zero_0d - # 0D + 2D -> 2D - b4 = zero_0d + zero_2d_22 - b5 = zero_2d_22 + zero_0d - - # 1D + 1D -> 1D - c1 = zero_1d + zero_1d - # 1D + 2D -> 2D - c2 = zero_1d + zero_2d_21 - c3 = zero_1d + zero_2d_22 - c4 = zero_2d_21 + zero_1d - c5 = zero_2d_22 + zero_1d - - # 2D + 2D -> 2D - d1 = zero_2d_21 + zero_2d_21 - d2 = zero_2d_22 + zero_2d_22 - d3 = zero_2d_21 + zero_2d_22 - d4 = zero_2d_22 + zero_2d_21 - - return a0, a1, a2, a3, a4, a5, a6, b1, b2, b3, b4, b5, c1, c2, c3, c4, c5, d1, d2, d3, d4 - - -def test_cast_check(): - kernel = triton.compiler._compile(cast_check, - signature="*fp32", - device=0, - output="ttgir") - assert (kernel) - # TODO: Check types of the results diff --git a/python/tests/test_type.py b/python/tests/test_type.py new file mode 100644 index 000000000..07de3ce27 --- /dev/null +++ b/python/tests/test_type.py @@ -0,0 +1,80 @@ +import triton +import triton.language as tl + + +# TODO: function with no arguments don't work +@triton.jit +def binop_type_check(X): + # 0d-tensor is not allowed. + # zero_0d = tl.zeros([], dtype=tl.float32) + zero_1d = tl.zeros([2], dtype=tl.float32) + zero_2d_21 = tl.zeros([2, 1], dtype=tl.float32) + zero_2d_22 = tl.zeros([2, 2], dtype=tl.float32) + + # scalar + scalar -> scalar + a0 = 0.0 + 0.0 + # # scalar + 0D -> 0D + # a1 = 0.0 + zero_0d + # a2 = zero_0d + 0.0 + # scalar + 1D -> 1D + a3 = 0.0 + zero_1d + a4 = zero_1d + 0.0 + # scalar + 2D -> 2D + a5 = 0.0 + zero_2d_22 + a6 = zero_2d_22 + 0.0 + + # # 0D + 0D -> 0D + # b1 = zero_0d + zero_0d + # # 0D + 1D -> 1D + # b2 = zero_0d + zero_1d + # b3 = zero_1d + zero_0d + # # 0D + 2D -> 2D + # b4 = zero_0d + zero_2d_22 + # b5 = zero_2d_22 + zero_0d + + # 1D + 1D -> 1D + c1 = zero_1d + zero_1d + # 1D + 2D -> 2D + c2 = zero_1d + zero_2d_21 + c3 = zero_1d + zero_2d_22 + c4 = zero_2d_21 + zero_1d + c5 = zero_2d_22 + zero_1d + + # 2D + 2D -> 2D + d1 = zero_2d_21 + zero_2d_21 + d2 = zero_2d_22 + zero_2d_22 + d3 = zero_2d_21 + zero_2d_22 + d4 = zero_2d_22 + zero_2d_21 + + # return a0, a1, a2, a3, a4, a5, a6, b1, b2, b3, b4, b5, c1, c2, c3, c4, c5, d1, d2, d3, d4 + return a0, a3, a4, a5, a6, c1, c2, c3, c4, c5, d1, d2, d3, d4 + + +def test_binop_type_check(): + kernel = triton.compiler._compile(binop_type_check, + signature="*fp32", + device=0, + output="ttgir") + assert (kernel) + # TODO: Check types of the results + + +@triton.jit +def reduce_type_check(ptr): + v_32 = tl.load(ptr + tl.arange(0, 32)) + v_scalar = tl.min(v_32, axis=0) + tl.store(ptr, v_scalar) + v_64x128 = tl.load(ptr + tl.arange(0, 64)[:, None] + tl.arange(0, 128)[None, :]) + v_64 = tl.max(v_64x128, axis=1) + tl.store(ptr + tl.arange(0, 64), v_64) + v_128 = tl.max(v_64x128, axis=0) + tl.store(ptr + tl.arange(0, 128), v_128) + + +def test_reduce_type_check(): + kernel = triton.compiler._compile(reduce_type_check, + signature="*fp32", + device=0, + output="ttgir") + assert (kernel) + # TODO: Check types of the results diff --git a/python/triton/language/core.py b/python/triton/language/core.py index bf1d57ba4..b253a3491 100644 --- a/python/triton/language/core.py +++ b/python/triton/language/core.py @@ -241,7 +241,9 @@ class block_type(dtype): # while tensor's shape is a list of constexpr. # shape can be empty ([]) when an input is a 0D tensor. - if shape and isinstance(shape[0], constexpr): + if not shape: + raise TypeError('0d block_type is forbidden') + if isinstance(shape[0], constexpr): shape = [s.value for s in shape] self.shape = shape diff --git a/python/triton/language/semantic.py b/python/triton/language/semantic.py index 1c8caaefe..9bf1261fb 100644 --- a/python/triton/language/semantic.py +++ b/python/triton/language/semantic.py @@ -991,7 +991,11 @@ def reduce_impl(input: tl.tensor, axis: int, builder: ir.builder, name: str, for i, s in enumerate(shape): if i != axis: ret_shape.append(s) - res_ty = tl.block_type(scalar_ty, ret_shape) + if ret_shape: + res_ty = tl.block_type(scalar_ty, ret_shape) + else: + # 0d-tensor -> scalar + res_ty = scalar_ty if scalar_ty.is_floating(): return tl.tensor(builder.create_reduce(input.handle, FLOAT_OP, axis), res_ty) diff --git a/test/Conversion/triton_ops.mlir b/test/Conversion/triton_ops.mlir index 20744ecaa..15e3fc28c 100644 --- a/test/Conversion/triton_ops.mlir +++ b/test/Conversion/triton_ops.mlir @@ -53,3 +53,80 @@ func @addptr_ops(%scalar_ptr: !tt.ptr, %scalar_i32: i32) { %2 = tt.addptr %tensor_ptr_1d, %tensor_i32_1d : tensor<16x!tt.ptr> return } + +func @load_store_ops_scalar(%ptr: !tt.ptr {tt.divisibility = 16 : i32}, %mask : i1) { + // Test if Load/Store ops can handle scalar values + %other = arith.constant 0.0e+0 : f32 + + // load scalar + // CHECK: %[[L0:.*]] = tt.load %{{.*}} {cache = 1 : i32, evict = 1 : i32, isVolatile = true} : f32 + %a = tt.load %ptr {cache = 1 : i32, evict = 1 : i32, isVolatile = true} : f32 + // CHECK: %[[L1:.*]] = tt.load %{{.*}}, %{{.*}} {cache = 1 : i32, evict = 1 : i32, isVolatile = true} : f32 + %b = tt.load %ptr, %mask {cache = 1 : i32, evict = 1 : i32, isVolatile = true} : f32 + // CHECK: %[[L2:.*]] = tt.load %{{.*}}, %{{.*}}, %{{.*}} {cache = 1 : i32, evict = 1 : i32, isVolatile = true} : f32 + %c = tt.load %ptr, %mask, %other {cache = 1 : i32, evict = 1 : i32, isVolatile = true} : f32 + + // store scalar + // CHECK: tt.store %{{.*}}, %[[L0]] : f32 + tt.store %ptr, %a : f32 + // CHECK: tt.store %{{.*}}, %[[L1]], %{{.*}} : f32 + tt.store %ptr, %b, %mask : f32 + // CHECK: tt.store %{{.*}}, %[[L2]], %{{.*}} : f32 + tt.store %ptr, %c, %mask : f32 + return +} + +func @reduce_ops_infer(%ptr: !tt.ptr, %v : tensor<1x2x4xf32>) { + // Test if reduce ops infer types correctly + + // CHECK: %{{.*}} = tt.reduce %{{.*}} -> tensor<2x4xf32> + %a = tt.reduce %v {redOp = 1 : i32, axis = 0 : i32} : tensor<1x2x4xf32> -> tensor<2x4xf32> + // CHECK: %{{.*}} = tt.reduce %{{.*}} -> tensor<1x4xf32> + %b = tt.reduce %v {redOp = 1 : i32, axis = 1 : i32} : tensor<1x2x4xf32> -> tensor<1x4xf32> + // CHECK: %{{.*}} = tt.reduce %{{.*}} -> tensor<1x2xf32> + %c = tt.reduce %v {redOp = 1 : i32, axis = 2 : i32} : tensor<1x2x4xf32> -> tensor<1x2xf32> + // CHECK: %{{.*}} = tt.reduce %{{.*}} -> tensor<1xf32> + %e = tt.reduce %b {redOp = 1 : i32, axis = 1 : i32} : tensor<1x4xf32> -> tensor<1xf32> + // CHECK: %{{.*}} = tt.reduce %{{.*}} -> tensor<4xf32> + %f = tt.reduce %a {redOp = 1 : i32, axis = 0 : i32} : tensor<2x4xf32> -> tensor<4xf32> + // CHECK: %{{.*}} = tt.reduce %{{.*}} -> f32 + %g = tt.reduce %f {redOp = 1 : i32, axis = 0 : i32} : tensor<4xf32> -> f32 + + // Avoid optimizations for c, e, and g + %ptr1x2 = tt.splat %ptr : (!tt.ptr) -> tensor<1x2x!tt.ptr> + %ptr1 = tt.splat %ptr : (!tt.ptr) -> tensor<1x!tt.ptr> + tt.store %ptr1x2, %c : tensor<1x2xf32> + tt.store %ptr1, %e : tensor<1xf32> + tt.store %ptr, %g : f32 + return +} + +func @dot_ops_infer(%ptr: !tt.ptr, %v : f32) { + // Test if reduce ops infer types correctly + %v128x32 = tt.splat %v : (f32) -> tensor<128x32xf32> + %v32x128 = tt.splat %v : (f32) -> tensor<32x128xf32> + %v128x1 = tt.splat %v : (f32) -> tensor<128x1xf32> + %v1x128 = tt.splat %v : (f32) -> tensor<1x128xf32> + + %zero128x128 = arith.constant dense<0.00e+00> : tensor<128x128xf32> + %zero32x32 = arith.constant dense<0.00e+00> : tensor<32x32xf32> + %zero1x1 = arith.constant dense<0.00e+00> : tensor<1x1xf32> + + // CHECK: %{{.*}} = tt.dot %{{.*}} -> tensor<128x128xf32> + %r1 = tt.dot %v128x32, %v32x128, %zero128x128 {allowTF32 = true} : tensor<128x32xf32> * tensor<32x128xf32> -> tensor<128x128xf32> + // CHECK: %{{.*}} = tt.dot %{{.*}} -> tensor<32x32xf32> + %r2 = tt.dot %v32x128, %v128x32, %zero32x32 {allowTF32 = true} : tensor<32x128xf32> * tensor<128x32xf32> -> tensor<32x32xf32> + // CHECK: %{{.*}} = tt.dot %{{.*}} -> tensor<128x128xf32> + %r3 = tt.dot %v128x1, %v1x128, %zero128x128 {allowTF32 = true} : tensor<128x1xf32> * tensor<1x128xf32> -> tensor<128x128xf32> + // CHECK: %{{.*}} = tt.dot %{{.*}} -> tensor<1x1xf32> + %r4 = tt.dot %v1x128, %v128x1, %zero1x1 {allowTF32 = true} : tensor<1x128xf32> * tensor<128x1xf32> -> tensor<1x1xf32> + + %ptr128x128 = tt.splat %ptr : (!tt.ptr) -> tensor<128x128x!tt.ptr> + %ptr32x32 = tt.splat %ptr : (!tt.ptr) -> tensor<32x32x!tt.ptr> + %ptr1x1 = tt.splat %ptr : (!tt.ptr) -> tensor<1x1x!tt.ptr> + tt.store %ptr128x128, %r1 : tensor<128x128xf32> + tt.store %ptr32x32, %r2 : tensor<32x32xf32> + tt.store %ptr128x128, %r3 : tensor<128x128xf32> + tt.store %ptr1x1, %r4 : tensor<1x1xf32> + return +}