[Triton-MLIR] Keren/code gen for extract slice and alloc tensor (#692)
Co-authored-by: gzhu <goostavz@outlook.com>
This commit is contained in:
@@ -333,6 +333,13 @@ public:
|
||||
PatternBenefit benefit = 1)
|
||||
: ConvertOpToLLVMPattern<SourceOp>(typeConverter, benefit) {}
|
||||
|
||||
explicit ConvertTritonGPUOpToLLVMPattern(LLVMTypeConverter &typeConverter,
|
||||
const Allocation *allocation,
|
||||
Value smem,
|
||||
PatternBenefit benefit = 1)
|
||||
: ConvertOpToLLVMPattern<SourceOp>(typeConverter, benefit),
|
||||
allocation(allocation), smem(smem) {}
|
||||
|
||||
Value getThreadId(ConversionPatternRewriter &rewriter, Location loc) const {
|
||||
auto llvmIndexTy = this->getTypeConverter()->getIndexType();
|
||||
auto cast = rewriter.create<UnrealizedConversionCastOp>(
|
||||
@@ -585,12 +592,12 @@ public:
|
||||
return multiDimIdx;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
Value getSharedMemoryBase(Location loc, ConversionPatternRewriter &rewriter,
|
||||
Value smem, const Allocation *allocation,
|
||||
Operation *op) const {
|
||||
T value) const {
|
||||
auto ptrTy = LLVM::LLVMPointerType::get(
|
||||
this->getTypeConverter()->convertType(rewriter.getIntegerType(8)), 3);
|
||||
auto bufferId = allocation->getBufferId(op);
|
||||
this->getTypeConverter()->convertType(rewriter.getI8Type()), 3);
|
||||
auto bufferId = allocation->getBufferId(value);
|
||||
assert(bufferId != Allocation::InvalidBufferId && "BufferId not found");
|
||||
size_t offset = allocation->getOffset(bufferId);
|
||||
auto llvmIndexTy = this->getTypeConverter()->getIndexType();
|
||||
@@ -598,6 +605,10 @@ public:
|
||||
Value base = gep(ptrTy, smem, offVal);
|
||||
return base;
|
||||
}
|
||||
|
||||
protected:
|
||||
const Allocation *allocation;
|
||||
Value smem;
|
||||
};
|
||||
|
||||
// Convert SplatOp or arith::ConstantOp with SplatElementsAttr to a
|
||||
@@ -1332,6 +1343,65 @@ struct AddPtrOpConversion
|
||||
}
|
||||
};
|
||||
|
||||
struct AllocTensorOpConversion
|
||||
: public ConvertTritonGPUOpToLLVMPattern<triton::gpu::AllocTensorOp> {
|
||||
using ConvertTritonGPUOpToLLVMPattern<
|
||||
triton::gpu::AllocTensorOp>::ConvertTritonGPUOpToLLVMPattern;
|
||||
|
||||
LogicalResult
|
||||
matchAndRewrite(triton::gpu::AllocTensorOp op, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
Location loc = op->getLoc();
|
||||
Value smemBase = getSharedMemoryBase(loc, rewriter, op.getResult());
|
||||
auto resultTy = op.getType().dyn_cast<RankedTensorType>();
|
||||
auto llvmElemTy =
|
||||
getTypeConverter()->convertType(resultTy.getElementType());
|
||||
auto elemPtrTy = LLVM::LLVMPointerType::get(llvmElemTy, 3);
|
||||
Value resultVal =
|
||||
rewriter.create<LLVM::BitcastOp>(loc, elemPtrTy, smemBase);
|
||||
rewriter.replaceOp(op, resultVal);
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
struct ExtractSliceOpConversion
|
||||
: public ConvertTritonGPUOpToLLVMPattern<triton::gpu::ExtractSliceOp> {
|
||||
using ConvertTritonGPUOpToLLVMPattern<
|
||||
triton::gpu::ExtractSliceOp>::ConvertTritonGPUOpToLLVMPattern;
|
||||
|
||||
LogicalResult
|
||||
matchAndRewrite(triton::gpu::ExtractSliceOp op, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
Location loc = op->getLoc();
|
||||
auto srcTy = op.src().getType().dyn_cast<RankedTensorType>();
|
||||
auto srcLayout = srcTy.getEncoding().dyn_cast<SharedEncodingAttr>();
|
||||
assert(srcLayout && "Unexpected resultLayout in ExtractSliceOpConversion");
|
||||
|
||||
// axis > 0 will result in non-contiguous memory access if the result tensor
|
||||
// is an alias of the source tensor.
|
||||
auto axis =
|
||||
op->getAttrOfType<IntegerAttr>("axis").cast<IntegerAttr>().getInt();
|
||||
assert(axis == 0 && "Only axis=0 is supported for now");
|
||||
|
||||
// Example:
|
||||
// %dst = extract_slice %src, %index {axis = 0}
|
||||
// src.shape = [11, 2, 3, 4, 1]
|
||||
// offset = %index * 2 * 3 * 4 * 1
|
||||
auto dstTy = op.getType().dyn_cast<RankedTensorType>();
|
||||
auto base = product<int64_t>(dstTy.getShape());
|
||||
auto baseVal = createIndexAttrConstant(
|
||||
rewriter, loc, getTypeConverter()->getIndexType(), base);
|
||||
Value offset = rewriter.create<LLVM::MulOp>(loc, adaptor.index(), baseVal);
|
||||
|
||||
auto llvmElemTy = getTypeConverter()->convertType(dstTy.getElementType());
|
||||
auto elemPtrTy = LLVM::LLVMPointerType::get(llvmElemTy, 3);
|
||||
Value resultVal =
|
||||
rewriter.create<LLVM::GEPOp>(loc, elemPtrTy, adaptor.src(), offset);
|
||||
rewriter.replaceOp(op, resultVal);
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
template <typename SourceOp, typename DestOp>
|
||||
class BinaryOpConversion : public ConvertTritonGPUOpToLLVMPattern<SourceOp> {
|
||||
public:
|
||||
@@ -1379,13 +1449,6 @@ public:
|
||||
using ConvertTritonGPUOpToLLVMPattern<
|
||||
triton::gpu::ConvertLayoutOp>::ConvertTritonGPUOpToLLVMPattern;
|
||||
|
||||
ConvertLayoutOpConversion(LLVMTypeConverter &converter,
|
||||
const Allocation *allocation, Value smem,
|
||||
PatternBenefit benefit)
|
||||
: ConvertTritonGPUOpToLLVMPattern<triton::gpu::ConvertLayoutOp>(converter,
|
||||
benefit),
|
||||
allocation(allocation), smem(smem) {}
|
||||
|
||||
LogicalResult
|
||||
matchAndRewrite(triton::gpu::ConvertLayoutOp op, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
@@ -1399,13 +1462,10 @@ public:
|
||||
if ((!srcLayout.isa<BlockedEncodingAttr>()) ||
|
||||
(!dstLayout.isa<BlockedEncodingAttr>())) {
|
||||
// TODO: not implemented
|
||||
llvm::errs()
|
||||
<< "convert_layout except for blocked -> blocked is not implemented";
|
||||
return failure();
|
||||
}
|
||||
auto llvmElemTy = getTypeConverter()->convertType(dstTy.getElementType());
|
||||
Value smemBase =
|
||||
getSharedMemoryBase(loc, rewriter, smem, allocation, op.getOperation());
|
||||
Value smemBase = getSharedMemoryBase(loc, rewriter, op.getOperation());
|
||||
auto elemPtrTy = LLVM::LLVMPointerType::get(llvmElemTy, 3);
|
||||
smemBase = bit_cast(elemPtrTy, smemBase);
|
||||
|
||||
@@ -1587,9 +1647,6 @@ private:
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
const Allocation *allocation;
|
||||
Value smem;
|
||||
};
|
||||
|
||||
/// ====================== dot codegen begin ==========================
|
||||
@@ -1926,11 +1983,8 @@ struct DotOpConversion : public ConvertTritonGPUOpToLLVMPattern<triton::DotOp> {
|
||||
NOT_APPLICABLE,
|
||||
};
|
||||
|
||||
explicit DotOpConversion(LLVMTypeConverter &typeConverter,
|
||||
const Allocation *allocation, Value smem,
|
||||
PatternBenefit benefit = 1)
|
||||
: ConvertTritonGPUOpToLLVMPattern(typeConverter, benefit),
|
||||
allocation(allocation), smem(smem) {}
|
||||
using ConvertTritonGPUOpToLLVMPattern<
|
||||
triton::DotOp>::ConvertTritonGPUOpToLLVMPattern;
|
||||
|
||||
LogicalResult
|
||||
matchAndRewrite(triton::DotOp op, OpAdaptor adaptor,
|
||||
@@ -1995,15 +2049,6 @@ private:
|
||||
assert(false && "Not implemented yet.");
|
||||
return failure();
|
||||
}
|
||||
|
||||
Value getSmemAddr(Value value, Location loc,
|
||||
ConversionPatternRewriter &rewriter) const {
|
||||
return getSharedMemoryBase(loc, rewriter, smem, allocation,
|
||||
value.getDefiningOp());
|
||||
}
|
||||
|
||||
const Allocation *allocation;
|
||||
Value smem;
|
||||
};
|
||||
|
||||
struct DotOpConversionHelper {
|
||||
@@ -2340,7 +2385,7 @@ DotOpConversion::convertMMA16816(triton::DotOp op, OpAdaptor adapter,
|
||||
SmallVector<Value> ptrs(numPtrs);
|
||||
|
||||
Type smemPtrTy = helper.getShemPtrTy();
|
||||
auto smemBase = getSmemAddr(tensor, loc, rewriter);
|
||||
auto smemBase = getSharedMemoryBase(loc, rewriter, tensor);
|
||||
for (int i = 0; i < numPtrs; i++) {
|
||||
ptrs[i] = bit_cast(
|
||||
smemPtrTy, gep(smemBase.getType(), smemBase, ValueRange({offs[i]})));
|
||||
@@ -2479,10 +2524,12 @@ public:
|
||||
SmallVector<Type, 4> types(numElementsPerThread,
|
||||
convertType(type.getElementType()));
|
||||
return LLVM::LLVMStructType::getLiteral(&getContext(), types);
|
||||
} else if (auto mma_layout = layout.dyn_cast<MmaEncodingAttr>()) {
|
||||
return type;
|
||||
} else if (auto shared_layout = layout.dyn_cast<SharedEncodingAttr>()) {
|
||||
} else if (auto mma_layout = layout.dyn_cast_or_null<MmaEncodingAttr>()) {
|
||||
// TODO: Not implemented
|
||||
return type;
|
||||
} else if (auto shared_layout =
|
||||
layout.dyn_cast_or_null<SharedEncodingAttr>()) {
|
||||
return LLVM::LLVMPointerType::get(convertType(type.getElementType()), 3);
|
||||
}
|
||||
return llvm::None;
|
||||
}
|
||||
@@ -2493,6 +2540,9 @@ void populateTritonToLLVMPatterns(mlir::LLVMTypeConverter &typeConverter,
|
||||
AxisInfoAnalysis &axisInfoAnalysis,
|
||||
const Allocation *allocation, Value smem,
|
||||
PatternBenefit benefit = 1) {
|
||||
patterns.add<AddPtrOpConversion>(typeConverter, benefit);
|
||||
patterns.add<AllocTensorOpConversion>(typeConverter, allocation, smem,
|
||||
benefit);
|
||||
patterns.add<ArithConstantSplatOpConversion>(typeConverter, benefit);
|
||||
patterns.add<BinaryOpConversion<arith::AddIOp, LLVM::AddOp>>(typeConverter,
|
||||
benefit);
|
||||
@@ -2503,9 +2553,10 @@ void populateTritonToLLVMPatterns(mlir::LLVMTypeConverter &typeConverter,
|
||||
patterns.add<BinaryOpConversion<arith::MulFOp, LLVM::FMulOp>>(typeConverter,
|
||||
benefit);
|
||||
patterns.add<BroadcastOpConversion>(typeConverter, benefit);
|
||||
patterns.add<AddPtrOpConversion>(typeConverter, benefit);
|
||||
patterns.add<ConvertLayoutOpConversion>(typeConverter, allocation, smem,
|
||||
benefit);
|
||||
patterns.add<ExtractSliceOpConversion>(typeConverter, allocation, smem,
|
||||
benefit);
|
||||
patterns.add<GetProgramIdOpConversion>(typeConverter, benefit);
|
||||
patterns.add<LoadOpConversion>(typeConverter, axisInfoAnalysis, benefit);
|
||||
patterns.add<MakeRangeOpConversion>(typeConverter, benefit);
|
||||
|
Reference in New Issue
Block a user