.
This commit is contained in:
@@ -72,24 +72,24 @@ void storeDistributedToShared(Value src, Value llSrc,
|
|||||||
Value staIdx1 = i32_val(0);
|
Value staIdx1 = i32_val(0);
|
||||||
Value stride0 = dstStrides[outOrd[0]];
|
Value stride0 = dstStrides[outOrd[0]];
|
||||||
Value stride1 = dstStrides[outOrd[1]];
|
Value stride1 = dstStrides[outOrd[1]];
|
||||||
// if (auto addOp = dyn_cast<LLVM::AddOp>(dynIdx0.getDefiningOp()))
|
if (auto addOp = dyn_cast<LLVM::AddOp>(dynIdx0.getDefiningOp()))
|
||||||
// if (auto cstRhs =
|
if (auto cstRhs =
|
||||||
// dyn_cast<LLVM::ConstantOp>(addOp.getRhs().getDefiningOp())) {
|
dyn_cast<LLVM::ConstantOp>(addOp.getRhs().getDefiningOp())) {
|
||||||
// unsigned rhsVal =
|
unsigned rhsVal =
|
||||||
// cstRhs.getValue().cast<IntegerAttr>().getValue().getSExtValue();
|
cstRhs.getValue().cast<IntegerAttr>().getValue().getSExtValue();
|
||||||
// unsigned key = (rhsVal / outVec) % maxPhase;
|
unsigned key = (rhsVal / outVec) % maxPhase;
|
||||||
// if (cache.find(key) == cache.end())
|
if (cache.find(key) == cache.end())
|
||||||
// cache[key] = dynIdx0;
|
cache[key] = dynIdx0;
|
||||||
// dynIdx0 = cache[key];
|
dynIdx0 = cache[key];
|
||||||
// staIdx0 =
|
staIdx0 =
|
||||||
// i32_val((rhsVal) / (outVec * maxPhase) * (outVec * maxPhase));
|
i32_val((rhsVal) / (outVec * maxPhase) * (outVec * maxPhase));
|
||||||
// }
|
}
|
||||||
// if (auto addOp = dyn_cast<LLVM::AddOp>(dynIdx1.getDefiningOp()))
|
if (auto addOp = dyn_cast<LLVM::AddOp>(dynIdx1.getDefiningOp()))
|
||||||
// if (auto cstRhs =
|
if (auto cstRhs =
|
||||||
// dyn_cast<LLVM::ConstantOp>(addOp.getRhs().getDefiningOp())) {
|
dyn_cast<LLVM::ConstantOp>(addOp.getRhs().getDefiningOp())) {
|
||||||
// dynIdx1 = addOp.getLhs();
|
dynIdx1 = addOp.getLhs();
|
||||||
// staIdx1 = addOp.getRhs();
|
staIdx1 = addOp.getRhs();
|
||||||
// }
|
}
|
||||||
|
|
||||||
// offset along non-contiguous dimension
|
// offset along non-contiguous dimension
|
||||||
Value off1 = mul(dynIdx1, stride1);
|
Value off1 = mul(dynIdx1, stride1);
|
||||||
|
@@ -1266,14 +1266,16 @@ public:
|
|||||||
: mlir::RewritePattern(triton::gpu::ConvertLayoutOp::getOperationName(),
|
: mlir::RewritePattern(triton::gpu::ConvertLayoutOp::getOperationName(),
|
||||||
1, context) {}
|
1, context) {}
|
||||||
|
|
||||||
LogicalResult matchAndRewrite(mlir::Operation* op,
|
LogicalResult
|
||||||
|
matchAndRewrite(mlir::Operation *op,
|
||||||
mlir::PatternRewriter &rewriter) const override {
|
mlir::PatternRewriter &rewriter) const override {
|
||||||
auto dstOp = cast<triton::gpu::ConvertLayoutOp>(op);
|
auto dstOp = cast<triton::gpu::ConvertLayoutOp>(op);
|
||||||
auto tmpOp = dyn_cast_or_null<triton::TransOp>(dstOp.src().getDefiningOp());
|
auto tmpOp = dyn_cast_or_null<triton::TransOp>(dstOp.src().getDefiningOp());
|
||||||
if(!tmpOp)
|
if (!tmpOp)
|
||||||
return mlir::failure();
|
return mlir::failure();
|
||||||
auto srcOp = dyn_cast_or_null<triton::gpu::ConvertLayoutOp>(tmpOp.src().getDefiningOp());
|
auto srcOp = dyn_cast_or_null<triton::gpu::ConvertLayoutOp>(
|
||||||
if(!srcOp)
|
tmpOp.src().getDefiningOp());
|
||||||
|
if (!srcOp)
|
||||||
return mlir::failure();
|
return mlir::failure();
|
||||||
auto arg = srcOp.src();
|
auto arg = srcOp.src();
|
||||||
auto X = tmpOp.src();
|
auto X = tmpOp.src();
|
||||||
@@ -1285,25 +1287,74 @@ public:
|
|||||||
auto ZType = dstOp.getResult().getType().cast<RankedTensorType>();
|
auto ZType = dstOp.getResult().getType().cast<RankedTensorType>();
|
||||||
// encodings
|
// encodings
|
||||||
auto argEncoding = argType.getEncoding();
|
auto argEncoding = argType.getEncoding();
|
||||||
auto XEncoding = XType.getEncoding().cast<triton::gpu::SharedEncodingAttr>();
|
auto XEncoding =
|
||||||
auto YEncoding = YType.getEncoding().cast<triton::gpu::SharedEncodingAttr>();
|
XType.getEncoding().cast<triton::gpu::SharedEncodingAttr>();
|
||||||
auto ZEncoding = ZType.getEncoding().dyn_cast<triton::gpu::DotOperandEncodingAttr>();
|
auto YEncoding =
|
||||||
if(!ZEncoding)
|
YType.getEncoding().cast<triton::gpu::SharedEncodingAttr>();
|
||||||
|
auto ZEncoding =
|
||||||
|
ZType.getEncoding().dyn_cast<triton::gpu::DotOperandEncodingAttr>();
|
||||||
|
if (!ZEncoding)
|
||||||
return mlir::failure();
|
return mlir::failure();
|
||||||
// new X encoding
|
// new X encoding
|
||||||
auto newXOrder = triton::gpu::getOrder(argEncoding);
|
auto newXOrder = triton::gpu::getOrder(argEncoding);
|
||||||
auto newXEncoding = triton::gpu::SharedEncodingAttr::get(
|
auto newXEncoding = triton::gpu::SharedEncodingAttr::get(
|
||||||
getContext(), ZEncoding, XType.getShape(), newXOrder,
|
getContext(), ZEncoding, XType.getShape(), newXOrder,
|
||||||
XType.getElementType());
|
XType.getElementType());
|
||||||
auto newXType = RankedTensorType::get(XType.getShape(), XType.getElementType(),
|
auto newXType = RankedTensorType::get(XType.getShape(),
|
||||||
newXEncoding);
|
XType.getElementType(), newXEncoding);
|
||||||
if(XEncoding == newXEncoding)
|
if (XEncoding == newXEncoding)
|
||||||
return mlir::failure();
|
return mlir::failure();
|
||||||
|
|
||||||
|
auto newX = rewriter.create<triton::gpu::ConvertLayoutOp>(srcOp.getLoc(),
|
||||||
auto newX = rewriter.create<triton::gpu::ConvertLayoutOp>(srcOp.getLoc(), newXType, arg);
|
newXType, arg);
|
||||||
auto newY = rewriter.create<triton::TransOp>(tmpOp.getLoc(), newX);
|
auto newY = rewriter.create<triton::TransOp>(tmpOp.getLoc(), newX);
|
||||||
rewriter.replaceOpWithNewOp<triton::gpu::ConvertLayoutOp>(dstOp, ZType, newY);
|
rewriter.replaceOpWithNewOp<triton::gpu::ConvertLayoutOp>(dstOp, ZType,
|
||||||
|
newY);
|
||||||
|
return mlir::success();
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
//
|
||||||
|
class ConvertDotConvert : public mlir::RewritePattern {
|
||||||
|
public:
|
||||||
|
ConvertDotConvert(mlir::MLIRContext *context)
|
||||||
|
: mlir::RewritePattern(triton::gpu::ConvertLayoutOp::getOperationName(),
|
||||||
|
1, context) {}
|
||||||
|
|
||||||
|
LogicalResult
|
||||||
|
matchAndRewrite(mlir::Operation *op,
|
||||||
|
mlir::PatternRewriter &rewriter) const override {
|
||||||
|
auto dstOp = cast<triton::gpu::ConvertLayoutOp>(op);
|
||||||
|
auto dotOp = dyn_cast_or_null<triton::DotOp>(dstOp.src().getDefiningOp());
|
||||||
|
if (!dotOp)
|
||||||
|
return mlir::failure();
|
||||||
|
if (std::distance(dstOp->user_begin(), dstOp->user_end()) != 1 ||
|
||||||
|
std::distance(dotOp->user_begin(), dotOp->user_end()) != 1)
|
||||||
|
return mlir::failure();
|
||||||
|
auto cvtOp = dyn_cast_or_null<triton::gpu::ConvertLayoutOp>(
|
||||||
|
dotOp.getOperand(2).getDefiningOp());
|
||||||
|
if (!cvtOp)
|
||||||
|
return mlir::failure();
|
||||||
|
auto loadOp = dyn_cast_or_null<triton::LoadOp>(cvtOp.src().getDefiningOp());
|
||||||
|
if (!loadOp)
|
||||||
|
return mlir::failure();
|
||||||
|
auto dstTy = dstOp.getResult().getType().cast<RankedTensorType>();
|
||||||
|
auto srcTy = cvtOp.getOperand().getType().cast<RankedTensorType>();
|
||||||
|
if (dstTy != srcTy)
|
||||||
|
return mlir::failure();
|
||||||
|
|
||||||
|
// TODO: int tensor cores
|
||||||
|
auto _0f = rewriter.create<arith::ConstantFloatOp>(
|
||||||
|
op->getLoc(), APFloat(0.0f), dstTy.getElementType().cast<FloatType>());
|
||||||
|
auto _0 = rewriter.create<triton::SplatOp>(
|
||||||
|
op->getLoc(), dotOp.getResult().getType(), _0f);
|
||||||
|
auto newDot = rewriter.create<triton::DotOp>(
|
||||||
|
op->getLoc(), dotOp.getResult().getType(), dotOp.getOperand(0),
|
||||||
|
dotOp.getOperand(1), _0, dotOp.allowTF32());
|
||||||
|
auto newCvt = rewriter.create<triton::gpu::ConvertLayoutOp>(
|
||||||
|
op->getLoc(), dstTy, newDot.getResult());
|
||||||
|
auto newAdd = rewriter.replaceOpWithNewOp<arith::AddFOp>(
|
||||||
|
op, newCvt, cvtOp.getOperand());
|
||||||
return mlir::success();
|
return mlir::success();
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
@@ -1477,6 +1528,7 @@ public:
|
|||||||
patterns.add<MoveConvertOutOfIf>(context);
|
patterns.add<MoveConvertOutOfIf>(context);
|
||||||
patterns.add<BlockedToMMA>(context, computeCapability);
|
patterns.add<BlockedToMMA>(context, computeCapability);
|
||||||
patterns.add<ConvertTransConvert>(context);
|
patterns.add<ConvertTransConvert>(context);
|
||||||
|
patterns.add<ConvertDotConvert>(context);
|
||||||
|
|
||||||
if (applyPatternsAndFoldGreedily(m, std::move(patterns)).failed()) {
|
if (applyPatternsAndFoldGreedily(m, std::move(patterns)).failed()) {
|
||||||
signalPassFailure();
|
signalPassFailure();
|
||||||
|
@@ -148,7 +148,7 @@ module attributes {"triton_gpu.num-warps" = 8 : i32} {
|
|||||||
%136 = triton_gpu.convert_layout %60 : (tensor<128x64xf16, #shared0>) -> tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma1}>>
|
%136 = triton_gpu.convert_layout %60 : (tensor<128x64xf16, #shared0>) -> tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma1}>>
|
||||||
%137 = tt.dot %135, %136, %134 {allowTF32 = true} : tensor<128x128xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma1}>> * tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma1}>> -> tensor<128x64xf32, #mma1>
|
%137 = tt.dot %135, %136, %134 {allowTF32 = true} : tensor<128x128xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma1}>> * tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma1}>> -> tensor<128x64xf32, #mma1>
|
||||||
%138 = triton_gpu.convert_layout %137 : (tensor<128x64xf32, #mma1>) -> tensor<128x64xf32, #blocked2>
|
%138 = triton_gpu.convert_layout %137 : (tensor<128x64xf32, #mma1>) -> tensor<128x64xf32, #blocked2>
|
||||||
tt.store %arg29, %138 : tensor<128x64xf32, #blocked2>
|
tt.store %arg29, %133 : tensor<128x64xf32, #blocked2>
|
||||||
%139 = tt.addptr %arg29, %43 : tensor<128x64x!tt.ptr<f32>, #blocked2>, tensor<128x64xi32, #blocked2>
|
%139 = tt.addptr %arg29, %43 : tensor<128x64x!tt.ptr<f32>, #blocked2>, tensor<128x64xi32, #blocked2>
|
||||||
%140 = tt.addptr %arg30, %42 : tensor<128x64x!tt.ptr<f16>, #blocked1>, tensor<128x64xi32, #blocked1>
|
%140 = tt.addptr %arg30, %42 : tensor<128x64x!tt.ptr<f16>, #blocked1>, tensor<128x64xi32, #blocked1>
|
||||||
%141 = tt.addptr %arg31, %42 : tensor<128x64x!tt.ptr<f16>, #blocked1>, tensor<128x64xi32, #blocked1>
|
%141 = tt.addptr %arg31, %42 : tensor<128x64x!tt.ptr<f16>, #blocked1>, tensor<128x64xi32, #blocked1>
|
||||||
|
@@ -191,6 +191,7 @@ def _bwd_kernel(
|
|||||||
tl.store(dv_ptrs, dv)
|
tl.store(dv_ptrs, dv)
|
||||||
tl.store(dk_ptrs, dk)
|
tl.store(dk_ptrs, dk)
|
||||||
|
|
||||||
|
|
||||||
# _bwd_kernel = triton.compile("./slow.ttgir", num_warps=8)
|
# _bwd_kernel = triton.compile("./slow.ttgir", num_warps=8)
|
||||||
# _bwd_kernel = triton.compile("./unoptimized.ttgir", num_warps=8)
|
# _bwd_kernel = triton.compile("./unoptimized.ttgir", num_warps=8)
|
||||||
# _bwd_kernel = triton.compile("./bwd.ttgir", num_warps=8)
|
# _bwd_kernel = triton.compile("./bwd.ttgir", num_warps=8)
|
||||||
@@ -260,7 +261,7 @@ class _attention(torch.autograd.Function):
|
|||||||
BLOCK_M=ctx.BLOCK, D_HEAD=ctx.BLOCK_DMODEL,
|
BLOCK_M=ctx.BLOCK, D_HEAD=ctx.BLOCK_DMODEL,
|
||||||
)
|
)
|
||||||
|
|
||||||
# _bwd_kernel[(ctx.grid[1],1,1)](
|
# _bwd_kernel[(ctx.grid[1], 1, 1)](
|
||||||
# q.data_ptr(), k.data_ptr(), v.data_ptr(), ctx.sm_scale,
|
# q.data_ptr(), k.data_ptr(), v.data_ptr(), ctx.sm_scale,
|
||||||
# o.data_ptr(), do_scaled.data_ptr(),
|
# o.data_ptr(), do_scaled.data_ptr(),
|
||||||
# dq.data_ptr(), dk.data_ptr(), dv.data_ptr(),
|
# dq.data_ptr(), dk.data_ptr(), dv.data_ptr(),
|
||||||
|
Reference in New Issue
Block a user