diff --git a/python/test/unit/runtime/test_comm.py b/python/test/unit/runtime/test_comm.py index ae843a15f..520462870 100644 --- a/python/test/unit/runtime/test_comm.py +++ b/python/test/unit/runtime/test_comm.py @@ -23,13 +23,13 @@ def get_p2p_matrix(): def get_p2p_devices(): matrix = get_p2p_matrix() idx = np.where(matrix == "OK") - return f"cuda:{idx[0][0]}", f"cuda:{idx[1][0]}" + return [f"cuda:{idx[0][0]}", f"cuda:{idx[1][0]}"] if len(idx[0]) > 0 else [] def get_non_p2p_devices(): matrix = get_p2p_matrix() idx = np.where(matrix == "NS") - return f"cuda:{idx[0][0]}", f"cuda:{idx[1][0]}" + return [f"cuda:{idx[0][0]}", f"cuda:{idx[1][0]}"] if len(idx[0]) > 0 else [] p2p_devices = get_p2p_devices()