[codegen] more hmma row-major handling

This commit is contained in:
Philippe Tillet
2019-09-24 19:35:46 -04:00
parent c24d55db23
commit a3bf3a1804
8 changed files with 60 additions and 30 deletions

View File

@@ -89,7 +89,7 @@ private:
public:
shared_tile(Type* ty, const shapes_t &shapes, Value* ptr, Builder &builder, Value* offset = nullptr);
shared_tile(Type* ty, const shapes_t &shapes, const std::vector<int> &order, Value* ptr, Builder &builder, Value* offset = nullptr);
void set_vector_size(unsigned vector_size);
void set_return_mode(bool return_vector);
void set_value(indices_t, Value *);
@@ -97,7 +97,8 @@ public:
Value* get_value(indices_t idx);
Value* get_pointer() { return ptr_; }
Value* get_offset() { return offset_; }
static Value* shared_offset(Builder& builder, const shapes_t& shapes, indices_t idx);
const std::vector<int>& get_order() { return order_; }
static Value* shared_offset(Builder& builder, const shapes_t& shapes, const std::vector<int>& order, indices_t idx);
private:
Value *ptr_;
@@ -106,6 +107,7 @@ private:
Value *offset_;
std::map<indices_t, Value*> ptr_cache_;
unsigned vector_size_;
std::vector<int> order_;
};
// Distribtued tile
@@ -123,6 +125,7 @@ public:
distributed_tile(Type *ty, const shapes_t& shapes, const std::vector<int>& order, const axes_t &axes, Builder &builder, bool vectorize);
void set_value(indices_t idx, Value *v);
Value* get_value(indices_t idx);
const std::vector<int>& get_order() { return order_; }
unsigned get_linear_index(indices_t idx);
indices_t get_ordered_indices(unsigned id);
void for_each(std::function<void(indices_t)> fn);

View File

