[Triton-MLIR] Replace triton.extract_slice with tensor.extract_slice and support more general tensor slicing (#837)

## Features

- Allow taking a block of tensor slice, as long as each dimension is
contiguous (unit stride).
- Fix some problems in `insert_slice_async`'s semantic.
- More general verification for ops that return shared layout encoding.

## Known Limitations

- `insert_slice_async` still uses the old semantic. May submit another
PR later to support similar semantic like `tensor.extract_slice`.
- No encoding verification for `tensor.extract_slice`.
- 3d tensor ops are broken.
- Strided accesses are not allowed.
- May cause a little performance slowdown since we are passing strides
as values but not constants (e.g., int).
It would be difficult to pass strides as attributes when we have control
flows. A block argument is possible to accept tensors with different
strides.
This commit is contained in:
Keren Zhou
2022-11-06 22:59:03 -08:00
committed by GitHub
parent a4ff0c362c
commit fdd59900f7
26 changed files with 507 additions and 339 deletions

View File

@@ -38,6 +38,7 @@ def ConvertTritonGPUToLLVM : Pass<"convert-triton-gpu-to-llvm", "mlir::ModuleOp"
"mlir::gpu::GPUDialect", "mlir::gpu::GPUDialect",
"mlir::scf::SCFDialect", "mlir::scf::SCFDialect",
"mlir::LLVM::LLVMDialect", "mlir::LLVM::LLVMDialect",
"mlir::tensor::TensorDialect",
"mlir::triton::TritonDialect", "mlir::triton::TritonDialect",
"mlir::triton::gpu::TritonGPUDialect", "mlir::triton::gpu::TritonGPUDialect",
"mlir::NVVM::NVVMDialect", "mlir::NVVM::NVVMDialect",

View File

@@ -1,7 +1,6 @@
#ifndef TRITON_DIALECT_TRITON_IR_DIALECT_H_ #ifndef TRITON_DIALECT_TRITON_IR_DIALECT_H_
#define TRITON_DIALECT_TRITON_IR_DIALECT_H_ #define TRITON_DIALECT_TRITON_IR_DIALECT_H_
#include "mlir/Dialect/GPU/GPUDialect.h"
#include "mlir/Dialect/Math/IR/Math.h" #include "mlir/Dialect/Math/IR/Math.h"
#include "mlir/Dialect/SCF/SCF.h" #include "mlir/Dialect/SCF/SCF.h"
#include "mlir/Dialect/StandardOps/IR/Ops.h" #include "mlir/Dialect/StandardOps/IR/Ops.h"

View File

@@ -27,7 +27,6 @@ def Triton_Dialect : Dialect {
"math::MathDialect", "math::MathDialect",
"StandardOpsDialect", "StandardOpsDialect",
"scf::SCFDialect", "scf::SCFDialect",
"gpu::GPUDialect",
// Since LLVM 15 // Since LLVM 15
// "cf::ControlFlowDialect", // "cf::ControlFlowDialect",

View File

@@ -2,6 +2,7 @@
#define TRITON_DIALECT_TRITONGPU_IR_DIALECT_H_ #define TRITON_DIALECT_TRITONGPU_IR_DIALECT_H_
#include "mlir/Dialect/GPU/GPUDialect.h" #include "mlir/Dialect/GPU/GPUDialect.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/IR/BuiltinOps.h" #include "mlir/IR/BuiltinOps.h"
#include "mlir/IR/Dialect.h" #include "mlir/IR/Dialect.h"
@@ -9,6 +10,7 @@
#include "triton/Dialect/Triton/IR/Dialect.h" #include "triton/Dialect/Triton/IR/Dialect.h"
#include "triton/Dialect/TritonGPU/IR/Dialect.h.inc" #include "triton/Dialect/TritonGPU/IR/Dialect.h.inc"
#include "triton/Dialect/TritonGPU/IR/Traits.h"
#define GET_ATTRDEF_CLASSES #define GET_ATTRDEF_CLASSES
#include "triton/Dialect/Triton/IR/AttrInterfaces.h.inc" #include "triton/Dialect/Triton/IR/AttrInterfaces.h.inc"

View File

@@ -0,0 +1,31 @@
#ifndef TRITON_GPU_IR_TRAITS_H_
#define TRITON_GPU_IR_TRAITS_H_
#include "mlir/IR/OpDefinition.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/Support/LogicalResult.h"
namespace mlir {
namespace OpTrait {
// These functions are out-of-line implementations of the methods in the
// corresponding trait classes. This avoids them being template
// instantiated/duplicated.
namespace impl {
LogicalResult verifyResultsAreSharedEncoding(Operation *op);
} // namespace impl
template <typename ConcreteType>
class ResultsAreSharedEncoding
: public TraitBase<ConcreteType, ResultsAreSharedEncoding> {
public:
static LogicalResult verifyTrait(Operation *op) {
return impl::verifyResultsAreSharedEncoding(op);
}
};
} // namespace OpTrait
} // namespace mlir
#endif

View File

@@ -16,7 +16,8 @@ def TritonGPU_Dialect : Dialect {
let dependentDialects = [ let dependentDialects = [
"triton::TritonDialect", "triton::TritonDialect",
"mlir::gpu::GPUDialect" "mlir::gpu::GPUDialect",
"tensor::TensorDialect",
]; ];
let extraClassDeclaration = [{ let extraClassDeclaration = [{

View File

@@ -10,6 +10,8 @@ include "mlir/IR/OpBase.td"
include "mlir/Interfaces/SideEffectInterfaces.td" // NoSideEffect include "mlir/Interfaces/SideEffectInterfaces.td" // NoSideEffect
include "mlir/Interfaces/InferTypeOpInterface.td" // SameOperandsAndResultType include "mlir/Interfaces/InferTypeOpInterface.td" // SameOperandsAndResultType
def ResultsAreSharedEncoding: NativeOpTrait<"ResultsAreSharedEncoding">;
class TTG_Op<string mnemonic, list<Trait> traits = []> : class TTG_Op<string mnemonic, list<Trait> traits = []> :
Op<TritonGPU_Dialect, mnemonic, traits>; Op<TritonGPU_Dialect, mnemonic, traits>;
@@ -75,7 +77,8 @@ def TTG_SelectOp : TTG_Op<"select", [NoSideEffect]> {
def TTG_InsertSliceAsyncOp : TTG_Op<"insert_slice_async", def TTG_InsertSliceAsyncOp : TTG_Op<"insert_slice_async",
[SameVariadicOperandSize, [AttrSizedOperandSegments,
ResultsAreSharedEncoding,
// MemoryEffects<[MemRead]>, doesn't work with CSE but seems like it should? // MemoryEffects<[MemRead]>, doesn't work with CSE but seems like it should?
NoSideEffect, NoSideEffect,
TypesMatchWith<"infer mask type from src type", TypesMatchWith<"infer mask type from src type",
@@ -93,6 +96,10 @@ def TTG_InsertSliceAsyncOp : TTG_Op<"insert_slice_async",
It returns a copy of `$dst` with the proper slice updated asynchronously with the value of `$src`. It returns a copy of `$dst` with the proper slice updated asynchronously with the value of `$src`.
This operation is non-blocking, and `$results` will have the updated value after the corresponding async_wait. This operation is non-blocking, and `$results` will have the updated value after the corresponding async_wait.
When converting from `tt.load` to `triton_gpu.insert_slice_async`, the `$evict`, `$cache`, and `$isVolatile` fields
might be ignored on certain hardware. For example, on NVIDIA GPUs, the cache policy is determined by the backend,
and `$evict` and `$isVolatile` are ignored because they apply to L1 cache only.
The insert_slice_async operation supports the following arguments: The insert_slice_async operation supports the following arguments:
* src: the tensor that is inserted. * src: the tensor that is inserted.
@@ -149,48 +156,9 @@ def TTG_InsertSliceAsyncOp : TTG_Op<"insert_slice_async",
let parser = [{ return parseInsertSliceAsyncOp(parser, result); }]; let parser = [{ return parseInsertSliceAsyncOp(parser, result); }];
let printer = [{ return printInsertSliceAsyncOp(p, *this); }]; let printer = [{ return printInsertSliceAsyncOp(p, *this); }];
// result needs to be of shared layout
let verifier = [{ return ::verify(*this); }];
} }
def TTG_ExtractSliceOp : TTG_Op<"extract_slice", [NoSideEffect, InferTypeOpInterface]> { def TTG_AllocTensorOp : TTG_Op<"alloc_tensor", [NoSideEffect, ResultsAreSharedEncoding]> {
let summary = "extract slice";
let description = [{
The "extract_slice" operation extracts a `$result` tensor from a `$src` tensor as
specified by the operation's `$index` and `$axis` arguments.
The extract_slice operation supports the following arguments:
* src: the tensor that is extracted from.
* index: the index at the given `$axis` from which the `$src` tensor is extracted
Example:
```
// Rank-reducing extract_slice.
%1 = tensor.extract_slice %0, %index {axis = 0} : tensor<8x16x4xf32> -> tensor<1x16x4xf32>
```
}];
let arguments = (ins TT_Tensor:$src, I32:$index, I32Attr:$axis);
let results = (outs TT_Tensor:$result);
let assemblyFormat = [{$src `,` $index attr-dict `:` type($src) `->` type($result)}];
let extraClassDeclaration = [{
static ::mlir::LogicalResult inferReturnTypes(::mlir::MLIRContext *context,
::llvm::Optional<::mlir::Location> location, ::mlir::ValueRange operands,
::mlir::DictionaryAttr attributes, ::mlir::RegionRange regions,
::llvm::SmallVectorImpl<::mlir::Type> &inferredReturnTypes);
}];
// result needs to be of shared layout
let verifier = [{ return ::verify(*this); }];
}
def TTG_AllocTensorOp : TTG_Op<"alloc_tensor", [NoSideEffect]> {
let summary = "allocate tensor"; let summary = "allocate tensor";
let description = [{ let description = [{
@@ -203,9 +171,6 @@ def TTG_AllocTensorOp : TTG_Op<"alloc_tensor", [NoSideEffect]> {
let assemblyFormat = [{attr-dict `:` type($result)}]; let assemblyFormat = [{attr-dict `:` type($result)}];
let results = (outs TT_Tensor:$result); let results = (outs TT_Tensor:$result);
// result needs to be of shared layout
let verifier = [{ return ::verify(*this); }];
} }
#endif #endif

View File

@@ -1,4 +1,5 @@
#include "triton/Analysis/Alias.h" #include "triton/Analysis/Alias.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "triton/Analysis/Utility.h" #include "triton/Analysis/Utility.h"
#include "triton/Dialect/TritonGPU/IR/Dialect.h" #include "triton/Dialect/TritonGPU/IR/Dialect.h"
@@ -24,18 +25,18 @@ ChangeResult SharedMemoryAliasAnalysis::visitOperation(
if (maybeSharedAllocationOp(op)) { if (maybeSharedAllocationOp(op)) {
// These ops may allocate a new shared memory buffer. // These ops may allocate a new shared memory buffer.
auto result = op->getResult(0); auto result = op->getResult(0);
if (isSharedEncoding(result)) {
// FIXME(Keren): extract and insert are always alias for now // FIXME(Keren): extract and insert are always alias for now
if (auto extractSliceOp = dyn_cast<triton::gpu::ExtractSliceOp>(op)) { if (auto extractSliceOp = dyn_cast<tensor::ExtractSliceOp>(op)) {
// extract_slice %src, %index // extract_slice %src
aliasInfo = AliasInfo(operands[0]->getValue()); aliasInfo = AliasInfo(operands[0]->getValue());
pessimistic = false;
} else if (auto insertSliceOp = } else if (auto insertSliceOp =
dyn_cast<triton::gpu::InsertSliceAsyncOp>(op)) { dyn_cast<triton::gpu::InsertSliceAsyncOp>(op)) {
// insert_slice_async %src, %dst, %index // insert_slice_async %src, %dst, %index
aliasInfo = AliasInfo(operands[1]->getValue()); aliasInfo = AliasInfo(operands[1]->getValue());
} else { pessimistic = false;
} else if (isSharedEncoding(result)) {
aliasInfo.insert(result); aliasInfo.insert(result);
}
pessimistic = false; pessimistic = false;
} }
} }

View File

@@ -1,6 +1,7 @@
#include "triton/Analysis/Allocation.h" #include "triton/Analysis/Allocation.h"
#include "mlir/Analysis/Liveness.h" #include "mlir/Analysis/Liveness.h"
#include "mlir/Analysis/SliceAnalysis.h" #include "mlir/Analysis/SliceAnalysis.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "triton/Analysis/Alias.h" #include "triton/Analysis/Alias.h"
#include "triton/Analysis/Utility.h" #include "triton/Analysis/Utility.h"
#include "triton/Dialect/TritonGPU/IR/Dialect.h" #include "triton/Dialect/TritonGPU/IR/Dialect.h"
@@ -76,13 +77,13 @@ SmallVector<unsigned> getScratchConfigForReduce(triton::ReduceOp op) {
auto srcShape = srcTy.getShape(); auto srcShape = srcTy.getShape();
auto axis = op.axis(); auto axis = op.axis();
bool fast_reduce = axis == 1; // FIXME(Qingyi): The fastest-changing dimension bool fastReduce = axis == 1; // FIXME(Qingyi): The fastest-changing dimension
SmallVector<unsigned> smemShape; SmallVector<unsigned> smemShape;
for (auto d : srcShape) for (auto d : srcShape)
smemShape.push_back(d); smemShape.push_back(d);
if (fast_reduce) { if (fastReduce) {
unsigned sizeInterWarps = srcLayout.getWarpsPerCTA()[axis]; unsigned sizeInterWarps = srcLayout.getWarpsPerCTA()[axis];
smemShape[axis] = sizeInterWarps; smemShape[axis] = sizeInterWarps;
} else { } else {
@@ -123,7 +124,7 @@ private:
// For example: %a = scf.if -> yield // For example: %a = scf.if -> yield
// %a must be allocated elsewhere by other operations. // %a must be allocated elsewhere by other operations.
// FIXME(Keren): extract and insert are always alias for now // FIXME(Keren): extract and insert are always alias for now
if (!maybeSharedAllocationOp(op) || isa<triton::gpu::ExtractSliceOp>(op) || if (!maybeSharedAllocationOp(op) || isa<tensor::ExtractSliceOp>(op) ||
isa<triton::gpu::InsertSliceAsyncOp>(op)) { isa<triton::gpu::InsertSliceAsyncOp>(op)) {
return; return;
} }

View File

@@ -2,6 +2,7 @@
#include "triton/Dialect/TritonGPU/IR/Dialect.h" #include "triton/Dialect/TritonGPU/IR/Dialect.h"
#include "mlir/Dialect/GPU/GPUDialect.h" #include "mlir/Dialect/GPU/GPUDialect.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
namespace mlir { namespace mlir {
@@ -43,8 +44,7 @@ void MembarAnalysis::dfsOperation(Operation *operation,
void MembarAnalysis::transfer(Operation *op, RegionInfo *regionInfo, void MembarAnalysis::transfer(Operation *op, RegionInfo *regionInfo,
OpBuilder *builder) { OpBuilder *builder) {
if (isa<scf::ForOp>(op) || isa<scf::IfOp>(op) || isa<scf::YieldOp>(op) || if (isa<scf::ForOp>(op) || isa<scf::IfOp>(op) || isa<scf::YieldOp>(op) ||
isa<triton::gpu::ExtractSliceOp>(op) || isa<tensor::ExtractSliceOp>(op) || isa<triton::gpu::AllocTensorOp>(op)) {
isa<triton::gpu::AllocTensorOp>(op)) {
// Do not insert barriers before control flow operations and // Do not insert barriers before control flow operations and
// alloc/extract/insert // alloc/extract/insert
// alloc is an allocation op without memory write. // alloc is an allocation op without memory write.

View File

@@ -24,7 +24,8 @@ bool maybeSharedAllocationOp(Operation *op) {
mlir::TypeID::get<triton::gpu::TritonGPUDialect>() || mlir::TypeID::get<triton::gpu::TritonGPUDialect>() ||
dialect->getTypeID() == mlir::TypeID::get<triton::TritonDialect>() || dialect->getTypeID() == mlir::TypeID::get<triton::TritonDialect>() ||
dialect->getTypeID() == dialect->getTypeID() ==
mlir::TypeID::get<arith::ArithmeticDialect>()); mlir::TypeID::get<arith::ArithmeticDialect>() ||
dialect->getTypeID() == mlir::TypeID::get<tensor::TensorDialect>());
} }
std::string getValueOperandName(Value value, AsmState &state) { std::string getValueOperandName(Value value, AsmState &state) {

View File

@@ -11,6 +11,7 @@
#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
#include "mlir/Dialect/GPU/GPUDialect.h" #include "mlir/Dialect/GPU/GPUDialect.h"
#include "mlir/Dialect/LLVMIR/LLVMDialect.h" #include "mlir/Dialect/LLVMIR/LLVMDialect.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/IR/Matchers.h" #include "mlir/IR/Matchers.h"
#include "mlir/IR/TypeUtilities.h" #include "mlir/IR/TypeUtilities.h"
#include "mlir/Transforms/DialectConversion.h" #include "mlir/Transforms/DialectConversion.h"
@@ -50,7 +51,8 @@ static StringRef getStructAttrsAttrName() { return "llvm.struct_attrs"; }
namespace { namespace {
// Create a 32-bit integer constant. // Create a 32-bit integer constant.
Value createConstantI32(Location loc, PatternRewriter &rewriter, int32_t v) { static Value createConstantI32(Location loc, PatternRewriter &rewriter,
int32_t v) {
auto i32ty = rewriter.getIntegerType(32); auto i32ty = rewriter.getIntegerType(32);
return rewriter.create<LLVM::ConstantOp>(loc, i32ty, return rewriter.create<LLVM::ConstantOp>(loc, i32ty,
IntegerAttr::get(i32ty, v)); IntegerAttr::get(i32ty, v));
@@ -63,7 +65,7 @@ Value createConstantF32(Location loc, PatternRewriter &rewriter, float v) {
} }
// Create a index type constant. // Create a index type constant.
Value createIndexConstant(OpBuilder &builder, Location loc, static Value createIndexConstant(OpBuilder &builder, Location loc,
TypeConverter *converter, int64_t value) { TypeConverter *converter, int64_t value) {
Type ty = converter->convertType(builder.getIndexType()); Type ty = converter->convertType(builder.getIndexType());
@@ -72,8 +74,8 @@ Value createIndexConstant(OpBuilder &builder, Location loc,
} }
// Create an integer constant of \param width bits. // Create an integer constant of \param width bits.
Value createLLVMIntegerConstant(OpBuilder &builder, Location loc, short width, static Value createLLVMIntegerConstant(OpBuilder &builder, Location loc,
int64_t value) { short width, int64_t value) {
Type ty = builder.getIntegerType(width); Type ty = builder.getIntegerType(width);
return builder.create<LLVM::ConstantOp>(loc, ty, return builder.create<LLVM::ConstantOp>(loc, ty,
builder.getIntegerAttr(ty, value)); builder.getIntegerAttr(ty, value));
@@ -369,8 +371,8 @@ static T getLinearIndex(ArrayRef<T> multiDimIndex, ArrayRef<T> shape) {
return linearIndex; return linearIndex;
} }
Value storeShared(ConversionPatternRewriter &rewriter, Location loc, Value ptr, static Value storeShared(ConversionPatternRewriter &rewriter, Location loc,
Value val, Value pred) { Value ptr, Value val, Value pred) {
MLIRContext *ctx = rewriter.getContext(); MLIRContext *ctx = rewriter.getContext();
unsigned bits = val.getType().getIntOrFloatBitWidth(); unsigned bits = val.getType().getIntOrFloatBitWidth();
const char *c = bits == 64 ? "l" : (bits == 16 ? "h" : "r"); const char *c = bits == 64 ? "l" : (bits == 16 ? "h" : "r");
@@ -383,6 +385,50 @@ Value storeShared(ConversionPatternRewriter &rewriter, Location loc, Value ptr,
return builder.launch(rewriter, loc, void_ty(ctx)); return builder.launch(rewriter, loc, void_ty(ctx));
} }
struct SharedMemoryObject {
Value base; // i32 ptr. The start address of the shared memory object.
// We need to store strides as Values but not integers because the
// extract_slice instruction can take a slice at artibary offsets.
// Take $a[16:32, 16:32] as an example, though we know the stride of $a[0] is
// 32, we need to let the instruction that uses $a to be aware of that.
// Otherwise, when we use $a, we only know that the shape of $a is 16x16. If
// we store strides into an attribute array of integers, the information
// cannot pass through block argument assignment because attributes are
// associated with operations but not Values.
// TODO(Keren): We may need to figure out a way to store strides as integers
// if we want to support more optimizations.
SmallVector<Value>
strides; // i32 int. The strides of the shared memory object.
SharedMemoryObject(Value base, ArrayRef<Value> strides)
: base(base), strides(strides.begin(), strides.end()) {}
SharedMemoryObject(Value base, ArrayRef<int64_t> shape, Location loc,
ConversionPatternRewriter &rewriter)
: base(base) {
auto stride = 1;
for (auto dim : llvm::reverse(shape)) {
this->strides.emplace_back(i32_val(stride));
stride *= dim;
}
this->strides = llvm::to_vector<4>(llvm::reverse(this->strides));
}
SmallVector<Value> getElems() const {
SmallVector<Value> elems;
elems.push_back(base);
elems.append(strides.begin(), strides.end());
return elems;
}
SmallVector<Type> getTypes() const {
SmallVector<Type> types;
types.push_back(base.getType());
types.append(strides.size(), IntegerType::get(base.getContext(), 32));
return types;
}
};
struct ConvertTritonGPUOpToLLVMPatternBase { struct ConvertTritonGPUOpToLLVMPatternBase {
static SmallVector<Value> static SmallVector<Value>
getElementsFromStruct(Location loc, Value llvmStruct, getElementsFromStruct(Location loc, Value llvmStruct,
@@ -489,6 +535,16 @@ public:
return linear; return linear;
} }
Value dot(ConversionPatternRewriter &rewriter, Location loc,
ArrayRef<Value> offsets, ArrayRef<Value> strides) const {
assert(offsets.size() == strides.size());
Value ret = idx_val(0);
for (auto [offset, stride] : llvm::zip(offsets, strides)) {
ret = add(ret, mul(offset, stride));
}
return ret;
}
// Get an index-base for each dimension for a \param blocked_layout. // Get an index-base for each dimension for a \param blocked_layout.
SmallVector<Value> SmallVector<Value>
emitBaseIndexForBlockedLayout(Location loc, emitBaseIndexForBlockedLayout(Location loc,
@@ -671,6 +727,25 @@ public:
return base; return base;
} }
static SharedMemoryObject
getSharedMemoryObjectFromStruct(Location loc, Value llvmStruct,
ConversionPatternRewriter &rewriter) {
auto elems = getElementsFromStruct(loc, llvmStruct, rewriter);
return SharedMemoryObject(/*base=*/elems[0],
/*strides=*/{elems.begin() + 1, elems.end()});
}
static Value
getStructFromSharedMemoryObject(Location loc,
const SharedMemoryObject &smemObj,
ConversionPatternRewriter &rewriter) {
auto elems = smemObj.getElems();
auto types = smemObj.getTypes();
auto structTy =
LLVM::LLVMStructType::getLiteral(rewriter.getContext(), types);
return getStructFromElements(loc, elems, rewriter, structTy);
}
protected: protected:
const Allocation *allocation; const Allocation *allocation;
Value smem; Value smem;
@@ -1734,46 +1809,63 @@ struct AllocTensorOpConversion
auto resultTy = op.getType().dyn_cast<RankedTensorType>(); auto resultTy = op.getType().dyn_cast<RankedTensorType>();
auto llvmElemTy = auto llvmElemTy =
getTypeConverter()->convertType(resultTy.getElementType()); getTypeConverter()->convertType(resultTy.getElementType());
auto elemPtrTy = LLVM::LLVMPointerType::get(llvmElemTy, 3); auto elemPtrTy = ptr_ty(llvmElemTy, 3);
Value resultVal = smemBase = bitcast(smemBase, elemPtrTy);
rewriter.create<LLVM::BitcastOp>(loc, elemPtrTy, smemBase); auto smemObj =
rewriter.replaceOp(op, resultVal); SharedMemoryObject(smemBase, resultTy.getShape(), loc, rewriter);
auto retVal = getStructFromSharedMemoryObject(loc, smemObj, rewriter);
rewriter.replaceOp(op, retVal);
return success(); return success();
} }
}; };
struct ExtractSliceOpConversion struct ExtractSliceOpConversion
: public ConvertTritonGPUOpToLLVMPattern<triton::gpu::ExtractSliceOp> { : public ConvertTritonGPUOpToLLVMPattern<tensor::ExtractSliceOp> {
using ConvertTritonGPUOpToLLVMPattern< using ConvertTritonGPUOpToLLVMPattern<
triton::gpu::ExtractSliceOp>::ConvertTritonGPUOpToLLVMPattern; tensor::ExtractSliceOp>::ConvertTritonGPUOpToLLVMPattern;
LogicalResult LogicalResult
matchAndRewrite(triton::gpu::ExtractSliceOp op, OpAdaptor adaptor, matchAndRewrite(tensor::ExtractSliceOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override { ConversionPatternRewriter &rewriter) const override {
// %dst = extract_slice %src[%offsets]
Location loc = op->getLoc(); Location loc = op->getLoc();
auto srcTy = op.src().getType().dyn_cast<RankedTensorType>(); auto srcTy = op.source().getType().dyn_cast<RankedTensorType>();
auto srcLayout = srcTy.getEncoding().dyn_cast<SharedEncodingAttr>(); auto srcLayout = srcTy.getEncoding().dyn_cast<SharedEncodingAttr>();
assert(srcLayout && "Unexpected resultLayout in ExtractSliceOpConversion"); assert(srcLayout && "Unexpected resultLayout in ExtractSliceOpConversion");
assert(op.hasUnitStride() &&
"Only unit stride supported by ExtractSliceOpConversion");
// axis > 0 will result in non-contiguous memory access if the result // newBase = base + offset
// tensor is an alias of the source tensor. // Triton support either static and dynamic offsets
auto axis = op->getAttrOfType<IntegerAttr>("axis").getInt(); auto smemObj =
assert(axis == 0 && "extract_slice: Only axis=0 is supported for now"); getSharedMemoryObjectFromStruct(loc, adaptor.source(), rewriter);
SmallVector<Value, 4> offsetVals;
// Example: auto mixedOffsets = op.getMixedOffsets();
// %dst = extract_slice %src, %index {axis = 0} for (auto i = 0; i < mixedOffsets.size(); ++i) {
// src.shape = [11, 2, 3, 4, 1] if (op.isDynamicOffset(i))
// offset = %index * 2 * 3 * 4 * 1 offsetVals.emplace_back(adaptor.offsets()[i]);
auto dstTy = op.getType().dyn_cast<RankedTensorType>(); else
auto base = product<int64_t>(dstTy.getShape()); offsetVals.emplace_back(i32_val(op.getStaticOffset(i)));
auto baseVal = createIndexAttrConstant( }
rewriter, loc, getTypeConverter()->getIndexType(), base); // Compute the offset based on the original strides of the shared memory
Value offset = mul(adaptor.index(), baseVal); // object
auto offset = dot(rewriter, loc, offsetVals, smemObj.strides);
auto llvmElemTy = getTypeConverter()->convertType(dstTy.getElementType()); // newShape = rank_reduce(shape)
auto elemPtrTy = LLVM::LLVMPointerType::get(llvmElemTy, 3); // Triton only supports static tensor sizes
Value resultVal = gep(elemPtrTy, adaptor.src(), offset); SmallVector<Value, 4> strideVals;
rewriter.replaceOp(op, resultVal); auto staticSizes = op.static_sizes();
for (auto i = 0; i < op.static_sizes().size(); ++i) {
if (op.getStaticSize(i) != 1) {
strideVals.emplace_back(smemObj.strides[i]);
}
}
auto llvmElemTy = getTypeConverter()->convertType(srcTy.getElementType());
auto elemPtrTy = ptr_ty(llvmElemTy, 3);
auto resTy = op.getType().dyn_cast<RankedTensorType>();
smemObj =
SharedMemoryObject(gep(elemPtrTy, smemObj.base, offset), strideVals);
auto retVal = getStructFromSharedMemoryObject(loc, smemObj, rewriter);
rewriter.replaceOp(op, retVal);
return success(); return success();
} }
}; };
@@ -2262,8 +2354,9 @@ LogicalResult ConvertLayoutOpConversion::lowerBlockedToShared(
Value src = op.src(); Value src = op.src();
Value dst = op.result(); Value dst = op.result();
auto srcTy = src.getType().cast<RankedTensorType>(); auto srcTy = src.getType().cast<RankedTensorType>();
auto dstTy = dst.getType().cast<RankedTensorType>();
auto srcShape = srcTy.getShape(); auto srcShape = srcTy.getShape();
auto dstTy = dst.getType().cast<RankedTensorType>();
auto dstShape = dstTy.getShape();
assert(srcShape.size() == 2 && assert(srcShape.size() == 2 &&
"Unexpected rank of ConvertLayout(blocked->shared)"); "Unexpected rank of ConvertLayout(blocked->shared)");
auto srcBlockedLayout = srcTy.getEncoding().cast<BlockedEncodingAttr>(); auto srcBlockedLayout = srcTy.getEncoding().cast<BlockedEncodingAttr>();
@@ -2309,6 +2402,8 @@ LogicalResult ConvertLayoutOpConversion::lowerBlockedToShared(
Value smemBase = getSharedMemoryBase(loc, rewriter, dst); Value smemBase = getSharedMemoryBase(loc, rewriter, dst);
auto elemPtrTy = ptr_ty(getTypeConverter()->convertType(elemTy), 3); auto elemPtrTy = ptr_ty(getTypeConverter()->convertType(elemTy), 3);
smemBase = bitcast(smemBase, elemPtrTy); smemBase = bitcast(smemBase, elemPtrTy);
auto smemObj = SharedMemoryObject(smemBase, dstShape, loc, rewriter);
auto retVal = getStructFromSharedMemoryObject(loc, smemObj, rewriter);
unsigned numWordsEachRep = product<unsigned>(wordsInEachRep); unsigned numWordsEachRep = product<unsigned>(wordsInEachRep);
SmallVector<Value> wordVecs(numWordsEachRep); SmallVector<Value> wordVecs(numWordsEachRep);
// TODO: We should get less barriers if it is handled by membar pass // TODO: We should get less barriers if it is handled by membar pass
@@ -2369,8 +2464,10 @@ LogicalResult ConvertLayoutOpConversion::lowerBlockedToShared(
} }
} }
} }
barrier(); // Barrier is not necessary.
rewriter.replaceOp(op, smemBase); // The membar pass knows that it writes to shared memory and will handle it
// properly.
rewriter.replaceOp(op, retVal);
return success(); return success();
} }
@@ -2380,9 +2477,10 @@ LogicalResult ConvertLayoutOpConversion::lowerBlockedToShared(
class MMA16816SmemLoader { class MMA16816SmemLoader {
public: public:
MMA16816SmemLoader(int wpt, ArrayRef<uint32_t> order, uint32_t kOrder, MMA16816SmemLoader(int wpt, ArrayRef<uint32_t> order, uint32_t kOrder,
ArrayRef<int64_t> tileShape, ArrayRef<int> instrShape, ArrayRef<Value> smemStrides, ArrayRef<int64_t> tileShape,
ArrayRef<int> matShape, int perPhase, int maxPhase, ArrayRef<int> instrShape, ArrayRef<int> matShape,
int elemBytes, ConversionPatternRewriter &rewriter, int perPhase, int maxPhase, int elemBytes,
ConversionPatternRewriter &rewriter,
TypeConverter *typeConverter, const Location &loc) TypeConverter *typeConverter, const Location &loc)
: order(order.begin(), order.end()), kOrder(kOrder), : order(order.begin(), order.end()), kOrder(kOrder),
tileShape(tileShape.begin(), tileShape.end()), tileShape(tileShape.begin(), tileShape.end()),
@@ -2393,8 +2491,8 @@ public:
cMatShape = matShape[order[0]]; cMatShape = matShape[order[0]];
sMatShape = matShape[order[1]]; sMatShape = matShape[order[1]];
cTileStride = tileShape[order[1]]; cTileStride = smemStrides[order[0]];
sTileStride = tileShape[order[0]]; sTileStride = smemStrides[order[1]];
// rule: k must be the fast-changing axis. // rule: k must be the fast-changing axis.
needTrans = kOrder != order[0]; needTrans = kOrder != order[0];
@@ -2497,8 +2595,7 @@ public:
for (int i = 0; i < numPtr; ++i) { for (int i = 0; i < numPtr; ++i) {
Value cMatOffI = add(cMatOff, i32_val(i * pLoadStrideInMat)); Value cMatOffI = add(cMatOff, i32_val(i * pLoadStrideInMat));
cMatOffI = xor_(cMatOffI, phase); cMatOffI = xor_(cMatOffI, phase);
offs[i] = add(mul(cMatOffI, i32_val(cMatShape)), offs[i] = add(mul(cMatOffI, i32_val(cMatShape)), mul(sOff, sTileStride));
mul(sOff, i32_val(sTileStride)));
} }
return offs; return offs;
@@ -2534,7 +2631,7 @@ public:
Value cOff = add(cOffInMat, mul(cMatOffI, i32_val(cMatShape))); Value cOff = add(cOffInMat, mul(cMatOffI, i32_val(cMatShape)));
cOff = urem(cOff, i32_val(tileShape[order[0]])); cOff = urem(cOff, i32_val(tileShape[order[0]]));
sOff = urem(sOff, i32_val(tileShape[order[1]])); sOff = urem(sOff, i32_val(tileShape[order[1]]));
offs[2 * i + nkMatArrInt] = add(cOff, mul(sOff, i32_val(sTileStride))); offs[2 * i + nkMatArrInt] = add(cOff, mul(sOff, sTileStride));
} }
} }
return offs; return offs;
@@ -2574,7 +2671,7 @@ public:
// To prevent out-of-bound access when tile is too small. // To prevent out-of-bound access when tile is too small.
cOff = urem(cOff, i32_val(tileShape[order[0]])); cOff = urem(cOff, i32_val(tileShape[order[0]]));
sOff = urem(sOff, i32_val(tileShape[order[1]])); sOff = urem(sOff, i32_val(tileShape[order[1]]));
offs[ptrOff] = add(cOff, mul(sOff, i32_val(sTileStride))); offs[ptrOff] = add(cOff, mul(sOff, sTileStride));
} }
} }
} }
@@ -2608,14 +2705,15 @@ public:
Value ptr = getPtr(ptrIdx); Value ptr = getPtr(ptrIdx);
if (canUseLdmatrix) { if (canUseLdmatrix) {
int sOffset = Value sOffset =
matIdx[order[1]] * sMatStride * sMatShape * sTileStride * elemBytes; mul(i32_val(matIdx[order[1]] * sMatStride * sMatShape), sTileStride);
Value sOffsetPtr = gep(shemPtrTy, ptr, sOffset);
PTXBuilder builder; PTXBuilder builder;
// ldmatrix.m8n8.x4 returns 4x2xfp16(that is 4xb32) elements for a // ldmatrix.m8n8.x4 returns 4x2xfp16(that is 4xb32) elements for a
// thread. // thread.
auto resArgs = builder.newListOperand(4, "=r"); auto resArgs = builder.newListOperand(4, "=r");
auto addrArg = builder.newAddrOperand(ptr, "r", sOffset); auto addrArg = builder.newAddrOperand(sOffsetPtr, "r");
auto ldmatrix = builder.create("ldmatrix.sync.aligned.m8n8.x4") auto ldmatrix = builder.create("ldmatrix.sync.aligned.m8n8.x4")
->o("trans", needTrans /*predicate*/) ->o("trans", needTrans /*predicate*/)
@@ -2640,26 +2738,24 @@ public:
needTrans) { // Use lds.32 to load tf32 matrices needTrans) { // Use lds.32 to load tf32 matrices
Value ptr2 = getPtr(ptrIdx + 1); Value ptr2 = getPtr(ptrIdx + 1);
assert(sMatStride == 1); assert(sMatStride == 1);
int sOffsetElem = int sOffsetElem = matIdx[order[1]] * (sMatStride * sMatShape);
matIdx[order[1]] * (sMatStride * sMatShape) * sTileStride; Value sOffsetElemVal = mul(i32_val(sOffsetElem), sTileStride);
int sOffsetArrElem = 1 * (sMatStride * sMatShape) * sTileStride; int sOffsetArrElem = sMatStride * sMatShape;
Value sOffsetArrElemVal =
add(sOffsetElemVal, mul(i32_val(sOffsetArrElem), sTileStride));
Value elems[4]; Value elems[4];
Type elemTy = type::f32Ty(ctx); Type elemTy = type::f32Ty(ctx);
if (kOrder == 1) { if (kOrder == 1) {
elems[0] = load(gep(elemTy, ptr, i32_val(sOffsetElem))); elems[0] = load(gep(elemTy, ptr, sOffsetElemVal));
elems[1] = load(gep(elemTy, ptr2, i32_val(sOffsetElem))); elems[1] = load(gep(elemTy, ptr2, sOffsetElemVal));
elems[2] = elems[2] = load(gep(elemTy, ptr, sOffsetArrElemVal));
load(gep(elemTy, ptr, i32_val(sOffsetElem + sOffsetArrElem))); elems[3] = load(gep(elemTy, ptr2, sOffsetArrElemVal));
elems[3] =
load(gep(elemTy, ptr2, i32_val(sOffsetElem + sOffsetArrElem)));
} else { } else {
elems[0] = load(gep(elemTy, ptr, i32_val(sOffsetElem))); elems[0] = load(gep(elemTy, ptr, sOffsetElemVal));
elems[2] = load(gep(elemTy, ptr2, i32_val(sOffsetElem))); elems[2] = load(gep(elemTy, ptr2, sOffsetElemVal));
elems[1] = elems[1] = load(gep(elemTy, ptr, sOffsetArrElemVal));
load(gep(elemTy, ptr, i32_val(sOffsetElem + sOffsetArrElem))); elems[3] = load(gep(elemTy, ptr2, sOffsetArrElemVal));
elems[3] =
load(gep(elemTy, ptr2, i32_val(sOffsetElem + sOffsetArrElem)));
} }
return {elems[0], elems[1], elems[2], elems[3]}; return {elems[0], elems[1], elems[2], elems[3]};
@@ -2680,9 +2776,11 @@ public:
}; };
assert(sMatStride == 1); assert(sMatStride == 1);
int sOffsetElem = int sOffsetElem = matIdx[order[1]] * (sMatStride * sMatShape);
matIdx[order[1]] * (sMatStride * sMatShape) * sTileStride; Value sOffsetElemVal = mul(i32_val(sOffsetElem), sTileStride);
int sOffsetArrElem = 1 * (sMatStride * sMatShape) * sTileStride; int sOffsetArrElem = 1 * (sMatStride * sMatShape);
Value sOffsetArrElemVal =
add(sOffsetElemVal, mul(i32_val(sOffsetArrElem), sTileStride));
std::array<Value, 4> i8v4Elems; std::array<Value, 4> i8v4Elems;
std::array<Value, 4> i32Elems; std::array<Value, 4> i32Elems;
@@ -2692,16 +2790,14 @@ public:
Value i8Elems[4][4]; Value i8Elems[4][4];
Type elemTy = type::i8Ty(ctx); Type elemTy = type::i8Ty(ctx);
if (kOrder == 1) { if (kOrder == 1) {
Value offset = i32_val(sOffsetElem);
for (int i = 0; i < 2; ++i) for (int i = 0; i < 2; ++i)
for (int j = 0; j < 4; ++j) for (int j = 0; j < 4; ++j)
i8Elems[i][j] = load(gep(elemTy, ptrs[i][j], offset)); i8Elems[i][j] = load(gep(elemTy, ptrs[i][j], sOffsetElemVal));
offset = i32_val(sOffsetElem + sOffsetArrElem);
for (int i = 2; i < 4; ++i) for (int i = 2; i < 4; ++i)
for (int j = 0; j < 4; ++j) for (int j = 0; j < 4; ++j)
i8Elems[i][j] = load(gep(elemTy, ptrs[i - 2][j], offset)); i8Elems[i][j] =
load(gep(elemTy, ptrs[i - 2][j], sOffsetArrElemVal));
for (int m = 0; m < 4; ++m) { for (int m = 0; m < 4; ++m) {
for (int e = 0; e < 4; ++e) for (int e = 0; e < 4; ++e)
@@ -2710,16 +2806,14 @@ public:
i32Elems[m] = bitcast(i8v4Elems[m], i32_ty); i32Elems[m] = bitcast(i8v4Elems[m], i32_ty);
} }
} else { // k first } else { // k first
Value offset = i32_val(sOffsetElem);
for (int j = 0; j < 4; ++j) for (int j = 0; j < 4; ++j)
i8Elems[0][j] = load(gep(elemTy, ptrs[0][j], offset)); i8Elems[0][j] = load(gep(elemTy, ptrs[0][j], sOffsetElemVal));
for (int j = 0; j < 4; ++j) for (int j = 0; j < 4; ++j)
i8Elems[2][j] = load(gep(elemTy, ptrs[1][j], offset)); i8Elems[2][j] = load(gep(elemTy, ptrs[1][j], sOffsetElemVal));
offset = i32_val(sOffsetElem + sOffsetArrElem);
for (int j = 0; j < 4; ++j) for (int j = 0; j < 4; ++j)
i8Elems[1][j] = load(gep(elemTy, ptrs[0][j], offset)); i8Elems[1][j] = load(gep(elemTy, ptrs[0][j], sOffsetArrElemVal));
for (int j = 0; j < 4; ++j) for (int j = 0; j < 4; ++j)
i8Elems[3][j] = load(gep(elemTy, ptrs[1][j], offset)); i8Elems[3][j] = load(gep(elemTy, ptrs[1][j], sOffsetArrElemVal));
for (int m = 0; m < 4; ++m) { for (int m = 0; m < 4; ++m) {
for (int e = 0; e < 4; ++e) for (int e = 0; e < 4; ++e)
@@ -2752,8 +2846,8 @@ private:
int cMatShape; int cMatShape;
int sMatShape; int sMatShape;
int cTileStride; Value cTileStride;
int sTileStride; Value sTileStride;
bool needTrans; bool needTrans;
bool canUseLdmatrix; bool canUseLdmatrix;
@@ -2922,12 +3016,12 @@ struct DotOpMmaV1ConversionHelper {
} }
// Loading $a from smem to registers, returns a LLVM::Struct. // Loading $a from smem to registers, returns a LLVM::Struct.
Value loadA(Value A, Value llA, Value thread, Value smem, Location loc, Value loadA(Value A, const SharedMemoryObject &smemObj, Value thread,
ConversionPatternRewriter &rewriter) const; Location loc, ConversionPatternRewriter &rewriter) const;
// Loading $b from smem to registers, returns a LLVM::Struct. // Loading $b from smem to registers, returns a LLVM::Struct.
Value loadB(Value B, Value llB, Value thread, Value smem, Location loc, Value loadB(Value B, const SharedMemoryObject &smemObj, Value thread,
ConversionPatternRewriter &rewriter) const; Location loc, ConversionPatternRewriter &rewriter) const;
// Loading $c to registers, returns a LLVM::Struct. // Loading $c to registers, returns a LLVM::Struct.
Value loadC(Value C, Value llC, ConversionPatternRewriter &rewriter) const; Value loadC(Value C, Value llC, ConversionPatternRewriter &rewriter) const;
@@ -3334,7 +3428,7 @@ struct MMA16816ConversionHelper {
} }
// Loading $a from smem to registers, returns a LLVM::Struct. // Loading $a from smem to registers, returns a LLVM::Struct.
Value loadA(Value tensor, Value llTensor) const { Value loadA(Value tensor, const SharedMemoryObject &smemObj) const {
auto aTensorTy = tensor.getType().cast<RankedTensorType>(); auto aTensorTy = tensor.getType().cast<RankedTensorType>();
auto shape = aTensorTy.getShape(); auto shape = aTensorTy.getShape();
@@ -3348,7 +3442,7 @@ struct MMA16816ConversionHelper {
if (aTensorTy.getEncoding().isa<SharedEncodingAttr>()) { if (aTensorTy.getEncoding().isa<SharedEncodingAttr>()) {
// load from smem // load from smem
loadFn = getLoadMatrixFn( loadFn = getLoadMatrixFn(
tensor, llTensor, mmaLayout, mmaLayout.getWarpsPerCTA()[0] /*wpt*/, tensor, smemObj, mmaLayout, mmaLayout.getWarpsPerCTA()[0] /*wpt*/,
1 /*kOrder*/, {mmaInstrM, mmaInstrK} /*instrShpae*/, 1 /*kOrder*/, {mmaInstrM, mmaInstrK} /*instrShpae*/,
{matShapeM, matShapeK} /*matShape*/, warpM /*warpId*/, ha /*vals*/); {matShapeM, matShapeK} /*matShape*/, warpM /*warpId*/, ha /*vals*/);
} else if (aTensorTy.getEncoding().isa<BlockedEncodingAttr>()) { } else if (aTensorTy.getEncoding().isa<BlockedEncodingAttr>()) {
@@ -3370,7 +3464,7 @@ struct MMA16816ConversionHelper {
} }
// Loading $b from smem to registers, returns a LLVM::Struct. // Loading $b from smem to registers, returns a LLVM::Struct.
Value loadB(Value tensor, Value llTensor) { Value loadB(Value tensor, const SharedMemoryObject &smemObj) {
ValueTable hb; ValueTable hb;
auto tensorTy = tensor.getType().cast<RankedTensorType>(); auto tensorTy = tensor.getType().cast<RankedTensorType>();
auto shape = tensorTy.getShape(); auto shape = tensorTy.getShape();
@@ -3380,7 +3474,7 @@ struct MMA16816ConversionHelper {
int numRepN = getNumRepN(tensorTy, shape[1]); int numRepN = getNumRepN(tensorTy, shape[1]);
auto loadFn = getLoadMatrixFn( auto loadFn = getLoadMatrixFn(
tensor, llTensor, mmaLayout, mmaLayout.getWarpsPerCTA()[1] /*wpt*/, tensor, smemObj, mmaLayout, mmaLayout.getWarpsPerCTA()[1] /*wpt*/,
0 /*kOrder*/, {mmaInstrK, mmaInstrN} /*instrShpae*/, 0 /*kOrder*/, {mmaInstrK, mmaInstrN} /*instrShpae*/,
{matShapeK, matShapeN} /*matShape*/, warpN /*warpId*/, hb /*vals*/); {matShapeK, matShapeN} /*matShape*/, warpN /*warpId*/, hb /*vals*/);
@@ -3485,10 +3579,10 @@ struct MMA16816ConversionHelper {
private: private:
std::function<void(int, int)> std::function<void(int, int)>
getLoadMatrixFn(Value tensor, Value llTensor, MmaEncodingAttr mmaLayout, getLoadMatrixFn(Value tensor, const SharedMemoryObject &smemObj,
int wpt, uint32_t kOrder, ArrayRef<int> instrShape, MmaEncodingAttr mmaLayout, int wpt, uint32_t kOrder,
ArrayRef<int> matShape, Value warpId, ArrayRef<int> instrShape, ArrayRef<int> matShape,
ValueTable &vals) const { Value warpId, ValueTable &vals) const {
auto tensorTy = tensor.getType().cast<RankedTensorType>(); auto tensorTy = tensor.getType().cast<RankedTensorType>();
// We assumes that the input operand of Dot should be from shared layout. // We assumes that the input operand of Dot should be from shared layout.
// TODO(Superjomn) Consider other layouts if needed later. // TODO(Superjomn) Consider other layouts if needed later.
@@ -3507,10 +3601,10 @@ private:
// (a, b) is the coordinate. // (a, b) is the coordinate.
auto load = [=, &vals, &ld2](int a, int b) { auto load = [=, &vals, &ld2](int a, int b) {
MMA16816SmemLoader loader(wpt, sharedLayout.getOrder(), kOrder, MMA16816SmemLoader loader(
tensorTy.getShape() /*tileShape*/, instrShape, wpt, sharedLayout.getOrder(), kOrder, smemObj.strides,
matShape, perPhase, maxPhase, elemBytes, tensorTy.getShape() /*tileShape*/, instrShape, matShape, perPhase,
rewriter, typeConverter, loc); maxPhase, elemBytes, rewriter, typeConverter, loc);
SmallVector<Value> offs = loader.computeOffsets(warpId, lane); SmallVector<Value> offs = loader.computeOffsets(warpId, lane);
const int numPtrs = loader.getNumPtr(); const int numPtrs = loader.getNumPtr();
@@ -3519,8 +3613,8 @@ private:
Type smemPtrTy = helper.getShemPtrTy(); Type smemPtrTy = helper.getShemPtrTy();
for (int i = 0; i < numPtrs; ++i) { for (int i = 0; i < numPtrs; ++i) {
ptrs[i] = ptrs[i] = bitcast(gep(smemPtrTy, smemObj.base, ValueRange({offs[i]})),
bitcast(gep(smemPtrTy, llTensor, ValueRange({offs[i]})), smemPtrTy); smemPtrTy);
} }
auto [ha0, ha1, ha2, ha3] = loader.loadX4( auto [ha0, ha1, ha2, ha3] = loader.loadX4(
@@ -3612,6 +3706,7 @@ LogicalResult ConvertLayoutOpConversion::lowerSharedToDotOperand(
dotOperandLayout.getParent().dyn_cast_or_null<MmaEncodingAttr>(); dotOperandLayout.getParent().dyn_cast_or_null<MmaEncodingAttr>();
assert(mmaLayout); assert(mmaLayout);
auto smemObj = getSharedMemoryObjectFromStruct(loc, adaptor.src(), rewriter);
Value res; Value res;
if (mmaLayout.getVersion() == 2) { if (mmaLayout.getVersion() == 2) {
MMA16816ConversionHelper mmaHelper(mmaLayout, getThreadId(rewriter, loc), MMA16816ConversionHelper mmaHelper(mmaLayout, getThreadId(rewriter, loc),
@@ -3620,21 +3715,21 @@ LogicalResult ConvertLayoutOpConversion::lowerSharedToDotOperand(
if (dotOperandLayout.getOpIdx() == 0) { if (dotOperandLayout.getOpIdx() == 0) {
// operand $a // operand $a
res = mmaHelper.loadA(src, adaptor.src()); res = mmaHelper.loadA(src, smemObj);
} else if (dotOperandLayout.getOpIdx() == 1) { } else if (dotOperandLayout.getOpIdx() == 1) {
// operand $b // operand $b
res = mmaHelper.loadB(src, adaptor.src()); res = mmaHelper.loadB(src, smemObj);
} }
} else if (mmaLayout.getVersion() == 1) { } else if (mmaLayout.getVersion() == 1) {
DotOpMmaV1ConversionHelper helper(mmaLayout); DotOpMmaV1ConversionHelper helper(mmaLayout);
if (dotOperandLayout.getOpIdx() == 0) { if (dotOperandLayout.getOpIdx() == 0) {
// operand $a // operand $a
res = helper.loadA(src, adaptor.src(), getThreadId(rewriter, loc), res =
adaptor.src(), loc, rewriter); helper.loadA(src, smemObj, getThreadId(rewriter, loc), loc, rewriter);
} else if (dotOperandLayout.getOpIdx() == 1) { } else if (dotOperandLayout.getOpIdx() == 1) {
// operand $b // operand $b
res = helper.loadB(src, adaptor.src(), getThreadId(rewriter, loc), res =
adaptor.src(), loc, rewriter); helper.loadB(src, smemObj, getThreadId(rewriter, loc), loc, rewriter);
} }
} else { } else {
assert(false && "Unsupported mma layout found"); assert(false && "Unsupported mma layout found");
@@ -3671,8 +3766,12 @@ DotOpConversion::convertMMA16816(triton::DotOp op, OpAdaptor adaptor,
loadedA = adaptor.a(); loadedA = adaptor.a();
loadedB = adaptor.b(); loadedB = adaptor.b();
} else { } else {
loadedA = mmaHelper.loadA(op.a(), adaptor.a()); SharedMemoryObject smemA =
loadedB = mmaHelper.loadB(op.b(), adaptor.b()); getSharedMemoryObjectFromStruct(loc, adaptor.a(), rewriter);
SharedMemoryObject smemB =
getSharedMemoryObjectFromStruct(loc, adaptor.b(), rewriter);
loadedA = mmaHelper.loadA(op.a(), smemA);
loadedB = mmaHelper.loadB(op.b(), smemB);
} }
loadedC = mmaHelper.loadC(op.c(), adaptor.c()); loadedC = mmaHelper.loadC(op.c(), adaptor.c());
@@ -3797,8 +3896,12 @@ DotOpConversion::convertMMA884(triton::DotOp op, DotOpAdaptor adaptor,
} }
Value DotOpMmaV1ConversionHelper::loadA( Value DotOpMmaV1ConversionHelper::loadA(
Value tensor, Value llTensor, Value thread, Value smem, Location loc, Value tensor, const SharedMemoryObject &smemObj, Value thread, Location loc,
ConversionPatternRewriter &rewriter) const { ConversionPatternRewriter &rewriter) const {
// smem
Value smem = smemObj.base;
auto strides = smemObj.strides;
auto *ctx = rewriter.getContext(); auto *ctx = rewriter.getContext();
auto tensorTy = tensor.getType().cast<RankedTensorType>(); auto tensorTy = tensor.getType().cast<RankedTensorType>();
auto shape = tensorTy.getShape(); auto shape = tensorTy.getShape();
@@ -3818,10 +3921,10 @@ Value DotOpMmaV1ConversionHelper::loadA(
int vecA = sharedLayout.getVec(); int vecA = sharedLayout.getVec();
int strideAM = isARow ? shape[1] : 1; Value strideAM = isARow ? strides[0] : i32_val(1);
int strideAK = isARow ? 1 : shape[0]; Value strideAK = isARow ? i32_val(1) : strides[1];
int strideA0 = isARow ? strideAK : strideAM; Value strideA0 = isARow ? strideAK : strideAM;
int strideA1 = isARow ? strideAM : strideAK; Value strideA1 = isARow ? strideAM : strideAK;
int strideRepM = wpt[0] * fpw[0] * 8; int strideRepM = wpt[0] * fpw[0] * 8;
int strideRepK = 1; int strideRepK = 1;
@@ -3847,8 +3950,7 @@ Value DotOpMmaV1ConversionHelper::loadA(
offA0I = udiv(offA0I, i32_val(vecA)); offA0I = udiv(offA0I, i32_val(vecA));
offA0I = xor_(offA0I, phaseA); offA0I = xor_(offA0I, phaseA);
offA0I = xor_(offA0I, i32_val(vecA)); offA0I = xor_(offA0I, i32_val(vecA));
offA[i] = offA[i] = add(mul(offA0I, strideA0), mul(offA1, strideA1));
add(mul(offA0I, i32_val(strideA0)), mul(offA1, i32_val(strideA1)));
} }
Type f16x2Ty = vec_ty(f16_ty, 2); Type f16x2Ty = vec_ty(f16_ty, 2);
@@ -3877,8 +3979,9 @@ Value DotOpMmaV1ConversionHelper::loadA(
int stepAM = isARow ? m : m / numPtrA * numPtrA; int stepAM = isARow ? m : m / numPtrA * numPtrA;
int stepAK = isARow ? k / (numPtrA * vecA) * (numPtrA * vecA) : k; int stepAK = isARow ? k / (numPtrA * vecA) * (numPtrA * vecA) : k;
Value pa = gep(f16PtrTy, thePtrA, Value offset = add(mul(i32_val(stepAM * strideRepM), strideAM),
i32_val(stepAM * strideRepM * strideAM + stepAK * strideAK)); mul(i32_val(stepAK), strideAK));
Value pa = gep(f16PtrTy, thePtrA, offset);
Type aPtrTy = ptr_ty(vec_ty(i32_ty, std::max<int>(vecA / 2, 1)), 3); Type aPtrTy = ptr_ty(vec_ty(i32_ty, std::max<int>(vecA / 2, 1)), 3);
Value ha = load(bitcast(pa, aPtrTy)); Value ha = load(bitcast(pa, aPtrTy));
// record lds that needs to be moved // record lds that needs to be moved
@@ -3915,8 +4018,12 @@ Value DotOpMmaV1ConversionHelper::loadA(
} }
Value DotOpMmaV1ConversionHelper::loadB( Value DotOpMmaV1ConversionHelper::loadB(
Value tensor, Value llTensor, Value thread, Value smem, Location loc, Value tensor, const SharedMemoryObject &smemObj, Value thread, Location loc,
ConversionPatternRewriter &rewriter) const { ConversionPatternRewriter &rewriter) const {
// smem
Value smem = smemObj.base;
auto strides = smemObj.strides;
auto *ctx = rewriter.getContext(); auto *ctx = rewriter.getContext();
auto tensorTy = tensor.getType().cast<RankedTensorType>(); auto tensorTy = tensor.getType().cast<RankedTensorType>();
auto shape = tensorTy.getShape(); auto shape = tensorTy.getShape();
@@ -3929,10 +4036,10 @@ Value DotOpMmaV1ConversionHelper::loadB(
SmallVector<int> rep({0, 2 * packSize1, 1}); // pad M with 0 SmallVector<int> rep({0, 2 * packSize1, 1}); // pad M with 0
SmallVector<int> spw({0, fpw[1] * 4 * rep[1], 1}); // pad M with 0 SmallVector<int> spw({0, fpw[1] * 4 * rep[1], 1}); // pad M with 0
int vecB = sharedLayout.getVec(); int vecB = sharedLayout.getVec();
int strideBN = isBRow ? 1 : shape[0]; Value strideBN = isBRow ? i32_val(1) : strides[1];
int strideBK = isBRow ? shape[1] : 1; Value strideBK = isBRow ? strides[0] : i32_val(1);
int strideB0 = isBRow ? strideBN : strideBK; Value strideB0 = isBRow ? strideBN : strideBK;
int strideB1 = isBRow ? strideBK : strideBN; Value strideB1 = isBRow ? strideBK : strideBN;
int strideRepN = wpt[1] * fpw[1] * 8; int strideRepN = wpt[1] * fpw[1] * 8;
int strideRepK = 1; int strideRepK = 1;
@@ -3957,8 +4064,7 @@ Value DotOpMmaV1ConversionHelper::loadB(
offB0I = udiv(offB0I, i32_val(vecB)); offB0I = udiv(offB0I, i32_val(vecB));
offB0I = xor_(offB0I, phaseB); offB0I = xor_(offB0I, phaseB);
offB0I = mul(offB0I, i32_val(vecB)); offB0I = mul(offB0I, i32_val(vecB));
offB[i] = offB[i] = add(mul(offB0I, strideB0), mul(offB1, strideB1));
add(mul(offB0I, i32_val(strideB0)), mul(offB1, i32_val(strideB1)));
} }
Type f16PtrTy = ptr_ty(f16_ty); Type f16PtrTy = ptr_ty(f16_ty);
@@ -3979,8 +4085,9 @@ Value DotOpMmaV1ConversionHelper::loadB(
int stepBN = isBRow ? n / numPtrB * numPtrB : n; int stepBN = isBRow ? n / numPtrB * numPtrB : n;
int stepBK = isBRow ? K : K / (numPtrB * vecB) * (numPtrB * vecB); int stepBK = isBRow ? K : K / (numPtrB * vecB) * (numPtrB * vecB);
Value pb = gep(f16PtrTy, thePtrB, Value offset = add(mul(i32_val(stepBN * strideRepN), strideBN),
i32_val(stepBN * strideRepN * strideBN + stepBK * strideBK)); mul(i32_val(stepBK), strideBK));
Value pb = gep(f16PtrTy, thePtrB, offset);
Value hb = Value hb =
load(bitcast(pb, ptr_ty(vec_ty(i32_ty, std::max(vecB / 2, 1)), 3))); load(bitcast(pb, ptr_ty(vec_ty(i32_ty, std::max(vecB / 2, 1)), 3)));
// record lds that needs to be moved // record lds that needs to be moved
@@ -4171,7 +4278,17 @@ public:
return LLVM::LLVMStructType::getLiteral(ctx, types); return LLVM::LLVMStructType::getLiteral(ctx, types);
} else if (auto shared_layout = } else if (auto shared_layout =
layout.dyn_cast_or_null<SharedEncodingAttr>()) { layout.dyn_cast_or_null<SharedEncodingAttr>()) {
return LLVM::LLVMPointerType::get(convertType(type.getElementType()), 3); SmallVector<Type, 4> types;
// base ptr
auto ptrType =
LLVM::LLVMPointerType::get(convertType(type.getElementType()), 3);
types.push_back(ptrType);
// shape dims
auto rank = type.getRank();
for (auto i = 0; i < rank; i++) {
types.push_back(IntegerType::get(ctx, 32));
}
return LLVM::LLVMStructType::getLiteral(ctx, types);
} else if (auto mmaLayout = layout.dyn_cast_or_null<MmaEncodingAttr>()) { } else if (auto mmaLayout = layout.dyn_cast_or_null<MmaEncodingAttr>()) {
if (mmaLayout.getVersion() == 2) { if (mmaLayout.getVersion() == 2) {
auto [repM, repN] = DotOpMmaV2ConversionHelper::getRepMN(type); auto [repM, repN] = DotOpMmaV2ConversionHelper::getRepMN(type);
@@ -4309,15 +4426,26 @@ struct InsertSliceAsyncOpConversion
auto srcElems = getLLVMElems(src, llSrc, rewriter, loc); auto srcElems = getLLVMElems(src, llSrc, rewriter, loc);
// %dst // %dst
auto dstTy = dst.getType().cast<RankedTensorType>();
auto dstShape = dstTy.getShape();
auto smemObj = getSharedMemoryObjectFromStruct(loc, llDst, rewriter);
auto axis = op->getAttrOfType<IntegerAttr>("axis").getInt(); auto axis = op->getAttrOfType<IntegerAttr>("axis").getInt();
assert(axis == 0 && "insert_slice_async: Only axis=0 is supported for now"); SmallVector<Value, 4> offsetVals;
auto dstBase = createIndexAttrConstant(rewriter, loc, SmallVector<Value, 4> srcStrides;
getTypeConverter()->getIndexType(), for (auto i = 0; i < dstShape.size(); ++i) {
product<int64_t>(srcTy.getShape())); if (i == axis) {
Value offset = mul(llIndex, dstBase); offsetVals.emplace_back(llIndex);
auto dstPtrTy = LLVM::LLVMPointerType::get( } else {
getTypeConverter()->convertType(resTy.getElementType()), 3); offsetVals.emplace_back(i32_val(0));
Value dstPtrBase = gep(dstPtrTy, llDst, offset); srcStrides.emplace_back(smemObj.strides[i]);
}
}
// Compute the offset based on the original dimensions of the shared memory
// object
auto dstOffset = dot(rewriter, loc, offsetVals, smemObj.strides);
auto dstPtrTy =
ptr_ty(getTypeConverter()->convertType(resTy.getElementType()), 3);
Value dstPtrBase = gep(dstPtrTy, smemObj.base, dstOffset);
// %mask // %mask
SmallVector<Value> maskElems; SmallVector<Value> maskElems;
@@ -4345,11 +4473,10 @@ struct InsertSliceAsyncOpConversion
unsigned maxPhase = resSharedLayout.getMaxPhase(); unsigned maxPhase = resSharedLayout.getMaxPhase();
auto sizePerThread = srcBlockedLayout.getSizePerThread(); auto sizePerThread = srcBlockedLayout.getSizePerThread();
auto threadsPerCTA = getThreadsPerCTA(srcBlockedLayout); auto threadsPerCTA = getThreadsPerCTA(srcBlockedLayout);
auto inOrder = srcBlockedLayout.getOrder(); auto inOrder = srcBlockedLayout.getOrder();
// If perPhase * maxPhase > threadsPerCTA, we need to swizzle over // If perPhase * maxPhase > threadsPerCTA, we need to swizzle over
// elements across phases. If perPhase * maxPhase == threadsPerCTA, // elements across phases. If perPhase * maxPhase <= threadsPerCTA,
// swizzle is not allowd // swizzle is not allowd
auto numSwizzleRows = std::max<unsigned>( auto numSwizzleRows = std::max<unsigned>(
(perPhase * maxPhase) / threadsPerCTA[inOrder[1]], 1); (perPhase * maxPhase) / threadsPerCTA[inOrder[1]], 1);
@@ -4377,7 +4504,6 @@ struct InsertSliceAsyncOpConversion
vecIdxCol / numVecCols * numVecCols * threadsPerCTA[inOrder[0]]; vecIdxCol / numVecCols * numVecCols * threadsPerCTA[inOrder[0]];
auto baseOffsetRow = vecIdxRow / numSwizzleRows * numSwizzleRows * auto baseOffsetRow = vecIdxRow / numSwizzleRows * numSwizzleRows *
threadsPerCTA[inOrder[1]]; threadsPerCTA[inOrder[1]];
auto baseOffset = (baseOffsetRow * srcShape[inOrder[0]] + baseOffsetCol);
auto tileVecIdxCol = vecIdxCol % numVecCols; auto tileVecIdxCol = vecIdxCol % numVecCols;
auto tileVecIdxRow = vecIdxRow % numSwizzleRows; auto tileVecIdxRow = vecIdxRow % numSwizzleRows;
@@ -4399,8 +4525,10 @@ struct InsertSliceAsyncOpConversion
auto srcIdx = srcIndices[tileVecIdxRow * sizePerThread[inOrder[0]]]; auto srcIdx = srcIndices[tileVecIdxRow * sizePerThread[inOrder[0]]];
Value phase = urem(udiv(srcIdx[inOrder[1]], i32_val(perPhase)), Value phase = urem(udiv(srcIdx[inOrder[1]], i32_val(perPhase)),
i32_val(maxPhase)); i32_val(maxPhase));
Value rowOffset = // srcShape and smemObj.shape maybe different if smemObj is a
mul(srcIdx[inOrder[1]], i32_val(srcShape[inOrder[0]])); // slice of the original shared memory object.
// So we need to use the original shape to compute the offset
Value rowOffset = mul(srcIdx[inOrder[1]], srcStrides[inOrder[1]]);
Value colOffset = Value colOffset =
add(srcIdx[inOrder[0]], i32_val(tileVecIdxCol * minVec)); add(srcIdx[inOrder[0]], i32_val(tileVecIdxCol * minVec));
Value swizzleIdx = udiv(colOffset, i32_val(outVec)); Value swizzleIdx = udiv(colOffset, i32_val(outVec));
@@ -4420,21 +4548,25 @@ struct InsertSliceAsyncOpConversion
auto numWords = vecBitWidth / bitWidth; auto numWords = vecBitWidth / bitWidth;
auto numWordElems = bitWidth / resElemTy.getIntOrFloatBitWidth(); auto numWordElems = bitWidth / resElemTy.getIntOrFloatBitWidth();
// XXX(Keren): Tune CG and CA here. // Tune CG and CA here.
auto byteWidth = bitWidth / 8; auto byteWidth = bitWidth / 8;
CacheModifier srcCacheModifier = CacheModifier srcCacheModifier =
byteWidth == 16 ? CacheModifier::CG : CacheModifier::CA; byteWidth == 16 ? CacheModifier::CG : CacheModifier::CA;
assert(byteWidth == 16 || byteWidth == 8 || byteWidth == 4); assert(byteWidth == 16 || byteWidth == 8 || byteWidth == 4);
auto resByteWidth = resElemTy.getIntOrFloatBitWidth() / 8; auto resByteWidth = resElemTy.getIntOrFloatBitWidth() / 8;
auto tileOffset = tileOffsetMap[{tileVecIdxRow, tileVecIdxCol}]; Value tileOffset = tileOffsetMap[{tileVecIdxRow, tileVecIdxCol}];
Value baseOffset =
add(mul(i32_val(baseOffsetRow), srcStrides[inOrder[1]]),
i32_val(baseOffsetCol));
Value basePtr = gep(dstPtrTy, tileOffset, baseOffset);
for (size_t wordIdx = 0; wordIdx < numWords; ++wordIdx) { for (size_t wordIdx = 0; wordIdx < numWords; ++wordIdx) {
PTXBuilder ptxBuilder; PTXBuilder ptxBuilder;
auto wordElemIdx = wordIdx * numWordElems; auto wordElemIdx = wordIdx * numWordElems;
auto &copyAsyncOp = auto &copyAsyncOp =
*ptxBuilder.create<PTXCpAsyncLoadInstr>(srcCacheModifier); *ptxBuilder.create<PTXCpAsyncLoadInstr>(srcCacheModifier);
auto *dstOperand = ptxBuilder.newAddrOperand( auto *dstOperand =
tileOffset, "r", (wordElemIdx + baseOffset) * resByteWidth); ptxBuilder.newAddrOperand(basePtr, "r", wordElemIdx * resByteWidth);
auto *srcOperand = auto *srcOperand =
ptxBuilder.newAddrOperand(srcElems[elemIdx + wordElemIdx], "l"); ptxBuilder.newAddrOperand(srcElems[elemIdx + wordElemIdx], "l");
auto *copySize = ptxBuilder.newConstantOperand(byteWidth); auto *copySize = ptxBuilder.newConstantOperand(byteWidth);

View File

@@ -1,5 +1,6 @@
add_mlir_dialect_library(TritonGPUIR add_mlir_dialect_library(TritonGPUIR
Dialect.cpp Dialect.cpp
Traits.cpp
DEPENDS DEPENDS
TritonGPUTableGen TritonGPUTableGen

View File

@@ -474,7 +474,7 @@ void SharedEncodingAttr::print(AsmPrinter &printer) const {
ParseResult parseInsertSliceAsyncOp(OpAsmParser &parser, ParseResult parseInsertSliceAsyncOp(OpAsmParser &parser,
OperationState &result) { OperationState &result) {
SmallVector<OpAsmParser::OperandType, 4> allOperands; SmallVector<OpAsmParser::OperandType, 8> allOperands;
Type srcType, dstType; Type srcType, dstType;
SMLoc allOperandLoc = parser.getCurrentLocation(); SMLoc allOperandLoc = parser.getCurrentLocation();
if (parser.parseOperandList(allOperands) || if (parser.parseOperandList(allOperands) ||
@@ -489,14 +489,27 @@ ParseResult parseInsertSliceAsyncOp(OpAsmParser &parser,
operandTypes.push_back(dstType); // dst operandTypes.push_back(dstType); // dst
operandTypes.push_back( operandTypes.push_back(
IntegerType::get(parser.getBuilder().getContext(), 32)); // index IntegerType::get(parser.getBuilder().getContext(), 32)); // index
if (allOperands.size() >= 4)
int hasMask = 0, hasOther = 0;
if (allOperands.size() >= 4) {
operandTypes.push_back(triton::getI1SameShape(srcType)); // mask operandTypes.push_back(triton::getI1SameShape(srcType)); // mask
if (allOperands.size() >= 5) hasMask = 1;
}
if (allOperands.size() >= 5) {
operandTypes.push_back(triton::getPointeeType(srcType)); // other operandTypes.push_back(triton::getPointeeType(srcType)); // other
hasOther = 1;
}
if (parser.resolveOperands(allOperands, operandTypes, allOperandLoc, if (parser.resolveOperands(allOperands, operandTypes, allOperandLoc,
result.operands)) result.operands))
return failure(); return failure();
// Deduce operand_segment_sizes from the number of the operands.
auto operand_segment_sizesAttrName =
InsertSliceAsyncOp::operand_segment_sizesAttrName(result.name);
result.addAttribute(
operand_segment_sizesAttrName,
parser.getBuilder().getI32VectorAttr({1, 1, 1, hasMask, hasOther}));
return success(); return success();
} }
@@ -504,39 +517,16 @@ void printInsertSliceAsyncOp(OpAsmPrinter &printer,
InsertSliceAsyncOp insertSliceAsyncOp) { InsertSliceAsyncOp insertSliceAsyncOp) {
printer << " "; printer << " ";
printer << insertSliceAsyncOp.getOperation()->getOperands(); printer << insertSliceAsyncOp.getOperation()->getOperands();
printer.printOptionalAttrDict(insertSliceAsyncOp->getAttrs(), // "operand_segment_sizes" can be deduced, so we don't print it.
/*elidedAttrs=*/{}); printer.printOptionalAttrDict(
insertSliceAsyncOp->getAttrs(),
{insertSliceAsyncOp.operand_segment_sizesAttrName()});
printer << " : "; printer << " : ";
printer.printStrippedAttrOrType(insertSliceAsyncOp.src().getType()); printer.printStrippedAttrOrType(insertSliceAsyncOp.src().getType());
printer << " -> "; printer << " -> ";
printer.printStrippedAttrOrType(insertSliceAsyncOp.result().getType()); printer.printStrippedAttrOrType(insertSliceAsyncOp.result().getType());
} }
//===----------------------------------------------------------------------===//
// ExtractSliceOp
//===----------------------------------------------------------------------===//
mlir::LogicalResult ExtractSliceOp::inferReturnTypes(
::mlir::MLIRContext *context, llvm::Optional<::mlir::Location> location,
::mlir::ValueRange operands, mlir::DictionaryAttr attributes,
::mlir::RegionRange regions,
llvm::SmallVectorImpl<::mlir::Type> &inferredReturnTypes) {
auto srcType = operands[0].getType().cast<RankedTensorType>();
auto encoding = srcType.getEncoding();
auto srcShape = srcType.getShape();
auto axis = attributes.get("axis").cast<IntegerAttr>().getInt();
if (axis < 0 || (size_t)axis > srcShape.size())
return failure();
SmallVector<int64_t, 4> dstShape;
for (size_t i = 0; i < srcShape.size(); i++)
if (i != (size_t)axis)
dstShape.push_back(srcShape[i]);
auto returnType =
RankedTensorType::get(dstShape, srcType.getElementType(), encoding);
inferredReturnTypes.assign({returnType});
return success();
}
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
// DotOperand Encoding // DotOperand Encoding
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
@@ -631,32 +621,6 @@ void TritonGPUDialect::initialize() {
addInterfaces<TritonGPUInferLayoutInterface>(); addInterfaces<TritonGPUInferLayoutInterface>();
} }
//===----------------------------------------------------------------------===//
// Verification
//===----------------------------------------------------------------------===//
static LogicalResult verify(InsertSliceAsyncOp op) {
if (!isSharedEncoding(op.getResult())) {
return op.emitOpError(
"insert_slice_async should return a shared memory tensor");
}
return success();
}
static LogicalResult verify(ExtractSliceOp op) {
if (!isSharedEncoding(op.getResult())) {
return op.emitOpError("extract_slice should return a shared memory tensor");
}
return success();
}
static LogicalResult verify(AllocTensorOp op) {
if (!isSharedEncoding(op.getResult())) {
return op.emitOpError("alloc_tensor should return a shared memory tensor");
}
return success();
}
#define GET_OP_CLASSES #define GET_OP_CLASSES
#include "triton/Dialect/TritonGPU/IR/Ops.cpp.inc" #include "triton/Dialect/TritonGPU/IR/Ops.cpp.inc"

View File

@@ -0,0 +1,14 @@
#include "triton/Dialect/TritonGPU/IR/Traits.h"
#include "triton/Analysis/Utility.h"
mlir::LogicalResult
mlir::OpTrait::impl::verifyResultsAreSharedEncoding(Operation *op) {
if (failed(verifyAtLeastNResults(op, 1)))
return failure();
for (auto result : op->getResults())
if (!isSharedEncoding(result))
return op->emitOpError() << "requires all results to be shared encoding";
return success();
};

View File

@@ -111,37 +111,41 @@ public:
// cvt(insert_slice(x), type2) -> insert_slice(cvt(x, type2)) // cvt(insert_slice(x), type2) -> insert_slice(cvt(x, type2))
auto insert_slice = dyn_cast<triton::gpu::InsertSliceAsyncOp>(arg); auto insert_slice = dyn_cast<triton::gpu::InsertSliceAsyncOp>(arg);
if (insert_slice) { if (insert_slice) {
auto newType = op->getResult(0).getType(); auto newType = op->getResult(0).getType().cast<RankedTensorType>();
// Ensure that the new insert_slice op is placed in the same place as the // Ensure that the new insert_slice op is placed in the same place as the
// old insert_slice op. Otherwise, the new insert_slice op may be placed // old insert_slice op. Otherwise, the new insert_slice op may be placed
// after the async_wait op, which is not allowed. // after the async_wait op, which is not allowed.
OpBuilder::InsertionGuard guard(rewriter); OpBuilder::InsertionGuard guard(rewriter);
rewriter.setInsertionPoint(insert_slice); rewriter.setInsertionPoint(insert_slice);
auto new_arg = rewriter.create<triton::gpu::ConvertLayoutOp>( auto newArg = rewriter.create<triton::gpu::ConvertLayoutOp>(
op->getLoc(), newType, insert_slice.dst()); op->getLoc(), newType, insert_slice.dst());
rewriter.replaceOpWithNewOp<triton::gpu::InsertSliceAsyncOp>( rewriter.replaceOpWithNewOp<triton::gpu::InsertSliceAsyncOp>(
op, newType, insert_slice.src(), new_arg.getResult(), op, newType, insert_slice.src(), newArg.getResult(),
insert_slice.index(), insert_slice.mask(), insert_slice.other(), insert_slice.index(), insert_slice.mask(), insert_slice.other(),
insert_slice.cache(), insert_slice.evict(), insert_slice.isVolatile(), insert_slice.cache(), insert_slice.evict(),
insert_slice.axis()); insert_slice.isVolatile(), insert_slice.axis());
return mlir::success(); return mlir::success();
} }
// cvt(extract_slice(x), type2) -> extract_slice(cvt(x, type2)) // cvt(extract_slice(x), type2) -> extract_slice(cvt(x, type2))
auto extract_slice = dyn_cast<triton::gpu::ExtractSliceOp>(arg); auto extract_slice = dyn_cast<tensor::ExtractSliceOp>(arg);
if (extract_slice) { if (extract_slice) {
auto origType = extract_slice.src().getType().cast<RankedTensorType>(); auto origType = extract_slice.source().getType().cast<RankedTensorType>();
auto newType = RankedTensorType::get(
origType.getShape(), origType.getElementType(),
op->getResult(0).getType().cast<RankedTensorType>().getEncoding());
auto resType = op->getResult(0).getType().cast<RankedTensorType>();
// Ensure that the new extract_slice op is placed in the same place as the // Ensure that the new extract_slice op is placed in the same place as the
// old extract_slice op. Otherwise, the new extract_slice op may be placed // old extract_slice op. Otherwise, the new extract_slice op may be placed
// after the async_wait op, which is not allowed. // after the async_wait op, which is not allowed.
OpBuilder::InsertionGuard guard(rewriter); OpBuilder::InsertionGuard guard(rewriter);
rewriter.setInsertionPoint(extract_slice); rewriter.setInsertionPoint(extract_slice);
auto newType = RankedTensorType::get( auto newArg = rewriter.create<triton::gpu::ConvertLayoutOp>(
origType.getShape(), origType.getElementType(), op->getLoc(), newType, extract_slice.source());
op->getResult(0).getType().cast<RankedTensorType>().getEncoding()); rewriter.replaceOpWithNewOp<tensor::ExtractSliceOp>(
auto new_arg = rewriter.create<triton::gpu::ConvertLayoutOp>( op, resType, newArg.getResult(), extract_slice.offsets(),
op->getLoc(), newType, extract_slice.src()); extract_slice.sizes(), extract_slice.strides(),
rewriter.replaceOpWithNewOp<triton::gpu::ExtractSliceOp>( extract_slice.static_offsets(), extract_slice.static_sizes(),
op, new_arg.getResult(), extract_slice.index(), extract_slice.axis()); extract_slice.static_strides());
return mlir::success(); return mlir::success();
} }
// cvt(type2, x) // cvt(type2, x)
@@ -198,7 +202,7 @@ static LogicalResult invertEncoding(Attribute targetEncoding, Operation *op,
inline bool expensive_to_remat(Operation *op) { inline bool expensive_to_remat(Operation *op) {
if (!op) if (!op)
return true; return true;
if (isa<triton::gpu::ExtractSliceOp, triton::gpu::AllocTensorOp, if (isa<tensor::ExtractSliceOp, triton::gpu::AllocTensorOp,
triton::gpu::InsertSliceAsyncOp, triton::LoadOp, triton::StoreOp, triton::gpu::InsertSliceAsyncOp, triton::LoadOp, triton::StoreOp,
triton::AtomicRMWOp, triton::AtomicCASOp, triton::DotOp>(op)) triton::AtomicRMWOp, triton::AtomicCASOp, triton::DotOp>(op))
return true; return true;

View File

@@ -339,14 +339,20 @@ void LoopPipeliner::emitPrologue() {
builder.create<arith::ConstantIntOp>(iv.getLoc(), 1, 32)); builder.create<arith::ConstantIntOp>(iv.getLoc(), 1, 32));
} // for (int stage = 0; stage < numStages - 1; ++stage) } // for (int stage = 0; stage < numStages - 1; ++stage)
auto intAttr = [&](int64_t v) { return builder.getI64IntegerAttr(v); };
// async.wait & extract_slice // async.wait & extract_slice
builder.create<triton::gpu::AsyncWaitOp>(loads[0].getLoc(), builder.create<triton::gpu::AsyncWaitOp>(loads[0].getLoc(),
loads.size() * (numStages - 2)); loads.size() * (numStages - 2));
loopIterIdx = builder.create<arith::ConstantIntOp>(iv.getLoc(), 0, 32); loopIterIdx = builder.create<arith::ConstantIntOp>(iv.getLoc(), 0, 32);
for (Value loadOp : loads) { for (Value loadOp : loads) {
Value extractSlice = builder.create<triton::gpu::ExtractSliceOp>( auto sliceType = loadsMapping[loadOp].getType().cast<RankedTensorType>();
loadOp.getLoc(), loadsMapping[loadOp].getType(), Value extractSlice = builder.create<tensor::ExtractSliceOp>(
loadStageBuffer[loadOp][numStages - 1], loopIterIdx, /*axis*/ 0); loadOp.getLoc(), sliceType, loadStageBuffer[loadOp][numStages - 1],
SmallVector<OpFoldResult>{intAttr(0), intAttr(0), intAttr(0)},
SmallVector<OpFoldResult>{intAttr(1), intAttr(sliceType.getShape()[0]),
intAttr(sliceType.getShape()[1])},
SmallVector<OpFoldResult>{intAttr(1), intAttr(1), intAttr(1)});
loadsExtract[loadOp] = extractSlice; loadsExtract[loadOp] = extractSlice;
} }
// bump up loopIterIdx, this is used for getting the correct slice for the // bump up loopIterIdx, this is used for getting the correct slice for the
@@ -477,6 +483,10 @@ scf::ForOp LoopPipeliner::createNewForOp() {
Value extractSliceIndex = builder.create<arith::RemSIOp>( Value extractSliceIndex = builder.create<arith::RemSIOp>(
nextIV.getLoc(), loopIterIdx, nextIV.getLoc(), loopIterIdx,
builder.create<arith::ConstantIntOp>(nextIV.getLoc(), numStages, 32)); builder.create<arith::ConstantIntOp>(nextIV.getLoc(), numStages, 32));
extractSliceIndex = builder.create<arith::IndexCastOp>(
extractSliceIndex.getLoc(), builder.getIndexType(), extractSliceIndex);
auto intAttr = [&](int64_t v) { return builder.getI64IntegerAttr(v); };
for (Operation *op : orderedDeps) { for (Operation *op : orderedDeps) {
Operation *nextOp = nullptr; Operation *nextOp = nullptr;
@@ -503,9 +513,14 @@ scf::ForOp LoopPipeliner::createNewForOp() {
nextMapping.lookupOrDefault(loadOp.other()), loadOp.cache(), nextMapping.lookupOrDefault(loadOp.other()), loadOp.cache(),
loadOp.evict(), loadOp.isVolatile(), /*axis*/ 0); loadOp.evict(), loadOp.isVolatile(), /*axis*/ 0);
nextBuffers.push_back(insertAsyncOp); nextBuffers.push_back(insertAsyncOp);
nextOp = builder.create<triton::gpu::ExtractSliceOp>( auto sliceType = loadsMapping[loadOp].getType().cast<RankedTensorType>();
op->getLoc(), loadsMapping[loadOp].getType(), insertAsyncOp, nextOp = builder.create<tensor::ExtractSliceOp>(
extractSliceIndex, /*axis*/ 0); op->getLoc(), sliceType, insertAsyncOp,
SmallVector<OpFoldResult>{extractSliceIndex, intAttr(0), intAttr(0)},
SmallVector<OpFoldResult>{intAttr(1),
intAttr(sliceType.getShape()[0]),
intAttr(sliceType.getShape()[1])},
SmallVector<OpFoldResult>{intAttr(1), intAttr(1), intAttr(1)});
extractSlices.push_back(nextOp->getResult(0)); extractSlices.push_back(nextOp->getResult(0));
} else } else
nextOp = builder.clone(*op, nextMapping); nextOp = builder.clone(*op, nextMapping);

View File

@@ -137,7 +137,7 @@ def kernel(X0, X1, Y, BLOCK: tl.constexpr):
# reference result # reference result
if expr == "cdiv": if expr == "cdiv":
y_ref = (x0 + x1 - 1) // x1 y_ref = torch.div(x0 + x1 - 1, x1, rounding_mode='trunc')
elif expr == "umulhi": elif expr == "umulhi":
y_ref = ((x0.to(torch.int64) * x1) >> 32).to(torch.int32) y_ref = ((x0.to(torch.int64) * x1) >> 32).to(torch.int32)
else: else:

View File

@@ -25,6 +25,7 @@ from filelock import FileLock
import triton import triton
import triton._C.libtriton.triton as _triton import triton._C.libtriton.triton as _triton
from .tools.disasm import extract
def str_to_ty(name): def str_to_ty(name):
@@ -875,8 +876,6 @@ def ttir_to_ttgir(mod, num_warps, num_stages):
pm = _triton.ir.pass_manager(mod.context) pm = _triton.ir.pass_manager(mod.context)
pm.add_convert_triton_to_tritongpu_pass(num_warps) pm.add_convert_triton_to_tritongpu_pass(num_warps)
pm.enable_debug() pm.enable_debug()
# Get error in backend due to wrong conversion in expanding async-related instruction.
# TODO[Superjomn]: Open it when fixed.
pm.add_tritongpu_pipeline_pass(num_stages) pm.add_tritongpu_pipeline_pass(num_stages)
pm.add_canonicalizer_pass() pm.add_canonicalizer_pass()
pm.add_cse_pass() pm.add_cse_pass()
@@ -1396,6 +1395,19 @@ class CompiledKernel:
self.c_wrapper(grid[0], grid[1], grid[2], self.num_warps, self.shared, stream, self.cu_function, *args) self.c_wrapper(grid[0], grid[1], grid[2], self.num_warps, self.shared, stream, self.cu_function, *args)
return return
def get_sass(self, fun=None):
if 'sass' in self.asm:
return self.asm['sass']
fd, path = tempfile.mkstemp()
try:
with open(fd, 'wb') as cubin:
cubin.write(self.asm['cubin'])
self.sass = extract(path, fun)
finally:
os.remove(path)
self.asm['sass'] = self.sass
return self.sass
class CudaUtils(object): class CudaUtils(object):

View File

@@ -18,10 +18,10 @@ func @matmul_loop(%lb : index, %ub : index, %step : index, %A : !tt.ptr<f16>, %B
%a_off = arith.constant dense<4> : tensor<128x32xi32, #AL> %a_off = arith.constant dense<4> : tensor<128x32xi32, #AL>
%b_off = arith.constant dense<4> : tensor<32x128xi32, #BL> %b_off = arith.constant dense<4> : tensor<32x128xi32, #BL>
scf.for %iv = %lb to %ub step %step iter_args(%a_ptr = %a_ptr_init, %b_ptr = %b_ptr_init, %prev_c = %c_init) -> (tensor<128x32x!tt.ptr<f16>, #AL>, tensor<32x128x!tt.ptr<f16>, #BL>, tensor<128x128xf32, #C>) { scf.for %iv = %lb to %ub step %step iter_args(%a_ptr = %a_ptr_init, %b_ptr = %b_ptr_init, %prev_c = %c_init) -> (tensor<128x32x!tt.ptr<f16>, #AL>, tensor<32x128x!tt.ptr<f16>, #BL>, tensor<128x128xf32, #C>) {
%a_ = tt.load %a_ptr, %a_mask, %a_other {cache = 1 : i32, evict = 1 : i32, isOtherUnspecified = false, isVolatile = false} : tensor<128x32xf16, #AL> %a_ = tt.load %a_ptr, %a_mask, %a_other {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<128x32xf16, #AL>
// CHECK: %4 -> %4 // CHECK: %4 -> %4
%a = triton_gpu.convert_layout %a_ : (tensor<128x32xf16, #AL>) -> tensor<128x32xf16, #A> %a = triton_gpu.convert_layout %a_ : (tensor<128x32xf16, #AL>) -> tensor<128x32xf16, #A>
%b_ = tt.load %b_ptr, %b_mask, %b_other {cache = 1 : i32, evict = 1 : i32, isOtherUnspecified = false, isVolatile = false} : tensor<32x128xf16, #BL> %b_ = tt.load %b_ptr, %b_mask, %b_other {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<32x128xf16, #BL>
// CHECK-NEXT: %6 -> %6 // CHECK-NEXT: %6 -> %6
%b = triton_gpu.convert_layout %b_ : (tensor<32x128xf16, #BL>) -> tensor<32x128xf16, #B> %b = triton_gpu.convert_layout %b_ : (tensor<32x128xf16, #BL>) -> tensor<32x128xf16, #B>
%c = tt.dot %a, %b, %prev_c {allowTF32 = true, transA = false, transB = false} : tensor<128x32xf16, #A> * tensor<32x128xf16, #B> -> tensor<128x128xf32, #C> %c = tt.dot %a, %b, %prev_c {allowTF32 = true, transA = false, transB = false} : tensor<128x32xf16, #A> * tensor<32x128xf16, #B> -> tensor<128x128xf32, #C>
@@ -60,7 +60,7 @@ func @insert_slice_async(%A : !tt.ptr<f16>, %i1 : i1) {
%tensor = arith.constant dense<0.000000e+00> : tensor<1x16x16xf16, #A> %tensor = arith.constant dense<0.000000e+00> : tensor<1x16x16xf16, #A>
%index = arith.constant 0 : i32 %index = arith.constant 0 : i32
// CHECK: %2 -> %cst_0 // CHECK: %2 -> %cst_0
%a = triton_gpu.insert_slice_async %a_ptr, %tensor, %index, %mask, %other {axis = 0 : i32, cache = 1 : i32, evict = 1 : i32, isOtherUnspecified = false, isVolatile = false} : tensor<16x16x!tt.ptr<f16>, #AL> -> tensor<1x16x16xf16, #A> %a = triton_gpu.insert_slice_async %a_ptr, %tensor, %index, %mask, %other {axis = 0 : i32, cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<16x16x!tt.ptr<f16>, #AL> -> tensor<1x16x16xf16, #A>
return return
} }
@@ -68,9 +68,9 @@ func @insert_slice_async(%A : !tt.ptr<f16>, %i1 : i1) {
func @extract_slice(%A : !tt.ptr<f16>) { func @extract_slice(%A : !tt.ptr<f16>) {
// CHECK: %cst -> %cst // CHECK: %cst -> %cst
%cst0 = arith.constant dense<0.000000e+00> : tensor<1x16x16xf16, #A> %cst0 = arith.constant dense<0.000000e+00> : tensor<1x16x16xf16, #A>
%index = arith.constant 0 : i32 %index = arith.constant 0 : index
// CHECK-NEXT: %0 -> %cst // CHECK-NEXT: %0 -> %cst
%cst1 = triton_gpu.extract_slice %cst0, %index { axis = 0 : i32 } : tensor<1x16x16xf16, #A> -> tensor<16x16xf16, #A> %cst1 = tensor.extract_slice %cst0[%index, 0, 0][1, 16, 16][1, 1, 1] : tensor<1x16x16xf16, #A> to tensor<16x16xf16, #A>
return return
} }
@@ -144,9 +144,9 @@ func @for_if(%lb : index, %ub : index, %step : index, %A : !tt.ptr<f16>, %B : !t
// CHECK-NEXT: %0#2 -> %cst,%cst_0 // CHECK-NEXT: %0#2 -> %cst,%cst_0
%a_shared, %b_shared, %c_shared = scf.for %iv = %lb to %ub step %step iter_args(%a_shared = %a_shared_init, %b_shared = %b_shared_init, %c_shared = %c_shared_init) -> (tensor<128x32xf16, #A>, tensor<128x32xf16, #A>, tensor<128x32xf16, #A>) { %a_shared, %b_shared, %c_shared = scf.for %iv = %lb to %ub step %step iter_args(%a_shared = %a_shared_init, %b_shared = %b_shared_init, %c_shared = %c_shared_init) -> (tensor<128x32xf16, #A>, tensor<128x32xf16, #A>, tensor<128x32xf16, #A>) {
scf.if %i1 { scf.if %i1 {
%index = arith.constant 8 : i32 %index = arith.constant 8 : index
// CHECK-NEXT: %1 -> %cst,%cst_0 // CHECK-NEXT: %1 -> %cst,%cst_0
%cst0 = triton_gpu.extract_slice %a_shared, %index { axis = 0 : i32 } : tensor<128x32xf16, #A> -> tensor<32xf16, #A> %cst0 = tensor.extract_slice %a_shared[%index, 0][1, 32][1, 1] : tensor<128x32xf16, #A> to tensor<32xf16, #A>
scf.yield scf.yield
} }
scf.yield %b_shared, %a_shared, %a_shared : tensor<128x32xf16, #A>, tensor<128x32xf16, #A>, tensor<128x32xf16, #A> scf.yield %b_shared, %a_shared, %a_shared : tensor<128x32xf16, #A>, tensor<128x32xf16, #A>, tensor<128x32xf16, #A>

View File

@@ -178,7 +178,7 @@ func @insert_slice_async(%A : !tt.ptr<f16>, %i1 : i1) {
// CHECK: offset = 0, size = 512 // CHECK: offset = 0, size = 512
%tensor = arith.constant dense<0.000000e+00> : tensor<1x16x16xf16, #A> %tensor = arith.constant dense<0.000000e+00> : tensor<1x16x16xf16, #A>
%index = arith.constant 0 : i32 %index = arith.constant 0 : i32
%a = triton_gpu.insert_slice_async %a_ptr, %tensor, %index, %mask, %other {axis = 0 : i32, cache = 1 : i32, evict = 1 : i32, isOtherUnspecified = false, isVolatile = false} : tensor<16x16x!tt.ptr<f16>, #AL> -> tensor<1x16x16xf16, #A> %a = triton_gpu.insert_slice_async %a_ptr, %tensor, %index, %mask, %other {axis = 0 : i32, cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<16x16x!tt.ptr<f16>, #AL> -> tensor<1x16x16xf16, #A>
return return
// CHECK-NEXT: size = 512 // CHECK-NEXT: size = 512
} }
@@ -187,8 +187,8 @@ func @insert_slice_async(%A : !tt.ptr<f16>, %i1 : i1) {
func @extract_slice(%A : !tt.ptr<f16>) { func @extract_slice(%A : !tt.ptr<f16>) {
// CHECK: offset = 0, size = 512 // CHECK: offset = 0, size = 512
%cst0 = arith.constant dense<0.000000e+00> : tensor<1x16x16xf16, #A> %cst0 = arith.constant dense<0.000000e+00> : tensor<1x16x16xf16, #A>
%index = arith.constant 0 : i32 %index = arith.constant 0 : index
%cst1 = triton_gpu.extract_slice %cst0, %index { axis = 0 : i32 } : tensor<1x16x16xf16, #A> -> tensor<16x16xf16, #A> %cst1 = tensor.extract_slice %cst0[%index, 0, 0][1, 16, 16][1,1,1] : tensor<1x16x16xf16, #A> to tensor<16x16xf16, #A>
return return
// CHECK-NEXT: size = 512 // CHECK-NEXT: size = 512
} }
@@ -271,8 +271,8 @@ func @for_if_slice(%lb : index, %ub : index, %step : index, %A : !tt.ptr<f16>, %
%c_shared_init = arith.constant dense<0.00e+00> : tensor<128x32xf16, #A> %c_shared_init = arith.constant dense<0.00e+00> : tensor<128x32xf16, #A>
%a_shared, %b_shared, %c_shared = scf.for %iv = %lb to %ub step %step iter_args(%a_shared = %a_shared_init, %b_shared = %b_shared_init, %c_shared = %c_shared_init) -> (tensor<128x32xf16, #A>, tensor<128x32xf16, #A>, tensor<128x32xf16, #A>) { %a_shared, %b_shared, %c_shared = scf.for %iv = %lb to %ub step %step iter_args(%a_shared = %a_shared_init, %b_shared = %b_shared_init, %c_shared = %c_shared_init) -> (tensor<128x32xf16, #A>, tensor<128x32xf16, #A>, tensor<128x32xf16, #A>) {
scf.if %i1 { scf.if %i1 {
%index = arith.constant 8 : i32 %index = arith.constant 8 : index
%cst0 = triton_gpu.extract_slice %a_shared, %index { axis = 0 : i32 } : tensor<128x32xf16, #A> -> tensor<32xf16, #A> %cst0 = tensor.extract_slice %a_shared[%index, 0][1, 32][1, 1] : tensor<128x32xf16, #A> to tensor<32xf16, #A>
scf.yield scf.yield
} }
scf.yield %b_shared, %a_shared, %a_shared : tensor<128x32xf16, #A>, tensor<128x32xf16, #A>, tensor<128x32xf16, #A> scf.yield %b_shared, %a_shared, %a_shared : tensor<128x32xf16, #A>, tensor<128x32xf16, #A>, tensor<128x32xf16, #A>

View File

@@ -22,9 +22,9 @@ func @matmul_loop(%lb : index, %ub : index, %step : index, %A : !tt.ptr<f16>, %B
%b_off = arith.constant dense<4> : tensor<32x128xi32, #BL> %b_off = arith.constant dense<4> : tensor<32x128xi32, #BL>
scf.for %iv = %lb to %ub step %step iter_args(%a_ptr = %a_ptr_init, %b_ptr = %b_ptr_init, %prev_c = %c_init) -> (tensor<128x32x!tt.ptr<f16>, #AL>, tensor<32x128x!tt.ptr<f16>, #BL>, tensor<128x128xf32, #C>) { scf.for %iv = %lb to %ub step %step iter_args(%a_ptr = %a_ptr_init, %b_ptr = %b_ptr_init, %prev_c = %c_init) -> (tensor<128x32x!tt.ptr<f16>, #AL>, tensor<32x128x!tt.ptr<f16>, #BL>, tensor<128x128xf32, #C>) {
%a_ = tt.load %a_ptr, %a_mask, %a_other {cache = 1 : i32, evict = 1 : i32, isOtherUnspecified = false, isVolatile = false} : tensor<128x32xf16, #AL> %a_ = tt.load %a_ptr, %a_mask, %a_other {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<128x32xf16, #AL>
%a = triton_gpu.convert_layout %a_ : (tensor<128x32xf16, #AL>) -> tensor<128x32xf16, #A> %a = triton_gpu.convert_layout %a_ : (tensor<128x32xf16, #AL>) -> tensor<128x32xf16, #A>
%b_ = tt.load %b_ptr, %b_mask, %b_other {cache = 1 : i32, evict = 1 : i32, isOtherUnspecified = false, isVolatile = false} : tensor<32x128xf16, #BL> %b_ = tt.load %b_ptr, %b_mask, %b_other {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<32x128xf16, #BL>
%b = triton_gpu.convert_layout %b_ : (tensor<32x128xf16, #BL>) -> tensor<32x128xf16, #B> %b = triton_gpu.convert_layout %b_ : (tensor<32x128xf16, #BL>) -> tensor<32x128xf16, #B>
// CHECK: Membar 13 // CHECK: Membar 13
%c = tt.dot %a, %b, %prev_c {allowTF32 = true, transA = false, transB = false} : tensor<128x32xf16, #A> * tensor<32x128xf16, #B> -> tensor<128x128xf32, #C> %c = tt.dot %a, %b, %prev_c {allowTF32 = true, transA = false, transB = false} : tensor<128x32xf16, #A> * tensor<32x128xf16, #B> -> tensor<128x128xf32, #C>
@@ -41,7 +41,7 @@ func @raw_single_block(%A : !tt.ptr<f16>) {
%cst1 = arith.constant dense<true> : tensor<128x32xi1, #AL> %cst1 = arith.constant dense<true> : tensor<128x32xi1, #AL>
%cst2 = arith.constant dense<0.000000e+00> : tensor<128x32xf16, #AL> %cst2 = arith.constant dense<0.000000e+00> : tensor<128x32xf16, #AL>
%a_ptr = tt.broadcast %A : (!tt.ptr<f16>) -> tensor<128x32x!tt.ptr<f16>, #AL> %a_ptr = tt.broadcast %A : (!tt.ptr<f16>) -> tensor<128x32x!tt.ptr<f16>, #AL>
%a1_ = tt.load %a_ptr, %cst1, %cst2 {cache = 1 : i32, evict = 1 : i32, isOtherUnspecified = false, isVolatile = false} : tensor<128x32xf16, #AL> %a1_ = tt.load %a_ptr, %cst1, %cst2 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<128x32xf16, #AL>
%a1 = triton_gpu.convert_layout %a1_ : (tensor<128x32xf16, #AL>) -> tensor<128x32xf16, #A> %a1 = triton_gpu.convert_layout %a1_ : (tensor<128x32xf16, #AL>) -> tensor<128x32xf16, #A>
// CHECK: Membar 5 // CHECK: Membar 5
%a2 = triton_gpu.convert_layout %a1 : (tensor<128x32xf16, #A>) -> tensor<128x32xf16, #A> %a2 = triton_gpu.convert_layout %a1 : (tensor<128x32xf16, #A>) -> tensor<128x32xf16, #A>
@@ -53,7 +53,7 @@ func @war_single_block(%A : !tt.ptr<f16>) {
%cst1 = arith.constant dense<true> : tensor<128x32xi1, #AL> %cst1 = arith.constant dense<true> : tensor<128x32xi1, #AL>
%cst2 = arith.constant dense<0.000000e+00> : tensor<128x32xf16, #AL> %cst2 = arith.constant dense<0.000000e+00> : tensor<128x32xf16, #AL>
%a_ptr = tt.broadcast %A : (!tt.ptr<f16>) -> tensor<128x32x!tt.ptr<f16>, #AL> %a_ptr = tt.broadcast %A : (!tt.ptr<f16>) -> tensor<128x32x!tt.ptr<f16>, #AL>
%a1_ = tt.load %a_ptr, %cst1, %cst2 {cache = 1 : i32, evict = 1 : i32, isOtherUnspecified = false, isVolatile = false} : tensor<128x32xf16, #AL> %a1_ = tt.load %a_ptr, %cst1, %cst2 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<128x32xf16, #AL>
%a1 = triton_gpu.convert_layout %a1_ : (tensor<128x32xf16, #AL>) -> tensor<128x32xf16, #A> %a1 = triton_gpu.convert_layout %a1_ : (tensor<128x32xf16, #AL>) -> tensor<128x32xf16, #A>
// CHECK: Membar 5 // CHECK: Membar 5
%a2 = triton_gpu.convert_layout %a1 : (tensor<128x32xf16, #A>) -> tensor<128x32xf16, #AL> %a2 = triton_gpu.convert_layout %a1 : (tensor<128x32xf16, #A>) -> tensor<128x32xf16, #AL>
@@ -98,8 +98,8 @@ func @alloc() {
// CHECK-LABEL: extract_slice // CHECK-LABEL: extract_slice
func @extract_slice() { func @extract_slice() {
%cst0 = arith.constant dense<0.000000e+00> : tensor<1x16x16xf16, #A> %cst0 = arith.constant dense<0.000000e+00> : tensor<1x16x16xf16, #A>
%index = arith.constant 0 : i32 %index = arith.constant 0 : index
%cst1 = triton_gpu.extract_slice %cst0, %index { axis = 0 : i32 } : tensor<1x16x16xf16, #A> -> tensor<16x16xf16, #A> %cst1 = tensor.extract_slice %cst0[%index, 0, 0][1, 16, 16][1, 1, 1] : tensor<1x16x16xf16, #A> to tensor<16x16xf16, #A>
// CHECK: Membar 3 // CHECK: Membar 3
%cst2 = triton_gpu.convert_layout %cst1 : (tensor<16x16xf16, #A>) -> tensor<16x16xf16, #AL> %cst2 = triton_gpu.convert_layout %cst1 : (tensor<16x16xf16, #A>) -> tensor<16x16xf16, #AL>
// CHECK-NEXT: Membar 5 // CHECK-NEXT: Membar 5
@@ -114,7 +114,7 @@ func @insert_slice_async(%A : !tt.ptr<f16>, %i1 : i1) {
%other = arith.constant dense<0.000000e+00> : tensor<16x16xf16, #AL> %other = arith.constant dense<0.000000e+00> : tensor<16x16xf16, #AL>
%tensor = triton_gpu.alloc_tensor : tensor<1x16x16xf16, #A> %tensor = triton_gpu.alloc_tensor : tensor<1x16x16xf16, #A>
%index = arith.constant 0 : i32 %index = arith.constant 0 : i32
%a = triton_gpu.insert_slice_async %a_ptr, %tensor, %index, %mask, %other {axis = 0 : i32, cache = 1 : i32, evict = 1 : i32, isOtherUnspecified = false, isVolatile = false} : tensor<16x16x!tt.ptr<f16>, #AL> -> tensor<1x16x16xf16, #A> %a = triton_gpu.insert_slice_async %a_ptr, %tensor, %index, %mask, %other {axis = 0 : i32, cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<16x16x!tt.ptr<f16>, #AL> -> tensor<1x16x16xf16, #A>
%b = tt.cat %a, %a {axis = 0} : (tensor<1x16x16xf16, #A>, tensor<1x16x16xf16, #A>) -> tensor<2x16x16xf16, #A> %b = tt.cat %a, %a {axis = 0} : (tensor<1x16x16xf16, #A>, tensor<1x16x16xf16, #A>) -> tensor<2x16x16xf16, #A>
// CHECK: Membar 7 // CHECK: Membar 7
%c = tt.cat %b, %b {axis = 0} : (tensor<2x16x16xf16, #A>, tensor<2x16x16xf16, #A>) -> tensor<4x16x16xf16, #A> %c = tt.cat %b, %b {axis = 0} : (tensor<2x16x16xf16, #A>, tensor<2x16x16xf16, #A>) -> tensor<4x16x16xf16, #A>

View File

@@ -346,18 +346,24 @@ module attributes {"triton_gpu.num-warps" = 4 : i32} {
// CHECK: llvm.mlir.global external @global_smem // CHECK: llvm.mlir.global external @global_smem
// CHECK-LABEL: basic_extract_slice // CHECK-LABEL: basic_extract_slice
func @basic_extract_slice() { func @basic_extract_slice() {
// CHECK: %[[BASE0:.*]] = llvm.mlir.addressof @global_smem // CHECK: llvm.mlir.addressof @global_smem
// CHECK-NEXT: %[[BASE1:.*]] = llvm.bitcast %[[BASE0]] // CHECK: llvm.extractvalue
// CHECK-NEXT: %[[OFFSET0:.*]] = llvm.mlir.constant // CHECK-NEXT: llvm.extractvalue
// CHECK-NEXT: %[[OFFSET1:.*]] = llvm.mlir.constant // CHECK-NEXT: llvm.extractvalue
// CHECK-NEXT: llvm.getelementptr %[[BASE1]][%[[OFFSET1]]] // CHECK-NEXT: llvm.extractvalue
// CHECK-NEXT: %[[BASE2:.*]] = llvm.bitcast // CHECK-NEXT: llvm.mlir.constant(0 : i32) : i32
// CHECK-NEXT: %[[OFFSET2:.*]] = llvm.mlir.constant // CHECK-NEXT: llvm.mlir.constant(0 : i32) : i32
// CHECK-NEXT: %[[OFFSET3:.*]] = llvm.mul %[[OFFSET0]], %[[OFFSET2]] // CHECK-NEXT: llvm.mlir.constant(0 : i32) : i32
// CHECK-NEXT: llvm.getelementptr %[[BASE2]][%[[OFFSET3]]] // CHECK-NEXT: llvm.mul
%index = arith.constant 1 : i32 // CHECK-NEXT: llvm.add
// CHECK-NEXT: llvm.mul
// CHECK-NEXT: llvm.add
// CHECK-NEXT: llvm.mul
// CHECK-NEXT: llvm.add
// CHECK-NEXT: llvm.getelementptr
%index = arith.constant 1 : index
%0 = triton_gpu.alloc_tensor : tensor<128x16x32xf32, #shared0> %0 = triton_gpu.alloc_tensor : tensor<128x16x32xf32, #shared0>
%1 = triton_gpu.extract_slice %0, %index {axis = 0: i32} : tensor<128x16x32xf32, #shared0> -> tensor<16x32xf32, #shared0> %1 = tensor.extract_slice %0[%index, 0, 0][1, 16, 32][1, 1, 1] : tensor<128x16x32xf32, #shared0> to tensor<16x32xf32, #shared0>
return return
} }
} }
@@ -488,22 +494,38 @@ module attributes {"triton_gpu.num-warps" = 4 : i32} {
%tensor = triton_gpu.alloc_tensor : tensor<2x32x32xf32, #A> %tensor = triton_gpu.alloc_tensor : tensor<2x32x32xf32, #A>
%index = arith.constant 1 : i32 %index = arith.constant 1 : i32
// CHECK: llvm.mlir.constant(0 : i32) : i32
// CHECK: llvm.add
// CHECK: llvm.inline_asm // CHECK: llvm.inline_asm
// CHECK-SAME: cp.async.ca.shared.global [ ${{.*}} + 0 ], [ ${{.*}} + 0 ], 0x4, 0x4 // CHECK-SAME: cp.async.ca.shared.global [ ${{.*}} + 0 ], [ ${{.*}} + 0 ], 0x4, 0x4
// CHECK: llvm.mlir.constant(0 : i32) : i32
// CHECK: llvm.add
// CHECK: llvm.inline_asm // CHECK: llvm.inline_asm
// CHECK-SAME: cp.async.ca.shared.global [ ${{.*}} + 0 ], [ ${{.*}} + 0 ], 0x4, 0x4 // CHECK-SAME: cp.async.ca.shared.global [ ${{.*}} + 0 ], [ ${{.*}} + 0 ], 0x4, 0x4
// CHECK: llvm.mlir.constant(0 : i32) : i32
// CHECK: llvm.add
// CHECK: llvm.inline_asm // CHECK: llvm.inline_asm
// CHECK-SAME: cp.async.ca.shared.global [ ${{.*}} + 0 ], [ ${{.*}} + 0 ], 0x4, 0x4 // CHECK-SAME: cp.async.ca.shared.global [ ${{.*}} + 0 ], [ ${{.*}} + 0 ], 0x4, 0x4
// CHECK: llvm.mlir.constant(0 : i32) : i32
// CHECK: llvm.add
// CHECK: llvm.inline_asm // CHECK: llvm.inline_asm
// CHECK-SAME: cp.async.ca.shared.global [ ${{.*}} + 0 ], [ ${{.*}} + 0 ], 0x4, 0x4 // CHECK-SAME: cp.async.ca.shared.global [ ${{.*}} + 0 ], [ ${{.*}} + 0 ], 0x4, 0x4
// CHECK: llvm.mlir.constant(16 : i32) : i32
// CHECK: llvm.add
// CHECK: llvm.inline_asm // CHECK: llvm.inline_asm
// CHECK-SAME: cp.async.ca.shared.global [ ${{.*}} + 2048 ], [ ${{.*}} + 0 ], 0x4, 0x4 // CHECK-SAME: cp.async.ca.shared.global [ ${{.*}} + 0 ], [ ${{.*}} + 0 ], 0x4, 0x4
// CHECK: llvm.mlir.constant(16 : i32) : i32
// CHECK: llvm.add
// CHECK: llvm.inline_asm // CHECK: llvm.inline_asm
// CHECK-SAME: cp.async.ca.shared.global [ ${{.*}} + 2048 ], [ ${{.*}} + 0 ], 0x4, 0x4 // CHECK-SAME: cp.async.ca.shared.global [ ${{.*}} + 0 ], [ ${{.*}} + 0 ], 0x4, 0x4
// CHECK: llvm.mlir.constant(16 : i32) : i32
// CHECK: llvm.add
// CHECK: llvm.inline_asm // CHECK: llvm.inline_asm
// CHECK-SAME: cp.async.ca.shared.global [ ${{.*}} + 2048 ], [ ${{.*}} + 0 ], 0x4, 0x4 // CHECK-SAME: cp.async.ca.shared.global [ ${{.*}} + 0 ], [ ${{.*}} + 0 ], 0x4, 0x4
// CHECK: llvm.mlir.constant(16 : i32) : i32
// CHECK: llvm.add
// CHECK: llvm.inline_asm // CHECK: llvm.inline_asm
// CHECK-SAME: cp.async.ca.shared.global [ ${{.*}} + 2048 ], [ ${{.*}} + 0 ], 0x4, 0x4 // CHECK-SAME: cp.async.ca.shared.global [ ${{.*}} + 0 ], [ ${{.*}} + 0 ], 0x4, 0x4
// CHECK: llvm.inline_asm // CHECK: llvm.inline_asm
// CHECK-SAME: cp.async.commit_group // CHECK-SAME: cp.async.commit_group
%a = triton_gpu.insert_slice_async %a_ptr, %tensor, %index {axis = 0 : i32, cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<32x32x!tt.ptr<f32>, #AL> -> tensor<2x32x32xf32, #A> %a = triton_gpu.insert_slice_async %a_ptr, %tensor, %index {axis = 0 : i32, cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<32x32x!tt.ptr<f32>, #AL> -> tensor<2x32x32xf32, #A>

View File

@@ -62,7 +62,7 @@ func @remat(%arg0: i32) -> tensor<1024xi32, #layout1> {
// CHECK-LABEL: transpose // CHECK-LABEL: transpose
func @transpose(%arg0: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg1: i32 {tt.divisibility = 16 : i32}, %arg2: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg3: i32 {tt.divisibility = 16 : i32}) { func @transpose(%arg0: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg1: i32 {tt.divisibility = 16 : i32}, %arg2: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg3: i32 {tt.divisibility = 16 : i32}) {
// CHECK-NOT: triton_gpu.convert_layout // CHECK-NOT: triton_gpu.convert_layout
// CHECK: [[loaded_val:%.*]] = tt.load {{.*}}, %cst, %cst_0 {cache = 1 : i32, evict = 1 : i32, isOtherUnspecified = false, isVolatile = false} : tensor<64x64xf32, [[row_layout]]> // CHECK: [[loaded_val:%.*]] = tt.load {{.*}}, %cst, %cst_0 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<64x64xf32, [[row_layout]]>
// CHECK: [[cvt_val:%.*]] = triton_gpu.convert_layout [[loaded_val]] : (tensor<64x64xf32, [[row_layout]]>) -> tensor<64x64xf32, [[col_layout]]> // CHECK: [[cvt_val:%.*]] = triton_gpu.convert_layout [[loaded_val]] : (tensor<64x64xf32, [[row_layout]]>) -> tensor<64x64xf32, [[col_layout]]>
// CHECK: tt.store {{.*}}, [[cvt_val]], %cst_1 : tensor<64x64xf32, [[col_layout]]> // CHECK: tt.store {{.*}}, [[cvt_val]], %cst_1 : tensor<64x64xf32, [[col_layout]]>
// CHECK: return // CHECK: return
@@ -91,7 +91,7 @@ func @transpose(%arg0: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg1: i32 {tt
%19 = triton_gpu.convert_layout %10 : (tensor<64x64x!tt.ptr<f32>, #blocked1>) -> tensor<64x64x!tt.ptr<f32>, #blocked3> %19 = triton_gpu.convert_layout %10 : (tensor<64x64x!tt.ptr<f32>, #blocked1>) -> tensor<64x64x!tt.ptr<f32>, #blocked3>
%20 = triton_gpu.convert_layout %cst_0 : (tensor<64x64xi1, #blocked1>) -> tensor<64x64xi1, #blocked3> %20 = triton_gpu.convert_layout %cst_0 : (tensor<64x64xi1, #blocked1>) -> tensor<64x64xi1, #blocked3>
%21 = triton_gpu.convert_layout %cst : (tensor<64x64xf32, #blocked1>) -> tensor<64x64xf32, #blocked3> %21 = triton_gpu.convert_layout %cst : (tensor<64x64xf32, #blocked1>) -> tensor<64x64xf32, #blocked3>
%22 = tt.load %19, %20, %21 {cache = 1 : i32, evict = 1 : i32, isVolatile = false, isOtherUnspecified = false} : tensor<64x64xf32, #blocked3> %22 = tt.load %19, %20, %21 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<64x64xf32, #blocked3>
%23 = triton_gpu.convert_layout %22 : (tensor<64x64xf32, #blocked3>) -> tensor<64x64xf32, #blocked1> %23 = triton_gpu.convert_layout %22 : (tensor<64x64xf32, #blocked3>) -> tensor<64x64xf32, #blocked1>
%24 = triton_gpu.convert_layout %18 : (tensor<64x64x!tt.ptr<f32>, #blocked1>) -> tensor<64x64x!tt.ptr<f32>, #blocked4> %24 = triton_gpu.convert_layout %18 : (tensor<64x64x!tt.ptr<f32>, #blocked1>) -> tensor<64x64x!tt.ptr<f32>, #blocked4>
%25 = triton_gpu.convert_layout %23 : (tensor<64x64xf32, #blocked1>) -> tensor<64x64xf32, #blocked4> %25 = triton_gpu.convert_layout %23 : (tensor<64x64xf32, #blocked1>) -> tensor<64x64xf32, #blocked4>
@@ -133,7 +133,7 @@ func @loop(%arg0: !tt.ptr<f32>, %arg1: i32, %arg2: !tt.ptr<f32>, %arg3: i32, %ar
%23 = triton_gpu.convert_layout %arg7 : (tensor<64x64x!tt.ptr<f32>, #blocked1>) -> tensor<64x64x!tt.ptr<f32>, #blocked3> %23 = triton_gpu.convert_layout %arg7 : (tensor<64x64x!tt.ptr<f32>, #blocked1>) -> tensor<64x64x!tt.ptr<f32>, #blocked3>
%24 = triton_gpu.convert_layout %cst : (tensor<64x64xi1, #blocked1>) -> tensor<64x64xi1, #blocked3> %24 = triton_gpu.convert_layout %cst : (tensor<64x64xi1, #blocked1>) -> tensor<64x64xi1, #blocked3>
%25 = triton_gpu.convert_layout %cst_1 : (tensor<64x64xf32, #blocked1>) -> tensor<64x64xf32, #blocked3> %25 = triton_gpu.convert_layout %cst_1 : (tensor<64x64xf32, #blocked1>) -> tensor<64x64xf32, #blocked3>
%26 = tt.load %23, %24, %25 {cache = 1 : i32, evict = 1 : i32, isOtherUnspecified = false, isVolatile = false} : tensor<64x64xf32, #blocked3> %26 = tt.load %23, %24, %25 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<64x64xf32, #blocked3>
%27 = triton_gpu.convert_layout %26 : (tensor<64x64xf32, #blocked3>) -> tensor<64x64xf32, #blocked1> %27 = triton_gpu.convert_layout %26 : (tensor<64x64xf32, #blocked3>) -> tensor<64x64xf32, #blocked1>
%28 = arith.addf %arg6, %27 : tensor<64x64xf32, #blocked1> %28 = arith.addf %arg6, %27 : tensor<64x64xf32, #blocked1>
%29 = tt.addptr %arg7, %cst_0 : tensor<64x64x!tt.ptr<f32>, #blocked1> %29 = tt.addptr %arg7, %cst_0 : tensor<64x64x!tt.ptr<f32>, #blocked1>

View File

@@ -20,17 +20,18 @@
// CHECK: %[[A1BUFFER:.*]] = triton_gpu.insert_slice_async {{.*}}, {{.*}}, %[[CONSTANT_1]] // CHECK: %[[A1BUFFER:.*]] = triton_gpu.insert_slice_async {{.*}}, {{.*}}, %[[CONSTANT_1]]
// CHECK: %[[B1BUFFER:.*]] = triton_gpu.insert_slice_async {{.*}}, {{.*}}, %[[CONSTANT_1]] // CHECK: %[[B1BUFFER:.*]] = triton_gpu.insert_slice_async {{.*}}, {{.*}}, %[[CONSTANT_1]]
// CHECK: triton_gpu.async_wait {num = 2 : i32} // CHECK: triton_gpu.async_wait {num = 2 : i32}
// CHECK: %[[A0:.*]] = triton_gpu.extract_slice %[[A1BUFFER]], %[[CONSTANT_0]] // CHECK: %[[A0:.*]] = tensor.extract_slice %[[A1BUFFER]][0, 0, 0]
// CHECK: %[[B0:.*]] = triton_gpu.extract_slice %[[B1BUFFER]], %[[CONSTANT_0]] // CHECK: %[[B0:.*]] = tensor.extract_slice %[[B1BUFFER]][0, 0, 0]
// CHECK: scf.for {{.*}} iter_args({{.*}}, {{.*}}, {{.*}}, {{.*}}, {{.*}}, %[[arg_a0:.*]] = %[[A0]], %[[arg_b0:.*]] = %[[B0]], {{.*}}, {{.*}}, {{.*}}, %[[PIPELINE_IDX:.*]] = %[[CONSTANT_2]], %[[LOOP_IDX:.*]] = %[[CONSTANT_1]] // CHECK: scf.for {{.*}} iter_args({{.*}}, {{.*}}, {{.*}}, {{.*}}, {{.*}}, %[[arg_a0:.*]] = %[[A0]], %[[arg_b0:.*]] = %[[B0]], {{.*}}, {{.*}}, {{.*}}, %[[PIPELINE_IDX:.*]] = %[[CONSTANT_2]], %[[LOOP_IDX:.*]] = %[[CONSTANT_1]]
// CHECK: tt.dot %[[arg_a0]], %[[arg_b0]], {{.*}} // CHECK: tt.dot %[[arg_a0]], %[[arg_b0]], {{.*}}
// CHECK-DAG: %[[INSERT_IDX:.*]] = arith.remsi %[[PIPELINE_IDX]], %[[CONSTANT_3]] // CHECK-DAG: %[[INSERT_IDX:.*]] = arith.remsi %[[PIPELINE_IDX]], %[[CONSTANT_3]]
// CHECK-DAG: %[[EXTRACT_IDX:.*]] = arith.remsi %[[LOOP_IDX]], %[[CONSTANT_3]] // CHECK-DAG: %[[EXTRACT_INT:.*]] = arith.remsi %[[LOOP_IDX]], %[[CONSTANT_3]]
// CHECK-DAG: %[[EXTRACT_IDX:.*]] = arith.index_cast %[[EXTRACT_INT]] : i32 to index
// CHECK: %[[NEXT_A_BUFFER:.*]] = triton_gpu.insert_slice_async {{.*}}, {{.*}}, %[[INSERT_IDX]] // CHECK: %[[NEXT_A_BUFFER:.*]] = triton_gpu.insert_slice_async {{.*}}, {{.*}}, %[[INSERT_IDX]]
// CHECK: %[[NEXT_B_BUFFER:.*]] = triton_gpu.insert_slice_async {{.*}}, {{.*}}, %[[INSERT_IDX]] // CHECK: %[[NEXT_B_BUFFER:.*]] = triton_gpu.insert_slice_async {{.*}}, {{.*}}, %[[INSERT_IDX]]
// CHECK: triton_gpu.async_wait {num = 2 : i32} // CHECK: triton_gpu.async_wait {num = 2 : i32}
// CHECK: %[[NEXT_A:.*]] = triton_gpu.extract_slice %[[NEXT_A_BUFFER]], %[[EXTRACT_IDX]] // CHECK: %[[NEXT_A:.*]] = tensor.extract_slice %[[NEXT_A_BUFFER]][%[[EXTRACT_IDX]], 0, 0]
// CHECK: %[[NEXT_B:.*]] = triton_gpu.extract_slice %[[NEXT_B_BUFFER]], %[[EXTRACT_IDX]] // CHECK: %[[NEXT_B:.*]] = tensor.extract_slice %[[NEXT_B_BUFFER]][%[[EXTRACT_IDX]], 0, 0]
// CHECK-DAG: %[[NEXT_PIPELINE_IDX:.*]] = arith.addi %[[PIPELINE_IDX]], %[[CONSTANT_1]] // CHECK-DAG: %[[NEXT_PIPELINE_IDX:.*]] = arith.addi %[[PIPELINE_IDX]], %[[CONSTANT_1]]
// CHECK-DAG: %[[NEXT_LOOP_IDX:.*]] = arith.addi %[[LOOP_IDX]], %[[CONSTANT_1]] // CHECK-DAG: %[[NEXT_LOOP_IDX:.*]] = arith.addi %[[LOOP_IDX]], %[[CONSTANT_1]]
// CHECK: scf.yield {{.*}}, {{.*}}, {{.*}}, %[[NEXT_A_BUFFER]], %[[NEXT_B_BUFFER]], %[[NEXT_A]], %[[NEXT_B]], {{.*}}, {{.*}}, {{.*}}, %[[NEXT_PIPELINE_IDX]], %[[NEXT_LOOP_IDX]] // CHECK: scf.yield {{.*}}, {{.*}}, {{.*}}, %[[NEXT_A_BUFFER]], %[[NEXT_B_BUFFER]], %[[NEXT_A]], %[[NEXT_B]], {{.*}}, {{.*}}, {{.*}}, %[[NEXT_PIPELINE_IDX]], %[[NEXT_LOOP_IDX]]
@@ -76,17 +77,18 @@ func @matmul_loop(%lb : index, %ub : index, %step : index, %A : !tt.ptr<f16>, %B
// CHECK: %[[A1BUFFER:.*]] = triton_gpu.insert_slice_async {{.*}}, {{.*}}, %[[CONSTANT_1]] // CHECK: %[[A1BUFFER:.*]] = triton_gpu.insert_slice_async {{.*}}, {{.*}}, %[[CONSTANT_1]]
// CHECK: %[[B1BUFFER:.*]] = triton_gpu.insert_slice_async {{.*}}, {{.*}}, %[[CONSTANT_1]] // CHECK: %[[B1BUFFER:.*]] = triton_gpu.insert_slice_async {{.*}}, {{.*}}, %[[CONSTANT_1]]
// CHECK: triton_gpu.async_wait {num = 2 : i32} // CHECK: triton_gpu.async_wait {num = 2 : i32}
// CHECK: %[[A0:.*]] = triton_gpu.extract_slice %[[A1BUFFER]], %[[CONSTANT_0]] // CHECK: %[[A0:.*]] = tensor.extract_slice %[[A1BUFFER]][0, 0, 0]
// CHECK: %[[B0:.*]] = triton_gpu.extract_slice %[[B1BUFFER]], %[[CONSTANT_0]] // CHECK: %[[B0:.*]] = tensor.extract_slice %[[B1BUFFER]][0, 0, 0]
// CHECK: scf.for {{.*}} iter_args({{.*}}, {{.*}}, {{.*}}, {{.*}}, {{.*}}, %[[arg_a0:.*]] = %[[A0]], %[[arg_b0:.*]] = %[[B0]], {{.*}}, {{.*}}, {{.*}}, %[[PIPELINE_IDX:.*]] = %[[CONSTANT_2]], %[[LOOP_IDX:.*]] = %[[CONSTANT_1]] // CHECK: scf.for {{.*}} iter_args({{.*}}, {{.*}}, {{.*}}, {{.*}}, {{.*}}, %[[arg_a0:.*]] = %[[A0]], %[[arg_b0:.*]] = %[[B0]], {{.*}}, {{.*}}, {{.*}}, %[[PIPELINE_IDX:.*]] = %[[CONSTANT_2]], %[[LOOP_IDX:.*]] = %[[CONSTANT_1]]
// CHECK: tt.dot %[[arg_a0]], %[[arg_b0]], {{.*}} // CHECK: tt.dot %[[arg_a0]], %[[arg_b0]], {{.*}}
// CHECK-DAG: %[[INSERT_IDX:.*]] = arith.remsi %[[PIPELINE_IDX]], %[[CONSTANT_3]] // CHECK-DAG: %[[INSERT_IDX:.*]] = arith.remsi %[[PIPELINE_IDX]], %[[CONSTANT_3]]
// CHECK-DAG: %[[EXTRACT_IDX:.*]] = arith.remsi %[[LOOP_IDX]], %[[CONSTANT_3]] // CHECK-DAG: %[[EXTRACT_INT:.*]] = arith.remsi %[[LOOP_IDX]], %[[CONSTANT_3]]
// CHECK-DAG: %[[EXTRACT_IDX:.*]] = arith.index_cast %[[EXTRACT_INT]] : i32 to index
// CHECK: %[[NEXT_A_BUFFER:.*]] = triton_gpu.insert_slice_async {{.*}}, {{.*}}, %[[INSERT_IDX]] // CHECK: %[[NEXT_A_BUFFER:.*]] = triton_gpu.insert_slice_async {{.*}}, {{.*}}, %[[INSERT_IDX]]
// CHECK: %[[NEXT_B_BUFFER:.*]] = triton_gpu.insert_slice_async {{.*}}, {{.*}}, %[[INSERT_IDX]] // CHECK: %[[NEXT_B_BUFFER:.*]] = triton_gpu.insert_slice_async {{.*}}, {{.*}}, %[[INSERT_IDX]]
// CHECK: triton_gpu.async_wait {num = 2 : i32} // CHECK: triton_gpu.async_wait {num = 2 : i32}
// CHECK: %[[NEXT_A:.*]] = triton_gpu.extract_slice %[[NEXT_A_BUFFER]], %[[EXTRACT_IDX]] // CHECK: %[[NEXT_A:.*]] = tensor.extract_slice %[[NEXT_A_BUFFER]][%[[EXTRACT_IDX]], 0, 0]
// CHECK: %[[NEXT_B:.*]] = triton_gpu.extract_slice %[[NEXT_B_BUFFER]], %[[EXTRACT_IDX]] // CHECK: %[[NEXT_B:.*]] = tensor.extract_slice %[[NEXT_B_BUFFER]][%[[EXTRACT_IDX]], 0, 0]
// CHECK-DAG: %[[NEXT_PIPELINE_IDX:.*]] = arith.addi %[[PIPELINE_IDX]], %[[CONSTANT_1]] // CHECK-DAG: %[[NEXT_PIPELINE_IDX:.*]] = arith.addi %[[PIPELINE_IDX]], %[[CONSTANT_1]]
// CHECK-DAG: %[[NEXT_LOOP_IDX:.*]] = arith.addi %[[LOOP_IDX]], %[[CONSTANT_1]] // CHECK-DAG: %[[NEXT_LOOP_IDX:.*]] = arith.addi %[[LOOP_IDX]], %[[CONSTANT_1]]
// CHECK: scf.yield {{.*}}, {{.*}}, {{.*}}, %[[NEXT_A_BUFFER]], %[[NEXT_B_BUFFER]], %[[NEXT_A]], %[[NEXT_B]], {{.*}}, {{.*}}, {{.*}}, %[[NEXT_PIPELINE_IDX]], %[[NEXT_LOOP_IDX]] // CHECK: scf.yield {{.*}}, {{.*}}, {{.*}}, %[[NEXT_A_BUFFER]], %[[NEXT_B_BUFFER]], %[[NEXT_A]], %[[NEXT_B]], {{.*}}, {{.*}}, {{.*}}, %[[NEXT_PIPELINE_IDX]], %[[NEXT_LOOP_IDX]]
@@ -130,14 +132,15 @@ func @matmul_loop_nested(%lb : index, %ub : index, %step : index, %A : !tt.ptr<f
// CHECK: %[[B0BUFFER:.*]] = triton_gpu.insert_slice_async {{.*}}, {{.*}}, %[[CONSTANT_0]] // CHECK: %[[B0BUFFER:.*]] = triton_gpu.insert_slice_async {{.*}}, {{.*}}, %[[CONSTANT_0]]
// CHECK: %[[B1BUFFER:.*]] = triton_gpu.insert_slice_async {{.*}}, {{.*}}, %[[CONSTANT_1]] // CHECK: %[[B1BUFFER:.*]] = triton_gpu.insert_slice_async {{.*}}, {{.*}}, %[[CONSTANT_1]]
// CHECK: triton_gpu.async_wait {num = 1 : i32} // CHECK: triton_gpu.async_wait {num = 1 : i32}
// CHECK: %[[B0:.*]] = triton_gpu.extract_slice %[[B1BUFFER]], %[[CONSTANT_0]] // CHECK: %[[B0:.*]] = tensor.extract_slice %[[B1BUFFER]][0, 0, 0]
// CHECK: scf.for {{.*}} iter_args({{.*}}, {{.*}}, {{.*}}, %[[arg_b0:.*]] = %[[B0]], {{.*}}, {{.*}}, %[[PIPELINE_IDX:.*]] = %[[CONSTANT_2]], %[[LOOP_IDX:.*]] = %[[CONSTANT_1]] // CHECK: scf.for {{.*}} iter_args({{.*}}, {{.*}}, {{.*}}, %[[arg_b0:.*]] = %[[B0]], {{.*}}, {{.*}}, %[[PIPELINE_IDX:.*]] = %[[CONSTANT_2]], %[[LOOP_IDX:.*]] = %[[CONSTANT_1]]
// CHECK: tt.dot {{.*}}, %[[arg_b0]], {{.*}} // CHECK: tt.dot {{.*}}, %[[arg_b0]], {{.*}}
// CHECK-DAG: %[[INSERT_IDX:.*]] = arith.remsi %[[PIPELINE_IDX]], %[[CONSTANT_3]] // CHECK-DAG: %[[INSERT_IDX:.*]] = arith.remsi %[[PIPELINE_IDX]], %[[CONSTANT_3]]
// CHECK-DAG: %[[EXTRACT_IDX:.*]] = arith.remsi %[[LOOP_IDX]], %[[CONSTANT_3]] // CHECK-DAG: %[[EXTRACT_INT:.*]] = arith.remsi %[[LOOP_IDX]], %[[CONSTANT_3]]
// CHECK-DAG: %[[EXTRACT_IDX:.*]] = arith.index_cast %[[EXTRACT_INT]] : i32 to index
// CHECK: %[[NEXT_B_BUFFER:.*]] = triton_gpu.insert_slice_async {{.*}}, {{.*}}, %[[INSERT_IDX]] // CHECK: %[[NEXT_B_BUFFER:.*]] = triton_gpu.insert_slice_async {{.*}}, {{.*}}, %[[INSERT_IDX]]
// CHECK: triton_gpu.async_wait {num = 1 : i32} // CHECK: triton_gpu.async_wait {num = 1 : i32}
// CHECK: %[[NEXT_B:.*]] = triton_gpu.extract_slice %[[NEXT_B_BUFFER]] // CHECK: %[[NEXT_B:.*]] = tensor.extract_slice %[[NEXT_B_BUFFER]][%[[EXTRACT_IDX]], 0, 0]
// CHECK-DAG: %[[NEXT_PIPELINE_IDX:.*]] = arith.addi %[[PIPELINE_IDX]], %[[CONSTANT_1]] // CHECK-DAG: %[[NEXT_PIPELINE_IDX:.*]] = arith.addi %[[PIPELINE_IDX]], %[[CONSTANT_1]]
// CHECK-DAG: %[[NEXT_LOOP_IDX:.*]] = arith.addi %[[LOOP_IDX]], %[[CONSTANT_1]] // CHECK-DAG: %[[NEXT_LOOP_IDX:.*]] = arith.addi %[[LOOP_IDX]], %[[CONSTANT_1]]
// CHECK: scf.yield {{.*}}, {{.*}}, %[[NEXT_B_BUFFER]], %[[NEXT_B]], {{.*}}, {{.*}}, %[[NEXT_PIPELINE_IDX]], %[[NEXT_LOOP_IDX]] // CHECK: scf.yield {{.*}}, {{.*}}, %[[NEXT_B_BUFFER]], %[[NEXT_B]], {{.*}}, {{.*}}, %[[NEXT_PIPELINE_IDX]], %[[NEXT_LOOP_IDX]]