diff --git a/lib/Analysis/Allocation.cpp b/lib/Analysis/Allocation.cpp index 97734c04c..82e91224f 100644 --- a/lib/Analysis/Allocation.cpp +++ b/lib/Analysis/Allocation.cpp @@ -112,6 +112,20 @@ SmallVector getScratchConfigForReduce(triton::ReduceOp op) { return smemShape; } +// TODO: extend beyond scalars +SmallVector getScratchConfigForAtomicRMW(triton::AtomicRMWOp op) { + SmallVector smemShape; + auto ptrTy = op.ptr().getType(); + if (auto tensorType = ptrTy.dyn_cast()) { + // do nothing or just assert because shared memory is not used in tensor + } else { + // need only bytes for scalar + // always vec = 1 and elemsPerThread = 1 for scalar? + smemShape.push_back(1); + } + return smemShape; +} + class AllocationAnalysis { public: AllocationAnalysis(Operation *operation, Allocation *allocation) @@ -200,6 +214,23 @@ private: elems * kPtrBitWidth / 8 : elems * srcTy.getElementTypeBitWidth() / 8; allocation->addBuffer(op, bytes); + } else if (auto atomicRMWOp = dyn_cast(op)) { + auto value = op->getOperand(0); + // only scalar requires scratch memory + // make it explicit for readability + if (value.getType().dyn_cast()) { + // nothing to do + } else { + auto smemShape = getScratchConfigForAtomicRMW(atomicRMWOp); + unsigned elems = std::accumulate(smemShape.begin(), smemShape.end(), 1, + std::multiplies{}); + auto elemTy = + value.getType().cast().getPointeeType(); + auto bytes = elemTy.isa() + ? elems * kPtrBitWidth / 8 + : elems * elemTy.getIntOrFloatBitWidth() / 8; + allocation->addBuffer(op, bytes); + } } } diff --git a/lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVM.cpp b/lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVM.cpp index d79632c55..0c4557698 100644 --- a/lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVM.cpp +++ b/lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVM.cpp @@ -5947,10 +5947,11 @@ struct AtomicRMWOpConversion triton::AtomicRMWOp>::ConvertTritonGPUOpToLLVMPattern; AtomicRMWOpConversion(LLVMTypeConverter &converter, + const Allocation *allocation, Value smem, AxisInfoAnalysis &axisAnalysisPass, PatternBenefit benefit) - : ConvertTritonGPUOpToLLVMPattern(converter, - benefit), + : ConvertTritonGPUOpToLLVMPattern( + converter, allocation, smem, benefit), LoadStoreConversionBase(axisAnalysisPass) {} LogicalResult @@ -5971,30 +5972,29 @@ struct AtomicRMWOpConversion auto valElements = getElementsFromStruct(loc, llVal, rewriter); auto ptrElements = getElementsFromStruct(loc, llPtr, rewriter); auto maskElements = getElementsFromStruct(loc, llMask, rewriter); - - // TODO[dongdongl]: Support scalar - + auto valueTy = op.getResult().getType().dyn_cast(); - if (!valueTy) - return failure(); Type valueElemTy = - getTypeConverter()->convertType(valueTy.getElementType()); - - auto valTy = val.getType().cast(); + valueTy ? getTypeConverter()->convertType(valueTy.getElementType()) + : op.getResult().getType(); const size_t valueElemNbits = valueElemTy.getIntOrFloatBitWidth(); + auto elemsPerThread = getElemsPerThread(val.getType()); + // vec = 1 for scalar auto vec = getVectorSize(ptr); - vec = std::min(vec, valTy.getElementType().isF16() ? 2 : 1); + Value mask = int_val(1, 1); + auto tid = tid_val(); + // tensor + if (valueTy) { + auto valTy = val.getType().cast(); + vec = std::min(vec, valTy.getElementType().isF16() ? 2 : 1); + // mask + auto shape = valueTy.getShape(); + auto numElements = product(shape); + mask = and_(mask, icmp_slt(mul(tid, i32_val(elemsPerThread)), + i32_val(numElements))); + } auto vecTy = vec_ty(valueElemTy, vec); - auto elemsPerThread = getElemsPerThread(val.getType()); - // mask - Value mask = int_val(1, 1); - auto shape = valueTy.getShape(); - auto numElements = product(shape); - auto tid = tid_val(); - mask = and_(mask, icmp_slt(mul(tid, i32_val(elemsPerThread)), - i32_val(numElements))); - SmallVector resultVals(elemsPerThread); for (size_t i = 0; i < elemsPerThread; i += vec) { Value rmwVal = undef(vecTy); @@ -6008,10 +6008,12 @@ struct AtomicRMWOpConversion rmwMask = and_(rmwMask, mask); std::string sTy; PTXBuilder ptxBuilder; - - auto *dstOpr = ptxBuilder.newOperand("=r"); + std::string tyId = valueElemNbits * vec == 64 + ? "l" + : (valueElemNbits * vec == 32 ? "r" : "h"); + auto *dstOpr = ptxBuilder.newOperand("=" + tyId); auto *ptrOpr = ptxBuilder.newAddrOperand(rmwPtr, "l"); - auto *valOpr = ptxBuilder.newOperand(rmwVal, "r"); + auto *valOpr = ptxBuilder.newOperand(rmwVal, tyId); auto &atom = ptxBuilder.create<>("atom")->global().o("gpu"); auto rmwOp = stringifyRMWOp(atomicRmwAttr).str(); @@ -6053,18 +6055,32 @@ struct AtomicRMWOpConversion return failure(); } atom.o(rmwOp).o(sTy); - atom(dstOpr, ptrOpr, valOpr).predicate(rmwMask); - - auto ret = ptxBuilder.launch(rewriter, loc, valueElemTy); - for (int ii = 0; ii < vec; ++ii) { - resultVals[i * vec + ii] = - vec == 1 ? ret : extract_element(valueElemTy, ret, idx_val(ii)); + if (valueTy) { + atom(dstOpr, ptrOpr, valOpr).predicate(rmwMask); + auto ret = ptxBuilder.launch(rewriter, loc, valueElemTy); + for (int ii = 0; ii < vec; ++ii) { + resultVals[i * vec + ii] = + vec == 1 ? ret : extract_element(valueElemTy, ret, idx_val(ii)); + } + } else { + rmwMask = and_(rmwMask, icmp_eq(tid, i32_val(0))); + atom(dstOpr, ptrOpr, valOpr).predicate(rmwMask); + auto old = ptxBuilder.launch(rewriter, loc, valueElemTy); + Value atomPtr = getSharedMemoryBase(loc, rewriter, op.getOperation()); + atomPtr = bitcast(atomPtr, ptr_ty(valueElemTy, 3)); + store(old, atomPtr); + barrier(); + Value ret = load(atomPtr); + barrier(); + rewriter.replaceOp(op, {ret}); } } - Type structTy = getTypeConverter()->convertType(valueTy); - Value resultStruct = - getStructFromElements(loc, resultVals, rewriter, structTy); - rewriter.replaceOp(op, {resultStruct}); + if (valueTy) { + Type structTy = getTypeConverter()->convertType(valueTy); + Value resultStruct = + getStructFromElements(loc, resultVals, rewriter, structTy); + rewriter.replaceOp(op, {resultStruct}); + } return success(); } }; @@ -6150,7 +6166,7 @@ void populateTritonToLLVMPatterns(mlir::LLVMTypeConverter &typeConverter, patterns.add(typeConverter, allocation, smem, benefit); patterns.add(typeConverter, allocation, smem, benefit); - patterns.add(typeConverter, axisInfoAnalysis, benefit); + patterns.add(typeConverter, allocation, smem, axisInfoAnalysis, benefit); patterns.add(typeConverter, allocation, smem, benefit); patterns.add(typeConverter, benefit); diff --git a/python/src/triton.cc b/python/src/triton.cc index f133b6c14..6fd5003e6 100644 --- a/python/src/triton.cc +++ b/python/src/triton.cc @@ -1106,7 +1106,18 @@ void init_triton_ir(py::module &&m) { mlir::Value &ptr, mlir::Value &val, mlir::Value &mask) -> mlir::Value { auto loc = self.getUnknownLoc(); - mlir::Type dstType = val.getType(); + mlir::Type dstType; + if (auto srcTensorType = ptr.getType().dyn_cast()) { + mlir::Type dstElemType = srcTensorType.getElementType() + .cast() + .getPointeeType(); + dstType = mlir::RankedTensorType::get(srcTensorType.getShape(), + dstElemType); + } else { + auto ptrType = mlir::getElementTypeOrSelf(ptr) + .cast(); + dstType = ptrType.getPointeeType(); + } return self.create(loc, dstType, rmwOp, ptr, val, mask); }) diff --git a/python/tests/test_core.py b/python/tests/test_core.py index 5f46f2517..21f9df750 100644 --- a/python/tests/test_core.py +++ b/python/tests/test_core.py @@ -595,100 +595,80 @@ def test_tuples(): assert c_tri == c_ref -# # --------------- -# # test atomics -# # --------------- -# @pytest.mark.parametrize("op, dtype_x_str, mode", itertools.chain.from_iterable([ -# [ -# ('add', 'float16', mode), -# ('add', 'uint32', mode), ('add', 'int32', mode), ('add', 'float32', mode), -# ('max', 'uint32', mode), ('max', 'int32', mode), ('max', 'float32', mode), -# ('min', 'uint32', mode), ('min', 'int32', mode), ('min', 'float32', mode), -# ] -# for mode in ['all_neg', 'all_pos', 'min_neg', 'max_pos']])) -# def test_atomic_rmw(op, dtype_x_str, mode, device='cuda'): -# n_programs = 5 +# --------------- +# test atomics +# --------------- +@pytest.mark.parametrize("op, dtype_x_str, mode", itertools.chain.from_iterable([ + [ + ('add', 'float16', mode), + ('add', 'uint32', mode), ('add', 'int32', mode), ('add', 'float32', mode), + ('max', 'uint32', mode), ('max', 'int32', mode), ('max', 'float32', mode), + ('min', 'uint32', mode), ('min', 'int32', mode), ('min', 'float32', mode), + ] + for mode in ['all_neg', 'all_pos', 'min_neg', 'max_pos']])) +def test_atomic_rmw(op, dtype_x_str, mode, device='cuda'): + n_programs = 5 -# # triton kernel -# @triton.jit -# def kernel(X, Z): -# pid = tl.program_id(0) -# x = tl.load(X + pid) -# old = GENERATE_TEST_HERE - -# kernel = patch_kernel(kernel, {'GENERATE_TEST_HERE': f'tl.atomic_{op}(Z, x)'}) -# numpy_op = {'add': np.sum, 'max': np.max, 'min': np.min}[op] -# max_neutral = float('-inf') if dtype_x_str in float_dtypes else np.iinfo(getattr(np, dtype_x_str)).min -# min_neutral = float('inf') if dtype_x_str in float_dtypes else np.iinfo(getattr(np, dtype_x_str)).max -# neutral = {'add': 0, 'max': max_neutral, 'min': min_neutral}[op] - -# # triton result -# rs = RandomState(17) -# x = numpy_random((n_programs, ), dtype_str=dtype_x_str, rs=rs) -# if mode == 'all_neg': -# x = -np.abs(x) -# if mode == 'all_pos': -# x = np.abs(x) -# if mode == 'min_neg': -# idx = rs.randint(n_programs, size=(1, )).item() -# x[idx] = -np.max(np.abs(x)) - 1 -# if mode == 'max_pos': -# idx = rs.randint(n_programs, size=(1, )).item() -# x[idx] = np.max(np.abs(x)) + 1 -# x_tri = to_triton(x, device=device) - -# z_tri = to_triton(np.array([neutral], dtype=getattr(np, dtype_x_str)), device=device) -# kernel[(n_programs, )](x_tri, z_tri) -# # torch result -# z_ref = numpy_op(x).astype(getattr(np, dtype_x_str)) -# # compare -# exact = op not in ['add'] -# if exact: -# assert z_ref.item() == to_numpy(z_tri).item() -# else: -# np.testing.assert_allclose(z_ref, to_numpy(z_tri), rtol=0.01) - - -# @pytest.mark.parametrize("axis", [0, 1]) -# def test_tensor_atomic_rmw(axis, device="cuda"): -# shape0, shape1 = 8, 8 -# # triton kernel - -# @triton.jit -# def kernel(Z, X, AXIS: tl.constexpr, SHAPE0: tl.constexpr, SHAPE1: tl.constexpr): -# off0 = tl.arange(0, SHAPE0) -# off1 = tl.arange(0, SHAPE1) -# x = tl.load(X + off0[:, None] * SHAPE1 + off1[None, :]) -# z = tl.sum(x, axis=AXIS) -# tl.atomic_add(Z + off0, z) -# rs = RandomState(17) -# x = numpy_random((shape0, shape1), dtype_str="float32", rs=rs) -# # reference result -# z_ref = np.sum(x, axis=axis) -# # triton result -# x_tri = to_triton(x, device=device) -# z_tri = to_triton(np.zeros((shape0,), dtype="float32"), device=device) -# kernel[(1,)](z_tri, x_tri, axis, shape0, shape1) -# np.testing.assert_allclose(z_ref, to_numpy(z_tri), rtol=1e-4) - -def test_tensor_atomic_rmw_add_elementwise(device="cuda"): - shape0, shape1 = 2, 8 + # triton kernel @triton.jit - def kernel(Z, X, SHAPE0: tl.constexpr, SHAPE1: tl.constexpr): + def kernel(X, Z): + pid = tl.program_id(0) + x = tl.load(X + pid) + old = GENERATE_TEST_HERE + + kernel = patch_kernel(kernel, {'GENERATE_TEST_HERE': f'tl.atomic_{op}(Z, x)'}) + numpy_op = {'add': np.sum, 'max': np.max, 'min': np.min}[op] + max_neutral = float('-inf') if dtype_x_str in float_dtypes else np.iinfo(getattr(np, dtype_x_str)).min + min_neutral = float('inf') if dtype_x_str in float_dtypes else np.iinfo(getattr(np, dtype_x_str)).max + neutral = {'add': 0, 'max': max_neutral, 'min': min_neutral}[op] + + # triton result + rs = RandomState(17) + x = numpy_random((n_programs, ), dtype_str=dtype_x_str, rs=rs) + if mode == 'all_neg': + x = -np.abs(x) + if mode == 'all_pos': + x = np.abs(x) + if mode == 'min_neg': + idx = rs.randint(n_programs, size=(1, )).item() + x[idx] = -np.max(np.abs(x)) - 1 + if mode == 'max_pos': + idx = rs.randint(n_programs, size=(1, )).item() + x[idx] = np.max(np.abs(x)) + 1 + x_tri = to_triton(x, device=device) + + z_tri = to_triton(np.array([neutral], dtype=getattr(np, dtype_x_str)), device=device) + kernel[(n_programs, )](x_tri, z_tri) + # torch result + z_ref = numpy_op(x).astype(getattr(np, dtype_x_str)) + # compare + exact = op not in ['add'] + if exact: + assert z_ref.item() == to_numpy(z_tri).item() + else: + np.testing.assert_allclose(z_ref, to_numpy(z_tri), rtol=0.01) + +#TODO[dongdongl]:add more cases with size of tensor less than warp size +@pytest.mark.parametrize("axis", [0, 1]) +def test_tensor_atomic_rmw(axis, device="cuda"): + shape0, shape1 = 8, 8 + # triton kernel + + @triton.jit + def kernel(Z, X, AXIS: tl.constexpr, SHAPE0: tl.constexpr, SHAPE1: tl.constexpr): off0 = tl.arange(0, SHAPE0) off1 = tl.arange(0, SHAPE1) x = tl.load(X + off0[:, None] * SHAPE1 + off1[None, :]) - tl.atomic_add(Z + off0[:, None] * SHAPE1 + off1[None, :], x) - + z = tl.sum(x, axis=AXIS) + tl.atomic_add(Z + off0, z) rs = RandomState(17) x = numpy_random((shape0, shape1), dtype_str="float32", rs=rs) - z = numpy_random((shape0, shape1), dtype_str="float32", rs=rs) - # reference - z_ref = z + x + # reference result + z_ref = np.sum(x, axis=axis) # triton result - x_tri = torch.from_numpy(x).to(device=device) - z_tri = torch.from_numpy(z).to(device=device) - kernel[(1,)](z_tri, x_tri, shape0, shape1) + x_tri = to_triton(x, device=device) + z_tri = to_triton(np.zeros((shape0,), dtype="float32"), device=device) + kernel[(1,)](z_tri, x_tri, axis, shape0, shape1) np.testing.assert_allclose(z_ref, to_numpy(z_tri), rtol=1e-4) # def test_atomic_cas(): diff --git a/python/triton/language/semantic.py b/python/triton/language/semantic.py index 402e4c8d1..df2d6a3af 100644 --- a/python/triton/language/semantic.py +++ b/python/triton/language/semantic.py @@ -875,8 +875,8 @@ def atomic_max(ptr: tl.tensor, # return atomic_umin(i_ptr, i_val) if val < 0 i_val = bitcast(val, tl.int32, builder) i_ptr = bitcast(ptr, tl.pointer_type(tl.int32, 1), builder) - pos = greater_equal(val, tl.tensor(ir.constant_float.get(sca_ty.to_ir(builder), 0), sca_ty), builder) - neg = less_than(val, tl.tensor(ir.constant_float.get(sca_ty.to_ir(builder), 0), sca_ty), builder) + pos = greater_equal(val, tl.tensor(builder.get_float32(0), sca_ty), builder) + neg = less_than(val, tl.tensor(builder.get_float32(0), sca_ty), builder) pos_ret = tl.tensor(builder.create_atomic_rmw(ir.ATOMIC_OP.MAX, i_ptr.handle, i_val.handle, and_(mask, pos, builder).handle), i_val.type) neg_ret = tl.tensor(builder.create_atomic_rmw(ir.ATOMIC_OP.UMIN, i_ptr.handle, i_val.handle, and_(mask, neg, builder).handle), i_val.type) return where(pos, pos_ret, neg_ret, builder) @@ -907,8 +907,8 @@ def atomic_min(ptr: tl.tensor, # return atomic_umax(i_ptr, i_val) if val < 0 i_val = bitcast(val, tl.int32, builder) i_ptr = bitcast(ptr, tl.pointer_type(tl.int32, 1), builder) - pos = greater_equal(val, tl.tensor(ir.constant_float.get(sca_ty.to_ir(builder), 0), sca_ty), builder) - neg = less_than(val, tl.tensor(ir.constant_float.get(sca_ty.to_ir(builder), 0), sca_ty), builder) + pos = greater_equal(val, tl.tensor(builder.get_float32(0), sca_ty), builder) + neg = less_than(val, tl.tensor(builder.get_float32(0), sca_ty), builder) pos_ret = tl.tensor(builder.create_atomic_rmw(ir.ATOMIC_OP.MIN, i_ptr.handle, i_val.handle,