[Triton-MLIR][OPTIMIZER] Cleaned up swizzling (#869)
Swizzling is no longer implemented as a separate pass. It is instead done in a specialized constructor of SharedEncodingAttr, and tested via google tests instead of triton-opt + filecheck. In the future we may want to implement it as a pass again once we have an additional dialect between TritonGPU and LLVM.
This commit is contained in:
@@ -2707,6 +2707,7 @@ public:
|
||||
return lowerDistributedToDistributed(op, adaptor, rewriter);
|
||||
}
|
||||
// TODO: to be implemented
|
||||
llvm_unreachable("unsupported layout conversion");
|
||||
return failure();
|
||||
}
|
||||
|
||||
@@ -5763,6 +5764,35 @@ void populateTritonToLLVMPatterns(mlir::LLVMTypeConverter &typeConverter,
|
||||
|
||||
class ConvertTritonGPUToLLVM
|
||||
: public ConvertTritonGPUToLLVMBase<ConvertTritonGPUToLLVM> {
|
||||
|
||||
private:
|
||||
void decomposeBlockedToDotOperand(ModuleOp mod) {
|
||||
// replace `blocked -> dot_op` with `blocked -> shared -> dot_op`
|
||||
// because the codegen doesn't handle `blocked -> dot_op` directly
|
||||
mod.walk([&](triton::gpu::ConvertLayoutOp cvtOp) -> void {
|
||||
OpBuilder builder(cvtOp);
|
||||
auto srcType = cvtOp.getOperand().getType().cast<RankedTensorType>();
|
||||
auto dstType = cvtOp.getType().cast<RankedTensorType>();
|
||||
auto srcBlocked =
|
||||
srcType.getEncoding().dyn_cast<triton::gpu::BlockedEncodingAttr>();
|
||||
auto dstDotOp =
|
||||
dstType.getEncoding().dyn_cast<triton::gpu::DotOperandEncodingAttr>();
|
||||
if (srcBlocked && dstDotOp) {
|
||||
auto tmpType = RankedTensorType::get(
|
||||
dstType.getShape(), dstType.getElementType(),
|
||||
triton::gpu::SharedEncodingAttr::get(mod.getContext(), dstDotOp,
|
||||
srcType.getShape(),
|
||||
srcType.getElementType()));
|
||||
auto tmp = builder.create<triton::gpu::ConvertLayoutOp>(
|
||||
cvtOp.getLoc(), tmpType, cvtOp.getOperand());
|
||||
auto newConvert = builder.create<triton::gpu::ConvertLayoutOp>(
|
||||
cvtOp.getLoc(), dstType, tmp);
|
||||
cvtOp.replaceAllUsesWith(newConvert.getResult());
|
||||
cvtOp.erase();
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
public:
|
||||
ConvertTritonGPUToLLVM() = default;
|
||||
|
||||
@@ -5779,15 +5809,19 @@ public:
|
||||
|
||||
int numWarps = triton::gpu::TritonGPUDialect::getNumWarps(mod);
|
||||
|
||||
// step 1: Allocate shared memories and insert barriers
|
||||
// step 2: Convert SCF to CFG
|
||||
// step 3: Convert FuncOp to LLVMFuncOp via partial conversion
|
||||
// step 4: Convert the rest of ops via partial conversion
|
||||
// step 1: Decompose unoptimized layout conversions to use shared memory
|
||||
// step 2: Allocate shared memories and insert barriers
|
||||
// step 3: Convert SCF to CFG
|
||||
// step 4: Convert FuncOp to LLVMFuncOp via partial conversion
|
||||
// step 5: Convert the rest of ops via partial conversion
|
||||
// The reason for putting step 1 before step 2 is that the membar analysis
|
||||
// currently only supports SCF but not CFG.
|
||||
// The reason for a separation between 1/4 is that, step 3 is out of
|
||||
// the scope of Dialect Conversion, thus we need to make sure the smem
|
||||
// is not revised during the conversion of step 4.
|
||||
|
||||
decomposeBlockedToDotOperand(mod);
|
||||
|
||||
Allocation allocation(mod);
|
||||
MembarAnalysis membar(&allocation);
|
||||
|
||||
|
Reference in New Issue
Block a user