[codegen] more hmma row-major handling

This commit is contained in:
Philippe Tillet
2019-09-24 19:35:46 -04:00
parent c24d55db23
commit a3bf3a1804
8 changed files with 60 additions and 30 deletions

View File

@@ -89,7 +89,7 @@ private:
public: public:
shared_tile(Type* ty, const shapes_t &shapes, Value* ptr, Builder &builder, Value* offset = nullptr); shared_tile(Type* ty, const shapes_t &shapes, const std::vector<int> &order, Value* ptr, Builder &builder, Value* offset = nullptr);
void set_vector_size(unsigned vector_size); void set_vector_size(unsigned vector_size);
void set_return_mode(bool return_vector); void set_return_mode(bool return_vector);
void set_value(indices_t, Value *); void set_value(indices_t, Value *);
@@ -97,7 +97,8 @@ public:
Value* get_value(indices_t idx); Value* get_value(indices_t idx);
Value* get_pointer() { return ptr_; } Value* get_pointer() { return ptr_; }
Value* get_offset() { return offset_; } Value* get_offset() { return offset_; }
static Value* shared_offset(Builder& builder, const shapes_t& shapes, indices_t idx); const std::vector<int>& get_order() { return order_; }
static Value* shared_offset(Builder& builder, const shapes_t& shapes, const std::vector<int>& order, indices_t idx);
private: private:
Value *ptr_; Value *ptr_;
@@ -106,6 +107,7 @@ private:
Value *offset_; Value *offset_;
std::map<indices_t, Value*> ptr_cache_; std::map<indices_t, Value*> ptr_cache_;
unsigned vector_size_; unsigned vector_size_;
std::vector<int> order_;
}; };
// Distribtued tile // Distribtued tile
@@ -123,6 +125,7 @@ public:
distributed_tile(Type *ty, const shapes_t& shapes, const std::vector<int>& order, const axes_t &axes, Builder &builder, bool vectorize); distributed_tile(Type *ty, const shapes_t& shapes, const std::vector<int>& order, const axes_t &axes, Builder &builder, bool vectorize);
void set_value(indices_t idx, Value *v); void set_value(indices_t idx, Value *v);
Value* get_value(indices_t idx); Value* get_value(indices_t idx);
const std::vector<int>& get_order() { return order_; }
unsigned get_linear_index(indices_t idx); unsigned get_linear_index(indices_t idx);
indices_t get_ordered_indices(unsigned id); indices_t get_ordered_indices(unsigned id);
void for_each(std::function<void(indices_t)> fn); void for_each(std::function<void(indices_t)> fn);

View File

