[LANG] Preliminary FP8 support (#96)
This commit is contained in:
committed by
Philippe Tillet
parent
4290be1ae8
commit
7355efa745
@@ -193,6 +193,7 @@ void init_triton_ir(py::module &&m) {
|
||||
.def("make_function", &ir::function_type::get, ret::reference)
|
||||
.def("make_block", &ir::block_type::get, ret::reference)
|
||||
.def("get_void", &ir::type::get_void_ty, ret::reference)
|
||||
.def("get_fp8", &ir::type::get_fp8_ty, ret::reference)
|
||||
.def("get_fp16", &ir::type::get_half_ty, ret::reference)
|
||||
.def("get_fp32", &ir::type::get_float_ty, ret::reference)
|
||||
.def("get_fp64", &ir::type::get_double_ty, ret::reference)
|
||||
@@ -203,6 +204,7 @@ void init_triton_ir(py::module &&m) {
|
||||
.def("get_int64", &ir::type::get_int64_ty, ret::reference)
|
||||
|
||||
.def("is_void", &ir::type::is_void_ty)
|
||||
.def("is_fp8", &ir::type::is_fp8_ty)
|
||||
.def("is_fp16", &ir::type::is_half_ty)
|
||||
.def("is_fp32", &ir::type::is_float_ty)
|
||||
.def("is_fp64", &ir::type::is_double_ty)
|
||||
|
@@ -2,7 +2,7 @@
|
||||
# or pybind11 shows `munmap_chunk(): invalid pointer`
|
||||
import torch
|
||||
# submodules
|
||||
from .code_gen import cdiv, jit, autotune, heuristics, Config, Autotuner
|
||||
from .code_gen import cdiv, jit, autotune, heuristics, Config, Autotuner, reinterpret
|
||||
|
||||
from . import language
|
||||
from . import code_gen
|
||||
|
@@ -436,20 +436,23 @@ class CompilationError(Exception):
|
||||
|
||||
|
||||
class Kernel:
|
||||
|
||||
type_names = {
|
||||
int: 'I',
|
||||
float: 'f',
|
||||
bool: 'B',
|
||||
torch.float16: 'f16',
|
||||
torch.float32: 'f32',
|
||||
torch.float64: 'f64',
|
||||
torch.bool: 'i1',
|
||||
torch.int8: 'i8',
|
||||
torch.int16: 'i16',
|
||||
torch.int32: 'i32',
|
||||
torch.int64: 'i64',
|
||||
}
|
||||
@staticmethod
|
||||
def _type_name(obj):
|
||||
type_names = {
|
||||
int: 'I',
|
||||
float: 'f',
|
||||
bool: 'B',
|
||||
triton.language.float8: 'f8',
|
||||
torch.float16: 'f16',
|
||||
torch.float32: 'f32',
|
||||
torch.float64: 'f64',
|
||||
torch.bool: 'i1',
|
||||
torch.int8: 'i8',
|
||||
torch.int16: 'i16',
|
||||
torch.int32: 'i32',
|
||||
torch.int64: 'i64',
|
||||
}
|
||||
return type_names[obj]
|
||||
|
||||
@staticmethod
|
||||
def _to_triton_ir(context, obj):
|
||||
@@ -457,6 +460,7 @@ class Kernel:
|
||||
'I': _triton.ir.type.get_int32,
|
||||
'f': _triton.ir.type.get_fp32,
|
||||
'B': _triton.ir.type.get_int1,
|
||||
'f8': _triton.ir.type.get_fp8,
|
||||
'f16': _triton.ir.type.get_fp16,
|
||||
'f32': _triton.ir.type.get_fp32,
|
||||
'f64': _triton.ir.type.get_fp64,
|
||||
@@ -467,12 +471,12 @@ class Kernel:
|
||||
'i64': _triton.ir.type.get_int64,
|
||||
}
|
||||
# convert torch.Tensor to Triton IR pointers
|
||||
if isinstance(obj, torch.Tensor):
|
||||
name = Kernel.type_names[obj.dtype]
|
||||
if hasattr(obj, 'data_ptr'):
|
||||
name = Kernel._type_name(obj.dtype)
|
||||
elt_ty = type_map[name](context)
|
||||
return _triton.ir.type.make_ptr(elt_ty, 1)
|
||||
# default path returns triton.ir.type directly
|
||||
name = Kernel.type_names[obj.__class__]
|
||||
name = Kernel._type_name(obj.__class__)
|
||||
return type_map[name](context)
|
||||
|
||||
@staticmethod
|
||||
@@ -481,7 +485,7 @@ class Kernel:
|
||||
types_key = [None] * len(wargs)
|
||||
for i, arg in enumerate(wargs):
|
||||
prefix = 'P' if i in tensor_idxs else ''
|
||||
suffix = Kernel.type_names[arg.dtype] if i in tensor_idxs else Kernel.type_names[arg.__class__]
|
||||
suffix = Kernel._type_name(arg.dtype) if i in tensor_idxs else Kernel._type_name(arg.__class__)
|
||||
types_key[i] = prefix + suffix
|
||||
return tuple(types_key)
|
||||
|
||||
@@ -523,7 +527,7 @@ class Kernel:
|
||||
|
||||
def __call__(self, *wargs, grid, num_warps=4, **meta):
|
||||
# device inference
|
||||
tensor_idxs = [i for i, arg in enumerate(wargs) if isinstance(arg, torch.Tensor)]
|
||||
tensor_idxs = [i for i, arg in enumerate(wargs) if hasattr(arg, 'data_ptr')]
|
||||
if len(tensor_idxs) == 0:
|
||||
raise ValueError("No Tensor argument found.")
|
||||
device = wargs[tensor_idxs[0]].device
|
||||
@@ -545,7 +549,7 @@ class Kernel:
|
||||
*wargs, device=device, attributes=attributes, num_warps=num_warps, constants=constants, **meta
|
||||
)
|
||||
# pack arguments
|
||||
fmt = ''.join(['P' if i in tensor_idxs else Kernel.type_names[arg.__class__] for i, arg in enumerate(wargs)])
|
||||
fmt = ''.join(['P' if i in tensor_idxs else Kernel._type_name(arg.__class__) for i, arg in enumerate(wargs)])
|
||||
params = struct.pack(fmt, *args)
|
||||
# enqueue cached function into stream
|
||||
binary = cache[key]
|
||||
@@ -703,3 +707,19 @@ def jit(fn):
|
||||
|
||||
def cdiv(x, y):
|
||||
return (x + y - 1) // y
|
||||
|
||||
|
||||
######
|
||||
|
||||
|
||||
class TensorWrapper:
|
||||
def __init__(self, data_ptr, dtype):
|
||||
self._data_ptr = data_ptr
|
||||
self.dtype = dtype
|
||||
|
||||
def data_ptr(self):
|
||||
return self._data_ptr
|
||||
|
||||
|
||||
def reinterpret(tensor, dtype):
|
||||
return TensorWrapper(tensor.data_ptr(), dtype)
|
@@ -84,6 +84,7 @@ int8 = dtype(ir.type.get_int8)
|
||||
int16 = dtype(ir.type.get_int16)
|
||||
int32 = dtype(ir.type.get_int32)
|
||||
int64 = dtype(ir.type.get_int64)
|
||||
float8 = dtype(ir.type.get_fp8)
|
||||
float16 = dtype(ir.type.get_fp16)
|
||||
float32 = dtype(ir.type.get_fp32)
|
||||
float64 = dtype(ir.type.get_fp64)
|
||||
@@ -98,6 +99,7 @@ class block:
|
||||
if ir_type.is_int16(): return int16
|
||||
if ir_type.is_int32(): return int32
|
||||
if ir_type.is_int64(): return int64
|
||||
if ir_type.is_fp8(): return float8
|
||||
if ir_type.is_fp16(): return float16
|
||||
if ir_type.is_fp32(): return float32
|
||||
if ir_type.is_fp64(): return float64
|
||||
|
Reference in New Issue
Block a user