more tinkering
This commit is contained in:
113
conv.cpp
113
conv.cpp
@@ -65,7 +65,7 @@ void autotune(llvm::TargetMachine *machine, llvm::Module &module){
|
||||
for(llvm::TargetTuner::ParamType ¶m: tuning_params){
|
||||
// This parameter has not been seen before
|
||||
if(unique.insert(param.Value).second){
|
||||
std::cout << instr.getName().data() << " " << param.Name << std::endl;
|
||||
std::cout << "PARAM: " << instr.getName().data() << " " << param.Name << std::endl;
|
||||
params.push_back(param.Value);
|
||||
}
|
||||
}
|
||||
@@ -142,10 +142,12 @@ int main(){
|
||||
llvm::Intrinsic::ID mma_id = llvm::Intrinsic::tlvm_mma_nt;
|
||||
llvm::Function* outer_add = llvm::Intrinsic::getDeclaration(module.get(), llvm::Intrinsic::tlvm_outer_add, {int32_tile_t, int32_slice_t, int32_slice_t});
|
||||
llvm::Function* outer_and = llvm::Intrinsic::getDeclaration(module.get(), llvm::Intrinsic::tlvm_outer_and, {int1_tile_t, int1_slice_t, int1_slice_t});
|
||||
llvm::Function* outer_and_int32 = llvm::Intrinsic::getDeclaration(module.get(), llvm::Intrinsic::tlvm_outer_and, {int1_tile_t, int32_slice_t, int32_slice_t});
|
||||
llvm::Function* mma = llvm::Intrinsic::getDeclaration(module.get(), mma_id, {tile_t});
|
||||
llvm::Function* reshape = llvm::Intrinsic::getDeclaration(module.get(), llvm::Intrinsic::tlvm_reshape_2d, {tile_t});
|
||||
llvm::Function* splat_2d = llvm::Intrinsic::getDeclaration(module.get(), llvm::Intrinsic::tlvm_splat_2d, {mask_tile_t, tile_t, bool_t});
|
||||
llvm::Function* splat_1d = llvm::Intrinsic::getDeclaration(module.get(), llvm::Intrinsic::tlvm_splat_1d, {int32_slice_t, int32_slice_t, int32_t});
|
||||
llvm::Function* splat_2d = llvm::Intrinsic::getDeclaration(module.get(), llvm::Intrinsic::tlvm_splat_2d, {mask_tile_t, bool_t});
|
||||
llvm::Function* splat_1d = llvm::Intrinsic::getDeclaration(module.get(), llvm::Intrinsic::tlvm_splat_1d, {int32_slice_t, int32_t});
|
||||
|
||||
llvm::Function* masked_load = llvm::Intrinsic::getDeclaration(module.get(), llvm::Intrinsic::tlvm_masked_load, {tile_t, tile_ptr_t, mask_tile_t});
|
||||
llvm::Function* masked_store = llvm::Intrinsic::getDeclaration(module.get(), llvm::Intrinsic::tlvm_masked_store, {tile_t, tile_ptr_t, mask_tile_t});
|
||||
|
||||
@@ -208,48 +210,48 @@ int main(){
|
||||
llvm::Value* PQN = builder.CreateMul(H, builder.CreateMul(W, N));
|
||||
|
||||
// Images HWN offset
|
||||
llvm::Value* sa_hw = builder.CreateUDiv(sa0, builder.CreateCall(splat_1d, {sa0, N}));
|
||||
llvm::Value* sa_n = builder.CreateURem(sa0, builder.CreateCall(splat_1d, {sa0, N}));
|
||||
llvm::Value* sa_h = builder.CreateUDiv(sa_hw, builder.CreateCall(splat_1d, {sa0, W}));
|
||||
llvm::Value* sa_w = builder.CreateURem(sa_hw, builder.CreateCall(splat_1d, {sa0, W}));
|
||||
llvm::Value* offa_0 = builder.CreateMul(sa_n, builder.CreateCall(splat_1d, {sa0, lda_n}));
|
||||
offa_0 = builder.CreateAdd(offa_0, builder.CreateMul(sa_h, builder.CreateCall(splat_1d, {sa0, lda_h})));
|
||||
offa_0 = builder.CreateAdd(offa_0, builder.CreateMul(sa_w, builder.CreateCall(splat_1d, {sa0, lda_w})));
|
||||
llvm::Value* sa_hw = builder.CreateUDiv(sa0, builder.CreateCall(splat_1d, {bm, N}));
|
||||
llvm::Value* sa_n = builder.CreateURem(sa0, builder.CreateCall(splat_1d, {bm, N}));
|
||||
llvm::Value* sa_h = builder.CreateUDiv(sa_hw, builder.CreateCall(splat_1d, {bm, W}));
|
||||
llvm::Value* sa_w = builder.CreateURem(sa_hw, builder.CreateCall(splat_1d, {bm, W}));
|
||||
llvm::Value* offa_0 = builder.CreateMul(sa_n, builder.CreateCall(splat_1d, {bm, lda_n}));
|
||||
offa_0 = builder.CreateAdd(offa_0, builder.CreateMul(sa_h, builder.CreateCall(splat_1d, {bm, lda_h})));
|
||||
offa_0 = builder.CreateAdd(offa_0, builder.CreateMul(sa_w, builder.CreateCall(splat_1d, {bm, lda_w})));
|
||||
// Images CRS offset
|
||||
llvm::Value* sa_cr = builder.CreateUDiv(sa1, builder.CreateCall(splat_1d, {sa1, S}));
|
||||
llvm::Value* sa_s = builder.CreateURem(sa1, builder.CreateCall(splat_1d, {sa1, S}));
|
||||
llvm::Value* sa_c = builder.CreateUDiv(sa_cr, builder.CreateCall(splat_1d, {sa1, R}));
|
||||
llvm::Value* sa_r = builder.CreateURem(sa_cr, builder.CreateCall(splat_1d, {sa1, R}));
|
||||
llvm::Value* offa_1 = builder.CreateMul(sa_c, builder.CreateCall(splat_1d, {sa1, lda_c}));
|
||||
offa_1 = builder.CreateAdd(offa_1, builder.CreateMul(sa_r, builder.CreateCall(splat_1d, {sa1, lda_h})));
|
||||
offa_1 = builder.CreateAdd(offa_1, builder.CreateMul(sa_s, builder.CreateCall(splat_1d, {sa1, lda_w})));
|
||||
llvm::Value* sa_cr = builder.CreateUDiv(sa1, builder.CreateCall(splat_1d, {bk, S}));
|
||||
llvm::Value* sa_s = builder.CreateURem(sa1, builder.CreateCall(splat_1d, {bk, S}));
|
||||
llvm::Value* sa_c = builder.CreateUDiv(sa_cr, builder.CreateCall(splat_1d, {bk, R}));
|
||||
llvm::Value* sa_r = builder.CreateURem(sa_cr, builder.CreateCall(splat_1d, {bk, R}));
|
||||
llvm::Value* offa_1 = builder.CreateMul(sa_c, builder.CreateCall(splat_1d, {bk, lda_c}));
|
||||
offa_1 = builder.CreateAdd(offa_1, builder.CreateMul(sa_r, builder.CreateCall(splat_1d, {bk, lda_h})));
|
||||
offa_1 = builder.CreateAdd(offa_1, builder.CreateMul(sa_s, builder.CreateCall(splat_1d, {bk, lda_w})));
|
||||
// Images pointer
|
||||
llvm::Value* off_a = builder.CreateCall(outer_add, {offa_0, offa_1});
|
||||
llvm::Value* start_pa = builder.CreateCall(gtp_2d, {base_i_ptr, off_a}, "start_i_ptr");
|
||||
llvm::LoadInst* start_aa = builder.CreateLoad(start_pa, false, "start_i_val");
|
||||
llvm::Value* start_a = builder.CreateCall(reshape, {start_aa, bm, bk}, "start_i");
|
||||
// Filters pointer
|
||||
llvm::Value* tldb_s = builder.CreateCall(splat_1d, {sb1, K});
|
||||
llvm::Value* tldb_s = builder.CreateCall(splat_1d, {bk, K});
|
||||
llvm::Value* off_b = builder.CreateCall(outer_add, {sb0, builder.CreateMul(sb1, tldb_s)}, "off_f");
|
||||
llvm::Value* start_pb = builder.CreateCall(gtp_2d, {base_f_ptr, off_b}, "start_f_ptr");
|
||||
llvm::Value* start_bb = builder.CreateLoad(start_pb, false, "start_f_val");
|
||||
llvm::Value* start_b = builder.CreateCall(reshape, {start_bb, bn, bk}, "start_f");
|
||||
// Filters increment
|
||||
llvm::Value* inc_b_0 = builder.CreateCall(splat_1d, {sb0, _s0}, "inc_f_0");
|
||||
llvm::Value* inc_b_1 = builder.CreateCall(splat_1d, {sb1, builder.CreateMul(bk, ldb_k)}, "inc_f_1");
|
||||
llvm::Value* inc_b_0 = builder.CreateCall(splat_1d, {bn, _s0}, "inc_f_0");
|
||||
llvm::Value* inc_b_1 = builder.CreateCall(splat_1d, {bk, builder.CreateMul(bk, ldb_k)}, "inc_f_1");
|
||||
llvm::Value* inc_b = builder.CreateCall(outer_add, {inc_b_0, inc_b_1}, "inc_f");
|
||||
// Delta pointers
|
||||
llvm::Value* base_incdelta = lut_ptr;
|
||||
llvm::Value* start_pincdelta = builder.CreateCall(gtp_1d, {base_incdelta, sa1}, "start_pincdelta");
|
||||
llvm::Value* start_pincdelta = builder.CreateCall(gtp_1d, {base_incdelta, sa0}, "start_pincdelta");
|
||||
llvm::Value* base_delta = builder.CreateGEP(lut_ptr, builder.getInt32(nlut));
|
||||
llvm::Value* start_pdelta = builder.CreateCall(gtp_1d, {base_delta, builder.CreateCall(splat_1d, {sa1, _s0})}, "start_pdelta");
|
||||
llvm::Value* start_pdelta = builder.CreateCall(gtp_1d, {base_delta, builder.CreateCall(splat_1d, {bk, _s0})}, "start_pdelta");
|
||||
// Masks
|
||||
llvm::Value* _1 = builder.CreateCall(splat_1d, {sb1, builder.getInt32(1)});
|
||||
llvm::Value* mask_a_1 = builder.CreateShl(_1, sb1);
|
||||
llvm::Value* base_incmask = builder.CreateGEP(lut_ptr, builder.getInt32(2*nlut));
|
||||
llvm::Value* start_pincmask = builder.CreateCall(gtp_1d, {base_incmask, sb1});
|
||||
llvm::Value* base_mask = builder.CreateGEP(lut_ptr, builder.getInt32(3*nlut));
|
||||
llvm::Value* start_pmask = builder.CreateCall(gtp_1d, {base_mask, sb1});
|
||||
// llvm::Value* _1 = builder.CreateCall(splat_1d, {bk, builder.getInt32(1)});
|
||||
// llvm::Value* mask_a_1 = builder.CreateShl(_1, sa1);
|
||||
// llvm::Value* base_incmask = builder.CreateGEP(lut_ptr, builder.getInt32(2*nlut));
|
||||
// llvm::Value* start_pincmask = builder.CreateCall(gtp_1d, {base_incmask, sa0});
|
||||
// llvm::Value* base_mask = builder.CreateGEP(lut_ptr, builder.getInt32(3*nlut));
|
||||
// llvm::Value* start_pmask = builder.CreateCall(gtp_1d, {base_mask, sa0});
|
||||
// Enter loop
|
||||
builder.CreateBr(LoopBB);
|
||||
builder.SetInsertPoint(LoopBB);
|
||||
@@ -262,8 +264,8 @@ int main(){
|
||||
llvm::PHINode *b = builder.CreatePHI(start_b->getType(), 3, "b");
|
||||
llvm::PHINode *pdelta = builder.CreatePHI(start_pdelta->getType(), 3);
|
||||
llvm::PHINode *pincdelta = builder.CreatePHI(start_pincdelta->getType(), 3);
|
||||
llvm::PHINode *pmasks = builder.CreatePHI(start_pmask->getType(), 3);
|
||||
llvm::PHINode *pincmasks = builder.CreatePHI(start_pincmask->getType(), 3);
|
||||
// llvm::PHINode *pmasks = builder.CreatePHI(start_pmask->getType(), 3);
|
||||
// llvm::PHINode *pincmasks = builder.CreatePHI(start_pincmask->getType(), 3);
|
||||
llvm::Value* next_c = builder.CreateCall(mma, {a, b, c}, "next_c");
|
||||
c->addIncoming(_0, PrologBB);
|
||||
c->addIncoming(next_c, LoopBB);
|
||||
@@ -273,24 +275,24 @@ int main(){
|
||||
crs->addIncoming(next_crs, LoopBB);
|
||||
// Update pointer
|
||||
llvm::Value *inc_delta = builder.CreateLoad(pincdelta);
|
||||
llvm::Value *inc_mask = builder.CreateLoad(pincmasks);
|
||||
// llvm::Value *inc_mask = builder.CreateLoad(pincmasks);
|
||||
llvm::Value *inc_a_1 = builder.CreateLoad(pdelta);
|
||||
llvm::Value *inc_a_0 = builder.CreateCall(splat_1d, {sa0, builder.getInt32(0)});
|
||||
llvm::Value *inc_a_0 = builder.CreateCall(splat_1d, {bm, builder.getInt32(0)});
|
||||
llvm::Value *inc_a = builder.CreateCall(outer_add, {inc_a_0, inc_a_1});
|
||||
llvm::Value *next_pa = builder.CreateCall(stp_2d, {pa, inc_a}, "next_i_ptr");
|
||||
llvm::Value *next_pb = builder.CreateCall(stp_2d, {pb, inc_b}, "next_f_ptr");
|
||||
llvm::Value *next_pdelta = builder.CreateCall(stp_1d, {pdelta, inc_delta});
|
||||
llvm::Value *next_pincdelta = builder.CreateCall(stp_1d, {pincdelta, inc_delta});
|
||||
llvm::Value *next_pmask = builder.CreateCall(stp_1d, {pmasks, inc_mask});
|
||||
llvm::Value *next_pincmask = builder.CreateCall(stp_1d, {pincmasks, inc_mask});
|
||||
// llvm::Value *next_pmask = builder.CreateCall(stp_1d, {pmasks, inc_mask});
|
||||
// llvm::Value *next_pincmask = builder.CreateCall(stp_1d, {pincmasks, inc_mask});
|
||||
pdelta->addIncoming(start_pdelta, PrologBB);
|
||||
pdelta->addIncoming(next_pdelta, LoopBB);
|
||||
pincdelta->addIncoming(start_pincdelta, PrologBB);
|
||||
pincdelta->addIncoming(next_pincdelta, LoopBB);
|
||||
pmasks->addIncoming(start_pmask, PrologBB);
|
||||
pmasks->addIncoming(next_pmask, LoopBB);
|
||||
pincmasks->addIncoming(start_pincmask, PrologBB);
|
||||
pincmasks->addIncoming(next_pincmask, LoopBB);
|
||||
// pmasks->addIncoming(start_pmask, PrologBB);
|
||||
// pmasks->addIncoming(next_pmask, LoopBB);
|
||||
// pincmasks->addIncoming(start_pincmask, PrologBB);
|
||||
// pincmasks->addIncoming(next_pincmask, LoopBB);
|
||||
pa->addIncoming(start_pa, PrologBB);
|
||||
pa->addIncoming(next_pa, LoopBB);
|
||||
pb->addIncoming(start_pb, PrologBB);
|
||||
@@ -298,9 +300,10 @@ int main(){
|
||||
// End condition
|
||||
llvm::Value* no_bounds_check = builder.CreateICmpSGT(next_crs, builder.getInt32(0));
|
||||
// Masks
|
||||
llvm::Value* mask_a_0 = builder.CreateLoad(pdelta);
|
||||
llvm::Value* mask_a = builder.CreateCall(outer_and, {mask_a_0, mask_a_1});
|
||||
llvm::Value* mask_b = builder.CreateCall(splat_2d, {start_bb, no_bounds_check}, "mask_b");
|
||||
// llvm::Value* mask_a_0 = builder.CreateLoad(pmasks);
|
||||
// llvm::Value* mask_a = builder.CreateCall(outer_and_int32, {mask_a_0, mask_a_1});
|
||||
llvm::Value* mask_a = builder.CreateCall(splat_2d, {bm, bk, no_bounds_check}, "mask_a");
|
||||
llvm::Value* mask_b = builder.CreateCall(splat_2d, {bn, bk, no_bounds_check}, "mask_b");
|
||||
// Pre-fetch
|
||||
llvm::Value* next_aa = builder.CreateCall(masked_load, {next_pa, mask_a}, "next_aa");
|
||||
llvm::Value* next_bb = builder.CreateCall(masked_load, {next_pb, mask_b}, "next_bb");
|
||||
@@ -318,8 +321,8 @@ int main(){
|
||||
builder.CreateCondBr(exit, EpilogueBB, LastIterBB);
|
||||
// Last Iteration
|
||||
builder.SetInsertPoint(LastIterBB);
|
||||
llvm::Value* in_bounds_b0 = builder.CreateICmpSLT(sb0, builder.CreateCall(splat_1d, {sb0, K}));
|
||||
llvm::Value* in_bounds_b1 = builder.CreateICmpSLT(sb1, builder.CreateCall(splat_1d, {sb1, bk}));
|
||||
llvm::Value* in_bounds_b0 = builder.CreateICmpSLT(sb0, builder.CreateCall(splat_1d, {bn, K}));
|
||||
llvm::Value* in_bounds_b1 = builder.CreateICmpSLT(sb1, builder.CreateCall(splat_1d, {bk, next_crs}));
|
||||
llvm::Value* last_maskb = builder.CreateCall(outer_and, {in_bounds_b0, in_bounds_b1}, "last_maskb");
|
||||
llvm::Value* last_bb = builder.CreateCall(masked_load, {next_pb, last_maskb}, "last_bb");
|
||||
llvm::Value* last_b = builder.CreateCall(reshape, {last_bb, bn, bk}, "last_b");
|
||||
@@ -332,8 +335,8 @@ int main(){
|
||||
pb->addIncoming(next_pb, LastIterBB);
|
||||
pdelta->addIncoming(next_pdelta, LastIterBB);
|
||||
pincdelta->addIncoming(next_pincdelta, LastIterBB);
|
||||
pmasks->addIncoming(next_pmask, LastIterBB);
|
||||
pincmasks->addIncoming(next_pincmask, LastIterBB);
|
||||
// pmasks->addIncoming(next_pmask, LastIterBB);
|
||||
// pincmasks->addIncoming(next_pincmask, LastIterBB);
|
||||
builder.CreateCondBr(loop, LoopBB, EpilogueBB);
|
||||
|
||||
// Epilogue
|
||||
@@ -346,21 +349,21 @@ int main(){
|
||||
llvm::Value* ldc_k = builder.CreateMul(lda_h, H);
|
||||
llvm::Value* ldb_n = builder.CreateMul(lda_c, K);
|
||||
// Output PQN offset
|
||||
llvm::Value* sc_pq = builder.CreateUDiv(sc_pqn, builder.CreateCall(splat_1d, {sc_pqn, N}));
|
||||
llvm::Value* sc_n = builder.CreateURem(sc_pqn, builder.CreateCall(splat_1d, {sc_pqn, N}));
|
||||
llvm::Value* sc_p = builder.CreateUDiv(sc_pq, builder.CreateCall(splat_1d, {sc_pqn, W}));
|
||||
llvm::Value* sc_q = builder.CreateURem(sc_pq, builder.CreateCall(splat_1d, {sc_pqn, W}));
|
||||
llvm::Value* offc0 = builder.CreateMul(sc_n, builder.CreateCall(splat_1d, {sc_pqn, ldb_n}));
|
||||
offc0 = builder.CreateAdd(offc0, builder.CreateMul(sc_p, builder.CreateCall(splat_1d, {sc_pqn, ldc_p})));
|
||||
offc0 = builder.CreateAdd(offc0, builder.CreateMul(sc_q, builder.CreateCall(splat_1d, {sc_pqn, ldc_q})));
|
||||
llvm::Value* sc_pq = builder.CreateUDiv(sc_pqn, builder.CreateCall(splat_1d, {bm, N}));
|
||||
llvm::Value* sc_n = builder.CreateURem(sc_pqn, builder.CreateCall(splat_1d, {bm, N}));
|
||||
llvm::Value* sc_p = builder.CreateUDiv(sc_pq, builder.CreateCall(splat_1d, {bm, W}));
|
||||
llvm::Value* sc_q = builder.CreateURem(sc_pq, builder.CreateCall(splat_1d, {bm, W}));
|
||||
llvm::Value* offc0 = builder.CreateMul(sc_n, builder.CreateCall(splat_1d, {bm, ldb_n}));
|
||||
offc0 = builder.CreateAdd(offc0, builder.CreateMul(sc_p, builder.CreateCall(splat_1d, {bm, ldc_p})));
|
||||
offc0 = builder.CreateAdd(offc0, builder.CreateMul(sc_q, builder.CreateCall(splat_1d, {bm, ldc_q})));
|
||||
// Output K offset
|
||||
llvm::Value* offc1 = builder.CreateMul(sc_k, builder.CreateCall(splat_1d, {sc_k, ldc_k}));
|
||||
llvm::Value* offc1 = builder.CreateMul(sc_k, builder.CreateCall(splat_1d, {bn, ldc_k}));
|
||||
// Output pointer
|
||||
llvm::Value* offc = builder.CreateCall(outer_add, {offc0, offc1});
|
||||
llvm::Value* pc = builder.CreateCall(gtp_2d, {base_o_ptr, offc});
|
||||
// Output masks
|
||||
llvm::Value* in_bounds_c0 = builder.CreateICmpSLT(sc_pqn, builder.CreateCall(splat_1d, {sc_pqn, PQN}));
|
||||
llvm::Value* in_bounds_c1 = builder.CreateICmpSLT(sc_k, builder.CreateCall(splat_1d, {sc_k, K}));
|
||||
llvm::Value* in_bounds_c0 = builder.CreateICmpSLT(sc_pqn, builder.CreateCall(splat_1d, {bm, PQN}));
|
||||
llvm::Value* in_bounds_c1 = builder.CreateICmpSLT(sc_k, builder.CreateCall(splat_1d, {bn, K}));
|
||||
llvm::Value* maskc = builder.CreateCall(outer_and, {in_bounds_c0, in_bounds_c1});
|
||||
builder.CreateCall(masked_store, {next_c, pc, maskc});
|
||||
builder.CreateRet(NULL);
|
||||
|
36
gemm.cpp
36
gemm.cpp
@@ -138,14 +138,12 @@ int main(){
|
||||
if(!AT && BT) mma_id = llvm::Intrinsic::tlvm_mma_nt;
|
||||
if(AT && !BT) mma_id = llvm::Intrinsic::tlvm_mma_tn;
|
||||
if(AT && BT) mma_id = llvm::Intrinsic::tlvm_mma_tt;
|
||||
llvm::Function* broadcast_int32 = llvm::Intrinsic::getDeclaration(module.get(), llvm::Intrinsic::tlvm_broadcast_1d, {int32_tile_t, int32_slice_t});
|
||||
llvm::Function* broadcast_int1 = llvm::Intrinsic::getDeclaration(module.get(), llvm::Intrinsic::tlvm_broadcast_1d, {int1_tile_t, int1_slice_t});
|
||||
llvm::Function* outer_add = llvm::Intrinsic::getDeclaration(module.get(), llvm::Intrinsic::tlvm_outer_add, {int32_tile_t, int32_slice_t, int32_slice_t});
|
||||
llvm::Function* outer_and = llvm::Intrinsic::getDeclaration(module.get(), llvm::Intrinsic::tlvm_outer_and, {int1_tile_t, int1_slice_t, int1_slice_t});
|
||||
llvm::Function* mma = llvm::Intrinsic::getDeclaration(module.get(), mma_id, {tile_t});
|
||||
llvm::Function* reshape = llvm::Intrinsic::getDeclaration(module.get(), llvm::Intrinsic::tlvm_reshape_2d, {tile_t});
|
||||
llvm::Function* splat_2d = llvm::Intrinsic::getDeclaration(module.get(), llvm::Intrinsic::tlvm_splat_2d, {mask_tile_t, tile_t, bool_t});
|
||||
llvm::Function* splat_1d = llvm::Intrinsic::getDeclaration(module.get(), llvm::Intrinsic::tlvm_splat_1d, {int32_slice_t, int32_slice_t, int32_t});
|
||||
llvm::Function* splat_2d = llvm::Intrinsic::getDeclaration(module.get(), llvm::Intrinsic::tlvm_splat_2d, {mask_tile_t, bool_t});
|
||||
llvm::Function* splat_1d = llvm::Intrinsic::getDeclaration(module.get(), llvm::Intrinsic::tlvm_splat_1d, {int32_slice_t, int32_t});
|
||||
llvm::Function* masked_load = llvm::Intrinsic::getDeclaration(module.get(), llvm::Intrinsic::tlvm_masked_load, {tile_t, tile_ptr_t, mask_tile_t});
|
||||
llvm::Function* masked_store = llvm::Intrinsic::getDeclaration(module.get(), llvm::Intrinsic::tlvm_masked_store, {tile_t, tile_ptr_t, mask_tile_t});
|
||||
|
||||
@@ -215,8 +213,8 @@ int main(){
|
||||
std::swap(incb0, incb1);
|
||||
}
|
||||
|
||||
llvm::CallInst* tlda = builder.CreateCall(splat_1d, {sa1, AS0}, "lda");
|
||||
llvm::CallInst* tldb = builder.CreateCall(splat_1d, {sb1, BS1}, "ldb");
|
||||
llvm::CallInst* tlda = builder.CreateCall(splat_1d, {ba1, AS0}, "lda");
|
||||
llvm::CallInst* tldb = builder.CreateCall(splat_1d, {bb1, BS1}, "ldb");
|
||||
llvm::CallInst* offa = builder.CreateCall(outer_add, {sa0, builder.CreateMul(sa1, tlda)}, "offa");
|
||||
llvm::CallInst* offb = builder.CreateCall(outer_add, {sb0, builder.CreateMul(sb1, tldb)}, "offb");
|
||||
llvm::CallInst* startpa = builder.CreateCall(gtp, {arguments[0], offa}, "startpa");
|
||||
@@ -225,10 +223,10 @@ int main(){
|
||||
llvm::LoadInst* startfb = builder.CreateLoad(startpb, "startfb");
|
||||
llvm::Value* starta = builder.CreateCall(reshape, {startfa, ba0, ba1}, "starta");
|
||||
llvm::Value* startb = builder.CreateCall(reshape, {startfb, bb0, bb1}, "startb");
|
||||
llvm::Value* tinca0 = builder.CreateCall(splat_1d, {sa0, builder.CreateMul(inca0, AS0)});
|
||||
llvm::Value* tinca1 = builder.CreateCall(splat_1d, {sa1, builder.CreateMul(inca1, AS1)});
|
||||
llvm::Value* tincb0 = builder.CreateCall(splat_1d, {sb0, builder.CreateMul(incb0, BS0)});
|
||||
llvm::Value* tincb1 = builder.CreateCall(splat_1d, {sb1, builder.CreateMul(incb1, BS1)});
|
||||
llvm::Value* tinca0 = builder.CreateCall(splat_1d, {ba0, builder.CreateMul(inca0, AS0)}, "tinca0");
|
||||
llvm::Value* tinca1 = builder.CreateCall(splat_1d, {ba1, builder.CreateMul(inca1, AS1)});
|
||||
llvm::Value* tincb0 = builder.CreateCall(splat_1d, {bb0, builder.CreateMul(incb0, BS0)});
|
||||
llvm::Value* tincb1 = builder.CreateCall(splat_1d, {bb1, builder.CreateMul(incb1, BS1)});
|
||||
llvm::Value* inca = builder.CreateCall(outer_add, {tinca0, tinca1}, "inca");
|
||||
llvm::Value* incb = builder.CreateCall(outer_add, {tincb0, tincb1}, "incb");
|
||||
// Enter loop
|
||||
@@ -258,8 +256,8 @@ int main(){
|
||||
// End condition
|
||||
llvm::Value* no_bounds_check = builder.CreateICmpSGT(nextk, bound);
|
||||
// Masks
|
||||
llvm::Value* maska = builder.CreateCall(splat_2d, {startfa, no_bounds_check}, "maska");
|
||||
llvm::Value* maskb = builder.CreateCall(splat_2d, {startfb, no_bounds_check}, "maskb");
|
||||
llvm::Value* maska = builder.CreateCall(splat_2d, {ba0, ba1, no_bounds_check}, "maska");
|
||||
llvm::Value* maskb = builder.CreateCall(splat_2d, {bb0, bb1, no_bounds_check}, "maskb");
|
||||
// Pre-fetch
|
||||
llvm::Value* nextfa = builder.CreateCall(masked_load, {nextpa, maska}, "nextfa");
|
||||
llvm::Value* nextfb = builder.CreateCall(masked_load, {nextpb, maskb}, "nextfb");
|
||||
@@ -277,10 +275,10 @@ int main(){
|
||||
builder.CreateCondBr(exit, EpilogueBB, LastIterBB);
|
||||
// Last Iteration
|
||||
builder.SetInsertPoint(LastIterBB);
|
||||
llvm::Value* in_bounds_a0 = builder.CreateICmpSLT(aasm, builder.CreateCall(splat_1d, {aasm, M}));
|
||||
llvm::Value* in_bounds_a1 = builder.CreateICmpSLT(ask, builder.CreateCall(splat_1d, {ask, bk}));
|
||||
llvm::Value* in_bounds_b0 = builder.CreateICmpSLT(bbsn, builder.CreateCall(splat_1d, {bbsn, N}));
|
||||
llvm::Value* in_bounds_b1 = builder.CreateICmpSLT(bsk, builder.CreateCall(splat_1d, {bsk, bk}));
|
||||
llvm::Value* in_bounds_a0 = builder.CreateICmpSLT(aasm, builder.CreateCall(splat_1d, {ba0, M}));
|
||||
llvm::Value* in_bounds_a1 = builder.CreateICmpSLT(ask, builder.CreateCall(splat_1d, {ba1, bk}));
|
||||
llvm::Value* in_bounds_b0 = builder.CreateICmpSLT(bbsn, builder.CreateCall(splat_1d, {bb0, N}));
|
||||
llvm::Value* in_bounds_b1 = builder.CreateICmpSLT(bsk, builder.CreateCall(splat_1d, {bb1, bk}));
|
||||
llvm::Value* lastmaska = builder.CreateCall(outer_and, {in_bounds_a0, in_bounds_a1}, "lastmaska");
|
||||
llvm::Value* lastmaskb = builder.CreateCall(outer_and, {in_bounds_b0, in_bounds_b1}, "lastmaskb");
|
||||
llvm::Value* lastfa = builder.CreateCall(masked_load, {nextpa, lastmaska}, "lastfa");
|
||||
@@ -299,11 +297,11 @@ int main(){
|
||||
builder.SetInsertPoint(EpilogueBB);
|
||||
llvm::CallInst* sm = builder.CreateCall(read_slice_x, {bm}, "sm");
|
||||
llvm::CallInst* sn = builder.CreateCall(read_slice_y, {bn}, "sn");
|
||||
llvm::CallInst* ldc = builder.CreateCall(splat_1d, {sn, M}, "lda");
|
||||
llvm::CallInst* ldc = builder.CreateCall(splat_1d, {bn, M}, "lda");
|
||||
llvm::CallInst* offc = builder.CreateCall(outer_add, {sm, builder.CreateMul(sn, ldc)}, "offc");
|
||||
llvm::CallInst* pc = builder.CreateCall(gtp, {arguments[2], offc}, "pc");
|
||||
llvm::Value* in_bounds_c0 = builder.CreateICmpSLT(sm, builder.CreateCall(splat_1d, {sm, M}));
|
||||
llvm::Value* in_bounds_c1 = builder.CreateICmpSLT(sn, builder.CreateCall(splat_1d, {sn, N}));
|
||||
llvm::Value* in_bounds_c0 = builder.CreateICmpSLT(sm, builder.CreateCall(splat_1d, {bm, M}));
|
||||
llvm::Value* in_bounds_c1 = builder.CreateICmpSLT(sn, builder.CreateCall(splat_1d, {bn, N}));
|
||||
llvm::Value* maskc = builder.CreateCall(outer_and, {in_bounds_c0, in_bounds_c1}, "maskc");
|
||||
builder.CreateCall(masked_store, {nextc, pc, maskc});
|
||||
builder.CreateRet(NULL);
|
||||
|
Reference in New Issue
Block a user