[WIP][Triton-MLIR] Prefetch pass fixup (#873)

A (potential) problem by directly adopting `tensor.extract_slice`.

Long story short, `tensor.extract_slice` is not aware of swizzling.
Consider the following shared memory tensor and its first three slices,
where each slice includes two tile (the loading unit of LDGSTS) of
elements. Currently, the tiles haven't been swizzled yet, so slicing
seems to work.

<img width="1219" alt="image"
src="https://user-images.githubusercontent.com/2306281/201833023-a7950705-2d50-4c0a-8527-7505261c3a3c.png">

However, now consider the following figure, which is the layout after
applying swizzling on the first figure.

<img width="1244" alt="image"
src="https://user-images.githubusercontent.com/2306281/201834824-7daae360-f5bc-4e6b-a921-20be3f294b78.png">

Note that on phase 2, all tiles have been swizzled out of their
originally slices. This implies that if we use the tile index after
slicing, we can no longer locate the correct tiles. For example, T3 was
in slice 1 but got swapped to slice 0 after swizzling.

Here's a more detailed explanation. In the current `triton-mlir` branch,
we only compute the relative offset of each tile. So T3's index in Slice
1 is *1*, and it will be swizzled using *1* and *phase id*. Whereas the
correct index of T3 should be *3*, which is the relative offset to the
beginning of the shared memory tensor being swizzled, and T3 should be
swizzled using *3* and *phase id*.

This PR proposes a hacky solution for this problem. We restore the
"correct" offset of each tile by **assuming that slicing on a specific
dim only happens at most once on the output of insert_slice_async**. I
admit it's risky and fragile.

The other possible solution is adopting cutlass' swizzling logic that
limits the indices being swizzled in a "bounding box" that matches the
mma instruction executes. For example, in the following tensor layout,
each 4x4 submatrix is a minimum swizzling unit, and the entire tensor
represents the tensor layout of operand A in `mma.16816`.

<img width="565" alt="image"
src="https://user-images.githubusercontent.com/2306281/201836879-4ca7824b-530c-4a06-a3d5-1e74a2de1b42.png">

Co-authored-by: Phil Tillet <phil@openai.com>
This commit is contained in:
Keren Zhou
2022-11-19 19:57:16 -08:00
committed by GitHub
parent e8994209f4
commit 6c5f646f4e
7 changed files with 146 additions and 70 deletions

View File

@@ -99,6 +99,7 @@ void llPrintf(StringRef msg, ValueRange args,
#define udiv(...) rewriter.create<LLVM::UDivOp>(loc, __VA_ARGS__) #define udiv(...) rewriter.create<LLVM::UDivOp>(loc, __VA_ARGS__)
#define urem(...) rewriter.create<LLVM::URemOp>(loc, __VA_ARGS__) #define urem(...) rewriter.create<LLVM::URemOp>(loc, __VA_ARGS__)
#define add(...) rewriter.create<LLVM::AddOp>(loc, __VA_ARGS__) #define add(...) rewriter.create<LLVM::AddOp>(loc, __VA_ARGS__)
#define sub(...) rewriter.create<LLVM::SubOp>(loc, __VA_ARGS__)
#define fadd(...) rewriter.create<LLVM::FAddOp>(loc, __VA_ARGS__) #define fadd(...) rewriter.create<LLVM::FAddOp>(loc, __VA_ARGS__)
#define mul(...) rewriter.create<LLVM::MulOp>(loc, __VA_ARGS__) #define mul(...) rewriter.create<LLVM::MulOp>(loc, __VA_ARGS__)
#define smax(...) rewriter.create<LLVM::SMaxOp>(loc, __VA_ARGS__) #define smax(...) rewriter.create<LLVM::SMaxOp>(loc, __VA_ARGS__)
@@ -441,25 +442,48 @@ struct SharedMemoryObject {
// if we want to support more optimizations. // if we want to support more optimizations.
SmallVector<Value> SmallVector<Value>
strides; // i32 int. The strides of the shared memory object. strides; // i32 int. The strides of the shared memory object.
SmallVector<Value> offsets; // i32 int. The offsets of the shared memory
// objects from the originally allocated object.
SharedMemoryObject(Value base, ArrayRef<Value> strides) SharedMemoryObject(Value base, ArrayRef<Value> strides,
: base(base), strides(strides.begin(), strides.end()) {} ArrayRef<Value> offsets)
: base(base), strides(strides.begin(), strides.end()),
offsets(offsets.begin(), offsets.end()) {}
SharedMemoryObject(Value base, ArrayRef<int64_t> shape,
ArrayRef<unsigned> order, Location loc,
ConversionPatternRewriter &rewriter)
: base(base) {
auto rank = shape.size();
auto stride = 1;
strides.resize(rank);
for (auto idx : order) {
strides[idx] = i32_val(stride);
offsets.emplace_back(i32_val(0));
stride *= shape[idx];
}
}
// XXX(Keren): a special allocator for 3d tensors. It's a workaround for
// now since we don't have a correct way to encoding 3d tensors in the
// pipeline pass.
SharedMemoryObject(Value base, ArrayRef<int64_t> shape, Location loc, SharedMemoryObject(Value base, ArrayRef<int64_t> shape, Location loc,
ConversionPatternRewriter &rewriter) ConversionPatternRewriter &rewriter)
: base(base) { : base(base) {
auto stride = 1; auto stride = 1;
for (auto dim : llvm::reverse(shape)) { for (auto dim : llvm::reverse(shape)) {
this->strides.emplace_back(i32_val(stride)); strides.emplace_back(i32_val(stride));
offsets.emplace_back(i32_val(0));
stride *= dim; stride *= dim;
} }
this->strides = llvm::to_vector<4>(llvm::reverse(this->strides)); strides = llvm::to_vector<4>(llvm::reverse(strides));
} }
SmallVector<Value> getElems() const { SmallVector<Value> getElems() const {
SmallVector<Value> elems; SmallVector<Value> elems;
elems.push_back(base); elems.push_back(base);
elems.append(strides.begin(), strides.end()); elems.append(strides.begin(), strides.end());
elems.append(offsets.begin(), offsets.end());
return elems; return elems;
} }
@@ -467,8 +491,22 @@ struct SharedMemoryObject {
SmallVector<Type> types; SmallVector<Type> types;
types.push_back(base.getType()); types.push_back(base.getType());
types.append(strides.size(), IntegerType::get(base.getContext(), 32)); types.append(strides.size(), IntegerType::get(base.getContext(), 32));
types.append(offsets.size(), IntegerType::get(base.getContext(), 32));
return types; return types;
} }
Value getCSwizzleOffset(int order) const {
assert(order >= 0 && order < strides.size());
return offsets[order];
}
Value getBaseBeforeSwizzle(int order, Location loc,
ConversionPatternRewriter &rewriter) const {
Value cSwizzleOffset = getCSwizzleOffset(order);
Value offset = sub(i32_val(0), cSwizzleOffset);
Type type = base.getType();
return gep(type, base, offset);
}
}; };
struct ConvertTritonGPUOpToLLVMPatternBase { struct ConvertTritonGPUOpToLLVMPatternBase {
@@ -493,8 +531,11 @@ struct ConvertTritonGPUOpToLLVMPatternBase {
getSharedMemoryObjectFromStruct(Location loc, Value llvmStruct, getSharedMemoryObjectFromStruct(Location loc, Value llvmStruct,
ConversionPatternRewriter &rewriter) { ConversionPatternRewriter &rewriter) {
auto elems = getElementsFromStruct(loc, llvmStruct, rewriter); auto elems = getElementsFromStruct(loc, llvmStruct, rewriter);
return SharedMemoryObject(/*base=*/elems[0], auto rank = (elems.size() - 1) / 2;
/*strides=*/{elems.begin() + 1, elems.end()}); return SharedMemoryObject(
/*base=*/elems[0],
/*strides=*/{elems.begin() + 1, elems.begin() + 1 + rank},
/*offsets=*/{elems.begin() + 1 + rank, elems.end()});
} }
static Value static Value
@@ -2238,31 +2279,34 @@ struct ExtractSliceOpConversion
// Triton support either static and dynamic offsets // Triton support either static and dynamic offsets
auto smemObj = auto smemObj =
getSharedMemoryObjectFromStruct(loc, adaptor.source(), rewriter); getSharedMemoryObjectFromStruct(loc, adaptor.source(), rewriter);
SmallVector<Value, 4> opOffsetVals;
SmallVector<Value, 4> offsetVals; SmallVector<Value, 4> offsetVals;
auto mixedOffsets = op.getMixedOffsets(); auto mixedOffsets = op.getMixedOffsets();
for (auto i = 0; i < mixedOffsets.size(); ++i) { for (auto i = 0; i < mixedOffsets.size(); ++i) {
if (op.isDynamicOffset(i)) if (op.isDynamicOffset(i))
offsetVals.emplace_back(adaptor.offsets()[i]); opOffsetVals.emplace_back(adaptor.offsets()[i]);
else else
offsetVals.emplace_back(i32_val(op.getStaticOffset(i))); opOffsetVals.emplace_back(i32_val(op.getStaticOffset(i)));
offsetVals.emplace_back(add(smemObj.offsets[i], opOffsetVals[i]));
} }
// Compute the offset based on the original strides of the shared memory // Compute the offset based on the original strides of the shared memory
// object // object
auto offset = dot(rewriter, loc, offsetVals, smemObj.strides); auto offset = dot(rewriter, loc, opOffsetVals, smemObj.strides);
// newShape = rank_reduce(shape) // newShape = rank_reduce(shape)
// Triton only supports static tensor sizes // Triton only supports static tensor sizes
SmallVector<Value, 4> strideVals; SmallVector<Value, 4> strideVals;
auto staticSizes = op.static_sizes();
for (auto i = 0; i < op.static_sizes().size(); ++i) { for (auto i = 0; i < op.static_sizes().size(); ++i) {
if (op.getStaticSize(i) != 1) { if (op.getStaticSize(i) == 1) {
offsetVals.erase(offsetVals.begin() + i);
} else {
strideVals.emplace_back(smemObj.strides[i]); strideVals.emplace_back(smemObj.strides[i]);
} }
} }
auto llvmElemTy = getTypeConverter()->convertType(srcTy.getElementType()); auto llvmElemTy = getTypeConverter()->convertType(srcTy.getElementType());
auto elemPtrTy = ptr_ty(llvmElemTy, 3); auto elemPtrTy = ptr_ty(llvmElemTy, 3);
auto resTy = op.getType().dyn_cast<RankedTensorType>(); auto resTy = op.getType().dyn_cast<RankedTensorType>();
smemObj = smemObj = SharedMemoryObject(gep(elemPtrTy, smemObj.base, offset),
SharedMemoryObject(gep(elemPtrTy, smemObj.base, offset), strideVals); strideVals, offsetVals);
auto retVal = getStructFromSharedMemoryObject(loc, smemObj, rewriter); auto retVal = getStructFromSharedMemoryObject(loc, smemObj, rewriter);
rewriter.replaceOp(op, retVal); rewriter.replaceOp(op, retVal);
return success(); return success();
@@ -3128,7 +3172,7 @@ LogicalResult ConvertLayoutOpConversion::lowerBlockedToShared(
Value smemBase = getSharedMemoryBase(loc, rewriter, dst); Value smemBase = getSharedMemoryBase(loc, rewriter, dst);
auto elemPtrTy = ptr_ty(getTypeConverter()->convertType(elemTy), 3); auto elemPtrTy = ptr_ty(getTypeConverter()->convertType(elemTy), 3);
smemBase = bitcast(smemBase, elemPtrTy); smemBase = bitcast(smemBase, elemPtrTy);
auto smemObj = SharedMemoryObject(smemBase, dstShape, loc, rewriter); auto smemObj = SharedMemoryObject(smemBase, dstShape, outOrd, loc, rewriter);
auto retVal = getStructFromSharedMemoryObject(loc, smemObj, rewriter); auto retVal = getStructFromSharedMemoryObject(loc, smemObj, rewriter);
unsigned numWordsEachRep = product<unsigned>(wordsInEachRep); unsigned numWordsEachRep = product<unsigned>(wordsInEachRep);
SmallVector<Value> wordVecs(numWordsEachRep); SmallVector<Value> wordVecs(numWordsEachRep);
@@ -3228,17 +3272,16 @@ public:
if (canUseLdmatrix) { if (canUseLdmatrix) {
// Each CTA, the warps is arranged as [1xwpt] if not transposed, // Each CTA, the warps is arranged as [1xwpt] if not transposed,
// otherwise [wptx1], and each warp will perform a mma. // otherwise [wptx1], and each warp will perform a mma.
numPtr = numPtrs =
tileShape[order[0]] / (needTrans ? wpt : 1) / instrShape[order[0]]; tileShape[order[0]] / (needTrans ? wpt : 1) / instrShape[order[0]];
} else { } else {
numPtr = tileShape[order[0]] / wpt / matShape[order[0]]; numPtrs = tileShape[order[0]] / wpt / matShape[order[0]];
} }
numPtrs = std::max<int>(numPtrs, 2);
numPtr = std::max<int>(numPtr, 2);
// Special rule for i8/u8, 4 ptrs for each matrix // Special rule for i8/u8, 4 ptrs for each matrix
if (!canUseLdmatrix && elemBytes == 1) if (!canUseLdmatrix && elemBytes == 1)
numPtr *= 4; numPtrs *= 4;
int loadStrideInMat[2]; int loadStrideInMat[2];
loadStrideInMat[kOrder] = loadStrideInMat[kOrder] =
@@ -3257,24 +3300,26 @@ public:
// lane = thread % 32 // lane = thread % 32
// warpOff = (thread/32) % wpt(0) // warpOff = (thread/32) % wpt(0)
llvm::SmallVector<Value> computeOffsets(Value warpOff, Value lane) { llvm::SmallVector<Value> computeOffsets(Value warpOff, Value lane,
Value cSwizzleOffset) {
if (canUseLdmatrix) if (canUseLdmatrix)
return computeLdmatrixMatOffs(warpOff, lane); return computeLdmatrixMatOffs(warpOff, lane, cSwizzleOffset);
else if (elemBytes == 4 && needTrans) else if (elemBytes == 4 && needTrans)
return computeB32MatOffs(warpOff, lane); return computeB32MatOffs(warpOff, lane, cSwizzleOffset);
else if (elemBytes == 1 && needTrans) else if (elemBytes == 1 && needTrans)
return computeB8MatOffs(warpOff, lane); return computeB8MatOffs(warpOff, lane, cSwizzleOffset);
else else
llvm::report_fatal_error("Invalid smem load config"); llvm::report_fatal_error("Invalid smem load config");
return {}; return {};
} }
int getNumPtr() const { return numPtr; } int getNumPtrs() const { return numPtrs; }
// Compute the offset to the matrix this thread(indexed by warpOff and lane) // Compute the offset to the matrix this thread(indexed by warpOff and lane)
// mapped to. // mapped to.
SmallVector<Value> computeLdmatrixMatOffs(Value warpId, Value lane) { SmallVector<Value> computeLdmatrixMatOffs(Value warpId, Value lane,
Value cSwizzleOffset) {
// 4x4 matrices // 4x4 matrices
Value c = urem(lane, i32_val(8)); Value c = urem(lane, i32_val(8));
Value s = udiv(lane, i32_val(8)); // sub-warp-id Value s = udiv(lane, i32_val(8)); // sub-warp-id
@@ -3312,14 +3357,16 @@ public:
// Physical offset (before swizzling) // Physical offset (before swizzling)
Value cMatOff = matOff[order[0]]; Value cMatOff = matOff[order[0]];
Value sMatOff = matOff[order[1]]; Value sMatOff = matOff[order[1]];
Value cSwizzleMatOff = udiv(cSwizzleOffset, i32_val(cMatShape));
cMatOff = add(cMatOff, cSwizzleMatOff);
// row offset inside a matrix, each matrix has 8 rows. // row offset inside a matrix, each matrix has 8 rows.
Value sOffInMat = c; Value sOffInMat = c;
SmallVector<Value> offs(numPtr); SmallVector<Value> offs(numPtrs);
Value phase = urem(udiv(sOffInMat, i32_val(perPhase)), i32_val(maxPhase)); Value phase = urem(udiv(sOffInMat, i32_val(perPhase)), i32_val(maxPhase));
Value sOff = add(sOffInMat, mul(sMatOff, i32_val(sMatShape))); Value sOff = add(sOffInMat, mul(sMatOff, i32_val(sMatShape)));
for (int i = 0; i < numPtr; ++i) { for (int i = 0; i < numPtrs; ++i) {
Value cMatOffI = add(cMatOff, i32_val(i * pLoadStrideInMat)); Value cMatOffI = add(cMatOff, i32_val(i * pLoadStrideInMat));
cMatOffI = xor_(cMatOffI, phase); cMatOffI = xor_(cMatOffI, phase);
offs[i] = add(mul(cMatOffI, i32_val(cMatShape)), mul(sOff, sStride)); offs[i] = add(mul(cMatOffI, i32_val(cMatShape)), mul(sOff, sStride));
@@ -3329,14 +3376,15 @@ public:
} }
// Compute 32-bit matrix offsets. // Compute 32-bit matrix offsets.
SmallVector<Value> computeB32MatOffs(Value warpOff, Value lane) { SmallVector<Value> computeB32MatOffs(Value warpOff, Value lane,
Value cSwizzleOffset) {
assert(needTrans && "Only used in transpose mode."); assert(needTrans && "Only used in transpose mode.");
// Load tf32 matrices with lds32 // Load tf32 matrices with lds32
Value cOffInMat = udiv(lane, i32_val(4)); Value cOffInMat = udiv(lane, i32_val(4));
Value sOffInMat = urem(lane, i32_val(4)); Value sOffInMat = urem(lane, i32_val(4));
Value phase = urem(udiv(sOffInMat, i32_val(perPhase)), i32_val(maxPhase)); Value phase = urem(udiv(sOffInMat, i32_val(perPhase)), i32_val(maxPhase));
SmallVector<Value> offs(numPtr); SmallVector<Value> offs(numPtrs);
for (int mat = 0; mat < 4; ++mat) { // Load 4 mats each time for (int mat = 0; mat < 4; ++mat) { // Load 4 mats each time
int kMatArrInt = kOrder == 1 ? mat / 2 : mat % 2; int kMatArrInt = kOrder == 1 ? mat / 2 : mat % 2;
@@ -3348,10 +3396,13 @@ public:
Value cMatOff = add(mul(warpOff, i32_val(warpOffStride)), Value cMatOff = add(mul(warpOff, i32_val(warpOffStride)),
mul(nkMatArr, i32_val(matArrStride))); mul(nkMatArr, i32_val(matArrStride)));
Value cSwizzleMatOff = udiv(cSwizzleOffset, i32_val(cMatShape));
cMatOff = add(cMatOff, cSwizzleMatOff);
Value sMatOff = kMatArr; Value sMatOff = kMatArr;
Value sOff = add(sOffInMat, mul(sMatOff, i32_val(sMatShape))); Value sOff = add(sOffInMat, mul(sMatOff, i32_val(sMatShape)));
// FIXME: (kOrder == 1?) is really dirty hack // FIXME: (kOrder == 1?) is really dirty hack
for (int i = 0; i < numPtr / 2; ++i) { for (int i = 0; i < numPtrs / 2; ++i) {
Value cMatOffI = Value cMatOffI =
add(cMatOff, i32_val(i * pLoadStrideInMat * (kOrder == 1 ? 1 : 2))); add(cMatOff, i32_val(i * pLoadStrideInMat * (kOrder == 1 ? 1 : 2)));
cMatOffI = xor_(cMatOffI, phase); cMatOffI = xor_(cMatOffI, phase);
@@ -3365,13 +3416,14 @@ public:
} }
// compute 8-bit matrix offset. // compute 8-bit matrix offset.
SmallVector<Value> computeB8MatOffs(Value warpOff, Value lane) { SmallVector<Value> computeB8MatOffs(Value warpOff, Value lane,
Value cSwizzleOffset) {
assert(needTrans && "Only used in transpose mode."); assert(needTrans && "Only used in transpose mode.");
Value cOffInMat = udiv(lane, i32_val(4)); Value cOffInMat = udiv(lane, i32_val(4));
Value sOffInMat = Value sOffInMat =
mul(urem(lane, i32_val(4)), i32_val(4)); // each thread load 4 cols mul(urem(lane, i32_val(4)), i32_val(4)); // each thread load 4 cols
SmallVector<Value> offs(numPtr); SmallVector<Value> offs(numPtrs);
for (int mat = 0; mat < 4; ++mat) { for (int mat = 0; mat < 4; ++mat) {
int kMatArrInt = kOrder == 1 ? mat / 2 : mat % 2; int kMatArrInt = kOrder == 1 ? mat / 2 : mat % 2;
int nkMatArrInt = kOrder == 1 ? mat % 2 : mat / 2; int nkMatArrInt = kOrder == 1 ? mat % 2 : mat / 2;
@@ -3384,7 +3436,7 @@ public:
mul(nkMatArr, i32_val(matArrStride))); mul(nkMatArr, i32_val(matArrStride)));
Value sMatOff = kMatArr; Value sMatOff = kMatArr;
for (int loadx4Off = 0; loadx4Off < numPtr / 8; ++loadx4Off) { for (int loadx4Off = 0; loadx4Off < numPtrs / 8; ++loadx4Off) {
for (int elemOff = 0; elemOff < 4; ++elemOff) { for (int elemOff = 0; elemOff < 4; ++elemOff) {
int ptrOff = loadx4Off * 8 + nkMatArrInt * 4 + elemOff; int ptrOff = loadx4Off * 8 + nkMatArrInt * 4 + elemOff;
Value cMatOffI = add(cMatOff, i32_val(loadx4Off * pLoadStrideInMat * Value cMatOffI = add(cMatOff, i32_val(loadx4Off * pLoadStrideInMat *
@@ -3587,7 +3639,7 @@ private:
bool needTrans; bool needTrans;
bool canUseLdmatrix; bool canUseLdmatrix;
int numPtr; int numPtrs;
int pLoadStrideInMat; int pLoadStrideInMat;
int sMatStride; int sMatStride;
@@ -4392,14 +4444,17 @@ private:
wpt, sharedLayout.getOrder(), kOrder, smemObj.strides, wpt, sharedLayout.getOrder(), kOrder, smemObj.strides,
tensorTy.getShape() /*tileShape*/, instrShape, matShape, perPhase, tensorTy.getShape() /*tileShape*/, instrShape, matShape, perPhase,
maxPhase, elemBytes, rewriter, typeConverter, loc); maxPhase, elemBytes, rewriter, typeConverter, loc);
SmallVector<Value> offs = loader.computeOffsets(warpId, lane); Value cSwizzleOffset = smemObj.getCSwizzleOffset(order[0]);
const int numPtrs = loader.getNumPtr(); SmallVector<Value> offs =
loader.computeOffsets(warpId, lane, cSwizzleOffset);
const int numPtrs = loader.getNumPtrs();
SmallVector<Value> ptrs(numPtrs); SmallVector<Value> ptrs(numPtrs);
Value smemBase = smemObj.getBaseBeforeSwizzle(order[0], loc, rewriter);
Type smemPtrTy = helper.getShemPtrTy(); Type smemPtrTy = helper.getShemPtrTy();
for (int i = 0; i < numPtrs; ++i) { for (int i = 0; i < numPtrs; ++i) {
ptrs[i] = bitcast(gep(smemPtrTy, smemObj.base, ValueRange({offs[i]})), ptrs[i] =
smemPtrTy); bitcast(gep(smemPtrTy, smemBase, ValueRange({offs[i]})), smemPtrTy);
} }
auto [ha0, ha1, ha2, ha3] = loader.loadX4( auto [ha0, ha1, ha2, ha3] = loader.loadX4(
@@ -4432,7 +4487,7 @@ private:
// ... // ...
// (2,0), (2,1), (3,0), (3,1), # i=1, j=0 // (2,0), (2,1), (3,0), (3,1), # i=1, j=0
// (2,2), (2,3), (3,2), (3,3), # i=1, j=1 // (2,2), (2,3), (3,2), (3,3), # i=1, j=1
// (2,4), (2,5), (2,4), (2,5), # i=1, j=2 // (2,4), (2,5), (3,4), (3,5), # i=1, j=2
// ... // ...
// ] // ]
// i \in [0, n0) and j \in [0, n1) // i \in [0, n0) and j \in [0, n1)
@@ -4811,15 +4866,13 @@ DotOpConversion::convertMMA884(triton::DotOp op, DotOpAdaptor adaptor,
Value DotOpMmaV1ConversionHelper::loadA( Value DotOpMmaV1ConversionHelper::loadA(
Value tensor, const SharedMemoryObject &smemObj, Value thread, Location loc, Value tensor, const SharedMemoryObject &smemObj, Value thread, Location loc,
ConversionPatternRewriter &rewriter) const { ConversionPatternRewriter &rewriter) const {
// smem
Value smem = smemObj.base;
auto strides = smemObj.strides;
auto *ctx = rewriter.getContext(); auto *ctx = rewriter.getContext();
auto tensorTy = tensor.getType().cast<RankedTensorType>(); auto tensorTy = tensor.getType().cast<RankedTensorType>();
auto shape = tensorTy.getShape(); auto shape = tensorTy.getShape();
auto sharedLayout = tensorTy.getEncoding().cast<SharedEncodingAttr>(); auto sharedLayout = tensorTy.getEncoding().cast<SharedEncodingAttr>();
auto order = sharedLayout.getOrder(); auto order = sharedLayout.getOrder();
Value cSwizzleOffset = smemObj.getCSwizzleOffset(order[0]);
bool isARow = order[0] != 0; bool isARow = order[0] != 0;
bool isAVec4 = !isARow && shape[order[0]] <= 16; // fp16*4 = 16bytes bool isAVec4 = !isARow && shape[order[0]] <= 16; // fp16*4 = 16bytes
@@ -4834,6 +4887,7 @@ Value DotOpMmaV1ConversionHelper::loadA(
int vecA = sharedLayout.getVec(); int vecA = sharedLayout.getVec();
auto strides = smemObj.strides;
Value strideAM = isARow ? strides[0] : i32_val(1); Value strideAM = isARow ? strides[0] : i32_val(1);
Value strideAK = isARow ? i32_val(1) : strides[1]; Value strideAK = isARow ? i32_val(1) : strides[1];
Value strideA0 = isARow ? strideAK : strideAM; Value strideA0 = isARow ? strideAK : strideAM;
@@ -4856,8 +4910,8 @@ Value DotOpMmaV1ConversionHelper::loadA(
Value offA0 = isARow ? offsetAK : offsetAM; Value offA0 = isARow ? offsetAK : offsetAM;
Value offA1 = isARow ? offsetAM : offsetAK; Value offA1 = isARow ? offsetAM : offsetAK;
Value phaseA = urem(udiv(offA1, i32_val(perPhaseA)), i32_val(maxPhaseA)); Value phaseA = urem(udiv(offA1, i32_val(perPhaseA)), i32_val(maxPhaseA));
offA0 = add(offA0, cSwizzleOffset);
SmallVector<Value> offA(numPtrA); SmallVector<Value> offA(numPtrA);
for (int i = 0; i < numPtrA; i++) { for (int i = 0; i < numPtrA; i++) {
Value offA0I = add(offA0, i32_val(i * (isARow ? 4 : strideRepM))); Value offA0I = add(offA0, i32_val(i * (isARow ? 4 : strideRepM)));
offA0I = udiv(offA0I, i32_val(vecA)); offA0I = udiv(offA0I, i32_val(vecA));
@@ -4875,6 +4929,7 @@ Value DotOpMmaV1ConversionHelper::loadA(
SmallVector<Value> ptrA(numPtrA); SmallVector<Value> ptrA(numPtrA);
std::map<std::pair<int, int>, std::pair<Value, Value>> has; std::map<std::pair<int, int>, std::pair<Value, Value>> has;
auto smem = smemObj.getBaseBeforeSwizzle(order[0], loc, rewriter);
for (int i = 0; i < numPtrA; i++) for (int i = 0; i < numPtrA; i++)
ptrA[i] = gep(ptr_ty(f16_ty), smem, offA[i]); ptrA[i] = gep(ptr_ty(f16_ty), smem, offA[i]);
@@ -4971,6 +5026,8 @@ Value DotOpMmaV1ConversionHelper::loadB(
Value offB0 = isBRow ? offsetBN : offsetBK; Value offB0 = isBRow ? offsetBN : offsetBK;
Value offB1 = isBRow ? offsetBK : offsetBN; Value offB1 = isBRow ? offsetBK : offsetBN;
Value phaseB = urem(udiv(offB1, i32_val(perPhaseB)), i32_val(maxPhaseB)); Value phaseB = urem(udiv(offB1, i32_val(perPhaseB)), i32_val(maxPhaseB));
Value cSwizzleOffset = smemObj.getCSwizzleOffset(order[0]);
offB0 = add(offB0, cSwizzleOffset);
SmallVector<Value> offB(numPtrB); SmallVector<Value> offB(numPtrB);
for (int i = 0; i < numPtrB; ++i) { for (int i = 0; i < numPtrB; ++i) {
Value offB0I = add(offB0, i32_val(i * (isBRow ? strideRepN : 4))); Value offB0I = add(offB0, i32_val(i * (isBRow ? strideRepN : 4)));
@@ -5480,7 +5537,8 @@ public:
types.push_back(ptrType); types.push_back(ptrType);
// shape dims // shape dims
auto rank = type.getRank(); auto rank = type.getRank();
for (auto i = 0; i < rank; i++) { // offsets + strides
for (auto i = 0; i < rank * 2; i++) {
types.push_back(IntegerType::get(ctx, 32)); types.push_back(IntegerType::get(ctx, 32));
} }
return LLVM::LLVMStructType::getLiteral(ctx, types); return LLVM::LLVMStructType::getLiteral(ctx, types);

View File

@@ -143,8 +143,8 @@ bool FpToFpOp::areCastCompatible(::mlir::TypeRange inputs,
std::swap(srcEltType, dstEltType); std::swap(srcEltType, dstEltType);
if (!srcEltType.dyn_cast<mlir::triton::Float8Type>()) if (!srcEltType.dyn_cast<mlir::triton::Float8Type>())
return false; return false;
return dstEltType.isF16() || dstEltType.isBF16() || return dstEltType.isF16() || dstEltType.isBF16() || dstEltType.isF32() ||
dstEltType.isF32() || dstEltType.isF64(); dstEltType.isF64();
} }
//-- StoreOp -- //-- StoreOp --

View File

@@ -33,9 +33,9 @@ struct CoalescePass : public TritonGPUCoalesceBase<CoalescePass> {
SmallVector<unsigned, 4> sizePerThread(rank, 1); SmallVector<unsigned, 4> sizePerThread(rank, 1);
PointerType ptrType = origType.getElementType().cast<PointerType>(); PointerType ptrType = origType.getElementType().cast<PointerType>();
auto pointeeType = ptrType.getPointeeType(); auto pointeeType = ptrType.getPointeeType();
unsigned numBits = unsigned numBits = pointeeType.isa<triton::Float8Type>()
pointeeType.isa<triton::Float8Type>() ? ? 8
8 : pointeeType.getIntOrFloatBitWidth(); : pointeeType.getIntOrFloatBitWidth();
unsigned maxMultiple = info.getDivisibility(order[0]); unsigned maxMultiple = info.getDivisibility(order[0]);
unsigned maxContig = info.getContiguity(order[0]); unsigned maxContig = info.getContiguity(order[0]);
unsigned alignment = std::min(maxMultiple, maxContig); unsigned alignment = std::min(maxMultiple, maxContig);

View File

@@ -78,8 +78,6 @@ public:
if (!llvm::isa<triton::gpu::ConvertLayoutOp>(op)) if (!llvm::isa<triton::gpu::ConvertLayoutOp>(op))
return mlir::failure(); return mlir::failure();
auto convert = llvm::cast<triton::gpu::ConvertLayoutOp>(op); auto convert = llvm::cast<triton::gpu::ConvertLayoutOp>(op);
auto srcType = convert.getOperand().getType().cast<RankedTensorType>();
auto dstType = convert.getType().cast<RankedTensorType>();
// we don't handle conversions to DotOperandEncodingAttr // we don't handle conversions to DotOperandEncodingAttr
// this is a heuristics to accommodate fused attention // this is a heuristics to accommodate fused attention
// if (dstType.getEncoding().isa<triton::gpu::DotOperandEncodingAttr>()) // if (dstType.getEncoding().isa<triton::gpu::DotOperandEncodingAttr>())
@@ -96,6 +94,9 @@ public:
// cvt(alloc_tensor(x), type2) -> alloc_tensor(x, type2) // cvt(alloc_tensor(x), type2) -> alloc_tensor(x, type2)
auto alloc_tensor = dyn_cast<triton::gpu::AllocTensorOp>(arg); auto alloc_tensor = dyn_cast<triton::gpu::AllocTensorOp>(arg);
if (alloc_tensor) { if (alloc_tensor) {
if (!isSharedEncoding(op->getResult(0))) {
return mlir::failure();
}
rewriter.replaceOpWithNewOp<triton::gpu::AllocTensorOp>( rewriter.replaceOpWithNewOp<triton::gpu::AllocTensorOp>(
op, op->getResult(0).getType()); op, op->getResult(0).getType());
return mlir::success(); return mlir::success();
@@ -103,6 +104,9 @@ public:
// cvt(insert_slice(x), type2) -> insert_slice(cvt(x, type2)) // cvt(insert_slice(x), type2) -> insert_slice(cvt(x, type2))
auto insert_slice = dyn_cast<triton::gpu::InsertSliceAsyncOp>(arg); auto insert_slice = dyn_cast<triton::gpu::InsertSliceAsyncOp>(arg);
if (insert_slice) { if (insert_slice) {
if (!isSharedEncoding(op->getResult(0))) {
return mlir::failure();
}
auto newType = op->getResult(0).getType().cast<RankedTensorType>(); auto newType = op->getResult(0).getType().cast<RankedTensorType>();
// Ensure that the new insert_slice op is placed in the same place as the // Ensure that the new insert_slice op is placed in the same place as the
// old insert_slice op. Otherwise, the new insert_slice op may be placed // old insert_slice op. Otherwise, the new insert_slice op may be placed
@@ -121,6 +125,9 @@ public:
// cvt(extract_slice(x), type2) -> extract_slice(cvt(x, type2)) // cvt(extract_slice(x), type2) -> extract_slice(cvt(x, type2))
auto extract_slice = dyn_cast<tensor::ExtractSliceOp>(arg); auto extract_slice = dyn_cast<tensor::ExtractSliceOp>(arg);
if (extract_slice) { if (extract_slice) {
if (!isSharedEncoding(op->getResult(0))) {
return mlir::failure();
}
auto origType = extract_slice.source().getType().cast<RankedTensorType>(); auto origType = extract_slice.source().getType().cast<RankedTensorType>();
auto newType = RankedTensorType::get( auto newType = RankedTensorType::get(
origType.getShape(), origType.getElementType(), origType.getShape(), origType.getElementType(),
@@ -144,16 +151,15 @@ public:
return mlir::success(); return mlir::success();
} }
// cvt(type2, x) // cvt(cvt(x, type1), type2) -> cvt(x, type2)
if (llvm::isa<triton::gpu::ConvertLayoutOp>(arg)) { if (llvm::isa<triton::gpu::ConvertLayoutOp>(arg)) {
auto argType = arg->getOperand(0).getType().cast<RankedTensorType>();
if (arg->getOperand(0).getDefiningOp() && if (arg->getOperand(0).getDefiningOp() &&
!argType.getEncoding().isa<triton::gpu::SharedEncodingAttr>() && !isSharedEncoding(arg->getOperand(0)) &&
srcType.getEncoding().isa<triton::gpu::SharedEncodingAttr>() && isSharedEncoding(convert.getOperand()) &&
!dstType.getEncoding().isa<triton::gpu::SharedEncodingAttr>()) { !isSharedEncoding(convert.getResult())) {
return mlir::failure(); return mlir::failure();
} }
auto srcType = convert.getOperand().getType().cast<RankedTensorType>();
auto srcShared = auto srcShared =
srcType.getEncoding().dyn_cast<triton::gpu::SharedEncodingAttr>(); srcType.getEncoding().dyn_cast<triton::gpu::SharedEncodingAttr>();
if (srcShared && srcShared.getVec() > 1) if (srcShared && srcShared.getVec() > 1)

View File

@@ -27,6 +27,7 @@
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
#include "mlir/IR/BlockAndValueMapping.h" #include "mlir/IR/BlockAndValueMapping.h"
#include "triton/Analysis/Utility.h"
#include "triton/Dialect/TritonGPU/IR/Dialect.h" #include "triton/Dialect/TritonGPU/IR/Dialect.h"
#include "triton/Dialect/TritonGPU/Transforms/Passes.h" #include "triton/Dialect/TritonGPU/Transforms/Passes.h"
@@ -60,7 +61,7 @@ class Prefetcher {
LogicalResult isForOpOperand(Value v); LogicalResult isForOpOperand(Value v);
Value generatePrefetch(Value v, unsigned opIdx, bool isPrefetch, Value generatePrefetch(Value v, unsigned opIdx, bool isPrologue,
Attribute dotEncoding, OpBuilder &builder, Attribute dotEncoding, OpBuilder &builder,
llvm::Optional<int64_t> offsetK = llvm::None, llvm::Optional<int64_t> offsetK = llvm::None,
llvm::Optional<int64_t> shapeK = llvm::None); llvm::Optional<int64_t> shapeK = llvm::None);
@@ -79,7 +80,7 @@ public:
scf::ForOp createNewForOp(); scf::ForOp createNewForOp();
}; };
Value Prefetcher::generatePrefetch(Value v, unsigned opIdx, bool isPrefetch, Value Prefetcher::generatePrefetch(Value v, unsigned opIdx, bool isPrologue,
Attribute dotEncoding, OpBuilder &builder, Attribute dotEncoding, OpBuilder &builder,
llvm::Optional<int64_t> offsetK, llvm::Optional<int64_t> offsetK,
llvm::Optional<int64_t> shapeK) { llvm::Optional<int64_t> shapeK) {
@@ -94,8 +95,8 @@ Value Prefetcher::generatePrefetch(Value v, unsigned opIdx, bool isPrefetch,
// k => (prefetchWidth, k - prefetchWidth) // k => (prefetchWidth, k - prefetchWidth)
int64_t kIdx = opIdx == 0 ? 1 : 0; int64_t kIdx = opIdx == 0 ? 1 : 0;
offset[kIdx] = isPrefetch ? 0 : prefetchWidth; offset[kIdx] = isPrologue ? 0 : prefetchWidth;
shape[kIdx] = isPrefetch ? prefetchWidth : (shape[kIdx] - prefetchWidth); shape[kIdx] = isPrologue ? prefetchWidth : (shape[kIdx] - prefetchWidth);
if (shapeK) if (shapeK)
shape[kIdx] = *shapeK; shape[kIdx] = *shapeK;
@@ -132,8 +133,8 @@ LogicalResult Prefetcher::initialize() {
// returns source of cvt // returns source of cvt
auto getPrefetchSrc = [](Value v) -> Value { auto getPrefetchSrc = [](Value v) -> Value {
// TODO: Check if the layout of src is SharedEncodingAttr
if (auto cvt = v.getDefiningOp<triton::gpu::ConvertLayoutOp>()) if (auto cvt = v.getDefiningOp<triton::gpu::ConvertLayoutOp>())
if (isSharedEncoding(cvt.getOperand()))
return cvt.src(); return cvt.src();
return Value(); return Value();
}; };
@@ -152,6 +153,10 @@ LogicalResult Prefetcher::initialize() {
}; };
for (triton::DotOp dot : dotsInFor) { for (triton::DotOp dot : dotsInFor) {
auto kSize = dot.a().getType().cast<RankedTensorType>().getShape()[1];
// Skip prefetching if kSize is less than prefetchWidth
if (kSize < prefetchWidth)
continue;
Value aSmem = getPrefetchSrc(dot.a()); Value aSmem = getPrefetchSrc(dot.a());
Value bSmem = getPrefetchSrc(dot.b()); Value bSmem = getPrefetchSrc(dot.b());
if (aSmem && bSmem) { if (aSmem && bSmem) {
@@ -217,7 +222,7 @@ scf::ForOp Prefetcher::createNewForOp() {
mapping.map(arg.value(), newForOp.getRegionIterArgs()[arg.index()]); mapping.map(arg.value(), newForOp.getRegionIterArgs()[arg.index()]);
for (Operation &op : forOp.getBody()->without_terminator()) { for (Operation &op : forOp.getBody()->without_terminator()) {
Operation *newOp = nullptr; Operation *newOp = builder.clone(op, mapping);
auto dot = dyn_cast<triton::DotOp>(&op); auto dot = dyn_cast<triton::DotOp>(&op);
if (dots.contains(dot)) { if (dots.contains(dot)) {
Attribute dotEncoding = Attribute dotEncoding =
@@ -252,8 +257,6 @@ scf::ForOp Prefetcher::createNewForOp() {
kOff += kShape; kOff += kShape;
kRem -= kShape; kRem -= kShape;
} }
} else {
newOp = builder.clone(op, mapping);
} }
// update mapping of results // update mapping of results
for (unsigned dstIdx : llvm::seq(unsigned(0), op.getNumResults())) for (unsigned dstIdx : llvm::seq(unsigned(0), op.getNumResults()))

View File

@@ -857,6 +857,7 @@ def build_triton_ir(fn, signature, specialization, constants):
ret.context = context ret.context = context
return ret, generator return ret, generator
def optimize_triton_ir(mod): def optimize_triton_ir(mod):
pm = _triton.ir.pass_manager(mod.context) pm = _triton.ir.pass_manager(mod.context)
pm.enable_debug() pm.enable_debug()
@@ -868,10 +869,12 @@ def optimize_triton_ir(mod):
pm.run(mod) pm.run(mod)
return mod return mod
def ast_to_ttir(fn, signature, specialization, constants): def ast_to_ttir(fn, signature, specialization, constants):
mod, _ = build_triton_ir(fn, signature, specialization, constants) mod, _ = build_triton_ir(fn, signature, specialization, constants)
return optimize_triton_ir(mod) return optimize_triton_ir(mod)
def ttir_to_ttgir(mod, num_warps, num_stages): def ttir_to_ttgir(mod, num_warps, num_stages):
pm = _triton.ir.pass_manager(mod.context) pm = _triton.ir.pass_manager(mod.context)
pm.add_convert_triton_to_tritongpu_pass(num_warps) pm.add_convert_triton_to_tritongpu_pass(num_warps)
@@ -880,6 +883,9 @@ def ttir_to_ttgir(mod, num_warps, num_stages):
# can get shared memory swizzled correctly. # can get shared memory swizzled correctly.
pm.add_triton_gpu_combine_pass() pm.add_triton_gpu_combine_pass()
pm.add_tritongpu_pipeline_pass(num_stages) pm.add_tritongpu_pipeline_pass(num_stages)
# Prefetch must be done after pipeline pass because pipeline pass
# extracts slices from the original tensor.
pm.add_tritongpu_prefetch_pass()
pm.add_canonicalizer_pass() pm.add_canonicalizer_pass()
pm.add_cse_pass() pm.add_cse_pass()
pm.add_coalesce_pass() pm.add_coalesce_pass()
@@ -922,7 +928,6 @@ def llir_to_ptx(mod: Any, compute_capability: int = None, ptx_version: int = Non
return _triton.translate_llvmir_to_ptx(mod, compute_capability, ptx_version) return _triton.translate_llvmir_to_ptx(mod, compute_capability, ptx_version)
def ptx_to_cubin(ptx: str, device: int): def ptx_to_cubin(ptx: str, device: int):
''' '''
Compile TritonGPU module to cubin. Compile TritonGPU module to cubin.
@@ -992,8 +997,6 @@ def path_to_ptxas():
instance_descriptor = namedtuple("instance_descriptor", ["divisible_by_16", "equal_to_1"], defaults=[set(), set()]) instance_descriptor = namedtuple("instance_descriptor", ["divisible_by_16", "equal_to_1"], defaults=[set(), set()])
# ------------------------------------------------------------------------------ # ------------------------------------------------------------------------------
# compiler # compiler
# ------------------------------------------------------------------------------ # ------------------------------------------------------------------------------

View File

@@ -351,8 +351,14 @@ module attributes {"triton_gpu.num-warps" = 4 : i32} {
// CHECK-NEXT: llvm.extractvalue // CHECK-NEXT: llvm.extractvalue
// CHECK-NEXT: llvm.extractvalue // CHECK-NEXT: llvm.extractvalue
// CHECK-NEXT: llvm.extractvalue // CHECK-NEXT: llvm.extractvalue
// CHECK-NEXT: llvm.extractvalue
// CHECK-NEXT: llvm.extractvalue
// CHECK-NEXT: llvm.extractvalue
// CHECK-NEXT: llvm.add
// CHECK-NEXT: llvm.mlir.constant(0 : i32) : i32 // CHECK-NEXT: llvm.mlir.constant(0 : i32) : i32
// CHECK-NEXT: llvm.add
// CHECK-NEXT: llvm.mlir.constant(0 : i32) : i32 // CHECK-NEXT: llvm.mlir.constant(0 : i32) : i32
// CHECK-NEXT: llvm.add
// CHECK-NEXT: llvm.mlir.constant(0 : i32) : i32 // CHECK-NEXT: llvm.mlir.constant(0 : i32) : i32
// CHECK-NEXT: llvm.mul // CHECK-NEXT: llvm.mul
// CHECK-NEXT: llvm.add // CHECK-NEXT: llvm.add