[codegen/alignment_info] better alignment information
This commit is contained in:
@@ -8,8 +8,8 @@
|
||||
|
||||
|
||||
int main() {
|
||||
bool AT = false;
|
||||
bool BT = true;
|
||||
bool AT = true;
|
||||
bool BT = false;
|
||||
typedef float T;
|
||||
std::string ty = "fp16";
|
||||
size_t dt_nbytes = sizeof(T);
|
||||
@@ -37,7 +37,7 @@ int main() {
|
||||
stream->write(dc, true, 0, hc);
|
||||
stream->synchronize();
|
||||
triton::dnn::gemm gemm(M, N, K, AT, BT, ty, ty, 4, 4);
|
||||
gemm.enqueue(stream, {da, db, dc}, true);
|
||||
gemm.enqueue(stream, {da, db, dc}, false);
|
||||
// stream->read(dc, true, 0, hc);
|
||||
// gemm.cpu_ref<T>(rc, ha, hb);
|
||||
// for(size_t i = 0; i < M*N; i++)
|
||||
|
@@ -14,12 +14,17 @@ namespace ir {
|
||||
namespace codegen{
|
||||
|
||||
class alignment_info {
|
||||
struct cst_info {
|
||||
unsigned num_cst;
|
||||
unsigned value;
|
||||
};
|
||||
|
||||
private:
|
||||
// helpers
|
||||
bool is_first_axis_unit(ir::value *v);
|
||||
|
||||
// populate maps
|
||||
bool populate_is_constant(ir::value *v);
|
||||
cst_info populate_is_constant(ir::value *v);
|
||||
unsigned populate_max_contiguous(ir::value *v);
|
||||
unsigned populate_starting_multiple(ir::value *v);
|
||||
|
||||
@@ -29,7 +34,7 @@ public:
|
||||
unsigned get_max_contiguous(ir::value* v) const;
|
||||
|
||||
private:
|
||||
std::map<ir::value*, bool> is_constant_;
|
||||
std::map<ir::value*, cst_info> is_constant_;
|
||||
std::map<ir::value*, unsigned> max_contiguous_;
|
||||
std::map<ir::value*, unsigned> starting_multiple_;
|
||||
};
|
||||
|
@@ -70,6 +70,7 @@ public:
|
||||
void target_independent(ir::module &module) {
|
||||
optimize_dot.run(module);
|
||||
optimize_trans.run(module);
|
||||
// ir::print(module, std::cout);
|
||||
}
|
||||
|
||||
void target_dependent(ir::module &module) {
|
||||
|
@@ -9,6 +9,18 @@ namespace triton {
|
||||
namespace codegen{
|
||||
|
||||
|
||||
inline int gcd(int a, int b) {
|
||||
if (a == 0)
|
||||
return b;
|
||||
if (b == 0)
|
||||
return a;
|
||||
if (a == b)
|
||||
return a;
|
||||
if (a > b)
|
||||
return gcd(a-b, b);
|
||||
return gcd(a, b-a);
|
||||
}
|
||||
|
||||
template<class T>
|
||||
inline T add_to_cache(ir::value *i, T value, std::map<ir::value*, T> &map) {
|
||||
return map[i] = value;
|
||||
@@ -22,50 +34,69 @@ bool alignment_info::is_first_axis_unit(ir::value *x){
|
||||
return true;
|
||||
}
|
||||
|
||||
bool alignment_info::populate_is_constant(ir::value *v) {
|
||||
alignment_info::cst_info alignment_info::populate_is_constant(ir::value *v) {
|
||||
if(is_constant_.find(v) != is_constant_.end())
|
||||
return is_constant_.at(v);
|
||||
// helper for the cache
|
||||
auto cache = [this,v](bool value){ return add_to_cache(v, value, is_constant_); };
|
||||
auto cache = [this,v](cst_info value){
|
||||
return add_to_cache(v, value, is_constant_); }
|
||||
;
|
||||
// populate
|
||||
if(auto *x = dynamic_cast<ir::retile_inst*>(v)){
|
||||
ir::value *op = x->get_operand(0);
|
||||
populate_is_constant(op);
|
||||
if(is_first_axis_unit(op))
|
||||
return cache(true);
|
||||
auto op_cst = populate_is_constant(op);
|
||||
if(is_first_axis_unit(op)){
|
||||
unsigned num_cst = x->get_type()->get_tile_shapes()[0]->get_value();
|
||||
return cache({num_cst, op_cst.value});
|
||||
}
|
||||
}
|
||||
if(auto *x = dynamic_cast<ir::constant_int*>(v))
|
||||
return cache(true);
|
||||
return cache({true, (unsigned)x->get_value()});
|
||||
if(auto *x = dynamic_cast<ir::binary_operator*>(v)){
|
||||
bool lhs = populate_is_constant(x->get_operand(0));
|
||||
bool rhs = populate_is_constant(x->get_operand(1));
|
||||
return cache(lhs && rhs);
|
||||
ir::value* lhs_op = x->get_operand(0);
|
||||
ir::value* rhs_op = x->get_operand(1);
|
||||
cst_info lhs = populate_is_constant(lhs_op);
|
||||
cst_info rhs = populate_is_constant(rhs_op);
|
||||
if(lhs.num_cst==0 && rhs.value && x->is_int_div()){
|
||||
unsigned max_contiguous = populate_max_contiguous(lhs_op);
|
||||
unsigned starting_multiple = populate_starting_multiple(lhs_op);
|
||||
return cache({gcd(max_contiguous, rhs.value) - (starting_multiple % rhs.value), 0});
|
||||
}
|
||||
return cache({std::min(lhs.num_cst, rhs.num_cst), 0});
|
||||
}
|
||||
if(auto *x = dynamic_cast<ir::getelementptr_inst*>(v)){
|
||||
ir::value* lhs_op = x->get_operand(0);
|
||||
ir::value* rhs_op = x->get_operand(1);
|
||||
cst_info lhs = populate_is_constant(lhs_op);
|
||||
cst_info rhs = populate_is_constant(rhs_op);
|
||||
return cache({std::min(lhs.num_cst, rhs.num_cst), 0});
|
||||
}
|
||||
if(auto *x = dynamic_cast<ir::psi_inst*>(v)){
|
||||
bool value_true = populate_is_constant(x->get_value_true());
|
||||
bool value_false = populate_is_constant(x->get_value_false());
|
||||
return cache(value_true && value_false);
|
||||
cst_info value_true = populate_is_constant(x->get_value_true());
|
||||
cst_info value_false = populate_is_constant(x->get_value_false());
|
||||
return cache({std::min(value_true.num_cst, value_false.num_cst), 0});
|
||||
}
|
||||
if(v->get_type()->is_tile_ty())
|
||||
return cache(false);
|
||||
return cache({0, 0});
|
||||
if(auto *x = dynamic_cast<ir::phi_node*>(v)){
|
||||
// put a conservative initial value in phi node to avoid infinite recursion
|
||||
bool result = true;
|
||||
unsigned result = 1;
|
||||
for(unsigned n = 0; n < x->get_num_incoming(); n++){
|
||||
ir::value* inc = x->get_incoming_value(n);
|
||||
if(is_constant_.find(inc) != is_constant_.end())
|
||||
result = is_constant_.at(inc);
|
||||
result = is_constant_.at(inc).num_cst;
|
||||
}
|
||||
cache(result);
|
||||
cache({result, 0});
|
||||
// recurse
|
||||
for(unsigned n = 0; n < x->get_num_incoming(); n++){
|
||||
ir::value* inc = x->get_incoming_value(n);
|
||||
result = result && populate_is_constant(inc);
|
||||
result = std::min(result, populate_is_constant(inc).num_cst);
|
||||
}
|
||||
return cache(result);
|
||||
return cache({result, 0});
|
||||
}
|
||||
// scalars are always constant in the contiguous dimension
|
||||
return cache(true);
|
||||
// but value is not known at compile-time
|
||||
return cache({1, 0});
|
||||
}
|
||||
|
||||
unsigned alignment_info::populate_max_contiguous(ir::value *v){
|
||||
@@ -95,13 +126,21 @@ unsigned alignment_info::populate_max_contiguous(ir::value *v){
|
||||
ir::value* rhs = x->get_operand(1);
|
||||
unsigned lhs_max_contiguous = populate_max_contiguous(lhs);
|
||||
unsigned rhs_max_contiguous = populate_max_contiguous(rhs);
|
||||
bool lhs_has_cst = populate_is_constant(lhs);
|
||||
bool rhs_has_cst = populate_is_constant(rhs);
|
||||
if(x->is_int_add_sub()){
|
||||
if(lhs_has_cst)
|
||||
return cache(rhs_max_contiguous);
|
||||
if(rhs_has_cst)
|
||||
cst_info lhs_cst_info = populate_is_constant(lhs);
|
||||
cst_info rhs_cst_info = populate_is_constant(rhs);
|
||||
if(x->is_int_rem() && rhs_cst_info.value > 0)
|
||||
return cache(std::min(lhs_max_contiguous, rhs_cst_info.value));
|
||||
if(x->is_int_mult()){
|
||||
if(rhs_cst_info.value == 1)
|
||||
return cache(lhs_max_contiguous);
|
||||
if(lhs_cst_info.value == 1)
|
||||
return cache(rhs_max_contiguous);
|
||||
}
|
||||
if(x->is_int_add_sub()){
|
||||
if(lhs_cst_info.num_cst)
|
||||
return cache(gcd(rhs_max_contiguous, lhs_cst_info.num_cst));
|
||||
if(rhs_cst_info.num_cst)
|
||||
return cache(gcd(lhs_max_contiguous, rhs_cst_info.num_cst));
|
||||
}
|
||||
}
|
||||
if(auto *x = dynamic_cast<ir::psi_inst*>(v)){
|
||||
@@ -114,11 +153,11 @@ unsigned alignment_info::populate_max_contiguous(ir::value *v){
|
||||
ir::value* rhs = x->get_operand(1);
|
||||
unsigned lhs_max_contiguous = populate_max_contiguous(lhs);
|
||||
unsigned rhs_max_contiguous = populate_max_contiguous(rhs);
|
||||
bool lhs_has_cst = populate_is_constant(lhs);
|
||||
bool rhs_has_cst = populate_is_constant(rhs);
|
||||
if(lhs_has_cst)
|
||||
auto lhs_cst_info = populate_is_constant(lhs);
|
||||
auto rhs_cst_info = populate_is_constant(rhs);
|
||||
if(lhs_cst_info.num_cst)
|
||||
return cache(rhs_max_contiguous);
|
||||
if(rhs_has_cst)
|
||||
if(rhs_cst_info.num_cst)
|
||||
return cache(lhs_max_contiguous);
|
||||
}
|
||||
if(auto *x = dynamic_cast<ir::phi_node*>(v)){
|
||||
@@ -140,22 +179,12 @@ unsigned alignment_info::populate_max_contiguous(ir::value *v){
|
||||
return cache(1);
|
||||
}
|
||||
|
||||
inline int gcd(int a, int b) {
|
||||
if (a == 0)
|
||||
return b;
|
||||
if (b == 0)
|
||||
return a;
|
||||
if (a == b)
|
||||
return a;
|
||||
if (a > b)
|
||||
return gcd(a-b, b);
|
||||
return gcd(a, b-a);
|
||||
}
|
||||
|
||||
unsigned alignment_info::populate_starting_multiple(ir::value *v){
|
||||
if(starting_multiple_.find(v) != starting_multiple_.end())
|
||||
return starting_multiple_.at(v);
|
||||
auto cache = [this,v](unsigned value){ return add_to_cache(v, value, starting_multiple_); };
|
||||
auto cache = [this,v](unsigned value){
|
||||
return add_to_cache(v, value, starting_multiple_);
|
||||
};
|
||||
// has metadata
|
||||
if(auto *x = dynamic_cast<ir::instruction*>(v)){
|
||||
unsigned multiple_of = x->get_metadata(ir::metadata::multiple_of);
|
||||
@@ -185,15 +214,16 @@ unsigned alignment_info::populate_starting_multiple(ir::value *v){
|
||||
return cache(gcd(lhs, rhs));
|
||||
if(x->is_int_div())
|
||||
return cache(std::max(lhs / rhs, 1));
|
||||
if(x->is_int_rem())
|
||||
return cache(std::max(lhs % rhs, 1));
|
||||
if(x->is_int_rem() && rhs > 1)
|
||||
return cache(gcd(lhs, rhs));
|
||||
if(x->is_shl())
|
||||
return cache(lhs << rhs);
|
||||
if(x->is_shr())
|
||||
return cache(std::max(lhs >> rhs, 1));
|
||||
}
|
||||
if(auto *x = dynamic_cast<ir::constant_int*>(v))
|
||||
if(auto *x = dynamic_cast<ir::constant_int*>(v)){
|
||||
return cache(x->get_value());
|
||||
}
|
||||
if(auto *x = dynamic_cast<ir::constant_range*>(v)){
|
||||
return cache(x->get_first()->get_value());
|
||||
}
|
||||
@@ -270,7 +300,6 @@ void alignment_info::run(ir::module &mod) {
|
||||
for(ir::basic_block *block: fn->blocks())
|
||||
for(ir::instruction *i: block->get_inst_list()){
|
||||
populate_max_contiguous(i);
|
||||
// std::cout << i->get_name() << " " << max_contiguous_.at(i) << " " << starting_multiple_.at(i) << std::endl;
|
||||
}
|
||||
}
|
||||
|
||||
|
@@ -233,10 +233,15 @@ void tune::run(ir::module &mod) {
|
||||
for(ir::instruction *i : block->get_inst_list()){
|
||||
if(fragments_.find({i, 0}) != fragments_.end() && fragments_.at({i, 0}) != STRIDED_SCAN)
|
||||
continue;
|
||||
if(dynamic_cast<ir::load_inst*>(i) && i->get_type()->is_tile_ty()){
|
||||
ir::type *ty = mod.get_builder().get_int32_ty();
|
||||
std::unique_ptr<ir::metaparameter> tmp(ir::metaparameter::create(ctx, ty, 4, 8));
|
||||
*params_.at(i).at("nts.d0") = *tmp;
|
||||
if(auto *ld = dynamic_cast<ir::load_inst*>(i))
|
||||
if(i->get_type()->is_tile_ty()){
|
||||
ir::type *ptr_ty = ld->get_pointer_operand()->get_type()->get_scalar_ty();
|
||||
size_t addr_space = ptr_ty->get_pointer_address_space();
|
||||
if(addr_space < 4){
|
||||
ir::type *ty = mod.get_builder().get_int32_ty();
|
||||
std::unique_ptr<ir::metaparameter> tmp(ir::metaparameter::create(ctx, ty, 8, 8));
|
||||
*params_.at(i).at("nts.d0") = *tmp;
|
||||
}
|
||||
}
|
||||
if(dynamic_cast<ir::dot_inst*>(i) && i->get_type()->is_tile_ty()){
|
||||
ir::type *ty = mod.get_builder().get_int32_ty();
|
||||
|
@@ -51,8 +51,7 @@ void base::enqueue(driver::stream *stream, std::vector<driver::buffer *> args, b
|
||||
jit->add_module(name_.c_str(), src.c_str(), best.params);
|
||||
}
|
||||
else {
|
||||
// jit->add_module(name_.c_str(), src.c_str(), jit->get_valid(name_.c_str(), src.c_str()));
|
||||
jit->add_module(name_.c_str(), src.c_str(), {32, 128, 16, 128, 2, 2, 2, 2, 4, 4, 32, 8, 4, 1});
|
||||
jit->add_module(name_.c_str(), src.c_str(), jit->get_valid(name_.c_str(), src.c_str()));
|
||||
}
|
||||
triton::driver::kernel* kernel = jit->get_function(name_.c_str());
|
||||
clone->init_impl(stream, (triton::driver::cu_module*)kernel->module());
|
||||
|
@@ -113,8 +113,8 @@ void matmul(restrict read_only align(16) )" + a_ty_ + R"( *A,
|
||||
int32 bound, int32 *locks, int32 grid0, int32 grid1) {
|
||||
int32 ridx = get_range_id(0);
|
||||
int32 ridy = get_range_id(1);
|
||||
int32 rxa[TM] = ridx*TM + (0 ... TM);
|
||||
int32 ryb[TN] = ridy*TN + (0 ... TN);
|
||||
int32 rxa[TM] = ridx * TM + (0 ... TM);
|
||||
int32 ryb[TN] = ridy * TN + (0 ... TN);
|
||||
int32 rka[TK] = 0 ... TK;
|
||||
int32 rkb[TK] = 0 ... TK;
|
||||
fp32 c[TM, TN] = 0;
|
||||
|
@@ -27,7 +27,7 @@ shift::shift(int B, int C,
|
||||
layout_(layout){
|
||||
// std::cout << B_ << " " << C_ << " " << F_ << " " << stride_h_ << " " << stride_w_ << " " << a_ty_ << " " << b_ty_ << " " << ty_ << " " << layout_ << std::endl;
|
||||
// max number of channels
|
||||
TK_ = (ty == FPROP && a_ty_ == "fp32") ? 8 : 16;
|
||||
TK_ = (ty == FPROP && a_ty_ == "fp32") ? 8 : 32;
|
||||
MAX_C_ = 8192 + TK_;
|
||||
// activation sizes
|
||||
CD_ = AD_ / stride_d_;
|
||||
@@ -223,7 +223,7 @@ void shift::deinit_impl() {
|
||||
void shift::enqueue_impl(driver::stream *stream, driver::kernel *kernel,
|
||||
std::vector<driver::buffer *> args,
|
||||
runtime::launch_information info) {
|
||||
unsigned TM = info.global_range_size[0], TN = info.global_range_size[1];
|
||||
unsigned TM = info.globals.at("TM"), TN = info.globals.at("TN");
|
||||
unsigned grid_0 = (M_ + TM - 1)/TM;
|
||||
unsigned grid_1 = (N_ + TN - 1)/TN;
|
||||
unsigned num_locks = grid_0 * grid_1;
|
||||
@@ -278,6 +278,8 @@ void shift::triton_c_src(std::ostream &os) const {
|
||||
std::string usea = AT_ ? "trans(a)" : "a";
|
||||
std::string useb = BT_ ? "trans(b)" : "b";
|
||||
std::string bca0 = "[newaxis, :]", bca1 = "[:, newaxis]";
|
||||
std::string stride_h = std::to_string(stride_h_);
|
||||
std::string stride_w = std::to_string(stride_w_);
|
||||
if(AT_){
|
||||
std::swap(AS0, AS1);
|
||||
std::swap(bca0, bca1);
|
||||
@@ -290,6 +292,11 @@ void shift::triton_c_src(std::ostream &os) const {
|
||||
std::string BS = BS0 + ", " + BS1;
|
||||
bool is_chwn = layout_ == CHWN;
|
||||
|
||||
std::string lda_b = is_chwn ? "1" : "lda_b";
|
||||
std::string ldb_b = is_chwn ? "1" : "ldb_b";
|
||||
std::string ldc_b = is_chwn ? "1" : "ldc_b";
|
||||
|
||||
|
||||
auto compute_bhw = [&](std::string rx, std::string sz, std::string rkx){
|
||||
std::string B = std::to_string(B_);
|
||||
std::string CW = std::to_string(ICW_);
|
||||
@@ -317,7 +324,7 @@ const tunable int32 TM = {16, 32, 64, 128};
|
||||
const tunable int32 TN = {16, 32, 64, 128};
|
||||
const tunable int32 TK = {)" + std::to_string(TK_) + "};";
|
||||
if(op_ == WGRAD)
|
||||
result += "const tunable int32 GZ = {1, 4, 16};";
|
||||
result += "const tunable int32 GZ = {1};";
|
||||
else
|
||||
result += "const tunable int32 GZ = {1};";
|
||||
|
||||
@@ -329,30 +336,27 @@ void shift(restrict read_only align(16) )" + a_ty_ + R"( *A,
|
||||
)" + c_ty_ + R"( *C,
|
||||
int32 M, int32 N, int32 K,
|
||||
int32 stride_h, int32 stride_w,
|
||||
multiple_of(4) int32 lda_b, multiple_of(4) int32 lda_w, multiple_of(4) int32 lda_h, multiple_of(4) int32 lda_c,
|
||||
multiple_of(4) int32 ldb_b, multiple_of(4) int32 ldb_w, multiple_of(4) int32 ldb_h, multiple_of(4) int32 ldb_c,
|
||||
multiple_of(4) int32 ldc_b, multiple_of(4) int32 ldc_w, multiple_of(4) int32 ldc_h, multiple_of(4) int32 ldc_c,
|
||||
multiple_of(8) int32 lda_b, multiple_of(8) int32 lda_w, multiple_of(8) int32 lda_h, multiple_of(8) int32 lda_c,
|
||||
multiple_of(8) int32 ldb_b, multiple_of(8) int32 ldb_w, multiple_of(8) int32 ldb_h, multiple_of(8) int32 ldb_c,
|
||||
multiple_of(8) int32 ldc_b, multiple_of(8) int32 ldc_w, multiple_of(8) int32 ldc_h, multiple_of(8) int32 ldc_c,
|
||||
int32 NB,
|
||||
int32 AH, int32 AW,
|
||||
int32 BH, int32 BW,
|
||||
int32 CH, int32 CW,
|
||||
int32* locks, int32 grid0, int32 grid1, int32 grid2) {
|
||||
int32 rxa[TM] = get_global_range[TM](0);
|
||||
int32 ryb[TN] = get_global_range[TN](1);
|
||||
int32 rz = get_global_range[1](2);
|
||||
int32 ridx = get_range_id(0);
|
||||
int32 ridy = get_range_id(1);
|
||||
int32 rz = get_range_id(2);
|
||||
int32 rxa[TM] = ridx*TM + (0 ... TM);
|
||||
int32 ryb[TN] = ridy*TN + (0 ... TN);
|
||||
int32 rka[TK] = 0 ... TK;
|
||||
int32 rkb[TK] = 0 ... TK;
|
||||
fp32 acc[TM, TN] = 0;
|
||||
int32 pad_h = BH / 2;
|
||||
int32 pad_w = BW / 2;
|
||||
int32 div = K / grid2;
|
||||
int32 rem = K % grid2;
|
||||
K = select(rz < rem, div - 1, div);
|
||||
int32 offk = select(rz < rem, rz*(div + 1), rz*div + rem);)";
|
||||
int32 pad_w = BW / 2;)";
|
||||
if(op_ == WGRAD){
|
||||
result += R"(
|
||||
rka = rka + offk;
|
||||
rkb = rkb + offk;
|
||||
|
||||
)";
|
||||
}
|
||||
|
||||
@@ -360,31 +364,26 @@ if(op_ == WGRAD){
|
||||
if(op_ == FPROP){
|
||||
result +=
|
||||
compute_bhw("ra", "TM", "rxa") + R"(
|
||||
raw = raw * stride_w;
|
||||
rah = rah * stride_h;
|
||||
int32 offxa[TM] = rab*lda_b + raw*lda_w + rah*lda_h;
|
||||
raw = raw * )" + stride_w + R"(;
|
||||
rah = rah * )" + stride_h + R"(;
|
||||
int32 offxa[TM] = rab*)" + lda_b + R"( + raw*lda_w + rah*lda_h;
|
||||
int32 offa0[TM, TK] = offxa[:, newaxis];
|
||||
__constant__ int32* pd[TK] = delta_a + rka;
|
||||
multiple_of(4) int32 d[TK] = *pd;
|
||||
multiple_of(8) int32 d[TK] = *pd;
|
||||
int32 offa1[TM, TK] = d[newaxis, :];)";
|
||||
}
|
||||
if(op_ == BPROP){
|
||||
result +=
|
||||
compute_bhw("ra", "TM", "rxa") + R"(
|
||||
int32 offxa[TM] = rab*lda_b + raw*lda_w + rah*lda_h;
|
||||
int32 offxa[TM] = rab*)" + lda_b + R"( + raw*lda_w + rah*lda_h;
|
||||
int32 offa0[TM, TK] = offxa[:, newaxis];
|
||||
int32 offa1[TM, TK] = rka[newaxis, :] * lda_c;)";
|
||||
}
|
||||
if(op_ == WGRAD && layout_ == CHWN){
|
||||
result += R"(
|
||||
int32 offa0[TK, TM] = rxa[newaxis, :] * lda_c;
|
||||
int32 offa1[TK, TM] = rka[:, newaxis];)";
|
||||
}
|
||||
if(op_ == WGRAD && layout_ == NCHW){
|
||||
if(op_ == WGRAD){
|
||||
result +=
|
||||
compute_bhw("ra", "TK", "rka") + R"(
|
||||
int32 offa0[TK, TM] = rxa[newaxis, :] * lda_c;
|
||||
int32 offxa[TK] = rab*lda_b + raw*lda_w + rah*lda_h;
|
||||
int32 offxa[TK] = rab*)" + lda_b + R"( + raw*lda_w + rah*lda_h;
|
||||
int32 offa1[TK, TM] = offxa[:, newaxis];)";
|
||||
}
|
||||
|
||||
@@ -403,11 +402,11 @@ if(op_ == WGRAD){
|
||||
result +=
|
||||
compute_bhw("rb", "TK", "rkb") + R"(
|
||||
__constant__ int32* pd[TN] = delta_a + ryb;
|
||||
int32 d[TN] = *pd;
|
||||
int32 shift[TK, TN] = d[newaxis, :];
|
||||
rbw = rbw * stride_w;
|
||||
rbh = rbh * stride_h;
|
||||
int32 offkb[TK] = rbb*ldb_b + rbw*ldb_w + rbh*ldb_h;
|
||||
multiple_of(8) int32 d[TN] = *pd;
|
||||
multiple_of(8) int32 shift[TK, TN] = d[newaxis, :];
|
||||
rbw = rbw * )" + stride_w + R"(;
|
||||
rbh = rbh * )" + stride_h + R"(;
|
||||
int32 offkb[TK] = rbb*)" + ldb_b + R"( + rbw*ldb_w + rbh*ldb_h;
|
||||
int32 offb0[TK, TN] = ryb[newaxis, :] * ldb_c;
|
||||
int32 offb1[TK, TN] = offkb[:, newaxis] + shift;)";
|
||||
}
|
||||
@@ -416,8 +415,8 @@ if(op_ == WGRAD){
|
||||
result += R"(
|
||||
)" + a_ty_ + "* pa[" + AS + R"(] = A + offa0 + offa1;
|
||||
)" + b_ty_ + "* pb[" + BS + R"(] = B + offb0 + offb1;
|
||||
int1 checka[)" + AS + "] = (rka < K + offk)" + bca0 + R"(;
|
||||
int1 checkb[)" + BS + "] = (rkb < K + offk)" + bcb0 + R"(;
|
||||
int1 checka[)" + AS + "] = (rka < K)" + bca0 + R"(;
|
||||
int1 checkb[)" + BS + "] = (rkb < K)" + bcb0 + R"(;
|
||||
)" + a_ty_ + " a[" + AS + R"(] = checka ? *pa : 0;
|
||||
)" + b_ty_ + " b[" + BS + R"(] = checkb ? *pb : 0;
|
||||
for(int32 k = K; k > 0; k = k - TK){
|
||||
@@ -436,15 +435,11 @@ if(op_ == BPROP){
|
||||
result += R"(
|
||||
pa = pa + TK * lda_c;)";
|
||||
}
|
||||
if(op_ == WGRAD && layout_ == CHWN){
|
||||
result += R"(
|
||||
pa = pa + TK;)";
|
||||
}
|
||||
if(op_ == WGRAD && layout_ == NCHW){
|
||||
if(op_ == WGRAD){
|
||||
result += R"(
|
||||
rka = rka + TK;)"
|
||||
+ compute_bhw("ra", "TK", "rka") + R"(
|
||||
offxa = rab*lda_b + raw*lda_w + rah*lda_h;
|
||||
offxa = rab*)" + lda_b + R"( + raw*lda_w + rah*lda_h;
|
||||
pa = A + offa0 + offxa[:, newaxis];)";
|
||||
}
|
||||
result += R"(
|
||||
@@ -455,9 +450,9 @@ if(op_ == WGRAD){
|
||||
result += R"(
|
||||
rkb = rkb + TK;)"
|
||||
+ compute_bhw("rb", "TK", "rkb") + R"(
|
||||
rbw = rbw * stride_w;
|
||||
rbh = rbh * stride_h;
|
||||
offkb = rbb*ldb_b + rbw*ldb_w + rbh*ldb_h;
|
||||
rbw = rbw * )" + stride_w + R"(;
|
||||
rbh = rbh * )" + stride_h + R"(;
|
||||
offkb = rbb*)" + ldb_b + R"( + rbw*ldb_w + rbh*ldb_h;
|
||||
pb = B + offb0 + offkb[:, newaxis] + shift;)";
|
||||
}
|
||||
if(op_ == FPROP){
|
||||
@@ -471,21 +466,21 @@ if(op_ == BPROP){
|
||||
result += R"(
|
||||
@checkb b = *pb;
|
||||
}
|
||||
int32 rxc[TM] = get_global_range[TM](0);
|
||||
int32 ryc[TN] = get_global_range[TN](1);)";
|
||||
int32 rxc[TM] = ridx*TM + (0 ... TM);
|
||||
int32 ryc[TN] = ridy*TN + (0 ... TN);)";
|
||||
|
||||
/* C offsets */
|
||||
if(op_ == BPROP){
|
||||
result +=
|
||||
compute_bhw("rc", "TM", "rxc") + R"(
|
||||
rcw = rcw * stride_w;
|
||||
rch = rch * stride_h;
|
||||
int32 offxc[TM] = rcb*ldc_b + rcw*ldc_w + rch*ldc_h;)";
|
||||
rcw = rcw * )" + stride_w + R"(;
|
||||
rch = rch * )" + stride_h + R"(;
|
||||
int32 offxc[TM] = rcb*)" + ldc_b + R"( + rcw*ldc_w + rch*ldc_h;)";
|
||||
}
|
||||
if(op_ == FPROP){
|
||||
result +=
|
||||
compute_bhw("rc", "TM", "rxc") + R"(
|
||||
int32 offxc[TM] = rcb*ldc_b + rcw*ldc_w + rch*ldc_h;)";
|
||||
int32 offxc[TM] = rcb*)" + ldc_b + R"( + rcw*ldc_w + rch*ldc_h;)";
|
||||
}
|
||||
if(op_ == WGRAD){
|
||||
result += R"(
|
||||
@@ -506,27 +501,7 @@ if(op_ == BPROP){
|
||||
}
|
||||
else{
|
||||
result += R"(
|
||||
int1 has_lock = (GZ > 1) && (locks != 0);
|
||||
if(has_lock){
|
||||
int32 ridx = get_range_id(0);
|
||||
int32 ridy = get_range_id(1);
|
||||
int32 *plock = locks + ridx + ridy*grid0;
|
||||
int32 *pcount = plock + grid0*grid1;
|
||||
while(__atomic_cas(plock, 0, 1) == 1);
|
||||
int32 count = *pcount;
|
||||
int32 countp1 = select(count == grid2 - 1, 0, count + 1);
|
||||
if(count == 0) {
|
||||
@checkc *pc = c;
|
||||
}
|
||||
else {
|
||||
@checkc *pc = c + *pc;
|
||||
}
|
||||
*pcount = countp1;
|
||||
*plock = 0;
|
||||
}
|
||||
else{
|
||||
@checkc *pc = c;
|
||||
})";
|
||||
@checkc *pc = c;)";
|
||||
}
|
||||
result += R"(
|
||||
})";
|
||||
|
@@ -130,7 +130,7 @@ bool binary_operator::is_int_mult() const {
|
||||
}
|
||||
|
||||
bool binary_operator::is_int_add_sub() const {
|
||||
return op_ == llop::Add || llop::Sub;
|
||||
return op_ == llop::Add || op_ == llop::Sub;
|
||||
}
|
||||
|
||||
|
||||
|
Reference in New Issue
Block a user