From 8b040b46454ffec90ef637bf1011fa9e278d5da5 Mon Sep 17 00:00:00 2001 From: Philippe Tillet Date: Mon, 3 Dec 2018 07:42:05 -0500 Subject: [PATCH] updates --- conv.cpp | 108 +++++++++++++++++++++++++++++++++++++++++-------------- gemm.cpp | 30 ++++++++-------- 2 files changed, 97 insertions(+), 41 deletions(-) diff --git a/conv.cpp b/conv.cpp index 7959a612d..fa99d301e 100644 --- a/conv.cpp +++ b/conv.cpp @@ -27,7 +27,57 @@ #include "llvm/Support/Debug.h" #include "llvm/Transforms/Utils/Cloning.h" +// Index computation +inline int32_t idx(int32_t x, int32_t y, int32_t z, int32_t w, int32_t u, + int32_t /*s0*/, int32_t s1, int32_t s2, int32_t s3, int32_t s4) +{ return u + w*s4 + z*s4*s3 + y*s4*s3*s2 + x*s4*s3*s2*s1; } +template +void cpp_conv_nchw(int32_t C, int32_t N, int32_t K, + int32_t D, int32_t H, int32_t W, + int32_t T, int32_t R, int32_t S, + int32_t pad_d, int32_t pad_h, int32_t pad_w, + int32_t stride_d, int32_t stride_h, int32_t stride_w, + int32_t M, int32_t P, int32_t Q, + std::vector>& O, IN_DTYPE* I, IN_DTYPE* F) +{ + size_t num_outputs = O.size(); + static const int PACK_IN = 1; + static const int PACK_OUT = 1; + if(C % PACK_IN != 0) throw std::runtime_error("Number of input channels must be a multiple of 4"); + if(K % PACK_OUT != 0) throw std::runtime_error("Number of output channels must be a multiple of 4"); + C /= PACK_IN; + K /= PACK_OUT; + int32_t Kout = K; + IN_DTYPE accs[PACK_OUT]; + for(size_t o = 0; o < num_outputs; o++) + for(int32_t m = 0 ; m < M; ++m) + for(int32_t p = 0 ; p < P; ++p) + for(int32_t q = 0; q < Q; ++q) + for(int32_t n = 0; n < N; ++n) + for(int32_t k = 0; k < Kout ; ++k) + { + for(int32_t i = 0 ; i < PACK_OUT; ++i) + accs[i] = 0; + int32_t mm = m*stride_d - pad_d; + int32_t pp = p*stride_h - pad_h; + int32_t qq = q*stride_w - pad_w; + for(int32_t kk = 0; kk < PACK_OUT; ++kk) + for(int32_t c = 0; c < C; ++c) + for(int32_t t = 0; t < T; ++t) + for(int32_t r = 0; r < R; ++r) + for(int32_t s = 0; s < S; ++s){ + int32_t d = mm + t; + int32_t h = pp + r; + int32_t w = qq + s; + bool in_bounds = (d >= 0 && h >= 0 && w >= 0 && d < D && h < H && w < W); + IN_DTYPE i = in_bounds?I[idx(n, c, d, h, w, N, C, D, H, W)]:0; + IN_DTYPE f = F[idx(c, t, r, s, k*PACK_OUT + kk, C, T, R, S, K*PACK_OUT)]; + accs[kk] += i*f; + } + O[o][idx(n, k, m, p, q, N, K, M, P, Q)] = accs[0]; + } +} void autotune(llvm::TargetMachine *machine, llvm::Module &module){ // Target parameters @@ -95,8 +145,6 @@ void autotune(llvm::TargetMachine *machine, llvm::Module &module){ } int main(){ -// llvm::DebugFlag = true; - std::string error; llvm::InitializeAllTargetInfos(); @@ -162,10 +210,15 @@ int main(){ llvm::Constant *_0 = llvm::ConstantTile::get(_f0, {bm, bn}); // LUT + unsigned num_delta = nlut; + unsigned num_inc_delta = nlut; + unsigned num_masks = nlut; + unsigned num_inc_masks = nlut; + unsigned cst_size = num_delta + num_inc_delta + num_masks + num_inc_masks; llvm::GlobalVariable *lut_array = - new llvm::GlobalVariable(*module, llvm::ArrayType::get(int32_t, nlut), false, llvm::GlobalVariable::InternalLinkage, + new llvm::GlobalVariable(*module, llvm::ArrayType::get(int32_t, cst_size), false, llvm::GlobalVariable::InternalLinkage, nullptr, "lut_array", nullptr, llvm::GlobalVariable::NotThreadLocal, 4); - llvm::Value *lut_ptr = builder.CreateBitCast(lut_array, lut_ptr_t); + llvm::Value *cst_ptr = builder.CreateBitCast(lut_array, lut_ptr_t); // Function @@ -177,7 +230,7 @@ int main(){ F->addAttribute(2, llvm::Attribute::ReadOnly); F->addAttribute(2, llvm::Attribute::NoAlias); std::transform(F->arg_begin(), F->arg_end(), std::back_inserter(args), [&](llvm::Argument& x){ return &x;}); - llvm::Value *base_o_ptr = args[0], *base_i_ptr = args[1], *base_f_ptr = args[2]; + llvm::Value *base_pc = args[0], *base_pa = args[1], *base_pb = args[2]; llvm::Value *C = args[3], *H = args[4], *W = args[5], *N = args[6], *K = args[7]; llvm::Value *R = builder.getInt32(RR), *S = builder.getInt32(SS); @@ -191,10 +244,10 @@ int main(){ // First basic block builder.SetInsertPoint(PrologBB); - llvm::Value* sa0 = builder.CreateCall(read_slice_x, {bm}, "i_slice_pqn"); - llvm::Value* sb0 = builder.CreateCall(read_slice_y, {bn}, "f_slice_k"); - llvm::Value* sa1 = builder.CreateCall(range, {builder.getInt32(0), bk}, "i_slice_crs"); - llvm::Value* sb1 = builder.CreateCall(range, {builder.getInt32(0), bk}, "f_slice_crs"); + llvm::Value* sa0 = builder.CreateCall(read_slice_x, {bm}, "sa0"); + llvm::Value* sb0 = builder.CreateCall(read_slice_y, {bn}, "sb0"); + llvm::Value* sa1 = builder.CreateCall(range, {builder.getInt32(0), bk}, "sa1"); + llvm::Value* sb1 = builder.CreateCall(range, {builder.getInt32(0), bk}, "sb1"); llvm::Value* lda_w = builder.getInt32(1); llvm::Value* lda_h = builder.CreateMul(lda_w, W); @@ -227,30 +280,31 @@ int main(){ offa_1 = builder.CreateAdd(offa_1, builder.CreateMul(sa_s, builder.CreateCall(splat_1d, {bk, lda_w}))); // Images pointer llvm::Value* off_a = builder.CreateCall(outer_add, {offa_0, offa_1}); - llvm::Value* start_pa = builder.CreateCall(gtp_2d, {base_i_ptr, off_a}, "start_i_ptr"); - llvm::LoadInst* start_aa = builder.CreateLoad(start_pa, false, "start_i_val"); - llvm::Value* start_a = builder.CreateCall(reshape, {start_aa, bm, bk}, "start_i"); + llvm::Value* start_pa = builder.CreateCall(gtp_2d, {base_pa, off_a}, "start_pa"); + llvm::LoadInst* start_aa = builder.CreateLoad(start_pa, false, "start_aa"); + llvm::Value* start_a = builder.CreateCall(reshape, {start_aa, bm, bk}, "start_a"); // Filters pointer llvm::Value* tldb_s = builder.CreateCall(splat_1d, {bk, K}); - llvm::Value* off_b = builder.CreateCall(outer_add, {sb0, builder.CreateMul(sb1, tldb_s)}, "off_f"); - llvm::Value* start_pb = builder.CreateCall(gtp_2d, {base_f_ptr, off_b}, "start_f_ptr"); - llvm::Value* start_bb = builder.CreateLoad(start_pb, false, "start_f_val"); - llvm::Value* start_b = builder.CreateCall(reshape, {start_bb, bn, bk}, "start_f"); + llvm::Value* off_b = builder.CreateCall(outer_add, {sb0, builder.CreateMul(sb1, tldb_s)}, "off_b"); + llvm::Value* start_pb = builder.CreateCall(gtp_2d, {base_pb, off_b}, "start_pb"); + llvm::Value* start_bb = builder.CreateLoad(start_pb, false, "start_bb"); + llvm::Value* start_b = builder.CreateCall(reshape, {start_bb, bn, bk}, "start_b"); // Filters increment - llvm::Value* inc_b_0 = builder.CreateCall(splat_1d, {bn, _s0}, "inc_f_0"); - llvm::Value* inc_b_1 = builder.CreateCall(splat_1d, {bk, builder.CreateMul(bk, ldb_k)}, "inc_f_1"); - llvm::Value* inc_b = builder.CreateCall(outer_add, {inc_b_0, inc_b_1}, "inc_f"); - // Delta pointers - llvm::Value* base_incdelta = lut_ptr; + llvm::Value* inc_b_0 = builder.CreateCall(splat_1d, {bn, _s0}, "inc_b_0"); + llvm::Value* inc_b_1 = builder.CreateCall(splat_1d, {bk, builder.CreateMul(bk, ldb_k)}, "inc_b_1"); + llvm::Value* inc_b = builder.CreateCall(outer_add, {inc_b_0, inc_b_1}, "inc_b"); + // Pointers to constant memory + llvm::Value* base_incdelta = builder.CreateGEP(cst_ptr, builder.getInt32(0)); + llvm::Value* base_delta = builder.CreateGEP(cst_ptr, builder.getInt32(num_inc_delta)); + llvm::Value* base_incmask = builder.CreateGEP(cst_ptr, builder.getInt32(num_delta)); + llvm::Value* base_mask = builder.CreateGEP(cst_ptr, builder.getInt32(num_inc_masks)); + // Delta pointers llvm::Value* start_pincdelta = builder.CreateCall(gtp_1d, {base_incdelta, sa1}, "start_pincdelta"); - llvm::Value* base_delta = builder.CreateGEP(lut_ptr, builder.getInt32(nlut)); llvm::Value* start_pdelta = builder.CreateCall(gtp_1d, {base_delta, builder.CreateCall(splat_1d, {bk, _s0})}, "start_pdelta"); // Masks llvm::Value* _1 = builder.CreateCall(splat_1d, {bk, builder.getInt32(1)}); llvm::Value* mask_a_1 = builder.CreateShl(_1, sa1); - llvm::Value* base_incmask = builder.CreateGEP(lut_ptr, builder.getInt32(2*nlut), "base_incmask"); llvm::Value* start_pincmask = builder.CreateCall(gtp_1d, {base_incmask, sa0}, "start_pincmask"); - llvm::Value* base_mask = builder.CreateGEP(lut_ptr, builder.getInt32(3*nlut), "base_mask"); llvm::Value* start_pmask = builder.CreateCall(gtp_1d, {base_mask, sa0}, "start_pmask"); // Enter loop builder.CreateBr(LoopBB); @@ -341,8 +395,8 @@ int main(){ // Epilogue builder.SetInsertPoint(EpilogueBB); - llvm::Value* sc_pqn = builder.CreateCall(read_slice_x, {bm}, "o_slice_pqn"); - llvm::Value* sc_k = builder.CreateCall(read_slice_y, {bn}, "o_slice_k"); + llvm::Value* sc_pqn = builder.CreateCall(read_slice_x, {bm}, "sc_pqn"); + llvm::Value* sc_k = builder.CreateCall(read_slice_y, {bn}, "sc_k"); // Output strides llvm::Value* ldc_q = builder.getInt32(1); llvm::Value* ldc_p = builder.CreateMul(lda_w, W); @@ -360,7 +414,7 @@ int main(){ llvm::Value* offc1 = builder.CreateMul(sc_k, builder.CreateCall(splat_1d, {bn, ldc_k})); // Output pointer llvm::Value* offc = builder.CreateCall(outer_add, {offc0, offc1}); - llvm::Value* pc = builder.CreateCall(gtp_2d, {base_o_ptr, offc}); + llvm::Value* pc = builder.CreateCall(gtp_2d, {base_pc, offc}); // Output masks llvm::Value* in_bounds_c0 = builder.CreateICmpSLT(sc_pqn, builder.CreateCall(splat_1d, {bm, PQN})); llvm::Value* in_bounds_c1 = builder.CreateICmpSLT(sc_k, builder.CreateCall(splat_1d, {bn, K})); diff --git a/gemm.cpp b/gemm.cpp index 5adc3ebbd..5433fd8d9 100644 --- a/gemm.cpp +++ b/gemm.cpp @@ -121,18 +121,19 @@ int main(){ llvm::IntegerType* int32_t = llvm::Type::getInt32Ty(context); llvm::IntegerType* int1_t = llvm::Type::getInt1Ty(context); - llvm::Type* tile_t = llvm::TileType::get(numeric_t, 2); + llvm::Type* tile2d_t = llvm::TileType::get(numeric_t, 2); + llvm::Type* tile3d_t = llvm::TileType::get(numeric_t, 3); llvm::Type* int32_slice_t = llvm::TileType::get(int32_t, 1); llvm::Type* int32_tile_t = llvm::TileType::get(int32_t, 2); llvm::Type* int1_slice_t = llvm::TileType::get(int1_t, 1); llvm::Type* int1_tile_t = llvm::TileType::get(int1_t, 2); - llvm::PointerType* tile_ptr_t = llvm::PointerType::get(tile_t, 0); + llvm::PointerType* tile2d_ptr_t = llvm::PointerType::get(tile2d_t, 0); llvm::Function* read_slice_x = llvm::Intrinsic::getDeclaration(module.get(), llvm::Intrinsic::tlvm_read_slice_x, {int32_slice_t}); llvm::Function* read_slice_y = llvm::Intrinsic::getDeclaration(module.get(), llvm::Intrinsic::tlvm_read_slice_y, {int32_slice_t}); llvm::Function* range = llvm::Intrinsic::getDeclaration(module.get(), llvm::Intrinsic::tlvm_range, {int32_slice_t}); - llvm::Function* gtp = llvm::Intrinsic::getDeclaration(module.get(), llvm::Intrinsic::tlvm_gtp_2d, {tile_ptr_t, numeric_ptr_t, int32_tile_t}); - llvm::Function* stp = llvm::Intrinsic::getDeclaration(module.get(), llvm::Intrinsic::tlvm_stp_2d, {tile_ptr_t, int32_tile_t}); + llvm::Function* gtp = llvm::Intrinsic::getDeclaration(module.get(), llvm::Intrinsic::tlvm_gtp_2d, {tile2d_ptr_t, numeric_ptr_t, int32_tile_t}); + llvm::Function* stp = llvm::Intrinsic::getDeclaration(module.get(), llvm::Intrinsic::tlvm_stp_2d, {tile2d_ptr_t, int32_tile_t}); llvm::Intrinsic::ID mma_id; if(!AT && !BT) mma_id = llvm::Intrinsic::tlvm_mma_nn; if(!AT && BT) mma_id = llvm::Intrinsic::tlvm_mma_nt; @@ -140,17 +141,18 @@ int main(){ if(AT && BT) mma_id = llvm::Intrinsic::tlvm_mma_tt; llvm::Function* outer_add = llvm::Intrinsic::getDeclaration(module.get(), llvm::Intrinsic::tlvm_outer_add, {int32_tile_t, int32_slice_t, int32_slice_t}); llvm::Function* outer_and = llvm::Intrinsic::getDeclaration(module.get(), llvm::Intrinsic::tlvm_outer_and, {int1_tile_t, int1_slice_t, int1_slice_t}); - llvm::Function* mma = llvm::Intrinsic::getDeclaration(module.get(), mma_id, {tile_t}); - llvm::Function* reshape = llvm::Intrinsic::getDeclaration(module.get(), llvm::Intrinsic::tlvm_reshape_2d, {tile_t}); + llvm::Function* mma = llvm::Intrinsic::getDeclaration(module.get(), mma, {tile3d_t}); + llvm::Function* reshape = llvm::Intrinsic::getDeclaration(module.get(), llvm::Intrinsic::tlvm_reshape_3d, {tile3d_t, tile2d_t}); llvm::Function* splat_2d = llvm::Intrinsic::getDeclaration(module.get(), llvm::Intrinsic::tlvm_splat_2d, {mask_tile_t, bool_t}); llvm::Function* splat_1d = llvm::Intrinsic::getDeclaration(module.get(), llvm::Intrinsic::tlvm_splat_1d, {int32_slice_t, int32_t}); - llvm::Function* masked_load = llvm::Intrinsic::getDeclaration(module.get(), llvm::Intrinsic::tlvm_masked_load, {tile_t, tile_ptr_t, mask_tile_t}); - llvm::Function* masked_store = llvm::Intrinsic::getDeclaration(module.get(), llvm::Intrinsic::tlvm_masked_store, {tile_t, tile_ptr_t, mask_tile_t}); + llvm::Function* masked_load = llvm::Intrinsic::getDeclaration(module.get(), llvm::Intrinsic::tlvm_masked_load, {tile2d_t, tile2d_ptr_t, mask_tile_t}); + llvm::Function* masked_store = llvm::Intrinsic::getDeclaration(module.get(), llvm::Intrinsic::tlvm_masked_store, {tile2d_t, tile2d_ptr_t, mask_tile_t}); // Hyperparameters llvm::Hyperparameter *bm = llvm::Hyperparameter::get(int32_t, 0); llvm::Hyperparameter *bn = llvm::Hyperparameter::get(int32_t, 1); llvm::Hyperparameter *bk = llvm::Hyperparameter::get(int32_t, 2); + llvm::Hyperparameter *br = llvm::Hyperparameter::get(int32_t, 3); // Constants llvm::Constant *_s0 = llvm::ConstantInt::get(int32_t, 0); @@ -221,8 +223,8 @@ int main(){ llvm::CallInst* startpb = builder.CreateCall(gtp, {arguments[1], offb}, "startpb"); llvm::LoadInst* startfa = builder.CreateLoad(startpa, "startfa"); llvm::LoadInst* startfb = builder.CreateLoad(startpb, "startfb"); - llvm::Value* starta = builder.CreateCall(reshape, {startfa, ba0, ba1}, "starta"); - llvm::Value* startb = builder.CreateCall(reshape, {startfb, bb0, bb1}, "startb"); + llvm::Value* starta = builder.CreateCall(reshape, {startfa, ba0, ba1, br}, "starta"); + llvm::Value* startb = builder.CreateCall(reshape, {startfb, bb0, bb1, br}, "startb"); llvm::Value* tinca0 = builder.CreateCall(splat_1d, {ba0, builder.CreateMul(inca0, AS0)}, "tinca0"); llvm::Value* tinca1 = builder.CreateCall(splat_1d, {ba1, builder.CreateMul(inca1, AS1)}); llvm::Value* tincb0 = builder.CreateCall(splat_1d, {bb0, builder.CreateMul(incb0, BS0)}); @@ -261,8 +263,8 @@ int main(){ // Pre-fetch llvm::Value* nextfa = builder.CreateCall(masked_load, {nextpa, maska}, "nextfa"); llvm::Value* nextfb = builder.CreateCall(masked_load, {nextpb, maskb}, "nextfb"); - llvm::Value* nexta = builder.CreateCall(reshape, {nextfa, ba0, ba1}, "nexta"); - llvm::Value* nextb = builder.CreateCall(reshape, {nextfb, bb0, bb1}, "nextb"); + llvm::Value* nexta = builder.CreateCall(reshape, {nextfa, ba0, ba1, br}, "nexta"); + llvm::Value* nextb = builder.CreateCall(reshape, {nextfb, bb0, bb1, br}, "nextb"); a->addIncoming(starta, PrologBB); a->addIncoming(nexta, LoopBB); b->addIncoming(startb, PrologBB); @@ -283,8 +285,8 @@ int main(){ llvm::Value* lastmaskb = builder.CreateCall(outer_and, {in_bounds_b0, in_bounds_b1}, "lastmaskb"); llvm::Value* lastfa = builder.CreateCall(masked_load, {nextpa, lastmaska}, "lastfa"); llvm::Value* lastfb = builder.CreateCall(masked_load, {nextpb, lastmaskb}, "lastfb"); - llvm::Value* lasta = builder.CreateCall(reshape, {lastfa, ba0, ba1}, "lasta"); - llvm::Value* lastb = builder.CreateCall(reshape, {lastfb, bb0, bb1}, "lastb"); + llvm::Value* lasta = builder.CreateCall(reshape, {lastfa, ba0, ba1, br}, "lasta"); + llvm::Value* lastb = builder.CreateCall(reshape, {lastfb, bb0, bb1, br}, "lastb"); llvm::Value* loop = builder.CreateICmpSGT(nextk, _s0); a->addIncoming(lasta, LastIterBB); b->addIncoming(lastb, LastIterBB);