[Triton-MLIR][OPTIMIZER] Cleaned up swizzling (#869)

Swizzling is no longer implemented as a separate pass. It is instead
done in a specialized constructor of SharedEncodingAttr, and tested via
google tests instead of triton-opt + filecheck.

In the future we may want to implement it as a pass again once we have
an additional dialect between TritonGPU and LLVM.
This commit is contained in:
Philippe Tillet
2022-11-10 12:05:46 -08:00
committed by GitHub
parent 2aa538ec2e
commit f40c63fb03
14 changed files with 170 additions and 318 deletions

View File

@@ -2707,6 +2707,7 @@ public:
return lowerDistributedToDistributed(op, adaptor, rewriter);
}
// TODO: to be implemented
llvm_unreachable("unsupported layout conversion");
return failure();
}
@@ -5763,6 +5764,35 @@ void populateTritonToLLVMPatterns(mlir::LLVMTypeConverter &typeConverter,
class ConvertTritonGPUToLLVM
: public ConvertTritonGPUToLLVMBase<ConvertTritonGPUToLLVM> {
private:
void decomposeBlockedToDotOperand(ModuleOp mod) {
// replace `blocked -> dot_op` with `blocked -> shared -> dot_op`
// because the codegen doesn't handle `blocked -> dot_op` directly
mod.walk([&](triton::gpu::ConvertLayoutOp cvtOp) -> void {
OpBuilder builder(cvtOp);
auto srcType = cvtOp.getOperand().getType().cast<RankedTensorType>();
auto dstType = cvtOp.getType().cast<RankedTensorType>();
auto srcBlocked =
srcType.getEncoding().dyn_cast<triton::gpu::BlockedEncodingAttr>();
auto dstDotOp =
dstType.getEncoding().dyn_cast<triton::gpu::DotOperandEncodingAttr>();
if (srcBlocked && dstDotOp) {
auto tmpType = RankedTensorType::get(
dstType.getShape(), dstType.getElementType(),
triton::gpu::SharedEncodingAttr::get(mod.getContext(), dstDotOp,
srcType.getShape(),
srcType.getElementType()));
auto tmp = builder.create<triton::gpu::ConvertLayoutOp>(
cvtOp.getLoc(), tmpType, cvtOp.getOperand());
auto newConvert = builder.create<triton::gpu::ConvertLayoutOp>(
cvtOp.getLoc(), dstType, tmp);
cvtOp.replaceAllUsesWith(newConvert.getResult());
cvtOp.erase();
}
});
}
public:
ConvertTritonGPUToLLVM() = default;
@@ -5779,15 +5809,19 @@ public:
int numWarps = triton::gpu::TritonGPUDialect::getNumWarps(mod);
// step 1: Allocate shared memories and insert barriers
// step 2: Convert SCF to CFG
// step 3: Convert FuncOp to LLVMFuncOp via partial conversion
// step 4: Convert the rest of ops via partial conversion
// step 1: Decompose unoptimized layout conversions to use shared memory
// step 2: Allocate shared memories and insert barriers
// step 3: Convert SCF to CFG
// step 4: Convert FuncOp to LLVMFuncOp via partial conversion
// step 5: Convert the rest of ops via partial conversion
// The reason for putting step 1 before step 2 is that the membar analysis
// currently only supports SCF but not CFG.
// The reason for a separation between 1/4 is that, step 3 is out of
// the scope of Dialect Conversion, thus we need to make sure the smem
// is not revised during the conversion of step 4.
decomposeBlockedToDotOperand(mod);
Allocation allocation(mod);
MembarAnalysis membar(&allocation);