[CODEGEN] More work on the CPU backend
This commit is contained in:
committed by
Philippe Tillet
parent
64eaec016f
commit
840308ab5d
@@ -168,9 +168,9 @@ scanline_layout::scanline_layout(size_t num_warps,
|
||||
const std::vector<int>& axes,
|
||||
const std::vector<unsigned>& shape,
|
||||
const std::vector<ir::value *> &values,
|
||||
analysis::align* align): data_layout(SCANLINE, axes, shape, values, align){
|
||||
analysis::align* align, target *tgt): data_layout(SCANLINE, axes, shape, values, align){
|
||||
unsigned size = std::accumulate(shape_.begin(), shape_.end(), 1, std::multiplies<int>());
|
||||
unsigned num_threads = num_warps * 32;
|
||||
unsigned num_threads = tgt->is_gpu() ? num_warps * 32 : 1;
|
||||
nts_.resize(shape_.size());
|
||||
mts_.resize(shape_.size());
|
||||
bool is_dot = std::any_of(values.begin(), values.end(),
|
||||
@@ -324,8 +324,8 @@ shared_layout::shared_layout(const data_layout *arg,
|
||||
* ---- Layouts Inference Pass ---- *
|
||||
* -------------------------------- */
|
||||
|
||||
layouts::layouts(analysis::axes *axes, analysis::align *align, size_t num_warps)
|
||||
: axes_(axes), align_(align), num_warps_(num_warps) { }
|
||||
layouts::layouts(analysis::axes *axes, analysis::align *align, size_t num_warps, target* tgt)
|
||||
: axes_(axes), align_(align), num_warps_(num_warps), tgt_(tgt){ }
|
||||
|
||||
|
||||
void layouts::connect(ir::value *x, ir::value *y) {
|
||||
@@ -382,7 +382,7 @@ void layouts::create(size_t id, const std::vector<ir::value*>& values) {
|
||||
layouts_[id] = new shared_layout(get(arg), axes, shapes, values, largest->get_type()->get_scalar_ty(), align_);
|
||||
}
|
||||
else
|
||||
layouts_[id] = new scanline_layout(num_warps_, axes, shapes, values, align_);
|
||||
layouts_[id] = new scanline_layout(num_warps_, axes, shapes, values, align_, tgt_);
|
||||
}
|
||||
|
||||
void layouts::run(ir::module &mod) {
|
||||
|
@@ -488,41 +488,47 @@ void generator::visit_masked_store_inst(ir::masked_store_inst* st) {
|
||||
ptr = gep->getPointerOperand();
|
||||
}
|
||||
ptr = builder_->CreateBitCast(ptr, ty->getPointerTo(1));
|
||||
// asm argument type
|
||||
std::vector<Type*> arg_ty = {pred->getType(), ptr->getType()};
|
||||
for(int v = 0; v < vector_size; v++)
|
||||
arg_ty.push_back(ty->getScalarType());
|
||||
// asm function type
|
||||
FunctionType *fn_ty = FunctionType::get(builder_->getVoidTy(), arg_ty, false);
|
||||
// asm string
|
||||
std::string asm_str;
|
||||
asm_str += "@$0 st.global";
|
||||
if(vector_size > 1)
|
||||
asm_str += ".v" + std::to_string(vector_size);
|
||||
asm_str += ".b" + std::to_string(nbits) + " [$1" + offset + "],";
|
||||
if(vector_size > 1)
|
||||
asm_str += "{";
|
||||
for(int v = 0; v < vector_size; v++){
|
||||
if(v > 0)
|
||||
asm_str += ", ";
|
||||
asm_str += "$" + std::to_string(2 + v);
|
||||
if(tgt_->is_gpu()){
|
||||
// asm argument type
|
||||
std::vector<Type*> arg_ty = {pred->getType(), ptr->getType()};
|
||||
for(int v = 0; v < vector_size; v++)
|
||||
arg_ty.push_back(ty->getScalarType());
|
||||
// asm function type
|
||||
FunctionType *fn_ty = FunctionType::get(builder_->getVoidTy(), arg_ty, false);
|
||||
// asm string
|
||||
std::string asm_str;
|
||||
asm_str += "@$0 st.global";
|
||||
if(vector_size > 1)
|
||||
asm_str += ".v" + std::to_string(vector_size);
|
||||
asm_str += ".b" + std::to_string(nbits) + " [$1" + offset + "],";
|
||||
if(vector_size > 1)
|
||||
asm_str += "{";
|
||||
for(int v = 0; v < vector_size; v++){
|
||||
if(v > 0)
|
||||
asm_str += ", ";
|
||||
asm_str += "$" + std::to_string(2 + v);
|
||||
}
|
||||
if(vector_size > 1)
|
||||
asm_str += "}";
|
||||
asm_str += ";";
|
||||
// asm constraint
|
||||
std::string constraint = "b,l";
|
||||
for(int v = 0; v < vector_size; v++){
|
||||
constraint += ",";
|
||||
constraint += (nbits == 32 ? "r" : "h");
|
||||
}
|
||||
// create inline asm
|
||||
InlineAsm *iasm = InlineAsm::get(fn_ty, asm_str, constraint, true);
|
||||
// call asm
|
||||
std::vector<Value*> args = {pred, ptr};
|
||||
for(int v = 0; v < vector_size; v++)
|
||||
args.push_back(builder_->CreateExtractElement(elt, builder_->getInt32(v)));
|
||||
builder_->CreateCall(iasm, args);
|
||||
}
|
||||
if(vector_size > 1)
|
||||
asm_str += "}";
|
||||
asm_str += ";";
|
||||
// asm constraint
|
||||
std::string constraint = "b,l";
|
||||
for(int v = 0; v < vector_size; v++){
|
||||
constraint += ",";
|
||||
constraint += (nbits == 32 ? "r" : "h");
|
||||
else{
|
||||
builder_->CreateMaskedStore(elt, ptr, alignment, builder_->CreateVectorSplat(vector_size, pred));
|
||||
}
|
||||
// create inline asm
|
||||
InlineAsm *iasm = InlineAsm::get(fn_ty, asm_str, constraint, true);
|
||||
// call asm
|
||||
std::vector<Value*> args = {pred, ptr};
|
||||
for(int v = 0; v < vector_size; v++)
|
||||
args.push_back(builder_->CreateExtractElement(elt, builder_->getInt32(v)));
|
||||
builder_->CreateCall(iasm, args);
|
||||
|
||||
}
|
||||
});
|
||||
}
|
||||
@@ -1302,17 +1308,22 @@ void generator::visit_function(ir::function* fn) {
|
||||
for(auto attr_pair: fn->attrs()){
|
||||
unsigned id = attr_pair.first;
|
||||
for(ir::attribute attr: attr_pair.second)
|
||||
if(attr.is_llvm_attr())
|
||||
ret->addAttribute(id, llvm_attr(ctx, attr));
|
||||
if(attr.is_llvm_attr()){
|
||||
llvm::Attribute llattr = llvm_attr(ctx, attr);
|
||||
if(llattr.getKindAsEnum() != llvm::Attribute::None)
|
||||
ret->addAttribute(id, llvm_attr(ctx, attr));
|
||||
}
|
||||
}
|
||||
// set metadata
|
||||
tgt_->set_kernel(*builder_, ctx, mod_, ret);
|
||||
Metadata *md_args[] = {
|
||||
ValueAsMetadata::get(ret),
|
||||
MDString::get(ctx, "maxntidx"),
|
||||
ValueAsMetadata::get(builder_->getInt32(num_warps_*32))
|
||||
};
|
||||
mod_->getOrInsertNamedMetadata("nvvm.annotations")->addOperand(MDNode::get(ctx, md_args));
|
||||
if(tgt_->is_gpu()){
|
||||
tgt_->set_kernel(*builder_, ctx, mod_, ret);
|
||||
Metadata *md_args[] = {
|
||||
ValueAsMetadata::get(ret),
|
||||
MDString::get(ctx, "maxntidx"),
|
||||
ValueAsMetadata::get(builder_->getInt32(num_warps_*32))
|
||||
};
|
||||
mod_->getOrInsertNamedMetadata("nvvm.annotations")->addOperand(MDNode::get(ctx, md_args));
|
||||
}
|
||||
// set arguments
|
||||
for(unsigned i = 0; i < fn->args().size(); i++)
|
||||
vmap_[fn->args()[i]] = &*(ret->arg_begin() + i);
|
||||
|
Reference in New Issue
Block a user