2022-11-10 13:57:27 +08:00
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
//
|
|
|
|
// This pass tries to prefetch operands (a and b) of tt.dot.
|
|
|
|
// Those ConvertLayoutOps will be lowered to shared memory loads.
|
|
|
|
//
|
|
|
|
// For example:
|
|
|
|
// %a: tensor<128x32xf16, #enc>
|
|
|
|
// scf.for %iv = ... iter_args(%a_arg = %a, ...) {
|
|
|
|
// %d = tt.dot %a_arg, %b, %c
|
|
|
|
// ...
|
|
|
|
// scf.yield %a_next, ...
|
|
|
|
// }
|
|
|
|
//
|
|
|
|
// will be translated to
|
|
|
|
//
|
|
|
|
// %a: tensor<128x32xf16, #enc>
|
|
|
|
// %a_tmp = tensor.extract_slice %a[0, 0] [128, 16]
|
|
|
|
// %a_prefetch = triton_gpu.convert_layout %a_tmp
|
|
|
|
// scf.for %iv = ... iter_args(%a_buf = %a, ..., %a_prefetch_arg = %a_prefetch)
|
|
|
|
// {
|
|
|
|
// %x = tt.dot %a_arg, %b, %c
|
|
|
|
// %a_tmp_rem = tensor.extract_slice %a_buf[0, 16] [128, 16]
|
|
|
|
// %a_prefetch_next = triton_gpu.convert_layout %a_tmp_rem
|
|
|
|
// ...
|
|
|
|
// scf.yield %next_a, ..., %a_prefetch_next
|
|
|
|
// }
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|
|
|
|
#include "mlir/IR/BlockAndValueMapping.h"
|
[WIP][Triton-MLIR] Prefetch pass fixup (#873)
A (potential) problem by directly adopting `tensor.extract_slice`.
Long story short, `tensor.extract_slice` is not aware of swizzling.
Consider the following shared memory tensor and its first three slices,
where each slice includes two tile (the loading unit of LDGSTS) of
elements. Currently, the tiles haven't been swizzled yet, so slicing
seems to work.
<img width="1219" alt="image"
src="https://user-images.githubusercontent.com/2306281/201833023-a7950705-2d50-4c0a-8527-7505261c3a3c.png">
However, now consider the following figure, which is the layout after
applying swizzling on the first figure.
<img width="1244" alt="image"
src="https://user-images.githubusercontent.com/2306281/201834824-7daae360-f5bc-4e6b-a921-20be3f294b78.png">
Note that on phase 2, all tiles have been swizzled out of their
originally slices. This implies that if we use the tile index after
slicing, we can no longer locate the correct tiles. For example, T3 was
in slice 1 but got swapped to slice 0 after swizzling.
Here's a more detailed explanation. In the current `triton-mlir` branch,
we only compute the relative offset of each tile. So T3's index in Slice
1 is *1*, and it will be swizzled using *1* and *phase id*. Whereas the
correct index of T3 should be *3*, which is the relative offset to the
beginning of the shared memory tensor being swizzled, and T3 should be
swizzled using *3* and *phase id*.
This PR proposes a hacky solution for this problem. We restore the
"correct" offset of each tile by **assuming that slicing on a specific
dim only happens at most once on the output of insert_slice_async**. I
admit it's risky and fragile.
The other possible solution is adopting cutlass' swizzling logic that
limits the indices being swizzled in a "bounding box" that matches the
mma instruction executes. For example, in the following tensor layout,
each 4x4 submatrix is a minimum swizzling unit, and the entire tensor
represents the tensor layout of operand A in `mma.16816`.
<img width="565" alt="image"
src="https://user-images.githubusercontent.com/2306281/201836879-4ca7824b-530c-4a06-a3d5-1e74a2de1b42.png">
Co-authored-by: Phil Tillet <phil@openai.com>
2022-11-19 19:57:16 -08:00
|
|
|
#include "triton/Analysis/Utility.h"
|
2022-11-10 13:57:27 +08:00
|
|
|
#include "triton/Dialect/TritonGPU/IR/Dialect.h"
|
|
|
|
#include "triton/Dialect/TritonGPU/Transforms/Passes.h"
|
|
|
|
|
|
|
|
using namespace mlir;
|
|
|
|
|
|
|
|
#define GEN_PASS_CLASSES
|
|
|
|
#include "triton/Dialect/TritonGPU/Transforms/Passes.h.inc"
|
|
|
|
|
|
|
|
namespace {
|
|
|
|
|
|
|
|
class Prefetcher {
|
|
|
|
/// cache the ForOp we are working on
|
|
|
|
scf::ForOp forOp;
|
|
|
|
/// cache the YieldOp of this ForOp
|
|
|
|
scf::YieldOp yieldOp;
|
|
|
|
///
|
|
|
|
// TODO: add a hook to infer prefetchWidth
|
|
|
|
unsigned prefetchWidth = 16;
|
|
|
|
|
|
|
|
/// dots to be prefetched
|
|
|
|
SetVector<Value> dots;
|
|
|
|
/// dot => dot operand
|
|
|
|
DenseMap<Value, Value> dot2aLoopArg;
|
|
|
|
DenseMap<Value, Value> dot2aHeaderDef;
|
|
|
|
DenseMap<Value, Value> dot2bLoopArg;
|
|
|
|
DenseMap<Value, Value> dot2bHeaderDef;
|
|
|
|
DenseMap<Value, Value> dot2aYield;
|
|
|
|
DenseMap<Value, Value> dot2bYield;
|
|
|
|
/// operand => defining
|
|
|
|
DenseMap<Value, Value> operand2headPrefetch;
|
|
|
|
|
|
|
|
LogicalResult isForOpOperand(Value v);
|
|
|
|
|
[WIP][Triton-MLIR] Prefetch pass fixup (#873)
A (potential) problem by directly adopting `tensor.extract_slice`.
Long story short, `tensor.extract_slice` is not aware of swizzling.
Consider the following shared memory tensor and its first three slices,
where each slice includes two tile (the loading unit of LDGSTS) of
elements. Currently, the tiles haven't been swizzled yet, so slicing
seems to work.
<img width="1219" alt="image"
src="https://user-images.githubusercontent.com/2306281/201833023-a7950705-2d50-4c0a-8527-7505261c3a3c.png">
However, now consider the following figure, which is the layout after
applying swizzling on the first figure.
<img width="1244" alt="image"
src="https://user-images.githubusercontent.com/2306281/201834824-7daae360-f5bc-4e6b-a921-20be3f294b78.png">
Note that on phase 2, all tiles have been swizzled out of their
originally slices. This implies that if we use the tile index after
slicing, we can no longer locate the correct tiles. For example, T3 was
in slice 1 but got swapped to slice 0 after swizzling.
Here's a more detailed explanation. In the current `triton-mlir` branch,
we only compute the relative offset of each tile. So T3's index in Slice
1 is *1*, and it will be swizzled using *1* and *phase id*. Whereas the
correct index of T3 should be *3*, which is the relative offset to the
beginning of the shared memory tensor being swizzled, and T3 should be
swizzled using *3* and *phase id*.
This PR proposes a hacky solution for this problem. We restore the
"correct" offset of each tile by **assuming that slicing on a specific
dim only happens at most once on the output of insert_slice_async**. I
admit it's risky and fragile.
The other possible solution is adopting cutlass' swizzling logic that
limits the indices being swizzled in a "bounding box" that matches the
mma instruction executes. For example, in the following tensor layout,
each 4x4 submatrix is a minimum swizzling unit, and the entire tensor
represents the tensor layout of operand A in `mma.16816`.
<img width="565" alt="image"
src="https://user-images.githubusercontent.com/2306281/201836879-4ca7824b-530c-4a06-a3d5-1e74a2de1b42.png">
Co-authored-by: Phil Tillet <phil@openai.com>
2022-11-19 19:57:16 -08:00
|
|
|
Value generatePrefetch(Value v, unsigned opIdx, bool isPrologue,
|
2022-11-10 13:57:27 +08:00
|
|
|
Attribute dotEncoding, OpBuilder &builder,
|
|
|
|
llvm::Optional<int64_t> offsetK = llvm::None,
|
|
|
|
llvm::Optional<int64_t> shapeK = llvm::None);
|
|
|
|
|
|
|
|
public:
|
|
|
|
Prefetcher() = delete;
|
|
|
|
|
|
|
|
Prefetcher(scf::ForOp forOp) : forOp(forOp) {
|
|
|
|
yieldOp = cast<scf::YieldOp>(forOp.getBody()->getTerminator());
|
|
|
|
}
|
|
|
|
|
|
|
|
LogicalResult initialize();
|
|
|
|
|
|
|
|
void emitPrologue();
|
|
|
|
|
|
|
|
scf::ForOp createNewForOp();
|
|
|
|
};
|
|
|
|
|
[WIP][Triton-MLIR] Prefetch pass fixup (#873)
A (potential) problem by directly adopting `tensor.extract_slice`.
Long story short, `tensor.extract_slice` is not aware of swizzling.
Consider the following shared memory tensor and its first three slices,
where each slice includes two tile (the loading unit of LDGSTS) of
elements. Currently, the tiles haven't been swizzled yet, so slicing
seems to work.
<img width="1219" alt="image"
src="https://user-images.githubusercontent.com/2306281/201833023-a7950705-2d50-4c0a-8527-7505261c3a3c.png">
However, now consider the following figure, which is the layout after
applying swizzling on the first figure.
<img width="1244" alt="image"
src="https://user-images.githubusercontent.com/2306281/201834824-7daae360-f5bc-4e6b-a921-20be3f294b78.png">
Note that on phase 2, all tiles have been swizzled out of their
originally slices. This implies that if we use the tile index after
slicing, we can no longer locate the correct tiles. For example, T3 was
in slice 1 but got swapped to slice 0 after swizzling.
Here's a more detailed explanation. In the current `triton-mlir` branch,
we only compute the relative offset of each tile. So T3's index in Slice
1 is *1*, and it will be swizzled using *1* and *phase id*. Whereas the
correct index of T3 should be *3*, which is the relative offset to the
beginning of the shared memory tensor being swizzled, and T3 should be
swizzled using *3* and *phase id*.
This PR proposes a hacky solution for this problem. We restore the
"correct" offset of each tile by **assuming that slicing on a specific
dim only happens at most once on the output of insert_slice_async**. I
admit it's risky and fragile.
The other possible solution is adopting cutlass' swizzling logic that
limits the indices being swizzled in a "bounding box" that matches the
mma instruction executes. For example, in the following tensor layout,
each 4x4 submatrix is a minimum swizzling unit, and the entire tensor
represents the tensor layout of operand A in `mma.16816`.
<img width="565" alt="image"
src="https://user-images.githubusercontent.com/2306281/201836879-4ca7824b-530c-4a06-a3d5-1e74a2de1b42.png">
Co-authored-by: Phil Tillet <phil@openai.com>
2022-11-19 19:57:16 -08:00
|
|
|
Value Prefetcher::generatePrefetch(Value v, unsigned opIdx, bool isPrologue,
|
2022-11-10 13:57:27 +08:00
|
|
|
Attribute dotEncoding, OpBuilder &builder,
|
|
|
|
llvm::Optional<int64_t> offsetK,
|
|
|
|
llvm::Optional<int64_t> shapeK) {
|
|
|
|
// opIdx: 0 => a, 1 => b
|
|
|
|
auto type = v.getType().cast<RankedTensorType>();
|
|
|
|
SmallVector<int64_t> shape{type.getShape().begin(), type.getShape().end()};
|
|
|
|
SmallVector<int64_t> offset{0, 0};
|
|
|
|
Type elementType = type.getElementType();
|
|
|
|
|
|
|
|
auto intAttr = [&](int64_t val) { return builder.getI64IntegerAttr(val); };
|
|
|
|
|
|
|
|
// k => (prefetchWidth, k - prefetchWidth)
|
|
|
|
int64_t kIdx = opIdx == 0 ? 1 : 0;
|
|
|
|
|
[WIP][Triton-MLIR] Prefetch pass fixup (#873)
A (potential) problem by directly adopting `tensor.extract_slice`.
Long story short, `tensor.extract_slice` is not aware of swizzling.
Consider the following shared memory tensor and its first three slices,
where each slice includes two tile (the loading unit of LDGSTS) of
elements. Currently, the tiles haven't been swizzled yet, so slicing
seems to work.
<img width="1219" alt="image"
src="https://user-images.githubusercontent.com/2306281/201833023-a7950705-2d50-4c0a-8527-7505261c3a3c.png">
However, now consider the following figure, which is the layout after
applying swizzling on the first figure.
<img width="1244" alt="image"
src="https://user-images.githubusercontent.com/2306281/201834824-7daae360-f5bc-4e6b-a921-20be3f294b78.png">
Note that on phase 2, all tiles have been swizzled out of their
originally slices. This implies that if we use the tile index after
slicing, we can no longer locate the correct tiles. For example, T3 was
in slice 1 but got swapped to slice 0 after swizzling.
Here's a more detailed explanation. In the current `triton-mlir` branch,
we only compute the relative offset of each tile. So T3's index in Slice
1 is *1*, and it will be swizzled using *1* and *phase id*. Whereas the
correct index of T3 should be *3*, which is the relative offset to the
beginning of the shared memory tensor being swizzled, and T3 should be
swizzled using *3* and *phase id*.
This PR proposes a hacky solution for this problem. We restore the
"correct" offset of each tile by **assuming that slicing on a specific
dim only happens at most once on the output of insert_slice_async**. I
admit it's risky and fragile.
The other possible solution is adopting cutlass' swizzling logic that
limits the indices being swizzled in a "bounding box" that matches the
mma instruction executes. For example, in the following tensor layout,
each 4x4 submatrix is a minimum swizzling unit, and the entire tensor
represents the tensor layout of operand A in `mma.16816`.
<img width="565" alt="image"
src="https://user-images.githubusercontent.com/2306281/201836879-4ca7824b-530c-4a06-a3d5-1e74a2de1b42.png">
Co-authored-by: Phil Tillet <phil@openai.com>
2022-11-19 19:57:16 -08:00
|
|
|
offset[kIdx] = isPrologue ? 0 : prefetchWidth;
|
|
|
|
shape[kIdx] = isPrologue ? prefetchWidth : (shape[kIdx] - prefetchWidth);
|
2022-11-10 13:57:27 +08:00
|
|
|
|
|
|
|
if (shapeK)
|
|
|
|
shape[kIdx] = *shapeK;
|
|
|
|
if (offsetK)
|
|
|
|
offset[kIdx] = *offsetK;
|
|
|
|
|
|
|
|
Value newSmem = builder.create<tensor::ExtractSliceOp>(
|
|
|
|
v.getLoc(),
|
|
|
|
// TODO: encoding?
|
|
|
|
RankedTensorType::get(shape, elementType, type.getEncoding()), v,
|
|
|
|
SmallVector<OpFoldResult>{intAttr(offset[0]), intAttr(offset[1])},
|
|
|
|
SmallVector<OpFoldResult>{intAttr(shape[0]), intAttr(shape[1])},
|
|
|
|
SmallVector<OpFoldResult>{intAttr(1), intAttr(1)});
|
|
|
|
|
|
|
|
auto dotOperandEnc = triton::gpu::DotOperandEncodingAttr::get(
|
|
|
|
builder.getContext(), opIdx, dotEncoding);
|
|
|
|
Value prefetchSlice = builder.create<triton::gpu::ConvertLayoutOp>(
|
|
|
|
v.getLoc(), RankedTensorType::get(shape, elementType, dotOperandEnc),
|
|
|
|
newSmem);
|
|
|
|
|
|
|
|
return prefetchSlice;
|
|
|
|
}
|
|
|
|
|
|
|
|
LogicalResult Prefetcher::initialize() {
|
|
|
|
Block *loop = forOp.getBody();
|
|
|
|
|
|
|
|
SmallVector<triton::DotOp> dotsInFor;
|
|
|
|
for (Operation &op : *loop)
|
|
|
|
if (auto dotOp = dyn_cast<triton::DotOp>(op))
|
|
|
|
dotsInFor.push_back(dotOp);
|
|
|
|
|
|
|
|
if (dotsInFor.empty())
|
|
|
|
return failure();
|
2022-12-01 11:54:18 -08:00
|
|
|
|
2022-11-30 11:13:24 +01:00
|
|
|
// TODO: segfault (original for still has uses)
|
|
|
|
// when used in flash attention that has 2 dots in the loop
|
2022-12-01 11:54:18 -08:00
|
|
|
if (dotsInFor.size() > 1)
|
2022-11-30 11:13:24 +01:00
|
|
|
return failure();
|
2022-11-10 13:57:27 +08:00
|
|
|
|
|
|
|
// returns source of cvt
|
|
|
|
auto getPrefetchSrc = [](Value v) -> Value {
|
|
|
|
if (auto cvt = v.getDefiningOp<triton::gpu::ConvertLayoutOp>())
|
[WIP][Triton-MLIR] Prefetch pass fixup (#873)
A (potential) problem by directly adopting `tensor.extract_slice`.
Long story short, `tensor.extract_slice` is not aware of swizzling.
Consider the following shared memory tensor and its first three slices,
where each slice includes two tile (the loading unit of LDGSTS) of
elements. Currently, the tiles haven't been swizzled yet, so slicing
seems to work.
<img width="1219" alt="image"
src="https://user-images.githubusercontent.com/2306281/201833023-a7950705-2d50-4c0a-8527-7505261c3a3c.png">
However, now consider the following figure, which is the layout after
applying swizzling on the first figure.
<img width="1244" alt="image"
src="https://user-images.githubusercontent.com/2306281/201834824-7daae360-f5bc-4e6b-a921-20be3f294b78.png">
Note that on phase 2, all tiles have been swizzled out of their
originally slices. This implies that if we use the tile index after
slicing, we can no longer locate the correct tiles. For example, T3 was
in slice 1 but got swapped to slice 0 after swizzling.
Here's a more detailed explanation. In the current `triton-mlir` branch,
we only compute the relative offset of each tile. So T3's index in Slice
1 is *1*, and it will be swizzled using *1* and *phase id*. Whereas the
correct index of T3 should be *3*, which is the relative offset to the
beginning of the shared memory tensor being swizzled, and T3 should be
swizzled using *3* and *phase id*.
This PR proposes a hacky solution for this problem. We restore the
"correct" offset of each tile by **assuming that slicing on a specific
dim only happens at most once on the output of insert_slice_async**. I
admit it's risky and fragile.
The other possible solution is adopting cutlass' swizzling logic that
limits the indices being swizzled in a "bounding box" that matches the
mma instruction executes. For example, in the following tensor layout,
each 4x4 submatrix is a minimum swizzling unit, and the entire tensor
represents the tensor layout of operand A in `mma.16816`.
<img width="565" alt="image"
src="https://user-images.githubusercontent.com/2306281/201836879-4ca7824b-530c-4a06-a3d5-1e74a2de1b42.png">
Co-authored-by: Phil Tillet <phil@openai.com>
2022-11-19 19:57:16 -08:00
|
|
|
if (isSharedEncoding(cvt.getOperand()))
|
|
|
|
return cvt.src();
|
2022-11-10 13:57:27 +08:00
|
|
|
return Value();
|
|
|
|
};
|
|
|
|
|
|
|
|
auto getIncomingOp = [this](Value v) -> Value {
|
|
|
|
if (auto arg = v.dyn_cast<BlockArgument>())
|
|
|
|
if (arg.getOwner()->getParentOp() == forOp.getOperation())
|
|
|
|
return forOp.getOpOperandForRegionIterArg(arg).get();
|
|
|
|
return Value();
|
|
|
|
};
|
|
|
|
|
|
|
|
auto getYieldOp = [this](Value v) -> Value {
|
|
|
|
auto arg = v.cast<BlockArgument>();
|
|
|
|
unsigned yieldIdx = arg.getArgNumber() - forOp.getNumInductionVars();
|
|
|
|
return yieldOp.getOperand(yieldIdx);
|
|
|
|
};
|
|
|
|
|
|
|
|
for (triton::DotOp dot : dotsInFor) {
|
[WIP][Triton-MLIR] Prefetch pass fixup (#873)
A (potential) problem by directly adopting `tensor.extract_slice`.
Long story short, `tensor.extract_slice` is not aware of swizzling.
Consider the following shared memory tensor and its first three slices,
where each slice includes two tile (the loading unit of LDGSTS) of
elements. Currently, the tiles haven't been swizzled yet, so slicing
seems to work.
<img width="1219" alt="image"
src="https://user-images.githubusercontent.com/2306281/201833023-a7950705-2d50-4c0a-8527-7505261c3a3c.png">
However, now consider the following figure, which is the layout after
applying swizzling on the first figure.
<img width="1244" alt="image"
src="https://user-images.githubusercontent.com/2306281/201834824-7daae360-f5bc-4e6b-a921-20be3f294b78.png">
Note that on phase 2, all tiles have been swizzled out of their
originally slices. This implies that if we use the tile index after
slicing, we can no longer locate the correct tiles. For example, T3 was
in slice 1 but got swapped to slice 0 after swizzling.
Here's a more detailed explanation. In the current `triton-mlir` branch,
we only compute the relative offset of each tile. So T3's index in Slice
1 is *1*, and it will be swizzled using *1* and *phase id*. Whereas the
correct index of T3 should be *3*, which is the relative offset to the
beginning of the shared memory tensor being swizzled, and T3 should be
swizzled using *3* and *phase id*.
This PR proposes a hacky solution for this problem. We restore the
"correct" offset of each tile by **assuming that slicing on a specific
dim only happens at most once on the output of insert_slice_async**. I
admit it's risky and fragile.
The other possible solution is adopting cutlass' swizzling logic that
limits the indices being swizzled in a "bounding box" that matches the
mma instruction executes. For example, in the following tensor layout,
each 4x4 submatrix is a minimum swizzling unit, and the entire tensor
represents the tensor layout of operand A in `mma.16816`.
<img width="565" alt="image"
src="https://user-images.githubusercontent.com/2306281/201836879-4ca7824b-530c-4a06-a3d5-1e74a2de1b42.png">
Co-authored-by: Phil Tillet <phil@openai.com>
2022-11-19 19:57:16 -08:00
|
|
|
auto kSize = dot.a().getType().cast<RankedTensorType>().getShape()[1];
|
|
|
|
// Skip prefetching if kSize is less than prefetchWidth
|
|
|
|
if (kSize < prefetchWidth)
|
|
|
|
continue;
|
2022-11-10 13:57:27 +08:00
|
|
|
Value aSmem = getPrefetchSrc(dot.a());
|
|
|
|
Value bSmem = getPrefetchSrc(dot.b());
|
|
|
|
if (aSmem && bSmem) {
|
|
|
|
Value aHeaderDef = getIncomingOp(aSmem);
|
|
|
|
Value bHeaderDef = getIncomingOp(bSmem);
|
|
|
|
// Only prefetch loop arg
|
|
|
|
if (aHeaderDef && bHeaderDef) {
|
|
|
|
dots.insert(dot);
|
|
|
|
dot2aHeaderDef[dot] = aHeaderDef;
|
|
|
|
dot2bHeaderDef[dot] = bHeaderDef;
|
|
|
|
dot2aLoopArg[dot] = aSmem;
|
|
|
|
dot2bLoopArg[dot] = bSmem;
|
|
|
|
dot2aYield[dot] = getYieldOp(aSmem);
|
|
|
|
dot2bYield[dot] = getYieldOp(bSmem);
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
return success();
|
|
|
|
}
|
|
|
|
|
|
|
|
void Prefetcher::emitPrologue() {
|
|
|
|
OpBuilder builder(forOp);
|
|
|
|
|
|
|
|
for (Value dot : dots) {
|
|
|
|
Attribute dotEncoding =
|
|
|
|
dot.getType().cast<RankedTensorType>().getEncoding();
|
|
|
|
Value aPrefetched =
|
|
|
|
generatePrefetch(dot2aHeaderDef[dot], 0, true, dotEncoding, builder);
|
|
|
|
operand2headPrefetch[dot.getDefiningOp<triton::DotOp>().a()] = aPrefetched;
|
|
|
|
Value bPrefetched =
|
|
|
|
generatePrefetch(dot2bHeaderDef[dot], 1, true, dotEncoding, builder);
|
|
|
|
operand2headPrefetch[dot.getDefiningOp<triton::DotOp>().b()] = bPrefetched;
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
scf::ForOp Prefetcher::createNewForOp() {
|
|
|
|
OpBuilder builder(forOp);
|
|
|
|
|
|
|
|
SmallVector<Value> loopArgs;
|
|
|
|
for (auto v : forOp.getIterOperands())
|
|
|
|
loopArgs.push_back(v);
|
|
|
|
for (Value dot : dots) {
|
|
|
|
loopArgs.push_back(
|
|
|
|
operand2headPrefetch[dot.getDefiningOp<triton::DotOp>().a()]);
|
|
|
|
loopArgs.push_back(
|
|
|
|
operand2headPrefetch[dot.getDefiningOp<triton::DotOp>().b()]);
|
|
|
|
}
|
|
|
|
|
|
|
|
auto newForOp = builder.create<scf::ForOp>(
|
|
|
|
forOp.getLoc(), forOp.getLowerBound(), forOp.getUpperBound(),
|
|
|
|
forOp.getStep(), loopArgs);
|
|
|
|
|
|
|
|
auto largestPow2 = [](int64_t n) -> int64_t {
|
|
|
|
while ((n & (n - 1)) != 0)
|
|
|
|
n = n & (n - 1);
|
|
|
|
return n;
|
|
|
|
};
|
|
|
|
|
|
|
|
builder.setInsertionPointToStart(newForOp.getBody());
|
|
|
|
BlockAndValueMapping mapping;
|
|
|
|
for (const auto &arg : llvm::enumerate(forOp.getRegionIterArgs()))
|
|
|
|
mapping.map(arg.value(), newForOp.getRegionIterArgs()[arg.index()]);
|
2022-12-10 20:34:58 -08:00
|
|
|
mapping.map(forOp.getInductionVar(), newForOp.getInductionVar());
|
2022-11-10 13:57:27 +08:00
|
|
|
|
|
|
|
for (Operation &op : forOp.getBody()->without_terminator()) {
|
[WIP][Triton-MLIR] Prefetch pass fixup (#873)
A (potential) problem by directly adopting `tensor.extract_slice`.
Long story short, `tensor.extract_slice` is not aware of swizzling.
Consider the following shared memory tensor and its first three slices,
where each slice includes two tile (the loading unit of LDGSTS) of
elements. Currently, the tiles haven't been swizzled yet, so slicing
seems to work.
<img width="1219" alt="image"
src="https://user-images.githubusercontent.com/2306281/201833023-a7950705-2d50-4c0a-8527-7505261c3a3c.png">
However, now consider the following figure, which is the layout after
applying swizzling on the first figure.
<img width="1244" alt="image"
src="https://user-images.githubusercontent.com/2306281/201834824-7daae360-f5bc-4e6b-a921-20be3f294b78.png">
Note that on phase 2, all tiles have been swizzled out of their
originally slices. This implies that if we use the tile index after
slicing, we can no longer locate the correct tiles. For example, T3 was
in slice 1 but got swapped to slice 0 after swizzling.
Here's a more detailed explanation. In the current `triton-mlir` branch,
we only compute the relative offset of each tile. So T3's index in Slice
1 is *1*, and it will be swizzled using *1* and *phase id*. Whereas the
correct index of T3 should be *3*, which is the relative offset to the
beginning of the shared memory tensor being swizzled, and T3 should be
swizzled using *3* and *phase id*.
This PR proposes a hacky solution for this problem. We restore the
"correct" offset of each tile by **assuming that slicing on a specific
dim only happens at most once on the output of insert_slice_async**. I
admit it's risky and fragile.
The other possible solution is adopting cutlass' swizzling logic that
limits the indices being swizzled in a "bounding box" that matches the
mma instruction executes. For example, in the following tensor layout,
each 4x4 submatrix is a minimum swizzling unit, and the entire tensor
represents the tensor layout of operand A in `mma.16816`.
<img width="565" alt="image"
src="https://user-images.githubusercontent.com/2306281/201836879-4ca7824b-530c-4a06-a3d5-1e74a2de1b42.png">
Co-authored-by: Phil Tillet <phil@openai.com>
2022-11-19 19:57:16 -08:00
|
|
|
Operation *newOp = builder.clone(op, mapping);
|
2022-11-10 13:57:27 +08:00
|
|
|
auto dot = dyn_cast<triton::DotOp>(&op);
|
|
|
|
if (dots.contains(dot)) {
|
|
|
|
Attribute dotEncoding =
|
|
|
|
dot.getType().cast<RankedTensorType>().getEncoding();
|
|
|
|
// prefetched dot
|
|
|
|
Operation *firstDot = builder.clone(*dot, mapping);
|
|
|
|
if (Value a = operand2headPrefetch.lookup(dot.a()))
|
|
|
|
firstDot->setOperand(
|
|
|
|
0, newForOp.getRegionIterArgForOpOperand(*a.use_begin()));
|
|
|
|
if (Value b = operand2headPrefetch.lookup(dot.b()))
|
|
|
|
firstDot->setOperand(
|
|
|
|
1, newForOp.getRegionIterArgForOpOperand(*b.use_begin()));
|
|
|
|
|
|
|
|
// remaining part
|
|
|
|
int64_t kOff = prefetchWidth;
|
|
|
|
int64_t kRem = dot.a().getType().cast<RankedTensorType>().getShape()[1] -
|
|
|
|
prefetchWidth;
|
|
|
|
Operation *prevDot = firstDot;
|
|
|
|
while (kRem != 0) {
|
|
|
|
int64_t kShape = largestPow2(kRem);
|
|
|
|
Value aRem =
|
|
|
|
generatePrefetch(mapping.lookup(dot2aLoopArg[dot]), 0, false,
|
|
|
|
dotEncoding, builder, kOff, kShape);
|
|
|
|
Value bRem =
|
|
|
|
generatePrefetch(mapping.lookup(dot2bLoopArg[dot]), 1, false,
|
|
|
|
dotEncoding, builder, kOff, kShape);
|
|
|
|
newOp = builder.clone(*dot, mapping);
|
|
|
|
newOp->setOperand(0, aRem);
|
|
|
|
newOp->setOperand(1, bRem);
|
|
|
|
newOp->setOperand(2, prevDot->getResult(0));
|
|
|
|
prevDot = newOp;
|
|
|
|
kOff += kShape;
|
|
|
|
kRem -= kShape;
|
|
|
|
}
|
|
|
|
}
|
|
|
|
// update mapping of results
|
|
|
|
for (unsigned dstIdx : llvm::seq(unsigned(0), op.getNumResults()))
|
|
|
|
mapping.map(op.getResult(dstIdx), newOp->getResult(dstIdx));
|
|
|
|
}
|
|
|
|
|
|
|
|
// prefetch next iteration
|
|
|
|
SmallVector<Value> yieldValues;
|
|
|
|
for (Value v : forOp.getBody()->getTerminator()->getOperands())
|
|
|
|
yieldValues.push_back(mapping.lookup(v));
|
|
|
|
for (Value dot : dots) {
|
|
|
|
Attribute dotEncoding =
|
|
|
|
dot.getType().cast<RankedTensorType>().getEncoding();
|
|
|
|
yieldValues.push_back(generatePrefetch(mapping.lookup(dot2aYield[dot]), 0,
|
|
|
|
true, dotEncoding, builder));
|
|
|
|
yieldValues.push_back(generatePrefetch(mapping.lookup(dot2bYield[dot]), 1,
|
|
|
|
true, dotEncoding, builder));
|
|
|
|
}
|
|
|
|
// Update ops of yield
|
|
|
|
builder.create<scf::YieldOp>(yieldOp.getLoc(), yieldValues);
|
|
|
|
return newForOp;
|
|
|
|
}
|
|
|
|
|
|
|
|
struct PrefetchPass : public TritonGPUPrefetchBase<PrefetchPass> {
|
|
|
|
void runOnOperation() override {
|
|
|
|
getOperation()->walk([&](scf::ForOp forOp) {
|
|
|
|
Prefetcher prefetcher(forOp);
|
|
|
|
|
|
|
|
if (prefetcher.initialize().failed())
|
|
|
|
return;
|
|
|
|
|
|
|
|
prefetcher.emitPrologue();
|
|
|
|
|
|
|
|
scf::ForOp newForOp = prefetcher.createNewForOp();
|
|
|
|
|
|
|
|
// replace the original loop
|
|
|
|
for (unsigned i = 0; i < forOp->getNumResults(); ++i)
|
|
|
|
forOp->getResult(i).replaceAllUsesWith(newForOp->getResult(i));
|
|
|
|
forOp->erase();
|
|
|
|
});
|
|
|
|
}
|
|
|
|
};
|
|
|
|
|
|
|
|
} // anonymous namespace
|
|
|
|
|
|
|
|
std::unique_ptr<Pass> mlir::createTritonGPUPrefetchPass() {
|
|
|
|
return std::make_unique<PrefetchPass>();
|
|
|
|
}
|