more improvements and regressions

This commit is contained in:
Philippe Tillet
2019-08-06 16:21:20 -07:00
parent 26c9849462
commit 5efdb7978e
14 changed files with 138 additions and 69 deletions

View File

@@ -26,7 +26,7 @@ struct perf_t {
perf_t do_bench(triton::driver::stream* stream, bool AT, bool BT, int32_t M, int32_t N, int32_t K){
typedef float NumericT;
std::string ty = "float";
std::string ty = "half";
size_t dt_nbytes = sizeof(NumericT);
triton::driver::context* context = stream->context();
std::vector<NumericT> hc(M*N);
@@ -48,7 +48,7 @@ perf_t do_bench(triton::driver::stream* stream, bool AT, bool BT, int32_t M, int
stream->synchronize();
triton::dnn::dot dot(M, N, K, AT, BT, ty, ty, 8, 8, 8);
// benchmark triton
double triton_ns = triton::tools::bench([&]() { dot.enqueue(stream, {da, db, dc}, triton::dnn::NO_TUNING);}, stream);
double triton_ns = triton::tools::bench([&]() { dot.enqueue(stream, {da, db, dc}, triton::dnn::FULL_TUNING);}, stream);
// benchmark cublas
// NumericT alpha = 1;
// NumericT beta = 0;
@@ -111,7 +111,7 @@ int main() {
std::vector<config_t> configs = {
// {false, false, 8192, 512, 512},
// {false, true, 8192, 8192, 8192}
{false, true, 128, 128, 128},
{true, true, 128, 128, 128},
// {false, true, 32768, 256, 512}
// {true, false, 8192, 512, 512},
// {true, true, 8192, 512, 512}

View File

@@ -38,6 +38,7 @@ protected:
public:
virtual uint64_t get_value() const { return value_; }
virtual std::string repr() const { return std::to_string(get_value()); }
static constant_int *get(type *ty, uint64_t value);
protected:
@@ -57,7 +58,7 @@ public:
const std::vector<unsigned>& get_space() { return space_; }
void set_space(const std::vector<unsigned> &space) { space_ = space; }
uint64_t get_value() const { assert(has_value_); return value_; }
std::string repr() const { return has_value_? std::to_string(value_) : "?" ;}
private:
std::vector<unsigned> space_;
bool has_value_;

View File

@@ -584,12 +584,18 @@ private:
class trans_inst: public builtin_inst {
public:
ir::type* get_res_ty(ir::type* in);
std::vector<constant_int*> get_default_perm(ir::type* ty);
ir::type* get_res_ty(ir::type* in, std::vector<constant_int *> perm);
std::vector<constant_int*> init_perm(ir::type* ty, const std::vector<constant_int*>& perm);
private:
trans_inst(value *arg, const std::vector<constant_int*>& perm, const std::string& name, instruction* next);
std::string repr_impl() const { return "trans"; }
std::string repr_impl() const {
std::string res = "trans<";
for(ir::constant_int *x: perm_)
res += x->repr() + ",";
res[res.size()-1] = '>';
return res;
}
public:
static instruction* create(value *arg, const std::vector<constant_int*>& perm = {}, const std::string &name = "", instruction *next = nullptr);
@@ -609,7 +615,7 @@ public:
class reduce_inst: public builtin_inst {
private:
static type* get_type(value *arg, unsigned axis);
static type* get_res_type(value *arg, unsigned axis);
private:
reduce_inst(value* arg, unsigned axis, const std::string& name, instruction* next);

View File

@@ -180,11 +180,12 @@ private:
class trans_expression: public builtin_expression{
public:
trans_expression(node *arg): arg_(arg) {}
trans_expression(node *arg, node *perm): arg_(arg), perm_((list<expression*>*)perm) {}
ir::value* codegen(ir::module *mod) const;
private:
node* arg_;
const list<expression*>* perm_;
};
class sqrt_expression: public builtin_expression{

View File

@@ -125,7 +125,8 @@ builtin_expression
| DOT '(' expression ',' expression ',' expression ')' { $$ = new matmul_expression($3, $5, $7); }
| SQRT '(' expression ')' { $$ = new sqrt_expression($3); }
| ALLOC_CONST type_specifier '[' constant ']' { $$ = new alloc_const_expression(new typed_declaration_specifier(get_type_spec($2)), $4); }
| TRANS '(' expression ')' { $$ = new trans_expression($3); }
| TRANS '(' expression ',' constant_expression_list ')' { $$ = new trans_expression($3, $5); }
| TRANS '(' expression ')' { $$ = new trans_expression($3, nullptr); }
| REDUCE_SUM '(' expression ',' constant ')' { $$ = new reduce_expression($3, $5);}
| MAX '(' expression ',' expression ')' { $$ = new max_expression($3, $5); }
| MIN '(' expression ',' expression ')' { $$ = new min_expression($3, $5); }

View File

@@ -8,7 +8,17 @@ namespace triton {
namespace codegen{
inline bool is_trans(ir::value *v){
return dynamic_cast<ir::trans_inst*>(v) != nullptr;
auto *x = dynamic_cast<ir::trans_inst*>(v);
if(!x)
return false;
std::vector<ir::constant_int*> perm = x->get_perm();
std::vector<ir::constant_int*> ref;
ir::type *int32_ty = ir::type::get_int32_ty(v->get_type()->get_context());
for(size_t i = 0; i < perm.size(); i++)
ref.push_back(ir::constant_int::get(int32_ty, i));
std::swap(ref[0], ref[1]);
// true is perm == ref
return std::equal(perm.begin(), perm.end(), ref.begin());
}
inline bool is_hmma(ir::value *v){
@@ -28,7 +38,6 @@ inline bool is_hmma(ir::value *v){
void optimize_dot::run(ir::module &mod) {
ir::builder &builder = mod.get_builder();
std::vector<ir::instruction*> to_delete;
// iterate
for(ir::function *fn: mod.get_function_list())
for(ir::basic_block *block: fn->blocks())
@@ -47,15 +56,12 @@ void optimize_dot::run(ir::module &mod) {
ir::value *BB = B;
if(trans_a){
AA = ((ir::trans_inst*)A)->get_operand(0);
to_delete.push_back((ir::instruction*)A);
}
if(trans_b){
BB = ((ir::trans_inst*)B)->get_operand(0);
to_delete.push_back((ir::instruction*)B);
}
ir::instruction *dot_atbt = builder.insert(ir::dot_inst::create(AA, BB, D, trans_a, trans_b));
dot->replace_all_uses_with(dot_atbt);
to_delete.push_back(dot);
}
else{
// dot(op(a), trans(b))
@@ -63,28 +69,24 @@ void optimize_dot::run(ir::module &mod) {
ir::value* BB = ((ir::trans_inst*)B)->get_operand(0);
ir::instruction *NT = builder.insert(ir::dot_inst::create_nt(A, BB, D));
dot->replace_all_uses_with(NT);
to_delete.push_back((ir::instruction*)B);
to_delete.push_back(dot);
}
// dot(op(a), b)
if(!trans_b){
// create permutations
size_t size = B->get_type()->get_tile_shapes().size();
std::vector<ir::constant_int*> perm(size);
ir::type *int32_ty = ir::type::get_int32_ty(B->get_type()->get_context());
for(size_t i = 0; i < size; i++)
perm[i] = ir::constant_int::get(int32_ty, i);
std::swap(perm[0], perm[1]);
// replace NN -> NT (trans)
ir::value* BB = builder.create_trans(B, perm);
ir::instruction *NT = builder.insert(ir::dot_inst::create_nt(A, BB, D));
dot->replace_all_uses_with(NT);
to_delete.push_back(dot);
}
}
}
}
for(ir::instruction* i: to_delete)
i->erase_from_parent();
}
}

View File

@@ -42,22 +42,32 @@ void optimize_trans::run(ir::module &mod) {
for(ir::function *fn: mod.get_function_list())
for(ir::basic_block *block: fn->blocks())
for(ir::instruction* i: block->get_inst_list()){
// filter transposition
// transposition
if(auto trans = dynamic_cast<ir::trans_inst*>(i)) {
auto users = trans->get_users();
auto ops = trans->ops();
if(users.size() > 1 || ops.size() > 1)
continue;
ir::value* op = *ops.begin();
// chains of transpositions
// TODO
// todo: chains of transpositions
// trans(phi) -> phi(trans(), trans()...)
if(dynamic_cast<ir::phi_node*>(op)){
ir::value* new_phi = replace_phi(op, builder, trans->get_perm());
trans->replace_all_uses_with(new_phi);
}
}
// reductions
if(auto x = dynamic_cast<ir::reduce_inst*>(i)) {
ir::constant_int *one = ir::constant_int::get(ir::type::get_int32_ty(i->get_type()->get_context()), 1);
ir::value *arg = x->get_operand(0);
auto shapes = arg->get_type()->get_tile_shapes();
if(shapes[x->get_axis()] == one){
builder.set_insert_point(x);
ir::value* new_red = builder.create_reshape(arg, x->get_type()->get_tile_shapes());
x->replace_all_uses_with(new_red);
}
}
}
}

View File

@@ -996,8 +996,9 @@ void selection::lower_tile_instruction(ir::instruction *ins, llvm::IRBuilder<> &
distributed_tile *TC = (distributed_tile*)tmap_.at(C);
Type *c_ty = llvm_type(C->get_type()->get_scalar_ty(), ctx);
Function *f_mul_add = Intrinsic::getDeclaration(module, Intrinsic::fmuladd, {c_ty});
auto A_shapes = A->get_type()->get_tile_shapes();
size_t red_axis = dot->is_a_trans() ? 0 : 1;
unsigned NK = A->get_type()->get_tile_shapes()[red_axis]->get_value();
unsigned NK = A_shapes[red_axis]->get_value();
if(NK != 1)
{
shared_tile *TA = (shared_tile*)tmap_.at(A);
@@ -1008,18 +1009,27 @@ void selection::lower_tile_instruction(ir::instruction *ins, llvm::IRBuilder<> &
result->for_each([&](indices_t idx){
Value *res = TC->get_value(idx);
for(unsigned K = 0; K < NK; ++K){
indices_t a_idx = {idx[0], builder.getInt32(K), idx[2]};
indices_t b_idx = {builder.getInt32(K), idx[1], idx[2]};
// input indices
indices_t a_idx = {idx[0], builder.getInt32(K)};
indices_t b_idx = {builder.getInt32(K), idx[1]};
if(AT)
std::swap(a_idx[0], a_idx[1]);
if(BT)
std::swap(b_idx[0], b_idx[1]);
// add batching dimension
for(size_t i = 2; i < idx.size(); i++){
a_idx.insert(a_idx.end(), idx[i]);
b_idx.insert(b_idx.end(), idx[i]);
}
// load value
Value *a = TA->get_value(a_idx);
Value *b = TB->get_value(b_idx);
if(a->getType() != c_ty)
a = builder.CreateFPCast(a, c_ty);
if(b->getType() != c_ty)
b = builder.CreateFPCast(b, c_ty);
// a = ConstantFP::get(builder.getFloatTy(), 1);
// b = ConstantFP::get(builder.getFloatTy(), 1);
res = builder.CreateCall(f_mul_add, {a, b, res});
}
result->set_value(idx, res);

View File

@@ -67,8 +67,6 @@ void tune::init_c_graph(ir::instruction *v) {
continue;
add_constraint({reduce, current++}, {arg, i});
}
// add_constraint({reduce, 0}, {arg, 0});
// add_constraint({reduce, 1}, {arg, 1});
return;
}
else
@@ -115,7 +113,7 @@ void tune::init_c_graph(ir::instruction *v) {
}
}
// Matrix multiplication
else if(dynamic_cast<ir::dot_inst*>(v)){
else if(auto *x = dynamic_cast<ir::dot_inst*>(v)){
ir::value *A = v->get_operand(0);
ir::value *B = v->get_operand(1);
ir::value *D = v->get_operand(2);
@@ -124,8 +122,8 @@ void tune::init_c_graph(ir::instruction *v) {
for(unsigned i = 2; i < shapes.size(); i++){
if(shapes[i] == one)
static_params_.insert({{v, i}, 1});
add_constraint({v, i}, {A, i});
add_constraint({v, i}, {B, i});
// add_constraint({v, i}, {A, i});
// add_constraint({v, i}, {B, i});
}
}
// Element-wise
@@ -268,35 +266,53 @@ void tune::run(ir::module &mod) {
for(ir::function *fn: mod.get_function_list())
for(ir::basic_block *block: fn->blocks())
for(ir::instruction *i : block->get_inst_list()){
if(fragments_.find({i, 0}) != fragments_.end() && fragments_.at({i, 0}) != STRIDED_SCAN)
continue;
if(auto *ld = dynamic_cast<ir::load_inst*>(i))
if(auto *x = dynamic_cast<ir::load_inst*>(i))
if(i->get_type()->is_tile_ty()){
ir::type *ptr_ty = ld->get_pointer_operand()->get_type()->get_scalar_ty();
ir::type *ptr_ty = x->get_pointer_operand()->get_type()->get_scalar_ty();
size_t addr_space = ptr_ty->get_pointer_address_space();
if(addr_space < 4){
ir::type *ty = mod.get_builder().get_int32_ty();
std::unique_ptr<ir::metaparameter> tmp(ir::metaparameter::create(ctx, ty, 1, 1));
std::unique_ptr<ir::metaparameter> tmp(ir::metaparameter::create(ctx, ty, 2, 4));
*params_.at(i).at("nts.d0") = *tmp;
}
}
if(dynamic_cast<ir::dot_inst*>(i) && i->get_type()->is_tile_ty()){
ir::type *ty = mod.get_builder().get_int32_ty();
std::unique_ptr<ir::metaparameter> tmp1(ir::metaparameter::create(ctx, ty, 1, 1));
std::unique_ptr<ir::metaparameter> tmp2(ir::metaparameter::create(ctx, ty, 1, 1));
std::unique_ptr<ir::metaparameter> tmp1(ir::metaparameter::create(ctx, ty, 2, 4));
std::unique_ptr<ir::metaparameter> tmp2(ir::metaparameter::create(ctx, ty, 2, 4));
*params_.at(i).at("nts.d0") = *tmp1;
*params_.at(i).at("nts.d1") = *tmp2;
}
}
// initialize grids
for(ir::function *fn: mod.get_function_list()){
std::map<ir::metaparameter*, ir::instruction*> references;
create_grids(grids_, references, fn);
}
for(ir::instruction *i: grids_){
auto shapes = i->get_type()->get_tile_shapes();
for(size_t k = 0; k < shapes.size(); k++)
if(shapes[k]->get_value() == 1) {
if(fragments_.at({i, k}) == STRIDED_SCAN){
params_.at(i).at("nts.d" + std::to_string(k))->set_value(1);
params_.at(i).at("mts.d" + std::to_string(k))->set_value(1);
}
if(fragments_.at({i, k}) == HMMA_FRAGMENT_C){
params_.at(i).at("fpw.d" + std::to_string(k))->set_value(1);
params_.at(i).at("wpt.d" + std::to_string(k))->set_value(1);
}
}
}
}
void tune::init(ir::module &mod) {
for(ir::function *fn: mod.get_function_list()){
// initialize grids
std::map<ir::metaparameter*, ir::instruction*> references;
create_grids(grids_, references, fn);
}
// number of threads
num_threads_ = get_req_num_threads(grids_.front());
}

View File

@@ -64,7 +64,9 @@ std::pair<base*, rt::jit*> base::get_profile_impl(driver::stream *stream, std::v
else{
// params_t params = heuristics();
// params_t params = jit->get_valid(name_.c_str(), src.c_str());
params_t params = {4, 1, 32, 4, 1, 32, 4, 4, 4, 1, 1, 16, 32, 16, 4, 4, 1};
// params_t params = {4, 1, 32, 4, 1, 32, 4, 4, 4, 1, 1, 16, 32, 16, 4, 4, 4, 4, 1}; //NT
// params_t params = {4, 1, 32, 4, 32, 4, 4, 4, 1, 1, 16, 32, 16, 1, 4, 4, 4, 4, 4, 1}; //NN
params_t params = {4, 32, 4, 1, 32, 4, 4, 4, 1, 1, 16, 1, 32, 16, 4, 4, 4, 4, 4, 1}; // TT
jit->add_module(name_.c_str(), src.c_str(), params);
}
triton::driver::kernel* kernel = jit->get_function(name_.c_str());

View File

@@ -74,22 +74,24 @@ void dot::enqueue_impl(driver::stream *stream, driver::kernel *kernel,
void dot::triton_c_src(std::ostream &os) const {
std::string AS0 = "TM", AS1 = "TK";
std::string BS0 = "TK", BS1 = "TN";
std::string XAS0 = "TM", XAS1 = "TK/4", XAS2 = "4";
std::string XBS0 = "TK/4", XBS1 = "TN", XBS2 = "4";
std::string XAS0 = "TM", XAS1 = "TK/1", XAS2 = "1";
std::string XBS0 = "TK/1", XBS1 = "1", XBS2 = "TN";
std::string bca0 = "[newaxis, :]", bca1 = "[:, newaxis]";
std::string bcb0 = "[:, newaxis]", bcb1 = "[newaxis, :]";
std::string lda0 = "*lda", lda1 = "";
std::string ldb0 = "", ldb1 = "*ldb";
std::string usea = AT_ ? "trans(xa)" : "xa";
std::string useb = BT_ ? "trans(xb)" : "xb";
std::string usea = AT_ ? "trans(xa, 0, 2, 1)" : "xa";
std::string useb = BT_ ? "trans(xb, 1, 0, 2)" : "trans(xb, 0, 2, 1)";
if(AT_){
std::swap(AS0, AS1);
std::swap(XAS0, XAS1);
std::swap(XAS1, XAS2);
std::swap(bca0, bca1);
std::swap(lda0, lda1);
}
if(BT_){
std::swap(BS0, BS1);
std::swap(XBS1, XBS2);
std::swap(XBS0, XBS1);
std::swap(bcb0, bcb1);
std::swap(ldb0, ldb1);
@@ -98,7 +100,7 @@ void dot::triton_c_src(std::ostream &os) const {
std::string BS = BS0 + ", " + BS1;
std::string XAS = XAS0 + ", " + XAS1 + ", " + XAS2;
std::string XBS = XBS0 + ", " + XBS1 + ", " + XBS2;
std::string XCS = "TM, TN, 4";
std::string XCS = "TM, TN, 1";
std::string align_lda_str = "multiple_of(" + std::to_string(align_lda_) + ")";
std::string align_ldb_str = "multiple_of(" + std::to_string(align_ldb_) + ")";
std::string res =
@@ -146,7 +148,7 @@ void matmul(restrict read_only align(16) )" + a_ty_ + R"( *A,
}
)";
std::cout << res << std::endl;
// std::cout << res << std::endl;
os << res;
}

View File

@@ -482,7 +482,8 @@ std::string retile_inst::shape_suffix(ir::type* ty){
std::string res = "[";
const auto& shapes = ty->get_tile_shapes();
for(unsigned i = 0; i < shapes.size(); i++){
res += std::to_string(ty->get_tile_shapes()[i]->get_value());
ir::constant_int *shape_i = ty->get_tile_shapes()[i];
res += shape_i->repr();
if(i < shapes.size() - 1)
res += ", ";
}
@@ -566,26 +567,33 @@ instruction *dot_inst::create_tt(value *A, value *B, value *C,
// trans instructions
//===----------------------------------------------------------------------===//
ir::type* trans_inst::get_res_ty(ir::type* ty) {
auto shapes = ty->get_tile_shapes();
std::rotate(shapes.begin(), shapes.begin() + 1, shapes.end());
return tile_type::get(ty->get_scalar_ty(), shapes);
ir::type* trans_inst::get_res_ty(ir::type* ty, std::vector<constant_int*> perm) {
// get argument shapes
ir::tile_type::tile_shapes_t arg_shapes = ty->get_tile_shapes();
// permutate argument shapes
perm = init_perm(ty, perm);
ir::tile_type::tile_shapes_t res_shapes = arg_shapes;
for(int i = 0; i < perm.size(); i++)
res_shapes[i] = arg_shapes[perm[i]->get_value()];
// construct type
return tile_type::get(ty->get_scalar_ty(), res_shapes);
}
std::vector<constant_int*> trans_inst::get_default_perm(ir::type* ty) {
std::vector<constant_int*> trans_inst::init_perm(ir::type* ty, const std::vector<constant_int*>& perm) {
if(!perm.empty())
return perm;
auto size = ty->get_tile_shapes().size();
ir::type* int32_ty = type::get_int32_ty(ty->get_context());
std::vector<constant_int*> result;
for(size_t i = 0; i < size; i++)
result.push_back(ir::constant_int::get(int32_ty, i + 1 % size));
result.push_back(ir::constant_int::get(int32_ty, size - 1));
for(int i = 0; i < size - 1; i++)
result.push_back(ir::constant_int::get(int32_ty, i));
return result;
}
trans_inst::trans_inst(value *arg, const std::vector<constant_int*>& perm, const std::string &name, instruction *next)
: builtin_inst(get_res_ty(arg->get_type()), 1, 1, name, next) {
perm_ = perm;
if(perm_.empty())
perm_ = get_default_perm(arg->get_type());
: builtin_inst(get_res_ty(arg->get_type(), perm), 1, 1, name, next) {
perm_ = init_perm(arg->get_type(), perm);
auto size = arg->get_type()->get_tile_shapes().size();
assert(perm_.size() == size);
set_operand(0, arg);
@@ -615,7 +623,7 @@ instruction* sqrt_inst::create(value *arg, const std::string &name, instruction
//===----------------------------------------------------------------------===//
// reduce instructions
//===----------------------------------------------------------------------===//
type* reduce_inst::get_type(value *arg, unsigned axis) {
type* reduce_inst::get_res_type(value *arg, unsigned axis) {
ir::tile_type::tile_shapes_t shapes = arg->get_type()->get_tile_shapes();
shapes.erase(shapes.begin() + axis);
type *scalar_ty = arg->get_type()->get_scalar_ty();
@@ -626,7 +634,7 @@ type* reduce_inst::get_type(value *arg, unsigned axis) {
}
reduce_inst::reduce_inst(value *arg, unsigned axis, const std::string &name, instruction *next)
: builtin_inst(get_type(arg, axis), 1, 1, name, next),
: builtin_inst(get_res_type(arg, axis), 1, 1, name, next),
axis_(axis){
set_operand(0, arg);
}

View File

@@ -203,7 +203,17 @@ ir::value* select_expression::codegen(ir::module *mod) const {
// trans
ir::value* trans_expression::codegen(ir::module *mod) const {
return mod->get_builder().create_trans(arg_->codegen(mod));
// shapes
std::vector<ir::constant_int*> perm;
if(perm_) {
for(expression *expr: perm_->values()){
ir::constant_int *shape = dynamic_cast<ir::constant_int*>(expr->codegen(mod));
if(shape == nullptr)
throw std::runtime_error("tile shapes must be constant expressions");
perm.push_back(shape);
}
}
return mod->get_builder().create_trans(arg_->codegen(mod), perm);
}
// sqrt

View File

@@ -37,13 +37,13 @@ void parallel_loop_nest(std::vector<size_t> const & ranges,
size_t D = ranges.size();
std::vector<size_t> values(D, 0);
// thread pools
// ThreadPool pool(nthreads);
ThreadPool pool(nthreads);
// Start with innermost loop
size_t i = D - 1;
while(true){
// Execute function
// pool.enqueue(f,values);
f(values);
pool.enqueue(f,values);
// f(values);
while(values[i]++ == ranges[i] - 1){
if(i == 0)
return;