fixed p2p tests failing when there are no supported p2p devices (#386)
This commit is contained in:
@@ -23,13 +23,13 @@ def get_p2p_matrix():
|
|||||||
def get_p2p_devices():
|
def get_p2p_devices():
|
||||||
matrix = get_p2p_matrix()
|
matrix = get_p2p_matrix()
|
||||||
idx = np.where(matrix == "OK")
|
idx = np.where(matrix == "OK")
|
||||||
return f"cuda:{idx[0][0]}", f"cuda:{idx[1][0]}"
|
return [f"cuda:{idx[0][0]}", f"cuda:{idx[1][0]}"] if len(idx[0]) > 0 else []
|
||||||
|
|
||||||
|
|
||||||
def get_non_p2p_devices():
|
def get_non_p2p_devices():
|
||||||
matrix = get_p2p_matrix()
|
matrix = get_p2p_matrix()
|
||||||
idx = np.where(matrix == "NS")
|
idx = np.where(matrix == "NS")
|
||||||
return f"cuda:{idx[0][0]}", f"cuda:{idx[1][0]}"
|
return [f"cuda:{idx[0][0]}", f"cuda:{idx[1][0]}"] if len(idx[0]) > 0 else []
|
||||||
|
|
||||||
|
|
||||||
p2p_devices = get_p2p_devices()
|
p2p_devices = get_p2p_devices()
|
||||||
|
Reference in New Issue
Block a user