[BACKEND] Added support for 1D conversion blocked -> slice (#831)
This commit is contained in:
@@ -11,7 +11,9 @@
|
||||
#include <numeric>
|
||||
|
||||
using ::mlir::triton::gpu::BlockedEncodingAttr;
|
||||
using ::mlir::triton::gpu::getOrder;
|
||||
using ::mlir::triton::gpu::getShapePerCTA;
|
||||
using ::mlir::triton::gpu::getSizePerThread;
|
||||
using ::mlir::triton::gpu::MmaEncodingAttr;
|
||||
using ::mlir::triton::gpu::SharedEncodingAttr;
|
||||
using ::mlir::triton::gpu::SliceEncodingAttr;
|
||||
@@ -34,28 +36,16 @@ getScratchConfigForCvtLayout(triton::gpu::ConvertLayoutOp op, unsigned &inVec,
|
||||
"Unexpect layout in getScratchConfigForCvtLayout()");
|
||||
unsigned rank = dstTy.getRank();
|
||||
SmallVector<unsigned> paddedRepShape(rank);
|
||||
if (auto srcSliceLayout = srcLayout.dyn_cast<SliceEncodingAttr>())
|
||||
srcLayout = srcSliceLayout.getParent();
|
||||
if (auto dstSliceLayout = dstLayout.dyn_cast<SliceEncodingAttr>())
|
||||
dstLayout = dstSliceLayout.getParent();
|
||||
auto srcBlockedLayout = srcLayout.dyn_cast<BlockedEncodingAttr>();
|
||||
auto srcMmaLayout = srcLayout.dyn_cast<MmaEncodingAttr>();
|
||||
auto dstBlockedLayout = dstLayout.dyn_cast<BlockedEncodingAttr>();
|
||||
auto dstMmaLayout = dstLayout.dyn_cast<MmaEncodingAttr>();
|
||||
assert((srcBlockedLayout || srcMmaLayout) &&
|
||||
"Unexpected srcLayout in getScratchConfigForCvtLayout");
|
||||
assert((dstBlockedLayout || dstMmaLayout) &&
|
||||
"Unexpected dstLayout in getScratchConfigForCvtLayout");
|
||||
assert(!(srcMmaLayout && dstMmaLayout) &&
|
||||
"Unexpected mma -> mma layout conversion");
|
||||
auto inOrd =
|
||||
srcMmaLayout ? dstBlockedLayout.getOrder() : srcBlockedLayout.getOrder();
|
||||
auto outOrd =
|
||||
dstMmaLayout ? srcBlockedLayout.getOrder() : dstBlockedLayout.getOrder();
|
||||
unsigned srcContigPerThread =
|
||||
srcBlockedLayout ? srcBlockedLayout.getSizePerThread()[inOrd[0]] : 2;
|
||||
unsigned dstContigPerThread =
|
||||
dstBlockedLayout ? dstBlockedLayout.getSizePerThread()[outOrd[0]] : 2;
|
||||
auto inOrd = srcMmaLayout ? getOrder(dstLayout) : getOrder(srcLayout);
|
||||
auto outOrd = dstMmaLayout ? getOrder(srcLayout) : getOrder(dstLayout);
|
||||
unsigned srcContigPerThread = getSizePerThread(srcLayout)[inOrd[0]];
|
||||
unsigned dstContigPerThread = getSizePerThread(dstLayout)[outOrd[0]];
|
||||
// TODO: Fix the legacy issue that ourOrd[0] == 0 always means
|
||||
// that we cannot do vectorization.
|
||||
inVec = outOrd[0] == 0 ? 1 : inOrd[0] == 0 ? 1 : srcContigPerThread;
|
||||
@@ -70,6 +60,8 @@ getScratchConfigForCvtLayout(triton::gpu::ConvertLayoutOp op, unsigned &inVec,
|
||||
std::max(std::min<unsigned>(srcTy.getShape()[d], srcShapePerCTA[d]),
|
||||
std::min<unsigned>(dstTy.getShape()[d], dstShapePerCTA[d]));
|
||||
}
|
||||
if (rank == 1)
|
||||
return paddedRepShape;
|
||||
unsigned paddedDim = 1;
|
||||
if (auto dstBlockedLayout = dstLayout.dyn_cast<BlockedEncodingAttr>()) {
|
||||
paddedDim = dstBlockedLayout.getOrder()[0];
|
||||
|
@@ -1197,7 +1197,7 @@ struct BroadcastOpConversion
|
||||
broadcastDims.push_back(d);
|
||||
srcLogicalShape[d] = 1;
|
||||
srcLogicalShape[d + rank] =
|
||||
std::max(unsigned(1), srcLayout.getSizePerThread()[d]);
|
||||
std::max<unsigned>(1, srcLayout.getSizePerThread()[d]);
|
||||
} else {
|
||||
srcLogicalShape[d] = numCtas;
|
||||
srcLogicalShape[d + rank] = resultLayout.getSizePerThread()[d];
|
||||
@@ -2231,6 +2231,7 @@ LogicalResult ConvertLayoutOpConversion::lowerDistributedToDistributed(
|
||||
}
|
||||
barrier();
|
||||
if (dstLayout.isa<BlockedEncodingAttr>() ||
|
||||
dstLayout.isa<SliceEncodingAttr>() ||
|
||||
dstLayout.isa<MmaEncodingAttr>()) {
|
||||
processReplica(loc, rewriter, /*stNotRd*/ false, dstTy, outNumCTAsEachRep,
|
||||
multiDimRepId, outVec, paddedRepShape, outOrd, outVals,
|
||||
|
@@ -68,18 +68,7 @@ SmallVector<unsigned> getSizePerThread(Attribute layout) {
|
||||
return SmallVector<unsigned>(blockedLayout.getSizePerThread().begin(),
|
||||
blockedLayout.getSizePerThread().end());
|
||||
} else if (auto sliceLayout = layout.dyn_cast<SliceEncodingAttr>()) {
|
||||
unsigned dim = sliceLayout.getDim();
|
||||
auto parent = sliceLayout.getParent();
|
||||
if (auto blockedParent = parent.dyn_cast<BlockedEncodingAttr>()) {
|
||||
SmallVector<unsigned> sizePerThread(
|
||||
blockedParent.getSizePerThread().begin(),
|
||||
blockedParent.getSizePerThread().end());
|
||||
sizePerThread.erase(sizePerThread.begin() + dim);
|
||||
return sizePerThread;
|
||||
} else {
|
||||
assert(0 && "SliceEncodingAttr with parent other than "
|
||||
"BlockedEncodingAttr not implemented");
|
||||
}
|
||||
return getSizePerThread(sliceLayout.getParent());
|
||||
} else if (auto mmaLayout = layout.dyn_cast<MmaEncodingAttr>()) {
|
||||
assert(mmaLayout.getVersion() == 2 &&
|
||||
"mmaLayout version = 1 is not implemented yet");
|
||||
@@ -144,6 +133,19 @@ SmallVector<unsigned> getOrder(const Attribute &layout) {
|
||||
blockedLayout.getOrder().end());
|
||||
} else if (auto mmaLayout = layout.dyn_cast<MmaEncodingAttr>()) {
|
||||
return SmallVector<unsigned>{1, 0};
|
||||
} else if (auto sliceLayout = layout.dyn_cast<SliceEncodingAttr>()) {
|
||||
SmallVector<unsigned> parentOrder = getOrder(sliceLayout.getParent());
|
||||
unsigned dim = sliceLayout.getDim();
|
||||
SmallVector<unsigned> order;
|
||||
for (unsigned d : parentOrder) {
|
||||
if (d == dim)
|
||||
continue;
|
||||
else if (d > dim)
|
||||
order.push_back(d - 1);
|
||||
else
|
||||
order.push_back(d);
|
||||
}
|
||||
return order;
|
||||
} else if (auto sharedLayout = layout.dyn_cast<SharedEncodingAttr>()) {
|
||||
return SmallVector<unsigned>(sharedLayout.getOrder().begin(),
|
||||
sharedLayout.getOrder().end());
|
||||
|
@@ -493,61 +493,65 @@ def make_ptr_str(name, shape):
|
||||
|
||||
|
||||
# TODO: handle `%4 = triton_gpu.convert_layout %3 : (tensor<32xi32, #blocked0>) -> tensor<32xi32, #triton_gpu.slice<{dim = 0, parent = #blocked1}>>``
|
||||
# @pytest.mark.parametrize("expr, dtype_str", [
|
||||
# (f'x[{s}]', d)
|
||||
# for s in ['None, :', ':, None', 'None, :, :', ':, :, None']
|
||||
# for d in ['int32', 'uint32', 'uint16']
|
||||
# ])
|
||||
# def test_index1d(expr, dtype_str, device='cuda'):
|
||||
# rank_x = expr.count(':')
|
||||
# rank_y = expr.count(',') + 1
|
||||
# shape_x = [32 for _ in range(rank_x)]
|
||||
# shape_z = [32 for _ in range(rank_y)]
|
||||
# shape_z_rank_mismatch = [32 for _ in range(rank_y + 1)]
|
||||
# shape_z_dim_mismatch = [64 for _ in range(rank_y)]
|
||||
@pytest.mark.parametrize("expr, dtype_str", [
|
||||
(f'x[{s}]', d)
|
||||
for s in ['None, :', ':, None',
|
||||
# TODO: 3D
|
||||
# 'None, :, :',
|
||||
# ':, :, None'
|
||||
]
|
||||
for d in ['int32', 'uint32', 'uint16']
|
||||
])
|
||||
def test_index1d(expr, dtype_str, device='cuda'):
|
||||
rank_x = expr.count(':')
|
||||
rank_y = expr.count(',') + 1
|
||||
shape_x = [32 for _ in range(rank_x)]
|
||||
shape_z = [32 for _ in range(rank_y)]
|
||||
shape_z_rank_mismatch = [32 for _ in range(rank_y + 1)]
|
||||
shape_z_dim_mismatch = [64 for _ in range(rank_y)]
|
||||
|
||||
# # Triton kernel
|
||||
# @triton.jit
|
||||
# def kernel(Z, X, SIZE: tl.constexpr):
|
||||
# m = tl.arange(0, SIZE)
|
||||
# n = tl.arange(0, SIZE)
|
||||
# x = tl.load(X_PTR_EXPR)
|
||||
# z = GENERATE_TEST_HERE
|
||||
# tl.store(Z_PTR_EXPR, z)
|
||||
# Triton kernel
|
||||
@triton.jit
|
||||
def kernel(Z, X, SIZE: tl.constexpr):
|
||||
m = tl.arange(0, SIZE)
|
||||
n = tl.arange(0, SIZE)
|
||||
x = tl.load(X_PTR_EXPR)
|
||||
z = GENERATE_TEST_HERE
|
||||
tl.store(Z_PTR_EXPR, z)
|
||||
|
||||
# def generate_kernel(shape_x, shape_z):
|
||||
# to_replace = {
|
||||
# 'X_PTR_EXPR': make_ptr_str('X', shape_x),
|
||||
# 'Z_PTR_EXPR': make_ptr_str('Z', shape_z),
|
||||
# 'GENERATE_TEST_HERE': expr,
|
||||
# }
|
||||
# return patch_kernel(kernel, to_replace)
|
||||
def generate_kernel(shape_x, shape_z):
|
||||
to_replace = {
|
||||
'X_PTR_EXPR': make_ptr_str('X', shape_x),
|
||||
'Z_PTR_EXPR': make_ptr_str('Z', shape_z),
|
||||
'GENERATE_TEST_HERE': expr,
|
||||
}
|
||||
return patch_kernel(kernel, to_replace)
|
||||
|
||||
# kernel_match = generate_kernel(shape_x, shape_z)
|
||||
# kernel_dim_mismatch = generate_kernel(shape_x, shape_z_dim_mismatch)
|
||||
# kernel_rank_mismatch = generate_kernel(shape_x, shape_z_rank_mismatch)
|
||||
kernel_match = generate_kernel(shape_x, shape_z)
|
||||
kernel_dim_mismatch = generate_kernel(shape_x, shape_z_dim_mismatch)
|
||||
kernel_rank_mismatch = generate_kernel(shape_x, shape_z_rank_mismatch)
|
||||
|
||||
# # torch result
|
||||
# x = numpy_random(shape_x, dtype_str=dtype_str)
|
||||
# y = np.zeros(shape_z, dtype=getattr(np, dtype_str))
|
||||
# z_ref = eval(expr) + y
|
||||
# # triton result
|
||||
# z_tri = to_triton(np.empty_like(z_ref), device=device)
|
||||
# x_tri = to_triton(x)
|
||||
# kernel_match[(1, )](z_tri, x_tri, num_warps=1, SIZE=shape_x[0])
|
||||
# # compare
|
||||
# assert (z_ref == to_numpy(z_tri)).all()
|
||||
# torch result
|
||||
x = numpy_random(shape_x, dtype_str=dtype_str)
|
||||
y = np.zeros(shape_z, dtype=getattr(np, dtype_str))
|
||||
z_ref = eval(expr) + y
|
||||
# triton result
|
||||
z_tri = to_triton(np.empty_like(z_ref), device=device)
|
||||
x_tri = to_triton(x)
|
||||
kernel_match[(1, )](z_tri, x_tri, num_warps=1, SIZE=shape_x[0])
|
||||
# compare
|
||||
assert (z_ref == to_numpy(z_tri)).all()
|
||||
|
||||
# def catch_compilation_error(kernel):
|
||||
# try:
|
||||
# kernel[(1, )](z_tri, x_tri, num_warps=1, SIZE=shape_x[0])
|
||||
# except triton.CompilationError as e:
|
||||
# np.testing.assert_(True)
|
||||
# except BaseException:
|
||||
# np.testing.assert_(False)
|
||||
def catch_compilation_error(kernel):
|
||||
try:
|
||||
kernel[(1, )](z_tri, x_tri, num_warps=1, SIZE=shape_x[0])
|
||||
except triton.CompilationError as e:
|
||||
np.testing.assert_(True)
|
||||
except BaseException:
|
||||
np.testing.assert_(False)
|
||||
|
||||
# catch_compilation_error(kernel_dim_mismatch)
|
||||
# catch_compilation_error(kernel_rank_mismatch)
|
||||
catch_compilation_error(kernel_dim_mismatch)
|
||||
catch_compilation_error(kernel_rank_mismatch)
|
||||
|
||||
|
||||
# # ---------------
|
||||
|
@@ -715,6 +715,32 @@ module attributes {"triton_gpu.num-warps" = 1 : i32} {
|
||||
}
|
||||
}
|
||||
|
||||
// -----
|
||||
#blocked0 = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [1], order = [0]}>
|
||||
#blocked1 = #triton_gpu.blocked<{sizePerThread = [1, 4], threadsPerWarp = [4, 8], warpsPerCTA = [1, 1], order = [1, 0]}>
|
||||
module attributes {"triton_gpu.num-warps" = 1 : i32} {
|
||||
// CHECK-LABEL: convert_blocked1d_to_slice0
|
||||
func @convert_blocked1d_to_slice0(%src:tensor<32xi32, #blocked0>) {
|
||||
// CHECK-COUNT-4: llvm.load {{.*}} : !llvm.ptr<vector<1xi32>, 3>
|
||||
%cvt = triton_gpu.convert_layout %src : (tensor<32xi32, #blocked0>) -> tensor<32xi32, #triton_gpu.slice<{dim = 0, parent = #blocked1}>>
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
// -----
|
||||
#blocked0 = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [1], order = [0]}>
|
||||
#blocked1 = #triton_gpu.blocked<{sizePerThread = [1, 4], threadsPerWarp = [4, 8], warpsPerCTA = [1, 1], order = [1, 0]}>
|
||||
module attributes {"triton_gpu.num-warps" = 1 : i32} {
|
||||
// CHECK-LABEL: convert_blocked1d_to_slice1
|
||||
func @convert_blocked1d_to_slice1(%src:tensor<32xi32, #blocked0>) {
|
||||
// CHECK-COUNT-32: llvm.load {{.*}} : !llvm.ptr<vector<1xi32>, 3>
|
||||
%cvt = triton_gpu.convert_layout %src : (tensor<32xi32, #blocked0>) -> tensor<32xi32, #triton_gpu.slice<{dim = 1, parent = #blocked1}>>
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
// -----
|
||||
#blocked = #triton_gpu.blocked<{sizePerThread = [1, 4], threadsPerWarp = [2, 16], warpsPerCTA = [1, 4], order = [1, 0]}>
|
||||
#shared = #triton_gpu.shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [1, 0]}>
|
||||
|
Reference in New Issue
Block a user