diff --git a/python/triton/compiler.py b/python/triton/compiler.py index 1332f2c76..ab7733b60 100644 --- a/python/triton/compiler.py +++ b/python/triton/compiler.py @@ -913,7 +913,10 @@ def ty_to_cpp(ty): "i64": "int64_t", "u32": "uint32_t", "u64": "uint64_t", + "fp16": "float", + "bf16": "float", "fp32": "float", + "fp64": "double", }[ty] @@ -943,6 +946,8 @@ def generate_launcher(identifier, constants, signature): 'i64': 'int64_t', 'u32': 'uint32_t', 'u64': 'uint64_t', + 'fp16': 'float', + 'bf16': 'float', 'fp32': 'float', 'fp64': 'double', }[ty]