Deprecation of Triton-C and Replacement by decorated Python functions (#86)

This PR implements a major overhaul of the frontend for Triton, and replaces Triton-C by a pure Python API in which kernels are defined as @triton.jit decorated functions. The documentation and tutorials have also been updated to accommodate these changes.

See documentations for more information on the new API
This commit is contained in:
Philippe Tillet
2021-04-20 22:29:40 -04:00
committed by Philippe Tillet
parent 1fdb465b71
commit 39f4730305
91 changed files with 4500 additions and 13008 deletions

View File

@@ -55,8 +55,8 @@ inline T add_to_cache(ir::value *i, T value, std::map<ir::value*, T> &map) {
std::vector<unsigned> align::get_shapes(ir::value *v) {
ir::type *ty = v->get_type();
if(ty->is_tile_ty())
return ty->get_tile_shapes();
if(ty->is_block_ty())
return ty->get_block_shapes();
else
return {1};
}
@@ -95,7 +95,7 @@ std::vector<align::cst_info> align::populate_is_constant_reshape(ir::reshape_ins
auto x_shapes = get_shapes(x);
std::vector<cst_info> result;
ir::value *op = x->get_operand(0);
auto op_shapes = op->get_type()->get_tile_shapes();
auto op_shapes = op->get_type()->get_block_shapes();
auto op_cst = populate_is_constant(op);
unsigned current = 0;
bool is_skewed = false;
@@ -119,7 +119,7 @@ std::vector<align::cst_info> align::populate_is_constant_broadcast(ir::broadcast
auto x_shapes = get_shapes(x);
std::vector<cst_info> result;
ir::value *op = x->get_operand(0);
auto op_shapes = op->get_type()->get_tile_shapes();
auto op_shapes = op->get_type()->get_block_shapes();
auto op_cst = populate_is_constant(op);
for(size_t d = 0; d < x_shapes.size(); d++)
if(op_shapes[d] == 1)
@@ -229,7 +229,7 @@ std::vector<unsigned> align::populate_max_contiguous_reshape(ir::reshape_inst* x
auto shapes = get_shapes(x);
std::vector<unsigned> result;
ir::value *op = x->get_operand(0);
auto op_shapes = op->get_type()->get_tile_shapes();
auto op_shapes = op->get_type()->get_block_shapes();
auto op_mc = populate_max_contiguous(op);
unsigned current = 0;
bool is_skewed = false;
@@ -251,7 +251,7 @@ std::vector<unsigned> align::populate_max_contiguous_broadcast(ir::broadcast_ins
auto shapes = get_shapes(x);
std::vector<unsigned> result;
ir::value *op = x->get_operand(0);
auto op_shapes = op->get_type()->get_tile_shapes();
auto op_shapes = op->get_type()->get_block_shapes();
auto op_mc = populate_max_contiguous(op);
for(size_t d = 0; d < shapes.size(); d++)
if(op_shapes[d] == 1)
@@ -317,9 +317,9 @@ std::vector<unsigned> align::populate_max_contiguous_gep(ir::getelementptr_inst*
}
std::vector<unsigned> align::populate_max_contiguous_default(ir::value* v) {
if(!v->get_type()->is_tile_ty())
if(!v->get_type()->is_block_ty())
return add_to_cache(v, {1}, max_contiguous_);
auto shapes = v->get_type()->get_tile_shapes();
auto shapes = v->get_type()->get_block_shapes();
if(dynamic_cast<ir::make_range*>(v))
return add_to_cache(v, {shapes[0]}, max_contiguous_);
if(dynamic_cast<ir::make_range_sta*>(v))
@@ -450,8 +450,8 @@ std::vector<unsigned> align::populate_starting_multiple_cast(ir::cast_inst* x){
std::vector<unsigned> align::populate_starting_multiple_default(ir::value* v) {
ir::type* ty = v->get_type();
if(ty->is_tile_ty()) {
return add_to_cache(v, ty->get_tile_shapes(), starting_multiple_);
if(ty->is_block_ty()) {
return add_to_cache(v, ty->get_block_shapes(), starting_multiple_);
}
if(auto *x = dynamic_cast<ir::argument*>(v)){
std::set<ir::attribute> attributes = x->get_parent()->get_attributes(x);
@@ -462,7 +462,7 @@ std::vector<unsigned> align::populate_starting_multiple_default(ir::value* v) {
if(attr.get_kind() == ir::aligned){
ir::type* ty = x->get_type()->get_pointer_element_ty();
int nbits = ty->get_primitive_size_in_bits();
int nbytes = nbits / 8;
int nbytes = std::max<int>(nbits / 8, 1);
return add_to_cache(x, {attr.get_value() / nbytes}, starting_multiple_);
}
}

View File

@@ -15,7 +15,7 @@ void axes::update_graph_reduce(ir::instruction *i) {
auto* red = static_cast<ir::reduce_inst*>(i);
unsigned axis = red->get_axis();
ir::value *arg = red->get_operand(0);
auto in_shapes = arg->get_type()->get_tile_shapes();
auto in_shapes = arg->get_type()->get_block_shapes();
unsigned current = 0;
for(unsigned d = 0; d < in_shapes.size(); d++){
if(d == axis)
@@ -29,8 +29,8 @@ void axes::update_graph_reshape(ir::instruction *i) {
// operands
ir::value *op = reshape->get_operand(0);
// shapes
auto op_shapes = op->get_type()->get_tile_shapes();
auto res_shapes = reshape->get_type()->get_tile_shapes();
auto op_shapes = op->get_type()->get_block_shapes();
auto res_shapes = reshape->get_type()->get_block_shapes();
// construct edges
unsigned current = 0;
bool is_skewed = false;
@@ -58,10 +58,10 @@ void axes::update_graph_trans(ir::instruction *i) {
void axes::update_graph_broadcast(ir::instruction *i) {
auto *broadcast = static_cast<ir::broadcast_inst*>(i);
auto shapes = broadcast->get_type()->get_tile_shapes();
auto shapes = broadcast->get_type()->get_block_shapes();
ir::value *op = broadcast->get_operand(0);
ir::type *op_ty = op->get_type();
const auto& op_shapes = op_ty->get_tile_shapes();
const auto& op_shapes = op_ty->get_block_shapes();
// add edge between non-broadcast axes
for(unsigned d = 0; d < shapes.size(); d ++)
if(op_shapes[d] == shapes[d])
@@ -70,7 +70,7 @@ void axes::update_graph_broadcast(ir::instruction *i) {
void axes::update_graph_dot(ir::instruction *i) {
auto *dot = static_cast<ir::dot_inst*>(i);
auto shapes = dot->get_type()->get_tile_shapes();
auto shapes = dot->get_type()->get_block_shapes();
ir::value *A = dot->get_operand(0);
ir::value *B = dot->get_operand(1);
ir::value *D = dot->get_operand(2);
@@ -83,7 +83,7 @@ void axes::update_graph_elementwise(ir::instruction *i, bool connect_ret) {
if(i->get_num_operands() == 0)
return;
ir::value *op = i->get_operand(0);
if(!op->get_type()->is_tile_ty())
if(!op->get_type()->is_block_ty())
return;
auto rank = op->get_type()->get_tile_rank();
for(unsigned d = 0; d < rank; d++)
@@ -96,7 +96,7 @@ void axes::update_graph_elementwise(ir::instruction *i, bool connect_ret) {
}
void axes::update_graph_no_edge(ir::instruction *i) {
if(!i->get_type()->is_tile_ty())
if(!i->get_type()->is_block_ty())
return;
auto rank = i->get_type()->get_tile_rank();
for(unsigned d = 0; d < rank; d++)

View File

@@ -325,9 +325,9 @@ layouts::layouts(analysis::axes *axes, analysis::align *align, size_t num_warps,
void layouts::connect(ir::value *x, ir::value *y) {
if(x == y)
return;
if(!x->get_type()->is_tile_ty())
if(!x->get_type()->is_block_ty())
return;
if(!y->get_type()->is_tile_ty())
if(!y->get_type()->is_block_ty())
return;
std::vector<int> x_axes = axes_->get(x);
std::vector<int> y_axes = axes_->get(y);
@@ -364,7 +364,7 @@ void layouts::create(size_t id, const std::vector<ir::value*>& values) {
std::remove_if(lvalue.begin(), lvalue.end(), [&](ir::value* v) { return dynamic_cast<ir::trans_inst*>(v); });
ir::value *largest = *std::max_element(lvalue.begin(), lvalue.end(), cmp);
const auto& axes = axes_->get(largest);
const auto& shapes = largest->get_type()->get_tile_shapes();
const auto& shapes = largest->get_type()->get_block_shapes();
auto it_cts = std::find_if(values.begin(), values.end(), [](ir::value* v) {
return dynamic_cast<ir::copy_to_shared_inst*>(v) ||
dynamic_cast<ir::masked_load_async_inst*>(v);
@@ -411,7 +411,7 @@ void layouts::run(ir::module &mod) {
ir::value *arg = red->get_operand(0);
unsigned axis = red->get_axis();
// shape
auto shapes = arg->get_type()->get_tile_shapes();
auto shapes = arg->get_type()->get_block_shapes();
scanline_layout *layout = get(arg)->to_scanline();
shapes[axis] = layout->mts(axis);
// create layout
@@ -425,8 +425,8 @@ void layouts::run(ir::module &mod) {
if(!in_layout || !out_layout)
return;
id++;
ir::type::tile_shapes_t in_shape = val->get_type()->get_tile_shapes();
ir::type::tile_shapes_t shape(in_shape.size());
ir::type::block_shapes_t in_shape = val->get_type()->get_block_shapes();
ir::type::block_shapes_t shape(in_shape.size());
size_t ld = out_layout->get_order(0);
shape[ld] = in_shape[ld];
for(size_t k = 0; k < in_shape.size(); k++)

103
lib/codegen/pass.cc Normal file
View File

@@ -0,0 +1,103 @@
#include "triton/codegen/pass.h"
#include "triton/codegen/analysis/align.h"
#include "triton/codegen/analysis/allocation.h"
#include "triton/codegen/analysis/axes.h"
#include "triton/codegen/analysis/liveness.h"
#include "triton/codegen/analysis/swizzle.h"
#include "triton/codegen/selection/generator.h"
#include "triton/codegen/transform/coalesce.h"
#include "triton/codegen/transform/cts.h"
#include "triton/codegen/transform/dce.h"
#include "triton/codegen/transform/disassociate.h"
#include "triton/codegen/transform/membar.h"
#include "triton/codegen/transform/peephole.h"
#include "triton/codegen/transform/pipeline.h"
#include "triton/codegen/transform/reassociate.h"
#include "triton/driver/device.h"
#include "triton/driver/kernel.h"
#include "triton/driver/module.h"
#include "triton/ir/function.h"
#include "triton/ir/module.h"
#include "triton/ir/print.h"
#include "llvm/IR/Module.h"
namespace triton {
namespace codegen {
// TODO:
// There should be a proper pass manager there!
void add_passes_to_emit_bin(ir::module &ir, driver::device *dev, int num_warps,
driver::module *&mod, driver::kernel *&ker, size_t &shared_mem) {
// generate llvm code
llvm::LLVMContext ctx;
std::string name = ir.get_function_list()[0]->get_name();
std::unique_ptr<llvm::Module> llvm(new llvm::Module(name, ctx));
// optimizations
std::unique_ptr<codegen::target> target = dev->make_target();
bool cts_use_async = target->as_nvidia()->sm() >= 80;
// create passes
codegen::analysis::align align;
codegen::analysis::axes axes;
codegen::transform::cts cts(cts_use_async);
codegen::transform::pipeline pipeline(cts_use_async);
codegen::transform::disassociate disassociate;
codegen::analysis::layouts layouts(&axes, &align, num_warps, target.get());
codegen::analysis::liveness liveness(&layouts);
codegen::analysis::swizzle swizzle(&layouts, target.get());
codegen::analysis::allocation allocation(&liveness);
codegen::transform::membar barriers(&liveness, &layouts, &allocation);
codegen::transform::dce dce;
codegen::transform::peephole peephole(target.get(), &layouts);
codegen::transform::reassociate reassociate;
codegen::transform::coalesce coalesce(&align, &layouts);
codegen::generator isel(&axes, &layouts, &align, &allocation, &swizzle, target.get(), num_warps);
// run passes
dce.run(ir);
//ir::print(ir, std::cout);
peephole.run(ir);
dce.run(ir);
pipeline.run(ir);
dce.run(ir);
//ir::print(ir, std::cout);
disassociate.run(ir);
dce.run(ir);
align.run(ir);
axes.run(ir);
layouts.run(ir);
peephole.run(ir);
dce.run(ir);
if (target->is_gpu())
cts.run(ir);
align.run(ir);
axes.run(ir);
layouts.run(ir);
coalesce.run(ir);
dce.run(ir);
align.run(ir);
dce.run(ir);
if (target->is_gpu()) {
reassociate.run(ir);
cts.run(ir);
}
dce.run(ir);
align.run(ir);
axes.run(ir);
layouts.run(ir);
peephole.run(ir);
dce.run(ir);
align.run(ir);
axes.run(ir);
layouts.run(ir);
swizzle.run(ir);
liveness.run(ir);
allocation.run(ir);
barriers.run(ir);
// ir::print(ir, std::cout);
isel.visit(ir, *llvm);
mod = driver::module::create(dev, std::move(llvm));
ker = driver::kernel::create(&*mod, name.c_str());
shared_mem = allocation.allocated_size();
}
} // namespace codegen
} // namespace triton

View File

@@ -150,7 +150,7 @@ generator::generator(analysis::axes *a_axes,
void generator::visit_value(ir::value* v) {
if(!seen_.insert(v).second)
return;
if(v->get_type()->is_tile_ty()){
if(v->get_type()->is_block_ty()){
if(analysis::shared_layout* layout = layouts_->get(v)->to_shared()){
auto double_buffer = layout->get_double_buffer();
// offset
@@ -384,7 +384,7 @@ void generator::visit_load_inst(ir::load_inst* x){
// compute vector width
size_t vec = 1;
if(op->get_type()->is_tile_ty()){
if(op->get_type()->is_block_ty()){
auto ord = ords_.at(op);
size_t aln = alignment_->get(op, ord[0]);
size_t nts = layouts_->get(x)->to_scanline()->nts(ord[0]);
@@ -407,7 +407,10 @@ void generator::visit_load_inst(ir::load_inst* x){
PHINode *_ret = phi(ptr->getType()->getPointerElementType(), 2);
Instruction *then_term;
Instruction *else_term;
builder_->SetInsertPoint(_ret->getParent());
Instruction* dummy = builder_->CreateRet(nullptr);
llvm::SplitBlockAndInsertIfThenElse(vals_[mx->get_mask_operand()][idx], _ret, &then_term, &else_term);
dummy->removeFromParent();
builder_->SetInsertPoint(then_term);
Value* then_ret = load(ptr);
builder_->SetInsertPoint(else_term);
@@ -441,7 +444,7 @@ void generator::visit_store_inst(ir::store_inst * x){
ir::value *val_op = x->get_value_operand();
// vector size
size_t vec = 1;
if(val_op->get_type()->is_tile_ty()){
if(val_op->get_type()->is_block_ty()){
auto ord = ords_.at(x->get_pointer_operand());
size_t aln = alignment_->get(ptr_op, ord[0]);
size_t nts = axes_.at(a_axes_->get(x->get_pointer_operand(), ord[0])).contiguous;
@@ -461,7 +464,10 @@ void generator::visit_store_inst(ir::store_inst * x){
if(mx){
Value *msk = vals_[mx->get_mask_operand()][idx];
Instruction *no_op = intrinsic(Intrinsic::donothing, {}, {});
builder_->SetInsertPoint(no_op->getParent());
Instruction* dummy = builder_->CreateRet(nullptr);
Instruction *term = llvm::SplitBlockAndInsertIfThen(msk, no_op, false);
dummy->removeFromParent();
builder_->SetInsertPoint(term);
store(val, ptr);
builder_->SetInsertPoint(no_op);
@@ -501,13 +507,15 @@ void generator::visit_splat_inst(ir::splat_inst* x) {
*/
void generator::visit_broadcast_inst(ir::broadcast_inst* x) {
ir::value* op = x->get_operand(0);
const auto& shape = op->get_type()->get_tile_shapes();
const auto& shape = op->get_type()->get_block_shapes();
for(auto out_idx: idxs_.at(x)){
indices_t in_idx = out_idx;
for(size_t k = 0; k < in_idx.size(); k++)
in_idx[k] = shape[k] == 1 ? i32(0) : in_idx[k];
vals_[x][out_idx] = vals_[op][in_idx];
}
// for(size_t i = 0; i < idxs_.at(x).size(); i++)
// vals_[x][idxs_[x][i]] = vals_[op][idxs_[op][i]];
}
/**
@@ -527,9 +535,9 @@ void generator::visit_get_program_id_inst(ir::get_program_id_inst* pid) {
}
/**
* \brief Code Generation for `get_num_program`
* \brief Code Generation for `get_num_programs`
*/
void generator::visit_get_num_program_inst(ir::get_num_program_inst* np) {
void generator::visit_get_num_programs_inst(ir::get_num_programs_inst* np) {
Module *module = builder_->GetInsertBlock()->getModule();
Value *ret = tgt_->get_num_blocks(module, *builder_, np->get_axis());
vals_[np][{}] = ret;
@@ -621,7 +629,7 @@ void generator::visit_atomic_exch_inst(ir::atomic_exch_inst* xchg) {
//TODO: clean-up
void generator::visit_atomic_add_inst(ir::atomic_add_inst* add) {
if(add->get_type()->is_tile_ty()){
if(add->get_type()->is_block_ty()){
ir::value* ptr = add->get_operand(0);
ir::value* val = add->get_operand(1);
ir::value* msk = add->get_operand(2);
@@ -706,9 +714,9 @@ void generator::visit_atomic_add_inst(ir::atomic_add_inst* add) {
//TODO: clean-up
void generator::visit_mma884(ir::dot_inst* C, ir::value *A, ir::value *B, ir::value *D, unsigned NK) {
// shapes
auto shape_c = C->get_type()->get_tile_shapes();
auto shape_a = A->get_type()->get_tile_shapes();
auto shape_b = B->get_type()->get_tile_shapes();
auto shape_c = C->get_type()->get_block_shapes();
auto shape_a = A->get_type()->get_block_shapes();
auto shape_b = B->get_type()->get_block_shapes();
// order
auto ord_a = layouts_->get(A)->get_order();
auto ord_b = layouts_->get(B)->get_order();
@@ -877,7 +885,7 @@ void generator::visit_mma884(ir::dot_inst* C, ir::value *A, ir::value *B, ir::va
*/
//TODO: clean-up
void generator::visit_mma16816(ir::dot_inst* dot, ir::value *A, ir::value *B, ir::value *D, unsigned NK) {
const auto& shapes = dot->get_type()->get_tile_shapes();
const auto& shapes = dot->get_type()->get_block_shapes();
std::map<std::vector<Value*>, std::vector<Value*>> fcs;
@@ -887,8 +895,8 @@ void generator::visit_mma16816(ir::dot_inst* dot, ir::value *A, ir::value *B, ir
fcs[key].push_back(vals_[D][idx]);
};
auto shape_a = A->get_type()->get_tile_shapes();
auto shape_b = B->get_type()->get_tile_shapes();
auto shape_a = A->get_type()->get_block_shapes();
auto shape_b = B->get_type()->get_block_shapes();
auto ord_a = layouts_->get(A)->get_order();
auto ord_b = layouts_->get(B)->get_order();
analysis::mma_layout* layout = layouts_->get(dot)->to_mma();
@@ -1059,9 +1067,9 @@ void generator::visit_mma16816(ir::dot_inst* dot, ir::value *A, ir::value *B, ir
* \brief Code Generation for FMA-based `dot` (FP32, FP64, Default)
*/
void generator::visit_fmadot(ir::dot_inst* C, ir::value* A, ir::value* B, ir::value* D, unsigned NK, Type *c_ty, Function *f_mul_add) {
auto shape_c = C->get_type()->get_tile_shapes();
auto shape_a = A->get_type()->get_tile_shapes();
auto shape_b = B->get_type()->get_tile_shapes();
auto shape_c = C->get_type()->get_block_shapes();
auto shape_a = A->get_type()->get_block_shapes();
auto shape_b = B->get_type()->get_block_shapes();
auto ord_a = layouts_->get(A)->get_order();
auto ord_b = layouts_->get(B)->get_order();
analysis::scanline_layout* layout_c = layouts_->get(C)->to_scanline();
@@ -1161,7 +1169,7 @@ void generator::visit_dot_inst(ir::dot_inst* dot) {
ir::value *D = dot->get_operand(2);
Type *c_ty = cvt(D->get_type()->get_scalar_ty());
Function *f_mul_add = Intrinsic::getDeclaration(module, Intrinsic::fmuladd, std::vector<llvm::Type*>{c_ty});
auto A_shapes = A->get_type()->get_tile_shapes();
auto A_shapes = A->get_type()->get_block_shapes();
size_t red_axis = 1;
unsigned NK = A_shapes[red_axis];
bool is_outer = NK == 1;
@@ -1236,7 +1244,10 @@ void generator::visit_reduce1d_inst(ir::reduce_inst* x, std::function<Value*(Val
// reduce across warps
Value *cond = icmp_eq(warp, i32(0));
Instruction *barrier = add_barrier();
builder_->SetInsertPoint(barrier->getParent());
Instruction* dummy = builder_->CreateRet(nullptr);
Instruction *term = llvm::SplitBlockAndInsertIfThen(cond, barrier, false);
dummy->removeFromParent();
builder_->SetInsertPoint(term);
Value* ret = load(gep(base, thread));
for(int i = (num_warps_+1)/2; i > 0; i >>= 1){
@@ -1359,10 +1370,11 @@ void generator::visit_reduce_inst(ir::reduce_inst* x) {
* \brief Code Generation for `select`
*/
void generator::visit_select_inst(ir::select_inst* x) {
for(indices_t idx: idxs_.at(x))
for(indices_t idx: idxs_.at(x)){
vals_[x][idx] = select(vals_[x->get_operand(0)][idx],
vals_[x->get_operand(1)][idx],
vals_[x->get_operand(2)][idx]);
}
}
/**
@@ -1370,7 +1382,7 @@ void generator::visit_select_inst(ir::select_inst* x) {
*/
void generator::visit_recoalesce_inst(ir::recoalesce_inst* rc) {
ir::value *op = rc->get_operand(0);
ir::tile_type::tile_shapes_t shape = rc->get_type()->get_tile_shapes();
ir::block_type::block_shapes_t shape = rc->get_type()->get_block_shapes();
// pointer to temporary shared memory
Type *ty = cvt(rc->get_type()->get_scalar_ty());
// layout
@@ -1435,7 +1447,7 @@ void generator::visit_masked_load_async_inst(ir::masked_load_async_inst* x){
int in_ld = in_layout->get_shape()[in_order[0]] / in_layout->mts(in_order[0]);
int n_shared_1 = std::max<int>(per_phase*max_phase / in_layout->mts(in_order[1]), 1);
int n_shared_0 = std::max<int>(in_vec / out_vec, 1);
auto shapes = x->get_type()->get_tile_shapes();
auto shapes = x->get_type()->get_block_shapes();
BasicBlock* CurrBB = builder_->GetInsertBlock();
BasicBlock* FirstBB = &CurrBB->getParent()->getEntryBlock();
std::map<std::pair<int, int>, Value*> tmp;
@@ -1520,7 +1532,7 @@ void generator::visit_copy_to_shared_inst(ir::copy_to_shared_inst* cts) {
BasicBlock* CurrBB = builder_->GetInsertBlock();
BasicBlock* FirstBB = &CurrBB->getParent()->getEntryBlock();
auto shapes = cts->get_type()->get_tile_shapes();
auto shapes = cts->get_type()->get_block_shapes();
// store to shared
Value *current = nullptr;
@@ -1901,13 +1913,13 @@ void generator::visit_argument(ir::argument* arg) {
void generator::init_idx(ir::value *v) {
idxs_[v].clear();
if(!v->get_type()->is_tile_ty()){
if(!v->get_type()->is_block_ty()){
idxs_[v].push_back({});
return;
}
if(layouts_->get(v)->to_shared())
return;
const auto &shapes = v->get_type()->get_tile_shapes();
const auto &shapes = v->get_type()->get_block_shapes();
size_t rank = shapes.size();
std::vector<distributed_axis> axes(rank);
std::vector<int> ord(rank);

View File

@@ -37,7 +37,7 @@ int membar::group_of(ir::value* v, std::vector<ir::value*> &async_write) {
membar::val_set_t membar::intersect_with(const val_set_t& as, const val_set_t& bs) {
val_set_t ret;
for(ir::value* a: as){
if(!a->get_type()->is_tile_ty())
if(!a->get_type()->is_block_ty())
continue;
analysis::shared_layout* a_layout = layouts_->get(a)->to_shared();
if(!a_layout)
@@ -45,7 +45,7 @@ membar::val_set_t membar::intersect_with(const val_set_t& as, const val_set_t& b
int a_start = alloc_->offset(a_layout);
int a_end = a_start + a_layout->get_size();
for(ir::value* b: bs){
if(!b->get_type()->is_tile_ty())
if(!b->get_type()->is_block_ty())
continue;
analysis::shared_layout* b_layout = layouts_->get(b)->to_shared();
if(!b_layout)
@@ -80,7 +80,7 @@ void membar::transfer(ir::basic_block *block,
// Get shared memory reads
std::set<ir::value*> read;
std::copy_if(i->op_begin(), i->op_end(), std::inserter(read, read.begin()),
[&](ir::value* i){ return i->get_type()->is_tile_ty() && layouts_->get(i)->to_shared();});
[&](ir::value* i){ return i->get_type()->is_block_ty() && layouts_->get(i)->to_shared();});
// RAW (async)
val_set_t tmp;
std::copy(async_write.begin(), async_write.end(), std::inserter(tmp, tmp.begin()));

View File

@@ -58,7 +58,8 @@ bool peephole::rewrite_trans_phi(ir::instruction* value, ir::builder& builder) {
}
bool peephole::rewrite_dot(ir::instruction *value, ir::builder& builder){
// dot(a, b, 0) + c -> dot(a, b, c)
// dot(a, b, c) + d -> dot(a, b, c + d)
// d + dot(a, b, c) -> dot(a, b, c + d)
auto add = dynamic_cast<ir::binary_operator*>(value);
if(add && add->get_op() == ir::binary_op_t::FAdd) {
ir::value *lhs = add->get_operand(0);
@@ -131,10 +132,10 @@ bool peephole::rewrite_unit_red(ir::instruction *value, ir::builder& builder){
if(!x)
return false;
ir::value *arg = x->get_operand(0);
auto shapes = arg->get_type()->get_tile_shapes();
auto shapes = arg->get_type()->get_block_shapes();
if(shapes[x->get_axis()] == 1){
builder.set_insert_point(x);
ir::value* new_red = builder.create_reshape(arg, x->get_type()->get_tile_shapes());
ir::value* new_red = builder.create_reshape(arg, x->get_type()->get_block_shapes());
x->replace_all_uses_with(new_red);
return true;
}

View File

@@ -23,6 +23,24 @@ void recursive_deps(ir::value* v, ir::basic_block* block, std::vector<ir::instru
recursive_deps(u, block, ret);
}
ir::value* rematerialize(ir::builder& builder, ir::value* v, size_t phi_idx){
ir::instruction* i = dynamic_cast<ir::instruction*>(v);
if(!i)
return v;
if(ir::phi_node* phi = dynamic_cast<ir::phi_node*>(v))
return phi->get_incoming_value(phi_idx);
std::vector<ir::value*> new_ops;
for(ir::value* op: i->ops()){
new_ops.push_back(rematerialize(builder, op, phi_idx));
}
ir::instruction* ret = i->clone();
for(size_t k = 0; k < new_ops.size(); k++)
ret->set_operand(k, new_ops[k]);
builder.insert(ret);
return ret;
}
void pipeline::run(ir::module &mod) {
// *Very* conservative heuristics for pre-fetching.
// A load instruction can be pipelined if:
@@ -55,21 +73,27 @@ void pipeline::run(ir::module &mod) {
// pre-fetch first iteration
builder.set_insert_point(header->get_inst_list().back());
ir::value* first_ptr = ptr->get_value_for_block(header);
ir::value* first_mask = builder.create_splat(header_br->get_cond(), ty->get_tile_shapes());
ir::value* first_mask = builder.create_splat(header_br->get_cond(), ty->get_block_shapes());
ir::value* false_value;
if(auto* masked_load = dynamic_cast<ir::masked_load_inst*>(load)){
first_mask = builder.create_and(first_mask, masked_load->get_mask_operand());
false_value = masked_load->get_false_value_operand();
ir::value* remat_mask = rematerialize(builder, masked_load->get_mask_operand(), 0);
ir::value* remat_false_value = rematerialize(builder, masked_load->get_false_value_operand(), 0);
first_mask = builder.create_and(first_mask, remat_mask);
false_value = remat_false_value;
}
else
false_value = builder.create_splat(ir::undef_value::get(ty->get_scalar_ty()), ty->get_tile_shapes());
false_value = builder.create_splat(ir::undef_value::get(ty->get_scalar_ty()), ty->get_block_shapes());
ir::value* first_load = builder.create_masked_load(first_ptr, first_mask, false_value);
// pre-fetch next iteration
builder.set_insert_point(block->get_inst_list().back());
ir::value* next_ptr = ptr->get_value_for_block(block);
ir::value* next_mask = builder.create_splat(block_br->get_cond(), ty->get_tile_shapes());
if(auto* masked_load = dynamic_cast<ir::masked_load_inst*>(load))
next_mask = builder.create_and(next_mask, masked_load->get_mask_operand());
ir::value* next_mask = builder.create_splat(block_br->get_cond(), ty->get_block_shapes());
if(auto* masked_load = dynamic_cast<ir::masked_load_inst*>(load)){
ir::value* remat_mask = rematerialize(builder, masked_load->get_mask_operand(), 1);
ir::value* remat_false_value = rematerialize(builder, masked_load->get_false_value_operand(), 1);
next_mask = builder.create_and(next_mask, remat_mask);
false_value = remat_false_value;
}
ir::value* next_load = builder.create_masked_load(next_ptr, next_mask, false_value);
// phi node
builder.set_insert_point(block->get_first_non_phi());

View File

@@ -40,7 +40,7 @@ ir::value *reassociate::reassociate_idx(ir::value *old_value,
// handle retiling
if(ir::instruction* op = dynamic_cast<ir::retile_inst*>(old_value)){
auto shapes = op->get_type()->get_tile_shapes();
auto shapes = op->get_type()->get_block_shapes();
ir::value *old_arg = op->get_operand(0);
ir::value *new_arg = reassociate_idx(old_arg, builder, noncst, cst);
// retile(x + y) = retile(x) + retile(y)
@@ -54,19 +54,19 @@ ir::value *reassociate::reassociate_idx(ir::value *old_value,
builder.set_insert_point(op);
new_lhs = builder.create_reshape(old_lhs, shapes);
new_rhs = builder.create_reshape(old_rhs, shapes);
new_value = builder.create_add(new_lhs, new_rhs, op->get_name());
new_value = builder.create_add(new_lhs, new_rhs);
}
if(dynamic_cast<ir::broadcast_inst*>(op)){
builder.set_insert_point(op);
new_lhs = builder.create_broadcast(old_lhs, shapes);
new_rhs = builder.create_broadcast(old_rhs, shapes);
new_value = builder.create_add(new_lhs, new_rhs, op->get_name());
new_value = builder.create_add(new_lhs, new_rhs);
}
if(dynamic_cast<ir::splat_inst*>(op)){
builder.set_insert_point(op);
new_lhs = builder.create_splat(old_lhs, shapes);
new_rhs = builder.create_splat(old_rhs, shapes);
new_value = builder.create_add(new_lhs, new_rhs, op->get_name());
new_value = builder.create_add(new_lhs, new_rhs);
}
}
}
@@ -84,10 +84,10 @@ ir::value *reassociate::reassociate_idx(ir::value *old_value,
ir::value *rlhs = bin_lhs->get_operand(1);
// (cst + x) + y -> cst + (x + y)
if(is_cst(llhs))
new_value = builder.create_add(llhs, builder.create_add(rlhs, rhs), name);
new_value = builder.create_add(llhs, builder.create_add(rlhs, rhs));
// (x + cst) + y -> cst + (x + y)
if(is_cst(rlhs))
new_value = builder.create_add(rlhs, builder.create_add(llhs, rhs), name);
new_value = builder.create_add(rlhs, builder.create_add(llhs, rhs));
}
// x + (y + z)
if(ir::instruction* bin_rhs = is_bin_add(rhs)){
@@ -95,10 +95,10 @@ ir::value *reassociate::reassociate_idx(ir::value *old_value,
ir::value *rrhs = bin_rhs->get_operand(1);
// x + (cst + y) -> cst + (x + y)
if(is_cst(lrhs))
new_value = builder.create_add(lrhs, builder.create_add(rrhs, lhs), name, cst);
new_value = builder.create_add(lrhs, builder.create_add(rrhs, lhs), cst);
// x + (y + cst) -> cst + (x + y)
if(is_cst(rrhs))
new_value = builder.create_add(rrhs, builder.create_add(lrhs, lhs), name, cst);
new_value = builder.create_add(rrhs, builder.create_add(lrhs, lhs), cst);
}
}
// extract constant and non-constant
@@ -166,7 +166,7 @@ void reassociate::run(ir::module &mod) {
ir::value* dyn = infos.at(op).dyn_ptr;
ir::value* cst = *sta->idx_begin();
if(dynamic_cast<ir::broadcast_inst*>(rt)) {
auto shapes = rt->get_type()->get_tile_shapes();
auto shapes = rt->get_type()->get_block_shapes();
ir::value* ndyn = builder.create_broadcast(dyn, shapes);
ir::value* broadcast = builder.create_broadcast(cst, shapes);
ir::getelementptr_inst* nsta = (ir::getelementptr_inst*)builder.create_gep(ndyn, {broadcast});
@@ -202,7 +202,7 @@ void reassociate::run(ir::module &mod) {
ir::value *cst = *sta->idx_begin();
ir::value *off = *pz->idx_begin();
ir::value *pz_dyn = builder.create_gep(dyn, {off});
ir::value *pz_sta = builder.create_gep(pz_dyn, {cst}, pz->get_name());
ir::value *pz_sta = builder.create_gep(pz_dyn, {cst});
pz->replace_all_uses_with(pz_sta);
infos[pz_sta].dyn_ptr = pz_dyn;
infos[pz_sta].sta_ptr = (ir::getelementptr_inst*)pz_sta;
@@ -235,7 +235,8 @@ void reassociate::run(ir::module &mod) {
phi_dyn->add_incoming(pa_dyn, phi->get_incoming_block(idx_a));
builder.set_insert_point(phi->get_parent()->get_first_non_phi());
// re-add the offset
ir::value *phi_sta = builder.create_gep(phi_dyn, {off}, phi->get_name() + "_sta");
ir::value *phi_sta = builder.create_gep(phi_dyn, {off});
phi_sta->set_name( phi->get_name() + "_sta");
phi->replace_all_uses_with(phi_sta);
// remove offset from pz
if(auto *x = dynamic_cast<ir::instruction*>(pz)){
@@ -245,8 +246,8 @@ void reassociate::run(ir::module &mod) {
builder.set_insert_point(*it);
}
ir::value *_0 = builder.get_int32(0);
if(off->get_type()->is_tile_ty())
_0 = builder.create_splat(_0, off->get_type()->get_tile_shapes());
if(off->get_type()->is_block_ty())
_0 = builder.create_splat(_0, off->get_type()->get_block_shapes());
ir::value *neg_off = builder.create_sub(_0, off);
ir::value *pz_dyn = builder.create_gep(pz, {neg_off});
phi_dyn->add_incoming(pz_dyn, phi->get_incoming_block(idx_z));

View File

@@ -11,38 +11,38 @@ namespace codegen{
namespace transform{
void reorder::run(ir::module& mod){
ir::builder &builder = mod.get_builder();
std::vector<std::pair<ir::instruction*, ir::value*>> to_replace;
// ir::builder &builder = mod.get_builder();
// std::vector<std::pair<ir::instruction*, ir::value*>> to_replace;
for(ir::function *fn: mod.get_function_list())
for(ir::basic_block *block: fn->blocks())
for(ir::instruction* i: block->get_inst_list()){
if(auto* ld = dynamic_cast<ir::masked_load_inst*>(i)){
ir::value* _ptr = ld->get_pointer_operand();
ir::value* _msk = ld->get_mask_operand();
ir::value* _val = ld->get_false_value_operand();
auto ptr = std::find(block->begin(), block->end(), _ptr);
auto msk = std::find(block->begin(), block->end(), _msk);
auto val = std::find(block->begin(), block->end(), _val);
if(ptr == block->end() || msk == block->end() || val == block->end())
continue;
auto it = std::find(block->begin(), block->end(), i);
int dist_ptr = std::distance(ptr, it);
int dist_msk = std::distance(msk, it);
int dist_val = std::distance(val, it);
if(dist_ptr < dist_msk && dist_ptr < dist_val)
builder.set_insert_point(++ptr);
if(dist_msk < dist_ptr && dist_msk < dist_val)
builder.set_insert_point(++msk);
if(dist_val < dist_ptr && dist_val < dist_msk)
builder.set_insert_point(++val);
ir::value* new_ld = builder.create_masked_load(_ptr, _msk, _val);
to_replace.push_back(std::make_pair(ld, new_ld));
}
}
// for(ir::function *fn: mod.get_function_list())
// for(ir::basic_block *block: fn->blocks())
// for(ir::instruction* i: block->get_inst_list()){
// if(auto* ld = dynamic_cast<ir::masked_load_inst*>(i)){
// ir::value* _ptr = ld->get_pointer_operand();
// ir::value* _msk = ld->get_mask_operand();
// ir::value* _val = ld->get_false_value_operand();
// auto ptr = std::find(block->begin(), block->end(), _ptr);
// auto msk = std::find(block->begin(), block->end(), _msk);
// auto val = std::find(block->begin(), block->end(), _val);
// if(ptr == block->end() || msk == block->end() || val == block->end())
// continue;
// auto it = std::find(block->begin(), block->end(), i);
// int dist_ptr = std::distance(ptr, it);
// int dist_msk = std::distance(msk, it);
// int dist_val = std::distance(val, it);
// if(dist_ptr < dist_msk && dist_ptr < dist_val)
// builder.set_insert_point(++ptr);
// if(dist_msk < dist_ptr && dist_msk < dist_val)
// builder.set_insert_point(++msk);
// if(dist_val < dist_ptr && dist_val < dist_msk)
// builder.set_insert_point(++val);
// ir::value* new_ld = builder.create_masked_load(_ptr, _msk, _val);
// to_replace.push_back(std::make_pair(ld, new_ld));
// }
// }
for(auto& x: to_replace)
x.first->replace_all_uses_with(x.second);
// for(auto& x: to_replace)
// x.first->replace_all_uses_with(x.second);
}