[TRITON-MLIR][BACKEND]AtomicRMWOp supports scalar (#903)

AtomicRMWOp supports scalar

Co-authored-by: dongdongl <dongdongl@nvidia.com>
This commit is contained in:
donproc
2022-11-23 15:59:09 +08:00
committed by GitHub
parent 2e33352419
commit 8925c2cd11
5 changed files with 163 additions and 125 deletions

View File

@@ -112,6 +112,20 @@ SmallVector<unsigned> getScratchConfigForReduce(triton::ReduceOp op) {
return smemShape;
}
// TODO: extend beyond scalars
SmallVector<unsigned> getScratchConfigForAtomicRMW(triton::AtomicRMWOp op) {
SmallVector<unsigned> smemShape;
auto ptrTy = op.ptr().getType();
if (auto tensorType = ptrTy.dyn_cast<RankedTensorType>()) {
// do nothing or just assert because shared memory is not used in tensor
} else {
// need only bytes for scalar
// always vec = 1 and elemsPerThread = 1 for scalar?
smemShape.push_back(1);
}
return smemShape;
}
class AllocationAnalysis {
public:
AllocationAnalysis(Operation *operation, Allocation *allocation)
@@ -200,6 +214,23 @@ private:
elems * kPtrBitWidth / 8 :
elems * srcTy.getElementTypeBitWidth() / 8;
allocation->addBuffer<BufferT::BufferKind::Scratch>(op, bytes);
} else if (auto atomicRMWOp = dyn_cast<triton::AtomicRMWOp>(op)) {
auto value = op->getOperand(0);
// only scalar requires scratch memory
// make it explicit for readability
if (value.getType().dyn_cast<RankedTensorType>()) {
// nothing to do
} else {
auto smemShape = getScratchConfigForAtomicRMW(atomicRMWOp);
unsigned elems = std::accumulate(smemShape.begin(), smemShape.end(), 1,
std::multiplies{});
auto elemTy =
value.getType().cast<triton::PointerType>().getPointeeType();
auto bytes = elemTy.isa<triton::PointerType>()
? elems * kPtrBitWidth / 8
: elems * elemTy.getIntOrFloatBitWidth() / 8;
allocation->addBuffer<BufferT::BufferKind::Scratch>(op, bytes);
}
}
}

View File

