[Intermediate Representation] Added skeleton

This commit is contained in:
Philippe Tillet
2018-12-31 22:47:31 -05:00
parent d260aefbd1
commit e7a4e70e22
22 changed files with 729 additions and 484 deletions

View File

@@ -1,141 +1,42 @@
#include <functional>
#include <algorithm>
#include "ast.h"
#include "codegen.h"
#include "llvm/IR/DerivedTypes.h"
#include "llvm/IR/Module.h"
#include "llvm/IR/CFG.h"
#include "llvm/IR/Instructions.h"
#include "llvm/Transforms/Utils/BasicBlockUtils.h"
#include "ir/constant.h"
#include "ir/function.h"
#include "ir/module.h"
#include "ir/basic_block.h"
#include "ir/builder.h"
#include "ir/type.h"
#include <iostream>
using namespace llvm;
namespace tdl{
/* Nd Array utils */
inline std::vector<unsigned> array_shapes(Type *array_ty){
std::vector<unsigned> result;
Type *current = array_ty;
while(isa<ArrayType>(current)){
result.push_back(array_ty->getArrayNumElements());
current = array_ty->getArrayElementType();
printf("%d %d\n", current, current->getTypeID());
};
return result;
}
/* Context */
context::context() { }
LLVMContext *context::handle() {
return &handle_;
}
/* Module */
module::module(const std::string &name, context *ctx)
: handle_(name.c_str(), *ctx->handle()), builder_(*ctx->handle()) {
sealed_blocks_.insert(nullptr);
}
llvm::Module* module::handle() {
return &handle_;
}
llvm::IRBuilder<>& module::builder() {
return builder_;
}
void module::set_value(const std::string& name, BasicBlock *block, Value *value){
values_[val_key_t{name, block}] = value;
}
void module::set_value(const std::string& name, llvm::Value* value){
return set_value(name, builder_.GetInsertBlock(), value);
}
PHINode* module::make_phi(Type *type, unsigned num_values, BasicBlock *block){
Instruction* instr = block->getFirstNonPHIOrDbg();
if(instr)
builder_.SetInsertPoint(instr);
PHINode *res = builder_.CreatePHI(type, num_values);
if(instr)
builder_.SetInsertPoint(block);
return res;
}
Value *module::add_phi_operands(const std::string& name, PHINode *&phi){
BasicBlock *block = phi->getParent();
for(BasicBlock *pred: predecessors(block)){
llvm::Value *value = get_value(name, pred);
phi->addIncoming(value, pred);
}
return phi;
}
llvm::Value *module::get_value_recursive(const std::string& name, BasicBlock *block) {
llvm::Value *result;
if(sealed_blocks_.find(block) == sealed_blocks_.end()){
llvm::Value *pred = get_value(name, *pred_begin(block));
incomplete_phis_[block][name] = make_phi(pred->getType(), 1, block);
result = (Value*)incomplete_phis_[block][name];
}
else if(pred_size(block) <= 1){
bool has_pred = pred_size(block);
result = get_value(name, has_pred?*pred_begin(block):nullptr);
}
else{
llvm::Value *pred = get_value(name, *pred_begin(block));
result = make_phi(pred->getType(), 1, block);
set_value(name, block, result);
add_phi_operands(name, (PHINode*&)result);
}
set_value(name, block, result);
return result;
}
llvm::Value *module::get_value(const std::string& name, BasicBlock *block) {
val_key_t key(name, block);
if(values_.find(key) != values_.end()){
return values_.at(key);
}
return get_value_recursive(name, block);
}
llvm::Value *module::get_value(const std::string& name) {
return get_value(name, builder_.GetInsertBlock());
}
llvm::Value *module::seal_block(BasicBlock *block){
for(auto &x: incomplete_phis_[block])
add_phi_operands(x.first, x.second);
sealed_blocks_.insert(block);
}
namespace ast{
/* Translation unit */
Value* translation_unit::codegen(module *mod) const{
ir::value* translation_unit::codegen(ir::module *mod) const{
decls_->codegen(mod);
return nullptr;
}
/* Declaration specifier */
Type* declaration_specifier::type(module *mod) const {
LLVMContext &ctx = mod->handle()->getContext();
ir::type* declaration_specifier::type(ir::module *mod) const {
ir::context &ctx = mod->get_context();
switch (spec_) {
case VOID_T: return Type::getVoidTy(ctx);
case INT8_T: return IntegerType::get(ctx, 8);
case INT16_T: return IntegerType::get(ctx, 16);
case INT32_T: return IntegerType::get(ctx, 32);
case INT64_T: return IntegerType::get(ctx, 64);
case FLOAT32_T: return Type::getFloatTy(ctx);
case FLOAT64_T: return Type::getDoubleTy(ctx);
default: assert(false && "unreachable"); throw;
case VOID_T: return ir::type::get_void_ty(ctx);
case INT8_T: return ir::integer_type::get(ctx, 8);
case INT16_T: return ir::integer_type::get(ctx, 16);
case INT32_T: return ir::integer_type::get(ctx, 32);
case INT64_T: return ir::integer_type::get(ctx, 64);
case FLOAT32_T: return ir::type::get_float_ty(ctx);
case FLOAT64_T: return ir::type::get_double_ty(ctx);
default: throw std::runtime_error("unreachable");
}
}
/* Parameter */
Type* parameter::type(module *mod) const {
ir::type* parameter::type(ir::module *mod) const {
return decl_->type(mod, spec_->type(mod));
}
@@ -144,14 +45,14 @@ const identifier *parameter::id() const {
}
/* Declarators */
Type* declarator::type(module *mod, Type *type) const{
ir::type* declarator::type(ir::module *mod, ir::type *type) const{
if(ptr_)
return type_impl(mod, ptr_->type(mod, type));
return type_impl(mod, type);
}
// Identifier
Type* identifier::type_impl(module *, Type *type) const{
ir::type* identifier::type_impl(ir::module *, ir::type *type) const{
return type;
}
@@ -160,60 +61,57 @@ const std::string &identifier::name() const{
}
// Tile
Type* tile::type_impl(module*, Type *type) const{
Type *current = type;
unsigned i = 0;
do{
current = ArrayType::get(current, shapes_->values()[i++]->value());
}while(i < shapes_->values().size());
return current;
ir::type* tile::type_impl(ir::module*, ir::type *type) const{
std::vector<unsigned> shapes;
for(constant *cst: shapes_->values())
shapes.push_back(cst->value());
return ir::tile_type::get(type, shapes);
}
// Pointer
Type* pointer::type_impl(module*, Type *type) const{
return PointerType::get(type, 1);
ir::type* pointer::type_impl(ir::module*, ir::type *type) const{
return ir::pointer_type::get(type, 1);
}
// Function
void function::bind_parameters(module *mod, Function *fn) const{
std::vector<llvm::Value*> args;
std::transform(fn->arg_begin(), fn->arg_end(), std::back_inserter(args), [&](llvm::Argument& x){ return &x;});
void function::bind_parameters(ir::module *mod, ir::function *fn) const{
std::vector<ir::value*> args;
std::transform(fn->arg_begin(), fn->arg_end(), std::back_inserter(args), [&](ir::argument& x){ return &x;});
assert(args.size() == args_->values().size());
for(size_t i = 0; i < args.size(); i++){
parameter *param_i = args_->values().at(i);
const identifier *id_i = param_i->id();
if(id_i){
args[i]->setName(id_i->name());
args[i]->set_name(id_i->name());
mod->set_value(id_i->name(), nullptr, args[i]);
}
}
}
Type* function::type_impl(module*mod, Type *type) const{
SmallVector<Type*, 8> types;
for(parameter* param: args_->values()){
ir::type* function::type_impl(ir::module* mod, ir::type *type) const{
std::vector<ir::type*> types;
for(parameter* param: args_->values())
types.push_back(param->type(mod));
}
return FunctionType::get(type, types, false);
return ir::function_type::get(type, types);
}
/* Function definition */
Value* function_definition::codegen(module *mod) const{
FunctionType *prototype = (FunctionType *)header_->type(mod, spec_->type(mod));
ir::value* function_definition::codegen(ir::module *mod) const{
ir::function_type *prototype = (ir::function_type*)header_->type(mod, spec_->type(mod));
const std::string &name = header_->id()->name();
Function *fn = Function::Create(prototype, Function::ExternalLinkage, name, mod->handle());
ir::function *fn = ir::function::create(prototype, name, mod);
header_->bind_parameters(mod, fn);
BasicBlock *entry = BasicBlock::Create(mod->handle()->getContext(), "entry", fn);
ir::basic_block *entry = ir::basic_block::create(mod->get_context(), "entry", fn);
mod->seal_block(entry);
mod->builder().SetInsertPoint(entry);
mod->get_builder().set_insert_point(entry);
body_->codegen(mod);
mod->builder().CreateRetVoid();
mod->get_builder().create_ret_void();
return nullptr;
}
/* Statements */
Value* compound_statement::codegen(module* mod) const{
ir::value* compound_statement::codegen(ir::module* mod) const{
decls_->codegen(mod);
if(statements_)
statements_->codegen(mod);
@@ -221,56 +119,56 @@ Value* compound_statement::codegen(module* mod) const{
}
/* Iteration statement */
Value* iteration_statement::codegen(module *mod) const{
IRBuilder<> &builder = mod->builder();
LLVMContext &ctx = mod->handle()->getContext();
Function *fn = builder.GetInsertBlock()->getParent();
BasicBlock *loop_bb = BasicBlock::Create(ctx, "loop", fn);
BasicBlock *next_bb = BasicBlock::Create(ctx, "postloop", fn);
ir::value* iteration_statement::codegen(ir::module *mod) const{
ir::builder &builder = mod->get_builder();
ir::context &ctx = mod->get_context();
ir::function *fn = builder.get_insert_block()->get_parent();
ir::basic_block *loop_bb = ir::basic_block::create(ctx, "loop", fn);
ir::basic_block *next_bb = ir::basic_block::create(ctx, "postloop", fn);
init_->codegen(mod);
builder.CreateBr(loop_bb);
builder.SetInsertPoint(loop_bb);
builder.create_br(loop_bb);
builder.set_insert_point(loop_bb);
statements_->codegen(mod);
exec_->codegen(mod);
Value *cond = stop_->codegen(mod);
builder.CreateCondBr(cond, loop_bb, next_bb);
builder.SetInsertPoint(next_bb);
ir::value *cond = stop_->codegen(mod);
builder.create_cond_br(cond, loop_bb, next_bb);
builder.set_insert_point(next_bb);
mod->seal_block(loop_bb);
mod->seal_block(next_bb);
return nullptr;
}
/* Selection statement */
Value* selection_statement::codegen(module* mod) const{
IRBuilder<> &builder = mod->builder();
LLVMContext &ctx = mod->handle()->getContext();
Function *fn = builder.GetInsertBlock()->getParent();
Value *cond = cond_->codegen(mod);
BasicBlock *then_bb = BasicBlock::Create(ctx, "then", fn);
BasicBlock *else_bb = else_value_?BasicBlock::Create(ctx, "else", fn):nullptr;
BasicBlock *endif_bb = BasicBlock::Create(ctx, "endif", fn);
ir::value* selection_statement::codegen(ir::module* mod) const{
ir::builder &builder = mod->get_builder();
ir::context &ctx = mod->get_context();
ir::function *fn = builder.get_insert_block()->get_parent();
ir::value *cond = cond_->codegen(mod);
ir::basic_block *then_bb = ir::basic_block::create(ctx, "then", fn);
ir::basic_block *else_bb = else_value_?ir::basic_block::create(ctx, "else", fn):nullptr;
ir::basic_block *endif_bb = ir::basic_block::create(ctx, "endif", fn);
// Branch
if(else_value_)
builder.CreateCondBr(cond, then_bb, else_bb);
builder.create_cond_br(cond, then_bb, else_bb);
else
builder.CreateCondBr(cond, then_bb, endif_bb);
builder.create_cond_br(cond, then_bb, endif_bb);
// Then
builder.SetInsertPoint(then_bb);
builder.set_insert_point(then_bb);
then_value_->codegen(mod);
if(else_value_)
builder.CreateBr(endif_bb);
builder.create_br(endif_bb);
// Else
if(else_value_){
builder.SetInsertPoint(else_bb);
builder.set_insert_point(else_bb);
else_value_->codegen(mod);
builder.CreateBr(endif_bb);
builder.create_br(endif_bb);
}
// Endif
builder.SetInsertPoint(endif_bb);
builder.set_insert_point(endif_bb);
}
/* Declaration */
Value* declaration::codegen(module* mod) const{
ir::value* declaration::codegen(ir::module* mod) const{
for(initializer *init: init_->values())
init->specifier(spec_);
init_->codegen(mod);
@@ -278,7 +176,7 @@ Value* declaration::codegen(module* mod) const{
}
/* Initializer */
Type* initializer::type_impl(module *mod, Type *type) const{
ir::type* initializer::type_impl(ir::module *mod, ir::type *type) const{
return decl_->type(mod, type);
}
@@ -286,15 +184,15 @@ void initializer::specifier(const declaration_specifier *spec) {
spec_ = spec;
}
Value* initializer::codegen(module * mod) const{
Type *ty = decl_->type(mod, spec_->type(mod));
ir::value* initializer::codegen(ir::module * mod) const{
ir::type *ty = decl_->type(mod, spec_->type(mod));
std::string name = decl_->id()->name();
Value *value;
ir::value *value;
if(expr_)
value = expr_->codegen(mod);
else
value = llvm::UndefValue::get(ty);
value->setName(name);
value = ir::undef_value::get(ty);
value->set_name(name);
mod->set_value(name, value);
return value;
}
@@ -302,97 +200,87 @@ Value* initializer::codegen(module * mod) const{
/*------------------*/
/* Expression */
/*------------------*/
llvm::Value *llvm_cast(llvm::IRBuilder<> &builder, Value *src, Type *dst_ty){
Type *src_ty = src->getType();
ir::value *llvm_cast(ir::builder &builder, ir::value *src, ir::type *dst_ty){
ir::type *src_ty = src->get_type();
bool src_signed = false;
bool dst_signed = false;
if(src_ty == dst_ty)
return src;
else if(src_ty->isIntegerTy() && src_signed && dst_ty->isFloatingPointTy())
return builder.CreateSIToFP(src, dst_ty);
else if(src_ty->is_integer_ty() && src_signed && dst_ty->is_floating_point_ty())
return builder.create_si_to_fp(src, dst_ty);
else if(src_ty->isIntegerTy() && !src_signed && dst_ty->isFloatingPointTy())
return builder.CreateUIToFP(src, dst_ty);
else if(src_ty->is_integer_ty() && !src_signed && dst_ty->is_floating_point_ty())
return builder.create_ui_to_fp(src, dst_ty);
else if(src_ty->isFloatingPointTy() && dst_ty->isIntegerTy() && dst_signed)
return builder.CreateFPToSI(src, dst_ty);
else if(src_ty->is_floating_point_ty() && dst_ty->is_integer_ty() && dst_signed)
return builder.create_fp_to_si(src, dst_ty);
else if(src_ty->isFloatingPointTy() && dst_ty->isIntegerTy() && !dst_signed)
return builder.CreateFPToUI(src, dst_ty);
else if(src_ty->is_floating_point_ty() && dst_ty->is_integer_ty() && !dst_signed)
return builder.create_fp_to_ui(src, dst_ty);
else if(src_ty->isFloatingPointTy() && dst_ty->isFloatingPointTy() &&
src_ty->getFPMantissaWidth() < dst_ty->getFPMantissaWidth())
return builder.CreateFPExt(src, dst_ty);
else if(src_ty->is_floating_point_ty() && dst_ty->is_floating_point_ty() &&
src_ty->get_fp_mantissa_width() < dst_ty->get_fp_mantissa_width())
return builder.create_fp_ext(src, dst_ty);
else if(src_ty->isFloatingPointTy() && dst_ty->isFloatingPointTy() &&
src_ty->getFPMantissaWidth() > dst_ty->getFPMantissaWidth())
return builder.CreateFPTrunc(src, dst_ty);
else if(src_ty->is_floating_point_ty() && dst_ty->is_floating_point_ty() &&
src_ty->get_fp_mantissa_width() > dst_ty->get_fp_mantissa_width())
return builder.create_fp_trunc(src, dst_ty);
else if(src_ty->isIntegerTy() && dst_ty->isIntegerTy() &&
src_ty->getIntegerBitWidth())
return builder.CreateIntCast(src, dst_ty, dst_signed);
else if(src_ty->is_integer_ty() && dst_ty->is_integer_ty() &&
src_ty->get_integer_bit_width())
return builder.create_int_cast(src, dst_ty, dst_signed);
else{
assert(false && "unreachable");
throw;
}
else
throw std::runtime_error("unreachable");
}
inline void implicit_cast(llvm::IRBuilder<> &builder, Value *&lhs, Value *&rhs,
inline void implicit_cast(ir::builder &builder, ir::value *&lhs, ir::value *&rhs,
bool &is_float, bool &is_ptr, bool &is_int, bool &is_signed){
// Input types
Type *left_ty = lhs->getType();
Type *right_ty = rhs->getType();
ir::type *left_ty = lhs->get_type();
ir::type *right_ty = rhs->get_type();
// One operand is pointer
if(left_ty->isPointerTy()){
if(left_ty->is_pointer_ty()){
is_ptr = true;
}
// One operand is double
else if(left_ty->isDoubleTy() || right_ty->isDoubleTy()){
Value *&to_convert = left_ty->isDoubleTy()?rhs:lhs;
to_convert = llvm_cast(builder, to_convert, builder.getDoubleTy());
else if(left_ty->is_double_ty() || right_ty->is_double_ty()){
ir::value *&to_convert = left_ty->is_double_ty()?rhs:lhs;
to_convert = llvm_cast(builder, to_convert, builder.get_double_ty());
is_float = true;
}
// One operand is float
else if(left_ty->isFloatTy() || right_ty->isFloatTy()){
Value *&to_convert = left_ty->isFloatTy()?rhs:lhs;
to_convert = llvm_cast(builder, to_convert, builder.getFloatTy());
else if(left_ty->is_float_ty() || right_ty->is_float_ty()){
ir::value *&to_convert = left_ty->is_float_ty()?rhs:lhs;
to_convert = llvm_cast(builder, to_convert, builder.get_float_ty());
is_float = true;
}
// Both operands are integers
else if(left_ty->isIntegerTy() && right_ty->isIntegerTy()){
else if(left_ty->is_integer_ty() && right_ty->is_integer_ty()){
is_int = true;
is_signed = false;
if(left_ty->getIntegerBitWidth() != right_ty->getIntegerBitWidth()){
Value *&to_convert = (left_ty->getIntegerBitWidth() > right_ty->getIntegerBitWidth())?rhs:lhs;
Type *dst_ty = (to_convert==lhs)?right_ty:left_ty;
if(left_ty->get_integer_bit_width() != right_ty->get_integer_bit_width()){
ir::value *&to_convert = (left_ty->get_integer_bit_width() > right_ty->get_integer_bit_width())?rhs:lhs;
ir::type *dst_ty = (to_convert==lhs)?right_ty:left_ty;
to_convert = llvm_cast(builder, to_convert, dst_ty);
}
}
// Not reachable
else{
assert(false);
throw;
}
else
throw std::runtime_error("unreachable");
}
inline void implicit_broadcast(module *mod, llvm::IRBuilder<> &builder, Value *&lhs, Value *&rhs){
std::vector<unsigned> lhs_shapes = array_shapes(lhs->getType());
std::vector<unsigned> rhs_shapes = array_shapes(rhs->getType());
inline void implicit_broadcast(ir::module *mod, ir::builder &builder, ir::value *&lhs, ir::value *&rhs){
std::vector<unsigned> lhs_shapes = lhs->get_type()->get_tile_shapes();
std::vector<unsigned> rhs_shapes = rhs->get_type()->get_tile_shapes();
// Both are scalar
if(lhs_shapes.empty() && rhs_shapes.empty())
return;
// One argument is scalar
if(!lhs_shapes.empty() ^ !rhs_shapes.empty()){
auto &ref_shapes = lhs_shapes.empty()?rhs_shapes:lhs_shapes;
auto &ref = lhs_shapes.empty()?rhs:lhs;
auto &shapes = lhs_shapes.empty()?rhs_shapes:lhs_shapes;
auto &target = lhs_shapes.empty()?lhs:rhs;
Function *splat_fn = Intrinsic::getDeclaration(mod->handle(), Intrinsic::tlvm_splat_2d, {ref->getType()});
SmallVector<Value*, 4> args(1 + ref_shapes.size());
for(unsigned i = 0; i < ref_shapes.size(); i++)
args[1 + i] = builder.getInt32(ref_shapes[i]);
args[0] = target;
target = builder.CreateCall(splat_fn, args);
target = builder.create_splat(target, shapes);
return;
}
// Both are arrays
@@ -407,246 +295,195 @@ inline void implicit_broadcast(module *mod, llvm::IRBuilder<> &builder, Value *&
throw std::runtime_error("cannot broadcast");
}
// Pad
for(size_t i = 0; i < off; i++){
for(size_t i = 0; i < off; i++)
shortest.insert(shortest.begin(), 1);
}
Value *&target = (lhs_dim < rhs_dim)?lhs:rhs;
SmallVector<Value*, 4> args(1 + ndim);
// Reshape left hand side
ir::value *&target = (lhs_dim < rhs_dim)?lhs:rhs;
target = builder.create_reshape(target, shortest);
// Broadcast
std::vector<unsigned> shapes(ndim);
for(size_t i = 0; i < ndim; i++)
args[1 + i] = builder.getInt32(shortest[i]);
args[0] = target;
Function *reshape_fn = Intrinsic::getDeclaration(mod->handle(), Intrinsic::tlvm_reshape_2d_1d, {rhs->getType(), lhs->getType()});
target = builder.CreateCall(reshape_fn, args);
// Broadcast both arguments
for(size_t i = 0; i < ndim; i++)
args[1 + i] = builder.getInt32(std::max(shortest[i], longest[i]));
Function *broadcast_fn = Intrinsic::getDeclaration(mod->handle(), Intrinsic::tlvm_broadcast_2d, {target->getType(), target->getType()});
// Broadcast lhs
args[0] = lhs;
lhs = builder.CreateCall(broadcast_fn, args);
// Broadcast rhs
args[0] = rhs;
rhs = builder.CreateCall(broadcast_fn, args);
shapes[i] = std::max(shortest[i], longest[i]);
lhs = builder.create_broadcast(lhs, shapes);
rhs = builder.create_broadcast(rhs, shapes);
}
/* Binary operator */
Value *binary_operator::llvm_op(module *mod, llvm::IRBuilder<> &builder, Value *lhs, Value *rhs, const std::string &name) const
ir::value *binary_operator::llvm_op(ir::module *mod, ir::builder &builder, ir::value *lhs, ir::value *rhs, const std::string &name) const
{
bool is_float = false, is_ptr = false, is_int = false, is_signed = false;
// implicit_cast(builder, lhs, rhs, is_float, is_ptr, is_int, is_signed);
// implicit_broadcast(mod, builder, lhs, rhs);
// Mul
if(op_==MUL && is_float)
return builder.CreateFMul(lhs, rhs, name);
return builder.create_fmul(lhs, rhs, name);
if(op_==MUL && is_int)
return builder.CreateMul(lhs, rhs, name);
// Div
return builder.create_mul(lhs, rhs, name);
if(op_==DIV && is_float)
return builder.CreateFDiv(lhs, rhs, name);
return builder.create_fdiv(lhs, rhs, name);
if(op_==DIV && is_int && is_signed)
return builder.CreateSDiv(lhs, rhs, name);
return builder.create_sdiv(lhs, rhs, name);
if(op_==DIV && is_int && !is_signed)
return builder.CreateUDiv(lhs, rhs, name);
// Mod
return builder.create_udiv(lhs, rhs, name);
if(op_==MOD && is_float)
return builder.CreateFRem(lhs, rhs, name);
return builder.create_frem(lhs, rhs, name);
if(op_==MOD && is_int && is_signed)
return builder.CreateSRem(lhs, rhs, name);
return builder.create_srem(lhs, rhs, name);
if(op_==MOD && is_int && !is_signed)
return builder.CreateURem(lhs, rhs, name);
// Add
return builder.create_urem(lhs, rhs, name);
if(op_==ADD && is_float)
return builder.CreateFAdd(lhs, rhs, name);
return builder.create_fadd(lhs, rhs, name);
if(op_==ADD && is_int)
return builder.CreateAdd(lhs, rhs);
return builder.create_add(lhs, rhs);
if(op_==ADD && is_ptr)
return builder.CreateGEP(lhs, {rhs});
// Sub
return builder.create_gep(lhs, {rhs});
if(op_==SUB && is_float)
return builder.CreateFSub(lhs, rhs, name);
return builder.create_fsub(lhs, rhs, name);
if(op_==SUB && is_int)
return builder.CreateSub(lhs, rhs, name);
return builder.create_sub(lhs, rhs, name);
if(op_==SUB && is_ptr)
return builder.CreateGEP(lhs, {builder.CreateNeg(rhs)});
// Left shift
if(op_==LEFT_SHIFT){
assert(is_int);
return builder.CreateLShr(lhs, rhs, name);
}
// Right shift
if(op_==RIGHT_SHIFT){
assert(is_int);
return builder.CreateAShr(lhs, rhs, name);
}
// LT
return builder.create_gep(lhs, {builder.create_neg(rhs)});
if(op_==LEFT_SHIFT)
return builder.create_lshr(lhs, rhs, name);
if(op_==RIGHT_SHIFT)
return builder.create_ashr(lhs, rhs, name);
if(op_ == LT && is_float)
return builder.CreateFCmpOLT(lhs, rhs, name);
return builder.create_fcmpOLT(lhs, rhs, name);
if(op_ == LT && is_int && is_signed)
return builder.CreateICmpSLT(lhs, rhs, name);
return builder.create_icmpSLT(lhs, rhs, name);
if(op_ == LT && is_int && !is_signed)
return builder.CreateICmpULT(lhs, rhs, name);
// GT
return builder.create_icmpULT(lhs, rhs, name);
if(op_ == GT && is_float)
return builder.CreateFCmpOGT(lhs, rhs, name);
return builder.create_fcmpOGT(lhs, rhs, name);
if(op_ == GT && is_int && is_signed)
return builder.CreateICmpSGT(lhs, rhs, name);
return builder.create_icmpSGT(lhs, rhs, name);
if(op_ == GT && is_int && !is_signed)
return builder.CreateICmpUGT(lhs, rhs, name);
// LE
return builder.create_icmpUGT(lhs, rhs, name);
if(op_ == LE && is_float)
return builder.CreateFCmpOLE(lhs, rhs, name);
return builder.create_fcmpOLE(lhs, rhs, name);
if(op_ == LE && is_int && is_signed)
return builder.CreateICmpSLE(lhs, rhs, name);
return builder.create_icmpSLE(lhs, rhs, name);
if(op_ == LE && is_int && !is_signed)
return builder.CreateICmpULE(lhs, rhs, name);
// GE
return builder.create_icmpULE(lhs, rhs, name);
if(op_ == GE && is_float)
return builder.CreateFCmpOGE(lhs, rhs, name);
return builder.create_fcmpOGE(lhs, rhs, name);
if(op_ == GE && is_int && is_signed)
return builder.CreateICmpSGE(lhs, rhs, name);
return builder.create_icmpSGE(lhs, rhs, name);
if(op_ == GE && is_int && !is_signed)
return builder.CreateICmpUGE(lhs, rhs, name);
// EQ
return builder.create_icmpUGE(lhs, rhs, name);
if(op_ == EQ && is_float)
return builder.CreateFCmpOEQ(lhs, rhs, name);
return builder.create_fcmpOEQ(lhs, rhs, name);
if(op_ == EQ && is_int)
return builder.CreateICmpEQ(lhs, rhs, name);
// NE
return builder.create_icmpEQ(lhs, rhs, name);
if(op_ == NE && is_float)
return builder.CreateFCmpONE(lhs, rhs, name);
return builder.create_fcmpONE(lhs, rhs, name);
if(op_ == NE && is_int)
return builder.CreateICmpNE(lhs, rhs, name);
// AND
if(op_ == AND){
assert(is_int);
return builder.CreateAnd(lhs, rhs, name);
}
if(op_ == XOR){
assert(is_int);
return builder.CreateXor(lhs, rhs, name);
}
if(op_ == OR){
assert(is_int);
return builder.CreateOr(lhs, rhs, name);
}
if(op_ == LAND){
assert(is_int);
return builder.CreateAnd(lhs, rhs, name);
}
if(op_ == LOR){
assert(is_int);
return builder.CreateOr(lhs, rhs, name);
}
assert(false && "unreachable");
throw;
return builder.create_icmpNE(lhs, rhs, name);
if(op_ == AND)
return builder.create_and(lhs, rhs, name);
if(op_ == XOR)
return builder.create_xor(lhs, rhs, name);
if(op_ == OR)
return builder.create_or(lhs, rhs, name);
if(op_ == LAND)
return builder.create_and(lhs, rhs, name);
if(op_ == LOR)
return builder.create_or(lhs, rhs, name);
throw std::runtime_error("unreachable");
}
Value* binary_operator::codegen(module *mod) const{
Value *lhs = lhs_->codegen(mod);
Value *rhs = rhs_->codegen(mod);
Value *result = llvm_op(mod, mod->builder(), lhs, rhs, "");
ir::value* binary_operator::codegen(ir::module *mod) const{
ir::value *lhs = lhs_->codegen(mod);
ir::value *rhs = rhs_->codegen(mod);
ir::value *result = llvm_op(mod, mod->get_builder(), lhs, rhs, "");
return result;
}
/* Postfix expression */
Value* indexing_expression::codegen(module *mod) const{
Value *in = mod->get_value(id_->name());
std::vector<range_enum_t> ranges;
for(range *x: ranges_->values())
ranges.push_back(x->type());
// Type information
Function* reshape;
Type *in_type = in->getType();
size_t in_dim = in_type->getTileNumDimensions();
size_t out_dim = ranges.size();
Type *out_type = TileType::get(in_type->getTileElementType(), out_dim);
// Intrinsic function
Function *reshape_fn = Intrinsic::getDeclaration(mod->handle(), Intrinsic::tlvm_reshape_2d_1d, {out_type, in_type});
return nullptr;
ir::value* indexing_expression::codegen(ir::module *mod) const{
ir::value *in = mod->get_value(id_->name());
const std::vector<range*> &ranges = ranges_->values();
std::vector<unsigned> in_shapes = in->get_type()->get_tile_shapes();
std::vector<unsigned> out_shapes(ranges.size());
size_t current = 0;
for(size_t i = 0; i < out_shapes.size(); i++)
out_shapes[i] = (ranges[i]->type()==NEWAXIS)?1:in_shapes[current++];
return mod->get_builder().create_reshape(in, out_shapes);
}
/* Unary operator */
Value *unary_operator::llvm_op(llvm::IRBuilder<> &builder, Value *arg, const std::string &name) const{
Type *atype = arg->getType();
bool is_float = atype->isFloatingPointTy();
bool is_int = atype->isIntegerTy();
if(op_ == INC){
assert(is_int);
return builder.CreateAdd(arg, builder.getInt32(1), name);
}
if(op_ == DEC){
assert(is_int);
return builder.CreateSub(arg, builder.getInt32(1), name);
}
ir::value *unary_operator::llvm_op(ir::builder &builder, ir::value *arg, const std::string &name) const{
ir::type *atype = arg->get_type();
bool is_float = atype->is_floating_point_ty();
bool is_int = atype->is_integer_ty();
if(op_ == INC)
return builder.create_add(arg, builder.get_int32(1), name);
if(op_ == DEC)
return builder.create_sub(arg, builder.get_int32(1), name);
if(op_ == PLUS)
return arg;
if(op_ == MINUS && is_float)
return builder.CreateFNeg(arg, name);
return builder.create_fneg(arg, name);
if(op_ == MINUS && is_int)
return builder.CreateNeg(arg, name);
return builder.create_neg(arg, name);
if(op_ == ADDR)
throw std::runtime_error("not supported");
if(op_ == DEREF)
return builder.CreateLoad(arg, name);
return builder.create_load(arg, name);
if(op_ == COMPL)
throw std::runtime_error("not supported");
if(op_ == NOT)
return builder.CreateNot(arg, name);
assert(false && "unrechable");
throw;
return builder.create_not(arg, name);
throw std::runtime_error("unreachable");
}
Value* unary_operator::codegen(module *mod) const{
Value *arg = arg_->codegen(mod);
Value *result = llvm_op(mod->builder(), arg, "");
ir::value* unary_operator::codegen(ir::module *mod) const{
ir::value *arg = arg_->codegen(mod);
ir::value *result = llvm_op(mod->get_builder(), arg, "");
return result;
}
/* Cast operator */
Value *cast_operator::llvm_op(IRBuilder<> &builder, Type *T, Value *arg, const std::string &name) const{
ir::value *cast_operator::llvm_op(ir::builder &builder, ir::type *T, ir::value *arg, const std::string &name) const{
return nullptr;
}
Value* cast_operator::codegen(module *mod) const{
Value *arg = arg_->codegen(mod);
Type *T = T_->type(mod);
return llvm_op(mod->builder(), T, arg, "");
ir::value* cast_operator::codegen(ir::module *mod) const{
ir::value *arg = arg_->codegen(mod);
ir::type *T = T_->type(mod);
return llvm_op(mod->get_builder(), T, arg, "");
}
/* Conditional expression */
Value *conditional_expression::llvm_op(IRBuilder<> &builder, Value *cond, Value *true_value, Value *false_value, const std::string &name) const{
ir::value *conditional_expression::llvm_op(ir::builder &builder, ir::value *cond, ir::value *true_value, ir::value *false_value, const std::string &name) const{
return nullptr;
}
Value *conditional_expression::codegen(module *mod) const{
Value *cond = cond_->codegen(mod);
Value *true_value = true_value_->codegen(mod);
Value *false_value = false_value_->codegen(mod);
return llvm_op(mod->builder(), cond, true_value, false_value, "");
ir::value *conditional_expression::codegen(ir::module *mod) const{
ir::value *cond = cond_->codegen(mod);
ir::value *true_value = true_value_->codegen(mod);
ir::value *false_value = false_value_->codegen(mod);
return llvm_op(mod->get_builder(), cond, true_value, false_value, "");
}
/* Assignment expression */
Value *assignment_expression::codegen(module *mod) const{
Value *rvalue = rvalue_->codegen(mod);
ir::value *assignment_expression::codegen(ir::module *mod) const{
ir::value *rvalue = rvalue_->codegen(mod);
mod->set_value(lvalue_->id()->name(), rvalue);
return rvalue;
}
/* Type name */
llvm::Type *type_name::type(module *mod) const{
ir::type *type_name::type(ir::module *mod) const{
return decl_->type(mod, spec_->type(mod));
}
/* String literal */
llvm::Value* string_literal::codegen(module *mod) const{
return ConstantDataArray::getString(mod->handle()->getContext(), value_);
ir::value* string_literal::codegen(ir::module *mod) const{
return ir::constant_data_array::get_string(mod->get_context(), value_);
}
/* Constant */
llvm::Value* constant::codegen(module *mod) const{
return mod->builder().getInt32(value_);
ir::value* constant::codegen(ir::module *mod) const{
return mod->get_builder().get_int32(value_);
}
int constant::value() const{
@@ -660,7 +497,7 @@ const identifier* unary_expression::id() const{
}
/* Named */
llvm::Value* named_expression::codegen(module *mod) const{
ir::value* named_expression::codegen(ir::module *mod) const{
const std::string &name = id()->name();
return mod->get_value(name);
}