2022-08-18 12:49:37 -07:00
|
|
|
#include <numeric>
|
|
|
|
|
2022-05-01 22:06:54 +08:00
|
|
|
#include "mlir/IR/DialectImplementation.h"
|
2022-07-26 10:50:11 -07:00
|
|
|
#include "mlir/IR/OpImplementation.h"
|
Keren/tensor slice insert alloc (#94)
This branch defines three new triton_gpu operations to partially solve #87. Below is an overview:
```
%tensor = triton_gpu.alloc_tensor : tensor<2x16x16xf16, #A>
%b = triton_gpu.insert_slice_async %a_ptr, %tensor, %offset {axis = 0 : i32, cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<16x16x!tt.ptr<f16>, #AL> -> tensor<2x16x16xf16, #A>
%c = triton_gpu.extract_slice %b, %offset {axis = 0 : i32} : tensor<2x16x16xf16, #A> -> tensor<16x16xf16, #A>
```
We plan to fully replace `copy_async` with `insert_slice_async`. **This hasn't been done yet.**
2022-09-01 12:37:17 -07:00
|
|
|
#include "triton/Analysis/Utility.h"
|
2022-10-11 18:16:41 -07:00
|
|
|
#include "triton/Dialect/Triton/IR/Dialect.h"
|
2022-08-18 12:49:37 -07:00
|
|
|
#include "triton/Dialect/TritonGPU/IR/Dialect.h"
|
2022-05-01 22:06:54 +08:00
|
|
|
#include "llvm/ADT/TypeSwitch.h"
|
2022-04-28 18:51:31 +08:00
|
|
|
|
|
|
|
#include "triton/Dialect/TritonGPU/IR/Dialect.cpp.inc"
|
|
|
|
|
2022-05-31 11:43:21 +00:00
|
|
|
using namespace mlir;
|
2022-04-28 18:51:31 +08:00
|
|
|
using namespace mlir::triton::gpu;
|
|
|
|
|
2022-08-24 12:55:49 -07:00
|
|
|
// Utility
|
|
|
|
namespace mlir {
|
|
|
|
namespace triton {
|
|
|
|
|
|
|
|
// Type inference
|
|
|
|
static Type getI1SameShape(Type type) {
|
|
|
|
auto i1Type = IntegerType::get(type.getContext(), 1);
|
|
|
|
if (auto tensorType = type.dyn_cast<RankedTensorType>())
|
|
|
|
return RankedTensorType::get(tensorType.getShape(), i1Type,
|
|
|
|
tensorType.getEncoding());
|
|
|
|
return Type();
|
|
|
|
}
|
|
|
|
|
|
|
|
static Type getPointeeType(Type type) {
|
|
|
|
if (auto tensorType = type.dyn_cast<RankedTensorType>()) {
|
|
|
|
// Tensor of pointers
|
|
|
|
auto shape = tensorType.getShape();
|
|
|
|
auto ptrType = tensorType.getElementType().dyn_cast<PointerType>();
|
|
|
|
Type pointeeType = ptrType.getPointeeType();
|
|
|
|
return RankedTensorType::get(shape, pointeeType, tensorType.getEncoding());
|
|
|
|
} else if (auto ptrType = type.dyn_cast<PointerType>()) {
|
|
|
|
// scalar pointer
|
|
|
|
Type pointeeType = ptrType.getPointeeType();
|
|
|
|
return pointeeType;
|
|
|
|
}
|
|
|
|
return Type();
|
|
|
|
}
|
|
|
|
|
2022-09-18 05:58:42 +08:00
|
|
|
namespace gpu {
|
|
|
|
|
2022-11-10 09:58:07 -08:00
|
|
|
// TODO: Inheritance of layout attributes
|
|
|
|
// so that all distributed layouts implement
|
|
|
|
// these utilities
|
|
|
|
|
|
|
|
unsigned getElemsPerThread(Attribute layout, ArrayRef<int64_t> shape) {
|
2022-09-18 05:58:42 +08:00
|
|
|
if (auto blockedLayout = layout.dyn_cast<BlockedEncodingAttr>()) {
|
|
|
|
return blockedLayout.getElemsPerThread(shape);
|
|
|
|
} else if (auto sliceLayout = layout.dyn_cast<SliceEncodingAttr>()) {
|
|
|
|
return sliceLayout.getElemsPerThread(shape);
|
|
|
|
} else if (auto mmaLayout = layout.dyn_cast<MmaEncodingAttr>()) {
|
|
|
|
return mmaLayout.getElemsPerThread(shape);
|
|
|
|
} else if (auto sharedLayout = layout.dyn_cast<SharedEncodingAttr>()) {
|
|
|
|
return sharedLayout.getElemsPerThread(shape);
|
2022-11-10 13:57:27 +08:00
|
|
|
} else if (auto dotLayout = layout.dyn_cast<DotOperandEncodingAttr>()) {
|
|
|
|
return dotLayout.getElemsPerThread(shape);
|
2022-09-18 05:58:42 +08:00
|
|
|
} else {
|
|
|
|
assert(0 && "getElemsPerThread not implemented");
|
|
|
|
return 0;
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
2022-11-10 09:58:07 -08:00
|
|
|
unsigned getElemsPerThread(Type type) {
|
2022-11-14 16:56:30 +08:00
|
|
|
if (type.isIntOrIndexOrFloat() || type.isa<triton::Float8Type>() ||
|
2022-11-10 09:58:07 -08:00
|
|
|
type.isa<triton::PointerType>())
|
|
|
|
return 1;
|
|
|
|
auto tensorType = type.cast<RankedTensorType>();
|
|
|
|
return getElemsPerThread(tensorType.getEncoding(), tensorType.getShape());
|
|
|
|
}
|
|
|
|
|
2022-11-30 10:07:34 -08:00
|
|
|
SmallVector<unsigned> getThreadsPerWarp(const Attribute &layout) {
|
2022-11-10 09:58:07 -08:00
|
|
|
if (auto blockedLayout = layout.dyn_cast<BlockedEncodingAttr>()) {
|
|
|
|
return SmallVector<unsigned>(blockedLayout.getThreadsPerWarp().begin(),
|
|
|
|
blockedLayout.getThreadsPerWarp().end());
|
|
|
|
}
|
|
|
|
if (auto mmaLayout = layout.dyn_cast<MmaEncodingAttr>()) {
|
|
|
|
if (mmaLayout.getVersion() == 1)
|
2022-11-28 16:10:30 +08:00
|
|
|
return {4, 8};
|
2022-11-10 09:58:07 -08:00
|
|
|
if (mmaLayout.getVersion() == 2)
|
2022-11-28 16:10:30 +08:00
|
|
|
return {8, 4};
|
2022-11-10 09:58:07 -08:00
|
|
|
}
|
|
|
|
assert(0 && "getThreadsPerWarp not implemented");
|
|
|
|
return {};
|
|
|
|
}
|
|
|
|
|
2022-11-30 10:07:34 -08:00
|
|
|
SmallVector<unsigned> getWarpsPerCTA(const Attribute &layout) {
|
2022-11-10 09:58:07 -08:00
|
|
|
if (auto blockedLayout = layout.dyn_cast<BlockedEncodingAttr>()) {
|
|
|
|
return SmallVector<unsigned>(blockedLayout.getWarpsPerCTA().begin(),
|
|
|
|
blockedLayout.getWarpsPerCTA().end());
|
|
|
|
}
|
|
|
|
if (auto mmaLayout = layout.dyn_cast<MmaEncodingAttr>()) {
|
|
|
|
return SmallVector<unsigned>(mmaLayout.getWarpsPerCTA().begin(),
|
|
|
|
mmaLayout.getWarpsPerCTA().end());
|
|
|
|
}
|
|
|
|
assert(0 && "getWarpsPerCTA not implemented");
|
|
|
|
return {};
|
|
|
|
}
|
|
|
|
|
2022-11-30 10:07:34 -08:00
|
|
|
SmallVector<unsigned> getSizePerThread(const Attribute &layout) {
|
2022-09-27 11:58:47 +08:00
|
|
|
if (auto blockedLayout = layout.dyn_cast<BlockedEncodingAttr>()) {
|
|
|
|
return SmallVector<unsigned>(blockedLayout.getSizePerThread().begin(),
|
|
|
|
blockedLayout.getSizePerThread().end());
|
2022-10-28 11:07:45 +08:00
|
|
|
} else if (auto sliceLayout = layout.dyn_cast<SliceEncodingAttr>()) {
|
2022-11-01 13:19:58 -07:00
|
|
|
return getSizePerThread(sliceLayout.getParent());
|
2022-09-27 11:58:47 +08:00
|
|
|
} else if (auto mmaLayout = layout.dyn_cast<MmaEncodingAttr>()) {
|
2022-11-28 16:10:30 +08:00
|
|
|
if (mmaLayout.getVersion() == 2) {
|
|
|
|
return {2, 2};
|
|
|
|
} else if (mmaLayout.getVersion() == 1) {
|
2022-11-30 17:27:26 +08:00
|
|
|
// Note: here the definition of sizePerThread is obscure, which doesn't
|
|
|
|
// mean vecSize=4 can be supported in the last dimension.
|
2022-11-28 16:10:30 +08:00
|
|
|
return {2, 4};
|
|
|
|
} else {
|
|
|
|
llvm_unreachable("Unexpected mma version");
|
|
|
|
}
|
2022-11-10 13:57:27 +08:00
|
|
|
} else if (auto dotLayout = layout.dyn_cast<DotOperandEncodingAttr>()) {
|
|
|
|
auto parentLayout = dotLayout.getParent();
|
|
|
|
assert(parentLayout && "DotOperandEncodingAttr must have a parent");
|
|
|
|
if (auto parentMmaLayout = parentLayout.dyn_cast<MmaEncodingAttr>()) {
|
|
|
|
assert(parentMmaLayout.getVersion() == 2 &&
|
|
|
|
"mmaLayout version = 1 is not implemented yet");
|
|
|
|
auto parentShapePerCTA = getShapePerCTA(parentLayout);
|
|
|
|
auto opIdx = dotLayout.getOpIdx();
|
|
|
|
if (opIdx == 0) {
|
|
|
|
return {2, 4};
|
|
|
|
} else if (opIdx == 1) {
|
|
|
|
return {4, 1};
|
|
|
|
} else {
|
|
|
|
assert(0 && "DotOperandEncodingAttr opIdx must be 0 or 1");
|
|
|
|
return {};
|
|
|
|
}
|
|
|
|
} else {
|
|
|
|
assert(0 && "DotOperandEncodingAttr non-MmaEncodingAttr parent not "
|
|
|
|
"supported yet");
|
|
|
|
return {};
|
|
|
|
}
|
2022-09-27 11:58:47 +08:00
|
|
|
} else {
|
|
|
|
assert(0 && "getSizePerThread not implemented");
|
|
|
|
return {};
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
2022-11-30 17:27:26 +08:00
|
|
|
SmallVector<unsigned> getContigPerThread(Attribute layout) {
|
|
|
|
if (auto mmaLayout = layout.dyn_cast<MmaEncodingAttr>()) {
|
|
|
|
assert(mmaLayout.getVersion() == 1 || mmaLayout.getVersion() == 2);
|
|
|
|
return {1, 2};
|
|
|
|
} else {
|
|
|
|
return getSizePerThread(layout);
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
2022-10-04 09:37:00 -07:00
|
|
|
SmallVector<unsigned> getThreadsPerCTA(const Attribute &layout) {
|
|
|
|
SmallVector<unsigned> threads;
|
|
|
|
if (auto blockedLayout = layout.dyn_cast<BlockedEncodingAttr>()) {
|
|
|
|
for (int d = 0, n = blockedLayout.getOrder().size(); d < n; ++d)
|
|
|
|
threads.push_back(blockedLayout.getThreadsPerWarp()[d] *
|
|
|
|
blockedLayout.getWarpsPerCTA()[d]);
|
|
|
|
} else if (auto mmaLayout = layout.dyn_cast<MmaEncodingAttr>()) {
|
|
|
|
assert(0 && "Unimplemented usage of MmaEncodingAttr");
|
|
|
|
} else {
|
|
|
|
assert(0 && "Unimplemented usage of getShapePerCTA");
|
|
|
|
}
|
|
|
|
|
|
|
|
return threads;
|
|
|
|
}
|
|
|
|
|
2022-09-27 14:38:34 +08:00
|
|
|
SmallVector<unsigned> getShapePerCTA(const Attribute &layout) {
|
|
|
|
SmallVector<unsigned> shape;
|
2022-09-18 05:58:42 +08:00
|
|
|
if (auto blockedLayout = layout.dyn_cast<BlockedEncodingAttr>()) {
|
2022-10-28 12:36:09 -07:00
|
|
|
for (unsigned d = 0, n = blockedLayout.getOrder().size(); d < n; ++d)
|
2022-09-27 14:38:34 +08:00
|
|
|
shape.push_back(blockedLayout.getSizePerThread()[d] *
|
|
|
|
blockedLayout.getThreadsPerWarp()[d] *
|
|
|
|
blockedLayout.getWarpsPerCTA()[d]);
|
2022-10-28 11:07:45 +08:00
|
|
|
} else if (auto sliceLayout = layout.dyn_cast<SliceEncodingAttr>()) {
|
|
|
|
unsigned dim = sliceLayout.getDim();
|
|
|
|
auto parent = sliceLayout.getParent();
|
2022-11-10 09:58:07 -08:00
|
|
|
for (unsigned d = 0, n = getOrder(parent).size(); d < n; ++d) {
|
|
|
|
if (d == dim)
|
|
|
|
continue;
|
|
|
|
shape.push_back(getSizePerThread(parent)[d] *
|
|
|
|
getThreadsPerWarp(parent)[d] * getWarpsPerCTA(parent)[d]);
|
2022-10-28 11:07:45 +08:00
|
|
|
}
|
2022-09-27 11:58:47 +08:00
|
|
|
} else if (auto mmaLayout = layout.dyn_cast<MmaEncodingAttr>()) {
|
2022-11-09 12:23:43 +08:00
|
|
|
if (mmaLayout.getVersion() == 2)
|
|
|
|
return {16 * mmaLayout.getWarpsPerCTA()[0],
|
|
|
|
8 * mmaLayout.getWarpsPerCTA()[1]};
|
|
|
|
if (mmaLayout.getVersion() == 1)
|
|
|
|
return {16 * mmaLayout.getWarpsPerCTA()[0],
|
|
|
|
16 * mmaLayout.getWarpsPerCTA()[1]};
|
|
|
|
assert(0 && "Unexpected MMA layout version found");
|
2022-11-10 13:57:27 +08:00
|
|
|
} else if (auto dotLayout = layout.dyn_cast<DotOperandEncodingAttr>()) {
|
|
|
|
auto parentLayout = dotLayout.getParent();
|
|
|
|
assert(parentLayout && "DotOperandEncodingAttr must have a parent");
|
|
|
|
if (auto parentMmaLayout = parentLayout.dyn_cast<MmaEncodingAttr>()) {
|
|
|
|
assert(parentMmaLayout.getVersion() == 2 &&
|
|
|
|
"mmaLayout version = 1 is not implemented yet");
|
|
|
|
auto parentShapePerCTA = getShapePerCTA(parentLayout);
|
|
|
|
auto opIdx = dotLayout.getOpIdx();
|
|
|
|
if (opIdx == 0) {
|
|
|
|
return {parentShapePerCTA[0], 16};
|
|
|
|
} else if (opIdx == 1) {
|
|
|
|
return {16, parentShapePerCTA[1]};
|
|
|
|
} else {
|
|
|
|
assert(0 && "DotOperandEncodingAttr opIdx must be 0 or 1");
|
|
|
|
}
|
|
|
|
} else {
|
|
|
|
assert(0 && "DotOperandEncodingAttr non-MmaEncodingAttr parent not "
|
|
|
|
"supported yet");
|
|
|
|
}
|
2022-11-28 16:10:30 +08:00
|
|
|
} else if (auto mmaLayout = layout.dyn_cast<MmaEncodingAttr>()) {
|
|
|
|
if (mmaLayout.getVersion() == 2) {
|
|
|
|
return {16 * mmaLayout.getWarpsPerCTA()[0],
|
|
|
|
8 * mmaLayout.getWarpsPerCTA()[1]};
|
|
|
|
} else if (mmaLayout.getVersion() == 1) {
|
|
|
|
return {16 * mmaLayout.getWarpsPerCTA()[0],
|
|
|
|
16 * mmaLayout.getWarpsPerCTA()[1]};
|
|
|
|
} else {
|
|
|
|
llvm_unreachable("Unexpected mma version");
|
|
|
|
}
|
2022-09-18 05:58:42 +08:00
|
|
|
} else {
|
|
|
|
assert(0 && "Unimplemented usage of getShapePerCTA");
|
|
|
|
}
|
2022-09-27 14:38:34 +08:00
|
|
|
return shape;
|
|
|
|
}
|
2022-09-18 05:58:42 +08:00
|
|
|
|
2022-09-27 11:58:47 +08:00
|
|
|
SmallVector<unsigned> getOrder(const Attribute &layout) {
|
|
|
|
if (auto blockedLayout = layout.dyn_cast<BlockedEncodingAttr>()) {
|
|
|
|
return SmallVector<unsigned>(blockedLayout.getOrder().begin(),
|
|
|
|
blockedLayout.getOrder().end());
|
|
|
|
} else if (auto mmaLayout = layout.dyn_cast<MmaEncodingAttr>()) {
|
2022-11-28 16:10:30 +08:00
|
|
|
return {1, 0};
|
2022-11-10 13:57:27 +08:00
|
|
|
} else if (auto dotLayout = layout.dyn_cast<DotOperandEncodingAttr>()) {
|
2022-11-28 16:10:30 +08:00
|
|
|
return {1, 0};
|
2022-11-01 13:19:58 -07:00
|
|
|
} else if (auto sliceLayout = layout.dyn_cast<SliceEncodingAttr>()) {
|
|
|
|
SmallVector<unsigned> parentOrder = getOrder(sliceLayout.getParent());
|
|
|
|
unsigned dim = sliceLayout.getDim();
|
|
|
|
SmallVector<unsigned> order;
|
|
|
|
for (unsigned d : parentOrder) {
|
|
|
|
if (d == dim)
|
|
|
|
continue;
|
|
|
|
else if (d > dim)
|
|
|
|
order.push_back(d - 1);
|
|
|
|
else
|
|
|
|
order.push_back(d);
|
|
|
|
}
|
|
|
|
return order;
|
2022-09-27 11:58:47 +08:00
|
|
|
} else if (auto sharedLayout = layout.dyn_cast<SharedEncodingAttr>()) {
|
|
|
|
return SmallVector<unsigned>(sharedLayout.getOrder().begin(),
|
|
|
|
sharedLayout.getOrder().end());
|
|
|
|
} else {
|
|
|
|
assert(0 && "Unimplemented usage of getOrder");
|
|
|
|
return {};
|
|
|
|
}
|
2022-10-03 19:33:25 +08:00
|
|
|
};
|
2022-09-27 11:58:47 +08:00
|
|
|
|
2022-09-18 05:58:42 +08:00
|
|
|
} // namespace gpu
|
2022-08-24 12:55:49 -07:00
|
|
|
} // namespace triton
|
|
|
|
} // namespace mlir
|
|
|
|
|
2022-08-11 21:20:47 -07:00
|
|
|
static LogicalResult parseIntAttrValue(AsmParser &parser, const Attribute &attr,
|
|
|
|
unsigned &value, StringRef desc) {
|
|
|
|
auto intAttr = attr.dyn_cast<IntegerAttr>();
|
|
|
|
if (!intAttr) {
|
|
|
|
parser.emitError(parser.getNameLoc(), "expected an integer type in ")
|
|
|
|
<< desc;
|
|
|
|
return failure();
|
|
|
|
}
|
|
|
|
if (intAttr.getType().isSignedInteger()) {
|
|
|
|
int64_t attrVal = intAttr.getSInt();
|
|
|
|
if (attrVal < 0) {
|
|
|
|
parser.emitError(parser.getNameLoc(),
|
|
|
|
"expected an unsigned integer value in ")
|
|
|
|
<< desc;
|
|
|
|
return failure();
|
|
|
|
}
|
|
|
|
value = attrVal;
|
|
|
|
} else if (intAttr.getType().isSignlessInteger()) {
|
|
|
|
int64_t attrVal = intAttr.getInt();
|
|
|
|
if (attrVal < 0) {
|
|
|
|
parser.emitError(parser.getNameLoc(),
|
|
|
|
"expected an unsigned integer value in ")
|
|
|
|
<< desc;
|
|
|
|
return failure();
|
|
|
|
}
|
|
|
|
value = attrVal;
|
|
|
|
} else {
|
|
|
|
value = intAttr.getUInt();
|
|
|
|
}
|
|
|
|
return success();
|
|
|
|
}
|
|
|
|
|
2022-05-31 11:43:21 +00:00
|
|
|
// parse an array of integers
|
|
|
|
static LogicalResult parseIntArrayAttr(AsmParser &parser,
|
|
|
|
const NamedAttribute &attr,
|
2022-07-27 01:32:10 -07:00
|
|
|
SmallVector<unsigned, 2> &res,
|
2022-07-26 17:25:03 -07:00
|
|
|
StringRef desc) {
|
2022-05-31 11:43:21 +00:00
|
|
|
auto arrayAttr = attr.getValue().dyn_cast<ArrayAttr>();
|
|
|
|
if (!arrayAttr) {
|
2022-07-26 17:25:03 -07:00
|
|
|
parser.emitError(parser.getNameLoc(), "expected an array for ") << desc;
|
2022-05-31 11:43:21 +00:00
|
|
|
return failure();
|
|
|
|
}
|
|
|
|
for (Attribute i : arrayAttr) {
|
2022-08-11 21:20:47 -07:00
|
|
|
unsigned value;
|
|
|
|
if (parseIntAttrValue(parser, i, value, desc).failed())
|
2022-05-31 11:43:21 +00:00
|
|
|
return failure();
|
2022-08-11 21:20:47 -07:00
|
|
|
res.push_back(value);
|
2022-05-31 11:43:21 +00:00
|
|
|
}
|
|
|
|
return success();
|
|
|
|
};
|
|
|
|
|
2022-07-31 13:59:44 -07:00
|
|
|
static LogicalResult parseUInt(AsmParser &parser, const NamedAttribute &attr,
|
|
|
|
unsigned &value, StringRef desc) {
|
2022-08-11 21:20:47 -07:00
|
|
|
return parseIntAttrValue(parser, attr.getValue(), value, desc);
|
2022-07-31 13:59:44 -07:00
|
|
|
};
|
|
|
|
|
2022-05-01 22:06:54 +08:00
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
// Attribute methods
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
#define GET_ATTRDEF_CLASSES
|
|
|
|
#include "triton/Dialect/TritonGPU/IR/TritonGPUAttrDefs.cpp.inc"
|
|
|
|
|
2022-08-18 12:49:37 -07:00
|
|
|
SliceEncodingAttr BlockedEncodingAttr::squeeze(int axis) {
|
|
|
|
return SliceEncodingAttr::get(getContext(), axis, *this);
|
|
|
|
}
|
|
|
|
|
2022-09-18 05:58:42 +08:00
|
|
|
unsigned BlockedEncodingAttr::getElemsPerThread(ArrayRef<int64_t> shape) const {
|
|
|
|
size_t rank = shape.size();
|
2022-09-27 14:38:34 +08:00
|
|
|
auto sizePerThread = getSizePerThread();
|
|
|
|
auto warpsPerCTA = getWarpsPerCTA();
|
|
|
|
auto threadsPerWarp = getThreadsPerWarp();
|
|
|
|
assert(rank == sizePerThread.size() &&
|
2022-09-18 05:58:42 +08:00
|
|
|
"unexpected rank in BlockedEncodingAttr::getElemsPerThread");
|
2022-09-27 14:38:34 +08:00
|
|
|
SmallVector<unsigned> elemsPerThread(rank);
|
2022-09-18 05:58:42 +08:00
|
|
|
for (size_t i = 0; i < rank; ++i) {
|
2022-09-27 14:38:34 +08:00
|
|
|
unsigned t = sizePerThread[i] * threadsPerWarp[i] * warpsPerCTA[i];
|
|
|
|
elemsPerThread[i] = ceil<unsigned>(shape[i], t) * sizePerThread[i];
|
2022-09-18 05:58:42 +08:00
|
|
|
}
|
2022-09-27 14:38:34 +08:00
|
|
|
return product<unsigned>(elemsPerThread);
|
2022-09-18 05:58:42 +08:00
|
|
|
}
|
|
|
|
|
2022-11-10 09:58:07 -08:00
|
|
|
template <class T>
|
|
|
|
SmallVector<T> SliceEncodingAttr::paddedShape(ArrayRef<T> shape) const {
|
2022-10-28 11:07:45 +08:00
|
|
|
size_t rank = shape.size();
|
|
|
|
unsigned dim = getDim();
|
2022-11-10 09:58:07 -08:00
|
|
|
SmallVector<T> retShape(rank + 1);
|
2022-10-28 11:07:45 +08:00
|
|
|
for (unsigned d = 0; d < rank + 1; ++d) {
|
|
|
|
if (d < dim)
|
|
|
|
retShape[d] = shape[d];
|
|
|
|
else if (d == dim)
|
|
|
|
retShape[d] = 1;
|
|
|
|
else
|
|
|
|
retShape[d] = shape[d - 1];
|
|
|
|
}
|
|
|
|
return retShape;
|
|
|
|
}
|
2022-11-10 09:58:07 -08:00
|
|
|
template SmallVector<unsigned>
|
|
|
|
SliceEncodingAttr::paddedShape<unsigned>(ArrayRef<unsigned> shape) const;
|
|
|
|
template SmallVector<int64_t>
|
|
|
|
SliceEncodingAttr::paddedShape<int64_t>(ArrayRef<int64_t> shape) const;
|
2022-10-28 11:07:45 +08:00
|
|
|
|
2022-09-18 05:58:42 +08:00
|
|
|
unsigned SliceEncodingAttr::getElemsPerThread(ArrayRef<int64_t> shape) const {
|
|
|
|
size_t rank = shape.size();
|
|
|
|
auto parent = getParent();
|
2022-11-10 09:58:07 -08:00
|
|
|
return ::getElemsPerThread(parent, paddedShape(shape));
|
2022-09-18 05:58:42 +08:00
|
|
|
}
|
|
|
|
|
|
|
|
unsigned MmaEncodingAttr::getElemsPerThread(ArrayRef<int64_t> shape) const {
|
2022-10-03 19:33:25 +08:00
|
|
|
size_t rank = shape.size();
|
|
|
|
assert(rank == 2 && "Unexpected rank of mma layout");
|
2022-11-01 09:42:14 +08:00
|
|
|
assert((getVersion() == 1 || getVersion() == 2) &&
|
|
|
|
"Only version 1 and 2 is supported");
|
|
|
|
|
|
|
|
int res = 0;
|
|
|
|
if (getVersion() == 1) {
|
|
|
|
unsigned mmasRow = ceil<unsigned>(shape[0], 16 * getWarpsPerCTA()[0]);
|
|
|
|
unsigned mmasCol = ceil<unsigned>(shape[1], 16 * getWarpsPerCTA()[1]);
|
|
|
|
// Each warp-level mma884 will perform a m16xn16xk4 mma, thus get a m16xn16
|
|
|
|
// matrix as result.
|
|
|
|
res = mmasRow * mmasCol * (16 * 16 / 32);
|
|
|
|
} else if (getVersion() == 2) {
|
|
|
|
unsigned elemsCol = ceil<unsigned>(shape[0], 16 * getWarpsPerCTA()[0]) * 2;
|
|
|
|
unsigned elemsRow = ceil<unsigned>(shape[1], 8 * getWarpsPerCTA()[1]) * 2;
|
|
|
|
res = elemsCol * elemsRow;
|
2022-11-28 16:10:30 +08:00
|
|
|
} else {
|
|
|
|
llvm_unreachable("Unexpected mma version");
|
2022-11-01 09:42:14 +08:00
|
|
|
}
|
|
|
|
|
|
|
|
return res;
|
2022-09-18 05:58:42 +08:00
|
|
|
}
|
|
|
|
|
|
|
|
unsigned SharedEncodingAttr::getElemsPerThread(ArrayRef<int64_t> shape) const {
|
|
|
|
// TODO:
|
|
|
|
assert(0 && "SharedEncodingAttr::getElemsPerThread not implemented");
|
|
|
|
return 0;
|
|
|
|
}
|
|
|
|
|
2022-11-10 13:57:27 +08:00
|
|
|
unsigned
|
|
|
|
DotOperandEncodingAttr::getElemsPerThread(ArrayRef<int64_t> shape) const {
|
2022-11-14 16:56:30 +08:00
|
|
|
if (auto blockedLayout = getParent().dyn_cast<BlockedEncodingAttr>()) {
|
|
|
|
return blockedLayout.getElemsPerThread(shape);
|
|
|
|
}
|
2022-11-14 10:15:53 +08:00
|
|
|
assert(0 && "DotOperandEncodingAttr::getElemsPerThread not implemented");
|
2022-11-10 13:57:27 +08:00
|
|
|
return 0;
|
|
|
|
}
|
|
|
|
|
2022-07-31 13:59:44 -07:00
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
// Blocked Encoding
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|
2022-08-18 12:49:37 -07:00
|
|
|
Attribute BlockedEncodingAttr::parse(AsmParser &parser, Type type) {
|
2022-05-31 11:43:21 +00:00
|
|
|
if (parser.parseLess().failed())
|
|
|
|
return {};
|
|
|
|
// Parse the data as a dictionary
|
|
|
|
DictionaryAttr dict;
|
|
|
|
if (parser.parseAttribute(dict).failed())
|
|
|
|
return {};
|
|
|
|
if (parser.parseGreater().failed())
|
|
|
|
return {};
|
2022-07-26 17:25:03 -07:00
|
|
|
|
2022-07-31 13:59:44 -07:00
|
|
|
SmallVector<unsigned, 2> sizePerThread;
|
|
|
|
SmallVector<unsigned, 2> threadsPerWarp;
|
|
|
|
SmallVector<unsigned, 2> warpsPerCTA;
|
2022-05-31 11:43:21 +00:00
|
|
|
SmallVector<unsigned, 2> order;
|
|
|
|
|
|
|
|
for (const NamedAttribute &attr : dict) {
|
2022-07-31 13:59:44 -07:00
|
|
|
if (attr.getName() == "sizePerThread") {
|
|
|
|
if (parseIntArrayAttr(parser, attr, sizePerThread,
|
|
|
|
"number of elements per thread")
|
2022-07-26 17:25:03 -07:00
|
|
|
.failed())
|
2022-05-31 11:43:21 +00:00
|
|
|
return {};
|
2022-07-31 13:59:44 -07:00
|
|
|
} else if (attr.getName() == "threadsPerWarp") {
|
|
|
|
if (parseIntArrayAttr(parser, attr, threadsPerWarp,
|
|
|
|
"number of threads per warp")
|
2022-07-26 17:25:03 -07:00
|
|
|
.failed())
|
2022-05-31 11:43:21 +00:00
|
|
|
return {};
|
2022-07-31 13:59:44 -07:00
|
|
|
} else if (attr.getName() == "warpsPerCTA") {
|
|
|
|
if (parseIntArrayAttr(parser, attr, warpsPerCTA,
|
|
|
|
"number of warps per CTA")
|
2022-07-26 17:25:03 -07:00
|
|
|
.failed())
|
2022-05-31 11:43:21 +00:00
|
|
|
return {};
|
|
|
|
} else if (attr.getName() == "order") {
|
|
|
|
if (parseIntArrayAttr(parser, attr, order, "order").failed())
|
|
|
|
return {};
|
|
|
|
} else {
|
|
|
|
parser.emitError(parser.getNameLoc(), "unexpected key: ")
|
|
|
|
<< attr.getName().strref();
|
|
|
|
return {};
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
2022-10-11 18:16:41 -07:00
|
|
|
auto ret = parser.getChecked<BlockedEncodingAttr>(
|
2022-07-31 13:59:44 -07:00
|
|
|
parser.getContext(), sizePerThread, threadsPerWarp, warpsPerCTA, order);
|
2022-10-11 18:16:41 -07:00
|
|
|
return ret;
|
2022-05-01 22:06:54 +08:00
|
|
|
}
|
|
|
|
|
2022-08-18 12:49:37 -07:00
|
|
|
void BlockedEncodingAttr::print(mlir::AsmPrinter &printer) const {
|
2022-06-05 14:25:09 +08:00
|
|
|
printer << "<{"
|
2022-07-31 13:59:44 -07:00
|
|
|
<< "sizePerThread = [" << getSizePerThread() << "]"
|
|
|
|
<< ", threadsPerWarp = [" << getThreadsPerWarp() << "]"
|
|
|
|
<< ", warpsPerCTA = [" << getWarpsPerCTA() << "]"
|
|
|
|
<< ", order = [" << getOrder() << "]"
|
2022-06-05 14:25:09 +08:00
|
|
|
<< "}>";
|
2022-05-01 22:06:54 +08:00
|
|
|
}
|
|
|
|
|
2022-07-31 13:59:44 -07:00
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
// MMA encoding
|
|
|
|
//===----------------------------------------------------------------------===//
|
2022-06-18 21:16:45 +08:00
|
|
|
|
2022-08-18 12:49:37 -07:00
|
|
|
Attribute MmaEncodingAttr::parse(AsmParser &parser, Type type) {
|
2022-06-06 21:03:58 +08:00
|
|
|
if (parser.parseLess().failed())
|
|
|
|
return {};
|
|
|
|
DictionaryAttr dict;
|
|
|
|
if (parser.parseAttribute(dict).failed())
|
|
|
|
return {};
|
|
|
|
if (parser.parseGreater().failed())
|
|
|
|
return {};
|
|
|
|
|
2022-07-31 13:59:44 -07:00
|
|
|
unsigned version = 0;
|
|
|
|
SmallVector<unsigned, 2> warpsPerCTA;
|
2022-06-06 21:03:58 +08:00
|
|
|
|
|
|
|
for (const NamedAttribute &attr : dict) {
|
2022-07-31 13:59:44 -07:00
|
|
|
if (attr.getName() == "version") {
|
|
|
|
if (parseUInt(parser, attr, version, "version").failed())
|
2022-06-06 21:03:58 +08:00
|
|
|
return {};
|
2022-07-31 13:59:44 -07:00
|
|
|
}
|
|
|
|
if (attr.getName() == "warpsPerCTA") {
|
|
|
|
if (parseIntArrayAttr(parser, attr, warpsPerCTA, "warpsPerCTA").failed())
|
2022-06-06 21:03:58 +08:00
|
|
|
return {};
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
2022-08-18 12:49:37 -07:00
|
|
|
return parser.getChecked<MmaEncodingAttr>(parser.getContext(), version,
|
|
|
|
warpsPerCTA);
|
2022-05-01 22:06:54 +08:00
|
|
|
}
|
|
|
|
|
2022-08-18 12:49:37 -07:00
|
|
|
void MmaEncodingAttr::print(AsmPrinter &printer) const {
|
2022-06-06 21:03:58 +08:00
|
|
|
printer << "<{"
|
2022-07-31 13:59:44 -07:00
|
|
|
<< "version = " << getVersion() << ", "
|
|
|
|
<< "warpsPerCTA = [" << getWarpsPerCTA() << "]"
|
2022-06-06 21:03:58 +08:00
|
|
|
<< "}>";
|
2022-05-01 22:06:54 +08:00
|
|
|
}
|
|
|
|
|
2022-08-18 12:49:37 -07:00
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
// Sliced Encoding
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|
|
|
|
Attribute SliceEncodingAttr::parse(AsmParser &parser, Type type) {
|
|
|
|
if (parser.parseLess().failed())
|
|
|
|
return {};
|
2022-10-11 18:16:41 -07:00
|
|
|
NamedAttrList attrs;
|
|
|
|
if (parser.parseOptionalAttrDict(attrs).failed())
|
2022-08-18 12:49:37 -07:00
|
|
|
return {};
|
|
|
|
if (parser.parseGreater().failed())
|
|
|
|
return {};
|
2022-10-11 18:16:41 -07:00
|
|
|
unsigned dim = attrs.get("dim").cast<IntegerAttr>().getInt();
|
|
|
|
Attribute parent = attrs.get("parent");
|
2022-08-18 12:49:37 -07:00
|
|
|
return parser.getChecked<SliceEncodingAttr>(parser.getContext(), dim, parent);
|
|
|
|
}
|
|
|
|
|
|
|
|
void SliceEncodingAttr::print(mlir::AsmPrinter &printer) const {
|
|
|
|
printer << "<{"
|
|
|
|
<< "dim = " << getDim() << ", "
|
|
|
|
<< "parent = " << getParent() << "}>";
|
|
|
|
}
|
|
|
|
|
2022-07-31 13:59:44 -07:00
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
// Shared encoding
|
|
|
|
//===----------------------------------------------------------------------===//
|
2022-06-18 21:16:45 +08:00
|
|
|
|
2022-08-18 12:49:37 -07:00
|
|
|
Attribute SharedEncodingAttr::parse(AsmParser &parser, Type type) {
|
2022-05-31 11:43:21 +00:00
|
|
|
if (parser.parseLess().failed())
|
|
|
|
return {};
|
|
|
|
// Parse the data as a dictionary
|
|
|
|
DictionaryAttr dict;
|
|
|
|
if (parser.parseAttribute(dict).failed())
|
|
|
|
return {};
|
|
|
|
if (parser.parseGreater().failed())
|
|
|
|
return {};
|
|
|
|
|
|
|
|
unsigned vec = 0;
|
|
|
|
unsigned perPhase = 0;
|
|
|
|
unsigned maxPhase = 0;
|
|
|
|
SmallVector<unsigned, 2> order;
|
|
|
|
|
|
|
|
for (const NamedAttribute &attr : dict) {
|
|
|
|
if (attr.getName() == "vec") {
|
2022-07-31 13:59:44 -07:00
|
|
|
if (parseUInt(parser, attr, vec, "vec").failed())
|
2022-05-31 11:43:21 +00:00
|
|
|
return {};
|
|
|
|
} else if (attr.getName() == "perPhase") {
|
2022-07-31 13:59:44 -07:00
|
|
|
if (parseUInt(parser, attr, perPhase, "perPhase").failed())
|
2022-05-31 11:43:21 +00:00
|
|
|
return {};
|
|
|
|
} else if (attr.getName() == "maxPhase") {
|
2022-07-31 13:59:44 -07:00
|
|
|
if (parseUInt(parser, attr, maxPhase, "maxPhase").failed())
|
2022-05-31 11:43:21 +00:00
|
|
|
return {};
|
|
|
|
} else if (attr.getName() == "order") {
|
|
|
|
if (parseIntArrayAttr(parser, attr, order, "order").failed())
|
|
|
|
return {};
|
|
|
|
} else {
|
|
|
|
parser.emitError(parser.getNameLoc(), "unexpected key: ")
|
2022-07-26 17:25:03 -07:00
|
|
|
<< attr.getName().strref();
|
2022-05-31 11:43:21 +00:00
|
|
|
return {};
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
2022-08-18 12:49:37 -07:00
|
|
|
return parser.getChecked<SharedEncodingAttr>(parser.getContext(), vec,
|
|
|
|
perPhase, maxPhase, order);
|
2022-05-01 22:06:54 +08:00
|
|
|
}
|
|
|
|
|
2022-08-18 12:49:37 -07:00
|
|
|
void SharedEncodingAttr::print(AsmPrinter &printer) const {
|
2022-06-05 14:25:09 +08:00
|
|
|
printer << "<{"
|
2022-07-26 17:25:03 -07:00
|
|
|
<< "vec = " << getVec() << ", perPhase = " << getPerPhase()
|
|
|
|
<< ", maxPhase = " << getMaxPhase() << ", order = [" << getOrder()
|
|
|
|
<< "]"
|
2022-06-05 14:25:09 +08:00
|
|
|
<< "}>";
|
2022-05-01 22:06:54 +08:00
|
|
|
}
|
|
|
|
|
2022-11-10 13:57:27 +08:00
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
// DotOperand Encoding
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
Attribute DotOperandEncodingAttr::parse(AsmParser &parser, Type type) {
|
|
|
|
if (parser.parseLess().failed())
|
|
|
|
return {};
|
|
|
|
NamedAttrList attrs;
|
|
|
|
if (parser.parseOptionalAttrDict(attrs).failed())
|
|
|
|
return {};
|
|
|
|
if (parser.parseGreater().failed())
|
|
|
|
return {};
|
|
|
|
unsigned opIdx = attrs.get("opIdx").cast<IntegerAttr>().getInt();
|
|
|
|
Attribute parent = attrs.get("parent");
|
2022-12-11 19:01:57 -08:00
|
|
|
Attribute isMMAv1Row;
|
|
|
|
if(parent.isa<MmaEncodingAttr>() &&
|
|
|
|
parent.cast<MmaEncodingAttr>().getVersion() == 1){
|
|
|
|
isMMAv1Row = attrs.get("isMMAv1Row");
|
|
|
|
if(!isMMAv1Row)
|
|
|
|
llvm::report_fatal_error("isMMAv1Row attribute is missing");
|
|
|
|
}
|
2022-11-10 13:57:27 +08:00
|
|
|
return parser.getChecked<DotOperandEncodingAttr>(parser.getContext(), opIdx,
|
2022-12-11 19:01:57 -08:00
|
|
|
parent, isMMAv1Row);
|
2022-11-10 13:57:27 +08:00
|
|
|
}
|
|
|
|
|
|
|
|
void DotOperandEncodingAttr::print(mlir::AsmPrinter &printer) const {
|
|
|
|
printer << "<{"
|
|
|
|
<< "opIdx = " << getOpIdx() << ", "
|
2022-12-11 19:01:57 -08:00
|
|
|
<< "parent = " << getParent();
|
|
|
|
if(getIsMMAv1Row())
|
|
|
|
printer << ", isMMAv1Row = " << getIsMMAv1Row();
|
|
|
|
printer << "}>";
|
2022-11-10 13:57:27 +08:00
|
|
|
}
|
|
|
|
|
Keren/tensor slice insert alloc (#94)
This branch defines three new triton_gpu operations to partially solve #87. Below is an overview:
```
%tensor = triton_gpu.alloc_tensor : tensor<2x16x16xf16, #A>
%b = triton_gpu.insert_slice_async %a_ptr, %tensor, %offset {axis = 0 : i32, cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<16x16x!tt.ptr<f16>, #AL> -> tensor<2x16x16xf16, #A>
%c = triton_gpu.extract_slice %b, %offset {axis = 0 : i32} : tensor<2x16x16xf16, #A> -> tensor<16x16xf16, #A>
```
We plan to fully replace `copy_async` with `insert_slice_async`. **This hasn't been done yet.**
2022-09-01 12:37:17 -07:00
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
// InsertSliceAsyncOp
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|
|
|
|
ParseResult parseInsertSliceAsyncOp(OpAsmParser &parser,
|
|
|
|
OperationState &result) {
|
2022-11-06 22:59:03 -08:00
|
|
|
SmallVector<OpAsmParser::OperandType, 8> allOperands;
|
Keren/tensor slice insert alloc (#94)
This branch defines three new triton_gpu operations to partially solve #87. Below is an overview:
```
%tensor = triton_gpu.alloc_tensor : tensor<2x16x16xf16, #A>
%b = triton_gpu.insert_slice_async %a_ptr, %tensor, %offset {axis = 0 : i32, cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<16x16x!tt.ptr<f16>, #AL> -> tensor<2x16x16xf16, #A>
%c = triton_gpu.extract_slice %b, %offset {axis = 0 : i32} : tensor<2x16x16xf16, #A> -> tensor<16x16xf16, #A>
```
We plan to fully replace `copy_async` with `insert_slice_async`. **This hasn't been done yet.**
2022-09-01 12:37:17 -07:00
|
|
|
Type srcType, dstType;
|
|
|
|
SMLoc allOperandLoc = parser.getCurrentLocation();
|
|
|
|
if (parser.parseOperandList(allOperands) ||
|
|
|
|
parser.parseOptionalAttrDict(result.attributes) || parser.parseColon() ||
|
|
|
|
parser.parseCustomTypeWithFallback(srcType) || parser.parseArrow() ||
|
|
|
|
parser.parseCustomTypeWithFallback(dstType))
|
|
|
|
return failure();
|
|
|
|
result.addTypes(dstType);
|
|
|
|
|
|
|
|
SmallVector<Type> operandTypes;
|
|
|
|
operandTypes.push_back(srcType); // src
|
|
|
|
operandTypes.push_back(dstType); // dst
|
|
|
|
operandTypes.push_back(
|
2022-09-09 12:03:41 -07:00
|
|
|
IntegerType::get(parser.getBuilder().getContext(), 32)); // index
|
2022-11-06 22:59:03 -08:00
|
|
|
|
|
|
|
int hasMask = 0, hasOther = 0;
|
|
|
|
if (allOperands.size() >= 4) {
|
Keren/tensor slice insert alloc (#94)
This branch defines three new triton_gpu operations to partially solve #87. Below is an overview:
```
%tensor = triton_gpu.alloc_tensor : tensor<2x16x16xf16, #A>
%b = triton_gpu.insert_slice_async %a_ptr, %tensor, %offset {axis = 0 : i32, cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<16x16x!tt.ptr<f16>, #AL> -> tensor<2x16x16xf16, #A>
%c = triton_gpu.extract_slice %b, %offset {axis = 0 : i32} : tensor<2x16x16xf16, #A> -> tensor<16x16xf16, #A>
```
We plan to fully replace `copy_async` with `insert_slice_async`. **This hasn't been done yet.**
2022-09-01 12:37:17 -07:00
|
|
|
operandTypes.push_back(triton::getI1SameShape(srcType)); // mask
|
2022-11-06 22:59:03 -08:00
|
|
|
hasMask = 1;
|
|
|
|
}
|
|
|
|
if (allOperands.size() >= 5) {
|
Keren/tensor slice insert alloc (#94)
This branch defines three new triton_gpu operations to partially solve #87. Below is an overview:
```
%tensor = triton_gpu.alloc_tensor : tensor<2x16x16xf16, #A>
%b = triton_gpu.insert_slice_async %a_ptr, %tensor, %offset {axis = 0 : i32, cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<16x16x!tt.ptr<f16>, #AL> -> tensor<2x16x16xf16, #A>
%c = triton_gpu.extract_slice %b, %offset {axis = 0 : i32} : tensor<2x16x16xf16, #A> -> tensor<16x16xf16, #A>
```
We plan to fully replace `copy_async` with `insert_slice_async`. **This hasn't been done yet.**
2022-09-01 12:37:17 -07:00
|
|
|
operandTypes.push_back(triton::getPointeeType(srcType)); // other
|
2022-11-06 22:59:03 -08:00
|
|
|
hasOther = 1;
|
|
|
|
}
|
Keren/tensor slice insert alloc (#94)
This branch defines three new triton_gpu operations to partially solve #87. Below is an overview:
```
%tensor = triton_gpu.alloc_tensor : tensor<2x16x16xf16, #A>
%b = triton_gpu.insert_slice_async %a_ptr, %tensor, %offset {axis = 0 : i32, cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<16x16x!tt.ptr<f16>, #AL> -> tensor<2x16x16xf16, #A>
%c = triton_gpu.extract_slice %b, %offset {axis = 0 : i32} : tensor<2x16x16xf16, #A> -> tensor<16x16xf16, #A>
```
We plan to fully replace `copy_async` with `insert_slice_async`. **This hasn't been done yet.**
2022-09-01 12:37:17 -07:00
|
|
|
|
|
|
|
if (parser.resolveOperands(allOperands, operandTypes, allOperandLoc,
|
|
|
|
result.operands))
|
|
|
|
return failure();
|
2022-11-06 22:59:03 -08:00
|
|
|
|
|
|
|
// Deduce operand_segment_sizes from the number of the operands.
|
|
|
|
auto operand_segment_sizesAttrName =
|
|
|
|
InsertSliceAsyncOp::operand_segment_sizesAttrName(result.name);
|
|
|
|
result.addAttribute(
|
|
|
|
operand_segment_sizesAttrName,
|
|
|
|
parser.getBuilder().getI32VectorAttr({1, 1, 1, hasMask, hasOther}));
|
Keren/tensor slice insert alloc (#94)
This branch defines three new triton_gpu operations to partially solve #87. Below is an overview:
```
%tensor = triton_gpu.alloc_tensor : tensor<2x16x16xf16, #A>
%b = triton_gpu.insert_slice_async %a_ptr, %tensor, %offset {axis = 0 : i32, cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<16x16x!tt.ptr<f16>, #AL> -> tensor<2x16x16xf16, #A>
%c = triton_gpu.extract_slice %b, %offset {axis = 0 : i32} : tensor<2x16x16xf16, #A> -> tensor<16x16xf16, #A>
```
We plan to fully replace `copy_async` with `insert_slice_async`. **This hasn't been done yet.**
2022-09-01 12:37:17 -07:00
|
|
|
return success();
|
|
|
|
}
|
|
|
|
|
|
|
|
void printInsertSliceAsyncOp(OpAsmPrinter &printer,
|
|
|
|
InsertSliceAsyncOp insertSliceAsyncOp) {
|
|
|
|
printer << " ";
|
|
|
|
printer << insertSliceAsyncOp.getOperation()->getOperands();
|
2022-11-06 22:59:03 -08:00
|
|
|
// "operand_segment_sizes" can be deduced, so we don't print it.
|
|
|
|
printer.printOptionalAttrDict(
|
|
|
|
insertSliceAsyncOp->getAttrs(),
|
|
|
|
{insertSliceAsyncOp.operand_segment_sizesAttrName()});
|
Keren/tensor slice insert alloc (#94)
This branch defines three new triton_gpu operations to partially solve #87. Below is an overview:
```
%tensor = triton_gpu.alloc_tensor : tensor<2x16x16xf16, #A>
%b = triton_gpu.insert_slice_async %a_ptr, %tensor, %offset {axis = 0 : i32, cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<16x16x!tt.ptr<f16>, #AL> -> tensor<2x16x16xf16, #A>
%c = triton_gpu.extract_slice %b, %offset {axis = 0 : i32} : tensor<2x16x16xf16, #A> -> tensor<16x16xf16, #A>
```
We plan to fully replace `copy_async` with `insert_slice_async`. **This hasn't been done yet.**
2022-09-01 12:37:17 -07:00
|
|
|
printer << " : ";
|
|
|
|
printer.printStrippedAttrOrType(insertSliceAsyncOp.src().getType());
|
|
|
|
printer << " -> ";
|
|
|
|
printer.printStrippedAttrOrType(insertSliceAsyncOp.result().getType());
|
|
|
|
}
|
|
|
|
|
2022-07-31 13:59:44 -07:00
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
// ASM Interface (i.e.: alias)
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|
2022-07-26 10:50:11 -07:00
|
|
|
class TritonGPUOpAsmInterface : public OpAsmDialectInterface {
|
2022-07-26 17:25:03 -07:00
|
|
|
public:
|
2022-07-26 10:50:11 -07:00
|
|
|
using OpAsmDialectInterface::OpAsmDialectInterface;
|
|
|
|
|
|
|
|
AliasResult getAlias(Attribute attr, raw_ostream &os) const override {
|
2022-08-18 12:49:37 -07:00
|
|
|
if (auto mmaAttr = attr.dyn_cast<MmaEncodingAttr>()) {
|
2022-07-26 10:50:11 -07:00
|
|
|
os << "mma";
|
|
|
|
return AliasResult::FinalAlias;
|
2022-08-18 12:49:37 -07:00
|
|
|
} else if (auto sharedAttr = attr.dyn_cast<SharedEncodingAttr>()) {
|
2022-07-26 10:50:11 -07:00
|
|
|
os << "shared";
|
|
|
|
return AliasResult::FinalAlias;
|
2022-08-18 12:49:37 -07:00
|
|
|
} else if (auto blockedAttr = attr.dyn_cast<BlockedEncodingAttr>()) {
|
2022-07-26 10:50:11 -07:00
|
|
|
os << "blocked";
|
2022-07-27 01:32:10 -07:00
|
|
|
return AliasResult::FinalAlias;
|
2022-08-18 12:49:37 -07:00
|
|
|
} /* else if (auto sliceAttr = attr.dyn_cast<SliceEncodingAttr>()) {
|
|
|
|
os << "slice";
|
|
|
|
return AliasResult::FinalAlias;
|
|
|
|
} */
|
2022-08-24 12:55:49 -07:00
|
|
|
return OpAsmDialectInterface::getAlias(attr, os);
|
2022-07-26 10:50:11 -07:00
|
|
|
}
|
|
|
|
};
|
|
|
|
|
2022-10-11 18:16:41 -07:00
|
|
|
struct TritonGPUInferLayoutInterface
|
|
|
|
: public triton::DialectInferLayoutInterface {
|
|
|
|
using DialectInferLayoutInterface::DialectInferLayoutInterface;
|
|
|
|
|
2022-10-28 12:36:09 -07:00
|
|
|
LogicalResult
|
|
|
|
inferReduceOpEncoding(Attribute operandEncoding, unsigned axis,
|
|
|
|
Attribute &resultEncoding) const override {
|
2022-10-11 18:16:41 -07:00
|
|
|
resultEncoding = SliceEncodingAttr::get(getDialect()->getContext(), axis,
|
|
|
|
operandEncoding);
|
|
|
|
return success();
|
|
|
|
}
|
|
|
|
|
2022-10-28 12:36:09 -07:00
|
|
|
LogicalResult
|
|
|
|
inferExpandDimsOpEncoding(Attribute operandEncoding, unsigned axis,
|
2022-11-10 13:57:27 +08:00
|
|
|
Attribute &resultEncoding,
|
|
|
|
Optional<Location> location) const override {
|
2022-10-11 18:16:41 -07:00
|
|
|
auto sliceEncoding = operandEncoding.dyn_cast<SliceEncodingAttr>();
|
2022-11-10 13:57:27 +08:00
|
|
|
if (!sliceEncoding)
|
|
|
|
return emitOptionalError(
|
|
|
|
location, "ExpandDimsOp operand encoding must be SliceEncodingAttr");
|
|
|
|
if (sliceEncoding.getDim() != axis)
|
|
|
|
return emitOptionalError(
|
|
|
|
location, "Incompatible slice dimension for ExpandDimsOp operand");
|
2022-10-11 18:16:41 -07:00
|
|
|
resultEncoding = sliceEncoding.getParent();
|
|
|
|
return success();
|
|
|
|
}
|
2022-11-10 13:57:27 +08:00
|
|
|
|
|
|
|
LogicalResult inferDotOpEncoding(Attribute operandEncoding, unsigned opIdx,
|
|
|
|
Attribute retEncoding,
|
|
|
|
Optional<Location> location) const override {
|
|
|
|
if (auto dotOpEnc = operandEncoding.dyn_cast<DotOperandEncodingAttr>()) {
|
|
|
|
if (opIdx != dotOpEnc.getOpIdx())
|
|
|
|
return emitOptionalError(location, "Wrong opIdx");
|
|
|
|
if (retEncoding != dotOpEnc.getParent())
|
|
|
|
return emitOptionalError(location, "Incompatible parent encoding");
|
|
|
|
} else
|
|
|
|
return emitOptionalError(
|
|
|
|
location, "Dot's a/b's encoding should be of DotOperandEncodingAttr");
|
|
|
|
return success();
|
|
|
|
}
|
2022-10-11 18:16:41 -07:00
|
|
|
};
|
|
|
|
|
2022-04-28 18:51:31 +08:00
|
|
|
void TritonGPUDialect::initialize() {
|
2022-05-02 21:51:00 +08:00
|
|
|
addAttributes<
|
|
|
|
#define GET_ATTRDEF_LIST
|
|
|
|
#include "triton/Dialect/TritonGPU/IR/TritonGPUAttrDefs.cpp.inc"
|
|
|
|
>();
|
2022-04-28 18:51:31 +08:00
|
|
|
addOperations<
|
|
|
|
#define GET_OP_LIST
|
|
|
|
#include "triton/Dialect/TritonGPU/IR/Ops.cpp.inc"
|
2022-07-26 17:25:03 -07:00
|
|
|
>();
|
2022-07-26 10:50:11 -07:00
|
|
|
addInterfaces<TritonGPUOpAsmInterface>();
|
2022-10-11 18:16:41 -07:00
|
|
|
addInterfaces<TritonGPUInferLayoutInterface>();
|
2022-04-28 18:51:31 +08:00
|
|
|
}
|
2022-05-01 22:06:54 +08:00
|
|
|
|
|
|
|
#define GET_OP_CLASSES
|
|
|
|
#include "triton/Dialect/TritonGPU/IR/Ops.cpp.inc"
|
2022-05-24 19:48:56 +08:00
|
|
|
|
|
|
|
// verify TritonGPU ops
|
2022-07-26 17:25:03 -07:00
|
|
|
LogicalResult TritonGPUDialect::verifyOperationAttribute(Operation *op,
|
|
|
|
NamedAttribute attr) {
|
2022-05-24 19:48:56 +08:00
|
|
|
// TODO: fill this.
|
|
|
|
return success();
|
2022-10-24 11:40:13 +08:00
|
|
|
}
|