[general] some cleaning:
* trans/dot -> peephole * isel -> added function for tile-level lowering
This commit is contained in:
@@ -120,36 +120,66 @@ class selection{
|
|||||||
typedef std::map<ir::value *, llvm::Value *> vmap_t;
|
typedef std::map<ir::value *, llvm::Value *> vmap_t;
|
||||||
typedef std::map<ir::value *, tile *> tmap_t;
|
typedef std::map<ir::value *, tile *> tmap_t;
|
||||||
|
|
||||||
|
typedef llvm::LLVMContext LLVMContext;
|
||||||
|
typedef llvm::IRBuilder<> Builder;
|
||||||
|
typedef llvm::Type Type;
|
||||||
|
typedef llvm::Value Value;
|
||||||
|
typedef llvm::Module Module;
|
||||||
|
typedef llvm::Instruction Instruction;
|
||||||
|
typedef llvm::Constant Constant;
|
||||||
|
typedef llvm::ArrayType ArrayType;
|
||||||
|
typedef llvm::Function Function;
|
||||||
|
|
||||||
private:
|
private:
|
||||||
// utils
|
// utils
|
||||||
llvm::Type *make_vector_ty(llvm::Type *ty, size_t vector_size);
|
Type *make_vector_ty(Type *ty, size_t vector_size);
|
||||||
std::vector<unsigned> extract_shapes(ir::value *v);
|
std::vector<unsigned> extract_shapes(ir::value *v);
|
||||||
|
|
||||||
// LLVM conversions
|
// LLVM conversions
|
||||||
llvm::Type* llvm_type(ir::type *ty, llvm::LLVMContext &ctx);
|
Type* llvm_type(ir::type *ty, LLVMContext &ctx);
|
||||||
llvm::Value* llvm_value(ir::value *v, llvm::IRBuilder<> &builder);
|
Value* llvm_value(ir::value *v, Builder &builder);
|
||||||
llvm::Instruction* llvm_inst(ir::instruction *inst, std::function<llvm::Value*(ir::value*)> value, llvm::IRBuilder<> &builder);
|
Instruction* llvm_inst(ir::instruction *inst, std::function<Value*(ir::value*)> value, Builder &builder);
|
||||||
llvm::Constant* llvm_constant(ir::constant *cst, llvm::LLVMContext &ctx);
|
Constant* llvm_constant(ir::constant *cst, LLVMContext &ctx);
|
||||||
llvm::Value* llvm_alloc_const(ir::alloc_const *v, llvm::Module *module, llvm::IRBuilder<> &builder);
|
Value* llvm_alloc_const(ir::alloc_const *v, Module *module, Builder &builder);
|
||||||
llvm::ArrayType* llvm_linearized_tile_type(ir::type *ty, llvm::LLVMContext &ctx);
|
ArrayType* llvm_linearized_tile_type(ir::type *ty, LLVMContext &ctx);
|
||||||
|
|
||||||
// grid construction
|
// grid construction
|
||||||
void create_grids(std::vector<ir::value *> &grids,
|
void create_grids(std::vector<ir::value *> &grids,
|
||||||
std::map<unsigned, ir::value *> &references,
|
std::map<unsigned, ir::value *> &references,
|
||||||
ir::function *fn);
|
ir::function *fn);
|
||||||
void create_tile(ir::value *v, llvm::IRBuilder<> &builder, const std::map<unsigned, ir::value *> &references, std::set<ir::value *> &seen, llvm::Value *sh_mem_ptr);
|
void create_tile(ir::value *v, Builder &builder, const std::map<unsigned, ir::value *> &references, std::set<ir::value *> &seen, Value *sh_mem_ptr);
|
||||||
void init_axes(ir::value *i, llvm::IRBuilder<> &builder, llvm::Value *u_thread_id, llvm::Value *u_warp_id);
|
void init_axes(ir::value *i, Builder &builder, Value *u_thread_id, Value *u_warp_id);
|
||||||
void init_grids(ir::function *fn, llvm::IRBuilder<> &builder, llvm::Value *sh_mem_ptr);
|
void init_grids(ir::function *fn, Builder &builder, Value *sh_mem_ptr);
|
||||||
|
|
||||||
|
// lower scalar instruction
|
||||||
|
void lower_instruction(ir::instruction *src, Builder &builder);
|
||||||
|
// lower tile instruction
|
||||||
|
void lower_masked_store(ir::masked_store_inst *x, LLVMContext &ctx, Function *fn, Builder &builder);
|
||||||
|
void lower_store(ir::store_inst *x, LLVMContext &ctx, Function *fn, Builder &builder);
|
||||||
|
void lower_downcast(ir::downcast_inst *x, LLVMContext &ctx, Function *fn, Builder &builder);
|
||||||
|
void lower_reduce(ir::reduce_inst *x, LLVMContext &ctx, Function *fn, Builder &builder);
|
||||||
|
void lower_dynamic_range_idx(ir::nv_dynamic_range_idx_inst *x, LLVMContext &ctx, Function *fn, Builder &builder);
|
||||||
|
void lower_reshape(ir::reshape_inst* x, LLVMContext &ctx, Function *fn, Builder &builder);
|
||||||
|
void lower_splat(ir::splat_inst *x, LLVMContext &ctx, Function *fn, Builder &builder);
|
||||||
|
void lower_broadcast(ir::broadcast_inst *x, LLVMContext &ctx, Function *fn, Builder &builder);
|
||||||
|
void lower_vectorize(ir::vectorize_inst *x, LLVMContext &ctx, Function *fn, Builder &builder);
|
||||||
|
void lower_copy_to_shared(ir::copy_to_shared_inst *x, LLVMContext &ctx, Function *fn, Builder &builder);
|
||||||
|
void lower_trans(ir::trans_inst *x, LLVMContext &ctx, Function *fn, Builder &builder);
|
||||||
|
void lower_hmma_dot(ir::dot_inst *x, LLVMContext &ctx, Function *fn, Builder &builder);
|
||||||
|
void lower_scalar_dot(ir::dot_inst *x, LLVMContext &ctx, Function *fn, Builder &builder);
|
||||||
|
void lower_dot(ir::dot_inst *x, LLVMContext &ctx, Function *fn, Builder &builder);
|
||||||
|
void lower_masked_load(ir::masked_load_inst *x, LLVMContext &ctx, Function *fn, Builder &builder);
|
||||||
|
void lower_load(ir::load_inst *x, LLVMContext &ctx, Function *fn, Builder &builder);
|
||||||
|
void lower_elementwise(ir::instruction *x, LLVMContext &ctx, Function *fn, Builder &builder);
|
||||||
|
void lower_tile_instruction(ir::instruction *src, Builder &builder);
|
||||||
|
|
||||||
|
|
||||||
// lowering
|
|
||||||
void lower_instruction(ir::instruction *src, llvm::IRBuilder<> &builder);
|
|
||||||
void lower_tile_instruction(ir::instruction *src, llvm::IRBuilder<> &builder);
|
|
||||||
|
|
||||||
public:
|
public:
|
||||||
selection(analysis::shmem::allocation *alloc, analysis::tune *params, analysis::shmem::info *buffer_info, analysis::alignment_info *ax_info, target *tgt)
|
selection(analysis::shmem::allocation *alloc, analysis::tune *params, analysis::shmem::info *buffer_info, analysis::alignment_info *ax_info, target *tgt)
|
||||||
: alloc_(alloc), params_(params), buffer_info_(buffer_info), axis_info_(ax_info), tgt_(tgt){ }
|
: alloc_(alloc), params_(params), buffer_info_(buffer_info), axis_info_(ax_info), tgt_(tgt){ }
|
||||||
|
|
||||||
void run(ir::module &src, llvm::Module &dst);
|
void run(ir::module &src, Module &dst);
|
||||||
|
|
||||||
private:
|
private:
|
||||||
vmap_t vmap_;
|
vmap_t vmap_;
|
||||||
@@ -160,9 +190,9 @@ private:
|
|||||||
analysis::shmem::info *buffer_info_;
|
analysis::shmem::info *buffer_info_;
|
||||||
analysis::alignment_info *axis_info_;
|
analysis::alignment_info *axis_info_;
|
||||||
std::map<unsigned, distributed_axis> axes_;
|
std::map<unsigned, distributed_axis> axes_;
|
||||||
llvm::Value *sh_mem_ptr_;
|
Value *sh_mem_ptr_;
|
||||||
llvm::Value *offset_a_i_, *offset_a_k_;
|
Value *offset_a_i_, *offset_a_k_;
|
||||||
llvm::Value *offset_b_j_, *offset_b_k_;
|
Value *offset_b_j_, *offset_b_k_;
|
||||||
unsigned num_packs_0_, num_packs_1_;
|
unsigned num_packs_0_, num_packs_1_;
|
||||||
unsigned pack_size_0_, pack_size_1_;
|
unsigned pack_size_0_, pack_size_1_;
|
||||||
};
|
};
|
||||||
|
@@ -1,35 +0,0 @@
|
|||||||
#ifndef TDL_INCLUDE_CODEGEN_OPTIMIZE_DOT_H
|
|
||||||
#define TDL_INCLUDE_CODEGEN_OPTIMIZE_DOT_H
|
|
||||||
|
|
||||||
#include <tuple>
|
|
||||||
#include <vector>
|
|
||||||
#include <set>
|
|
||||||
|
|
||||||
namespace triton {
|
|
||||||
|
|
||||||
namespace ir {
|
|
||||||
class module;
|
|
||||||
}
|
|
||||||
|
|
||||||
namespace codegen{
|
|
||||||
|
|
||||||
namespace analysis{
|
|
||||||
class tune;
|
|
||||||
}
|
|
||||||
|
|
||||||
namespace transform{
|
|
||||||
|
|
||||||
class optimize_dot {
|
|
||||||
public:
|
|
||||||
optimize_dot(analysis::tune* params): params_(params) {}
|
|
||||||
void run(ir::module &mod);
|
|
||||||
|
|
||||||
private:
|
|
||||||
analysis::tune* params_;
|
|
||||||
};
|
|
||||||
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
#endif
|
|
@@ -14,6 +14,7 @@ namespace ir {
|
|||||||
class trans_inst;
|
class trans_inst;
|
||||||
class builder;
|
class builder;
|
||||||
class constant_int;
|
class constant_int;
|
||||||
|
class dot_inst;
|
||||||
}
|
}
|
||||||
|
|
||||||
namespace codegen{
|
namespace codegen{
|
||||||
@@ -22,6 +23,8 @@ namespace transform{
|
|||||||
class peephole {
|
class peephole {
|
||||||
private:
|
private:
|
||||||
bool rewrite_trans_phi(ir::instruction* value, ir::builder &builder);
|
bool rewrite_trans_phi(ir::instruction* value, ir::builder &builder);
|
||||||
|
bool rewrite_dot_fp32(ir::dot_inst *dot, ir::builder& builder, bool trans_a, bool trans_b, ir::value *A, ir::value *B, ir::value *D);
|
||||||
|
bool rewrite_dot_hmma(ir::dot_inst *dot, ir::builder& builder, bool trans_a, bool trans_b, ir::value *A, ir::value *B, ir::value *D);
|
||||||
bool rewrite_dot(ir::instruction *value, ir::builder& builder);
|
bool rewrite_dot(ir::instruction *value, ir::builder& builder);
|
||||||
bool rewrite_unit_red(ir::instruction *value, ir::builder& builder);
|
bool rewrite_unit_red(ir::instruction *value, ir::builder& builder);
|
||||||
bool rewrite_gep_ptr_min_off_plus_off(ir::instruction *value, ir::builder& builder);
|
bool rewrite_gep_ptr_min_off_plus_off(ir::instruction *value, ir::builder& builder);
|
@@ -15,9 +15,8 @@
|
|||||||
#include "triton/codegen/analysis/shmem/liveness.h"
|
#include "triton/codegen/analysis/shmem/liveness.h"
|
||||||
#include "triton/codegen/analysis/shmem/info.h"
|
#include "triton/codegen/analysis/shmem/info.h"
|
||||||
#include "triton/codegen/analysis/alignment.h"
|
#include "triton/codegen/analysis/alignment.h"
|
||||||
#include "triton/codegen/transform/dot.h"
|
|
||||||
#include "triton/codegen/transform/dce.h"
|
#include "triton/codegen/transform/dce.h"
|
||||||
#include "triton/codegen/transform/trans.h"
|
#include "triton/codegen/transform/peephole.h"
|
||||||
#include "triton/codegen/transform/shmem/barriers.h"
|
#include "triton/codegen/transform/shmem/barriers.h"
|
||||||
#include "triton/codegen/transform/reassociate.h"
|
#include "triton/codegen/transform/reassociate.h"
|
||||||
#include "triton/codegen/transform/vectorize.h"
|
#include "triton/codegen/transform/vectorize.h"
|
||||||
@@ -64,7 +63,6 @@ public:
|
|||||||
shmem_barriers(&shmem_allocation, &shmem_info),
|
shmem_barriers(&shmem_allocation, &shmem_info),
|
||||||
vectorize(&tune),
|
vectorize(&tune),
|
||||||
selection(&shmem_allocation, &tune, &shmem_info, &alignment_info, target),
|
selection(&shmem_allocation, &tune, &shmem_info, &alignment_info, target),
|
||||||
optimize_dot(&tune),
|
|
||||||
dce(),
|
dce(),
|
||||||
peephole(),
|
peephole(),
|
||||||
alignment_info(),
|
alignment_info(),
|
||||||
@@ -72,7 +70,6 @@ public:
|
|||||||
target_(target) { }
|
target_(target) { }
|
||||||
|
|
||||||
void target_independent(ir::module &module) {
|
void target_independent(ir::module &module) {
|
||||||
ir::print(module, std::cout);
|
|
||||||
peephole.run(module);
|
peephole.run(module);
|
||||||
dce.run(module);
|
dce.run(module);
|
||||||
}
|
}
|
||||||
@@ -89,7 +86,6 @@ public:
|
|||||||
alignment_info.run(module);
|
alignment_info.run(module);
|
||||||
vectorize.run(module);
|
vectorize.run(module);
|
||||||
dce.run(module);
|
dce.run(module);
|
||||||
ir::print(module, std::cout);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
codegen::selection selection;
|
codegen::selection selection;
|
||||||
@@ -100,7 +96,6 @@ public:
|
|||||||
codegen::analysis::alignment_info alignment_info;
|
codegen::analysis::alignment_info alignment_info;
|
||||||
codegen::transform::shmem_barriers shmem_barriers;
|
codegen::transform::shmem_barriers shmem_barriers;
|
||||||
codegen::transform::vectorize vectorize;
|
codegen::transform::vectorize vectorize;
|
||||||
codegen::transform::optimize_dot optimize_dot;
|
|
||||||
codegen::transform::dce dce;
|
codegen::transform::dce dce;
|
||||||
codegen::transform::peephole peephole;
|
codegen::transform::peephole peephole;
|
||||||
codegen::transform::reassociate reassociate;
|
codegen::transform::reassociate reassociate;
|
||||||
|
File diff suppressed because it is too large
Load Diff
@@ -1,113 +0,0 @@
|
|||||||
#include "triton/ir/function.h"
|
|
||||||
#include "triton/ir/basic_block.h"
|
|
||||||
#include "triton/ir/module.h"
|
|
||||||
#include "triton/codegen/transform/dot.h"
|
|
||||||
#include "triton/codegen/analysis/tune.h"
|
|
||||||
|
|
||||||
namespace triton {
|
|
||||||
namespace codegen{
|
|
||||||
namespace transform{
|
|
||||||
|
|
||||||
inline bool is_trans(ir::value *v){
|
|
||||||
auto *x = dynamic_cast<ir::trans_inst*>(v);
|
|
||||||
if(!x)
|
|
||||||
return false;
|
|
||||||
std::vector<ir::constant_int*> perm = x->get_perm();
|
|
||||||
std::vector<ir::constant_int*> ref;
|
|
||||||
ir::type *int32_ty = ir::type::get_int32_ty(v->get_type()->get_context());
|
|
||||||
for(size_t i = 0; i < perm.size(); i++)
|
|
||||||
ref.push_back(ir::constant_int::get(int32_ty, i));
|
|
||||||
std::swap(ref[0], ref[1]);
|
|
||||||
// true is perm == ref
|
|
||||||
return std::equal(perm.begin(), perm.end(), ref.begin());
|
|
||||||
}
|
|
||||||
|
|
||||||
inline bool is_hmma(ir::value *v){
|
|
||||||
bool result = false;
|
|
||||||
if(auto *x = dynamic_cast<ir::dot_inst*>(v)){
|
|
||||||
ir::value *a = x->get_operand(0);
|
|
||||||
ir::type *a_ty = a->get_type();
|
|
||||||
ir::value *b = x->get_operand(1);
|
|
||||||
ir::type *b_ty = b->get_type();
|
|
||||||
// inputs have to be FP16
|
|
||||||
result = a_ty->get_scalar_ty()->is_half_ty() && b_ty->get_scalar_ty()->is_half_ty();
|
|
||||||
// reduction has to be multiple of 4
|
|
||||||
// result = result && ((a_ty->get_tile_shapes()[1]->get_value() % 4) == 0);
|
|
||||||
}
|
|
||||||
return result;
|
|
||||||
}
|
|
||||||
|
|
||||||
void optimize_dot::run(ir::module &mod) {
|
|
||||||
ir::builder &builder = mod.get_builder();
|
|
||||||
// iterate
|
|
||||||
for(ir::function *fn: mod.get_function_list())
|
|
||||||
for(ir::basic_block *block: fn->blocks())
|
|
||||||
for(ir::instruction *i: block->get_inst_list())
|
|
||||||
if(auto dot = dynamic_cast<ir::dot_inst*>(i)){
|
|
||||||
builder.set_insert_point(i);
|
|
||||||
ir::value *A = dot->get_operand(0);
|
|
||||||
ir::value *B = dot->get_operand(1);
|
|
||||||
ir::value *D = dot->get_operand(2);
|
|
||||||
bool trans_a = is_trans(A);
|
|
||||||
bool trans_b = is_trans(B);
|
|
||||||
|
|
||||||
if(!dot->is_a_trans() && !dot->is_b_trans()){
|
|
||||||
if(is_hmma(dot)){
|
|
||||||
ir::value *AA = A;
|
|
||||||
ir::value *BB = B;
|
|
||||||
if(trans_a){
|
|
||||||
AA = ((ir::trans_inst*)A)->get_operand(0);
|
|
||||||
}
|
|
||||||
else{
|
|
||||||
if(auto *T = dynamic_cast<ir::trans_inst*>(A)){
|
|
||||||
std::vector<ir::constant_int*> perm(T->get_perm());
|
|
||||||
std::swap(perm[0], perm[1]);
|
|
||||||
AA = builder.create_trans(T->get_operand(0), perm);
|
|
||||||
T->replace_all_uses_with(AA);
|
|
||||||
trans_a = true;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if(trans_b){
|
|
||||||
BB = ((ir::trans_inst*)B)->get_operand(0);
|
|
||||||
}
|
|
||||||
else{
|
|
||||||
// if(auto *T = dynamic_cast<ir::trans_inst*>(A)){
|
|
||||||
// std::vector<ir::constant_int*> perm(T->get_perm());
|
|
||||||
// std::swap(perm[0], perm[1]);
|
|
||||||
// AA = builder.create_trans(T->get_operand(0), perm);
|
|
||||||
// T->replace_all_uses_with(AA);
|
|
||||||
// trans_a = true;
|
|
||||||
// }
|
|
||||||
}
|
|
||||||
ir::instruction *dot_atbt = builder.insert(ir::dot_inst::create(AA, BB, D, trans_a, trans_b));
|
|
||||||
dot->replace_all_uses_with(dot_atbt);
|
|
||||||
}
|
|
||||||
else{
|
|
||||||
// dot(op(a), trans(b))
|
|
||||||
if(trans_b){
|
|
||||||
ir::value* BB = ((ir::trans_inst*)B)->get_operand(0);
|
|
||||||
ir::instruction *NT = builder.insert(ir::dot_inst::create_nt(A, BB, D));
|
|
||||||
dot->replace_all_uses_with(NT);
|
|
||||||
}
|
|
||||||
// dot(op(a), b)
|
|
||||||
if(!trans_b){
|
|
||||||
// create permutations
|
|
||||||
size_t size = B->get_type()->get_tile_shapes().size();
|
|
||||||
std::vector<ir::constant_int*> perm(size);
|
|
||||||
ir::type *int32_ty = ir::type::get_int32_ty(B->get_type()->get_context());
|
|
||||||
for(size_t i = 0; i < size; i++)
|
|
||||||
perm[i] = ir::constant_int::get(int32_ty, i);
|
|
||||||
std::swap(perm[0], perm[1]);
|
|
||||||
// replace NN -> NT (trans)
|
|
||||||
ir::value* BB = builder.create_trans(B, perm);
|
|
||||||
ir::instruction *NT = builder.insert(ir::dot_inst::create_nt(A, BB, D));
|
|
||||||
dot->replace_all_uses_with(NT);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
@@ -1,6 +1,6 @@
|
|||||||
#include "triton/ir/module.h"
|
#include "triton/ir/module.h"
|
||||||
#include "triton/ir/function.h"
|
#include "triton/ir/function.h"
|
||||||
#include "triton/codegen/transform/trans.h"
|
#include "triton/codegen/transform/peephole.h"
|
||||||
|
|
||||||
namespace triton {
|
namespace triton {
|
||||||
namespace codegen{
|
namespace codegen{
|
||||||
@@ -70,84 +70,96 @@ bool peephole::rewrite_trans_phi(ir::instruction* value, ir::builder& builder) {
|
|||||||
if(users.size() > 1 || ops.size() > 1)
|
if(users.size() > 1 || ops.size() > 1)
|
||||||
return false;
|
return false;
|
||||||
ir::value* op = *ops.begin();
|
ir::value* op = *ops.begin();
|
||||||
|
// trans(phi) -> phi(trans(), trans()...)
|
||||||
auto* phi = dynamic_cast<ir::phi_node*>(op);
|
auto* phi = dynamic_cast<ir::phi_node*>(op);
|
||||||
if(!phi)
|
if(!phi)
|
||||||
return false;
|
return false;
|
||||||
ir::value* new_phi = rewrite_trans_phi_impl(op, builder, trans->get_perm());
|
ir::value* new_phi = rewrite_trans_phi_impl(phi, builder, trans->get_perm());
|
||||||
trans->replace_all_uses_with(new_phi);
|
trans->replace_all_uses_with(new_phi);
|
||||||
|
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
bool peephole::rewrite_dot(ir::instruction *value, ir::builder& builder){
|
bool peephole::rewrite_dot_hmma(ir::dot_inst *dot, ir::builder& builder, bool trans_a, bool trans_b,
|
||||||
if(auto dot = dynamic_cast<ir::dot_inst*>(value)){
|
ir::value *A, ir::value *B, ir::value *D){
|
||||||
builder.set_insert_point(value);
|
ir::value *AA = A;
|
||||||
ir::value *A = dot->get_operand(0);
|
ir::value *BB = B;
|
||||||
ir::value *B = dot->get_operand(1);
|
if(trans_a){
|
||||||
ir::value *D = dot->get_operand(2);
|
AA = ((ir::trans_inst*)A)->get_operand(0);
|
||||||
bool trans_a = is_trans(A);
|
}
|
||||||
bool trans_b = is_trans(B);
|
else{
|
||||||
// NN
|
if(auto *T = dynamic_cast<ir::trans_inst*>(A)){
|
||||||
if(!dot->is_a_trans() && !dot->is_b_trans()){
|
std::vector<ir::constant_int*> perm(T->get_perm());
|
||||||
if(is_hmma(dot)) {
|
std::swap(perm[0], perm[1]);
|
||||||
ir::value *AA = A;
|
AA = builder.create_trans(T->get_operand(0), perm);
|
||||||
ir::value *BB = B;
|
T->replace_all_uses_with(AA);
|
||||||
if(trans_a){
|
trans_a = true;
|
||||||
AA = ((ir::trans_inst*)A)->get_operand(0);
|
|
||||||
}
|
|
||||||
else{
|
|
||||||
if(auto *T = dynamic_cast<ir::trans_inst*>(A)){
|
|
||||||
std::vector<ir::constant_int*> perm(T->get_perm());
|
|
||||||
std::swap(perm[0], perm[1]);
|
|
||||||
AA = builder.create_trans(T->get_operand(0), perm);
|
|
||||||
T->replace_all_uses_with(AA);
|
|
||||||
trans_a = true;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if(trans_b){
|
|
||||||
BB = ((ir::trans_inst*)B)->get_operand(0);
|
|
||||||
}
|
|
||||||
else{
|
|
||||||
if(auto *T = dynamic_cast<ir::trans_inst*>(A)){
|
|
||||||
std::vector<ir::constant_int*> perm(T->get_perm());
|
|
||||||
std::swap(perm[0], perm[1]);
|
|
||||||
AA = builder.create_trans(T->get_operand(0), perm);
|
|
||||||
T->replace_all_uses_with(AA);
|
|
||||||
trans_a = true;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
ir::instruction *dot_atbt = builder.insert(ir::dot_inst::create(AA, BB, D, trans_a, trans_b));
|
|
||||||
dot->replace_all_uses_with(dot_atbt);
|
|
||||||
return true;
|
|
||||||
}
|
|
||||||
else{
|
|
||||||
// dot(op(a), trans(b))
|
|
||||||
if(trans_b){
|
|
||||||
ir::value* BB = ((ir::trans_inst*)B)->get_operand(0);
|
|
||||||
ir::instruction *NT = builder.insert(ir::dot_inst::create_nt(A, BB, D));
|
|
||||||
dot->replace_all_uses_with(NT);
|
|
||||||
return true;
|
|
||||||
}
|
|
||||||
// dot(op(a), b)
|
|
||||||
if(!trans_b){
|
|
||||||
// create permutations
|
|
||||||
size_t size = B->get_type()->get_tile_shapes().size();
|
|
||||||
std::vector<ir::constant_int*> perm(size);
|
|
||||||
ir::type *int32_ty = ir::type::get_int32_ty(B->get_type()->get_context());
|
|
||||||
for(size_t i = 0; i < size; i++)
|
|
||||||
perm[i] = ir::constant_int::get(int32_ty, i);
|
|
||||||
std::swap(perm[0], perm[1]);
|
|
||||||
// replace NN -> NT (trans)
|
|
||||||
ir::value* BB = builder.create_trans(B, perm);
|
|
||||||
ir::instruction *NT = builder.insert(ir::dot_inst::create_nt(A, BB, D));
|
|
||||||
dot->replace_all_uses_with(NT);
|
|
||||||
return true;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
if(trans_b){
|
||||||
|
BB = ((ir::trans_inst*)B)->get_operand(0);
|
||||||
|
}
|
||||||
|
else{
|
||||||
|
if(auto *T = dynamic_cast<ir::trans_inst*>(A)){
|
||||||
|
std::vector<ir::constant_int*> perm(T->get_perm());
|
||||||
|
std::swap(perm[0], perm[1]);
|
||||||
|
AA = builder.create_trans(T->get_operand(0), perm);
|
||||||
|
T->replace_all_uses_with(AA);
|
||||||
|
trans_a = true;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
ir::instruction *dot_atbt = builder.insert(ir::dot_inst::create(AA, BB, D, trans_a, trans_b));
|
||||||
|
dot->replace_all_uses_with(dot_atbt);
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
bool peephole::rewrite_dot_fp32(ir::dot_inst *dot, ir::builder& builder, bool trans_a, bool trans_b,
|
||||||
|
ir::value *A, ir::value *B, ir::value *D){
|
||||||
|
// dot(op(a), trans(b))
|
||||||
|
if(trans_b){
|
||||||
|
ir::value* BB = ((ir::trans_inst*)B)->get_operand(0);
|
||||||
|
ir::instruction *NT = builder.insert(ir::dot_inst::create_nt(A, BB, D));
|
||||||
|
dot->replace_all_uses_with(NT);
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
// dot(op(a), b)
|
||||||
|
if(!trans_b){
|
||||||
|
// create permutations
|
||||||
|
size_t size = B->get_type()->get_tile_shapes().size();
|
||||||
|
std::vector<ir::constant_int*> perm(size);
|
||||||
|
ir::type *int32_ty = ir::type::get_int32_ty(B->get_type()->get_context());
|
||||||
|
for(size_t i = 0; i < size; i++)
|
||||||
|
perm[i] = ir::constant_int::get(int32_ty, i);
|
||||||
|
std::swap(perm[0], perm[1]);
|
||||||
|
// replace NN -> NT (trans)
|
||||||
|
ir::value* BB = builder.create_trans(B, perm);
|
||||||
|
ir::instruction *NT = builder.insert(ir::dot_inst::create_nt(A, BB, D));
|
||||||
|
dot->replace_all_uses_with(NT);
|
||||||
|
return true;
|
||||||
|
}
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
bool peephole::rewrite_dot(ir::instruction *value, ir::builder& builder){
|
||||||
|
auto dot = dynamic_cast<ir::dot_inst*>(value);
|
||||||
|
if(!dot)
|
||||||
|
return false;
|
||||||
|
builder.set_insert_point(value);
|
||||||
|
ir::value *A = dot->get_operand(0);
|
||||||
|
ir::value *B = dot->get_operand(1);
|
||||||
|
ir::value *D = dot->get_operand(2);
|
||||||
|
bool trans_a = is_trans(A);
|
||||||
|
bool trans_b = is_trans(B);
|
||||||
|
// only consider dot-nn
|
||||||
|
if(dot->is_a_trans() || dot->is_b_trans())
|
||||||
|
return false;
|
||||||
|
// hmma
|
||||||
|
if(is_hmma(dot))
|
||||||
|
return rewrite_dot_hmma(dot, builder, trans_a, trans_b, A, B, D);
|
||||||
|
else
|
||||||
|
return rewrite_dot_fp32(dot, builder, trans_a, trans_b, A, B, D);
|
||||||
|
}
|
||||||
|
|
||||||
bool peephole::rewrite_unit_red(ir::instruction *value, ir::builder& builder){
|
bool peephole::rewrite_unit_red(ir::instruction *value, ir::builder& builder){
|
||||||
auto x = dynamic_cast<ir::reduce_inst*>(value);
|
auto x = dynamic_cast<ir::reduce_inst*>(value);
|
||||||
if(!x)
|
if(!x)
|
||||||
@@ -190,28 +202,40 @@ bool peephole::rewrite_gep_ptr_min_off_plus_off(ir::instruction *value, ir::buil
|
|||||||
void peephole::run(ir::module &mod) {
|
void peephole::run(ir::module &mod) {
|
||||||
ir::builder &builder = mod.get_builder();
|
ir::builder &builder = mod.get_builder();
|
||||||
// keep track of whether any modification was made
|
// keep track of whether any modification was made
|
||||||
bool was_modified = false;
|
std::set<ir::value*> seen;
|
||||||
|
size_t n_seen;
|
||||||
|
|
||||||
// rewrite dots first
|
// rewrite dots first
|
||||||
do{
|
do{
|
||||||
was_modified = false;
|
n_seen = seen.size();
|
||||||
for(ir::function *fn: mod.get_function_list())
|
|
||||||
for(ir::basic_block *block: fn->blocks())
|
|
||||||
for(ir::instruction* i: block->get_inst_list())
|
|
||||||
rewrite_dot(i, builder);
|
|
||||||
}while(was_modified);
|
|
||||||
|
|
||||||
// rewrite other ops
|
|
||||||
do{
|
|
||||||
was_modified = false;
|
|
||||||
for(ir::function *fn: mod.get_function_list())
|
for(ir::function *fn: mod.get_function_list())
|
||||||
for(ir::basic_block *block: fn->blocks())
|
for(ir::basic_block *block: fn->blocks())
|
||||||
for(ir::instruction* i: block->get_inst_list()){
|
for(ir::instruction* i: block->get_inst_list()){
|
||||||
|
if(seen.find(i) != seen.end())
|
||||||
|
continue;
|
||||||
|
bool was_modified = rewrite_dot(i, builder);
|
||||||
|
if(was_modified)
|
||||||
|
seen.insert(i);
|
||||||
|
}
|
||||||
|
}while(seen.size() != n_seen);
|
||||||
|
|
||||||
|
// rewrite other ops
|
||||||
|
seen.clear();
|
||||||
|
do{
|
||||||
|
n_seen = seen.size();
|
||||||
|
for(ir::function *fn: mod.get_function_list())
|
||||||
|
for(ir::basic_block *block: fn->blocks())
|
||||||
|
for(ir::instruction* i: block->get_inst_list()){
|
||||||
|
if(seen.find(i) != seen.end())
|
||||||
|
continue;
|
||||||
|
bool was_modified = false;
|
||||||
was_modified = was_modified || rewrite_trans_phi(i, builder);
|
was_modified = was_modified || rewrite_trans_phi(i, builder);
|
||||||
was_modified = was_modified || rewrite_unit_red(i, builder);
|
was_modified = was_modified || rewrite_unit_red(i, builder);
|
||||||
was_modified = was_modified || rewrite_gep_ptr_min_off_plus_off(i, builder);
|
was_modified = was_modified || rewrite_gep_ptr_min_off_plus_off(i, builder);
|
||||||
|
if(was_modified)
|
||||||
|
seen.insert(i);
|
||||||
}
|
}
|
||||||
}while(was_modified);
|
}while(seen.size() != n_seen);
|
||||||
}
|
}
|
||||||
|
|
||||||
}
|
}
|
Reference in New Issue
Block a user