[codegen] bugfix in alignment inference
This commit is contained in:
@@ -26,9 +26,9 @@ const tunable int32 TN = {64, 128};
|
||||
const tunable int32 TK = {16};
|
||||
const tunable int32 GZ = {1};
|
||||
|
||||
void matmul(restrict read_only align(4) fp16 *A,
|
||||
restrict read_only align(4) fp16 *B,
|
||||
align(4) fp32 *C,
|
||||
void matmul(restrict read_only align(16) fp16 *A,
|
||||
restrict read_only align(16) fp16 *B,
|
||||
align(16) fp32 *C,
|
||||
int32 M, int32 N, int32 K,
|
||||
multiple_of(4) int32 lda, multiple_of(4) int32 ldb, multiple_of(4) int32 ldc,
|
||||
int32 *locks, int32 grid0, int32 grid1) {
|
||||
|
@@ -16,20 +16,22 @@ namespace codegen{
|
||||
class axis_info {
|
||||
private:
|
||||
// helpers
|
||||
bool is_first_axis_unit(ir::value *x);
|
||||
bool is_first_axis_unit(ir::value *v);
|
||||
|
||||
// populate maps
|
||||
bool populate_is_constant(ir::value *i);
|
||||
unsigned populate_max_contiguous(ir::value *i);
|
||||
unsigned populate_multiple_of(ir::value *i);
|
||||
bool populate_is_constant(ir::value *v);
|
||||
unsigned populate_max_contiguous(ir::value *v);
|
||||
unsigned populate_starting_multiple(ir::value *v);
|
||||
|
||||
public:
|
||||
void run(ir::module &mod);
|
||||
unsigned get_starting_multiple(ir::value* v) const;
|
||||
unsigned get_max_contiguous(ir::value* v) const;
|
||||
|
||||
private:
|
||||
std::map<ir::value*, bool> is_constant_;
|
||||
std::map<ir::value*, unsigned> max_contiguous_;
|
||||
std::map<ir::value*, unsigned> multiple_of_;
|
||||
std::map<ir::value*, unsigned> starting_multiple_;
|
||||
};
|
||||
|
||||
|
||||
|
@@ -25,6 +25,7 @@ class shmem_allocation;
|
||||
class tune;
|
||||
class shmem_info;
|
||||
class target;
|
||||
class axis_info;
|
||||
|
||||
typedef std::vector<llvm::Value*> indices_t;
|
||||
|
||||
@@ -143,8 +144,8 @@ private:
|
||||
void lower_tile_instruction(ir::instruction *src, llvm::IRBuilder<> &builder);
|
||||
|
||||
public:
|
||||
selection(shmem_allocation *alloc, tune *params, shmem_info *buffer_info, target *tgt)
|
||||
: alloc_(alloc), params_(params), buffer_info_(buffer_info), tgt_(tgt){ }
|
||||
selection(shmem_allocation *alloc, tune *params, shmem_info *buffer_info, axis_info *ax_info, target *tgt)
|
||||
: alloc_(alloc), params_(params), buffer_info_(buffer_info), axis_info_(ax_info), tgt_(tgt){ }
|
||||
|
||||
void run(ir::module &src, llvm::Module &dst);
|
||||
|
||||
@@ -157,6 +158,7 @@ private:
|
||||
tune *params_;
|
||||
target *tgt_;
|
||||
shmem_info *buffer_info_;
|
||||
axis_info *axis_info_;
|
||||
std::map<unsigned, distributed_axis> axes_;
|
||||
llvm::Value *sh_mem_ptr_;
|
||||
llvm::Value *offset_a_i_, *offset_a_k_;
|
||||
|
@@ -57,7 +57,7 @@ public:
|
||||
shmem_allocation(&shmem_liveness, &shmem_info, &tune),
|
||||
shmem_barriers(&shmem_allocation, &shmem_info),
|
||||
vectorize(&tune),
|
||||
selection(&shmem_allocation, &tune, &shmem_info, target),
|
||||
selection(&shmem_allocation, &tune, &shmem_info, &axis_info, target),
|
||||
optimize_dot(&tune),
|
||||
optimize_cse(),
|
||||
optimize_trans(),
|
||||
@@ -67,11 +67,11 @@ public:
|
||||
void target_independent(ir::module &module) {
|
||||
optimize_dot.run(module);
|
||||
optimize_trans.run(module);
|
||||
axis_info.run(module);
|
||||
// ir::print(module, std::cout);
|
||||
ir::print(module, std::cout);
|
||||
}
|
||||
|
||||
void target_dependent(ir::module &module) {
|
||||
axis_info.run(module);
|
||||
if(target_->is_gpu()){
|
||||
shmem_info.run(module);
|
||||
shmem_liveness.run(module);
|
||||
|
@@ -11,7 +11,7 @@ namespace codegen{
|
||||
|
||||
template<class T>
|
||||
inline T add_to_cache(ir::value *i, T value, std::map<ir::value*, T> &map) {
|
||||
return map.insert(std::make_pair(i, value)).first->second;
|
||||
return map[i] = value;
|
||||
}
|
||||
|
||||
|
||||
@@ -23,63 +23,132 @@ bool axis_info::is_first_axis_unit(ir::value *x){
|
||||
}
|
||||
|
||||
bool axis_info::populate_is_constant(ir::value *v) {
|
||||
if(is_constant_.find(v) != is_constant_.end())
|
||||
return is_constant_.at(v);
|
||||
// helper for the cache
|
||||
auto cache = [this,v](bool value){ return add_to_cache(v, value, is_constant_); };
|
||||
// populate
|
||||
if(v->get_type()->is_tile_ty()){
|
||||
if(auto *x = dynamic_cast<ir::retile_inst*>(v)){
|
||||
bool value = populate_is_constant(x->get_operand(0));
|
||||
// check if broadcast (i.e., constant) along contiguous dimension
|
||||
if(is_first_axis_unit(x->get_operand(0))
|
||||
&& !is_first_axis_unit(x))
|
||||
return cache(value);
|
||||
}
|
||||
// otherwise the tile is not constant in the contiguous dimension
|
||||
if(auto *x = dynamic_cast<ir::retile_inst*>(v)){
|
||||
ir::value *op = x->get_operand(0);
|
||||
populate_is_constant(op);
|
||||
if(is_first_axis_unit(op))
|
||||
return cache(true);
|
||||
}
|
||||
if(auto *x = dynamic_cast<ir::binary_operator*>(v)){
|
||||
bool lhs = populate_is_constant(x->get_operand(0));
|
||||
bool rhs = populate_is_constant(x->get_operand(1));
|
||||
return cache(lhs && rhs);
|
||||
}
|
||||
if(v->get_type()->is_tile_ty())
|
||||
return cache(false);
|
||||
if(auto *x = dynamic_cast<ir::phi_node*>(v)){
|
||||
// put a conservative initial value in phi node to avoid infinite recursion
|
||||
bool result = true;
|
||||
for(unsigned n = 0; n < x->get_num_incoming(); n++){
|
||||
ir::value* inc = x->get_incoming_value(n);
|
||||
if(is_constant_.find(inc) != is_constant_.end())
|
||||
result = is_constant_.at(inc);
|
||||
}
|
||||
cache(result);
|
||||
// recurse
|
||||
for(unsigned n = 0; n < x->get_num_incoming(); n++){
|
||||
ir::value* inc = x->get_incoming_value(n);
|
||||
result = result && populate_is_constant(inc);
|
||||
}
|
||||
return cache(result);
|
||||
}
|
||||
// scalars are always constant in the contiguous dimension
|
||||
return cache(true);
|
||||
}
|
||||
|
||||
unsigned axis_info::populate_max_contiguous(ir::value *v){
|
||||
if(max_contiguous_.find(v) != max_contiguous_.end())
|
||||
return max_contiguous_.at(v);
|
||||
// helper for the cache
|
||||
auto cache = [this,v](unsigned value){ return add_to_cache(v, value, max_contiguous_); };
|
||||
// populate
|
||||
if(v->get_type()->is_tile_ty()){
|
||||
auto shapes = v->get_type()->get_tile_shapes();
|
||||
if(dynamic_cast<ir::get_global_range_inst*>(v))
|
||||
return cache(shapes[0]->get_value());
|
||||
if(auto *x = dynamic_cast<ir::binary_operator*>(v)){
|
||||
ir::value* lhs = x->get_operand(0);
|
||||
ir::value* rhs = x->get_operand(1);
|
||||
unsigned lhs_max_contiguous = populate_max_contiguous(lhs);
|
||||
bool lhs_has_cst = populate_is_constant(lhs);
|
||||
unsigned rhs_max_contiguous = populate_max_contiguous(rhs);
|
||||
bool rhs_has_cst = populate_is_constant(rhs);
|
||||
if(x->is_int_add_sub()){
|
||||
if(lhs_has_cst)
|
||||
return cache(rhs_max_contiguous);
|
||||
if(rhs_has_cst)
|
||||
return cache(lhs_max_contiguous);
|
||||
}
|
||||
if(!v->get_type()->is_tile_ty())
|
||||
return cache(1);
|
||||
auto shapes = v->get_type()->get_tile_shapes();
|
||||
if(dynamic_cast<ir::get_global_range_inst*>(v))
|
||||
return cache(shapes[0]->get_value());
|
||||
if(dynamic_cast<ir::constant_range*>(v))
|
||||
return cache(shapes[0]->get_value());
|
||||
if(auto *x = dynamic_cast<ir::retile_inst*>(v)){
|
||||
ir::value *op = x->get_operand(0);
|
||||
if(op->get_type()->is_tile_ty()){
|
||||
auto op_shapes = op->get_type()->get_tile_shapes();
|
||||
if(op_shapes[0] == shapes[0])
|
||||
return cache(populate_max_contiguous(op));
|
||||
}
|
||||
return cache(1);
|
||||
}
|
||||
if(auto *x = dynamic_cast<ir::binary_operator*>(v)){
|
||||
ir::value* lhs = x->get_operand(0);
|
||||
ir::value* rhs = x->get_operand(1);
|
||||
unsigned lhs_max_contiguous = populate_max_contiguous(lhs);
|
||||
unsigned rhs_max_contiguous = populate_max_contiguous(rhs);
|
||||
bool lhs_has_cst = populate_is_constant(lhs);
|
||||
bool rhs_has_cst = populate_is_constant(rhs);
|
||||
if(x->is_int_add_sub()){
|
||||
if(lhs_has_cst)
|
||||
return cache(rhs_max_contiguous);
|
||||
if(rhs_has_cst)
|
||||
return cache(lhs_max_contiguous);
|
||||
}
|
||||
}
|
||||
if(auto *x = dynamic_cast<ir::getelementptr_inst*>(v)){
|
||||
ir::value* lhs = x->get_operand(0);
|
||||
ir::value* rhs = x->get_operand(1);
|
||||
unsigned lhs_max_contiguous = populate_max_contiguous(lhs);
|
||||
unsigned rhs_max_contiguous = populate_max_contiguous(rhs);
|
||||
bool lhs_has_cst = populate_is_constant(lhs);
|
||||
bool rhs_has_cst = populate_is_constant(rhs);
|
||||
if(lhs_has_cst)
|
||||
return cache(rhs_max_contiguous);
|
||||
if(rhs_has_cst)
|
||||
return cache(lhs_max_contiguous);
|
||||
}
|
||||
if(auto *x = dynamic_cast<ir::phi_node*>(v)){
|
||||
// put a conservative initial value in phi node to avoid infinite recursion
|
||||
unsigned result = 1;
|
||||
for(unsigned n = 0; n < x->get_num_incoming(); n++){
|
||||
ir::value* inc = x->get_incoming_value(n);
|
||||
if(max_contiguous_.find(inc) != max_contiguous_.end())
|
||||
result = max_contiguous_.at(inc);
|
||||
}
|
||||
cache(result);
|
||||
// recurse
|
||||
for(unsigned n = 0; n < x->get_num_incoming(); n++){
|
||||
ir::value* inc = x->get_incoming_value(n);
|
||||
result = std::min(result, populate_max_contiguous(inc));
|
||||
}
|
||||
return cache(result);
|
||||
}
|
||||
return cache(1);
|
||||
}
|
||||
|
||||
unsigned axis_info::populate_multiple_of(ir::value *v){
|
||||
auto cache = [this,v](unsigned value){ return add_to_cache(v, value, max_contiguous_); };
|
||||
|
||||
unsigned axis_info::populate_starting_multiple(ir::value *v){
|
||||
if(starting_multiple_.find(v) != starting_multiple_.end())
|
||||
return starting_multiple_.at(v);
|
||||
auto cache = [this,v](unsigned value){ return add_to_cache(v, value, starting_multiple_); };
|
||||
// arguments
|
||||
if(auto *x = dynamic_cast<ir::argument*>(v)){
|
||||
std::set<ir::attribute> attributes = x->get_parent()->get_attributes(x);
|
||||
for(auto attr: attributes){
|
||||
if(attr.get_kind() == ir::multiple_of)
|
||||
return cache(attr.get_value());
|
||||
if(attr.get_kind() == ir::aligned){
|
||||
ir::type* ty = x->get_type()->get_pointer_element_ty();
|
||||
int nbits = ty->get_primitive_size_in_bits();
|
||||
int nbytes = nbits / 8;
|
||||
return cache(attr.get_value() / nbytes);
|
||||
}
|
||||
}
|
||||
}
|
||||
if(auto *x = dynamic_cast<ir::binary_operator*>(v)){
|
||||
int lhs = populate_multiple_of(x->get_operand(0));
|
||||
int rhs = populate_multiple_of(x->get_operand(1));
|
||||
int lhs = populate_starting_multiple(x->get_operand(0));
|
||||
int rhs = populate_starting_multiple(x->get_operand(1));
|
||||
if(x->is_int_mult())
|
||||
return cache(lhs * rhs);
|
||||
if(x->is_int_add_sub())
|
||||
@@ -93,12 +162,52 @@ unsigned axis_info::populate_multiple_of(ir::value *v){
|
||||
if(x->is_shr())
|
||||
return cache(std::max(lhs >> rhs, 1));
|
||||
}
|
||||
if(auto *x = dynamic_cast<ir::retile_inst*>(v)){
|
||||
return cache(populate_multiple_of(x->get_operand(0)));
|
||||
if(auto *x = dynamic_cast<ir::getelementptr_inst*>(v)){
|
||||
int lhs = populate_starting_multiple(x->get_operand(0));
|
||||
int rhs = populate_starting_multiple(x->get_operand(1));
|
||||
return cache(std::min(lhs, rhs));
|
||||
}
|
||||
return cache(1);
|
||||
if(auto *x = dynamic_cast<ir::retile_inst*>(v)){
|
||||
int op = populate_starting_multiple(x->get_operand(0));
|
||||
return cache(op);
|
||||
}
|
||||
if(auto *x = dynamic_cast<ir::get_global_range_inst*>(v)){
|
||||
return cache(v->get_type()->get_tile_shapes()[0]->get_value());
|
||||
}
|
||||
if(auto *x = dynamic_cast<ir::phi_node*>(v)){
|
||||
// put a conservative initial value in phi node to avoid infinite recursion
|
||||
unsigned result = 1;
|
||||
for(unsigned n = 0; n < x->get_num_incoming(); n++){
|
||||
ir::value* inc = x->get_incoming_value(n);
|
||||
if(starting_multiple_.find(inc) != starting_multiple_.end())
|
||||
result = starting_multiple_.at(inc);
|
||||
}
|
||||
cache(result);
|
||||
// recurse
|
||||
for(unsigned n = 0; n < x->get_num_incoming(); n++){
|
||||
ir::value* inc = x->get_incoming_value(n);
|
||||
result = std::min(result, populate_starting_multiple(inc));
|
||||
}
|
||||
return cache(result);
|
||||
}
|
||||
// scalars
|
||||
if(!v->get_type()->is_tile_ty())
|
||||
return cache(1);
|
||||
// tiles
|
||||
auto shapes = v->get_type()->get_tile_shapes();
|
||||
unsigned result = 1;
|
||||
for(unsigned i = 0; i < shapes.size() - 1; i++)
|
||||
result *= shapes[i]->get_value();
|
||||
return cache(result);
|
||||
}
|
||||
|
||||
unsigned axis_info::get_starting_multiple(ir::value* v) const {
|
||||
return starting_multiple_.at(v);
|
||||
}
|
||||
|
||||
unsigned axis_info::get_max_contiguous(ir::value* v) const {
|
||||
return max_contiguous_.at(v);
|
||||
}
|
||||
|
||||
|
||||
void axis_info::run(ir::module &mod) {
|
||||
@@ -109,11 +218,11 @@ void axis_info::run(ir::module &mod) {
|
||||
populate_is_constant(i);
|
||||
}
|
||||
|
||||
// populate multiple_of
|
||||
// populate starting multiple
|
||||
for(ir::function *fn: mod.get_function_list())
|
||||
for(ir::basic_block *block: fn->blocks())
|
||||
for(ir::instruction *i: block->get_inst_list()){
|
||||
populate_multiple_of(i);
|
||||
populate_starting_multiple(i);
|
||||
}
|
||||
|
||||
// populate maximum contiguous
|
||||
|
@@ -2,6 +2,7 @@
|
||||
#include "triton/codegen/tune.h"
|
||||
#include "triton/codegen/shmem_allocation.h"
|
||||
#include "triton/codegen/target.h"
|
||||
#include "triton/codegen/axis_info.h"
|
||||
#include "llvm/IR/InstrTypes.h"
|
||||
#include "llvm/IR/Module.h"
|
||||
#include "llvm/IR/IRBuilder.h"
|
||||
@@ -1027,7 +1028,11 @@ void selection::lower_tile_instruction(ir::instruction *ins, llvm::IRBuilder<> &
|
||||
}
|
||||
}
|
||||
else if(auto *ld = dynamic_cast<ir::load_inst*>(ins)){
|
||||
unsigned vector_size = result->axis(0).contiguous;
|
||||
ir::value *ptr = ld->get_pointer_operand();
|
||||
unsigned starting_multiple = axis_info_->get_starting_multiple(ptr);
|
||||
unsigned max_contiguous = axis_info_->get_max_contiguous(ptr);
|
||||
unsigned alignment = std::min(starting_multiple, max_contiguous);
|
||||
unsigned vector_size = std::min<unsigned>(result->axis(0).contiguous, alignment);
|
||||
std::map<unsigned, Value*> packets;
|
||||
distributed_tile *TP = (distributed_tile*)tmap_.at(ld->get_pointer_operand());
|
||||
result->for_each([&](indices_t idx){
|
||||
|
@@ -255,7 +255,7 @@ std::string cu_module::compile_llvm_module(llvm::Module* module) {
|
||||
cu_module::cu_module(driver::context * context, llvm::Module* ll_module): cu_module(context, compile_llvm_module(ll_module)) { }
|
||||
|
||||
cu_module::cu_module(driver::context * context, std::string const & source) : module(context, CUmodule(), true), source_(source){
|
||||
// std::cout << source << std::endl;
|
||||
std::cout << source << std::endl;
|
||||
cu_context::context_switcher ctx_switch(*context);
|
||||
// JIT compile source-code
|
||||
CUjit_option opt[] = {CU_JIT_ERROR_LOG_BUFFER_SIZE_BYTES, CU_JIT_ERROR_LOG_BUFFER};
|
||||
|
Reference in New Issue
Block a user