some cleaning

This commit is contained in:
Philippe Tillet
2019-03-21 23:51:47 -07:00
parent 907bbb1ad2
commit 49fd6ece99
20 changed files with 423 additions and 182 deletions

View File

@@ -45,6 +45,12 @@ protected:
driver::context* context_;
};
// CPU
class cpu_buffer: public buffer
{
};
// OpenCL
class ocl_buffer: public buffer
{

View File

@@ -48,6 +48,11 @@ protected:
std::string cache_path_;
};
// CPU
class cpu_context: public context {
};
// CUDA
class cu_context: public context {
public:

View File

@@ -40,6 +40,11 @@ public:
using polymorphic_resource::polymorphic_resource;
};
// CPU device
class cpu_device: public device {
};
// OpenCL device
class ocl_device: public device {
public:

View File

@@ -49,6 +49,7 @@ template<class T> void check(T){}
void check(CUresult err);
void check(cublasStatus_t err);
void check(cudnnStatus_t err);
void check(cl_int err);
class dispatch
{
@@ -117,6 +118,7 @@ public:
static cl_int clGetKernelInfo(cl_kernel, cl_kernel_info, size_t, void *, size_t *);
static cl_int clGetKernelWorkGroupInfo(cl_kernel, cl_device_id, cl_kernel_work_group_info, size_t, void *, size_t *);
static cl_kernel clCreateKernel(cl_program, const char *, cl_int *);
static cl_int clCreateKernelsInProgram(cl_program, cl_uint, cl_kernel*, cl_uint*);
static cl_mem clCreateBuffer(cl_context, cl_mem_flags, size_t, void *, cl_int *);
static cl_program clCreateProgramWithSource(cl_context, cl_uint, const char **, const size_t *, cl_int *);
static cl_int clReleaseKernel(cl_kernel);
@@ -233,6 +235,7 @@ private:
static void* clGetKernelInfo_;
static void* clGetKernelWorkGroupInfo_;
static void* clCreateKernel_;
static void* clCreateKernelsInProgram_;
static void* clCreateBuffer_;
static void* clCreateProgramWithSource_;
static void* clReleaseKernel_;

View File

@@ -31,8 +31,8 @@ namespace triton
namespace driver
{
// Event
class Event
// event
class event
{
public:
float elapsed_time() const;

View File

@@ -35,6 +35,12 @@ namespace triton
namespace driver
{
enum backend_t {
CUDA,
OpenCL
};
// helpers for CUDA
struct cu_event_t{
operator bool() const { return first && second; }
CUevent first;
@@ -79,18 +85,20 @@ protected:
template<class CUType, class CLType>
class polymorphic_resource {
public:
polymorphic_resource(CUType cu, bool take_ownership): cu_(cu, take_ownership){}
polymorphic_resource(CLType cl, bool take_ownership): cl_(cl, take_ownership){}
polymorphic_resource(CUType cu, bool take_ownership): cu_(cu, take_ownership), backend_(CUDA){}
polymorphic_resource(CLType cl, bool take_ownership): cl_(cl, take_ownership), backend_(OpenCL){}
virtual ~polymorphic_resource() { }
handle<CUType> cu() { return cu_; }
handle<CLType> cl() { return cl_; }
const handle<CUType>& cu() const { return cu_; }
const handle<CLType>& cl() const { return cl_; }
backend_t backend() { return backend_; }
protected:
handle<CLType> cl_;
handle<CUType> cu_;
backend_t backend_;
};
}

View File

@@ -53,6 +53,11 @@ private:
driver::module* program_;
};
// CPU
class cpu_kernel: public kernel {
};
// OpenCL
class ocl_kernel: public kernel {
public:

View File

@@ -56,12 +56,17 @@ public:
driver::context* context() const;
void compile_llvm_module(llvm::Module* module, const std::string& triple,
const std::string &proc, std::string layout,
llvm::SmallVectorImpl<char> &buffer);
llvm::SmallVectorImpl<char> &buffer, std::vector<std::string> files = {});
protected:
driver::context* ctx_;
};
// CPU
class cpu_module: public module{
};
// OpenCL
class ocl_module: public module{

View File

@@ -50,6 +50,7 @@ private:
std::string name_;
};
// CUDA
class cu_platform: public platform
{
public:
@@ -61,6 +62,7 @@ private:
handle<CUPlatform> cu_;
};
// OpenCL
class cl_platform: public platform
{
public:
@@ -72,6 +74,15 @@ private:
handle<cl_platform_id> cl_;
};
// CPU
class cpu_platform: public platform
{
public:
cpu_platform(): platform("CPU") { }
std::string version() const;
void devices(std::vector<driver::device*> &devices) const;
};
}
}

View File

@@ -36,7 +36,7 @@ namespace driver
{
class kernel;
class Event;
class event;
class Range;
class cu_buffer;
@@ -51,7 +51,7 @@ public:
driver::context* context() const;
// methods
virtual void synchronize() = 0;
virtual void enqueue(driver::kernel* kernel, std::array<size_t, 3> grid, std::array<size_t, 3> block, std::vector<Event> const * = NULL, Event *event = NULL) = 0;
virtual void enqueue(driver::kernel* kernel, std::array<size_t, 3> grid, std::array<size_t, 3> block, std::vector<event> const * = NULL, event *event = NULL) = 0;
virtual void write(driver::buffer* buf, bool blocking, std::size_t offset, std::size_t size, void const* ptr) = 0;
virtual void read(driver::buffer* buf, bool blocking, std::size_t offset, std::size_t size, void* ptr) = 0;
// template helpers
@@ -64,6 +64,11 @@ protected:
driver::context *ctx_;
};
// CPU
class cpu_stream: public stream {
};
// OpenCL
class cl_stream: public stream {
public:
@@ -72,7 +77,7 @@ public:
// Overridden
void synchronize();
void enqueue(driver::kernel* kernel, std::array<size_t, 3> grid, std::array<size_t, 3> block, std::vector<Event> const *, Event *event);
void enqueue(driver::kernel* kernel, std::array<size_t, 3> grid, std::array<size_t, 3> block, std::vector<event> const *, event *event);
void write(driver::buffer* buf, bool blocking, std::size_t offset, std::size_t size, void const* ptr);
void read(driver::buffer* buf, bool blocking, std::size_t offset, std::size_t size, void* ptr);
};
@@ -86,7 +91,7 @@ public:
// Overridden
void synchronize();
void enqueue(driver::kernel* kernel, std::array<size_t, 3> grid, std::array<size_t, 3> block, std::vector<Event> const *, Event *event);
void enqueue(driver::kernel* kernel, std::array<size_t, 3> grid, std::array<size_t, 3> block, std::vector<event> const *, event *event);
void write(driver::buffer* buf, bool blocking, std::size_t offset, std::size_t size, void const* ptr);
void read(driver::buffer* buf, bool blocking, std::size_t offset, std::size_t size, void* ptr);
};

