[BACKEND] MMA->DotOperand conversion for chain dot of float32 tensors (#962)
Co-authored-by: Philippe Tillet <phil@openai.com>
This commit is contained in:
@@ -2682,63 +2682,24 @@ public:
|
||||
dstLayout.isa<SliceEncodingAttr>())) {
|
||||
return lowerDistributedToDistributed(op, adaptor, rewriter);
|
||||
}
|
||||
// dot_op<opIdx=0, parent=#mma> = #mma
|
||||
// when #mma = MmaEncoding<version=2, warpsPerCTA=[..., 1]>
|
||||
if (srcLayout.isa<MmaEncodingAttr>() &&
|
||||
dstLayout.isa<DotOperandEncodingAttr>()) {
|
||||
auto srcMmaLayout = srcLayout.cast<MmaEncodingAttr>();
|
||||
auto dstDotLayout = dstLayout.cast<DotOperandEncodingAttr>();
|
||||
if (srcMmaLayout.getWarpsPerCTA()[1] == 1 &&
|
||||
dstDotLayout.getOpIdx() == 0 &&
|
||||
dstDotLayout.getParent() == srcMmaLayout) {
|
||||
// get source values
|
||||
Location loc = op->getLoc();
|
||||
auto vals = getElementsFromStruct(loc, adaptor.src(), rewriter);
|
||||
unsigned elems = getElemsPerThread(srcTy);
|
||||
Type elemTy =
|
||||
this->getTypeConverter()->convertType(srcTy.getElementType());
|
||||
// for the destination type, we need to pack values together
|
||||
// so they can be consumed by tensor core operations
|
||||
unsigned vecSize =
|
||||
std::max<unsigned>(32 / elemTy.getIntOrFloatBitWidth(), 1);
|
||||
Type vecTy = vec_ty(elemTy, vecSize);
|
||||
SmallVector<Type> types(elems / vecSize, vecTy);
|
||||
SmallVector<Value> vecVals;
|
||||
for (unsigned i = 0; i < elems; i += vecSize) {
|
||||
Value packed = rewriter.create<LLVM::UndefOp>(loc, vecTy);
|
||||
for (unsigned j = 0; j < vecSize; j++)
|
||||
packed = insert_element(vecTy, packed, vals[i + j], i32_val(j));
|
||||
vecVals.push_back(packed);
|
||||
}
|
||||
|
||||
// This needs to be ordered the same way that
|
||||
// ldmatrix.x4 would order it
|
||||
// TODO: this needs to be refactor so we don't
|
||||
// implicitly depends on how emitOffsetsForMMAV2
|
||||
// is implemented
|
||||
SmallVector<Value> reorderedVals;
|
||||
for (unsigned i = 0; i < vecVals.size(); i += 4) {
|
||||
reorderedVals.push_back(vecVals[i]);
|
||||
reorderedVals.push_back(vecVals[i + 2]);
|
||||
reorderedVals.push_back(vecVals[i + 1]);
|
||||
reorderedVals.push_back(vecVals[i + 3]);
|
||||
}
|
||||
|
||||
// return composeValuesToDotOperandLayoutStruct(ha, numRepM, numRepK);
|
||||
|
||||
Type structTy =
|
||||
LLVM::LLVMStructType::getLiteral(this->getContext(), types);
|
||||
Value view =
|
||||
getStructFromElements(loc, reorderedVals, rewriter, structTy);
|
||||
rewriter.replaceOp(op, view);
|
||||
return success();
|
||||
}
|
||||
return lowerMmaToDotOperand(op, adaptor, rewriter);
|
||||
}
|
||||
// TODO: to be implemented
|
||||
llvm_unreachable("unsupported layout conversion");
|
||||
return failure();
|
||||
}
|
||||
|
||||
static bool isMmaToDotShortcut(MmaEncodingAttr &mmaLayout,
|
||||
DotOperandEncodingAttr &dotOperandLayout) {
|
||||
// dot_op<opIdx=0, parent=#mma> = #mma
|
||||
// when #mma = MmaEncoding<version=2, warpsPerCTA=[..., 1]>
|
||||
return mmaLayout.getWarpsPerCTA()[1] == 1 &&
|
||||
dotOperandLayout.getOpIdx() == 0 &&
|
||||
dotOperandLayout.getParent() == mmaLayout;
|
||||
}
|
||||
|
||||
static void storeBlockedToShared(Value src, Value llSrc,
|
||||
ArrayRef<Value> srcStrides,
|
||||
ArrayRef<Value> srcIndices, Value dst,
|
||||
@@ -3003,6 +2964,11 @@ private:
|
||||
lowerSharedToDotOperand(triton::gpu::ConvertLayoutOp op, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const;
|
||||
|
||||
// mma -> dot_operand
|
||||
LogicalResult lowerMmaToDotOperand(triton::gpu::ConvertLayoutOp op,
|
||||
OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const;
|
||||
|
||||
// shared -> dot_operand if the result layout is mma
|
||||
Value lowerSharedToDotOperandMMA(
|
||||
triton::gpu::ConvertLayoutOp op, OpAdaptor adaptor,
|
||||
@@ -3209,6 +3175,58 @@ LogicalResult ConvertLayoutOpConversion::lowerBlockedToShared(
|
||||
return success();
|
||||
}
|
||||
|
||||
LogicalResult ConvertLayoutOpConversion::lowerMmaToDotOperand(
|
||||
triton::gpu::ConvertLayoutOp op, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const {
|
||||
auto loc = op.getLoc();
|
||||
auto srcTy = op.src().getType().cast<RankedTensorType>();
|
||||
auto dstTy = op.result().getType().cast<RankedTensorType>();
|
||||
auto srcLayout = srcTy.getEncoding();
|
||||
auto dstLayout = dstTy.getEncoding();
|
||||
auto srcMmaLayout = srcLayout.cast<MmaEncodingAttr>();
|
||||
auto dstDotLayout = dstLayout.cast<DotOperandEncodingAttr>();
|
||||
if (isMmaToDotShortcut(srcMmaLayout, dstDotLayout)) {
|
||||
// get source values
|
||||
auto vals = getElementsFromStruct(loc, adaptor.src(), rewriter);
|
||||
unsigned elems = getElemsPerThread(srcTy);
|
||||
Type elemTy = this->getTypeConverter()->convertType(srcTy.getElementType());
|
||||
// for the destination type, we need to pack values together
|
||||
// so they can be consumed by tensor core operations
|
||||
unsigned vecSize =
|
||||
std::max<unsigned>(32 / elemTy.getIntOrFloatBitWidth(), 1);
|
||||
Type vecTy = vec_ty(elemTy, vecSize);
|
||||
SmallVector<Type> types(elems / vecSize, vecTy);
|
||||
SmallVector<Value> vecVals;
|
||||
for (unsigned i = 0; i < elems; i += vecSize) {
|
||||
Value packed = rewriter.create<LLVM::UndefOp>(loc, vecTy);
|
||||
for (unsigned j = 0; j < vecSize; j++)
|
||||
packed = insert_element(vecTy, packed, vals[i + j], i32_val(j));
|
||||
vecVals.push_back(packed);
|
||||
}
|
||||
|
||||
// This needs to be ordered the same way that
|
||||
// ldmatrix.x4 would order it
|
||||
// TODO: this needs to be refactor so we don't
|
||||
// implicitly depends on how emitOffsetsForMMAV2
|
||||
// is implemented
|
||||
SmallVector<Value> reorderedVals;
|
||||
for (unsigned i = 0; i < vecVals.size(); i += 4) {
|
||||
reorderedVals.push_back(vecVals[i]);
|
||||
reorderedVals.push_back(vecVals[i + 2]);
|
||||
reorderedVals.push_back(vecVals[i + 1]);
|
||||
reorderedVals.push_back(vecVals[i + 3]);
|
||||
}
|
||||
|
||||
// return composeValuesToDotOperandLayoutStruct(ha, numRepM, numRepK);
|
||||
|
||||
Type structTy = LLVM::LLVMStructType::getLiteral(this->getContext(), types);
|
||||
Value view = getStructFromElements(loc, reorderedVals, rewriter, structTy);
|
||||
rewriter.replaceOp(op, view);
|
||||
return success();
|
||||
}
|
||||
return failure();
|
||||
}
|
||||
|
||||
struct InsertSliceOpConversion
|
||||
: public ConvertTritonGPUOpToLLVMPattern<tensor::InsertSliceOp> {
|
||||
using ConvertTritonGPUOpToLLVMPattern<
|
||||
@@ -4625,6 +4643,34 @@ class ConvertTritonGPUToLLVM
|
||||
: public ConvertTritonGPUToLLVMBase<ConvertTritonGPUToLLVM> {
|
||||
|
||||
private:
|
||||
void decomposeMmaToDotOperand(ModuleOp mod, int numWarps) {
|
||||
// replace `mma -> dot_op` with `mma -> blocked -> dot_op`
|
||||
// unless certain conditions are met
|
||||
mod.walk([&](triton::gpu::ConvertLayoutOp cvtOp) -> void {
|
||||
OpBuilder builder(cvtOp);
|
||||
auto srcType = cvtOp.getOperand().getType().cast<RankedTensorType>();
|
||||
auto dstType = cvtOp.getType().cast<RankedTensorType>();
|
||||
auto srcMma =
|
||||
srcType.getEncoding().dyn_cast<triton::gpu::MmaEncodingAttr>();
|
||||
auto dstDotOp =
|
||||
dstType.getEncoding().dyn_cast<triton::gpu::DotOperandEncodingAttr>();
|
||||
if (srcMma && dstDotOp &&
|
||||
!ConvertLayoutOpConversion::isMmaToDotShortcut(srcMma, dstDotOp)) {
|
||||
auto tmpType = RankedTensorType::get(
|
||||
dstType.getShape(), dstType.getElementType(),
|
||||
triton::gpu::BlockedEncodingAttr::get(
|
||||
mod.getContext(), srcType.getShape(), getSizePerThread(srcMma),
|
||||
getOrder(srcMma), numWarps));
|
||||
auto tmp = builder.create<triton::gpu::ConvertLayoutOp>(
|
||||
cvtOp.getLoc(), tmpType, cvtOp.getOperand());
|
||||
auto newConvert = builder.create<triton::gpu::ConvertLayoutOp>(
|
||||
cvtOp.getLoc(), dstType, tmp);
|
||||
cvtOp.replaceAllUsesWith(newConvert.getResult());
|
||||
cvtOp.erase();
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
void decomposeBlockedToDotOperand(ModuleOp mod) {
|
||||
// replace `blocked -> dot_op` with `blocked -> shared -> dot_op`
|
||||
// because the codegen doesn't handle `blocked -> dot_op` directly
|
||||
@@ -4771,6 +4817,8 @@ public:
|
||||
// separation between 1/4 is that, step 3 is out of the scope of Dialect
|
||||
// Conversion, thus we need to make sure the smem is not revised during the
|
||||
// conversion of step 4.
|
||||
decomposeMmaToDotOperand(mod, numWarps);
|
||||
|
||||
decomposeBlockedToDotOperand(mod);
|
||||
|
||||
decomposeInsertSliceAsyncOp(mod);
|
||||
|
Reference in New Issue
Block a user