From f257884eb74712d7f864a9a8987daef8082cb54c Mon Sep 17 00:00:00 2001 From: Philippe Tillet Date: Mon, 24 Jun 2019 09:31:34 -0700 Subject: [PATCH] some cleaning --- include/triton/dnn/gemm.h | 3 +++ include/triton/lang/declaration.h | 33 +++++++++++++++++++++-------- include/triton/lang/node.h | 2 +- include/triton/lang/parser.y | 22 +++++++++---------- include/triton/lang/scanner.l | 10 +++++---- lib/codegen/shmem_allocation.cpp | 14 ++++++------- lib/lang/declaration.cpp | 35 +++++++++++++++++-------------- 7 files changed, 71 insertions(+), 48 deletions(-) diff --git a/include/triton/dnn/gemm.h b/include/triton/dnn/gemm.h index 0697ea981..e44c9631d 100644 --- a/include/triton/dnn/gemm.h +++ b/include/triton/dnn/gemm.h @@ -8,11 +8,14 @@ namespace dnn{ class gemm { public: static void init(driver::stream* stream, driver::buffer* locks); + static void set_arg(driver::kernel *kernel, driver::buffer *a, driver::buffer *b, driver::buffer *c, int32_t M, int32_t N, int32_t K, driver::buffer *locks, int32_t grid_0, int32_t grid_1); + static std::vector default_params(bool AT, bool BT); + static std::string src(bool AT, bool BT); template diff --git a/include/triton/lang/declaration.h b/include/triton/lang/declaration.h index a7dbdb97e..22275630c 100644 --- a/include/triton/lang/declaration.h +++ b/include/triton/lang/declaration.h @@ -40,34 +40,49 @@ public: }; // Types +class modifier: public node { + +}; + +class storage_specifier: public node { +public: + storage_specifier(STORAGE_SPEC_T value): value_(value) {} + STORAGE_SPEC_T value() const { return value_; } + +private: + const STORAGE_SPEC_T value_; +}; + + class declaration_specifier: public node{ public: virtual ir::type* type(ir::module *mod) const = 0; - virtual std::vector storage() const = 0; + virtual std::vector modifiers() const = 0; }; class typed_declaration_specifier: public declaration_specifier { public: typed_declaration_specifier(TYPE_T ty): ty_(ty){ } ir::type* type(ir::module *mod) const; - std::vector storage() const; + std::vector modifiers() const; private: const TYPE_T ty_; }; -class storage_declaration_specifier: public declaration_specifier { +class declaration_modifier: public declaration_specifier { public: - storage_declaration_specifier(STORAGE_SPEC_T storage_spec, node *decl_spec) - : storage_spec_(storage_spec), decl_spec_((declaration_specifier*)decl_spec) {} + declaration_modifier(node* mod, node *decl_spec) + : mod_((modifier*)mod), decl_spec_((declaration_specifier*)decl_spec) {} ir::type* type(ir::module *mod) const; - std::vector storage() const; + std::vector modifiers() const; private: - const STORAGE_SPEC_T storage_spec_; + modifier* mod_; const declaration_specifier* decl_spec_; }; + class declarator; class parameter: public node { public: @@ -76,7 +91,7 @@ public: decl_((declarator*)decl) { } ir::type* type(ir::module *mod) const; - std::vector storage() const; + std::vector storage() const; const identifier* id() const; public: @@ -87,7 +102,7 @@ public: /* Declarators */ class declarator: public node{ protected: - typedef std::vector storage_spec_vec_t; + typedef std::vector storage_spec_vec_t; typedef const storage_spec_vec_t& storage_spec_vec_const_ref_t; public: diff --git a/include/triton/lang/node.h b/include/triton/lang/node.h index e689f6f16..c9bd0b011 100644 --- a/include/triton/lang/node.h +++ b/include/triton/lang/node.h @@ -23,7 +23,7 @@ class identifier; class constant; class compound_statement; class initializer; -class declaration_specifier; +class modifier; class function; // Node diff --git a/include/triton/lang/parser.y b/include/triton/lang/parser.y index 18fc3bbed..21065d94f 100644 --- a/include/triton/lang/parser.y +++ b/include/triton/lang/parser.y @@ -47,7 +47,7 @@ STORAGE_SPEC_T get_storage_spec(node *op) { return ((token*)op)->storage_spec;} %} %token IDENTIFIER CONSTANT STRING_LITERAL -%token TUNABLE KERNEL RESTRICT READONLY WRITEONLY CONST CONSTANT_SPACE +%token TUNABLE KERNEL RESTRICT READONLY WRITEONLY CONST CONSTANT_SPACE ALIGN MULTIPLE_OF %token PTR_OP INC_OP DEC_OP LEFT_OP RIGHT_OP LE_OP GE_OP EQ_OP NE_OP %token AND_OP OR_OP MUL_ASSIGN DIV_ASSIGN MOD_ASSIGN ADD_ASSIGN %token SUB_ASSIGN LEFT_ASSIGN RIGHT_ASSIGN AND_ASSIGN @@ -351,8 +351,8 @@ parameter_declaration declaration_specifiers - : type_specifier { $$ = new typed_declaration_specifier(get_type_spec($1)); } - | storage_class_specifier declaration_specifiers { $$ = new storage_declaration_specifier(get_storage_spec($1), $2); } + : type_specifier { $$ = new typed_declaration_specifier(get_type_spec($1)); } + | storage_class_specifier declaration_specifiers { $$ = new declaration_modifier($1, $2); } ; init_declarator_list @@ -376,13 +376,13 @@ init_declarator ; storage_class_specifier - : CONST { $$ = new token(CONST_T); } - | TUNABLE { $$ = new token(TUNABLE_T); } - | KERNEL { $$ = new token(KERNEL_T); } - | RESTRICT { $$ = new token(RESTRICT_T); } - | READONLY { $$ = new token(READONLY_T); } - | WRITEONLY { $$ = new token(WRITEONLY_T); } - | CONSTANT_SPACE { $$ = new token(CONSTANT_SPACE_T); } + : CONST { $$ = new storage_specifier(CONST_T); } + | TUNABLE { $$ = new storage_specifier(TUNABLE_T); } + | KERNEL { $$ = new storage_specifier(KERNEL_T); } + | RESTRICT { $$ = new storage_specifier(RESTRICT_T); } + | READONLY { $$ = new storage_specifier(READONLY_T); } + | WRITEONLY { $$ = new storage_specifier(WRITEONLY_T); } + | CONSTANT_SPACE { $$ = new storage_specifier(CONSTANT_SPACE_T); } ; external_declaration @@ -399,7 +399,7 @@ function_definition /* -------------------------- */ translation_unit - : external_declaration { ast_root = new translation_unit($1); $$ = ast_root; } + : external_declaration { ast_root = new translation_unit($1); $$ = ast_root; } | translation_unit external_declaration { $$ = ((translation_unit*)($1))->add($2); } ; diff --git a/include/triton/lang/scanner.l b/include/triton/lang/scanner.l index a2cd50922..e91b25961 100644 --- a/include/triton/lang/scanner.l +++ b/include/triton/lang/scanner.l @@ -21,11 +21,13 @@ using triton::lang::return_void; "restrict" { return return_impl(RESTRICT, yytext); } "read_only" { return return_impl(READONLY, yytext); } "write_only" { return return_impl(WRITEONLY, yytext); } +"align" { return return_impl(ALIGN, yytext); } +"multiple_of" { return return_impl(MULTIPLE_OF, yytext); } "@" { return return_impl(AT, yytext); } -"newaxis" { return return_impl(NEWAXIS, yytext); } -"if" { return return_impl(IF, yytext); } -"else" { return return_impl(ELSE, yytext); } -"for" { return return_impl(FOR, yytext); } +"newaxis" { return return_impl(NEWAXIS, yytext); } +"if" { return return_impl(IF, yytext); } +"else" { return return_impl(ELSE, yytext); } +"for" { return return_impl(FOR, yytext); } "while" { return return_impl(WHILE, yytext); } "void" { return return_impl(VOID, yytext); } "uint1" { return return_impl(UINT1, yytext); } diff --git a/lib/codegen/shmem_allocation.cpp b/lib/codegen/shmem_allocation.cpp index 7940808bb..1558c663d 100644 --- a/lib/codegen/shmem_allocation.cpp +++ b/lib/codegen/shmem_allocation.cpp @@ -12,12 +12,6 @@ namespace triton{ namespace codegen{ unsigned shmem_allocation::is_ld_padded(ir::value *x) { - if(auto* phi = dynamic_cast(x)) { - unsigned result = 0; - for(unsigned i = 0; i < phi->get_num_incoming(); i++) - result = std::max(result, is_ld_padded(phi->get_incoming_value(i))); - return result; - } if(dynamic_cast(x)) return 4; for(ir::user* user: x->get_users()) @@ -25,7 +19,13 @@ unsigned shmem_allocation::is_ld_padded(ir::value *x) { if(params_->get_fragment(user, 0) == tune::HMMA_FRAGMENT_C){ return 16; } - return 16; + if(auto* phi = dynamic_cast(x)) { + unsigned result = 0; + for(unsigned i = 0; i < phi->get_num_incoming(); i++) + result = std::max(result, is_ld_padded(phi->get_incoming_value(i))); + return result; + } + return 0; } unsigned shmem_allocation::get_num_bytes(ir::value *x) { diff --git a/lib/lang/declaration.cpp b/lib/lang/declaration.cpp index 46fa6b597..c5a23def5 100644 --- a/lib/lang/declaration.cpp +++ b/lib/lang/declaration.cpp @@ -28,18 +28,18 @@ ir::type* typed_declaration_specifier::type(ir::module *mod) const { } } -std::vector typed_declaration_specifier::storage() const { +std::vector typed_declaration_specifier::modifiers() const { return {}; } -ir::type* storage_declaration_specifier::type(ir::module *mod) const { +ir::type* declaration_modifier::type(ir::module *mod) const { return decl_spec_->type(mod); } -std::vector storage_declaration_specifier::storage() const { - auto result = decl_spec_->storage(); - result.push_back(storage_spec_); +std::vector declaration_modifier::modifiers() const { + auto result = decl_spec_->modifiers(); + result.push_back(mod_); return result; } @@ -49,8 +49,8 @@ ir::type* parameter::type(ir::module *mod) const { return decl_->type(mod, spec_->type(mod), {}); } -std::vector parameter::storage() const { - return spec_->storage(); +std::vector parameter::storage() const { + return spec_->modifiers(); } const identifier *parameter::id() const { @@ -87,7 +87,8 @@ ir::type* tile::type_impl(ir::module *mod, ir::type *type, storage_spec_vec_cons // Pointer ir::type* pointer::type_impl(ir::module*, ir::type *type, storage_spec_vec_const_ref_t storage) const{ - bool is_ptr_to_const = std::find(storage.begin(), storage.end(), CONSTANT_SPACE_T) != storage.end(); + auto is_cst = [](modifier* x){ return x->value() == CONSTANT_SPACE_T; }; + bool is_ptr_to_const = std::find_if(storage.begin(), storage.end(), is_cst) != storage.end(); return ir::pointer_type::get(type, is_ptr_to_const?4:1); } @@ -132,11 +133,12 @@ void initializer::set_specifier(const declaration_specifier *spec) { } ir::value* initializer::codegen(ir::module * mod) const{ - std::vector storage = spec_->storage(); + std::vector storage = spec_->modifiers(); ir::type *ty = decl_->type(mod, spec_->type(mod), storage); std::string name = decl_->id()->name(); ir::value *value = ir::undef_value::get(ty); - if(std::find(storage.begin(), storage.end(), TUNABLE_T) != storage.end()){ + auto is_tunable = [](modifier* x){ return x->value() == TUNABLE_T; }; + if(std::find_if(storage.begin(), storage.end(), is_tunable) != storage.end()){ auto csts = dynamic_cast*>((node*)expr_); if(csts == nullptr) throw std::runtime_error("must specify constant list for metaparameters"); @@ -156,7 +158,8 @@ ir::value* initializer::codegen(ir::module * mod) const{ mod->get_scope().types[name] = ty; if(auto *x = dynamic_cast(value)) mod->add_alloc(x); - if(std::find(storage.begin(), storage.end(), CONST_T) != storage.end()) + auto is_cst = [](modifier* mod){ return mod->value() == CONST_T; }; + if(std::find_if(storage.begin(), storage.end(), is_cst) != storage.end()) mod->set_const(name); return value; } @@ -167,8 +170,8 @@ ir::type *type_name::type(ir::module *mod) const{ } /* Function definition */ -ir::attribute_t get_ir_attr(STORAGE_SPEC_T spec){ - switch(spec){ +ir::attribute_t get_ir_attr(modifier* mod){ + switch(mod->value()){ case RESTRICT_T: return ir::noalias; case READONLY_T: return ir::readonly; case WRITEONLY_T: return ir::writeonly; @@ -177,13 +180,13 @@ ir::attribute_t get_ir_attr(STORAGE_SPEC_T spec){ } ir::value* function_definition::codegen(ir::module *mod) const{ - ir::function_type *prototype = (ir::function_type*)header_->type(mod, spec_->type(mod), spec_->storage()); + ir::function_type *prototype = (ir::function_type*)header_->type(mod, spec_->type(mod), spec_->modifiers()); const std::string &name = header_->id()->name(); ir::function *fn = mod->get_or_insert_function(name, prototype); for(unsigned i = 0; i < header_->get_num_args(); i++){ parameter *param = header_->get_arg(i); - std::vector storage = param->storage(); - for(STORAGE_SPEC_T spec: storage) + std::vector storage = param->storage(); + for(modifier* spec: storage) fn->add_attr(1 + i, get_ir_attr(spec)); } header_->bind_parameters(mod, fn);