[FRONTEND] Added ExpandDimsOp primitive (#36)

This commit is contained in:
Philippe Tillet
2022-08-04 18:41:06 -07:00
committed by GitHub
parent a7b49b3227
commit 78ebbe24c7
8 changed files with 98 additions and 41 deletions

View File

@@ -5,6 +5,7 @@
#include "triton/Dialect/Triton/IR/Dialect.h"
#include "triton/Dialect/TritonGPU/IR/Dialect.h"
#include "triton/Dialect/TritonGPU/Transforms/TritonGPUConversion.h"
#include <numeric>
using namespace mlir;
using namespace mlir::triton;
@@ -142,6 +143,46 @@ struct TritonMakeRangePattern
}
};
struct TritonExpandDimsPattern
: public OpConversionPattern<triton::ExpandDimsOp> {
using OpConversionPattern<triton::ExpandDimsOp>::OpConversionPattern;
LogicalResult
matchAndRewrite(triton::ExpandDimsOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
// Type retType = op.getType());
RankedTensorType argType = adaptor.src().getType().cast<RankedTensorType>();
Attribute _argEncoding = argType.getEncoding();
if (!_argEncoding)
return failure();
auto argEncoding =
_argEncoding.cast<triton::gpu::TritonGPUBlockedEncodingAttr>();
// return shape
auto retShape = argType.getShape().vec();
retShape.insert(retShape.begin() + op.axis(), 1);
// return encoding
auto retSizePerThread = argEncoding.getSizePerThread().vec();
retSizePerThread.insert(retSizePerThread.begin() + op.axis(), 1);
auto retThreadsPerWarp = argEncoding.getThreadsPerWarp().vec();
retThreadsPerWarp.insert(retThreadsPerWarp.begin() + op.axis(), 1);
auto retWarpsPerCTA = argEncoding.getWarpsPerCTA().vec();
retWarpsPerCTA.insert(retWarpsPerCTA.begin() + op.axis(), 1);
SmallVector<unsigned, 4> retOrder(retShape.size());
std::iota(retOrder.begin(), retOrder.end(), 0);
triton::gpu::TritonGPUBlockedEncodingAttr retEncoding =
triton::gpu::TritonGPUBlockedEncodingAttr::get(
getContext(), retSizePerThread, retThreadsPerWarp, retWarpsPerCTA,
retOrder);
// return type
RankedTensorType retType =
RankedTensorType::get(retShape, argType.getElementType(), retEncoding);
// construct new op
rewriter.replaceOpWithNewOp<triton::ExpandDimsOp>(
op, retType, adaptor.src(), adaptor.axis());
return success();
}
};
struct TritonDotPattern : public OpConversionPattern<triton::DotOp> {
using OpConversionPattern<triton::DotOp>::OpConversionPattern;
@@ -260,8 +301,8 @@ void populateTritonPatterns(TritonGPUTypeConverter &typeConverter,
TritonGenericPattern<triton::ViewOp>,
TritonGenericPattern<triton::SplatOp>, TritonBroadcastPattern,
TritonGenericPattern<triton::GEPOp>, TritonReducePattern,
TritonMakeRangePattern, TritonDotPattern, TritonLoadPattern,
TritonStorePattern>(typeConverter, context);
TritonExpandDimsPattern, TritonMakeRangePattern, TritonDotPattern,
TritonLoadPattern, TritonStorePattern>(typeConverter, context);
}
//