[examples] added basic skeleton to generate matrix multiplication PTX

This commit is contained in:
Philippe Tillet
2019-02-07 22:42:54 -05:00
parent 1b9a7a8e97
commit dd35277858
4 changed files with 74 additions and 11 deletions

View File

@@ -441,6 +441,24 @@ void selection::lower_tile_instruction(ir::instruction *ins, llvm::IRBuilder<> &
ti->set_value(idx, in->get_value(idx));
});
}
// matrix multiplication
else if(dynamic_cast<ir::matmul_inst*>(ins)) {
ir::value *A = ins->get_operand(0);
ir::value *B = ins->get_operand(1);
ir::value *C = ins->get_operand(2);
result->for_each([&](indices_t idx){
Value *res = tmap_.at(C)->get_value(idx);
unsigned NK = A->get_type()->get_tile_shapes()[1];
for(unsigned K = 0; K < NK; ++K){
indices_t a_idx = {idx[0], builder.getInt32(K)};
indices_t b_idx = {idx[1], builder.getInt32(K)};
Value *a = tmap_.at(A)->get_value(a_idx);
Value *b = tmap_.at(B)->get_value(b_idx);
res = builder.CreateAdd(res, builder.CreateMul(a, b));
}
result->set_value(idx, res);
});
}
// element-wise
else {
result->for_each([&](indices_t idx){

View File

@@ -17,10 +17,12 @@ void place_shared_copy::run(ir::module &mod) {
builder.set_insert_point(i);
ir::value *x = i->get_operand(0);
ir::value *y = i->get_operand(1);
ir::value *rx = builder.create_copy_to_shared(x);
ir::value *ry = builder.create_copy_to_shared(y);
ir::instruction *rx = (ir::instruction*)builder.create_copy_to_shared(x);
ir::instruction *ry = (ir::instruction*)builder.create_copy_to_shared(y);
x->replace_all_uses_with(rx);
y->replace_all_uses_with(ry);
rx->set_operand(0, x);
ry->set_operand(0, y);
}
}