[language] added alignment metadata for variables

This commit is contained in:
Philippe Tillet
2019-06-29 13:58:46 -07:00
parent d8c3d58593
commit 9a86bc51e1
12 changed files with 3183 additions and 9 deletions

View File

@@ -16,7 +16,7 @@ int main() {
triton::jit jit(context); triton::jit jit(context);
// matrix multiplication parameters // matrix multiplication parameters
int32_t M = 1024, N = 1024, K = 1024; int32_t M = 32768, N = 1024, K = 1024;
std::vector<float> hc(M*N); std::vector<float> hc(M*N);
std::vector<float> rc(M*N); std::vector<float> rc(M*N);
std::vector<float> ha(M*K); std::vector<float> ha(M*K);

View File

@@ -21,6 +21,7 @@ int main() {
int32_t BS = 32, F = 1024; int32_t BS = 32, F = 1024;
int32_t H = 32, W = 32; int32_t H = 32, W = 32;
int32_t C = 1024; int32_t C = 1024;
// random shifts // random shifts
std::vector<int32_t> shift_h(C); std::vector<int32_t> shift_h(C);
std::vector<int32_t> shift_w(C); std::vector<int32_t> shift_w(C);

3067
include/triton/external/half.hpp vendored Normal file

File diff suppressed because it is too large Load Diff

View File

@@ -2,8 +2,9 @@
#define TDL_INCLUDE_IR_INSTRUCTIONS_H #define TDL_INCLUDE_IR_INSTRUCTIONS_H
#include <vector> #include <vector>
#include "value.h" #include "triton/ir/value.h"
#include "triton/ir/type.h" #include "triton/ir/type.h"
#include "triton/ir/metadata.h"
#include "llvm/IR/Instructions.h" #include "llvm/IR/Instructions.h"
namespace triton{ namespace triton{
@@ -48,12 +49,16 @@ public:
// results // results
unsigned get_num_results() const { return results_.size(); } unsigned get_num_results() const { return results_.size(); }
value* get_result(unsigned i) { return results_.at(i); } value* get_result(unsigned i) { return results_.at(i); }
// metadata
void set_metadata(ir::metadata::kind_t kind,
unsigned value) { metadatas_[kind] = value;}
unsigned get_metadata(ir::metadata::kind_t kind) { return metadatas_[kind];}
private: private:
basic_block *parent_; basic_block *parent_;
value *pred_; value *pred_;
value *mask_pred_; value *mask_pred_;
std::vector<value*> results_; std::vector<value*> results_;
std::map<ir::metadata::kind_t, unsigned> metadatas_;
}; };
// result reference // result reference

View File

@@ -0,0 +1,29 @@
#ifndef TDL_INCLUDE_IR_METADATA_H
#define TDL_INCLUDE_IR_METADATA_H
namespace triton{
namespace ir{
/* Metadata */
class metadata{
public:
enum kind_t{
multiple_of
};
private:
metadata(kind_t kind, unsigned value);
public:
static metadata* get(kind_t kind, unsigned value);
private:
kind_t kind_;
unsigned value_;
};
}
}
#endif

View File

@@ -7,6 +7,7 @@
#include <string> #include <string>
#include <functional> #include <functional>
#include "builder.h" #include "builder.h"
#include "metadata.h"
namespace triton{ namespace triton{
@@ -38,6 +39,7 @@ struct scope {
class module { class module {
typedef std::pair<std::string, basic_block*> val_key_t; typedef std::pair<std::string, basic_block*> val_key_t;
friend class function; friend class function;
typedef std::pair<ir::metadata::kind_t, unsigned> md_pair_t;
public: public:
typedef std::map<std::string, global_value*> symbols_map_t; typedef std::map<std::string, global_value*> symbols_map_t;
@@ -84,6 +86,8 @@ public:
// Register global // Register global
void register_global(const std::string& name, ir::value *x) { globals_[name] = x; } void register_global(const std::string& name, ir::value *x) { globals_[name] = x; }
const std::map<std::string, ir::value*>& globals() const { return globals_; } const std::map<std::string, ir::value*>& globals() const { return globals_; }
// Metadata
void add_metadata(const std::string &name, md_pair_t x) { metadatas_[name] = x; }
private: private:
std::string name_; std::string name_;
@@ -101,6 +105,7 @@ private:
std::stack<scope> scopes_; std::stack<scope> scopes_;
std::vector<ir::alloc_const*> allocs_; std::vector<ir::alloc_const*> allocs_;
std::map<std::string, ir::value*> globals_; std::map<std::string, ir::value*> globals_;
std::map<std::string, md_pair_t> metadatas_;
}; };
} }

View File

@@ -45,7 +45,9 @@ public:
virtual bool is_cst_space() const { return false; } virtual bool is_cst_space() const { return false; }
virtual bool is_tunable() const { return false; } virtual bool is_tunable() const { return false; }
virtual bool is_cst() const { return false; } virtual bool is_cst() const { return false; }
virtual bool is_multiple_of() const { return false; }
virtual void add_attr(ir::function* fn, size_t pos) = 0; virtual void add_attr(ir::function* fn, size_t pos) = 0;
virtual void add_metadata(ir::module* mod, std::string name) = 0;
}; };
class storage_specifier: public modifier { class storage_specifier: public modifier {
@@ -56,6 +58,7 @@ public:
bool is_tunable() const { return value_ == TUNABLE_T; } bool is_tunable() const { return value_ == TUNABLE_T; }
bool is_cst() const { return value_ == CONST_T; } bool is_cst() const { return value_ == CONST_T; }
void add_attr(ir::function* fn, size_t pos); void add_attr(ir::function* fn, size_t pos);
void add_metadata(ir::module* mod, std::string name);
private: private:
const STORAGE_SPEC_T value_; const STORAGE_SPEC_T value_;
@@ -65,6 +68,7 @@ class alignment_specifier: public modifier {
public: public:
alignment_specifier(node* value): cst_((constant*)value) { } alignment_specifier(node* value): cst_((constant*)value) { }
void add_attr(ir::function* fn, size_t pos); void add_attr(ir::function* fn, size_t pos);
void add_metadata(ir::module* mod, std::string name);
private: private:
constant* cst_; constant* cst_;
@@ -74,6 +78,8 @@ class multiple_of_specifier: public modifier {
public: public:
multiple_of_specifier(node* value): cst_((constant*)value) {} multiple_of_specifier(node* value): cst_((constant*)value) {}
void add_attr(ir::function* fn, size_t pos); void add_attr(ir::function* fn, size_t pos);
void add_metadata(ir::module* mod, std::string name);
bool is_multiple_of() const { return true; }
private: private:
constant* cst_; constant* cst_;

View File

@@ -39,6 +39,11 @@ bool alignment_info::populate_is_constant(ir::value *v) {
bool rhs = populate_is_constant(x->get_operand(1)); bool rhs = populate_is_constant(x->get_operand(1));
return cache(lhs && rhs); return cache(lhs && rhs);
} }
if(auto *x = dynamic_cast<ir::psi_inst*>(v)){
bool value_true = populate_is_constant(x->get_value_true());
bool value_false = populate_is_constant(x->get_value_false());
return cache(value_true && value_false);
}
if(v->get_type()->is_tile_ty()) if(v->get_type()->is_tile_ty())
return cache(false); return cache(false);
if(auto *x = dynamic_cast<ir::phi_node*>(v)){ if(auto *x = dynamic_cast<ir::phi_node*>(v)){
@@ -97,6 +102,11 @@ unsigned alignment_info::populate_max_contiguous(ir::value *v){
return cache(lhs_max_contiguous); return cache(lhs_max_contiguous);
} }
} }
if(auto *x = dynamic_cast<ir::psi_inst*>(v)){
int value_true = populate_max_contiguous(x->get_value_true());
int value_false = populate_max_contiguous(x->get_value_false());
return cache(std::min(value_true, value_false));
}
if(auto *x = dynamic_cast<ir::getelementptr_inst*>(v)){ if(auto *x = dynamic_cast<ir::getelementptr_inst*>(v)){
ir::value* lhs = x->get_operand(0); ir::value* lhs = x->get_operand(0);
ir::value* rhs = x->get_operand(1); ir::value* rhs = x->get_operand(1);
@@ -132,6 +142,12 @@ unsigned alignment_info::populate_starting_multiple(ir::value *v){
if(starting_multiple_.find(v) != starting_multiple_.end()) if(starting_multiple_.find(v) != starting_multiple_.end())
return starting_multiple_.at(v); return starting_multiple_.at(v);
auto cache = [this,v](unsigned value){ return add_to_cache(v, value, starting_multiple_); }; auto cache = [this,v](unsigned value){ return add_to_cache(v, value, starting_multiple_); };
// has metadata
if(auto *x = dynamic_cast<ir::instruction*>(v)){
unsigned multiple_of = x->get_metadata(ir::metadata::multiple_of);
if(multiple_of > 0)
return cache(multiple_of);
}
// arguments // arguments
if(auto *x = dynamic_cast<ir::argument*>(v)){ if(auto *x = dynamic_cast<ir::argument*>(v)){
std::set<ir::attribute> attributes = x->get_parent()->get_attributes(x); std::set<ir::attribute> attributes = x->get_parent()->get_attributes(x);
@@ -174,6 +190,11 @@ unsigned alignment_info::populate_starting_multiple(ir::value *v){
if(auto *x = dynamic_cast<ir::get_global_range_inst*>(v)){ if(auto *x = dynamic_cast<ir::get_global_range_inst*>(v)){
return cache(v->get_type()->get_tile_shapes()[0]->get_value()); return cache(v->get_type()->get_tile_shapes()[0]->get_value());
} }
if(auto *x = dynamic_cast<ir::psi_inst*>(v)){
int value_true = populate_starting_multiple(x->get_value_true());
int value_false = populate_starting_multiple(x->get_value_false());
return cache(std::min(value_true, value_false));
}
if(auto *x = dynamic_cast<ir::phi_node*>(v)){ if(auto *x = dynamic_cast<ir::phi_node*>(v)){
// put a conservative initial value in phi node to avoid infinite recursion // put a conservative initial value in phi node to avoid infinite recursion
unsigned result = 1; unsigned result = 1;

View File

@@ -123,7 +123,7 @@ void shift(restrict read_only align(16) )" << a_ty_ << R"( *a,
restrict read_only align(16) )" << b_ty_ << R"( *b, restrict read_only align(16) )" << b_ty_ << R"( *b,
fp32 *c, fp32 *c,
multiple_of(4) int32 M, multiple_of(4) int32 N, multiple_of(4) int32 K, multiple_of(4) int32 M, multiple_of(4) int32 N, multiple_of(4) int32 K,
int32 lda, multiple_of(4) int32 lda,
int32 ABS, int32 AH, int32 AW, int32 AR, int32 AS) { int32 ABS, int32 AH, int32 AW, int32 AR, int32 AS) {
int32 rxa[TM] = get_global_range[TM](0); int32 rxa[TM] = get_global_range[TM](0);
int32 ryb[TN] = get_global_range[TN](1); int32 ryb[TN] = get_global_range[TN](1);
@@ -140,7 +140,8 @@ void shift(restrict read_only align(16) )" << a_ty_ << R"( *a,
int1 maskw[TM] = (raw >= pad_w) && (raw < (AW - pad_w)); int1 maskw[TM] = (raw >= pad_w) && (raw < (AW - pad_w));
int1 mask[TM, TK] = maskh[:, newaxis] && maskw[:, newaxis]; int1 mask[TM, TK] = maskh[:, newaxis] && maskw[:, newaxis];
__constant__ int32* pd[TK] = delta + rka; __constant__ int32* pd[TK] = delta + rka;
int32 d[TK] = *pd; multiple_of(4) int32 d[TK];
d = *pd;
int32 offa1[TK] = rka*lda; int32 offa1[TK] = rka*lda;
int32 inc[TM, TK] = mask ? d[newaxis, :] : offa1[newaxis, :]; int32 inc[TM, TK] = mask ? d[newaxis, :] : offa1[newaxis, :];
)" << a_ty_ << R"(* pa[TM, TK] = a + rxa[:, newaxis] + inc; )" << a_ty_ << R"(* pa[TM, TK] = a + rxa[:, newaxis] + inc;

14
lib/ir/metadata.cpp Normal file
View File

@@ -0,0 +1,14 @@
#include "triton/ir/metadata.h"
namespace triton{
namespace ir{
metadata::metadata(kind_t kind, unsigned value)
: kind_(kind), value_(value) { }
metadata* metadata::get(kind_t kind, unsigned value) {
return new metadata(kind, value);
}
}
}

View File

@@ -23,6 +23,11 @@ ir::context& module::get_context() {
void module::set_value(const std::string& name, ir::basic_block *block, ir::value *value){ void module::set_value(const std::string& name, ir::basic_block *block, ir::value *value){
values_[val_key_t{name, block}] = value; values_[val_key_t{name, block}] = value;
auto it = metadatas_.find(name);
if(auto *x = dynamic_cast<ir::instruction*>(value))
if(it != metadatas_.end()){
x->set_metadata(it->second.first, it->second.second);
}
} }
void module::set_value(const std::string& name, ir::value *value){ void module::set_value(const std::string& name, ir::value *value){

View File

@@ -5,6 +5,7 @@
#include "triton/ir/basic_block.h" #include "triton/ir/basic_block.h"
#include "triton/ir/builder.h" #include "triton/ir/builder.h"
#include "triton/ir/type.h" #include "triton/ir/type.h"
#include "triton/ir/metadata.h"
namespace triton{ namespace triton{
@@ -133,12 +134,12 @@ void initializer::set_specifier(const declaration_specifier *spec) {
} }
ir::value* initializer::codegen(ir::module * mod) const{ ir::value* initializer::codegen(ir::module * mod) const{
std::vector<modifier*> storage = spec_->modifiers(); std::vector<modifier*> modifiers = spec_->modifiers();
ir::type *ty = decl_->type(mod, spec_->type(mod), storage); ir::type *ty = decl_->type(mod, spec_->type(mod), modifiers);
std::string name = decl_->id()->name(); std::string name = decl_->id()->name();
ir::value *value = ir::undef_value::get(ty); ir::value *value = ir::undef_value::get(ty);
auto is_tunable = [](modifier* x){ return x->is_tunable(); }; auto is_tunable = [](modifier* x){ return x->is_tunable(); };
if(std::find_if(storage.begin(), storage.end(), is_tunable) != storage.end()){ if(std::find_if(modifiers.begin(), modifiers.end(), is_tunable) != modifiers.end()){
auto csts = dynamic_cast<list<constant*>*>((node*)expr_); auto csts = dynamic_cast<list<constant*>*>((node*)expr_);
if(csts == nullptr) if(csts == nullptr)
throw std::runtime_error("must specify constant list for metaparameters"); throw std::runtime_error("must specify constant list for metaparameters");
@@ -154,12 +155,19 @@ ir::value* initializer::codegen(ir::module * mod) const{
implicit_broadcast(mod, ty, value); implicit_broadcast(mod, ty, value);
} }
value->set_name(name); value->set_name(name);
// metadata
auto is_multiple_of = [](modifier* x){ return x->is_multiple_of(); };
auto it = std::find_if(modifiers.begin(), modifiers.end(), is_multiple_of);
if(it != modifiers.end())
(*it)->add_metadata(mod, name);
// register
mod->set_value(name, value); mod->set_value(name, value);
mod->get_scope().types[name] = ty; mod->get_scope().types[name] = ty;
if(auto *x = dynamic_cast<ir::alloc_const*>(value)) if(auto *x = dynamic_cast<ir::alloc_const*>(value))
mod->add_alloc(x); mod->add_alloc(x);
// constants
auto is_cst = [](modifier* x){ return x->is_cst(); }; auto is_cst = [](modifier* x){ return x->is_cst(); };
if(std::find_if(storage.begin(), storage.end(), is_cst) != storage.end()) if(std::find_if(modifiers.begin(), modifiers.end(), is_cst) != modifiers.end())
mod->set_const(name); mod->set_const(name);
return value; return value;
} }
@@ -183,16 +191,28 @@ void storage_specifier::add_attr(ir::function* fn, size_t pos) {
fn->add_attr(pos, ir::attribute(get_ir_attr(value_))); fn->add_attr(pos, ir::attribute(get_ir_attr(value_)));
} }
void storage_specifier::add_metadata(ir::module*, std::string) {
throw std::runtime_error("storage specifier is not a metadata");
}
/* Alignment specifier */ /* Alignment specifier */
void alignment_specifier::add_attr(ir::function* fn, size_t pos) { void alignment_specifier::add_attr(ir::function* fn, size_t pos) {
fn->add_attr(pos, ir::attribute(ir::aligned, cst_->value())); fn->add_attr(pos, ir::attribute(ir::aligned, cst_->value()));
} }
void alignment_specifier::add_metadata(ir::module *mod, std::string name) {
throw std::runtime_error("alignment specifier is not a metadata");
}
/* Multiple-Of specifier */ /* Multiple-Of specifier */
void multiple_of_specifier::add_attr(ir::function* fn, size_t pos) { void multiple_of_specifier::add_attr(ir::function* fn, size_t pos) {
fn->add_attr(pos, ir::attribute(ir::multiple_of, cst_->value())); fn->add_attr(pos, ir::attribute(ir::multiple_of, cst_->value()));
} }
void multiple_of_specifier::add_metadata(ir::module *mod, std::string name) {
mod->add_metadata(name, {ir::metadata::multiple_of, cst_->value()});
}
/* Function definition */ /* Function definition */
ir::value* function_definition::codegen(ir::module *mod) const{ ir::value* function_definition::codegen(ir::module *mod) const{