@@ -5947,10 +5947,11 @@ struct AtomicRMWOpConversion
triton::AtomicRMWOp>::ConvertTritonGPUOpToLLVMPattern;
AtomicRMWOpConversion(LLVMTypeConverter &converter,
const Allocation *allocation, Value smem,
AxisInfoAnalysis &axisAnalysisPass,
PatternBenefit benefit)
: ConvertTritonGPUOpToLLVMPattern<triton::AtomicRMWOp>(converter,
benefit),
: ConvertTritonGPUOpToLLVMPattern<triton::AtomicRMWOp>(
converter, allocation, smem, benefit),
LoadStoreConversionBase(axisAnalysisPass) {}
LogicalResult
@@ -5971,30 +5972,29 @@ struct AtomicRMWOpConversion
auto valElements = getElementsFromStruct(loc, llVal, rewriter);
auto ptrElements = getElementsFromStruct(loc, llPtr, rewriter);
auto maskElements = getElementsFromStruct(loc, llMask, rewriter);
// TODO[dongdongl]: Support scalar
auto valueTy = op.getResult().getType().dyn_cast<RankedTensorType>();
if (!valueTy)
return failure();
Type valueElemTy =
getTypeConverter()->convertType(valueTy.getElementType());
auto valTy = val.getType().cast<RankedTensorType>();
valueTy ? getTypeConverter()->convertType(valueTy.getElementType())
: op.getResult().getType();
const size_t valueElemNbits = valueElemTy.getIntOrFloatBitWidth();
auto elemsPerThread = getElemsPerThread(val.getType());
// vec = 1 for scalar
auto vec = getVectorSize(ptr);
vec = std::min<unsigned>(vec, valTy.getElementType().isF16() ? 2 : 1);
Value mask = int_val(1, 1);
auto tid = tid_val();
// tensor
if (valueTy) {
auto valTy = val.getType().cast<RankedTensorType>();
vec = std::min<unsigned>(vec, valTy.getElementType().isF16() ? 2 : 1);
// mask
auto shape = valueTy.getShape();
auto numElements = product(shape);
mask = and_(mask, icmp_slt(mul(tid, i32_val(elemsPerThread)),
i32_val(numElements)));
}
auto vecTy = vec_ty(valueElemTy, vec);
auto elemsPerThread = getElemsPerThread(val.getType());
// mask
Value mask = int_val(1, 1);
auto shape = valueTy.getShape();
auto numElements = product(shape);
auto tid = tid_val();
mask = and_(mask, icmp_slt(mul(tid, i32_val(elemsPerThread)),
i32_val(numElements)));
SmallVector<Value> resultVals(elemsPerThread);
for (size_t i = 0; i < elemsPerThread; i += vec) {
Value rmwVal = undef(vecTy);
@@ -6008,10 +6008,12 @@ struct AtomicRMWOpConversion
rmwMask = and_(rmwMask, mask);
std::string sTy;
PTXBuilder ptxBuilder;
auto *dstOpr = ptxBuilder.newOperand("=r");
std::string tyId = valueElemNbits * vec == 64
? "l"
: (valueElemNbits * vec == 32 ? "r" : "h");
auto *dstOpr = ptxBuilder.newOperand("=" + tyId);
auto *ptrOpr = ptxBuilder.newAddrOperand(rmwPtr, "l");
auto *valOpr = ptxBuilder.newOperand(rmwVal, "r");
auto *valOpr = ptxBuilder.newOperand(rmwVal, tyId);
auto &atom = ptxBuilder.create<>("atom")->global().o("gpu");
auto rmwOp = stringifyRMWOp(atomicRmwAttr).str();
@@ -6053,18 +6055,32 @@ struct AtomicRMWOpConversion
return failure();
}
atom.o(rmwOp).o(sTy);
atom(dstOpr, ptrOpr, valOpr).predicate(rmwMask);
auto ret = ptxBuilder.launch(rewriter, loc, valueElemTy);
for (int ii = 0; ii < vec; ++ii) {
resultVals[i * vec + ii] =
vec == 1 ? ret : extract_element(valueElemTy, ret, idx_val(ii));
if (valueTy) {
atom(dstOpr, ptrOpr, valOpr).predicate(rmwMask);
auto ret = ptxBuilder.launch(rewriter, loc, valueElemTy);
for (int ii = 0; ii < vec; ++ii) {
resultVals[i * vec + ii] =
vec == 1 ? ret : extract_element(valueElemTy, ret, idx_val(ii));
}
} else {
rmwMask = and_(rmwMask, icmp_eq(tid, i32_val(0)));
atom(dstOpr, ptrOpr, valOpr).predicate(rmwMask);
auto old = ptxBuilder.launch(rewriter, loc, valueElemTy);
Value atomPtr = getSharedMemoryBase(loc, rewriter, op.getOperation());
atomPtr = bitcast(atomPtr, ptr_ty(valueElemTy, 3));
store(old, atomPtr);
barrier();
Value ret = load(atomPtr);
barrier();
rewriter.replaceOp(op, {ret});
}
}
Type structTy = getTypeConverter()->convertType(valueTy);
Value resultStruct =
getStructFromElements(loc, resultVals, rewriter, structTy);
rewriter.replaceOp(op, {resultStruct});
if (valueTy) {
Type structTy = getTypeConverter()->convertType(valueTy);
Value resultStruct =
getStructFromElements(loc, resultVals, rewriter, structTy);
rewriter.replaceOp(op, {resultStruct});
}
return success();
}
};
@@ -6150,7 +6166,7 @@ void populateTritonToLLVMPatterns(mlir::LLVMTypeConverter &typeConverter,
patterns.add<ReduceOpConversion>(typeConverter, allocation, smem, benefit);
patterns.add<ConvertLayoutOpConversion>(typeConverter, allocation, smem,
benefit);
patterns.add<AtomicRMWOpConversion>(typeConverter, axisInfoAnalysis, benefit);
patterns.add<AtomicRMWOpConversion>(typeConverter, allocation, smem, axisInfoAnalysis, benefit);
patterns.add<ExtractSliceOpConversion>(typeConverter, allocation, smem,
benefit);
patterns.add<GetProgramIdOpConversion>(typeConverter, benefit);

View File

@@ -1106,7 +1106,18 @@ void init_triton_ir(py::module &&m) {
mlir::Value &ptr, mlir::Value &val,
mlir::Value &mask) -> mlir::Value {
auto loc = self.getUnknownLoc();
mlir::Type dstType = val.getType();
mlir::Type dstType;
if (auto srcTensorType = ptr.getType().dyn_cast<mlir::RankedTensorType>()) {
mlir::Type dstElemType = srcTensorType.getElementType()
.cast<mlir::triton::PointerType>()
.getPointeeType();
dstType = mlir::RankedTensorType::get(srcTensorType.getShape(),
dstElemType);
} else {
auto ptrType = mlir::getElementTypeOrSelf(ptr)
.cast<mlir::triton::PointerType>();
dstType = ptrType.getPointeeType();
}
return self.create<mlir::triton::AtomicRMWOp>(loc, dstType, rmwOp,
ptr, val, mask);
})

