[BUILD] Remove compilation warnings

This commit is contained in:
Philippe Tillet
2021-03-24 01:24:50 -04:00
committed by Philippe Tillet
parent b352bc79e3
commit 5ba5a77561
6 changed files with 26 additions and 59 deletions

View File

@@ -127,10 +127,6 @@ void init_triton_runtime(py::module &&m) {
.value("float", rt::FLOAT_T)
.value("double", rt::DOUBLE_T)
.value("buffer", rt::BUFFER_T);
// assembly mode
py::enum_<rt::asm_mode_t>(m, "asm_mode")
.value("ptx", rt::ASM_NV_PTX)
.value("sass", rt::ASM_NV_SASS);
// compilation options
py::class_<rt::options_t>(m, "options", py::dynamic_attr())
.def(py::init<>())
@@ -142,7 +138,8 @@ void init_triton_runtime(py::module &&m) {
// kernel
py::class_<rt::kernel>(m, "kernel")
.def("__call__", &rt::kernel::operator())
.def_readonly("opt", &rt::kernel::opt);
.def_readonly("opt", &rt::kernel::opt)
.def("asm", &rt::kernel::get_asm);
// tune conf
py::class_<rt::config>(m, "config")
.def(py::init<std::map<std::string, std::string>, int>(),

View File

@@ -6,14 +6,9 @@ import torch
import triton._C.libtriton.triton as _triton
codes = {
_triton.runtime.arg_type.int1: 'B',
_triton.runtime.arg_type.int8: 'B',
_triton.runtime.arg_type.int32: 'I',
_triton.runtime.arg_type.int64: 'Q',
_triton.runtime.arg_type.half: 'H',
_triton.runtime.arg_type.float: 'f',
_triton.runtime.arg_type.double: 'd',
_triton.runtime.arg_type.buffer: 'P'
_triton.runtime.arg_type.int1: 'B', _triton.runtime.arg_type.int8: 'B', _triton.runtime.arg_type.int32: 'I',
_triton.runtime.arg_type.int64: 'Q', _triton.runtime.arg_type.half: 'H', _triton.runtime.arg_type.float: 'f',
_triton.runtime.arg_type.double: 'd', _triton.runtime.arg_type.buffer: 'P'
}

View File

@@ -3,8 +3,10 @@ import os
try:
import triton._C.libtriton.cutlass as _cutlass
has_cutlass = True
except ImportError:
_cutlass = None
has_cutlass = False
def sparsify_tensor(x, mask, block):
@@ -44,7 +46,7 @@ def mask_tensor(x, mask, block, value=0):
return ret
def allclose(x, y):
def allclose(x, y, tol=1e-2):
assert x.dtype == y.dtype
diff = abs(x - y)
x_max = torch.max(x)