@@ -57,7 +57,7 @@ unsigned allocation::num_bytes(ir::value *x) {
unsigned num_bytes = x->get_type()->get_primitive_size_in_bits() / 8; unsigned num_bytes = x->get_type()->get_primitive_size_in_bits() / 8;
unsigned pad = is_ld_padded(x); unsigned pad = is_ld_padded(x);
if(pad > 0){ if(pad > 0){
unsigned ld = x->get_type()->get_tile_shapes()[0]; unsigned ld = x->get_type()->get_tile_shapes()[tiles_->order(x)[0]];
num_bytes += pad * num_bytes / ld; num_bytes += pad * num_bytes / ld;
} }
if(liveness_->has_double(x)) if(liveness_->has_double(x))

View File

@@ -218,6 +218,7 @@ void tiles::run(ir::module &) {
auto cmp = [&rank](ir::value* x, ir::value *y) { return rank(x) < rank(y); }; auto cmp = [&rank](ir::value* x, ir::value *y) { return rank(x) < rank(y); };
largest_[i] = *std::max_element(values.begin(), values.end(), cmp); largest_[i] = *std::max_element(values.begin(), values.end(), cmp);
} }
// find out the order of a group // find out the order of a group
for(size_t i = 0; i < num_groups; i++){ for(size_t i = 0; i < num_groups; i++){
std::set<ir::io_inst*> io; std::set<ir::io_inst*> io;
@@ -237,6 +238,20 @@ void tiles::run(ir::module &) {
} }
order_[i] = order; order_[i] = order;
} }
for(size_t i = 0; i < num_groups; i++){
bool is_hmma_op = hmma_[i] == HMMA_A_COL || hmma_[i] == HMMA_A_ROW ||
hmma_[i] == HMMA_B_COL || hmma_[i] == HMMA_B_ROW;
if(!is_hmma_op)
continue;
// extract copies to shared memory
std::vector<ir::copy_to_shared_inst*> cts;
for(ir::value* v: layout_->values(i))
if(auto *x = dynamic_cast<ir::copy_to_shared_inst*>(v))
cts.push_back(x);
if(cts.empty())
continue;
order_[i] = order(cts[0]->get_operand(0));
}
// tiling parameters // tiling parameters
for(auto x: largest_){ for(auto x: largest_){
ir::value *i = x.second; ir::value *i = x.second;

View File

@@ -146,26 +146,26 @@ void shared_tile::extract_constant(const indices_t &arg_idx, indices_t &non_cst_
} }
Value* shared_tile::shared_offset(llvm::IRBuilder<> &builder, const shapes_t& shapes, indices_t idx) { Value* shared_tile::shared_offset(llvm::IRBuilder<> &builder, const shapes_t& shapes, const std::vector<int>& order, indices_t idx) {
Value *result = builder.getInt32(0); Value *result = builder.getInt32(0);
result = builder.CreateAdd(result, idx[0]); result = builder.CreateAdd(result, idx[order[0]]);
Value *ld = builder.getInt32(shapes[0]); Value *ld = builder.getInt32(shapes[order[0]]);
for(size_t i = 1; i < idx.size(); i++) { for(size_t i = 1; i < idx.size(); i++) {
result = builder.CreateAdd(result, builder.CreateMul(idx[i], ld)); result = builder.CreateAdd(result, builder.CreateMul(idx[order[i]], ld));
if(i < idx.size() - 1){ if(i < idx.size() - 1){
ld = builder.CreateMul(ld, builder.getInt32(shapes[i])); ld = builder.CreateMul(ld, builder.getInt32(shapes[order[i]]));
} }
} }
return result; return result;
} }
shared_tile::shared_tile(Type *ty, const shapes_t &shapes, Value *ptr, llvm::IRBuilder<> &builder, Value *offset): shared_tile::shared_tile(Type *ty, const shapes_t &shapes, const std::vector<int>& order, Value *ptr, llvm::IRBuilder<> &builder, Value *offset):
tile(ty, shapes), ptr_(ptr), builder_(builder), offset_(offset), vector_size_(1){ tile(ty, shapes), order_(order), ptr_(ptr), builder_(builder), offset_(offset), vector_size_(1){
return_vector_ = false; return_vector_ = false;
} }
void shared_tile::set_value(indices_t idx, Value *value) { void shared_tile::set_value(indices_t idx, Value *value) {
Value *ptr = builder_.CreateGEP(ptr_, shared_offset(builder_, shapes_, idx)); Value *ptr = builder_.CreateGEP(ptr_, shared_offset(builder_, shapes_, order_, idx));
unsigned addr_space = ptr->getType()->getPointerAddressSpace(); unsigned addr_space = ptr->getType()->getPointerAddressSpace();
ptr = builder_.CreateBitCast(ptr, value->getType()->getPointerTo(addr_space)); ptr = builder_.CreateBitCast(ptr, value->getType()->getPointerTo(addr_space));
builder_.CreateStore(value, ptr); builder_.CreateStore(value, ptr);
@@ -196,7 +196,7 @@ Value* shared_tile::get_value(indices_t idx) {
// if(isa<Instruction>(non_cst_idx.front())){ // if(isa<Instruction>(non_cst_idx.front())){
// builder_.SetInsertPoint((Instruction*)non_cst_idx.front()); // builder_.SetInsertPoint((Instruction*)non_cst_idx.front());
// } // }
base_ptr = builder_.CreateGEP(ptr_, shared_offset(builder_, shapes_, non_cst_idx)); base_ptr = builder_.CreateGEP(ptr_, shared_offset(builder_, shapes_, order_, non_cst_idx));
if(vector_size_ > 1){ if(vector_size_ > 1){
Type *vec_ty = VectorType::get(ty, vector_size); Type *vec_ty = VectorType::get(ty, vector_size);
Type *vec_ptr_ty = PointerType::get(vec_ty, base_ptr->getType()->getPointerAddressSpace()); Type *vec_ptr_ty = PointerType::get(vec_ty, base_ptr->getType()->getPointerAddressSpace());
@@ -204,7 +204,7 @@ Value* shared_tile::get_value(indices_t idx) {
} }
// builder_.SetInsertPoint(store); // builder_.SetInsertPoint(store);
} }
Value *offset = shared_offset(builder_, shapes_, cst_idx); Value *offset = shared_offset(builder_, shapes_, order_, cst_idx);
Value *div = offset; Value *div = offset;
if(vector_size_ > 1) if(vector_size_ > 1)
div = builder_.CreateUDiv(offset, builder_.getInt32(vector_size_)); div = builder_.CreateUDiv(offset, builder_.getInt32(vector_size_));
@@ -721,10 +721,13 @@ void selection::init_axes(ir::value *v, IRBuilder<> &builder, Value *u_thread_id
* ------------------- */ * ------------------- */
void selection::create_shared_tile(ir::value *v, IRBuilder<> &builder, Value *sh_mem_ptr) { void selection::create_shared_tile(ir::value *v, IRBuilder<> &builder, Value *sh_mem_ptr) {
if(tmap_.find(v) != tmap_.end())
return;
auto order = tiles_->order(v);
auto shapes = v->get_type()->get_tile_shapes(); auto shapes = v->get_type()->get_tile_shapes();
unsigned pad = alloc_->is_ld_padded(v); unsigned pad = alloc_->is_ld_padded(v);
if(pad > 0) if(pad > 0)
shapes[0] += pad; shapes[order[0]] += pad;
Type* ty = llvm_type(v->get_type()->get_scalar_ty(), builder.getContext()); Type* ty = llvm_type(v->get_type()->get_scalar_ty(), builder.getContext());
// shared copy // shared copy
PointerType *ptr_ty = ty->getPointerTo(sh_mem_ptr->getType()->getPointerAddressSpace()); PointerType *ptr_ty = ty->getPointerTo(sh_mem_ptr->getType()->getPointerAddressSpace());
@@ -744,15 +747,15 @@ void selection::create_shared_tile(ir::value *v, IRBuilder<> &builder, Value *sh
Value *pre_ptr = builder.CreateGEP(sh_mem_ptr, builder.getInt32(alloc_->offset(v))); Value *pre_ptr = builder.CreateGEP(sh_mem_ptr, builder.getInt32(alloc_->offset(v)));
pre_ptr = builder.CreateBitCast(pre_ptr, ptr->getType()); pre_ptr = builder.CreateBitCast(pre_ptr, ptr->getType());
Value *next_ptr = builder.CreateGEP(ptr, offset, "next_ptr"); Value *next_ptr = builder.CreateGEP(ptr, offset, "next_ptr");
tmap_.insert({phi, new shared_tile(ty, shapes, ptr, builder, offset)}); tmap_.insert({phi, new shared_tile(ty, shapes, order, ptr, builder, offset)});
tmap_.insert({v, new shared_tile(ty, shapes, pre_ptr, builder)}); tmap_.insert({v, new shared_tile(ty, shapes, order, pre_ptr, builder)});
tmap_.insert({info.latch, new shared_tile(ty, shapes, next_ptr, builder)}); tmap_.insert({info.latch, new shared_tile(ty, shapes, order, next_ptr, builder)});
} }
else { else {
size_t offset = alloc_->offset(v); size_t offset = alloc_->offset(v);
Value *ptr = builder.CreateGEP(sh_mem_ptr, builder.getInt32(offset)); Value *ptr = builder.CreateGEP(sh_mem_ptr, builder.getInt32(offset));
ptr = builder.CreateBitCast(ptr, ptr_ty); ptr = builder.CreateBitCast(ptr, ptr_ty);
tmap_.insert({v, new shared_tile(ty, shapes, ptr, builder)}); tmap_.insert({v, new shared_tile(ty, shapes, order, ptr, builder)});
} }
} }
@@ -920,7 +923,7 @@ void selection::lower_reduce(ir::reduce_inst *x, LLVMContext &ctx, Function *fn,
write_idx.insert(write_idx.begin() + axis, lane); write_idx.insert(write_idx.begin() + axis, lane);
// shared memory write pointer // shared memory write pointer
Value *write_offset = shared_tile::shared_offset(builder, op_tile->get_shapes(), write_idx); Value *write_offset = shared_tile::shared_offset(builder, op_tile->get_shapes(), op_tile->get_order(), write_idx);
Value *write_ptr = builder.CreateGEP(base_ptr, write_offset); Value *write_ptr = builder.CreateGEP(base_ptr, write_offset);
// initialize shared memory // initialize shared memory
@@ -933,7 +936,7 @@ void selection::lower_reduce(ir::reduce_inst *x, LLVMContext &ctx, Function *fn,
indices_t current(write_idx.size(), builder.getInt32(0)); indices_t current(write_idx.size(), builder.getInt32(0));
current[axis] = builder.getInt32(i); current[axis] = builder.getInt32(i);
// shared memory offset // shared memory offset
Value *read_offset = shared_tile::shared_offset(builder, op_tile->get_shapes(), current); Value *read_offset = shared_tile::shared_offset(builder, op_tile->get_shapes(), op_tile->get_order(), current);
Value *is_active = builder.CreateICmpULT(lane, builder.getInt32(i)); Value *is_active = builder.CreateICmpULT(lane, builder.getInt32(i));
read_offset = builder.CreateSelect(is_active, read_offset, builder.getInt32(0)); read_offset = builder.CreateSelect(is_active, read_offset, builder.getInt32(0));
// shared memory read pointer // shared memory read pointer
@@ -949,7 +952,7 @@ void selection::lower_reduce(ir::reduce_inst *x, LLVMContext &ctx, Function *fn,
// result is on the first lane of shared memory // result is on the first lane of shared memory
indices_t final = write_idx; indices_t final = write_idx;
final[axis] = builder.getInt32(0); final[axis] = builder.getInt32(0);
Value *read_offset = shared_tile::shared_offset(builder, op_tile->get_shapes(), final); Value *read_offset = shared_tile::shared_offset(builder, op_tile->get_shapes(), op_tile->get_order(), final);
Value *read_ptr = builder.CreateGEP(base_ptr, read_offset); Value *read_ptr = builder.CreateGEP(base_ptr, read_offset);
tgt_->add_barrier(module, builder); tgt_->add_barrier(module, builder);
result = builder.CreateLoad(read_ptr); result = builder.CreateLoad(read_ptr);
@@ -1077,17 +1080,24 @@ void selection::lower_hmma_dot(ir::dot_inst *dot, LLVMContext &ctx, Function *fn
Value *offset_b_k = offset_b_k_; Value *offset_b_k = offset_b_k_;
Value* u_thread_id = tgt_->get_local_id(builder.GetInsertBlock()->getModule(), builder, 0); Value* u_thread_id = tgt_->get_local_id(builder.GetInsertBlock()->getModule(), builder, 0);
if(dot->is_a_trans()){
auto ord_a = tiles_->order(dot->get_operand(0));
auto ord_b = tiles_->order(dot->get_operand(1));
bool is_a_row = dot->is_a_trans() ^ ord_a[ord_a.size() - 2] == 1;
bool is_b_row = dot->is_b_trans() ^ ord_b[ord_b.size() - 2] == 1;
if(is_a_row){
offset_a_i = builder.CreateAdd(offset_a_i, builder.CreateURem(u_thread_id, builder.getInt32(4))); offset_a_i = builder.CreateAdd(offset_a_i, builder.CreateURem(u_thread_id, builder.getInt32(4)));
offset_a_k = builder.getInt32(0); offset_a_k = builder.getInt32(0);
} }
if(!dot->is_b_trans()){ if(!is_b_row){
offset_b_j = builder.CreateAdd(offset_b_j, builder.CreateURem(u_thread_id, builder.getInt32(4))); offset_b_j = builder.CreateAdd(offset_b_j, builder.CreateURem(u_thread_id, builder.getInt32(4)));
offset_b_k = builder.getInt32(0); offset_b_k = builder.getInt32(0);
} }
std::string op_a = dot->is_a_trans() ? "row" : "col"; std::string op_a = is_a_row ? "row" : "col";
std::string op_b = dot->is_b_trans() ? "row" : "col"; std::string op_b = is_b_row ? "row" : "col";
InlineAsm *mma_fn = InlineAsm::get(mma_ty, " mma.sync.aligned.m8n8k4." + op_a + "." + op_b + ".f32.f16.f16.f32 " InlineAsm *mma_fn = InlineAsm::get(mma_ty, " mma.sync.aligned.m8n8k4." + op_a + "." + op_b + ".f32.f16.f16.f32 "
"{$0, $1, $2, $3, $4, $5, $6, $7}, " "{$0, $1, $2, $3, $4, $5, $6, $7}, "

View File

@@ -242,7 +242,7 @@ cu_module::cu_module(driver::context * context, std::unique_ptr<llvm::Module> ll
cu_module::cu_module(driver::context * context, std::string const & source) : module(context, CUmodule(), true), source_(source){ cu_module::cu_module(driver::context * context, std::string const & source) : module(context, CUmodule(), true), source_(source){
// std::cout << source << std::endl; // std::cout << source << std::endl;
cu_context::context_switcher ctx_switch(*context); cu_context::context_switcher ctx(*context);
// JIT compile source-code // JIT compile source-code
CUjit_option opt[] = {CU_JIT_ERROR_LOG_BUFFER_SIZE_BYTES, CU_JIT_ERROR_LOG_BUFFER}; CUjit_option opt[] = {CU_JIT_ERROR_LOG_BUFFER_SIZE_BYTES, CU_JIT_ERROR_LOG_BUFFER};
unsigned int errbufsize = 8096; unsigned int errbufsize = 8096;

View File

@@ -222,13 +222,11 @@ std::unique_ptr<driver::module> function::make_bin(ir::module &module, driver::c
axes.run(module); axes.run(module);
layouts.run(module); layouts.run(module);
coalesce.run(module); coalesce.run(module);
// ir::print(module, std::cout);
dce.run(module); dce.run(module);
align.run(module); align.run(module);
dce.run(module); dce.run(module);
tiles.run(module); tiles.run(module);
reassociate.run(module); reassociate.run(module);
peephole.run(module);
dce.run(module); dce.run(module);
cts.run(module); cts.run(module);
liveness.run(module); liveness.run(module);
@@ -242,6 +240,7 @@ std::unique_ptr<driver::module> function::make_bin(ir::module &module, driver::c
layouts.run(module); layouts.run(module);
align.run(module); align.run(module);
tiles.run(module); tiles.run(module);
// ir::print(module, std::cout);
selection.run(module, *llvm); selection.run(module, *llvm);
// return binary // return binary
std::unique_ptr<driver::module> res(driver::module::create(context, std::move(llvm))); std::unique_ptr<driver::module> res(driver::module::create(context, std::move(llvm)));

View File

@@ -79,7 +79,10 @@ int main() {
// shapes to benchmark // shapes to benchmark
typedef std::tuple<bool, bool, int, int, int> config_t; typedef std::tuple<bool, bool, int, int, int> config_t;
std::vector<config_t> configs; std::vector<config_t> configs;
for(auto x: std::vector<std::array<bool, 2>>{{false, false}}){ for(auto x: std::vector<std::array<bool, 2>>{{false, false},
{false, true},
{true, false},
{true, true}}){
std::vector<config_t> tmp = { std::vector<config_t> tmp = {
config_t{x[0], x[1], 4096, 4096, 4096} config_t{x[0], x[1], 4096, 4096, 4096}
// config_t{x[0], x[1], 16, 2048, 2048}, // config_t{x[0], x[1], 16, 2048, 2048},