From 73db84c8bacf7bf5b8ed3d164bbbd292992ed9bd Mon Sep 17 00:00:00 2001 From: Philippe Tillet Date: Tue, 8 Jan 2019 17:44:31 -0500 Subject: [PATCH] [syntax tree] fixed broadcast semantics lowering --- examples/matrix.cpp | 15 ++- include/ast/ast.h | 5 + include/codegen/allocation.h | 2 +- include/codegen/layout.h | 4 +- include/codegen/liveness.h | 3 +- lib/ast/lowering.cpp | 232 ++++++++++++++++++----------------- lib/codegen/allocation.cpp | 2 +- lib/codegen/layout.cpp | 11 +- lib/codegen/liveness.cpp | 5 +- lib/codegen/tune.cpp | 2 + 10 files changed, 153 insertions(+), 128 deletions(-) diff --git a/examples/matrix.cpp b/examples/matrix.cpp index 78c8d7194..3a4877d36 100644 --- a/examples/matrix.cpp +++ b/examples/matrix.cpp @@ -4,6 +4,7 @@ #include "ir/context.h" #include "ir/module.h" #include "codegen/selection.h" +#include "codegen/tune.h" #include "llvm/IR/IRPrintingPasses.h" #include "llvm/IR/Module.h" #include "llvm/IR/LLVMContext.h" @@ -20,6 +21,8 @@ extern translation_unit *ast_root; const char src[] = "\ void test(fp32 *A, fp32 *B, fp32 *C, int32 i){\ + int32 tile[16, 16] = 0;\ + int32 test[16, 16] = tile + i;\ i = 1;\ A = A + i;\ }\ @@ -35,10 +38,14 @@ int main() { program->codegen(&module); llvm::LLVMContext llvm_context; llvm::Module llvm_module("test", llvm_context); + // lowering passes tdl::codegen::selection selection; - selection.run(module, llvm_module); - llvm::PrintModulePass print(llvm::outs()); - llvm::AnalysisManager analysis; - print.run(llvm_module, analysis); + tdl::codegen::tune tune; + tune.run(module); +// selection.run(module, llvm_module); +// // print LLVM program +// llvm::PrintModulePass print(llvm::outs()); +// llvm::AnalysisManager analysis; +// print.run(llvm_module, analysis); return 0; } diff --git a/include/ast/ast.h b/include/ast/ast.h index d9b24f8e4..2ab5c02cc 100644 --- a/include/ast/ast.h +++ b/include/ast/ast.h @@ -61,6 +61,11 @@ class identifier; // AST class node { +protected: + static ir::value* explicit_cast(ir::builder &builder, ir::value *src, ir::type *dst_ty); + static void implicit_broadcast(ir::module *mod, ir::value *&lhs, ir::value *&rhs); + static void implicit_cast(ir::builder &builder, ir::value *&lhs, ir::value *&rhs, + bool &is_float, bool &is_ptr, bool &is_int, bool &is_signed); public: virtual ir::value* codegen(ir::module *) const { return nullptr; } }; diff --git a/include/codegen/allocation.h b/include/codegen/allocation.h index c6783cc5e..b93322539 100644 --- a/include/codegen/allocation.h +++ b/include/codegen/allocation.h @@ -25,7 +25,7 @@ public: unsigned get_allocated_size() const { return allocated_size_; } // run - void run(ir::function &fn); + void run(); private: std::map offsets_; diff --git a/include/codegen/layout.h b/include/codegen/layout.h index 1fd6deeda..d63a5dbe2 100644 --- a/include/codegen/layout.h +++ b/include/codegen/layout.h @@ -7,7 +7,7 @@ namespace tdl { namespace ir { - class function; + class module; class instruction; class value; } @@ -32,7 +32,7 @@ public: shared_view_info get_shared_view(ir::value *v, unsigned idx); // run - bool run(ir::function &fn); + void run(ir::module &mod); private: std::map shared_views_; diff --git a/include/codegen/liveness.h b/include/codegen/liveness.h index 8a6806c2e..c7cfb5959 100644 --- a/include/codegen/liveness.h +++ b/include/codegen/liveness.h @@ -8,6 +8,7 @@ namespace tdl{ namespace ir{ class value; class function; + class module; } namespace codegen{ @@ -47,7 +48,7 @@ public: segment get_interval(ir::value* v) const { return intervals_.at(v); } // run - void run(ir::function *fn); + void run(ir::module &mod); private: has_storage_map_t has_dedicated_storage_; diff --git a/lib/ast/lowering.cpp b/lib/ast/lowering.cpp index 9dcbb2826..52d7a4a2e 100644 --- a/lib/ast/lowering.cpp +++ b/lib/ast/lowering.cpp @@ -14,6 +14,118 @@ namespace tdl{ namespace ast{ +/* node */ +ir::value *node::explicit_cast(ir::builder &builder, ir::value *src, ir::type *dst_ty){ + ir::type *src_ty = src->get_type()->get_scalar_ty(); + bool src_signed = false; + bool dst_signed = false; + if(src_ty == dst_ty) + return src; + else if(src_ty->is_integer_ty() && src_signed && dst_ty->is_floating_point_ty()) + return builder.create_si_to_fp(src, dst_ty); + + else if(src_ty->is_integer_ty() && !src_signed && dst_ty->is_floating_point_ty()) + return builder.create_ui_to_fp(src, dst_ty); + + else if(src_ty->is_floating_point_ty() && dst_ty->is_integer_ty() && dst_signed) + return builder.create_fp_to_si(src, dst_ty); + + else if(src_ty->is_floating_point_ty() && dst_ty->is_integer_ty() && !dst_signed) + return builder.create_fp_to_ui(src, dst_ty); + + else if(src_ty->is_floating_point_ty() && dst_ty->is_floating_point_ty() && + src_ty->get_fp_mantissa_width() < dst_ty->get_fp_mantissa_width()) + return builder.create_fp_ext(src, dst_ty); + + else if(src_ty->is_floating_point_ty() && dst_ty->is_floating_point_ty() && + src_ty->get_fp_mantissa_width() > dst_ty->get_fp_mantissa_width()) + return builder.create_fp_trunc(src, dst_ty); + + else if(src_ty->is_integer_ty() && dst_ty->is_integer_ty() && + src_ty->get_integer_bitwidth()) + return builder.create_int_cast(src, dst_ty, dst_signed); + + else + throw std::runtime_error("unreachable"); +} + + +void node::implicit_cast(ir::builder &builder, ir::value *&lhs, ir::value *&rhs, + bool &is_float, bool &is_ptr, bool &is_int, bool &is_signed){ + // Input types + ir::type *left_ty = lhs->get_type()->get_scalar_ty(); + ir::type *right_ty = rhs->get_type()->get_scalar_ty(); + // One operand is pointer + if(left_ty->is_pointer_ty()){ + is_ptr = true; + } + // One operand is double + else if(left_ty->is_double_ty() || right_ty->is_double_ty()){ + ir::value *&to_convert = left_ty->is_double_ty()?rhs:lhs; + to_convert = explicit_cast(builder, to_convert, builder.get_double_ty()); + is_float = true; + } + // One operand is float + else if(left_ty->is_float_ty() || right_ty->is_float_ty()){ + ir::value *&to_convert = left_ty->is_float_ty()?rhs:lhs; + to_convert = explicit_cast(builder, to_convert, builder.get_float_ty()); + is_float = true; + } + // Both operands are integers + else if(left_ty->is_integer_ty() && right_ty->is_integer_ty()){ + is_int = true; + is_signed = false; + if(left_ty->get_integer_bitwidth() != right_ty->get_integer_bitwidth()){ + ir::value *&to_convert = (left_ty->get_integer_bitwidth() > right_ty->get_integer_bitwidth())?rhs:lhs; + ir::type *dst_ty = (to_convert==lhs)?right_ty:left_ty; + to_convert = explicit_cast(builder, to_convert, dst_ty); + } + } + // Not reachable + else + throw std::runtime_error("unreachable"); +} + +void node::implicit_broadcast(ir::module *mod, ir::value *&lhs, ir::value *&rhs){ + ir::builder &builder = mod->get_builder(); + ir::type *lhs_ty = lhs->get_type(); + ir::type *rhs_ty = rhs->get_type(); + // Both are scalar + if(!lhs_ty->is_tile_ty() && !rhs_ty->is_tile_ty()) + return; + // One argument is scalar + if(lhs_ty->is_tile_ty() ^ rhs_ty->is_tile_ty()){ + auto &shapes = lhs_ty->is_tile_ty()?lhs_ty->get_tile_shapes():rhs_ty->get_tile_shapes(); + auto &scalar = lhs_ty->is_tile_ty()?rhs:lhs; + scalar = builder.create_splat(scalar, shapes); + return; + } + // Both are arrays + std::vector lhs_shapes = lhs->get_type()->get_tile_shapes(); + std::vector rhs_shapes = rhs->get_type()->get_tile_shapes(); + int lhs_dim = lhs_shapes.size(); + int rhs_dim = rhs_shapes.size(); + std::vector &shortest = (lhs_dim < rhs_dim)?lhs_shapes:rhs_shapes; + std::vector &longest = (lhs_dim < rhs_dim)?rhs_shapes:lhs_shapes; + size_t ndim = longest.size(); + int off = longest.size() - shortest.size(); + for(int i = longest.size(); i>= 0; i--){ + if(shortest[off + i] != longest[i]) + throw std::runtime_error("cannot broadcast"); + } + // Pad + for(size_t i = 0; i < off; i++) + shortest.insert(shortest.begin(), 1); + ir::value *&target = (lhs_dim < rhs_dim)?lhs:rhs; + target = builder.create_reshape(target, shortest); + // Broadcast + std::vector shapes(ndim); + for(size_t i = 0; i < ndim; i++) + shapes[i] = std::max(shortest[i], longest[i]); + lhs = builder.create_broadcast(lhs, shapes); + rhs = builder.create_broadcast(rhs, shapes); +} + /* Translation unit */ ir::value* translation_unit::codegen(ir::module *mod) const{ decls_->codegen(mod); @@ -195,11 +307,12 @@ void initializer::specifier(const declaration_specifier *spec) { ir::value* initializer::codegen(ir::module * mod) const{ ir::type *ty = decl_->type(mod, spec_->type(mod)); std::string name = decl_->id()->name(); - ir::value *value; - if(expr_) - value = expr_->codegen(mod); - else - value = ir::undef_value::get(ty); + ir::value *value = ir::undef_value::get(ty); + if(expr_){ + ir::value* target = expr_->codegen(mod); + explicit_cast(mod->get_builder(), target, ty->get_scalar_ty()); + implicit_broadcast(mod, value, target); + } value->set_name(name); mod->set_value(name, value); return value; @@ -208,119 +321,12 @@ ir::value* initializer::codegen(ir::module * mod) const{ /*------------------*/ /* Expression */ /*------------------*/ -ir::value *llvm_cast(ir::builder &builder, ir::value *src, ir::type *dst_ty){ - ir::type *src_ty = src->get_type(); - bool src_signed = false; - bool dst_signed = false; - if(src_ty == dst_ty) - return src; - else if(src_ty->is_integer_ty() && src_signed && dst_ty->is_floating_point_ty()) - return builder.create_si_to_fp(src, dst_ty); - - else if(src_ty->is_integer_ty() && !src_signed && dst_ty->is_floating_point_ty()) - return builder.create_ui_to_fp(src, dst_ty); - - else if(src_ty->is_floating_point_ty() && dst_ty->is_integer_ty() && dst_signed) - return builder.create_fp_to_si(src, dst_ty); - - else if(src_ty->is_floating_point_ty() && dst_ty->is_integer_ty() && !dst_signed) - return builder.create_fp_to_ui(src, dst_ty); - - else if(src_ty->is_floating_point_ty() && dst_ty->is_floating_point_ty() && - src_ty->get_fp_mantissa_width() < dst_ty->get_fp_mantissa_width()) - return builder.create_fp_ext(src, dst_ty); - - else if(src_ty->is_floating_point_ty() && dst_ty->is_floating_point_ty() && - src_ty->get_fp_mantissa_width() > dst_ty->get_fp_mantissa_width()) - return builder.create_fp_trunc(src, dst_ty); - - else if(src_ty->is_integer_ty() && dst_ty->is_integer_ty() && - src_ty->get_integer_bitwidth()) - return builder.create_int_cast(src, dst_ty, dst_signed); - - else - throw std::runtime_error("unreachable"); -} - -inline void implicit_cast(ir::builder &builder, ir::value *&lhs, ir::value *&rhs, - bool &is_float, bool &is_ptr, bool &is_int, bool &is_signed){ - // Input types - ir::type *left_ty = lhs->get_type(); - ir::type *right_ty = rhs->get_type(); - // One operand is pointer - if(left_ty->is_pointer_ty()){ - is_ptr = true; - } - // One operand is double - else if(left_ty->is_double_ty() || right_ty->is_double_ty()){ - ir::value *&to_convert = left_ty->is_double_ty()?rhs:lhs; - to_convert = llvm_cast(builder, to_convert, builder.get_double_ty()); - is_float = true; - } - // One operand is float - else if(left_ty->is_float_ty() || right_ty->is_float_ty()){ - ir::value *&to_convert = left_ty->is_float_ty()?rhs:lhs; - to_convert = llvm_cast(builder, to_convert, builder.get_float_ty()); - is_float = true; - } - // Both operands are integers - else if(left_ty->is_integer_ty() && right_ty->is_integer_ty()){ - is_int = true; - is_signed = false; - if(left_ty->get_integer_bitwidth() != right_ty->get_integer_bitwidth()){ - ir::value *&to_convert = (left_ty->get_integer_bitwidth() > right_ty->get_integer_bitwidth())?rhs:lhs; - ir::type *dst_ty = (to_convert==lhs)?right_ty:left_ty; - to_convert = llvm_cast(builder, to_convert, dst_ty); - } - } - // Not reachable - else - throw std::runtime_error("unreachable"); -} - -inline void implicit_broadcast(ir::module *mod, ir::builder &builder, ir::value *&lhs, ir::value *&rhs){ - std::vector lhs_shapes = lhs->get_type()->get_tile_shapes(); - std::vector rhs_shapes = rhs->get_type()->get_tile_shapes(); - // Both are scalar - if(lhs_shapes.empty() && rhs_shapes.empty()) - return; - // One argument is scalar - if(!lhs_shapes.empty() ^ !rhs_shapes.empty()){ - auto &shapes = lhs_shapes.empty()?rhs_shapes:lhs_shapes; - auto &target = lhs_shapes.empty()?lhs:rhs; - target = builder.create_splat(target, shapes); - return; - } - // Both are arrays - int lhs_dim = lhs_shapes.size(); - int rhs_dim = rhs_shapes.size(); - std::vector &shortest = (lhs_dim < rhs_dim)?lhs_shapes:rhs_shapes; - std::vector &longest = (lhs_dim < rhs_dim)?rhs_shapes:lhs_shapes; - size_t ndim = longest.size(); - int off = longest.size() - shortest.size(); - for(int i = longest.size(); i>= 0; i--){ - if(shortest[off + i] != longest[i]) - throw std::runtime_error("cannot broadcast"); - } - // Pad - for(size_t i = 0; i < off; i++) - shortest.insert(shortest.begin(), 1); - ir::value *&target = (lhs_dim < rhs_dim)?lhs:rhs; - target = builder.create_reshape(target, shortest); - // Broadcast - std::vector shapes(ndim); - for(size_t i = 0; i < ndim; i++) - shapes[i] = std::max(shortest[i], longest[i]); - lhs = builder.create_broadcast(lhs, shapes); - rhs = builder.create_broadcast(rhs, shapes); -} - /* Binary operator */ ir::value *binary_operator::llvm_op(ir::module *mod, ir::builder &builder, ir::value *lhs, ir::value *rhs, const std::string &name) const { bool is_float = false, is_ptr = false, is_int = false, is_signed = false; implicit_cast(builder, lhs, rhs, is_float, is_ptr, is_int, is_signed); -// implicit_broadcast(mod, builder, lhs, rhs); + implicit_broadcast(mod, lhs, rhs); if(op_==MUL && is_float) return builder.create_fmul(lhs, rhs, name); if(op_==MUL && is_int) diff --git a/lib/codegen/allocation.cpp b/lib/codegen/allocation.cpp index f67140bf5..1730371ae 100644 --- a/lib/codegen/allocation.cpp +++ b/lib/codegen/allocation.cpp @@ -11,7 +11,7 @@ namespace tdl{ namespace codegen{ -void allocation::run(ir::function &fn){ +void allocation::run(){ using std::max; using std::min; typedef std::multimap triples_map_type; diff --git a/lib/codegen/layout.cpp b/lib/codegen/layout.cpp index cdddb1d17..58f81227a 100644 --- a/lib/codegen/layout.cpp +++ b/lib/codegen/layout.cpp @@ -1,5 +1,6 @@ #include "codegen/layout.h" #include "ir/function.h" +#include "ir/module.h" #include "ir/basic_block.h" #include "ir/instructions.h" @@ -36,19 +37,19 @@ void layout::add_shared_views(ir::value *v){ } // Entry point -bool layout::run(ir::function &fn) { +void layout::run(ir::module &mod) { +for(ir::function *fn: mod.get_function_list()){ // Non-phis - for(ir::basic_block *block: fn.blocks()) + for(ir::basic_block *block: fn->blocks()) for(ir::instruction *instr: block->get_inst_list()) { add_shared_views(instr); } // Phi nodes - for(ir::basic_block *block: fn.blocks()) + for(ir::basic_block *block: fn->blocks()) for(ir::instruction *instr: block->get_inst_list()) { add_phi_nodes(instr); } - // Done - return false; +} } } diff --git a/lib/codegen/liveness.cpp b/lib/codegen/liveness.cpp index 9e910b420..824c95590 100644 --- a/lib/codegen/liveness.cpp +++ b/lib/codegen/liveness.cpp @@ -2,6 +2,7 @@ #include "codegen/layout.h" #include "ir/basic_block.h" #include "ir/function.h" +#include "ir/module.h" #include "ir/instructions.h" #include "ir/value.h" @@ -10,7 +11,8 @@ namespace codegen{ // Entry point -void liveness::run(ir::function *fn) { +void liveness::run(ir::module &mod) { +for(ir::function *fn: mod.get_function_list()){ // Assigns index to each instruction slot_index index = 0; for(ir::basic_block *block: fn->blocks()) @@ -35,6 +37,7 @@ void liveness::run(ir::function *fn) { intervals_[v] = segment{start, end}; } } +} } } diff --git a/lib/codegen/tune.cpp b/lib/codegen/tune.cpp index e77a13773..670008389 100644 --- a/lib/codegen/tune.cpp +++ b/lib/codegen/tune.cpp @@ -108,6 +108,7 @@ for(ir::function *fn: mod.get_function_list()){ r = i; } } + // extract unique instructions in order std::vector grids; for(auto &ref: references) @@ -118,6 +119,7 @@ for(ir::function *fn: mod.get_function_list()){ int num_warps = 1; for(size_t k = 0; k < grids.front()->get_type()->get_tile_shapes().size(); k++) num_warps *= *params_[grids.front()]["p2.d" + to_string(k)]; + // check constraints for(ir::instruction *i: grids){ ir::type *ty = i->get_type();