From 6a9810ccf273e1274a60e1f83230053da324ccb1 Mon Sep 17 00:00:00 2001 From: Philippe Tillet Date: Sat, 1 May 2021 16:14:58 -0400 Subject: [PATCH] [codegen] small bugfix: (#97) * Added fp32 -> fp8 for ConstantFP = 0 * Added some more robust semantic check for atomic_add --- lib/codegen/selection/generator.cc | 30 +++++++++--------------------- lib/ir/dispatch.cc | 13 +++++++++++++ 2 files changed, 22 insertions(+), 21 deletions(-) diff --git a/lib/codegen/selection/generator.cc b/lib/codegen/selection/generator.cc index f19892cd3..328e51d05 100644 --- a/lib/codegen/selection/generator.cc +++ b/lib/codegen/selection/generator.cc @@ -322,27 +322,13 @@ void generator::visit_fcmp_inst(ir::fcmp_inst* x) { std::tuple generator::fp32x4_to_fp8x4(Value *in0, Value *in1, Value *in2, Value *in3){ - InlineAsm *ptx = InlineAsm::get(FunctionType::get(i32_ty, {f32_ty, f32_ty, f32_ty, f32_ty}, false), - "{ \n\t" - ".reg .b32 b<4>; \n\t" - "shl.b32 b0, $1, 4; \n\t" // shift into into upper byte - "shl.b32 b1, $2, 4; \n\t" - "shl.b32 b2, $3, 4; \n\t" - "shl.b32 b3, $4, 4; \n\t" - "lop3.b32 b0, b0, 0x80000000, $1, 0xb8; \n\t" // restore sign - "lop3.b32 b1, b1, 0x80000000, $2, 0xb8; \n\t" - "lop3.b32 b2, b2, 0x80000000, $3, 0xb8; \n\t" - "lop3.b32 b3, b3, 0x80000000, $4, 0xb8; \n\t" - "prmt.b32 b0, b0, b1, 0x6273; \n\t" // pack lower half b0, b1 (62 unused here) - "prmt.b32 b2, b2, b3, 0x6273; \n\t" // pack lower half b2, b3 (62 unused here) - "prmt.b32 $0, b0, b2, 0x5410; \n\t" // pack full b0, b1, b2, b3 - "}", "=r, r, r, r, r", false); - Value *packed_ret = call(ptx, {in0, in1, in2, in3}); - Value* ret = bit_cast(packed_ret, vec_ty(i8_ty, 4)); - return std::make_tuple(extract_elt(ret, (int)0), - extract_elt(ret, (int)1), - extract_elt(ret, (int)2), - extract_elt(ret, (int)3)); + auto cvt = [this](Value *v){ + if(ConstantFP* ci = dyn_cast(v)) + if(ci->getValue().convertToFloat() == 0) + return builder_->getInt8(0); + throw std::runtime_error("unsupported cast"); + }; + return std::make_tuple(cvt(in0), cvt(in1), cvt(in2), cvt(in3)); } std::tuple generator::fp8x4_to_fp32x4(Value *in0, Value *in1, Value *in2, Value *in3){ @@ -405,6 +391,8 @@ void generator::visit_cast_inst(ir::cast_inst* x) { auto op_idxs = idxs_.at(op); // run the conversion auto cvt = [&](Value* a, Value* b, Value* c, Value* d){ + if(op_sca_ty->is_float_ty() && ret_sca_ty->is_fp8_ty()) + return fp32x4_to_fp8x4(a, b, c, d); if(op_sca_ty->is_fp8_ty() && ret_sca_ty->is_half_ty()) return fp8x4_to_fp16x4(a, b, c, d); throw std::runtime_error("unsupported conversion"); diff --git a/lib/ir/dispatch.cc b/lib/ir/dispatch.cc index 53c0dd2e8..5f6d66916 100644 --- a/lib/ir/dispatch.cc +++ b/lib/ir/dispatch.cc @@ -522,11 +522,24 @@ ir::value *dispatch::atomic_xchg(ir::value* ptr, ir::value *val, ir::builder *bu } ir::value *dispatch::atomic_add(ir::value* ptr, ir::value *val, ir::value *mask, ir::builder *builder){ + if(!ptr->get_type()->get_scalar_ty()->is_pointer_ty()) + throw semantic_error("Pointer argument of store instruction is " + ptr->get_type()->repr()); + if(ptr->get_type()->is_block_ty()){ + if(mask){ + mask = dispatch::broadcast(mask, ptr->get_type()->get_block_shapes(), builder); + } + if(val){ + val = dispatch::broadcast(val, ptr->get_type()->get_block_shapes(), builder); + } + } + val = dispatch::cast(val, ptr->get_type()->get_scalar_ty()->get_pointer_element_ty(), builder); + if(!mask){ mask = builder->get_int1(true); if(ptr->get_type()->is_block_ty()) mask = builder->create_splat(mask, ptr->get_type()->get_block_shapes()); } + return builder->create_atomic_add(ptr, val, mask); }