[TRITON-MLIR][BACKEND]AtomicRMWOp supports scalar (#903)
AtomicRMWOp supports scalar Co-authored-by: dongdongl <dongdongl@nvidia.com>
This commit is contained in:
@@ -112,6 +112,20 @@ SmallVector<unsigned> getScratchConfigForReduce(triton::ReduceOp op) {
|
||||
return smemShape;
|
||||
}
|
||||
|
||||
// TODO: extend beyond scalars
|
||||
SmallVector<unsigned> getScratchConfigForAtomicRMW(triton::AtomicRMWOp op) {
|
||||
SmallVector<unsigned> smemShape;
|
||||
auto ptrTy = op.ptr().getType();
|
||||
if (auto tensorType = ptrTy.dyn_cast<RankedTensorType>()) {
|
||||
// do nothing or just assert because shared memory is not used in tensor
|
||||
} else {
|
||||
// need only bytes for scalar
|
||||
// always vec = 1 and elemsPerThread = 1 for scalar?
|
||||
smemShape.push_back(1);
|
||||
}
|
||||
return smemShape;
|
||||
}
|
||||
|
||||
class AllocationAnalysis {
|
||||
public:
|
||||
AllocationAnalysis(Operation *operation, Allocation *allocation)
|
||||
@@ -200,6 +214,23 @@ private:
|
||||
elems * kPtrBitWidth / 8 :
|
||||
elems * srcTy.getElementTypeBitWidth() / 8;
|
||||
allocation->addBuffer<BufferT::BufferKind::Scratch>(op, bytes);
|
||||
} else if (auto atomicRMWOp = dyn_cast<triton::AtomicRMWOp>(op)) {
|
||||
auto value = op->getOperand(0);
|
||||
// only scalar requires scratch memory
|
||||
// make it explicit for readability
|
||||
if (value.getType().dyn_cast<RankedTensorType>()) {
|
||||
// nothing to do
|
||||
} else {
|
||||
auto smemShape = getScratchConfigForAtomicRMW(atomicRMWOp);
|
||||
unsigned elems = std::accumulate(smemShape.begin(), smemShape.end(), 1,
|
||||
std::multiplies{});
|
||||
auto elemTy =
|
||||
value.getType().cast<triton::PointerType>().getPointeeType();
|
||||
auto bytes = elemTy.isa<triton::PointerType>()
|
||||
? elems * kPtrBitWidth / 8
|
||||
: elems * elemTy.getIntOrFloatBitWidth() / 8;
|
||||
allocation->addBuffer<BufferT::BufferKind::Scratch>(op, bytes);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
@@ -5947,10 +5947,11 @@ struct AtomicRMWOpConversion
|
||||
triton::AtomicRMWOp>::ConvertTritonGPUOpToLLVMPattern;
|
||||
|
||||
AtomicRMWOpConversion(LLVMTypeConverter &converter,
|
||||
const Allocation *allocation, Value smem,
|
||||
AxisInfoAnalysis &axisAnalysisPass,
|
||||
PatternBenefit benefit)
|
||||
: ConvertTritonGPUOpToLLVMPattern<triton::AtomicRMWOp>(converter,
|
||||
benefit),
|
||||
: ConvertTritonGPUOpToLLVMPattern<triton::AtomicRMWOp>(
|
||||
converter, allocation, smem, benefit),
|
||||
LoadStoreConversionBase(axisAnalysisPass) {}
|
||||
|
||||
LogicalResult
|
||||
@@ -5971,30 +5972,29 @@ struct AtomicRMWOpConversion
|
||||
auto valElements = getElementsFromStruct(loc, llVal, rewriter);
|
||||
auto ptrElements = getElementsFromStruct(loc, llPtr, rewriter);
|
||||
auto maskElements = getElementsFromStruct(loc, llMask, rewriter);
|
||||
|
||||
// TODO[dongdongl]: Support scalar
|
||||
|
||||
|
||||
auto valueTy = op.getResult().getType().dyn_cast<RankedTensorType>();
|
||||
if (!valueTy)
|
||||
return failure();
|
||||
Type valueElemTy =
|
||||
getTypeConverter()->convertType(valueTy.getElementType());
|
||||
|
||||
auto valTy = val.getType().cast<RankedTensorType>();
|
||||
valueTy ? getTypeConverter()->convertType(valueTy.getElementType())
|
||||
: op.getResult().getType();
|
||||
const size_t valueElemNbits = valueElemTy.getIntOrFloatBitWidth();
|
||||
auto elemsPerThread = getElemsPerThread(val.getType());
|
||||
// vec = 1 for scalar
|
||||
auto vec = getVectorSize(ptr);
|
||||
vec = std::min<unsigned>(vec, valTy.getElementType().isF16() ? 2 : 1);
|
||||
Value mask = int_val(1, 1);
|
||||
auto tid = tid_val();
|
||||
// tensor
|
||||
if (valueTy) {
|
||||
auto valTy = val.getType().cast<RankedTensorType>();
|
||||
vec = std::min<unsigned>(vec, valTy.getElementType().isF16() ? 2 : 1);
|
||||
// mask
|
||||
auto shape = valueTy.getShape();
|
||||
auto numElements = product(shape);
|
||||
mask = and_(mask, icmp_slt(mul(tid, i32_val(elemsPerThread)),
|
||||
i32_val(numElements)));
|
||||
}
|
||||
|
||||
auto vecTy = vec_ty(valueElemTy, vec);
|
||||
auto elemsPerThread = getElemsPerThread(val.getType());
|
||||
// mask
|
||||
Value mask = int_val(1, 1);
|
||||
auto shape = valueTy.getShape();
|
||||
auto numElements = product(shape);
|
||||
auto tid = tid_val();
|
||||
mask = and_(mask, icmp_slt(mul(tid, i32_val(elemsPerThread)),
|
||||
i32_val(numElements)));
|
||||
|
||||
SmallVector<Value> resultVals(elemsPerThread);
|
||||
for (size_t i = 0; i < elemsPerThread; i += vec) {
|
||||
Value rmwVal = undef(vecTy);
|
||||
@@ -6008,10 +6008,12 @@ struct AtomicRMWOpConversion
|
||||
rmwMask = and_(rmwMask, mask);
|
||||
std::string sTy;
|
||||
PTXBuilder ptxBuilder;
|
||||
|
||||
auto *dstOpr = ptxBuilder.newOperand("=r");
|
||||
std::string tyId = valueElemNbits * vec == 64
|
||||
? "l"
|
||||
: (valueElemNbits * vec == 32 ? "r" : "h");
|
||||
auto *dstOpr = ptxBuilder.newOperand("=" + tyId);
|
||||
auto *ptrOpr = ptxBuilder.newAddrOperand(rmwPtr, "l");
|
||||
auto *valOpr = ptxBuilder.newOperand(rmwVal, "r");
|
||||
auto *valOpr = ptxBuilder.newOperand(rmwVal, tyId);
|
||||
|
||||
auto &atom = ptxBuilder.create<>("atom")->global().o("gpu");
|
||||
auto rmwOp = stringifyRMWOp(atomicRmwAttr).str();
|
||||
@@ -6053,18 +6055,32 @@ struct AtomicRMWOpConversion
|
||||
return failure();
|
||||
}
|
||||
atom.o(rmwOp).o(sTy);
|
||||
atom(dstOpr, ptrOpr, valOpr).predicate(rmwMask);
|
||||
|
||||
auto ret = ptxBuilder.launch(rewriter, loc, valueElemTy);
|
||||
for (int ii = 0; ii < vec; ++ii) {
|
||||
resultVals[i * vec + ii] =
|
||||
vec == 1 ? ret : extract_element(valueElemTy, ret, idx_val(ii));
|
||||
if (valueTy) {
|
||||
atom(dstOpr, ptrOpr, valOpr).predicate(rmwMask);
|
||||
auto ret = ptxBuilder.launch(rewriter, loc, valueElemTy);
|
||||
for (int ii = 0; ii < vec; ++ii) {
|
||||
resultVals[i * vec + ii] =
|
||||
vec == 1 ? ret : extract_element(valueElemTy, ret, idx_val(ii));
|
||||
}
|
||||
} else {
|
||||
rmwMask = and_(rmwMask, icmp_eq(tid, i32_val(0)));
|
||||
atom(dstOpr, ptrOpr, valOpr).predicate(rmwMask);
|
||||
auto old = ptxBuilder.launch(rewriter, loc, valueElemTy);
|
||||
Value atomPtr = getSharedMemoryBase(loc, rewriter, op.getOperation());
|
||||
atomPtr = bitcast(atomPtr, ptr_ty(valueElemTy, 3));
|
||||
store(old, atomPtr);
|
||||
barrier();
|
||||
Value ret = load(atomPtr);
|
||||
barrier();
|
||||
rewriter.replaceOp(op, {ret});
|
||||
}
|
||||
}
|
||||
Type structTy = getTypeConverter()->convertType(valueTy);
|
||||
Value resultStruct =
|
||||
getStructFromElements(loc, resultVals, rewriter, structTy);
|
||||
rewriter.replaceOp(op, {resultStruct});
|
||||
if (valueTy) {
|
||||
Type structTy = getTypeConverter()->convertType(valueTy);
|
||||
Value resultStruct =
|
||||
getStructFromElements(loc, resultVals, rewriter, structTy);
|
||||
rewriter.replaceOp(op, {resultStruct});
|
||||
}
|
||||
return success();
|
||||
}
|
||||
};
|
||||
@@ -6150,7 +6166,7 @@ void populateTritonToLLVMPatterns(mlir::LLVMTypeConverter &typeConverter,
|
||||
patterns.add<ReduceOpConversion>(typeConverter, allocation, smem, benefit);
|
||||
patterns.add<ConvertLayoutOpConversion>(typeConverter, allocation, smem,
|
||||
benefit);
|
||||
patterns.add<AtomicRMWOpConversion>(typeConverter, axisInfoAnalysis, benefit);
|
||||
patterns.add<AtomicRMWOpConversion>(typeConverter, allocation, smem, axisInfoAnalysis, benefit);
|
||||
patterns.add<ExtractSliceOpConversion>(typeConverter, allocation, smem,
|
||||
benefit);
|
||||
patterns.add<GetProgramIdOpConversion>(typeConverter, benefit);
|
||||
|
@@ -1106,7 +1106,18 @@ void init_triton_ir(py::module &&m) {
|
||||
mlir::Value &ptr, mlir::Value &val,
|
||||
mlir::Value &mask) -> mlir::Value {
|
||||
auto loc = self.getUnknownLoc();
|
||||
mlir::Type dstType = val.getType();
|
||||
mlir::Type dstType;
|
||||
if (auto srcTensorType = ptr.getType().dyn_cast<mlir::RankedTensorType>()) {
|
||||
mlir::Type dstElemType = srcTensorType.getElementType()
|
||||
.cast<mlir::triton::PointerType>()
|
||||
.getPointeeType();
|
||||
dstType = mlir::RankedTensorType::get(srcTensorType.getShape(),
|
||||
dstElemType);
|
||||
} else {
|
||||
auto ptrType = mlir::getElementTypeOrSelf(ptr)
|
||||
.cast<mlir::triton::PointerType>();
|
||||
dstType = ptrType.getPointeeType();
|
||||
}
|
||||
return self.create<mlir::triton::AtomicRMWOp>(loc, dstType, rmwOp,
|
||||
ptr, val, mask);
|
||||
})
|
||||
|
@@ -595,100 +595,80 @@ def test_tuples():
|
||||
assert c_tri == c_ref
|
||||
|
||||
|
||||
# # ---------------
|
||||
# # test atomics
|
||||
# # ---------------
|
||||
# @pytest.mark.parametrize("op, dtype_x_str, mode", itertools.chain.from_iterable([
|
||||
# [
|
||||
# ('add', 'float16', mode),
|
||||
# ('add', 'uint32', mode), ('add', 'int32', mode), ('add', 'float32', mode),
|
||||
# ('max', 'uint32', mode), ('max', 'int32', mode), ('max', 'float32', mode),
|
||||
# ('min', 'uint32', mode), ('min', 'int32', mode), ('min', 'float32', mode),
|
||||
# ]
|
||||
# for mode in ['all_neg', 'all_pos', 'min_neg', 'max_pos']]))
|
||||
# def test_atomic_rmw(op, dtype_x_str, mode, device='cuda'):
|
||||
# n_programs = 5
|
||||
# ---------------
|
||||
# test atomics
|
||||
# ---------------
|
||||
@pytest.mark.parametrize("op, dtype_x_str, mode", itertools.chain.from_iterable([
|
||||
[
|
||||
('add', 'float16', mode),
|
||||
('add', 'uint32', mode), ('add', 'int32', mode), ('add', 'float32', mode),
|
||||
('max', 'uint32', mode), ('max', 'int32', mode), ('max', 'float32', mode),
|
||||
('min', 'uint32', mode), ('min', 'int32', mode), ('min', 'float32', mode),
|
||||
]
|
||||
for mode in ['all_neg', 'all_pos', 'min_neg', 'max_pos']]))
|
||||
def test_atomic_rmw(op, dtype_x_str, mode, device='cuda'):
|
||||
n_programs = 5
|
||||
|
||||
# # triton kernel
|
||||
# @triton.jit
|
||||
# def kernel(X, Z):
|
||||
# pid = tl.program_id(0)
|
||||
# x = tl.load(X + pid)
|
||||
# old = GENERATE_TEST_HERE
|
||||
|
||||
# kernel = patch_kernel(kernel, {'GENERATE_TEST_HERE': f'tl.atomic_{op}(Z, x)'})
|
||||
# numpy_op = {'add': np.sum, 'max': np.max, 'min': np.min}[op]
|
||||
# max_neutral = float('-inf') if dtype_x_str in float_dtypes else np.iinfo(getattr(np, dtype_x_str)).min
|
||||
# min_neutral = float('inf') if dtype_x_str in float_dtypes else np.iinfo(getattr(np, dtype_x_str)).max
|
||||
# neutral = {'add': 0, 'max': max_neutral, 'min': min_neutral}[op]
|
||||
|
||||
# # triton result
|
||||
# rs = RandomState(17)
|
||||
# x = numpy_random((n_programs, ), dtype_str=dtype_x_str, rs=rs)
|
||||
# if mode == 'all_neg':
|
||||
# x = -np.abs(x)
|
||||
# if mode == 'all_pos':
|
||||
# x = np.abs(x)
|
||||
# if mode == 'min_neg':
|
||||
# idx = rs.randint(n_programs, size=(1, )).item()
|
||||
# x[idx] = -np.max(np.abs(x)) - 1
|
||||
# if mode == 'max_pos':
|
||||
# idx = rs.randint(n_programs, size=(1, )).item()
|
||||
# x[idx] = np.max(np.abs(x)) + 1
|
||||
# x_tri = to_triton(x, device=device)
|
||||
|
||||
# z_tri = to_triton(np.array([neutral], dtype=getattr(np, dtype_x_str)), device=device)
|
||||
# kernel[(n_programs, )](x_tri, z_tri)
|
||||
# # torch result
|
||||
# z_ref = numpy_op(x).astype(getattr(np, dtype_x_str))
|
||||
# # compare
|
||||
# exact = op not in ['add']
|
||||
# if exact:
|
||||
# assert z_ref.item() == to_numpy(z_tri).item()
|
||||
# else:
|
||||
# np.testing.assert_allclose(z_ref, to_numpy(z_tri), rtol=0.01)
|
||||
|
||||
|
||||
# @pytest.mark.parametrize("axis", [0, 1])
|
||||
# def test_tensor_atomic_rmw(axis, device="cuda"):
|
||||
# shape0, shape1 = 8, 8
|
||||
# # triton kernel
|
||||
|
||||
# @triton.jit
|
||||
# def kernel(Z, X, AXIS: tl.constexpr, SHAPE0: tl.constexpr, SHAPE1: tl.constexpr):
|
||||
# off0 = tl.arange(0, SHAPE0)
|
||||
# off1 = tl.arange(0, SHAPE1)
|
||||
# x = tl.load(X + off0[:, None] * SHAPE1 + off1[None, :])
|
||||
# z = tl.sum(x, axis=AXIS)
|
||||
# tl.atomic_add(Z + off0, z)
|
||||
# rs = RandomState(17)
|
||||
# x = numpy_random((shape0, shape1), dtype_str="float32", rs=rs)
|
||||
# # reference result
|
||||
# z_ref = np.sum(x, axis=axis)
|
||||
# # triton result
|
||||
# x_tri = to_triton(x, device=device)
|
||||
# z_tri = to_triton(np.zeros((shape0,), dtype="float32"), device=device)
|
||||
# kernel[(1,)](z_tri, x_tri, axis, shape0, shape1)
|
||||
# np.testing.assert_allclose(z_ref, to_numpy(z_tri), rtol=1e-4)
|
||||
|
||||
def test_tensor_atomic_rmw_add_elementwise(device="cuda"):
|
||||
shape0, shape1 = 2, 8
|
||||
# triton kernel
|
||||
@triton.jit
|
||||
def kernel(Z, X, SHAPE0: tl.constexpr, SHAPE1: tl.constexpr):
|
||||
def kernel(X, Z):
|
||||
pid = tl.program_id(0)
|
||||
x = tl.load(X + pid)
|
||||
old = GENERATE_TEST_HERE
|
||||
|
||||
kernel = patch_kernel(kernel, {'GENERATE_TEST_HERE': f'tl.atomic_{op}(Z, x)'})
|
||||
numpy_op = {'add': np.sum, 'max': np.max, 'min': np.min}[op]
|
||||
max_neutral = float('-inf') if dtype_x_str in float_dtypes else np.iinfo(getattr(np, dtype_x_str)).min
|
||||
min_neutral = float('inf') if dtype_x_str in float_dtypes else np.iinfo(getattr(np, dtype_x_str)).max
|
||||
neutral = {'add': 0, 'max': max_neutral, 'min': min_neutral}[op]
|
||||
|
||||
# triton result
|
||||
rs = RandomState(17)
|
||||
x = numpy_random((n_programs, ), dtype_str=dtype_x_str, rs=rs)
|
||||
if mode == 'all_neg':
|
||||
x = -np.abs(x)
|
||||
if mode == 'all_pos':
|
||||
x = np.abs(x)
|
||||
if mode == 'min_neg':
|
||||
idx = rs.randint(n_programs, size=(1, )).item()
|
||||
x[idx] = -np.max(np.abs(x)) - 1
|
||||
if mode == 'max_pos':
|
||||
idx = rs.randint(n_programs, size=(1, )).item()
|
||||
x[idx] = np.max(np.abs(x)) + 1
|
||||
x_tri = to_triton(x, device=device)
|
||||
|
||||
z_tri = to_triton(np.array([neutral], dtype=getattr(np, dtype_x_str)), device=device)
|
||||
kernel[(n_programs, )](x_tri, z_tri)
|
||||
# torch result
|
||||
z_ref = numpy_op(x).astype(getattr(np, dtype_x_str))
|
||||
# compare
|
||||
exact = op not in ['add']
|
||||
if exact:
|
||||
assert z_ref.item() == to_numpy(z_tri).item()
|
||||
else:
|
||||
np.testing.assert_allclose(z_ref, to_numpy(z_tri), rtol=0.01)
|
||||
|
||||
#TODO[dongdongl]:add more cases with size of tensor less than warp size
|
||||
@pytest.mark.parametrize("axis", [0, 1])
|
||||
def test_tensor_atomic_rmw(axis, device="cuda"):
|
||||
shape0, shape1 = 8, 8
|
||||
# triton kernel
|
||||
|
||||
@triton.jit
|
||||
def kernel(Z, X, AXIS: tl.constexpr, SHAPE0: tl.constexpr, SHAPE1: tl.constexpr):
|
||||
off0 = tl.arange(0, SHAPE0)
|
||||
off1 = tl.arange(0, SHAPE1)
|
||||
x = tl.load(X + off0[:, None] * SHAPE1 + off1[None, :])
|
||||
tl.atomic_add(Z + off0[:, None] * SHAPE1 + off1[None, :], x)
|
||||
|
||||
z = tl.sum(x, axis=AXIS)
|
||||
tl.atomic_add(Z + off0, z)
|
||||
rs = RandomState(17)
|
||||
x = numpy_random((shape0, shape1), dtype_str="float32", rs=rs)
|
||||
z = numpy_random((shape0, shape1), dtype_str="float32", rs=rs)
|
||||
# reference
|
||||
z_ref = z + x
|
||||
# reference result
|
||||
z_ref = np.sum(x, axis=axis)
|
||||
# triton result
|
||||
x_tri = torch.from_numpy(x).to(device=device)
|
||||
z_tri = torch.from_numpy(z).to(device=device)
|
||||
kernel[(1,)](z_tri, x_tri, shape0, shape1)
|
||||
x_tri = to_triton(x, device=device)
|
||||
z_tri = to_triton(np.zeros((shape0,), dtype="float32"), device=device)
|
||||
kernel[(1,)](z_tri, x_tri, axis, shape0, shape1)
|
||||
np.testing.assert_allclose(z_ref, to_numpy(z_tri), rtol=1e-4)
|
||||
|
||||
# def test_atomic_cas():
|
||||
|
@@ -875,8 +875,8 @@ def atomic_max(ptr: tl.tensor,
|
||||
# return atomic_umin(i_ptr, i_val) if val < 0
|
||||
i_val = bitcast(val, tl.int32, builder)
|
||||
i_ptr = bitcast(ptr, tl.pointer_type(tl.int32, 1), builder)
|
||||
pos = greater_equal(val, tl.tensor(ir.constant_float.get(sca_ty.to_ir(builder), 0), sca_ty), builder)
|
||||
neg = less_than(val, tl.tensor(ir.constant_float.get(sca_ty.to_ir(builder), 0), sca_ty), builder)
|
||||
pos = greater_equal(val, tl.tensor(builder.get_float32(0), sca_ty), builder)
|
||||
neg = less_than(val, tl.tensor(builder.get_float32(0), sca_ty), builder)
|
||||
pos_ret = tl.tensor(builder.create_atomic_rmw(ir.ATOMIC_OP.MAX, i_ptr.handle, i_val.handle, and_(mask, pos, builder).handle), i_val.type)
|
||||
neg_ret = tl.tensor(builder.create_atomic_rmw(ir.ATOMIC_OP.UMIN, i_ptr.handle, i_val.handle, and_(mask, neg, builder).handle), i_val.type)
|
||||
return where(pos, pos_ret, neg_ret, builder)
|
||||
@@ -907,8 +907,8 @@ def atomic_min(ptr: tl.tensor,
|
||||
# return atomic_umax(i_ptr, i_val) if val < 0
|
||||
i_val = bitcast(val, tl.int32, builder)
|
||||
i_ptr = bitcast(ptr, tl.pointer_type(tl.int32, 1), builder)
|
||||
pos = greater_equal(val, tl.tensor(ir.constant_float.get(sca_ty.to_ir(builder), 0), sca_ty), builder)
|
||||
neg = less_than(val, tl.tensor(ir.constant_float.get(sca_ty.to_ir(builder), 0), sca_ty), builder)
|
||||
pos = greater_equal(val, tl.tensor(builder.get_float32(0), sca_ty), builder)
|
||||
neg = less_than(val, tl.tensor(builder.get_float32(0), sca_ty), builder)
|
||||
pos_ret = tl.tensor(builder.create_atomic_rmw(ir.ATOMIC_OP.MIN,
|
||||
i_ptr.handle,
|
||||
i_val.handle,
|
||||
|
Reference in New Issue
Block a user