[FRONTEND][BACKEND] Added trans
instruction; made flash attention bwd pass work (#943)
This commit is contained in:
@@ -2716,6 +2716,9 @@ public:
|
||||
auto dstSharedLayout = dstTy.getEncoding().cast<SharedEncodingAttr>();
|
||||
auto inOrd = srcBlockedLayout.getOrder();
|
||||
auto outOrd = dstSharedLayout.getOrder();
|
||||
if (inOrd != outOrd)
|
||||
llvm_unreachable(
|
||||
"blocked -> shared with different order not yet implemented");
|
||||
unsigned inVec =
|
||||
inOrd == outOrd ? srcBlockedLayout.getSizePerThread()[inOrd[0]] : 1;
|
||||
unsigned outVec = dstSharedLayout.getVec();
|
||||
@@ -2775,7 +2778,8 @@ public:
|
||||
getMultiDimIndex<unsigned>(linearRepIdx, reps, inOrd);
|
||||
for (unsigned linearWordIdx = 0; linearWordIdx < numWordsEachRep;
|
||||
++linearWordIdx) {
|
||||
// step 1: recover the multidim_index from the index of input_elements
|
||||
// step 1: recover the multidim_index from the index of
|
||||
// input_elements
|
||||
auto multiDimWordIdx =
|
||||
getMultiDimIndex<unsigned>(linearWordIdx, wordsInEachRep, inOrd);
|
||||
SmallVector<Value> multiDimIdx(2);
|
||||
@@ -3711,6 +3715,33 @@ DotOpConversion::convertFMADot(triton::DotOp op, OpAdaptor adaptor,
|
||||
|
||||
/// ====================== mma codegen end ============================
|
||||
|
||||
/// ====================== trans codegen begin ============================
|
||||
|
||||
struct TransOpConversion
|
||||
: public ConvertTritonGPUOpToLLVMPattern<triton::TransOp> {
|
||||
using ConvertTritonGPUOpToLLVMPattern<
|
||||
triton::TransOp>::ConvertTritonGPUOpToLLVMPattern;
|
||||
|
||||
LogicalResult
|
||||
matchAndRewrite(triton::TransOp op, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
Location loc = op->getLoc();
|
||||
auto srcSmemObj =
|
||||
getSharedMemoryObjectFromStruct(loc, adaptor.src(), rewriter);
|
||||
SmallVector<Value> dstStrides = {srcSmemObj.strides[1],
|
||||
srcSmemObj.strides[0]};
|
||||
SmallVector<Value> dstOffsets = {srcSmemObj.offsets[1],
|
||||
srcSmemObj.offsets[0]};
|
||||
auto dstSmemObj =
|
||||
SharedMemoryObject(srcSmemObj.base, dstStrides, dstOffsets);
|
||||
auto retVal = getStructFromSharedMemoryObject(loc, dstSmemObj, rewriter);
|
||||
rewriter.replaceOp(op, retVal);
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
/// ====================== trans codegen end ============================
|
||||
|
||||
Value convertSplatLikeOpWithMmaLayout(const MmaEncodingAttr &layout,
|
||||
Type resType, Type elemType,
|
||||
Value constVal,
|
||||
@@ -4538,6 +4569,7 @@ void populateTritonToLLVMPatterns(mlir::LLVMTypeConverter &typeConverter,
|
||||
patterns.add<ViewLikeOpConversion<triton::ExpandDimsOp>>(typeConverter,
|
||||
benefit);
|
||||
patterns.add<DotOpConversion>(typeConverter, allocation, smem, benefit);
|
||||
patterns.add<TransOpConversion>(typeConverter, benefit);
|
||||
patterns.add<PrintfOpConversion>(typeConverter, benefit);
|
||||
}
|
||||
|
||||
|
@@ -252,6 +252,51 @@ struct TritonDotPattern : public OpConversionPattern<triton::DotOp> {
|
||||
}
|
||||
};
|
||||
|
||||
struct TritonTransPattern : public OpConversionPattern<triton::TransOp> {
|
||||
|
||||
using OpConversionPattern<triton::TransOp>::OpConversionPattern;
|
||||
|
||||
LogicalResult
|
||||
matchAndRewrite(triton::TransOp op, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
Value src = adaptor.src();
|
||||
auto srcType = src.getType().cast<RankedTensorType>();
|
||||
Attribute srcEncoding = srcType.getEncoding();
|
||||
if (!srcEncoding)
|
||||
return failure();
|
||||
if (!srcEncoding.isa<triton::gpu::SharedEncodingAttr>()) {
|
||||
// TODO: end-to-end correctness is broken if
|
||||
// the input is blocked and the output is shared
|
||||
// with different order. Maybe a backend issue in BlockedToShared?
|
||||
SmallVector<unsigned> order = {1, 0};
|
||||
if (auto srcBlockedEncoding =
|
||||
srcEncoding.dyn_cast<triton::gpu::BlockedEncodingAttr>())
|
||||
llvm::copy(srcBlockedEncoding.getOrder(), order.begin());
|
||||
srcEncoding =
|
||||
triton::gpu::SharedEncodingAttr::get(getContext(), 1, 1, 1, order);
|
||||
srcType = RankedTensorType::get(srcType.getShape(),
|
||||
srcType.getElementType(), srcEncoding);
|
||||
src = rewriter.create<triton::gpu::ConvertLayoutOp>(src.getLoc(), srcType,
|
||||
src);
|
||||
}
|
||||
auto srcSharedEncoding =
|
||||
srcEncoding.cast<triton::gpu::SharedEncodingAttr>();
|
||||
SmallVector<unsigned> retOrder(srcSharedEncoding.getOrder().begin(),
|
||||
srcSharedEncoding.getOrder().end());
|
||||
SmallVector<int64_t> retShapes(srcType.getShape().begin(),
|
||||
srcType.getShape().end());
|
||||
std::reverse(retOrder.begin(), retOrder.end());
|
||||
std::reverse(retShapes.begin(), retShapes.end());
|
||||
auto retEncoding =
|
||||
triton::gpu::SharedEncodingAttr::get(getContext(), 1, 1, 1, retOrder);
|
||||
auto retType =
|
||||
RankedTensorType::get(retShapes, srcType.getElementType(), retEncoding);
|
||||
|
||||
rewriter.replaceOpWithNewOp<triton::TransOp>(op, retType, src);
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
struct TritonLoadPattern : public OpConversionPattern<triton::LoadOp> {
|
||||
using OpConversionPattern<triton::LoadOp>::OpConversionPattern;
|
||||
|
||||
@@ -390,9 +435,10 @@ void populateTritonPatterns(TritonGPUTypeConverter &typeConverter,
|
||||
TritonGenericPattern<triton::PtrToIntOp>,
|
||||
TritonGenericPattern<triton::SplatOp>, TritonBroadcastPattern,
|
||||
TritonGenericPattern<triton::AddPtrOp>, TritonReducePattern,
|
||||
TritonExpandDimsPattern, TritonMakeRangePattern, TritonDotPattern,
|
||||
TritonLoadPattern, TritonStorePattern, TritonExtElemwisePattern,
|
||||
TritonPrintfPattern, TritonAtomicRMWPattern>(typeConverter, context);
|
||||
TritonTransPattern, TritonExpandDimsPattern, TritonMakeRangePattern,
|
||||
TritonDotPattern, TritonLoadPattern, TritonStorePattern,
|
||||
TritonExtElemwisePattern, TritonPrintfPattern, TritonAtomicRMWPattern>(
|
||||
typeConverter, context);
|
||||
}
|
||||
|
||||
//
|
||||
|
Reference in New Issue
Block a user