[BACKEND] Added support for 1D conversion blocked -> slice (#831)

This commit is contained in:
Philippe Tillet
2022-11-01 13:19:58 -07:00
committed by GitHub
parent c9d84237e8
commit 12d60cb4a3
5 changed files with 103 additions and 78 deletions

View File

@@ -11,7 +11,9 @@
#include <numeric>
using ::mlir::triton::gpu::BlockedEncodingAttr;
using ::mlir::triton::gpu::getOrder;
using ::mlir::triton::gpu::getShapePerCTA;
using ::mlir::triton::gpu::getSizePerThread;
using ::mlir::triton::gpu::MmaEncodingAttr;
using ::mlir::triton::gpu::SharedEncodingAttr;
using ::mlir::triton::gpu::SliceEncodingAttr;
@@ -34,28 +36,16 @@ getScratchConfigForCvtLayout(triton::gpu::ConvertLayoutOp op, unsigned &inVec,
"Unexpect layout in getScratchConfigForCvtLayout()");
unsigned rank = dstTy.getRank();
SmallVector<unsigned> paddedRepShape(rank);
if (auto srcSliceLayout = srcLayout.dyn_cast<SliceEncodingAttr>())
srcLayout = srcSliceLayout.getParent();
if (auto dstSliceLayout = dstLayout.dyn_cast<SliceEncodingAttr>())
dstLayout = dstSliceLayout.getParent();
auto srcBlockedLayout = srcLayout.dyn_cast<BlockedEncodingAttr>();
auto srcMmaLayout = srcLayout.dyn_cast<MmaEncodingAttr>();
auto dstBlockedLayout = dstLayout.dyn_cast<BlockedEncodingAttr>();
auto dstMmaLayout = dstLayout.dyn_cast<MmaEncodingAttr>();
assert((srcBlockedLayout || srcMmaLayout) &&
"Unexpected srcLayout in getScratchConfigForCvtLayout");
assert((dstBlockedLayout || dstMmaLayout) &&
"Unexpected dstLayout in getScratchConfigForCvtLayout");
assert(!(srcMmaLayout && dstMmaLayout) &&
"Unexpected mma -> mma layout conversion");
auto inOrd =
srcMmaLayout ? dstBlockedLayout.getOrder() : srcBlockedLayout.getOrder();
auto outOrd =
dstMmaLayout ? srcBlockedLayout.getOrder() : dstBlockedLayout.getOrder();
unsigned srcContigPerThread =
srcBlockedLayout ? srcBlockedLayout.getSizePerThread()[inOrd[0]] : 2;
unsigned dstContigPerThread =
dstBlockedLayout ? dstBlockedLayout.getSizePerThread()[outOrd[0]] : 2;
auto inOrd = srcMmaLayout ? getOrder(dstLayout) : getOrder(srcLayout);
auto outOrd = dstMmaLayout ? getOrder(srcLayout) : getOrder(dstLayout);
unsigned srcContigPerThread = getSizePerThread(srcLayout)[inOrd[0]];
unsigned dstContigPerThread = getSizePerThread(dstLayout)[outOrd[0]];
// TODO: Fix the legacy issue that ourOrd[0] == 0 always means
// that we cannot do vectorization.
inVec = outOrd[0] == 0 ? 1 : inOrd[0] == 0 ? 1 : srcContigPerThread;
@@ -70,6 +60,8 @@ getScratchConfigForCvtLayout(triton::gpu::ConvertLayoutOp op, unsigned &inVec,
std::max(std::min<unsigned>(srcTy.getShape()[d], srcShapePerCTA[d]),
std::min<unsigned>(dstTy.getShape()[d], dstShapePerCTA[d]));
}
if (rank == 1)
return paddedRepShape;
unsigned paddedDim = 1;
if (auto dstBlockedLayout = dstLayout.dyn_cast<BlockedEncodingAttr>()) {
paddedDim = dstBlockedLayout.getOrder()[0];

View File

@@ -1197,7 +1197,7 @@ struct BroadcastOpConversion
broadcastDims.push_back(d);
srcLogicalShape[d] = 1;
srcLogicalShape[d + rank] =
std::max(unsigned(1), srcLayout.getSizePerThread()[d]);
std::max<unsigned>(1, srcLayout.getSizePerThread()[d]);
} else {
srcLogicalShape[d] = numCtas;
srcLogicalShape[d + rank] = resultLayout.getSizePerThread()[d];
@@ -2231,6 +2231,7 @@ LogicalResult ConvertLayoutOpConversion::lowerDistributedToDistributed(
}
barrier();
if (dstLayout.isa<BlockedEncodingAttr>() ||
dstLayout.isa<SliceEncodingAttr>() ||
dstLayout.isa<MmaEncodingAttr>()) {
processReplica(loc, rewriter, /*stNotRd*/ false, dstTy, outNumCTAsEachRep,
multiDimRepId, outVec, paddedRepShape, outOrd, outVals,

View File

@@ -68,18 +68,7 @@ SmallVector<unsigned> getSizePerThread(Attribute layout) {
return SmallVector<unsigned>(blockedLayout.getSizePerThread().begin(),
blockedLayout.getSizePerThread().end());
} else if (auto sliceLayout = layout.dyn_cast<SliceEncodingAttr>()) {
unsigned dim = sliceLayout.getDim();
auto parent = sliceLayout.getParent();
if (auto blockedParent = parent.dyn_cast<BlockedEncodingAttr>()) {
SmallVector<unsigned> sizePerThread(
blockedParent.getSizePerThread().begin(),
blockedParent.getSizePerThread().end());
sizePerThread.erase(sizePerThread.begin() + dim);
return sizePerThread;
} else {
assert(0 && "SliceEncodingAttr with parent other than "
"BlockedEncodingAttr not implemented");
}
return getSizePerThread(sliceLayout.getParent());
} else if (auto mmaLayout = layout.dyn_cast<MmaEncodingAttr>()) {
assert(mmaLayout.getVersion() == 2 &&
"mmaLayout version = 1 is not implemented yet");
@@ -144,6 +133,19 @@ SmallVector<unsigned> getOrder(const Attribute &layout) {
blockedLayout.getOrder().end());
} else if (auto mmaLayout = layout.dyn_cast<MmaEncodingAttr>()) {
return SmallVector<unsigned>{1, 0};
} else if (auto sliceLayout = layout.dyn_cast<SliceEncodingAttr>()) {
SmallVector<unsigned> parentOrder = getOrder(sliceLayout.getParent());
unsigned dim = sliceLayout.getDim();
SmallVector<unsigned> order;
for (unsigned d : parentOrder) {
if (d == dim)
continue;
else if (d > dim)
order.push_back(d - 1);
else
order.push_back(d);
}
return order;
} else if (auto sharedLayout = layout.dyn_cast<SharedEncodingAttr>()) {
return SmallVector<unsigned>(sharedLayout.getOrder().begin(),
sharedLayout.getOrder().end());

View File

@@ -493,61 +493,65 @@ def make_ptr_str(name, shape):
# TODO: handle `%4 = triton_gpu.convert_layout %3 : (tensor<32xi32, #blocked0>) -> tensor<32xi32, #triton_gpu.slice<{dim = 0, parent = #blocked1}>>``
# @pytest.mark.parametrize("expr, dtype_str", [
# (f'x[{s}]', d)
# for s in ['None, :', ':, None', 'None, :, :', ':, :, None']
# for d in ['int32', 'uint32', 'uint16']
# ])
# def test_index1d(expr, dtype_str, device='cuda'):
# rank_x = expr.count(':')
# rank_y = expr.count(',') + 1
# shape_x = [32 for _ in range(rank_x)]
# shape_z = [32 for _ in range(rank_y)]
# shape_z_rank_mismatch = [32 for _ in range(rank_y + 1)]
# shape_z_dim_mismatch = [64 for _ in range(rank_y)]
@pytest.mark.parametrize("expr, dtype_str", [
(f'x[{s}]', d)
for s in ['None, :', ':, None',
# TODO: 3D
# 'None, :, :',
# ':, :, None'
]
for d in ['int32', 'uint32', 'uint16']
])
def test_index1d(expr, dtype_str, device='cuda'):
rank_x = expr.count(':')
rank_y = expr.count(',') + 1
shape_x = [32 for _ in range(rank_x)]
shape_z = [32 for _ in range(rank_y)]
shape_z_rank_mismatch = [32 for _ in range(rank_y + 1)]
shape_z_dim_mismatch = [64 for _ in range(rank_y)]
# # Triton kernel
# @triton.jit
# def kernel(Z, X, SIZE: tl.constexpr):
# m = tl.arange(0, SIZE)
# n = tl.arange(0, SIZE)
# x = tl.load(X_PTR_EXPR)
# z = GENERATE_TEST_HERE
# tl.store(Z_PTR_EXPR, z)
# Triton kernel
@triton.jit
def kernel(Z, X, SIZE: tl.constexpr):
m = tl.arange(0, SIZE)
n = tl.arange(0, SIZE)
x = tl.load(X_PTR_EXPR)
z = GENERATE_TEST_HERE
tl.store(Z_PTR_EXPR, z)
# def generate_kernel(shape_x, shape_z):
# to_replace = {
# 'X_PTR_EXPR': make_ptr_str('X', shape_x),
# 'Z_PTR_EXPR': make_ptr_str('Z', shape_z),
# 'GENERATE_TEST_HERE': expr,
# }
# return patch_kernel(kernel, to_replace)
def generate_kernel(shape_x, shape_z):
to_replace = {
'X_PTR_EXPR': make_ptr_str('X', shape_x),
'Z_PTR_EXPR': make_ptr_str('Z', shape_z),
'GENERATE_TEST_HERE': expr,
}
return patch_kernel(kernel, to_replace)
# kernel_match = generate_kernel(shape_x, shape_z)
# kernel_dim_mismatch = generate_kernel(shape_x, shape_z_dim_mismatch)
# kernel_rank_mismatch = generate_kernel(shape_x, shape_z_rank_mismatch)
kernel_match = generate_kernel(shape_x, shape_z)
kernel_dim_mismatch = generate_kernel(shape_x, shape_z_dim_mismatch)
kernel_rank_mismatch = generate_kernel(shape_x, shape_z_rank_mismatch)
# # torch result
# x = numpy_random(shape_x, dtype_str=dtype_str)
# y = np.zeros(shape_z, dtype=getattr(np, dtype_str))
# z_ref = eval(expr) + y
# # triton result
# z_tri = to_triton(np.empty_like(z_ref), device=device)
# x_tri = to_triton(x)
# kernel_match[(1, )](z_tri, x_tri, num_warps=1, SIZE=shape_x[0])
# # compare
# assert (z_ref == to_numpy(z_tri)).all()
# torch result
x = numpy_random(shape_x, dtype_str=dtype_str)
y = np.zeros(shape_z, dtype=getattr(np, dtype_str))
z_ref = eval(expr) + y
# triton result
z_tri = to_triton(np.empty_like(z_ref), device=device)
x_tri = to_triton(x)
kernel_match[(1, )](z_tri, x_tri, num_warps=1, SIZE=shape_x[0])
# compare
assert (z_ref == to_numpy(z_tri)).all()
# def catch_compilation_error(kernel):
# try:
# kernel[(1, )](z_tri, x_tri, num_warps=1, SIZE=shape_x[0])
# except triton.CompilationError as e:
# np.testing.assert_(True)
# except BaseException:
# np.testing.assert_(False)
def catch_compilation_error(kernel):
try:
kernel[(1, )](z_tri, x_tri, num_warps=1, SIZE=shape_x[0])
except triton.CompilationError as e:
np.testing.assert_(True)
except BaseException:
np.testing.assert_(False)
# catch_compilation_error(kernel_dim_mismatch)
# catch_compilation_error(kernel_rank_mismatch)
catch_compilation_error(kernel_dim_mismatch)
catch_compilation_error(kernel_rank_mismatch)
# # ---------------

View File

@@ -715,6 +715,32 @@ module attributes {"triton_gpu.num-warps" = 1 : i32} {
}
}
// -----
#blocked0 = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [1], order = [0]}>
#blocked1 = #triton_gpu.blocked<{sizePerThread = [1, 4], threadsPerWarp = [4, 8], warpsPerCTA = [1, 1], order = [1, 0]}>
module attributes {"triton_gpu.num-warps" = 1 : i32} {
// CHECK-LABEL: convert_blocked1d_to_slice0
func @convert_blocked1d_to_slice0(%src:tensor<32xi32, #blocked0>) {
// CHECK-COUNT-4: llvm.load {{.*}} : !llvm.ptr<vector<1xi32>, 3>
%cvt = triton_gpu.convert_layout %src : (tensor<32xi32, #blocked0>) -> tensor<32xi32, #triton_gpu.slice<{dim = 0, parent = #blocked1}>>
return
}
}
// -----
#blocked0 = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [1], order = [0]}>
#blocked1 = #triton_gpu.blocked<{sizePerThread = [1, 4], threadsPerWarp = [4, 8], warpsPerCTA = [1, 1], order = [1, 0]}>
module attributes {"triton_gpu.num-warps" = 1 : i32} {
// CHECK-LABEL: convert_blocked1d_to_slice1
func @convert_blocked1d_to_slice1(%src:tensor<32xi32, #blocked0>) {
// CHECK-COUNT-32: llvm.load {{.*}} : !llvm.ptr<vector<1xi32>, 3>
%cvt = triton_gpu.convert_layout %src : (tensor<32xi32, #blocked0>) -> tensor<32xi32, #triton_gpu.slice<{dim = 1, parent = #blocked1}>>
return
}
}
// -----
#blocked = #triton_gpu.blocked<{sizePerThread = [1, 4], threadsPerWarp = [2, 16], warpsPerCTA = [1, 4], order = [1, 0]}>
#shared = #triton_gpu.shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [1, 0]}>