[FRONTEND] Added ExpandDimsOp
primitive (#36)
This commit is contained in:
@@ -5,6 +5,7 @@
|
||||
#include "triton/Dialect/Triton/IR/Dialect.h"
|
||||
#include "triton/Dialect/TritonGPU/IR/Dialect.h"
|
||||
#include "triton/Dialect/TritonGPU/Transforms/TritonGPUConversion.h"
|
||||
#include <numeric>
|
||||
|
||||
using namespace mlir;
|
||||
using namespace mlir::triton;
|
||||
@@ -142,6 +143,46 @@ struct TritonMakeRangePattern
|
||||
}
|
||||
};
|
||||
|
||||
struct TritonExpandDimsPattern
|
||||
: public OpConversionPattern<triton::ExpandDimsOp> {
|
||||
using OpConversionPattern<triton::ExpandDimsOp>::OpConversionPattern;
|
||||
|
||||
LogicalResult
|
||||
matchAndRewrite(triton::ExpandDimsOp op, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
// Type retType = op.getType());
|
||||
RankedTensorType argType = adaptor.src().getType().cast<RankedTensorType>();
|
||||
Attribute _argEncoding = argType.getEncoding();
|
||||
if (!_argEncoding)
|
||||
return failure();
|
||||
auto argEncoding =
|
||||
_argEncoding.cast<triton::gpu::TritonGPUBlockedEncodingAttr>();
|
||||
// return shape
|
||||
auto retShape = argType.getShape().vec();
|
||||
retShape.insert(retShape.begin() + op.axis(), 1);
|
||||
// return encoding
|
||||
auto retSizePerThread = argEncoding.getSizePerThread().vec();
|
||||
retSizePerThread.insert(retSizePerThread.begin() + op.axis(), 1);
|
||||
auto retThreadsPerWarp = argEncoding.getThreadsPerWarp().vec();
|
||||
retThreadsPerWarp.insert(retThreadsPerWarp.begin() + op.axis(), 1);
|
||||
auto retWarpsPerCTA = argEncoding.getWarpsPerCTA().vec();
|
||||
retWarpsPerCTA.insert(retWarpsPerCTA.begin() + op.axis(), 1);
|
||||
SmallVector<unsigned, 4> retOrder(retShape.size());
|
||||
std::iota(retOrder.begin(), retOrder.end(), 0);
|
||||
triton::gpu::TritonGPUBlockedEncodingAttr retEncoding =
|
||||
triton::gpu::TritonGPUBlockedEncodingAttr::get(
|
||||
getContext(), retSizePerThread, retThreadsPerWarp, retWarpsPerCTA,
|
||||
retOrder);
|
||||
// return type
|
||||
RankedTensorType retType =
|
||||
RankedTensorType::get(retShape, argType.getElementType(), retEncoding);
|
||||
// construct new op
|
||||
rewriter.replaceOpWithNewOp<triton::ExpandDimsOp>(
|
||||
op, retType, adaptor.src(), adaptor.axis());
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
struct TritonDotPattern : public OpConversionPattern<triton::DotOp> {
|
||||
using OpConversionPattern<triton::DotOp>::OpConversionPattern;
|
||||
|
||||
@@ -260,8 +301,8 @@ void populateTritonPatterns(TritonGPUTypeConverter &typeConverter,
|
||||
TritonGenericPattern<triton::ViewOp>,
|
||||
TritonGenericPattern<triton::SplatOp>, TritonBroadcastPattern,
|
||||
TritonGenericPattern<triton::GEPOp>, TritonReducePattern,
|
||||
TritonMakeRangePattern, TritonDotPattern, TritonLoadPattern,
|
||||
TritonStorePattern>(typeConverter, context);
|
||||
TritonExpandDimsPattern, TritonMakeRangePattern, TritonDotPattern,
|
||||
TritonLoadPattern, TritonStorePattern>(typeConverter, context);
|
||||
}
|
||||
|
||||
//
|
||||
|
Reference in New Issue
Block a user