Add bf16/fp16/fp64 support for ty_to_cpp (#800)
In ```torch._inductor```, we [convert 0d CPU tensor to scalar during triton codegen](https://github.com/pytorch/pytorch/pull/87329), so need add missing triton support for bf16/fp16/fp64.
This commit is contained in:
@@ -913,7 +913,10 @@ def ty_to_cpp(ty):
|
||||
"i64": "int64_t",
|
||||
"u32": "uint32_t",
|
||||
"u64": "uint64_t",
|
||||
"fp16": "float",
|
||||
"bf16": "float",
|
||||
"fp32": "float",
|
||||
"fp64": "double",
|
||||
}[ty]
|
||||
|
||||
|
||||
@@ -943,6 +946,8 @@ def generate_launcher(identifier, constants, signature):
|
||||
'i64': 'int64_t',
|
||||
'u32': 'uint32_t',
|
||||
'u64': 'uint64_t',
|
||||
'fp16': 'float',
|
||||
'bf16': 'float',
|
||||
'fp32': 'float',
|
||||
'fp64': 'double',
|
||||
}[ty]
|
||||
|
Reference in New Issue
Block a user