[code generation] basic CPU backend
This commit is contained in:
@@ -19,22 +19,18 @@ void matmul(restrict read_only fp32 *a, restrict read_only fp32 *b, fp32 *c,
|
||||
fp32 C[TM, TN] = 0;
|
||||
fp32* pa[TM, TK] = a + rka[newaxis, :]*M + rxa[:, newaxis];
|
||||
fp32* pb[TN, TK] = b + rkb[newaxis, :]*K + ryb[:, newaxis];
|
||||
fp32 a[TM, TK] = *pa;
|
||||
fp32 b[TN, TK] = *pb;
|
||||
for(int32 k = K; k > 0; k = k - TK){
|
||||
for(int32 k = K; k > 0;){
|
||||
fp32 a[TM, TK] = *pa;
|
||||
fp32 b[TN, TK] = *pb;
|
||||
C = dot(a, b, C);
|
||||
pa = pa + TK*M;
|
||||
pb = pb + TK*K;
|
||||
a = *pa;
|
||||
b = *pb;
|
||||
k = k - TK;
|
||||
}
|
||||
int32 rxc[TM] = get_global_range[TM](0);
|
||||
int32 ryc[TN] = get_global_range[TN](1);
|
||||
fp32* pc[TM, TN] = c + ryc[newaxis, :]*M + rxc[:, newaxis];
|
||||
int1 checkc0[TM] = rxc < M;
|
||||
int1 checkc1[TN] = ryc < N;
|
||||
int1 checkc[TM, TN] = checkc0[:, newaxis] && checkc1[newaxis, :];
|
||||
@checkc *pc = C;
|
||||
*pc = C;
|
||||
}
|
||||
)";
|
||||
|
||||
@@ -93,7 +89,7 @@ int main() {
|
||||
triton::jit jit(context);
|
||||
|
||||
// matrix multiplication parameters
|
||||
int32_t M = 512, N = 512, K = 512;
|
||||
int32_t M = 256, N = 256, K = 256;
|
||||
std::vector<float> hc(M*N);
|
||||
std::vector<float> rc(M*N);
|
||||
std::vector<float> ha(M*K);
|
||||
@@ -155,11 +151,11 @@ int main() {
|
||||
|
||||
// just-in-time compile source-code
|
||||
std::vector<unsigned> params = {
|
||||
16, 2, 64,
|
||||
32, 2, 64,
|
||||
16, 8, 2, 2,
|
||||
8, 8,
|
||||
4,
|
||||
1, 4, 8,
|
||||
1, 4, 8,
|
||||
1, 1, 4, 4,
|
||||
1, 8,
|
||||
1,
|
||||
};
|
||||
// params = {8, 2, 64, 16, 2, 64, 4, 16, 2, 2, 8, 8, 4};
|
||||
|
||||
|
@@ -68,10 +68,10 @@ public:
|
||||
void setArg(unsigned int index, std::size_t size, void* ptr);
|
||||
void setArg(unsigned int index, driver::buffer* buffer);
|
||||
// Params
|
||||
const std::vector<llvm::GenericValue>& params();
|
||||
const std::vector<void*>& params();
|
||||
private:
|
||||
std::vector<std::shared_ptr<void> > params_store_;
|
||||
std::vector<llvm::GenericValue> params_;
|
||||
std::vector<void*> params_;
|
||||
};
|
||||
|
||||
// OpenCL
|
||||
|
@@ -52,15 +52,13 @@ public:
|
||||
selection(&allocation, &tune, &buffer_info, target) { }
|
||||
|
||||
void init(ir::module &module) {
|
||||
// generate ptx
|
||||
buffer_info.run(module);
|
||||
shared.run(module);
|
||||
triton::ir::print(module, std::cout);
|
||||
liveness.run(module);
|
||||
allocation.run();
|
||||
barriers.run(module);
|
||||
// buffer_info.run(module);
|
||||
// shared.run(module);
|
||||
// liveness.run(module);
|
||||
// allocation.run();
|
||||
// barriers.run(module);
|
||||
vectorize.run(module);
|
||||
triton::ir::print(module, std::cout);
|
||||
// triton::ir::print(module, std::cout);
|
||||
}
|
||||
|
||||
codegen::tune tune;
|
||||
|
@@ -234,7 +234,7 @@ ir::type* tile::type_impl(ir::module *mod, ir::type *type, storage_spec_vec_cons
|
||||
// Pointer
|
||||
ir::type* pointer::type_impl(ir::module*, ir::type *type, storage_spec_vec_const_ref_t storage) const{
|
||||
bool is_ptr_to_const = std::find(storage.begin(), storage.end(), CONSTANT_SPACE_T) != storage.end();
|
||||
return ir::pointer_type::get(type, is_ptr_to_const?4:1);
|
||||
return ir::pointer_type::get(type, is_ptr_to_const?4:0);
|
||||
}
|
||||
|
||||
// Function
|
||||
|
@@ -116,9 +116,6 @@ void allocation::run(){
|
||||
for(auto &x: offsets_){
|
||||
allocated_size_ = std::max<size_t>(allocated_size_, x.second + get_num_bytes(x.first));
|
||||
}
|
||||
std::cout << "Allocated: " << allocated_size_ << std::endl;
|
||||
for(auto &x: offsets_)
|
||||
std::cout << x.first->get_name() << " " << x.second << std::endl;
|
||||
}
|
||||
|
||||
}
|
||||
|
@@ -130,7 +130,6 @@ void barriers::run(ir::module &mod) {
|
||||
n_inserted_im1 = n_inserted_i;
|
||||
}while(!done);
|
||||
for(ir::instruction* i: insert_locs){
|
||||
std::cout << i->get_name() << std::endl;
|
||||
insert_barrier(i, builder);
|
||||
}
|
||||
}
|
||||
|
@@ -46,7 +46,6 @@ void liveness::run(ir::module &mod) {
|
||||
}
|
||||
intervals_[v] = segment{start, end};
|
||||
}
|
||||
std::cout << "Number of intervals: " << intervals_.size() << std::endl;
|
||||
}
|
||||
}
|
||||
|
||||
|
@@ -810,7 +810,21 @@ void selection::run(ir::module &src, Module &dst) {
|
||||
for(ir::function *fn: src.get_function_list()) {
|
||||
// create LLVM function
|
||||
FunctionType *fn_ty = (FunctionType*)llvm_type(fn->get_fn_type(), dst_ctx);
|
||||
Function *dst_fn = Function::Create(fn_ty, Function::ExternalLinkage, fn->get_name(), &dst);
|
||||
Type *dst_fn_ret_ty = fn_ty->getReturnType();
|
||||
std::vector<Type*> dst_fn_args_ty;
|
||||
for(unsigned i = 0; i < fn_ty->getNumParams(); i++)
|
||||
dst_fn_args_ty.push_back(fn_ty->getParamType(i));
|
||||
dst_fn_args_ty.push_back(dst_builder.getInt32Ty());
|
||||
dst_fn_args_ty.push_back(dst_builder.getInt32Ty());
|
||||
dst_fn_args_ty.push_back(dst_builder.getInt32Ty());
|
||||
FunctionType *dst_fn_ty = FunctionType::get(dst_fn_ret_ty, dst_fn_args_ty, false);
|
||||
// grid indices
|
||||
fn->get_fn_type()->get_return_ty();
|
||||
Function *dst_fn = Function::Create(dst_fn_ty, Function::ExternalLinkage, fn->get_name(), &dst);
|
||||
|
||||
|
||||
|
||||
|
||||
// set attributes
|
||||
for(auto attr_pair: fn->attrs()){
|
||||
unsigned id = attr_pair.first;
|
||||
@@ -831,15 +845,15 @@ void selection::run(ir::module &src, Module &dst) {
|
||||
|
||||
// allocate shared memory
|
||||
Value *sh_mem_ptr = nullptr;
|
||||
if(unsigned alloc_size = alloc_->get_allocated_size()){
|
||||
Type *int_8_ty = Type::getInt8Ty(dst_ctx);
|
||||
ArrayType *array_ty = ArrayType::get(int_8_ty, alloc_size);
|
||||
Type *ptr_ty = PointerType::get(int_8_ty, 3);
|
||||
GlobalVariable *sh_mem_array =
|
||||
new GlobalVariable(dst, array_ty, false, GlobalVariable::ExternalLinkage,
|
||||
nullptr, "__shared_ptr", nullptr, GlobalVariable::NotThreadLocal, 3);
|
||||
sh_mem_ptr = dst_builder.CreateBitCast(sh_mem_array, ptr_ty);
|
||||
}
|
||||
// if(unsigned alloc_size = alloc_->get_allocated_size()){
|
||||
// Type *int_8_ty = Type::getInt8Ty(dst_ctx);
|
||||
// ArrayType *array_ty = ArrayType::get(int_8_ty, alloc_size);
|
||||
// Type *ptr_ty = PointerType::get(int_8_ty, 3);
|
||||
// GlobalVariable *sh_mem_array =
|
||||
// new GlobalVariable(dst, array_ty, false, GlobalVariable::ExternalLinkage,
|
||||
// nullptr, "__shared_ptr", nullptr, GlobalVariable::NotThreadLocal, 3);
|
||||
// sh_mem_ptr = dst_builder.CreateBitCast(sh_mem_array, ptr_ty);
|
||||
// }
|
||||
|
||||
// create grids
|
||||
init_grids(fn, dst_builder, sh_mem_ptr);
|
||||
|
@@ -246,7 +246,7 @@ bool tune::check_constraints(std::map<ir::value *, std::vector<std::string>> &er
|
||||
int num_threads = 1;
|
||||
for(size_t k = 0; k < shapes.size(); k++)
|
||||
num_threads *= params_[i]["mts.d" + to_string(k)]->get_value();
|
||||
if(num_threads % 64 != 0)
|
||||
if(num_threads % 1 != 0)
|
||||
errors[i].push_back("number of threads per block (" + to_string(num_threads) + ") must be multiple of 32");
|
||||
if(num_threads != num_threads_)
|
||||
errors[i].push_back("Number of threads must be the same for all tiles (" + to_string(num_threads_) + ")");
|
||||
|
@@ -48,20 +48,20 @@ void backend::platforms::init() {
|
||||
if(dispatch::cuinit()){
|
||||
cache_.push_back(new cu_platform());
|
||||
}
|
||||
//if OpenCL is here
|
||||
if(dispatch::clinit()){
|
||||
cl_uint num_platforms;
|
||||
dispatch::clGetPlatformIDs(0, nullptr, &num_platforms);
|
||||
std::vector<cl_platform_id> ids(num_platforms);
|
||||
dispatch::clGetPlatformIDs(num_platforms, ids.data(), nullptr);
|
||||
for(cl_platform_id id: ids)
|
||||
cache_.push_back(new cl_platform(id));
|
||||
}
|
||||
// //if host is here
|
||||
// bool host_visible = true;
|
||||
// if(host_visible){
|
||||
// cache_.push_back(new host_platform());
|
||||
// //if OpenCL is here
|
||||
// if(dispatch::clinit()){
|
||||
// cl_uint num_platforms;
|
||||
// dispatch::clGetPlatformIDs(0, nullptr, &num_platforms);
|
||||
// std::vector<cl_platform_id> ids(num_platforms);
|
||||
// dispatch::clGetPlatformIDs(num_platforms, ids.data(), nullptr);
|
||||
// for(cl_platform_id id: ids)
|
||||
// cache_.push_back(new cl_platform(id));
|
||||
// }
|
||||
//if host is here
|
||||
bool host_visible = true;
|
||||
if(host_visible){
|
||||
cache_.push_back(new host_platform());
|
||||
}
|
||||
if(cache_.empty())
|
||||
throw std::runtime_error("ISAAC: No backend available. Make sure CUDA is available in your library path");
|
||||
}
|
||||
|
@@ -64,6 +64,9 @@ buffer* buffer::create(driver::context* ctx, size_t size) {
|
||||
host_buffer::host_buffer(driver::context *context, size_t size)
|
||||
: buffer(context, host_buffer_t(), true){
|
||||
hst_->data = new char[size];
|
||||
std::cout << size << std::endl;
|
||||
std::cout << "allocating " << (float*)hst_->data << std::endl;
|
||||
std::cout << *((float*)(hst_->data) + 512*500) << std::endl;
|
||||
}
|
||||
|
||||
//
|
||||
|
@@ -77,14 +77,14 @@ void host_kernel::setArg(unsigned int index, std::size_t size, void* ptr){
|
||||
}
|
||||
params_store_[index].reset(malloc(size), free);
|
||||
memcpy(params_store_[index].get(), ptr, size);
|
||||
params_[index] = llvm::GenericValue(params_store_[index].get());
|
||||
params_[index] = params_store_[index].get();
|
||||
}
|
||||
|
||||
void host_kernel::setArg(unsigned int index, driver::buffer* buffer){
|
||||
kernel::setArg(index, (void*)buffer->hst()->data);
|
||||
}
|
||||
|
||||
const std::vector<llvm::GenericValue>& host_kernel::params(){
|
||||
const std::vector<void *> &host_kernel::params(){
|
||||
return params_;
|
||||
}
|
||||
|
||||
|
@@ -151,11 +151,38 @@ host_module::host_module(driver::context * context, llvm::Module* src): module(c
|
||||
// llvm::SmallVector<char, 0> buffer;
|
||||
// module::compile_llvm_module(src, triple, cpu, "", buffer);
|
||||
|
||||
// create kernel wrapper
|
||||
llvm::LLVMContext &ctx = src->getContext();
|
||||
llvm::Type *void_ty = llvm::Type::getVoidTy(ctx);
|
||||
llvm::Type *args_ty = llvm::Type::getInt8PtrTy(ctx)->getPointerTo();
|
||||
llvm::Type *int32_ty = llvm::Type::getInt32Ty(ctx);
|
||||
llvm::FunctionType *main_ty = llvm::FunctionType::get(void_ty, {args_ty, int32_ty, int32_ty, int32_ty}, false);
|
||||
llvm::Function* main = llvm::Function::Create(main_ty, llvm::Function::ExternalLinkage, "main", src);
|
||||
llvm::Function* fn = src->getFunction("matmul");
|
||||
llvm::FunctionType *fn_ty = fn->getFunctionType();
|
||||
std::vector<llvm::Value*> fn_args(fn_ty->getNumParams());
|
||||
std::vector<llvm::Value*> ptrs(fn_args.size() - 3);
|
||||
llvm::BasicBlock* entry = llvm::BasicBlock::Create(ctx, "entry", main);
|
||||
llvm::IRBuilder<> ir_builder(ctx);
|
||||
ir_builder.SetInsertPoint(entry);
|
||||
for(unsigned i = 0; i < ptrs.size(); i++)
|
||||
ptrs[i] = ir_builder.CreateGEP(main->arg_begin(), ir_builder.getInt32(i));
|
||||
for(unsigned i = 0; i < ptrs.size(); i++){
|
||||
llvm::Value* addr = ir_builder.CreateBitCast(ir_builder.CreateLoad(ptrs[i]), fn_ty->getParamType(i)->getPointerTo());
|
||||
fn_args[i] = ir_builder.CreateLoad(addr);
|
||||
}
|
||||
fn_args[fn_args.size() - 3] = main->arg_begin() + 1;
|
||||
fn_args[fn_args.size() - 2] = main->arg_begin() + 2;
|
||||
fn_args[fn_args.size() - 1] = main->arg_begin() + 3;
|
||||
ir_builder.CreateCall(fn, fn_args);
|
||||
ir_builder.CreateRetVoid();
|
||||
|
||||
|
||||
// create execution engine
|
||||
// llvm::legacy::PassManager pass;
|
||||
// pass.add(llvm::createPrintModulePass(llvm::outs()));
|
||||
// pass.add(llvm::createVerifierPass());
|
||||
// pass.run(*src);
|
||||
llvm::legacy::PassManager pass;
|
||||
pass.add(llvm::createPrintModulePass(llvm::outs()));
|
||||
pass.add(llvm::createVerifierPass());
|
||||
pass.run(*src);
|
||||
auto cloned = llvm::CloneModule(*src);
|
||||
for(llvm::Function& fn: cloned->functions())
|
||||
hst_->functions[fn.getName()] = &fn;
|
||||
|
@@ -84,15 +84,19 @@ void host_stream::synchronize() {
|
||||
void host_stream::enqueue(driver::kernel* kernel, std::array<size_t, 3> grid, std::array<size_t, 3> block, std::vector<event> const *, event* event) {
|
||||
driver::host_kernel* hst_kernel = (host_kernel*)kernel;
|
||||
llvm::ExecutionEngine* engine = kernel->module()->hst()->engine;
|
||||
engine->runFunction(kernel->hst()->fn, llvm::ArrayRef<llvm::GenericValue>(hst_kernel->params()));
|
||||
void (*fn)(char**, int32_t, int32_t, int32_t) = (void(*)(char**, int32_t, int32_t, int32_t))engine->getFunctionAddress("main");
|
||||
for(size_t i = 0; i < grid[0]; i++)
|
||||
for(size_t j = 0; j < grid[1]; j++)
|
||||
for(size_t k = 0; k < grid[2]; k++)
|
||||
fn((char**)hst_kernel->params().data(), int32_t(i), int32_t(j), int32_t(k));
|
||||
}
|
||||
|
||||
void host_stream::write(driver::buffer* buffer, bool blocking, std::size_t offset, std::size_t size, void const* ptr) {
|
||||
|
||||
std::memcpy((void*)buffer->hst()->data, ptr, size);
|
||||
}
|
||||
|
||||
void host_stream::read(driver::buffer* buffer, bool blocking, std::size_t offset, std::size_t size, void* ptr) {
|
||||
|
||||
std::memcpy(ptr, (const void*)buffer->hst()->data, size);
|
||||
}
|
||||
|
||||
|
||||
|
13
lib/jit.cpp
13
lib/jit.cpp
@@ -92,7 +92,7 @@ std::unique_ptr<ir::module> jit::make_triton_module(const std::string &src) {
|
||||
}
|
||||
|
||||
|
||||
jit::jit(driver::context *context): driver_context_(context), target_(new triton::codegen::amd_cl_target()) {
|
||||
jit::jit(driver::context *context): driver_context_(context), target_(new triton::codegen::cpu_target()) {
|
||||
}
|
||||
|
||||
|
||||
@@ -164,11 +164,16 @@ void jit::add_module(ir::module &tt_module, const std::vector<unsigned> ¶ms)
|
||||
// check constraints
|
||||
std::map<ir::value*, std::vector<std::string>> errors;
|
||||
passes.tune.check_constraints(errors);
|
||||
for(auto x: errors){
|
||||
std::cout << x.first << std::endl;
|
||||
for(auto str: x.second)
|
||||
std::cout << str << std::endl;
|
||||
}
|
||||
if(errors.size())
|
||||
throw std::runtime_error("invalid parameters");
|
||||
driver::device* device = driver_context_->device();
|
||||
if(passes.allocation.get_allocated_size() > device->max_shared_memory())
|
||||
throw std::runtime_error("invalid parameters");
|
||||
// driver::device* device = driver_context_->device();
|
||||
// if(passes.allocation.get_allocated_size() > device->max_shared_memory())
|
||||
// throw std::runtime_error("invalid parameters");
|
||||
// triton module -> llvm module
|
||||
auto ll_module = make_llvm_module(tt_module, passes);
|
||||
// llvm module -> machine code
|
||||
|
Reference in New Issue
Block a user