More cleaning of masks

This commit is contained in:
Philippe Tillet
2018-11-27 12:20:51 +01:00
parent e0cd621bb8
commit 68c8de88f5

View File

@@ -242,16 +242,16 @@ int main(){
llvm::Value* inc_b = builder.CreateCall(outer_add, {inc_b_0, inc_b_1}, "inc_f");
// Delta pointers
llvm::Value* base_incdelta = lut_ptr;
llvm::Value* start_pincdelta = builder.CreateCall(gtp_1d, {base_incdelta, sa0}, "start_pincdelta");
llvm::Value* start_pincdelta = builder.CreateCall(gtp_1d, {base_incdelta, sa1}, "start_pincdelta");
llvm::Value* base_delta = builder.CreateGEP(lut_ptr, builder.getInt32(nlut));
llvm::Value* start_pdelta = builder.CreateCall(gtp_1d, {base_delta, builder.CreateCall(splat_1d, {bk, _s0})}, "start_pdelta");
// Masks
// llvm::Value* _1 = builder.CreateCall(splat_1d, {bk, builder.getInt32(1)});
// llvm::Value* mask_a_1 = builder.CreateShl(_1, sa1);
// llvm::Value* base_incmask = builder.CreateGEP(lut_ptr, builder.getInt32(2*nlut));
// llvm::Value* start_pincmask = builder.CreateCall(gtp_1d, {base_incmask, sa0});
// llvm::Value* base_mask = builder.CreateGEP(lut_ptr, builder.getInt32(3*nlut));
// llvm::Value* start_pmask = builder.CreateCall(gtp_1d, {base_mask, sa0});
llvm::Value* _1 = builder.CreateCall(splat_1d, {bk, builder.getInt32(1)});
llvm::Value* mask_a_1 = builder.CreateShl(_1, sa1);
llvm::Value* base_incmask = builder.CreateGEP(lut_ptr, builder.getInt32(2*nlut), "base_incmask");
llvm::Value* start_pincmask = builder.CreateCall(gtp_1d, {base_incmask, sa0}, "start_pincmask");
llvm::Value* base_mask = builder.CreateGEP(lut_ptr, builder.getInt32(3*nlut), "base_mask");
llvm::Value* start_pmask = builder.CreateCall(gtp_1d, {base_mask, sa0}, "start_pmask");
// Enter loop
builder.CreateBr(LoopBB);
builder.SetInsertPoint(LoopBB);
@@ -264,8 +264,8 @@ int main(){
llvm::PHINode *b = builder.CreatePHI(start_b->getType(), 3, "b");
llvm::PHINode *pdelta = builder.CreatePHI(start_pdelta->getType(), 3);
llvm::PHINode *pincdelta = builder.CreatePHI(start_pincdelta->getType(), 3);
// llvm::PHINode *pmasks = builder.CreatePHI(start_pmask->getType(), 3);
// llvm::PHINode *pincmasks = builder.CreatePHI(start_pincmask->getType(), 3);
llvm::PHINode *pmasks = builder.CreatePHI(start_pmask->getType(), 3);
llvm::PHINode *pincmasks = builder.CreatePHI(start_pincmask->getType(), 3);
llvm::Value* next_c = builder.CreateCall(mma, {a, b, c}, "next_c");
c->addIncoming(_0, PrologBB);
c->addIncoming(next_c, LoopBB);
@@ -275,34 +275,34 @@ int main(){
crs->addIncoming(next_crs, LoopBB);
// Update pointer
llvm::Value *inc_delta = builder.CreateLoad(pincdelta);
// llvm::Value *inc_mask = builder.CreateLoad(pincmasks);
llvm::Value *inc_mask = builder.CreateLoad(pincmasks);
llvm::Value *inc_a_1 = builder.CreateLoad(pdelta);
llvm::Value *inc_a_0 = builder.CreateCall(splat_1d, {bm, builder.getInt32(0)});
llvm::Value *inc_a = builder.CreateCall(outer_add, {inc_a_0, inc_a_1});
llvm::Value *next_pa = builder.CreateCall(stp_2d, {pa, inc_a}, "next_i_ptr");
llvm::Value *next_pb = builder.CreateCall(stp_2d, {pb, inc_b}, "next_f_ptr");
llvm::Value *next_pdelta = builder.CreateCall(stp_1d, {pdelta, inc_delta});
llvm::Value *next_pincdelta = builder.CreateCall(stp_1d, {pincdelta, inc_delta});
// llvm::Value *next_pmask = builder.CreateCall(stp_1d, {pmasks, inc_mask});
// llvm::Value *next_pincmask = builder.CreateCall(stp_1d, {pincmasks, inc_mask});
llvm::Value *next_pa = builder.CreateCall(stp_2d, {pa, inc_a}, "next_pa");
llvm::Value *next_pb = builder.CreateCall(stp_2d, {pb, inc_b}, "next_pb");
llvm::Value *next_pdelta = builder.CreateCall(stp_1d, {pdelta, inc_delta}, "next_pdelta");
llvm::Value *next_pincdelta = builder.CreateCall(stp_1d, {pincdelta, inc_delta}, "next_pincdelta");
llvm::Value *next_pmask = builder.CreateCall(stp_1d, {pmasks, inc_mask}, "next_pmask");
llvm::Value *next_pincmask = builder.CreateCall(stp_1d, {pincmasks, inc_mask}, "next_pincmask");
pdelta->addIncoming(start_pdelta, PrologBB);
pdelta->addIncoming(next_pdelta, LoopBB);
pincdelta->addIncoming(start_pincdelta, PrologBB);
pincdelta->addIncoming(next_pincdelta, LoopBB);
// pmasks->addIncoming(start_pmask, PrologBB);
// pmasks->addIncoming(next_pmask, LoopBB);
// pincmasks->addIncoming(start_pincmask, PrologBB);
// pincmasks->addIncoming(next_pincmask, LoopBB);
pmasks->addIncoming(start_pmask, PrologBB);
pmasks->addIncoming(next_pmask, LoopBB);
pincmasks->addIncoming(start_pincmask, PrologBB);
pincmasks->addIncoming(next_pincmask, LoopBB);
pa->addIncoming(start_pa, PrologBB);
pa->addIncoming(next_pa, LoopBB);
pb->addIncoming(start_pb, PrologBB);
pb->addIncoming(next_pb, LoopBB);
// End condition
llvm::Value* no_bounds_check = builder.CreateICmpSGT(next_crs, builder.getInt32(0));
llvm::Value* no_bounds_check = builder.CreateICmpSGT(next_crs, builder.getInt32(0), "no_bounds_check");
// Masks
// llvm::Value* mask_a_0 = builder.CreateLoad(pmasks);
// llvm::Value* mask_a = builder.CreateCall(outer_and_int32, {mask_a_0, mask_a_1});
llvm::Value* mask_a = builder.CreateCall(splat_2d, {bm, bk, no_bounds_check}, "mask_a");
llvm::Value* mask_a_0 = builder.CreateLoad(pmasks, "mask_a_0");
llvm::Value* mask_a_i32 = builder.CreateCall(outer_and_int32, {mask_a_0, mask_a_1}, "mask_a_i32");
llvm::Value* mask_a = builder.CreateICmpNE(mask_a_i32, llvm::ConstantTile::get(_s0, {bm, bk}), "mask_a");
llvm::Value* mask_b = builder.CreateCall(splat_2d, {bn, bk, no_bounds_check}, "mask_b");
// Pre-fetch
llvm::Value* next_aa = builder.CreateCall(masked_load, {next_pa, mask_a}, "next_aa");
@@ -335,8 +335,8 @@ int main(){
pb->addIncoming(next_pb, LastIterBB);
pdelta->addIncoming(next_pdelta, LastIterBB);
pincdelta->addIncoming(next_pincdelta, LastIterBB);
// pmasks->addIncoming(next_pmask, LastIterBB);
// pincmasks->addIncoming(next_pincmask, LastIterBB);
pmasks->addIncoming(next_pmask, LastIterBB);
pincmasks->addIncoming(next_pincmask, LastIterBB);
builder.CreateCondBr(loop, LoopBB, EpilogueBB);
// Epilogue