[IR] Preliminary support for BF16 (#129)
This PR adds a BF16 data-type, along with FP32 <-> BF16 conversion instructions in the LLVM codegen. Other kinds of ops on bfloat16 are not yet supported.
This commit is contained in:
committed by
Philippe Tillet
parent
9b4e2cae2d
commit
8cea583109
@@ -134,6 +134,9 @@ public:
|
|||||||
std::tuple<Value*, Value*, Value*, Value*> fp32x4_to_fp8x4(Value *in0, Value *in1, Value *in2, Value *in3);
|
std::tuple<Value*, Value*, Value*, Value*> fp32x4_to_fp8x4(Value *in0, Value *in1, Value *in2, Value *in3);
|
||||||
std::tuple<Value*, Value*, Value*, Value*> fp8x4_to_fp16x4(Value *in0, Value *in1, Value *in2, Value *in3);
|
std::tuple<Value*, Value*, Value*, Value*> fp8x4_to_fp16x4(Value *in0, Value *in1, Value *in2, Value *in3);
|
||||||
std::tuple<Value*, Value*, Value*, Value*> fp16x4_to_fp8x4(Value *in0, Value *in1, Value *in2, Value *in3);
|
std::tuple<Value*, Value*, Value*, Value*> fp16x4_to_fp8x4(Value *in0, Value *in1, Value *in2, Value *in3);
|
||||||
|
Value* bf16_to_fp32(Value *in0);
|
||||||
|
Value* fp32_to_bf16(Value *in0);
|
||||||
|
|
||||||
void visit_cast_inst(ir::cast_inst*);
|
void visit_cast_inst(ir::cast_inst*);
|
||||||
void visit_return_inst(ir::return_inst*);
|
void visit_return_inst(ir::return_inst*);
|
||||||
void visit_cond_branch_inst(ir::cond_branch_inst*);
|
void visit_cond_branch_inst(ir::cond_branch_inst*);
|
||||||
|
@@ -25,7 +25,7 @@ public:
|
|||||||
// non-numeric types
|
// non-numeric types
|
||||||
type void_ty, label_ty;
|
type void_ty, label_ty;
|
||||||
// floating point types
|
// floating point types
|
||||||
type fp8_ty, half_ty, float_ty, double_ty;
|
type fp8_ty, fp16_ty, bf16_ty, fp32_ty, fp64_ty;
|
||||||
// integer types
|
// integer types
|
||||||
integer_type int1_ty, int8_ty, int16_ty, int32_ty, int64_ty, int128_ty;
|
integer_type int1_ty, int8_ty, int16_ty, int32_ty, int64_ty, int128_ty;
|
||||||
// Pointer types
|
// Pointer types
|
||||||
|
@@ -28,23 +28,21 @@ protected:
|
|||||||
public:
|
public:
|
||||||
enum id_t {
|
enum id_t {
|
||||||
// primitive types
|
// primitive types
|
||||||
VoidTyID = 0, ///< 0: type with no size
|
VoidTyID = 0, ///< type with no size
|
||||||
FP8TyID, ///< 1: 8-bit floating point type (3 bits mantissa)
|
FP8TyID, ///< 8-bit floating point type (3 bits mantissa)
|
||||||
HalfTyID, ///< 3: 16-bit floating point type
|
FP16TyID, ///< 16-bit floating point type (10 bits mantissa)
|
||||||
FloatTyID, ///< 4: 32-bit floating point type
|
BF16TyID, ///< 16-bit floating point type (7 bits mantissa)
|
||||||
DoubleTyID, ///< 5: 64-bit floating point type
|
FP32TyID, ///< 32-bit floating point type
|
||||||
X86_FP80TyID, ///< 6: 80-bit floating point type (X87)
|
FP64TyID, ///< 64-bit floating point type
|
||||||
FP128TyID, ///< 7: 128-bit floating point type (112-bit mantissa)
|
LabelTyID, ///< Labels
|
||||||
PPC_FP128TyID, ///< 8: 128-bit floating point type (two 64-bits, PowerPC)
|
MetadataTyID, ///< Metadata
|
||||||
LabelTyID, ///< 9: Labels
|
TokenTyID, ///< Token
|
||||||
MetadataTyID, ///< 10: Metadata
|
|
||||||
TokenTyID, ///< 11: Token
|
|
||||||
// derived types
|
// derived types
|
||||||
IntegerTyID, ///< 12: Arbitrary bit width integers
|
IntegerTyID, ///< Arbitrary bit width integers
|
||||||
FunctionTyID, ///< 13: Functions
|
FunctionTyID, ///< Functions
|
||||||
PointerTyID, ///< 14: Pointers
|
PointerTyID, ///< Pointers
|
||||||
StructTyID, ///< 15: Struct
|
StructTyID, ///< Struct
|
||||||
BlockTyID, ///< 16: Block
|
BlockTyID, ///< Block
|
||||||
};
|
};
|
||||||
|
|
||||||
public:
|
public:
|
||||||
@@ -74,9 +72,10 @@ public:
|
|||||||
// primitive predicates
|
// primitive predicates
|
||||||
bool is_void_ty() const { return id_ == VoidTyID; }
|
bool is_void_ty() const { return id_ == VoidTyID; }
|
||||||
bool is_fp8_ty() const { return id_ == FP8TyID; }
|
bool is_fp8_ty() const { return id_ == FP8TyID; }
|
||||||
bool is_half_ty() const { return id_ == HalfTyID; }
|
bool is_fp16_ty() const { return id_ == FP16TyID; }
|
||||||
bool is_float_ty() const { return id_ == FloatTyID; }
|
bool is_bf16_ty() const { return id_ == BF16TyID; }
|
||||||
bool is_double_ty() const { return id_ == DoubleTyID; }
|
bool is_fp32_ty() const { return id_ == FP32TyID; }
|
||||||
|
bool is_fp64_ty() const { return id_ == FP64TyID; }
|
||||||
bool is_label_ty() const { return id_ == LabelTyID;}
|
bool is_label_ty() const { return id_ == LabelTyID;}
|
||||||
bool is_metadata_ty() const { return id_ == MetadataTyID; }
|
bool is_metadata_ty() const { return id_ == MetadataTyID; }
|
||||||
bool is_token_ty() const { return id_ == TokenTyID; }
|
bool is_token_ty() const { return id_ == TokenTyID; }
|
||||||
@@ -99,9 +98,10 @@ public:
|
|||||||
static type *get_label_ty(context &ctx);
|
static type *get_label_ty(context &ctx);
|
||||||
// half
|
// half
|
||||||
static type *get_fp8_ty(context &ctx);
|
static type *get_fp8_ty(context &ctx);
|
||||||
static type *get_half_ty(context &ctx);
|
static type *get_fp16_ty(context &ctx);
|
||||||
static type *get_float_ty(context &ctx);
|
static type *get_bf16_ty(context &ctx);
|
||||||
static type *get_double_ty(context &ctx);
|
static type *get_fp32_ty(context &ctx);
|
||||||
|
static type *get_fp64_ty(context &ctx);
|
||||||
// integer types
|
// integer types
|
||||||
static integer_type *get_int1_ty(context &ctx);
|
static integer_type *get_int1_ty(context &ctx);
|
||||||
static integer_type *get_int8_ty(context &ctx);
|
static integer_type *get_int8_ty(context &ctx);
|
||||||
@@ -128,12 +128,9 @@ public:
|
|||||||
switch(id_) {
|
switch(id_) {
|
||||||
case VoidTyID: return "void";
|
case VoidTyID: return "void";
|
||||||
case FP8TyID: return "fp8";
|
case FP8TyID: return "fp8";
|
||||||
case HalfTyID: return "f16";
|
case FP16TyID: return "f16";
|
||||||
case FloatTyID: return "f32";
|
case FP32TyID: return "f32";
|
||||||
case DoubleTyID: return "f64";
|
case FP64TyID: return "f64";
|
||||||
case X86_FP80TyID: return "f80";
|
|
||||||
case FP128TyID: return "f128";
|
|
||||||
case PPC_FP128TyID: return "ppcf128";
|
|
||||||
case LabelTyID: return "label";
|
case LabelTyID: return "label";
|
||||||
case MetadataTyID: return "md";
|
case MetadataTyID: return "md";
|
||||||
case TokenTyID: return "tok";
|
case TokenTyID: return "tok";
|
||||||
|
@@ -30,8 +30,8 @@ inline bool is_hmma_c(ir::value *v){
|
|||||||
ir::type *a_ty = a->get_type();
|
ir::type *a_ty = a->get_type();
|
||||||
ir::value *b = x->get_operand(1);
|
ir::value *b = x->get_operand(1);
|
||||||
ir::type *b_ty = b->get_type();
|
ir::type *b_ty = b->get_type();
|
||||||
result = a_ty->get_scalar_ty()->is_half_ty() &&
|
result = a_ty->get_scalar_ty()->is_fp16_ty() &&
|
||||||
b_ty->get_scalar_ty()->is_half_ty();
|
b_ty->get_scalar_ty()->is_fp16_ty();
|
||||||
}
|
}
|
||||||
return result;
|
return result;
|
||||||
}
|
}
|
||||||
@@ -312,6 +312,7 @@ static bool is_multistage_pipe_phi(ir::phi_node* phi, ir::basic_block* bb0, ir::
|
|||||||
} else
|
} else
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
return false;
|
||||||
}
|
}
|
||||||
|
|
||||||
void shared_layout::extract_N_bufferable(ir::value *v, std::shared_ptr<N_buffer_info_t> &res, int &prev_stages) {
|
void shared_layout::extract_N_bufferable(ir::value *v, std::shared_ptr<N_buffer_info_t> &res, int &prev_stages) {
|
||||||
|
@@ -96,7 +96,7 @@ void add_passes_to_emit_bin(ir::module &ir, driver::device *dev, int num_warps,
|
|||||||
// ir::print(ir, std::cout);
|
// ir::print(ir, std::cout);
|
||||||
barriers.run(ir);
|
barriers.run(ir);
|
||||||
// ir::print(ir, std::cout);
|
// ir::print(ir, std::cout);
|
||||||
// ir::print(ir, std::cout);
|
// ir::print(ir, std::cout);
|
||||||
isel.visit(ir, *llvm);
|
isel.visit(ir, *llvm);
|
||||||
mod = driver::module::create(dev, std::move(llvm));
|
mod = driver::module::create(dev, std::move(llvm));
|
||||||
ker = driver::kernel::create(&*mod, name.c_str());
|
ker = driver::kernel::create(&*mod, name.c_str());
|
||||||
|
@@ -161,11 +161,10 @@ Type *generator::cvt(ir::type *ty) {
|
|||||||
switch(ty->get_type_id()){
|
switch(ty->get_type_id()){
|
||||||
case ir::type::VoidTyID: return Type::getVoidTy(*ctx_);
|
case ir::type::VoidTyID: return Type::getVoidTy(*ctx_);
|
||||||
case ir::type::FP8TyID: return Type::getInt8Ty(*ctx_);
|
case ir::type::FP8TyID: return Type::getInt8Ty(*ctx_);
|
||||||
case ir::type::HalfTyID: return Type::getHalfTy(*ctx_);
|
case ir::type::FP16TyID: return Type::getHalfTy(*ctx_);
|
||||||
case ir::type::FloatTyID: return Type::getFloatTy(*ctx_);
|
case ir::type::BF16TyID: return Type::getInt16Ty(*ctx_);
|
||||||
case ir::type::DoubleTyID: return Type::getDoubleTy(*ctx_);
|
case ir::type::FP32TyID: return Type::getFloatTy(*ctx_);
|
||||||
case ir::type::X86_FP80TyID: return Type::getX86_FP80Ty(*ctx_);
|
case ir::type::FP64TyID: return Type::getDoubleTy(*ctx_);
|
||||||
case ir::type::PPC_FP128TyID: return Type::getPPC_FP128Ty(*ctx_);
|
|
||||||
case ir::type::LabelTyID: return Type::getLabelTy(*ctx_);
|
case ir::type::LabelTyID: return Type::getLabelTy(*ctx_);
|
||||||
case ir::type::MetadataTyID: return Type::getMetadataTy(*ctx_);
|
case ir::type::MetadataTyID: return Type::getMetadataTy(*ctx_);
|
||||||
case ir::type::TokenTyID: return Type::getTokenTy(*ctx_);
|
case ir::type::TokenTyID: return Type::getTokenTy(*ctx_);
|
||||||
@@ -428,57 +427,74 @@ std::tuple<Value*, Value*, Value*, Value*> generator::fp8x4_to_fp32x4(Value *in0
|
|||||||
|
|
||||||
|
|
||||||
std::tuple<Value*, Value*, Value*, Value*> generator::fp8x4_to_fp16x4(Value *in0, Value *in1, Value *in2, Value *in3){
|
std::tuple<Value*, Value*, Value*, Value*> generator::fp8x4_to_fp16x4(Value *in0, Value *in1, Value *in2, Value *in3){
|
||||||
Type *ret_ty = StructType::get(*ctx_, {vec_ty(f16_ty, 2), vec_ty(f16_ty, 2)});
|
Type *ret_ty = StructType::get(*ctx_, {vec_ty(f16_ty, 2), vec_ty(f16_ty, 2)});
|
||||||
InlineAsm *ptx = InlineAsm::get(FunctionType::get(ret_ty, {i32_ty}, false),
|
InlineAsm *ptx = InlineAsm::get(FunctionType::get(ret_ty, {i32_ty}, false),
|
||||||
"{"
|
"{"
|
||||||
".reg .b32 a<2>, b<2>; \n\t"
|
".reg .b32 a<2>, b<2>; \n\t"
|
||||||
"prmt.b32 a0, 0, $2, 0x5140; \n\t"
|
"prmt.b32 a0, 0, $2, 0x5140; \n\t"
|
||||||
"prmt.b32 a1, 0, $2, 0x7362; \n\t"
|
"prmt.b32 a1, 0, $2, 0x7362; \n\t"
|
||||||
"lop3.b32 b0, a0, 0x7fff7fff, 0, 0xc0; \n\t" // strip sign
|
"lop3.b32 b0, a0, 0x7fff7fff, 0, 0xc0; \n\t" // strip sign
|
||||||
"lop3.b32 b1, a1, 0x7fff7fff, 0, 0xc0; \n\t"
|
"lop3.b32 b1, a1, 0x7fff7fff, 0, 0xc0; \n\t"
|
||||||
"shr.b32 b0, b0, 1; \n\t" // shift into fp16 poistion
|
"shr.b32 b0, b0, 1; \n\t" // shift into fp16 poistion
|
||||||
"shr.b32 b1, b1, 1; \n\t"
|
"shr.b32 b1, b1, 1; \n\t"
|
||||||
"lop3.b32 $0, b0, 0x80008000, a0, 0xf8; \n\t" // restore sign
|
"lop3.b32 $0, b0, 0x80008000, a0, 0xf8; \n\t" // restore sign
|
||||||
"lop3.b32 $1, b1, 0x80008000, a1, 0xf8; \n\t"
|
"lop3.b32 $1, b1, 0x80008000, a1, 0xf8; \n\t"
|
||||||
"}", "=r,=r,r", false);
|
"}", "=r,=r,r", false);
|
||||||
Value *packed_in = UndefValue::get(vec_ty(i8_ty, 4));
|
Value *packed_in = UndefValue::get(vec_ty(i8_ty, 4));
|
||||||
packed_in = insert_elt(packed_in, in0, (int)0);
|
packed_in = insert_elt(packed_in, in0, (int)0);
|
||||||
packed_in = insert_elt(packed_in, in1, (int)1);
|
packed_in = insert_elt(packed_in, in1, (int)1);
|
||||||
packed_in = insert_elt(packed_in, in2, (int)2);
|
packed_in = insert_elt(packed_in, in2, (int)2);
|
||||||
packed_in = insert_elt(packed_in, in3, (int)3);
|
packed_in = insert_elt(packed_in, in3, (int)3);
|
||||||
Value *in = bit_cast(packed_in, i32_ty);
|
Value *in = bit_cast(packed_in, i32_ty);
|
||||||
Value *ret = call(ptx, {in});
|
Value *ret = call(ptx, {in});
|
||||||
Value *packed_ret0 = extract_val(ret, {0});
|
Value *packed_ret0 = extract_val(ret, {0});
|
||||||
Value *packed_ret1 = extract_val(ret, {1});
|
Value *packed_ret1 = extract_val(ret, {1});
|
||||||
Value *ret0 = extract_elt(packed_ret0, (int)0);
|
Value *ret0 = extract_elt(packed_ret0, (int)0);
|
||||||
Value *ret1 = extract_elt(packed_ret0, (int)1);
|
Value *ret1 = extract_elt(packed_ret0, (int)1);
|
||||||
Value *ret2 = extract_elt(packed_ret1, (int)0);
|
Value *ret2 = extract_elt(packed_ret1, (int)0);
|
||||||
Value *ret3 = extract_elt(packed_ret1, (int)1);
|
Value *ret3 = extract_elt(packed_ret1, (int)1);
|
||||||
return std::make_tuple(ret0, ret1, ret2, ret3);
|
return std::make_tuple(ret0, ret1, ret2, ret3);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Value* generator::bf16_to_fp32(Value *in0){
|
||||||
|
Value *ret = UndefValue::get(vec_ty(builder_->getInt16Ty(), 2));
|
||||||
|
ret = insert_elt(ret, in0, (uint64_t)1);
|
||||||
|
ret = insert_elt(ret, builder_->getInt16(0), (uint64_t)0);
|
||||||
|
return bit_cast(ret, builder_->getFloatTy());
|
||||||
|
}
|
||||||
|
|
||||||
|
Value* generator::fp32_to_bf16(Value *in0){
|
||||||
|
if(tgt_->as_nvidia()->sm() >= 80){
|
||||||
|
InlineAsm *ptx = InlineAsm::get(FunctionType::get(builder_->getInt16Ty(), {builder_->getFloatTy()}),
|
||||||
|
"cvt.rn.bf16.f32 $0, $1;", "=h,r", false);
|
||||||
|
return call(ptx, {in0});
|
||||||
|
}
|
||||||
|
return extract_elt(bit_cast(in0, vec_ty(builder_->getInt16Ty(), 2)), (uint64_t)1);
|
||||||
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* \brief Code Generation for `cast`
|
* \brief Code Generation for `cast`
|
||||||
*/
|
*/
|
||||||
void generator::visit_cast_inst(ir::cast_inst* x) {
|
void generator::visit_cast_inst(ir::cast_inst* x) {
|
||||||
// <> FP8
|
|
||||||
ir::value *op = x->get_operand(0);
|
ir::value *op = x->get_operand(0);
|
||||||
ir::type* ret_sca_ty = x->get_type()->get_scalar_ty();
|
ir::type* ret_sca_ty = x->get_type()->get_scalar_ty();
|
||||||
ir::type* op_sca_ty = op->get_type()->get_scalar_ty();
|
ir::type* op_sca_ty = op->get_type()->get_scalar_ty();
|
||||||
|
auto x_idxs = idxs_.at(x);
|
||||||
|
auto op_idxs = idxs_.at(op);
|
||||||
|
|
||||||
|
// <> FP8
|
||||||
if(ret_sca_ty->is_fp8_ty() || op_sca_ty->is_fp8_ty()){
|
if(ret_sca_ty->is_fp8_ty() || op_sca_ty->is_fp8_ty()){
|
||||||
// ensure that conversions can be vectorized
|
// ensure that conversions can be vectorized
|
||||||
int ld = layouts_->get(x)->get_order(0);
|
int ld = layouts_->get(x)->get_order(0);
|
||||||
int contiguous = layouts_->get(x)->to_scanline()->nts(ld);
|
int contiguous = layouts_->get(x)->to_scanline()->nts(ld);
|
||||||
if(contiguous % 4 != 0)
|
if(contiguous % 4 != 0)
|
||||||
throw std::runtime_error("unsupported fp32 -> fp8 conversion");
|
throw std::runtime_error("unsupported fp32 -> fp8 conversion");
|
||||||
auto x_idxs = idxs_.at(x);
|
|
||||||
auto op_idxs = idxs_.at(op);
|
|
||||||
// run the conversion
|
// run the conversion
|
||||||
auto cvt = [&](Value* a, Value* b, Value* c, Value* d){
|
auto cvt = [&](Value* a, Value* b, Value* c, Value* d){
|
||||||
if(op_sca_ty->is_float_ty() && ret_sca_ty->is_fp8_ty())
|
if(op_sca_ty->is_fp32_ty() && ret_sca_ty->is_fp8_ty())
|
||||||
return fp32x4_to_fp8x4(a, b, c, d);
|
return fp32x4_to_fp8x4(a, b, c, d);
|
||||||
if(op_sca_ty->is_fp8_ty() && ret_sca_ty->is_half_ty())
|
if(op_sca_ty->is_fp8_ty() && ret_sca_ty->is_fp16_ty())
|
||||||
return fp8x4_to_fp16x4(a, b, c, d);
|
return fp8x4_to_fp16x4(a, b, c, d);
|
||||||
throw std::runtime_error("unsupported conversion");
|
throw std::runtime_error("unsupported conversion");
|
||||||
};
|
};
|
||||||
@@ -494,6 +510,19 @@ void generator::visit_cast_inst(ir::cast_inst* x) {
|
|||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// <> BF16
|
||||||
|
if(ret_sca_ty->is_bf16_ty() || op_sca_ty->is_bf16_ty()){
|
||||||
|
// FP32 -> BF16
|
||||||
|
if(op_sca_ty->is_fp32_ty())
|
||||||
|
for(size_t i = 0; i < x_idxs.size(); i++)
|
||||||
|
vals_[x][x_idxs[i + 0]] = fp32_to_bf16(vals_[op][op_idxs[i + 0]]);
|
||||||
|
// BF16 -> FP32
|
||||||
|
if(ret_sca_ty->is_fp32_ty())
|
||||||
|
for(size_t i = 0; i < x_idxs.size(); i++)
|
||||||
|
vals_[x][x_idxs[i + 0]] = bf16_to_fp32(vals_[op][op_idxs[i + 0]]);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
Type *ty = cvt(x->get_type()->get_scalar_ty());
|
Type *ty = cvt(x->get_type()->get_scalar_ty());
|
||||||
auto cvt = [](ir::cast_op_t op){
|
auto cvt = [](ir::cast_op_t op){
|
||||||
@@ -675,7 +704,6 @@ void generator::visit_load_inst(ir::load_inst* x){
|
|||||||
curr = extract_val(_ret, {ii});
|
curr = extract_val(_ret, {ii});
|
||||||
else
|
else
|
||||||
curr = _ret;
|
curr = _ret;
|
||||||
// std::cout << n_words << " " << vec << " " << width << " " << dtsize << " " << nbits << std::endl;
|
|
||||||
rets.push_back(bit_cast(curr, vec_ty(ty, width / (dtsize*8))));
|
rets.push_back(bit_cast(curr, vec_ty(ty, width / (dtsize*8))));
|
||||||
}
|
}
|
||||||
int tmp = (width / (dtsize * 8));
|
int tmp = (width / (dtsize * 8));
|
||||||
@@ -694,6 +722,7 @@ void generator::visit_masked_load_inst(ir::masked_load_inst* x) {
|
|||||||
/**
|
/**
|
||||||
* \brief Code Generation for a (synchronous) `store`
|
* \brief Code Generation for a (synchronous) `store`
|
||||||
*/
|
*/
|
||||||
|
|
||||||
void generator::visit_store_inst(ir::store_inst * x){
|
void generator::visit_store_inst(ir::store_inst * x){
|
||||||
ir::masked_store_inst *mx = dynamic_cast<ir::masked_store_inst*>(x);
|
ir::masked_store_inst *mx = dynamic_cast<ir::masked_store_inst*>(x);
|
||||||
// operands
|
// operands
|
||||||
@@ -740,6 +769,7 @@ void generator::visit_masked_store_inst(ir::masked_store_inst* x) {
|
|||||||
visit_store_inst(x);
|
visit_store_inst(x);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* \brief Code Generation for `reshape`
|
* \brief Code Generation for `reshape`
|
||||||
*/
|
*/
|
||||||
@@ -901,7 +931,7 @@ void generator::visit_atomic_rmw_inst(ir::atomic_rmw_inst *atom) {
|
|||||||
int ld = ords_.at(ptr)[0];
|
int ld = ords_.at(ptr)[0];
|
||||||
unsigned alignment = alignment_->get(ptr, ld);
|
unsigned alignment = alignment_->get(ptr, ld);
|
||||||
vec = std::min<int>(layouts_->get(ptr)->to_scanline()->nts(ld), alignment);
|
vec = std::min<int>(layouts_->get(ptr)->to_scanline()->nts(ld), alignment);
|
||||||
vec = std::min(vec, val->get_type()->get_tile_element_ty()->is_half_ty() ? 2 : 1);
|
vec = std::min(vec, val->get_type()->get_tile_element_ty()->is_fp16_ty() ? 2 : 1);
|
||||||
}
|
}
|
||||||
|
|
||||||
for(int i = 0; i < idxs_.at(val).size(); i += vec){
|
for(int i = 0; i < idxs_.at(val).size(); i += vec){
|
||||||
|
@@ -30,7 +30,7 @@ void prefetch::run(ir::module &mod) {
|
|||||||
ir::for_each_instruction(mod, [&](ir::instruction *i) {
|
ir::for_each_instruction(mod, [&](ir::instruction *i) {
|
||||||
if (auto *dot = dynamic_cast<ir::dot_inst*>(i)) {
|
if (auto *dot = dynamic_cast<ir::dot_inst*>(i)) {
|
||||||
// Now only do prefetching when dot is fp16
|
// Now only do prefetching when dot is fp16
|
||||||
if (dot->get_operand(0)->get_type()->get_scalar_ty()->get_type_id() != ir::type::HalfTyID)
|
if (dot->get_operand(0)->get_type()->get_scalar_ty()->get_type_id() != ir::type::FP16TyID)
|
||||||
return;
|
return;
|
||||||
auto *a = dynamic_cast<ir::phi_node*>(dot->get_operand(0));
|
auto *a = dynamic_cast<ir::phi_node*>(dot->get_operand(0));
|
||||||
auto *b = dynamic_cast<ir::phi_node*>(dot->get_operand(1));
|
auto *b = dynamic_cast<ir::phi_node*>(dot->get_operand(1));
|
||||||
|
@@ -283,8 +283,6 @@ std::string cu_module::compile_llvm_module(llvm::Module* module, driver::device*
|
|||||||
|
|
||||||
void cu_module::init_from_ptx(const std::string& ptx, driver::cu_device* device) {
|
void cu_module::init_from_ptx(const std::string& ptx, driver::cu_device* device) {
|
||||||
// JIT compile source-code
|
// JIT compile source-code
|
||||||
// std::cout << ptx << std::endl;
|
|
||||||
|
|
||||||
try{
|
try{
|
||||||
std::string ptxas = tools::getenv("TRITON_PTXAS");
|
std::string ptxas = tools::getenv("TRITON_PTXAS");
|
||||||
|
|
||||||
@@ -324,7 +322,7 @@ void cu_module::init_from_ptx(const std::string& ptx, driver::cu_device* device)
|
|||||||
}
|
}
|
||||||
catch(exception::cuda::invalid_ptx const &){
|
catch(exception::cuda::invalid_ptx const &){
|
||||||
//#ifdef TRITON_LOG_PTX_ERROR
|
//#ifdef TRITON_LOG_PTX_ERROR
|
||||||
// std::cout << ptx << std::endl;
|
std::cout << ptx << std::endl;
|
||||||
std::cerr << "It appears that Triton produced invalid PTX code:" << std::endl;
|
std::cerr << "It appears that Triton produced invalid PTX code:" << std::endl;
|
||||||
// exit(1);
|
// exit(1);
|
||||||
//#endif
|
//#endif
|
||||||
|
@@ -55,10 +55,10 @@ value *builder::get_int64(int64_t val)
|
|||||||
{ return constant_int::get(type::get_int64_ty(ctx_), val);}
|
{ return constant_int::get(type::get_int64_ty(ctx_), val);}
|
||||||
|
|
||||||
value *builder::get_float16(float val)
|
value *builder::get_float16(float val)
|
||||||
{ return constant_fp::get(type::get_half_ty(ctx_), val); }
|
{ return constant_fp::get(type::get_fp16_ty(ctx_), val); }
|
||||||
|
|
||||||
value *builder::get_float32(float val)
|
value *builder::get_float32(float val)
|
||||||
{ return constant_fp::get(type::get_float_ty(ctx_), val); }
|
{ return constant_fp::get(type::get_fp32_ty(ctx_), val); }
|
||||||
|
|
||||||
value *builder::get_range(int32_t _lo, int32_t _hi) {
|
value *builder::get_range(int32_t _lo, int32_t _hi) {
|
||||||
constant_int* lo = static_cast<constant_int*>(get_int32(_lo));
|
constant_int* lo = static_cast<constant_int*>(get_int32(_lo));
|
||||||
@@ -85,13 +85,13 @@ type *builder::get_int64_ty()
|
|||||||
{ return type::get_int64_ty(ctx_); }
|
{ return type::get_int64_ty(ctx_); }
|
||||||
|
|
||||||
type *builder::get_half_ty()
|
type *builder::get_half_ty()
|
||||||
{ return type::get_half_ty(ctx_); }
|
{ return type::get_fp16_ty(ctx_); }
|
||||||
|
|
||||||
type *builder::get_float_ty()
|
type *builder::get_float_ty()
|
||||||
{ return type::get_float_ty(ctx_); }
|
{ return type::get_fp32_ty(ctx_); }
|
||||||
|
|
||||||
type *builder::get_double_ty()
|
type *builder::get_double_ty()
|
||||||
{ return type::get_double_ty(ctx_); }
|
{ return type::get_fp64_ty(ctx_); }
|
||||||
|
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
@@ -15,12 +15,12 @@ constant *constant::get_null_value(type *ty) {
|
|||||||
switch (ty->get_scalar_ty()->get_type_id()) {
|
switch (ty->get_scalar_ty()->get_type_id()) {
|
||||||
case type::IntegerTyID:
|
case type::IntegerTyID:
|
||||||
return constant_int::get(ty, 0);
|
return constant_int::get(ty, 0);
|
||||||
case type::HalfTyID:
|
case type::FP16TyID:
|
||||||
return constant_fp::get(type::get_half_ty(ctx), 0);
|
return constant_fp::get(type::get_fp16_ty(ctx), 0);
|
||||||
case type::FloatTyID:
|
case type::FP32TyID:
|
||||||
return constant_fp::get(type::get_float_ty(ctx), 0);
|
return constant_fp::get(type::get_fp32_ty(ctx), 0);
|
||||||
case type::DoubleTyID:
|
case type::FP64TyID:
|
||||||
return constant_fp::get(type::get_double_ty(ctx), 0);
|
return constant_fp::get(type::get_fp64_ty(ctx), 0);
|
||||||
default:
|
default:
|
||||||
throw std::runtime_error("Cannot create a null constant of that type!");
|
throw std::runtime_error("Cannot create a null constant of that type!");
|
||||||
}
|
}
|
||||||
|
@@ -14,9 +14,10 @@ context_impl::context_impl(context &ctx)
|
|||||||
label_ty(ctx, type::LabelTyID),
|
label_ty(ctx, type::LabelTyID),
|
||||||
// floating point
|
// floating point
|
||||||
fp8_ty(ctx, type::FP8TyID),
|
fp8_ty(ctx, type::FP8TyID),
|
||||||
half_ty(ctx, type::HalfTyID),
|
fp16_ty(ctx, type::FP16TyID),
|
||||||
float_ty(ctx, type::FloatTyID),
|
bf16_ty(ctx, type::BF16TyID),
|
||||||
double_ty(ctx, type::DoubleTyID),
|
fp32_ty(ctx, type::FP32TyID),
|
||||||
|
fp64_ty(ctx, type::FP64TyID),
|
||||||
// integers
|
// integers
|
||||||
int1_ty(ctx, 1),
|
int1_ty(ctx, 1),
|
||||||
int8_ty(ctx, 8),
|
int8_ty(ctx, 8),
|
||||||
|
@@ -37,16 +37,16 @@ ir::type *computation_type(ir::type* a_ty, ir::type* b_ty){
|
|||||||
context &ctx = a_ty->get_context();
|
context &ctx = a_ty->get_context();
|
||||||
// 1) if one operand is double, the other is implicitly
|
// 1) if one operand is double, the other is implicitly
|
||||||
// converted to double
|
// converted to double
|
||||||
if(a_ty->is_double_ty() || b_ty->is_double_ty())
|
if(a_ty->is_fp64_ty() || b_ty->is_fp64_ty())
|
||||||
return type::get_double_ty(ctx);
|
return type::get_fp64_ty(ctx);
|
||||||
// 2) if one operand is float, the other is implicitly
|
// 2) if one operand is float, the other is implicitly
|
||||||
// converted to float
|
// converted to float
|
||||||
if(a_ty->is_float_ty() || b_ty->is_float_ty())
|
if(a_ty->is_fp32_ty() || b_ty->is_fp32_ty())
|
||||||
return type::get_float_ty(ctx);
|
return type::get_fp32_ty(ctx);
|
||||||
// 3 ) if one operand is half, the other is implicitly
|
// 3 ) if one operand is half, the other is implicitly
|
||||||
// converted to half
|
// converted to half
|
||||||
if(a_ty->is_half_ty() || b_ty->is_half_ty())
|
if(a_ty->is_fp16_ty() || b_ty->is_fp16_ty())
|
||||||
return type::get_half_ty(ctx);
|
return type::get_fp16_ty(ctx);
|
||||||
if(!a_ty->is_integer_ty() || !b_ty->is_integer_ty())
|
if(!a_ty->is_integer_ty() || !b_ty->is_integer_ty())
|
||||||
throw_unreachable("augment_types");
|
throw_unreachable("augment_types");
|
||||||
// 4 ) both operands are integer and undergo
|
// 4 ) both operands are integer and undergo
|
||||||
|
@@ -22,12 +22,10 @@ type *type::get_scalar_ty() const {
|
|||||||
unsigned type::get_primitive_size_in_bits() const {
|
unsigned type::get_primitive_size_in_bits() const {
|
||||||
switch (id_) {
|
switch (id_) {
|
||||||
case FP8TyID: return 8;
|
case FP8TyID: return 8;
|
||||||
case HalfTyID: return 16;
|
case FP16TyID: return 16;
|
||||||
case FloatTyID: return 32;
|
case BF16TyID: return 16;
|
||||||
case DoubleTyID: return 64;
|
case FP32TyID: return 32;
|
||||||
case X86_FP80TyID: return 80;
|
case FP64TyID: return 64;
|
||||||
case FP128TyID: return 128;
|
|
||||||
case PPC_FP128TyID: return 128;
|
|
||||||
case IntegerTyID: return ((integer_type*)(this))->get_bitwidth();
|
case IntegerTyID: return ((integer_type*)(this))->get_bitwidth();
|
||||||
case BlockTyID: return ((block_type*)(this))->get_bitwidth();
|
case BlockTyID: return ((block_type*)(this))->get_bitwidth();
|
||||||
default: return 0;
|
default: return 0;
|
||||||
@@ -44,9 +42,10 @@ unsigned type::get_fp_mantissa_width() const {
|
|||||||
id_t id = get_scalar_ty()->id_;
|
id_t id = get_scalar_ty()->id_;
|
||||||
assert(is_floating_point_ty() && "Not a floating point type!");
|
assert(is_floating_point_ty() && "Not a floating point type!");
|
||||||
if (id == FP8TyID) return 3;
|
if (id == FP8TyID) return 3;
|
||||||
if (id == HalfTyID) return 10;
|
if (id == FP16TyID) return 10;
|
||||||
if (id == FloatTyID) return 23;
|
if (id == BF16TyID) return 7;
|
||||||
if (id == DoubleTyID) return 53;
|
if (id == FP32TyID) return 23;
|
||||||
|
if (id == FP64TyID) return 53;
|
||||||
throw std::runtime_error("unreachable");
|
throw std::runtime_error("unreachable");
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -105,7 +104,7 @@ bool type::is_integer_ty(unsigned width) const
|
|||||||
|
|
||||||
|
|
||||||
bool type::is_floating_point_ty() const
|
bool type::is_floating_point_ty() const
|
||||||
{ return is_fp8_ty() || is_half_ty() || is_float_ty() || is_double_ty(); }
|
{ return is_fp8_ty() || is_fp16_ty() || is_bf16_ty() || is_fp32_ty() || is_fp64_ty(); }
|
||||||
|
|
||||||
bool type::is_sized() const {
|
bool type::is_sized() const {
|
||||||
// primitive types are sized
|
// primitive types are sized
|
||||||
@@ -124,9 +123,10 @@ type *type::get_void_ty(context &ctx) { return &ctx.p_impl->void_ty; }
|
|||||||
type *type::get_label_ty(context &ctx) { return &ctx.p_impl->label_ty; }
|
type *type::get_label_ty(context &ctx) { return &ctx.p_impl->label_ty; }
|
||||||
// floating point
|
// floating point
|
||||||
type *type::get_fp8_ty(context &ctx) { return &ctx.p_impl->fp8_ty; }
|
type *type::get_fp8_ty(context &ctx) { return &ctx.p_impl->fp8_ty; }
|
||||||
type *type::get_half_ty(context &ctx) { return &ctx.p_impl->half_ty; }
|
type *type::get_fp16_ty(context &ctx) { return &ctx.p_impl->fp16_ty; }
|
||||||
type *type::get_float_ty(context &ctx) { return &ctx.p_impl->float_ty; }
|
type *type::get_bf16_ty(context &ctx) { return &ctx.p_impl->bf16_ty; }
|
||||||
type *type::get_double_ty(context &ctx) { return &ctx.p_impl->double_ty; }
|
type *type::get_fp32_ty(context &ctx) { return &ctx.p_impl->fp32_ty; }
|
||||||
|
type *type::get_fp64_ty(context &ctx) { return &ctx.p_impl->fp64_ty; }
|
||||||
// integer types
|
// integer types
|
||||||
integer_type *type::get_int1_ty(context &ctx) { return &ctx.p_impl->int1_ty; }
|
integer_type *type::get_int1_ty(context &ctx) { return &ctx.p_impl->int1_ty; }
|
||||||
integer_type *type::get_int8_ty(context &ctx) { return &ctx.p_impl->int8_ty; }
|
integer_type *type::get_int8_ty(context &ctx) { return &ctx.p_impl->int8_ty; }
|
||||||
|
@@ -49,7 +49,7 @@ class CMakeBuild(build_ext):
|
|||||||
self.build_extension(ext)
|
self.build_extension(ext)
|
||||||
|
|
||||||
def build_extension(self, ext):
|
def build_extension(self, ext):
|
||||||
#self.debug = True
|
# self.debug = True
|
||||||
extdir = os.path.abspath(os.path.dirname(self.get_ext_fullpath(ext.path)))
|
extdir = os.path.abspath(os.path.dirname(self.get_ext_fullpath(ext.path)))
|
||||||
# create build directories
|
# create build directories
|
||||||
build_suffix = 'debug' if self.debug else 'release'
|
build_suffix = 'debug' if self.debug else 'release'
|
||||||
|
@@ -204,9 +204,10 @@ void init_triton_ir(py::module &&m) {
|
|||||||
.def("make_block", &ir::block_type::get, ret::reference)
|
.def("make_block", &ir::block_type::get, ret::reference)
|
||||||
.def("get_void", &ir::type::get_void_ty, ret::reference)
|
.def("get_void", &ir::type::get_void_ty, ret::reference)
|
||||||
.def("get_fp8", &ir::type::get_fp8_ty, ret::reference)
|
.def("get_fp8", &ir::type::get_fp8_ty, ret::reference)
|
||||||
.def("get_fp16", &ir::type::get_half_ty, ret::reference)
|
.def("get_fp16", &ir::type::get_fp16_ty, ret::reference)
|
||||||
.def("get_fp32", &ir::type::get_float_ty, ret::reference)
|
.def("get_bf16", &ir::type::get_bf16_ty, ret::reference)
|
||||||
.def("get_fp64", &ir::type::get_double_ty, ret::reference)
|
.def("get_fp32", &ir::type::get_fp32_ty, ret::reference)
|
||||||
|
.def("get_fp64", &ir::type::get_fp64_ty, ret::reference)
|
||||||
.def("get_int1", &ir::type::get_int1_ty, ret::reference)
|
.def("get_int1", &ir::type::get_int1_ty, ret::reference)
|
||||||
.def("get_int8", &ir::type::get_int8_ty, ret::reference)
|
.def("get_int8", &ir::type::get_int8_ty, ret::reference)
|
||||||
.def("get_int16", &ir::type::get_int16_ty, ret::reference)
|
.def("get_int16", &ir::type::get_int16_ty, ret::reference)
|
||||||
@@ -215,9 +216,10 @@ void init_triton_ir(py::module &&m) {
|
|||||||
|
|
||||||
.def("is_void", &ir::type::is_void_ty)
|
.def("is_void", &ir::type::is_void_ty)
|
||||||
.def("is_fp8", &ir::type::is_fp8_ty)
|
.def("is_fp8", &ir::type::is_fp8_ty)
|
||||||
.def("is_fp16", &ir::type::is_half_ty)
|
.def("is_fp16", &ir::type::is_fp16_ty)
|
||||||
.def("is_fp32", &ir::type::is_float_ty)
|
.def("is_bf16", &ir::type::is_bf16_ty)
|
||||||
.def("is_fp64", &ir::type::is_double_ty)
|
.def("is_fp32", &ir::type::is_fp32_ty)
|
||||||
|
.def("is_fp64", &ir::type::is_fp64_ty)
|
||||||
.def("is_int1", [](ir::type *self) { return self->is_integer_ty(1); })
|
.def("is_int1", [](ir::type *self) { return self->is_integer_ty(1); })
|
||||||
.def("is_int8", [](ir::type *self) { return self->is_integer_ty(8); })
|
.def("is_int8", [](ir::type *self) { return self->is_integer_ty(8); })
|
||||||
.def("is_int16", [](ir::type *self) { return self->is_integer_ty(16); })
|
.def("is_int16", [](ir::type *self) { return self->is_integer_ty(16); })
|
||||||
|
@@ -16,6 +16,7 @@ cvt = {
|
|||||||
'int16': torch.int16,
|
'int16': torch.int16,
|
||||||
'int32': torch.int32,
|
'int32': torch.int32,
|
||||||
'int64': torch.int64,
|
'int64': torch.int64,
|
||||||
|
'bfloat16': torch.bfloat16,
|
||||||
'float16': torch.float16,
|
'float16': torch.float16,
|
||||||
'float32': torch.float32,
|
'float32': torch.float32,
|
||||||
'float64': torch.float64,
|
'float64': torch.float64,
|
||||||
@@ -292,9 +293,12 @@ def test_atomic_rmw(op, dtype_x, mode, device='cuda'):
|
|||||||
# test cast
|
# test cast
|
||||||
# ---------------
|
# ---------------
|
||||||
@pytest.mark.parametrize("dtype_x, dtype_z, bitcast", [
|
@pytest.mark.parametrize("dtype_x, dtype_z, bitcast", [
|
||||||
(dtype_x, dtype_z, False) for dtype_x in dtypes \
|
(dtype_x, dtype_z, False) \
|
||||||
for dtype_z in dtypes
|
for dtype_x in dtypes\
|
||||||
|
for dtype_z in dtypes
|
||||||
] + [
|
] + [
|
||||||
|
('float32', 'bfloat16', False),
|
||||||
|
('bfloat16', 'float32', False),
|
||||||
('float32', 'int32', True)
|
('float32', 'int32', True)
|
||||||
])
|
])
|
||||||
def test_cast(dtype_x, dtype_z, bitcast, device='cuda'):
|
def test_cast(dtype_x, dtype_z, bitcast, device='cuda'):
|
||||||
|
@@ -465,6 +465,7 @@ class Kernel:
|
|||||||
float: 'f',
|
float: 'f',
|
||||||
bool: 'B',
|
bool: 'B',
|
||||||
triton.language.float8: 'f8',
|
triton.language.float8: 'f8',
|
||||||
|
torch.bfloat16: 'bf16',
|
||||||
torch.float16: 'f16',
|
torch.float16: 'f16',
|
||||||
torch.float32: 'f32',
|
torch.float32: 'f32',
|
||||||
torch.float64: 'f64',
|
torch.float64: 'f64',
|
||||||
@@ -484,6 +485,7 @@ class Kernel:
|
|||||||
'B': _triton.ir.type.get_int1,
|
'B': _triton.ir.type.get_int1,
|
||||||
'f8': _triton.ir.type.get_fp8,
|
'f8': _triton.ir.type.get_fp8,
|
||||||
'f16': _triton.ir.type.get_fp16,
|
'f16': _triton.ir.type.get_fp16,
|
||||||
|
'bf16': _triton.ir.type.get_bf16,
|
||||||
'f32': _triton.ir.type.get_fp32,
|
'f32': _triton.ir.type.get_fp32,
|
||||||
'f64': _triton.ir.type.get_fp64,
|
'f64': _triton.ir.type.get_fp64,
|
||||||
'i1': _triton.ir.type.get_int1,
|
'i1': _triton.ir.type.get_int1,
|
||||||
@@ -555,6 +557,7 @@ class Kernel:
|
|||||||
if len(tensor_idxs) == 0:
|
if len(tensor_idxs) == 0:
|
||||||
raise ValueError("No Tensor argument found.")
|
raise ValueError("No Tensor argument found.")
|
||||||
device = wargs[tensor_idxs[0]].device
|
device = wargs[tensor_idxs[0]].device
|
||||||
|
torch.cuda.set_device(device.index)
|
||||||
# attributes
|
# attributes
|
||||||
args = [arg.data_ptr() if i in tensor_idxs else arg for i, arg in enumerate(wargs)]
|
args = [arg.data_ptr() if i in tensor_idxs else arg for i, arg in enumerate(wargs)]
|
||||||
attributes = {i: Kernel.pow2_divisor(a) for i, a in enumerate(args) if isinstance(a, int)}
|
attributes = {i: Kernel.pow2_divisor(a) for i, a in enumerate(args) if isinstance(a, int)}
|
||||||
|
@@ -86,6 +86,7 @@ int32 = dtype(ir.type.get_int32)
|
|||||||
int64 = dtype(ir.type.get_int64)
|
int64 = dtype(ir.type.get_int64)
|
||||||
float8 = dtype(ir.type.get_fp8)
|
float8 = dtype(ir.type.get_fp8)
|
||||||
float16 = dtype(ir.type.get_fp16)
|
float16 = dtype(ir.type.get_fp16)
|
||||||
|
bfloat16 = dtype(ir.type.get_bf16)
|
||||||
float32 = dtype(ir.type.get_fp32)
|
float32 = dtype(ir.type.get_fp32)
|
||||||
float64 = dtype(ir.type.get_fp64)
|
float64 = dtype(ir.type.get_fp64)
|
||||||
|
|
||||||
@@ -103,6 +104,7 @@ class block:
|
|||||||
if ir_type.is_int64(): return int64
|
if ir_type.is_int64(): return int64
|
||||||
if ir_type.is_fp8(): return float8
|
if ir_type.is_fp8(): return float8
|
||||||
if ir_type.is_fp16(): return float16
|
if ir_type.is_fp16(): return float16
|
||||||
|
if ir_type.is_bf16(): return bfloat16
|
||||||
if ir_type.is_fp32(): return float32
|
if ir_type.is_fp32(): return float32
|
||||||
if ir_type.is_fp64(): return float64
|
if ir_type.is_fp64(): return float64
|
||||||
# pointer type
|
# pointer type
|
||||||
|
Reference in New Issue
Block a user