more cleaning
This commit is contained in:
@@ -26,8 +26,9 @@ const tunable int32 TN = {64, 128};
|
|||||||
const tunable int32 TK = {16};
|
const tunable int32 TK = {16};
|
||||||
const tunable int32 GZ = {1};
|
const tunable int32 GZ = {1};
|
||||||
|
|
||||||
void matmul(restrict read_only fp16 *A, restrict read_only fp16 *B,
|
void matmul(restrict read_only align(4) fp16 *A,
|
||||||
fp32 *C,
|
restrict read_only align(4) fp16 *B,
|
||||||
|
align(4) fp32 *C,
|
||||||
int32 M, int32 N, int32 K,
|
int32 M, int32 N, int32 K,
|
||||||
int32 lda, int32 ldb, int32 ldc,
|
int32 lda, int32 ldb, int32 ldc,
|
||||||
int32 *locks, int32 grid0, int32 grid1) {
|
int32 *locks, int32 grid0, int32 grid1) {
|
||||||
@@ -119,7 +120,7 @@ class BlockSparseGemmOp : public OpKernel {
|
|||||||
return 2.*M*N*K / ts * 1e-3;
|
return 2.*M*N*K / ts * 1e-3;
|
||||||
};
|
};
|
||||||
// just-in-time compile source-code
|
// just-in-time compile source-code
|
||||||
jit.autotune("matmul", src, benchmark);
|
// jit.autotune("matmul", src, benchmark);
|
||||||
// jit.add_module("matmul", src, {4, 2, 8, 4, 2, 32, 1, 4, 1, 1, 8, 8, 8, 1});
|
// jit.add_module("matmul", src, {4, 2, 8, 4, 2, 32, 1, 4, 1, 1, 8, 8, 8, 1});
|
||||||
// jit.add_module("matmul", src, {16, 4, 128, 16, 4, 128, 2, 2, 2, 2, 8, 32, 8, 1});
|
// jit.add_module("matmul", src, {16, 4, 128, 16, 4, 128, 2, 2, 2, 2, 8, 32, 8, 1});
|
||||||
// jit.add_module("matmul", src, {8, 8, 128, 16, 8, 128, 2, 2, 2, 2, 16, 32, 8, 1 });
|
// jit.add_module("matmul", src, {8, 8, 128, 16, 8, 128, 2, 2, 2, 2, 16, 32, 8, 1 });
|
||||||
|
@@ -28,10 +28,34 @@ private:
|
|||||||
};
|
};
|
||||||
|
|
||||||
/* Attribute */
|
/* Attribute */
|
||||||
enum attribute_t {
|
enum attribute_kind_t {
|
||||||
readonly,
|
readonly,
|
||||||
writeonly,
|
writeonly,
|
||||||
noalias
|
noalias,
|
||||||
|
aligned,
|
||||||
|
multiple_of
|
||||||
|
};
|
||||||
|
|
||||||
|
class attribute {
|
||||||
|
public:
|
||||||
|
attribute(attribute_kind_t kind, unsigned value = 0):
|
||||||
|
kind_(kind), value_(value){}
|
||||||
|
|
||||||
|
bool operator<(const attribute& other) const {
|
||||||
|
return std::make_pair(kind_, value_) < std::make_pair(other.kind_, other.value_);
|
||||||
|
}
|
||||||
|
|
||||||
|
const attribute_kind_t get_kind() const {
|
||||||
|
return kind_;
|
||||||
|
}
|
||||||
|
|
||||||
|
const unsigned get_value() const {
|
||||||
|
return value_;
|
||||||
|
}
|
||||||
|
|
||||||
|
private:
|
||||||
|
attribute_kind_t kind_;
|
||||||
|
unsigned value_;
|
||||||
};
|
};
|
||||||
|
|
||||||
/* Function */
|
/* Function */
|
||||||
@@ -44,7 +68,7 @@ class function: public global_object{
|
|||||||
typedef blocks_t::iterator block_iterator;
|
typedef blocks_t::iterator block_iterator;
|
||||||
typedef blocks_t::const_iterator const_block_iterator;
|
typedef blocks_t::const_iterator const_block_iterator;
|
||||||
|
|
||||||
typedef std::map<unsigned, std::set<attribute_t>> attr_map_t;
|
typedef std::map<unsigned, std::set<attribute>> attr_map_t;
|
||||||
|
|
||||||
private:
|
private:
|
||||||
function(function_type *ty, linkage_types_t linkage,
|
function(function_type *ty, linkage_types_t linkage,
|
||||||
@@ -63,7 +87,7 @@ public:
|
|||||||
void insert_block(basic_block* block, basic_block *next = nullptr);
|
void insert_block(basic_block* block, basic_block *next = nullptr);
|
||||||
|
|
||||||
// attributes
|
// attributes
|
||||||
void add_attr(unsigned arg_id, attribute_t attr) { attrs_[arg_id].insert(attr); }
|
void add_attr(unsigned arg_id, attribute attr) { attrs_[arg_id].insert(attr); }
|
||||||
const attr_map_t &attrs() { return attrs_; }
|
const attr_map_t &attrs() { return attrs_; }
|
||||||
|
|
||||||
private:
|
private:
|
||||||
|
@@ -3,7 +3,7 @@
|
|||||||
|
|
||||||
#include "node.h"
|
#include "node.h"
|
||||||
#include <string>
|
#include <string>
|
||||||
|
#include <iostream>
|
||||||
|
|
||||||
namespace triton{
|
namespace triton{
|
||||||
|
|
||||||
@@ -41,19 +41,45 @@ public:
|
|||||||
|
|
||||||
// Types
|
// Types
|
||||||
class modifier: public node {
|
class modifier: public node {
|
||||||
|
public:
|
||||||
|
virtual bool is_cst_space() const { return false; }
|
||||||
|
virtual bool is_tunable() const { return false; }
|
||||||
|
virtual bool is_cst() const { return false; }
|
||||||
|
virtual void add_attr(ir::function* fn, size_t pos) = 0;
|
||||||
};
|
};
|
||||||
|
|
||||||
class storage_specifier: public node {
|
class storage_specifier: public modifier {
|
||||||
public:
|
public:
|
||||||
storage_specifier(STORAGE_SPEC_T value): value_(value) {}
|
storage_specifier(STORAGE_SPEC_T value): value_(value) {}
|
||||||
STORAGE_SPEC_T value() const { return value_; }
|
STORAGE_SPEC_T value() const { return value_; }
|
||||||
|
bool is_cst_space() const { return value_ == CONSTANT_SPACE_T; }
|
||||||
|
bool is_tunable() const { return value_ == TUNABLE_T; }
|
||||||
|
bool is_cst() const { return value_ == CONST_T; }
|
||||||
|
void add_attr(ir::function* fn, size_t pos);
|
||||||
|
|
||||||
private:
|
private:
|
||||||
const STORAGE_SPEC_T value_;
|
const STORAGE_SPEC_T value_;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
class alignment_specifier: public modifier {
|
||||||
|
public:
|
||||||
|
alignment_specifier(node* value): cst_((constant*)value) { }
|
||||||
|
void add_attr(ir::function* fn, size_t pos);
|
||||||
|
|
||||||
|
private:
|
||||||
|
constant* cst_;
|
||||||
|
};
|
||||||
|
|
||||||
|
class multiple_of_specifier: public modifier {
|
||||||
|
public:
|
||||||
|
multiple_of_specifier(node* value): cst_((constant*)value) {}
|
||||||
|
void add_attr(ir::function* fn, size_t pos);
|
||||||
|
|
||||||
|
private:
|
||||||
|
constant* cst_;
|
||||||
|
};
|
||||||
|
|
||||||
|
// declaration specifier
|
||||||
class declaration_specifier: public node{
|
class declaration_specifier: public node{
|
||||||
public:
|
public:
|
||||||
virtual ir::type* type(ir::module *mod) const = 0;
|
virtual ir::type* type(ir::module *mod) const = 0;
|
||||||
@@ -70,6 +96,7 @@ private:
|
|||||||
const TYPE_T ty_;
|
const TYPE_T ty_;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
// declaration modifier
|
||||||
class declaration_modifier: public declaration_specifier {
|
class declaration_modifier: public declaration_specifier {
|
||||||
public:
|
public:
|
||||||
declaration_modifier(node* mod, node *decl_spec)
|
declaration_modifier(node* mod, node *decl_spec)
|
||||||
@@ -91,7 +118,7 @@ public:
|
|||||||
decl_((declarator*)decl) { }
|
decl_((declarator*)decl) { }
|
||||||
|
|
||||||
ir::type* type(ir::module *mod) const;
|
ir::type* type(ir::module *mod) const;
|
||||||
std::vector<modifier*> storage() const;
|
std::vector<modifier*> modifiers() const;
|
||||||
const identifier* id() const;
|
const identifier* id() const;
|
||||||
|
|
||||||
public:
|
public:
|
||||||
|
@@ -353,6 +353,8 @@ parameter_declaration
|
|||||||
declaration_specifiers
|
declaration_specifiers
|
||||||
: type_specifier { $$ = new typed_declaration_specifier(get_type_spec($1)); }
|
: type_specifier { $$ = new typed_declaration_specifier(get_type_spec($1)); }
|
||||||
| storage_class_specifier declaration_specifiers { $$ = new declaration_modifier($1, $2); }
|
| storage_class_specifier declaration_specifiers { $$ = new declaration_modifier($1, $2); }
|
||||||
|
| alignment_class_specifier declaration_specifiers { $$ = new declaration_modifier($1, $2); }
|
||||||
|
| multiple_of_class_specifier declaration_specifiers { $$ = new declaration_modifier($1, $2); }
|
||||||
;
|
;
|
||||||
|
|
||||||
init_declarator_list
|
init_declarator_list
|
||||||
@@ -385,6 +387,13 @@ storage_class_specifier
|
|||||||
| CONSTANT_SPACE { $$ = new storage_specifier(CONSTANT_SPACE_T); }
|
| CONSTANT_SPACE { $$ = new storage_specifier(CONSTANT_SPACE_T); }
|
||||||
;
|
;
|
||||||
|
|
||||||
|
alignment_class_specifier
|
||||||
|
: ALIGN '(' constant ')' { $$ = new alignment_specifier($3); }
|
||||||
|
|
||||||
|
multiple_of_class_specifier
|
||||||
|
: MULTIPLE_OF '(' constant ')' { $$ = new multiple_of_specifier($3); }
|
||||||
|
|
||||||
|
|
||||||
external_declaration
|
external_declaration
|
||||||
: function_definition { $$ = $1; }
|
: function_definition { $$ = $1; }
|
||||||
| declaration { $$ = $1; }
|
| declaration { $$ = $1; }
|
||||||
|
@@ -1074,11 +1074,12 @@ void selection::lower_instruction(ir::instruction *src, IRBuilder<> &builder) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
inline llvm::Attribute::AttrKind llvm_attr(ir::attribute_t attr) {
|
inline llvm::Attribute llvm_attr(llvm::LLVMContext& ctx, ir::attribute attr) {
|
||||||
switch(attr){
|
switch(attr.get_kind()){
|
||||||
case ir::noalias: return llvm::Attribute::NoAlias;
|
case ir::noalias: return llvm::Attribute::get(ctx, llvm::Attribute::NoAlias);
|
||||||
case ir::readonly: return llvm::Attribute::ReadOnly;
|
case ir::readonly: return llvm::Attribute::get(ctx, llvm::Attribute::ReadOnly);
|
||||||
case ir::writeonly: return llvm::Attribute::WriteOnly;
|
case ir::writeonly: return llvm::Attribute::get(ctx, llvm::Attribute::WriteOnly);
|
||||||
|
case ir::aligned: return llvm::Attribute::get(ctx, llvm::Attribute::Alignment, attr.get_value());
|
||||||
default: throw std::runtime_error("cannot convert ir::attribute_t to llvm::Attribute");
|
default: throw std::runtime_error("cannot convert ir::attribute_t to llvm::Attribute");
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -1101,6 +1102,7 @@ void selection::run(ir::module &src, Module &dst) {
|
|||||||
|
|
||||||
// iterate over functions
|
// iterate over functions
|
||||||
for(ir::function *fn: src.get_function_list()) {
|
for(ir::function *fn: src.get_function_list()) {
|
||||||
|
|
||||||
// create LLVM function
|
// create LLVM function
|
||||||
FunctionType *fn_ty = (FunctionType*)llvm_type(fn->get_fn_type(), dst_ctx);
|
FunctionType *fn_ty = (FunctionType*)llvm_type(fn->get_fn_type(), dst_ctx);
|
||||||
FunctionType *dst_fn_ty = fn_ty;
|
FunctionType *dst_fn_ty = fn_ty;
|
||||||
@@ -1114,18 +1116,16 @@ void selection::run(ir::module &src, Module &dst) {
|
|||||||
dst_fn_args_ty.push_back(dst_builder.getInt32Ty());
|
dst_fn_args_ty.push_back(dst_builder.getInt32Ty());
|
||||||
dst_fn_ty = FunctionType::get(dst_fn_ret_ty, dst_fn_args_ty, false);
|
dst_fn_ty = FunctionType::get(dst_fn_ret_ty, dst_fn_args_ty, false);
|
||||||
}
|
}
|
||||||
|
|
||||||
// grid indices
|
// grid indices
|
||||||
fn->get_fn_type()->get_return_ty();
|
fn->get_fn_type()->get_return_ty();
|
||||||
Function *dst_fn = Function::Create(dst_fn_ty, Function::ExternalLinkage, fn->get_name(), &dst);
|
Function *dst_fn = Function::Create(dst_fn_ty, Function::ExternalLinkage, fn->get_name(), &dst);
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
// set attributes
|
// set attributes
|
||||||
for(auto attr_pair: fn->attrs()){
|
for(auto attr_pair: fn->attrs()){
|
||||||
unsigned id = attr_pair.first;
|
unsigned id = attr_pair.first;
|
||||||
for(ir::attribute_t attr: attr_pair.second)
|
for(ir::attribute attr: attr_pair.second)
|
||||||
dst_fn->addAttribute(id, llvm_attr(attr));
|
dst_fn->addAttribute(id, llvm_attr(dst_ctx, attr));
|
||||||
}
|
}
|
||||||
tgt_->set_kernel(dst_builder, dst_ctx, &dst, dst_fn);
|
tgt_->set_kernel(dst_builder, dst_ctx, &dst, dst_fn);
|
||||||
|
|
||||||
|
@@ -63,7 +63,7 @@ void backend::platforms::init() {
|
|||||||
cache_.push_back(new host_platform());
|
cache_.push_back(new host_platform());
|
||||||
}
|
}
|
||||||
if(cache_.empty())
|
if(cache_.empty())
|
||||||
throw std::runtime_error("ISAAC: No backend available. Make sure CUDA is available in your library path");
|
throw std::runtime_error("Triton: No backend available. Make sure CUDA is available in your library path");
|
||||||
}
|
}
|
||||||
|
|
||||||
void backend::platforms::get(std::vector<platform *> &results) {
|
void backend::platforms::get(std::vector<platform *> &results) {
|
||||||
@@ -83,7 +83,7 @@ void backend::devices::init(std::vector<platform*> const & platforms) {
|
|||||||
for(driver::platform* pf: platforms)
|
for(driver::platform* pf: platforms)
|
||||||
pf->devices(cache_);
|
pf->devices(cache_);
|
||||||
if(cache_.empty())
|
if(cache_.empty())
|
||||||
throw std::runtime_error("ISAAC: No device available. Make sure that your platform is configured properly");
|
throw std::runtime_error("Triton: No device available. Make sure that your platform is configured properly");
|
||||||
}
|
}
|
||||||
|
|
||||||
void backend::devices::get(std::vector<device*> &devs) {
|
void backend::devices::get(std::vector<device*> &devs) {
|
||||||
|
@@ -49,7 +49,7 @@ ir::type* parameter::type(ir::module *mod) const {
|
|||||||
return decl_->type(mod, spec_->type(mod), {});
|
return decl_->type(mod, spec_->type(mod), {});
|
||||||
}
|
}
|
||||||
|
|
||||||
std::vector<modifier*> parameter::storage() const {
|
std::vector<modifier*> parameter::modifiers() const {
|
||||||
return spec_->modifiers();
|
return spec_->modifiers();
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -87,7 +87,7 @@ ir::type* tile::type_impl(ir::module *mod, ir::type *type, storage_spec_vec_cons
|
|||||||
|
|
||||||
// Pointer
|
// Pointer
|
||||||
ir::type* pointer::type_impl(ir::module*, ir::type *type, storage_spec_vec_const_ref_t storage) const{
|
ir::type* pointer::type_impl(ir::module*, ir::type *type, storage_spec_vec_const_ref_t storage) const{
|
||||||
auto is_cst = [](modifier* x){ return x->value() == CONSTANT_SPACE_T; };
|
auto is_cst = [](modifier* x){ return x->is_cst_space(); };
|
||||||
bool is_ptr_to_const = std::find_if(storage.begin(), storage.end(), is_cst) != storage.end();
|
bool is_ptr_to_const = std::find_if(storage.begin(), storage.end(), is_cst) != storage.end();
|
||||||
return ir::pointer_type::get(type, is_ptr_to_const?4:1);
|
return ir::pointer_type::get(type, is_ptr_to_const?4:1);
|
||||||
}
|
}
|
||||||
@@ -137,7 +137,7 @@ ir::value* initializer::codegen(ir::module * mod) const{
|
|||||||
ir::type *ty = decl_->type(mod, spec_->type(mod), storage);
|
ir::type *ty = decl_->type(mod, spec_->type(mod), storage);
|
||||||
std::string name = decl_->id()->name();
|
std::string name = decl_->id()->name();
|
||||||
ir::value *value = ir::undef_value::get(ty);
|
ir::value *value = ir::undef_value::get(ty);
|
||||||
auto is_tunable = [](modifier* x){ return x->value() == TUNABLE_T; };
|
auto is_tunable = [](modifier* x){ return x->is_tunable(); };
|
||||||
if(std::find_if(storage.begin(), storage.end(), is_tunable) != storage.end()){
|
if(std::find_if(storage.begin(), storage.end(), is_tunable) != storage.end()){
|
||||||
auto csts = dynamic_cast<list<constant*>*>((node*)expr_);
|
auto csts = dynamic_cast<list<constant*>*>((node*)expr_);
|
||||||
if(csts == nullptr)
|
if(csts == nullptr)
|
||||||
@@ -158,7 +158,7 @@ ir::value* initializer::codegen(ir::module * mod) const{
|
|||||||
mod->get_scope().types[name] = ty;
|
mod->get_scope().types[name] = ty;
|
||||||
if(auto *x = dynamic_cast<ir::alloc_const*>(value))
|
if(auto *x = dynamic_cast<ir::alloc_const*>(value))
|
||||||
mod->add_alloc(x);
|
mod->add_alloc(x);
|
||||||
auto is_cst = [](modifier* mod){ return mod->value() == CONST_T; };
|
auto is_cst = [](modifier* x){ return x->is_cst(); };
|
||||||
if(std::find_if(storage.begin(), storage.end(), is_cst) != storage.end())
|
if(std::find_if(storage.begin(), storage.end(), is_cst) != storage.end())
|
||||||
mod->set_const(name);
|
mod->set_const(name);
|
||||||
return value;
|
return value;
|
||||||
@@ -169,9 +169,9 @@ ir::type *type_name::type(ir::module *mod) const{
|
|||||||
return decl_->type(mod, spec_->type(mod), {});
|
return decl_->type(mod, spec_->type(mod), {});
|
||||||
}
|
}
|
||||||
|
|
||||||
/* Function definition */
|
/* Storage specifier */
|
||||||
ir::attribute_t get_ir_attr(modifier* mod){
|
inline ir::attribute_kind_t get_ir_attr(STORAGE_SPEC_T spec){
|
||||||
switch(mod->value()){
|
switch(spec){
|
||||||
case RESTRICT_T: return ir::noalias;
|
case RESTRICT_T: return ir::noalias;
|
||||||
case READONLY_T: return ir::readonly;
|
case READONLY_T: return ir::readonly;
|
||||||
case WRITEONLY_T: return ir::writeonly;
|
case WRITEONLY_T: return ir::writeonly;
|
||||||
@@ -179,15 +179,31 @@ ir::attribute_t get_ir_attr(modifier* mod){
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void storage_specifier::add_attr(ir::function* fn, size_t pos) {
|
||||||
|
fn->add_attr(pos, ir::attribute(get_ir_attr(value_)));
|
||||||
|
}
|
||||||
|
|
||||||
|
/* Alignment specifier */
|
||||||
|
void alignment_specifier::add_attr(ir::function* fn, size_t pos) {
|
||||||
|
fn->add_attr(pos, ir::attribute(ir::aligned, cst_->value()));
|
||||||
|
}
|
||||||
|
|
||||||
|
/* Multiple-Of specifier */
|
||||||
|
void multiple_of_specifier::add_attr(ir::function* fn, size_t pos) {
|
||||||
|
fn->add_attr(pos, ir::attribute(ir::multiple_of, cst_->value()));
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
/* Function definition */
|
||||||
ir::value* function_definition::codegen(ir::module *mod) const{
|
ir::value* function_definition::codegen(ir::module *mod) const{
|
||||||
ir::function_type *prototype = (ir::function_type*)header_->type(mod, spec_->type(mod), spec_->modifiers());
|
ir::function_type *prototype = (ir::function_type*)header_->type(mod, spec_->type(mod), spec_->modifiers());
|
||||||
const std::string &name = header_->id()->name();
|
const std::string &name = header_->id()->name();
|
||||||
ir::function *fn = mod->get_or_insert_function(name, prototype);
|
ir::function *fn = mod->get_or_insert_function(name, prototype);
|
||||||
for(unsigned i = 0; i < header_->get_num_args(); i++){
|
for(unsigned i = 0; i < header_->get_num_args(); i++){
|
||||||
parameter *param = header_->get_arg(i);
|
parameter *param = header_->get_arg(i);
|
||||||
std::vector<modifier*> storage = param->storage();
|
std::vector<modifier*> modifiers = param->modifiers();
|
||||||
for(modifier* spec: storage)
|
for(modifier* m: modifiers)
|
||||||
fn->add_attr(1 + i, get_ir_attr(spec));
|
m->add_attr(fn, 1 + i);
|
||||||
}
|
}
|
||||||
header_->bind_parameters(mod, fn);
|
header_->bind_parameters(mod, fn);
|
||||||
ir::basic_block *entry = ir::basic_block::create(mod->get_context(), "entry", fn);
|
ir::basic_block *entry = ir::basic_block::create(mod->get_context(), "entry", fn);
|
||||||
|
Reference in New Issue
Block a user