trying more things

This commit is contained in:
Phil Tillet
2022-12-27 20:58:31 -08:00
parent 4182e90862
commit 0d6e6cf578
6 changed files with 63 additions and 35 deletions

View File

@@ -104,6 +104,7 @@ SmallVector<unsigned> getSizePerThread(const Attribute &layout) {
return SmallVector<unsigned>(blockedLayout.getSizePerThread().begin(),
blockedLayout.getSizePerThread().end());
} else if (auto sliceLayout = layout.dyn_cast<SliceEncodingAttr>()) {
return {1};
return getSizePerThread(sliceLayout.getParent());
} else if (auto mmaLayout = layout.dyn_cast<MmaEncodingAttr>()) {
if (mmaLayout.isAmpere()) {

View File

@@ -19,6 +19,8 @@
#include <memory>
#define int_attr(num) rewriter.getI64IntegerAttr(num)
using namespace mlir;
namespace {
#include "TritonGPUCombine.inc"
@@ -1153,6 +1155,60 @@ public:
}
};
class LoadConvertToInsertSlice : public mlir::RewritePattern{
public:
explicit LoadConvertToInsertSlice(mlir::MLIRContext *context)
: mlir::RewritePattern(triton::gpu::ConvertLayoutOp::getOperationName(), 2, context) {}
mlir::LogicalResult
matchAndRewrite(mlir::Operation *op,
mlir::PatternRewriter &rewriter) const override {
auto cvt = cast<triton::gpu::ConvertLayoutOp>(op);
auto origRetType = cvt.getResult().getType().cast<RankedTensorType>();
auto shape = origRetType.getShape();
auto eltType = origRetType.getElementType();
auto dotOpEncoding = origRetType.getEncoding().dyn_cast<triton::gpu::DotOperandEncodingAttr>();
if(!dotOpEncoding){
return failure();
}
auto loadOp = dyn_cast<triton::LoadOp>(*cvt.getOperand().getDefiningOp());
if(!loadOp){
return failure();
}
auto blockedEncoding = loadOp.getType().cast<RankedTensorType>().getEncoding().dyn_cast<triton::gpu::BlockedEncodingAttr>();
if(!blockedEncoding)
return failure();
auto sharedEncoding = triton::gpu::SharedEncodingAttr::get(getContext(), dotOpEncoding, shape,
blockedEncoding.getOrder(), eltType);
auto srcTy = RankedTensorType::get({1, shape[0], shape[1]},
eltType,
sharedEncoding);
auto loadTensor = rewriter.create<triton::gpu::AllocTensorOp>(op->getLoc(), srcTy);
auto newOp = rewriter.create<triton::gpu::InsertSliceAsyncOp>(
op->getLoc(), loadTensor.getType(),
loadOp.ptr(),
loadTensor, rewriter.create<arith::ConstantIntOp>(op->getLoc(), 0, 32),
loadOp.mask(),
loadOp.other(), loadOp.cache(),
loadOp.evict(), loadOp.isVolatile(), /*axis*/ 0);
rewriter.create<triton::gpu::AsyncWaitOp>(op->getLoc(), 0);
auto tmpType = RankedTensorType::get({shape[0], shape[1]}, eltType, sharedEncoding);
auto tmp = rewriter.create<tensor::ExtractSliceOp>(op->getLoc(), tmpType, newOp,
SmallVector<OpFoldResult>{int_attr(0), int_attr(0), int_attr(0)},
SmallVector<OpFoldResult>{int_attr(1),
int_attr(shape[0]),
int_attr(shape[1])},
SmallVector<OpFoldResult>{int_attr(1), int_attr(1), int_attr(1)});
rewriter.replaceOpWithNewOp<triton::gpu::ConvertLayoutOp>(op, origRetType, tmp);
return success();
}
};
class FixupLoop : public mlir::RewritePattern {
public:
@@ -1224,6 +1280,7 @@ public:
patterns.add<MoveConvertOutOfLoop>(context);
patterns.add<MoveConvertOutOfIf>(context);
patterns.add<BlockedToMMA>(context, computeCapability);
patterns.add<LoadConvertToInsertSlice>(context);
if (applyPatternsAndFoldGreedily(m, std::move(patterns)).failed()) {
signalPassFailure();

View File

@@ -1,19 +0,0 @@
import triton
import triton.language as tl
# triton kernel
@triton.jit
def kernel(X, stride_xm,
Z, stride_zn,
BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr):
off_m = tl.arange(0, BLOCK_M)
off_n = tl.arange(0, BLOCK_N)
Xs = X + off_m[:, None] * stride_xm + off_n[None, :] * 1
Zs = Z + off_m[:, None] * 1 + off_n[None, :] * stride_zn
tl.store(Zs, tl.load(Xs))
ret = triton.compile(kernel, "*fp32,i32,*fp32,i32", constants={"BLOCK_M": 64, "BLOCK_N": 64}, output="ttgir")
print(ret)

View File

@@ -1,13 +0,0 @@
import torch
import triton
import triton.language as tl
@triton.jit
def kernel(X, stride_xm, stride_xn, BLOCK: tl.constexpr):
pass
X = torch.randn(1, device="cuda")
pgm = kernel[(1,)](X, 1, 1, BLOCK=1024)

View File

@@ -1562,7 +1562,7 @@ class CompiledKernel:
if self.shared > max_shared:
raise OutOfResources(self.shared, max_shared, "shared memory")
mod, func, n_regs, n_spills = cuda_utils.load_binary(self.metadata["name"], self.asm["cubin"], self.shared, device)
print(n_regs, n_spills)
print(self.shared, n_regs, n_spills)
self.cu_module = mod
self.cu_function = func

View File

@@ -194,6 +194,7 @@ def _bwd_kernel(
empty = torch.empty(128, device="cuda")
class _attention(torch.autograd.Function):
@staticmethod
@@ -220,7 +221,7 @@ class _attention(torch.autograd.Function):
q.shape[0], q.shape[1], q.shape[2],
BLOCK_M=BLOCK, BLOCK_N=BLOCK,
BLOCK_DMODEL=Lk, num_warps=num_warps,
num_stages=2,
num_stages=1,
)
ctx.save_for_backward(q, k, v, o, L, m)
@@ -335,7 +336,8 @@ def bench_flash_attention(BATCH, H, N_CTX, D_HEAD, mode, provider, dtype=torch.f
ms = triton.testing.do_bench(fn, percentiles=None, warmup=warmup, rep=rep)
flops_per_matmul = 2.*BATCH*H*N_CTX*N_CTX*D_HEAD*0.5
total_flops = 2*flops_per_matmul
print(total_flops/ms*1e-9)
# print(total_flops/ms*1e-9)
print(ms)
return ms
if provider == "flash":
lengths = torch.full((BATCH,), fill_value=N_CTX, device=device)