2020-08-12 19:39:57 -04:00
|
|
|
// Thanks to Scott Gray (OpenAI) for the idea to pass the arguments
|
|
|
|
// as a string constructed with struct.pack in python
|
|
|
|
|
2020-08-11 20:10:39 -04:00
|
|
|
#include "triton/driver/buffer.h"
|
|
|
|
#include "triton/driver/stream.h"
|
|
|
|
#include "triton/runtime/function.h"
|
|
|
|
#include "triton/tools/bench.hpp"
|
|
|
|
#include "torch/script.h"
|
|
|
|
#include "ATen/cuda/CUDAContext.h"
|
|
|
|
|
|
|
|
namespace rt = triton::runtime;
|
|
|
|
namespace drv = triton::driver;
|
|
|
|
|
2020-09-11 11:44:34 -04:00
|
|
|
typedef std::pair<int, int> map_key_t;
|
2020-08-11 20:10:39 -04:00
|
|
|
extern std::map<map_key_t, std::shared_ptr<rt::function::grid_fn_ty>> id_grid_map;
|
|
|
|
extern std::map<map_key_t, std::shared_ptr<rt::function>> id_fn_map;
|
2020-09-11 11:44:34 -04:00
|
|
|
std::shared_ptr<drv::device> host_device;
|
|
|
|
std::shared_ptr<drv::context> host_context;
|
|
|
|
std::shared_ptr<drv::stream> host_stream;
|
2020-08-11 20:10:39 -04:00
|
|
|
|
2020-11-03 16:02:02 -05:00
|
|
|
int64_t cdiv_sum(torch::Tensor x, int64_t div){
|
|
|
|
TORCH_CHECK(!x.is_cuda(), "Argument of cdiv_sum must be a CPU tensor")
|
2020-11-02 15:06:08 -05:00
|
|
|
auto _x = x.accessor<int, 1>();
|
|
|
|
int64_t ret = 0;
|
|
|
|
for(size_t i = 0; i < x.size(0); i++)
|
|
|
|
ret += (_x[i] + div - 1) / div;
|
|
|
|
return ret;
|
|
|
|
}
|
|
|
|
|
2020-11-12 02:11:45 -05:00
|
|
|
void init_host_stream() {
|
|
|
|
if(!host_stream){
|
|
|
|
host_device.reset(new drv::host_device());
|
|
|
|
host_context.reset(drv::context::create(&*host_device));
|
|
|
|
host_stream.reset(drv::stream::create(&*host_context));
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
2020-11-07 02:55:48 -05:00
|
|
|
CUstream torch_get_cuda_stream(int64_t dev_id) {
|
|
|
|
return (CUstream)at::cuda::getCurrentCUDAStream(dev_id).stream();
|
|
|
|
}
|
|
|
|
|
2020-11-12 02:11:45 -05:00
|
|
|
void synchronize(int64_t dev_id) {
|
|
|
|
if(dev_id == -1){
|
|
|
|
init_host_stream();
|
|
|
|
host_stream->synchronize();
|
|
|
|
}
|
|
|
|
else{
|
|
|
|
triton::driver::cu_stream stream(torch_get_cuda_stream(dev_id), false);
|
|
|
|
stream.synchronize();
|
|
|
|
}
|
|
|
|
|
|
|
|
}
|
|
|
|
|
2020-11-11 14:44:56 -05:00
|
|
|
void launch_kernel(int64_t op_id, int64_t dev_id, const std::string& args,
|
|
|
|
const std::vector<std::string>& constant_names, const std::vector<torch::Tensor>& constant_vals){
|
|
|
|
rt::function* fn = id_fn_map.at({op_id, dev_id}).get();
|
|
|
|
for(size_t n = 0; n < constant_names.size(); n++){
|
|
|
|
const torch::Tensor& x = constant_vals[n];
|
2020-11-12 02:11:45 -05:00
|
|
|
fn->set_cst(constant_names[n].c_str(), (char*)x.data_ptr(), x.numel()*x.element_size());
|
2020-11-11 14:44:56 -05:00
|
|
|
}
|
2020-09-11 11:44:34 -04:00
|
|
|
if(dev_id == -1){
|
2020-11-12 02:11:45 -05:00
|
|
|
init_host_stream();
|
2020-11-11 14:44:56 -05:00
|
|
|
(*fn)((void**)args.c_str(), args.size(), *id_grid_map.at({op_id, dev_id}), &*host_stream);
|
2020-09-11 11:44:34 -04:00
|
|
|
}
|
|
|
|
else{
|
2020-11-07 02:55:48 -05:00
|
|
|
triton::driver::cu_stream stream(torch_get_cuda_stream(dev_id), false);
|
2020-09-11 11:44:34 -04:00
|
|
|
triton::driver::context* ctx = stream.context();
|
2020-11-11 14:44:56 -05:00
|
|
|
(*fn)((void**)args.c_str(), args.size(), *id_grid_map.at({op_id, dev_id}), &stream);
|
2020-09-11 11:44:34 -04:00
|
|
|
}
|
2020-08-11 20:10:39 -04:00
|
|
|
}
|
|
|
|
|
|
|
|
|
2020-11-02 15:06:08 -05:00
|
|
|
static auto registry = torch::RegisterOperators()
|
|
|
|
.op("triton::launch_kernel", &launch_kernel)
|
2020-11-12 02:11:45 -05:00
|
|
|
.op("triton::cdiv_sum", &cdiv_sum)
|
|
|
|
.op("triton::synchronize", &synchronize);
|