[CORE] Fixed several issues that arose in the development of the
torch-blocksparse package: * Now using warp shuffle in reductions when possible * Various bugfixes in layout inference * Added INFINITY, exponential and select * Better error messages for unimplemented constructs
This commit is contained in:
committed by
Philippe Tillet
parent
ac26fbdc1f
commit
3304629de9
@@ -2,7 +2,7 @@
|
||||
#include "triton/ir/utils.h"
|
||||
#include "triton/ir/instructions.h"
|
||||
#include "triton/ir/type.h"
|
||||
|
||||
#include <iostream>
|
||||
|
||||
|
||||
namespace triton{
|
||||
|
@@ -16,7 +16,9 @@ namespace analysis{
|
||||
* Helper Functions *
|
||||
* -------------------------------- */
|
||||
|
||||
inline unsigned clamp(unsigned x, unsigned lo, unsigned hi) {
|
||||
inline unsigned clamp(unsigned x, unsigned a, unsigned b) {
|
||||
unsigned lo = std::min(a, b);
|
||||
unsigned hi = std::max(a, b);
|
||||
return std::min(std::max(x, lo), hi);
|
||||
}
|
||||
|
||||
@@ -97,7 +99,9 @@ data_layout::data_layout(id_t id,
|
||||
order_.resize(axes_.size());
|
||||
std::iota(order_.begin(), order_.end(), 0);
|
||||
auto largest = std::max_element(ptr.begin(), ptr.end(), [&](ir::value *x, ir::value *y){
|
||||
return x->get_type()->get_tile_rank() < y->get_type()->get_tile_rank();
|
||||
std::pair<int, int> xx = {x->get_type()->get_tile_rank(), x->get_type()->get_tile_num_elements()};
|
||||
std::pair<int, int> yy = {y->get_type()->get_tile_rank(), y->get_type()->get_tile_num_elements()};
|
||||
return xx < yy;
|
||||
});
|
||||
if(*largest){
|
||||
auto max_contiguous = align->contiguous(*largest);
|
||||
@@ -201,8 +205,9 @@ scanline_layout::scanline_layout(size_t num_warps,
|
||||
for(size_t d = 0; d < shape_.size(); d++)
|
||||
effective_num_threads *= mts_[d];
|
||||
|
||||
if(num_warps * 32 != effective_num_threads)
|
||||
throw std::runtime_error("cannot create a kernel with this amount of warps");
|
||||
// std::cout <<values.size() << " " << num_warps << " " << effective_num_threads << std::endl;
|
||||
// if(num_warps * 32 != effective_num_threads)
|
||||
// throw std::runtime_error("cannot create a kernel with this amount of warps");
|
||||
}
|
||||
|
||||
|
||||
@@ -355,8 +360,9 @@ void layouts::make_graph(ir::instruction *i) {
|
||||
void layouts::create(size_t id, const std::vector<ir::value*>& values) {
|
||||
auto it_hmma_c = std::find_if(values.begin(), values.end(), &is_hmma_c);
|
||||
auto cmp = [](ir::value* x, ir::value *y) {
|
||||
return x->get_type()->get_tile_ranks1() <
|
||||
y->get_type()->get_tile_ranks1();
|
||||
std::pair<int, int> xx = {x->get_type()->get_tile_rank(), x->get_type()->get_tile_num_elements()};
|
||||
std::pair<int, int> yy = {y->get_type()->get_tile_rank(), y->get_type()->get_tile_num_elements()};
|
||||
return xx < yy;
|
||||
};
|
||||
std::vector<ir::value*> lvalue = values;
|
||||
std::remove_if(lvalue.begin(), lvalue.end(), [&](ir::value* v) { return dynamic_cast<ir::trans_inst*>(v); });
|
||||
@@ -402,11 +408,8 @@ void layouts::run(ir::module &mod) {
|
||||
unsigned axis = red->get_axis();
|
||||
// shape
|
||||
auto shapes = arg->get_type()->get_tile_shapes();
|
||||
unsigned shape_ax = shapes[axis];
|
||||
scanline_layout *layout = get(arg)->to_scanline();
|
||||
unsigned per_thread = layout->nts(axis);
|
||||
unsigned depth = shape_ax / per_thread;
|
||||
shapes[axis] = depth;
|
||||
shapes[axis] = layout->mts(axis);
|
||||
// create layout
|
||||
layouts_[id] = new shared_layout(layout, axes_->get(arg), shapes, {red}, red->get_type()->get_scalar_ty(), align_);
|
||||
tmp_[red] = id;
|
||||
|
@@ -196,8 +196,9 @@ void generator::visit_value(ir::value* v) {
|
||||
BasicBlock *current = builder_->GetInsertBlock();
|
||||
auto *inst = dynamic_cast<ir::instruction*>(v);
|
||||
if(inst && !dynamic_cast<ir::phi_node*>(v))
|
||||
for(ir::value *op: inst->ops())
|
||||
for(ir::value *op: inst->ops()){
|
||||
visit_value(op);
|
||||
}
|
||||
// change insert point for phi node
|
||||
builder_->SetInsertPoint(current);
|
||||
auto *phi = dynamic_cast<ir::phi_node*>(v);
|
||||
@@ -547,6 +548,24 @@ void generator::visit_get_num_program_inst(ir::get_num_program_inst* np) {
|
||||
vmap_[np] = ret;
|
||||
}
|
||||
|
||||
void generator::visit_exp_inst(ir::exp_inst* x){
|
||||
distributed_tile *arg = (distributed_tile*)tmap_.at(x->get_operand(0));
|
||||
// Function *fn = builder_->GetInsertBlock()->getParent();
|
||||
// Module *module = fn->getParent();
|
||||
// Type *ty = llvm_type(x->get_type()->get_scalar_ty(), *ctx_);
|
||||
// Function *ex2 = Intrinsic::getDeclaration(module, Intrinsic::nvvm_ex2_approx_ftz_f, {ty});
|
||||
Constant *log2e = ConstantFP::get(builder_->getFloatTy(), 1.4426950408889634);
|
||||
|
||||
FunctionType *fn_ty = FunctionType::get(builder_->getFloatTy(), {builder_->getFloatTy()}, false);
|
||||
InlineAsm *ex2 = InlineAsm::get(fn_ty, "ex2.approx.ftz.f32 $0, $1;", "=f,f", false);
|
||||
|
||||
|
||||
for_each(x, [&](indices_t idx){
|
||||
Value *ex2arg = builder_->CreateFMul(arg->get_value(idx), log2e);
|
||||
set_value(x, idx, builder_->CreateCall(ex2, {ex2arg}));
|
||||
});
|
||||
}
|
||||
|
||||
void generator::visit_atomic_cas_inst(ir::atomic_cas_inst* cas) {
|
||||
BasicBlock *current = builder_->GetInsertBlock();
|
||||
Module *module = current->getModule();
|
||||
@@ -587,6 +606,7 @@ void generator::visit_atomic_exch_inst(ir::atomic_exch_inst* xchg) {
|
||||
BasicBlock *tid_0_bb = BasicBlock::Create(*ctx_, "tid_0", current->getParent());
|
||||
BasicBlock *tid_0_done_bb = BasicBlock::Create(*ctx_, "tid_0_done", current->getParent());
|
||||
tgt_->add_memfence(module, *builder_);
|
||||
tgt_->add_barrier(module, *builder_);
|
||||
builder_->CreateCondBr(pred, tid_0_bb, tid_0_done_bb);
|
||||
builder_->SetInsertPoint(tid_0_bb);
|
||||
builder_->CreateAtomicRMW(AtomicRMWInst::Xchg, rmw_ptr, rmw_val,
|
||||
@@ -825,24 +845,111 @@ void generator::visit_reduce_inst(ir::reduce_inst* x) {
|
||||
ir::value *arg = x->get_operand(0);
|
||||
distributed_tile* arg_tile = (distributed_tile*)tmap_.at(arg);
|
||||
ir::reduce_inst::op_t op = x->get_op();
|
||||
unsigned axis = x->get_axis();
|
||||
|
||||
Type *fp32_ty = builder_->getFloatTy();
|
||||
FunctionType *fmaxmin_ty = FunctionType::get(fp32_ty, {fp32_ty, fp32_ty}, false);
|
||||
InlineAsm *fmin = InlineAsm::get(fmaxmin_ty, "min.ftz.f32 $0, $1, $2;", "=f,f,f", false);
|
||||
InlineAsm *fmax = InlineAsm::get(fmaxmin_ty, "max.ftz.f32 $0, $1, $2;", "=f,f,f", false);
|
||||
|
||||
auto accumulate = [&](Value* x, Value *y) -> Value* {
|
||||
switch(op) {
|
||||
case ir::reduce_inst::ADD: return builder_->CreateAdd(x, y);
|
||||
case ir::reduce_inst::SUB: return builder_->CreateSub(x, y);
|
||||
case ir::reduce_inst::MAX: return builder_->CreateMaximum(x, y);
|
||||
case ir::reduce_inst::MIN: return builder_->CreateMinimum(x, y);
|
||||
case ir::reduce_inst::MAX:{
|
||||
if(x->getType()->isIntegerTy())
|
||||
return builder_->CreateSelect(builder_->CreateICmpSGE(x, y), x, y);
|
||||
else
|
||||
return builder_->CreateMaxNum(x, y);
|
||||
}
|
||||
case ir::reduce_inst::MIN:{
|
||||
if(x->getType()->isIntegerTy())
|
||||
return builder_->CreateSelect(builder_->CreateICmpSLE(x, y), x, y);
|
||||
else
|
||||
return builder_->CreateMinNum(x, y);
|
||||
}
|
||||
case ir::reduce_inst::FADD: return builder_->CreateFAdd(x, y);
|
||||
case ir::reduce_inst::FSUB: return builder_->CreateFSub(x, y);
|
||||
case ir::reduce_inst::FMAX: return builder_->CreateSelect(builder_->CreateFCmpOGT(x, y), x, y);
|
||||
case ir::reduce_inst::FMIN: return builder_->CreateSelect(builder_->CreateFCmpOLT(x, y), x, y);
|
||||
default: break;
|
||||
case ir::reduce_inst::FMAX: return builder_->CreateCall(fmax, {x, y});
|
||||
case ir::reduce_inst::FMIN: return builder_->CreateCall(fmin, {x, y});
|
||||
default: assert(false); return nullptr;
|
||||
}
|
||||
assert(false);
|
||||
return nullptr;
|
||||
};
|
||||
|
||||
Value *neutral;
|
||||
switch(op) {
|
||||
case ir::reduce_inst::ADD: neutral = builder_->getInt32(0); break;
|
||||
case ir::reduce_inst::SUB: neutral = builder_->getInt32(0); break;
|
||||
case ir::reduce_inst::MAX: neutral = builder_->getInt32(INT32_MIN); break;
|
||||
case ir::reduce_inst::MIN: neutral = builder_->getInt32(INT32_MAX); break;
|
||||
case ir::reduce_inst::FADD: neutral = ConstantFP::get(arg_tile->get_ty(), 0); break;
|
||||
case ir::reduce_inst::FSUB: neutral = ConstantFP::get(arg_tile->get_ty(), 0); break;
|
||||
case ir::reduce_inst::FMAX: neutral = ConstantFP::get(arg_tile->get_ty(), -INFINITY); break;
|
||||
case ir::reduce_inst::FMIN: neutral = ConstantFP::get(arg_tile->get_ty(), INFINITY); break;
|
||||
default: assert(false); break;
|
||||
}
|
||||
|
||||
|
||||
|
||||
analysis::data_layout* arg_layout = layouts_->get(arg);
|
||||
if(auto* L = dynamic_cast<analysis::scanline_layout*>(arg_layout)){
|
||||
bool can_optimize = true;
|
||||
for(size_t r = 0; r < L->get_rank(); r++){
|
||||
if(r != axis)
|
||||
can_optimize = can_optimize && (L->mts(r) == L->get_shape()[r]);
|
||||
}
|
||||
if(can_optimize){
|
||||
Value *thread_acc = nullptr;
|
||||
// reduce within thread
|
||||
arg_tile->for_each([&](indices_t idx) {
|
||||
Value *current = arg_tile->get_value(idx);
|
||||
if(thread_acc == nullptr)
|
||||
thread_acc = current;
|
||||
else
|
||||
thread_acc = accumulate(thread_acc, current);
|
||||
});
|
||||
// reduce within wrap
|
||||
FunctionType *fn_ty = FunctionType::get(thread_acc->getType(), {thread_acc->getType(), builder_->getInt32Ty()}, false);
|
||||
InlineAsm *shfl_xor = InlineAsm::get(fn_ty, "shfl.sync.bfly.b32 $0, $1, $2, 0x1f, 0xffffffff;", "=f,f,r", false);
|
||||
Value *warp_acc = thread_acc;
|
||||
for(int i = 16; i > 0; i >>= 1)
|
||||
warp_acc = accumulate(warp_acc, builder_->CreateCall(shfl_xor, {warp_acc, builder_->getInt32(i)}));
|
||||
// shared memory pointer
|
||||
unsigned addr_space = sh_mem_ptr_->getType()->getPointerAddressSpace();
|
||||
Type *res_ty = arg_tile->get_ty();
|
||||
Value *sh_mem_ptr = builder_->CreateBitCast(sh_mem_ptr_, PointerType::get(res_ty, addr_space));
|
||||
Value* u_thread_id = tgt_->get_local_id(builder_->GetInsertBlock()->getModule(), *builder_, 0);
|
||||
Value* warp_id = builder_->CreateUDiv(u_thread_id, builder_->getInt32(32));
|
||||
Value *write_ptr = builder_->CreateGEP(sh_mem_ptr, warp_id);
|
||||
// store warp result in shared memory
|
||||
tgt_->add_barrier(mod_, *builder_);
|
||||
builder_->CreateStore(warp_acc, write_ptr);
|
||||
tgt_->add_barrier(mod_, *builder_);
|
||||
// accumulate all warps
|
||||
Value *load_ptr = builder_->CreateGEP(sh_mem_ptr, u_thread_id);
|
||||
Value* is_first_warp = builder_->CreateICmpEQ(warp_id, builder_->getInt32(0));
|
||||
BasicBlock* bb_final_acc = BasicBlock::Create(*ctx_, "bb_final_acc", builder_->GetInsertBlock()->getParent());
|
||||
BasicBlock* bb_final_acc_done = BasicBlock::Create(*ctx_, "bb_final_acc_done", builder_->GetInsertBlock()->getParent());
|
||||
builder_->CreateCondBr(is_first_warp, bb_final_acc, bb_final_acc_done);
|
||||
builder_->SetInsertPoint(bb_final_acc);
|
||||
Value* final_val = builder_->CreateLoad(load_ptr);
|
||||
for(int i = (num_warps_+1)/2; i > 0; i >>= 1)
|
||||
final_val = accumulate(final_val, builder_->CreateCall(shfl_xor, {final_val, builder_->getInt32(i)}));
|
||||
builder_->CreateStore(final_val, load_ptr);
|
||||
builder_->CreateBr(bb_final_acc_done);
|
||||
// // store first warp done
|
||||
builder_->SetInsertPoint(bb_final_acc_done);
|
||||
// write back
|
||||
tgt_->add_barrier(mod_, *builder_);
|
||||
final_val = builder_->CreateLoad(sh_mem_ptr);
|
||||
for_each(x, [&](indices_t idx) {
|
||||
set_value(x, idx, final_val);
|
||||
});
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
// reduce within thread
|
||||
unsigned axis = x->get_axis();
|
||||
arg_tile->for_each([&](indices_t idx) {
|
||||
indices_t pidx = idx;
|
||||
pidx[axis] = builder_->getInt32(0);
|
||||
@@ -861,7 +968,7 @@ void generator::visit_reduce_inst(ir::reduce_inst* x) {
|
||||
unsigned depth = stile->get_shapes()[axis];
|
||||
|
||||
unsigned addr_space = sh_mem_ptr_->getType()->getPointerAddressSpace();
|
||||
Type *res_ty = builder_->getFloatTy();
|
||||
Type *res_ty = arg_tile->get_ty();
|
||||
Value *base_ptr = builder_->CreateBitCast(sh_mem_ptr_, PointerType::get(res_ty, addr_space));
|
||||
for(auto& x: partial) {
|
||||
// current element being computed
|
||||
@@ -891,10 +998,12 @@ void generator::visit_reduce_inst(ir::reduce_inst* x) {
|
||||
// accumulate
|
||||
result = accumulate(result, next);
|
||||
// write back
|
||||
tgt_->add_barrier(mod_, *builder_);
|
||||
builder_->CreateStore(result, write_ptr);
|
||||
}
|
||||
}
|
||||
tgt_->add_barrier(mod_, *builder_);
|
||||
|
||||
// write back
|
||||
for_each(x, [&](indices_t idx) {
|
||||
indices_t red_idx = idx;
|
||||
@@ -1169,8 +1278,9 @@ void generator::visit_function(ir::function* fn) {
|
||||
}
|
||||
builder_->SetInsertPoint((BasicBlock*)vmap_[fn->blocks()[0]]);
|
||||
// initialize layouts
|
||||
for(auto x: layouts_->get_all())
|
||||
for(auto x: layouts_->get_all()){
|
||||
visit_layout(x.second);
|
||||
}
|
||||
// generate LLVM-IR code
|
||||
for(ir::basic_block *block: fn->blocks())
|
||||
visit_basic_block(block);
|
||||
|
@@ -158,7 +158,6 @@ tile *machine_distributed_layout::create(ir::value *v) {
|
||||
return false;
|
||||
};
|
||||
std::sort(order.begin(), order.end(), cmp);
|
||||
|
||||
return new distributed_tile(ty, shapes, order, axes, *builder_);
|
||||
}
|
||||
|
||||
|
@@ -135,13 +135,13 @@ Value* shared_tile::shared_offset(llvm::IRBuilder<> &builder, const shapes_t& sh
|
||||
const std::vector<int>& perm, const std::vector<int>& order,
|
||||
indices_t idx) {
|
||||
// strides
|
||||
std::vector<Value*> strides(order.size());
|
||||
std::vector<Value*> strides(shapes.size(), builder.getInt32(0));
|
||||
strides[order[0]] = builder.getInt32(1);
|
||||
for(size_t i = 1; i < idx.size(); i++)
|
||||
strides[order[i]] = builder.CreateMul(strides[order[i-1]], builder.getInt32(shapes[order[i-1]]));
|
||||
// result
|
||||
Value *result = builder.getInt32(0);
|
||||
for(size_t i = 0; i < strides.size(); i++)
|
||||
for(size_t i = 0; i < idx.size(); i++)
|
||||
result = builder.CreateAdd(result, builder.CreateMul(idx[perm[i]], strides[i]));
|
||||
return result;
|
||||
}
|
||||
|
@@ -26,8 +26,6 @@ inline bool is_shmem_res(ir::value* v){
|
||||
return false;
|
||||
if(i->get_id() == ir::INST_TRANS)
|
||||
return true;
|
||||
if(i->get_id() == ir::INST_REDUCE)
|
||||
return true;
|
||||
if(i->get_id() == ir::INST_COPY_TO_SHARED)
|
||||
return true;
|
||||
return false;
|
||||
@@ -76,8 +74,9 @@ void cts::run(ir::module &mod) {
|
||||
size_t num_op = i->get_num_operands();
|
||||
// copy to shared operands
|
||||
for(size_t k = 0; k < num_op; k++)
|
||||
if(is_shmem_op(i, k))
|
||||
if(is_shmem_op(i, k)){
|
||||
add_copy(i, i->get_operand(k), builder, true);
|
||||
}
|
||||
// copy from shared operands
|
||||
for(size_t k = 0; k < num_op; k++)
|
||||
if(!dynamic_cast<ir::phi_node*>(i) &&
|
||||
|
@@ -83,6 +83,19 @@ bool peephole::rewrite_dot(ir::instruction *value, ir::builder& builder){
|
||||
}
|
||||
}
|
||||
|
||||
bool peephole::rewrite_cts_cfs(ir::instruction *value, ir::builder &builder){
|
||||
auto cfs = dynamic_cast<ir::copy_from_shared_inst*>(value);
|
||||
if(cfs) {
|
||||
ir::value *arg = cfs->get_operand(0);
|
||||
ir::copy_to_shared_inst* cts = dynamic_cast<ir::copy_to_shared_inst*>(arg);
|
||||
if(!cts)
|
||||
return false;
|
||||
cfs->replace_all_uses_with(cts->get_operand(0));
|
||||
return true;
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
bool peephole::rewrite_unit_red(ir::instruction *value, ir::builder& builder){
|
||||
auto x = dynamic_cast<ir::reduce_inst*>(value);
|
||||
if(!x)
|
||||
@@ -183,6 +196,7 @@ void peephole::run(ir::module &mod) {
|
||||
continue;
|
||||
bool was_modified = false;
|
||||
was_modified = was_modified || rewrite_mult(i, builder);
|
||||
was_modified = was_modified || rewrite_cts_cfs(i, builder);
|
||||
was_modified = was_modified || rewrite_trans_phi(i, builder);
|
||||
was_modified = was_modified || rewrite_unit_red(i, builder);
|
||||
was_modified = was_modified || rewrite_gep_ptr_min_off_plus_off(i, builder);
|
||||
|
Reference in New Issue
Block a user