[ir] deleted mask/merge instructions; will be replaced by masked_load/store and select
This commit is contained in:
@@ -32,7 +32,7 @@ void distributed_tile::init_indices() {
|
||||
current.push_back(axes_[d].values[id[d]]);
|
||||
size_t sz = indices_.size();
|
||||
indices_[current] = sz;
|
||||
values_[current] = UndefValue::get(ty_);
|
||||
values_[current] = nullptr;
|
||||
ordered_indices_.push_back(current);
|
||||
id[0]++;
|
||||
while(id[k] == axes_[k].values.size()){
|
||||
@@ -57,12 +57,17 @@ distributed_tile::distributed_tile(Type *ty, const shapes_t &shapes, const axes_
|
||||
init_indices();
|
||||
}
|
||||
|
||||
void distributed_tile::set_value(indices_t idx, Value *v) {
|
||||
values_[idx] = v;
|
||||
void distributed_tile::set_value(indices_t idx, Value *x) {
|
||||
assert(x->getType() == ty_ && "cannot set a value of different type");
|
||||
Value *&result = values_[idx];
|
||||
assert(!result && "value cannot be set twice");
|
||||
result = x;
|
||||
}
|
||||
|
||||
Value* distributed_tile::get_value(indices_t idx) {
|
||||
return values_[idx];
|
||||
Value *result = values_.at(idx);
|
||||
assert(result && "value has not been set");
|
||||
return result;
|
||||
}
|
||||
|
||||
unsigned distributed_tile::get_linear_index(indices_t idx) {
|
||||
@@ -688,15 +693,15 @@ void selection::create_tile(ir::value *v, IRBuilder<> &builder,
|
||||
}
|
||||
bool vectorize = dynamic_cast<ir::vectorize_inst*>(v);
|
||||
distributed_tile *T = new distributed_tile(ty, shapes, axes, builder, vectorize);
|
||||
tmap_.insert({v, T});
|
||||
bool is_inserted = tmap_.insert({v, T}).second;
|
||||
// constant range
|
||||
if(dynamic_cast<ir::constant_range*>(v)){
|
||||
if(is_inserted && dynamic_cast<ir::constant_range*>(v)){
|
||||
T->for_each([&](indices_t idx){
|
||||
assert(idx.size() == 1);
|
||||
T->set_value(idx, idx[0]);
|
||||
});
|
||||
}
|
||||
if(dynamic_cast<ir::nv_static_range_idx*>(v)){
|
||||
if(is_inserted && dynamic_cast<ir::nv_static_range_idx*>(v)){
|
||||
T->for_each([&](indices_t idx){
|
||||
assert(idx.size() == 1);
|
||||
BinaryOperator *bin_add = dyn_cast<BinaryOperator>(idx[0]);
|
||||
@@ -746,31 +751,41 @@ void selection::lower_tile_instruction(ir::instruction *ins, llvm::IRBuilder<> &
|
||||
LLVMContext &ctx = builder.getContext();
|
||||
Function *fn = block->getParent();
|
||||
// store
|
||||
if(auto *x = dynamic_cast<ir::store_inst*>(ins)) {
|
||||
distributed_tile* ptr = (distributed_tile*)tmap_.at(x->get_pointer_operand());
|
||||
tile *value = tmap_.at(x->get_value_operand());
|
||||
ir::value *mask = x->get_mask();
|
||||
if(mask) {
|
||||
distributed_tile* preds = (distributed_tile*)tmap_.at(mask);
|
||||
ptr->for_each([&](indices_t idx){
|
||||
BasicBlock *mask_then_bb = BasicBlock::Create(ctx, "mask_then", fn);
|
||||
BasicBlock *mask_done_bb = BasicBlock::Create(ctx, "mask_done", fn);
|
||||
builder.CreateCondBr(preds->get_value(idx), mask_then_bb, mask_done_bb);
|
||||
builder.SetInsertPoint(mask_then_bb);
|
||||
builder.CreateStore(value->get_value(idx), ptr->get_value(idx));
|
||||
builder.CreateBr(mask_done_bb);
|
||||
builder.SetInsertPoint(mask_done_bb);
|
||||
});
|
||||
}
|
||||
else {
|
||||
ptr->for_each([&](indices_t idx){
|
||||
if(GetElementPtrInst *gep = dyn_cast<GetElementPtrInst>(ptr->get_value(idx)))
|
||||
if(BinaryOperator *binop = dyn_cast<BinaryOperator>(*gep->idx_begin())){
|
||||
std::cout << isa<Constant>(binop->getOperand(0)) << " " << isa<Constant>(binop->getOperand(1)) << std::endl;
|
||||
}
|
||||
builder.CreateStore(value->get_value(idx), ptr->get_value(idx));
|
||||
});
|
||||
}
|
||||
if(auto *x = dynamic_cast<ir::masked_store_inst*>(ins)){
|
||||
distributed_tile* ptrs = (distributed_tile*)tmap_.at(x->get_pointer_operand());
|
||||
tile *scalars = tmap_.at(x->get_value_operand());
|
||||
ir::value *mask = x->get_mask_operand();
|
||||
distributed_tile* preds = (distributed_tile*)tmap_.at(mask);
|
||||
ptrs->for_each([&](indices_t idx){
|
||||
Value *scalar = scalars->get_value(idx);
|
||||
Value *ptr = ptrs->get_value(idx);
|
||||
Value *pred = preds->get_value(idx);
|
||||
// std::string offset = "";
|
||||
// if(GetElementPtrInst *gep = dyn_cast<GetElementPtrInst>(ptr))
|
||||
// if(gep->getNumIndices() == 1)
|
||||
// if(ConstantInt *cst = dyn_cast<ConstantInt>(gep->idx_begin())){
|
||||
// offset = " + " + std::to_string(cst->getValue().getSExtValue()*4);
|
||||
// }
|
||||
// FunctionType *ty = FunctionType::get(Type::getVoidTy(ctx), {pred->getType(), ptr->getType(), scalar->getType()}, false);
|
||||
// std::string asm_str = "@$0 st.global.b32 [$1" + offset + "], $2;";
|
||||
// InlineAsm *iasm = InlineAsm::get(ty, asm_str, "b,l,f", true);
|
||||
// builder.CreateCall(iasm, {pred, ptr, scalar});
|
||||
|
||||
BasicBlock *mask_then_bb = BasicBlock::Create(ctx, "mask_then", fn);
|
||||
BasicBlock *mask_done_bb = BasicBlock::Create(ctx, "mask_done", fn);
|
||||
builder.CreateCondBr(pred, mask_then_bb, mask_done_bb);
|
||||
builder.SetInsertPoint(mask_then_bb);
|
||||
builder.CreateStore(scalar, ptr);
|
||||
builder.CreateBr(mask_done_bb);
|
||||
builder.SetInsertPoint(mask_done_bb);
|
||||
});
|
||||
}
|
||||
else if(auto *x = dynamic_cast<ir::store_inst*>(ins)) {
|
||||
distributed_tile* ptrs = (distributed_tile*)tmap_.at(x->get_pointer_operand());
|
||||
tile *scalars = tmap_.at(x->get_value_operand());
|
||||
ptrs->for_each([&](indices_t idx){
|
||||
builder.CreateStore(scalars->get_value(idx), ptrs->get_value(idx));
|
||||
});
|
||||
}
|
||||
else {
|
||||
if(auto *x = dynamic_cast<ir::downcast_inst*>(ins)){
|
||||
@@ -837,14 +852,6 @@ void selection::lower_tile_instruction(ir::instruction *ins, llvm::IRBuilder<> &
|
||||
if(!ins->get_type()->is_tile_ty())
|
||||
return;
|
||||
const auto& shapes = ins->get_type()->get_tile_shapes();
|
||||
// global_range
|
||||
if(auto *x = dynamic_cast<ir::get_global_range_inst*>(ins)) {
|
||||
Value *offset = tgt_->get_global_offset(module, builder, shapes[0]->get_value(), x->get_axis());
|
||||
result->for_each([&](indices_t idx){
|
||||
BinaryOperator *bin = static_cast<BinaryOperator*>(idx[0]);
|
||||
result->set_value(idx, builder.CreateAdd(bin, offset));
|
||||
});
|
||||
}
|
||||
// nv_dynamic_range_idx_inst
|
||||
if(dynamic_cast<ir::nv_dynamic_range_idx_inst*>(ins)){
|
||||
result->for_each([&](indices_t idx){
|
||||
@@ -855,49 +862,6 @@ void selection::lower_tile_instruction(ir::instruction *ins, llvm::IRBuilder<> &
|
||||
result->set_value(idx, res);
|
||||
});
|
||||
}
|
||||
// // mask
|
||||
// else if(dynamic_cast<ir::mask_inst*>(ins)) {
|
||||
// distributed_tile* pred = (distributed_tile*)tmap_.at(ins->get_operand(0));
|
||||
// distributed_tile* mask_tile_true = (distributed_tile*)tmap_.at(ins->get_result(0));
|
||||
// distributed_tile* mask_tile_false = (distributed_tile*)tmap_.at(ins->get_result(1));
|
||||
// pred->for_each([&](indices_t idx){
|
||||
// BasicBlock *mask_then_bb = BasicBlock::Create(ctx, "mask_then", fn);
|
||||
// BasicBlock* mask_else_bb = BasicBlock::Create(ctx, "mask_else", fn);
|
||||
// BasicBlock *mask_done_bb = BasicBlock::Create(ctx, "mask_done", fn);
|
||||
// builder.CreateCondBr(pred->get_value(idx), mask_then_bb, mask_else_bb);
|
||||
// builder.SetInsertPoint(mask_then_bb);
|
||||
// builder.CreateBr(mask_done_bb);
|
||||
// builder.SetInsertPoint(mask_else_bb);
|
||||
// builder.CreateBr(mask_done_bb);
|
||||
// builder.SetInsertPoint(mask_done_bb);
|
||||
// pmap_.insert({{mask_tile_true, idx}, mask_then_bb});
|
||||
// pmap_.insert({{mask_tile_false, idx}, mask_else_bb});
|
||||
// last_block_.insert({{mask_tile_true, idx}, mask_done_bb});
|
||||
// last_block_.insert({{mask_tile_false, idx}, mask_done_bb});
|
||||
// });
|
||||
// }
|
||||
// // merge
|
||||
// else if(auto *merge = dynamic_cast<ir::psi_inst*>(ins)) {
|
||||
// distributed_tile* mask_tile_true = (distributed_tile*)tmap_.at(merge->get_mask_true());
|
||||
// distributed_tile *value_tile_true = (distributed_tile*)tmap_.at(merge->get_value_true());
|
||||
// distributed_tile* mask_tile_false = (distributed_tile*)tmap_.at(merge->get_mask_false());
|
||||
// distributed_tile *value_tile_false = (distributed_tile*)tmap_.at(merge->get_value_false());
|
||||
// result->for_each([&](indices_t idx){
|
||||
// BasicBlock *block_true = pmap_.at({mask_tile_true, idx});
|
||||
// Value *value_true = value_tile_true->get_value(idx);
|
||||
// BasicBlock *block_false = pmap_.at({mask_tile_false, idx});
|
||||
// Value *value_false = value_tile_false->get_value(idx);
|
||||
// BasicBlock *block_done = last_block_.at({mask_tile_true, idx});
|
||||
// if(block_done->getTerminator())
|
||||
// builder.SetInsertPoint(block_done->getTerminator());
|
||||
// else
|
||||
// builder.SetInsertPoint(block_done);
|
||||
// PHINode *phi = builder.CreatePHI(value_true->getType(), 2);
|
||||
// phi->addIncoming(value_true, block_true);
|
||||
// phi->addIncoming(value_false,block_false);
|
||||
// result->set_value(idx, phi);
|
||||
// });
|
||||
// }
|
||||
// reshape
|
||||
else if(dynamic_cast<ir::reshape_inst*>(ins)) {
|
||||
ir::value* in = ins->get_operand(0);
|
||||
@@ -939,9 +903,10 @@ void selection::lower_tile_instruction(ir::instruction *ins, llvm::IRBuilder<> &
|
||||
in->for_each([&](indices_t idx){
|
||||
unsigned linear = in->get_linear_index(idx);
|
||||
unsigned id = linear / vector_size;
|
||||
Value *in_value = in->get_value(idx);
|
||||
if(linear % vector_size == 0)
|
||||
packets[id] = result->get_value(idx);
|
||||
packets[id] = builder.CreateInsertElement(packets.at(id), in->get_value(idx), linear % vector_size);
|
||||
packets[id] = UndefValue::get(VectorType::get(in_value->getType(), vector_size));
|
||||
packets[id] = builder.CreateInsertElement(packets.at(id), in_value, linear % vector_size);
|
||||
});
|
||||
result->for_each([&](indices_t idx){
|
||||
unsigned linear = in->get_linear_index(idx);
|
||||
@@ -1017,8 +982,10 @@ void selection::lower_tile_instruction(ir::instruction *ins, llvm::IRBuilder<> &
|
||||
TB->set_return_mode(true);
|
||||
|
||||
std::vector<Value *> fc;
|
||||
|
||||
result->for_each([&](indices_t idx){
|
||||
fc.push_back(result->get_value(idx));
|
||||
fc.push_back(TC->get_value(idx));
|
||||
// fc.push_back(UndefValue::get(TC->get_value(idx)->getType()));
|
||||
});
|
||||
|
||||
Type *fp32_ty = builder.getFloatTy();
|
||||
@@ -1076,10 +1043,10 @@ void selection::lower_tile_instruction(ir::instruction *ins, llvm::IRBuilder<> &
|
||||
Value *hb = TB->get_value(idx_b);
|
||||
for(unsigned ii = 0; ii < pack_size_0_; ii++)
|
||||
for(unsigned jj = 0; jj < pack_size_1_; jj++){
|
||||
Value *ha0 = builder.CreateExtractElement(ha, builder.getInt32(ii*pack_size_0_ + 0));
|
||||
Value *ha1 = builder.CreateExtractElement(ha, builder.getInt32(ii*pack_size_0_ + 1));
|
||||
Value *hb0 = builder.CreateExtractElement(hb, builder.getInt32(jj*pack_size_0_ + 0));
|
||||
Value *hb1 = builder.CreateExtractElement(hb, builder.getInt32(jj*pack_size_0_ + 1));
|
||||
Value *ha0 = builder.CreateBitCast(builder.CreateExtractElement(ha, builder.getInt32(ii*pack_size_0_ + 0)), fp16x2_ty);
|
||||
Value *ha1 = builder.CreateBitCast(builder.CreateExtractElement(ha, builder.getInt32(ii*pack_size_0_ + 1)), fp16x2_ty);
|
||||
Value *hb0 = builder.CreateBitCast(builder.CreateExtractElement(hb, builder.getInt32(jj*pack_size_0_ + 0)), fp16x2_ty);
|
||||
Value *hb1 = builder.CreateBitCast(builder.CreateExtractElement(hb, builder.getInt32(jj*pack_size_0_ + 1)), fp16x2_ty);
|
||||
std::vector<size_t> idx = {
|
||||
(pack_i*2*pack_size_0_ + ii*2 + 0) + (pack_j*4*pack_size_1_ + jj*4 + 0)*ld_fc,
|
||||
(pack_i*2*pack_size_0_ + ii*2 + 0) + (pack_j*4*pack_size_1_ + jj*4 + 1)*ld_fc,
|
||||
@@ -1136,24 +1103,106 @@ void selection::lower_tile_instruction(ir::instruction *ins, llvm::IRBuilder<> &
|
||||
});
|
||||
}
|
||||
}
|
||||
else if(auto *ld = dynamic_cast<ir::load_inst*>(ins)){
|
||||
else if(auto *ld = dynamic_cast<ir::masked_load_inst*>(ins)){
|
||||
// find vector size
|
||||
ir::value *ptr = ld->get_pointer_operand();
|
||||
unsigned starting_multiple = axis_info_->get_starting_multiple(ptr);
|
||||
unsigned max_contiguous = axis_info_->get_max_contiguous(ptr);
|
||||
unsigned alignment = std::min(starting_multiple, max_contiguous);
|
||||
unsigned vector_size = std::min<unsigned>(result->axis(0).contiguous, alignment);
|
||||
distributed_tile *pointers = (distributed_tile*)tmap_.at(ptr);
|
||||
distributed_tile *masks = (distributed_tile*)tmap_.at(ld->get_mask_operand());
|
||||
distributed_tile *false_values = (distributed_tile*)tmap_.at(ld->get_false_value_operand());
|
||||
std::map<unsigned, Value*> packets;
|
||||
distributed_tile *TP = (distributed_tile*)tmap_.at(ld->get_pointer_operand());
|
||||
result->for_each([&](indices_t idx){
|
||||
unsigned linear = result->get_linear_index(idx);
|
||||
unsigned id = linear / vector_size;
|
||||
if(linear % vector_size == 0){
|
||||
Value *ptr = TP->get_value(idx);
|
||||
ptr= builder.CreateBitCast(ptr, PointerType::get(VectorType::get(result->get_ty(), vector_size),
|
||||
ptr->getType()->getPointerAddressSpace()));
|
||||
if(linear % vector_size == 0) {
|
||||
Value *ptr = pointers->get_value(idx);
|
||||
ConstantInt *cst = nullptr;
|
||||
if(GetElementPtrInst *gep = dyn_cast<GetElementPtrInst>(ptr))
|
||||
if(gep->getNumIndices() == 1){
|
||||
cst = dyn_cast<ConstantInt>(gep->idx_begin());
|
||||
}
|
||||
|
||||
ptr = builder.CreateBitCast(ptr, PointerType::get(VectorType::get(result->get_ty(), vector_size),
|
||||
ptr->getType()->getPointerAddressSpace()));
|
||||
Value *mask = masks->get_value(idx);
|
||||
BasicBlock *current_bb = builder.GetInsertBlock();
|
||||
BasicBlock *mask_then_bb = BasicBlock::Create(ctx, "mask_then", fn);
|
||||
BasicBlock *mask_done_bb = BasicBlock::Create(ctx, "mask_done", fn);
|
||||
builder.CreateCondBr(mask, mask_then_bb, mask_done_bb);
|
||||
builder.SetInsertPoint(mask_then_bb);
|
||||
Value *result_then = builder.CreateLoad(ptr);
|
||||
builder.CreateBr(mask_done_bb);
|
||||
builder.SetInsertPoint(mask_done_bb);
|
||||
Value *result = nullptr;
|
||||
if(false_values){
|
||||
result = builder.CreatePHI(result_then->getType(), 2);
|
||||
((PHINode*)result)->addIncoming(result_then, mask_then_bb);
|
||||
Value *result_false = false_values->get_value(idx);
|
||||
if(vector_size > 1)
|
||||
result_false = builder.CreateVectorSplat(vector_size, result_false);
|
||||
((PHINode*)result)->addIncoming(result_false, current_bb);
|
||||
}
|
||||
else
|
||||
result = result_then;
|
||||
|
||||
// std::string offset = "";
|
||||
// if(cst)
|
||||
// offset = " + " + std::to_string(cst->getValue().getSExtValue()*2*vector_size);
|
||||
// Type *fp16x2_ty = VectorType::get(builder.getHalfTy(), 2);
|
||||
// Type *fp16x2_pack4_ty = StructType::get(ctx, {fp16x2_ty, fp16x2_ty, fp16x2_ty, fp16x2_ty});
|
||||
// FunctionType *ty = FunctionType::get(fp16x2_pack4_ty, {mask->getType(), ptr->getType()}, false);
|
||||
// std::string asm_str = "@$0 ld.global.nc.v4.b32 {$1, $2, $3, $4}, [$5" + offset + "];";
|
||||
// if(false_value)
|
||||
// asm_str += "\n\t@!$0 mov.v4.b32 {$1, $2, $3, $4}, {0, 0, 0, 0};";
|
||||
// InlineAsm *iasm = InlineAsm::get(ty, asm_str, "b,=r,=r,=r,=r,l", true);
|
||||
// Value *result = builder.CreateCall(iasm, {mask, ptr});
|
||||
|
||||
packets[id] = result;
|
||||
}
|
||||
});
|
||||
// extract result element
|
||||
result->for_each([&](indices_t idx){
|
||||
unsigned linear = result->get_linear_index(idx);
|
||||
unsigned id = linear / vector_size;
|
||||
// Value *tmp = builder.CreateExtractValue(packets.at(id), {(linear % vector_size) / 2});
|
||||
// Value *res = builder.CreateExtractElement(tmp, (linear % vector_size) % 2);
|
||||
// result->set_value(idx, res);
|
||||
result->set_value(idx, builder.CreateExtractElement(packets.at(id), linear % vector_size));
|
||||
});
|
||||
}
|
||||
else if(auto *ld = dynamic_cast<ir::load_inst*>(ins)){
|
||||
// find vector size
|
||||
ir::value *ptr = ld->get_pointer_operand();
|
||||
unsigned starting_multiple = axis_info_->get_starting_multiple(ptr);
|
||||
unsigned max_contiguous = axis_info_->get_max_contiguous(ptr);
|
||||
unsigned alignment = std::min(starting_multiple, max_contiguous);
|
||||
unsigned vector_size = std::min<unsigned>(result->axis(0).contiguous, alignment);
|
||||
distributed_tile *pointers = (distributed_tile*)tmap_.at(ptr);
|
||||
// vector loads
|
||||
std::map<unsigned, Value*> packets;
|
||||
result->for_each([&](indices_t idx){
|
||||
unsigned linear = result->get_linear_index(idx);
|
||||
unsigned id = linear / vector_size;
|
||||
if(linear % vector_size == 0) {
|
||||
Value *ptr = pointers->get_value(idx);
|
||||
ConstantInt *cst = nullptr;
|
||||
if(GetElementPtrInst *gep = dyn_cast<GetElementPtrInst>(ptr))
|
||||
if(gep->getNumIndices() == 1){
|
||||
cst = dyn_cast<ConstantInt>(gep->idx_begin());
|
||||
}
|
||||
ptr = builder.CreateBitCast(ptr, PointerType::get(VectorType::get(result->get_ty(), vector_size),
|
||||
ptr->getType()->getPointerAddressSpace()));
|
||||
packets[id] = builder.CreateLoad(ptr);
|
||||
}
|
||||
result->set_value(idx, builder.CreateExtractElement(packets.at(id), linear % vector_size));
|
||||
});
|
||||
// extract result element
|
||||
result->for_each([&](indices_t idx){
|
||||
unsigned linear = result->get_linear_index(idx);
|
||||
unsigned id = linear / vector_size;
|
||||
// result->set_value(idx, builder.CreateExtractElement(packets.at(id), linear % vector_size));
|
||||
});
|
||||
}
|
||||
// element-wise
|
||||
|
Reference in New Issue
Block a user