[general] added simple jit interface
This commit is contained in:
@@ -1,40 +1,8 @@
|
||||
#include <cstring>
|
||||
#include <cstdio>
|
||||
#include "cuda.h"
|
||||
#include "llvm/IR/Verifier.h"
|
||||
#include "triton/ast/ast.h"
|
||||
#include "triton/ir/context.h"
|
||||
#include "triton/ir/module.h"
|
||||
#include "triton/ir/print.h"
|
||||
#include "triton/ir/context_impl.h"
|
||||
#include "triton/codegen/selection.h"
|
||||
#include "triton/codegen/tune.h"
|
||||
#include "triton/codegen/shared_copy.h"
|
||||
#include "triton/codegen/allocation.h"
|
||||
#include "triton/codegen/liveness.h"
|
||||
#include "triton/codegen/vectorize.h"
|
||||
#include "triton/codegen/buffer_info.h"
|
||||
#include "triton/codegen/barriers.h"
|
||||
#include "llvm/IR/IRPrintingPasses.h"
|
||||
#include "llvm/IR/Module.h"
|
||||
#include "llvm/IR/LLVMContext.h"
|
||||
#include "llvm/IR/PassManager.h"
|
||||
#include "llvm/Support/raw_ostream.h"
|
||||
#include "llvm/Support/TargetRegistry.h"
|
||||
#include "llvm/Support/TargetSelect.h"
|
||||
#include "llvm/Target/TargetMachine.h"
|
||||
#include "llvm/Target/TargetOptions.h"
|
||||
#include "llvm/CodeGen/TargetPassConfig.h"
|
||||
#include "llvm/IR/LegacyPassManager.h"
|
||||
#include "llvm/Transforms/Scalar/EarlyCSE.h"
|
||||
#include "llvm/Analysis/LoopPass.h"
|
||||
|
||||
typedef struct yy_buffer_state * YY_BUFFER_STATE;
|
||||
extern int yyparse();
|
||||
extern YY_BUFFER_STATE yy_scan_string(const char * str);
|
||||
extern void yy_delete_buffer(YY_BUFFER_STATE buffer);
|
||||
using triton::ast::translation_unit;
|
||||
extern translation_unit *ast_root;
|
||||
#include "triton/jit.h"
|
||||
#include "triton/driver/backend.h"
|
||||
#include "triton/driver/stream.h"
|
||||
|
||||
const char* src =
|
||||
R"(
|
||||
@@ -44,7 +12,7 @@ const tunable int32 TK;
|
||||
|
||||
void matmul(restrict read_only fp32 *a, restrict read_only fp32 *b, fp32 *c,
|
||||
int32 M, int32 N, int32 K, int32 bound){
|
||||
int32 rxa[TM] = get_global_range[TM](0)
|
||||
int32 rxa[TM] = get_global_range[TM](0);
|
||||
int32 ryb[TN] = get_global_range[TN](1);
|
||||
int32 rka[TK] = 0 ... TK;
|
||||
int32 rkb[TK] = 0 ... TK;
|
||||
@@ -83,81 +51,6 @@ void matmul(restrict read_only fp32 *a, restrict read_only fp32 *b, fp32 *c,
|
||||
}
|
||||
)";
|
||||
|
||||
static std::string compute_data_layout(bool is64Bit, bool UseShortPointers) {
|
||||
std::string Ret = "e";
|
||||
if (!is64Bit)
|
||||
Ret += "-p:32:32";
|
||||
else if (UseShortPointers)
|
||||
Ret += "-p3:32:32-p4:32:32-p5:32:32";
|
||||
Ret += "-i64:64-i128:128-v16:16-v32:32-n16:32:64";
|
||||
return Ret;
|
||||
}
|
||||
|
||||
static std::string generate_machine_code(llvm::Module &module, const std::string &target_triple, const std::string &data_layout) {
|
||||
llvm::InitializeAllTargetInfos();
|
||||
llvm::InitializeAllTargets();
|
||||
llvm::InitializeAllTargetMCs();
|
||||
llvm::InitializeAllAsmParsers();
|
||||
llvm::InitializeAllAsmPrinters();
|
||||
|
||||
module.setTargetTriple(target_triple);
|
||||
std::string error;
|
||||
auto target = llvm::TargetRegistry::lookupTarget(module.getTargetTriple(), error);
|
||||
llvm::TargetMachine *machine = target->createTargetMachine(module.getTargetTriple(), "sm_52", "",
|
||||
llvm::TargetOptions(), llvm::Reloc::Model(),
|
||||
llvm::None, llvm::CodeGenOpt::Aggressive);
|
||||
module.setDataLayout(data_layout);
|
||||
|
||||
// emit machine code
|
||||
llvm::legacy::PassManager pass;
|
||||
llvm::SmallVector<char, 0> buffer;
|
||||
llvm::raw_svector_ostream stream(buffer);
|
||||
machine->addPassesToEmitFile(pass, stream, nullptr, llvm::TargetMachine::CGFT_AssemblyFile);
|
||||
pass.run(module);
|
||||
std::string src(buffer.begin(), buffer.end());
|
||||
return src;
|
||||
}
|
||||
|
||||
static void __checkCudaErrors( CUresult err, const char *file, const int line )
|
||||
{
|
||||
if( CUDA_SUCCESS != err) {
|
||||
fprintf(stderr,
|
||||
"CUDA Driver API error = %04d from file <%s>, line %i.\n",
|
||||
err, file, line );
|
||||
exit(-1);
|
||||
}
|
||||
}
|
||||
#define checkCudaErrors(err) __checkCudaErrors (err, __FILE__, __LINE__)
|
||||
|
||||
static void compile_machine_code(CUdevice &device, CUcontext &context, CUmodule &module,
|
||||
CUfunction &function, CUstream &stream, int &major, int &minor,
|
||||
const std::string &src, const std::string &name) {
|
||||
int numDevices;
|
||||
|
||||
// Initialize
|
||||
checkCudaErrors(cuInit(0));
|
||||
checkCudaErrors(cuDeviceGetCount(&numDevices));
|
||||
checkCudaErrors(cuDeviceGet(&device, 0));
|
||||
checkCudaErrors(cuDeviceComputeCapability(&major, &minor, device));
|
||||
checkCudaErrors(cuCtxCreate(&context, 0, device));
|
||||
checkCudaErrors(cuStreamCreate(&stream, 0));
|
||||
|
||||
// Compile program
|
||||
CUjit_option opt[] = {CU_JIT_ERROR_LOG_BUFFER_SIZE_BYTES, CU_JIT_ERROR_LOG_BUFFER};
|
||||
unsigned int errbufsize = 8096;
|
||||
std::string errbuf(errbufsize, 0);
|
||||
const void *cpterr = static_cast<const void*>(errbuf.data());
|
||||
void *pterr = const_cast<void*>(cpterr);
|
||||
void* optval[] = {(void*)(uintptr_t)errbufsize, pterr};
|
||||
int err = cuModuleLoadDataEx(&module, src.data(), 2, opt, optval);
|
||||
if(err != CUDA_SUCCESS){
|
||||
std::cerr << "Compilation Failed! Log: " << std::endl;
|
||||
std::cerr << errbuf << std::endl;
|
||||
}
|
||||
|
||||
// Get function
|
||||
checkCudaErrors(cuModuleGetFunction(&function, module, name.c_str()));
|
||||
}
|
||||
|
||||
template<class T>
|
||||
void simple_gemm(std::vector<T> &c, const std::vector<T> &a, const std::vector<T> &b, size_t M, size_t N, size_t K){
|
||||
@@ -170,54 +63,7 @@ void simple_gemm(std::vector<T> &c, const std::vector<T> &a, const std::vector<T
|
||||
}
|
||||
}
|
||||
|
||||
void loop_nest(std::vector<size_t> const & ranges, std::function<void(std::vector<size_t> const &)> const & f){
|
||||
size_t D = ranges.size();
|
||||
std::vector<size_t> values(D, 0);
|
||||
// Start with innermost loop
|
||||
size_t i = D - 1;
|
||||
while(true){
|
||||
//Execute function
|
||||
f(values);
|
||||
//Increment counters
|
||||
while(values[i]++ == ranges[i] - 1){
|
||||
if(i == 0)
|
||||
return;
|
||||
values[i--] = 0;
|
||||
}
|
||||
i = D - 1;
|
||||
}
|
||||
}
|
||||
|
||||
int main() {
|
||||
// create AST from Triton-C source
|
||||
YY_BUFFER_STATE buffer = yy_scan_string(src);
|
||||
yyparse();
|
||||
yy_delete_buffer(buffer);
|
||||
translation_unit *program = ast_root;
|
||||
|
||||
// create Triton-IR from AST
|
||||
triton::ir::context context;
|
||||
triton::ir::module module("matrix", context);
|
||||
program->codegen(&module);
|
||||
llvm::LLVMContext llvm_context;
|
||||
llvm::Module llvm_module("matmul", llvm_context);
|
||||
|
||||
|
||||
|
||||
// create passes
|
||||
triton::codegen::buffer_info_pass buffer_info;
|
||||
triton::codegen::place_shared_copy shared(&buffer_info);
|
||||
triton::codegen::tune tune;
|
||||
triton::codegen::liveness liveness(&buffer_info);
|
||||
triton::codegen::allocation allocation(&liveness, &buffer_info);
|
||||
triton::codegen::barriers barriers(&allocation, &buffer_info);
|
||||
triton::codegen::vectorize vectorize(&tune);
|
||||
triton::codegen::selection selection(&allocation, &tune, &buffer_info);
|
||||
|
||||
triton::ir::print(module, std::cout);
|
||||
|
||||
// tuning parameters
|
||||
tune.run(module);
|
||||
std::vector<unsigned> params = {
|
||||
// shapes
|
||||
16, 16, 8,
|
||||
@@ -232,97 +78,49 @@ int main() {
|
||||
// b1
|
||||
1, 8, 1
|
||||
};
|
||||
unsigned TM = params[0];
|
||||
unsigned TN = params[1];
|
||||
unsigned nthreads = params[10]*params[13]*params[11]*params[14];
|
||||
|
||||
auto context = triton::driver::backend::contexts::get_default();
|
||||
triton::jit jit(context);
|
||||
jit.add_module(src, params);
|
||||
triton::driver::kernel kernel = jit.get_function("matmul");
|
||||
|
||||
// meta-parameters
|
||||
unsigned i = 0;
|
||||
context.p_impl->mp_constants_[0]->set_value(params[0]);
|
||||
context.p_impl->mp_constants_[1]->set_value(params[1]);
|
||||
context.p_impl->mp_constants_[2]->set_value(params[2]);
|
||||
for(unsigned *x: tune.get_params(module))
|
||||
*x = params[3 + i++];
|
||||
|
||||
|
||||
|
||||
// constraints
|
||||
std::map<triton::ir::value*, std::vector<std::string>> errors;
|
||||
tune.check_constraints(module, errors);
|
||||
std::cout << "errors: " << errors.size() << std::endl;
|
||||
for(auto &x: errors){
|
||||
for(auto &e: x.second)
|
||||
std::cout << x.first->get_name() << " " << e << std::endl;
|
||||
}
|
||||
if(errors.size())
|
||||
exit(EXIT_FAILURE);
|
||||
|
||||
|
||||
|
||||
// run passes
|
||||
buffer_info.run(module);
|
||||
shared.run(module);
|
||||
liveness.run(module);
|
||||
allocation.run();
|
||||
barriers.run(module);
|
||||
vectorize.run(module);
|
||||
selection.run(module, llvm_module);
|
||||
|
||||
// llvm source
|
||||
llvm::legacy::PassManager manager;
|
||||
manager.add(llvm::createPrintModulePass(llvm::outs()));
|
||||
manager.add(llvm::createVerifierPass(true));
|
||||
manager.run(llvm_module);
|
||||
|
||||
std::string src = generate_machine_code(llvm_module, "nvptx64-nvidia-cuda", compute_data_layout(true, true));
|
||||
std::cout << src << std::endl;
|
||||
|
||||
// compile machine code
|
||||
CUdevice cu_device;
|
||||
CUcontext cu_context;
|
||||
CUmodule cu_module;
|
||||
CUfunction cu_kernel;
|
||||
CUstream cu_stream;
|
||||
int major, minor;
|
||||
compile_machine_code(cu_device, cu_context, cu_module, cu_kernel, cu_stream, major, minor, src, "matmul");
|
||||
|
||||
// execute machine code
|
||||
// Allocate buffers
|
||||
typedef float numeric_t;
|
||||
size_t M = 128, N = 128, K = 128;
|
||||
size_t bound = 8;
|
||||
std::vector<numeric_t> c(M*N);
|
||||
std::vector<numeric_t> rc(M*N);
|
||||
std::vector<numeric_t> a(M*K);
|
||||
std::vector<numeric_t> b(K*N);
|
||||
std::vector<float> hc(M*N);
|
||||
std::vector<float> rc(M*N);
|
||||
std::vector<float> ha(M*K);
|
||||
std::vector<float> hb(K*N);
|
||||
srand(0);
|
||||
for(size_t i = 0; i < a.size(); i++)
|
||||
a[i] = 1;
|
||||
for(size_t i = 0; i < b.size(); i++)
|
||||
b[i] = 1;
|
||||
for(size_t i = 0; i < c.size(); i++)
|
||||
c[i] = 0;
|
||||
CUdeviceptr d_a, d_b, d_c;
|
||||
checkCudaErrors(cuMemAlloc(&d_a, sizeof(numeric_t) * a.size()));
|
||||
checkCudaErrors(cuMemAlloc(&d_b, sizeof(numeric_t) * b.size()));
|
||||
checkCudaErrors(cuMemAlloc(&d_c, sizeof(numeric_t) * c.size()));
|
||||
// Copy buffers
|
||||
checkCudaErrors(cuMemcpyHtoD(d_a, a.data(), sizeof(numeric_t) * a.size()));
|
||||
checkCudaErrors(cuMemcpyHtoD(d_b, b.data(), sizeof(numeric_t) * b.size()));
|
||||
checkCudaErrors(cuMemcpyHtoD(d_c, c.data(), sizeof(numeric_t) * c.size()));
|
||||
// Launch kernel
|
||||
void *args[] = { &d_a, &d_b, &d_c, &M, &N, &K, &bound};
|
||||
int num_regs;
|
||||
cuFuncGetAttribute(&num_regs, CU_FUNC_ATTRIBUTE_NUM_REGS, cu_kernel);
|
||||
unsigned TM = context.p_impl->mp_constants_[0]->get_value();
|
||||
unsigned TN = context.p_impl->mp_constants_[1]->get_value();
|
||||
unsigned nthreads = params[10]*params[13]*params[11]*params[14];
|
||||
checkCudaErrors(cuLaunchKernel(cu_kernel, (M + TM - 1)/TM, (N + TN - 1)/TN, 1, nthreads, 1, 1, 0, cu_stream, args, NULL));
|
||||
checkCudaErrors(cuStreamSynchronize(cu_stream));
|
||||
// Write back
|
||||
checkCudaErrors(cuMemcpyDtoH(c.data(), d_c, sizeof(numeric_t) * c.size()));
|
||||
simple_gemm(rc, a, b, M, N, K);
|
||||
for(size_t i = 0; i < ha.size(); i++)
|
||||
ha[i] = 1;
|
||||
for(size_t i = 0; i < hb.size(); i++)
|
||||
hb[i] = 1;
|
||||
for(size_t i = 0; i < hc.size(); i++)
|
||||
hc[i] = 0;
|
||||
triton::driver::buffer dc(context, hc.size()*4);
|
||||
triton::driver::buffer da(context, ha.size()*4);
|
||||
triton::driver::buffer db(context, hb.size()*4);
|
||||
triton::driver::stream stream(context);
|
||||
stream.write(da, true, 0, ha);
|
||||
stream.write(db, true, 0, hb);
|
||||
stream.write(dc, true, 0, hc);
|
||||
kernel.setArg(0, da);
|
||||
kernel.setArg(1, db);
|
||||
kernel.setArg(2, dc);
|
||||
kernel.setArg(3, M);
|
||||
kernel.setArg(4, N);
|
||||
kernel.setArg(5, K);
|
||||
kernel.setArg(6, bound);
|
||||
stream.enqueue(kernel, {(M + TM - 1)/TM, (N + TN - 1)/TN, 1}, {nthreads, 1, 1});
|
||||
stream.synchronize();
|
||||
stream.read(dc, true, 0, hc);
|
||||
simple_gemm(rc, ha, hb, M, N, K);
|
||||
for(size_t i = 0; i < M*N; i++)
|
||||
if(std::abs(c[i] - rc[i])/std::max(c[i], rc[i]) > 1e-4){
|
||||
std::cout << i << " " << c[i] << " " << rc[i] << std::endl;
|
||||
if(std::abs(hc[i] - rc[i])/std::max(hc[i], rc[i]) > 1e-4){
|
||||
std::cout << i << " " << hc[i] << " " << rc[i] << std::endl;
|
||||
exit(EXIT_FAILURE);
|
||||
}
|
||||
std::cout << "Pass!" << std::endl;
|
||||
|
@@ -33,13 +33,13 @@ namespace triton
|
||||
namespace driver
|
||||
{
|
||||
|
||||
class Buffer;
|
||||
class Stream;
|
||||
class Device;
|
||||
class Context;
|
||||
class Platform;
|
||||
class Module;
|
||||
class Kernel;
|
||||
class buffer;
|
||||
class stream;
|
||||
class device;
|
||||
class context;
|
||||
class platform;
|
||||
class module;
|
||||
class kernel;
|
||||
|
||||
struct backend
|
||||
{
|
||||
@@ -49,9 +49,9 @@ struct backend
|
||||
friend class backend;
|
||||
public:
|
||||
static void release();
|
||||
static Module& get(Stream const & stream, std::string const & name, std::string const &src);
|
||||
static module& get(driver::stream const & stream, std::string const & name, std::string const &src);
|
||||
private:
|
||||
static std::map<std::tuple<Stream, std::string>, Module * > cache_;
|
||||
static std::map<std::tuple<stream, std::string>, module * > cache_;
|
||||
};
|
||||
|
||||
class kernels
|
||||
@@ -59,53 +59,53 @@ struct backend
|
||||
friend class backend;
|
||||
public:
|
||||
static void release();
|
||||
static Kernel & get(Module const & program, std::string const & name);
|
||||
static kernel & get(driver::module const & program, std::string const & name);
|
||||
private:
|
||||
static std::map<std::tuple<Module, std::string>, Kernel * > cache_;
|
||||
static std::map<std::tuple<module, std::string>, kernel * > cache_;
|
||||
};
|
||||
|
||||
class contexts
|
||||
{
|
||||
friend class backend;
|
||||
private:
|
||||
static void init(std::vector<Platform> const &);
|
||||
static void init(std::vector<platform> const &);
|
||||
static void release();
|
||||
public:
|
||||
static Context const & get_default();
|
||||
static driver::context const & get_default();
|
||||
template<class T>
|
||||
static Context const & import(T context)
|
||||
static driver::context const & import(T ctx)
|
||||
{
|
||||
for(driver::Context const * x: cache_)
|
||||
if((T)*x==context)
|
||||
for(driver::context const * x: cache_)
|
||||
if((T)*x==ctx)
|
||||
return *x;
|
||||
cache_.emplace_back(new Context(context, false));
|
||||
cache_.emplace_back(new driver::context(ctx, false));
|
||||
return *cache_.back();
|
||||
}
|
||||
static void get(std::list<Context const *> &);
|
||||
static void get(std::list<context const *> &);
|
||||
private:
|
||||
static std::list<Context const *> cache_;
|
||||
static std::list<context const *> cache_;
|
||||
};
|
||||
|
||||
class streams
|
||||
{
|
||||
friend class backend;
|
||||
private:
|
||||
static void init(std::list<Context const *> const &);
|
||||
static void init(std::list<context const *> const &);
|
||||
static void release();
|
||||
public:
|
||||
static void get(Context const &, std::vector<Stream *> &streams);
|
||||
static Stream & get(Context const &, unsigned int id = 0);
|
||||
static Stream & get_default();
|
||||
static void get(driver::context const &, std::vector<stream *> &streams);
|
||||
static stream & get(driver::context const &, unsigned int id = 0);
|
||||
static stream & get_default();
|
||||
private:
|
||||
static std::map< Context, std::vector<Stream*> > cache_;
|
||||
static std::map< context, std::vector<stream*> > cache_;
|
||||
};
|
||||
|
||||
static void init();
|
||||
static void release();
|
||||
|
||||
static std::vector<Device> devices();
|
||||
static std::vector<Platform> platforms();
|
||||
static void synchronize(Context const &);
|
||||
static std::vector<device> devices();
|
||||
static std::vector<platform> platforms();
|
||||
static void synchronize(driver::context const &);
|
||||
|
||||
static unsigned int default_device;
|
||||
};
|
||||
|
@@ -31,21 +31,21 @@ namespace triton
|
||||
namespace driver
|
||||
{
|
||||
|
||||
class Stream;
|
||||
class stream;
|
||||
|
||||
// Buffer
|
||||
class Buffer: public HandleInterface<Buffer, CUdeviceptr>
|
||||
class buffer: public handle_interface<buffer, CUdeviceptr>
|
||||
{
|
||||
public:
|
||||
Buffer(Context const & context, size_t size);
|
||||
Buffer(Context const & context, CUdeviceptr cu, bool take_ownership);
|
||||
void set_zero(Stream const & queue, size_t size);
|
||||
Handle<CUdeviceptr> const & cu() const;
|
||||
Handle<CUdeviceptr> & cu();
|
||||
buffer(driver::context const & context, size_t size);
|
||||
buffer(driver::context const & context, CUdeviceptr cu, bool take_ownership);
|
||||
void set_zero(stream const & queue, size_t size);
|
||||
handle<CUdeviceptr> const & cu() const;
|
||||
handle<CUdeviceptr> & cu();
|
||||
|
||||
private:
|
||||
Context context_;
|
||||
Handle<CUdeviceptr> cu_;
|
||||
context context_;
|
||||
handle<CUdeviceptr> cu_;
|
||||
};
|
||||
|
||||
}
|
||||
|
@@ -31,7 +31,7 @@ namespace triton
|
||||
namespace driver
|
||||
{
|
||||
|
||||
class Context: public HandleInterface<Context, CUcontext>
|
||||
class context: public handle_interface<context, CUcontext>
|
||||
{
|
||||
private:
|
||||
static std::string get_cache_path();
|
||||
@@ -39,25 +39,25 @@ private:
|
||||
|
||||
public:
|
||||
//Constructors
|
||||
explicit Context(CUcontext context, bool take_ownership = true);
|
||||
explicit Context(Device const & device);
|
||||
explicit context(CUcontext context, bool take_ownership = true);
|
||||
explicit context(driver::device const & dvc);
|
||||
//Accessors
|
||||
Device const & device() const;
|
||||
driver::device const & device() const;
|
||||
std::string const & cache_path() const;
|
||||
Handle<CUcontext> const & cu() const;
|
||||
handle<CUcontext> const & cu() const;
|
||||
|
||||
private:
|
||||
Handle<CUcontext> cu_;
|
||||
Device device_;
|
||||
handle<CUcontext> cu_;
|
||||
driver::device dvc_;
|
||||
std::string cache_path_;
|
||||
};
|
||||
|
||||
class ContextSwitcher{
|
||||
public:
|
||||
ContextSwitcher(Context const & ctx);
|
||||
ContextSwitcher(driver::context const & ctx);
|
||||
~ContextSwitcher();
|
||||
private:
|
||||
Context const & ctx_;
|
||||
driver::context const & ctx_;
|
||||
};
|
||||
|
||||
}
|
||||
|
@@ -51,7 +51,7 @@ static const std::vector<cublasGemmAlgo_t> cublasAlgorithms = {
|
||||
static const std::map<DType, cudaDataType> cudtype = {{FLOAT_TYPE, CUDA_R_32F}, {DOUBLE_TYPE,CUDA_R_64F}};
|
||||
static const std::map<char, cublasOperation_t> cuop = {{'N', CUBLAS_OP_N}, {'T', CUBLAS_OP_T}};
|
||||
|
||||
inline cublasGemmAlgo_t cublasGemmFastest(Stream& stream, cublasHandle_t handle, cudaDataType cudt, cublasOperation_t AT, cublasOperation_t BT, int32_t M, int32_t N, int32_t K,
|
||||
inline cublasGemmAlgo_t cublasGemmFastest(stream& stream, cublasHandle_t handle, cudaDataType cudt, cublasOperation_t AT, cublasOperation_t BT, int32_t M, int32_t N, int32_t K,
|
||||
void* alpha, CUdeviceptr A, int32_t lda, CUdeviceptr B, int32_t ldb,
|
||||
void* beta, CUdeviceptr C, int32_t ldc){
|
||||
|
||||
@@ -84,7 +84,7 @@ inline void cublasGemmEx(cublasHandle_t handle, cudaDataType cudt, cublasOperati
|
||||
|
||||
|
||||
/* Simplified API for default GEMM */
|
||||
inline void cublasGemm(DType dtype, Stream& stream, char cAT, char cBT, int32_t M, int32_t N, int32_t K, scalar alpha, Buffer const & A, int32_t lda, Buffer const & B, int32_t ldb, scalar beta, Buffer& C, int32_t ldc, cublasGemmAlgo_t* fastest = NULL, cublasGemmAlgo_t algo = CUBLAS_GEMM_DFALT){
|
||||
inline void cublasGemm(DType dtype, stream& stream, char cAT, char cBT, int32_t M, int32_t N, int32_t K, scalar alpha, buffer const & A, int32_t lda, buffer const & B, int32_t ldb, scalar beta, buffer& C, int32_t ldc, cublasGemmAlgo_t* fastest = NULL, cublasGemmAlgo_t algo = CUBLAS_GEMM_DFALT){
|
||||
ContextSwitcher ctx_switch(stream.context());
|
||||
cublasHandle_t handle = dispatch::cublasHandle(stream.context());
|
||||
dispatch::cublasSetStream_v2(handle, (CUstream)stream);
|
||||
@@ -111,9 +111,9 @@ inline cudnnTensorFormat_t format(cudnnDataType_t cutype){
|
||||
}
|
||||
}
|
||||
|
||||
inline void cudnnConv(DType dtype, Stream& stream, int32_t D, int32_t H, int32_t W, int32_t N, int32_t K, int32_t M, int32_t P, int32_t Q, int32_t C, int32_t T, int32_t R, int32_t S,
|
||||
int32_t pad_d, int32_t pad_h, int32_t pad_w, int32_t stride_d, int32_t stride_h, int32_t stride_w, scalar alpha, Buffer const & I, Buffer const & F, scalar beta, Buffer const & O){
|
||||
driver::Context const & ctx = stream.context();
|
||||
inline void cudnnConv(DType dtype, stream& stream, int32_t D, int32_t H, int32_t W, int32_t N, int32_t K, int32_t M, int32_t P, int32_t Q, int32_t C, int32_t T, int32_t R, int32_t S,
|
||||
int32_t pad_d, int32_t pad_h, int32_t pad_w, int32_t stride_d, int32_t stride_h, int32_t stride_w, scalar alpha, buffer const & I, buffer const & F, scalar beta, buffer const & O){
|
||||
driver::driver::context const & ctx = stream.context();
|
||||
ContextSwitcher switch_ctx(ctx);
|
||||
|
||||
std::vector<int> pad = {pad_d, pad_h, pad_w};
|
||||
@@ -154,16 +154,16 @@ inline void cudnnConv(DType dtype, Stream& stream, int32_t D, int32_t H, int32_t
|
||||
|
||||
size_t workspace_size;
|
||||
dispatch::cudnnGetConvolutionForwardWorkspaceSize(handle, tI, tF, conv, tO, algo, &workspace_size);
|
||||
static Buffer work(ctx, 1024*1024*64);
|
||||
static buffer work(ctx, 1024*1024*64);
|
||||
CUdeviceptr twork = work;
|
||||
CUdeviceptr pI = I, pF = F, pO = O;
|
||||
dispatch::cudnnConvolutionForward(handle, alpha.data(), tI, (void*)pI, tF, (void*)pF, conv, algo, (void*)twork, workspace_size, beta.data(), tO, (void*)pO);
|
||||
}
|
||||
|
||||
|
||||
inline void cudnnPool(DType dtype, Stream& stream, int32_t D, int32_t H, int32_t W, int32_t N, int32_t K, int32_t M, int32_t P, int32_t Q, int32_t T, int32_t R, int32_t S,
|
||||
int32_t pad_d, int32_t pad_h, int32_t pad_w, int32_t stride_d, int32_t stride_h, int32_t stride_w, scalar alpha, Buffer const & I, scalar beta, Buffer const & O){
|
||||
driver::Context const & ctx = stream.context();
|
||||
inline void cudnnPool(DType dtype, stream& stream, int32_t D, int32_t H, int32_t W, int32_t N, int32_t K, int32_t M, int32_t P, int32_t Q, int32_t T, int32_t R, int32_t S,
|
||||
int32_t pad_d, int32_t pad_h, int32_t pad_w, int32_t stride_d, int32_t stride_h, int32_t stride_w, scalar alpha, buffer const & I, scalar beta, buffer const & O){
|
||||
driver::driver::context const & ctx = stream.context();
|
||||
ContextSwitcher switch_ctx(ctx);
|
||||
|
||||
std::vector<int> pad = {pad_d, pad_h, pad_w};
|
||||
@@ -200,11 +200,11 @@ inline void cudnnPool(DType dtype, Stream& stream, int32_t D, int32_t H, int32_t
|
||||
dispatch::cudnnPoolingForward(handle, desc, alpha.data(), tI, (void*)pI, beta.data(), tO, (void*)pO);
|
||||
}
|
||||
|
||||
inline void cudnnTransformTensor(driver::Stream & stream,
|
||||
inline void cudnnTransformTensor(driver::stream & stream,
|
||||
DType in_dtype, DType out_dtype,
|
||||
cudnnTensorFormat_t in_layout, cudnnTensorFormat_t out_layout,
|
||||
int32_t N, int32_t C, int32_t D, int32_t H, int32_t W,
|
||||
scalar alpha, driver::Buffer const & I, scalar beta, driver::Buffer& O)
|
||||
scalar alpha, driver::buffer const & I, scalar beta, driver::buffer& O)
|
||||
{
|
||||
cudnnHandle_t handle = dispatch::cudnnHandle(stream.context());
|
||||
dispatch::cudnnSetStream(handle, (CUstream)stream);
|
||||
|
@@ -33,7 +33,7 @@ namespace driver
|
||||
{
|
||||
|
||||
// Device
|
||||
class Device: public HandleInterface<Device, CUdevice>
|
||||
class device: public handle_interface<device, CUdevice>
|
||||
{
|
||||
public:
|
||||
//Supported architectures
|
||||
@@ -61,14 +61,14 @@ private:
|
||||
inline nvmlDevice_t nvml_device() const;
|
||||
|
||||
public:
|
||||
Device(CUdevice cu = CUdevice(), bool take_ownership = true): cu_(cu, take_ownership){}
|
||||
device(CUdevice cu = CUdevice(), bool take_ownership = true): cu_(cu, take_ownership){}
|
||||
//Accessors
|
||||
Architecture architecture() const;
|
||||
Handle<CUdevice> const & cu() const;
|
||||
handle<CUdevice> const & cu() const;
|
||||
//Informations
|
||||
std::string infos() const;
|
||||
size_t address_bits() const;
|
||||
driver::Platform platform() const;
|
||||
driver::platform platform() const;
|
||||
std::vector<size_t> max_block_dim() const;
|
||||
size_t max_threads_per_block() const;
|
||||
size_t max_shared_memory() const;
|
||||
@@ -87,7 +87,7 @@ public:
|
||||
size_t max_mem_clock() const;
|
||||
|
||||
private:
|
||||
Handle<CUdevice> cu_;
|
||||
handle<CUdevice> cu_;
|
||||
std::shared_ptr<std::pair<size_t, size_t>> interpreted_as_;
|
||||
};
|
||||
|
||||
|
@@ -42,7 +42,7 @@ namespace triton
|
||||
namespace driver
|
||||
{
|
||||
|
||||
class Context;
|
||||
class context;
|
||||
|
||||
template<class T> void check(T){}
|
||||
void check(nvrtcResult err);
|
||||
@@ -137,7 +137,7 @@ public:
|
||||
static nvrtcResult nvrtcCreateProgram(nvrtcProgram *prog, const char *src, const char *name, int numHeaders, const char **headers, const char **includeNames);
|
||||
static nvrtcResult nvrtcGetProgramLog(nvrtcProgram prog, char *log);
|
||||
|
||||
static cublasHandle_t cublasHandle(Context const & ctx);
|
||||
static cublasHandle_t cublasHandle(driver::context const & ctx);
|
||||
static cublasStatus_t cublasCreate_v2(cublasHandle_t* h);
|
||||
static cublasStatus_t cublasGetStream_v2(cublasHandle_t h, cudaStream_t *streamId);
|
||||
static cublasStatus_t cublasSetStream_v2(cublasHandle_t h, cudaStream_t streamId);
|
||||
@@ -146,7 +146,7 @@ public:
|
||||
static cublasStatus_t cublasHgemm (cublasHandle_t handle, cublasOperation_t transa, cublasOperation_t transb, int m, int n, int k, half* alpha, const half *A, int lda, const half *B, int ldb, half* beta, half *C, int ldc);
|
||||
static cublasStatus_t cublasGemmEx(cublasHandle_t handle, cublasOperation_t transa, cublasOperation_t transb, int m, int n, int k, const void *alpha, const void *A, cudaDataType Atype, int lda, const void *B, cudaDataType Btype, int ldb, const void *beta, void *C, cudaDataType Ctype, int ldc, cudaDataType computeType, cublasGemmAlgo_t algo);
|
||||
|
||||
static cudnnHandle_t cudnnHandle(Context const & ctx);
|
||||
static cudnnHandle_t cudnnHandle(driver::context const & ctx);
|
||||
static cudnnStatus_t cudnnCreatePoolingDescriptor(cudnnPoolingDescriptor_t *poolingDesc);
|
||||
static cudnnStatus_t cudnnCreateConvolutionDescriptor(cudnnConvolutionDescriptor_t* convDesc);
|
||||
static cudnnStatus_t cudnnCreateTensorDescriptor(cudnnTensorDescriptor_t *tensorDesc);
|
||||
|
@@ -32,14 +32,14 @@ namespace driver
|
||||
{
|
||||
|
||||
// Event
|
||||
class Event: public HandleInterface<Event, cu_event_t>
|
||||
class Event: public handle_interface<Event, cu_event_t>
|
||||
{
|
||||
public:
|
||||
float elapsed_time() const;
|
||||
Handle<cu_event_t> const & cu() const;
|
||||
handle<cu_event_t> const & cu() const;
|
||||
|
||||
private:
|
||||
Handle<cu_event_t> cu_;
|
||||
handle<cu_event_t> cu_;
|
||||
};
|
||||
|
||||
}
|
||||
|
@@ -49,24 +49,24 @@ private:
|
||||
};
|
||||
|
||||
template<class T, class CUType>
|
||||
class HandleInterface{
|
||||
class handle_interface{
|
||||
public:
|
||||
//Accessors
|
||||
operator CUType() const { return *(((T*)this)->cu().h_); }
|
||||
//Comparison
|
||||
bool operator==(HandleInterface const & y) { return (CUType)(*this) == (CUType)(y); }
|
||||
bool operator!=(HandleInterface const & y) { return (CUType)(*this) != (CUType)(y); }
|
||||
bool operator<(HandleInterface const & y) { return (CUType)(*this) < (CUType)(y); }
|
||||
bool operator==(handle_interface const & y) { return (CUType)(*this) == (CUType)(y); }
|
||||
bool operator!=(handle_interface const & y) { return (CUType)(*this) != (CUType)(y); }
|
||||
bool operator<(handle_interface const & y) { return (CUType)(*this) < (CUType)(y); }
|
||||
};
|
||||
|
||||
template<class CUType>
|
||||
class Handle{
|
||||
class handle{
|
||||
public:
|
||||
template<class, class> friend class HandleInterface;
|
||||
template<class, class> friend class handle_interface;
|
||||
public:
|
||||
//Constructors
|
||||
Handle(CUType cu = CUType(), bool take_ownership = true);
|
||||
~Handle();
|
||||
handle(CUType cu = CUType(), bool take_ownership = true);
|
||||
~handle();
|
||||
CUType& operator*() { return *h_; }
|
||||
CUType const & operator*() const { return *h_; }
|
||||
CUType* operator->() const { return h_.get(); }
|
||||
|
@@ -34,27 +34,27 @@ namespace triton
|
||||
namespace driver
|
||||
{
|
||||
|
||||
class Buffer;
|
||||
class buffer;
|
||||
|
||||
// Kernel
|
||||
class Kernel: public HandleInterface<Kernel, CUfunction>
|
||||
class kernel: public handle_interface<kernel, CUfunction>
|
||||
{
|
||||
public:
|
||||
//Constructors
|
||||
Kernel(Module const & program, const char * name);
|
||||
kernel(driver::module const & program, const char * name);
|
||||
//Accessors
|
||||
Handle<CUfunction> const & cu() const;
|
||||
Module const & module() const;
|
||||
handle<CUfunction> const & cu() const;
|
||||
driver::module const & module() const;
|
||||
//Arguments setters
|
||||
void setArg(unsigned int index, std::size_t size, void* ptr);
|
||||
void setArg(unsigned int index, Buffer const &);
|
||||
void setArg(unsigned int index, buffer const &);
|
||||
template<class T> void setArg(unsigned int index, T value) { setArg(index, sizeof(T), (void*)&value); }
|
||||
//Arguments getters
|
||||
void* const* cu_params() const;
|
||||
|
||||
private:
|
||||
Handle<CUfunction> cu_;
|
||||
Module program_;
|
||||
handle<CUfunction> cu_;
|
||||
driver::module program_;
|
||||
unsigned int address_bits_;
|
||||
std::vector<std::shared_ptr<void> > cu_params_store_;
|
||||
std::vector<void*> cu_params_;
|
||||
|
@@ -34,22 +34,22 @@ namespace triton
|
||||
namespace driver
|
||||
{
|
||||
|
||||
class Context;
|
||||
class Device;
|
||||
class context;
|
||||
class device;
|
||||
|
||||
class Module: public HandleInterface<Module, CUmodule>
|
||||
class module: public handle_interface<module, CUmodule>
|
||||
{
|
||||
static std::string header(Device const & device);
|
||||
static std::string header(device const & device);
|
||||
|
||||
public:
|
||||
Module(Context const & context, std::string const & source);
|
||||
Context const & context() const;
|
||||
Handle<CUmodule> const & cu() const;
|
||||
Buffer symbol(const char * name) const;
|
||||
module(driver::context const & context, std::string const & source);
|
||||
driver::context const & context() const;
|
||||
handle<CUmodule> const & cu() const;
|
||||
buffer symbol(const char * name) const;
|
||||
|
||||
private:
|
||||
Handle<CUmodule> cu_;
|
||||
Context context_;
|
||||
handle<CUmodule> cu_;
|
||||
driver::context context_;
|
||||
std::string source_;
|
||||
};
|
||||
|
||||
|
@@ -34,17 +34,17 @@ namespace triton
|
||||
namespace driver
|
||||
{
|
||||
|
||||
class Device;
|
||||
class device;
|
||||
|
||||
class Platform
|
||||
class platform
|
||||
{
|
||||
public:
|
||||
//Accessors
|
||||
std::string name() const { return "CUDA"; }
|
||||
std::string version() const;
|
||||
std::vector<Device> devices() const;
|
||||
std::vector<device> devices() const;
|
||||
private:
|
||||
Handle<cu_platform> cu_;
|
||||
handle<cu_platform> cu_;
|
||||
};
|
||||
|
||||
}
|
||||
|
@@ -35,43 +35,43 @@ namespace triton
|
||||
namespace driver
|
||||
{
|
||||
|
||||
class Kernel;
|
||||
class kernel;
|
||||
class Event;
|
||||
class Range;
|
||||
class Buffer;
|
||||
class buffer;
|
||||
|
||||
// Command Queue
|
||||
class Stream: public HandleInterface<Stream, CUstream>
|
||||
class stream: public handle_interface<stream, CUstream>
|
||||
{
|
||||
public:
|
||||
//Constructors
|
||||
Stream(CUstream stream, bool take_ownership);
|
||||
Stream(Context const & context);
|
||||
stream(CUstream stream, bool take_ownership);
|
||||
stream(driver::context const & context);
|
||||
|
||||
//Accessors
|
||||
Handle<CUstream> const & cu() const;
|
||||
Context const & context() const;
|
||||
handle<CUstream> const & cu() const;
|
||||
driver::context const & context() const;
|
||||
|
||||
//Synchronize
|
||||
void synchronize();
|
||||
|
||||
//Enqueue
|
||||
void enqueue(Kernel const & kernel, std::array<size_t, 3> grid, std::array<size_t, 3> block, std::vector<Event> const * = NULL, Event *event = NULL);
|
||||
void enqueue(kernel const & kernel, std::array<size_t, 3> grid, std::array<size_t, 3> block, std::vector<Event> const * = NULL, Event *event = NULL);
|
||||
|
||||
// Write
|
||||
void write(Buffer const & buffer, bool blocking, std::size_t offset, std::size_t size, void const* ptr);
|
||||
void write(driver::buffer const & buffer, bool blocking, std::size_t offset, std::size_t size, void const* ptr);
|
||||
|
||||
template<class T> void write(Buffer const & buffer, bool blocking, std::size_t offset, std::vector<T> const & x)
|
||||
template<class T> void write(driver::buffer const & buffer, bool blocking, std::size_t offset, std::vector<T> const & x)
|
||||
{ write(buffer, blocking, offset, x.size()*sizeof(T), x.data()); }
|
||||
|
||||
// Read
|
||||
void read(Buffer const & buffer, bool blocking, std::size_t offset, std::size_t size, void* ptr);
|
||||
void read(driver::buffer const & buffer, bool blocking, std::size_t offset, std::size_t size, void* ptr);
|
||||
|
||||
template<class T> void read(Buffer const & buffer, bool blocking, std::size_t offset, std::vector<T>& x)
|
||||
template<class T> void read(driver::buffer const & buffer, bool blocking, std::size_t offset, std::vector<T>& x)
|
||||
{ read(buffer, blocking, offset, x.size()*sizeof(T), x.data()); }
|
||||
private:
|
||||
Context context_;
|
||||
Handle<CUstream> cu_;
|
||||
driver::context context_;
|
||||
handle<CUstream> cu_;
|
||||
};
|
||||
|
||||
|
||||
|
45
include/triton/jit.h
Normal file
45
include/triton/jit.h
Normal file
@@ -0,0 +1,45 @@
|
||||
#ifndef TDL_INCLUDE_JIT_H
|
||||
#define TDL_INCLUDE_JIT_H
|
||||
|
||||
#include <string>
|
||||
#include <memory>
|
||||
#include "llvm/IR/LLVMContext.h"
|
||||
#include "triton/ir/context.h"
|
||||
#include "triton/driver/module.h"
|
||||
#include "triton/driver/kernel.h"
|
||||
|
||||
namespace llvm {
|
||||
class Module;
|
||||
}
|
||||
|
||||
namespace triton {
|
||||
|
||||
namespace ir {
|
||||
class module;
|
||||
class context;
|
||||
}
|
||||
|
||||
class jit {
|
||||
private:
|
||||
void init_llvm();
|
||||
std::string compute_data_layout(bool is64Bit = true, bool UseShortPointers = true);
|
||||
std::unique_ptr<llvm::Module> make_llvm_module(triton::ir::module &module, const std::vector<unsigned>& params);
|
||||
std::unique_ptr<ir::module> make_triton_module(const std::string &src);
|
||||
|
||||
public:
|
||||
jit(driver::context context);
|
||||
void add_module(ir::module &module, const std::vector<unsigned>& params = {});
|
||||
void add_module(const std::string &src, const std::vector<unsigned>& params = {});
|
||||
driver::kernel get_function(const std::string &name);
|
||||
|
||||
private:
|
||||
std::vector<driver::module> modules_;
|
||||
driver::context driver_context_;
|
||||
llvm::LLVMContext llvm_context_;
|
||||
ir::context triton_context_;
|
||||
};
|
||||
|
||||
|
||||
}
|
||||
|
||||
#endif
|
@@ -47,14 +47,14 @@ void backend::modules::release(){
|
||||
cache_.clear();
|
||||
}
|
||||
|
||||
Module& backend::modules::get(Stream const & stream, std::string const & name, std::string const & src){
|
||||
std::tuple<Stream, std::string> key(stream, name);
|
||||
module& backend::modules::get(driver::stream const & stream, std::string const & name, std::string const & src){
|
||||
std::tuple<driver::stream, std::string> key(stream, name);
|
||||
if(cache_.find(key)==cache_.end())
|
||||
return *cache_.insert(std::make_pair(key, new Module(stream.context(), src))).first->second;
|
||||
return *cache_.insert(std::make_pair(key, new module(stream.context(), src))).first->second;
|
||||
return *cache_.at(key);
|
||||
}
|
||||
|
||||
std::map<std::tuple<Stream, std::string>, Module * > backend::modules::cache_;
|
||||
std::map<std::tuple<stream, std::string>, module * > backend::modules::cache_;
|
||||
|
||||
/*-----------------------------------*/
|
||||
//----------- Kernels --------------*/
|
||||
@@ -66,23 +66,23 @@ void backend::kernels::release(){
|
||||
cache_.clear();
|
||||
}
|
||||
|
||||
Kernel & backend::kernels::get(Module const & program, std::string const & name){
|
||||
std::tuple<Module, std::string> key(program, name);
|
||||
kernel & backend::kernels::get(driver::module const & program, std::string const & name){
|
||||
std::tuple<module, std::string> key(program, name);
|
||||
if(cache_.find(key)==cache_.end())
|
||||
return *cache_.insert(std::make_pair(key, new Kernel(program, name.c_str()))).first->second;
|
||||
return *cache_.insert(std::make_pair(key, new kernel(program, name.c_str()))).first->second;
|
||||
return *cache_.at(key);
|
||||
}
|
||||
|
||||
std::map<std::tuple<Module, std::string>, Kernel * > backend::kernels::cache_;
|
||||
std::map<std::tuple<module, std::string>, kernel * > backend::kernels::cache_;
|
||||
|
||||
/*-----------------------------------*/
|
||||
//------------ Queues --------------*/
|
||||
/*-----------------------------------*/
|
||||
|
||||
void backend::streams::init(std::list<const Context *> const & contexts){
|
||||
for(Context const * ctx : contexts)
|
||||
void backend::streams::init(std::list<const context *> const & contexts){
|
||||
for(context const * ctx : contexts)
|
||||
if(cache_.find(*ctx)==cache_.end())
|
||||
cache_.insert(std::make_pair(*ctx, std::vector<Stream*>{new Stream(*ctx)}));
|
||||
cache_.insert(std::make_pair(*ctx, std::vector<stream*>{new stream(*ctx)}));
|
||||
}
|
||||
|
||||
void backend::streams::release(){
|
||||
@@ -92,32 +92,32 @@ void backend::streams::release(){
|
||||
cache_.clear();
|
||||
}
|
||||
|
||||
Stream & backend::streams::get_default()
|
||||
stream & backend::streams::get_default()
|
||||
{ return get(contexts::get_default(), 0); }
|
||||
|
||||
Stream & backend::streams::get(Context const & context, unsigned int id){
|
||||
init(std::list<Context const *>(1,&context));
|
||||
stream & backend::streams::get(driver::context const & context, unsigned int id){
|
||||
init(std::list<driver::context const *>(1,&context));
|
||||
for(auto & x : cache_)
|
||||
if(x.first==context)
|
||||
return *x.second[id];
|
||||
throw;
|
||||
}
|
||||
|
||||
void backend::streams::get(Context const & context, std::vector<Stream*> & queues){
|
||||
init(std::list<Context const *>(1,&context));
|
||||
void backend::streams::get(driver::context const & context, std::vector<stream*> & queues){
|
||||
init(std::list<driver::context const *>(1,&context));
|
||||
queues = cache_.at(context);
|
||||
}
|
||||
|
||||
std::map<Context, std::vector<Stream*> > backend::streams::cache_;
|
||||
std::map<context, std::vector<stream*> > backend::streams::cache_;
|
||||
|
||||
/*-----------------------------------*/
|
||||
//------------ Contexts ------------*/
|
||||
/*-----------------------------------*/
|
||||
|
||||
void backend::contexts::init(std::vector<Platform> const & platforms){
|
||||
for(Platform const & platform: platforms){
|
||||
for(Device const & device: platform.devices())
|
||||
cache_.push_back(new Context(device));
|
||||
void backend::contexts::init(std::vector<platform> const & platforms){
|
||||
for(platform const & platform: platforms){
|
||||
for(device const & device: platform.devices())
|
||||
cache_.push_back(new context(device));
|
||||
}
|
||||
}
|
||||
|
||||
@@ -127,19 +127,19 @@ void backend::contexts::release(){
|
||||
cache_.clear();
|
||||
}
|
||||
|
||||
Context const & backend::contexts::get_default(){
|
||||
driver::context const & backend::contexts::get_default(){
|
||||
backend::init();
|
||||
std::list<Context const *>::const_iterator it = cache_.begin();
|
||||
std::list<context const *>::const_iterator it = cache_.begin();
|
||||
std::advance(it, default_device);
|
||||
return **it;
|
||||
}
|
||||
|
||||
void backend::contexts::get(std::list<Context const *> & contexts){
|
||||
void backend::contexts::get(std::list<context const *> & contexts){
|
||||
backend::init();
|
||||
contexts = cache_;
|
||||
}
|
||||
|
||||
std::list<Context const *> backend::contexts::cache_;
|
||||
std::list<context const *> backend::contexts::cache_;
|
||||
|
||||
|
||||
|
||||
@@ -147,28 +147,28 @@ std::list<Context const *> backend::contexts::cache_;
|
||||
//------------ General -------------*/
|
||||
/*-----------------------------------*/
|
||||
|
||||
std::vector<Device> backend::devices(){
|
||||
std::vector<Platform> platforms = backend::platforms();
|
||||
std::vector<Device> result;
|
||||
for(Platform const & platform: platforms){
|
||||
std::vector<device> backend::devices(){
|
||||
std::vector<platform> platforms = backend::platforms();
|
||||
std::vector<device> result;
|
||||
for(platform const & platform: platforms){
|
||||
auto devices = platform.devices();
|
||||
result.insert(result.end(), devices.begin(), devices.end());
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
std::vector<Platform> backend::platforms(){
|
||||
std::vector<Platform> platforms;
|
||||
std::vector<platform> backend::platforms(){
|
||||
std::vector<platform> platforms;
|
||||
//if CUDA is here
|
||||
if(dispatch::cuinit())
|
||||
platforms.push_back(Platform());
|
||||
platforms.push_back(platform());
|
||||
if(platforms.empty())
|
||||
throw std::runtime_error("ISAAC: No backend available. Make sure CUDA is available in your library path");
|
||||
return platforms;
|
||||
}
|
||||
|
||||
void backend::synchronize(Context const & context){
|
||||
for(Stream * queue: streams::cache_.at(context))
|
||||
void backend::synchronize(driver::context const & context){
|
||||
for(stream * queue: streams::cache_.at(context))
|
||||
queue->synchronize();
|
||||
}
|
||||
|
||||
@@ -184,7 +184,7 @@ void backend::release(){
|
||||
void backend::init(){
|
||||
if(!contexts::cache_.empty())
|
||||
return;
|
||||
std::vector<Platform> platforms = backend::platforms();
|
||||
std::vector<platform> platforms = backend::platforms();
|
||||
contexts::init(platforms);
|
||||
streams::init(contexts::cache_);
|
||||
}
|
||||
|
@@ -33,26 +33,26 @@ namespace triton
|
||||
namespace driver
|
||||
{
|
||||
|
||||
Buffer::Buffer(Context const & context, size_t size) : context_(context)
|
||||
buffer::buffer(driver::context const & context, size_t size) : context_(context)
|
||||
{
|
||||
ContextSwitcher ctx_switch(context_);
|
||||
dispatch::cuMemAlloc(&*cu_, size);
|
||||
}
|
||||
|
||||
Buffer::Buffer(Context const & context, CUdeviceptr cu, bool take_ownership):
|
||||
buffer::buffer(driver::context const & context, CUdeviceptr cu, bool take_ownership):
|
||||
context_(context), cu_(cu, take_ownership)
|
||||
{ }
|
||||
|
||||
void Buffer::set_zero(Stream const & queue, size_t size)
|
||||
void buffer::set_zero(stream const & queue, size_t size)
|
||||
{
|
||||
ContextSwitcher ctx_switch(context_);
|
||||
dispatch::cuMemsetD8Async(*cu_, 0, size, queue);
|
||||
}
|
||||
|
||||
Handle<CUdeviceptr> const & Buffer::cu() const
|
||||
handle<CUdeviceptr> const & buffer::cu() const
|
||||
{ return cu_; }
|
||||
|
||||
Handle<CUdeviceptr> & Buffer::cu()
|
||||
handle<CUdeviceptr> & buffer::cu()
|
||||
{ return cu_; }
|
||||
|
||||
}
|
||||
|
@@ -35,7 +35,7 @@ namespace triton
|
||||
namespace driver
|
||||
{
|
||||
|
||||
std::string Context::get_cache_path(){
|
||||
std::string context::get_cache_path(){
|
||||
//user-specified cache path
|
||||
std::string result = tools::getenv("ISAAC_CACHE_PATH");
|
||||
if(!result.empty()){
|
||||
@@ -54,7 +54,7 @@ std::string Context::get_cache_path(){
|
||||
return "";
|
||||
}
|
||||
|
||||
CUdevice Context::device(CUcontext context){
|
||||
CUdevice context::device(CUcontext context){
|
||||
dispatch::cuCtxPushCurrent_v2(context);
|
||||
CUdevice res;
|
||||
dispatch::cuCtxGetDevice(&res);
|
||||
@@ -62,26 +62,26 @@ CUdevice Context::device(CUcontext context){
|
||||
return res;
|
||||
}
|
||||
|
||||
Context::Context(CUcontext context, bool take_ownership): cu_(context, take_ownership), device_(device(context), false), cache_path_(get_cache_path())
|
||||
context::context(CUcontext context, bool take_ownership): cu_(context, take_ownership), dvc_(device(context), false), cache_path_(get_cache_path())
|
||||
{ }
|
||||
|
||||
Context::Context(Device const & device): device_(device), cache_path_(get_cache_path())
|
||||
context::context(driver::device const & device): dvc_(device), cache_path_(get_cache_path())
|
||||
{
|
||||
dispatch::cuCtxCreate(&*cu_, CU_CTX_SCHED_AUTO, (CUdevice)device);
|
||||
dispatch::cuCtxPopCurrent_v2(NULL);
|
||||
}
|
||||
|
||||
Device const & Context::device() const
|
||||
{ return device_; }
|
||||
device const & context::device() const
|
||||
{ return dvc_; }
|
||||
|
||||
std::string const & Context::cache_path() const
|
||||
std::string const & context::cache_path() const
|
||||
{ return cache_path_; }
|
||||
|
||||
Handle<CUcontext> const & Context::cu() const
|
||||
handle<CUcontext> const & context::cu() const
|
||||
{ return cu_; }
|
||||
|
||||
/* Context Switcher */
|
||||
ContextSwitcher::ContextSwitcher(Context const & ctx): ctx_(ctx)
|
||||
ContextSwitcher::ContextSwitcher(driver::context const & ctx): ctx_(ctx)
|
||||
{
|
||||
dispatch::cuCtxPushCurrent_v2(ctx_);
|
||||
}
|
||||
|
@@ -35,7 +35,7 @@ namespace driver
|
||||
{
|
||||
|
||||
/* Architecture [NVidia] */
|
||||
Device::Architecture Device::nv_arch(std::pair<unsigned int, unsigned int> sm) const{
|
||||
device::Architecture device::nv_arch(std::pair<unsigned int, unsigned int> sm) const{
|
||||
switch(sm.first)
|
||||
{
|
||||
case 7:
|
||||
@@ -81,13 +81,13 @@ Device::Architecture Device::nv_arch(std::pair<unsigned int, unsigned int> sm) c
|
||||
}
|
||||
|
||||
template<CUdevice_attribute attr>
|
||||
int Device::cuGetInfo() const{
|
||||
int device::cuGetInfo() const{
|
||||
int res;
|
||||
dispatch::cuDeviceGetAttribute(&res, attr, *cu_);
|
||||
return res;
|
||||
}
|
||||
|
||||
nvmlDevice_t Device::nvml_device() const{
|
||||
nvmlDevice_t device::nvml_device() const{
|
||||
std::map<std::string, nvmlDevice_t> map;
|
||||
std::string key = pci_bus_id();
|
||||
if(map.find(key)==map.end()){
|
||||
@@ -99,33 +99,33 @@ nvmlDevice_t Device::nvml_device() const{
|
||||
}
|
||||
|
||||
/* Architecture */
|
||||
Device::Architecture Device::architecture() const
|
||||
device::Architecture device::architecture() const
|
||||
{ return nv_arch(compute_capability()); }
|
||||
|
||||
/* Attributes */
|
||||
size_t Device::address_bits() const
|
||||
size_t device::address_bits() const
|
||||
{ return sizeof(size_t)*8; }
|
||||
|
||||
driver::Platform Device::platform() const
|
||||
{ return Platform(); }
|
||||
driver::platform device::platform() const
|
||||
{ return platform(); }
|
||||
|
||||
std::string Device::name() const{
|
||||
std::string device::name() const{
|
||||
char tmp[128];
|
||||
dispatch::cuDeviceGetName(tmp, 128, *cu_);
|
||||
return std::string(tmp);
|
||||
}
|
||||
|
||||
std::string Device::pci_bus_id() const{
|
||||
std::string device::pci_bus_id() const{
|
||||
char tmp[128];
|
||||
dispatch::cuDeviceGetPCIBusId(tmp, 128, *cu_);
|
||||
return std::string(tmp);
|
||||
}
|
||||
|
||||
void Device::interpret_as(std::pair<size_t, size_t> cc){
|
||||
void device::interpret_as(std::pair<size_t, size_t> cc){
|
||||
interpreted_as_ = std::make_shared<std::pair<size_t, size_t>>(cc);
|
||||
}
|
||||
|
||||
std::pair<size_t, size_t> Device::compute_capability() const{
|
||||
std::pair<size_t, size_t> device::compute_capability() const{
|
||||
if(interpreted_as_)
|
||||
return *interpreted_as_;
|
||||
size_t _major = cuGetInfo<CU_DEVICE_ATTRIBUTE_COMPUTE_CAPABILITY_MAJOR>();
|
||||
@@ -133,17 +133,17 @@ std::pair<size_t, size_t> Device::compute_capability() const{
|
||||
return std::make_pair(_major, _minor);
|
||||
}
|
||||
|
||||
size_t Device::max_threads_per_block() const
|
||||
size_t device::max_threads_per_block() const
|
||||
{ return cuGetInfo<CU_DEVICE_ATTRIBUTE_MAX_THREADS_PER_BLOCK>(); }
|
||||
|
||||
size_t Device::max_shared_memory() const
|
||||
size_t device::max_shared_memory() const
|
||||
{ return cuGetInfo<CU_DEVICE_ATTRIBUTE_MAX_SHARED_MEMORY_PER_BLOCK>(); }
|
||||
|
||||
size_t Device::warp_size() const
|
||||
size_t device::warp_size() const
|
||||
{ return cuGetInfo<CU_DEVICE_ATTRIBUTE_WARP_SIZE>(); }
|
||||
|
||||
|
||||
std::vector<size_t> Device::max_block_dim() const{
|
||||
std::vector<size_t> device::max_block_dim() const{
|
||||
std::vector<size_t> result(3);
|
||||
result[0] = cuGetInfo<CU_DEVICE_ATTRIBUTE_MAX_BLOCK_DIM_X>();
|
||||
result[1] = cuGetInfo<CU_DEVICE_ATTRIBUTE_MAX_BLOCK_DIM_Y>();
|
||||
@@ -151,33 +151,33 @@ std::vector<size_t> Device::max_block_dim() const{
|
||||
return result;
|
||||
}
|
||||
|
||||
size_t Device::current_sm_clock() const{
|
||||
size_t device::current_sm_clock() const{
|
||||
unsigned int result;
|
||||
dispatch::nvmlDeviceGetClockInfo(nvml_device(), NVML_CLOCK_SM, &result);
|
||||
return result;
|
||||
}
|
||||
|
||||
size_t Device::max_sm_clock() const{
|
||||
size_t device::max_sm_clock() const{
|
||||
unsigned int result;
|
||||
dispatch::nvmlDeviceGetMaxClockInfo(nvml_device(), NVML_CLOCK_SM, &result);
|
||||
return result;
|
||||
}
|
||||
|
||||
|
||||
size_t Device::current_mem_clock() const{
|
||||
size_t device::current_mem_clock() const{
|
||||
unsigned int result;
|
||||
dispatch::nvmlDeviceGetClockInfo(nvml_device(), NVML_CLOCK_MEM, &result);
|
||||
return result;
|
||||
}
|
||||
|
||||
size_t Device::max_mem_clock() const{
|
||||
size_t device::max_mem_clock() const{
|
||||
unsigned int result;
|
||||
dispatch::nvmlDeviceGetMaxClockInfo(nvml_device(), NVML_CLOCK_MEM, &result);
|
||||
return result;
|
||||
}
|
||||
|
||||
/* Infos */
|
||||
std::string Device::infos() const{
|
||||
std::string device::infos() const{
|
||||
std::ostringstream oss;
|
||||
std::vector<size_t> max_wi_sizes = max_block_dim();
|
||||
oss << "Platform: " << platform().name() << std::endl;
|
||||
@@ -188,7 +188,7 @@ std::string Device::infos() const{
|
||||
return oss.str();
|
||||
}
|
||||
|
||||
Handle<CUdevice> const & Device::cu() const
|
||||
handle<CUdevice> const & device::cu() const
|
||||
{ return cu_; }
|
||||
|
||||
}
|
||||
|
@@ -180,16 +180,16 @@ NVML_DEFINE2(nvmlReturn_t, nvmlDeviceGetHandleByPciBusId_v2, const char *, nvmlD
|
||||
NVML_DEFINE3(nvmlReturn_t, nvmlDeviceGetClockInfo, nvmlDevice_t, nvmlClockType_t, unsigned int*)
|
||||
NVML_DEFINE3(nvmlReturn_t, nvmlDeviceGetMaxClockInfo, nvmlDevice_t, nvmlClockType_t, unsigned int*)
|
||||
|
||||
cublasHandle_t dispatch::cublasHandle(Context const & ctx){
|
||||
static std::map<Context, cublasHandle_t> handles;
|
||||
cublasHandle_t dispatch::cublasHandle(driver::context const & ctx){
|
||||
static std::map<context, cublasHandle_t> handles;
|
||||
auto pr = handles.insert({ctx, cublasHandle_t()});
|
||||
if(pr.second)
|
||||
cublasCreate_v2(&pr.first->second);
|
||||
return pr.first->second;
|
||||
}
|
||||
|
||||
cudnnHandle_t dispatch::cudnnHandle(Context const & ctx){
|
||||
static std::map<Context, cudnnHandle_t> handles;
|
||||
cudnnHandle_t dispatch::cudnnHandle(driver::context const & ctx){
|
||||
static std::map<context, cudnnHandle_t> handles;
|
||||
auto pr = handles.insert({ctx, cudnnHandle_t()});
|
||||
if(pr.second)
|
||||
cudnnCreate(&pr.first->second);
|
||||
|
@@ -33,7 +33,7 @@ float Event::elapsed_time() const{
|
||||
return time;
|
||||
}
|
||||
|
||||
Handle<cu_event_t> const & Event::cu() const
|
||||
handle<cu_event_t> const & Event::cu() const
|
||||
{ return cu_; }
|
||||
|
||||
}
|
||||
|
@@ -43,24 +43,24 @@ inline void _delete(cu_platform){}
|
||||
|
||||
//Constructor
|
||||
template<class CUType>
|
||||
Handle<CUType>::Handle(CUType cu, bool take_ownership): h_(new CUType(cu)), has_ownership_(take_ownership)
|
||||
handle<CUType>::handle(CUType cu, bool take_ownership): h_(new CUType(cu)), has_ownership_(take_ownership)
|
||||
{ }
|
||||
|
||||
|
||||
template<class CUType>
|
||||
Handle<CUType>::~Handle(){
|
||||
handle<CUType>::~handle(){
|
||||
if(has_ownership_ && h_ && h_.unique() && *h_)
|
||||
_delete(*h_);
|
||||
}
|
||||
|
||||
template class Handle<CUdeviceptr>;
|
||||
template class Handle<CUstream>;
|
||||
template class Handle<CUcontext>;
|
||||
template class Handle<CUdevice>;
|
||||
template class Handle<cu_event_t>;
|
||||
template class Handle<CUfunction>;
|
||||
template class Handle<CUmodule>;
|
||||
template class Handle<cu_platform>;
|
||||
template class handle<CUdeviceptr>;
|
||||
template class handle<CUstream>;
|
||||
template class handle<CUcontext>;
|
||||
template class handle<CUdevice>;
|
||||
template class handle<cu_event_t>;
|
||||
template class handle<CUfunction>;
|
||||
template class handle<CUmodule>;
|
||||
template class handle<cu_platform>;
|
||||
|
||||
}
|
||||
}
|
||||
|
@@ -32,13 +32,13 @@ namespace triton
|
||||
namespace driver
|
||||
{
|
||||
|
||||
Kernel::Kernel(Module const & program, const char * name) : program_(program), address_bits_(program.context().device().address_bits()){
|
||||
kernel::kernel(driver::module const & program, const char * name) : program_(program), address_bits_(program.context().device().address_bits()){
|
||||
cu_params_store_.reserve(64);
|
||||
cu_params_.reserve(64);
|
||||
dispatch::cuModuleGetFunction(&*cu_, program, name);
|
||||
}
|
||||
|
||||
void Kernel::setArg(unsigned int index, std::size_t size, void* ptr){
|
||||
void kernel::setArg(unsigned int index, std::size_t size, void* ptr){
|
||||
if(index + 1> cu_params_store_.size()){
|
||||
cu_params_store_.resize(index+1);
|
||||
cu_params_.resize(index+1);
|
||||
@@ -48,16 +48,16 @@ void Kernel::setArg(unsigned int index, std::size_t size, void* ptr){
|
||||
cu_params_[index] = cu_params_store_[index].get();
|
||||
}
|
||||
|
||||
void Kernel::setArg(unsigned int index, Buffer const & data)
|
||||
void kernel::setArg(unsigned int index, buffer const & data)
|
||||
{ return setArg(index, (CUdeviceptr)data);}
|
||||
|
||||
void* const* Kernel::cu_params() const
|
||||
void* const* kernel::cu_params() const
|
||||
{ return cu_params_.data(); }
|
||||
|
||||
Handle<CUfunction> const & Kernel::cu() const
|
||||
handle<CUfunction> const & kernel::cu() const
|
||||
{ return cu_; }
|
||||
|
||||
Module const & Kernel::module() const
|
||||
driver::module const & kernel::module() const
|
||||
{ return program_; }
|
||||
|
||||
|
||||
|
@@ -34,7 +34,7 @@ namespace triton
|
||||
namespace driver
|
||||
{
|
||||
|
||||
Module::Module(Context const & context, std::string const & source) : context_(context), source_(source){
|
||||
module::module(driver::context const & context, std::string const & source) : context_(context), source_(source){
|
||||
ContextSwitcher ctx_switch(context_);
|
||||
// JIT compile source-code
|
||||
CUjit_option opt[] = {CU_JIT_ERROR_LOG_BUFFER_SIZE_BYTES, CU_JIT_ERROR_LOG_BUFFER};
|
||||
@@ -50,17 +50,17 @@ Module::Module(Context const & context, std::string const & source) : context_(c
|
||||
}
|
||||
}
|
||||
|
||||
Context const & Module::context() const
|
||||
driver::context const & module::context() const
|
||||
{ return context_; }
|
||||
|
||||
Handle<CUmodule> const & Module::cu() const
|
||||
handle<CUmodule> const & module::cu() const
|
||||
{ return cu_; }
|
||||
|
||||
Buffer Module::symbol(const char *name) const{
|
||||
buffer module::symbol(const char *name) const{
|
||||
CUdeviceptr handle;
|
||||
size_t size;
|
||||
dispatch::cuModuleGetGlobal_v2(&handle, &size, *cu_, name);
|
||||
return Buffer(context_, handle, false);
|
||||
return buffer(context_, handle, false);
|
||||
}
|
||||
|
||||
|
||||
|
@@ -31,20 +31,20 @@ namespace triton
|
||||
namespace driver
|
||||
{
|
||||
|
||||
std::string Platform::version() const{
|
||||
std::string platform::version() const{
|
||||
int version;
|
||||
dispatch::cuDriverGetVersion(&version);
|
||||
return std::to_string(version);
|
||||
}
|
||||
|
||||
std::vector<Device> Platform::devices() const{
|
||||
std::vector<Device> devices;
|
||||
std::vector<device> platform::devices() const{
|
||||
std::vector<device> devices;
|
||||
int N;
|
||||
dispatch::cuDeviceGetCount(&N);
|
||||
for(int i = 0 ; i < N ; ++i){
|
||||
CUdevice device;
|
||||
dispatch::cuDeviceGet(&device, i);
|
||||
devices.push_back(Device(device));
|
||||
CUdevice dvc;
|
||||
dispatch::cuDeviceGet(&dvc, i);
|
||||
devices.push_back(driver::device(dvc));
|
||||
}
|
||||
return devices;
|
||||
}
|
||||
|
@@ -44,25 +44,25 @@ inline CUcontext cucontext(){
|
||||
return result;
|
||||
}
|
||||
|
||||
Stream::Stream(CUstream stream, bool take_ownership): context_(cucontext(), take_ownership), cu_(stream, take_ownership)
|
||||
stream::stream(CUstream stream, bool take_ownership): context_(cucontext(), take_ownership), cu_(stream, take_ownership)
|
||||
{}
|
||||
|
||||
Stream::Stream(Context const & context): context_(context), cu_(CUstream(), true)
|
||||
stream::stream(driver::context const & context): context_(context), cu_(CUstream(), true)
|
||||
{
|
||||
ContextSwitcher ctx_switch(context_);
|
||||
dispatch::cuStreamCreate(&*cu_, 0);
|
||||
}
|
||||
|
||||
void Stream::synchronize()
|
||||
void stream::synchronize()
|
||||
{
|
||||
ContextSwitcher ctx_switch(context_);
|
||||
dispatch::cuStreamSynchronize(*cu_);
|
||||
}
|
||||
|
||||
Context const & Stream::context() const
|
||||
driver::context const & stream::context() const
|
||||
{ return context_; }
|
||||
|
||||
void Stream::enqueue(Kernel const & kernel, std::array<size_t, 3> grid, std::array<size_t, 3> block, std::vector<Event> const *, Event* event){
|
||||
void stream::enqueue(kernel const & kernel, std::array<size_t, 3> grid, std::array<size_t, 3> block, std::vector<Event> const *, Event* event){
|
||||
ContextSwitcher ctx_switch(context_);
|
||||
if(event)
|
||||
dispatch::cuEventRecord(((cu_event_t)*event).first, *cu_);
|
||||
@@ -71,7 +71,7 @@ void Stream::enqueue(Kernel const & kernel, std::array<size_t, 3> grid, std::arr
|
||||
dispatch::cuEventRecord(((cu_event_t)*event).second, *cu_);
|
||||
}
|
||||
|
||||
void Stream::write(Buffer const & buffer, bool blocking, std::size_t offset, std::size_t size, void const* ptr){
|
||||
void stream::write(buffer const & buffer, bool blocking, std::size_t offset, std::size_t size, void const* ptr){
|
||||
ContextSwitcher ctx_switch(context_);
|
||||
if(blocking)
|
||||
dispatch::cuMemcpyHtoD(buffer + offset, ptr, size);
|
||||
@@ -79,7 +79,7 @@ void Stream::write(Buffer const & buffer, bool blocking, std::size_t offset, std
|
||||
dispatch::cuMemcpyHtoDAsync(buffer + offset, ptr, size, *cu_);
|
||||
}
|
||||
|
||||
void Stream::read(Buffer const & buffer, bool blocking, std::size_t offset, std::size_t size, void* ptr){
|
||||
void stream::read(buffer const & buffer, bool blocking, std::size_t offset, std::size_t size, void* ptr){
|
||||
ContextSwitcher ctx_switch(context_);
|
||||
if(blocking)
|
||||
dispatch::cuMemcpyDtoH(ptr, buffer + offset, size);
|
||||
@@ -87,7 +87,7 @@ void Stream::read(Buffer const & buffer, bool blocking, std::size_t offset, std:
|
||||
dispatch::cuMemcpyDtoHAsync(ptr, buffer + offset, size, *cu_);
|
||||
}
|
||||
|
||||
Handle<CUstream> const & Stream::cu() const
|
||||
handle<CUstream> const & stream::cu() const
|
||||
{ return cu_; }
|
||||
|
||||
}
|
||||
|
151
lib/jit.cpp
Normal file
151
lib/jit.cpp
Normal file
@@ -0,0 +1,151 @@
|
||||
#include "triton/jit.h"
|
||||
#include <string>
|
||||
#include "triton/ast/ast.h"
|
||||
#include "triton/ir/context.h"
|
||||
#include "triton/ir/context_impl.h"
|
||||
#include "triton/codegen/selection.h"
|
||||
#include "triton/codegen/tune.h"
|
||||
#include "triton/codegen/shared_copy.h"
|
||||
#include "triton/codegen/allocation.h"
|
||||
#include "triton/codegen/liveness.h"
|
||||
#include "triton/codegen/vectorize.h"
|
||||
#include "triton/codegen/buffer_info.h"
|
||||
#include "triton/codegen/barriers.h"
|
||||
#include "llvm/IR/IRPrintingPasses.h"
|
||||
#include "llvm/IR/Module.h"
|
||||
#include "llvm/IR/LLVMContext.h"
|
||||
#include "llvm/IR/PassManager.h"
|
||||
#include "llvm/Support/raw_ostream.h"
|
||||
#include "llvm/Support/TargetRegistry.h"
|
||||
#include "llvm/Support/TargetSelect.h"
|
||||
#include "llvm/Target/TargetMachine.h"
|
||||
#include "llvm/Target/TargetOptions.h"
|
||||
#include "llvm/CodeGen/TargetPassConfig.h"
|
||||
#include "llvm/IR/LegacyPassManager.h"
|
||||
#include "llvm/Transforms/Scalar/EarlyCSE.h"
|
||||
#include "llvm/Analysis/LoopPass.h"
|
||||
|
||||
typedef struct yy_buffer_state * YY_BUFFER_STATE;
|
||||
extern int yyparse();
|
||||
extern YY_BUFFER_STATE yy_scan_string(const char * str);
|
||||
extern void yy_delete_buffer(YY_BUFFER_STATE buffer);
|
||||
using triton::ast::translation_unit;
|
||||
extern translation_unit *ast_root;
|
||||
|
||||
namespace triton {
|
||||
|
||||
void jit::init_llvm() {
|
||||
static bool init = false;
|
||||
if(!init){
|
||||
llvm::InitializeAllTargetInfos();
|
||||
llvm::InitializeAllTargets();
|
||||
llvm::InitializeAllTargetMCs();
|
||||
llvm::InitializeAllAsmParsers();
|
||||
llvm::InitializeAllAsmPrinters();
|
||||
init = true;
|
||||
}
|
||||
}
|
||||
|
||||
std::unique_ptr<llvm::Module> jit::make_llvm_module(ir::module &module, const std::vector<unsigned>& params) {
|
||||
llvm::Module* result = new llvm::Module("matmul", llvm_context_);
|
||||
|
||||
// create passes
|
||||
codegen::buffer_info_pass buffer_info;
|
||||
codegen::place_shared_copy shared(&buffer_info);
|
||||
codegen::tune tune;
|
||||
codegen::liveness liveness(&buffer_info);
|
||||
codegen::allocation allocation(&liveness, &buffer_info);
|
||||
codegen::barriers barriers(&allocation, &buffer_info);
|
||||
codegen::vectorize vectorize(&tune);
|
||||
codegen::selection selection(&allocation, &tune, &buffer_info);
|
||||
|
||||
// tuning parameters
|
||||
tune.run(module);
|
||||
unsigned i = 0;
|
||||
triton_context_.p_impl->mp_constants_[0]->set_value(params[0]);
|
||||
triton_context_.p_impl->mp_constants_[1]->set_value(params[1]);
|
||||
triton_context_.p_impl->mp_constants_[2]->set_value(params[2]);
|
||||
for(unsigned *x: tune.get_params(module))
|
||||
*x = params[3 + i++];
|
||||
// constraints
|
||||
std::map<ir::value*, std::vector<std::string>> errors;
|
||||
tune.check_constraints(module, errors);
|
||||
std::cout << "errors: " << errors.size() << std::endl;
|
||||
for(auto &x: errors){
|
||||
for(auto &e: x.second)
|
||||
std::cout << x.first->get_name() << " " << e << std::endl;
|
||||
}
|
||||
if(errors.size())
|
||||
exit(EXIT_FAILURE);
|
||||
|
||||
// generate ptx
|
||||
buffer_info.run(module);
|
||||
shared.run(module);
|
||||
liveness.run(module);
|
||||
allocation.run();
|
||||
barriers.run(module);
|
||||
vectorize.run(module);
|
||||
selection.run(module, *result);
|
||||
|
||||
return std::unique_ptr<llvm::Module>(result);
|
||||
}
|
||||
|
||||
std::unique_ptr<ir::module> jit::make_triton_module(const std::string &src) {
|
||||
// create AST from Triton-C source
|
||||
YY_BUFFER_STATE buffer = yy_scan_string(src.c_str());
|
||||
yyparse();
|
||||
yy_delete_buffer(buffer);
|
||||
translation_unit *program = ast_root;
|
||||
// create Triton-IR from AST
|
||||
ir::module* module = new ir::module("matrix", triton_context_);
|
||||
program->codegen(module);
|
||||
return std::unique_ptr<ir::module>(module);
|
||||
}
|
||||
|
||||
|
||||
jit::jit(driver::context context): driver_context_(context) {
|
||||
}
|
||||
|
||||
std::string jit::compute_data_layout(bool is_64bit, bool use_short_pointers) {
|
||||
std::string ret = "e";
|
||||
if (!is_64bit)
|
||||
ret += "-p:32:32";
|
||||
else if (use_short_pointers)
|
||||
ret += "-p3:32:32-p4:32:32-p5:32:32";
|
||||
ret += "-i64:64-i128:128-v16:16-v32:32-n16:32:64";
|
||||
return ret;
|
||||
}
|
||||
|
||||
void jit::add_module(ir::module &tt_module, const std::vector<unsigned> ¶ms) {
|
||||
init_llvm();
|
||||
auto ll_module = make_llvm_module(tt_module, params);
|
||||
ll_module->setTargetTriple("nvptx64-nvidia-cuda");
|
||||
std::string error;
|
||||
auto target = llvm::TargetRegistry::lookupTarget(ll_module->getTargetTriple(), error);
|
||||
llvm::TargetMachine *machine = target->createTargetMachine(ll_module->getTargetTriple(), "sm_52", "",
|
||||
llvm::TargetOptions(), llvm::Reloc::Model(),
|
||||
llvm::None, llvm::CodeGenOpt::Aggressive);
|
||||
ll_module->setDataLayout(compute_data_layout());
|
||||
|
||||
// emit machine code
|
||||
llvm::legacy::PassManager pass;
|
||||
llvm::SmallVector<char, 0> buffer;
|
||||
llvm::raw_svector_ostream stream(buffer);
|
||||
machine->addPassesToEmitFile(pass, stream, nullptr, llvm::TargetMachine::CGFT_AssemblyFile);
|
||||
pass.run(*ll_module);
|
||||
std::string src(buffer.begin(), buffer.end());
|
||||
|
||||
modules_.push_back(driver::module(driver_context_, src));
|
||||
}
|
||||
|
||||
void jit::add_module(const std::string &src, const std::vector<unsigned> ¶ms) {
|
||||
auto ptt_module = make_triton_module(src);
|
||||
add_module(*ptt_module, params);
|
||||
}
|
||||
|
||||
driver::kernel jit::get_function(const std::string &name) {
|
||||
return driver::kernel(modules_.front(), name.c_str());
|
||||
}
|
||||
|
||||
|
||||
}
|
Reference in New Issue
Block a user