[LANG] Added support for constexpr (#361)
This commit is contained in:
@@ -103,7 +103,9 @@ class CodeGenerator(ast.NodeVisitor):
|
||||
arg_values = []
|
||||
for i, arg_name in enumerate(arg_names):
|
||||
if i in self.constants:
|
||||
cst = triton.language.core._to_ir(self.constants[i], self.builder)
|
||||
cst = self.constants[i]
|
||||
if not isinstance(cst, triton.language.constexpr):
|
||||
cst = triton.language.constexpr(self.constants[i])
|
||||
arg_values.append(cst)
|
||||
else:
|
||||
if i in self.attributes:
|
||||
@@ -114,6 +116,7 @@ class CodeGenerator(ast.NodeVisitor):
|
||||
fn.add_attr(i + 1, attr)
|
||||
fn.args[i].name = arg_name
|
||||
arg_values.append(fn.args[i])
|
||||
|
||||
for arg_name, arg_value in zip(arg_names, arg_values):
|
||||
self.set_value(arg_name, arg_value)
|
||||
if inline:
|
||||
@@ -139,6 +142,22 @@ class CodeGenerator(ast.NodeVisitor):
|
||||
ast.NodeVisitor.generic_visit(self, node)
|
||||
return node.arg
|
||||
|
||||
def visit_AnnAssign(self, node):
|
||||
# extract attributes
|
||||
annotation = self.visit(node.annotation)
|
||||
target = self.visit(node.target)
|
||||
value = self.visit(node.value)
|
||||
# constexpr
|
||||
if annotation == triton.language.constexpr:
|
||||
if target in self.lscope:
|
||||
raise ValueError(f'{target} is already defined.'
|
||||
f' constexpr cannot be reassigned.')
|
||||
self.lscope[target] = triton.language.constexpr(value)
|
||||
return self.lscope[target]
|
||||
# default: call visit_Assign
|
||||
return self.visit_Assign(node)
|
||||
|
||||
|
||||
def visit_Assign(self, node):
|
||||
_names = []
|
||||
for target in node.targets:
|
||||
@@ -151,6 +170,9 @@ class CodeGenerator(ast.NodeVisitor):
|
||||
if not isinstance(values, tuple):
|
||||
values = [values]
|
||||
for name, value in zip(names, values):
|
||||
# by default, constexpr are assigned into python variable
|
||||
if isinstance(value, triton.language.constexpr):
|
||||
value = value.value
|
||||
if not isinstance(value, triton.language.block):
|
||||
value = triton.language.core._to_ir(value, self.builder)
|
||||
self.set_value(name, value)
|
||||
@@ -181,6 +203,10 @@ class CodeGenerator(ast.NodeVisitor):
|
||||
def visit_BinOp(self, node):
|
||||
lhs = self.visit(node.left)
|
||||
rhs = self.visit(node.right)
|
||||
if isinstance(lhs, triton.language.core.constexpr):
|
||||
lhs = lhs.value
|
||||
if isinstance(rhs, triton.language.core.constexpr):
|
||||
rhs = rhs.value
|
||||
fn = {
|
||||
ast.Add: '__add__',
|
||||
ast.Sub: '__sub__',
|
||||
@@ -195,17 +221,13 @@ class CodeGenerator(ast.NodeVisitor):
|
||||
ast.BitOr: '__or__',
|
||||
ast.BitXor: '__xor__',
|
||||
}[type(node.op)]
|
||||
kws = dict()
|
||||
|
||||
if self.is_triton_object(lhs):
|
||||
kws['_builder'] = self.builder
|
||||
ret = getattr(lhs, fn)(rhs, **kws)
|
||||
if ret is NotImplemented:
|
||||
if self.is_triton_object(rhs):
|
||||
kws['_builder'] = self.builder
|
||||
return getattr(lhs, fn)(rhs, _builder=self.builder)
|
||||
elif self.is_triton_object(rhs):
|
||||
fn = fn[:2] + 'r' + fn[2:]
|
||||
ret = getattr(rhs, fn)(lhs, **kws)
|
||||
return ret
|
||||
return getattr(rhs, fn)(lhs, _builder=self.builder)
|
||||
else:
|
||||
return getattr(lhs, fn)(rhs)
|
||||
|
||||
def visit_If(self, node):
|
||||
cond = self.visit(node.test)
|
||||
@@ -254,6 +276,10 @@ class CodeGenerator(ast.NodeVisitor):
|
||||
assert len(node.ops) == 1
|
||||
lhs = self.visit(node.left)
|
||||
rhs = self.visit(node.comparators[0])
|
||||
if isinstance(lhs, triton.language.core.constexpr):
|
||||
lhs = lhs.value
|
||||
if isinstance(rhs, triton.language.core.constexpr):
|
||||
rhs = rhs.value
|
||||
fn = {
|
||||
ast.Eq: '__eq__',
|
||||
ast.NotEq: '__ne__',
|
||||
@@ -274,6 +300,8 @@ class CodeGenerator(ast.NodeVisitor):
|
||||
|
||||
def visit_UnaryOp(self, node):
|
||||
op = self.visit(node.operand)
|
||||
if isinstance(op, triton.language.core.constexpr):
|
||||
op = op.value
|
||||
fn = {
|
||||
ast.USub: '__neg__',
|
||||
ast.UAdd: '__pos__',
|
||||
@@ -394,7 +422,7 @@ class CodeGenerator(ast.NodeVisitor):
|
||||
return fn(*args, **kws)
|
||||
|
||||
def visit_Num(self, node):
|
||||
return node.n
|
||||
return triton.language.constexpr(node.n)
|
||||
|
||||
def visit_Attribute(self, node):
|
||||
lhs = self.visit(node.value)
|
||||
@@ -477,6 +505,8 @@ class Kernel:
|
||||
}
|
||||
if hasattr(obj, 'data_ptr'):
|
||||
return type_names[obj.dtype]
|
||||
if isinstance(obj, triton.language.core.constexpr):
|
||||
obj = obj.value
|
||||
if isinstance(obj, int):
|
||||
if abs(obj) <= 0xffffffff:
|
||||
return 'I'
|
||||
@@ -485,6 +515,8 @@ class Kernel:
|
||||
return 'f'
|
||||
if isinstance(obj, bool):
|
||||
return 'B'
|
||||
if isinstance(obj, str):
|
||||
return 'str'
|
||||
assert False
|
||||
|
||||
|
||||
@@ -537,7 +569,8 @@ class Kernel:
|
||||
def __init__(self, fn):
|
||||
self.fn = fn
|
||||
|
||||
def _compile(self, *wargs, device, attributes, constants, num_warps, num_stages, **meta):
|
||||
def _compile(self, *wargs, device, attributes, constants, num_warps, num_stages):
|
||||
wargs = [arg for arg in wargs if not isinstance(arg, triton.language.constexpr)]
|
||||
# create IR module
|
||||
context = _triton.ir.context()
|
||||
# get just-in-time proto-type of kernel
|
||||
@@ -547,7 +580,7 @@ class Kernel:
|
||||
# generate Triton-IR
|
||||
# export symbols visible from self.fn into code-generator object
|
||||
gscope = sys.modules[self.fn.module].__dict__
|
||||
generator = CodeGenerator(context, prototype, gscope=gscope, attributes=attributes, constants=constants, kwargs=meta)
|
||||
generator = CodeGenerator(context, prototype, gscope=gscope, attributes=attributes, constants=constants, kwargs=dict())
|
||||
try:
|
||||
generator.visit(self.fn.parse())
|
||||
except Exception as e:
|
||||
@@ -566,7 +599,19 @@ class Kernel:
|
||||
raise OutOfResources(shared_mem, max_shared_memory, "shared memory")
|
||||
return Binary(backend, name, asm, shared_mem, num_warps)
|
||||
|
||||
def __call__(self, *wargs, grid, num_warps=4, num_stages=2, **meta):
|
||||
def __call__(self, *wargs, grid, num_warps=4, num_stages=2, **kwargs):
|
||||
# handle arguments passed by name
|
||||
kwargs = {self.fn.arg_names.index(name): value for name, value in kwargs.items()}
|
||||
wargs = list(wargs)
|
||||
for i, pos in enumerate(sorted(kwargs)):
|
||||
wargs.insert(pos + i, kwargs[pos])
|
||||
if len(wargs) != len(self.fn.arg_names):
|
||||
raise TypeError(f"Function takes {len(self.fn.arg_names)} positional arguments but {len(wargs)} were given")
|
||||
# handle annotations
|
||||
for name, type in self.fn.__annotations__.items():
|
||||
pos = self.fn.arg_names.index(name)
|
||||
assert type == triton.language.core.constexpr
|
||||
wargs[pos] = type(wargs[pos])
|
||||
# device inference
|
||||
tensor_idxs = [i for i, arg in enumerate(wargs) if hasattr(arg, 'data_ptr')]
|
||||
if len(tensor_idxs) == 0:
|
||||
@@ -601,18 +646,19 @@ class Kernel:
|
||||
args = [arg.data_ptr() if i in tensor_idxs else arg for i, arg in enumerate(wargs)]
|
||||
attributes = {i: Kernel.pow2_divisor(a) for i, a in enumerate(args) \
|
||||
if isinstance(a, int) and i not in self.fn.do_not_specialize}
|
||||
|
||||
# transforms ints whose value is one into constants for just-in-time compilation
|
||||
constants = {i: arg for i, arg in enumerate(wargs) if isinstance(arg, int) and arg == 1}
|
||||
constants.update({i: arg.value for i, arg in enumerate(wargs) if isinstance(arg, triton.language.constexpr)})
|
||||
|
||||
# compute hash for caching this kernel
|
||||
types_key = Kernel._types_key(*wargs, tensor_idxs=tensor_idxs)
|
||||
attr_key = tuple(attributes.items())
|
||||
meta_key = tuple(sorted(meta.items()))
|
||||
const_key = tuple(constants.items())
|
||||
compute_capability = torch.cuda.get_device_capability(device)
|
||||
|
||||
key = (
|
||||
self.fn.cache_key, version_key(), compute_capability,
|
||||
types_key, attr_key, num_warps, num_stages, meta_key, const_key
|
||||
types_key, attr_key, num_warps, num_stages, const_key
|
||||
)
|
||||
key = repr(key)
|
||||
|
||||
@@ -644,7 +690,7 @@ class Kernel:
|
||||
binary = self._compile(
|
||||
*wargs, device=device_idx, attributes=attributes,
|
||||
num_warps=num_warps, num_stages=num_stages,
|
||||
constants=constants, **meta
|
||||
constants=constants,
|
||||
)
|
||||
if bin_cache_path:
|
||||
assert bin_lock_path is not None
|
||||
@@ -657,12 +703,15 @@ class Kernel:
|
||||
|
||||
drv_cache[key] = LoadedBinary(device_idx, binary)
|
||||
# pack arguments
|
||||
fmt = ''.join(['P' if i in tensor_idxs else Kernel._type_name(arg) for i, arg in enumerate(wargs)])
|
||||
params = struct.pack(fmt, *args)
|
||||
fmt = ''.join(['P' if i in tensor_idxs else Kernel._type_name(arg) for i, arg in enumerate(wargs) if not isinstance(arg, triton.language.core.constexpr)])
|
||||
params = struct.pack(fmt, *[arg for arg in args if not isinstance(arg, triton.language.core.constexpr)])
|
||||
# enqueue cached function into stream
|
||||
callable = drv_cache[key]
|
||||
stream = torch.cuda.current_stream(device_idx).cuda_stream
|
||||
grid = grid(meta) if hasattr(grid, '__call__') else grid
|
||||
csts = {self.fn.arg_names[i]: arg.value for i, arg in enumerate(wargs) if isinstance(arg, triton.language.core.constexpr)}
|
||||
grid = grid(csts) if hasattr(grid, '__call__') else grid
|
||||
if isinstance(grid, int):
|
||||
grid = tuple(grid)
|
||||
callable(stream, params, *grid)
|
||||
return callable
|
||||
|
||||
@@ -697,31 +746,31 @@ class Autotuner:
|
||||
def _bench(self, *args, config, **meta):
|
||||
# check for conflicts, i.e. meta-parameters both provided
|
||||
# as kwargs and by the autotuner
|
||||
conflicts = meta.keys() & config.meta.keys()
|
||||
conflicts = meta.keys() & config.kwargs.keys()
|
||||
if conflicts:
|
||||
raise ValueError(
|
||||
f"Conflicting meta-parameters: {', '.join(conflicts)}."
|
||||
" Make sure that you don't re-define auto-tuned symbols."
|
||||
)
|
||||
# augment meta-parameters with tunable ones
|
||||
current = dict(meta, **config.meta)
|
||||
current = dict(meta, **config.kwargs)
|
||||
def kernel_call():
|
||||
self.hook(args)
|
||||
self.kernel(*args, num_warps=config.num_warps, num_stages=config.num_stages, **current)
|
||||
return triton.testing.do_bench(kernel_call)
|
||||
|
||||
def __call__(self, *args, **meta):
|
||||
def __call__(self, *args, **kwargs):
|
||||
if len(self.configs) > 1:
|
||||
key = tuple([args[i] for i in self.key_idx])
|
||||
if key not in self.cache:
|
||||
timings = {config: self._bench(*args, config=config, **meta) \
|
||||
timings = {config: self._bench(*args, config=config, **kwargs) \
|
||||
for config in self.configs}
|
||||
self.cache[key] = builtins.min(timings, key=timings.get)
|
||||
self.hook(args)
|
||||
config = self.cache[key]
|
||||
else:
|
||||
config = self.configs[0]
|
||||
return self.kernel(*args, num_warps=config.num_warps, num_stages=config.num_stages, **meta, **config.meta)
|
||||
return self.kernel(*args, num_warps=config.num_warps, num_stages=config.num_stages, **kwargs, **config.kwargs)
|
||||
|
||||
|
||||
@functools.lru_cache()
|
||||
@@ -769,6 +818,8 @@ class JITFunction:
|
||||
# when called with a grid using __getitem__
|
||||
self.kernel_decorators = []
|
||||
self.kernel = None
|
||||
# annotations
|
||||
self.__annotations__ = fn.__annotations__
|
||||
# forward docs
|
||||
self.__doc__ = fn.__doc__
|
||||
|
||||
@@ -839,8 +890,8 @@ class Config:
|
||||
Mostly useful for matrix multiplication workloads on SM80+ GPUs.
|
||||
:type num_stages: int
|
||||
"""
|
||||
def __init__(self, meta, num_warps=4, num_stages=2):
|
||||
self.meta = meta
|
||||
def __init__(self, kwargs, num_warps=4, num_stages=2):
|
||||
self.kwargs = kwargs
|
||||
self.num_warps = num_warps
|
||||
self.num_stages = num_stages
|
||||
|
||||
|
Reference in New Issue
Block a user