[FRONTEND] Semantic analysis refactor (#473)
Moved dispatch.cc to semantic.py Integer signedness now moved from C++ to python Cleaner frontend type Co-authored-by: Phil Tillet <phil@openai.com>
This commit is contained in:
@@ -1,3 +1,5 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import ast
|
||||
import builtins
|
||||
import functools
|
||||
@@ -11,7 +13,7 @@ import tempfile
|
||||
import textwrap
|
||||
import time
|
||||
import warnings
|
||||
from typing import Dict
|
||||
from typing import Dict, Optional, Set, Tuple, Union
|
||||
|
||||
import torch
|
||||
from filelock import FileLock
|
||||
@@ -22,48 +24,13 @@ from .tools.disasm import extract
|
||||
|
||||
|
||||
class CodeGenerator(ast.NodeVisitor):
|
||||
def get_value(self, name):
|
||||
# search node.id in local scope
|
||||
ret = None
|
||||
if name in self.lscope:
|
||||
ret = self.lscope[name]
|
||||
# search node.id in global scope
|
||||
elif name in self.gscope:
|
||||
ret = self.gscope[name]
|
||||
# search node.id in builtins
|
||||
elif name in self.builtins:
|
||||
ret = self.builtins[name]
|
||||
else:
|
||||
raise ValueError(f'{name} is not defined')
|
||||
if isinstance(ret, triton.language.block):
|
||||
handle = self.module.get_value(name)
|
||||
return triton.language.block(handle)
|
||||
return ret
|
||||
|
||||
def set_value(self, name, value):
|
||||
if isinstance(value, _triton.ir.value):
|
||||
value = triton.language.block(value)
|
||||
if isinstance(value, triton.language.block):
|
||||
self.module.set_value(name, value.handle)
|
||||
self.module.set_type(name, value.handle.type)
|
||||
self.lscope[name] = value
|
||||
|
||||
def is_triton_object(self, value):
|
||||
return isinstance(value, triton.language.block)
|
||||
|
||||
def visit_compound_statement(self, stmts):
|
||||
for stmt in stmts:
|
||||
self.last_ret = self.visit(stmt)
|
||||
if isinstance(stmt, ast.Return):
|
||||
break
|
||||
return stmts and isinstance(stmt, ast.Return)
|
||||
|
||||
def __init__(self, context, prototype, gscope, attributes, constants, kwargs):
|
||||
self.builder = _triton.ir.builder(context)
|
||||
self.module = _triton.ir.module('', self.builder)
|
||||
self.prototype = prototype
|
||||
self.gscope = gscope
|
||||
self.lscope = dict()
|
||||
self.is_arg_lscope = dict() # name => is_arg: {str: bool}
|
||||
self.attributes = attributes
|
||||
self.constants = constants
|
||||
self.kwargs = kwargs
|
||||
@@ -77,6 +44,146 @@ class CodeGenerator(ast.NodeVisitor):
|
||||
'isinstance': isinstance,
|
||||
'getattr': getattr,
|
||||
}
|
||||
# SSA-construction
|
||||
# [name, bb] => triton.language.tensor
|
||||
self.lvalues: Dict[Tuple[str, _triton.ir.basic_block], triton.language.tensor] = {}
|
||||
# bb => {name => phi}
|
||||
self.incomplete_phis = {}
|
||||
self.sealed_blocks: Set[_triton.ir.basic_block] = set()
|
||||
|
||||
def get_value(self, name):
|
||||
''' This function:
|
||||
1. make sure `name` is defined
|
||||
2. if `name` is triton.language.tensor, get stored tensor by calling
|
||||
`self._get_tensor()`
|
||||
'''
|
||||
# search node.id in local scope
|
||||
ret = None
|
||||
if name in self.lscope:
|
||||
ret = self.lscope[name]
|
||||
# search node.id in global scope
|
||||
elif name in self.gscope:
|
||||
ret = self.gscope[name]
|
||||
# search node.id in builtins
|
||||
elif name in self.builtins:
|
||||
ret = self.builtins[name]
|
||||
else:
|
||||
raise ValueError(f'{name} is not defined')
|
||||
if self.is_triton_tensor(ret) and not self.is_arg_lscope[name]:
|
||||
return self._get_tensor(name)
|
||||
return ret
|
||||
|
||||
def set_value(self, name: str,
|
||||
value: Union[triton.language.tensor, triton.language.constexpr],
|
||||
is_arg: bool = False) -> None:
|
||||
''' This function:
|
||||
called by visit_Assign() & visit_FuncDef() to store left value (lvalue)
|
||||
1. record local defined name (FIXME: should consider control flow)
|
||||
2. store tensor in self.lvalue
|
||||
'''
|
||||
self.lscope[name] = value
|
||||
# if this value is an argument, we don't need to create phis for it
|
||||
self.is_arg_lscope[name] = is_arg
|
||||
if isinstance(value, triton.language.tensor) and not is_arg:
|
||||
self._set_value(name, self.builder.get_insert_block(), value)
|
||||
|
||||
#
|
||||
# SSA-construction
|
||||
#
|
||||
def _get_tensor(self, name: str, bb: Optional[_triton.ir.basic_block] = None) -> triton.language.tensor:
|
||||
if not bb:
|
||||
bb = self.builder.get_insert_block()
|
||||
# local value numbering
|
||||
if (name, bb) in self.lvalues:
|
||||
return self.lvalues[(name, bb)]
|
||||
# global value numbering
|
||||
saved_insert_point = self.builder.get_insert_point()
|
||||
result = self._get_tensor_recursive(name, bb)
|
||||
self.builder.set_insert_point(saved_insert_point)
|
||||
return result
|
||||
|
||||
def _get_tensor_recursive(self, name: str, bb: _triton.ir.basic_block) -> triton.language.tensor:
|
||||
preds = bb.get_predecessors()
|
||||
type = self.lscope[name].type
|
||||
# some preds haven't been filled, create a phi as a proxy of the value
|
||||
if bb not in self.sealed_blocks:
|
||||
result = self._make_phi(type, len(preds), bb)
|
||||
if bb in self.incomplete_phis:
|
||||
self.incomplete_phis[bb][name] = result
|
||||
else:
|
||||
self.incomplete_phis[bb] = {name: result}
|
||||
elif len(preds) == 1:
|
||||
# one predecessor: no phi needed, try get value from pred
|
||||
result = self._get_tensor(name, preds[0])
|
||||
else: # multiple preds
|
||||
assert len(preds) > 1, f'{name} is an undefined name (cannot find in the entry block)'
|
||||
phi = self._make_phi(type, len(preds), bb)
|
||||
self._set_value(name, bb, phi)
|
||||
result = self._add_phi_operands(name, phi)
|
||||
self._set_value(name, bb, result)
|
||||
return result
|
||||
|
||||
# returns a new phi tensor, which encausulate an ir.phi_node
|
||||
def _make_phi(self,
|
||||
type: triton.language.dtype,
|
||||
num_values: int,
|
||||
bb: _triton.ir.basic_block) -> triton.language.tensor:
|
||||
instr = bb.get_first_non_phi()
|
||||
self.builder.set_insert_point((bb, instr))
|
||||
ir_phi = self.builder.create_phi(type.to_ir(self.builder), num_values)
|
||||
if instr:
|
||||
self.builder.set_insert_block(bb)
|
||||
return triton.language.tensor(ir_phi, type)
|
||||
|
||||
# complete a phi node. (TODO: rename this as _complete_phis?)
|
||||
# Note: since we try to remove tryival phi, the return tensor might not be a phi
|
||||
def _add_phi_operands(self, name: str,
|
||||
phi: triton.language.tensor) -> triton.language.tensor:
|
||||
bb = phi.handle.get_parent()
|
||||
for pred in bb.get_predecessors():
|
||||
v = self._get_tensor(name, pred)
|
||||
phi.handle.add_incoming(v.handle, pred)
|
||||
phi = self._try_remove_trivial_phi(phi)
|
||||
return phi
|
||||
|
||||
def _set_value(self, name: str, bb: _triton.ir.basic_block, value: triton.language.tensor) -> None:
|
||||
self.lvalues[(name, bb)] = value
|
||||
# TODO: why we need this?
|
||||
self.module.set_instr_metadata(name, value.handle)
|
||||
|
||||
def _seal_block(self, bb: _triton.ir.basic_block):
|
||||
# complete all incomplete phis
|
||||
if bb in self.incomplete_phis:
|
||||
for name, phi in self.incomplete_phis[bb].items():
|
||||
result = self._add_phi_operands(name, phi)
|
||||
# it's possible that this phi is trivial
|
||||
if self._get_tensor(name, bb).handle == phi.handle:
|
||||
self._set_value(name, bb, result)
|
||||
del self.incomplete_phis[bb]
|
||||
self.sealed_blocks.add(bb)
|
||||
|
||||
def _try_remove_trivial_phi(self, phi: triton.language.tensor) -> triton.language.tensor:
|
||||
unique_handles = {op for op in phi.handle.ops() if op != phi.handle}
|
||||
if len(unique_handles) != 1: # non-trivial phi
|
||||
return phi
|
||||
v = unique_handles.pop()
|
||||
phi.handle.replace_all_uses_with(v)
|
||||
phi.handle.erase_from_parent()
|
||||
# TODO: remove trivial phis recursively
|
||||
return triton.language.tensor(v, phi.type)
|
||||
|
||||
def is_triton_tensor(self, value):
|
||||
return isinstance(value, triton.language.tensor)
|
||||
|
||||
#
|
||||
# AST visitor
|
||||
#
|
||||
def visit_compound_statement(self, stmts):
|
||||
for stmt in stmts:
|
||||
self.last_ret = self.visit(stmt)
|
||||
if isinstance(stmt, ast.Return):
|
||||
break
|
||||
return stmts and isinstance(stmt, ast.Return)
|
||||
|
||||
def visit_Module(self, node):
|
||||
ast.NodeVisitor.generic_visit(self, node)
|
||||
@@ -113,7 +220,7 @@ class CodeGenerator(ast.NodeVisitor):
|
||||
if inline:
|
||||
pass
|
||||
else:
|
||||
fn = self.module.get_or_insert_function(node.name, self.prototype)
|
||||
fn = self.module.get_or_insert_function(node.name, self.prototype.to_ir(self.builder))
|
||||
arg_values = []
|
||||
idx = 0
|
||||
for i, arg_name in enumerate(arg_names):
|
||||
@@ -130,17 +237,17 @@ class CodeGenerator(ast.NodeVisitor):
|
||||
attr = _triton.ir.attribute(attr, self.attributes[i])
|
||||
fn.add_attr(idx + 1, attr)
|
||||
fn.args[idx].name = arg_name
|
||||
arg_values.append(fn.args[idx])
|
||||
arg_values.append(triton.language.tensor(fn.args[idx], self.prototype.param_types[idx]))
|
||||
idx += 1
|
||||
|
||||
for arg_name, arg_value in zip(arg_names, arg_values):
|
||||
self.set_value(arg_name, arg_value)
|
||||
self.set_value(arg_name, arg_value, is_arg=True)
|
||||
if inline:
|
||||
self.visit_compound_statement(node.body)
|
||||
return self.last_ret
|
||||
else:
|
||||
entry = _triton.ir.basic_block.create(self.builder.context, "entry", fn)
|
||||
self.module.seal_block(entry)
|
||||
self._seal_block(entry)
|
||||
self.builder.set_insert_block(entry)
|
||||
# visit function body
|
||||
self.visit_compound_statement(node.body)
|
||||
@@ -187,11 +294,12 @@ class CodeGenerator(ast.NodeVisitor):
|
||||
if not isinstance(values, tuple):
|
||||
values = [values]
|
||||
for name, value in zip(names, values):
|
||||
# TODO: can we store constexpr here to support constant folding?
|
||||
# by default, constexpr are assigned into python variable
|
||||
if isinstance(value, triton.language.constexpr):
|
||||
value = value.value
|
||||
if not isinstance(value, triton.language.block):
|
||||
value = triton.language.core._to_ir(value, self.builder)
|
||||
if not isinstance(value, triton.language.tensor):
|
||||
value = triton.language.core._to_tensor(value, self.builder)
|
||||
self.set_value(name, value)
|
||||
|
||||
def visit_AugAssign(self, node):
|
||||
@@ -220,9 +328,9 @@ class CodeGenerator(ast.NodeVisitor):
|
||||
def visit_BinOp(self, node):
|
||||
lhs = self.visit(node.left)
|
||||
rhs = self.visit(node.right)
|
||||
if isinstance(lhs, triton.language.core.constexpr):
|
||||
if isinstance(lhs, triton.language.constexpr):
|
||||
lhs = lhs.value
|
||||
if isinstance(rhs, triton.language.core.constexpr):
|
||||
if isinstance(rhs, triton.language.constexpr):
|
||||
rhs = rhs.value
|
||||
fn = {
|
||||
ast.Add: '__add__',
|
||||
@@ -238,9 +346,9 @@ class CodeGenerator(ast.NodeVisitor):
|
||||
ast.BitOr: '__or__',
|
||||
ast.BitXor: '__xor__',
|
||||
}[type(node.op)]
|
||||
if self.is_triton_object(lhs):
|
||||
if self.is_triton_tensor(lhs):
|
||||
return getattr(lhs, fn)(rhs, _builder=self.builder)
|
||||
elif self.is_triton_object(rhs):
|
||||
elif self.is_triton_tensor(rhs):
|
||||
fn = fn[:2] + 'r' + fn[2:]
|
||||
return getattr(rhs, fn)(lhs, _builder=self.builder)
|
||||
else:
|
||||
@@ -248,15 +356,15 @@ class CodeGenerator(ast.NodeVisitor):
|
||||
|
||||
def visit_If(self, node):
|
||||
cond = self.visit(node.test)
|
||||
if isinstance(cond, triton.language.block):
|
||||
if isinstance(cond, triton.language.tensor):
|
||||
cond = cond.to(triton.language.int1, _builder=self.builder)
|
||||
current_bb = self.builder.get_insert_block()
|
||||
then_bb = _triton.ir.basic_block.create(self.builder.context, "then", current_bb.parent)
|
||||
else_bb = _triton.ir.basic_block.create(self.builder.context, "else", current_bb.parent) if node.orelse else None
|
||||
endif_bb = _triton.ir.basic_block.create(self.builder.context, "endif", current_bb.parent)
|
||||
self.module.seal_block(then_bb)
|
||||
self._seal_block(then_bb)
|
||||
if else_bb:
|
||||
self.module.seal_block(else_bb)
|
||||
self._seal_block(else_bb)
|
||||
self.builder.cond_br(cond.handle, then_bb, else_bb)
|
||||
else:
|
||||
self.builder.cond_br(cond.handle, then_bb, endif_bb)
|
||||
@@ -271,7 +379,7 @@ class CodeGenerator(ast.NodeVisitor):
|
||||
# TODO: last statement is a terminator?
|
||||
if not is_terminator:
|
||||
self.builder.br(endif_bb)
|
||||
self.module.seal_block(endif_bb)
|
||||
self._seal_block(endif_bb)
|
||||
self.builder.set_insert_block(endif_bb)
|
||||
else:
|
||||
if isinstance(cond, triton.language.constexpr):
|
||||
@@ -296,9 +404,9 @@ class CodeGenerator(ast.NodeVisitor):
|
||||
assert len(node.ops) == 1
|
||||
lhs = self.visit(node.left)
|
||||
rhs = self.visit(node.comparators[0])
|
||||
if isinstance(lhs, triton.language.core.constexpr):
|
||||
if isinstance(lhs, triton.language.constexpr):
|
||||
lhs = lhs.value
|
||||
if isinstance(rhs, triton.language.core.constexpr):
|
||||
if isinstance(rhs, triton.language.constexpr):
|
||||
rhs = rhs.value
|
||||
if type(node.ops[0]) == ast.Is:
|
||||
return triton.language.constexpr(lhs is rhs)
|
||||
@@ -312,9 +420,9 @@ class CodeGenerator(ast.NodeVisitor):
|
||||
ast.Gt: '__gt__',
|
||||
ast.GtE: '__ge__',
|
||||
}[type(node.ops[0])]
|
||||
if self.is_triton_object(lhs):
|
||||
if self.is_triton_tensor(lhs):
|
||||
return getattr(lhs, fn)(rhs, _builder=self.builder)
|
||||
elif self.is_triton_object(rhs):
|
||||
elif self.is_triton_tensor(rhs):
|
||||
fn = fn[:2] + 'r' + fn[2:]
|
||||
return getattr(rhs, fn)(lhs, _builder=self.builder)
|
||||
else:
|
||||
@@ -325,21 +433,21 @@ class CodeGenerator(ast.NodeVisitor):
|
||||
if type(node.op) == ast.Not:
|
||||
assert isinstance(op, triton.language.constexpr), "`not` only supported for constexpr at the moment"
|
||||
return triton.language.constexpr(not op)
|
||||
if isinstance(op, triton.language.core.constexpr):
|
||||
if isinstance(op, triton.language.constexpr):
|
||||
op = op.value
|
||||
fn = {
|
||||
ast.USub: '__neg__',
|
||||
ast.UAdd: '__pos__',
|
||||
ast.Invert: '__invert__',
|
||||
}[type(node.op)]
|
||||
if self.is_triton_object(op):
|
||||
if self.is_triton_tensor(op):
|
||||
return getattr(op, fn)(_builder=self.builder)
|
||||
return getattr(op, fn)()
|
||||
|
||||
def visit_While(self, node):
|
||||
current_bb = self.builder.get_insert_block()
|
||||
loop_bb = _triton.ir.basic_block.create(self.module.builder.context, "loop", current_bb.parent)
|
||||
next_bb = _triton.ir.basic_block.create(self.module.builder.context, "postloop", current_bb.parent)
|
||||
loop_bb = _triton.ir.basic_block.create(self.builder.context, "loop", current_bb.parent)
|
||||
next_bb = _triton.ir.basic_block.create(self.builder.context, "postloop", current_bb.parent)
|
||||
|
||||
def continue_fn():
|
||||
cond = self.visit(node.test)
|
||||
@@ -350,9 +458,9 @@ class CodeGenerator(ast.NodeVisitor):
|
||||
self.visit_compound_statement(node.body)
|
||||
continue_fn()
|
||||
stop_bb = self.builder.get_insert_block()
|
||||
self.module.seal_block(stop_bb)
|
||||
self.module.seal_block(loop_bb)
|
||||
self.module.seal_block(next_bb)
|
||||
self._seal_block(stop_bb)
|
||||
self._seal_block(loop_bb)
|
||||
self._seal_block(next_bb)
|
||||
self.builder.set_insert_block(next_bb)
|
||||
|
||||
for stmt in node.orelse:
|
||||
@@ -362,7 +470,7 @@ class CodeGenerator(ast.NodeVisitor):
|
||||
assert node.ctx.__class__.__name__ == "Load"
|
||||
lhs = self.visit(node.value)
|
||||
slices = self.visit(node.slice)
|
||||
if self.is_triton_object(lhs):
|
||||
if self.is_triton_tensor(lhs):
|
||||
return lhs.__getitem__(slices, _builder=self.builder)
|
||||
return lhs[slices]
|
||||
|
||||
@@ -405,8 +513,8 @@ class CodeGenerator(ast.NodeVisitor):
|
||||
step_node = ast.AugAssign(target=st_target, op=ast.Add(), value=arg_2)
|
||||
# code generation
|
||||
current_bb = self.builder.get_insert_block()
|
||||
loop_bb = _triton.ir.basic_block.create(self.module.builder.context, "loop", current_bb.parent)
|
||||
next_bb = _triton.ir.basic_block.create(self.module.builder.context, "postloop", current_bb.parent)
|
||||
loop_bb = _triton.ir.basic_block.create(self.builder.context, "loop", current_bb.parent)
|
||||
next_bb = _triton.ir.basic_block.create(self.builder.context, "postloop", current_bb.parent)
|
||||
|
||||
def continue_fn():
|
||||
self.visit(step_node)
|
||||
@@ -421,9 +529,9 @@ class CodeGenerator(ast.NodeVisitor):
|
||||
# TODO: handle case where body breaks control flow
|
||||
continue_fn()
|
||||
stop_bb = self.builder.get_insert_block()
|
||||
self.module.seal_block(stop_bb)
|
||||
self.module.seal_block(loop_bb)
|
||||
self.module.seal_block(next_bb)
|
||||
self._seal_block(stop_bb)
|
||||
self._seal_block(loop_bb)
|
||||
self._seal_block(next_bb)
|
||||
self.builder.set_insert_block(next_bb)
|
||||
|
||||
for stmt in node.orelse:
|
||||
@@ -451,7 +559,7 @@ class CodeGenerator(ast.NodeVisitor):
|
||||
args = [self.visit(arg) for arg in node.args]
|
||||
if isinstance(fn, JITFunction):
|
||||
return fn(*args, generator=self, **kws)
|
||||
if hasattr(fn, '__self__') and self.is_triton_object(fn.__self__) or \
|
||||
if hasattr(fn, '__self__') and self.is_triton_tensor(fn.__self__) or \
|
||||
sys.modules[fn.__module__] is triton.language.core:
|
||||
return fn(*args, _builder=self.builder, **kws)
|
||||
if fn in self.builtins.values():
|
||||
@@ -581,7 +689,7 @@ class Kernel:
|
||||
}
|
||||
if hasattr(obj, 'data_ptr'):
|
||||
return type_names[obj.dtype]
|
||||
if isinstance(obj, triton.language.core.constexpr):
|
||||
if isinstance(obj, triton.language.constexpr):
|
||||
obj = obj.value
|
||||
if isinstance(obj, int):
|
||||
if -2**31 <= obj < 2**31:
|
||||
@@ -613,34 +721,34 @@ class Kernel:
|
||||
return 'scalar', name
|
||||
|
||||
@staticmethod
|
||||
def _to_triton_ir(context, obj):
|
||||
def _to_triton_ir(obj):
|
||||
which, name = obj
|
||||
type_map = {
|
||||
'I': _triton.ir.type.get_int32,
|
||||
'L': _triton.ir.type.get_int64,
|
||||
'f': _triton.ir.type.get_fp32,
|
||||
'B': _triton.ir.type.get_int1,
|
||||
'f8': _triton.ir.type.get_fp8,
|
||||
'f16': _triton.ir.type.get_fp16,
|
||||
'bf16': _triton.ir.type.get_bf16,
|
||||
'f32': _triton.ir.type.get_fp32,
|
||||
'f64': _triton.ir.type.get_fp64,
|
||||
'i1': _triton.ir.type.get_int1,
|
||||
'i8': _triton.ir.type.get_int8,
|
||||
'i16': _triton.ir.type.get_int16,
|
||||
'i32': _triton.ir.type.get_int32,
|
||||
'i64': _triton.ir.type.get_int64,
|
||||
'u8': _triton.ir.type.get_uint8,
|
||||
'u16': _triton.ir.type.get_uint16,
|
||||
'u32': _triton.ir.type.get_uint32,
|
||||
'u64': _triton.ir.type.get_uint64,
|
||||
'I': triton.language.int32,
|
||||
'L': triton.language.int64,
|
||||
'f': triton.language.float32,
|
||||
'B': triton.language.int1,
|
||||
'f8': triton.language.float8,
|
||||
'f16': triton.language.float16,
|
||||
'bf16': triton.language.bfloat16,
|
||||
'f32': triton.language.float32,
|
||||
'f64': triton.language.float64,
|
||||
'i1': triton.language.int1,
|
||||
'i8': triton.language.int8,
|
||||
'i16': triton.language.int16,
|
||||
'i32': triton.language.int32,
|
||||
'i64': triton.language.int64,
|
||||
'u8': triton.language.uint8,
|
||||
'u16': triton.language.uint16,
|
||||
'u32': triton.language.uint32,
|
||||
'u64': triton.language.uint64,
|
||||
}
|
||||
# convert torch.Tensor to Triton IR pointers
|
||||
if which == 'ptr':
|
||||
elt_ty = type_map[name](context)
|
||||
return _triton.ir.type.make_ptr(elt_ty, 1)
|
||||
elt_ty = type_map[name]
|
||||
return triton.language.pointer_type(elt_ty, 1)
|
||||
# default path returns triton.ir.type directly
|
||||
return type_map[name](context)
|
||||
return type_map[name]
|
||||
|
||||
@staticmethod
|
||||
def pow2_divisor(N):
|
||||
@@ -920,25 +1028,31 @@ class JITFunction:
|
||||
assert isinstance(tree.body[0], ast.FunctionDef)
|
||||
return tree
|
||||
|
||||
# Called by CodeGenerator.visit_Call()
|
||||
def __call__(self, *args, generator: CodeGenerator, **kwargs):
|
||||
try:
|
||||
from inspect import getcallargs
|
||||
arg_values = getcallargs(self.fn, *args, **kwargs)
|
||||
arg_values = [arg_values[name] for name in self.arg_names]
|
||||
arg_values = [arg if isinstance(arg, triton.language.block)
|
||||
arg_values = [arg if isinstance(arg, triton.language.tensor)
|
||||
else triton.language.constexpr(arg) for arg in arg_values]
|
||||
|
||||
# Record values in the caller (parent scope)
|
||||
gscope = generator.gscope.copy()
|
||||
lscope = generator.lscope.copy()
|
||||
values = generator.module.get_values().copy()
|
||||
types = generator.module.get_types().copy()
|
||||
|
||||
# TODO: clear values other than args
|
||||
lvalues = generator.lvalues.copy()
|
||||
# types = generator.module.get_types().copy()
|
||||
generator.gscope = sys.modules[self.fn.__module__].__dict__
|
||||
generator.lscope = dict()
|
||||
ret = generator.visit_FunctionDef(self.parse().body[0], inline=True, arg_values=arg_values)
|
||||
generator.gscope = gscope
|
||||
generator.lscope = lscope
|
||||
generator.module.set_values(values)
|
||||
generator.module.set_types(types)
|
||||
|
||||
generator.lvalues = lvalues
|
||||
# generator.module.set_types(types)
|
||||
|
||||
return ret
|
||||
except Exception as e:
|
||||
node = generator.last_node
|
||||
@@ -1023,9 +1137,9 @@ class JITFunction:
|
||||
# create IR module
|
||||
context = _triton.ir.context()
|
||||
# get just-in-time proto-type of kernel
|
||||
arg_types = [Kernel._to_triton_ir(context, arg) for arg in arg_types]
|
||||
ret_type = _triton.ir.type.get_void(context)
|
||||
prototype = _triton.ir.type.make_function(ret_type, arg_types)
|
||||
arg_types = [Kernel._to_triton_ir(arg) for arg in arg_types]
|
||||
ret_type = triton.language.void
|
||||
prototype = triton.language.function_type(ret_type, arg_types)
|
||||
# generate Triton-IR
|
||||
# export symbols visible from self into code-generator object
|
||||
gscope = self.__globals__
|
||||
|
Reference in New Issue
Block a user