Files
triton/python/triton/compiler.py
Rohit Santhanam 8cc448d92e Changes to eliminate the need for the MI_GPU_ARCH environment variable.
The AMDGPU arch is now parsed out of the rocminfo dump.
2022-11-18 18:51:57 +00:00

1430 lines
55 KiB
Python

from __future__ import annotations
import ast
import contextlib
import functools
import hashlib
import io
import json
import os
import re
import shutil
import subprocess
import sys
import sysconfig
import tempfile
import warnings
from sysconfig import get_paths
from typing import Any, Dict, Set, Tuple, Union
import setuptools
import torch
from filelock import FileLock
import triton
import triton._C.libtriton.triton as _triton
from .tools.disasm import extract
def static_vars(**kwargs):
def decorate(func):
for k in kwargs:
setattr(func, k, kwargs[k])
return func
return decorate
def str_to_ty(name):
if name[0] == "*":
ty = str_to_ty(name[1:])
return triton.language.pointer_type(ty)
tys = {
"i1": triton.language.int1,
"fp8": triton.language.float8,
"fp16": triton.language.float16,
"bf16": triton.language.bfloat16,
"fp32": triton.language.float32,
"fp64": triton.language.float64,
"i8": triton.language.int8,
"i16": triton.language.int16,
"i32": triton.language.int32,
"i64": triton.language.int64,
"u8": triton.language.uint8,
"u16": triton.language.uint16,
"u32": triton.language.uint32,
"u64": triton.language.uint64,
"B": triton.language.int1,
}
return tys[name]
def mangle_ty(ty):
if ty.is_ptr():
return 'P' + mangle_ty(ty.element_ty)
if ty.is_int():
return 'i' + str(ty.int_bitwidth)
if ty.is_fp8():
return 'fp8'
if ty.is_fp16():
return 'fp16'
if ty.is_bf16():
return 'bf16'
if ty.is_fp32():
return 'fp32'
if ty.is_fp64():
return 'fp64'
if ty.is_void():
return 'V'
if ty.is_block():
elt = mangle_ty(ty.scalar)
shape = '_'.join(map(str, ty.shape))
return f'{elt}S{shape}S'
assert False, "Unsupported type"
def mangle_fn(name, arg_tys, constants):
# doesn't mangle ret type, which must be a function of arg tys
mangled_arg_names = '_'.join([mangle_ty(ty) for ty in arg_tys])
key = lambda x: x.__name__ if isinstance(x, triton.runtime.JITFunction) else repr(x)
mangled_constants = '_'.join([f'{i}c{key(constants[i])}' for i in sorted(constants)])
mangled_constants = mangled_constants.replace('.', '_d_')
mangled_constants = mangled_constants.replace("'", '_sq_')
mangled_constants = mangled_constants.replace("e-", '_em_')
ret = f'{name}__{mangled_arg_names}__{mangled_constants}'
return ret
def is_triton_tensor(value):
return isinstance(value, triton.language.tensor)
class ValueConstructor:
def __init__(self, module, builder, gscope) -> None:
self.gscope = gscope
self.lscope = dict()
self.builder = builder
self.module = module
# [name, bb] => triton.language.tensor
self.lvalues: Dict[Tuple[str, _triton.ir.basic_block], triton.language.tensor] = {}
# bb => {name => phi}
self.incomplete_phis = {}
self.sealed_blocks: Set[_triton.ir.basic_block] = set()
#
self.builtins = {
'range': range,
'min': triton.language.minimum,
'float': float,
'int': int,
'print': print,
'isinstance': isinstance,
'getattr': getattr,
}
def get_value(self, name):
''' This function:
1. make sure `name` is defined
2. if `name` is triton.language.tensor, get stored tensor by calling
`self._get_tensor()`
'''
# search node.id in local scope
ret = None
if name in self.lscope:
ret = self.lscope[name]
# search node.id in global scope
elif name in self.gscope:
ret = self.gscope[name]
# search node.id in builtins
elif name in self.builtins:
ret = self.builtins[name]
else:
raise ValueError(f'{name} is not defined')
if is_triton_tensor(ret):
return self._get_tensor(name, self.builder.get_insert_block())
return ret
def set_value(self, name: str,
value: Union[triton.language.tensor, triton.language.constexpr]) -> None:
''' This function:
called by visit_Assign() & visit_FuncDef() to store left value (lvalue)
1. record local defined name (FIXME: should consider control flow)
2. store tensor in self.lvalue
'''
self.lscope[name] = value
if isinstance(value, triton.language.tensor):
self._set_value(name, self.builder.get_insert_block(), value)
#
# SSA-construction
#
def _get_tensor(self, name: str, bb: _triton.ir.basic_block) -> triton.language.tensor:
# local value numbering
if (name, bb) in self.lvalues:
return self.lvalues[(name, bb)]
# global value numbering
saved_insert_point = self.builder.get_insert_point()
result = self._get_tensor_recursive(name, bb)
self.builder.set_insert_point(saved_insert_point)
return result
def _get_tensor_recursive(self, name: str, bb: _triton.ir.basic_block) -> triton.language.tensor:
preds = bb.get_predecessors()
type = self.lscope[name].type
# some preds haven't been filled, create a phi as a proxy of the value
if bb not in self.sealed_blocks:
result = self._make_phi(type, len(preds), bb)
if bb in self.incomplete_phis:
self.incomplete_phis[bb][name] = result
else:
self.incomplete_phis[bb] = {name: result}
elif len(preds) == 1:
# one predecessor: no phi needed, try get value from pred
result = self._get_tensor(name, preds[0])
elif len(preds) == 0:
result = self._get_tensor(name, None)
else: # multiple preds
phi = self._make_phi(type, len(preds), bb)
self._set_value(name, bb, phi)
result = self._add_phi_operands(name, phi)
self._set_value(name, bb, result)
return result
# returns a new phi tensor, which encausulate an ir.phi_node
def _make_phi(self,
type: triton.language.dtype,
num_values: int,
bb: _triton.ir.basic_block) -> triton.language.tensor:
instr = bb.get_first_non_phi()
self.builder.set_insert_point((bb, instr))
ir_phi = self.builder.create_phi(type.to_ir(self.builder), num_values)
if instr:
self.builder.set_insert_block(bb)
return triton.language.tensor(ir_phi, type)
# complete a phi node. (TODO: rename this as _complete_phis?)
# Note: since we try to remove tryival phi, the return tensor might not be a phi
def _add_phi_operands(self, name: str,
phi: triton.language.tensor) -> triton.language.tensor:
bb = phi.handle.get_parent()
for pred in bb.get_predecessors():
v = self._get_tensor(name, pred)
phi.handle.add_incoming(v.handle, pred)
phi = self._try_remove_trivial_phi(phi)
return phi
def _set_value(self, name: str, bb: _triton.ir.basic_block, value: triton.language.tensor) -> None:
self.lvalues[(name, bb)] = value
# TODO: why we need this?
self.module.set_instr_metadata(name, value.handle)
def _seal_block(self, bb: _triton.ir.basic_block):
# complete all incomplete phis
if bb in self.incomplete_phis:
for name, phi in self.incomplete_phis[bb].items():
result = self._add_phi_operands(name, phi)
# it's possible that this phi is trivial
if self._get_tensor(name, bb).handle == phi.handle:
self._set_value(name, bb, result)
del self.incomplete_phis[bb]
self.sealed_blocks.add(bb)
def _try_remove_trivial_phi(self, phi: triton.language.tensor) -> triton.language.tensor:
unique_handles = {op for op in phi.handle.ops() if op != phi.handle}
if len(unique_handles) != 1: # non-trivial phi
return phi
v = unique_handles.pop()
phi.handle.replace_all_uses_with(v)
# phi.handle.erase_from_parent()
# TODO: remove trivial phis recursively
return triton.language.tensor(v, phi.type)
class CodeGenerator(ast.NodeVisitor):
def __init__(self, context, prototype, gscope, attributes, constants, function_name, spec_to_1=None, prototypes=None, module=None, is_kernel=False):
self.spec_to_1 = set() if spec_to_1 is None else spec_to_1
self.prototypes = dict() if prototypes is None else prototypes
self.builder = _triton.ir.builder(context)
self.module = _triton.ir.module('', self.builder) if module is None else module
self.prototype = prototype
self.attributes = attributes
self.constants = constants
self.last_node = None
self.function_name = function_name
self.is_kernel = is_kernel
self.value_constructor = ValueConstructor(self.module, self.builder, gscope)
#
# AST visitor
#
def visit_compound_statement(self, stmts):
for stmt in stmts:
self.last_ret = self.visit(stmt)
if isinstance(stmt, ast.Return):
break
return stmts and isinstance(stmt, ast.Return)
def visit_Module(self, node):
ast.NodeVisitor.generic_visit(self, node)
def visit_List(self, node):
ctx = self.visit(node.ctx)
assert ctx is None
elts = [self.visit(elt) for elt in node.elts]
return elts
# By design, only non-kernel functions can return
def visit_Return(self, node):
ret = self.visit(node.value)
if ret is None:
return triton.language.tensor(self.builder.ret_void(), triton.language.void)
ret = triton.language.core._to_tensor(ret, self.builder)
ret = triton.language.tensor(self.builder.ret(ret.handle), ret.type)
return ret
def visit_FunctionDef(self, node):
arg_names, arg_annotations, kwarg_names = self.visit(node.args)
# initialize defaults
for i, default_value in enumerate(node.args.defaults):
arg_node = node.args.args[-i - 1]
annotation = arg_node.annotation
name = arg_node.arg
st_target = ast.Name(id=name, ctx=ast.Store())
if annotation is None:
init_node = ast.Assign(targets=[st_target], value=default_value)
else:
init_node = ast.AnnAssign(target=st_target, value=default_value, annotation=annotation)
self.visit(init_node)
# initialize function
self.prototypes[self.function_name] = self.prototype
fn = self.module.get_or_insert_function(self.function_name, self.prototype.to_ir(self.builder))
fn.set_is_kernel(self.is_kernel)
arg_values = []
idx = 0
for i, (arg_name, annotation) in enumerate(zip(arg_names, arg_annotations)):
if i in self.constants:
cst = self.constants[i]
if not isinstance(cst, triton.language.constexpr):
cst = triton.language.constexpr(self.constants[i])
arg_values.append(cst)
continue
if i in self.attributes:
is_ptr = fn.args[idx].type.is_ptr()
attr = 'aligned' if is_ptr else 'multiple_of'
attr = getattr(_triton.ir.attribute_kind, attr)
attr = _triton.ir.attribute(attr, self.attributes[i][1])
fn.add_attr(idx + 1, attr)
fn.args[idx].name = arg_name
arg_values.append(triton.language.tensor(fn.args[idx], self.prototype.param_types[idx]))
idx += 1
insert_pt = self.builder.get_insert_block()
entry = _triton.ir.basic_block.create(self.builder.context, "entry", fn)
self.builder.set_insert_block(entry)
self.value_constructor._seal_block(entry)
for arg_name, arg_value in zip(arg_names, arg_values):
self.value_constructor.set_value(arg_name, arg_value)
# visit function body
has_ret = self.visit_compound_statement(node.body)
# finalize
if not has_ret:
self.builder.ret_void()
else:
# a bit hacky: we only know the return type at the last moment so we update type info here
self.module.reset_ret_ty(self.function_name, self.last_ret.type.to_ir(self.builder))
self.prototype.ret_type = self.last_ret.type
self.builder.set_insert_block(insert_pt)
def visit_arguments(self, node):
arg_names = []
arg_annotations = []
for arg in node.args:
curr = self.visit(arg)
arg_names += [curr[0]]
arg_annotations += [curr[1]]
kwarg_names = self.visit(node.kwarg)
return arg_names, arg_annotations, kwarg_names
def visit_arg(self, node):
ast.NodeVisitor.generic_visit(self, node)
return node.arg, node.annotation
def visit_AnnAssign(self, node):
# extract attributes
annotation = self.visit(node.annotation)
target = self.visit(node.target)
value = self.visit(node.value)
# constexpr
if annotation == triton.language.constexpr:
if target in self.value_constructor.lscope:
raise ValueError(f'{target} is already defined.'
f' constexpr cannot be reassigned.')
if not isinstance(value, triton.language.constexpr):
value = triton.language.constexpr(value)
self.value_constructor.lscope[target] = value
return self.value_constructor.lscope[target]
# default: call visit_Assign
return self.visit_Assign(node)
def visit_Assign(self, node):
_names = []
for target in node.targets:
_names += [self.visit(target)]
assert len(_names) == 1
names = _names[0]
values = self.visit(node.value)
if not isinstance(names, tuple):
names = [names]
if not isinstance(values, tuple):
values = [values]
if isinstance(values[0], triton.language.tensor) \
and isinstance(values[0].type, triton.language.tuple_type):
struct = values[0].handle
tys = values[0].type.element_types
values = [self.builder.extract_value(struct, i) for i in range(len(tys))]
values = [triton.language.tensor(v, ty) for v, ty in zip(values, tys)]
assert len(values) == len(names)
for name, value in zip(names, values):
# TODO: can we store constexpr here to support constant folding?
# by default, constexpr are assigned into python variable
if isinstance(value, triton.language.constexpr):
value = value.value
if value is None:
raise ValueError(f'Cannot assign None to non-constexpr `{name}`. Please annotate as `: tl.constexpr`')
if not isinstance(value, triton.language.tensor):
value = triton.language.core._to_tensor(value, self.builder)
self.value_constructor.set_value(name, value)
def visit_AugAssign(self, node):
name = node.target.id
lhs = ast.Name(id=name, ctx=ast.Load())
rhs = ast.BinOp(lhs, node.op, node.value)
assign = ast.Assign(targets=[node.target], value=rhs)
self.visit(assign)
return self.value_constructor.get_value(name)
def visit_Name(self, node):
if type(node.ctx) == ast.Store:
return node.id
return self.value_constructor.get_value(node.id)
def visit_Store(self, node):
ast.NodeVisitor.generic_visit(self, node)
def visit_Load(self, node):
ast.NodeVisitor.generic_visit(self, node)
def visit_Tuple(self, node):
args = [self.visit(x) for x in node.elts]
mode = type(args[0])
# tuple of values -- create a struct
if len(args) > 1 and mode == triton.language.tensor\
and all([type(arg) == mode for arg in args]):
tuple_ty = triton.language.tuple_type([arg.type for arg in args])
ret = _triton.ir.undef.get(tuple_ty.to_ir(self.builder))
for i, arg in enumerate(args):
ret = self.builder.insert_value(ret, arg.handle, i)
ret = triton.language.tensor(ret, tuple_ty)
return ret
return tuple(args)
def visit_BinOp(self, node):
# visit operand
lhs = self.visit(node.left)
rhs = self.visit(node.right)
is_lhs_constexpr = isinstance(lhs, triton.language.constexpr)
is_rhs_constexpr = isinstance(rhs, triton.language.constexpr)
lhs = lhs.value if is_lhs_constexpr else lhs
rhs = rhs.value if is_rhs_constexpr else rhs
# get function name
fn = {
ast.Add: '__add__',
ast.Sub: '__sub__',
ast.Mult: '__mul__',
ast.Div: '__truediv__',
ast.FloorDiv: '__floordiv__',
ast.Mod: '__mod__',
ast.Pow: '__pow__',
ast.LShift: '__lshift__',
ast.RShift: '__rshift__',
ast.BitAnd: '__and__',
ast.BitOr: '__or__',
ast.BitXor: '__xor__',
}[type(node.op)]
# return a new constexpr if both arg are constexprs
if is_lhs_constexpr and is_rhs_constexpr:
return triton.language.constexpr(getattr(lhs, fn)(rhs))
# call operator
if is_triton_tensor(lhs):
return getattr(lhs, fn)(rhs, _builder=self.builder)
elif is_triton_tensor(rhs):
fn = fn[:2] + 'r' + fn[2:]
return getattr(rhs, fn)(lhs, _builder=self.builder)
else:
return getattr(lhs, fn)(rhs)
def visit_If(self, node):
cond = self.visit(node.test)
if isinstance(cond, triton.language.tensor):
cond = cond.to(triton.language.int1, _builder=self.builder)
current_bb = self.builder.get_insert_block()
then_bb = _triton.ir.basic_block.create(self.builder.context, "then", current_bb.parent)
else_bb = _triton.ir.basic_block.create(self.builder.context, "else", current_bb.parent) if node.orelse else None
endif_bb = _triton.ir.basic_block.create(self.builder.context, "endif", current_bb.parent)
self.value_constructor._seal_block(then_bb)
if else_bb:
self.value_constructor._seal_block(else_bb)
self.builder.cond_br(cond.handle, then_bb, else_bb)
else:
self.builder.cond_br(cond.handle, then_bb, endif_bb)
self.builder.set_insert_block(then_bb)
is_terminator = self.visit_compound_statement(node.body)
# TODO: last statement is a terminator?
if not is_terminator:
self.builder.br(endif_bb)
if else_bb:
self.builder.set_insert_block(else_bb)
is_terminator = self.visit_compound_statement(node.orelse)
# TODO: last statement is a terminator?
if not is_terminator:
self.builder.br(endif_bb)
self.value_constructor._seal_block(endif_bb)
self.builder.set_insert_block(endif_bb)
else:
if isinstance(cond, triton.language.constexpr):
cond = cond.value
if cond:
self.visit_compound_statement(node.body)
else:
self.visit_compound_statement(node.orelse)
def visit_IfExp(self, node):
cond = self.visit(node.test)
if cond.value:
return self.visit(node.body)
else:
return self.visit(node.orelse)
def visit_Pass(self, node):
pass
def visit_Compare(self, node):
assert len(node.comparators) == 1
assert len(node.ops) == 1
lhs = self.visit(node.left)
rhs = self.visit(node.comparators[0])
is_lhs_constexpr = isinstance(lhs, triton.language.constexpr)
is_rhs_constexpr = isinstance(rhs, triton.language.constexpr)
lhs = lhs.value if is_lhs_constexpr else lhs
rhs = rhs.value if is_rhs_constexpr else rhs
# handle `is`` and `is not``
if type(node.ops[0]) == ast.Is:
return triton.language.constexpr(lhs is rhs)
if type(node.ops[0]) == ast.IsNot:
return triton.language.constexpr(lhs is not rhs)
# function name
fn = {
ast.Eq: '__eq__',
ast.NotEq: '__ne__',
ast.Lt: '__lt__',
ast.LtE: '__le__',
ast.Gt: '__gt__',
ast.GtE: '__ge__',
}[type(node.ops[0])]
# return a new constexpr if both arg are constexprs
if is_lhs_constexpr and is_rhs_constexpr:
return triton.language.constexpr(getattr(lhs, fn)(rhs))
# call operator
if is_triton_tensor(lhs):
return getattr(lhs, fn)(rhs, _builder=self.builder)
elif is_triton_tensor(rhs):
fn = fn[:2] + 'r' + fn[2:]
return getattr(rhs, fn)(lhs, _builder=self.builder)
else:
assert False
def visit_UnaryOp(self, node):
op = self.visit(node.operand)
if type(node.op) == ast.Not:
assert isinstance(op, triton.language.constexpr), "`not` only supported for constexpr at the moment"
return triton.language.constexpr(not op)
fn = {
ast.USub: '__neg__',
ast.UAdd: '__pos__',
ast.Invert: '__invert__',
}[type(node.op)]
if isinstance(op, triton.language.constexpr):
return triton.language.constexpr(getattr(op.value, fn)())
assert is_triton_tensor(op)
return getattr(op, fn)(_builder=self.builder)
def visit_While(self, node):
current_bb = self.builder.get_insert_block()
loop_bb = _triton.ir.basic_block.create(self.builder.context, "loop", current_bb.parent)
next_bb = _triton.ir.basic_block.create(self.builder.context, "postloop", current_bb.parent)
def continue_fn():
cond = self.visit(node.test)
return self.builder.cond_br(cond.handle, loop_bb, next_bb)
continue_fn()
self.builder.set_insert_block(loop_bb)
self.visit_compound_statement(node.body)
continue_fn()
stop_bb = self.builder.get_insert_block()
self.value_constructor._seal_block(stop_bb)
self.value_constructor._seal_block(loop_bb)
self.value_constructor._seal_block(next_bb)
self.builder.set_insert_block(next_bb)
for stmt in node.orelse:
ast.NodeVisitor.generic_visit(self, stmt)
def visit_Subscript(self, node):
assert node.ctx.__class__.__name__ == "Load"
lhs = self.visit(node.value)
slices = self.visit(node.slice)
if is_triton_tensor(lhs):
return lhs.__getitem__(slices, _builder=self.builder)
return lhs[slices]
def visit_ExtSlice(self, node):
return [self.visit(dim) for dim in node.dims]
def visit_For(self, node):
iterator = self.visit(node.iter.func)
if iterator != self.value_constructor.builtins['range']:
raise RuntimeError('Only `range` iterator currently supported')
# static for loops: all iterator arguments are constexpr
iter_args = [self.visit(arg) for arg in node.iter.args]
is_static = all([isinstance(x, triton.language.constexpr) for x in iter_args])
if is_static:
st_target = ast.Name(id=node.target.id, ctx=ast.Store())
iter_args = [arg.value for arg in iter_args]
range = iterator(*iter_args)
if len(range) <= 10:
for i in iterator(*iter_args):
self.value_constructor.lscope[node.target.id] = triton.language.constexpr(i)
self.visit_compound_statement(node.body)
for stmt in node.orelse:
ast.NodeVisitor.generic_visit(self, stmt)
return
# create nodes
st_target = ast.Name(id=node.target.id, ctx=ast.Store())
ld_target = ast.Name(id=node.target.id, ctx=ast.Load())
arg_0 = node.iter.args[0] if len(node.iter.args) > 1 else ast.Num(0)
arg_1 = node.iter.args[1] if len(node.iter.args) > 1 else node.iter.args[0]
arg_2 = node.iter.args[2] if len(node.iter.args) > 2 else ast.Num(1)
# init node
init_node = ast.Assign(targets=[st_target], value=arg_0)
# step node
pos_cond_node = ast.Compare(ld_target, [ast.Lt()], [arg_1])
neg_cond_node = ast.Compare(ld_target, [ast.Gt()], [arg_1])
pos_step_node = ast.Compare(arg_2, [ast.Gt()], [ast.Num(0)])
build_cond = lambda: triton.language.where(self.visit(pos_step_node),
self.visit(pos_cond_node),
self.visit(neg_cond_node),
_builder=self.builder)
# cond_node = neg_cond_node
step_node = ast.AugAssign(target=st_target, op=ast.Add(), value=arg_2)
# code generation
current_bb = self.builder.get_insert_block()
loop_bb = _triton.ir.basic_block.create(self.builder.context, "loop", current_bb.parent)
next_bb = _triton.ir.basic_block.create(self.builder.context, "postloop", current_bb.parent)
def continue_fn():
self.visit(step_node)
cond = build_cond()
return self.builder.cond_br(cond.handle, loop_bb, next_bb)
# init loop induction variable
self.visit(init_node)
# promote it to right type
init_val = self.value_constructor.get_value(node.target.id)
promote = lambda a, b: triton.language.semantic.computation_type_impl(a, b, False)
start_ty = triton.language.core._to_tensor(iter_args[0], self.builder).type
stop_ty = triton.language.core._to_tensor(iter_args[1], self.builder).type if len(iter_args) > 1 else None
ty = promote(start_ty, stop_ty) if len(iter_args) > 1 else start_ty
casted = triton.language.semantic.cast(init_val, ty, self.builder)
self.value_constructor.set_value(node.target.id, casted)
# create cond
cond = build_cond()
self.builder.cond_br(cond.handle, loop_bb, next_bb)
self.builder.set_insert_block(loop_bb)
self.visit_compound_statement(node.body)
# TODO: handle case where body breaks control flow
continue_fn()
stop_bb = self.builder.get_insert_block()
self.value_constructor._seal_block(stop_bb)
self.value_constructor._seal_block(loop_bb)
self.value_constructor._seal_block(next_bb)
self.builder.set_insert_block(next_bb)
for stmt in node.orelse:
ast.NodeVisitor.generic_visit(self, stmt)
def visit_Slice(self, node):
lower = self.visit(node.lower)
upper = self.visit(node.upper)
step = self.visit(node.step)
return slice(lower, upper, step)
def visit_Index(self, node):
return self.visit(node.value)
def visit_keyword(self, node):
return {node.arg: self.visit(node.value)}
def visit_Call(self, node):
fn = self.visit(node.func)
if isinstance(fn, triton.language.constexpr):
fn = fn.value
kws = dict()
for keyword in node.keywords:
kws.update(self.visit(keyword))
args = [self.visit(arg) for arg in node.args]
if isinstance(fn, triton.runtime.JITFunction):
from inspect import getcallargs
args = getcallargs(fn.fn, *args, **kws)
args = [args[name] for name in fn.arg_names]
args = [arg if isinstance(arg, triton.language.tensor)
else triton.language.constexpr(arg) for arg in args]
# generate function def
attributes = dict()
constexprs = [i for i, arg in enumerate(args) if isinstance(arg, triton.language.constexpr)]
constants = {i: args[i] for i in constexprs}
# generate call
args = [None if i in constexprs else arg for i, arg in enumerate(args)]
arg_vals = [arg.handle for arg in args if arg is not None]
arg_types = [arg.type for arg in args if arg is not None]
fn_name = mangle_fn(fn.__name__, arg_types, constants)
# generate function def if necessary
if not self.module.has_function(fn_name):
ret_type = triton.language.void
prototype = triton.language.function_type(ret_type, arg_types)
gscope = sys.modules[fn.fn.__module__].__dict__
generator = CodeGenerator(self.builder.context, prototype, gscope, attributes, constants, function_name=fn_name, prototypes=self.prototypes, module=self.module)
generator.visit(fn.parse())
symbol = self.module.get_function(fn_name)
ret = self.builder.call(symbol, arg_vals)
if not ret.type.is_void():
ret = triton.language.tensor(ret, self.prototypes[fn_name].ret_type)
return ret
# built-in function
if sys.modules[fn.__module__] is triton.language.core or isinstance(fn, triton.language.extern.ExternalFunction):
ret = fn(*args, _builder=self.builder, **kws)
if fn in self.value_constructor.builtins.values():
args = [arg.value if isinstance(arg, triton.language.constexpr) else arg
for arg in args]
ret = fn(*args, **kws)
if isinstance(ret, (bool, int, float)):
ret = triton.language.core.constexpr(ret)
else:
ret = triton.language.core._to_tensor(ret, self.builder)
# special case: dynamic parallelism
# in this case the core primitive returns a proxy
# if isinstance(ret, triton.language.core.LaunchProxy):
# ret_type = _triton.ir.type.get_void(self.builder.context)
# arg_tys = [x.type for x in ret.args]
# prototype = _triton.ir.type.make_function(ret_type, arg_tys)
# gscope = sys.modules[ret.fn.fn.__module__].__dict__
# constants = ret.constants
# fn_name = mangle_fn(ret.fn.__name__, arg_tys, ret.constants)
# # TODO: clean-up attributes handling in function
# if not self.module.has_function(fn_name):
# attributes = {i: list(arg.parent.get_attrs(arg))[0].value for i, arg in enumerate(ret.args) \
# if isinstance(arg, _triton.ir.argument) and arg.parent.has_attr(i + 1) }
# generator = CodeGenerator(self.builder.context, prototype, gscope, attributes, constants, module=self.module, is_kernel=True)
# generator.visit(ret.fn.parse())
# symbol = self.module.get_function(fn_name)
# # TODO: should ret.args not include any constants ?
# ret = self.builder.launch(symbol, ret.args, ret.grid, ret.num_warps)
return ret
# return fn(*args, **kws)
def visit_Constant(self, node):
return triton.language.constexpr(node.value)
if sys.version_info < (3, 8):
def visit_NameConstant(self, node):
return triton.language.constexpr(node.value)
def visit_Num(self, node):
return triton.language.constexpr(node.n)
def visit_Str(self, node):
return triton.language.constexpr(ast.literal_eval(node))
def visit_Attribute(self, node):
lhs = self.visit(node.value)
return getattr(lhs, node.attr)
def visit_Expr(self, node):
ast.NodeVisitor.generic_visit(self, node)
def visit_NoneType(self, node):
return None
def visit(self, node):
if node is not None:
self.last_node = node
with warnings.catch_warnings():
# The ast library added visit_Constant and deprecated some other
# methods but we can't move to that without breaking Python 3.6 and 3.7.
warnings.simplefilter("ignore", DeprecationWarning) # python 3.9
warnings.simplefilter("ignore", PendingDeprecationWarning) # python 3.8
return super().visit(node)
def generic_visit(self, node):
typename = type(node).__name__
raise NotImplementedError("Unsupported node: {}".format(typename))
class CompilationError(Exception):
def __init__(self, src, node):
self.message = f'at {node.lineno}:{node.col_offset}:\n'
self.message += '\n'.join(src.split('\n')[:node.lineno])
self.message += '\n' + ' ' * node.col_offset + '^'
self.src = src
self.node = node
super().__init__(self.message)
def __reduce__(self):
# this is necessary to make CompilationError picklable
return (type(self), (self.src, self.node))
class OutOfResources(Exception):
def __init__(self, required, limit, name):
self.message = f'out of resource: {name}, '\
f'Required: {required}, '\
f'Hardware limit: {limit}'
self.required = required
self.limit = limit
self.name = name
super().__init__(self.message)
def __reduce__(self):
# this is necessary to make CompilationError picklable
return (type(self), (self.required, self.limit, self.name))
def kernel_suffix(signature, specialization):
# suffix format:
# <argid><'c' if equal to 1><'d' if divisible by 16>
suffix = ''
for i, _ in enumerate(signature):
suffix += str(i)
if i in specialization.equal_to_1:
suffix += 'c'
if i in specialization.divisible_by_16:
suffix += 'd'
return suffix
def make_triton_ir(fn, signature, specialization, constants):
context = _triton.ir.context()
# create kernel prototype
cst_key = lambda i: fn.arg_names.index(i) if isinstance(i, str) else i
constants = {cst_key(key): value for key, value in constants.items()}
# visit kernel AST
gscope = fn.__globals__.copy()
function_name = '_'.join([fn.__name__, kernel_suffix(signature.values(), specialization)])
tys = list(signature.values())
new_constants = {k: True if k in tys and tys[k] == "i1" else 1 for k in specialization.equal_to_1}
new_attrs = {k: ("multiple_of", 16) for k in specialization.divisible_by_16}
all_constants = constants.copy()
all_constants.update(new_constants)
arg_types = [str_to_ty(v) for k, v in signature.items() if k not in constants]
prototype = triton.language.function_type(triton.language.void, arg_types)
generator = CodeGenerator(context, prototype, gscope=gscope, constants=all_constants, function_name=function_name, attributes=new_attrs, is_kernel=True)
try:
generator.visit(fn.parse())
except Exception as e:
node = generator.last_node
if node is None or isinstance(e, (NotImplementedError, CompilationError)):
raise e
raise CompilationError(fn.src, node) from e
ret = generator.module
# module takes ownership of the context
ret.context = context
return ret, generator
def make_ptx(mod: Any, device: int) -> Tuple[str, int]:
'''
Translate TritonGPU module to PTX code.
:param mod: a TritonGPU dialect module
:return:
- PTX code
- shared memory alloaction size
'''
return _triton.translate_triton_gpu_to_ptx(mod, device)
def make_cubin(ptx, device):
'''
Compile TritonGPU module to cubin.
:param ptx: ptx code
:param device: CUDA device
:return: str
'''
return _triton.compile_ptx_to_cubin(ptx, device)
def ptx_get_kernel_name(ptx: str) -> str:
'''
Get kernel name from PTX code.
This Kernel name is required when launching the kernel.
'''
# There is a name mangling in PTX codegen, so the original kernel names in Triton IR are not available in PTX/cubin.
assert ptx
for line in ptx.split('\n'):
line = line.strip()
if line.startswith('// .globl'):
return line.split()[-1]
@functools.lru_cache()
def rocm_path_dir():
return os.getenv("ROCM_PATH", default="/opt/rocm")
def _get_amdgpu_arch():
try:
rocminfo = subprocess.check_output(rocm_path_dir() + '/bin/rocminfo').decode()
gfx_arch = re.search('Name:\\s+.*(gfx\\d+)', rocminfo)
return gfx_arch.group(1).strip()
except:
return None
@static_vars(discovered_gfx_arch = _get_amdgpu_arch())
def _compile(fn, signature: str, device: int = -1, constants=dict(),
specialization=_triton.code_gen.instance_descriptor(),
num_warps: int = 4, num_stages: int = 3, extern_libs=None,
output: str = "ttgir", cc=0) -> Tuple[str, int, str]:
print("compiler.py: _compile")
print(f"\t{fn, signature, device, constants, specialization, num_warps, num_stages, extern_libs, output, cc}")
valid_outputs = ("ttir", "ttgir", "ptx", "cubin")
# assert output in valid_outputs, "output should be one of [%s], but get \"%s\"" % (','.join(valid_outputs), output)
# triton-ir
module, _ = make_triton_ir(fn, signature, specialization, constants)
if output == "ttir":
return module
assert (output == "cubin" or output == "hsaco")
if torch.version.hip is not None:
backend = _triton.runtime.backend.ROCM
else:
backend = _triton.runtime.backend.CUDA
if extern_libs is None:
extern_libs = dict()
# compile ttir
if torch.version.hip is not None:
gfx_arch = os.environ.get('MI_GPU_ARCH', _compile.discovered_gfx_arch)
if gfx_arch is None:
raise RuntimeError('AMDGCN gfx arch is not defined.')
name, asm, shared_mem = _triton.code_gen.compile_ttir_to_amdgcn(backend, module, device, num_warps, num_stages, extern_libs, gfx_arch)
else:
name, asm, shared_mem = _triton.code_gen.compile_ttir_to_ptx(backend, module, device, num_warps, num_stages, extern_libs, cc)
return asm, shared_mem, name
def ty_to_cpp(ty):
if ty[0] == '*':
return "hipDeviceptr_t"
return {
"i1": "int32_t",
"i8": "int8_t",
"i16": "int16_t",
"i32": "int32_t",
"i64": "int64_t",
"u32": "uint32_t",
"u64": "uint64_t",
"fp16": "float",
"bf16": "float",
"fp32": "float",
"fp64": "double",
}[ty]
def generate_name_initializer(signature):
src = "int i = 0;\n"
tys = signature.split(',')
for i, ty in enumerate(tys):
src
def binary_name_to_header_name(name):
if len(name) > 128:
# avoid filename too long errors (filename limit is 255)
name = "kernel_" + hashlib.sha256(name.encode("utf-8")).hexdigest()
return f"{name}.h"
def generate_launcher(identifier, constants, signature):
arg_decls = ', '.join(f"{ty_to_cpp(ty)} arg{i}" for i, ty in signature.items())
def _extracted_type(ty):
if ty[0] == '*':
return "PyObject*"
return {
'i1': 'int32_t',
'i32': 'int32_t',
'i64': 'int64_t',
'u32': 'uint32_t',
'u64': 'uint64_t',
'fp16': 'float',
'bf16': 'float',
'fp32': 'float',
'fp64': 'double',
}[ty]
def format_of(ty):
return {
"PyObject*": "O",
"float": "f",
"double": "d",
"long": "l",
"uint32_t": "I",
"int32_t": "i",
"uint64_t": "K",
"int64_t": "L",
}[ty]
format = "iiiiiKKOOO" + ''.join([format_of(_extracted_type(ty)) for ty in signature.values()])
# generate glue code
if torch.version.hip is not None:
src = f"""
#define __HIP_PLATFORM_AMD__
#include <hip/hip_runtime.h>
#include <Python.h>
static inline void gpuAssert(hipError_t code, const char *file, int line)
{{
if (code != HIP_SUCCESS)
{{
const char* prefix = "Triton Error [CUDA]: ";
const char* str = hipGetErrorString(code);
char err[1024] = {{0}};
strcat(err, prefix);
strcat(err, str);
PyErr_SetString(PyExc_RuntimeError, err);
}}
}}
#define CUDA_CHECK(ans) {{ gpuAssert((ans), __FILE__, __LINE__); }}
void _launch(int gridX, int gridY, int gridZ, int num_warps, int shared_memory, hipStream_t stream, hipFunction_t function, {arg_decls}) {{
void *params[] = {{ {', '.join(f"&arg{i}" for i in signature.keys() if i not in constants)} }};
if(gridX*gridY*gridZ > 0){{
hipModuleLaunchKernel(function, gridX, gridY, gridZ, 32*num_warps, 1, 1, shared_memory, stream, params, 0);
}}
}}
static inline hipDeviceptr_t getPointer(PyObject *obj, int idx) {{
if (PyLong_Check(obj)) {{
return (hipDeviceptr_t)PyLong_AsUnsignedLongLong(obj);
}}
if (obj == Py_None) {{
return (hipDeviceptr_t)0;
}}
PyObject *ptr = PyObject_GetAttrString(obj, "data_ptr");
if(ptr){{
PyObject *empty_tuple = PyTuple_New(0);
PyObject *ret = PyObject_Call(ptr, empty_tuple, NULL);
Py_DECREF(empty_tuple);
Py_DECREF(ptr);
if (!PyLong_Check(ret)) {{
PyErr_SetString(PyExc_TypeError, "data_ptr method of Pointer object must return 64-bit int");
}}
return (hipDeviceptr_t)PyLong_AsUnsignedLongLong(ret);
}}
PyErr_SetString(PyExc_TypeError, "Pointer argument must be either uint64 or have data_ptr method");
return (hipDeviceptr_t)0;
}}
static PyObject* launch(PyObject* self, PyObject* args) {{
// printf("launch(PyObject* self, PyObject* args)");
int gridX, gridY, gridZ;
uint64_t _stream;
uint64_t _function;
int num_warps;
int shared_memory;
PyObject *launch_enter_hook = NULL;
PyObject *launch_exit_hook = NULL;
PyObject *compiled_kernel = NULL;
PyObject *hook_ret = NULL;
{' '.join([f"{_extracted_type(ty)} _arg{i}; " for i, ty in signature.items()])}
if(!PyArg_ParseTuple(args, \"{format}\", &gridX, &gridY, &gridZ, &num_warps, &shared_memory, &_stream, &_function, &launch_enter_hook, &launch_exit_hook, &compiled_kernel, {', '.join(f"&_arg{i}" for i, ty in signature.items())})) {{
return NULL;
}}
if (launch_enter_hook != Py_None) {{
PyObject *new_args = PyTuple_Pack(1, compiled_kernel);
hook_ret = PyObject_CallObject(launch_enter_hook, new_args);
Py_DECREF(new_args);
}}
_launch(gridX, gridY, gridZ, num_warps, shared_memory, (hipStream_t)_stream, (hipFunction_t)_function, {', '.join(f"getPointer(_arg{i},{i})" if ty[0]=="*" else f"_arg{i}"for i, ty in signature.items())});
if (launch_exit_hook != Py_None) {{
PyObject *new_args = NULL;
if (hook_ret) {{
new_args = PyTuple_Pack(2, compiled_kernel, hook_ret);
}} else {{
new_args = PyTuple_Pack(1, compiled_kernel);
}}
hook_ret = PyObject_CallObject(launch_exit_hook, new_args);
Py_DECREF(new_args);
}}
if (hook_ret) {{
Py_DECREF(hook_ret);
}}
if(PyErr_Occurred()) {{
return NULL;
}}
// return None
Py_INCREF(Py_None);
return Py_None;
}}
static PyMethodDef ModuleMethods[] = {{
{{"launch", launch, METH_VARARGS, "Entry point for all kernels with this signature"}},
{{NULL, NULL, 0, NULL}} // sentinel
}};
static struct PyModuleDef ModuleDef = {{
PyModuleDef_HEAD_INIT,
\"launcher\",
NULL, //documentation
-1, //size
ModuleMethods
}};
PyMODINIT_FUNC PyInit_launcher(void) {{
PyObject *m = PyModule_Create(&ModuleDef);
if(m == NULL) {{
return NULL;
}}
PyModule_AddFunctions(m, ModuleMethods);
return m;
}}
"""
return src
def default_cache_dir():
return os.path.join(os.environ["HOME"], ".triton", "cache")
class CacheManager:
def __init__(self, key):
self.key = key
self.lock_path = None
# create cache directory if it doesn't exist
self.cache_dir = os.environ.get('TRITON_CACHE_DIR', default_cache_dir())
if self.cache_dir:
self.cache_dir = os.path.join(self.cache_dir, self.key)
self.lock_path = os.path.join(self.cache_dir, "lock")
os.makedirs(self.cache_dir, exist_ok=True)
def _make_path(self, filename):
return os.path.join(self.cache_dir, filename)
def has_file(self, filename):
if not self.cache_dir:
return False
return os.path.exists(self._make_path(filename))
def put(self, data, filename, binary=True):
if not self.cache_dir:
return
assert self.lock_path is not None
filepath = self._make_path(filename)
with FileLock(self.lock_path):
# use tempfile to be robust against program interruptions
mode = "wb" if binary else "w"
with open(filepath + ".tmp", mode) as f:
f.write(data)
os.rename(filepath + ".tmp", filepath)
# utilties for generating and compiling C wrappers
@functools.lru_cache()
def libcuda_dirs():
locs = subprocess.check_output(["whereis", "libcuda.so"]).decode().strip().split()[1:]
return [os.path.dirname(loc) for loc in locs]
@functools.lru_cache()
def libhip_dirs():
return ["/opt/rocm/lib"]
@functools.lru_cache()
def cuda_home_dirs():
default_dir = "/usr/local/cuda"
return os.getenv("CUDA_HOME", default=default_dir)
@functools.lru_cache()
def hip_home_dirs():
default_dir = "/opt/rocm"
return os.getenv("ROCM_HOME", default=default_dir)
@functools.lru_cache()
def rocm_path_dir():
return os.getenv("ROCM_PATH", default="/opt/rocm")
@contextlib.contextmanager
def quiet():
old_stdout, old_stderr = sys.stdout, sys.stderr
sys.stdout, sys.stderr = io.StringIO(), io.StringIO()
try:
yield
finally:
sys.stdout, sys.stderr = old_stdout, old_stderr
def _build(name, src, srcdir):
print("compiler.py: _build")
print(f"\t{name, src, srcdir}")
if torch.version.hip is not None:
hip_lib_dirs = libhip_dirs()
hip_include_dir = os.path.join(hip_home_dirs(), "include")
else:
cuda_lib_dirs = libcuda_dirs()
cu_include_dir = os.path.join(cuda_home_dirs(), "include")
suffix = sysconfig.get_config_var('EXT_SUFFIX')
so = os.path.join(srcdir, '{name}{suffix}'.format(name=name, suffix=suffix))
# try to avoid setuptools if possible
cc = os.environ.get("CC")
if cc is None:
# TODO: support more things here.
clang = shutil.which("clang")
gcc = shutil.which("gcc")
cc = gcc if gcc is not None else clang
py_include_dir = get_paths()["include"]
if torch.version.hip is not None:
cc_cmd = [cc, src, f"-I{hip_include_dir}", f"-I{py_include_dir}", f"-I{srcdir}", "-shared", "-fPIC", "-lamdhip64", "-o", so]
cc_cmd += [f"-L{dir}" for dir in hip_lib_dirs]
else:
cc_cmd = [cc, src, "-O3", f"-I{cu_include_dir}", f"-I{py_include_dir}", f"-I{srcdir}", "-shared", "-fPIC", "-lcuda", "-o", so]
cc_cmd += [f"-L{dir}" for dir in cuda_lib_dirs]
print("\t", ''.join(cc_cmd))
ret = subprocess.check_call(cc_cmd)
if ret == 0:
print("ret:", ret)
print(so)
return so
# fallback on setuptools
extra_compile_args = []
if torch.version.hip is not None:
library_dirs = hip_lib_dirs
include_dirs = [srcdir, hip_include_dir]
libraries = ['rocm']
else:
library_dirs = cuda_lib_dirs
include_dirs = [srcdir, cu_include_dir]
libraries = ['cuda']
# extra arguments
extra_link_args = []
# create extension module
ext = setuptools.Extension(
name=name,
language='c',
sources=[src],
include_dirs=include_dirs,
extra_compile_args=extra_compile_args + ['-O3'],
extra_link_args=extra_link_args,
library_dirs=library_dirs,
libraries=libraries,
)
# build extension module
args = ['build_ext']
args.append('--build-temp=' + srcdir)
args.append('--build-lib=' + srcdir)
args.append('-q')
args = dict(
name=name,
ext_modules=[ext],
script_args=args,
)
with quiet():
setuptools.setup(**args)
return so
def make_so_cache_key(version_hash, signature, constants):
# Get unique key for the compiled code
signature = {k: 'ptr' if v[0] == '*' else v for k, v in signature.items()}
key = f"{version_hash}-{''.join(signature.values())}{constants}"
key = hashlib.md5(key.encode("utf-8")).hexdigest()
return key
def make_fn_cache_key(fn_hash, signature, configs, constants, num_warps, num_stages):
# Get unique key for the compiled code
get_conf_key = lambda conf: (sorted(conf.divisible_by_16), sorted(conf.equal_to_1))
configs_key = [get_conf_key(conf) for conf in configs]
key = f"{fn_hash}-{''.join(signature.values())}-{configs_key}-{constants}-{num_warps}-{num_stages}"
key = hashlib.md5(key.encode("utf-8")).hexdigest()
return key
def compile(fn, signature: str, device: int = -1, constants=dict(), num_warps: int = 4,
num_stages: int = 3, extern_libs=None, configs=None, cc=0, warm_cache_only=False):
print("compiler.py: compile")
print(f"\t{fn, signature, device, constants, num_warps, num_stages, extern_libs, configs, cc, warm_cache_only}")
# we get the kernel, i.e. the first function generated in the module
assert len(configs) == 1
# cache manager
name = fn.__name__
# name of files that are cached
so_cache_key = make_so_cache_key(triton.runtime.jit.version_key(), signature, constants)
so_cache_manager = CacheManager(so_cache_key)
so_name = f"{name}.so"
# retrieve stub from cache if it exists
if not so_cache_manager.has_file(so_name):
with tempfile.TemporaryDirectory() as tmpdir:
src = generate_launcher(name, constants, signature)
src_path = os.path.join(tmpdir, "main.c")
with open(src_path, "w") as f:
f.write(src)
so = _build(fn.__name__, src_path, tmpdir) # build step
with open(so, "rb") as f:
so_cache_manager.put(f.read(), so_name, binary=True)
# retrieve cached shared object if it exists
fn_cache_key = make_fn_cache_key(fn.cache_key, signature, configs, constants, num_warps, num_stages)
fn_cache_manager = CacheManager(fn_cache_key)
if torch.version.hip is not None:
amdgcn_name = f"{name}.gcn"
hasco_name = f"{name}.hsaco"
assembly_name = amdgcn_name
binary_name = hasco_name
else:
ptx_name = f"{name}.ptx"
cubin_name = f"{name}.cubin"
assembly_name = ptx_name
binary_name = cubin_name
data_name = f"{name}.json"
ttir_name = f"{name}.ttir"
llir_name = f"{name}.llir"
if not fn_cache_manager.has_file(binary_name) or \
not fn_cache_manager.has_file(data_name) or \
not fn_cache_manager.has_file(assembly_name) or \
not fn_cache_manager.has_file(ttir_name) or \
not fn_cache_manager.has_file(llir_name):
if torch.version.hip is not None:
asm, shared, kernel_name = _compile(fn, signature, device, constants, configs[0], num_warps, num_stages,
extern_libs, "hsaco", cc)
# cache AMD assembly and binary
fn_cache_manager.put(asm["hsaco_path"], binary_name, binary=False)
fn_cache_manager.put(asm["amdgcn"], assembly_name, binary=False)
else:
asm, shared, kernel_name = _compile(fn, signature, device, constants, configs[0], num_warps, num_stages,
extern_libs, "cubin", cc)
# cache Nvidia assembly and binary
fn_cache_manager.put(asm["cubin"], binary_name)
fn_cache_manager.put(asm["ptx"], assembly_name, binary=False)
# cache triton and llvm ir
fn_cache_manager.put(asm["ttir"], ttir_name, binary=False)
fn_cache_manager.put(asm["llir"], llir_name, binary=False)
# cache metadata
metadata = {"name": kernel_name, "shared": shared, "num_warps": num_warps, "num_stages": num_stages}
fn_cache_manager.put(json.dumps(metadata), data_name, binary=False)
if warm_cache_only:
return # load_binary() requires a valid cuda context
return CompiledKernel(name, so_cache_manager._make_path(so_name), fn_cache_manager.cache_dir, device)
class CompiledKernel:
# Hooks for external tools to monitor the execution of triton kernels
launch_enter_hook = None
launch_exit_hook = None
def __init__(self, fn_name, so_path, cache_dir, device):
print("compiler.py: CompiledKernel:__init__")
print(f"\t{self, fn_name, so_path, cache_dir, device}")
# initialize launcher
import importlib.util
spec = importlib.util.spec_from_file_location("launcher", so_path)
print("spec:", spec)
mod = importlib.util.module_from_spec(spec)
spec.loader.exec_module(mod)
self.c_wrapper = getattr(mod, "launch")
# initialize metadata
with open(os.path.join(cache_dir, f"{fn_name}.json")) as f:
metadata = json.load(f)
self.shared = metadata["shared"]
self.num_warps = metadata["num_warps"]
self.num_stages = metadata["num_stages"]
# initialize asm dict
self.asm = dict()
if torch.version.hip is not None:
with open(os.path.join(cache_dir, f"{fn_name}.hsaco"), "rb") as f:
self.asm["hsaco_path"] = f.read()
with open(os.path.join(cache_dir, f"{fn_name}.gcn"), "r") as f:
self.asm["amdgcn"] = f.read()
else:
with open(os.path.join(cache_dir, f"{fn_name}.cubin"), "rb") as f:
self.asm["cubin"] = f.read()
with open(os.path.join(cache_dir, f"{fn_name}.ptx"), "r") as f:
self.asm["ptx"] = f.read()
with open(os.path.join(cache_dir, f"{fn_name}.llir"), "r") as f:
self.asm["llir"] = f.read()
with open(os.path.join(cache_dir, f"{fn_name}.ttir"), "r") as f:
self.asm["ttir"] = f.read()
if torch.version.hip is not None:
mod, func, n_regs, n_spills = _triton.code_gen.load_binary_hsaco(metadata["name"], self.asm["hsaco_path"], self.shared, device)
else:
mod, func, n_regs, n_spills = _triton.code_gen.load_binary_cubin(metadata["name"], self.asm["cubin"], self.shared, device)
self.fn_name = fn_name
self.cu_module = mod
self.cu_function = func
self.n_regs = n_regs
self.n_spills = n_spills
def __getitem__(self, grid):
def runner(*args, stream=None):
if stream is None:
stream = torch.cuda.current_stream().cuda_stream
self.c_wrapper(grid[0], grid[1], grid[2], self.num_warps, self.shared, stream, self.cu_function,
CompiledKernel.launch_enter_hook, CompiledKernel.launch_exit_hook, self, *args)
return runner
def get_sass(self, fun=None):
if 'sass' in self.asm:
return self.asm['sass']
fd, path = tempfile.mkstemp()
try:
with open(fd, 'wb') as cubin:
cubin.write(self.asm['cubin'])
self.sass = extract(path, fun)
finally:
os.remove(path)
self.asm['sass'] = self.sass
return self.sass