[dnn/shift] added base pointer for a, b
This commit is contained in:
@@ -8,15 +8,15 @@
|
||||
|
||||
|
||||
int main() {
|
||||
bool AT = true;
|
||||
bool BT = false;
|
||||
bool AT = false;
|
||||
bool BT = true;
|
||||
typedef float T;
|
||||
std::string ty = "fp16";
|
||||
size_t dt_nbytes = sizeof(T);
|
||||
// initialize default compute device
|
||||
auto context = triton::driver::backend::contexts::get_default();
|
||||
// matrix multiplication parameters
|
||||
int32_t M = 4096, N = 4096, K = 4096;
|
||||
int32_t M = 65536, N = 2048, K = 2048;
|
||||
std::vector<T> hc(M*N);
|
||||
std::vector<T> rc(M*N);
|
||||
std::vector<T> ha(M*K);
|
||||
@@ -37,7 +37,7 @@ int main() {
|
||||
stream->write(dc, true, 0, hc);
|
||||
stream->synchronize();
|
||||
triton::dnn::gemm gemm(M, N, K, AT, BT, ty, ty, 4, 4);
|
||||
gemm.enqueue(stream, {da, db, dc}, false);
|
||||
gemm.enqueue(stream, {da, db, dc}, true);
|
||||
// stream->read(dc, true, 0, hc);
|
||||
// gemm.cpu_ref<T>(rc, ha, hb);
|
||||
// for(size_t i = 0; i < M*N; i++)
|
||||
|
@@ -14,13 +14,13 @@ int main() {
|
||||
|
||||
// initialize default compute device
|
||||
auto context = triton::driver::backend::contexts::get_default();
|
||||
auto op = triton::dnn::shift::WGRAD;
|
||||
auto op = triton::dnn::shift::FPROP;
|
||||
|
||||
// initialization
|
||||
int32_t R = 3, S = 3;
|
||||
int32_t B = 16, F = 4096;
|
||||
int32_t H = 16, W = 16;
|
||||
int32_t C = 4096;
|
||||
int32_t B = 64, F = 2048;
|
||||
int32_t H = 32, W = 32;
|
||||
int32_t C = 2048;
|
||||
|
||||
// random shifts
|
||||
std::vector<int32_t> shift_h(C);
|
||||
|
@@ -59,8 +59,9 @@ alignment_info::cst_info alignment_info::populate_is_constant(ir::value *v) {
|
||||
cst_info rhs = populate_is_constant(rhs_op);
|
||||
if(lhs.num_cst==0 && rhs.value && x->is_int_div()){
|
||||
unsigned max_contiguous = populate_max_contiguous(lhs_op);
|
||||
unsigned starting_multiple = populate_starting_multiple(lhs_op);
|
||||
return cache({gcd(max_contiguous, rhs.value) - (starting_multiple % rhs.value), 0});
|
||||
// todo might not be entirely true
|
||||
unsigned num_constants = gcd(max_contiguous, rhs.value);
|
||||
return cache({num_constants, 0});
|
||||
}
|
||||
return cache({std::min(lhs.num_cst, rhs.num_cst), 0});
|
||||
}
|
||||
@@ -300,6 +301,7 @@ void alignment_info::run(ir::module &mod) {
|
||||
for(ir::basic_block *block: fn->blocks())
|
||||
for(ir::instruction *i: block->get_inst_list()){
|
||||
populate_max_contiguous(i);
|
||||
// std::cout << i->get_name() << " " << is_constant_.at(i).num_cst << " " << max_contiguous_.at(i) << " " << starting_multiple_.at(i) << std::endl;
|
||||
}
|
||||
}
|
||||
|
||||
|
@@ -1148,7 +1148,7 @@ void selection::lower_tile_instruction(ir::instruction *ins, llvm::IRBuilder<> &
|
||||
unsigned max_contiguous = axis_info_->get_max_contiguous(ptr);
|
||||
unsigned alignment = std::min(starting_multiple, max_contiguous);
|
||||
unsigned vector_size = std::min<unsigned>(result->axis(0).contiguous, alignment);
|
||||
vector_size = result->axis(0).contiguous;
|
||||
// vector_size = result->axis(0).contiguous;
|
||||
// vector_size = 1;
|
||||
std::map<unsigned, Value*> packets;
|
||||
distributed_tile *TP = (distributed_tile*)tmap_.at(ld->get_pointer_operand());
|
||||
|
@@ -354,11 +354,6 @@ void shift(restrict read_only align(16) )" + a_ty_ + R"( *A,
|
||||
fp32 acc[TM, TN] = 0;
|
||||
int32 pad_h = BH / 2;
|
||||
int32 pad_w = BW / 2;)";
|
||||
if(op_ == WGRAD){
|
||||
result += R"(
|
||||
|
||||
)";
|
||||
}
|
||||
|
||||
/* A offsets */
|
||||
if(op_ == FPROP){
|
||||
@@ -408,13 +403,21 @@ if(op_ == WGRAD){
|
||||
rbh = rbh * )" + stride_h + R"(;
|
||||
int32 offkb[TK] = rbb*)" + ldb_b + R"( + rbw*ldb_w + rbh*ldb_h;
|
||||
int32 offb0[TK, TN] = ryb[newaxis, :] * ldb_c;
|
||||
int32 offb1[TK, TN] = offkb[:, newaxis] + shift;)";
|
||||
int32 offb1[TK, TN] = offkb[:, newaxis];
|
||||
)" + a_ty_ + "* pa_base[" + AS + R"(] = A + offa0;
|
||||
)" + b_ty_ + "* pb_base[" + BS + R"(] = B + offb0 + shift;
|
||||
)" + a_ty_ + "* pa[" + AS + R"(] = pa_base + offa1;
|
||||
)" + b_ty_ + "* pb[" + BS + R"(] = pb_base + offb1;)";
|
||||
}
|
||||
else{
|
||||
result += R"(
|
||||
)" + a_ty_ + "* pa[" + AS + R"(] = A + offa0 + offa1;
|
||||
)" + b_ty_ + "* pb[" + BS + R"(] = B + offb0 + offb1;)";
|
||||
}
|
||||
|
||||
/* Main loop */
|
||||
/* Increment A pointers */
|
||||
result += R"(
|
||||
)" + a_ty_ + "* pa[" + AS + R"(] = A + offa0 + offa1;
|
||||
)" + b_ty_ + "* pb[" + BS + R"(] = B + offb0 + offb1;
|
||||
int1 checka[)" + AS + "] = (rka < K)" + bca0 + R"(;
|
||||
int1 checkb[)" + BS + "] = (rkb < K)" + bcb0 + R"(;
|
||||
)" + a_ty_ + " a[" + AS + R"(] = checka ? *pa : 0;
|
||||
@@ -440,7 +443,7 @@ if(op_ == WGRAD){
|
||||
rka = rka + TK;)"
|
||||
+ compute_bhw("ra", "TK", "rka") + R"(
|
||||
offxa = rab*)" + lda_b + R"( + raw*lda_w + rah*lda_h;
|
||||
pa = A + offa0 + offxa[:, newaxis];)";
|
||||
pa = pa_base + offxa[:, newaxis];)";
|
||||
}
|
||||
result += R"(
|
||||
@checka a = *pa;)";
|
||||
@@ -453,7 +456,7 @@ if(op_ == WGRAD){
|
||||
rbw = rbw * )" + stride_w + R"(;
|
||||
rbh = rbh * )" + stride_h + R"(;
|
||||
offkb = rbb*)" + ldb_b + R"( + rbw*ldb_w + rbh*ldb_h;
|
||||
pb = B + offb0 + offkb[:, newaxis] + shift;)";
|
||||
pb = pb_base + offkb[:, newaxis];)";
|
||||
}
|
||||
if(op_ == FPROP){
|
||||
result += R"(
|
||||
|
Reference in New Issue
Block a user