[Triton-MLIR][Backend] Fix number of warps and threads per warp when matrices are small (#917)
This commit is contained in:
@@ -8,6 +8,29 @@
|
||||
|
||||
namespace mlir {
|
||||
|
||||
class ReduceOpHelper {
|
||||
public:
|
||||
explicit ReduceOpHelper(triton::ReduceOp op) : op(op) {
|
||||
srcTy = op.operand().getType().cast<RankedTensorType>();
|
||||
}
|
||||
|
||||
ArrayRef<int64_t> getSrcShape() { return srcTy.getShape(); }
|
||||
|
||||
Attribute getSrcLayout() { return srcTy.getEncoding(); }
|
||||
|
||||
bool isFastReduction();
|
||||
|
||||
unsigned getInterWarpSize();
|
||||
|
||||
unsigned getIntraWarpSize();
|
||||
|
||||
unsigned getThreadsReductionAxis();
|
||||
|
||||
private:
|
||||
triton::ReduceOp op;
|
||||
RankedTensorType srcTy{};
|
||||
};
|
||||
|
||||
bool isSharedEncoding(Value value);
|
||||
|
||||
bool maybeSharedAllocationOp(Operation *op);
|
||||
|
@@ -228,9 +228,11 @@ for
|
||||
unsigned remainingLanes = 32;
|
||||
unsigned remainingThreads = numWarps*32;
|
||||
unsigned remainingWarps = numWarps;
|
||||
unsigned prevLanes = 1;
|
||||
unsigned prevWarps = 1;
|
||||
SmallVector<unsigned, 4> threadsPerWarp(rank);
|
||||
SmallVector<unsigned, 4> warpsPerCTA(rank);
|
||||
for (int _dim = 0; _dim < rank; ++_dim) {
|
||||
for (int _dim = 0; _dim < rank - 1; ++_dim) {
|
||||
int i = order[_dim];
|
||||
unsigned threadsPerCTA = std::clamp<unsigned>(remainingThreads, 1, shape[i] / sizePerThread[i]);
|
||||
threadsPerWarp[i] = std::clamp<unsigned>(threadsPerCTA, 1, remainingLanes);
|
||||
@@ -238,7 +240,12 @@ for
|
||||
remainingWarps /= warpsPerCTA[i];
|
||||
remainingLanes /= threadsPerWarp[i];
|
||||
remainingThreads /= threadsPerCTA;
|
||||
prevLanes *= threadsPerWarp[i];
|
||||
prevWarps *= warpsPerCTA[i];
|
||||
}
|
||||
// Expand the last dimension to fill the remaining lanes and warps
|
||||
threadsPerWarp[order[rank-1]] = 32 / prevLanes;
|
||||
warpsPerCTA[order[rank-1]] = numWarps / prevWarps;
|
||||
|
||||
return $_get(context, sizePerThread, threadsPerWarp, warpsPerCTA, order);
|
||||
|
||||
|
@@ -89,24 +89,19 @@ getScratchConfigForCvtLayout(triton::gpu::ConvertLayoutOp op, unsigned &inVec,
|
||||
}
|
||||
|
||||
SmallVector<unsigned> getScratchConfigForReduce(triton::ReduceOp op) {
|
||||
auto srcTy = op.operand().getType().cast<RankedTensorType>();
|
||||
auto srcLayout = srcTy.getEncoding();
|
||||
auto srcShape = srcTy.getShape();
|
||||
auto axis = op.axis();
|
||||
|
||||
bool fastReduce = axis == getOrder(srcLayout)[0];
|
||||
ReduceOpHelper helper(op);
|
||||
|
||||
SmallVector<unsigned> smemShape;
|
||||
auto srcShape = helper.getSrcShape();
|
||||
for (auto d : srcShape)
|
||||
smemShape.push_back(d);
|
||||
|
||||
if (fastReduce) {
|
||||
unsigned sizeInterWarps = gpu::getWarpsPerCTA(srcLayout)[axis];
|
||||
smemShape[axis] = sizeInterWarps;
|
||||
auto axis = op.axis();
|
||||
if (helper.isFastReduction()) {
|
||||
smemShape[axis] = helper.getInterWarpSize();
|
||||
} else {
|
||||
unsigned threadsPerCTAAxis = gpu::getThreadsPerWarp(srcLayout)[axis] *
|
||||
gpu::getWarpsPerCTA(srcLayout)[axis];
|
||||
smemShape[axis] = threadsPerCTAAxis;
|
||||
smemShape[axis] =
|
||||
std::min(smemShape[axis], helper.getThreadsReductionAxis());
|
||||
}
|
||||
|
||||
return smemShape;
|
||||
@@ -181,8 +176,7 @@ private:
|
||||
// TODO(Keren): Reduce with index is not supported yet.
|
||||
auto value = op->getOperand(0);
|
||||
if (auto tensorType = value.getType().dyn_cast<RankedTensorType>()) {
|
||||
auto srcLayout = tensorType.getEncoding();
|
||||
bool fastReduce = reduceOp.axis() == getOrder(srcLayout)[0];
|
||||
bool fastReduce = ReduceOpHelper(reduceOp).isFastReduction();
|
||||
auto smemShape = getScratchConfigForReduce(reduceOp);
|
||||
unsigned elems = std::accumulate(smemShape.begin(), smemShape.end(), 1,
|
||||
std::multiplies{});
|
||||
|
@@ -5,6 +5,38 @@
|
||||
|
||||
namespace mlir {
|
||||
|
||||
bool ReduceOpHelper::isFastReduction() {
|
||||
auto srcLayout = srcTy.getEncoding();
|
||||
auto axis = op.axis();
|
||||
return axis == triton::gpu::getOrder(srcLayout)[0];
|
||||
}
|
||||
|
||||
unsigned ReduceOpHelper::getInterWarpSize() {
|
||||
auto srcLayout = srcTy.getEncoding();
|
||||
auto srcShape = srcTy.getShape();
|
||||
auto axis = op.axis();
|
||||
auto srcReduceDimSize = static_cast<unsigned>(srcShape[axis]);
|
||||
unsigned sizeIntraWarps = getIntraWarpSize();
|
||||
return std::min(srcReduceDimSize / sizeIntraWarps,
|
||||
triton::gpu::getWarpsPerCTA(srcLayout)[axis]);
|
||||
}
|
||||
|
||||
unsigned ReduceOpHelper::getIntraWarpSize() {
|
||||
auto srcLayout = srcTy.getEncoding();
|
||||
auto srcShape = srcTy.getShape();
|
||||
auto axis = op.axis();
|
||||
auto srcReduceDimSize = static_cast<unsigned>(srcShape[axis]);
|
||||
return std::min(srcReduceDimSize,
|
||||
triton::gpu::getThreadsPerWarp(srcLayout)[axis]);
|
||||
}
|
||||
|
||||
unsigned ReduceOpHelper::getThreadsReductionAxis() {
|
||||
auto srcLayout = srcTy.getEncoding();
|
||||
auto axis = op.axis();
|
||||
return triton::gpu::getThreadsPerWarp(srcLayout)[axis] *
|
||||
triton::gpu::getWarpsPerCTA(srcLayout)[axis];
|
||||
}
|
||||
|
||||
bool isSharedEncoding(Value value) {
|
||||
auto type = value.getType();
|
||||
if (auto tensorType = type.dyn_cast<RankedTensorType>()) {
|
||||
|
@@ -1563,9 +1563,7 @@ private:
|
||||
LogicalResult
|
||||
ReduceOpConversion::matchAndRewrite(triton::ReduceOp op, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const {
|
||||
auto srcTy = op.operand().getType().cast<RankedTensorType>();
|
||||
auto srcLayout = srcTy.getEncoding().cast<BlockedEncodingAttr>();
|
||||
if (op.axis() == srcLayout.getOrder()[0])
|
||||
if (ReduceOpHelper(op).isFastReduction())
|
||||
return matchAndRewriteFast(op, adaptor, rewriter);
|
||||
return matchAndRewriteBasic(op, adaptor, rewriter);
|
||||
}
|
||||
@@ -1763,10 +1761,11 @@ LogicalResult ReduceOpConversion::matchAndRewriteFast(
|
||||
Value smemBase = getSharedMemoryBase(loc, rewriter, op.getOperation());
|
||||
smemBase = bitcast(smemBase, elemPtrTy);
|
||||
|
||||
auto order = getOrder(srcLayout);
|
||||
unsigned sizeIntraWarps = threadsPerWarp[axis];
|
||||
unsigned sizeInterWarps = warpsPerCTA[axis];
|
||||
ReduceOpHelper helper(op);
|
||||
unsigned sizeIntraWarps = helper.getIntraWarpSize();
|
||||
unsigned sizeInterWarps = helper.getInterWarpSize();
|
||||
|
||||
auto order = getOrder(srcLayout);
|
||||
unsigned srcElems = getElemsPerThread(srcTy);
|
||||
auto srcIndices = emitIndices(loc, rewriter, srcLayout, srcShape);
|
||||
auto srcValues = getElementsFromStruct(loc, adaptor.operand(), rewriter);
|
||||
|
@@ -648,10 +648,11 @@ def test_atomic_rmw(op, dtype_x_str, mode, device='cuda'):
|
||||
else:
|
||||
np.testing.assert_allclose(z_ref, to_numpy(z_tri), rtol=0.01)
|
||||
|
||||
#TODO[dongdongl]:add more cases with size of tensor less than warp size
|
||||
@pytest.mark.parametrize("axis", [0, 1])
|
||||
def test_tensor_atomic_rmw(axis, device="cuda"):
|
||||
shape0, shape1 = 8, 8
|
||||
|
||||
@pytest.mark.parametrize("shape, axis",
|
||||
[(shape, axis) for shape in [(2, 2), (2, 8), (8, 2), (8, 8), (32, 32)] for axis in [0, 1]])
|
||||
def test_tensor_atomic_rmw(shape, axis, device="cuda"):
|
||||
shape0, shape1 = shape
|
||||
# triton kernel
|
||||
|
||||
@triton.jit
|
||||
@@ -660,14 +661,19 @@ def test_tensor_atomic_rmw(axis, device="cuda"):
|
||||
off1 = tl.arange(0, SHAPE1)
|
||||
x = tl.load(X + off0[:, None] * SHAPE1 + off1[None, :])
|
||||
z = tl.sum(x, axis=AXIS)
|
||||
tl.atomic_add(Z + off0, z)
|
||||
if AXIS == 1:
|
||||
tl.atomic_add(Z + off0, z)
|
||||
else:
|
||||
tl.atomic_add(Z + off1, z)
|
||||
rs = RandomState(17)
|
||||
x = numpy_random((shape0, shape1), dtype_str="float32", rs=rs)
|
||||
print(x)
|
||||
# reference result
|
||||
z_ref = np.sum(x, axis=axis)
|
||||
z_ref = np.sum(x, axis=axis, keepdims=False)
|
||||
# triton result
|
||||
x_tri = to_triton(x, device=device)
|
||||
z_tri = to_triton(np.zeros((shape0,), dtype="float32"), device=device)
|
||||
z_shape = (shape0, ) if axis == 1 else (shape1, )
|
||||
z_tri = to_triton(np.zeros(z_shape, dtype="float32"), device=device)
|
||||
kernel[(1,)](z_tri, x_tri, axis, shape0, shape1)
|
||||
np.testing.assert_allclose(z_ref, to_numpy(z_tri), rtol=1e-4)
|
||||
|
||||
|
@@ -1,4 +1,4 @@
|
||||
// RUN: triton-opt %s -convert-triton-to-tritongpu=num-warps=2 | FileCheck %s
|
||||
// RUN: triton-opt %s -split-input-file -convert-triton-to-tritongpu=num-warps=2 | FileCheck %s
|
||||
|
||||
func @ops() {
|
||||
// CHECK: module attributes {"triton_gpu.num-warps" = 2 : i32} {{.*}}
|
||||
@@ -9,6 +9,8 @@ func @ops() {
|
||||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @load_ops(%ptr: !tt.ptr<f32> {tt.divisibility = 16 : i32}) {
|
||||
// Test if LoadOp is lowered properly (see #771)
|
||||
%ptrs = tt.splat %ptr : (!tt.ptr<f32>) -> tensor<128x!tt.ptr<f32>>
|
||||
@@ -25,3 +27,27 @@ func @load_ops(%ptr: !tt.ptr<f32> {tt.divisibility = 16 : i32}) {
|
||||
tt.store %ptrs, %c : tensor<128xf32>
|
||||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @reduce_ops(%ptr: !tt.ptr<f32> {tt.divisibility = 16 : i32}) {
|
||||
// Test if the total number of threadsPerWarp is 32
|
||||
// Test if the total number of warps is 2
|
||||
// CHECK: #blocked0 = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [4, 8], warpsPerCTA = [1, 2], order = [0, 1]}>
|
||||
// CHECK: #blocked1 = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [8, 4], warpsPerCTA = [1, 2], order = [0, 1]}>
|
||||
// CHECK: #blocked2 = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [16, 2], warpsPerCTA = [1, 2], order = [0, 1]}>
|
||||
// CHECK: module attributes {"triton_gpu.num-warps" = 2 : i32} {{.*}}
|
||||
%c0 = arith.constant dense<1.00e+00> : tensor<4x4xf32>
|
||||
%c1 = arith.constant dense<2.00e+00> : tensor<8x2xf32>
|
||||
%c2 = arith.constant dense<3.00e+00> : tensor<16x16xf32>
|
||||
// CHECK: tensor<4x4xf32, #blocked0> -> tensor<4xf32, #triton_gpu.slice<{dim = 0, parent = #blocked0}>>
|
||||
%c0_ = tt.reduce %c0 {redOp = 1 : i32, axis = 0 : i32} : tensor<4x4xf32> -> tensor<4xf32>
|
||||
// CHECK: tensor<8x2xf32, #blocked1> -> tensor<2xf32, #triton_gpu.slice<{dim = 0, parent = #blocked1}>
|
||||
%c1_ = tt.reduce %c1 {redOp = 1 : i32, axis = 0 : i32} : tensor<8x2xf32> -> tensor<2xf32>
|
||||
// CHECK: tensor<8x2xf32, #blocked1> -> tensor<8xf32, #triton_gpu.slice<{dim = 1, parent = #blocked1}>>
|
||||
%c2_ = tt.reduce %c1 {redOp = 1 : i32, axis = 1 : i32} : tensor<8x2xf32> -> tensor<8xf32>
|
||||
// CHECK: tensor<16x16xf32, #blocked2> -> tensor<16xf32, #triton_gpu.slice<{dim = 0, parent = #blocked2}>>
|
||||
%c3_ = tt.reduce %c2 {redOp = 1 : i32, axis = 0 : i32} : tensor<16x16xf32> -> tensor<16xf32>
|
||||
|
||||
return
|
||||
}
|
||||
|
Reference in New Issue
Block a user