[code generation] simple matrix-multiplication working

This commit is contained in:
Philippe Tillet
2019-02-09 19:20:50 -05:00
parent 4c8dbcccdc
commit d39f97ef38
2 changed files with 39 additions and 19 deletions

View File

@@ -317,6 +317,7 @@ void selection::create_tile(ir::value *v, IRBuilder<> &builder,
if(is_shared){
size_t offset = alloc_->get_offset(v);
Value *ptr = builder.CreateGEP(sh_mem_ptr, builder.getInt32(offset));
ptr = builder.CreateBitCast(ptr, ty->getPointerTo(ptr->getType()->getPointerAddressSpace()));
tmap_.insert({v, new shared_tile(ty, shapes, ptr, builder)});
}
// create distributed tile
@@ -445,6 +446,7 @@ void selection::lower_tile_instruction(ir::instruction *ins, llvm::IRBuilder<> &
ir::value *A = ins->get_operand(0);
ir::value *B = ins->get_operand(1);
ir::value *C = ins->get_operand(2);
Function *f_mul_add = Intrinsic::getDeclaration(module, Intrinsic::fmuladd, {llvm_type(C->get_type()->get_scalar_ty(), ctx)});
result->for_each([&](indices_t idx){
Value *res = tmap_.at(C)->get_value(idx);
unsigned NK = A->get_type()->get_tile_shapes()[1];
@@ -453,7 +455,7 @@ void selection::lower_tile_instruction(ir::instruction *ins, llvm::IRBuilder<> &
indices_t b_idx = {idx[1], builder.getInt32(K)};
Value *a = tmap_.at(A)->get_value(a_idx);
Value *b = tmap_.at(B)->get_value(b_idx);
res = builder.CreateAdd(res, builder.CreateMul(a, b));
res = builder.CreateCall(f_mul_add, {a, b, res});
}
result->set_value(idx, res);
});