[FRONTEND][BACKEND] Added trans
instruction; made flash attention bwd pass work (#943)
This commit is contained in:
@@ -26,7 +26,7 @@ ChangeResult SharedMemoryAliasAnalysis::visitOperation(
|
||||
// These ops may allocate a new shared memory buffer.
|
||||
auto result = op->getResult(0);
|
||||
// FIXME(Keren): extract and insert are always alias for now
|
||||
if (isa<tensor::ExtractSliceOp>(op)) {
|
||||
if (isa<tensor::ExtractSliceOp, triton::TransOp>(op)) {
|
||||
// extract_slice %src
|
||||
aliasInfo = AliasInfo(operands[0]->getValue());
|
||||
pessimistic = false;
|
||||
|
@@ -105,7 +105,7 @@ bool maybeSharedAllocationOp(Operation *op) {
|
||||
}
|
||||
|
||||
bool maybeAliasOp(Operation *op) {
|
||||
return isa<tensor::ExtractSliceOp>(op) ||
|
||||
return isa<tensor::ExtractSliceOp>(op) || isa<triton::TransOp>(op) ||
|
||||
isa<triton::gpu::InsertSliceAsyncOp>(op) ||
|
||||
isa<tensor::InsertSliceOp>(op);
|
||||
}
|
||||
|
@@ -2716,6 +2716,9 @@ public:
|
||||
auto dstSharedLayout = dstTy.getEncoding().cast<SharedEncodingAttr>();
|
||||
auto inOrd = srcBlockedLayout.getOrder();
|
||||
auto outOrd = dstSharedLayout.getOrder();
|
||||
if (inOrd != outOrd)
|
||||
llvm_unreachable(
|
||||
"blocked -> shared with different order not yet implemented");
|
||||
unsigned inVec =
|
||||
inOrd == outOrd ? srcBlockedLayout.getSizePerThread()[inOrd[0]] : 1;
|
||||
unsigned outVec = dstSharedLayout.getVec();
|
||||
@@ -2775,7 +2778,8 @@ public:
|
||||
getMultiDimIndex<unsigned>(linearRepIdx, reps, inOrd);
|
||||
for (unsigned linearWordIdx = 0; linearWordIdx < numWordsEachRep;
|
||||
++linearWordIdx) {
|
||||
// step 1: recover the multidim_index from the index of input_elements
|
||||
// step 1: recover the multidim_index from the index of
|
||||
// input_elements
|
||||
auto multiDimWordIdx =
|
||||
getMultiDimIndex<unsigned>(linearWordIdx, wordsInEachRep, inOrd);
|
||||
SmallVector<Value> multiDimIdx(2);
|
||||
@@ -3711,6 +3715,33 @@ DotOpConversion::convertFMADot(triton::DotOp op, OpAdaptor adaptor,
|
||||
|
||||
/// ====================== mma codegen end ============================
|
||||
|
||||
/// ====================== trans codegen begin ============================
|
||||
|
||||
struct TransOpConversion
|
||||
: public ConvertTritonGPUOpToLLVMPattern<triton::TransOp> {
|
||||
using ConvertTritonGPUOpToLLVMPattern<
|
||||
triton::TransOp>::ConvertTritonGPUOpToLLVMPattern;
|
||||
|
||||
LogicalResult
|
||||
matchAndRewrite(triton::TransOp op, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
Location loc = op->getLoc();
|
||||
auto srcSmemObj =
|
||||
getSharedMemoryObjectFromStruct(loc, adaptor.src(), rewriter);
|
||||
SmallVector<Value> dstStrides = {srcSmemObj.strides[1],
|
||||
srcSmemObj.strides[0]};
|
||||
SmallVector<Value> dstOffsets = {srcSmemObj.offsets[1],
|
||||
srcSmemObj.offsets[0]};
|
||||
auto dstSmemObj =
|
||||
SharedMemoryObject(srcSmemObj.base, dstStrides, dstOffsets);
|
||||
auto retVal = getStructFromSharedMemoryObject(loc, dstSmemObj, rewriter);
|
||||
rewriter.replaceOp(op, retVal);
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
/// ====================== trans codegen end ============================
|
||||
|
||||
Value convertSplatLikeOpWithMmaLayout(const MmaEncodingAttr &layout,
|
||||
Type resType, Type elemType,
|
||||
Value constVal,
|
||||
@@ -4538,6 +4569,7 @@ void populateTritonToLLVMPatterns(mlir::LLVMTypeConverter &typeConverter,
|
||||
patterns.add<ViewLikeOpConversion<triton::ExpandDimsOp>>(typeConverter,
|
||||
benefit);
|
||||
patterns.add<DotOpConversion>(typeConverter, allocation, smem, benefit);
|
||||
patterns.add<TransOpConversion>(typeConverter, benefit);
|
||||
patterns.add<PrintfOpConversion>(typeConverter, benefit);
|
||||
}
|
||||
|
||||
|
@@ -252,6 +252,51 @@ struct TritonDotPattern : public OpConversionPattern<triton::DotOp> {
|
||||
}
|
||||
};
|
||||
|
||||
struct TritonTransPattern : public OpConversionPattern<triton::TransOp> {
|
||||
|
||||
using OpConversionPattern<triton::TransOp>::OpConversionPattern;
|
||||
|
||||
LogicalResult
|
||||
matchAndRewrite(triton::TransOp op, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
Value src = adaptor.src();
|
||||
auto srcType = src.getType().cast<RankedTensorType>();
|
||||
Attribute srcEncoding = srcType.getEncoding();
|
||||
if (!srcEncoding)
|
||||
return failure();
|
||||
if (!srcEncoding.isa<triton::gpu::SharedEncodingAttr>()) {
|
||||
// TODO: end-to-end correctness is broken if
|
||||
// the input is blocked and the output is shared
|
||||
// with different order. Maybe a backend issue in BlockedToShared?
|
||||
SmallVector<unsigned> order = {1, 0};
|
||||
if (auto srcBlockedEncoding =
|
||||
srcEncoding.dyn_cast<triton::gpu::BlockedEncodingAttr>())
|
||||
llvm::copy(srcBlockedEncoding.getOrder(), order.begin());
|
||||
srcEncoding =
|
||||
triton::gpu::SharedEncodingAttr::get(getContext(), 1, 1, 1, order);
|
||||
srcType = RankedTensorType::get(srcType.getShape(),
|
||||
srcType.getElementType(), srcEncoding);
|
||||
src = rewriter.create<triton::gpu::ConvertLayoutOp>(src.getLoc(), srcType,
|
||||
src);
|
||||
}
|
||||
auto srcSharedEncoding =
|
||||
srcEncoding.cast<triton::gpu::SharedEncodingAttr>();
|
||||
SmallVector<unsigned> retOrder(srcSharedEncoding.getOrder().begin(),
|
||||
srcSharedEncoding.getOrder().end());
|
||||
SmallVector<int64_t> retShapes(srcType.getShape().begin(),
|
||||
srcType.getShape().end());
|
||||
std::reverse(retOrder.begin(), retOrder.end());
|
||||
std::reverse(retShapes.begin(), retShapes.end());
|
||||
auto retEncoding =
|
||||
triton::gpu::SharedEncodingAttr::get(getContext(), 1, 1, 1, retOrder);
|
||||
auto retType =
|
||||
RankedTensorType::get(retShapes, srcType.getElementType(), retEncoding);
|
||||
|
||||
rewriter.replaceOpWithNewOp<triton::TransOp>(op, retType, src);
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
struct TritonLoadPattern : public OpConversionPattern<triton::LoadOp> {
|
||||
using OpConversionPattern<triton::LoadOp>::OpConversionPattern;
|
||||
|
||||
@@ -390,9 +435,10 @@ void populateTritonPatterns(TritonGPUTypeConverter &typeConverter,
|
||||
TritonGenericPattern<triton::PtrToIntOp>,
|
||||
TritonGenericPattern<triton::SplatOp>, TritonBroadcastPattern,
|
||||
TritonGenericPattern<triton::AddPtrOp>, TritonReducePattern,
|
||||
TritonExpandDimsPattern, TritonMakeRangePattern, TritonDotPattern,
|
||||
TritonLoadPattern, TritonStorePattern, TritonExtElemwisePattern,
|
||||
TritonPrintfPattern, TritonAtomicRMWPattern>(typeConverter, context);
|
||||
TritonTransPattern, TritonExpandDimsPattern, TritonMakeRangePattern,
|
||||
TritonDotPattern, TritonLoadPattern, TritonStorePattern,
|
||||
TritonExtElemwisePattern, TritonPrintfPattern, TritonAtomicRMWPattern>(
|
||||
typeConverter, context);
|
||||
}
|
||||
|
||||
//
|
||||
|
@@ -178,6 +178,10 @@ public:
|
||||
!isSharedEncoding(convert.getResult())) {
|
||||
return mlir::failure();
|
||||
}
|
||||
if (isSharedEncoding(convert.getOperand()) &&
|
||||
isSharedEncoding(convert.getResult())) {
|
||||
return mlir::failure();
|
||||
}
|
||||
auto srcType = convert.getOperand().getType().cast<RankedTensorType>();
|
||||
auto srcShared =
|
||||
srcType.getEncoding().dyn_cast<triton::gpu::SharedEncodingAttr>();
|
||||
@@ -661,6 +665,54 @@ SmallVector<unsigned, 2> warpsPerTileV2(triton::DotOp dotOp,
|
||||
|
||||
} // namespace
|
||||
|
||||
class OptimizeBlockedToShared : public mlir::RewritePattern {
|
||||
public:
|
||||
OptimizeBlockedToShared(mlir::MLIRContext *context)
|
||||
: RewritePattern(triton::gpu::ConvertLayoutOp::getOperationName(), 1,
|
||||
context) {}
|
||||
|
||||
mlir::LogicalResult
|
||||
matchAndRewrite(mlir::Operation *op,
|
||||
mlir::PatternRewriter &rewriter) const override {
|
||||
auto cvt = cast<triton::gpu::ConvertLayoutOp>(op);
|
||||
auto srcType = cvt.getOperand().getType().cast<RankedTensorType>();
|
||||
auto dstType = cvt.getResult().getType().cast<RankedTensorType>();
|
||||
auto srcBlockedLayout =
|
||||
srcType.getEncoding().dyn_cast<triton::gpu::BlockedEncodingAttr>();
|
||||
auto dstSharedLayout =
|
||||
dstType.getEncoding().dyn_cast<triton::gpu::SharedEncodingAttr>();
|
||||
if (!srcBlockedLayout || !dstSharedLayout)
|
||||
return failure();
|
||||
if (srcBlockedLayout.getOrder() == dstSharedLayout.getOrder())
|
||||
return failure();
|
||||
// For now only works if single use is transpose
|
||||
// TODO: rematerialize #shared uses
|
||||
auto users = op->getUsers();
|
||||
if (std::distance(users.begin(), users.end()) != 1 ||
|
||||
!isa<triton::TransOp>(*users.begin()))
|
||||
return failure();
|
||||
|
||||
auto tmpShared = triton::gpu::SharedEncodingAttr::get(
|
||||
op->getContext(), dstSharedLayout.getVec(),
|
||||
dstSharedLayout.getPerPhase(), dstSharedLayout.getMaxPhase(),
|
||||
srcBlockedLayout.getOrder());
|
||||
auto tmpType = RankedTensorType::get(srcType.getShape(),
|
||||
srcType.getElementType(), tmpShared);
|
||||
auto tmpCvt = rewriter.create<triton::gpu::ConvertLayoutOp>(
|
||||
op->getLoc(), tmpType, cvt.getOperand());
|
||||
|
||||
auto newDstType = RankedTensorType::get(
|
||||
users.begin()->getResultTypes()[0].cast<RankedTensorType>().getShape(),
|
||||
srcType.getElementType(), dstSharedLayout);
|
||||
|
||||
auto newTrans = rewriter.create<triton::TransOp>(op->getLoc(), newDstType,
|
||||
tmpCvt.getResult());
|
||||
|
||||
rewriter.replaceOp(*users.begin(), newTrans.getResult());
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
class BlockedToMMA : public mlir::RewritePattern {
|
||||
int computeCapability;
|
||||
|
||||
@@ -755,6 +807,7 @@ public:
|
||||
|
||||
mlir::RewritePatternSet patterns(context);
|
||||
|
||||
patterns.add<OptimizeBlockedToShared>(context);
|
||||
patterns.add<SimplifyConversion>(context);
|
||||
patterns.add<DecomposeDotOperand>(context);
|
||||
patterns.add<RematerializeBackward>(context);
|
||||
|
Reference in New Issue
Block a user