[dnn/shift] added base pointer for a, b

This commit is contained in:
Philippe Tillet
2019-07-20 22:05:16 -07:00
parent d159455f7b
commit 484e3871cf
5 changed files with 26 additions and 21 deletions

View File

@@ -8,15 +8,15 @@
int main() { int main() {
bool AT = true; bool AT = false;
bool BT = false; bool BT = true;
typedef float T; typedef float T;
std::string ty = "fp16"; std::string ty = "fp16";
size_t dt_nbytes = sizeof(T); size_t dt_nbytes = sizeof(T);
// initialize default compute device // initialize default compute device
auto context = triton::driver::backend::contexts::get_default(); auto context = triton::driver::backend::contexts::get_default();
// matrix multiplication parameters // matrix multiplication parameters
int32_t M = 4096, N = 4096, K = 4096; int32_t M = 65536, N = 2048, K = 2048;
std::vector<T> hc(M*N); std::vector<T> hc(M*N);
std::vector<T> rc(M*N); std::vector<T> rc(M*N);
std::vector<T> ha(M*K); std::vector<T> ha(M*K);
@@ -37,7 +37,7 @@ int main() {
stream->write(dc, true, 0, hc); stream->write(dc, true, 0, hc);
stream->synchronize(); stream->synchronize();
triton::dnn::gemm gemm(M, N, K, AT, BT, ty, ty, 4, 4); triton::dnn::gemm gemm(M, N, K, AT, BT, ty, ty, 4, 4);
gemm.enqueue(stream, {da, db, dc}, false); gemm.enqueue(stream, {da, db, dc}, true);
// stream->read(dc, true, 0, hc); // stream->read(dc, true, 0, hc);
// gemm.cpu_ref<T>(rc, ha, hb); // gemm.cpu_ref<T>(rc, ha, hb);
// for(size_t i = 0; i < M*N; i++) // for(size_t i = 0; i < M*N; i++)

View File

@@ -14,13 +14,13 @@ int main() {
// initialize default compute device // initialize default compute device
auto context = triton::driver::backend::contexts::get_default(); auto context = triton::driver::backend::contexts::get_default();
auto op = triton::dnn::shift::WGRAD; auto op = triton::dnn::shift::FPROP;
// initialization // initialization
int32_t R = 3, S = 3; int32_t R = 3, S = 3;
int32_t B = 16, F = 4096; int32_t B = 64, F = 2048;
int32_t H = 16, W = 16; int32_t H = 32, W = 32;
int32_t C = 4096; int32_t C = 2048;
// random shifts // random shifts
std::vector<int32_t> shift_h(C); std::vector<int32_t> shift_h(C);

View File

@@ -59,8 +59,9 @@ alignment_info::cst_info alignment_info::populate_is_constant(ir::value *v) {
cst_info rhs = populate_is_constant(rhs_op); cst_info rhs = populate_is_constant(rhs_op);
if(lhs.num_cst==0 && rhs.value && x->is_int_div()){ if(lhs.num_cst==0 && rhs.value && x->is_int_div()){
unsigned max_contiguous = populate_max_contiguous(lhs_op); unsigned max_contiguous = populate_max_contiguous(lhs_op);
unsigned starting_multiple = populate_starting_multiple(lhs_op); // todo might not be entirely true
return cache({gcd(max_contiguous, rhs.value) - (starting_multiple % rhs.value), 0}); unsigned num_constants = gcd(max_contiguous, rhs.value);
return cache({num_constants, 0});
} }
return cache({std::min(lhs.num_cst, rhs.num_cst), 0}); return cache({std::min(lhs.num_cst, rhs.num_cst), 0});
} }
@@ -300,6 +301,7 @@ void alignment_info::run(ir::module &mod) {
for(ir::basic_block *block: fn->blocks()) for(ir::basic_block *block: fn->blocks())
for(ir::instruction *i: block->get_inst_list()){ for(ir::instruction *i: block->get_inst_list()){
populate_max_contiguous(i); populate_max_contiguous(i);
// std::cout << i->get_name() << " " << is_constant_.at(i).num_cst << " " << max_contiguous_.at(i) << " " << starting_multiple_.at(i) << std::endl;
} }
} }

View File

@@ -1148,7 +1148,7 @@ void selection::lower_tile_instruction(ir::instruction *ins, llvm::IRBuilder<> &
unsigned max_contiguous = axis_info_->get_max_contiguous(ptr); unsigned max_contiguous = axis_info_->get_max_contiguous(ptr);
unsigned alignment = std::min(starting_multiple, max_contiguous); unsigned alignment = std::min(starting_multiple, max_contiguous);
unsigned vector_size = std::min<unsigned>(result->axis(0).contiguous, alignment); unsigned vector_size = std::min<unsigned>(result->axis(0).contiguous, alignment);
vector_size = result->axis(0).contiguous; // vector_size = result->axis(0).contiguous;
// vector_size = 1; // vector_size = 1;
std::map<unsigned, Value*> packets; std::map<unsigned, Value*> packets;
distributed_tile *TP = (distributed_tile*)tmap_.at(ld->get_pointer_operand()); distributed_tile *TP = (distributed_tile*)tmap_.at(ld->get_pointer_operand());

View File

@@ -354,11 +354,6 @@ void shift(restrict read_only align(16) )" + a_ty_ + R"( *A,
fp32 acc[TM, TN] = 0; fp32 acc[TM, TN] = 0;
int32 pad_h = BH / 2; int32 pad_h = BH / 2;
int32 pad_w = BW / 2;)"; int32 pad_w = BW / 2;)";
if(op_ == WGRAD){
result += R"(
)";
}
/* A offsets */ /* A offsets */
if(op_ == FPROP){ if(op_ == FPROP){
@@ -408,13 +403,21 @@ if(op_ == WGRAD){
rbh = rbh * )" + stride_h + R"(; rbh = rbh * )" + stride_h + R"(;
int32 offkb[TK] = rbb*)" + ldb_b + R"( + rbw*ldb_w + rbh*ldb_h; int32 offkb[TK] = rbb*)" + ldb_b + R"( + rbw*ldb_w + rbh*ldb_h;
int32 offb0[TK, TN] = ryb[newaxis, :] * ldb_c; int32 offb0[TK, TN] = ryb[newaxis, :] * ldb_c;
int32 offb1[TK, TN] = offkb[:, newaxis] + shift;)"; int32 offb1[TK, TN] = offkb[:, newaxis];
)" + a_ty_ + "* pa_base[" + AS + R"(] = A + offa0;
)" + b_ty_ + "* pb_base[" + BS + R"(] = B + offb0 + shift;
)" + a_ty_ + "* pa[" + AS + R"(] = pa_base + offa1;
)" + b_ty_ + "* pb[" + BS + R"(] = pb_base + offb1;)";
}
else{
result += R"(
)" + a_ty_ + "* pa[" + AS + R"(] = A + offa0 + offa1;
)" + b_ty_ + "* pb[" + BS + R"(] = B + offb0 + offb1;)";
} }
/* Main loop */ /* Main loop */
/* Increment A pointers */
result += R"( result += R"(
)" + a_ty_ + "* pa[" + AS + R"(] = A + offa0 + offa1;
)" + b_ty_ + "* pb[" + BS + R"(] = B + offb0 + offb1;
int1 checka[)" + AS + "] = (rka < K)" + bca0 + R"(; int1 checka[)" + AS + "] = (rka < K)" + bca0 + R"(;
int1 checkb[)" + BS + "] = (rkb < K)" + bcb0 + R"(; int1 checkb[)" + BS + "] = (rkb < K)" + bcb0 + R"(;
)" + a_ty_ + " a[" + AS + R"(] = checka ? *pa : 0; )" + a_ty_ + " a[" + AS + R"(] = checka ? *pa : 0;
@@ -440,7 +443,7 @@ if(op_ == WGRAD){
rka = rka + TK;)" rka = rka + TK;)"
+ compute_bhw("ra", "TK", "rka") + R"( + compute_bhw("ra", "TK", "rka") + R"(
offxa = rab*)" + lda_b + R"( + raw*lda_w + rah*lda_h; offxa = rab*)" + lda_b + R"( + raw*lda_w + rah*lda_h;
pa = A + offa0 + offxa[:, newaxis];)"; pa = pa_base + offxa[:, newaxis];)";
} }
result += R"( result += R"(
@checka a = *pa;)"; @checka a = *pa;)";
@@ -453,7 +456,7 @@ if(op_ == WGRAD){
rbw = rbw * )" + stride_w + R"(; rbw = rbw * )" + stride_w + R"(;
rbh = rbh * )" + stride_h + R"(; rbh = rbh * )" + stride_h + R"(;
offkb = rbb*)" + ldb_b + R"( + rbw*ldb_w + rbh*ldb_h; offkb = rbb*)" + ldb_b + R"( + rbw*ldb_w + rbh*ldb_h;
pb = B + offb0 + offkb[:, newaxis] + shift;)"; pb = pb_base + offkb[:, newaxis];)";
} }
if(op_ == FPROP){ if(op_ == FPROP){
result += R"( result += R"(