Simple assert

This commit is contained in:
Jokeren
2023-01-05 15:04:08 -05:00
parent bc73bbb12c
commit 2920f6f50f
10 changed files with 112 additions and 7 deletions

View File

@@ -270,6 +270,46 @@ struct PrintfOpConversion
}
};
struct AssertOpConversion
: public ConvertTritonGPUOpToLLVMPattern<triton::AssertOp> {
using ConvertTritonGPUOpToLLVMPattern<
triton::AssertOp>::ConvertTritonGPUOpToLLVMPattern;
LogicalResult
matchAndRewrite(triton::AssertOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
auto loc = op.getLoc();
auto ctx = rewriter.getContext();
auto voidTy = void_ty(ctx);
auto elems = getElementsFromStruct(loc, adaptor.condition(), rewriter);
Value ret;
for (auto elem : elems) {
auto type = elem.getType();
Value condition;
if (type.isIntOrFloat()) {
if (type.isSignedInteger() || type.isSignlessInteger()) {
condition = icmp_eq(elem, rewriter.create<LLVM::ConstantOp>(
loc, type, rewriter.getZeroAttr(type)));
} else {
condition = fcmp_eq(elem, rewriter.create<LLVM::ConstantOp>(
loc, type, rewriter.getZeroAttr(type)));
}
} else {
assert(false && "Unsupported type for assert");
return failure();
}
// MLIR::AssertOp is lowered to a call to llvm.abort, which cannot be
// handled by ptxas
PTXBuilder builder;
auto &trapOp = *builder.create<PTXInstr>("trap");
trapOp().predicate(condition);
ret = builder.launch(rewriter, loc, voidTy);
}
rewriter.replaceOp(op, ret);
return success();
}
};
struct MakeRangeOpConversion
: public ConvertTritonGPUOpToLLVMPattern<triton::MakeRangeOp> {
@@ -524,4 +564,5 @@ void populateTritonGPUToLLVMPatterns(
patterns.add<MakeRangeOpConversion>(typeConverter, indexCacheInfo, benefit);
patterns.add<ReturnOpConversion>(typeConverter, benefit);
patterns.add<PrintfOpConversion>(typeConverter, benefit);
patterns.add<AssertOpConversion>(typeConverter, benefit);
}