[IR] Preliminary support for BF16 (#129)
This PR adds a BF16 data-type, along with FP32 <-> BF16 conversion instructions in the LLVM codegen. Other kinds of ops on bfloat16 are not yet supported.
This commit is contained in:
committed by
Philippe Tillet
parent
9b4e2cae2d
commit
8cea583109
@@ -16,6 +16,7 @@ cvt = {
|
||||
'int16': torch.int16,
|
||||
'int32': torch.int32,
|
||||
'int64': torch.int64,
|
||||
'bfloat16': torch.bfloat16,
|
||||
'float16': torch.float16,
|
||||
'float32': torch.float32,
|
||||
'float64': torch.float64,
|
||||
@@ -292,9 +293,12 @@ def test_atomic_rmw(op, dtype_x, mode, device='cuda'):
|
||||
# test cast
|
||||
# ---------------
|
||||
@pytest.mark.parametrize("dtype_x, dtype_z, bitcast", [
|
||||
(dtype_x, dtype_z, False) for dtype_x in dtypes \
|
||||
for dtype_z in dtypes
|
||||
] + [
|
||||
(dtype_x, dtype_z, False) \
|
||||
for dtype_x in dtypes\
|
||||
for dtype_z in dtypes
|
||||
] + [
|
||||
('float32', 'bfloat16', False),
|
||||
('bfloat16', 'float32', False),
|
||||
('float32', 'int32', True)
|
||||
])
|
||||
def test_cast(dtype_x, dtype_z, bitcast, device='cuda'):
|
||||
|
Reference in New Issue
Block a user