[IR] Preliminary support for BF16 (#129)
This PR adds a BF16 data-type, along with FP32 <-> BF16 conversion instructions in the LLVM codegen. Other kinds of ops on bfloat16 are not yet supported.
This commit is contained in:
committed by
Philippe Tillet
parent
9b4e2cae2d
commit
8cea583109
@@ -49,7 +49,7 @@ class CMakeBuild(build_ext):
|
||||
self.build_extension(ext)
|
||||
|
||||
def build_extension(self, ext):
|
||||
#self.debug = True
|
||||
# self.debug = True
|
||||
extdir = os.path.abspath(os.path.dirname(self.get_ext_fullpath(ext.path)))
|
||||
# create build directories
|
||||
build_suffix = 'debug' if self.debug else 'release'
|
||||
|
@@ -204,9 +204,10 @@ void init_triton_ir(py::module &&m) {
|
||||
.def("make_block", &ir::block_type::get, ret::reference)
|
||||
.def("get_void", &ir::type::get_void_ty, ret::reference)
|
||||
.def("get_fp8", &ir::type::get_fp8_ty, ret::reference)
|
||||
.def("get_fp16", &ir::type::get_half_ty, ret::reference)
|
||||
.def("get_fp32", &ir::type::get_float_ty, ret::reference)
|
||||
.def("get_fp64", &ir::type::get_double_ty, ret::reference)
|
||||
.def("get_fp16", &ir::type::get_fp16_ty, ret::reference)
|
||||
.def("get_bf16", &ir::type::get_bf16_ty, ret::reference)
|
||||
.def("get_fp32", &ir::type::get_fp32_ty, ret::reference)
|
||||
.def("get_fp64", &ir::type::get_fp64_ty, ret::reference)
|
||||
.def("get_int1", &ir::type::get_int1_ty, ret::reference)
|
||||
.def("get_int8", &ir::type::get_int8_ty, ret::reference)
|
||||
.def("get_int16", &ir::type::get_int16_ty, ret::reference)
|
||||
@@ -215,9 +216,10 @@ void init_triton_ir(py::module &&m) {
|
||||
|
||||
.def("is_void", &ir::type::is_void_ty)
|
||||
.def("is_fp8", &ir::type::is_fp8_ty)
|
||||
.def("is_fp16", &ir::type::is_half_ty)
|
||||
.def("is_fp32", &ir::type::is_float_ty)
|
||||
.def("is_fp64", &ir::type::is_double_ty)
|
||||
.def("is_fp16", &ir::type::is_fp16_ty)
|
||||
.def("is_bf16", &ir::type::is_bf16_ty)
|
||||
.def("is_fp32", &ir::type::is_fp32_ty)
|
||||
.def("is_fp64", &ir::type::is_fp64_ty)
|
||||
.def("is_int1", [](ir::type *self) { return self->is_integer_ty(1); })
|
||||
.def("is_int8", [](ir::type *self) { return self->is_integer_ty(8); })
|
||||
.def("is_int16", [](ir::type *self) { return self->is_integer_ty(16); })
|
||||
|
@@ -16,6 +16,7 @@ cvt = {
|
||||
'int16': torch.int16,
|
||||
'int32': torch.int32,
|
||||
'int64': torch.int64,
|
||||
'bfloat16': torch.bfloat16,
|
||||
'float16': torch.float16,
|
||||
'float32': torch.float32,
|
||||
'float64': torch.float64,
|
||||
@@ -292,9 +293,12 @@ def test_atomic_rmw(op, dtype_x, mode, device='cuda'):
|
||||
# test cast
|
||||
# ---------------
|
||||
@pytest.mark.parametrize("dtype_x, dtype_z, bitcast", [
|
||||
(dtype_x, dtype_z, False) for dtype_x in dtypes \
|
||||
for dtype_z in dtypes
|
||||
] + [
|
||||
(dtype_x, dtype_z, False) \
|
||||
for dtype_x in dtypes\
|
||||
for dtype_z in dtypes
|
||||
] + [
|
||||
('float32', 'bfloat16', False),
|
||||
('bfloat16', 'float32', False),
|
||||
('float32', 'int32', True)
|
||||
])
|
||||
def test_cast(dtype_x, dtype_z, bitcast, device='cuda'):
|
||||
|
@@ -465,6 +465,7 @@ class Kernel:
|
||||
float: 'f',
|
||||
bool: 'B',
|
||||
triton.language.float8: 'f8',
|
||||
torch.bfloat16: 'bf16',
|
||||
torch.float16: 'f16',
|
||||
torch.float32: 'f32',
|
||||
torch.float64: 'f64',
|
||||
@@ -484,6 +485,7 @@ class Kernel:
|
||||
'B': _triton.ir.type.get_int1,
|
||||
'f8': _triton.ir.type.get_fp8,
|
||||
'f16': _triton.ir.type.get_fp16,
|
||||
'bf16': _triton.ir.type.get_bf16,
|
||||
'f32': _triton.ir.type.get_fp32,
|
||||
'f64': _triton.ir.type.get_fp64,
|
||||
'i1': _triton.ir.type.get_int1,
|
||||
@@ -555,6 +557,7 @@ class Kernel:
|
||||
if len(tensor_idxs) == 0:
|
||||
raise ValueError("No Tensor argument found.")
|
||||
device = wargs[tensor_idxs[0]].device
|
||||
torch.cuda.set_device(device.index)
|
||||
# attributes
|
||||
args = [arg.data_ptr() if i in tensor_idxs else arg for i, arg in enumerate(wargs)]
|
||||
attributes = {i: Kernel.pow2_divisor(a) for i, a in enumerate(args) if isinstance(a, int)}
|
||||
|
@@ -86,6 +86,7 @@ int32 = dtype(ir.type.get_int32)
|
||||
int64 = dtype(ir.type.get_int64)
|
||||
float8 = dtype(ir.type.get_fp8)
|
||||
float16 = dtype(ir.type.get_fp16)
|
||||
bfloat16 = dtype(ir.type.get_bf16)
|
||||
float32 = dtype(ir.type.get_fp32)
|
||||
float64 = dtype(ir.type.get_fp64)
|
||||
|
||||
@@ -103,6 +104,7 @@ class block:
|
||||
if ir_type.is_int64(): return int64
|
||||
if ir_type.is_fp8(): return float8
|
||||
if ir_type.is_fp16(): return float16
|
||||
if ir_type.is_bf16(): return bfloat16
|
||||
if ir_type.is_fp32(): return float32
|
||||
if ir_type.is_fp64(): return float64
|
||||
# pointer type
|
||||
|
Reference in New Issue
Block a user