diff --git a/CMakeLists.txt b/CMakeLists.txt index 2531e84ca..814206cfc 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -15,7 +15,7 @@ find_package(LLVM REQUIRED CONFIG) message(STATUS ${LLVM_INCLUDE_DIRS}) include_directories(${LLVM_INCLUDE_DIRS}) add_definitions(${LLVM_DEFINITIONS}) -llvm_map_components_to_libnames(llvm_libs support core irreader MC NVPTXCodeGen all) +#llvm_map_components_to_libnames(llvm_libs all) #Default build type if(NOT CMAKE_BUILD_TYPE) @@ -34,7 +34,7 @@ set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} ${LLVM_CXXFLAGS} -std=c++11") # TDL file(GLOB_RECURSE LIBTDL_SRC lib/*.cpp) add_library(tdl SHARED ${LIBTDL_SRC} ${BISON_Parser_OUTPUTS} ${FLEX_Lexer_OUTPUTS}) -target_link_libraries(tdl ${llvm_libs}) +target_link_libraries(tdl LLVM) # Examples add_subdirectory(examples) diff --git a/examples/matrix.cpp b/examples/matrix.cpp index f5334769e..8145ddb90 100644 --- a/examples/matrix.cpp +++ b/examples/matrix.cpp @@ -36,7 +36,7 @@ extern translation_unit *ast_root; const char src[] = "\ -void test(fp32 *a, fp32 *b, fp32 *c, int32 M, int32 N, int32 K){\ +void test(fp32 *a, fp32 *b, fp32 *c, int32 M, int32 N, int32 K, int32 bound){\ int32 rxa[16] = get_global_range[16](0);\ int32 ryb[16] = get_global_range[16](1);\ int32 rka[8] = 0 ... 8;\ @@ -50,15 +50,17 @@ void test(fp32 *a, fp32 *b, fp32 *c, int32 M, int32 N, int32 K){\ fp32* pc[16, 16] = c + rxc[:, newaxis] + ryc[newaxis, :]*M;\ fp32 a[16, 8] = *pa;\ fp32 b[16, 8] = *pb;\ - int1 checkc0[16] = (rxc < M);\ - int1 checkc1[16] = (ryc < N);\ + int1 checkc0[16] = rxc < M;\ + int1 checkc1[16] = ryc < N;\ int1 checkc[16, 16] = checkc0[:, newaxis] && checkc1[newaxis, :];\ for(k = K; k > 0; k = k - 8){\ + int1 sanitya[16, 8] = (k >= bound);\ + int1 sanityb[16, 8] = (k >= bound);\ C = dot(a, b, C);\ pa = pa + 8*M;\ pb = pb + 8*K;\ - a = *pa;\ - b = *pb;\ + @sanitya a = *pa;\ + @sanityb b = *pb;\ }\ @checkc *pc = C;\ }\ @@ -201,6 +203,8 @@ int main() { for(auto &e: x.second) std::cout << e << std::endl; } + if(errors.size()) + exit(EXIT_FAILURE); // run passes shared.run(module); @@ -213,7 +217,7 @@ int main() { // llvm source llvm::legacy::PassManager manager; -// manager.add(llvm::createPrintModulePass(llvm::outs())); + manager.add(llvm::createPrintModulePass(llvm::outs())); manager.add(llvm::createVerifierPass(true)); manager.run(llvm_module); @@ -233,6 +237,7 @@ int main() { // Allocate buffers typedef float numeric_t; size_t M = 128, N = 128, K = 128; + size_t bound = 8; std::vector c(M*N); std::vector rc(M*N); std::vector a(M*K); @@ -252,13 +257,13 @@ int main() { checkCudaErrors(cuMemcpyHtoD(d_b, b.data(), sizeof(numeric_t) * b.size())); checkCudaErrors(cuMemcpyHtoD(d_c, c.data(), sizeof(numeric_t) * c.size())); // Launch kernel - void *args[] = { &d_a, &d_b, &d_c, &M, &N, &K}; + void *args[] = { &d_a, &d_b, &d_c, &M, &N, &K, &bound}; int num_regs; cuFuncGetAttribute(&num_regs, CU_FUNC_ATTRIBUTE_NUM_REGS, cu_kernel); unsigned TM = 16; unsigned TN = 16; unsigned nthreads = 32; - checkCudaErrors(cuLaunchKernel(cu_kernel, M/TM, N/TN, 1, nthreads, 1, 1, 0, cu_stream, args, NULL)); + checkCudaErrors(cuLaunchKernel(cu_kernel, (M + TM - 1)/TM, (N + TN - 1)/TN, 1, nthreads, 1, 1, 0, cu_stream, args, NULL)); checkCudaErrors(cuStreamSynchronize(cu_stream)); // Write back checkCudaErrors(cuMemcpyDtoH(c.data(), d_c, sizeof(numeric_t) * c.size())); diff --git a/include/ast/ast.h b/include/ast/ast.h index 6471b2296..4a9889093 100644 --- a/include/ast/ast.h +++ b/include/ast/ast.h @@ -185,7 +185,8 @@ private: public: binary_operator(BIN_OP_T op, node *lhs, node *rhs) - : op_(op), lhs_((expression*)lhs), rhs_((expression*)rhs) { } + : op_(op), lhs_((expression*)lhs), rhs_((expression*)rhs) { + } ir::value* codegen(ir::module *) const; private: @@ -320,14 +321,14 @@ private: class expression_statement: public statement{ public: - expression_statement(node *expr, node *pred = nullptr) - : expr_((expression*)expr), pred_((expression*)pred){ } + expression_statement(node *expr, node *mask = nullptr) + : expr_((expression*)expr), mask_((expression*)mask){ } ir::value* codegen(ir::module * mod) const; private: expression *expr_; - expression *pred_; + expression *mask_; }; class compound_statement: public statement{ diff --git a/include/ast/parser.y b/include/ast/parser.y index 442bee12e..905541d70 100644 --- a/include/ast/parser.y +++ b/include/ast/parser.y @@ -121,7 +121,7 @@ primary_expression | constant ELLIPSIS constant { $$ = new constant_range($1, $3); } | builtin { $$ = $1; } | STRING_LITERAL { $$ = new string_literal(yytext); } - | '(' expression ')' { $$ = $1; } + | '(' expression ')' { $$ = $2; } ; slice @@ -155,7 +155,7 @@ unary_operator cast_expression : unary_expression { $$ = $1; } - | '(' type_name ')' cast_expression { $$ = new cast_operator($1, $2); } + | '(' type_name ')' cast_expression { $$ = new cast_operator($2, $4); } ; multiplicative_expression diff --git a/include/codegen/selection.h b/include/codegen/selection.h index 6580ade98..2531dc74c 100644 --- a/include/codegen/selection.h +++ b/include/codegen/selection.h @@ -54,13 +54,15 @@ private: llvm::Value* shared_offset(indices_t idx); public: - shared_tile(llvm::Type* ty, const shapes_t &shapes, llvm::Value* ptr, llvm::IRBuilder<> &builder); + shared_tile(llvm::Type* ty, const shapes_t &shapes, llvm::Value* ptr, llvm::IRBuilder<> &builder, llvm::Value* offset = nullptr); void set_value(indices_t, llvm::Value *); llvm::Value* get_value(indices_t idx); llvm::Value* get_pointer() { return ptr_; } + llvm::Value* get_offset() { return offset_; } private: llvm::Value *ptr_; + llvm::Value *offset_; llvm::IRBuilder<> &builder_; std::map ptr_cache_; }; diff --git a/include/ir/instructions.h b/include/ir/instructions.h index 28feeb442..6c835ec2e 100644 --- a/include/ir/instructions.h +++ b/include/ir/instructions.h @@ -26,12 +26,16 @@ public: const basic_block *get_parent() const { return parent_; } basic_block *get_parent() { return parent_; } void erase_from_parent(); + // mask + value* set_mask(value *mask) { mask_ = mask; } + value* get_mask() { return mask_; } // helpers bool has_tile_result_or_op(); private: basic_block *parent_; value *pred_; + value *mask_; }; //===----------------------------------------------------------------------===// diff --git a/lib/ast/lowering.cpp b/lib/ast/lowering.cpp index 36bd50adb..4a9d7ff3e 100644 --- a/lib/ast/lowering.cpp +++ b/lib/ast/lowering.cpp @@ -16,33 +16,34 @@ namespace ast{ /* node */ ir::value *node::explicit_cast(ir::builder &builder, ir::value *src, ir::type *dst_ty){ - ir::type *src_ty = src->get_type()->get_scalar_ty(); + ir::type *src_scalar_ty = src->get_type()->get_scalar_ty(); + ir::type *dst_scalar_ty = dst_ty->get_scalar_ty(); bool src_signed = false; bool dst_signed = false; - if(src_ty == dst_ty) + if(src_scalar_ty == dst_scalar_ty) return src; - else if(src_ty->is_integer_ty() && src_signed && dst_ty->is_floating_point_ty()) + else if(src_scalar_ty->is_integer_ty() && src_signed && dst_scalar_ty->is_floating_point_ty()) return builder.create_si_to_fp(src, dst_ty); - else if(src_ty->is_integer_ty() && !src_signed && dst_ty->is_floating_point_ty()) + else if(src_scalar_ty->is_integer_ty() && !src_signed && dst_scalar_ty->is_floating_point_ty()) return builder.create_ui_to_fp(src, dst_ty); - else if(src_ty->is_floating_point_ty() && dst_ty->is_integer_ty() && dst_signed) + else if(src_scalar_ty->is_floating_point_ty() && dst_scalar_ty->is_integer_ty() && dst_signed) return builder.create_fp_to_si(src, dst_ty); - else if(src_ty->is_floating_point_ty() && dst_ty->is_integer_ty() && !dst_signed) + else if(src_scalar_ty->is_floating_point_ty() && dst_scalar_ty->is_integer_ty() && !dst_signed) return builder.create_fp_to_ui(src, dst_ty); - else if(src_ty->is_floating_point_ty() && dst_ty->is_floating_point_ty() && - src_ty->get_fp_mantissa_width() < dst_ty->get_fp_mantissa_width()) + else if(src_scalar_ty->is_floating_point_ty() && dst_scalar_ty->is_floating_point_ty() && + src_scalar_ty->get_fp_mantissa_width() < dst_scalar_ty->get_fp_mantissa_width()) return builder.create_fp_ext(src, dst_ty); - else if(src_ty->is_floating_point_ty() && dst_ty->is_floating_point_ty() && - src_ty->get_fp_mantissa_width() > dst_ty->get_fp_mantissa_width()) + else if(src_scalar_ty->is_floating_point_ty() && dst_scalar_ty->is_floating_point_ty() && + src_scalar_ty->get_fp_mantissa_width() > dst_scalar_ty->get_fp_mantissa_width()) return builder.create_fp_trunc(src, dst_ty); - else if(src_ty->is_integer_ty() && dst_ty->is_integer_ty() && - src_ty->get_integer_bitwidth()) + else if(src_scalar_ty->is_integer_ty() && dst_scalar_ty->is_integer_ty() && + src_scalar_ty->get_integer_bitwidth()) return builder.create_int_cast(src, dst_ty, dst_signed); else @@ -247,7 +248,14 @@ ir::value* compound_statement::codegen(ir::module* mod) const{ /* expression statement */ ir::value* expression_statement::codegen(ir::module *mod) const{ - return expr_->codegen(mod); + ir::value *expr = expr_->codegen(mod); + if(mask_) { + ir::instruction *itn = dynamic_cast(expr); + assert(itn); + ir::value *mask = mask_->codegen(mod); + itn->set_mask(mask); + } + return expr; } /* Iteration statement */ @@ -325,7 +333,7 @@ ir::value* initializer::codegen(ir::module * mod) const{ ir::value *value = ir::undef_value::get(ty); if(expr_){ value = expr_->codegen(mod); - value = explicit_cast(mod->get_builder(), value, ty->get_scalar_ty()); + value = explicit_cast(mod->get_builder(), value, ty); implicit_broadcast(mod, value, ty); } value->set_name(name); @@ -526,7 +534,7 @@ ir::value *assignment_expression::codegen(ir::module *mod) const{ assert(x->get_op()==DEREF); assert(x->lvalue()); ir::value *ptr = x->lvalue()->codegen(mod); - mod->get_builder().create_store(ptr, rvalue); + rvalue = mod->get_builder().create_store(ptr, rvalue); } return rvalue; } diff --git a/lib/codegen/selection.cpp b/lib/codegen/selection.cpp index 9ef405e06..85fdb2189 100644 --- a/lib/codegen/selection.cpp +++ b/lib/codegen/selection.cpp @@ -1,6 +1,7 @@ #include "codegen/selection.h" #include "codegen/tune.h" #include "codegen/allocation.h" +#include "llvm/IR/InstrTypes.h" #include "llvm/IR/Module.h" #include "llvm/IR/IRBuilder.h" #include "ir/context.h" @@ -9,6 +10,8 @@ #include "ir/type.h" #include "llvm/Transforms/Scalar/EarlyCSE.h" #include "llvm/Analysis/LoopInfo.h" +#include "llvm/Transforms/Utils/BasicBlockUtils.h" +#include "llvm/IR/BasicBlock.h" namespace tdl{ namespace codegen{ @@ -121,8 +124,8 @@ Value* shared_tile::shared_offset(indices_t idx) { return result; } -shared_tile::shared_tile(Type *ty, const shapes_t &shapes, Value *ptr, llvm::IRBuilder<> &builder): - tile(ty, shapes), ptr_(ptr), builder_(builder) { +shared_tile::shared_tile(Type *ty, const shapes_t &shapes, Value *ptr, llvm::IRBuilder<> &builder, Value *offset): + tile(ty, shapes), ptr_(ptr), builder_(builder), offset_(offset) { } void shared_tile::set_value(indices_t idx, Value *value) { @@ -404,25 +407,17 @@ void selection::create_tile(ir::value *v, IRBuilder<> &builder, std::swap(id_pre, id_loop); ir::value *pre_value = phi->get_incoming_value(id_pre); ir::value *loop_value = phi->get_incoming_value(id_loop); - BasicBlock *pre_block = (BasicBlock*)vmap_[phi->get_incoming_block(id_pre)]; - BasicBlock *loop_block = (BasicBlock*)vmap_[phi->get_incoming_block(id_loop)]; if(parent->empty()) builder.SetInsertPoint(parent); else builder.SetInsertPoint(&*parent->getFirstInsertionPt()); PHINode *ptr = builder.CreatePHI(ptr_ty, 2); - // offset PHINode *offset = builder.CreatePHI(builder.getInt32Ty(), 2); - Value *next_offset = builder.CreateNeg(offset); - offset->addIncoming(builder.getInt32(alloc_->get_num_bytes(phi) / 2 / 4), pre_block); - offset->addIncoming(next_offset, loop_block); // next pointer Value *pre_ptr = builder.CreateGEP(sh_mem_ptr, builder.getInt32(alloc_->get_offset(phi))); pre_ptr = builder.CreateBitCast(pre_ptr, ptr->getType()); Value *next_ptr = builder.CreateGEP(ptr, offset); - ptr->addIncoming(pre_ptr, pre_block); - ptr->addIncoming(next_ptr, loop_block); - tmap_.insert({v, new shared_tile(ty, shapes, ptr, builder)}); + tmap_.insert({phi, new shared_tile(ty, shapes, ptr, builder, offset)}); tmap_.insert({pre_value, new shared_tile(ty, shapes, pre_ptr, builder)}); tmap_.insert({loop_value, new shared_tile(ty, shapes, next_ptr, builder)}); } @@ -483,14 +478,43 @@ void selection::init_grids(ir::function *fn, IRBuilder<> &builder, Value *sh_mem void selection::lower_tile_instruction(ir::instruction *ins, llvm::IRBuilder<> &builder) { - Module *module = builder.GetInsertBlock()->getModule(); + BasicBlock *block = builder.GetInsertBlock(); + Module *module = block->getModule(); + Function *function = block->getParent(); + ir::value *mask = ins->get_mask(); LLVMContext &ctx = builder.getContext(); + // helper to handle masks + auto insert_masked = [&](indices_t idx, std::function insert_value) { + BasicBlock *block = builder.GetInsertBlock(); + Value *result; + if(mask){ + Value *llvm_mask = tmap_.at(mask)->get_value(idx); + BasicBlock *then_bb = BasicBlock::Create(ctx, "", function); + BasicBlock *done_bb = BasicBlock::Create(ctx, "", function); + builder.CreateCondBr(llvm_mask, then_bb, done_bb); + builder.SetInsertPoint(then_bb); + result = insert_value(); + builder.CreateBr(done_bb); + builder.SetInsertPoint(done_bb); + if(!ins->get_type()->is_void_ty()){ + Type *ty = result->getType(); + PHINode *phi = builder.CreatePHI(ty, 2); + phi->addIncoming(llvm::UndefValue::get(ty), block); + phi->addIncoming(result, then_bb); + return (Value*)phi; + } + } + else + result = insert_value(); + return result; + }; + // store if(auto *x = dynamic_cast(ins)) { distributed_tile* ptr = (distributed_tile*)tmap_.at(x->get_pointer_operand()); tile *value = tmap_.at(x->get_value_operand()); ptr->for_each([&](indices_t idx){ - builder.CreateStore(value->get_value(idx), ptr->get_value(idx)); + insert_masked(idx, [&]{ return builder.CreateStore(value->get_value(idx), ptr->get_value(idx)); }); }); } else { @@ -511,7 +535,7 @@ void selection::lower_tile_instruction(ir::instruction *ins, llvm::IRBuilder<> & Value *offset = builder.CreateMul(builder.getInt32(shapes[0]), group_id); result->for_each([&](indices_t idx){ BinaryOperator *bin = static_cast(idx[0]); - result->set_value(idx, builder.CreateAdd(bin, offset)); + result->set_value(idx, insert_masked(idx, [&]{ return builder.CreateAdd(bin, offset); })); }); } // reshape @@ -530,7 +554,7 @@ void selection::lower_tile_instruction(ir::instruction *ins, llvm::IRBuilder<> & // splat else if(dynamic_cast(ins)) { result->for_each([&](indices_t idx) { - result->set_value(idx, llvm_value(ins->get_operand(0), builder)); + result->set_value(idx, insert_masked(idx, [&]{ return llvm_value(ins->get_operand(0), builder); })); }); } // broadcast @@ -603,7 +627,7 @@ void selection::lower_tile_instruction(ir::instruction *ins, llvm::IRBuilder<> & else return llvm_value(x, builder); }; - result->set_value(idx, llvm_inst(ins, value, builder)); + result->set_value(idx, insert_masked(idx, [&]() { return llvm_inst(ins, value, builder); })); }); } } @@ -625,6 +649,7 @@ void selection::run(ir::module &src, Module &dst){ vmap_.clear(); LLVMContext &dst_ctx = dst.getContext(); IRBuilder<> dst_builder(dst_ctx); + std::map block_of; // iterate over functions for(ir::function *fn: src.get_function_list()) { @@ -661,6 +686,7 @@ void selection::run(ir::module &src, Module &dst){ } // create grids init_grids(fn, dst_builder, sh_mem_ptr); + std::map last_block; // iterate through block for(ir::basic_block *block: fn->blocks()) { BasicBlock *parent = (BasicBlock*)vmap_[block]; @@ -671,6 +697,7 @@ void selection::run(ir::module &src, Module &dst){ lower_instruction(i, dst_builder); if(dynamic_cast(i)) dst_builder.SetInsertPoint(parent); + last_block[block] = dst_builder.GetInsertBlock(); } } // add phi operands @@ -678,12 +705,31 @@ void selection::run(ir::module &src, Module &dst){ for(ir::instruction *inst: block->get_inst_list()) if(auto *phi = dynamic_cast(inst)){ if(buffer_info_->is_shared(phi)) { + PHINode *ptr = (PHINode*)((shared_tile*)tmap_.at(phi))->get_pointer(); + PHINode *offset = (PHINode*)((shared_tile*)tmap_.at(phi))->get_offset(); + for(unsigned n = 0; n < phi->get_num_incoming(); n++){ + ir::value *inc_val = phi->get_incoming_value(n); + ir::basic_block *inc_block = phi->get_incoming_block(n); + BasicBlock *llvm_inc_block = last_block.at(inc_block); + shared_tile *inc_shared = (shared_tile*)tmap_.at(inc_val); + GetElementPtrInst *inc_ptr = dyn_cast(inc_shared->get_pointer()); + if(inc_ptr && ptr == inc_ptr->getPointerOperand()){ + dst_builder.SetInsertPoint(llvm_inc_block->getTerminator()); + Value *next_offset = dst_builder.CreateNeg(offset); + offset->addIncoming(next_offset, llvm_inc_block); + } + else { + offset->addIncoming(dst_builder.getInt32(alloc_->get_num_bytes(phi)/(2*4)), llvm_inc_block); + } + ptr->addIncoming(inc_shared->get_pointer(), llvm_inc_block); + } continue; } for(unsigned n = 0; n < phi->get_num_incoming(); n++){ ir::value *inc_val = phi->get_incoming_value(n); ir::basic_block *inc_block = phi->get_incoming_block(n); - BasicBlock *llvm_inc_block = (BasicBlock*)vmap_[inc_block]; + std::cout << typeid(*inc_val).name() << " " << inc_val << " " << inc_block << std::endl; + BasicBlock *llvm_inc_block = last_block.at(inc_block); if(phi->get_type()->is_tile_ty()) { distributed_tile *phi_tile = (distributed_tile*)tmap_.at(phi); distributed_tile *inc_tile = (distributed_tile*)tmap_.at(inc_val); diff --git a/lib/codegen/tune.cpp b/lib/codegen/tune.cpp index c98b2ae66..924392cab 100644 --- a/lib/codegen/tune.cpp +++ b/lib/codegen/tune.cpp @@ -67,10 +67,17 @@ void tune::init_c_graph(ir::instruction *v) { } // Element-wise else if(dynamic_cast(v)){ + std::cout << typeid(*v).name() << std::endl; for(unsigned i = 0; i < shapes.size(); i ++) - for(ir::value* op: v->ops()){ + for(ir::value* op: v->ops()) add_constraint({v, i}, {op, i}); - } + } + + /* Add mask constraints */ + if(ir::value *mask = v->get_mask()){ + std::cout << typeid(*mask).name() << " " << typeid(*v->ops()[0]).name() << std::endl; + for(unsigned i = 0; i < shapes.size(); i++) + add_constraint({v->ops()[0], i}, {mask, i}); } } @@ -99,6 +106,7 @@ std::vector tune::get_params(ir::module &mod) { for(ir::instruction *i : block->get_inst_list()) for(auto &x: params_[i]) if(seen.insert(x.second).second && *x.second == 0){ + std::cout << typeid(*i).name() << std::endl; result.push_back(x.second); } return result; diff --git a/lib/ir/instructions.cpp b/lib/ir/instructions.cpp index f335bbeea..56f583141 100644 --- a/lib/ir/instructions.cpp +++ b/lib/ir/instructions.cpp @@ -186,8 +186,8 @@ cast_inst *cast_inst::create(op_t op, value *arg, type *ty, const std::string &n cast_inst *cast_inst::create_integer_cast(value *arg, type *ty, bool is_signed, const std::string &name, instruction *next){ type *arg_ty = arg->get_type(); assert(arg_ty->is_int_or_tileint_ty() && ty->is_int_or_tileint_ty() && "Invalid integer cast!"); - unsigned arg_bits = arg_ty->get_integer_bitwidth(); - unsigned dst_bits = ty->get_integer_bitwidth(); + unsigned arg_bits = arg_ty->get_scalar_ty()->get_integer_bitwidth(); + unsigned dst_bits = ty->get_scalar_ty()->get_integer_bitwidth(); op_t op = (arg_bits == dst_bits ? ic::BitCast : (arg_bits > dst_bits ? ic::Trunc : (is_signed ? ic::SExt : ic::ZExt))); diff --git a/lib/ir/type.cpp b/lib/ir/type.cpp index 075bcd88b..c790120fb 100644 --- a/lib/ir/type.cpp +++ b/lib/ir/type.cpp @@ -33,7 +33,7 @@ unsigned type::get_primitive_size_in_bits() const { } unsigned type::get_integer_bitwidth() const -{ return ((integer_type*)(this))->get_bitwidth(); } +{ assert(id_ == IntegerTyID); return ((integer_type*)(this))->get_bitwidth(); } unsigned type::get_tile_bitwidth() const { return ((tile_type*)(this))->get_bitwidth(); }