[GENERAL] Merged v1.0alpha into master. Added features are:

- A100 support via mma.16816
- Thread swizzling for conflict-free shared memory accesses without
padding
- Complete overhaul of the LLVM code generation in
codegen/selection/generator.cc to remove overengineering
- Added debugging capabilities in the Python binding
- Compilation error for kernels that spill
This commit is contained in:
Philippe Tillet
2021-01-11 19:20:34 -05:00
parent c0bc7ed8b0
commit 083bbd1e8d
75 changed files with 2688 additions and 4512 deletions

View File

@@ -38,7 +38,7 @@ void delete_grid(const map_key_t& key) {
void register_fn(const map_key_t& key,
const std::string& src,
const rt::function::options_space_t& opt) {
const rt::options_space_t& opt) {
if(id_fn_map.find(key) == id_fn_map.end())
id_fn_map[key].reset(new rt::function(src, opt, ""));
}
@@ -47,9 +47,9 @@ void delete_fn(const map_key_t& key) {
id_fn_map.erase(key);
}
std::string get_fn_ptx(const map_key_t& key, const rt::function::options_t& opt) {
triton::driver::cu_device device(torch_get_cuda_device(key.second), false);
return id_fn_map[key]->ptx(&device, opt);
std::string get_fn_asm(const map_key_t& key, rt::asm_mode_t mode, const rt::options_t& opt) {
triton::driver::cu_device device(key.second, false);
return id_fn_map[key]->get_asm(mode, &device, opt);
}
void cleanup() {
@@ -63,7 +63,7 @@ size_t make_op_id() {
/* Function signature */
void make_module(const std::string& src, ir::module* ir,
const runtime::function::options_space_t& opt) {
const runtime::options_space_t& opt) {
std::string copy = triton::runtime::function::preheader() + src;
// pre-process
TokenSequence tokens;
@@ -80,7 +80,7 @@ void make_module(const std::string& src, ir::module* ir,
}
std::vector<rt::arg_type> get_fn_signature(const std::string& src,
const runtime::function::options_space_t& opt) {
const runtime::options_space_t& opt) {
// triton-ir code-gen
ir::context ctx;
auto ir = std::shared_ptr<ir::module>(new ir::module("", ctx));
@@ -95,8 +95,8 @@ std::vector<rt::arg_type> get_fn_signature(const std::string& src,
return ret;
}
typedef triton::runtime::function::options_t options_t;
typedef triton::runtime::function::options_space_t options_space_t;
typedef triton::runtime::options_t options_t;
typedef triton::runtime::options_space_t options_space_t;
PYBIND11_MODULE(libtriton, m) {
m.doc() = "Python bindings to the C++ Triton API";
@@ -112,6 +112,10 @@ PYBIND11_MODULE(libtriton, m) {
.value("float", rt::FLOAT_T)
.value("double", rt::DOUBLE_T)
.value("buffer", rt::BUFFER_T);
pybind11::enum_<rt::asm_mode_t>(m, "asm_mode")
.value("ptx", rt::ASM_NV_PTX)
.value("sass", rt::ASM_NV_SASS);
pybind11::class_<options_t>(m, "options")
.def(pybind11::init<>())
@@ -126,7 +130,7 @@ PYBIND11_MODULE(libtriton, m) {
// hooks into triton constructs since frameworks may not use pybind11
m.def("get_fn_signature", &get_fn_signature);
m.def("get_fn_ptx", &get_fn_ptx);
m.def("get_fn_asm", &get_fn_asm);
m.def("register_grid", &register_grid);
m.def("delete_grid", &delete_grid);
m.def("register_fn", &register_fn);

View File

@@ -59,16 +59,12 @@ void synchronize(int64_t dev_id) {
}
}
torch::Tensor raw_like(torch::Tensor x){
torch::Tensor cuda_empty_like(torch::Tensor x){
if(x.nbytes() == 0)
return torch::empty_like(x);
C10_CUDA_CHECK(cudaSetDevice(x.device().index()));
auto shape = x.sizes();
CUdeviceptr data;
triton::driver::dispatch::cuMemAlloc(&data, x.nbytes());
auto deleter = [data](void* ptr) { triton::driver::dispatch::cuMemFree_v2(data); };
auto ret = torch::from_blob((void*)data, shape, deleter, x.options());
ret.copy_(x);
void* data;
cudaMalloc(&data, x.nbytes());
auto ret = torch::from_blob((void*)data, x.sizes(), x.strides(), [data](void* ptr) { cudaFree(data); }, x.options());
return ret;
}
@@ -94,6 +90,6 @@ void launch_kernel(int64_t op_id, int64_t dev_id, const std::string& args,
static auto registry = torch::RegisterOperators()
.op("triton::launch_kernel", &launch_kernel)
.op("triton::raw_like", &raw_like)
.op("triton::cuda_empty_like", &cuda_empty_like)
.op("triton::cdiv_sum", &cdiv_sum)
.op("triton::synchronize", &synchronize);