Simple assert

This commit is contained in:
Jokeren
2023-01-05 15:04:08 -05:00
parent bc73bbb12c
commit 2920f6f50f
10 changed files with 112 additions and 7 deletions

View File

@@ -12,6 +12,7 @@
#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
#include "mlir/Dialect/GPU/GPUDialect.h"
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
#include "mlir/Dialect/StandardOps/IR/Ops.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/IR/Matchers.h"
#include "mlir/IR/TypeUtilities.h"

View File

@@ -270,6 +270,46 @@ struct PrintfOpConversion
}
};
struct AssertOpConversion
: public ConvertTritonGPUOpToLLVMPattern<triton::AssertOp> {
using ConvertTritonGPUOpToLLVMPattern<
triton::AssertOp>::ConvertTritonGPUOpToLLVMPattern;
LogicalResult
matchAndRewrite(triton::AssertOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
auto loc = op.getLoc();
auto ctx = rewriter.getContext();
auto voidTy = void_ty(ctx);
auto elems = getElementsFromStruct(loc, adaptor.condition(), rewriter);
Value ret;
for (auto elem : elems) {
auto type = elem.getType();
Value condition;
if (type.isIntOrFloat()) {
if (type.isSignedInteger() || type.isSignlessInteger()) {
condition = icmp_eq(elem, rewriter.create<LLVM::ConstantOp>(
loc, type, rewriter.getZeroAttr(type)));
} else {
condition = fcmp_eq(elem, rewriter.create<LLVM::ConstantOp>(
loc, type, rewriter.getZeroAttr(type)));
}
} else {
assert(false && "Unsupported type for assert");
return failure();
}
// MLIR::AssertOp is lowered to a call to llvm.abort, which cannot be
// handled by ptxas
PTXBuilder builder;
auto &trapOp = *builder.create<PTXInstr>("trap");
trapOp().predicate(condition);
ret = builder.launch(rewriter, loc, voidTy);
}
rewriter.replaceOp(op, ret);
return success();
}
};
struct MakeRangeOpConversion
: public ConvertTritonGPUOpToLLVMPattern<triton::MakeRangeOp> {
@@ -524,4 +564,5 @@ void populateTritonGPUToLLVMPatterns(
patterns.add<MakeRangeOpConversion>(typeConverter, indexCacheInfo, benefit);
patterns.add<ReturnOpConversion>(typeConverter, benefit);
patterns.add<PrintfOpConversion>(typeConverter, benefit);
patterns.add<AssertOpConversion>(typeConverter, benefit);
}

View File

@@ -45,6 +45,9 @@
#define fcmp_olt(lhs, rhs) \
rewriter.create<LLVM::FCmpOp>(loc, rewriter.getI1Type(), \
LLVM::FCmpPredicate::olt, lhs, rhs)
#define fcmp_eq(lhs, rhs) \
rewriter.create<LLVM::FCmpOp>(loc, rewriter.getI1Type(), \
LLVM::FCmpPredicate::oeq, lhs, rhs)
#define icmp_eq(...) \
rewriter.create<LLVM::ICmpOp>(loc, LLVM::ICmpPredicate::eq, __VA_ARGS__)
#define icmp_ne(...) \
@@ -77,6 +80,7 @@
#define f16_ty rewriter.getF16Type()
#define bf16_ty rewriter.getBF16Type()
#define i8_ty rewriter.getIntegerType(8)
#define i1_ty rewriter.getI1Type()
#define f32_ty rewriter.getF32Type()
#define f64_ty rewriter.getF64Type()
#define vec_ty(type, num) VectorType::get(num, type)

View File

@@ -453,10 +453,11 @@ struct TritonReducePattern : public OpConversionPattern<triton::ReduceOp> {
};
struct TritonPrintfPattern : public OpConversionPattern<triton::PrintfOp> {
using OpConversionPattern<PrintfOp>::OpConversionPattern;
using OpConversionPattern<triton::PrintfOp>::OpConversionPattern;
LogicalResult
matchAndRewrite(PrintfOp op, typename PrintfOp::Adaptor adaptor,
matchAndRewrite(triton::PrintfOp op,
typename triton::PrintfOp::Adaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
rewriter.replaceOpWithNewOp<triton::PrintfOp>(op, op.prefixAttr(),
adaptor.getOperands());
@@ -464,6 +465,19 @@ struct TritonPrintfPattern : public OpConversionPattern<triton::PrintfOp> {
}
};
struct TritonAssertPattern : public OpConversionPattern<triton::AssertOp> {
using OpConversionPattern<triton::AssertOp>::OpConversionPattern;
LogicalResult
matchAndRewrite(triton::AssertOp op,
typename triton::AssertOp::Adaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
rewriter.replaceOpWithNewOp<triton::AssertOp>(op, adaptor.condition(),
op.messageAttr());
return success();
}
};
void populateTritonPatterns(TritonGPUTypeConverter &typeConverter,
RewritePatternSet &patterns) {
MLIRContext *context = patterns.getContext();
@@ -478,7 +492,7 @@ void populateTritonPatterns(TritonGPUTypeConverter &typeConverter,
TritonReducePattern, TritonTransPattern, TritonExpandDimsPattern,
TritonMakeRangePattern, TritonDotPattern, TritonLoadPattern,
TritonStorePattern, TritonExtElemwisePattern, TritonPrintfPattern,
TritonAtomicRMWPattern>(typeConverter, context);
TritonAssertPattern, TritonAtomicRMWPattern>(typeConverter, context);
}
//