[jit] added runtime for host but compilation still needs to be implemented

This commit is contained in:
Philippe Tillet
2019-03-23 13:40:42 -07:00
parent 49fd6ece99
commit 9de9feff4a
21 changed files with 389 additions and 234 deletions

View File

@@ -24,6 +24,7 @@ namespace codegen{
class allocation; class allocation;
class tune; class tune;
class buffer_info_pass; class buffer_info_pass;
class target;
typedef std::vector<llvm::Value*> indices_t; typedef std::vector<llvm::Value*> indices_t;
@@ -128,7 +129,9 @@ private:
void lower_tile_instruction(ir::instruction *src, llvm::IRBuilder<> &builder); void lower_tile_instruction(ir::instruction *src, llvm::IRBuilder<> &builder);
public: public:
selection(allocation *alloc, tune *params, buffer_info_pass *buffer_info): alloc_(alloc), params_(params), buffer_info_(buffer_info){ } selection(allocation *alloc, tune *params, buffer_info_pass *buffer_info, target *tgt)
: alloc_(alloc), params_(params), buffer_info_(buffer_info), tgt_(tgt){ }
void run(ir::module &src, llvm::Module &dst); void run(ir::module &src, llvm::Module &dst);
private: private:
@@ -138,6 +141,7 @@ private:
pmap_t last_block_; pmap_t last_block_;
allocation *alloc_; allocation *alloc_;
tune *params_; tune *params_;
target *tgt_;
buffer_info_pass *buffer_info_; buffer_info_pass *buffer_info_;
std::map<ir::metaparameter*, distributed_axis> axes_; std::map<ir::metaparameter*, distributed_axis> axes_;
}; };

View File

