[CODEGEN] More work on the CPU backend
This commit is contained in:
committed by
Philippe Tillet
parent
64eaec016f
commit
840308ab5d
@@ -74,7 +74,7 @@ class CMakeBuild(build_ext):
|
||||
'-DLLVM_CONFIG=' + find_llvm()]
|
||||
# configuration
|
||||
cfg = 'Debug' if self.debug else 'Release'
|
||||
cfg = 'Release'
|
||||
cfg = 'Debug'
|
||||
build_args = ['--config', cfg]
|
||||
|
||||
if platform.system() == "Windows":
|
||||
|
@@ -15,7 +15,7 @@ using namespace triton;
|
||||
|
||||
namespace rt = triton::runtime;
|
||||
|
||||
typedef std::pair<size_t, size_t> map_key_t;
|
||||
typedef std::pair<int, int> map_key_t;
|
||||
std::map<map_key_t, std::shared_ptr<rt::function::grid_fn_ty>> id_grid_map;
|
||||
std::map<map_key_t, std::shared_ptr<rt::function>> id_fn_map;
|
||||
std::map<size_t, double> fp64scalar_map;
|
||||
|
@@ -8,22 +8,31 @@
|
||||
#include "torch/script.h"
|
||||
#include "ATen/cuda/CUDAContext.h"
|
||||
|
||||
#define CHECK_CUDA(x) AT_ASSERTM(x.type().is_cuda(), #x " must be a CUDA tensor")
|
||||
#define CHECK_CONTIGUOUS(x) AT_ASSERTM(x.is_contiguous(), #x " must be contiguous")
|
||||
#define CHECK_INPUT(x) CHECK_CUDA(x);
|
||||
|
||||
namespace rt = triton::runtime;
|
||||
namespace drv = triton::driver;
|
||||
|
||||
typedef std::pair<size_t, size_t> map_key_t;
|
||||
typedef std::pair<int, int> map_key_t;
|
||||
extern std::map<map_key_t, std::shared_ptr<rt::function::grid_fn_ty>> id_grid_map;
|
||||
extern std::map<map_key_t, std::shared_ptr<rt::function>> id_fn_map;
|
||||
std::shared_ptr<drv::device> host_device;
|
||||
std::shared_ptr<drv::context> host_context;
|
||||
std::shared_ptr<drv::stream> host_stream;
|
||||
|
||||
void launch_kernel(int64_t op_id, int64_t dev_id, const std::string& args){
|
||||
CUstream custream = (CUstream)at::cuda::getCurrentCUDAStream(dev_id).stream();
|
||||
triton::driver::cu_stream stream(custream, false);
|
||||
triton::driver::context* ctx = stream.context();
|
||||
(*id_fn_map.at({op_id, dev_id}))((void**)args.c_str(), args.size(), *id_grid_map.at({op_id, dev_id}), &stream);
|
||||
if(dev_id == -1){
|
||||
if(!host_stream){
|
||||
host_device.reset(new drv::host_device());
|
||||
host_context.reset(drv::context::create(&*host_device));
|
||||
host_stream.reset(drv::stream::create(&*host_context));
|
||||
}
|
||||
(*id_fn_map.at({op_id, dev_id}))((void**)args.c_str(), args.size(), *id_grid_map.at({op_id, dev_id}), &*host_stream);
|
||||
}
|
||||
else{
|
||||
CUstream custream = (CUstream)at::cuda::getCurrentCUDAStream(dev_id).stream();
|
||||
triton::driver::cu_stream stream(custream, false);
|
||||
triton::driver::context* ctx = stream.context();
|
||||
(*id_fn_map.at({op_id, dev_id}))((void**)args.c_str(), args.size(), *id_grid_map.at({op_id, dev_id}), &stream);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
|
@@ -67,6 +67,7 @@ class kernel:
|
||||
for x in args:
|
||||
if isinstance(x, torch.Tensor):
|
||||
device = x.device.index
|
||||
device = -1 if device is None else device
|
||||
break
|
||||
# lazily register function for device
|
||||
if device not in self.registered:
|
||||
|
Reference in New Issue
Block a user