[PYTHON] Cleaning C++ bindings
This commit is contained in:
committed by
Philippe Tillet
parent
34f1d5e565
commit
02a6e81b88
@@ -349,11 +349,17 @@ std::string function::preheader() {
|
|||||||
#define F32_INFINITY bitcast<float>(0x7F800000)
|
#define F32_INFINITY bitcast<float>(0x7F800000)
|
||||||
#define F16_INFINITY bitcast<half>((int16)0x7C00)
|
#define F16_INFINITY bitcast<half>((int16)0x7C00)
|
||||||
|
|
||||||
|
|
||||||
|
#define PASTER(a, b, _) a ## _ ## b
|
||||||
|
#define EVALUATOR(a, b, _) PASTER(a, b, _)
|
||||||
|
#define atomic_add(TM, TN) EVALUATOR(atomic_add, EVALUATOR(TM, TN, x), _)
|
||||||
|
extern void atomic_add_64(float*[64], float[64], bool[64]);
|
||||||
|
extern void atomic_add_128x128(float*[128, 128], float[128, 128], bool[128, 128]);
|
||||||
|
extern void atomic_add_64x64(float*[64, 64], float[64, 64], bool[64, 64]);
|
||||||
|
|
||||||
extern int atomic_cas(int*, int, int);
|
extern int atomic_cas(int*, int, int);
|
||||||
extern int atomic_xchg(int*, int);
|
extern int atomic_xchg(int*, int);
|
||||||
extern float f32_atomic_add(float*, float);
|
extern float f32_atomic_add(float*, float);
|
||||||
extern void atomic_add_128x128(float*[128, 128], float[128, 128], bool[128, 128]);
|
|
||||||
extern void atomic_add_64x64(float*[64, 64], float[64, 64], bool[64, 64]);
|
|
||||||
extern int get_program_id(int);
|
extern int get_program_id(int);
|
||||||
extern int get_num_programs(int);
|
extern int get_num_programs(int);
|
||||||
extern float sqrtf(float);
|
extern float sqrtf(float);
|
||||||
|
@@ -21,7 +21,7 @@ std::map<map_key_t, std::shared_ptr<rt::function>> id_fn_map;
|
|||||||
std::map<size_t, double> fp64scalar_map;
|
std::map<size_t, double> fp64scalar_map;
|
||||||
std::map<size_t, int64_t> i64scalar_map;
|
std::map<size_t, int64_t> i64scalar_map;
|
||||||
|
|
||||||
/* Grid map */
|
/* Grid utilities */
|
||||||
|
|
||||||
void register_grid(const map_key_t& key,
|
void register_grid(const map_key_t& key,
|
||||||
const rt::function::grid_fn_ty& grid_fn) {
|
const rt::function::grid_fn_ty& grid_fn) {
|
||||||
@@ -32,7 +32,7 @@ void delete_grid(const map_key_t& key) {
|
|||||||
id_grid_map.erase(key);
|
id_grid_map.erase(key);
|
||||||
}
|
}
|
||||||
|
|
||||||
/* Function map */
|
/* Function utilities */
|
||||||
|
|
||||||
void register_fn(const map_key_t& key,
|
void register_fn(const map_key_t& key,
|
||||||
const std::string& src,
|
const std::string& src,
|
||||||
@@ -60,22 +60,7 @@ size_t make_op_id() {
|
|||||||
return id_fn_map.size();
|
return id_fn_map.size();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/* Function signature */
|
||||||
/* TF scalar wrapper */
|
|
||||||
size_t make_scalar_id() {
|
|
||||||
size_t ret = i64scalar_map.size();
|
|
||||||
i64scalar_map[ret] = int64_t();
|
|
||||||
return ret;
|
|
||||||
}
|
|
||||||
|
|
||||||
bool has_scalar(size_t id) {
|
|
||||||
return i64scalar_map.find(id) != i64scalar_map.end();
|
|
||||||
}
|
|
||||||
|
|
||||||
int64_t retrieve_scalar(size_t id) {
|
|
||||||
return i64scalar_map.at(id);
|
|
||||||
}
|
|
||||||
|
|
||||||
void make_module(const std::string& src, ir::module* ir,
|
void make_module(const std::string& src, ir::module* ir,
|
||||||
const runtime::function::options_space_t& opt) {
|
const runtime::function::options_space_t& opt) {
|
||||||
std::string copy = triton::runtime::function::preheader() + src;
|
std::string copy = triton::runtime::function::preheader() + src;
|
||||||
@@ -93,7 +78,6 @@ void make_module(const std::string& src, ir::module* ir,
|
|||||||
gen.Gen(ir);
|
gen.Gen(ir);
|
||||||
}
|
}
|
||||||
|
|
||||||
/* Function signature */
|
|
||||||
std::vector<rt::arg_type> get_fn_signature(const std::string& src,
|
std::vector<rt::arg_type> get_fn_signature(const std::string& src,
|
||||||
const runtime::function::options_space_t& opt) {
|
const runtime::function::options_space_t& opt) {
|
||||||
// triton-ir code-gen
|
// triton-ir code-gen
|
||||||
@@ -146,8 +130,6 @@ PYBIND11_MODULE(libtriton, m) {
|
|||||||
m.def("register_cst", ®ister_cst);
|
m.def("register_cst", ®ister_cst);
|
||||||
m.def("delete_fn", &delete_fn);
|
m.def("delete_fn", &delete_fn);
|
||||||
m.def("make_op_id", &make_op_id);
|
m.def("make_op_id", &make_op_id);
|
||||||
m.def("make_scalar_id", &make_scalar_id);
|
|
||||||
m.def("retrieve_scalar", &retrieve_scalar);
|
|
||||||
m.def("cleanup", &cleanup);
|
m.def("cleanup", &cleanup);
|
||||||
;
|
;
|
||||||
}
|
}
|
||||||
|
@@ -18,6 +18,15 @@ std::shared_ptr<drv::device> host_device;
|
|||||||
std::shared_ptr<drv::context> host_context;
|
std::shared_ptr<drv::context> host_context;
|
||||||
std::shared_ptr<drv::stream> host_stream;
|
std::shared_ptr<drv::stream> host_stream;
|
||||||
|
|
||||||
|
int64_t cdiv_sum(torch::Tensor& x, int64_t div){
|
||||||
|
TORCH_CHECK(x.unsafeGetTensorImpl()->key_set().has(c10::DispatchKey::CPU), "Argument of cdiv_sum must be a CPU tensor")
|
||||||
|
auto _x = x.accessor<int, 1>();
|
||||||
|
int64_t ret = 0;
|
||||||
|
for(size_t i = 0; i < x.size(0); i++)
|
||||||
|
ret += (_x[i] + div - 1) / div;
|
||||||
|
return ret;
|
||||||
|
}
|
||||||
|
|
||||||
void launch_kernel(int64_t op_id, int64_t dev_id, const std::string& args){
|
void launch_kernel(int64_t op_id, int64_t dev_id, const std::string& args){
|
||||||
if(dev_id == -1){
|
if(dev_id == -1){
|
||||||
if(!host_stream){
|
if(!host_stream){
|
||||||
@@ -36,4 +45,6 @@ void launch_kernel(int64_t op_id, int64_t dev_id, const std::string& args){
|
|||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
static auto registry = torch::RegisterOperators("triton::launch_kernel", &launch_kernel);
|
static auto registry = torch::RegisterOperators()
|
||||||
|
.op("triton::launch_kernel", &launch_kernel)
|
||||||
|
.op("triton::cdiv_sum", &cdiv_sum);
|
||||||
|
@@ -47,6 +47,9 @@ def th_to_triton(obj):
|
|||||||
def cdiv(a, b):
|
def cdiv(a, b):
|
||||||
return (a + b - 1) // b
|
return (a + b - 1) // b
|
||||||
|
|
||||||
|
def cdiv_sum(a, b):
|
||||||
|
return torch.ops.triton.cdiv_sum(a, b)
|
||||||
|
|
||||||
class kernel:
|
class kernel:
|
||||||
|
|
||||||
def __init__(self, src, defines = dict(), num_warps = [2, 4, 8]):
|
def __init__(self, src, defines = dict(), num_warps = [2, 4, 8]):
|
||||||
@@ -73,9 +76,6 @@ class kernel:
|
|||||||
if device not in self.registered:
|
if device not in self.registered:
|
||||||
self.registered.add(device)
|
self.registered.add(device)
|
||||||
libtriton.register_fn((self.op_id, device), self.src, self.opt, os.path.realpath(libtriton.__file__))
|
libtriton.register_fn((self.op_id, device), self.src, self.opt, os.path.realpath(libtriton.__file__))
|
||||||
# launch options
|
|
||||||
bench = kwargs['bench'] if 'bench' in kwargs else 0
|
|
||||||
bench_id = libtriton.make_scalar_id() if bench > 0 else -1
|
|
||||||
# launch grid
|
# launch grid
|
||||||
if 'grid' not in kwargs:
|
if 'grid' not in kwargs:
|
||||||
raise RuntimeError('Must provide grid for kernel launch')
|
raise RuntimeError('Must provide grid for kernel launch')
|
||||||
|
Reference in New Issue
Block a user