2022-04-06 16:13:53 -07:00
from __future__ import annotations
2021-04-20 22:29:40 -04:00
import ast
2022-09-18 08:51:48 -07:00
import contextlib
2021-09-23 16:45:54 -07:00
import functools
2021-09-18 22:48:26 -07:00
import hashlib
2022-09-18 08:51:48 -07:00
import io
2022-09-22 16:44:22 -07:00
import json
2021-09-18 22:48:26 -07:00
import os
2022-12-21 01:30:50 -08:00
import re
2022-09-19 21:01:36 -07:00
import shutil
2021-09-23 16:45:54 -07:00
import subprocess
2022-01-06 14:34:17 -08:00
import sys
2022-09-18 08:51:48 -07:00
import sysconfig
2022-01-06 14:34:17 -08:00
import tempfile
2021-12-10 15:19:20 -08:00
import warnings
2022-12-21 01:30:50 -08:00
from collections import namedtuple
from pathlib import Path
2022-09-19 21:01:36 -07:00
from sysconfig import get_paths
2022-12-21 01:30:50 -08:00
from typing import Any , Callable , Dict , Tuple , Union
2022-01-06 14:34:17 -08:00
2022-09-18 08:51:48 -07:00
import setuptools
2021-08-21 06:00:54 +02:00
import torch
2021-09-21 14:10:02 -07:00
from filelock import FileLock
2021-12-22 01:56:10 +08:00
2022-01-06 14:34:17 -08:00
import triton
import triton . _C . libtriton . triton as _triton
2022-12-21 01:30:50 -08:00
from . import impl
2022-10-02 17:39:52 -07:00
from . tools . disasm import extract
2021-04-20 22:29:40 -04:00
2022-05-24 12:08:49 -07:00
2022-09-18 08:51:48 -07:00
def str_to_ty ( name ) :
if name [ 0 ] == " * " :
ty = str_to_ty ( name [ 1 : ] )
return triton . language . pointer_type ( ty )
tys = {
" fp8 " : triton . language . float8 ,
" fp16 " : triton . language . float16 ,
" bf16 " : triton . language . bfloat16 ,
" fp32 " : triton . language . float32 ,
" fp64 " : triton . language . float64 ,
2022-12-21 01:30:50 -08:00
" i1 " : triton . language . int1 ,
2022-09-18 08:51:48 -07:00
" i8 " : triton . language . int8 ,
" i16 " : triton . language . int16 ,
" i32 " : triton . language . int32 ,
" i64 " : triton . language . int64 ,
" u8 " : triton . language . uint8 ,
" u16 " : triton . language . uint16 ,
" u32 " : triton . language . uint32 ,
" u64 " : triton . language . uint64 ,
" B " : triton . language . int1 ,
}
return tys [ name ]
2022-04-07 12:11:32 -07:00
2021-04-20 22:29:40 -04:00
2022-04-06 16:13:53 -07:00
def mangle_ty ( ty ) :
if ty . is_ptr ( ) :
return ' P ' + mangle_ty ( ty . element_ty )
if ty . is_int ( ) :
return ' i ' + str ( ty . int_bitwidth )
if ty . is_fp8 ( ) :
2022-04-03 20:58:16 -07:00
return ' fp8 '
2022-04-06 16:13:53 -07:00
if ty . is_fp16 ( ) :
2022-04-03 20:58:16 -07:00
return ' fp16 '
2022-04-06 16:13:53 -07:00
if ty . is_bf16 ( ) :
2022-04-03 20:58:16 -07:00
return ' bf16 '
2022-04-06 16:13:53 -07:00
if ty . is_fp32 ( ) :
2022-04-03 20:58:16 -07:00
return ' fp32 '
2022-04-06 16:13:53 -07:00
if ty . is_fp64 ( ) :
2022-04-03 20:58:16 -07:00
return ' fp64 '
2022-04-06 16:13:53 -07:00
if ty . is_block ( ) :
elt = mangle_ty ( ty . scalar )
shape = ' _ ' . join ( map ( str , ty . shape ) )
2022-04-03 20:58:16 -07:00
return f ' { elt } S { shape } S '
2022-12-21 01:30:50 -08:00
if ty . is_void ( ) :
return ' V '
2022-09-16 12:26:40 -07:00
assert False , " Unsupported type "
2022-04-03 20:58:16 -07:00
def mangle_fn ( name , arg_tys , constants ) :
# doesn't mangle ret type, which must be a function of arg tys
mangled_arg_names = ' _ ' . join ( [ mangle_ty ( ty ) for ty in arg_tys ] )
2022-12-21 01:30:50 -08:00
mangled_constants = ' _ ' . join ( [ f ' { i } c { repr ( constants [ i ] ) } ' for i in sorted ( constants ) ] )
2022-04-03 20:58:16 -07:00
mangled_constants = mangled_constants . replace ( ' . ' , ' _d_ ' )
mangled_constants = mangled_constants . replace ( " ' " , ' _sq_ ' )
ret = f ' { name } __ { mangled_arg_names } __ { mangled_constants } '
return ret
2022-12-21 01:30:50 -08:00
class enter_sub_region :
def __init__ ( self , generator : CodeGenerator ) :
self . generator = generator
def __enter__ ( self ) :
# record lscope & local_defs in the parent scope
self . liveins = self . generator . lscope . copy ( )
self . prev_defs = self . generator . local_defs . copy ( )
self . generator . local_defs = { }
self . insert_block = self . generator . builder . get_insertion_block ( )
return self . liveins , self . insert_block
def __exit__ ( self , * args , * * kwargs ) :
self . generator . builder . set_insertion_point_to_end ( self . insert_block )
self . generator . lscope = self . liveins
self . generator . local_defs = self . prev_defs
2022-04-06 16:13:53 -07:00
2022-12-21 01:30:50 -08:00
class CodeGenerator ( ast . NodeVisitor ) :
def __init__ ( self , context , prototype , gscope , attributes , constants , function_name , module = None , is_kernel = False , function_types = dict ( ) ) :
self . builder = _triton . ir . builder ( context )
self . module = self . builder . create_module ( ) if module is None else module
self . function_ret_types = function_types
self . prototype = prototype
2022-04-06 16:13:53 -07:00
self . gscope = gscope
self . lscope = dict ( )
2022-12-21 01:30:50 -08:00
self . attributes = attributes
self . constants = constants
self . function_name = function_name
self . is_kernel = is_kernel
self . last_node = None
2022-04-06 16:13:53 -07:00
self . builtins = {
' range ' : range ,
' min ' : triton . language . minimum ,
' float ' : float ,
' int ' : int ,
' print ' : print ,
' isinstance ' : isinstance ,
' getattr ' : getattr ,
}
2022-12-21 01:30:50 -08:00
# SSA-construction
# name => triton.language.tensor
self . local_defs : Dict [ str , triton . language . tensor ] = { }
self . global_uses : Dict [ str , triton . language . tensor ] = { }
2022-04-06 16:13:53 -07:00
2021-04-20 22:29:40 -04:00
def get_value ( self , name ) :
2022-04-06 16:13:53 -07:00
''' This function:
1. make sure ` name ` is defined
2. if ` name ` is triton . language . tensor , get stored tensor by calling
` self . _get_tensor ( ) `
'''
2021-04-20 22:29:40 -04:00
# search node.id in local scope
ret = None
if name in self . lscope :
ret = self . lscope [ name ]
2022-12-21 01:30:50 -08:00
if name not in self . local_defs :
self . global_uses [ name ] = ret
2021-04-20 22:29:40 -04:00
# search node.id in global scope
elif name in self . gscope :
ret = self . gscope [ name ]
# search node.id in builtins
elif name in self . builtins :
ret = self . builtins [ name ]
else :
raise ValueError ( f ' { name } is not defined ' )
return ret
2022-04-06 16:13:53 -07:00
def set_value ( self , name : str ,
value : Union [ triton . language . tensor , triton . language . constexpr ] ) - > None :
''' This function:
called by visit_Assign ( ) & visit_FuncDef ( ) to store left value ( lvalue )
1. record local defined name ( FIXME : should consider control flow )
2. store tensor in self . lvalue
'''
2021-04-20 22:29:40 -04:00
self . lscope [ name ] = value
2022-12-21 01:30:50 -08:00
self . local_defs [ name ] = value
2022-04-06 16:13:53 -07:00
2022-12-21 01:30:50 -08:00
def is_triton_tensor ( self , value ) :
return isinstance ( value , triton . language . tensor )
2022-04-06 16:13:53 -07:00
#
# AST visitor
#
def visit_compound_statement ( self , stmts ) :
for stmt in stmts :
2022-12-21 01:30:50 -08:00
self . last_ret_type = self . visit ( stmt )
2022-04-06 16:13:53 -07:00
if isinstance ( stmt , ast . Return ) :
break
return stmts and isinstance ( stmt , ast . Return )
2022-03-24 17:16:50 -07:00
2021-04-20 22:29:40 -04:00
def visit_Module ( self , node ) :
ast . NodeVisitor . generic_visit ( self , node )
def visit_List ( self , node ) :
ctx = self . visit ( node . ctx )
assert ctx is None
elts = [ self . visit ( elt ) for elt in node . elts ]
return elts
# By design, only non-kernel functions can return
def visit_Return ( self , node ) :
2022-12-21 01:30:50 -08:00
ret_value = self . visit ( node . value )
if ret_value is None :
self . builder . ret ( [ ] )
return None
if isinstance ( ret_value , tuple ) :
ret_values = [ triton . language . core . _to_tensor ( v , self . builder ) for v in ret_value ]
ret_types = [ v . type for v in ret_values ]
self . builder . ret ( [ v . handle for v in ret_values ] )
return tuple ( ret_types )
else :
ret = triton . language . core . _to_tensor ( ret_value , self . builder )
self . builder . ret ( [ ret . handle ] )
return ret . type
2021-04-20 22:29:40 -04:00
2022-04-03 20:58:16 -07:00
def visit_FunctionDef ( self , node ) :
2022-12-21 01:30:50 -08:00
arg_names , kwarg_names = self . visit ( node . args )
2021-11-29 19:11:26 -08:00
# initialize defaults
for i , default_value in enumerate ( node . args . defaults ) :
2022-01-06 14:34:17 -08:00
arg_node = node . args . args [ - i - 1 ]
2021-11-29 19:11:26 -08:00
annotation = arg_node . annotation
name = arg_node . arg
st_target = ast . Name ( id = name , ctx = ast . Store ( ) )
if annotation is None :
init_node = ast . Assign ( targets = [ st_target ] , value = default_value )
else :
init_node = ast . AnnAssign ( target = st_target , value = default_value , annotation = annotation )
self . visit ( init_node )
2021-04-20 22:29:40 -04:00
# initialize function
2022-12-21 01:30:50 -08:00
visibility = " public " if self . is_kernel else " private "
fn = self . builder . get_or_insert_function ( self . module , self . function_name , self . prototype . to_ir ( self . builder ) , visibility )
self . module . push_back ( fn )
entry = fn . add_entry_block ( )
2022-04-03 20:58:16 -07:00
arg_values = [ ]
idx = 0
2022-12-21 01:30:50 -08:00
for i , arg_name in enumerate ( arg_names ) :
2022-04-03 20:58:16 -07:00
if i in self . constants :
cst = self . constants [ i ]
if not isinstance ( cst , triton . language . constexpr ) :
cst = triton . language . constexpr ( self . constants [ i ] )
arg_values . append ( cst )
2022-09-18 08:51:48 -07:00
continue
2022-12-21 01:30:50 -08:00
else :
if i in self . attributes :
fn . set_arg_attr ( idx , " tt.divisibility " , self . attributes [ i ] [ 1 ] )
arg_values . append ( triton . language . tensor ( fn . args ( idx ) , self . prototype . param_types [ idx ] ) )
idx + = 1
insert_pt = self . builder . get_insertion_block ( )
2021-04-20 22:29:40 -04:00
for arg_name , arg_value in zip ( arg_names , arg_values ) :
2022-12-21 01:30:50 -08:00
self . set_value ( arg_name , arg_value )
self . builder . set_insertion_point_to_start ( entry )
2022-04-03 20:58:16 -07:00
# visit function body
has_ret = self . visit_compound_statement ( node . body )
2022-12-21 01:30:50 -08:00
# finalize function
2022-04-03 20:58:16 -07:00
if not has_ret :
2022-12-21 01:30:50 -08:00
self . builder . ret ( [ ] )
2022-04-03 20:58:16 -07:00
else :
2022-12-21 01:30:50 -08:00
# update return type
if isinstance ( self . last_ret_type , tuple ) :
self . prototype . ret_types = list ( self . last_ret_type )
fn . reset_type ( self . prototype . to_ir ( self . builder ) )
else :
self . prototype . ret_types = [ self . last_ret_type ]
fn . reset_type ( self . prototype . to_ir ( self . builder ) )
if insert_pt :
self . builder . set_insertion_point_to_end ( insert_pt )
2021-04-20 22:29:40 -04:00
def visit_arguments ( self , node ) :
arg_names = [ ]
for arg in node . args :
2022-12-21 01:30:50 -08:00
arg_names + = [ self . visit ( arg ) ]
2021-04-20 22:29:40 -04:00
kwarg_names = self . visit ( node . kwarg )
2022-12-21 01:30:50 -08:00
return arg_names , kwarg_names
2021-04-20 22:29:40 -04:00
def visit_arg ( self , node ) :
ast . NodeVisitor . generic_visit ( self , node )
2022-12-21 01:30:50 -08:00
return node . arg
2021-04-20 22:29:40 -04:00
2021-10-30 00:32:58 -07:00
def visit_AnnAssign ( self , node ) :
# extract attributes
annotation = self . visit ( node . annotation )
target = self . visit ( node . target )
value = self . visit ( node . value )
# constexpr
if annotation == triton . language . constexpr :
2022-12-21 01:30:50 -08:00
if target in self . lscope :
2021-10-30 00:32:58 -07:00
raise ValueError ( f ' { target } is already defined. '
f ' constexpr cannot be reassigned. ' )
2021-11-05 09:26:33 -07:00
if not isinstance ( value , triton . language . constexpr ) :
value = triton . language . constexpr ( value )
2022-12-21 01:30:50 -08:00
self . lscope [ target ] = value
return self . lscope [ target ]
2021-10-30 00:32:58 -07:00
# default: call visit_Assign
return self . visit_Assign ( node )
2021-04-20 22:29:40 -04:00
def visit_Assign ( self , node ) :
2021-05-20 14:12:04 -04:00
_names = [ ]
2021-04-20 22:29:40 -04:00
for target in node . targets :
2021-05-20 14:12:04 -04:00
_names + = [ self . visit ( target ) ]
assert len ( _names ) == 1
names = _names [ 0 ]
values = self . visit ( node . value )
if not isinstance ( names , tuple ) :
names = [ names ]
if not isinstance ( values , tuple ) :
values = [ values ]
for name , value in zip ( names , values ) :
2021-10-30 00:32:58 -07:00
# by default, constexpr are assigned into python variable
if isinstance ( value , triton . language . constexpr ) :
value = value . value
2022-04-06 16:13:53 -07:00
if not isinstance ( value , triton . language . tensor ) :
value = triton . language . core . _to_tensor ( value , self . builder )
2022-12-21 01:30:50 -08:00
self . set_value ( name , value )
2021-04-20 22:29:40 -04:00
def visit_AugAssign ( self , node ) :
name = node . target . id
lhs = ast . Name ( id = name , ctx = ast . Load ( ) )
rhs = ast . BinOp ( lhs , node . op , node . value )
assign = ast . Assign ( targets = [ node . target ] , value = rhs )
self . visit ( assign )
2022-12-21 01:30:50 -08:00
return self . get_value ( name )
2021-04-20 22:29:40 -04:00
def visit_Name ( self , node ) :
if type ( node . ctx ) == ast . Store :
return node . id
2022-12-21 01:30:50 -08:00
return self . get_value ( node . id )
2021-04-20 22:29:40 -04:00
def visit_Store ( self , node ) :
ast . NodeVisitor . generic_visit ( self , node )
def visit_Load ( self , node ) :
ast . NodeVisitor . generic_visit ( self , node )
def visit_Tuple ( self , node ) :
args = [ self . visit ( x ) for x in node . elts ]
return tuple ( args )
def visit_BinOp ( self , node ) :
lhs = self . visit ( node . left )
rhs = self . visit ( node . right )
fn = {
ast . Add : ' __add__ ' ,
ast . Sub : ' __sub__ ' ,
ast . Mult : ' __mul__ ' ,
ast . Div : ' __truediv__ ' ,
ast . FloorDiv : ' __floordiv__ ' ,
ast . Mod : ' __mod__ ' ,
ast . Pow : ' __pow__ ' ,
ast . LShift : ' __lshift__ ' ,
ast . RShift : ' __rshift__ ' ,
ast . BitAnd : ' __and__ ' ,
ast . BitOr : ' __or__ ' ,
ast . BitXor : ' __xor__ ' ,
} [ type ( node . op ) ]
2022-12-21 01:30:50 -08:00
if self . is_triton_tensor ( lhs ) :
2021-10-30 00:32:58 -07:00
return getattr ( lhs , fn ) ( rhs , _builder = self . builder )
2022-12-21 01:30:50 -08:00
elif self . is_triton_tensor ( rhs ) :
2021-04-20 22:29:40 -04:00
fn = fn [ : 2 ] + ' r ' + fn [ 2 : ]
2021-10-30 00:32:58 -07:00
return getattr ( rhs , fn ) ( lhs , _builder = self . builder )
else :
return getattr ( lhs , fn ) ( rhs )
2021-04-20 22:29:40 -04:00
def visit_If ( self , node ) :
cond = self . visit ( node . test )
2022-04-06 16:13:53 -07:00
if isinstance ( cond , triton . language . tensor ) :
2021-12-17 18:06:21 -08:00
cond = cond . to ( triton . language . int1 , _builder = self . builder )
2022-12-21 01:30:50 -08:00
with enter_sub_region ( self ) as sr :
liveins , ip_block = sr
liveins_copy = liveins . copy ( )
then_block = self . builder . create_block ( )
self . builder . set_insertion_point_to_start ( then_block )
self . visit_compound_statement ( node . body )
then_defs = self . local_defs . copy ( )
# when need an else block when:
# 1. we have an orelse node
# or
# 2. the then block defines new variable
else_defs = { }
if then_defs or node . orelse :
if node . orelse :
self . lscope = liveins
self . local_defs = { }
else_block = self . builder . create_block ( )
self . builder . set_insertion_point_to_end ( else_block )
self . visit_compound_statement ( node . orelse )
else_defs = self . local_defs . copy ( )
else :
# collect else_defs
for name in then_defs :
if name in liveins :
assert self . is_triton_tensor ( then_defs [ name ] )
assert self . is_triton_tensor ( liveins [ name ] )
else_defs [ name ] = liveins [ name ]
# collect yields
names = [ ]
ret_types = [ ]
for then_name in then_defs :
for else_name in else_defs :
if then_name == else_name :
if then_defs [ then_name ] . type == else_defs [ else_name ] . type :
names . append ( then_name )
ret_types . append ( then_defs [ then_name ] . type )
# defined in else block but not in then block
# to find in parent scope and yield them
for else_name in else_defs :
if else_name in liveins and else_name not in then_defs :
if else_defs [ else_name ] . type == liveins [ else_name ] . type :
names . append ( else_name )
ret_types . append ( else_defs [ else_name ] . type )
then_defs [ else_name ] = liveins_copy [ else_name ]
self . builder . set_insertion_point_to_end ( ip_block )
if then_defs or node . orelse : # with else block
if_op = self . builder . create_if_op ( [ ty . to_ir ( self . builder ) for ty in ret_types ] , cond . handle , True )
then_block . merge_block_before ( if_op . get_then_block ( ) )
self . builder . set_insertion_point_to_end ( if_op . get_then_block ( ) )
if len ( names ) > 0 :
self . builder . create_yield_op ( [ then_defs [ n ] . handle for n in names ] )
if not node . orelse :
else_block = if_op . get_else_block ( )
else :
else_block . merge_block_before ( if_op . get_else_block ( ) )
self . builder . set_insertion_point_to_end ( if_op . get_else_block ( ) )
if len ( names ) > 0 :
self . builder . create_yield_op ( [ else_defs [ n ] . handle for n in names ] )
else : # no else block
if_op = self . builder . create_if_op ( [ ty . to_ir ( self . builder ) for ty in ret_types ] , cond . handle , False )
then_block . merge_block_before ( if_op . get_then_block ( ) )
# update values yielded by IfOp
for i , name in enumerate ( names ) :
new_tensor = triton . language . core . tensor ( if_op . get_result ( i ) , ret_types [ i ] )
self . lscope [ name ] = new_tensor
self . local_defs [ name ] = new_tensor
2021-04-20 22:29:40 -04:00
else :
2021-12-17 18:06:21 -08:00
if isinstance ( cond , triton . language . constexpr ) :
cond = cond . value
2021-04-20 22:29:40 -04:00
if cond :
self . visit_compound_statement ( node . body )
else :
self . visit_compound_statement ( node . orelse )
def visit_IfExp ( self , node ) :
cond = self . visit ( node . test )
2021-12-09 15:14:06 -08:00
if cond . value :
2021-04-20 22:29:40 -04:00
return self . visit ( node . body )
else :
return self . visit ( node . orelse )
def visit_Pass ( self , node ) :
pass
def visit_Compare ( self , node ) :
assert len ( node . comparators ) == 1
assert len ( node . ops ) == 1
lhs = self . visit ( node . left )
rhs = self . visit ( node . comparators [ 0 ] )
2022-12-21 01:30:50 -08:00
if isinstance ( lhs , triton . language . constexpr ) :
lhs = lhs . value
if isinstance ( rhs , triton . language . constexpr ) :
rhs = rhs . value
2021-12-09 13:21:22 -08:00
if type ( node . ops [ 0 ] ) == ast . Is :
return triton . language . constexpr ( lhs is rhs )
if type ( node . ops [ 0 ] ) == ast . IsNot :
return triton . language . constexpr ( lhs is not rhs )
2021-04-20 22:29:40 -04:00
fn = {
ast . Eq : ' __eq__ ' ,
ast . NotEq : ' __ne__ ' ,
ast . Lt : ' __lt__ ' ,
ast . LtE : ' __le__ ' ,
ast . Gt : ' __gt__ ' ,
ast . GtE : ' __ge__ ' ,
} [ type ( node . ops [ 0 ] ) ]
2022-12-21 01:30:50 -08:00
if self . is_triton_tensor ( lhs ) :
2021-08-18 11:15:53 -07:00
return getattr ( lhs , fn ) ( rhs , _builder = self . builder )
2022-12-21 01:30:50 -08:00
elif self . is_triton_tensor ( rhs ) :
2021-04-29 09:13:45 -04:00
fn = fn [ : 2 ] + ' r ' + fn [ 2 : ]
2021-08-18 11:15:53 -07:00
return getattr ( rhs , fn ) ( lhs , _builder = self . builder )
2021-04-29 09:13:45 -04:00
else :
2022-12-21 01:30:50 -08:00
return getattr ( lhs , fn ) ( rhs )
2021-04-20 22:29:40 -04:00
def visit_UnaryOp ( self , node ) :
op = self . visit ( node . operand )
2021-12-09 13:21:22 -08:00
if type ( node . op ) == ast . Not :
assert isinstance ( op , triton . language . constexpr ) , " `not` only supported for constexpr at the moment "
return triton . language . constexpr ( not op )
2021-04-20 22:29:40 -04:00
fn = {
ast . USub : ' __neg__ ' ,
ast . UAdd : ' __pos__ ' ,
ast . Invert : ' __invert__ ' ,
} [ type ( node . op ) ]
2022-12-21 01:30:50 -08:00
if self . is_triton_tensor ( op ) :
return getattr ( op , fn ) ( _builder = self . builder )
return getattr ( op , fn ) ( )
2021-04-20 22:29:40 -04:00
def visit_While ( self , node ) :
2022-12-21 01:30:50 -08:00
with enter_sub_region ( self ) as sr :
liveins , insert_block = sr
2021-04-20 22:29:40 -04:00
2022-12-21 01:30:50 -08:00
# condition (the before region)
cond_block = self . builder . create_block ( )
self . builder . set_insertion_point_to_start ( cond_block )
2021-04-20 22:29:40 -04:00
cond = self . visit ( node . test )
2022-12-21 01:30:50 -08:00
# loop body (the after region)
loop_block = self . builder . create_block ( )
self . builder . set_insertion_point_to_start ( loop_block )
self . visit_compound_statement ( node . body )
loop_defs = self . local_defs
# collect loop-carried values
names = [ ]
ret_types = [ ]
init_args = [ ]
yields = [ ]
for name in loop_defs :
if name in liveins :
# We should not def new constexpr
assert self . is_triton_tensor ( loop_defs [ name ] )
assert self . is_triton_tensor ( liveins [ name ] )
if loop_defs [ name ] . type == liveins [ name ] . type :
# these are loop-carried values
names . append ( name )
ret_types . append ( loop_defs [ name ] . type )
init_args . append ( liveins [ name ] )
yields . append ( loop_defs [ name ] )
self . builder . set_insertion_point_to_end ( insert_block )
while_op = self . builder . create_while_op ( [ ty . to_ir ( self . builder ) for ty in ret_types ] ,
[ arg . handle for arg in init_args ] )
# merge the condition region
before_block = self . builder . create_block_with_parent ( while_op . get_before ( ) ,
[ ty . to_ir ( self . builder ) for ty in ret_types ] )
cond_block . merge_block_before ( before_block )
self . builder . set_insertion_point_to_end ( before_block )
# create ConditionOp: e.g., scf.condition(%cond) %arg0, %arg1, ...
self . builder . create_condition_op ( cond . handle , [ before_block . arg ( i ) for i in range ( len ( init_args ) ) ] )
# merge the loop body
after_block = self . builder . create_block_with_parent ( while_op . get_after ( ) ,
[ ty . to_ir ( self . builder ) for ty in ret_types ] )
loop_block . merge_block_before ( after_block )
self . builder . set_insertion_point_to_end ( after_block )
self . builder . create_yield_op ( [ y . handle for y in yields ] )
# update global uses in while_op
for i , name in enumerate ( names ) :
before_block . replace_use_in_block_with ( init_args [ i ] . handle , before_block . arg ( i ) )
after_block . replace_use_in_block_with ( init_args [ i ] . handle , after_block . arg ( i ) )
# WhileOp defines new values, update the symbol table (lscope, local_defs)
for i , name in enumerate ( names ) :
new_def = triton . language . core . tensor ( while_op . get_result ( i ) , ret_types [ i ] )
self . lscope [ name ] = new_def
self . local_defs [ name ] = new_def
2021-04-20 22:29:40 -04:00
for stmt in node . orelse :
2022-12-21 01:30:50 -08:00
assert False , " Not implemented "
2021-04-20 22:29:40 -04:00
ast . NodeVisitor . generic_visit ( self , stmt )
def visit_Subscript ( self , node ) :
assert node . ctx . __class__ . __name__ == " Load "
lhs = self . visit ( node . value )
slices = self . visit ( node . slice )
2022-12-21 01:30:50 -08:00
if self . is_triton_tensor ( lhs ) :
2021-08-18 11:15:53 -07:00
return lhs . __getitem__ ( slices , _builder = self . builder )
2021-04-20 22:29:40 -04:00
return lhs [ slices ]
def visit_ExtSlice ( self , node ) :
return [ self . visit ( dim ) for dim in node . dims ]
def visit_For ( self , node ) :
iterator = self . visit ( node . iter . func )
2022-12-21 01:30:50 -08:00
if iterator != self . builtins [ ' range ' ] :
2021-08-14 10:11:18 -07:00
raise RuntimeError ( ' Only `range` iterator currently supported ' )
2022-12-21 01:30:50 -08:00
# visit iterator arguments
# note: only `range` iterator is supported now
2021-11-29 19:11:26 -08:00
iter_args = [ self . visit ( arg ) for arg in node . iter . args ]
2022-12-21 01:30:50 -08:00
# collect lower bound (lb), upper bound (ub), and step
lb = iter_args [ 0 ] if len ( iter_args ) > 1 else self . visit ( ast . Num ( 0 ) )
ub = iter_args [ 1 ] if len ( iter_args ) > 1 else self . visit ( node . iter . args [ 0 ] )
step = iter_args [ 2 ] if len ( iter_args ) > 2 else self . visit ( ast . Num ( 1 ) )
# static for loops: all iterator arguments are constexpr
if isinstance ( lb , triton . language . constexpr ) and \
isinstance ( ub , triton . language . constexpr ) and \
isinstance ( step , triton . language . constexpr ) :
sta_range = iterator ( lb . value , ub . value , step . value )
static_unrolling = os . environ . get ( ' TRITON_STATIC_LOOP_UNROLLING ' , False )
if static_unrolling and len ( sta_range ) < = 10 :
for i in sta_range :
self . lscope [ node . target . id ] = triton . language . constexpr ( i )
2021-11-29 19:11:26 -08:00
self . visit_compound_statement ( node . body )
for stmt in node . orelse :
ast . NodeVisitor . generic_visit ( self , stmt )
return
2022-12-21 01:30:50 -08:00
# handle negative constant step (not supported by scf.for in MLIR)
negative_step = False
if isinstance ( step , triton . language . constexpr ) and step . value < 0 :
step = triton . language . constexpr ( - step . value )
negative_step = True
lb , ub = ub , lb
# lb/ub/step might be constexpr, we need to cast them to tensor
lb = triton . language . core . _to_tensor ( lb , self . builder ) . handle
ub = triton . language . core . _to_tensor ( ub , self . builder ) . handle
step = triton . language . core . _to_tensor ( step , self . builder ) . handle
# ForOp can only accept IndexType as lb/ub/step. Cast integer to Index
lb = self . builder . create_to_index ( lb )
ub = self . builder . create_to_index ( ub )
step = self . builder . create_to_index ( step )
# Create placeholder for the loop induction variable
iv = self . builder . create_undef ( self . builder . get_int32_ty ( ) )
self . set_value ( node . target . id , triton . language . core . tensor ( iv , triton . language . core . int32 ) )
with enter_sub_region ( self ) as sr :
liveins , insert_block = sr
# create loop body block
block = self . builder . create_block ( )
self . builder . set_insertion_point_to_start ( block )
# visit loop body
self . visit_compound_statement ( node . body )
# If a variable (name) is defined in both its parent & itself, then it's
# a loop-carried variable. (They must be of the same type)
init_args = [ ]
yields = [ ]
names = [ ]
for name in self . local_defs :
if name in liveins :
assert self . is_triton_tensor ( self . local_defs [ name ] ) , f ' { name } is not tensor '
assert self . is_triton_tensor ( liveins [ name ] )
if self . local_defs [ name ] . type != liveins [ name ] . type :
local_value = self . local_defs [ name ]
self . local_defs [ name ] = local_value . to ( liveins [ name ] . dtype , _builder = self . builder )
names . append ( name )
init_args . append ( triton . language . core . _to_tensor ( liveins [ name ] , self . builder ) )
yields . append ( triton . language . core . _to_tensor ( self . local_defs [ name ] , self . builder ) )
# create ForOp
self . builder . set_insertion_point_to_end ( insert_block )
for_op = self . builder . create_for_op ( lb , ub , step , [ arg . handle for arg in init_args ] )
block . merge_block_before ( for_op . get_body ( 0 ) )
# update induction variable with actual value, and replace all uses
self . builder . set_insertion_point_to_start ( for_op . get_body ( 0 ) )
iv = self . builder . create_index_to_si ( for_op . get_induction_var ( ) )
if negative_step :
ub_si = self . builder . create_index_to_si ( ub )
iv = self . builder . create_sub ( ub_si , iv )
self . lscope [ node . target . id ] . handle . replace_all_uses_with ( iv )
self . set_value ( node . target . id , triton . language . core . tensor ( iv , triton . language . core . int32 ) )
# create YieldOp
self . builder . set_insertion_point_to_end ( for_op . get_body ( 0 ) )
if len ( yields ) > 0 :
self . builder . create_yield_op ( [ y . handle for y in yields ] )
for_op_region = for_op . get_body ( 0 ) . get_parent ( )
assert for_op_region . size ( ) == 1 , " We use SCF, so the loop body should only have one block "
# replace global uses with block arguments
for i , name in enumerate ( names ) :
# arg0 is the induction variable
for_op . get_body ( 0 ) . replace_use_in_block_with ( init_args [ i ] . handle , for_op . get_body ( 0 ) . arg ( i + 1 ) )
# update lscope & local_defs (ForOp defines new values)
for i , name in enumerate ( names ) :
self . set_value ( name , triton . language . core . tensor ( for_op . get_result ( i ) , yields [ i ] . type ) )
2021-04-20 22:29:40 -04:00
for stmt in node . orelse :
2022-12-21 01:30:50 -08:00
assert False , " Don ' t know what to do with else after for "
2021-04-20 22:29:40 -04:00
ast . NodeVisitor . generic_visit ( self , stmt )
def visit_Slice ( self , node ) :
lower = self . visit ( node . lower )
upper = self . visit ( node . upper )
step = self . visit ( node . step )
return slice ( lower , upper , step )
def visit_Index ( self , node ) :
return self . visit ( node . value )
def visit_keyword ( self , node ) :
return { node . arg : self . visit ( node . value ) }
def visit_Call ( self , node ) :
fn = self . visit ( node . func )
2021-11-12 00:55:00 -08:00
if isinstance ( fn , triton . language . constexpr ) :
fn = fn . value
2021-04-20 22:29:40 -04:00
kws = dict ( )
for keyword in node . keywords :
kws . update ( self . visit ( keyword ) )
args = [ self . visit ( arg ) for arg in node . args ]
2022-09-18 08:51:48 -07:00
if isinstance ( fn , triton . runtime . JITFunction ) :
2022-04-03 20:58:16 -07:00
from inspect import getcallargs
args = getcallargs ( fn . fn , * args , * * kws )
args = [ args [ name ] for name in fn . arg_names ]
2022-04-06 16:13:53 -07:00
args = [ arg if isinstance ( arg , triton . language . tensor )
2022-04-03 20:58:16 -07:00
else triton . language . constexpr ( arg ) for arg in args ]
# generate function def
attributes = dict ( )
constexprs = [ i for i , arg in enumerate ( args ) if isinstance ( arg , triton . language . constexpr ) ]
constants = { i : args [ i ] for i in constexprs }
# generate call
args = [ None if i in constexprs else arg for i , arg in enumerate ( args ) ]
arg_vals = [ arg . handle for arg in args if arg is not None ]
2022-04-06 16:13:53 -07:00
arg_types = [ arg . type for arg in args if arg is not None ]
2022-04-03 20:58:16 -07:00
fn_name = mangle_fn ( fn . __name__ , arg_types , constants )
# generate function def if necessary
if not self . module . has_function ( fn_name ) :
2022-12-21 01:30:50 -08:00
prototype = triton . language . function_type ( [ ] , arg_types )
2022-04-03 20:58:16 -07:00
gscope = sys . modules [ fn . fn . __module__ ] . __dict__
2022-12-21 01:30:50 -08:00
generator = CodeGenerator ( self . builder . context , prototype , gscope , attributes , constants , module = self . module , function_name = fn_name , function_types = self . function_ret_types )
2022-04-03 20:58:16 -07:00
generator . visit ( fn . parse ( ) )
2022-12-21 01:30:50 -08:00
callee_ret_type = generator . last_ret_type
self . function_ret_types [ fn_name ] = callee_ret_type
else :
callee_ret_type = self . function_ret_types [ fn_name ]
2022-04-03 20:58:16 -07:00
symbol = self . module . get_function ( fn_name )
2022-12-21 01:30:50 -08:00
call_op = self . builder . call ( symbol , arg_vals )
if call_op . get_num_results ( ) == 0 or callee_ret_type is None :
return None
elif call_op . get_num_results ( ) == 1 :
return triton . language . tensor ( call_op . get_result ( 0 ) , callee_ret_type )
else :
# should return a tuple of tl.tensor
results = [ ]
for i in range ( call_op . get_num_results ( ) ) :
results . append ( triton . language . tensor ( call_op . get_result ( i ) , callee_ret_type [ i ] ) )
return tuple ( results )
if ( hasattr ( fn , ' __self__ ' ) and self . is_triton_tensor ( fn . __self__ ) ) \
or impl . is_builtin ( fn ) :
return fn ( * args , _builder = self . builder , * * kws )
if fn in self . builtins . values ( ) :
2022-01-30 20:21:20 -08:00
args = [ arg . value if isinstance ( arg , triton . language . constexpr ) else arg
for arg in args ]
2022-12-21 01:30:50 -08:00
return fn ( * args , * * kws )
2021-04-20 22:29:40 -04:00
2022-01-30 20:21:20 -08:00
def visit_Constant ( self , node ) :
return triton . language . constexpr ( node . value )
2022-12-21 01:30:50 -08:00
def visit_BoolOp ( self , node : ast . BoolOp ) :
assert len ( node . values ) == 2
lhs = self . visit ( node . values [ 0 ] )
rhs = self . visit ( node . values [ 1 ] )
if isinstance ( lhs , triton . language . constexpr ) :
lhs = lhs . value
if isinstance ( rhs , triton . language . constexpr ) :
rhs = rhs . value
fn = {
ast . And : ' logical_and ' ,
ast . Or : ' logical_or ' ,
} [ type ( node . op ) ]
if self . is_triton_tensor ( lhs ) :
return getattr ( lhs , fn ) ( rhs , _builder = self . builder )
elif self . is_triton_tensor ( rhs ) :
fn = fn [ : 2 ] + ' r ' + fn [ 2 : ]
return getattr ( rhs , fn ) ( lhs , _builder = self . builder )
else :
return getattr ( lhs , fn ) ( rhs )
2022-01-30 20:21:20 -08:00
if sys . version_info < ( 3 , 8 ) :
def visit_NameConstant ( self , node ) :
return triton . language . constexpr ( node . value )
def visit_Num ( self , node ) :
return triton . language . constexpr ( node . n )
def visit_Str ( self , node ) :
return triton . language . constexpr ( ast . literal_eval ( node ) )
2021-04-20 22:29:40 -04:00
def visit_Attribute ( self , node ) :
lhs = self . visit ( node . value )
2022-12-21 01:30:50 -08:00
if isinstance ( lhs , triton . language . tensor ) :
if node . attr == " T " :
return triton . language . semantic . trans ( lhs , builder = self . builder )
2021-04-20 22:29:40 -04:00
return getattr ( lhs , node . attr )
def visit_Expr ( self , node ) :
ast . NodeVisitor . generic_visit ( self , node )
def visit_NoneType ( self , node ) :
2022-04-29 14:35:09 -07:00
return None
2021-04-20 22:29:40 -04:00
def visit ( self , node ) :
if node is not None :
self . last_node = node
2021-12-10 15:19:20 -08:00
with warnings . catch_warnings ( ) :
# The ast library added visit_Constant and deprecated some other
# methods but we can't move to that without breaking Python 3.6 and 3.7.
2021-12-21 09:46:05 -08:00
warnings . simplefilter ( " ignore " , DeprecationWarning ) # python 3.9
warnings . simplefilter ( " ignore " , PendingDeprecationWarning ) # python 3.8
2021-12-10 15:19:20 -08:00
return super ( ) . visit ( node )
2021-04-20 22:29:40 -04:00
def generic_visit ( self , node ) :
typename = type ( node ) . __name__
raise NotImplementedError ( " Unsupported node: {} " . format ( typename ) )
class CompilationError ( Exception ) :
2021-12-10 15:19:20 -08:00
def __init__ ( self , src , node ) :
2021-12-17 20:11:45 -08:00
self . message = f ' at { node . lineno } : { node . col_offset } : \n '
self . message + = ' \n ' . join ( src . split ( ' \n ' ) [ : node . lineno ] )
2021-04-20 22:29:40 -04:00
self . message + = ' \n ' + ' ' * node . col_offset + ' ^ '
2022-03-23 00:09:49 -05:00
self . src = src
self . node = node
2021-04-20 22:29:40 -04:00
super ( ) . __init__ ( self . message )
2022-03-23 00:09:49 -05:00
def __reduce__ ( self ) :
2021-12-21 22:14:06 -08:00
# this is necessary to make CompilationError picklable
2022-03-23 00:09:49 -05:00
return ( type ( self ) , ( self . src , self . node ) )
2021-04-20 22:29:40 -04:00
2021-08-21 06:00:54 +02:00
2021-06-21 14:25:13 +08:00
class OutOfResources ( Exception ) :
def __init__ ( self , required , limit , name ) :
2021-12-22 01:56:10 +08:00
self . message = f ' out of resource: { name } , ' \
f ' Required: { required } , ' \
2021-06-21 14:25:13 +08:00
f ' Hardware limit: { limit } '
2022-12-21 01:30:50 -08:00
self . message + = ' . Reducing block sizes or `num_stages` may help. '
2022-03-23 00:09:49 -05:00
self . required = required
self . limit = limit
self . name = name
2021-06-21 14:25:13 +08:00
super ( ) . __init__ ( self . message )
2022-03-23 00:09:49 -05:00
def __reduce__ ( self ) :
# this is necessary to make CompilationError picklable
return ( type ( self ) , ( self . required , self . limit , self . name ) )
2021-06-21 14:25:13 +08:00
2021-04-20 22:29:40 -04:00
2022-09-18 08:51:48 -07:00
def kernel_suffix ( signature , specialization ) :
# suffix format:
# <argid><'c' if equal to 1><'d' if divisible by 16>
suffix = ' '
for i , _ in enumerate ( signature ) :
suffix + = str ( i )
if i in specialization . equal_to_1 :
suffix + = ' c '
if i in specialization . divisible_by_16 :
suffix + = ' d '
return suffix
2022-12-21 01:30:50 -08:00
# ------------------------------------------------------------------------------
# ------------------------------------------------------------------------------
2022-09-18 08:51:48 -07:00
2022-12-21 01:30:50 -08:00
def build_triton_ir ( fn , signature , specialization , constants ) :
# canonicalize signature
if isinstance ( signature , str ) :
signature = { k : v . strip ( ) for k , v in enumerate ( signature . split ( " , " ) ) }
2022-09-18 08:51:48 -07:00
context = _triton . ir . context ( )
2022-12-21 01:30:50 -08:00
context . load_triton ( )
2022-09-18 08:51:48 -07:00
# create kernel prototype
cst_key = lambda i : fn . arg_names . index ( i ) if isinstance ( i , str ) else i
constants = { cst_key ( key ) : value for key , value in constants . items ( ) }
# visit kernel AST
gscope = fn . __globals__ . copy ( )
function_name = ' _ ' . join ( [ fn . __name__ , kernel_suffix ( signature . values ( ) , specialization ) ] )
tys = list ( signature . values ( ) )
2022-10-05 11:00:32 -07:00
new_constants = { k : True if k in tys and tys [ k ] == " i1 " else 1 for k in specialization . equal_to_1 }
2022-09-18 08:51:48 -07:00
new_attrs = { k : ( " multiple_of " , 16 ) for k in specialization . divisible_by_16 }
all_constants = constants . copy ( )
all_constants . update ( new_constants )
arg_types = [ str_to_ty ( v ) for k , v in signature . items ( ) if k not in constants ]
2022-12-21 01:30:50 -08:00
prototype = triton . language . function_type ( [ ] , arg_types )
2022-09-18 08:51:48 -07:00
generator = CodeGenerator ( context , prototype , gscope = gscope , constants = all_constants , function_name = function_name , attributes = new_attrs , is_kernel = True )
try :
generator . visit ( fn . parse ( ) )
except Exception as e :
node = generator . last_node
if node is None or isinstance ( e , ( NotImplementedError , CompilationError ) ) :
raise e
raise CompilationError ( fn . src , node ) from e
ret = generator . module
# module takes ownership of the context
ret . context = context
return ret , generator
2022-12-21 01:30:50 -08:00
def optimize_triton_ir ( mod ) :
pm = _triton . ir . pass_manager ( mod . context )
pm . enable_debug ( )
pm . add_inliner_pass ( )
pm . add_triton_combine_pass ( )
pm . add_canonicalizer_pass ( )
pm . add_cse_pass ( )
pm . add_licm_pass ( )
pm . run ( mod )
return mod
def ast_to_ttir ( fn , signature , specialization , constants ) :
mod , _ = build_triton_ir ( fn , signature , specialization , constants )
return optimize_triton_ir ( mod )
def ttir_to_ttgir ( mod , num_warps , num_stages , compute_capability ) :
pm = _triton . ir . pass_manager ( mod . context )
pm . add_convert_triton_to_tritongpu_pass ( num_warps )
pm . enable_debug ( )
pm . add_coalesce_pass ( )
# The combine pass converts blocked layout to mma layout
# for dot ops so that pipeline can get shared memory swizzled correctly.
2022-12-28 13:42:43 -08:00
pm . add_tritongpu_combine_pass ( compute_capability )
2022-12-21 01:30:50 -08:00
pm . add_tritongpu_pipeline_pass ( num_stages )
# Prefetch must be done after pipeline pass because pipeline pass
# extracts slices from the original tensor.
pm . add_tritongpu_prefetch_pass ( )
pm . add_canonicalizer_pass ( )
pm . add_cse_pass ( )
2022-12-28 13:42:43 -08:00
pm . add_tritongpu_combine_pass ( compute_capability )
2022-12-21 01:30:50 -08:00
pm . add_licm_pass ( )
2022-12-28 13:42:43 -08:00
pm . add_tritongpu_combine_pass ( compute_capability )
2022-12-21 01:30:50 -08:00
pm . add_cse_pass ( )
2022-12-30 15:21:00 -08:00
# pm.add_tritongpu_optimize_load_convert_pass()
2023-01-06 14:26:06 -08:00
pm . add_tritongpu_decompose_conversions_to_dot_operand_pass ( )
2023-01-08 14:29:17 -08:00
pm . add_cse_pass ( )
2023-01-09 15:45:06 -08:00
pm . add_symbol_dce_pass ( )
2023-01-09 19:08:51 -08:00
pm . add_tritongpu_sink_conversions_from_shared_pass ( )
2022-12-21 01:30:50 -08:00
pm . run ( mod )
return mod
def add_external_libs ( mod , libs ) :
for name , path in libs . items ( ) :
if len ( name ) == 0 or len ( path ) == 0 :
return
_triton . add_external_libs ( mod , list ( libs . keys ( ) ) , list ( libs . values ( ) ) )
def ttgir_to_llir ( mod , extern_libs , compute_capability ) :
if extern_libs :
add_external_libs ( mod , extern_libs )
return _triton . translate_triton_gpu_to_llvmir ( mod , compute_capability )
def llir_to_ptx ( mod : Any , compute_capability : int , ptx_version : int = None ) - > Tuple [ str , int ] :
2022-09-18 08:51:48 -07:00
'''
Translate TritonGPU module to PTX code .
: param mod : a TritonGPU dialect module
: return :
- PTX code
2022-12-21 01:30:50 -08:00
- shared memory allocation size
2022-09-18 08:51:48 -07:00
'''
2022-12-21 01:30:50 -08:00
if ptx_version is None :
_ , cuda_version = path_to_ptxas ( )
ptx_version = ptx_get_version ( cuda_version )
return _triton . translate_llvmir_to_ptx ( mod , compute_capability , ptx_version )
2022-09-18 08:51:48 -07:00
2022-12-21 01:30:50 -08:00
def ptx_to_cubin ( ptx : str , compute_capability : int ) :
2022-09-18 08:51:48 -07:00
'''
Compile TritonGPU module to cubin .
: param ptx : ptx code
2022-12-21 01:30:50 -08:00
: param compute_capability : compute capability
2022-09-18 08:51:48 -07:00
: return : str
'''
2022-12-21 01:30:50 -08:00
ptxas , _ = path_to_ptxas ( )
return _triton . compile_ptx_to_cubin ( ptx , ptxas , compute_capability )
2022-09-18 08:51:48 -07:00
def ptx_get_kernel_name ( ptx : str ) - > str :
'''
Get kernel name from PTX code .
This Kernel name is required when launching the kernel .
'''
# There is a name mangling in PTX codegen, so the original kernel names in Triton IR are not available in PTX/cubin.
assert ptx
for line in ptx . split ( ' \n ' ) :
line = line . strip ( )
if line . startswith ( ' // .globl ' ) :
return line . split ( ) [ - 1 ]
2022-12-21 01:30:50 -08:00
@functools.lru_cache
def ptx_get_version ( cuda_version ) - > int :
'''
Get the highest PTX version supported by the current CUDA driver .
'''
assert isinstance ( cuda_version , str )
major , minor = map ( int , cuda_version . split ( ' . ' ) )
version = major * 1000 + minor * 10
if version > = 11040 :
return 74
if version > = 11030 :
return 73
if version > = 11020 :
return 72
if version > = 11010 :
return 71
if version > = 11000 :
return 70
if version > = 10020 :
return 65
if version > = 10010 :
return 64
if version > = 10000 :
return 63
raise RuntimeError ( " Triton only support CUDA 10.0 or higher " )
def path_to_ptxas ( ) :
prefixes = [
os . environ . get ( " TRITON_PTXAS_PATH " , " " ) ,
" " ,
" /usr " ,
os . environ . get ( ' CUDA_PATH ' , default_cuda_dir ( ) )
]
for prefix in prefixes :
ptxas = os . path . join ( prefix , " bin " , " ptxas " )
if os . path . exists ( ptxas ) :
result = subprocess . check_output ( [ ptxas , " --version " ] , stderr = subprocess . STDOUT )
if result is not None :
version = re . search ( r " .*release ( \ d+ \ . \ d+).* " , result . decode ( " utf-8 " ) , flags = re . MULTILINE )
if version is not None :
return ptxas , version . group ( 1 )
raise RuntimeError ( " Cannot find ptxas " )
instance_descriptor = namedtuple ( " instance_descriptor " , [ " divisible_by_16 " , " equal_to_1 " ] , defaults = [ set ( ) , set ( ) ] )
# ------------------------------------------------------------------------------
# compiler
# ------------------------------------------------------------------------------
2022-09-18 08:51:48 -07:00
def ty_to_cpp ( ty ) :
if ty [ 0 ] == ' * ' :
return " CUdeviceptr "
return {
" i1 " : " int32_t " ,
" i8 " : " int8_t " ,
" i16 " : " int16_t " ,
" i32 " : " int32_t " ,
" i64 " : " int64_t " ,
" u32 " : " uint32_t " ,
" u64 " : " uint64_t " ,
2022-10-24 19:41:25 -07:00
" fp16 " : " float " ,
" bf16 " : " float " ,
2022-09-18 08:51:48 -07:00
" fp32 " : " float " ,
2022-12-21 01:30:50 -08:00
" f32 " : " float " ,
2022-10-24 19:41:25 -07:00
" fp64 " : " double " ,
2022-09-18 08:51:48 -07:00
} [ ty ]
def generate_name_initializer ( signature ) :
src = " int i = 0; \n "
tys = signature . split ( ' , ' )
for i , ty in enumerate ( tys ) :
src
2022-09-18 14:26:29 -07:00
def binary_name_to_header_name ( name ) :
if len ( name ) > 128 :
# avoid filename too long errors (filename limit is 255)
name = " kernel_ " + hashlib . sha256 ( name . encode ( " utf-8 " ) ) . hexdigest ( )
return f " { name } .h "
2022-12-21 01:30:50 -08:00
def generate_launcher ( constants , signature ) :
2022-09-18 08:51:48 -07:00
arg_decls = ' , ' . join ( f " { ty_to_cpp ( ty ) } arg { i } " for i , ty in signature . items ( ) )
def _extracted_type ( ty ) :
if ty [ 0 ] == ' * ' :
return " PyObject* "
return {
' i1 ' : ' int32_t ' ,
' i32 ' : ' int32_t ' ,
' i64 ' : ' int64_t ' ,
' u32 ' : ' uint32_t ' ,
' u64 ' : ' uint64_t ' ,
2022-10-24 19:41:25 -07:00
' fp16 ' : ' float ' ,
' bf16 ' : ' float ' ,
2022-09-18 08:51:48 -07:00
' fp32 ' : ' float ' ,
2022-12-21 01:30:50 -08:00
' f32 ' : ' float ' ,
2022-09-18 08:51:48 -07:00
' fp64 ' : ' double ' ,
} [ ty ]
def format_of ( ty ) :
return {
" PyObject* " : " O " ,
" float " : " f " ,
" double " : " d " ,
" long " : " l " ,
" uint32_t " : " I " ,
" int32_t " : " i " ,
" uint64_t " : " K " ,
" int64_t " : " L " ,
} [ ty ]
2022-10-05 14:46:55 -07:00
format = " iiiiiKKOOO " + ' ' . join ( [ format_of ( _extracted_type ( ty ) ) for ty in signature . values ( ) ] )
2022-09-18 08:51:48 -07:00
# generate glue code
2022-09-22 16:44:22 -07:00
src = f """
2022-09-18 08:51:48 -07:00
#include \"cuda.h\"
#include <Python.h>
2022-09-19 21:01:36 -07:00
static inline void gpuAssert ( CUresult code , const char * file , int line )
2022-09-18 08:51:48 -07:00
{ {
if ( code != CUDA_SUCCESS )
{ {
const char * prefix = " Triton Error [CUDA]: " ;
const char * str ;
cuGetErrorString ( code , & str ) ;
char err [ 1024 ] = { { 0 } } ;
strcat ( err , prefix ) ;
strcat ( err , str ) ;
PyErr_SetString ( PyExc_RuntimeError , err ) ;
} }
} }
2022-12-21 01:30:50 -08:00
#define CUDA_CHECK(ans) {{ gpuAssert((ans), __FILE__, __LINE__); }}
2022-09-18 08:51:48 -07:00
2022-09-22 16:44:22 -07:00
void _launch ( int gridX , int gridY , int gridZ , int num_warps , int shared_memory , CUstream stream , CUfunction function , { arg_decls } ) { {
2022-09-18 08:51:48 -07:00
void * params [ ] = { { { ' , ' . join ( f " &arg { i } " for i in signature . keys ( ) if i not in constants ) } } } ;
if ( gridX * gridY * gridZ > 0 ) { {
2022-09-22 16:44:22 -07:00
CUDA_CHECK ( cuLaunchKernel ( function , gridX , gridY , gridZ , 32 * num_warps , 1 , 1 , shared_memory , stream , params , 0 ) ) ;
2022-09-18 08:51:48 -07:00
} }
} }
2022-09-22 16:44:22 -07:00
static inline CUdeviceptr getPointer ( PyObject * obj , int idx ) { {
2022-09-18 08:51:48 -07:00
if ( PyLong_Check ( obj ) ) { {
return ( CUdeviceptr ) PyLong_AsUnsignedLongLong ( obj ) ;
} }
if ( obj == Py_None ) { {
return ( CUdeviceptr ) 0 ;
} }
PyObject * ptr = PyObject_GetAttrString ( obj , " data_ptr " ) ;
if ( ptr ) { {
PyObject * empty_tuple = PyTuple_New ( 0 ) ;
PyObject * ret = PyObject_Call ( ptr , empty_tuple , NULL ) ;
Py_DECREF ( empty_tuple ) ;
Py_DECREF ( ptr ) ;
if ( ! PyLong_Check ( ret ) ) { {
PyErr_SetString ( PyExc_TypeError , " data_ptr method of Pointer object must return 64-bit int " ) ;
} }
return ( CUdeviceptr ) PyLong_AsUnsignedLongLong ( ret ) ;
} }
PyErr_SetString ( PyExc_TypeError , " Pointer argument must be either uint64 or have data_ptr method " ) ;
return ( CUdeviceptr ) 0 ;
} }
2022-09-22 16:44:22 -07:00
static PyObject * launch ( PyObject * self , PyObject * args ) { {
2022-09-18 08:51:48 -07:00
int gridX , gridY , gridZ ;
2022-09-22 16:44:22 -07:00
uint64_t _stream ;
uint64_t _function ;
int num_warps ;
int shared_memory ;
2022-10-05 14:46:55 -07:00
PyObject * launch_enter_hook = NULL ;
PyObject * launch_exit_hook = NULL ;
PyObject * compiled_kernel = NULL ;
PyObject * hook_ret = NULL ;
2022-09-18 08:51:48 -07:00
{ ' ' . join ( [ f " { _extracted_type ( ty ) } _arg { i } ; " for i , ty in signature . items ( ) ] ) }
2022-10-05 14:46:55 -07:00
if ( ! PyArg_ParseTuple ( args , \" {format} \" , &gridX, &gridY, &gridZ, &num_warps, &shared_memory, &_stream, &_function, &launch_enter_hook, &launch_exit_hook, &compiled_kernel, { ' , ' .join(f " & _arg { i } " for i, ty in signature.items())})) {{
2022-09-18 08:51:48 -07:00
return NULL ;
} }
2022-10-05 14:46:55 -07:00
if ( launch_enter_hook != Py_None ) { {
PyObject * new_args = PyTuple_Pack ( 1 , compiled_kernel ) ;
hook_ret = PyObject_CallObject ( launch_enter_hook , new_args ) ;
Py_DECREF ( new_args ) ;
} }
2022-09-22 16:44:22 -07:00
_launch ( gridX , gridY , gridZ , num_warps , shared_memory , ( CUstream ) _stream , ( CUfunction ) _function , { ' , ' . join ( f " getPointer(_arg { i } , { i } ) " if ty [ 0 ] == " * " else f " _arg { i } " for i , ty in signature . items ( ) ) } ) ;
2022-09-18 08:51:48 -07:00
2022-10-05 14:46:55 -07:00
if ( launch_exit_hook != Py_None ) { {
PyObject * new_args = NULL ;
if ( hook_ret ) { {
new_args = PyTuple_Pack ( 2 , compiled_kernel , hook_ret ) ;
} } else { {
new_args = PyTuple_Pack ( 1 , compiled_kernel ) ;
} }
hook_ret = PyObject_CallObject ( launch_exit_hook , new_args ) ;
Py_DECREF ( new_args ) ;
} }
2022-09-18 08:51:48 -07:00
2022-10-05 14:46:55 -07:00
if ( hook_ret ) { {
Py_DECREF ( hook_ret ) ;
} }
2022-09-18 08:51:48 -07:00
if ( PyErr_Occurred ( ) ) { {
return NULL ;
} }
/ / return None
Py_INCREF ( Py_None ) ;
return Py_None ;
} }
static PyMethodDef ModuleMethods [ ] = { {
2022-09-22 16:44:22 -07:00
{ { " launch " , launch , METH_VARARGS , " Entry point for all kernels with this signature " } } ,
2022-09-18 08:51:48 -07:00
{ { NULL , NULL , 0 , NULL } } / / sentinel
} } ;
static struct PyModuleDef ModuleDef = { {
PyModuleDef_HEAD_INIT ,
2022-09-22 16:44:22 -07:00
\" launcher \" ,
2022-09-18 08:51:48 -07:00
NULL , / / documentation
- 1 , / / size
ModuleMethods
} } ;
2022-09-22 16:44:22 -07:00
PyMODINIT_FUNC PyInit_launcher ( void ) { {
2022-09-18 08:51:48 -07:00
PyObject * m = PyModule_Create ( & ModuleDef ) ;
if ( m == NULL ) { {
return NULL ;
} }
PyModule_AddFunctions ( m , ModuleMethods ) ;
return m ;
2022-09-22 16:44:22 -07:00
} }
2022-09-18 08:51:48 -07:00
"""
return src
2021-11-12 00:55:00 -08:00
2022-01-06 14:34:17 -08:00
2022-05-27 16:51:05 -04:00
def default_cache_dir ( ) :
2022-06-13 19:37:52 -07:00
return os . path . join ( os . environ [ " HOME " ] , " .triton " , " cache " )
2022-05-27 16:51:05 -04:00
2022-12-21 01:30:50 -08:00
def default_cuda_dir ( ) :
default_dir = " /usr/local/cuda "
return os . getenv ( " CUDA_HOME " , default = default_dir )
2022-09-18 08:51:48 -07:00
class CacheManager :
def __init__ ( self , key ) :
self . key = key
self . lock_path = None
2022-09-22 16:44:22 -07:00
# create cache directory if it doesn't exist
2022-09-18 08:51:48 -07:00
self . cache_dir = os . environ . get ( ' TRITON_CACHE_DIR ' , default_cache_dir ( ) )
if self . cache_dir :
2022-09-22 16:44:22 -07:00
self . cache_dir = os . path . join ( self . cache_dir , self . key )
self . lock_path = os . path . join ( self . cache_dir , " lock " )
2022-09-18 08:51:48 -07:00
os . makedirs ( self . cache_dir , exist_ok = True )
2022-09-22 16:44:22 -07:00
def _make_path ( self , filename ) :
return os . path . join ( self . cache_dir , filename )
2022-09-18 08:51:48 -07:00
2022-09-22 16:44:22 -07:00
def has_file ( self , filename ) :
if not self . cache_dir :
return False
return os . path . exists ( self . _make_path ( filename ) )
2022-09-18 08:51:48 -07:00
2022-09-22 16:44:22 -07:00
def put ( self , data , filename , binary = True ) :
if not self . cache_dir :
return
2022-12-21 01:30:50 -08:00
binary = isinstance ( data , bytes )
if not binary :
data = str ( data )
2022-09-22 16:44:22 -07:00
assert self . lock_path is not None
filepath = self . _make_path ( filename )
with FileLock ( self . lock_path ) :
# use tempfile to be robust against program interruptions
mode = " wb " if binary else " w "
with open ( filepath + " .tmp " , mode ) as f :
f . write ( data )
os . rename ( filepath + " .tmp " , filepath )
2022-09-18 08:51:48 -07:00
2022-12-21 01:30:50 -08:00
# Utilities for generating and compiling C wrappers
2022-09-18 08:51:48 -07:00
2022-09-19 21:01:36 -07:00
@functools.lru_cache ( )
2022-10-03 19:36:24 +01:00
def libcuda_dirs ( ) :
locs = subprocess . check_output ( [ " whereis " , " libcuda.so " ] ) . decode ( ) . strip ( ) . split ( ) [ 1 : ]
return [ os . path . dirname ( loc ) for loc in locs ]
2022-09-19 21:01:36 -07:00
@contextlib.contextmanager
def quiet ( ) :
old_stdout , old_stderr = sys . stdout , sys . stderr
sys . stdout , sys . stderr = io . StringIO ( ) , io . StringIO ( )
try :
yield
finally :
sys . stdout , sys . stderr = old_stdout , old_stderr
def _build ( name , src , srcdir ) :
2022-10-03 19:36:24 +01:00
cuda_lib_dirs = libcuda_dirs ( )
2022-12-21 01:30:50 -08:00
cuda_path = os . environ . get ( ' CUDA_PATH ' , default_cuda_dir ( ) )
cu_include_dir = os . path . join ( cuda_path , " include " )
2022-09-19 21:01:36 -07:00
suffix = sysconfig . get_config_var ( ' EXT_SUFFIX ' )
so = os . path . join ( srcdir , ' {name} {suffix} ' . format ( name = name , suffix = suffix ) )
# try to avoid setuptools if possible
cc = os . environ . get ( " CC " )
if cc is None :
# TODO: support more things here.
clang = shutil . which ( " clang " )
gcc = shutil . which ( " gcc " )
cc = gcc if gcc is not None else clang
py_include_dir = get_paths ( ) [ " include " ]
2022-12-21 01:30:50 -08:00
2022-10-03 19:36:24 +01:00
cc_cmd = [ cc , src , " -O3 " , f " -I { cu_include_dir } " , f " -I { py_include_dir } " , f " -I { srcdir } " , " -shared " , " -fPIC " , " -lcuda " , " -o " , so ]
cc_cmd + = [ f " -L { dir } " for dir in cuda_lib_dirs ]
ret = subprocess . check_call ( cc_cmd )
2022-12-21 01:30:50 -08:00
2022-09-19 21:01:36 -07:00
if ret == 0 :
return so
# fallback on setuptools
extra_compile_args = [ ]
2022-10-03 19:36:24 +01:00
library_dirs = cuda_lib_dirs
2022-09-19 21:01:36 -07:00
include_dirs = [ srcdir , cu_include_dir ]
libraries = [ ' cuda ' ]
# extra arguments
extra_link_args = [ ]
# create extension module
ext = setuptools . Extension (
name = name ,
language = ' c ' ,
sources = [ src ] ,
include_dirs = include_dirs ,
extra_compile_args = extra_compile_args + [ ' -O3 ' ] ,
extra_link_args = extra_link_args ,
library_dirs = library_dirs ,
libraries = libraries ,
)
# build extension module
args = [ ' build_ext ' ]
args . append ( ' --build-temp= ' + srcdir )
args . append ( ' --build-lib= ' + srcdir )
args . append ( ' -q ' )
args = dict (
name = name ,
ext_modules = [ ext ] ,
script_args = args ,
)
2022-09-18 08:51:48 -07:00
with quiet ( ) :
2022-09-19 21:01:36 -07:00
setuptools . setup ( * * args )
return so
2022-09-18 08:51:48 -07:00
2022-10-11 13:24:30 -07:00
def make_so_cache_key ( version_hash , signature , constants ) :
2022-09-22 16:44:22 -07:00
# Get unique key for the compiled code
signature = { k : ' ptr ' if v [ 0 ] == ' * ' else v for k , v in signature . items ( ) }
2022-10-11 13:24:30 -07:00
key = f " { version_hash } - { ' ' . join ( signature . values ( ) ) } { constants } "
2022-09-22 16:44:22 -07:00
key = hashlib . md5 ( key . encode ( " utf-8 " ) ) . hexdigest ( )
return key
def make_fn_cache_key ( fn_hash , signature , configs , constants , num_warps , num_stages ) :
# Get unique key for the compiled code
get_conf_key = lambda conf : ( sorted ( conf . divisible_by_16 ) , sorted ( conf . equal_to_1 ) )
configs_key = [ get_conf_key ( conf ) for conf in configs ]
key = f " { fn_hash } - { ' ' . join ( signature . values ( ) ) } - { configs_key } - { constants } - { num_warps } - { num_stages } "
key = hashlib . md5 ( key . encode ( " utf-8 " ) ) . hexdigest ( )
return key
2022-12-21 01:30:50 -08:00
def read_or_execute ( cache_manager , force_compile , file_name , metadata ,
run_if_found : Callable [ [ str ] , bytes ] = None ,
run_if_not_found : Callable = None ) :
suffix = file_name . split ( " . " ) [ 1 ]
if not force_compile and cache_manager . has_file ( file_name ) :
module = run_if_found ( cache_manager . _make_path ( file_name ) )
data = module if isinstance ( module , bytes ) else str ( module ) . encode ( " utf-8 " )
md5 = hashlib . md5 ( data ) . hexdigest ( )
has_changed = metadata and md5 != metadata [ " md5 " ] [ suffix ]
return module , md5 , has_changed , True
module = run_if_not_found ( )
data = module if isinstance ( module , bytes ) else str ( module ) . encode ( " utf-8 " )
md5 = hashlib . md5 ( data ) . hexdigest ( )
cache_manager . put ( data , file_name , True if isinstance ( data , bytes ) else data )
return module , md5 , True , False
#
def make_stub ( name , signature , constants ) :
2022-09-22 16:44:22 -07:00
# name of files that are cached
2022-10-11 13:24:30 -07:00
so_cache_key = make_so_cache_key ( triton . runtime . jit . version_key ( ) , signature , constants )
2022-09-22 16:44:22 -07:00
so_cache_manager = CacheManager ( so_cache_key )
so_name = f " { name } .so "
# retrieve stub from cache if it exists
if not so_cache_manager . has_file ( so_name ) :
with tempfile . TemporaryDirectory ( ) as tmpdir :
2022-12-21 01:30:50 -08:00
src = generate_launcher ( constants , signature )
2022-09-22 16:44:22 -07:00
src_path = os . path . join ( tmpdir , " main.c " )
with open ( src_path , " w " ) as f :
f . write ( src )
2022-12-21 01:30:50 -08:00
so = _build ( name , src_path , tmpdir )
2022-09-22 16:44:22 -07:00
with open ( so , " rb " ) as f :
so_cache_manager . put ( f . read ( ) , so_name , binary = True )
2022-12-21 01:30:50 -08:00
return so_cache_manager . _make_path ( so_name )
def convert_type_repr ( x ) :
match = re . search ( r ' !tt \ .ptr<(.*)> ' , x )
if match is not None :
return ' * ' + convert_type_repr ( match . group ( 1 ) )
return x
def make_hash ( fn , * * kwargs ) :
if isinstance ( fn , triton . runtime . JITFunction ) :
configs = kwargs [ " configs " ]
signature = kwargs [ " signature " ]
constants = kwargs . get ( " constants " , dict ( ) )
num_warps = kwargs . get ( " num_warps " , 4 )
num_stages = kwargs . get ( " num_stages " , 3 )
# Get unique key for the compiled code
get_conf_key = lambda conf : ( sorted ( conf . divisible_by_16 ) , sorted ( conf . equal_to_1 ) )
configs_key = [ get_conf_key ( conf ) for conf in configs ]
key = f " { fn . cache_key } - { ' ' . join ( signature . values ( ) ) } - { configs_key } - { constants } - { num_warps } - { num_stages } "
return hashlib . md5 ( key . encode ( " utf-8 " ) ) . hexdigest ( )
assert isinstance ( fn , str )
return hashlib . md5 ( ( Path ( fn ) . read_text ( ) + triton . runtime . jit . version_key ( ) ) . encode ( " utf-8 " ) ) . hexdigest ( )
# - ^\s*func\s+ : match the start of the string, any leading whitespace, the keyword func,
# and any following whitespace
# - (public\s+)? : optionally match the keyword public and any following whitespace
# - (@\w+) : match an @ symbol followed by one or more word characters
# (letters, digits, or underscores), and capture it as group 1 (the function name)
# - (\((?:%\w+: \S+(?: \{\S+ = \S+ : \S+\})?(?:, )?)*\)) : match a pair of parentheses enclosing
# zero or more arguments separated by commas, and capture it as group 2 (the argument list)
mlir_prototype_pattern = r ' ^ \ s*func \ s+(?:public \ s+)?(@ \ w+)( \ ((?: % \ w+: \ S+(?: \ { \ S+ = \ S+ : \ S+ \ })?(?:, )?)* \ )) \ s* \ { \ s*$ '
ptx_prototype_pattern = r " \ .(?:visible|extern) \ s+ \ .(?:entry|func) \ s+( \ w+) \ s* \ (([^)]*) \ ) "
prototype_pattern = {
" ttir " : mlir_prototype_pattern ,
" ttgir " : mlir_prototype_pattern ,
" ptx " : ptx_prototype_pattern ,
}
mlir_arg_type_pattern = r ' % \ w+: ([^,^ \ ) \ s]+)(?: \ { \ S+ = \ S+ : \ S+ \ })?,? '
ptx_arg_type_pattern = r " \ .param \ s+ \ .( \ w+) "
arg_type_pattern = {
" ttir " : mlir_arg_type_pattern ,
" ttgir " : mlir_arg_type_pattern ,
" ptx " : ptx_arg_type_pattern ,
}
# def compile(fn, signature: str, device: int = -1, constants=dict(), num_warps: int = 4, num_stages: int = 3, extern_libs=None, configs=None):
def compile ( fn , * * kwargs ) :
capability = kwargs . get ( " cc " , None )
if capability is None :
device = torch . cuda . current_device ( )
capability = torch . cuda . get_device_capability ( device )
capability = capability [ 0 ] * 10 + capability [ 1 ]
# we get the kernel, i.e. the first function generated in the module
# if fn is not a JITFunction, then it
# has to be a path to a file
context = _triton . ir . context ( )
asm = dict ( )
constants = kwargs . get ( " constants " , dict ( ) )
num_warps = kwargs . get ( " num_warps " , 4 )
num_stages = kwargs . get ( " num_stages " , 3 if capability > = 75 else 2 )
extern_libs = kwargs . get ( " extern_libs " , dict ( ) )
# build compilation stages
stages = {
" ast " : ( lambda path : fn , None ) ,
" ttir " : ( lambda path : _triton . ir . parse_mlir_module ( path , context ) ,
lambda src : ast_to_ttir ( src , signature , configs [ 0 ] , constants ) ) ,
" ttgir " : ( lambda path : _triton . ir . parse_mlir_module ( path , context ) ,
lambda src : ttir_to_ttgir ( src , num_warps , num_stages , capability ) ) ,
" llir " : ( lambda path : Path ( path ) . read_bytes ( ) ,
lambda src : ttgir_to_llir ( src , extern_libs , capability ) ) ,
" ptx " : ( lambda path : Path ( path ) . read_text ( ) ,
lambda src : llir_to_ptx ( src , capability ) ) ,
" cubin " : ( lambda path : Path ( path ) . read_bytes ( ) ,
lambda src : ptx_to_cubin ( src , capability ) )
}
# find out the signature of the function
if isinstance ( fn , triton . runtime . JITFunction ) :
configs = kwargs . get ( " configs " , None )
signature = kwargs [ " signature " ]
if configs is None :
configs = [ instance_descriptor ( ) ]
assert len ( configs ) == 1
kwargs [ " configs " ] = configs
name = fn . __name__
first_stage = 0
if isinstance ( signature , str ) :
signature = { k : v . strip ( ) for k , v in enumerate ( signature . split ( " , " ) ) }
kwargs [ " signature " ] = signature
else :
assert isinstance ( fn , str )
_ , ir = os . path . basename ( fn ) . split ( " . " )
src = Path ( fn ) . read_text ( )
import re
match = re . search ( prototype_pattern [ ir ] , src , re . MULTILINE )
name , signature = match . group ( 1 ) , match . group ( 2 )
print ( name , signature )
types = re . findall ( arg_type_pattern [ ir ] , signature )
print ( types )
param_tys = [ convert_type_repr ( ty ) for ty in types ]
signature = { k : v for k , v in enumerate ( param_tys ) }
first_stage = list ( stages . keys ( ) ) . index ( ir )
2022-09-22 16:44:22 -07:00
2022-12-21 01:30:50 -08:00
# cache manager
so_path = make_stub ( name , signature , constants )
# create cache manager
fn_cache_manager = CacheManager ( make_hash ( fn , * * kwargs ) )
# determine name and extension type of provided function
if isinstance ( fn , triton . runtime . JITFunction ) :
name , ext = fn . __name__ , " ast "
else :
name , ext = os . path . basename ( fn ) . split ( " . " )
# load metadata if any
metadata = None
if fn_cache_manager . has_file ( f ' { name } .json ' ) :
with open ( fn_cache_manager . _make_path ( f " { name } .json " ) ) as f :
metadata = json . load ( f )
else :
metadata = { " num_warps " : num_warps , " num_stages " : num_stages , " ctime " : dict ( ) }
if ext == " ptx " :
assert " shared " in kwargs , " ptx compilation must provide shared memory size "
metadata [ " shared " ] = kwargs [ " shared " ]
first_stage = list ( stages . keys ( ) ) . index ( ext )
asm = dict ( )
module = fn
# run compilation pipeline and populate metadata
for ir , ( parse , compile ) in list ( stages . items ( ) ) [ first_stage : ] :
path = fn_cache_manager . _make_path ( f " { name } . { ir } " )
if ir == ext :
next_module = parse ( fn )
elif os . path . exists ( path ) and \
ir in metadata [ " ctime " ] and \
os . path . getctime ( path ) == metadata [ " ctime " ] [ ir ] :
next_module = parse ( path )
else :
next_module = compile ( module )
fn_cache_manager . put ( next_module , f " { name } . { ir } " )
if os . path . exists ( path ) :
metadata [ " ctime " ] [ ir ] = os . path . getctime ( path )
asm [ ir ] = next_module if ir == " cubin " else str ( next_module )
if ir == " llir " and " shared " not in metadata :
metadata [ " shared " ] = _triton . get_shared_memory_size ( module )
if ir == " ptx " :
metadata [ " name " ] = ptx_get_kernel_name ( next_module )
module = next_module
# write-back metadata
fn_cache_manager . put ( json . dumps ( metadata ) , f " { name } .json " , binary = False )
# return handle to compiled kernel
return CompiledKernel ( so_path , metadata , asm )
2022-09-18 08:51:48 -07:00
class CompiledKernel :
2022-10-05 14:46:55 -07:00
# Hooks for external tools to monitor the execution of triton kernels
launch_enter_hook = None
launch_exit_hook = None
2022-12-21 01:30:50 -08:00
def __init__ ( self , so_path , metadata , asm ) :
2022-09-22 16:44:22 -07:00
# initialize launcher
2022-09-18 08:51:48 -07:00
import importlib . util
2022-09-22 16:44:22 -07:00
spec = importlib . util . spec_from_file_location ( " launcher " , so_path )
2022-09-18 08:51:48 -07:00
mod = importlib . util . module_from_spec ( spec )
spec . loader . exec_module ( mod )
2022-09-22 16:44:22 -07:00
self . c_wrapper = getattr ( mod , " launch " )
# initialize metadata
self . shared = metadata [ " shared " ]
self . num_warps = metadata [ " num_warps " ]
self . num_stages = metadata [ " num_stages " ]
# initialize asm dict
2022-12-21 01:30:50 -08:00
self . asm = asm
# binaries are lazily initialized
# because it involves doing runtime things
# (e.g., checking amount of shared memory on current device)
self . metadata = metadata
self . cu_module = None
self . cu_function = None
def _init_handles ( self ) :
if self . cu_module is not None :
return
device = torch . cuda . current_device ( )
global cuda_utils
init_cuda_utils ( )
max_shared = cuda_utils . get_device_properties ( device ) [ " max_shared_mem " ]
if self . shared > max_shared :
raise OutOfResources ( self . shared , max_shared , " shared memory " )
mod , func , n_regs , n_spills = cuda_utils . load_binary ( self . metadata [ " name " ] , self . asm [ " cubin " ] , self . shared , device )
2022-12-27 20:58:31 -08:00
print ( self . shared , n_regs , n_spills )
2022-09-22 16:44:22 -07:00
self . cu_module = mod
self . cu_function = func
2022-12-21 01:30:50 -08:00
def __getattribute__ ( self , name ) :
if name == ' c_wrapper ' :
self . _init_handles ( )
return super ( ) . __getattribute__ ( name )
2022-03-15 12:20:51 -07:00
2021-04-20 22:29:40 -04:00
def __getitem__ ( self , grid ) :
2022-12-21 01:30:50 -08:00
self . _init_handles ( )
2022-09-18 08:51:48 -07:00
def runner ( * args , stream = None ) :
if stream is None :
stream = torch . cuda . current_stream ( ) . cuda_stream
2022-10-05 14:46:55 -07:00
self . c_wrapper ( grid [ 0 ] , grid [ 1 ] , grid [ 2 ] , self . num_warps , self . shared , stream , self . cu_function ,
CompiledKernel . launch_enter_hook , CompiledKernel . launch_exit_hook , self , * args )
2022-09-18 08:51:48 -07:00
return runner
2022-10-02 17:39:52 -07:00
def get_sass ( self , fun = None ) :
if ' sass ' in self . asm :
return self . asm [ ' sass ' ]
fd , path = tempfile . mkstemp ( )
try :
with open ( fd , ' wb ' ) as cubin :
cubin . write ( self . asm [ ' cubin ' ] )
self . sass = extract ( path , fun )
finally :
os . remove ( path )
self . asm [ ' sass ' ] = self . sass
return self . sass
2022-12-21 01:30:50 -08:00
class CudaUtils ( object ) :
def __new__ ( cls ) :
if not hasattr ( cls , ' instance ' ) :
cls . instance = super ( CudaUtils , cls ) . __new__ ( cls )
return cls . instance
def _generate_src ( self ) :
return """
#include <cuda.h>
#include \"cuda.h\"
#define PY_SSIZE_T_CLEAN
#include <Python.h>
static inline void gpuAssert ( CUresult code , const char * file , int line )
{
if ( code != CUDA_SUCCESS )
{
const char * prefix = " Triton Error [CUDA]: " ;
const char * str ;
cuGetErrorString ( code , & str ) ;
char err [ 1024 ] = { 0 } ;
strcat ( err , prefix ) ;
strcat ( err , str ) ;
PyErr_SetString ( PyExc_RuntimeError , err ) ;
}
}
#define CUDA_CHECK(ans) { gpuAssert((ans), __FILE__, __LINE__); if(PyErr_Occurred()) return NULL; }
static PyObject * getDeviceProperties ( PyObject * self , PyObject * args ) {
int device_id ;
if ( ! PyArg_ParseTuple ( args , " i " , & device_id ) )
return NULL ;
/ / Get device handle
CUdevice device ;
cuDeviceGet ( & device , device_id ) ;
/ / create a struct to hold device properties
int max_shared_mem ;
int multiprocessor_count ;
int sm_clock_rate ;
int mem_clock_rate ;
int mem_bus_width ;
CUDA_CHECK ( cuDeviceGetAttribute ( & max_shared_mem , CU_DEVICE_ATTRIBUTE_MAX_SHARED_MEMORY_PER_BLOCK_OPTIN , device ) ) ;
CUDA_CHECK ( cuDeviceGetAttribute ( & multiprocessor_count , CU_DEVICE_ATTRIBUTE_MULTIPROCESSOR_COUNT , device ) ) ;
CUDA_CHECK ( cuDeviceGetAttribute ( & sm_clock_rate , CU_DEVICE_ATTRIBUTE_CLOCK_RATE , device ) ) ;
CUDA_CHECK ( cuDeviceGetAttribute ( & mem_clock_rate , CU_DEVICE_ATTRIBUTE_MEMORY_CLOCK_RATE , device ) ) ;
CUDA_CHECK ( cuDeviceGetAttribute ( & mem_bus_width , CU_DEVICE_ATTRIBUTE_GLOBAL_MEMORY_BUS_WIDTH , device ) ) ;
return Py_BuildValue ( " { s:i, s:i, s:i, s:i, s:i} " , " max_shared_mem " , max_shared_mem ,
" multiprocessor_count " , multiprocessor_count ,
" sm_clock_rate " , sm_clock_rate ,
" mem_clock_rate " , mem_clock_rate ,
" mem_bus_width " , mem_bus_width ) ;
}
static PyObject * loadBinary ( PyObject * self , PyObject * args ) {
const char * name ;
const char * data ;
Py_ssize_t data_size ;
int shared ;
int device ;
if ( ! PyArg_ParseTuple ( args , " ss#ii " , & name , & data , & data_size , & shared , & device ) ) {
return NULL ;
}
CUfunction fun ;
CUmodule mod ;
int32_t n_regs = 0 ;
int32_t n_spills = 0 ;
/ / create driver handles
CUDA_CHECK ( cuModuleLoadData ( & mod , data ) ) ;
CUDA_CHECK ( cuModuleGetFunction ( & fun , mod , name ) ) ;
/ / get allocated registers and spilled registers from the function
CUDA_CHECK ( cuFuncGetAttribute ( & n_regs , CU_FUNC_ATTRIBUTE_NUM_REGS , fun ) ) ;
CUDA_CHECK ( cuFuncGetAttribute ( & n_spills , CU_FUNC_ATTRIBUTE_LOCAL_SIZE_BYTES , fun ) ) ;
n_spills / = 4 ;
/ / set dynamic shared memory if necessary
int shared_optin ;
CUDA_CHECK ( cuDeviceGetAttribute ( & shared_optin , CU_DEVICE_ATTRIBUTE_MAX_SHARED_MEMORY_PER_BLOCK_OPTIN , device ) ) ;
if ( shared > 49152 & & shared_optin > 49152 ) {
CUDA_CHECK ( cuFuncSetCacheConfig ( fun , CU_FUNC_CACHE_PREFER_SHARED ) ) ;
int shared_total , shared_static ;
CUDA_CHECK ( cuDeviceGetAttribute ( & shared_total , CU_DEVICE_ATTRIBUTE_MAX_SHARED_MEMORY_PER_MULTIPROCESSOR , device ) ) ;
CUDA_CHECK ( cuFuncGetAttribute ( & shared_static , CU_FUNC_ATTRIBUTE_SHARED_SIZE_BYTES , fun ) ) ;
CUDA_CHECK ( cuFuncSetAttribute ( fun , CU_FUNC_ATTRIBUTE_MAX_DYNAMIC_SHARED_SIZE_BYTES , shared_optin - shared_static ) ) ;
}
if ( PyErr_Occurred ( ) ) {
return NULL ;
}
return Py_BuildValue ( " (KKii) " , ( uint64_t ) mod , ( uint64_t ) fun , n_regs , n_spills ) ;
}
static PyMethodDef ModuleMethods [ ] = {
{ " load_binary " , loadBinary , METH_VARARGS , " Load provided cubin into CUDA driver " } ,
{ " get_device_properties " , getDeviceProperties , METH_VARARGS , " Get the properties for a given device " } ,
{ NULL , NULL , 0 , NULL } / / sentinel
} ;
static struct PyModuleDef ModuleDef = {
PyModuleDef_HEAD_INIT ,
\" cuda_utils \" ,
NULL , / / documentation
- 1 , / / size
ModuleMethods
} ;
PyMODINIT_FUNC PyInit_cuda_utils ( void ) {
PyObject * m = PyModule_Create ( & ModuleDef ) ;
if ( m == NULL ) {
return NULL ;
}
PyModule_AddFunctions ( m , ModuleMethods ) ;
return m ;
}
"""
def __init__ ( self ) :
src = self . _generate_src ( )
key = hashlib . md5 ( src . encode ( " utf-8 " ) ) . hexdigest ( )
cache = CacheManager ( key )
fname = " cuda_utils.so "
if not cache . has_file ( fname ) :
with tempfile . TemporaryDirectory ( ) as tmpdir :
src_path = os . path . join ( tmpdir , " main.c " )
with open ( src_path , " w " ) as f :
f . write ( src )
so = _build ( " cuda_utils " , src_path , tmpdir )
with open ( so , " rb " ) as f :
cache . put ( f . read ( ) , fname , binary = True )
import importlib . util
spec = importlib . util . spec_from_file_location ( " cuda_utils " , cache . _make_path ( fname ) )
mod = importlib . util . module_from_spec ( spec )
spec . loader . exec_module ( mod )
self . load_binary = mod . load_binary
self . get_device_properties = mod . get_device_properties
def init_cuda_utils ( ) :
global cuda_utils
if cuda_utils is None :
cuda_utils = CudaUtils ( )
cuda_utils = None