[syntax tree] more fixes in lowering phi nodes

This commit is contained in:
Philippe Tillet
2019-02-26 12:36:37 -05:00
parent 338f291835
commit 68dea75aa0
10 changed files with 118 additions and 41 deletions

View File

@@ -42,7 +42,7 @@ const tunable int32 TM;\
const tunable int32 TN;\ const tunable int32 TN;\
const tunable int32 TK;\ const tunable int32 TK;\
\ \
void matmul(fp32 *a, fp32 *b, fp32 *c, int32 M, int32 N, int32 K, int32 bound){\ void matmul(restrict readonly fp32 *a, restrict readonly fp32 *b, fp32 *c, int32 M, int32 N, int32 K, int32 bound){\
int32 rxa[TM] = get_global_range[TM](0);\ int32 rxa[TM] = get_global_range[TM](0);\
int32 ryb[TN] = get_global_range[TN](1);\ int32 ryb[TN] = get_global_range[TN](1);\
int32 rka[TK] = 0 ... TK;\ int32 rka[TK] = 0 ... TK;\
@@ -56,16 +56,19 @@ void matmul(fp32 *a, fp32 *b, fp32 *c, int32 M, int32 N, int32 K, int32 bound){\
fp32* pc[TM, TN] = c + rxc[:, newaxis] + ryc[newaxis, :]*M;\ fp32* pc[TM, TN] = c + rxc[:, newaxis] + ryc[newaxis, :]*M;\
fp32 a[TM, TK] = *pa;\ fp32 a[TM, TK] = *pa;\
fp32 b[TN, TK] = *pb;\ fp32 b[TN, TK] = *pb;\
int1 checkc0[TM] = rxc < M;\ int1 checkc0[TM];\
int1 checkc1[TN] = ryc < N;\ int1 checkc1[TN];\
int1 checkc[TM, TN] = checkc0[:, newaxis] && checkc1[newaxis, :];\ int1 checkc[TM, TN];\
for(k = K; k > 0; k = k - TK){\ for(k = K; k > 0; k = k - TK){\
C = dot(a, b, C);\ C = dot(a, b, C);\
pa = pa + TK*M;\ pa = pa + TK*M;\
pb = pb + TK*K;\ pb = pb + TK*K;\
a = *pa;\ a = *pa;\
b = *pb;\ b = *pb;\
}\ }\
checkc0 = rxc < M;\
checkc1 = ryc < N;\
checkc = checkc0[:, newaxis] && checkc1[newaxis, :];\
@checkc *pc = C;\ @checkc *pc = C;\
}\ }\
"; ";
@@ -203,23 +206,23 @@ int main() {
// tuning parameters // tuning parameters
tune.run(module); tune.run(module);
std::vector<unsigned> params = { std::vector<unsigned> params = {
// shapes // shapes
8, 8, 8, 16, 16, 8,
// a0 // a0
1, 8, 1, 2, 8, 1,
// b0 // b0
1, 8, 1, 4, 4, 1,
// c0 // c0
1, 8, 1, 2, 8, 1,
// c1 // c1
1, 4, 2, 4, 4, 1,
// a1 // a1
1, 4, 2, 2, 4, 1,
// b1 // b1
1, 4, 2 1, 8, 1
}; };
// meta-parameters // meta-parameters
unsigned i = 0; unsigned i = 0;
context.p_impl->mp_constants_[0]->set_value(params[0]); context.p_impl->mp_constants_[0]->set_value(params[0]);
@@ -240,12 +243,13 @@ int main() {
// run passes // run passes
triton::ir::print(module, std::cout);
exit(EXIT_FAILURE);
buffer_info.run(module); buffer_info.run(module);
shared.run(module); shared.run(module);
liveness.run(module); liveness.run(module);
allocation.run(); allocation.run();
barriers.run(module); barriers.run(module);
// triton::ir::print(module, std::cout);
vectorize.run(module); vectorize.run(module);
selection.run(module, llvm_module); selection.run(module, llvm_module);
@@ -256,6 +260,7 @@ int main() {
manager.run(llvm_module); manager.run(llvm_module);
std::string src = generate_machine_code(llvm_module, "nvptx64-nvidia-cuda", compute_data_layout(true, true)); std::string src = generate_machine_code(llvm_module, "nvptx64-nvidia-cuda", compute_data_layout(true, true));
std::cout << src << std::endl;
// compile machine code // compile machine code
CUdevice cu_device; CUdevice cu_device;
@@ -277,9 +282,9 @@ int main() {
std::vector<numeric_t> b(K*N); std::vector<numeric_t> b(K*N);
srand(0); srand(0);
for(size_t i = 0; i < a.size(); i++) for(size_t i = 0; i < a.size(); i++)
a[i] = (float)rand()/RAND_MAX; a[i] = 1;
for(size_t i = 0; i < b.size(); i++) for(size_t i = 0; i < b.size(); i++)
b[i] = (float)rand()/RAND_MAX; b[i] = 1;
for(size_t i = 0; i < c.size(); i++) for(size_t i = 0; i < c.size(); i++)
c[i] = 0; c[i] = 0;
CUdeviceptr d_a, d_b, d_c; CUdeviceptr d_a, d_b, d_c;

View File

@@ -60,7 +60,9 @@ enum STORAGE_SPEC_T{
CONST_T, CONST_T,
TUNABLE_T, TUNABLE_T,
KERNEL_T, KERNEL_T,
READONLY_T, WRITEONLY_T, RESTRICT_T,
READONLY_T,
WRITEONLY_T
}; };
class pointer; class pointer;
@@ -505,6 +507,8 @@ public:
: declarator(id), args_((list<parameter*>*)args) { } : declarator(id), args_((list<parameter*>*)args) { }
void bind_parameters(ir::module *mod, ir::function *fn) const; void bind_parameters(ir::module *mod, ir::function *fn) const;
unsigned get_num_args() const { return args_->values().size(); }
parameter* get_arg(unsigned i) const { return args_->values().at(i); }
public: public:
const list<parameter*>* args_; const list<parameter*>* args_;

View File

@@ -46,7 +46,7 @@ STORAGE_SPEC_T get_storage_spec(node *op) { return ((token*)op)->storage_spec;}
%} %}
%token IDENTIFIER CONSTANT STRING_LITERAL %token IDENTIFIER CONSTANT STRING_LITERAL
%token TUNABLE KERNEL READONLY WRITEONLY CONST %token TUNABLE KERNEL RESTRICT READONLY WRITEONLY CONST
%token PTR_OP INC_OP DEC_OP LEFT_OP RIGHT_OP LE_OP GE_OP EQ_OP NE_OP %token PTR_OP INC_OP DEC_OP LEFT_OP RIGHT_OP LE_OP GE_OP EQ_OP NE_OP
%token AND_OP OR_OP MUL_ASSIGN DIV_ASSIGN MOD_ASSIGN ADD_ASSIGN %token AND_OP OR_OP MUL_ASSIGN DIV_ASSIGN MOD_ASSIGN ADD_ASSIGN
%token SUB_ASSIGN LEFT_ASSIGN RIGHT_ASSIGN AND_ASSIGN %token SUB_ASSIGN LEFT_ASSIGN RIGHT_ASSIGN AND_ASSIGN
@@ -363,6 +363,7 @@ storage_class_specifier
: CONST { $$ = new token(CONST_T); } : CONST { $$ = new token(CONST_T); }
| TUNABLE { $$ = new token(TUNABLE_T); } | TUNABLE { $$ = new token(TUNABLE_T); }
| KERNEL { $$ = new token(KERNEL_T); } | KERNEL { $$ = new token(KERNEL_T); }
| RESTRICT { $$ = new token(RESTRICT_T); }
| READONLY { $$ = new token(READONLY_T); } | READONLY { $$ = new token(READONLY_T); }
| WRITEONLY { $$ = new token(WRITEONLY_T); } | WRITEONLY { $$ = new token(WRITEONLY_T); }
; ;

View File

@@ -19,6 +19,7 @@ int comment();
"const" { count(); return(CONST); } "const" { count(); return(CONST); }
"tunable" { count(); return(TUNABLE); } "tunable" { count(); return(TUNABLE); }
"kernel" { count(); return(KERNEL); } "kernel" { count(); return(KERNEL); }
"restrict" { count(); return(RESTRICT); }
"readonly" { count(); return(READONLY); } "readonly" { count(); return(READONLY); }
"writeonly" { count(); return(WRITEONLY); } "writeonly" { count(); return(WRITEONLY); }
"@" { count(); return(AT); } "@" { count(); return(AT); }

View File

@@ -55,6 +55,7 @@ private:
public: public:
shared_tile(llvm::Type* ty, const shapes_t &shapes, llvm::Value* ptr, llvm::IRBuilder<> &builder, llvm::Value* offset = nullptr); shared_tile(llvm::Type* ty, const shapes_t &shapes, llvm::Value* ptr, llvm::IRBuilder<> &builder, llvm::Value* offset = nullptr);
void set_vector_size(unsigned vector_size);
void set_value(indices_t, llvm::Value *); void set_value(indices_t, llvm::Value *);
llvm::Value* get_value(indices_t idx); llvm::Value* get_value(indices_t idx);
llvm::Value* get_pointer() { return ptr_; } llvm::Value* get_pointer() { return ptr_; }
@@ -65,6 +66,7 @@ private:
llvm::Value *offset_; llvm::Value *offset_;
llvm::IRBuilder<> &builder_; llvm::IRBuilder<> &builder_;
std::map<indices_t, llvm::Value*> ptr_cache_; std::map<indices_t, llvm::Value*> ptr_cache_;
unsigned vector_size_;
}; };
class distributed_tile: public tile{ class distributed_tile: public tile{

View File

@@ -2,6 +2,7 @@
#define TDL_INCLUDE_IR_FUNCTION_H #define TDL_INCLUDE_IR_FUNCTION_H
#include <string> #include <string>
#include <map>
#include "value.h" #include "value.h"
#include "constant.h" #include "constant.h"
@@ -27,8 +28,10 @@ private:
}; };
/* Attribute */ /* Attribute */
class attribute { enum attribute_t {
readonly,
writeonly,
noalias
}; };
/* Function */ /* Function */
@@ -41,6 +44,8 @@ class function: public global_object{
typedef blocks_t::iterator block_iterator; typedef blocks_t::iterator block_iterator;
typedef blocks_t::const_iterator const_block_iterator; typedef blocks_t::const_iterator const_block_iterator;
typedef std::map<unsigned, std::set<attribute_t>> attr_map_t;
private: private:
function(function_type *ty, linkage_types_t linkage, function(function_type *ty, linkage_types_t linkage,
const std::string &name = "", module *parent = nullptr); const std::string &name = "", module *parent = nullptr);
@@ -49,6 +54,7 @@ public:
// accessors // accessors
const args_t &args() { return args_; } const args_t &args() { return args_; }
function_type* get_fn_type() { return fn_ty_; } function_type* get_fn_type() { return fn_ty_; }
// factory methods // factory methods
static function *create(function_type *ty, linkage_types_t linkage, static function *create(function_type *ty, linkage_types_t linkage,
const std::string &name, module *mod); const std::string &name, module *mod);
@@ -56,12 +62,17 @@ public:
const blocks_t &blocks() { return blocks_; } const blocks_t &blocks() { return blocks_; }
void insert_block(basic_block* block, basic_block *next = nullptr); void insert_block(basic_block* block, basic_block *next = nullptr);
// attributes
void add_attr(unsigned arg_id, attribute_t attr) { attrs_[arg_id].insert(attr); }
const attr_map_t &attrs() { return attrs_; }
private: private:
module *parent_; module *parent_;
bool init_; bool init_;
function_type *fn_ty_; function_type *fn_ty_;
args_t args_; args_t args_;
blocks_t blocks_; blocks_t blocks_;
attr_map_t attrs_;
}; };
} }

