[DRIVER] Add CUDA P2P support (#209)
This commit is contained in:
@@ -27,6 +27,7 @@
|
||||
#include <memory>
|
||||
#include "triton/driver/device.h"
|
||||
#include "triton/driver/context.h"
|
||||
#include "triton/driver/error.h"
|
||||
#include "triton/codegen/target.h"
|
||||
|
||||
namespace triton
|
||||
@@ -159,6 +160,14 @@ void cu_device::set_max_clock() {
|
||||
dispatch::nvmlDeviceSetApplicationsClocks(nvml_device(), max_mem_clock(), max_sm_clock());
|
||||
}
|
||||
|
||||
void cu_device::enable_peer_access(CUdeviceptr peer_mem_ptr) const{
|
||||
CUcontext context;
|
||||
dispatch::cuPointerGetAttribute(&context, CU_POINTER_ATTRIBUTE_CONTEXT, peer_mem_ptr);
|
||||
try {
|
||||
dispatch::cuCtxEnablePeerAccess(context, 0);
|
||||
} catch (exception::cuda::peer_access_already_enabled) {}
|
||||
}
|
||||
|
||||
// print infos
|
||||
std::string cu_device::infos() const{
|
||||
std::ostringstream oss;
|
||||
|
@@ -178,6 +178,7 @@ CUDA_DEFINE1(CUresult, cuCtxPopCurrent_v2, CUcontext*)
|
||||
CUDA_DEFINE3(CUresult, cuFuncGetAttribute, int*, CUfunction_attribute, CUfunction)
|
||||
CUDA_DEFINE3(CUresult, cuFuncSetAttribute, CUfunction, CUfunction_attribute, int)
|
||||
CUDA_DEFINE2(CUresult, cuFuncSetCacheConfig, CUfunction, CUfunc_cache)
|
||||
CUDA_DEFINE2(CUresult, cuCtxEnablePeerAccess, CUcontext, unsigned int)
|
||||
|
||||
NVML_DEFINE2(nvmlReturn_t, nvmlDeviceGetHandleByPciBusId_v2, const char *, nvmlDevice_t*)
|
||||
NVML_DEFINE3(nvmlReturn_t, nvmlDeviceGetClockInfo, nvmlDevice_t, nvmlClockType_t, unsigned int*)
|
||||
@@ -252,6 +253,7 @@ void* dispatch::cuCtxPopCurrent_v2_;
|
||||
void* dispatch::cuFuncGetAttribute_;
|
||||
void* dispatch::cuFuncSetAttribute_;
|
||||
void* dispatch::cuFuncSetCacheConfig_;
|
||||
void* dispatch::cuCtxEnablePeerAccess_;
|
||||
|
||||
void* dispatch::nvmlInit_v2_;
|
||||
void* dispatch::nvmlDeviceGetHandleByPciBusId_v2_;
|
||||
|
Reference in New Issue
Block a user