[CODEGEN] More work on the CPU backend

This commit is contained in:
Philippe Tillet
2020-09-11 11:44:34 -04:00
committed by Philippe Tillet
parent 64eaec016f
commit 840308ab5d
17 changed files with 258 additions and 185 deletions

View File

@@ -74,7 +74,7 @@ class CMakeBuild(build_ext):
'-DLLVM_CONFIG=' + find_llvm()]
# configuration
cfg = 'Debug' if self.debug else 'Release'
cfg = 'Release'
cfg = 'Debug'
build_args = ['--config', cfg]
if platform.system() == "Windows":

View File

@@ -15,7 +15,7 @@ using namespace triton;
namespace rt = triton::runtime;
typedef std::pair<size_t, size_t> map_key_t;
typedef std::pair<int, int> map_key_t;
std::map<map_key_t, std::shared_ptr<rt::function::grid_fn_ty>> id_grid_map;
std::map<map_key_t, std::shared_ptr<rt::function>> id_fn_map;
std::map<size_t, double> fp64scalar_map;

View File

@@ -8,22 +8,31 @@
#include "torch/script.h"
#include "ATen/cuda/CUDAContext.h"
#define CHECK_CUDA(x) AT_ASSERTM(x.type().is_cuda(), #x " must be a CUDA tensor")
#define CHECK_CONTIGUOUS(x) AT_ASSERTM(x.is_contiguous(), #x " must be contiguous")
#define CHECK_INPUT(x) CHECK_CUDA(x);
namespace rt = triton::runtime;
namespace drv = triton::driver;
typedef std::pair<size_t, size_t> map_key_t;
typedef std::pair<int, int> map_key_t;
extern std::map<map_key_t, std::shared_ptr<rt::function::grid_fn_ty>> id_grid_map;
extern std::map<map_key_t, std::shared_ptr<rt::function>> id_fn_map;
std::shared_ptr<drv::device> host_device;
std::shared_ptr<drv::context> host_context;
std::shared_ptr<drv::stream> host_stream;
void launch_kernel(int64_t op_id, int64_t dev_id, const std::string& args){
CUstream custream = (CUstream)at::cuda::getCurrentCUDAStream(dev_id).stream();
triton::driver::cu_stream stream(custream, false);
triton::driver::context* ctx = stream.context();
(*id_fn_map.at({op_id, dev_id}))((void**)args.c_str(), args.size(), *id_grid_map.at({op_id, dev_id}), &stream);
if(dev_id == -1){
if(!host_stream){
host_device.reset(new drv::host_device());
host_context.reset(drv::context::create(&*host_device));
host_stream.reset(drv::stream::create(&*host_context));
}
(*id_fn_map.at({op_id, dev_id}))((void**)args.c_str(), args.size(), *id_grid_map.at({op_id, dev_id}), &*host_stream);
}
else{
CUstream custream = (CUstream)at::cuda::getCurrentCUDAStream(dev_id).stream();
triton::driver::cu_stream stream(custream, false);
triton::driver::context* ctx = stream.context();
(*id_fn_map.at({op_id, dev_id}))((void**)args.c_str(), args.size(), *id_grid_map.at({op_id, dev_id}), &stream);
}
}

View File

@@ -67,6 +67,7 @@ class kernel:
for x in args:
if isinstance(x, torch.Tensor):
device = x.device.index
device = -1 if device is None else device
break
# lazily register function for device
if device not in self.registered: