[CODEGEN] Improvements and bugfixes (#463)

This commit is contained in:
Philippe Tillet
2022-02-24 14:56:24 -08:00
committed by GitHub
parent a9dfdcaaa9
commit 98ed7db8c1
14 changed files with 154 additions and 81 deletions

View File

@@ -119,7 +119,7 @@ Value* geper::operator()(Value *ptr, Value* off, const std::string& name){
#define icmp_ult(...) builder_->CreateICmpULT(__VA_ARGS__)
#define insert_elt(...) builder_->CreateInsertElement(__VA_ARGS__)
#define intrinsic(...) builder_->CreateIntrinsic(__VA_ARGS__)
#define load(...) builder_->CreateLoad(__VA_ARGS__)
#define load(ptr) builder_->CreateLoad(ptr->getType()->getPointerElementType(), ptr)
#define lshr(...) builder_->CreateLShr(__VA_ARGS__)
#define max_num(...) builder_->CreateMaxNum(__VA_ARGS__)
#define min_num(...) builder_->CreateMinNum(__VA_ARGS__)
@@ -576,18 +576,19 @@ void generator::visit_cast_inst(ir::cast_inst* x) {
// <> BF16
if(ret_sca_ty->is_bf16_ty() || op_sca_ty->is_bf16_ty()){
// FP32 -> BF16
if(op_sca_ty->is_fp32_ty())
// for(size_t i = 0; i < x_idxs.size(); i++)
// vals_[x][x_idxs[i + 0]] = fp32_to_bf16(vals_[op][op_idxs[i + 0]]);
if(op_sca_ty->is_fp32_ty()){
for (indices_t idx: idxs_.at(x)) {
Value *arg = vals_[x->get_operand(0)][idx];
vals_[x][idx] = fp32_to_bf16(arg); // cast(cvt(x->get_op()), arg, ty);
}
return;
}
// BF16 -> FP32
if(ret_sca_ty->is_fp32_ty())
if(ret_sca_ty->is_fp32_ty()){
for(size_t i = 0; i < x_idxs.size(); i++)
vals_[x][x_idxs[i + 0]] = bf16_to_fp32(vals_[op][op_idxs[i + 0]]);
return;
return;
}
}
@@ -697,12 +698,13 @@ void generator::visit_load_inst(ir::load_inst* x){
std::ostringstream asm_oss;
asm_oss << "@$" << n_words; // predicate
asm_oss << " ld";
// std::cout << x->get_is_volatile() << std::endl;
if(x->get_is_volatile())
asm_oss << ".volatile";
asm_oss << ".global";
if (x->get_cache_modifier() == ir::load_inst::CA) asm_oss << ".ca";
if (x->get_cache_modifier() == ir::load_inst::CG) asm_oss << ".cg";
if (x->get_eviction_policy() == ir::load_inst::EVICT_LAST) asm_oss << ".L1::evict_last";
if (x->get_eviction_policy() == ir::load_inst::EVICT_FIRST) asm_oss << ".L1::evict_first";
if(n_words > 1)
asm_oss << ".v" << n_words; // vector width
asm_oss << ".b" << width; // word size