[CORE] Fixed several issues that arose in the development of the
torch-blocksparse package: * Now using warp shuffle in reductions when possible * Various bugfixes in layout inference * Added INFINITY, exponential and select * Better error messages for unimplemented constructs
This commit is contained in:
committed by
Philippe Tillet
parent
ac26fbdc1f
commit
3304629de9
@@ -119,7 +119,7 @@ struct scanline_layout: public data_layout {
|
||||
int mts(size_t k) { return mts_.at(k); }
|
||||
int nts(size_t k) { return nts_.at(k); }
|
||||
|
||||
private:
|
||||
public:
|
||||
std::vector<int> mts_;
|
||||
std::vector<int> nts_;
|
||||
};
|
||||
|
@@ -111,6 +111,8 @@ public:
|
||||
void visit_broadcast_inst(ir::broadcast_inst*);
|
||||
void visit_downcast_inst(ir::downcast_inst*);
|
||||
|
||||
void visit_exp_inst(ir::exp_inst*);
|
||||
|
||||
void visit_get_program_id_inst(ir::get_program_id_inst*);
|
||||
void visit_get_num_program_inst(ir::get_num_program_inst*);
|
||||
void visit_atomic_cas_inst(ir::atomic_cas_inst*);
|
||||
|
@@ -19,6 +19,7 @@ namespace transform{
|
||||
|
||||
class peephole {
|
||||
private:
|
||||
bool rewrite_cts_cfs(ir::instruction *value, ir::builder &builder);
|
||||
bool rewrite_trans_phi(ir::instruction* value, ir::builder &builder);
|
||||
bool rewrite_dot_fp32(ir::dot_inst *dot, ir::builder& builder, bool trans_a, bool trans_b, ir::value *A, ir::value *B, ir::value *D);
|
||||
bool rewrite_dot_hmma(ir::dot_inst *dot, ir::builder& builder, bool trans_a, bool trans_b, ir::value *A, ir::value *B, ir::value *D);
|
||||
|
@@ -35,7 +35,8 @@ public:
|
||||
basic_block* get_insert_block() { return block_; }
|
||||
iterator get_insert_point() { return insert_point_;}
|
||||
// Constants
|
||||
value *get_int32(unsigned val);
|
||||
value *get_int32(int32_t val);
|
||||
value *get_int64(int64_t val);
|
||||
// Types
|
||||
type *get_void_ty();
|
||||
type *get_int1_ty();
|
||||
@@ -63,6 +64,7 @@ public:
|
||||
value* create_ret_void();
|
||||
// Cast instructions
|
||||
value *create_cast(cast_op_t op, value *v, type *dst_ty, const std::string &name = "");
|
||||
value* create_ptr_to_int(value *src, type *dst_ty, const std::string &name = "");
|
||||
value* create_si_to_fp(value *src, type *dst_ty, const std::string &name = "");
|
||||
value* create_ui_to_fp(value *src, type *dst_ty, const std::string &name = "");
|
||||
value* create_fp_to_si(value *src, type *dst_ty, const std::string &name = "");
|
||||
@@ -135,6 +137,7 @@ public:
|
||||
value *create_atomic_cas(value *ptr, value *cmp, value *val, const std::string &name = "");
|
||||
value *create_atomic_exch(value *ptr, value *val, const std::string &name = "");
|
||||
value *create_atomic_add(value *ptr, value *val, const std::string &name = "");
|
||||
value *create_exp(value* arg, const std::string &name = "");
|
||||
value *create_dot(value *A, value *B, value *C, const std::string &name = "");
|
||||
value *create_trans(value *A, const std::vector<int> &perm = {}, const std::string &name = "");
|
||||
value *create_sqrt(value *A, const std::string &name = "");
|
||||
|
@@ -127,6 +127,8 @@ enum value_id_t: unsigned {
|
||||
INST_ATOMIC_CAS,
|
||||
INST_ATOMIC_EXCH,
|
||||
INST_ATOMIC_ADD,
|
||||
// math
|
||||
INST_EXP,
|
||||
// array arithmetic
|
||||
INST_TRANS,
|
||||
INST_REDUCE,
|
||||
|
@@ -612,6 +612,17 @@ public:
|
||||
static instruction* create(value *ptr, value *val, const std::string &name = "", instruction *next = nullptr);
|
||||
};
|
||||
|
||||
class exp_inst: public builtin_inst {
|
||||
private:
|
||||
exp_inst(value *val, const std::string &name = "", instruction *next = nullptr);
|
||||
std::string repr_impl() const { return "exp"; }
|
||||
_TRITON_DEFINE_CLONE(exp_inst)
|
||||
_TRITON_DEFINE_ACCEPT(exp_inst)
|
||||
|
||||
public:
|
||||
static instruction* create(value *val, const std::string &name = "", instruction *next = nullptr);
|
||||
};
|
||||
|
||||
class dot_inst: public builtin_inst {
|
||||
public:
|
||||
enum TransT { NoTrans, Trans };
|
||||
|
@@ -81,6 +81,7 @@ public:
|
||||
bool is_integer_ty() const { return id_ == IntegerTyID; }
|
||||
bool is_integer_ty(unsigned bitwidth) { return is_integer_ty() &&
|
||||
get_integer_bitwidth() == bitwidth;}
|
||||
bool is_bool_ty() const { return is_integer_ty(1); }
|
||||
bool is_pointer_ty() const { return id_ == PointerTyID; }
|
||||
bool is_tile_ty() const { return id_ == TileTyID; }
|
||||
|
||||
|
@@ -48,6 +48,8 @@ class splat_inst;
|
||||
class broadcast_inst;
|
||||
class downcast_inst;
|
||||
|
||||
class exp_inst;
|
||||
|
||||
class get_program_id_inst;
|
||||
class get_num_program_inst;
|
||||
class atomic_cas_inst;
|
||||
@@ -114,6 +116,8 @@ public:
|
||||
virtual void visit_unmasked_store_inst(unmasked_store_inst*) = 0;
|
||||
virtual void visit_masked_store_inst(masked_store_inst*) = 0;
|
||||
|
||||
virtual void visit_exp_inst(exp_inst*) = 0;
|
||||
|
||||
virtual void visit_reshape_inst(reshape_inst*) = 0;
|
||||
virtual void visit_splat_inst(splat_inst*) = 0;
|
||||
virtual void visit_broadcast_inst(broadcast_inst*) = 0;
|
||||
|
@@ -433,6 +433,7 @@ public:
|
||||
void UnaryArithmOpTypeChecking();
|
||||
void BitcastOpTypeChecking();
|
||||
void CastOpTypeChecking();
|
||||
void IntrinsicOpTypeChecking();
|
||||
|
||||
protected:
|
||||
UnaryOp(int op, Expr* operand, QualType type=nullptr, int info=0)
|
||||
|
@@ -33,8 +33,8 @@ using LocationList = std::vector<std::string>;
|
||||
using StaticInitList = std::vector<StaticInitializer>;
|
||||
|
||||
// Error
|
||||
inline void should_not_happen() { throw std::runtime_error("should not happen"); }
|
||||
inline void error_not_implemented() { throw std::runtime_error("not implemented"); }
|
||||
inline void should_not_happen(const std::string& suffix) { throw std::runtime_error("internal compiler error: " + suffix); }
|
||||
inline void error_not_implemented(const std::string& msg) { throw std::runtime_error(msg); }
|
||||
|
||||
class Generator: public Visitor {
|
||||
friend class Evaluator<Addr>;
|
||||
@@ -87,6 +87,9 @@ protected:
|
||||
// Triton-IR attributes
|
||||
ir::attribute GenIRAttr(ASTNode::Attr attr);
|
||||
|
||||
// Triton-IR metadata
|
||||
void SetIRMetadata(ASTNode::Attr attr, ir::value *rhs);
|
||||
|
||||
// Triton-IR values
|
||||
ir::value* GenAssignOp(Expr* lvalue, ir::value* rhs);
|
||||
ir::value* GenBroadcastOp(ir::value* src, ir::type* dst_ty);
|
||||
@@ -131,22 +134,22 @@ public:
|
||||
void VisitObject(Object* obj);
|
||||
void VisitIdentifier(Identifier* ident);
|
||||
|
||||
void VisitConditionalOp(ConditionalOp*) { should_not_happen(); }
|
||||
void VisitFuncCall(FuncCall*) { should_not_happen(); }
|
||||
void VisitTransOp(TransOp*) { should_not_happen(); }
|
||||
void VisitEnumerator(Enumerator*) { should_not_happen(); }
|
||||
void VisitConstant(Constant*) { should_not_happen(); }
|
||||
void VisitTempVar(TempVar*) { should_not_happen(); }
|
||||
void VisitDeclaration(Declaration*) { should_not_happen(); }
|
||||
void VisitEmptyStmt(EmptyStmt*) { should_not_happen(); }
|
||||
void VisitIfStmt(IfStmt*) { should_not_happen(); }
|
||||
void VisitForStmt(ForStmt*) { should_not_happen(); }
|
||||
void VisitJumpStmt(JumpStmt*) { should_not_happen(); }
|
||||
void VisitReturnStmt(ReturnStmt*) { should_not_happen(); }
|
||||
void VisitLabelStmt(LabelStmt*) { should_not_happen(); }
|
||||
void VisitCompoundStmt(CompoundStmt*) { should_not_happen(); }
|
||||
void VisitFuncDef(FuncDef*) { should_not_happen(); }
|
||||
void VisitTranslationUnit(TranslationUnit*) { should_not_happen(); }
|
||||
void VisitConditionalOp(ConditionalOp*) { should_not_happen("conditional cannot be lvalue"); }
|
||||
void VisitFuncCall(FuncCall*) { should_not_happen("funccall cannot be lvalue"); }
|
||||
void VisitTransOp(TransOp*) { should_not_happen("transop cannot be lvalue"); }
|
||||
void VisitEnumerator(Enumerator*) { should_not_happen("enumerator cannot be lvalue"); }
|
||||
void VisitConstant(Constant*) { should_not_happen("constant cannot be lvalue"); }
|
||||
void VisitTempVar(TempVar*) { should_not_happen("tempvar cannot be lvalue"); }
|
||||
void VisitDeclaration(Declaration*) { should_not_happen("declaration cannot be lvalue"); }
|
||||
void VisitEmptyStmt(EmptyStmt*) { should_not_happen("empty statement cannot be lvalue"); }
|
||||
void VisitIfStmt(IfStmt*) { should_not_happen("if statement cannot be lvalue"); }
|
||||
void VisitForStmt(ForStmt*) { should_not_happen("for statement cannot be lvalue"); }
|
||||
void VisitJumpStmt(JumpStmt*) { should_not_happen("jump statement cannot be lvalue"); }
|
||||
void VisitReturnStmt(ReturnStmt*) { should_not_happen("return statement cannot be lvalue"); }
|
||||
void VisitLabelStmt(LabelStmt*) { should_not_happen("label statement cannot be lvalue"); }
|
||||
void VisitCompoundStmt(CompoundStmt*) { should_not_happen("compound statement cannot be lvalue"); }
|
||||
void VisitFuncDef(FuncDef*) { should_not_happen("function definition cannot be lvalue"); }
|
||||
void VisitTranslationUnit(TranslationUnit*) { should_not_happen("translation unit cannot be lvalue"); }
|
||||
|
||||
ir::value* GenExpr(Expr* expr, ir::value* rhs) {
|
||||
rhs_ = rhs;
|
||||
|
@@ -83,6 +83,7 @@ public:
|
||||
Constant* ParseSizeof();
|
||||
Constant* ParseAlignof();
|
||||
UnaryOp* ParsePrefixIncDec(const Token* tok);
|
||||
UnaryOp* ParseUnaryIntrinsicOp(const Token* tok, int op);
|
||||
UnaryOp* ParseUnaryOp(const Token* tok, int op);
|
||||
Expr* ParseDerefOp(const Token* tok);
|
||||
|
||||
|
@@ -164,7 +164,9 @@ public:
|
||||
ALIGNOF, // _Alignof
|
||||
GENERIC, // _Generic
|
||||
IMAGINARY, // _Imaginary
|
||||
// function keywords
|
||||
BITCAST,
|
||||
EXP,
|
||||
// KEYWORD END
|
||||
|
||||
IDENTIFIER,
|
||||
|
Reference in New Issue
Block a user