Compare commits
4 Commits
keren/v100
...
phil/mma-v
Author | SHA1 | Date | |
---|---|---|---|
|
58d2867fe6 | ||
|
fa6dbbff60 | ||
|
13644e7ac4 | ||
|
0d27912554 |
@@ -416,15 +416,35 @@ In TritonGPU dialect, considering `d = tt.dot a, b, c`
|
||||
tt.dot's operands a and b must be of DotOperandEncodingAttr layout.
|
||||
a's opIdx is 0, b's opIdx is 1.
|
||||
The parend field in DotOperandEncodingAttr is the layout of d.
|
||||
|
||||
For MMA v1, an additional attribute `isMMAv1Row` determines whether e.g. the a operand is used
|
||||
in the context of an mma.884.row.col or an mma.884.col.col operation. See the PTX ISA documentation
|
||||
section 9.7.13.4.1 for more details.
|
||||
}];
|
||||
|
||||
let parameters = (
|
||||
ins
|
||||
"unsigned":$opIdx,
|
||||
"Attribute":$parent
|
||||
"Attribute":$parent,
|
||||
"Attribute":$isMMAv1Row
|
||||
);
|
||||
|
||||
let builders = [
|
||||
AttrBuilder<(ins "unsigned":$opIdx,
|
||||
"Attribute":$parent), [{
|
||||
Attribute isMMAv1Row;
|
||||
if(parent.isa<MmaEncodingAttr>() &&
|
||||
parent.cast<MmaEncodingAttr>().getVersion() == 1){
|
||||
isMMAv1Row = BoolAttr::get(context, true);
|
||||
}
|
||||
return $_get(context, opIdx, parent, isMMAv1Row);
|
||||
}]>
|
||||
|
||||
];
|
||||
|
||||
let extraClassDeclaration = extraBaseClassDeclaration;
|
||||
}
|
||||
|
||||
|
||||
|
||||
#endif
|
||||
|
@@ -1356,6 +1356,20 @@ Value DotOpMmaV1ConversionHelper::loadA(
|
||||
Value tensor, bool transA, const SharedMemoryObject &smemObj, Value thread,
|
||||
Location loc, ConversionPatternRewriter &rewriter) const {
|
||||
|
||||
// [1, 0] (isRow = True)
|
||||
// x x x x || x x x x
|
||||
// x x x x || x x x x
|
||||
// stride = [8, 1]
|
||||
// strideA0 = strideAk = 1
|
||||
// strideA1 = strideAm = 8
|
||||
|
||||
// [0, 1] (isRow = False)
|
||||
// x x x x || x x x x
|
||||
// x x x x || x x x x
|
||||
// stride = [1, 2]
|
||||
// strideA0 = strideAm = 1
|
||||
// strideA1 = strideAk = 2
|
||||
|
||||
auto *ctx = rewriter.getContext();
|
||||
auto tensorTy = tensor.getType().cast<RankedTensorType>();
|
||||
auto sharedLayout = tensorTy.getEncoding().cast<SharedEncodingAttr>();
|
||||
@@ -1364,8 +1378,8 @@ Value DotOpMmaV1ConversionHelper::loadA(
|
||||
SmallVector<unsigned> order(sharedLayout.getOrder().begin(),
|
||||
sharedLayout.getOrder().end());
|
||||
|
||||
Value cSwizzleOffset = smemObj.getCSwizzleOffset(order[0]);
|
||||
Value smemBase = smemObj.getBaseBeforeSwizzle(order[0], loc, rewriter);
|
||||
// Value smemBase = smemObj.getBaseBeforeSwizzle(order[0], loc, rewriter);
|
||||
Value smemBase = smemObj.base;
|
||||
|
||||
bool isARow = order[0] != 0;
|
||||
AParam param(isARow);
|
||||
@@ -1387,6 +1401,7 @@ Value DotOpMmaV1ConversionHelper::loadA(
|
||||
Value strideA0 = isARow ? strideAK : strideAM;
|
||||
Value strideA1 = isARow ? strideAM : strideAK;
|
||||
|
||||
smemBase = gep(ptr_ty(f16_ty), smemBase, Value(smemObj.offsets[1]));
|
||||
int strideRepM = wpt[0] * fpw[0] * 8;
|
||||
int strideRepK = 1;
|
||||
|
||||
@@ -1401,7 +1416,9 @@ Value DotOpMmaV1ConversionHelper::loadA(
|
||||
Value offA0 = isARow ? offsetAK : offsetAM;
|
||||
Value offA1 = isARow ? offsetAM : offsetAK;
|
||||
Value phaseA = urem(udiv(offA1, i32_val(perPhaseA)), i32_val(maxPhaseA));
|
||||
offA0 = add(offA0, cSwizzleOffset);
|
||||
// offA0 = add(offA0, smemObj.offsets[order[0]]);
|
||||
// offA1 = add(offA1, smemObj.offsets[order[1]]);
|
||||
|
||||
SmallVector<Value> offA(numPtrA);
|
||||
for (int i = 0; i < numPtrA; i++) {
|
||||
Value offA0I = add(offA0, i32_val(i * (isARow ? 4 : strideRepM)));
|
||||
@@ -1422,6 +1439,7 @@ Value DotOpMmaV1ConversionHelper::loadA(
|
||||
|
||||
Type f16PtrTy = ptr_ty(f16_ty);
|
||||
|
||||
|
||||
auto ld = [&](decltype(has) &vals, int m, int k, Value val0, Value val1) {
|
||||
vals[{m, k}] = {val0, val1};
|
||||
};
|
||||
@@ -1451,7 +1469,10 @@ Value DotOpMmaV1ConversionHelper::loadA(
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
unsigned numM = getNumM(shape, order);
|
||||
llvm::outs() << "LOAD A " << numM << " " << NK << "\n";
|
||||
|
||||
for (unsigned k = 0; k < NK; k += 4)
|
||||
for (unsigned m = 0; m < numM / 2; ++m)
|
||||
loadA(m, k);
|
||||
|
@@ -3427,8 +3427,24 @@ Value ConvertLayoutOpConversion::lowerSharedToDotOperandMMA(
|
||||
}
|
||||
} else if (!isOuter && mmaLayout.getVersion() == 1 &&
|
||||
isHMMA) { // tensor core v1
|
||||
// vprintf("offset 0", smemObj.offsets[0]}, rewriter);
|
||||
DotOpMmaV1ConversionHelper helper(mmaLayout);
|
||||
bool isMMAv1Row =
|
||||
dotOperandLayout.getIsMMAv1Row().cast<BoolAttr>().getValue();
|
||||
auto srcSharedLayout = src.getType()
|
||||
.cast<RankedTensorType>()
|
||||
.getEncoding()
|
||||
.cast<SharedEncodingAttr>();
|
||||
|
||||
// Can only convert [1, 0] to row or [0, 1] to col for now
|
||||
if ((srcSharedLayout.getOrder()[0] == 1 && !isMMAv1Row) ||
|
||||
(srcSharedLayout.getOrder()[0] == 0 && isMMAv1Row)) {
|
||||
llvm::errs() << "Unsupported Shared -> DotOperand[MMAv1] conversion\n";
|
||||
return Value();
|
||||
}
|
||||
|
||||
if (dotOperandLayout.getOpIdx() == 0) { // operand $a
|
||||
// LLVM::vprintf_array(i32_val(0), smemObj.offsets, "offsets ", "%d", rewriter);
|
||||
// TODO[Superjomn]: transA is not available here.
|
||||
bool transA = false;
|
||||
res = helper.loadA(src, transA, smemObj, getThreadId(rewriter, loc), loc,
|
||||
@@ -3540,6 +3556,14 @@ DotOpConversion::convertMMA884(triton::DotOp op, DotOpAdaptor adaptor,
|
||||
.cast<RankedTensorType>()
|
||||
.getEncoding()
|
||||
.cast<MmaEncodingAttr>();
|
||||
auto ALayout = A.getType()
|
||||
.cast<RankedTensorType>()
|
||||
.getEncoding()
|
||||
.cast<DotOperandEncodingAttr>();
|
||||
auto BLayout = B.getType()
|
||||
.cast<RankedTensorType>()
|
||||
.getEncoding()
|
||||
.cast<DotOperandEncodingAttr>();
|
||||
|
||||
auto ATensorTy = A.getType().cast<RankedTensorType>();
|
||||
auto BTensorTy = B.getType().cast<RankedTensorType>();
|
||||
@@ -3551,12 +3575,8 @@ DotOpConversion::convertMMA884(triton::DotOp op, DotOpAdaptor adaptor,
|
||||
auto DShape = DTensorTy.getShape();
|
||||
auto wpt = mmaLayout.getWarpsPerCTA();
|
||||
|
||||
// TODO[Superjomn]: order cannot accessed in DotOp.
|
||||
SmallVector<unsigned> AOrder({1, 0});
|
||||
SmallVector<unsigned> BOrder({1, 0});
|
||||
|
||||
bool isARow = AOrder[0] != 0;
|
||||
bool isBRow = BOrder[0] != 0;
|
||||
bool isARow = ALayout.getIsMMAv1Row().cast<BoolAttr>().getValue();
|
||||
bool isBRow = BLayout.getIsMMAv1Row().cast<BoolAttr>().getValue();
|
||||
bool isAVec4 = !isARow && AShape[isARow] <= 16; // fp16*4 = 16bytes
|
||||
bool isBVec4 = isBRow && BShape[isBRow] <= 16;
|
||||
// TODO[Superjomn]: ld.v4 is not supported.
|
||||
@@ -4698,6 +4718,47 @@ private:
|
||||
});
|
||||
}
|
||||
|
||||
void rewriteConvertToDotOperand(ModuleOp mod) {
|
||||
mod.walk([&](triton::gpu::ConvertLayoutOp cvt){
|
||||
OpBuilder builder(cvt);
|
||||
auto srcType = cvt.getOperand().getType().cast<RankedTensorType>();
|
||||
auto dstType = cvt.getResult().getType().cast<RankedTensorType>();
|
||||
// order
|
||||
ArrayRef<unsigned> order;
|
||||
if(auto srcBlockedLayout =
|
||||
srcType.getEncoding().dyn_cast<triton::gpu::BlockedEncodingAttr>())
|
||||
order = srcBlockedLayout.getOrder();
|
||||
else if(auto srcSharedLayout =
|
||||
srcType.getEncoding().dyn_cast<triton::gpu::SharedEncodingAttr>())
|
||||
order = srcSharedLayout.getOrder();
|
||||
else
|
||||
return;
|
||||
// dot operand output
|
||||
auto dstDotOperandLayout =
|
||||
dstType.getEncoding().dyn_cast<triton::gpu::DotOperandEncodingAttr>();
|
||||
if (!dstDotOperandLayout)
|
||||
return;
|
||||
unsigned opIdx = dstDotOperandLayout.getOpIdx();
|
||||
if(!dstDotOperandLayout.getIsMMAv1Row())
|
||||
return;
|
||||
bool isMMAv1Row = dstDotOperandLayout.getIsMMAv1Row().cast<BoolAttr>().getValue();
|
||||
if((order[0] == 1 && isMMAv1Row) ||
|
||||
(order[0] == 0 && !isMMAv1Row))
|
||||
return;
|
||||
auto newIsRow = BoolAttr::get(cvt.getContext(), !isMMAv1Row);
|
||||
auto newDstEncoding = triton::gpu::DotOperandEncodingAttr::get(
|
||||
cvt.getContext(), dstDotOperandLayout.getOpIdx(), dstDotOperandLayout.getParent(),
|
||||
newIsRow);
|
||||
auto newDstType = RankedTensorType::get(
|
||||
dstType.getShape(),
|
||||
dstType.getElementType(), newDstEncoding);
|
||||
auto newCvt = builder.create<triton::gpu::ConvertLayoutOp>(
|
||||
cvt.getLoc(), newDstType, cvt.getOperand());
|
||||
cvt.replaceAllUsesWith(newCvt.getResult());
|
||||
cvt.erase();
|
||||
});
|
||||
}
|
||||
|
||||
void decomposeInsertSliceAsyncOp(ModuleOp mod) {
|
||||
AxisInfoAnalysis axisInfoAnalysis(mod.getContext());
|
||||
axisInfoAnalysis.run(mod);
|
||||
@@ -4817,6 +4878,7 @@ public:
|
||||
// separation between 1/4 is that, step 3 is out of the scope of Dialect
|
||||
// Conversion, thus we need to make sure the smem is not revised during the
|
||||
// conversion of step 4.
|
||||
rewriteConvertToDotOperand(mod);
|
||||
decomposeMmaToDotOperand(mod, numWarps);
|
||||
|
||||
decomposeBlockedToDotOperand(mod);
|
||||
@@ -4827,6 +4889,7 @@ public:
|
||||
MembarAnalysis membarPass(&allocation);
|
||||
membarPass.run();
|
||||
|
||||
llvm::outs() << mod << "\n";
|
||||
RewritePatternSet scf_patterns(context);
|
||||
mlir::populateLoopToStdConversionPatterns(scf_patterns);
|
||||
mlir::ConversionTarget scf_target(*context);
|
||||
|
@@ -589,15 +589,24 @@ Attribute DotOperandEncodingAttr::parse(AsmParser &parser, Type type) {
|
||||
return {};
|
||||
unsigned opIdx = attrs.get("opIdx").cast<IntegerAttr>().getInt();
|
||||
Attribute parent = attrs.get("parent");
|
||||
|
||||
Attribute isMMAv1Row;
|
||||
if(parent.isa<MmaEncodingAttr>() &&
|
||||
parent.cast<MmaEncodingAttr>().getVersion() == 1){
|
||||
isMMAv1Row = attrs.get("isMMAv1Row");
|
||||
if(!isMMAv1Row)
|
||||
llvm::report_fatal_error("isMMAv1Row attribute is missing");
|
||||
}
|
||||
return parser.getChecked<DotOperandEncodingAttr>(parser.getContext(), opIdx,
|
||||
parent);
|
||||
parent, isMMAv1Row);
|
||||
}
|
||||
|
||||
void DotOperandEncodingAttr::print(mlir::AsmPrinter &printer) const {
|
||||
printer << "<{"
|
||||
<< "opIdx = " << getOpIdx() << ", "
|
||||
<< "parent = " << getParent() << "}>";
|
||||
<< "parent = " << getParent();
|
||||
if(getIsMMAv1Row())
|
||||
printer << ", isMMAv1Row = " << getIsMMAv1Row();
|
||||
printer << "}>";
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
@@ -713,6 +713,55 @@ public:
|
||||
}
|
||||
};
|
||||
|
||||
class OptimizeConvertToDotOperand : public mlir::RewritePattern {
|
||||
public:
|
||||
OptimizeConvertToDotOperand(mlir::MLIRContext *context)
|
||||
: RewritePattern(triton::gpu::ConvertLayoutOp::getOperationName(), 1,
|
||||
context) {}
|
||||
|
||||
mlir::LogicalResult
|
||||
matchAndRewrite(mlir::Operation *op,
|
||||
mlir::PatternRewriter &rewriter) const override {
|
||||
auto cvt = cast<triton::gpu::ConvertLayoutOp>(op);
|
||||
auto srcType = cvt.getOperand().getType().cast<RankedTensorType>();
|
||||
auto dstType = cvt.getResult().getType().cast<RankedTensorType>();
|
||||
// order
|
||||
ArrayRef<unsigned> order;
|
||||
if(auto srcBlockedLayout =
|
||||
srcType.getEncoding().dyn_cast<triton::gpu::BlockedEncodingAttr>())
|
||||
order = srcBlockedLayout.getOrder();
|
||||
else if(auto srcSharedLayout =
|
||||
srcType.getEncoding().dyn_cast<triton::gpu::SharedEncodingAttr>())
|
||||
order = srcSharedLayout.getOrder();
|
||||
else
|
||||
return failure();
|
||||
// dot operand output
|
||||
auto dstDotOperandLayout =
|
||||
dstType.getEncoding().dyn_cast<triton::gpu::DotOperandEncodingAttr>();
|
||||
if (!dstDotOperandLayout)
|
||||
return failure();
|
||||
unsigned opIdx = dstDotOperandLayout.getOpIdx();
|
||||
if(!dstDotOperandLayout.getIsMMAv1Row())
|
||||
return failure();
|
||||
bool isMMAv1Row = dstDotOperandLayout.getIsMMAv1Row().cast<BoolAttr>().getValue();
|
||||
if((order[0] == 1 && isMMAv1Row) ||
|
||||
(order[0] == 0 && !isMMAv1Row))
|
||||
return failure();
|
||||
auto newIsRow = BoolAttr::get(op->getContext(), !isMMAv1Row);
|
||||
auto newDstEncoding = triton::gpu::DotOperandEncodingAttr::get(
|
||||
op->getContext(), dstDotOperandLayout.getOpIdx(), dstDotOperandLayout.getParent(),
|
||||
newIsRow);
|
||||
auto newDstType = RankedTensorType::get(
|
||||
dstType.getShape(),
|
||||
dstType.getElementType(), newDstEncoding);
|
||||
auto newCvt = rewriter.create<triton::gpu::ConvertLayoutOp>(
|
||||
op->getLoc(), newDstType, cvt.getOperand());
|
||||
rewriter.replaceOp(op, newCvt.getResult());
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
class BlockedToMMA : public mlir::RewritePattern {
|
||||
int computeCapability;
|
||||
|
||||
@@ -770,14 +819,28 @@ public:
|
||||
Value b = dotOp.b();
|
||||
auto oldAType = a.getType().cast<RankedTensorType>();
|
||||
auto oldBType = b.getType().cast<RankedTensorType>();
|
||||
auto oldAOrder = oldAType.getEncoding().cast<triton::gpu::DotOperandEncodingAttr>()
|
||||
.getParent().cast<triton::gpu::BlockedEncodingAttr>().getOrder();
|
||||
auto oldBOrder = oldBType.getEncoding().cast<triton::gpu::DotOperandEncodingAttr>()
|
||||
.getParent().cast<triton::gpu::BlockedEncodingAttr>().getOrder();
|
||||
Attribute isMMAv1RowA;
|
||||
Attribute isMMAv1RowB;
|
||||
if(version == 1){
|
||||
isMMAv1RowA = BoolAttr::get(getContext(), oldAOrder[0] == 1);
|
||||
isMMAv1RowB = BoolAttr::get(getContext(), oldBOrder[0] == 1);
|
||||
}
|
||||
|
||||
auto newAType = RankedTensorType::get(
|
||||
oldAType.getShape(), oldAType.getElementType(),
|
||||
triton::gpu::DotOperandEncodingAttr::get(oldAType.getContext(), 0,
|
||||
newRetType.getEncoding()));
|
||||
newRetType.getEncoding(),
|
||||
isMMAv1RowA));
|
||||
auto newBType = RankedTensorType::get(
|
||||
oldBType.getShape(), oldBType.getElementType(),
|
||||
triton::gpu::DotOperandEncodingAttr::get(oldBType.getContext(), 1,
|
||||
newRetType.getEncoding()));
|
||||
newRetType.getEncoding(),
|
||||
isMMAv1RowB));
|
||||
|
||||
a = rewriter.create<triton::gpu::ConvertLayoutOp>(a.getLoc(), newAType, a);
|
||||
b = rewriter.create<triton::gpu::ConvertLayoutOp>(b.getLoc(), newBType, b);
|
||||
auto newDot = rewriter.create<triton::DotOp>(
|
||||
@@ -808,6 +871,7 @@ public:
|
||||
mlir::RewritePatternSet patterns(context);
|
||||
|
||||
patterns.add<OptimizeBlockedToShared>(context);
|
||||
// patterns.add<OptimizeConvertToDotOperand>(context);
|
||||
patterns.add<SimplifyConversion>(context);
|
||||
patterns.add<DecomposeDotOperand>(context);
|
||||
patterns.add<RematerializeBackward>(context);
|
||||
@@ -818,6 +882,7 @@ public:
|
||||
if (applyPatternsAndFoldGreedily(m, std::move(patterns)).failed()) {
|
||||
signalPassFailure();
|
||||
}
|
||||
|
||||
}
|
||||
};
|
||||
|
||||
|
@@ -297,15 +297,23 @@ def test_gemm_fp32(M, N, K, num_warps, block_M, block_N, block_K, allow_tf32):
|
||||
# NOTE this is useful only on Volta GPU.
|
||||
@pytest.mark.parametrize('SIZE_M,SIZE_N,SIZE_K,NUM_WARPS,BLOCK_SIZE_M,BLOCK_SIZE_N,BLOCK_SIZE_K,TRANS_A,TRANS_B', [
|
||||
# Non-forloop
|
||||
[16, 16, 16, 1, 16, 16, 16, False, False],
|
||||
[16, 16, 32, 1, 16, 16, 32, False, False],
|
||||
[32, 16, 32, 1, 32, 16, 32, False, False],
|
||||
[32, 32, 32, 1, 32, 32, 32, False, False],
|
||||
[128, 32, 32, 1, 128, 32, 32, False, False],
|
||||
# [16, 16, 16, 1, 16, 16, 16, False, False],
|
||||
# [16, 16, 32, 1, 16, 16, 32, False, False],
|
||||
# [32, 16, 32, 1, 32, 16, 32, False, False],
|
||||
# [32, 32, 32, 1, 32, 32, 32, False, False],
|
||||
# [128, 32, 32, 1, 128, 32, 32, False, False],
|
||||
|
||||
# split-K
|
||||
[16, 16, 32, 1, 16, 16, 16, False, False],
|
||||
[64, 64, 128, 1, 64, 64, 32, False, False],
|
||||
# [128, 32, 32, 1, 128, 32, 32, True, False],
|
||||
# [128, 32, 32, 1, 128, 32, 32, True, True],
|
||||
|
||||
# # split-K
|
||||
# [16, 16, 32, 1, 16, 16, 16, False, False],
|
||||
# [64, 64, 128, 1, 64, 64, 32, False, False],
|
||||
|
||||
# [16, 16, 32, 1, 16, 16, 16, True, False],
|
||||
# [16, 16, 32, 1, 16, 16, 16, True, True],
|
||||
[64, 64, 64, 1, 64, 64, 32, True, False],
|
||||
])
|
||||
def test_gemm_for_mmav1(SIZE_M, SIZE_N, SIZE_K, NUM_WARPS, BLOCK_SIZE_M, BLOCK_SIZE_N, BLOCK_SIZE_K, TRANS_A, TRANS_B):
|
||||
test_gemm(SIZE_M, SIZE_N, SIZE_K, NUM_WARPS, BLOCK_SIZE_M, BLOCK_SIZE_N, BLOCK_SIZE_K, TRANS_A, TRANS_B)
|
||||
|
||||
|
@@ -1402,9 +1402,9 @@ def compile(fn, **kwargs):
|
||||
"ttir": (lambda path: _triton.ir.parse_mlir_module(path, context),
|
||||
lambda src: ast_to_ttir(src, signature, configs[0], constants)),
|
||||
"ttgir": (lambda path: _triton.ir.parse_mlir_module(path, context),
|
||||
lambda src: ttir_to_ttgir(src, num_warps, num_stages, capability)),
|
||||
lambda src: ttir_to_ttgir(src, num_warps, num_stages, 70)),
|
||||
"llir": (lambda path: Path(path).read_bytes(),
|
||||
lambda src: ttgir_to_llir(src, extern_libs, capability)),
|
||||
lambda src: ttgir_to_llir(src, extern_libs, 70)),
|
||||
"ptx": (lambda path: Path(path).read_text(),
|
||||
lambda src: llir_to_ptx(src, capability)),
|
||||
"cubin": (lambda path: Path(path).read_bytes(),
|
||||
|
@@ -879,8 +879,8 @@ module attributes {"triton_gpu.num-warps" = 4 : i32} {
|
||||
#blocked = #triton_gpu.blocked<{sizePerThread = [1, 4], threadsPerWarp = [2, 16], warpsPerCTA = [1, 4], order = [1, 0]}>
|
||||
#shared = #triton_gpu.shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [1, 0]}>
|
||||
#mma = #triton_gpu.mma<{version = 1, warpsPerCTA = [2, 2]}>
|
||||
#dot_operand_a = #triton_gpu.dot_op<{opIdx=0, parent=#mma}>
|
||||
#dot_operand_b = #triton_gpu.dot_op<{opIdx=1, parent=#mma}>
|
||||
#dot_operand_a = #triton_gpu.dot_op<{opIdx=0, parent=#mma, isMMAv1Row=true}>
|
||||
#dot_operand_b = #triton_gpu.dot_op<{opIdx=1, parent=#mma, isMMAv1Row=true}>
|
||||
module attributes {"triton_gpu.num-warps" = 4 : i32} {
|
||||
func @matmul884_kernel_dot_operand_layout(%ptr:!tt.ptr<f32> {tt.divisibility = 16 : i32},
|
||||
%a:tensor<128x32xf16, #shared>, %b:tensor<32x256xf16, #shared>) {
|
||||
|
Reference in New Issue
Block a user