[general] added simple jit interface

This commit is contained in:
Philippe Tillet
2019-03-08 23:58:42 -05:00
parent c5073a5af6
commit d049679aa2
26 changed files with 458 additions and 464 deletions

View File

@@ -1,40 +1,8 @@
#include <cstring>
#include <cstdio>
#include "cuda.h"
#include "llvm/IR/Verifier.h"
#include "triton/ast/ast.h"
#include "triton/ir/context.h"
#include "triton/ir/module.h"
#include "triton/ir/print.h"
#include "triton/ir/context_impl.h"
#include "triton/codegen/selection.h"
#include "triton/codegen/tune.h"
#include "triton/codegen/shared_copy.h"
#include "triton/codegen/allocation.h"
#include "triton/codegen/liveness.h"
#include "triton/codegen/vectorize.h"
#include "triton/codegen/buffer_info.h"
#include "triton/codegen/barriers.h"
#include "llvm/IR/IRPrintingPasses.h"
#include "llvm/IR/Module.h"
#include "llvm/IR/LLVMContext.h"
#include "llvm/IR/PassManager.h"
#include "llvm/Support/raw_ostream.h"
#include "llvm/Support/TargetRegistry.h"
#include "llvm/Support/TargetSelect.h"
#include "llvm/Target/TargetMachine.h"
#include "llvm/Target/TargetOptions.h"
#include "llvm/CodeGen/TargetPassConfig.h"
#include "llvm/IR/LegacyPassManager.h"
#include "llvm/Transforms/Scalar/EarlyCSE.h"
#include "llvm/Analysis/LoopPass.h"
typedef struct yy_buffer_state * YY_BUFFER_STATE;
extern int yyparse();
extern YY_BUFFER_STATE yy_scan_string(const char * str);
extern void yy_delete_buffer(YY_BUFFER_STATE buffer);
using triton::ast::translation_unit;
extern translation_unit *ast_root;
#include "triton/jit.h"
#include "triton/driver/backend.h"
#include "triton/driver/stream.h"
const char* src =
R"(
@@ -44,7 +12,7 @@ const tunable int32 TK;
void matmul(restrict read_only fp32 *a, restrict read_only fp32 *b, fp32 *c,
int32 M, int32 N, int32 K, int32 bound){
int32 rxa[TM] = get_global_range[TM](0)
int32 rxa[TM] = get_global_range[TM](0);
int32 ryb[TN] = get_global_range[TN](1);
int32 rka[TK] = 0 ... TK;
int32 rkb[TK] = 0 ... TK;
@@ -83,81 +51,6 @@ void matmul(restrict read_only fp32 *a, restrict read_only fp32 *b, fp32 *c,
}
)";
static std::string compute_data_layout(bool is64Bit, bool UseShortPointers) {
std::string Ret = "e";
if (!is64Bit)
Ret += "-p:32:32";
else if (UseShortPointers)
Ret += "-p3:32:32-p4:32:32-p5:32:32";
Ret += "-i64:64-i128:128-v16:16-v32:32-n16:32:64";
return Ret;
}
static std::string generate_machine_code(llvm::Module &module, const std::string &target_triple, const std::string &data_layout) {
llvm::InitializeAllTargetInfos();
llvm::InitializeAllTargets();
llvm::InitializeAllTargetMCs();
llvm::InitializeAllAsmParsers();
llvm::InitializeAllAsmPrinters();
module.setTargetTriple(target_triple);
std::string error;
auto target = llvm::TargetRegistry::lookupTarget(module.getTargetTriple(), error);
llvm::TargetMachine *machine = target->createTargetMachine(module.getTargetTriple(), "sm_52", "",
llvm::TargetOptions(), llvm::Reloc::Model(),
llvm::None, llvm::CodeGenOpt::Aggressive);
module.setDataLayout(data_layout);
// emit machine code
llvm::legacy::PassManager pass;
llvm::SmallVector<char, 0> buffer;
llvm::raw_svector_ostream stream(buffer);
machine->addPassesToEmitFile(pass, stream, nullptr, llvm::TargetMachine::CGFT_AssemblyFile);
pass.run(module);
std::string src(buffer.begin(), buffer.end());
return src;
}
static void __checkCudaErrors( CUresult err, const char *file, const int line )
{
if( CUDA_SUCCESS != err) {
fprintf(stderr,
"CUDA Driver API error = %04d from file <%s>, line %i.\n",
err, file, line );
exit(-1);
}
}
#define checkCudaErrors(err) __checkCudaErrors (err, __FILE__, __LINE__)
static void compile_machine_code(CUdevice &device, CUcontext &context, CUmodule &module,
CUfunction &function, CUstream &stream, int &major, int &minor,
const std::string &src, const std::string &name) {
int numDevices;
// Initialize
checkCudaErrors(cuInit(0));
checkCudaErrors(cuDeviceGetCount(&numDevices));
checkCudaErrors(cuDeviceGet(&device, 0));
checkCudaErrors(cuDeviceComputeCapability(&major, &minor, device));
checkCudaErrors(cuCtxCreate(&context, 0, device));
checkCudaErrors(cuStreamCreate(&stream, 0));
// Compile program
CUjit_option opt[] = {CU_JIT_ERROR_LOG_BUFFER_SIZE_BYTES, CU_JIT_ERROR_LOG_BUFFER};
unsigned int errbufsize = 8096;
std::string errbuf(errbufsize, 0);
const void *cpterr = static_cast<const void*>(errbuf.data());
void *pterr = const_cast<void*>(cpterr);
void* optval[] = {(void*)(uintptr_t)errbufsize, pterr};
int err = cuModuleLoadDataEx(&module, src.data(), 2, opt, optval);
if(err != CUDA_SUCCESS){
std::cerr << "Compilation Failed! Log: " << std::endl;
std::cerr << errbuf << std::endl;
}
// Get function
checkCudaErrors(cuModuleGetFunction(&function, module, name.c_str()));
}
template<class T>
void simple_gemm(std::vector<T> &c, const std::vector<T> &a, const std::vector<T> &b, size_t M, size_t N, size_t K){
@@ -170,54 +63,7 @@ void simple_gemm(std::vector<T> &c, const std::vector<T> &a, const std::vector<T
}
}
void loop_nest(std::vector<size_t> const & ranges, std::function<void(std::vector<size_t> const &)> const & f){
size_t D = ranges.size();
std::vector<size_t> values(D, 0);
// Start with innermost loop
size_t i = D - 1;
while(true){
//Execute function
f(values);
//Increment counters
while(values[i]++ == ranges[i] - 1){
if(i == 0)
return;
values[i--] = 0;
}
i = D - 1;
}
}
int main() {
// create AST from Triton-C source
YY_BUFFER_STATE buffer = yy_scan_string(src);
yyparse();
yy_delete_buffer(buffer);
translation_unit *program = ast_root;
// create Triton-IR from AST
triton::ir::context context;
triton::ir::module module("matrix", context);
program->codegen(&module);
llvm::LLVMContext llvm_context;
llvm::Module llvm_module("matmul", llvm_context);
// create passes
triton::codegen::buffer_info_pass buffer_info;
triton::codegen::place_shared_copy shared(&buffer_info);
triton::codegen::tune tune;
triton::codegen::liveness liveness(&buffer_info);
triton::codegen::allocation allocation(&liveness, &buffer_info);
triton::codegen::barriers barriers(&allocation, &buffer_info);
triton::codegen::vectorize vectorize(&tune);
triton::codegen::selection selection(&allocation, &tune, &buffer_info);
triton::ir::print(module, std::cout);
// tuning parameters
tune.run(module);
std::vector<unsigned> params = {
// shapes
16, 16, 8,
@@ -232,97 +78,49 @@ int main() {
// b1
1, 8, 1
};
unsigned TM = params[0];
unsigned TN = params[1];
unsigned nthreads = params[10]*params[13]*params[11]*params[14];
auto context = triton::driver::backend::contexts::get_default();
triton::jit jit(context);
jit.add_module(src, params);
triton::driver::kernel kernel = jit.get_function("matmul");
// meta-parameters
unsigned i = 0;
context.p_impl->mp_constants_[0]->set_value(params[0]);
context.p_impl->mp_constants_[1]->set_value(params[1]);
context.p_impl->mp_constants_[2]->set_value(params[2]);
for(unsigned *x: tune.get_params(module))
*x = params[3 + i++];
// constraints
std::map<triton::ir::value*, std::vector<std::string>> errors;
tune.check_constraints(module, errors);
std::cout << "errors: " << errors.size() << std::endl;
for(auto &x: errors){
for(auto &e: x.second)
std::cout << x.first->get_name() << " " << e << std::endl;
}
if(errors.size())
exit(EXIT_FAILURE);
// run passes
buffer_info.run(module);
shared.run(module);
liveness.run(module);
allocation.run();
barriers.run(module);
vectorize.run(module);
selection.run(module, llvm_module);
// llvm source
llvm::legacy::PassManager manager;
manager.add(llvm::createPrintModulePass(llvm::outs()));
manager.add(llvm::createVerifierPass(true));
manager.run(llvm_module);
std::string src = generate_machine_code(llvm_module, "nvptx64-nvidia-cuda", compute_data_layout(true, true));
std::cout << src << std::endl;
// compile machine code
CUdevice cu_device;
CUcontext cu_context;
CUmodule cu_module;
CUfunction cu_kernel;
CUstream cu_stream;
int major, minor;
compile_machine_code(cu_device, cu_context, cu_module, cu_kernel, cu_stream, major, minor, src, "matmul");
// execute machine code
// Allocate buffers
typedef float numeric_t;
size_t M = 128, N = 128, K = 128;
size_t bound = 8;
std::vector<numeric_t> c(M*N);
std::vector<numeric_t> rc(M*N);
std::vector<numeric_t> a(M*K);
std::vector<numeric_t> b(K*N);
std::vector<float> hc(M*N);
std::vector<float> rc(M*N);
std::vector<float> ha(M*K);
std::vector<float> hb(K*N);
srand(0);
for(size_t i = 0; i < a.size(); i++)
a[i] = 1;
for(size_t i = 0; i < b.size(); i++)
b[i] = 1;
for(size_t i = 0; i < c.size(); i++)
c[i] = 0;
CUdeviceptr d_a, d_b, d_c;
checkCudaErrors(cuMemAlloc(&d_a, sizeof(numeric_t) * a.size()));
checkCudaErrors(cuMemAlloc(&d_b, sizeof(numeric_t) * b.size()));
checkCudaErrors(cuMemAlloc(&d_c, sizeof(numeric_t) * c.size()));
// Copy buffers
checkCudaErrors(cuMemcpyHtoD(d_a, a.data(), sizeof(numeric_t) * a.size()));
checkCudaErrors(cuMemcpyHtoD(d_b, b.data(), sizeof(numeric_t) * b.size()));
checkCudaErrors(cuMemcpyHtoD(d_c, c.data(), sizeof(numeric_t) * c.size()));
// Launch kernel
void *args[] = { &d_a, &d_b, &d_c, &M, &N, &K, &bound};
int num_regs;
cuFuncGetAttribute(&num_regs, CU_FUNC_ATTRIBUTE_NUM_REGS, cu_kernel);
unsigned TM = context.p_impl->mp_constants_[0]->get_value();
unsigned TN = context.p_impl->mp_constants_[1]->get_value();
unsigned nthreads = params[10]*params[13]*params[11]*params[14];
checkCudaErrors(cuLaunchKernel(cu_kernel, (M + TM - 1)/TM, (N + TN - 1)/TN, 1, nthreads, 1, 1, 0, cu_stream, args, NULL));
checkCudaErrors(cuStreamSynchronize(cu_stream));
// Write back
checkCudaErrors(cuMemcpyDtoH(c.data(), d_c, sizeof(numeric_t) * c.size()));
simple_gemm(rc, a, b, M, N, K);
for(size_t i = 0; i < ha.size(); i++)
ha[i] = 1;
for(size_t i = 0; i < hb.size(); i++)
hb[i] = 1;
for(size_t i = 0; i < hc.size(); i++)
hc[i] = 0;
triton::driver::buffer dc(context, hc.size()*4);
triton::driver::buffer da(context, ha.size()*4);
triton::driver::buffer db(context, hb.size()*4);
triton::driver::stream stream(context);
stream.write(da, true, 0, ha);
stream.write(db, true, 0, hb);
stream.write(dc, true, 0, hc);
kernel.setArg(0, da);
kernel.setArg(1, db);
kernel.setArg(2, dc);
kernel.setArg(3, M);
kernel.setArg(4, N);
kernel.setArg(5, K);
kernel.setArg(6, bound);
stream.enqueue(kernel, {(M + TM - 1)/TM, (N + TN - 1)/TN, 1}, {nthreads, 1, 1});
stream.synchronize();
stream.read(dc, true, 0, hc);
simple_gemm(rc, ha, hb, M, N, K);
for(size_t i = 0; i < M*N; i++)
if(std::abs(c[i] - rc[i])/std::max(c[i], rc[i]) > 1e-4){
std::cout << i << " " << c[i] << " " << rc[i] << std::endl;
if(std::abs(hc[i] - rc[i])/std::max(hc[i], rc[i]) > 1e-4){
std::cout << i << " " << hc[i] << " " << rc[i] << std::endl;
exit(EXIT_FAILURE);
}
std::cout << "Pass!" << std::endl;

View File

@@ -33,13 +33,13 @@ namespace triton
namespace driver
{
class Buffer;
class Stream;
class Device;
class Context;
class Platform;
class Module;
class Kernel;
class buffer;
class stream;
class device;
class context;
class platform;
class module;
class kernel;
struct backend
{
@@ -49,9 +49,9 @@ struct backend
friend class backend;
public:
static void release();
static Module& get(Stream const & stream, std::string const & name, std::string const &src);
static module& get(driver::stream const & stream, std::string const & name, std::string const &src);
private:
static std::map<std::tuple<Stream, std::string>, Module * > cache_;
static std::map<std::tuple<stream, std::string>, module * > cache_;
};
class kernels
@@ -59,53 +59,53 @@ struct backend
friend class backend;
public:
static void release();
static Kernel & get(Module const & program, std::string const & name);
static kernel & get(driver::module const & program, std::string const & name);
private:
static std::map<std::tuple<Module, std::string>, Kernel * > cache_;
static std::map<std::tuple<module, std::string>, kernel * > cache_;
};
class contexts
{
friend class backend;
private:
static void init(std::vector<Platform> const &);
static void init(std::vector<platform> const &);
static void release();
public:
static Context const & get_default();
static driver::context const & get_default();
template<class T>
static Context const & import(T context)
static driver::context const & import(T ctx)
{
for(driver::Context const * x: cache_)
if((T)*x==context)
for(driver::context const * x: cache_)
if((T)*x==ctx)
return *x;
cache_.emplace_back(new Context(context, false));
cache_.emplace_back(new driver::context(ctx, false));
return *cache_.back();
}
static void get(std::list<Context const *> &);
static void get(std::list<context const *> &);
private:
static std::list<Context const *> cache_;
static std::list<context const *> cache_;
};
class streams
{
friend class backend;
private:
static void init(std::list<Context const *> const &);
static void init(std::list<context const *> const &);
static void release();
public:
static void get(Context const &, std::vector<Stream *> &streams);
static Stream & get(Context const &, unsigned int id = 0);
static Stream & get_default();
static void get(driver::context const &, std::vector<stream *> &streams);
static stream & get(driver::context const &, unsigned int id = 0);
static stream & get_default();
private:
static std::map< Context, std::vector<Stream*> > cache_;
static std::map< context, std::vector<stream*> > cache_;
};
static void init();
static void release();
static std::vector<Device> devices();
static std::vector<Platform> platforms();
static void synchronize(Context const &);
static std::vector<device> devices();
static std::vector<platform> platforms();
static void synchronize(driver::context const &);
static unsigned int default_device;
};

View File

@@ -31,21 +31,21 @@ namespace triton
namespace driver
{
class Stream;
class stream;
// Buffer
class Buffer: public HandleInterface<Buffer, CUdeviceptr>
class buffer: public handle_interface<buffer, CUdeviceptr>
{
public:
Buffer(Context const & context, size_t size);
Buffer(Context const & context, CUdeviceptr cu, bool take_ownership);
void set_zero(Stream const & queue, size_t size);
Handle<CUdeviceptr> const & cu() const;
Handle<CUdeviceptr> & cu();
buffer(driver::context const & context, size_t size);
buffer(driver::context const & context, CUdeviceptr cu, bool take_ownership);
void set_zero(stream const & queue, size_t size);
handle<CUdeviceptr> const & cu() const;
handle<CUdeviceptr> & cu();
private:
Context context_;
Handle<CUdeviceptr> cu_;
context context_;
handle<CUdeviceptr> cu_;
};
}

View File

@@ -31,7 +31,7 @@ namespace triton
namespace driver
{
class Context: public HandleInterface<Context, CUcontext>
class context: public handle_interface<context, CUcontext>
{
private:
static std::string get_cache_path();
@@ -39,25 +39,25 @@ private:
public:
//Constructors
explicit Context(CUcontext context, bool take_ownership = true);
explicit Context(Device const & device);
explicit context(CUcontext context, bool take_ownership = true);
explicit context(driver::device const & dvc);
//Accessors
Device const & device() const;
driver::device const & device() const;
std::string const & cache_path() const;
Handle<CUcontext> const & cu() const;
handle<CUcontext> const & cu() const;
private:
Handle<CUcontext> cu_;
Device device_;
handle<CUcontext> cu_;
driver::device dvc_;
std::string cache_path_;
};
class ContextSwitcher{
public:
ContextSwitcher(Context const & ctx);
ContextSwitcher(driver::context const & ctx);
~ContextSwitcher();
private:
Context const & ctx_;
driver::context const & ctx_;
};
}

View File

@@ -51,7 +51,7 @@ static const std::vector<cublasGemmAlgo_t> cublasAlgorithms = {
static const std::map<DType, cudaDataType> cudtype = {{FLOAT_TYPE, CUDA_R_32F}, {DOUBLE_TYPE,CUDA_R_64F}};
static const std::map<char, cublasOperation_t> cuop = {{'N', CUBLAS_OP_N}, {'T', CUBLAS_OP_T}};
inline cublasGemmAlgo_t cublasGemmFastest(Stream& stream, cublasHandle_t handle, cudaDataType cudt, cublasOperation_t AT, cublasOperation_t BT, int32_t M, int32_t N, int32_t K,
inline cublasGemmAlgo_t cublasGemmFastest(stream& stream, cublasHandle_t handle, cudaDataType cudt, cublasOperation_t AT, cublasOperation_t BT, int32_t M, int32_t N, int32_t K,
void* alpha, CUdeviceptr A, int32_t lda, CUdeviceptr B, int32_t ldb,
void* beta, CUdeviceptr C, int32_t ldc){
@@ -84,7 +84,7 @@ inline void cublasGemmEx(cublasHandle_t handle, cudaDataType cudt, cublasOperati
/* Simplified API for default GEMM */
inline void cublasGemm(DType dtype, Stream& stream, char cAT, char cBT, int32_t M, int32_t N, int32_t K, scalar alpha, Buffer const & A, int32_t lda, Buffer const & B, int32_t ldb, scalar beta, Buffer& C, int32_t ldc, cublasGemmAlgo_t* fastest = NULL, cublasGemmAlgo_t algo = CUBLAS_GEMM_DFALT){
inline void cublasGemm(DType dtype, stream& stream, char cAT, char cBT, int32_t M, int32_t N, int32_t K, scalar alpha, buffer const & A, int32_t lda, buffer const & B, int32_t ldb, scalar beta, buffer& C, int32_t ldc, cublasGemmAlgo_t* fastest = NULL, cublasGemmAlgo_t algo = CUBLAS_GEMM_DFALT){
ContextSwitcher ctx_switch(stream.context());
cublasHandle_t handle = dispatch::cublasHandle(stream.context());
dispatch::cublasSetStream_v2(handle, (CUstream)stream);
@@ -111,9 +111,9 @@ inline cudnnTensorFormat_t format(cudnnDataType_t cutype){
}
}
inline void cudnnConv(DType dtype, Stream& stream, int32_t D, int32_t H, int32_t W, int32_t N, int32_t K, int32_t M, int32_t P, int32_t Q, int32_t C, int32_t T, int32_t R, int32_t S,
int32_t pad_d, int32_t pad_h, int32_t pad_w, int32_t stride_d, int32_t stride_h, int32_t stride_w, scalar alpha, Buffer const & I, Buffer const & F, scalar beta, Buffer const & O){
driver::Context const & ctx = stream.context();
inline void cudnnConv(DType dtype, stream& stream, int32_t D, int32_t H, int32_t W, int32_t N, int32_t K, int32_t M, int32_t P, int32_t Q, int32_t C, int32_t T, int32_t R, int32_t S,
int32_t pad_d, int32_t pad_h, int32_t pad_w, int32_t stride_d, int32_t stride_h, int32_t stride_w, scalar alpha, buffer const & I, buffer const & F, scalar beta, buffer const & O){
driver::driver::context const & ctx = stream.context();
ContextSwitcher switch_ctx(ctx);
std::vector<int> pad = {pad_d, pad_h, pad_w};
@@ -154,16 +154,16 @@ inline void cudnnConv(DType dtype, Stream& stream, int32_t D, int32_t H, int32_t
size_t workspace_size;
dispatch::cudnnGetConvolutionForwardWorkspaceSize(handle, tI, tF, conv, tO, algo, &workspace_size);
static Buffer work(ctx, 1024*1024*64);
static buffer work(ctx, 1024*1024*64);
CUdeviceptr twork = work;
CUdeviceptr pI = I, pF = F, pO = O;
dispatch::cudnnConvolutionForward(handle, alpha.data(), tI, (void*)pI, tF, (void*)pF, conv, algo, (void*)twork, workspace_size, beta.data(), tO, (void*)pO);
}
inline void cudnnPool(DType dtype, Stream& stream, int32_t D, int32_t H, int32_t W, int32_t N, int32_t K, int32_t M, int32_t P, int32_t Q, int32_t T, int32_t R, int32_t S,
int32_t pad_d, int32_t pad_h, int32_t pad_w, int32_t stride_d, int32_t stride_h, int32_t stride_w, scalar alpha, Buffer const & I, scalar beta, Buffer const & O){
driver::Context const & ctx = stream.context();
inline void cudnnPool(DType dtype, stream& stream, int32_t D, int32_t H, int32_t W, int32_t N, int32_t K, int32_t M, int32_t P, int32_t Q, int32_t T, int32_t R, int32_t S,
int32_t pad_d, int32_t pad_h, int32_t pad_w, int32_t stride_d, int32_t stride_h, int32_t stride_w, scalar alpha, buffer const & I, scalar beta, buffer const & O){
driver::driver::context const & ctx = stream.context();
ContextSwitcher switch_ctx(ctx);
std::vector<int> pad = {pad_d, pad_h, pad_w};
@@ -200,11 +200,11 @@ inline void cudnnPool(DType dtype, Stream& stream, int32_t D, int32_t H, int32_t
dispatch::cudnnPoolingForward(handle, desc, alpha.data(), tI, (void*)pI, beta.data(), tO, (void*)pO);
}
inline void cudnnTransformTensor(driver::Stream & stream,
inline void cudnnTransformTensor(driver::stream & stream,
DType in_dtype, DType out_dtype,
cudnnTensorFormat_t in_layout, cudnnTensorFormat_t out_layout,
int32_t N, int32_t C, int32_t D, int32_t H, int32_t W,
scalar alpha, driver::Buffer const & I, scalar beta, driver::Buffer& O)
scalar alpha, driver::buffer const & I, scalar beta, driver::buffer& O)
{
cudnnHandle_t handle = dispatch::cudnnHandle(stream.context());
dispatch::cudnnSetStream(handle, (CUstream)stream);

View File

@@ -33,7 +33,7 @@ namespace driver
{
// Device
class Device: public HandleInterface<Device, CUdevice>
class device: public handle_interface<device, CUdevice>
{
public:
//Supported architectures
@@ -61,14 +61,14 @@ private:
inline nvmlDevice_t nvml_device() const;
public:
Device(CUdevice cu = CUdevice(), bool take_ownership = true): cu_(cu, take_ownership){}
device(CUdevice cu = CUdevice(), bool take_ownership = true): cu_(cu, take_ownership){}
//Accessors
Architecture architecture() const;
Handle<CUdevice> const & cu() const;
handle<CUdevice> const & cu() const;
//Informations
std::string infos() const;
size_t address_bits() const;
driver::Platform platform() const;
driver::platform platform() const;
std::vector<size_t> max_block_dim() const;
size_t max_threads_per_block() const;
size_t max_shared_memory() const;
@@ -87,7 +87,7 @@ public:
size_t max_mem_clock() const;
private:
Handle<CUdevice> cu_;
handle<CUdevice> cu_;
std::shared_ptr<std::pair<size_t, size_t>> interpreted_as_;
};

View File

@@ -42,7 +42,7 @@ namespace triton
namespace driver
{
class Context;
class context;
template<class T> void check(T){}
void check(nvrtcResult err);
@@ -137,7 +137,7 @@ public:
static nvrtcResult nvrtcCreateProgram(nvrtcProgram *prog, const char *src, const char *name, int numHeaders, const char **headers, const char **includeNames);
static nvrtcResult nvrtcGetProgramLog(nvrtcProgram prog, char *log);
static cublasHandle_t cublasHandle(Context const & ctx);
static cublasHandle_t cublasHandle(driver::context const & ctx);
static cublasStatus_t cublasCreate_v2(cublasHandle_t* h);
static cublasStatus_t cublasGetStream_v2(cublasHandle_t h, cudaStream_t *streamId);
static cublasStatus_t cublasSetStream_v2(cublasHandle_t h, cudaStream_t streamId);
@@ -146,7 +146,7 @@ public:
static cublasStatus_t cublasHgemm (cublasHandle_t handle, cublasOperation_t transa, cublasOperation_t transb, int m, int n, int k, half* alpha, const half *A, int lda, const half *B, int ldb, half* beta, half *C, int ldc);
static cublasStatus_t cublasGemmEx(cublasHandle_t handle, cublasOperation_t transa, cublasOperation_t transb, int m, int n, int k, const void *alpha, const void *A, cudaDataType Atype, int lda, const void *B, cudaDataType Btype, int ldb, const void *beta, void *C, cudaDataType Ctype, int ldc, cudaDataType computeType, cublasGemmAlgo_t algo);
static cudnnHandle_t cudnnHandle(Context const & ctx);
static cudnnHandle_t cudnnHandle(driver::context const & ctx);
static cudnnStatus_t cudnnCreatePoolingDescriptor(cudnnPoolingDescriptor_t *poolingDesc);
static cudnnStatus_t cudnnCreateConvolutionDescriptor(cudnnConvolutionDescriptor_t* convDesc);
static cudnnStatus_t cudnnCreateTensorDescriptor(cudnnTensorDescriptor_t *tensorDesc);

View File

@@ -32,14 +32,14 @@ namespace driver
{
// Event
class Event: public HandleInterface<Event, cu_event_t>
class Event: public handle_interface<Event, cu_event_t>
{
public:
float elapsed_time() const;
Handle<cu_event_t> const & cu() const;
handle<cu_event_t> const & cu() const;
private:
Handle<cu_event_t> cu_;
handle<cu_event_t> cu_;
};
}

View File

@@ -49,24 +49,24 @@ private:
};
template<class T, class CUType>
class HandleInterface{
class handle_interface{
public:
//Accessors
operator CUType() const { return *(((T*)this)->cu().h_); }
//Comparison
bool operator==(HandleInterface const & y) { return (CUType)(*this) == (CUType)(y); }
bool operator!=(HandleInterface const & y) { return (CUType)(*this) != (CUType)(y); }
bool operator<(HandleInterface const & y) { return (CUType)(*this) < (CUType)(y); }
bool operator==(handle_interface const & y) { return (CUType)(*this) == (CUType)(y); }
bool operator!=(handle_interface const & y) { return (CUType)(*this) != (CUType)(y); }
bool operator<(handle_interface const & y) { return (CUType)(*this) < (CUType)(y); }
};
template<class CUType>
class Handle{
class handle{
public:
template<class, class> friend class HandleInterface;
template<class, class> friend class handle_interface;
public:
//Constructors
Handle(CUType cu = CUType(), bool take_ownership = true);
~Handle();
handle(CUType cu = CUType(), bool take_ownership = true);
~handle();
CUType& operator*() { return *h_; }
CUType const & operator*() const { return *h_; }
CUType* operator->() const { return h_.get(); }

View File

@@ -34,27 +34,27 @@ namespace triton
namespace driver
{
class Buffer;
class buffer;
// Kernel
class Kernel: public HandleInterface<Kernel, CUfunction>
class kernel: public handle_interface<kernel, CUfunction>
{
public:
//Constructors
Kernel(Module const & program, const char * name);
kernel(driver::module const & program, const char * name);
//Accessors
Handle<CUfunction> const & cu() const;
Module const & module() const;
handle<CUfunction> const & cu() const;
driver::module const & module() const;
//Arguments setters
void setArg(unsigned int index, std::size_t size, void* ptr);
void setArg(unsigned int index, Buffer const &);
void setArg(unsigned int index, buffer const &);
template<class T> void setArg(unsigned int index, T value) { setArg(index, sizeof(T), (void*)&value); }
//Arguments getters
void* const* cu_params() const;
private:
Handle<CUfunction> cu_;
Module program_;
handle<CUfunction> cu_;
driver::module program_;
unsigned int address_bits_;
std::vector<std::shared_ptr<void> > cu_params_store_;
std::vector<void*> cu_params_;

View File

@@ -34,22 +34,22 @@ namespace triton
namespace driver
{
class Context;
class Device;
class context;
class device;
class Module: public HandleInterface<Module, CUmodule>
class module: public handle_interface<module, CUmodule>
{
static std::string header(Device const & device);
static std::string header(device const & device);
public:
Module(Context const & context, std::string const & source);
Context const & context() const;
Handle<CUmodule> const & cu() const;
Buffer symbol(const char * name) const;
module(driver::context const & context, std::string const & source);
driver::context const & context() const;
handle<CUmodule> const & cu() const;
buffer symbol(const char * name) const;
private:
Handle<CUmodule> cu_;
Context context_;
handle<CUmodule> cu_;
driver::context context_;
std::string source_;
};

View File

@@ -34,17 +34,17 @@ namespace triton
namespace driver
{
class Device;
class device;
class Platform
class platform
{
public:
//Accessors
std::string name() const { return "CUDA"; }
std::string version() const;
std::vector<Device> devices() const;
std::vector<device> devices() const;
private:
Handle<cu_platform> cu_;
handle<cu_platform> cu_;
};
}

View File

@@ -35,43 +35,43 @@ namespace triton
namespace driver
{
class Kernel;
class kernel;
class Event;
class Range;
class Buffer;
class buffer;
// Command Queue
class Stream: public HandleInterface<Stream, CUstream>
class stream: public handle_interface<stream, CUstream>
{
public:
//Constructors
Stream(CUstream stream, bool take_ownership);
Stream(Context const & context);
stream(CUstream stream, bool take_ownership);
stream(driver::context const & context);
//Accessors
Handle<CUstream> const & cu() const;
Context const & context() const;
handle<CUstream> const & cu() const;
driver::context const & context() const;
//Synchronize
void synchronize();
//Enqueue
void enqueue(Kernel const & kernel, std::array<size_t, 3> grid, std::array<size_t, 3> block, std::vector<Event> const * = NULL, Event *event = NULL);
void enqueue(kernel const & kernel, std::array<size_t, 3> grid, std::array<size_t, 3> block, std::vector<Event> const * = NULL, Event *event = NULL);
// Write
void write(Buffer const & buffer, bool blocking, std::size_t offset, std::size_t size, void const* ptr);
void write(driver::buffer const & buffer, bool blocking, std::size_t offset, std::size_t size, void const* ptr);
template<class T> void write(Buffer const & buffer, bool blocking, std::size_t offset, std::vector<T> const & x)
template<class T> void write(driver::buffer const & buffer, bool blocking, std::size_t offset, std::vector<T> const & x)
{ write(buffer, blocking, offset, x.size()*sizeof(T), x.data()); }
// Read
void read(Buffer const & buffer, bool blocking, std::size_t offset, std::size_t size, void* ptr);
void read(driver::buffer const & buffer, bool blocking, std::size_t offset, std::size_t size, void* ptr);
template<class T> void read(Buffer const & buffer, bool blocking, std::size_t offset, std::vector<T>& x)
template<class T> void read(driver::buffer const & buffer, bool blocking, std::size_t offset, std::vector<T>& x)
{ read(buffer, blocking, offset, x.size()*sizeof(T), x.data()); }
private:
Context context_;
Handle<CUstream> cu_;
driver::context context_;
handle<CUstream> cu_;
};

45
include/triton/jit.h Normal file
View File

@@ -0,0 +1,45 @@
#ifndef TDL_INCLUDE_JIT_H
#define TDL_INCLUDE_JIT_H
#include <string>
#include <memory>
#include "llvm/IR/LLVMContext.h"
#include "triton/ir/context.h"
#include "triton/driver/module.h"
#include "triton/driver/kernel.h"
namespace llvm {
class Module;
}
namespace triton {
namespace ir {
class module;
class context;
}
class jit {
private:
void init_llvm();
std::string compute_data_layout(bool is64Bit = true, bool UseShortPointers = true);
std::unique_ptr<llvm::Module> make_llvm_module(triton::ir::module &module, const std::vector<unsigned>& params);
std::unique_ptr<ir::module> make_triton_module(const std::string &src);
public:
jit(driver::context context);
void add_module(ir::module &module, const std::vector<unsigned>& params = {});
void add_module(const std::string &src, const std::vector<unsigned>& params = {});
driver::kernel get_function(const std::string &name);
private:
std::vector<driver::module> modules_;
driver::context driver_context_;
llvm::LLVMContext llvm_context_;
ir::context triton_context_;
};
}
#endif

View File

@@ -47,14 +47,14 @@ void backend::modules::release(){
cache_.clear();
}
Module& backend::modules::get(Stream const & stream, std::string const & name, std::string const & src){
std::tuple<Stream, std::string> key(stream, name);
module& backend::modules::get(driver::stream const & stream, std::string const & name, std::string const & src){
std::tuple<driver::stream, std::string> key(stream, name);
if(cache_.find(key)==cache_.end())
return *cache_.insert(std::make_pair(key, new Module(stream.context(), src))).first->second;
return *cache_.insert(std::make_pair(key, new module(stream.context(), src))).first->second;
return *cache_.at(key);
}
std::map<std::tuple<Stream, std::string>, Module * > backend::modules::cache_;
std::map<std::tuple<stream, std::string>, module * > backend::modules::cache_;
/*-----------------------------------*/
//----------- Kernels --------------*/
@@ -66,23 +66,23 @@ void backend::kernels::release(){
cache_.clear();
}
Kernel & backend::kernels::get(Module const & program, std::string const & name){
std::tuple<Module, std::string> key(program, name);
kernel & backend::kernels::get(driver::module const & program, std::string const & name){
std::tuple<module, std::string> key(program, name);
if(cache_.find(key)==cache_.end())
return *cache_.insert(std::make_pair(key, new Kernel(program, name.c_str()))).first->second;
return *cache_.insert(std::make_pair(key, new kernel(program, name.c_str()))).first->second;
return *cache_.at(key);
}
std::map<std::tuple<Module, std::string>, Kernel * > backend::kernels::cache_;
std::map<std::tuple<module, std::string>, kernel * > backend::kernels::cache_;
/*-----------------------------------*/
//------------ Queues --------------*/
/*-----------------------------------*/
void backend::streams::init(std::list<const Context *> const & contexts){
for(Context const * ctx : contexts)
void backend::streams::init(std::list<const context *> const & contexts){
for(context const * ctx : contexts)
if(cache_.find(*ctx)==cache_.end())
cache_.insert(std::make_pair(*ctx, std::vector<Stream*>{new Stream(*ctx)}));
cache_.insert(std::make_pair(*ctx, std::vector<stream*>{new stream(*ctx)}));
}
void backend::streams::release(){
@@ -92,32 +92,32 @@ void backend::streams::release(){
cache_.clear();
}
Stream & backend::streams::get_default()
stream & backend::streams::get_default()
{ return get(contexts::get_default(), 0); }
Stream & backend::streams::get(Context const & context, unsigned int id){
init(std::list<Context const *>(1,&context));
stream & backend::streams::get(driver::context const & context, unsigned int id){
init(std::list<driver::context const *>(1,&context));
for(auto & x : cache_)
if(x.first==context)
return *x.second[id];
throw;
}
void backend::streams::get(Context const & context, std::vector<Stream*> & queues){
init(std::list<Context const *>(1,&context));
void backend::streams::get(driver::context const & context, std::vector<stream*> & queues){
init(std::list<driver::context const *>(1,&context));
queues = cache_.at(context);
}
std::map<Context, std::vector<Stream*> > backend::streams::cache_;
std::map<context, std::vector<stream*> > backend::streams::cache_;
/*-----------------------------------*/
//------------ Contexts ------------*/
/*-----------------------------------*/
void backend::contexts::init(std::vector<Platform> const & platforms){
for(Platform const & platform: platforms){
for(Device const & device: platform.devices())
cache_.push_back(new Context(device));
void backend::contexts::init(std::vector<platform> const & platforms){
for(platform const & platform: platforms){
for(device const & device: platform.devices())
cache_.push_back(new context(device));
}
}
@@ -127,19 +127,19 @@ void backend::contexts::release(){
cache_.clear();
}
Context const & backend::contexts::get_default(){
driver::context const & backend::contexts::get_default(){
backend::init();
std::list<Context const *>::const_iterator it = cache_.begin();
std::list<context const *>::const_iterator it = cache_.begin();
std::advance(it, default_device);
return **it;
}
void backend::contexts::get(std::list<Context const *> & contexts){
void backend::contexts::get(std::list<context const *> & contexts){
backend::init();
contexts = cache_;
}
std::list<Context const *> backend::contexts::cache_;
std::list<context const *> backend::contexts::cache_;
@@ -147,28 +147,28 @@ std::list<Context const *> backend::contexts::cache_;
//------------ General -------------*/
/*-----------------------------------*/
std::vector<Device> backend::devices(){
std::vector<Platform> platforms = backend::platforms();
std::vector<Device> result;
for(Platform const & platform: platforms){
std::vector<device> backend::devices(){
std::vector<platform> platforms = backend::platforms();
std::vector<device> result;
for(platform const & platform: platforms){
auto devices = platform.devices();
result.insert(result.end(), devices.begin(), devices.end());
}
return result;
}
std::vector<Platform> backend::platforms(){
std::vector<Platform> platforms;
std::vector<platform> backend::platforms(){
std::vector<platform> platforms;
//if CUDA is here
if(dispatch::cuinit())
platforms.push_back(Platform());
platforms.push_back(platform());
if(platforms.empty())
throw std::runtime_error("ISAAC: No backend available. Make sure CUDA is available in your library path");
return platforms;
}
void backend::synchronize(Context const & context){
for(Stream * queue: streams::cache_.at(context))
void backend::synchronize(driver::context const & context){
for(stream * queue: streams::cache_.at(context))
queue->synchronize();
}
@@ -184,7 +184,7 @@ void backend::release(){
void backend::init(){
if(!contexts::cache_.empty())
return;
std::vector<Platform> platforms = backend::platforms();
std::vector<platform> platforms = backend::platforms();
contexts::init(platforms);
streams::init(contexts::cache_);
}

View File

@@ -33,26 +33,26 @@ namespace triton
namespace driver
{
Buffer::Buffer(Context const & context, size_t size) : context_(context)
buffer::buffer(driver::context const & context, size_t size) : context_(context)
{
ContextSwitcher ctx_switch(context_);
dispatch::cuMemAlloc(&*cu_, size);
}
Buffer::Buffer(Context const & context, CUdeviceptr cu, bool take_ownership):
buffer::buffer(driver::context const & context, CUdeviceptr cu, bool take_ownership):
context_(context), cu_(cu, take_ownership)
{ }
void Buffer::set_zero(Stream const & queue, size_t size)
void buffer::set_zero(stream const & queue, size_t size)
{
ContextSwitcher ctx_switch(context_);
dispatch::cuMemsetD8Async(*cu_, 0, size, queue);
}
Handle<CUdeviceptr> const & Buffer::cu() const
handle<CUdeviceptr> const & buffer::cu() const
{ return cu_; }
Handle<CUdeviceptr> & Buffer::cu()
handle<CUdeviceptr> & buffer::cu()
{ return cu_; }
}

View File

@@ -35,7 +35,7 @@ namespace triton
namespace driver
{
std::string Context::get_cache_path(){
std::string context::get_cache_path(){
//user-specified cache path
std::string result = tools::getenv("ISAAC_CACHE_PATH");
if(!result.empty()){
@@ -54,7 +54,7 @@ std::string Context::get_cache_path(){
return "";
}
CUdevice Context::device(CUcontext context){
CUdevice context::device(CUcontext context){
dispatch::cuCtxPushCurrent_v2(context);
CUdevice res;
dispatch::cuCtxGetDevice(&res);
@@ -62,26 +62,26 @@ CUdevice Context::device(CUcontext context){
return res;
}
Context::Context(CUcontext context, bool take_ownership): cu_(context, take_ownership), device_(device(context), false), cache_path_(get_cache_path())
context::context(CUcontext context, bool take_ownership): cu_(context, take_ownership), dvc_(device(context), false), cache_path_(get_cache_path())
{ }
Context::Context(Device const & device): device_(device), cache_path_(get_cache_path())
context::context(driver::device const & device): dvc_(device), cache_path_(get_cache_path())
{
dispatch::cuCtxCreate(&*cu_, CU_CTX_SCHED_AUTO, (CUdevice)device);
dispatch::cuCtxPopCurrent_v2(NULL);
}
Device const & Context::device() const
{ return device_; }
device const & context::device() const
{ return dvc_; }
std::string const & Context::cache_path() const
std::string const & context::cache_path() const
{ return cache_path_; }
Handle<CUcontext> const & Context::cu() const
handle<CUcontext> const & context::cu() const
{ return cu_; }
/* Context Switcher */
ContextSwitcher::ContextSwitcher(Context const & ctx): ctx_(ctx)
ContextSwitcher::ContextSwitcher(driver::context const & ctx): ctx_(ctx)
{
dispatch::cuCtxPushCurrent_v2(ctx_);
}

View File

@@ -35,7 +35,7 @@ namespace driver
{
/* Architecture [NVidia] */
Device::Architecture Device::nv_arch(std::pair<unsigned int, unsigned int> sm) const{
device::Architecture device::nv_arch(std::pair<unsigned int, unsigned int> sm) const{
switch(sm.first)
{
case 7:
@@ -81,13 +81,13 @@ Device::Architecture Device::nv_arch(std::pair<unsigned int, unsigned int> sm) c
}
template<CUdevice_attribute attr>
int Device::cuGetInfo() const{
int device::cuGetInfo() const{
int res;
dispatch::cuDeviceGetAttribute(&res, attr, *cu_);
return res;
}
nvmlDevice_t Device::nvml_device() const{
nvmlDevice_t device::nvml_device() const{
std::map<std::string, nvmlDevice_t> map;
std::string key = pci_bus_id();
if(map.find(key)==map.end()){
@@ -99,33 +99,33 @@ nvmlDevice_t Device::nvml_device() const{
}
/* Architecture */
Device::Architecture Device::architecture() const
device::Architecture device::architecture() const
{ return nv_arch(compute_capability()); }
/* Attributes */
size_t Device::address_bits() const
size_t device::address_bits() const
{ return sizeof(size_t)*8; }
driver::Platform Device::platform() const
{ return Platform(); }
driver::platform device::platform() const
{ return platform(); }
std::string Device::name() const{
std::string device::name() const{
char tmp[128];
dispatch::cuDeviceGetName(tmp, 128, *cu_);
return std::string(tmp);
}
std::string Device::pci_bus_id() const{
std::string device::pci_bus_id() const{
char tmp[128];
dispatch::cuDeviceGetPCIBusId(tmp, 128, *cu_);
return std::string(tmp);
}
void Device::interpret_as(std::pair<size_t, size_t> cc){
void device::interpret_as(std::pair<size_t, size_t> cc){
interpreted_as_ = std::make_shared<std::pair<size_t, size_t>>(cc);
}
std::pair<size_t, size_t> Device::compute_capability() const{
std::pair<size_t, size_t> device::compute_capability() const{
if(interpreted_as_)
return *interpreted_as_;
size_t _major = cuGetInfo<CU_DEVICE_ATTRIBUTE_COMPUTE_CAPABILITY_MAJOR>();
@@ -133,17 +133,17 @@ std::pair<size_t, size_t> Device::compute_capability() const{
return std::make_pair(_major, _minor);
}
size_t Device::max_threads_per_block() const
size_t device::max_threads_per_block() const
{ return cuGetInfo<CU_DEVICE_ATTRIBUTE_MAX_THREADS_PER_BLOCK>(); }
size_t Device::max_shared_memory() const
size_t device::max_shared_memory() const
{ return cuGetInfo<CU_DEVICE_ATTRIBUTE_MAX_SHARED_MEMORY_PER_BLOCK>(); }
size_t Device::warp_size() const
size_t device::warp_size() const
{ return cuGetInfo<CU_DEVICE_ATTRIBUTE_WARP_SIZE>(); }
std::vector<size_t> Device::max_block_dim() const{
std::vector<size_t> device::max_block_dim() const{
std::vector<size_t> result(3);
result[0] = cuGetInfo<CU_DEVICE_ATTRIBUTE_MAX_BLOCK_DIM_X>();
result[1] = cuGetInfo<CU_DEVICE_ATTRIBUTE_MAX_BLOCK_DIM_Y>();
@@ -151,33 +151,33 @@ std::vector<size_t> Device::max_block_dim() const{
return result;
}
size_t Device::current_sm_clock() const{
size_t device::current_sm_clock() const{
unsigned int result;
dispatch::nvmlDeviceGetClockInfo(nvml_device(), NVML_CLOCK_SM, &result);
return result;
}
size_t Device::max_sm_clock() const{
size_t device::max_sm_clock() const{
unsigned int result;
dispatch::nvmlDeviceGetMaxClockInfo(nvml_device(), NVML_CLOCK_SM, &result);
return result;
}
size_t Device::current_mem_clock() const{
size_t device::current_mem_clock() const{
unsigned int result;
dispatch::nvmlDeviceGetClockInfo(nvml_device(), NVML_CLOCK_MEM, &result);
return result;
}
size_t Device::max_mem_clock() const{
size_t device::max_mem_clock() const{
unsigned int result;
dispatch::nvmlDeviceGetMaxClockInfo(nvml_device(), NVML_CLOCK_MEM, &result);
return result;
}
/* Infos */
std::string Device::infos() const{
std::string device::infos() const{
std::ostringstream oss;
std::vector<size_t> max_wi_sizes = max_block_dim();
oss << "Platform: " << platform().name() << std::endl;
@@ -188,7 +188,7 @@ std::string Device::infos() const{
return oss.str();
}
Handle<CUdevice> const & Device::cu() const
handle<CUdevice> const & device::cu() const
{ return cu_; }
}

View File

@@ -180,16 +180,16 @@ NVML_DEFINE2(nvmlReturn_t, nvmlDeviceGetHandleByPciBusId_v2, const char *, nvmlD
NVML_DEFINE3(nvmlReturn_t, nvmlDeviceGetClockInfo, nvmlDevice_t, nvmlClockType_t, unsigned int*)
NVML_DEFINE3(nvmlReturn_t, nvmlDeviceGetMaxClockInfo, nvmlDevice_t, nvmlClockType_t, unsigned int*)
cublasHandle_t dispatch::cublasHandle(Context const & ctx){
static std::map<Context, cublasHandle_t> handles;
cublasHandle_t dispatch::cublasHandle(driver::context const & ctx){
static std::map<context, cublasHandle_t> handles;
auto pr = handles.insert({ctx, cublasHandle_t()});
if(pr.second)
cublasCreate_v2(&pr.first->second);
return pr.first->second;
}
cudnnHandle_t dispatch::cudnnHandle(Context const & ctx){
static std::map<Context, cudnnHandle_t> handles;
cudnnHandle_t dispatch::cudnnHandle(driver::context const & ctx){
static std::map<context, cudnnHandle_t> handles;
auto pr = handles.insert({ctx, cudnnHandle_t()});
if(pr.second)
cudnnCreate(&pr.first->second);

View File

@@ -33,7 +33,7 @@ float Event::elapsed_time() const{
return time;
}
Handle<cu_event_t> const & Event::cu() const
handle<cu_event_t> const & Event::cu() const
{ return cu_; }
}

View File

@@ -43,24 +43,24 @@ inline void _delete(cu_platform){}
//Constructor
template<class CUType>
Handle<CUType>::Handle(CUType cu, bool take_ownership): h_(new CUType(cu)), has_ownership_(take_ownership)
handle<CUType>::handle(CUType cu, bool take_ownership): h_(new CUType(cu)), has_ownership_(take_ownership)
{ }
template<class CUType>
Handle<CUType>::~Handle(){
handle<CUType>::~handle(){
if(has_ownership_ && h_ && h_.unique() && *h_)
_delete(*h_);
}
template class Handle<CUdeviceptr>;
template class Handle<CUstream>;
template class Handle<CUcontext>;
template class Handle<CUdevice>;
template class Handle<cu_event_t>;
template class Handle<CUfunction>;
template class Handle<CUmodule>;
template class Handle<cu_platform>;
template class handle<CUdeviceptr>;
template class handle<CUstream>;
template class handle<CUcontext>;
template class handle<CUdevice>;
template class handle<cu_event_t>;
template class handle<CUfunction>;
template class handle<CUmodule>;
template class handle<cu_platform>;
}
}

View File

@@ -32,13 +32,13 @@ namespace triton
namespace driver
{
Kernel::Kernel(Module const & program, const char * name) : program_(program), address_bits_(program.context().device().address_bits()){
kernel::kernel(driver::module const & program, const char * name) : program_(program), address_bits_(program.context().device().address_bits()){
cu_params_store_.reserve(64);
cu_params_.reserve(64);
dispatch::cuModuleGetFunction(&*cu_, program, name);
}
void Kernel::setArg(unsigned int index, std::size_t size, void* ptr){
void kernel::setArg(unsigned int index, std::size_t size, void* ptr){
if(index + 1> cu_params_store_.size()){
cu_params_store_.resize(index+1);
cu_params_.resize(index+1);
@@ -48,16 +48,16 @@ void Kernel::setArg(unsigned int index, std::size_t size, void* ptr){
cu_params_[index] = cu_params_store_[index].get();
}
void Kernel::setArg(unsigned int index, Buffer const & data)
void kernel::setArg(unsigned int index, buffer const & data)
{ return setArg(index, (CUdeviceptr)data);}
void* const* Kernel::cu_params() const
void* const* kernel::cu_params() const
{ return cu_params_.data(); }
Handle<CUfunction> const & Kernel::cu() const
handle<CUfunction> const & kernel::cu() const
{ return cu_; }
Module const & Kernel::module() const
driver::module const & kernel::module() const
{ return program_; }

View File

@@ -34,7 +34,7 @@ namespace triton
namespace driver
{
Module::Module(Context const & context, std::string const & source) : context_(context), source_(source){
module::module(driver::context const & context, std::string const & source) : context_(context), source_(source){
ContextSwitcher ctx_switch(context_);
// JIT compile source-code
CUjit_option opt[] = {CU_JIT_ERROR_LOG_BUFFER_SIZE_BYTES, CU_JIT_ERROR_LOG_BUFFER};
@@ -50,17 +50,17 @@ Module::Module(Context const & context, std::string const & source) : context_(c
}
}
Context const & Module::context() const
driver::context const & module::context() const
{ return context_; }
Handle<CUmodule> const & Module::cu() const
handle<CUmodule> const & module::cu() const
{ return cu_; }
Buffer Module::symbol(const char *name) const{
buffer module::symbol(const char *name) const{
CUdeviceptr handle;
size_t size;
dispatch::cuModuleGetGlobal_v2(&handle, &size, *cu_, name);
return Buffer(context_, handle, false);
return buffer(context_, handle, false);
}

View File

@@ -31,20 +31,20 @@ namespace triton
namespace driver
{
std::string Platform::version() const{
std::string platform::version() const{
int version;
dispatch::cuDriverGetVersion(&version);
return std::to_string(version);
}
std::vector<Device> Platform::devices() const{
std::vector<Device> devices;
std::vector<device> platform::devices() const{
std::vector<device> devices;
int N;
dispatch::cuDeviceGetCount(&N);
for(int i = 0 ; i < N ; ++i){
CUdevice device;
dispatch::cuDeviceGet(&device, i);
devices.push_back(Device(device));
CUdevice dvc;
dispatch::cuDeviceGet(&dvc, i);
devices.push_back(driver::device(dvc));
}
return devices;
}

View File

@@ -44,25 +44,25 @@ inline CUcontext cucontext(){
return result;
}
Stream::Stream(CUstream stream, bool take_ownership): context_(cucontext(), take_ownership), cu_(stream, take_ownership)
stream::stream(CUstream stream, bool take_ownership): context_(cucontext(), take_ownership), cu_(stream, take_ownership)
{}
Stream::Stream(Context const & context): context_(context), cu_(CUstream(), true)
stream::stream(driver::context const & context): context_(context), cu_(CUstream(), true)
{
ContextSwitcher ctx_switch(context_);
dispatch::cuStreamCreate(&*cu_, 0);
}
void Stream::synchronize()
void stream::synchronize()
{
ContextSwitcher ctx_switch(context_);
dispatch::cuStreamSynchronize(*cu_);
}
Context const & Stream::context() const
driver::context const & stream::context() const
{ return context_; }
void Stream::enqueue(Kernel const & kernel, std::array<size_t, 3> grid, std::array<size_t, 3> block, std::vector<Event> const *, Event* event){
void stream::enqueue(kernel const & kernel, std::array<size_t, 3> grid, std::array<size_t, 3> block, std::vector<Event> const *, Event* event){
ContextSwitcher ctx_switch(context_);
if(event)
dispatch::cuEventRecord(((cu_event_t)*event).first, *cu_);
@@ -71,7 +71,7 @@ void Stream::enqueue(Kernel const & kernel, std::array<size_t, 3> grid, std::arr
dispatch::cuEventRecord(((cu_event_t)*event).second, *cu_);
}
void Stream::write(Buffer const & buffer, bool blocking, std::size_t offset, std::size_t size, void const* ptr){
void stream::write(buffer const & buffer, bool blocking, std::size_t offset, std::size_t size, void const* ptr){
ContextSwitcher ctx_switch(context_);
if(blocking)
dispatch::cuMemcpyHtoD(buffer + offset, ptr, size);
@@ -79,7 +79,7 @@ void Stream::write(Buffer const & buffer, bool blocking, std::size_t offset, std
dispatch::cuMemcpyHtoDAsync(buffer + offset, ptr, size, *cu_);
}
void Stream::read(Buffer const & buffer, bool blocking, std::size_t offset, std::size_t size, void* ptr){
void stream::read(buffer const & buffer, bool blocking, std::size_t offset, std::size_t size, void* ptr){
ContextSwitcher ctx_switch(context_);
if(blocking)
dispatch::cuMemcpyDtoH(ptr, buffer + offset, size);
@@ -87,7 +87,7 @@ void Stream::read(Buffer const & buffer, bool blocking, std::size_t offset, std:
dispatch::cuMemcpyDtoHAsync(ptr, buffer + offset, size, *cu_);
}
Handle<CUstream> const & Stream::cu() const
handle<CUstream> const & stream::cu() const
{ return cu_; }
}

151
lib/jit.cpp Normal file
View File

@@ -0,0 +1,151 @@
#include "triton/jit.h"
#include <string>
#include "triton/ast/ast.h"
#include "triton/ir/context.h"
#include "triton/ir/context_impl.h"
#include "triton/codegen/selection.h"
#include "triton/codegen/tune.h"
#include "triton/codegen/shared_copy.h"
#include "triton/codegen/allocation.h"
#include "triton/codegen/liveness.h"
#include "triton/codegen/vectorize.h"
#include "triton/codegen/buffer_info.h"
#include "triton/codegen/barriers.h"
#include "llvm/IR/IRPrintingPasses.h"
#include "llvm/IR/Module.h"
#include "llvm/IR/LLVMContext.h"
#include "llvm/IR/PassManager.h"
#include "llvm/Support/raw_ostream.h"
#include "llvm/Support/TargetRegistry.h"
#include "llvm/Support/TargetSelect.h"
#include "llvm/Target/TargetMachine.h"
#include "llvm/Target/TargetOptions.h"
#include "llvm/CodeGen/TargetPassConfig.h"
#include "llvm/IR/LegacyPassManager.h"
#include "llvm/Transforms/Scalar/EarlyCSE.h"
#include "llvm/Analysis/LoopPass.h"
typedef struct yy_buffer_state * YY_BUFFER_STATE;
extern int yyparse();
extern YY_BUFFER_STATE yy_scan_string(const char * str);
extern void yy_delete_buffer(YY_BUFFER_STATE buffer);
using triton::ast::translation_unit;
extern translation_unit *ast_root;
namespace triton {
void jit::init_llvm() {
static bool init = false;
if(!init){
llvm::InitializeAllTargetInfos();
llvm::InitializeAllTargets();
llvm::InitializeAllTargetMCs();
llvm::InitializeAllAsmParsers();
llvm::InitializeAllAsmPrinters();
init = true;
}
}
std::unique_ptr<llvm::Module> jit::make_llvm_module(ir::module &module, const std::vector<unsigned>& params) {
llvm::Module* result = new llvm::Module("matmul", llvm_context_);
// create passes
codegen::buffer_info_pass buffer_info;
codegen::place_shared_copy shared(&buffer_info);
codegen::tune tune;
codegen::liveness liveness(&buffer_info);
codegen::allocation allocation(&liveness, &buffer_info);
codegen::barriers barriers(&allocation, &buffer_info);
codegen::vectorize vectorize(&tune);
codegen::selection selection(&allocation, &tune, &buffer_info);
// tuning parameters
tune.run(module);
unsigned i = 0;
triton_context_.p_impl->mp_constants_[0]->set_value(params[0]);
triton_context_.p_impl->mp_constants_[1]->set_value(params[1]);
triton_context_.p_impl->mp_constants_[2]->set_value(params[2]);
for(unsigned *x: tune.get_params(module))
*x = params[3 + i++];
// constraints
std::map<ir::value*, std::vector<std::string>> errors;
tune.check_constraints(module, errors);
std::cout << "errors: " << errors.size() << std::endl;
for(auto &x: errors){
for(auto &e: x.second)
std::cout << x.first->get_name() << " " << e << std::endl;
}
if(errors.size())
exit(EXIT_FAILURE);
// generate ptx
buffer_info.run(module);
shared.run(module);
liveness.run(module);
allocation.run();
barriers.run(module);
vectorize.run(module);
selection.run(module, *result);
return std::unique_ptr<llvm::Module>(result);
}
std::unique_ptr<ir::module> jit::make_triton_module(const std::string &src) {
// create AST from Triton-C source
YY_BUFFER_STATE buffer = yy_scan_string(src.c_str());
yyparse();
yy_delete_buffer(buffer);
translation_unit *program = ast_root;
// create Triton-IR from AST
ir::module* module = new ir::module("matrix", triton_context_);
program->codegen(module);
return std::unique_ptr<ir::module>(module);
}
jit::jit(driver::context context): driver_context_(context) {
}
std::string jit::compute_data_layout(bool is_64bit, bool use_short_pointers) {
std::string ret = "e";
if (!is_64bit)
ret += "-p:32:32";
else if (use_short_pointers)
ret += "-p3:32:32-p4:32:32-p5:32:32";
ret += "-i64:64-i128:128-v16:16-v32:32-n16:32:64";
return ret;
}
void jit::add_module(ir::module &tt_module, const std::vector<unsigned> &params) {
init_llvm();
auto ll_module = make_llvm_module(tt_module, params);
ll_module->setTargetTriple("nvptx64-nvidia-cuda");
std::string error;
auto target = llvm::TargetRegistry::lookupTarget(ll_module->getTargetTriple(), error);
llvm::TargetMachine *machine = target->createTargetMachine(ll_module->getTargetTriple(), "sm_52", "",
llvm::TargetOptions(), llvm::Reloc::Model(),
llvm::None, llvm::CodeGenOpt::Aggressive);
ll_module->setDataLayout(compute_data_layout());
// emit machine code
llvm::legacy::PassManager pass;
llvm::SmallVector<char, 0> buffer;
llvm::raw_svector_ostream stream(buffer);
machine->addPassesToEmitFile(pass, stream, nullptr, llvm::TargetMachine::CGFT_AssemblyFile);
pass.run(*ll_module);
std::string src(buffer.begin(), buffer.end());
modules_.push_back(driver::module(driver_context_, src));
}
void jit::add_module(const std::string &src, const std::vector<unsigned> &params) {
auto ptt_module = make_triton_module(src);
add_module(*ptt_module, params);
}
driver::kernel jit::get_function(const std::string &name) {
return driver::kernel(modules_.front(), name.c_str());
}
}