[examples] added basic skeleton to generate matrix multiplication PTX
This commit is contained in:
@@ -441,6 +441,24 @@ void selection::lower_tile_instruction(ir::instruction *ins, llvm::IRBuilder<> &
|
||||
ti->set_value(idx, in->get_value(idx));
|
||||
});
|
||||
}
|
||||
// matrix multiplication
|
||||
else if(dynamic_cast<ir::matmul_inst*>(ins)) {
|
||||
ir::value *A = ins->get_operand(0);
|
||||
ir::value *B = ins->get_operand(1);
|
||||
ir::value *C = ins->get_operand(2);
|
||||
result->for_each([&](indices_t idx){
|
||||
Value *res = tmap_.at(C)->get_value(idx);
|
||||
unsigned NK = A->get_type()->get_tile_shapes()[1];
|
||||
for(unsigned K = 0; K < NK; ++K){
|
||||
indices_t a_idx = {idx[0], builder.getInt32(K)};
|
||||
indices_t b_idx = {idx[1], builder.getInt32(K)};
|
||||
Value *a = tmap_.at(A)->get_value(a_idx);
|
||||
Value *b = tmap_.at(B)->get_value(b_idx);
|
||||
res = builder.CreateAdd(res, builder.CreateMul(a, b));
|
||||
}
|
||||
result->set_value(idx, res);
|
||||
});
|
||||
}
|
||||
// element-wise
|
||||
else {
|
||||
result->for_each([&](indices_t idx){
|
||||
|
@@ -17,10 +17,12 @@ void place_shared_copy::run(ir::module &mod) {
|
||||
builder.set_insert_point(i);
|
||||
ir::value *x = i->get_operand(0);
|
||||
ir::value *y = i->get_operand(1);
|
||||
ir::value *rx = builder.create_copy_to_shared(x);
|
||||
ir::value *ry = builder.create_copy_to_shared(y);
|
||||
ir::instruction *rx = (ir::instruction*)builder.create_copy_to_shared(x);
|
||||
ir::instruction *ry = (ir::instruction*)builder.create_copy_to_shared(y);
|
||||
x->replace_all_uses_with(rx);
|
||||
y->replace_all_uses_with(ry);
|
||||
rx->set_operand(0, x);
|
||||
ry->set_operand(0, y);
|
||||
}
|
||||
}
|
||||
|
||||
|
Reference in New Issue
Block a user