[Triton-MLIR] tt.dot
operands now must have DotOperand layout; also added prefetch pass prototype (#712)
Co-authored-by: Jokeren <kerenzhou@openai.com> Co-authored-by: Phil Tillet <phil@openai.com> Co-authored-by: Superjomn <yanchunwei@outlook.com>
This commit is contained in:
@@ -7,6 +7,7 @@ add_mlir_dialect_library(TritonGPUTransforms
|
||||
CanonicalizeLoops.cpp
|
||||
Combine.cpp
|
||||
Pipeline.cpp
|
||||
Prefetch.cpp
|
||||
Swizzle.cpp
|
||||
TritonGPUConversion.cpp
|
||||
|
||||
|
@@ -12,21 +12,13 @@
|
||||
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
|
||||
#include "mlir/Transforms/Passes.h"
|
||||
#include "mlir/Transforms/RegionUtils.h"
|
||||
#include "triton/Analysis/Utility.h"
|
||||
#include "triton/Dialect/TritonGPU/IR/Dialect.h"
|
||||
#include "triton/Dialect/TritonGPU/Transforms/Passes.h"
|
||||
|
||||
#include <memory>
|
||||
|
||||
using namespace mlir;
|
||||
|
||||
static bool isSharedLayout(Value v) {
|
||||
if (auto tensorType = v.getType().dyn_cast<RankedTensorType>()) {
|
||||
Attribute encoding = tensorType.getEncoding();
|
||||
return encoding.isa<triton::gpu::SharedEncodingAttr>();
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
namespace {
|
||||
#include "TritonGPUCombine.inc"
|
||||
|
||||
@@ -37,7 +29,7 @@ namespace {
|
||||
// convert(blocked, dot_operand) ->
|
||||
// convert(blocked, mma) + convert(mma, dot_operand)
|
||||
// if this value is itself the result of a dot operation
|
||||
// this is a hueiristics to accomodate some pattern seen in fused attention
|
||||
// this is a heuristic to accomodate some pattern seen in fused attention
|
||||
// kernels.
|
||||
// TODO: replace this by something more generic, i.e. layout-aware CSE
|
||||
class DecomposeDotOperand : public mlir::RewritePattern {
|
||||
@@ -59,9 +51,8 @@ public:
|
||||
dstType.getEncoding().isa<triton::gpu::DotOperandEncodingAttr>()) {
|
||||
auto tmpType =
|
||||
RankedTensorType::get(dstType.getShape(), dstType.getElementType(),
|
||||
dstType.getEncoding()
|
||||
.cast<triton::gpu::DotOperandEncodingAttr>()
|
||||
.getParent());
|
||||
triton::gpu::SharedEncodingAttr::get(
|
||||
op->getContext(), 1, 1, 1, {1, 0}));
|
||||
auto tmp = rewriter.create<triton::gpu::ConvertLayoutOp>(
|
||||
convert.getLoc(), tmpType, convert.getOperand());
|
||||
auto newConvert = rewriter.create<triton::gpu::ConvertLayoutOp>(
|
||||
@@ -87,11 +78,12 @@ public:
|
||||
if (!llvm::isa<triton::gpu::ConvertLayoutOp>(op))
|
||||
return mlir::failure();
|
||||
auto convert = llvm::cast<triton::gpu::ConvertLayoutOp>(op);
|
||||
auto srcType = convert.getOperand().getType().cast<RankedTensorType>();
|
||||
auto dstType = convert.getType().cast<RankedTensorType>();
|
||||
// we don't handle conversions to DotOperandEncodingAttr
|
||||
// this is a heuristics to accomodate fused attention
|
||||
if (dstType.getEncoding().isa<triton::gpu::DotOperandEncodingAttr>())
|
||||
return mlir::failure();
|
||||
// if (dstType.getEncoding().isa<triton::gpu::DotOperandEncodingAttr>())
|
||||
// return mlir::failure();
|
||||
// convert to the same layout -- we can delete
|
||||
if (op->getResultTypes() == op->getOperandTypes()) {
|
||||
rewriter.replaceOp(op, op->getOperands());
|
||||
@@ -122,8 +114,8 @@ public:
|
||||
rewriter.replaceOpWithNewOp<triton::gpu::InsertSliceAsyncOp>(
|
||||
op, newType, insert_slice.src(), newArg.getResult(),
|
||||
insert_slice.index(), insert_slice.mask(), insert_slice.other(),
|
||||
insert_slice.cache(), insert_slice.evict(),
|
||||
insert_slice.isVolatile(), insert_slice.axis());
|
||||
insert_slice.cache(), insert_slice.evict(), insert_slice.isVolatile(),
|
||||
insert_slice.axis());
|
||||
return mlir::success();
|
||||
}
|
||||
// cvt(extract_slice(x), type2) -> extract_slice(cvt(x, type2))
|
||||
@@ -133,7 +125,10 @@ public:
|
||||
auto newType = RankedTensorType::get(
|
||||
origType.getShape(), origType.getElementType(),
|
||||
op->getResult(0).getType().cast<RankedTensorType>().getEncoding());
|
||||
auto resType = op->getResult(0).getType().cast<RankedTensorType>();
|
||||
auto origResType = op->getResult(0).getType().cast<RankedTensorType>();
|
||||
auto resType = RankedTensorType::get(
|
||||
origResType.getShape(), origResType.getElementType(),
|
||||
extract_slice.getType().cast<RankedTensorType>().getEncoding());
|
||||
// Ensure that the new extract_slice op is placed in the same place as the
|
||||
// old extract_slice op. Otherwise, the new extract_slice op may be placed
|
||||
// after the async_wait op, which is not allowed.
|
||||
@@ -148,8 +143,21 @@ public:
|
||||
extract_slice.static_strides());
|
||||
return mlir::success();
|
||||
}
|
||||
|
||||
// cvt(type2, x)
|
||||
if (llvm::isa<triton::gpu::ConvertLayoutOp>(arg)) {
|
||||
auto argType = arg->getOperand(0).getType().cast<RankedTensorType>();
|
||||
if (arg->getOperand(0).getDefiningOp() &&
|
||||
!argType.getEncoding().isa<triton::gpu::SharedEncodingAttr>() &&
|
||||
srcType.getEncoding().isa<triton::gpu::SharedEncodingAttr>() &&
|
||||
!dstType.getEncoding().isa<triton::gpu::SharedEncodingAttr>()) {
|
||||
|
||||
return mlir::failure();
|
||||
}
|
||||
auto srcShared =
|
||||
srcType.getEncoding().dyn_cast<triton::gpu::SharedEncodingAttr>();
|
||||
if (srcShared && srcShared.getVec() > 1)
|
||||
return mlir::failure();
|
||||
rewriter.replaceOpWithNewOp<triton::gpu::ConvertLayoutOp>(
|
||||
op, op->getResultTypes().front(), arg->getOperand(0));
|
||||
return mlir::success();
|
||||
@@ -253,8 +261,8 @@ public:
|
||||
if (!op)
|
||||
return mlir::failure();
|
||||
// we don't want to rematerialize any conversion to/from shared
|
||||
if (isSharedLayout(cvt->getResults()[0]) ||
|
||||
isSharedLayout(cvt->getOperand(0)))
|
||||
if (isSharedEncoding(cvt->getResults()[0]) ||
|
||||
isSharedEncoding(cvt->getOperand(0)))
|
||||
return mlir::failure();
|
||||
// we don't handle conversions to DotOperandEncodingAttr
|
||||
// this is a heuristics to accomodate fused attention
|
||||
@@ -325,7 +333,6 @@ public:
|
||||
for (Operation *op : tmp)
|
||||
sortedValues.push_back(op->getResult(0));
|
||||
|
||||
// llvm::outs() << "----\n";
|
||||
BlockAndValueMapping mapping;
|
||||
for (Value currOperand : sortedValues) {
|
||||
// unpack information
|
||||
@@ -346,7 +353,6 @@ public:
|
||||
newOperand->moveAfter(currOperation);
|
||||
mapping.map(currOperand, newOperand);
|
||||
}
|
||||
// llvm::outs() << cvt->getParentOfType<mlir::FuncOp>() << "\n";
|
||||
rewriter.replaceOp(cvt, mapping.lookup(cvt->getOperand(0)));
|
||||
return mlir::success();
|
||||
}
|
||||
@@ -356,8 +362,6 @@ public:
|
||||
//
|
||||
// -----------------------------------------------------------------------------
|
||||
|
||||
// int test = 0;
|
||||
|
||||
class MoveConvertOutOfLoop : public mlir::RewritePattern {
|
||||
public:
|
||||
MoveConvertOutOfLoop(mlir::MLIRContext *context)
|
||||
@@ -435,9 +439,25 @@ public:
|
||||
auto users = iterArg.value().getUsers();
|
||||
// check first condition
|
||||
SetVector<Type> cvtTargetTypes;
|
||||
for (auto user : users)
|
||||
if (isa<triton::gpu::ConvertLayoutOp>(user))
|
||||
cvtTargetTypes.insert(user->getResults()[0].getType());
|
||||
for (auto user : users) {
|
||||
if (isa<triton::gpu::ConvertLayoutOp>(user)) {
|
||||
auto newType =
|
||||
user->getResults()[0].getType().cast<RankedTensorType>();
|
||||
auto oldType = user->getOperand(0).getType().cast<RankedTensorType>();
|
||||
if (oldType.getEncoding().isa<triton::gpu::SharedEncodingAttr>() &&
|
||||
newType.getEncoding()
|
||||
.isa<triton::gpu::DotOperandEncodingAttr>()) {
|
||||
continue;
|
||||
}
|
||||
if (newType.getEncoding().isa<triton::gpu::SharedEncodingAttr>()) {
|
||||
if (newType.getEncoding()
|
||||
.cast<triton::gpu::SharedEncodingAttr>()
|
||||
.getVec() == 1)
|
||||
continue;
|
||||
}
|
||||
cvtTargetTypes.insert(newType);
|
||||
}
|
||||
}
|
||||
if (cvtTargetTypes.size() != 1)
|
||||
continue;
|
||||
// TODO: check second condition
|
||||
@@ -446,6 +466,7 @@ public:
|
||||
continue;
|
||||
}
|
||||
// check
|
||||
// llvm::outs() << "replacing " << iterArg.index() << "\n";
|
||||
for (auto op : iterArg.value().getUsers()) {
|
||||
auto cvt = dyn_cast<triton::gpu::ConvertLayoutOp>(op);
|
||||
if (!cvt)
|
||||
@@ -597,10 +618,23 @@ public:
|
||||
auto oldAcc = dotOp.getOperand(2);
|
||||
auto newAcc = rewriter.create<triton::gpu::ConvertLayoutOp>(
|
||||
oldAcc.getLoc(), newRetType, oldAcc);
|
||||
// convert output
|
||||
Value a = dotOp.a();
|
||||
Value b = dotOp.b();
|
||||
auto oldAType = a.getType().cast<RankedTensorType>();
|
||||
auto oldBType = b.getType().cast<RankedTensorType>();
|
||||
auto newAType = RankedTensorType::get(
|
||||
oldAType.getShape(), oldAType.getElementType(),
|
||||
triton::gpu::DotOperandEncodingAttr::get(oldAType.getContext(), 0,
|
||||
newRetType.getEncoding()));
|
||||
auto newBType = RankedTensorType::get(
|
||||
oldBType.getShape(), oldBType.getElementType(),
|
||||
triton::gpu::DotOperandEncodingAttr::get(oldBType.getContext(), 1,
|
||||
newRetType.getEncoding()));
|
||||
a = rewriter.create<triton::gpu::ConvertLayoutOp>(a.getLoc(), newAType, a);
|
||||
b = rewriter.create<triton::gpu::ConvertLayoutOp>(b.getLoc(), newBType, b);
|
||||
auto newDot = rewriter.create<triton::DotOp>(
|
||||
dotOp.getLoc(), newRetType, dotOp.getOperand(0), dotOp.getOperand(1),
|
||||
newAcc, dotOp.allowTF32(), dotOp.transA(), dotOp.transB());
|
||||
dotOp.getLoc(), newRetType, a, b, newAcc, dotOp.allowTF32(),
|
||||
dotOp.transA(), dotOp.transB());
|
||||
|
||||
rewriter.replaceOpWithNewOp<triton::gpu::ConvertLayoutOp>(
|
||||
op, oldRetType, newDot.getResult());
|
||||
@@ -623,7 +657,7 @@ public:
|
||||
mlir::RewritePatternSet patterns(context);
|
||||
|
||||
patterns.add<SimplifyConversion>(context);
|
||||
patterns.add<DecomposeDotOperand>(context);
|
||||
// patterns.add<DecomposeDotOperand>(context);
|
||||
patterns.add<RematerializeBackward>(context);
|
||||
patterns.add<RematerializeForward>(context);
|
||||
patterns.add<MoveConvertOutOfLoop>(context);
|
||||
|
@@ -1,3 +1,4 @@
|
||||
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
||||
#include "mlir/IR/BlockAndValueMapping.h"
|
||||
#include "triton/Dialect/TritonGPU/IR/Dialect.h"
|
||||
#include "triton/Dialect/TritonGPU/Transforms/Passes.h"
|
||||
@@ -11,6 +12,7 @@
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
using namespace mlir;
|
||||
namespace ttg = triton::gpu;
|
||||
|
||||
#define GEN_PASS_CLASSES
|
||||
#include "triton/Dialect/TritonGPU/Transforms/Passes.h.inc"
|
||||
@@ -24,6 +26,7 @@ static Type getI1SameShape(Value v) {
|
||||
}
|
||||
|
||||
namespace {
|
||||
|
||||
class LoopPipeliner {
|
||||
/// cache forOp we are working on
|
||||
scf::ForOp forOp;
|
||||
@@ -37,6 +40,8 @@ class LoopPipeliner {
|
||||
DenseMap<Value, Value> loadsMapping;
|
||||
/// load => buffer
|
||||
DenseMap<Value, Value> loadsBuffer;
|
||||
/// load => buffer type (with shared layout after swizzling)
|
||||
DenseMap<Value, RankedTensorType> loadsBufferType;
|
||||
/// load => buffer at stage N
|
||||
DenseMap<Value, SmallVector<Value>> loadStageBuffer;
|
||||
/// load => after extract
|
||||
@@ -67,8 +72,11 @@ class LoopPipeliner {
|
||||
Value lookupOrDefault(Value origin, int stage);
|
||||
|
||||
/// returns a empty buffer of size <numStages, ...>
|
||||
triton::gpu::AllocTensorOp allocateEmptyBuffer(Operation *op,
|
||||
OpBuilder &builder);
|
||||
ttg::AllocTensorOp allocateEmptyBuffer(Operation *op, OpBuilder &builder);
|
||||
|
||||
/// compute type of shared buffers (with swizzled shared layouts)
|
||||
RankedTensorType getSwizzleType(ttg::DotOperandEncodingAttr dotOpEnc,
|
||||
RankedTensorType tensorType);
|
||||
|
||||
public:
|
||||
LoopPipeliner(scf::ForOp forOp, int numStages)
|
||||
@@ -128,25 +136,82 @@ void LoopPipeliner::collectDeps(Value v, int stages, DenseSet<Value> &deps) {
|
||||
}
|
||||
}
|
||||
|
||||
triton::gpu::AllocTensorOp
|
||||
LoopPipeliner::allocateEmptyBuffer(Operation *op, OpBuilder &builder) {
|
||||
ttg::AllocTensorOp LoopPipeliner::allocateEmptyBuffer(Operation *op,
|
||||
OpBuilder &builder) {
|
||||
// allocate a buffer for each pipelined tensor
|
||||
// shape: e.g. (numStages==4), <32x64xbf16> -> <4x32x64xbf16>
|
||||
Value convertLayout = loadsMapping[op->getResult(0)];
|
||||
if (auto tensorType = convertLayout.getType().dyn_cast<RankedTensorType>()) {
|
||||
SmallVector<int64_t> shape(tensorType.getShape().begin(),
|
||||
tensorType.getShape().end());
|
||||
shape.insert(shape.begin(), numStages);
|
||||
Type elementType = tensorType.getElementType();
|
||||
// The encoding of the buffer is similar to the original tensor
|
||||
Attribute encoding = tensorType.getEncoding();
|
||||
auto bufferType = RankedTensorType::get(shape, elementType, encoding);
|
||||
return builder.create<triton::gpu::AllocTensorOp>(convertLayout.getLoc(),
|
||||
bufferType);
|
||||
return builder.create<ttg::AllocTensorOp>(
|
||||
convertLayout.getLoc(), loadsBufferType[op->getResult(0)]);
|
||||
}
|
||||
llvm_unreachable("Async copy's return should be of RankedTensorType");
|
||||
}
|
||||
|
||||
// TODO: I copied the code from Swizzle.cpp. Should find a way to unify the
|
||||
// code path.
|
||||
// Swizzle has to be performed before pipeline for now. If we do swizzle
|
||||
// after pipeline, we need to propagate the swizzled layout to all
|
||||
// operands that is an alias of the swizzled tensor. The alias analysis
|
||||
// component maybe helpful for this purpose.
|
||||
RankedTensorType
|
||||
LoopPipeliner::getSwizzleType(ttg::DotOperandEncodingAttr dotOpEnc,
|
||||
RankedTensorType ty) {
|
||||
int opIdx = dotOpEnc.getOpIdx();
|
||||
int vec = 1;
|
||||
int maxPhase = 1;
|
||||
int perPhase = 1;
|
||||
llvm::SmallVector<unsigned> order;
|
||||
if (auto mmaEnc = dotOpEnc.getParent().dyn_cast<ttg::MmaEncodingAttr>()) {
|
||||
// Only support row major for now
|
||||
// TODO(Keren): check why column major code crashes
|
||||
order = {1, 0};
|
||||
int version = mmaEnc.getVersion();
|
||||
auto tyEncoding = ty.getEncoding().cast<ttg::BlockedEncodingAttr>();
|
||||
// number of rows per phase
|
||||
perPhase = 128 / (ty.getShape()[order[0]] *
|
||||
(ty.getElementType().getIntOrFloatBitWidth() / 8));
|
||||
perPhase = std::max<int>(perPhase, 1);
|
||||
|
||||
// index of the inner dimension in `order`
|
||||
unsigned inner = (opIdx == 0) ? 0 : 1;
|
||||
if (version == 1) {
|
||||
maxPhase = (order[inner] == 1 ? 8 : 4) / perPhase;
|
||||
// TODO: handle rep (see
|
||||
// https://github.com/openai/triton/blob/master/lib/codegen/analysis/layout.cc#L209)
|
||||
} else if (version == 2) {
|
||||
auto eltTy = ty.getElementType();
|
||||
std::vector<size_t> matShape = {8, 8,
|
||||
2 * 64 / eltTy.getIntOrFloatBitWidth()};
|
||||
// for now, disable swizzle when using transposed int8 tensor cores
|
||||
if (ty.getElementType().isInteger(8) && order[0] == inner)
|
||||
perPhase = 1;
|
||||
else {
|
||||
if (opIdx == 0) { // compute swizzling for A operand
|
||||
vec = order[0] == 1 ? matShape[2] : matShape[0]; // k : m
|
||||
int mmaStride = order[0] == 1 ? matShape[0] : matShape[2];
|
||||
maxPhase = mmaStride / perPhase;
|
||||
} else if (opIdx == 1) { // compute swizzling for B operand
|
||||
vec = order[0] == 1 ? matShape[1] : matShape[2]; // n : k
|
||||
int mmaStride = order[0] == 1 ? matShape[2] : matShape[1];
|
||||
maxPhase = mmaStride / perPhase;
|
||||
} else
|
||||
llvm_unreachable("invalid operand index");
|
||||
}
|
||||
} else // version not in [1, 2]
|
||||
llvm_unreachable("unsupported swizzling for provided MMA version");
|
||||
} else { // If the layout of dot is not mma, we don't need to swizzle
|
||||
auto blockedEnc = dotOpEnc.getParent().cast<ttg::BlockedEncodingAttr>();
|
||||
order = llvm::SmallVector<unsigned>(blockedEnc.getOrder().begin(),
|
||||
blockedEnc.getOrder().end());
|
||||
}
|
||||
auto newEncoding = ttg::SharedEncodingAttr::get(ty.getContext(), vec,
|
||||
perPhase, maxPhase, order);
|
||||
SmallVector<int64_t> bufferShape(ty.getShape().begin(), ty.getShape().end());
|
||||
bufferShape.insert(bufferShape.begin(), numStages);
|
||||
return RankedTensorType::get(bufferShape, ty.getElementType(), newEncoding);
|
||||
}
|
||||
|
||||
/// A load instruction can be pipelined if:
|
||||
/// - the load doesn't depend on any other loads (after loop peeling)
|
||||
/// - (?) this load is not a loop-invariant value (we should run LICM before
|
||||
@@ -186,19 +251,21 @@ LogicalResult LoopPipeliner::initialize() {
|
||||
}
|
||||
}
|
||||
|
||||
// For now, we only pipeline loads that have one covert_layout (to smem) use
|
||||
// We only pipeline loads that have one covert_layout (to dot_op) use
|
||||
// TODO: lift this constraint in the future
|
||||
if (isCandiate && loadOp.getResult().hasOneUse()) {
|
||||
isCandiate = false;
|
||||
Operation *use = *loadOp.getResult().getUsers().begin();
|
||||
if (auto convertLayout =
|
||||
llvm::dyn_cast<triton::gpu::ConvertLayoutOp>(use)) {
|
||||
if (auto convertLayout = llvm::dyn_cast<ttg::ConvertLayoutOp>(use)) {
|
||||
if (auto tensorType = convertLayout.getResult()
|
||||
.getType()
|
||||
.dyn_cast<RankedTensorType>()) {
|
||||
if (tensorType.getEncoding().isa<triton::gpu::SharedEncodingAttr>()) {
|
||||
if (auto dotOpEnc = tensorType.getEncoding()
|
||||
.dyn_cast<ttg::DotOperandEncodingAttr>()) {
|
||||
isCandiate = true;
|
||||
loadsMapping[loadOp] = convertLayout;
|
||||
loadsBufferType[loadOp] = getSwizzleType(
|
||||
dotOpEnc, loadOp.getType().cast<RankedTensorType>());
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -238,6 +305,9 @@ void LoopPipeliner::emitPrologue() {
|
||||
setValueMapping(arg, operand.get(), 0);
|
||||
}
|
||||
|
||||
// helper to construct int attribute
|
||||
auto intAttr = [&](int64_t val) { return builder.getI64IntegerAttr(val); };
|
||||
|
||||
// prologue from [0, numStage-1)
|
||||
Value iv = forOp.getLowerBound();
|
||||
pipelineIterIdx = builder.create<arith::ConstantIntOp>(iv.getLoc(), 0, 32);
|
||||
@@ -330,14 +400,15 @@ void LoopPipeliner::emitPrologue() {
|
||||
builder.create<arith::ConstantIntOp>(iv.getLoc(), 1, 32));
|
||||
} // for (int stage = 0; stage < numStages - 1; ++stage)
|
||||
|
||||
auto intAttr = [&](int64_t v) { return builder.getI64IntegerAttr(v); };
|
||||
|
||||
// async.wait & extract_slice
|
||||
builder.create<triton::gpu::AsyncWaitOp>(loads[0].getLoc(),
|
||||
loads.size() * (numStages - 2));
|
||||
builder.create<ttg::AsyncWaitOp>(loads[0].getLoc(),
|
||||
loads.size() * (numStages - 2));
|
||||
loopIterIdx = builder.create<arith::ConstantIntOp>(iv.getLoc(), 0, 32);
|
||||
for (Value loadOp : loads) {
|
||||
auto sliceType = loadsMapping[loadOp].getType().cast<RankedTensorType>();
|
||||
sliceType =
|
||||
RankedTensorType::get(sliceType.getShape(), sliceType.getElementType(),
|
||||
loadsBufferType[loadOp].getEncoding());
|
||||
Value extractSlice = builder.create<tensor::ExtractSliceOp>(
|
||||
loadOp.getLoc(), sliceType, loadStageBuffer[loadOp][numStages - 1],
|
||||
SmallVector<OpFoldResult>{intAttr(0), intAttr(0), intAttr(0)},
|
||||
@@ -366,6 +437,7 @@ void LoopPipeliner::emitEpilogue() {
|
||||
|
||||
scf::ForOp LoopPipeliner::createNewForOp() {
|
||||
OpBuilder builder(forOp);
|
||||
auto intAttr = [&](int64_t v) { return builder.getI64IntegerAttr(v); };
|
||||
|
||||
// order of new args:
|
||||
// (original args),
|
||||
@@ -477,8 +549,6 @@ scf::ForOp LoopPipeliner::createNewForOp() {
|
||||
extractSliceIndex = builder.create<arith::IndexCastOp>(
|
||||
extractSliceIndex.getLoc(), builder.getIndexType(), extractSliceIndex);
|
||||
|
||||
auto intAttr = [&](int64_t v) { return builder.getI64IntegerAttr(v); };
|
||||
|
||||
for (Operation *op : orderedDeps) {
|
||||
Operation *nextOp = nullptr;
|
||||
// update loading mask
|
||||
@@ -508,6 +578,9 @@ scf::ForOp LoopPipeliner::createNewForOp() {
|
||||
loadOp.evict(), loadOp.isVolatile(), /*axis*/ 0);
|
||||
nextBuffers.push_back(insertAsyncOp);
|
||||
auto sliceType = loadsMapping[loadOp].getType().cast<RankedTensorType>();
|
||||
sliceType = RankedTensorType::get(sliceType.getShape(),
|
||||
sliceType.getElementType(),
|
||||
loadsBufferType[loadOp].getEncoding());
|
||||
nextOp = builder.create<tensor::ExtractSliceOp>(
|
||||
op->getLoc(), sliceType, insertAsyncOp,
|
||||
SmallVector<OpFoldResult>{extractSliceIndex, intAttr(0), intAttr(0)},
|
||||
@@ -534,8 +607,37 @@ scf::ForOp LoopPipeliner::createNewForOp() {
|
||||
}
|
||||
}
|
||||
|
||||
{
|
||||
OpBuilder::InsertionGuard guard(builder);
|
||||
for (Operation &op : *newForOp.getBody()) {
|
||||
if (auto dotOp = llvm::dyn_cast<triton::DotOp>(&op)) {
|
||||
builder.setInsertionPoint(&op);
|
||||
auto dotType = dotOp.getType().cast<RankedTensorType>();
|
||||
Value a = dotOp.a();
|
||||
Value b = dotOp.b();
|
||||
auto layoutCast = [&](Value dotOperand, int opIdx) -> Value {
|
||||
auto tensorType = dotOperand.getType().cast<RankedTensorType>();
|
||||
if (!tensorType.getEncoding().isa<ttg::DotOperandEncodingAttr>()) {
|
||||
auto newEncoding = ttg::DotOperandEncodingAttr::get(
|
||||
tensorType.getContext(), opIdx, dotType.getEncoding());
|
||||
auto newType =
|
||||
RankedTensorType::get(tensorType.getShape(),
|
||||
tensorType.getElementType(), newEncoding);
|
||||
return builder.create<ttg::ConvertLayoutOp>(dotOperand.getLoc(),
|
||||
newType, dotOperand);
|
||||
}
|
||||
return dotOperand;
|
||||
};
|
||||
a = layoutCast(a, 0);
|
||||
b = layoutCast(b, 1);
|
||||
dotOp->setOperand(0, a);
|
||||
dotOp->setOperand(1, b);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// async.wait & extract_slice
|
||||
Operation *asyncWait = builder.create<triton::gpu::AsyncWaitOp>(
|
||||
Operation *asyncWait = builder.create<ttg::AsyncWaitOp>(
|
||||
loads[0].getLoc(), loads.size() * (numStages - 2));
|
||||
for (auto it = extractSlices.rbegin(); it != extractSlices.rend(); ++it) {
|
||||
// move extract_slice after asyncWait
|
||||
|
304
lib/Dialect/TritonGPU/Transforms/Prefetch.cpp
Normal file
304
lib/Dialect/TritonGPU/Transforms/Prefetch.cpp
Normal file
@@ -0,0 +1,304 @@
|
||||
//===----------------------------------------------------------------------===//
|
||||
//
|
||||
// This pass tries to prefetch operands (a and b) of tt.dot.
|
||||
// Those ConvertLayoutOps will be lowered to shared memory loads.
|
||||
//
|
||||
// For example:
|
||||
// %a: tensor<128x32xf16, #enc>
|
||||
// scf.for %iv = ... iter_args(%a_arg = %a, ...) {
|
||||
// %d = tt.dot %a_arg, %b, %c
|
||||
// ...
|
||||
// scf.yield %a_next, ...
|
||||
// }
|
||||
//
|
||||
// will be translated to
|
||||
//
|
||||
// %a: tensor<128x32xf16, #enc>
|
||||
// %a_tmp = tensor.extract_slice %a[0, 0] [128, 16]
|
||||
// %a_prefetch = triton_gpu.convert_layout %a_tmp
|
||||
// scf.for %iv = ... iter_args(%a_buf = %a, ..., %a_prefetch_arg = %a_prefetch)
|
||||
// {
|
||||
// %x = tt.dot %a_arg, %b, %c
|
||||
// %a_tmp_rem = tensor.extract_slice %a_buf[0, 16] [128, 16]
|
||||
// %a_prefetch_next = triton_gpu.convert_layout %a_tmp_rem
|
||||
// ...
|
||||
// scf.yield %next_a, ..., %a_prefetch_next
|
||||
// }
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#include "mlir/IR/BlockAndValueMapping.h"
|
||||
#include "triton/Dialect/TritonGPU/IR/Dialect.h"
|
||||
#include "triton/Dialect/TritonGPU/Transforms/Passes.h"
|
||||
|
||||
using namespace mlir;
|
||||
|
||||
#define GEN_PASS_CLASSES
|
||||
#include "triton/Dialect/TritonGPU/Transforms/Passes.h.inc"
|
||||
|
||||
namespace {
|
||||
|
||||
class Prefetcher {
|
||||
/// cache the ForOp we are working on
|
||||
scf::ForOp forOp;
|
||||
/// cache the YieldOp of this ForOp
|
||||
scf::YieldOp yieldOp;
|
||||
///
|
||||
// TODO: add a hook to infer prefetchWidth
|
||||
unsigned prefetchWidth = 16;
|
||||
|
||||
/// dots to be prefetched
|
||||
SetVector<Value> dots;
|
||||
/// dot => dot operand
|
||||
DenseMap<Value, Value> dot2aLoopArg;
|
||||
DenseMap<Value, Value> dot2aHeaderDef;
|
||||
DenseMap<Value, Value> dot2bLoopArg;
|
||||
DenseMap<Value, Value> dot2bHeaderDef;
|
||||
DenseMap<Value, Value> dot2aYield;
|
||||
DenseMap<Value, Value> dot2bYield;
|
||||
/// operand => defining
|
||||
DenseMap<Value, Value> operand2headPrefetch;
|
||||
|
||||
LogicalResult isForOpOperand(Value v);
|
||||
|
||||
Value generatePrefetch(Value v, unsigned opIdx, bool isPrefetch,
|
||||
Attribute dotEncoding, OpBuilder &builder,
|
||||
llvm::Optional<int64_t> offsetK = llvm::None,
|
||||
llvm::Optional<int64_t> shapeK = llvm::None);
|
||||
|
||||
public:
|
||||
Prefetcher() = delete;
|
||||
|
||||
Prefetcher(scf::ForOp forOp) : forOp(forOp) {
|
||||
yieldOp = cast<scf::YieldOp>(forOp.getBody()->getTerminator());
|
||||
}
|
||||
|
||||
LogicalResult initialize();
|
||||
|
||||
void emitPrologue();
|
||||
|
||||
scf::ForOp createNewForOp();
|
||||
};
|
||||
|
||||
Value Prefetcher::generatePrefetch(Value v, unsigned opIdx, bool isPrefetch,
|
||||
Attribute dotEncoding, OpBuilder &builder,
|
||||
llvm::Optional<int64_t> offsetK,
|
||||
llvm::Optional<int64_t> shapeK) {
|
||||
// opIdx: 0 => a, 1 => b
|
||||
auto type = v.getType().cast<RankedTensorType>();
|
||||
SmallVector<int64_t> shape{type.getShape().begin(), type.getShape().end()};
|
||||
SmallVector<int64_t> offset{0, 0};
|
||||
Type elementType = type.getElementType();
|
||||
|
||||
auto intAttr = [&](int64_t val) { return builder.getI64IntegerAttr(val); };
|
||||
|
||||
// k => (prefetchWidth, k - prefetchWidth)
|
||||
int64_t kIdx = opIdx == 0 ? 1 : 0;
|
||||
|
||||
offset[kIdx] = isPrefetch ? 0 : prefetchWidth;
|
||||
shape[kIdx] = isPrefetch ? prefetchWidth : (shape[kIdx] - prefetchWidth);
|
||||
|
||||
if (shapeK)
|
||||
shape[kIdx] = *shapeK;
|
||||
if (offsetK)
|
||||
offset[kIdx] = *offsetK;
|
||||
|
||||
Value newSmem = builder.create<tensor::ExtractSliceOp>(
|
||||
v.getLoc(),
|
||||
// TODO: encoding?
|
||||
RankedTensorType::get(shape, elementType, type.getEncoding()), v,
|
||||
SmallVector<OpFoldResult>{intAttr(offset[0]), intAttr(offset[1])},
|
||||
SmallVector<OpFoldResult>{intAttr(shape[0]), intAttr(shape[1])},
|
||||
SmallVector<OpFoldResult>{intAttr(1), intAttr(1)});
|
||||
|
||||
auto dotOperandEnc = triton::gpu::DotOperandEncodingAttr::get(
|
||||
builder.getContext(), opIdx, dotEncoding);
|
||||
Value prefetchSlice = builder.create<triton::gpu::ConvertLayoutOp>(
|
||||
v.getLoc(), RankedTensorType::get(shape, elementType, dotOperandEnc),
|
||||
newSmem);
|
||||
|
||||
return prefetchSlice;
|
||||
}
|
||||
|
||||
LogicalResult Prefetcher::initialize() {
|
||||
Block *loop = forOp.getBody();
|
||||
|
||||
SmallVector<triton::DotOp> dotsInFor;
|
||||
for (Operation &op : *loop)
|
||||
if (auto dotOp = dyn_cast<triton::DotOp>(op))
|
||||
dotsInFor.push_back(dotOp);
|
||||
|
||||
if (dotsInFor.empty())
|
||||
return failure();
|
||||
|
||||
// returns source of cvt
|
||||
auto getPrefetchSrc = [](Value v) -> Value {
|
||||
// TODO: Check if the layout of src is SharedEncodingAttr
|
||||
if (auto cvt = v.getDefiningOp<triton::gpu::ConvertLayoutOp>())
|
||||
return cvt.src();
|
||||
return Value();
|
||||
};
|
||||
|
||||
auto getIncomingOp = [this](Value v) -> Value {
|
||||
if (auto arg = v.dyn_cast<BlockArgument>())
|
||||
if (arg.getOwner()->getParentOp() == forOp.getOperation())
|
||||
return forOp.getOpOperandForRegionIterArg(arg).get();
|
||||
return Value();
|
||||
};
|
||||
|
||||
auto getYieldOp = [this](Value v) -> Value {
|
||||
auto arg = v.cast<BlockArgument>();
|
||||
unsigned yieldIdx = arg.getArgNumber() - forOp.getNumInductionVars();
|
||||
return yieldOp.getOperand(yieldIdx);
|
||||
};
|
||||
|
||||
for (triton::DotOp dot : dotsInFor) {
|
||||
Value aSmem = getPrefetchSrc(dot.a());
|
||||
Value bSmem = getPrefetchSrc(dot.b());
|
||||
if (aSmem && bSmem) {
|
||||
Value aHeaderDef = getIncomingOp(aSmem);
|
||||
Value bHeaderDef = getIncomingOp(bSmem);
|
||||
// Only prefetch loop arg
|
||||
if (aHeaderDef && bHeaderDef) {
|
||||
dots.insert(dot);
|
||||
dot2aHeaderDef[dot] = aHeaderDef;
|
||||
dot2bHeaderDef[dot] = bHeaderDef;
|
||||
dot2aLoopArg[dot] = aSmem;
|
||||
dot2bLoopArg[dot] = bSmem;
|
||||
dot2aYield[dot] = getYieldOp(aSmem);
|
||||
dot2bYield[dot] = getYieldOp(bSmem);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return success();
|
||||
}
|
||||
|
||||
void Prefetcher::emitPrologue() {
|
||||
OpBuilder builder(forOp);
|
||||
|
||||
for (Value dot : dots) {
|
||||
Attribute dotEncoding =
|
||||
dot.getType().cast<RankedTensorType>().getEncoding();
|
||||
Value aPrefetched =
|
||||
generatePrefetch(dot2aHeaderDef[dot], 0, true, dotEncoding, builder);
|
||||
operand2headPrefetch[dot.getDefiningOp<triton::DotOp>().a()] = aPrefetched;
|
||||
Value bPrefetched =
|
||||
generatePrefetch(dot2bHeaderDef[dot], 1, true, dotEncoding, builder);
|
||||
operand2headPrefetch[dot.getDefiningOp<triton::DotOp>().b()] = bPrefetched;
|
||||
}
|
||||
}
|
||||
|
||||
scf::ForOp Prefetcher::createNewForOp() {
|
||||
OpBuilder builder(forOp);
|
||||
|
||||
SmallVector<Value> loopArgs;
|
||||
for (auto v : forOp.getIterOperands())
|
||||
loopArgs.push_back(v);
|
||||
for (Value dot : dots) {
|
||||
loopArgs.push_back(
|
||||
operand2headPrefetch[dot.getDefiningOp<triton::DotOp>().a()]);
|
||||
loopArgs.push_back(
|
||||
operand2headPrefetch[dot.getDefiningOp<triton::DotOp>().b()]);
|
||||
}
|
||||
|
||||
auto newForOp = builder.create<scf::ForOp>(
|
||||
forOp.getLoc(), forOp.getLowerBound(), forOp.getUpperBound(),
|
||||
forOp.getStep(), loopArgs);
|
||||
|
||||
auto largestPow2 = [](int64_t n) -> int64_t {
|
||||
while ((n & (n - 1)) != 0)
|
||||
n = n & (n - 1);
|
||||
return n;
|
||||
};
|
||||
|
||||
builder.setInsertionPointToStart(newForOp.getBody());
|
||||
BlockAndValueMapping mapping;
|
||||
for (const auto &arg : llvm::enumerate(forOp.getRegionIterArgs()))
|
||||
mapping.map(arg.value(), newForOp.getRegionIterArgs()[arg.index()]);
|
||||
|
||||
for (Operation &op : forOp.getBody()->without_terminator()) {
|
||||
Operation *newOp = nullptr;
|
||||
auto dot = dyn_cast<triton::DotOp>(&op);
|
||||
if (dots.contains(dot)) {
|
||||
Attribute dotEncoding =
|
||||
dot.getType().cast<RankedTensorType>().getEncoding();
|
||||
// prefetched dot
|
||||
Operation *firstDot = builder.clone(*dot, mapping);
|
||||
if (Value a = operand2headPrefetch.lookup(dot.a()))
|
||||
firstDot->setOperand(
|
||||
0, newForOp.getRegionIterArgForOpOperand(*a.use_begin()));
|
||||
if (Value b = operand2headPrefetch.lookup(dot.b()))
|
||||
firstDot->setOperand(
|
||||
1, newForOp.getRegionIterArgForOpOperand(*b.use_begin()));
|
||||
|
||||
// remaining part
|
||||
int64_t kOff = prefetchWidth;
|
||||
int64_t kRem = dot.a().getType().cast<RankedTensorType>().getShape()[1] -
|
||||
prefetchWidth;
|
||||
Operation *prevDot = firstDot;
|
||||
while (kRem != 0) {
|
||||
int64_t kShape = largestPow2(kRem);
|
||||
Value aRem =
|
||||
generatePrefetch(mapping.lookup(dot2aLoopArg[dot]), 0, false,
|
||||
dotEncoding, builder, kOff, kShape);
|
||||
Value bRem =
|
||||
generatePrefetch(mapping.lookup(dot2bLoopArg[dot]), 1, false,
|
||||
dotEncoding, builder, kOff, kShape);
|
||||
newOp = builder.clone(*dot, mapping);
|
||||
newOp->setOperand(0, aRem);
|
||||
newOp->setOperand(1, bRem);
|
||||
newOp->setOperand(2, prevDot->getResult(0));
|
||||
prevDot = newOp;
|
||||
kOff += kShape;
|
||||
kRem -= kShape;
|
||||
}
|
||||
} else {
|
||||
newOp = builder.clone(op, mapping);
|
||||
}
|
||||
// update mapping of results
|
||||
for (unsigned dstIdx : llvm::seq(unsigned(0), op.getNumResults()))
|
||||
mapping.map(op.getResult(dstIdx), newOp->getResult(dstIdx));
|
||||
}
|
||||
|
||||
// prefetch next iteration
|
||||
SmallVector<Value> yieldValues;
|
||||
for (Value v : forOp.getBody()->getTerminator()->getOperands())
|
||||
yieldValues.push_back(mapping.lookup(v));
|
||||
for (Value dot : dots) {
|
||||
Attribute dotEncoding =
|
||||
dot.getType().cast<RankedTensorType>().getEncoding();
|
||||
yieldValues.push_back(generatePrefetch(mapping.lookup(dot2aYield[dot]), 0,
|
||||
true, dotEncoding, builder));
|
||||
yieldValues.push_back(generatePrefetch(mapping.lookup(dot2bYield[dot]), 1,
|
||||
true, dotEncoding, builder));
|
||||
}
|
||||
// Update ops of yield
|
||||
builder.create<scf::YieldOp>(yieldOp.getLoc(), yieldValues);
|
||||
return newForOp;
|
||||
}
|
||||
|
||||
struct PrefetchPass : public TritonGPUPrefetchBase<PrefetchPass> {
|
||||
void runOnOperation() override {
|
||||
getOperation()->walk([&](scf::ForOp forOp) {
|
||||
Prefetcher prefetcher(forOp);
|
||||
|
||||
if (prefetcher.initialize().failed())
|
||||
return;
|
||||
|
||||
prefetcher.emitPrologue();
|
||||
|
||||
scf::ForOp newForOp = prefetcher.createNewForOp();
|
||||
|
||||
// replace the original loop
|
||||
for (unsigned i = 0; i < forOp->getNumResults(); ++i)
|
||||
forOp->getResult(i).replaceAllUsesWith(newForOp->getResult(i));
|
||||
forOp->erase();
|
||||
});
|
||||
}
|
||||
};
|
||||
|
||||
} // anonymous namespace
|
||||
|
||||
std::unique_ptr<Pass> mlir::createTritonGPUPrefetchPass() {
|
||||
return std::make_unique<PrefetchPass>();
|
||||
}
|
@@ -39,23 +39,23 @@ struct SwizzlePass : public TritonGPUSwizzleBase<SwizzlePass> {
|
||||
return SwizzleInfo{vec, perPhase, maxPhase};
|
||||
} else if (version == 2) {
|
||||
auto eltTy = ty.getElementType();
|
||||
std::vector<size_t> mat_shape = {8, 8,
|
||||
2 * 64 / eltTy.getIntOrFloatBitWidth()};
|
||||
std::vector<size_t> matShape = {8, 8,
|
||||
2 * 64 / eltTy.getIntOrFloatBitWidth()};
|
||||
// for now, disable swizzle when using transposed int8 tensor cores
|
||||
bool is_int8_mma = ty.getElementType().isInteger(8);
|
||||
if (is_int8_mma && order[0] == inner)
|
||||
bool isInt8Mma = ty.getElementType().isInteger(8);
|
||||
if (isInt8Mma && order[0] == inner)
|
||||
return noSwizzling;
|
||||
// compute swizzling for A operand
|
||||
if (opIdx == 0) {
|
||||
int vec = order[0] == 1 ? mat_shape[2] : mat_shape[0]; // k : m
|
||||
int mmaStride = order[0] == 1 ? mat_shape[0] : mat_shape[2];
|
||||
int vec = order[0] == 1 ? matShape[2] : matShape[0]; // k : m
|
||||
int mmaStride = order[0] == 1 ? matShape[0] : matShape[2];
|
||||
int maxPhase = mmaStride / perPhase;
|
||||
return SwizzleInfo{vec, perPhase, maxPhase};
|
||||
}
|
||||
// compute swizzling for B operand
|
||||
else if (opIdx == 1) {
|
||||
int vec = order[0] == 1 ? mat_shape[1] : mat_shape[2]; // n : k
|
||||
int mmaStride = order[0] == 1 ? mat_shape[2] : mat_shape[1];
|
||||
int vec = order[0] == 1 ? matShape[1] : matShape[2]; // n : k
|
||||
int mmaStride = order[0] == 1 ? matShape[2] : matShape[1];
|
||||
int maxPhase = mmaStride / perPhase;
|
||||
return SwizzleInfo{vec, perPhase, maxPhase};
|
||||
} else {
|
||||
@@ -67,32 +67,64 @@ struct SwizzlePass : public TritonGPUSwizzleBase<SwizzlePass> {
|
||||
|
||||
void runOnOperation() override {
|
||||
Operation *op = getOperation();
|
||||
op->walk([&](triton::DotOp dotOp) -> void {
|
||||
OpBuilder builder(dotOp);
|
||||
auto _retEncoding =
|
||||
dotOp.getResult().getType().cast<RankedTensorType>().getEncoding();
|
||||
auto retEncoding = _retEncoding.dyn_cast<triton::gpu::MmaEncodingAttr>();
|
||||
if (!retEncoding)
|
||||
return;
|
||||
for (int opIdx : {0, 1}) {
|
||||
Value op = dotOp.getOperand(opIdx);
|
||||
auto ty = op.getType().template cast<RankedTensorType>();
|
||||
// compute new swizzled encoding
|
||||
SwizzleInfo swizzle = getSwizzleMMA(opIdx, retEncoding, ty);
|
||||
auto newEncoding = triton::gpu::SharedEncodingAttr::get(
|
||||
&getContext(), swizzle.vec, swizzle.perPhase, swizzle.maxPhase,
|
||||
ty.getEncoding()
|
||||
.cast<triton::gpu::SharedEncodingAttr>()
|
||||
.getOrder());
|
||||
// create conversion
|
||||
auto newType = RankedTensorType::get(ty.getShape(), ty.getElementType(),
|
||||
newEncoding);
|
||||
Operation *newOp = builder.create<triton::gpu::ConvertLayoutOp>(
|
||||
op.getLoc(), newType, op);
|
||||
// bind new op to dot operand
|
||||
dotOp->replaceUsesOfWith(op, newOp->getResult(0));
|
||||
// replace blocked -> dot_op with
|
||||
// blocked -> shared -> dot_op in order to
|
||||
// expose opportunities for swizzling
|
||||
op->walk([&](triton::gpu::ConvertLayoutOp cvtOp) -> void {
|
||||
OpBuilder builder(cvtOp);
|
||||
auto srcType = cvtOp.getOperand().getType().cast<RankedTensorType>();
|
||||
auto dstType = cvtOp.getType().cast<RankedTensorType>();
|
||||
if (srcType.getEncoding().isa<triton::gpu::BlockedEncodingAttr>() &&
|
||||
dstType.getEncoding().isa<triton::gpu::DotOperandEncodingAttr>()) {
|
||||
auto tmpType =
|
||||
RankedTensorType::get(dstType.getShape(), dstType.getElementType(),
|
||||
triton::gpu::SharedEncodingAttr::get(
|
||||
op->getContext(), 1, 1, 1, {1, 0}));
|
||||
auto tmp = builder.create<triton::gpu::ConvertLayoutOp>(
|
||||
cvtOp.getLoc(), tmpType, cvtOp.getOperand());
|
||||
auto newConvert = builder.create<triton::gpu::ConvertLayoutOp>(
|
||||
cvtOp.getLoc(), dstType, tmp);
|
||||
cvtOp.replaceAllUsesWith(newConvert.getResult());
|
||||
}
|
||||
});
|
||||
|
||||
op->walk([&](triton::gpu::ConvertLayoutOp cvtOp) -> void {
|
||||
OpBuilder builder(cvtOp);
|
||||
auto arg = cvtOp.getOperand();
|
||||
auto retType = cvtOp.getResult().getType().cast<RankedTensorType>();
|
||||
auto retEncoding =
|
||||
retType.getEncoding().dyn_cast<triton::gpu::DotOperandEncodingAttr>();
|
||||
auto argType = arg.getType().cast<RankedTensorType>();
|
||||
auto argEncoding =
|
||||
argType.getEncoding().dyn_cast<triton::gpu::SharedEncodingAttr>();
|
||||
if (!argEncoding || !retEncoding)
|
||||
return;
|
||||
auto opIdx = retEncoding.getOpIdx();
|
||||
// compute new swizzled encoding
|
||||
auto parentEncoding =
|
||||
retEncoding.getParent().dyn_cast<triton::gpu::MmaEncodingAttr>();
|
||||
if (!parentEncoding)
|
||||
return;
|
||||
auto swizzleType = argType;
|
||||
if (arg.getDefiningOp() &&
|
||||
isa<tensor::ExtractSliceOp>(arg.getDefiningOp())) {
|
||||
swizzleType = arg.getDefiningOp()
|
||||
->getOperand(0)
|
||||
.getType()
|
||||
.cast<RankedTensorType>();
|
||||
}
|
||||
SwizzleInfo swizzle = getSwizzleMMA(opIdx, parentEncoding, swizzleType);
|
||||
auto newEncoding = triton::gpu::SharedEncodingAttr::get(
|
||||
&getContext(), swizzle.vec, swizzle.perPhase, swizzle.maxPhase,
|
||||
argEncoding.getOrder());
|
||||
// create conversion
|
||||
auto newType = RankedTensorType::get(
|
||||
argType.getShape(), argType.getElementType(), newEncoding);
|
||||
Operation *newArg = builder.create<triton::gpu::ConvertLayoutOp>(
|
||||
cvtOp.getLoc(), newType, arg);
|
||||
// bind new op to cvt operand
|
||||
cvtOp->replaceUsesOfWith(arg, newArg->getResult(0));
|
||||
});
|
||||
}
|
||||
};
|
||||
} // anonymous namespace
|
||||
|
@@ -95,8 +95,8 @@ TritonGPUConversionTarget::TritonGPUConversionTarget(
|
||||
dotOp.a().getType().cast<RankedTensorType>().getEncoding();
|
||||
Attribute bEncoding =
|
||||
dotOp.b().getType().cast<RankedTensorType>().getEncoding();
|
||||
if (aEncoding && aEncoding.isa<triton::gpu::SharedEncodingAttr>() &&
|
||||
bEncoding && bEncoding.isa<triton::gpu::SharedEncodingAttr>())
|
||||
if (aEncoding && aEncoding.isa<triton::gpu::DotOperandEncodingAttr>() &&
|
||||
bEncoding && bEncoding.isa<triton::gpu::DotOperandEncodingAttr>())
|
||||
return true;
|
||||
return false;
|
||||
});
|
||||
|
Reference in New Issue
Block a user