@@ -57,7 +57,7 @@ unsigned allocation::num_bytes(ir::value *x) {
unsigned num_bytes = x->get_type()->get_primitive_size_in_bits() / 8;
unsigned pad = is_ld_padded(x);
if(pad > 0){
unsigned ld = x->get_type()->get_tile_shapes()[0];
unsigned ld = x->get_type()->get_tile_shapes()[tiles_->order(x)[0]];
num_bytes += pad * num_bytes / ld;
}
if(liveness_->has_double(x))

View File

@@ -218,6 +218,7 @@ void tiles::run(ir::module &) {
auto cmp = [&rank](ir::value* x, ir::value *y) { return rank(x) < rank(y); };
largest_[i] = *std::max_element(values.begin(), values.end(), cmp);
}
// find out the order of a group
for(size_t i = 0; i < num_groups; i++){
std::set<ir::io_inst*> io;
@@ -237,6 +238,20 @@ void tiles::run(ir::module &) {
}
order_[i] = order;
}
for(size_t i = 0; i < num_groups; i++){
bool is_hmma_op = hmma_[i] == HMMA_A_COL || hmma_[i] == HMMA_A_ROW ||
hmma_[i] == HMMA_B_COL || hmma_[i] == HMMA_B_ROW;
if(!is_hmma_op)
continue;
// extract copies to shared memory
std::vector<ir::copy_to_shared_inst*> cts;
for(ir::value* v: layout_->values(i))
if(auto *x = dynamic_cast<ir::copy_to_shared_inst*>(v))
cts.push_back(x);
if(cts.empty())
continue;
order_[i] = order(cts[0]->get_operand(0));
}
// tiling parameters
for(auto x: largest_){
ir::value *i = x.second;

View File

@@ -146,26 +146,26 @@ void shared_tile::extract_constant(const indices_t &arg_idx, indices_t &non_cst_
}
Value* shared_tile::shared_offset(llvm::IRBuilder<> &builder, const shapes_t& shapes, indices_t idx) {
Value* shared_tile::shared_offset(llvm::IRBuilder<> &builder, const shapes_t& shapes, const std::vector<int>& order, indices_t idx) {
Value *result = builder.getInt32(0);
result = builder.CreateAdd(result, idx[0]);
Value *ld = builder.getInt32(shapes[0]);
result = builder.CreateAdd(result, idx[order[0]]);
Value *ld = builder.getInt32(shapes[order[0]]);
for(size_t i = 1; i < idx.size(); i++) {
result = builder.CreateAdd(result, builder.CreateMul(idx[i], ld));
result = builder.CreateAdd(result, builder.CreateMul(idx[order[i]], ld));
if(i < idx.size() - 1){
ld = builder.CreateMul(ld, builder.getInt32(shapes[i]));
ld = builder.CreateMul(ld, builder.getInt32(shapes[order[i]]));
}
}
return result;
}
shared_tile::shared_tile(Type *ty, const shapes_t &shapes, Value *ptr, llvm::IRBuilder<> &builder, Value *offset):
tile(ty, shapes), ptr_(ptr), builder_(builder), offset_(offset), vector_size_(1){
shared_tile::shared_tile(Type *ty, const shapes_t &shapes, const std::vector<int>& order, Value *ptr, llvm::IRBuilder<> &builder, Value *offset):
tile(ty, shapes), order_(order), ptr_(ptr), builder_(builder), offset_(offset), vector_size_(1){
return_vector_ = false;
}
void shared_tile::set_value(indices_t idx, Value *value) {
Value *ptr = builder_.CreateGEP(ptr_, shared_offset(builder_, shapes_, idx));
Value *ptr = builder_.CreateGEP(ptr_, shared_offset(builder_, shapes_, order_, idx));
unsigned addr_space = ptr->getType()->getPointerAddressSpace();
ptr = builder_.CreateBitCast(ptr, value->getType()->getPointerTo(addr_space));
builder_.CreateStore(value, ptr);
@@ -196,7 +196,7 @@ Value* shared_tile::get_value(indices_t idx) {
// if(isa<Instruction>(non_cst_idx.front())){
// builder_.SetInsertPoint((Instruction*)non_cst_idx.front());
// }
base_ptr = builder_.CreateGEP(ptr_, shared_offset(builder_, shapes_, non_cst_idx));
base_ptr = builder_.CreateGEP(ptr_, shared_offset(builder_, shapes_, order_, non_cst_idx));
if(vector_size_ > 1){
Type *vec_ty = VectorType::get(ty, vector_size);
Type *vec_ptr_ty = PointerType::get(vec_ty, base_ptr->getType()->getPointerAddressSpace());
@@ -204,7 +204,7 @@ Value* shared_tile::get_value(indices_t idx) {
}
// builder_.SetInsertPoint(store);
}
Value *offset = shared_offset(builder_, shapes_, cst_idx);
Value *offset = shared_offset(builder_, shapes_, order_, cst_idx);
Value *div = offset;
if(vector_size_ > 1)
div = builder_.CreateUDiv(offset, builder_.getInt32(vector_size_));
@@ -721,10 +721,13 @@ void selection::init_axes(ir::value *v, IRBuilder<> &builder, Value *u_thread_id
* ------------------- */
void selection::create_shared_tile(ir::value *v, IRBuilder<> &builder, Value *sh_mem_ptr) {
if(tmap_.find(v) != tmap_.end())
return;
auto order = tiles_->order(v);
auto shapes = v->get_type()->get_tile_shapes();
unsigned pad = alloc_->is_ld_padded(v);
if(pad > 0)
shapes[0] += pad;
shapes[order[0]] += pad;
Type* ty = llvm_type(v->get_type()->get_scalar_ty(), builder.getContext());
// shared copy
PointerType *ptr_ty = ty->getPointerTo(sh_mem_ptr->getType()->getPointerAddressSpace());
@@ -744,15 +747,15 @@ void selection::create_shared_tile(ir::value *v, IRBuilder<> &builder, Value *sh
Value *pre_ptr = builder.CreateGEP(sh_mem_ptr, builder.getInt32(alloc_->offset(v)));
pre_ptr = builder.CreateBitCast(pre_ptr, ptr->getType());
Value *next_ptr = builder.CreateGEP(ptr, offset, "next_ptr");
tmap_.insert({phi, new shared_tile(ty, shapes, ptr, builder, offset)});
tmap_.insert({v, new shared_tile(ty, shapes, pre_ptr, builder)});
tmap_.insert({info.latch, new shared_tile(ty, shapes, next_ptr, builder)});
tmap_.insert({phi, new shared_tile(ty, shapes, order, ptr, builder, offset)});
tmap_.insert({v, new shared_tile(ty, shapes, order, pre_ptr, builder)});
tmap_.insert({info.latch, new shared_tile(ty, shapes, order, next_ptr, builder)});
}
else {
size_t offset = alloc_->offset(v);
Value *ptr = builder.CreateGEP(sh_mem_ptr, builder.getInt32(offset));
ptr = builder.CreateBitCast(ptr, ptr_ty);
tmap_.insert({v, new shared_tile(ty, shapes, ptr, builder)});
tmap_.insert({v, new shared_tile(ty, shapes, order, ptr, builder)});
}
}
@@ -920,7 +923,7 @@ void selection::lower_reduce(ir::reduce_inst *x, LLVMContext &ctx, Function *fn,
write_idx.insert(write_idx.begin() + axis, lane);
// shared memory write pointer
Value *write_offset = shared_tile::shared_offset(builder, op_tile->get_shapes(), write_idx);
Value *write_offset = shared_tile::shared_offset(builder, op_tile->get_shapes(), op_tile->get_order(), write_idx);
Value *write_ptr = builder.CreateGEP(base_ptr, write_offset);
// initialize shared memory
@@ -933,7 +936,7 @@ void selection::lower_reduce(ir::reduce_inst *x, LLVMContext &ctx, Function *fn,
indices_t current(write_idx.size(), builder.getInt32(0));
current[axis] = builder.getInt32(i);
// shared memory offset
Value *read_offset = shared_tile::shared_offset(builder, op_tile->get_shapes(), current);
Value *read_offset = shared_tile::shared_offset(builder, op_tile->get_shapes(), op_tile->get_order(), current);
Value *is_active = builder.CreateICmpULT(lane, builder.getInt32(i));
read_offset = builder.CreateSelect(is_active, read_offset, builder.getInt32(0));
// shared memory read pointer
@@ -949,7 +952,7 @@ void selection::lower_reduce(ir::reduce_inst *x, LLVMContext &ctx, Function *fn,
// result is on the first lane of shared memory
indices_t final = write_idx;
final[axis] = builder.getInt32(0);
Value *read_offset = shared_tile::shared_offset(builder, op_tile->get_shapes(), final);
Value *read_offset = shared_tile::shared_offset(builder, op_tile->get_shapes(), op_tile->get_order(), final);
Value *read_ptr = builder.CreateGEP(base_ptr, read_offset);
tgt_->add_barrier(module, builder);
result = builder.CreateLoad(read_ptr);
@@ -1077,17 +1080,24 @@ void selection::lower_hmma_dot(ir::dot_inst *dot, LLVMContext &ctx, Function *fn
Value *offset_b_k = offset_b_k_;
Value* u_thread_id = tgt_->get_local_id(builder.GetInsertBlock()->getModule(), builder, 0);
if(dot->is_a_trans()){
auto ord_a = tiles_->order(dot->get_operand(0));
auto ord_b = tiles_->order(dot->get_operand(1));
bool is_a_row = dot->is_a_trans() ^ ord_a[ord_a.size() - 2] == 1;
bool is_b_row = dot->is_b_trans() ^ ord_b[ord_b.size() - 2] == 1;
if(is_a_row){
offset_a_i = builder.CreateAdd(offset_a_i, builder.CreateURem(u_thread_id, builder.getInt32(4)));
offset_a_k = builder.getInt32(0);
}
if(!dot->is_b_trans()){
if(!is_b_row){
offset_b_j = builder.CreateAdd(offset_b_j, builder.CreateURem(u_thread_id, builder.getInt32(4)));
offset_b_k = builder.getInt32(0);
}
std::string op_a = dot->is_a_trans() ? "row" : "col";
std::string op_b = dot->is_b_trans() ? "row" : "col";
std::string op_a = is_a_row ? "row" : "col";
std::string op_b = is_b_row ? "row" : "col";
InlineAsm *mma_fn = InlineAsm::get(mma_ty, " mma.sync.aligned.m8n8k4." + op_a + "." + op_b + ".f32.f16.f16.f32 "
"{$0, $1, $2, $3, $4, $5, $6, $7}, "

View File

@@ -242,7 +242,7 @@ cu_module::cu_module(driver::context * context, std::unique_ptr<llvm::Module> ll
cu_module::cu_module(driver::context * context, std::string const & source) : module(context, CUmodule(), true), source_(source){
// std::cout << source << std::endl;
cu_context::context_switcher ctx_switch(*context);
cu_context::context_switcher ctx(*context);
// JIT compile source-code
CUjit_option opt[] = {CU_JIT_ERROR_LOG_BUFFER_SIZE_BYTES, CU_JIT_ERROR_LOG_BUFFER};
unsigned int errbufsize = 8096;

View File

@@ -222,13 +222,11 @@ std::unique_ptr<driver::module> function::make_bin(ir::module &module, driver::c
axes.run(module);
layouts.run(module);
coalesce.run(module);
// ir::print(module, std::cout);
dce.run(module);
align.run(module);
dce.run(module);
tiles.run(module);
reassociate.run(module);
peephole.run(module);
dce.run(module);
cts.run(module);
liveness.run(module);
@@ -242,6 +240,7 @@ std::unique_ptr<driver::module> function::make_bin(ir::module &module, driver::c
layouts.run(module);
align.run(module);
tiles.run(module);
// ir::print(module, std::cout);
selection.run(module, *llvm);
// return binary
std::unique_ptr<driver::module> res(driver::module::create(context, std::move(llvm)));

View File

@@ -79,7 +79,10 @@ int main() {
// shapes to benchmark
typedef std::tuple<bool, bool, int, int, int> config_t;
std::vector<config_t> configs;
for(auto x: std::vector<std::array<bool, 2>>{{false, false}}){
for(auto x: std::vector<std::array<bool, 2>>{{false, false},
{false, true},
{true, false},
{true, true}}){
std::vector<config_t> tmp = {
config_t{x[0], x[1], 4096, 4096, 4096}
// config_t{x[0], x[1], 16, 2048, 2048},

View File

@@ -59,7 +59,7 @@ void copy3d(TYPE * X __noalias __readonly __aligned(16),
}
)";
const char* copy_nd[] = {copy1d, copy2d, copy3d};
const char* copy_nd[] = {copy1d, copy2d, copy3d};
}