[OPTIMIZER] Rewrite patterns for layout conversions (#64)

This commit is contained in:
Philippe Tillet
2022-08-18 12:49:37 -07:00
committed by GitHub
parent e0bedeb44c
commit 192be76b3c
19 changed files with 851 additions and 127 deletions

View File

@@ -23,9 +23,9 @@
using namespace mlir;
using namespace mlir::triton;
using ::mlir::triton::gpu::TritonGPUBlockedEncodingAttr;
using ::mlir::triton::gpu::TritonGPUMmaEncodingAttr;
using ::mlir::triton::gpu::TritonGPUSharedEncodingAttr;
using ::mlir::triton::gpu::BlockedEncodingAttr;
using ::mlir::triton::gpu::MmaEncodingAttr;
using ::mlir::triton::gpu::SharedEncodingAttr;
namespace mlir {
namespace LLVM {
@@ -226,7 +226,7 @@ static int64_t getLinearIndex(std::vector<int64_t> multidim_index,
return linear_index;
}
static unsigned getElemsPerThread(TritonGPUBlockedEncodingAttr layout,
static unsigned getElemsPerThread(BlockedEncodingAttr layout,
ArrayRef<int64_t> shape) {
size_t rank = shape.size();
SmallVector<unsigned> elemsPerThreadPerDim(rank);
@@ -368,10 +368,10 @@ public:
// Emit indices calculation within each ConversionPattern
// TODO: [goostavz] Double confirm the redundant indices calculations will
// be eliminated in the consequent MLIR/LLVM optimization
SmallVector<SmallVector<Value>> emitIndicesForBlockedLayout(
Location loc, ConversionPatternRewriter &b,
const TritonGPUBlockedEncodingAttr &blocked_layout,
ArrayRef<int64_t> shape) const {
SmallVector<SmallVector<Value>>
emitIndicesForBlockedLayout(Location loc, ConversionPatternRewriter &b,
const BlockedEncodingAttr &blocked_layout,
ArrayRef<int64_t> shape) const {
auto llvmIndexTy = this->getTypeConverter()->getIndexType();
auto cast = b.create<UnrealizedConversionCastOp>(
loc, TypeRange{llvmIndexTy},
@@ -483,7 +483,7 @@ Value convertSplatLikeOp(Type elemType, Type resType, Value constVal,
ConversionPatternRewriter &rewriter, Location loc) {
auto tensorTy = resType.cast<RankedTensorType>();
auto layout = tensorTy.getEncoding().cast<TritonGPUBlockedEncodingAttr>();
auto layout = tensorTy.getEncoding().cast<BlockedEncodingAttr>();
auto srcType = typeConverter->convertType(elemType);
auto llSrc = rewriter.create<LLVM::BitcastOp>(loc, srcType, constVal);
@@ -594,9 +594,9 @@ struct StoreOpConversion
MLIRContext *ctx = rewriter.getContext();
auto loc = op->getLoc();
auto getLLVMElems = [&](Value value, Value llValue,
const TritonGPUBlockedEncodingAttr &layout)
-> SmallVector<Value, 4> {
auto getLLVMElems =
[&](Value value, Value llValue,
const BlockedEncodingAttr &layout) -> SmallVector<Value, 4> {
auto ty = value.getType().cast<RankedTensorType>();
auto shape = ty.getShape();
// Here, we assume that all inputs should have a blockedLayout
@@ -613,11 +613,11 @@ struct StoreOpConversion
};
auto getLayout =
[&](Value val) -> std::tuple<TritonGPUBlockedEncodingAttr, unsigned> {
[&](Value val) -> std::tuple<BlockedEncodingAttr, unsigned> {
auto ty = val.getType().cast<RankedTensorType>();
auto shape = ty.getShape();
// Here, we assume that all inputs should have a blockedLayout
auto layout = ty.getEncoding().dyn_cast<TritonGPUBlockedEncodingAttr>();
auto layout = ty.getEncoding().dyn_cast<BlockedEncodingAttr>();
unsigned valueElems = getElemsPerThread(layout, shape);
@@ -633,9 +633,8 @@ struct StoreOpConversion
auto maskElems = getLLVMElems(mask, llMask, maskLayout);
assert(valueElems.size() == maskElems.size());
auto getAlign =
[this](Value val,
const TritonGPUBlockedEncodingAttr &layout) -> unsigned {
auto getAlign = [this](Value val,
const BlockedEncodingAttr &layout) -> unsigned {
auto axisInfo = getAxisInfo(val);
assert(axisInfo.hasValue());
@@ -648,9 +647,9 @@ struct StoreOpConversion
};
// get align
auto getVec = [this, &getAlign](
Value val,
const TritonGPUBlockedEncodingAttr &layout) -> unsigned {
auto getVec = [this,
&getAlign](Value val,
const BlockedEncodingAttr &layout) -> unsigned {
auto axisInfo = getAxisInfo(val);
auto contig = axisInfo->getContiguity();
// Here order should be ordered by contiguous first, so the first element
@@ -820,10 +819,8 @@ struct BroadcastOpConversion
Value result = op.result();
auto srcTy = op.src().getType().cast<RankedTensorType>();
auto resultTy = result.getType().cast<RankedTensorType>();
auto srcLayout =
srcTy.getEncoding().dyn_cast<TritonGPUBlockedEncodingAttr>();
auto resultLayout =
resultTy.getEncoding().dyn_cast<TritonGPUBlockedEncodingAttr>();
auto srcLayout = srcTy.getEncoding().dyn_cast<BlockedEncodingAttr>();
auto resultLayout = resultTy.getEncoding().dyn_cast<BlockedEncodingAttr>();
assert(srcLayout && (srcLayout == resultLayout) &&
"Unexpected layout of BroadcastOp");
auto srcShape = srcTy.getShape();
@@ -894,8 +891,7 @@ struct ViewOpConversion
// due to MLIR's restrictions
Location loc = op->getLoc();
auto resultTy = op.getType().cast<RankedTensorType>();
auto resultLayout =
resultTy.getEncoding().dyn_cast<TritonGPUBlockedEncodingAttr>();
auto resultLayout = resultTy.getEncoding().dyn_cast<BlockedEncodingAttr>();
auto resultShape = resultTy.getShape();
unsigned elems = getElemsPerThread(resultLayout, resultShape);
Type elemTy =
@@ -921,7 +917,7 @@ struct MakeRangeOpConversion
auto rankedTy = op.result().getType().dyn_cast<RankedTensorType>();
auto shape = rankedTy.getShape();
auto blocked_layout =
rankedTy.getEncoding().dyn_cast<TritonGPUBlockedEncodingAttr>();
rankedTy.getEncoding().dyn_cast<BlockedEncodingAttr>();
auto elemTy = rankedTy.getElementType();
assert(elemTy.isInteger(32));
Value start = createIndexAttrConstant(rewriter, loc, elemTy, op.start());
@@ -955,8 +951,7 @@ struct LoadOpConversion
Value mask = adaptor.mask();
Value other = adaptor.other();
auto resultTy = op.result().getType().cast<RankedTensorType>();
auto blockedLayout =
resultTy.getEncoding().dyn_cast<TritonGPUBlockedEncodingAttr>();
auto blockedLayout = resultTy.getEncoding().dyn_cast<BlockedEncodingAttr>();
auto shape = resultTy.getShape();
// TODO: Handle AxisInfo
@@ -1166,8 +1161,7 @@ struct GEPOpConversion : public ConvertTritonGPUOpToLLVMPattern<triton::GEPOp> {
ConversionPatternRewriter &rewriter) const override {
Location loc = op->getLoc();
auto resultTy = op.getType().dyn_cast<RankedTensorType>();
auto resultLayout =
resultTy.getEncoding().dyn_cast<TritonGPUBlockedEncodingAttr>();
auto resultLayout = resultTy.getEncoding().dyn_cast<BlockedEncodingAttr>();
auto resultShape = resultTy.getShape();
unsigned elems = getElemsPerThread(resultLayout, resultShape);
Type elemTy =
@@ -1206,8 +1200,8 @@ public:
return failure();
Location loc = op->getLoc();
auto resultLayout = resultTy.getEncoding()
.template dyn_cast<TritonGPUBlockedEncodingAttr>();
auto resultLayout =
resultTy.getEncoding().template dyn_cast<BlockedEncodingAttr>();
auto resultShape = resultTy.getShape();
unsigned elems = getElemsPerThread(resultLayout, resultShape);
Type elemTy =
@@ -1250,17 +1244,16 @@ public:
llvm::Optional<Type> convertTritonTensorType(RankedTensorType type) {
Attribute layout = type.getEncoding();
if (auto blocked_layout = layout.dyn_cast<TritonGPUBlockedEncodingAttr>()) {
if (auto blocked_layout = layout.dyn_cast<BlockedEncodingAttr>()) {
unsigned numElementsPerThread =
getElemsPerThread(blocked_layout, type.getShape());
SmallVector<Type, 4> types(numElementsPerThread,
convertType(type.getElementType()));
return LLVM::LLVMStructType::getLiteral(&getContext(), types);
} else if (auto mma_layout = layout.dyn_cast<TritonGPUMmaEncodingAttr>()) {
} else if (auto mma_layout = layout.dyn_cast<MmaEncodingAttr>()) {
// TODO: Not implemented
return llvm::None;
} else if (auto shared_layout =
layout.dyn_cast<TritonGPUSharedEncodingAttr>()) {
} else if (auto shared_layout = layout.dyn_cast<SharedEncodingAttr>()) {
// TODO: Not implemented
return llvm::None;
}