[IR] Preliminary support for BF16 (#129)

This PR adds a BF16 data-type, along with FP32 <-> BF16 conversion instructions in the LLVM codegen. Other kinds of ops on bfloat16 are not yet supported.
This commit is contained in:
Philippe Tillet
2021-06-25 10:19:29 -04:00
committed by Philippe Tillet
parent 9b4e2cae2d
commit 8cea583109
18 changed files with 173 additions and 132 deletions

View File

@@ -465,6 +465,7 @@ class Kernel:
float: 'f',
bool: 'B',
triton.language.float8: 'f8',
torch.bfloat16: 'bf16',
torch.float16: 'f16',
torch.float32: 'f32',
torch.float64: 'f64',
@@ -484,6 +485,7 @@ class Kernel:
'B': _triton.ir.type.get_int1,
'f8': _triton.ir.type.get_fp8,
'f16': _triton.ir.type.get_fp16,
'bf16': _triton.ir.type.get_bf16,
'f32': _triton.ir.type.get_fp32,
'f64': _triton.ir.type.get_fp64,
'i1': _triton.ir.type.get_int1,
@@ -555,6 +557,7 @@ class Kernel:
if len(tensor_idxs) == 0:
raise ValueError("No Tensor argument found.")
device = wargs[tensor_idxs[0]].device
torch.cuda.set_device(device.index)
# attributes
args = [arg.data_ptr() if i in tensor_idxs else arg for i, arg in enumerate(wargs)]
attributes = {i: Kernel.pow2_divisor(a) for i, a in enumerate(args) if isinstance(a, int)}