2021-09-09 00:04:28 -07:00
/* Copyright 2015-2017 Philippe Tillet
2022-10-12 12:02:30 -07:00
*
* Permission is hereby granted , free of charge , to any person obtaining
* a copy of this software and associated documentation files
* ( the " Software " ) , to deal in the Software without restriction ,
* including without limitation the rights to use , copy , modify , merge ,
* publish , distribute , sublicense , and / or sell copies of the Software ,
* and to permit persons to whom the Software is furnished to do so ,
* subject to the following conditions :
*
* The above copyright notice and this permission notice shall be
* included in all copies or substantial portions of the Software .
*
* THE SOFTWARE IS PROVIDED " AS IS " , WITHOUT WARRANTY OF ANY KIND ,
* EXPRESS OR IMPLIED , INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
* MERCHANTABILITY , FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT .
* IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
* CLAIM , DAMAGES OR OTHER LIABILITY , WHETHER IN AN ACTION OF CONTRACT ,
* TORT OR OTHERWISE , ARISING FROM , OUT OF OR IN CONNECTION WITH THE
* SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE .
*/
2021-09-09 00:04:28 -07:00
# include <fstream>
2021-12-07 14:10:58 -08:00
# if __has_include(<unistd.h>)
2022-10-12 12:02:30 -07:00
# include <unistd.h>
2021-12-07 14:10:58 -08:00
# endif
2021-09-09 00:04:28 -07:00
# include <memory>
# include <regex>
2021-12-13 12:28:15 -08:00
# include <iomanip>
2021-09-09 00:04:28 -07:00
# include "triton/driver/llvm.h"
# include "triton/driver/dispatch.h"
# include "triton/driver/error.h"
# include "triton/tools/sha1.hpp"
# include "triton/tools/sys/getenv.hpp"
# include "triton/tools/sys/mkdir.hpp"
# include "triton/tools/sys/exec.hpp"
# include "llvm/IR/IRBuilder.h"
# include "llvm/IR/Verifier.h"
# include "llvm/IR/IRPrintingPasses.h"
# include "llvm/IR/Module.h"
# include "llvm/Support/CodeGen.h"
# include "llvm/Support/CommandLine.h"
# include "llvm/Support/SourceMgr.h"
# include "llvm/Support/raw_ostream.h"
# include "llvm/Support/TargetRegistry.h"
# include "llvm/Support/TargetSelect.h"
# include "llvm/Target/TargetMachine.h"
# include "llvm/Target/TargetOptions.h"
# include "llvm/IR/LegacyPassManager.h"
# include "llvm/ExecutionEngine/ExecutionEngine.h"
# include "llvm/ExecutionEngine/SectionMemoryManager.h"
# include "llvm/Transforms/Utils/Cloning.h"
2022-01-28 01:12:44 +08:00
# include "llvm/Transforms/Scalar.h"
2021-09-09 00:04:28 -07:00
// begin AMD stuff
# include "llvm/Support/FileSystem.h"
# include "llvm/Support/FormattedStream.h"
# include "llvm/Support/Program.h"
# include "llvm/Support/ToolOutputFile.h"
# include "llvm/ADT/StringRef.h"
# include "llvm/Analysis/TargetLibraryInfo.h"
2021-12-13 12:28:15 -08:00
# include "llvm/IR/IntrinsicsAMDGPU.h"
# include "llvm/IR/Intrinsics.h"
2021-09-09 00:04:28 -07:00
// end AMD stuff
2022-10-12 12:02:30 -07:00
extern " C "
{
int set_curterm ( char * nterm ) { return 0 ; }
int del_curterm ( char * nterm ) { return 0 ; }
2022-02-24 14:56:24 -08:00
int tigetnum ( char * capname ) { return 0 ; }
int setupterm ( char * term , int fildes , int * errret ) { return 0 ; }
2021-09-09 00:04:28 -07:00
}
2022-10-12 12:02:30 -07:00
namespace triton
{
namespace driver
{
2021-09-09 00:04:28 -07:00
2022-10-12 12:02:30 -07:00
void init_llvm ( )
{
LLVMInitializeNVPTXTargetInfo ( ) ;
LLVMInitializeNVPTXTarget ( ) ;
LLVMInitializeNVPTXTargetMC ( ) ;
LLVMInitializeNVPTXAsmPrinter ( ) ;
LLVMInitializeAMDGPUTargetInfo ( ) ;
LLVMInitializeAMDGPUTarget ( ) ;
LLVMInitializeAMDGPUTargetMC ( ) ;
LLVMInitializeAMDGPUAsmPrinter ( ) ;
2022-03-30 22:45:41 -05:00
}
2021-09-09 00:04:28 -07:00
2022-10-12 12:02:30 -07:00
/* ------------------------ */
// CUDA //
/* ------------------------ */
static bool find_and_replace ( std : : string & str , const std : : string & begin , const std : : string & end , const std : : string & target )
{
size_t start_replace = str . find ( begin ) ;
size_t end_replace = str . find ( end , start_replace ) ;
if ( start_replace = = std : : string : : npos )
return false ;
str . replace ( start_replace , end_replace + 1 - start_replace , target ) ;
return true ;
}
2021-09-09 00:04:28 -07:00
2022-10-12 12:02:30 -07:00
std : : string path_to_ptxas ( int & version )
{
std : : vector < std : : string > rets ;
std : : string ret ;
// search paths for ptxas
std : : vector < std : : string > ptxas_prefixes = { " " , " /usr/local/cuda/bin/ " } ;
std : : string triton_ptxas = tools : : getenv ( " TRITON_PTXAS_PATH " ) ;
if ( ! triton_ptxas . empty ( ) )
ptxas_prefixes . insert ( ptxas_prefixes . begin ( ) , triton_ptxas ) ;
// see what path for ptxas are valid
std : : vector < std : : string > working_ptxas ;
for ( std : : string prefix : ptxas_prefixes )
{
std : : string ptxas = prefix + " ptxas " ;
bool works = tools : : exec ( ptxas + " --version 2>&1 " , ret ) = = 0 ;
if ( works )
{
working_ptxas . push_back ( ptxas ) ;
rets . push_back ( ret ) ;
}
}
// error if no working ptxas was found
if ( working_ptxas . empty ( ) )
throw std : : runtime_error ( " `ptxas` was searched in TRITON_PTXAS_PATH, /usr/local/cuda/bin/ or PATH "
" but a working version could not be found. " ) ;
std : : string ptxas = working_ptxas . front ( ) ;
// parse version
std : : regex version_regex ( " release ( \\ d+) \\ .( \\ d+) " ) ;
std : : smatch match ;
bool found = false ;
// currently choosing the first ptxas. Other logics can be implemented in future
for ( std : : string ret : rets )
{
if ( std : : regex_search ( ret , match , version_regex ) )
{
int major = std : : stoi ( match [ 1 ] ) ;
int minor = std : : stoi ( match [ 2 ] ) ;
version = major * 1000 + minor * 10 ;
found = true ;
break ;
}
}
if ( not found )
{
throw std : : runtime_error ( " Error in parsing version " ) ;
}
return ptxas ;
}
2021-09-09 00:04:28 -07:00
2022-10-12 12:02:30 -07:00
int vptx ( int version )
{
if ( version > = 11040 )
return 74 ;
// if(version >= 11030) return 73;
// if(version >= 11020) return 72;
// if(version >= 11010) return 71;
// if(version >= 11000) return 70;
// if(version >= 10020) return 65;
// if(version >= 10010) return 64;
// if(version >= 10000) return 63;
throw std : : runtime_error ( " Triton requires CUDA 11.4+ " ) ;
}
2021-09-18 22:48:26 -07:00
2022-10-12 12:02:30 -07:00
std : : string llir_to_ptx ( llvm : : Module * module , int cc , int version )
{
// LLVM version in use may not officially support target hardware
int max_nvvm_cc = 75 ;
int max_nvvm_ptx = 74 ;
// options
auto options = llvm : : cl : : getRegisteredOptions ( ) ;
auto * short_ptr = static_cast < llvm : : cl : : opt < bool > * > ( options [ " nvptx-short-ptr " ] ) ;
assert ( short_ptr ) ;
short_ptr - > setValue ( true ) ;
// compute capability
std : : string sm = " sm_ " + std : : to_string ( cc ) ;
// max PTX version
int ptx = vptx ( version ) ;
int ptx_major = ptx / 10 ;
int ptx_minor = ptx % 10 ;
// create
llvm : : SmallVector < char , 0 > buffer ;
std : : string triple = " nvptx64-nvidia-cuda " ;
std : : string proc = " sm_ " + std : : to_string ( std : : min ( cc , max_nvvm_cc ) ) ;
std : : string layout = " " ;
std : : string features = " " ;
// std::string features = "+ptx" + std::to_string(std::min(ptx, max_nvvm_ptx));
init_llvm ( ) ;
// verify and store llvm
llvm : : legacy : : PassManager pm ;
// pm.add(llvm::createPrintModulePass(llvm::outs()));
pm . add ( llvm : : createVerifierPass ( ) ) ;
pm . run ( * module ) ;
// module->print(llvm::outs(), nullptr);
2021-09-09 00:04:28 -07:00
2022-10-12 12:02:30 -07:00
// create machine
module - > setTargetTriple ( triple ) ;
std : : string error ;
llvm : : TargetMachine * machine ;
auto target = llvm : : TargetRegistry : : lookupTarget ( module - > getTargetTriple ( ) , error ) ;
llvm : : TargetOptions opt ;
opt . AllowFPOpFusion = llvm : : FPOpFusion : : Fast ;
opt . UnsafeFPMath = false ;
opt . NoInfsFPMath = false ;
opt . NoNaNsFPMath = true ;
machine = target - > createTargetMachine ( module - > getTargetTriple ( ) , proc , features , opt ,
llvm : : Reloc : : PIC_ , llvm : : None , llvm : : CodeGenOpt : : Aggressive ) ;
// set data layout
if ( layout . empty ( ) )
module - > setDataLayout ( machine - > createDataLayout ( ) ) ;
else
module - > setDataLayout ( layout ) ;
// emit machine code
for ( llvm : : Function & f : module - > functions ( ) )
f . addFnAttr ( llvm : : Attribute : : AlwaysInline ) ;
llvm : : legacy : : PassManager pass ;
llvm : : raw_svector_ostream stream ( buffer ) ;
// emit
machine - > addPassesToEmitFile ( pass , stream , nullptr , llvm : : CodeGenFileType : : CGFT_AssemblyFile ) ;
pass . run ( * module ) ;
2021-09-09 00:04:28 -07:00
2022-10-12 12:02:30 -07:00
// post-process
std : : string result ( buffer . begin ( ) , buffer . end ( ) ) ;
find_and_replace ( result , " .version " , " \n " , " .version " + std : : to_string ( ptx_major ) + " . " + std : : to_string ( ptx_minor ) + " \n " ) ;
find_and_replace ( result , " .target " , " \n " , " .target " + sm + " \n " ) ;
while ( find_and_replace ( result , " \t // begin inline asm " , " \n " , " " ) )
;
while ( find_and_replace ( result , " \t // end inline asm " , " \n " , " " ) )
;
return result ;
}
2021-09-09 00:04:28 -07:00
2022-10-12 12:02:30 -07:00
std : : string ptx_to_cubin ( const std : : string & ptx , const std : : string & ptxas , int cc )
{
2021-09-09 00:04:28 -07:00
// compile ptx with ptxas
2021-12-07 14:10:58 -08:00
char _fsrc [ L_tmpnam ] ;
char _flog [ L_tmpnam ] ;
std : : tmpnam ( _fsrc ) ;
std : : tmpnam ( _flog ) ;
2021-09-09 00:04:28 -07:00
std : : string fsrc = _fsrc ;
std : : string flog = _flog ;
2021-09-18 22:48:26 -07:00
std : : string fbin = fsrc + " .o " ;
2022-10-12 12:02:30 -07:00
const char * _fbin = fbin . c_str ( ) ;
2021-09-09 00:04:28 -07:00
std : : ofstream ofs ( fsrc ) ;
2022-10-12 12:02:30 -07:00
ofs < < ptx < < std : : endl ;
2021-09-09 00:04:28 -07:00
ofs . close ( ) ;
std : : string cmd ;
int err ;
cmd = ptxas + " -v --gpu-name=sm_ " + std : : to_string ( cc ) + " " + fsrc + " -o " + fsrc + " .o 2> " + flog ;
err = system ( cmd . c_str ( ) ) ;
2022-10-12 12:02:30 -07:00
if ( err ! = 0 )
{
std : : ifstream _log ( _flog ) ;
std : : string log ( std : : istreambuf_iterator < char > ( _log ) , { } ) ;
unlink ( _fsrc ) ;
unlink ( _flog ) ;
throw std : : runtime_error ( " Internal Triton PTX codegen error: \n " + log ) ;
}
std : : ifstream _cubin ( _fbin , std : : ios : : binary ) ;
2021-09-18 22:48:26 -07:00
std : : string cubin ( std : : istreambuf_iterator < char > ( _cubin ) , { } ) ;
_cubin . close ( ) ;
2021-09-09 00:04:28 -07:00
unlink ( _fsrc ) ;
unlink ( _flog ) ;
2021-09-18 22:48:26 -07:00
unlink ( _fbin ) ;
2022-10-12 12:02:30 -07:00
return cubin ;
2021-09-09 00:04:28 -07:00
}
2022-10-12 12:02:30 -07:00
/* ------------------------ */
// HIP //
/* ------------------------ */
2022-10-26 17:18:33 +00:00
std : : tuple < std : : string , std : : string > llir_to_amdgcn ( llvm : : Module * module , const std : : string & _proc )
2022-10-12 12:02:30 -07:00
{
2022-10-26 17:18:33 +00:00
std : : cout < < " llvm.cc: llir_to_amdgcn: " < < std : : endl ;
2022-10-12 12:02:30 -07:00
init_llvm ( ) ;
2022-10-17 18:29:15 +00:00
// proc = std::get<0>(GetFeatureStrFromGCNArchName(rocminfo));
// features = std::get<1>(GetFeatureStrFromGCNArchName(rocminfo));
2021-09-09 00:04:28 -07:00
2022-10-17 18:29:15 +00:00
// create
llvm : : SmallVector < char , 0 > buffer ;
std : : string triple = " amdgcn-amd-amdhsa " ;
std : : string layout = " " ;
std : : string features = " +sramecc,-xnack " ;
std : : string proc = STRINGIFY ( MI_GPU_ARCH ) ;
// name kernel
auto in_time_t = std : : chrono : : system_clock : : to_time_t ( std : : chrono : : system_clock : : now ( ) ) ;
std : : stringstream cur_time ;
cur_time < < std : : put_time ( std : : localtime ( & in_time_t ) , " %Y-%m-%d--%I-%M-%S " ) ;
std : : string kernel_name = module - > getModuleIdentifier ( ) + " _ " + cur_time . str ( ) ;
// verify and store llvm
llvm : : legacy : : PassManager pm ;
pm . add ( llvm : : createVerifierPass ( ) ) ;
pm . run ( * module ) ;
// create machine
module - > setTargetTriple ( triple ) ;
std : : string error ;
auto target = llvm : : TargetRegistry : : lookupTarget ( module - > getTargetTriple ( ) , error ) ;
llvm : : TargetOptions opt ;
opt . AllowFPOpFusion = llvm : : FPOpFusion : : Fast ;
opt . UnsafeFPMath = false ;
opt . NoInfsFPMath = false ;
opt . NoNaNsFPMath = true ;
llvm : : TargetMachine * machine = target - > createTargetMachine ( module - > getTargetTriple ( ) , proc , features , opt ,
llvm : : Reloc : : PIC_ , llvm : : None ,
llvm : : CodeGenOpt : : None ) ;
// set data layout
if ( layout . empty ( ) )
module - > setDataLayout ( machine - > createDataLayout ( ) ) ;
else
module - > setDataLayout ( layout ) ;
// emit machine code
for ( llvm : : Function & f : module - > functions ( ) )
f . addFnAttr ( llvm : : Attribute : : AlwaysInline ) ;
llvm : : legacy : : PassManager pass ;
llvm : : raw_svector_ostream stream ( buffer ) ;
2021-09-09 00:04:28 -07:00
2022-10-17 18:29:15 +00:00
// create dump files
std : : error_code ec ;
2021-09-09 00:04:28 -07:00
2022-10-17 18:29:15 +00:00
// Save GCN ISA binary.
std : : string isabin_path = std : : string ( " /tmp/ " ) + kernel_name + std : : string ( " .o " ) ;
std : : unique_ptr < llvm : : raw_fd_ostream > isabin_fs (
new llvm : : raw_fd_ostream ( isabin_path , ec , llvm : : sys : : fs : : OF_Text ) ) ;
if ( ec )
{
std : : cout < < isabin_path < < " was not created. error code: " < < ec < < std : : endl ;
}
2021-09-09 00:04:28 -07:00
2022-10-17 18:29:15 +00:00
// emit
machine - > addPassesToEmitFile ( pass , * isabin_fs , nullptr , llvm : : CGFT_ObjectFile ) ;
pass . run ( * module ) ;
2021-12-13 12:28:15 -08:00
2022-10-26 17:18:33 +00:00
// Save GCN ISA.
2022-10-17 18:29:15 +00:00
llvm : : SmallVector < char , 0 > debugBuffer ;
llvm : : legacy : : PassManager debugPass ;
llvm : : raw_svector_ostream debugStream ( debugBuffer ) ;
machine - > addPassesToEmitFile ( debugPass , debugStream , nullptr , llvm : : CodeGenFileType : : CGFT_AssemblyFile ) ; // TODO:cause segfault on REM ops also cause @llvm.amdgcn.if bug
debugPass . run ( * module ) ;
2022-10-26 17:18:33 +00:00
std : : string amdgcn ( debugBuffer . begin ( ) , debugBuffer . end ( ) ) ;
2021-09-09 00:04:28 -07:00
2022-10-17 18:29:15 +00:00
// generate HASCO file
std : : string hsaco_path = std : : string ( " /tmp/ " ) + kernel_name + std : : string ( " .hsaco " ) ;
std : : string error_message ;
int lld_result =
llvm : : sys : : ExecuteAndWait ( " /opt/rocm/llvm/bin/ld.lld " ,
{ " /opt/rocm/llvm/bin/ld.lld " , " -flavor " , " gnu " , " -shared " , " -o " , hsaco_path , isabin_path } ,
llvm : : None , { } , 0 , 0 , & error_message ) ;
if ( lld_result )
{
std : : cout < < " ld.lld execute fail: " < < std : : endl ;
std : : cout < < error_message < < std : : endl ;
std : : cout < < lld_result < < std : : endl ;
}
2021-09-09 00:04:28 -07:00
2022-10-26 17:18:33 +00:00
return std : : make_tuple ( amdgcn , hsaco_path ) ;
2022-10-12 12:02:30 -07:00
}
2021-09-09 00:04:28 -07:00
2022-10-26 17:18:33 +00:00
hipModule_t amdgpu_to_hipmodule ( const std : : string & hsaco_path )
2022-10-12 12:02:30 -07:00
{
2022-10-26 17:18:33 +00:00
std : : cout < < " llvm.cc: amdgpu_to_hipmodule: " < < std : : endl ;
2022-10-12 12:02:30 -07:00
// Read HSACO.
2022-10-26 17:18:33 +00:00
std : : ifstream hsaco_file ( hsaco_path , std : : ios : : binary | std : : ios : : ate ) ;
2022-10-12 12:02:30 -07:00
std : : ifstream : : pos_type hsaco_file_size = hsaco_file . tellg ( ) ;
2021-09-09 00:04:28 -07:00
2022-10-12 12:02:30 -07:00
std : : vector < unsigned char > hsaco ( hsaco_file_size ) ;
hsaco_file . seekg ( 0 , std : : ios : : beg ) ;
hsaco_file . read ( reinterpret_cast < char * > ( & hsaco [ 0 ] ) , hsaco_file_size ) ;
hsaco_file . close ( ) ;
hipJitOption opt [ ] = { hipJitOptionErrorLogBufferSizeBytes , hipJitOptionErrorLogBuffer ,
2021-09-09 00:04:28 -07:00
hipJitOptionInfoLogBufferSizeBytes , hipJitOptionInfoLogBuffer ,
hipJitOptionLogVerbose } ;
2022-10-12 12:02:30 -07:00
const unsigned int errbufsize = 8192 ;
const unsigned int logbufsize = 8192 ;
char _err [ errbufsize ] ;
char _log [ logbufsize ] ;
void * optval [ ] = { ( void * ) ( uintptr_t ) errbufsize ,
( void * ) _err , ( void * ) ( uintptr_t ) logbufsize ,
( void * ) _log , ( void * ) 1 } ;
hipModule_t ret ;
dispatch : : hipModuleLoadDataEx ( & ret , hsaco . data ( ) , 5 , opt , optval ) ;
return ret ;
}
2021-09-09 00:04:28 -07:00
2022-10-12 12:02:30 -07:00
} // namespace driver
} // namespace triton