[python] upgraded pybind11 ; forcing torch tensors to be contiguous()

This commit is contained in:
Philippe Tillet
2019-09-05 12:30:51 -04:00
parent 58544d0523
commit 2d6c8311e8
23 changed files with 960 additions and 531 deletions

View File

@@ -35,7 +35,6 @@ void register_grid(size_t id,
void delete_grid(size_t id) {
id_grid_map.erase(id);
std::cout << "deleted " << id_grid_map.size() << std::endl;
}
void register_fn(size_t id,
@@ -46,7 +45,6 @@ void register_fn(size_t id,
void delete_fn(size_t id) {
id_fn_map.erase(id);
std::cout << "deleted " << id_fn_map.size() << std::endl;
}
void cleanup() {
@@ -415,8 +413,10 @@ void gen_torch_make_handles(std::ostream &os,
ir::type* ty = arg->get_type();
if(!ty->is_pointer_ty())
os << " " << to_c_ty(ty) << " arg_" << name << " = " << name << ";" << std::endl;
else
else{
os << " CHECK_INPUT(" << name << ");" << std::endl;
os << " drv::cu_buffer arg_" + name + "(ctx, " + name + ".storage().size(), (CUdeviceptr)" + name + ".storage().data(), false);" << std::endl;
}
}
}
@@ -435,6 +435,10 @@ void gen_torch_make_launch_function(std::ostream &os, const std::vector<ir::argu
}
void gen_torch_ret(std::ostream &os, const std::vector<std::string>& outputs) {
if(outputs.size() == 1){
os << " return " << outputs[0] << ";" << std::endl;
return;
}
os << " return {";
for(size_t i = 0; i < outputs.size(); i++){
if(i > 0)
@@ -467,6 +471,10 @@ std::tuple<std::string,
#include "ATen/cuda/CUDAContext.h"
#include "ATen/cuda/detail/CUDAHooks.h"
#define CHECK_CUDA(x) AT_ASSERTM(x.type().is_cuda(), #x " must be a CUDA tensor")
#define CHECK_CONTIGUOUS(x) AT_ASSERTM(x.is_contiguous(), #x " must be contiguous")
#define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x)
namespace rt = triton::runtime;
namespace drv = triton::driver;