View File

@@ -595,100 +595,80 @@ def test_tuples():
assert c_tri == c_ref
# # ---------------
# # test atomics
# # ---------------
# @pytest.mark.parametrize("op, dtype_x_str, mode", itertools.chain.from_iterable([
# [
# ('add', 'float16', mode),
# ('add', 'uint32', mode), ('add', 'int32', mode), ('add', 'float32', mode),
# ('max', 'uint32', mode), ('max', 'int32', mode), ('max', 'float32', mode),
# ('min', 'uint32', mode), ('min', 'int32', mode), ('min', 'float32', mode),
# ]
# for mode in ['all_neg', 'all_pos', 'min_neg', 'max_pos']]))
# def test_atomic_rmw(op, dtype_x_str, mode, device='cuda'):
# n_programs = 5
# ---------------
# test atomics
# ---------------
@pytest.mark.parametrize("op, dtype_x_str, mode", itertools.chain.from_iterable([
[
('add', 'float16', mode),
('add', 'uint32', mode), ('add', 'int32', mode), ('add', 'float32', mode),
('max', 'uint32', mode), ('max', 'int32', mode), ('max', 'float32', mode),
('min', 'uint32', mode), ('min', 'int32', mode), ('min', 'float32', mode),
]
for mode in ['all_neg', 'all_pos', 'min_neg', 'max_pos']]))
def test_atomic_rmw(op, dtype_x_str, mode, device='cuda'):
n_programs = 5
# # triton kernel
# @triton.jit
# def kernel(X, Z):
# pid = tl.program_id(0)
# x = tl.load(X + pid)
# old = GENERATE_TEST_HERE
# kernel = patch_kernel(kernel, {'GENERATE_TEST_HERE': f'tl.atomic_{op}(Z, x)'})
# numpy_op = {'add': np.sum, 'max': np.max, 'min': np.min}[op]
# max_neutral = float('-inf') if dtype_x_str in float_dtypes else np.iinfo(getattr(np, dtype_x_str)).min
# min_neutral = float('inf') if dtype_x_str in float_dtypes else np.iinfo(getattr(np, dtype_x_str)).max
# neutral = {'add': 0, 'max': max_neutral, 'min': min_neutral}[op]
# # triton result
# rs = RandomState(17)
# x = numpy_random((n_programs, ), dtype_str=dtype_x_str, rs=rs)
# if mode == 'all_neg':
# x = -np.abs(x)
# if mode == 'all_pos':
# x = np.abs(x)
# if mode == 'min_neg':
# idx = rs.randint(n_programs, size=(1, )).item()
# x[idx] = -np.max(np.abs(x)) - 1
# if mode == 'max_pos':
# idx = rs.randint(n_programs, size=(1, )).item()
# x[idx] = np.max(np.abs(x)) + 1
# x_tri = to_triton(x, device=device)
# z_tri = to_triton(np.array([neutral], dtype=getattr(np, dtype_x_str)), device=device)
# kernel[(n_programs, )](x_tri, z_tri)
# # torch result
# z_ref = numpy_op(x).astype(getattr(np, dtype_x_str))
# # compare
# exact = op not in ['add']
# if exact:
# assert z_ref.item() == to_numpy(z_tri).item()
# else:
# np.testing.assert_allclose(z_ref, to_numpy(z_tri), rtol=0.01)
# @pytest.mark.parametrize("axis", [0, 1])
# def test_tensor_atomic_rmw(axis, device="cuda"):
# shape0, shape1 = 8, 8
# # triton kernel
# @triton.jit
# def kernel(Z, X, AXIS: tl.constexpr, SHAPE0: tl.constexpr, SHAPE1: tl.constexpr):
# off0 = tl.arange(0, SHAPE0)
# off1 = tl.arange(0, SHAPE1)
# x = tl.load(X + off0[:, None] * SHAPE1 + off1[None, :])
# z = tl.sum(x, axis=AXIS)
# tl.atomic_add(Z + off0, z)
# rs = RandomState(17)
# x = numpy_random((shape0, shape1), dtype_str="float32", rs=rs)
# # reference result
# z_ref = np.sum(x, axis=axis)
# # triton result
# x_tri = to_triton(x, device=device)
# z_tri = to_triton(np.zeros((shape0,), dtype="float32"), device=device)
# kernel[(1,)](z_tri, x_tri, axis, shape0, shape1)
# np.testing.assert_allclose(z_ref, to_numpy(z_tri), rtol=1e-4)
def test_tensor_atomic_rmw_add_elementwise(device="cuda"):
shape0, shape1 = 2, 8
# triton kernel
@triton.jit
def kernel(Z, X, SHAPE0: tl.constexpr, SHAPE1: tl.constexpr):
def kernel(X, Z):
pid = tl.program_id(0)
x = tl.load(X + pid)
old = GENERATE_TEST_HERE
kernel = patch_kernel(kernel, {'GENERATE_TEST_HERE': f'tl.atomic_{op}(Z, x)'})
numpy_op = {'add': np.sum, 'max': np.max, 'min': np.min}[op]
max_neutral = float('-inf') if dtype_x_str in float_dtypes else np.iinfo(getattr(np, dtype_x_str)).min
min_neutral = float('inf') if dtype_x_str in float_dtypes else np.iinfo(getattr(np, dtype_x_str)).max
neutral = {'add': 0, 'max': max_neutral, 'min': min_neutral}[op]
# triton result
rs = RandomState(17)
x = numpy_random((n_programs, ), dtype_str=dtype_x_str, rs=rs)
if mode == 'all_neg':
x = -np.abs(x)
if mode == 'all_pos':
x = np.abs(x)
if mode == 'min_neg':
idx = rs.randint(n_programs, size=(1, )).item()
x[idx] = -np.max(np.abs(x)) - 1
if mode == 'max_pos':
idx = rs.randint(n_programs, size=(1, )).item()
x[idx] = np.max(np.abs(x)) + 1
x_tri = to_triton(x, device=device)
z_tri = to_triton(np.array([neutral], dtype=getattr(np, dtype_x_str)), device=device)
kernel[(n_programs, )](x_tri, z_tri)
# torch result
z_ref = numpy_op(x).astype(getattr(np, dtype_x_str))
# compare
exact = op not in ['add']
if exact:
assert z_ref.item() == to_numpy(z_tri).item()
else:
np.testing.assert_allclose(z_ref, to_numpy(z_tri), rtol=0.01)
#TODO[dongdongl]:add more cases with size of tensor less than warp size
@pytest.mark.parametrize("axis", [0, 1])
def test_tensor_atomic_rmw(axis, device="cuda"):
shape0, shape1 = 8, 8
# triton kernel
@triton.jit
def kernel(Z, X, AXIS: tl.constexpr, SHAPE0: tl.constexpr, SHAPE1: tl.constexpr):
off0 = tl.arange(0, SHAPE0)
off1 = tl.arange(0, SHAPE1)
x = tl.load(X + off0[:, None] * SHAPE1 + off1[None, :])
tl.atomic_add(Z + off0[:, None] * SHAPE1 + off1[None, :], x)
z = tl.sum(x, axis=AXIS)
tl.atomic_add(Z + off0, z)
rs = RandomState(17)
x = numpy_random((shape0, shape1), dtype_str="float32", rs=rs)
z = numpy_random((shape0, shape1), dtype_str="float32", rs=rs)
# reference
z_ref = z + x
# reference result
z_ref = np.sum(x, axis=axis)
# triton result
x_tri = torch.from_numpy(x).to(device=device)
z_tri = torch.from_numpy(z).to(device=device)
kernel[(1,)](z_tri, x_tri, shape0, shape1)
x_tri = to_triton(x, device=device)
z_tri = to_triton(np.zeros((shape0,), dtype="float32"), device=device)
kernel[(1,)](z_tri, x_tri, axis, shape0, shape1)
np.testing.assert_allclose(z_ref, to_numpy(z_tri), rtol=1e-4)
# def test_atomic_cas():

