trying more things
This commit is contained in:
@@ -104,6 +104,7 @@ SmallVector<unsigned> getSizePerThread(const Attribute &layout) {
|
|||||||
return SmallVector<unsigned>(blockedLayout.getSizePerThread().begin(),
|
return SmallVector<unsigned>(blockedLayout.getSizePerThread().begin(),
|
||||||
blockedLayout.getSizePerThread().end());
|
blockedLayout.getSizePerThread().end());
|
||||||
} else if (auto sliceLayout = layout.dyn_cast<SliceEncodingAttr>()) {
|
} else if (auto sliceLayout = layout.dyn_cast<SliceEncodingAttr>()) {
|
||||||
|
return {1};
|
||||||
return getSizePerThread(sliceLayout.getParent());
|
return getSizePerThread(sliceLayout.getParent());
|
||||||
} else if (auto mmaLayout = layout.dyn_cast<MmaEncodingAttr>()) {
|
} else if (auto mmaLayout = layout.dyn_cast<MmaEncodingAttr>()) {
|
||||||
if (mmaLayout.isAmpere()) {
|
if (mmaLayout.isAmpere()) {
|
||||||
|
@@ -19,6 +19,8 @@
|
|||||||
|
|
||||||
#include <memory>
|
#include <memory>
|
||||||
|
|
||||||
|
#define int_attr(num) rewriter.getI64IntegerAttr(num)
|
||||||
|
|
||||||
using namespace mlir;
|
using namespace mlir;
|
||||||
namespace {
|
namespace {
|
||||||
#include "TritonGPUCombine.inc"
|
#include "TritonGPUCombine.inc"
|
||||||
@@ -1153,6 +1155,60 @@ public:
|
|||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
class LoadConvertToInsertSlice : public mlir::RewritePattern{
|
||||||
|
|
||||||
|
public:
|
||||||
|
explicit LoadConvertToInsertSlice(mlir::MLIRContext *context)
|
||||||
|
: mlir::RewritePattern(triton::gpu::ConvertLayoutOp::getOperationName(), 2, context) {}
|
||||||
|
|
||||||
|
mlir::LogicalResult
|
||||||
|
matchAndRewrite(mlir::Operation *op,
|
||||||
|
mlir::PatternRewriter &rewriter) const override {
|
||||||
|
auto cvt = cast<triton::gpu::ConvertLayoutOp>(op);
|
||||||
|
auto origRetType = cvt.getResult().getType().cast<RankedTensorType>();
|
||||||
|
auto shape = origRetType.getShape();
|
||||||
|
auto eltType = origRetType.getElementType();
|
||||||
|
auto dotOpEncoding = origRetType.getEncoding().dyn_cast<triton::gpu::DotOperandEncodingAttr>();
|
||||||
|
if(!dotOpEncoding){
|
||||||
|
return failure();
|
||||||
|
}
|
||||||
|
auto loadOp = dyn_cast<triton::LoadOp>(*cvt.getOperand().getDefiningOp());
|
||||||
|
if(!loadOp){
|
||||||
|
return failure();
|
||||||
|
}
|
||||||
|
auto blockedEncoding = loadOp.getType().cast<RankedTensorType>().getEncoding().dyn_cast<triton::gpu::BlockedEncodingAttr>();
|
||||||
|
if(!blockedEncoding)
|
||||||
|
return failure();
|
||||||
|
auto sharedEncoding = triton::gpu::SharedEncodingAttr::get(getContext(), dotOpEncoding, shape,
|
||||||
|
blockedEncoding.getOrder(), eltType);
|
||||||
|
auto srcTy = RankedTensorType::get({1, shape[0], shape[1]},
|
||||||
|
eltType,
|
||||||
|
sharedEncoding);
|
||||||
|
auto loadTensor = rewriter.create<triton::gpu::AllocTensorOp>(op->getLoc(), srcTy);
|
||||||
|
|
||||||
|
auto newOp = rewriter.create<triton::gpu::InsertSliceAsyncOp>(
|
||||||
|
op->getLoc(), loadTensor.getType(),
|
||||||
|
loadOp.ptr(),
|
||||||
|
loadTensor, rewriter.create<arith::ConstantIntOp>(op->getLoc(), 0, 32),
|
||||||
|
loadOp.mask(),
|
||||||
|
loadOp.other(), loadOp.cache(),
|
||||||
|
loadOp.evict(), loadOp.isVolatile(), /*axis*/ 0);
|
||||||
|
|
||||||
|
rewriter.create<triton::gpu::AsyncWaitOp>(op->getLoc(), 0);
|
||||||
|
auto tmpType = RankedTensorType::get({shape[0], shape[1]}, eltType, sharedEncoding);
|
||||||
|
auto tmp = rewriter.create<tensor::ExtractSliceOp>(op->getLoc(), tmpType, newOp,
|
||||||
|
SmallVector<OpFoldResult>{int_attr(0), int_attr(0), int_attr(0)},
|
||||||
|
SmallVector<OpFoldResult>{int_attr(1),
|
||||||
|
int_attr(shape[0]),
|
||||||
|
int_attr(shape[1])},
|
||||||
|
SmallVector<OpFoldResult>{int_attr(1), int_attr(1), int_attr(1)});
|
||||||
|
rewriter.replaceOpWithNewOp<triton::gpu::ConvertLayoutOp>(op, origRetType, tmp);
|
||||||
|
return success();
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
};
|
||||||
|
|
||||||
class FixupLoop : public mlir::RewritePattern {
|
class FixupLoop : public mlir::RewritePattern {
|
||||||
|
|
||||||
public:
|
public:
|
||||||
@@ -1224,6 +1280,7 @@ public:
|
|||||||
patterns.add<MoveConvertOutOfLoop>(context);
|
patterns.add<MoveConvertOutOfLoop>(context);
|
||||||
patterns.add<MoveConvertOutOfIf>(context);
|
patterns.add<MoveConvertOutOfIf>(context);
|
||||||
patterns.add<BlockedToMMA>(context, computeCapability);
|
patterns.add<BlockedToMMA>(context, computeCapability);
|
||||||
|
patterns.add<LoadConvertToInsertSlice>(context);
|
||||||
|
|
||||||
if (applyPatternsAndFoldGreedily(m, std::move(patterns)).failed()) {
|
if (applyPatternsAndFoldGreedily(m, std::move(patterns)).failed()) {
|
||||||
signalPassFailure();
|
signalPassFailure();
|
||||||
|
@@ -1,19 +0,0 @@
|
|||||||
|
|
||||||
import triton
|
|
||||||
import triton.language as tl
|
|
||||||
|
|
||||||
|
|
||||||
# triton kernel
|
|
||||||
@triton.jit
|
|
||||||
def kernel(X, stride_xm,
|
|
||||||
Z, stride_zn,
|
|
||||||
BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr):
|
|
||||||
off_m = tl.arange(0, BLOCK_M)
|
|
||||||
off_n = tl.arange(0, BLOCK_N)
|
|
||||||
Xs = X + off_m[:, None] * stride_xm + off_n[None, :] * 1
|
|
||||||
Zs = Z + off_m[:, None] * 1 + off_n[None, :] * stride_zn
|
|
||||||
tl.store(Zs, tl.load(Xs))
|
|
||||||
|
|
||||||
|
|
||||||
ret = triton.compile(kernel, "*fp32,i32,*fp32,i32", constants={"BLOCK_M": 64, "BLOCK_N": 64}, output="ttgir")
|
|
||||||
print(ret)
|
|
@@ -1,13 +0,0 @@
|
|||||||
import torch
|
|
||||||
|
|
||||||
import triton
|
|
||||||
import triton.language as tl
|
|
||||||
|
|
||||||
|
|
||||||
@triton.jit
|
|
||||||
def kernel(X, stride_xm, stride_xn, BLOCK: tl.constexpr):
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
X = torch.randn(1, device="cuda")
|
|
||||||
pgm = kernel[(1,)](X, 1, 1, BLOCK=1024)
|
|
@@ -1562,7 +1562,7 @@ class CompiledKernel:
|
|||||||
if self.shared > max_shared:
|
if self.shared > max_shared:
|
||||||
raise OutOfResources(self.shared, max_shared, "shared memory")
|
raise OutOfResources(self.shared, max_shared, "shared memory")
|
||||||
mod, func, n_regs, n_spills = cuda_utils.load_binary(self.metadata["name"], self.asm["cubin"], self.shared, device)
|
mod, func, n_regs, n_spills = cuda_utils.load_binary(self.metadata["name"], self.asm["cubin"], self.shared, device)
|
||||||
print(n_regs, n_spills)
|
print(self.shared, n_regs, n_spills)
|
||||||
self.cu_module = mod
|
self.cu_module = mod
|
||||||
self.cu_function = func
|
self.cu_function = func
|
||||||
|
|
||||||
|
@@ -194,6 +194,7 @@ def _bwd_kernel(
|
|||||||
|
|
||||||
|
|
||||||
empty = torch.empty(128, device="cuda")
|
empty = torch.empty(128, device="cuda")
|
||||||
|
|
||||||
class _attention(torch.autograd.Function):
|
class _attention(torch.autograd.Function):
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
@@ -220,7 +221,7 @@ class _attention(torch.autograd.Function):
|
|||||||
q.shape[0], q.shape[1], q.shape[2],
|
q.shape[0], q.shape[1], q.shape[2],
|
||||||
BLOCK_M=BLOCK, BLOCK_N=BLOCK,
|
BLOCK_M=BLOCK, BLOCK_N=BLOCK,
|
||||||
BLOCK_DMODEL=Lk, num_warps=num_warps,
|
BLOCK_DMODEL=Lk, num_warps=num_warps,
|
||||||
num_stages=2,
|
num_stages=1,
|
||||||
)
|
)
|
||||||
|
|
||||||
ctx.save_for_backward(q, k, v, o, L, m)
|
ctx.save_for_backward(q, k, v, o, L, m)
|
||||||
@@ -335,7 +336,8 @@ def bench_flash_attention(BATCH, H, N_CTX, D_HEAD, mode, provider, dtype=torch.f
|
|||||||
ms = triton.testing.do_bench(fn, percentiles=None, warmup=warmup, rep=rep)
|
ms = triton.testing.do_bench(fn, percentiles=None, warmup=warmup, rep=rep)
|
||||||
flops_per_matmul = 2.*BATCH*H*N_CTX*N_CTX*D_HEAD*0.5
|
flops_per_matmul = 2.*BATCH*H*N_CTX*N_CTX*D_HEAD*0.5
|
||||||
total_flops = 2*flops_per_matmul
|
total_flops = 2*flops_per_matmul
|
||||||
print(total_flops/ms*1e-9)
|
# print(total_flops/ms*1e-9)
|
||||||
|
print(ms)
|
||||||
return ms
|
return ms
|
||||||
if provider == "flash":
|
if provider == "flash":
|
||||||
lengths = torch.full((BATCH,), fill_value=N_CTX, device=device)
|
lengths = torch.full((BATCH,), fill_value=N_CTX, device=device)
|
||||||
|
Reference in New Issue
Block a user