more improvements and regressions
This commit is contained in:
@@ -26,7 +26,7 @@ struct perf_t {
|
||||
|
||||
perf_t do_bench(triton::driver::stream* stream, bool AT, bool BT, int32_t M, int32_t N, int32_t K){
|
||||
typedef float NumericT;
|
||||
std::string ty = "float";
|
||||
std::string ty = "half";
|
||||
size_t dt_nbytes = sizeof(NumericT);
|
||||
triton::driver::context* context = stream->context();
|
||||
std::vector<NumericT> hc(M*N);
|
||||
@@ -48,7 +48,7 @@ perf_t do_bench(triton::driver::stream* stream, bool AT, bool BT, int32_t M, int
|
||||
stream->synchronize();
|
||||
triton::dnn::dot dot(M, N, K, AT, BT, ty, ty, 8, 8, 8);
|
||||
// benchmark triton
|
||||
double triton_ns = triton::tools::bench([&]() { dot.enqueue(stream, {da, db, dc}, triton::dnn::NO_TUNING);}, stream);
|
||||
double triton_ns = triton::tools::bench([&]() { dot.enqueue(stream, {da, db, dc}, triton::dnn::FULL_TUNING);}, stream);
|
||||
// benchmark cublas
|
||||
// NumericT alpha = 1;
|
||||
// NumericT beta = 0;
|
||||
@@ -111,7 +111,7 @@ int main() {
|
||||
std::vector<config_t> configs = {
|
||||
// {false, false, 8192, 512, 512},
|
||||
// {false, true, 8192, 8192, 8192}
|
||||
{false, true, 128, 128, 128},
|
||||
{true, true, 128, 128, 128},
|
||||
// {false, true, 32768, 256, 512}
|
||||
// {true, false, 8192, 512, 512},
|
||||
// {true, true, 8192, 512, 512}
|
||||
|
@@ -38,6 +38,7 @@ protected:
|
||||
|
||||
public:
|
||||
virtual uint64_t get_value() const { return value_; }
|
||||
virtual std::string repr() const { return std::to_string(get_value()); }
|
||||
static constant_int *get(type *ty, uint64_t value);
|
||||
|
||||
protected:
|
||||
@@ -57,7 +58,7 @@ public:
|
||||
const std::vector<unsigned>& get_space() { return space_; }
|
||||
void set_space(const std::vector<unsigned> &space) { space_ = space; }
|
||||
uint64_t get_value() const { assert(has_value_); return value_; }
|
||||
|
||||
std::string repr() const { return has_value_? std::to_string(value_) : "?" ;}
|
||||
private:
|
||||
std::vector<unsigned> space_;
|
||||
bool has_value_;
|
||||
|
@@ -584,12 +584,18 @@ private:
|
||||
|
||||
class trans_inst: public builtin_inst {
|
||||
public:
|
||||
ir::type* get_res_ty(ir::type* in);
|
||||
std::vector<constant_int*> get_default_perm(ir::type* ty);
|
||||
ir::type* get_res_ty(ir::type* in, std::vector<constant_int *> perm);
|
||||
std::vector<constant_int*> init_perm(ir::type* ty, const std::vector<constant_int*>& perm);
|
||||
|
||||
private:
|
||||
trans_inst(value *arg, const std::vector<constant_int*>& perm, const std::string& name, instruction* next);
|
||||
std::string repr_impl() const { return "trans"; }
|
||||
std::string repr_impl() const {
|
||||
std::string res = "trans<";
|
||||
for(ir::constant_int *x: perm_)
|
||||
res += x->repr() + ",";
|
||||
res[res.size()-1] = '>';
|
||||
return res;
|
||||
}
|
||||
|
||||
public:
|
||||
static instruction* create(value *arg, const std::vector<constant_int*>& perm = {}, const std::string &name = "", instruction *next = nullptr);
|
||||
@@ -609,7 +615,7 @@ public:
|
||||
|
||||
class reduce_inst: public builtin_inst {
|
||||
private:
|
||||
static type* get_type(value *arg, unsigned axis);
|
||||
static type* get_res_type(value *arg, unsigned axis);
|
||||
|
||||
private:
|
||||
reduce_inst(value* arg, unsigned axis, const std::string& name, instruction* next);
|
||||
|
@@ -180,11 +180,12 @@ private:
|
||||
|
||||
class trans_expression: public builtin_expression{
|
||||
public:
|
||||
trans_expression(node *arg): arg_(arg) {}
|
||||
trans_expression(node *arg, node *perm): arg_(arg), perm_((list<expression*>*)perm) {}
|
||||
ir::value* codegen(ir::module *mod) const;
|
||||
|
||||
private:
|
||||
node* arg_;
|
||||
const list<expression*>* perm_;
|
||||
};
|
||||
|
||||
class sqrt_expression: public builtin_expression{
|
||||
|
@@ -125,7 +125,8 @@ builtin_expression
|
||||
| DOT '(' expression ',' expression ',' expression ')' { $$ = new matmul_expression($3, $5, $7); }
|
||||
| SQRT '(' expression ')' { $$ = new sqrt_expression($3); }
|
||||
| ALLOC_CONST type_specifier '[' constant ']' { $$ = new alloc_const_expression(new typed_declaration_specifier(get_type_spec($2)), $4); }
|
||||
| TRANS '(' expression ')' { $$ = new trans_expression($3); }
|
||||
| TRANS '(' expression ',' constant_expression_list ')' { $$ = new trans_expression($3, $5); }
|
||||
| TRANS '(' expression ')' { $$ = new trans_expression($3, nullptr); }
|
||||
| REDUCE_SUM '(' expression ',' constant ')' { $$ = new reduce_expression($3, $5);}
|
||||
| MAX '(' expression ',' expression ')' { $$ = new max_expression($3, $5); }
|
||||
| MIN '(' expression ',' expression ')' { $$ = new min_expression($3, $5); }
|
||||
|
@@ -8,7 +8,17 @@ namespace triton {
|
||||
namespace codegen{
|
||||
|
||||
inline bool is_trans(ir::value *v){
|
||||
return dynamic_cast<ir::trans_inst*>(v) != nullptr;
|
||||
auto *x = dynamic_cast<ir::trans_inst*>(v);
|
||||
if(!x)
|
||||
return false;
|
||||
std::vector<ir::constant_int*> perm = x->get_perm();
|
||||
std::vector<ir::constant_int*> ref;
|
||||
ir::type *int32_ty = ir::type::get_int32_ty(v->get_type()->get_context());
|
||||
for(size_t i = 0; i < perm.size(); i++)
|
||||
ref.push_back(ir::constant_int::get(int32_ty, i));
|
||||
std::swap(ref[0], ref[1]);
|
||||
// true is perm == ref
|
||||
return std::equal(perm.begin(), perm.end(), ref.begin());
|
||||
}
|
||||
|
||||
inline bool is_hmma(ir::value *v){
|
||||
@@ -28,7 +38,6 @@ inline bool is_hmma(ir::value *v){
|
||||
|
||||
void optimize_dot::run(ir::module &mod) {
|
||||
ir::builder &builder = mod.get_builder();
|
||||
std::vector<ir::instruction*> to_delete;
|
||||
// iterate
|
||||
for(ir::function *fn: mod.get_function_list())
|
||||
for(ir::basic_block *block: fn->blocks())
|
||||
@@ -47,15 +56,12 @@ void optimize_dot::run(ir::module &mod) {
|
||||
ir::value *BB = B;
|
||||
if(trans_a){
|
||||
AA = ((ir::trans_inst*)A)->get_operand(0);
|
||||
to_delete.push_back((ir::instruction*)A);
|
||||
}
|
||||
if(trans_b){
|
||||
BB = ((ir::trans_inst*)B)->get_operand(0);
|
||||
to_delete.push_back((ir::instruction*)B);
|
||||
}
|
||||
ir::instruction *dot_atbt = builder.insert(ir::dot_inst::create(AA, BB, D, trans_a, trans_b));
|
||||
dot->replace_all_uses_with(dot_atbt);
|
||||
to_delete.push_back(dot);
|
||||
}
|
||||
else{
|
||||
// dot(op(a), trans(b))
|
||||
@@ -63,28 +69,24 @@ void optimize_dot::run(ir::module &mod) {
|
||||
ir::value* BB = ((ir::trans_inst*)B)->get_operand(0);
|
||||
ir::instruction *NT = builder.insert(ir::dot_inst::create_nt(A, BB, D));
|
||||
dot->replace_all_uses_with(NT);
|
||||
to_delete.push_back((ir::instruction*)B);
|
||||
to_delete.push_back(dot);
|
||||
}
|
||||
// dot(op(a), b)
|
||||
if(!trans_b){
|
||||
// create permutations
|
||||
size_t size = B->get_type()->get_tile_shapes().size();
|
||||
std::vector<ir::constant_int*> perm(size);
|
||||
ir::type *int32_ty = ir::type::get_int32_ty(B->get_type()->get_context());
|
||||
for(size_t i = 0; i < size; i++)
|
||||
perm[i] = ir::constant_int::get(int32_ty, i);
|
||||
std::swap(perm[0], perm[1]);
|
||||
// replace NN -> NT (trans)
|
||||
ir::value* BB = builder.create_trans(B, perm);
|
||||
ir::instruction *NT = builder.insert(ir::dot_inst::create_nt(A, BB, D));
|
||||
dot->replace_all_uses_with(NT);
|
||||
to_delete.push_back(dot);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for(ir::instruction* i: to_delete)
|
||||
i->erase_from_parent();
|
||||
}
|
||||
|
||||
}
|
||||
|
@@ -42,22 +42,32 @@ void optimize_trans::run(ir::module &mod) {
|
||||
for(ir::function *fn: mod.get_function_list())
|
||||
for(ir::basic_block *block: fn->blocks())
|
||||
for(ir::instruction* i: block->get_inst_list()){
|
||||
// filter transposition
|
||||
// transposition
|
||||
if(auto trans = dynamic_cast<ir::trans_inst*>(i)) {
|
||||
auto users = trans->get_users();
|
||||
auto ops = trans->ops();
|
||||
if(users.size() > 1 || ops.size() > 1)
|
||||
continue;
|
||||
ir::value* op = *ops.begin();
|
||||
// chains of transpositions
|
||||
// TODO
|
||||
|
||||
// todo: chains of transpositions
|
||||
// trans(phi) -> phi(trans(), trans()...)
|
||||
if(dynamic_cast<ir::phi_node*>(op)){
|
||||
ir::value* new_phi = replace_phi(op, builder, trans->get_perm());
|
||||
trans->replace_all_uses_with(new_phi);
|
||||
}
|
||||
}
|
||||
// reductions
|
||||
if(auto x = dynamic_cast<ir::reduce_inst*>(i)) {
|
||||
ir::constant_int *one = ir::constant_int::get(ir::type::get_int32_ty(i->get_type()->get_context()), 1);
|
||||
ir::value *arg = x->get_operand(0);
|
||||
auto shapes = arg->get_type()->get_tile_shapes();
|
||||
if(shapes[x->get_axis()] == one){
|
||||
builder.set_insert_point(x);
|
||||
ir::value* new_red = builder.create_reshape(arg, x->get_type()->get_tile_shapes());
|
||||
x->replace_all_uses_with(new_red);
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
}
|
||||
|
||||
|
@@ -996,8 +996,9 @@ void selection::lower_tile_instruction(ir::instruction *ins, llvm::IRBuilder<> &
|
||||
distributed_tile *TC = (distributed_tile*)tmap_.at(C);
|
||||
Type *c_ty = llvm_type(C->get_type()->get_scalar_ty(), ctx);
|
||||
Function *f_mul_add = Intrinsic::getDeclaration(module, Intrinsic::fmuladd, {c_ty});
|
||||
auto A_shapes = A->get_type()->get_tile_shapes();
|
||||
size_t red_axis = dot->is_a_trans() ? 0 : 1;
|
||||
unsigned NK = A->get_type()->get_tile_shapes()[red_axis]->get_value();
|
||||
unsigned NK = A_shapes[red_axis]->get_value();
|
||||
if(NK != 1)
|
||||
{
|
||||
shared_tile *TA = (shared_tile*)tmap_.at(A);
|
||||
@@ -1008,18 +1009,27 @@ void selection::lower_tile_instruction(ir::instruction *ins, llvm::IRBuilder<> &
|
||||
result->for_each([&](indices_t idx){
|
||||
Value *res = TC->get_value(idx);
|
||||
for(unsigned K = 0; K < NK; ++K){
|
||||
indices_t a_idx = {idx[0], builder.getInt32(K), idx[2]};
|
||||
indices_t b_idx = {builder.getInt32(K), idx[1], idx[2]};
|
||||
// input indices
|
||||
indices_t a_idx = {idx[0], builder.getInt32(K)};
|
||||
indices_t b_idx = {builder.getInt32(K), idx[1]};
|
||||
if(AT)
|
||||
std::swap(a_idx[0], a_idx[1]);
|
||||
if(BT)
|
||||
std::swap(b_idx[0], b_idx[1]);
|
||||
// add batching dimension
|
||||
for(size_t i = 2; i < idx.size(); i++){
|
||||
a_idx.insert(a_idx.end(), idx[i]);
|
||||
b_idx.insert(b_idx.end(), idx[i]);
|
||||
}
|
||||
// load value
|
||||
Value *a = TA->get_value(a_idx);
|
||||
Value *b = TB->get_value(b_idx);
|
||||
if(a->getType() != c_ty)
|
||||
a = builder.CreateFPCast(a, c_ty);
|
||||
if(b->getType() != c_ty)
|
||||
b = builder.CreateFPCast(b, c_ty);
|
||||
// a = ConstantFP::get(builder.getFloatTy(), 1);
|
||||
// b = ConstantFP::get(builder.getFloatTy(), 1);
|
||||
res = builder.CreateCall(f_mul_add, {a, b, res});
|
||||
}
|
||||
result->set_value(idx, res);
|
||||
|
@@ -67,8 +67,6 @@ void tune::init_c_graph(ir::instruction *v) {
|
||||
continue;
|
||||
add_constraint({reduce, current++}, {arg, i});
|
||||
}
|
||||
// add_constraint({reduce, 0}, {arg, 0});
|
||||
// add_constraint({reduce, 1}, {arg, 1});
|
||||
return;
|
||||
}
|
||||
else
|
||||
@@ -115,7 +113,7 @@ void tune::init_c_graph(ir::instruction *v) {
|
||||
}
|
||||
}
|
||||
// Matrix multiplication
|
||||
else if(dynamic_cast<ir::dot_inst*>(v)){
|
||||
else if(auto *x = dynamic_cast<ir::dot_inst*>(v)){
|
||||
ir::value *A = v->get_operand(0);
|
||||
ir::value *B = v->get_operand(1);
|
||||
ir::value *D = v->get_operand(2);
|
||||
@@ -124,8 +122,8 @@ void tune::init_c_graph(ir::instruction *v) {
|
||||
for(unsigned i = 2; i < shapes.size(); i++){
|
||||
if(shapes[i] == one)
|
||||
static_params_.insert({{v, i}, 1});
|
||||
add_constraint({v, i}, {A, i});
|
||||
add_constraint({v, i}, {B, i});
|
||||
// add_constraint({v, i}, {A, i});
|
||||
// add_constraint({v, i}, {B, i});
|
||||
}
|
||||
}
|
||||
// Element-wise
|
||||
@@ -268,35 +266,53 @@ void tune::run(ir::module &mod) {
|
||||
for(ir::function *fn: mod.get_function_list())
|
||||
for(ir::basic_block *block: fn->blocks())
|
||||
for(ir::instruction *i : block->get_inst_list()){
|
||||
|
||||
|
||||
if(fragments_.find({i, 0}) != fragments_.end() && fragments_.at({i, 0}) != STRIDED_SCAN)
|
||||
continue;
|
||||
if(auto *ld = dynamic_cast<ir::load_inst*>(i))
|
||||
|
||||
if(auto *x = dynamic_cast<ir::load_inst*>(i))
|
||||
if(i->get_type()->is_tile_ty()){
|
||||
ir::type *ptr_ty = ld->get_pointer_operand()->get_type()->get_scalar_ty();
|
||||
ir::type *ptr_ty = x->get_pointer_operand()->get_type()->get_scalar_ty();
|
||||
size_t addr_space = ptr_ty->get_pointer_address_space();
|
||||
if(addr_space < 4){
|
||||
ir::type *ty = mod.get_builder().get_int32_ty();
|
||||
std::unique_ptr<ir::metaparameter> tmp(ir::metaparameter::create(ctx, ty, 1, 1));
|
||||
std::unique_ptr<ir::metaparameter> tmp(ir::metaparameter::create(ctx, ty, 2, 4));
|
||||
*params_.at(i).at("nts.d0") = *tmp;
|
||||
}
|
||||
}
|
||||
if(dynamic_cast<ir::dot_inst*>(i) && i->get_type()->is_tile_ty()){
|
||||
ir::type *ty = mod.get_builder().get_int32_ty();
|
||||
std::unique_ptr<ir::metaparameter> tmp1(ir::metaparameter::create(ctx, ty, 1, 1));
|
||||
std::unique_ptr<ir::metaparameter> tmp2(ir::metaparameter::create(ctx, ty, 1, 1));
|
||||
std::unique_ptr<ir::metaparameter> tmp1(ir::metaparameter::create(ctx, ty, 2, 4));
|
||||
std::unique_ptr<ir::metaparameter> tmp2(ir::metaparameter::create(ctx, ty, 2, 4));
|
||||
*params_.at(i).at("nts.d0") = *tmp1;
|
||||
*params_.at(i).at("nts.d1") = *tmp2;
|
||||
}
|
||||
}
|
||||
|
||||
// initialize grids
|
||||
for(ir::function *fn: mod.get_function_list()){
|
||||
std::map<ir::metaparameter*, ir::instruction*> references;
|
||||
create_grids(grids_, references, fn);
|
||||
}
|
||||
|
||||
for(ir::instruction *i: grids_){
|
||||
auto shapes = i->get_type()->get_tile_shapes();
|
||||
for(size_t k = 0; k < shapes.size(); k++)
|
||||
if(shapes[k]->get_value() == 1) {
|
||||
if(fragments_.at({i, k}) == STRIDED_SCAN){
|
||||
params_.at(i).at("nts.d" + std::to_string(k))->set_value(1);
|
||||
params_.at(i).at("mts.d" + std::to_string(k))->set_value(1);
|
||||
}
|
||||
if(fragments_.at({i, k}) == HMMA_FRAGMENT_C){
|
||||
params_.at(i).at("fpw.d" + std::to_string(k))->set_value(1);
|
||||
params_.at(i).at("wpt.d" + std::to_string(k))->set_value(1);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void tune::init(ir::module &mod) {
|
||||
for(ir::function *fn: mod.get_function_list()){
|
||||
// initialize grids
|
||||
std::map<ir::metaparameter*, ir::instruction*> references;
|
||||
create_grids(grids_, references, fn);
|
||||
}
|
||||
// number of threads
|
||||
num_threads_ = get_req_num_threads(grids_.front());
|
||||
}
|
||||
|
||||
|
@@ -64,7 +64,9 @@ std::pair<base*, rt::jit*> base::get_profile_impl(driver::stream *stream, std::v
|
||||
else{
|
||||
// params_t params = heuristics();
|
||||
// params_t params = jit->get_valid(name_.c_str(), src.c_str());
|
||||
params_t params = {4, 1, 32, 4, 1, 32, 4, 4, 4, 1, 1, 16, 32, 16, 4, 4, 1};
|
||||
// params_t params = {4, 1, 32, 4, 1, 32, 4, 4, 4, 1, 1, 16, 32, 16, 4, 4, 4, 4, 1}; //NT
|
||||
// params_t params = {4, 1, 32, 4, 32, 4, 4, 4, 1, 1, 16, 32, 16, 1, 4, 4, 4, 4, 4, 1}; //NN
|
||||
params_t params = {4, 32, 4, 1, 32, 4, 4, 4, 1, 1, 16, 1, 32, 16, 4, 4, 4, 4, 4, 1}; // TT
|
||||
jit->add_module(name_.c_str(), src.c_str(), params);
|
||||
}
|
||||
triton::driver::kernel* kernel = jit->get_function(name_.c_str());
|
||||
|
@@ -74,22 +74,24 @@ void dot::enqueue_impl(driver::stream *stream, driver::kernel *kernel,
|
||||
void dot::triton_c_src(std::ostream &os) const {
|
||||
std::string AS0 = "TM", AS1 = "TK";
|
||||
std::string BS0 = "TK", BS1 = "TN";
|
||||
std::string XAS0 = "TM", XAS1 = "TK/4", XAS2 = "4";
|
||||
std::string XBS0 = "TK/4", XBS1 = "TN", XBS2 = "4";
|
||||
std::string XAS0 = "TM", XAS1 = "TK/1", XAS2 = "1";
|
||||
std::string XBS0 = "TK/1", XBS1 = "1", XBS2 = "TN";
|
||||
std::string bca0 = "[newaxis, :]", bca1 = "[:, newaxis]";
|
||||
std::string bcb0 = "[:, newaxis]", bcb1 = "[newaxis, :]";
|
||||
std::string lda0 = "*lda", lda1 = "";
|
||||
std::string ldb0 = "", ldb1 = "*ldb";
|
||||
std::string usea = AT_ ? "trans(xa)" : "xa";
|
||||
std::string useb = BT_ ? "trans(xb)" : "xb";
|
||||
std::string usea = AT_ ? "trans(xa, 0, 2, 1)" : "xa";
|
||||
std::string useb = BT_ ? "trans(xb, 1, 0, 2)" : "trans(xb, 0, 2, 1)";
|
||||
if(AT_){
|
||||
std::swap(AS0, AS1);
|
||||
std::swap(XAS0, XAS1);
|
||||
std::swap(XAS1, XAS2);
|
||||
std::swap(bca0, bca1);
|
||||
std::swap(lda0, lda1);
|
||||
}
|
||||
if(BT_){
|
||||
std::swap(BS0, BS1);
|
||||
std::swap(XBS1, XBS2);
|
||||
std::swap(XBS0, XBS1);
|
||||
std::swap(bcb0, bcb1);
|
||||
std::swap(ldb0, ldb1);
|
||||
@@ -98,7 +100,7 @@ void dot::triton_c_src(std::ostream &os) const {
|
||||
std::string BS = BS0 + ", " + BS1;
|
||||
std::string XAS = XAS0 + ", " + XAS1 + ", " + XAS2;
|
||||
std::string XBS = XBS0 + ", " + XBS1 + ", " + XBS2;
|
||||
std::string XCS = "TM, TN, 4";
|
||||
std::string XCS = "TM, TN, 1";
|
||||
std::string align_lda_str = "multiple_of(" + std::to_string(align_lda_) + ")";
|
||||
std::string align_ldb_str = "multiple_of(" + std::to_string(align_ldb_) + ")";
|
||||
std::string res =
|
||||
@@ -146,7 +148,7 @@ void matmul(restrict read_only align(16) )" + a_ty_ + R"( *A,
|
||||
}
|
||||
)";
|
||||
|
||||
std::cout << res << std::endl;
|
||||
// std::cout << res << std::endl;
|
||||
os << res;
|
||||
}
|
||||
|
||||
|
@@ -482,7 +482,8 @@ std::string retile_inst::shape_suffix(ir::type* ty){
|
||||
std::string res = "[";
|
||||
const auto& shapes = ty->get_tile_shapes();
|
||||
for(unsigned i = 0; i < shapes.size(); i++){
|
||||
res += std::to_string(ty->get_tile_shapes()[i]->get_value());
|
||||
ir::constant_int *shape_i = ty->get_tile_shapes()[i];
|
||||
res += shape_i->repr();
|
||||
if(i < shapes.size() - 1)
|
||||
res += ", ";
|
||||
}
|
||||
@@ -566,26 +567,33 @@ instruction *dot_inst::create_tt(value *A, value *B, value *C,
|
||||
// trans instructions
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
ir::type* trans_inst::get_res_ty(ir::type* ty) {
|
||||
auto shapes = ty->get_tile_shapes();
|
||||
std::rotate(shapes.begin(), shapes.begin() + 1, shapes.end());
|
||||
return tile_type::get(ty->get_scalar_ty(), shapes);
|
||||
ir::type* trans_inst::get_res_ty(ir::type* ty, std::vector<constant_int*> perm) {
|
||||
// get argument shapes
|
||||
ir::tile_type::tile_shapes_t arg_shapes = ty->get_tile_shapes();
|
||||
// permutate argument shapes
|
||||
perm = init_perm(ty, perm);
|
||||
ir::tile_type::tile_shapes_t res_shapes = arg_shapes;
|
||||
for(int i = 0; i < perm.size(); i++)
|
||||
res_shapes[i] = arg_shapes[perm[i]->get_value()];
|
||||
// construct type
|
||||
return tile_type::get(ty->get_scalar_ty(), res_shapes);
|
||||
}
|
||||
|
||||
std::vector<constant_int*> trans_inst::get_default_perm(ir::type* ty) {
|
||||
std::vector<constant_int*> trans_inst::init_perm(ir::type* ty, const std::vector<constant_int*>& perm) {
|
||||
if(!perm.empty())
|
||||
return perm;
|
||||
auto size = ty->get_tile_shapes().size();
|
||||
ir::type* int32_ty = type::get_int32_ty(ty->get_context());
|
||||
std::vector<constant_int*> result;
|
||||
for(size_t i = 0; i < size; i++)
|
||||
result.push_back(ir::constant_int::get(int32_ty, i + 1 % size));
|
||||
result.push_back(ir::constant_int::get(int32_ty, size - 1));
|
||||
for(int i = 0; i < size - 1; i++)
|
||||
result.push_back(ir::constant_int::get(int32_ty, i));
|
||||
return result;
|
||||
}
|
||||
|
||||
trans_inst::trans_inst(value *arg, const std::vector<constant_int*>& perm, const std::string &name, instruction *next)
|
||||
: builtin_inst(get_res_ty(arg->get_type()), 1, 1, name, next) {
|
||||
perm_ = perm;
|
||||
if(perm_.empty())
|
||||
perm_ = get_default_perm(arg->get_type());
|
||||
: builtin_inst(get_res_ty(arg->get_type(), perm), 1, 1, name, next) {
|
||||
perm_ = init_perm(arg->get_type(), perm);
|
||||
auto size = arg->get_type()->get_tile_shapes().size();
|
||||
assert(perm_.size() == size);
|
||||
set_operand(0, arg);
|
||||
@@ -615,7 +623,7 @@ instruction* sqrt_inst::create(value *arg, const std::string &name, instruction
|
||||
//===----------------------------------------------------------------------===//
|
||||
// reduce instructions
|
||||
//===----------------------------------------------------------------------===//
|
||||
type* reduce_inst::get_type(value *arg, unsigned axis) {
|
||||
type* reduce_inst::get_res_type(value *arg, unsigned axis) {
|
||||
ir::tile_type::tile_shapes_t shapes = arg->get_type()->get_tile_shapes();
|
||||
shapes.erase(shapes.begin() + axis);
|
||||
type *scalar_ty = arg->get_type()->get_scalar_ty();
|
||||
@@ -626,7 +634,7 @@ type* reduce_inst::get_type(value *arg, unsigned axis) {
|
||||
}
|
||||
|
||||
reduce_inst::reduce_inst(value *arg, unsigned axis, const std::string &name, instruction *next)
|
||||
: builtin_inst(get_type(arg, axis), 1, 1, name, next),
|
||||
: builtin_inst(get_res_type(arg, axis), 1, 1, name, next),
|
||||
axis_(axis){
|
||||
set_operand(0, arg);
|
||||
}
|
||||
|
@@ -203,7 +203,17 @@ ir::value* select_expression::codegen(ir::module *mod) const {
|
||||
|
||||
// trans
|
||||
ir::value* trans_expression::codegen(ir::module *mod) const {
|
||||
return mod->get_builder().create_trans(arg_->codegen(mod));
|
||||
// shapes
|
||||
std::vector<ir::constant_int*> perm;
|
||||
if(perm_) {
|
||||
for(expression *expr: perm_->values()){
|
||||
ir::constant_int *shape = dynamic_cast<ir::constant_int*>(expr->codegen(mod));
|
||||
if(shape == nullptr)
|
||||
throw std::runtime_error("tile shapes must be constant expressions");
|
||||
perm.push_back(shape);
|
||||
}
|
||||
}
|
||||
return mod->get_builder().create_trans(arg_->codegen(mod), perm);
|
||||
}
|
||||
|
||||
// sqrt
|
||||
|
@@ -37,13 +37,13 @@ void parallel_loop_nest(std::vector<size_t> const & ranges,
|
||||
size_t D = ranges.size();
|
||||
std::vector<size_t> values(D, 0);
|
||||
// thread pools
|
||||
// ThreadPool pool(nthreads);
|
||||
ThreadPool pool(nthreads);
|
||||
// Start with innermost loop
|
||||
size_t i = D - 1;
|
||||
while(true){
|
||||
// Execute function
|
||||
// pool.enqueue(f,values);
|
||||
f(values);
|
||||
pool.enqueue(f,values);
|
||||
// f(values);
|
||||
while(values[i]++ == ranges[i] - 1){
|
||||
if(i == 0)
|
||||
return;
|
||||
|
Reference in New Issue
Block a user