This commit is contained in:
Philippe Tillet
2018-12-03 07:42:05 -05:00
parent 68c8de88f5
commit 8b040b4645
2 changed files with 97 additions and 41 deletions

View File

@@ -121,18 +121,19 @@ int main(){
llvm::IntegerType* int32_t = llvm::Type::getInt32Ty(context);
llvm::IntegerType* int1_t = llvm::Type::getInt1Ty(context);
llvm::Type* tile_t = llvm::TileType::get(numeric_t, 2);
llvm::Type* tile2d_t = llvm::TileType::get(numeric_t, 2);
llvm::Type* tile3d_t = llvm::TileType::get(numeric_t, 3);
llvm::Type* int32_slice_t = llvm::TileType::get(int32_t, 1);
llvm::Type* int32_tile_t = llvm::TileType::get(int32_t, 2);
llvm::Type* int1_slice_t = llvm::TileType::get(int1_t, 1);
llvm::Type* int1_tile_t = llvm::TileType::get(int1_t, 2);
llvm::PointerType* tile_ptr_t = llvm::PointerType::get(tile_t, 0);
llvm::PointerType* tile2d_ptr_t = llvm::PointerType::get(tile2d_t, 0);
llvm::Function* read_slice_x = llvm::Intrinsic::getDeclaration(module.get(), llvm::Intrinsic::tlvm_read_slice_x, {int32_slice_t});
llvm::Function* read_slice_y = llvm::Intrinsic::getDeclaration(module.get(), llvm::Intrinsic::tlvm_read_slice_y, {int32_slice_t});
llvm::Function* range = llvm::Intrinsic::getDeclaration(module.get(), llvm::Intrinsic::tlvm_range, {int32_slice_t});
llvm::Function* gtp = llvm::Intrinsic::getDeclaration(module.get(), llvm::Intrinsic::tlvm_gtp_2d, {tile_ptr_t, numeric_ptr_t, int32_tile_t});
llvm::Function* stp = llvm::Intrinsic::getDeclaration(module.get(), llvm::Intrinsic::tlvm_stp_2d, {tile_ptr_t, int32_tile_t});
llvm::Function* gtp = llvm::Intrinsic::getDeclaration(module.get(), llvm::Intrinsic::tlvm_gtp_2d, {tile2d_ptr_t, numeric_ptr_t, int32_tile_t});
llvm::Function* stp = llvm::Intrinsic::getDeclaration(module.get(), llvm::Intrinsic::tlvm_stp_2d, {tile2d_ptr_t, int32_tile_t});
llvm::Intrinsic::ID mma_id;
if(!AT && !BT) mma_id = llvm::Intrinsic::tlvm_mma_nn;
if(!AT && BT) mma_id = llvm::Intrinsic::tlvm_mma_nt;
@@ -140,17 +141,18 @@ int main(){
if(AT && BT) mma_id = llvm::Intrinsic::tlvm_mma_tt;
llvm::Function* outer_add = llvm::Intrinsic::getDeclaration(module.get(), llvm::Intrinsic::tlvm_outer_add, {int32_tile_t, int32_slice_t, int32_slice_t});
llvm::Function* outer_and = llvm::Intrinsic::getDeclaration(module.get(), llvm::Intrinsic::tlvm_outer_and, {int1_tile_t, int1_slice_t, int1_slice_t});
llvm::Function* mma = llvm::Intrinsic::getDeclaration(module.get(), mma_id, {tile_t});
llvm::Function* reshape = llvm::Intrinsic::getDeclaration(module.get(), llvm::Intrinsic::tlvm_reshape_2d, {tile_t});
llvm::Function* mma = llvm::Intrinsic::getDeclaration(module.get(), mma, {tile3d_t});
llvm::Function* reshape = llvm::Intrinsic::getDeclaration(module.get(), llvm::Intrinsic::tlvm_reshape_3d, {tile3d_t, tile2d_t});
llvm::Function* splat_2d = llvm::Intrinsic::getDeclaration(module.get(), llvm::Intrinsic::tlvm_splat_2d, {mask_tile_t, bool_t});
llvm::Function* splat_1d = llvm::Intrinsic::getDeclaration(module.get(), llvm::Intrinsic::tlvm_splat_1d, {int32_slice_t, int32_t});
llvm::Function* masked_load = llvm::Intrinsic::getDeclaration(module.get(), llvm::Intrinsic::tlvm_masked_load, {tile_t, tile_ptr_t, mask_tile_t});
llvm::Function* masked_store = llvm::Intrinsic::getDeclaration(module.get(), llvm::Intrinsic::tlvm_masked_store, {tile_t, tile_ptr_t, mask_tile_t});
llvm::Function* masked_load = llvm::Intrinsic::getDeclaration(module.get(), llvm::Intrinsic::tlvm_masked_load, {tile2d_t, tile2d_ptr_t, mask_tile_t});
llvm::Function* masked_store = llvm::Intrinsic::getDeclaration(module.get(), llvm::Intrinsic::tlvm_masked_store, {tile2d_t, tile2d_ptr_t, mask_tile_t});
// Hyperparameters
llvm::Hyperparameter *bm = llvm::Hyperparameter::get(int32_t, 0);
llvm::Hyperparameter *bn = llvm::Hyperparameter::get(int32_t, 1);
llvm::Hyperparameter *bk = llvm::Hyperparameter::get(int32_t, 2);
llvm::Hyperparameter *br = llvm::Hyperparameter::get(int32_t, 3);
// Constants
llvm::Constant *_s0 = llvm::ConstantInt::get(int32_t, 0);
@@ -221,8 +223,8 @@ int main(){
llvm::CallInst* startpb = builder.CreateCall(gtp, {arguments[1], offb}, "startpb");
llvm::LoadInst* startfa = builder.CreateLoad(startpa, "startfa");
llvm::LoadInst* startfb = builder.CreateLoad(startpb, "startfb");
llvm::Value* starta = builder.CreateCall(reshape, {startfa, ba0, ba1}, "starta");
llvm::Value* startb = builder.CreateCall(reshape, {startfb, bb0, bb1}, "startb");
llvm::Value* starta = builder.CreateCall(reshape, {startfa, ba0, ba1, br}, "starta");
llvm::Value* startb = builder.CreateCall(reshape, {startfb, bb0, bb1, br}, "startb");
llvm::Value* tinca0 = builder.CreateCall(splat_1d, {ba0, builder.CreateMul(inca0, AS0)}, "tinca0");
llvm::Value* tinca1 = builder.CreateCall(splat_1d, {ba1, builder.CreateMul(inca1, AS1)});
llvm::Value* tincb0 = builder.CreateCall(splat_1d, {bb0, builder.CreateMul(incb0, BS0)});
@@ -261,8 +263,8 @@ int main(){
// Pre-fetch
llvm::Value* nextfa = builder.CreateCall(masked_load, {nextpa, maska}, "nextfa");
llvm::Value* nextfb = builder.CreateCall(masked_load, {nextpb, maskb}, "nextfb");
llvm::Value* nexta = builder.CreateCall(reshape, {nextfa, ba0, ba1}, "nexta");
llvm::Value* nextb = builder.CreateCall(reshape, {nextfb, bb0, bb1}, "nextb");
llvm::Value* nexta = builder.CreateCall(reshape, {nextfa, ba0, ba1, br}, "nexta");
llvm::Value* nextb = builder.CreateCall(reshape, {nextfb, bb0, bb1, br}, "nextb");
a->addIncoming(starta, PrologBB);
a->addIncoming(nexta, LoopBB);
b->addIncoming(startb, PrologBB);
@@ -283,8 +285,8 @@ int main(){
llvm::Value* lastmaskb = builder.CreateCall(outer_and, {in_bounds_b0, in_bounds_b1}, "lastmaskb");
llvm::Value* lastfa = builder.CreateCall(masked_load, {nextpa, lastmaska}, "lastfa");
llvm::Value* lastfb = builder.CreateCall(masked_load, {nextpb, lastmaskb}, "lastfb");
llvm::Value* lasta = builder.CreateCall(reshape, {lastfa, ba0, ba1}, "lasta");
llvm::Value* lastb = builder.CreateCall(reshape, {lastfb, bb0, bb1}, "lastb");
llvm::Value* lasta = builder.CreateCall(reshape, {lastfa, ba0, ba1, br}, "lasta");
llvm::Value* lastb = builder.CreateCall(reshape, {lastfb, bb0, bb1, br}, "lastb");
llvm::Value* loop = builder.CreateICmpSGT(nextk, _s0);
a->addIncoming(lasta, LastIterBB);
b->addIncoming(lastb, LastIterBB);