Deprecation of Triton-C and Replacement by decorated Python functions (#86)
This PR implements a major overhaul of the frontend for Triton, and replaces Triton-C by a pure Python API in which kernels are defined as @triton.jit decorated functions. The documentation and tutorials have also been updated to accommodate these changes. See documentations for more information on the new API
This commit is contained in:
committed by
Philippe Tillet
parent
1fdb465b71
commit
39f4730305
@@ -55,8 +55,8 @@ inline T add_to_cache(ir::value *i, T value, std::map<ir::value*, T> &map) {
|
||||
|
||||
std::vector<unsigned> align::get_shapes(ir::value *v) {
|
||||
ir::type *ty = v->get_type();
|
||||
if(ty->is_tile_ty())
|
||||
return ty->get_tile_shapes();
|
||||
if(ty->is_block_ty())
|
||||
return ty->get_block_shapes();
|
||||
else
|
||||
return {1};
|
||||
}
|
||||
@@ -95,7 +95,7 @@ std::vector<align::cst_info> align::populate_is_constant_reshape(ir::reshape_ins
|
||||
auto x_shapes = get_shapes(x);
|
||||
std::vector<cst_info> result;
|
||||
ir::value *op = x->get_operand(0);
|
||||
auto op_shapes = op->get_type()->get_tile_shapes();
|
||||
auto op_shapes = op->get_type()->get_block_shapes();
|
||||
auto op_cst = populate_is_constant(op);
|
||||
unsigned current = 0;
|
||||
bool is_skewed = false;
|
||||
@@ -119,7 +119,7 @@ std::vector<align::cst_info> align::populate_is_constant_broadcast(ir::broadcast
|
||||
auto x_shapes = get_shapes(x);
|
||||
std::vector<cst_info> result;
|
||||
ir::value *op = x->get_operand(0);
|
||||
auto op_shapes = op->get_type()->get_tile_shapes();
|
||||
auto op_shapes = op->get_type()->get_block_shapes();
|
||||
auto op_cst = populate_is_constant(op);
|
||||
for(size_t d = 0; d < x_shapes.size(); d++)
|
||||
if(op_shapes[d] == 1)
|
||||
@@ -229,7 +229,7 @@ std::vector<unsigned> align::populate_max_contiguous_reshape(ir::reshape_inst* x
|
||||
auto shapes = get_shapes(x);
|
||||
std::vector<unsigned> result;
|
||||
ir::value *op = x->get_operand(0);
|
||||
auto op_shapes = op->get_type()->get_tile_shapes();
|
||||
auto op_shapes = op->get_type()->get_block_shapes();
|
||||
auto op_mc = populate_max_contiguous(op);
|
||||
unsigned current = 0;
|
||||
bool is_skewed = false;
|
||||
@@ -251,7 +251,7 @@ std::vector<unsigned> align::populate_max_contiguous_broadcast(ir::broadcast_ins
|
||||
auto shapes = get_shapes(x);
|
||||
std::vector<unsigned> result;
|
||||
ir::value *op = x->get_operand(0);
|
||||
auto op_shapes = op->get_type()->get_tile_shapes();
|
||||
auto op_shapes = op->get_type()->get_block_shapes();
|
||||
auto op_mc = populate_max_contiguous(op);
|
||||
for(size_t d = 0; d < shapes.size(); d++)
|
||||
if(op_shapes[d] == 1)
|
||||
@@ -317,9 +317,9 @@ std::vector<unsigned> align::populate_max_contiguous_gep(ir::getelementptr_inst*
|
||||
}
|
||||
|
||||
std::vector<unsigned> align::populate_max_contiguous_default(ir::value* v) {
|
||||
if(!v->get_type()->is_tile_ty())
|
||||
if(!v->get_type()->is_block_ty())
|
||||
return add_to_cache(v, {1}, max_contiguous_);
|
||||
auto shapes = v->get_type()->get_tile_shapes();
|
||||
auto shapes = v->get_type()->get_block_shapes();
|
||||
if(dynamic_cast<ir::make_range*>(v))
|
||||
return add_to_cache(v, {shapes[0]}, max_contiguous_);
|
||||
if(dynamic_cast<ir::make_range_sta*>(v))
|
||||
@@ -450,8 +450,8 @@ std::vector<unsigned> align::populate_starting_multiple_cast(ir::cast_inst* x){
|
||||
|
||||
std::vector<unsigned> align::populate_starting_multiple_default(ir::value* v) {
|
||||
ir::type* ty = v->get_type();
|
||||
if(ty->is_tile_ty()) {
|
||||
return add_to_cache(v, ty->get_tile_shapes(), starting_multiple_);
|
||||
if(ty->is_block_ty()) {
|
||||
return add_to_cache(v, ty->get_block_shapes(), starting_multiple_);
|
||||
}
|
||||
if(auto *x = dynamic_cast<ir::argument*>(v)){
|
||||
std::set<ir::attribute> attributes = x->get_parent()->get_attributes(x);
|
||||
@@ -462,7 +462,7 @@ std::vector<unsigned> align::populate_starting_multiple_default(ir::value* v) {
|
||||
if(attr.get_kind() == ir::aligned){
|
||||
ir::type* ty = x->get_type()->get_pointer_element_ty();
|
||||
int nbits = ty->get_primitive_size_in_bits();
|
||||
int nbytes = nbits / 8;
|
||||
int nbytes = std::max<int>(nbits / 8, 1);
|
||||
return add_to_cache(x, {attr.get_value() / nbytes}, starting_multiple_);
|
||||
}
|
||||
}
|
||||
|
@@ -15,7 +15,7 @@ void axes::update_graph_reduce(ir::instruction *i) {
|
||||
auto* red = static_cast<ir::reduce_inst*>(i);
|
||||
unsigned axis = red->get_axis();
|
||||
ir::value *arg = red->get_operand(0);
|
||||
auto in_shapes = arg->get_type()->get_tile_shapes();
|
||||
auto in_shapes = arg->get_type()->get_block_shapes();
|
||||
unsigned current = 0;
|
||||
for(unsigned d = 0; d < in_shapes.size(); d++){
|
||||
if(d == axis)
|
||||
@@ -29,8 +29,8 @@ void axes::update_graph_reshape(ir::instruction *i) {
|
||||
// operands
|
||||
ir::value *op = reshape->get_operand(0);
|
||||
// shapes
|
||||
auto op_shapes = op->get_type()->get_tile_shapes();
|
||||
auto res_shapes = reshape->get_type()->get_tile_shapes();
|
||||
auto op_shapes = op->get_type()->get_block_shapes();
|
||||
auto res_shapes = reshape->get_type()->get_block_shapes();
|
||||
// construct edges
|
||||
unsigned current = 0;
|
||||
bool is_skewed = false;
|
||||
@@ -58,10 +58,10 @@ void axes::update_graph_trans(ir::instruction *i) {
|
||||
|
||||
void axes::update_graph_broadcast(ir::instruction *i) {
|
||||
auto *broadcast = static_cast<ir::broadcast_inst*>(i);
|
||||
auto shapes = broadcast->get_type()->get_tile_shapes();
|
||||
auto shapes = broadcast->get_type()->get_block_shapes();
|
||||
ir::value *op = broadcast->get_operand(0);
|
||||
ir::type *op_ty = op->get_type();
|
||||
const auto& op_shapes = op_ty->get_tile_shapes();
|
||||
const auto& op_shapes = op_ty->get_block_shapes();
|
||||
// add edge between non-broadcast axes
|
||||
for(unsigned d = 0; d < shapes.size(); d ++)
|
||||
if(op_shapes[d] == shapes[d])
|
||||
@@ -70,7 +70,7 @@ void axes::update_graph_broadcast(ir::instruction *i) {
|
||||
|
||||
void axes::update_graph_dot(ir::instruction *i) {
|
||||
auto *dot = static_cast<ir::dot_inst*>(i);
|
||||
auto shapes = dot->get_type()->get_tile_shapes();
|
||||
auto shapes = dot->get_type()->get_block_shapes();
|
||||
ir::value *A = dot->get_operand(0);
|
||||
ir::value *B = dot->get_operand(1);
|
||||
ir::value *D = dot->get_operand(2);
|
||||
@@ -83,7 +83,7 @@ void axes::update_graph_elementwise(ir::instruction *i, bool connect_ret) {
|
||||
if(i->get_num_operands() == 0)
|
||||
return;
|
||||
ir::value *op = i->get_operand(0);
|
||||
if(!op->get_type()->is_tile_ty())
|
||||
if(!op->get_type()->is_block_ty())
|
||||
return;
|
||||
auto rank = op->get_type()->get_tile_rank();
|
||||
for(unsigned d = 0; d < rank; d++)
|
||||
@@ -96,7 +96,7 @@ void axes::update_graph_elementwise(ir::instruction *i, bool connect_ret) {
|
||||
}
|
||||
|
||||
void axes::update_graph_no_edge(ir::instruction *i) {
|
||||
if(!i->get_type()->is_tile_ty())
|
||||
if(!i->get_type()->is_block_ty())
|
||||
return;
|
||||
auto rank = i->get_type()->get_tile_rank();
|
||||
for(unsigned d = 0; d < rank; d++)
|
||||
|
@@ -325,9 +325,9 @@ layouts::layouts(analysis::axes *axes, analysis::align *align, size_t num_warps,
|
||||
void layouts::connect(ir::value *x, ir::value *y) {
|
||||
if(x == y)
|
||||
return;
|
||||
if(!x->get_type()->is_tile_ty())
|
||||
if(!x->get_type()->is_block_ty())
|
||||
return;
|
||||
if(!y->get_type()->is_tile_ty())
|
||||
if(!y->get_type()->is_block_ty())
|
||||
return;
|
||||
std::vector<int> x_axes = axes_->get(x);
|
||||
std::vector<int> y_axes = axes_->get(y);
|
||||
@@ -364,7 +364,7 @@ void layouts::create(size_t id, const std::vector<ir::value*>& values) {
|
||||
std::remove_if(lvalue.begin(), lvalue.end(), [&](ir::value* v) { return dynamic_cast<ir::trans_inst*>(v); });
|
||||
ir::value *largest = *std::max_element(lvalue.begin(), lvalue.end(), cmp);
|
||||
const auto& axes = axes_->get(largest);
|
||||
const auto& shapes = largest->get_type()->get_tile_shapes();
|
||||
const auto& shapes = largest->get_type()->get_block_shapes();
|
||||
auto it_cts = std::find_if(values.begin(), values.end(), [](ir::value* v) {
|
||||
return dynamic_cast<ir::copy_to_shared_inst*>(v) ||
|
||||
dynamic_cast<ir::masked_load_async_inst*>(v);
|
||||
@@ -411,7 +411,7 @@ void layouts::run(ir::module &mod) {
|
||||
ir::value *arg = red->get_operand(0);
|
||||
unsigned axis = red->get_axis();
|
||||
// shape
|
||||
auto shapes = arg->get_type()->get_tile_shapes();
|
||||
auto shapes = arg->get_type()->get_block_shapes();
|
||||
scanline_layout *layout = get(arg)->to_scanline();
|
||||
shapes[axis] = layout->mts(axis);
|
||||
// create layout
|
||||
@@ -425,8 +425,8 @@ void layouts::run(ir::module &mod) {
|
||||
if(!in_layout || !out_layout)
|
||||
return;
|
||||
id++;
|
||||
ir::type::tile_shapes_t in_shape = val->get_type()->get_tile_shapes();
|
||||
ir::type::tile_shapes_t shape(in_shape.size());
|
||||
ir::type::block_shapes_t in_shape = val->get_type()->get_block_shapes();
|
||||
ir::type::block_shapes_t shape(in_shape.size());
|
||||
size_t ld = out_layout->get_order(0);
|
||||
shape[ld] = in_shape[ld];
|
||||
for(size_t k = 0; k < in_shape.size(); k++)
|
||||
|
103
lib/codegen/pass.cc
Normal file
103
lib/codegen/pass.cc
Normal file
@@ -0,0 +1,103 @@
|
||||
#include "triton/codegen/pass.h"
|
||||
#include "triton/codegen/analysis/align.h"
|
||||
#include "triton/codegen/analysis/allocation.h"
|
||||
#include "triton/codegen/analysis/axes.h"
|
||||
#include "triton/codegen/analysis/liveness.h"
|
||||
#include "triton/codegen/analysis/swizzle.h"
|
||||
#include "triton/codegen/selection/generator.h"
|
||||
#include "triton/codegen/transform/coalesce.h"
|
||||
#include "triton/codegen/transform/cts.h"
|
||||
#include "triton/codegen/transform/dce.h"
|
||||
#include "triton/codegen/transform/disassociate.h"
|
||||
#include "triton/codegen/transform/membar.h"
|
||||
#include "triton/codegen/transform/peephole.h"
|
||||
#include "triton/codegen/transform/pipeline.h"
|
||||
#include "triton/codegen/transform/reassociate.h"
|
||||
#include "triton/driver/device.h"
|
||||
#include "triton/driver/kernel.h"
|
||||
#include "triton/driver/module.h"
|
||||
#include "triton/ir/function.h"
|
||||
#include "triton/ir/module.h"
|
||||
#include "triton/ir/print.h"
|
||||
#include "llvm/IR/Module.h"
|
||||
|
||||
namespace triton {
|
||||
namespace codegen {
|
||||
|
||||
// TODO:
|
||||
// There should be a proper pass manager there!
|
||||
void add_passes_to_emit_bin(ir::module &ir, driver::device *dev, int num_warps,
|
||||
driver::module *&mod, driver::kernel *&ker, size_t &shared_mem) {
|
||||
// generate llvm code
|
||||
llvm::LLVMContext ctx;
|
||||
std::string name = ir.get_function_list()[0]->get_name();
|
||||
std::unique_ptr<llvm::Module> llvm(new llvm::Module(name, ctx));
|
||||
// optimizations
|
||||
std::unique_ptr<codegen::target> target = dev->make_target();
|
||||
bool cts_use_async = target->as_nvidia()->sm() >= 80;
|
||||
// create passes
|
||||
codegen::analysis::align align;
|
||||
codegen::analysis::axes axes;
|
||||
codegen::transform::cts cts(cts_use_async);
|
||||
codegen::transform::pipeline pipeline(cts_use_async);
|
||||
codegen::transform::disassociate disassociate;
|
||||
codegen::analysis::layouts layouts(&axes, &align, num_warps, target.get());
|
||||
codegen::analysis::liveness liveness(&layouts);
|
||||
codegen::analysis::swizzle swizzle(&layouts, target.get());
|
||||
codegen::analysis::allocation allocation(&liveness);
|
||||
codegen::transform::membar barriers(&liveness, &layouts, &allocation);
|
||||
codegen::transform::dce dce;
|
||||
codegen::transform::peephole peephole(target.get(), &layouts);
|
||||
codegen::transform::reassociate reassociate;
|
||||
codegen::transform::coalesce coalesce(&align, &layouts);
|
||||
codegen::generator isel(&axes, &layouts, &align, &allocation, &swizzle, target.get(), num_warps);
|
||||
// run passes
|
||||
dce.run(ir);
|
||||
//ir::print(ir, std::cout);
|
||||
peephole.run(ir);
|
||||
dce.run(ir);
|
||||
pipeline.run(ir);
|
||||
dce.run(ir);
|
||||
//ir::print(ir, std::cout);
|
||||
disassociate.run(ir);
|
||||
dce.run(ir);
|
||||
align.run(ir);
|
||||
axes.run(ir);
|
||||
layouts.run(ir);
|
||||
peephole.run(ir);
|
||||
dce.run(ir);
|
||||
if (target->is_gpu())
|
||||
cts.run(ir);
|
||||
align.run(ir);
|
||||
axes.run(ir);
|
||||
layouts.run(ir);
|
||||
coalesce.run(ir);
|
||||
dce.run(ir);
|
||||
align.run(ir);
|
||||
dce.run(ir);
|
||||
if (target->is_gpu()) {
|
||||
reassociate.run(ir);
|
||||
cts.run(ir);
|
||||
}
|
||||
dce.run(ir);
|
||||
align.run(ir);
|
||||
axes.run(ir);
|
||||
layouts.run(ir);
|
||||
peephole.run(ir);
|
||||
dce.run(ir);
|
||||
align.run(ir);
|
||||
axes.run(ir);
|
||||
layouts.run(ir);
|
||||
swizzle.run(ir);
|
||||
liveness.run(ir);
|
||||
allocation.run(ir);
|
||||
barriers.run(ir);
|
||||
// ir::print(ir, std::cout);
|
||||
isel.visit(ir, *llvm);
|
||||
mod = driver::module::create(dev, std::move(llvm));
|
||||
ker = driver::kernel::create(&*mod, name.c_str());
|
||||
shared_mem = allocation.allocated_size();
|
||||
}
|
||||
|
||||
} // namespace codegen
|
||||
} // namespace triton
|
@@ -150,7 +150,7 @@ generator::generator(analysis::axes *a_axes,
|
||||
void generator::visit_value(ir::value* v) {
|
||||
if(!seen_.insert(v).second)
|
||||
return;
|
||||
if(v->get_type()->is_tile_ty()){
|
||||
if(v->get_type()->is_block_ty()){
|
||||
if(analysis::shared_layout* layout = layouts_->get(v)->to_shared()){
|
||||
auto double_buffer = layout->get_double_buffer();
|
||||
// offset
|
||||
@@ -384,7 +384,7 @@ void generator::visit_load_inst(ir::load_inst* x){
|
||||
|
||||
// compute vector width
|
||||
size_t vec = 1;
|
||||
if(op->get_type()->is_tile_ty()){
|
||||
if(op->get_type()->is_block_ty()){
|
||||
auto ord = ords_.at(op);
|
||||
size_t aln = alignment_->get(op, ord[0]);
|
||||
size_t nts = layouts_->get(x)->to_scanline()->nts(ord[0]);
|
||||
@@ -407,7 +407,10 @@ void generator::visit_load_inst(ir::load_inst* x){
|
||||
PHINode *_ret = phi(ptr->getType()->getPointerElementType(), 2);
|
||||
Instruction *then_term;
|
||||
Instruction *else_term;
|
||||
builder_->SetInsertPoint(_ret->getParent());
|
||||
Instruction* dummy = builder_->CreateRet(nullptr);
|
||||
llvm::SplitBlockAndInsertIfThenElse(vals_[mx->get_mask_operand()][idx], _ret, &then_term, &else_term);
|
||||
dummy->removeFromParent();
|
||||
builder_->SetInsertPoint(then_term);
|
||||
Value* then_ret = load(ptr);
|
||||
builder_->SetInsertPoint(else_term);
|
||||
@@ -441,7 +444,7 @@ void generator::visit_store_inst(ir::store_inst * x){
|
||||
ir::value *val_op = x->get_value_operand();
|
||||
// vector size
|
||||
size_t vec = 1;
|
||||
if(val_op->get_type()->is_tile_ty()){
|
||||
if(val_op->get_type()->is_block_ty()){
|
||||
auto ord = ords_.at(x->get_pointer_operand());
|
||||
size_t aln = alignment_->get(ptr_op, ord[0]);
|
||||
size_t nts = axes_.at(a_axes_->get(x->get_pointer_operand(), ord[0])).contiguous;
|
||||
@@ -461,7 +464,10 @@ void generator::visit_store_inst(ir::store_inst * x){
|
||||
if(mx){
|
||||
Value *msk = vals_[mx->get_mask_operand()][idx];
|
||||
Instruction *no_op = intrinsic(Intrinsic::donothing, {}, {});
|
||||
builder_->SetInsertPoint(no_op->getParent());
|
||||
Instruction* dummy = builder_->CreateRet(nullptr);
|
||||
Instruction *term = llvm::SplitBlockAndInsertIfThen(msk, no_op, false);
|
||||
dummy->removeFromParent();
|
||||
builder_->SetInsertPoint(term);
|
||||
store(val, ptr);
|
||||
builder_->SetInsertPoint(no_op);
|
||||
@@ -501,13 +507,15 @@ void generator::visit_splat_inst(ir::splat_inst* x) {
|
||||
*/
|
||||
void generator::visit_broadcast_inst(ir::broadcast_inst* x) {
|
||||
ir::value* op = x->get_operand(0);
|
||||
const auto& shape = op->get_type()->get_tile_shapes();
|
||||
const auto& shape = op->get_type()->get_block_shapes();
|
||||
for(auto out_idx: idxs_.at(x)){
|
||||
indices_t in_idx = out_idx;
|
||||
for(size_t k = 0; k < in_idx.size(); k++)
|
||||
in_idx[k] = shape[k] == 1 ? i32(0) : in_idx[k];
|
||||
vals_[x][out_idx] = vals_[op][in_idx];
|
||||
}
|
||||
// for(size_t i = 0; i < idxs_.at(x).size(); i++)
|
||||
// vals_[x][idxs_[x][i]] = vals_[op][idxs_[op][i]];
|
||||
}
|
||||
|
||||
/**
|
||||
@@ -527,9 +535,9 @@ void generator::visit_get_program_id_inst(ir::get_program_id_inst* pid) {
|
||||
}
|
||||
|
||||
/**
|
||||
* \brief Code Generation for `get_num_program`
|
||||
* \brief Code Generation for `get_num_programs`
|
||||
*/
|
||||
void generator::visit_get_num_program_inst(ir::get_num_program_inst* np) {
|
||||
void generator::visit_get_num_programs_inst(ir::get_num_programs_inst* np) {
|
||||
Module *module = builder_->GetInsertBlock()->getModule();
|
||||
Value *ret = tgt_->get_num_blocks(module, *builder_, np->get_axis());
|
||||
vals_[np][{}] = ret;
|
||||
@@ -621,7 +629,7 @@ void generator::visit_atomic_exch_inst(ir::atomic_exch_inst* xchg) {
|
||||
//TODO: clean-up
|
||||
void generator::visit_atomic_add_inst(ir::atomic_add_inst* add) {
|
||||
|
||||
if(add->get_type()->is_tile_ty()){
|
||||
if(add->get_type()->is_block_ty()){
|
||||
ir::value* ptr = add->get_operand(0);
|
||||
ir::value* val = add->get_operand(1);
|
||||
ir::value* msk = add->get_operand(2);
|
||||
@@ -706,9 +714,9 @@ void generator::visit_atomic_add_inst(ir::atomic_add_inst* add) {
|
||||
//TODO: clean-up
|
||||
void generator::visit_mma884(ir::dot_inst* C, ir::value *A, ir::value *B, ir::value *D, unsigned NK) {
|
||||
// shapes
|
||||
auto shape_c = C->get_type()->get_tile_shapes();
|
||||
auto shape_a = A->get_type()->get_tile_shapes();
|
||||
auto shape_b = B->get_type()->get_tile_shapes();
|
||||
auto shape_c = C->get_type()->get_block_shapes();
|
||||
auto shape_a = A->get_type()->get_block_shapes();
|
||||
auto shape_b = B->get_type()->get_block_shapes();
|
||||
// order
|
||||
auto ord_a = layouts_->get(A)->get_order();
|
||||
auto ord_b = layouts_->get(B)->get_order();
|
||||
@@ -877,7 +885,7 @@ void generator::visit_mma884(ir::dot_inst* C, ir::value *A, ir::value *B, ir::va
|
||||
*/
|
||||
//TODO: clean-up
|
||||
void generator::visit_mma16816(ir::dot_inst* dot, ir::value *A, ir::value *B, ir::value *D, unsigned NK) {
|
||||
const auto& shapes = dot->get_type()->get_tile_shapes();
|
||||
const auto& shapes = dot->get_type()->get_block_shapes();
|
||||
|
||||
std::map<std::vector<Value*>, std::vector<Value*>> fcs;
|
||||
|
||||
@@ -887,8 +895,8 @@ void generator::visit_mma16816(ir::dot_inst* dot, ir::value *A, ir::value *B, ir
|
||||
fcs[key].push_back(vals_[D][idx]);
|
||||
};
|
||||
|
||||
auto shape_a = A->get_type()->get_tile_shapes();
|
||||
auto shape_b = B->get_type()->get_tile_shapes();
|
||||
auto shape_a = A->get_type()->get_block_shapes();
|
||||
auto shape_b = B->get_type()->get_block_shapes();
|
||||
auto ord_a = layouts_->get(A)->get_order();
|
||||
auto ord_b = layouts_->get(B)->get_order();
|
||||
analysis::mma_layout* layout = layouts_->get(dot)->to_mma();
|
||||
@@ -1059,9 +1067,9 @@ void generator::visit_mma16816(ir::dot_inst* dot, ir::value *A, ir::value *B, ir
|
||||
* \brief Code Generation for FMA-based `dot` (FP32, FP64, Default)
|
||||
*/
|
||||
void generator::visit_fmadot(ir::dot_inst* C, ir::value* A, ir::value* B, ir::value* D, unsigned NK, Type *c_ty, Function *f_mul_add) {
|
||||
auto shape_c = C->get_type()->get_tile_shapes();
|
||||
auto shape_a = A->get_type()->get_tile_shapes();
|
||||
auto shape_b = B->get_type()->get_tile_shapes();
|
||||
auto shape_c = C->get_type()->get_block_shapes();
|
||||
auto shape_a = A->get_type()->get_block_shapes();
|
||||
auto shape_b = B->get_type()->get_block_shapes();
|
||||
auto ord_a = layouts_->get(A)->get_order();
|
||||
auto ord_b = layouts_->get(B)->get_order();
|
||||
analysis::scanline_layout* layout_c = layouts_->get(C)->to_scanline();
|
||||
@@ -1161,7 +1169,7 @@ void generator::visit_dot_inst(ir::dot_inst* dot) {
|
||||
ir::value *D = dot->get_operand(2);
|
||||
Type *c_ty = cvt(D->get_type()->get_scalar_ty());
|
||||
Function *f_mul_add = Intrinsic::getDeclaration(module, Intrinsic::fmuladd, std::vector<llvm::Type*>{c_ty});
|
||||
auto A_shapes = A->get_type()->get_tile_shapes();
|
||||
auto A_shapes = A->get_type()->get_block_shapes();
|
||||
size_t red_axis = 1;
|
||||
unsigned NK = A_shapes[red_axis];
|
||||
bool is_outer = NK == 1;
|
||||
@@ -1236,7 +1244,10 @@ void generator::visit_reduce1d_inst(ir::reduce_inst* x, std::function<Value*(Val
|
||||
// reduce across warps
|
||||
Value *cond = icmp_eq(warp, i32(0));
|
||||
Instruction *barrier = add_barrier();
|
||||
builder_->SetInsertPoint(barrier->getParent());
|
||||
Instruction* dummy = builder_->CreateRet(nullptr);
|
||||
Instruction *term = llvm::SplitBlockAndInsertIfThen(cond, barrier, false);
|
||||
dummy->removeFromParent();
|
||||
builder_->SetInsertPoint(term);
|
||||
Value* ret = load(gep(base, thread));
|
||||
for(int i = (num_warps_+1)/2; i > 0; i >>= 1){
|
||||
@@ -1359,10 +1370,11 @@ void generator::visit_reduce_inst(ir::reduce_inst* x) {
|
||||
* \brief Code Generation for `select`
|
||||
*/
|
||||
void generator::visit_select_inst(ir::select_inst* x) {
|
||||
for(indices_t idx: idxs_.at(x))
|
||||
for(indices_t idx: idxs_.at(x)){
|
||||
vals_[x][idx] = select(vals_[x->get_operand(0)][idx],
|
||||
vals_[x->get_operand(1)][idx],
|
||||
vals_[x->get_operand(2)][idx]);
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
@@ -1370,7 +1382,7 @@ void generator::visit_select_inst(ir::select_inst* x) {
|
||||
*/
|
||||
void generator::visit_recoalesce_inst(ir::recoalesce_inst* rc) {
|
||||
ir::value *op = rc->get_operand(0);
|
||||
ir::tile_type::tile_shapes_t shape = rc->get_type()->get_tile_shapes();
|
||||
ir::block_type::block_shapes_t shape = rc->get_type()->get_block_shapes();
|
||||
// pointer to temporary shared memory
|
||||
Type *ty = cvt(rc->get_type()->get_scalar_ty());
|
||||
// layout
|
||||
@@ -1435,7 +1447,7 @@ void generator::visit_masked_load_async_inst(ir::masked_load_async_inst* x){
|
||||
int in_ld = in_layout->get_shape()[in_order[0]] / in_layout->mts(in_order[0]);
|
||||
int n_shared_1 = std::max<int>(per_phase*max_phase / in_layout->mts(in_order[1]), 1);
|
||||
int n_shared_0 = std::max<int>(in_vec / out_vec, 1);
|
||||
auto shapes = x->get_type()->get_tile_shapes();
|
||||
auto shapes = x->get_type()->get_block_shapes();
|
||||
BasicBlock* CurrBB = builder_->GetInsertBlock();
|
||||
BasicBlock* FirstBB = &CurrBB->getParent()->getEntryBlock();
|
||||
std::map<std::pair<int, int>, Value*> tmp;
|
||||
@@ -1520,7 +1532,7 @@ void generator::visit_copy_to_shared_inst(ir::copy_to_shared_inst* cts) {
|
||||
|
||||
BasicBlock* CurrBB = builder_->GetInsertBlock();
|
||||
BasicBlock* FirstBB = &CurrBB->getParent()->getEntryBlock();
|
||||
auto shapes = cts->get_type()->get_tile_shapes();
|
||||
auto shapes = cts->get_type()->get_block_shapes();
|
||||
|
||||
// store to shared
|
||||
Value *current = nullptr;
|
||||
@@ -1901,13 +1913,13 @@ void generator::visit_argument(ir::argument* arg) {
|
||||
|
||||
void generator::init_idx(ir::value *v) {
|
||||
idxs_[v].clear();
|
||||
if(!v->get_type()->is_tile_ty()){
|
||||
if(!v->get_type()->is_block_ty()){
|
||||
idxs_[v].push_back({});
|
||||
return;
|
||||
}
|
||||
if(layouts_->get(v)->to_shared())
|
||||
return;
|
||||
const auto &shapes = v->get_type()->get_tile_shapes();
|
||||
const auto &shapes = v->get_type()->get_block_shapes();
|
||||
size_t rank = shapes.size();
|
||||
std::vector<distributed_axis> axes(rank);
|
||||
std::vector<int> ord(rank);
|
||||
|
@@ -37,7 +37,7 @@ int membar::group_of(ir::value* v, std::vector<ir::value*> &async_write) {
|
||||
membar::val_set_t membar::intersect_with(const val_set_t& as, const val_set_t& bs) {
|
||||
val_set_t ret;
|
||||
for(ir::value* a: as){
|
||||
if(!a->get_type()->is_tile_ty())
|
||||
if(!a->get_type()->is_block_ty())
|
||||
continue;
|
||||
analysis::shared_layout* a_layout = layouts_->get(a)->to_shared();
|
||||
if(!a_layout)
|
||||
@@ -45,7 +45,7 @@ membar::val_set_t membar::intersect_with(const val_set_t& as, const val_set_t& b
|
||||
int a_start = alloc_->offset(a_layout);
|
||||
int a_end = a_start + a_layout->get_size();
|
||||
for(ir::value* b: bs){
|
||||
if(!b->get_type()->is_tile_ty())
|
||||
if(!b->get_type()->is_block_ty())
|
||||
continue;
|
||||
analysis::shared_layout* b_layout = layouts_->get(b)->to_shared();
|
||||
if(!b_layout)
|
||||
@@ -80,7 +80,7 @@ void membar::transfer(ir::basic_block *block,
|
||||
// Get shared memory reads
|
||||
std::set<ir::value*> read;
|
||||
std::copy_if(i->op_begin(), i->op_end(), std::inserter(read, read.begin()),
|
||||
[&](ir::value* i){ return i->get_type()->is_tile_ty() && layouts_->get(i)->to_shared();});
|
||||
[&](ir::value* i){ return i->get_type()->is_block_ty() && layouts_->get(i)->to_shared();});
|
||||
// RAW (async)
|
||||
val_set_t tmp;
|
||||
std::copy(async_write.begin(), async_write.end(), std::inserter(tmp, tmp.begin()));
|
||||
|
@@ -58,7 +58,8 @@ bool peephole::rewrite_trans_phi(ir::instruction* value, ir::builder& builder) {
|
||||
}
|
||||
|
||||
bool peephole::rewrite_dot(ir::instruction *value, ir::builder& builder){
|
||||
// dot(a, b, 0) + c -> dot(a, b, c)
|
||||
// dot(a, b, c) + d -> dot(a, b, c + d)
|
||||
// d + dot(a, b, c) -> dot(a, b, c + d)
|
||||
auto add = dynamic_cast<ir::binary_operator*>(value);
|
||||
if(add && add->get_op() == ir::binary_op_t::FAdd) {
|
||||
ir::value *lhs = add->get_operand(0);
|
||||
@@ -131,10 +132,10 @@ bool peephole::rewrite_unit_red(ir::instruction *value, ir::builder& builder){
|
||||
if(!x)
|
||||
return false;
|
||||
ir::value *arg = x->get_operand(0);
|
||||
auto shapes = arg->get_type()->get_tile_shapes();
|
||||
auto shapes = arg->get_type()->get_block_shapes();
|
||||
if(shapes[x->get_axis()] == 1){
|
||||
builder.set_insert_point(x);
|
||||
ir::value* new_red = builder.create_reshape(arg, x->get_type()->get_tile_shapes());
|
||||
ir::value* new_red = builder.create_reshape(arg, x->get_type()->get_block_shapes());
|
||||
x->replace_all_uses_with(new_red);
|
||||
return true;
|
||||
}
|
||||
|
@@ -23,6 +23,24 @@ void recursive_deps(ir::value* v, ir::basic_block* block, std::vector<ir::instru
|
||||
recursive_deps(u, block, ret);
|
||||
}
|
||||
|
||||
ir::value* rematerialize(ir::builder& builder, ir::value* v, size_t phi_idx){
|
||||
ir::instruction* i = dynamic_cast<ir::instruction*>(v);
|
||||
if(!i)
|
||||
return v;
|
||||
if(ir::phi_node* phi = dynamic_cast<ir::phi_node*>(v))
|
||||
return phi->get_incoming_value(phi_idx);
|
||||
|
||||
std::vector<ir::value*> new_ops;
|
||||
for(ir::value* op: i->ops()){
|
||||
new_ops.push_back(rematerialize(builder, op, phi_idx));
|
||||
}
|
||||
ir::instruction* ret = i->clone();
|
||||
for(size_t k = 0; k < new_ops.size(); k++)
|
||||
ret->set_operand(k, new_ops[k]);
|
||||
builder.insert(ret);
|
||||
return ret;
|
||||
}
|
||||
|
||||
void pipeline::run(ir::module &mod) {
|
||||
// *Very* conservative heuristics for pre-fetching.
|
||||
// A load instruction can be pipelined if:
|
||||
@@ -55,21 +73,27 @@ void pipeline::run(ir::module &mod) {
|
||||
// pre-fetch first iteration
|
||||
builder.set_insert_point(header->get_inst_list().back());
|
||||
ir::value* first_ptr = ptr->get_value_for_block(header);
|
||||
ir::value* first_mask = builder.create_splat(header_br->get_cond(), ty->get_tile_shapes());
|
||||
ir::value* first_mask = builder.create_splat(header_br->get_cond(), ty->get_block_shapes());
|
||||
ir::value* false_value;
|
||||
if(auto* masked_load = dynamic_cast<ir::masked_load_inst*>(load)){
|
||||
first_mask = builder.create_and(first_mask, masked_load->get_mask_operand());
|
||||
false_value = masked_load->get_false_value_operand();
|
||||
ir::value* remat_mask = rematerialize(builder, masked_load->get_mask_operand(), 0);
|
||||
ir::value* remat_false_value = rematerialize(builder, masked_load->get_false_value_operand(), 0);
|
||||
first_mask = builder.create_and(first_mask, remat_mask);
|
||||
false_value = remat_false_value;
|
||||
}
|
||||
else
|
||||
false_value = builder.create_splat(ir::undef_value::get(ty->get_scalar_ty()), ty->get_tile_shapes());
|
||||
false_value = builder.create_splat(ir::undef_value::get(ty->get_scalar_ty()), ty->get_block_shapes());
|
||||
ir::value* first_load = builder.create_masked_load(first_ptr, first_mask, false_value);
|
||||
// pre-fetch next iteration
|
||||
builder.set_insert_point(block->get_inst_list().back());
|
||||
ir::value* next_ptr = ptr->get_value_for_block(block);
|
||||
ir::value* next_mask = builder.create_splat(block_br->get_cond(), ty->get_tile_shapes());
|
||||
if(auto* masked_load = dynamic_cast<ir::masked_load_inst*>(load))
|
||||
next_mask = builder.create_and(next_mask, masked_load->get_mask_operand());
|
||||
ir::value* next_mask = builder.create_splat(block_br->get_cond(), ty->get_block_shapes());
|
||||
if(auto* masked_load = dynamic_cast<ir::masked_load_inst*>(load)){
|
||||
ir::value* remat_mask = rematerialize(builder, masked_load->get_mask_operand(), 1);
|
||||
ir::value* remat_false_value = rematerialize(builder, masked_load->get_false_value_operand(), 1);
|
||||
next_mask = builder.create_and(next_mask, remat_mask);
|
||||
false_value = remat_false_value;
|
||||
}
|
||||
ir::value* next_load = builder.create_masked_load(next_ptr, next_mask, false_value);
|
||||
// phi node
|
||||
builder.set_insert_point(block->get_first_non_phi());
|
||||
|
@@ -40,7 +40,7 @@ ir::value *reassociate::reassociate_idx(ir::value *old_value,
|
||||
|
||||
// handle retiling
|
||||
if(ir::instruction* op = dynamic_cast<ir::retile_inst*>(old_value)){
|
||||
auto shapes = op->get_type()->get_tile_shapes();
|
||||
auto shapes = op->get_type()->get_block_shapes();
|
||||
ir::value *old_arg = op->get_operand(0);
|
||||
ir::value *new_arg = reassociate_idx(old_arg, builder, noncst, cst);
|
||||
// retile(x + y) = retile(x) + retile(y)
|
||||
@@ -54,19 +54,19 @@ ir::value *reassociate::reassociate_idx(ir::value *old_value,
|
||||
builder.set_insert_point(op);
|
||||
new_lhs = builder.create_reshape(old_lhs, shapes);
|
||||
new_rhs = builder.create_reshape(old_rhs, shapes);
|
||||
new_value = builder.create_add(new_lhs, new_rhs, op->get_name());
|
||||
new_value = builder.create_add(new_lhs, new_rhs);
|
||||
}
|
||||
if(dynamic_cast<ir::broadcast_inst*>(op)){
|
||||
builder.set_insert_point(op);
|
||||
new_lhs = builder.create_broadcast(old_lhs, shapes);
|
||||
new_rhs = builder.create_broadcast(old_rhs, shapes);
|
||||
new_value = builder.create_add(new_lhs, new_rhs, op->get_name());
|
||||
new_value = builder.create_add(new_lhs, new_rhs);
|
||||
}
|
||||
if(dynamic_cast<ir::splat_inst*>(op)){
|
||||
builder.set_insert_point(op);
|
||||
new_lhs = builder.create_splat(old_lhs, shapes);
|
||||
new_rhs = builder.create_splat(old_rhs, shapes);
|
||||
new_value = builder.create_add(new_lhs, new_rhs, op->get_name());
|
||||
new_value = builder.create_add(new_lhs, new_rhs);
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -84,10 +84,10 @@ ir::value *reassociate::reassociate_idx(ir::value *old_value,
|
||||
ir::value *rlhs = bin_lhs->get_operand(1);
|
||||
// (cst + x) + y -> cst + (x + y)
|
||||
if(is_cst(llhs))
|
||||
new_value = builder.create_add(llhs, builder.create_add(rlhs, rhs), name);
|
||||
new_value = builder.create_add(llhs, builder.create_add(rlhs, rhs));
|
||||
// (x + cst) + y -> cst + (x + y)
|
||||
if(is_cst(rlhs))
|
||||
new_value = builder.create_add(rlhs, builder.create_add(llhs, rhs), name);
|
||||
new_value = builder.create_add(rlhs, builder.create_add(llhs, rhs));
|
||||
}
|
||||
// x + (y + z)
|
||||
if(ir::instruction* bin_rhs = is_bin_add(rhs)){
|
||||
@@ -95,10 +95,10 @@ ir::value *reassociate::reassociate_idx(ir::value *old_value,
|
||||
ir::value *rrhs = bin_rhs->get_operand(1);
|
||||
// x + (cst + y) -> cst + (x + y)
|
||||
if(is_cst(lrhs))
|
||||
new_value = builder.create_add(lrhs, builder.create_add(rrhs, lhs), name, cst);
|
||||
new_value = builder.create_add(lrhs, builder.create_add(rrhs, lhs), cst);
|
||||
// x + (y + cst) -> cst + (x + y)
|
||||
if(is_cst(rrhs))
|
||||
new_value = builder.create_add(rrhs, builder.create_add(lrhs, lhs), name, cst);
|
||||
new_value = builder.create_add(rrhs, builder.create_add(lrhs, lhs), cst);
|
||||
}
|
||||
}
|
||||
// extract constant and non-constant
|
||||
@@ -166,7 +166,7 @@ void reassociate::run(ir::module &mod) {
|
||||
ir::value* dyn = infos.at(op).dyn_ptr;
|
||||
ir::value* cst = *sta->idx_begin();
|
||||
if(dynamic_cast<ir::broadcast_inst*>(rt)) {
|
||||
auto shapes = rt->get_type()->get_tile_shapes();
|
||||
auto shapes = rt->get_type()->get_block_shapes();
|
||||
ir::value* ndyn = builder.create_broadcast(dyn, shapes);
|
||||
ir::value* broadcast = builder.create_broadcast(cst, shapes);
|
||||
ir::getelementptr_inst* nsta = (ir::getelementptr_inst*)builder.create_gep(ndyn, {broadcast});
|
||||
@@ -202,7 +202,7 @@ void reassociate::run(ir::module &mod) {
|
||||
ir::value *cst = *sta->idx_begin();
|
||||
ir::value *off = *pz->idx_begin();
|
||||
ir::value *pz_dyn = builder.create_gep(dyn, {off});
|
||||
ir::value *pz_sta = builder.create_gep(pz_dyn, {cst}, pz->get_name());
|
||||
ir::value *pz_sta = builder.create_gep(pz_dyn, {cst});
|
||||
pz->replace_all_uses_with(pz_sta);
|
||||
infos[pz_sta].dyn_ptr = pz_dyn;
|
||||
infos[pz_sta].sta_ptr = (ir::getelementptr_inst*)pz_sta;
|
||||
@@ -235,7 +235,8 @@ void reassociate::run(ir::module &mod) {
|
||||
phi_dyn->add_incoming(pa_dyn, phi->get_incoming_block(idx_a));
|
||||
builder.set_insert_point(phi->get_parent()->get_first_non_phi());
|
||||
// re-add the offset
|
||||
ir::value *phi_sta = builder.create_gep(phi_dyn, {off}, phi->get_name() + "_sta");
|
||||
ir::value *phi_sta = builder.create_gep(phi_dyn, {off});
|
||||
phi_sta->set_name( phi->get_name() + "_sta");
|
||||
phi->replace_all_uses_with(phi_sta);
|
||||
// remove offset from pz
|
||||
if(auto *x = dynamic_cast<ir::instruction*>(pz)){
|
||||
@@ -245,8 +246,8 @@ void reassociate::run(ir::module &mod) {
|
||||
builder.set_insert_point(*it);
|
||||
}
|
||||
ir::value *_0 = builder.get_int32(0);
|
||||
if(off->get_type()->is_tile_ty())
|
||||
_0 = builder.create_splat(_0, off->get_type()->get_tile_shapes());
|
||||
if(off->get_type()->is_block_ty())
|
||||
_0 = builder.create_splat(_0, off->get_type()->get_block_shapes());
|
||||
ir::value *neg_off = builder.create_sub(_0, off);
|
||||
ir::value *pz_dyn = builder.create_gep(pz, {neg_off});
|
||||
phi_dyn->add_incoming(pz_dyn, phi->get_incoming_block(idx_z));
|
||||
|
@@ -11,38 +11,38 @@ namespace codegen{
|
||||
namespace transform{
|
||||
|
||||
void reorder::run(ir::module& mod){
|
||||
ir::builder &builder = mod.get_builder();
|
||||
std::vector<std::pair<ir::instruction*, ir::value*>> to_replace;
|
||||
// ir::builder &builder = mod.get_builder();
|
||||
// std::vector<std::pair<ir::instruction*, ir::value*>> to_replace;
|
||||
|
||||
for(ir::function *fn: mod.get_function_list())
|
||||
for(ir::basic_block *block: fn->blocks())
|
||||
for(ir::instruction* i: block->get_inst_list()){
|
||||
if(auto* ld = dynamic_cast<ir::masked_load_inst*>(i)){
|
||||
ir::value* _ptr = ld->get_pointer_operand();
|
||||
ir::value* _msk = ld->get_mask_operand();
|
||||
ir::value* _val = ld->get_false_value_operand();
|
||||
auto ptr = std::find(block->begin(), block->end(), _ptr);
|
||||
auto msk = std::find(block->begin(), block->end(), _msk);
|
||||
auto val = std::find(block->begin(), block->end(), _val);
|
||||
if(ptr == block->end() || msk == block->end() || val == block->end())
|
||||
continue;
|
||||
auto it = std::find(block->begin(), block->end(), i);
|
||||
int dist_ptr = std::distance(ptr, it);
|
||||
int dist_msk = std::distance(msk, it);
|
||||
int dist_val = std::distance(val, it);
|
||||
if(dist_ptr < dist_msk && dist_ptr < dist_val)
|
||||
builder.set_insert_point(++ptr);
|
||||
if(dist_msk < dist_ptr && dist_msk < dist_val)
|
||||
builder.set_insert_point(++msk);
|
||||
if(dist_val < dist_ptr && dist_val < dist_msk)
|
||||
builder.set_insert_point(++val);
|
||||
ir::value* new_ld = builder.create_masked_load(_ptr, _msk, _val);
|
||||
to_replace.push_back(std::make_pair(ld, new_ld));
|
||||
}
|
||||
}
|
||||
// for(ir::function *fn: mod.get_function_list())
|
||||
// for(ir::basic_block *block: fn->blocks())
|
||||
// for(ir::instruction* i: block->get_inst_list()){
|
||||
// if(auto* ld = dynamic_cast<ir::masked_load_inst*>(i)){
|
||||
// ir::value* _ptr = ld->get_pointer_operand();
|
||||
// ir::value* _msk = ld->get_mask_operand();
|
||||
// ir::value* _val = ld->get_false_value_operand();
|
||||
// auto ptr = std::find(block->begin(), block->end(), _ptr);
|
||||
// auto msk = std::find(block->begin(), block->end(), _msk);
|
||||
// auto val = std::find(block->begin(), block->end(), _val);
|
||||
// if(ptr == block->end() || msk == block->end() || val == block->end())
|
||||
// continue;
|
||||
// auto it = std::find(block->begin(), block->end(), i);
|
||||
// int dist_ptr = std::distance(ptr, it);
|
||||
// int dist_msk = std::distance(msk, it);
|
||||
// int dist_val = std::distance(val, it);
|
||||
// if(dist_ptr < dist_msk && dist_ptr < dist_val)
|
||||
// builder.set_insert_point(++ptr);
|
||||
// if(dist_msk < dist_ptr && dist_msk < dist_val)
|
||||
// builder.set_insert_point(++msk);
|
||||
// if(dist_val < dist_ptr && dist_val < dist_msk)
|
||||
// builder.set_insert_point(++val);
|
||||
// ir::value* new_ld = builder.create_masked_load(_ptr, _msk, _val);
|
||||
// to_replace.push_back(std::make_pair(ld, new_ld));
|
||||
// }
|
||||
// }
|
||||
|
||||
for(auto& x: to_replace)
|
||||
x.first->replace_all_uses_with(x.second);
|
||||
// for(auto& x: to_replace)
|
||||
// x.first->replace_all_uses_with(x.second);
|
||||
|
||||
}
|
||||
|
||||
|
Reference in New Issue
Block a user