[BUILD] Remove compilation warnings
This commit is contained in:
committed by
Philippe Tillet
parent
b352bc79e3
commit
5ba5a77561
@@ -92,12 +92,6 @@ void add_arg(std::stringstream& ss, T arg) {
|
|||||||
/* ------------------------- */
|
/* ------------------------- */
|
||||||
/* ------------------------- */
|
/* ------------------------- */
|
||||||
|
|
||||||
enum asm_mode_t {
|
|
||||||
ASM_LLIR,
|
|
||||||
ASM_NV_PTX,
|
|
||||||
ASM_NV_SASS
|
|
||||||
};
|
|
||||||
|
|
||||||
class kernel{
|
class kernel{
|
||||||
public:
|
public:
|
||||||
typedef std::vector<size_t> grid_t;
|
typedef std::vector<size_t> grid_t;
|
||||||
@@ -111,7 +105,7 @@ public:
|
|||||||
public:
|
public:
|
||||||
kernel(const std::string& src, const options_t& opt, driver::device *device, const std::map<int, triton::ir::attribute> &attrs = {});
|
kernel(const std::string& src, const options_t& opt, driver::device *device, const std::map<int, triton::ir::attribute> &attrs = {});
|
||||||
void operator()(const std::string& args, driver::stream *stream, const grid_t& grid) const;
|
void operator()(const std::string& args, driver::stream *stream, const grid_t& grid) const;
|
||||||
std::string get_asm(asm_mode_t mode);
|
std::string get_asm(const std::string &mode);
|
||||||
|
|
||||||
public:
|
public:
|
||||||
const options_t opt;
|
const options_t opt;
|
||||||
|
@@ -28,7 +28,7 @@ using namespace llvm;
|
|||||||
#define f16_ty builder_->getHalfTy()
|
#define f16_ty builder_->getHalfTy()
|
||||||
#define f32_ty builder_->getFloatTy()
|
#define f32_ty builder_->getFloatTy()
|
||||||
#define i32_ty builder_->getInt32Ty()
|
#define i32_ty builder_->getInt32Ty()
|
||||||
#define vec_ty(...) VectorType::get(__VA_ARGS__)
|
#define vec_ty(type, num_el) VectorType::get(type, num_el, false)
|
||||||
#define ptr_ty(...) PointerType::get(__VA_ARGS__)
|
#define ptr_ty(...) PointerType::get(__VA_ARGS__)
|
||||||
// constants
|
// constants
|
||||||
#define i32(...) builder_->getInt32(__VA_ARGS__)
|
#define i32(...) builder_->getInt32(__VA_ARGS__)
|
||||||
|
@@ -212,44 +212,23 @@ void kernel::operator()(const std::string& args, driver::stream *stream, const s
|
|||||||
stream->enqueue(&*ker_, grid, {(size_t)opt.num_warps * 32, 1, 1}, (void*)args.data(), args.size(), shared_mem_);
|
stream->enqueue(&*ker_, grid, {(size_t)opt.num_warps * 32, 1, 1}, (void*)args.data(), args.size(), shared_mem_);
|
||||||
}
|
}
|
||||||
|
|
||||||
std::string kernel::get_asm(asm_mode_t mode) {
|
std::string kernel::get_asm(const std::string& mode) {
|
||||||
switch(mode){
|
std::vector<std::string> modes = {"llir", "ptx"};
|
||||||
case ASM_LLIR:{
|
if(std::find(modes.begin(), modes.end(), mode) == modes.end()){
|
||||||
return ((driver::cu_module*)mod_.get())->llir();
|
std::string err = "Unrecognized mode. Supported values are: ";
|
||||||
}
|
for(std::string m: modes){
|
||||||
case ASM_NV_PTX:
|
if(m != modes[0])
|
||||||
case ASM_NV_SASS:{
|
err += ", ";
|
||||||
std::string ptx = ((driver::cu_module*)mod_.get())->ptx();
|
err += m;
|
||||||
// SASS
|
|
||||||
std::string input = std::tmpnam(nullptr);
|
|
||||||
std::string output = std::tmpnam(nullptr);
|
|
||||||
std::ofstream ofs(input);
|
|
||||||
ofs << ptx;
|
|
||||||
ofs.close();
|
|
||||||
if(mode == ASM_NV_PTX)
|
|
||||||
return ptx;
|
|
||||||
std::string cmd;
|
|
||||||
int err;
|
|
||||||
// compile ptx
|
|
||||||
driver::cu_device* cu_device = (driver::cu_device*)dev_;
|
|
||||||
cmd = "ptxas --gpu-name=sm_" + std::to_string(cu_device->compute_capability()) + " " + input + " -o " + input + ".o";
|
|
||||||
err = system(cmd.c_str());
|
|
||||||
// disassemble
|
|
||||||
cmd = "cuobjdump --dump-sass " + input + ".o >> " + output;
|
|
||||||
err = system(cmd.c_str());
|
|
||||||
std::regex comment(" *\\/\\* 0x[0-9a-f]+ \\*\\/");
|
|
||||||
std::string to_delete = " /*";
|
|
||||||
std::ifstream ifs(output);
|
|
||||||
std::string line;
|
|
||||||
std::string sass;
|
|
||||||
while(std::getline(ifs, line))
|
|
||||||
if(!std::regex_match(line, comment))
|
|
||||||
sass += line + "\n";
|
|
||||||
return sass;
|
|
||||||
}
|
|
||||||
default:
|
|
||||||
return "";
|
|
||||||
}
|
}
|
||||||
|
throw std::runtime_error(err);
|
||||||
|
}
|
||||||
|
if(mode == "llir")
|
||||||
|
return ((driver::cu_module*)mod_.get())->llir();
|
||||||
|
if(mode == "ptx")
|
||||||
|
return ((driver::cu_module*)mod_.get())->ptx();
|
||||||
|
assert(false);
|
||||||
|
return "";
|
||||||
}
|
}
|
||||||
/* --------------------------------- */
|
/* --------------------------------- */
|
||||||
/* --------------------------------- */
|
/* --------------------------------- */
|
||||||
|
@@ -127,10 +127,6 @@ void init_triton_runtime(py::module &&m) {
|
|||||||
.value("float", rt::FLOAT_T)
|
.value("float", rt::FLOAT_T)
|
||||||
.value("double", rt::DOUBLE_T)
|
.value("double", rt::DOUBLE_T)
|
||||||
.value("buffer", rt::BUFFER_T);
|
.value("buffer", rt::BUFFER_T);
|
||||||
// assembly mode
|
|
||||||
py::enum_<rt::asm_mode_t>(m, "asm_mode")
|
|
||||||
.value("ptx", rt::ASM_NV_PTX)
|
|
||||||
.value("sass", rt::ASM_NV_SASS);
|
|
||||||
// compilation options
|
// compilation options
|
||||||
py::class_<rt::options_t>(m, "options", py::dynamic_attr())
|
py::class_<rt::options_t>(m, "options", py::dynamic_attr())
|
||||||
.def(py::init<>())
|
.def(py::init<>())
|
||||||
@@ -142,7 +138,8 @@ void init_triton_runtime(py::module &&m) {
|
|||||||
// kernel
|
// kernel
|
||||||
py::class_<rt::kernel>(m, "kernel")
|
py::class_<rt::kernel>(m, "kernel")
|
||||||
.def("__call__", &rt::kernel::operator())
|
.def("__call__", &rt::kernel::operator())
|
||||||
.def_readonly("opt", &rt::kernel::opt);
|
.def_readonly("opt", &rt::kernel::opt)
|
||||||
|
.def("asm", &rt::kernel::get_asm);
|
||||||
// tune conf
|
// tune conf
|
||||||
py::class_<rt::config>(m, "config")
|
py::class_<rt::config>(m, "config")
|
||||||
.def(py::init<std::map<std::string, std::string>, int>(),
|
.def(py::init<std::map<std::string, std::string>, int>(),
|
||||||
|
@@ -6,14 +6,9 @@ import torch
|
|||||||
import triton._C.libtriton.triton as _triton
|
import triton._C.libtriton.triton as _triton
|
||||||
|
|
||||||
codes = {
|
codes = {
|
||||||
_triton.runtime.arg_type.int1: 'B',
|
_triton.runtime.arg_type.int1: 'B', _triton.runtime.arg_type.int8: 'B', _triton.runtime.arg_type.int32: 'I',
|
||||||
_triton.runtime.arg_type.int8: 'B',
|
_triton.runtime.arg_type.int64: 'Q', _triton.runtime.arg_type.half: 'H', _triton.runtime.arg_type.float: 'f',
|
||||||
_triton.runtime.arg_type.int32: 'I',
|
_triton.runtime.arg_type.double: 'd', _triton.runtime.arg_type.buffer: 'P'
|
||||||
_triton.runtime.arg_type.int64: 'Q',
|
|
||||||
_triton.runtime.arg_type.half: 'H',
|
|
||||||
_triton.runtime.arg_type.float: 'f',
|
|
||||||
_triton.runtime.arg_type.double: 'd',
|
|
||||||
_triton.runtime.arg_type.buffer: 'P'
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@@ -3,8 +3,10 @@ import os
|
|||||||
|
|
||||||
try:
|
try:
|
||||||
import triton._C.libtriton.cutlass as _cutlass
|
import triton._C.libtriton.cutlass as _cutlass
|
||||||
|
has_cutlass = True
|
||||||
except ImportError:
|
except ImportError:
|
||||||
_cutlass = None
|
_cutlass = None
|
||||||
|
has_cutlass = False
|
||||||
|
|
||||||
|
|
||||||
def sparsify_tensor(x, mask, block):
|
def sparsify_tensor(x, mask, block):
|
||||||
@@ -44,7 +46,7 @@ def mask_tensor(x, mask, block, value=0):
|
|||||||
return ret
|
return ret
|
||||||
|
|
||||||
|
|
||||||
def allclose(x, y):
|
def allclose(x, y, tol=1e-2):
|
||||||
assert x.dtype == y.dtype
|
assert x.dtype == y.dtype
|
||||||
diff = abs(x - y)
|
diff = abs(x - y)
|
||||||
x_max = torch.max(x)
|
x_max = torch.max(x)
|
||||||
|
Reference in New Issue
Block a user