2021-04-20 22:29:40 -04:00
|
|
|
import ast
|
|
|
|
import builtins
|
2021-08-21 06:00:54 +02:00
|
|
|
import inspect
|
|
|
|
import struct
|
2021-04-20 22:29:40 -04:00
|
|
|
import sys
|
2021-08-21 06:00:54 +02:00
|
|
|
import tempfile
|
2021-04-20 22:29:40 -04:00
|
|
|
import textwrap
|
2021-09-18 22:48:26 -07:00
|
|
|
import hashlib
|
|
|
|
import os
|
|
|
|
import shelve
|
2021-09-21 14:10:02 -07:00
|
|
|
import shutil
|
|
|
|
import os
|
|
|
|
from .tools.disasm import extract
|
|
|
|
import tempfile
|
2021-08-21 06:00:54 +02:00
|
|
|
import torch
|
|
|
|
import triton
|
|
|
|
import triton._C.libtriton.triton as _triton
|
2021-09-21 14:10:02 -07:00
|
|
|
from filelock import FileLock
|
|
|
|
import dbm
|
2021-04-20 22:29:40 -04:00
|
|
|
|
|
|
|
|
|
|
|
class CodeGenerator(ast.NodeVisitor):
|
|
|
|
def get_value(self, name):
|
|
|
|
# search node.id in local scope
|
|
|
|
ret = None
|
|
|
|
if name in self.lscope:
|
|
|
|
ret = self.lscope[name]
|
|
|
|
# search node.id in global scope
|
|
|
|
elif name in self.gscope:
|
|
|
|
ret = self.gscope[name]
|
|
|
|
# search node.id in builtins
|
|
|
|
elif name in self.builtins:
|
|
|
|
ret = self.builtins[name]
|
|
|
|
else:
|
|
|
|
raise ValueError(f'{name} is not defined')
|
2021-04-23 17:18:14 -04:00
|
|
|
if isinstance(ret, triton.language.block):
|
2021-04-20 22:29:40 -04:00
|
|
|
handle = self.module.get_value(name)
|
2021-04-23 17:18:14 -04:00
|
|
|
return triton.language.block(handle)
|
2021-04-20 22:29:40 -04:00
|
|
|
return ret
|
|
|
|
|
|
|
|
def set_value(self, name, value):
|
|
|
|
if isinstance(value, _triton.ir.value):
|
2021-04-23 17:18:14 -04:00
|
|
|
value = triton.language.block(value)
|
|
|
|
if isinstance(value, triton.language.block):
|
2021-04-20 22:29:40 -04:00
|
|
|
self.module.set_value(name, value.handle)
|
2021-05-18 23:04:31 -04:00
|
|
|
self.module.set_type(name, value.handle.type)
|
2021-04-20 22:29:40 -04:00
|
|
|
self.lscope[name] = value
|
|
|
|
|
|
|
|
def is_triton_object(self, value):
|
2021-04-23 17:18:14 -04:00
|
|
|
return isinstance(value, triton.language.block)
|
2021-04-20 22:29:40 -04:00
|
|
|
|
2021-05-18 23:04:31 -04:00
|
|
|
def visit_compound_statement(self, stmts):
|
2021-04-20 22:29:40 -04:00
|
|
|
for stmt in stmts:
|
|
|
|
self.last_ret = self.visit(stmt)
|
|
|
|
if isinstance(stmt, ast.Return):
|
|
|
|
break
|
2021-04-29 18:54:38 -04:00
|
|
|
return stmts and isinstance(stmt, ast.Return)
|
2021-04-20 22:29:40 -04:00
|
|
|
|
|
|
|
def __init__(self, context, prototype, gscope, attributes, constants, kwargs):
|
|
|
|
self.builder = _triton.ir.builder(context)
|
|
|
|
self.module = _triton.ir.module('', self.builder)
|
|
|
|
self.prototype = prototype
|
|
|
|
self.gscope = gscope
|
|
|
|
self.lscope = dict()
|
|
|
|
self.attributes = attributes
|
|
|
|
self.constants = constants
|
|
|
|
self.kwargs = kwargs
|
|
|
|
self.last_node = None
|
2021-04-23 17:18:14 -04:00
|
|
|
self.builtins = {
|
|
|
|
'range': range,
|
|
|
|
'min': triton.language.minimum,
|
|
|
|
'float': float,
|
|
|
|
'int': int,
|
|
|
|
'print': print,
|
2021-06-23 00:53:58 -04:00
|
|
|
'isinstance': isinstance,
|
2021-04-23 17:18:14 -04:00
|
|
|
'getattr': getattr,
|
|
|
|
}
|
2021-04-20 22:29:40 -04:00
|
|
|
|
|
|
|
def visit_Module(self, node):
|
|
|
|
ast.NodeVisitor.generic_visit(self, node)
|
|
|
|
|
|
|
|
def visit_List(self, node):
|
|
|
|
ctx = self.visit(node.ctx)
|
|
|
|
assert ctx is None
|
|
|
|
elts = [self.visit(elt) for elt in node.elts]
|
|
|
|
return elts
|
|
|
|
|
|
|
|
# By design, only non-kernel functions can return
|
|
|
|
def visit_Return(self, node):
|
2021-04-29 18:54:38 -04:00
|
|
|
ret = self.visit(node.value)
|
|
|
|
if ret is None:
|
|
|
|
return self.builder.ret_void()
|
|
|
|
return ret
|
2021-04-20 22:29:40 -04:00
|
|
|
|
|
|
|
def visit_FunctionDef(self, node, inline=False, arg_values=None):
|
|
|
|
arg_names, kwarg_names = self.visit(node.args)
|
|
|
|
# store keyword arguments in local scope
|
|
|
|
self.lscope[kwarg_names] = self.kwargs
|
|
|
|
# initialize function
|
|
|
|
if inline:
|
|
|
|
pass
|
|
|
|
else:
|
|
|
|
fn = self.module.get_or_insert_function(node.name, self.prototype)
|
|
|
|
arg_values = []
|
|
|
|
for i, arg_name in enumerate(arg_names):
|
|
|
|
if i in self.constants:
|
|
|
|
arg_values.append(self.constants[i])
|
|
|
|
else:
|
|
|
|
if i in self.attributes:
|
|
|
|
is_ptr = fn.args[i].type.is_ptr()
|
|
|
|
attr = 'aligned' if is_ptr else 'multiple_of'
|
|
|
|
attr = getattr(_triton.ir.attribute_kind, attr)
|
|
|
|
attr = _triton.ir.attribute(attr, self.attributes[i])
|
|
|
|
fn.add_attr(i + 1, attr)
|
|
|
|
fn.args[i].name = arg_name
|
|
|
|
arg_values.append(fn.args[i])
|
|
|
|
for arg_name, arg_value in zip(arg_names, arg_values):
|
|
|
|
self.set_value(arg_name, arg_value)
|
|
|
|
if inline:
|
2021-05-18 23:04:31 -04:00
|
|
|
self.visit_compound_statement(node.body)
|
2021-04-29 18:54:38 -04:00
|
|
|
return self.last_ret
|
2021-04-20 22:29:40 -04:00
|
|
|
else:
|
|
|
|
entry = _triton.ir.basic_block.create(self.builder.context, "entry", fn)
|
|
|
|
self.module.seal_block(entry)
|
|
|
|
self.builder.set_insert_block(entry)
|
|
|
|
# visit function body
|
2021-05-18 23:04:31 -04:00
|
|
|
self.visit_compound_statement(node.body)
|
2021-04-20 22:29:40 -04:00
|
|
|
# finalize function
|
|
|
|
self.builder.ret_void()
|
|
|
|
|
|
|
|
def visit_arguments(self, node):
|
|
|
|
arg_names = []
|
|
|
|
for arg in node.args:
|
|
|
|
arg_names += [self.visit(arg)]
|
|
|
|
kwarg_names = self.visit(node.kwarg)
|
|
|
|
return arg_names, kwarg_names
|
|
|
|
|
|
|
|
def visit_arg(self, node):
|
|
|
|
ast.NodeVisitor.generic_visit(self, node)
|
|
|
|
return node.arg
|
|
|
|
|
|
|
|
def visit_Assign(self, node):
|
2021-05-20 14:12:04 -04:00
|
|
|
_names = []
|
2021-04-20 22:29:40 -04:00
|
|
|
for target in node.targets:
|
2021-05-20 14:12:04 -04:00
|
|
|
_names += [self.visit(target)]
|
|
|
|
assert len(_names) == 1
|
|
|
|
names = _names[0]
|
|
|
|
values = self.visit(node.value)
|
|
|
|
if not isinstance(names, tuple):
|
|
|
|
names = [names]
|
|
|
|
if not isinstance(values, tuple):
|
|
|
|
values = [values]
|
|
|
|
for name, value in zip(names, values):
|
|
|
|
if not isinstance(value, triton.language.block):
|
2021-09-02 13:30:14 -07:00
|
|
|
value = triton.language.core._to_ir(value, self.builder)
|
2021-05-20 14:12:04 -04:00
|
|
|
self.set_value(name, value)
|
2021-04-20 22:29:40 -04:00
|
|
|
|
|
|
|
def visit_AugAssign(self, node):
|
|
|
|
name = node.target.id
|
|
|
|
lhs = ast.Name(id=name, ctx=ast.Load())
|
|
|
|
rhs = ast.BinOp(lhs, node.op, node.value)
|
|
|
|
assign = ast.Assign(targets=[node.target], value=rhs)
|
|
|
|
self.visit(assign)
|
|
|
|
return self.get_value(name)
|
|
|
|
|
|
|
|
def visit_Name(self, node):
|
|
|
|
if type(node.ctx) == ast.Store:
|
|
|
|
return node.id
|
|
|
|
return self.get_value(node.id)
|
|
|
|
|
|
|
|
def visit_Store(self, node):
|
|
|
|
ast.NodeVisitor.generic_visit(self, node)
|
|
|
|
|
|
|
|
def visit_Load(self, node):
|
|
|
|
ast.NodeVisitor.generic_visit(self, node)
|
|
|
|
|
|
|
|
def visit_Tuple(self, node):
|
|
|
|
args = [self.visit(x) for x in node.elts]
|
|
|
|
return tuple(args)
|
|
|
|
|
|
|
|
def visit_BinOp(self, node):
|
|
|
|
lhs = self.visit(node.left)
|
|
|
|
rhs = self.visit(node.right)
|
|
|
|
fn = {
|
|
|
|
ast.Add: '__add__',
|
|
|
|
ast.Sub: '__sub__',
|
|
|
|
ast.Mult: '__mul__',
|
|
|
|
ast.Div: '__truediv__',
|
|
|
|
ast.FloorDiv: '__floordiv__',
|
|
|
|
ast.Mod: '__mod__',
|
|
|
|
ast.Pow: '__pow__',
|
|
|
|
ast.LShift: '__lshift__',
|
|
|
|
ast.RShift: '__rshift__',
|
|
|
|
ast.BitAnd: '__and__',
|
|
|
|
ast.BitOr: '__or__',
|
|
|
|
ast.BitXor: '__xor__',
|
|
|
|
}[type(node.op)]
|
|
|
|
kws = dict()
|
|
|
|
|
|
|
|
if self.is_triton_object(lhs):
|
2021-08-18 11:15:53 -07:00
|
|
|
kws['_builder'] = self.builder
|
2021-04-20 22:29:40 -04:00
|
|
|
ret = getattr(lhs, fn)(rhs, **kws)
|
|
|
|
if ret is NotImplemented:
|
|
|
|
if self.is_triton_object(rhs):
|
2021-08-18 11:15:53 -07:00
|
|
|
kws['_builder'] = self.builder
|
2021-04-20 22:29:40 -04:00
|
|
|
fn = fn[:2] + 'r' + fn[2:]
|
|
|
|
ret = getattr(rhs, fn)(lhs, **kws)
|
|
|
|
return ret
|
|
|
|
|
|
|
|
def visit_If(self, node):
|
|
|
|
cond = self.visit(node.test)
|
|
|
|
if self.is_triton_object(cond):
|
|
|
|
current_bb = self.builder.get_insert_block()
|
|
|
|
then_bb = _triton.ir.basic_block.create(self.builder.context, "then", current_bb.parent)
|
|
|
|
else_bb = _triton.ir.basic_block.create(self.builder.context, "else", current_bb.parent) if node.orelse else None
|
|
|
|
endif_bb = _triton.ir.basic_block.create(self.builder.context, "endif", current_bb.parent)
|
|
|
|
self.module.seal_block(then_bb)
|
|
|
|
if else_bb:
|
|
|
|
self.module.seal_block(else_bb)
|
|
|
|
self.builder.cond_br(cond.handle, then_bb, else_bb)
|
|
|
|
else:
|
|
|
|
self.builder.cond_br(cond.handle, then_bb, endif_bb)
|
|
|
|
self.builder.set_insert_block(then_bb)
|
2021-05-18 23:04:31 -04:00
|
|
|
is_terminator = self.visit_compound_statement(node.body)
|
2021-04-20 22:29:40 -04:00
|
|
|
# TODO: last statement is a terminator?
|
2021-04-29 18:54:38 -04:00
|
|
|
if not is_terminator:
|
|
|
|
self.builder.br(endif_bb)
|
2021-04-20 22:29:40 -04:00
|
|
|
if else_bb:
|
|
|
|
self.builder.set_insert_block(else_bb)
|
2021-05-18 23:04:31 -04:00
|
|
|
is_terminator = self.visit_compound_statement(node.orelse)
|
2021-04-20 22:29:40 -04:00
|
|
|
#TODO: last statement is a terminator?
|
2021-04-29 18:54:38 -04:00
|
|
|
if not is_terminator:
|
|
|
|
self.builder.br(endif_bb)
|
2021-04-20 22:29:40 -04:00
|
|
|
self.module.seal_block(endif_bb)
|
|
|
|
self.builder.set_insert_block(endif_bb)
|
|
|
|
else:
|
|
|
|
if cond:
|
|
|
|
self.visit_compound_statement(node.body)
|
|
|
|
else:
|
|
|
|
self.visit_compound_statement(node.orelse)
|
|
|
|
|
|
|
|
def visit_IfExp(self, node):
|
|
|
|
cond = self.visit(node.test)
|
|
|
|
if cond:
|
|
|
|
return self.visit(node.body)
|
|
|
|
else:
|
|
|
|
return self.visit(node.orelse)
|
|
|
|
|
|
|
|
def visit_Pass(self, node):
|
|
|
|
pass
|
|
|
|
|
|
|
|
def visit_Compare(self, node):
|
|
|
|
assert len(node.comparators) == 1
|
|
|
|
assert len(node.ops) == 1
|
|
|
|
lhs = self.visit(node.left)
|
|
|
|
rhs = self.visit(node.comparators[0])
|
|
|
|
fn = {
|
|
|
|
ast.Eq: '__eq__',
|
|
|
|
ast.NotEq: '__ne__',
|
|
|
|
ast.Lt: '__lt__',
|
|
|
|
ast.LtE: '__le__',
|
|
|
|
ast.Gt: '__gt__',
|
|
|
|
ast.GtE: '__ge__',
|
|
|
|
ast.Is: '__eq__',
|
|
|
|
ast.IsNot: '__ne__',
|
|
|
|
}[type(node.ops[0])]
|
2021-04-29 09:13:45 -04:00
|
|
|
if self.is_triton_object(lhs):
|
2021-08-18 11:15:53 -07:00
|
|
|
return getattr(lhs, fn)(rhs, _builder=self.builder)
|
2021-04-29 09:13:45 -04:00
|
|
|
elif self.is_triton_object(rhs):
|
|
|
|
fn = fn[:2] + 'r' + fn[2:]
|
2021-08-18 11:15:53 -07:00
|
|
|
return getattr(rhs, fn)(lhs, _builder=self.builder)
|
2021-04-29 09:13:45 -04:00
|
|
|
else:
|
|
|
|
return getattr(lhs, fn)(rhs)
|
2021-04-20 22:29:40 -04:00
|
|
|
|
|
|
|
def visit_UnaryOp(self, node):
|
|
|
|
op = self.visit(node.operand)
|
|
|
|
fn = {
|
|
|
|
ast.USub: '__neg__',
|
|
|
|
ast.UAdd: '__pos__',
|
|
|
|
ast.Invert: '__invert__',
|
|
|
|
}[type(node.op)]
|
|
|
|
if self.is_triton_object(op):
|
2021-08-18 11:15:53 -07:00
|
|
|
return getattr(op, fn)(_builder=self.builder)
|
2021-04-20 22:29:40 -04:00
|
|
|
return getattr(op, fn)()
|
|
|
|
|
|
|
|
def visit_While(self, node):
|
|
|
|
current_bb = self.builder.get_insert_block()
|
|
|
|
loop_bb = _triton.ir.basic_block.create(self.module.builder.context, "loop", current_bb.parent)
|
|
|
|
next_bb = _triton.ir.basic_block.create(self.module.builder.context, "postloop", current_bb.parent)
|
|
|
|
|
|
|
|
def continue_fn():
|
|
|
|
cond = self.visit(node.test)
|
|
|
|
return self.builder.cond_br(cond.handle, loop_bb, next_bb)
|
|
|
|
|
|
|
|
continue_fn()
|
|
|
|
self.builder.set_insert_block(loop_bb)
|
2021-05-18 23:04:31 -04:00
|
|
|
self.visit_compound_statement(node.body)
|
2021-04-20 22:29:40 -04:00
|
|
|
continue_fn()
|
|
|
|
stop_bb = self.builder.get_insert_block()
|
|
|
|
self.module.seal_block(stop_bb)
|
|
|
|
self.module.seal_block(loop_bb)
|
|
|
|
self.module.seal_block(next_bb)
|
|
|
|
self.builder.set_insert_block(next_bb)
|
|
|
|
|
|
|
|
for stmt in node.orelse:
|
|
|
|
ast.NodeVisitor.generic_visit(self, stmt)
|
|
|
|
|
|
|
|
def visit_Str(self, node):
|
|
|
|
return ast.literal_eval(node)
|
|
|
|
|
|
|
|
def visit_Subscript(self, node):
|
|
|
|
assert node.ctx.__class__.__name__ == "Load"
|
|
|
|
lhs = self.visit(node.value)
|
|
|
|
slices = self.visit(node.slice)
|
|
|
|
if self.is_triton_object(lhs):
|
2021-08-18 11:15:53 -07:00
|
|
|
return lhs.__getitem__(slices, _builder=self.builder)
|
2021-04-20 22:29:40 -04:00
|
|
|
return lhs[slices]
|
|
|
|
|
|
|
|
def visit_ExtSlice(self, node):
|
|
|
|
return [self.visit(dim) for dim in node.dims]
|
|
|
|
|
|
|
|
def visit_For(self, node):
|
|
|
|
iterator = self.visit(node.iter.func)
|
2021-08-14 10:11:18 -07:00
|
|
|
if iterator != self.builtins['range']:
|
|
|
|
raise RuntimeError('Only `range` iterator currently supported')
|
2021-04-20 22:29:40 -04:00
|
|
|
# create nodes
|
|
|
|
st_target = ast.Name(id=node.target.id, ctx=ast.Store())
|
|
|
|
ld_target = ast.Name(id=node.target.id, ctx=ast.Load())
|
2021-08-14 10:11:18 -07:00
|
|
|
arg_0 = node.iter.args[0] if len(node.iter.args) > 1 else ast.Num(0)
|
|
|
|
arg_1 = node.iter.args[1] if len(node.iter.args) > 1 else node.iter.args[0]
|
|
|
|
arg_2 = node.iter.args[2] if len(node.iter.args) > 2 else ast.Num(1)
|
|
|
|
init_node = ast.Assign(targets=[st_target], value=arg_0)
|
|
|
|
pos_cond_node = ast.Compare(ld_target, [ast.Lt()], [arg_1])
|
|
|
|
neg_cond_node = ast.Compare(ld_target, [ast.Gt()], [arg_1])
|
|
|
|
pos_step_node = ast.Compare(arg_2, [ast.Gt()], [ast.Num(0)])
|
2021-04-23 17:18:14 -04:00
|
|
|
build_cond = lambda: triton.language.where(self.visit(pos_step_node),\
|
2021-04-20 22:29:40 -04:00
|
|
|
self.visit(pos_cond_node),\
|
|
|
|
self.visit(neg_cond_node),\
|
2021-08-18 11:15:53 -07:00
|
|
|
_builder=self.builder)
|
2021-04-20 22:29:40 -04:00
|
|
|
#cond_node = neg_cond_node
|
2021-08-14 10:11:18 -07:00
|
|
|
step_node = ast.AugAssign(target=st_target, op=ast.Add(), value=arg_2)
|
2021-04-20 22:29:40 -04:00
|
|
|
# code generation
|
|
|
|
current_bb = self.builder.get_insert_block()
|
|
|
|
loop_bb = _triton.ir.basic_block.create(self.module.builder.context, "loop", current_bb.parent)
|
|
|
|
next_bb = _triton.ir.basic_block.create(self.module.builder.context, "postloop", current_bb.parent)
|
|
|
|
|
|
|
|
def continue_fn():
|
|
|
|
self.visit(step_node)
|
|
|
|
cond = build_cond()
|
|
|
|
return self.builder.cond_br(cond.handle, loop_bb, next_bb)
|
|
|
|
|
|
|
|
self.visit(init_node)
|
|
|
|
cond = build_cond()
|
|
|
|
self.builder.cond_br(cond.handle, loop_bb, next_bb)
|
|
|
|
self.builder.set_insert_block(loop_bb)
|
2021-05-18 23:04:31 -04:00
|
|
|
self.visit_compound_statement(node.body)
|
2021-04-20 22:29:40 -04:00
|
|
|
# TODO: handle case where body breaks control flow
|
|
|
|
continue_fn()
|
|
|
|
stop_bb = self.builder.get_insert_block()
|
|
|
|
self.module.seal_block(stop_bb)
|
|
|
|
self.module.seal_block(loop_bb)
|
|
|
|
self.module.seal_block(next_bb)
|
|
|
|
self.builder.set_insert_block(next_bb)
|
|
|
|
|
|
|
|
for stmt in node.orelse:
|
|
|
|
ast.NodeVisitor.generic_visit(self, stmt)
|
|
|
|
|
|
|
|
def visit_Slice(self, node):
|
|
|
|
lower = self.visit(node.lower)
|
|
|
|
upper = self.visit(node.upper)
|
|
|
|
step = self.visit(node.step)
|
|
|
|
return slice(lower, upper, step)
|
|
|
|
|
|
|
|
def visit_Index(self, node):
|
|
|
|
return self.visit(node.value)
|
|
|
|
|
|
|
|
def visit_NameConstant(self, node):
|
|
|
|
return node.value
|
|
|
|
|
|
|
|
def visit_keyword(self, node):
|
|
|
|
return {node.arg: self.visit(node.value)}
|
|
|
|
|
|
|
|
def visit_Call(self, node):
|
|
|
|
fn = self.visit(node.func)
|
|
|
|
kws = dict()
|
|
|
|
for keyword in node.keywords:
|
|
|
|
kws.update(self.visit(keyword))
|
|
|
|
args = [self.visit(arg) for arg in node.args]
|
|
|
|
if isinstance(fn, JITFunction):
|
|
|
|
return fn(*args, generator=self, **kws)
|
|
|
|
if hasattr(fn, '__self__') and self.is_triton_object(fn.__self__) or \
|
2021-09-02 13:30:14 -07:00
|
|
|
sys.modules[fn.__module__] is triton.language.core:
|
2021-08-18 11:15:53 -07:00
|
|
|
return fn(*args, _builder=self.builder, **kws)
|
2021-04-20 22:29:40 -04:00
|
|
|
return fn(*args, **kws)
|
|
|
|
|
|
|
|
def visit_Num(self, node):
|
|
|
|
return node.n
|
|
|
|
|
|
|
|
def visit_Attribute(self, node):
|
|
|
|
lhs = self.visit(node.value)
|
|
|
|
return getattr(lhs, node.attr)
|
|
|
|
|
|
|
|
def visit_Expr(self, node):
|
|
|
|
ast.NodeVisitor.generic_visit(self, node)
|
|
|
|
|
|
|
|
def visit_NoneType(self, node):
|
|
|
|
return None
|
|
|
|
|
|
|
|
def visit(self, node):
|
|
|
|
if node is not None:
|
|
|
|
self.last_node = node
|
|
|
|
return super().visit(node)
|
|
|
|
|
|
|
|
def generic_visit(self, node):
|
|
|
|
typename = type(node).__name__
|
|
|
|
raise NotImplementedError("Unsupported node: {}".format(typename))
|
|
|
|
|
|
|
|
|
|
|
|
class Binary:
|
2021-09-18 22:48:26 -07:00
|
|
|
def __init__(self, backend, name, asm, shared_mem, num_warps):
|
|
|
|
self.backend = backend
|
|
|
|
self.name = name
|
2021-09-09 00:04:28 -07:00
|
|
|
self.asm = asm
|
2021-04-20 22:29:40 -04:00
|
|
|
self.shared_mem = shared_mem
|
|
|
|
self.num_warps = num_warps
|
2021-09-18 22:48:26 -07:00
|
|
|
|
|
|
|
class LoadedBinary:
|
|
|
|
def __init__(self, device: int, bin: Binary):
|
|
|
|
module, kernel = _triton.code_gen.load_binary(bin.backend,
|
|
|
|
bin.name,
|
|
|
|
bin.asm,
|
|
|
|
bin.shared_mem,
|
|
|
|
device)
|
|
|
|
self.bin = bin
|
|
|
|
self.asm = bin.asm
|
|
|
|
self.module = module
|
|
|
|
self.kernel = kernel
|
|
|
|
self.device = device
|
2021-04-22 21:50:19 -04:00
|
|
|
|
2021-04-20 22:29:40 -04:00
|
|
|
def __call__(self, stream, args, grid_0, grid_1=1, grid_2=1):
|
2021-09-18 22:48:26 -07:00
|
|
|
_triton.runtime.enqueue(self.bin.backend, stream, self.kernel,
|
2021-09-09 00:04:28 -07:00
|
|
|
grid_0, grid_1, grid_2,
|
2021-09-18 22:48:26 -07:00
|
|
|
self.bin.num_warps * 32, 1, 1,
|
|
|
|
args, self.bin.shared_mem)
|
2021-04-20 22:29:40 -04:00
|
|
|
|
|
|
|
|
|
|
|
class CompilationError(Exception):
|
|
|
|
def __init__(self, src, node, err):
|
|
|
|
self.message = '\n'.join(src.split('\n')[:node.lineno])
|
|
|
|
self.message += '\n' + ' ' * node.col_offset + '^'
|
|
|
|
self.message += '\n Error: ' + str(err)
|
|
|
|
super().__init__(self.message)
|
|
|
|
|
2021-08-21 06:00:54 +02:00
|
|
|
|
2021-06-21 14:25:13 +08:00
|
|
|
class OutOfResources(Exception):
|
|
|
|
def __init__(self, required, limit, name):
|
|
|
|
self.message = f'out of resource: {name}'\
|
|
|
|
f'Required: {required}'\
|
|
|
|
f'Hardware limit: {limit}'
|
|
|
|
super().__init__(self.message)
|
|
|
|
|
2021-04-20 22:29:40 -04:00
|
|
|
|
|
|
|
class Kernel:
|
2021-05-01 14:34:33 -04:00
|
|
|
@staticmethod
|
|
|
|
def _type_name(obj):
|
|
|
|
type_names = {
|
|
|
|
int: 'I',
|
|
|
|
float: 'f',
|
|
|
|
bool: 'B',
|
|
|
|
triton.language.float8: 'f8',
|
2021-06-25 10:19:29 -04:00
|
|
|
torch.bfloat16: 'bf16',
|
2021-05-01 14:34:33 -04:00
|
|
|
torch.float16: 'f16',
|
|
|
|
torch.float32: 'f32',
|
|
|
|
torch.float64: 'f64',
|
|
|
|
torch.bool: 'i1',
|
|
|
|
torch.int8: 'i8',
|
|
|
|
torch.int16: 'i16',
|
|
|
|
torch.int32: 'i32',
|
|
|
|
torch.int64: 'i64',
|
|
|
|
}
|
|
|
|
return type_names[obj]
|
2021-04-20 22:29:40 -04:00
|
|
|
|
|
|
|
@staticmethod
|
|
|
|
def _to_triton_ir(context, obj):
|
|
|
|
type_map = {
|
|
|
|
'I': _triton.ir.type.get_int32,
|
|
|
|
'f': _triton.ir.type.get_fp32,
|
|
|
|
'B': _triton.ir.type.get_int1,
|
2021-05-01 14:34:33 -04:00
|
|
|
'f8': _triton.ir.type.get_fp8,
|
2021-04-20 22:29:40 -04:00
|
|
|
'f16': _triton.ir.type.get_fp16,
|
2021-06-25 10:19:29 -04:00
|
|
|
'bf16': _triton.ir.type.get_bf16,
|
2021-04-20 22:29:40 -04:00
|
|
|
'f32': _triton.ir.type.get_fp32,
|
|
|
|
'f64': _triton.ir.type.get_fp64,
|
|
|
|
'i1': _triton.ir.type.get_int1,
|
|
|
|
'i8': _triton.ir.type.get_int8,
|
|
|
|
'i16': _triton.ir.type.get_int16,
|
|
|
|
'i32': _triton.ir.type.get_int32,
|
|
|
|
'i64': _triton.ir.type.get_int64,
|
|
|
|
}
|
|
|
|
# convert torch.Tensor to Triton IR pointers
|
2021-05-01 14:34:33 -04:00
|
|
|
if hasattr(obj, 'data_ptr'):
|
|
|
|
name = Kernel._type_name(obj.dtype)
|
2021-04-20 22:29:40 -04:00
|
|
|
elt_ty = type_map[name](context)
|
|
|
|
return _triton.ir.type.make_ptr(elt_ty, 1)
|
|
|
|
# default path returns triton.ir.type directly
|
2021-05-01 14:34:33 -04:00
|
|
|
name = Kernel._type_name(obj.__class__)
|
2021-04-20 22:29:40 -04:00
|
|
|
return type_map[name](context)
|
|
|
|
|
|
|
|
@staticmethod
|
|
|
|
def _types_key(*wargs, tensor_idxs):
|
|
|
|
# type inference
|
|
|
|
types_key = [None] * len(wargs)
|
|
|
|
for i, arg in enumerate(wargs):
|
|
|
|
prefix = 'P' if i in tensor_idxs else ''
|
2021-05-01 14:34:33 -04:00
|
|
|
suffix = Kernel._type_name(arg.dtype) if i in tensor_idxs else Kernel._type_name(arg.__class__)
|
2021-04-20 22:29:40 -04:00
|
|
|
types_key[i] = prefix + suffix
|
|
|
|
return tuple(types_key)
|
|
|
|
|
|
|
|
@staticmethod
|
|
|
|
def pow2_divisor(N):
|
|
|
|
if N % 16 == 0: return 16
|
|
|
|
if N % 8 == 0: return 8
|
|
|
|
if N % 4 == 0: return 4
|
|
|
|
if N % 2 == 0: return 2
|
|
|
|
return 1
|
|
|
|
|
|
|
|
def __init__(self, fn):
|
|
|
|
self.fn = fn
|
|
|
|
|
2021-07-20 17:58:06 -04:00
|
|
|
def _compile(self, *wargs, device, attributes, constants, num_warps, num_stages, force_nc_cache, **meta):
|
2021-04-20 22:29:40 -04:00
|
|
|
# create IR module
|
|
|
|
context = _triton.ir.context()
|
|
|
|
# get just-in-time proto-type of kernel
|
|
|
|
arg_types = [Kernel._to_triton_ir(context, arg) for arg in wargs]
|
|
|
|
ret_type = _triton.ir.type.get_void(context)
|
|
|
|
prototype = _triton.ir.type.make_function(ret_type, arg_types)
|
|
|
|
# generate Triton-IR
|
|
|
|
# export symbols visible from self.fn into code-generator object
|
|
|
|
gscope = sys.modules[self.fn.module].__dict__
|
|
|
|
generator = CodeGenerator(context, prototype, gscope=gscope, attributes=attributes, constants=constants, kwargs=meta)
|
|
|
|
try:
|
|
|
|
generator.visit(self.fn.parse())
|
|
|
|
except Exception as e:
|
|
|
|
node = generator.last_node
|
|
|
|
if node is None or isinstance(e, (NotImplementedError, CompilationError)):
|
|
|
|
raise e
|
|
|
|
raise CompilationError(self.fn.src, node, e)
|
|
|
|
# Compile to machine code
|
2021-09-09 00:04:28 -07:00
|
|
|
if torch.version.hip is None:
|
|
|
|
backend = _triton.runtime.backend.CUDA
|
|
|
|
else:
|
|
|
|
backend = _triton.runtime.backend.ROCM
|
2021-09-18 22:48:26 -07:00
|
|
|
name, asm, shared_mem = _triton.code_gen.compile_ttir(backend, generator.module, device, num_warps, num_stages, force_nc_cache)
|
2021-09-09 00:04:28 -07:00
|
|
|
max_shared_memory = _triton.runtime.max_shared_memory(backend, device)
|
|
|
|
if shared_mem > max_shared_memory:
|
|
|
|
raise OutOfResources(shared_mem, max_shared_memory, "shared memory")
|
2021-09-18 22:48:26 -07:00
|
|
|
return Binary(backend, name, asm, shared_mem, num_warps)
|
2021-04-20 22:29:40 -04:00
|
|
|
|
2021-07-20 17:58:06 -04:00
|
|
|
def __call__(self, *wargs, grid, num_warps=4, num_stages=2, force_nc_cache=False, **meta):
|
2021-04-20 22:29:40 -04:00
|
|
|
# device inference
|
2021-05-01 14:34:33 -04:00
|
|
|
tensor_idxs = [i for i, arg in enumerate(wargs) if hasattr(arg, 'data_ptr')]
|
2021-04-20 22:29:40 -04:00
|
|
|
if len(tensor_idxs) == 0:
|
|
|
|
raise ValueError("No Tensor argument found.")
|
2021-07-30 09:47:27 -07:00
|
|
|
invalid_args = []
|
2021-08-21 06:00:54 +02:00
|
|
|
device_ids = []
|
2021-07-30 09:47:27 -07:00
|
|
|
for idx in tensor_idxs:
|
|
|
|
curr = wargs[idx]
|
|
|
|
if not curr.is_cuda:
|
2021-08-21 06:00:54 +02:00
|
|
|
invalid_args.append(idx)
|
|
|
|
else:
|
|
|
|
device_ids.append(curr.device.index)
|
2021-07-30 09:47:27 -07:00
|
|
|
if invalid_args:
|
|
|
|
raise ValueError("Arguments at index {invalid_args} are on the wrong device.".format(invalid_args=invalid_args) +
|
|
|
|
" Only CUDA is supported at the moment")
|
2021-08-21 06:00:54 +02:00
|
|
|
|
|
|
|
device = torch.device('cuda', torch.cuda.current_device())
|
2021-09-09 00:04:28 -07:00
|
|
|
device_ty = device.type
|
|
|
|
device_idx = device.index
|
|
|
|
if len(set(device_ids)) != 1 or device_ids[0] != device_idx:
|
2021-08-21 06:00:54 +02:00
|
|
|
# try to enable P2P communication
|
|
|
|
for arg_idx, dst_idx in zip(tensor_idxs, device_ids):
|
2021-09-09 00:04:28 -07:00
|
|
|
if dst_idx != device_idx:
|
2021-08-21 06:00:54 +02:00
|
|
|
try:
|
2021-09-09 00:04:28 -07:00
|
|
|
_triton.runtime.enable_peer_access(self.backend, wargs[arg_idx].data_ptr())
|
2021-08-21 06:00:54 +02:00
|
|
|
except RuntimeError as e:
|
|
|
|
raise RuntimeError("Cannot enable P2P access from device {} to device {}: {}"
|
2021-09-09 00:04:28 -07:00
|
|
|
.format(device_idx, dst_idx, str(e)))
|
2021-08-21 06:00:54 +02:00
|
|
|
|
|
|
|
# enqueue kernel on the current device
|
2021-09-09 00:04:28 -07:00
|
|
|
torch.cuda.set_device(device_idx)
|
2021-04-20 22:29:40 -04:00
|
|
|
# attributes
|
|
|
|
args = [arg.data_ptr() if i in tensor_idxs else arg for i, arg in enumerate(wargs)]
|
|
|
|
attributes = {i: Kernel.pow2_divisor(a) for i, a in enumerate(args) if isinstance(a, int)}
|
|
|
|
# transforms ints whose value is one into constants for just-in-time compilation
|
|
|
|
constants = {i: arg for i, arg in enumerate(wargs) if isinstance(arg, int) and arg == 1}
|
2021-09-18 22:48:26 -07:00
|
|
|
# compute hash for caching this kernel
|
2021-04-20 22:29:40 -04:00
|
|
|
types_key = Kernel._types_key(*wargs, tensor_idxs=tensor_idxs)
|
|
|
|
attr_key = frozenset(attributes.items())
|
|
|
|
meta_key = frozenset(meta.items())
|
|
|
|
const_key = frozenset(constants.items())
|
2021-09-09 00:04:28 -07:00
|
|
|
key = (device_ty, device_idx, types_key, attr_key, num_warps, num_stages, meta_key, const_key)
|
2021-09-18 22:48:26 -07:00
|
|
|
key = repr(key)
|
|
|
|
# get cached binary
|
|
|
|
drv_cache = self.fn.drv_cache
|
2021-09-21 16:36:24 -07:00
|
|
|
bin_mut_path = self.fn.bin_mut_path
|
2021-09-18 22:48:26 -07:00
|
|
|
bin_cache_path = self.fn.bin_cache_path
|
|
|
|
bin_lock_path = self.fn.bin_lock_path
|
|
|
|
if key not in drv_cache:
|
|
|
|
binary = None
|
|
|
|
if bin_lock_path:
|
|
|
|
with FileLock(bin_lock_path):
|
2021-09-21 16:36:24 -07:00
|
|
|
dbtype = dbm.whichdb(bin_cache_path)
|
|
|
|
# handle stale/corrupted cache if it exists
|
|
|
|
if dbtype is not None:
|
|
|
|
# some db types can create multiple files
|
|
|
|
exts = {'dbm.gnu': [''], 'dbm.ndbm': ['.db'],
|
2021-09-21 17:03:41 -07:00
|
|
|
'dbm.dumb': ['.dir', '.dat']}[dbtype]
|
2021-09-21 16:36:24 -07:00
|
|
|
db_paths = [bin_cache_path + ext for ext in exts]
|
|
|
|
# check if the cache is stale
|
|
|
|
frontend_mtime = os.path.getmtime(triton.code_gen.__file__)
|
|
|
|
backend_mtime = os.path.getmtime(triton._C.libtriton.__file__)
|
|
|
|
cache_mtime = max([os.path.getmtime(db) for db in db_paths])
|
|
|
|
is_stale = frontend_mtime > cache_mtime or backend_mtime > cache_mtime
|
|
|
|
# check if the cache is corrupted
|
|
|
|
is_corrupted = os.path.exists(bin_mut_path)
|
|
|
|
# delete the cache if stale or corrupted
|
|
|
|
if is_stale or is_corrupted:
|
|
|
|
for db in db_paths:
|
|
|
|
os.remove(db)
|
|
|
|
if is_corrupted:
|
|
|
|
os.remove(bin_mut_path)
|
|
|
|
# read the cache, creating if needed
|
2021-09-18 22:48:26 -07:00
|
|
|
with shelve.open(bin_cache_path) as db:
|
|
|
|
binary = db.get(key, None)
|
|
|
|
if binary is None:
|
|
|
|
binary = self._compile(
|
|
|
|
*wargs, device=device_idx, attributes=attributes,
|
|
|
|
num_warps=num_warps, num_stages=num_stages, force_nc_cache=force_nc_cache,
|
|
|
|
constants=constants, **meta
|
|
|
|
)
|
|
|
|
if bin_lock_path:
|
|
|
|
with FileLock(bin_lock_path):
|
2021-09-21 16:36:24 -07:00
|
|
|
open(bin_mut_path, 'a').close()
|
|
|
|
with shelve.open(bin_cache_path) as db:
|
|
|
|
db[key] = binary
|
|
|
|
os.remove(bin_mut_path)
|
2021-09-21 14:10:02 -07:00
|
|
|
|
2021-09-18 22:48:26 -07:00
|
|
|
drv_cache[key] = LoadedBinary(device_idx, binary)
|
2021-04-20 22:29:40 -04:00
|
|
|
# pack arguments
|
2021-05-01 14:34:33 -04:00
|
|
|
fmt = ''.join(['P' if i in tensor_idxs else Kernel._type_name(arg.__class__) for i, arg in enumerate(wargs)])
|
2021-04-20 22:29:40 -04:00
|
|
|
params = struct.pack(fmt, *args)
|
|
|
|
# enqueue cached function into stream
|
2021-09-18 22:48:26 -07:00
|
|
|
callable = drv_cache[key]
|
2021-09-09 00:04:28 -07:00
|
|
|
stream = torch.cuda.current_stream(device_idx).cuda_stream
|
2021-04-20 22:29:40 -04:00
|
|
|
grid = grid(meta) if hasattr(grid, '__call__') else grid
|
2021-09-18 22:48:26 -07:00
|
|
|
callable(stream, params, *grid)
|
|
|
|
return callable
|
2021-04-20 22:29:40 -04:00
|
|
|
|
|
|
|
|
|
|
|
class Launcher:
|
|
|
|
def __init__(self, kernel, grid):
|
|
|
|
self.kernel = kernel
|
|
|
|
self.grid = grid
|
|
|
|
|
|
|
|
def __call__(self, *wargs, **kwargs):
|
2021-04-22 21:50:19 -04:00
|
|
|
return self.kernel(*wargs, **kwargs, grid=self.grid)
|
2021-04-20 22:29:40 -04:00
|
|
|
|
|
|
|
|
|
|
|
class Autotuner:
|
2021-08-14 10:58:38 -07:00
|
|
|
def __init__(self, kernel, arg_names, configs, key, reset_to_zero):
|
2021-04-20 22:29:40 -04:00
|
|
|
if not configs:
|
2021-06-21 14:25:13 +08:00
|
|
|
self.configs = [Config(dict(), num_warps=4, num_stages=2)]
|
2021-04-20 22:29:40 -04:00
|
|
|
else:
|
|
|
|
self.configs = configs
|
|
|
|
self.key_idx = [arg_names.index(k) for k in key]
|
|
|
|
self.cache = dict()
|
|
|
|
self.kernel = kernel
|
2021-08-14 10:58:38 -07:00
|
|
|
# hook to reset all required tensor to zeros before relaunching a kernel
|
|
|
|
self.hook = lambda args: 0
|
|
|
|
if reset_to_zero is not None:
|
|
|
|
self.reset_idx = [arg_names.index(k) for k in reset_to_zero]
|
|
|
|
def _hook(args):
|
|
|
|
for i in self.reset_idx:
|
|
|
|
args[i].zero_()
|
|
|
|
self.hook = _hook
|
2021-04-20 22:29:40 -04:00
|
|
|
|
|
|
|
def _bench(self, *args, config, **meta):
|
|
|
|
# check for conflicts, i.e. meta-parameters both provided
|
|
|
|
# as kwargs and by the autotuner
|
|
|
|
conflicts = meta.keys() & config.meta.keys()
|
|
|
|
if conflicts:
|
|
|
|
raise ValueError(
|
|
|
|
f"Conflicting meta-parameters: {', '.join(conflicts)}."
|
|
|
|
" Make sure that you don't re-define auto-tuned symbols."
|
|
|
|
)
|
|
|
|
# augment meta-parameters with tunable ones
|
|
|
|
current = dict(meta, **config.meta)
|
2021-08-14 10:58:38 -07:00
|
|
|
def kernel_call():
|
|
|
|
self.hook(args)
|
|
|
|
self.kernel(*args, num_warps=config.num_warps, num_stages=config.num_stages, **current)
|
2021-04-20 22:29:40 -04:00
|
|
|
return triton.testing.do_bench(kernel_call)
|
|
|
|
|
|
|
|
def __call__(self, *args, **meta):
|
|
|
|
if len(self.configs) > 1:
|
|
|
|
key = tuple([args[i] for i in self.key_idx])
|
|
|
|
if key not in self.cache:
|
|
|
|
timings = {config: self._bench(*args, config=config, **meta) \
|
|
|
|
for config in self.configs}
|
|
|
|
self.cache[key] = builtins.min(timings, key=timings.get)
|
2021-08-14 10:58:38 -07:00
|
|
|
self.hook(args)
|
2021-04-20 22:29:40 -04:00
|
|
|
config = self.cache[key]
|
|
|
|
else:
|
|
|
|
config = self.configs[0]
|
2021-06-21 14:25:13 +08:00
|
|
|
return self.kernel(*args, num_warps=config.num_warps, num_stages=config.num_stages, **meta, **config.meta)
|
2021-04-20 22:29:40 -04:00
|
|
|
|
|
|
|
|
2021-09-21 14:10:02 -07:00
|
|
|
|
|
|
|
|
2021-04-20 22:29:40 -04:00
|
|
|
class JITFunction:
|
2021-09-18 22:48:26 -07:00
|
|
|
|
|
|
|
def _init_cache_paths(self):
|
|
|
|
# fetch cache directory path
|
|
|
|
cache_dir = os.environ.get('TRITON_CACHE_DIR', '/tmp/triton/')
|
|
|
|
if not cache_dir:
|
|
|
|
self.bin_cache_path = None
|
|
|
|
self.bin_lock_path = None
|
|
|
|
return
|
|
|
|
# create cache directory
|
|
|
|
if not os.path.exists(cache_dir):
|
2021-09-18 23:44:21 -07:00
|
|
|
os.makedirs(cache_dir, exist_ok=True)
|
2021-09-18 22:48:26 -07:00
|
|
|
# create md5 hash of src
|
|
|
|
md5 = hashlib.md5()
|
|
|
|
md5.update(self.src.encode('utf-8'))
|
|
|
|
md5_hash = md5.hexdigest()
|
|
|
|
# load dbm file in cache_dir for md5_hash
|
|
|
|
self.bin_cache_path = os.path.join(cache_dir, md5_hash)
|
|
|
|
self.bin_lock_path = self.bin_cache_path + '.lock'
|
2021-09-21 16:36:24 -07:00
|
|
|
self.bin_mut_path = self.bin_cache_path + '.mutating'
|
2021-09-18 22:48:26 -07:00
|
|
|
|
2021-04-20 22:29:40 -04:00
|
|
|
def __init__(self, fn):
|
2021-09-18 22:48:26 -07:00
|
|
|
# information of wrapped function
|
2021-04-22 10:27:02 -04:00
|
|
|
self.fn = fn
|
2021-04-20 22:29:40 -04:00
|
|
|
self.module = fn.__module__
|
|
|
|
self.arg_names = inspect.getfullargspec(fn).args
|
2021-09-18 22:48:26 -07:00
|
|
|
self.src = textwrap.dedent(inspect.getsource(fn))
|
|
|
|
# cache for callable driver objects (e.g. CUkernel)
|
|
|
|
self.drv_cache = dict()
|
|
|
|
# on-disk paths for the binary cache and corresponding
|
|
|
|
# file-lock
|
|
|
|
self._init_cache_paths()
|
|
|
|
# JITFunction can be instantiated as kernel
|
|
|
|
# when called with a grid using __getitem__
|
2021-04-20 22:29:40 -04:00
|
|
|
self.kernel_decorators = []
|
|
|
|
self.kernel = None
|
2021-09-18 22:48:26 -07:00
|
|
|
# forward docs
|
2021-04-22 10:27:02 -04:00
|
|
|
self.__doc__ = fn.__doc__
|
2021-04-20 22:29:40 -04:00
|
|
|
|
2021-09-18 22:48:26 -07:00
|
|
|
|
|
|
|
# we do not parse `src` in the constructor because
|
2021-04-20 22:29:40 -04:00
|
|
|
# the user might want to monkey-patch self.src dynamically.
|
|
|
|
# Some unit tests do this, for example.
|
|
|
|
def parse(self):
|
|
|
|
tree = ast.parse(self.src)
|
|
|
|
assert isinstance(tree, ast.Module)
|
|
|
|
assert len(tree.body) == 1
|
|
|
|
assert isinstance(tree.body[0], ast.FunctionDef)
|
|
|
|
return tree
|
|
|
|
|
|
|
|
def __call__(self, *args, generator: CodeGenerator, **meta):
|
|
|
|
try:
|
2021-07-21 15:58:26 -07:00
|
|
|
gscope = generator.gscope.copy()
|
2021-05-18 23:04:31 -04:00
|
|
|
lscope = generator.lscope.copy()
|
|
|
|
values = generator.module.get_values().copy()
|
2021-07-21 15:58:26 -07:00
|
|
|
generator.gscope = sys.modules[self.fn.__module__].__dict__
|
2021-05-18 23:04:31 -04:00
|
|
|
ret = generator.visit_FunctionDef(self.parse().body[0], inline=True, arg_values=args)
|
2021-07-21 15:58:26 -07:00
|
|
|
generator.gscope = gscope
|
2021-05-18 23:04:31 -04:00
|
|
|
generator.lscope = lscope
|
|
|
|
generator.module.set_values(values)
|
|
|
|
return ret
|
2021-04-20 22:29:40 -04:00
|
|
|
except Exception as e:
|
|
|
|
node = generator.last_node
|
|
|
|
if node is None or isinstance(e, (NotImplementedError, CompilationError)):
|
|
|
|
raise e
|
|
|
|
raise CompilationError(self.src, node, e)
|
|
|
|
|
2021-09-18 22:48:26 -07:00
|
|
|
# - when `.src` attribute is set, cache path needs
|
|
|
|
# to be reinitialized
|
|
|
|
# - when kernel decorators change, cached kernel
|
|
|
|
# needs to be cleared
|
2021-04-23 17:18:14 -04:00
|
|
|
def __setattr__(self, name, value):
|
|
|
|
if name == 'kernel_decorators':
|
|
|
|
self.kernel = None
|
|
|
|
super(JITFunction, self).__setattr__(name, value)
|
2021-09-18 22:48:26 -07:00
|
|
|
if name == 'src':
|
|
|
|
self._init_cache_paths()
|
2021-04-23 17:18:14 -04:00
|
|
|
|
2021-04-20 22:29:40 -04:00
|
|
|
def _init_kernel(self):
|
|
|
|
if self.kernel is None:
|
|
|
|
self.kernel = Kernel(self)
|
|
|
|
for decorator in reversed(self.kernel_decorators):
|
|
|
|
self.kernel = decorator(self.kernel)
|
|
|
|
return self.kernel
|
|
|
|
|
|
|
|
def __getitem__(self, grid):
|
|
|
|
return Launcher(self._init_kernel(), grid)
|
|
|
|
|
|
|
|
|
|
|
|
class Config:
|
2021-08-18 11:15:53 -07:00
|
|
|
"""
|
|
|
|
An object that represents a possible kernel configuration for the auto-tuner to try.
|
|
|
|
|
|
|
|
:ivar meta: a dictionary of meta-parameters to pass to the kernel as keyword arguments.
|
|
|
|
:type meta: dict[Str, Any]
|
|
|
|
:ivar num_warps: the number of warps to use for the kernel when compiled for GPUs. For example, if
|
|
|
|
`num_warps=8`, then each kernel instance will be automatically parallelized to
|
|
|
|
cooperatively execute using `8 * 32 = 256` threads.
|
|
|
|
:type num_warps: int
|
|
|
|
:ivar num_stages: the number of stages that the compiler should use when software-pipelining loops.
|
|
|
|
Mostly useful for matrix multiplication workloads on SM80+ GPUs.
|
|
|
|
:type num_stages: int
|
|
|
|
"""
|
2021-06-21 14:25:13 +08:00
|
|
|
def __init__(self, meta, num_warps=4, num_stages=2):
|
2021-04-20 22:29:40 -04:00
|
|
|
self.meta = meta
|
|
|
|
self.num_warps = num_warps
|
2021-06-21 14:25:13 +08:00
|
|
|
self.num_stages = num_stages
|
2021-04-20 22:29:40 -04:00
|
|
|
|
|
|
|
|
2021-08-14 10:58:38 -07:00
|
|
|
def autotune(configs, key, reset_to_zero=None):
|
2021-08-18 11:15:53 -07:00
|
|
|
"""
|
|
|
|
Decorator for auto-tuning a :code:`triton.jit`'d function.
|
|
|
|
|
|
|
|
.. highlight:: python
|
|
|
|
.. code-block:: python
|
|
|
|
|
|
|
|
@triton.autotune(configs=[
|
|
|
|
triton.Config(meta={'BLOCK_SIZE': 128}, num_warps=4),
|
|
|
|
triton.Config(meta={'BLOCK_SIZE': 1024}, num_warps=8),
|
|
|
|
],
|
|
|
|
key=['x_size'] # the two above configs will be evaluated anytime
|
|
|
|
# the value of x_size changes
|
|
|
|
)
|
|
|
|
@triton.jit
|
|
|
|
def kernel(x_ptr, x_size, **META):
|
|
|
|
BLOCK_SIZE = META['BLOCK_SIZE']
|
|
|
|
|
|
|
|
:note: When all the configurations are evaluated, the kernel will run multiple time.
|
|
|
|
This means that whatever value the kernel updates will be updated multiple times.
|
|
|
|
To avoid this undesired behavior, you can use the `reset_to_zero` argument, which
|
|
|
|
reset the value of the provided tensor to `zero` before running any configuration.
|
|
|
|
|
|
|
|
:param configs: a list of :code:`triton.Config` objects
|
|
|
|
:type configs: list[triton.Config]
|
|
|
|
:param key: a list of argument names whose change in value will trigger the evaluation of all provided configs.
|
|
|
|
:type key: list[str]
|
|
|
|
:param reset_to_zero: a list of argument names whose value will be reset to zero before evaluating any configs.
|
|
|
|
:type reset_to_zero: list[str]
|
|
|
|
"""
|
2021-04-20 22:29:40 -04:00
|
|
|
def decorator(fn):
|
|
|
|
def wrapper(kernel):
|
2021-08-14 10:58:38 -07:00
|
|
|
return Autotuner(kernel, fn.arg_names, configs, key, reset_to_zero)
|
2021-04-20 22:29:40 -04:00
|
|
|
|
|
|
|
fn.kernel_decorators.append(wrapper)
|
|
|
|
return fn
|
|
|
|
|
|
|
|
return decorator
|
|
|
|
|
|
|
|
|
|
|
|
def heuristics(values):
|
2021-08-18 11:15:53 -07:00
|
|
|
"""
|
|
|
|
Decorator for specifying how the values of certain meta-parameters may be computed.
|
|
|
|
This is useful for cases where auto-tuning is prohibitevely expensive, or just not applicable.
|
|
|
|
|
|
|
|
.. highlight:: python
|
|
|
|
.. code-block:: python
|
|
|
|
|
2021-09-02 07:20:17 +05:30
|
|
|
@triton.heuristics(values={'BLOCK_SIZE': lambda args: 2 ** int(math.ceil(math.log2(args[1])))})
|
2021-08-18 11:15:53 -07:00
|
|
|
@triton.jit
|
|
|
|
def kernel(x_ptr, x_size, **META):
|
|
|
|
BLOCK_SIZE = META['BLOCK_SIZE'] # smallest power-of-two >= x_size
|
|
|
|
|
|
|
|
|
|
|
|
.param values: a dictionary of meta-parameter names and functions that compute the value of the meta-parameter.
|
|
|
|
each such function takes a list of positional arguments as input.
|
|
|
|
.type values: dict[str, Callable[[list[Any]], Any]]
|
|
|
|
"""
|
2021-04-20 22:29:40 -04:00
|
|
|
def decorator(fn):
|
|
|
|
def wrapper(kernel):
|
|
|
|
def fun(*args, **meta):
|
|
|
|
for v, heur in values.items():
|
|
|
|
assert v not in meta
|
|
|
|
meta[v] = heur(*args, **meta)
|
|
|
|
return kernel(*args, **meta)
|
|
|
|
|
|
|
|
return fun
|
|
|
|
|
|
|
|
fn.kernel_decorators.append(wrapper)
|
|
|
|
return fn
|
|
|
|
|
|
|
|
return decorator
|
|
|
|
|
|
|
|
|
|
|
|
def jit(fn):
|
2021-04-23 17:18:14 -04:00
|
|
|
"""
|
|
|
|
Decorator for JIT-compiling a function using the Triton compiler.
|
|
|
|
|
|
|
|
:note: When a jit'd function is called, :code:`torch.tensor` arguments are implicitly converted to pointers using the :code:`.data_ptr()` method.
|
|
|
|
|
|
|
|
:note: This function will be compiled and run on the GPU. It will only have access to:
|
|
|
|
|
|
|
|
* python primitives,
|
|
|
|
* objects within the triton.language package,
|
|
|
|
* arguments to this function,
|
|
|
|
* other jit'd functions
|
|
|
|
|
|
|
|
:param fn: the function to be jit-compiled
|
|
|
|
:type fn: Callable
|
|
|
|
"""
|
2021-04-20 22:29:40 -04:00
|
|
|
return JITFunction(fn)
|
2021-04-23 17:18:14 -04:00
|
|
|
|
|
|
|
|
2021-08-18 11:15:53 -07:00
|
|
|
######
|
|
|
|
|
2021-04-23 17:18:14 -04:00
|
|
|
def cdiv(x, y):
|
|
|
|
return (x + y - 1) // y
|
2021-05-01 14:34:33 -04:00
|
|
|
|
2021-09-10 11:05:44 -07:00
|
|
|
def next_power_of_2(n):
|
|
|
|
"""Return the smallest power of 2 greater than or equal to n"""
|
|
|
|
n -= 1
|
|
|
|
n |= n >> 1
|
|
|
|
n |= n >> 2
|
|
|
|
n |= n >> 4
|
|
|
|
n |= n >> 8
|
|
|
|
n |= n >> 16
|
|
|
|
n += 1
|
|
|
|
return n
|
2021-05-01 14:34:33 -04:00
|
|
|
|
|
|
|
######
|
|
|
|
|
|
|
|
class TensorWrapper:
|
2021-09-22 13:53:27 -07:00
|
|
|
def __init__(self, base, dtype):
|
2021-05-01 14:34:33 -04:00
|
|
|
self.dtype = dtype
|
2021-09-22 13:53:27 -07:00
|
|
|
self.base = base
|
|
|
|
self.is_cuda = base.is_cuda
|
|
|
|
self.device = base.device
|
|
|
|
|
2021-05-01 14:34:33 -04:00
|
|
|
def data_ptr(self):
|
2021-09-22 13:53:27 -07:00
|
|
|
return self.base.data_ptr()
|
2021-05-01 14:34:33 -04:00
|
|
|
|
|
|
|
|
|
|
|
def reinterpret(tensor, dtype):
|
2021-09-22 13:53:27 -07:00
|
|
|
return TensorWrapper(tensor, dtype)
|