[BACKEND/CODE_GEN] Fixed float32 matmul problem (#380)

This commit is contained in:
Philippe Tillet
2021-11-30 22:00:56 -08:00
committed by GitHub
parent c86ad9c9ab
commit 8ec9f037bb
2 changed files with 17 additions and 13 deletions

View File

@@ -788,7 +788,6 @@ void generator::visit_cat_inst(ir::cat_inst* x) {
for(size_t j = 0; j < idxs_.at(rhs).size(); j ++){
vals_[x][idxs_[x][i++]] = vals_[rhs][idxs_[rhs][j]];
}
// std::cout << "!" << std::endl;
}
@@ -1660,13 +1659,17 @@ void generator::visit_fmadot(ir::dot_inst* C, ir::value* A, ir::value* B, ir::va
std::map<indices_t, Value*> ret = vals_[D];
std::map<std::pair<int, int>, Value*> has, hbs;
auto ord = layout_c->get_order();
for(unsigned k = 0; k < NK; k++){
int z = 0;
for(unsigned m = 0; m < shape_c[0]; m += layout_c->shape_per_cta(0))
for(unsigned n = 0; n < shape_c[1]; n += layout_c->shape_per_cta(1))
for(unsigned mm = 0; mm < layout_c->nts(0); mm++)
for(unsigned nn = 0; nn < layout_c->nts(1); nn++)
{
for(unsigned i = 0; i < shape_c[ord[1]]; i += layout_c->shape_per_cta(ord[1]))
for(unsigned j = 0; j < shape_c[ord[0]]; j += layout_c->shape_per_cta(ord[0]))
for(unsigned ii = 0; ii < layout_c->nts(ord[1]); ii++)
for(unsigned jj = 0; jj < layout_c->nts(ord[0]); jj++){
unsigned m = (ord[0] == 1) ? i : j;
unsigned n = (ord[0] == 1) ? j : i;
unsigned mm = (ord[0] == 1) ? ii : jj;
unsigned nn = (ord[0] == 1) ? jj : ii;
if(has.find({m + mm, k}) == has.end()){
Value* pa = gep(ptrs_a[0], i32((m + mm)*stride_a_m + k*stride_a_k));
Value* va = load(pa);