From 72867d17d492957c9121fb096dc6f62387e880d5 Mon Sep 17 00:00:00 2001 From: Philippe Tillet Date: Mon, 24 Jun 2019 12:37:13 -0700 Subject: [PATCH] more cleaning --- examples/python/tensorflow/dot.cpp | 13 ++++++----- include/triton/ir/function.h | 32 ++++++++++++++++++++++---- include/triton/lang/declaration.h | 35 +++++++++++++++++++++++++---- include/triton/lang/parser.y | 9 ++++++++ lib/codegen/selection.cpp | 20 ++++++++--------- lib/driver/backend.cpp | 4 ++-- lib/lang/declaration.cpp | 36 +++++++++++++++++++++--------- 7 files changed, 113 insertions(+), 36 deletions(-) diff --git a/examples/python/tensorflow/dot.cpp b/examples/python/tensorflow/dot.cpp index 937309df9..9b52020f6 100644 --- a/examples/python/tensorflow/dot.cpp +++ b/examples/python/tensorflow/dot.cpp @@ -26,11 +26,12 @@ const tunable int32 TN = {64, 128}; const tunable int32 TK = {16}; const tunable int32 GZ = {1}; -void matmul(restrict read_only fp16 *A, restrict read_only fp16 *B, - fp32 *C, - int32 M, int32 N, int32 K, - int32 lda, int32 ldb, int32 ldc, - int32 *locks, int32 grid0, int32 grid1) { +void matmul(restrict read_only align(4) fp16 *A, + restrict read_only align(4) fp16 *B, + align(4) fp32 *C, + int32 M, int32 N, int32 K, + int32 lda, int32 ldb, int32 ldc, + int32 *locks, int32 grid0, int32 grid1) { int32 rxa[TM] = get_global_range[TM](0); int32 ryb[TN] = get_global_range[TN](1); int32 rz = get_global_range[1](2); @@ -119,7 +120,7 @@ class BlockSparseGemmOp : public OpKernel { return 2.*M*N*K / ts * 1e-3; }; // just-in-time compile source-code - jit.autotune("matmul", src, benchmark); +// jit.autotune("matmul", src, benchmark); // jit.add_module("matmul", src, {4, 2, 8, 4, 2, 32, 1, 4, 1, 1, 8, 8, 8, 1}); // jit.add_module("matmul", src, {16, 4, 128, 16, 4, 128, 2, 2, 2, 2, 8, 32, 8, 1}); // jit.add_module("matmul", src, {8, 8, 128, 16, 8, 128, 2, 2, 2, 2, 16, 32, 8, 1 }); diff --git a/include/triton/ir/function.h b/include/triton/ir/function.h index cc00b4a92..cb1ab1f6d 100644 --- a/include/triton/ir/function.h +++ b/include/triton/ir/function.h @@ -28,10 +28,34 @@ private: }; /* Attribute */ -enum attribute_t { +enum attribute_kind_t { readonly, writeonly, - noalias + noalias, + aligned, + multiple_of +}; + +class attribute { +public: + attribute(attribute_kind_t kind, unsigned value = 0): + kind_(kind), value_(value){} + + bool operator<(const attribute& other) const { + return std::make_pair(kind_, value_) < std::make_pair(other.kind_, other.value_); + } + + const attribute_kind_t get_kind() const { + return kind_; + } + + const unsigned get_value() const { + return value_; + } + +private: + attribute_kind_t kind_; + unsigned value_; }; /* Function */ @@ -44,7 +68,7 @@ class function: public global_object{ typedef blocks_t::iterator block_iterator; typedef blocks_t::const_iterator const_block_iterator; - typedef std::map> attr_map_t; + typedef std::map> attr_map_t; private: function(function_type *ty, linkage_types_t linkage, @@ -63,7 +87,7 @@ public: void insert_block(basic_block* block, basic_block *next = nullptr); // attributes - void add_attr(unsigned arg_id, attribute_t attr) { attrs_[arg_id].insert(attr); } + void add_attr(unsigned arg_id, attribute attr) { attrs_[arg_id].insert(attr); } const attr_map_t &attrs() { return attrs_; } private: diff --git a/include/triton/lang/declaration.h b/include/triton/lang/declaration.h index 22275630c..b5f4de412 100644 --- a/include/triton/lang/declaration.h +++ b/include/triton/lang/declaration.h @@ -3,7 +3,7 @@ #include "node.h" #include - +#include namespace triton{ @@ -41,19 +41,45 @@ public: // Types class modifier: public node { - +public: + virtual bool is_cst_space() const { return false; } + virtual bool is_tunable() const { return false; } + virtual bool is_cst() const { return false; } + virtual void add_attr(ir::function* fn, size_t pos) = 0; }; -class storage_specifier: public node { +class storage_specifier: public modifier { public: storage_specifier(STORAGE_SPEC_T value): value_(value) {} STORAGE_SPEC_T value() const { return value_; } + bool is_cst_space() const { return value_ == CONSTANT_SPACE_T; } + bool is_tunable() const { return value_ == TUNABLE_T; } + bool is_cst() const { return value_ == CONST_T; } + void add_attr(ir::function* fn, size_t pos); private: const STORAGE_SPEC_T value_; }; +class alignment_specifier: public modifier { +public: + alignment_specifier(node* value): cst_((constant*)value) { } + void add_attr(ir::function* fn, size_t pos); +private: + constant* cst_; +}; + +class multiple_of_specifier: public modifier { +public: + multiple_of_specifier(node* value): cst_((constant*)value) {} + void add_attr(ir::function* fn, size_t pos); + +private: + constant* cst_; +}; + +// declaration specifier class declaration_specifier: public node{ public: virtual ir::type* type(ir::module *mod) const = 0; @@ -70,6 +96,7 @@ private: const TYPE_T ty_; }; +// declaration modifier class declaration_modifier: public declaration_specifier { public: declaration_modifier(node* mod, node *decl_spec) @@ -91,7 +118,7 @@ public: decl_((declarator*)decl) { } ir::type* type(ir::module *mod) const; - std::vector storage() const; + std::vector modifiers() const; const identifier* id() const; public: diff --git a/include/triton/lang/parser.y b/include/triton/lang/parser.y index 21065d94f..2c942b86c 100644 --- a/include/triton/lang/parser.y +++ b/include/triton/lang/parser.y @@ -353,6 +353,8 @@ parameter_declaration declaration_specifiers : type_specifier { $$ = new typed_declaration_specifier(get_type_spec($1)); } | storage_class_specifier declaration_specifiers { $$ = new declaration_modifier($1, $2); } + | alignment_class_specifier declaration_specifiers { $$ = new declaration_modifier($1, $2); } + | multiple_of_class_specifier declaration_specifiers { $$ = new declaration_modifier($1, $2); } ; init_declarator_list @@ -385,6 +387,13 @@ storage_class_specifier | CONSTANT_SPACE { $$ = new storage_specifier(CONSTANT_SPACE_T); } ; +alignment_class_specifier + : ALIGN '(' constant ')' { $$ = new alignment_specifier($3); } + +multiple_of_class_specifier + : MULTIPLE_OF '(' constant ')' { $$ = new multiple_of_specifier($3); } + + external_declaration : function_definition { $$ = $1; } | declaration { $$ = $1; } diff --git a/lib/codegen/selection.cpp b/lib/codegen/selection.cpp index 31cee7e6b..fed49407f 100644 --- a/lib/codegen/selection.cpp +++ b/lib/codegen/selection.cpp @@ -1074,11 +1074,12 @@ void selection::lower_instruction(ir::instruction *src, IRBuilder<> &builder) { } } -inline llvm::Attribute::AttrKind llvm_attr(ir::attribute_t attr) { - switch(attr){ - case ir::noalias: return llvm::Attribute::NoAlias; - case ir::readonly: return llvm::Attribute::ReadOnly; - case ir::writeonly: return llvm::Attribute::WriteOnly; +inline llvm::Attribute llvm_attr(llvm::LLVMContext& ctx, ir::attribute attr) { + switch(attr.get_kind()){ + case ir::noalias: return llvm::Attribute::get(ctx, llvm::Attribute::NoAlias); + case ir::readonly: return llvm::Attribute::get(ctx, llvm::Attribute::ReadOnly); + case ir::writeonly: return llvm::Attribute::get(ctx, llvm::Attribute::WriteOnly); + case ir::aligned: return llvm::Attribute::get(ctx, llvm::Attribute::Alignment, attr.get_value()); default: throw std::runtime_error("cannot convert ir::attribute_t to llvm::Attribute"); } } @@ -1101,6 +1102,7 @@ void selection::run(ir::module &src, Module &dst) { // iterate over functions for(ir::function *fn: src.get_function_list()) { + // create LLVM function FunctionType *fn_ty = (FunctionType*)llvm_type(fn->get_fn_type(), dst_ctx); FunctionType *dst_fn_ty = fn_ty; @@ -1114,18 +1116,16 @@ void selection::run(ir::module &src, Module &dst) { dst_fn_args_ty.push_back(dst_builder.getInt32Ty()); dst_fn_ty = FunctionType::get(dst_fn_ret_ty, dst_fn_args_ty, false); } + // grid indices fn->get_fn_type()->get_return_ty(); Function *dst_fn = Function::Create(dst_fn_ty, Function::ExternalLinkage, fn->get_name(), &dst); - - - // set attributes for(auto attr_pair: fn->attrs()){ unsigned id = attr_pair.first; - for(ir::attribute_t attr: attr_pair.second) - dst_fn->addAttribute(id, llvm_attr(attr)); + for(ir::attribute attr: attr_pair.second) + dst_fn->addAttribute(id, llvm_attr(dst_ctx, attr)); } tgt_->set_kernel(dst_builder, dst_ctx, &dst, dst_fn); diff --git a/lib/driver/backend.cpp b/lib/driver/backend.cpp index 9761e94e7..6f98be75c 100755 --- a/lib/driver/backend.cpp +++ b/lib/driver/backend.cpp @@ -63,7 +63,7 @@ void backend::platforms::init() { cache_.push_back(new host_platform()); } if(cache_.empty()) - throw std::runtime_error("ISAAC: No backend available. Make sure CUDA is available in your library path"); + throw std::runtime_error("Triton: No backend available. Make sure CUDA is available in your library path"); } void backend::platforms::get(std::vector &results) { @@ -83,7 +83,7 @@ void backend::devices::init(std::vector const & platforms) { for(driver::platform* pf: platforms) pf->devices(cache_); if(cache_.empty()) - throw std::runtime_error("ISAAC: No device available. Make sure that your platform is configured properly"); + throw std::runtime_error("Triton: No device available. Make sure that your platform is configured properly"); } void backend::devices::get(std::vector &devs) { diff --git a/lib/lang/declaration.cpp b/lib/lang/declaration.cpp index c5a23def5..b1a455099 100644 --- a/lib/lang/declaration.cpp +++ b/lib/lang/declaration.cpp @@ -49,7 +49,7 @@ ir::type* parameter::type(ir::module *mod) const { return decl_->type(mod, spec_->type(mod), {}); } -std::vector parameter::storage() const { +std::vector parameter::modifiers() const { return spec_->modifiers(); } @@ -87,7 +87,7 @@ ir::type* tile::type_impl(ir::module *mod, ir::type *type, storage_spec_vec_cons // Pointer ir::type* pointer::type_impl(ir::module*, ir::type *type, storage_spec_vec_const_ref_t storage) const{ - auto is_cst = [](modifier* x){ return x->value() == CONSTANT_SPACE_T; }; + auto is_cst = [](modifier* x){ return x->is_cst_space(); }; bool is_ptr_to_const = std::find_if(storage.begin(), storage.end(), is_cst) != storage.end(); return ir::pointer_type::get(type, is_ptr_to_const?4:1); } @@ -137,7 +137,7 @@ ir::value* initializer::codegen(ir::module * mod) const{ ir::type *ty = decl_->type(mod, spec_->type(mod), storage); std::string name = decl_->id()->name(); ir::value *value = ir::undef_value::get(ty); - auto is_tunable = [](modifier* x){ return x->value() == TUNABLE_T; }; + auto is_tunable = [](modifier* x){ return x->is_tunable(); }; if(std::find_if(storage.begin(), storage.end(), is_tunable) != storage.end()){ auto csts = dynamic_cast*>((node*)expr_); if(csts == nullptr) @@ -158,7 +158,7 @@ ir::value* initializer::codegen(ir::module * mod) const{ mod->get_scope().types[name] = ty; if(auto *x = dynamic_cast(value)) mod->add_alloc(x); - auto is_cst = [](modifier* mod){ return mod->value() == CONST_T; }; + auto is_cst = [](modifier* x){ return x->is_cst(); }; if(std::find_if(storage.begin(), storage.end(), is_cst) != storage.end()) mod->set_const(name); return value; @@ -169,9 +169,9 @@ ir::type *type_name::type(ir::module *mod) const{ return decl_->type(mod, spec_->type(mod), {}); } -/* Function definition */ -ir::attribute_t get_ir_attr(modifier* mod){ - switch(mod->value()){ +/* Storage specifier */ +inline ir::attribute_kind_t get_ir_attr(STORAGE_SPEC_T spec){ + switch(spec){ case RESTRICT_T: return ir::noalias; case READONLY_T: return ir::readonly; case WRITEONLY_T: return ir::writeonly; @@ -179,15 +179,31 @@ ir::attribute_t get_ir_attr(modifier* mod){ } } +void storage_specifier::add_attr(ir::function* fn, size_t pos) { + fn->add_attr(pos, ir::attribute(get_ir_attr(value_))); +} + +/* Alignment specifier */ +void alignment_specifier::add_attr(ir::function* fn, size_t pos) { + fn->add_attr(pos, ir::attribute(ir::aligned, cst_->value())); +} + +/* Multiple-Of specifier */ +void multiple_of_specifier::add_attr(ir::function* fn, size_t pos) { + fn->add_attr(pos, ir::attribute(ir::multiple_of, cst_->value())); +} + + +/* Function definition */ ir::value* function_definition::codegen(ir::module *mod) const{ ir::function_type *prototype = (ir::function_type*)header_->type(mod, spec_->type(mod), spec_->modifiers()); const std::string &name = header_->id()->name(); ir::function *fn = mod->get_or_insert_function(name, prototype); for(unsigned i = 0; i < header_->get_num_args(); i++){ parameter *param = header_->get_arg(i); - std::vector storage = param->storage(); - for(modifier* spec: storage) - fn->add_attr(1 + i, get_ir_attr(spec)); + std::vector modifiers = param->modifiers(); + for(modifier* m: modifiers) + m->add_attr(fn, 1 + i); } header_->bind_parameters(mod, fn); ir::basic_block *entry = ir::basic_block::create(mod->get_context(), "entry", fn);