Compare commits
6 Commits
phil/swizz
...
keren/asse
Author | SHA1 | Date | |
---|---|---|---|
|
4037f3b921 | ||
|
fcff1a6e75 | ||
|
2920f6f50f | ||
|
bc73bbb12c | ||
|
8460ea3df1 | ||
|
678b9f53a2 |
@@ -289,7 +289,7 @@ def TT_CatOp : TT_Op<"cat", [NoSideEffect,
|
||||
}
|
||||
|
||||
def TT_TransOp : TT_Op<"trans", [NoSideEffect,
|
||||
SameOperandsAndResultElementType]> {
|
||||
SameOperandsAndResultElementType]> {
|
||||
|
||||
let summary = "transpose a tensor";
|
||||
|
||||
@@ -408,8 +408,7 @@ def TT_MakeRangeOp : TT_Op<"make_range", [NoSideEffect]> {
|
||||
// Make PrintfOp
|
||||
//
|
||||
def TT_PrintfOp : TT_Op<"printf", [MemoryEffects<[MemWrite]>]>,
|
||||
Arguments<(ins StrAttr:$prefix,
|
||||
Variadic<AnyTypeOf<[TT_Type]>>:$args)> {
|
||||
Arguments<(ins StrAttr:$prefix, Variadic<AnyTypeOf<[TT_Type]>>:$args)> {
|
||||
let summary = "Device-side printf, as in CUDA for debugging";
|
||||
let description = [{
|
||||
`tt.printf` takes a literal string prefix and an arbitrary number of scalar or tensor arguments that should be printed.
|
||||
@@ -420,4 +419,14 @@ def TT_PrintfOp : TT_Op<"printf", [MemoryEffects<[MemWrite]>]>,
|
||||
}];
|
||||
}
|
||||
|
||||
//
|
||||
// Make AssertOp
|
||||
//
|
||||
def TT_AssertOp : TT_Op<"assert", [MemoryEffects<[MemWrite]>]> {
|
||||
let summary = "Device-side assert, as in CUDA for debugging";
|
||||
let description = [{}];
|
||||
let arguments = (ins TT_Tensor:$condition, StrAttr:$message);
|
||||
let assemblyFormat = "$condition `,` $message attr-dict `:` type($condition)";
|
||||
}
|
||||
|
||||
#endif // Triton_OPS
|
||||
|
@@ -25,13 +25,14 @@ ChangeResult SharedMemoryAliasAnalysis::visitOperation(
|
||||
if (maybeSharedAllocationOp(op)) {
|
||||
// These ops may allocate a new shared memory buffer.
|
||||
auto result = op->getResult(0);
|
||||
// FIXME(Keren): extract and insert are always alias for now
|
||||
// XXX(Keren): the following ops are always aliasing for now
|
||||
if (isa<tensor::ExtractSliceOp, triton::TransOp>(op)) {
|
||||
// extract_slice %src
|
||||
// trans %src
|
||||
aliasInfo = AliasInfo(operands[0]->getValue());
|
||||
pessimistic = false;
|
||||
} else if (isa<tensor::InsertSliceOp>(op) ||
|
||||
isa<triton::gpu::InsertSliceAsyncOp>(op)) {
|
||||
} else if (isa<tensor::InsertSliceOp, triton::gpu::InsertSliceAsyncOp>(
|
||||
op)) {
|
||||
// insert_slice_async %src, %dst, %index
|
||||
// insert_slice %src into %dst[%offsets]
|
||||
aliasInfo = AliasInfo(operands[1]->getValue());
|
||||
|
@@ -298,10 +298,24 @@ private:
|
||||
|
||||
/// Resolves liveness of all values involved under the root operation.
|
||||
void resolveLiveness() {
|
||||
// In the SCF dialect, we always have a sequentially nested structure of
|
||||
// blocks
|
||||
// Assign an ID to each operation using post-order traversal.
|
||||
// To achieve the correct liveness range, the parent operation's ID
|
||||
// should be greater than each of its child operation's ID .
|
||||
// Example:
|
||||
// ...
|
||||
// %5 = triton.convert_layout %4
|
||||
// %6 = scf.for ... iter_args(%arg0 = %0) -> (i32) {
|
||||
// %2 = triton.convert_layout %5
|
||||
// ...
|
||||
// scf.yield %arg0
|
||||
// }
|
||||
// For example, %5 is defined in the parent region and used in
|
||||
// the child region, and is not passed as a block argument.
|
||||
// %6 should should have an ID greater than its child operations,
|
||||
// otherwise %5 liveness range ends before the child operation's liveness
|
||||
// range ends.
|
||||
DenseMap<Operation *, size_t> operationId;
|
||||
operation->walk<WalkOrder::PreOrder>(
|
||||
operation->walk<WalkOrder::PostOrder>(
|
||||
[&](Operation *op) { operationId[op] = operationId.size(); });
|
||||
|
||||
// Analyze liveness of explicit buffers
|
||||
|
@@ -12,6 +12,7 @@
|
||||
#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
|
||||
#include "mlir/Dialect/GPU/GPUDialect.h"
|
||||
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
|
||||
#include "mlir/Dialect/StandardOps/IR/Ops.h"
|
||||
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
||||
#include "mlir/IR/Matchers.h"
|
||||
#include "mlir/IR/TypeUtilities.h"
|
||||
|
@@ -270,6 +270,48 @@ struct PrintfOpConversion
|
||||
}
|
||||
};
|
||||
|
||||
struct AssertOpConversion
|
||||
: public ConvertTritonGPUOpToLLVMPattern<triton::AssertOp> {
|
||||
using ConvertTritonGPUOpToLLVMPattern<
|
||||
triton::AssertOp>::ConvertTritonGPUOpToLLVMPattern;
|
||||
|
||||
LogicalResult
|
||||
matchAndRewrite(triton::AssertOp op, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
auto loc = op.getLoc();
|
||||
auto ctx = rewriter.getContext();
|
||||
auto voidTy = void_ty(ctx);
|
||||
auto elems = getElementsFromStruct(loc, adaptor.condition(), rewriter);
|
||||
Value ret;
|
||||
for (auto elem : elems) {
|
||||
auto type = elem.getType();
|
||||
Value condition;
|
||||
if (type.isIntOrFloat()) {
|
||||
if (type.isSignedInteger() || type.isSignlessInteger()) {
|
||||
condition = icmp_eq(elem, rewriter.create<LLVM::ConstantOp>(
|
||||
loc, type, rewriter.getZeroAttr(type)));
|
||||
} else {
|
||||
condition = fcmp_eq(elem, rewriter.create<LLVM::ConstantOp>(
|
||||
loc, type, rewriter.getZeroAttr(type)));
|
||||
}
|
||||
} else {
|
||||
assert(false && "Unsupported type for assert");
|
||||
return failure();
|
||||
}
|
||||
// MLIR::AssertOp is lowered to a call to llvm.abort, which cannot be
|
||||
// handled by ptxas
|
||||
// We should call __assertfail here
|
||||
// Delete the definition of triton.assert, using mlir.assert instead
|
||||
PTXBuilder builder;
|
||||
auto &trapOp = *builder.create<PTXInstr>("trap");
|
||||
trapOp().predicate(condition);
|
||||
ret = builder.launch(rewriter, loc, voidTy);
|
||||
}
|
||||
rewriter.replaceOp(op, ret);
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
struct MakeRangeOpConversion
|
||||
: public ConvertTritonGPUOpToLLVMPattern<triton::MakeRangeOp> {
|
||||
|
||||
@@ -524,4 +566,5 @@ void populateTritonGPUToLLVMPatterns(
|
||||
patterns.add<MakeRangeOpConversion>(typeConverter, indexCacheInfo, benefit);
|
||||
patterns.add<ReturnOpConversion>(typeConverter, benefit);
|
||||
patterns.add<PrintfOpConversion>(typeConverter, benefit);
|
||||
patterns.add<AssertOpConversion>(typeConverter, benefit);
|
||||
}
|
@@ -45,6 +45,9 @@
|
||||
#define fcmp_olt(lhs, rhs) \
|
||||
rewriter.create<LLVM::FCmpOp>(loc, rewriter.getI1Type(), \
|
||||
LLVM::FCmpPredicate::olt, lhs, rhs)
|
||||
#define fcmp_eq(lhs, rhs) \
|
||||
rewriter.create<LLVM::FCmpOp>(loc, rewriter.getI1Type(), \
|
||||
LLVM::FCmpPredicate::oeq, lhs, rhs)
|
||||
#define icmp_eq(...) \
|
||||
rewriter.create<LLVM::ICmpOp>(loc, LLVM::ICmpPredicate::eq, __VA_ARGS__)
|
||||
#define icmp_ne(...) \
|
||||
@@ -77,6 +80,7 @@
|
||||
#define f16_ty rewriter.getF16Type()
|
||||
#define bf16_ty rewriter.getBF16Type()
|
||||
#define i8_ty rewriter.getIntegerType(8)
|
||||
#define i1_ty rewriter.getI1Type()
|
||||
#define f32_ty rewriter.getF32Type()
|
||||
#define f64_ty rewriter.getF64Type()
|
||||
#define vec_ty(type, num) VectorType::get(num, type)
|
||||
|
@@ -453,10 +453,11 @@ struct TritonReducePattern : public OpConversionPattern<triton::ReduceOp> {
|
||||
};
|
||||
|
||||
struct TritonPrintfPattern : public OpConversionPattern<triton::PrintfOp> {
|
||||
using OpConversionPattern<PrintfOp>::OpConversionPattern;
|
||||
using OpConversionPattern<triton::PrintfOp>::OpConversionPattern;
|
||||
|
||||
LogicalResult
|
||||
matchAndRewrite(PrintfOp op, typename PrintfOp::Adaptor adaptor,
|
||||
matchAndRewrite(triton::PrintfOp op,
|
||||
typename triton::PrintfOp::Adaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
rewriter.replaceOpWithNewOp<triton::PrintfOp>(op, op.prefixAttr(),
|
||||
adaptor.getOperands());
|
||||
@@ -464,6 +465,19 @@ struct TritonPrintfPattern : public OpConversionPattern<triton::PrintfOp> {
|
||||
}
|
||||
};
|
||||
|
||||
struct TritonAssertPattern : public OpConversionPattern<triton::AssertOp> {
|
||||
using OpConversionPattern<triton::AssertOp>::OpConversionPattern;
|
||||
|
||||
LogicalResult
|
||||
matchAndRewrite(triton::AssertOp op,
|
||||
typename triton::AssertOp::Adaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
rewriter.replaceOpWithNewOp<triton::AssertOp>(op, adaptor.condition(),
|
||||
op.messageAttr());
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
void populateTritonPatterns(TritonGPUTypeConverter &typeConverter,
|
||||
RewritePatternSet &patterns) {
|
||||
MLIRContext *context = patterns.getContext();
|
||||
@@ -478,7 +492,7 @@ void populateTritonPatterns(TritonGPUTypeConverter &typeConverter,
|
||||
TritonReducePattern, TritonTransPattern, TritonExpandDimsPattern,
|
||||
TritonMakeRangePattern, TritonDotPattern, TritonLoadPattern,
|
||||
TritonStorePattern, TritonExtElemwisePattern, TritonPrintfPattern,
|
||||
TritonAtomicRMWPattern>(typeConverter, context);
|
||||
TritonAssertPattern, TritonAtomicRMWPattern>(typeConverter, context);
|
||||
}
|
||||
|
||||
//
|
||||
|
@@ -1261,6 +1261,14 @@ void init_triton_ir(py::module &&m) {
|
||||
llvm::StringRef(prefix)),
|
||||
values);
|
||||
})
|
||||
.def("create_assert",
|
||||
[](mlir::OpBuilder &self, mlir::Value &condition,
|
||||
const std::string &message) -> void {
|
||||
auto loc = self.getUnknownLoc();
|
||||
auto messageAttr = mlir::StringAttr::get(self.getContext(),
|
||||
llvm::StringRef(message));
|
||||
self.create<mlir::triton::AssertOp>(loc, condition, messageAttr);
|
||||
})
|
||||
// Undef
|
||||
.def("create_undef",
|
||||
[](mlir::OpBuilder &self, mlir::Type &type) -> mlir::Value {
|
||||
|
@@ -52,5 +52,21 @@ def printf(data_type):
|
||||
assert_close(y, x)
|
||||
|
||||
|
||||
printf("float16")
|
||||
printf("int8")
|
||||
def assert2(data_type):
|
||||
@triton.jit
|
||||
def kernel(X, Y, BLOCK: tl.constexpr):
|
||||
x = tl.load(X + tl.arange(0, BLOCK))
|
||||
tl.assert2(x == 0, "x > 0")
|
||||
tl.store(Y + tl.arange(0, BLOCK), x)
|
||||
|
||||
shape = (128, )
|
||||
# limit the range of integers so that the sum does not overflow
|
||||
x = get_tensor(shape, data_type)
|
||||
y = torch.zeros(shape, dtype=x.dtype, device="cuda")
|
||||
kernel[(1,)](x, y, BLOCK=shape[0])
|
||||
assert_close(y, x)
|
||||
|
||||
|
||||
#printf("float16")
|
||||
#printf("int8")
|
||||
assert2("float16")
|
@@ -5,11 +5,13 @@ from ..impl import (
|
||||
ir,
|
||||
builtin,
|
||||
)
|
||||
from . import libdevice
|
||||
from .core import (
|
||||
abs,
|
||||
arange,
|
||||
argmin,
|
||||
argmax,
|
||||
assert2,
|
||||
atomic_add,
|
||||
atomic_and,
|
||||
atomic_cas,
|
||||
@@ -97,6 +99,7 @@ __all__ = [
|
||||
"arange",
|
||||
"argmin",
|
||||
"argmax",
|
||||
"assert2",
|
||||
"atomic_add",
|
||||
"atomic_and",
|
||||
"atomic_cas",
|
||||
@@ -130,6 +133,7 @@ __all__ = [
|
||||
"int64",
|
||||
"int8",
|
||||
"ir",
|
||||
"libdevice",
|
||||
"load",
|
||||
"log",
|
||||
"max",
|
||||
|
@@ -1253,3 +1253,9 @@ def printf(prefix, *args, _builder=None):
|
||||
for arg in args:
|
||||
new_args.append(_to_tensor(arg, _builder))
|
||||
return semantic.printf(new_prefix, new_args, _builder)
|
||||
|
||||
|
||||
@builtin
|
||||
def assert2(cond, msg="", _builder=None):
|
||||
msg = _constexpr_to_value(msg)
|
||||
return semantic.assert2(_to_tensor(cond, _builder), msg, _builder)
|
||||
|
@@ -1057,6 +1057,13 @@ def reduce_impl(input: tl.tensor, axis: int, builder: ir.builder, name: str,
|
||||
if INT_OP in int_op_to_unit:
|
||||
INT_OP = int_op_to_unit[INT_OP]
|
||||
|
||||
# If we are doing an argmin or argmax we want to use an int32 output type
|
||||
out_scalar_ty = scalar_ty
|
||||
if FLOAT_OP is ir.REDUCE_OP.ARGFMAX or INT_OP is ir.REDUCE_OP.ARGMAX:
|
||||
out_scalar_ty = tl.int32
|
||||
elif FLOAT_OP is ir.REDUCE_OP.ARGFMIN or INT_OP is ir.REDUCE_OP.ARGMIN:
|
||||
out_scalar_ty = tl.int32
|
||||
|
||||
# get result type
|
||||
shape = input.type.shape
|
||||
ret_shape = []
|
||||
@@ -1064,10 +1071,10 @@ def reduce_impl(input: tl.tensor, axis: int, builder: ir.builder, name: str,
|
||||
if i != axis:
|
||||
ret_shape.append(s)
|
||||
if ret_shape:
|
||||
res_ty = tl.block_type(scalar_ty, ret_shape)
|
||||
res_ty = tl.block_type(out_scalar_ty, ret_shape)
|
||||
else:
|
||||
# 0d-tensor -> scalar
|
||||
res_ty = scalar_ty
|
||||
res_ty = out_scalar_ty
|
||||
|
||||
if scalar_ty.is_floating():
|
||||
return tl.tensor(builder.create_reduce(input.handle, FLOAT_OP, axis), res_ty)
|
||||
@@ -1163,3 +1170,7 @@ def printf(prefix: str, args: List[tl.tensor], builder: ir.builder) -> tl.tensor
|
||||
for arg in args:
|
||||
new_args.append(arg.handle)
|
||||
return tl.tensor(builder.create_printf(prefix, new_args), tl.void)
|
||||
|
||||
|
||||
def assert2(cond: tl.tensor, msg: str, builder: ir.builder) -> tl.tensor:
|
||||
return tl.tensor(builder.create_assert(cond.handle, msg), tl.void)
|
||||
|
@@ -52,6 +52,15 @@ func @convert(%A : !tt.ptr<f16>) {
|
||||
return
|
||||
}
|
||||
|
||||
// CHECK-LABEL: trans
|
||||
func @trans(%A : !tt.ptr<f16>) {
|
||||
// CHECK: %cst -> %cst
|
||||
%tensor = arith.constant dense<0.000000e+00> : tensor<16x32xf16, #A_SHARED>
|
||||
// CHECK: %0 -> %cst
|
||||
%b = tt.trans %tensor : (tensor<16x32xf16, #A_SHARED>) -> tensor<32x16xf16, #A_SHARED>
|
||||
return
|
||||
}
|
||||
|
||||
// CHECK-LABEL: insert_slice_async
|
||||
func @insert_slice_async(%A : !tt.ptr<f16>, %i1 : i1) {
|
||||
%a_ptr = tt.broadcast %A : (!tt.ptr<f16>) -> tensor<16x16x!tt.ptr<f16>, #AL>
|
||||
|
@@ -174,6 +174,14 @@ func @scratch() {
|
||||
// CHECK-NEXT: size = 512
|
||||
}
|
||||
|
||||
// CHECK-LABEL: trans
|
||||
func @trans(%A : !tt.ptr<f16>) {
|
||||
// CHECK: offset = 0, size = 1024
|
||||
%tensor = arith.constant dense<0.000000e+00> : tensor<16x32xf16, #A_SHARED>
|
||||
%b = tt.trans %tensor : (tensor<16x32xf16, #A_SHARED>) -> tensor<32x16xf16, #A_SHARED>
|
||||
return
|
||||
}
|
||||
|
||||
// CHECK-LABEL: insert_slice_async
|
||||
func @insert_slice_async(%A : !tt.ptr<f16>, %i1 : i1) {
|
||||
%a_ptr = tt.broadcast %A : (!tt.ptr<f16>) -> tensor<16x16x!tt.ptr<f16>, #AL>
|
||||
@@ -285,6 +293,25 @@ func @for_if_slice(%lb : index, %ub : index, %step : index, %A : !tt.ptr<f16>, %
|
||||
// CHECK-NEXT: size = 24576
|
||||
}
|
||||
|
||||
// c0 cannot be released in the loop
|
||||
// CHECK-LABEL: for_use_ancestor
|
||||
func @for_use_ancestor(%lb : index, %ub : index, %step : index, %A : !tt.ptr<f16>, %B : !tt.ptr<f16>, %i1 : i1) {
|
||||
// CHECK: offset = 0, size = 8192
|
||||
%a_shared_init = arith.constant dense<0.00e+00> : tensor<128x32xf16, #A_SHARED>
|
||||
// CHECK-NEXT: offset = 8192, size = 8192
|
||||
%b_shared_init = arith.constant dense<0.00e+00> : tensor<128x32xf16, #A_SHARED>
|
||||
// CHECK-NEXT: offset = 16384, size = 8192
|
||||
%c_shared_init = arith.constant dense<0.00e+00> : tensor<128x32xf16, #A_SHARED>
|
||||
%a_shared, %b_shared = scf.for %iv = %lb to %ub step %step iter_args(%a_shared = %a_shared_init, %b_shared = %b_shared_init) -> (tensor<128x32xf16, #A_SHARED>, tensor<128x32xf16, #A_SHARED>) {
|
||||
%c0 = tt.trans %c_shared_init : (tensor<128x32xf16, #A_SHARED>) -> tensor<32x128xf16, #A_SHARED>
|
||||
// CHECK-NEXT: offset = 24576, size = 8192
|
||||
%c1 = arith.constant dense<0.00e+00> : tensor<128x32xf16, #A_SHARED>
|
||||
scf.yield %b_shared, %a_shared: tensor<128x32xf16, #A_SHARED>, tensor<128x32xf16, #A_SHARED>
|
||||
}
|
||||
return
|
||||
// CHECK-NEXT: size = 32768
|
||||
}
|
||||
|
||||
// a_shared_init, b_shared_init, and c_shared_init's liveness ranges are span over the entire function before cst2.
|
||||
// So they cannot be reused by cst0 and cst1, but can be reused by cst2.
|
||||
// CHECK-LABEL: for_if_for
|
||||
|
@@ -111,6 +111,13 @@ func @extract_slice() {
|
||||
return
|
||||
}
|
||||
|
||||
// CHECK-LABEL: trans
|
||||
func @trans() {
|
||||
%cst0 = arith.constant dense<0.000000e+00> : tensor<16x32xf16, #A_SHARED>
|
||||
%b = tt.trans %cst0 : (tensor<16x32xf16, #A_SHARED>) -> tensor<32x16xf16, #A_SHARED>
|
||||
return
|
||||
}
|
||||
|
||||
// CHECK-LABEL: insert_slice_async
|
||||
func @insert_slice_async(%A : !tt.ptr<f16>, %i1 : i1) {
|
||||
%a_ptr = tt.broadcast %A : (!tt.ptr<f16>) -> tensor<16x16x!tt.ptr<f16>, #AL>
|
||||
|
Reference in New Issue
Block a user