[OPTIMIZER] Update the versionMinor in MMA layout for volta (#1014)

Continue the work https://github.com/openai/triton/pull/990

# Background
The `versionMinor` in MmaEncodingAttr holds some states of DotOp's
operands in Volta, while such operands will be modified by some
patterns, making the states out-of-date.

This PR helps to correct the states.

# Implementation
It adds three new patterns:

1. `CollectMmaToUpdateForVolta` helps to collect and build a map holding
the MmaEncodingAttr instances with wrong states and create new correct
ones for them,
2. `UpdateMMAVersionMinorForVolta` helps to replace the Ops generating
the wrong MmaEncodingAttr instances with new correct ones, currently it
supports the following Ops
    a. `convert_layout[X -> mma]`
    b. `arith.constant SplatAttr : !tensor<mma>`
    c. `dot ... : !tensor<mma>`

# Limitation
This PR chooses the mapping way to bypass the IR walk complexity from
the circular dependency between dot_operand[parent] and mma.
We use the MmaEncodingAttr instance as the mapping key, but there might
be multiple DotOp holding different DotOprand(IsMMAv1Row) that have the
same wrong MmaEncodingAttr instance.
To make each DotOp's (wrong) MmaEncodingAttr unique, we might need an ID
field to MmaEncodingAttr.
This commit is contained in:
Yan Chunwei
2022-12-28 12:24:01 +08:00
committed by GitHub
parent fd2da4aff6
commit 2ba74d2729
3 changed files with 281 additions and 26 deletions

View File

@ -204,7 +204,12 @@ struct DotOpMmaV1ConversionHelper {
offA[i] = add(mul(offA0I, strideA0), mul(offA1, strideA1));
}
Type f16x2Ty = vec_ty(f16_ty, 2);
Type elemX2Ty = vec_ty(f16_ty, 2);
Type elemPtrTy = ptr_ty(f16_ty);
if (tensorTy.getElementType().isBF16()) {
elemX2Ty = vec_ty(i16_ty, 2);
elemPtrTy = ptr_ty(i16_ty);
}
// prepare arguments
SmallVector<Value> ptrA(numPtrA);
@ -213,30 +218,28 @@ struct DotOpMmaV1ConversionHelper {
for (int i = 0; i < numPtrA; i++)
ptrA[i] = gep(ptr_ty(f16_ty), smemBase, offA[i]);
Type f16PtrTy = ptr_ty(f16_ty);
auto ld = [&](decltype(has) &vals, int m, int k, Value val0, Value val1) {
vals[{m, k}] = {val0, val1};
};
auto loadA = [&](int m, int k) {
int offidx = (isARow ? k / 4 : m) % numPtrA;
Value thePtrA = gep(f16PtrTy, smemBase, offA[offidx]);
Value thePtrA = gep(elemPtrTy, smemBase, offA[offidx]);
int stepAM = isARow ? m : m / numPtrA * numPtrA;
int stepAK = isARow ? k / (numPtrA * vecA) * (numPtrA * vecA) : k;
Value offset = add(mul(i32_val(stepAM * strideRepM), strideAM),
mul(i32_val(stepAK), strideAK));
Value pa = gep(f16PtrTy, thePtrA, offset);
Value pa = gep(elemPtrTy, thePtrA, offset);
Type aPtrTy = ptr_ty(vec_ty(i32_ty, std::max<int>(vecA / 2, 1)), 3);
Value ha = load(bitcast(pa, aPtrTy));
// record lds that needs to be moved
Value ha00 = bitcast(extract_element(ha, i32_val(0)), f16x2Ty);
Value ha01 = bitcast(extract_element(ha, i32_val(1)), f16x2Ty);
Value ha00 = bitcast(extract_element(ha, i32_val(0)), elemX2Ty);
Value ha01 = bitcast(extract_element(ha, i32_val(1)), elemX2Ty);
ld(has, m, k, ha00, ha01);
if (vecA > 4) {
Value ha10 = bitcast(extract_element(ha, i32_val(2)), f16x2Ty);
Value ha11 = bitcast(extract_element(ha, i32_val(3)), f16x2Ty);
Value ha10 = bitcast(extract_element(ha, i32_val(2)), elemX2Ty);
Value ha11 = bitcast(extract_element(ha, i32_val(3)), elemX2Ty);
if (isARow)
ld(has, m, k + 4, ha10, ha11);
else
@ -256,7 +259,7 @@ struct DotOpMmaV1ConversionHelper {
elems.push_back(item.second.second);
}
Type resTy = struct_ty(SmallVector<Type>(elems.size(), f16x2Ty));
Type resTy = struct_ty(SmallVector<Type>(elems.size(), elemX2Ty));
Value res = getStructFromElements(loc, elems, rewriter, resTy);
return res;
}
@ -319,8 +322,12 @@ struct DotOpMmaV1ConversionHelper {
offB[i] = add(mul(offB0I, strideB0), mul(offB1, strideB1));
}
Type f16PtrTy = ptr_ty(f16_ty);
Type f16x2Ty = vec_ty(f16_ty, 2);
Type elemPtrTy = ptr_ty(f16_ty);
Type elemX2Ty = vec_ty(f16_ty, 2);
if (tensorTy.getElementType().isBF16()) {
elemPtrTy = ptr_ty(i16_ty);
elemX2Ty = vec_ty(i16_ty, 2);
}
SmallVector<Value> ptrB(numPtrB);
ValueTable hbs;
@ -339,17 +346,17 @@ struct DotOpMmaV1ConversionHelper {
int stepBK = isBRow ? K : K / (numPtrB * vecB) * (numPtrB * vecB);
Value offset = add(mul(i32_val(stepBN * strideRepN), strideBN),
mul(i32_val(stepBK), strideBK));
Value pb = gep(f16PtrTy, thePtrB, offset);
Value pb = gep(elemPtrTy, thePtrB, offset);
Value hb =
load(bitcast(pb, ptr_ty(vec_ty(i32_ty, std::max(vecB / 2, 1)), 3)));
// record lds that needs to be moved
Value hb00 = bitcast(extract_element(hb, i32_val(0)), f16x2Ty);
Value hb01 = bitcast(extract_element(hb, i32_val(1)), f16x2Ty);
Value hb00 = bitcast(extract_element(hb, i32_val(0)), elemX2Ty);
Value hb01 = bitcast(extract_element(hb, i32_val(1)), elemX2Ty);
ld(hbs, n, K, hb00, hb01);
if (vecB > 4) {
Value hb10 = bitcast(extract_element(hb, i32_val(2)), f16x2Ty);
Value hb11 = bitcast(extract_element(hb, i32_val(3)), f16x2Ty);
Value hb10 = bitcast(extract_element(hb, i32_val(2)), elemX2Ty);
Value hb11 = bitcast(extract_element(hb, i32_val(3)), elemX2Ty);
if (isBRow)
ld(hbs, n + 1, K, hb10, hb11);
else
@ -369,8 +376,7 @@ struct DotOpMmaV1ConversionHelper {
elems.push_back(item.second.first);
elems.push_back(item.second.second);
}
Type fp16x2Ty = vec_ty(type::f16Ty(ctx), 2);
Type resTy = struct_ty(SmallVector<Type>(elems.size(), fp16x2Ty));
Type resTy = struct_ty(SmallVector<Type>(elems.size(), elemX2Ty));
Value res = getStructFromElements(loc, elems, rewriter, resTy);
return res;
}

View File

@ -22,6 +22,10 @@
using namespace mlir;
namespace {
#include "TritonGPUCombine.inc"
using triton::DotOp;
using triton::gpu::ConvertLayoutOp;
using triton::gpu::DotOperandEncodingAttr;
using triton::gpu::MmaEncodingAttr;
// -----------------------------------------------------------------------------
//
@ -1019,6 +1023,7 @@ public:
dstDotOperandLayout.getIsMMAv1Row().cast<BoolAttr>().getValue();
if ((order[0] == 1 && isMMAv1Row) || (order[0] == 0 && !isMMAv1Row))
return failure();
auto newIsRow = BoolAttr::get(op->getContext(), !isMMAv1Row);
auto newDstEncoding = triton::gpu::DotOperandEncodingAttr::get(
op->getContext(), dstDotOperandLayout.getOpIdx(),
@ -1060,7 +1065,8 @@ public:
auto dotOp = cast<triton::DotOp>(op);
// TODO: Check data-types and SM compatibility
auto oldRetType = dotOp.getResult().getType().cast<RankedTensorType>();
if (oldRetType.getEncoding().isa<triton::gpu::MmaEncodingAttr>())
if (!oldRetType.getEncoding() ||
oldRetType.getEncoding().isa<triton::gpu::MmaEncodingAttr>())
return failure();
auto AType = dotOp.getOperand(0).getType().cast<RankedTensorType>();
@ -1170,7 +1176,8 @@ public:
for (size_t i = 0; i < newInitArgs.size(); i++) {
auto initArg = newInitArgs[i];
auto regionArg = forOp.getRegionIterArgs()[i];
if (newInitArgs[i].getType() != forOp.getRegionIterArgs()[i].getType()) {
if (newInitArgs[i].getType() != forOp.getRegionIterArgs()[i].getType() ||
newInitArgs[i].getType() != forOp.getResultTypes()[i]) {
shouldRematerialize = true;
break;
}
@ -1186,15 +1193,207 @@ public:
BlockAndValueMapping mapping;
for (const auto &arg : llvm::enumerate(forOp.getRegionIterArgs()))
mapping.map(arg.value(), newForOp.getRegionIterArgs()[arg.index()]);
mapping.map(forOp.getInductionVar(), newForOp.getInductionVar());
for (Operation &op : forOp.getBody()->getOperations()) {
Operation *newOp = rewriter.clone(op, mapping);
rewriter.clone(op, mapping);
}
rewriter.replaceOp(forOp, newForOp.getResults());
return success();
}
};
// This pattern collects the wrong Mma those need to update and create the right
// ones for each.
class CollectMmaToUpdateForVolta : public mlir::RewritePattern {
DenseMap<MmaEncodingAttr, MmaEncodingAttr> &mmaToUpdate;
public:
CollectMmaToUpdateForVolta(
mlir::MLIRContext *ctx,
DenseMap<MmaEncodingAttr, MmaEncodingAttr> &mmaToUpdate)
: mlir::RewritePattern(triton::DotOp::getOperationName(), 1, ctx),
mmaToUpdate(mmaToUpdate) {}
mlir::LogicalResult
matchAndRewrite(mlir::Operation *op,
mlir::PatternRewriter &rewriter) const override {
auto dotOp = cast<triton::DotOp>(op);
auto *ctx = dotOp->getContext();
auto AT = dotOp.a().getType().cast<RankedTensorType>();
auto BT = dotOp.b().getType().cast<RankedTensorType>();
auto DT = dotOp.d().getType().cast<RankedTensorType>();
if (!DT.getEncoding())
return failure();
auto mmaLayout = DT.getEncoding().dyn_cast<MmaEncodingAttr>();
if (!(mmaLayout && mmaLayout.isVolta()))
return failure();
// Has processed.
if (mmaToUpdate.count(mmaLayout))
return failure();
auto dotOperandA = AT.getEncoding().cast<DotOperandEncodingAttr>();
auto dotOperandB = BT.getEncoding().cast<DotOperandEncodingAttr>();
bool isARow = dotOperandA.getIsMMAv1Row().cast<BoolAttr>().getValue();
bool isBRow = dotOperandB.getIsMMAv1Row().cast<BoolAttr>().getValue();
auto [isARow_, isBRow_, isAVec4, isBVec4] =
mmaLayout.decodeVoltaLayoutStates();
if (isARow_ == isARow && isBRow_ == isBRow) {
return failure(); // No need to update
}
auto newMmaLayout = MmaEncodingAttr::get(
ctx, mmaLayout.getVersionMajor(), mmaLayout.getWarpsPerCTA(),
AT.getShape(), BT.getShape(), isARow, isBRow);
// Collect the wrong MMA Layouts, and mark need to update.
mmaToUpdate.try_emplace(mmaLayout, newMmaLayout);
return failure();
}
};
// Correct the versionMinor field in MmaEncodingAttr for Volta.
class UpdateMMAVersionMinorForVolta : public mlir::RewritePattern {
const DenseMap<MmaEncodingAttr, MmaEncodingAttr> &mmaToUpdate;
enum class Kind {
kUnk,
kCvtToMma,
kCvtToDotOp,
kDot,
kConstant,
};
mutable Kind rewriteKind{Kind::kUnk};
public:
UpdateMMAVersionMinorForVolta(
mlir::MLIRContext *ctx, llvm::StringRef opName,
const DenseMap<MmaEncodingAttr, MmaEncodingAttr> &mmaToUpdate)
: RewritePattern(opName, 1 /*benefit*/, ctx), mmaToUpdate(mmaToUpdate) {}
LogicalResult match(Operation *op) const override {
MmaEncodingAttr mma;
if (mmaToUpdate.empty())
return failure();
if (op->getNumResults() != 1)
return failure();
auto tensorTy = op->getResult(0).getType().dyn_cast<RankedTensorType>();
if (!tensorTy)
return failure();
// ConvertLayoutOp
if (auto cvt = llvm::dyn_cast<ConvertLayoutOp>(op)) {
// cvt X -> dot_operand
if (auto dotOperand =
tensorTy.getEncoding().dyn_cast<DotOperandEncodingAttr>()) {
mma = dotOperand.getParent().dyn_cast<MmaEncodingAttr>();
rewriteKind = Kind::kCvtToDotOp;
if (mma && mmaToUpdate.count(mma))
return success();
}
if ((mma = tensorTy.getEncoding().dyn_cast<MmaEncodingAttr>())) {
// cvt X -> mma
rewriteKind = Kind::kCvtToMma;
if (mma && mmaToUpdate.count(mma))
return success();
}
} else if (auto dot = llvm::dyn_cast<DotOp>(op)) {
// DotOp
mma = dot.d()
.getType()
.cast<RankedTensorType>()
.getEncoding()
.dyn_cast<MmaEncodingAttr>();
rewriteKind = Kind::kDot;
} else if (auto constant = llvm::dyn_cast<arith::ConstantOp>(op)) {
// ConstantOp
mma = tensorTy.getEncoding().dyn_cast<MmaEncodingAttr>();
rewriteKind = Kind::kConstant;
}
return success(mma && mmaToUpdate.count(mma));
}
void rewrite(Operation *op, PatternRewriter &rewriter) const override {
switch (rewriteKind) {
case Kind::kDot:
rewriteDot(op, rewriter);
break;
case Kind::kConstant:
rewriteConstant(op, rewriter);
break;
case Kind::kCvtToDotOp:
rewriteCvtDotOp(op, rewriter);
break;
case Kind::kCvtToMma:
rewriteCvtToMma(op, rewriter);
break;
default:
llvm::report_fatal_error("Not supported rewrite kind");
}
}
private:
void rewriteCvtDotOp(Operation *op, PatternRewriter &rewriter) const {
auto *ctx = op->getContext();
auto cvt = llvm::cast<ConvertLayoutOp>(op);
auto tensorTy = cvt.result().getType().cast<RankedTensorType>();
auto dotOperand = tensorTy.getEncoding().cast<DotOperandEncodingAttr>();
MmaEncodingAttr newMma =
mmaToUpdate.lookup(dotOperand.getParent().cast<MmaEncodingAttr>());
auto newDotOperand = DotOperandEncodingAttr::get(
ctx, dotOperand.getOpIdx(), newMma, dotOperand.getIsMMAv1Row());
auto newTensorTy = RankedTensorType::get(
tensorTy.getShape(), tensorTy.getElementType(), newDotOperand);
rewriter.replaceOpWithNewOp<ConvertLayoutOp>(op, newTensorTy,
cvt.getOperand());
}
void rewriteDot(Operation *op, PatternRewriter &rewriter) const {
auto *ctx = op->getContext();
auto dot = llvm::cast<DotOp>(op);
auto tensorTy = dot.d().getType().cast<RankedTensorType>();
auto mma = tensorTy.getEncoding().cast<MmaEncodingAttr>();
auto newMma = mmaToUpdate.lookup(mma);
auto newTensorTy = RankedTensorType::get(tensorTy.getShape(),
tensorTy.getElementType(), newMma);
rewriter.replaceOpWithNewOp<DotOp>(op, newTensorTy, dot.a(), dot.b(),
dot.c(), dot.allowTF32());
}
void rewriteCvtToMma(Operation *op, PatternRewriter &rewriter) const {
auto *ctx = op->getContext();
auto cvt = llvm::cast<ConvertLayoutOp>(op);
auto tensorTy = cvt.result().getType().cast<RankedTensorType>();
auto mma = tensorTy.getEncoding().cast<MmaEncodingAttr>();
auto newMma = mmaToUpdate.lookup(mma);
auto newTensorTy = RankedTensorType::get(tensorTy.getShape(),
tensorTy.getElementType(), newMma);
rewriter.replaceOpWithNewOp<ConvertLayoutOp>(op, newTensorTy,
cvt.getOperand());
}
void rewriteConstant(Operation *op, PatternRewriter &rewriter) const {
auto *ctx = op->getContext();
auto constant = llvm::cast<arith::ConstantOp>(op);
auto tensorTy = constant.getResult().getType().dyn_cast<RankedTensorType>();
auto mma = tensorTy.getEncoding().cast<MmaEncodingAttr>();
auto newMma = mmaToUpdate.lookup(mma);
auto newTensorTy = RankedTensorType::get(tensorTy.getShape(),
tensorTy.getElementType(), newMma);
if (auto attr = constant.getValue().dyn_cast<SplatElementsAttr>()) {
auto newRet =
SplatElementsAttr::get(newTensorTy, attr.getSplatValue<Attribute>());
rewriter.replaceOpWithNewOp<arith::ConstantOp>(op, newTensorTy, newRet);
return;
}
assert(false && "Not supported ConstantOp value type");
}
};
} // namespace
#define GEN_PASS_CLASSES
@ -1229,6 +1428,28 @@ public:
signalPassFailure();
}
llvm::DenseMap<MmaEncodingAttr, MmaEncodingAttr> mmaToUpdate;
{
mlir::RewritePatternSet patterns(context);
patterns.add<CollectMmaToUpdateForVolta>(context, mmaToUpdate);
if (applyPatternsAndFoldGreedily(m, std::move(patterns)).failed())
signalPassFailure();
}
{
mlir::RewritePatternSet patterns(context);
patterns.add<UpdateMMAVersionMinorForVolta>(
context, DotOp::getOperationName(), mmaToUpdate);
patterns.add<UpdateMMAVersionMinorForVolta>(
context, ConvertLayoutOp::getOperationName(), mmaToUpdate);
patterns.add<UpdateMMAVersionMinorForVolta>(
context, arith::ConstantOp::getOperationName(), mmaToUpdate);
mlir::GreedyRewriteConfig config;
config.useTopDownTraversal = true;
if (applyPatternsAndFoldGreedily(m, std::move(patterns), config).failed())
signalPassFailure();
}
mlir::RewritePatternSet loopFixup(context);
loopFixup.add<FixupLoop>(context);
if (applyPatternsAndFoldGreedily(m, std::move(loopFixup)).failed()) {

View File

@ -1,4 +1,4 @@
// RUN: triton-opt %s -tritongpu-combine 2>&1 | FileCheck %s
// RUN: triton-opt %s -split-input-file -tritongpu-combine 2>&1 | FileCheck %s
#layout0 = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}>
#layout1 = #triton_gpu.blocked<{sizePerThread = [4], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}>
@ -7,7 +7,6 @@
// CHECK: [[row_layout:#.*]] = #triton_gpu.blocked<{sizePerThread = [1, 4], threadsPerWarp = [2, 16], warpsPerCTA = [1, 4], order = [1, 0]}>
// CHECK: [[col_layout:#.*]] = #triton_gpu.blocked<{sizePerThread = [4, 1], threadsPerWarp = [16, 2], warpsPerCTA = [4, 1], order = [0, 1]}>
// CHECK: [[col_layout_novec:#.*]] = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [32, 1], warpsPerCTA = [4, 1], order = [0, 1]}>
func @cst() -> tensor<1024xi32, #layout1> {
%cst = arith.constant dense<0> : tensor<1024xi32, #layout0>
%1 = triton_gpu.convert_layout %cst : (tensor<1024xi32, #layout0>) -> tensor<1024xi32, #layout1>
@ -62,9 +61,9 @@ func @remat(%arg0: i32) -> tensor<1024xi32, #layout1> {
// CHECK-LABEL: transpose
func @transpose(%arg0: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg1: i32 {tt.divisibility = 16 : i32}, %arg2: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg3: i32 {tt.divisibility = 16 : i32}) {
// CHECK-NOT: triton_gpu.convert_layout
// CHECK: [[loaded_val:%.*]] = tt.load {{.*}}, %cst, %cst_0 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<64x64xf32, [[row_layout]]>
// CHECK: [[loaded_val:%.*]] = tt.load {{.*}}, {{%cst.*}}, {{%cst.*}} {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<64x64xf32, [[row_layout]]>
// CHECK: [[cvt_val:%.*]] = triton_gpu.convert_layout [[loaded_val]] : (tensor<64x64xf32, [[row_layout]]>) -> tensor<64x64xf32, [[col_layout]]>
// CHECK: tt.store {{.*}}, [[cvt_val]], %cst_1 : tensor<64x64xf32, [[col_layout]]>
// CHECK: tt.store {{.*}}, [[cvt_val]], {{%cst.*}} : tensor<64x64xf32, [[col_layout]]>
// CHECK: return
%cst = arith.constant dense<0.000000e+00> : tensor<64x64xf32, #blocked1>
%cst_0 = arith.constant dense<true> : tensor<64x64xi1, #blocked1>
@ -184,3 +183,32 @@ func @vecadd(%arg0: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg1: !tt.ptr<f3
tt.store %21, %22 : tensor<256xf32, #layout1>
return
}
// -----
// check the UpdateMMAVersionMinorForVolta pattern
#blocked0 = #triton_gpu.blocked<{sizePerThread = [1, 4], threadsPerWarp = [8, 4], warpsPerCTA = [1, 1], order = [1, 0]}>
#shared0 = #triton_gpu.shared<{vec = 1, perPhase=2, maxPhase=8 ,order = [1, 0]}>
#mma0 = #triton_gpu.mma<{versionMajor=1, versionMinor=0, warpsPerCTA=[1,1]}>
// Here, the isMMAv1Row of a and b's dot_operands mismatch #mma0's versionMinor,
// and the pattern should update the versionMinor.
#dot_operand_a = #triton_gpu.dot_op<{opIdx=0, parent=#mma0, isMMAv1Row=true}>
#dot_operand_b = #triton_gpu.dot_op<{opIdx=1, parent=#mma0, isMMAv1Row=false}>
// It creates a new MMA layout to fit with $a and $b's dot_operand
// CHECK: [[new_mma:#mma.*]] = #triton_gpu.mma<{versionMajor = 1, versionMinor = 11, warpsPerCTA = [1, 1]}>
module attributes {"triton_gpu.num-warps" = 1 : i32} {
// CHECK-LABEL: dot_mmav1
func @dot_mmav1(%A: tensor<16x16xf16, #blocked0>, %B: tensor<16x16xf16, #blocked0>) -> tensor<16x16xf32, #blocked0> {
%C = arith.constant dense<0.000000e+00> : tensor<16x16xf32, #blocked0>
%AA = triton_gpu.convert_layout %A : (tensor<16x16xf16, #blocked0>) -> tensor<16x16xf16, #dot_operand_a>
%BB = triton_gpu.convert_layout %B : (tensor<16x16xf16, #blocked0>) -> tensor<16x16xf16, #dot_operand_b>
%CC = triton_gpu.convert_layout %C : (tensor<16x16xf32, #blocked0>) -> tensor<16x16xf32, #mma0>
// CHECK: {{.*}} = tt.dot {{.*}}, {{.*}}, %cst {allowTF32 = true} : tensor<16x16xf16, #triton_gpu.dot_op<{opIdx = 0, parent = [[new_mma]], isMMAv1Row = true}>> * tensor<16x16xf16, #triton_gpu.dot_op<{opIdx = 1, parent = [[new_mma]], isMMAv1Row = true}>> -> tensor<16x16xf32, [[new_mma]]>
%D = tt.dot %AA, %BB, %CC {allowTF32 = true} : tensor<16x16xf16, #dot_operand_a> * tensor<16x16xf16, #dot_operand_b> -> tensor<16x16xf32, #mma0>
%res = triton_gpu.convert_layout %D : (tensor<16x16xf32, #mma0>) -> tensor<16x16xf32, #blocked0>
return %res : tensor<16x16xf32, #blocked0>
}
}