diff --git a/lib/runtime/function.cc b/lib/runtime/function.cc index ef735b670..a7f25168b 100644 --- a/lib/runtime/function.cc +++ b/lib/runtime/function.cc @@ -349,11 +349,17 @@ std::string function::preheader() { #define F32_INFINITY bitcast(0x7F800000) #define F16_INFINITY bitcast((int16)0x7C00) + +#define PASTER(a, b, _) a ## _ ## b +#define EVALUATOR(a, b, _) PASTER(a, b, _) +#define atomic_add(TM, TN) EVALUATOR(atomic_add, EVALUATOR(TM, TN, x), _) +extern void atomic_add_64(float*[64], float[64], bool[64]); +extern void atomic_add_128x128(float*[128, 128], float[128, 128], bool[128, 128]); +extern void atomic_add_64x64(float*[64, 64], float[64, 64], bool[64, 64]); + extern int atomic_cas(int*, int, int); extern int atomic_xchg(int*, int); extern float f32_atomic_add(float*, float); -extern void atomic_add_128x128(float*[128, 128], float[128, 128], bool[128, 128]); -extern void atomic_add_64x64(float*[64, 64], float[64, 64], bool[64, 64]); extern int get_program_id(int); extern int get_num_programs(int); extern float sqrtf(float); diff --git a/python/src/bindings.cc b/python/src/bindings.cc index 2ed4aba46..c97cf5fc0 100644 --- a/python/src/bindings.cc +++ b/python/src/bindings.cc @@ -21,7 +21,7 @@ std::map> id_fn_map; std::map fp64scalar_map; std::map i64scalar_map; -/* Grid map */ +/* Grid utilities */ void register_grid(const map_key_t& key, const rt::function::grid_fn_ty& grid_fn) { @@ -32,7 +32,7 @@ void delete_grid(const map_key_t& key) { id_grid_map.erase(key); } -/* Function map */ +/* Function utilities */ void register_fn(const map_key_t& key, const std::string& src, @@ -60,22 +60,7 @@ size_t make_op_id() { return id_fn_map.size(); } - -/* TF scalar wrapper */ -size_t make_scalar_id() { - size_t ret = i64scalar_map.size(); - i64scalar_map[ret] = int64_t(); - return ret; -} - -bool has_scalar(size_t id) { - return i64scalar_map.find(id) != i64scalar_map.end(); -} - -int64_t retrieve_scalar(size_t id) { - return i64scalar_map.at(id); -} - +/* Function signature */ void make_module(const std::string& src, ir::module* ir, const runtime::function::options_space_t& opt) { std::string copy = triton::runtime::function::preheader() + src; @@ -93,7 +78,6 @@ void make_module(const std::string& src, ir::module* ir, gen.Gen(ir); } -/* Function signature */ std::vector get_fn_signature(const std::string& src, const runtime::function::options_space_t& opt) { // triton-ir code-gen @@ -146,8 +130,6 @@ PYBIND11_MODULE(libtriton, m) { m.def("register_cst", ®ister_cst); m.def("delete_fn", &delete_fn); m.def("make_op_id", &make_op_id); - m.def("make_scalar_id", &make_scalar_id); - m.def("retrieve_scalar", &retrieve_scalar); m.def("cleanup", &cleanup); ; } diff --git a/python/src/launch.cc b/python/src/launch.cc index 75fd0e66a..5e698b096 100644 --- a/python/src/launch.cc +++ b/python/src/launch.cc @@ -18,6 +18,15 @@ std::shared_ptr host_device; std::shared_ptr host_context; std::shared_ptr host_stream; +int64_t cdiv_sum(torch::Tensor& x, int64_t div){ + TORCH_CHECK(x.unsafeGetTensorImpl()->key_set().has(c10::DispatchKey::CPU), "Argument of cdiv_sum must be a CPU tensor") + auto _x = x.accessor(); + int64_t ret = 0; + for(size_t i = 0; i < x.size(0); i++) + ret += (_x[i] + div - 1) / div; + return ret; +} + void launch_kernel(int64_t op_id, int64_t dev_id, const std::string& args){ if(dev_id == -1){ if(!host_stream){ @@ -36,4 +45,6 @@ void launch_kernel(int64_t op_id, int64_t dev_id, const std::string& args){ } -static auto registry = torch::RegisterOperators("triton::launch_kernel", &launch_kernel); +static auto registry = torch::RegisterOperators() + .op("triton::launch_kernel", &launch_kernel) + .op("triton::cdiv_sum", &cdiv_sum); diff --git a/python/triton/kernel.py b/python/triton/kernel.py index 1cf354a3b..efd683878 100644 --- a/python/triton/kernel.py +++ b/python/triton/kernel.py @@ -47,6 +47,9 @@ def th_to_triton(obj): def cdiv(a, b): return (a + b - 1) // b +def cdiv_sum(a, b): + return torch.ops.triton.cdiv_sum(a, b) + class kernel: def __init__(self, src, defines = dict(), num_warps = [2, 4, 8]): @@ -73,9 +76,6 @@ class kernel: if device not in self.registered: self.registered.add(device) libtriton.register_fn((self.op_id, device), self.src, self.opt, os.path.realpath(libtriton.__file__)) - # launch options - bench = kwargs['bench'] if 'bench' in kwargs else 0 - bench_id = libtriton.make_scalar_id() if bench > 0 else -1 # launch grid if 'grid' not in kwargs: raise RuntimeError('Must provide grid for kernel launch')