diff --git a/include/triton/codegen/selection/generator.h b/include/triton/codegen/selection/generator.h index 2a65ba27c..58908ace1 100644 --- a/include/triton/codegen/selection/generator.h +++ b/include/triton/codegen/selection/generator.h @@ -134,6 +134,9 @@ public: std::tuple fp32x4_to_fp8x4(Value *in0, Value *in1, Value *in2, Value *in3); std::tuple fp8x4_to_fp16x4(Value *in0, Value *in1, Value *in2, Value *in3); std::tuple fp16x4_to_fp8x4(Value *in0, Value *in1, Value *in2, Value *in3); + Value* bf16_to_fp32(Value *in0); + Value* fp32_to_bf16(Value *in0); + void visit_cast_inst(ir::cast_inst*); void visit_return_inst(ir::return_inst*); void visit_cond_branch_inst(ir::cond_branch_inst*); diff --git a/include/triton/ir/context_impl.h b/include/triton/ir/context_impl.h index ba175a434..7d18a3b4c 100644 --- a/include/triton/ir/context_impl.h +++ b/include/triton/ir/context_impl.h @@ -25,7 +25,7 @@ public: // non-numeric types type void_ty, label_ty; // floating point types - type fp8_ty, half_ty, float_ty, double_ty; + type fp8_ty, fp16_ty, bf16_ty, fp32_ty, fp64_ty; // integer types integer_type int1_ty, int8_ty, int16_ty, int32_ty, int64_ty, int128_ty; // Pointer types diff --git a/include/triton/ir/type.h b/include/triton/ir/type.h index 804566362..33c74f245 100644 --- a/include/triton/ir/type.h +++ b/include/triton/ir/type.h @@ -28,23 +28,21 @@ protected: public: enum id_t { // primitive types - VoidTyID = 0, ///< 0: type with no size - FP8TyID, ///< 1: 8-bit floating point type (3 bits mantissa) - HalfTyID, ///< 3: 16-bit floating point type - FloatTyID, ///< 4: 32-bit floating point type - DoubleTyID, ///< 5: 64-bit floating point type - X86_FP80TyID, ///< 6: 80-bit floating point type (X87) - FP128TyID, ///< 7: 128-bit floating point type (112-bit mantissa) - PPC_FP128TyID, ///< 8: 128-bit floating point type (two 64-bits, PowerPC) - LabelTyID, ///< 9: Labels - MetadataTyID, ///< 10: Metadata - TokenTyID, ///< 11: Token + VoidTyID = 0, ///< type with no size + FP8TyID, ///< 8-bit floating point type (3 bits mantissa) + FP16TyID, ///< 16-bit floating point type (10 bits mantissa) + BF16TyID, ///< 16-bit floating point type (7 bits mantissa) + FP32TyID, ///< 32-bit floating point type + FP64TyID, ///< 64-bit floating point type + LabelTyID, ///< Labels + MetadataTyID, ///< Metadata + TokenTyID, ///< Token // derived types - IntegerTyID, ///< 12: Arbitrary bit width integers - FunctionTyID, ///< 13: Functions - PointerTyID, ///< 14: Pointers - StructTyID, ///< 15: Struct - BlockTyID, ///< 16: Block + IntegerTyID, ///< Arbitrary bit width integers + FunctionTyID, ///< Functions + PointerTyID, ///< Pointers + StructTyID, ///< Struct + BlockTyID, ///< Block }; public: @@ -74,9 +72,10 @@ public: // primitive predicates bool is_void_ty() const { return id_ == VoidTyID; } bool is_fp8_ty() const { return id_ == FP8TyID; } - bool is_half_ty() const { return id_ == HalfTyID; } - bool is_float_ty() const { return id_ == FloatTyID; } - bool is_double_ty() const { return id_ == DoubleTyID; } + bool is_fp16_ty() const { return id_ == FP16TyID; } + bool is_bf16_ty() const { return id_ == BF16TyID; } + bool is_fp32_ty() const { return id_ == FP32TyID; } + bool is_fp64_ty() const { return id_ == FP64TyID; } bool is_label_ty() const { return id_ == LabelTyID;} bool is_metadata_ty() const { return id_ == MetadataTyID; } bool is_token_ty() const { return id_ == TokenTyID; } @@ -99,9 +98,10 @@ public: static type *get_label_ty(context &ctx); // half static type *get_fp8_ty(context &ctx); - static type *get_half_ty(context &ctx); - static type *get_float_ty(context &ctx); - static type *get_double_ty(context &ctx); + static type *get_fp16_ty(context &ctx); + static type *get_bf16_ty(context &ctx); + static type *get_fp32_ty(context &ctx); + static type *get_fp64_ty(context &ctx); // integer types static integer_type *get_int1_ty(context &ctx); static integer_type *get_int8_ty(context &ctx); @@ -128,12 +128,9 @@ public: switch(id_) { case VoidTyID: return "void"; case FP8TyID: return "fp8"; - case HalfTyID: return "f16"; - case FloatTyID: return "f32"; - case DoubleTyID: return "f64"; - case X86_FP80TyID: return "f80"; - case FP128TyID: return "f128"; - case PPC_FP128TyID: return "ppcf128"; + case FP16TyID: return "f16"; + case FP32TyID: return "f32"; + case FP64TyID: return "f64"; case LabelTyID: return "label"; case MetadataTyID: return "md"; case TokenTyID: return "tok"; diff --git a/lib/codegen/analysis/layout.cc b/lib/codegen/analysis/layout.cc index 2b3c3d1ac..46f1649b0 100644 --- a/lib/codegen/analysis/layout.cc +++ b/lib/codegen/analysis/layout.cc @@ -30,8 +30,8 @@ inline bool is_hmma_c(ir::value *v){ ir::type *a_ty = a->get_type(); ir::value *b = x->get_operand(1); ir::type *b_ty = b->get_type(); - result = a_ty->get_scalar_ty()->is_half_ty() && - b_ty->get_scalar_ty()->is_half_ty(); + result = a_ty->get_scalar_ty()->is_fp16_ty() && + b_ty->get_scalar_ty()->is_fp16_ty(); } return result; } @@ -312,6 +312,7 @@ static bool is_multistage_pipe_phi(ir::phi_node* phi, ir::basic_block* bb0, ir:: } else return false; } + return false; } void shared_layout::extract_N_bufferable(ir::value *v, std::shared_ptr &res, int &prev_stages) { diff --git a/lib/codegen/pass.cc b/lib/codegen/pass.cc index d3b78f28f..a88f4259e 100644 --- a/lib/codegen/pass.cc +++ b/lib/codegen/pass.cc @@ -96,7 +96,7 @@ void add_passes_to_emit_bin(ir::module &ir, driver::device *dev, int num_warps, // ir::print(ir, std::cout); barriers.run(ir); // ir::print(ir, std::cout); - // ir::print(ir, std::cout); +// ir::print(ir, std::cout); isel.visit(ir, *llvm); mod = driver::module::create(dev, std::move(llvm)); ker = driver::kernel::create(&*mod, name.c_str()); diff --git a/lib/codegen/selection/generator.cc b/lib/codegen/selection/generator.cc index e93fe7895..d827928a2 100644 --- a/lib/codegen/selection/generator.cc +++ b/lib/codegen/selection/generator.cc @@ -161,11 +161,10 @@ Type *generator::cvt(ir::type *ty) { switch(ty->get_type_id()){ case ir::type::VoidTyID: return Type::getVoidTy(*ctx_); case ir::type::FP8TyID: return Type::getInt8Ty(*ctx_); - case ir::type::HalfTyID: return Type::getHalfTy(*ctx_); - case ir::type::FloatTyID: return Type::getFloatTy(*ctx_); - case ir::type::DoubleTyID: return Type::getDoubleTy(*ctx_); - case ir::type::X86_FP80TyID: return Type::getX86_FP80Ty(*ctx_); - case ir::type::PPC_FP128TyID: return Type::getPPC_FP128Ty(*ctx_); + case ir::type::FP16TyID: return Type::getHalfTy(*ctx_); + case ir::type::BF16TyID: return Type::getInt16Ty(*ctx_); + case ir::type::FP32TyID: return Type::getFloatTy(*ctx_); + case ir::type::FP64TyID: return Type::getDoubleTy(*ctx_); case ir::type::LabelTyID: return Type::getLabelTy(*ctx_); case ir::type::MetadataTyID: return Type::getMetadataTy(*ctx_); case ir::type::TokenTyID: return Type::getTokenTy(*ctx_); @@ -428,57 +427,74 @@ std::tuple generator::fp8x4_to_fp32x4(Value *in0 std::tuple generator::fp8x4_to_fp16x4(Value *in0, Value *in1, Value *in2, Value *in3){ - Type *ret_ty = StructType::get(*ctx_, {vec_ty(f16_ty, 2), vec_ty(f16_ty, 2)}); - InlineAsm *ptx = InlineAsm::get(FunctionType::get(ret_ty, {i32_ty}, false), - "{" - ".reg .b32 a<2>, b<2>; \n\t" - "prmt.b32 a0, 0, $2, 0x5140; \n\t" - "prmt.b32 a1, 0, $2, 0x7362; \n\t" - "lop3.b32 b0, a0, 0x7fff7fff, 0, 0xc0; \n\t" // strip sign - "lop3.b32 b1, a1, 0x7fff7fff, 0, 0xc0; \n\t" - "shr.b32 b0, b0, 1; \n\t" // shift into fp16 poistion - "shr.b32 b1, b1, 1; \n\t" - "lop3.b32 $0, b0, 0x80008000, a0, 0xf8; \n\t" // restore sign - "lop3.b32 $1, b1, 0x80008000, a1, 0xf8; \n\t" - "}", "=r,=r,r", false); - Value *packed_in = UndefValue::get(vec_ty(i8_ty, 4)); - packed_in = insert_elt(packed_in, in0, (int)0); - packed_in = insert_elt(packed_in, in1, (int)1); - packed_in = insert_elt(packed_in, in2, (int)2); - packed_in = insert_elt(packed_in, in3, (int)3); - Value *in = bit_cast(packed_in, i32_ty); - Value *ret = call(ptx, {in}); - Value *packed_ret0 = extract_val(ret, {0}); - Value *packed_ret1 = extract_val(ret, {1}); - Value *ret0 = extract_elt(packed_ret0, (int)0); - Value *ret1 = extract_elt(packed_ret0, (int)1); - Value *ret2 = extract_elt(packed_ret1, (int)0); - Value *ret3 = extract_elt(packed_ret1, (int)1); - return std::make_tuple(ret0, ret1, ret2, ret3); + Type *ret_ty = StructType::get(*ctx_, {vec_ty(f16_ty, 2), vec_ty(f16_ty, 2)}); + InlineAsm *ptx = InlineAsm::get(FunctionType::get(ret_ty, {i32_ty}, false), + "{" + ".reg .b32 a<2>, b<2>; \n\t" + "prmt.b32 a0, 0, $2, 0x5140; \n\t" + "prmt.b32 a1, 0, $2, 0x7362; \n\t" + "lop3.b32 b0, a0, 0x7fff7fff, 0, 0xc0; \n\t" // strip sign + "lop3.b32 b1, a1, 0x7fff7fff, 0, 0xc0; \n\t" + "shr.b32 b0, b0, 1; \n\t" // shift into fp16 poistion + "shr.b32 b1, b1, 1; \n\t" + "lop3.b32 $0, b0, 0x80008000, a0, 0xf8; \n\t" // restore sign + "lop3.b32 $1, b1, 0x80008000, a1, 0xf8; \n\t" + "}", "=r,=r,r", false); + Value *packed_in = UndefValue::get(vec_ty(i8_ty, 4)); + packed_in = insert_elt(packed_in, in0, (int)0); + packed_in = insert_elt(packed_in, in1, (int)1); + packed_in = insert_elt(packed_in, in2, (int)2); + packed_in = insert_elt(packed_in, in3, (int)3); + Value *in = bit_cast(packed_in, i32_ty); + Value *ret = call(ptx, {in}); + Value *packed_ret0 = extract_val(ret, {0}); + Value *packed_ret1 = extract_val(ret, {1}); + Value *ret0 = extract_elt(packed_ret0, (int)0); + Value *ret1 = extract_elt(packed_ret0, (int)1); + Value *ret2 = extract_elt(packed_ret1, (int)0); + Value *ret3 = extract_elt(packed_ret1, (int)1); + return std::make_tuple(ret0, ret1, ret2, ret3); } +Value* generator::bf16_to_fp32(Value *in0){ + Value *ret = UndefValue::get(vec_ty(builder_->getInt16Ty(), 2)); + ret = insert_elt(ret, in0, (uint64_t)1); + ret = insert_elt(ret, builder_->getInt16(0), (uint64_t)0); + return bit_cast(ret, builder_->getFloatTy()); +} + +Value* generator::fp32_to_bf16(Value *in0){ + if(tgt_->as_nvidia()->sm() >= 80){ + InlineAsm *ptx = InlineAsm::get(FunctionType::get(builder_->getInt16Ty(), {builder_->getFloatTy()}), + "cvt.rn.bf16.f32 $0, $1;", "=h,r", false); + return call(ptx, {in0}); + } + return extract_elt(bit_cast(in0, vec_ty(builder_->getInt16Ty(), 2)), (uint64_t)1); +} /** * \brief Code Generation for `cast` */ void generator::visit_cast_inst(ir::cast_inst* x) { - // <> FP8 ir::value *op = x->get_operand(0); ir::type* ret_sca_ty = x->get_type()->get_scalar_ty(); ir::type* op_sca_ty = op->get_type()->get_scalar_ty(); + auto x_idxs = idxs_.at(x); + auto op_idxs = idxs_.at(op); + + // <> FP8 if(ret_sca_ty->is_fp8_ty() || op_sca_ty->is_fp8_ty()){ // ensure that conversions can be vectorized int ld = layouts_->get(x)->get_order(0); int contiguous = layouts_->get(x)->to_scanline()->nts(ld); if(contiguous % 4 != 0) throw std::runtime_error("unsupported fp32 -> fp8 conversion"); - auto x_idxs = idxs_.at(x); - auto op_idxs = idxs_.at(op); + // run the conversion auto cvt = [&](Value* a, Value* b, Value* c, Value* d){ - if(op_sca_ty->is_float_ty() && ret_sca_ty->is_fp8_ty()) + if(op_sca_ty->is_fp32_ty() && ret_sca_ty->is_fp8_ty()) return fp32x4_to_fp8x4(a, b, c, d); - if(op_sca_ty->is_fp8_ty() && ret_sca_ty->is_half_ty()) + if(op_sca_ty->is_fp8_ty() && ret_sca_ty->is_fp16_ty()) return fp8x4_to_fp16x4(a, b, c, d); throw std::runtime_error("unsupported conversion"); }; @@ -494,6 +510,19 @@ void generator::visit_cast_inst(ir::cast_inst* x) { return; } + // <> BF16 + if(ret_sca_ty->is_bf16_ty() || op_sca_ty->is_bf16_ty()){ + // FP32 -> BF16 + if(op_sca_ty->is_fp32_ty()) + for(size_t i = 0; i < x_idxs.size(); i++) + vals_[x][x_idxs[i + 0]] = fp32_to_bf16(vals_[op][op_idxs[i + 0]]); + // BF16 -> FP32 + if(ret_sca_ty->is_fp32_ty()) + for(size_t i = 0; i < x_idxs.size(); i++) + vals_[x][x_idxs[i + 0]] = bf16_to_fp32(vals_[op][op_idxs[i + 0]]); + return; + } + Type *ty = cvt(x->get_type()->get_scalar_ty()); auto cvt = [](ir::cast_op_t op){ @@ -675,7 +704,6 @@ void generator::visit_load_inst(ir::load_inst* x){ curr = extract_val(_ret, {ii}); else curr = _ret; -// std::cout << n_words << " " << vec << " " << width << " " << dtsize << " " << nbits << std::endl; rets.push_back(bit_cast(curr, vec_ty(ty, width / (dtsize*8)))); } int tmp = (width / (dtsize * 8)); @@ -694,6 +722,7 @@ void generator::visit_masked_load_inst(ir::masked_load_inst* x) { /** * \brief Code Generation for a (synchronous) `store` */ + void generator::visit_store_inst(ir::store_inst * x){ ir::masked_store_inst *mx = dynamic_cast(x); // operands @@ -740,6 +769,7 @@ void generator::visit_masked_store_inst(ir::masked_store_inst* x) { visit_store_inst(x); } + /** * \brief Code Generation for `reshape` */ @@ -901,7 +931,7 @@ void generator::visit_atomic_rmw_inst(ir::atomic_rmw_inst *atom) { int ld = ords_.at(ptr)[0]; unsigned alignment = alignment_->get(ptr, ld); vec = std::min(layouts_->get(ptr)->to_scanline()->nts(ld), alignment); - vec = std::min(vec, val->get_type()->get_tile_element_ty()->is_half_ty() ? 2 : 1); + vec = std::min(vec, val->get_type()->get_tile_element_ty()->is_fp16_ty() ? 2 : 1); } for(int i = 0; i < idxs_.at(val).size(); i += vec){ @@ -1105,10 +1135,10 @@ void generator::visit_mma884(ir::dot_inst* C, ir::value *A, ir::value *B, ir::va ir::phi_node* phiA = dynamic_cast(A); ir::phi_node* phiB = dynamic_cast(B); - + // Cache lds value. If values are prefetched, create phi node // @param inc: incoming block (0 = header, 1 = loop) - auto register_lds = + auto register_lds = [&](decltype(has)& vals, int m, int K, int inc, Value* val0, Value *val1, bool is_prefetch) { if (K == 0 && is_prefetch) { ir::basic_block* inc_block = phiA->get_incoming_block(inc); @@ -1208,7 +1238,7 @@ void generator::visit_mma884(ir::dot_inst* C, ir::value *A, ir::value *B, ir::va load_a(m, 0, 0, true); for (unsigned n = 0; n < num_n/2; n += is_b_row?2:1) load_b(n, 0, 0, true); - + // update accumulators builder_->SetInsertPoint(curr_bb); for (unsigned K = 0; K < NK; K += 4) { @@ -1225,7 +1255,7 @@ void generator::visit_mma884(ir::dot_inst* C, ir::value *A, ir::value *B, ir::va call_mma(m, n, K); } } - } else { // not prefetched + } else { // not prefetched for(unsigned K = 0; K < NK; K += 4) for(unsigned m = 0; m < num_m/2; m++) for(unsigned n = 0; n < num_n/2; n++) { @@ -1356,7 +1386,7 @@ void generator::visit_mma16816(ir::dot_inst* C, ir::value *A, ir::value *B, ir:: "{$0, $1, $2, $3}, " "{$4, $5, $6, $7}, " "{$8, $9}, " - "{$10, $11, $12, $13};", + "{$10, $11, $12, $13};", "=f,=f,=f,=f,r,r,r,r,r,r,0,1,2,3", true); unsigned num_rep_0 = shapes[0] / layout->spt(0); @@ -1416,8 +1446,8 @@ void generator::visit_mma16816(ir::dot_inst* C, ir::value *A, ir::value *B, ir:: int step_am = is_a_row ? m : m / (num_ptr_a)*(num_ptr_a); int step_ak = is_a_row ? K / (num_ptr_a*16)*(num_ptr_a*16) : K; InlineAsm *ld_a0_fn = InlineAsm::get(ld_x4_ty, "ldmatrix.sync.aligned.m8n8.x4" + a_trans + ".shared.b16 " - "{$0, $1, $2, $3}, [$4 + " + - std::to_string(2*step_am*16*layout->wpt(0)*stride_a_m + 2*step_ak*stride_a_k) + "];", + "{$0, $1, $2, $3}, [$4 + " + + std::to_string(2*step_am*16*layout->wpt(0)*stride_a_m + 2*step_ak*stride_a_k) + "];", "=r,=r,=r,=r,r", true); Value *haa = call(ld_x4_ty, ld_a0_fn, {ptra}); if(K == 0 && inc == 1 && is_prefetch) @@ -1444,8 +1474,8 @@ void generator::visit_mma16816(ir::dot_inst* C, ir::value *A, ir::value *B, ir:: int step_bn = is_b_row ? n / (num_ptr_b)*(num_ptr_b) : n; int step_bk = is_b_row ? K : K / (num_ptr_b*8)*(num_ptr_b*8); InlineAsm *ld_b_fn = InlineAsm::get(ld_x4_ty, "ldmatrix.sync.aligned.m8n8.x4" + b_trans + ".shared.b16 " - "{$0, $1, $2, $3}, [$4 + " + - std::to_string(2*step_bn*8*layout->wpt(1)*stride_b_n + 2*step_bk*stride_b_k) + "];", + "{$0, $1, $2, $3}, [$4 + " + + std::to_string(2*step_bn*8*layout->wpt(1)*stride_b_n + 2*step_bk*stride_b_k) + "];", "=r,=r,=r,=r,r", true); Value *hbb = call(ld_x4_ty, ld_b_fn, {ptrb}); if(K == 0 && inc == 1 && is_prefetch) @@ -2058,7 +2088,7 @@ void generator::visit_prefetch_s_inst(ir::prefetch_s_inst *i) { } else { // If dot has been visitied, insert prefetched lds assert(inc == 1); - assert(prefetch_latch_to_bb_.find(v) != prefetch_latch_to_bb_.end() && + assert(prefetch_latch_to_bb_.find(v) != prefetch_latch_to_bb_.end() && "dot hasn't be visited"); // sink lds & extract element // move lds & all uses to current location @@ -2081,7 +2111,7 @@ void generator::visit_prefetch_s_inst(ir::prefetch_s_inst *i) { assert(m_instr->getParent() == &*builder_->GetInsertBlock()); builder_->SetInsertPoint(m_instr->getParent()); } - } + } } void generator::visit_async_wait_inst(ir::async_wait_inst* i) { @@ -2384,7 +2414,7 @@ void generator::visit_layout_shared(analysis::shared_layout* layout) { builder_->SetInsertPoint(parent); else if (const Instruction *first_non_phi = &*parent->getFirstNonPHI()) { builder_->SetInsertPoint(&*parent->getFirstNonPHI()); - } else + } else builder_->SetInsertPoint(parent); // create smem_idx @@ -2507,7 +2537,7 @@ void generator::finalize_shared_layout(analysis::shared_layout *shared) { Value *idx = smem_idx[shared]; builder_->SetInsertPoint(bbs_.at(n_buffer->phi->get_parent())->getTerminator()); Value *cond = icmp_eq(idx, i32(shared->get_num_stages()-1)); - PHINode *_ret = phi(i32_ty, 2); + PHINode *_ret = phi(i32_ty, 2); Instruction *then_term = nullptr; Instruction *else_term = nullptr; Instruction *dummy = builder_->CreateRet(nullptr); @@ -2544,7 +2574,7 @@ void generator::finalize_shared_layout(analysis::shared_layout *shared) { if (auto curr_ptr = dyn_cast(shared_ptr_[shared])) { curr_ptr->addIncoming(shared_pre_ptr_[shared], header); curr_ptr->addIncoming(shared_next_ptr_[shared], loop); - } else + } else throw std::runtime_error("Should be PHINode"); BasicBlock *current = builder_->GetInsertBlock(); diff --git a/lib/codegen/transform/prefetch.cc b/lib/codegen/transform/prefetch.cc index 6c6d28ad7..e4e32b2ec 100644 --- a/lib/codegen/transform/prefetch.cc +++ b/lib/codegen/transform/prefetch.cc @@ -30,7 +30,7 @@ void prefetch::run(ir::module &mod) { ir::for_each_instruction(mod, [&](ir::instruction *i) { if (auto *dot = dynamic_cast(i)) { // Now only do prefetching when dot is fp16 - if (dot->get_operand(0)->get_type()->get_scalar_ty()->get_type_id() != ir::type::HalfTyID) + if (dot->get_operand(0)->get_type()->get_scalar_ty()->get_type_id() != ir::type::FP16TyID) return; auto *a = dynamic_cast(dot->get_operand(0)); auto *b = dynamic_cast(dot->get_operand(1)); diff --git a/lib/driver/module.cc b/lib/driver/module.cc index abe839d7e..aba374c7a 100755 --- a/lib/driver/module.cc +++ b/lib/driver/module.cc @@ -283,8 +283,6 @@ std::string cu_module::compile_llvm_module(llvm::Module* module, driver::device* void cu_module::init_from_ptx(const std::string& ptx, driver::cu_device* device) { // JIT compile source-code -// std::cout << ptx << std::endl; - try{ std::string ptxas = tools::getenv("TRITON_PTXAS"); @@ -324,7 +322,7 @@ void cu_module::init_from_ptx(const std::string& ptx, driver::cu_device* device) } catch(exception::cuda::invalid_ptx const &){ //#ifdef TRITON_LOG_PTX_ERROR - // std::cout << ptx << std::endl; + std::cout << ptx << std::endl; std::cerr << "It appears that Triton produced invalid PTX code:" << std::endl; // exit(1); //#endif diff --git a/lib/ir/builder.cc b/lib/ir/builder.cc index 80914d145..2bd745884 100644 --- a/lib/ir/builder.cc +++ b/lib/ir/builder.cc @@ -55,10 +55,10 @@ value *builder::get_int64(int64_t val) { return constant_int::get(type::get_int64_ty(ctx_), val);} value *builder::get_float16(float val) -{ return constant_fp::get(type::get_half_ty(ctx_), val); } +{ return constant_fp::get(type::get_fp16_ty(ctx_), val); } value *builder::get_float32(float val) -{ return constant_fp::get(type::get_float_ty(ctx_), val); } +{ return constant_fp::get(type::get_fp32_ty(ctx_), val); } value *builder::get_range(int32_t _lo, int32_t _hi) { constant_int* lo = static_cast(get_int32(_lo)); @@ -85,13 +85,13 @@ type *builder::get_int64_ty() { return type::get_int64_ty(ctx_); } type *builder::get_half_ty() -{ return type::get_half_ty(ctx_); } +{ return type::get_fp16_ty(ctx_); } type *builder::get_float_ty() -{ return type::get_float_ty(ctx_); } +{ return type::get_fp32_ty(ctx_); } type *builder::get_double_ty() -{ return type::get_double_ty(ctx_); } +{ return type::get_fp64_ty(ctx_); } //===----------------------------------------------------------------------===// diff --git a/lib/ir/constant.cc b/lib/ir/constant.cc index ece402979..3a06b5dea 100644 --- a/lib/ir/constant.cc +++ b/lib/ir/constant.cc @@ -15,12 +15,12 @@ constant *constant::get_null_value(type *ty) { switch (ty->get_scalar_ty()->get_type_id()) { case type::IntegerTyID: return constant_int::get(ty, 0); - case type::HalfTyID: - return constant_fp::get(type::get_half_ty(ctx), 0); - case type::FloatTyID: - return constant_fp::get(type::get_float_ty(ctx), 0); - case type::DoubleTyID: - return constant_fp::get(type::get_double_ty(ctx), 0); + case type::FP16TyID: + return constant_fp::get(type::get_fp16_ty(ctx), 0); + case type::FP32TyID: + return constant_fp::get(type::get_fp32_ty(ctx), 0); + case type::FP64TyID: + return constant_fp::get(type::get_fp64_ty(ctx), 0); default: throw std::runtime_error("Cannot create a null constant of that type!"); } diff --git a/lib/ir/context.cc b/lib/ir/context.cc index 7aa79dde4..9bd66ec9a 100644 --- a/lib/ir/context.cc +++ b/lib/ir/context.cc @@ -14,9 +14,10 @@ context_impl::context_impl(context &ctx) label_ty(ctx, type::LabelTyID), // floating point fp8_ty(ctx, type::FP8TyID), - half_ty(ctx, type::HalfTyID), - float_ty(ctx, type::FloatTyID), - double_ty(ctx, type::DoubleTyID), + fp16_ty(ctx, type::FP16TyID), + bf16_ty(ctx, type::BF16TyID), + fp32_ty(ctx, type::FP32TyID), + fp64_ty(ctx, type::FP64TyID), // integers int1_ty(ctx, 1), int8_ty(ctx, 8), diff --git a/lib/ir/dispatch.cc b/lib/ir/dispatch.cc index 83076e685..cd7223239 100644 --- a/lib/ir/dispatch.cc +++ b/lib/ir/dispatch.cc @@ -37,16 +37,16 @@ ir::type *computation_type(ir::type* a_ty, ir::type* b_ty){ context &ctx = a_ty->get_context(); // 1) if one operand is double, the other is implicitly // converted to double - if(a_ty->is_double_ty() || b_ty->is_double_ty()) - return type::get_double_ty(ctx); + if(a_ty->is_fp64_ty() || b_ty->is_fp64_ty()) + return type::get_fp64_ty(ctx); // 2) if one operand is float, the other is implicitly // converted to float - if(a_ty->is_float_ty() || b_ty->is_float_ty()) - return type::get_float_ty(ctx); + if(a_ty->is_fp32_ty() || b_ty->is_fp32_ty()) + return type::get_fp32_ty(ctx); // 3 ) if one operand is half, the other is implicitly // converted to half - if(a_ty->is_half_ty() || b_ty->is_half_ty()) - return type::get_half_ty(ctx); + if(a_ty->is_fp16_ty() || b_ty->is_fp16_ty()) + return type::get_fp16_ty(ctx); if(!a_ty->is_integer_ty() || !b_ty->is_integer_ty()) throw_unreachable("augment_types"); // 4 ) both operands are integer and undergo diff --git a/lib/ir/type.cc b/lib/ir/type.cc index 9607e7db2..47ae5bde0 100644 --- a/lib/ir/type.cc +++ b/lib/ir/type.cc @@ -22,12 +22,10 @@ type *type::get_scalar_ty() const { unsigned type::get_primitive_size_in_bits() const { switch (id_) { case FP8TyID: return 8; - case HalfTyID: return 16; - case FloatTyID: return 32; - case DoubleTyID: return 64; - case X86_FP80TyID: return 80; - case FP128TyID: return 128; - case PPC_FP128TyID: return 128; + case FP16TyID: return 16; + case BF16TyID: return 16; + case FP32TyID: return 32; + case FP64TyID: return 64; case IntegerTyID: return ((integer_type*)(this))->get_bitwidth(); case BlockTyID: return ((block_type*)(this))->get_bitwidth(); default: return 0; @@ -44,9 +42,10 @@ unsigned type::get_fp_mantissa_width() const { id_t id = get_scalar_ty()->id_; assert(is_floating_point_ty() && "Not a floating point type!"); if (id == FP8TyID) return 3; - if (id == HalfTyID) return 10; - if (id == FloatTyID) return 23; - if (id == DoubleTyID) return 53; + if (id == FP16TyID) return 10; + if (id == BF16TyID) return 7; + if (id == FP32TyID) return 23; + if (id == FP64TyID) return 53; throw std::runtime_error("unreachable"); } @@ -105,7 +104,7 @@ bool type::is_integer_ty(unsigned width) const bool type::is_floating_point_ty() const -{ return is_fp8_ty() || is_half_ty() || is_float_ty() || is_double_ty(); } +{ return is_fp8_ty() || is_fp16_ty() || is_bf16_ty() || is_fp32_ty() || is_fp64_ty(); } bool type::is_sized() const { // primitive types are sized @@ -124,9 +123,10 @@ type *type::get_void_ty(context &ctx) { return &ctx.p_impl->void_ty; } type *type::get_label_ty(context &ctx) { return &ctx.p_impl->label_ty; } // floating point type *type::get_fp8_ty(context &ctx) { return &ctx.p_impl->fp8_ty; } -type *type::get_half_ty(context &ctx) { return &ctx.p_impl->half_ty; } -type *type::get_float_ty(context &ctx) { return &ctx.p_impl->float_ty; } -type *type::get_double_ty(context &ctx) { return &ctx.p_impl->double_ty; } +type *type::get_fp16_ty(context &ctx) { return &ctx.p_impl->fp16_ty; } +type *type::get_bf16_ty(context &ctx) { return &ctx.p_impl->bf16_ty; } +type *type::get_fp32_ty(context &ctx) { return &ctx.p_impl->fp32_ty; } +type *type::get_fp64_ty(context &ctx) { return &ctx.p_impl->fp64_ty; } // integer types integer_type *type::get_int1_ty(context &ctx) { return &ctx.p_impl->int1_ty; } integer_type *type::get_int8_ty(context &ctx) { return &ctx.p_impl->int8_ty; } diff --git a/python/setup.py b/python/setup.py index a11b81ebf..bf7675f32 100644 --- a/python/setup.py +++ b/python/setup.py @@ -49,7 +49,7 @@ class CMakeBuild(build_ext): self.build_extension(ext) def build_extension(self, ext): - #self.debug = True + # self.debug = True extdir = os.path.abspath(os.path.dirname(self.get_ext_fullpath(ext.path))) # create build directories build_suffix = 'debug' if self.debug else 'release' diff --git a/python/src/triton.cc b/python/src/triton.cc index c72f753a5..591e4bf79 100644 --- a/python/src/triton.cc +++ b/python/src/triton.cc @@ -204,9 +204,10 @@ void init_triton_ir(py::module &&m) { .def("make_block", &ir::block_type::get, ret::reference) .def("get_void", &ir::type::get_void_ty, ret::reference) .def("get_fp8", &ir::type::get_fp8_ty, ret::reference) - .def("get_fp16", &ir::type::get_half_ty, ret::reference) - .def("get_fp32", &ir::type::get_float_ty, ret::reference) - .def("get_fp64", &ir::type::get_double_ty, ret::reference) + .def("get_fp16", &ir::type::get_fp16_ty, ret::reference) + .def("get_bf16", &ir::type::get_bf16_ty, ret::reference) + .def("get_fp32", &ir::type::get_fp32_ty, ret::reference) + .def("get_fp64", &ir::type::get_fp64_ty, ret::reference) .def("get_int1", &ir::type::get_int1_ty, ret::reference) .def("get_int8", &ir::type::get_int8_ty, ret::reference) .def("get_int16", &ir::type::get_int16_ty, ret::reference) @@ -215,9 +216,10 @@ void init_triton_ir(py::module &&m) { .def("is_void", &ir::type::is_void_ty) .def("is_fp8", &ir::type::is_fp8_ty) - .def("is_fp16", &ir::type::is_half_ty) - .def("is_fp32", &ir::type::is_float_ty) - .def("is_fp64", &ir::type::is_double_ty) + .def("is_fp16", &ir::type::is_fp16_ty) + .def("is_bf16", &ir::type::is_bf16_ty) + .def("is_fp32", &ir::type::is_fp32_ty) + .def("is_fp64", &ir::type::is_fp64_ty) .def("is_int1", [](ir::type *self) { return self->is_integer_ty(1); }) .def("is_int8", [](ir::type *self) { return self->is_integer_ty(8); }) .def("is_int16", [](ir::type *self) { return self->is_integer_ty(16); }) diff --git a/python/test/test_language.py b/python/test/test_language.py index dd5215b24..3757cdf22 100644 --- a/python/test/test_language.py +++ b/python/test/test_language.py @@ -16,6 +16,7 @@ cvt = { 'int16': torch.int16, 'int32': torch.int32, 'int64': torch.int64, + 'bfloat16': torch.bfloat16, 'float16': torch.float16, 'float32': torch.float32, 'float64': torch.float64, @@ -292,9 +293,12 @@ def test_atomic_rmw(op, dtype_x, mode, device='cuda'): # test cast # --------------- @pytest.mark.parametrize("dtype_x, dtype_z, bitcast", [ - (dtype_x, dtype_z, False) for dtype_x in dtypes \ - for dtype_z in dtypes -] + [ + (dtype_x, dtype_z, False) \ + for dtype_x in dtypes\ + for dtype_z in dtypes +] + [ + ('float32', 'bfloat16', False), + ('bfloat16', 'float32', False), ('float32', 'int32', True) ]) def test_cast(dtype_x, dtype_z, bitcast, device='cuda'): diff --git a/python/triton/code_gen.py b/python/triton/code_gen.py index 1c28d0b4e..5b0863814 100644 --- a/python/triton/code_gen.py +++ b/python/triton/code_gen.py @@ -465,6 +465,7 @@ class Kernel: float: 'f', bool: 'B', triton.language.float8: 'f8', + torch.bfloat16: 'bf16', torch.float16: 'f16', torch.float32: 'f32', torch.float64: 'f64', @@ -484,6 +485,7 @@ class Kernel: 'B': _triton.ir.type.get_int1, 'f8': _triton.ir.type.get_fp8, 'f16': _triton.ir.type.get_fp16, + 'bf16': _triton.ir.type.get_bf16, 'f32': _triton.ir.type.get_fp32, 'f64': _triton.ir.type.get_fp64, 'i1': _triton.ir.type.get_int1, @@ -555,6 +557,7 @@ class Kernel: if len(tensor_idxs) == 0: raise ValueError("No Tensor argument found.") device = wargs[tensor_idxs[0]].device + torch.cuda.set_device(device.index) # attributes args = [arg.data_ptr() if i in tensor_idxs else arg for i, arg in enumerate(wargs)] attributes = {i: Kernel.pow2_divisor(a) for i, a in enumerate(args) if isinstance(a, int)} diff --git a/python/triton/language.py b/python/triton/language.py index 6029058bd..a8e36b9aa 100644 --- a/python/triton/language.py +++ b/python/triton/language.py @@ -86,6 +86,7 @@ int32 = dtype(ir.type.get_int32) int64 = dtype(ir.type.get_int64) float8 = dtype(ir.type.get_fp8) float16 = dtype(ir.type.get_fp16) +bfloat16 = dtype(ir.type.get_bf16) float32 = dtype(ir.type.get_fp32) float64 = dtype(ir.type.get_fp64) @@ -103,6 +104,7 @@ class block: if ir_type.is_int64(): return int64 if ir_type.is_fp8(): return float8 if ir_type.is_fp16(): return float16 + if ir_type.is_bf16(): return bfloat16 if ir_type.is_fp32(): return float32 if ir_type.is_fp64(): return float64 # pointer type