[FRONTEND][BACKEND] Added trans instruction; made flash attention bwd pass work (#943)

This commit is contained in:
Philippe Tillet
2022-12-03 09:58:24 -08:00
committed by GitHub
parent 4d64589b22
commit 8edfe813a5
12 changed files with 310 additions and 143 deletions

View File

@@ -26,7 +26,7 @@ ChangeResult SharedMemoryAliasAnalysis::visitOperation(
// These ops may allocate a new shared memory buffer.
auto result = op->getResult(0);
// FIXME(Keren): extract and insert are always alias for now
if (isa<tensor::ExtractSliceOp>(op)) {
if (isa<tensor::ExtractSliceOp, triton::TransOp>(op)) {
// extract_slice %src
aliasInfo = AliasInfo(operands[0]->getValue());
pessimistic = false;

View File

@@ -105,7 +105,7 @@ bool maybeSharedAllocationOp(Operation *op) {
}
bool maybeAliasOp(Operation *op) {
return isa<tensor::ExtractSliceOp>(op) ||
return isa<tensor::ExtractSliceOp>(op) || isa<triton::TransOp>(op) ||
isa<triton::gpu::InsertSliceAsyncOp>(op) ||
isa<tensor::InsertSliceOp>(op);
}

View File

@@ -2716,6 +2716,9 @@ public:
auto dstSharedLayout = dstTy.getEncoding().cast<SharedEncodingAttr>();
auto inOrd = srcBlockedLayout.getOrder();
auto outOrd = dstSharedLayout.getOrder();
if (inOrd != outOrd)
llvm_unreachable(
"blocked -> shared with different order not yet implemented");
unsigned inVec =
inOrd == outOrd ? srcBlockedLayout.getSizePerThread()[inOrd[0]] : 1;
unsigned outVec = dstSharedLayout.getVec();
@@ -2775,7 +2778,8 @@ public:
getMultiDimIndex<unsigned>(linearRepIdx, reps, inOrd);
for (unsigned linearWordIdx = 0; linearWordIdx < numWordsEachRep;
++linearWordIdx) {
// step 1: recover the multidim_index from the index of input_elements
// step 1: recover the multidim_index from the index of
// input_elements
auto multiDimWordIdx =
getMultiDimIndex<unsigned>(linearWordIdx, wordsInEachRep, inOrd);
SmallVector<Value> multiDimIdx(2);
@@ -3711,6 +3715,33 @@ DotOpConversion::convertFMADot(triton::DotOp op, OpAdaptor adaptor,
/// ====================== mma codegen end ============================
/// ====================== trans codegen begin ============================
struct TransOpConversion
: public ConvertTritonGPUOpToLLVMPattern<triton::TransOp> {
using ConvertTritonGPUOpToLLVMPattern<
triton::TransOp>::ConvertTritonGPUOpToLLVMPattern;
LogicalResult
matchAndRewrite(triton::TransOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
Location loc = op->getLoc();
auto srcSmemObj =
getSharedMemoryObjectFromStruct(loc, adaptor.src(), rewriter);
SmallVector<Value> dstStrides = {srcSmemObj.strides[1],
srcSmemObj.strides[0]};
SmallVector<Value> dstOffsets = {srcSmemObj.offsets[1],
srcSmemObj.offsets[0]};
auto dstSmemObj =
SharedMemoryObject(srcSmemObj.base, dstStrides, dstOffsets);
auto retVal = getStructFromSharedMemoryObject(loc, dstSmemObj, rewriter);
rewriter.replaceOp(op, retVal);
return success();
}
};
/// ====================== trans codegen end ============================
Value convertSplatLikeOpWithMmaLayout(const MmaEncodingAttr &layout,
Type resType, Type elemType,
Value constVal,
@@ -4538,6 +4569,7 @@ void populateTritonToLLVMPatterns(mlir::LLVMTypeConverter &typeConverter,
patterns.add<ViewLikeOpConversion<triton::ExpandDimsOp>>(typeConverter,
benefit);
patterns.add<DotOpConversion>(typeConverter, allocation, smem, benefit);
patterns.add<TransOpConversion>(typeConverter, benefit);
patterns.add<PrintfOpConversion>(typeConverter, benefit);
}

View File

@@ -252,6 +252,51 @@ struct TritonDotPattern : public OpConversionPattern<triton::DotOp> {
}
};
struct TritonTransPattern : public OpConversionPattern<triton::TransOp> {
using OpConversionPattern<triton::TransOp>::OpConversionPattern;
LogicalResult
matchAndRewrite(triton::TransOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
Value src = adaptor.src();
auto srcType = src.getType().cast<RankedTensorType>();
Attribute srcEncoding = srcType.getEncoding();
if (!srcEncoding)
return failure();
if (!srcEncoding.isa<triton::gpu::SharedEncodingAttr>()) {
// TODO: end-to-end correctness is broken if
// the input is blocked and the output is shared
// with different order. Maybe a backend issue in BlockedToShared?
SmallVector<unsigned> order = {1, 0};
if (auto srcBlockedEncoding =
srcEncoding.dyn_cast<triton::gpu::BlockedEncodingAttr>())
llvm::copy(srcBlockedEncoding.getOrder(), order.begin());
srcEncoding =
triton::gpu::SharedEncodingAttr::get(getContext(), 1, 1, 1, order);
srcType = RankedTensorType::get(srcType.getShape(),
srcType.getElementType(), srcEncoding);
src = rewriter.create<triton::gpu::ConvertLayoutOp>(src.getLoc(), srcType,
src);
}
auto srcSharedEncoding =
srcEncoding.cast<triton::gpu::SharedEncodingAttr>();
SmallVector<unsigned> retOrder(srcSharedEncoding.getOrder().begin(),
srcSharedEncoding.getOrder().end());
SmallVector<int64_t> retShapes(srcType.getShape().begin(),
srcType.getShape().end());
std::reverse(retOrder.begin(), retOrder.end());
std::reverse(retShapes.begin(), retShapes.end());
auto retEncoding =
triton::gpu::SharedEncodingAttr::get(getContext(), 1, 1, 1, retOrder);
auto retType =
RankedTensorType::get(retShapes, srcType.getElementType(), retEncoding);
rewriter.replaceOpWithNewOp<triton::TransOp>(op, retType, src);
return success();
}
};
struct TritonLoadPattern : public OpConversionPattern<triton::LoadOp> {
using OpConversionPattern<triton::LoadOp>::OpConversionPattern;
@@ -390,9 +435,10 @@ void populateTritonPatterns(TritonGPUTypeConverter &typeConverter,
TritonGenericPattern<triton::PtrToIntOp>,
TritonGenericPattern<triton::SplatOp>, TritonBroadcastPattern,
TritonGenericPattern<triton::AddPtrOp>, TritonReducePattern,
TritonExpandDimsPattern, TritonMakeRangePattern, TritonDotPattern,
TritonLoadPattern, TritonStorePattern, TritonExtElemwisePattern,
TritonPrintfPattern, TritonAtomicRMWPattern>(typeConverter, context);
TritonTransPattern, TritonExpandDimsPattern, TritonMakeRangePattern,
TritonDotPattern, TritonLoadPattern, TritonStorePattern,
TritonExtElemwisePattern, TritonPrintfPattern, TritonAtomicRMWPattern>(
typeConverter, context);
}
//

View File

@@ -178,6 +178,10 @@ public:
!isSharedEncoding(convert.getResult())) {
return mlir::failure();
}
if (isSharedEncoding(convert.getOperand()) &&
isSharedEncoding(convert.getResult())) {
return mlir::failure();
}
auto srcType = convert.getOperand().getType().cast<RankedTensorType>();
auto srcShared =
srcType.getEncoding().dyn_cast<triton::gpu::SharedEncodingAttr>();
@@ -661,6 +665,54 @@ SmallVector<unsigned, 2> warpsPerTileV2(triton::DotOp dotOp,
} // namespace
class OptimizeBlockedToShared : public mlir::RewritePattern {
public:
OptimizeBlockedToShared(mlir::MLIRContext *context)
: RewritePattern(triton::gpu::ConvertLayoutOp::getOperationName(), 1,
context) {}
mlir::LogicalResult
matchAndRewrite(mlir::Operation *op,
mlir::PatternRewriter &rewriter) const override {
auto cvt = cast<triton::gpu::ConvertLayoutOp>(op);
auto srcType = cvt.getOperand().getType().cast<RankedTensorType>();
auto dstType = cvt.getResult().getType().cast<RankedTensorType>();
auto srcBlockedLayout =
srcType.getEncoding().dyn_cast<triton::gpu::BlockedEncodingAttr>();
auto dstSharedLayout =
dstType.getEncoding().dyn_cast<triton::gpu::SharedEncodingAttr>();
if (!srcBlockedLayout || !dstSharedLayout)
return failure();
if (srcBlockedLayout.getOrder() == dstSharedLayout.getOrder())
return failure();
// For now only works if single use is transpose
// TODO: rematerialize #shared uses
auto users = op->getUsers();
if (std::distance(users.begin(), users.end()) != 1 ||
!isa<triton::TransOp>(*users.begin()))
return failure();
auto tmpShared = triton::gpu::SharedEncodingAttr::get(
op->getContext(), dstSharedLayout.getVec(),
dstSharedLayout.getPerPhase(), dstSharedLayout.getMaxPhase(),
srcBlockedLayout.getOrder());
auto tmpType = RankedTensorType::get(srcType.getShape(),
srcType.getElementType(), tmpShared);
auto tmpCvt = rewriter.create<triton::gpu::ConvertLayoutOp>(
op->getLoc(), tmpType, cvt.getOperand());
auto newDstType = RankedTensorType::get(
users.begin()->getResultTypes()[0].cast<RankedTensorType>().getShape(),
srcType.getElementType(), dstSharedLayout);
auto newTrans = rewriter.create<triton::TransOp>(op->getLoc(), newDstType,
tmpCvt.getResult());
rewriter.replaceOp(*users.begin(), newTrans.getResult());
return success();
}
};
class BlockedToMMA : public mlir::RewritePattern {
int computeCapability;
@@ -755,6 +807,7 @@ public:
mlir::RewritePatternSet patterns(context);
patterns.add<OptimizeBlockedToShared>(context);
patterns.add<SimplifyConversion>(context);
patterns.add<DecomposeDotOperand>(context);
patterns.add<RematerializeBackward>(context);