more improvements and regressions

This commit is contained in:
Philippe Tillet
2019-08-06 16:21:20 -07:00
parent 26c9849462
commit 5efdb7978e
14 changed files with 138 additions and 69 deletions

View File

@@ -26,7 +26,7 @@ struct perf_t {
perf_t do_bench(triton::driver::stream* stream, bool AT, bool BT, int32_t M, int32_t N, int32_t K){ perf_t do_bench(triton::driver::stream* stream, bool AT, bool BT, int32_t M, int32_t N, int32_t K){
typedef float NumericT; typedef float NumericT;
std::string ty = "float"; std::string ty = "half";
size_t dt_nbytes = sizeof(NumericT); size_t dt_nbytes = sizeof(NumericT);
triton::driver::context* context = stream->context(); triton::driver::context* context = stream->context();
std::vector<NumericT> hc(M*N); std::vector<NumericT> hc(M*N);
@@ -48,7 +48,7 @@ perf_t do_bench(triton::driver::stream* stream, bool AT, bool BT, int32_t M, int
stream->synchronize(); stream->synchronize();
triton::dnn::dot dot(M, N, K, AT, BT, ty, ty, 8, 8, 8); triton::dnn::dot dot(M, N, K, AT, BT, ty, ty, 8, 8, 8);
// benchmark triton // benchmark triton
double triton_ns = triton::tools::bench([&]() { dot.enqueue(stream, {da, db, dc}, triton::dnn::NO_TUNING);}, stream); double triton_ns = triton::tools::bench([&]() { dot.enqueue(stream, {da, db, dc}, triton::dnn::FULL_TUNING);}, stream);
// benchmark cublas // benchmark cublas
// NumericT alpha = 1; // NumericT alpha = 1;
// NumericT beta = 0; // NumericT beta = 0;
@@ -111,7 +111,7 @@ int main() {
std::vector<config_t> configs = { std::vector<config_t> configs = {
// {false, false, 8192, 512, 512}, // {false, false, 8192, 512, 512},
// {false, true, 8192, 8192, 8192} // {false, true, 8192, 8192, 8192}
{false, true, 128, 128, 128}, {true, true, 128, 128, 128},
// {false, true, 32768, 256, 512} // {false, true, 32768, 256, 512}
// {true, false, 8192, 512, 512}, // {true, false, 8192, 512, 512},
// {true, true, 8192, 512, 512} // {true, true, 8192, 512, 512}

View File

@@ -38,6 +38,7 @@ protected:
public: public:
virtual uint64_t get_value() const { return value_; } virtual uint64_t get_value() const { return value_; }
virtual std::string repr() const { return std::to_string(get_value()); }
static constant_int *get(type *ty, uint64_t value); static constant_int *get(type *ty, uint64_t value);
protected: protected:
@@ -57,7 +58,7 @@ public:
const std::vector<unsigned>& get_space() { return space_; } const std::vector<unsigned>& get_space() { return space_; }
void set_space(const std::vector<unsigned> &space) { space_ = space; } void set_space(const std::vector<unsigned> &space) { space_ = space; }
uint64_t get_value() const { assert(has_value_); return value_; } uint64_t get_value() const { assert(has_value_); return value_; }
std::string repr() const { return has_value_? std::to_string(value_) : "?" ;}
private: private:
std::vector<unsigned> space_; std::vector<unsigned> space_;
bool has_value_; bool has_value_;

View File

@@ -584,12 +584,18 @@ private:
class trans_inst: public builtin_inst { class trans_inst: public builtin_inst {
public: public:
ir::type* get_res_ty(ir::type* in); ir::type* get_res_ty(ir::type* in, std::vector<constant_int *> perm);
std::vector<constant_int*> get_default_perm(ir::type* ty); std::vector<constant_int*> init_perm(ir::type* ty, const std::vector<constant_int*>& perm);
private: private:
trans_inst(value *arg, const std::vector<constant_int*>& perm, const std::string& name, instruction* next); trans_inst(value *arg, const std::vector<constant_int*>& perm, const std::string& name, instruction* next);
std::string repr_impl() const { return "trans"; } std::string repr_impl() const {
std::string res = "trans<";
for(ir::constant_int *x: perm_)
res += x->repr() + ",";
res[res.size()-1] = '>';
return res;
}
public: public:
static instruction* create(value *arg, const std::vector<constant_int*>& perm = {}, const std::string &name = "", instruction *next = nullptr); static instruction* create(value *arg, const std::vector<constant_int*>& perm = {}, const std::string &name = "", instruction *next = nullptr);
@@ -609,7 +615,7 @@ public:
class reduce_inst: public builtin_inst { class reduce_inst: public builtin_inst {
private: private:
static type* get_type(value *arg, unsigned axis); static type* get_res_type(value *arg, unsigned axis);
private: private:
reduce_inst(value* arg, unsigned axis, const std::string& name, instruction* next); reduce_inst(value* arg, unsigned axis, const std::string& name, instruction* next);

View File

@@ -180,11 +180,12 @@ private:
class trans_expression: public builtin_expression{ class trans_expression: public builtin_expression{
public: public:
trans_expression(node *arg): arg_(arg) {} trans_expression(node *arg, node *perm): arg_(arg), perm_((list<expression*>*)perm) {}
ir::value* codegen(ir::module *mod) const; ir::value* codegen(ir::module *mod) const;
private: private:
node* arg_; node* arg_;
const list<expression*>* perm_;
}; };
class sqrt_expression: public builtin_expression{ class sqrt_expression: public builtin_expression{

View File

@@ -125,7 +125,8 @@ builtin_expression
| DOT '(' expression ',' expression ',' expression ')' { $$ = new matmul_expression($3, $5, $7); } | DOT '(' expression ',' expression ',' expression ')' { $$ = new matmul_expression($3, $5, $7); }
| SQRT '(' expression ')' { $$ = new sqrt_expression($3); } | SQRT '(' expression ')' { $$ = new sqrt_expression($3); }
| ALLOC_CONST type_specifier '[' constant ']' { $$ = new alloc_const_expression(new typed_declaration_specifier(get_type_spec($2)), $4); } | ALLOC_CONST type_specifier '[' constant ']' { $$ = new alloc_const_expression(new typed_declaration_specifier(get_type_spec($2)), $4); }
| TRANS '(' expression ')' { $$ = new trans_expression($3); } | TRANS '(' expression ',' constant_expression_list ')' { $$ = new trans_expression($3, $5); }
| TRANS '(' expression ')' { $$ = new trans_expression($3, nullptr); }
| REDUCE_SUM '(' expression ',' constant ')' { $$ = new reduce_expression($3, $5);} | REDUCE_SUM '(' expression ',' constant ')' { $$ = new reduce_expression($3, $5);}
| MAX '(' expression ',' expression ')' { $$ = new max_expression($3, $5); } | MAX '(' expression ',' expression ')' { $$ = new max_expression($3, $5); }
| MIN '(' expression ',' expression ')' { $$ = new min_expression($3, $5); } | MIN '(' expression ',' expression ')' { $$ = new min_expression($3, $5); }

View File

@@ -8,7 +8,17 @@ namespace triton {
namespace codegen{ namespace codegen{
inline bool is_trans(ir::value *v){ inline bool is_trans(ir::value *v){
return dynamic_cast<ir::trans_inst*>(v) != nullptr; auto *x = dynamic_cast<ir::trans_inst*>(v);
if(!x)
return false;
std::vector<ir::constant_int*> perm = x->get_perm();
std::vector<ir::constant_int*> ref;
ir::type *int32_ty = ir::type::get_int32_ty(v->get_type()->get_context());
for(size_t i = 0; i < perm.size(); i++)
ref.push_back(ir::constant_int::get(int32_ty, i));
std::swap(ref[0], ref[1]);
// true is perm == ref
return std::equal(perm.begin(), perm.end(), ref.begin());
} }
inline bool is_hmma(ir::value *v){ inline bool is_hmma(ir::value *v){
@@ -28,7 +38,6 @@ inline bool is_hmma(ir::value *v){
void optimize_dot::run(ir::module &mod) { void optimize_dot::run(ir::module &mod) {
ir::builder &builder = mod.get_builder(); ir::builder &builder = mod.get_builder();
std::vector<ir::instruction*> to_delete;
// iterate // iterate
for(ir::function *fn: mod.get_function_list()) for(ir::function *fn: mod.get_function_list())
for(ir::basic_block *block: fn->blocks()) for(ir::basic_block *block: fn->blocks())
@@ -47,15 +56,12 @@ void optimize_dot::run(ir::module &mod) {
ir::value *BB = B; ir::value *BB = B;
if(trans_a){ if(trans_a){
AA = ((ir::trans_inst*)A)->get_operand(0); AA = ((ir::trans_inst*)A)->get_operand(0);
to_delete.push_back((ir::instruction*)A);
} }
if(trans_b){ if(trans_b){
BB = ((ir::trans_inst*)B)->get_operand(0); BB = ((ir::trans_inst*)B)->get_operand(0);
to_delete.push_back((ir::instruction*)B);
} }
ir::instruction *dot_atbt = builder.insert(ir::dot_inst::create(AA, BB, D, trans_a, trans_b)); ir::instruction *dot_atbt = builder.insert(ir::dot_inst::create(AA, BB, D, trans_a, trans_b));
dot->replace_all_uses_with(dot_atbt); dot->replace_all_uses_with(dot_atbt);
to_delete.push_back(dot);
} }
else{ else{
// dot(op(a), trans(b)) // dot(op(a), trans(b))
@@ -63,28 +69,24 @@ void optimize_dot::run(ir::module &mod) {
ir::value* BB = ((ir::trans_inst*)B)->get_operand(0); ir::value* BB = ((ir::trans_inst*)B)->get_operand(0);
ir::instruction *NT = builder.insert(ir::dot_inst::create_nt(A, BB, D)); ir::instruction *NT = builder.insert(ir::dot_inst::create_nt(A, BB, D));
dot->replace_all_uses_with(NT); dot->replace_all_uses_with(NT);
to_delete.push_back((ir::instruction*)B);
to_delete.push_back(dot);
} }
// dot(op(a), b) // dot(op(a), b)
if(!trans_b){ if(!trans_b){
// create permutations
size_t size = B->get_type()->get_tile_shapes().size(); size_t size = B->get_type()->get_tile_shapes().size();
std::vector<ir::constant_int*> perm(size); std::vector<ir::constant_int*> perm(size);
ir::type *int32_ty = ir::type::get_int32_ty(B->get_type()->get_context()); ir::type *int32_ty = ir::type::get_int32_ty(B->get_type()->get_context());
for(size_t i = 0; i < size; i++) for(size_t i = 0; i < size; i++)
perm[i] = ir::constant_int::get(int32_ty, i); perm[i] = ir::constant_int::get(int32_ty, i);
std::swap(perm[0], perm[1]); std::swap(perm[0], perm[1]);
// replace NN -> NT (trans)
ir::value* BB = builder.create_trans(B, perm); ir::value* BB = builder.create_trans(B, perm);
ir::instruction *NT = builder.insert(ir::dot_inst::create_nt(A, BB, D)); ir::instruction *NT = builder.insert(ir::dot_inst::create_nt(A, BB, D));
dot->replace_all_uses_with(NT); dot->replace_all_uses_with(NT);
to_delete.push_back(dot);
} }
} }
} }
} }
for(ir::instruction* i: to_delete)
i->erase_from_parent();
} }
} }

View File

@@ -42,22 +42,32 @@ void optimize_trans::run(ir::module &mod) {
for(ir::function *fn: mod.get_function_list()) for(ir::function *fn: mod.get_function_list())
for(ir::basic_block *block: fn->blocks()) for(ir::basic_block *block: fn->blocks())
for(ir::instruction* i: block->get_inst_list()){ for(ir::instruction* i: block->get_inst_list()){
// filter transposition // transposition
if(auto trans = dynamic_cast<ir::trans_inst*>(i)) { if(auto trans = dynamic_cast<ir::trans_inst*>(i)) {
auto users = trans->get_users(); auto users = trans->get_users();
auto ops = trans->ops(); auto ops = trans->ops();
if(users.size() > 1 || ops.size() > 1) if(users.size() > 1 || ops.size() > 1)
continue; continue;
ir::value* op = *ops.begin(); ir::value* op = *ops.begin();
// chains of transpositions // todo: chains of transpositions
// TODO
// trans(phi) -> phi(trans(), trans()...) // trans(phi) -> phi(trans(), trans()...)
if(dynamic_cast<ir::phi_node*>(op)){ if(dynamic_cast<ir::phi_node*>(op)){
ir::value* new_phi = replace_phi(op, builder, trans->get_perm()); ir::value* new_phi = replace_phi(op, builder, trans->get_perm());
trans->replace_all_uses_with(new_phi); trans->replace_all_uses_with(new_phi);
} }
} }
// reductions
if(auto x = dynamic_cast<ir::reduce_inst*>(i)) {
ir::constant_int *one = ir::constant_int::get(ir::type::get_int32_ty(i->get_type()->get_context()), 1);
ir::value *arg = x->get_operand(0);
auto shapes = arg->get_type()->get_tile_shapes();
if(shapes[x->get_axis()] == one){
builder.set_insert_point(x);
ir::value* new_red = builder.create_reshape(arg, x->get_type()->get_tile_shapes());
x->replace_all_uses_with(new_red);
}
}
} }
} }

View File

@@ -996,8 +996,9 @@ void selection::lower_tile_instruction(ir::instruction *ins, llvm::IRBuilder<> &
distributed_tile *TC = (distributed_tile*)tmap_.at(C); distributed_tile *TC = (distributed_tile*)tmap_.at(C);
Type *c_ty = llvm_type(C->get_type()->get_scalar_ty(), ctx); Type *c_ty = llvm_type(C->get_type()->get_scalar_ty(), ctx);
Function *f_mul_add = Intrinsic::getDeclaration(module, Intrinsic::fmuladd, {c_ty}); Function *f_mul_add = Intrinsic::getDeclaration(module, Intrinsic::fmuladd, {c_ty});
auto A_shapes = A->get_type()->get_tile_shapes();
size_t red_axis = dot->is_a_trans() ? 0 : 1; size_t red_axis = dot->is_a_trans() ? 0 : 1;
unsigned NK = A->get_type()->get_tile_shapes()[red_axis]->get_value(); unsigned NK = A_shapes[red_axis]->get_value();
if(NK != 1) if(NK != 1)
{ {
shared_tile *TA = (shared_tile*)tmap_.at(A); shared_tile *TA = (shared_tile*)tmap_.at(A);
@@ -1008,18 +1009,27 @@ void selection::lower_tile_instruction(ir::instruction *ins, llvm::IRBuilder<> &
result->for_each([&](indices_t idx){ result->for_each([&](indices_t idx){
Value *res = TC->get_value(idx); Value *res = TC->get_value(idx);
for(unsigned K = 0; K < NK; ++K){ for(unsigned K = 0; K < NK; ++K){
indices_t a_idx = {idx[0], builder.getInt32(K), idx[2]}; // input indices
indices_t b_idx = {builder.getInt32(K), idx[1], idx[2]}; indices_t a_idx = {idx[0], builder.getInt32(K)};
indices_t b_idx = {builder.getInt32(K), idx[1]};
if(AT) if(AT)
std::swap(a_idx[0], a_idx[1]); std::swap(a_idx[0], a_idx[1]);
if(BT) if(BT)
std::swap(b_idx[0], b_idx[1]); std::swap(b_idx[0], b_idx[1]);
// add batching dimension
for(size_t i = 2; i < idx.size(); i++){
a_idx.insert(a_idx.end(), idx[i]);
b_idx.insert(b_idx.end(), idx[i]);
}
// load value
Value *a = TA->get_value(a_idx); Value *a = TA->get_value(a_idx);
Value *b = TB->get_value(b_idx); Value *b = TB->get_value(b_idx);
if(a->getType() != c_ty) if(a->getType() != c_ty)
a = builder.CreateFPCast(a, c_ty); a = builder.CreateFPCast(a, c_ty);
if(b->getType() != c_ty) if(b->getType() != c_ty)
b = builder.CreateFPCast(b, c_ty); b = builder.CreateFPCast(b, c_ty);
// a = ConstantFP::get(builder.getFloatTy(), 1);
// b = ConstantFP::get(builder.getFloatTy(), 1);
res = builder.CreateCall(f_mul_add, {a, b, res}); res = builder.CreateCall(f_mul_add, {a, b, res});
} }
result->set_value(idx, res); result->set_value(idx, res);

View File

@@ -67,8 +67,6 @@ void tune::init_c_graph(ir::instruction *v) {
continue; continue;
add_constraint({reduce, current++}, {arg, i}); add_constraint({reduce, current++}, {arg, i});
} }
// add_constraint({reduce, 0}, {arg, 0});
// add_constraint({reduce, 1}, {arg, 1});
return; return;
} }
else else
@@ -115,7 +113,7 @@ void tune::init_c_graph(ir::instruction *v) {
} }
} }
// Matrix multiplication // Matrix multiplication
else if(dynamic_cast<ir::dot_inst*>(v)){ else if(auto *x = dynamic_cast<ir::dot_inst*>(v)){
ir::value *A = v->get_operand(0); ir::value *A = v->get_operand(0);
ir::value *B = v->get_operand(1); ir::value *B = v->get_operand(1);
ir::value *D = v->get_operand(2); ir::value *D = v->get_operand(2);
@@ -124,8 +122,8 @@ void tune::init_c_graph(ir::instruction *v) {
for(unsigned i = 2; i < shapes.size(); i++){ for(unsigned i = 2; i < shapes.size(); i++){
if(shapes[i] == one) if(shapes[i] == one)
static_params_.insert({{v, i}, 1}); static_params_.insert({{v, i}, 1});
add_constraint({v, i}, {A, i}); // add_constraint({v, i}, {A, i});
add_constraint({v, i}, {B, i}); // add_constraint({v, i}, {B, i});
} }
} }
// Element-wise // Element-wise
@@ -268,35 +266,53 @@ void tune::run(ir::module &mod) {
for(ir::function *fn: mod.get_function_list()) for(ir::function *fn: mod.get_function_list())
for(ir::basic_block *block: fn->blocks()) for(ir::basic_block *block: fn->blocks())
for(ir::instruction *i : block->get_inst_list()){ for(ir::instruction *i : block->get_inst_list()){
if(fragments_.find({i, 0}) != fragments_.end() && fragments_.at({i, 0}) != STRIDED_SCAN) if(fragments_.find({i, 0}) != fragments_.end() && fragments_.at({i, 0}) != STRIDED_SCAN)
continue; continue;
if(auto *ld = dynamic_cast<ir::load_inst*>(i))
if(auto *x = dynamic_cast<ir::load_inst*>(i))
if(i->get_type()->is_tile_ty()){ if(i->get_type()->is_tile_ty()){
ir::type *ptr_ty = ld->get_pointer_operand()->get_type()->get_scalar_ty(); ir::type *ptr_ty = x->get_pointer_operand()->get_type()->get_scalar_ty();
size_t addr_space = ptr_ty->get_pointer_address_space(); size_t addr_space = ptr_ty->get_pointer_address_space();
if(addr_space < 4){ if(addr_space < 4){
ir::type *ty = mod.get_builder().get_int32_ty(); ir::type *ty = mod.get_builder().get_int32_ty();
std::unique_ptr<ir::metaparameter> tmp(ir::metaparameter::create(ctx, ty, 1, 1)); std::unique_ptr<ir::metaparameter> tmp(ir::metaparameter::create(ctx, ty, 2, 4));
*params_.at(i).at("nts.d0") = *tmp; *params_.at(i).at("nts.d0") = *tmp;
} }
} }
if(dynamic_cast<ir::dot_inst*>(i) && i->get_type()->is_tile_ty()){ if(dynamic_cast<ir::dot_inst*>(i) && i->get_type()->is_tile_ty()){
ir::type *ty = mod.get_builder().get_int32_ty(); ir::type *ty = mod.get_builder().get_int32_ty();
std::unique_ptr<ir::metaparameter> tmp1(ir::metaparameter::create(ctx, ty, 1, 1)); std::unique_ptr<ir::metaparameter> tmp1(ir::metaparameter::create(ctx, ty, 2, 4));
std::unique_ptr<ir::metaparameter> tmp2(ir::metaparameter::create(ctx, ty, 1, 1)); std::unique_ptr<ir::metaparameter> tmp2(ir::metaparameter::create(ctx, ty, 2, 4));
*params_.at(i).at("nts.d0") = *tmp1; *params_.at(i).at("nts.d0") = *tmp1;
*params_.at(i).at("nts.d1") = *tmp2; *params_.at(i).at("nts.d1") = *tmp2;
} }
} }
// initialize grids
for(ir::function *fn: mod.get_function_list()){
std::map<ir::metaparameter*, ir::instruction*> references;
create_grids(grids_, references, fn);
}
for(ir::instruction *i: grids_){
auto shapes = i->get_type()->get_tile_shapes();
for(size_t k = 0; k < shapes.size(); k++)
if(shapes[k]->get_value() == 1) {
if(fragments_.at({i, k}) == STRIDED_SCAN){
params_.at(i).at("nts.d" + std::to_string(k))->set_value(1);
params_.at(i).at("mts.d" + std::to_string(k))->set_value(1);
}
if(fragments_.at({i, k}) == HMMA_FRAGMENT_C){
params_.at(i).at("fpw.d" + std::to_string(k))->set_value(1);
params_.at(i).at("wpt.d" + std::to_string(k))->set_value(1);
}
}
}
} }
void tune::init(ir::module &mod) { void tune::init(ir::module &mod) {
for(ir::function *fn: mod.get_function_list()){
// initialize grids
std::map<ir::metaparameter*, ir::instruction*> references;
create_grids(grids_, references, fn);
}
// number of threads
num_threads_ = get_req_num_threads(grids_.front()); num_threads_ = get_req_num_threads(grids_.front());
} }

View File

@@ -64,7 +64,9 @@ std::pair<base*, rt::jit*> base::get_profile_impl(driver::stream *stream, std::v
else{ else{
// params_t params = heuristics(); // params_t params = heuristics();
// params_t params = jit->get_valid(name_.c_str(), src.c_str()); // params_t params = jit->get_valid(name_.c_str(), src.c_str());
params_t params = {4, 1, 32, 4, 1, 32, 4, 4, 4, 1, 1, 16, 32, 16, 4, 4, 1}; // params_t params = {4, 1, 32, 4, 1, 32, 4, 4, 4, 1, 1, 16, 32, 16, 4, 4, 4, 4, 1}; //NT
// params_t params = {4, 1, 32, 4, 32, 4, 4, 4, 1, 1, 16, 32, 16, 1, 4, 4, 4, 4, 4, 1}; //NN
params_t params = {4, 32, 4, 1, 32, 4, 4, 4, 1, 1, 16, 1, 32, 16, 4, 4, 4, 4, 4, 1}; // TT
jit->add_module(name_.c_str(), src.c_str(), params); jit->add_module(name_.c_str(), src.c_str(), params);
} }
triton::driver::kernel* kernel = jit->get_function(name_.c_str()); triton::driver::kernel* kernel = jit->get_function(name_.c_str());

View File

@@ -74,22 +74,24 @@ void dot::enqueue_impl(driver::stream *stream, driver::kernel *kernel,
void dot::triton_c_src(std::ostream &os) const { void dot::triton_c_src(std::ostream &os) const {
std::string AS0 = "TM", AS1 = "TK"; std::string AS0 = "TM", AS1 = "TK";
std::string BS0 = "TK", BS1 = "TN"; std::string BS0 = "TK", BS1 = "TN";
std::string XAS0 = "TM", XAS1 = "TK/4", XAS2 = "4"; std::string XAS0 = "TM", XAS1 = "TK/1", XAS2 = "1";
std::string XBS0 = "TK/4", XBS1 = "TN", XBS2 = "4"; std::string XBS0 = "TK/1", XBS1 = "1", XBS2 = "TN";
std::string bca0 = "[newaxis, :]", bca1 = "[:, newaxis]"; std::string bca0 = "[newaxis, :]", bca1 = "[:, newaxis]";
std::string bcb0 = "[:, newaxis]", bcb1 = "[newaxis, :]"; std::string bcb0 = "[:, newaxis]", bcb1 = "[newaxis, :]";
std::string lda0 = "*lda", lda1 = ""; std::string lda0 = "*lda", lda1 = "";
std::string ldb0 = "", ldb1 = "*ldb"; std::string ldb0 = "", ldb1 = "*ldb";
std::string usea = AT_ ? "trans(xa)" : "xa"; std::string usea = AT_ ? "trans(xa, 0, 2, 1)" : "xa";
std::string useb = BT_ ? "trans(xb)" : "xb"; std::string useb = BT_ ? "trans(xb, 1, 0, 2)" : "trans(xb, 0, 2, 1)";
if(AT_){ if(AT_){
std::swap(AS0, AS1); std::swap(AS0, AS1);
std::swap(XAS0, XAS1); std::swap(XAS0, XAS1);
std::swap(XAS1, XAS2);
std::swap(bca0, bca1); std::swap(bca0, bca1);
std::swap(lda0, lda1); std::swap(lda0, lda1);
} }
if(BT_){ if(BT_){
std::swap(BS0, BS1); std::swap(BS0, BS1);
std::swap(XBS1, XBS2);
std::swap(XBS0, XBS1); std::swap(XBS0, XBS1);
std::swap(bcb0, bcb1); std::swap(bcb0, bcb1);
std::swap(ldb0, ldb1); std::swap(ldb0, ldb1);
@@ -98,7 +100,7 @@ void dot::triton_c_src(std::ostream &os) const {
std::string BS = BS0 + ", " + BS1; std::string BS = BS0 + ", " + BS1;
std::string XAS = XAS0 + ", " + XAS1 + ", " + XAS2; std::string XAS = XAS0 + ", " + XAS1 + ", " + XAS2;
std::string XBS = XBS0 + ", " + XBS1 + ", " + XBS2; std::string XBS = XBS0 + ", " + XBS1 + ", " + XBS2;
std::string XCS = "TM, TN, 4"; std::string XCS = "TM, TN, 1";
std::string align_lda_str = "multiple_of(" + std::to_string(align_lda_) + ")"; std::string align_lda_str = "multiple_of(" + std::to_string(align_lda_) + ")";
std::string align_ldb_str = "multiple_of(" + std::to_string(align_ldb_) + ")"; std::string align_ldb_str = "multiple_of(" + std::to_string(align_ldb_) + ")";
std::string res = std::string res =
@@ -146,7 +148,7 @@ void matmul(restrict read_only align(16) )" + a_ty_ + R"( *A,
} }
)"; )";
std::cout << res << std::endl; // std::cout << res << std::endl;
os << res; os << res;
} }

View File

@@ -482,7 +482,8 @@ std::string retile_inst::shape_suffix(ir::type* ty){
std::string res = "["; std::string res = "[";
const auto& shapes = ty->get_tile_shapes(); const auto& shapes = ty->get_tile_shapes();
for(unsigned i = 0; i < shapes.size(); i++){ for(unsigned i = 0; i < shapes.size(); i++){
res += std::to_string(ty->get_tile_shapes()[i]->get_value()); ir::constant_int *shape_i = ty->get_tile_shapes()[i];
res += shape_i->repr();
if(i < shapes.size() - 1) if(i < shapes.size() - 1)
res += ", "; res += ", ";
} }
@@ -566,26 +567,33 @@ instruction *dot_inst::create_tt(value *A, value *B, value *C,
// trans instructions // trans instructions
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
ir::type* trans_inst::get_res_ty(ir::type* ty) { ir::type* trans_inst::get_res_ty(ir::type* ty, std::vector<constant_int*> perm) {
auto shapes = ty->get_tile_shapes(); // get argument shapes
std::rotate(shapes.begin(), shapes.begin() + 1, shapes.end()); ir::tile_type::tile_shapes_t arg_shapes = ty->get_tile_shapes();
return tile_type::get(ty->get_scalar_ty(), shapes); // permutate argument shapes
perm = init_perm(ty, perm);
ir::tile_type::tile_shapes_t res_shapes = arg_shapes;
for(int i = 0; i < perm.size(); i++)
res_shapes[i] = arg_shapes[perm[i]->get_value()];
// construct type
return tile_type::get(ty->get_scalar_ty(), res_shapes);
} }
std::vector<constant_int*> trans_inst::get_default_perm(ir::type* ty) { std::vector<constant_int*> trans_inst::init_perm(ir::type* ty, const std::vector<constant_int*>& perm) {
if(!perm.empty())
return perm;
auto size = ty->get_tile_shapes().size(); auto size = ty->get_tile_shapes().size();
ir::type* int32_ty = type::get_int32_ty(ty->get_context()); ir::type* int32_ty = type::get_int32_ty(ty->get_context());
std::vector<constant_int*> result; std::vector<constant_int*> result;
for(size_t i = 0; i < size; i++) result.push_back(ir::constant_int::get(int32_ty, size - 1));
result.push_back(ir::constant_int::get(int32_ty, i + 1 % size)); for(int i = 0; i < size - 1; i++)
result.push_back(ir::constant_int::get(int32_ty, i));
return result; return result;
} }
trans_inst::trans_inst(value *arg, const std::vector<constant_int*>& perm, const std::string &name, instruction *next) trans_inst::trans_inst(value *arg, const std::vector<constant_int*>& perm, const std::string &name, instruction *next)
: builtin_inst(get_res_ty(arg->get_type()), 1, 1, name, next) { : builtin_inst(get_res_ty(arg->get_type(), perm), 1, 1, name, next) {
perm_ = perm; perm_ = init_perm(arg->get_type(), perm);
if(perm_.empty())
perm_ = get_default_perm(arg->get_type());
auto size = arg->get_type()->get_tile_shapes().size(); auto size = arg->get_type()->get_tile_shapes().size();
assert(perm_.size() == size); assert(perm_.size() == size);
set_operand(0, arg); set_operand(0, arg);
@@ -615,7 +623,7 @@ instruction* sqrt_inst::create(value *arg, const std::string &name, instruction
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
// reduce instructions // reduce instructions
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
type* reduce_inst::get_type(value *arg, unsigned axis) { type* reduce_inst::get_res_type(value *arg, unsigned axis) {
ir::tile_type::tile_shapes_t shapes = arg->get_type()->get_tile_shapes(); ir::tile_type::tile_shapes_t shapes = arg->get_type()->get_tile_shapes();
shapes.erase(shapes.begin() + axis); shapes.erase(shapes.begin() + axis);
type *scalar_ty = arg->get_type()->get_scalar_ty(); type *scalar_ty = arg->get_type()->get_scalar_ty();
@@ -626,7 +634,7 @@ type* reduce_inst::get_type(value *arg, unsigned axis) {
} }
reduce_inst::reduce_inst(value *arg, unsigned axis, const std::string &name, instruction *next) reduce_inst::reduce_inst(value *arg, unsigned axis, const std::string &name, instruction *next)
: builtin_inst(get_type(arg, axis), 1, 1, name, next), : builtin_inst(get_res_type(arg, axis), 1, 1, name, next),
axis_(axis){ axis_(axis){
set_operand(0, arg); set_operand(0, arg);
} }

View File

@@ -203,7 +203,17 @@ ir::value* select_expression::codegen(ir::module *mod) const {
// trans // trans
ir::value* trans_expression::codegen(ir::module *mod) const { ir::value* trans_expression::codegen(ir::module *mod) const {
return mod->get_builder().create_trans(arg_->codegen(mod)); // shapes
std::vector<ir::constant_int*> perm;
if(perm_) {
for(expression *expr: perm_->values()){
ir::constant_int *shape = dynamic_cast<ir::constant_int*>(expr->codegen(mod));
if(shape == nullptr)
throw std::runtime_error("tile shapes must be constant expressions");
perm.push_back(shape);
}
}
return mod->get_builder().create_trans(arg_->codegen(mod), perm);
} }
// sqrt // sqrt

View File

@@ -37,13 +37,13 @@ void parallel_loop_nest(std::vector<size_t> const & ranges,
size_t D = ranges.size(); size_t D = ranges.size();
std::vector<size_t> values(D, 0); std::vector<size_t> values(D, 0);
// thread pools // thread pools
// ThreadPool pool(nthreads); ThreadPool pool(nthreads);
// Start with innermost loop // Start with innermost loop
size_t i = D - 1; size_t i = D - 1;
while(true){ while(true){
// Execute function // Execute function
// pool.enqueue(f,values); pool.enqueue(f,values);
f(values); // f(values);
while(values[i]++ == ranges[i] - 1){ while(values[i]++ == ranges[i] - 1){
if(i == 0) if(i == 0)
return; return;