[codegen] more hmma row-major handling
This commit is contained in:
@@ -89,7 +89,7 @@ private:
|
||||
|
||||
|
||||
public:
|
||||
shared_tile(Type* ty, const shapes_t &shapes, Value* ptr, Builder &builder, Value* offset = nullptr);
|
||||
shared_tile(Type* ty, const shapes_t &shapes, const std::vector<int> &order, Value* ptr, Builder &builder, Value* offset = nullptr);
|
||||
void set_vector_size(unsigned vector_size);
|
||||
void set_return_mode(bool return_vector);
|
||||
void set_value(indices_t, Value *);
|
||||
@@ -97,7 +97,8 @@ public:
|
||||
Value* get_value(indices_t idx);
|
||||
Value* get_pointer() { return ptr_; }
|
||||
Value* get_offset() { return offset_; }
|
||||
static Value* shared_offset(Builder& builder, const shapes_t& shapes, indices_t idx);
|
||||
const std::vector<int>& get_order() { return order_; }
|
||||
static Value* shared_offset(Builder& builder, const shapes_t& shapes, const std::vector<int>& order, indices_t idx);
|
||||
|
||||
private:
|
||||
Value *ptr_;
|
||||
@@ -106,6 +107,7 @@ private:
|
||||
Value *offset_;
|
||||
std::map<indices_t, Value*> ptr_cache_;
|
||||
unsigned vector_size_;
|
||||
std::vector<int> order_;
|
||||
};
|
||||
|
||||
// Distribtued tile
|
||||
@@ -123,6 +125,7 @@ public:
|
||||
distributed_tile(Type *ty, const shapes_t& shapes, const std::vector<int>& order, const axes_t &axes, Builder &builder, bool vectorize);
|
||||
void set_value(indices_t idx, Value *v);
|
||||
Value* get_value(indices_t idx);
|
||||
const std::vector<int>& get_order() { return order_; }
|
||||
unsigned get_linear_index(indices_t idx);
|
||||
indices_t get_ordered_indices(unsigned id);
|
||||
void for_each(std::function<void(indices_t)> fn);
|
||||
|
@@ -57,7 +57,7 @@ unsigned allocation::num_bytes(ir::value *x) {
|
||||
unsigned num_bytes = x->get_type()->get_primitive_size_in_bits() / 8;
|
||||
unsigned pad = is_ld_padded(x);
|
||||
if(pad > 0){
|
||||
unsigned ld = x->get_type()->get_tile_shapes()[0];
|
||||
unsigned ld = x->get_type()->get_tile_shapes()[tiles_->order(x)[0]];
|
||||
num_bytes += pad * num_bytes / ld;
|
||||
}
|
||||
if(liveness_->has_double(x))
|
||||
|
@@ -218,6 +218,7 @@ void tiles::run(ir::module &) {
|
||||
auto cmp = [&rank](ir::value* x, ir::value *y) { return rank(x) < rank(y); };
|
||||
largest_[i] = *std::max_element(values.begin(), values.end(), cmp);
|
||||
}
|
||||
|
||||
// find out the order of a group
|
||||
for(size_t i = 0; i < num_groups; i++){
|
||||
std::set<ir::io_inst*> io;
|
||||
@@ -237,6 +238,20 @@ void tiles::run(ir::module &) {
|
||||
}
|
||||
order_[i] = order;
|
||||
}
|
||||
for(size_t i = 0; i < num_groups; i++){
|
||||
bool is_hmma_op = hmma_[i] == HMMA_A_COL || hmma_[i] == HMMA_A_ROW ||
|
||||
hmma_[i] == HMMA_B_COL || hmma_[i] == HMMA_B_ROW;
|
||||
if(!is_hmma_op)
|
||||
continue;
|
||||
// extract copies to shared memory
|
||||
std::vector<ir::copy_to_shared_inst*> cts;
|
||||
for(ir::value* v: layout_->values(i))
|
||||
if(auto *x = dynamic_cast<ir::copy_to_shared_inst*>(v))
|
||||
cts.push_back(x);
|
||||
if(cts.empty())
|
||||
continue;
|
||||
order_[i] = order(cts[0]->get_operand(0));
|
||||
}
|
||||
// tiling parameters
|
||||
for(auto x: largest_){
|
||||
ir::value *i = x.second;
|
||||
|
@@ -146,26 +146,26 @@ void shared_tile::extract_constant(const indices_t &arg_idx, indices_t &non_cst_
|
||||
}
|
||||
|
||||
|
||||
Value* shared_tile::shared_offset(llvm::IRBuilder<> &builder, const shapes_t& shapes, indices_t idx) {
|
||||
Value* shared_tile::shared_offset(llvm::IRBuilder<> &builder, const shapes_t& shapes, const std::vector<int>& order, indices_t idx) {
|
||||
Value *result = builder.getInt32(0);
|
||||
result = builder.CreateAdd(result, idx[0]);
|
||||
Value *ld = builder.getInt32(shapes[0]);
|
||||
result = builder.CreateAdd(result, idx[order[0]]);
|
||||
Value *ld = builder.getInt32(shapes[order[0]]);
|
||||
for(size_t i = 1; i < idx.size(); i++) {
|
||||
result = builder.CreateAdd(result, builder.CreateMul(idx[i], ld));
|
||||
result = builder.CreateAdd(result, builder.CreateMul(idx[order[i]], ld));
|
||||
if(i < idx.size() - 1){
|
||||
ld = builder.CreateMul(ld, builder.getInt32(shapes[i]));
|
||||
ld = builder.CreateMul(ld, builder.getInt32(shapes[order[i]]));
|
||||
}
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
shared_tile::shared_tile(Type *ty, const shapes_t &shapes, Value *ptr, llvm::IRBuilder<> &builder, Value *offset):
|
||||
tile(ty, shapes), ptr_(ptr), builder_(builder), offset_(offset), vector_size_(1){
|
||||
shared_tile::shared_tile(Type *ty, const shapes_t &shapes, const std::vector<int>& order, Value *ptr, llvm::IRBuilder<> &builder, Value *offset):
|
||||
tile(ty, shapes), order_(order), ptr_(ptr), builder_(builder), offset_(offset), vector_size_(1){
|
||||
return_vector_ = false;
|
||||
}
|
||||
|
||||
void shared_tile::set_value(indices_t idx, Value *value) {
|
||||
Value *ptr = builder_.CreateGEP(ptr_, shared_offset(builder_, shapes_, idx));
|
||||
Value *ptr = builder_.CreateGEP(ptr_, shared_offset(builder_, shapes_, order_, idx));
|
||||
unsigned addr_space = ptr->getType()->getPointerAddressSpace();
|
||||
ptr = builder_.CreateBitCast(ptr, value->getType()->getPointerTo(addr_space));
|
||||
builder_.CreateStore(value, ptr);
|
||||
@@ -196,7 +196,7 @@ Value* shared_tile::get_value(indices_t idx) {
|
||||
// if(isa<Instruction>(non_cst_idx.front())){
|
||||
// builder_.SetInsertPoint((Instruction*)non_cst_idx.front());
|
||||
// }
|
||||
base_ptr = builder_.CreateGEP(ptr_, shared_offset(builder_, shapes_, non_cst_idx));
|
||||
base_ptr = builder_.CreateGEP(ptr_, shared_offset(builder_, shapes_, order_, non_cst_idx));
|
||||
if(vector_size_ > 1){
|
||||
Type *vec_ty = VectorType::get(ty, vector_size);
|
||||
Type *vec_ptr_ty = PointerType::get(vec_ty, base_ptr->getType()->getPointerAddressSpace());
|
||||
@@ -204,7 +204,7 @@ Value* shared_tile::get_value(indices_t idx) {
|
||||
}
|
||||
// builder_.SetInsertPoint(store);
|
||||
}
|
||||
Value *offset = shared_offset(builder_, shapes_, cst_idx);
|
||||
Value *offset = shared_offset(builder_, shapes_, order_, cst_idx);
|
||||
Value *div = offset;
|
||||
if(vector_size_ > 1)
|
||||
div = builder_.CreateUDiv(offset, builder_.getInt32(vector_size_));
|
||||
@@ -721,10 +721,13 @@ void selection::init_axes(ir::value *v, IRBuilder<> &builder, Value *u_thread_id
|
||||
* ------------------- */
|
||||
|
||||
void selection::create_shared_tile(ir::value *v, IRBuilder<> &builder, Value *sh_mem_ptr) {
|
||||
if(tmap_.find(v) != tmap_.end())
|
||||
return;
|
||||
auto order = tiles_->order(v);
|
||||
auto shapes = v->get_type()->get_tile_shapes();
|
||||
unsigned pad = alloc_->is_ld_padded(v);
|
||||
if(pad > 0)
|
||||
shapes[0] += pad;
|
||||
shapes[order[0]] += pad;
|
||||
Type* ty = llvm_type(v->get_type()->get_scalar_ty(), builder.getContext());
|
||||
// shared copy
|
||||
PointerType *ptr_ty = ty->getPointerTo(sh_mem_ptr->getType()->getPointerAddressSpace());
|
||||
@@ -744,15 +747,15 @@ void selection::create_shared_tile(ir::value *v, IRBuilder<> &builder, Value *sh
|
||||
Value *pre_ptr = builder.CreateGEP(sh_mem_ptr, builder.getInt32(alloc_->offset(v)));
|
||||
pre_ptr = builder.CreateBitCast(pre_ptr, ptr->getType());
|
||||
Value *next_ptr = builder.CreateGEP(ptr, offset, "next_ptr");
|
||||
tmap_.insert({phi, new shared_tile(ty, shapes, ptr, builder, offset)});
|
||||
tmap_.insert({v, new shared_tile(ty, shapes, pre_ptr, builder)});
|
||||
tmap_.insert({info.latch, new shared_tile(ty, shapes, next_ptr, builder)});
|
||||
tmap_.insert({phi, new shared_tile(ty, shapes, order, ptr, builder, offset)});
|
||||
tmap_.insert({v, new shared_tile(ty, shapes, order, pre_ptr, builder)});
|
||||
tmap_.insert({info.latch, new shared_tile(ty, shapes, order, next_ptr, builder)});
|
||||
}
|
||||
else {
|
||||
size_t offset = alloc_->offset(v);
|
||||
Value *ptr = builder.CreateGEP(sh_mem_ptr, builder.getInt32(offset));
|
||||
ptr = builder.CreateBitCast(ptr, ptr_ty);
|
||||
tmap_.insert({v, new shared_tile(ty, shapes, ptr, builder)});
|
||||
tmap_.insert({v, new shared_tile(ty, shapes, order, ptr, builder)});
|
||||
}
|
||||
}
|
||||
|
||||
@@ -920,7 +923,7 @@ void selection::lower_reduce(ir::reduce_inst *x, LLVMContext &ctx, Function *fn,
|
||||
write_idx.insert(write_idx.begin() + axis, lane);
|
||||
|
||||
// shared memory write pointer
|
||||
Value *write_offset = shared_tile::shared_offset(builder, op_tile->get_shapes(), write_idx);
|
||||
Value *write_offset = shared_tile::shared_offset(builder, op_tile->get_shapes(), op_tile->get_order(), write_idx);
|
||||
Value *write_ptr = builder.CreateGEP(base_ptr, write_offset);
|
||||
|
||||
// initialize shared memory
|
||||
@@ -933,7 +936,7 @@ void selection::lower_reduce(ir::reduce_inst *x, LLVMContext &ctx, Function *fn,
|
||||
indices_t current(write_idx.size(), builder.getInt32(0));
|
||||
current[axis] = builder.getInt32(i);
|
||||
// shared memory offset
|
||||
Value *read_offset = shared_tile::shared_offset(builder, op_tile->get_shapes(), current);
|
||||
Value *read_offset = shared_tile::shared_offset(builder, op_tile->get_shapes(), op_tile->get_order(), current);
|
||||
Value *is_active = builder.CreateICmpULT(lane, builder.getInt32(i));
|
||||
read_offset = builder.CreateSelect(is_active, read_offset, builder.getInt32(0));
|
||||
// shared memory read pointer
|
||||
@@ -949,7 +952,7 @@ void selection::lower_reduce(ir::reduce_inst *x, LLVMContext &ctx, Function *fn,
|
||||
// result is on the first lane of shared memory
|
||||
indices_t final = write_idx;
|
||||
final[axis] = builder.getInt32(0);
|
||||
Value *read_offset = shared_tile::shared_offset(builder, op_tile->get_shapes(), final);
|
||||
Value *read_offset = shared_tile::shared_offset(builder, op_tile->get_shapes(), op_tile->get_order(), final);
|
||||
Value *read_ptr = builder.CreateGEP(base_ptr, read_offset);
|
||||
tgt_->add_barrier(module, builder);
|
||||
result = builder.CreateLoad(read_ptr);
|
||||
@@ -1077,17 +1080,24 @@ void selection::lower_hmma_dot(ir::dot_inst *dot, LLVMContext &ctx, Function *fn
|
||||
Value *offset_b_k = offset_b_k_;
|
||||
|
||||
Value* u_thread_id = tgt_->get_local_id(builder.GetInsertBlock()->getModule(), builder, 0);
|
||||
if(dot->is_a_trans()){
|
||||
|
||||
auto ord_a = tiles_->order(dot->get_operand(0));
|
||||
auto ord_b = tiles_->order(dot->get_operand(1));
|
||||
|
||||
bool is_a_row = dot->is_a_trans() ^ ord_a[ord_a.size() - 2] == 1;
|
||||
bool is_b_row = dot->is_b_trans() ^ ord_b[ord_b.size() - 2] == 1;
|
||||
|
||||
if(is_a_row){
|
||||
offset_a_i = builder.CreateAdd(offset_a_i, builder.CreateURem(u_thread_id, builder.getInt32(4)));
|
||||
offset_a_k = builder.getInt32(0);
|
||||
}
|
||||
if(!dot->is_b_trans()){
|
||||
if(!is_b_row){
|
||||
offset_b_j = builder.CreateAdd(offset_b_j, builder.CreateURem(u_thread_id, builder.getInt32(4)));
|
||||
offset_b_k = builder.getInt32(0);
|
||||
}
|
||||
|
||||
std::string op_a = dot->is_a_trans() ? "row" : "col";
|
||||
std::string op_b = dot->is_b_trans() ? "row" : "col";
|
||||
std::string op_a = is_a_row ? "row" : "col";
|
||||
std::string op_b = is_b_row ? "row" : "col";
|
||||
|
||||
InlineAsm *mma_fn = InlineAsm::get(mma_ty, " mma.sync.aligned.m8n8k4." + op_a + "." + op_b + ".f32.f16.f16.f32 "
|
||||
"{$0, $1, $2, $3, $4, $5, $6, $7}, "
|
||||
|
@@ -242,7 +242,7 @@ cu_module::cu_module(driver::context * context, std::unique_ptr<llvm::Module> ll
|
||||
|
||||
cu_module::cu_module(driver::context * context, std::string const & source) : module(context, CUmodule(), true), source_(source){
|
||||
// std::cout << source << std::endl;
|
||||
cu_context::context_switcher ctx_switch(*context);
|
||||
cu_context::context_switcher ctx(*context);
|
||||
// JIT compile source-code
|
||||
CUjit_option opt[] = {CU_JIT_ERROR_LOG_BUFFER_SIZE_BYTES, CU_JIT_ERROR_LOG_BUFFER};
|
||||
unsigned int errbufsize = 8096;
|
||||
|
@@ -222,13 +222,11 @@ std::unique_ptr<driver::module> function::make_bin(ir::module &module, driver::c
|
||||
axes.run(module);
|
||||
layouts.run(module);
|
||||
coalesce.run(module);
|
||||
// ir::print(module, std::cout);
|
||||
dce.run(module);
|
||||
align.run(module);
|
||||
dce.run(module);
|
||||
tiles.run(module);
|
||||
reassociate.run(module);
|
||||
peephole.run(module);
|
||||
dce.run(module);
|
||||
cts.run(module);
|
||||
liveness.run(module);
|
||||
@@ -242,6 +240,7 @@ std::unique_ptr<driver::module> function::make_bin(ir::module &module, driver::c
|
||||
layouts.run(module);
|
||||
align.run(module);
|
||||
tiles.run(module);
|
||||
// ir::print(module, std::cout);
|
||||
selection.run(module, *llvm);
|
||||
// return binary
|
||||
std::unique_ptr<driver::module> res(driver::module::create(context, std::move(llvm)));
|
||||
|
@@ -79,7 +79,10 @@ int main() {
|
||||
// shapes to benchmark
|
||||
typedef std::tuple<bool, bool, int, int, int> config_t;
|
||||
std::vector<config_t> configs;
|
||||
for(auto x: std::vector<std::array<bool, 2>>{{false, false}}){
|
||||
for(auto x: std::vector<std::array<bool, 2>>{{false, false},
|
||||
{false, true},
|
||||
{true, false},
|
||||
{true, true}}){
|
||||
std::vector<config_t> tmp = {
|
||||
config_t{x[0], x[1], 4096, 4096, 4096}
|
||||
// config_t{x[0], x[1], 16, 2048, 2048},
|
||||
|
@@ -59,7 +59,7 @@ void copy3d(TYPE * X __noalias __readonly __aligned(16),
|
||||
}
|
||||
)";
|
||||
|
||||
const char* copy_nd[] = {copy1d, copy2d, copy3d};
|
||||
const char* copy_nd[] = {copy1d, copy2d, copy3d};
|
||||
|
||||
}
|
||||
|
||||
|
Reference in New Issue
Block a user