Add bf16/fp16/fp64 support for ty_to_cpp (#800)
In ```torch._inductor```, we [convert 0d CPU tensor to scalar during triton codegen](https://github.com/pytorch/pytorch/pull/87329), so need add missing triton support for bf16/fp16/fp64.
This commit is contained in:
@@ -913,7 +913,10 @@ def ty_to_cpp(ty):
|
|||||||
"i64": "int64_t",
|
"i64": "int64_t",
|
||||||
"u32": "uint32_t",
|
"u32": "uint32_t",
|
||||||
"u64": "uint64_t",
|
"u64": "uint64_t",
|
||||||
|
"fp16": "float",
|
||||||
|
"bf16": "float",
|
||||||
"fp32": "float",
|
"fp32": "float",
|
||||||
|
"fp64": "double",
|
||||||
}[ty]
|
}[ty]
|
||||||
|
|
||||||
|
|
||||||
@@ -943,6 +946,8 @@ def generate_launcher(identifier, constants, signature):
|
|||||||
'i64': 'int64_t',
|
'i64': 'int64_t',
|
||||||
'u32': 'uint32_t',
|
'u32': 'uint32_t',
|
||||||
'u64': 'uint64_t',
|
'u64': 'uint64_t',
|
||||||
|
'fp16': 'float',
|
||||||
|
'bf16': 'float',
|
||||||
'fp32': 'float',
|
'fp32': 'float',
|
||||||
'fp64': 'double',
|
'fp64': 'double',
|
||||||
}[ty]
|
}[ty]
|
||||||
|
Reference in New Issue
Block a user