[ir] deleted mask/merge instructions; will be replaced by masked_load/store and select

This commit is contained in:
Philippe Tillet
2019-07-25 15:06:15 -07:00
parent 6ce82dfcdb
commit 2a377bc8b1
27 changed files with 387 additions and 407 deletions

View File

@@ -32,7 +32,7 @@ void distributed_tile::init_indices() {
current.push_back(axes_[d].values[id[d]]);
size_t sz = indices_.size();
indices_[current] = sz;
values_[current] = UndefValue::get(ty_);
values_[current] = nullptr;
ordered_indices_.push_back(current);
id[0]++;
while(id[k] == axes_[k].values.size()){
@@ -57,12 +57,17 @@ distributed_tile::distributed_tile(Type *ty, const shapes_t &shapes, const axes_
init_indices();
}
void distributed_tile::set_value(indices_t idx, Value *v) {
values_[idx] = v;
void distributed_tile::set_value(indices_t idx, Value *x) {
assert(x->getType() == ty_ && "cannot set a value of different type");
Value *&result = values_[idx];
assert(!result && "value cannot be set twice");
result = x;
}
Value* distributed_tile::get_value(indices_t idx) {
return values_[idx];
Value *result = values_.at(idx);
assert(result && "value has not been set");
return result;
}
unsigned distributed_tile::get_linear_index(indices_t idx) {
@@ -688,15 +693,15 @@ void selection::create_tile(ir::value *v, IRBuilder<> &builder,
}
bool vectorize = dynamic_cast<ir::vectorize_inst*>(v);
distributed_tile *T = new distributed_tile(ty, shapes, axes, builder, vectorize);
tmap_.insert({v, T});
bool is_inserted = tmap_.insert({v, T}).second;
// constant range
if(dynamic_cast<ir::constant_range*>(v)){
if(is_inserted && dynamic_cast<ir::constant_range*>(v)){
T->for_each([&](indices_t idx){
assert(idx.size() == 1);
T->set_value(idx, idx[0]);
});
}
if(dynamic_cast<ir::nv_static_range_idx*>(v)){
if(is_inserted && dynamic_cast<ir::nv_static_range_idx*>(v)){
T->for_each([&](indices_t idx){
assert(idx.size() == 1);
BinaryOperator *bin_add = dyn_cast<BinaryOperator>(idx[0]);
@@ -746,31 +751,41 @@ void selection::lower_tile_instruction(ir::instruction *ins, llvm::IRBuilder<> &
LLVMContext &ctx = builder.getContext();
Function *fn = block->getParent();
// store
if(auto *x = dynamic_cast<ir::store_inst*>(ins)) {
distributed_tile* ptr = (distributed_tile*)tmap_.at(x->get_pointer_operand());
tile *value = tmap_.at(x->get_value_operand());
ir::value *mask = x->get_mask();
if(mask) {
distributed_tile* preds = (distributed_tile*)tmap_.at(mask);
ptr->for_each([&](indices_t idx){
BasicBlock *mask_then_bb = BasicBlock::Create(ctx, "mask_then", fn);
BasicBlock *mask_done_bb = BasicBlock::Create(ctx, "mask_done", fn);
builder.CreateCondBr(preds->get_value(idx), mask_then_bb, mask_done_bb);
builder.SetInsertPoint(mask_then_bb);
builder.CreateStore(value->get_value(idx), ptr->get_value(idx));
builder.CreateBr(mask_done_bb);
builder.SetInsertPoint(mask_done_bb);
});
}
else {
ptr->for_each([&](indices_t idx){
if(GetElementPtrInst *gep = dyn_cast<GetElementPtrInst>(ptr->get_value(idx)))
if(BinaryOperator *binop = dyn_cast<BinaryOperator>(*gep->idx_begin())){
std::cout << isa<Constant>(binop->getOperand(0)) << " " << isa<Constant>(binop->getOperand(1)) << std::endl;
}
builder.CreateStore(value->get_value(idx), ptr->get_value(idx));
});
}
if(auto *x = dynamic_cast<ir::masked_store_inst*>(ins)){
distributed_tile* ptrs = (distributed_tile*)tmap_.at(x->get_pointer_operand());
tile *scalars = tmap_.at(x->get_value_operand());
ir::value *mask = x->get_mask_operand();
distributed_tile* preds = (distributed_tile*)tmap_.at(mask);
ptrs->for_each([&](indices_t idx){
Value *scalar = scalars->get_value(idx);
Value *ptr = ptrs->get_value(idx);
Value *pred = preds->get_value(idx);
// std::string offset = "";
// if(GetElementPtrInst *gep = dyn_cast<GetElementPtrInst>(ptr))
// if(gep->getNumIndices() == 1)
// if(ConstantInt *cst = dyn_cast<ConstantInt>(gep->idx_begin())){
// offset = " + " + std::to_string(cst->getValue().getSExtValue()*4);
// }
// FunctionType *ty = FunctionType::get(Type::getVoidTy(ctx), {pred->getType(), ptr->getType(), scalar->getType()}, false);
// std::string asm_str = "@$0 st.global.b32 [$1" + offset + "], $2;";
// InlineAsm *iasm = InlineAsm::get(ty, asm_str, "b,l,f", true);
// builder.CreateCall(iasm, {pred, ptr, scalar});
BasicBlock *mask_then_bb = BasicBlock::Create(ctx, "mask_then", fn);
BasicBlock *mask_done_bb = BasicBlock::Create(ctx, "mask_done", fn);
builder.CreateCondBr(pred, mask_then_bb, mask_done_bb);
builder.SetInsertPoint(mask_then_bb);
builder.CreateStore(scalar, ptr);
builder.CreateBr(mask_done_bb);
builder.SetInsertPoint(mask_done_bb);
});
}
else if(auto *x = dynamic_cast<ir::store_inst*>(ins)) {
distributed_tile* ptrs = (distributed_tile*)tmap_.at(x->get_pointer_operand());
tile *scalars = tmap_.at(x->get_value_operand());
ptrs->for_each([&](indices_t idx){
builder.CreateStore(scalars->get_value(idx), ptrs->get_value(idx));
});
}
else {
if(auto *x = dynamic_cast<ir::downcast_inst*>(ins)){
@@ -837,14 +852,6 @@ void selection::lower_tile_instruction(ir::instruction *ins, llvm::IRBuilder<> &
if(!ins->get_type()->is_tile_ty())
return;
const auto& shapes = ins->get_type()->get_tile_shapes();
// global_range
if(auto *x = dynamic_cast<ir::get_global_range_inst*>(ins)) {
Value *offset = tgt_->get_global_offset(module, builder, shapes[0]->get_value(), x->get_axis());
result->for_each([&](indices_t idx){
BinaryOperator *bin = static_cast<BinaryOperator*>(idx[0]);
result->set_value(idx, builder.CreateAdd(bin, offset));
});
}
// nv_dynamic_range_idx_inst
if(dynamic_cast<ir::nv_dynamic_range_idx_inst*>(ins)){
result->for_each([&](indices_t idx){
@@ -855,49 +862,6 @@ void selection::lower_tile_instruction(ir::instruction *ins, llvm::IRBuilder<> &
result->set_value(idx, res);
});
}
// // mask
// else if(dynamic_cast<ir::mask_inst*>(ins)) {
// distributed_tile* pred = (distributed_tile*)tmap_.at(ins->get_operand(0));
// distributed_tile* mask_tile_true = (distributed_tile*)tmap_.at(ins->get_result(0));
// distributed_tile* mask_tile_false = (distributed_tile*)tmap_.at(ins->get_result(1));
// pred->for_each([&](indices_t idx){
// BasicBlock *mask_then_bb = BasicBlock::Create(ctx, "mask_then", fn);
// BasicBlock* mask_else_bb = BasicBlock::Create(ctx, "mask_else", fn);
// BasicBlock *mask_done_bb = BasicBlock::Create(ctx, "mask_done", fn);
// builder.CreateCondBr(pred->get_value(idx), mask_then_bb, mask_else_bb);
// builder.SetInsertPoint(mask_then_bb);
// builder.CreateBr(mask_done_bb);
// builder.SetInsertPoint(mask_else_bb);
// builder.CreateBr(mask_done_bb);
// builder.SetInsertPoint(mask_done_bb);
// pmap_.insert({{mask_tile_true, idx}, mask_then_bb});
// pmap_.insert({{mask_tile_false, idx}, mask_else_bb});
// last_block_.insert({{mask_tile_true, idx}, mask_done_bb});
// last_block_.insert({{mask_tile_false, idx}, mask_done_bb});
// });
// }
// // merge
// else if(auto *merge = dynamic_cast<ir::psi_inst*>(ins)) {
// distributed_tile* mask_tile_true = (distributed_tile*)tmap_.at(merge->get_mask_true());
// distributed_tile *value_tile_true = (distributed_tile*)tmap_.at(merge->get_value_true());
// distributed_tile* mask_tile_false = (distributed_tile*)tmap_.at(merge->get_mask_false());
// distributed_tile *value_tile_false = (distributed_tile*)tmap_.at(merge->get_value_false());
// result->for_each([&](indices_t idx){
// BasicBlock *block_true = pmap_.at({mask_tile_true, idx});
// Value *value_true = value_tile_true->get_value(idx);
// BasicBlock *block_false = pmap_.at({mask_tile_false, idx});
// Value *value_false = value_tile_false->get_value(idx);
// BasicBlock *block_done = last_block_.at({mask_tile_true, idx});
// if(block_done->getTerminator())
// builder.SetInsertPoint(block_done->getTerminator());
// else
// builder.SetInsertPoint(block_done);
// PHINode *phi = builder.CreatePHI(value_true->getType(), 2);
// phi->addIncoming(value_true, block_true);
// phi->addIncoming(value_false,block_false);
// result->set_value(idx, phi);
// });
// }
// reshape
else if(dynamic_cast<ir::reshape_inst*>(ins)) {
ir::value* in = ins->get_operand(0);
@@ -939,9 +903,10 @@ void selection::lower_tile_instruction(ir::instruction *ins, llvm::IRBuilder<> &
in->for_each([&](indices_t idx){
unsigned linear = in->get_linear_index(idx);
unsigned id = linear / vector_size;
Value *in_value = in->get_value(idx);
if(linear % vector_size == 0)
packets[id] = result->get_value(idx);
packets[id] = builder.CreateInsertElement(packets.at(id), in->get_value(idx), linear % vector_size);
packets[id] = UndefValue::get(VectorType::get(in_value->getType(), vector_size));
packets[id] = builder.CreateInsertElement(packets.at(id), in_value, linear % vector_size);
});
result->for_each([&](indices_t idx){
unsigned linear = in->get_linear_index(idx);
@@ -1017,8 +982,10 @@ void selection::lower_tile_instruction(ir::instruction *ins, llvm::IRBuilder<> &
TB->set_return_mode(true);
std::vector<Value *> fc;
result->for_each([&](indices_t idx){
fc.push_back(result->get_value(idx));
fc.push_back(TC->get_value(idx));
// fc.push_back(UndefValue::get(TC->get_value(idx)->getType()));
});
Type *fp32_ty = builder.getFloatTy();
@@ -1076,10 +1043,10 @@ void selection::lower_tile_instruction(ir::instruction *ins, llvm::IRBuilder<> &
Value *hb = TB->get_value(idx_b);
for(unsigned ii = 0; ii < pack_size_0_; ii++)
for(unsigned jj = 0; jj < pack_size_1_; jj++){
Value *ha0 = builder.CreateExtractElement(ha, builder.getInt32(ii*pack_size_0_ + 0));
Value *ha1 = builder.CreateExtractElement(ha, builder.getInt32(ii*pack_size_0_ + 1));
Value *hb0 = builder.CreateExtractElement(hb, builder.getInt32(jj*pack_size_0_ + 0));
Value *hb1 = builder.CreateExtractElement(hb, builder.getInt32(jj*pack_size_0_ + 1));
Value *ha0 = builder.CreateBitCast(builder.CreateExtractElement(ha, builder.getInt32(ii*pack_size_0_ + 0)), fp16x2_ty);
Value *ha1 = builder.CreateBitCast(builder.CreateExtractElement(ha, builder.getInt32(ii*pack_size_0_ + 1)), fp16x2_ty);
Value *hb0 = builder.CreateBitCast(builder.CreateExtractElement(hb, builder.getInt32(jj*pack_size_0_ + 0)), fp16x2_ty);
Value *hb1 = builder.CreateBitCast(builder.CreateExtractElement(hb, builder.getInt32(jj*pack_size_0_ + 1)), fp16x2_ty);
std::vector<size_t> idx = {
(pack_i*2*pack_size_0_ + ii*2 + 0) + (pack_j*4*pack_size_1_ + jj*4 + 0)*ld_fc,
(pack_i*2*pack_size_0_ + ii*2 + 0) + (pack_j*4*pack_size_1_ + jj*4 + 1)*ld_fc,
@@ -1136,24 +1103,106 @@ void selection::lower_tile_instruction(ir::instruction *ins, llvm::IRBuilder<> &
});
}
}
else if(auto *ld = dynamic_cast<ir::load_inst*>(ins)){
else if(auto *ld = dynamic_cast<ir::masked_load_inst*>(ins)){
// find vector size
ir::value *ptr = ld->get_pointer_operand();
unsigned starting_multiple = axis_info_->get_starting_multiple(ptr);
unsigned max_contiguous = axis_info_->get_max_contiguous(ptr);
unsigned alignment = std::min(starting_multiple, max_contiguous);
unsigned vector_size = std::min<unsigned>(result->axis(0).contiguous, alignment);
distributed_tile *pointers = (distributed_tile*)tmap_.at(ptr);
distributed_tile *masks = (distributed_tile*)tmap_.at(ld->get_mask_operand());
distributed_tile *false_values = (distributed_tile*)tmap_.at(ld->get_false_value_operand());
std::map<unsigned, Value*> packets;
distributed_tile *TP = (distributed_tile*)tmap_.at(ld->get_pointer_operand());
result->for_each([&](indices_t idx){
unsigned linear = result->get_linear_index(idx);
unsigned id = linear / vector_size;
if(linear % vector_size == 0){
Value *ptr = TP->get_value(idx);
ptr= builder.CreateBitCast(ptr, PointerType::get(VectorType::get(result->get_ty(), vector_size),
ptr->getType()->getPointerAddressSpace()));
if(linear % vector_size == 0) {
Value *ptr = pointers->get_value(idx);
ConstantInt *cst = nullptr;
if(GetElementPtrInst *gep = dyn_cast<GetElementPtrInst>(ptr))
if(gep->getNumIndices() == 1){
cst = dyn_cast<ConstantInt>(gep->idx_begin());
}
ptr = builder.CreateBitCast(ptr, PointerType::get(VectorType::get(result->get_ty(), vector_size),
ptr->getType()->getPointerAddressSpace()));
Value *mask = masks->get_value(idx);
BasicBlock *current_bb = builder.GetInsertBlock();
BasicBlock *mask_then_bb = BasicBlock::Create(ctx, "mask_then", fn);
BasicBlock *mask_done_bb = BasicBlock::Create(ctx, "mask_done", fn);
builder.CreateCondBr(mask, mask_then_bb, mask_done_bb);
builder.SetInsertPoint(mask_then_bb);
Value *result_then = builder.CreateLoad(ptr);
builder.CreateBr(mask_done_bb);
builder.SetInsertPoint(mask_done_bb);
Value *result = nullptr;
if(false_values){
result = builder.CreatePHI(result_then->getType(), 2);
((PHINode*)result)->addIncoming(result_then, mask_then_bb);
Value *result_false = false_values->get_value(idx);
if(vector_size > 1)
result_false = builder.CreateVectorSplat(vector_size, result_false);
((PHINode*)result)->addIncoming(result_false, current_bb);
}
else
result = result_then;
// std::string offset = "";
// if(cst)
// offset = " + " + std::to_string(cst->getValue().getSExtValue()*2*vector_size);
// Type *fp16x2_ty = VectorType::get(builder.getHalfTy(), 2);
// Type *fp16x2_pack4_ty = StructType::get(ctx, {fp16x2_ty, fp16x2_ty, fp16x2_ty, fp16x2_ty});
// FunctionType *ty = FunctionType::get(fp16x2_pack4_ty, {mask->getType(), ptr->getType()}, false);
// std::string asm_str = "@$0 ld.global.nc.v4.b32 {$1, $2, $3, $4}, [$5" + offset + "];";
// if(false_value)
// asm_str += "\n\t@!$0 mov.v4.b32 {$1, $2, $3, $4}, {0, 0, 0, 0};";
// InlineAsm *iasm = InlineAsm::get(ty, asm_str, "b,=r,=r,=r,=r,l", true);
// Value *result = builder.CreateCall(iasm, {mask, ptr});
packets[id] = result;
}
});
// extract result element
result->for_each([&](indices_t idx){
unsigned linear = result->get_linear_index(idx);
unsigned id = linear / vector_size;
// Value *tmp = builder.CreateExtractValue(packets.at(id), {(linear % vector_size) / 2});
// Value *res = builder.CreateExtractElement(tmp, (linear % vector_size) % 2);
// result->set_value(idx, res);
result->set_value(idx, builder.CreateExtractElement(packets.at(id), linear % vector_size));
});
}
else if(auto *ld = dynamic_cast<ir::load_inst*>(ins)){
// find vector size
ir::value *ptr = ld->get_pointer_operand();
unsigned starting_multiple = axis_info_->get_starting_multiple(ptr);
unsigned max_contiguous = axis_info_->get_max_contiguous(ptr);
unsigned alignment = std::min(starting_multiple, max_contiguous);
unsigned vector_size = std::min<unsigned>(result->axis(0).contiguous, alignment);
distributed_tile *pointers = (distributed_tile*)tmap_.at(ptr);
// vector loads
std::map<unsigned, Value*> packets;
result->for_each([&](indices_t idx){
unsigned linear = result->get_linear_index(idx);
unsigned id = linear / vector_size;
if(linear % vector_size == 0) {
Value *ptr = pointers->get_value(idx);
ConstantInt *cst = nullptr;
if(GetElementPtrInst *gep = dyn_cast<GetElementPtrInst>(ptr))
if(gep->getNumIndices() == 1){
cst = dyn_cast<ConstantInt>(gep->idx_begin());
}
ptr = builder.CreateBitCast(ptr, PointerType::get(VectorType::get(result->get_ty(), vector_size),
ptr->getType()->getPointerAddressSpace()));
packets[id] = builder.CreateLoad(ptr);
}
result->set_value(idx, builder.CreateExtractElement(packets.at(id), linear % vector_size));
});
// extract result element
result->for_each([&](indices_t idx){
unsigned linear = result->get_linear_index(idx);
unsigned id = linear / vector_size;
// result->set_value(idx, builder.CreateExtractElement(packets.at(id), linear % vector_size));
});
}
// element-wise