View File

@@ -44,7 +44,7 @@ public:
private: private:
phi_node *make_phi(type *ty, unsigned num_values, basic_block *block); phi_node *make_phi(type *ty, unsigned num_values, basic_block *block);
value *try_remove_trivial_phis(ir::phi_node *&phi, value **pre_user); value *try_remove_trivial_phis(ir::phi_node *&phi);
value *add_phi_operands(const std::string& name, phi_node *&phi); value *add_phi_operands(const std::string& name, phi_node *&phi);
value *get_value_recursive(const std::string& name, basic_block *block); value *get_value_recursive(const std::string& name, basic_block *block);
void push_function(function *fn) { functions_.push_back(fn); } void push_function(function *fn) { functions_.push_back(fn); }

View File

@@ -255,10 +255,25 @@ ir::type* function::type_impl(ir::module* mod, ir::type *type) const{
} }
/* Function definition */ /* Function definition */
ir::attribute_t get_ir_attr(STORAGE_SPEC_T spec){
switch(spec){
case RESTRICT_T: return ir::noalias;
case READONLY_T: return ir::readonly;
case WRITEONLY_T: return ir::writeonly;
default: throw std::runtime_error("cannot convert storage specifier to IR function attribute");
}
}
ir::value* function_definition::codegen(ir::module *mod) const{ ir::value* function_definition::codegen(ir::module *mod) const{
ir::function_type *prototype = (ir::function_type*)header_->type(mod, spec_->type(mod)); ir::function_type *prototype = (ir::function_type*)header_->type(mod, spec_->type(mod));
const std::string &name = header_->id()->name(); const std::string &name = header_->id()->name();
ir::function *fn = mod->get_or_insert_function(name, prototype); ir::function *fn = mod->get_or_insert_function(name, prototype);
for(unsigned i = 0; i < header_->get_num_args(); i++){
parameter *param = header_->get_arg(i);
std::vector<STORAGE_SPEC_T> storage = param->storage();
for(STORAGE_SPEC_T spec: storage)
fn->add_attr(1 + i, get_ir_attr(spec));
}
header_->bind_parameters(mod, fn); header_->bind_parameters(mod, fn);
ir::basic_block *entry = ir::basic_block::create(mod->get_context(), "entry", fn); ir::basic_block *entry = ir::basic_block::create(mod->get_context(), "entry", fn);
mod->seal_block(entry); mod->seal_block(entry);

View File

@@ -12,6 +12,7 @@
#include "llvm/Analysis/LoopInfo.h" #include "llvm/Analysis/LoopInfo.h"
#include "llvm/Transforms/Utils/BasicBlockUtils.h" #include "llvm/Transforms/Utils/BasicBlockUtils.h"
#include "llvm/IR/BasicBlock.h" #include "llvm/IR/BasicBlock.h"
#include "llvm/IR/Attributes.h"
namespace triton{ namespace triton{
namespace codegen{ namespace codegen{
@@ -125,7 +126,7 @@ Value* shared_tile::shared_offset(indices_t idx) {
} }
shared_tile::shared_tile(Type *ty, const shapes_t &shapes, Value *ptr, llvm::IRBuilder<> &builder, Value *offset): shared_tile::shared_tile(Type *ty, const shapes_t &shapes, Value *ptr, llvm::IRBuilder<> &builder, Value *offset):
tile(ty, shapes), ptr_(ptr), builder_(builder), offset_(offset) { tile(ty, shapes), ptr_(ptr), builder_(builder), offset_(offset), vector_size_(1){
} }
void shared_tile::set_value(indices_t idx, Value *value) { void shared_tile::set_value(indices_t idx, Value *value) {
@@ -135,18 +136,33 @@ void shared_tile::set_value(indices_t idx, Value *value) {
builder_.CreateStore(value, ptr); builder_.CreateStore(value, ptr);
} }
void shared_tile::set_vector_size(unsigned vector_size) {
vector_size_ = vector_size;
}
Value* shared_tile::get_value(indices_t idx) { Value* shared_tile::get_value(indices_t idx) {
indices_t non_cst_idx, cst_idx; indices_t non_cst_idx, cst_idx;
extract_constant(idx, non_cst_idx, cst_idx); extract_constant(idx, non_cst_idx, cst_idx);
Value *&base_ptr = ptr_cache_[non_cst_idx]; Value *&base_ptr = ptr_cache_[non_cst_idx];
if(base_ptr == nullptr){ if(base_ptr == nullptr){
base_ptr = builder_.CreateGEP(ptr_, shared_offset(non_cst_idx)); base_ptr = builder_.CreateGEP(ptr_, shared_offset(non_cst_idx));
// Type *vec_ty = VectorType::get(base_ptr->getType()->getPointerElementType(), vec_); if(vector_size_ > 1){
// Type *vec_ptr_ty = PointerType::get(vec_ty, base_ptr->getType()->getPointerElementType()); Type *vec_ty = VectorType::get(base_ptr->getType()->getPointerElementType(), vector_size_);
// base_ptr = builder_.CreateBitCast(base_ptr, vec_ptr_ty); Type *vec_ptr_ty = PointerType::get(vec_ty, base_ptr->getType()->getPointerAddressSpace());
base_ptr = builder_.CreateBitCast(base_ptr, vec_ptr_ty);
}
} }
Value *ptr = builder_.CreateGEP(base_ptr, shared_offset(cst_idx)); Value *offset = shared_offset(cst_idx);
return builder_.CreateLoad(ptr); Value *div = offset;
if(vector_size_ > 1)
div = builder_.CreateUDiv(offset, builder_.getInt32(vector_size_));
Value *ptr = builder_.CreateGEP(base_ptr, div);
Value *result = builder_.CreateLoad(ptr);
if(vector_size_ > 1) {
Value *rem = builder_.CreateURem(offset, builder_.getInt32(vector_size_));
result = builder_.CreateExtractElement(result, rem);
}
return result;
} }
/* convert ir::type to Type */ /* convert ir::type to Type */
@@ -623,15 +639,20 @@ void selection::lower_tile_instruction(ir::instruction *ins, llvm::IRBuilder<> &
ir::value *A = ins->get_operand(0); ir::value *A = ins->get_operand(0);
ir::value *B = ins->get_operand(1); ir::value *B = ins->get_operand(1);
ir::value *C = ins->get_operand(2); ir::value *C = ins->get_operand(2);
shared_tile *TA = (shared_tile*)tmap_.at(A);
shared_tile *TB = (shared_tile*)tmap_.at(B);
distributed_tile *TC = (distributed_tile*)tmap_.at(C);
TA->set_vector_size(TC->axis(0).contiguous);
TB->set_vector_size(TC->axis(1).contiguous);
Function *f_mul_add = Intrinsic::getDeclaration(module, Intrinsic::fmuladd, {llvm_type(C->get_type()->get_scalar_ty(), ctx)}); Function *f_mul_add = Intrinsic::getDeclaration(module, Intrinsic::fmuladd, {llvm_type(C->get_type()->get_scalar_ty(), ctx)});
result->for_each([&](indices_t idx){ result->for_each([&](indices_t idx){
Value *res = tmap_.at(C)->get_value(idx); Value *res = TC->get_value(idx);
unsigned NK = A->get_type()->get_tile_shapes()[1]->get_value(); unsigned NK = A->get_type()->get_tile_shapes()[1]->get_value();
for(unsigned K = 0; K < NK; ++K){ for(unsigned K = 0; K < NK; ++K){
indices_t a_idx = {idx[0], builder.getInt32(K)}; indices_t a_idx = {idx[0], builder.getInt32(K)};
indices_t b_idx = {idx[1], builder.getInt32(K)}; indices_t b_idx = {idx[1], builder.getInt32(K)};
Value *a = tmap_.at(A)->get_value(a_idx); Value *a = TA->get_value(a_idx);
Value *b = tmap_.at(B)->get_value(b_idx); Value *b = TB->get_value(b_idx);
res = builder.CreateCall(f_mul_add, {a, b, res}); res = builder.CreateCall(f_mul_add, {a, b, res});
} }
result->set_value(idx, res); result->set_value(idx, res);
@@ -660,10 +681,20 @@ void selection::lower_instruction(ir::instruction *src, IRBuilder<> &builder) {
} }
else { else {
Instruction *i = (Instruction*)llvm_value(src, builder); Instruction *i = (Instruction*)llvm_value(src, builder);
std::cout << "instruction: " << src->get_name() << " " << src->has_tile_result_or_op() << std::endl;
vmap_[src] = i; vmap_[src] = i;
} }
} }
inline llvm::Attribute::AttrKind llvm_attr(ir::attribute_t attr) {
switch(attr){
case ir::noalias: return llvm::Attribute::NoAlias;
case ir::readonly: return llvm::Attribute::ReadOnly;
case ir::writeonly: return llvm::Attribute::WriteOnly;
default: throw std::runtime_error("cannot convert ir::attribute_t to llvm::Attribute");
}
}
void selection::run(ir::module &src, Module &dst){ void selection::run(ir::module &src, Module &dst){
vmap_.clear(); vmap_.clear();
LLVMContext &dst_ctx = dst.getContext(); LLVMContext &dst_ctx = dst.getContext();
@@ -675,7 +706,13 @@ void selection::run(ir::module &src, Module &dst){
// create LLVM function // create LLVM function
FunctionType *fn_ty = (FunctionType*)llvm_type(fn->get_fn_type(), dst_ctx); FunctionType *fn_ty = (FunctionType*)llvm_type(fn->get_fn_type(), dst_ctx);
Function *dst_fn = Function::Create(fn_ty, Function::ExternalLinkage, fn->get_name(), &dst); Function *dst_fn = Function::Create(fn_ty, Function::ExternalLinkage, fn->get_name(), &dst);
// Set metadata // set attributes
for(auto attr_pair: fn->attrs()){
unsigned id = attr_pair.first;
for(ir::attribute_t attr: attr_pair.second)
dst_fn->addAttribute(id, llvm_attr(attr));
}
// set metadata
llvm::Metadata *md_args[] = { llvm::Metadata *md_args[] = {
llvm::ValueAsMetadata::get(dst_fn), llvm::ValueAsMetadata::get(dst_fn),
llvm::MDString::get(dst_ctx, "kernel"), llvm::MDString::get(dst_ctx, "kernel"),
@@ -760,6 +797,7 @@ void selection::run(ir::module &src, Module &dst){
}); });
} }
else { else {
std::cout << phi->get_name() << " " << inc_val->get_name() << std::endl;
PHINode *llvm_phi = (PHINode*)vmap_.at(phi); PHINode *llvm_phi = (PHINode*)vmap_.at(phi);
Value *llvm_inc_val = vmap_.at(inc_val); Value *llvm_inc_val = vmap_.at(inc_val);
llvm_phi->addIncoming(llvm_inc_val, llvm_inc_block); llvm_phi->addIncoming(llvm_inc_val, llvm_inc_block);

View File

@@ -60,7 +60,7 @@ ir::phi_node* module::make_phi(ir::type *ty, unsigned num_values, ir::basic_bloc
return res; return res;
} }
ir::value *module::try_remove_trivial_phis(ir::phi_node *&phi, ir::value** pre_user){ ir::value *module::try_remove_trivial_phis(ir::phi_node *&phi){
// find non-self references // find non-self references
std::set<ir::value*> non_self_ref; std::set<ir::value*> non_self_ref;
std::copy_if(phi->ops().begin(), phi->ops().end(), std::inserter(non_self_ref, non_self_ref.begin()), std::copy_if(phi->ops().begin(), phi->ops().end(), std::inserter(non_self_ref, non_self_ref.begin()),
@@ -76,7 +76,7 @@ ir::value *module::try_remove_trivial_phis(ir::phi_node *&phi, ir::value** pre_u
for(ir::user* u: users) for(ir::user* u: users)
if(auto *uphi = dynamic_cast<ir::phi_node*>(u)) if(auto *uphi = dynamic_cast<ir::phi_node*>(u))
if(uphi != phi) if(uphi != phi)
try_remove_trivial_phis(uphi, &same); try_remove_trivial_phis(uphi);
return same; return same;
} }
@@ -113,7 +113,7 @@ ir::value *module::get_value_recursive(const std::string& name, ir::basic_block
result = add_phi_operands(name, (ir::phi_node*&)result); result = add_phi_operands(name, (ir::phi_node*&)result);
} }
if(auto *phi = dynamic_cast<ir::phi_node*>(result)) if(auto *phi = dynamic_cast<ir::phi_node*>(result))
result = try_remove_trivial_phis(phi, nullptr); result = try_remove_trivial_phis(phi);
set_value(name, block, result); set_value(name, block, result);
return result; return result;
} }
@@ -155,7 +155,7 @@ ir::type *module::get_type(const std::string &name) {
void module::seal_block(ir::basic_block *block){ void module::seal_block(ir::basic_block *block){
for(auto &x: incomplete_phis_[block]){ for(auto &x: incomplete_phis_[block]){
add_phi_operands(x.first, x.second); add_phi_operands(x.first, x.second);
try_remove_trivial_phis(x.second, nullptr); set_value(x.first, try_remove_trivial_phis(x.second));
} }
sealed_blocks_.insert(block); sealed_blocks_.insert(block);
incomplete_phis_[block].clear(); incomplete_phis_[block].clear();