History prior to this date belonged to the now deprecated ISAAC project, and was deleted to save space
This commit is contained in:
1308
lib/codegen/selection/generator.cc
Normal file
1308
lib/codegen/selection/generator.cc
Normal file
File diff suppressed because it is too large
Load Diff
326
lib/codegen/selection/machine_layout.cc
Normal file
326
lib/codegen/selection/machine_layout.cc
Normal file
@@ -0,0 +1,326 @@
|
||||
#include <numeric>
|
||||
#include "triton/codegen/selection/machine_layout.h"
|
||||
#include "triton/codegen/selection/machine_value.h"
|
||||
#include "triton/codegen/selection/generator.h"
|
||||
#include "triton/codegen/analysis/allocation.h"
|
||||
#include "triton/codegen/analysis/axes.h"
|
||||
#include "triton/codegen/target.h"
|
||||
#include "triton/ir/instructions.h"
|
||||
#include "triton/ir/type.h"
|
||||
#include "llvm/IR/IRBuilder.h"
|
||||
|
||||
namespace triton{
|
||||
namespace codegen{
|
||||
|
||||
using namespace llvm;
|
||||
|
||||
inline Type *llvm_type(ir::type *ty, LLVMContext &ctx) {
|
||||
// function
|
||||
if(auto* tt = dynamic_cast<ir::function_type*>(ty)){
|
||||
Type *return_ty = llvm_type(tt->get_return_ty(), ctx);
|
||||
std::vector<Type*> param_tys;
|
||||
std::transform(tt->params_begin(), tt->params_end(), std::back_inserter(param_tys),
|
||||
[&ctx](ir::type* t){ return llvm_type(t, ctx);});
|
||||
return FunctionType::get(return_ty, param_tys, false);
|
||||
}
|
||||
// pointer
|
||||
if(ty->is_pointer_ty()){
|
||||
Type *elt_ty = llvm_type(ty->get_pointer_element_ty(), ctx);
|
||||
unsigned addr_space = ty->get_pointer_address_space();
|
||||
return PointerType::get(elt_ty, addr_space);
|
||||
}
|
||||
// integer
|
||||
if(ty->is_integer_ty()){
|
||||
unsigned bitwidth = ty->get_integer_bitwidth();
|
||||
return IntegerType::get(ctx, bitwidth);
|
||||
}
|
||||
// primitive types
|
||||
switch(ty->get_type_id()){
|
||||
case ir::type::VoidTyID: return Type::getVoidTy(ctx);
|
||||
case ir::type::HalfTyID: return Type::getHalfTy(ctx);
|
||||
case ir::type::FloatTyID: return Type::getFloatTy(ctx);
|
||||
case ir::type::DoubleTyID: return Type::getDoubleTy(ctx);
|
||||
case ir::type::X86_FP80TyID: return Type::getX86_FP80Ty(ctx);
|
||||
case ir::type::PPC_FP128TyID: return Type::getPPC_FP128Ty(ctx);
|
||||
case ir::type::LabelTyID: return Type::getLabelTy(ctx);
|
||||
case ir::type::MetadataTyID: return Type::getMetadataTy(ctx);
|
||||
case ir::type::TokenTyID: return Type::getTokenTy(ctx);
|
||||
default: break;
|
||||
}
|
||||
// unknown type
|
||||
throw std::runtime_error("unknown conversion from ir::type to Type");
|
||||
}
|
||||
|
||||
// Grid construction
|
||||
inline std::vector<Value*> delinearize(Value *trailing, const std::vector<int>& order, std::vector<int> &shapes, IRBuilder<> &builder){
|
||||
size_t dim = shapes.size();
|
||||
std::vector<Value*> result(dim);
|
||||
for(unsigned k = 0; k < dim - 1; k++){
|
||||
Constant *dim_k = builder.getInt32(shapes[order[k]]);
|
||||
Value *rem = builder.CreateURem(trailing, dim_k);
|
||||
trailing = builder.CreateUDiv(trailing, dim_k);
|
||||
result[order[k]] = rem;
|
||||
}
|
||||
result[order[dim - 1]] = trailing;
|
||||
return result;
|
||||
}
|
||||
|
||||
inline int32_t ceil(int32_t num, int32_t div){
|
||||
return (num + div - 1)/div;
|
||||
}
|
||||
|
||||
|
||||
|
||||
machine_shared_layout::machine_shared_layout(Module *mod, Builder *builder, target *tgt, analysis::allocation* alloc,
|
||||
Value *&sh_mem_ptr, analysis::shared_layout *layout,
|
||||
std::map<ir::value *, Value *>& vmap,
|
||||
std::map<ir::value *, tile *>& tmap)
|
||||
: mod_(mod), builder_(builder), tgt_(tgt), alloc_(alloc), sh_mem_ptr_(sh_mem_ptr), layout_(layout), vmap_(vmap), tmap_(tmap) {
|
||||
|
||||
Type* ty = llvm_type(layout_->get_type(), builder_->getContext());
|
||||
PointerType *ptr_ty = ty->getPointerTo(sh_mem_ptr_->getType()->getPointerAddressSpace());
|
||||
// double-buffered
|
||||
if(layout_->get_double_buffer()) {
|
||||
BasicBlock *current = builder_->GetInsertBlock();
|
||||
auto info = *layout_->get_double_buffer();
|
||||
ir::phi_node *phi = info.phi;
|
||||
BasicBlock *parent = (BasicBlock*)vmap_.at((ir::value*)(phi->get_parent()));
|
||||
if(parent->empty())
|
||||
builder_->SetInsertPoint(parent);
|
||||
else
|
||||
builder_->SetInsertPoint(&*parent->getFirstNonPHI());
|
||||
// create pointers
|
||||
ptr_ = builder_->CreatePHI(ptr_ty, 2);
|
||||
pre_ptr_ = builder_->CreateGEP(sh_mem_ptr_, builder_->getInt32(alloc_->offset(layout_)));
|
||||
pre_ptr_ = builder_->CreateBitCast(pre_ptr_, ptr_->getType());
|
||||
offset_ = builder_->CreatePHI(builder_->getInt32Ty(), 2);
|
||||
next_ptr_ = builder_->CreateGEP(ptr_, offset_, "next_ptr");
|
||||
builder_->SetInsertPoint(current);
|
||||
}
|
||||
else{
|
||||
size_t offset = alloc_->offset(layout_);
|
||||
ptr_ = builder_->CreateGEP(sh_mem_ptr_, builder_->getInt32(offset));
|
||||
ptr_ = builder_->CreateBitCast(ptr_, ptr_ty);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
tile* machine_shared_layout::create(ir::value *v) {
|
||||
Type* ty = llvm_type(layout_->get_type(), builder_->getContext());
|
||||
auto double_buffer = layout_->get_double_buffer();
|
||||
// offset
|
||||
Value *offset = nullptr;
|
||||
if(double_buffer && v == double_buffer->phi)
|
||||
offset = offset_;
|
||||
// base pointer
|
||||
Value *ptr = ptr_;
|
||||
if(double_buffer && v == double_buffer->latch)
|
||||
ptr = next_ptr_;
|
||||
else if(double_buffer && v == double_buffer->first)
|
||||
ptr = pre_ptr_;
|
||||
// create tile
|
||||
return new shared_tile(ty, layout_->get_shape(), layout_->get_order(), ptr, *builder_, offset);
|
||||
}
|
||||
|
||||
machine_distributed_layout::machine_distributed_layout(Module *mod, Builder *builder, target *tgt,
|
||||
analysis::axes *a_axes, std::map<unsigned, distributed_axis>& axes,
|
||||
analysis::data_layout *layout)
|
||||
: mod_(mod), builder_(builder), tgt_(tgt), a_axes_(a_axes), axes_(axes), layout_(layout) {
|
||||
|
||||
}
|
||||
|
||||
tile *machine_distributed_layout::create(ir::value *v) {
|
||||
Type *ty = llvm_type(v->get_type()->get_scalar_ty(), builder_->getContext());
|
||||
const auto &shapes = v->get_type()->get_tile_shapes();
|
||||
size_t rank = shapes.size();
|
||||
std::vector<distributed_axis> axes(rank);
|
||||
std::vector<int> order(rank);
|
||||
// compute axes
|
||||
for(size_t d = 0; d < shapes.size(); d++){
|
||||
if(shapes[d] > 1){
|
||||
unsigned x = a_axes_->get(v, d);
|
||||
axes[d] = axes_.at(x);
|
||||
}
|
||||
else{
|
||||
axes[d].contiguous = 1;
|
||||
axes[d].values = {builder_->getInt32(0)};
|
||||
}
|
||||
}
|
||||
// compute order
|
||||
std::iota(order.begin(), order.end(), 0);
|
||||
auto cmp = [&](int x, int y) {
|
||||
unsigned axx = a_axes_->get(v, x);
|
||||
unsigned axy = a_axes_->get(v, y);
|
||||
size_t posx = layout_->find_axis(axx);
|
||||
size_t posy = layout_->find_axis(axy);
|
||||
if(posx < rank && posy < rank)
|
||||
return layout_->get_order(posx) < layout_->get_order(posy);
|
||||
return false;
|
||||
};
|
||||
std::sort(order.begin(), order.end(), cmp);
|
||||
|
||||
return new distributed_tile(ty, shapes, order, axes, *builder_);
|
||||
}
|
||||
|
||||
machine_mma884_layout::machine_mma884_layout(Module *mod, Builder *builder,
|
||||
target *tgt, analysis::axes *a_axes,
|
||||
std::map<unsigned, distributed_axis>& axes,
|
||||
analysis::mma884_layout* layout)
|
||||
: machine_distributed_layout(mod, builder, tgt, a_axes, axes, layout) {
|
||||
|
||||
Value *warp_size = builder_->getInt32(32);
|
||||
Value* u_thread_id_0 = tgt_->get_local_id(mod_, *builder_, 0);
|
||||
Value *u_thread_id = builder_->CreateURem(u_thread_id_0, warp_size);
|
||||
Value *u_warp_id = builder_->CreateUDiv(u_thread_id_0, warp_size);
|
||||
|
||||
const auto& shape = layout->get_shape();
|
||||
if(shape.size() > 3)
|
||||
throw std::runtime_error("unsupported");
|
||||
bool is_batched = shape.size() >= 3;
|
||||
|
||||
Value *_1 = builder_->getInt32(1);
|
||||
Value *_2 = builder_->getInt32(2);
|
||||
Value *_3 = builder_->getInt32(3);
|
||||
Value *_4 = builder_->getInt32(4);
|
||||
Value *_16 = builder_->getInt32(16);
|
||||
|
||||
// fragments per warp
|
||||
unsigned fpw_0 = layout->fpw(0);
|
||||
unsigned fpw_1 = layout->fpw(1);
|
||||
unsigned fpw_2 = is_batched ? layout->fpw(2) : 1;
|
||||
// warps per tile
|
||||
unsigned wpt_0 = layout->wpt(0);
|
||||
unsigned wpt_1 = layout->wpt(1);
|
||||
unsigned wpt_2 = is_batched ? layout->wpt(2) : 1;
|
||||
// mma warp tile size
|
||||
unsigned hmma_wts_0 = fpw_0 * 8;
|
||||
unsigned hmma_wts_1 = fpw_1 * 8;
|
||||
unsigned hmma_wts_2 = is_batched ? fpw_2 : 1;
|
||||
// mma block tile size
|
||||
unsigned hmma_bts_0 = hmma_wts_0 * wpt_0;
|
||||
unsigned hmma_bts_1 = hmma_wts_1 * wpt_1;
|
||||
unsigned hmma_bts_2 = is_batched ? hmma_wts_2 * wpt_2 : 1;
|
||||
// number of repetition
|
||||
unsigned num_rep_0 = shape[0] / hmma_bts_0;
|
||||
unsigned num_rep_1 = shape[1] / hmma_bts_1;
|
||||
unsigned num_rep_2 = is_batched ? shape[2] / hmma_bts_2 : 1;
|
||||
// size of each pack (interleaving)
|
||||
pack_size_0_ = std::min<unsigned>(num_rep_0, 1);
|
||||
pack_size_1_ = std::min<unsigned>(num_rep_1, 1);
|
||||
// number of packs (interleaving)
|
||||
num_packs_0_ = num_rep_0 / pack_size_0_;
|
||||
num_packs_1_ = num_rep_1 / pack_size_1_;
|
||||
|
||||
/* intra warp offset */
|
||||
// offset of quad in pair
|
||||
Value *in_pair_off_a = builder_->CreateMul(builder_->CreateUDiv(builder_->CreateAnd(u_thread_id, _16), builder_->getInt32(4)),
|
||||
builder_->getInt32(fpw_0 * pack_size_0_));
|
||||
Value *in_pair_off_b = builder_->CreateMul(builder_->CreateUDiv(builder_->CreateAnd(u_thread_id, _16), builder_->getInt32(4)),
|
||||
builder_->getInt32(fpw_1 * pack_size_1_));
|
||||
|
||||
// Quad pair id
|
||||
Value *pair_a_id = builder_->CreateUDiv(builder_->CreateURem(u_thread_id, _16), _4);
|
||||
Value *pair_b_id = builder_->CreateUDiv(builder_->CreateURem(u_thread_id, _16), _4);
|
||||
pair_a_id = builder_->CreateURem(pair_a_id, builder_->getInt32(fpw_0));
|
||||
pair_b_id = builder_->CreateUDiv(pair_b_id, builder_->getInt32(fpw_0));
|
||||
pair_b_id = builder_->CreateURem(pair_b_id, builder_->getInt32(fpw_1));
|
||||
// Quad pair offset
|
||||
Value *pair_a_off = builder_->CreateMul(pair_a_id, builder_->getInt32(4 * pack_size_0_));
|
||||
Value *pair_b_off = builder_->CreateMul(pair_b_id, builder_->getInt32(4 * pack_size_1_));
|
||||
|
||||
/* inter warp offset */
|
||||
Value *warp_id_0 = builder_->CreateURem(u_warp_id, builder_->getInt32(wpt_0));
|
||||
Value *warp_id_12 = builder_->CreateUDiv(u_warp_id, builder_->getInt32(wpt_0));
|
||||
Value *warp_id_1 = builder_->CreateURem(warp_id_12, builder_->getInt32(wpt_1));
|
||||
Value *warp_id_2 = builder_->CreateUDiv(warp_id_12, builder_->getInt32(wpt_1));
|
||||
Value *warp_offset_i = builder_->CreateMul(warp_id_0, builder_->getInt32(hmma_wts_0 * pack_size_0_));
|
||||
Value *warp_offset_j = builder_->CreateMul(warp_id_1, builder_->getInt32(hmma_wts_1 * pack_size_1_));
|
||||
|
||||
/* offsets */
|
||||
// a offset
|
||||
offset_a_i_ = builder_->CreateAdd(warp_offset_i, builder_->CreateAdd(pair_a_off, in_pair_off_a));
|
||||
offset_a_k_ = builder_->CreateAnd(u_thread_id, _3);
|
||||
// b offsets
|
||||
offset_b_j_ = builder_->CreateAdd(warp_offset_j, builder_->CreateAdd(pair_b_off, in_pair_off_b));
|
||||
offset_b_k_ = builder_->CreateAnd(u_thread_id, _3);
|
||||
|
||||
// c offsets
|
||||
Value *offset_c_i = builder_->CreateAdd(builder_->CreateAnd(u_thread_id, _1), offset_a_i_);
|
||||
Value *offset_c_j = builder_->CreateAdd(builder_->CreateAnd(u_thread_id, _2),
|
||||
builder_->CreateAdd(warp_offset_j, pair_b_off));
|
||||
|
||||
/* indices */
|
||||
// i indices
|
||||
std::vector<Value*> idx_i;
|
||||
for(unsigned pack = 0; pack < num_packs_0_; pack++)
|
||||
for(unsigned ii = 0; ii < pack_size_0_; ii++)
|
||||
for(unsigned i = 0; i < 2; i++){
|
||||
idx_i.push_back(builder_->CreateAdd(offset_c_i, builder_->getInt32(pack*hmma_bts_0*pack_size_0_ + ii*4 + i*2)));
|
||||
}
|
||||
// j indices
|
||||
std::vector<Value*> idx_j;
|
||||
for(unsigned pack = 0; pack < num_packs_1_; pack++)
|
||||
for(unsigned jj = 0; jj < pack_size_1_; jj++)
|
||||
for(unsigned j = 0; j < 2; j++){
|
||||
idx_j.push_back(builder_->CreateAdd(offset_c_j, builder_->getInt32(pack*hmma_bts_1*pack_size_1_ + jj*4 + j*4*fpw_1*pack_size_1_)));
|
||||
idx_j.push_back(builder_->CreateAdd(offset_c_j, builder_->getInt32(pack*hmma_bts_1*pack_size_1_ + jj*4 + j*4*fpw_1*pack_size_1_ + 1)));
|
||||
}
|
||||
// z indices
|
||||
std::vector<Value*> idx_z;
|
||||
for(unsigned pack = 0; pack < num_rep_2; pack++)
|
||||
idx_z.push_back(builder_->CreateAdd(warp_id_2, builder_->getInt32(pack*hmma_bts_2)));
|
||||
|
||||
|
||||
/* axes */
|
||||
axes_[layout->get_axis(0)] = distributed_axis{1, idx_i, warp_id_0};
|
||||
axes_[layout->get_axis(1)] = distributed_axis{1, idx_j, warp_id_1};
|
||||
if(is_batched)
|
||||
axes_[layout->get_axis(2)] = distributed_axis{1, idx_z, warp_id_2};
|
||||
}
|
||||
|
||||
|
||||
machine_scanline_layout::machine_scanline_layout(Module *mod, Builder *builder,
|
||||
target *tgt,
|
||||
analysis::axes *a_axes, std::map<unsigned, distributed_axis> &axes,
|
||||
analysis::scanline_layout* layout)
|
||||
: machine_distributed_layout(mod, builder, tgt, a_axes, axes, layout) {
|
||||
|
||||
Value *warp_size = builder_->getInt32(32);
|
||||
Value* u_thread_id_0 = tgt_->get_local_id(mod_, *builder_, 0);
|
||||
Value *u_thread_id = builder_->CreateURem(u_thread_id_0, warp_size);
|
||||
Value *u_warp_id = builder_->CreateUDiv(u_thread_id_0, warp_size);
|
||||
|
||||
auto order = layout->get_order();
|
||||
const auto& shape = layout->get_shape();
|
||||
Value* full_thread_id = builder_->CreateAdd(builder_->CreateMul(u_warp_id, builder_->getInt32(32)), u_thread_id);
|
||||
// Delinearize
|
||||
size_t dim = shape.size();
|
||||
std::vector<Value*> thread_id(dim);
|
||||
for(unsigned k = 0; k < dim - 1; k++){
|
||||
Constant *dim_k = builder_->getInt32(layout->mts(order[k]));
|
||||
Value *rem = builder_->CreateURem(full_thread_id, dim_k);
|
||||
full_thread_id = builder_->CreateUDiv(full_thread_id, dim_k);
|
||||
thread_id[order[k]] = rem;
|
||||
}
|
||||
thread_id[order[dim - 1]] = full_thread_id;
|
||||
// Create axes
|
||||
for(unsigned k = 0; k < dim; k++) {
|
||||
int nts = layout->nts(k);
|
||||
int mts = layout->mts(k);
|
||||
std::string str_k = std::to_string(k);
|
||||
Value *contiguous_k = builder_->getInt32(nts);
|
||||
Value *scaled_thread_id = builder_->CreateMul(thread_id[k], contiguous_k);
|
||||
unsigned per_block = nts * mts;
|
||||
unsigned per_thread = nts * shape[k] / per_block;
|
||||
std::vector<Value*> idx_list(per_thread);
|
||||
for(unsigned n = 0 ; n < per_thread; n++){
|
||||
unsigned offset = n / nts * per_block + n % nts;
|
||||
idx_list[n] = builder_->CreateAdd(scaled_thread_id, builder_->getInt32(offset), "idx_" + str_k + "_" + std::to_string(n));
|
||||
}
|
||||
axes_[layout->get_axis(k)] = distributed_axis{nts, idx_list, thread_id[k]};
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
}
|
||||
}
|
214
lib/codegen/selection/machine_value.cc
Normal file
214
lib/codegen/selection/machine_value.cc
Normal file
@@ -0,0 +1,214 @@
|
||||
#include <numeric>
|
||||
#include <iostream>
|
||||
#include "llvm/IR/IRBuilder.h"
|
||||
#include "triton/codegen/selection/machine_value.h"
|
||||
|
||||
namespace triton{
|
||||
namespace codegen{
|
||||
|
||||
using namespace llvm;
|
||||
|
||||
/* Distributed Tile */
|
||||
void distributed_tile::init_indices() {
|
||||
std::vector<size_t> id(axes_.size(), 0);
|
||||
// build
|
||||
size_t k = 0;
|
||||
while(true) {
|
||||
indices_t current;
|
||||
for(size_t d = 0; d < id.size(); d++)
|
||||
current.push_back(axes_[d].values[id[d]]);
|
||||
size_t sz = indices_.size();
|
||||
indices_[current] = sz;
|
||||
values_[current] = nullptr;
|
||||
ordered_indices_.push_back(current);
|
||||
id[order_[0]]++;
|
||||
while(id[order_[k]] == axes_[order_[k]].values.size()){
|
||||
if(k == id.size() - 1)
|
||||
return;
|
||||
id[order_[k++]] = 0;
|
||||
id[order_[k]]++;
|
||||
}
|
||||
k = 0;
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
distributed_tile::distributed_tile(Type *ty, const shapes_t &shapes, const std::vector<int>& order, const axes_t &axes, llvm::IRBuilder<> &builder)
|
||||
: tile(ty, shapes), axes_(axes), order_(order), builder_(builder) {
|
||||
init_indices();
|
||||
}
|
||||
|
||||
void distributed_tile::set_value(indices_t idx, Value *x) {
|
||||
assert(x->getType() == ty_ && "cannot set a value of different type");
|
||||
Value *&result = values_[idx];
|
||||
assert(!result && "value cannot be set twice");
|
||||
result = x;
|
||||
}
|
||||
|
||||
Value* distributed_tile::get_value(indices_t idx) {
|
||||
Value *result = values_.at(idx);
|
||||
assert(result && "value has not been set");
|
||||
return result;
|
||||
}
|
||||
|
||||
unsigned distributed_tile::get_linear_index(indices_t idx) {
|
||||
return indices_[idx];
|
||||
}
|
||||
|
||||
indices_t distributed_tile::get_ordered_indices(unsigned id) {
|
||||
return ordered_indices_.at(id);
|
||||
}
|
||||
|
||||
|
||||
void distributed_tile::for_each(std::function<void (indices_t)> fn, int start, int end) {
|
||||
if(end < 0)
|
||||
end = ordered_indices_.size() + end + 1;
|
||||
for(unsigned i = start; i < end; i++)
|
||||
fn(ordered_indices_[i]);
|
||||
}
|
||||
|
||||
void distributed_tile::for_each(std::function<void(indices_t)> fn, std::vector<int> starts, std::vector<int> sizes){
|
||||
int rank = sizes.size();
|
||||
int len = 1;
|
||||
for(int s: sizes)
|
||||
len *= s;
|
||||
|
||||
for(int i = 0; i < len; i++){
|
||||
indices_t idx(rank);
|
||||
int current = i;
|
||||
for(int k = 0; k < rank; k++){
|
||||
idx[k] = axes_[k].values.at(starts[k] + current % sizes[k]);
|
||||
current = current / sizes[k];
|
||||
}
|
||||
fn(idx);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
/* Shared Tile */
|
||||
void shared_tile::extract_constant(Value *arg, Value *&non_cst, Value *&cst) {
|
||||
BinaryOperator *bin_op = dyn_cast<BinaryOperator>(arg);
|
||||
Constant *_0 = ConstantInt::get(Type::getInt32Ty(arg->getContext()), 0);
|
||||
if(dyn_cast<Constant>(arg)){
|
||||
cst = arg;
|
||||
non_cst = _0;
|
||||
return;
|
||||
}
|
||||
if(!bin_op || bin_op->getOpcode() != llvm::BinaryOperator::Add){
|
||||
non_cst = arg;
|
||||
cst = _0;
|
||||
return;
|
||||
}
|
||||
Constant *cst_lhs = dyn_cast<Constant>(bin_op->getOperand(0));
|
||||
Constant *cst_rhs = dyn_cast<Constant>(bin_op->getOperand(1));
|
||||
if(cst_lhs && cst_rhs){
|
||||
cst = arg;
|
||||
non_cst = _0;
|
||||
}
|
||||
else if(cst_lhs){
|
||||
cst = cst_lhs;
|
||||
non_cst = bin_op->getOperand(1);
|
||||
}
|
||||
else if(cst_rhs){
|
||||
cst = cst_rhs;
|
||||
non_cst = bin_op->getOperand(0);
|
||||
}
|
||||
else{
|
||||
non_cst = arg;
|
||||
cst = _0;
|
||||
}
|
||||
}
|
||||
|
||||
void shared_tile::extract_constant(const indices_t &arg_idx, indices_t &non_cst_idx, indices_t &cst_idx) {
|
||||
non_cst_idx.clear();
|
||||
cst_idx.clear();
|
||||
for(Value *idx: arg_idx){
|
||||
Value *non_cst, *cst;
|
||||
extract_constant(idx, non_cst, cst);
|
||||
non_cst_idx.push_back(non_cst);
|
||||
cst_idx.push_back(cst);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
Value* shared_tile::shared_offset(llvm::IRBuilder<> &builder, const shapes_t& shapes,
|
||||
const std::vector<int>& perm, const std::vector<int>& order,
|
||||
indices_t idx) {
|
||||
// strides
|
||||
std::vector<Value*> strides(order.size());
|
||||
strides[order[0]] = builder.getInt32(1);
|
||||
for(size_t i = 1; i < idx.size(); i++)
|
||||
strides[order[i]] = builder.CreateMul(strides[order[i-1]], builder.getInt32(shapes[order[i-1]]));
|
||||
// result
|
||||
Value *result = builder.getInt32(0);
|
||||
for(size_t i = 0; i < strides.size(); i++)
|
||||
result = builder.CreateAdd(result, builder.CreateMul(idx[perm[i]], strides[i]));
|
||||
return result;
|
||||
}
|
||||
|
||||
shared_tile::shared_tile(Type *ty, const shapes_t &shapes, const std::vector<int>& order, Value *ptr, llvm::IRBuilder<> &builder, Value *offset, const std::vector<int>& perm):
|
||||
tile(ty, shapes), order_(order), ptr_(ptr), builder_(builder), offset_(offset), vector_size_(1), perm_(perm){
|
||||
return_vector_ = false;
|
||||
if(perm_.empty()){
|
||||
perm_.resize(shapes.size());
|
||||
std::iota(perm_.begin(), perm_.end(), 0);
|
||||
}
|
||||
}
|
||||
|
||||
void shared_tile::set_value(indices_t idx, Value *value) {
|
||||
Value *ptr = builder_.CreateGEP(ptr_, shared_offset(builder_, shapes_, perm_, order_, idx));
|
||||
unsigned addr_space = ptr->getType()->getPointerAddressSpace();
|
||||
ptr = builder_.CreateBitCast(ptr, value->getType()->getPointerTo(addr_space));
|
||||
builder_.CreateStore(value, ptr);
|
||||
}
|
||||
|
||||
void shared_tile::set_vector_size(unsigned vector_size) {
|
||||
vector_size_ = vector_size;
|
||||
}
|
||||
|
||||
void shared_tile::set_return_mode(bool return_vector){
|
||||
return_vector_ = return_vector;
|
||||
}
|
||||
|
||||
|
||||
Value* shared_tile::get_value(indices_t idx) {
|
||||
indices_t non_cst_idx, cst_idx;
|
||||
extract_constant(idx, non_cst_idx, cst_idx);
|
||||
Value *&base_ptr = ptr_cache_[non_cst_idx];
|
||||
unsigned vector_size = vector_size_;
|
||||
Type *ty = ty_;
|
||||
if(ty->isHalfTy() && (vector_size % 2 == 0)){
|
||||
ty = IntegerType::get(ty->getContext(), 32);
|
||||
vector_size = vector_size / 2;
|
||||
}
|
||||
if(base_ptr == nullptr){
|
||||
// BasicBlock* store = builder_.GetInsertBlock();
|
||||
// if(!non_cst_idx.empty())
|
||||
// if(isa<Instruction>(non_cst_idx.front())){
|
||||
// builder_.SetInsertPoint((Instruction*)non_cst_idx.front());
|
||||
// }
|
||||
base_ptr = builder_.CreateGEP(ptr_, shared_offset(builder_, shapes_, perm_, order_, non_cst_idx));
|
||||
if(vector_size_ > 1){
|
||||
Type *vec_ty = VectorType::get(ty, vector_size);
|
||||
Type *vec_ptr_ty = PointerType::get(vec_ty, base_ptr->getType()->getPointerAddressSpace());
|
||||
base_ptr = builder_.CreateBitCast(base_ptr, vec_ptr_ty);
|
||||
}
|
||||
// builder_.SetInsertPoint(store);
|
||||
}
|
||||
Value *offset = shared_offset(builder_, shapes_, perm_, order_, cst_idx);
|
||||
Value *div = offset;
|
||||
if(vector_size_ > 1)
|
||||
div = builder_.CreateUDiv(offset, builder_.getInt32(vector_size_));
|
||||
Value *ptr = builder_.CreateGEP(base_ptr, div);
|
||||
Value *result = builder_.CreateLoad(ptr);
|
||||
if(return_vector_ == false && vector_size_ > 1) {
|
||||
Value *rem = builder_.CreateURem(offset, builder_.getInt32(vector_size_));
|
||||
result = builder_.CreateExtractElement(result, rem);
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
|
||||
|
||||
}
|
||||
}
|
Reference in New Issue
Block a user