diff --git a/examples/matrix.cpp b/examples/matrix.cpp index 8fa19a840..aa3b39b9b 100644 --- a/examples/matrix.cpp +++ b/examples/matrix.cpp @@ -17,14 +17,11 @@ const char src[] = "\ void test(fp32 *A, fp32 *B, fp32 *C, int32 i){\ int32 j = 1;\ + int32 test[16, 16] = 0;\ + int32 test2[16, 16];\ + int32 test3[16, 16];\ int32 k;\ - i = i + j;\ - for(k = 0; k < 10; k = k+1){\ - int32 u = 1;\ - u = u + i;\ - if(k == 0)\ - u = u + 2;\ - }\ + test = test2 + test3;\ }\ "; diff --git a/include/ast.h b/include/ast.h index 072644f1e..282e421e3 100644 --- a/include/ast.h +++ b/include/ast.h @@ -145,7 +145,7 @@ public: class binary_operator: public expression{ private: - llvm::Value* llvm_op(llvm::IRBuilder<> &bld, llvm::Value *lhs, llvm::Value *rhs, const std::string &name) const; + llvm::Value* llvm_op(module *mod, llvm::IRBuilder<> &bld, llvm::Value *lhs, llvm::Value *rhs, const std::string &name) const; public: binary_operator(BIN_OP_T op, node *lhs, node *rhs) @@ -163,6 +163,7 @@ class constant: public expression{ public: constant(int value): value_(value) { } llvm::Value* codegen(module *mod) const; + int value() const; private: const int value_; diff --git a/include/parser.y b/include/parser.y index 74b5f8bbd..f43b45265 100644 --- a/include/parser.y +++ b/include/parser.y @@ -92,7 +92,7 @@ constant : constant_list : constant { $$ = new list((constant*)$1); } - | constant_list ',' constant { $$ = append_ptr_list($1, $2); } + | constant_list ',' constant { $$ = append_ptr_list($1, $3); } ; type_name diff --git a/lib/codegen.cpp b/lib/codegen.cpp index 97b9407b8..f7af12a60 100644 --- a/lib/codegen.cpp +++ b/lib/codegen.cpp @@ -12,6 +12,18 @@ using namespace llvm; namespace tdl{ +/* Nd Array utils */ +inline std::vector array_shapes(Type *array_ty){ + std::vector result; + Type *current = array_ty; + while(isa(current)){ + result.push_back(array_ty->getArrayNumElements()); + current = array_ty->getArrayElementType(); + printf("%d %d\n", current, current->getTypeID()); + }; + return result; +} + /* Context */ context::context() { } @@ -149,13 +161,14 @@ const std::string &identifier::name() const{ // Tile Type* tile::type_impl(module*, Type *type) const{ - return TileType::get(type, shapes_->values().size()); + Type *current = type; + unsigned i = 0; + do{ + current = ArrayType::get(current, shapes_->values()[i++]->value()); + }while(i < shapes_->values().size()); + return current; } -// Initializer -Type* initializer::type_impl(module *mod, Type *type) const{ - return decl_->type(mod, type); -} // Pointer Type* pointer::type_impl(module*, Type *type) const{ @@ -265,6 +278,10 @@ Value* declaration::codegen(module* mod) const{ } /* Initializer */ +Type* initializer::type_impl(module *mod, Type *type) const{ + return decl_->type(mod, type); +} + void initializer::specifier(const declaration_specifier *spec) { spec_ = spec; } @@ -359,16 +376,66 @@ inline void implicit_cast(llvm::IRBuilder<> &builder, Value *&lhs, Value *&rhs, } } -//inline void implicit_broadcast(llvm::IRBuilder<> &builder, Value *&lhs, Value *&rhs){ -// return; -//} +inline void implicit_broadcast(module *mod, llvm::IRBuilder<> &builder, Value *&lhs, Value *&rhs){ + std::vector lhs_shapes = array_shapes(lhs->getType()); + std::vector rhs_shapes = array_shapes(rhs->getType()); + // Both are scalar + if(lhs_shapes.empty() && rhs_shapes.empty()) + return; + // One argument is scalar + if(!lhs_shapes.empty() ^ !rhs_shapes.empty()){ + auto &ref_shapes = lhs_shapes.empty()?rhs_shapes:lhs_shapes; + auto &ref = lhs_shapes.empty()?rhs:lhs; + auto &target = lhs_shapes.empty()?lhs:rhs; + Function *splat_fn = Intrinsic::getDeclaration(mod->handle(), Intrinsic::tlvm_splat_2d, {ref->getType()}); + SmallVector args(1 + ref_shapes.size()); + for(unsigned i = 0; i < ref_shapes.size(); i++) + args[1 + i] = builder.getInt32(ref_shapes[i]); + args[0] = target; + target = builder.CreateCall(splat_fn, args); + return; + } + // Both are arrays + int lhs_dim = lhs_shapes.size(); + int rhs_dim = rhs_shapes.size(); + std::vector &shortest = (lhs_dim < rhs_dim)?lhs_shapes:rhs_shapes; + std::vector &longest = (lhs_dim < rhs_dim)?rhs_shapes:lhs_shapes; + size_t ndim = longest.size(); + int off = longest.size() - shortest.size(); + for(int i = longest.size(); i>= 0; i--){ + if(shortest[off + i] != longest[i]) + throw std::runtime_error("cannot broadcast"); + } + // Pad + for(size_t i = 0; i < off; i++){ + shortest.insert(shortest.begin(), 1); + } + Value *&target = (lhs_dim < rhs_dim)?lhs:rhs; + SmallVector args(1 + ndim); + // Reshape left hand side + for(size_t i = 0; i < ndim; i++) + args[1 + i] = builder.getInt32(shortest[i]); + args[0] = target; + Function *reshape_fn = Intrinsic::getDeclaration(mod->handle(), Intrinsic::tlvm_reshape_2d_1d, {rhs->getType(), lhs->getType()}); + target = builder.CreateCall(reshape_fn, args); + // Broadcast both arguments + for(size_t i = 0; i < ndim; i++) + args[1 + i] = builder.getInt32(std::max(shortest[i], longest[i])); + Function *broadcast_fn = Intrinsic::getDeclaration(mod->handle(), Intrinsic::tlvm_broadcast_2d, {target->getType(), target->getType()}); + // Broadcast lhs + args[0] = lhs; + lhs = builder.CreateCall(broadcast_fn, args); + // Broadcast rhs + args[0] = rhs; + rhs = builder.CreateCall(broadcast_fn, args); +} /* Binary operator */ -Value *binary_operator::llvm_op(llvm::IRBuilder<> &builder, Value *lhs, Value *rhs, const std::string &name) const +Value *binary_operator::llvm_op(module *mod, llvm::IRBuilder<> &builder, Value *lhs, Value *rhs, const std::string &name) const { bool is_float = false, is_ptr = false, is_int = false, is_signed = false; - implicit_cast(builder, lhs, rhs, is_float, is_ptr, is_int, is_signed); -// implicit_broadcast(builder, lhs, rhs); +// implicit_cast(builder, lhs, rhs, is_float, is_ptr, is_int, is_signed); +// implicit_broadcast(mod, builder, lhs, rhs); // Mul if(op_==MUL && is_float) return builder.CreateFMul(lhs, rhs, name); @@ -478,12 +545,25 @@ Value *binary_operator::llvm_op(llvm::IRBuilder<> &builder, Value *lhs, Value *r Value* binary_operator::codegen(module *mod) const{ Value *lhs = lhs_->codegen(mod); Value *rhs = rhs_->codegen(mod); - Value *result = llvm_op(mod->builder(), lhs, rhs, ""); + Value *result = llvm_op(mod, mod->builder(), lhs, rhs, ""); return result; } /* Postfix expression */ Value* indexing_expression::codegen(module *mod) const{ + Value *in = mod->get_value(id_->name()); + std::vector ranges; + for(range *x: ranges_->values()) + ranges.push_back(x->type()); + // Type information + Function* reshape; + Type *in_type = in->getType(); + size_t in_dim = in_type->getTileNumDimensions(); + size_t out_dim = ranges.size(); + Type *out_type = TileType::get(in_type->getTileElementType(), out_dim); + // Intrinsic function + Function *reshape_fn = Intrinsic::getDeclaration(mod->handle(), Intrinsic::tlvm_reshape_2d_1d, {out_type, in_type}); + return nullptr; } @@ -569,6 +649,11 @@ llvm::Value* constant::codegen(module *mod) const{ return mod->builder().getInt32(value_); } +int constant::value() const{ + return value_; +} + + /* Unary expression */ const identifier* unary_expression::id() const{ return id_;