updates
This commit is contained in:
30
gemm.cpp
30
gemm.cpp
@@ -121,18 +121,19 @@ int main(){
|
||||
llvm::IntegerType* int32_t = llvm::Type::getInt32Ty(context);
|
||||
llvm::IntegerType* int1_t = llvm::Type::getInt1Ty(context);
|
||||
|
||||
llvm::Type* tile_t = llvm::TileType::get(numeric_t, 2);
|
||||
llvm::Type* tile2d_t = llvm::TileType::get(numeric_t, 2);
|
||||
llvm::Type* tile3d_t = llvm::TileType::get(numeric_t, 3);
|
||||
llvm::Type* int32_slice_t = llvm::TileType::get(int32_t, 1);
|
||||
llvm::Type* int32_tile_t = llvm::TileType::get(int32_t, 2);
|
||||
llvm::Type* int1_slice_t = llvm::TileType::get(int1_t, 1);
|
||||
llvm::Type* int1_tile_t = llvm::TileType::get(int1_t, 2);
|
||||
|
||||
llvm::PointerType* tile_ptr_t = llvm::PointerType::get(tile_t, 0);
|
||||
llvm::PointerType* tile2d_ptr_t = llvm::PointerType::get(tile2d_t, 0);
|
||||
llvm::Function* read_slice_x = llvm::Intrinsic::getDeclaration(module.get(), llvm::Intrinsic::tlvm_read_slice_x, {int32_slice_t});
|
||||
llvm::Function* read_slice_y = llvm::Intrinsic::getDeclaration(module.get(), llvm::Intrinsic::tlvm_read_slice_y, {int32_slice_t});
|
||||
llvm::Function* range = llvm::Intrinsic::getDeclaration(module.get(), llvm::Intrinsic::tlvm_range, {int32_slice_t});
|
||||
llvm::Function* gtp = llvm::Intrinsic::getDeclaration(module.get(), llvm::Intrinsic::tlvm_gtp_2d, {tile_ptr_t, numeric_ptr_t, int32_tile_t});
|
||||
llvm::Function* stp = llvm::Intrinsic::getDeclaration(module.get(), llvm::Intrinsic::tlvm_stp_2d, {tile_ptr_t, int32_tile_t});
|
||||
llvm::Function* gtp = llvm::Intrinsic::getDeclaration(module.get(), llvm::Intrinsic::tlvm_gtp_2d, {tile2d_ptr_t, numeric_ptr_t, int32_tile_t});
|
||||
llvm::Function* stp = llvm::Intrinsic::getDeclaration(module.get(), llvm::Intrinsic::tlvm_stp_2d, {tile2d_ptr_t, int32_tile_t});
|
||||
llvm::Intrinsic::ID mma_id;
|
||||
if(!AT && !BT) mma_id = llvm::Intrinsic::tlvm_mma_nn;
|
||||
if(!AT && BT) mma_id = llvm::Intrinsic::tlvm_mma_nt;
|
||||
@@ -140,17 +141,18 @@ int main(){
|
||||
if(AT && BT) mma_id = llvm::Intrinsic::tlvm_mma_tt;
|
||||
llvm::Function* outer_add = llvm::Intrinsic::getDeclaration(module.get(), llvm::Intrinsic::tlvm_outer_add, {int32_tile_t, int32_slice_t, int32_slice_t});
|
||||
llvm::Function* outer_and = llvm::Intrinsic::getDeclaration(module.get(), llvm::Intrinsic::tlvm_outer_and, {int1_tile_t, int1_slice_t, int1_slice_t});
|
||||
llvm::Function* mma = llvm::Intrinsic::getDeclaration(module.get(), mma_id, {tile_t});
|
||||
llvm::Function* reshape = llvm::Intrinsic::getDeclaration(module.get(), llvm::Intrinsic::tlvm_reshape_2d, {tile_t});
|
||||
llvm::Function* mma = llvm::Intrinsic::getDeclaration(module.get(), mma, {tile3d_t});
|
||||
llvm::Function* reshape = llvm::Intrinsic::getDeclaration(module.get(), llvm::Intrinsic::tlvm_reshape_3d, {tile3d_t, tile2d_t});
|
||||
llvm::Function* splat_2d = llvm::Intrinsic::getDeclaration(module.get(), llvm::Intrinsic::tlvm_splat_2d, {mask_tile_t, bool_t});
|
||||
llvm::Function* splat_1d = llvm::Intrinsic::getDeclaration(module.get(), llvm::Intrinsic::tlvm_splat_1d, {int32_slice_t, int32_t});
|
||||
llvm::Function* masked_load = llvm::Intrinsic::getDeclaration(module.get(), llvm::Intrinsic::tlvm_masked_load, {tile_t, tile_ptr_t, mask_tile_t});
|
||||
llvm::Function* masked_store = llvm::Intrinsic::getDeclaration(module.get(), llvm::Intrinsic::tlvm_masked_store, {tile_t, tile_ptr_t, mask_tile_t});
|
||||
llvm::Function* masked_load = llvm::Intrinsic::getDeclaration(module.get(), llvm::Intrinsic::tlvm_masked_load, {tile2d_t, tile2d_ptr_t, mask_tile_t});
|
||||
llvm::Function* masked_store = llvm::Intrinsic::getDeclaration(module.get(), llvm::Intrinsic::tlvm_masked_store, {tile2d_t, tile2d_ptr_t, mask_tile_t});
|
||||
|
||||
// Hyperparameters
|
||||
llvm::Hyperparameter *bm = llvm::Hyperparameter::get(int32_t, 0);
|
||||
llvm::Hyperparameter *bn = llvm::Hyperparameter::get(int32_t, 1);
|
||||
llvm::Hyperparameter *bk = llvm::Hyperparameter::get(int32_t, 2);
|
||||
llvm::Hyperparameter *br = llvm::Hyperparameter::get(int32_t, 3);
|
||||
|
||||
// Constants
|
||||
llvm::Constant *_s0 = llvm::ConstantInt::get(int32_t, 0);
|
||||
@@ -221,8 +223,8 @@ int main(){
|
||||
llvm::CallInst* startpb = builder.CreateCall(gtp, {arguments[1], offb}, "startpb");
|
||||
llvm::LoadInst* startfa = builder.CreateLoad(startpa, "startfa");
|
||||
llvm::LoadInst* startfb = builder.CreateLoad(startpb, "startfb");
|
||||
llvm::Value* starta = builder.CreateCall(reshape, {startfa, ba0, ba1}, "starta");
|
||||
llvm::Value* startb = builder.CreateCall(reshape, {startfb, bb0, bb1}, "startb");
|
||||
llvm::Value* starta = builder.CreateCall(reshape, {startfa, ba0, ba1, br}, "starta");
|
||||
llvm::Value* startb = builder.CreateCall(reshape, {startfb, bb0, bb1, br}, "startb");
|
||||
llvm::Value* tinca0 = builder.CreateCall(splat_1d, {ba0, builder.CreateMul(inca0, AS0)}, "tinca0");
|
||||
llvm::Value* tinca1 = builder.CreateCall(splat_1d, {ba1, builder.CreateMul(inca1, AS1)});
|
||||
llvm::Value* tincb0 = builder.CreateCall(splat_1d, {bb0, builder.CreateMul(incb0, BS0)});
|
||||
@@ -261,8 +263,8 @@ int main(){
|
||||
// Pre-fetch
|
||||
llvm::Value* nextfa = builder.CreateCall(masked_load, {nextpa, maska}, "nextfa");
|
||||
llvm::Value* nextfb = builder.CreateCall(masked_load, {nextpb, maskb}, "nextfb");
|
||||
llvm::Value* nexta = builder.CreateCall(reshape, {nextfa, ba0, ba1}, "nexta");
|
||||
llvm::Value* nextb = builder.CreateCall(reshape, {nextfb, bb0, bb1}, "nextb");
|
||||
llvm::Value* nexta = builder.CreateCall(reshape, {nextfa, ba0, ba1, br}, "nexta");
|
||||
llvm::Value* nextb = builder.CreateCall(reshape, {nextfb, bb0, bb1, br}, "nextb");
|
||||
a->addIncoming(starta, PrologBB);
|
||||
a->addIncoming(nexta, LoopBB);
|
||||
b->addIncoming(startb, PrologBB);
|
||||
@@ -283,8 +285,8 @@ int main(){
|
||||
llvm::Value* lastmaskb = builder.CreateCall(outer_and, {in_bounds_b0, in_bounds_b1}, "lastmaskb");
|
||||
llvm::Value* lastfa = builder.CreateCall(masked_load, {nextpa, lastmaska}, "lastfa");
|
||||
llvm::Value* lastfb = builder.CreateCall(masked_load, {nextpb, lastmaskb}, "lastfb");
|
||||
llvm::Value* lasta = builder.CreateCall(reshape, {lastfa, ba0, ba1}, "lasta");
|
||||
llvm::Value* lastb = builder.CreateCall(reshape, {lastfb, bb0, bb1}, "lastb");
|
||||
llvm::Value* lasta = builder.CreateCall(reshape, {lastfa, ba0, ba1, br}, "lasta");
|
||||
llvm::Value* lastb = builder.CreateCall(reshape, {lastfb, bb0, bb1, br}, "lastb");
|
||||
llvm::Value* loop = builder.CreateICmpSGT(nextk, _s0);
|
||||
a->addIncoming(lasta, LastIterBB);
|
||||
b->addIncoming(lastb, LastIterBB);
|
||||
|
Reference in New Issue
Block a user