@@ -1,4 +1,4 @@
# include < string >
# include < string>
# include <mutex>
# include <regex>
# include <functional>
@@ -45,360 +45,25 @@ std::mutex mut;
namespace triton {
namespace runtime {
/* --------------------- */
/* HELPERS */
/* --------------------- */
/* --------------------------------- */
/* --------------------------------- */
/* --------------------------------- */
void _loop_nest ( std : : vector < size_t > const & ranges ,
std : : function < void ( std : : vector < size_t > const & ) > const & f ) {
size_t D = ranges . size ( ) ;
std : : vector < size_t > values ( D , 0 ) ;
size_t i = D - 1 ;
while ( true ) {
f ( values ) ;
while ( values [ i ] + + = = ranges [ i ] - 1 ) {
if ( i = = 0 )
return ;
values [ i - - ] = 0 ;
}
i = D - 1 ;
}
}
/* --------------------- */
/* OPTIONS */
/* --------------------- */
std : : string options_t : : to_str ( ) const {
std : : string ret = " nw- " + std : : to_string ( num_warps ) ;
for ( const auto & x : defines ) {
ret + = ' - ' ;
ret + = x . first ;
ret + = ' - ' ;
ret + = x . second ;
}
// legalize
for ( char & x : ret ) {
if ( x = = ' ' | | x = = ' ^ ' | | x = = ' , ' | | x = = ' : ' )
x = ' _ ' ;
}
return ret ;
}
/* --------------------- */
/* CALLER OBJECT */
/* --------------------- */
arg_type convert ( ir : : type * ty ) {
if ( ty - > is_integer_ty ( 1 ) )
return INT1_T ;
if ( ty - > is_integer_ty ( 8 ) )
return INT8_T ;
if ( ty - > is_integer_ty ( 16 ) )
return INT16_T ;
if ( ty - > is_integer_ty ( 32 ) )
return INT32_T ;
if ( ty - > is_integer_ty ( 64 ) )
return INT64_T ;
if ( ty - > is_half_ty ( ) )
return HALF_T ;
if ( ty - > is_float_ty ( ) )
return FLOAT_T ;
if ( ty - > is_double_ty ( ) )
return DOUBLE_T ;
if ( ty - > is_pointer_ty ( ) )
return BUFFER_T ;
arg_type kernel : : convert ( ir : : type * ty ) {
if ( ty - > is_integer_ty ( 1 ) ) return INT1_T ;
if ( ty - > is_integer_ty ( 8 ) ) return INT8_T ;
if ( ty - > is_integer_ty ( 16 ) ) return INT16_T ;
if ( ty - > is_integer_ty ( 32 ) ) return INT32_T ;
if ( ty - > is_integer_ty ( 64 ) ) return INT64_T ;
i f( ty - > is_half_ty ( ) ) return HALF_T ;
if ( ty - > is_float_ty ( ) ) return FLOAT_T ;
if ( ty - > is_double_ty ( ) ) return DOUBLE_T ;
if ( ty - > is_pointer_ty ( ) ) return BUFFER_T ;
throw std : : runtime_error ( " unknown type " ) ;
}
//void function::caller::write(std::ofstream &ofs) {
// // write name
// ofs << name_ << std::endl;
// // write signature
// for(size_t i = 0; i < param_tys_.size(); i++)
// ofs << param_tys_[i] << " ";
// ofs << std::endl;
// // write module
// std::string source = ((driver::cu_module*)(&*parent_))->ptx();
// ofs << source;
//}
//void function::caller::read(driver::context* ctx, std::ifstream &ifs) {
// // read name
// std::getline(ifs, name_);
// // read signature
// std::string line;
// std::getline(ifs, line);
// std::istringstream current(line);
// int param;
// param_tys_.clear();
// while(current >> param)
// param_tys_.push_back((arg_type)param);
// // read module
// std::string src((std::istreambuf_iterator<char>(ifs)),
// std::istreambuf_iterator<char>());
// parent_.reset(new driver::cu_module(ctx, src));
// bin_.reset(driver::kernel::create(&*parent_, name_.c_str()));
//}
//function::caller::caller(driver::context* ctx, std::ifstream &ifs, const options_t& opt)
// : opt_(opt) {
// read(ctx, ifs);
//}
function : : caller : : caller ( ir : : function * ir ,
std : : shared_ptr < driver : : module > parent , const options_t & opt )
: parent_ ( parent ) , opt_ ( opt ) , name_ ( ir - > get_name ( ) ) {
bin_ . reset ( driver : : kernel : : create ( & * parent , name_ . c_str ( ) ) ) ;
// extract signature
ir : : function_type * ty = ir - > get_fn_type ( ) ;
for ( size_t i = 0 ; i < ty - > get_num_params ( ) ; i + + ) {
param_tys_ . push_back ( convert ( ty - > get_param_ty ( i ) ) ) ;
if ( ! ir - > has_attr ( i + 1 ) )
continue ;
for ( ir : : attribute attr : ir - > attrs ( ) . at ( i + 1 ) )
if ( attr . get_kind ( ) = = ir : : retune )
retune_ . push_back ( i ) ;
}
}
void function : : caller : : operator ( ) ( driver : : stream * stream , const grid_t & _grid , void * * args , size_t args_size , const std : : map < std : : string , std : : vector < char > > & csts ) const {
// copy constants
for ( const auto & cst : csts ) {
std : : unique_ptr < driver : : buffer > buffer = parent ( ) - > symbol ( cst . first . c_str ( ) ) ;
stream - > write ( & * buffer , true , 0 , cst . second ) ;
}
// set grid
if ( _grid . size ( ) > 3 )
throw std : : runtime_error ( " grid size must be no greater than 3 " ) ;
std : : array < size_t , 3 > grid ;
for ( size_t i = 0 ; i < 3 ; i + + )
grid [ i ] = ( i < _grid . size ( ) ) ? _grid [ i ] : 1 ;
// enqueue
stream - > enqueue ( & * bin_ , grid , { opt_ . num_warps * 32 , 1 , 1 } , args , args_size ) ;
}
/* --------------------- */
/* FUNCTION */
/* --------------------- */
// create Triton-IR from AST
std : : unique_ptr < ir : : module > function : : make_ir ( Parser & parser ) {
ir : : module * module = new ir : : module ( " " , ctx_ ) ;
Generator gen ( & parser ) ;
gen . Gen ( module ) ;
return std : : unique_ptr < ir : : module > ( module ) ;
}
// create Binary from Triton-IR
std : : unique_ptr < driver : : module > function : : make_bin ( ir : : module & module , driver : : device * device , const options_t & opt ) {
std : : unique_ptr < codegen : : target > target = device - > make_target ( ) ;
// generate llvm code
llvm : : LLVMContext ctx ;
std : : unique_ptr < llvm : : Module > llvm ( new llvm : : Module ( module . get_name ( ) , ctx ) ) ;
// optimizations
bool cts_use_async = target - > as_nvidia ( ) - > sm ( ) > = 80 ;
// create passes
codegen : : analysis : : align align ;
codegen : : analysis : : axes axes ;
codegen : : transform : : cts cts ( cts_use_async ) ;
codegen : : transform : : disassociate disassociate ;
codegen : : analysis : : layouts layouts ( & axes , & align , opt . num_warps , target . get ( ) ) ;
codegen : : analysis : : liveness liveness ( & layouts ) ;
codegen : : analysis : : swizzle swizzle ( & layouts , target . get ( ) ) ;
codegen : : analysis : : allocation allocation ( & liveness ) ;
codegen : : transform : : membar barriers ( & liveness , & layouts , & allocation ) ;
codegen : : transform : : dce dce ;
codegen : : transform : : peephole peephole ( target . get ( ) ) ;
codegen : : transform : : reassociate reassociate ;
codegen : : transform : : coalesce coalesce ( & align , & layouts ) ;
codegen : : generator isel ( & axes , & layouts , & align , & allocation , & swizzle , target . get ( ) , opt . num_warps ) ;
// run passes
dce . run ( module ) ;
disassociate . run ( module ) ;
dce . run ( module ) ;
peephole . run ( module ) ;
dce . run ( module ) ;
align . run ( module ) ;
if ( target - > is_gpu ( ) )
cts . run ( module ) ;
axes . run ( module ) ;
layouts . run ( module ) ;
coalesce . run ( module ) ;
dce . run ( module ) ;
align . run ( module ) ;
dce . run ( module ) ;
if ( target - > is_gpu ( ) ) {
reassociate . run ( module ) ;
cts . run ( module ) ;
}
peephole . run ( module ) ;
dce . run ( module ) ;
align . run ( module ) ;
axes . run ( module ) ;
// ir::print(module, std::cout);
layouts . run ( module ) ;
swizzle . run ( module ) ;
liveness . run ( module ) ;
allocation . run ( module ) ;
if ( allocation . allocated_size ( ) > device - > max_shared_memory ( ) )
throw exception : : out_of_shared_memory ( ) ;
barriers . run ( module ) ;
isel . visit ( module , * llvm ) ;
std : : unique_ptr < driver : : module > res ( driver : : module : : create ( device , std : : move ( llvm ) ) ) ;
// if(res->spilled() > 256)
// throw exception::out_of_registers();
return res ;
}
// create Binary from options
void function : : make ( driver : : device * device , options_t opt ) {
if ( callers_ . find ( opt ) ! = callers_ . end ( ) )
return ;
// pre-process
TokenSequence tokens ;
Preprocessor cpp ( & src_ , true ) ;
for ( auto it : opt . defines )
cpp . AddMacro ( it . first , & it . second ) ;
cpp . Process ( tokens ) ;
// src -> ast
Parser parser ( tokens ) ;
parser . Parse ( ) ;
// ast -> triton-ir
auto ir = make_ir ( parser ) ;
// triton-ir -> binary
std : : unique_ptr < driver : : module > bin ;
try {
bin = make_bin ( * ir , device , opt ) ;
} catch ( const exception : : base & ) {
throw ;
}
// create callable
ir : : function * tmp = ir - > get_function_list ( ) [ 0 ] ;
callers_ [ opt ] . reset ( new caller ( tmp , std : : move ( bin ) , opt ) ) ;
}
// precompile all kernels spanned by given options space
void function : : precompile ( driver : : device * device , const options_space_t & space ) {
// all ranges
std : : vector < size_t > ranges ;
ranges . push_back ( space . num_warps . size ( ) ) ;
for ( const auto & x : space . defines )
ranges . push_back ( x . second . size ( ) ) ;
// functor for source with given option
std : : map < options_t , std : : string > err ;
auto do_make = [ & ] ( std : : vector < size_t > params ) {
// compilation options
unsigned i = 0 ;
options_t opt ;
opt . num_warps = space . num_warps [ params [ i + + ] ] ;
for ( auto D : space . defines )
opt . defines [ D . first ] = D . second [ params [ i + + ] ] ;
// compile
try {
make ( device , opt ) ;
} catch ( const exception : : base & e ) {
err [ opt ] = e . what ( ) ;
}
} ;
// multi-threaded compilation
_loop_nest ( ranges , do_make ) ;
if ( callers_ . empty ( ) ) {
std : : ostringstream dbg ;
dbg < < " Auto-Tuner could not find any valid configuration: " < < std : : endl ;
for ( auto x : err ) {
dbg < < " [ " ;
dbg < < x . first . num_warps < < " , " ;
dbg < < " { " ;
for ( const auto & y : x . first . defines )
dbg < < ' " ' < < y . first < < " \" = \" " < < y . second < < " \" , " ;
dbg < < " } ] -> " < < x . second < < std : : endl ;
}
throw exception : : no_valid_configuration ( dbg . str ( ) ) ;
}
}
std : : string function : : get_asm ( asm_mode_t mode , driver : : device * device , const options_t & opt ) {
make ( device , opt ) ;
const auto & fn = callers_ . at ( opt ) ;
if ( ! fn )
return " " ;
switch ( mode ) {
case ASM_LLIR : {
return fn - > parent ( ) - > llir ( ) ;
}
case ASM_NV_PTX :
case ASM_NV_SASS : {
std : : string ptx = ( ( driver : : cu_module * ) fn - > parent ( ) ) - > ptx ( ) ;
// SASS
std : : string input = std : : tmpnam ( nullptr ) ;
std : : string output = std : : tmpnam ( nullptr ) ;
std : : ofstream ofs ( input ) ;
ofs < < ptx ;
ofs . close ( ) ;
if ( mode = = ASM_NV_PTX )
return ptx ;
std : : string cmd ;
int err ;
// compile ptx
driver : : cu_device * cu_device = ( driver : : cu_device * ) device ;
cmd = " ptxas --gpu-name=sm_ " + std : : to_string ( cu_device - > compute_capability ( ) ) + " " + input + " -o " + input + " .o " ;
err = system ( cmd . c_str ( ) ) ;
// disassemble
cmd = " cuobjdump --dump-sass " + input + " .o >> " + output ;
err = system ( cmd . c_str ( ) ) ;
std : : regex comment ( " * \\ / \\ * 0x[0-9a-f]+ \\ * \\ / " ) ;
std : : string to_delete = " /* " ;
std : : ifstream ifs ( output ) ;
std : : string line ;
std : : string sass ;
while ( std : : getline ( ifs , line ) )
if ( ! std : : regex_match ( line , comment ) )
sass + = line + " \n " ;
return sass ;
}
default :
return " " ;
}
}
// returns program with best compilation options for given parameter
function : : caller * function : : autotune ( driver : : stream * stream , const grid_fn_ty & grid_fn ,
void * * args , size_t args_size ) {
// fast path -- no autotuning necessary
if ( callers_ . size ( ) = = 1 )
return & * callers_ . begin ( ) - > second ;
// run auto-tuner
double best_ts = INFINITY ;
caller * ret = nullptr ;
for ( auto & x : callers_ ) {
if ( x . second = = nullptr )
throw std : : runtime_error ( " configuration not compiled " ) ;
caller * current = & * x . second ;
double ts = tools : : bench ( [ & ] ( ) { ( * current ) ( stream , grid_fn ( x . first ) , args , args_size , cst_ ) ; } ,
stream , true ) ;
ret = ( ts < best_ts ) ? current : ret ;
best_ts = std : : min ( ts , best_ts ) ;
}
stream - > synchronize ( ) ;
return ret ;
}
// set copy host buffer "data" into constant memory buffer "name"
void function : : set_cst ( const char * name , void * data , size_t n_bytes ) {
cst_ [ std : : string ( name ) ] = std : : vector < char > ( ( char * ) data , ( char * ) data + n_bytes ) ;
}
std : : string function : : preheader ( ) {
std : : string kernel : : preheader ( ) {
return R " (
# define bool _Bool
# define true 1
@@ -452,67 +117,212 @@ typedef long int64;
) " ;
}
std : : string function : : get_cache_prefix ( ) {
//user-specified cache path
std : : string result = tools : : getenv ( " TRITON_CACHE_PATH " ) ;
if ( ! result . empty ( ) ) {
i f( tools : : mkpath ( result ) = = 0 )
return result ;
}
//create in home
result = tools : : getenv ( " HOME " ) ;
if ( ! result . empty ( ) )
{
res ult = result + " /.triton/cache/ " ;
if ( tools : : mkpath ( result ) = = 0 )
return result ;
}
return " " ;
void kernel : : init_ir ( const std : : string & src ) {
// pre-process
TokenSequence tokens ;
Preprocessor cpp ( & src , true ) ;
for ( auto it : opt . defines )
cpp . AddMacro ( it . first , & it . second ) ;
cpp . Process ( tokens ) ;
// src -> ast
Parser parser ( tokens ) ;
parser . Parse ( ) ;
// ast -> triton-ir
ir : : module * mod ule = new ir : : module ( " " , ctx_ ) ;
Generator gen ( & parser ) ;
gen . Gen ( module ) ;
ir_ . reset ( module ) ;
}
function : : function ( const std : : string & src ,
const options_space_t & opt ,
const std : : string & cache_ref ) :
src_ ( src ) , opt_ ( opt ) , cache_ref_ ( cache_ref ) {
// hash source code
unsigned char hash [ 20 ] ;
sha1 : : calc ( ( void * ) src_ . data ( ) , src_ . size ( ) , hash ) ;
// create cache path
char _hex [ 40 ] ;
sha1 : : toHexString ( h ash , _hex ) ;
std : : string hex ( _hex , _hex + 40 ) ;
cache_path_ = get_cache_prefix ( ) + hex + " / " ;
tools : : mkpath ( cache_path_ ) ;
// append pre-header to source
src_ = preheader ( ) + src_ ;
void kernel : : init_ker ( ) {
// triton-ir -> binary
std : : unique_ptr < driver : : module > bin ;
std : : unique_ptr < codegen : : target > target = dev_ - > make_target ( ) ;
// generate llvm code
llvm : : LLVMContext ctx ;
std : : string name = ir_ - > get_function_list ( ) [ 0 ] - > get_name ( ) ;
std : : unique_ptr < llvm : : Module > llvm ( new llvm : : Module ( name , ctx ) ) ;
// optimizations
bool cts_use_async = target - > as_nvidia ( ) - > sm ( ) > = 80 ;
// create passes
codegen : : analysis : : align align ;
codegen : : analysis : : axes axes ;
codegen : : transform : : cts cts ( cts_use_async ) ;
codegen : : transform : : disassociate disassociate ;
codegen : : analysis : : layouts layouts ( & axes , & align , opt . num_warps , target . get ( ) ) ;
codegen : : analysis : : liveness liveness ( & layouts ) ;
codegen : : analysis : : swizzle swizzle ( & layouts , target . get ( ) ) ;
codegen : : analysis : : allocation allocation ( & liveness ) ;
codegen : : transform : : membar barriers ( & liveness , & layouts , & allocation ) ;
codegen : : transform : : dce dce ;
codegen : : transform : : peephole peephole ( target . get ( ) ) ;
codegen : : transform : : reassociate reassociate ;
codegen : : transform : : coalesce coalesce ( & align , & layouts ) ;
codegen : : generator isel ( & axes , & layouts , & align , & allocation , & swizzle , target . get ( ) , opt . num_warps ) ;
// run passes
dce . run ( * ir_ ) ;
disassociate . run ( * ir_ ) ;
dce . run ( * ir_ ) ;
peephole . run ( * ir_ ) ;
dce . run ( * ir_ ) ;
align . run ( * ir_ ) ;
if ( target - > is_gpu ( ) )
cts . run ( * ir_ ) ;
axes . run ( * ir_ ) ;
layouts . run ( * ir_ ) ;
coalesce . run ( * ir_ ) ;
dce . run ( * ir_ ) ;
align . run ( * ir_ ) ;
dce . run ( * ir_ ) ;
if ( target - > is_gpu ( ) ) {
reassociate . run ( * ir_ ) ;
cts . run ( * ir_ ) ;
}
peephole . run ( * ir_ ) ;
dce . run ( * ir_ ) ;
align . run ( * ir_ ) ;
axes . run ( * ir_ ) ;
layouts . run ( * ir_ ) ;
swizzle . run ( * ir_ ) ;
liveness . run ( * ir_ ) ;
allocation . run ( * ir_ ) ;
if ( allocation . allocated_size ( ) > dev_ - > max_shared_memory ( ) )
throw exception : : out_of_shared_memory ( ) ;
barriers . run ( * ir_ ) ;
isel . visit ( * ir_ , * llvm ) ;
//if(res->spilled() > 256)
// throw exception::out_of_registers();
mod_ . reset ( driver : : module : : create ( dev_ , std : : move ( llvm ) ) ) ;
ker_ . reset ( driver : : kernel : : create ( & * mod_ , name . c_str ( ) ) ) ;
}
void function : : operator ( ) ( void * * args , size_t args_size , const grid_fn_ty & grid_fn , driver : : stream * stream , driver : : device * device ) {
// pre-compile kernels
if ( callers_ . empty ( ) ) {
precompile ( device , opt_ ) ;
void kernel : : init_sig ( ) {
ir : : function * fn = ir_ - > get_function_list ( ) [ 0 ] ;
ir : : function_type * ty = fn - > get_fn_type ( ) ;
for ( size_t i = 0 ; i < ty - > get_num_params ( ) ; i + + ) {
sig_ . push_back ( convert ( ty - > get_param_ty ( i ) ) ) ;
if ( ! fn - > has_attr ( i + 1 ) )
continue ;
}
// re-tuning key
cache_key_t key ;
key . first = device ;
key . second = callers_ . begin ( ) - > second - > retune ( ) ;
// auto-tune if necessary
}
kernel : : kernel ( const std : : string & src , const options_t & opt , driver : : device * dev ) :
opt ( opt ) , dev_ ( dev ) {
init_ir ( preheader ( ) + src ) ;
init_ker ( ) ;
init_sig ( ) ;
}
void kernel : : operator ( ) ( void * args , size_t args_size , driver : : stream * stream , const std : : vector < size_t > & _grid ) const {
// set grid
if ( _grid . size ( ) > 3 )
throw std : : runtime_error ( " grid size must be no greater than 3 " ) ;
std : : array < size_t , 3 > grid ;
for ( size_t i = 0 ; i < 3 ; i + + )
grid [ i ] = ( i < _grid . size ( ) ) ? _grid [ i ] : 1 ;
// enqueue
stream - > enqueue ( & * ker_ , grid , { opt . num_warps * 32 , 1 , 1 } , args , args_size ) ;
}
/* --------------------------------- */
/* --------------------------------- */
/* --------------------------------- */
void function : : do_loop_nest ( std : : vector < size_t > const & ranges ,
std : : function < void ( std : : vector < size_t > const & ) > const & f ) {
size_t D = ranges . size ( ) ;
std : : vector < size_t > values ( D , 0 ) ;
size_t i = D - 1 ;
while ( true ) {
f ( values ) ;
while ( values [ i ] + + = = ranges [ i ] - 1 ) {
if ( i = = 0 )
return ;
values [ i - - ] = 0 ;
}
i = D - 1 ;
}
}
void function : : init_kernels ( const std : : string & src , const options_space_t & opts , driver : : device * device ) {
// all ranges
std : : vector < size_t > ranges ;
ranges . push_back ( opts . num_warps . size ( ) ) ;
for ( const auto & x : opts . defines )
ranges . push_back ( x . second . size ( ) ) ;
// functor for source with given option
std : : vector < std : : pair < options_t , std : : string > > err ;
auto do_make = [ & ] ( std : : vector < size_t > params ) {
// compilation options
unsigned i = 0 ;
options_t opt ;
opt . num_warps = opts . num_warps [ params [ i + + ] ] ;
for ( auto D : opts . defines )
opt . defines [ D . first ] = D . second [ params [ i + + ] ] ;
// compile
try {
kernels_ . push_back ( { opt , std : : make_shared < kernel > ( src , opt , device ) } ) ;
} catch ( const exception : : base & e ) {
err . push_back ( { opt , e . what ( ) } ) ;
}
} ;
// multi-threaded compilation
do_loop_nest ( ranges , do_make ) ;
if ( kernels_ . empty ( ) ) {
std : : ostringstream dbg ;
dbg < < " Auto-Tuner could not find any valid configuration: " < < std : : endl ;
for ( auto x : err ) {
dbg < < " [ " ;
dbg < < x . first . num_warps < < " , " ;
dbg < < " { " ;
for ( const auto & y : x . first . defines )
dbg < < ' " ' < < y . first < < " \" = \" " < < y . second < < " \" , " ;
dbg < < " } ] -> " < < x . second < < std : : endl ;
}
throw exception : : no_valid_configuration ( dbg . str ( ) ) ;
}
}
kernel * function : : autotune ( void * args , size_t args_size , const grid_fn_ty & grid_fn , driver : : stream * stream ) {
// fast path -- no autotuning necessary
if ( kernels_ . size ( ) = = 1 )
return & * kernels_ . begin ( ) - > second ;
// auto-tuning key
std : : vector < uint64_t > key ;
auto it = cache_ . find ( key ) ;
if ( it = = cache_ . end ( ) ) {
auto best = autotune ( stream , grid_fn , args , args_size ) ;
it = cache_ . insert ( { key , best } ) . first ;
if ( it ! = cache_ . end ( ) )
return it - > second ;
// run auto-tuner
double best_ts = INFINITY ;
kernel * ret = nullptr ;
for ( auto & x : kernels_ ) {
kernel * current = & * x . second ;
auto grid = grid_fn ( x . first ) ;
while ( grid . size ( ) < 3 )
grid . push_back ( 1 ) ;
double ts = tools : : bench ( [ & ] ( ) { ( * current ) ( args , args_size , stream , grid ) ; } ,
stream , true ) ;
ret = ( ts < best_ts ) ? current : ret ;
best_ts = std : : min ( ts , best_ts ) ;
}
// run
( * it - > second ) ( stream , grid_fn ( it - > second - > opt ( ) ) , args , args_size , cst_ ) ;
stream - > synchronize ( ) ;
it = cache_ . insert ( { key , ret } ) . first ;
return it - > second ;
}
void function : : operator ( ) ( void * * args ,
size_t args_size ,
const grid_t & grid ,
driver : : stream * stream , driver : : device * device ) {
return this - > operator ( ) ( args , args_size , [ & grid ] ( const options_t & ) { return grid ; } , stream , device ) ;
function : : function ( const std : : string & src , const options_space_t & opt , driver : : device * device ) {
init_kernels ( src , opt , device ) ;
}
void function : : operator ( ) ( void * args , size_t args_size , const grid_fn_ty & grid_fn , driver : : stream * stream ) {
runtime : : kernel * fn = autotune ( args , args_size , grid_fn , stream ) ;
( * fn ) ( args , args_size , stream , grid_fn ( fn - > opt ) ) ;
}
void function : : operator ( ) ( void * args , size_t args_size , const grid_t & grid , driver : : stream * stream ) {
return this - > operator ( ) ( args , args_size , [ & grid ] ( const options_t & ) { return grid ; } , stream ) ;
}
}