cleanup
This commit is contained in:
@@ -13,6 +13,8 @@ std::unique_ptr<Pass> createTritonGPUCanonicalizeLoopsPass();
|
|||||||
|
|
||||||
std::unique_ptr<Pass> createTritonGPUCoalescePass();
|
std::unique_ptr<Pass> createTritonGPUCoalescePass();
|
||||||
|
|
||||||
|
std::unique_ptr<Pass> createTritonGPUOptimizeLoadConvertPass();
|
||||||
|
|
||||||
std::unique_ptr<Pass> createTritonGPUCombineOpsPass(int computeCapability = 80);
|
std::unique_ptr<Pass> createTritonGPUCombineOpsPass(int computeCapability = 80);
|
||||||
|
|
||||||
std::unique_ptr<Pass> createTritonGPUVerifier();
|
std::unique_ptr<Pass> createTritonGPUVerifier();
|
||||||
|
@@ -71,6 +71,19 @@ def TritonGPUCombineOps : Pass<"tritongpu-combine", "mlir::ModuleOp"> {
|
|||||||
];
|
];
|
||||||
}
|
}
|
||||||
|
|
||||||
|
def TritonGPUOptimizeLoadConvert: Pass<"tritongpu-optimize-load-convert", "mlir::ModuleOp"> {
|
||||||
|
let summary = "Optimize load + convert into insert_slice_async + wait + extract_slice + convert";
|
||||||
|
|
||||||
|
let description = "Transform load + convert into insert_slice_async + wait + extract_slice + convert."
|
||||||
|
"This decreases registers pressure on architecture with direct pathways between DRAM "
|
||||||
|
"and shared memory";
|
||||||
|
|
||||||
|
let constructor = "mlir::createTritonGPUOptimizeLoadConvertPass()";
|
||||||
|
|
||||||
|
let dependentDialects = ["mlir::triton::gpu::TritonGPUDialect",
|
||||||
|
"mlir::triton::TritonDialect"];
|
||||||
|
}
|
||||||
|
|
||||||
def TritonGPUCanonicalizeLoops: Pass<"tritongpu-canonicalize-loops", "mlir::ModuleOp"> {
|
def TritonGPUCanonicalizeLoops: Pass<"tritongpu-canonicalize-loops", "mlir::ModuleOp"> {
|
||||||
let summary = "canonicalize scf.ForOp ops";
|
let summary = "canonicalize scf.ForOp ops";
|
||||||
|
|
||||||
|
@@ -8,6 +8,7 @@ add_mlir_dialect_library(TritonGPUTransforms
|
|||||||
Combine.cpp
|
Combine.cpp
|
||||||
Pipeline.cpp
|
Pipeline.cpp
|
||||||
Prefetch.cpp
|
Prefetch.cpp
|
||||||
|
OptimizeLoadConvert.cpp
|
||||||
TritonGPUConversion.cpp
|
TritonGPUConversion.cpp
|
||||||
|
|
||||||
DEPENDS
|
DEPENDS
|
||||||
|
@@ -19,7 +19,6 @@
|
|||||||
|
|
||||||
#include <memory>
|
#include <memory>
|
||||||
|
|
||||||
#define int_attr(num) rewriter.getI64IntegerAttr(num)
|
|
||||||
|
|
||||||
using namespace mlir;
|
using namespace mlir;
|
||||||
namespace {
|
namespace {
|
||||||
@@ -1155,60 +1154,6 @@ public:
|
|||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
class LoadConvertToInsertSlice : public mlir::RewritePattern{
|
|
||||||
|
|
||||||
public:
|
|
||||||
explicit LoadConvertToInsertSlice(mlir::MLIRContext *context)
|
|
||||||
: mlir::RewritePattern(triton::gpu::ConvertLayoutOp::getOperationName(), 2, context) {}
|
|
||||||
|
|
||||||
mlir::LogicalResult
|
|
||||||
matchAndRewrite(mlir::Operation *op,
|
|
||||||
mlir::PatternRewriter &rewriter) const override {
|
|
||||||
auto cvt = cast<triton::gpu::ConvertLayoutOp>(op);
|
|
||||||
auto origRetType = cvt.getResult().getType().cast<RankedTensorType>();
|
|
||||||
auto shape = origRetType.getShape();
|
|
||||||
auto eltType = origRetType.getElementType();
|
|
||||||
auto dotOpEncoding = origRetType.getEncoding().dyn_cast<triton::gpu::DotOperandEncodingAttr>();
|
|
||||||
if(!dotOpEncoding){
|
|
||||||
return failure();
|
|
||||||
}
|
|
||||||
auto loadOp = dyn_cast<triton::LoadOp>(*cvt.getOperand().getDefiningOp());
|
|
||||||
if(!loadOp){
|
|
||||||
return failure();
|
|
||||||
}
|
|
||||||
auto blockedEncoding = loadOp.getType().cast<RankedTensorType>().getEncoding().dyn_cast<triton::gpu::BlockedEncodingAttr>();
|
|
||||||
if(!blockedEncoding)
|
|
||||||
return failure();
|
|
||||||
auto sharedEncoding = triton::gpu::SharedEncodingAttr::get(getContext(), dotOpEncoding, shape,
|
|
||||||
blockedEncoding.getOrder(), eltType);
|
|
||||||
auto srcTy = RankedTensorType::get({1, shape[0], shape[1]},
|
|
||||||
eltType,
|
|
||||||
sharedEncoding);
|
|
||||||
auto loadTensor = rewriter.create<triton::gpu::AllocTensorOp>(op->getLoc(), srcTy);
|
|
||||||
|
|
||||||
auto newOp = rewriter.create<triton::gpu::InsertSliceAsyncOp>(
|
|
||||||
op->getLoc(), loadTensor.getType(),
|
|
||||||
loadOp.ptr(),
|
|
||||||
loadTensor, rewriter.create<arith::ConstantIntOp>(op->getLoc(), 0, 32),
|
|
||||||
loadOp.mask(),
|
|
||||||
loadOp.other(), loadOp.cache(),
|
|
||||||
loadOp.evict(), loadOp.isVolatile(), /*axis*/ 0);
|
|
||||||
|
|
||||||
rewriter.create<triton::gpu::AsyncWaitOp>(op->getLoc(), 0);
|
|
||||||
auto tmpType = RankedTensorType::get({shape[0], shape[1]}, eltType, sharedEncoding);
|
|
||||||
auto tmp = rewriter.create<tensor::ExtractSliceOp>(op->getLoc(), tmpType, newOp,
|
|
||||||
SmallVector<OpFoldResult>{int_attr(0), int_attr(0), int_attr(0)},
|
|
||||||
SmallVector<OpFoldResult>{int_attr(1),
|
|
||||||
int_attr(shape[0]),
|
|
||||||
int_attr(shape[1])},
|
|
||||||
SmallVector<OpFoldResult>{int_attr(1), int_attr(1), int_attr(1)});
|
|
||||||
rewriter.replaceOpWithNewOp<triton::gpu::ConvertLayoutOp>(op, origRetType, tmp);
|
|
||||||
return success();
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
};
|
|
||||||
|
|
||||||
class FixupLoop : public mlir::RewritePattern {
|
class FixupLoop : public mlir::RewritePattern {
|
||||||
|
|
||||||
public:
|
public:
|
||||||
@@ -1280,7 +1225,6 @@ public:
|
|||||||
patterns.add<MoveConvertOutOfLoop>(context);
|
patterns.add<MoveConvertOutOfLoop>(context);
|
||||||
patterns.add<MoveConvertOutOfIf>(context);
|
patterns.add<MoveConvertOutOfIf>(context);
|
||||||
patterns.add<BlockedToMMA>(context, computeCapability);
|
patterns.add<BlockedToMMA>(context, computeCapability);
|
||||||
patterns.add<LoadConvertToInsertSlice>(context);
|
|
||||||
|
|
||||||
if (applyPatternsAndFoldGreedily(m, std::move(patterns)).failed()) {
|
if (applyPatternsAndFoldGreedily(m, std::move(patterns)).failed()) {
|
||||||
signalPassFailure();
|
signalPassFailure();
|
||||||
|
@@ -1339,11 +1339,15 @@ void init_triton_ir(py::module &&m) {
|
|||||||
[](mlir::PassManager &self) {
|
[](mlir::PassManager &self) {
|
||||||
self.addPass(mlir::createTritonGPUPrefetchPass());
|
self.addPass(mlir::createTritonGPUPrefetchPass());
|
||||||
})
|
})
|
||||||
.def("add_triton_gpu_combine_pass",
|
.def("add_tritongpu_combine_pass",
|
||||||
[](mlir::PassManager &self, int computeCapability) {
|
[](mlir::PassManager &self, int computeCapability) {
|
||||||
self.addPass(
|
self.addPass(
|
||||||
mlir::createTritonGPUCombineOpsPass(computeCapability));
|
mlir::createTritonGPUCombineOpsPass(computeCapability));
|
||||||
})
|
})
|
||||||
|
.def("add_tritongpu_optimize_load_convert_pass",
|
||||||
|
[](mlir::PassManager &self) {
|
||||||
|
self.addPass(mlir::createTritonGPUOptimizeLoadConvertPass());
|
||||||
|
})
|
||||||
.def("add_triton_gpu_to_llvm",
|
.def("add_triton_gpu_to_llvm",
|
||||||
[](mlir::PassManager &self) {
|
[](mlir::PassManager &self) {
|
||||||
self.addPass(mlir::triton::createConvertTritonGPUToLLVMPass());
|
self.addPass(mlir::triton::createConvertTritonGPUToLLVMPass());
|
||||||
|
@@ -894,17 +894,18 @@ def ttir_to_ttgir(mod, num_warps, num_stages, compute_capability):
|
|||||||
pm.add_coalesce_pass()
|
pm.add_coalesce_pass()
|
||||||
# The combine pass converts blocked layout to mma layout
|
# The combine pass converts blocked layout to mma layout
|
||||||
# for dot ops so that pipeline can get shared memory swizzled correctly.
|
# for dot ops so that pipeline can get shared memory swizzled correctly.
|
||||||
pm.add_triton_gpu_combine_pass(compute_capability)
|
pm.add_tritongpu_combine_pass(compute_capability)
|
||||||
pm.add_tritongpu_pipeline_pass(num_stages)
|
pm.add_tritongpu_pipeline_pass(num_stages)
|
||||||
# Prefetch must be done after pipeline pass because pipeline pass
|
# Prefetch must be done after pipeline pass because pipeline pass
|
||||||
# extracts slices from the original tensor.
|
# extracts slices from the original tensor.
|
||||||
pm.add_tritongpu_prefetch_pass()
|
pm.add_tritongpu_prefetch_pass()
|
||||||
pm.add_canonicalizer_pass()
|
pm.add_canonicalizer_pass()
|
||||||
pm.add_cse_pass()
|
pm.add_cse_pass()
|
||||||
pm.add_triton_gpu_combine_pass(compute_capability)
|
pm.add_tritongpu_combine_pass(compute_capability)
|
||||||
pm.add_licm_pass()
|
pm.add_licm_pass()
|
||||||
pm.add_triton_gpu_combine_pass(compute_capability)
|
pm.add_tritongpu_combine_pass(compute_capability)
|
||||||
pm.add_cse_pass()
|
pm.add_cse_pass()
|
||||||
|
pm.add_tritongpu_optimize_load_convert_pass()
|
||||||
pm.run(mod)
|
pm.run(mod)
|
||||||
return mod
|
return mod
|
||||||
|
|
||||||
|
@@ -191,7 +191,7 @@ def _bwd_kernel(
|
|||||||
tl.store(dv_ptrs, dv)
|
tl.store(dv_ptrs, dv)
|
||||||
tl.store(dk_ptrs, dk)
|
tl.store(dk_ptrs, dk)
|
||||||
|
|
||||||
_fwd_kernel = triton.compile("./flash-attention.ttgir", num_warps=4)
|
# _fwd_kernel = triton.compile("./flash-attention.ttgir", num_warps=4)
|
||||||
|
|
||||||
empty = torch.empty(128, device="cuda")
|
empty = torch.empty(128, device="cuda")
|
||||||
|
|
||||||
@@ -210,28 +210,28 @@ class _attention(torch.autograd.Function):
|
|||||||
m = torch.empty((q.shape[0] * q.shape[1], q.shape[2]), device=q.device, dtype=torch.float32)
|
m = torch.empty((q.shape[0] * q.shape[1], q.shape[2]), device=q.device, dtype=torch.float32)
|
||||||
num_warps = 4 if Lk <= 64 else 8
|
num_warps = 4 if Lk <= 64 else 8
|
||||||
|
|
||||||
# _fwd_kernel[grid](
|
|
||||||
# q, k, v, sm_scale,
|
|
||||||
# L, m,
|
|
||||||
# o,
|
|
||||||
# q.stride(0), q.stride(1), q.stride(2), q.stride(3),
|
|
||||||
# k.stride(0), k.stride(1), k.stride(2), k.stride(3),
|
|
||||||
# v.stride(0), v.stride(1), v.stride(2), v.stride(3),
|
|
||||||
# o.stride(0), o.stride(1), o.stride(2), o.stride(3),
|
|
||||||
# q.shape[0], q.shape[1], q.shape[2],
|
|
||||||
# BLOCK_M=BLOCK, BLOCK_N=BLOCK,
|
|
||||||
# BLOCK_DMODEL=Lk, num_warps=num_warps,
|
|
||||||
# num_stages=1,
|
|
||||||
# )
|
|
||||||
_fwd_kernel[grid](
|
_fwd_kernel[grid](
|
||||||
q.data_ptr(), k.data_ptr(), v.data_ptr(), sm_scale,
|
q, k, v, sm_scale,
|
||||||
L.data_ptr(), m.data_ptr(),
|
L, m,
|
||||||
o.data_ptr(),
|
o,
|
||||||
q.stride(0), q.stride(1), q.stride(2),
|
q.stride(0), q.stride(1), q.stride(2), q.stride(3),
|
||||||
k.stride(0), k.stride(1), k.stride(2),
|
k.stride(0), k.stride(1), k.stride(2), k.stride(3),
|
||||||
v.stride(0), v.stride(1), v.stride(2),
|
v.stride(0), v.stride(1), v.stride(2), v.stride(3),
|
||||||
o.stride(0), o.stride(1), o.stride(2),
|
o.stride(0), o.stride(1), o.stride(2), o.stride(3),
|
||||||
q.shape[0], q.shape[1], q.shape[2])
|
q.shape[0], q.shape[1], q.shape[2],
|
||||||
|
BLOCK_M=BLOCK, BLOCK_N=BLOCK,
|
||||||
|
BLOCK_DMODEL=Lk, num_warps=num_warps,
|
||||||
|
num_stages=1,
|
||||||
|
)
|
||||||
|
# _fwd_kernel[grid](
|
||||||
|
# q.data_ptr(), k.data_ptr(), v.data_ptr(), sm_scale,
|
||||||
|
# L.data_ptr(), m.data_ptr(),
|
||||||
|
# o.data_ptr(),
|
||||||
|
# q.stride(0), q.stride(1), q.stride(2),
|
||||||
|
# k.stride(0), k.stride(1), k.stride(2),
|
||||||
|
# v.stride(0), v.stride(1), v.stride(2),
|
||||||
|
# o.stride(0), o.stride(1), o.stride(2),
|
||||||
|
# q.shape[0], q.shape[1], q.shape[2])
|
||||||
|
|
||||||
ctx.save_for_backward(q, k, v, o, L, m)
|
ctx.save_for_backward(q, k, v, o, L, m)
|
||||||
ctx.BLOCK = BLOCK
|
ctx.BLOCK = BLOCK
|
||||||
|
Reference in New Issue
Block a user