[TRITON-MLIR][BACKEND]AtomicRMWOp supports scalar (#903)
AtomicRMWOp supports scalar Co-authored-by: dongdongl <dongdongl@nvidia.com>
This commit is contained in:
@@ -112,6 +112,20 @@ SmallVector<unsigned> getScratchConfigForReduce(triton::ReduceOp op) {
|
|||||||
return smemShape;
|
return smemShape;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// TODO: extend beyond scalars
|
||||||
|
SmallVector<unsigned> getScratchConfigForAtomicRMW(triton::AtomicRMWOp op) {
|
||||||
|
SmallVector<unsigned> smemShape;
|
||||||
|
auto ptrTy = op.ptr().getType();
|
||||||
|
if (auto tensorType = ptrTy.dyn_cast<RankedTensorType>()) {
|
||||||
|
// do nothing or just assert because shared memory is not used in tensor
|
||||||
|
} else {
|
||||||
|
// need only bytes for scalar
|
||||||
|
// always vec = 1 and elemsPerThread = 1 for scalar?
|
||||||
|
smemShape.push_back(1);
|
||||||
|
}
|
||||||
|
return smemShape;
|
||||||
|
}
|
||||||
|
|
||||||
class AllocationAnalysis {
|
class AllocationAnalysis {
|
||||||
public:
|
public:
|
||||||
AllocationAnalysis(Operation *operation, Allocation *allocation)
|
AllocationAnalysis(Operation *operation, Allocation *allocation)
|
||||||
@@ -200,6 +214,23 @@ private:
|
|||||||
elems * kPtrBitWidth / 8 :
|
elems * kPtrBitWidth / 8 :
|
||||||
elems * srcTy.getElementTypeBitWidth() / 8;
|
elems * srcTy.getElementTypeBitWidth() / 8;
|
||||||
allocation->addBuffer<BufferT::BufferKind::Scratch>(op, bytes);
|
allocation->addBuffer<BufferT::BufferKind::Scratch>(op, bytes);
|
||||||
|
} else if (auto atomicRMWOp = dyn_cast<triton::AtomicRMWOp>(op)) {
|
||||||
|
auto value = op->getOperand(0);
|
||||||
|
// only scalar requires scratch memory
|
||||||
|
// make it explicit for readability
|
||||||
|
if (value.getType().dyn_cast<RankedTensorType>()) {
|
||||||
|
// nothing to do
|
||||||
|
} else {
|
||||||
|
auto smemShape = getScratchConfigForAtomicRMW(atomicRMWOp);
|
||||||
|
unsigned elems = std::accumulate(smemShape.begin(), smemShape.end(), 1,
|
||||||
|
std::multiplies{});
|
||||||
|
auto elemTy =
|
||||||
|
value.getType().cast<triton::PointerType>().getPointeeType();
|
||||||
|
auto bytes = elemTy.isa<triton::PointerType>()
|
||||||
|
? elems * kPtrBitWidth / 8
|
||||||
|
: elems * elemTy.getIntOrFloatBitWidth() / 8;
|
||||||
|
allocation->addBuffer<BufferT::BufferKind::Scratch>(op, bytes);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@@ -5947,10 +5947,11 @@ struct AtomicRMWOpConversion
|
|||||||
triton::AtomicRMWOp>::ConvertTritonGPUOpToLLVMPattern;
|
triton::AtomicRMWOp>::ConvertTritonGPUOpToLLVMPattern;
|
||||||
|
|
||||||
AtomicRMWOpConversion(LLVMTypeConverter &converter,
|
AtomicRMWOpConversion(LLVMTypeConverter &converter,
|
||||||
|
const Allocation *allocation, Value smem,
|
||||||
AxisInfoAnalysis &axisAnalysisPass,
|
AxisInfoAnalysis &axisAnalysisPass,
|
||||||
PatternBenefit benefit)
|
PatternBenefit benefit)
|
||||||
: ConvertTritonGPUOpToLLVMPattern<triton::AtomicRMWOp>(converter,
|
: ConvertTritonGPUOpToLLVMPattern<triton::AtomicRMWOp>(
|
||||||
benefit),
|
converter, allocation, smem, benefit),
|
||||||
LoadStoreConversionBase(axisAnalysisPass) {}
|
LoadStoreConversionBase(axisAnalysisPass) {}
|
||||||
|
|
||||||
LogicalResult
|
LogicalResult
|
||||||
@@ -5972,29 +5973,28 @@ struct AtomicRMWOpConversion
|
|||||||
auto ptrElements = getElementsFromStruct(loc, llPtr, rewriter);
|
auto ptrElements = getElementsFromStruct(loc, llPtr, rewriter);
|
||||||
auto maskElements = getElementsFromStruct(loc, llMask, rewriter);
|
auto maskElements = getElementsFromStruct(loc, llMask, rewriter);
|
||||||
|
|
||||||
// TODO[dongdongl]: Support scalar
|
|
||||||
|
|
||||||
auto valueTy = op.getResult().getType().dyn_cast<RankedTensorType>();
|
auto valueTy = op.getResult().getType().dyn_cast<RankedTensorType>();
|
||||||
if (!valueTy)
|
|
||||||
return failure();
|
|
||||||
Type valueElemTy =
|
Type valueElemTy =
|
||||||
getTypeConverter()->convertType(valueTy.getElementType());
|
valueTy ? getTypeConverter()->convertType(valueTy.getElementType())
|
||||||
|
: op.getResult().getType();
|
||||||
auto valTy = val.getType().cast<RankedTensorType>();
|
|
||||||
const size_t valueElemNbits = valueElemTy.getIntOrFloatBitWidth();
|
const size_t valueElemNbits = valueElemTy.getIntOrFloatBitWidth();
|
||||||
auto vec = getVectorSize(ptr);
|
|
||||||
vec = std::min<unsigned>(vec, valTy.getElementType().isF16() ? 2 : 1);
|
|
||||||
|
|
||||||
auto vecTy = vec_ty(valueElemTy, vec);
|
|
||||||
auto elemsPerThread = getElemsPerThread(val.getType());
|
auto elemsPerThread = getElemsPerThread(val.getType());
|
||||||
// mask
|
// vec = 1 for scalar
|
||||||
|
auto vec = getVectorSize(ptr);
|
||||||
Value mask = int_val(1, 1);
|
Value mask = int_val(1, 1);
|
||||||
|
auto tid = tid_val();
|
||||||
|
// tensor
|
||||||
|
if (valueTy) {
|
||||||
|
auto valTy = val.getType().cast<RankedTensorType>();
|
||||||
|
vec = std::min<unsigned>(vec, valTy.getElementType().isF16() ? 2 : 1);
|
||||||
|
// mask
|
||||||
auto shape = valueTy.getShape();
|
auto shape = valueTy.getShape();
|
||||||
auto numElements = product(shape);
|
auto numElements = product(shape);
|
||||||
auto tid = tid_val();
|
|
||||||
mask = and_(mask, icmp_slt(mul(tid, i32_val(elemsPerThread)),
|
mask = and_(mask, icmp_slt(mul(tid, i32_val(elemsPerThread)),
|
||||||
i32_val(numElements)));
|
i32_val(numElements)));
|
||||||
|
}
|
||||||
|
|
||||||
|
auto vecTy = vec_ty(valueElemTy, vec);
|
||||||
SmallVector<Value> resultVals(elemsPerThread);
|
SmallVector<Value> resultVals(elemsPerThread);
|
||||||
for (size_t i = 0; i < elemsPerThread; i += vec) {
|
for (size_t i = 0; i < elemsPerThread; i += vec) {
|
||||||
Value rmwVal = undef(vecTy);
|
Value rmwVal = undef(vecTy);
|
||||||
@@ -6008,10 +6008,12 @@ struct AtomicRMWOpConversion
|
|||||||
rmwMask = and_(rmwMask, mask);
|
rmwMask = and_(rmwMask, mask);
|
||||||
std::string sTy;
|
std::string sTy;
|
||||||
PTXBuilder ptxBuilder;
|
PTXBuilder ptxBuilder;
|
||||||
|
std::string tyId = valueElemNbits * vec == 64
|
||||||
auto *dstOpr = ptxBuilder.newOperand("=r");
|
? "l"
|
||||||
|
: (valueElemNbits * vec == 32 ? "r" : "h");
|
||||||
|
auto *dstOpr = ptxBuilder.newOperand("=" + tyId);
|
||||||
auto *ptrOpr = ptxBuilder.newAddrOperand(rmwPtr, "l");
|
auto *ptrOpr = ptxBuilder.newAddrOperand(rmwPtr, "l");
|
||||||
auto *valOpr = ptxBuilder.newOperand(rmwVal, "r");
|
auto *valOpr = ptxBuilder.newOperand(rmwVal, tyId);
|
||||||
|
|
||||||
auto &atom = ptxBuilder.create<>("atom")->global().o("gpu");
|
auto &atom = ptxBuilder.create<>("atom")->global().o("gpu");
|
||||||
auto rmwOp = stringifyRMWOp(atomicRmwAttr).str();
|
auto rmwOp = stringifyRMWOp(atomicRmwAttr).str();
|
||||||
@@ -6053,18 +6055,32 @@ struct AtomicRMWOpConversion
|
|||||||
return failure();
|
return failure();
|
||||||
}
|
}
|
||||||
atom.o(rmwOp).o(sTy);
|
atom.o(rmwOp).o(sTy);
|
||||||
|
if (valueTy) {
|
||||||
atom(dstOpr, ptrOpr, valOpr).predicate(rmwMask);
|
atom(dstOpr, ptrOpr, valOpr).predicate(rmwMask);
|
||||||
|
|
||||||
auto ret = ptxBuilder.launch(rewriter, loc, valueElemTy);
|
auto ret = ptxBuilder.launch(rewriter, loc, valueElemTy);
|
||||||
for (int ii = 0; ii < vec; ++ii) {
|
for (int ii = 0; ii < vec; ++ii) {
|
||||||
resultVals[i * vec + ii] =
|
resultVals[i * vec + ii] =
|
||||||
vec == 1 ? ret : extract_element(valueElemTy, ret, idx_val(ii));
|
vec == 1 ? ret : extract_element(valueElemTy, ret, idx_val(ii));
|
||||||
}
|
}
|
||||||
|
} else {
|
||||||
|
rmwMask = and_(rmwMask, icmp_eq(tid, i32_val(0)));
|
||||||
|
atom(dstOpr, ptrOpr, valOpr).predicate(rmwMask);
|
||||||
|
auto old = ptxBuilder.launch(rewriter, loc, valueElemTy);
|
||||||
|
Value atomPtr = getSharedMemoryBase(loc, rewriter, op.getOperation());
|
||||||
|
atomPtr = bitcast(atomPtr, ptr_ty(valueElemTy, 3));
|
||||||
|
store(old, atomPtr);
|
||||||
|
barrier();
|
||||||
|
Value ret = load(atomPtr);
|
||||||
|
barrier();
|
||||||
|
rewriter.replaceOp(op, {ret});
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
if (valueTy) {
|
||||||
Type structTy = getTypeConverter()->convertType(valueTy);
|
Type structTy = getTypeConverter()->convertType(valueTy);
|
||||||
Value resultStruct =
|
Value resultStruct =
|
||||||
getStructFromElements(loc, resultVals, rewriter, structTy);
|
getStructFromElements(loc, resultVals, rewriter, structTy);
|
||||||
rewriter.replaceOp(op, {resultStruct});
|
rewriter.replaceOp(op, {resultStruct});
|
||||||
|
}
|
||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
@@ -6150,7 +6166,7 @@ void populateTritonToLLVMPatterns(mlir::LLVMTypeConverter &typeConverter,
|
|||||||
patterns.add<ReduceOpConversion>(typeConverter, allocation, smem, benefit);
|
patterns.add<ReduceOpConversion>(typeConverter, allocation, smem, benefit);
|
||||||
patterns.add<ConvertLayoutOpConversion>(typeConverter, allocation, smem,
|
patterns.add<ConvertLayoutOpConversion>(typeConverter, allocation, smem,
|
||||||
benefit);
|
benefit);
|
||||||
patterns.add<AtomicRMWOpConversion>(typeConverter, axisInfoAnalysis, benefit);
|
patterns.add<AtomicRMWOpConversion>(typeConverter, allocation, smem, axisInfoAnalysis, benefit);
|
||||||
patterns.add<ExtractSliceOpConversion>(typeConverter, allocation, smem,
|
patterns.add<ExtractSliceOpConversion>(typeConverter, allocation, smem,
|
||||||
benefit);
|
benefit);
|
||||||
patterns.add<GetProgramIdOpConversion>(typeConverter, benefit);
|
patterns.add<GetProgramIdOpConversion>(typeConverter, benefit);
|
||||||
|
@@ -1106,7 +1106,18 @@ void init_triton_ir(py::module &&m) {
|
|||||||
mlir::Value &ptr, mlir::Value &val,
|
mlir::Value &ptr, mlir::Value &val,
|
||||||
mlir::Value &mask) -> mlir::Value {
|
mlir::Value &mask) -> mlir::Value {
|
||||||
auto loc = self.getUnknownLoc();
|
auto loc = self.getUnknownLoc();
|
||||||
mlir::Type dstType = val.getType();
|
mlir::Type dstType;
|
||||||
|
if (auto srcTensorType = ptr.getType().dyn_cast<mlir::RankedTensorType>()) {
|
||||||
|
mlir::Type dstElemType = srcTensorType.getElementType()
|
||||||
|
.cast<mlir::triton::PointerType>()
|
||||||
|
.getPointeeType();
|
||||||
|
dstType = mlir::RankedTensorType::get(srcTensorType.getShape(),
|
||||||
|
dstElemType);
|
||||||
|
} else {
|
||||||
|
auto ptrType = mlir::getElementTypeOrSelf(ptr)
|
||||||
|
.cast<mlir::triton::PointerType>();
|
||||||
|
dstType = ptrType.getPointeeType();
|
||||||
|
}
|
||||||
return self.create<mlir::triton::AtomicRMWOp>(loc, dstType, rmwOp,
|
return self.create<mlir::triton::AtomicRMWOp>(loc, dstType, rmwOp,
|
||||||
ptr, val, mask);
|
ptr, val, mask);
|
||||||
})
|
})
|
||||||
|
@@ -595,100 +595,80 @@ def test_tuples():
|
|||||||
assert c_tri == c_ref
|
assert c_tri == c_ref
|
||||||
|
|
||||||
|
|
||||||
# # ---------------
|
# ---------------
|
||||||
# # test atomics
|
# test atomics
|
||||||
# # ---------------
|
# ---------------
|
||||||
# @pytest.mark.parametrize("op, dtype_x_str, mode", itertools.chain.from_iterable([
|
@pytest.mark.parametrize("op, dtype_x_str, mode", itertools.chain.from_iterable([
|
||||||
# [
|
[
|
||||||
# ('add', 'float16', mode),
|
('add', 'float16', mode),
|
||||||
# ('add', 'uint32', mode), ('add', 'int32', mode), ('add', 'float32', mode),
|
('add', 'uint32', mode), ('add', 'int32', mode), ('add', 'float32', mode),
|
||||||
# ('max', 'uint32', mode), ('max', 'int32', mode), ('max', 'float32', mode),
|
('max', 'uint32', mode), ('max', 'int32', mode), ('max', 'float32', mode),
|
||||||
# ('min', 'uint32', mode), ('min', 'int32', mode), ('min', 'float32', mode),
|
('min', 'uint32', mode), ('min', 'int32', mode), ('min', 'float32', mode),
|
||||||
# ]
|
]
|
||||||
# for mode in ['all_neg', 'all_pos', 'min_neg', 'max_pos']]))
|
for mode in ['all_neg', 'all_pos', 'min_neg', 'max_pos']]))
|
||||||
# def test_atomic_rmw(op, dtype_x_str, mode, device='cuda'):
|
def test_atomic_rmw(op, dtype_x_str, mode, device='cuda'):
|
||||||
# n_programs = 5
|
n_programs = 5
|
||||||
|
|
||||||
# # triton kernel
|
# triton kernel
|
||||||
# @triton.jit
|
|
||||||
# def kernel(X, Z):
|
|
||||||
# pid = tl.program_id(0)
|
|
||||||
# x = tl.load(X + pid)
|
|
||||||
# old = GENERATE_TEST_HERE
|
|
||||||
|
|
||||||
# kernel = patch_kernel(kernel, {'GENERATE_TEST_HERE': f'tl.atomic_{op}(Z, x)'})
|
|
||||||
# numpy_op = {'add': np.sum, 'max': np.max, 'min': np.min}[op]
|
|
||||||
# max_neutral = float('-inf') if dtype_x_str in float_dtypes else np.iinfo(getattr(np, dtype_x_str)).min
|
|
||||||
# min_neutral = float('inf') if dtype_x_str in float_dtypes else np.iinfo(getattr(np, dtype_x_str)).max
|
|
||||||
# neutral = {'add': 0, 'max': max_neutral, 'min': min_neutral}[op]
|
|
||||||
|
|
||||||
# # triton result
|
|
||||||
# rs = RandomState(17)
|
|
||||||
# x = numpy_random((n_programs, ), dtype_str=dtype_x_str, rs=rs)
|
|
||||||
# if mode == 'all_neg':
|
|
||||||
# x = -np.abs(x)
|
|
||||||
# if mode == 'all_pos':
|
|
||||||
# x = np.abs(x)
|
|
||||||
# if mode == 'min_neg':
|
|
||||||
# idx = rs.randint(n_programs, size=(1, )).item()
|
|
||||||
# x[idx] = -np.max(np.abs(x)) - 1
|
|
||||||
# if mode == 'max_pos':
|
|
||||||
# idx = rs.randint(n_programs, size=(1, )).item()
|
|
||||||
# x[idx] = np.max(np.abs(x)) + 1
|
|
||||||
# x_tri = to_triton(x, device=device)
|
|
||||||
|
|
||||||
# z_tri = to_triton(np.array([neutral], dtype=getattr(np, dtype_x_str)), device=device)
|
|
||||||
# kernel[(n_programs, )](x_tri, z_tri)
|
|
||||||
# # torch result
|
|
||||||
# z_ref = numpy_op(x).astype(getattr(np, dtype_x_str))
|
|
||||||
# # compare
|
|
||||||
# exact = op not in ['add']
|
|
||||||
# if exact:
|
|
||||||
# assert z_ref.item() == to_numpy(z_tri).item()
|
|
||||||
# else:
|
|
||||||
# np.testing.assert_allclose(z_ref, to_numpy(z_tri), rtol=0.01)
|
|
||||||
|
|
||||||
|
|
||||||
# @pytest.mark.parametrize("axis", [0, 1])
|
|
||||||
# def test_tensor_atomic_rmw(axis, device="cuda"):
|
|
||||||
# shape0, shape1 = 8, 8
|
|
||||||
# # triton kernel
|
|
||||||
|
|
||||||
# @triton.jit
|
|
||||||
# def kernel(Z, X, AXIS: tl.constexpr, SHAPE0: tl.constexpr, SHAPE1: tl.constexpr):
|
|
||||||
# off0 = tl.arange(0, SHAPE0)
|
|
||||||
# off1 = tl.arange(0, SHAPE1)
|
|
||||||
# x = tl.load(X + off0[:, None] * SHAPE1 + off1[None, :])
|
|
||||||
# z = tl.sum(x, axis=AXIS)
|
|
||||||
# tl.atomic_add(Z + off0, z)
|
|
||||||
# rs = RandomState(17)
|
|
||||||
# x = numpy_random((shape0, shape1), dtype_str="float32", rs=rs)
|
|
||||||
# # reference result
|
|
||||||
# z_ref = np.sum(x, axis=axis)
|
|
||||||
# # triton result
|
|
||||||
# x_tri = to_triton(x, device=device)
|
|
||||||
# z_tri = to_triton(np.zeros((shape0,), dtype="float32"), device=device)
|
|
||||||
# kernel[(1,)](z_tri, x_tri, axis, shape0, shape1)
|
|
||||||
# np.testing.assert_allclose(z_ref, to_numpy(z_tri), rtol=1e-4)
|
|
||||||
|
|
||||||
def test_tensor_atomic_rmw_add_elementwise(device="cuda"):
|
|
||||||
shape0, shape1 = 2, 8
|
|
||||||
@triton.jit
|
@triton.jit
|
||||||
def kernel(Z, X, SHAPE0: tl.constexpr, SHAPE1: tl.constexpr):
|
def kernel(X, Z):
|
||||||
|
pid = tl.program_id(0)
|
||||||
|
x = tl.load(X + pid)
|
||||||
|
old = GENERATE_TEST_HERE
|
||||||
|
|
||||||
|
kernel = patch_kernel(kernel, {'GENERATE_TEST_HERE': f'tl.atomic_{op}(Z, x)'})
|
||||||
|
numpy_op = {'add': np.sum, 'max': np.max, 'min': np.min}[op]
|
||||||
|
max_neutral = float('-inf') if dtype_x_str in float_dtypes else np.iinfo(getattr(np, dtype_x_str)).min
|
||||||
|
min_neutral = float('inf') if dtype_x_str in float_dtypes else np.iinfo(getattr(np, dtype_x_str)).max
|
||||||
|
neutral = {'add': 0, 'max': max_neutral, 'min': min_neutral}[op]
|
||||||
|
|
||||||
|
# triton result
|
||||||
|
rs = RandomState(17)
|
||||||
|
x = numpy_random((n_programs, ), dtype_str=dtype_x_str, rs=rs)
|
||||||
|
if mode == 'all_neg':
|
||||||
|
x = -np.abs(x)
|
||||||
|
if mode == 'all_pos':
|
||||||
|
x = np.abs(x)
|
||||||
|
if mode == 'min_neg':
|
||||||
|
idx = rs.randint(n_programs, size=(1, )).item()
|
||||||
|
x[idx] = -np.max(np.abs(x)) - 1
|
||||||
|
if mode == 'max_pos':
|
||||||
|
idx = rs.randint(n_programs, size=(1, )).item()
|
||||||
|
x[idx] = np.max(np.abs(x)) + 1
|
||||||
|
x_tri = to_triton(x, device=device)
|
||||||
|
|
||||||
|
z_tri = to_triton(np.array([neutral], dtype=getattr(np, dtype_x_str)), device=device)
|
||||||
|
kernel[(n_programs, )](x_tri, z_tri)
|
||||||
|
# torch result
|
||||||
|
z_ref = numpy_op(x).astype(getattr(np, dtype_x_str))
|
||||||
|
# compare
|
||||||
|
exact = op not in ['add']
|
||||||
|
if exact:
|
||||||
|
assert z_ref.item() == to_numpy(z_tri).item()
|
||||||
|
else:
|
||||||
|
np.testing.assert_allclose(z_ref, to_numpy(z_tri), rtol=0.01)
|
||||||
|
|
||||||
|
#TODO[dongdongl]:add more cases with size of tensor less than warp size
|
||||||
|
@pytest.mark.parametrize("axis", [0, 1])
|
||||||
|
def test_tensor_atomic_rmw(axis, device="cuda"):
|
||||||
|
shape0, shape1 = 8, 8
|
||||||
|
# triton kernel
|
||||||
|
|
||||||
|
@triton.jit
|
||||||
|
def kernel(Z, X, AXIS: tl.constexpr, SHAPE0: tl.constexpr, SHAPE1: tl.constexpr):
|
||||||
off0 = tl.arange(0, SHAPE0)
|
off0 = tl.arange(0, SHAPE0)
|
||||||
off1 = tl.arange(0, SHAPE1)
|
off1 = tl.arange(0, SHAPE1)
|
||||||
x = tl.load(X + off0[:, None] * SHAPE1 + off1[None, :])
|
x = tl.load(X + off0[:, None] * SHAPE1 + off1[None, :])
|
||||||
tl.atomic_add(Z + off0[:, None] * SHAPE1 + off1[None, :], x)
|
z = tl.sum(x, axis=AXIS)
|
||||||
|
tl.atomic_add(Z + off0, z)
|
||||||
rs = RandomState(17)
|
rs = RandomState(17)
|
||||||
x = numpy_random((shape0, shape1), dtype_str="float32", rs=rs)
|
x = numpy_random((shape0, shape1), dtype_str="float32", rs=rs)
|
||||||
z = numpy_random((shape0, shape1), dtype_str="float32", rs=rs)
|
# reference result
|
||||||
# reference
|
z_ref = np.sum(x, axis=axis)
|
||||||
z_ref = z + x
|
|
||||||
# triton result
|
# triton result
|
||||||
x_tri = torch.from_numpy(x).to(device=device)
|
x_tri = to_triton(x, device=device)
|
||||||
z_tri = torch.from_numpy(z).to(device=device)
|
z_tri = to_triton(np.zeros((shape0,), dtype="float32"), device=device)
|
||||||
kernel[(1,)](z_tri, x_tri, shape0, shape1)
|
kernel[(1,)](z_tri, x_tri, axis, shape0, shape1)
|
||||||
np.testing.assert_allclose(z_ref, to_numpy(z_tri), rtol=1e-4)
|
np.testing.assert_allclose(z_ref, to_numpy(z_tri), rtol=1e-4)
|
||||||
|
|
||||||
# def test_atomic_cas():
|
# def test_atomic_cas():
|
||||||
|
@@ -875,8 +875,8 @@ def atomic_max(ptr: tl.tensor,
|
|||||||
# return atomic_umin(i_ptr, i_val) if val < 0
|
# return atomic_umin(i_ptr, i_val) if val < 0
|
||||||
i_val = bitcast(val, tl.int32, builder)
|
i_val = bitcast(val, tl.int32, builder)
|
||||||
i_ptr = bitcast(ptr, tl.pointer_type(tl.int32, 1), builder)
|
i_ptr = bitcast(ptr, tl.pointer_type(tl.int32, 1), builder)
|
||||||
pos = greater_equal(val, tl.tensor(ir.constant_float.get(sca_ty.to_ir(builder), 0), sca_ty), builder)
|
pos = greater_equal(val, tl.tensor(builder.get_float32(0), sca_ty), builder)
|
||||||
neg = less_than(val, tl.tensor(ir.constant_float.get(sca_ty.to_ir(builder), 0), sca_ty), builder)
|
neg = less_than(val, tl.tensor(builder.get_float32(0), sca_ty), builder)
|
||||||
pos_ret = tl.tensor(builder.create_atomic_rmw(ir.ATOMIC_OP.MAX, i_ptr.handle, i_val.handle, and_(mask, pos, builder).handle), i_val.type)
|
pos_ret = tl.tensor(builder.create_atomic_rmw(ir.ATOMIC_OP.MAX, i_ptr.handle, i_val.handle, and_(mask, pos, builder).handle), i_val.type)
|
||||||
neg_ret = tl.tensor(builder.create_atomic_rmw(ir.ATOMIC_OP.UMIN, i_ptr.handle, i_val.handle, and_(mask, neg, builder).handle), i_val.type)
|
neg_ret = tl.tensor(builder.create_atomic_rmw(ir.ATOMIC_OP.UMIN, i_ptr.handle, i_val.handle, and_(mask, neg, builder).handle), i_val.type)
|
||||||
return where(pos, pos_ret, neg_ret, builder)
|
return where(pos, pos_ret, neg_ret, builder)
|
||||||
@@ -907,8 +907,8 @@ def atomic_min(ptr: tl.tensor,
|
|||||||
# return atomic_umax(i_ptr, i_val) if val < 0
|
# return atomic_umax(i_ptr, i_val) if val < 0
|
||||||
i_val = bitcast(val, tl.int32, builder)
|
i_val = bitcast(val, tl.int32, builder)
|
||||||
i_ptr = bitcast(ptr, tl.pointer_type(tl.int32, 1), builder)
|
i_ptr = bitcast(ptr, tl.pointer_type(tl.int32, 1), builder)
|
||||||
pos = greater_equal(val, tl.tensor(ir.constant_float.get(sca_ty.to_ir(builder), 0), sca_ty), builder)
|
pos = greater_equal(val, tl.tensor(builder.get_float32(0), sca_ty), builder)
|
||||||
neg = less_than(val, tl.tensor(ir.constant_float.get(sca_ty.to_ir(builder), 0), sca_ty), builder)
|
neg = less_than(val, tl.tensor(builder.get_float32(0), sca_ty), builder)
|
||||||
pos_ret = tl.tensor(builder.create_atomic_rmw(ir.ATOMIC_OP.MIN,
|
pos_ret = tl.tensor(builder.create_atomic_rmw(ir.ATOMIC_OP.MIN,
|
||||||
i_ptr.handle,
|
i_ptr.handle,
|
||||||
i_val.handle,
|
i_val.handle,
|
||||||
|
Reference in New Issue
Block a user