[code generation] simple matrix-multiplication working
This commit is contained in:
@@ -317,6 +317,7 @@ void selection::create_tile(ir::value *v, IRBuilder<> &builder,
|
||||
if(is_shared){
|
||||
size_t offset = alloc_->get_offset(v);
|
||||
Value *ptr = builder.CreateGEP(sh_mem_ptr, builder.getInt32(offset));
|
||||
ptr = builder.CreateBitCast(ptr, ty->getPointerTo(ptr->getType()->getPointerAddressSpace()));
|
||||
tmap_.insert({v, new shared_tile(ty, shapes, ptr, builder)});
|
||||
}
|
||||
// create distributed tile
|
||||
@@ -445,6 +446,7 @@ void selection::lower_tile_instruction(ir::instruction *ins, llvm::IRBuilder<> &
|
||||
ir::value *A = ins->get_operand(0);
|
||||
ir::value *B = ins->get_operand(1);
|
||||
ir::value *C = ins->get_operand(2);
|
||||
Function *f_mul_add = Intrinsic::getDeclaration(module, Intrinsic::fmuladd, {llvm_type(C->get_type()->get_scalar_ty(), ctx)});
|
||||
result->for_each([&](indices_t idx){
|
||||
Value *res = tmap_.at(C)->get_value(idx);
|
||||
unsigned NK = A->get_type()->get_tile_shapes()[1];
|
||||
@@ -453,7 +455,7 @@ void selection::lower_tile_instruction(ir::instruction *ins, llvm::IRBuilder<> &
|
||||
indices_t b_idx = {idx[1], builder.getInt32(K)};
|
||||
Value *a = tmap_.at(A)->get_value(a_idx);
|
||||
Value *b = tmap_.at(B)->get_value(b_idx);
|
||||
res = builder.CreateAdd(res, builder.CreateMul(a, b));
|
||||
res = builder.CreateCall(f_mul_add, {a, b, res});
|
||||
}
|
||||
result->set_value(idx, res);
|
||||
});
|
||||
|
Reference in New Issue
Block a user