View File

@@ -19,6 +19,59 @@ namespace codegen{
using namespace llvm;
inline void set_kernel(llvm::IRBuilder<>& builder, llvm::LLVMContext &ctx, llvm::Module *module, llvm::Function* fn) {
fn->setCallingConv(llvm::CallingConv::AMDGPU_KERNEL);
// module->getOrInsertNamedMetadata("opencl.ocl.version")->addOperand(llvm::MDTuple::get(ctx, {llvm::ValueAsMetadata::get(builder.getInt32(2)), llvm::ValueAsMetadata::get(builder.getInt32(0))}));
// // set metadata
// llvm::Metadata *md_args[] = {
// llvm::ValueAsMetadata::get(dst_fn),
// llvm::MDString::get(dst_ctx, "kernel"),
// llvm::ValueAsMetadata::get(dst_builder.getInt32(1))
// };
// module->getOrInsertNamedMetadata("nvvm.annotations")->addOperand(llvm::MDNode::get(dst_ctx, md_args));
}
inline Instruction* add_barrier(llvm::Module *module, llvm::IRBuilder<>& builder) {
// Function *barrier = Intrinsic::getDeclaration(module, Intrinsic::nvvm_barrier0);
// return builder.CreateCall(barrier, {});
Function *barrier = Intrinsic::getDeclaration(module, Intrinsic::amdgcn_s_barrier);
return builder.CreateCall(barrier, {});
}
inline Value* get_global_offset(llvm::Module *module, llvm::IRBuilder<>& builder, unsigned stride, unsigned ax) {
// static std::array<Intrinsic::ID, 3> ctaid = {
// Intrinsic::nvvm_read_ptx_sreg_ctaid_x,
// Intrinsic::nvvm_read_ptx_sreg_ctaid_y,
// Intrinsic::nvvm_read_ptx_sreg_ctaid_z
// };
static std::array<Intrinsic::ID, 3> ids = {
Intrinsic::amdgcn_workgroup_id_x,
Intrinsic::amdgcn_workgroup_id_y,
Intrinsic::amdgcn_workgroup_id_z
};
Value* get_group_id = Intrinsic::getDeclaration(module, ids[ax]);
Value* group_id = builder.CreateCall(get_group_id, {});
Value* result = builder.CreateMul(builder.getInt32(stride), group_id);
return result;
}
inline Value* get_local_id(llvm::Module *module, llvm::IRBuilder<>& builder, unsigned ax) {
// static std::array<Intrinsic::ID, 3> ids = {
// Intrinsic::nvvm_read_ptx_sreg_tid_x,
// Intrinsic::nvvm_read_ptx_sreg_tid_y,
// Intrinsic::nvvm_read_ptx_sreg_tid_z
// };
static std::array<Intrinsic::ID, 3> ids = {
Intrinsic::amdgcn_workitem_id_x,
Intrinsic::amdgcn_workitem_id_y,
Intrinsic::amdgcn_workitem_id_z
};
Function *get_local_id = Intrinsic::getDeclaration(module, ids[ax]);
return builder.CreateCall(get_local_id, {});
}
/* Distributed Tile */
void distributed_tile::init_indices() {
std::vector<size_t> id(axes_.size(), 0);
@@ -264,8 +317,7 @@ Instruction *selection::llvm_inst(ir::instruction *inst, std::function<Value*(ir
}
if(dynamic_cast<ir::barrier_inst*>(inst)){
Module *module = builder.GetInsertBlock()->getModule();
Function *barrier = Intrinsic::getDeclaration(module, Intrinsic::nvvm_barrier0);
return builder.CreateCall(barrier, {});
return add_barrier(module, builder);
}
if(auto* ii = dynamic_cast<ir::phi_node*>(inst)){
Type *ty = type(ii->get_type()->get_scalar_ty());
@@ -561,9 +613,8 @@ void selection::create_tile(ir::value *v, IRBuilder<> &builder,
void selection::init_grids(ir::function *fn, IRBuilder<> &builder, Value *sh_mem_ptr){
// fetch linear ID
Module *mod = builder.GetInsertBlock()->getParent()->getParent();
Function *get_thread_id = Intrinsic::getDeclaration(mod, Intrinsic::nvvm_read_ptx_sreg_tid_x);
Value *warp_size = builder.getInt32(32);
Value *u_thread_id = builder.CreateCall(get_thread_id, {});
Value* u_thread_id = get_local_id(mod, builder, 0);
Value *u_thread_warp_id = builder.CreateURem(u_thread_id, warp_size);
Value *u_warp_id = builder.CreateUDiv(u_thread_id, warp_size);
// create grid
@@ -619,14 +670,7 @@ void selection::lower_tile_instruction(ir::instruction *ins, llvm::IRBuilder<> &
const auto& shapes = ins->get_type()->get_tile_shapes();
// global_range
if(auto *x = dynamic_cast<ir::get_global_range_inst*>(ins)) {
static std::array<Intrinsic::ID, 3> ctaid = {
Intrinsic::nvvm_read_ptx_sreg_ctaid_x,
Intrinsic::nvvm_read_ptx_sreg_ctaid_y,
Intrinsic::nvvm_read_ptx_sreg_ctaid_z
};
Function *get_group_id = Intrinsic::getDeclaration(module, ctaid[x->get_axis()]);
Value *group_id = builder.CreateCall(get_group_id, {});
Value *offset = builder.CreateMul(builder.getInt32(shapes[0]->get_value()), group_id);
Value *offset = get_global_offset(module, builder, shapes[0]->get_value(), x->get_axis());
result->for_each([&](indices_t idx){
BinaryOperator *bin = static_cast<BinaryOperator*>(idx[0]);
result->set_value(idx, builder.CreateAdd(bin, offset));
@@ -739,27 +783,27 @@ void selection::lower_tile_instruction(ir::instruction *ins, llvm::IRBuilder<> &
return;
// matrix multiplication
else if(dynamic_cast<ir::matmul_inst*>(ins)) {
ir::value *A = ins->get_operand(0);
ir::value *B = ins->get_operand(1);
ir::value *C = ins->get_operand(2);
shared_tile *TA = (shared_tile*)tmap_.at(A);
shared_tile *TB = (shared_tile*)tmap_.at(B);
distributed_tile *TC = (distributed_tile*)tmap_.at(C);
TA->set_vector_size(TC->axis(0).contiguous);
TB->set_vector_size(TC->axis(1).contiguous);
Function *f_mul_add = Intrinsic::getDeclaration(module, Intrinsic::fmuladd, {llvm_type(C->get_type()->get_scalar_ty(), ctx)});
result->for_each([&](indices_t idx){
Value *res = TC->get_value(idx);
unsigned NK = A->get_type()->get_tile_shapes()[1]->get_value();
for(unsigned K = 0; K < NK; ++K){
indices_t a_idx = {idx[0], builder.getInt32(K)};
indices_t b_idx = {idx[1], builder.getInt32(K)};
Value *a = TA->get_value(a_idx);
Value *b = TB->get_value(b_idx);
res = builder.CreateCall(f_mul_add, {a, b, res});
}
result->set_value(idx, res);
});
// ir::value *A = ins->get_operand(0);
// ir::value *B = ins->get_operand(1);
// ir::value *C = ins->get_operand(2);
// shared_tile *TA = (shared_tile*)tmap_.at(A);
// shared_tile *TB = (shared_tile*)tmap_.at(B);
// distributed_tile *TC = (distributed_tile*)tmap_.at(C);
// TA->set_vector_size(TC->axis(0).contiguous);
// TB->set_vector_size(TC->axis(1).contiguous);
// Function *f_mul_add = Intrinsic::getDeclaration(module, Intrinsic::fmuladd, {llvm_type(C->get_type()->get_scalar_ty(), ctx)});
// result->for_each([&](indices_t idx){
// Value *res = TC->get_value(idx);
// unsigned NK = A->get_type()->get_tile_shapes()[1]->get_value();
// for(unsigned K = 0; K < NK; ++K){
// indices_t a_idx = {idx[0], builder.getInt32(K)};
// indices_t b_idx = {idx[1], builder.getInt32(K)};
// Value *a = TA->get_value(a_idx);
// Value *b = TB->get_value(b_idx);
// res = builder.CreateCall(f_mul_add, {a, b, res});
// }
// result->set_value(idx, res);
// });
}
// element-wise
else {
@@ -805,7 +849,7 @@ ArrayType* selection::llvm_linearized_tile_type(ir::type *ty, LLVMContext &ctx)
return ArrayType::get(llvm_type(ty->get_scalar_ty(), ctx), size);
}
void selection::run(ir::module &src, Module &dst){
void selection::run(ir::module &src, Module &dst) {
vmap_.clear();
LLVMContext &dst_ctx = dst.getContext();
IRBuilder<> dst_builder(dst_ctx);
@@ -825,13 +869,7 @@ void selection::run(ir::module &src, Module &dst){
for(ir::attribute_t attr: attr_pair.second)
dst_fn->addAttribute(id, llvm_attr(attr));
}
// set metadata
llvm::Metadata *md_args[] = {
llvm::ValueAsMetadata::get(dst_fn),
llvm::MDString::get(dst_ctx, "kernel"),
llvm::ValueAsMetadata::get(dst_builder.getInt32(1))
};
dst.getOrInsertNamedMetadata("nvvm.annotations")->addOperand(llvm::MDNode::get(dst_ctx, md_args));
set_kernel(dst_builder, dst_ctx, &dst, dst_fn);
// map parameters
for(unsigned i = 0; i < fn->args().size(); i++)
@@ -842,82 +880,83 @@ void selection::run(ir::module &src, Module &dst){
vmap_[block] = dst_block;
}
dst_builder.SetInsertPoint((BasicBlock*)vmap_[fn->blocks()[0]]);
// allocate shared memory
Value *sh_mem_ptr = nullptr;
if(unsigned alloc_size = alloc_->get_allocated_size()){
Type *int_8_ty = Type::getInt8Ty(dst_ctx);
ArrayType *array_ty = ArrayType::get(int_8_ty, alloc_size);
Type *ptr_ty = PointerType::get(int_8_ty, 3);
GlobalVariable *sh_mem_array =
new GlobalVariable(dst, array_ty, false, GlobalVariable::ExternalLinkage,
nullptr, "__shared_ptr", nullptr, GlobalVariable::NotThreadLocal, 3);
sh_mem_ptr = dst_builder.CreateBitCast(sh_mem_array, ptr_ty);
}
// create grids
init_grids(fn, dst_builder, sh_mem_ptr);
std::map<ir::basic_block*, BasicBlock*> last_block;
// iterate through block
for(ir::basic_block *block: fn->blocks()) {
BasicBlock *parent = (BasicBlock*)vmap_[block];
dst_builder.SetInsertPoint(parent);
for(ir::instruction *i: block->get_inst_list()){
BasicBlock *current = dst_builder.GetInsertBlock();
bool phi_inserted = (dynamic_cast<ir::phi_node*>(i) || dynamic_cast<ir::merge_inst*>(i)) && !current->empty();
if(phi_inserted)
dst_builder.SetInsertPoint(&*current->getFirstInsertionPt());
lower_instruction(i, dst_builder);
if(phi_inserted)
dst_builder.SetInsertPoint(current);
last_block[block] = dst_builder.GetInsertBlock();
}
}
// add phi operands
for(ir::basic_block *block: fn->blocks())
for(ir::instruction *inst: block->get_inst_list())
if(auto *phi = dynamic_cast<ir::phi_node*>(inst)){
if(buffer_info_->is_double(phi)) {
PHINode *ptr = (PHINode*)((shared_tile*)tmap_.at(phi))->get_pointer();
PHINode *offset = (PHINode*)((shared_tile*)tmap_.at(phi))->get_offset();
for(unsigned n = 0; n < phi->get_num_incoming(); n++){
ir::basic_block* inc_block = phi->get_incoming_block(n);
ir::value* inc_val = phi->get_incoming_value(n);
ir::value* terminator = inc_block->get_inst_list().back();
BasicBlock *llvm_inc_block = last_block.at(inc_block);
shared_tile *inc_shared = (shared_tile*)tmap_.at(inc_val);
bool is_loop_latch = buffer_info_->is_loop_latch(phi, terminator);
if(is_loop_latch){
dst_builder.SetInsertPoint(llvm_inc_block->getTerminator());
Value *next_offset = dst_builder.CreateNeg(offset);
offset->addIncoming(next_offset, llvm_inc_block);
}
else {
offset->addIncoming(dst_builder.getInt32(alloc_->get_num_bytes(phi)/(2*4)), llvm_inc_block);
}
ptr->addIncoming(inc_shared->get_pointer(), llvm_inc_block);
}
}
else {
for(unsigned n = 0; n < phi->get_num_incoming(); n++){
ir::value *inc_val = phi->get_incoming_value(n);
ir::basic_block *inc_block = phi->get_incoming_block(n);
BasicBlock *llvm_inc_block = last_block.at(inc_block);
if(phi->get_type()->is_tile_ty()) {
distributed_tile *phi_tile = (distributed_tile*)tmap_.at(phi);
distributed_tile *inc_tile = (distributed_tile*)tmap_.at(inc_val);
phi_tile->for_each([&](indices_t idx){
PHINode *llvm_phi = (PHINode*)phi_tile->get_value(idx);
Value *llvm_inc_val = inc_tile->get_value(idx);
llvm_phi->addIncoming(llvm_inc_val, llvm_inc_block);
});
}
else {
PHINode *llvm_phi = (PHINode*)vmap_.at(phi);
Value *llvm_inc_val = vmap_.at(inc_val);
llvm_phi->addIncoming(llvm_inc_val, llvm_inc_block);
}
}
}
}
dst_builder.CreateRetVoid();
// // allocate shared memory
// Value *sh_mem_ptr = nullptr;
// if(unsigned alloc_size = alloc_->get_allocated_size()){
// Type *int_8_ty = Type::getInt8Ty(dst_ctx);
// ArrayType *array_ty = ArrayType::get(int_8_ty, alloc_size);
// Type *ptr_ty = PointerType::get(int_8_ty, 3);
// GlobalVariable *sh_mem_array =
// new GlobalVariable(dst, array_ty, false, GlobalVariable::ExternalLinkage,
// nullptr, "__shared_ptr", nullptr, GlobalVariable::NotThreadLocal, 3);
// sh_mem_ptr = dst_builder.CreateBitCast(sh_mem_array, ptr_ty);
// }
// // create grids
// init_grids(fn, dst_builder, sh_mem_ptr);
// std::map<ir::basic_block*, BasicBlock*> last_block;
// // iterate through block
// for(ir::basic_block *block: fn->blocks()) {
// BasicBlock *parent = (BasicBlock*)vmap_[block];
// dst_builder.SetInsertPoint(parent);
// for(ir::instruction *i: block->get_inst_list()){
// BasicBlock *current = dst_builder.GetInsertBlock();
// bool phi_inserted = (dynamic_cast<ir::phi_node*>(i) || dynamic_cast<ir::merge_inst*>(i)) && !current->empty();
// if(phi_inserted)
// dst_builder.SetInsertPoint(&*current->getFirstInsertionPt());
// lower_instruction(i, dst_builder);
// if(phi_inserted)
// dst_builder.SetInsertPoint(current);
// last_block[block] = dst_builder.GetInsertBlock();
// }
// }
// // add phi operands
// for(ir::basic_block *block: fn->blocks())
// for(ir::instruction *inst: block->get_inst_list())
// if(auto *phi = dynamic_cast<ir::phi_node*>(inst)){
// if(buffer_info_->is_double(phi)) {
// PHINode *ptr = (PHINode*)((shared_tile*)tmap_.at(phi))->get_pointer();
// PHINode *offset = (PHINode*)((shared_tile*)tmap_.at(phi))->get_offset();
// for(unsigned n = 0; n < phi->get_num_incoming(); n++){
// ir::basic_block* inc_block = phi->get_incoming_block(n);
// ir::value* inc_val = phi->get_incoming_value(n);
// ir::value* terminator = inc_block->get_inst_list().back();
// BasicBlock *llvm_inc_block = last_block.at(inc_block);
// shared_tile *inc_shared = (shared_tile*)tmap_.at(inc_val);
// bool is_loop_latch = buffer_info_->is_loop_latch(phi, terminator);
// if(is_loop_latch){
// dst_builder.SetInsertPoint(llvm_inc_block->getTerminator());
// Value *next_offset = dst_builder.CreateNeg(offset);
// offset->addIncoming(next_offset, llvm_inc_block);
// }
// else {
// offset->addIncoming(dst_builder.getInt32(alloc_->get_num_bytes(phi)/(2*4)), llvm_inc_block);
// }
// ptr->addIncoming(inc_shared->get_pointer(), llvm_inc_block);
// }
// }
// else {
// for(unsigned n = 0; n < phi->get_num_incoming(); n++){
// ir::value *inc_val = phi->get_incoming_value(n);
// ir::basic_block *inc_block = phi->get_incoming_block(n);
// BasicBlock *llvm_inc_block = last_block.at(inc_block);
// if(phi->get_type()->is_tile_ty()) {
// distributed_tile *phi_tile = (distributed_tile*)tmap_.at(phi);
// distributed_tile *inc_tile = (distributed_tile*)tmap_.at(inc_val);
// phi_tile->for_each([&](indices_t idx){
// PHINode *llvm_phi = (PHINode*)phi_tile->get_value(idx);
// Value *llvm_inc_val = inc_tile->get_value(idx);
// llvm_phi->addIncoming(llvm_inc_val, llvm_inc_block);
// });
// }
// else {
// PHINode *llvm_phi = (PHINode*)vmap_.at(phi);
// Value *llvm_inc_val = vmap_.at(inc_val);
// llvm_phi->addIncoming(llvm_inc_val, llvm_inc_block);
// }
// }
// }
// }
}
}

View File

@@ -47,11 +47,11 @@ driver::context* buffer::context() {
}
buffer* buffer::create(driver::context* ctx, size_t size) {
if(dynamic_cast<driver::cu_context*>(ctx))
return new cu_buffer(ctx, size);
if(dynamic_cast<driver::ocl_context*>(ctx))
return new ocl_buffer(ctx, size);
throw std::runtime_error("unknown context");
switch(ctx->backend()){
case CUDA: return new cu_buffer(ctx, size);
case OpenCL: return new ocl_buffer(ctx, size);
default: throw std::runtime_error("unknown backend");
}
}
//
@@ -59,7 +59,8 @@ buffer* buffer::create(driver::context* ctx, size_t size) {
ocl_buffer::ocl_buffer(driver::context* context, size_t size)
: buffer(context, cl_mem(), true){
cl_int err;
dispatch::clCreateBuffer(*context->cl(), CL_MEM_READ_WRITE, size, NULL, &err);
*cl_ = dispatch::clCreateBuffer(*context->cl(), CL_MEM_READ_WRITE, size, NULL, &err);
check(err);
}

View File

@@ -51,11 +51,11 @@ context::context(driver::device *dev, cl_context cl, bool take_ownership):
}
context* context::create(driver::device *dev){
if(dynamic_cast<driver::cu_device*>(dev))
return new cu_context(dev);
if(dynamic_cast<driver::ocl_device*>(dev))
return new ocl_context(dev);
throw std::runtime_error("unknown context");
switch(dev->backend()){
case CUDA: return new cu_context(dev);
case OpenCL: return new ocl_context(dev);
default: throw std::runtime_error("unknown backend");
}
}
@@ -99,7 +99,7 @@ cu_context::context_switcher::context_switcher(const context &ctx): ctx_((const
cu_context::context_switcher::~context_switcher() {
CUcontext tmp;
dispatch::cuCtxPopCurrent_v2(&tmp);
assert(tmp==(CUcontext)ctx_ && "Switching back to invalid context!");
assert(tmp==*ctx_.cu() && "Switching back to invalid context!");
}
// import CUdevice
@@ -129,6 +129,7 @@ cu_context::cu_context(driver::device* device): context(device, CUcontext(), tru
ocl_context::ocl_context(driver::device* dev): context(dev, cl_context(), true) {
cl_int err;
*cl_ = dispatch::clCreateContext(nullptr, 1, &*dev->cl(), nullptr, nullptr, &err);
check(err);
}

View File

@@ -286,6 +286,7 @@ OCL_DEFINE5(cl_int, clGetProgramInfo, cl_program, cl_program_info, size_t, void
OCL_DEFINE5(cl_int, clGetKernelInfo, cl_kernel, cl_kernel_info, size_t, void *, size_t *)
OCL_DEFINE6(cl_int, clGetKernelWorkGroupInfo, cl_kernel, cl_device_id, cl_kernel_work_group_info, size_t, void *, size_t *)
OCL_DEFINE3(cl_kernel, clCreateKernel, cl_program, const char *, cl_int *)
OCL_DEFINE4(cl_int, clCreateKernelsInProgram, cl_program, cl_uint, cl_kernel*, cl_uint*)
OCL_DEFINE5(cl_mem, clCreateBuffer, cl_context, cl_mem_flags, size_t, void *, cl_int *)
OCL_DEFINE5(cl_program, clCreateProgramWithSource, cl_context, cl_uint, const char **, const size_t *, cl_int *)
OCL_DEFINE1(cl_int, clReleaseKernel, cl_kernel)
@@ -343,6 +344,7 @@ void* dispatch::clGetProgramInfo_;
void* dispatch::clGetKernelInfo_;
void* dispatch::clGetKernelWorkGroupInfo_;
void* dispatch::clCreateKernel_;
void* dispatch::clCreateKernelsInProgram_;
void* dispatch::clCreateBuffer_;
void* dispatch::clCreateProgramWithSource_;
void* dispatch::clReleaseKernel_;

View File

@@ -133,6 +133,67 @@ void check(cudnnStatus_t err){
}
}
void check(cl_int err)
{
using namespace exception::ocl;
switch(err)
{
case CL_SUCCESS: break;
case CL_DEVICE_NOT_FOUND: throw device_not_found();
case CL_DEVICE_NOT_AVAILABLE: throw device_not_available();
case CL_COMPILER_NOT_AVAILABLE: throw compiler_not_available();
case CL_MEM_OBJECT_ALLOCATION_FAILURE: throw mem_object_allocation_failure();
case CL_OUT_OF_RESOURCES: throw out_of_resources();
case CL_OUT_OF_HOST_MEMORY: throw out_of_host_memory();
case CL_PROFILING_INFO_NOT_AVAILABLE: throw profiling_info_not_available();
case CL_MEM_COPY_OVERLAP: throw mem_copy_overlap();
case CL_IMAGE_FORMAT_MISMATCH: throw image_format_mismatch();
case CL_IMAGE_FORMAT_NOT_SUPPORTED: throw image_format_not_supported();
case CL_BUILD_PROGRAM_FAILURE: throw build_program_failure();
case CL_MAP_FAILURE: throw map_failure();
case CL_INVALID_VALUE: throw invalid_value();
case CL_INVALID_DEVICE_TYPE: throw invalid_device_type();
case CL_INVALID_PLATFORM: throw invalid_platform();
case CL_INVALID_DEVICE: throw invalid_device();
case CL_INVALID_CONTEXT: throw invalid_context();
case CL_INVALID_QUEUE_PROPERTIES: throw invalid_queue_properties();
case CL_INVALID_COMMAND_QUEUE: throw invalid_command_queue();
case CL_INVALID_HOST_PTR: throw invalid_host_ptr();
case CL_INVALID_MEM_OBJECT: throw invalid_mem_object();
case CL_INVALID_IMAGE_FORMAT_DESCRIPTOR: throw invalid_image_format_descriptor();
case CL_INVALID_IMAGE_SIZE: throw invalid_image_size();
case CL_INVALID_SAMPLER: throw invalid_sampler();
case CL_INVALID_BINARY: throw invalid_binary();
case CL_INVALID_BUILD_OPTIONS: throw invalid_build_options();
case CL_INVALID_PROGRAM: throw invalid_program();
case CL_INVALID_PROGRAM_EXECUTABLE: throw invalid_program_executable();
case CL_INVALID_KERNEL_NAME: throw invalid_kernel_name();
case CL_INVALID_KERNEL_DEFINITION: throw invalid_kernel_definition();
case CL_INVALID_KERNEL: throw invalid_kernel();
case CL_INVALID_ARG_INDEX: throw invalid_arg_index();
case CL_INVALID_ARG_VALUE: throw invalid_arg_value();
case CL_INVALID_ARG_SIZE: throw invalid_arg_size();
case CL_INVALID_KERNEL_ARGS: throw invalid_kernel_args();
case CL_INVALID_WORK_DIMENSION: throw invalid_work_dimension();
case CL_INVALID_WORK_GROUP_SIZE: throw invalid_work_group_size();
case CL_INVALID_WORK_ITEM_SIZE: throw invalid_work_item_size();
case CL_INVALID_GLOBAL_OFFSET: throw invalid_global_offset();
case CL_INVALID_EVENT_WAIT_LIST: throw invalid_event_wait_list();
case CL_INVALID_EVENT: throw invalid_event();
case CL_INVALID_OPERATION: throw invalid_operation();
case CL_INVALID_GL_OBJECT: throw invalid_gl_object();
case CL_INVALID_BUFFER_SIZE: throw invalid_buffer_size();
case CL_INVALID_MIP_LEVEL: throw invalid_mip_level();
case CL_INVALID_GLOBAL_WORK_SIZE: throw invalid_global_work_size();
#ifdef CL_INVALID_PROPERTY
case CL_INVALID_PROPERTY: throw invalid_property();
#endif
default: throw;
}
}
}
}

View File

@@ -27,13 +27,13 @@ namespace triton
namespace driver
{
float Event::elapsed_time() const{
float event::elapsed_time() const{
float time;
dispatch::cuEventElapsedTime(&time, cu_->first, cu_->second);
return time;
}
handle<cu_event_t> const & Event::cu() const
handle<cu_event_t> const & event::cu() const
{ return cu_; }
}

View File

@@ -46,11 +46,11 @@ kernel::kernel(driver::module *program, cl_kernel fn, bool has_ownership):
}
kernel* kernel::create(driver::module* program, const char* name) {
if(dynamic_cast<driver::cu_module*>(program))
return new cu_kernel(program, name);
if(dynamic_cast<driver::ocl_module*>(program))
return new ocl_kernel(program, name);
throw std::runtime_error("unknown program");
switch(program->backend()){
case CUDA: return new cu_kernel(program, name);
case OpenCL: return new ocl_kernel(program, name);
default: throw std::runtime_error("unknown backend");
}
}
driver::module* kernel::module() {
@@ -62,16 +62,21 @@ driver::module* kernel::module() {
/* ------------------------ */
ocl_kernel::ocl_kernel(driver::module* program, const char* name): kernel(program, cl_kernel(), true) {
// cl_uint res;
// check(dispatch::clCreateKernelsInProgram(*program->cl(), 0, NULL, &res));
// std::cout << res << std::endl;
cl_int err;
*cl_ = dispatch::clCreateKernel(*program->cl(), name, &err);
std::cout << *program->cl() << std::endl;
*cl_ = dispatch::clCreateKernel(*program->cl(), "matmul", &err);
check(err);
}
void ocl_kernel::setArg(unsigned int index, std::size_t size, void* ptr) {
dispatch::clSetKernelArg(*cl_, index, size, ptr);
check(dispatch::clSetKernelArg(*cl_, index, size, ptr));
}
void ocl_kernel::setArg(unsigned int index, driver::buffer* buffer) {
dispatch::clSetKernelArg(*cl_, index, sizeof(cl_mem), (void*)&*buffer->cl());
check(dispatch::clSetKernelArg(*cl_, index, sizeof(cl_mem), (void*)&*buffer->cl()));
}

View File

@@ -28,9 +28,15 @@
#include "triton/driver/error.h"
#include "triton/tools/sys/getenv.hpp"
#include "llvm/IR/IRPrintingPasses.h"
#include "llvm/Bitcode/BitcodeWriter.h"
#include "llvm/IR/Verifier.h"
#include "llvm/IR/Module.h"
#include "llvm/IR/LLVMContext.h"
#include "llvm/IR/PassManager.h"
#include "llvm/Support/SourceMgr.h"
#include "llvm/Linker/Linker.h"
#include "llvm/IRReader/IRReader.h"
#include "llvm/AsmParser/Parser.h"
#include "llvm/Support/raw_ostream.h"
#include "llvm/Support/TargetRegistry.h"
#include "llvm/Support/TargetSelect.h"
@@ -75,38 +81,60 @@ driver::context* module::context() const {
}
module* module::create(driver::context* ctx, llvm::Module *src) {
if(dynamic_cast<driver::cu_context*>(ctx))
return new cu_module(ctx, src);
if(dynamic_cast<driver::ocl_context*>(ctx))
return new ocl_module(ctx, src);
throw std::runtime_error("unknown context");
switch(ctx->backend()){
case CUDA: return new cu_module(ctx, src);
case OpenCL: return new ocl_module(ctx, src);
default: throw std::runtime_error("unknown backend");
}
}
void module::compile_llvm_module(llvm::Module* module, const std::string& triple,
const std::string &proc, std::string layout,
llvm::SmallVectorImpl<char> &buffer) {
llvm::SmallVectorImpl<char> &buffer,
std::vector<std::string> files) {
init_llvm();
// create machine
module->setTargetTriple(triple);
std::string error;
auto target = llvm::TargetRegistry::lookupTarget(module->getTargetTriple(), error);
llvm::TargetMachine *machine = target->createTargetMachine(module->getTargetTriple(), proc, "",
llvm::TargetOptions(), llvm::Reloc::Model(),
llvm::None, llvm::CodeGenOpt::Aggressive);
llvm::TargetOptions opt;
opt.AllowFPOpFusion = llvm::FPOpFusion::Fast;
opt.UnsafeFPMath = false;
opt.NoInfsFPMath = false;
opt.NoNaNsFPMath = true;
llvm::TargetMachine *machine = target->createTargetMachine(module->getTargetTriple(), proc, "", opt,
llvm::Reloc::PIC_, llvm::None, llvm::CodeGenOpt::Aggressive);
// set data layout
if(layout.empty())
layout = module->getDataLayoutStr();
module->setDataLayout(layout);
module->setDataLayout(machine->createDataLayout());
else
module->setDataLayout(layout);
// link
for (std::string& file: files) {
std::string path = "/opt/rocm/lib/" + file;
llvm::SMDiagnostic err;
std::unique_ptr<llvm::Module> mlib = llvm::parseIRFile(path, err, module->getContext());
if (mlib.get() == nullptr) {
std::string msg = err.getMessage();
std::cerr << "Fail to load bitcode file " << path << "\n"
<< "line " << err.getLineNo() << ":" << msg;
}
mlib->setTargetTriple(module->getTargetTriple());
mlib->setDataLayout(module->getDataLayout());
for (llvm::Function &f : mlib->functions()) {
f.addFnAttr(llvm::Attribute::AlwaysInline);
}
llvm::Linker::linkModules(*module, std::move(mlib));
}
std::cout << "compiling" << std::endl;
// emit machine code
llvm::legacy::PassManager pass;
llvm::raw_svector_ostream stream(buffer);
machine->addPassesToEmitFile(pass, stream, nullptr, llvm::TargetMachine::CGFT_AssemblyFile);
machine->addPassesToEmitFile(pass, stream, nullptr, llvm::TargetMachine::CGFT_ObjectFile);
pass.run(*module);
std::cout << "compiled" << std::endl;
// std::cout << std::string(buffer.begin(), buffer.end()) << std::endl;
}
/* ------------------------ */
@@ -114,9 +142,56 @@ void module::compile_llvm_module(llvm::Module* module, const std::string& triple
/* ------------------------ */
ocl_module::ocl_module(driver::context * context, llvm::Module* src): module(context, cl_program(), true) {
// const char* x = "__kernel void matmul(){ }";
// cl_int err;
// *cl_ = dispatch::clCreateProgramWithSource(*context->cl(), 1, &x, NULL, &err);
// check(err);
// return;
init_llvm();
llvm::SmallVector<char, 0> buffer;
module::compile_llvm_module(src, "amdgcn-amd-amdpal", "gfx902", "", buffer);
std::vector<std::string> files = {
"oclc_daz_opt_on.amdgcn.bc",
"ocml.amdgcn.bc",
"hc.amdgcn.bc",
"ockl.amdgcn.bc",
"oclc_correctly_rounded_sqrt_off.amdgcn.bc",
"oclc_correctly_rounded_sqrt_on.amdgcn.bc",
"oclc_daz_opt_off.amdgcn.bc",
"oclc_finite_only_off.amdgcn.bc",
"oclc_finite_only_on.amdgcn.bc",
"oclc_isa_version_803.amdgcn.bc",
"oclc_isa_version_900.amdgcn.bc",
"oclc_unsafe_math_off.amdgcn.bc",
"oclc_unsafe_math_on.amdgcn.bc",
"oclc_isa_version_700.amdgcn.bc",
"opencl.amdgcn.bc"
};
module::compile_llvm_module(src, "amdgcn-amd-amdpal", "gfx902", "", buffer, files);
// llvm::BitcodeWriter writer(buffer);
// writer.writeModule(*src);
// llvm::legacy::PassManager pass;
// llvm::raw_svector_ostream stream(buffer);
// pass.add(llvm::createPrintModulePass(stream));
// pass.run(*src);
size_t sizes[] = {buffer.size()};
const unsigned char* data[] = {(unsigned char*)buffer.data()};
cl_int status;
cl_int err;
*cl_ = dispatch::clCreateProgramWithBinary(*context->cl(), 1, &*context->device()->cl(), sizes, data, &status, &err);
check(err);
check(status);
try{
dispatch::clBuildProgram(*cl_, 1, &*context->device()->cl(), NULL, NULL, NULL);
}
catch(...){
char log[2048];
dispatch::clGetProgramBuildInfo(*cl_, *context->device()->cl(), CL_PROGRAM_BUILD_LOG, 1024, log, NULL);
std::cout << log << std::endl;
}
}

View File

@@ -58,20 +58,26 @@ void cu_platform::devices(std::vector<device *> &devices) const{
std::string cl_platform::version() const {
size_t size;
dispatch::clGetPlatformInfo(*cl_, CL_PLATFORM_VERSION, 0, nullptr, &size);
check(dispatch::clGetPlatformInfo(*cl_, CL_PLATFORM_VERSION, 0, nullptr, &size));
std::string result(size, 0);
dispatch::clGetPlatformInfo(*cl_, CL_PLATFORM_VERSION, size, (void*)&*result.begin(), nullptr);
check(dispatch::clGetPlatformInfo(*cl_, CL_PLATFORM_VERSION, size, (void*)&*result.begin(), nullptr));
return result;
}
void cl_platform::devices(std::vector<device*> &devices) const{
cl_uint num_devices;
dispatch::clGetDeviceIDs(*cl_, CL_DEVICE_TYPE_GPU, 0, nullptr, &num_devices);
check(dispatch::clGetDeviceIDs(*cl_, CL_DEVICE_TYPE_GPU, 0, nullptr, &num_devices));
std::vector<cl_device_id> ids(num_devices);
dispatch::clGetDeviceIDs(*cl_, CL_DEVICE_TYPE_GPU, num_devices, ids.data(), nullptr);
check(dispatch::clGetDeviceIDs(*cl_, CL_DEVICE_TYPE_GPU, num_devices, ids.data(), nullptr));
for(cl_device_id id: ids)
devices.push_back(new driver::ocl_device(id));
}
/* ------------------------ */
// Vulkan //
/* ------------------------ */
}
}

View File

@@ -44,23 +44,20 @@ namespace driver
stream::stream(driver::context *ctx, CUstream cu, bool has_ownership)
: polymorphic_resource(cu, has_ownership), ctx_(ctx) {
}
stream::stream(driver::context *ctx, cl_command_queue cl, bool has_ownership)
: polymorphic_resource(cl, has_ownership), ctx_(ctx) {
}
driver::stream* stream::create(driver::context* ctx) {
if(dynamic_cast<driver::cu_context*>(ctx))
return new cu_stream(ctx);
if(dynamic_cast<driver::ocl_context*>(ctx))
return new cl_stream(ctx);
throw std::runtime_error("unknown context");
switch(ctx->backend()){
case CUDA: return new cu_stream(ctx);
case OpenCL: return new cl_stream(ctx);
default: throw std::runtime_error("unknown backend");
}
}
driver::context* stream::context() const {
return ctx_;
}
@@ -73,22 +70,23 @@ driver::context* stream::context() const {
cl_stream::cl_stream(driver::context *ctx): stream(ctx, cl_command_queue(), true) {
cl_int err;
*cl_ = dispatch::clCreateCommandQueue(*ctx->cl(), *ctx->device()->cl(), 0, &err);
check(err);
}
void cl_stream::synchronize() {
dispatch::clFinish(*cl_);
check(dispatch::clFinish(*cl_));
}
void cl_stream::enqueue(driver::kernel* kernel, std::array<size_t, 3> grid, std::array<size_t, 3> block, std::vector<Event> const *, Event* event) {
cl_int err = dispatch::clEnqueueNDRangeKernel(*cl_, *kernel->cl(), grid.size(), NULL, (const size_t*)grid.data(), (const size_t*)block.data(), 0, NULL, NULL);
void cl_stream::enqueue(driver::kernel* kernel, std::array<size_t, 3> grid, std::array<size_t, 3> block, std::vector<event> const *, event* event) {
check(dispatch::clEnqueueNDRangeKernel(*cl_, *kernel->cl(), grid.size(), NULL, (const size_t*)grid.data(), (const size_t*)block.data(), 0, NULL, NULL));
}
void cl_stream::write(driver::buffer* buffer, bool blocking, std::size_t offset, std::size_t size, void const* ptr) {
cl_int err = dispatch::clEnqueueWriteBuffer(*cl_, *buffer->cl(), blocking?CL_TRUE:CL_FALSE, offset, size, ptr, 0, NULL, NULL);
check(dispatch::clEnqueueWriteBuffer(*cl_, *buffer->cl(), blocking?CL_TRUE:CL_FALSE, offset, size, ptr, 0, NULL, NULL));
}
void cl_stream::read(driver::buffer* buffer, bool blocking, std::size_t offset, std::size_t size, void* ptr) {
cl_int err = dispatch::clEnqueueReadBuffer(*cl_, *buffer->cl(), blocking?CL_TRUE:CL_FALSE, offset, size, ptr, 0, NULL, NULL);
check(dispatch::clEnqueueReadBuffer(*cl_, *buffer->cl(), blocking?CL_TRUE:CL_FALSE, offset, size, ptr, 0, NULL, NULL));
}
/* ------------------------ */
@@ -115,7 +113,7 @@ void cu_stream::synchronize() {
dispatch::cuStreamSynchronize(*cu_);
}
void cu_stream::enqueue(driver::kernel* kernel, std::array<size_t, 3> grid, std::array<size_t, 3> block, std::vector<Event> const *, Event* event) {
void cu_stream::enqueue(driver::kernel* kernel, std::array<size_t, 3> grid, std::array<size_t, 3> block, std::vector<event> const *, event* event) {
driver::cu_kernel* cu_kernel = (driver::cu_kernel*)kernel;
cu_context::context_switcher ctx_switch(*ctx_);
if(event)