[Triton-MLIR][FRONTEND] [BACKEND] fix atomics (#879)
minor fix to backend and frontend of atomics, we can pass 1 test without mask and the shape aligned with CTA size now Co-authored-by: dongdongl <dongdongl@nvidia.com>
This commit is contained in:
@@ -5958,19 +5958,19 @@ struct AtomicRMWOpConversion
|
|||||||
auto elemsPerThread = getElemsPerThread(val.getType());
|
auto elemsPerThread = getElemsPerThread(val.getType());
|
||||||
SmallVector<Value> resultVals(elemsPerThread);
|
SmallVector<Value> resultVals(elemsPerThread);
|
||||||
for (size_t i = 0; i < elemsPerThread; i += vec) {
|
for (size_t i = 0; i < elemsPerThread; i += vec) {
|
||||||
Value rmvVal = undef(vecTy);
|
Value rmwVal = undef(vecTy);
|
||||||
for (int ii = 0; ii < vec; ++ii) {
|
for (int ii = 0; ii < vec; ++ii) {
|
||||||
Value iiVal = createIndexAttrConstant(
|
Value iiVal = createIndexAttrConstant(
|
||||||
rewriter, loc, getTypeConverter()->getIndexType(), ii);
|
rewriter, loc, getTypeConverter()->getIndexType(), ii);
|
||||||
rmvVal = insert_element(vecTy, rmvVal, valElements[i], iiVal);
|
rmwVal = insert_element(vecTy, rmwVal, valElements[i + ii], iiVal);
|
||||||
}
|
}
|
||||||
Value rmwPtr = bitcast(ptrElements[i], ptr_ty(valTy.getElementType()));
|
Value rmwPtr = ptrElements[i];
|
||||||
std::string sTy;
|
std::string sTy;
|
||||||
PTXBuilder ptxBuilder;
|
PTXBuilder ptxBuilder;
|
||||||
|
|
||||||
auto *dstOpr = ptxBuilder.newOperand("=r");
|
auto *dstOpr = ptxBuilder.newOperand("=r");
|
||||||
auto *ptrOpr = ptxBuilder.newAddrOperand(rmwPtr, "r");
|
auto *ptrOpr = ptxBuilder.newAddrOperand(rmwPtr, "l");
|
||||||
auto *valOpr = ptxBuilder.newOperand(rmvVal, "r");
|
auto *valOpr = ptxBuilder.newOperand(rmwVal, "r");
|
||||||
|
|
||||||
auto &atom = ptxBuilder.create<>("atom")->global().o("gpu");
|
auto &atom = ptxBuilder.create<>("atom")->global().o("gpu");
|
||||||
auto rmwOp = stringifyRMWOp(atomicRmwAttr).str();
|
auto rmwOp = stringifyRMWOp(atomicRmwAttr).str();
|
||||||
@@ -6012,12 +6012,13 @@ struct AtomicRMWOpConversion
|
|||||||
return failure();
|
return failure();
|
||||||
}
|
}
|
||||||
atom.o(rmwOp).o(sTy);
|
atom.o(rmwOp).o(sTy);
|
||||||
|
//TODO:[dongdongl] actual mask support
|
||||||
atom(dstOpr, ptrOpr, valOpr);
|
Value pred = int_val(1, 1);
|
||||||
auto ret = ptxBuilder.launch(rewriter, loc, valueElemTy, false);
|
atom(dstOpr, ptrOpr, valOpr).predicate(pred);
|
||||||
|
auto ret = ptxBuilder.launch(rewriter, loc, valueElemTy);
|
||||||
for (int ii = 0; ii < vec; ++ii) {
|
for (int ii = 0; ii < vec; ++ii) {
|
||||||
resultVals[i * vec + ii] =
|
resultVals[i * vec + ii] =
|
||||||
vec == 1 ? ret : extract_element(vecTy, ret, idx_val(ii));
|
vec == 1 ? ret : extract_element(valueElemTy, ret, idx_val(ii));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
Type structTy = getTypeConverter()->convertType(valueTy);
|
Type structTy = getTypeConverter()->convertType(valueTy);
|
||||||
|
@@ -378,7 +378,7 @@ void populateTritonPatterns(TritonGPUTypeConverter &typeConverter,
|
|||||||
TritonGenericPattern<triton::AddPtrOp>, TritonReducePattern,
|
TritonGenericPattern<triton::AddPtrOp>, TritonReducePattern,
|
||||||
TritonExpandDimsPattern, TritonMakeRangePattern, TritonDotPattern,
|
TritonExpandDimsPattern, TritonMakeRangePattern, TritonDotPattern,
|
||||||
TritonLoadPattern, TritonStorePattern, TritonExtElemwisePattern,
|
TritonLoadPattern, TritonStorePattern, TritonExtElemwisePattern,
|
||||||
TritonPrintfPattern>(typeConverter, context);
|
TritonPrintfPattern, TritonAtomicRMWPattern>(typeConverter, context);
|
||||||
}
|
}
|
||||||
|
|
||||||
//
|
//
|
||||||
|
@@ -1086,9 +1086,7 @@ void init_triton_ir(py::module &&m) {
|
|||||||
mlir::Value &ptr, mlir::Value &val,
|
mlir::Value &ptr, mlir::Value &val,
|
||||||
mlir::Value &mask) -> mlir::Value {
|
mlir::Value &mask) -> mlir::Value {
|
||||||
auto loc = self.getUnknownLoc();
|
auto loc = self.getUnknownLoc();
|
||||||
auto ptrType = mlir::getElementTypeOrSelf(ptr)
|
mlir::Type dstType = val.getType();
|
||||||
.cast<mlir::triton::PointerType>();
|
|
||||||
mlir::Type dstType = ptrType.getPointeeType();
|
|
||||||
return self.create<mlir::triton::AtomicRMWOp>(loc, dstType, rmwOp,
|
return self.create<mlir::triton::AtomicRMWOp>(loc, dstType, rmwOp,
|
||||||
ptr, val, mask);
|
ptr, val, mask);
|
||||||
})
|
})
|
||||||
|
@@ -671,6 +671,25 @@ def test_tuples():
|
|||||||
# kernel[(1,)](z_tri, x_tri, axis, shape0, shape1)
|
# kernel[(1,)](z_tri, x_tri, axis, shape0, shape1)
|
||||||
# np.testing.assert_allclose(z_ref, to_numpy(z_tri), rtol=1e-4)
|
# np.testing.assert_allclose(z_ref, to_numpy(z_tri), rtol=1e-4)
|
||||||
|
|
||||||
|
def test_tensor_atomic_rmw_add_elementwise(device="cuda"):
|
||||||
|
shape0, shape1 = 16, 16
|
||||||
|
@triton.jit
|
||||||
|
def kernel(Z, X, SHAPE0: tl.constexpr, SHAPE1: tl.constexpr):
|
||||||
|
off0 = tl.arange(0, SHAPE0)
|
||||||
|
off1 = tl.arange(0, SHAPE1)
|
||||||
|
x = tl.load(X + off0[:, None] * SHAPE1 + off1[None, :])
|
||||||
|
tl.atomic_add(Z + off0[:, None] * SHAPE1 + off1[None, :], x)
|
||||||
|
|
||||||
|
rs = RandomState(17)
|
||||||
|
x = numpy_random((shape0, shape1), dtype_str="float32", rs=rs)
|
||||||
|
z = numpy_random((shape0, shape1), dtype_str="float32", rs=rs)
|
||||||
|
# reference
|
||||||
|
z_ref = z + x
|
||||||
|
# triton result
|
||||||
|
x_tri = torch.from_numpy(x).to(device=device)
|
||||||
|
z_tri = torch.from_numpy(z).to(device=device)
|
||||||
|
kernel[(1,)](z_tri, x_tri, shape0, shape1)
|
||||||
|
np.testing.assert_allclose(z_ref, to_numpy(z_tri), rtol=1e-4)
|
||||||
|
|
||||||
# def test_atomic_cas():
|
# def test_atomic_cas():
|
||||||
# # 1. make sure that atomic_cas changes the original value (Lock)
|
# # 1. make sure that atomic_cas changes the original value (Lock)
|
||||||
|
Reference in New Issue
Block a user