[Triton-MLIR] Keren/code gen for extract slice and alloc tensor (#692)
Co-authored-by: gzhu <goostavz@outlook.com>
This commit is contained in:
@@ -43,7 +43,6 @@ getScratchConfigForCvtLayout(triton::gpu::ConvertLayoutOp op, unsigned &inVec,
|
|||||||
return 0;
|
return 0;
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
// blocked -> blocked
|
|
||||||
if (srcLayout.isa<BlockedEncodingAttr>() &&
|
if (srcLayout.isa<BlockedEncodingAttr>() &&
|
||||||
dstLayout.isa<BlockedEncodingAttr>()) {
|
dstLayout.isa<BlockedEncodingAttr>()) {
|
||||||
auto srcBlockedLayout = srcLayout.cast<BlockedEncodingAttr>();
|
auto srcBlockedLayout = srcLayout.cast<BlockedEncodingAttr>();
|
||||||
@@ -66,14 +65,6 @@ getScratchConfigForCvtLayout(triton::gpu::ConvertLayoutOp op, unsigned &inVec,
|
|||||||
}
|
}
|
||||||
paddedRepShape[outOrd[0]] += pad;
|
paddedRepShape[outOrd[0]] += pad;
|
||||||
}
|
}
|
||||||
// blocked -> shared
|
|
||||||
if (srcLayout.isa<BlockedEncodingAttr>() &&
|
|
||||||
dstLayout.isa<SharedEncodingAttr>()) {
|
|
||||||
auto sharedLayout = dstLayout.cast<SharedEncodingAttr>();
|
|
||||||
for (int v : dstTy.getShape())
|
|
||||||
paddedRepShape.push_back(v);
|
|
||||||
}
|
|
||||||
|
|
||||||
return paddedRepShape;
|
return paddedRepShape;
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -140,8 +131,9 @@ private:
|
|||||||
auto dstTy = cvtLayout.result().getType().cast<RankedTensorType>();
|
auto dstTy = cvtLayout.result().getType().cast<RankedTensorType>();
|
||||||
auto srcEncoding = srcTy.getEncoding();
|
auto srcEncoding = srcTy.getEncoding();
|
||||||
auto dstEncoding = dstTy.getEncoding();
|
auto dstEncoding = dstTy.getEncoding();
|
||||||
if (srcEncoding.isa<SharedEncodingAttr>()) {
|
if (srcEncoding.isa<SharedEncodingAttr>() ||
|
||||||
// only block->block and block->shared is supported now
|
dstEncoding.isa<SharedEncodingAttr>()) {
|
||||||
|
// Only blocked -> blocked conversion requires for scratch allocation
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
// ConvertLayoutOp with both input/output non-shared_layout
|
// ConvertLayoutOp with both input/output non-shared_layout
|
||||||
|
@@ -333,6 +333,13 @@ public:
|
|||||||
PatternBenefit benefit = 1)
|
PatternBenefit benefit = 1)
|
||||||
: ConvertOpToLLVMPattern<SourceOp>(typeConverter, benefit) {}
|
: ConvertOpToLLVMPattern<SourceOp>(typeConverter, benefit) {}
|
||||||
|
|
||||||
|
explicit ConvertTritonGPUOpToLLVMPattern(LLVMTypeConverter &typeConverter,
|
||||||
|
const Allocation *allocation,
|
||||||
|
Value smem,
|
||||||
|
PatternBenefit benefit = 1)
|
||||||
|
: ConvertOpToLLVMPattern<SourceOp>(typeConverter, benefit),
|
||||||
|
allocation(allocation), smem(smem) {}
|
||||||
|
|
||||||
Value getThreadId(ConversionPatternRewriter &rewriter, Location loc) const {
|
Value getThreadId(ConversionPatternRewriter &rewriter, Location loc) const {
|
||||||
auto llvmIndexTy = this->getTypeConverter()->getIndexType();
|
auto llvmIndexTy = this->getTypeConverter()->getIndexType();
|
||||||
auto cast = rewriter.create<UnrealizedConversionCastOp>(
|
auto cast = rewriter.create<UnrealizedConversionCastOp>(
|
||||||
@@ -585,12 +592,12 @@ public:
|
|||||||
return multiDimIdx;
|
return multiDimIdx;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
Value getSharedMemoryBase(Location loc, ConversionPatternRewriter &rewriter,
|
Value getSharedMemoryBase(Location loc, ConversionPatternRewriter &rewriter,
|
||||||
Value smem, const Allocation *allocation,
|
T value) const {
|
||||||
Operation *op) const {
|
|
||||||
auto ptrTy = LLVM::LLVMPointerType::get(
|
auto ptrTy = LLVM::LLVMPointerType::get(
|
||||||
this->getTypeConverter()->convertType(rewriter.getIntegerType(8)), 3);
|
this->getTypeConverter()->convertType(rewriter.getI8Type()), 3);
|
||||||
auto bufferId = allocation->getBufferId(op);
|
auto bufferId = allocation->getBufferId(value);
|
||||||
assert(bufferId != Allocation::InvalidBufferId && "BufferId not found");
|
assert(bufferId != Allocation::InvalidBufferId && "BufferId not found");
|
||||||
size_t offset = allocation->getOffset(bufferId);
|
size_t offset = allocation->getOffset(bufferId);
|
||||||
auto llvmIndexTy = this->getTypeConverter()->getIndexType();
|
auto llvmIndexTy = this->getTypeConverter()->getIndexType();
|
||||||
@@ -598,6 +605,10 @@ public:
|
|||||||
Value base = gep(ptrTy, smem, offVal);
|
Value base = gep(ptrTy, smem, offVal);
|
||||||
return base;
|
return base;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
protected:
|
||||||
|
const Allocation *allocation;
|
||||||
|
Value smem;
|
||||||
};
|
};
|
||||||
|
|
||||||
// Convert SplatOp or arith::ConstantOp with SplatElementsAttr to a
|
// Convert SplatOp or arith::ConstantOp with SplatElementsAttr to a
|
||||||
@@ -1332,6 +1343,65 @@ struct AddPtrOpConversion
|
|||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
struct AllocTensorOpConversion
|
||||||
|
: public ConvertTritonGPUOpToLLVMPattern<triton::gpu::AllocTensorOp> {
|
||||||
|
using ConvertTritonGPUOpToLLVMPattern<
|
||||||
|
triton::gpu::AllocTensorOp>::ConvertTritonGPUOpToLLVMPattern;
|
||||||
|
|
||||||
|
LogicalResult
|
||||||
|
matchAndRewrite(triton::gpu::AllocTensorOp op, OpAdaptor adaptor,
|
||||||
|
ConversionPatternRewriter &rewriter) const override {
|
||||||
|
Location loc = op->getLoc();
|
||||||
|
Value smemBase = getSharedMemoryBase(loc, rewriter, op.getResult());
|
||||||
|
auto resultTy = op.getType().dyn_cast<RankedTensorType>();
|
||||||
|
auto llvmElemTy =
|
||||||
|
getTypeConverter()->convertType(resultTy.getElementType());
|
||||||
|
auto elemPtrTy = LLVM::LLVMPointerType::get(llvmElemTy, 3);
|
||||||
|
Value resultVal =
|
||||||
|
rewriter.create<LLVM::BitcastOp>(loc, elemPtrTy, smemBase);
|
||||||
|
rewriter.replaceOp(op, resultVal);
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
struct ExtractSliceOpConversion
|
||||||
|
: public ConvertTritonGPUOpToLLVMPattern<triton::gpu::ExtractSliceOp> {
|
||||||
|
using ConvertTritonGPUOpToLLVMPattern<
|
||||||
|
triton::gpu::ExtractSliceOp>::ConvertTritonGPUOpToLLVMPattern;
|
||||||
|
|
||||||
|
LogicalResult
|
||||||
|
matchAndRewrite(triton::gpu::ExtractSliceOp op, OpAdaptor adaptor,
|
||||||
|
ConversionPatternRewriter &rewriter) const override {
|
||||||
|
Location loc = op->getLoc();
|
||||||
|
auto srcTy = op.src().getType().dyn_cast<RankedTensorType>();
|
||||||
|
auto srcLayout = srcTy.getEncoding().dyn_cast<SharedEncodingAttr>();
|
||||||
|
assert(srcLayout && "Unexpected resultLayout in ExtractSliceOpConversion");
|
||||||
|
|
||||||
|
// axis > 0 will result in non-contiguous memory access if the result tensor
|
||||||
|
// is an alias of the source tensor.
|
||||||
|
auto axis =
|
||||||
|
op->getAttrOfType<IntegerAttr>("axis").cast<IntegerAttr>().getInt();
|
||||||
|
assert(axis == 0 && "Only axis=0 is supported for now");
|
||||||
|
|
||||||
|
// Example:
|
||||||
|
// %dst = extract_slice %src, %index {axis = 0}
|
||||||
|
// src.shape = [11, 2, 3, 4, 1]
|
||||||
|
// offset = %index * 2 * 3 * 4 * 1
|
||||||
|
auto dstTy = op.getType().dyn_cast<RankedTensorType>();
|
||||||
|
auto base = product<int64_t>(dstTy.getShape());
|
||||||
|
auto baseVal = createIndexAttrConstant(
|
||||||
|
rewriter, loc, getTypeConverter()->getIndexType(), base);
|
||||||
|
Value offset = rewriter.create<LLVM::MulOp>(loc, adaptor.index(), baseVal);
|
||||||
|
|
||||||
|
auto llvmElemTy = getTypeConverter()->convertType(dstTy.getElementType());
|
||||||
|
auto elemPtrTy = LLVM::LLVMPointerType::get(llvmElemTy, 3);
|
||||||
|
Value resultVal =
|
||||||
|
rewriter.create<LLVM::GEPOp>(loc, elemPtrTy, adaptor.src(), offset);
|
||||||
|
rewriter.replaceOp(op, resultVal);
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
template <typename SourceOp, typename DestOp>
|
template <typename SourceOp, typename DestOp>
|
||||||
class BinaryOpConversion : public ConvertTritonGPUOpToLLVMPattern<SourceOp> {
|
class BinaryOpConversion : public ConvertTritonGPUOpToLLVMPattern<SourceOp> {
|
||||||
public:
|
public:
|
||||||
@@ -1379,13 +1449,6 @@ public:
|
|||||||
using ConvertTritonGPUOpToLLVMPattern<
|
using ConvertTritonGPUOpToLLVMPattern<
|
||||||
triton::gpu::ConvertLayoutOp>::ConvertTritonGPUOpToLLVMPattern;
|
triton::gpu::ConvertLayoutOp>::ConvertTritonGPUOpToLLVMPattern;
|
||||||
|
|
||||||
ConvertLayoutOpConversion(LLVMTypeConverter &converter,
|
|
||||||
const Allocation *allocation, Value smem,
|
|
||||||
PatternBenefit benefit)
|
|
||||||
: ConvertTritonGPUOpToLLVMPattern<triton::gpu::ConvertLayoutOp>(converter,
|
|
||||||
benefit),
|
|
||||||
allocation(allocation), smem(smem) {}
|
|
||||||
|
|
||||||
LogicalResult
|
LogicalResult
|
||||||
matchAndRewrite(triton::gpu::ConvertLayoutOp op, OpAdaptor adaptor,
|
matchAndRewrite(triton::gpu::ConvertLayoutOp op, OpAdaptor adaptor,
|
||||||
ConversionPatternRewriter &rewriter) const override {
|
ConversionPatternRewriter &rewriter) const override {
|
||||||
@@ -1399,13 +1462,10 @@ public:
|
|||||||
if ((!srcLayout.isa<BlockedEncodingAttr>()) ||
|
if ((!srcLayout.isa<BlockedEncodingAttr>()) ||
|
||||||
(!dstLayout.isa<BlockedEncodingAttr>())) {
|
(!dstLayout.isa<BlockedEncodingAttr>())) {
|
||||||
// TODO: not implemented
|
// TODO: not implemented
|
||||||
llvm::errs()
|
|
||||||
<< "convert_layout except for blocked -> blocked is not implemented";
|
|
||||||
return failure();
|
return failure();
|
||||||
}
|
}
|
||||||
auto llvmElemTy = getTypeConverter()->convertType(dstTy.getElementType());
|
auto llvmElemTy = getTypeConverter()->convertType(dstTy.getElementType());
|
||||||
Value smemBase =
|
Value smemBase = getSharedMemoryBase(loc, rewriter, op.getOperation());
|
||||||
getSharedMemoryBase(loc, rewriter, smem, allocation, op.getOperation());
|
|
||||||
auto elemPtrTy = LLVM::LLVMPointerType::get(llvmElemTy, 3);
|
auto elemPtrTy = LLVM::LLVMPointerType::get(llvmElemTy, 3);
|
||||||
smemBase = bit_cast(elemPtrTy, smemBase);
|
smemBase = bit_cast(elemPtrTy, smemBase);
|
||||||
|
|
||||||
@@ -1587,9 +1647,6 @@ private:
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
const Allocation *allocation;
|
|
||||||
Value smem;
|
|
||||||
};
|
};
|
||||||
|
|
||||||
/// ====================== dot codegen begin ==========================
|
/// ====================== dot codegen begin ==========================
|
||||||
@@ -1926,11 +1983,8 @@ struct DotOpConversion : public ConvertTritonGPUOpToLLVMPattern<triton::DotOp> {
|
|||||||
NOT_APPLICABLE,
|
NOT_APPLICABLE,
|
||||||
};
|
};
|
||||||
|
|
||||||
explicit DotOpConversion(LLVMTypeConverter &typeConverter,
|
using ConvertTritonGPUOpToLLVMPattern<
|
||||||
const Allocation *allocation, Value smem,
|
triton::DotOp>::ConvertTritonGPUOpToLLVMPattern;
|
||||||
PatternBenefit benefit = 1)
|
|
||||||
: ConvertTritonGPUOpToLLVMPattern(typeConverter, benefit),
|
|
||||||
allocation(allocation), smem(smem) {}
|
|
||||||
|
|
||||||
LogicalResult
|
LogicalResult
|
||||||
matchAndRewrite(triton::DotOp op, OpAdaptor adaptor,
|
matchAndRewrite(triton::DotOp op, OpAdaptor adaptor,
|
||||||
@@ -1995,15 +2049,6 @@ private:
|
|||||||
assert(false && "Not implemented yet.");
|
assert(false && "Not implemented yet.");
|
||||||
return failure();
|
return failure();
|
||||||
}
|
}
|
||||||
|
|
||||||
Value getSmemAddr(Value value, Location loc,
|
|
||||||
ConversionPatternRewriter &rewriter) const {
|
|
||||||
return getSharedMemoryBase(loc, rewriter, smem, allocation,
|
|
||||||
value.getDefiningOp());
|
|
||||||
}
|
|
||||||
|
|
||||||
const Allocation *allocation;
|
|
||||||
Value smem;
|
|
||||||
};
|
};
|
||||||
|
|
||||||
struct DotOpConversionHelper {
|
struct DotOpConversionHelper {
|
||||||
@@ -2340,7 +2385,7 @@ DotOpConversion::convertMMA16816(triton::DotOp op, OpAdaptor adapter,
|
|||||||
SmallVector<Value> ptrs(numPtrs);
|
SmallVector<Value> ptrs(numPtrs);
|
||||||
|
|
||||||
Type smemPtrTy = helper.getShemPtrTy();
|
Type smemPtrTy = helper.getShemPtrTy();
|
||||||
auto smemBase = getSmemAddr(tensor, loc, rewriter);
|
auto smemBase = getSharedMemoryBase(loc, rewriter, tensor);
|
||||||
for (int i = 0; i < numPtrs; i++) {
|
for (int i = 0; i < numPtrs; i++) {
|
||||||
ptrs[i] = bit_cast(
|
ptrs[i] = bit_cast(
|
||||||
smemPtrTy, gep(smemBase.getType(), smemBase, ValueRange({offs[i]})));
|
smemPtrTy, gep(smemBase.getType(), smemBase, ValueRange({offs[i]})));
|
||||||
@@ -2479,10 +2524,12 @@ public:
|
|||||||
SmallVector<Type, 4> types(numElementsPerThread,
|
SmallVector<Type, 4> types(numElementsPerThread,
|
||||||
convertType(type.getElementType()));
|
convertType(type.getElementType()));
|
||||||
return LLVM::LLVMStructType::getLiteral(&getContext(), types);
|
return LLVM::LLVMStructType::getLiteral(&getContext(), types);
|
||||||
} else if (auto mma_layout = layout.dyn_cast<MmaEncodingAttr>()) {
|
} else if (auto mma_layout = layout.dyn_cast_or_null<MmaEncodingAttr>()) {
|
||||||
return type;
|
// TODO: Not implemented
|
||||||
} else if (auto shared_layout = layout.dyn_cast<SharedEncodingAttr>()) {
|
|
||||||
return type;
|
return type;
|
||||||
|
} else if (auto shared_layout =
|
||||||
|
layout.dyn_cast_or_null<SharedEncodingAttr>()) {
|
||||||
|
return LLVM::LLVMPointerType::get(convertType(type.getElementType()), 3);
|
||||||
}
|
}
|
||||||
return llvm::None;
|
return llvm::None;
|
||||||
}
|
}
|
||||||
@@ -2493,6 +2540,9 @@ void populateTritonToLLVMPatterns(mlir::LLVMTypeConverter &typeConverter,
|
|||||||
AxisInfoAnalysis &axisInfoAnalysis,
|
AxisInfoAnalysis &axisInfoAnalysis,
|
||||||
const Allocation *allocation, Value smem,
|
const Allocation *allocation, Value smem,
|
||||||
PatternBenefit benefit = 1) {
|
PatternBenefit benefit = 1) {
|
||||||
|
patterns.add<AddPtrOpConversion>(typeConverter, benefit);
|
||||||
|
patterns.add<AllocTensorOpConversion>(typeConverter, allocation, smem,
|
||||||
|
benefit);
|
||||||
patterns.add<ArithConstantSplatOpConversion>(typeConverter, benefit);
|
patterns.add<ArithConstantSplatOpConversion>(typeConverter, benefit);
|
||||||
patterns.add<BinaryOpConversion<arith::AddIOp, LLVM::AddOp>>(typeConverter,
|
patterns.add<BinaryOpConversion<arith::AddIOp, LLVM::AddOp>>(typeConverter,
|
||||||
benefit);
|
benefit);
|
||||||
@@ -2503,9 +2553,10 @@ void populateTritonToLLVMPatterns(mlir::LLVMTypeConverter &typeConverter,
|
|||||||
patterns.add<BinaryOpConversion<arith::MulFOp, LLVM::FMulOp>>(typeConverter,
|
patterns.add<BinaryOpConversion<arith::MulFOp, LLVM::FMulOp>>(typeConverter,
|
||||||
benefit);
|
benefit);
|
||||||
patterns.add<BroadcastOpConversion>(typeConverter, benefit);
|
patterns.add<BroadcastOpConversion>(typeConverter, benefit);
|
||||||
patterns.add<AddPtrOpConversion>(typeConverter, benefit);
|
|
||||||
patterns.add<ConvertLayoutOpConversion>(typeConverter, allocation, smem,
|
patterns.add<ConvertLayoutOpConversion>(typeConverter, allocation, smem,
|
||||||
benefit);
|
benefit);
|
||||||
|
patterns.add<ExtractSliceOpConversion>(typeConverter, allocation, smem,
|
||||||
|
benefit);
|
||||||
patterns.add<GetProgramIdOpConversion>(typeConverter, benefit);
|
patterns.add<GetProgramIdOpConversion>(typeConverter, benefit);
|
||||||
patterns.add<LoadOpConversion>(typeConverter, axisInfoAnalysis, benefit);
|
patterns.add<LoadOpConversion>(typeConverter, axisInfoAnalysis, benefit);
|
||||||
patterns.add<MakeRangeOpConversion>(typeConverter, benefit);
|
patterns.add<MakeRangeOpConversion>(typeConverter, benefit);
|
||||||
|
@@ -431,9 +431,10 @@ mlir::LogicalResult ExtractSliceOp::inferReturnTypes(
|
|||||||
auto axis = attributes.get("axis").cast<IntegerAttr>().getInt();
|
auto axis = attributes.get("axis").cast<IntegerAttr>().getInt();
|
||||||
if (axis < 0 || axis > srcShape.size())
|
if (axis < 0 || axis > srcShape.size())
|
||||||
return failure();
|
return failure();
|
||||||
// Since we only extract a slice from a certain index on the axis,
|
SmallVector<int64_t, 4> dstShape;
|
||||||
// the dims before the axis can be dropped.
|
for (int i = 0; i < srcShape.size(); i++)
|
||||||
auto dstShape = srcShape.drop_front(axis + 1);
|
if (i != axis)
|
||||||
|
dstShape.push_back(srcShape[i]);
|
||||||
auto returnType =
|
auto returnType =
|
||||||
RankedTensorType::get(dstShape, srcType.getElementType(), encoding);
|
RankedTensorType::get(dstShape, srcType.getElementType(), encoding);
|
||||||
inferredReturnTypes.assign({returnType});
|
inferredReturnTypes.assign({returnType});
|
||||||
|
@@ -22,11 +22,9 @@ func @matmul_loop(%lb : index, %ub : index, %step : index, %A : !tt.ptr<f16>, %B
|
|||||||
|
|
||||||
scf.for %iv = %lb to %ub step %step iter_args(%a_ptr = %a_ptr_init, %b_ptr = %b_ptr_init, %prev_c = %c_init) -> (tensor<128x32x!tt.ptr<f16>, #AL>, tensor<32x128x!tt.ptr<f16>, #BL>, tensor<128x128xf32, #C>) {
|
scf.for %iv = %lb to %ub step %step iter_args(%a_ptr = %a_ptr_init, %b_ptr = %b_ptr_init, %prev_c = %c_init) -> (tensor<128x32x!tt.ptr<f16>, #AL>, tensor<32x128x!tt.ptr<f16>, #BL>, tensor<128x128xf32, #C>) {
|
||||||
%a_ = tt.load %a_ptr, %a_mask, %a_other {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<128x32xf16, #AL>
|
%a_ = tt.load %a_ptr, %a_mask, %a_other {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<128x32xf16, #AL>
|
||||||
// CHECK: scratch offset = 8192, size = 0
|
// CHECK: offset = 0, size = 8192
|
||||||
// CHECK-NEXT: offset = 0, size = 8192
|
|
||||||
%a = triton_gpu.convert_layout %a_ : (tensor<128x32xf16, #AL>) -> tensor<128x32xf16, #A>
|
%a = triton_gpu.convert_layout %a_ : (tensor<128x32xf16, #AL>) -> tensor<128x32xf16, #A>
|
||||||
%b_ = tt.load %b_ptr, %b_mask, %b_other {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<32x128xf16, #BL>
|
%b_ = tt.load %b_ptr, %b_mask, %b_other {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<32x128xf16, #BL>
|
||||||
// CHECK-NEXT: scratch offset = 16384, size = 0
|
|
||||||
// CHECK-NEXT: offset = 8192, size = 8192
|
// CHECK-NEXT: offset = 8192, size = 8192
|
||||||
%b = triton_gpu.convert_layout %b_ : (tensor<32x128xf16, #BL>) -> tensor<32x128xf16, #B>
|
%b = triton_gpu.convert_layout %b_ : (tensor<32x128xf16, #BL>) -> tensor<32x128xf16, #B>
|
||||||
|
|
||||||
@@ -52,20 +50,16 @@ func @reusable(%A : !tt.ptr<f16>) {
|
|||||||
%a_ptr = tt.broadcast %A : (!tt.ptr<f16>) -> tensor<128x32x!tt.ptr<f16>, #AL>
|
%a_ptr = tt.broadcast %A : (!tt.ptr<f16>) -> tensor<128x32x!tt.ptr<f16>, #AL>
|
||||||
%b_ptr = tt.broadcast %A : (!tt.ptr<f16>) -> tensor<32x128x!tt.ptr<f16>, #AL>
|
%b_ptr = tt.broadcast %A : (!tt.ptr<f16>) -> tensor<32x128x!tt.ptr<f16>, #AL>
|
||||||
%a1_ = tt.load %a_ptr, %cst1, %cst2 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<128x32xf16, #AL>
|
%a1_ = tt.load %a_ptr, %cst1, %cst2 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<128x32xf16, #AL>
|
||||||
// CHECK: scratch offset = 8192, size = 0
|
|
||||||
// CHECK-NEXT: offset = 0, size = 8192
|
// CHECK-NEXT: offset = 0, size = 8192
|
||||||
%a1 = triton_gpu.convert_layout %a1_ : (tensor<128x32xf16, #AL>) -> tensor<128x32xf16, #A>
|
%a1 = triton_gpu.convert_layout %a1_ : (tensor<128x32xf16, #AL>) -> tensor<128x32xf16, #A>
|
||||||
%a2_ = tt.load %b_ptr, %cst3, %cst4 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<32x128xf16, #AL>
|
%a2_ = tt.load %b_ptr, %cst3, %cst4 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<32x128xf16, #AL>
|
||||||
// CHECK-NEXT: scratch offset = 16384, size = 0
|
|
||||||
// CHECK-NEXT: offset = 8192, size = 8192
|
// CHECK-NEXT: offset = 8192, size = 8192
|
||||||
%a2 = triton_gpu.convert_layout %a2_ : (tensor<32x128xf16, #AL>) -> tensor<32x128xf16, #A>
|
%a2 = triton_gpu.convert_layout %a2_ : (tensor<32x128xf16, #AL>) -> tensor<32x128xf16, #A>
|
||||||
%a3_ = tt.load %a_ptr, %cst1, %cst2 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<128x32xf16, #AL>
|
%a3_ = tt.load %a_ptr, %cst1, %cst2 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<128x32xf16, #AL>
|
||||||
// CHECK-NEXT: scratch offset = 24576, size = 0
|
|
||||||
// CHECK-NEXT: offset = 16384, size = 8192
|
// CHECK-NEXT: offset = 16384, size = 8192
|
||||||
%a3 = triton_gpu.convert_layout %a3_ : (tensor<128x32xf16, #AL>) -> tensor<128x32xf16, #A>
|
%a3 = triton_gpu.convert_layout %a3_ : (tensor<128x32xf16, #AL>) -> tensor<128x32xf16, #A>
|
||||||
%c = tt.dot %a1, %a2, %c_init {allowTF32 = true} : tensor<128x32xf16, #A> * tensor<32x128xf16, #B> -> tensor<128x128xf32, #C>
|
%c = tt.dot %a1, %a2, %c_init {allowTF32 = true} : tensor<128x32xf16, #A> * tensor<32x128xf16, #B> -> tensor<128x128xf32, #C>
|
||||||
%a4_ = tt.load %b_ptr, %cst3, %cst4 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<32x128xf16, #AL>
|
%a4_ = tt.load %b_ptr, %cst3, %cst4 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<32x128xf16, #AL>
|
||||||
// CHECK-NEXT: scratch offset = 8192, size = 0
|
|
||||||
// CHECK-NEXT: offset = 0, size = 8192
|
// CHECK-NEXT: offset = 0, size = 8192
|
||||||
%a4 = triton_gpu.convert_layout %a4_ : (tensor<32x128xf16, #AL>) -> tensor<32x128xf16, #A>
|
%a4 = triton_gpu.convert_layout %a4_ : (tensor<32x128xf16, #AL>) -> tensor<32x128xf16, #A>
|
||||||
%c1 = tt.dot %a3, %a4, %c {allowTF32 = true} : tensor<128x32xf16, #A> * tensor<32x128xf16, #B> -> tensor<128x128xf32, #C>
|
%c1 = tt.dot %a3, %a4, %c {allowTF32 = true} : tensor<128x32xf16, #A> * tensor<32x128xf16, #B> -> tensor<128x128xf32, #C>
|
||||||
|
@@ -293,6 +293,44 @@ module attributes {"triton_gpu.num-warps" = 4 : i32} {
|
|||||||
|
|
||||||
// -----
|
// -----
|
||||||
|
|
||||||
|
#shared0 = #triton_gpu.shared<{vec = 2, perPhase = 2, maxPhase = 4, order = [1, 0]}>
|
||||||
|
module attributes {"triton_gpu.num-warps" = 4 : i32} {
|
||||||
|
// CHECK: llvm.mlir.global internal @global_smem
|
||||||
|
// CHECK-LABEL: basic_alloc_tensor
|
||||||
|
func @basic_alloc_tensor() {
|
||||||
|
// CHECK: llvm.mlir.addressof @global_smem
|
||||||
|
// CHECK-NEXT: llvm.mlir.constant
|
||||||
|
// CHECK-NEXT: llvm.getelementptr
|
||||||
|
// CHECK-NEXT: llvm.bitcast
|
||||||
|
%0 = triton_gpu.alloc_tensor : tensor<16x16xf16, #shared0>
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
|
#shared0 = #triton_gpu.shared<{vec = 2, perPhase = 2, maxPhase = 4, order = [1, 0]}>
|
||||||
|
module attributes {"triton_gpu.num-warps" = 4 : i32} {
|
||||||
|
// CHECK: llvm.mlir.global internal @global_smem
|
||||||
|
// CHECK-LABEL: basic_extract_slice
|
||||||
|
func @basic_extract_slice() {
|
||||||
|
// CHECK: %[[BASE0:.*]] = llvm.mlir.addressof @global_smem
|
||||||
|
// CHECK-NEXT: %[[OFFSET0:.*]] = llvm.mlir.constant
|
||||||
|
// CHECK-NEXT: %[[OFFSET1:.*]] = llvm.mlir.constant
|
||||||
|
// CHECK-NEXT: llvm.getelementptr %[[BASE0]][%[[OFFSET1]]]
|
||||||
|
// CHECK-NEXT: %[[BASE1:.*]] = llvm.bitcast
|
||||||
|
// CHECK-NEXT: %[[OFFSET2:.*]] = llvm.mlir.constant
|
||||||
|
// CHECK-NEXT: %[[OFFSET3:.*]] = llvm.mul %[[OFFSET0]], %[[OFFSET2]]
|
||||||
|
// CHECK-NEXT: llvm.getelementptr %[[BASE1]][%[[OFFSET3]]]
|
||||||
|
%index = arith.constant 1 : i32
|
||||||
|
%0 = triton_gpu.alloc_tensor : tensor<128x16x32xf32, #shared0>
|
||||||
|
%1 = triton_gpu.extract_slice %0, %index {axis = 0: i32} : tensor<128x16x32xf32, #shared0> -> tensor<16x32xf32, #shared0>
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
#blocked0 = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}>
|
#blocked0 = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}>
|
||||||
module attributes {"triton_gpu.num-warps" = 4 : i32} {
|
module attributes {"triton_gpu.num-warps" = 4 : i32} {
|
||||||
// CHECK: basic_splat
|
// CHECK: basic_splat
|
||||||
|
Reference in New Issue
Block a user