[FRONTEND][BACKEND] Added trans
instruction; made flash attention bwd pass work (#943)
This commit is contained in:
@@ -252,6 +252,51 @@ struct TritonDotPattern : public OpConversionPattern<triton::DotOp> {
|
||||
}
|
||||
};
|
||||
|
||||
struct TritonTransPattern : public OpConversionPattern<triton::TransOp> {
|
||||
|
||||
using OpConversionPattern<triton::TransOp>::OpConversionPattern;
|
||||
|
||||
LogicalResult
|
||||
matchAndRewrite(triton::TransOp op, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
Value src = adaptor.src();
|
||||
auto srcType = src.getType().cast<RankedTensorType>();
|
||||
Attribute srcEncoding = srcType.getEncoding();
|
||||
if (!srcEncoding)
|
||||
return failure();
|
||||
if (!srcEncoding.isa<triton::gpu::SharedEncodingAttr>()) {
|
||||
// TODO: end-to-end correctness is broken if
|
||||
// the input is blocked and the output is shared
|
||||
// with different order. Maybe a backend issue in BlockedToShared?
|
||||
SmallVector<unsigned> order = {1, 0};
|
||||
if (auto srcBlockedEncoding =
|
||||
srcEncoding.dyn_cast<triton::gpu::BlockedEncodingAttr>())
|
||||
llvm::copy(srcBlockedEncoding.getOrder(), order.begin());
|
||||
srcEncoding =
|
||||
triton::gpu::SharedEncodingAttr::get(getContext(), 1, 1, 1, order);
|
||||
srcType = RankedTensorType::get(srcType.getShape(),
|
||||
srcType.getElementType(), srcEncoding);
|
||||
src = rewriter.create<triton::gpu::ConvertLayoutOp>(src.getLoc(), srcType,
|
||||
src);
|
||||
}
|
||||
auto srcSharedEncoding =
|
||||
srcEncoding.cast<triton::gpu::SharedEncodingAttr>();
|
||||
SmallVector<unsigned> retOrder(srcSharedEncoding.getOrder().begin(),
|
||||
srcSharedEncoding.getOrder().end());
|
||||
SmallVector<int64_t> retShapes(srcType.getShape().begin(),
|
||||
srcType.getShape().end());
|
||||
std::reverse(retOrder.begin(), retOrder.end());
|
||||
std::reverse(retShapes.begin(), retShapes.end());
|
||||
auto retEncoding =
|
||||
triton::gpu::SharedEncodingAttr::get(getContext(), 1, 1, 1, retOrder);
|
||||
auto retType =
|
||||
RankedTensorType::get(retShapes, srcType.getElementType(), retEncoding);
|
||||
|
||||
rewriter.replaceOpWithNewOp<triton::TransOp>(op, retType, src);
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
struct TritonLoadPattern : public OpConversionPattern<triton::LoadOp> {
|
||||
using OpConversionPattern<triton::LoadOp>::OpConversionPattern;
|
||||
|
||||
@@ -390,9 +435,10 @@ void populateTritonPatterns(TritonGPUTypeConverter &typeConverter,
|
||||
TritonGenericPattern<triton::PtrToIntOp>,
|
||||
TritonGenericPattern<triton::SplatOp>, TritonBroadcastPattern,
|
||||
TritonGenericPattern<triton::AddPtrOp>, TritonReducePattern,
|
||||
TritonExpandDimsPattern, TritonMakeRangePattern, TritonDotPattern,
|
||||
TritonLoadPattern, TritonStorePattern, TritonExtElemwisePattern,
|
||||
TritonPrintfPattern, TritonAtomicRMWPattern>(typeConverter, context);
|
||||
TritonTransPattern, TritonExpandDimsPattern, TritonMakeRangePattern,
|
||||
TritonDotPattern, TritonLoadPattern, TritonStorePattern,
|
||||
TritonExtElemwisePattern, TritonPrintfPattern, TritonAtomicRMWPattern>(
|
||||
typeConverter, context);
|
||||
}
|
||||
|
||||
//
|
||||
|
Reference in New Issue
Block a user