[code generation] basic CPU backend

This commit is contained in:
Philippe Tillet
2019-03-27 11:13:36 -07:00
parent 9d6fc1c051
commit e04253c0dd
15 changed files with 110 additions and 68 deletions

View File

@@ -19,22 +19,18 @@ void matmul(restrict read_only fp32 *a, restrict read_only fp32 *b, fp32 *c,
fp32 C[TM, TN] = 0;
fp32* pa[TM, TK] = a + rka[newaxis, :]*M + rxa[:, newaxis];
fp32* pb[TN, TK] = b + rkb[newaxis, :]*K + ryb[:, newaxis];
fp32 a[TM, TK] = *pa;
fp32 b[TN, TK] = *pb;
for(int32 k = K; k > 0; k = k - TK){
for(int32 k = K; k > 0;){
fp32 a[TM, TK] = *pa;
fp32 b[TN, TK] = *pb;
C = dot(a, b, C);
pa = pa + TK*M;
pb = pb + TK*K;
a = *pa;
b = *pb;
k = k - TK;
}
int32 rxc[TM] = get_global_range[TM](0);
int32 ryc[TN] = get_global_range[TN](1);
fp32* pc[TM, TN] = c + ryc[newaxis, :]*M + rxc[:, newaxis];
int1 checkc0[TM] = rxc < M;
int1 checkc1[TN] = ryc < N;
int1 checkc[TM, TN] = checkc0[:, newaxis] && checkc1[newaxis, :];
@checkc *pc = C;
*pc = C;
}
)";
@@ -93,7 +89,7 @@ int main() {
triton::jit jit(context);
// matrix multiplication parameters
int32_t M = 512, N = 512, K = 512;
int32_t M = 256, N = 256, K = 256;
std::vector<float> hc(M*N);
std::vector<float> rc(M*N);
std::vector<float> ha(M*K);
@@ -155,11 +151,11 @@ int main() {
// just-in-time compile source-code
std::vector<unsigned> params = {
16, 2, 64,
32, 2, 64,
16, 8, 2, 2,
8, 8,
4,
1, 4, 8,
1, 4, 8,
1, 1, 4, 4,
1, 8,
1,
};
// params = {8, 2, 64, 16, 2, 64, 4, 16, 2, 2, 8, 8, 4};

View File

@@ -68,10 +68,10 @@ public:
void setArg(unsigned int index, std::size_t size, void* ptr);
void setArg(unsigned int index, driver::buffer* buffer);
// Params
const std::vector<llvm::GenericValue>& params();
const std::vector<void*>& params();
private:
std::vector<std::shared_ptr<void> > params_store_;
std::vector<llvm::GenericValue> params_;
std::vector<void*> params_;
};
// OpenCL

View File

@@ -52,15 +52,13 @@ public:
selection(&allocation, &tune, &buffer_info, target) { }
void init(ir::module &module) {
// generate ptx
buffer_info.run(module);
shared.run(module);
triton::ir::print(module, std::cout);
liveness.run(module);
allocation.run();
barriers.run(module);
// buffer_info.run(module);
// shared.run(module);
// liveness.run(module);
// allocation.run();
// barriers.run(module);
vectorize.run(module);
triton::ir::print(module, std::cout);
// triton::ir::print(module, std::cout);
}
codegen::tune tune;

View File

@@ -234,7 +234,7 @@ ir::type* tile::type_impl(ir::module *mod, ir::type *type, storage_spec_vec_cons
// Pointer
ir::type* pointer::type_impl(ir::module*, ir::type *type, storage_spec_vec_const_ref_t storage) const{
bool is_ptr_to_const = std::find(storage.begin(), storage.end(), CONSTANT_SPACE_T) != storage.end();
return ir::pointer_type::get(type, is_ptr_to_const?4:1);
return ir::pointer_type::get(type, is_ptr_to_const?4:0);
}
// Function

View File

@@ -116,9 +116,6 @@ void allocation::run(){
for(auto &x: offsets_){
allocated_size_ = std::max<size_t>(allocated_size_, x.second + get_num_bytes(x.first));
}
std::cout << "Allocated: " << allocated_size_ << std::endl;
for(auto &x: offsets_)
std::cout << x.first->get_name() << " " << x.second << std::endl;
}
}

View File

@@ -130,7 +130,6 @@ void barriers::run(ir::module &mod) {
n_inserted_im1 = n_inserted_i;
}while(!done);
for(ir::instruction* i: insert_locs){
std::cout << i->get_name() << std::endl;
insert_barrier(i, builder);
}
}

