[Triton-MLIR] Replace triton.extract_slice with tensor.extract_slice and support more general tensor slicing (#837)
## Features - Allow taking a block of tensor slice, as long as each dimension is contiguous (unit stride). - Fix some problems in `insert_slice_async`'s semantic. - More general verification for ops that return shared layout encoding. ## Known Limitations - `insert_slice_async` still uses the old semantic. May submit another PR later to support similar semantic like `tensor.extract_slice`. - No encoding verification for `tensor.extract_slice`. - 3d tensor ops are broken. - Strided accesses are not allowed. - May cause a little performance slowdown since we are passing strides as values but not constants (e.g., int). It would be difficult to pass strides as attributes when we have control flows. A block argument is possible to accept tensors with different strides.
This commit is contained in:
@@ -38,6 +38,7 @@ def ConvertTritonGPUToLLVM : Pass<"convert-triton-gpu-to-llvm", "mlir::ModuleOp"
|
|||||||
"mlir::gpu::GPUDialect",
|
"mlir::gpu::GPUDialect",
|
||||||
"mlir::scf::SCFDialect",
|
"mlir::scf::SCFDialect",
|
||||||
"mlir::LLVM::LLVMDialect",
|
"mlir::LLVM::LLVMDialect",
|
||||||
|
"mlir::tensor::TensorDialect",
|
||||||
"mlir::triton::TritonDialect",
|
"mlir::triton::TritonDialect",
|
||||||
"mlir::triton::gpu::TritonGPUDialect",
|
"mlir::triton::gpu::TritonGPUDialect",
|
||||||
"mlir::NVVM::NVVMDialect",
|
"mlir::NVVM::NVVMDialect",
|
||||||
|
@@ -1,7 +1,6 @@
|
|||||||
#ifndef TRITON_DIALECT_TRITON_IR_DIALECT_H_
|
#ifndef TRITON_DIALECT_TRITON_IR_DIALECT_H_
|
||||||
#define TRITON_DIALECT_TRITON_IR_DIALECT_H_
|
#define TRITON_DIALECT_TRITON_IR_DIALECT_H_
|
||||||
|
|
||||||
#include "mlir/Dialect/GPU/GPUDialect.h"
|
|
||||||
#include "mlir/Dialect/Math/IR/Math.h"
|
#include "mlir/Dialect/Math/IR/Math.h"
|
||||||
#include "mlir/Dialect/SCF/SCF.h"
|
#include "mlir/Dialect/SCF/SCF.h"
|
||||||
#include "mlir/Dialect/StandardOps/IR/Ops.h"
|
#include "mlir/Dialect/StandardOps/IR/Ops.h"
|
||||||
|
@@ -27,7 +27,6 @@ def Triton_Dialect : Dialect {
|
|||||||
"math::MathDialect",
|
"math::MathDialect",
|
||||||
"StandardOpsDialect",
|
"StandardOpsDialect",
|
||||||
"scf::SCFDialect",
|
"scf::SCFDialect",
|
||||||
"gpu::GPUDialect",
|
|
||||||
|
|
||||||
// Since LLVM 15
|
// Since LLVM 15
|
||||||
// "cf::ControlFlowDialect",
|
// "cf::ControlFlowDialect",
|
||||||
|
@@ -2,6 +2,7 @@
|
|||||||
#define TRITON_DIALECT_TRITONGPU_IR_DIALECT_H_
|
#define TRITON_DIALECT_TRITONGPU_IR_DIALECT_H_
|
||||||
|
|
||||||
#include "mlir/Dialect/GPU/GPUDialect.h"
|
#include "mlir/Dialect/GPU/GPUDialect.h"
|
||||||
|
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
||||||
#include "mlir/IR/BuiltinOps.h"
|
#include "mlir/IR/BuiltinOps.h"
|
||||||
#include "mlir/IR/Dialect.h"
|
#include "mlir/IR/Dialect.h"
|
||||||
|
|
||||||
@@ -9,6 +10,7 @@
|
|||||||
#include "triton/Dialect/Triton/IR/Dialect.h"
|
#include "triton/Dialect/Triton/IR/Dialect.h"
|
||||||
|
|
||||||
#include "triton/Dialect/TritonGPU/IR/Dialect.h.inc"
|
#include "triton/Dialect/TritonGPU/IR/Dialect.h.inc"
|
||||||
|
#include "triton/Dialect/TritonGPU/IR/Traits.h"
|
||||||
|
|
||||||
#define GET_ATTRDEF_CLASSES
|
#define GET_ATTRDEF_CLASSES
|
||||||
#include "triton/Dialect/Triton/IR/AttrInterfaces.h.inc"
|
#include "triton/Dialect/Triton/IR/AttrInterfaces.h.inc"
|
||||||
|
31
include/triton/Dialect/TritonGPU/IR/Traits.h
Normal file
31
include/triton/Dialect/TritonGPU/IR/Traits.h
Normal file
@@ -0,0 +1,31 @@
|
|||||||
|
#ifndef TRITON_GPU_IR_TRAITS_H_
|
||||||
|
#define TRITON_GPU_IR_TRAITS_H_
|
||||||
|
|
||||||
|
#include "mlir/IR/OpDefinition.h"
|
||||||
|
|
||||||
|
#include "mlir/IR/BuiltinTypes.h"
|
||||||
|
#include "mlir/Support/LogicalResult.h"
|
||||||
|
|
||||||
|
namespace mlir {
|
||||||
|
namespace OpTrait {
|
||||||
|
|
||||||
|
// These functions are out-of-line implementations of the methods in the
|
||||||
|
// corresponding trait classes. This avoids them being template
|
||||||
|
// instantiated/duplicated.
|
||||||
|
namespace impl {
|
||||||
|
LogicalResult verifyResultsAreSharedEncoding(Operation *op);
|
||||||
|
} // namespace impl
|
||||||
|
|
||||||
|
template <typename ConcreteType>
|
||||||
|
class ResultsAreSharedEncoding
|
||||||
|
: public TraitBase<ConcreteType, ResultsAreSharedEncoding> {
|
||||||
|
public:
|
||||||
|
static LogicalResult verifyTrait(Operation *op) {
|
||||||
|
return impl::verifyResultsAreSharedEncoding(op);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace OpTrait
|
||||||
|
} // namespace mlir
|
||||||
|
|
||||||
|
#endif
|
@@ -16,7 +16,8 @@ def TritonGPU_Dialect : Dialect {
|
|||||||
|
|
||||||
let dependentDialects = [
|
let dependentDialects = [
|
||||||
"triton::TritonDialect",
|
"triton::TritonDialect",
|
||||||
"mlir::gpu::GPUDialect"
|
"mlir::gpu::GPUDialect",
|
||||||
|
"tensor::TensorDialect",
|
||||||
];
|
];
|
||||||
|
|
||||||
let extraClassDeclaration = [{
|
let extraClassDeclaration = [{
|
||||||
|
@@ -10,6 +10,8 @@ include "mlir/IR/OpBase.td"
|
|||||||
include "mlir/Interfaces/SideEffectInterfaces.td" // NoSideEffect
|
include "mlir/Interfaces/SideEffectInterfaces.td" // NoSideEffect
|
||||||
include "mlir/Interfaces/InferTypeOpInterface.td" // SameOperandsAndResultType
|
include "mlir/Interfaces/InferTypeOpInterface.td" // SameOperandsAndResultType
|
||||||
|
|
||||||
|
def ResultsAreSharedEncoding: NativeOpTrait<"ResultsAreSharedEncoding">;
|
||||||
|
|
||||||
class TTG_Op<string mnemonic, list<Trait> traits = []> :
|
class TTG_Op<string mnemonic, list<Trait> traits = []> :
|
||||||
Op<TritonGPU_Dialect, mnemonic, traits>;
|
Op<TritonGPU_Dialect, mnemonic, traits>;
|
||||||
|
|
||||||
@@ -75,7 +77,8 @@ def TTG_SelectOp : TTG_Op<"select", [NoSideEffect]> {
|
|||||||
|
|
||||||
|
|
||||||
def TTG_InsertSliceAsyncOp : TTG_Op<"insert_slice_async",
|
def TTG_InsertSliceAsyncOp : TTG_Op<"insert_slice_async",
|
||||||
[SameVariadicOperandSize,
|
[AttrSizedOperandSegments,
|
||||||
|
ResultsAreSharedEncoding,
|
||||||
// MemoryEffects<[MemRead]>, doesn't work with CSE but seems like it should?
|
// MemoryEffects<[MemRead]>, doesn't work with CSE but seems like it should?
|
||||||
NoSideEffect,
|
NoSideEffect,
|
||||||
TypesMatchWith<"infer mask type from src type",
|
TypesMatchWith<"infer mask type from src type",
|
||||||
@@ -93,6 +96,10 @@ def TTG_InsertSliceAsyncOp : TTG_Op<"insert_slice_async",
|
|||||||
It returns a copy of `$dst` with the proper slice updated asynchronously with the value of `$src`.
|
It returns a copy of `$dst` with the proper slice updated asynchronously with the value of `$src`.
|
||||||
This operation is non-blocking, and `$results` will have the updated value after the corresponding async_wait.
|
This operation is non-blocking, and `$results` will have the updated value after the corresponding async_wait.
|
||||||
|
|
||||||
|
When converting from `tt.load` to `triton_gpu.insert_slice_async`, the `$evict`, `$cache`, and `$isVolatile` fields
|
||||||
|
might be ignored on certain hardware. For example, on NVIDIA GPUs, the cache policy is determined by the backend,
|
||||||
|
and `$evict` and `$isVolatile` are ignored because they apply to L1 cache only.
|
||||||
|
|
||||||
The insert_slice_async operation supports the following arguments:
|
The insert_slice_async operation supports the following arguments:
|
||||||
|
|
||||||
* src: the tensor that is inserted.
|
* src: the tensor that is inserted.
|
||||||
@@ -149,48 +156,9 @@ def TTG_InsertSliceAsyncOp : TTG_Op<"insert_slice_async",
|
|||||||
let parser = [{ return parseInsertSliceAsyncOp(parser, result); }];
|
let parser = [{ return parseInsertSliceAsyncOp(parser, result); }];
|
||||||
|
|
||||||
let printer = [{ return printInsertSliceAsyncOp(p, *this); }];
|
let printer = [{ return printInsertSliceAsyncOp(p, *this); }];
|
||||||
|
|
||||||
// result needs to be of shared layout
|
|
||||||
let verifier = [{ return ::verify(*this); }];
|
|
||||||
}
|
}
|
||||||
|
|
||||||
def TTG_ExtractSliceOp : TTG_Op<"extract_slice", [NoSideEffect, InferTypeOpInterface]> {
|
def TTG_AllocTensorOp : TTG_Op<"alloc_tensor", [NoSideEffect, ResultsAreSharedEncoding]> {
|
||||||
let summary = "extract slice";
|
|
||||||
let description = [{
|
|
||||||
The "extract_slice" operation extracts a `$result` tensor from a `$src` tensor as
|
|
||||||
specified by the operation's `$index` and `$axis` arguments.
|
|
||||||
|
|
||||||
The extract_slice operation supports the following arguments:
|
|
||||||
|
|
||||||
* src: the tensor that is extracted from.
|
|
||||||
* index: the index at the given `$axis` from which the `$src` tensor is extracted
|
|
||||||
|
|
||||||
Example:
|
|
||||||
|
|
||||||
```
|
|
||||||
// Rank-reducing extract_slice.
|
|
||||||
%1 = tensor.extract_slice %0, %index {axis = 0} : tensor<8x16x4xf32> -> tensor<1x16x4xf32>
|
|
||||||
```
|
|
||||||
}];
|
|
||||||
|
|
||||||
let arguments = (ins TT_Tensor:$src, I32:$index, I32Attr:$axis);
|
|
||||||
|
|
||||||
let results = (outs TT_Tensor:$result);
|
|
||||||
|
|
||||||
let assemblyFormat = [{$src `,` $index attr-dict `:` type($src) `->` type($result)}];
|
|
||||||
|
|
||||||
let extraClassDeclaration = [{
|
|
||||||
static ::mlir::LogicalResult inferReturnTypes(::mlir::MLIRContext *context,
|
|
||||||
::llvm::Optional<::mlir::Location> location, ::mlir::ValueRange operands,
|
|
||||||
::mlir::DictionaryAttr attributes, ::mlir::RegionRange regions,
|
|
||||||
::llvm::SmallVectorImpl<::mlir::Type> &inferredReturnTypes);
|
|
||||||
}];
|
|
||||||
|
|
||||||
// result needs to be of shared layout
|
|
||||||
let verifier = [{ return ::verify(*this); }];
|
|
||||||
}
|
|
||||||
|
|
||||||
def TTG_AllocTensorOp : TTG_Op<"alloc_tensor", [NoSideEffect]> {
|
|
||||||
let summary = "allocate tensor";
|
let summary = "allocate tensor";
|
||||||
|
|
||||||
let description = [{
|
let description = [{
|
||||||
@@ -203,9 +171,6 @@ def TTG_AllocTensorOp : TTG_Op<"alloc_tensor", [NoSideEffect]> {
|
|||||||
let assemblyFormat = [{attr-dict `:` type($result)}];
|
let assemblyFormat = [{attr-dict `:` type($result)}];
|
||||||
|
|
||||||
let results = (outs TT_Tensor:$result);
|
let results = (outs TT_Tensor:$result);
|
||||||
|
|
||||||
// result needs to be of shared layout
|
|
||||||
let verifier = [{ return ::verify(*this); }];
|
|
||||||
}
|
}
|
||||||
|
|
||||||
#endif
|
#endif
|
||||||
|
@@ -1,4 +1,5 @@
|
|||||||
#include "triton/Analysis/Alias.h"
|
#include "triton/Analysis/Alias.h"
|
||||||
|
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
||||||
#include "triton/Analysis/Utility.h"
|
#include "triton/Analysis/Utility.h"
|
||||||
#include "triton/Dialect/TritonGPU/IR/Dialect.h"
|
#include "triton/Dialect/TritonGPU/IR/Dialect.h"
|
||||||
|
|
||||||
@@ -24,18 +25,18 @@ ChangeResult SharedMemoryAliasAnalysis::visitOperation(
|
|||||||
if (maybeSharedAllocationOp(op)) {
|
if (maybeSharedAllocationOp(op)) {
|
||||||
// These ops may allocate a new shared memory buffer.
|
// These ops may allocate a new shared memory buffer.
|
||||||
auto result = op->getResult(0);
|
auto result = op->getResult(0);
|
||||||
if (isSharedEncoding(result)) {
|
|
||||||
// FIXME(Keren): extract and insert are always alias for now
|
// FIXME(Keren): extract and insert are always alias for now
|
||||||
if (auto extractSliceOp = dyn_cast<triton::gpu::ExtractSliceOp>(op)) {
|
if (auto extractSliceOp = dyn_cast<tensor::ExtractSliceOp>(op)) {
|
||||||
// extract_slice %src, %index
|
// extract_slice %src
|
||||||
aliasInfo = AliasInfo(operands[0]->getValue());
|
aliasInfo = AliasInfo(operands[0]->getValue());
|
||||||
|
pessimistic = false;
|
||||||
} else if (auto insertSliceOp =
|
} else if (auto insertSliceOp =
|
||||||
dyn_cast<triton::gpu::InsertSliceAsyncOp>(op)) {
|
dyn_cast<triton::gpu::InsertSliceAsyncOp>(op)) {
|
||||||
// insert_slice_async %src, %dst, %index
|
// insert_slice_async %src, %dst, %index
|
||||||
aliasInfo = AliasInfo(operands[1]->getValue());
|
aliasInfo = AliasInfo(operands[1]->getValue());
|
||||||
} else {
|
pessimistic = false;
|
||||||
|
} else if (isSharedEncoding(result)) {
|
||||||
aliasInfo.insert(result);
|
aliasInfo.insert(result);
|
||||||
}
|
|
||||||
pessimistic = false;
|
pessimistic = false;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@@ -1,6 +1,7 @@
|
|||||||
#include "triton/Analysis/Allocation.h"
|
#include "triton/Analysis/Allocation.h"
|
||||||
#include "mlir/Analysis/Liveness.h"
|
#include "mlir/Analysis/Liveness.h"
|
||||||
#include "mlir/Analysis/SliceAnalysis.h"
|
#include "mlir/Analysis/SliceAnalysis.h"
|
||||||
|
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
||||||
#include "triton/Analysis/Alias.h"
|
#include "triton/Analysis/Alias.h"
|
||||||
#include "triton/Analysis/Utility.h"
|
#include "triton/Analysis/Utility.h"
|
||||||
#include "triton/Dialect/TritonGPU/IR/Dialect.h"
|
#include "triton/Dialect/TritonGPU/IR/Dialect.h"
|
||||||
@@ -76,13 +77,13 @@ SmallVector<unsigned> getScratchConfigForReduce(triton::ReduceOp op) {
|
|||||||
auto srcShape = srcTy.getShape();
|
auto srcShape = srcTy.getShape();
|
||||||
auto axis = op.axis();
|
auto axis = op.axis();
|
||||||
|
|
||||||
bool fast_reduce = axis == 1; // FIXME(Qingyi): The fastest-changing dimension
|
bool fastReduce = axis == 1; // FIXME(Qingyi): The fastest-changing dimension
|
||||||
|
|
||||||
SmallVector<unsigned> smemShape;
|
SmallVector<unsigned> smemShape;
|
||||||
for (auto d : srcShape)
|
for (auto d : srcShape)
|
||||||
smemShape.push_back(d);
|
smemShape.push_back(d);
|
||||||
|
|
||||||
if (fast_reduce) {
|
if (fastReduce) {
|
||||||
unsigned sizeInterWarps = srcLayout.getWarpsPerCTA()[axis];
|
unsigned sizeInterWarps = srcLayout.getWarpsPerCTA()[axis];
|
||||||
smemShape[axis] = sizeInterWarps;
|
smemShape[axis] = sizeInterWarps;
|
||||||
} else {
|
} else {
|
||||||
@@ -123,7 +124,7 @@ private:
|
|||||||
// For example: %a = scf.if -> yield
|
// For example: %a = scf.if -> yield
|
||||||
// %a must be allocated elsewhere by other operations.
|
// %a must be allocated elsewhere by other operations.
|
||||||
// FIXME(Keren): extract and insert are always alias for now
|
// FIXME(Keren): extract and insert are always alias for now
|
||||||
if (!maybeSharedAllocationOp(op) || isa<triton::gpu::ExtractSliceOp>(op) ||
|
if (!maybeSharedAllocationOp(op) || isa<tensor::ExtractSliceOp>(op) ||
|
||||||
isa<triton::gpu::InsertSliceAsyncOp>(op)) {
|
isa<triton::gpu::InsertSliceAsyncOp>(op)) {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
@@ -2,6 +2,7 @@
|
|||||||
#include "triton/Dialect/TritonGPU/IR/Dialect.h"
|
#include "triton/Dialect/TritonGPU/IR/Dialect.h"
|
||||||
|
|
||||||
#include "mlir/Dialect/GPU/GPUDialect.h"
|
#include "mlir/Dialect/GPU/GPUDialect.h"
|
||||||
|
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
||||||
|
|
||||||
namespace mlir {
|
namespace mlir {
|
||||||
|
|
||||||
@@ -43,8 +44,7 @@ void MembarAnalysis::dfsOperation(Operation *operation,
|
|||||||
void MembarAnalysis::transfer(Operation *op, RegionInfo *regionInfo,
|
void MembarAnalysis::transfer(Operation *op, RegionInfo *regionInfo,
|
||||||
OpBuilder *builder) {
|
OpBuilder *builder) {
|
||||||
if (isa<scf::ForOp>(op) || isa<scf::IfOp>(op) || isa<scf::YieldOp>(op) ||
|
if (isa<scf::ForOp>(op) || isa<scf::IfOp>(op) || isa<scf::YieldOp>(op) ||
|
||||||
isa<triton::gpu::ExtractSliceOp>(op) ||
|
isa<tensor::ExtractSliceOp>(op) || isa<triton::gpu::AllocTensorOp>(op)) {
|
||||||
isa<triton::gpu::AllocTensorOp>(op)) {
|
|
||||||
// Do not insert barriers before control flow operations and
|
// Do not insert barriers before control flow operations and
|
||||||
// alloc/extract/insert
|
// alloc/extract/insert
|
||||||
// alloc is an allocation op without memory write.
|
// alloc is an allocation op without memory write.
|
||||||
|
@@ -24,7 +24,8 @@ bool maybeSharedAllocationOp(Operation *op) {
|
|||||||
mlir::TypeID::get<triton::gpu::TritonGPUDialect>() ||
|
mlir::TypeID::get<triton::gpu::TritonGPUDialect>() ||
|
||||||
dialect->getTypeID() == mlir::TypeID::get<triton::TritonDialect>() ||
|
dialect->getTypeID() == mlir::TypeID::get<triton::TritonDialect>() ||
|
||||||
dialect->getTypeID() ==
|
dialect->getTypeID() ==
|
||||||
mlir::TypeID::get<arith::ArithmeticDialect>());
|
mlir::TypeID::get<arith::ArithmeticDialect>() ||
|
||||||
|
dialect->getTypeID() == mlir::TypeID::get<tensor::TensorDialect>());
|
||||||
}
|
}
|
||||||
|
|
||||||
std::string getValueOperandName(Value value, AsmState &state) {
|
std::string getValueOperandName(Value value, AsmState &state) {
|
||||||
|
@@ -11,6 +11,7 @@
|
|||||||
#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
|
#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
|
||||||
#include "mlir/Dialect/GPU/GPUDialect.h"
|
#include "mlir/Dialect/GPU/GPUDialect.h"
|
||||||
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
|
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
|
||||||
|
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
||||||
#include "mlir/IR/Matchers.h"
|
#include "mlir/IR/Matchers.h"
|
||||||
#include "mlir/IR/TypeUtilities.h"
|
#include "mlir/IR/TypeUtilities.h"
|
||||||
#include "mlir/Transforms/DialectConversion.h"
|
#include "mlir/Transforms/DialectConversion.h"
|
||||||
@@ -50,7 +51,8 @@ static StringRef getStructAttrsAttrName() { return "llvm.struct_attrs"; }
|
|||||||
namespace {
|
namespace {
|
||||||
|
|
||||||
// Create a 32-bit integer constant.
|
// Create a 32-bit integer constant.
|
||||||
Value createConstantI32(Location loc, PatternRewriter &rewriter, int32_t v) {
|
static Value createConstantI32(Location loc, PatternRewriter &rewriter,
|
||||||
|
int32_t v) {
|
||||||
auto i32ty = rewriter.getIntegerType(32);
|
auto i32ty = rewriter.getIntegerType(32);
|
||||||
return rewriter.create<LLVM::ConstantOp>(loc, i32ty,
|
return rewriter.create<LLVM::ConstantOp>(loc, i32ty,
|
||||||
IntegerAttr::get(i32ty, v));
|
IntegerAttr::get(i32ty, v));
|
||||||
@@ -63,7 +65,7 @@ Value createConstantF32(Location loc, PatternRewriter &rewriter, float v) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Create a index type constant.
|
// Create a index type constant.
|
||||||
Value createIndexConstant(OpBuilder &builder, Location loc,
|
static Value createIndexConstant(OpBuilder &builder, Location loc,
|
||||||
|
|
||||||
TypeConverter *converter, int64_t value) {
|
TypeConverter *converter, int64_t value) {
|
||||||
Type ty = converter->convertType(builder.getIndexType());
|
Type ty = converter->convertType(builder.getIndexType());
|
||||||
@@ -72,8 +74,8 @@ Value createIndexConstant(OpBuilder &builder, Location loc,
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Create an integer constant of \param width bits.
|
// Create an integer constant of \param width bits.
|
||||||
Value createLLVMIntegerConstant(OpBuilder &builder, Location loc, short width,
|
static Value createLLVMIntegerConstant(OpBuilder &builder, Location loc,
|
||||||
int64_t value) {
|
short width, int64_t value) {
|
||||||
Type ty = builder.getIntegerType(width);
|
Type ty = builder.getIntegerType(width);
|
||||||
return builder.create<LLVM::ConstantOp>(loc, ty,
|
return builder.create<LLVM::ConstantOp>(loc, ty,
|
||||||
builder.getIntegerAttr(ty, value));
|
builder.getIntegerAttr(ty, value));
|
||||||
@@ -369,8 +371,8 @@ static T getLinearIndex(ArrayRef<T> multiDimIndex, ArrayRef<T> shape) {
|
|||||||
return linearIndex;
|
return linearIndex;
|
||||||
}
|
}
|
||||||
|
|
||||||
Value storeShared(ConversionPatternRewriter &rewriter, Location loc, Value ptr,
|
static Value storeShared(ConversionPatternRewriter &rewriter, Location loc,
|
||||||
Value val, Value pred) {
|
Value ptr, Value val, Value pred) {
|
||||||
MLIRContext *ctx = rewriter.getContext();
|
MLIRContext *ctx = rewriter.getContext();
|
||||||
unsigned bits = val.getType().getIntOrFloatBitWidth();
|
unsigned bits = val.getType().getIntOrFloatBitWidth();
|
||||||
const char *c = bits == 64 ? "l" : (bits == 16 ? "h" : "r");
|
const char *c = bits == 64 ? "l" : (bits == 16 ? "h" : "r");
|
||||||
@@ -383,6 +385,50 @@ Value storeShared(ConversionPatternRewriter &rewriter, Location loc, Value ptr,
|
|||||||
return builder.launch(rewriter, loc, void_ty(ctx));
|
return builder.launch(rewriter, loc, void_ty(ctx));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
struct SharedMemoryObject {
|
||||||
|
Value base; // i32 ptr. The start address of the shared memory object.
|
||||||
|
// We need to store strides as Values but not integers because the
|
||||||
|
// extract_slice instruction can take a slice at artibary offsets.
|
||||||
|
// Take $a[16:32, 16:32] as an example, though we know the stride of $a[0] is
|
||||||
|
// 32, we need to let the instruction that uses $a to be aware of that.
|
||||||
|
// Otherwise, when we use $a, we only know that the shape of $a is 16x16. If
|
||||||
|
// we store strides into an attribute array of integers, the information
|
||||||
|
// cannot pass through block argument assignment because attributes are
|
||||||
|
// associated with operations but not Values.
|
||||||
|
// TODO(Keren): We may need to figure out a way to store strides as integers
|
||||||
|
// if we want to support more optimizations.
|
||||||
|
SmallVector<Value>
|
||||||
|
strides; // i32 int. The strides of the shared memory object.
|
||||||
|
|
||||||
|
SharedMemoryObject(Value base, ArrayRef<Value> strides)
|
||||||
|
: base(base), strides(strides.begin(), strides.end()) {}
|
||||||
|
|
||||||
|
SharedMemoryObject(Value base, ArrayRef<int64_t> shape, Location loc,
|
||||||
|
ConversionPatternRewriter &rewriter)
|
||||||
|
: base(base) {
|
||||||
|
auto stride = 1;
|
||||||
|
for (auto dim : llvm::reverse(shape)) {
|
||||||
|
this->strides.emplace_back(i32_val(stride));
|
||||||
|
stride *= dim;
|
||||||
|
}
|
||||||
|
this->strides = llvm::to_vector<4>(llvm::reverse(this->strides));
|
||||||
|
}
|
||||||
|
|
||||||
|
SmallVector<Value> getElems() const {
|
||||||
|
SmallVector<Value> elems;
|
||||||
|
elems.push_back(base);
|
||||||
|
elems.append(strides.begin(), strides.end());
|
||||||
|
return elems;
|
||||||
|
}
|
||||||
|
|
||||||
|
SmallVector<Type> getTypes() const {
|
||||||
|
SmallVector<Type> types;
|
||||||
|
types.push_back(base.getType());
|
||||||
|
types.append(strides.size(), IntegerType::get(base.getContext(), 32));
|
||||||
|
return types;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
struct ConvertTritonGPUOpToLLVMPatternBase {
|
struct ConvertTritonGPUOpToLLVMPatternBase {
|
||||||
static SmallVector<Value>
|
static SmallVector<Value>
|
||||||
getElementsFromStruct(Location loc, Value llvmStruct,
|
getElementsFromStruct(Location loc, Value llvmStruct,
|
||||||
@@ -489,6 +535,16 @@ public:
|
|||||||
return linear;
|
return linear;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Value dot(ConversionPatternRewriter &rewriter, Location loc,
|
||||||
|
ArrayRef<Value> offsets, ArrayRef<Value> strides) const {
|
||||||
|
assert(offsets.size() == strides.size());
|
||||||
|
Value ret = idx_val(0);
|
||||||
|
for (auto [offset, stride] : llvm::zip(offsets, strides)) {
|
||||||
|
ret = add(ret, mul(offset, stride));
|
||||||
|
}
|
||||||
|
return ret;
|
||||||
|
}
|
||||||
|
|
||||||
// Get an index-base for each dimension for a \param blocked_layout.
|
// Get an index-base for each dimension for a \param blocked_layout.
|
||||||
SmallVector<Value>
|
SmallVector<Value>
|
||||||
emitBaseIndexForBlockedLayout(Location loc,
|
emitBaseIndexForBlockedLayout(Location loc,
|
||||||
@@ -671,6 +727,25 @@ public:
|
|||||||
return base;
|
return base;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
static SharedMemoryObject
|
||||||
|
getSharedMemoryObjectFromStruct(Location loc, Value llvmStruct,
|
||||||
|
ConversionPatternRewriter &rewriter) {
|
||||||
|
auto elems = getElementsFromStruct(loc, llvmStruct, rewriter);
|
||||||
|
return SharedMemoryObject(/*base=*/elems[0],
|
||||||
|
/*strides=*/{elems.begin() + 1, elems.end()});
|
||||||
|
}
|
||||||
|
|
||||||
|
static Value
|
||||||
|
getStructFromSharedMemoryObject(Location loc,
|
||||||
|
const SharedMemoryObject &smemObj,
|
||||||
|
ConversionPatternRewriter &rewriter) {
|
||||||
|
auto elems = smemObj.getElems();
|
||||||
|
auto types = smemObj.getTypes();
|
||||||
|
auto structTy =
|
||||||
|
LLVM::LLVMStructType::getLiteral(rewriter.getContext(), types);
|
||||||
|
return getStructFromElements(loc, elems, rewriter, structTy);
|
||||||
|
}
|
||||||
|
|
||||||
protected:
|
protected:
|
||||||
const Allocation *allocation;
|
const Allocation *allocation;
|
||||||
Value smem;
|
Value smem;
|
||||||
@@ -1734,46 +1809,63 @@ struct AllocTensorOpConversion
|
|||||||
auto resultTy = op.getType().dyn_cast<RankedTensorType>();
|
auto resultTy = op.getType().dyn_cast<RankedTensorType>();
|
||||||
auto llvmElemTy =
|
auto llvmElemTy =
|
||||||
getTypeConverter()->convertType(resultTy.getElementType());
|
getTypeConverter()->convertType(resultTy.getElementType());
|
||||||
auto elemPtrTy = LLVM::LLVMPointerType::get(llvmElemTy, 3);
|
auto elemPtrTy = ptr_ty(llvmElemTy, 3);
|
||||||
Value resultVal =
|
smemBase = bitcast(smemBase, elemPtrTy);
|
||||||
rewriter.create<LLVM::BitcastOp>(loc, elemPtrTy, smemBase);
|
auto smemObj =
|
||||||
rewriter.replaceOp(op, resultVal);
|
SharedMemoryObject(smemBase, resultTy.getShape(), loc, rewriter);
|
||||||
|
auto retVal = getStructFromSharedMemoryObject(loc, smemObj, rewriter);
|
||||||
|
rewriter.replaceOp(op, retVal);
|
||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
struct ExtractSliceOpConversion
|
struct ExtractSliceOpConversion
|
||||||
: public ConvertTritonGPUOpToLLVMPattern<triton::gpu::ExtractSliceOp> {
|
: public ConvertTritonGPUOpToLLVMPattern<tensor::ExtractSliceOp> {
|
||||||
using ConvertTritonGPUOpToLLVMPattern<
|
using ConvertTritonGPUOpToLLVMPattern<
|
||||||
triton::gpu::ExtractSliceOp>::ConvertTritonGPUOpToLLVMPattern;
|
tensor::ExtractSliceOp>::ConvertTritonGPUOpToLLVMPattern;
|
||||||
|
|
||||||
LogicalResult
|
LogicalResult
|
||||||
matchAndRewrite(triton::gpu::ExtractSliceOp op, OpAdaptor adaptor,
|
matchAndRewrite(tensor::ExtractSliceOp op, OpAdaptor adaptor,
|
||||||
ConversionPatternRewriter &rewriter) const override {
|
ConversionPatternRewriter &rewriter) const override {
|
||||||
|
// %dst = extract_slice %src[%offsets]
|
||||||
Location loc = op->getLoc();
|
Location loc = op->getLoc();
|
||||||
auto srcTy = op.src().getType().dyn_cast<RankedTensorType>();
|
auto srcTy = op.source().getType().dyn_cast<RankedTensorType>();
|
||||||
auto srcLayout = srcTy.getEncoding().dyn_cast<SharedEncodingAttr>();
|
auto srcLayout = srcTy.getEncoding().dyn_cast<SharedEncodingAttr>();
|
||||||
assert(srcLayout && "Unexpected resultLayout in ExtractSliceOpConversion");
|
assert(srcLayout && "Unexpected resultLayout in ExtractSliceOpConversion");
|
||||||
|
assert(op.hasUnitStride() &&
|
||||||
|
"Only unit stride supported by ExtractSliceOpConversion");
|
||||||
|
|
||||||
// axis > 0 will result in non-contiguous memory access if the result
|
// newBase = base + offset
|
||||||
// tensor is an alias of the source tensor.
|
// Triton support either static and dynamic offsets
|
||||||
auto axis = op->getAttrOfType<IntegerAttr>("axis").getInt();
|
auto smemObj =
|
||||||
assert(axis == 0 && "extract_slice: Only axis=0 is supported for now");
|
getSharedMemoryObjectFromStruct(loc, adaptor.source(), rewriter);
|
||||||
|
SmallVector<Value, 4> offsetVals;
|
||||||
// Example:
|
auto mixedOffsets = op.getMixedOffsets();
|
||||||
// %dst = extract_slice %src, %index {axis = 0}
|
for (auto i = 0; i < mixedOffsets.size(); ++i) {
|
||||||
// src.shape = [11, 2, 3, 4, 1]
|
if (op.isDynamicOffset(i))
|
||||||
// offset = %index * 2 * 3 * 4 * 1
|
offsetVals.emplace_back(adaptor.offsets()[i]);
|
||||||
auto dstTy = op.getType().dyn_cast<RankedTensorType>();
|
else
|
||||||
auto base = product<int64_t>(dstTy.getShape());
|
offsetVals.emplace_back(i32_val(op.getStaticOffset(i)));
|
||||||
auto baseVal = createIndexAttrConstant(
|
}
|
||||||
rewriter, loc, getTypeConverter()->getIndexType(), base);
|
// Compute the offset based on the original strides of the shared memory
|
||||||
Value offset = mul(adaptor.index(), baseVal);
|
// object
|
||||||
|
auto offset = dot(rewriter, loc, offsetVals, smemObj.strides);
|
||||||
auto llvmElemTy = getTypeConverter()->convertType(dstTy.getElementType());
|
// newShape = rank_reduce(shape)
|
||||||
auto elemPtrTy = LLVM::LLVMPointerType::get(llvmElemTy, 3);
|
// Triton only supports static tensor sizes
|
||||||
Value resultVal = gep(elemPtrTy, adaptor.src(), offset);
|
SmallVector<Value, 4> strideVals;
|
||||||
rewriter.replaceOp(op, resultVal);
|
auto staticSizes = op.static_sizes();
|
||||||
|
for (auto i = 0; i < op.static_sizes().size(); ++i) {
|
||||||
|
if (op.getStaticSize(i) != 1) {
|
||||||
|
strideVals.emplace_back(smemObj.strides[i]);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
auto llvmElemTy = getTypeConverter()->convertType(srcTy.getElementType());
|
||||||
|
auto elemPtrTy = ptr_ty(llvmElemTy, 3);
|
||||||
|
auto resTy = op.getType().dyn_cast<RankedTensorType>();
|
||||||
|
smemObj =
|
||||||
|
SharedMemoryObject(gep(elemPtrTy, smemObj.base, offset), strideVals);
|
||||||
|
auto retVal = getStructFromSharedMemoryObject(loc, smemObj, rewriter);
|
||||||
|
rewriter.replaceOp(op, retVal);
|
||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
@@ -2262,8 +2354,9 @@ LogicalResult ConvertLayoutOpConversion::lowerBlockedToShared(
|
|||||||
Value src = op.src();
|
Value src = op.src();
|
||||||
Value dst = op.result();
|
Value dst = op.result();
|
||||||
auto srcTy = src.getType().cast<RankedTensorType>();
|
auto srcTy = src.getType().cast<RankedTensorType>();
|
||||||
auto dstTy = dst.getType().cast<RankedTensorType>();
|
|
||||||
auto srcShape = srcTy.getShape();
|
auto srcShape = srcTy.getShape();
|
||||||
|
auto dstTy = dst.getType().cast<RankedTensorType>();
|
||||||
|
auto dstShape = dstTy.getShape();
|
||||||
assert(srcShape.size() == 2 &&
|
assert(srcShape.size() == 2 &&
|
||||||
"Unexpected rank of ConvertLayout(blocked->shared)");
|
"Unexpected rank of ConvertLayout(blocked->shared)");
|
||||||
auto srcBlockedLayout = srcTy.getEncoding().cast<BlockedEncodingAttr>();
|
auto srcBlockedLayout = srcTy.getEncoding().cast<BlockedEncodingAttr>();
|
||||||
@@ -2309,6 +2402,8 @@ LogicalResult ConvertLayoutOpConversion::lowerBlockedToShared(
|
|||||||
Value smemBase = getSharedMemoryBase(loc, rewriter, dst);
|
Value smemBase = getSharedMemoryBase(loc, rewriter, dst);
|
||||||
auto elemPtrTy = ptr_ty(getTypeConverter()->convertType(elemTy), 3);
|
auto elemPtrTy = ptr_ty(getTypeConverter()->convertType(elemTy), 3);
|
||||||
smemBase = bitcast(smemBase, elemPtrTy);
|
smemBase = bitcast(smemBase, elemPtrTy);
|
||||||
|
auto smemObj = SharedMemoryObject(smemBase, dstShape, loc, rewriter);
|
||||||
|
auto retVal = getStructFromSharedMemoryObject(loc, smemObj, rewriter);
|
||||||
unsigned numWordsEachRep = product<unsigned>(wordsInEachRep);
|
unsigned numWordsEachRep = product<unsigned>(wordsInEachRep);
|
||||||
SmallVector<Value> wordVecs(numWordsEachRep);
|
SmallVector<Value> wordVecs(numWordsEachRep);
|
||||||
// TODO: We should get less barriers if it is handled by membar pass
|
// TODO: We should get less barriers if it is handled by membar pass
|
||||||
@@ -2369,8 +2464,10 @@ LogicalResult ConvertLayoutOpConversion::lowerBlockedToShared(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
barrier();
|
// Barrier is not necessary.
|
||||||
rewriter.replaceOp(op, smemBase);
|
// The membar pass knows that it writes to shared memory and will handle it
|
||||||
|
// properly.
|
||||||
|
rewriter.replaceOp(op, retVal);
|
||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -2380,9 +2477,10 @@ LogicalResult ConvertLayoutOpConversion::lowerBlockedToShared(
|
|||||||
class MMA16816SmemLoader {
|
class MMA16816SmemLoader {
|
||||||
public:
|
public:
|
||||||
MMA16816SmemLoader(int wpt, ArrayRef<uint32_t> order, uint32_t kOrder,
|
MMA16816SmemLoader(int wpt, ArrayRef<uint32_t> order, uint32_t kOrder,
|
||||||
ArrayRef<int64_t> tileShape, ArrayRef<int> instrShape,
|
ArrayRef<Value> smemStrides, ArrayRef<int64_t> tileShape,
|
||||||
ArrayRef<int> matShape, int perPhase, int maxPhase,
|
ArrayRef<int> instrShape, ArrayRef<int> matShape,
|
||||||
int elemBytes, ConversionPatternRewriter &rewriter,
|
int perPhase, int maxPhase, int elemBytes,
|
||||||
|
ConversionPatternRewriter &rewriter,
|
||||||
TypeConverter *typeConverter, const Location &loc)
|
TypeConverter *typeConverter, const Location &loc)
|
||||||
: order(order.begin(), order.end()), kOrder(kOrder),
|
: order(order.begin(), order.end()), kOrder(kOrder),
|
||||||
tileShape(tileShape.begin(), tileShape.end()),
|
tileShape(tileShape.begin(), tileShape.end()),
|
||||||
@@ -2393,8 +2491,8 @@ public:
|
|||||||
cMatShape = matShape[order[0]];
|
cMatShape = matShape[order[0]];
|
||||||
sMatShape = matShape[order[1]];
|
sMatShape = matShape[order[1]];
|
||||||
|
|
||||||
cTileStride = tileShape[order[1]];
|
cTileStride = smemStrides[order[0]];
|
||||||
sTileStride = tileShape[order[0]];
|
sTileStride = smemStrides[order[1]];
|
||||||
|
|
||||||
// rule: k must be the fast-changing axis.
|
// rule: k must be the fast-changing axis.
|
||||||
needTrans = kOrder != order[0];
|
needTrans = kOrder != order[0];
|
||||||
@@ -2497,8 +2595,7 @@ public:
|
|||||||
for (int i = 0; i < numPtr; ++i) {
|
for (int i = 0; i < numPtr; ++i) {
|
||||||
Value cMatOffI = add(cMatOff, i32_val(i * pLoadStrideInMat));
|
Value cMatOffI = add(cMatOff, i32_val(i * pLoadStrideInMat));
|
||||||
cMatOffI = xor_(cMatOffI, phase);
|
cMatOffI = xor_(cMatOffI, phase);
|
||||||
offs[i] = add(mul(cMatOffI, i32_val(cMatShape)),
|
offs[i] = add(mul(cMatOffI, i32_val(cMatShape)), mul(sOff, sTileStride));
|
||||||
mul(sOff, i32_val(sTileStride)));
|
|
||||||
}
|
}
|
||||||
|
|
||||||
return offs;
|
return offs;
|
||||||
@@ -2534,7 +2631,7 @@ public:
|
|||||||
Value cOff = add(cOffInMat, mul(cMatOffI, i32_val(cMatShape)));
|
Value cOff = add(cOffInMat, mul(cMatOffI, i32_val(cMatShape)));
|
||||||
cOff = urem(cOff, i32_val(tileShape[order[0]]));
|
cOff = urem(cOff, i32_val(tileShape[order[0]]));
|
||||||
sOff = urem(sOff, i32_val(tileShape[order[1]]));
|
sOff = urem(sOff, i32_val(tileShape[order[1]]));
|
||||||
offs[2 * i + nkMatArrInt] = add(cOff, mul(sOff, i32_val(sTileStride)));
|
offs[2 * i + nkMatArrInt] = add(cOff, mul(sOff, sTileStride));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return offs;
|
return offs;
|
||||||
@@ -2574,7 +2671,7 @@ public:
|
|||||||
// To prevent out-of-bound access when tile is too small.
|
// To prevent out-of-bound access when tile is too small.
|
||||||
cOff = urem(cOff, i32_val(tileShape[order[0]]));
|
cOff = urem(cOff, i32_val(tileShape[order[0]]));
|
||||||
sOff = urem(sOff, i32_val(tileShape[order[1]]));
|
sOff = urem(sOff, i32_val(tileShape[order[1]]));
|
||||||
offs[ptrOff] = add(cOff, mul(sOff, i32_val(sTileStride)));
|
offs[ptrOff] = add(cOff, mul(sOff, sTileStride));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -2608,14 +2705,15 @@ public:
|
|||||||
Value ptr = getPtr(ptrIdx);
|
Value ptr = getPtr(ptrIdx);
|
||||||
|
|
||||||
if (canUseLdmatrix) {
|
if (canUseLdmatrix) {
|
||||||
int sOffset =
|
Value sOffset =
|
||||||
matIdx[order[1]] * sMatStride * sMatShape * sTileStride * elemBytes;
|
mul(i32_val(matIdx[order[1]] * sMatStride * sMatShape), sTileStride);
|
||||||
|
Value sOffsetPtr = gep(shemPtrTy, ptr, sOffset);
|
||||||
PTXBuilder builder;
|
PTXBuilder builder;
|
||||||
|
|
||||||
// ldmatrix.m8n8.x4 returns 4x2xfp16(that is 4xb32) elements for a
|
// ldmatrix.m8n8.x4 returns 4x2xfp16(that is 4xb32) elements for a
|
||||||
// thread.
|
// thread.
|
||||||
auto resArgs = builder.newListOperand(4, "=r");
|
auto resArgs = builder.newListOperand(4, "=r");
|
||||||
auto addrArg = builder.newAddrOperand(ptr, "r", sOffset);
|
auto addrArg = builder.newAddrOperand(sOffsetPtr, "r");
|
||||||
|
|
||||||
auto ldmatrix = builder.create("ldmatrix.sync.aligned.m8n8.x4")
|
auto ldmatrix = builder.create("ldmatrix.sync.aligned.m8n8.x4")
|
||||||
->o("trans", needTrans /*predicate*/)
|
->o("trans", needTrans /*predicate*/)
|
||||||
@@ -2640,26 +2738,24 @@ public:
|
|||||||
needTrans) { // Use lds.32 to load tf32 matrices
|
needTrans) { // Use lds.32 to load tf32 matrices
|
||||||
Value ptr2 = getPtr(ptrIdx + 1);
|
Value ptr2 = getPtr(ptrIdx + 1);
|
||||||
assert(sMatStride == 1);
|
assert(sMatStride == 1);
|
||||||
int sOffsetElem =
|
int sOffsetElem = matIdx[order[1]] * (sMatStride * sMatShape);
|
||||||
matIdx[order[1]] * (sMatStride * sMatShape) * sTileStride;
|
Value sOffsetElemVal = mul(i32_val(sOffsetElem), sTileStride);
|
||||||
int sOffsetArrElem = 1 * (sMatStride * sMatShape) * sTileStride;
|
int sOffsetArrElem = sMatStride * sMatShape;
|
||||||
|
Value sOffsetArrElemVal =
|
||||||
|
add(sOffsetElemVal, mul(i32_val(sOffsetArrElem), sTileStride));
|
||||||
|
|
||||||
Value elems[4];
|
Value elems[4];
|
||||||
Type elemTy = type::f32Ty(ctx);
|
Type elemTy = type::f32Ty(ctx);
|
||||||
if (kOrder == 1) {
|
if (kOrder == 1) {
|
||||||
elems[0] = load(gep(elemTy, ptr, i32_val(sOffsetElem)));
|
elems[0] = load(gep(elemTy, ptr, sOffsetElemVal));
|
||||||
elems[1] = load(gep(elemTy, ptr2, i32_val(sOffsetElem)));
|
elems[1] = load(gep(elemTy, ptr2, sOffsetElemVal));
|
||||||
elems[2] =
|
elems[2] = load(gep(elemTy, ptr, sOffsetArrElemVal));
|
||||||
load(gep(elemTy, ptr, i32_val(sOffsetElem + sOffsetArrElem)));
|
elems[3] = load(gep(elemTy, ptr2, sOffsetArrElemVal));
|
||||||
elems[3] =
|
|
||||||
load(gep(elemTy, ptr2, i32_val(sOffsetElem + sOffsetArrElem)));
|
|
||||||
} else {
|
} else {
|
||||||
elems[0] = load(gep(elemTy, ptr, i32_val(sOffsetElem)));
|
elems[0] = load(gep(elemTy, ptr, sOffsetElemVal));
|
||||||
elems[2] = load(gep(elemTy, ptr2, i32_val(sOffsetElem)));
|
elems[2] = load(gep(elemTy, ptr2, sOffsetElemVal));
|
||||||
elems[1] =
|
elems[1] = load(gep(elemTy, ptr, sOffsetArrElemVal));
|
||||||
load(gep(elemTy, ptr, i32_val(sOffsetElem + sOffsetArrElem)));
|
elems[3] = load(gep(elemTy, ptr2, sOffsetArrElemVal));
|
||||||
elems[3] =
|
|
||||||
load(gep(elemTy, ptr2, i32_val(sOffsetElem + sOffsetArrElem)));
|
|
||||||
}
|
}
|
||||||
|
|
||||||
return {elems[0], elems[1], elems[2], elems[3]};
|
return {elems[0], elems[1], elems[2], elems[3]};
|
||||||
@@ -2680,9 +2776,11 @@ public:
|
|||||||
};
|
};
|
||||||
|
|
||||||
assert(sMatStride == 1);
|
assert(sMatStride == 1);
|
||||||
int sOffsetElem =
|
int sOffsetElem = matIdx[order[1]] * (sMatStride * sMatShape);
|
||||||
matIdx[order[1]] * (sMatStride * sMatShape) * sTileStride;
|
Value sOffsetElemVal = mul(i32_val(sOffsetElem), sTileStride);
|
||||||
int sOffsetArrElem = 1 * (sMatStride * sMatShape) * sTileStride;
|
int sOffsetArrElem = 1 * (sMatStride * sMatShape);
|
||||||
|
Value sOffsetArrElemVal =
|
||||||
|
add(sOffsetElemVal, mul(i32_val(sOffsetArrElem), sTileStride));
|
||||||
|
|
||||||
std::array<Value, 4> i8v4Elems;
|
std::array<Value, 4> i8v4Elems;
|
||||||
std::array<Value, 4> i32Elems;
|
std::array<Value, 4> i32Elems;
|
||||||
@@ -2692,16 +2790,14 @@ public:
|
|||||||
Value i8Elems[4][4];
|
Value i8Elems[4][4];
|
||||||
Type elemTy = type::i8Ty(ctx);
|
Type elemTy = type::i8Ty(ctx);
|
||||||
if (kOrder == 1) {
|
if (kOrder == 1) {
|
||||||
Value offset = i32_val(sOffsetElem);
|
|
||||||
|
|
||||||
for (int i = 0; i < 2; ++i)
|
for (int i = 0; i < 2; ++i)
|
||||||
for (int j = 0; j < 4; ++j)
|
for (int j = 0; j < 4; ++j)
|
||||||
i8Elems[i][j] = load(gep(elemTy, ptrs[i][j], offset));
|
i8Elems[i][j] = load(gep(elemTy, ptrs[i][j], sOffsetElemVal));
|
||||||
|
|
||||||
offset = i32_val(sOffsetElem + sOffsetArrElem);
|
|
||||||
for (int i = 2; i < 4; ++i)
|
for (int i = 2; i < 4; ++i)
|
||||||
for (int j = 0; j < 4; ++j)
|
for (int j = 0; j < 4; ++j)
|
||||||
i8Elems[i][j] = load(gep(elemTy, ptrs[i - 2][j], offset));
|
i8Elems[i][j] =
|
||||||
|
load(gep(elemTy, ptrs[i - 2][j], sOffsetArrElemVal));
|
||||||
|
|
||||||
for (int m = 0; m < 4; ++m) {
|
for (int m = 0; m < 4; ++m) {
|
||||||
for (int e = 0; e < 4; ++e)
|
for (int e = 0; e < 4; ++e)
|
||||||
@@ -2710,16 +2806,14 @@ public:
|
|||||||
i32Elems[m] = bitcast(i8v4Elems[m], i32_ty);
|
i32Elems[m] = bitcast(i8v4Elems[m], i32_ty);
|
||||||
}
|
}
|
||||||
} else { // k first
|
} else { // k first
|
||||||
Value offset = i32_val(sOffsetElem);
|
|
||||||
for (int j = 0; j < 4; ++j)
|
for (int j = 0; j < 4; ++j)
|
||||||
i8Elems[0][j] = load(gep(elemTy, ptrs[0][j], offset));
|
i8Elems[0][j] = load(gep(elemTy, ptrs[0][j], sOffsetElemVal));
|
||||||
for (int j = 0; j < 4; ++j)
|
for (int j = 0; j < 4; ++j)
|
||||||
i8Elems[2][j] = load(gep(elemTy, ptrs[1][j], offset));
|
i8Elems[2][j] = load(gep(elemTy, ptrs[1][j], sOffsetElemVal));
|
||||||
offset = i32_val(sOffsetElem + sOffsetArrElem);
|
|
||||||
for (int j = 0; j < 4; ++j)
|
for (int j = 0; j < 4; ++j)
|
||||||
i8Elems[1][j] = load(gep(elemTy, ptrs[0][j], offset));
|
i8Elems[1][j] = load(gep(elemTy, ptrs[0][j], sOffsetArrElemVal));
|
||||||
for (int j = 0; j < 4; ++j)
|
for (int j = 0; j < 4; ++j)
|
||||||
i8Elems[3][j] = load(gep(elemTy, ptrs[1][j], offset));
|
i8Elems[3][j] = load(gep(elemTy, ptrs[1][j], sOffsetArrElemVal));
|
||||||
|
|
||||||
for (int m = 0; m < 4; ++m) {
|
for (int m = 0; m < 4; ++m) {
|
||||||
for (int e = 0; e < 4; ++e)
|
for (int e = 0; e < 4; ++e)
|
||||||
@@ -2752,8 +2846,8 @@ private:
|
|||||||
int cMatShape;
|
int cMatShape;
|
||||||
int sMatShape;
|
int sMatShape;
|
||||||
|
|
||||||
int cTileStride;
|
Value cTileStride;
|
||||||
int sTileStride;
|
Value sTileStride;
|
||||||
|
|
||||||
bool needTrans;
|
bool needTrans;
|
||||||
bool canUseLdmatrix;
|
bool canUseLdmatrix;
|
||||||
@@ -2922,12 +3016,12 @@ struct DotOpMmaV1ConversionHelper {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Loading $a from smem to registers, returns a LLVM::Struct.
|
// Loading $a from smem to registers, returns a LLVM::Struct.
|
||||||
Value loadA(Value A, Value llA, Value thread, Value smem, Location loc,
|
Value loadA(Value A, const SharedMemoryObject &smemObj, Value thread,
|
||||||
ConversionPatternRewriter &rewriter) const;
|
Location loc, ConversionPatternRewriter &rewriter) const;
|
||||||
|
|
||||||
// Loading $b from smem to registers, returns a LLVM::Struct.
|
// Loading $b from smem to registers, returns a LLVM::Struct.
|
||||||
Value loadB(Value B, Value llB, Value thread, Value smem, Location loc,
|
Value loadB(Value B, const SharedMemoryObject &smemObj, Value thread,
|
||||||
ConversionPatternRewriter &rewriter) const;
|
Location loc, ConversionPatternRewriter &rewriter) const;
|
||||||
|
|
||||||
// Loading $c to registers, returns a LLVM::Struct.
|
// Loading $c to registers, returns a LLVM::Struct.
|
||||||
Value loadC(Value C, Value llC, ConversionPatternRewriter &rewriter) const;
|
Value loadC(Value C, Value llC, ConversionPatternRewriter &rewriter) const;
|
||||||
@@ -3334,7 +3428,7 @@ struct MMA16816ConversionHelper {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Loading $a from smem to registers, returns a LLVM::Struct.
|
// Loading $a from smem to registers, returns a LLVM::Struct.
|
||||||
Value loadA(Value tensor, Value llTensor) const {
|
Value loadA(Value tensor, const SharedMemoryObject &smemObj) const {
|
||||||
auto aTensorTy = tensor.getType().cast<RankedTensorType>();
|
auto aTensorTy = tensor.getType().cast<RankedTensorType>();
|
||||||
auto shape = aTensorTy.getShape();
|
auto shape = aTensorTy.getShape();
|
||||||
|
|
||||||
@@ -3348,7 +3442,7 @@ struct MMA16816ConversionHelper {
|
|||||||
if (aTensorTy.getEncoding().isa<SharedEncodingAttr>()) {
|
if (aTensorTy.getEncoding().isa<SharedEncodingAttr>()) {
|
||||||
// load from smem
|
// load from smem
|
||||||
loadFn = getLoadMatrixFn(
|
loadFn = getLoadMatrixFn(
|
||||||
tensor, llTensor, mmaLayout, mmaLayout.getWarpsPerCTA()[0] /*wpt*/,
|
tensor, smemObj, mmaLayout, mmaLayout.getWarpsPerCTA()[0] /*wpt*/,
|
||||||
1 /*kOrder*/, {mmaInstrM, mmaInstrK} /*instrShpae*/,
|
1 /*kOrder*/, {mmaInstrM, mmaInstrK} /*instrShpae*/,
|
||||||
{matShapeM, matShapeK} /*matShape*/, warpM /*warpId*/, ha /*vals*/);
|
{matShapeM, matShapeK} /*matShape*/, warpM /*warpId*/, ha /*vals*/);
|
||||||
} else if (aTensorTy.getEncoding().isa<BlockedEncodingAttr>()) {
|
} else if (aTensorTy.getEncoding().isa<BlockedEncodingAttr>()) {
|
||||||
@@ -3370,7 +3464,7 @@ struct MMA16816ConversionHelper {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Loading $b from smem to registers, returns a LLVM::Struct.
|
// Loading $b from smem to registers, returns a LLVM::Struct.
|
||||||
Value loadB(Value tensor, Value llTensor) {
|
Value loadB(Value tensor, const SharedMemoryObject &smemObj) {
|
||||||
ValueTable hb;
|
ValueTable hb;
|
||||||
auto tensorTy = tensor.getType().cast<RankedTensorType>();
|
auto tensorTy = tensor.getType().cast<RankedTensorType>();
|
||||||
auto shape = tensorTy.getShape();
|
auto shape = tensorTy.getShape();
|
||||||
@@ -3380,7 +3474,7 @@ struct MMA16816ConversionHelper {
|
|||||||
int numRepN = getNumRepN(tensorTy, shape[1]);
|
int numRepN = getNumRepN(tensorTy, shape[1]);
|
||||||
|
|
||||||
auto loadFn = getLoadMatrixFn(
|
auto loadFn = getLoadMatrixFn(
|
||||||
tensor, llTensor, mmaLayout, mmaLayout.getWarpsPerCTA()[1] /*wpt*/,
|
tensor, smemObj, mmaLayout, mmaLayout.getWarpsPerCTA()[1] /*wpt*/,
|
||||||
0 /*kOrder*/, {mmaInstrK, mmaInstrN} /*instrShpae*/,
|
0 /*kOrder*/, {mmaInstrK, mmaInstrN} /*instrShpae*/,
|
||||||
{matShapeK, matShapeN} /*matShape*/, warpN /*warpId*/, hb /*vals*/);
|
{matShapeK, matShapeN} /*matShape*/, warpN /*warpId*/, hb /*vals*/);
|
||||||
|
|
||||||
@@ -3485,10 +3579,10 @@ struct MMA16816ConversionHelper {
|
|||||||
|
|
||||||
private:
|
private:
|
||||||
std::function<void(int, int)>
|
std::function<void(int, int)>
|
||||||
getLoadMatrixFn(Value tensor, Value llTensor, MmaEncodingAttr mmaLayout,
|
getLoadMatrixFn(Value tensor, const SharedMemoryObject &smemObj,
|
||||||
int wpt, uint32_t kOrder, ArrayRef<int> instrShape,
|
MmaEncodingAttr mmaLayout, int wpt, uint32_t kOrder,
|
||||||
ArrayRef<int> matShape, Value warpId,
|
ArrayRef<int> instrShape, ArrayRef<int> matShape,
|
||||||
ValueTable &vals) const {
|
Value warpId, ValueTable &vals) const {
|
||||||
auto tensorTy = tensor.getType().cast<RankedTensorType>();
|
auto tensorTy = tensor.getType().cast<RankedTensorType>();
|
||||||
// We assumes that the input operand of Dot should be from shared layout.
|
// We assumes that the input operand of Dot should be from shared layout.
|
||||||
// TODO(Superjomn) Consider other layouts if needed later.
|
// TODO(Superjomn) Consider other layouts if needed later.
|
||||||
@@ -3507,10 +3601,10 @@ private:
|
|||||||
|
|
||||||
// (a, b) is the coordinate.
|
// (a, b) is the coordinate.
|
||||||
auto load = [=, &vals, &ld2](int a, int b) {
|
auto load = [=, &vals, &ld2](int a, int b) {
|
||||||
MMA16816SmemLoader loader(wpt, sharedLayout.getOrder(), kOrder,
|
MMA16816SmemLoader loader(
|
||||||
tensorTy.getShape() /*tileShape*/, instrShape,
|
wpt, sharedLayout.getOrder(), kOrder, smemObj.strides,
|
||||||
matShape, perPhase, maxPhase, elemBytes,
|
tensorTy.getShape() /*tileShape*/, instrShape, matShape, perPhase,
|
||||||
rewriter, typeConverter, loc);
|
maxPhase, elemBytes, rewriter, typeConverter, loc);
|
||||||
SmallVector<Value> offs = loader.computeOffsets(warpId, lane);
|
SmallVector<Value> offs = loader.computeOffsets(warpId, lane);
|
||||||
|
|
||||||
const int numPtrs = loader.getNumPtr();
|
const int numPtrs = loader.getNumPtr();
|
||||||
@@ -3519,8 +3613,8 @@ private:
|
|||||||
|
|
||||||
Type smemPtrTy = helper.getShemPtrTy();
|
Type smemPtrTy = helper.getShemPtrTy();
|
||||||
for (int i = 0; i < numPtrs; ++i) {
|
for (int i = 0; i < numPtrs; ++i) {
|
||||||
ptrs[i] =
|
ptrs[i] = bitcast(gep(smemPtrTy, smemObj.base, ValueRange({offs[i]})),
|
||||||
bitcast(gep(smemPtrTy, llTensor, ValueRange({offs[i]})), smemPtrTy);
|
smemPtrTy);
|
||||||
}
|
}
|
||||||
|
|
||||||
auto [ha0, ha1, ha2, ha3] = loader.loadX4(
|
auto [ha0, ha1, ha2, ha3] = loader.loadX4(
|
||||||
@@ -3612,6 +3706,7 @@ LogicalResult ConvertLayoutOpConversion::lowerSharedToDotOperand(
|
|||||||
dotOperandLayout.getParent().dyn_cast_or_null<MmaEncodingAttr>();
|
dotOperandLayout.getParent().dyn_cast_or_null<MmaEncodingAttr>();
|
||||||
assert(mmaLayout);
|
assert(mmaLayout);
|
||||||
|
|
||||||
|
auto smemObj = getSharedMemoryObjectFromStruct(loc, adaptor.src(), rewriter);
|
||||||
Value res;
|
Value res;
|
||||||
if (mmaLayout.getVersion() == 2) {
|
if (mmaLayout.getVersion() == 2) {
|
||||||
MMA16816ConversionHelper mmaHelper(mmaLayout, getThreadId(rewriter, loc),
|
MMA16816ConversionHelper mmaHelper(mmaLayout, getThreadId(rewriter, loc),
|
||||||
@@ -3620,21 +3715,21 @@ LogicalResult ConvertLayoutOpConversion::lowerSharedToDotOperand(
|
|||||||
|
|
||||||
if (dotOperandLayout.getOpIdx() == 0) {
|
if (dotOperandLayout.getOpIdx() == 0) {
|
||||||
// operand $a
|
// operand $a
|
||||||
res = mmaHelper.loadA(src, adaptor.src());
|
res = mmaHelper.loadA(src, smemObj);
|
||||||
} else if (dotOperandLayout.getOpIdx() == 1) {
|
} else if (dotOperandLayout.getOpIdx() == 1) {
|
||||||
// operand $b
|
// operand $b
|
||||||
res = mmaHelper.loadB(src, adaptor.src());
|
res = mmaHelper.loadB(src, smemObj);
|
||||||
}
|
}
|
||||||
} else if (mmaLayout.getVersion() == 1) {
|
} else if (mmaLayout.getVersion() == 1) {
|
||||||
DotOpMmaV1ConversionHelper helper(mmaLayout);
|
DotOpMmaV1ConversionHelper helper(mmaLayout);
|
||||||
if (dotOperandLayout.getOpIdx() == 0) {
|
if (dotOperandLayout.getOpIdx() == 0) {
|
||||||
// operand $a
|
// operand $a
|
||||||
res = helper.loadA(src, adaptor.src(), getThreadId(rewriter, loc),
|
res =
|
||||||
adaptor.src(), loc, rewriter);
|
helper.loadA(src, smemObj, getThreadId(rewriter, loc), loc, rewriter);
|
||||||
} else if (dotOperandLayout.getOpIdx() == 1) {
|
} else if (dotOperandLayout.getOpIdx() == 1) {
|
||||||
// operand $b
|
// operand $b
|
||||||
res = helper.loadB(src, adaptor.src(), getThreadId(rewriter, loc),
|
res =
|
||||||
adaptor.src(), loc, rewriter);
|
helper.loadB(src, smemObj, getThreadId(rewriter, loc), loc, rewriter);
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
assert(false && "Unsupported mma layout found");
|
assert(false && "Unsupported mma layout found");
|
||||||
@@ -3671,8 +3766,12 @@ DotOpConversion::convertMMA16816(triton::DotOp op, OpAdaptor adaptor,
|
|||||||
loadedA = adaptor.a();
|
loadedA = adaptor.a();
|
||||||
loadedB = adaptor.b();
|
loadedB = adaptor.b();
|
||||||
} else {
|
} else {
|
||||||
loadedA = mmaHelper.loadA(op.a(), adaptor.a());
|
SharedMemoryObject smemA =
|
||||||
loadedB = mmaHelper.loadB(op.b(), adaptor.b());
|
getSharedMemoryObjectFromStruct(loc, adaptor.a(), rewriter);
|
||||||
|
SharedMemoryObject smemB =
|
||||||
|
getSharedMemoryObjectFromStruct(loc, adaptor.b(), rewriter);
|
||||||
|
loadedA = mmaHelper.loadA(op.a(), smemA);
|
||||||
|
loadedB = mmaHelper.loadB(op.b(), smemB);
|
||||||
}
|
}
|
||||||
|
|
||||||
loadedC = mmaHelper.loadC(op.c(), adaptor.c());
|
loadedC = mmaHelper.loadC(op.c(), adaptor.c());
|
||||||
@@ -3797,8 +3896,12 @@ DotOpConversion::convertMMA884(triton::DotOp op, DotOpAdaptor adaptor,
|
|||||||
}
|
}
|
||||||
|
|
||||||
Value DotOpMmaV1ConversionHelper::loadA(
|
Value DotOpMmaV1ConversionHelper::loadA(
|
||||||
Value tensor, Value llTensor, Value thread, Value smem, Location loc,
|
Value tensor, const SharedMemoryObject &smemObj, Value thread, Location loc,
|
||||||
ConversionPatternRewriter &rewriter) const {
|
ConversionPatternRewriter &rewriter) const {
|
||||||
|
// smem
|
||||||
|
Value smem = smemObj.base;
|
||||||
|
auto strides = smemObj.strides;
|
||||||
|
|
||||||
auto *ctx = rewriter.getContext();
|
auto *ctx = rewriter.getContext();
|
||||||
auto tensorTy = tensor.getType().cast<RankedTensorType>();
|
auto tensorTy = tensor.getType().cast<RankedTensorType>();
|
||||||
auto shape = tensorTy.getShape();
|
auto shape = tensorTy.getShape();
|
||||||
@@ -3818,10 +3921,10 @@ Value DotOpMmaV1ConversionHelper::loadA(
|
|||||||
|
|
||||||
int vecA = sharedLayout.getVec();
|
int vecA = sharedLayout.getVec();
|
||||||
|
|
||||||
int strideAM = isARow ? shape[1] : 1;
|
Value strideAM = isARow ? strides[0] : i32_val(1);
|
||||||
int strideAK = isARow ? 1 : shape[0];
|
Value strideAK = isARow ? i32_val(1) : strides[1];
|
||||||
int strideA0 = isARow ? strideAK : strideAM;
|
Value strideA0 = isARow ? strideAK : strideAM;
|
||||||
int strideA1 = isARow ? strideAM : strideAK;
|
Value strideA1 = isARow ? strideAM : strideAK;
|
||||||
|
|
||||||
int strideRepM = wpt[0] * fpw[0] * 8;
|
int strideRepM = wpt[0] * fpw[0] * 8;
|
||||||
int strideRepK = 1;
|
int strideRepK = 1;
|
||||||
@@ -3847,8 +3950,7 @@ Value DotOpMmaV1ConversionHelper::loadA(
|
|||||||
offA0I = udiv(offA0I, i32_val(vecA));
|
offA0I = udiv(offA0I, i32_val(vecA));
|
||||||
offA0I = xor_(offA0I, phaseA);
|
offA0I = xor_(offA0I, phaseA);
|
||||||
offA0I = xor_(offA0I, i32_val(vecA));
|
offA0I = xor_(offA0I, i32_val(vecA));
|
||||||
offA[i] =
|
offA[i] = add(mul(offA0I, strideA0), mul(offA1, strideA1));
|
||||||
add(mul(offA0I, i32_val(strideA0)), mul(offA1, i32_val(strideA1)));
|
|
||||||
}
|
}
|
||||||
|
|
||||||
Type f16x2Ty = vec_ty(f16_ty, 2);
|
Type f16x2Ty = vec_ty(f16_ty, 2);
|
||||||
@@ -3877,8 +3979,9 @@ Value DotOpMmaV1ConversionHelper::loadA(
|
|||||||
|
|
||||||
int stepAM = isARow ? m : m / numPtrA * numPtrA;
|
int stepAM = isARow ? m : m / numPtrA * numPtrA;
|
||||||
int stepAK = isARow ? k / (numPtrA * vecA) * (numPtrA * vecA) : k;
|
int stepAK = isARow ? k / (numPtrA * vecA) * (numPtrA * vecA) : k;
|
||||||
Value pa = gep(f16PtrTy, thePtrA,
|
Value offset = add(mul(i32_val(stepAM * strideRepM), strideAM),
|
||||||
i32_val(stepAM * strideRepM * strideAM + stepAK * strideAK));
|
mul(i32_val(stepAK), strideAK));
|
||||||
|
Value pa = gep(f16PtrTy, thePtrA, offset);
|
||||||
Type aPtrTy = ptr_ty(vec_ty(i32_ty, std::max<int>(vecA / 2, 1)), 3);
|
Type aPtrTy = ptr_ty(vec_ty(i32_ty, std::max<int>(vecA / 2, 1)), 3);
|
||||||
Value ha = load(bitcast(pa, aPtrTy));
|
Value ha = load(bitcast(pa, aPtrTy));
|
||||||
// record lds that needs to be moved
|
// record lds that needs to be moved
|
||||||
@@ -3915,8 +4018,12 @@ Value DotOpMmaV1ConversionHelper::loadA(
|
|||||||
}
|
}
|
||||||
|
|
||||||
Value DotOpMmaV1ConversionHelper::loadB(
|
Value DotOpMmaV1ConversionHelper::loadB(
|
||||||
Value tensor, Value llTensor, Value thread, Value smem, Location loc,
|
Value tensor, const SharedMemoryObject &smemObj, Value thread, Location loc,
|
||||||
ConversionPatternRewriter &rewriter) const {
|
ConversionPatternRewriter &rewriter) const {
|
||||||
|
// smem
|
||||||
|
Value smem = smemObj.base;
|
||||||
|
auto strides = smemObj.strides;
|
||||||
|
|
||||||
auto *ctx = rewriter.getContext();
|
auto *ctx = rewriter.getContext();
|
||||||
auto tensorTy = tensor.getType().cast<RankedTensorType>();
|
auto tensorTy = tensor.getType().cast<RankedTensorType>();
|
||||||
auto shape = tensorTy.getShape();
|
auto shape = tensorTy.getShape();
|
||||||
@@ -3929,10 +4036,10 @@ Value DotOpMmaV1ConversionHelper::loadB(
|
|||||||
SmallVector<int> rep({0, 2 * packSize1, 1}); // pad M with 0
|
SmallVector<int> rep({0, 2 * packSize1, 1}); // pad M with 0
|
||||||
SmallVector<int> spw({0, fpw[1] * 4 * rep[1], 1}); // pad M with 0
|
SmallVector<int> spw({0, fpw[1] * 4 * rep[1], 1}); // pad M with 0
|
||||||
int vecB = sharedLayout.getVec();
|
int vecB = sharedLayout.getVec();
|
||||||
int strideBN = isBRow ? 1 : shape[0];
|
Value strideBN = isBRow ? i32_val(1) : strides[1];
|
||||||
int strideBK = isBRow ? shape[1] : 1;
|
Value strideBK = isBRow ? strides[0] : i32_val(1);
|
||||||
int strideB0 = isBRow ? strideBN : strideBK;
|
Value strideB0 = isBRow ? strideBN : strideBK;
|
||||||
int strideB1 = isBRow ? strideBK : strideBN;
|
Value strideB1 = isBRow ? strideBK : strideBN;
|
||||||
int strideRepN = wpt[1] * fpw[1] * 8;
|
int strideRepN = wpt[1] * fpw[1] * 8;
|
||||||
int strideRepK = 1;
|
int strideRepK = 1;
|
||||||
|
|
||||||
@@ -3957,8 +4064,7 @@ Value DotOpMmaV1ConversionHelper::loadB(
|
|||||||
offB0I = udiv(offB0I, i32_val(vecB));
|
offB0I = udiv(offB0I, i32_val(vecB));
|
||||||
offB0I = xor_(offB0I, phaseB);
|
offB0I = xor_(offB0I, phaseB);
|
||||||
offB0I = mul(offB0I, i32_val(vecB));
|
offB0I = mul(offB0I, i32_val(vecB));
|
||||||
offB[i] =
|
offB[i] = add(mul(offB0I, strideB0), mul(offB1, strideB1));
|
||||||
add(mul(offB0I, i32_val(strideB0)), mul(offB1, i32_val(strideB1)));
|
|
||||||
}
|
}
|
||||||
|
|
||||||
Type f16PtrTy = ptr_ty(f16_ty);
|
Type f16PtrTy = ptr_ty(f16_ty);
|
||||||
@@ -3979,8 +4085,9 @@ Value DotOpMmaV1ConversionHelper::loadB(
|
|||||||
|
|
||||||
int stepBN = isBRow ? n / numPtrB * numPtrB : n;
|
int stepBN = isBRow ? n / numPtrB * numPtrB : n;
|
||||||
int stepBK = isBRow ? K : K / (numPtrB * vecB) * (numPtrB * vecB);
|
int stepBK = isBRow ? K : K / (numPtrB * vecB) * (numPtrB * vecB);
|
||||||
Value pb = gep(f16PtrTy, thePtrB,
|
Value offset = add(mul(i32_val(stepBN * strideRepN), strideBN),
|
||||||
i32_val(stepBN * strideRepN * strideBN + stepBK * strideBK));
|
mul(i32_val(stepBK), strideBK));
|
||||||
|
Value pb = gep(f16PtrTy, thePtrB, offset);
|
||||||
Value hb =
|
Value hb =
|
||||||
load(bitcast(pb, ptr_ty(vec_ty(i32_ty, std::max(vecB / 2, 1)), 3)));
|
load(bitcast(pb, ptr_ty(vec_ty(i32_ty, std::max(vecB / 2, 1)), 3)));
|
||||||
// record lds that needs to be moved
|
// record lds that needs to be moved
|
||||||
@@ -4171,7 +4278,17 @@ public:
|
|||||||
return LLVM::LLVMStructType::getLiteral(ctx, types);
|
return LLVM::LLVMStructType::getLiteral(ctx, types);
|
||||||
} else if (auto shared_layout =
|
} else if (auto shared_layout =
|
||||||
layout.dyn_cast_or_null<SharedEncodingAttr>()) {
|
layout.dyn_cast_or_null<SharedEncodingAttr>()) {
|
||||||
return LLVM::LLVMPointerType::get(convertType(type.getElementType()), 3);
|
SmallVector<Type, 4> types;
|
||||||
|
// base ptr
|
||||||
|
auto ptrType =
|
||||||
|
LLVM::LLVMPointerType::get(convertType(type.getElementType()), 3);
|
||||||
|
types.push_back(ptrType);
|
||||||
|
// shape dims
|
||||||
|
auto rank = type.getRank();
|
||||||
|
for (auto i = 0; i < rank; i++) {
|
||||||
|
types.push_back(IntegerType::get(ctx, 32));
|
||||||
|
}
|
||||||
|
return LLVM::LLVMStructType::getLiteral(ctx, types);
|
||||||
} else if (auto mmaLayout = layout.dyn_cast_or_null<MmaEncodingAttr>()) {
|
} else if (auto mmaLayout = layout.dyn_cast_or_null<MmaEncodingAttr>()) {
|
||||||
if (mmaLayout.getVersion() == 2) {
|
if (mmaLayout.getVersion() == 2) {
|
||||||
auto [repM, repN] = DotOpMmaV2ConversionHelper::getRepMN(type);
|
auto [repM, repN] = DotOpMmaV2ConversionHelper::getRepMN(type);
|
||||||
@@ -4309,15 +4426,26 @@ struct InsertSliceAsyncOpConversion
|
|||||||
auto srcElems = getLLVMElems(src, llSrc, rewriter, loc);
|
auto srcElems = getLLVMElems(src, llSrc, rewriter, loc);
|
||||||
|
|
||||||
// %dst
|
// %dst
|
||||||
|
auto dstTy = dst.getType().cast<RankedTensorType>();
|
||||||
|
auto dstShape = dstTy.getShape();
|
||||||
|
auto smemObj = getSharedMemoryObjectFromStruct(loc, llDst, rewriter);
|
||||||
auto axis = op->getAttrOfType<IntegerAttr>("axis").getInt();
|
auto axis = op->getAttrOfType<IntegerAttr>("axis").getInt();
|
||||||
assert(axis == 0 && "insert_slice_async: Only axis=0 is supported for now");
|
SmallVector<Value, 4> offsetVals;
|
||||||
auto dstBase = createIndexAttrConstant(rewriter, loc,
|
SmallVector<Value, 4> srcStrides;
|
||||||
getTypeConverter()->getIndexType(),
|
for (auto i = 0; i < dstShape.size(); ++i) {
|
||||||
product<int64_t>(srcTy.getShape()));
|
if (i == axis) {
|
||||||
Value offset = mul(llIndex, dstBase);
|
offsetVals.emplace_back(llIndex);
|
||||||
auto dstPtrTy = LLVM::LLVMPointerType::get(
|
} else {
|
||||||
getTypeConverter()->convertType(resTy.getElementType()), 3);
|
offsetVals.emplace_back(i32_val(0));
|
||||||
Value dstPtrBase = gep(dstPtrTy, llDst, offset);
|
srcStrides.emplace_back(smemObj.strides[i]);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// Compute the offset based on the original dimensions of the shared memory
|
||||||
|
// object
|
||||||
|
auto dstOffset = dot(rewriter, loc, offsetVals, smemObj.strides);
|
||||||
|
auto dstPtrTy =
|
||||||
|
ptr_ty(getTypeConverter()->convertType(resTy.getElementType()), 3);
|
||||||
|
Value dstPtrBase = gep(dstPtrTy, smemObj.base, dstOffset);
|
||||||
|
|
||||||
// %mask
|
// %mask
|
||||||
SmallVector<Value> maskElems;
|
SmallVector<Value> maskElems;
|
||||||
@@ -4345,11 +4473,10 @@ struct InsertSliceAsyncOpConversion
|
|||||||
unsigned maxPhase = resSharedLayout.getMaxPhase();
|
unsigned maxPhase = resSharedLayout.getMaxPhase();
|
||||||
auto sizePerThread = srcBlockedLayout.getSizePerThread();
|
auto sizePerThread = srcBlockedLayout.getSizePerThread();
|
||||||
auto threadsPerCTA = getThreadsPerCTA(srcBlockedLayout);
|
auto threadsPerCTA = getThreadsPerCTA(srcBlockedLayout);
|
||||||
|
|
||||||
auto inOrder = srcBlockedLayout.getOrder();
|
auto inOrder = srcBlockedLayout.getOrder();
|
||||||
|
|
||||||
// If perPhase * maxPhase > threadsPerCTA, we need to swizzle over
|
// If perPhase * maxPhase > threadsPerCTA, we need to swizzle over
|
||||||
// elements across phases. If perPhase * maxPhase == threadsPerCTA,
|
// elements across phases. If perPhase * maxPhase <= threadsPerCTA,
|
||||||
// swizzle is not allowd
|
// swizzle is not allowd
|
||||||
auto numSwizzleRows = std::max<unsigned>(
|
auto numSwizzleRows = std::max<unsigned>(
|
||||||
(perPhase * maxPhase) / threadsPerCTA[inOrder[1]], 1);
|
(perPhase * maxPhase) / threadsPerCTA[inOrder[1]], 1);
|
||||||
@@ -4377,7 +4504,6 @@ struct InsertSliceAsyncOpConversion
|
|||||||
vecIdxCol / numVecCols * numVecCols * threadsPerCTA[inOrder[0]];
|
vecIdxCol / numVecCols * numVecCols * threadsPerCTA[inOrder[0]];
|
||||||
auto baseOffsetRow = vecIdxRow / numSwizzleRows * numSwizzleRows *
|
auto baseOffsetRow = vecIdxRow / numSwizzleRows * numSwizzleRows *
|
||||||
threadsPerCTA[inOrder[1]];
|
threadsPerCTA[inOrder[1]];
|
||||||
auto baseOffset = (baseOffsetRow * srcShape[inOrder[0]] + baseOffsetCol);
|
|
||||||
auto tileVecIdxCol = vecIdxCol % numVecCols;
|
auto tileVecIdxCol = vecIdxCol % numVecCols;
|
||||||
auto tileVecIdxRow = vecIdxRow % numSwizzleRows;
|
auto tileVecIdxRow = vecIdxRow % numSwizzleRows;
|
||||||
|
|
||||||
@@ -4399,8 +4525,10 @@ struct InsertSliceAsyncOpConversion
|
|||||||
auto srcIdx = srcIndices[tileVecIdxRow * sizePerThread[inOrder[0]]];
|
auto srcIdx = srcIndices[tileVecIdxRow * sizePerThread[inOrder[0]]];
|
||||||
Value phase = urem(udiv(srcIdx[inOrder[1]], i32_val(perPhase)),
|
Value phase = urem(udiv(srcIdx[inOrder[1]], i32_val(perPhase)),
|
||||||
i32_val(maxPhase));
|
i32_val(maxPhase));
|
||||||
Value rowOffset =
|
// srcShape and smemObj.shape maybe different if smemObj is a
|
||||||
mul(srcIdx[inOrder[1]], i32_val(srcShape[inOrder[0]]));
|
// slice of the original shared memory object.
|
||||||
|
// So we need to use the original shape to compute the offset
|
||||||
|
Value rowOffset = mul(srcIdx[inOrder[1]], srcStrides[inOrder[1]]);
|
||||||
Value colOffset =
|
Value colOffset =
|
||||||
add(srcIdx[inOrder[0]], i32_val(tileVecIdxCol * minVec));
|
add(srcIdx[inOrder[0]], i32_val(tileVecIdxCol * minVec));
|
||||||
Value swizzleIdx = udiv(colOffset, i32_val(outVec));
|
Value swizzleIdx = udiv(colOffset, i32_val(outVec));
|
||||||
@@ -4420,21 +4548,25 @@ struct InsertSliceAsyncOpConversion
|
|||||||
auto numWords = vecBitWidth / bitWidth;
|
auto numWords = vecBitWidth / bitWidth;
|
||||||
auto numWordElems = bitWidth / resElemTy.getIntOrFloatBitWidth();
|
auto numWordElems = bitWidth / resElemTy.getIntOrFloatBitWidth();
|
||||||
|
|
||||||
// XXX(Keren): Tune CG and CA here.
|
// Tune CG and CA here.
|
||||||
auto byteWidth = bitWidth / 8;
|
auto byteWidth = bitWidth / 8;
|
||||||
CacheModifier srcCacheModifier =
|
CacheModifier srcCacheModifier =
|
||||||
byteWidth == 16 ? CacheModifier::CG : CacheModifier::CA;
|
byteWidth == 16 ? CacheModifier::CG : CacheModifier::CA;
|
||||||
assert(byteWidth == 16 || byteWidth == 8 || byteWidth == 4);
|
assert(byteWidth == 16 || byteWidth == 8 || byteWidth == 4);
|
||||||
auto resByteWidth = resElemTy.getIntOrFloatBitWidth() / 8;
|
auto resByteWidth = resElemTy.getIntOrFloatBitWidth() / 8;
|
||||||
|
|
||||||
auto tileOffset = tileOffsetMap[{tileVecIdxRow, tileVecIdxCol}];
|
Value tileOffset = tileOffsetMap[{tileVecIdxRow, tileVecIdxCol}];
|
||||||
|
Value baseOffset =
|
||||||
|
add(mul(i32_val(baseOffsetRow), srcStrides[inOrder[1]]),
|
||||||
|
i32_val(baseOffsetCol));
|
||||||
|
Value basePtr = gep(dstPtrTy, tileOffset, baseOffset);
|
||||||
for (size_t wordIdx = 0; wordIdx < numWords; ++wordIdx) {
|
for (size_t wordIdx = 0; wordIdx < numWords; ++wordIdx) {
|
||||||
PTXBuilder ptxBuilder;
|
PTXBuilder ptxBuilder;
|
||||||
auto wordElemIdx = wordIdx * numWordElems;
|
auto wordElemIdx = wordIdx * numWordElems;
|
||||||
auto ©AsyncOp =
|
auto ©AsyncOp =
|
||||||
*ptxBuilder.create<PTXCpAsyncLoadInstr>(srcCacheModifier);
|
*ptxBuilder.create<PTXCpAsyncLoadInstr>(srcCacheModifier);
|
||||||
auto *dstOperand = ptxBuilder.newAddrOperand(
|
auto *dstOperand =
|
||||||
tileOffset, "r", (wordElemIdx + baseOffset) * resByteWidth);
|
ptxBuilder.newAddrOperand(basePtr, "r", wordElemIdx * resByteWidth);
|
||||||
auto *srcOperand =
|
auto *srcOperand =
|
||||||
ptxBuilder.newAddrOperand(srcElems[elemIdx + wordElemIdx], "l");
|
ptxBuilder.newAddrOperand(srcElems[elemIdx + wordElemIdx], "l");
|
||||||
auto *copySize = ptxBuilder.newConstantOperand(byteWidth);
|
auto *copySize = ptxBuilder.newConstantOperand(byteWidth);
|
||||||
|
@@ -1,5 +1,6 @@
|
|||||||
add_mlir_dialect_library(TritonGPUIR
|
add_mlir_dialect_library(TritonGPUIR
|
||||||
Dialect.cpp
|
Dialect.cpp
|
||||||
|
Traits.cpp
|
||||||
|
|
||||||
DEPENDS
|
DEPENDS
|
||||||
TritonGPUTableGen
|
TritonGPUTableGen
|
||||||
|
@@ -474,7 +474,7 @@ void SharedEncodingAttr::print(AsmPrinter &printer) const {
|
|||||||
|
|
||||||
ParseResult parseInsertSliceAsyncOp(OpAsmParser &parser,
|
ParseResult parseInsertSliceAsyncOp(OpAsmParser &parser,
|
||||||
OperationState &result) {
|
OperationState &result) {
|
||||||
SmallVector<OpAsmParser::OperandType, 4> allOperands;
|
SmallVector<OpAsmParser::OperandType, 8> allOperands;
|
||||||
Type srcType, dstType;
|
Type srcType, dstType;
|
||||||
SMLoc allOperandLoc = parser.getCurrentLocation();
|
SMLoc allOperandLoc = parser.getCurrentLocation();
|
||||||
if (parser.parseOperandList(allOperands) ||
|
if (parser.parseOperandList(allOperands) ||
|
||||||
@@ -489,14 +489,27 @@ ParseResult parseInsertSliceAsyncOp(OpAsmParser &parser,
|
|||||||
operandTypes.push_back(dstType); // dst
|
operandTypes.push_back(dstType); // dst
|
||||||
operandTypes.push_back(
|
operandTypes.push_back(
|
||||||
IntegerType::get(parser.getBuilder().getContext(), 32)); // index
|
IntegerType::get(parser.getBuilder().getContext(), 32)); // index
|
||||||
if (allOperands.size() >= 4)
|
|
||||||
|
int hasMask = 0, hasOther = 0;
|
||||||
|
if (allOperands.size() >= 4) {
|
||||||
operandTypes.push_back(triton::getI1SameShape(srcType)); // mask
|
operandTypes.push_back(triton::getI1SameShape(srcType)); // mask
|
||||||
if (allOperands.size() >= 5)
|
hasMask = 1;
|
||||||
|
}
|
||||||
|
if (allOperands.size() >= 5) {
|
||||||
operandTypes.push_back(triton::getPointeeType(srcType)); // other
|
operandTypes.push_back(triton::getPointeeType(srcType)); // other
|
||||||
|
hasOther = 1;
|
||||||
|
}
|
||||||
|
|
||||||
if (parser.resolveOperands(allOperands, operandTypes, allOperandLoc,
|
if (parser.resolveOperands(allOperands, operandTypes, allOperandLoc,
|
||||||
result.operands))
|
result.operands))
|
||||||
return failure();
|
return failure();
|
||||||
|
|
||||||
|
// Deduce operand_segment_sizes from the number of the operands.
|
||||||
|
auto operand_segment_sizesAttrName =
|
||||||
|
InsertSliceAsyncOp::operand_segment_sizesAttrName(result.name);
|
||||||
|
result.addAttribute(
|
||||||
|
operand_segment_sizesAttrName,
|
||||||
|
parser.getBuilder().getI32VectorAttr({1, 1, 1, hasMask, hasOther}));
|
||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -504,39 +517,16 @@ void printInsertSliceAsyncOp(OpAsmPrinter &printer,
|
|||||||
InsertSliceAsyncOp insertSliceAsyncOp) {
|
InsertSliceAsyncOp insertSliceAsyncOp) {
|
||||||
printer << " ";
|
printer << " ";
|
||||||
printer << insertSliceAsyncOp.getOperation()->getOperands();
|
printer << insertSliceAsyncOp.getOperation()->getOperands();
|
||||||
printer.printOptionalAttrDict(insertSliceAsyncOp->getAttrs(),
|
// "operand_segment_sizes" can be deduced, so we don't print it.
|
||||||
/*elidedAttrs=*/{});
|
printer.printOptionalAttrDict(
|
||||||
|
insertSliceAsyncOp->getAttrs(),
|
||||||
|
{insertSliceAsyncOp.operand_segment_sizesAttrName()});
|
||||||
printer << " : ";
|
printer << " : ";
|
||||||
printer.printStrippedAttrOrType(insertSliceAsyncOp.src().getType());
|
printer.printStrippedAttrOrType(insertSliceAsyncOp.src().getType());
|
||||||
printer << " -> ";
|
printer << " -> ";
|
||||||
printer.printStrippedAttrOrType(insertSliceAsyncOp.result().getType());
|
printer.printStrippedAttrOrType(insertSliceAsyncOp.result().getType());
|
||||||
}
|
}
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
|
||||||
// ExtractSliceOp
|
|
||||||
//===----------------------------------------------------------------------===//
|
|
||||||
|
|
||||||
mlir::LogicalResult ExtractSliceOp::inferReturnTypes(
|
|
||||||
::mlir::MLIRContext *context, llvm::Optional<::mlir::Location> location,
|
|
||||||
::mlir::ValueRange operands, mlir::DictionaryAttr attributes,
|
|
||||||
::mlir::RegionRange regions,
|
|
||||||
llvm::SmallVectorImpl<::mlir::Type> &inferredReturnTypes) {
|
|
||||||
auto srcType = operands[0].getType().cast<RankedTensorType>();
|
|
||||||
auto encoding = srcType.getEncoding();
|
|
||||||
auto srcShape = srcType.getShape();
|
|
||||||
auto axis = attributes.get("axis").cast<IntegerAttr>().getInt();
|
|
||||||
if (axis < 0 || (size_t)axis > srcShape.size())
|
|
||||||
return failure();
|
|
||||||
SmallVector<int64_t, 4> dstShape;
|
|
||||||
for (size_t i = 0; i < srcShape.size(); i++)
|
|
||||||
if (i != (size_t)axis)
|
|
||||||
dstShape.push_back(srcShape[i]);
|
|
||||||
auto returnType =
|
|
||||||
RankedTensorType::get(dstShape, srcType.getElementType(), encoding);
|
|
||||||
inferredReturnTypes.assign({returnType});
|
|
||||||
return success();
|
|
||||||
}
|
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
// DotOperand Encoding
|
// DotOperand Encoding
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
@@ -631,32 +621,6 @@ void TritonGPUDialect::initialize() {
|
|||||||
addInterfaces<TritonGPUInferLayoutInterface>();
|
addInterfaces<TritonGPUInferLayoutInterface>();
|
||||||
}
|
}
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
|
||||||
// Verification
|
|
||||||
//===----------------------------------------------------------------------===//
|
|
||||||
|
|
||||||
static LogicalResult verify(InsertSliceAsyncOp op) {
|
|
||||||
if (!isSharedEncoding(op.getResult())) {
|
|
||||||
return op.emitOpError(
|
|
||||||
"insert_slice_async should return a shared memory tensor");
|
|
||||||
}
|
|
||||||
return success();
|
|
||||||
}
|
|
||||||
|
|
||||||
static LogicalResult verify(ExtractSliceOp op) {
|
|
||||||
if (!isSharedEncoding(op.getResult())) {
|
|
||||||
return op.emitOpError("extract_slice should return a shared memory tensor");
|
|
||||||
}
|
|
||||||
return success();
|
|
||||||
}
|
|
||||||
|
|
||||||
static LogicalResult verify(AllocTensorOp op) {
|
|
||||||
if (!isSharedEncoding(op.getResult())) {
|
|
||||||
return op.emitOpError("alloc_tensor should return a shared memory tensor");
|
|
||||||
}
|
|
||||||
return success();
|
|
||||||
}
|
|
||||||
|
|
||||||
#define GET_OP_CLASSES
|
#define GET_OP_CLASSES
|
||||||
#include "triton/Dialect/TritonGPU/IR/Ops.cpp.inc"
|
#include "triton/Dialect/TritonGPU/IR/Ops.cpp.inc"
|
||||||
|
|
||||||
|
14
lib/Dialect/TritonGPU/IR/Traits.cpp
Normal file
14
lib/Dialect/TritonGPU/IR/Traits.cpp
Normal file
@@ -0,0 +1,14 @@
|
|||||||
|
#include "triton/Dialect/TritonGPU/IR/Traits.h"
|
||||||
|
#include "triton/Analysis/Utility.h"
|
||||||
|
|
||||||
|
mlir::LogicalResult
|
||||||
|
mlir::OpTrait::impl::verifyResultsAreSharedEncoding(Operation *op) {
|
||||||
|
if (failed(verifyAtLeastNResults(op, 1)))
|
||||||
|
return failure();
|
||||||
|
|
||||||
|
for (auto result : op->getResults())
|
||||||
|
if (!isSharedEncoding(result))
|
||||||
|
return op->emitOpError() << "requires all results to be shared encoding";
|
||||||
|
|
||||||
|
return success();
|
||||||
|
};
|
@@ -111,37 +111,41 @@ public:
|
|||||||
// cvt(insert_slice(x), type2) -> insert_slice(cvt(x, type2))
|
// cvt(insert_slice(x), type2) -> insert_slice(cvt(x, type2))
|
||||||
auto insert_slice = dyn_cast<triton::gpu::InsertSliceAsyncOp>(arg);
|
auto insert_slice = dyn_cast<triton::gpu::InsertSliceAsyncOp>(arg);
|
||||||
if (insert_slice) {
|
if (insert_slice) {
|
||||||
auto newType = op->getResult(0).getType();
|
auto newType = op->getResult(0).getType().cast<RankedTensorType>();
|
||||||
// Ensure that the new insert_slice op is placed in the same place as the
|
// Ensure that the new insert_slice op is placed in the same place as the
|
||||||
// old insert_slice op. Otherwise, the new insert_slice op may be placed
|
// old insert_slice op. Otherwise, the new insert_slice op may be placed
|
||||||
// after the async_wait op, which is not allowed.
|
// after the async_wait op, which is not allowed.
|
||||||
OpBuilder::InsertionGuard guard(rewriter);
|
OpBuilder::InsertionGuard guard(rewriter);
|
||||||
rewriter.setInsertionPoint(insert_slice);
|
rewriter.setInsertionPoint(insert_slice);
|
||||||
auto new_arg = rewriter.create<triton::gpu::ConvertLayoutOp>(
|
auto newArg = rewriter.create<triton::gpu::ConvertLayoutOp>(
|
||||||
op->getLoc(), newType, insert_slice.dst());
|
op->getLoc(), newType, insert_slice.dst());
|
||||||
rewriter.replaceOpWithNewOp<triton::gpu::InsertSliceAsyncOp>(
|
rewriter.replaceOpWithNewOp<triton::gpu::InsertSliceAsyncOp>(
|
||||||
op, newType, insert_slice.src(), new_arg.getResult(),
|
op, newType, insert_slice.src(), newArg.getResult(),
|
||||||
insert_slice.index(), insert_slice.mask(), insert_slice.other(),
|
insert_slice.index(), insert_slice.mask(), insert_slice.other(),
|
||||||
insert_slice.cache(), insert_slice.evict(), insert_slice.isVolatile(),
|
insert_slice.cache(), insert_slice.evict(),
|
||||||
insert_slice.axis());
|
insert_slice.isVolatile(), insert_slice.axis());
|
||||||
return mlir::success();
|
return mlir::success();
|
||||||
}
|
}
|
||||||
// cvt(extract_slice(x), type2) -> extract_slice(cvt(x, type2))
|
// cvt(extract_slice(x), type2) -> extract_slice(cvt(x, type2))
|
||||||
auto extract_slice = dyn_cast<triton::gpu::ExtractSliceOp>(arg);
|
auto extract_slice = dyn_cast<tensor::ExtractSliceOp>(arg);
|
||||||
if (extract_slice) {
|
if (extract_slice) {
|
||||||
auto origType = extract_slice.src().getType().cast<RankedTensorType>();
|
auto origType = extract_slice.source().getType().cast<RankedTensorType>();
|
||||||
|
auto newType = RankedTensorType::get(
|
||||||
|
origType.getShape(), origType.getElementType(),
|
||||||
|
op->getResult(0).getType().cast<RankedTensorType>().getEncoding());
|
||||||
|
auto resType = op->getResult(0).getType().cast<RankedTensorType>();
|
||||||
// Ensure that the new extract_slice op is placed in the same place as the
|
// Ensure that the new extract_slice op is placed in the same place as the
|
||||||
// old extract_slice op. Otherwise, the new extract_slice op may be placed
|
// old extract_slice op. Otherwise, the new extract_slice op may be placed
|
||||||
// after the async_wait op, which is not allowed.
|
// after the async_wait op, which is not allowed.
|
||||||
OpBuilder::InsertionGuard guard(rewriter);
|
OpBuilder::InsertionGuard guard(rewriter);
|
||||||
rewriter.setInsertionPoint(extract_slice);
|
rewriter.setInsertionPoint(extract_slice);
|
||||||
auto newType = RankedTensorType::get(
|
auto newArg = rewriter.create<triton::gpu::ConvertLayoutOp>(
|
||||||
origType.getShape(), origType.getElementType(),
|
op->getLoc(), newType, extract_slice.source());
|
||||||
op->getResult(0).getType().cast<RankedTensorType>().getEncoding());
|
rewriter.replaceOpWithNewOp<tensor::ExtractSliceOp>(
|
||||||
auto new_arg = rewriter.create<triton::gpu::ConvertLayoutOp>(
|
op, resType, newArg.getResult(), extract_slice.offsets(),
|
||||||
op->getLoc(), newType, extract_slice.src());
|
extract_slice.sizes(), extract_slice.strides(),
|
||||||
rewriter.replaceOpWithNewOp<triton::gpu::ExtractSliceOp>(
|
extract_slice.static_offsets(), extract_slice.static_sizes(),
|
||||||
op, new_arg.getResult(), extract_slice.index(), extract_slice.axis());
|
extract_slice.static_strides());
|
||||||
return mlir::success();
|
return mlir::success();
|
||||||
}
|
}
|
||||||
// cvt(type2, x)
|
// cvt(type2, x)
|
||||||
@@ -198,7 +202,7 @@ static LogicalResult invertEncoding(Attribute targetEncoding, Operation *op,
|
|||||||
inline bool expensive_to_remat(Operation *op) {
|
inline bool expensive_to_remat(Operation *op) {
|
||||||
if (!op)
|
if (!op)
|
||||||
return true;
|
return true;
|
||||||
if (isa<triton::gpu::ExtractSliceOp, triton::gpu::AllocTensorOp,
|
if (isa<tensor::ExtractSliceOp, triton::gpu::AllocTensorOp,
|
||||||
triton::gpu::InsertSliceAsyncOp, triton::LoadOp, triton::StoreOp,
|
triton::gpu::InsertSliceAsyncOp, triton::LoadOp, triton::StoreOp,
|
||||||
triton::AtomicRMWOp, triton::AtomicCASOp, triton::DotOp>(op))
|
triton::AtomicRMWOp, triton::AtomicCASOp, triton::DotOp>(op))
|
||||||
return true;
|
return true;
|
||||||
|
@@ -339,14 +339,20 @@ void LoopPipeliner::emitPrologue() {
|
|||||||
builder.create<arith::ConstantIntOp>(iv.getLoc(), 1, 32));
|
builder.create<arith::ConstantIntOp>(iv.getLoc(), 1, 32));
|
||||||
} // for (int stage = 0; stage < numStages - 1; ++stage)
|
} // for (int stage = 0; stage < numStages - 1; ++stage)
|
||||||
|
|
||||||
|
auto intAttr = [&](int64_t v) { return builder.getI64IntegerAttr(v); };
|
||||||
|
|
||||||
// async.wait & extract_slice
|
// async.wait & extract_slice
|
||||||
builder.create<triton::gpu::AsyncWaitOp>(loads[0].getLoc(),
|
builder.create<triton::gpu::AsyncWaitOp>(loads[0].getLoc(),
|
||||||
loads.size() * (numStages - 2));
|
loads.size() * (numStages - 2));
|
||||||
loopIterIdx = builder.create<arith::ConstantIntOp>(iv.getLoc(), 0, 32);
|
loopIterIdx = builder.create<arith::ConstantIntOp>(iv.getLoc(), 0, 32);
|
||||||
for (Value loadOp : loads) {
|
for (Value loadOp : loads) {
|
||||||
Value extractSlice = builder.create<triton::gpu::ExtractSliceOp>(
|
auto sliceType = loadsMapping[loadOp].getType().cast<RankedTensorType>();
|
||||||
loadOp.getLoc(), loadsMapping[loadOp].getType(),
|
Value extractSlice = builder.create<tensor::ExtractSliceOp>(
|
||||||
loadStageBuffer[loadOp][numStages - 1], loopIterIdx, /*axis*/ 0);
|
loadOp.getLoc(), sliceType, loadStageBuffer[loadOp][numStages - 1],
|
||||||
|
SmallVector<OpFoldResult>{intAttr(0), intAttr(0), intAttr(0)},
|
||||||
|
SmallVector<OpFoldResult>{intAttr(1), intAttr(sliceType.getShape()[0]),
|
||||||
|
intAttr(sliceType.getShape()[1])},
|
||||||
|
SmallVector<OpFoldResult>{intAttr(1), intAttr(1), intAttr(1)});
|
||||||
loadsExtract[loadOp] = extractSlice;
|
loadsExtract[loadOp] = extractSlice;
|
||||||
}
|
}
|
||||||
// bump up loopIterIdx, this is used for getting the correct slice for the
|
// bump up loopIterIdx, this is used for getting the correct slice for the
|
||||||
@@ -477,6 +483,10 @@ scf::ForOp LoopPipeliner::createNewForOp() {
|
|||||||
Value extractSliceIndex = builder.create<arith::RemSIOp>(
|
Value extractSliceIndex = builder.create<arith::RemSIOp>(
|
||||||
nextIV.getLoc(), loopIterIdx,
|
nextIV.getLoc(), loopIterIdx,
|
||||||
builder.create<arith::ConstantIntOp>(nextIV.getLoc(), numStages, 32));
|
builder.create<arith::ConstantIntOp>(nextIV.getLoc(), numStages, 32));
|
||||||
|
extractSliceIndex = builder.create<arith::IndexCastOp>(
|
||||||
|
extractSliceIndex.getLoc(), builder.getIndexType(), extractSliceIndex);
|
||||||
|
|
||||||
|
auto intAttr = [&](int64_t v) { return builder.getI64IntegerAttr(v); };
|
||||||
|
|
||||||
for (Operation *op : orderedDeps) {
|
for (Operation *op : orderedDeps) {
|
||||||
Operation *nextOp = nullptr;
|
Operation *nextOp = nullptr;
|
||||||
@@ -503,9 +513,14 @@ scf::ForOp LoopPipeliner::createNewForOp() {
|
|||||||
nextMapping.lookupOrDefault(loadOp.other()), loadOp.cache(),
|
nextMapping.lookupOrDefault(loadOp.other()), loadOp.cache(),
|
||||||
loadOp.evict(), loadOp.isVolatile(), /*axis*/ 0);
|
loadOp.evict(), loadOp.isVolatile(), /*axis*/ 0);
|
||||||
nextBuffers.push_back(insertAsyncOp);
|
nextBuffers.push_back(insertAsyncOp);
|
||||||
nextOp = builder.create<triton::gpu::ExtractSliceOp>(
|
auto sliceType = loadsMapping[loadOp].getType().cast<RankedTensorType>();
|
||||||
op->getLoc(), loadsMapping[loadOp].getType(), insertAsyncOp,
|
nextOp = builder.create<tensor::ExtractSliceOp>(
|
||||||
extractSliceIndex, /*axis*/ 0);
|
op->getLoc(), sliceType, insertAsyncOp,
|
||||||
|
SmallVector<OpFoldResult>{extractSliceIndex, intAttr(0), intAttr(0)},
|
||||||
|
SmallVector<OpFoldResult>{intAttr(1),
|
||||||
|
intAttr(sliceType.getShape()[0]),
|
||||||
|
intAttr(sliceType.getShape()[1])},
|
||||||
|
SmallVector<OpFoldResult>{intAttr(1), intAttr(1), intAttr(1)});
|
||||||
extractSlices.push_back(nextOp->getResult(0));
|
extractSlices.push_back(nextOp->getResult(0));
|
||||||
} else
|
} else
|
||||||
nextOp = builder.clone(*op, nextMapping);
|
nextOp = builder.clone(*op, nextMapping);
|
||||||
|
@@ -137,7 +137,7 @@ def kernel(X0, X1, Y, BLOCK: tl.constexpr):
|
|||||||
# reference result
|
# reference result
|
||||||
|
|
||||||
if expr == "cdiv":
|
if expr == "cdiv":
|
||||||
y_ref = (x0 + x1 - 1) // x1
|
y_ref = torch.div(x0 + x1 - 1, x1, rounding_mode='trunc')
|
||||||
elif expr == "umulhi":
|
elif expr == "umulhi":
|
||||||
y_ref = ((x0.to(torch.int64) * x1) >> 32).to(torch.int32)
|
y_ref = ((x0.to(torch.int64) * x1) >> 32).to(torch.int32)
|
||||||
else:
|
else:
|
||||||
|
@@ -25,6 +25,7 @@ from filelock import FileLock
|
|||||||
|
|
||||||
import triton
|
import triton
|
||||||
import triton._C.libtriton.triton as _triton
|
import triton._C.libtriton.triton as _triton
|
||||||
|
from .tools.disasm import extract
|
||||||
|
|
||||||
|
|
||||||
def str_to_ty(name):
|
def str_to_ty(name):
|
||||||
@@ -875,8 +876,6 @@ def ttir_to_ttgir(mod, num_warps, num_stages):
|
|||||||
pm = _triton.ir.pass_manager(mod.context)
|
pm = _triton.ir.pass_manager(mod.context)
|
||||||
pm.add_convert_triton_to_tritongpu_pass(num_warps)
|
pm.add_convert_triton_to_tritongpu_pass(num_warps)
|
||||||
pm.enable_debug()
|
pm.enable_debug()
|
||||||
# Get error in backend due to wrong conversion in expanding async-related instruction.
|
|
||||||
# TODO[Superjomn]: Open it when fixed.
|
|
||||||
pm.add_tritongpu_pipeline_pass(num_stages)
|
pm.add_tritongpu_pipeline_pass(num_stages)
|
||||||
pm.add_canonicalizer_pass()
|
pm.add_canonicalizer_pass()
|
||||||
pm.add_cse_pass()
|
pm.add_cse_pass()
|
||||||
@@ -1396,6 +1395,19 @@ class CompiledKernel:
|
|||||||
self.c_wrapper(grid[0], grid[1], grid[2], self.num_warps, self.shared, stream, self.cu_function, *args)
|
self.c_wrapper(grid[0], grid[1], grid[2], self.num_warps, self.shared, stream, self.cu_function, *args)
|
||||||
return
|
return
|
||||||
|
|
||||||
|
def get_sass(self, fun=None):
|
||||||
|
if 'sass' in self.asm:
|
||||||
|
return self.asm['sass']
|
||||||
|
fd, path = tempfile.mkstemp()
|
||||||
|
try:
|
||||||
|
with open(fd, 'wb') as cubin:
|
||||||
|
cubin.write(self.asm['cubin'])
|
||||||
|
self.sass = extract(path, fun)
|
||||||
|
finally:
|
||||||
|
os.remove(path)
|
||||||
|
self.asm['sass'] = self.sass
|
||||||
|
return self.sass
|
||||||
|
|
||||||
|
|
||||||
class CudaUtils(object):
|
class CudaUtils(object):
|
||||||
|
|
||||||
|
@@ -18,10 +18,10 @@ func @matmul_loop(%lb : index, %ub : index, %step : index, %A : !tt.ptr<f16>, %B
|
|||||||
%a_off = arith.constant dense<4> : tensor<128x32xi32, #AL>
|
%a_off = arith.constant dense<4> : tensor<128x32xi32, #AL>
|
||||||
%b_off = arith.constant dense<4> : tensor<32x128xi32, #BL>
|
%b_off = arith.constant dense<4> : tensor<32x128xi32, #BL>
|
||||||
scf.for %iv = %lb to %ub step %step iter_args(%a_ptr = %a_ptr_init, %b_ptr = %b_ptr_init, %prev_c = %c_init) -> (tensor<128x32x!tt.ptr<f16>, #AL>, tensor<32x128x!tt.ptr<f16>, #BL>, tensor<128x128xf32, #C>) {
|
scf.for %iv = %lb to %ub step %step iter_args(%a_ptr = %a_ptr_init, %b_ptr = %b_ptr_init, %prev_c = %c_init) -> (tensor<128x32x!tt.ptr<f16>, #AL>, tensor<32x128x!tt.ptr<f16>, #BL>, tensor<128x128xf32, #C>) {
|
||||||
%a_ = tt.load %a_ptr, %a_mask, %a_other {cache = 1 : i32, evict = 1 : i32, isOtherUnspecified = false, isVolatile = false} : tensor<128x32xf16, #AL>
|
%a_ = tt.load %a_ptr, %a_mask, %a_other {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<128x32xf16, #AL>
|
||||||
// CHECK: %4 -> %4
|
// CHECK: %4 -> %4
|
||||||
%a = triton_gpu.convert_layout %a_ : (tensor<128x32xf16, #AL>) -> tensor<128x32xf16, #A>
|
%a = triton_gpu.convert_layout %a_ : (tensor<128x32xf16, #AL>) -> tensor<128x32xf16, #A>
|
||||||
%b_ = tt.load %b_ptr, %b_mask, %b_other {cache = 1 : i32, evict = 1 : i32, isOtherUnspecified = false, isVolatile = false} : tensor<32x128xf16, #BL>
|
%b_ = tt.load %b_ptr, %b_mask, %b_other {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<32x128xf16, #BL>
|
||||||
// CHECK-NEXT: %6 -> %6
|
// CHECK-NEXT: %6 -> %6
|
||||||
%b = triton_gpu.convert_layout %b_ : (tensor<32x128xf16, #BL>) -> tensor<32x128xf16, #B>
|
%b = triton_gpu.convert_layout %b_ : (tensor<32x128xf16, #BL>) -> tensor<32x128xf16, #B>
|
||||||
%c = tt.dot %a, %b, %prev_c {allowTF32 = true, transA = false, transB = false} : tensor<128x32xf16, #A> * tensor<32x128xf16, #B> -> tensor<128x128xf32, #C>
|
%c = tt.dot %a, %b, %prev_c {allowTF32 = true, transA = false, transB = false} : tensor<128x32xf16, #A> * tensor<32x128xf16, #B> -> tensor<128x128xf32, #C>
|
||||||
@@ -60,7 +60,7 @@ func @insert_slice_async(%A : !tt.ptr<f16>, %i1 : i1) {
|
|||||||
%tensor = arith.constant dense<0.000000e+00> : tensor<1x16x16xf16, #A>
|
%tensor = arith.constant dense<0.000000e+00> : tensor<1x16x16xf16, #A>
|
||||||
%index = arith.constant 0 : i32
|
%index = arith.constant 0 : i32
|
||||||
// CHECK: %2 -> %cst_0
|
// CHECK: %2 -> %cst_0
|
||||||
%a = triton_gpu.insert_slice_async %a_ptr, %tensor, %index, %mask, %other {axis = 0 : i32, cache = 1 : i32, evict = 1 : i32, isOtherUnspecified = false, isVolatile = false} : tensor<16x16x!tt.ptr<f16>, #AL> -> tensor<1x16x16xf16, #A>
|
%a = triton_gpu.insert_slice_async %a_ptr, %tensor, %index, %mask, %other {axis = 0 : i32, cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<16x16x!tt.ptr<f16>, #AL> -> tensor<1x16x16xf16, #A>
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -68,9 +68,9 @@ func @insert_slice_async(%A : !tt.ptr<f16>, %i1 : i1) {
|
|||||||
func @extract_slice(%A : !tt.ptr<f16>) {
|
func @extract_slice(%A : !tt.ptr<f16>) {
|
||||||
// CHECK: %cst -> %cst
|
// CHECK: %cst -> %cst
|
||||||
%cst0 = arith.constant dense<0.000000e+00> : tensor<1x16x16xf16, #A>
|
%cst0 = arith.constant dense<0.000000e+00> : tensor<1x16x16xf16, #A>
|
||||||
%index = arith.constant 0 : i32
|
%index = arith.constant 0 : index
|
||||||
// CHECK-NEXT: %0 -> %cst
|
// CHECK-NEXT: %0 -> %cst
|
||||||
%cst1 = triton_gpu.extract_slice %cst0, %index { axis = 0 : i32 } : tensor<1x16x16xf16, #A> -> tensor<16x16xf16, #A>
|
%cst1 = tensor.extract_slice %cst0[%index, 0, 0][1, 16, 16][1, 1, 1] : tensor<1x16x16xf16, #A> to tensor<16x16xf16, #A>
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -144,9 +144,9 @@ func @for_if(%lb : index, %ub : index, %step : index, %A : !tt.ptr<f16>, %B : !t
|
|||||||
// CHECK-NEXT: %0#2 -> %cst,%cst_0
|
// CHECK-NEXT: %0#2 -> %cst,%cst_0
|
||||||
%a_shared, %b_shared, %c_shared = scf.for %iv = %lb to %ub step %step iter_args(%a_shared = %a_shared_init, %b_shared = %b_shared_init, %c_shared = %c_shared_init) -> (tensor<128x32xf16, #A>, tensor<128x32xf16, #A>, tensor<128x32xf16, #A>) {
|
%a_shared, %b_shared, %c_shared = scf.for %iv = %lb to %ub step %step iter_args(%a_shared = %a_shared_init, %b_shared = %b_shared_init, %c_shared = %c_shared_init) -> (tensor<128x32xf16, #A>, tensor<128x32xf16, #A>, tensor<128x32xf16, #A>) {
|
||||||
scf.if %i1 {
|
scf.if %i1 {
|
||||||
%index = arith.constant 8 : i32
|
%index = arith.constant 8 : index
|
||||||
// CHECK-NEXT: %1 -> %cst,%cst_0
|
// CHECK-NEXT: %1 -> %cst,%cst_0
|
||||||
%cst0 = triton_gpu.extract_slice %a_shared, %index { axis = 0 : i32 } : tensor<128x32xf16, #A> -> tensor<32xf16, #A>
|
%cst0 = tensor.extract_slice %a_shared[%index, 0][1, 32][1, 1] : tensor<128x32xf16, #A> to tensor<32xf16, #A>
|
||||||
scf.yield
|
scf.yield
|
||||||
}
|
}
|
||||||
scf.yield %b_shared, %a_shared, %a_shared : tensor<128x32xf16, #A>, tensor<128x32xf16, #A>, tensor<128x32xf16, #A>
|
scf.yield %b_shared, %a_shared, %a_shared : tensor<128x32xf16, #A>, tensor<128x32xf16, #A>, tensor<128x32xf16, #A>
|
||||||
|
@@ -178,7 +178,7 @@ func @insert_slice_async(%A : !tt.ptr<f16>, %i1 : i1) {
|
|||||||
// CHECK: offset = 0, size = 512
|
// CHECK: offset = 0, size = 512
|
||||||
%tensor = arith.constant dense<0.000000e+00> : tensor<1x16x16xf16, #A>
|
%tensor = arith.constant dense<0.000000e+00> : tensor<1x16x16xf16, #A>
|
||||||
%index = arith.constant 0 : i32
|
%index = arith.constant 0 : i32
|
||||||
%a = triton_gpu.insert_slice_async %a_ptr, %tensor, %index, %mask, %other {axis = 0 : i32, cache = 1 : i32, evict = 1 : i32, isOtherUnspecified = false, isVolatile = false} : tensor<16x16x!tt.ptr<f16>, #AL> -> tensor<1x16x16xf16, #A>
|
%a = triton_gpu.insert_slice_async %a_ptr, %tensor, %index, %mask, %other {axis = 0 : i32, cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<16x16x!tt.ptr<f16>, #AL> -> tensor<1x16x16xf16, #A>
|
||||||
return
|
return
|
||||||
// CHECK-NEXT: size = 512
|
// CHECK-NEXT: size = 512
|
||||||
}
|
}
|
||||||
@@ -187,8 +187,8 @@ func @insert_slice_async(%A : !tt.ptr<f16>, %i1 : i1) {
|
|||||||
func @extract_slice(%A : !tt.ptr<f16>) {
|
func @extract_slice(%A : !tt.ptr<f16>) {
|
||||||
// CHECK: offset = 0, size = 512
|
// CHECK: offset = 0, size = 512
|
||||||
%cst0 = arith.constant dense<0.000000e+00> : tensor<1x16x16xf16, #A>
|
%cst0 = arith.constant dense<0.000000e+00> : tensor<1x16x16xf16, #A>
|
||||||
%index = arith.constant 0 : i32
|
%index = arith.constant 0 : index
|
||||||
%cst1 = triton_gpu.extract_slice %cst0, %index { axis = 0 : i32 } : tensor<1x16x16xf16, #A> -> tensor<16x16xf16, #A>
|
%cst1 = tensor.extract_slice %cst0[%index, 0, 0][1, 16, 16][1,1,1] : tensor<1x16x16xf16, #A> to tensor<16x16xf16, #A>
|
||||||
return
|
return
|
||||||
// CHECK-NEXT: size = 512
|
// CHECK-NEXT: size = 512
|
||||||
}
|
}
|
||||||
@@ -271,8 +271,8 @@ func @for_if_slice(%lb : index, %ub : index, %step : index, %A : !tt.ptr<f16>, %
|
|||||||
%c_shared_init = arith.constant dense<0.00e+00> : tensor<128x32xf16, #A>
|
%c_shared_init = arith.constant dense<0.00e+00> : tensor<128x32xf16, #A>
|
||||||
%a_shared, %b_shared, %c_shared = scf.for %iv = %lb to %ub step %step iter_args(%a_shared = %a_shared_init, %b_shared = %b_shared_init, %c_shared = %c_shared_init) -> (tensor<128x32xf16, #A>, tensor<128x32xf16, #A>, tensor<128x32xf16, #A>) {
|
%a_shared, %b_shared, %c_shared = scf.for %iv = %lb to %ub step %step iter_args(%a_shared = %a_shared_init, %b_shared = %b_shared_init, %c_shared = %c_shared_init) -> (tensor<128x32xf16, #A>, tensor<128x32xf16, #A>, tensor<128x32xf16, #A>) {
|
||||||
scf.if %i1 {
|
scf.if %i1 {
|
||||||
%index = arith.constant 8 : i32
|
%index = arith.constant 8 : index
|
||||||
%cst0 = triton_gpu.extract_slice %a_shared, %index { axis = 0 : i32 } : tensor<128x32xf16, #A> -> tensor<32xf16, #A>
|
%cst0 = tensor.extract_slice %a_shared[%index, 0][1, 32][1, 1] : tensor<128x32xf16, #A> to tensor<32xf16, #A>
|
||||||
scf.yield
|
scf.yield
|
||||||
}
|
}
|
||||||
scf.yield %b_shared, %a_shared, %a_shared : tensor<128x32xf16, #A>, tensor<128x32xf16, #A>, tensor<128x32xf16, #A>
|
scf.yield %b_shared, %a_shared, %a_shared : tensor<128x32xf16, #A>, tensor<128x32xf16, #A>, tensor<128x32xf16, #A>
|
||||||
|
@@ -22,9 +22,9 @@ func @matmul_loop(%lb : index, %ub : index, %step : index, %A : !tt.ptr<f16>, %B
|
|||||||
%b_off = arith.constant dense<4> : tensor<32x128xi32, #BL>
|
%b_off = arith.constant dense<4> : tensor<32x128xi32, #BL>
|
||||||
|
|
||||||
scf.for %iv = %lb to %ub step %step iter_args(%a_ptr = %a_ptr_init, %b_ptr = %b_ptr_init, %prev_c = %c_init) -> (tensor<128x32x!tt.ptr<f16>, #AL>, tensor<32x128x!tt.ptr<f16>, #BL>, tensor<128x128xf32, #C>) {
|
scf.for %iv = %lb to %ub step %step iter_args(%a_ptr = %a_ptr_init, %b_ptr = %b_ptr_init, %prev_c = %c_init) -> (tensor<128x32x!tt.ptr<f16>, #AL>, tensor<32x128x!tt.ptr<f16>, #BL>, tensor<128x128xf32, #C>) {
|
||||||
%a_ = tt.load %a_ptr, %a_mask, %a_other {cache = 1 : i32, evict = 1 : i32, isOtherUnspecified = false, isVolatile = false} : tensor<128x32xf16, #AL>
|
%a_ = tt.load %a_ptr, %a_mask, %a_other {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<128x32xf16, #AL>
|
||||||
%a = triton_gpu.convert_layout %a_ : (tensor<128x32xf16, #AL>) -> tensor<128x32xf16, #A>
|
%a = triton_gpu.convert_layout %a_ : (tensor<128x32xf16, #AL>) -> tensor<128x32xf16, #A>
|
||||||
%b_ = tt.load %b_ptr, %b_mask, %b_other {cache = 1 : i32, evict = 1 : i32, isOtherUnspecified = false, isVolatile = false} : tensor<32x128xf16, #BL>
|
%b_ = tt.load %b_ptr, %b_mask, %b_other {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<32x128xf16, #BL>
|
||||||
%b = triton_gpu.convert_layout %b_ : (tensor<32x128xf16, #BL>) -> tensor<32x128xf16, #B>
|
%b = triton_gpu.convert_layout %b_ : (tensor<32x128xf16, #BL>) -> tensor<32x128xf16, #B>
|
||||||
// CHECK: Membar 13
|
// CHECK: Membar 13
|
||||||
%c = tt.dot %a, %b, %prev_c {allowTF32 = true, transA = false, transB = false} : tensor<128x32xf16, #A> * tensor<32x128xf16, #B> -> tensor<128x128xf32, #C>
|
%c = tt.dot %a, %b, %prev_c {allowTF32 = true, transA = false, transB = false} : tensor<128x32xf16, #A> * tensor<32x128xf16, #B> -> tensor<128x128xf32, #C>
|
||||||
@@ -41,7 +41,7 @@ func @raw_single_block(%A : !tt.ptr<f16>) {
|
|||||||
%cst1 = arith.constant dense<true> : tensor<128x32xi1, #AL>
|
%cst1 = arith.constant dense<true> : tensor<128x32xi1, #AL>
|
||||||
%cst2 = arith.constant dense<0.000000e+00> : tensor<128x32xf16, #AL>
|
%cst2 = arith.constant dense<0.000000e+00> : tensor<128x32xf16, #AL>
|
||||||
%a_ptr = tt.broadcast %A : (!tt.ptr<f16>) -> tensor<128x32x!tt.ptr<f16>, #AL>
|
%a_ptr = tt.broadcast %A : (!tt.ptr<f16>) -> tensor<128x32x!tt.ptr<f16>, #AL>
|
||||||
%a1_ = tt.load %a_ptr, %cst1, %cst2 {cache = 1 : i32, evict = 1 : i32, isOtherUnspecified = false, isVolatile = false} : tensor<128x32xf16, #AL>
|
%a1_ = tt.load %a_ptr, %cst1, %cst2 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<128x32xf16, #AL>
|
||||||
%a1 = triton_gpu.convert_layout %a1_ : (tensor<128x32xf16, #AL>) -> tensor<128x32xf16, #A>
|
%a1 = triton_gpu.convert_layout %a1_ : (tensor<128x32xf16, #AL>) -> tensor<128x32xf16, #A>
|
||||||
// CHECK: Membar 5
|
// CHECK: Membar 5
|
||||||
%a2 = triton_gpu.convert_layout %a1 : (tensor<128x32xf16, #A>) -> tensor<128x32xf16, #A>
|
%a2 = triton_gpu.convert_layout %a1 : (tensor<128x32xf16, #A>) -> tensor<128x32xf16, #A>
|
||||||
@@ -53,7 +53,7 @@ func @war_single_block(%A : !tt.ptr<f16>) {
|
|||||||
%cst1 = arith.constant dense<true> : tensor<128x32xi1, #AL>
|
%cst1 = arith.constant dense<true> : tensor<128x32xi1, #AL>
|
||||||
%cst2 = arith.constant dense<0.000000e+00> : tensor<128x32xf16, #AL>
|
%cst2 = arith.constant dense<0.000000e+00> : tensor<128x32xf16, #AL>
|
||||||
%a_ptr = tt.broadcast %A : (!tt.ptr<f16>) -> tensor<128x32x!tt.ptr<f16>, #AL>
|
%a_ptr = tt.broadcast %A : (!tt.ptr<f16>) -> tensor<128x32x!tt.ptr<f16>, #AL>
|
||||||
%a1_ = tt.load %a_ptr, %cst1, %cst2 {cache = 1 : i32, evict = 1 : i32, isOtherUnspecified = false, isVolatile = false} : tensor<128x32xf16, #AL>
|
%a1_ = tt.load %a_ptr, %cst1, %cst2 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<128x32xf16, #AL>
|
||||||
%a1 = triton_gpu.convert_layout %a1_ : (tensor<128x32xf16, #AL>) -> tensor<128x32xf16, #A>
|
%a1 = triton_gpu.convert_layout %a1_ : (tensor<128x32xf16, #AL>) -> tensor<128x32xf16, #A>
|
||||||
// CHECK: Membar 5
|
// CHECK: Membar 5
|
||||||
%a2 = triton_gpu.convert_layout %a1 : (tensor<128x32xf16, #A>) -> tensor<128x32xf16, #AL>
|
%a2 = triton_gpu.convert_layout %a1 : (tensor<128x32xf16, #A>) -> tensor<128x32xf16, #AL>
|
||||||
@@ -98,8 +98,8 @@ func @alloc() {
|
|||||||
// CHECK-LABEL: extract_slice
|
// CHECK-LABEL: extract_slice
|
||||||
func @extract_slice() {
|
func @extract_slice() {
|
||||||
%cst0 = arith.constant dense<0.000000e+00> : tensor<1x16x16xf16, #A>
|
%cst0 = arith.constant dense<0.000000e+00> : tensor<1x16x16xf16, #A>
|
||||||
%index = arith.constant 0 : i32
|
%index = arith.constant 0 : index
|
||||||
%cst1 = triton_gpu.extract_slice %cst0, %index { axis = 0 : i32 } : tensor<1x16x16xf16, #A> -> tensor<16x16xf16, #A>
|
%cst1 = tensor.extract_slice %cst0[%index, 0, 0][1, 16, 16][1, 1, 1] : tensor<1x16x16xf16, #A> to tensor<16x16xf16, #A>
|
||||||
// CHECK: Membar 3
|
// CHECK: Membar 3
|
||||||
%cst2 = triton_gpu.convert_layout %cst1 : (tensor<16x16xf16, #A>) -> tensor<16x16xf16, #AL>
|
%cst2 = triton_gpu.convert_layout %cst1 : (tensor<16x16xf16, #A>) -> tensor<16x16xf16, #AL>
|
||||||
// CHECK-NEXT: Membar 5
|
// CHECK-NEXT: Membar 5
|
||||||
@@ -114,7 +114,7 @@ func @insert_slice_async(%A : !tt.ptr<f16>, %i1 : i1) {
|
|||||||
%other = arith.constant dense<0.000000e+00> : tensor<16x16xf16, #AL>
|
%other = arith.constant dense<0.000000e+00> : tensor<16x16xf16, #AL>
|
||||||
%tensor = triton_gpu.alloc_tensor : tensor<1x16x16xf16, #A>
|
%tensor = triton_gpu.alloc_tensor : tensor<1x16x16xf16, #A>
|
||||||
%index = arith.constant 0 : i32
|
%index = arith.constant 0 : i32
|
||||||
%a = triton_gpu.insert_slice_async %a_ptr, %tensor, %index, %mask, %other {axis = 0 : i32, cache = 1 : i32, evict = 1 : i32, isOtherUnspecified = false, isVolatile = false} : tensor<16x16x!tt.ptr<f16>, #AL> -> tensor<1x16x16xf16, #A>
|
%a = triton_gpu.insert_slice_async %a_ptr, %tensor, %index, %mask, %other {axis = 0 : i32, cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<16x16x!tt.ptr<f16>, #AL> -> tensor<1x16x16xf16, #A>
|
||||||
%b = tt.cat %a, %a {axis = 0} : (tensor<1x16x16xf16, #A>, tensor<1x16x16xf16, #A>) -> tensor<2x16x16xf16, #A>
|
%b = tt.cat %a, %a {axis = 0} : (tensor<1x16x16xf16, #A>, tensor<1x16x16xf16, #A>) -> tensor<2x16x16xf16, #A>
|
||||||
// CHECK: Membar 7
|
// CHECK: Membar 7
|
||||||
%c = tt.cat %b, %b {axis = 0} : (tensor<2x16x16xf16, #A>, tensor<2x16x16xf16, #A>) -> tensor<4x16x16xf16, #A>
|
%c = tt.cat %b, %b {axis = 0} : (tensor<2x16x16xf16, #A>, tensor<2x16x16xf16, #A>) -> tensor<4x16x16xf16, #A>
|
||||||
|
@@ -346,18 +346,24 @@ module attributes {"triton_gpu.num-warps" = 4 : i32} {
|
|||||||
// CHECK: llvm.mlir.global external @global_smem
|
// CHECK: llvm.mlir.global external @global_smem
|
||||||
// CHECK-LABEL: basic_extract_slice
|
// CHECK-LABEL: basic_extract_slice
|
||||||
func @basic_extract_slice() {
|
func @basic_extract_slice() {
|
||||||
// CHECK: %[[BASE0:.*]] = llvm.mlir.addressof @global_smem
|
// CHECK: llvm.mlir.addressof @global_smem
|
||||||
// CHECK-NEXT: %[[BASE1:.*]] = llvm.bitcast %[[BASE0]]
|
// CHECK: llvm.extractvalue
|
||||||
// CHECK-NEXT: %[[OFFSET0:.*]] = llvm.mlir.constant
|
// CHECK-NEXT: llvm.extractvalue
|
||||||
// CHECK-NEXT: %[[OFFSET1:.*]] = llvm.mlir.constant
|
// CHECK-NEXT: llvm.extractvalue
|
||||||
// CHECK-NEXT: llvm.getelementptr %[[BASE1]][%[[OFFSET1]]]
|
// CHECK-NEXT: llvm.extractvalue
|
||||||
// CHECK-NEXT: %[[BASE2:.*]] = llvm.bitcast
|
// CHECK-NEXT: llvm.mlir.constant(0 : i32) : i32
|
||||||
// CHECK-NEXT: %[[OFFSET2:.*]] = llvm.mlir.constant
|
// CHECK-NEXT: llvm.mlir.constant(0 : i32) : i32
|
||||||
// CHECK-NEXT: %[[OFFSET3:.*]] = llvm.mul %[[OFFSET0]], %[[OFFSET2]]
|
// CHECK-NEXT: llvm.mlir.constant(0 : i32) : i32
|
||||||
// CHECK-NEXT: llvm.getelementptr %[[BASE2]][%[[OFFSET3]]]
|
// CHECK-NEXT: llvm.mul
|
||||||
%index = arith.constant 1 : i32
|
// CHECK-NEXT: llvm.add
|
||||||
|
// CHECK-NEXT: llvm.mul
|
||||||
|
// CHECK-NEXT: llvm.add
|
||||||
|
// CHECK-NEXT: llvm.mul
|
||||||
|
// CHECK-NEXT: llvm.add
|
||||||
|
// CHECK-NEXT: llvm.getelementptr
|
||||||
|
%index = arith.constant 1 : index
|
||||||
%0 = triton_gpu.alloc_tensor : tensor<128x16x32xf32, #shared0>
|
%0 = triton_gpu.alloc_tensor : tensor<128x16x32xf32, #shared0>
|
||||||
%1 = triton_gpu.extract_slice %0, %index {axis = 0: i32} : tensor<128x16x32xf32, #shared0> -> tensor<16x32xf32, #shared0>
|
%1 = tensor.extract_slice %0[%index, 0, 0][1, 16, 32][1, 1, 1] : tensor<128x16x32xf32, #shared0> to tensor<16x32xf32, #shared0>
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -488,22 +494,38 @@ module attributes {"triton_gpu.num-warps" = 4 : i32} {
|
|||||||
%tensor = triton_gpu.alloc_tensor : tensor<2x32x32xf32, #A>
|
%tensor = triton_gpu.alloc_tensor : tensor<2x32x32xf32, #A>
|
||||||
%index = arith.constant 1 : i32
|
%index = arith.constant 1 : i32
|
||||||
|
|
||||||
|
// CHECK: llvm.mlir.constant(0 : i32) : i32
|
||||||
|
// CHECK: llvm.add
|
||||||
// CHECK: llvm.inline_asm
|
// CHECK: llvm.inline_asm
|
||||||
// CHECK-SAME: cp.async.ca.shared.global [ ${{.*}} + 0 ], [ ${{.*}} + 0 ], 0x4, 0x4
|
// CHECK-SAME: cp.async.ca.shared.global [ ${{.*}} + 0 ], [ ${{.*}} + 0 ], 0x4, 0x4
|
||||||
|
// CHECK: llvm.mlir.constant(0 : i32) : i32
|
||||||
|
// CHECK: llvm.add
|
||||||
// CHECK: llvm.inline_asm
|
// CHECK: llvm.inline_asm
|
||||||
// CHECK-SAME: cp.async.ca.shared.global [ ${{.*}} + 0 ], [ ${{.*}} + 0 ], 0x4, 0x4
|
// CHECK-SAME: cp.async.ca.shared.global [ ${{.*}} + 0 ], [ ${{.*}} + 0 ], 0x4, 0x4
|
||||||
|
// CHECK: llvm.mlir.constant(0 : i32) : i32
|
||||||
|
// CHECK: llvm.add
|
||||||
// CHECK: llvm.inline_asm
|
// CHECK: llvm.inline_asm
|
||||||
// CHECK-SAME: cp.async.ca.shared.global [ ${{.*}} + 0 ], [ ${{.*}} + 0 ], 0x4, 0x4
|
// CHECK-SAME: cp.async.ca.shared.global [ ${{.*}} + 0 ], [ ${{.*}} + 0 ], 0x4, 0x4
|
||||||
|
// CHECK: llvm.mlir.constant(0 : i32) : i32
|
||||||
|
// CHECK: llvm.add
|
||||||
// CHECK: llvm.inline_asm
|
// CHECK: llvm.inline_asm
|
||||||
// CHECK-SAME: cp.async.ca.shared.global [ ${{.*}} + 0 ], [ ${{.*}} + 0 ], 0x4, 0x4
|
// CHECK-SAME: cp.async.ca.shared.global [ ${{.*}} + 0 ], [ ${{.*}} + 0 ], 0x4, 0x4
|
||||||
|
// CHECK: llvm.mlir.constant(16 : i32) : i32
|
||||||
|
// CHECK: llvm.add
|
||||||
// CHECK: llvm.inline_asm
|
// CHECK: llvm.inline_asm
|
||||||
// CHECK-SAME: cp.async.ca.shared.global [ ${{.*}} + 2048 ], [ ${{.*}} + 0 ], 0x4, 0x4
|
// CHECK-SAME: cp.async.ca.shared.global [ ${{.*}} + 0 ], [ ${{.*}} + 0 ], 0x4, 0x4
|
||||||
|
// CHECK: llvm.mlir.constant(16 : i32) : i32
|
||||||
|
// CHECK: llvm.add
|
||||||
// CHECK: llvm.inline_asm
|
// CHECK: llvm.inline_asm
|
||||||
// CHECK-SAME: cp.async.ca.shared.global [ ${{.*}} + 2048 ], [ ${{.*}} + 0 ], 0x4, 0x4
|
// CHECK-SAME: cp.async.ca.shared.global [ ${{.*}} + 0 ], [ ${{.*}} + 0 ], 0x4, 0x4
|
||||||
|
// CHECK: llvm.mlir.constant(16 : i32) : i32
|
||||||
|
// CHECK: llvm.add
|
||||||
// CHECK: llvm.inline_asm
|
// CHECK: llvm.inline_asm
|
||||||
// CHECK-SAME: cp.async.ca.shared.global [ ${{.*}} + 2048 ], [ ${{.*}} + 0 ], 0x4, 0x4
|
// CHECK-SAME: cp.async.ca.shared.global [ ${{.*}} + 0 ], [ ${{.*}} + 0 ], 0x4, 0x4
|
||||||
|
// CHECK: llvm.mlir.constant(16 : i32) : i32
|
||||||
|
// CHECK: llvm.add
|
||||||
// CHECK: llvm.inline_asm
|
// CHECK: llvm.inline_asm
|
||||||
// CHECK-SAME: cp.async.ca.shared.global [ ${{.*}} + 2048 ], [ ${{.*}} + 0 ], 0x4, 0x4
|
// CHECK-SAME: cp.async.ca.shared.global [ ${{.*}} + 0 ], [ ${{.*}} + 0 ], 0x4, 0x4
|
||||||
// CHECK: llvm.inline_asm
|
// CHECK: llvm.inline_asm
|
||||||
// CHECK-SAME: cp.async.commit_group
|
// CHECK-SAME: cp.async.commit_group
|
||||||
%a = triton_gpu.insert_slice_async %a_ptr, %tensor, %index {axis = 0 : i32, cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<32x32x!tt.ptr<f32>, #AL> -> tensor<2x32x32xf32, #A>
|
%a = triton_gpu.insert_slice_async %a_ptr, %tensor, %index {axis = 0 : i32, cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<32x32x!tt.ptr<f32>, #AL> -> tensor<2x32x32xf32, #A>
|
||||||
|
@@ -62,7 +62,7 @@ func @remat(%arg0: i32) -> tensor<1024xi32, #layout1> {
|
|||||||
// CHECK-LABEL: transpose
|
// CHECK-LABEL: transpose
|
||||||
func @transpose(%arg0: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg1: i32 {tt.divisibility = 16 : i32}, %arg2: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg3: i32 {tt.divisibility = 16 : i32}) {
|
func @transpose(%arg0: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg1: i32 {tt.divisibility = 16 : i32}, %arg2: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg3: i32 {tt.divisibility = 16 : i32}) {
|
||||||
// CHECK-NOT: triton_gpu.convert_layout
|
// CHECK-NOT: triton_gpu.convert_layout
|
||||||
// CHECK: [[loaded_val:%.*]] = tt.load {{.*}}, %cst, %cst_0 {cache = 1 : i32, evict = 1 : i32, isOtherUnspecified = false, isVolatile = false} : tensor<64x64xf32, [[row_layout]]>
|
// CHECK: [[loaded_val:%.*]] = tt.load {{.*}}, %cst, %cst_0 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<64x64xf32, [[row_layout]]>
|
||||||
// CHECK: [[cvt_val:%.*]] = triton_gpu.convert_layout [[loaded_val]] : (tensor<64x64xf32, [[row_layout]]>) -> tensor<64x64xf32, [[col_layout]]>
|
// CHECK: [[cvt_val:%.*]] = triton_gpu.convert_layout [[loaded_val]] : (tensor<64x64xf32, [[row_layout]]>) -> tensor<64x64xf32, [[col_layout]]>
|
||||||
// CHECK: tt.store {{.*}}, [[cvt_val]], %cst_1 : tensor<64x64xf32, [[col_layout]]>
|
// CHECK: tt.store {{.*}}, [[cvt_val]], %cst_1 : tensor<64x64xf32, [[col_layout]]>
|
||||||
// CHECK: return
|
// CHECK: return
|
||||||
@@ -91,7 +91,7 @@ func @transpose(%arg0: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg1: i32 {tt
|
|||||||
%19 = triton_gpu.convert_layout %10 : (tensor<64x64x!tt.ptr<f32>, #blocked1>) -> tensor<64x64x!tt.ptr<f32>, #blocked3>
|
%19 = triton_gpu.convert_layout %10 : (tensor<64x64x!tt.ptr<f32>, #blocked1>) -> tensor<64x64x!tt.ptr<f32>, #blocked3>
|
||||||
%20 = triton_gpu.convert_layout %cst_0 : (tensor<64x64xi1, #blocked1>) -> tensor<64x64xi1, #blocked3>
|
%20 = triton_gpu.convert_layout %cst_0 : (tensor<64x64xi1, #blocked1>) -> tensor<64x64xi1, #blocked3>
|
||||||
%21 = triton_gpu.convert_layout %cst : (tensor<64x64xf32, #blocked1>) -> tensor<64x64xf32, #blocked3>
|
%21 = triton_gpu.convert_layout %cst : (tensor<64x64xf32, #blocked1>) -> tensor<64x64xf32, #blocked3>
|
||||||
%22 = tt.load %19, %20, %21 {cache = 1 : i32, evict = 1 : i32, isVolatile = false, isOtherUnspecified = false} : tensor<64x64xf32, #blocked3>
|
%22 = tt.load %19, %20, %21 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<64x64xf32, #blocked3>
|
||||||
%23 = triton_gpu.convert_layout %22 : (tensor<64x64xf32, #blocked3>) -> tensor<64x64xf32, #blocked1>
|
%23 = triton_gpu.convert_layout %22 : (tensor<64x64xf32, #blocked3>) -> tensor<64x64xf32, #blocked1>
|
||||||
%24 = triton_gpu.convert_layout %18 : (tensor<64x64x!tt.ptr<f32>, #blocked1>) -> tensor<64x64x!tt.ptr<f32>, #blocked4>
|
%24 = triton_gpu.convert_layout %18 : (tensor<64x64x!tt.ptr<f32>, #blocked1>) -> tensor<64x64x!tt.ptr<f32>, #blocked4>
|
||||||
%25 = triton_gpu.convert_layout %23 : (tensor<64x64xf32, #blocked1>) -> tensor<64x64xf32, #blocked4>
|
%25 = triton_gpu.convert_layout %23 : (tensor<64x64xf32, #blocked1>) -> tensor<64x64xf32, #blocked4>
|
||||||
@@ -133,7 +133,7 @@ func @loop(%arg0: !tt.ptr<f32>, %arg1: i32, %arg2: !tt.ptr<f32>, %arg3: i32, %ar
|
|||||||
%23 = triton_gpu.convert_layout %arg7 : (tensor<64x64x!tt.ptr<f32>, #blocked1>) -> tensor<64x64x!tt.ptr<f32>, #blocked3>
|
%23 = triton_gpu.convert_layout %arg7 : (tensor<64x64x!tt.ptr<f32>, #blocked1>) -> tensor<64x64x!tt.ptr<f32>, #blocked3>
|
||||||
%24 = triton_gpu.convert_layout %cst : (tensor<64x64xi1, #blocked1>) -> tensor<64x64xi1, #blocked3>
|
%24 = triton_gpu.convert_layout %cst : (tensor<64x64xi1, #blocked1>) -> tensor<64x64xi1, #blocked3>
|
||||||
%25 = triton_gpu.convert_layout %cst_1 : (tensor<64x64xf32, #blocked1>) -> tensor<64x64xf32, #blocked3>
|
%25 = triton_gpu.convert_layout %cst_1 : (tensor<64x64xf32, #blocked1>) -> tensor<64x64xf32, #blocked3>
|
||||||
%26 = tt.load %23, %24, %25 {cache = 1 : i32, evict = 1 : i32, isOtherUnspecified = false, isVolatile = false} : tensor<64x64xf32, #blocked3>
|
%26 = tt.load %23, %24, %25 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<64x64xf32, #blocked3>
|
||||||
%27 = triton_gpu.convert_layout %26 : (tensor<64x64xf32, #blocked3>) -> tensor<64x64xf32, #blocked1>
|
%27 = triton_gpu.convert_layout %26 : (tensor<64x64xf32, #blocked3>) -> tensor<64x64xf32, #blocked1>
|
||||||
%28 = arith.addf %arg6, %27 : tensor<64x64xf32, #blocked1>
|
%28 = arith.addf %arg6, %27 : tensor<64x64xf32, #blocked1>
|
||||||
%29 = tt.addptr %arg7, %cst_0 : tensor<64x64x!tt.ptr<f32>, #blocked1>
|
%29 = tt.addptr %arg7, %cst_0 : tensor<64x64x!tt.ptr<f32>, #blocked1>
|
||||||
|
@@ -20,17 +20,18 @@
|
|||||||
// CHECK: %[[A1BUFFER:.*]] = triton_gpu.insert_slice_async {{.*}}, {{.*}}, %[[CONSTANT_1]]
|
// CHECK: %[[A1BUFFER:.*]] = triton_gpu.insert_slice_async {{.*}}, {{.*}}, %[[CONSTANT_1]]
|
||||||
// CHECK: %[[B1BUFFER:.*]] = triton_gpu.insert_slice_async {{.*}}, {{.*}}, %[[CONSTANT_1]]
|
// CHECK: %[[B1BUFFER:.*]] = triton_gpu.insert_slice_async {{.*}}, {{.*}}, %[[CONSTANT_1]]
|
||||||
// CHECK: triton_gpu.async_wait {num = 2 : i32}
|
// CHECK: triton_gpu.async_wait {num = 2 : i32}
|
||||||
// CHECK: %[[A0:.*]] = triton_gpu.extract_slice %[[A1BUFFER]], %[[CONSTANT_0]]
|
// CHECK: %[[A0:.*]] = tensor.extract_slice %[[A1BUFFER]][0, 0, 0]
|
||||||
// CHECK: %[[B0:.*]] = triton_gpu.extract_slice %[[B1BUFFER]], %[[CONSTANT_0]]
|
// CHECK: %[[B0:.*]] = tensor.extract_slice %[[B1BUFFER]][0, 0, 0]
|
||||||
// CHECK: scf.for {{.*}} iter_args({{.*}}, {{.*}}, {{.*}}, {{.*}}, {{.*}}, %[[arg_a0:.*]] = %[[A0]], %[[arg_b0:.*]] = %[[B0]], {{.*}}, {{.*}}, {{.*}}, %[[PIPELINE_IDX:.*]] = %[[CONSTANT_2]], %[[LOOP_IDX:.*]] = %[[CONSTANT_1]]
|
// CHECK: scf.for {{.*}} iter_args({{.*}}, {{.*}}, {{.*}}, {{.*}}, {{.*}}, %[[arg_a0:.*]] = %[[A0]], %[[arg_b0:.*]] = %[[B0]], {{.*}}, {{.*}}, {{.*}}, %[[PIPELINE_IDX:.*]] = %[[CONSTANT_2]], %[[LOOP_IDX:.*]] = %[[CONSTANT_1]]
|
||||||
// CHECK: tt.dot %[[arg_a0]], %[[arg_b0]], {{.*}}
|
// CHECK: tt.dot %[[arg_a0]], %[[arg_b0]], {{.*}}
|
||||||
// CHECK-DAG: %[[INSERT_IDX:.*]] = arith.remsi %[[PIPELINE_IDX]], %[[CONSTANT_3]]
|
// CHECK-DAG: %[[INSERT_IDX:.*]] = arith.remsi %[[PIPELINE_IDX]], %[[CONSTANT_3]]
|
||||||
// CHECK-DAG: %[[EXTRACT_IDX:.*]] = arith.remsi %[[LOOP_IDX]], %[[CONSTANT_3]]
|
// CHECK-DAG: %[[EXTRACT_INT:.*]] = arith.remsi %[[LOOP_IDX]], %[[CONSTANT_3]]
|
||||||
|
// CHECK-DAG: %[[EXTRACT_IDX:.*]] = arith.index_cast %[[EXTRACT_INT]] : i32 to index
|
||||||
// CHECK: %[[NEXT_A_BUFFER:.*]] = triton_gpu.insert_slice_async {{.*}}, {{.*}}, %[[INSERT_IDX]]
|
// CHECK: %[[NEXT_A_BUFFER:.*]] = triton_gpu.insert_slice_async {{.*}}, {{.*}}, %[[INSERT_IDX]]
|
||||||
// CHECK: %[[NEXT_B_BUFFER:.*]] = triton_gpu.insert_slice_async {{.*}}, {{.*}}, %[[INSERT_IDX]]
|
// CHECK: %[[NEXT_B_BUFFER:.*]] = triton_gpu.insert_slice_async {{.*}}, {{.*}}, %[[INSERT_IDX]]
|
||||||
// CHECK: triton_gpu.async_wait {num = 2 : i32}
|
// CHECK: triton_gpu.async_wait {num = 2 : i32}
|
||||||
// CHECK: %[[NEXT_A:.*]] = triton_gpu.extract_slice %[[NEXT_A_BUFFER]], %[[EXTRACT_IDX]]
|
// CHECK: %[[NEXT_A:.*]] = tensor.extract_slice %[[NEXT_A_BUFFER]][%[[EXTRACT_IDX]], 0, 0]
|
||||||
// CHECK: %[[NEXT_B:.*]] = triton_gpu.extract_slice %[[NEXT_B_BUFFER]], %[[EXTRACT_IDX]]
|
// CHECK: %[[NEXT_B:.*]] = tensor.extract_slice %[[NEXT_B_BUFFER]][%[[EXTRACT_IDX]], 0, 0]
|
||||||
// CHECK-DAG: %[[NEXT_PIPELINE_IDX:.*]] = arith.addi %[[PIPELINE_IDX]], %[[CONSTANT_1]]
|
// CHECK-DAG: %[[NEXT_PIPELINE_IDX:.*]] = arith.addi %[[PIPELINE_IDX]], %[[CONSTANT_1]]
|
||||||
// CHECK-DAG: %[[NEXT_LOOP_IDX:.*]] = arith.addi %[[LOOP_IDX]], %[[CONSTANT_1]]
|
// CHECK-DAG: %[[NEXT_LOOP_IDX:.*]] = arith.addi %[[LOOP_IDX]], %[[CONSTANT_1]]
|
||||||
// CHECK: scf.yield {{.*}}, {{.*}}, {{.*}}, %[[NEXT_A_BUFFER]], %[[NEXT_B_BUFFER]], %[[NEXT_A]], %[[NEXT_B]], {{.*}}, {{.*}}, {{.*}}, %[[NEXT_PIPELINE_IDX]], %[[NEXT_LOOP_IDX]]
|
// CHECK: scf.yield {{.*}}, {{.*}}, {{.*}}, %[[NEXT_A_BUFFER]], %[[NEXT_B_BUFFER]], %[[NEXT_A]], %[[NEXT_B]], {{.*}}, {{.*}}, {{.*}}, %[[NEXT_PIPELINE_IDX]], %[[NEXT_LOOP_IDX]]
|
||||||
@@ -76,17 +77,18 @@ func @matmul_loop(%lb : index, %ub : index, %step : index, %A : !tt.ptr<f16>, %B
|
|||||||
// CHECK: %[[A1BUFFER:.*]] = triton_gpu.insert_slice_async {{.*}}, {{.*}}, %[[CONSTANT_1]]
|
// CHECK: %[[A1BUFFER:.*]] = triton_gpu.insert_slice_async {{.*}}, {{.*}}, %[[CONSTANT_1]]
|
||||||
// CHECK: %[[B1BUFFER:.*]] = triton_gpu.insert_slice_async {{.*}}, {{.*}}, %[[CONSTANT_1]]
|
// CHECK: %[[B1BUFFER:.*]] = triton_gpu.insert_slice_async {{.*}}, {{.*}}, %[[CONSTANT_1]]
|
||||||
// CHECK: triton_gpu.async_wait {num = 2 : i32}
|
// CHECK: triton_gpu.async_wait {num = 2 : i32}
|
||||||
// CHECK: %[[A0:.*]] = triton_gpu.extract_slice %[[A1BUFFER]], %[[CONSTANT_0]]
|
// CHECK: %[[A0:.*]] = tensor.extract_slice %[[A1BUFFER]][0, 0, 0]
|
||||||
// CHECK: %[[B0:.*]] = triton_gpu.extract_slice %[[B1BUFFER]], %[[CONSTANT_0]]
|
// CHECK: %[[B0:.*]] = tensor.extract_slice %[[B1BUFFER]][0, 0, 0]
|
||||||
// CHECK: scf.for {{.*}} iter_args({{.*}}, {{.*}}, {{.*}}, {{.*}}, {{.*}}, %[[arg_a0:.*]] = %[[A0]], %[[arg_b0:.*]] = %[[B0]], {{.*}}, {{.*}}, {{.*}}, %[[PIPELINE_IDX:.*]] = %[[CONSTANT_2]], %[[LOOP_IDX:.*]] = %[[CONSTANT_1]]
|
// CHECK: scf.for {{.*}} iter_args({{.*}}, {{.*}}, {{.*}}, {{.*}}, {{.*}}, %[[arg_a0:.*]] = %[[A0]], %[[arg_b0:.*]] = %[[B0]], {{.*}}, {{.*}}, {{.*}}, %[[PIPELINE_IDX:.*]] = %[[CONSTANT_2]], %[[LOOP_IDX:.*]] = %[[CONSTANT_1]]
|
||||||
// CHECK: tt.dot %[[arg_a0]], %[[arg_b0]], {{.*}}
|
// CHECK: tt.dot %[[arg_a0]], %[[arg_b0]], {{.*}}
|
||||||
// CHECK-DAG: %[[INSERT_IDX:.*]] = arith.remsi %[[PIPELINE_IDX]], %[[CONSTANT_3]]
|
// CHECK-DAG: %[[INSERT_IDX:.*]] = arith.remsi %[[PIPELINE_IDX]], %[[CONSTANT_3]]
|
||||||
// CHECK-DAG: %[[EXTRACT_IDX:.*]] = arith.remsi %[[LOOP_IDX]], %[[CONSTANT_3]]
|
// CHECK-DAG: %[[EXTRACT_INT:.*]] = arith.remsi %[[LOOP_IDX]], %[[CONSTANT_3]]
|
||||||
|
// CHECK-DAG: %[[EXTRACT_IDX:.*]] = arith.index_cast %[[EXTRACT_INT]] : i32 to index
|
||||||
// CHECK: %[[NEXT_A_BUFFER:.*]] = triton_gpu.insert_slice_async {{.*}}, {{.*}}, %[[INSERT_IDX]]
|
// CHECK: %[[NEXT_A_BUFFER:.*]] = triton_gpu.insert_slice_async {{.*}}, {{.*}}, %[[INSERT_IDX]]
|
||||||
// CHECK: %[[NEXT_B_BUFFER:.*]] = triton_gpu.insert_slice_async {{.*}}, {{.*}}, %[[INSERT_IDX]]
|
// CHECK: %[[NEXT_B_BUFFER:.*]] = triton_gpu.insert_slice_async {{.*}}, {{.*}}, %[[INSERT_IDX]]
|
||||||
// CHECK: triton_gpu.async_wait {num = 2 : i32}
|
// CHECK: triton_gpu.async_wait {num = 2 : i32}
|
||||||
// CHECK: %[[NEXT_A:.*]] = triton_gpu.extract_slice %[[NEXT_A_BUFFER]], %[[EXTRACT_IDX]]
|
// CHECK: %[[NEXT_A:.*]] = tensor.extract_slice %[[NEXT_A_BUFFER]][%[[EXTRACT_IDX]], 0, 0]
|
||||||
// CHECK: %[[NEXT_B:.*]] = triton_gpu.extract_slice %[[NEXT_B_BUFFER]], %[[EXTRACT_IDX]]
|
// CHECK: %[[NEXT_B:.*]] = tensor.extract_slice %[[NEXT_B_BUFFER]][%[[EXTRACT_IDX]], 0, 0]
|
||||||
// CHECK-DAG: %[[NEXT_PIPELINE_IDX:.*]] = arith.addi %[[PIPELINE_IDX]], %[[CONSTANT_1]]
|
// CHECK-DAG: %[[NEXT_PIPELINE_IDX:.*]] = arith.addi %[[PIPELINE_IDX]], %[[CONSTANT_1]]
|
||||||
// CHECK-DAG: %[[NEXT_LOOP_IDX:.*]] = arith.addi %[[LOOP_IDX]], %[[CONSTANT_1]]
|
// CHECK-DAG: %[[NEXT_LOOP_IDX:.*]] = arith.addi %[[LOOP_IDX]], %[[CONSTANT_1]]
|
||||||
// CHECK: scf.yield {{.*}}, {{.*}}, {{.*}}, %[[NEXT_A_BUFFER]], %[[NEXT_B_BUFFER]], %[[NEXT_A]], %[[NEXT_B]], {{.*}}, {{.*}}, {{.*}}, %[[NEXT_PIPELINE_IDX]], %[[NEXT_LOOP_IDX]]
|
// CHECK: scf.yield {{.*}}, {{.*}}, {{.*}}, %[[NEXT_A_BUFFER]], %[[NEXT_B_BUFFER]], %[[NEXT_A]], %[[NEXT_B]], {{.*}}, {{.*}}, {{.*}}, %[[NEXT_PIPELINE_IDX]], %[[NEXT_LOOP_IDX]]
|
||||||
@@ -130,14 +132,15 @@ func @matmul_loop_nested(%lb : index, %ub : index, %step : index, %A : !tt.ptr<f
|
|||||||
// CHECK: %[[B0BUFFER:.*]] = triton_gpu.insert_slice_async {{.*}}, {{.*}}, %[[CONSTANT_0]]
|
// CHECK: %[[B0BUFFER:.*]] = triton_gpu.insert_slice_async {{.*}}, {{.*}}, %[[CONSTANT_0]]
|
||||||
// CHECK: %[[B1BUFFER:.*]] = triton_gpu.insert_slice_async {{.*}}, {{.*}}, %[[CONSTANT_1]]
|
// CHECK: %[[B1BUFFER:.*]] = triton_gpu.insert_slice_async {{.*}}, {{.*}}, %[[CONSTANT_1]]
|
||||||
// CHECK: triton_gpu.async_wait {num = 1 : i32}
|
// CHECK: triton_gpu.async_wait {num = 1 : i32}
|
||||||
// CHECK: %[[B0:.*]] = triton_gpu.extract_slice %[[B1BUFFER]], %[[CONSTANT_0]]
|
// CHECK: %[[B0:.*]] = tensor.extract_slice %[[B1BUFFER]][0, 0, 0]
|
||||||
// CHECK: scf.for {{.*}} iter_args({{.*}}, {{.*}}, {{.*}}, %[[arg_b0:.*]] = %[[B0]], {{.*}}, {{.*}}, %[[PIPELINE_IDX:.*]] = %[[CONSTANT_2]], %[[LOOP_IDX:.*]] = %[[CONSTANT_1]]
|
// CHECK: scf.for {{.*}} iter_args({{.*}}, {{.*}}, {{.*}}, %[[arg_b0:.*]] = %[[B0]], {{.*}}, {{.*}}, %[[PIPELINE_IDX:.*]] = %[[CONSTANT_2]], %[[LOOP_IDX:.*]] = %[[CONSTANT_1]]
|
||||||
// CHECK: tt.dot {{.*}}, %[[arg_b0]], {{.*}}
|
// CHECK: tt.dot {{.*}}, %[[arg_b0]], {{.*}}
|
||||||
// CHECK-DAG: %[[INSERT_IDX:.*]] = arith.remsi %[[PIPELINE_IDX]], %[[CONSTANT_3]]
|
// CHECK-DAG: %[[INSERT_IDX:.*]] = arith.remsi %[[PIPELINE_IDX]], %[[CONSTANT_3]]
|
||||||
// CHECK-DAG: %[[EXTRACT_IDX:.*]] = arith.remsi %[[LOOP_IDX]], %[[CONSTANT_3]]
|
// CHECK-DAG: %[[EXTRACT_INT:.*]] = arith.remsi %[[LOOP_IDX]], %[[CONSTANT_3]]
|
||||||
|
// CHECK-DAG: %[[EXTRACT_IDX:.*]] = arith.index_cast %[[EXTRACT_INT]] : i32 to index
|
||||||
// CHECK: %[[NEXT_B_BUFFER:.*]] = triton_gpu.insert_slice_async {{.*}}, {{.*}}, %[[INSERT_IDX]]
|
// CHECK: %[[NEXT_B_BUFFER:.*]] = triton_gpu.insert_slice_async {{.*}}, {{.*}}, %[[INSERT_IDX]]
|
||||||
// CHECK: triton_gpu.async_wait {num = 1 : i32}
|
// CHECK: triton_gpu.async_wait {num = 1 : i32}
|
||||||
// CHECK: %[[NEXT_B:.*]] = triton_gpu.extract_slice %[[NEXT_B_BUFFER]]
|
// CHECK: %[[NEXT_B:.*]] = tensor.extract_slice %[[NEXT_B_BUFFER]][%[[EXTRACT_IDX]], 0, 0]
|
||||||
// CHECK-DAG: %[[NEXT_PIPELINE_IDX:.*]] = arith.addi %[[PIPELINE_IDX]], %[[CONSTANT_1]]
|
// CHECK-DAG: %[[NEXT_PIPELINE_IDX:.*]] = arith.addi %[[PIPELINE_IDX]], %[[CONSTANT_1]]
|
||||||
// CHECK-DAG: %[[NEXT_LOOP_IDX:.*]] = arith.addi %[[LOOP_IDX]], %[[CONSTANT_1]]
|
// CHECK-DAG: %[[NEXT_LOOP_IDX:.*]] = arith.addi %[[LOOP_IDX]], %[[CONSTANT_1]]
|
||||||
// CHECK: scf.yield {{.*}}, {{.*}}, %[[NEXT_B_BUFFER]], %[[NEXT_B]], {{.*}}, {{.*}}, %[[NEXT_PIPELINE_IDX]], %[[NEXT_LOOP_IDX]]
|
// CHECK: scf.yield {{.*}}, {{.*}}, %[[NEXT_B_BUFFER]], %[[NEXT_B]], {{.*}}, {{.*}}, %[[NEXT_PIPELINE_IDX]], %[[NEXT_LOOP_IDX]]
|
||||||
|
Reference in New Issue
Block a user