[dnn/shift] added base pointer for a, b

This commit is contained in:
Philippe Tillet
2019-07-20 22:05:16 -07:00
parent d159455f7b
commit 484e3871cf
5 changed files with 26 additions and 21 deletions

View File

@@ -8,15 +8,15 @@
int main() {
bool AT = true;
bool BT = false;
bool AT = false;
bool BT = true;
typedef float T;
std::string ty = "fp16";
size_t dt_nbytes = sizeof(T);
// initialize default compute device
auto context = triton::driver::backend::contexts::get_default();
// matrix multiplication parameters
int32_t M = 4096, N = 4096, K = 4096;
int32_t M = 65536, N = 2048, K = 2048;
std::vector<T> hc(M*N);
std::vector<T> rc(M*N);
std::vector<T> ha(M*K);
@@ -37,7 +37,7 @@ int main() {
stream->write(dc, true, 0, hc);
stream->synchronize();
triton::dnn::gemm gemm(M, N, K, AT, BT, ty, ty, 4, 4);
gemm.enqueue(stream, {da, db, dc}, false);
gemm.enqueue(stream, {da, db, dc}, true);
// stream->read(dc, true, 0, hc);
// gemm.cpu_ref<T>(rc, ha, hb);
// for(size_t i = 0; i < M*N; i++)

View File

@@ -14,13 +14,13 @@ int main() {
// initialize default compute device
auto context = triton::driver::backend::contexts::get_default();
auto op = triton::dnn::shift::WGRAD;
auto op = triton::dnn::shift::FPROP;
// initialization
int32_t R = 3, S = 3;
int32_t B = 16, F = 4096;
int32_t H = 16, W = 16;
int32_t C = 4096;
int32_t B = 64, F = 2048;
int32_t H = 32, W = 32;
int32_t C = 2048;
// random shifts
std::vector<int32_t> shift_h(C);

View File

@@ -59,8 +59,9 @@ alignment_info::cst_info alignment_info::populate_is_constant(ir::value *v) {
cst_info rhs = populate_is_constant(rhs_op);
if(lhs.num_cst==0 && rhs.value && x->is_int_div()){
unsigned max_contiguous = populate_max_contiguous(lhs_op);
unsigned starting_multiple = populate_starting_multiple(lhs_op);
return cache({gcd(max_contiguous, rhs.value) - (starting_multiple % rhs.value), 0});
// todo might not be entirely true
unsigned num_constants = gcd(max_contiguous, rhs.value);
return cache({num_constants, 0});
}
return cache({std::min(lhs.num_cst, rhs.num_cst), 0});
}
@@ -300,6 +301,7 @@ void alignment_info::run(ir::module &mod) {
for(ir::basic_block *block: fn->blocks())
for(ir::instruction *i: block->get_inst_list()){
populate_max_contiguous(i);
// std::cout << i->get_name() << " " << is_constant_.at(i).num_cst << " " << max_contiguous_.at(i) << " " << starting_multiple_.at(i) << std::endl;
}
}

View File

@@ -1148,7 +1148,7 @@ void selection::lower_tile_instruction(ir::instruction *ins, llvm::IRBuilder<> &
unsigned max_contiguous = axis_info_->get_max_contiguous(ptr);
unsigned alignment = std::min(starting_multiple, max_contiguous);
unsigned vector_size = std::min<unsigned>(result->axis(0).contiguous, alignment);
vector_size = result->axis(0).contiguous;
// vector_size = result->axis(0).contiguous;
// vector_size = 1;
std::map<unsigned, Value*> packets;
distributed_tile *TP = (distributed_tile*)tmap_.at(ld->get_pointer_operand());

View File

@@ -354,11 +354,6 @@ void shift(restrict read_only align(16) )" + a_ty_ + R"( *A,
fp32 acc[TM, TN] = 0;
int32 pad_h = BH / 2;
int32 pad_w = BW / 2;)";
if(op_ == WGRAD){
result += R"(
)";
}
/* A offsets */
if(op_ == FPROP){
@@ -408,13 +403,21 @@ if(op_ == WGRAD){
rbh = rbh * )" + stride_h + R"(;
int32 offkb[TK] = rbb*)" + ldb_b + R"( + rbw*ldb_w + rbh*ldb_h;
int32 offb0[TK, TN] = ryb[newaxis, :] * ldb_c;
int32 offb1[TK, TN] = offkb[:, newaxis] + shift;)";
int32 offb1[TK, TN] = offkb[:, newaxis];
)" + a_ty_ + "* pa_base[" + AS + R"(] = A + offa0;
)" + b_ty_ + "* pb_base[" + BS + R"(] = B + offb0 + shift;
)" + a_ty_ + "* pa[" + AS + R"(] = pa_base + offa1;
)" + b_ty_ + "* pb[" + BS + R"(] = pb_base + offb1;)";
}
else{
result += R"(
)" + a_ty_ + "* pa[" + AS + R"(] = A + offa0 + offa1;
)" + b_ty_ + "* pb[" + BS + R"(] = B + offb0 + offb1;)";
}
/* Main loop */
/* Increment A pointers */
result += R"(
)" + a_ty_ + "* pa[" + AS + R"(] = A + offa0 + offa1;
)" + b_ty_ + "* pb[" + BS + R"(] = B + offb0 + offb1;
int1 checka[)" + AS + "] = (rka < K)" + bca0 + R"(;
int1 checkb[)" + BS + "] = (rkb < K)" + bcb0 + R"(;
)" + a_ty_ + " a[" + AS + R"(] = checka ? *pa : 0;
@@ -440,7 +443,7 @@ if(op_ == WGRAD){
rka = rka + TK;)"
+ compute_bhw("ra", "TK", "rka") + R"(
offxa = rab*)" + lda_b + R"( + raw*lda_w + rah*lda_h;
pa = A + offa0 + offxa[:, newaxis];)";
pa = pa_base + offxa[:, newaxis];)";
}
result += R"(
@checka a = *pa;)";
@@ -453,7 +456,7 @@ if(op_ == WGRAD){
rbw = rbw * )" + stride_w + R"(;
rbh = rbh * )" + stride_h + R"(;
offkb = rbb*)" + ldb_b + R"( + rbw*ldb_w + rbh*ldb_h;
pb = B + offb0 + offkb[:, newaxis] + shift;)";
pb = pb_base + offkb[:, newaxis];)";
}
if(op_ == FPROP){
result += R"(