[LANG] Preliminary FP8 support (#96)

This commit is contained in:
Philippe Tillet
2021-05-01 14:34:33 -04:00
committed by Philippe Tillet
parent 4290be1ae8
commit 7355efa745
10 changed files with 182 additions and 40 deletions

View File

@@ -193,6 +193,7 @@ void init_triton_ir(py::module &&m) {
.def("make_function", &ir::function_type::get, ret::reference)
.def("make_block", &ir::block_type::get, ret::reference)
.def("get_void", &ir::type::get_void_ty, ret::reference)
.def("get_fp8", &ir::type::get_fp8_ty, ret::reference)
.def("get_fp16", &ir::type::get_half_ty, ret::reference)
.def("get_fp32", &ir::type::get_float_ty, ret::reference)
.def("get_fp64", &ir::type::get_double_ty, ret::reference)
@@ -203,6 +204,7 @@ void init_triton_ir(py::module &&m) {
.def("get_int64", &ir::type::get_int64_ty, ret::reference)
.def("is_void", &ir::type::is_void_ty)
.def("is_fp8", &ir::type::is_fp8_ty)
.def("is_fp16", &ir::type::is_half_ty)
.def("is_fp32", &ir::type::is_float_ty)
.def("is_fp64", &ir::type::is_double_ty)

View File

@@ -2,7 +2,7 @@
# or pybind11 shows `munmap_chunk(): invalid pointer`
import torch
# submodules
from .code_gen import cdiv, jit, autotune, heuristics, Config, Autotuner
from .code_gen import cdiv, jit, autotune, heuristics, Config, Autotuner, reinterpret
from . import language
from . import code_gen

View File

@@ -436,20 +436,23 @@ class CompilationError(Exception):
class Kernel:
type_names = {
int: 'I',
float: 'f',
bool: 'B',
torch.float16: 'f16',
torch.float32: 'f32',
torch.float64: 'f64',
torch.bool: 'i1',
torch.int8: 'i8',
torch.int16: 'i16',
torch.int32: 'i32',
torch.int64: 'i64',
}
@staticmethod
def _type_name(obj):
type_names = {
int: 'I',
float: 'f',
bool: 'B',
triton.language.float8: 'f8',
torch.float16: 'f16',
torch.float32: 'f32',
torch.float64: 'f64',
torch.bool: 'i1',
torch.int8: 'i8',
torch.int16: 'i16',
torch.int32: 'i32',
torch.int64: 'i64',
}
return type_names[obj]
@staticmethod
def _to_triton_ir(context, obj):
@@ -457,6 +460,7 @@ class Kernel:
'I': _triton.ir.type.get_int32,
'f': _triton.ir.type.get_fp32,
'B': _triton.ir.type.get_int1,
'f8': _triton.ir.type.get_fp8,
'f16': _triton.ir.type.get_fp16,
'f32': _triton.ir.type.get_fp32,
'f64': _triton.ir.type.get_fp64,
@@ -467,12 +471,12 @@ class Kernel:
'i64': _triton.ir.type.get_int64,
}
# convert torch.Tensor to Triton IR pointers
if isinstance(obj, torch.Tensor):
name = Kernel.type_names[obj.dtype]
if hasattr(obj, 'data_ptr'):
name = Kernel._type_name(obj.dtype)
elt_ty = type_map[name](context)
return _triton.ir.type.make_ptr(elt_ty, 1)
# default path returns triton.ir.type directly
name = Kernel.type_names[obj.__class__]
name = Kernel._type_name(obj.__class__)
return type_map[name](context)
@staticmethod
@@ -481,7 +485,7 @@ class Kernel:
types_key = [None] * len(wargs)
for i, arg in enumerate(wargs):
prefix = 'P' if i in tensor_idxs else ''
suffix = Kernel.type_names[arg.dtype] if i in tensor_idxs else Kernel.type_names[arg.__class__]
suffix = Kernel._type_name(arg.dtype) if i in tensor_idxs else Kernel._type_name(arg.__class__)
types_key[i] = prefix + suffix
return tuple(types_key)
@@ -523,7 +527,7 @@ class Kernel:
def __call__(self, *wargs, grid, num_warps=4, **meta):
# device inference
tensor_idxs = [i for i, arg in enumerate(wargs) if isinstance(arg, torch.Tensor)]
tensor_idxs = [i for i, arg in enumerate(wargs) if hasattr(arg, 'data_ptr')]
if len(tensor_idxs) == 0:
raise ValueError("No Tensor argument found.")
device = wargs[tensor_idxs[0]].device
@@ -545,7 +549,7 @@ class Kernel:
*wargs, device=device, attributes=attributes, num_warps=num_warps, constants=constants, **meta
)
# pack arguments
fmt = ''.join(['P' if i in tensor_idxs else Kernel.type_names[arg.__class__] for i, arg in enumerate(wargs)])
fmt = ''.join(['P' if i in tensor_idxs else Kernel._type_name(arg.__class__) for i, arg in enumerate(wargs)])
params = struct.pack(fmt, *args)
# enqueue cached function into stream
binary = cache[key]
@@ -703,3 +707,19 @@ def jit(fn):
def cdiv(x, y):
return (x + y - 1) // y
######
class TensorWrapper:
def __init__(self, data_ptr, dtype):
self._data_ptr = data_ptr
self.dtype = dtype
def data_ptr(self):
return self._data_ptr
def reinterpret(tensor, dtype):
return TensorWrapper(tensor.data_ptr(), dtype)

View File

@@ -84,6 +84,7 @@ int8 = dtype(ir.type.get_int8)
int16 = dtype(ir.type.get_int16)
int32 = dtype(ir.type.get_int32)
int64 = dtype(ir.type.get_int64)
float8 = dtype(ir.type.get_fp8)
float16 = dtype(ir.type.get_fp16)
float32 = dtype(ir.type.get_fp32)
float64 = dtype(ir.type.get_fp64)
@@ -98,6 +99,7 @@ class block:
if ir_type.is_int16(): return int16
if ir_type.is_int32(): return int32
if ir_type.is_int64(): return int64
if ir_type.is_fp8(): return float8
if ir_type.is_fp16(): return float16
if ir_type.is_fp32(): return float32
if ir_type.is_fp64(): return float64