[codegen] [selection] re-arranged file structure

This commit is contained in:
Philippe Tillet
2019-10-17 12:31:26 -04:00
parent a0182f41dd
commit f4f70db234
10 changed files with 1091 additions and 498 deletions

View File

@@ -178,7 +178,7 @@ public:
class machine_layout_distributed_t: public machine_layout_t {
public:
machine_layout_distributed_t(Module *mod, Builder *builder, target *tgt, Type *ty,
std::map<unsigned, distributed_axis>& axes,
analysis::axes *a_axes, std::map<unsigned, distributed_axis>& axes,
analysis::layout_t* layout);
tile* create(ir::value *v);
@@ -186,6 +186,7 @@ public:
Builder *builder_;
target *tgt_;
Type *ty_;
analysis::axes *a_axes_;
std::map<unsigned, distributed_axis>& axes_;
analysis::layout_t* layout_;
};
@@ -195,7 +196,7 @@ class machine_layout_hmma_884_t: public machine_layout_distributed_t {
public:
machine_layout_hmma_884_t(Module *mod, Builder *builder,
target *tgt, Type *ty,
std::map<unsigned, distributed_axis>& axes,
analysis::axes *a_axes, std::map<unsigned, distributed_axis>& axes,
analysis::layout_hmma_884_t* layout);
Value *offset_a_i_, *offset_a_k_;
Value *offset_b_j_, *offset_b_k_;
@@ -209,7 +210,7 @@ class machine_layout_scanline_t: public machine_layout_distributed_t {
public:
machine_layout_scanline_t(Module *mod, Builder *builder,
target *tgt, Type *ty,
std::map<unsigned, distributed_axis>& axes,
analysis::axes *a_axes, std::map<unsigned, distributed_axis>& axes,
analysis::layout_scanline_t* layout);
};
@@ -230,6 +231,7 @@ private:
public:
generator(Module *dst,
analysis::axes *a_axes,
target *tgt,
analysis::layout *layouts,
analysis::align *alignment,
@@ -298,6 +300,7 @@ private:
Module *mod_;
std::map<const analysis::layout_t*, machine_layout_t*> machine_layouts_;
analysis::axes *a_axes_;
std::map<unsigned, distributed_axis> axes_;
std::map<ir::value *, Value *> vmap_;
std::map<ir::value *, tile *> tmap_;
@@ -319,10 +322,10 @@ class selection{
public:
selection(analysis::liveness* liveness, analysis::allocation *alloc,
analysis::align *alignment,
analysis::align *alignment, analysis::axes *axes,
analysis::layout *layouts, target *tgt, unsigned num_warps)
: liveness_(liveness), alloc_(alloc),
alignment_(alignment), layouts_(layouts),
alignment_(alignment), a_axes_(axes), layouts_(layouts),
tgt_(tgt), num_warps_(num_warps){ }
void run(ir::module &src, Module &dst);
@@ -330,6 +333,7 @@ public:
private:
analysis::liveness *liveness_;
analysis::allocation *alloc_;
analysis::axes *a_axes_;
analysis::layout *layouts_;
analysis::align *alignment_;
target *tgt_;

View File

@@ -0,0 +1,169 @@
#pragma once
#ifndef _TRITON_SELECTION_GENERATOR_H_
#define _TRITON_SELECTION_GENERATOR_H_
#include "triton/ir/visitor.h"
#include "triton/codegen/analysis/layout.h"
#include "triton/codegen/selection/machine_value.h"
#include <functional>
// forward
namespace llvm{
class Type;
class Value;
class Instruction;
class Constant;
class LLVMContext;
class Module;
class ConstantFolder;
class IRBuilderDefaultInserter;
template <typename T, typename Inserter>
class IRBuilder;
class ArrayType;
class Function;
}
namespace triton{
namespace codegen{
// forward
namespace analysis{
class liveness;
class tiles;
class align;
class allocation;
class cts;
class axes;
class layout;
}
// typedef
typedef llvm::IRBuilder<llvm::ConstantFolder,
llvm::IRBuilderDefaultInserter> Builder;
typedef llvm::LLVMContext LLVMContext;
typedef llvm::Type Type;
typedef llvm::Value Value;
typedef llvm::Module Module;
typedef llvm::Instruction Instruction;
typedef llvm::Constant Constant;
typedef llvm::ArrayType ArrayType;
typedef llvm::Function Function;
typedef std::vector<Value*> indices_t;
// forward
class machine_layout_t;
class tile;
class shared_tile;
class distributed_tile;
class target;
}
}
namespace triton{
namespace codegen{
class generator: public ir::visitor, public analysis::layout_visitor {
private:
void for_each(ir::value *x, const std::function<void(indices_t)>& fn);
Value* get_value(ir::value *x, const indices_t& idx);
void set_value(ir::value *x, const indices_t& idx, Value* v);
void visit_hmma_dot(ir::dot_inst*, shared_tile *TA, shared_tile *TB, distributed_tile *TD, unsigned NK);
void visit_scanline_dot(ir::dot_inst*, shared_tile *TA, shared_tile *TB, distributed_tile *TD, unsigned NK, Type *c_ty, Function *f_mul_add);
void visit_outer_dot(ir::dot_inst*, distributed_tile *TA, distributed_tile *TB, distributed_tile *TD, unsigned NK,
Type *c_ty, Function *f_mul_add);
void finalize_shared_layout(analysis::layout_shared_t*);
void finalize_function(ir::function*);
void finalize_phi_node(ir::phi_node*);
public:
generator(Module *dst,
analysis::axes *a_axes,
target *tgt,
analysis::layout *layouts,
analysis::align *alignment,
analysis::allocation *alloc,
unsigned num_warps);
void visit_value(ir::value* v);
void visit_phi_node(ir::phi_node*);
void visit_binary_operator(ir::binary_operator*);
void visit_getelementptr_inst(ir::getelementptr_inst*);
void visit_icmp_inst(ir::icmp_inst*);
void visit_fcmp_inst(ir::fcmp_inst*);
void visit_cast_inst(ir::cast_inst*);
void visit_return_inst(ir::return_inst*);
void visit_cond_branch_inst(ir::cond_branch_inst*);
void visit_uncond_branch_inst(ir::uncond_branch_inst*);
void visit_unmasked_load_inst(ir::unmasked_load_inst*);
void visit_masked_load_inst(ir::masked_load_inst*);
void visit_unmasked_store_inst(ir::unmasked_store_inst*);
void visit_masked_store_inst(ir::masked_store_inst*);
void visit_reshape_inst(ir::reshape_inst*);
void visit_splat_inst(ir::splat_inst*);
void visit_broadcast_inst(ir::broadcast_inst*);
void visit_downcast_inst(ir::downcast_inst*);
void visit_get_program_id_inst(ir::get_program_id_inst*);
void visit_get_num_program_inst(ir::get_num_program_inst*);
void visit_atomic_cas_inst(ir::atomic_cas_inst*);
void visit_atomic_exch_inst(ir::atomic_exch_inst*);
void visit_atomic_add_inst(ir::atomic_add_inst*);
void visit_dot_inst(ir::dot_inst*);
void visit_trans_inst(ir::trans_inst*);
void visit_sqrt_inst(ir::sqrt_inst*);
void visit_reduce_inst(ir::reduce_inst*);
void visit_select_inst(ir::select_inst*);
void visit_copy_to_shared_inst(ir::copy_to_shared_inst*);
void visit_copy_from_shared_inst(ir::copy_from_shared_inst*);
void visit_barrier_inst(ir::barrier_inst*);
void visit_make_range_dyn(ir::make_range_dyn*);
void visit_make_range(ir::make_range*);
void visit_make_range_sta(ir::make_range_sta*);
void visit_undef_value(ir::undef_value*);
void visit_constant_int(ir::constant_int*);
void visit_constant_fp(ir::constant_fp*);
void visit_alloc_const(ir::alloc_const*);
void visit_function(ir::function*);
void visit_basic_block(ir::basic_block*);
void visit_argument(ir::argument*);
void visit_layout_hmma_884(analysis::layout_hmma_884_t*);
void visit_layout_scanline(analysis::layout_scanline_t*);
void visit_layout_shared(analysis::layout_shared_t*);
private:
LLVMContext *ctx_;
Builder* builder_;
Module *mod_;
std::map<const analysis::layout_t*, machine_layout_t*> machine_layouts_;
analysis::axes *a_axes_;
std::map<unsigned, distributed_axis> axes_;
std::map<ir::value *, Value *> vmap_;
std::map<ir::value *, tile *> tmap_;
target *tgt_;
analysis::layout *layouts_;
analysis::align *alignment_;
analysis::allocation *alloc_;
Value *sh_mem_ptr_;
unsigned num_warps_;
std::set<ir::value*> seen_;
};
}
}
#endif

View File

@@ -0,0 +1,138 @@
#pragma once
#ifndef _TRITON_SELECTION_MACHINE_LAYOUT_H_
#define _TRITON_SELECTION_MACHINE_LAYOUT_H_
#include <map>
#include "triton/codegen/analysis/layout.h"
namespace llvm{
class Type;
class Value;
class Instruction;
class Constant;
class LLVMContext;
class Module;
class ConstantFolder;
class IRBuilderDefaultInserter;
template <typename T, typename Inserter>
class IRBuilder;
class ArrayType;
class Function;
}
namespace triton{
namespace ir{
class value;
}
namespace codegen{
namespace analysis{
class liveness;
class tiles;
class align;
class allocation;
class cts;
class axes;
class layout;
}
typedef llvm::IRBuilder<llvm::ConstantFolder,
llvm::IRBuilderDefaultInserter> Builder;
typedef llvm::LLVMContext LLVMContext;
typedef llvm::Type Type;
typedef llvm::Value Value;
typedef llvm::Module Module;
typedef llvm::Instruction Instruction;
typedef llvm::Constant Constant;
typedef llvm::ArrayType ArrayType;
typedef llvm::Function Function;
class distributed_axis;
class machine_layout_t;
class tile;
class shared_tile;
class distributed_tile;
class target;
}
}
namespace triton{
namespace codegen{
class machine_layout_t {
public:
virtual tile* create(ir::value *v) = 0;
};
class machine_layout_shared_t: public machine_layout_t {
public:
machine_layout_shared_t(Module *mod, Builder *builder, target *tgt, analysis::allocation* alloc, Value *&sh_mem_ptr, analysis::layout_t* layout,
std::map<ir::value *, Value *>& vmap,
std::map<ir::value *, tile *>& tmap);
tile* create(ir::value *v);
Module *mod_;
Builder *builder_;
target *tgt_;
analysis::allocation* alloc_;
Value *&sh_mem_ptr_;
analysis::layout_t* layout_;
std::map<ir::value *, Value *>& vmap_;
std::map<ir::value *, tile *>& tmap_;
Value *offset_;
Value *ptr_;
Value *pre_ptr_;
Value *next_ptr_;
};
class machine_layout_distributed_t: public machine_layout_t {
public:
machine_layout_distributed_t(Module *mod, Builder *builder, target *tgt, Type *ty,
analysis::axes *a_axes, std::map<unsigned, distributed_axis>& axes,
analysis::layout_t* layout);
tile* create(ir::value *v);
Module *mod_;
Builder *builder_;
target *tgt_;
Type *ty_;
analysis::axes *a_axes_;
std::map<unsigned, distributed_axis>& axes_;
analysis::layout_t* layout_;
};
class machine_layout_hmma_884_t: public machine_layout_distributed_t {
public:
machine_layout_hmma_884_t(Module *mod, Builder *builder,
target *tgt, Type *ty,
analysis::axes *a_axes, std::map<unsigned, distributed_axis>& axes,
analysis::layout_hmma_884_t* layout);
Value *offset_a_i_, *offset_a_k_;
Value *offset_b_j_, *offset_b_k_;
unsigned pack_size_0_;
unsigned pack_size_1_;
unsigned num_packs_0_;
unsigned num_packs_1_;
};
class machine_layout_scanline_t: public machine_layout_distributed_t {
public:
machine_layout_scanline_t(Module *mod, Builder *builder,
target *tgt, Type *ty,
analysis::axes *a_axes, std::map<unsigned, distributed_axis>& axes,
analysis::layout_scanline_t* layout);
};
}
}
#endif

View File

@@ -0,0 +1,153 @@
#pragma once
#ifndef _TRITON_SELECTION_MACHINE_VALUE_H_
#define _TRITON_SELECTION_MACHINE_VALUE_H_
#include <vector>
#include <map>
#include <functional>
namespace llvm{
class Type;
class Value;
class Instruction;
class Constant;
class LLVMContext;
class Module;
class ConstantFolder;
class IRBuilderDefaultInserter;
template <typename T, typename Inserter>
class IRBuilder;
class ArrayType;
class Function;
}
namespace triton{
namespace codegen{
typedef llvm::IRBuilder<llvm::ConstantFolder,
llvm::IRBuilderDefaultInserter> Builder;
typedef llvm::LLVMContext LLVMContext;
typedef llvm::Type Type;
typedef llvm::Value Value;
typedef llvm::Module Module;
typedef llvm::Instruction Instruction;
typedef llvm::Constant Constant;
typedef llvm::ArrayType ArrayType;
typedef llvm::Function Function;
}
}
namespace triton{
namespace codegen{
namespace analysis{
class liveness;
class tiles;
class align;
class allocation;
class cts;
class axes;
class layout;
}
class distributed_axis;
class machine_layout_t;
class tile;
class shared_tile;
class distributed_tile;
class target;
typedef std::vector<Value*> indices_t;
}
}
namespace triton{
namespace codegen{
struct distributed_axis {
int contiguous;
std::vector<Value*> values;
Value* thread_id;
};
class tile {
protected:
typedef std::vector<unsigned> shapes_t;
public:
tile(Type *ty, const shapes_t &shapes): ty_(ty), shapes_(shapes){ }
virtual void set_value(indices_t idx, Value *v) = 0;
virtual Value* get_value(indices_t idx) = 0;
Type *get_ty() const { return ty_; }
shapes_t get_shapes() const { return shapes_; }
protected:
Type *ty_;
shapes_t shapes_;
};
class shared_tile: public tile {
private:
void extract_constant(Value *arg, Value *&non_cst, Value *&cst);
void extract_constant(const indices_t &arg_idx, indices_t &non_cst_idx, indices_t &cst_idx);
public:
shared_tile(Type* ty, const shapes_t &shapes, const std::vector<int> &order, Value* ptr, Builder &builder, Value* offset = nullptr, const std::vector<int>& perm = {});
void set_vector_size(unsigned vector_size);
void set_return_mode(bool return_vector);
void set_value(indices_t, Value *);
Value* get_ptr_to(indices_t idx);
Value* get_value(indices_t idx);
Value* get_pointer() { return ptr_; }
Value* get_offset() { return offset_; }
const std::vector<int>& get_perm() { return perm_; }
const std::vector<int>& get_order() { return order_; }
static Value* shared_offset(Builder& builder, const shapes_t& shapes, const std::vector<int>& perm, const std::vector<int>& order, indices_t idx);
private:
Value *ptr_;
bool return_vector_;
Builder &builder_;
Value *offset_;
std::map<indices_t, Value*> ptr_cache_;
unsigned vector_size_;
std::vector<int> order_;
std::vector<int> perm_;
};
// Distribtued tile
class distributed_tile: public tile{
typedef std::vector<distributed_axis> axes_t;
typedef std::vector<indices_t> ordered_indices_vec_t;
typedef std::map<indices_t, unsigned> indices_map_t;
typedef std::map<indices_t, Value*> values_map_t;
private:
void init_indices();
Type *make_vector_ty(Type *ty, size_t vector_size);
public:
distributed_tile(Type *ty, const shapes_t& shapes, const std::vector<int>& order, const axes_t &axes, Builder &builder, bool vectorize);
void set_value(indices_t idx, Value *v);
Value* get_value(indices_t idx);
const std::vector<int>& get_order() { return order_; }
unsigned get_linear_index(indices_t idx);
indices_t get_ordered_indices(unsigned id);
void for_each(std::function<void(indices_t)> fn);
const distributed_axis &axis(unsigned dim) { return axes_.at(dim); }
private:
axes_t axes_;
std::vector<int> order_;
indices_map_t indices_;
values_map_t values_;
ordered_indices_vec_t ordered_indices_;
size_t vector_size_;
Builder &builder_;
};
}
}
#endif

View File

@@ -0,0 +1,70 @@
#pragma once
#ifndef _TRITON_SELECTION_SELECTION_H_
#define _TRITON_SELECTION_SELECTION_H_
#include <map>
namespace llvm{
class Module;
class Value;
}
namespace triton{
namespace ir{
class value;
class module;
}
namespace codegen{
// typedef
typedef llvm::Module Module;
typedef llvm::Value Value;
// forward
namespace analysis{
class liveness;
class align;
class allocation;
class axes;
class layout;
}
class target;
class tile;
}
}
namespace triton{
namespace codegen{
// Selection pass
class selection{
typedef std::map<ir::value *, Value *> vmap_t;
typedef std::map<ir::value *, tile *> tmap_t;
public:
selection(analysis::liveness* liveness, analysis::allocation *alloc,
analysis::align *alignment, analysis::axes *axes,
analysis::layout *layouts, target *tgt, unsigned num_warps)
: liveness_(liveness), alloc_(alloc),
alignment_(alignment), a_axes_(axes), layouts_(layouts),
tgt_(tgt), num_warps_(num_warps){ }
void run(ir::module &src, Module &dst);
private:
analysis::liveness *liveness_;
analysis::allocation *alloc_;
analysis::axes *a_axes_;
analysis::layout *layouts_;
analysis::align *alignment_;
target *tgt_;
unsigned num_warps_;
};
}
}
#endif

View File

@@ -1,9 +1,8 @@
#include <numeric>
#include "triton/codegen/selection.h"
#include <numeric>
#include "triton/codegen/selection/generator.h"
#include "triton/codegen/selection/machine_layout.h"
#include "triton/codegen/selection/machine_value.h"
#include "triton/codegen/target.h"
#include "triton/codegen/analysis/liveness.h"
#include "triton/codegen/analysis/layout.h"
#include "triton/codegen/analysis/axes.h"
#include "triton/codegen/analysis/allocation.h"
#include "triton/codegen/analysis/align.h"
#include "triton/codegen/transform/coalesce.h"
@@ -12,12 +11,8 @@
#include "triton/ir/module.h"
#include "triton/ir/function.h"
#include "triton/ir/type.h"
#include "llvm/IR/InstrTypes.h"
#include "llvm/IR/Module.h"
#include "llvm/IR/IRBuilder.h"
#include "llvm/Transforms/Scalar/EarlyCSE.h"
#include "llvm/Analysis/LoopInfo.h"
#include "llvm/Transforms/Utils/BasicBlockUtils.h"
#include "llvm/IR/BasicBlock.h"
#include "llvm/IR/Attributes.h"
#include "llvm/IR/InlineAsm.h"
@@ -27,198 +22,6 @@ namespace codegen{
using namespace llvm;
/* Distributed Tile */
void distributed_tile::init_indices() {
std::vector<size_t> id(axes_.size(), 0);
// create iteration order
std::vector<size_t> order(id.size());
std::iota(order.begin(), order.end(), 0);
auto cmp = [&](int x, int y) {
return axes_[x].contiguous > axes_[y].contiguous;
};
std::sort(order.begin(), order.end(), cmp);
// build
size_t k = 0;
while(true) {
indices_t current;
for(size_t d = 0; d < id.size(); d++)
current.push_back(axes_[d].values[id[d]]);
size_t sz = indices_.size();
indices_[current] = sz;
values_[current] = nullptr;
ordered_indices_.push_back(current);
id[order[0]]++;
while(id[order[k]] == axes_[order[k]].values.size()){
if(k == id.size() - 1)
return;
id[order[k++]] = 0;
id[order[k]]++;
}
k = 0;
}
}
llvm::Type *distributed_tile::make_vector_ty(llvm::Type *ty, size_t vector_size) {
if(vector_size == 1)
return ty;
return VectorType::get(ty, vector_size);
}
distributed_tile::distributed_tile(Type *ty, const shapes_t &shapes, const std::vector<int>& order, const axes_t &axes, llvm::IRBuilder<> &builder, bool vectorize)
: tile(make_vector_ty(ty, vectorize?axes[0].contiguous:1), shapes), axes_(axes), order_(order), builder_(builder) {
vector_size_ = vectorize?ty_->getVectorNumElements():1;
init_indices();
}
void distributed_tile::set_value(indices_t idx, Value *x) {
assert(x->getType() == ty_ && "cannot set a value of different type");
Value *&result = values_[idx];
assert(!result && "value cannot be set twice");
result = x;
}
Value* distributed_tile::get_value(indices_t idx) {
Value *result = values_.at(idx);
assert(result && "value has not been set");
return result;
}
unsigned distributed_tile::get_linear_index(indices_t idx) {
return indices_[idx];
}
indices_t distributed_tile::get_ordered_indices(unsigned id) {
return ordered_indices_.at(id);
}
void distributed_tile::for_each(std::function<void (indices_t)> fn) {
for(unsigned i = 0; i < ordered_indices_.size(); i++){
if(i % vector_size_ == 0)
fn(ordered_indices_[i]);
}
}
/* Shared Tile */
void shared_tile::extract_constant(Value *arg, Value *&non_cst, Value *&cst) {
BinaryOperator *bin_op = dyn_cast<BinaryOperator>(arg);
Constant *_0 = ConstantInt::get(Type::getInt32Ty(arg->getContext()), 0);
if(dyn_cast<Constant>(arg)){
cst = arg;
non_cst = _0;
return;
}
if(!bin_op || bin_op->getOpcode() != llvm::BinaryOperator::Add){
non_cst = arg;
cst = _0;
return;
}
Constant *cst_lhs = dyn_cast<Constant>(bin_op->getOperand(0));
Constant *cst_rhs = dyn_cast<Constant>(bin_op->getOperand(1));
if(cst_lhs && cst_rhs){
cst = arg;
non_cst = _0;
}
else if(cst_lhs){
cst = cst_lhs;
non_cst = bin_op->getOperand(1);
}
else if(cst_rhs){
cst = cst_rhs;
non_cst = bin_op->getOperand(0);
}
else{
non_cst = arg;
cst = _0;
}
}
void shared_tile::extract_constant(const indices_t &arg_idx, indices_t &non_cst_idx, indices_t &cst_idx) {
non_cst_idx.clear();
cst_idx.clear();
for(Value *idx: arg_idx){
Value *non_cst, *cst;
extract_constant(idx, non_cst, cst);
non_cst_idx.push_back(non_cst);
cst_idx.push_back(cst);
}
}
Value* shared_tile::shared_offset(llvm::IRBuilder<> &builder, const shapes_t& shapes, const std::vector<int>& perm, const std::vector<int>& order, indices_t idx) {
// strides
std::vector<Value*> strides(order.size());
strides[order[0]] = builder.getInt32(1);
for(size_t i = 1; i < idx.size(); i++)
strides[order[i]] = builder.CreateMul(strides[order[i-1]], builder.getInt32(shapes[order[i-1]]));
// result
Value *result = builder.getInt32(0);
for(size_t i = 0; i < strides.size(); i++)
result = builder.CreateAdd(result, builder.CreateMul(idx[perm[i]], strides[i]));
return result;
}
shared_tile::shared_tile(Type *ty, const shapes_t &shapes, const std::vector<int>& order, Value *ptr, llvm::IRBuilder<> &builder, Value *offset, const std::vector<int>& perm):
tile(ty, shapes), order_(order), ptr_(ptr), builder_(builder), offset_(offset), vector_size_(1), perm_(perm){
return_vector_ = false;
if(perm_.empty()){
perm_.resize(shapes.size());
std::iota(perm_.begin(), perm_.end(), 0);
}
}
void shared_tile::set_value(indices_t idx, Value *value) {
Value *ptr = builder_.CreateGEP(ptr_, shared_offset(builder_, shapes_, perm_, order_, idx));
unsigned addr_space = ptr->getType()->getPointerAddressSpace();
ptr = builder_.CreateBitCast(ptr, value->getType()->getPointerTo(addr_space));
builder_.CreateStore(value, ptr);
}
void shared_tile::set_vector_size(unsigned vector_size) {
vector_size_ = vector_size;
}
void shared_tile::set_return_mode(bool return_vector){
return_vector_ = return_vector;
}
Value* shared_tile::get_value(indices_t idx) {
indices_t non_cst_idx, cst_idx;
extract_constant(idx, non_cst_idx, cst_idx);
Value *&base_ptr = ptr_cache_[non_cst_idx];
unsigned vector_size = vector_size_;
Type *ty = ty_;
if(ty->isHalfTy() && (vector_size % 2 == 0)){
ty = IntegerType::get(ty->getContext(), 32);
vector_size = vector_size / 2;
}
if(base_ptr == nullptr){
// BasicBlock* store = builder_.GetInsertBlock();
// if(!non_cst_idx.empty())
// if(isa<Instruction>(non_cst_idx.front())){
// builder_.SetInsertPoint((Instruction*)non_cst_idx.front());
// }
base_ptr = builder_.CreateGEP(ptr_, shared_offset(builder_, shapes_, perm_, order_, non_cst_idx));
if(vector_size_ > 1){
Type *vec_ty = VectorType::get(ty, vector_size);
Type *vec_ptr_ty = PointerType::get(vec_ty, base_ptr->getType()->getPointerAddressSpace());
base_ptr = builder_.CreateBitCast(base_ptr, vec_ptr_ty);
}
// builder_.SetInsertPoint(store);
}
Value *offset = shared_offset(builder_, shapes_, perm_, order_, cst_idx);
Value *div = offset;
if(vector_size_ > 1)
div = builder_.CreateUDiv(offset, builder_.getInt32(vector_size_));
Value *ptr = builder_.CreateGEP(base_ptr, div);
Value *result = builder_.CreateLoad(ptr);
if(return_vector_ == false && vector_size_ > 1) {
Value *rem = builder_.CreateURem(offset, builder_.getInt32(vector_size_));
result = builder_.CreateExtractElement(result, rem);
}
return result;
}
llvm::Instruction::BinaryOps llvm_op(ir::binary_op_t op) {
using llop = llvm::Instruction::BinaryOps;
@@ -306,7 +109,7 @@ llvm::CmpInst::Predicate llvm_pred(ir::cmp_pred_t pred) {
}
Type *type(ir::type *ty, LLVMContext &ctx) {
inline Type *type(ir::type *ty, LLVMContext &ctx) {
// function
if(auto* tt = dynamic_cast<ir::function_type*>(ty)){
Type *return_ty = type(tt->get_return_ty(), ctx);
@@ -344,34 +147,17 @@ Type *type(ir::type *ty, LLVMContext &ctx) {
}
/* -------------------
* ---- Init Axes ----
* ------------------- */
// Grid construction
std::vector<Value*> delinearize(Value *trailing, const std::vector<int>& order, std::vector<int> &shapes, IRBuilder<> &builder){
size_t dim = shapes.size();
std::vector<Value*> result(dim);
for(unsigned k = 0; k < dim - 1; k++){
Constant *dim_k = builder.getInt32(shapes[order[k]]);
Value *rem = builder.CreateURem(trailing, dim_k);
trailing = builder.CreateUDiv(trailing, dim_k);
result[order[k]] = rem;
inline llvm::Attribute llvm_attr(llvm::LLVMContext& ctx, ir::attribute attr) {
switch(attr.get_kind()){
case ir::noalias: return llvm::Attribute::get(ctx, llvm::Attribute::NoAlias);
case ir::readonly: return llvm::Attribute::get(ctx, llvm::Attribute::ReadOnly);
case ir::writeonly: return llvm::Attribute::get(ctx, llvm::Attribute::WriteOnly);
case ir::aligned: return llvm::Attribute::get(ctx, llvm::Attribute::Alignment, attr.get_value());
default: throw std::runtime_error("cannot convert ir::attribute_t to llvm::Attribute");
}
result[order[dim - 1]] = trailing;
return result;
}
inline int32_t ceil(int32_t num, int32_t div){
return (num + div - 1)/div;
}
/* -------------------
* ---- Init Tiles ----
* ------------------- */
bool is_trans(ir::value *v) {
inline bool is_trans(ir::value *v) {
if(dynamic_cast<ir::trans_inst *>(v)) {
return true;
}
@@ -386,34 +172,9 @@ bool is_trans(ir::value *v) {
/* ----------------------------
* ---- Generate LLVM code ----
* ---------------------------- */
inline llvm::Attribute llvm_attr(llvm::LLVMContext& ctx, ir::attribute attr) {
switch(attr.get_kind()){
case ir::noalias: return llvm::Attribute::get(ctx, llvm::Attribute::NoAlias);
case ir::readonly: return llvm::Attribute::get(ctx, llvm::Attribute::ReadOnly);
case ir::writeonly: return llvm::Attribute::get(ctx, llvm::Attribute::WriteOnly);
case ir::aligned: return llvm::Attribute::get(ctx, llvm::Attribute::Alignment, attr.get_value());
default: throw std::runtime_error("cannot convert ir::attribute_t to llvm::Attribute");
}
}
void selection::run(ir::module &src, Module &dst) {
// create tile
generator gen(&dst, tgt_, layouts_, alignment_, alloc_, num_warps_ );
for(ir::alloc_const *x: src.allocs())
gen.visit_value(x);
for(ir::function *fn: src.get_function_list())
gen.visit_value(fn);
}
generator::generator(Module *dst,
analysis::axes *a_axes,
target *tgt,
analysis::layout *layouts,
analysis::align *alignment,
@@ -421,7 +182,7 @@ generator::generator(Module *dst,
unsigned num_warps)
: ctx_(&dst->getContext()), mod_(dst),
builder_(new Builder(dst->getContext())),
tgt_(tgt),
a_axes_(a_axes), tgt_(tgt),
layouts_(layouts), alignment_(alignment), alloc_(alloc),
num_warps_(num_warps) {
@@ -1163,14 +924,12 @@ void generator::visit_function(ir::function* fn) {
void generator::visit_layout_hmma_884(analysis::layout_hmma_884_t* layout) {
machine_layouts_[layout] = new machine_layout_hmma_884_t(mod_, &*builder_, tgt_, type(layout->ty->get_scalar_ty(), *ctx_), axes_, layout);
machine_layouts_[layout] = new machine_layout_hmma_884_t(mod_, &*builder_, tgt_, type(layout->ty->get_scalar_ty(), *ctx_), a_axes_, axes_, layout);
}
void generator::visit_layout_scanline(analysis::layout_scanline_t* layout) {
machine_layouts_[layout] = new machine_layout_scanline_t(mod_, &*builder_, tgt_, type(layout->ty->get_scalar_ty(), *ctx_), axes_, layout);
machine_layouts_[layout] = new machine_layout_scanline_t(mod_, &*builder_, tgt_, type(layout->ty->get_scalar_ty(), *ctx_), a_axes_, axes_, layout);
}
void generator::visit_layout_shared(analysis::layout_shared_t* layout) {
@@ -1215,240 +974,6 @@ void generator::set_value(ir::value *x, const indices_t& idx, Value* v) {
}
machine_layout_shared_t::machine_layout_shared_t(Module *mod, Builder *builder, target *tgt, analysis::allocation* alloc,
Value *&sh_mem_ptr, analysis::layout_t *layout,
std::map<ir::value *, Value *>& vmap,
std::map<ir::value *, tile *>& tmap)
: mod_(mod), builder_(builder), tgt_(tgt), alloc_(alloc), sh_mem_ptr_(sh_mem_ptr), layout_(layout), vmap_(vmap), tmap_(tmap) {
auto order = layout_->order;
auto shapes = layout_->shapes;
shapes[order[0]] += layout_->pad;
Type* ty = type(layout_->ty, builder_->getContext());
PointerType *ptr_ty = ty->getPointerTo(sh_mem_ptr_->getType()->getPointerAddressSpace());
// double-buffered
if(layout_->double_buffer) {
BasicBlock *current = builder_->GetInsertBlock();
auto info = *layout_->double_buffer;
ir::phi_node *phi = info.phi;
BasicBlock *parent = (BasicBlock*)vmap_.at(phi->get_parent());
if(parent->empty())
builder_->SetInsertPoint(parent);
else
builder_->SetInsertPoint(&*parent->getFirstNonPHI());
// create pointers
ptr_ = builder_->CreatePHI(ptr_ty, 2);
pre_ptr_ = builder_->CreateGEP(sh_mem_ptr_, builder_->getInt32(alloc_->offset(layout_)));
pre_ptr_ = builder_->CreateBitCast(pre_ptr_, ptr_->getType());
offset_ = builder_->CreatePHI(builder_->getInt32Ty(), 2);
next_ptr_ = builder_->CreateGEP(ptr_, offset_, "next_ptr");
builder_->SetInsertPoint(current);
}
else{
size_t offset = alloc_->offset(layout_);
ptr_ = builder_->CreateGEP(sh_mem_ptr_, builder_->getInt32(offset));
ptr_ = builder_->CreateBitCast(ptr_, ptr_ty);
}
}
tile* machine_layout_shared_t::create(ir::value *v) {
auto order = layout_->order;
auto shapes = layout_->shapes;
shapes[order[0]] += layout_->pad;
Type* ty = type(layout_->ty, builder_->getContext());
// double-buffered
if(layout_->double_buffer) {
if(v == layout_->double_buffer->phi)
return new shared_tile(ty, shapes, order, ptr_, *builder_, offset_);
if(v == layout_->double_buffer->latch)
return new shared_tile(ty, shapes, order, next_ptr_, *builder_);
return new shared_tile(ty, shapes, order, pre_ptr_, *builder_);
}
else {
return new shared_tile(ty, shapes, order, ptr_, *builder_);
}
}
machine_layout_distributed_t::machine_layout_distributed_t(Module *mod, Builder *builder, target *tgt, Type *ty,
std::map<unsigned, distributed_axis>& axes,
analysis::layout_t *layout)
: mod_(mod), builder_(builder), tgt_(tgt), ty_(ty), axes_(axes), layout_(layout) {
}
tile *machine_layout_distributed_t::create(ir::value *v) {
Type *ty = type(v->get_type()->get_scalar_ty(), builder_->getContext());
const auto &shapes = v->get_type()->get_tile_shapes();
std::vector<distributed_axis> axes(shapes.size());
for(size_t d = 0; d < shapes.size(); d++){
if(shapes[d] > 1){
unsigned x = layout_->axes[d];
axes[d] = axes_.at(x);
}
else{
axes[d].contiguous = 1;
axes[d].values = {builder_->getInt32(0)};
}
}
return new distributed_tile(ty, shapes, layout_->order, axes, *builder_, false);
}
machine_layout_hmma_884_t::machine_layout_hmma_884_t(Module *mod, Builder *builder,
target *tgt, Type *ty,
std::map<unsigned, distributed_axis>& axes,
analysis::layout_hmma_884_t* layout)
: machine_layout_distributed_t(mod, builder, tgt, ty, axes, layout) {
Value *warp_size = builder_->getInt32(32);
Value* u_thread_id_0 = tgt_->get_local_id(mod_, *builder_, 0);
Value *u_thread_id = builder_->CreateURem(u_thread_id_0, warp_size);
Value *u_warp_id = builder_->CreateUDiv(u_thread_id_0, warp_size);
const auto& shapes = layout->shapes;
if(shapes.size() > 3)
throw std::runtime_error("unsupported");
bool is_batched = shapes.size() >= 3;
Value *_1 = builder_->getInt32(1);
Value *_2 = builder_->getInt32(2);
Value *_3 = builder_->getInt32(3);
Value *_4 = builder_->getInt32(4);
Value *_16 = builder_->getInt32(16);
// fragments per warp
unsigned fpw_0 = layout->fpw.at(0);
unsigned fpw_1 = layout->fpw.at(1);
unsigned fpw_2 = is_batched ? layout->fpw.at(2) : 1;
// warps per tile
unsigned wpt_0 = layout->wpt.at(0);
unsigned wpt_1 = layout->wpt.at(1);
unsigned wpt_2 = is_batched ? layout->wpt.at(2) : 1;
// hmma warp tile size
unsigned hmma_wts_0 = fpw_0 * 8;
unsigned hmma_wts_1 = fpw_1 * 8;
unsigned hmma_wts_2 = is_batched ? fpw_2 : 1;
// hmma block tile size
unsigned hmma_bts_0 = hmma_wts_0 * wpt_0;
unsigned hmma_bts_1 = hmma_wts_1 * wpt_1;
unsigned hmma_bts_2 = is_batched ? hmma_wts_2 * wpt_2 : 1;
// number of repetition
unsigned num_rep_0 = shapes[0] / hmma_bts_0;
unsigned num_rep_1 = shapes[1] / hmma_bts_1;
unsigned num_rep_2 = is_batched ? shapes[2] / hmma_bts_2 : 1;
// size of each pack (interleaving)
pack_size_0_ = std::min<unsigned>(num_rep_0, 1);
pack_size_1_ = std::min<unsigned>(num_rep_1, 1);
// number of packs (interleaving)
num_packs_0_ = num_rep_0 / pack_size_0_;
num_packs_1_ = num_rep_1 / pack_size_1_;
/* intra warp offset */
// offset of quad in pair
Value *in_pair_off_a = builder_->CreateMul(builder_->CreateUDiv(builder_->CreateAnd(u_thread_id, _16), builder_->getInt32(4)),
builder_->getInt32(fpw_0 * pack_size_0_));
Value *in_pair_off_b = builder_->CreateMul(builder_->CreateUDiv(builder_->CreateAnd(u_thread_id, _16), builder_->getInt32(4)),
builder_->getInt32(fpw_1 * pack_size_1_));
// Quad pair id
Value *pair_a_id = builder_->CreateUDiv(builder_->CreateURem(u_thread_id, _16), _4);
Value *pair_b_id = builder_->CreateUDiv(builder_->CreateURem(u_thread_id, _16), _4);
pair_a_id = builder_->CreateURem(pair_a_id, builder_->getInt32(fpw_0));
pair_b_id = builder_->CreateUDiv(pair_b_id, builder_->getInt32(fpw_0));
pair_b_id = builder_->CreateURem(pair_b_id, builder_->getInt32(fpw_1));
// Quad pair offset
Value *pair_a_off = builder_->CreateMul(pair_a_id, builder_->getInt32(4 * pack_size_0_));
Value *pair_b_off = builder_->CreateMul(pair_b_id, builder_->getInt32(4 * pack_size_1_));
/* inter warp offset */
Value *warp_id_0 = builder_->CreateURem(u_warp_id, builder_->getInt32(wpt_0));
Value *warp_id_12 = builder_->CreateUDiv(u_warp_id, builder_->getInt32(wpt_0));
Value *warp_id_1 = builder_->CreateURem(warp_id_12, builder_->getInt32(wpt_1));
Value *warp_id_2 = builder_->CreateUDiv(warp_id_12, builder_->getInt32(wpt_1));
Value *warp_offset_i = builder_->CreateMul(warp_id_0, builder_->getInt32(hmma_wts_0 * pack_size_0_));
Value *warp_offset_j = builder_->CreateMul(warp_id_1, builder_->getInt32(hmma_wts_1 * pack_size_1_));
/* offsets */
// a offset
offset_a_i_ = builder_->CreateAdd(warp_offset_i, builder_->CreateAdd(pair_a_off, in_pair_off_a));
offset_a_k_ = builder_->CreateAnd(u_thread_id, _3);
// b offsets
offset_b_j_ = builder_->CreateAdd(warp_offset_j, builder_->CreateAdd(pair_b_off, in_pair_off_b));
offset_b_k_ = builder_->CreateAnd(u_thread_id, _3);
// c offsets
Value *offset_c_i = builder_->CreateAdd(builder_->CreateAnd(u_thread_id, _1), offset_a_i_);
Value *offset_c_j = builder_->CreateAdd(builder_->CreateAnd(u_thread_id, _2),
builder_->CreateAdd(warp_offset_j, pair_b_off));
/* indices */
// i indices
std::vector<Value*> idx_i;
for(unsigned pack = 0; pack < num_packs_0_; pack++)
for(unsigned ii = 0; ii < pack_size_0_; ii++)
for(unsigned i = 0; i < 2; i++){
idx_i.push_back(builder_->CreateAdd(offset_c_i, builder_->getInt32(pack*hmma_bts_0*pack_size_0_ + ii*4 + i*2)));
}
// j indices
std::vector<Value*> idx_j;
for(unsigned pack = 0; pack < num_packs_1_; pack++)
for(unsigned jj = 0; jj < pack_size_1_; jj++)
for(unsigned j = 0; j < 2; j++){
idx_j.push_back(builder_->CreateAdd(offset_c_j, builder_->getInt32(pack*hmma_bts_1*pack_size_1_ + jj*4 + j*4*fpw_1*pack_size_1_)));
idx_j.push_back(builder_->CreateAdd(offset_c_j, builder_->getInt32(pack*hmma_bts_1*pack_size_1_ + jj*4 + j*4*fpw_1*pack_size_1_ + 1)));
}
// z indices
std::vector<Value*> idx_z;
for(unsigned pack = 0; pack < num_rep_2; pack++)
idx_z.push_back(builder_->CreateAdd(warp_id_2, builder_->getInt32(pack*hmma_bts_2)));
/* axes */
axes_[layout->axes[0]] = distributed_axis{1, idx_i, warp_id_0};
axes_[layout->axes[1]] = distributed_axis{1, idx_j, warp_id_1};
if(is_batched)
axes_[layout->axes[2]] = distributed_axis{1, idx_z, warp_id_2};
}
machine_layout_scanline_t::machine_layout_scanline_t(Module *mod, Builder *builder,
target *tgt, Type *ty,
std::map<unsigned, distributed_axis> &axes,
analysis::layout_scanline_t* layout)
: machine_layout_distributed_t(mod, builder, tgt, ty, axes, layout) {
Value *warp_size = builder_->getInt32(32);
Value* u_thread_id_0 = tgt_->get_local_id(mod_, *builder_, 0);
Value *u_thread_id = builder_->CreateURem(u_thread_id_0, warp_size);
Value *u_warp_id = builder_->CreateUDiv(u_thread_id_0, warp_size);
auto order = layout->order;
const auto& shapes = layout->shapes;
size_t dim = shapes.size();
std::vector<int> nts = layout->nts;
std::vector<int> mts = layout->mts;
Value* full_thread_id = builder_->CreateAdd(builder_->CreateMul(u_warp_id, builder_->getInt32(32)), u_thread_id);
std::vector<Value*> thread_id = delinearize(full_thread_id, order, mts, *builder_);
// Create axes
for(unsigned k = 0; k < dim; k++) {
std::string str_k = std::to_string(k);
Value *contiguous_k = builder_->getInt32(nts[k]);
Value *scaled_thread_id = builder_->CreateMul(thread_id[k], contiguous_k);
unsigned per_block = nts[k] * mts[k];
unsigned per_thread = nts[k] * shapes[k] / per_block;
std::vector<Value*> idx_list(per_thread);
for(unsigned n = 0 ; n < per_thread; n++){
unsigned offset = n / nts[k] * per_block + n % nts[k];
idx_list[n] = builder_->CreateAdd(scaled_thread_id, builder_->getInt32(offset), "idx_" + str_k + "_" + std::to_string(n));
}
axes_[layout->axes[k]] = distributed_axis{nts[k], idx_list, thread_id[k]};
}
}
void generator::finalize_shared_layout(analysis::layout_shared_t *shared) {
if(shared->double_buffer) {
auto info = *shared->double_buffer;

View File

@@ -0,0 +1,308 @@
#include <numeric>
#include "triton/codegen/selection/machine_layout.h"
#include "triton/codegen/selection/machine_value.h"
#include "triton/codegen/analysis/allocation.h"
#include "triton/codegen/analysis/axes.h"
#include "triton/codegen/target.h"
#include "triton/ir/instructions.h"
#include "triton/ir/type.h"
#include "llvm/IR/IRBuilder.h"
namespace triton{
namespace codegen{
using namespace llvm;
inline Type *type(ir::type *ty, LLVMContext &ctx) {
// function
if(auto* tt = dynamic_cast<ir::function_type*>(ty)){
Type *return_ty = type(tt->get_return_ty(), ctx);
std::vector<Type*> param_tys;
std::transform(tt->params_begin(), tt->params_end(), std::back_inserter(param_tys),
[&ctx](ir::type* t){ return type(t, ctx);});
return FunctionType::get(return_ty, param_tys, false);
}
// pointer
if(ty->is_pointer_ty()){
Type *elt_ty = type(ty->get_pointer_element_ty(), ctx);
unsigned addr_space = ty->get_pointer_address_space();
return PointerType::get(elt_ty, addr_space);
}
// integer
if(ty->is_integer_ty()){
unsigned bitwidth = ty->get_integer_bitwidth();
return IntegerType::get(ctx, bitwidth);
}
// primitive types
switch(ty->get_type_id()){
case ir::type::VoidTyID: return Type::getVoidTy(ctx);
case ir::type::HalfTyID: return Type::getHalfTy(ctx);
case ir::type::FloatTyID: return Type::getFloatTy(ctx);
case ir::type::DoubleTyID: return Type::getDoubleTy(ctx);
case ir::type::X86_FP80TyID: return Type::getX86_FP80Ty(ctx);
case ir::type::PPC_FP128TyID: return Type::getPPC_FP128Ty(ctx);
case ir::type::LabelTyID: return Type::getLabelTy(ctx);
case ir::type::MetadataTyID: return Type::getMetadataTy(ctx);
case ir::type::TokenTyID: return Type::getTokenTy(ctx);
default: break;
}
// unknown type
throw std::runtime_error("unknown conversion from ir::type to Type");
}
// Grid construction
inline std::vector<Value*> delinearize(Value *trailing, const std::vector<int>& order, std::vector<int> &shapes, IRBuilder<> &builder){
size_t dim = shapes.size();
std::vector<Value*> result(dim);
for(unsigned k = 0; k < dim - 1; k++){
Constant *dim_k = builder.getInt32(shapes[order[k]]);
Value *rem = builder.CreateURem(trailing, dim_k);
trailing = builder.CreateUDiv(trailing, dim_k);
result[order[k]] = rem;
}
result[order[dim - 1]] = trailing;
return result;
}
inline int32_t ceil(int32_t num, int32_t div){
return (num + div - 1)/div;
}
machine_layout_shared_t::machine_layout_shared_t(Module *mod, Builder *builder, target *tgt, analysis::allocation* alloc,
Value *&sh_mem_ptr, analysis::layout_t *layout,
std::map<ir::value *, Value *>& vmap,
std::map<ir::value *, tile *>& tmap)
: mod_(mod), builder_(builder), tgt_(tgt), alloc_(alloc), sh_mem_ptr_(sh_mem_ptr), layout_(layout), vmap_(vmap), tmap_(tmap) {
auto order = layout_->order;
auto shapes = layout_->shapes;
shapes[order[0]] += layout_->pad;
Type* ty = type(layout_->ty, builder_->getContext());
PointerType *ptr_ty = ty->getPointerTo(sh_mem_ptr_->getType()->getPointerAddressSpace());
// double-buffered
if(layout_->double_buffer) {
BasicBlock *current = builder_->GetInsertBlock();
auto info = *layout_->double_buffer;
ir::phi_node *phi = info.phi;
BasicBlock *parent = (BasicBlock*)vmap_.at((ir::value*)(phi->get_parent()));
if(parent->empty())
builder_->SetInsertPoint(parent);
else
builder_->SetInsertPoint(&*parent->getFirstNonPHI());
// create pointers
ptr_ = builder_->CreatePHI(ptr_ty, 2);
pre_ptr_ = builder_->CreateGEP(sh_mem_ptr_, builder_->getInt32(alloc_->offset(layout_)));
pre_ptr_ = builder_->CreateBitCast(pre_ptr_, ptr_->getType());
offset_ = builder_->CreatePHI(builder_->getInt32Ty(), 2);
next_ptr_ = builder_->CreateGEP(ptr_, offset_, "next_ptr");
builder_->SetInsertPoint(current);
}
else{
size_t offset = alloc_->offset(layout_);
ptr_ = builder_->CreateGEP(sh_mem_ptr_, builder_->getInt32(offset));
ptr_ = builder_->CreateBitCast(ptr_, ptr_ty);
}
}
tile* machine_layout_shared_t::create(ir::value *v) {
auto order = layout_->order;
auto shapes = layout_->shapes;
shapes[order[0]] += layout_->pad;
Type* ty = type(layout_->ty, builder_->getContext());
// double-buffered
if(layout_->double_buffer) {
if(v == layout_->double_buffer->phi)
return new shared_tile(ty, shapes, order, ptr_, *builder_, offset_);
if(v == layout_->double_buffer->latch)
return new shared_tile(ty, shapes, order, next_ptr_, *builder_);
return new shared_tile(ty, shapes, order, pre_ptr_, *builder_);
}
else {
return new shared_tile(ty, shapes, order, ptr_, *builder_);
}
}
machine_layout_distributed_t::machine_layout_distributed_t(Module *mod, Builder *builder, target *tgt, Type *ty,
analysis::axes *a_axes, std::map<unsigned, distributed_axis>& axes,
analysis::layout_t *layout)
: mod_(mod), builder_(builder), tgt_(tgt), ty_(ty), a_axes_(a_axes), axes_(axes), layout_(layout) {
}
tile *machine_layout_distributed_t::create(ir::value *v) {
Type *ty = type(v->get_type()->get_scalar_ty(), builder_->getContext());
const auto &shapes = v->get_type()->get_tile_shapes();
std::vector<distributed_axis> axes(shapes.size());
for(size_t d = 0; d < shapes.size(); d++){
if(shapes[d] > 1){
unsigned x = a_axes_->get(v, d);
axes[d] = axes_.at(x);
}
else{
axes[d].contiguous = 1;
axes[d].values = {builder_->getInt32(0)};
}
}
return new distributed_tile(ty, shapes, layout_->order, axes, *builder_, false);
}
machine_layout_hmma_884_t::machine_layout_hmma_884_t(Module *mod, Builder *builder,
target *tgt, Type *ty, analysis::axes *a_axes,
std::map<unsigned, distributed_axis>& axes,
analysis::layout_hmma_884_t* layout)
: machine_layout_distributed_t(mod, builder, tgt, ty, a_axes, axes, layout) {
Value *warp_size = builder_->getInt32(32);
Value* u_thread_id_0 = tgt_->get_local_id(mod_, *builder_, 0);
Value *u_thread_id = builder_->CreateURem(u_thread_id_0, warp_size);
Value *u_warp_id = builder_->CreateUDiv(u_thread_id_0, warp_size);
const auto& shapes = layout->shapes;
if(shapes.size() > 3)
throw std::runtime_error("unsupported");
bool is_batched = shapes.size() >= 3;
Value *_1 = builder_->getInt32(1);
Value *_2 = builder_->getInt32(2);
Value *_3 = builder_->getInt32(3);
Value *_4 = builder_->getInt32(4);
Value *_16 = builder_->getInt32(16);
// fragments per warp
unsigned fpw_0 = layout->fpw.at(0);
unsigned fpw_1 = layout->fpw.at(1);
unsigned fpw_2 = is_batched ? layout->fpw.at(2) : 1;
// warps per tile
unsigned wpt_0 = layout->wpt.at(0);
unsigned wpt_1 = layout->wpt.at(1);
unsigned wpt_2 = is_batched ? layout->wpt.at(2) : 1;
// hmma warp tile size
unsigned hmma_wts_0 = fpw_0 * 8;
unsigned hmma_wts_1 = fpw_1 * 8;
unsigned hmma_wts_2 = is_batched ? fpw_2 : 1;
// hmma block tile size
unsigned hmma_bts_0 = hmma_wts_0 * wpt_0;
unsigned hmma_bts_1 = hmma_wts_1 * wpt_1;
unsigned hmma_bts_2 = is_batched ? hmma_wts_2 * wpt_2 : 1;
// number of repetition
unsigned num_rep_0 = shapes[0] / hmma_bts_0;
unsigned num_rep_1 = shapes[1] / hmma_bts_1;
unsigned num_rep_2 = is_batched ? shapes[2] / hmma_bts_2 : 1;
// size of each pack (interleaving)
pack_size_0_ = std::min<unsigned>(num_rep_0, 1);
pack_size_1_ = std::min<unsigned>(num_rep_1, 1);
// number of packs (interleaving)
num_packs_0_ = num_rep_0 / pack_size_0_;
num_packs_1_ = num_rep_1 / pack_size_1_;
/* intra warp offset */
// offset of quad in pair
Value *in_pair_off_a = builder_->CreateMul(builder_->CreateUDiv(builder_->CreateAnd(u_thread_id, _16), builder_->getInt32(4)),
builder_->getInt32(fpw_0 * pack_size_0_));
Value *in_pair_off_b = builder_->CreateMul(builder_->CreateUDiv(builder_->CreateAnd(u_thread_id, _16), builder_->getInt32(4)),
builder_->getInt32(fpw_1 * pack_size_1_));
// Quad pair id
Value *pair_a_id = builder_->CreateUDiv(builder_->CreateURem(u_thread_id, _16), _4);
Value *pair_b_id = builder_->CreateUDiv(builder_->CreateURem(u_thread_id, _16), _4);
pair_a_id = builder_->CreateURem(pair_a_id, builder_->getInt32(fpw_0));
pair_b_id = builder_->CreateUDiv(pair_b_id, builder_->getInt32(fpw_0));
pair_b_id = builder_->CreateURem(pair_b_id, builder_->getInt32(fpw_1));
// Quad pair offset
Value *pair_a_off = builder_->CreateMul(pair_a_id, builder_->getInt32(4 * pack_size_0_));
Value *pair_b_off = builder_->CreateMul(pair_b_id, builder_->getInt32(4 * pack_size_1_));
/* inter warp offset */
Value *warp_id_0 = builder_->CreateURem(u_warp_id, builder_->getInt32(wpt_0));
Value *warp_id_12 = builder_->CreateUDiv(u_warp_id, builder_->getInt32(wpt_0));
Value *warp_id_1 = builder_->CreateURem(warp_id_12, builder_->getInt32(wpt_1));
Value *warp_id_2 = builder_->CreateUDiv(warp_id_12, builder_->getInt32(wpt_1));
Value *warp_offset_i = builder_->CreateMul(warp_id_0, builder_->getInt32(hmma_wts_0 * pack_size_0_));
Value *warp_offset_j = builder_->CreateMul(warp_id_1, builder_->getInt32(hmma_wts_1 * pack_size_1_));
/* offsets */
// a offset
offset_a_i_ = builder_->CreateAdd(warp_offset_i, builder_->CreateAdd(pair_a_off, in_pair_off_a));
offset_a_k_ = builder_->CreateAnd(u_thread_id, _3);
// b offsets
offset_b_j_ = builder_->CreateAdd(warp_offset_j, builder_->CreateAdd(pair_b_off, in_pair_off_b));
offset_b_k_ = builder_->CreateAnd(u_thread_id, _3);
// c offsets
Value *offset_c_i = builder_->CreateAdd(builder_->CreateAnd(u_thread_id, _1), offset_a_i_);
Value *offset_c_j = builder_->CreateAdd(builder_->CreateAnd(u_thread_id, _2),
builder_->CreateAdd(warp_offset_j, pair_b_off));
/* indices */
// i indices
std::vector<Value*> idx_i;
for(unsigned pack = 0; pack < num_packs_0_; pack++)
for(unsigned ii = 0; ii < pack_size_0_; ii++)
for(unsigned i = 0; i < 2; i++){
idx_i.push_back(builder_->CreateAdd(offset_c_i, builder_->getInt32(pack*hmma_bts_0*pack_size_0_ + ii*4 + i*2)));
}
// j indices
std::vector<Value*> idx_j;
for(unsigned pack = 0; pack < num_packs_1_; pack++)
for(unsigned jj = 0; jj < pack_size_1_; jj++)
for(unsigned j = 0; j < 2; j++){
idx_j.push_back(builder_->CreateAdd(offset_c_j, builder_->getInt32(pack*hmma_bts_1*pack_size_1_ + jj*4 + j*4*fpw_1*pack_size_1_)));
idx_j.push_back(builder_->CreateAdd(offset_c_j, builder_->getInt32(pack*hmma_bts_1*pack_size_1_ + jj*4 + j*4*fpw_1*pack_size_1_ + 1)));
}
// z indices
std::vector<Value*> idx_z;
for(unsigned pack = 0; pack < num_rep_2; pack++)
idx_z.push_back(builder_->CreateAdd(warp_id_2, builder_->getInt32(pack*hmma_bts_2)));
/* axes */
axes_[layout->axes[0]] = distributed_axis{1, idx_i, warp_id_0};
axes_[layout->axes[1]] = distributed_axis{1, idx_j, warp_id_1};
if(is_batched)
axes_[layout->axes[2]] = distributed_axis{1, idx_z, warp_id_2};
}
machine_layout_scanline_t::machine_layout_scanline_t(Module *mod, Builder *builder,
target *tgt, Type *ty,
analysis::axes *a_axes, std::map<unsigned, distributed_axis> &axes,
analysis::layout_scanline_t* layout)
: machine_layout_distributed_t(mod, builder, tgt, ty, a_axes, axes, layout) {
Value *warp_size = builder_->getInt32(32);
Value* u_thread_id_0 = tgt_->get_local_id(mod_, *builder_, 0);
Value *u_thread_id = builder_->CreateURem(u_thread_id_0, warp_size);
Value *u_warp_id = builder_->CreateUDiv(u_thread_id_0, warp_size);
auto order = layout->order;
const auto& shapes = layout->shapes;
size_t dim = shapes.size();
std::vector<int> nts = layout->nts;
std::vector<int> mts = layout->mts;
Value* full_thread_id = builder_->CreateAdd(builder_->CreateMul(u_warp_id, builder_->getInt32(32)), u_thread_id);
std::vector<Value*> thread_id = delinearize(full_thread_id, order, mts, *builder_);
// Create axes
for(unsigned k = 0; k < dim; k++) {
std::string str_k = std::to_string(k);
Value *contiguous_k = builder_->getInt32(nts[k]);
Value *scaled_thread_id = builder_->CreateMul(thread_id[k], contiguous_k);
unsigned per_block = nts[k] * mts[k];
unsigned per_thread = nts[k] * shapes[k] / per_block;
std::vector<Value*> idx_list(per_thread);
for(unsigned n = 0 ; n < per_thread; n++){
unsigned offset = n / nts[k] * per_block + n % nts[k];
idx_list[n] = builder_->CreateAdd(scaled_thread_id, builder_->getInt32(offset), "idx_" + str_k + "_" + std::to_string(n));
}
axes_[layout->axes[k]] = distributed_axis{nts[k], idx_list, thread_id[k]};
}
}
}
}

View File

@@ -0,0 +1,206 @@
#include <numeric>
#include "llvm/IR/IRBuilder.h"
#include "triton/codegen/selection/machine_value.h"
namespace triton{
namespace codegen{
using namespace llvm;
/* Distributed Tile */
void distributed_tile::init_indices() {
std::vector<size_t> id(axes_.size(), 0);
// create iteration order
std::vector<size_t> order(id.size());
std::iota(order.begin(), order.end(), 0);
auto cmp = [&](int x, int y) {
return axes_[x].contiguous > axes_[y].contiguous;
};
std::sort(order.begin(), order.end(), cmp);
// build
size_t k = 0;
while(true) {
indices_t current;
for(size_t d = 0; d < id.size(); d++)
current.push_back(axes_[d].values[id[d]]);
size_t sz = indices_.size();
indices_[current] = sz;
values_[current] = nullptr;
ordered_indices_.push_back(current);
id[order[0]]++;
while(id[order[k]] == axes_[order[k]].values.size()){
if(k == id.size() - 1)
return;
id[order[k++]] = 0;
id[order[k]]++;
}
k = 0;
}
}
llvm::Type *distributed_tile::make_vector_ty(llvm::Type *ty, size_t vector_size) {
if(vector_size == 1)
return ty;
return VectorType::get(ty, vector_size);
}
distributed_tile::distributed_tile(Type *ty, const shapes_t &shapes, const std::vector<int>& order, const axes_t &axes, llvm::IRBuilder<> &builder, bool vectorize)
: tile(make_vector_ty(ty, vectorize?axes[0].contiguous:1), shapes), axes_(axes), order_(order), builder_(builder) {
vector_size_ = vectorize?ty_->getVectorNumElements():1;
init_indices();
}
void distributed_tile::set_value(indices_t idx, Value *x) {
assert(x->getType() == ty_ && "cannot set a value of different type");
Value *&result = values_[idx];
assert(!result && "value cannot be set twice");
result = x;
}
Value* distributed_tile::get_value(indices_t idx) {
Value *result = values_.at(idx);
assert(result && "value has not been set");
return result;
}
unsigned distributed_tile::get_linear_index(indices_t idx) {
return indices_[idx];
}
indices_t distributed_tile::get_ordered_indices(unsigned id) {
return ordered_indices_.at(id);
}
void distributed_tile::for_each(std::function<void (indices_t)> fn) {
for(unsigned i = 0; i < ordered_indices_.size(); i++){
if(i % vector_size_ == 0)
fn(ordered_indices_[i]);
}
}
/* Shared Tile */
void shared_tile::extract_constant(Value *arg, Value *&non_cst, Value *&cst) {
BinaryOperator *bin_op = dyn_cast<BinaryOperator>(arg);
Constant *_0 = ConstantInt::get(Type::getInt32Ty(arg->getContext()), 0);
if(dyn_cast<Constant>(arg)){
cst = arg;
non_cst = _0;
return;
}
if(!bin_op || bin_op->getOpcode() != llvm::BinaryOperator::Add){
non_cst = arg;
cst = _0;
return;
}
Constant *cst_lhs = dyn_cast<Constant>(bin_op->getOperand(0));
Constant *cst_rhs = dyn_cast<Constant>(bin_op->getOperand(1));
if(cst_lhs && cst_rhs){
cst = arg;
non_cst = _0;
}
else if(cst_lhs){
cst = cst_lhs;
non_cst = bin_op->getOperand(1);
}
else if(cst_rhs){
cst = cst_rhs;
non_cst = bin_op->getOperand(0);
}
else{
non_cst = arg;
cst = _0;
}
}
void shared_tile::extract_constant(const indices_t &arg_idx, indices_t &non_cst_idx, indices_t &cst_idx) {
non_cst_idx.clear();
cst_idx.clear();
for(Value *idx: arg_idx){
Value *non_cst, *cst;
extract_constant(idx, non_cst, cst);
non_cst_idx.push_back(non_cst);
cst_idx.push_back(cst);
}
}
Value* shared_tile::shared_offset(llvm::IRBuilder<> &builder, const shapes_t& shapes, const std::vector<int>& perm, const std::vector<int>& order, indices_t idx) {
// strides
std::vector<Value*> strides(order.size());
strides[order[0]] = builder.getInt32(1);
for(size_t i = 1; i < idx.size(); i++)
strides[order[i]] = builder.CreateMul(strides[order[i-1]], builder.getInt32(shapes[order[i-1]]));
// result
Value *result = builder.getInt32(0);
for(size_t i = 0; i < strides.size(); i++)
result = builder.CreateAdd(result, builder.CreateMul(idx[perm[i]], strides[i]));
return result;
}
shared_tile::shared_tile(Type *ty, const shapes_t &shapes, const std::vector<int>& order, Value *ptr, llvm::IRBuilder<> &builder, Value *offset, const std::vector<int>& perm):
tile(ty, shapes), order_(order), ptr_(ptr), builder_(builder), offset_(offset), vector_size_(1), perm_(perm){
return_vector_ = false;
if(perm_.empty()){
perm_.resize(shapes.size());
std::iota(perm_.begin(), perm_.end(), 0);
}
}
void shared_tile::set_value(indices_t idx, Value *value) {
Value *ptr = builder_.CreateGEP(ptr_, shared_offset(builder_, shapes_, perm_, order_, idx));
unsigned addr_space = ptr->getType()->getPointerAddressSpace();
ptr = builder_.CreateBitCast(ptr, value->getType()->getPointerTo(addr_space));
builder_.CreateStore(value, ptr);
}
void shared_tile::set_vector_size(unsigned vector_size) {
vector_size_ = vector_size;
}
void shared_tile::set_return_mode(bool return_vector){
return_vector_ = return_vector;
}
Value* shared_tile::get_value(indices_t idx) {
indices_t non_cst_idx, cst_idx;
extract_constant(idx, non_cst_idx, cst_idx);
Value *&base_ptr = ptr_cache_[non_cst_idx];
unsigned vector_size = vector_size_;
Type *ty = ty_;
if(ty->isHalfTy() && (vector_size % 2 == 0)){
ty = IntegerType::get(ty->getContext(), 32);
vector_size = vector_size / 2;
}
if(base_ptr == nullptr){
// BasicBlock* store = builder_.GetInsertBlock();
// if(!non_cst_idx.empty())
// if(isa<Instruction>(non_cst_idx.front())){
// builder_.SetInsertPoint((Instruction*)non_cst_idx.front());
// }
base_ptr = builder_.CreateGEP(ptr_, shared_offset(builder_, shapes_, perm_, order_, non_cst_idx));
if(vector_size_ > 1){
Type *vec_ty = VectorType::get(ty, vector_size);
Type *vec_ptr_ty = PointerType::get(vec_ty, base_ptr->getType()->getPointerAddressSpace());
base_ptr = builder_.CreateBitCast(base_ptr, vec_ptr_ty);
}
// builder_.SetInsertPoint(store);
}
Value *offset = shared_offset(builder_, shapes_, perm_, order_, cst_idx);
Value *div = offset;
if(vector_size_ > 1)
div = builder_.CreateUDiv(offset, builder_.getInt32(vector_size_));
Value *ptr = builder_.CreateGEP(base_ptr, div);
Value *result = builder_.CreateLoad(ptr);
if(return_vector_ == false && vector_size_ > 1) {
Value *rem = builder_.CreateURem(offset, builder_.getInt32(vector_size_));
result = builder_.CreateExtractElement(result, rem);
}
return result;
}
}
}

View File

@@ -0,0 +1,20 @@
#include <numeric>
#include "triton/codegen/selection/selection.h"
#include "triton/codegen/selection/generator.h"
#include "triton/ir/module.h"
namespace triton{
namespace codegen{
using namespace llvm;
void selection::run(ir::module &src, Module &dst) {
generator gen(&dst, a_axes_, tgt_, layouts_, alignment_, alloc_, num_warps_ );
for(ir::alloc_const *x: src.allocs())
gen.visit_alloc_const(x);
for(ir::function *fn: src.get_function_list())
gen.visit_function(fn);
}
}
}

View File

@@ -217,7 +217,7 @@ std::unique_ptr<driver::module> function::make_bin(ir::module &module, driver::c
codegen::transform::reassociate reassociate(&align);
codegen::transform::coalesce coalesce(&align, &layouts);
codegen::transform::cts cts;
codegen::selection selection(&liveness, &allocation, &align, &layouts, target.get(), opt.num_warps);
codegen::selection selection(&liveness, &allocation, &align, &axes, &layouts, target.get(), opt.num_warps);
// run passes
// ir::print(module, std::cout);
peephole.run(module);