@@ -17,8 +17,6 @@ namespace ir{
namespace codegen{ namespace codegen{
class place_shared_copy;
class tune { class tune {
typedef std::pair<ir::value*, unsigned> node_t; typedef std::pair<ir::value*, unsigned> node_t;
typedef std::map <node_t, std::set<node_t>> graph_t; typedef std::map <node_t, std::set<node_t>> graph_t;

View File

@@ -34,10 +34,11 @@ namespace driver
class cu_stream; class cu_stream;
// Base // Base
class buffer : public polymorphic_resource<CUdeviceptr, cl_mem> { class buffer : public polymorphic_resource<CUdeviceptr, cl_mem, host_buffer_t> {
public: public:
buffer(driver::context* ctx, CUdeviceptr cl, bool take_ownership); buffer(driver::context* ctx, CUdeviceptr cl, bool take_ownership);
buffer(driver::context* ctx, cl_mem cl, bool take_ownership); buffer(driver::context* ctx, cl_mem cl, bool take_ownership);
buffer(driver::context* ctx, host_buffer_t hst, bool take_ownership);
static buffer* create(driver::context* ctx, size_t size); static buffer* create(driver::context* ctx, size_t size);
driver::context* context(); driver::context* context();
@@ -46,9 +47,10 @@ protected:
}; };
// CPU // CPU
class cpu_buffer: public buffer class host_buffer: public buffer
{ {
public:
host_buffer(driver::context* context, size_t size);
}; };
// OpenCL // OpenCL

View File

@@ -31,13 +31,14 @@ namespace triton
namespace driver namespace driver
{ {
class context: public polymorphic_resource<CUcontext, cl_context>{ class context: public polymorphic_resource<CUcontext, cl_context, host_context_t>{
protected: protected:
static std::string get_cache_path(); static std::string get_cache_path();
public: public:
context(driver::device *dev, CUcontext cu, bool take_ownership); context(driver::device *dev, CUcontext cu, bool take_ownership);
context(driver::device *dev, cl_context cl, bool take_ownership); context(driver::device *dev, cl_context cl, bool take_ownership);
context(driver::device *dev, host_context_t hst, bool take_ownership);
driver::device* device() const; driver::device* device() const;
std::string const & cache_path() const; std::string const & cache_path() const;
// factory methods // factory methods
@@ -48,9 +49,10 @@ protected:
std::string cache_path_; std::string cache_path_;
}; };
// CPU // Host
class cpu_context: public context { class host_context: public context {
public:
host_context(driver::device* dev);
}; };
// CUDA // CUDA

View File

@@ -35,14 +35,15 @@ namespace driver
class context; class context;
// Base device // Base device
class device: public polymorphic_resource<CUdevice, cl_device_id>{ class device: public polymorphic_resource<CUdevice, cl_device_id, host_device_t>{
public: public:
using polymorphic_resource::polymorphic_resource; using polymorphic_resource::polymorphic_resource;
}; };
// CPU device // Host device
class cpu_device: public device { class host_device: public device {
public:
host_device(): device(host_device_t(), true){ }
}; };
// OpenCL device // OpenCL device

View File

@@ -24,11 +24,18 @@
#define TDL_INCLUDE_DRIVER_HANDLE_H #define TDL_INCLUDE_DRIVER_HANDLE_H
#include <memory> #include <memory>
#include <map>
#include <iostream> #include <iostream>
#include <functional> #include <functional>
#include <type_traits> #include <type_traits>
#include "triton/driver/dispatch.h" #include "triton/driver/dispatch.h"
namespace llvm
{
class ExecutionEngine;
class Function;
}
namespace triton namespace triton
{ {
@@ -37,10 +44,43 @@ namespace driver
enum backend_t { enum backend_t {
CUDA, CUDA,
OpenCL OpenCL,
Host
}; };
// helpers for CUDA // Host handles
struct host_platform_t{
};
struct host_device_t{
};
struct host_context_t{
};
struct host_stream_t{
};
struct host_module_t{
std::string error;
llvm::ExecutionEngine* engine;
std::map<std::string, llvm::Function*> functions;
};
struct host_function_t{
llvm::Function* fn;
};
struct host_buffer_t{
char* data;
};
// Extra CUDA handles
struct cu_event_t{ struct cu_event_t{
operator bool() const { return first && second; } operator bool() const { return first && second; }
CUevent first; CUevent first;
@@ -82,22 +122,26 @@ protected:
bool has_ownership_; bool has_ownership_;
}; };
template<class CUType, class CLType> template<class CUType, class CLType, class HostType>
class polymorphic_resource { class polymorphic_resource {
public: public:
polymorphic_resource(CUType cu, bool take_ownership): cu_(cu, take_ownership), backend_(CUDA){} polymorphic_resource(CUType cu, bool take_ownership): cu_(cu, take_ownership), backend_(CUDA){}
polymorphic_resource(CLType cl, bool take_ownership): cl_(cl, take_ownership), backend_(OpenCL){} polymorphic_resource(CLType cl, bool take_ownership): cl_(cl, take_ownership), backend_(OpenCL){}
polymorphic_resource(HostType hst, bool take_ownership): hst_(hst, take_ownership), backend_(Host){}
virtual ~polymorphic_resource() { } virtual ~polymorphic_resource() { }
handle<CUType> cu() { return cu_; } handle<CUType> cu() { return cu_; }
handle<CLType> cl() { return cl_; } handle<CLType> cl() { return cl_; }
handle<HostType> hst() { return hst_; }
const handle<CUType>& cu() const { return cu_; } const handle<CUType>& cu() const { return cu_; }
const handle<CLType>& cl() const { return cl_; } const handle<CLType>& cl() const { return cl_; }
const handle<HostType>& hst() const { return hst_; }
backend_t backend() { return backend_; } backend_t backend() { return backend_; }
protected: protected:
handle<CLType> cl_; handle<CLType> cl_;
handle<CUType> cu_; handle<CUType> cu_;
handle<HostType> hst_;
backend_t backend_; backend_t backend_;
}; };

View File

@@ -28,6 +28,11 @@
#include <memory> #include <memory>
namespace llvm
{
class GenericValue;
}
namespace triton namespace triton
{ {
@@ -37,10 +42,11 @@ namespace driver
class cu_buffer; class cu_buffer;
// Base // Base
class kernel: public polymorphic_resource<CUfunction, cl_kernel> { class kernel: public polymorphic_resource<CUfunction, cl_kernel, host_function_t> {
public: public:
kernel(driver::module* program, CUfunction fn, bool has_ownership); kernel(driver::module* program, CUfunction fn, bool has_ownership);
kernel(driver::module* program, cl_kernel fn, bool has_ownership); kernel(driver::module* program, cl_kernel fn, bool has_ownership);
kernel(driver::module* program, host_function_t fn, bool has_ownership);
// Getters // Getters
driver::module* module(); driver::module* module();
// Factory methods // Factory methods
@@ -53,9 +59,19 @@ private:
driver::module* program_; driver::module* program_;
}; };
// CPU // Host
class cpu_kernel: public kernel { class host_kernel: public kernel {
public:
//Constructors
host_kernel(driver::module* program, const char* name);
// Arguments setters
void setArg(unsigned int index, std::size_t size, void* ptr);
void setArg(unsigned int index, driver::buffer* buffer);
// Params
const std::vector<llvm::GenericValue>& params();
private:
std::vector<std::shared_ptr<void> > params_store_;
std::vector<llvm::GenericValue> params_;
}; };
// OpenCL // OpenCL
@@ -81,8 +97,6 @@ public:
void* const* cu_params() const; void* const* cu_params() const;
private: private:
handle<CUfunction> cu_;
driver::cu_module* program_;
std::vector<std::shared_ptr<void> > cu_params_store_; std::vector<std::shared_ptr<void> > cu_params_store_;
std::vector<void*> cu_params_; std::vector<void*> cu_params_;
}; };

View File

@@ -45,13 +45,14 @@ class cu_context;
class cu_device; class cu_device;
// Base // Base
class module: public polymorphic_resource<CUmodule, cl_program> { class module: public polymorphic_resource<CUmodule, cl_program, host_module_t> {
protected: protected:
void init_llvm(); void init_llvm();
public: public:
module(driver::context* ctx, CUmodule mod, bool has_ownership); module(driver::context* ctx, CUmodule mod, bool has_ownership);
module(driver::context* ctx, cl_program mod, bool has_ownership); module(driver::context* ctx, cl_program mod, bool has_ownership);
module(driver::context* ctx, host_module_t mod, bool has_ownership);
static module* create(driver::context* ctx, llvm::Module *src); static module* create(driver::context* ctx, llvm::Module *src);
driver::context* context() const; driver::context* context() const;
void compile_llvm_module(llvm::Module* module, const std::string& triple, void compile_llvm_module(llvm::Module* module, const std::string& triple,
@@ -63,8 +64,9 @@ protected:
}; };
// CPU // CPU
class cpu_module: public module{ class host_module: public module{
public:
host_module(driver::context* context, llvm::Module *module);
}; };
// OpenCL // OpenCL

View File

@@ -74,11 +74,11 @@ private:
handle<cl_platform_id> cl_; handle<cl_platform_id> cl_;
}; };
// CPU // Host
class cpu_platform: public platform class host_platform: public platform
{ {
public: public:
cpu_platform(): platform("CPU") { } host_platform(): platform("CPU") { }
std::string version() const; std::string version() const;
void devices(std::vector<driver::device*> &devices) const; void devices(std::vector<driver::device*> &devices) const;
}; };

View File

@@ -41,10 +41,11 @@ class Range;
class cu_buffer; class cu_buffer;
// Base // Base
class stream: public polymorphic_resource<CUstream, cl_command_queue> { class stream: public polymorphic_resource<CUstream, cl_command_queue, host_stream_t> {
public: public:
stream(driver::context *ctx, CUstream, bool has_ownership); stream(driver::context *ctx, CUstream, bool has_ownership);
stream(driver::context *ctx, cl_command_queue, bool has_ownership); stream(driver::context *ctx, cl_command_queue, bool has_ownership);
stream(driver::context *ctx, host_stream_t, bool has_ownership);
// factory // factory
static driver::stream* create(driver::context* ctx); static driver::stream* create(driver::context* ctx);
// accessors // accessors
@@ -64,9 +65,17 @@ protected:
driver::context *ctx_; driver::context *ctx_;
}; };
// CPU // Host
class cpu_stream: public stream { class host_stream: public stream {
public:
// Constructors
host_stream(driver::context *ctx);
// Overridden
void synchronize();
void enqueue(driver::kernel* kernel, std::array<size_t, 3> grid, std::array<size_t, 3> block, std::vector<event> const *, event *event);
void write(driver::buffer* buf, bool blocking, std::size_t offset, std::size_t size, void const* ptr);
void read(driver::buffer* buf, bool blocking, std::size_t offset, std::size_t size, void* ptr);
}; };
// OpenCL // OpenCL

View File

@@ -15,6 +15,7 @@
#include "triton/codegen/vectorize.h" #include "triton/codegen/vectorize.h"
#include "triton/codegen/buffer_info.h" #include "triton/codegen/buffer_info.h"
#include "triton/codegen/barriers.h" #include "triton/codegen/barriers.h"
#include "triton/codegen/target.h"
#include <functional> #include <functional>
namespace llvm { namespace llvm {
@@ -42,11 +43,12 @@ public:
typedef std::function<double(driver::kernel*, launch_information)> benchmark_t; typedef std::function<double(driver::kernel*, launch_information)> benchmark_t;
struct passes_wrapper { struct passes_wrapper {
passes_wrapper(): shared(&buffer_info), liveness(&buffer_info), passes_wrapper(codegen::target* target)
: shared(&buffer_info), liveness(&buffer_info),
allocation(&liveness, &buffer_info), allocation(&liveness, &buffer_info),
barriers(&allocation, &buffer_info), barriers(&allocation, &buffer_info),
vectorize(&tune), vectorize(&tune),
selection(&allocation, &tune, &buffer_info){ } selection(&allocation, &tune, &buffer_info, target) { }
void init(ir::module &module) { void init(ir::module &module) {
// generate ptx // generate ptx
@@ -89,6 +91,7 @@ private:
ir::context triton_context_; ir::context triton_context_;
std::map<std::string, launch_information> launch_info_map_; std::map<std::string, launch_information> launch_info_map_;
std::map<std::string, unsigned> global_ints_; std::map<std::string, unsigned> global_ints_;
std::unique_ptr<triton::codegen::target> target_;
}; };

View File

@@ -1,6 +1,7 @@
#include "triton/codegen/selection.h" #include "triton/codegen/selection.h"
#include "triton/codegen/tune.h" #include "triton/codegen/tune.h"
#include "triton/codegen/allocation.h" #include "triton/codegen/allocation.h"
#include "triton/codegen/target.h"
#include "llvm/IR/InstrTypes.h" #include "llvm/IR/InstrTypes.h"
#include "llvm/IR/Module.h" #include "llvm/IR/Module.h"
#include "llvm/IR/IRBuilder.h" #include "llvm/IR/IRBuilder.h"
@@ -19,59 +20,6 @@ namespace codegen{
using namespace llvm; using namespace llvm;
inline void set_kernel(llvm::IRBuilder<>& builder, llvm::LLVMContext &ctx, llvm::Module *module, llvm::Function* fn) {
fn->setCallingConv(llvm::CallingConv::AMDGPU_KERNEL);
// module->getOrInsertNamedMetadata("opencl.ocl.version")->addOperand(llvm::MDTuple::get(ctx, {llvm::ValueAsMetadata::get(builder.getInt32(2)), llvm::ValueAsMetadata::get(builder.getInt32(0))}));
// // set metadata
// llvm::Metadata *md_args[] = {
// llvm::ValueAsMetadata::get(dst_fn),
// llvm::MDString::get(dst_ctx, "kernel"),
// llvm::ValueAsMetadata::get(dst_builder.getInt32(1))
// };
// module->getOrInsertNamedMetadata("nvvm.annotations")->addOperand(llvm::MDNode::get(dst_ctx, md_args));
}
inline Instruction* add_barrier(llvm::Module *module, llvm::IRBuilder<>& builder) {
// Function *barrier = Intrinsic::getDeclaration(module, Intrinsic::nvvm_barrier0);
// return builder.CreateCall(barrier, {});
Function *barrier = Intrinsic::getDeclaration(module, Intrinsic::amdgcn_s_barrier);
return builder.CreateCall(barrier, {});
}
inline Value* get_global_offset(llvm::Module *module, llvm::IRBuilder<>& builder, unsigned stride, unsigned ax) {
// static std::array<Intrinsic::ID, 3> ctaid = {
// Intrinsic::nvvm_read_ptx_sreg_ctaid_x,
// Intrinsic::nvvm_read_ptx_sreg_ctaid_y,
// Intrinsic::nvvm_read_ptx_sreg_ctaid_z
// };
static std::array<Intrinsic::ID, 3> ids = {
Intrinsic::amdgcn_workgroup_id_x,
Intrinsic::amdgcn_workgroup_id_y,
Intrinsic::amdgcn_workgroup_id_z
};
Value* get_group_id = Intrinsic::getDeclaration(module, ids[ax]);
Value* group_id = builder.CreateCall(get_group_id, {});
Value* result = builder.CreateMul(builder.getInt32(stride), group_id);
return result;
}
inline Value* get_local_id(llvm::Module *module, llvm::IRBuilder<>& builder, unsigned ax) {
// static std::array<Intrinsic::ID, 3> ids = {
// Intrinsic::nvvm_read_ptx_sreg_tid_x,
// Intrinsic::nvvm_read_ptx_sreg_tid_y,
// Intrinsic::nvvm_read_ptx_sreg_tid_z
// };
static std::array<Intrinsic::ID, 3> ids = {
Intrinsic::amdgcn_workitem_id_x,
Intrinsic::amdgcn_workitem_id_y,
Intrinsic::amdgcn_workitem_id_z
};
Function *get_local_id = Intrinsic::getDeclaration(module, ids[ax]);
return builder.CreateCall(get_local_id, {});
}
/* Distributed Tile */ /* Distributed Tile */
void distributed_tile::init_indices() { void distributed_tile::init_indices() {
std::vector<size_t> id(axes_.size(), 0); std::vector<size_t> id(axes_.size(), 0);
@@ -317,7 +265,7 @@ Instruction *selection::llvm_inst(ir::instruction *inst, std::function<Value*(ir
} }
if(dynamic_cast<ir::barrier_inst*>(inst)){ if(dynamic_cast<ir::barrier_inst*>(inst)){
Module *module = builder.GetInsertBlock()->getModule(); Module *module = builder.GetInsertBlock()->getModule();
return add_barrier(module, builder); return tgt_->add_barrier(module, builder);
} }
if(auto* ii = dynamic_cast<ir::phi_node*>(inst)){ if(auto* ii = dynamic_cast<ir::phi_node*>(inst)){
Type *ty = type(ii->get_type()->get_scalar_ty()); Type *ty = type(ii->get_type()->get_scalar_ty());
@@ -614,7 +562,7 @@ void selection::init_grids(ir::function *fn, IRBuilder<> &builder, Value *sh_mem
// fetch linear ID // fetch linear ID
Module *mod = builder.GetInsertBlock()->getParent()->getParent(); Module *mod = builder.GetInsertBlock()->getParent()->getParent();
Value *warp_size = builder.getInt32(32); Value *warp_size = builder.getInt32(32);
Value* u_thread_id = get_local_id(mod, builder, 0); Value* u_thread_id = tgt_->get_local_id(mod, builder, 0);
Value *u_thread_warp_id = builder.CreateURem(u_thread_id, warp_size); Value *u_thread_warp_id = builder.CreateURem(u_thread_id, warp_size);
Value *u_warp_id = builder.CreateUDiv(u_thread_id, warp_size); Value *u_warp_id = builder.CreateUDiv(u_thread_id, warp_size);
// create grid // create grid
@@ -670,7 +618,7 @@ void selection::lower_tile_instruction(ir::instruction *ins, llvm::IRBuilder<> &
const auto& shapes = ins->get_type()->get_tile_shapes(); const auto& shapes = ins->get_type()->get_tile_shapes();
// global_range // global_range
if(auto *x = dynamic_cast<ir::get_global_range_inst*>(ins)) { if(auto *x = dynamic_cast<ir::get_global_range_inst*>(ins)) {
Value *offset = get_global_offset(module, builder, shapes[0]->get_value(), x->get_axis()); Value *offset = tgt_->get_global_offset(module, builder, shapes[0]->get_value(), x->get_axis());
result->for_each([&](indices_t idx){ result->for_each([&](indices_t idx){
BinaryOperator *bin = static_cast<BinaryOperator*>(idx[0]); BinaryOperator *bin = static_cast<BinaryOperator*>(idx[0]);
result->set_value(idx, builder.CreateAdd(bin, offset)); result->set_value(idx, builder.CreateAdd(bin, offset));
@@ -783,27 +731,27 @@ void selection::lower_tile_instruction(ir::instruction *ins, llvm::IRBuilder<> &
return; return;
// matrix multiplication // matrix multiplication
else if(dynamic_cast<ir::matmul_inst*>(ins)) { else if(dynamic_cast<ir::matmul_inst*>(ins)) {
// ir::value *A = ins->get_operand(0); ir::value *A = ins->get_operand(0);
// ir::value *B = ins->get_operand(1); ir::value *B = ins->get_operand(1);
// ir::value *C = ins->get_operand(2); ir::value *C = ins->get_operand(2);
// shared_tile *TA = (shared_tile*)tmap_.at(A); shared_tile *TA = (shared_tile*)tmap_.at(A);
// shared_tile *TB = (shared_tile*)tmap_.at(B); shared_tile *TB = (shared_tile*)tmap_.at(B);
// distributed_tile *TC = (distributed_tile*)tmap_.at(C); distributed_tile *TC = (distributed_tile*)tmap_.at(C);
// TA->set_vector_size(TC->axis(0).contiguous); TA->set_vector_size(TC->axis(0).contiguous);
// TB->set_vector_size(TC->axis(1).contiguous); TB->set_vector_size(TC->axis(1).contiguous);
// Function *f_mul_add = Intrinsic::getDeclaration(module, Intrinsic::fmuladd, {llvm_type(C->get_type()->get_scalar_ty(), ctx)}); Function *f_mul_add = Intrinsic::getDeclaration(module, Intrinsic::fmuladd, {llvm_type(C->get_type()->get_scalar_ty(), ctx)});
// result->for_each([&](indices_t idx){ result->for_each([&](indices_t idx){
// Value *res = TC->get_value(idx); Value *res = TC->get_value(idx);
// unsigned NK = A->get_type()->get_tile_shapes()[1]->get_value(); unsigned NK = A->get_type()->get_tile_shapes()[1]->get_value();
// for(unsigned K = 0; K < NK; ++K){ for(unsigned K = 0; K < NK; ++K){
// indices_t a_idx = {idx[0], builder.getInt32(K)}; indices_t a_idx = {idx[0], builder.getInt32(K)};
// indices_t b_idx = {idx[1], builder.getInt32(K)}; indices_t b_idx = {idx[1], builder.getInt32(K)};
// Value *a = TA->get_value(a_idx); Value *a = TA->get_value(a_idx);
// Value *b = TB->get_value(b_idx); Value *b = TB->get_value(b_idx);
// res = builder.CreateCall(f_mul_add, {a, b, res}); res = builder.CreateCall(f_mul_add, {a, b, res});
// } }
// result->set_value(idx, res); result->set_value(idx, res);
// }); });
} }
// element-wise // element-wise
else { else {
@@ -869,7 +817,7 @@ void selection::run(ir::module &src, Module &dst) {
for(ir::attribute_t attr: attr_pair.second) for(ir::attribute_t attr: attr_pair.second)
dst_fn->addAttribute(id, llvm_attr(attr)); dst_fn->addAttribute(id, llvm_attr(attr));
} }
set_kernel(dst_builder, dst_ctx, &dst, dst_fn); tgt_->set_kernel(dst_builder, dst_ctx, &dst, dst_fn);
// map parameters // map parameters
for(unsigned i = 0; i < fn->args().size(); i++) for(unsigned i = 0; i < fn->args().size(); i++)
@@ -880,83 +828,86 @@ void selection::run(ir::module &src, Module &dst) {
vmap_[block] = dst_block; vmap_[block] = dst_block;
} }
dst_builder.SetInsertPoint((BasicBlock*)vmap_[fn->blocks()[0]]); dst_builder.SetInsertPoint((BasicBlock*)vmap_[fn->blocks()[0]]);
dst_builder.CreateRetVoid();
// // allocate shared memory // allocate shared memory
// Value *sh_mem_ptr = nullptr; Value *sh_mem_ptr = nullptr;
// if(unsigned alloc_size = alloc_->get_allocated_size()){ if(unsigned alloc_size = alloc_->get_allocated_size()){
// Type *int_8_ty = Type::getInt8Ty(dst_ctx); Type *int_8_ty = Type::getInt8Ty(dst_ctx);
// ArrayType *array_ty = ArrayType::get(int_8_ty, alloc_size); ArrayType *array_ty = ArrayType::get(int_8_ty, alloc_size);
// Type *ptr_ty = PointerType::get(int_8_ty, 3); Type *ptr_ty = PointerType::get(int_8_ty, 3);
// GlobalVariable *sh_mem_array = GlobalVariable *sh_mem_array =
// new GlobalVariable(dst, array_ty, false, GlobalVariable::ExternalLinkage, new GlobalVariable(dst, array_ty, false, GlobalVariable::ExternalLinkage,
// nullptr, "__shared_ptr", nullptr, GlobalVariable::NotThreadLocal, 3); nullptr, "__shared_ptr", nullptr, GlobalVariable::NotThreadLocal, 3);
// sh_mem_ptr = dst_builder.CreateBitCast(sh_mem_array, ptr_ty); sh_mem_ptr = dst_builder.CreateBitCast(sh_mem_array, ptr_ty);
// } }
// // create grids
// init_grids(fn, dst_builder, sh_mem_ptr); // create grids
// std::map<ir::basic_block*, BasicBlock*> last_block; init_grids(fn, dst_builder, sh_mem_ptr);
// // iterate through block
// for(ir::basic_block *block: fn->blocks()) { // iterate through block
// BasicBlock *parent = (BasicBlock*)vmap_[block]; std::map<ir::basic_block*, BasicBlock*> last_block;
// dst_builder.SetInsertPoint(parent); for(ir::basic_block *block: fn->blocks()) {
// for(ir::instruction *i: block->get_inst_list()){ BasicBlock *parent = (BasicBlock*)vmap_[block];
// BasicBlock *current = dst_builder.GetInsertBlock(); dst_builder.SetInsertPoint(parent);
// bool phi_inserted = (dynamic_cast<ir::phi_node*>(i) || dynamic_cast<ir::merge_inst*>(i)) && !current->empty(); for(ir::instruction *i: block->get_inst_list()){
// if(phi_inserted) BasicBlock *current = dst_builder.GetInsertBlock();
// dst_builder.SetInsertPoint(&*current->getFirstInsertionPt()); bool phi_inserted = (dynamic_cast<ir::phi_node*>(i) || dynamic_cast<ir::merge_inst*>(i)) && !current->empty();
// lower_instruction(i, dst_builder); if(phi_inserted)
// if(phi_inserted) dst_builder.SetInsertPoint(&*current->getFirstInsertionPt());
// dst_builder.SetInsertPoint(current); lower_instruction(i, dst_builder);
// last_block[block] = dst_builder.GetInsertBlock(); if(phi_inserted)
// } dst_builder.SetInsertPoint(current);
// } last_block[block] = dst_builder.GetInsertBlock();
// // add phi operands }
// for(ir::basic_block *block: fn->blocks()) }
// for(ir::instruction *inst: block->get_inst_list())
// if(auto *phi = dynamic_cast<ir::phi_node*>(inst)){ // add phi operands
// if(buffer_info_->is_double(phi)) { for(ir::basic_block *block: fn->blocks())
// PHINode *ptr = (PHINode*)((shared_tile*)tmap_.at(phi))->get_pointer(); for(ir::instruction *inst: block->get_inst_list())
// PHINode *offset = (PHINode*)((shared_tile*)tmap_.at(phi))->get_offset(); if(auto *phi = dynamic_cast<ir::phi_node*>(inst)){
// for(unsigned n = 0; n < phi->get_num_incoming(); n++){ if(buffer_info_->is_double(phi)) {
// ir::basic_block* inc_block = phi->get_incoming_block(n); PHINode *ptr = (PHINode*)((shared_tile*)tmap_.at(phi))->get_pointer();
// ir::value* inc_val = phi->get_incoming_value(n); PHINode *offset = (PHINode*)((shared_tile*)tmap_.at(phi))->get_offset();
// ir::value* terminator = inc_block->get_inst_list().back(); for(unsigned n = 0; n < phi->get_num_incoming(); n++){
// BasicBlock *llvm_inc_block = last_block.at(inc_block); ir::basic_block* inc_block = phi->get_incoming_block(n);
// shared_tile *inc_shared = (shared_tile*)tmap_.at(inc_val); ir::value* inc_val = phi->get_incoming_value(n);
// bool is_loop_latch = buffer_info_->is_loop_latch(phi, terminator); ir::value* terminator = inc_block->get_inst_list().back();
// if(is_loop_latch){ BasicBlock *llvm_inc_block = last_block.at(inc_block);
// dst_builder.SetInsertPoint(llvm_inc_block->getTerminator()); shared_tile *inc_shared = (shared_tile*)tmap_.at(inc_val);
// Value *next_offset = dst_builder.CreateNeg(offset); bool is_loop_latch = buffer_info_->is_loop_latch(phi, terminator);
// offset->addIncoming(next_offset, llvm_inc_block); if(is_loop_latch){
// } dst_builder.SetInsertPoint(llvm_inc_block->getTerminator());
// else { Value *next_offset = dst_builder.CreateNeg(offset);
// offset->addIncoming(dst_builder.getInt32(alloc_->get_num_bytes(phi)/(2*4)), llvm_inc_block); offset->addIncoming(next_offset, llvm_inc_block);
// } }
// ptr->addIncoming(inc_shared->get_pointer(), llvm_inc_block); else {
// } offset->addIncoming(dst_builder.getInt32(alloc_->get_num_bytes(phi)/(2*4)), llvm_inc_block);
// } }
// else { ptr->addIncoming(inc_shared->get_pointer(), llvm_inc_block);
// for(unsigned n = 0; n < phi->get_num_incoming(); n++){ }
// ir::value *inc_val = phi->get_incoming_value(n); }
// ir::basic_block *inc_block = phi->get_incoming_block(n); else {
// BasicBlock *llvm_inc_block = last_block.at(inc_block); for(unsigned n = 0; n < phi->get_num_incoming(); n++){
// if(phi->get_type()->is_tile_ty()) { ir::value *inc_val = phi->get_incoming_value(n);
// distributed_tile *phi_tile = (distributed_tile*)tmap_.at(phi); ir::basic_block *inc_block = phi->get_incoming_block(n);
// distributed_tile *inc_tile = (distributed_tile*)tmap_.at(inc_val); BasicBlock *llvm_inc_block = last_block.at(inc_block);
// phi_tile->for_each([&](indices_t idx){ if(phi->get_type()->is_tile_ty()) {
// PHINode *llvm_phi = (PHINode*)phi_tile->get_value(idx); distributed_tile *phi_tile = (distributed_tile*)tmap_.at(phi);
// Value *llvm_inc_val = inc_tile->get_value(idx); distributed_tile *inc_tile = (distributed_tile*)tmap_.at(inc_val);
// llvm_phi->addIncoming(llvm_inc_val, llvm_inc_block); phi_tile->for_each([&](indices_t idx){
// }); PHINode *llvm_phi = (PHINode*)phi_tile->get_value(idx);
// } Value *llvm_inc_val = inc_tile->get_value(idx);
// else { llvm_phi->addIncoming(llvm_inc_val, llvm_inc_block);
// PHINode *llvm_phi = (PHINode*)vmap_.at(phi); });
// Value *llvm_inc_val = vmap_.at(inc_val); }
// llvm_phi->addIncoming(llvm_inc_val, llvm_inc_block); else {
// } PHINode *llvm_phi = (PHINode*)vmap_.at(phi);
// } Value *llvm_inc_val = vmap_.at(inc_val);
// } llvm_phi->addIncoming(llvm_inc_val, llvm_inc_block);
// } }
}
}
}
} }
} }

View File

@@ -57,6 +57,11 @@ void backend::platforms::init() {
for(cl_platform_id id: ids) for(cl_platform_id id: ids)
cache_.push_back(new cl_platform(id)); cache_.push_back(new cl_platform(id));
} }
//if host is here
bool host_visible = true;
if(host_visible){
cache_.push_back(new host_platform());
}
if(cache_.empty()) if(cache_.empty())
throw std::runtime_error("ISAAC: No backend available. Make sure CUDA is available in your library path"); throw std::runtime_error("ISAAC: No backend available. Make sure CUDA is available in your library path");
} }

View File

@@ -42,6 +42,10 @@ buffer::buffer(driver::context* ctx, CUdeviceptr cu, bool take_ownership)
buffer::buffer(driver::context* ctx, cl_mem cl, bool take_ownership) buffer::buffer(driver::context* ctx, cl_mem cl, bool take_ownership)
: polymorphic_resource(cl, take_ownership), context_(ctx) { } : polymorphic_resource(cl, take_ownership), context_(ctx) { }
buffer::buffer(driver::context* ctx, host_buffer_t hst, bool take_ownership)
: polymorphic_resource(hst, take_ownership), context_(ctx) { }
driver::context* buffer::context() { driver::context* buffer::context() {
return context_; return context_;
} }
@@ -50,12 +54,20 @@ buffer* buffer::create(driver::context* ctx, size_t size) {
switch(ctx->backend()){ switch(ctx->backend()){
case CUDA: return new cu_buffer(ctx, size); case CUDA: return new cu_buffer(ctx, size);
case OpenCL: return new ocl_buffer(ctx, size); case OpenCL: return new ocl_buffer(ctx, size);
case Host: return new host_buffer(ctx, size);
default: throw std::runtime_error("unknown backend"); default: throw std::runtime_error("unknown backend");
} }
} }
// //
host_buffer::host_buffer(driver::context *context, size_t size)
: buffer(context, host_buffer_t(), true){
hst_->data = new char[size];
}
//
ocl_buffer::ocl_buffer(driver::context* context, size_t size) ocl_buffer::ocl_buffer(driver::context* context, size_t size)
: buffer(context, cl_mem(), true){ : buffer(context, cl_mem(), true){
cl_int err; cl_int err;

View File

@@ -47,13 +47,18 @@ context::context(driver::device *dev, CUcontext cu, bool take_ownership):
context::context(driver::device *dev, cl_context cl, bool take_ownership): context::context(driver::device *dev, cl_context cl, bool take_ownership):
polymorphic_resource(cl, take_ownership), polymorphic_resource(cl, take_ownership),
dev_(dev), cache_path_(get_cache_path()){ dev_(dev), cache_path_(get_cache_path()){
}
context::context(driver::device *dev, host_context_t hst, bool take_ownership):
polymorphic_resource(hst, take_ownership),
dev_(dev), cache_path_(get_cache_path()){
} }
context* context::create(driver::device *dev){ context* context::create(driver::device *dev){
switch(dev->backend()){ switch(dev->backend()){
case CUDA: return new cu_context(dev); case CUDA: return new cu_context(dev);
case OpenCL: return new ocl_context(dev); case OpenCL: return new ocl_context(dev);
case Host: return new host_context(dev);
default: throw std::runtime_error("unknown backend"); default: throw std::runtime_error("unknown backend");
} }
} }
@@ -86,6 +91,13 @@ std::string const & context::cache_path() const{
return cache_path_; return cache_path_;
} }
/* ------------------------ */
// Host //
/* ------------------------ */
host_context::host_context(driver::device* dev): context(dev, host_context_t(), true){
}
/* ------------------------ */ /* ------------------------ */
// CUDA // // CUDA //

View File

@@ -30,6 +30,15 @@ namespace triton
namespace driver namespace driver
{ {
//Host
inline void _delete(host_platform_t) { }
inline void _delete(host_device_t) { }
inline void _delete(host_context_t) { }
inline void _delete(host_module_t) { }
inline void _delete(host_stream_t) { }
inline void _delete(host_buffer_t x) { if(x.data) delete[] x.data; }
inline void _delete(host_function_t) { }
//OpenCL //OpenCL
inline void _delete(cl_platform_id) { } inline void _delete(cl_platform_id) { }
inline void _delete(cl_device_id x) { dispatch::clReleaseDevice(x); } inline void _delete(cl_device_id x) { dispatch::clReleaseDevice(x); }
@@ -58,7 +67,7 @@ handle<CUType>::handle(CUType cu, bool take_ownership): h_(new CUType(cu)), has_
template<class CUType> template<class CUType>
handle<CUType>::~handle(){ handle<CUType>::~handle(){
if(has_ownership_ && h_ && h_.unique() && *h_) if(has_ownership_ && h_ && h_.unique())
_delete(*h_); _delete(*h_);
} }
@@ -79,5 +88,14 @@ template class handle<cl_command_queue>;
template class handle<cl_mem>; template class handle<cl_mem>;
template class handle<cl_kernel>; template class handle<cl_kernel>;
template class handle<host_platform_t>;
template class handle<host_device_t>;
template class handle<host_context_t>;
template class handle<host_module_t>;
template class handle<host_stream_t>;
template class handle<host_buffer_t>;
template class handle<host_function_t>;
} }
} }

View File

@@ -22,7 +22,7 @@
#include <iostream> #include <iostream>
#include <cstring> #include <cstring>
#include "llvm/ExecutionEngine/GenericValue.h"
#include "triton/driver/kernel.h" #include "triton/driver/kernel.h"
#include "triton/driver/buffer.h" #include "triton/driver/buffer.h"
@@ -45,10 +45,15 @@ kernel::kernel(driver::module *program, cl_kernel fn, bool has_ownership):
polymorphic_resource(fn, has_ownership), program_(program){ polymorphic_resource(fn, has_ownership), program_(program){
} }
kernel::kernel(driver::module *program, host_function_t fn, bool has_ownership):
polymorphic_resource(fn, has_ownership), program_(program){
}
kernel* kernel::create(driver::module* program, const char* name) { kernel* kernel::create(driver::module* program, const char* name) {
switch(program->backend()){ switch(program->backend()){
case CUDA: return new cu_kernel(program, name); case CUDA: return new cu_kernel(program, name);
case OpenCL: return new ocl_kernel(program, name); case OpenCL: return new ocl_kernel(program, name);
case Host: return new host_kernel(program, name);
default: throw std::runtime_error("unknown backend"); default: throw std::runtime_error("unknown backend");
} }
} }
@@ -57,6 +62,32 @@ driver::module* kernel::module() {
return program_; return program_;
} }
/* ------------------------ */
// Host //
/* ------------------------ */
host_kernel::host_kernel(driver::module* program, const char *name): kernel(program, host_function_t(), true) {
hst_->fn = program->hst()->functions.at(name);
}
void host_kernel::setArg(unsigned int index, std::size_t size, void* ptr){
if(index + 1> params_store_.size()){
params_store_.resize(index+1);
params_.resize(index+1);
}
params_store_[index].reset(malloc(size), free);
memcpy(params_store_[index].get(), ptr, size);
params_[index] = llvm::GenericValue(params_store_[index].get());
}
void host_kernel::setArg(unsigned int index, driver::buffer* buffer){
kernel::setArg(index, (void*)buffer->hst()->data);
}
const std::vector<llvm::GenericValue>& host_kernel::params(){
return params_;
}
/* ------------------------ */ /* ------------------------ */
// OpenCL // // OpenCL //
/* ------------------------ */ /* ------------------------ */
@@ -66,7 +97,6 @@ ocl_kernel::ocl_kernel(driver::module* program, const char* name): kernel(progra
// check(dispatch::clCreateKernelsInProgram(*program->cl(), 0, NULL, &res)); // check(dispatch::clCreateKernelsInProgram(*program->cl(), 0, NULL, &res));
// std::cout << res << std::endl; // std::cout << res << std::endl;
cl_int err; cl_int err;
std::cout << *program->cl() << std::endl;
*cl_ = dispatch::clCreateKernel(*program->cl(), "matmul", &err); *cl_ = dispatch::clCreateKernel(*program->cl(), "matmul", &err);
check(err); check(err);
} }

View File

@@ -22,7 +22,7 @@
#include <iostream> #include <iostream>
#include <fstream> #include <fstream>
#include <memory>
#include "triton/driver/module.h" #include "triton/driver/module.h"
#include "triton/driver/context.h" #include "triton/driver/context.h"
#include "triton/driver/error.h" #include "triton/driver/error.h"
@@ -40,12 +40,17 @@
#include "llvm/Support/raw_ostream.h" #include "llvm/Support/raw_ostream.h"
#include "llvm/Support/TargetRegistry.h" #include "llvm/Support/TargetRegistry.h"
#include "llvm/Support/TargetSelect.h" #include "llvm/Support/TargetSelect.h"
#include "llvm/Support/Host.h"
#include "llvm/Target/TargetMachine.h" #include "llvm/Target/TargetMachine.h"
#include "llvm/Target/TargetOptions.h" #include "llvm/Target/TargetOptions.h"
#include "llvm/CodeGen/TargetPassConfig.h" #include "llvm/CodeGen/TargetPassConfig.h"
#include "llvm/IR/LegacyPassManager.h" #include "llvm/IR/LegacyPassManager.h"
#include "llvm/Transforms/Scalar/EarlyCSE.h" #include "llvm/Transforms/Scalar/EarlyCSE.h"
#include "llvm/Analysis/LoopPass.h" #include "llvm/Analysis/LoopPass.h"
#include "llvm/ExecutionEngine/ExecutionEngine.h"
#include "llvm/ExecutionEngine/OrcMCJITReplacement.h"
#include <llvm/ExecutionEngine/SectionMemoryManager.h>
#include "llvm/Transforms/Utils/Cloning.h"
namespace triton namespace triton
{ {
@@ -76,6 +81,10 @@ module::module(driver::context* ctx, cl_program mod, bool has_ownership)
: polymorphic_resource(mod, has_ownership), ctx_(ctx) { : polymorphic_resource(mod, has_ownership), ctx_(ctx) {
} }
module::module(driver::context* ctx, host_module_t mod, bool has_ownership)
: polymorphic_resource(mod, has_ownership), ctx_(ctx) {
}
driver::context* module::context() const { driver::context* module::context() const {
return ctx_; return ctx_;
} }
@@ -84,6 +93,7 @@ module* module::create(driver::context* ctx, llvm::Module *src) {
switch(ctx->backend()){ switch(ctx->backend()){
case CUDA: return new cu_module(ctx, src); case CUDA: return new cu_module(ctx, src);
case OpenCL: return new ocl_module(ctx, src); case OpenCL: return new ocl_module(ctx, src);
case Host: return new host_module(ctx, src);
default: throw std::runtime_error("unknown backend"); default: throw std::runtime_error("unknown backend");
} }
} }
@@ -91,7 +101,7 @@ module* module::create(driver::context* ctx, llvm::Module *src) {
void module::compile_llvm_module(llvm::Module* module, const std::string& triple, void module::compile_llvm_module(llvm::Module* module, const std::string& triple,
const std::string &proc, std::string layout, const std::string &proc, std::string layout,
llvm::SmallVectorImpl<char> &buffer, llvm::SmallVectorImpl<char> &buffer,
std::vector<std::string> files) { std::vector<std::string> paths) {
init_llvm(); init_llvm();
// create machine // create machine
module->setTargetTriple(triple); module->setTargetTriple(triple);
@@ -112,8 +122,7 @@ void module::compile_llvm_module(llvm::Module* module, const std::string& triple
module->setDataLayout(layout); module->setDataLayout(layout);
// link // link
for (std::string& file: files) { for (std::string& path: paths) {
std::string path = "/opt/rocm/lib/" + file;
llvm::SMDiagnostic err; llvm::SMDiagnostic err;
std::unique_ptr<llvm::Module> mlib = llvm::parseIRFile(path, err, module->getContext()); std::unique_ptr<llvm::Module> mlib = llvm::parseIRFile(path, err, module->getContext());
if (mlib.get() == nullptr) { if (mlib.get() == nullptr) {
@@ -137,46 +146,44 @@ void module::compile_llvm_module(llvm::Module* module, const std::string& triple
// std::cout << std::string(buffer.begin(), buffer.end()) << std::endl; // std::cout << std::string(buffer.begin(), buffer.end()) << std::endl;
} }
/* ------------------------ */
// Host //
/* ------------------------ */
host_module::host_module(driver::context * context, llvm::Module* src): module(context, host_module_t(), true) {
init_llvm();
// host info
// std::string triple = llvm::sys::getDefaultTargetTriple();
// std::string cpu = llvm::sys::getHostCPUName();
// llvm::SmallVector<char, 0> buffer;
// module::compile_llvm_module(src, triple, cpu, "", buffer);
// create execution engine
// llvm::legacy::PassManager pass;
// pass.add(llvm::createPrintModulePass(llvm::outs()));
// pass.add(llvm::createVerifierPass());
// pass.run(*src);
auto cloned = llvm::CloneModule(*src);
for(llvm::Function& fn: cloned->functions())
hst_->functions[fn.getName()] = &fn;
llvm::EngineBuilder builder(std::move(cloned));
builder.setErrorStr(&hst_->error);
builder.setMCJITMemoryManager(llvm::make_unique<llvm::SectionMemoryManager>());
builder.setOptLevel(llvm::CodeGenOpt::Aggressive);
builder.setEngineKind(llvm::EngineKind::JIT);
builder.setUseOrcMCJITReplacement(true);
hst_->engine = builder.create();
}
/* ------------------------ */ /* ------------------------ */
// OpenCL // // OpenCL //
/* ------------------------ */ /* ------------------------ */
ocl_module::ocl_module(driver::context * context, llvm::Module* src): module(context, cl_program(), true) { ocl_module::ocl_module(driver::context * context, llvm::Module* src): module(context, cl_program(), true) {
// const char* x = "__kernel void matmul(){ }";
// cl_int err;
// *cl_ = dispatch::clCreateProgramWithSource(*context->cl(), 1, &x, NULL, &err);
// check(err);
// return;
init_llvm(); init_llvm();
llvm::SmallVector<char, 0> buffer; llvm::SmallVector<char, 0> buffer;
std::vector<std::string> files = { module::compile_llvm_module(src, "amdgcn-amd-amdpal", "gfx902", "", buffer);
"oclc_daz_opt_on.amdgcn.bc",
"ocml.amdgcn.bc",
"hc.amdgcn.bc",
"ockl.amdgcn.bc",
"oclc_correctly_rounded_sqrt_off.amdgcn.bc",
"oclc_correctly_rounded_sqrt_on.amdgcn.bc",
"oclc_daz_opt_off.amdgcn.bc",
"oclc_finite_only_off.amdgcn.bc",
"oclc_finite_only_on.amdgcn.bc",
"oclc_isa_version_803.amdgcn.bc",
"oclc_isa_version_900.amdgcn.bc",
"oclc_unsafe_math_off.amdgcn.bc",
"oclc_unsafe_math_on.amdgcn.bc",
"oclc_isa_version_700.amdgcn.bc",
"opencl.amdgcn.bc"
};
module::compile_llvm_module(src, "amdgcn-amd-amdpal", "gfx902", "", buffer, files);
// llvm::BitcodeWriter writer(buffer);
// writer.writeModule(*src);
// llvm::legacy::PassManager pass;
// llvm::raw_svector_ostream stream(buffer);
// pass.add(llvm::createPrintModulePass(stream));
// pass.run(*src);
size_t sizes[] = {buffer.size()}; size_t sizes[] = {buffer.size()};
const unsigned char* data[] = {(unsigned char*)buffer.data()}; const unsigned char* data[] = {(unsigned char*)buffer.data()};
cl_int status; cl_int status;

View File

@@ -74,9 +74,16 @@ void cl_platform::devices(std::vector<device*> &devices) const{
} }
/* ------------------------ */ /* ------------------------ */
// Vulkan // // Host //
/* ------------------------ */ /* ------------------------ */
std::string host_platform::version() const {
return "1.0";
}
void host_platform::devices(std::vector<driver::device*> &devices) const {
devices.push_back(new driver::host_device());
}
} }

View File

@@ -31,6 +31,8 @@
#include "triton/driver/event.h" #include "triton/driver/event.h"
#include "triton/driver/kernel.h" #include "triton/driver/kernel.h"
#include "triton/driver/buffer.h" #include "triton/driver/buffer.h"
#include "llvm/ExecutionEngine/ExecutionEngine.h"
#include "llvm/ExecutionEngine/GenericValue.h"
namespace triton namespace triton
{ {
@@ -50,10 +52,15 @@ stream::stream(driver::context *ctx, cl_command_queue cl, bool has_ownership)
: polymorphic_resource(cl, has_ownership), ctx_(ctx) { : polymorphic_resource(cl, has_ownership), ctx_(ctx) {
} }
stream::stream(driver::context *ctx, host_stream_t cl, bool has_ownership)
: polymorphic_resource(cl, has_ownership), ctx_(ctx) {
}
driver::stream* stream::create(driver::context* ctx) { driver::stream* stream::create(driver::context* ctx) {
switch(ctx->backend()){ switch(ctx->backend()){
case CUDA: return new cu_stream(ctx); case CUDA: return new cu_stream(ctx);
case OpenCL: return new cl_stream(ctx); case OpenCL: return new cl_stream(ctx);
case Host: return new host_stream(ctx);
default: throw std::runtime_error("unknown backend"); default: throw std::runtime_error("unknown backend");
} }
} }
@@ -62,6 +69,32 @@ driver::context* stream::context() const {
return ctx_; return ctx_;
} }
/* ------------------------ */
// Host //
/* ------------------------ */
host_stream::host_stream(driver::context *ctx): stream(ctx, host_stream_t(), true) {
}
void host_stream::synchronize() {
}
void host_stream::enqueue(driver::kernel* kernel, std::array<size_t, 3> grid, std::array<size_t, 3> block, std::vector<event> const *, event* event) {
driver::host_kernel* hst_kernel = (host_kernel*)kernel;
llvm::ExecutionEngine* engine = kernel->module()->hst()->engine;
engine->runFunction(kernel->hst()->fn, llvm::ArrayRef<llvm::GenericValue>(hst_kernel->params()));
}
void host_stream::write(driver::buffer* buffer, bool blocking, std::size_t offset, std::size_t size, void const* ptr) {
}
void host_stream::read(driver::buffer* buffer, bool blocking, std::size_t offset, std::size_t size, void* ptr) {
}
/* ------------------------ */ /* ------------------------ */
// OpenCL // // OpenCL //

View File

@@ -1,6 +1,7 @@
#include "triton/jit.h" #include "triton/jit.h"
#include <string> #include <string>
#include "triton/ast/ast.h" #include "triton/ast/ast.h"
#include "triton/codegen/target.h"
#include "triton/ir/context.h" #include "triton/ir/context.h"
#include "triton/ir/context_impl.h" #include "triton/ir/context_impl.h"
#include "triton/driver/device.h" #include "triton/driver/device.h"
@@ -89,7 +90,7 @@ std::unique_ptr<ir::module> jit::make_triton_module(const std::string &src) {
} }
jit::jit(driver::context *context): driver_context_(context) { jit::jit(driver::context *context): driver_context_(context), target_(new triton::codegen::cpu_target()) {
} }
@@ -98,7 +99,7 @@ void jit::autotune(const std::string &src, benchmark_t benchmark) {
auto ptt_module = make_triton_module(src); auto ptt_module = make_triton_module(src);
ir::module &tt_module = *ptt_module; ir::module &tt_module = *ptt_module;
// set parameters // set parameters
passes_wrapper passes; passes_wrapper passes(target_.get());
passes.tune.run(tt_module); passes.tune.run(tt_module);
auto mps = passes.tune.get_params(tt_module); auto mps = passes.tune.get_params(tt_module);
// create parameter ranges // create parameter ranges
@@ -123,7 +124,7 @@ void jit::autotune(const std::string &src, benchmark_t benchmark) {
// Deep copy of the module and tuner // Deep copy of the module and tuner
auto ptt_module = make_triton_module(src); auto ptt_module = make_triton_module(src);
ir::module &tt_module = *ptt_module; ir::module &tt_module = *ptt_module;
passes_wrapper passes; passes_wrapper passes(target_.get());
passes.tune.run(tt_module); passes.tune.run(tt_module);
i = 0; i = 0;
for(ir::metaparameter* mp: passes.tune.get_params(tt_module)){ for(ir::metaparameter* mp: passes.tune.get_params(tt_module)){
@@ -154,7 +155,7 @@ void jit::autotune(const std::string &src, benchmark_t benchmark) {
void jit::add_module(ir::module &tt_module, const std::vector<unsigned> &params) { void jit::add_module(ir::module &tt_module, const std::vector<unsigned> &params) {
// set parameters // set parameters
passes_wrapper passes; passes_wrapper passes(target_.get());
passes.tune.run(tt_module); passes.tune.run(tt_module);
unsigned i = 0; unsigned i = 0;
for(ir::metaparameter* mp: passes.tune.get_params(tt_module)) for(ir::metaparameter* mp: passes.tune.get_params(tt_module))