[general] some cleaning:

* trans/dot -> peephole
* isel -> added function for tile-level lowering
This commit is contained in:
Philippe Tillet
2019-08-12 21:09:47 -07:00
parent 1400d960a6
commit 4bc5758a22
7 changed files with 660 additions and 712 deletions

View File

@@ -120,36 +120,66 @@ class selection{
typedef std::map<ir::value *, llvm::Value *> vmap_t;
typedef std::map<ir::value *, tile *> tmap_t;
typedef llvm::LLVMContext LLVMContext;
typedef llvm::IRBuilder<> Builder;
typedef llvm::Type Type;
typedef llvm::Value Value;
typedef llvm::Module Module;
typedef llvm::Instruction Instruction;
typedef llvm::Constant Constant;
typedef llvm::ArrayType ArrayType;
typedef llvm::Function Function;
private:
// utils
llvm::Type *make_vector_ty(llvm::Type *ty, size_t vector_size);
Type *make_vector_ty(Type *ty, size_t vector_size);
std::vector<unsigned> extract_shapes(ir::value *v);
// LLVM conversions
llvm::Type* llvm_type(ir::type *ty, llvm::LLVMContext &ctx);
llvm::Value* llvm_value(ir::value *v, llvm::IRBuilder<> &builder);
llvm::Instruction* llvm_inst(ir::instruction *inst, std::function<llvm::Value*(ir::value*)> value, llvm::IRBuilder<> &builder);
llvm::Constant* llvm_constant(ir::constant *cst, llvm::LLVMContext &ctx);
llvm::Value* llvm_alloc_const(ir::alloc_const *v, llvm::Module *module, llvm::IRBuilder<> &builder);
llvm::ArrayType* llvm_linearized_tile_type(ir::type *ty, llvm::LLVMContext &ctx);
Type* llvm_type(ir::type *ty, LLVMContext &ctx);
Value* llvm_value(ir::value *v, Builder &builder);
Instruction* llvm_inst(ir::instruction *inst, std::function<Value*(ir::value*)> value, Builder &builder);
Constant* llvm_constant(ir::constant *cst, LLVMContext &ctx);
Value* llvm_alloc_const(ir::alloc_const *v, Module *module, Builder &builder);
ArrayType* llvm_linearized_tile_type(ir::type *ty, LLVMContext &ctx);
// grid construction
void create_grids(std::vector<ir::value *> &grids,
std::map<unsigned, ir::value *> &references,
ir::function *fn);
void create_tile(ir::value *v, llvm::IRBuilder<> &builder, const std::map<unsigned, ir::value *> &references, std::set<ir::value *> &seen, llvm::Value *sh_mem_ptr);
void init_axes(ir::value *i, llvm::IRBuilder<> &builder, llvm::Value *u_thread_id, llvm::Value *u_warp_id);
void init_grids(ir::function *fn, llvm::IRBuilder<> &builder, llvm::Value *sh_mem_ptr);
void create_tile(ir::value *v, Builder &builder, const std::map<unsigned, ir::value *> &references, std::set<ir::value *> &seen, Value *sh_mem_ptr);
void init_axes(ir::value *i, Builder &builder, Value *u_thread_id, Value *u_warp_id);
void init_grids(ir::function *fn, Builder &builder, Value *sh_mem_ptr);
// lower scalar instruction
void lower_instruction(ir::instruction *src, Builder &builder);
// lower tile instruction
void lower_masked_store(ir::masked_store_inst *x, LLVMContext &ctx, Function *fn, Builder &builder);
void lower_store(ir::store_inst *x, LLVMContext &ctx, Function *fn, Builder &builder);
void lower_downcast(ir::downcast_inst *x, LLVMContext &ctx, Function *fn, Builder &builder);
void lower_reduce(ir::reduce_inst *x, LLVMContext &ctx, Function *fn, Builder &builder);
void lower_dynamic_range_idx(ir::nv_dynamic_range_idx_inst *x, LLVMContext &ctx, Function *fn, Builder &builder);
void lower_reshape(ir::reshape_inst* x, LLVMContext &ctx, Function *fn, Builder &builder);
void lower_splat(ir::splat_inst *x, LLVMContext &ctx, Function *fn, Builder &builder);
void lower_broadcast(ir::broadcast_inst *x, LLVMContext &ctx, Function *fn, Builder &builder);
void lower_vectorize(ir::vectorize_inst *x, LLVMContext &ctx, Function *fn, Builder &builder);
void lower_copy_to_shared(ir::copy_to_shared_inst *x, LLVMContext &ctx, Function *fn, Builder &builder);
void lower_trans(ir::trans_inst *x, LLVMContext &ctx, Function *fn, Builder &builder);
void lower_hmma_dot(ir::dot_inst *x, LLVMContext &ctx, Function *fn, Builder &builder);
void lower_scalar_dot(ir::dot_inst *x, LLVMContext &ctx, Function *fn, Builder &builder);
void lower_dot(ir::dot_inst *x, LLVMContext &ctx, Function *fn, Builder &builder);
void lower_masked_load(ir::masked_load_inst *x, LLVMContext &ctx, Function *fn, Builder &builder);
void lower_load(ir::load_inst *x, LLVMContext &ctx, Function *fn, Builder &builder);
void lower_elementwise(ir::instruction *x, LLVMContext &ctx, Function *fn, Builder &builder);
void lower_tile_instruction(ir::instruction *src, Builder &builder);
// lowering
void lower_instruction(ir::instruction *src, llvm::IRBuilder<> &builder);
void lower_tile_instruction(ir::instruction *src, llvm::IRBuilder<> &builder);
public:
selection(analysis::shmem::allocation *alloc, analysis::tune *params, analysis::shmem::info *buffer_info, analysis::alignment_info *ax_info, target *tgt)
: alloc_(alloc), params_(params), buffer_info_(buffer_info), axis_info_(ax_info), tgt_(tgt){ }
void run(ir::module &src, llvm::Module &dst);
void run(ir::module &src, Module &dst);
private:
vmap_t vmap_;
@@ -160,9 +190,9 @@ private:
analysis::shmem::info *buffer_info_;
analysis::alignment_info *axis_info_;
std::map<unsigned, distributed_axis> axes_;
llvm::Value *sh_mem_ptr_;
llvm::Value *offset_a_i_, *offset_a_k_;
llvm::Value *offset_b_j_, *offset_b_k_;
Value *sh_mem_ptr_;
Value *offset_a_i_, *offset_a_k_;
Value *offset_b_j_, *offset_b_k_;
unsigned num_packs_0_, num_packs_1_;
unsigned pack_size_0_, pack_size_1_;
};

