more tinkering

This commit is contained in:
Philippe Tillet
2018-11-27 09:39:56 +01:00
parent bd5b213921
commit e0cd621bb8
2 changed files with 75 additions and 74 deletions

113
conv.cpp
View File

@@ -65,7 +65,7 @@ void autotune(llvm::TargetMachine *machine, llvm::Module &module){
for(llvm::TargetTuner::ParamType &param: tuning_params){
// This parameter has not been seen before
if(unique.insert(param.Value).second){
std::cout << instr.getName().data() << " " << param.Name << std::endl;
std::cout << "PARAM: " << instr.getName().data() << " " << param.Name << std::endl;
params.push_back(param.Value);
}
}
@@ -142,10 +142,12 @@ int main(){
llvm::Intrinsic::ID mma_id = llvm::Intrinsic::tlvm_mma_nt;
llvm::Function* outer_add = llvm::Intrinsic::getDeclaration(module.get(), llvm::Intrinsic::tlvm_outer_add, {int32_tile_t, int32_slice_t, int32_slice_t});
llvm::Function* outer_and = llvm::Intrinsic::getDeclaration(module.get(), llvm::Intrinsic::tlvm_outer_and, {int1_tile_t, int1_slice_t, int1_slice_t});
llvm::Function* outer_and_int32 = llvm::Intrinsic::getDeclaration(module.get(), llvm::Intrinsic::tlvm_outer_and, {int1_tile_t, int32_slice_t, int32_slice_t});
llvm::Function* mma = llvm::Intrinsic::getDeclaration(module.get(), mma_id, {tile_t});
llvm::Function* reshape = llvm::Intrinsic::getDeclaration(module.get(), llvm::Intrinsic::tlvm_reshape_2d, {tile_t});
llvm::Function* splat_2d = llvm::Intrinsic::getDeclaration(module.get(), llvm::Intrinsic::tlvm_splat_2d, {mask_tile_t, tile_t, bool_t});
llvm::Function* splat_1d = llvm::Intrinsic::getDeclaration(module.get(), llvm::Intrinsic::tlvm_splat_1d, {int32_slice_t, int32_slice_t, int32_t});
llvm::Function* splat_2d = llvm::Intrinsic::getDeclaration(module.get(), llvm::Intrinsic::tlvm_splat_2d, {mask_tile_t, bool_t});
llvm::Function* splat_1d = llvm::Intrinsic::getDeclaration(module.get(), llvm::Intrinsic::tlvm_splat_1d, {int32_slice_t, int32_t});
llvm::Function* masked_load = llvm::Intrinsic::getDeclaration(module.get(), llvm::Intrinsic::tlvm_masked_load, {tile_t, tile_ptr_t, mask_tile_t});
llvm::Function* masked_store = llvm::Intrinsic::getDeclaration(module.get(), llvm::Intrinsic::tlvm_masked_store, {tile_t, tile_ptr_t, mask_tile_t});
@@ -208,48 +210,48 @@ int main(){
llvm::Value* PQN = builder.CreateMul(H, builder.CreateMul(W, N));
// Images HWN offset
llvm::Value* sa_hw = builder.CreateUDiv(sa0, builder.CreateCall(splat_1d, {sa0, N}));
llvm::Value* sa_n = builder.CreateURem(sa0, builder.CreateCall(splat_1d, {sa0, N}));
llvm::Value* sa_h = builder.CreateUDiv(sa_hw, builder.CreateCall(splat_1d, {sa0, W}));
llvm::Value* sa_w = builder.CreateURem(sa_hw, builder.CreateCall(splat_1d, {sa0, W}));
llvm::Value* offa_0 = builder.CreateMul(sa_n, builder.CreateCall(splat_1d, {sa0, lda_n}));
offa_0 = builder.CreateAdd(offa_0, builder.CreateMul(sa_h, builder.CreateCall(splat_1d, {sa0, lda_h})));
offa_0 = builder.CreateAdd(offa_0, builder.CreateMul(sa_w, builder.CreateCall(splat_1d, {sa0, lda_w})));
llvm::Value* sa_hw = builder.CreateUDiv(sa0, builder.CreateCall(splat_1d, {bm, N}));
llvm::Value* sa_n = builder.CreateURem(sa0, builder.CreateCall(splat_1d, {bm, N}));
llvm::Value* sa_h = builder.CreateUDiv(sa_hw, builder.CreateCall(splat_1d, {bm, W}));
llvm::Value* sa_w = builder.CreateURem(sa_hw, builder.CreateCall(splat_1d, {bm, W}));
llvm::Value* offa_0 = builder.CreateMul(sa_n, builder.CreateCall(splat_1d, {bm, lda_n}));
offa_0 = builder.CreateAdd(offa_0, builder.CreateMul(sa_h, builder.CreateCall(splat_1d, {bm, lda_h})));
offa_0 = builder.CreateAdd(offa_0, builder.CreateMul(sa_w, builder.CreateCall(splat_1d, {bm, lda_w})));
// Images CRS offset
llvm::Value* sa_cr = builder.CreateUDiv(sa1, builder.CreateCall(splat_1d, {sa1, S}));
llvm::Value* sa_s = builder.CreateURem(sa1, builder.CreateCall(splat_1d, {sa1, S}));
llvm::Value* sa_c = builder.CreateUDiv(sa_cr, builder.CreateCall(splat_1d, {sa1, R}));
llvm::Value* sa_r = builder.CreateURem(sa_cr, builder.CreateCall(splat_1d, {sa1, R}));
llvm::Value* offa_1 = builder.CreateMul(sa_c, builder.CreateCall(splat_1d, {sa1, lda_c}));
offa_1 = builder.CreateAdd(offa_1, builder.CreateMul(sa_r, builder.CreateCall(splat_1d, {sa1, lda_h})));
offa_1 = builder.CreateAdd(offa_1, builder.CreateMul(sa_s, builder.CreateCall(splat_1d, {sa1, lda_w})));
llvm::Value* sa_cr = builder.CreateUDiv(sa1, builder.CreateCall(splat_1d, {bk, S}));
llvm::Value* sa_s = builder.CreateURem(sa1, builder.CreateCall(splat_1d, {bk, S}));
llvm::Value* sa_c = builder.CreateUDiv(sa_cr, builder.CreateCall(splat_1d, {bk, R}));
llvm::Value* sa_r = builder.CreateURem(sa_cr, builder.CreateCall(splat_1d, {bk, R}));
llvm::Value* offa_1 = builder.CreateMul(sa_c, builder.CreateCall(splat_1d, {bk, lda_c}));
offa_1 = builder.CreateAdd(offa_1, builder.CreateMul(sa_r, builder.CreateCall(splat_1d, {bk, lda_h})));
offa_1 = builder.CreateAdd(offa_1, builder.CreateMul(sa_s, builder.CreateCall(splat_1d, {bk, lda_w})));
// Images pointer
llvm::Value* off_a = builder.CreateCall(outer_add, {offa_0, offa_1});
llvm::Value* start_pa = builder.CreateCall(gtp_2d, {base_i_ptr, off_a}, "start_i_ptr");
llvm::LoadInst* start_aa = builder.CreateLoad(start_pa, false, "start_i_val");
llvm::Value* start_a = builder.CreateCall(reshape, {start_aa, bm, bk}, "start_i");
// Filters pointer
llvm::Value* tldb_s = builder.CreateCall(splat_1d, {sb1, K});
llvm::Value* tldb_s = builder.CreateCall(splat_1d, {bk, K});
llvm::Value* off_b = builder.CreateCall(outer_add, {sb0, builder.CreateMul(sb1, tldb_s)}, "off_f");
llvm::Value* start_pb = builder.CreateCall(gtp_2d, {base_f_ptr, off_b}, "start_f_ptr");
llvm::Value* start_bb = builder.CreateLoad(start_pb, false, "start_f_val");
llvm::Value* start_b = builder.CreateCall(reshape, {start_bb, bn, bk}, "start_f");
// Filters increment
llvm::Value* inc_b_0 = builder.CreateCall(splat_1d, {sb0, _s0}, "inc_f_0");
llvm::Value* inc_b_1 = builder.CreateCall(splat_1d, {sb1, builder.CreateMul(bk, ldb_k)}, "inc_f_1");
llvm::Value* inc_b_0 = builder.CreateCall(splat_1d, {bn, _s0}, "inc_f_0");
llvm::Value* inc_b_1 = builder.CreateCall(splat_1d, {bk, builder.CreateMul(bk, ldb_k)}, "inc_f_1");
llvm::Value* inc_b = builder.CreateCall(outer_add, {inc_b_0, inc_b_1}, "inc_f");
// Delta pointers
llvm::Value* base_incdelta = lut_ptr;
llvm::Value* start_pincdelta = builder.CreateCall(gtp_1d, {base_incdelta, sa1}, "start_pincdelta");
llvm::Value* start_pincdelta = builder.CreateCall(gtp_1d, {base_incdelta, sa0}, "start_pincdelta");
llvm::Value* base_delta = builder.CreateGEP(lut_ptr, builder.getInt32(nlut));
llvm::Value* start_pdelta = builder.CreateCall(gtp_1d, {base_delta, builder.CreateCall(splat_1d, {sa1, _s0})}, "start_pdelta");
llvm::Value* start_pdelta = builder.CreateCall(gtp_1d, {base_delta, builder.CreateCall(splat_1d, {bk, _s0})}, "start_pdelta");
// Masks
llvm::Value* _1 = builder.CreateCall(splat_1d, {sb1, builder.getInt32(1)});
llvm::Value* mask_a_1 = builder.CreateShl(_1, sb1);
llvm::Value* base_incmask = builder.CreateGEP(lut_ptr, builder.getInt32(2*nlut));
llvm::Value* start_pincmask = builder.CreateCall(gtp_1d, {base_incmask, sb1});
llvm::Value* base_mask = builder.CreateGEP(lut_ptr, builder.getInt32(3*nlut));
llvm::Value* start_pmask = builder.CreateCall(gtp_1d, {base_mask, sb1});
// llvm::Value* _1 = builder.CreateCall(splat_1d, {bk, builder.getInt32(1)});
// llvm::Value* mask_a_1 = builder.CreateShl(_1, sa1);
// llvm::Value* base_incmask = builder.CreateGEP(lut_ptr, builder.getInt32(2*nlut));
// llvm::Value* start_pincmask = builder.CreateCall(gtp_1d, {base_incmask, sa0});
// llvm::Value* base_mask = builder.CreateGEP(lut_ptr, builder.getInt32(3*nlut));
// llvm::Value* start_pmask = builder.CreateCall(gtp_1d, {base_mask, sa0});
// Enter loop
builder.CreateBr(LoopBB);
builder.SetInsertPoint(LoopBB);
@@ -262,8 +264,8 @@ int main(){
llvm::PHINode *b = builder.CreatePHI(start_b->getType(), 3, "b");
llvm::PHINode *pdelta = builder.CreatePHI(start_pdelta->getType(), 3);
llvm::PHINode *pincdelta = builder.CreatePHI(start_pincdelta->getType(), 3);
llvm::PHINode *pmasks = builder.CreatePHI(start_pmask->getType(), 3);
llvm::PHINode *pincmasks = builder.CreatePHI(start_pincmask->getType(), 3);
// llvm::PHINode *pmasks = builder.CreatePHI(start_pmask->getType(), 3);
// llvm::PHINode *pincmasks = builder.CreatePHI(start_pincmask->getType(), 3);
llvm::Value* next_c = builder.CreateCall(mma, {a, b, c}, "next_c");
c->addIncoming(_0, PrologBB);
c->addIncoming(next_c, LoopBB);
@@ -273,24 +275,24 @@ int main(){
crs->addIncoming(next_crs, LoopBB);
// Update pointer
llvm::Value *inc_delta = builder.CreateLoad(pincdelta);
llvm::Value *inc_mask = builder.CreateLoad(pincmasks);
// llvm::Value *inc_mask = builder.CreateLoad(pincmasks);
llvm::Value *inc_a_1 = builder.CreateLoad(pdelta);
llvm::Value *inc_a_0 = builder.CreateCall(splat_1d, {sa0, builder.getInt32(0)});
llvm::Value *inc_a_0 = builder.CreateCall(splat_1d, {bm, builder.getInt32(0)});
llvm::Value *inc_a = builder.CreateCall(outer_add, {inc_a_0, inc_a_1});
llvm::Value *next_pa = builder.CreateCall(stp_2d, {pa, inc_a}, "next_i_ptr");
llvm::Value *next_pb = builder.CreateCall(stp_2d, {pb, inc_b}, "next_f_ptr");
llvm::Value *next_pdelta = builder.CreateCall(stp_1d, {pdelta, inc_delta});
llvm::Value *next_pincdelta = builder.CreateCall(stp_1d, {pincdelta, inc_delta});
llvm::Value *next_pmask = builder.CreateCall(stp_1d, {pmasks, inc_mask});
llvm::Value *next_pincmask = builder.CreateCall(stp_1d, {pincmasks, inc_mask});
// llvm::Value *next_pmask = builder.CreateCall(stp_1d, {pmasks, inc_mask});
// llvm::Value *next_pincmask = builder.CreateCall(stp_1d, {pincmasks, inc_mask});
pdelta->addIncoming(start_pdelta, PrologBB);
pdelta->addIncoming(next_pdelta, LoopBB);
pincdelta->addIncoming(start_pincdelta, PrologBB);
pincdelta->addIncoming(next_pincdelta, LoopBB);
pmasks->addIncoming(start_pmask, PrologBB);
pmasks->addIncoming(next_pmask, LoopBB);
pincmasks->addIncoming(start_pincmask, PrologBB);
pincmasks->addIncoming(next_pincmask, LoopBB);
// pmasks->addIncoming(start_pmask, PrologBB);
// pmasks->addIncoming(next_pmask, LoopBB);
// pincmasks->addIncoming(start_pincmask, PrologBB);
// pincmasks->addIncoming(next_pincmask, LoopBB);
pa->addIncoming(start_pa, PrologBB);
pa->addIncoming(next_pa, LoopBB);
pb->addIncoming(start_pb, PrologBB);
@@ -298,9 +300,10 @@ int main(){
// End condition
llvm::Value* no_bounds_check = builder.CreateICmpSGT(next_crs, builder.getInt32(0));
// Masks
llvm::Value* mask_a_0 = builder.CreateLoad(pdelta);
llvm::Value* mask_a = builder.CreateCall(outer_and, {mask_a_0, mask_a_1});
llvm::Value* mask_b = builder.CreateCall(splat_2d, {start_bb, no_bounds_check}, "mask_b");
// llvm::Value* mask_a_0 = builder.CreateLoad(pmasks);
// llvm::Value* mask_a = builder.CreateCall(outer_and_int32, {mask_a_0, mask_a_1});
llvm::Value* mask_a = builder.CreateCall(splat_2d, {bm, bk, no_bounds_check}, "mask_a");
llvm::Value* mask_b = builder.CreateCall(splat_2d, {bn, bk, no_bounds_check}, "mask_b");
// Pre-fetch
llvm::Value* next_aa = builder.CreateCall(masked_load, {next_pa, mask_a}, "next_aa");
llvm::Value* next_bb = builder.CreateCall(masked_load, {next_pb, mask_b}, "next_bb");
@@ -318,8 +321,8 @@ int main(){
builder.CreateCondBr(exit, EpilogueBB, LastIterBB);
// Last Iteration
builder.SetInsertPoint(LastIterBB);
llvm::Value* in_bounds_b0 = builder.CreateICmpSLT(sb0, builder.CreateCall(splat_1d, {sb0, K}));
llvm::Value* in_bounds_b1 = builder.CreateICmpSLT(sb1, builder.CreateCall(splat_1d, {sb1, bk}));
llvm::Value* in_bounds_b0 = builder.CreateICmpSLT(sb0, builder.CreateCall(splat_1d, {bn, K}));
llvm::Value* in_bounds_b1 = builder.CreateICmpSLT(sb1, builder.CreateCall(splat_1d, {bk, next_crs}));
llvm::Value* last_maskb = builder.CreateCall(outer_and, {in_bounds_b0, in_bounds_b1}, "last_maskb");
llvm::Value* last_bb = builder.CreateCall(masked_load, {next_pb, last_maskb}, "last_bb");
llvm::Value* last_b = builder.CreateCall(reshape, {last_bb, bn, bk}, "last_b");
@@ -332,8 +335,8 @@ int main(){
pb->addIncoming(next_pb, LastIterBB);
pdelta->addIncoming(next_pdelta, LastIterBB);
pincdelta->addIncoming(next_pincdelta, LastIterBB);
pmasks->addIncoming(next_pmask, LastIterBB);
pincmasks->addIncoming(next_pincmask, LastIterBB);
// pmasks->addIncoming(next_pmask, LastIterBB);
// pincmasks->addIncoming(next_pincmask, LastIterBB);
builder.CreateCondBr(loop, LoopBB, EpilogueBB);
// Epilogue
@@ -346,21 +349,21 @@ int main(){
llvm::Value* ldc_k = builder.CreateMul(lda_h, H);
llvm::Value* ldb_n = builder.CreateMul(lda_c, K);
// Output PQN offset
llvm::Value* sc_pq = builder.CreateUDiv(sc_pqn, builder.CreateCall(splat_1d, {sc_pqn, N}));
llvm::Value* sc_n = builder.CreateURem(sc_pqn, builder.CreateCall(splat_1d, {sc_pqn, N}));
llvm::Value* sc_p = builder.CreateUDiv(sc_pq, builder.CreateCall(splat_1d, {sc_pqn, W}));
llvm::Value* sc_q = builder.CreateURem(sc_pq, builder.CreateCall(splat_1d, {sc_pqn, W}));
llvm::Value* offc0 = builder.CreateMul(sc_n, builder.CreateCall(splat_1d, {sc_pqn, ldb_n}));
offc0 = builder.CreateAdd(offc0, builder.CreateMul(sc_p, builder.CreateCall(splat_1d, {sc_pqn, ldc_p})));
offc0 = builder.CreateAdd(offc0, builder.CreateMul(sc_q, builder.CreateCall(splat_1d, {sc_pqn, ldc_q})));
llvm::Value* sc_pq = builder.CreateUDiv(sc_pqn, builder.CreateCall(splat_1d, {bm, N}));
llvm::Value* sc_n = builder.CreateURem(sc_pqn, builder.CreateCall(splat_1d, {bm, N}));
llvm::Value* sc_p = builder.CreateUDiv(sc_pq, builder.CreateCall(splat_1d, {bm, W}));
llvm::Value* sc_q = builder.CreateURem(sc_pq, builder.CreateCall(splat_1d, {bm, W}));
llvm::Value* offc0 = builder.CreateMul(sc_n, builder.CreateCall(splat_1d, {bm, ldb_n}));
offc0 = builder.CreateAdd(offc0, builder.CreateMul(sc_p, builder.CreateCall(splat_1d, {bm, ldc_p})));
offc0 = builder.CreateAdd(offc0, builder.CreateMul(sc_q, builder.CreateCall(splat_1d, {bm, ldc_q})));
// Output K offset
llvm::Value* offc1 = builder.CreateMul(sc_k, builder.CreateCall(splat_1d, {sc_k, ldc_k}));
llvm::Value* offc1 = builder.CreateMul(sc_k, builder.CreateCall(splat_1d, {bn, ldc_k}));
// Output pointer
llvm::Value* offc = builder.CreateCall(outer_add, {offc0, offc1});
llvm::Value* pc = builder.CreateCall(gtp_2d, {base_o_ptr, offc});
// Output masks
llvm::Value* in_bounds_c0 = builder.CreateICmpSLT(sc_pqn, builder.CreateCall(splat_1d, {sc_pqn, PQN}));
llvm::Value* in_bounds_c1 = builder.CreateICmpSLT(sc_k, builder.CreateCall(splat_1d, {sc_k, K}));
llvm::Value* in_bounds_c0 = builder.CreateICmpSLT(sc_pqn, builder.CreateCall(splat_1d, {bm, PQN}));
llvm::Value* in_bounds_c1 = builder.CreateICmpSLT(sc_k, builder.CreateCall(splat_1d, {bn, K}));
llvm::Value* maskc = builder.CreateCall(outer_and, {in_bounds_c0, in_bounds_c1});
builder.CreateCall(masked_store, {next_c, pc, maskc});
builder.CreateRet(NULL);

View File

@@ -138,14 +138,12 @@ int main(){
if(!AT && BT) mma_id = llvm::Intrinsic::tlvm_mma_nt;
if(AT && !BT) mma_id = llvm::Intrinsic::tlvm_mma_tn;
if(AT && BT) mma_id = llvm::Intrinsic::tlvm_mma_tt;
llvm::Function* broadcast_int32 = llvm::Intrinsic::getDeclaration(module.get(), llvm::Intrinsic::tlvm_broadcast_1d, {int32_tile_t, int32_slice_t});
llvm::Function* broadcast_int1 = llvm::Intrinsic::getDeclaration(module.get(), llvm::Intrinsic::tlvm_broadcast_1d, {int1_tile_t, int1_slice_t});
llvm::Function* outer_add = llvm::Intrinsic::getDeclaration(module.get(), llvm::Intrinsic::tlvm_outer_add, {int32_tile_t, int32_slice_t, int32_slice_t});
llvm::Function* outer_and = llvm::Intrinsic::getDeclaration(module.get(), llvm::Intrinsic::tlvm_outer_and, {int1_tile_t, int1_slice_t, int1_slice_t});
llvm::Function* mma = llvm::Intrinsic::getDeclaration(module.get(), mma_id, {tile_t});
llvm::Function* reshape = llvm::Intrinsic::getDeclaration(module.get(), llvm::Intrinsic::tlvm_reshape_2d, {tile_t});
llvm::Function* splat_2d = llvm::Intrinsic::getDeclaration(module.get(), llvm::Intrinsic::tlvm_splat_2d, {mask_tile_t, tile_t, bool_t});
llvm::Function* splat_1d = llvm::Intrinsic::getDeclaration(module.get(), llvm::Intrinsic::tlvm_splat_1d, {int32_slice_t, int32_slice_t, int32_t});
llvm::Function* splat_2d = llvm::Intrinsic::getDeclaration(module.get(), llvm::Intrinsic::tlvm_splat_2d, {mask_tile_t, bool_t});
llvm::Function* splat_1d = llvm::Intrinsic::getDeclaration(module.get(), llvm::Intrinsic::tlvm_splat_1d, {int32_slice_t, int32_t});
llvm::Function* masked_load = llvm::Intrinsic::getDeclaration(module.get(), llvm::Intrinsic::tlvm_masked_load, {tile_t, tile_ptr_t, mask_tile_t});
llvm::Function* masked_store = llvm::Intrinsic::getDeclaration(module.get(), llvm::Intrinsic::tlvm_masked_store, {tile_t, tile_ptr_t, mask_tile_t});
@@ -215,8 +213,8 @@ int main(){
std::swap(incb0, incb1);
}
llvm::CallInst* tlda = builder.CreateCall(splat_1d, {sa1, AS0}, "lda");
llvm::CallInst* tldb = builder.CreateCall(splat_1d, {sb1, BS1}, "ldb");
llvm::CallInst* tlda = builder.CreateCall(splat_1d, {ba1, AS0}, "lda");
llvm::CallInst* tldb = builder.CreateCall(splat_1d, {bb1, BS1}, "ldb");
llvm::CallInst* offa = builder.CreateCall(outer_add, {sa0, builder.CreateMul(sa1, tlda)}, "offa");
llvm::CallInst* offb = builder.CreateCall(outer_add, {sb0, builder.CreateMul(sb1, tldb)}, "offb");
llvm::CallInst* startpa = builder.CreateCall(gtp, {arguments[0], offa}, "startpa");
@@ -225,10 +223,10 @@ int main(){
llvm::LoadInst* startfb = builder.CreateLoad(startpb, "startfb");
llvm::Value* starta = builder.CreateCall(reshape, {startfa, ba0, ba1}, "starta");
llvm::Value* startb = builder.CreateCall(reshape, {startfb, bb0, bb1}, "startb");
llvm::Value* tinca0 = builder.CreateCall(splat_1d, {sa0, builder.CreateMul(inca0, AS0)});
llvm::Value* tinca1 = builder.CreateCall(splat_1d, {sa1, builder.CreateMul(inca1, AS1)});
llvm::Value* tincb0 = builder.CreateCall(splat_1d, {sb0, builder.CreateMul(incb0, BS0)});
llvm::Value* tincb1 = builder.CreateCall(splat_1d, {sb1, builder.CreateMul(incb1, BS1)});
llvm::Value* tinca0 = builder.CreateCall(splat_1d, {ba0, builder.CreateMul(inca0, AS0)}, "tinca0");
llvm::Value* tinca1 = builder.CreateCall(splat_1d, {ba1, builder.CreateMul(inca1, AS1)});
llvm::Value* tincb0 = builder.CreateCall(splat_1d, {bb0, builder.CreateMul(incb0, BS0)});
llvm::Value* tincb1 = builder.CreateCall(splat_1d, {bb1, builder.CreateMul(incb1, BS1)});
llvm::Value* inca = builder.CreateCall(outer_add, {tinca0, tinca1}, "inca");
llvm::Value* incb = builder.CreateCall(outer_add, {tincb0, tincb1}, "incb");
// Enter loop
@@ -258,8 +256,8 @@ int main(){
// End condition
llvm::Value* no_bounds_check = builder.CreateICmpSGT(nextk, bound);
// Masks
llvm::Value* maska = builder.CreateCall(splat_2d, {startfa, no_bounds_check}, "maska");
llvm::Value* maskb = builder.CreateCall(splat_2d, {startfb, no_bounds_check}, "maskb");
llvm::Value* maska = builder.CreateCall(splat_2d, {ba0, ba1, no_bounds_check}, "maska");
llvm::Value* maskb = builder.CreateCall(splat_2d, {bb0, bb1, no_bounds_check}, "maskb");
// Pre-fetch
llvm::Value* nextfa = builder.CreateCall(masked_load, {nextpa, maska}, "nextfa");
llvm::Value* nextfb = builder.CreateCall(masked_load, {nextpb, maskb}, "nextfb");
@@ -277,10 +275,10 @@ int main(){
builder.CreateCondBr(exit, EpilogueBB, LastIterBB);
// Last Iteration
builder.SetInsertPoint(LastIterBB);
llvm::Value* in_bounds_a0 = builder.CreateICmpSLT(aasm, builder.CreateCall(splat_1d, {aasm, M}));
llvm::Value* in_bounds_a1 = builder.CreateICmpSLT(ask, builder.CreateCall(splat_1d, {ask, bk}));
llvm::Value* in_bounds_b0 = builder.CreateICmpSLT(bbsn, builder.CreateCall(splat_1d, {bbsn, N}));
llvm::Value* in_bounds_b1 = builder.CreateICmpSLT(bsk, builder.CreateCall(splat_1d, {bsk, bk}));
llvm::Value* in_bounds_a0 = builder.CreateICmpSLT(aasm, builder.CreateCall(splat_1d, {ba0, M}));
llvm::Value* in_bounds_a1 = builder.CreateICmpSLT(ask, builder.CreateCall(splat_1d, {ba1, bk}));
llvm::Value* in_bounds_b0 = builder.CreateICmpSLT(bbsn, builder.CreateCall(splat_1d, {bb0, N}));
llvm::Value* in_bounds_b1 = builder.CreateICmpSLT(bsk, builder.CreateCall(splat_1d, {bb1, bk}));
llvm::Value* lastmaska = builder.CreateCall(outer_and, {in_bounds_a0, in_bounds_a1}, "lastmaska");
llvm::Value* lastmaskb = builder.CreateCall(outer_and, {in_bounds_b0, in_bounds_b1}, "lastmaskb");
llvm::Value* lastfa = builder.CreateCall(masked_load, {nextpa, lastmaska}, "lastfa");
@@ -299,11 +297,11 @@ int main(){
builder.SetInsertPoint(EpilogueBB);
llvm::CallInst* sm = builder.CreateCall(read_slice_x, {bm}, "sm");
llvm::CallInst* sn = builder.CreateCall(read_slice_y, {bn}, "sn");
llvm::CallInst* ldc = builder.CreateCall(splat_1d, {sn, M}, "lda");
llvm::CallInst* ldc = builder.CreateCall(splat_1d, {bn, M}, "lda");
llvm::CallInst* offc = builder.CreateCall(outer_add, {sm, builder.CreateMul(sn, ldc)}, "offc");
llvm::CallInst* pc = builder.CreateCall(gtp, {arguments[2], offc}, "pc");
llvm::Value* in_bounds_c0 = builder.CreateICmpSLT(sm, builder.CreateCall(splat_1d, {sm, M}));
llvm::Value* in_bounds_c1 = builder.CreateICmpSLT(sn, builder.CreateCall(splat_1d, {sn, N}));
llvm::Value* in_bounds_c0 = builder.CreateICmpSLT(sm, builder.CreateCall(splat_1d, {bm, M}));
llvm::Value* in_bounds_c1 = builder.CreateICmpSLT(sn, builder.CreateCall(splat_1d, {bn, N}));
llvm::Value* maskc = builder.CreateCall(outer_and, {in_bounds_c0, in_bounds_c1}, "maskc");
builder.CreateCall(masked_store, {nextc, pc, maskc});
builder.CreateRet(NULL);