[GENERAL] Merged v1.0alpha into master. Added features are:
- A100 support via mma.16816 - Thread swizzling for conflict-free shared memory accesses without padding - Complete overhaul of the LLVM code generation in codegen/selection/generator.cc to remove overengineering - Added debugging capabilities in the Python binding - Compilation error for kernels that spill
This commit is contained in:
@@ -38,7 +38,7 @@ void delete_grid(const map_key_t& key) {
|
||||
|
||||
void register_fn(const map_key_t& key,
|
||||
const std::string& src,
|
||||
const rt::function::options_space_t& opt) {
|
||||
const rt::options_space_t& opt) {
|
||||
if(id_fn_map.find(key) == id_fn_map.end())
|
||||
id_fn_map[key].reset(new rt::function(src, opt, ""));
|
||||
}
|
||||
@@ -47,9 +47,9 @@ void delete_fn(const map_key_t& key) {
|
||||
id_fn_map.erase(key);
|
||||
}
|
||||
|
||||
std::string get_fn_ptx(const map_key_t& key, const rt::function::options_t& opt) {
|
||||
triton::driver::cu_device device(torch_get_cuda_device(key.second), false);
|
||||
return id_fn_map[key]->ptx(&device, opt);
|
||||
std::string get_fn_asm(const map_key_t& key, rt::asm_mode_t mode, const rt::options_t& opt) {
|
||||
triton::driver::cu_device device(key.second, false);
|
||||
return id_fn_map[key]->get_asm(mode, &device, opt);
|
||||
}
|
||||
|
||||
void cleanup() {
|
||||
@@ -63,7 +63,7 @@ size_t make_op_id() {
|
||||
|
||||
/* Function signature */
|
||||
void make_module(const std::string& src, ir::module* ir,
|
||||
const runtime::function::options_space_t& opt) {
|
||||
const runtime::options_space_t& opt) {
|
||||
std::string copy = triton::runtime::function::preheader() + src;
|
||||
// pre-process
|
||||
TokenSequence tokens;
|
||||
@@ -80,7 +80,7 @@ void make_module(const std::string& src, ir::module* ir,
|
||||
}
|
||||
|
||||
std::vector<rt::arg_type> get_fn_signature(const std::string& src,
|
||||
const runtime::function::options_space_t& opt) {
|
||||
const runtime::options_space_t& opt) {
|
||||
// triton-ir code-gen
|
||||
ir::context ctx;
|
||||
auto ir = std::shared_ptr<ir::module>(new ir::module("", ctx));
|
||||
@@ -95,8 +95,8 @@ std::vector<rt::arg_type> get_fn_signature(const std::string& src,
|
||||
return ret;
|
||||
}
|
||||
|
||||
typedef triton::runtime::function::options_t options_t;
|
||||
typedef triton::runtime::function::options_space_t options_space_t;
|
||||
typedef triton::runtime::options_t options_t;
|
||||
typedef triton::runtime::options_space_t options_space_t;
|
||||
|
||||
PYBIND11_MODULE(libtriton, m) {
|
||||
m.doc() = "Python bindings to the C++ Triton API";
|
||||
@@ -112,6 +112,10 @@ PYBIND11_MODULE(libtriton, m) {
|
||||
.value("float", rt::FLOAT_T)
|
||||
.value("double", rt::DOUBLE_T)
|
||||
.value("buffer", rt::BUFFER_T);
|
||||
|
||||
pybind11::enum_<rt::asm_mode_t>(m, "asm_mode")
|
||||
.value("ptx", rt::ASM_NV_PTX)
|
||||
.value("sass", rt::ASM_NV_SASS);
|
||||
|
||||
pybind11::class_<options_t>(m, "options")
|
||||
.def(pybind11::init<>())
|
||||
@@ -126,7 +130,7 @@ PYBIND11_MODULE(libtriton, m) {
|
||||
|
||||
// hooks into triton constructs since frameworks may not use pybind11
|
||||
m.def("get_fn_signature", &get_fn_signature);
|
||||
m.def("get_fn_ptx", &get_fn_ptx);
|
||||
m.def("get_fn_asm", &get_fn_asm);
|
||||
m.def("register_grid", ®ister_grid);
|
||||
m.def("delete_grid", &delete_grid);
|
||||
m.def("register_fn", ®ister_fn);
|
||||
|
@@ -59,16 +59,12 @@ void synchronize(int64_t dev_id) {
|
||||
}
|
||||
}
|
||||
|
||||
torch::Tensor raw_like(torch::Tensor x){
|
||||
torch::Tensor cuda_empty_like(torch::Tensor x){
|
||||
if(x.nbytes() == 0)
|
||||
return torch::empty_like(x);
|
||||
C10_CUDA_CHECK(cudaSetDevice(x.device().index()));
|
||||
auto shape = x.sizes();
|
||||
CUdeviceptr data;
|
||||
triton::driver::dispatch::cuMemAlloc(&data, x.nbytes());
|
||||
auto deleter = [data](void* ptr) { triton::driver::dispatch::cuMemFree_v2(data); };
|
||||
auto ret = torch::from_blob((void*)data, shape, deleter, x.options());
|
||||
ret.copy_(x);
|
||||
void* data;
|
||||
cudaMalloc(&data, x.nbytes());
|
||||
auto ret = torch::from_blob((void*)data, x.sizes(), x.strides(), [data](void* ptr) { cudaFree(data); }, x.options());
|
||||
return ret;
|
||||
}
|
||||
|
||||
@@ -94,6 +90,6 @@ void launch_kernel(int64_t op_id, int64_t dev_id, const std::string& args,
|
||||
|
||||
static auto registry = torch::RegisterOperators()
|
||||
.op("triton::launch_kernel", &launch_kernel)
|
||||
.op("triton::raw_like", &raw_like)
|
||||
.op("triton::cuda_empty_like", &cuda_empty_like)
|
||||
.op("triton::cdiv_sum", &cdiv_sum)
|
||||
.op("triton::synchronize", &synchronize);
|
||||
|
Reference in New Issue
Block a user