[language] added alignment metadata for variables
This commit is contained in:
@@ -16,7 +16,7 @@ int main() {
|
||||
triton::jit jit(context);
|
||||
|
||||
// matrix multiplication parameters
|
||||
int32_t M = 1024, N = 1024, K = 1024;
|
||||
int32_t M = 32768, N = 1024, K = 1024;
|
||||
std::vector<float> hc(M*N);
|
||||
std::vector<float> rc(M*N);
|
||||
std::vector<float> ha(M*K);
|
||||
|
@@ -21,6 +21,7 @@ int main() {
|
||||
int32_t BS = 32, F = 1024;
|
||||
int32_t H = 32, W = 32;
|
||||
int32_t C = 1024;
|
||||
|
||||
// random shifts
|
||||
std::vector<int32_t> shift_h(C);
|
||||
std::vector<int32_t> shift_w(C);
|
||||
|
3067
include/triton/external/half.hpp
vendored
Normal file
3067
include/triton/external/half.hpp
vendored
Normal file
File diff suppressed because it is too large
Load Diff
@@ -2,8 +2,9 @@
|
||||
#define TDL_INCLUDE_IR_INSTRUCTIONS_H
|
||||
|
||||
#include <vector>
|
||||
#include "value.h"
|
||||
#include "triton/ir/value.h"
|
||||
#include "triton/ir/type.h"
|
||||
#include "triton/ir/metadata.h"
|
||||
#include "llvm/IR/Instructions.h"
|
||||
|
||||
namespace triton{
|
||||
@@ -48,12 +49,16 @@ public:
|
||||
// results
|
||||
unsigned get_num_results() const { return results_.size(); }
|
||||
value* get_result(unsigned i) { return results_.at(i); }
|
||||
|
||||
// metadata
|
||||
void set_metadata(ir::metadata::kind_t kind,
|
||||
unsigned value) { metadatas_[kind] = value;}
|
||||
unsigned get_metadata(ir::metadata::kind_t kind) { return metadatas_[kind];}
|
||||
private:
|
||||
basic_block *parent_;
|
||||
value *pred_;
|
||||
value *mask_pred_;
|
||||
std::vector<value*> results_;
|
||||
std::map<ir::metadata::kind_t, unsigned> metadatas_;
|
||||
};
|
||||
|
||||
// result reference
|
||||
|
29
include/triton/ir/metadata.h
Normal file
29
include/triton/ir/metadata.h
Normal file
@@ -0,0 +1,29 @@
|
||||
#ifndef TDL_INCLUDE_IR_METADATA_H
|
||||
#define TDL_INCLUDE_IR_METADATA_H
|
||||
|
||||
namespace triton{
|
||||
namespace ir{
|
||||
|
||||
|
||||
/* Metadata */
|
||||
class metadata{
|
||||
public:
|
||||
enum kind_t{
|
||||
multiple_of
|
||||
};
|
||||
|
||||
private:
|
||||
metadata(kind_t kind, unsigned value);
|
||||
|
||||
public:
|
||||
static metadata* get(kind_t kind, unsigned value);
|
||||
|
||||
private:
|
||||
kind_t kind_;
|
||||
unsigned value_;
|
||||
};
|
||||
|
||||
}
|
||||
}
|
||||
|
||||
#endif
|
@@ -7,6 +7,7 @@
|
||||
#include <string>
|
||||
#include <functional>
|
||||
#include "builder.h"
|
||||
#include "metadata.h"
|
||||
|
||||
namespace triton{
|
||||
|
||||
@@ -38,6 +39,7 @@ struct scope {
|
||||
class module {
|
||||
typedef std::pair<std::string, basic_block*> val_key_t;
|
||||
friend class function;
|
||||
typedef std::pair<ir::metadata::kind_t, unsigned> md_pair_t;
|
||||
|
||||
public:
|
||||
typedef std::map<std::string, global_value*> symbols_map_t;
|
||||
@@ -84,6 +86,8 @@ public:
|
||||
// Register global
|
||||
void register_global(const std::string& name, ir::value *x) { globals_[name] = x; }
|
||||
const std::map<std::string, ir::value*>& globals() const { return globals_; }
|
||||
// Metadata
|
||||
void add_metadata(const std::string &name, md_pair_t x) { metadatas_[name] = x; }
|
||||
|
||||
private:
|
||||
std::string name_;
|
||||
@@ -101,6 +105,7 @@ private:
|
||||
std::stack<scope> scopes_;
|
||||
std::vector<ir::alloc_const*> allocs_;
|
||||
std::map<std::string, ir::value*> globals_;
|
||||
std::map<std::string, md_pair_t> metadatas_;
|
||||
};
|
||||
|
||||
}
|
||||
|
@@ -45,7 +45,9 @@ public:
|
||||
virtual bool is_cst_space() const { return false; }
|
||||
virtual bool is_tunable() const { return false; }
|
||||
virtual bool is_cst() const { return false; }
|
||||
virtual bool is_multiple_of() const { return false; }
|
||||
virtual void add_attr(ir::function* fn, size_t pos) = 0;
|
||||
virtual void add_metadata(ir::module* mod, std::string name) = 0;
|
||||
};
|
||||
|
||||
class storage_specifier: public modifier {
|
||||
@@ -56,6 +58,7 @@ public:
|
||||
bool is_tunable() const { return value_ == TUNABLE_T; }
|
||||
bool is_cst() const { return value_ == CONST_T; }
|
||||
void add_attr(ir::function* fn, size_t pos);
|
||||
void add_metadata(ir::module* mod, std::string name);
|
||||
|
||||
private:
|
||||
const STORAGE_SPEC_T value_;
|
||||
@@ -65,6 +68,7 @@ class alignment_specifier: public modifier {
|
||||
public:
|
||||
alignment_specifier(node* value): cst_((constant*)value) { }
|
||||
void add_attr(ir::function* fn, size_t pos);
|
||||
void add_metadata(ir::module* mod, std::string name);
|
||||
|
||||
private:
|
||||
constant* cst_;
|
||||
@@ -74,6 +78,8 @@ class multiple_of_specifier: public modifier {
|
||||
public:
|
||||
multiple_of_specifier(node* value): cst_((constant*)value) {}
|
||||
void add_attr(ir::function* fn, size_t pos);
|
||||
void add_metadata(ir::module* mod, std::string name);
|
||||
bool is_multiple_of() const { return true; }
|
||||
|
||||
private:
|
||||
constant* cst_;
|
||||
|
@@ -39,6 +39,11 @@ bool alignment_info::populate_is_constant(ir::value *v) {
|
||||
bool rhs = populate_is_constant(x->get_operand(1));
|
||||
return cache(lhs && rhs);
|
||||
}
|
||||
if(auto *x = dynamic_cast<ir::psi_inst*>(v)){
|
||||
bool value_true = populate_is_constant(x->get_value_true());
|
||||
bool value_false = populate_is_constant(x->get_value_false());
|
||||
return cache(value_true && value_false);
|
||||
}
|
||||
if(v->get_type()->is_tile_ty())
|
||||
return cache(false);
|
||||
if(auto *x = dynamic_cast<ir::phi_node*>(v)){
|
||||
@@ -97,6 +102,11 @@ unsigned alignment_info::populate_max_contiguous(ir::value *v){
|
||||
return cache(lhs_max_contiguous);
|
||||
}
|
||||
}
|
||||
if(auto *x = dynamic_cast<ir::psi_inst*>(v)){
|
||||
int value_true = populate_max_contiguous(x->get_value_true());
|
||||
int value_false = populate_max_contiguous(x->get_value_false());
|
||||
return cache(std::min(value_true, value_false));
|
||||
}
|
||||
if(auto *x = dynamic_cast<ir::getelementptr_inst*>(v)){
|
||||
ir::value* lhs = x->get_operand(0);
|
||||
ir::value* rhs = x->get_operand(1);
|
||||
@@ -132,6 +142,12 @@ unsigned alignment_info::populate_starting_multiple(ir::value *v){
|
||||
if(starting_multiple_.find(v) != starting_multiple_.end())
|
||||
return starting_multiple_.at(v);
|
||||
auto cache = [this,v](unsigned value){ return add_to_cache(v, value, starting_multiple_); };
|
||||
// has metadata
|
||||
if(auto *x = dynamic_cast<ir::instruction*>(v)){
|
||||
unsigned multiple_of = x->get_metadata(ir::metadata::multiple_of);
|
||||
if(multiple_of > 0)
|
||||
return cache(multiple_of);
|
||||
}
|
||||
// arguments
|
||||
if(auto *x = dynamic_cast<ir::argument*>(v)){
|
||||
std::set<ir::attribute> attributes = x->get_parent()->get_attributes(x);
|
||||
@@ -174,6 +190,11 @@ unsigned alignment_info::populate_starting_multiple(ir::value *v){
|
||||
if(auto *x = dynamic_cast<ir::get_global_range_inst*>(v)){
|
||||
return cache(v->get_type()->get_tile_shapes()[0]->get_value());
|
||||
}
|
||||
if(auto *x = dynamic_cast<ir::psi_inst*>(v)){
|
||||
int value_true = populate_starting_multiple(x->get_value_true());
|
||||
int value_false = populate_starting_multiple(x->get_value_false());
|
||||
return cache(std::min(value_true, value_false));
|
||||
}
|
||||
if(auto *x = dynamic_cast<ir::phi_node*>(v)){
|
||||
// put a conservative initial value in phi node to avoid infinite recursion
|
||||
unsigned result = 1;
|
||||
|
@@ -123,7 +123,7 @@ void shift(restrict read_only align(16) )" << a_ty_ << R"( *a,
|
||||
restrict read_only align(16) )" << b_ty_ << R"( *b,
|
||||
fp32 *c,
|
||||
multiple_of(4) int32 M, multiple_of(4) int32 N, multiple_of(4) int32 K,
|
||||
int32 lda,
|
||||
multiple_of(4) int32 lda,
|
||||
int32 ABS, int32 AH, int32 AW, int32 AR, int32 AS) {
|
||||
int32 rxa[TM] = get_global_range[TM](0);
|
||||
int32 ryb[TN] = get_global_range[TN](1);
|
||||
@@ -140,7 +140,8 @@ void shift(restrict read_only align(16) )" << a_ty_ << R"( *a,
|
||||
int1 maskw[TM] = (raw >= pad_w) && (raw < (AW - pad_w));
|
||||
int1 mask[TM, TK] = maskh[:, newaxis] && maskw[:, newaxis];
|
||||
__constant__ int32* pd[TK] = delta + rka;
|
||||
int32 d[TK] = *pd;
|
||||
multiple_of(4) int32 d[TK];
|
||||
d = *pd;
|
||||
int32 offa1[TK] = rka*lda;
|
||||
int32 inc[TM, TK] = mask ? d[newaxis, :] : offa1[newaxis, :];
|
||||
)" << a_ty_ << R"(* pa[TM, TK] = a + rxa[:, newaxis] + inc;
|
||||
|
14
lib/ir/metadata.cpp
Normal file
14
lib/ir/metadata.cpp
Normal file
@@ -0,0 +1,14 @@
|
||||
#include "triton/ir/metadata.h"
|
||||
|
||||
namespace triton{
|
||||
namespace ir{
|
||||
|
||||
metadata::metadata(kind_t kind, unsigned value)
|
||||
: kind_(kind), value_(value) { }
|
||||
|
||||
metadata* metadata::get(kind_t kind, unsigned value) {
|
||||
return new metadata(kind, value);
|
||||
}
|
||||
|
||||
}
|
||||
}
|
@@ -23,6 +23,11 @@ ir::context& module::get_context() {
|
||||
|
||||
void module::set_value(const std::string& name, ir::basic_block *block, ir::value *value){
|
||||
values_[val_key_t{name, block}] = value;
|
||||
auto it = metadatas_.find(name);
|
||||
if(auto *x = dynamic_cast<ir::instruction*>(value))
|
||||
if(it != metadatas_.end()){
|
||||
x->set_metadata(it->second.first, it->second.second);
|
||||
}
|
||||
}
|
||||
|
||||
void module::set_value(const std::string& name, ir::value *value){
|
||||
|
@@ -5,6 +5,7 @@
|
||||
#include "triton/ir/basic_block.h"
|
||||
#include "triton/ir/builder.h"
|
||||
#include "triton/ir/type.h"
|
||||
#include "triton/ir/metadata.h"
|
||||
|
||||
|
||||
namespace triton{
|
||||
@@ -133,12 +134,12 @@ void initializer::set_specifier(const declaration_specifier *spec) {
|
||||
}
|
||||
|
||||
ir::value* initializer::codegen(ir::module * mod) const{
|
||||
std::vector<modifier*> storage = spec_->modifiers();
|
||||
ir::type *ty = decl_->type(mod, spec_->type(mod), storage);
|
||||
std::vector<modifier*> modifiers = spec_->modifiers();
|
||||
ir::type *ty = decl_->type(mod, spec_->type(mod), modifiers);
|
||||
std::string name = decl_->id()->name();
|
||||
ir::value *value = ir::undef_value::get(ty);
|
||||
auto is_tunable = [](modifier* x){ return x->is_tunable(); };
|
||||
if(std::find_if(storage.begin(), storage.end(), is_tunable) != storage.end()){
|
||||
if(std::find_if(modifiers.begin(), modifiers.end(), is_tunable) != modifiers.end()){
|
||||
auto csts = dynamic_cast<list<constant*>*>((node*)expr_);
|
||||
if(csts == nullptr)
|
||||
throw std::runtime_error("must specify constant list for metaparameters");
|
||||
@@ -154,12 +155,19 @@ ir::value* initializer::codegen(ir::module * mod) const{
|
||||
implicit_broadcast(mod, ty, value);
|
||||
}
|
||||
value->set_name(name);
|
||||
// metadata
|
||||
auto is_multiple_of = [](modifier* x){ return x->is_multiple_of(); };
|
||||
auto it = std::find_if(modifiers.begin(), modifiers.end(), is_multiple_of);
|
||||
if(it != modifiers.end())
|
||||
(*it)->add_metadata(mod, name);
|
||||
// register
|
||||
mod->set_value(name, value);
|
||||
mod->get_scope().types[name] = ty;
|
||||
if(auto *x = dynamic_cast<ir::alloc_const*>(value))
|
||||
mod->add_alloc(x);
|
||||
// constants
|
||||
auto is_cst = [](modifier* x){ return x->is_cst(); };
|
||||
if(std::find_if(storage.begin(), storage.end(), is_cst) != storage.end())
|
||||
if(std::find_if(modifiers.begin(), modifiers.end(), is_cst) != modifiers.end())
|
||||
mod->set_const(name);
|
||||
return value;
|
||||
}
|
||||
@@ -183,16 +191,28 @@ void storage_specifier::add_attr(ir::function* fn, size_t pos) {
|
||||
fn->add_attr(pos, ir::attribute(get_ir_attr(value_)));
|
||||
}
|
||||
|
||||
void storage_specifier::add_metadata(ir::module*, std::string) {
|
||||
throw std::runtime_error("storage specifier is not a metadata");
|
||||
}
|
||||
|
||||
/* Alignment specifier */
|
||||
void alignment_specifier::add_attr(ir::function* fn, size_t pos) {
|
||||
fn->add_attr(pos, ir::attribute(ir::aligned, cst_->value()));
|
||||
}
|
||||
|
||||
void alignment_specifier::add_metadata(ir::module *mod, std::string name) {
|
||||
throw std::runtime_error("alignment specifier is not a metadata");
|
||||
}
|
||||
|
||||
/* Multiple-Of specifier */
|
||||
void multiple_of_specifier::add_attr(ir::function* fn, size_t pos) {
|
||||
fn->add_attr(pos, ir::attribute(ir::multiple_of, cst_->value()));
|
||||
}
|
||||
|
||||
void multiple_of_specifier::add_metadata(ir::module *mod, std::string name) {
|
||||
mod->add_metadata(name, {ir::metadata::multiple_of, cst_->value()});
|
||||
}
|
||||
|
||||
|
||||
/* Function definition */
|
||||
ir::value* function_definition::codegen(ir::module *mod) const{
|
||||
|
Reference in New Issue
Block a user