[codegen/alignment_info] better alignment information
This commit is contained in:
@@ -8,8 +8,8 @@
|
|||||||
|
|
||||||
|
|
||||||
int main() {
|
int main() {
|
||||||
bool AT = false;
|
bool AT = true;
|
||||||
bool BT = true;
|
bool BT = false;
|
||||||
typedef float T;
|
typedef float T;
|
||||||
std::string ty = "fp16";
|
std::string ty = "fp16";
|
||||||
size_t dt_nbytes = sizeof(T);
|
size_t dt_nbytes = sizeof(T);
|
||||||
@@ -37,7 +37,7 @@ int main() {
|
|||||||
stream->write(dc, true, 0, hc);
|
stream->write(dc, true, 0, hc);
|
||||||
stream->synchronize();
|
stream->synchronize();
|
||||||
triton::dnn::gemm gemm(M, N, K, AT, BT, ty, ty, 4, 4);
|
triton::dnn::gemm gemm(M, N, K, AT, BT, ty, ty, 4, 4);
|
||||||
gemm.enqueue(stream, {da, db, dc}, true);
|
gemm.enqueue(stream, {da, db, dc}, false);
|
||||||
// stream->read(dc, true, 0, hc);
|
// stream->read(dc, true, 0, hc);
|
||||||
// gemm.cpu_ref<T>(rc, ha, hb);
|
// gemm.cpu_ref<T>(rc, ha, hb);
|
||||||
// for(size_t i = 0; i < M*N; i++)
|
// for(size_t i = 0; i < M*N; i++)
|
||||||
|
@@ -14,12 +14,17 @@ namespace ir {
|
|||||||
namespace codegen{
|
namespace codegen{
|
||||||
|
|
||||||
class alignment_info {
|
class alignment_info {
|
||||||
|
struct cst_info {
|
||||||
|
unsigned num_cst;
|
||||||
|
unsigned value;
|
||||||
|
};
|
||||||
|
|
||||||
private:
|
private:
|
||||||
// helpers
|
// helpers
|
||||||
bool is_first_axis_unit(ir::value *v);
|
bool is_first_axis_unit(ir::value *v);
|
||||||
|
|
||||||
// populate maps
|
// populate maps
|
||||||
bool populate_is_constant(ir::value *v);
|
cst_info populate_is_constant(ir::value *v);
|
||||||
unsigned populate_max_contiguous(ir::value *v);
|
unsigned populate_max_contiguous(ir::value *v);
|
||||||
unsigned populate_starting_multiple(ir::value *v);
|
unsigned populate_starting_multiple(ir::value *v);
|
||||||
|
|
||||||
@@ -29,7 +34,7 @@ public:
|
|||||||
unsigned get_max_contiguous(ir::value* v) const;
|
unsigned get_max_contiguous(ir::value* v) const;
|
||||||
|
|
||||||
private:
|
private:
|
||||||
std::map<ir::value*, bool> is_constant_;
|
std::map<ir::value*, cst_info> is_constant_;
|
||||||
std::map<ir::value*, unsigned> max_contiguous_;
|
std::map<ir::value*, unsigned> max_contiguous_;
|
||||||
std::map<ir::value*, unsigned> starting_multiple_;
|
std::map<ir::value*, unsigned> starting_multiple_;
|
||||||
};
|
};
|
||||||
|
@@ -70,6 +70,7 @@ public:
|
|||||||
void target_independent(ir::module &module) {
|
void target_independent(ir::module &module) {
|
||||||
optimize_dot.run(module);
|
optimize_dot.run(module);
|
||||||
optimize_trans.run(module);
|
optimize_trans.run(module);
|
||||||
|
// ir::print(module, std::cout);
|
||||||
}
|
}
|
||||||
|
|
||||||
void target_dependent(ir::module &module) {
|
void target_dependent(ir::module &module) {
|
||||||
|
@@ -9,6 +9,18 @@ namespace triton {
|
|||||||
namespace codegen{
|
namespace codegen{
|
||||||
|
|
||||||
|
|
||||||
|
inline int gcd(int a, int b) {
|
||||||
|
if (a == 0)
|
||||||
|
return b;
|
||||||
|
if (b == 0)
|
||||||
|
return a;
|
||||||
|
if (a == b)
|
||||||
|
return a;
|
||||||
|
if (a > b)
|
||||||
|
return gcd(a-b, b);
|
||||||
|
return gcd(a, b-a);
|
||||||
|
}
|
||||||
|
|
||||||
template<class T>
|
template<class T>
|
||||||
inline T add_to_cache(ir::value *i, T value, std::map<ir::value*, T> &map) {
|
inline T add_to_cache(ir::value *i, T value, std::map<ir::value*, T> &map) {
|
||||||
return map[i] = value;
|
return map[i] = value;
|
||||||
@@ -22,50 +34,69 @@ bool alignment_info::is_first_axis_unit(ir::value *x){
|
|||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
bool alignment_info::populate_is_constant(ir::value *v) {
|
alignment_info::cst_info alignment_info::populate_is_constant(ir::value *v) {
|
||||||
if(is_constant_.find(v) != is_constant_.end())
|
if(is_constant_.find(v) != is_constant_.end())
|
||||||
return is_constant_.at(v);
|
return is_constant_.at(v);
|
||||||
// helper for the cache
|
// helper for the cache
|
||||||
auto cache = [this,v](bool value){ return add_to_cache(v, value, is_constant_); };
|
auto cache = [this,v](cst_info value){
|
||||||
|
return add_to_cache(v, value, is_constant_); }
|
||||||
|
;
|
||||||
// populate
|
// populate
|
||||||
if(auto *x = dynamic_cast<ir::retile_inst*>(v)){
|
if(auto *x = dynamic_cast<ir::retile_inst*>(v)){
|
||||||
ir::value *op = x->get_operand(0);
|
ir::value *op = x->get_operand(0);
|
||||||
populate_is_constant(op);
|
auto op_cst = populate_is_constant(op);
|
||||||
if(is_first_axis_unit(op))
|
if(is_first_axis_unit(op)){
|
||||||
return cache(true);
|
unsigned num_cst = x->get_type()->get_tile_shapes()[0]->get_value();
|
||||||
|
return cache({num_cst, op_cst.value});
|
||||||
|
}
|
||||||
}
|
}
|
||||||
if(auto *x = dynamic_cast<ir::constant_int*>(v))
|
if(auto *x = dynamic_cast<ir::constant_int*>(v))
|
||||||
return cache(true);
|
return cache({true, (unsigned)x->get_value()});
|
||||||
if(auto *x = dynamic_cast<ir::binary_operator*>(v)){
|
if(auto *x = dynamic_cast<ir::binary_operator*>(v)){
|
||||||
bool lhs = populate_is_constant(x->get_operand(0));
|
ir::value* lhs_op = x->get_operand(0);
|
||||||
bool rhs = populate_is_constant(x->get_operand(1));
|
ir::value* rhs_op = x->get_operand(1);
|
||||||
return cache(lhs && rhs);
|
cst_info lhs = populate_is_constant(lhs_op);
|
||||||
|
cst_info rhs = populate_is_constant(rhs_op);
|
||||||
|
if(lhs.num_cst==0 && rhs.value && x->is_int_div()){
|
||||||
|
unsigned max_contiguous = populate_max_contiguous(lhs_op);
|
||||||
|
unsigned starting_multiple = populate_starting_multiple(lhs_op);
|
||||||
|
return cache({gcd(max_contiguous, rhs.value) - (starting_multiple % rhs.value), 0});
|
||||||
|
}
|
||||||
|
return cache({std::min(lhs.num_cst, rhs.num_cst), 0});
|
||||||
|
}
|
||||||
|
if(auto *x = dynamic_cast<ir::getelementptr_inst*>(v)){
|
||||||
|
ir::value* lhs_op = x->get_operand(0);
|
||||||
|
ir::value* rhs_op = x->get_operand(1);
|
||||||
|
cst_info lhs = populate_is_constant(lhs_op);
|
||||||
|
cst_info rhs = populate_is_constant(rhs_op);
|
||||||
|
return cache({std::min(lhs.num_cst, rhs.num_cst), 0});
|
||||||
}
|
}
|
||||||
if(auto *x = dynamic_cast<ir::psi_inst*>(v)){
|
if(auto *x = dynamic_cast<ir::psi_inst*>(v)){
|
||||||
bool value_true = populate_is_constant(x->get_value_true());
|
cst_info value_true = populate_is_constant(x->get_value_true());
|
||||||
bool value_false = populate_is_constant(x->get_value_false());
|
cst_info value_false = populate_is_constant(x->get_value_false());
|
||||||
return cache(value_true && value_false);
|
return cache({std::min(value_true.num_cst, value_false.num_cst), 0});
|
||||||
}
|
}
|
||||||
if(v->get_type()->is_tile_ty())
|
if(v->get_type()->is_tile_ty())
|
||||||
return cache(false);
|
return cache({0, 0});
|
||||||
if(auto *x = dynamic_cast<ir::phi_node*>(v)){
|
if(auto *x = dynamic_cast<ir::phi_node*>(v)){
|
||||||
// put a conservative initial value in phi node to avoid infinite recursion
|
// put a conservative initial value in phi node to avoid infinite recursion
|
||||||
bool result = true;
|
unsigned result = 1;
|
||||||
for(unsigned n = 0; n < x->get_num_incoming(); n++){
|
for(unsigned n = 0; n < x->get_num_incoming(); n++){
|
||||||
ir::value* inc = x->get_incoming_value(n);
|
ir::value* inc = x->get_incoming_value(n);
|
||||||
if(is_constant_.find(inc) != is_constant_.end())
|
if(is_constant_.find(inc) != is_constant_.end())
|
||||||
result = is_constant_.at(inc);
|
result = is_constant_.at(inc).num_cst;
|
||||||
}
|
}
|
||||||
cache(result);
|
cache({result, 0});
|
||||||
// recurse
|
// recurse
|
||||||
for(unsigned n = 0; n < x->get_num_incoming(); n++){
|
for(unsigned n = 0; n < x->get_num_incoming(); n++){
|
||||||
ir::value* inc = x->get_incoming_value(n);
|
ir::value* inc = x->get_incoming_value(n);
|
||||||
result = result && populate_is_constant(inc);
|
result = std::min(result, populate_is_constant(inc).num_cst);
|
||||||
}
|
}
|
||||||
return cache(result);
|
return cache({result, 0});
|
||||||
}
|
}
|
||||||
// scalars are always constant in the contiguous dimension
|
// scalars are always constant in the contiguous dimension
|
||||||
return cache(true);
|
// but value is not known at compile-time
|
||||||
|
return cache({1, 0});
|
||||||
}
|
}
|
||||||
|
|
||||||
unsigned alignment_info::populate_max_contiguous(ir::value *v){
|
unsigned alignment_info::populate_max_contiguous(ir::value *v){
|
||||||
@@ -95,13 +126,21 @@ unsigned alignment_info::populate_max_contiguous(ir::value *v){
|
|||||||
ir::value* rhs = x->get_operand(1);
|
ir::value* rhs = x->get_operand(1);
|
||||||
unsigned lhs_max_contiguous = populate_max_contiguous(lhs);
|
unsigned lhs_max_contiguous = populate_max_contiguous(lhs);
|
||||||
unsigned rhs_max_contiguous = populate_max_contiguous(rhs);
|
unsigned rhs_max_contiguous = populate_max_contiguous(rhs);
|
||||||
bool lhs_has_cst = populate_is_constant(lhs);
|
cst_info lhs_cst_info = populate_is_constant(lhs);
|
||||||
bool rhs_has_cst = populate_is_constant(rhs);
|
cst_info rhs_cst_info = populate_is_constant(rhs);
|
||||||
if(x->is_int_add_sub()){
|
if(x->is_int_rem() && rhs_cst_info.value > 0)
|
||||||
if(lhs_has_cst)
|
return cache(std::min(lhs_max_contiguous, rhs_cst_info.value));
|
||||||
return cache(rhs_max_contiguous);
|
if(x->is_int_mult()){
|
||||||
if(rhs_has_cst)
|
if(rhs_cst_info.value == 1)
|
||||||
return cache(lhs_max_contiguous);
|
return cache(lhs_max_contiguous);
|
||||||
|
if(lhs_cst_info.value == 1)
|
||||||
|
return cache(rhs_max_contiguous);
|
||||||
|
}
|
||||||
|
if(x->is_int_add_sub()){
|
||||||
|
if(lhs_cst_info.num_cst)
|
||||||
|
return cache(gcd(rhs_max_contiguous, lhs_cst_info.num_cst));
|
||||||
|
if(rhs_cst_info.num_cst)
|
||||||
|
return cache(gcd(lhs_max_contiguous, rhs_cst_info.num_cst));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if(auto *x = dynamic_cast<ir::psi_inst*>(v)){
|
if(auto *x = dynamic_cast<ir::psi_inst*>(v)){
|
||||||
@@ -114,11 +153,11 @@ unsigned alignment_info::populate_max_contiguous(ir::value *v){
|
|||||||
ir::value* rhs = x->get_operand(1);
|
ir::value* rhs = x->get_operand(1);
|
||||||
unsigned lhs_max_contiguous = populate_max_contiguous(lhs);
|
unsigned lhs_max_contiguous = populate_max_contiguous(lhs);
|
||||||
unsigned rhs_max_contiguous = populate_max_contiguous(rhs);
|
unsigned rhs_max_contiguous = populate_max_contiguous(rhs);
|
||||||
bool lhs_has_cst = populate_is_constant(lhs);
|
auto lhs_cst_info = populate_is_constant(lhs);
|
||||||
bool rhs_has_cst = populate_is_constant(rhs);
|
auto rhs_cst_info = populate_is_constant(rhs);
|
||||||
if(lhs_has_cst)
|
if(lhs_cst_info.num_cst)
|
||||||
return cache(rhs_max_contiguous);
|
return cache(rhs_max_contiguous);
|
||||||
if(rhs_has_cst)
|
if(rhs_cst_info.num_cst)
|
||||||
return cache(lhs_max_contiguous);
|
return cache(lhs_max_contiguous);
|
||||||
}
|
}
|
||||||
if(auto *x = dynamic_cast<ir::phi_node*>(v)){
|
if(auto *x = dynamic_cast<ir::phi_node*>(v)){
|
||||||
@@ -140,22 +179,12 @@ unsigned alignment_info::populate_max_contiguous(ir::value *v){
|
|||||||
return cache(1);
|
return cache(1);
|
||||||
}
|
}
|
||||||
|
|
||||||
inline int gcd(int a, int b) {
|
|
||||||
if (a == 0)
|
|
||||||
return b;
|
|
||||||
if (b == 0)
|
|
||||||
return a;
|
|
||||||
if (a == b)
|
|
||||||
return a;
|
|
||||||
if (a > b)
|
|
||||||
return gcd(a-b, b);
|
|
||||||
return gcd(a, b-a);
|
|
||||||
}
|
|
||||||
|
|
||||||
unsigned alignment_info::populate_starting_multiple(ir::value *v){
|
unsigned alignment_info::populate_starting_multiple(ir::value *v){
|
||||||
if(starting_multiple_.find(v) != starting_multiple_.end())
|
if(starting_multiple_.find(v) != starting_multiple_.end())
|
||||||
return starting_multiple_.at(v);
|
return starting_multiple_.at(v);
|
||||||
auto cache = [this,v](unsigned value){ return add_to_cache(v, value, starting_multiple_); };
|
auto cache = [this,v](unsigned value){
|
||||||
|
return add_to_cache(v, value, starting_multiple_);
|
||||||
|
};
|
||||||
// has metadata
|
// has metadata
|
||||||
if(auto *x = dynamic_cast<ir::instruction*>(v)){
|
if(auto *x = dynamic_cast<ir::instruction*>(v)){
|
||||||
unsigned multiple_of = x->get_metadata(ir::metadata::multiple_of);
|
unsigned multiple_of = x->get_metadata(ir::metadata::multiple_of);
|
||||||
@@ -185,15 +214,16 @@ unsigned alignment_info::populate_starting_multiple(ir::value *v){
|
|||||||
return cache(gcd(lhs, rhs));
|
return cache(gcd(lhs, rhs));
|
||||||
if(x->is_int_div())
|
if(x->is_int_div())
|
||||||
return cache(std::max(lhs / rhs, 1));
|
return cache(std::max(lhs / rhs, 1));
|
||||||
if(x->is_int_rem())
|
if(x->is_int_rem() && rhs > 1)
|
||||||
return cache(std::max(lhs % rhs, 1));
|
return cache(gcd(lhs, rhs));
|
||||||
if(x->is_shl())
|
if(x->is_shl())
|
||||||
return cache(lhs << rhs);
|
return cache(lhs << rhs);
|
||||||
if(x->is_shr())
|
if(x->is_shr())
|
||||||
return cache(std::max(lhs >> rhs, 1));
|
return cache(std::max(lhs >> rhs, 1));
|
||||||
}
|
}
|
||||||
if(auto *x = dynamic_cast<ir::constant_int*>(v))
|
if(auto *x = dynamic_cast<ir::constant_int*>(v)){
|
||||||
return cache(x->get_value());
|
return cache(x->get_value());
|
||||||
|
}
|
||||||
if(auto *x = dynamic_cast<ir::constant_range*>(v)){
|
if(auto *x = dynamic_cast<ir::constant_range*>(v)){
|
||||||
return cache(x->get_first()->get_value());
|
return cache(x->get_first()->get_value());
|
||||||
}
|
}
|
||||||
@@ -270,7 +300,6 @@ void alignment_info::run(ir::module &mod) {
|
|||||||
for(ir::basic_block *block: fn->blocks())
|
for(ir::basic_block *block: fn->blocks())
|
||||||
for(ir::instruction *i: block->get_inst_list()){
|
for(ir::instruction *i: block->get_inst_list()){
|
||||||
populate_max_contiguous(i);
|
populate_max_contiguous(i);
|
||||||
// std::cout << i->get_name() << " " << max_contiguous_.at(i) << " " << starting_multiple_.at(i) << std::endl;
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@@ -233,10 +233,15 @@ void tune::run(ir::module &mod) {
|
|||||||
for(ir::instruction *i : block->get_inst_list()){
|
for(ir::instruction *i : block->get_inst_list()){
|
||||||
if(fragments_.find({i, 0}) != fragments_.end() && fragments_.at({i, 0}) != STRIDED_SCAN)
|
if(fragments_.find({i, 0}) != fragments_.end() && fragments_.at({i, 0}) != STRIDED_SCAN)
|
||||||
continue;
|
continue;
|
||||||
if(dynamic_cast<ir::load_inst*>(i) && i->get_type()->is_tile_ty()){
|
if(auto *ld = dynamic_cast<ir::load_inst*>(i))
|
||||||
ir::type *ty = mod.get_builder().get_int32_ty();
|
if(i->get_type()->is_tile_ty()){
|
||||||
std::unique_ptr<ir::metaparameter> tmp(ir::metaparameter::create(ctx, ty, 4, 8));
|
ir::type *ptr_ty = ld->get_pointer_operand()->get_type()->get_scalar_ty();
|
||||||
*params_.at(i).at("nts.d0") = *tmp;
|
size_t addr_space = ptr_ty->get_pointer_address_space();
|
||||||
|
if(addr_space < 4){
|
||||||
|
ir::type *ty = mod.get_builder().get_int32_ty();
|
||||||
|
std::unique_ptr<ir::metaparameter> tmp(ir::metaparameter::create(ctx, ty, 8, 8));
|
||||||
|
*params_.at(i).at("nts.d0") = *tmp;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
if(dynamic_cast<ir::dot_inst*>(i) && i->get_type()->is_tile_ty()){
|
if(dynamic_cast<ir::dot_inst*>(i) && i->get_type()->is_tile_ty()){
|
||||||
ir::type *ty = mod.get_builder().get_int32_ty();
|
ir::type *ty = mod.get_builder().get_int32_ty();
|
||||||
|
@@ -51,8 +51,7 @@ void base::enqueue(driver::stream *stream, std::vector<driver::buffer *> args, b
|
|||||||
jit->add_module(name_.c_str(), src.c_str(), best.params);
|
jit->add_module(name_.c_str(), src.c_str(), best.params);
|
||||||
}
|
}
|
||||||
else {
|
else {
|
||||||
// jit->add_module(name_.c_str(), src.c_str(), jit->get_valid(name_.c_str(), src.c_str()));
|
jit->add_module(name_.c_str(), src.c_str(), jit->get_valid(name_.c_str(), src.c_str()));
|
||||||
jit->add_module(name_.c_str(), src.c_str(), {32, 128, 16, 128, 2, 2, 2, 2, 4, 4, 32, 8, 4, 1});
|
|
||||||
}
|
}
|
||||||
triton::driver::kernel* kernel = jit->get_function(name_.c_str());
|
triton::driver::kernel* kernel = jit->get_function(name_.c_str());
|
||||||
clone->init_impl(stream, (triton::driver::cu_module*)kernel->module());
|
clone->init_impl(stream, (triton::driver::cu_module*)kernel->module());
|
||||||
|
@@ -113,8 +113,8 @@ void matmul(restrict read_only align(16) )" + a_ty_ + R"( *A,
|
|||||||
int32 bound, int32 *locks, int32 grid0, int32 grid1) {
|
int32 bound, int32 *locks, int32 grid0, int32 grid1) {
|
||||||
int32 ridx = get_range_id(0);
|
int32 ridx = get_range_id(0);
|
||||||
int32 ridy = get_range_id(1);
|
int32 ridy = get_range_id(1);
|
||||||
int32 rxa[TM] = ridx*TM + (0 ... TM);
|
int32 rxa[TM] = ridx * TM + (0 ... TM);
|
||||||
int32 ryb[TN] = ridy*TN + (0 ... TN);
|
int32 ryb[TN] = ridy * TN + (0 ... TN);
|
||||||
int32 rka[TK] = 0 ... TK;
|
int32 rka[TK] = 0 ... TK;
|
||||||
int32 rkb[TK] = 0 ... TK;
|
int32 rkb[TK] = 0 ... TK;
|
||||||
fp32 c[TM, TN] = 0;
|
fp32 c[TM, TN] = 0;
|
||||||
|
@@ -27,7 +27,7 @@ shift::shift(int B, int C,
|
|||||||
layout_(layout){
|
layout_(layout){
|
||||||
// std::cout << B_ << " " << C_ << " " << F_ << " " << stride_h_ << " " << stride_w_ << " " << a_ty_ << " " << b_ty_ << " " << ty_ << " " << layout_ << std::endl;
|
// std::cout << B_ << " " << C_ << " " << F_ << " " << stride_h_ << " " << stride_w_ << " " << a_ty_ << " " << b_ty_ << " " << ty_ << " " << layout_ << std::endl;
|
||||||
// max number of channels
|
// max number of channels
|
||||||
TK_ = (ty == FPROP && a_ty_ == "fp32") ? 8 : 16;
|
TK_ = (ty == FPROP && a_ty_ == "fp32") ? 8 : 32;
|
||||||
MAX_C_ = 8192 + TK_;
|
MAX_C_ = 8192 + TK_;
|
||||||
// activation sizes
|
// activation sizes
|
||||||
CD_ = AD_ / stride_d_;
|
CD_ = AD_ / stride_d_;
|
||||||
@@ -223,7 +223,7 @@ void shift::deinit_impl() {
|
|||||||
void shift::enqueue_impl(driver::stream *stream, driver::kernel *kernel,
|
void shift::enqueue_impl(driver::stream *stream, driver::kernel *kernel,
|
||||||
std::vector<driver::buffer *> args,
|
std::vector<driver::buffer *> args,
|
||||||
runtime::launch_information info) {
|
runtime::launch_information info) {
|
||||||
unsigned TM = info.global_range_size[0], TN = info.global_range_size[1];
|
unsigned TM = info.globals.at("TM"), TN = info.globals.at("TN");
|
||||||
unsigned grid_0 = (M_ + TM - 1)/TM;
|
unsigned grid_0 = (M_ + TM - 1)/TM;
|
||||||
unsigned grid_1 = (N_ + TN - 1)/TN;
|
unsigned grid_1 = (N_ + TN - 1)/TN;
|
||||||
unsigned num_locks = grid_0 * grid_1;
|
unsigned num_locks = grid_0 * grid_1;
|
||||||
@@ -278,6 +278,8 @@ void shift::triton_c_src(std::ostream &os) const {
|
|||||||
std::string usea = AT_ ? "trans(a)" : "a";
|
std::string usea = AT_ ? "trans(a)" : "a";
|
||||||
std::string useb = BT_ ? "trans(b)" : "b";
|
std::string useb = BT_ ? "trans(b)" : "b";
|
||||||
std::string bca0 = "[newaxis, :]", bca1 = "[:, newaxis]";
|
std::string bca0 = "[newaxis, :]", bca1 = "[:, newaxis]";
|
||||||
|
std::string stride_h = std::to_string(stride_h_);
|
||||||
|
std::string stride_w = std::to_string(stride_w_);
|
||||||
if(AT_){
|
if(AT_){
|
||||||
std::swap(AS0, AS1);
|
std::swap(AS0, AS1);
|
||||||
std::swap(bca0, bca1);
|
std::swap(bca0, bca1);
|
||||||
@@ -290,6 +292,11 @@ void shift::triton_c_src(std::ostream &os) const {
|
|||||||
std::string BS = BS0 + ", " + BS1;
|
std::string BS = BS0 + ", " + BS1;
|
||||||
bool is_chwn = layout_ == CHWN;
|
bool is_chwn = layout_ == CHWN;
|
||||||
|
|
||||||
|
std::string lda_b = is_chwn ? "1" : "lda_b";
|
||||||
|
std::string ldb_b = is_chwn ? "1" : "ldb_b";
|
||||||
|
std::string ldc_b = is_chwn ? "1" : "ldc_b";
|
||||||
|
|
||||||
|
|
||||||
auto compute_bhw = [&](std::string rx, std::string sz, std::string rkx){
|
auto compute_bhw = [&](std::string rx, std::string sz, std::string rkx){
|
||||||
std::string B = std::to_string(B_);
|
std::string B = std::to_string(B_);
|
||||||
std::string CW = std::to_string(ICW_);
|
std::string CW = std::to_string(ICW_);
|
||||||
@@ -317,7 +324,7 @@ const tunable int32 TM = {16, 32, 64, 128};
|
|||||||
const tunable int32 TN = {16, 32, 64, 128};
|
const tunable int32 TN = {16, 32, 64, 128};
|
||||||
const tunable int32 TK = {)" + std::to_string(TK_) + "};";
|
const tunable int32 TK = {)" + std::to_string(TK_) + "};";
|
||||||
if(op_ == WGRAD)
|
if(op_ == WGRAD)
|
||||||
result += "const tunable int32 GZ = {1, 4, 16};";
|
result += "const tunable int32 GZ = {1};";
|
||||||
else
|
else
|
||||||
result += "const tunable int32 GZ = {1};";
|
result += "const tunable int32 GZ = {1};";
|
||||||
|
|
||||||
@@ -329,30 +336,27 @@ void shift(restrict read_only align(16) )" + a_ty_ + R"( *A,
|
|||||||
)" + c_ty_ + R"( *C,
|
)" + c_ty_ + R"( *C,
|
||||||
int32 M, int32 N, int32 K,
|
int32 M, int32 N, int32 K,
|
||||||
int32 stride_h, int32 stride_w,
|
int32 stride_h, int32 stride_w,
|
||||||
multiple_of(4) int32 lda_b, multiple_of(4) int32 lda_w, multiple_of(4) int32 lda_h, multiple_of(4) int32 lda_c,
|
multiple_of(8) int32 lda_b, multiple_of(8) int32 lda_w, multiple_of(8) int32 lda_h, multiple_of(8) int32 lda_c,
|
||||||
multiple_of(4) int32 ldb_b, multiple_of(4) int32 ldb_w, multiple_of(4) int32 ldb_h, multiple_of(4) int32 ldb_c,
|
multiple_of(8) int32 ldb_b, multiple_of(8) int32 ldb_w, multiple_of(8) int32 ldb_h, multiple_of(8) int32 ldb_c,
|
||||||
multiple_of(4) int32 ldc_b, multiple_of(4) int32 ldc_w, multiple_of(4) int32 ldc_h, multiple_of(4) int32 ldc_c,
|
multiple_of(8) int32 ldc_b, multiple_of(8) int32 ldc_w, multiple_of(8) int32 ldc_h, multiple_of(8) int32 ldc_c,
|
||||||
int32 NB,
|
int32 NB,
|
||||||
int32 AH, int32 AW,
|
int32 AH, int32 AW,
|
||||||
int32 BH, int32 BW,
|
int32 BH, int32 BW,
|
||||||
int32 CH, int32 CW,
|
int32 CH, int32 CW,
|
||||||
int32* locks, int32 grid0, int32 grid1, int32 grid2) {
|
int32* locks, int32 grid0, int32 grid1, int32 grid2) {
|
||||||
int32 rxa[TM] = get_global_range[TM](0);
|
int32 ridx = get_range_id(0);
|
||||||
int32 ryb[TN] = get_global_range[TN](1);
|
int32 ridy = get_range_id(1);
|
||||||
int32 rz = get_global_range[1](2);
|
int32 rz = get_range_id(2);
|
||||||
|
int32 rxa[TM] = ridx*TM + (0 ... TM);
|
||||||
|
int32 ryb[TN] = ridy*TN + (0 ... TN);
|
||||||
int32 rka[TK] = 0 ... TK;
|
int32 rka[TK] = 0 ... TK;
|
||||||
int32 rkb[TK] = 0 ... TK;
|
int32 rkb[TK] = 0 ... TK;
|
||||||
fp32 acc[TM, TN] = 0;
|
fp32 acc[TM, TN] = 0;
|
||||||
int32 pad_h = BH / 2;
|
int32 pad_h = BH / 2;
|
||||||
int32 pad_w = BW / 2;
|
int32 pad_w = BW / 2;)";
|
||||||
int32 div = K / grid2;
|
|
||||||
int32 rem = K % grid2;
|
|
||||||
K = select(rz < rem, div - 1, div);
|
|
||||||
int32 offk = select(rz < rem, rz*(div + 1), rz*div + rem);)";
|
|
||||||
if(op_ == WGRAD){
|
if(op_ == WGRAD){
|
||||||
result += R"(
|
result += R"(
|
||||||
rka = rka + offk;
|
|
||||||
rkb = rkb + offk;
|
|
||||||
)";
|
)";
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -360,31 +364,26 @@ if(op_ == WGRAD){
|
|||||||
if(op_ == FPROP){
|
if(op_ == FPROP){
|
||||||
result +=
|
result +=
|
||||||
compute_bhw("ra", "TM", "rxa") + R"(
|
compute_bhw("ra", "TM", "rxa") + R"(
|
||||||
raw = raw * stride_w;
|
raw = raw * )" + stride_w + R"(;
|
||||||
rah = rah * stride_h;
|
rah = rah * )" + stride_h + R"(;
|
||||||
int32 offxa[TM] = rab*lda_b + raw*lda_w + rah*lda_h;
|
int32 offxa[TM] = rab*)" + lda_b + R"( + raw*lda_w + rah*lda_h;
|
||||||
int32 offa0[TM, TK] = offxa[:, newaxis];
|
int32 offa0[TM, TK] = offxa[:, newaxis];
|
||||||
__constant__ int32* pd[TK] = delta_a + rka;
|
__constant__ int32* pd[TK] = delta_a + rka;
|
||||||
multiple_of(4) int32 d[TK] = *pd;
|
multiple_of(8) int32 d[TK] = *pd;
|
||||||
int32 offa1[TM, TK] = d[newaxis, :];)";
|
int32 offa1[TM, TK] = d[newaxis, :];)";
|
||||||
}
|
}
|
||||||
if(op_ == BPROP){
|
if(op_ == BPROP){
|
||||||
result +=
|
result +=
|
||||||
compute_bhw("ra", "TM", "rxa") + R"(
|
compute_bhw("ra", "TM", "rxa") + R"(
|
||||||
int32 offxa[TM] = rab*lda_b + raw*lda_w + rah*lda_h;
|
int32 offxa[TM] = rab*)" + lda_b + R"( + raw*lda_w + rah*lda_h;
|
||||||
int32 offa0[TM, TK] = offxa[:, newaxis];
|
int32 offa0[TM, TK] = offxa[:, newaxis];
|
||||||
int32 offa1[TM, TK] = rka[newaxis, :] * lda_c;)";
|
int32 offa1[TM, TK] = rka[newaxis, :] * lda_c;)";
|
||||||
}
|
}
|
||||||
if(op_ == WGRAD && layout_ == CHWN){
|
if(op_ == WGRAD){
|
||||||
result += R"(
|
|
||||||
int32 offa0[TK, TM] = rxa[newaxis, :] * lda_c;
|
|
||||||
int32 offa1[TK, TM] = rka[:, newaxis];)";
|
|
||||||
}
|
|
||||||
if(op_ == WGRAD && layout_ == NCHW){
|
|
||||||
result +=
|
result +=
|
||||||
compute_bhw("ra", "TK", "rka") + R"(
|
compute_bhw("ra", "TK", "rka") + R"(
|
||||||
int32 offa0[TK, TM] = rxa[newaxis, :] * lda_c;
|
int32 offa0[TK, TM] = rxa[newaxis, :] * lda_c;
|
||||||
int32 offxa[TK] = rab*lda_b + raw*lda_w + rah*lda_h;
|
int32 offxa[TK] = rab*)" + lda_b + R"( + raw*lda_w + rah*lda_h;
|
||||||
int32 offa1[TK, TM] = offxa[:, newaxis];)";
|
int32 offa1[TK, TM] = offxa[:, newaxis];)";
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -403,11 +402,11 @@ if(op_ == WGRAD){
|
|||||||
result +=
|
result +=
|
||||||
compute_bhw("rb", "TK", "rkb") + R"(
|
compute_bhw("rb", "TK", "rkb") + R"(
|
||||||
__constant__ int32* pd[TN] = delta_a + ryb;
|
__constant__ int32* pd[TN] = delta_a + ryb;
|
||||||
int32 d[TN] = *pd;
|
multiple_of(8) int32 d[TN] = *pd;
|
||||||
int32 shift[TK, TN] = d[newaxis, :];
|
multiple_of(8) int32 shift[TK, TN] = d[newaxis, :];
|
||||||
rbw = rbw * stride_w;
|
rbw = rbw * )" + stride_w + R"(;
|
||||||
rbh = rbh * stride_h;
|
rbh = rbh * )" + stride_h + R"(;
|
||||||
int32 offkb[TK] = rbb*ldb_b + rbw*ldb_w + rbh*ldb_h;
|
int32 offkb[TK] = rbb*)" + ldb_b + R"( + rbw*ldb_w + rbh*ldb_h;
|
||||||
int32 offb0[TK, TN] = ryb[newaxis, :] * ldb_c;
|
int32 offb0[TK, TN] = ryb[newaxis, :] * ldb_c;
|
||||||
int32 offb1[TK, TN] = offkb[:, newaxis] + shift;)";
|
int32 offb1[TK, TN] = offkb[:, newaxis] + shift;)";
|
||||||
}
|
}
|
||||||
@@ -416,8 +415,8 @@ if(op_ == WGRAD){
|
|||||||
result += R"(
|
result += R"(
|
||||||
)" + a_ty_ + "* pa[" + AS + R"(] = A + offa0 + offa1;
|
)" + a_ty_ + "* pa[" + AS + R"(] = A + offa0 + offa1;
|
||||||
)" + b_ty_ + "* pb[" + BS + R"(] = B + offb0 + offb1;
|
)" + b_ty_ + "* pb[" + BS + R"(] = B + offb0 + offb1;
|
||||||
int1 checka[)" + AS + "] = (rka < K + offk)" + bca0 + R"(;
|
int1 checka[)" + AS + "] = (rka < K)" + bca0 + R"(;
|
||||||
int1 checkb[)" + BS + "] = (rkb < K + offk)" + bcb0 + R"(;
|
int1 checkb[)" + BS + "] = (rkb < K)" + bcb0 + R"(;
|
||||||
)" + a_ty_ + " a[" + AS + R"(] = checka ? *pa : 0;
|
)" + a_ty_ + " a[" + AS + R"(] = checka ? *pa : 0;
|
||||||
)" + b_ty_ + " b[" + BS + R"(] = checkb ? *pb : 0;
|
)" + b_ty_ + " b[" + BS + R"(] = checkb ? *pb : 0;
|
||||||
for(int32 k = K; k > 0; k = k - TK){
|
for(int32 k = K; k > 0; k = k - TK){
|
||||||
@@ -436,15 +435,11 @@ if(op_ == BPROP){
|
|||||||
result += R"(
|
result += R"(
|
||||||
pa = pa + TK * lda_c;)";
|
pa = pa + TK * lda_c;)";
|
||||||
}
|
}
|
||||||
if(op_ == WGRAD && layout_ == CHWN){
|
if(op_ == WGRAD){
|
||||||
result += R"(
|
|
||||||
pa = pa + TK;)";
|
|
||||||
}
|
|
||||||
if(op_ == WGRAD && layout_ == NCHW){
|
|
||||||
result += R"(
|
result += R"(
|
||||||
rka = rka + TK;)"
|
rka = rka + TK;)"
|
||||||
+ compute_bhw("ra", "TK", "rka") + R"(
|
+ compute_bhw("ra", "TK", "rka") + R"(
|
||||||
offxa = rab*lda_b + raw*lda_w + rah*lda_h;
|
offxa = rab*)" + lda_b + R"( + raw*lda_w + rah*lda_h;
|
||||||
pa = A + offa0 + offxa[:, newaxis];)";
|
pa = A + offa0 + offxa[:, newaxis];)";
|
||||||
}
|
}
|
||||||
result += R"(
|
result += R"(
|
||||||
@@ -455,9 +450,9 @@ if(op_ == WGRAD){
|
|||||||
result += R"(
|
result += R"(
|
||||||
rkb = rkb + TK;)"
|
rkb = rkb + TK;)"
|
||||||
+ compute_bhw("rb", "TK", "rkb") + R"(
|
+ compute_bhw("rb", "TK", "rkb") + R"(
|
||||||
rbw = rbw * stride_w;
|
rbw = rbw * )" + stride_w + R"(;
|
||||||
rbh = rbh * stride_h;
|
rbh = rbh * )" + stride_h + R"(;
|
||||||
offkb = rbb*ldb_b + rbw*ldb_w + rbh*ldb_h;
|
offkb = rbb*)" + ldb_b + R"( + rbw*ldb_w + rbh*ldb_h;
|
||||||
pb = B + offb0 + offkb[:, newaxis] + shift;)";
|
pb = B + offb0 + offkb[:, newaxis] + shift;)";
|
||||||
}
|
}
|
||||||
if(op_ == FPROP){
|
if(op_ == FPROP){
|
||||||
@@ -471,21 +466,21 @@ if(op_ == BPROP){
|
|||||||
result += R"(
|
result += R"(
|
||||||
@checkb b = *pb;
|
@checkb b = *pb;
|
||||||
}
|
}
|
||||||
int32 rxc[TM] = get_global_range[TM](0);
|
int32 rxc[TM] = ridx*TM + (0 ... TM);
|
||||||
int32 ryc[TN] = get_global_range[TN](1);)";
|
int32 ryc[TN] = ridy*TN + (0 ... TN);)";
|
||||||
|
|
||||||
/* C offsets */
|
/* C offsets */
|
||||||
if(op_ == BPROP){
|
if(op_ == BPROP){
|
||||||
result +=
|
result +=
|
||||||
compute_bhw("rc", "TM", "rxc") + R"(
|
compute_bhw("rc", "TM", "rxc") + R"(
|
||||||
rcw = rcw * stride_w;
|
rcw = rcw * )" + stride_w + R"(;
|
||||||
rch = rch * stride_h;
|
rch = rch * )" + stride_h + R"(;
|
||||||
int32 offxc[TM] = rcb*ldc_b + rcw*ldc_w + rch*ldc_h;)";
|
int32 offxc[TM] = rcb*)" + ldc_b + R"( + rcw*ldc_w + rch*ldc_h;)";
|
||||||
}
|
}
|
||||||
if(op_ == FPROP){
|
if(op_ == FPROP){
|
||||||
result +=
|
result +=
|
||||||
compute_bhw("rc", "TM", "rxc") + R"(
|
compute_bhw("rc", "TM", "rxc") + R"(
|
||||||
int32 offxc[TM] = rcb*ldc_b + rcw*ldc_w + rch*ldc_h;)";
|
int32 offxc[TM] = rcb*)" + ldc_b + R"( + rcw*ldc_w + rch*ldc_h;)";
|
||||||
}
|
}
|
||||||
if(op_ == WGRAD){
|
if(op_ == WGRAD){
|
||||||
result += R"(
|
result += R"(
|
||||||
@@ -506,27 +501,7 @@ if(op_ == BPROP){
|
|||||||
}
|
}
|
||||||
else{
|
else{
|
||||||
result += R"(
|
result += R"(
|
||||||
int1 has_lock = (GZ > 1) && (locks != 0);
|
@checkc *pc = c;)";
|
||||||
if(has_lock){
|
|
||||||
int32 ridx = get_range_id(0);
|
|
||||||
int32 ridy = get_range_id(1);
|
|
||||||
int32 *plock = locks + ridx + ridy*grid0;
|
|
||||||
int32 *pcount = plock + grid0*grid1;
|
|
||||||
while(__atomic_cas(plock, 0, 1) == 1);
|
|
||||||
int32 count = *pcount;
|
|
||||||
int32 countp1 = select(count == grid2 - 1, 0, count + 1);
|
|
||||||
if(count == 0) {
|
|
||||||
@checkc *pc = c;
|
|
||||||
}
|
|
||||||
else {
|
|
||||||
@checkc *pc = c + *pc;
|
|
||||||
}
|
|
||||||
*pcount = countp1;
|
|
||||||
*plock = 0;
|
|
||||||
}
|
|
||||||
else{
|
|
||||||
@checkc *pc = c;
|
|
||||||
})";
|
|
||||||
}
|
}
|
||||||
result += R"(
|
result += R"(
|
||||||
})";
|
})";
|
||||||
|
@@ -130,7 +130,7 @@ bool binary_operator::is_int_mult() const {
|
|||||||
}
|
}
|
||||||
|
|
||||||
bool binary_operator::is_int_add_sub() const {
|
bool binary_operator::is_int_add_sub() const {
|
||||||
return op_ == llop::Add || llop::Sub;
|
return op_ == llop::Add || op_ == llop::Sub;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
Reference in New Issue
Block a user