[codegen/alignment_info] better alignment information

This commit is contained in:
Philippe Tillet
2019-07-20 21:44:18 -07:00
parent 28c250216c
commit d159455f7b
9 changed files with 145 additions and 131 deletions

View File

@@ -8,8 +8,8 @@
int main() { int main() {
bool AT = false; bool AT = true;
bool BT = true; bool BT = false;
typedef float T; typedef float T;
std::string ty = "fp16"; std::string ty = "fp16";
size_t dt_nbytes = sizeof(T); size_t dt_nbytes = sizeof(T);
@@ -37,7 +37,7 @@ int main() {
stream->write(dc, true, 0, hc); stream->write(dc, true, 0, hc);
stream->synchronize(); stream->synchronize();
triton::dnn::gemm gemm(M, N, K, AT, BT, ty, ty, 4, 4); triton::dnn::gemm gemm(M, N, K, AT, BT, ty, ty, 4, 4);
gemm.enqueue(stream, {da, db, dc}, true); gemm.enqueue(stream, {da, db, dc}, false);
// stream->read(dc, true, 0, hc); // stream->read(dc, true, 0, hc);
// gemm.cpu_ref<T>(rc, ha, hb); // gemm.cpu_ref<T>(rc, ha, hb);
// for(size_t i = 0; i < M*N; i++) // for(size_t i = 0; i < M*N; i++)

View File

@@ -14,12 +14,17 @@ namespace ir {
namespace codegen{ namespace codegen{
class alignment_info { class alignment_info {
struct cst_info {
unsigned num_cst;
unsigned value;
};
private: private:
// helpers // helpers
bool is_first_axis_unit(ir::value *v); bool is_first_axis_unit(ir::value *v);
// populate maps // populate maps
bool populate_is_constant(ir::value *v); cst_info populate_is_constant(ir::value *v);
unsigned populate_max_contiguous(ir::value *v); unsigned populate_max_contiguous(ir::value *v);
unsigned populate_starting_multiple(ir::value *v); unsigned populate_starting_multiple(ir::value *v);
@@ -29,7 +34,7 @@ public:
unsigned get_max_contiguous(ir::value* v) const; unsigned get_max_contiguous(ir::value* v) const;
private: private:
std::map<ir::value*, bool> is_constant_; std::map<ir::value*, cst_info> is_constant_;
std::map<ir::value*, unsigned> max_contiguous_; std::map<ir::value*, unsigned> max_contiguous_;
std::map<ir::value*, unsigned> starting_multiple_; std::map<ir::value*, unsigned> starting_multiple_;
}; };

View File

@@ -70,6 +70,7 @@ public:
void target_independent(ir::module &module) { void target_independent(ir::module &module) {
optimize_dot.run(module); optimize_dot.run(module);
optimize_trans.run(module); optimize_trans.run(module);
// ir::print(module, std::cout);
} }
void target_dependent(ir::module &module) { void target_dependent(ir::module &module) {

View File

@@ -9,6 +9,18 @@ namespace triton {
namespace codegen{ namespace codegen{
inline int gcd(int a, int b) {
if (a == 0)
return b;
if (b == 0)
return a;
if (a == b)
return a;
if (a > b)
return gcd(a-b, b);
return gcd(a, b-a);
}
template<class T> template<class T>
inline T add_to_cache(ir::value *i, T value, std::map<ir::value*, T> &map) { inline T add_to_cache(ir::value *i, T value, std::map<ir::value*, T> &map) {
return map[i] = value; return map[i] = value;
@@ -22,50 +34,69 @@ bool alignment_info::is_first_axis_unit(ir::value *x){
return true; return true;
} }
bool alignment_info::populate_is_constant(ir::value *v) { alignment_info::cst_info alignment_info::populate_is_constant(ir::value *v) {
if(is_constant_.find(v) != is_constant_.end()) if(is_constant_.find(v) != is_constant_.end())
return is_constant_.at(v); return is_constant_.at(v);
// helper for the cache // helper for the cache
auto cache = [this,v](bool value){ return add_to_cache(v, value, is_constant_); }; auto cache = [this,v](cst_info value){
return add_to_cache(v, value, is_constant_); }
;
// populate // populate
if(auto *x = dynamic_cast<ir::retile_inst*>(v)){ if(auto *x = dynamic_cast<ir::retile_inst*>(v)){
ir::value *op = x->get_operand(0); ir::value *op = x->get_operand(0);
populate_is_constant(op); auto op_cst = populate_is_constant(op);
if(is_first_axis_unit(op)) if(is_first_axis_unit(op)){
return cache(true); unsigned num_cst = x->get_type()->get_tile_shapes()[0]->get_value();
return cache({num_cst, op_cst.value});
}
} }
if(auto *x = dynamic_cast<ir::constant_int*>(v)) if(auto *x = dynamic_cast<ir::constant_int*>(v))
return cache(true); return cache({true, (unsigned)x->get_value()});
if(auto *x = dynamic_cast<ir::binary_operator*>(v)){ if(auto *x = dynamic_cast<ir::binary_operator*>(v)){
bool lhs = populate_is_constant(x->get_operand(0)); ir::value* lhs_op = x->get_operand(0);
bool rhs = populate_is_constant(x->get_operand(1)); ir::value* rhs_op = x->get_operand(1);
return cache(lhs && rhs); cst_info lhs = populate_is_constant(lhs_op);
cst_info rhs = populate_is_constant(rhs_op);
if(lhs.num_cst==0 && rhs.value && x->is_int_div()){
unsigned max_contiguous = populate_max_contiguous(lhs_op);
unsigned starting_multiple = populate_starting_multiple(lhs_op);
return cache({gcd(max_contiguous, rhs.value) - (starting_multiple % rhs.value), 0});
}
return cache({std::min(lhs.num_cst, rhs.num_cst), 0});
}
if(auto *x = dynamic_cast<ir::getelementptr_inst*>(v)){
ir::value* lhs_op = x->get_operand(0);
ir::value* rhs_op = x->get_operand(1);
cst_info lhs = populate_is_constant(lhs_op);
cst_info rhs = populate_is_constant(rhs_op);
return cache({std::min(lhs.num_cst, rhs.num_cst), 0});
} }
if(auto *x = dynamic_cast<ir::psi_inst*>(v)){ if(auto *x = dynamic_cast<ir::psi_inst*>(v)){
bool value_true = populate_is_constant(x->get_value_true()); cst_info value_true = populate_is_constant(x->get_value_true());
bool value_false = populate_is_constant(x->get_value_false()); cst_info value_false = populate_is_constant(x->get_value_false());
return cache(value_true && value_false); return cache({std::min(value_true.num_cst, value_false.num_cst), 0});
} }
if(v->get_type()->is_tile_ty()) if(v->get_type()->is_tile_ty())
return cache(false); return cache({0, 0});
if(auto *x = dynamic_cast<ir::phi_node*>(v)){ if(auto *x = dynamic_cast<ir::phi_node*>(v)){
// put a conservative initial value in phi node to avoid infinite recursion // put a conservative initial value in phi node to avoid infinite recursion
bool result = true; unsigned result = 1;
for(unsigned n = 0; n < x->get_num_incoming(); n++){ for(unsigned n = 0; n < x->get_num_incoming(); n++){
ir::value* inc = x->get_incoming_value(n); ir::value* inc = x->get_incoming_value(n);
if(is_constant_.find(inc) != is_constant_.end()) if(is_constant_.find(inc) != is_constant_.end())
result = is_constant_.at(inc); result = is_constant_.at(inc).num_cst;
} }
cache(result); cache({result, 0});
// recurse // recurse
for(unsigned n = 0; n < x->get_num_incoming(); n++){ for(unsigned n = 0; n < x->get_num_incoming(); n++){
ir::value* inc = x->get_incoming_value(n); ir::value* inc = x->get_incoming_value(n);
result = result && populate_is_constant(inc); result = std::min(result, populate_is_constant(inc).num_cst);
} }
return cache(result); return cache({result, 0});
} }
// scalars are always constant in the contiguous dimension // scalars are always constant in the contiguous dimension
return cache(true); // but value is not known at compile-time
return cache({1, 0});
} }
unsigned alignment_info::populate_max_contiguous(ir::value *v){ unsigned alignment_info::populate_max_contiguous(ir::value *v){
@@ -95,13 +126,21 @@ unsigned alignment_info::populate_max_contiguous(ir::value *v){
ir::value* rhs = x->get_operand(1); ir::value* rhs = x->get_operand(1);
unsigned lhs_max_contiguous = populate_max_contiguous(lhs); unsigned lhs_max_contiguous = populate_max_contiguous(lhs);
unsigned rhs_max_contiguous = populate_max_contiguous(rhs); unsigned rhs_max_contiguous = populate_max_contiguous(rhs);
bool lhs_has_cst = populate_is_constant(lhs); cst_info lhs_cst_info = populate_is_constant(lhs);
bool rhs_has_cst = populate_is_constant(rhs); cst_info rhs_cst_info = populate_is_constant(rhs);
if(x->is_int_add_sub()){ if(x->is_int_rem() && rhs_cst_info.value > 0)
if(lhs_has_cst) return cache(std::min(lhs_max_contiguous, rhs_cst_info.value));
return cache(rhs_max_contiguous); if(x->is_int_mult()){
if(rhs_has_cst) if(rhs_cst_info.value == 1)
return cache(lhs_max_contiguous); return cache(lhs_max_contiguous);
if(lhs_cst_info.value == 1)
return cache(rhs_max_contiguous);
}
if(x->is_int_add_sub()){
if(lhs_cst_info.num_cst)
return cache(gcd(rhs_max_contiguous, lhs_cst_info.num_cst));
if(rhs_cst_info.num_cst)
return cache(gcd(lhs_max_contiguous, rhs_cst_info.num_cst));
} }
} }
if(auto *x = dynamic_cast<ir::psi_inst*>(v)){ if(auto *x = dynamic_cast<ir::psi_inst*>(v)){
@@ -114,11 +153,11 @@ unsigned alignment_info::populate_max_contiguous(ir::value *v){
ir::value* rhs = x->get_operand(1); ir::value* rhs = x->get_operand(1);
unsigned lhs_max_contiguous = populate_max_contiguous(lhs); unsigned lhs_max_contiguous = populate_max_contiguous(lhs);
unsigned rhs_max_contiguous = populate_max_contiguous(rhs); unsigned rhs_max_contiguous = populate_max_contiguous(rhs);
bool lhs_has_cst = populate_is_constant(lhs); auto lhs_cst_info = populate_is_constant(lhs);
bool rhs_has_cst = populate_is_constant(rhs); auto rhs_cst_info = populate_is_constant(rhs);
if(lhs_has_cst) if(lhs_cst_info.num_cst)
return cache(rhs_max_contiguous); return cache(rhs_max_contiguous);
if(rhs_has_cst) if(rhs_cst_info.num_cst)
return cache(lhs_max_contiguous); return cache(lhs_max_contiguous);
} }
if(auto *x = dynamic_cast<ir::phi_node*>(v)){ if(auto *x = dynamic_cast<ir::phi_node*>(v)){
@@ -140,22 +179,12 @@ unsigned alignment_info::populate_max_contiguous(ir::value *v){
return cache(1); return cache(1);
} }
inline int gcd(int a, int b) {
if (a == 0)
return b;
if (b == 0)
return a;
if (a == b)
return a;
if (a > b)
return gcd(a-b, b);
return gcd(a, b-a);
}
unsigned alignment_info::populate_starting_multiple(ir::value *v){ unsigned alignment_info::populate_starting_multiple(ir::value *v){
if(starting_multiple_.find(v) != starting_multiple_.end()) if(starting_multiple_.find(v) != starting_multiple_.end())
return starting_multiple_.at(v); return starting_multiple_.at(v);
auto cache = [this,v](unsigned value){ return add_to_cache(v, value, starting_multiple_); }; auto cache = [this,v](unsigned value){
return add_to_cache(v, value, starting_multiple_);
};
// has metadata // has metadata
if(auto *x = dynamic_cast<ir::instruction*>(v)){ if(auto *x = dynamic_cast<ir::instruction*>(v)){
unsigned multiple_of = x->get_metadata(ir::metadata::multiple_of); unsigned multiple_of = x->get_metadata(ir::metadata::multiple_of);
@@ -185,15 +214,16 @@ unsigned alignment_info::populate_starting_multiple(ir::value *v){
return cache(gcd(lhs, rhs)); return cache(gcd(lhs, rhs));
if(x->is_int_div()) if(x->is_int_div())
return cache(std::max(lhs / rhs, 1)); return cache(std::max(lhs / rhs, 1));
if(x->is_int_rem()) if(x->is_int_rem() && rhs > 1)
return cache(std::max(lhs % rhs, 1)); return cache(gcd(lhs, rhs));
if(x->is_shl()) if(x->is_shl())
return cache(lhs << rhs); return cache(lhs << rhs);
if(x->is_shr()) if(x->is_shr())
return cache(std::max(lhs >> rhs, 1)); return cache(std::max(lhs >> rhs, 1));
} }
if(auto *x = dynamic_cast<ir::constant_int*>(v)) if(auto *x = dynamic_cast<ir::constant_int*>(v)){
return cache(x->get_value()); return cache(x->get_value());
}
if(auto *x = dynamic_cast<ir::constant_range*>(v)){ if(auto *x = dynamic_cast<ir::constant_range*>(v)){
return cache(x->get_first()->get_value()); return cache(x->get_first()->get_value());
} }
@@ -270,7 +300,6 @@ void alignment_info::run(ir::module &mod) {
for(ir::basic_block *block: fn->blocks()) for(ir::basic_block *block: fn->blocks())
for(ir::instruction *i: block->get_inst_list()){ for(ir::instruction *i: block->get_inst_list()){
populate_max_contiguous(i); populate_max_contiguous(i);
// std::cout << i->get_name() << " " << max_contiguous_.at(i) << " " << starting_multiple_.at(i) << std::endl;
} }
} }

View File

@@ -233,10 +233,15 @@ void tune::run(ir::module &mod) {
for(ir::instruction *i : block->get_inst_list()){ for(ir::instruction *i : block->get_inst_list()){
if(fragments_.find({i, 0}) != fragments_.end() && fragments_.at({i, 0}) != STRIDED_SCAN) if(fragments_.find({i, 0}) != fragments_.end() && fragments_.at({i, 0}) != STRIDED_SCAN)
continue; continue;
if(dynamic_cast<ir::load_inst*>(i) && i->get_type()->is_tile_ty()){ if(auto *ld = dynamic_cast<ir::load_inst*>(i))
ir::type *ty = mod.get_builder().get_int32_ty(); if(i->get_type()->is_tile_ty()){
std::unique_ptr<ir::metaparameter> tmp(ir::metaparameter::create(ctx, ty, 4, 8)); ir::type *ptr_ty = ld->get_pointer_operand()->get_type()->get_scalar_ty();
*params_.at(i).at("nts.d0") = *tmp; size_t addr_space = ptr_ty->get_pointer_address_space();
if(addr_space < 4){
ir::type *ty = mod.get_builder().get_int32_ty();
std::unique_ptr<ir::metaparameter> tmp(ir::metaparameter::create(ctx, ty, 8, 8));
*params_.at(i).at("nts.d0") = *tmp;
}
} }
if(dynamic_cast<ir::dot_inst*>(i) && i->get_type()->is_tile_ty()){ if(dynamic_cast<ir::dot_inst*>(i) && i->get_type()->is_tile_ty()){
ir::type *ty = mod.get_builder().get_int32_ty(); ir::type *ty = mod.get_builder().get_int32_ty();

View File

@@ -51,8 +51,7 @@ void base::enqueue(driver::stream *stream, std::vector<driver::buffer *> args, b
jit->add_module(name_.c_str(), src.c_str(), best.params); jit->add_module(name_.c_str(), src.c_str(), best.params);
} }
else { else {
// jit->add_module(name_.c_str(), src.c_str(), jit->get_valid(name_.c_str(), src.c_str())); jit->add_module(name_.c_str(), src.c_str(), jit->get_valid(name_.c_str(), src.c_str()));
jit->add_module(name_.c_str(), src.c_str(), {32, 128, 16, 128, 2, 2, 2, 2, 4, 4, 32, 8, 4, 1});
} }
triton::driver::kernel* kernel = jit->get_function(name_.c_str()); triton::driver::kernel* kernel = jit->get_function(name_.c_str());
clone->init_impl(stream, (triton::driver::cu_module*)kernel->module()); clone->init_impl(stream, (triton::driver::cu_module*)kernel->module());

View File

@@ -113,8 +113,8 @@ void matmul(restrict read_only align(16) )" + a_ty_ + R"( *A,
int32 bound, int32 *locks, int32 grid0, int32 grid1) { int32 bound, int32 *locks, int32 grid0, int32 grid1) {
int32 ridx = get_range_id(0); int32 ridx = get_range_id(0);
int32 ridy = get_range_id(1); int32 ridy = get_range_id(1);
int32 rxa[TM] = ridx*TM + (0 ... TM); int32 rxa[TM] = ridx * TM + (0 ... TM);
int32 ryb[TN] = ridy*TN + (0 ... TN); int32 ryb[TN] = ridy * TN + (0 ... TN);
int32 rka[TK] = 0 ... TK; int32 rka[TK] = 0 ... TK;
int32 rkb[TK] = 0 ... TK; int32 rkb[TK] = 0 ... TK;
fp32 c[TM, TN] = 0; fp32 c[TM, TN] = 0;

View File

@@ -27,7 +27,7 @@ shift::shift(int B, int C,
layout_(layout){ layout_(layout){
// std::cout << B_ << " " << C_ << " " << F_ << " " << stride_h_ << " " << stride_w_ << " " << a_ty_ << " " << b_ty_ << " " << ty_ << " " << layout_ << std::endl; // std::cout << B_ << " " << C_ << " " << F_ << " " << stride_h_ << " " << stride_w_ << " " << a_ty_ << " " << b_ty_ << " " << ty_ << " " << layout_ << std::endl;
// max number of channels // max number of channels
TK_ = (ty == FPROP && a_ty_ == "fp32") ? 8 : 16; TK_ = (ty == FPROP && a_ty_ == "fp32") ? 8 : 32;
MAX_C_ = 8192 + TK_; MAX_C_ = 8192 + TK_;
// activation sizes // activation sizes
CD_ = AD_ / stride_d_; CD_ = AD_ / stride_d_;
@@ -223,7 +223,7 @@ void shift::deinit_impl() {
void shift::enqueue_impl(driver::stream *stream, driver::kernel *kernel, void shift::enqueue_impl(driver::stream *stream, driver::kernel *kernel,
std::vector<driver::buffer *> args, std::vector<driver::buffer *> args,
runtime::launch_information info) { runtime::launch_information info) {
unsigned TM = info.global_range_size[0], TN = info.global_range_size[1]; unsigned TM = info.globals.at("TM"), TN = info.globals.at("TN");
unsigned grid_0 = (M_ + TM - 1)/TM; unsigned grid_0 = (M_ + TM - 1)/TM;
unsigned grid_1 = (N_ + TN - 1)/TN; unsigned grid_1 = (N_ + TN - 1)/TN;
unsigned num_locks = grid_0 * grid_1; unsigned num_locks = grid_0 * grid_1;
@@ -278,6 +278,8 @@ void shift::triton_c_src(std::ostream &os) const {
std::string usea = AT_ ? "trans(a)" : "a"; std::string usea = AT_ ? "trans(a)" : "a";
std::string useb = BT_ ? "trans(b)" : "b"; std::string useb = BT_ ? "trans(b)" : "b";
std::string bca0 = "[newaxis, :]", bca1 = "[:, newaxis]"; std::string bca0 = "[newaxis, :]", bca1 = "[:, newaxis]";
std::string stride_h = std::to_string(stride_h_);
std::string stride_w = std::to_string(stride_w_);
if(AT_){ if(AT_){
std::swap(AS0, AS1); std::swap(AS0, AS1);
std::swap(bca0, bca1); std::swap(bca0, bca1);
@@ -290,6 +292,11 @@ void shift::triton_c_src(std::ostream &os) const {
std::string BS = BS0 + ", " + BS1; std::string BS = BS0 + ", " + BS1;
bool is_chwn = layout_ == CHWN; bool is_chwn = layout_ == CHWN;
std::string lda_b = is_chwn ? "1" : "lda_b";
std::string ldb_b = is_chwn ? "1" : "ldb_b";
std::string ldc_b = is_chwn ? "1" : "ldc_b";
auto compute_bhw = [&](std::string rx, std::string sz, std::string rkx){ auto compute_bhw = [&](std::string rx, std::string sz, std::string rkx){
std::string B = std::to_string(B_); std::string B = std::to_string(B_);
std::string CW = std::to_string(ICW_); std::string CW = std::to_string(ICW_);
@@ -317,7 +324,7 @@ const tunable int32 TM = {16, 32, 64, 128};
const tunable int32 TN = {16, 32, 64, 128}; const tunable int32 TN = {16, 32, 64, 128};
const tunable int32 TK = {)" + std::to_string(TK_) + "};"; const tunable int32 TK = {)" + std::to_string(TK_) + "};";
if(op_ == WGRAD) if(op_ == WGRAD)
result += "const tunable int32 GZ = {1, 4, 16};"; result += "const tunable int32 GZ = {1};";
else else
result += "const tunable int32 GZ = {1};"; result += "const tunable int32 GZ = {1};";
@@ -329,30 +336,27 @@ void shift(restrict read_only align(16) )" + a_ty_ + R"( *A,
)" + c_ty_ + R"( *C, )" + c_ty_ + R"( *C,
int32 M, int32 N, int32 K, int32 M, int32 N, int32 K,
int32 stride_h, int32 stride_w, int32 stride_h, int32 stride_w,
multiple_of(4) int32 lda_b, multiple_of(4) int32 lda_w, multiple_of(4) int32 lda_h, multiple_of(4) int32 lda_c, multiple_of(8) int32 lda_b, multiple_of(8) int32 lda_w, multiple_of(8) int32 lda_h, multiple_of(8) int32 lda_c,
multiple_of(4) int32 ldb_b, multiple_of(4) int32 ldb_w, multiple_of(4) int32 ldb_h, multiple_of(4) int32 ldb_c, multiple_of(8) int32 ldb_b, multiple_of(8) int32 ldb_w, multiple_of(8) int32 ldb_h, multiple_of(8) int32 ldb_c,
multiple_of(4) int32 ldc_b, multiple_of(4) int32 ldc_w, multiple_of(4) int32 ldc_h, multiple_of(4) int32 ldc_c, multiple_of(8) int32 ldc_b, multiple_of(8) int32 ldc_w, multiple_of(8) int32 ldc_h, multiple_of(8) int32 ldc_c,
int32 NB, int32 NB,
int32 AH, int32 AW, int32 AH, int32 AW,
int32 BH, int32 BW, int32 BH, int32 BW,
int32 CH, int32 CW, int32 CH, int32 CW,
int32* locks, int32 grid0, int32 grid1, int32 grid2) { int32* locks, int32 grid0, int32 grid1, int32 grid2) {
int32 rxa[TM] = get_global_range[TM](0); int32 ridx = get_range_id(0);
int32 ryb[TN] = get_global_range[TN](1); int32 ridy = get_range_id(1);
int32 rz = get_global_range[1](2); int32 rz = get_range_id(2);
int32 rxa[TM] = ridx*TM + (0 ... TM);
int32 ryb[TN] = ridy*TN + (0 ... TN);
int32 rka[TK] = 0 ... TK; int32 rka[TK] = 0 ... TK;
int32 rkb[TK] = 0 ... TK; int32 rkb[TK] = 0 ... TK;
fp32 acc[TM, TN] = 0; fp32 acc[TM, TN] = 0;
int32 pad_h = BH / 2; int32 pad_h = BH / 2;
int32 pad_w = BW / 2; int32 pad_w = BW / 2;)";
int32 div = K / grid2;
int32 rem = K % grid2;
K = select(rz < rem, div - 1, div);
int32 offk = select(rz < rem, rz*(div + 1), rz*div + rem);)";
if(op_ == WGRAD){ if(op_ == WGRAD){
result += R"( result += R"(
rka = rka + offk;
rkb = rkb + offk;
)"; )";
} }
@@ -360,31 +364,26 @@ if(op_ == WGRAD){
if(op_ == FPROP){ if(op_ == FPROP){
result += result +=
compute_bhw("ra", "TM", "rxa") + R"( compute_bhw("ra", "TM", "rxa") + R"(
raw = raw * stride_w; raw = raw * )" + stride_w + R"(;
rah = rah * stride_h; rah = rah * )" + stride_h + R"(;
int32 offxa[TM] = rab*lda_b + raw*lda_w + rah*lda_h; int32 offxa[TM] = rab*)" + lda_b + R"( + raw*lda_w + rah*lda_h;
int32 offa0[TM, TK] = offxa[:, newaxis]; int32 offa0[TM, TK] = offxa[:, newaxis];
__constant__ int32* pd[TK] = delta_a + rka; __constant__ int32* pd[TK] = delta_a + rka;
multiple_of(4) int32 d[TK] = *pd; multiple_of(8) int32 d[TK] = *pd;
int32 offa1[TM, TK] = d[newaxis, :];)"; int32 offa1[TM, TK] = d[newaxis, :];)";
} }
if(op_ == BPROP){ if(op_ == BPROP){
result += result +=
compute_bhw("ra", "TM", "rxa") + R"( compute_bhw("ra", "TM", "rxa") + R"(
int32 offxa[TM] = rab*lda_b + raw*lda_w + rah*lda_h; int32 offxa[TM] = rab*)" + lda_b + R"( + raw*lda_w + rah*lda_h;
int32 offa0[TM, TK] = offxa[:, newaxis]; int32 offa0[TM, TK] = offxa[:, newaxis];
int32 offa1[TM, TK] = rka[newaxis, :] * lda_c;)"; int32 offa1[TM, TK] = rka[newaxis, :] * lda_c;)";
} }
if(op_ == WGRAD && layout_ == CHWN){ if(op_ == WGRAD){
result += R"(
int32 offa0[TK, TM] = rxa[newaxis, :] * lda_c;
int32 offa1[TK, TM] = rka[:, newaxis];)";
}
if(op_ == WGRAD && layout_ == NCHW){
result += result +=
compute_bhw("ra", "TK", "rka") + R"( compute_bhw("ra", "TK", "rka") + R"(
int32 offa0[TK, TM] = rxa[newaxis, :] * lda_c; int32 offa0[TK, TM] = rxa[newaxis, :] * lda_c;
int32 offxa[TK] = rab*lda_b + raw*lda_w + rah*lda_h; int32 offxa[TK] = rab*)" + lda_b + R"( + raw*lda_w + rah*lda_h;
int32 offa1[TK, TM] = offxa[:, newaxis];)"; int32 offa1[TK, TM] = offxa[:, newaxis];)";
} }
@@ -403,11 +402,11 @@ if(op_ == WGRAD){
result += result +=
compute_bhw("rb", "TK", "rkb") + R"( compute_bhw("rb", "TK", "rkb") + R"(
__constant__ int32* pd[TN] = delta_a + ryb; __constant__ int32* pd[TN] = delta_a + ryb;
int32 d[TN] = *pd; multiple_of(8) int32 d[TN] = *pd;
int32 shift[TK, TN] = d[newaxis, :]; multiple_of(8) int32 shift[TK, TN] = d[newaxis, :];
rbw = rbw * stride_w; rbw = rbw * )" + stride_w + R"(;
rbh = rbh * stride_h; rbh = rbh * )" + stride_h + R"(;
int32 offkb[TK] = rbb*ldb_b + rbw*ldb_w + rbh*ldb_h; int32 offkb[TK] = rbb*)" + ldb_b + R"( + rbw*ldb_w + rbh*ldb_h;
int32 offb0[TK, TN] = ryb[newaxis, :] * ldb_c; int32 offb0[TK, TN] = ryb[newaxis, :] * ldb_c;
int32 offb1[TK, TN] = offkb[:, newaxis] + shift;)"; int32 offb1[TK, TN] = offkb[:, newaxis] + shift;)";
} }
@@ -416,8 +415,8 @@ if(op_ == WGRAD){
result += R"( result += R"(
)" + a_ty_ + "* pa[" + AS + R"(] = A + offa0 + offa1; )" + a_ty_ + "* pa[" + AS + R"(] = A + offa0 + offa1;
)" + b_ty_ + "* pb[" + BS + R"(] = B + offb0 + offb1; )" + b_ty_ + "* pb[" + BS + R"(] = B + offb0 + offb1;
int1 checka[)" + AS + "] = (rka < K + offk)" + bca0 + R"(; int1 checka[)" + AS + "] = (rka < K)" + bca0 + R"(;
int1 checkb[)" + BS + "] = (rkb < K + offk)" + bcb0 + R"(; int1 checkb[)" + BS + "] = (rkb < K)" + bcb0 + R"(;
)" + a_ty_ + " a[" + AS + R"(] = checka ? *pa : 0; )" + a_ty_ + " a[" + AS + R"(] = checka ? *pa : 0;
)" + b_ty_ + " b[" + BS + R"(] = checkb ? *pb : 0; )" + b_ty_ + " b[" + BS + R"(] = checkb ? *pb : 0;
for(int32 k = K; k > 0; k = k - TK){ for(int32 k = K; k > 0; k = k - TK){
@@ -436,15 +435,11 @@ if(op_ == BPROP){
result += R"( result += R"(
pa = pa + TK * lda_c;)"; pa = pa + TK * lda_c;)";
} }
if(op_ == WGRAD && layout_ == CHWN){ if(op_ == WGRAD){
result += R"(
pa = pa + TK;)";
}
if(op_ == WGRAD && layout_ == NCHW){
result += R"( result += R"(
rka = rka + TK;)" rka = rka + TK;)"
+ compute_bhw("ra", "TK", "rka") + R"( + compute_bhw("ra", "TK", "rka") + R"(
offxa = rab*lda_b + raw*lda_w + rah*lda_h; offxa = rab*)" + lda_b + R"( + raw*lda_w + rah*lda_h;
pa = A + offa0 + offxa[:, newaxis];)"; pa = A + offa0 + offxa[:, newaxis];)";
} }
result += R"( result += R"(
@@ -455,9 +450,9 @@ if(op_ == WGRAD){
result += R"( result += R"(
rkb = rkb + TK;)" rkb = rkb + TK;)"
+ compute_bhw("rb", "TK", "rkb") + R"( + compute_bhw("rb", "TK", "rkb") + R"(
rbw = rbw * stride_w; rbw = rbw * )" + stride_w + R"(;
rbh = rbh * stride_h; rbh = rbh * )" + stride_h + R"(;
offkb = rbb*ldb_b + rbw*ldb_w + rbh*ldb_h; offkb = rbb*)" + ldb_b + R"( + rbw*ldb_w + rbh*ldb_h;
pb = B + offb0 + offkb[:, newaxis] + shift;)"; pb = B + offb0 + offkb[:, newaxis] + shift;)";
} }
if(op_ == FPROP){ if(op_ == FPROP){
@@ -471,21 +466,21 @@ if(op_ == BPROP){
result += R"( result += R"(
@checkb b = *pb; @checkb b = *pb;
} }
int32 rxc[TM] = get_global_range[TM](0); int32 rxc[TM] = ridx*TM + (0 ... TM);
int32 ryc[TN] = get_global_range[TN](1);)"; int32 ryc[TN] = ridy*TN + (0 ... TN);)";
/* C offsets */ /* C offsets */
if(op_ == BPROP){ if(op_ == BPROP){
result += result +=
compute_bhw("rc", "TM", "rxc") + R"( compute_bhw("rc", "TM", "rxc") + R"(
rcw = rcw * stride_w; rcw = rcw * )" + stride_w + R"(;
rch = rch * stride_h; rch = rch * )" + stride_h + R"(;
int32 offxc[TM] = rcb*ldc_b + rcw*ldc_w + rch*ldc_h;)"; int32 offxc[TM] = rcb*)" + ldc_b + R"( + rcw*ldc_w + rch*ldc_h;)";
} }
if(op_ == FPROP){ if(op_ == FPROP){
result += result +=
compute_bhw("rc", "TM", "rxc") + R"( compute_bhw("rc", "TM", "rxc") + R"(
int32 offxc[TM] = rcb*ldc_b + rcw*ldc_w + rch*ldc_h;)"; int32 offxc[TM] = rcb*)" + ldc_b + R"( + rcw*ldc_w + rch*ldc_h;)";
} }
if(op_ == WGRAD){ if(op_ == WGRAD){
result += R"( result += R"(
@@ -506,27 +501,7 @@ if(op_ == BPROP){
} }
else{ else{
result += R"( result += R"(
int1 has_lock = (GZ > 1) && (locks != 0); @checkc *pc = c;)";
if(has_lock){
int32 ridx = get_range_id(0);
int32 ridy = get_range_id(1);
int32 *plock = locks + ridx + ridy*grid0;
int32 *pcount = plock + grid0*grid1;
while(__atomic_cas(plock, 0, 1) == 1);
int32 count = *pcount;
int32 countp1 = select(count == grid2 - 1, 0, count + 1);
if(count == 0) {
@checkc *pc = c;
}
else {
@checkc *pc = c + *pc;
}
*pcount = countp1;
*plock = 0;
}
else{
@checkc *pc = c;
})";
} }
result += R"( result += R"(
})"; })";

View File

@@ -130,7 +130,7 @@ bool binary_operator::is_int_mult() const {
} }
bool binary_operator::is_int_add_sub() const { bool binary_operator::is_int_add_sub() const {
return op_ == llop::Add || llop::Sub; return op_ == llop::Add || op_ == llop::Sub;
} }