[IR] Preliminary support for BF16 (#129)
This PR adds a BF16 data-type, along with FP32 <-> BF16 conversion instructions in the LLVM codegen. Other kinds of ops on bfloat16 are not yet supported.
This commit is contained in:
committed by
Philippe Tillet
parent
9b4e2cae2d
commit
8cea583109
@@ -465,6 +465,7 @@ class Kernel:
|
||||
float: 'f',
|
||||
bool: 'B',
|
||||
triton.language.float8: 'f8',
|
||||
torch.bfloat16: 'bf16',
|
||||
torch.float16: 'f16',
|
||||
torch.float32: 'f32',
|
||||
torch.float64: 'f64',
|
||||
@@ -484,6 +485,7 @@ class Kernel:
|
||||
'B': _triton.ir.type.get_int1,
|
||||
'f8': _triton.ir.type.get_fp8,
|
||||
'f16': _triton.ir.type.get_fp16,
|
||||
'bf16': _triton.ir.type.get_bf16,
|
||||
'f32': _triton.ir.type.get_fp32,
|
||||
'f64': _triton.ir.type.get_fp64,
|
||||
'i1': _triton.ir.type.get_int1,
|
||||
@@ -555,6 +557,7 @@ class Kernel:
|
||||
if len(tensor_idxs) == 0:
|
||||
raise ValueError("No Tensor argument found.")
|
||||
device = wargs[tensor_idxs[0]].device
|
||||
torch.cuda.set_device(device.index)
|
||||
# attributes
|
||||
args = [arg.data_ptr() if i in tensor_idxs else arg for i, arg in enumerate(wargs)]
|
||||
attributes = {i: Kernel.pow2_divisor(a) for i, a in enumerate(args) if isinstance(a, int)}
|
||||
|
Reference in New Issue
Block a user