View File

@@ -875,8 +875,8 @@ def atomic_max(ptr: tl.tensor,
# return atomic_umin(i_ptr, i_val) if val < 0
i_val = bitcast(val, tl.int32, builder)
i_ptr = bitcast(ptr, tl.pointer_type(tl.int32, 1), builder)
pos = greater_equal(val, tl.tensor(ir.constant_float.get(sca_ty.to_ir(builder), 0), sca_ty), builder)
neg = less_than(val, tl.tensor(ir.constant_float.get(sca_ty.to_ir(builder), 0), sca_ty), builder)
pos = greater_equal(val, tl.tensor(builder.get_float32(0), sca_ty), builder)
neg = less_than(val, tl.tensor(builder.get_float32(0), sca_ty), builder)
pos_ret = tl.tensor(builder.create_atomic_rmw(ir.ATOMIC_OP.MAX, i_ptr.handle, i_val.handle, and_(mask, pos, builder).handle), i_val.type)
neg_ret = tl.tensor(builder.create_atomic_rmw(ir.ATOMIC_OP.UMIN, i_ptr.handle, i_val.handle, and_(mask, neg, builder).handle), i_val.type)
return where(pos, pos_ret, neg_ret, builder)
@@ -907,8 +907,8 @@ def atomic_min(ptr: tl.tensor,
# return atomic_umax(i_ptr, i_val) if val < 0
i_val = bitcast(val, tl.int32, builder)
i_ptr = bitcast(ptr, tl.pointer_type(tl.int32, 1), builder)
pos = greater_equal(val, tl.tensor(ir.constant_float.get(sca_ty.to_ir(builder), 0), sca_ty), builder)
neg = less_than(val, tl.tensor(ir.constant_float.get(sca_ty.to_ir(builder), 0), sca_ty), builder)
pos = greater_equal(val, tl.tensor(builder.get_float32(0), sca_ty), builder)
neg = less_than(val, tl.tensor(builder.get_float32(0), sca_ty), builder)
pos_ret = tl.tensor(builder.create_atomic_rmw(ir.ATOMIC_OP.MIN,
i_ptr.handle,
i_val.handle,