more pass template
This commit is contained in:
@@ -17,6 +17,8 @@ std::unique_ptr<Pass> createTritonGPUOptimizeLoadConvertPass();
|
|||||||
|
|
||||||
std::unique_ptr<Pass> createTritonGPUSinkConversionsFromSharedPass();
|
std::unique_ptr<Pass> createTritonGPUSinkConversionsFromSharedPass();
|
||||||
|
|
||||||
|
std::unique_ptr<Pass> createTritonGPUDecomposeConversionsToDotOperandPass();
|
||||||
|
|
||||||
std::unique_ptr<Pass> createTritonGPUCombineOpsPass(int computeCapability = 80);
|
std::unique_ptr<Pass> createTritonGPUCombineOpsPass(int computeCapability = 80);
|
||||||
|
|
||||||
std::unique_ptr<Pass> createTritonGPUVerifier();
|
std::unique_ptr<Pass> createTritonGPUVerifier();
|
||||||
|
@@ -96,6 +96,17 @@ def TritonGPUSinkConversionsFromShared: Pass<"tritongpu-sink-conversions-from-sh
|
|||||||
"mlir::triton::TritonDialect"];
|
"mlir::triton::TritonDialect"];
|
||||||
}
|
}
|
||||||
|
|
||||||
|
def TritonGPUDecomposeConversionsToDotOperand: Pass<"tritongpu-decompose-conversions-to-dot-operand", "mlir::ModuleOp"> {
|
||||||
|
let summary = "Decompose convert[distributed -> dotOperand] into convert[distributed -> shared -> dotOperand]";
|
||||||
|
|
||||||
|
let description = "Decomposing conversions this way makes it possible to use CSE and re-use #shared tensors";
|
||||||
|
|
||||||
|
let constructor = "mlir::createTritonGPUDecomposeConversionsToDotOperandPass()";
|
||||||
|
|
||||||
|
let dependentDialects = ["mlir::triton::gpu::TritonGPUDialect",
|
||||||
|
"mlir::triton::TritonDialect"];
|
||||||
|
}
|
||||||
|
|
||||||
def TritonGPUCanonicalizeLoops: Pass<"tritongpu-canonicalize-loops", "mlir::ModuleOp"> {
|
def TritonGPUCanonicalizeLoops: Pass<"tritongpu-canonicalize-loops", "mlir::ModuleOp"> {
|
||||||
let summary = "canonicalize scf.ForOp ops";
|
let summary = "canonicalize scf.ForOp ops";
|
||||||
|
|
||||||
|
@@ -72,24 +72,24 @@ void storeDistributedToShared(Value src, Value llSrc,
|
|||||||
Value staIdx1 = i32_val(0);
|
Value staIdx1 = i32_val(0);
|
||||||
Value stride0 = dstStrides[outOrd[0]];
|
Value stride0 = dstStrides[outOrd[0]];
|
||||||
Value stride1 = dstStrides[outOrd[1]];
|
Value stride1 = dstStrides[outOrd[1]];
|
||||||
if (auto addOp = dyn_cast<LLVM::AddOp>(dynIdx0.getDefiningOp()))
|
// if (auto addOp = dyn_cast<LLVM::AddOp>(dynIdx0.getDefiningOp()))
|
||||||
if (auto cstRhs =
|
// if (auto cstRhs =
|
||||||
dyn_cast<LLVM::ConstantOp>(addOp.getRhs().getDefiningOp())) {
|
// dyn_cast<LLVM::ConstantOp>(addOp.getRhs().getDefiningOp())) {
|
||||||
unsigned rhsVal =
|
// unsigned rhsVal =
|
||||||
cstRhs.getValue().cast<IntegerAttr>().getValue().getSExtValue();
|
// cstRhs.getValue().cast<IntegerAttr>().getValue().getSExtValue();
|
||||||
unsigned key = (rhsVal / outVec) % maxPhase;
|
// unsigned key = (rhsVal / outVec) % maxPhase;
|
||||||
if (cache.find(key) == cache.end())
|
// if (cache.find(key) == cache.end())
|
||||||
cache[key] = dynIdx0;
|
// cache[key] = dynIdx0;
|
||||||
dynIdx0 = cache[key];
|
// dynIdx0 = cache[key];
|
||||||
staIdx0 =
|
// staIdx0 =
|
||||||
i32_val((rhsVal) / (outVec * maxPhase) * (outVec * maxPhase));
|
// i32_val((rhsVal) / (outVec * maxPhase) * (outVec * maxPhase));
|
||||||
}
|
// }
|
||||||
if (auto addOp = dyn_cast<LLVM::AddOp>(dynIdx1.getDefiningOp()))
|
// if (auto addOp = dyn_cast<LLVM::AddOp>(dynIdx1.getDefiningOp()))
|
||||||
if (auto cstRhs =
|
// if (auto cstRhs =
|
||||||
dyn_cast<LLVM::ConstantOp>(addOp.getRhs().getDefiningOp())) {
|
// dyn_cast<LLVM::ConstantOp>(addOp.getRhs().getDefiningOp())) {
|
||||||
dynIdx1 = addOp.getLhs();
|
// dynIdx1 = addOp.getLhs();
|
||||||
staIdx1 = addOp.getRhs();
|
// staIdx1 = addOp.getRhs();
|
||||||
}
|
// }
|
||||||
|
|
||||||
// offset along non-contiguous dimension
|
// offset along non-contiguous dimension
|
||||||
Value off1 = mul(dynIdx1, stride1);
|
Value off1 = mul(dynIdx1, stride1);
|
||||||
|
@@ -10,6 +10,7 @@ add_mlir_dialect_library(TritonGPUTransforms
|
|||||||
Prefetch.cpp
|
Prefetch.cpp
|
||||||
OptimizeLoadConvert.cpp
|
OptimizeLoadConvert.cpp
|
||||||
SinkConversionsFromShared.cpp
|
SinkConversionsFromShared.cpp
|
||||||
|
DecomposeConversionsToDotOperand.cpp
|
||||||
TritonGPUConversion.cpp
|
TritonGPUConversion.cpp
|
||||||
|
|
||||||
DEPENDS
|
DEPENDS
|
||||||
|
@@ -19,7 +19,6 @@
|
|||||||
|
|
||||||
#include <memory>
|
#include <memory>
|
||||||
|
|
||||||
|
|
||||||
using namespace mlir;
|
using namespace mlir;
|
||||||
namespace {
|
namespace {
|
||||||
#include "TritonGPUCombine.inc"
|
#include "TritonGPUCombine.inc"
|
||||||
@@ -483,8 +482,7 @@ public:
|
|||||||
return op->getBlock() == cvt->getBlock() &&
|
return op->getBlock() == cvt->getBlock() &&
|
||||||
!(isa<triton::ReduceOp>(op) &&
|
!(isa<triton::ReduceOp>(op) &&
|
||||||
!op->getResult(0).getType().isa<RankedTensorType>()) &&
|
!op->getResult(0).getType().isa<RankedTensorType>()) &&
|
||||||
!isa<triton::gpu::ConvertLayoutOp>(op) &&
|
!isa<triton::gpu::ConvertLayoutOp>(op) && !isa<scf::YieldOp>(op);
|
||||||
!isa<scf::YieldOp>(op);
|
|
||||||
};
|
};
|
||||||
mlir::getForwardSlice(cvt.getResult(), &cvtSlices, filter);
|
mlir::getForwardSlice(cvt.getResult(), &cvtSlices, filter);
|
||||||
if (cvtSlices.empty())
|
if (cvtSlices.empty())
|
||||||
|
@@ -0,0 +1,36 @@
|
|||||||
|
#include "mlir/Analysis/SliceAnalysis.h"
|
||||||
|
#include "mlir/Dialect/SCF/SCF.h"
|
||||||
|
#include "mlir/IR/BlockAndValueMapping.h"
|
||||||
|
#include "mlir/IR/BuiltinAttributes.h"
|
||||||
|
#include "mlir/IR/Matchers.h"
|
||||||
|
#include "mlir/IR/PatternMatch.h"
|
||||||
|
#include "mlir/IR/Verifier.h"
|
||||||
|
#include "mlir/Interfaces/InferTypeOpInterface.h"
|
||||||
|
#include "mlir/Pass/Pass.h"
|
||||||
|
#include "mlir/Pass/PassManager.h"
|
||||||
|
#include "mlir/Support/LogicalResult.h"
|
||||||
|
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
|
||||||
|
#include "mlir/Transforms/Passes.h"
|
||||||
|
#include "mlir/Transforms/RegionUtils.h"
|
||||||
|
#include "triton/Analysis/Utility.h"
|
||||||
|
#include "triton/Dialect/TritonGPU/IR/Dialect.h"
|
||||||
|
#include "triton/Dialect/TritonGPU/Transforms/Passes.h"
|
||||||
|
#include "triton/Dialect/TritonGPU/Transforms/TritonGPUConversion.h"
|
||||||
|
#define GEN_PASS_CLASSES
|
||||||
|
#include "triton/Dialect/TritonGPU/Transforms/Passes.h.inc"
|
||||||
|
|
||||||
|
using namespace mlir;
|
||||||
|
|
||||||
|
class TritonGPUDecomposeConversionsToDotOperandPass
|
||||||
|
: public TritonGPUDecomposeConversionsToDotOperandBase<
|
||||||
|
TritonGPUDecomposeConversionsToDotOperandPass> {
|
||||||
|
public:
|
||||||
|
TritonGPUDecomposeConversionsToDotOperandPass() = default;
|
||||||
|
|
||||||
|
void runOnOperation() override { return; }
|
||||||
|
};
|
||||||
|
|
||||||
|
std::unique_ptr<Pass>
|
||||||
|
mlir::createTritonGPUDecomposeConversionsToDotOperandPass() {
|
||||||
|
return std::make_unique<TritonGPUDecomposeConversionsToDotOperandPass>();
|
||||||
|
}
|
@@ -1345,14 +1345,18 @@ void init_triton_ir(py::module &&m) {
|
|||||||
mlir::createTritonGPUCombineOpsPass(computeCapability));
|
mlir::createTritonGPUCombineOpsPass(computeCapability));
|
||||||
})
|
})
|
||||||
.def("add_tritongpu_optimize_load_convert_pass",
|
.def("add_tritongpu_optimize_load_convert_pass",
|
||||||
[](mlir::PassManager &self) {
|
[](mlir::PassManager &self) {
|
||||||
self.addPass(mlir::createTritonGPUOptimizeLoadConvertPass());
|
self.addPass(mlir::createTritonGPUOptimizeLoadConvertPass());
|
||||||
})
|
})
|
||||||
.def("add_tritongpu_sink_conversions_from_shared_pass",
|
.def("add_tritongpu_sink_conversions_from_shared_pass",
|
||||||
[](mlir::PassManager &self) {
|
[](mlir::PassManager &self) {
|
||||||
self.addPass(
|
self.addPass(mlir::createTritonGPUSinkConversionsFromSharedPass());
|
||||||
mlir::createTritonGPUSinkConversionsFromSharedPass());
|
})
|
||||||
})
|
.def("add_tritongpu_decompose_conversions_to_dot_operand_pass",
|
||||||
|
[](mlir::PassManager &self) {
|
||||||
|
self.addPass(
|
||||||
|
mlir::createTritonGPUDecomposeConversionsToDotOperandPass());
|
||||||
|
})
|
||||||
.def("add_triton_gpu_to_llvm",
|
.def("add_triton_gpu_to_llvm",
|
||||||
[](mlir::PassManager &self) {
|
[](mlir::PassManager &self) {
|
||||||
self.addPass(mlir::triton::createConvertTritonGPUToLLVMPass());
|
self.addPass(mlir::triton::createConvertTritonGPUToLLVMPass());
|
||||||
|
@@ -906,6 +906,8 @@ def ttir_to_ttgir(mod, num_warps, num_stages, compute_capability):
|
|||||||
pm.add_tritongpu_combine_pass(compute_capability)
|
pm.add_tritongpu_combine_pass(compute_capability)
|
||||||
pm.add_cse_pass()
|
pm.add_cse_pass()
|
||||||
# pm.add_tritongpu_optimize_load_convert_pass()
|
# pm.add_tritongpu_optimize_load_convert_pass()
|
||||||
|
pm.add_tritongpu_sink_conversions_from_shared_pass()
|
||||||
|
pm.add_tritongpu_decompose_conversions_to_dot_operand_pass()
|
||||||
pm.run(mod)
|
pm.run(mod)
|
||||||
return mod
|
return mod
|
||||||
|
|
||||||
|
@@ -194,8 +194,10 @@ def _bwd_kernel(
|
|||||||
# _bwd_kernel = triton.compile("./bwd.ttgir", num_warps=8)
|
# _bwd_kernel = triton.compile("./bwd.ttgir", num_warps=8)
|
||||||
# _fwd_kernel = triton.compile("./fails.ptx", num_warps=4, shared=18432)
|
# _fwd_kernel = triton.compile("./fails.ptx", num_warps=4, shared=18432)
|
||||||
|
|
||||||
|
|
||||||
empty = torch.empty(128, device="cuda")
|
empty = torch.empty(128, device="cuda")
|
||||||
|
|
||||||
|
|
||||||
class _attention(torch.autograd.Function):
|
class _attention(torch.autograd.Function):
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
@@ -284,8 +286,8 @@ class _attention(torch.autograd.Function):
|
|||||||
BLOCK_DMODEL=ctx.BLOCK_DMODEL, num_warps=8,
|
BLOCK_DMODEL=ctx.BLOCK_DMODEL, num_warps=8,
|
||||||
num_stages=1,
|
num_stages=1,
|
||||||
)
|
)
|
||||||
print(pgm.asm["ttgir"])
|
# print(pgm.asm["ttgir"])
|
||||||
exit(1)
|
# exit()
|
||||||
return dq, dk, dv, None
|
return dq, dk, dv, None
|
||||||
|
|
||||||
|
|
||||||
@@ -327,6 +329,7 @@ def test_op(Z, H, N_CTX, D_HEAD, dtype=torch.float16):
|
|||||||
triton.testing.assert_almost_equal(ref_dk, tri_dk)
|
triton.testing.assert_almost_equal(ref_dk, tri_dk)
|
||||||
triton.testing.assert_almost_equal(ref_dq, tri_dq)
|
triton.testing.assert_almost_equal(ref_dq, tri_dq)
|
||||||
|
|
||||||
|
|
||||||
BATCH, N_HEADS, N_CTX, D_HEAD = 4, 48, 4096, 64
|
BATCH, N_HEADS, N_CTX, D_HEAD = 4, 48, 4096, 64
|
||||||
# vary seq length for fixed head and batch=4
|
# vary seq length for fixed head and batch=4
|
||||||
configs = [triton.testing.Benchmark(
|
configs = [triton.testing.Benchmark(
|
||||||
@@ -358,8 +361,8 @@ def bench_flash_attention(BATCH, H, N_CTX, D_HEAD, mode, provider, dtype=torch.f
|
|||||||
do = torch.randn_like(o)
|
do = torch.randn_like(o)
|
||||||
fn = lambda: o.backward(do, retain_graph=True)
|
fn = lambda: o.backward(do, retain_graph=True)
|
||||||
ms = triton.testing.do_bench(fn, percentiles=None, warmup=warmup, rep=rep)
|
ms = triton.testing.do_bench(fn, percentiles=None, warmup=warmup, rep=rep)
|
||||||
flops_per_matmul = 2.*BATCH*H*N_CTX*N_CTX*D_HEAD*0.5
|
flops_per_matmul = 2. * BATCH * H * N_CTX * N_CTX * D_HEAD * 0.5
|
||||||
total_flops = 2*flops_per_matmul
|
total_flops = 2 * flops_per_matmul
|
||||||
# print(total_flops/ms*1e-9)
|
# print(total_flops/ms*1e-9)
|
||||||
print(ms)
|
print(ms)
|
||||||
return ms
|
return ms
|
||||||
@@ -376,4 +379,5 @@ def bench_flash_attention(BATCH, H, N_CTX, D_HEAD, mode, provider, dtype=torch.f
|
|||||||
ms = triton.testing.do_bench(fn, percentiles=None, warmup=warmup, rep=rep)
|
ms = triton.testing.do_bench(fn, percentiles=None, warmup=warmup, rep=rep)
|
||||||
return ms
|
return ms
|
||||||
|
|
||||||
bench_flash_attention.run(save_path='.', print_data=True)
|
|
||||||
|
# bench_flash_attention.run(save_path='.', print_data=True)
|
||||||
|
Reference in New Issue
Block a user