View File

@@ -1,35 +0,0 @@
#ifndef TDL_INCLUDE_CODEGEN_OPTIMIZE_DOT_H
#define TDL_INCLUDE_CODEGEN_OPTIMIZE_DOT_H
#include <tuple>
#include <vector>
#include <set>
namespace triton {
namespace ir {
class module;
}
namespace codegen{
namespace analysis{
class tune;
}
namespace transform{
class optimize_dot {
public:
optimize_dot(analysis::tune* params): params_(params) {}
void run(ir::module &mod);
private:
analysis::tune* params_;
};
}
}
}
#endif

View File

@@ -14,6 +14,7 @@ namespace ir {
class trans_inst;
class builder;
class constant_int;
class dot_inst;
}
namespace codegen{
@@ -22,6 +23,8 @@ namespace transform{
class peephole {
private:
bool rewrite_trans_phi(ir::instruction* value, ir::builder &builder);
bool rewrite_dot_fp32(ir::dot_inst *dot, ir::builder& builder, bool trans_a, bool trans_b, ir::value *A, ir::value *B, ir::value *D);
bool rewrite_dot_hmma(ir::dot_inst *dot, ir::builder& builder, bool trans_a, bool trans_b, ir::value *A, ir::value *B, ir::value *D);
bool rewrite_dot(ir::instruction *value, ir::builder& builder);
bool rewrite_unit_red(ir::instruction *value, ir::builder& builder);
bool rewrite_gep_ptr_min_off_plus_off(ir::instruction *value, ir::builder& builder);

View File

@@ -15,9 +15,8 @@
#include "triton/codegen/analysis/shmem/liveness.h"
#include "triton/codegen/analysis/shmem/info.h"
#include "triton/codegen/analysis/alignment.h"
#include "triton/codegen/transform/dot.h"
#include "triton/codegen/transform/dce.h"
#include "triton/codegen/transform/trans.h"
#include "triton/codegen/transform/peephole.h"
#include "triton/codegen/transform/shmem/barriers.h"
#include "triton/codegen/transform/reassociate.h"
#include "triton/codegen/transform/vectorize.h"
@@ -64,7 +63,6 @@ public:
shmem_barriers(&shmem_allocation, &shmem_info),
vectorize(&tune),
selection(&shmem_allocation, &tune, &shmem_info, &alignment_info, target),
optimize_dot(&tune),
dce(),
peephole(),
alignment_info(),
@@ -72,7 +70,6 @@ public:
target_(target) { }
void target_independent(ir::module &module) {
ir::print(module, std::cout);
peephole.run(module);
dce.run(module);
}
@@ -89,7 +86,6 @@ public:
alignment_info.run(module);
vectorize.run(module);
dce.run(module);
ir::print(module, std::cout);
}
codegen::selection selection;
@@ -100,7 +96,6 @@ public:
codegen::analysis::alignment_info alignment_info;
codegen::transform::shmem_barriers shmem_barriers;
codegen::transform::vectorize vectorize;
codegen::transform::optimize_dot optimize_dot;
codegen::transform::dce dce;
codegen::transform::peephole peephole;
codegen::transform::reassociate reassociate;

View File

@@ -798,14 +798,7 @@ void selection::init_grids(ir::function *fn, IRBuilder<> &builder, Value *sh_mem
}
}
void selection::lower_tile_instruction(ir::instruction *ins, llvm::IRBuilder<> &builder) {
BasicBlock *block = builder.GetInsertBlock();
Module *module = block->getModule();
LLVMContext &ctx = builder.getContext();
Function *fn = block->getParent();
// store
if(auto *x = dynamic_cast<ir::masked_store_inst*>(ins)){
void selection::lower_masked_store(ir::masked_store_inst *x, LLVMContext &ctx, Function *fn, IRBuilder<> &builder) {
distributed_tile* ptrs = (distributed_tile*)tmap_.at(x->get_pointer_operand());
distributed_tile* scalars = (distributed_tile*)tmap_.at(x->get_value_operand());
ir::value *mask = x->get_mask_operand();
@@ -821,7 +814,6 @@ void selection::lower_tile_instruction(ir::instruction *ins, llvm::IRBuilder<> &
builder.CreateStore(scalar, ptr);
builder.CreateBr(mask_done_bb);
builder.SetInsertPoint(mask_done_bb);
// std::string offset = "";
// if(GetElementPtrInst *gep = dyn_cast<GetElementPtrInst>(ptr))
// if(gep->getNumIndices() == 1)
@@ -834,21 +826,24 @@ void selection::lower_tile_instruction(ir::instruction *ins, llvm::IRBuilder<> &
// builder.CreateCall(iasm, {pred, ptr, scalar});
});
}
else if(auto *x = dynamic_cast<ir::store_inst*>(ins)) {
void selection::lower_store(ir::store_inst *x, LLVMContext &ctx, Function *fn, IRBuilder<> &builder) {
distributed_tile* ptrs = (distributed_tile*)tmap_.at(x->get_pointer_operand());
tile *scalars = tmap_.at(x->get_value_operand());
ptrs->for_each([&](indices_t idx){
builder.CreateStore(scalars->get_value(idx), ptrs->get_value(idx));
});
}
else {
if(auto *x = dynamic_cast<ir::downcast_inst*>(ins)){
void selection::lower_downcast(ir::downcast_inst *x, LLVMContext &ctx, Function *fn, IRBuilder<> &builder) {
vmap_[x] = tmap_[x->get_operand(0)]->get_value({builder.getInt32(0)});
return;
}
if(auto *x = dynamic_cast<ir::reduce_inst*>(ins)){
void selection::lower_reduce(ir::reduce_inst *x, LLVMContext &ctx, Function *fn, IRBuilder<> &builder) {
ir::instruction *ins = (ir::instruction*)x;
Module *module = fn->getParent();
std::map<indices_t, Value*> partial;
ir::value *op = ins->get_operand(0);
ir::value *op = x->get_operand(0);
distributed_tile* op_tile = (distributed_tile*)tmap_.at(op);
unsigned axis = x->get_axis();
@@ -917,15 +912,10 @@ void selection::lower_tile_instruction(ir::instruction *ins, llvm::IRBuilder<> &
ti->set_value(x.first, result);
}
}
return;
}
tile *ti = tmap_[ins];
distributed_tile* result = (distributed_tile*)ti;
if(!ins->get_type()->is_tile_ty())
return;
const auto& shapes = ins->get_type()->get_tile_shapes();
// nv_dynamic_range_idx_inst
if(dynamic_cast<ir::nv_dynamic_range_idx_inst*>(ins)){
void selection::lower_dynamic_range_idx(ir::nv_dynamic_range_idx_inst *x, LLVMContext &ctx, Function *fn, IRBuilder<> &builder) {
distributed_tile* result = (distributed_tile*)tmap_.at(x);
result->for_each([&](indices_t idx){
assert(idx.size() == 1);
BinaryOperator *bin_add = dyn_cast<BinaryOperator>(idx[0]);
@@ -934,9 +924,10 @@ void selection::lower_tile_instruction(ir::instruction *ins, llvm::IRBuilder<> &
result->set_value(idx, res);
});
}
// reshape
else if(dynamic_cast<ir::reshape_inst*>(ins)) {
ir::value* in = ins->get_operand(0);
void selection::lower_reshape(ir::reshape_inst* x, LLVMContext &ctx, Function *fn, IRBuilder<> &builder) {
distributed_tile* result = (distributed_tile*)tmap_.at(x);
ir::value* in = x->get_operand(0);
distributed_tile *in_tile = (distributed_tile*)tmap_.at(in);
result->for_each([&](indices_t out_idx){
unsigned pos = result->get_linear_index(out_idx);
@@ -944,15 +935,17 @@ void selection::lower_tile_instruction(ir::instruction *ins, llvm::IRBuilder<> &
result->set_value(out_idx, in_tile->get_value(in_idx));
});
}
// splat
else if(dynamic_cast<ir::splat_inst*>(ins)) {
void selection::lower_splat(ir::splat_inst *x, LLVMContext &ctx, Function *fn, IRBuilder<> &builder) {
distributed_tile* result = (distributed_tile*)tmap_.at(x);
result->for_each([&](indices_t idx) {
result->set_value(idx, llvm_value(ins->get_operand(0), builder));
result->set_value(idx, llvm_value(x->get_operand(0), builder));
});
}
// broadcast
else if(dynamic_cast<ir::broadcast_inst*>(ins)) {
ir::value* in = ins->get_operand(0);
void selection::lower_broadcast(ir::broadcast_inst *x, LLVMContext &ctx, Function *fn, IRBuilder<> &builder) {
distributed_tile* result = (distributed_tile*)tmap_.at(x);
ir::value* in = x->get_operand(0);
const auto& in_shapes = in->get_type()->get_tile_shapes();
distributed_tile *in_tile = (distributed_tile*)tmap_.at(in);
result->for_each([&](indices_t out_idx){
@@ -964,9 +957,10 @@ void selection::lower_tile_instruction(ir::instruction *ins, llvm::IRBuilder<> &
result->set_value(out_idx, in_tile->get_value(in_idx));
});
}
// vectorize
else if(dynamic_cast<ir::vectorize_inst*>(ins)) {
distributed_tile* in = (distributed_tile*)tmap_.at(ins->get_operand(0));
void selection::lower_vectorize(ir::vectorize_inst *x, LLVMContext &ctx, Function *fn, IRBuilder<> &builder) {
distributed_tile* result = (distributed_tile*)tmap_.at(x);
distributed_tile* in = (distributed_tile*)tmap_.at(x->get_operand(0));
unsigned vector_size = result->axis(0).contiguous;
std::map<unsigned, Value*> packets;
in->for_each([&](indices_t idx){
@@ -984,31 +978,42 @@ void selection::lower_tile_instruction(ir::instruction *ins, llvm::IRBuilder<> &
result->set_value(idx, packets[id]);
});
}
// copy to shared
else if(dynamic_cast<ir::copy_to_shared_inst*>(ins)) {
distributed_tile* in = (distributed_tile*)tmap_.at(ins->get_operand(0));
void selection::lower_copy_to_shared(ir::copy_to_shared_inst *x, LLVMContext &ctx, Function *fn, IRBuilder<> &builder) {
shared_tile* result = (shared_tile*)tmap_.at(x);
distributed_tile* in = (distributed_tile*)tmap_.at(x->get_operand(0));
in->for_each([&](indices_t idx){
ti->set_value(idx, in->get_value(idx));
result->set_value(idx, in->get_value(idx));
});
}
// trans
else if(auto* x = dynamic_cast<ir::trans_inst*>(ins)) {
distributed_tile* in = (distributed_tile*)tmap_.at(ins->get_operand(0));
void selection::lower_trans(ir::trans_inst *x, LLVMContext &ctx, Function *fn, IRBuilder<> &builder) {
shared_tile* result = (shared_tile*)tmap_.at(x);
distributed_tile* in = (distributed_tile*)tmap_.at(x->get_operand(0));
auto perm = x->get_perm();
in->for_each([&](indices_t idx){
indices_t out_idx(idx.size());
for(size_t i = 0; i < idx.size(); i++)
out_idx[i] = idx[perm[i]->get_value()];
ti->set_value(out_idx, in->get_value(idx));
result->set_value(out_idx, in->get_value(idx));
});
}
else if(buffer_info_->is_shared(ins))
return;
// dot
else if(auto dot = dynamic_cast<ir::dot_inst*>(ins)) {
ir::value *A = ins->get_operand(0);
ir::value *B = ins->get_operand(1);
ir::value *C = ins->get_operand(2);
void selection::lower_hmma_dot(ir::dot_inst *x, LLVMContext &ctx, Function *fn, IRBuilder<> &builder) {
}
void selection::lower_scalar_dot(ir::dot_inst *x, LLVMContext &ctx, Function *fn, IRBuilder<> &builder) {
}
void selection::lower_dot(ir::dot_inst *dot, LLVMContext &ctx, Function *fn, IRBuilder<> &builder) {
const auto& shapes = dot->get_type()->get_tile_shapes();
distributed_tile* result = (distributed_tile*)tmap_.at(dot);
Module *module = fn->getParent();
ir::value *A = dot->get_operand(0);
ir::value *B = dot->get_operand(1);
ir::value *C = dot->get_operand(2);
bool AT = dot->is_a_trans();
bool BT = dot->is_b_trans();
distributed_tile *TC = (distributed_tile*)tmap_.at(C);
@@ -1023,7 +1028,7 @@ void selection::lower_tile_instruction(ir::instruction *ins, llvm::IRBuilder<> &
{
shared_tile *TA = (shared_tile*)tmap_.at(A);
shared_tile *TB = (shared_tile*)tmap_.at(B);
if(params_->get_fragment(ins, 0) == analysis::tune::STRIDED_SCAN) {
if(params_->get_fragment(dot, 0) == analysis::tune::STRIDED_SCAN) {
TA->set_vector_size(TC->axis(0).contiguous);
TB->set_vector_size(TC->axis(1).contiguous);
result->for_each([&](indices_t idx){
@@ -1193,16 +1198,18 @@ void selection::lower_tile_instruction(ir::instruction *ins, llvm::IRBuilder<> &
});
}
}
else if(auto *ld = dynamic_cast<ir::masked_load_inst*>(ins)){
void selection::lower_masked_load(ir::masked_load_inst *x, LLVMContext &ctx, Function *fn, IRBuilder<> &builder) {
// find vector size
ir::value *ptr = ld->get_pointer_operand();
distributed_tile* result = (distributed_tile*)tmap_.at(x);
ir::value *ptr = x->get_pointer_operand();
unsigned starting_multiple = axis_info_->get_starting_multiple(ptr);
unsigned max_contiguous = axis_info_->get_max_contiguous(ptr);
unsigned alignment = std::min(starting_multiple, max_contiguous);
unsigned vector_size = std::min<unsigned>(result->axis(0).contiguous, alignment);
distributed_tile *pointers = (distributed_tile*)tmap_.at(ptr);
distributed_tile *masks = (distributed_tile*)tmap_.at(ld->get_mask_operand());
distributed_tile *false_values = (distributed_tile*)tmap_.at(ld->get_false_value_operand());
distributed_tile *masks = (distributed_tile*)tmap_.at(x->get_mask_operand());
distributed_tile *false_values = (distributed_tile*)tmap_.at(x->get_false_value_operand());
std::map<unsigned, Value*> packets;
result->for_each([&](indices_t idx){
unsigned linear = result->get_linear_index(idx);
@@ -1263,9 +1270,11 @@ void selection::lower_tile_instruction(ir::instruction *ins, llvm::IRBuilder<> &
result->set_value(idx, builder.CreateExtractElement(packets.at(id), linear % vector_size));
});
}
else if(auto *ld = dynamic_cast<ir::load_inst*>(ins)){
void selection::lower_load(ir::load_inst *x, LLVMContext &ctx, Function *fn, IRBuilder<> &builder) {
distributed_tile* result = (distributed_tile*)tmap_.at(x);
// find vector size
ir::value *ptr = ld->get_pointer_operand();
ir::value *ptr = x->get_pointer_operand();
unsigned starting_multiple = axis_info_->get_starting_multiple(ptr);
unsigned max_contiguous = axis_info_->get_max_contiguous(ptr);
unsigned alignment = std::min(starting_multiple, max_contiguous);
@@ -1294,21 +1303,56 @@ void selection::lower_tile_instruction(ir::instruction *ins, llvm::IRBuilder<> &
result->set_value(idx, builder.CreateExtractElement(packets.at(id), linear % vector_size));
});
}
// element-wise
else {
void selection::lower_elementwise(ir::instruction *x, LLVMContext &ctx, Function *fn, IRBuilder<> &builder) {
distributed_tile* result = (distributed_tile*)tmap_.at(x);
result->for_each([&](indices_t idx){
auto value = [&](ir::value *x) {
if(auto *cst = dynamic_cast<ir::constant_int*>(x))
auto value = [&](ir::value *v) {
if(auto *cst = dynamic_cast<ir::constant_int*>(v))
return (Value*)llvm_constant(cst, ctx);
else if(x->get_type()->is_tile_ty())
return tmap_.at(x)->get_value(idx);
else if(v->get_type()->is_tile_ty())
return tmap_.at(v)->get_value(idx);
else
return llvm_value(x, builder);
return llvm_value(v, builder);
};
result->set_value(idx, llvm_inst(ins, value, builder));
result->set_value(idx, llvm_inst(x, value, builder));
});
}
}
void selection::lower_tile_instruction(ir::instruction *ins, llvm::IRBuilder<> &builder) {
BasicBlock *block = builder.GetInsertBlock();
LLVMContext &ctx = builder.getContext();
Function *fn = block->getParent();
if(auto *x = dynamic_cast<ir::masked_store_inst*>(ins))
lower_masked_store(x, ctx, fn, builder);
else if(auto *x = dynamic_cast<ir::store_inst*>(ins))
lower_store(x, ctx, fn, builder);
else if(auto *x = dynamic_cast<ir::downcast_inst*>(ins))
lower_downcast(x, ctx, fn, builder);
else if(auto *x = dynamic_cast<ir::reduce_inst*>(ins))
lower_reduce(x, ctx, fn, builder);
else if(auto *x = dynamic_cast<ir::nv_dynamic_range_idx_inst*>(ins))
lower_dynamic_range_idx(x, ctx, fn, builder);
else if(auto *x = dynamic_cast<ir::reshape_inst*>(ins))
lower_reshape(x, ctx, fn, builder);
else if(auto *x = dynamic_cast<ir::splat_inst*>(ins))
lower_splat(x, ctx, fn, builder);
else if(auto *x = dynamic_cast<ir::broadcast_inst*>(ins))
lower_broadcast(x, ctx, fn, builder);
else if(auto *x = dynamic_cast<ir::vectorize_inst*>(ins))
lower_vectorize(x, ctx, fn, builder);
else if(auto *x = dynamic_cast<ir::copy_to_shared_inst*>(ins))
lower_copy_to_shared(x, ctx, fn, builder);
else if(auto* x = dynamic_cast<ir::trans_inst*>(ins))
lower_trans(x, ctx, fn, builder);
else if(auto x = dynamic_cast<ir::dot_inst*>(ins))
lower_dot(x, ctx, fn, builder);
else if(auto *x = dynamic_cast<ir::masked_load_inst*>(ins))
lower_masked_load(x, ctx, fn, builder);
else if(auto *x = dynamic_cast<ir::load_inst*>(ins))
lower_load(x, ctx, fn, builder);
else if(!buffer_info_->is_shared(ins))
lower_elementwise(ins, ctx, fn, builder);
}
void selection::lower_instruction(ir::instruction *src, IRBuilder<> &builder) {

View File

@@ -1,113 +0,0 @@
#include "triton/ir/function.h"
#include "triton/ir/basic_block.h"
#include "triton/ir/module.h"
#include "triton/codegen/transform/dot.h"
#include "triton/codegen/analysis/tune.h"
namespace triton {
namespace codegen{
namespace transform{
inline bool is_trans(ir::value *v){
auto *x = dynamic_cast<ir::trans_inst*>(v);
if(!x)
return false;
std::vector<ir::constant_int*> perm = x->get_perm();
std::vector<ir::constant_int*> ref;
ir::type *int32_ty = ir::type::get_int32_ty(v->get_type()->get_context());
for(size_t i = 0; i < perm.size(); i++)
ref.push_back(ir::constant_int::get(int32_ty, i));
std::swap(ref[0], ref[1]);
// true is perm == ref
return std::equal(perm.begin(), perm.end(), ref.begin());
}
inline bool is_hmma(ir::value *v){
bool result = false;
if(auto *x = dynamic_cast<ir::dot_inst*>(v)){
ir::value *a = x->get_operand(0);
ir::type *a_ty = a->get_type();
ir::value *b = x->get_operand(1);
ir::type *b_ty = b->get_type();
// inputs have to be FP16
result = a_ty->get_scalar_ty()->is_half_ty() && b_ty->get_scalar_ty()->is_half_ty();
// reduction has to be multiple of 4
// result = result && ((a_ty->get_tile_shapes()[1]->get_value() % 4) == 0);
}
return result;
}
void optimize_dot::run(ir::module &mod) {
ir::builder &builder = mod.get_builder();
// iterate
for(ir::function *fn: mod.get_function_list())
for(ir::basic_block *block: fn->blocks())
for(ir::instruction *i: block->get_inst_list())
if(auto dot = dynamic_cast<ir::dot_inst*>(i)){
builder.set_insert_point(i);
ir::value *A = dot->get_operand(0);
ir::value *B = dot->get_operand(1);
ir::value *D = dot->get_operand(2);
bool trans_a = is_trans(A);
bool trans_b = is_trans(B);
if(!dot->is_a_trans() && !dot->is_b_trans()){
if(is_hmma(dot)){
ir::value *AA = A;
ir::value *BB = B;
if(trans_a){
AA = ((ir::trans_inst*)A)->get_operand(0);
}
else{
if(auto *T = dynamic_cast<ir::trans_inst*>(A)){
std::vector<ir::constant_int*> perm(T->get_perm());
std::swap(perm[0], perm[1]);
AA = builder.create_trans(T->get_operand(0), perm);
T->replace_all_uses_with(AA);
trans_a = true;
}
}
if(trans_b){
BB = ((ir::trans_inst*)B)->get_operand(0);
}
else{
// if(auto *T = dynamic_cast<ir::trans_inst*>(A)){
// std::vector<ir::constant_int*> perm(T->get_perm());
// std::swap(perm[0], perm[1]);
// AA = builder.create_trans(T->get_operand(0), perm);
// T->replace_all_uses_with(AA);
// trans_a = true;
// }
}
ir::instruction *dot_atbt = builder.insert(ir::dot_inst::create(AA, BB, D, trans_a, trans_b));
dot->replace_all_uses_with(dot_atbt);
}
else{
// dot(op(a), trans(b))
if(trans_b){
ir::value* BB = ((ir::trans_inst*)B)->get_operand(0);
ir::instruction *NT = builder.insert(ir::dot_inst::create_nt(A, BB, D));
dot->replace_all_uses_with(NT);
}
// dot(op(a), b)
if(!trans_b){
// create permutations
size_t size = B->get_type()->get_tile_shapes().size();
std::vector<ir::constant_int*> perm(size);
ir::type *int32_ty = ir::type::get_int32_ty(B->get_type()->get_context());
for(size_t i = 0; i < size; i++)
perm[i] = ir::constant_int::get(int32_ty, i);
std::swap(perm[0], perm[1]);
// replace NN -> NT (trans)
ir::value* BB = builder.create_trans(B, perm);
ir::instruction *NT = builder.insert(ir::dot_inst::create_nt(A, BB, D));
dot->replace_all_uses_with(NT);
}
}
}
}
}
}
}
}

View File

@@ -1,6 +1,6 @@
#include "triton/ir/module.h"
#include "triton/ir/function.h"
#include "triton/codegen/transform/trans.h"
#include "triton/codegen/transform/peephole.h"
namespace triton {
namespace codegen{
@@ -70,25 +70,18 @@ bool peephole::rewrite_trans_phi(ir::instruction* value, ir::builder& builder) {
if(users.size() > 1 || ops.size() > 1)
return false;
ir::value* op = *ops.begin();
// trans(phi) -> phi(trans(), trans()...)
auto* phi = dynamic_cast<ir::phi_node*>(op);
if(!phi)
return false;
ir::value* new_phi = rewrite_trans_phi_impl(op, builder, trans->get_perm());
ir::value* new_phi = rewrite_trans_phi_impl(phi, builder, trans->get_perm());
trans->replace_all_uses_with(new_phi);
return true;
}
bool peephole::rewrite_dot(ir::instruction *value, ir::builder& builder){
if(auto dot = dynamic_cast<ir::dot_inst*>(value)){
builder.set_insert_point(value);
ir::value *A = dot->get_operand(0);
ir::value *B = dot->get_operand(1);
ir::value *D = dot->get_operand(2);
bool trans_a = is_trans(A);
bool trans_b = is_trans(B);
// NN
if(!dot->is_a_trans() && !dot->is_b_trans()){
if(is_hmma(dot)) {
bool peephole::rewrite_dot_hmma(ir::dot_inst *dot, ir::builder& builder, bool trans_a, bool trans_b,
ir::value *A, ir::value *B, ir::value *D){
ir::value *AA = A;
ir::value *BB = B;
if(trans_a){
@@ -119,7 +112,9 @@ bool peephole::rewrite_dot(ir::instruction *value, ir::builder& builder){
dot->replace_all_uses_with(dot_atbt);
return true;
}
else{
bool peephole::rewrite_dot_fp32(ir::dot_inst *dot, ir::builder& builder, bool trans_a, bool trans_b,
ir::value *A, ir::value *B, ir::value *D){
// dot(op(a), trans(b))
if(trans_b){
ir::value* BB = ((ir::trans_inst*)B)->get_operand(0);
@@ -142,12 +137,29 @@ bool peephole::rewrite_dot(ir::instruction *value, ir::builder& builder){
dot->replace_all_uses_with(NT);
return true;
}
}
}
}
return false;
}
bool peephole::rewrite_dot(ir::instruction *value, ir::builder& builder){
auto dot = dynamic_cast<ir::dot_inst*>(value);
if(!dot)
return false;
builder.set_insert_point(value);
ir::value *A = dot->get_operand(0);
ir::value *B = dot->get_operand(1);
ir::value *D = dot->get_operand(2);
bool trans_a = is_trans(A);
bool trans_b = is_trans(B);
// only consider dot-nn
if(dot->is_a_trans() || dot->is_b_trans())
return false;
// hmma
if(is_hmma(dot))
return rewrite_dot_hmma(dot, builder, trans_a, trans_b, A, B, D);
else
return rewrite_dot_fp32(dot, builder, trans_a, trans_b, A, B, D);
}
bool peephole::rewrite_unit_red(ir::instruction *value, ir::builder& builder){
auto x = dynamic_cast<ir::reduce_inst*>(value);
if(!x)
@@ -190,28 +202,40 @@ bool peephole::rewrite_gep_ptr_min_off_plus_off(ir::instruction *value, ir::buil
void peephole::run(ir::module &mod) {
ir::builder &builder = mod.get_builder();
// keep track of whether any modification was made
bool was_modified = false;
std::set<ir::value*> seen;
size_t n_seen;
// rewrite dots first
do{
was_modified = false;
for(ir::function *fn: mod.get_function_list())
for(ir::basic_block *block: fn->blocks())
for(ir::instruction* i: block->get_inst_list())
rewrite_dot(i, builder);
}while(was_modified);
// rewrite other ops
do{
was_modified = false;
n_seen = seen.size();
for(ir::function *fn: mod.get_function_list())
for(ir::basic_block *block: fn->blocks())
for(ir::instruction* i: block->get_inst_list()){
if(seen.find(i) != seen.end())
continue;
bool was_modified = rewrite_dot(i, builder);
if(was_modified)
seen.insert(i);
}
}while(seen.size() != n_seen);
// rewrite other ops
seen.clear();
do{
n_seen = seen.size();
for(ir::function *fn: mod.get_function_list())
for(ir::basic_block *block: fn->blocks())
for(ir::instruction* i: block->get_inst_list()){
if(seen.find(i) != seen.end())
continue;
bool was_modified = false;
was_modified = was_modified || rewrite_trans_phi(i, builder);
was_modified = was_modified || rewrite_unit_red(i, builder);
was_modified = was_modified || rewrite_gep_ptr_min_off_plus_off(i, builder);
if(was_modified)
seen.insert(i);
}
}while(was_modified);
}while(seen.size() != n_seen);
}
}