[FRONTEND] Fix the implicit broadcasting rule (#663)

This PR solves the cast issue that appears in some tutorial code.
This commit is contained in:
Shintaro Iwasaki
2022-09-16 10:49:15 -07:00
committed by GitHub
parent 80e3fb5270
commit e9e1a4e682
4 changed files with 80 additions and 10 deletions

View File

@@ -752,8 +752,11 @@ def make_triton_ir(fn, signature, constants=dict(), attributes=dict()):
# create kernel prototype
constants = {fn.arg_names.index(name): value for name, value in constants.items()}
attributes = {fn.arg_names.index(name): value for name, value in attributes.items()}
arg_types = signature.replace(' ', '').split(',')
arg_types = [str_to_ty(x) for x in arg_types]
if signature.replace(' ', '') != '':
arg_types = signature.replace(' ', '').split(',')
arg_types = [str_to_ty(x) for x in arg_types]
else:
arg_types = []
prototype = triton.language.function_type([], arg_types)
# visit kernel AST
gscope = fn.__globals__.copy()