[LANG] Preliminary FP8 support (#96)
This commit is contained in:
committed by
Philippe Tillet
parent
4290be1ae8
commit
7355efa745
@@ -102,6 +102,10 @@ public:
|
|||||||
void visit_getelementptr_inst(ir::getelementptr_inst*);
|
void visit_getelementptr_inst(ir::getelementptr_inst*);
|
||||||
void visit_icmp_inst(ir::icmp_inst*);
|
void visit_icmp_inst(ir::icmp_inst*);
|
||||||
void visit_fcmp_inst(ir::fcmp_inst*);
|
void visit_fcmp_inst(ir::fcmp_inst*);
|
||||||
|
std::tuple<Value*, Value*, Value*, Value*> fp8x4_to_fp32x4(Value *in0, Value *in1, Value *in2, Value *in3);
|
||||||
|
std::tuple<Value*, Value*, Value*, Value*> fp32x4_to_fp8x4(Value *in0, Value *in1, Value *in2, Value *in3);
|
||||||
|
std::tuple<Value*, Value*, Value*, Value*> fp8x4_to_fp16x4(Value *in0, Value *in1, Value *in2, Value *in3);
|
||||||
|
std::tuple<Value*, Value*, Value*, Value*> fp16x4_to_fp8x4(Value *in0, Value *in1, Value *in2, Value *in3);
|
||||||
void visit_cast_inst(ir::cast_inst*);
|
void visit_cast_inst(ir::cast_inst*);
|
||||||
void visit_return_inst(ir::return_inst*);
|
void visit_return_inst(ir::return_inst*);
|
||||||
void visit_cond_branch_inst(ir::cond_branch_inst*);
|
void visit_cond_branch_inst(ir::cond_branch_inst*);
|
||||||
|
@@ -22,9 +22,11 @@ public:
|
|||||||
context_impl(context &ctx);
|
context_impl(context &ctx);
|
||||||
|
|
||||||
public:
|
public:
|
||||||
// primitive types
|
// non-numeric types
|
||||||
type void_ty, label_ty, half_ty, float_ty, double_ty;
|
type void_ty, label_ty;
|
||||||
// derived types
|
// floating point types
|
||||||
|
type fp8_ty, half_ty, float_ty, double_ty;
|
||||||
|
// integer types
|
||||||
integer_type int1_ty, int8_ty, int16_ty, int32_ty, int64_ty, int128_ty;
|
integer_type int1_ty, int8_ty, int16_ty, int32_ty, int64_ty, int128_ty;
|
||||||
// Pointer types
|
// Pointer types
|
||||||
std::map<std::pair<type*, unsigned>, pointer_type*> ptr_tys;
|
std::map<std::pair<type*, unsigned>, pointer_type*> ptr_tys;
|
||||||
|
@@ -29,21 +29,22 @@ public:
|
|||||||
enum id_t {
|
enum id_t {
|
||||||
// primitive types
|
// primitive types
|
||||||
VoidTyID = 0, ///< 0: type with no size
|
VoidTyID = 0, ///< 0: type with no size
|
||||||
HalfTyID, ///< 1: 16-bit floating point type
|
FP8TyID, ///< 1: 8-bit floating point type (3 bits mantissa)
|
||||||
FloatTyID, ///< 2: 32-bit floating point type
|
HalfTyID, ///< 3: 16-bit floating point type
|
||||||
DoubleTyID, ///< 3: 64-bit floating point type
|
FloatTyID, ///< 4: 32-bit floating point type
|
||||||
X86_FP80TyID, ///< 4: 80-bit floating point type (X87)
|
DoubleTyID, ///< 5: 64-bit floating point type
|
||||||
FP128TyID, ///< 5: 128-bit floating point type (112-bit mantissa)
|
X86_FP80TyID, ///< 6: 80-bit floating point type (X87)
|
||||||
PPC_FP128TyID, ///< 6: 128-bit floating point type (two 64-bits, PowerPC)
|
FP128TyID, ///< 7: 128-bit floating point type (112-bit mantissa)
|
||||||
LabelTyID, ///< 7: Labels
|
PPC_FP128TyID, ///< 8: 128-bit floating point type (two 64-bits, PowerPC)
|
||||||
MetadataTyID, ///< 8: Metadata
|
LabelTyID, ///< 9: Labels
|
||||||
TokenTyID, ///< 9: Token
|
MetadataTyID, ///< 10: Metadata
|
||||||
|
TokenTyID, ///< 11: Token
|
||||||
// derived types
|
// derived types
|
||||||
IntegerTyID, ///< 10: Arbitrary bit width integers
|
IntegerTyID, ///< 12: Arbitrary bit width integers
|
||||||
FunctionTyID, ///< 11: Functions
|
FunctionTyID, ///< 13: Functions
|
||||||
PointerTyID, ///< 12: Pointers
|
PointerTyID, ///< 14: Pointers
|
||||||
StructTyID, ///< 13: Struct
|
StructTyID, ///< 15: Struct
|
||||||
BlockTyID, ///< 14: Tile
|
BlockTyID, ///< 16: Block
|
||||||
};
|
};
|
||||||
|
|
||||||
public:
|
public:
|
||||||
@@ -72,6 +73,7 @@ public:
|
|||||||
|
|
||||||
// primitive predicates
|
// primitive predicates
|
||||||
bool is_void_ty() const { return id_ == VoidTyID; }
|
bool is_void_ty() const { return id_ == VoidTyID; }
|
||||||
|
bool is_fp8_ty() const { return id_ == FP8TyID; }
|
||||||
bool is_half_ty() const { return id_ == HalfTyID; }
|
bool is_half_ty() const { return id_ == HalfTyID; }
|
||||||
bool is_float_ty() const { return id_ == FloatTyID; }
|
bool is_float_ty() const { return id_ == FloatTyID; }
|
||||||
bool is_double_ty() const { return id_ == DoubleTyID; }
|
bool is_double_ty() const { return id_ == DoubleTyID; }
|
||||||
@@ -96,6 +98,7 @@ public:
|
|||||||
static type *get_void_ty(context &ctx);
|
static type *get_void_ty(context &ctx);
|
||||||
static type *get_label_ty(context &ctx);
|
static type *get_label_ty(context &ctx);
|
||||||
// half
|
// half
|
||||||
|
static type *get_fp8_ty(context &ctx);
|
||||||
static type *get_half_ty(context &ctx);
|
static type *get_half_ty(context &ctx);
|
||||||
static type *get_float_ty(context &ctx);
|
static type *get_float_ty(context &ctx);
|
||||||
static type *get_double_ty(context &ctx);
|
static type *get_double_ty(context &ctx);
|
||||||
@@ -124,6 +127,7 @@ public:
|
|||||||
std::string repr() const {
|
std::string repr() const {
|
||||||
switch(id_) {
|
switch(id_) {
|
||||||
case VoidTyID: return "void";
|
case VoidTyID: return "void";
|
||||||
|
case FP8TyID: return "fp8";
|
||||||
case HalfTyID: return "f16";
|
case HalfTyID: return "f16";
|
||||||
case FloatTyID: return "f32";
|
case FloatTyID: return "f32";
|
||||||
case DoubleTyID: return "f64";
|
case DoubleTyID: return "f64";
|
||||||
|
@@ -27,6 +27,7 @@ using namespace llvm;
|
|||||||
#define void_ty builder_->getVoidTy()
|
#define void_ty builder_->getVoidTy()
|
||||||
#define f16_ty builder_->getHalfTy()
|
#define f16_ty builder_->getHalfTy()
|
||||||
#define f32_ty builder_->getFloatTy()
|
#define f32_ty builder_->getFloatTy()
|
||||||
|
#define i8_ty builder_->getInt8Ty()
|
||||||
#define i32_ty builder_->getInt32Ty()
|
#define i32_ty builder_->getInt32Ty()
|
||||||
#define vec_ty(type, num_el) VectorType::get(type, num_el, false)
|
#define vec_ty(type, num_el) VectorType::get(type, num_el, false)
|
||||||
#define ptr_ty(...) PointerType::get(__VA_ARGS__)
|
#define ptr_ty(...) PointerType::get(__VA_ARGS__)
|
||||||
@@ -60,6 +61,7 @@ using namespace llvm;
|
|||||||
#define insert_elt(...) builder_->CreateInsertElement(__VA_ARGS__)
|
#define insert_elt(...) builder_->CreateInsertElement(__VA_ARGS__)
|
||||||
#define intrinsic(...) builder_->CreateIntrinsic(__VA_ARGS__)
|
#define intrinsic(...) builder_->CreateIntrinsic(__VA_ARGS__)
|
||||||
#define load(...) builder_->CreateLoad(__VA_ARGS__)
|
#define load(...) builder_->CreateLoad(__VA_ARGS__)
|
||||||
|
#define lshr(...) builder_->CreateLShr(__VA_ARGS__)
|
||||||
#define max_num(...) builder_->CreateMaxNum(__VA_ARGS__)
|
#define max_num(...) builder_->CreateMaxNum(__VA_ARGS__)
|
||||||
#define min_num(...) builder_->CreateMinNum(__VA_ARGS__)
|
#define min_num(...) builder_->CreateMinNum(__VA_ARGS__)
|
||||||
#define mul(...) builder_->CreateMul(__VA_ARGS__)
|
#define mul(...) builder_->CreateMul(__VA_ARGS__)
|
||||||
@@ -69,6 +71,7 @@ using namespace llvm;
|
|||||||
#define select(...) builder_->CreateSelect(__VA_ARGS__)
|
#define select(...) builder_->CreateSelect(__VA_ARGS__)
|
||||||
#define store(...) builder_->CreateStore(__VA_ARGS__)
|
#define store(...) builder_->CreateStore(__VA_ARGS__)
|
||||||
#define sub(...) builder_->CreateSub(__VA_ARGS__)
|
#define sub(...) builder_->CreateSub(__VA_ARGS__)
|
||||||
|
#define shl(...) builder_->CreateShl(__VA_ARGS__)
|
||||||
#define udiv(...) builder_->CreateUDiv(__VA_ARGS__)
|
#define udiv(...) builder_->CreateUDiv(__VA_ARGS__)
|
||||||
#define urem(...) builder_->CreateURem(__VA_ARGS__)
|
#define urem(...) builder_->CreateURem(__VA_ARGS__)
|
||||||
#define splat(...) builder_->CreateVectorSplat(__VA_ARGS__)
|
#define splat(...) builder_->CreateVectorSplat(__VA_ARGS__)
|
||||||
@@ -101,6 +104,7 @@ Type *generator::cvt(ir::type *ty) {
|
|||||||
// primitive types
|
// primitive types
|
||||||
switch(ty->get_type_id()){
|
switch(ty->get_type_id()){
|
||||||
case ir::type::VoidTyID: return Type::getVoidTy(*ctx_);
|
case ir::type::VoidTyID: return Type::getVoidTy(*ctx_);
|
||||||
|
case ir::type::FP8TyID: return Type::getInt8Ty(*ctx_);
|
||||||
case ir::type::HalfTyID: return Type::getHalfTy(*ctx_);
|
case ir::type::HalfTyID: return Type::getHalfTy(*ctx_);
|
||||||
case ir::type::FloatTyID: return Type::getFloatTy(*ctx_);
|
case ir::type::FloatTyID: return Type::getFloatTy(*ctx_);
|
||||||
case ir::type::DoubleTyID: return Type::getDoubleTy(*ctx_);
|
case ir::type::DoubleTyID: return Type::getDoubleTy(*ctx_);
|
||||||
@@ -316,10 +320,108 @@ void generator::visit_fcmp_inst(ir::fcmp_inst* x) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
std::tuple<Value*, Value*, Value*, Value*> generator::fp32x4_to_fp8x4(Value *in0, Value *in1, Value *in2, Value *in3){
|
||||||
|
InlineAsm *ptx = InlineAsm::get(FunctionType::get(i32_ty, {f32_ty, f32_ty, f32_ty, f32_ty}, false),
|
||||||
|
"{ \n\t"
|
||||||
|
".reg .b32 b<4>; \n\t"
|
||||||
|
"shl.b32 b0, $1, 4; \n\t" // shift into into upper byte
|
||||||
|
"shl.b32 b1, $2, 4; \n\t"
|
||||||
|
"shl.b32 b2, $3, 4; \n\t"
|
||||||
|
"shl.b32 b3, $4, 4; \n\t"
|
||||||
|
"lop3.b32 b0, b0, 0x80000000, $1, 0xb8; \n\t" // restore sign
|
||||||
|
"lop3.b32 b1, b1, 0x80000000, $2, 0xb8; \n\t"
|
||||||
|
"lop3.b32 b2, b2, 0x80000000, $3, 0xb8; \n\t"
|
||||||
|
"lop3.b32 b3, b3, 0x80000000, $4, 0xb8; \n\t"
|
||||||
|
"prmt.b32 b0, b0, b1, 0x6273; \n\t" // pack lower half b0, b1 (62 unused here)
|
||||||
|
"prmt.b32 b2, b2, b3, 0x6273; \n\t" // pack lower half b2, b3 (62 unused here)
|
||||||
|
"prmt.b32 $0, b0, b2, 0x5410; \n\t" // pack full b0, b1, b2, b3
|
||||||
|
"}", "=r, r, r, r, r", false);
|
||||||
|
Value *packed_ret = call(ptx, {in0, in1, in2, in3});
|
||||||
|
Value* ret = bit_cast(packed_ret, vec_ty(i8_ty, 4));
|
||||||
|
return std::make_tuple(extract_elt(ret, (int)0),
|
||||||
|
extract_elt(ret, (int)1),
|
||||||
|
extract_elt(ret, (int)2),
|
||||||
|
extract_elt(ret, (int)3));
|
||||||
|
}
|
||||||
|
|
||||||
|
std::tuple<Value*, Value*, Value*, Value*> generator::fp8x4_to_fp32x4(Value *in0, Value *in1, Value *in2, Value *in3){
|
||||||
|
Value *ret0, *ret1, *ret2, *ret3;
|
||||||
|
std::tie(ret0, ret1, ret2, ret3) = fp8x4_to_fp16x4(in0, in1, in2, in3);
|
||||||
|
ret0 = cast(llvm::Instruction::FPExt, ret0, f32_ty);
|
||||||
|
ret1 = cast(llvm::Instruction::FPExt, ret1, f32_ty);
|
||||||
|
ret2 = cast(llvm::Instruction::FPExt, ret2, f32_ty);
|
||||||
|
ret3 = cast(llvm::Instruction::FPExt, ret3, f32_ty);
|
||||||
|
return std::make_tuple(ret0, ret1, ret2, ret3);
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
std::tuple<Value*, Value*, Value*, Value*> generator::fp8x4_to_fp16x4(Value *in0, Value *in1, Value *in2, Value *in3){
|
||||||
|
Type *ret_ty = StructType::get(*ctx_, {vec_ty(f16_ty, 2), vec_ty(f16_ty, 2)});
|
||||||
|
InlineAsm *ptx = InlineAsm::get(FunctionType::get(ret_ty, {i32_ty}, false),
|
||||||
|
"{"
|
||||||
|
".reg .b32 a<2>, b<2>; \n\t"
|
||||||
|
"prmt.b32 a0, 0, $2, 0x5140; \n\t"
|
||||||
|
"prmt.b32 a1, 0, $2, 0x7362; \n\t"
|
||||||
|
"lop3.b32 b0, a0, 0x7fff7fff, 0, 0xc0; \n\t" // strip sign
|
||||||
|
"lop3.b32 b1, a1, 0x7fff7fff, 0, 0xc0; \n\t"
|
||||||
|
"shr.b32 b0, b0, 1; \n\t" // shift into fp16 poistion
|
||||||
|
"shr.b32 b1, b1, 1; \n\t"
|
||||||
|
"lop3.b32 $0, b0, 0x80008000, a0, 0xf8; \n\t" // restore sign
|
||||||
|
"lop3.b32 $1, b1, 0x80008000, a1, 0xf8; \n\t"
|
||||||
|
"}", "=r,=r,r", false);
|
||||||
|
Value *packed_in = UndefValue::get(vec_ty(i8_ty, 4));
|
||||||
|
packed_in = insert_elt(packed_in, in0, (int)0);
|
||||||
|
packed_in = insert_elt(packed_in, in1, (int)1);
|
||||||
|
packed_in = insert_elt(packed_in, in2, (int)2);
|
||||||
|
packed_in = insert_elt(packed_in, in3, (int)3);
|
||||||
|
Value *in = bit_cast(packed_in, i32_ty);
|
||||||
|
Value *ret = call(ptx, {in});
|
||||||
|
Value *packed_ret0 = extract_val(ret, {0});
|
||||||
|
Value *packed_ret1 = extract_val(ret, {1});
|
||||||
|
Value *ret0 = extract_elt(packed_ret0, (int)0);
|
||||||
|
Value *ret1 = extract_elt(packed_ret0, (int)1);
|
||||||
|
Value *ret2 = extract_elt(packed_ret1, (int)0);
|
||||||
|
Value *ret3 = extract_elt(packed_ret1, (int)1);
|
||||||
|
return std::make_tuple(ret0, ret1, ret2, ret3);
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* \brief Code Generation for `cast`
|
* \brief Code Generation for `cast`
|
||||||
*/
|
*/
|
||||||
void generator::visit_cast_inst(ir::cast_inst* x) {
|
void generator::visit_cast_inst(ir::cast_inst* x) {
|
||||||
|
// <> FP8
|
||||||
|
ir::value *op = x->get_operand(0);
|
||||||
|
ir::type* ret_sca_ty = x->get_type()->get_scalar_ty();
|
||||||
|
ir::type* op_sca_ty = op->get_type()->get_scalar_ty();
|
||||||
|
if(ret_sca_ty->is_fp8_ty() || op_sca_ty->is_fp8_ty()){
|
||||||
|
// ensure that conversions can be vectorized
|
||||||
|
int ld = layouts_->get(x)->get_order(0);
|
||||||
|
int contiguous = layouts_->get(x)->to_scanline()->nts(ld);
|
||||||
|
if(contiguous % 4 != 0)
|
||||||
|
throw std::runtime_error("unsupported fp32 -> fp8 conversion");
|
||||||
|
auto x_idxs = idxs_.at(x);
|
||||||
|
auto op_idxs = idxs_.at(op);
|
||||||
|
// run the conversion
|
||||||
|
auto cvt = [&](Value* a, Value* b, Value* c, Value* d){
|
||||||
|
if(op_sca_ty->is_fp8_ty() && ret_sca_ty->is_half_ty())
|
||||||
|
return fp8x4_to_fp16x4(a, b, c, d);
|
||||||
|
throw std::runtime_error("unsupported conversion");
|
||||||
|
};
|
||||||
|
for(size_t i = 0; i < x_idxs.size(); i+=4){
|
||||||
|
std::tie(vals_[x][x_idxs[i+0]],
|
||||||
|
vals_[x][x_idxs[i+1]],
|
||||||
|
vals_[x][x_idxs[i+2]],
|
||||||
|
vals_[x][x_idxs[i+3]]) = cvt(vals_[op][op_idxs[i+0]],
|
||||||
|
vals_[op][op_idxs[i+1]],
|
||||||
|
vals_[op][op_idxs[i+2]],
|
||||||
|
vals_[op][op_idxs[i+3]]);
|
||||||
|
}
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
Type *ty = cvt(x->get_type()->get_scalar_ty());
|
Type *ty = cvt(x->get_type()->get_scalar_ty());
|
||||||
auto cvt = [](ir::cast_op_t op){
|
auto cvt = [](ir::cast_op_t op){
|
||||||
using ll = llvm::Instruction::CastOps;
|
using ll = llvm::Instruction::CastOps;
|
||||||
|
@@ -12,9 +12,12 @@ namespace ir{
|
|||||||
context_impl::context_impl(context &ctx)
|
context_impl::context_impl(context &ctx)
|
||||||
: void_ty(ctx, type::VoidTyID),
|
: void_ty(ctx, type::VoidTyID),
|
||||||
label_ty(ctx, type::LabelTyID),
|
label_ty(ctx, type::LabelTyID),
|
||||||
|
// floating point
|
||||||
|
fp8_ty(ctx, type::FP8TyID),
|
||||||
half_ty(ctx, type::HalfTyID),
|
half_ty(ctx, type::HalfTyID),
|
||||||
float_ty(ctx, type::FloatTyID),
|
float_ty(ctx, type::FloatTyID),
|
||||||
double_ty(ctx, type::DoubleTyID),
|
double_ty(ctx, type::DoubleTyID),
|
||||||
|
// integers
|
||||||
int1_ty(ctx, 1),
|
int1_ty(ctx, 1),
|
||||||
int8_ty(ctx, 8),
|
int8_ty(ctx, 8),
|
||||||
int16_ty(ctx, 16),
|
int16_ty(ctx, 16),
|
||||||
|
@@ -21,6 +21,7 @@ type *type::get_scalar_ty() const {
|
|||||||
|
|
||||||
unsigned type::get_primitive_size_in_bits() const {
|
unsigned type::get_primitive_size_in_bits() const {
|
||||||
switch (id_) {
|
switch (id_) {
|
||||||
|
case FP8TyID: return 8;
|
||||||
case HalfTyID: return 16;
|
case HalfTyID: return 16;
|
||||||
case FloatTyID: return 32;
|
case FloatTyID: return 32;
|
||||||
case DoubleTyID: return 64;
|
case DoubleTyID: return 64;
|
||||||
@@ -42,6 +43,7 @@ unsigned type::get_tile_bitwidth() const
|
|||||||
unsigned type::get_fp_mantissa_width() const {
|
unsigned type::get_fp_mantissa_width() const {
|
||||||
id_t id = get_scalar_ty()->id_;
|
id_t id = get_scalar_ty()->id_;
|
||||||
assert(is_floating_point_ty() && "Not a floating point type!");
|
assert(is_floating_point_ty() && "Not a floating point type!");
|
||||||
|
if (id == FP8TyID) return 3;
|
||||||
if (id == HalfTyID) return 10;
|
if (id == HalfTyID) return 10;
|
||||||
if (id == FloatTyID) return 23;
|
if (id == FloatTyID) return 23;
|
||||||
if (id == DoubleTyID) return 53;
|
if (id == DoubleTyID) return 53;
|
||||||
@@ -103,7 +105,7 @@ bool type::is_integer_ty(unsigned width) const
|
|||||||
|
|
||||||
|
|
||||||
bool type::is_floating_point_ty() const
|
bool type::is_floating_point_ty() const
|
||||||
{ return is_half_ty() || is_float_ty() || is_double_ty(); }
|
{ return is_fp8_ty() || is_half_ty() || is_float_ty() || is_double_ty(); }
|
||||||
|
|
||||||
bool type::is_sized() const {
|
bool type::is_sized() const {
|
||||||
// primitive types are sized
|
// primitive types are sized
|
||||||
@@ -120,7 +122,8 @@ bool type::is_sized() const {
|
|||||||
// primitive types
|
// primitive types
|
||||||
type *type::get_void_ty(context &ctx) { return &ctx.p_impl->void_ty; }
|
type *type::get_void_ty(context &ctx) { return &ctx.p_impl->void_ty; }
|
||||||
type *type::get_label_ty(context &ctx) { return &ctx.p_impl->label_ty; }
|
type *type::get_label_ty(context &ctx) { return &ctx.p_impl->label_ty; }
|
||||||
// half
|
// floating point
|
||||||
|
type *type::get_fp8_ty(context &ctx) { return &ctx.p_impl->fp8_ty; }
|
||||||
type *type::get_half_ty(context &ctx) { return &ctx.p_impl->half_ty; }
|
type *type::get_half_ty(context &ctx) { return &ctx.p_impl->half_ty; }
|
||||||
type *type::get_float_ty(context &ctx) { return &ctx.p_impl->float_ty; }
|
type *type::get_float_ty(context &ctx) { return &ctx.p_impl->float_ty; }
|
||||||
type *type::get_double_ty(context &ctx) { return &ctx.p_impl->double_ty; }
|
type *type::get_double_ty(context &ctx) { return &ctx.p_impl->double_ty; }
|
||||||
|
@@ -193,6 +193,7 @@ void init_triton_ir(py::module &&m) {
|
|||||||
.def("make_function", &ir::function_type::get, ret::reference)
|
.def("make_function", &ir::function_type::get, ret::reference)
|
||||||
.def("make_block", &ir::block_type::get, ret::reference)
|
.def("make_block", &ir::block_type::get, ret::reference)
|
||||||
.def("get_void", &ir::type::get_void_ty, ret::reference)
|
.def("get_void", &ir::type::get_void_ty, ret::reference)
|
||||||
|
.def("get_fp8", &ir::type::get_fp8_ty, ret::reference)
|
||||||
.def("get_fp16", &ir::type::get_half_ty, ret::reference)
|
.def("get_fp16", &ir::type::get_half_ty, ret::reference)
|
||||||
.def("get_fp32", &ir::type::get_float_ty, ret::reference)
|
.def("get_fp32", &ir::type::get_float_ty, ret::reference)
|
||||||
.def("get_fp64", &ir::type::get_double_ty, ret::reference)
|
.def("get_fp64", &ir::type::get_double_ty, ret::reference)
|
||||||
@@ -203,6 +204,7 @@ void init_triton_ir(py::module &&m) {
|
|||||||
.def("get_int64", &ir::type::get_int64_ty, ret::reference)
|
.def("get_int64", &ir::type::get_int64_ty, ret::reference)
|
||||||
|
|
||||||
.def("is_void", &ir::type::is_void_ty)
|
.def("is_void", &ir::type::is_void_ty)
|
||||||
|
.def("is_fp8", &ir::type::is_fp8_ty)
|
||||||
.def("is_fp16", &ir::type::is_half_ty)
|
.def("is_fp16", &ir::type::is_half_ty)
|
||||||
.def("is_fp32", &ir::type::is_float_ty)
|
.def("is_fp32", &ir::type::is_float_ty)
|
||||||
.def("is_fp64", &ir::type::is_double_ty)
|
.def("is_fp64", &ir::type::is_double_ty)
|
||||||
|
@@ -2,7 +2,7 @@
|
|||||||
# or pybind11 shows `munmap_chunk(): invalid pointer`
|
# or pybind11 shows `munmap_chunk(): invalid pointer`
|
||||||
import torch
|
import torch
|
||||||
# submodules
|
# submodules
|
||||||
from .code_gen import cdiv, jit, autotune, heuristics, Config, Autotuner
|
from .code_gen import cdiv, jit, autotune, heuristics, Config, Autotuner, reinterpret
|
||||||
|
|
||||||
from . import language
|
from . import language
|
||||||
from . import code_gen
|
from . import code_gen
|
||||||
|
@@ -436,20 +436,23 @@ class CompilationError(Exception):
|
|||||||
|
|
||||||
|
|
||||||
class Kernel:
|
class Kernel:
|
||||||
|
@staticmethod
|
||||||
type_names = {
|
def _type_name(obj):
|
||||||
int: 'I',
|
type_names = {
|
||||||
float: 'f',
|
int: 'I',
|
||||||
bool: 'B',
|
float: 'f',
|
||||||
torch.float16: 'f16',
|
bool: 'B',
|
||||||
torch.float32: 'f32',
|
triton.language.float8: 'f8',
|
||||||
torch.float64: 'f64',
|
torch.float16: 'f16',
|
||||||
torch.bool: 'i1',
|
torch.float32: 'f32',
|
||||||
torch.int8: 'i8',
|
torch.float64: 'f64',
|
||||||
torch.int16: 'i16',
|
torch.bool: 'i1',
|
||||||
torch.int32: 'i32',
|
torch.int8: 'i8',
|
||||||
torch.int64: 'i64',
|
torch.int16: 'i16',
|
||||||
}
|
torch.int32: 'i32',
|
||||||
|
torch.int64: 'i64',
|
||||||
|
}
|
||||||
|
return type_names[obj]
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _to_triton_ir(context, obj):
|
def _to_triton_ir(context, obj):
|
||||||
@@ -457,6 +460,7 @@ class Kernel:
|
|||||||
'I': _triton.ir.type.get_int32,
|
'I': _triton.ir.type.get_int32,
|
||||||
'f': _triton.ir.type.get_fp32,
|
'f': _triton.ir.type.get_fp32,
|
||||||
'B': _triton.ir.type.get_int1,
|
'B': _triton.ir.type.get_int1,
|
||||||
|
'f8': _triton.ir.type.get_fp8,
|
||||||
'f16': _triton.ir.type.get_fp16,
|
'f16': _triton.ir.type.get_fp16,
|
||||||
'f32': _triton.ir.type.get_fp32,
|
'f32': _triton.ir.type.get_fp32,
|
||||||
'f64': _triton.ir.type.get_fp64,
|
'f64': _triton.ir.type.get_fp64,
|
||||||
@@ -467,12 +471,12 @@ class Kernel:
|
|||||||
'i64': _triton.ir.type.get_int64,
|
'i64': _triton.ir.type.get_int64,
|
||||||
}
|
}
|
||||||
# convert torch.Tensor to Triton IR pointers
|
# convert torch.Tensor to Triton IR pointers
|
||||||
if isinstance(obj, torch.Tensor):
|
if hasattr(obj, 'data_ptr'):
|
||||||
name = Kernel.type_names[obj.dtype]
|
name = Kernel._type_name(obj.dtype)
|
||||||
elt_ty = type_map[name](context)
|
elt_ty = type_map[name](context)
|
||||||
return _triton.ir.type.make_ptr(elt_ty, 1)
|
return _triton.ir.type.make_ptr(elt_ty, 1)
|
||||||
# default path returns triton.ir.type directly
|
# default path returns triton.ir.type directly
|
||||||
name = Kernel.type_names[obj.__class__]
|
name = Kernel._type_name(obj.__class__)
|
||||||
return type_map[name](context)
|
return type_map[name](context)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
@@ -481,7 +485,7 @@ class Kernel:
|
|||||||
types_key = [None] * len(wargs)
|
types_key = [None] * len(wargs)
|
||||||
for i, arg in enumerate(wargs):
|
for i, arg in enumerate(wargs):
|
||||||
prefix = 'P' if i in tensor_idxs else ''
|
prefix = 'P' if i in tensor_idxs else ''
|
||||||
suffix = Kernel.type_names[arg.dtype] if i in tensor_idxs else Kernel.type_names[arg.__class__]
|
suffix = Kernel._type_name(arg.dtype) if i in tensor_idxs else Kernel._type_name(arg.__class__)
|
||||||
types_key[i] = prefix + suffix
|
types_key[i] = prefix + suffix
|
||||||
return tuple(types_key)
|
return tuple(types_key)
|
||||||
|
|
||||||
@@ -523,7 +527,7 @@ class Kernel:
|
|||||||
|
|
||||||
def __call__(self, *wargs, grid, num_warps=4, **meta):
|
def __call__(self, *wargs, grid, num_warps=4, **meta):
|
||||||
# device inference
|
# device inference
|
||||||
tensor_idxs = [i for i, arg in enumerate(wargs) if isinstance(arg, torch.Tensor)]
|
tensor_idxs = [i for i, arg in enumerate(wargs) if hasattr(arg, 'data_ptr')]
|
||||||
if len(tensor_idxs) == 0:
|
if len(tensor_idxs) == 0:
|
||||||
raise ValueError("No Tensor argument found.")
|
raise ValueError("No Tensor argument found.")
|
||||||
device = wargs[tensor_idxs[0]].device
|
device = wargs[tensor_idxs[0]].device
|
||||||
@@ -545,7 +549,7 @@ class Kernel:
|
|||||||
*wargs, device=device, attributes=attributes, num_warps=num_warps, constants=constants, **meta
|
*wargs, device=device, attributes=attributes, num_warps=num_warps, constants=constants, **meta
|
||||||
)
|
)
|
||||||
# pack arguments
|
# pack arguments
|
||||||
fmt = ''.join(['P' if i in tensor_idxs else Kernel.type_names[arg.__class__] for i, arg in enumerate(wargs)])
|
fmt = ''.join(['P' if i in tensor_idxs else Kernel._type_name(arg.__class__) for i, arg in enumerate(wargs)])
|
||||||
params = struct.pack(fmt, *args)
|
params = struct.pack(fmt, *args)
|
||||||
# enqueue cached function into stream
|
# enqueue cached function into stream
|
||||||
binary = cache[key]
|
binary = cache[key]
|
||||||
@@ -703,3 +707,19 @@ def jit(fn):
|
|||||||
|
|
||||||
def cdiv(x, y):
|
def cdiv(x, y):
|
||||||
return (x + y - 1) // y
|
return (x + y - 1) // y
|
||||||
|
|
||||||
|
|
||||||
|
######
|
||||||
|
|
||||||
|
|
||||||
|
class TensorWrapper:
|
||||||
|
def __init__(self, data_ptr, dtype):
|
||||||
|
self._data_ptr = data_ptr
|
||||||
|
self.dtype = dtype
|
||||||
|
|
||||||
|
def data_ptr(self):
|
||||||
|
return self._data_ptr
|
||||||
|
|
||||||
|
|
||||||
|
def reinterpret(tensor, dtype):
|
||||||
|
return TensorWrapper(tensor.data_ptr(), dtype)
|
@@ -84,6 +84,7 @@ int8 = dtype(ir.type.get_int8)
|
|||||||
int16 = dtype(ir.type.get_int16)
|
int16 = dtype(ir.type.get_int16)
|
||||||
int32 = dtype(ir.type.get_int32)
|
int32 = dtype(ir.type.get_int32)
|
||||||
int64 = dtype(ir.type.get_int64)
|
int64 = dtype(ir.type.get_int64)
|
||||||
|
float8 = dtype(ir.type.get_fp8)
|
||||||
float16 = dtype(ir.type.get_fp16)
|
float16 = dtype(ir.type.get_fp16)
|
||||||
float32 = dtype(ir.type.get_fp32)
|
float32 = dtype(ir.type.get_fp32)
|
||||||
float64 = dtype(ir.type.get_fp64)
|
float64 = dtype(ir.type.get_fp64)
|
||||||
@@ -98,6 +99,7 @@ class block:
|
|||||||
if ir_type.is_int16(): return int16
|
if ir_type.is_int16(): return int16
|
||||||
if ir_type.is_int32(): return int32
|
if ir_type.is_int32(): return int32
|
||||||
if ir_type.is_int64(): return int64
|
if ir_type.is_int64(): return int64
|
||||||
|
if ir_type.is_fp8(): return float8
|
||||||
if ir_type.is_fp16(): return float16
|
if ir_type.is_fp16(): return float16
|
||||||
if ir_type.is_fp32(): return float32
|
if ir_type.is_fp32(): return float32
|
||||||
if ir_type.is_fp64(): return float64
|
if ir_type.is_fp64(): return float64
|
||||||
|
Reference in New Issue
Block a user