[BUILD] Remove compilation warnings
This commit is contained in:
committed by
Philippe Tillet
parent
b352bc79e3
commit
5ba5a77561
@@ -127,10 +127,6 @@ void init_triton_runtime(py::module &&m) {
|
||||
.value("float", rt::FLOAT_T)
|
||||
.value("double", rt::DOUBLE_T)
|
||||
.value("buffer", rt::BUFFER_T);
|
||||
// assembly mode
|
||||
py::enum_<rt::asm_mode_t>(m, "asm_mode")
|
||||
.value("ptx", rt::ASM_NV_PTX)
|
||||
.value("sass", rt::ASM_NV_SASS);
|
||||
// compilation options
|
||||
py::class_<rt::options_t>(m, "options", py::dynamic_attr())
|
||||
.def(py::init<>())
|
||||
@@ -142,7 +138,8 @@ void init_triton_runtime(py::module &&m) {
|
||||
// kernel
|
||||
py::class_<rt::kernel>(m, "kernel")
|
||||
.def("__call__", &rt::kernel::operator())
|
||||
.def_readonly("opt", &rt::kernel::opt);
|
||||
.def_readonly("opt", &rt::kernel::opt)
|
||||
.def("asm", &rt::kernel::get_asm);
|
||||
// tune conf
|
||||
py::class_<rt::config>(m, "config")
|
||||
.def(py::init<std::map<std::string, std::string>, int>(),
|
||||
|
@@ -6,14 +6,9 @@ import torch
|
||||
import triton._C.libtriton.triton as _triton
|
||||
|
||||
codes = {
|
||||
_triton.runtime.arg_type.int1: 'B',
|
||||
_triton.runtime.arg_type.int8: 'B',
|
||||
_triton.runtime.arg_type.int32: 'I',
|
||||
_triton.runtime.arg_type.int64: 'Q',
|
||||
_triton.runtime.arg_type.half: 'H',
|
||||
_triton.runtime.arg_type.float: 'f',
|
||||
_triton.runtime.arg_type.double: 'd',
|
||||
_triton.runtime.arg_type.buffer: 'P'
|
||||
_triton.runtime.arg_type.int1: 'B', _triton.runtime.arg_type.int8: 'B', _triton.runtime.arg_type.int32: 'I',
|
||||
_triton.runtime.arg_type.int64: 'Q', _triton.runtime.arg_type.half: 'H', _triton.runtime.arg_type.float: 'f',
|
||||
_triton.runtime.arg_type.double: 'd', _triton.runtime.arg_type.buffer: 'P'
|
||||
}
|
||||
|
||||
|
||||
|
@@ -3,8 +3,10 @@ import os
|
||||
|
||||
try:
|
||||
import triton._C.libtriton.cutlass as _cutlass
|
||||
has_cutlass = True
|
||||
except ImportError:
|
||||
_cutlass = None
|
||||
has_cutlass = False
|
||||
|
||||
|
||||
def sparsify_tensor(x, mask, block):
|
||||
@@ -44,7 +46,7 @@ def mask_tensor(x, mask, block, value=0):
|
||||
return ret
|
||||
|
||||
|
||||
def allclose(x, y):
|
||||
def allclose(x, y, tol=1e-2):
|
||||
assert x.dtype == y.dtype
|
||||
diff = abs(x - y)
|
||||
x_max = torch.max(x)
|
||||
|
Reference in New Issue
Block a user