View File

@@ -46,7 +46,6 @@ void liveness::run(ir::module &mod) {
}
intervals_[v] = segment{start, end};
}
std::cout << "Number of intervals: " << intervals_.size() << std::endl;
}
}

View File

@@ -810,7 +810,21 @@ void selection::run(ir::module &src, Module &dst) {
for(ir::function *fn: src.get_function_list()) {
// create LLVM function
FunctionType *fn_ty = (FunctionType*)llvm_type(fn->get_fn_type(), dst_ctx);
Function *dst_fn = Function::Create(fn_ty, Function::ExternalLinkage, fn->get_name(), &dst);
Type *dst_fn_ret_ty = fn_ty->getReturnType();
std::vector<Type*> dst_fn_args_ty;
for(unsigned i = 0; i < fn_ty->getNumParams(); i++)
dst_fn_args_ty.push_back(fn_ty->getParamType(i));
dst_fn_args_ty.push_back(dst_builder.getInt32Ty());
dst_fn_args_ty.push_back(dst_builder.getInt32Ty());
dst_fn_args_ty.push_back(dst_builder.getInt32Ty());
FunctionType *dst_fn_ty = FunctionType::get(dst_fn_ret_ty, dst_fn_args_ty, false);
// grid indices
fn->get_fn_type()->get_return_ty();
Function *dst_fn = Function::Create(dst_fn_ty, Function::ExternalLinkage, fn->get_name(), &dst);
// set attributes
for(auto attr_pair: fn->attrs()){
unsigned id = attr_pair.first;
@@ -831,15 +845,15 @@ void selection::run(ir::module &src, Module &dst) {
// allocate shared memory
Value *sh_mem_ptr = nullptr;
if(unsigned alloc_size = alloc_->get_allocated_size()){
Type *int_8_ty = Type::getInt8Ty(dst_ctx);
ArrayType *array_ty = ArrayType::get(int_8_ty, alloc_size);
Type *ptr_ty = PointerType::get(int_8_ty, 3);
GlobalVariable *sh_mem_array =
new GlobalVariable(dst, array_ty, false, GlobalVariable::ExternalLinkage,
nullptr, "__shared_ptr", nullptr, GlobalVariable::NotThreadLocal, 3);
sh_mem_ptr = dst_builder.CreateBitCast(sh_mem_array, ptr_ty);
}
// if(unsigned alloc_size = alloc_->get_allocated_size()){
// Type *int_8_ty = Type::getInt8Ty(dst_ctx);
// ArrayType *array_ty = ArrayType::get(int_8_ty, alloc_size);
// Type *ptr_ty = PointerType::get(int_8_ty, 3);
// GlobalVariable *sh_mem_array =
// new GlobalVariable(dst, array_ty, false, GlobalVariable::ExternalLinkage,
// nullptr, "__shared_ptr", nullptr, GlobalVariable::NotThreadLocal, 3);
// sh_mem_ptr = dst_builder.CreateBitCast(sh_mem_array, ptr_ty);
// }
// create grids
init_grids(fn, dst_builder, sh_mem_ptr);

View File

@@ -246,7 +246,7 @@ bool tune::check_constraints(std::map<ir::value *, std::vector<std::string>> &er
int num_threads = 1;
for(size_t k = 0; k < shapes.size(); k++)
num_threads *= params_[i]["mts.d" + to_string(k)]->get_value();
if(num_threads % 64 != 0)
if(num_threads % 1 != 0)
errors[i].push_back("number of threads per block (" + to_string(num_threads) + ") must be multiple of 32");
if(num_threads != num_threads_)
errors[i].push_back("Number of threads must be the same for all tiles (" + to_string(num_threads_) + ")");

View File

@@ -48,20 +48,20 @@ void backend::platforms::init() {
if(dispatch::cuinit()){
cache_.push_back(new cu_platform());
}
//if OpenCL is here
if(dispatch::clinit()){
cl_uint num_platforms;
dispatch::clGetPlatformIDs(0, nullptr, &num_platforms);
std::vector<cl_platform_id> ids(num_platforms);
dispatch::clGetPlatformIDs(num_platforms, ids.data(), nullptr);
for(cl_platform_id id: ids)
cache_.push_back(new cl_platform(id));
}
// //if host is here
// bool host_visible = true;
// if(host_visible){
// cache_.push_back(new host_platform());
// //if OpenCL is here
// if(dispatch::clinit()){
// cl_uint num_platforms;
// dispatch::clGetPlatformIDs(0, nullptr, &num_platforms);
// std::vector<cl_platform_id> ids(num_platforms);
// dispatch::clGetPlatformIDs(num_platforms, ids.data(), nullptr);
// for(cl_platform_id id: ids)
// cache_.push_back(new cl_platform(id));
// }
//if host is here
bool host_visible = true;
if(host_visible){
cache_.push_back(new host_platform());
}
if(cache_.empty())
throw std::runtime_error("ISAAC: No backend available. Make sure CUDA is available in your library path");
}

View File

@@ -64,6 +64,9 @@ buffer* buffer::create(driver::context* ctx, size_t size) {
host_buffer::host_buffer(driver::context *context, size_t size)
: buffer(context, host_buffer_t(), true){
hst_->data = new char[size];
std::cout << size << std::endl;
std::cout << "allocating " << (float*)hst_->data << std::endl;
std::cout << *((float*)(hst_->data) + 512*500) << std::endl;
}
//

View File

@@ -77,14 +77,14 @@ void host_kernel::setArg(unsigned int index, std::size_t size, void* ptr){
}
params_store_[index].reset(malloc(size), free);
memcpy(params_store_[index].get(), ptr, size);
params_[index] = llvm::GenericValue(params_store_[index].get());
params_[index] = params_store_[index].get();
}
void host_kernel::setArg(unsigned int index, driver::buffer* buffer){
kernel::setArg(index, (void*)buffer->hst()->data);
}
const std::vector<llvm::GenericValue>& host_kernel::params(){
const std::vector<void *> &host_kernel::params(){
return params_;
}

View File

@@ -151,11 +151,38 @@ host_module::host_module(driver::context * context, llvm::Module* src): module(c
// llvm::SmallVector<char, 0> buffer;
// module::compile_llvm_module(src, triple, cpu, "", buffer);
// create kernel wrapper
llvm::LLVMContext &ctx = src->getContext();
llvm::Type *void_ty = llvm::Type::getVoidTy(ctx);
llvm::Type *args_ty = llvm::Type::getInt8PtrTy(ctx)->getPointerTo();
llvm::Type *int32_ty = llvm::Type::getInt32Ty(ctx);
llvm::FunctionType *main_ty = llvm::FunctionType::get(void_ty, {args_ty, int32_ty, int32_ty, int32_ty}, false);
llvm::Function* main = llvm::Function::Create(main_ty, llvm::Function::ExternalLinkage, "main", src);
llvm::Function* fn = src->getFunction("matmul");
llvm::FunctionType *fn_ty = fn->getFunctionType();
std::vector<llvm::Value*> fn_args(fn_ty->getNumParams());
std::vector<llvm::Value*> ptrs(fn_args.size() - 3);
llvm::BasicBlock* entry = llvm::BasicBlock::Create(ctx, "entry", main);
llvm::IRBuilder<> ir_builder(ctx);
ir_builder.SetInsertPoint(entry);
for(unsigned i = 0; i < ptrs.size(); i++)
ptrs[i] = ir_builder.CreateGEP(main->arg_begin(), ir_builder.getInt32(i));
for(unsigned i = 0; i < ptrs.size(); i++){
llvm::Value* addr = ir_builder.CreateBitCast(ir_builder.CreateLoad(ptrs[i]), fn_ty->getParamType(i)->getPointerTo());
fn_args[i] = ir_builder.CreateLoad(addr);
}
fn_args[fn_args.size() - 3] = main->arg_begin() + 1;
fn_args[fn_args.size() - 2] = main->arg_begin() + 2;
fn_args[fn_args.size() - 1] = main->arg_begin() + 3;
ir_builder.CreateCall(fn, fn_args);
ir_builder.CreateRetVoid();
// create execution engine
// llvm::legacy::PassManager pass;
// pass.add(llvm::createPrintModulePass(llvm::outs()));
// pass.add(llvm::createVerifierPass());
// pass.run(*src);
llvm::legacy::PassManager pass;
pass.add(llvm::createPrintModulePass(llvm::outs()));
pass.add(llvm::createVerifierPass());
pass.run(*src);
auto cloned = llvm::CloneModule(*src);
for(llvm::Function& fn: cloned->functions())
hst_->functions[fn.getName()] = &fn;

View File

@@ -84,15 +84,19 @@ void host_stream::synchronize() {
void host_stream::enqueue(driver::kernel* kernel, std::array<size_t, 3> grid, std::array<size_t, 3> block, std::vector<event> const *, event* event) {
driver::host_kernel* hst_kernel = (host_kernel*)kernel;
llvm::ExecutionEngine* engine = kernel->module()->hst()->engine;
engine->runFunction(kernel->hst()->fn, llvm::ArrayRef<llvm::GenericValue>(hst_kernel->params()));
void (*fn)(char**, int32_t, int32_t, int32_t) = (void(*)(char**, int32_t, int32_t, int32_t))engine->getFunctionAddress("main");
for(size_t i = 0; i < grid[0]; i++)
for(size_t j = 0; j < grid[1]; j++)
for(size_t k = 0; k < grid[2]; k++)
fn((char**)hst_kernel->params().data(), int32_t(i), int32_t(j), int32_t(k));
}
void host_stream::write(driver::buffer* buffer, bool blocking, std::size_t offset, std::size_t size, void const* ptr) {
std::memcpy((void*)buffer->hst()->data, ptr, size);
}
void host_stream::read(driver::buffer* buffer, bool blocking, std::size_t offset, std::size_t size, void* ptr) {
std::memcpy(ptr, (const void*)buffer->hst()->data, size);
}

View File

@@ -92,7 +92,7 @@ std::unique_ptr<ir::module> jit::make_triton_module(const std::string &src) {
}
jit::jit(driver::context *context): driver_context_(context), target_(new triton::codegen::amd_cl_target()) {
jit::jit(driver::context *context): driver_context_(context), target_(new triton::codegen::cpu_target()) {
}
@@ -164,11 +164,16 @@ void jit::add_module(ir::module &tt_module, const std::vector<unsigned> &params)
// check constraints
std::map<ir::value*, std::vector<std::string>> errors;
passes.tune.check_constraints(errors);
for(auto x: errors){
std::cout << x.first << std::endl;
for(auto str: x.second)
std::cout << str << std::endl;
}
if(errors.size())
throw std::runtime_error("invalid parameters");
driver::device* device = driver_context_->device();
if(passes.allocation.get_allocated_size() > device->max_shared_memory())
throw std::runtime_error("invalid parameters");
// driver::device* device = driver_context_->device();
// if(passes.allocation.get_allocated_size() > device->max_shared_memory())
// throw std::runtime_error("invalid parameters");
// triton module -> llvm module
auto ll_module = make_llvm_module(tt_module, passes);
// llvm module -> machine code