[PYTHON][TESTS][DOC] Various improvement of the API and code quality:

* Simplified `triton.kernel` API to achieve lower latency:
  > .data_ptr() must now be passed as kernel argument. No more implicit
conversion from torch.tensor
  > compilation options are now constant attributes, i.e., opt.d('VAR')
becomes opt.VAR
  > torch.device must now be passed explicitly to triton.kernel (no
longer inferred from torch.tensor arguments)
* C++ tests moved to `python/tests/`
* C++ tutorial created in `tutorials/`
* Python tutorial created in python/tutorials/
* Version changed to 1.0alpha
* No longer copying C++ headers into the Python package
* added python/triton/ops/ package for pre-written Triton ops
This commit is contained in:
Philippe Tillet
2021-01-29 17:27:16 -05:00
parent a5a477c36b
commit 269ebc12e5
63 changed files with 2255 additions and 3883 deletions

View File

@@ -13,15 +13,19 @@
#include "triton/ir/function.h"
using namespace triton;
namespace rt = triton::runtime;
namespace drv = triton::driver;
typedef std::pair<int, int> map_key_t;
std::map<map_key_t, std::shared_ptr<rt::function::grid_fn_ty>> id_grid_map;
std::map<map_key_t, std::shared_ptr<rt::function>> id_fn_map;
CUstream torch_get_cuda_stream(int64_t dev_id);
CUdevice torch_get_cuda_device(int64_t dev_id);
std::map<map_key_t, std::shared_ptr<rt::function::grid_fn_ty>> id_grid_map;
std::map<int, std::shared_ptr<rt::function>> id_fn_map;
std::map<int, std::shared_ptr<triton::driver::device>> tt_devices;
std::map<int, std::shared_ptr<triton::driver::stream>> tt_streams;
std::unordered_map<const rt::options_t*, pybind11::object> opt_cache_;
extern CUstream torch_get_cuda_stream(int64_t dev_id);
extern CUdevice torch_get_cuda_device(int64_t dev_id);
/* Grid utilities */
@@ -36,106 +40,123 @@ void delete_grid(const map_key_t& key) {
/* Function utilities */
void register_fn(const map_key_t& key,
void register_fn(int op_id,
int dev_id,
const std::string& src,
const rt::options_space_t& opt) {
if(id_fn_map.find(key) == id_fn_map.end())
id_fn_map[key].reset(new rt::function(src, opt, ""));
if(tt_devices.find(dev_id) == tt_devices.end()) {
driver::device* device;
driver::stream* stream;
if(dev_id >= 0){
device = new triton::driver::cu_device(torch_get_cuda_device(dev_id), false);
stream = new triton::driver::cu_stream(torch_get_cuda_stream(dev_id), false);
}
else{
device = new triton::driver::host_device();
stream = new triton::driver::host_stream();
}
tt_devices[dev_id].reset(device);
tt_streams[dev_id].reset(stream);
}
if(id_fn_map.find(op_id) == id_fn_map.end()){
id_fn_map[op_id].reset(new rt::function(src, opt, &*tt_devices[dev_id]));
}
for(const auto& k: id_fn_map[op_id]->get_kernels()){
const rt::options_t* opt = &k.first;
pybind11::object obj = pybind11::cast(opt, pybind11::return_value_policy::reference);
for(auto x: opt->defines)
if(std::all_of(x.second.begin(), x.second.end(), ::isdigit))
obj.attr(x.first.c_str()) = std::stoi(x.second);
opt_cache_[&k.second->opt] = obj;
}
}
void delete_fn(const map_key_t& key) {
id_fn_map.erase(key);
void delete_fn(int op_id) {
id_fn_map.erase(op_id);
}
std::string get_fn_asm(const map_key_t& key, rt::asm_mode_t mode, const rt::options_t& opt) {
triton::driver::cu_device device(key.second, false);
return id_fn_map[key]->get_asm(mode, &device, opt);
}
void cleanup() {
id_grid_map.clear();
id_fn_map.clear();
opt_cache_.clear();
}
size_t make_op_id() {
return id_fn_map.size();
}
/* Function signature */
void make_module(const std::string& src, ir::module* ir,
const runtime::options_space_t& opt) {
std::string copy = triton::runtime::function::preheader() + src;
// pre-process
TokenSequence tokens;
Preprocessor cpp(&copy, true);
for(auto it: opt.defines){
cpp.AddMacro(it.first, &it.second[0]);
}
cpp.Process(tokens);
// parse
Parser parser(tokens);
parser.Parse();
Generator gen(&parser);
gen.Gen(ir);
std::vector<rt::arg_type> get_fn_signature(size_t op_id) {
return id_fn_map[op_id]->get_kernels()[0].second->get_sig();
}
std::vector<rt::arg_type> get_fn_signature(const std::string& src,
const runtime::options_space_t& opt) {
// triton-ir code-gen
ir::context ctx;
auto ir = std::shared_ptr<ir::module>(new ir::module("", ctx));
make_module(src, &*ir, opt);
// function
ir::function* fn = ir->get_function_list().front();
// extract signature
std::vector<rt::arg_type> ret;
ir::function_type* ty = fn->get_fn_type();
for(size_t i = 0; i < ty->get_num_params(); i++)
ret.push_back(rt::convert(ty->get_param_ty(i)));
return ret;
void launch_kernel(int64_t op_id, int64_t dev_id, const std::string& args, size_t grid_0, size_t grid_1, size_t grid_2){
rt::function* fn = id_fn_map.at(op_id).get();
(*fn)((void**)args.c_str(), args.size(), {grid_0, grid_1, grid_2}, &*tt_streams[dev_id]);
// for(size_t n = 0; n < constant_names.size(); n++){
// const torch::Tensor& x = constant_vals[n];
// fn->set_cst(constant_names[n].c_str(), (char*)x.data_ptr(), x.numel()*x.element_size());
}
typedef triton::runtime::options_t options_t;
typedef triton::runtime::options_space_t options_space_t;
pybind11::object autotune(int64_t op_id, int64_t dev_id, const std::string& args, const rt::function::grid_fn_ty& grid){
rt::function* fn = id_fn_map.at(op_id).get();
auto wrapper = [&grid](const rt::options_t& opt){
pybind11::object obj = pybind11::cast(&opt, pybind11::return_value_policy::reference);
for(auto x: opt.defines)
if(std::all_of(x.second.begin(), x.second.end(), ::isdigit))
obj.attr(x.first.c_str()) = std::stoi(x.second);
return grid(*obj.cast<rt::options_t*>());
};
rt::kernel* kernel = fn->autotune((void**)args.c_str(), args.size(), wrapper, &*tt_streams[dev_id]);
return opt_cache_.at(&kernel->opt);
}
void init_superblocking(pybind11::module &m);
void init_launch(pybind11::module &m);
PYBIND11_MODULE(libtriton, m) {
m.doc() = "Python bindings to the C++ Triton API";
// bindings for triton classes
pybind11::enum_<rt::arg_type>(m, "arg_type")
.value("int1", rt::INT1_T)
.value("int8", rt::INT8_T)
.value("int16", rt::INT16_T)
.value("int32", rt::INT32_T)
.value("int64", rt::INT64_T)
.value("half", rt::HALF_T)
.value("float", rt::FLOAT_T)
.value("int1" , rt::INT1_T)
.value("int8" , rt::INT8_T)
.value("int16" , rt::INT16_T)
.value("int32" , rt::INT32_T)
.value("int64" , rt::INT64_T)
.value("half" , rt::HALF_T)
.value("float" , rt::FLOAT_T)
.value("double", rt::DOUBLE_T)
.value("buffer", rt::BUFFER_T);
pybind11::enum_<rt::asm_mode_t>(m, "asm_mode")
.value("ptx", rt::ASM_NV_PTX)
.value("ptx" , rt::ASM_NV_PTX)
.value("sass", rt::ASM_NV_SASS);
pybind11::class_<options_t>(m, "options")
.def(pybind11::init<>())
.def("d", &options_t::D<int>)
.def_readwrite("num_warps", &options_t::num_warps)
.def_readwrite("defines" , &options_t::defines);
pybind11::class_<rt::options_t>(m, "options", pybind11::dynamic_attr())
.def_readwrite("num_warps", &rt::options_t::num_warps)
.def_readwrite("defines" , &rt::options_t::defines);
pybind11::class_<options_space_t>(m, "options_space")
pybind11::class_<rt::options_space_t>(m, "options_space")
.def(pybind11::init<>())
.def_readwrite("defines", &options_space_t::defines)
.def_readwrite("num_warps", &options_space_t::num_warps);
.def_readwrite("num_warps", &rt::options_space_t::num_warps)
.def_readwrite("defines" , &rt::options_space_t::defines);
// hooks into triton constructs since frameworks may not use pybind11
m.def("get_fn_signature", &get_fn_signature);
m.def("get_fn_asm", &get_fn_asm);
// m.def("get_fn_asm", &get_fn_asm);
m.def("register_grid", &register_grid);
m.def("delete_grid", &delete_grid);
m.def("register_fn", &register_fn);
m.def("delete_fn", &delete_fn);
m.def("make_op_id", &make_op_id);
m.def("cleanup", &cleanup);
;
m.def("autotune", &autotune, pybind11::return_value_policy::reference);
m.def("launch_kernel", &launch_kernel);
init_launch(m);
init_superblocking(m);
}

View File

@@ -1,95 +0,0 @@
// Thanks to Scott Gray (OpenAI) for the idea to pass the arguments
// as a string constructed with struct.pack in python
#include "triton/driver/buffer.h"
#include "triton/driver/stream.h"
#include "triton/runtime/function.h"
#include "triton/tools/bench.hpp"
#include "torch/script.h"
#include "ATen/cuda/CUDAContext.h"
#include <c10/cuda/CUDAException.h>
#include <cuda_runtime_api.h>
namespace rt = triton::runtime;
namespace drv = triton::driver;
typedef std::pair<int, int> map_key_t;
extern std::map<map_key_t, std::shared_ptr<rt::function::grid_fn_ty>> id_grid_map;
extern std::map<map_key_t, std::shared_ptr<rt::function>> id_fn_map;
std::shared_ptr<drv::device> host_device;
std::shared_ptr<drv::context> host_context;
std::shared_ptr<drv::stream> host_stream;
int64_t cdiv_sum(torch::Tensor x, int64_t div){
TORCH_CHECK(!x.is_cuda(), "Argument of cdiv_sum must be a CPU tensor")
auto _x = x.accessor<int, 1>();
int64_t ret = 0;
for(size_t i = 0; i < x.size(0); i++)
ret += (_x[i] + div - 1) / div;
return ret;
}
void init_host_stream() {
if(!host_stream){
host_device.reset(new drv::host_device());
host_context.reset(drv::context::create(&*host_device));
host_stream.reset(drv::stream::create(host_context->backend()));
}
}
CUstream torch_get_cuda_stream(int64_t dev_id) {
return (CUstream)c10::cuda::getCurrentCUDAStream(dev_id).stream();
}
CUdeviceptr torch_get_cuda_device(int64_t dev_id) {
CUdevice ret;
triton::driver::dispatch::cuDeviceGet(&ret, dev_id);
return ret;
}
void synchronize(int64_t dev_id) {
if(dev_id == -1){
init_host_stream();
host_stream->synchronize();
}
else{
triton::driver::cu_stream stream(torch_get_cuda_stream(dev_id), false);
stream.synchronize();
}
}
torch::Tensor cuda_empty_like(torch::Tensor x){
if(x.nbytes() == 0)
return torch::empty_like(x);
void* data;
cudaMalloc(&data, x.nbytes());
auto ret = torch::from_blob((void*)data, x.sizes(), x.strides(), [data](void* ptr) { cudaFree(data); }, x.options());
return ret;
}
void launch_kernel(int64_t op_id, int64_t dev_id, const std::string& args,
const std::vector<std::string>& constant_names, const std::vector<torch::Tensor>& constant_vals){
rt::function* fn = id_fn_map.at({op_id, dev_id}).get();
for(size_t n = 0; n < constant_names.size(); n++){
const torch::Tensor& x = constant_vals[n];
fn->set_cst(constant_names[n].c_str(), (char*)x.data_ptr(), x.numel()*x.element_size());
}
if(dev_id == -1){
init_host_stream();
(*fn)((void**)args.c_str(), args.size(), *id_grid_map.at({op_id, dev_id}), &*host_stream, &*host_device);
}
else{
C10_CUDA_CHECK(cudaSetDevice(dev_id));
triton::driver::cu_stream stream(torch_get_cuda_stream(dev_id), false);
triton::driver::cu_device device(torch_get_cuda_device(dev_id), false);
(*fn)((void**)args.c_str(), args.size(), *id_grid_map.at({op_id, dev_id}), &stream, &device);
}
}
static auto registry = torch::RegisterOperators()
.op("triton::launch_kernel", &launch_kernel)
.op("triton::cuda_empty_like", &cuda_empty_like)
.op("triton::cdiv_sum", &cdiv_sum)
.op("triton::synchronize", &synchronize);

View File

@@ -0,0 +1,83 @@
// Thanks to Scott Gray (OpenAI) for the idea to pass the arguments
// as a string constructed with struct.pack in python
#include <pybind11/pybind11.h>
#include <pybind11/stl.h>
#include <pybind11/functional.h>
#include "triton/driver/buffer.h"
#include "triton/driver/stream.h"
#include "triton/runtime/function.h"
#include "triton/tools/bench.hpp"
#include "torch/script.h"
#include "ATen/cuda/CUDAContext.h"
#include <c10/cuda/CUDAException.h>
#include <cuda_runtime_api.h>
namespace rt = triton::runtime;
namespace drv = triton::driver;
typedef std::pair<int, int> map_key_t;
extern std::map<map_key_t, std::shared_ptr<rt::function::grid_fn_ty>> id_grid_map;
extern std::map<int, std::shared_ptr<rt::function>> id_fn_map;
extern std::map<int, std::shared_ptr<drv::device>> tt_devices;
extern std::map<int, std::shared_ptr<drv::stream>> tt_streams;
int64_t cdiv(int64_t a, int64_t b) {
return (a + b - 1) / b;
}
int64_t largest_pow2_divisor(int64_t a){
if(a % 8 == 0) return 8;
if(a % 4 == 0) return 4;
if(a % 2 == 0) return 2;
return 1;
}
int64_t cdiv_sum(torch::Tensor x, int64_t div){
TORCH_CHECK(!x.is_cuda(), "Argument of cdiv_sum must be a CPU tensor")
auto _x = x.accessor<int, 1>();
int64_t ret = 0;
for(size_t i = 0; i < x.size(0); i++)
ret += (_x[i] + div - 1) / div;
return ret;
}
CUstream torch_get_cuda_stream(int64_t dev_id) {
return (CUstream)c10::cuda::getCurrentCUDAStream(dev_id).stream();
}
CUdeviceptr torch_get_cuda_device(int64_t dev_id) {
CUdevice ret;
triton::driver::dispatch::cuDeviceGet(&ret, dev_id);
return ret;
}
void synchronize(int64_t dev_id) {
tt_streams[dev_id]->synchronize();
}
torch::Tensor cuda_empty_like(torch::Tensor x){
if(x.nbytes() == 0)
return torch::empty_like(x);
void* data;
cudaMalloc(&data, x.nbytes());
auto ret = torch::from_blob((void*)data, x.sizes(), x.strides(), [data](void* ptr) { cudaFree(data); }, x.options());
return ret;
}
void cuda_set_device(int64_t dev_id) {
if(dev_id >= 0)
C10_CUDA_CHECK(cudaSetDevice(dev_id));
}
void init_launch(pybind11::module &m) {
m.def("cuda_set_device", &cuda_set_device);
m.def("cuda_empty_like", &cuda_empty_like);
m.def("largest_pow2_divisor", &largest_pow2_divisor);
m.def("cdiv", &cdiv);
m.def("cdiv_sum", &cdiv_sum);
m.def("synchronize", &synchronize);
}

View File

@@ -0,0 +1,117 @@
#include <torch/extension.h>
#include <string>
#include <tuple>
#include <vector>
#ifdef _OPENMP
#include <omp.h>
#endif
typedef std::vector<std::tuple<int, torch::Tensor>> ret_t;
void segment_blocks(torch::Tensor layout, torch::Tensor idx, torch::Tensor scratch, int max_width, ret_t& ret){
size_t H = layout.size(0);
size_t M = layout.size(1);
size_t N = layout.size(2);
torch::Tensor tmp = torch::zeros_like(layout);
auto _tmp = tmp.accessor <int, 3>();
auto _layout = layout.accessor <int, 3>();
auto _idx = idx.accessor <int, 3>();
auto _scratch = scratch.accessor<int, 3>();
std::vector<int> current(H, 0);
#ifdef _OPENMP
#pragma omp parallel for
#endif
for(size_t h = 0; h < H; h++){
// surrounding indices
std::vector<int> ii_left(max_width, -1);
std::vector<std::vector<int>> ii_top(max_width, std::vector<int>(N, -1));
for(size_t m = 0; m < M; m++){
for(size_t n = 0; n < N; n++){
int v = _layout[h][m][n];
if(v == 0)
continue;
int n_left= ii_left[max_width-1];
int m_top = ii_top [max_width-1][n];
int top = (m_top >= 0) ? _tmp[h][m_top][n] : 0;
int left = (n_left >= 0) ? _tmp[h][m][n_left] : 0;
int topleft = (m_top >=0 && n_left >= 0) ? _tmp[h][m_top][n_left] : 0;
int width = std::min(left, std::min(top, topleft)) + 1;
// reset width if blocks cannot be
// packed together (i.e., there's a 1 "in the middle")
for(int nn = n_left + 1; nn < n; nn++)
if(ii_top[max_width-1][nn] > ii_top[max_width-1][n])
width = 1;
_tmp[h][m][n] = width;
// update n_left ring buffer
for(int k = 0; k < max_width-1; k++)
ii_left[k] = ii_left[k+1];
ii_left[max_width-1] = n;
// update ii_top ring buffer
for(int k = 0; k < max_width-1; k++)
ii_top[k][n] = ii_top[k+1][n];
ii_top[max_width-1][n] = m;
// block is too small -- skip
if(width != max_width)
continue;
// retained blocks are set to zeros
for(size_t km = 0; km < max_width; km++)
for(size_t kn = 0; kn < max_width; kn++)
{
int mm = ii_top[km][n];
int nn = ii_left[kn];
if(mm < 0 || nn < 0)
continue;
_layout[h][mm][nn] = 0;
_tmp[h][mm][nn] = 0;
_scratch[h][current[h]][0] = (int)h;
_scratch[h][current[h]][1] = (int)mm;
_scratch[h][current[h]][2] = (int)nn;
_scratch[h][current[h]][3] = _idx[h][mm][nn];
current[h]++;
}
}
}
}
std::vector<torch::Tensor> to_cat;
for(size_t h = 0; h < H; h++)
if(current[h] > 0)
to_cat.push_back(scratch[h].slice(0, 0, current[h]));
if(!to_cat.empty())
ret.push_back(std::make_tuple(max_width, torch::cat(to_cat)));
}
ret_t superblock(torch::Tensor layout, int start_width) {
ret_t ret;
// block index
torch::Tensor idx = torch::zeros_like(layout);
int current = 0;
int64_t H = layout.size(0);
int64_t M = layout.size(1);
int64_t N = layout.size(2);
auto _layout = layout.accessor <int, 3>();
auto _idx = idx.accessor<int, 3>();
for(int64_t h = 0; h < H; h++)
for(int64_t m = 0; m < M; m++)
for(int64_t n = 0; n < N; n++){
if(_layout[h][m][n] == 0)
continue;
_idx[h][m][n] = current++;
}
// scratch memory
torch::Tensor scratch = torch::empty({H, layout.sum().item<int>(), 4}, layout.dtype());
for(int max_width = start_width; max_width > 0; max_width /= 2)
segment_blocks(layout, idx, scratch, max_width, ret);
return ret;
}
void init_superblocking(pybind11::module &m) {
m.def("superblock", &superblock, "super-blocking for block-sparse matrix multiplication");
}