[python] upgraded pybind11 ; forcing torch tensors to be contiguous()
This commit is contained in:
@@ -35,7 +35,6 @@ void register_grid(size_t id,
|
||||
|
||||
void delete_grid(size_t id) {
|
||||
id_grid_map.erase(id);
|
||||
std::cout << "deleted " << id_grid_map.size() << std::endl;
|
||||
}
|
||||
|
||||
void register_fn(size_t id,
|
||||
@@ -46,7 +45,6 @@ void register_fn(size_t id,
|
||||
|
||||
void delete_fn(size_t id) {
|
||||
id_fn_map.erase(id);
|
||||
std::cout << "deleted " << id_fn_map.size() << std::endl;
|
||||
}
|
||||
|
||||
void cleanup() {
|
||||
@@ -415,8 +413,10 @@ void gen_torch_make_handles(std::ostream &os,
|
||||
ir::type* ty = arg->get_type();
|
||||
if(!ty->is_pointer_ty())
|
||||
os << " " << to_c_ty(ty) << " arg_" << name << " = " << name << ";" << std::endl;
|
||||
else
|
||||
else{
|
||||
os << " CHECK_INPUT(" << name << ");" << std::endl;
|
||||
os << " drv::cu_buffer arg_" + name + "(ctx, " + name + ".storage().size(), (CUdeviceptr)" + name + ".storage().data(), false);" << std::endl;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -435,6 +435,10 @@ void gen_torch_make_launch_function(std::ostream &os, const std::vector<ir::argu
|
||||
}
|
||||
|
||||
void gen_torch_ret(std::ostream &os, const std::vector<std::string>& outputs) {
|
||||
if(outputs.size() == 1){
|
||||
os << " return " << outputs[0] << ";" << std::endl;
|
||||
return;
|
||||
}
|
||||
os << " return {";
|
||||
for(size_t i = 0; i < outputs.size(); i++){
|
||||
if(i > 0)
|
||||
@@ -467,6 +471,10 @@ std::tuple<std::string,
|
||||
#include "ATen/cuda/CUDAContext.h"
|
||||
#include "ATen/cuda/detail/CUDAHooks.h"
|
||||
|
||||
#define CHECK_CUDA(x) AT_ASSERTM(x.type().is_cuda(), #x " must be a CUDA tensor")
|
||||
#define CHECK_CONTIGUOUS(x) AT_ASSERTM(x.is_contiguous(), #x " must be contiguous")
|
||||
#define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x)
|
||||
|
||||
namespace rt = triton::runtime;
|
||||
namespace drv = triton::driver;
|
||||
|
||||
|
Reference in New Issue
Block a user