[PYTHON] Cleaning C++ bindings

This commit is contained in:
Philippe Tillet
2020-11-02 15:06:08 -05:00
committed by Philippe Tillet
parent 34f1d5e565
commit 02a6e81b88
4 changed files with 26 additions and 27 deletions

View File

@@ -349,11 +349,17 @@ std::string function::preheader() {
#define F32_INFINITY bitcast<float>(0x7F800000)
#define F16_INFINITY bitcast<half>((int16)0x7C00)
#define PASTER(a, b, _) a ## _ ## b
#define EVALUATOR(a, b, _) PASTER(a, b, _)
#define atomic_add(TM, TN) EVALUATOR(atomic_add, EVALUATOR(TM, TN, x), _)
extern void atomic_add_64(float*[64], float[64], bool[64]);
extern void atomic_add_128x128(float*[128, 128], float[128, 128], bool[128, 128]);
extern void atomic_add_64x64(float*[64, 64], float[64, 64], bool[64, 64]);
extern int atomic_cas(int*, int, int);
extern int atomic_xchg(int*, int);
extern float f32_atomic_add(float*, float);
extern void atomic_add_128x128(float*[128, 128], float[128, 128], bool[128, 128]);
extern void atomic_add_64x64(float*[64, 64], float[64, 64], bool[64, 64]);
extern int get_program_id(int);
extern int get_num_programs(int);
extern float sqrtf(float);

View File

@@ -21,7 +21,7 @@ std::map<map_key_t, std::shared_ptr<rt::function>> id_fn_map;
std::map<size_t, double> fp64scalar_map;
std::map<size_t, int64_t> i64scalar_map;
/* Grid map */
/* Grid utilities */
void register_grid(const map_key_t& key,
const rt::function::grid_fn_ty& grid_fn) {
@@ -32,7 +32,7 @@ void delete_grid(const map_key_t& key) {
id_grid_map.erase(key);
}
/* Function map */
/* Function utilities */
void register_fn(const map_key_t& key,
const std::string& src,
@@ -60,22 +60,7 @@ size_t make_op_id() {
return id_fn_map.size();
}
/* TF scalar wrapper */
size_t make_scalar_id() {
size_t ret = i64scalar_map.size();
i64scalar_map[ret] = int64_t();
return ret;
}
bool has_scalar(size_t id) {
return i64scalar_map.find(id) != i64scalar_map.end();
}
int64_t retrieve_scalar(size_t id) {
return i64scalar_map.at(id);
}
/* Function signature */
void make_module(const std::string& src, ir::module* ir,
const runtime::function::options_space_t& opt) {
std::string copy = triton::runtime::function::preheader() + src;
@@ -93,7 +78,6 @@ void make_module(const std::string& src, ir::module* ir,
gen.Gen(ir);
}
/* Function signature */
std::vector<rt::arg_type> get_fn_signature(const std::string& src,
const runtime::function::options_space_t& opt) {
// triton-ir code-gen
@@ -146,8 +130,6 @@ PYBIND11_MODULE(libtriton, m) {
m.def("register_cst", &register_cst);
m.def("delete_fn", &delete_fn);
m.def("make_op_id", &make_op_id);
m.def("make_scalar_id", &make_scalar_id);
m.def("retrieve_scalar", &retrieve_scalar);
m.def("cleanup", &cleanup);
;
}

View File

@@ -18,6 +18,15 @@ std::shared_ptr<drv::device> host_device;
std::shared_ptr<drv::context> host_context;
std::shared_ptr<drv::stream> host_stream;
int64_t cdiv_sum(torch::Tensor& x, int64_t div){
TORCH_CHECK(x.unsafeGetTensorImpl()->key_set().has(c10::DispatchKey::CPU), "Argument of cdiv_sum must be a CPU tensor")
auto _x = x.accessor<int, 1>();
int64_t ret = 0;
for(size_t i = 0; i < x.size(0); i++)
ret += (_x[i] + div - 1) / div;
return ret;
}
void launch_kernel(int64_t op_id, int64_t dev_id, const std::string& args){
if(dev_id == -1){
if(!host_stream){
@@ -36,4 +45,6 @@ void launch_kernel(int64_t op_id, int64_t dev_id, const std::string& args){
}
static auto registry = torch::RegisterOperators("triton::launch_kernel", &launch_kernel);
static auto registry = torch::RegisterOperators()
.op("triton::launch_kernel", &launch_kernel)
.op("triton::cdiv_sum", &cdiv_sum);

View File

@@ -47,6 +47,9 @@ def th_to_triton(obj):
def cdiv(a, b):
return (a + b - 1) // b
def cdiv_sum(a, b):
return torch.ops.triton.cdiv_sum(a, b)
class kernel:
def __init__(self, src, defines = dict(), num_warps = [2, 4, 8]):
@@ -73,9 +76,6 @@ class kernel:
if device not in self.registered:
self.registered.add(device)
libtriton.register_fn((self.op_id, device), self.src, self.opt, os.path.realpath(libtriton.__file__))
# launch options
bench = kwargs['bench'] if 'bench' in kwargs else 0
bench_id = libtriton.make_scalar_id() if bench > 0 else -1
# launch grid
if 'grid' not in kwargs:
raise RuntimeError('Must provide grid for kernel launch')