[FRONTEND] Fix the implicit broadcasting rule (#663)
This PR solves the cast issue that appears in some tutorial code.
This commit is contained in:
@@ -752,8 +752,11 @@ def make_triton_ir(fn, signature, constants=dict(), attributes=dict()):
|
||||
# create kernel prototype
|
||||
constants = {fn.arg_names.index(name): value for name, value in constants.items()}
|
||||
attributes = {fn.arg_names.index(name): value for name, value in attributes.items()}
|
||||
arg_types = signature.replace(' ', '').split(',')
|
||||
arg_types = [str_to_ty(x) for x in arg_types]
|
||||
if signature.replace(' ', '') != '':
|
||||
arg_types = signature.replace(' ', '').split(',')
|
||||
arg_types = [str_to_ty(x) for x in arg_types]
|
||||
else:
|
||||
arg_types = []
|
||||
prototype = triton.language.function_type([], arg_types)
|
||||
# visit kernel AST
|
||||
gscope = fn.__globals__.copy()
|
||||
|
Reference in New Issue
Block a user