[Triton-MLIR][Backend] Fix number of warps and threads per warp when matrices are small (#917)

This commit is contained in:
Keren Zhou
2022-11-26 12:30:38 -08:00
committed by GitHub
parent f63be0e9b5
commit 35c9ec1103
7 changed files with 116 additions and 29 deletions

View File

@@ -8,6 +8,29 @@
namespace mlir {
class ReduceOpHelper {
public:
explicit ReduceOpHelper(triton::ReduceOp op) : op(op) {
srcTy = op.operand().getType().cast<RankedTensorType>();
}
ArrayRef<int64_t> getSrcShape() { return srcTy.getShape(); }
Attribute getSrcLayout() { return srcTy.getEncoding(); }
bool isFastReduction();
unsigned getInterWarpSize();
unsigned getIntraWarpSize();
unsigned getThreadsReductionAxis();
private:
triton::ReduceOp op;
RankedTensorType srcTy{};
};
bool isSharedEncoding(Value value);
bool maybeSharedAllocationOp(Operation *op);

View File

@@ -228,9 +228,11 @@ for
unsigned remainingLanes = 32;
unsigned remainingThreads = numWarps*32;
unsigned remainingWarps = numWarps;
unsigned prevLanes = 1;
unsigned prevWarps = 1;
SmallVector<unsigned, 4> threadsPerWarp(rank);
SmallVector<unsigned, 4> warpsPerCTA(rank);
for (int _dim = 0; _dim < rank; ++_dim) {
for (int _dim = 0; _dim < rank - 1; ++_dim) {
int i = order[_dim];
unsigned threadsPerCTA = std::clamp<unsigned>(remainingThreads, 1, shape[i] / sizePerThread[i]);
threadsPerWarp[i] = std::clamp<unsigned>(threadsPerCTA, 1, remainingLanes);
@@ -238,7 +240,12 @@ for
remainingWarps /= warpsPerCTA[i];
remainingLanes /= threadsPerWarp[i];
remainingThreads /= threadsPerCTA;
prevLanes *= threadsPerWarp[i];
prevWarps *= warpsPerCTA[i];
}
// Expand the last dimension to fill the remaining lanes and warps
threadsPerWarp[order[rank-1]] = 32 / prevLanes;
warpsPerCTA[order[rank-1]] = numWarps / prevWarps;
return $_get(context, sizePerThread, threadsPerWarp, warpsPerCTA, order);

View File

@@ -89,24 +89,19 @@ getScratchConfigForCvtLayout(triton::gpu::ConvertLayoutOp op, unsigned &inVec,
}
SmallVector<unsigned> getScratchConfigForReduce(triton::ReduceOp op) {
auto srcTy = op.operand().getType().cast<RankedTensorType>();
auto srcLayout = srcTy.getEncoding();
auto srcShape = srcTy.getShape();
auto axis = op.axis();
bool fastReduce = axis == getOrder(srcLayout)[0];
ReduceOpHelper helper(op);
SmallVector<unsigned> smemShape;
auto srcShape = helper.getSrcShape();
for (auto d : srcShape)
smemShape.push_back(d);
if (fastReduce) {
unsigned sizeInterWarps = gpu::getWarpsPerCTA(srcLayout)[axis];
smemShape[axis] = sizeInterWarps;
auto axis = op.axis();
if (helper.isFastReduction()) {
smemShape[axis] = helper.getInterWarpSize();
} else {
unsigned threadsPerCTAAxis = gpu::getThreadsPerWarp(srcLayout)[axis] *
gpu::getWarpsPerCTA(srcLayout)[axis];
smemShape[axis] = threadsPerCTAAxis;
smemShape[axis] =
std::min(smemShape[axis], helper.getThreadsReductionAxis());
}
return smemShape;
@@ -181,8 +176,7 @@ private:
// TODO(Keren): Reduce with index is not supported yet.
auto value = op->getOperand(0);
if (auto tensorType = value.getType().dyn_cast<RankedTensorType>()) {
auto srcLayout = tensorType.getEncoding();
bool fastReduce = reduceOp.axis() == getOrder(srcLayout)[0];
bool fastReduce = ReduceOpHelper(reduceOp).isFastReduction();
auto smemShape = getScratchConfigForReduce(reduceOp);
unsigned elems = std::accumulate(smemShape.begin(), smemShape.end(), 1,
std::multiplies{});

View File

@@ -5,6 +5,38 @@
namespace mlir {
bool ReduceOpHelper::isFastReduction() {
auto srcLayout = srcTy.getEncoding();
auto axis = op.axis();
return axis == triton::gpu::getOrder(srcLayout)[0];
}
unsigned ReduceOpHelper::getInterWarpSize() {
auto srcLayout = srcTy.getEncoding();
auto srcShape = srcTy.getShape();
auto axis = op.axis();
auto srcReduceDimSize = static_cast<unsigned>(srcShape[axis]);
unsigned sizeIntraWarps = getIntraWarpSize();
return std::min(srcReduceDimSize / sizeIntraWarps,
triton::gpu::getWarpsPerCTA(srcLayout)[axis]);
}
unsigned ReduceOpHelper::getIntraWarpSize() {
auto srcLayout = srcTy.getEncoding();
auto srcShape = srcTy.getShape();
auto axis = op.axis();
auto srcReduceDimSize = static_cast<unsigned>(srcShape[axis]);
return std::min(srcReduceDimSize,
triton::gpu::getThreadsPerWarp(srcLayout)[axis]);
}
unsigned ReduceOpHelper::getThreadsReductionAxis() {
auto srcLayout = srcTy.getEncoding();
auto axis = op.axis();
return triton::gpu::getThreadsPerWarp(srcLayout)[axis] *
triton::gpu::getWarpsPerCTA(srcLayout)[axis];
}
bool isSharedEncoding(Value value) {
auto type = value.getType();
if (auto tensorType = type.dyn_cast<RankedTensorType>()) {

View File

@@ -1563,9 +1563,7 @@ private:
LogicalResult
ReduceOpConversion::matchAndRewrite(triton::ReduceOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const {
auto srcTy = op.operand().getType().cast<RankedTensorType>();
auto srcLayout = srcTy.getEncoding().cast<BlockedEncodingAttr>();
if (op.axis() == srcLayout.getOrder()[0])
if (ReduceOpHelper(op).isFastReduction())
return matchAndRewriteFast(op, adaptor, rewriter);
return matchAndRewriteBasic(op, adaptor, rewriter);
}
@@ -1763,10 +1761,11 @@ LogicalResult ReduceOpConversion::matchAndRewriteFast(
Value smemBase = getSharedMemoryBase(loc, rewriter, op.getOperation());
smemBase = bitcast(smemBase, elemPtrTy);
auto order = getOrder(srcLayout);
unsigned sizeIntraWarps = threadsPerWarp[axis];
unsigned sizeInterWarps = warpsPerCTA[axis];
ReduceOpHelper helper(op);
unsigned sizeIntraWarps = helper.getIntraWarpSize();
unsigned sizeInterWarps = helper.getInterWarpSize();
auto order = getOrder(srcLayout);
unsigned srcElems = getElemsPerThread(srcTy);
auto srcIndices = emitIndices(loc, rewriter, srcLayout, srcShape);
auto srcValues = getElementsFromStruct(loc, adaptor.operand(), rewriter);

View File

@@ -648,10 +648,11 @@ def test_atomic_rmw(op, dtype_x_str, mode, device='cuda'):
else:
np.testing.assert_allclose(z_ref, to_numpy(z_tri), rtol=0.01)
#TODO[dongdongl]:add more cases with size of tensor less than warp size
@pytest.mark.parametrize("axis", [0, 1])
def test_tensor_atomic_rmw(axis, device="cuda"):
shape0, shape1 = 8, 8
@pytest.mark.parametrize("shape, axis",
[(shape, axis) for shape in [(2, 2), (2, 8), (8, 2), (8, 8), (32, 32)] for axis in [0, 1]])
def test_tensor_atomic_rmw(shape, axis, device="cuda"):
shape0, shape1 = shape
# triton kernel
@triton.jit
@@ -660,14 +661,19 @@ def test_tensor_atomic_rmw(axis, device="cuda"):
off1 = tl.arange(0, SHAPE1)
x = tl.load(X + off0[:, None] * SHAPE1 + off1[None, :])
z = tl.sum(x, axis=AXIS)
tl.atomic_add(Z + off0, z)
if AXIS == 1:
tl.atomic_add(Z + off0, z)
else:
tl.atomic_add(Z + off1, z)
rs = RandomState(17)
x = numpy_random((shape0, shape1), dtype_str="float32", rs=rs)
print(x)
# reference result
z_ref = np.sum(x, axis=axis)
z_ref = np.sum(x, axis=axis, keepdims=False)
# triton result
x_tri = to_triton(x, device=device)
z_tri = to_triton(np.zeros((shape0,), dtype="float32"), device=device)
z_shape = (shape0, ) if axis == 1 else (shape1, )
z_tri = to_triton(np.zeros(z_shape, dtype="float32"), device=device)
kernel[(1,)](z_tri, x_tri, axis, shape0, shape1)
np.testing.assert_allclose(z_ref, to_numpy(z_tri), rtol=1e-4)

View File

@@ -1,4 +1,4 @@
// RUN: triton-opt %s -convert-triton-to-tritongpu=num-warps=2 | FileCheck %s
// RUN: triton-opt %s -split-input-file -convert-triton-to-tritongpu=num-warps=2 | FileCheck %s
func @ops() {
// CHECK: module attributes {"triton_gpu.num-warps" = 2 : i32} {{.*}}
@@ -9,6 +9,8 @@ func @ops() {
return
}
// -----
func @load_ops(%ptr: !tt.ptr<f32> {tt.divisibility = 16 : i32}) {
// Test if LoadOp is lowered properly (see #771)
%ptrs = tt.splat %ptr : (!tt.ptr<f32>) -> tensor<128x!tt.ptr<f32>>
@@ -25,3 +27,27 @@ func @load_ops(%ptr: !tt.ptr<f32> {tt.divisibility = 16 : i32}) {
tt.store %ptrs, %c : tensor<128xf32>
return
}
// -----
func @reduce_ops(%ptr: !tt.ptr<f32> {tt.divisibility = 16 : i32}) {
// Test if the total number of threadsPerWarp is 32
// Test if the total number of warps is 2
// CHECK: #blocked0 = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [4, 8], warpsPerCTA = [1, 2], order = [0, 1]}>
// CHECK: #blocked1 = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [8, 4], warpsPerCTA = [1, 2], order = [0, 1]}>
// CHECK: #blocked2 = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [16, 2], warpsPerCTA = [1, 2], order = [0, 1]}>
// CHECK: module attributes {"triton_gpu.num-warps" = 2 : i32} {{.*}}
%c0 = arith.constant dense<1.00e+00> : tensor<4x4xf32>
%c1 = arith.constant dense<2.00e+00> : tensor<8x2xf32>
%c2 = arith.constant dense<3.00e+00> : tensor<16x16xf32>
// CHECK: tensor<4x4xf32, #blocked0> -> tensor<4xf32, #triton_gpu.slice<{dim = 0, parent = #blocked0}>>
%c0_ = tt.reduce %c0 {redOp = 1 : i32, axis = 0 : i32} : tensor<4x4xf32> -> tensor<4xf32>
// CHECK: tensor<8x2xf32, #blocked1> -> tensor<2xf32, #triton_gpu.slice<{dim = 0, parent = #blocked1}>
%c1_ = tt.reduce %c1 {redOp = 1 : i32, axis = 0 : i32} : tensor<8x2xf32> -> tensor<2xf32>
// CHECK: tensor<8x2xf32, #blocked1> -> tensor<8xf32, #triton_gpu.slice<{dim = 1, parent = #blocked1}>>
%c2_ = tt.reduce %c1 {redOp = 1 : i32, axis = 1 : i32} : tensor<8x2xf32> -> tensor<8xf32>
// CHECK: tensor<16x16xf32, #blocked2> -> tensor<16xf32, #triton_gpu.slice<{dim = 0, parent = #blocked2}>>
%c3_ = tt.reduce %c2 {redOp = 1 : i32, axis = 0 : i32} : tensor<16x16xf32> -> tensor<16xf32>
return
}