6 Commits

Author SHA1 Message Date
Jokeren
4037f3b921 Add comment 2023-01-05 16:09:44 -05:00
Jokeren
fcff1a6e75 Add comment 2023-01-05 16:09:03 -05:00
Jokeren
2920f6f50f Simple assert 2023-01-05 15:04:08 -05:00
Sharad Vikram
bc73bbb12c [FRONTEND] Fix argmin/max output type (#1012)
Currently Triton returns tensors with the input types rather than i32
when doing reduce argmax/argmin.
2023-01-03 23:12:16 -08:00
Keren Zhou
8460ea3df1 [Frontend] Fix import for libdevice (#1028)
This is a hotfix for issue 1 in
https://github.com/openai/triton/issues/1017
2023-01-03 15:48:05 -08:00
Keren Zhou
678b9f53a2 [Backend] Use post-order traversal for liveness numbering (#1027)
Also add tests for `tt.trans`.
2023-01-03 15:11:54 -08:00
15 changed files with 190 additions and 16 deletions

View File

@@ -289,7 +289,7 @@ def TT_CatOp : TT_Op<"cat", [NoSideEffect,
}
def TT_TransOp : TT_Op<"trans", [NoSideEffect,
SameOperandsAndResultElementType]> {
SameOperandsAndResultElementType]> {
let summary = "transpose a tensor";
@@ -408,8 +408,7 @@ def TT_MakeRangeOp : TT_Op<"make_range", [NoSideEffect]> {
// Make PrintfOp
//
def TT_PrintfOp : TT_Op<"printf", [MemoryEffects<[MemWrite]>]>,
Arguments<(ins StrAttr:$prefix,
Variadic<AnyTypeOf<[TT_Type]>>:$args)> {
Arguments<(ins StrAttr:$prefix, Variadic<AnyTypeOf<[TT_Type]>>:$args)> {
let summary = "Device-side printf, as in CUDA for debugging";
let description = [{
`tt.printf` takes a literal string prefix and an arbitrary number of scalar or tensor arguments that should be printed.
@@ -420,4 +419,14 @@ def TT_PrintfOp : TT_Op<"printf", [MemoryEffects<[MemWrite]>]>,
}];
}
//
// Make AssertOp
//
def TT_AssertOp : TT_Op<"assert", [MemoryEffects<[MemWrite]>]> {
let summary = "Device-side assert, as in CUDA for debugging";
let description = [{}];
let arguments = (ins TT_Tensor:$condition, StrAttr:$message);
let assemblyFormat = "$condition `,` $message attr-dict `:` type($condition)";
}
#endif // Triton_OPS

View File

@@ -25,13 +25,14 @@ ChangeResult SharedMemoryAliasAnalysis::visitOperation(
if (maybeSharedAllocationOp(op)) {
// These ops may allocate a new shared memory buffer.
auto result = op->getResult(0);
// FIXME(Keren): extract and insert are always alias for now
// XXX(Keren): the following ops are always aliasing for now
if (isa<tensor::ExtractSliceOp, triton::TransOp>(op)) {
// extract_slice %src
// trans %src
aliasInfo = AliasInfo(operands[0]->getValue());
pessimistic = false;
} else if (isa<tensor::InsertSliceOp>(op) ||
isa<triton::gpu::InsertSliceAsyncOp>(op)) {
} else if (isa<tensor::InsertSliceOp, triton::gpu::InsertSliceAsyncOp>(
op)) {
// insert_slice_async %src, %dst, %index
// insert_slice %src into %dst[%offsets]
aliasInfo = AliasInfo(operands[1]->getValue());

View File

@@ -298,10 +298,24 @@ private:
/// Resolves liveness of all values involved under the root operation.
void resolveLiveness() {
// In the SCF dialect, we always have a sequentially nested structure of
// blocks
// Assign an ID to each operation using post-order traversal.
// To achieve the correct liveness range, the parent operation's ID
// should be greater than each of its child operation's ID .
// Example:
// ...
// %5 = triton.convert_layout %4
// %6 = scf.for ... iter_args(%arg0 = %0) -> (i32) {
// %2 = triton.convert_layout %5
// ...
// scf.yield %arg0
// }
// For example, %5 is defined in the parent region and used in
// the child region, and is not passed as a block argument.
// %6 should should have an ID greater than its child operations,
// otherwise %5 liveness range ends before the child operation's liveness
// range ends.
DenseMap<Operation *, size_t> operationId;
operation->walk<WalkOrder::PreOrder>(
operation->walk<WalkOrder::PostOrder>(
[&](Operation *op) { operationId[op] = operationId.size(); });
// Analyze liveness of explicit buffers

View File

@@ -12,6 +12,7 @@
#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
#include "mlir/Dialect/GPU/GPUDialect.h"
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
#include "mlir/Dialect/StandardOps/IR/Ops.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/IR/Matchers.h"
#include "mlir/IR/TypeUtilities.h"

View File

@@ -270,6 +270,48 @@ struct PrintfOpConversion
}
};
struct AssertOpConversion
: public ConvertTritonGPUOpToLLVMPattern<triton::AssertOp> {
using ConvertTritonGPUOpToLLVMPattern<
triton::AssertOp>::ConvertTritonGPUOpToLLVMPattern;
LogicalResult
matchAndRewrite(triton::AssertOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
auto loc = op.getLoc();
auto ctx = rewriter.getContext();
auto voidTy = void_ty(ctx);
auto elems = getElementsFromStruct(loc, adaptor.condition(), rewriter);
Value ret;
for (auto elem : elems) {
auto type = elem.getType();
Value condition;
if (type.isIntOrFloat()) {
if (type.isSignedInteger() || type.isSignlessInteger()) {
condition = icmp_eq(elem, rewriter.create<LLVM::ConstantOp>(
loc, type, rewriter.getZeroAttr(type)));
} else {
condition = fcmp_eq(elem, rewriter.create<LLVM::ConstantOp>(
loc, type, rewriter.getZeroAttr(type)));
}
} else {
assert(false && "Unsupported type for assert");
return failure();
}
// MLIR::AssertOp is lowered to a call to llvm.abort, which cannot be
// handled by ptxas
// We should call __assertfail here
// Delete the definition of triton.assert, using mlir.assert instead
PTXBuilder builder;
auto &trapOp = *builder.create<PTXInstr>("trap");
trapOp().predicate(condition);
ret = builder.launch(rewriter, loc, voidTy);
}
rewriter.replaceOp(op, ret);
return success();
}
};
struct MakeRangeOpConversion
: public ConvertTritonGPUOpToLLVMPattern<triton::MakeRangeOp> {
@@ -524,4 +566,5 @@ void populateTritonGPUToLLVMPatterns(
patterns.add<MakeRangeOpConversion>(typeConverter, indexCacheInfo, benefit);
patterns.add<ReturnOpConversion>(typeConverter, benefit);
patterns.add<PrintfOpConversion>(typeConverter, benefit);
patterns.add<AssertOpConversion>(typeConverter, benefit);
}

View File

@@ -45,6 +45,9 @@
#define fcmp_olt(lhs, rhs) \
rewriter.create<LLVM::FCmpOp>(loc, rewriter.getI1Type(), \
LLVM::FCmpPredicate::olt, lhs, rhs)
#define fcmp_eq(lhs, rhs) \
rewriter.create<LLVM::FCmpOp>(loc, rewriter.getI1Type(), \
LLVM::FCmpPredicate::oeq, lhs, rhs)
#define icmp_eq(...) \
rewriter.create<LLVM::ICmpOp>(loc, LLVM::ICmpPredicate::eq, __VA_ARGS__)
#define icmp_ne(...) \
@@ -77,6 +80,7 @@
#define f16_ty rewriter.getF16Type()
#define bf16_ty rewriter.getBF16Type()
#define i8_ty rewriter.getIntegerType(8)
#define i1_ty rewriter.getI1Type()
#define f32_ty rewriter.getF32Type()
#define f64_ty rewriter.getF64Type()
#define vec_ty(type, num) VectorType::get(num, type)

View File

@@ -453,10 +453,11 @@ struct TritonReducePattern : public OpConversionPattern<triton::ReduceOp> {
};
struct TritonPrintfPattern : public OpConversionPattern<triton::PrintfOp> {
using OpConversionPattern<PrintfOp>::OpConversionPattern;
using OpConversionPattern<triton::PrintfOp>::OpConversionPattern;
LogicalResult
matchAndRewrite(PrintfOp op, typename PrintfOp::Adaptor adaptor,
matchAndRewrite(triton::PrintfOp op,
typename triton::PrintfOp::Adaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
rewriter.replaceOpWithNewOp<triton::PrintfOp>(op, op.prefixAttr(),
adaptor.getOperands());
@@ -464,6 +465,19 @@ struct TritonPrintfPattern : public OpConversionPattern<triton::PrintfOp> {
}
};
struct TritonAssertPattern : public OpConversionPattern<triton::AssertOp> {
using OpConversionPattern<triton::AssertOp>::OpConversionPattern;
LogicalResult
matchAndRewrite(triton::AssertOp op,
typename triton::AssertOp::Adaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
rewriter.replaceOpWithNewOp<triton::AssertOp>(op, adaptor.condition(),
op.messageAttr());
return success();
}
};
void populateTritonPatterns(TritonGPUTypeConverter &typeConverter,
RewritePatternSet &patterns) {
MLIRContext *context = patterns.getContext();
@@ -478,7 +492,7 @@ void populateTritonPatterns(TritonGPUTypeConverter &typeConverter,
TritonReducePattern, TritonTransPattern, TritonExpandDimsPattern,
TritonMakeRangePattern, TritonDotPattern, TritonLoadPattern,
TritonStorePattern, TritonExtElemwisePattern, TritonPrintfPattern,
TritonAtomicRMWPattern>(typeConverter, context);
TritonAssertPattern, TritonAtomicRMWPattern>(typeConverter, context);
}
//

View File

@@ -1261,6 +1261,14 @@ void init_triton_ir(py::module &&m) {
llvm::StringRef(prefix)),
values);
})
.def("create_assert",
[](mlir::OpBuilder &self, mlir::Value &condition,
const std::string &message) -> void {
auto loc = self.getUnknownLoc();
auto messageAttr = mlir::StringAttr::get(self.getContext(),
llvm::StringRef(message));
self.create<mlir::triton::AssertOp>(loc, condition, messageAttr);
})
// Undef
.def("create_undef",
[](mlir::OpBuilder &self, mlir::Type &type) -> mlir::Value {

View File

@@ -52,5 +52,21 @@ def printf(data_type):
assert_close(y, x)
printf("float16")
printf("int8")
def assert2(data_type):
@triton.jit
def kernel(X, Y, BLOCK: tl.constexpr):
x = tl.load(X + tl.arange(0, BLOCK))
tl.assert2(x == 0, "x > 0")
tl.store(Y + tl.arange(0, BLOCK), x)
shape = (128, )
# limit the range of integers so that the sum does not overflow
x = get_tensor(shape, data_type)
y = torch.zeros(shape, dtype=x.dtype, device="cuda")
kernel[(1,)](x, y, BLOCK=shape[0])
assert_close(y, x)
#printf("float16")
#printf("int8")
assert2("float16")

View File

@@ -5,11 +5,13 @@ from ..impl import (
ir,
builtin,
)
from . import libdevice
from .core import (
abs,
arange,
argmin,
argmax,
assert2,
atomic_add,
atomic_and,
atomic_cas,
@@ -97,6 +99,7 @@ __all__ = [
"arange",
"argmin",
"argmax",
"assert2",
"atomic_add",
"atomic_and",
"atomic_cas",
@@ -130,6 +133,7 @@ __all__ = [
"int64",
"int8",
"ir",
"libdevice",
"load",
"log",
"max",

View File

@@ -1253,3 +1253,9 @@ def printf(prefix, *args, _builder=None):
for arg in args:
new_args.append(_to_tensor(arg, _builder))
return semantic.printf(new_prefix, new_args, _builder)
@builtin
def assert2(cond, msg="", _builder=None):
msg = _constexpr_to_value(msg)
return semantic.assert2(_to_tensor(cond, _builder), msg, _builder)

View File

@@ -1057,6 +1057,13 @@ def reduce_impl(input: tl.tensor, axis: int, builder: ir.builder, name: str,
if INT_OP in int_op_to_unit:
INT_OP = int_op_to_unit[INT_OP]
# If we are doing an argmin or argmax we want to use an int32 output type
out_scalar_ty = scalar_ty
if FLOAT_OP is ir.REDUCE_OP.ARGFMAX or INT_OP is ir.REDUCE_OP.ARGMAX:
out_scalar_ty = tl.int32
elif FLOAT_OP is ir.REDUCE_OP.ARGFMIN or INT_OP is ir.REDUCE_OP.ARGMIN:
out_scalar_ty = tl.int32
# get result type
shape = input.type.shape
ret_shape = []
@@ -1064,10 +1071,10 @@ def reduce_impl(input: tl.tensor, axis: int, builder: ir.builder, name: str,
if i != axis:
ret_shape.append(s)
if ret_shape:
res_ty = tl.block_type(scalar_ty, ret_shape)
res_ty = tl.block_type(out_scalar_ty, ret_shape)
else:
# 0d-tensor -> scalar
res_ty = scalar_ty
res_ty = out_scalar_ty
if scalar_ty.is_floating():
return tl.tensor(builder.create_reduce(input.handle, FLOAT_OP, axis), res_ty)
@@ -1163,3 +1170,7 @@ def printf(prefix: str, args: List[tl.tensor], builder: ir.builder) -> tl.tensor
for arg in args:
new_args.append(arg.handle)
return tl.tensor(builder.create_printf(prefix, new_args), tl.void)
def assert2(cond: tl.tensor, msg: str, builder: ir.builder) -> tl.tensor:
return tl.tensor(builder.create_assert(cond.handle, msg), tl.void)

View File

@@ -52,6 +52,15 @@ func @convert(%A : !tt.ptr<f16>) {
return
}
// CHECK-LABEL: trans
func @trans(%A : !tt.ptr<f16>) {
// CHECK: %cst -> %cst
%tensor = arith.constant dense<0.000000e+00> : tensor<16x32xf16, #A_SHARED>
// CHECK: %0 -> %cst
%b = tt.trans %tensor : (tensor<16x32xf16, #A_SHARED>) -> tensor<32x16xf16, #A_SHARED>
return
}
// CHECK-LABEL: insert_slice_async
func @insert_slice_async(%A : !tt.ptr<f16>, %i1 : i1) {
%a_ptr = tt.broadcast %A : (!tt.ptr<f16>) -> tensor<16x16x!tt.ptr<f16>, #AL>

View File

@@ -174,6 +174,14 @@ func @scratch() {
// CHECK-NEXT: size = 512
}
// CHECK-LABEL: trans
func @trans(%A : !tt.ptr<f16>) {
// CHECK: offset = 0, size = 1024
%tensor = arith.constant dense<0.000000e+00> : tensor<16x32xf16, #A_SHARED>
%b = tt.trans %tensor : (tensor<16x32xf16, #A_SHARED>) -> tensor<32x16xf16, #A_SHARED>
return
}
// CHECK-LABEL: insert_slice_async
func @insert_slice_async(%A : !tt.ptr<f16>, %i1 : i1) {
%a_ptr = tt.broadcast %A : (!tt.ptr<f16>) -> tensor<16x16x!tt.ptr<f16>, #AL>
@@ -285,6 +293,25 @@ func @for_if_slice(%lb : index, %ub : index, %step : index, %A : !tt.ptr<f16>, %
// CHECK-NEXT: size = 24576
}
// c0 cannot be released in the loop
// CHECK-LABEL: for_use_ancestor
func @for_use_ancestor(%lb : index, %ub : index, %step : index, %A : !tt.ptr<f16>, %B : !tt.ptr<f16>, %i1 : i1) {
// CHECK: offset = 0, size = 8192
%a_shared_init = arith.constant dense<0.00e+00> : tensor<128x32xf16, #A_SHARED>
// CHECK-NEXT: offset = 8192, size = 8192
%b_shared_init = arith.constant dense<0.00e+00> : tensor<128x32xf16, #A_SHARED>
// CHECK-NEXT: offset = 16384, size = 8192
%c_shared_init = arith.constant dense<0.00e+00> : tensor<128x32xf16, #A_SHARED>
%a_shared, %b_shared = scf.for %iv = %lb to %ub step %step iter_args(%a_shared = %a_shared_init, %b_shared = %b_shared_init) -> (tensor<128x32xf16, #A_SHARED>, tensor<128x32xf16, #A_SHARED>) {
%c0 = tt.trans %c_shared_init : (tensor<128x32xf16, #A_SHARED>) -> tensor<32x128xf16, #A_SHARED>
// CHECK-NEXT: offset = 24576, size = 8192
%c1 = arith.constant dense<0.00e+00> : tensor<128x32xf16, #A_SHARED>
scf.yield %b_shared, %a_shared: tensor<128x32xf16, #A_SHARED>, tensor<128x32xf16, #A_SHARED>
}
return
// CHECK-NEXT: size = 32768
}
// a_shared_init, b_shared_init, and c_shared_init's liveness ranges are span over the entire function before cst2.
// So they cannot be reused by cst0 and cst1, but can be reused by cst2.
// CHECK-LABEL: for_if_for

View File

@@ -111,6 +111,13 @@ func @extract_slice() {
return
}
// CHECK-LABEL: trans
func @trans() {
%cst0 = arith.constant dense<0.000000e+00> : tensor<16x32xf16, #A_SHARED>
%b = tt.trans %cst0 : (tensor<16x32xf16, #A_SHARED>) -> tensor<32x16xf16, #A_SHARED>
return
}
// CHECK-LABEL: insert_slice_async
func @insert_slice_async(%A : !tt.ptr<f16>, %i1 : i1) {
%a_ptr = tt.broadcast %A : (!tt.ptr<f16>) -> tensor<16x16x!tt.ptr<f16>, #AL>