Simple assert
This commit is contained in:
@@ -12,6 +12,7 @@
|
||||
#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
|
||||
#include "mlir/Dialect/GPU/GPUDialect.h"
|
||||
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
|
||||
#include "mlir/Dialect/StandardOps/IR/Ops.h"
|
||||
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
||||
#include "mlir/IR/Matchers.h"
|
||||
#include "mlir/IR/TypeUtilities.h"
|
||||
|
@@ -270,6 +270,46 @@ struct PrintfOpConversion
|
||||
}
|
||||
};
|
||||
|
||||
struct AssertOpConversion
|
||||
: public ConvertTritonGPUOpToLLVMPattern<triton::AssertOp> {
|
||||
using ConvertTritonGPUOpToLLVMPattern<
|
||||
triton::AssertOp>::ConvertTritonGPUOpToLLVMPattern;
|
||||
|
||||
LogicalResult
|
||||
matchAndRewrite(triton::AssertOp op, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
auto loc = op.getLoc();
|
||||
auto ctx = rewriter.getContext();
|
||||
auto voidTy = void_ty(ctx);
|
||||
auto elems = getElementsFromStruct(loc, adaptor.condition(), rewriter);
|
||||
Value ret;
|
||||
for (auto elem : elems) {
|
||||
auto type = elem.getType();
|
||||
Value condition;
|
||||
if (type.isIntOrFloat()) {
|
||||
if (type.isSignedInteger() || type.isSignlessInteger()) {
|
||||
condition = icmp_eq(elem, rewriter.create<LLVM::ConstantOp>(
|
||||
loc, type, rewriter.getZeroAttr(type)));
|
||||
} else {
|
||||
condition = fcmp_eq(elem, rewriter.create<LLVM::ConstantOp>(
|
||||
loc, type, rewriter.getZeroAttr(type)));
|
||||
}
|
||||
} else {
|
||||
assert(false && "Unsupported type for assert");
|
||||
return failure();
|
||||
}
|
||||
// MLIR::AssertOp is lowered to a call to llvm.abort, which cannot be
|
||||
// handled by ptxas
|
||||
PTXBuilder builder;
|
||||
auto &trapOp = *builder.create<PTXInstr>("trap");
|
||||
trapOp().predicate(condition);
|
||||
ret = builder.launch(rewriter, loc, voidTy);
|
||||
}
|
||||
rewriter.replaceOp(op, ret);
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
struct MakeRangeOpConversion
|
||||
: public ConvertTritonGPUOpToLLVMPattern<triton::MakeRangeOp> {
|
||||
|
||||
@@ -524,4 +564,5 @@ void populateTritonGPUToLLVMPatterns(
|
||||
patterns.add<MakeRangeOpConversion>(typeConverter, indexCacheInfo, benefit);
|
||||
patterns.add<ReturnOpConversion>(typeConverter, benefit);
|
||||
patterns.add<PrintfOpConversion>(typeConverter, benefit);
|
||||
patterns.add<AssertOpConversion>(typeConverter, benefit);
|
||||
}
|
@@ -45,6 +45,9 @@
|
||||
#define fcmp_olt(lhs, rhs) \
|
||||
rewriter.create<LLVM::FCmpOp>(loc, rewriter.getI1Type(), \
|
||||
LLVM::FCmpPredicate::olt, lhs, rhs)
|
||||
#define fcmp_eq(lhs, rhs) \
|
||||
rewriter.create<LLVM::FCmpOp>(loc, rewriter.getI1Type(), \
|
||||
LLVM::FCmpPredicate::oeq, lhs, rhs)
|
||||
#define icmp_eq(...) \
|
||||
rewriter.create<LLVM::ICmpOp>(loc, LLVM::ICmpPredicate::eq, __VA_ARGS__)
|
||||
#define icmp_ne(...) \
|
||||
@@ -77,6 +80,7 @@
|
||||
#define f16_ty rewriter.getF16Type()
|
||||
#define bf16_ty rewriter.getBF16Type()
|
||||
#define i8_ty rewriter.getIntegerType(8)
|
||||
#define i1_ty rewriter.getI1Type()
|
||||
#define f32_ty rewriter.getF32Type()
|
||||
#define f64_ty rewriter.getF64Type()
|
||||
#define vec_ty(type, num) VectorType::get(num, type)
|
||||
|
@@ -453,10 +453,11 @@ struct TritonReducePattern : public OpConversionPattern<triton::ReduceOp> {
|
||||
};
|
||||
|
||||
struct TritonPrintfPattern : public OpConversionPattern<triton::PrintfOp> {
|
||||
using OpConversionPattern<PrintfOp>::OpConversionPattern;
|
||||
using OpConversionPattern<triton::PrintfOp>::OpConversionPattern;
|
||||
|
||||
LogicalResult
|
||||
matchAndRewrite(PrintfOp op, typename PrintfOp::Adaptor adaptor,
|
||||
matchAndRewrite(triton::PrintfOp op,
|
||||
typename triton::PrintfOp::Adaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
rewriter.replaceOpWithNewOp<triton::PrintfOp>(op, op.prefixAttr(),
|
||||
adaptor.getOperands());
|
||||
@@ -464,6 +465,19 @@ struct TritonPrintfPattern : public OpConversionPattern<triton::PrintfOp> {
|
||||
}
|
||||
};
|
||||
|
||||
struct TritonAssertPattern : public OpConversionPattern<triton::AssertOp> {
|
||||
using OpConversionPattern<triton::AssertOp>::OpConversionPattern;
|
||||
|
||||
LogicalResult
|
||||
matchAndRewrite(triton::AssertOp op,
|
||||
typename triton::AssertOp::Adaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
rewriter.replaceOpWithNewOp<triton::AssertOp>(op, adaptor.condition(),
|
||||
op.messageAttr());
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
void populateTritonPatterns(TritonGPUTypeConverter &typeConverter,
|
||||
RewritePatternSet &patterns) {
|
||||
MLIRContext *context = patterns.getContext();
|
||||
@@ -478,7 +492,7 @@ void populateTritonPatterns(TritonGPUTypeConverter &typeConverter,
|
||||
TritonReducePattern, TritonTransPattern, TritonExpandDimsPattern,
|
||||
TritonMakeRangePattern, TritonDotPattern, TritonLoadPattern,
|
||||
TritonStorePattern, TritonExtElemwisePattern, TritonPrintfPattern,
|
||||
TritonAtomicRMWPattern>(typeConverter, context);
|
||||
TritonAssertPattern, TritonAtomicRMWPattern>(typeConverter, context);
|
||||
}
|
||||
|
||||
//
|
||||
|
Reference in New Issue
Block a user