From 5ba5a7756145c75d77b7c09dd9e4574f1c26e4cc Mon Sep 17 00:00:00 2001 From: Philippe Tillet Date: Wed, 24 Mar 2021 01:24:50 -0400 Subject: [PATCH] [BUILD] Remove compilation warnings --- include/triton/runtime/function.h | 8 +---- lib/codegen/selection/generator.cc | 2 +- lib/runtime/function.cc | 53 +++++++++--------------------- python/src/triton.cc | 7 ++-- python/triton/kernel.py | 11 ++----- python/triton/testing.py | 4 ++- 6 files changed, 26 insertions(+), 59 deletions(-) diff --git a/include/triton/runtime/function.h b/include/triton/runtime/function.h index ddfe78776..c33f9d0e1 100644 --- a/include/triton/runtime/function.h +++ b/include/triton/runtime/function.h @@ -92,12 +92,6 @@ void add_arg(std::stringstream& ss, T arg) { /* ------------------------- */ /* ------------------------- */ -enum asm_mode_t { - ASM_LLIR, - ASM_NV_PTX, - ASM_NV_SASS -}; - class kernel{ public: typedef std::vector grid_t; @@ -111,7 +105,7 @@ public: public: kernel(const std::string& src, const options_t& opt, driver::device *device, const std::map &attrs = {}); void operator()(const std::string& args, driver::stream *stream, const grid_t& grid) const; - std::string get_asm(asm_mode_t mode); + std::string get_asm(const std::string &mode); public: const options_t opt; diff --git a/lib/codegen/selection/generator.cc b/lib/codegen/selection/generator.cc index 8dd6864ed..c25495dd3 100644 --- a/lib/codegen/selection/generator.cc +++ b/lib/codegen/selection/generator.cc @@ -28,7 +28,7 @@ using namespace llvm; #define f16_ty builder_->getHalfTy() #define f32_ty builder_->getFloatTy() #define i32_ty builder_->getInt32Ty() -#define vec_ty(...) VectorType::get(__VA_ARGS__) +#define vec_ty(type, num_el) VectorType::get(type, num_el, false) #define ptr_ty(...) PointerType::get(__VA_ARGS__) // constants #define i32(...) builder_->getInt32(__VA_ARGS__) diff --git a/lib/runtime/function.cc b/lib/runtime/function.cc index 90857e3a5..91d126de5 100644 --- a/lib/runtime/function.cc +++ b/lib/runtime/function.cc @@ -212,44 +212,23 @@ void kernel::operator()(const std::string& args, driver::stream *stream, const s stream->enqueue(&*ker_, grid, {(size_t)opt.num_warps * 32, 1, 1}, (void*)args.data(), args.size(), shared_mem_); } -std::string kernel::get_asm(asm_mode_t mode) { - switch(mode){ - case ASM_LLIR:{ - return ((driver::cu_module*)mod_.get())->llir(); - } - case ASM_NV_PTX: - case ASM_NV_SASS:{ - std::string ptx = ((driver::cu_module*)mod_.get())->ptx(); - // SASS - std::string input = std::tmpnam(nullptr); - std::string output = std::tmpnam(nullptr); - std::ofstream ofs(input); - ofs << ptx; - ofs.close(); - if(mode == ASM_NV_PTX) - return ptx; - std::string cmd; - int err; - // compile ptx - driver::cu_device* cu_device = (driver::cu_device*)dev_; - cmd = "ptxas --gpu-name=sm_" + std::to_string(cu_device->compute_capability()) + " " + input + " -o " + input + ".o"; - err = system(cmd.c_str()); - // disassemble - cmd = "cuobjdump --dump-sass " + input + ".o >> " + output; - err = system(cmd.c_str()); - std::regex comment(" *\\/\\* 0x[0-9a-f]+ \\*\\/"); - std::string to_delete = " /*"; - std::ifstream ifs(output); - std::string line; - std::string sass; - while(std::getline(ifs, line)) - if(!std::regex_match(line, comment)) - sass += line + "\n"; - return sass; - } - default: - return ""; +std::string kernel::get_asm(const std::string& mode) { + std::vector modes = {"llir", "ptx"}; + if(std::find(modes.begin(), modes.end(), mode) == modes.end()){ + std::string err = "Unrecognized mode. Supported values are: "; + for(std::string m: modes){ + if(m != modes[0]) + err += ", "; + err += m; } + throw std::runtime_error(err); + } + if(mode == "llir") + return ((driver::cu_module*)mod_.get())->llir(); + if(mode == "ptx") + return ((driver::cu_module*)mod_.get())->ptx(); + assert(false); + return ""; } /* --------------------------------- */ /* --------------------------------- */ diff --git a/python/src/triton.cc b/python/src/triton.cc index 05e7b9d10..3d8a06e44 100644 --- a/python/src/triton.cc +++ b/python/src/triton.cc @@ -127,10 +127,6 @@ void init_triton_runtime(py::module &&m) { .value("float", rt::FLOAT_T) .value("double", rt::DOUBLE_T) .value("buffer", rt::BUFFER_T); - // assembly mode - py::enum_(m, "asm_mode") - .value("ptx", rt::ASM_NV_PTX) - .value("sass", rt::ASM_NV_SASS); // compilation options py::class_(m, "options", py::dynamic_attr()) .def(py::init<>()) @@ -142,7 +138,8 @@ void init_triton_runtime(py::module &&m) { // kernel py::class_(m, "kernel") .def("__call__", &rt::kernel::operator()) - .def_readonly("opt", &rt::kernel::opt); + .def_readonly("opt", &rt::kernel::opt) + .def("asm", &rt::kernel::get_asm); // tune conf py::class_(m, "config") .def(py::init, int>(), diff --git a/python/triton/kernel.py b/python/triton/kernel.py index f84ceae05..c6ee76796 100644 --- a/python/triton/kernel.py +++ b/python/triton/kernel.py @@ -6,14 +6,9 @@ import torch import triton._C.libtriton.triton as _triton codes = { - _triton.runtime.arg_type.int1: 'B', - _triton.runtime.arg_type.int8: 'B', - _triton.runtime.arg_type.int32: 'I', - _triton.runtime.arg_type.int64: 'Q', - _triton.runtime.arg_type.half: 'H', - _triton.runtime.arg_type.float: 'f', - _triton.runtime.arg_type.double: 'd', - _triton.runtime.arg_type.buffer: 'P' + _triton.runtime.arg_type.int1: 'B', _triton.runtime.arg_type.int8: 'B', _triton.runtime.arg_type.int32: 'I', + _triton.runtime.arg_type.int64: 'Q', _triton.runtime.arg_type.half: 'H', _triton.runtime.arg_type.float: 'f', + _triton.runtime.arg_type.double: 'd', _triton.runtime.arg_type.buffer: 'P' } diff --git a/python/triton/testing.py b/python/triton/testing.py index 99c0bc1e6..71696579e 100644 --- a/python/triton/testing.py +++ b/python/triton/testing.py @@ -3,8 +3,10 @@ import os try: import triton._C.libtriton.cutlass as _cutlass + has_cutlass = True except ImportError: _cutlass = None + has_cutlass = False def sparsify_tensor(x, mask, block): @@ -44,7 +46,7 @@ def mask_tensor(x, mask, block, value=0): return ret -def allclose(x, y): +def allclose(x, y, tol=1e-2): assert x.dtype == y.dtype diff = abs(x - y) x_max = torch.max(x)