[BACKEND] Support dot op when the output is mma encoding and allowtf32 is true (#937)

This commit is contained in:
Keren Zhou
2022-12-03 11:14:12 -08:00
committed by GitHub
parent 8edfe813a5
commit f2fcaeabf3
5 changed files with 105 additions and 72 deletions

View File

@@ -708,19 +708,19 @@ public:
Type elemTy = type::f32Ty(ctx);
Type elemPtrTy = ptr_ty(elemTy);
if (kOrder == 1) {
elems[0] = load(gep(elemPtrTy, ptr, i32_val(sOffsetElem)));
elems[1] = load(gep(elemPtrTy, ptr2, i32_val(sOffsetElem)));
elems[0] = load(gep(elemPtrTy, ptr, sOffsetElemVal));
elems[1] = load(gep(elemPtrTy, ptr2, sOffsetElemVal));
elems[2] =
load(gep(elemPtrTy, ptr, i32_val(sOffsetElem + sOffsetArrElem)));
load(gep(elemPtrTy, ptr, sOffsetArrElemVal));
elems[3] =
load(gep(elemPtrTy, ptr2, i32_val(sOffsetElem + sOffsetArrElem)));
load(gep(elemPtrTy, ptr2, sOffsetArrElemVal));
} else {
elems[0] = load(gep(elemPtrTy, ptr, i32_val(sOffsetElem)));
elems[2] = load(gep(elemPtrTy, ptr2, i32_val(sOffsetElem)));
elems[0] = load(gep(elemPtrTy, ptr, sOffsetElemVal));
elems[2] = load(gep(elemPtrTy, ptr2, sOffsetElemVal));
elems[1] =
load(gep(elemPtrTy, ptr, i32_val(sOffsetElem + sOffsetArrElem)));
load(gep(elemPtrTy, ptr, sOffsetArrElemVal));
elems[3] =
load(gep(elemPtrTy, ptr2, i32_val(sOffsetElem + sOffsetArrElem)));
load(gep(elemPtrTy, ptr2, sOffsetArrElemVal));
}
return {elems[0], elems[1], elems[2], elems[3]};