diff --git a/conv.cpp b/conv.cpp index 0132b3fe2..d806b87c7 100644 --- a/conv.cpp +++ b/conv.cpp @@ -65,7 +65,7 @@ void autotune(llvm::TargetMachine *machine, llvm::Module &module){ for(llvm::TargetTuner::ParamType ¶m: tuning_params){ // This parameter has not been seen before if(unique.insert(param.Value).second){ - std::cout << instr.getName().data() << " " << param.Name << std::endl; + std::cout << "PARAM: " << instr.getName().data() << " " << param.Name << std::endl; params.push_back(param.Value); } } @@ -142,10 +142,12 @@ int main(){ llvm::Intrinsic::ID mma_id = llvm::Intrinsic::tlvm_mma_nt; llvm::Function* outer_add = llvm::Intrinsic::getDeclaration(module.get(), llvm::Intrinsic::tlvm_outer_add, {int32_tile_t, int32_slice_t, int32_slice_t}); llvm::Function* outer_and = llvm::Intrinsic::getDeclaration(module.get(), llvm::Intrinsic::tlvm_outer_and, {int1_tile_t, int1_slice_t, int1_slice_t}); + llvm::Function* outer_and_int32 = llvm::Intrinsic::getDeclaration(module.get(), llvm::Intrinsic::tlvm_outer_and, {int1_tile_t, int32_slice_t, int32_slice_t}); llvm::Function* mma = llvm::Intrinsic::getDeclaration(module.get(), mma_id, {tile_t}); llvm::Function* reshape = llvm::Intrinsic::getDeclaration(module.get(), llvm::Intrinsic::tlvm_reshape_2d, {tile_t}); - llvm::Function* splat_2d = llvm::Intrinsic::getDeclaration(module.get(), llvm::Intrinsic::tlvm_splat_2d, {mask_tile_t, tile_t, bool_t}); - llvm::Function* splat_1d = llvm::Intrinsic::getDeclaration(module.get(), llvm::Intrinsic::tlvm_splat_1d, {int32_slice_t, int32_slice_t, int32_t}); + llvm::Function* splat_2d = llvm::Intrinsic::getDeclaration(module.get(), llvm::Intrinsic::tlvm_splat_2d, {mask_tile_t, bool_t}); + llvm::Function* splat_1d = llvm::Intrinsic::getDeclaration(module.get(), llvm::Intrinsic::tlvm_splat_1d, {int32_slice_t, int32_t}); + llvm::Function* masked_load = llvm::Intrinsic::getDeclaration(module.get(), llvm::Intrinsic::tlvm_masked_load, {tile_t, tile_ptr_t, mask_tile_t}); llvm::Function* masked_store = llvm::Intrinsic::getDeclaration(module.get(), llvm::Intrinsic::tlvm_masked_store, {tile_t, tile_ptr_t, mask_tile_t}); @@ -208,48 +210,48 @@ int main(){ llvm::Value* PQN = builder.CreateMul(H, builder.CreateMul(W, N)); // Images HWN offset - llvm::Value* sa_hw = builder.CreateUDiv(sa0, builder.CreateCall(splat_1d, {sa0, N})); - llvm::Value* sa_n = builder.CreateURem(sa0, builder.CreateCall(splat_1d, {sa0, N})); - llvm::Value* sa_h = builder.CreateUDiv(sa_hw, builder.CreateCall(splat_1d, {sa0, W})); - llvm::Value* sa_w = builder.CreateURem(sa_hw, builder.CreateCall(splat_1d, {sa0, W})); - llvm::Value* offa_0 = builder.CreateMul(sa_n, builder.CreateCall(splat_1d, {sa0, lda_n})); - offa_0 = builder.CreateAdd(offa_0, builder.CreateMul(sa_h, builder.CreateCall(splat_1d, {sa0, lda_h}))); - offa_0 = builder.CreateAdd(offa_0, builder.CreateMul(sa_w, builder.CreateCall(splat_1d, {sa0, lda_w}))); + llvm::Value* sa_hw = builder.CreateUDiv(sa0, builder.CreateCall(splat_1d, {bm, N})); + llvm::Value* sa_n = builder.CreateURem(sa0, builder.CreateCall(splat_1d, {bm, N})); + llvm::Value* sa_h = builder.CreateUDiv(sa_hw, builder.CreateCall(splat_1d, {bm, W})); + llvm::Value* sa_w = builder.CreateURem(sa_hw, builder.CreateCall(splat_1d, {bm, W})); + llvm::Value* offa_0 = builder.CreateMul(sa_n, builder.CreateCall(splat_1d, {bm, lda_n})); + offa_0 = builder.CreateAdd(offa_0, builder.CreateMul(sa_h, builder.CreateCall(splat_1d, {bm, lda_h}))); + offa_0 = builder.CreateAdd(offa_0, builder.CreateMul(sa_w, builder.CreateCall(splat_1d, {bm, lda_w}))); // Images CRS offset - llvm::Value* sa_cr = builder.CreateUDiv(sa1, builder.CreateCall(splat_1d, {sa1, S})); - llvm::Value* sa_s = builder.CreateURem(sa1, builder.CreateCall(splat_1d, {sa1, S})); - llvm::Value* sa_c = builder.CreateUDiv(sa_cr, builder.CreateCall(splat_1d, {sa1, R})); - llvm::Value* sa_r = builder.CreateURem(sa_cr, builder.CreateCall(splat_1d, {sa1, R})); - llvm::Value* offa_1 = builder.CreateMul(sa_c, builder.CreateCall(splat_1d, {sa1, lda_c})); - offa_1 = builder.CreateAdd(offa_1, builder.CreateMul(sa_r, builder.CreateCall(splat_1d, {sa1, lda_h}))); - offa_1 = builder.CreateAdd(offa_1, builder.CreateMul(sa_s, builder.CreateCall(splat_1d, {sa1, lda_w}))); + llvm::Value* sa_cr = builder.CreateUDiv(sa1, builder.CreateCall(splat_1d, {bk, S})); + llvm::Value* sa_s = builder.CreateURem(sa1, builder.CreateCall(splat_1d, {bk, S})); + llvm::Value* sa_c = builder.CreateUDiv(sa_cr, builder.CreateCall(splat_1d, {bk, R})); + llvm::Value* sa_r = builder.CreateURem(sa_cr, builder.CreateCall(splat_1d, {bk, R})); + llvm::Value* offa_1 = builder.CreateMul(sa_c, builder.CreateCall(splat_1d, {bk, lda_c})); + offa_1 = builder.CreateAdd(offa_1, builder.CreateMul(sa_r, builder.CreateCall(splat_1d, {bk, lda_h}))); + offa_1 = builder.CreateAdd(offa_1, builder.CreateMul(sa_s, builder.CreateCall(splat_1d, {bk, lda_w}))); // Images pointer llvm::Value* off_a = builder.CreateCall(outer_add, {offa_0, offa_1}); llvm::Value* start_pa = builder.CreateCall(gtp_2d, {base_i_ptr, off_a}, "start_i_ptr"); llvm::LoadInst* start_aa = builder.CreateLoad(start_pa, false, "start_i_val"); llvm::Value* start_a = builder.CreateCall(reshape, {start_aa, bm, bk}, "start_i"); // Filters pointer - llvm::Value* tldb_s = builder.CreateCall(splat_1d, {sb1, K}); + llvm::Value* tldb_s = builder.CreateCall(splat_1d, {bk, K}); llvm::Value* off_b = builder.CreateCall(outer_add, {sb0, builder.CreateMul(sb1, tldb_s)}, "off_f"); llvm::Value* start_pb = builder.CreateCall(gtp_2d, {base_f_ptr, off_b}, "start_f_ptr"); llvm::Value* start_bb = builder.CreateLoad(start_pb, false, "start_f_val"); llvm::Value* start_b = builder.CreateCall(reshape, {start_bb, bn, bk}, "start_f"); // Filters increment - llvm::Value* inc_b_0 = builder.CreateCall(splat_1d, {sb0, _s0}, "inc_f_0"); - llvm::Value* inc_b_1 = builder.CreateCall(splat_1d, {sb1, builder.CreateMul(bk, ldb_k)}, "inc_f_1"); + llvm::Value* inc_b_0 = builder.CreateCall(splat_1d, {bn, _s0}, "inc_f_0"); + llvm::Value* inc_b_1 = builder.CreateCall(splat_1d, {bk, builder.CreateMul(bk, ldb_k)}, "inc_f_1"); llvm::Value* inc_b = builder.CreateCall(outer_add, {inc_b_0, inc_b_1}, "inc_f"); // Delta pointers llvm::Value* base_incdelta = lut_ptr; - llvm::Value* start_pincdelta = builder.CreateCall(gtp_1d, {base_incdelta, sa1}, "start_pincdelta"); + llvm::Value* start_pincdelta = builder.CreateCall(gtp_1d, {base_incdelta, sa0}, "start_pincdelta"); llvm::Value* base_delta = builder.CreateGEP(lut_ptr, builder.getInt32(nlut)); - llvm::Value* start_pdelta = builder.CreateCall(gtp_1d, {base_delta, builder.CreateCall(splat_1d, {sa1, _s0})}, "start_pdelta"); + llvm::Value* start_pdelta = builder.CreateCall(gtp_1d, {base_delta, builder.CreateCall(splat_1d, {bk, _s0})}, "start_pdelta"); // Masks - llvm::Value* _1 = builder.CreateCall(splat_1d, {sb1, builder.getInt32(1)}); - llvm::Value* mask_a_1 = builder.CreateShl(_1, sb1); - llvm::Value* base_incmask = builder.CreateGEP(lut_ptr, builder.getInt32(2*nlut)); - llvm::Value* start_pincmask = builder.CreateCall(gtp_1d, {base_incmask, sb1}); - llvm::Value* base_mask = builder.CreateGEP(lut_ptr, builder.getInt32(3*nlut)); - llvm::Value* start_pmask = builder.CreateCall(gtp_1d, {base_mask, sb1}); +// llvm::Value* _1 = builder.CreateCall(splat_1d, {bk, builder.getInt32(1)}); +// llvm::Value* mask_a_1 = builder.CreateShl(_1, sa1); +// llvm::Value* base_incmask = builder.CreateGEP(lut_ptr, builder.getInt32(2*nlut)); +// llvm::Value* start_pincmask = builder.CreateCall(gtp_1d, {base_incmask, sa0}); +// llvm::Value* base_mask = builder.CreateGEP(lut_ptr, builder.getInt32(3*nlut)); +// llvm::Value* start_pmask = builder.CreateCall(gtp_1d, {base_mask, sa0}); // Enter loop builder.CreateBr(LoopBB); builder.SetInsertPoint(LoopBB); @@ -262,8 +264,8 @@ int main(){ llvm::PHINode *b = builder.CreatePHI(start_b->getType(), 3, "b"); llvm::PHINode *pdelta = builder.CreatePHI(start_pdelta->getType(), 3); llvm::PHINode *pincdelta = builder.CreatePHI(start_pincdelta->getType(), 3); - llvm::PHINode *pmasks = builder.CreatePHI(start_pmask->getType(), 3); - llvm::PHINode *pincmasks = builder.CreatePHI(start_pincmask->getType(), 3); +// llvm::PHINode *pmasks = builder.CreatePHI(start_pmask->getType(), 3); +// llvm::PHINode *pincmasks = builder.CreatePHI(start_pincmask->getType(), 3); llvm::Value* next_c = builder.CreateCall(mma, {a, b, c}, "next_c"); c->addIncoming(_0, PrologBB); c->addIncoming(next_c, LoopBB); @@ -273,24 +275,24 @@ int main(){ crs->addIncoming(next_crs, LoopBB); // Update pointer llvm::Value *inc_delta = builder.CreateLoad(pincdelta); - llvm::Value *inc_mask = builder.CreateLoad(pincmasks); +// llvm::Value *inc_mask = builder.CreateLoad(pincmasks); llvm::Value *inc_a_1 = builder.CreateLoad(pdelta); - llvm::Value *inc_a_0 = builder.CreateCall(splat_1d, {sa0, builder.getInt32(0)}); + llvm::Value *inc_a_0 = builder.CreateCall(splat_1d, {bm, builder.getInt32(0)}); llvm::Value *inc_a = builder.CreateCall(outer_add, {inc_a_0, inc_a_1}); llvm::Value *next_pa = builder.CreateCall(stp_2d, {pa, inc_a}, "next_i_ptr"); llvm::Value *next_pb = builder.CreateCall(stp_2d, {pb, inc_b}, "next_f_ptr"); llvm::Value *next_pdelta = builder.CreateCall(stp_1d, {pdelta, inc_delta}); llvm::Value *next_pincdelta = builder.CreateCall(stp_1d, {pincdelta, inc_delta}); - llvm::Value *next_pmask = builder.CreateCall(stp_1d, {pmasks, inc_mask}); - llvm::Value *next_pincmask = builder.CreateCall(stp_1d, {pincmasks, inc_mask}); +// llvm::Value *next_pmask = builder.CreateCall(stp_1d, {pmasks, inc_mask}); +// llvm::Value *next_pincmask = builder.CreateCall(stp_1d, {pincmasks, inc_mask}); pdelta->addIncoming(start_pdelta, PrologBB); pdelta->addIncoming(next_pdelta, LoopBB); pincdelta->addIncoming(start_pincdelta, PrologBB); pincdelta->addIncoming(next_pincdelta, LoopBB); - pmasks->addIncoming(start_pmask, PrologBB); - pmasks->addIncoming(next_pmask, LoopBB); - pincmasks->addIncoming(start_pincmask, PrologBB); - pincmasks->addIncoming(next_pincmask, LoopBB); +// pmasks->addIncoming(start_pmask, PrologBB); +// pmasks->addIncoming(next_pmask, LoopBB); +// pincmasks->addIncoming(start_pincmask, PrologBB); +// pincmasks->addIncoming(next_pincmask, LoopBB); pa->addIncoming(start_pa, PrologBB); pa->addIncoming(next_pa, LoopBB); pb->addIncoming(start_pb, PrologBB); @@ -298,9 +300,10 @@ int main(){ // End condition llvm::Value* no_bounds_check = builder.CreateICmpSGT(next_crs, builder.getInt32(0)); // Masks - llvm::Value* mask_a_0 = builder.CreateLoad(pdelta); - llvm::Value* mask_a = builder.CreateCall(outer_and, {mask_a_0, mask_a_1}); - llvm::Value* mask_b = builder.CreateCall(splat_2d, {start_bb, no_bounds_check}, "mask_b"); +// llvm::Value* mask_a_0 = builder.CreateLoad(pmasks); +// llvm::Value* mask_a = builder.CreateCall(outer_and_int32, {mask_a_0, mask_a_1}); + llvm::Value* mask_a = builder.CreateCall(splat_2d, {bm, bk, no_bounds_check}, "mask_a"); + llvm::Value* mask_b = builder.CreateCall(splat_2d, {bn, bk, no_bounds_check}, "mask_b"); // Pre-fetch llvm::Value* next_aa = builder.CreateCall(masked_load, {next_pa, mask_a}, "next_aa"); llvm::Value* next_bb = builder.CreateCall(masked_load, {next_pb, mask_b}, "next_bb"); @@ -318,8 +321,8 @@ int main(){ builder.CreateCondBr(exit, EpilogueBB, LastIterBB); // Last Iteration builder.SetInsertPoint(LastIterBB); - llvm::Value* in_bounds_b0 = builder.CreateICmpSLT(sb0, builder.CreateCall(splat_1d, {sb0, K})); - llvm::Value* in_bounds_b1 = builder.CreateICmpSLT(sb1, builder.CreateCall(splat_1d, {sb1, bk})); + llvm::Value* in_bounds_b0 = builder.CreateICmpSLT(sb0, builder.CreateCall(splat_1d, {bn, K})); + llvm::Value* in_bounds_b1 = builder.CreateICmpSLT(sb1, builder.CreateCall(splat_1d, {bk, next_crs})); llvm::Value* last_maskb = builder.CreateCall(outer_and, {in_bounds_b0, in_bounds_b1}, "last_maskb"); llvm::Value* last_bb = builder.CreateCall(masked_load, {next_pb, last_maskb}, "last_bb"); llvm::Value* last_b = builder.CreateCall(reshape, {last_bb, bn, bk}, "last_b"); @@ -332,8 +335,8 @@ int main(){ pb->addIncoming(next_pb, LastIterBB); pdelta->addIncoming(next_pdelta, LastIterBB); pincdelta->addIncoming(next_pincdelta, LastIterBB); - pmasks->addIncoming(next_pmask, LastIterBB); - pincmasks->addIncoming(next_pincmask, LastIterBB); +// pmasks->addIncoming(next_pmask, LastIterBB); +// pincmasks->addIncoming(next_pincmask, LastIterBB); builder.CreateCondBr(loop, LoopBB, EpilogueBB); // Epilogue @@ -346,21 +349,21 @@ int main(){ llvm::Value* ldc_k = builder.CreateMul(lda_h, H); llvm::Value* ldb_n = builder.CreateMul(lda_c, K); // Output PQN offset - llvm::Value* sc_pq = builder.CreateUDiv(sc_pqn, builder.CreateCall(splat_1d, {sc_pqn, N})); - llvm::Value* sc_n = builder.CreateURem(sc_pqn, builder.CreateCall(splat_1d, {sc_pqn, N})); - llvm::Value* sc_p = builder.CreateUDiv(sc_pq, builder.CreateCall(splat_1d, {sc_pqn, W})); - llvm::Value* sc_q = builder.CreateURem(sc_pq, builder.CreateCall(splat_1d, {sc_pqn, W})); - llvm::Value* offc0 = builder.CreateMul(sc_n, builder.CreateCall(splat_1d, {sc_pqn, ldb_n})); - offc0 = builder.CreateAdd(offc0, builder.CreateMul(sc_p, builder.CreateCall(splat_1d, {sc_pqn, ldc_p}))); - offc0 = builder.CreateAdd(offc0, builder.CreateMul(sc_q, builder.CreateCall(splat_1d, {sc_pqn, ldc_q}))); + llvm::Value* sc_pq = builder.CreateUDiv(sc_pqn, builder.CreateCall(splat_1d, {bm, N})); + llvm::Value* sc_n = builder.CreateURem(sc_pqn, builder.CreateCall(splat_1d, {bm, N})); + llvm::Value* sc_p = builder.CreateUDiv(sc_pq, builder.CreateCall(splat_1d, {bm, W})); + llvm::Value* sc_q = builder.CreateURem(sc_pq, builder.CreateCall(splat_1d, {bm, W})); + llvm::Value* offc0 = builder.CreateMul(sc_n, builder.CreateCall(splat_1d, {bm, ldb_n})); + offc0 = builder.CreateAdd(offc0, builder.CreateMul(sc_p, builder.CreateCall(splat_1d, {bm, ldc_p}))); + offc0 = builder.CreateAdd(offc0, builder.CreateMul(sc_q, builder.CreateCall(splat_1d, {bm, ldc_q}))); // Output K offset - llvm::Value* offc1 = builder.CreateMul(sc_k, builder.CreateCall(splat_1d, {sc_k, ldc_k})); + llvm::Value* offc1 = builder.CreateMul(sc_k, builder.CreateCall(splat_1d, {bn, ldc_k})); // Output pointer llvm::Value* offc = builder.CreateCall(outer_add, {offc0, offc1}); llvm::Value* pc = builder.CreateCall(gtp_2d, {base_o_ptr, offc}); // Output masks - llvm::Value* in_bounds_c0 = builder.CreateICmpSLT(sc_pqn, builder.CreateCall(splat_1d, {sc_pqn, PQN})); - llvm::Value* in_bounds_c1 = builder.CreateICmpSLT(sc_k, builder.CreateCall(splat_1d, {sc_k, K})); + llvm::Value* in_bounds_c0 = builder.CreateICmpSLT(sc_pqn, builder.CreateCall(splat_1d, {bm, PQN})); + llvm::Value* in_bounds_c1 = builder.CreateICmpSLT(sc_k, builder.CreateCall(splat_1d, {bn, K})); llvm::Value* maskc = builder.CreateCall(outer_and, {in_bounds_c0, in_bounds_c1}); builder.CreateCall(masked_store, {next_c, pc, maskc}); builder.CreateRet(NULL); diff --git a/gemm.cpp b/gemm.cpp index 6a63327b9..5adc3ebbd 100644 --- a/gemm.cpp +++ b/gemm.cpp @@ -138,14 +138,12 @@ int main(){ if(!AT && BT) mma_id = llvm::Intrinsic::tlvm_mma_nt; if(AT && !BT) mma_id = llvm::Intrinsic::tlvm_mma_tn; if(AT && BT) mma_id = llvm::Intrinsic::tlvm_mma_tt; - llvm::Function* broadcast_int32 = llvm::Intrinsic::getDeclaration(module.get(), llvm::Intrinsic::tlvm_broadcast_1d, {int32_tile_t, int32_slice_t}); - llvm::Function* broadcast_int1 = llvm::Intrinsic::getDeclaration(module.get(), llvm::Intrinsic::tlvm_broadcast_1d, {int1_tile_t, int1_slice_t}); llvm::Function* outer_add = llvm::Intrinsic::getDeclaration(module.get(), llvm::Intrinsic::tlvm_outer_add, {int32_tile_t, int32_slice_t, int32_slice_t}); llvm::Function* outer_and = llvm::Intrinsic::getDeclaration(module.get(), llvm::Intrinsic::tlvm_outer_and, {int1_tile_t, int1_slice_t, int1_slice_t}); llvm::Function* mma = llvm::Intrinsic::getDeclaration(module.get(), mma_id, {tile_t}); llvm::Function* reshape = llvm::Intrinsic::getDeclaration(module.get(), llvm::Intrinsic::tlvm_reshape_2d, {tile_t}); - llvm::Function* splat_2d = llvm::Intrinsic::getDeclaration(module.get(), llvm::Intrinsic::tlvm_splat_2d, {mask_tile_t, tile_t, bool_t}); - llvm::Function* splat_1d = llvm::Intrinsic::getDeclaration(module.get(), llvm::Intrinsic::tlvm_splat_1d, {int32_slice_t, int32_slice_t, int32_t}); + llvm::Function* splat_2d = llvm::Intrinsic::getDeclaration(module.get(), llvm::Intrinsic::tlvm_splat_2d, {mask_tile_t, bool_t}); + llvm::Function* splat_1d = llvm::Intrinsic::getDeclaration(module.get(), llvm::Intrinsic::tlvm_splat_1d, {int32_slice_t, int32_t}); llvm::Function* masked_load = llvm::Intrinsic::getDeclaration(module.get(), llvm::Intrinsic::tlvm_masked_load, {tile_t, tile_ptr_t, mask_tile_t}); llvm::Function* masked_store = llvm::Intrinsic::getDeclaration(module.get(), llvm::Intrinsic::tlvm_masked_store, {tile_t, tile_ptr_t, mask_tile_t}); @@ -215,8 +213,8 @@ int main(){ std::swap(incb0, incb1); } - llvm::CallInst* tlda = builder.CreateCall(splat_1d, {sa1, AS0}, "lda"); - llvm::CallInst* tldb = builder.CreateCall(splat_1d, {sb1, BS1}, "ldb"); + llvm::CallInst* tlda = builder.CreateCall(splat_1d, {ba1, AS0}, "lda"); + llvm::CallInst* tldb = builder.CreateCall(splat_1d, {bb1, BS1}, "ldb"); llvm::CallInst* offa = builder.CreateCall(outer_add, {sa0, builder.CreateMul(sa1, tlda)}, "offa"); llvm::CallInst* offb = builder.CreateCall(outer_add, {sb0, builder.CreateMul(sb1, tldb)}, "offb"); llvm::CallInst* startpa = builder.CreateCall(gtp, {arguments[0], offa}, "startpa"); @@ -225,10 +223,10 @@ int main(){ llvm::LoadInst* startfb = builder.CreateLoad(startpb, "startfb"); llvm::Value* starta = builder.CreateCall(reshape, {startfa, ba0, ba1}, "starta"); llvm::Value* startb = builder.CreateCall(reshape, {startfb, bb0, bb1}, "startb"); - llvm::Value* tinca0 = builder.CreateCall(splat_1d, {sa0, builder.CreateMul(inca0, AS0)}); - llvm::Value* tinca1 = builder.CreateCall(splat_1d, {sa1, builder.CreateMul(inca1, AS1)}); - llvm::Value* tincb0 = builder.CreateCall(splat_1d, {sb0, builder.CreateMul(incb0, BS0)}); - llvm::Value* tincb1 = builder.CreateCall(splat_1d, {sb1, builder.CreateMul(incb1, BS1)}); + llvm::Value* tinca0 = builder.CreateCall(splat_1d, {ba0, builder.CreateMul(inca0, AS0)}, "tinca0"); + llvm::Value* tinca1 = builder.CreateCall(splat_1d, {ba1, builder.CreateMul(inca1, AS1)}); + llvm::Value* tincb0 = builder.CreateCall(splat_1d, {bb0, builder.CreateMul(incb0, BS0)}); + llvm::Value* tincb1 = builder.CreateCall(splat_1d, {bb1, builder.CreateMul(incb1, BS1)}); llvm::Value* inca = builder.CreateCall(outer_add, {tinca0, tinca1}, "inca"); llvm::Value* incb = builder.CreateCall(outer_add, {tincb0, tincb1}, "incb"); // Enter loop @@ -258,8 +256,8 @@ int main(){ // End condition llvm::Value* no_bounds_check = builder.CreateICmpSGT(nextk, bound); // Masks - llvm::Value* maska = builder.CreateCall(splat_2d, {startfa, no_bounds_check}, "maska"); - llvm::Value* maskb = builder.CreateCall(splat_2d, {startfb, no_bounds_check}, "maskb"); + llvm::Value* maska = builder.CreateCall(splat_2d, {ba0, ba1, no_bounds_check}, "maska"); + llvm::Value* maskb = builder.CreateCall(splat_2d, {bb0, bb1, no_bounds_check}, "maskb"); // Pre-fetch llvm::Value* nextfa = builder.CreateCall(masked_load, {nextpa, maska}, "nextfa"); llvm::Value* nextfb = builder.CreateCall(masked_load, {nextpb, maskb}, "nextfb"); @@ -277,10 +275,10 @@ int main(){ builder.CreateCondBr(exit, EpilogueBB, LastIterBB); // Last Iteration builder.SetInsertPoint(LastIterBB); - llvm::Value* in_bounds_a0 = builder.CreateICmpSLT(aasm, builder.CreateCall(splat_1d, {aasm, M})); - llvm::Value* in_bounds_a1 = builder.CreateICmpSLT(ask, builder.CreateCall(splat_1d, {ask, bk})); - llvm::Value* in_bounds_b0 = builder.CreateICmpSLT(bbsn, builder.CreateCall(splat_1d, {bbsn, N})); - llvm::Value* in_bounds_b1 = builder.CreateICmpSLT(bsk, builder.CreateCall(splat_1d, {bsk, bk})); + llvm::Value* in_bounds_a0 = builder.CreateICmpSLT(aasm, builder.CreateCall(splat_1d, {ba0, M})); + llvm::Value* in_bounds_a1 = builder.CreateICmpSLT(ask, builder.CreateCall(splat_1d, {ba1, bk})); + llvm::Value* in_bounds_b0 = builder.CreateICmpSLT(bbsn, builder.CreateCall(splat_1d, {bb0, N})); + llvm::Value* in_bounds_b1 = builder.CreateICmpSLT(bsk, builder.CreateCall(splat_1d, {bb1, bk})); llvm::Value* lastmaska = builder.CreateCall(outer_and, {in_bounds_a0, in_bounds_a1}, "lastmaska"); llvm::Value* lastmaskb = builder.CreateCall(outer_and, {in_bounds_b0, in_bounds_b1}, "lastmaskb"); llvm::Value* lastfa = builder.CreateCall(masked_load, {nextpa, lastmaska}, "lastfa"); @@ -299,11 +297,11 @@ int main(){ builder.SetInsertPoint(EpilogueBB); llvm::CallInst* sm = builder.CreateCall(read_slice_x, {bm}, "sm"); llvm::CallInst* sn = builder.CreateCall(read_slice_y, {bn}, "sn"); - llvm::CallInst* ldc = builder.CreateCall(splat_1d, {sn, M}, "lda"); + llvm::CallInst* ldc = builder.CreateCall(splat_1d, {bn, M}, "lda"); llvm::CallInst* offc = builder.CreateCall(outer_add, {sm, builder.CreateMul(sn, ldc)}, "offc"); llvm::CallInst* pc = builder.CreateCall(gtp, {arguments[2], offc}, "pc"); - llvm::Value* in_bounds_c0 = builder.CreateICmpSLT(sm, builder.CreateCall(splat_1d, {sm, M})); - llvm::Value* in_bounds_c1 = builder.CreateICmpSLT(sn, builder.CreateCall(splat_1d, {sn, N})); + llvm::Value* in_bounds_c0 = builder.CreateICmpSLT(sm, builder.CreateCall(splat_1d, {bm, M})); + llvm::Value* in_bounds_c1 = builder.CreateICmpSLT(sn, builder.CreateCall(splat_1d, {bn, N})); llvm::Value* maskc = builder.CreateCall(outer_and, {in_bounds_c0, in_bounds_c1}, "maskc"); builder.CreateCall(masked_store, {nextc, pc, maskc}); builder.CreateRet(NULL);