[codegen] small bugfix: (#97)

* Added fp32 -> fp8 for ConstantFP = 0
 * Added some more robust semantic check for atomic_add
This commit is contained in:
Philippe Tillet
2021-05-01 16:14:58 -04:00
committed by Philippe Tillet
parent 7355efa745
commit 6a9810ccf2
2 changed files with 22 additions and 21 deletions

View File

@@ -322,27 +322,13 @@ void generator::visit_fcmp_inst(ir::fcmp_inst* x) {
std::tuple<Value*, Value*, Value*, Value*> generator::fp32x4_to_fp8x4(Value *in0, Value *in1, Value *in2, Value *in3){
InlineAsm *ptx = InlineAsm::get(FunctionType::get(i32_ty, {f32_ty, f32_ty, f32_ty, f32_ty}, false),
"{ \n\t"
".reg .b32 b<4>; \n\t"
"shl.b32 b0, $1, 4; \n\t" // shift into into upper byte
"shl.b32 b1, $2, 4; \n\t"
"shl.b32 b2, $3, 4; \n\t"
"shl.b32 b3, $4, 4; \n\t"
"lop3.b32 b0, b0, 0x80000000, $1, 0xb8; \n\t" // restore sign
"lop3.b32 b1, b1, 0x80000000, $2, 0xb8; \n\t"
"lop3.b32 b2, b2, 0x80000000, $3, 0xb8; \n\t"
"lop3.b32 b3, b3, 0x80000000, $4, 0xb8; \n\t"
"prmt.b32 b0, b0, b1, 0x6273; \n\t" // pack lower half b0, b1 (62 unused here)
"prmt.b32 b2, b2, b3, 0x6273; \n\t" // pack lower half b2, b3 (62 unused here)
"prmt.b32 $0, b0, b2, 0x5410; \n\t" // pack full b0, b1, b2, b3
"}", "=r, r, r, r, r", false);
Value *packed_ret = call(ptx, {in0, in1, in2, in3});
Value* ret = bit_cast(packed_ret, vec_ty(i8_ty, 4));
return std::make_tuple(extract_elt(ret, (int)0),
extract_elt(ret, (int)1),
extract_elt(ret, (int)2),
extract_elt(ret, (int)3));
auto cvt = [this](Value *v){
if(ConstantFP* ci = dyn_cast<ConstantFP>(v))
if(ci->getValue().convertToFloat() == 0)
return builder_->getInt8(0);
throw std::runtime_error("unsupported cast");
};
return std::make_tuple(cvt(in0), cvt(in1), cvt(in2), cvt(in3));
}
std::tuple<Value*, Value*, Value*, Value*> generator::fp8x4_to_fp32x4(Value *in0, Value *in1, Value *in2, Value *in3){
@@ -405,6 +391,8 @@ void generator::visit_cast_inst(ir::cast_inst* x) {
auto op_idxs = idxs_.at(op);
// run the conversion
auto cvt = [&](Value* a, Value* b, Value* c, Value* d){
if(op_sca_ty->is_float_ty() && ret_sca_ty->is_fp8_ty())
return fp32x4_to_fp8x4(a, b, c, d);
if(op_sca_ty->is_fp8_ty() && ret_sca_ty->is_half_ty())
return fp8x4_to_fp16x4(a, b, c, d);
throw std::runtime_error("unsupported conversion");

View File

@@ -522,11 +522,24 @@ ir::value *dispatch::atomic_xchg(ir::value* ptr, ir::value *val, ir::builder *bu
}
ir::value *dispatch::atomic_add(ir::value* ptr, ir::value *val, ir::value *mask, ir::builder *builder){
if(!ptr->get_type()->get_scalar_ty()->is_pointer_ty())
throw semantic_error("Pointer argument of store instruction is " + ptr->get_type()->repr());
if(ptr->get_type()->is_block_ty()){
if(mask){
mask = dispatch::broadcast(mask, ptr->get_type()->get_block_shapes(), builder);
}
if(val){
val = dispatch::broadcast(val, ptr->get_type()->get_block_shapes(), builder);
}
}
val = dispatch::cast(val, ptr->get_type()->get_scalar_ty()->get_pointer_element_ty(), builder);
if(!mask){
mask = builder->get_int1(true);
if(ptr->get_type()->is_block_ty())
mask = builder->create_splat(mask, ptr->get_type()->get_block_shapes());
}
return builder->create_atomic_add(ptr, val, mask);
}