updates
This commit is contained in:
108
conv.cpp
108
conv.cpp
@@ -27,7 +27,57 @@
|
||||
#include "llvm/Support/Debug.h"
|
||||
#include "llvm/Transforms/Utils/Cloning.h"
|
||||
|
||||
// Index computation
|
||||
inline int32_t idx(int32_t x, int32_t y, int32_t z, int32_t w, int32_t u,
|
||||
int32_t /*s0*/, int32_t s1, int32_t s2, int32_t s3, int32_t s4)
|
||||
{ return u + w*s4 + z*s4*s3 + y*s4*s3*s2 + x*s4*s3*s2*s1; }
|
||||
|
||||
template<class IN_DTYPE, class OUT_DTYPE>
|
||||
void cpp_conv_nchw(int32_t C, int32_t N, int32_t K,
|
||||
int32_t D, int32_t H, int32_t W,
|
||||
int32_t T, int32_t R, int32_t S,
|
||||
int32_t pad_d, int32_t pad_h, int32_t pad_w,
|
||||
int32_t stride_d, int32_t stride_h, int32_t stride_w,
|
||||
int32_t M, int32_t P, int32_t Q,
|
||||
std::vector<std::vector<OUT_DTYPE>>& O, IN_DTYPE* I, IN_DTYPE* F)
|
||||
{
|
||||
size_t num_outputs = O.size();
|
||||
static const int PACK_IN = 1;
|
||||
static const int PACK_OUT = 1;
|
||||
if(C % PACK_IN != 0) throw std::runtime_error("Number of input channels must be a multiple of 4");
|
||||
if(K % PACK_OUT != 0) throw std::runtime_error("Number of output channels must be a multiple of 4");
|
||||
C /= PACK_IN;
|
||||
K /= PACK_OUT;
|
||||
int32_t Kout = K;
|
||||
IN_DTYPE accs[PACK_OUT];
|
||||
for(size_t o = 0; o < num_outputs; o++)
|
||||
for(int32_t m = 0 ; m < M; ++m)
|
||||
for(int32_t p = 0 ; p < P; ++p)
|
||||
for(int32_t q = 0; q < Q; ++q)
|
||||
for(int32_t n = 0; n < N; ++n)
|
||||
for(int32_t k = 0; k < Kout ; ++k)
|
||||
{
|
||||
for(int32_t i = 0 ; i < PACK_OUT; ++i)
|
||||
accs[i] = 0;
|
||||
int32_t mm = m*stride_d - pad_d;
|
||||
int32_t pp = p*stride_h - pad_h;
|
||||
int32_t qq = q*stride_w - pad_w;
|
||||
for(int32_t kk = 0; kk < PACK_OUT; ++kk)
|
||||
for(int32_t c = 0; c < C; ++c)
|
||||
for(int32_t t = 0; t < T; ++t)
|
||||
for(int32_t r = 0; r < R; ++r)
|
||||
for(int32_t s = 0; s < S; ++s){
|
||||
int32_t d = mm + t;
|
||||
int32_t h = pp + r;
|
||||
int32_t w = qq + s;
|
||||
bool in_bounds = (d >= 0 && h >= 0 && w >= 0 && d < D && h < H && w < W);
|
||||
IN_DTYPE i = in_bounds?I[idx(n, c, d, h, w, N, C, D, H, W)]:0;
|
||||
IN_DTYPE f = F[idx(c, t, r, s, k*PACK_OUT + kk, C, T, R, S, K*PACK_OUT)];
|
||||
accs[kk] += i*f;
|
||||
}
|
||||
O[o][idx(n, k, m, p, q, N, K, M, P, Q)] = accs[0];
|
||||
}
|
||||
}
|
||||
|
||||
void autotune(llvm::TargetMachine *machine, llvm::Module &module){
|
||||
// Target parameters
|
||||
@@ -95,8 +145,6 @@ void autotune(llvm::TargetMachine *machine, llvm::Module &module){
|
||||
}
|
||||
|
||||
int main(){
|
||||
// llvm::DebugFlag = true;
|
||||
|
||||
std::string error;
|
||||
|
||||
llvm::InitializeAllTargetInfos();
|
||||
@@ -162,10 +210,15 @@ int main(){
|
||||
llvm::Constant *_0 = llvm::ConstantTile::get(_f0, {bm, bn});
|
||||
|
||||
// LUT
|
||||
unsigned num_delta = nlut;
|
||||
unsigned num_inc_delta = nlut;
|
||||
unsigned num_masks = nlut;
|
||||
unsigned num_inc_masks = nlut;
|
||||
unsigned cst_size = num_delta + num_inc_delta + num_masks + num_inc_masks;
|
||||
llvm::GlobalVariable *lut_array =
|
||||
new llvm::GlobalVariable(*module, llvm::ArrayType::get(int32_t, nlut), false, llvm::GlobalVariable::InternalLinkage,
|
||||
new llvm::GlobalVariable(*module, llvm::ArrayType::get(int32_t, cst_size), false, llvm::GlobalVariable::InternalLinkage,
|
||||
nullptr, "lut_array", nullptr, llvm::GlobalVariable::NotThreadLocal, 4);
|
||||
llvm::Value *lut_ptr = builder.CreateBitCast(lut_array, lut_ptr_t);
|
||||
llvm::Value *cst_ptr = builder.CreateBitCast(lut_array, lut_ptr_t);
|
||||
|
||||
|
||||
// Function
|
||||
@@ -177,7 +230,7 @@ int main(){
|
||||
F->addAttribute(2, llvm::Attribute::ReadOnly);
|
||||
F->addAttribute(2, llvm::Attribute::NoAlias);
|
||||
std::transform(F->arg_begin(), F->arg_end(), std::back_inserter(args), [&](llvm::Argument& x){ return &x;});
|
||||
llvm::Value *base_o_ptr = args[0], *base_i_ptr = args[1], *base_f_ptr = args[2];
|
||||
llvm::Value *base_pc = args[0], *base_pa = args[1], *base_pb = args[2];
|
||||
llvm::Value *C = args[3], *H = args[4], *W = args[5], *N = args[6], *K = args[7];
|
||||
llvm::Value *R = builder.getInt32(RR), *S = builder.getInt32(SS);
|
||||
|
||||
@@ -191,10 +244,10 @@ int main(){
|
||||
|
||||
// First basic block
|
||||
builder.SetInsertPoint(PrologBB);
|
||||
llvm::Value* sa0 = builder.CreateCall(read_slice_x, {bm}, "i_slice_pqn");
|
||||
llvm::Value* sb0 = builder.CreateCall(read_slice_y, {bn}, "f_slice_k");
|
||||
llvm::Value* sa1 = builder.CreateCall(range, {builder.getInt32(0), bk}, "i_slice_crs");
|
||||
llvm::Value* sb1 = builder.CreateCall(range, {builder.getInt32(0), bk}, "f_slice_crs");
|
||||
llvm::Value* sa0 = builder.CreateCall(read_slice_x, {bm}, "sa0");
|
||||
llvm::Value* sb0 = builder.CreateCall(read_slice_y, {bn}, "sb0");
|
||||
llvm::Value* sa1 = builder.CreateCall(range, {builder.getInt32(0), bk}, "sa1");
|
||||
llvm::Value* sb1 = builder.CreateCall(range, {builder.getInt32(0), bk}, "sb1");
|
||||
|
||||
llvm::Value* lda_w = builder.getInt32(1);
|
||||
llvm::Value* lda_h = builder.CreateMul(lda_w, W);
|
||||
@@ -227,30 +280,31 @@ int main(){
|
||||
offa_1 = builder.CreateAdd(offa_1, builder.CreateMul(sa_s, builder.CreateCall(splat_1d, {bk, lda_w})));
|
||||
// Images pointer
|
||||
llvm::Value* off_a = builder.CreateCall(outer_add, {offa_0, offa_1});
|
||||
llvm::Value* start_pa = builder.CreateCall(gtp_2d, {base_i_ptr, off_a}, "start_i_ptr");
|
||||
llvm::LoadInst* start_aa = builder.CreateLoad(start_pa, false, "start_i_val");
|
||||
llvm::Value* start_a = builder.CreateCall(reshape, {start_aa, bm, bk}, "start_i");
|
||||
llvm::Value* start_pa = builder.CreateCall(gtp_2d, {base_pa, off_a}, "start_pa");
|
||||
llvm::LoadInst* start_aa = builder.CreateLoad(start_pa, false, "start_aa");
|
||||
llvm::Value* start_a = builder.CreateCall(reshape, {start_aa, bm, bk}, "start_a");
|
||||
// Filters pointer
|
||||
llvm::Value* tldb_s = builder.CreateCall(splat_1d, {bk, K});
|
||||
llvm::Value* off_b = builder.CreateCall(outer_add, {sb0, builder.CreateMul(sb1, tldb_s)}, "off_f");
|
||||
llvm::Value* start_pb = builder.CreateCall(gtp_2d, {base_f_ptr, off_b}, "start_f_ptr");
|
||||
llvm::Value* start_bb = builder.CreateLoad(start_pb, false, "start_f_val");
|
||||
llvm::Value* start_b = builder.CreateCall(reshape, {start_bb, bn, bk}, "start_f");
|
||||
llvm::Value* off_b = builder.CreateCall(outer_add, {sb0, builder.CreateMul(sb1, tldb_s)}, "off_b");
|
||||
llvm::Value* start_pb = builder.CreateCall(gtp_2d, {base_pb, off_b}, "start_pb");
|
||||
llvm::Value* start_bb = builder.CreateLoad(start_pb, false, "start_bb");
|
||||
llvm::Value* start_b = builder.CreateCall(reshape, {start_bb, bn, bk}, "start_b");
|
||||
// Filters increment
|
||||
llvm::Value* inc_b_0 = builder.CreateCall(splat_1d, {bn, _s0}, "inc_f_0");
|
||||
llvm::Value* inc_b_1 = builder.CreateCall(splat_1d, {bk, builder.CreateMul(bk, ldb_k)}, "inc_f_1");
|
||||
llvm::Value* inc_b = builder.CreateCall(outer_add, {inc_b_0, inc_b_1}, "inc_f");
|
||||
// Delta pointers
|
||||
llvm::Value* base_incdelta = lut_ptr;
|
||||
llvm::Value* inc_b_0 = builder.CreateCall(splat_1d, {bn, _s0}, "inc_b_0");
|
||||
llvm::Value* inc_b_1 = builder.CreateCall(splat_1d, {bk, builder.CreateMul(bk, ldb_k)}, "inc_b_1");
|
||||
llvm::Value* inc_b = builder.CreateCall(outer_add, {inc_b_0, inc_b_1}, "inc_b");
|
||||
// Pointers to constant memory
|
||||
llvm::Value* base_incdelta = builder.CreateGEP(cst_ptr, builder.getInt32(0));
|
||||
llvm::Value* base_delta = builder.CreateGEP(cst_ptr, builder.getInt32(num_inc_delta));
|
||||
llvm::Value* base_incmask = builder.CreateGEP(cst_ptr, builder.getInt32(num_delta));
|
||||
llvm::Value* base_mask = builder.CreateGEP(cst_ptr, builder.getInt32(num_inc_masks));
|
||||
// Delta pointers
|
||||
llvm::Value* start_pincdelta = builder.CreateCall(gtp_1d, {base_incdelta, sa1}, "start_pincdelta");
|
||||
llvm::Value* base_delta = builder.CreateGEP(lut_ptr, builder.getInt32(nlut));
|
||||
llvm::Value* start_pdelta = builder.CreateCall(gtp_1d, {base_delta, builder.CreateCall(splat_1d, {bk, _s0})}, "start_pdelta");
|
||||
// Masks
|
||||
llvm::Value* _1 = builder.CreateCall(splat_1d, {bk, builder.getInt32(1)});
|
||||
llvm::Value* mask_a_1 = builder.CreateShl(_1, sa1);
|
||||
llvm::Value* base_incmask = builder.CreateGEP(lut_ptr, builder.getInt32(2*nlut), "base_incmask");
|
||||
llvm::Value* start_pincmask = builder.CreateCall(gtp_1d, {base_incmask, sa0}, "start_pincmask");
|
||||
llvm::Value* base_mask = builder.CreateGEP(lut_ptr, builder.getInt32(3*nlut), "base_mask");
|
||||
llvm::Value* start_pmask = builder.CreateCall(gtp_1d, {base_mask, sa0}, "start_pmask");
|
||||
// Enter loop
|
||||
builder.CreateBr(LoopBB);
|
||||
@@ -341,8 +395,8 @@ int main(){
|
||||
|
||||
// Epilogue
|
||||
builder.SetInsertPoint(EpilogueBB);
|
||||
llvm::Value* sc_pqn = builder.CreateCall(read_slice_x, {bm}, "o_slice_pqn");
|
||||
llvm::Value* sc_k = builder.CreateCall(read_slice_y, {bn}, "o_slice_k");
|
||||
llvm::Value* sc_pqn = builder.CreateCall(read_slice_x, {bm}, "sc_pqn");
|
||||
llvm::Value* sc_k = builder.CreateCall(read_slice_y, {bn}, "sc_k");
|
||||
// Output strides
|
||||
llvm::Value* ldc_q = builder.getInt32(1);
|
||||
llvm::Value* ldc_p = builder.CreateMul(lda_w, W);
|
||||
@@ -360,7 +414,7 @@ int main(){
|
||||
llvm::Value* offc1 = builder.CreateMul(sc_k, builder.CreateCall(splat_1d, {bn, ldc_k}));
|
||||
// Output pointer
|
||||
llvm::Value* offc = builder.CreateCall(outer_add, {offc0, offc1});
|
||||
llvm::Value* pc = builder.CreateCall(gtp_2d, {base_o_ptr, offc});
|
||||
llvm::Value* pc = builder.CreateCall(gtp_2d, {base_pc, offc});
|
||||
// Output masks
|
||||
llvm::Value* in_bounds_c0 = builder.CreateICmpSLT(sc_pqn, builder.CreateCall(splat_1d, {bm, PQN}));
|
||||
llvm::Value* in_bounds_c1 = builder.CreateICmpSLT(sc_k, builder.CreateCall(splat_1d, {bn, K}));
|
||||
|
30
gemm.cpp
30
gemm.cpp
@@ -121,18 +121,19 @@ int main(){
|
||||
llvm::IntegerType* int32_t = llvm::Type::getInt32Ty(context);
|
||||
llvm::IntegerType* int1_t = llvm::Type::getInt1Ty(context);
|
||||
|
||||
llvm::Type* tile_t = llvm::TileType::get(numeric_t, 2);
|
||||
llvm::Type* tile2d_t = llvm::TileType::get(numeric_t, 2);
|
||||
llvm::Type* tile3d_t = llvm::TileType::get(numeric_t, 3);
|
||||
llvm::Type* int32_slice_t = llvm::TileType::get(int32_t, 1);
|
||||
llvm::Type* int32_tile_t = llvm::TileType::get(int32_t, 2);
|
||||
llvm::Type* int1_slice_t = llvm::TileType::get(int1_t, 1);
|
||||
llvm::Type* int1_tile_t = llvm::TileType::get(int1_t, 2);
|
||||
|
||||
llvm::PointerType* tile_ptr_t = llvm::PointerType::get(tile_t, 0);
|
||||
llvm::PointerType* tile2d_ptr_t = llvm::PointerType::get(tile2d_t, 0);
|
||||
llvm::Function* read_slice_x = llvm::Intrinsic::getDeclaration(module.get(), llvm::Intrinsic::tlvm_read_slice_x, {int32_slice_t});
|
||||
llvm::Function* read_slice_y = llvm::Intrinsic::getDeclaration(module.get(), llvm::Intrinsic::tlvm_read_slice_y, {int32_slice_t});
|
||||
llvm::Function* range = llvm::Intrinsic::getDeclaration(module.get(), llvm::Intrinsic::tlvm_range, {int32_slice_t});
|
||||
llvm::Function* gtp = llvm::Intrinsic::getDeclaration(module.get(), llvm::Intrinsic::tlvm_gtp_2d, {tile_ptr_t, numeric_ptr_t, int32_tile_t});
|
||||
llvm::Function* stp = llvm::Intrinsic::getDeclaration(module.get(), llvm::Intrinsic::tlvm_stp_2d, {tile_ptr_t, int32_tile_t});
|
||||
llvm::Function* gtp = llvm::Intrinsic::getDeclaration(module.get(), llvm::Intrinsic::tlvm_gtp_2d, {tile2d_ptr_t, numeric_ptr_t, int32_tile_t});
|
||||
llvm::Function* stp = llvm::Intrinsic::getDeclaration(module.get(), llvm::Intrinsic::tlvm_stp_2d, {tile2d_ptr_t, int32_tile_t});
|
||||
llvm::Intrinsic::ID mma_id;
|
||||
if(!AT && !BT) mma_id = llvm::Intrinsic::tlvm_mma_nn;
|
||||
if(!AT && BT) mma_id = llvm::Intrinsic::tlvm_mma_nt;
|
||||
@@ -140,17 +141,18 @@ int main(){
|
||||
if(AT && BT) mma_id = llvm::Intrinsic::tlvm_mma_tt;
|
||||
llvm::Function* outer_add = llvm::Intrinsic::getDeclaration(module.get(), llvm::Intrinsic::tlvm_outer_add, {int32_tile_t, int32_slice_t, int32_slice_t});
|
||||
llvm::Function* outer_and = llvm::Intrinsic::getDeclaration(module.get(), llvm::Intrinsic::tlvm_outer_and, {int1_tile_t, int1_slice_t, int1_slice_t});
|
||||
llvm::Function* mma = llvm::Intrinsic::getDeclaration(module.get(), mma_id, {tile_t});
|
||||
llvm::Function* reshape = llvm::Intrinsic::getDeclaration(module.get(), llvm::Intrinsic::tlvm_reshape_2d, {tile_t});
|
||||
llvm::Function* mma = llvm::Intrinsic::getDeclaration(module.get(), mma, {tile3d_t});
|
||||
llvm::Function* reshape = llvm::Intrinsic::getDeclaration(module.get(), llvm::Intrinsic::tlvm_reshape_3d, {tile3d_t, tile2d_t});
|
||||
llvm::Function* splat_2d = llvm::Intrinsic::getDeclaration(module.get(), llvm::Intrinsic::tlvm_splat_2d, {mask_tile_t, bool_t});
|
||||
llvm::Function* splat_1d = llvm::Intrinsic::getDeclaration(module.get(), llvm::Intrinsic::tlvm_splat_1d, {int32_slice_t, int32_t});
|
||||
llvm::Function* masked_load = llvm::Intrinsic::getDeclaration(module.get(), llvm::Intrinsic::tlvm_masked_load, {tile_t, tile_ptr_t, mask_tile_t});
|
||||
llvm::Function* masked_store = llvm::Intrinsic::getDeclaration(module.get(), llvm::Intrinsic::tlvm_masked_store, {tile_t, tile_ptr_t, mask_tile_t});
|
||||
llvm::Function* masked_load = llvm::Intrinsic::getDeclaration(module.get(), llvm::Intrinsic::tlvm_masked_load, {tile2d_t, tile2d_ptr_t, mask_tile_t});
|
||||
llvm::Function* masked_store = llvm::Intrinsic::getDeclaration(module.get(), llvm::Intrinsic::tlvm_masked_store, {tile2d_t, tile2d_ptr_t, mask_tile_t});
|
||||
|
||||
// Hyperparameters
|
||||
llvm::Hyperparameter *bm = llvm::Hyperparameter::get(int32_t, 0);
|
||||
llvm::Hyperparameter *bn = llvm::Hyperparameter::get(int32_t, 1);
|
||||
llvm::Hyperparameter *bk = llvm::Hyperparameter::get(int32_t, 2);
|
||||
llvm::Hyperparameter *br = llvm::Hyperparameter::get(int32_t, 3);
|
||||
|
||||
// Constants
|
||||
llvm::Constant *_s0 = llvm::ConstantInt::get(int32_t, 0);
|
||||
@@ -221,8 +223,8 @@ int main(){
|
||||
llvm::CallInst* startpb = builder.CreateCall(gtp, {arguments[1], offb}, "startpb");
|
||||
llvm::LoadInst* startfa = builder.CreateLoad(startpa, "startfa");
|
||||
llvm::LoadInst* startfb = builder.CreateLoad(startpb, "startfb");
|
||||
llvm::Value* starta = builder.CreateCall(reshape, {startfa, ba0, ba1}, "starta");
|
||||
llvm::Value* startb = builder.CreateCall(reshape, {startfb, bb0, bb1}, "startb");
|
||||
llvm::Value* starta = builder.CreateCall(reshape, {startfa, ba0, ba1, br}, "starta");
|
||||
llvm::Value* startb = builder.CreateCall(reshape, {startfb, bb0, bb1, br}, "startb");
|
||||
llvm::Value* tinca0 = builder.CreateCall(splat_1d, {ba0, builder.CreateMul(inca0, AS0)}, "tinca0");
|
||||
llvm::Value* tinca1 = builder.CreateCall(splat_1d, {ba1, builder.CreateMul(inca1, AS1)});
|
||||
llvm::Value* tincb0 = builder.CreateCall(splat_1d, {bb0, builder.CreateMul(incb0, BS0)});
|
||||
@@ -261,8 +263,8 @@ int main(){
|
||||
// Pre-fetch
|
||||
llvm::Value* nextfa = builder.CreateCall(masked_load, {nextpa, maska}, "nextfa");
|
||||
llvm::Value* nextfb = builder.CreateCall(masked_load, {nextpb, maskb}, "nextfb");
|
||||
llvm::Value* nexta = builder.CreateCall(reshape, {nextfa, ba0, ba1}, "nexta");
|
||||
llvm::Value* nextb = builder.CreateCall(reshape, {nextfb, bb0, bb1}, "nextb");
|
||||
llvm::Value* nexta = builder.CreateCall(reshape, {nextfa, ba0, ba1, br}, "nexta");
|
||||
llvm::Value* nextb = builder.CreateCall(reshape, {nextfb, bb0, bb1, br}, "nextb");
|
||||
a->addIncoming(starta, PrologBB);
|
||||
a->addIncoming(nexta, LoopBB);
|
||||
b->addIncoming(startb, PrologBB);
|
||||
@@ -283,8 +285,8 @@ int main(){
|
||||
llvm::Value* lastmaskb = builder.CreateCall(outer_and, {in_bounds_b0, in_bounds_b1}, "lastmaskb");
|
||||
llvm::Value* lastfa = builder.CreateCall(masked_load, {nextpa, lastmaska}, "lastfa");
|
||||
llvm::Value* lastfb = builder.CreateCall(masked_load, {nextpb, lastmaskb}, "lastfb");
|
||||
llvm::Value* lasta = builder.CreateCall(reshape, {lastfa, ba0, ba1}, "lasta");
|
||||
llvm::Value* lastb = builder.CreateCall(reshape, {lastfb, bb0, bb1}, "lastb");
|
||||
llvm::Value* lasta = builder.CreateCall(reshape, {lastfa, ba0, ba1, br}, "lasta");
|
||||
llvm::Value* lastb = builder.CreateCall(reshape, {lastfb, bb0, bb1, br}, "lastb");
|
||||
llvm::Value* loop = builder.CreateICmpSGT(nextk, _s0);
|
||||
a->addIncoming(lasta, LastIterBB);
|
||||
b->addIncoming(lastb, LastIterBB);
|
||||
|
Reference in New Issue
Block a user