[BACKEND][FRONTEND] Fix problems with test_matmul (#973)
1. Handle induction variable when step is negative 2. Restore async_wait that accidentally deleted 3. Add missing induction variable in prefetch 4. Add device property functions Co-authored-by: Philippe Tillet <Phil.Tillet@gmail.com>
This commit is contained in:
@@ -4783,10 +4783,15 @@ private:
|
||||
decomposed = true;
|
||||
});
|
||||
|
||||
// async wait is supported in Ampere and later
|
||||
mod.walk([&](triton::gpu::AsyncWaitOp asyncWaitOp) -> void {
|
||||
if (!triton::gpu::AsyncWaitOp::isSupported(computeCapability) ||
|
||||
decomposed) {
|
||||
if (!triton::gpu::AsyncWaitOp::isSupported(computeCapability)) {
|
||||
// async wait is supported in Ampere and later
|
||||
asyncWaitOp.erase();
|
||||
} else if (decomposed) {
|
||||
// Wait for all previous async ops
|
||||
OpBuilder builder(asyncWaitOp);
|
||||
auto newAsyncWaitOp =
|
||||
builder.create<triton::gpu::AsyncWaitOp>(asyncWaitOp.getLoc(), 0);
|
||||
asyncWaitOp.erase();
|
||||
}
|
||||
});
|
||||
|
@@ -262,10 +262,10 @@ struct TritonCatPattern : public OpConversionPattern<triton::CatOp> {
|
||||
// For now, this behaves like generic, but this will evolve when
|
||||
// we add support for `can_reorder=False`
|
||||
Type retType = this->getTypeConverter()->convertType(op.getType());
|
||||
rewriter.replaceOpWithNewOp<triton::CatOp>(op, retType, adaptor.getOperands());
|
||||
rewriter.replaceOpWithNewOp<triton::CatOp>(op, retType,
|
||||
adaptor.getOperands());
|
||||
return success();
|
||||
}
|
||||
|
||||
};
|
||||
|
||||
struct TritonTransPattern : public OpConversionPattern<triton::TransOp> {
|
||||
@@ -450,13 +450,11 @@ void populateTritonPatterns(TritonGPUTypeConverter &typeConverter,
|
||||
TritonGenericPattern<triton::IntToPtrOp>,
|
||||
TritonGenericPattern<triton::PtrToIntOp>,
|
||||
TritonGenericPattern<triton::SplatOp>, TritonBroadcastPattern,
|
||||
TritonGenericPattern<triton::AddPtrOp>,
|
||||
TritonCatPattern,
|
||||
TritonReducePattern,
|
||||
TritonTransPattern, TritonExpandDimsPattern, TritonMakeRangePattern,
|
||||
TritonDotPattern, TritonLoadPattern, TritonStorePattern,
|
||||
TritonExtElemwisePattern, TritonPrintfPattern, TritonAtomicRMWPattern>(
|
||||
typeConverter, context);
|
||||
TritonGenericPattern<triton::AddPtrOp>, TritonCatPattern,
|
||||
TritonReducePattern, TritonTransPattern, TritonExpandDimsPattern,
|
||||
TritonMakeRangePattern, TritonDotPattern, TritonLoadPattern,
|
||||
TritonStorePattern, TritonExtElemwisePattern, TritonPrintfPattern,
|
||||
TritonAtomicRMWPattern>(typeConverter, context);
|
||||
}
|
||||
|
||||
//
|
||||
|
@@ -782,8 +782,8 @@ public:
|
||||
newRetType.getEncoding()));
|
||||
a = rewriter.create<triton::gpu::ConvertLayoutOp>(a.getLoc(), newAType, a);
|
||||
b = rewriter.create<triton::gpu::ConvertLayoutOp>(b.getLoc(), newBType, b);
|
||||
auto newDot = rewriter.create<triton::DotOp>(
|
||||
dotOp.getLoc(), newRetType, a, b, newAcc, dotOp.allowTF32());
|
||||
auto newDot = rewriter.create<triton::DotOp>(dotOp.getLoc(), newRetType, a,
|
||||
b, newAcc, dotOp.allowTF32());
|
||||
|
||||
rewriter.replaceOpWithNewOp<triton::gpu::ConvertLayoutOp>(
|
||||
op, oldRetType, newDot.getResult());
|
||||
|
@@ -225,6 +225,7 @@ scf::ForOp Prefetcher::createNewForOp() {
|
||||
BlockAndValueMapping mapping;
|
||||
for (const auto &arg : llvm::enumerate(forOp.getRegionIterArgs()))
|
||||
mapping.map(arg.value(), newForOp.getRegionIterArgs()[arg.index()]);
|
||||
mapping.map(forOp.getInductionVar(), newForOp.getInductionVar());
|
||||
|
||||
for (Operation &op : forOp.getBody()->without_terminator()) {
|
||||
Operation *newOp = builder.clone(op, mapping);
|
||||
|
Reference in New Issue
Block a user