[LANG] Added support for constexpr (#361)

This commit is contained in:
Philippe Tillet
2021-10-30 00:32:58 -07:00
committed by GitHub
parent 770ea96cca
commit 2acaa4d0dd
16 changed files with 355 additions and 365 deletions

View File

@@ -103,7 +103,9 @@ class CodeGenerator(ast.NodeVisitor):
arg_values = []
for i, arg_name in enumerate(arg_names):
if i in self.constants:
cst = triton.language.core._to_ir(self.constants[i], self.builder)
cst = self.constants[i]
if not isinstance(cst, triton.language.constexpr):
cst = triton.language.constexpr(self.constants[i])
arg_values.append(cst)
else:
if i in self.attributes:
@@ -114,6 +116,7 @@ class CodeGenerator(ast.NodeVisitor):
fn.add_attr(i + 1, attr)
fn.args[i].name = arg_name
arg_values.append(fn.args[i])
for arg_name, arg_value in zip(arg_names, arg_values):
self.set_value(arg_name, arg_value)
if inline:
@@ -139,6 +142,22 @@ class CodeGenerator(ast.NodeVisitor):
ast.NodeVisitor.generic_visit(self, node)
return node.arg
def visit_AnnAssign(self, node):
# extract attributes
annotation = self.visit(node.annotation)
target = self.visit(node.target)
value = self.visit(node.value)
# constexpr
if annotation == triton.language.constexpr:
if target in self.lscope:
raise ValueError(f'{target} is already defined.'
f' constexpr cannot be reassigned.')
self.lscope[target] = triton.language.constexpr(value)
return self.lscope[target]
# default: call visit_Assign
return self.visit_Assign(node)
def visit_Assign(self, node):
_names = []
for target in node.targets:
@@ -151,6 +170,9 @@ class CodeGenerator(ast.NodeVisitor):
if not isinstance(values, tuple):
values = [values]
for name, value in zip(names, values):
# by default, constexpr are assigned into python variable
if isinstance(value, triton.language.constexpr):
value = value.value
if not isinstance(value, triton.language.block):
value = triton.language.core._to_ir(value, self.builder)
self.set_value(name, value)
@@ -181,6 +203,10 @@ class CodeGenerator(ast.NodeVisitor):
def visit_BinOp(self, node):
lhs = self.visit(node.left)
rhs = self.visit(node.right)
if isinstance(lhs, triton.language.core.constexpr):
lhs = lhs.value
if isinstance(rhs, triton.language.core.constexpr):
rhs = rhs.value
fn = {
ast.Add: '__add__',
ast.Sub: '__sub__',
@@ -195,17 +221,13 @@ class CodeGenerator(ast.NodeVisitor):
ast.BitOr: '__or__',
ast.BitXor: '__xor__',
}[type(node.op)]
kws = dict()
if self.is_triton_object(lhs):
kws['_builder'] = self.builder
ret = getattr(lhs, fn)(rhs, **kws)
if ret is NotImplemented:
if self.is_triton_object(rhs):
kws['_builder'] = self.builder
return getattr(lhs, fn)(rhs, _builder=self.builder)
elif self.is_triton_object(rhs):
fn = fn[:2] + 'r' + fn[2:]
ret = getattr(rhs, fn)(lhs, **kws)
return ret
return getattr(rhs, fn)(lhs, _builder=self.builder)
else:
return getattr(lhs, fn)(rhs)
def visit_If(self, node):
cond = self.visit(node.test)
@@ -254,6 +276,10 @@ class CodeGenerator(ast.NodeVisitor):
assert len(node.ops) == 1
lhs = self.visit(node.left)
rhs = self.visit(node.comparators[0])
if isinstance(lhs, triton.language.core.constexpr):
lhs = lhs.value
if isinstance(rhs, triton.language.core.constexpr):
rhs = rhs.value
fn = {
ast.Eq: '__eq__',
ast.NotEq: '__ne__',
@@ -274,6 +300,8 @@ class CodeGenerator(ast.NodeVisitor):
def visit_UnaryOp(self, node):
op = self.visit(node.operand)
if isinstance(op, triton.language.core.constexpr):
op = op.value
fn = {
ast.USub: '__neg__',
ast.UAdd: '__pos__',
@@ -394,7 +422,7 @@ class CodeGenerator(ast.NodeVisitor):
return fn(*args, **kws)
def visit_Num(self, node):
return node.n
return triton.language.constexpr(node.n)
def visit_Attribute(self, node):
lhs = self.visit(node.value)
@@ -477,6 +505,8 @@ class Kernel:
}
if hasattr(obj, 'data_ptr'):
return type_names[obj.dtype]
if isinstance(obj, triton.language.core.constexpr):
obj = obj.value
if isinstance(obj, int):
if abs(obj) <= 0xffffffff:
return 'I'
@@ -485,6 +515,8 @@ class Kernel:
return 'f'
if isinstance(obj, bool):
return 'B'
if isinstance(obj, str):
return 'str'
assert False
@@ -537,7 +569,8 @@ class Kernel:
def __init__(self, fn):
self.fn = fn
def _compile(self, *wargs, device, attributes, constants, num_warps, num_stages, **meta):
def _compile(self, *wargs, device, attributes, constants, num_warps, num_stages):
wargs = [arg for arg in wargs if not isinstance(arg, triton.language.constexpr)]
# create IR module
context = _triton.ir.context()
# get just-in-time proto-type of kernel
@@ -547,7 +580,7 @@ class Kernel:
# generate Triton-IR
# export symbols visible from self.fn into code-generator object
gscope = sys.modules[self.fn.module].__dict__
generator = CodeGenerator(context, prototype, gscope=gscope, attributes=attributes, constants=constants, kwargs=meta)
generator = CodeGenerator(context, prototype, gscope=gscope, attributes=attributes, constants=constants, kwargs=dict())
try:
generator.visit(self.fn.parse())
except Exception as e:
@@ -566,7 +599,19 @@ class Kernel:
raise OutOfResources(shared_mem, max_shared_memory, "shared memory")
return Binary(backend, name, asm, shared_mem, num_warps)
def __call__(self, *wargs, grid, num_warps=4, num_stages=2, **meta):
def __call__(self, *wargs, grid, num_warps=4, num_stages=2, **kwargs):
# handle arguments passed by name
kwargs = {self.fn.arg_names.index(name): value for name, value in kwargs.items()}
wargs = list(wargs)
for i, pos in enumerate(sorted(kwargs)):
wargs.insert(pos + i, kwargs[pos])
if len(wargs) != len(self.fn.arg_names):
raise TypeError(f"Function takes {len(self.fn.arg_names)} positional arguments but {len(wargs)} were given")
# handle annotations
for name, type in self.fn.__annotations__.items():
pos = self.fn.arg_names.index(name)
assert type == triton.language.core.constexpr
wargs[pos] = type(wargs[pos])
# device inference
tensor_idxs = [i for i, arg in enumerate(wargs) if hasattr(arg, 'data_ptr')]
if len(tensor_idxs) == 0:
@@ -601,18 +646,19 @@ class Kernel:
args = [arg.data_ptr() if i in tensor_idxs else arg for i, arg in enumerate(wargs)]
attributes = {i: Kernel.pow2_divisor(a) for i, a in enumerate(args) \
if isinstance(a, int) and i not in self.fn.do_not_specialize}
# transforms ints whose value is one into constants for just-in-time compilation
constants = {i: arg for i, arg in enumerate(wargs) if isinstance(arg, int) and arg == 1}
constants.update({i: arg.value for i, arg in enumerate(wargs) if isinstance(arg, triton.language.constexpr)})
# compute hash for caching this kernel
types_key = Kernel._types_key(*wargs, tensor_idxs=tensor_idxs)
attr_key = tuple(attributes.items())
meta_key = tuple(sorted(meta.items()))
const_key = tuple(constants.items())
compute_capability = torch.cuda.get_device_capability(device)
key = (
self.fn.cache_key, version_key(), compute_capability,
types_key, attr_key, num_warps, num_stages, meta_key, const_key
types_key, attr_key, num_warps, num_stages, const_key
)
key = repr(key)
@@ -644,7 +690,7 @@ class Kernel:
binary = self._compile(
*wargs, device=device_idx, attributes=attributes,
num_warps=num_warps, num_stages=num_stages,
constants=constants, **meta
constants=constants,
)
if bin_cache_path:
assert bin_lock_path is not None
@@ -657,12 +703,15 @@ class Kernel:
drv_cache[key] = LoadedBinary(device_idx, binary)
# pack arguments
fmt = ''.join(['P' if i in tensor_idxs else Kernel._type_name(arg) for i, arg in enumerate(wargs)])
params = struct.pack(fmt, *args)
fmt = ''.join(['P' if i in tensor_idxs else Kernel._type_name(arg) for i, arg in enumerate(wargs) if not isinstance(arg, triton.language.core.constexpr)])
params = struct.pack(fmt, *[arg for arg in args if not isinstance(arg, triton.language.core.constexpr)])
# enqueue cached function into stream
callable = drv_cache[key]
stream = torch.cuda.current_stream(device_idx).cuda_stream
grid = grid(meta) if hasattr(grid, '__call__') else grid
csts = {self.fn.arg_names[i]: arg.value for i, arg in enumerate(wargs) if isinstance(arg, triton.language.core.constexpr)}
grid = grid(csts) if hasattr(grid, '__call__') else grid
if isinstance(grid, int):
grid = tuple(grid)
callable(stream, params, *grid)
return callable
@@ -697,31 +746,31 @@ class Autotuner:
def _bench(self, *args, config, **meta):
# check for conflicts, i.e. meta-parameters both provided
# as kwargs and by the autotuner
conflicts = meta.keys() & config.meta.keys()
conflicts = meta.keys() & config.kwargs.keys()
if conflicts:
raise ValueError(
f"Conflicting meta-parameters: {', '.join(conflicts)}."
" Make sure that you don't re-define auto-tuned symbols."
)
# augment meta-parameters with tunable ones
current = dict(meta, **config.meta)
current = dict(meta, **config.kwargs)
def kernel_call():
self.hook(args)
self.kernel(*args, num_warps=config.num_warps, num_stages=config.num_stages, **current)
return triton.testing.do_bench(kernel_call)
def __call__(self, *args, **meta):
def __call__(self, *args, **kwargs):
if len(self.configs) > 1:
key = tuple([args[i] for i in self.key_idx])
if key not in self.cache:
timings = {config: self._bench(*args, config=config, **meta) \
timings = {config: self._bench(*args, config=config, **kwargs) \
for config in self.configs}
self.cache[key] = builtins.min(timings, key=timings.get)
self.hook(args)
config = self.cache[key]
else:
config = self.configs[0]
return self.kernel(*args, num_warps=config.num_warps, num_stages=config.num_stages, **meta, **config.meta)
return self.kernel(*args, num_warps=config.num_warps, num_stages=config.num_stages, **kwargs, **config.kwargs)
@functools.lru_cache()
@@ -769,6 +818,8 @@ class JITFunction:
# when called with a grid using __getitem__
self.kernel_decorators = []
self.kernel = None
# annotations
self.__annotations__ = fn.__annotations__
# forward docs
self.__doc__ = fn.__doc__
@@ -839,8 +890,8 @@ class Config:
Mostly useful for matrix multiplication workloads on SM80+ GPUs.
:type num_stages: int
"""
def __init__(self, meta, num_warps=4, num_stages=2):
self.meta = meta
def __init__(self, kwargs, num_warps=4, num_stages=2):
self.kwargs = kwargs
self.num_warps = num_warps
self.num_stages = num_stages