[Triton-MLIR] tt.dot
operands now must have DotOperand layout; also added prefetch pass prototype (#712)
Co-authored-by: Jokeren <kerenzhou@openai.com> Co-authored-by: Phil Tillet <phil@openai.com> Co-authored-by: Superjomn <yanchunwei@outlook.com>
This commit is contained in:
@@ -39,23 +39,23 @@ struct SwizzlePass : public TritonGPUSwizzleBase<SwizzlePass> {
|
||||
return SwizzleInfo{vec, perPhase, maxPhase};
|
||||
} else if (version == 2) {
|
||||
auto eltTy = ty.getElementType();
|
||||
std::vector<size_t> mat_shape = {8, 8,
|
||||
2 * 64 / eltTy.getIntOrFloatBitWidth()};
|
||||
std::vector<size_t> matShape = {8, 8,
|
||||
2 * 64 / eltTy.getIntOrFloatBitWidth()};
|
||||
// for now, disable swizzle when using transposed int8 tensor cores
|
||||
bool is_int8_mma = ty.getElementType().isInteger(8);
|
||||
if (is_int8_mma && order[0] == inner)
|
||||
bool isInt8Mma = ty.getElementType().isInteger(8);
|
||||
if (isInt8Mma && order[0] == inner)
|
||||
return noSwizzling;
|
||||
// compute swizzling for A operand
|
||||
if (opIdx == 0) {
|
||||
int vec = order[0] == 1 ? mat_shape[2] : mat_shape[0]; // k : m
|
||||
int mmaStride = order[0] == 1 ? mat_shape[0] : mat_shape[2];
|
||||
int vec = order[0] == 1 ? matShape[2] : matShape[0]; // k : m
|
||||
int mmaStride = order[0] == 1 ? matShape[0] : matShape[2];
|
||||
int maxPhase = mmaStride / perPhase;
|
||||
return SwizzleInfo{vec, perPhase, maxPhase};
|
||||
}
|
||||
// compute swizzling for B operand
|
||||
else if (opIdx == 1) {
|
||||
int vec = order[0] == 1 ? mat_shape[1] : mat_shape[2]; // n : k
|
||||
int mmaStride = order[0] == 1 ? mat_shape[2] : mat_shape[1];
|
||||
int vec = order[0] == 1 ? matShape[1] : matShape[2]; // n : k
|
||||
int mmaStride = order[0] == 1 ? matShape[2] : matShape[1];
|
||||
int maxPhase = mmaStride / perPhase;
|
||||
return SwizzleInfo{vec, perPhase, maxPhase};
|
||||
} else {
|
||||
@@ -67,32 +67,64 @@ struct SwizzlePass : public TritonGPUSwizzleBase<SwizzlePass> {
|
||||
|
||||
void runOnOperation() override {
|
||||
Operation *op = getOperation();
|
||||
op->walk([&](triton::DotOp dotOp) -> void {
|
||||
OpBuilder builder(dotOp);
|
||||
auto _retEncoding =
|
||||
dotOp.getResult().getType().cast<RankedTensorType>().getEncoding();
|
||||
auto retEncoding = _retEncoding.dyn_cast<triton::gpu::MmaEncodingAttr>();
|
||||
if (!retEncoding)
|
||||
return;
|
||||
for (int opIdx : {0, 1}) {
|
||||
Value op = dotOp.getOperand(opIdx);
|
||||
auto ty = op.getType().template cast<RankedTensorType>();
|
||||
// compute new swizzled encoding
|
||||
SwizzleInfo swizzle = getSwizzleMMA(opIdx, retEncoding, ty);
|
||||
auto newEncoding = triton::gpu::SharedEncodingAttr::get(
|
||||
&getContext(), swizzle.vec, swizzle.perPhase, swizzle.maxPhase,
|
||||
ty.getEncoding()
|
||||
.cast<triton::gpu::SharedEncodingAttr>()
|
||||
.getOrder());
|
||||
// create conversion
|
||||
auto newType = RankedTensorType::get(ty.getShape(), ty.getElementType(),
|
||||
newEncoding);
|
||||
Operation *newOp = builder.create<triton::gpu::ConvertLayoutOp>(
|
||||
op.getLoc(), newType, op);
|
||||
// bind new op to dot operand
|
||||
dotOp->replaceUsesOfWith(op, newOp->getResult(0));
|
||||
// replace blocked -> dot_op with
|
||||
// blocked -> shared -> dot_op in order to
|
||||
// expose opportunities for swizzling
|
||||
op->walk([&](triton::gpu::ConvertLayoutOp cvtOp) -> void {
|
||||
OpBuilder builder(cvtOp);
|
||||
auto srcType = cvtOp.getOperand().getType().cast<RankedTensorType>();
|
||||
auto dstType = cvtOp.getType().cast<RankedTensorType>();
|
||||
if (srcType.getEncoding().isa<triton::gpu::BlockedEncodingAttr>() &&
|
||||
dstType.getEncoding().isa<triton::gpu::DotOperandEncodingAttr>()) {
|
||||
auto tmpType =
|
||||
RankedTensorType::get(dstType.getShape(), dstType.getElementType(),
|
||||
triton::gpu::SharedEncodingAttr::get(
|
||||
op->getContext(), 1, 1, 1, {1, 0}));
|
||||
auto tmp = builder.create<triton::gpu::ConvertLayoutOp>(
|
||||
cvtOp.getLoc(), tmpType, cvtOp.getOperand());
|
||||
auto newConvert = builder.create<triton::gpu::ConvertLayoutOp>(
|
||||
cvtOp.getLoc(), dstType, tmp);
|
||||
cvtOp.replaceAllUsesWith(newConvert.getResult());
|
||||
}
|
||||
});
|
||||
|
||||
op->walk([&](triton::gpu::ConvertLayoutOp cvtOp) -> void {
|
||||
OpBuilder builder(cvtOp);
|
||||
auto arg = cvtOp.getOperand();
|
||||
auto retType = cvtOp.getResult().getType().cast<RankedTensorType>();
|
||||
auto retEncoding =
|
||||
retType.getEncoding().dyn_cast<triton::gpu::DotOperandEncodingAttr>();
|
||||
auto argType = arg.getType().cast<RankedTensorType>();
|
||||
auto argEncoding =
|
||||
argType.getEncoding().dyn_cast<triton::gpu::SharedEncodingAttr>();
|
||||
if (!argEncoding || !retEncoding)
|
||||
return;
|
||||
auto opIdx = retEncoding.getOpIdx();
|
||||
// compute new swizzled encoding
|
||||
auto parentEncoding =
|
||||
retEncoding.getParent().dyn_cast<triton::gpu::MmaEncodingAttr>();
|
||||
if (!parentEncoding)
|
||||
return;
|
||||
auto swizzleType = argType;
|
||||
if (arg.getDefiningOp() &&
|
||||
isa<tensor::ExtractSliceOp>(arg.getDefiningOp())) {
|
||||
swizzleType = arg.getDefiningOp()
|
||||
->getOperand(0)
|
||||
.getType()
|
||||
.cast<RankedTensorType>();
|
||||
}
|
||||
SwizzleInfo swizzle = getSwizzleMMA(opIdx, parentEncoding, swizzleType);
|
||||
auto newEncoding = triton::gpu::SharedEncodingAttr::get(
|
||||
&getContext(), swizzle.vec, swizzle.perPhase, swizzle.maxPhase,
|
||||
argEncoding.getOrder());
|
||||
// create conversion
|
||||
auto newType = RankedTensorType::get(
|
||||
argType.getShape(), argType.getElementType(), newEncoding);
|
||||
Operation *newArg = builder.create<triton::gpu::ConvertLayoutOp>(
|
||||
cvtOp.getLoc(), newType, arg);
|
||||
// bind new op to cvt operand
|
||||
cvtOp->replaceUsesOfWith(arg, newArg->getResult(0));
|
||||
});
|
||||
}
|
||||
};
|
||||
} // anonymous namespace
|
||||
|
Reference in New Issue
Block a user