[CORE] Fixed several issues that arose in the development of the
torch-blocksparse package: * Now using warp shuffle in reductions when possible * Various bugfixes in layout inference * Added INFINITY, exponential and select * Better error messages for unimplemented constructs
This commit is contained in:
committed by
Philippe Tillet
parent
ac26fbdc1f
commit
3304629de9
@@ -9,6 +9,7 @@ option(BUILD_PYTHON_MODULE "Build Python Triton bindings" OFF)
|
|||||||
|
|
||||||
# LLVM
|
# LLVM
|
||||||
find_package(LLVM REQUIRED)
|
find_package(LLVM REQUIRED)
|
||||||
|
link_directories(${LLVM_LIBRARY_DIRS})
|
||||||
include_directories(${LLVM_INCLUDE_DIRS})
|
include_directories(${LLVM_INCLUDE_DIRS})
|
||||||
add_definitions(${LLVM_DEFINITIONS})
|
add_definitions(${LLVM_DEFINITIONS})
|
||||||
|
|
||||||
@@ -40,6 +41,5 @@ endif()
|
|||||||
# Triton
|
# Triton
|
||||||
file(GLOB_RECURSE LIBTRITON_SRC lib/*.cc)
|
file(GLOB_RECURSE LIBTRITON_SRC lib/*.cc)
|
||||||
add_library(triton SHARED ${LIBTRITON_SRC} ${PYTHON_SRC})
|
add_library(triton SHARED ${LIBTRITON_SRC} ${PYTHON_SRC})
|
||||||
link_directories(${LLVM_LIBRARY_DIRS})
|
|
||||||
target_link_libraries(triton ${LLVM_LIBRARIES})
|
target_link_libraries(triton ${LLVM_LIBRARIES})
|
||||||
|
|
||||||
|
@@ -119,7 +119,7 @@ struct scanline_layout: public data_layout {
|
|||||||
int mts(size_t k) { return mts_.at(k); }
|
int mts(size_t k) { return mts_.at(k); }
|
||||||
int nts(size_t k) { return nts_.at(k); }
|
int nts(size_t k) { return nts_.at(k); }
|
||||||
|
|
||||||
private:
|
public:
|
||||||
std::vector<int> mts_;
|
std::vector<int> mts_;
|
||||||
std::vector<int> nts_;
|
std::vector<int> nts_;
|
||||||
};
|
};
|
||||||
|
@@ -111,6 +111,8 @@ public:
|
|||||||
void visit_broadcast_inst(ir::broadcast_inst*);
|
void visit_broadcast_inst(ir::broadcast_inst*);
|
||||||
void visit_downcast_inst(ir::downcast_inst*);
|
void visit_downcast_inst(ir::downcast_inst*);
|
||||||
|
|
||||||
|
void visit_exp_inst(ir::exp_inst*);
|
||||||
|
|
||||||
void visit_get_program_id_inst(ir::get_program_id_inst*);
|
void visit_get_program_id_inst(ir::get_program_id_inst*);
|
||||||
void visit_get_num_program_inst(ir::get_num_program_inst*);
|
void visit_get_num_program_inst(ir::get_num_program_inst*);
|
||||||
void visit_atomic_cas_inst(ir::atomic_cas_inst*);
|
void visit_atomic_cas_inst(ir::atomic_cas_inst*);
|
||||||
|
@@ -19,6 +19,7 @@ namespace transform{
|
|||||||
|
|
||||||
class peephole {
|
class peephole {
|
||||||
private:
|
private:
|
||||||
|
bool rewrite_cts_cfs(ir::instruction *value, ir::builder &builder);
|
||||||
bool rewrite_trans_phi(ir::instruction* value, ir::builder &builder);
|
bool rewrite_trans_phi(ir::instruction* value, ir::builder &builder);
|
||||||
bool rewrite_dot_fp32(ir::dot_inst *dot, ir::builder& builder, bool trans_a, bool trans_b, ir::value *A, ir::value *B, ir::value *D);
|
bool rewrite_dot_fp32(ir::dot_inst *dot, ir::builder& builder, bool trans_a, bool trans_b, ir::value *A, ir::value *B, ir::value *D);
|
||||||
bool rewrite_dot_hmma(ir::dot_inst *dot, ir::builder& builder, bool trans_a, bool trans_b, ir::value *A, ir::value *B, ir::value *D);
|
bool rewrite_dot_hmma(ir::dot_inst *dot, ir::builder& builder, bool trans_a, bool trans_b, ir::value *A, ir::value *B, ir::value *D);
|
||||||
|
@@ -35,7 +35,8 @@ public:
|
|||||||
basic_block* get_insert_block() { return block_; }
|
basic_block* get_insert_block() { return block_; }
|
||||||
iterator get_insert_point() { return insert_point_;}
|
iterator get_insert_point() { return insert_point_;}
|
||||||
// Constants
|
// Constants
|
||||||
value *get_int32(unsigned val);
|
value *get_int32(int32_t val);
|
||||||
|
value *get_int64(int64_t val);
|
||||||
// Types
|
// Types
|
||||||
type *get_void_ty();
|
type *get_void_ty();
|
||||||
type *get_int1_ty();
|
type *get_int1_ty();
|
||||||
@@ -63,6 +64,7 @@ public:
|
|||||||
value* create_ret_void();
|
value* create_ret_void();
|
||||||
// Cast instructions
|
// Cast instructions
|
||||||
value *create_cast(cast_op_t op, value *v, type *dst_ty, const std::string &name = "");
|
value *create_cast(cast_op_t op, value *v, type *dst_ty, const std::string &name = "");
|
||||||
|
value* create_ptr_to_int(value *src, type *dst_ty, const std::string &name = "");
|
||||||
value* create_si_to_fp(value *src, type *dst_ty, const std::string &name = "");
|
value* create_si_to_fp(value *src, type *dst_ty, const std::string &name = "");
|
||||||
value* create_ui_to_fp(value *src, type *dst_ty, const std::string &name = "");
|
value* create_ui_to_fp(value *src, type *dst_ty, const std::string &name = "");
|
||||||
value* create_fp_to_si(value *src, type *dst_ty, const std::string &name = "");
|
value* create_fp_to_si(value *src, type *dst_ty, const std::string &name = "");
|
||||||
@@ -135,6 +137,7 @@ public:
|
|||||||
value *create_atomic_cas(value *ptr, value *cmp, value *val, const std::string &name = "");
|
value *create_atomic_cas(value *ptr, value *cmp, value *val, const std::string &name = "");
|
||||||
value *create_atomic_exch(value *ptr, value *val, const std::string &name = "");
|
value *create_atomic_exch(value *ptr, value *val, const std::string &name = "");
|
||||||
value *create_atomic_add(value *ptr, value *val, const std::string &name = "");
|
value *create_atomic_add(value *ptr, value *val, const std::string &name = "");
|
||||||
|
value *create_exp(value* arg, const std::string &name = "");
|
||||||
value *create_dot(value *A, value *B, value *C, const std::string &name = "");
|
value *create_dot(value *A, value *B, value *C, const std::string &name = "");
|
||||||
value *create_trans(value *A, const std::vector<int> &perm = {}, const std::string &name = "");
|
value *create_trans(value *A, const std::vector<int> &perm = {}, const std::string &name = "");
|
||||||
value *create_sqrt(value *A, const std::string &name = "");
|
value *create_sqrt(value *A, const std::string &name = "");
|
||||||
|
@@ -127,6 +127,8 @@ enum value_id_t: unsigned {
|
|||||||
INST_ATOMIC_CAS,
|
INST_ATOMIC_CAS,
|
||||||
INST_ATOMIC_EXCH,
|
INST_ATOMIC_EXCH,
|
||||||
INST_ATOMIC_ADD,
|
INST_ATOMIC_ADD,
|
||||||
|
// math
|
||||||
|
INST_EXP,
|
||||||
// array arithmetic
|
// array arithmetic
|
||||||
INST_TRANS,
|
INST_TRANS,
|
||||||
INST_REDUCE,
|
INST_REDUCE,
|
||||||
|
@@ -612,6 +612,17 @@ public:
|
|||||||
static instruction* create(value *ptr, value *val, const std::string &name = "", instruction *next = nullptr);
|
static instruction* create(value *ptr, value *val, const std::string &name = "", instruction *next = nullptr);
|
||||||
};
|
};
|
||||||
|
|
||||||
|
class exp_inst: public builtin_inst {
|
||||||
|
private:
|
||||||
|
exp_inst(value *val, const std::string &name = "", instruction *next = nullptr);
|
||||||
|
std::string repr_impl() const { return "exp"; }
|
||||||
|
_TRITON_DEFINE_CLONE(exp_inst)
|
||||||
|
_TRITON_DEFINE_ACCEPT(exp_inst)
|
||||||
|
|
||||||
|
public:
|
||||||
|
static instruction* create(value *val, const std::string &name = "", instruction *next = nullptr);
|
||||||
|
};
|
||||||
|
|
||||||
class dot_inst: public builtin_inst {
|
class dot_inst: public builtin_inst {
|
||||||
public:
|
public:
|
||||||
enum TransT { NoTrans, Trans };
|
enum TransT { NoTrans, Trans };
|
||||||
|
@@ -81,6 +81,7 @@ public:
|
|||||||
bool is_integer_ty() const { return id_ == IntegerTyID; }
|
bool is_integer_ty() const { return id_ == IntegerTyID; }
|
||||||
bool is_integer_ty(unsigned bitwidth) { return is_integer_ty() &&
|
bool is_integer_ty(unsigned bitwidth) { return is_integer_ty() &&
|
||||||
get_integer_bitwidth() == bitwidth;}
|
get_integer_bitwidth() == bitwidth;}
|
||||||
|
bool is_bool_ty() const { return is_integer_ty(1); }
|
||||||
bool is_pointer_ty() const { return id_ == PointerTyID; }
|
bool is_pointer_ty() const { return id_ == PointerTyID; }
|
||||||
bool is_tile_ty() const { return id_ == TileTyID; }
|
bool is_tile_ty() const { return id_ == TileTyID; }
|
||||||
|
|
||||||
|
@@ -48,6 +48,8 @@ class splat_inst;
|
|||||||
class broadcast_inst;
|
class broadcast_inst;
|
||||||
class downcast_inst;
|
class downcast_inst;
|
||||||
|
|
||||||
|
class exp_inst;
|
||||||
|
|
||||||
class get_program_id_inst;
|
class get_program_id_inst;
|
||||||
class get_num_program_inst;
|
class get_num_program_inst;
|
||||||
class atomic_cas_inst;
|
class atomic_cas_inst;
|
||||||
@@ -114,6 +116,8 @@ public:
|
|||||||
virtual void visit_unmasked_store_inst(unmasked_store_inst*) = 0;
|
virtual void visit_unmasked_store_inst(unmasked_store_inst*) = 0;
|
||||||
virtual void visit_masked_store_inst(masked_store_inst*) = 0;
|
virtual void visit_masked_store_inst(masked_store_inst*) = 0;
|
||||||
|
|
||||||
|
virtual void visit_exp_inst(exp_inst*) = 0;
|
||||||
|
|
||||||
virtual void visit_reshape_inst(reshape_inst*) = 0;
|
virtual void visit_reshape_inst(reshape_inst*) = 0;
|
||||||
virtual void visit_splat_inst(splat_inst*) = 0;
|
virtual void visit_splat_inst(splat_inst*) = 0;
|
||||||
virtual void visit_broadcast_inst(broadcast_inst*) = 0;
|
virtual void visit_broadcast_inst(broadcast_inst*) = 0;
|
||||||
|
@@ -433,6 +433,7 @@ public:
|
|||||||
void UnaryArithmOpTypeChecking();
|
void UnaryArithmOpTypeChecking();
|
||||||
void BitcastOpTypeChecking();
|
void BitcastOpTypeChecking();
|
||||||
void CastOpTypeChecking();
|
void CastOpTypeChecking();
|
||||||
|
void IntrinsicOpTypeChecking();
|
||||||
|
|
||||||
protected:
|
protected:
|
||||||
UnaryOp(int op, Expr* operand, QualType type=nullptr, int info=0)
|
UnaryOp(int op, Expr* operand, QualType type=nullptr, int info=0)
|
||||||
|
@@ -33,8 +33,8 @@ using LocationList = std::vector<std::string>;
|
|||||||
using StaticInitList = std::vector<StaticInitializer>;
|
using StaticInitList = std::vector<StaticInitializer>;
|
||||||
|
|
||||||
// Error
|
// Error
|
||||||
inline void should_not_happen() { throw std::runtime_error("should not happen"); }
|
inline void should_not_happen(const std::string& suffix) { throw std::runtime_error("internal compiler error: " + suffix); }
|
||||||
inline void error_not_implemented() { throw std::runtime_error("not implemented"); }
|
inline void error_not_implemented(const std::string& msg) { throw std::runtime_error(msg); }
|
||||||
|
|
||||||
class Generator: public Visitor {
|
class Generator: public Visitor {
|
||||||
friend class Evaluator<Addr>;
|
friend class Evaluator<Addr>;
|
||||||
@@ -87,6 +87,9 @@ protected:
|
|||||||
// Triton-IR attributes
|
// Triton-IR attributes
|
||||||
ir::attribute GenIRAttr(ASTNode::Attr attr);
|
ir::attribute GenIRAttr(ASTNode::Attr attr);
|
||||||
|
|
||||||
|
// Triton-IR metadata
|
||||||
|
void SetIRMetadata(ASTNode::Attr attr, ir::value *rhs);
|
||||||
|
|
||||||
// Triton-IR values
|
// Triton-IR values
|
||||||
ir::value* GenAssignOp(Expr* lvalue, ir::value* rhs);
|
ir::value* GenAssignOp(Expr* lvalue, ir::value* rhs);
|
||||||
ir::value* GenBroadcastOp(ir::value* src, ir::type* dst_ty);
|
ir::value* GenBroadcastOp(ir::value* src, ir::type* dst_ty);
|
||||||
@@ -131,22 +134,22 @@ public:
|
|||||||
void VisitObject(Object* obj);
|
void VisitObject(Object* obj);
|
||||||
void VisitIdentifier(Identifier* ident);
|
void VisitIdentifier(Identifier* ident);
|
||||||
|
|
||||||
void VisitConditionalOp(ConditionalOp*) { should_not_happen(); }
|
void VisitConditionalOp(ConditionalOp*) { should_not_happen("conditional cannot be lvalue"); }
|
||||||
void VisitFuncCall(FuncCall*) { should_not_happen(); }
|
void VisitFuncCall(FuncCall*) { should_not_happen("funccall cannot be lvalue"); }
|
||||||
void VisitTransOp(TransOp*) { should_not_happen(); }
|
void VisitTransOp(TransOp*) { should_not_happen("transop cannot be lvalue"); }
|
||||||
void VisitEnumerator(Enumerator*) { should_not_happen(); }
|
void VisitEnumerator(Enumerator*) { should_not_happen("enumerator cannot be lvalue"); }
|
||||||
void VisitConstant(Constant*) { should_not_happen(); }
|
void VisitConstant(Constant*) { should_not_happen("constant cannot be lvalue"); }
|
||||||
void VisitTempVar(TempVar*) { should_not_happen(); }
|
void VisitTempVar(TempVar*) { should_not_happen("tempvar cannot be lvalue"); }
|
||||||
void VisitDeclaration(Declaration*) { should_not_happen(); }
|
void VisitDeclaration(Declaration*) { should_not_happen("declaration cannot be lvalue"); }
|
||||||
void VisitEmptyStmt(EmptyStmt*) { should_not_happen(); }
|
void VisitEmptyStmt(EmptyStmt*) { should_not_happen("empty statement cannot be lvalue"); }
|
||||||
void VisitIfStmt(IfStmt*) { should_not_happen(); }
|
void VisitIfStmt(IfStmt*) { should_not_happen("if statement cannot be lvalue"); }
|
||||||
void VisitForStmt(ForStmt*) { should_not_happen(); }
|
void VisitForStmt(ForStmt*) { should_not_happen("for statement cannot be lvalue"); }
|
||||||
void VisitJumpStmt(JumpStmt*) { should_not_happen(); }
|
void VisitJumpStmt(JumpStmt*) { should_not_happen("jump statement cannot be lvalue"); }
|
||||||
void VisitReturnStmt(ReturnStmt*) { should_not_happen(); }
|
void VisitReturnStmt(ReturnStmt*) { should_not_happen("return statement cannot be lvalue"); }
|
||||||
void VisitLabelStmt(LabelStmt*) { should_not_happen(); }
|
void VisitLabelStmt(LabelStmt*) { should_not_happen("label statement cannot be lvalue"); }
|
||||||
void VisitCompoundStmt(CompoundStmt*) { should_not_happen(); }
|
void VisitCompoundStmt(CompoundStmt*) { should_not_happen("compound statement cannot be lvalue"); }
|
||||||
void VisitFuncDef(FuncDef*) { should_not_happen(); }
|
void VisitFuncDef(FuncDef*) { should_not_happen("function definition cannot be lvalue"); }
|
||||||
void VisitTranslationUnit(TranslationUnit*) { should_not_happen(); }
|
void VisitTranslationUnit(TranslationUnit*) { should_not_happen("translation unit cannot be lvalue"); }
|
||||||
|
|
||||||
ir::value* GenExpr(Expr* expr, ir::value* rhs) {
|
ir::value* GenExpr(Expr* expr, ir::value* rhs) {
|
||||||
rhs_ = rhs;
|
rhs_ = rhs;
|
||||||
|
@@ -83,6 +83,7 @@ public:
|
|||||||
Constant* ParseSizeof();
|
Constant* ParseSizeof();
|
||||||
Constant* ParseAlignof();
|
Constant* ParseAlignof();
|
||||||
UnaryOp* ParsePrefixIncDec(const Token* tok);
|
UnaryOp* ParsePrefixIncDec(const Token* tok);
|
||||||
|
UnaryOp* ParseUnaryIntrinsicOp(const Token* tok, int op);
|
||||||
UnaryOp* ParseUnaryOp(const Token* tok, int op);
|
UnaryOp* ParseUnaryOp(const Token* tok, int op);
|
||||||
Expr* ParseDerefOp(const Token* tok);
|
Expr* ParseDerefOp(const Token* tok);
|
||||||
|
|
||||||
|
@@ -164,7 +164,9 @@ public:
|
|||||||
ALIGNOF, // _Alignof
|
ALIGNOF, // _Alignof
|
||||||
GENERIC, // _Generic
|
GENERIC, // _Generic
|
||||||
IMAGINARY, // _Imaginary
|
IMAGINARY, // _Imaginary
|
||||||
|
// function keywords
|
||||||
BITCAST,
|
BITCAST,
|
||||||
|
EXP,
|
||||||
// KEYWORD END
|
// KEYWORD END
|
||||||
|
|
||||||
IDENTIFIER,
|
IDENTIFIER,
|
||||||
|
@@ -2,7 +2,7 @@
|
|||||||
#include "triton/ir/utils.h"
|
#include "triton/ir/utils.h"
|
||||||
#include "triton/ir/instructions.h"
|
#include "triton/ir/instructions.h"
|
||||||
#include "triton/ir/type.h"
|
#include "triton/ir/type.h"
|
||||||
|
#include <iostream>
|
||||||
|
|
||||||
|
|
||||||
namespace triton{
|
namespace triton{
|
||||||
|
@@ -16,7 +16,9 @@ namespace analysis{
|
|||||||
* Helper Functions *
|
* Helper Functions *
|
||||||
* -------------------------------- */
|
* -------------------------------- */
|
||||||
|
|
||||||
inline unsigned clamp(unsigned x, unsigned lo, unsigned hi) {
|
inline unsigned clamp(unsigned x, unsigned a, unsigned b) {
|
||||||
|
unsigned lo = std::min(a, b);
|
||||||
|
unsigned hi = std::max(a, b);
|
||||||
return std::min(std::max(x, lo), hi);
|
return std::min(std::max(x, lo), hi);
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -97,7 +99,9 @@ data_layout::data_layout(id_t id,
|
|||||||
order_.resize(axes_.size());
|
order_.resize(axes_.size());
|
||||||
std::iota(order_.begin(), order_.end(), 0);
|
std::iota(order_.begin(), order_.end(), 0);
|
||||||
auto largest = std::max_element(ptr.begin(), ptr.end(), [&](ir::value *x, ir::value *y){
|
auto largest = std::max_element(ptr.begin(), ptr.end(), [&](ir::value *x, ir::value *y){
|
||||||
return x->get_type()->get_tile_rank() < y->get_type()->get_tile_rank();
|
std::pair<int, int> xx = {x->get_type()->get_tile_rank(), x->get_type()->get_tile_num_elements()};
|
||||||
|
std::pair<int, int> yy = {y->get_type()->get_tile_rank(), y->get_type()->get_tile_num_elements()};
|
||||||
|
return xx < yy;
|
||||||
});
|
});
|
||||||
if(*largest){
|
if(*largest){
|
||||||
auto max_contiguous = align->contiguous(*largest);
|
auto max_contiguous = align->contiguous(*largest);
|
||||||
@@ -201,8 +205,9 @@ scanline_layout::scanline_layout(size_t num_warps,
|
|||||||
for(size_t d = 0; d < shape_.size(); d++)
|
for(size_t d = 0; d < shape_.size(); d++)
|
||||||
effective_num_threads *= mts_[d];
|
effective_num_threads *= mts_[d];
|
||||||
|
|
||||||
if(num_warps * 32 != effective_num_threads)
|
// std::cout <<values.size() << " " << num_warps << " " << effective_num_threads << std::endl;
|
||||||
throw std::runtime_error("cannot create a kernel with this amount of warps");
|
// if(num_warps * 32 != effective_num_threads)
|
||||||
|
// throw std::runtime_error("cannot create a kernel with this amount of warps");
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@@ -355,8 +360,9 @@ void layouts::make_graph(ir::instruction *i) {
|
|||||||
void layouts::create(size_t id, const std::vector<ir::value*>& values) {
|
void layouts::create(size_t id, const std::vector<ir::value*>& values) {
|
||||||
auto it_hmma_c = std::find_if(values.begin(), values.end(), &is_hmma_c);
|
auto it_hmma_c = std::find_if(values.begin(), values.end(), &is_hmma_c);
|
||||||
auto cmp = [](ir::value* x, ir::value *y) {
|
auto cmp = [](ir::value* x, ir::value *y) {
|
||||||
return x->get_type()->get_tile_ranks1() <
|
std::pair<int, int> xx = {x->get_type()->get_tile_rank(), x->get_type()->get_tile_num_elements()};
|
||||||
y->get_type()->get_tile_ranks1();
|
std::pair<int, int> yy = {y->get_type()->get_tile_rank(), y->get_type()->get_tile_num_elements()};
|
||||||
|
return xx < yy;
|
||||||
};
|
};
|
||||||
std::vector<ir::value*> lvalue = values;
|
std::vector<ir::value*> lvalue = values;
|
||||||
std::remove_if(lvalue.begin(), lvalue.end(), [&](ir::value* v) { return dynamic_cast<ir::trans_inst*>(v); });
|
std::remove_if(lvalue.begin(), lvalue.end(), [&](ir::value* v) { return dynamic_cast<ir::trans_inst*>(v); });
|
||||||
@@ -402,11 +408,8 @@ void layouts::run(ir::module &mod) {
|
|||||||
unsigned axis = red->get_axis();
|
unsigned axis = red->get_axis();
|
||||||
// shape
|
// shape
|
||||||
auto shapes = arg->get_type()->get_tile_shapes();
|
auto shapes = arg->get_type()->get_tile_shapes();
|
||||||
unsigned shape_ax = shapes[axis];
|
|
||||||
scanline_layout *layout = get(arg)->to_scanline();
|
scanline_layout *layout = get(arg)->to_scanline();
|
||||||
unsigned per_thread = layout->nts(axis);
|
shapes[axis] = layout->mts(axis);
|
||||||
unsigned depth = shape_ax / per_thread;
|
|
||||||
shapes[axis] = depth;
|
|
||||||
// create layout
|
// create layout
|
||||||
layouts_[id] = new shared_layout(layout, axes_->get(arg), shapes, {red}, red->get_type()->get_scalar_ty(), align_);
|
layouts_[id] = new shared_layout(layout, axes_->get(arg), shapes, {red}, red->get_type()->get_scalar_ty(), align_);
|
||||||
tmp_[red] = id;
|
tmp_[red] = id;
|
||||||
|
@@ -196,8 +196,9 @@ void generator::visit_value(ir::value* v) {
|
|||||||
BasicBlock *current = builder_->GetInsertBlock();
|
BasicBlock *current = builder_->GetInsertBlock();
|
||||||
auto *inst = dynamic_cast<ir::instruction*>(v);
|
auto *inst = dynamic_cast<ir::instruction*>(v);
|
||||||
if(inst && !dynamic_cast<ir::phi_node*>(v))
|
if(inst && !dynamic_cast<ir::phi_node*>(v))
|
||||||
for(ir::value *op: inst->ops())
|
for(ir::value *op: inst->ops()){
|
||||||
visit_value(op);
|
visit_value(op);
|
||||||
|
}
|
||||||
// change insert point for phi node
|
// change insert point for phi node
|
||||||
builder_->SetInsertPoint(current);
|
builder_->SetInsertPoint(current);
|
||||||
auto *phi = dynamic_cast<ir::phi_node*>(v);
|
auto *phi = dynamic_cast<ir::phi_node*>(v);
|
||||||
@@ -547,6 +548,24 @@ void generator::visit_get_num_program_inst(ir::get_num_program_inst* np) {
|
|||||||
vmap_[np] = ret;
|
vmap_[np] = ret;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void generator::visit_exp_inst(ir::exp_inst* x){
|
||||||
|
distributed_tile *arg = (distributed_tile*)tmap_.at(x->get_operand(0));
|
||||||
|
// Function *fn = builder_->GetInsertBlock()->getParent();
|
||||||
|
// Module *module = fn->getParent();
|
||||||
|
// Type *ty = llvm_type(x->get_type()->get_scalar_ty(), *ctx_);
|
||||||
|
// Function *ex2 = Intrinsic::getDeclaration(module, Intrinsic::nvvm_ex2_approx_ftz_f, {ty});
|
||||||
|
Constant *log2e = ConstantFP::get(builder_->getFloatTy(), 1.4426950408889634);
|
||||||
|
|
||||||
|
FunctionType *fn_ty = FunctionType::get(builder_->getFloatTy(), {builder_->getFloatTy()}, false);
|
||||||
|
InlineAsm *ex2 = InlineAsm::get(fn_ty, "ex2.approx.ftz.f32 $0, $1;", "=f,f", false);
|
||||||
|
|
||||||
|
|
||||||
|
for_each(x, [&](indices_t idx){
|
||||||
|
Value *ex2arg = builder_->CreateFMul(arg->get_value(idx), log2e);
|
||||||
|
set_value(x, idx, builder_->CreateCall(ex2, {ex2arg}));
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
void generator::visit_atomic_cas_inst(ir::atomic_cas_inst* cas) {
|
void generator::visit_atomic_cas_inst(ir::atomic_cas_inst* cas) {
|
||||||
BasicBlock *current = builder_->GetInsertBlock();
|
BasicBlock *current = builder_->GetInsertBlock();
|
||||||
Module *module = current->getModule();
|
Module *module = current->getModule();
|
||||||
@@ -587,6 +606,7 @@ void generator::visit_atomic_exch_inst(ir::atomic_exch_inst* xchg) {
|
|||||||
BasicBlock *tid_0_bb = BasicBlock::Create(*ctx_, "tid_0", current->getParent());
|
BasicBlock *tid_0_bb = BasicBlock::Create(*ctx_, "tid_0", current->getParent());
|
||||||
BasicBlock *tid_0_done_bb = BasicBlock::Create(*ctx_, "tid_0_done", current->getParent());
|
BasicBlock *tid_0_done_bb = BasicBlock::Create(*ctx_, "tid_0_done", current->getParent());
|
||||||
tgt_->add_memfence(module, *builder_);
|
tgt_->add_memfence(module, *builder_);
|
||||||
|
tgt_->add_barrier(module, *builder_);
|
||||||
builder_->CreateCondBr(pred, tid_0_bb, tid_0_done_bb);
|
builder_->CreateCondBr(pred, tid_0_bb, tid_0_done_bb);
|
||||||
builder_->SetInsertPoint(tid_0_bb);
|
builder_->SetInsertPoint(tid_0_bb);
|
||||||
builder_->CreateAtomicRMW(AtomicRMWInst::Xchg, rmw_ptr, rmw_val,
|
builder_->CreateAtomicRMW(AtomicRMWInst::Xchg, rmw_ptr, rmw_val,
|
||||||
@@ -825,24 +845,111 @@ void generator::visit_reduce_inst(ir::reduce_inst* x) {
|
|||||||
ir::value *arg = x->get_operand(0);
|
ir::value *arg = x->get_operand(0);
|
||||||
distributed_tile* arg_tile = (distributed_tile*)tmap_.at(arg);
|
distributed_tile* arg_tile = (distributed_tile*)tmap_.at(arg);
|
||||||
ir::reduce_inst::op_t op = x->get_op();
|
ir::reduce_inst::op_t op = x->get_op();
|
||||||
|
unsigned axis = x->get_axis();
|
||||||
|
|
||||||
|
Type *fp32_ty = builder_->getFloatTy();
|
||||||
|
FunctionType *fmaxmin_ty = FunctionType::get(fp32_ty, {fp32_ty, fp32_ty}, false);
|
||||||
|
InlineAsm *fmin = InlineAsm::get(fmaxmin_ty, "min.ftz.f32 $0, $1, $2;", "=f,f,f", false);
|
||||||
|
InlineAsm *fmax = InlineAsm::get(fmaxmin_ty, "max.ftz.f32 $0, $1, $2;", "=f,f,f", false);
|
||||||
|
|
||||||
auto accumulate = [&](Value* x, Value *y) -> Value* {
|
auto accumulate = [&](Value* x, Value *y) -> Value* {
|
||||||
switch(op) {
|
switch(op) {
|
||||||
case ir::reduce_inst::ADD: return builder_->CreateAdd(x, y);
|
case ir::reduce_inst::ADD: return builder_->CreateAdd(x, y);
|
||||||
case ir::reduce_inst::SUB: return builder_->CreateSub(x, y);
|
case ir::reduce_inst::SUB: return builder_->CreateSub(x, y);
|
||||||
case ir::reduce_inst::MAX: return builder_->CreateMaximum(x, y);
|
case ir::reduce_inst::MAX:{
|
||||||
case ir::reduce_inst::MIN: return builder_->CreateMinimum(x, y);
|
if(x->getType()->isIntegerTy())
|
||||||
|
return builder_->CreateSelect(builder_->CreateICmpSGE(x, y), x, y);
|
||||||
|
else
|
||||||
|
return builder_->CreateMaxNum(x, y);
|
||||||
|
}
|
||||||
|
case ir::reduce_inst::MIN:{
|
||||||
|
if(x->getType()->isIntegerTy())
|
||||||
|
return builder_->CreateSelect(builder_->CreateICmpSLE(x, y), x, y);
|
||||||
|
else
|
||||||
|
return builder_->CreateMinNum(x, y);
|
||||||
|
}
|
||||||
case ir::reduce_inst::FADD: return builder_->CreateFAdd(x, y);
|
case ir::reduce_inst::FADD: return builder_->CreateFAdd(x, y);
|
||||||
case ir::reduce_inst::FSUB: return builder_->CreateFSub(x, y);
|
case ir::reduce_inst::FSUB: return builder_->CreateFSub(x, y);
|
||||||
case ir::reduce_inst::FMAX: return builder_->CreateSelect(builder_->CreateFCmpOGT(x, y), x, y);
|
case ir::reduce_inst::FMAX: return builder_->CreateCall(fmax, {x, y});
|
||||||
case ir::reduce_inst::FMIN: return builder_->CreateSelect(builder_->CreateFCmpOLT(x, y), x, y);
|
case ir::reduce_inst::FMIN: return builder_->CreateCall(fmin, {x, y});
|
||||||
default: break;
|
default: assert(false); return nullptr;
|
||||||
}
|
}
|
||||||
assert(false);
|
|
||||||
return nullptr;
|
|
||||||
};
|
};
|
||||||
|
|
||||||
|
Value *neutral;
|
||||||
|
switch(op) {
|
||||||
|
case ir::reduce_inst::ADD: neutral = builder_->getInt32(0); break;
|
||||||
|
case ir::reduce_inst::SUB: neutral = builder_->getInt32(0); break;
|
||||||
|
case ir::reduce_inst::MAX: neutral = builder_->getInt32(INT32_MIN); break;
|
||||||
|
case ir::reduce_inst::MIN: neutral = builder_->getInt32(INT32_MAX); break;
|
||||||
|
case ir::reduce_inst::FADD: neutral = ConstantFP::get(arg_tile->get_ty(), 0); break;
|
||||||
|
case ir::reduce_inst::FSUB: neutral = ConstantFP::get(arg_tile->get_ty(), 0); break;
|
||||||
|
case ir::reduce_inst::FMAX: neutral = ConstantFP::get(arg_tile->get_ty(), -INFINITY); break;
|
||||||
|
case ir::reduce_inst::FMIN: neutral = ConstantFP::get(arg_tile->get_ty(), INFINITY); break;
|
||||||
|
default: assert(false); break;
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
analysis::data_layout* arg_layout = layouts_->get(arg);
|
||||||
|
if(auto* L = dynamic_cast<analysis::scanline_layout*>(arg_layout)){
|
||||||
|
bool can_optimize = true;
|
||||||
|
for(size_t r = 0; r < L->get_rank(); r++){
|
||||||
|
if(r != axis)
|
||||||
|
can_optimize = can_optimize && (L->mts(r) == L->get_shape()[r]);
|
||||||
|
}
|
||||||
|
if(can_optimize){
|
||||||
|
Value *thread_acc = nullptr;
|
||||||
|
// reduce within thread
|
||||||
|
arg_tile->for_each([&](indices_t idx) {
|
||||||
|
Value *current = arg_tile->get_value(idx);
|
||||||
|
if(thread_acc == nullptr)
|
||||||
|
thread_acc = current;
|
||||||
|
else
|
||||||
|
thread_acc = accumulate(thread_acc, current);
|
||||||
|
});
|
||||||
|
// reduce within wrap
|
||||||
|
FunctionType *fn_ty = FunctionType::get(thread_acc->getType(), {thread_acc->getType(), builder_->getInt32Ty()}, false);
|
||||||
|
InlineAsm *shfl_xor = InlineAsm::get(fn_ty, "shfl.sync.bfly.b32 $0, $1, $2, 0x1f, 0xffffffff;", "=f,f,r", false);
|
||||||
|
Value *warp_acc = thread_acc;
|
||||||
|
for(int i = 16; i > 0; i >>= 1)
|
||||||
|
warp_acc = accumulate(warp_acc, builder_->CreateCall(shfl_xor, {warp_acc, builder_->getInt32(i)}));
|
||||||
|
// shared memory pointer
|
||||||
|
unsigned addr_space = sh_mem_ptr_->getType()->getPointerAddressSpace();
|
||||||
|
Type *res_ty = arg_tile->get_ty();
|
||||||
|
Value *sh_mem_ptr = builder_->CreateBitCast(sh_mem_ptr_, PointerType::get(res_ty, addr_space));
|
||||||
|
Value* u_thread_id = tgt_->get_local_id(builder_->GetInsertBlock()->getModule(), *builder_, 0);
|
||||||
|
Value* warp_id = builder_->CreateUDiv(u_thread_id, builder_->getInt32(32));
|
||||||
|
Value *write_ptr = builder_->CreateGEP(sh_mem_ptr, warp_id);
|
||||||
|
// store warp result in shared memory
|
||||||
|
tgt_->add_barrier(mod_, *builder_);
|
||||||
|
builder_->CreateStore(warp_acc, write_ptr);
|
||||||
|
tgt_->add_barrier(mod_, *builder_);
|
||||||
|
// accumulate all warps
|
||||||
|
Value *load_ptr = builder_->CreateGEP(sh_mem_ptr, u_thread_id);
|
||||||
|
Value* is_first_warp = builder_->CreateICmpEQ(warp_id, builder_->getInt32(0));
|
||||||
|
BasicBlock* bb_final_acc = BasicBlock::Create(*ctx_, "bb_final_acc", builder_->GetInsertBlock()->getParent());
|
||||||
|
BasicBlock* bb_final_acc_done = BasicBlock::Create(*ctx_, "bb_final_acc_done", builder_->GetInsertBlock()->getParent());
|
||||||
|
builder_->CreateCondBr(is_first_warp, bb_final_acc, bb_final_acc_done);
|
||||||
|
builder_->SetInsertPoint(bb_final_acc);
|
||||||
|
Value* final_val = builder_->CreateLoad(load_ptr);
|
||||||
|
for(int i = (num_warps_+1)/2; i > 0; i >>= 1)
|
||||||
|
final_val = accumulate(final_val, builder_->CreateCall(shfl_xor, {final_val, builder_->getInt32(i)}));
|
||||||
|
builder_->CreateStore(final_val, load_ptr);
|
||||||
|
builder_->CreateBr(bb_final_acc_done);
|
||||||
|
// // store first warp done
|
||||||
|
builder_->SetInsertPoint(bb_final_acc_done);
|
||||||
|
// write back
|
||||||
|
tgt_->add_barrier(mod_, *builder_);
|
||||||
|
final_val = builder_->CreateLoad(sh_mem_ptr);
|
||||||
|
for_each(x, [&](indices_t idx) {
|
||||||
|
set_value(x, idx, final_val);
|
||||||
|
});
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// reduce within thread
|
// reduce within thread
|
||||||
unsigned axis = x->get_axis();
|
|
||||||
arg_tile->for_each([&](indices_t idx) {
|
arg_tile->for_each([&](indices_t idx) {
|
||||||
indices_t pidx = idx;
|
indices_t pidx = idx;
|
||||||
pidx[axis] = builder_->getInt32(0);
|
pidx[axis] = builder_->getInt32(0);
|
||||||
@@ -861,7 +968,7 @@ void generator::visit_reduce_inst(ir::reduce_inst* x) {
|
|||||||
unsigned depth = stile->get_shapes()[axis];
|
unsigned depth = stile->get_shapes()[axis];
|
||||||
|
|
||||||
unsigned addr_space = sh_mem_ptr_->getType()->getPointerAddressSpace();
|
unsigned addr_space = sh_mem_ptr_->getType()->getPointerAddressSpace();
|
||||||
Type *res_ty = builder_->getFloatTy();
|
Type *res_ty = arg_tile->get_ty();
|
||||||
Value *base_ptr = builder_->CreateBitCast(sh_mem_ptr_, PointerType::get(res_ty, addr_space));
|
Value *base_ptr = builder_->CreateBitCast(sh_mem_ptr_, PointerType::get(res_ty, addr_space));
|
||||||
for(auto& x: partial) {
|
for(auto& x: partial) {
|
||||||
// current element being computed
|
// current element being computed
|
||||||
@@ -891,10 +998,12 @@ void generator::visit_reduce_inst(ir::reduce_inst* x) {
|
|||||||
// accumulate
|
// accumulate
|
||||||
result = accumulate(result, next);
|
result = accumulate(result, next);
|
||||||
// write back
|
// write back
|
||||||
|
tgt_->add_barrier(mod_, *builder_);
|
||||||
builder_->CreateStore(result, write_ptr);
|
builder_->CreateStore(result, write_ptr);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
tgt_->add_barrier(mod_, *builder_);
|
tgt_->add_barrier(mod_, *builder_);
|
||||||
|
|
||||||
// write back
|
// write back
|
||||||
for_each(x, [&](indices_t idx) {
|
for_each(x, [&](indices_t idx) {
|
||||||
indices_t red_idx = idx;
|
indices_t red_idx = idx;
|
||||||
@@ -1169,8 +1278,9 @@ void generator::visit_function(ir::function* fn) {
|
|||||||
}
|
}
|
||||||
builder_->SetInsertPoint((BasicBlock*)vmap_[fn->blocks()[0]]);
|
builder_->SetInsertPoint((BasicBlock*)vmap_[fn->blocks()[0]]);
|
||||||
// initialize layouts
|
// initialize layouts
|
||||||
for(auto x: layouts_->get_all())
|
for(auto x: layouts_->get_all()){
|
||||||
visit_layout(x.second);
|
visit_layout(x.second);
|
||||||
|
}
|
||||||
// generate LLVM-IR code
|
// generate LLVM-IR code
|
||||||
for(ir::basic_block *block: fn->blocks())
|
for(ir::basic_block *block: fn->blocks())
|
||||||
visit_basic_block(block);
|
visit_basic_block(block);
|
||||||
|
@@ -158,7 +158,6 @@ tile *machine_distributed_layout::create(ir::value *v) {
|
|||||||
return false;
|
return false;
|
||||||
};
|
};
|
||||||
std::sort(order.begin(), order.end(), cmp);
|
std::sort(order.begin(), order.end(), cmp);
|
||||||
|
|
||||||
return new distributed_tile(ty, shapes, order, axes, *builder_);
|
return new distributed_tile(ty, shapes, order, axes, *builder_);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@@ -135,13 +135,13 @@ Value* shared_tile::shared_offset(llvm::IRBuilder<> &builder, const shapes_t& sh
|
|||||||
const std::vector<int>& perm, const std::vector<int>& order,
|
const std::vector<int>& perm, const std::vector<int>& order,
|
||||||
indices_t idx) {
|
indices_t idx) {
|
||||||
// strides
|
// strides
|
||||||
std::vector<Value*> strides(order.size());
|
std::vector<Value*> strides(shapes.size(), builder.getInt32(0));
|
||||||
strides[order[0]] = builder.getInt32(1);
|
strides[order[0]] = builder.getInt32(1);
|
||||||
for(size_t i = 1; i < idx.size(); i++)
|
for(size_t i = 1; i < idx.size(); i++)
|
||||||
strides[order[i]] = builder.CreateMul(strides[order[i-1]], builder.getInt32(shapes[order[i-1]]));
|
strides[order[i]] = builder.CreateMul(strides[order[i-1]], builder.getInt32(shapes[order[i-1]]));
|
||||||
// result
|
// result
|
||||||
Value *result = builder.getInt32(0);
|
Value *result = builder.getInt32(0);
|
||||||
for(size_t i = 0; i < strides.size(); i++)
|
for(size_t i = 0; i < idx.size(); i++)
|
||||||
result = builder.CreateAdd(result, builder.CreateMul(idx[perm[i]], strides[i]));
|
result = builder.CreateAdd(result, builder.CreateMul(idx[perm[i]], strides[i]));
|
||||||
return result;
|
return result;
|
||||||
}
|
}
|
||||||
|
@@ -26,8 +26,6 @@ inline bool is_shmem_res(ir::value* v){
|
|||||||
return false;
|
return false;
|
||||||
if(i->get_id() == ir::INST_TRANS)
|
if(i->get_id() == ir::INST_TRANS)
|
||||||
return true;
|
return true;
|
||||||
if(i->get_id() == ir::INST_REDUCE)
|
|
||||||
return true;
|
|
||||||
if(i->get_id() == ir::INST_COPY_TO_SHARED)
|
if(i->get_id() == ir::INST_COPY_TO_SHARED)
|
||||||
return true;
|
return true;
|
||||||
return false;
|
return false;
|
||||||
@@ -76,8 +74,9 @@ void cts::run(ir::module &mod) {
|
|||||||
size_t num_op = i->get_num_operands();
|
size_t num_op = i->get_num_operands();
|
||||||
// copy to shared operands
|
// copy to shared operands
|
||||||
for(size_t k = 0; k < num_op; k++)
|
for(size_t k = 0; k < num_op; k++)
|
||||||
if(is_shmem_op(i, k))
|
if(is_shmem_op(i, k)){
|
||||||
add_copy(i, i->get_operand(k), builder, true);
|
add_copy(i, i->get_operand(k), builder, true);
|
||||||
|
}
|
||||||
// copy from shared operands
|
// copy from shared operands
|
||||||
for(size_t k = 0; k < num_op; k++)
|
for(size_t k = 0; k < num_op; k++)
|
||||||
if(!dynamic_cast<ir::phi_node*>(i) &&
|
if(!dynamic_cast<ir::phi_node*>(i) &&
|
||||||
|
@@ -83,6 +83,19 @@ bool peephole::rewrite_dot(ir::instruction *value, ir::builder& builder){
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
bool peephole::rewrite_cts_cfs(ir::instruction *value, ir::builder &builder){
|
||||||
|
auto cfs = dynamic_cast<ir::copy_from_shared_inst*>(value);
|
||||||
|
if(cfs) {
|
||||||
|
ir::value *arg = cfs->get_operand(0);
|
||||||
|
ir::copy_to_shared_inst* cts = dynamic_cast<ir::copy_to_shared_inst*>(arg);
|
||||||
|
if(!cts)
|
||||||
|
return false;
|
||||||
|
cfs->replace_all_uses_with(cts->get_operand(0));
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
bool peephole::rewrite_unit_red(ir::instruction *value, ir::builder& builder){
|
bool peephole::rewrite_unit_red(ir::instruction *value, ir::builder& builder){
|
||||||
auto x = dynamic_cast<ir::reduce_inst*>(value);
|
auto x = dynamic_cast<ir::reduce_inst*>(value);
|
||||||
if(!x)
|
if(!x)
|
||||||
@@ -183,6 +196,7 @@ void peephole::run(ir::module &mod) {
|
|||||||
continue;
|
continue;
|
||||||
bool was_modified = false;
|
bool was_modified = false;
|
||||||
was_modified = was_modified || rewrite_mult(i, builder);
|
was_modified = was_modified || rewrite_mult(i, builder);
|
||||||
|
was_modified = was_modified || rewrite_cts_cfs(i, builder);
|
||||||
was_modified = was_modified || rewrite_trans_phi(i, builder);
|
was_modified = was_modified || rewrite_trans_phi(i, builder);
|
||||||
was_modified = was_modified || rewrite_unit_red(i, builder);
|
was_modified = was_modified || rewrite_unit_red(i, builder);
|
||||||
was_modified = was_modified || rewrite_gep_ptr_min_off_plus_off(i, builder);
|
was_modified = was_modified || rewrite_gep_ptr_min_off_plus_off(i, builder);
|
||||||
|
@@ -91,7 +91,7 @@ void module::compile_llvm_module(std::unique_ptr<llvm::Module> module, const std
|
|||||||
const std::string& features,
|
const std::string& features,
|
||||||
file_type_t ft) {
|
file_type_t ft) {
|
||||||
init_llvm();
|
init_llvm();
|
||||||
// debug
|
// // debug
|
||||||
// llvm::legacy::PassManager pm;
|
// llvm::legacy::PassManager pm;
|
||||||
// pm.add(llvm::createPrintModulePass(llvm::outs()));
|
// pm.add(llvm::createPrintModulePass(llvm::outs()));
|
||||||
// pm.add(llvm::createVerifierPass());
|
// pm.add(llvm::createVerifierPass());
|
||||||
|
@@ -44,9 +44,11 @@ void builder::set_insert_point(basic_block *block){
|
|||||||
// convenience functions
|
// convenience functions
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
value *builder::get_int32(unsigned val) {
|
value *builder::get_int32(int32_t val)
|
||||||
return constant_int::get(type::get_int32_ty(ctx_), val);
|
{ return constant_int::get(type::get_int32_ty(ctx_), val);}
|
||||||
}
|
|
||||||
|
value *builder::get_int64(int64_t val)
|
||||||
|
{ return constant_int::get(type::get_int64_ty(ctx_), val);}
|
||||||
|
|
||||||
type *builder::get_void_ty()
|
type *builder::get_void_ty()
|
||||||
{ return type::get_void_ty(ctx_); }
|
{ return type::get_void_ty(ctx_); }
|
||||||
@@ -103,6 +105,7 @@ value *builder::create_ret_void() {
|
|||||||
return create_cast(OPCODE, src, dst_ty, name);\
|
return create_cast(OPCODE, src, dst_ty, name);\
|
||||||
}
|
}
|
||||||
|
|
||||||
|
DEFINE_CAST_INSTR(ptr_to_int, cast_op_t::PtrToInt)
|
||||||
DEFINE_CAST_INSTR(si_to_fp, cast_op_t::SIToFP)
|
DEFINE_CAST_INSTR(si_to_fp, cast_op_t::SIToFP)
|
||||||
DEFINE_CAST_INSTR(ui_to_fp, cast_op_t::UIToFP)
|
DEFINE_CAST_INSTR(ui_to_fp, cast_op_t::UIToFP)
|
||||||
DEFINE_CAST_INSTR(fp_to_si, cast_op_t::FPToSI)
|
DEFINE_CAST_INSTR(fp_to_si, cast_op_t::FPToSI)
|
||||||
@@ -308,6 +311,10 @@ value *builder::create_atomic_add(value *ptr, value *val, const std::string &nam
|
|||||||
return insert(atomic_add_inst::create(ptr, val, name));
|
return insert(atomic_add_inst::create(ptr, val, name));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
value *builder::create_exp(value *arg, const std::string &name){
|
||||||
|
return insert(exp_inst::create(arg, name));
|
||||||
|
}
|
||||||
|
|
||||||
value *builder::create_dot(value *A, value *B, value *C, const std::string &name) {
|
value *builder::create_dot(value *A, value *B, value *C, const std::string &name) {
|
||||||
return insert(dot_inst::create_nn(A, B, C, name));
|
return insert(dot_inst::create_nn(A, B, C, name));
|
||||||
}
|
}
|
||||||
|
@@ -64,7 +64,7 @@ constant *constant_fp::get_negative_zero(type *ty){
|
|||||||
|
|
||||||
constant *constant_fp::get_zero_value_for_negation(type *ty) {
|
constant *constant_fp::get_zero_value_for_negation(type *ty) {
|
||||||
if(ty->get_scalar_ty()->is_floating_point_ty())
|
if(ty->get_scalar_ty()->is_floating_point_ty())
|
||||||
return get_negative_zero(ty);
|
return constant_fp::get(ty, 0);
|
||||||
return constant::get_null_value(ty);
|
return constant::get_null_value(ty);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@@ -746,6 +746,18 @@ instruction* atomic_add_inst::create(value *ptr, value *val, const std::string &
|
|||||||
return new atomic_add_inst(ptr, val, name, next);
|
return new atomic_add_inst(ptr, val, name, next);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// exp
|
||||||
|
|
||||||
|
exp_inst::exp_inst(value *val, const std::string &name, instruction *next)
|
||||||
|
: builtin_inst(val->get_type(), INST_EXP, 1, name, next) {
|
||||||
|
set_operand(0, val);
|
||||||
|
}
|
||||||
|
|
||||||
|
instruction* exp_inst::create(value *val, const std::string& name, instruction *next) {
|
||||||
|
return new exp_inst(val, name, next);
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
// intrinsic instructions
|
// intrinsic instructions
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
@@ -655,6 +655,9 @@ void UnaryOp::TypeChecking() {
|
|||||||
case Token::REDUCE:
|
case Token::REDUCE:
|
||||||
return ReduceOpTypeChecking();
|
return ReduceOpTypeChecking();
|
||||||
|
|
||||||
|
case Token::EXP:
|
||||||
|
return IntrinsicOpTypeChecking();
|
||||||
|
|
||||||
default:
|
default:
|
||||||
assert(false);
|
assert(false);
|
||||||
}
|
}
|
||||||
@@ -769,6 +772,10 @@ void UnaryOp::CastOpTypeChecking() {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void UnaryOp::IntrinsicOpTypeChecking() {
|
||||||
|
type_ = ScalarOrLikeTile(operand_, ArithmType::New(T_FLOAT));
|
||||||
|
}
|
||||||
|
|
||||||
/*
|
/*
|
||||||
* Transposition Operator
|
* Transposition Operator
|
||||||
*/
|
*/
|
||||||
|
@@ -42,8 +42,8 @@ void Generator::VisitBinaryOp(BinaryOp* binary) {
|
|||||||
case '^': return set_ret(bld_->create_xor(lhs, rhs));
|
case '^': return set_ret(bld_->create_xor(lhs, rhs));
|
||||||
case Token::LEFT: return set_ret(bld_->create_shl(lhs, rhs));
|
case Token::LEFT: return set_ret(bld_->create_shl(lhs, rhs));
|
||||||
case Token::RIGHT: return set_ret(bld_->create_lshr(lhs, rhs));
|
case Token::RIGHT: return set_ret(bld_->create_lshr(lhs, rhs));
|
||||||
case '.': return error_not_implemented();
|
case '.': return error_not_implemented(". binary operator not implemented");
|
||||||
case ',': return error_not_implemented();
|
case ',': return error_not_implemented(", binary operator not implemented");
|
||||||
case '@' : {
|
case '@' : {
|
||||||
ir::type* ret_ty = GenIRType(binary->Type(), *ctx_);
|
ir::type* ret_ty = GenIRType(binary->Type(), *ctx_);
|
||||||
ir::type* ret_scal_ty = ret_ty->get_scalar_ty();
|
ir::type* ret_scal_ty = ret_ty->get_scalar_ty();
|
||||||
@@ -66,7 +66,7 @@ void Generator::VisitBinaryOp(BinaryOp* binary) {
|
|||||||
auto clhs = dynamic_cast<ir::constant_int*>(lhs);
|
auto clhs = dynamic_cast<ir::constant_int*>(lhs);
|
||||||
auto crhs = dynamic_cast<ir::constant_int*>(rhs);
|
auto crhs = dynamic_cast<ir::constant_int*>(rhs);
|
||||||
if(!clhs || !crhs)
|
if(!clhs || !crhs)
|
||||||
should_not_happen();
|
error_not_implemented("ellipsis between variables not implemented");
|
||||||
return set_ret(bld_->insert(ir::make_range::create(clhs, crhs)));
|
return set_ret(bld_->insert(ir::make_range::create(clhs, crhs)));
|
||||||
}
|
}
|
||||||
case '+':
|
case '+':
|
||||||
@@ -97,7 +97,7 @@ void Generator::VisitBinaryOp(BinaryOp* binary) {
|
|||||||
else if(!sign)
|
else if(!sign)
|
||||||
return set_ret(bld_->create_udiv(lhs, rhs));
|
return set_ret(bld_->create_udiv(lhs, rhs));
|
||||||
else
|
else
|
||||||
return should_not_happen();
|
return should_not_happen("/ should not encounter type not in {float, int}");
|
||||||
case '%':
|
case '%':
|
||||||
if(flt)
|
if(flt)
|
||||||
return set_ret(bld_->create_frem(lhs, rhs));
|
return set_ret(bld_->create_frem(lhs, rhs));
|
||||||
@@ -113,7 +113,7 @@ void Generator::VisitBinaryOp(BinaryOp* binary) {
|
|||||||
else if(!sign)
|
else if(!sign)
|
||||||
return set_ret(bld_->create_icmpULT(lhs, rhs));
|
return set_ret(bld_->create_icmpULT(lhs, rhs));
|
||||||
else
|
else
|
||||||
return should_not_happen();
|
return should_not_happen("< should not encounter type not in {float, int}");
|
||||||
case '>':
|
case '>':
|
||||||
if(flt)
|
if(flt)
|
||||||
return set_ret(bld_->create_fcmpOGT(lhs, rhs));
|
return set_ret(bld_->create_fcmpOGT(lhs, rhs));
|
||||||
@@ -122,7 +122,7 @@ void Generator::VisitBinaryOp(BinaryOp* binary) {
|
|||||||
else if(!sign)
|
else if(!sign)
|
||||||
return set_ret(bld_->create_icmpUGT(lhs, rhs));
|
return set_ret(bld_->create_icmpUGT(lhs, rhs));
|
||||||
else
|
else
|
||||||
return should_not_happen();
|
return should_not_happen("> should not encounter type not in {float, int}");
|
||||||
case Token::LE:
|
case Token::LE:
|
||||||
if(flt)
|
if(flt)
|
||||||
return set_ret(bld_->create_fcmpOLE(lhs, rhs));
|
return set_ret(bld_->create_fcmpOLE(lhs, rhs));
|
||||||
@@ -131,7 +131,7 @@ void Generator::VisitBinaryOp(BinaryOp* binary) {
|
|||||||
else if(!sign)
|
else if(!sign)
|
||||||
return set_ret(bld_->create_icmpULE(lhs, rhs));
|
return set_ret(bld_->create_icmpULE(lhs, rhs));
|
||||||
else
|
else
|
||||||
return should_not_happen();
|
return should_not_happen("<= should not encounter type not in {float, int}");
|
||||||
case Token::GE:
|
case Token::GE:
|
||||||
if(flt)
|
if(flt)
|
||||||
return set_ret(bld_->create_fcmpOGE(lhs, rhs));
|
return set_ret(bld_->create_fcmpOGE(lhs, rhs));
|
||||||
@@ -140,7 +140,7 @@ void Generator::VisitBinaryOp(BinaryOp* binary) {
|
|||||||
else if(!sign)
|
else if(!sign)
|
||||||
return set_ret(bld_->create_icmpUGE(lhs, rhs));
|
return set_ret(bld_->create_icmpUGE(lhs, rhs));
|
||||||
else
|
else
|
||||||
return should_not_happen();
|
return should_not_happen(">= should not encounter type not in {float, int}");
|
||||||
case Token::EQ:
|
case Token::EQ:
|
||||||
if(flt)
|
if(flt)
|
||||||
return set_ret(bld_->create_fcmpOEQ(lhs, rhs));
|
return set_ret(bld_->create_fcmpOEQ(lhs, rhs));
|
||||||
@@ -152,9 +152,9 @@ void Generator::VisitBinaryOp(BinaryOp* binary) {
|
|||||||
else
|
else
|
||||||
return set_ret(bld_->create_icmpNE(lhs, rhs));
|
return set_ret(bld_->create_icmpNE(lhs, rhs));
|
||||||
default:
|
default:
|
||||||
error_not_implemented();
|
return error_not_implemented("binary operator " + std::to_string(binary->op_) + " not implemented");
|
||||||
}
|
}
|
||||||
error_not_implemented();
|
should_not_happen("");
|
||||||
}
|
}
|
||||||
|
|
||||||
ir::reduce_inst::op_t reduce_op(int tag, bool is_float) {
|
ir::reduce_inst::op_t reduce_op(int tag, bool is_float) {
|
||||||
@@ -166,7 +166,7 @@ ir::reduce_inst::op_t reduce_op(int tag, bool is_float) {
|
|||||||
case Token::MIN: return is_float ? reduce_inst::FMIN : reduce_inst::MIN;
|
case Token::MIN: return is_float ? reduce_inst::FMIN : reduce_inst::MIN;
|
||||||
default: break;
|
default: break;
|
||||||
}
|
}
|
||||||
should_not_happen();
|
error_not_implemented("reduction operator " + std::to_string(tag) + " not implemented");
|
||||||
return reduce_inst::op_t();
|
return reduce_inst::op_t();
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -176,7 +176,10 @@ ir::value* Generator::GenUnaryMinus(ir::value* arg) {
|
|||||||
ir::value *_0 = ir::constant_fp::get_zero_value_for_negation(sca_ty);
|
ir::value *_0 = ir::constant_fp::get_zero_value_for_negation(sca_ty);
|
||||||
if(ty->is_tile_ty())
|
if(ty->is_tile_ty())
|
||||||
_0 = bld_->create_splat(_0, ty->get_tile_shapes());
|
_0 = bld_->create_splat(_0, ty->get_tile_shapes());
|
||||||
return bld_->create_sub(_0, arg);
|
if(sca_ty->is_floating_point_ty())
|
||||||
|
return bld_->create_fsub(_0, arg);
|
||||||
|
else
|
||||||
|
return bld_->create_sub(_0, arg);
|
||||||
}
|
}
|
||||||
|
|
||||||
void Generator::VisitUnaryOp(UnaryOp* unary) {
|
void Generator::VisitUnaryOp(UnaryOp* unary) {
|
||||||
@@ -187,18 +190,19 @@ void Generator::VisitUnaryOp(UnaryOp* unary) {
|
|||||||
ir::type *arg_scal_ty = arg_ty->get_scalar_ty();
|
ir::type *arg_scal_ty = arg_ty->get_scalar_ty();
|
||||||
// return
|
// return
|
||||||
switch (unary->op_) {
|
switch (unary->op_) {
|
||||||
case Token::PREFIX_INC: return error_not_implemented();
|
case Token::PREFIX_INC: return error_not_implemented("prefix increment not implemented");
|
||||||
case Token::PREFIX_DEC: return error_not_implemented();
|
case Token::PREFIX_DEC: return error_not_implemented("prefix decrement not implemented");
|
||||||
case Token::POSTFIX_INC: return error_not_implemented();
|
case Token::POSTFIX_INC: return error_not_implemented("postfix increment not implemented");
|
||||||
case Token::POSTFIX_DEC: return error_not_implemented();
|
case Token::POSTFIX_DEC: return error_not_implemented("postfix decrement not implemented");
|
||||||
case Token::ADDR: return error_not_implemented();
|
case Token::ADDR: return error_not_implemented("unary & not implemented");
|
||||||
case Token::DEREF: return set_ret(bld_->create_load(arg));
|
case Token::DEREF: return set_ret(bld_->create_load(arg));
|
||||||
case Token::PLUS: return error_not_implemented();
|
case Token::PLUS: return error_not_implemented("unary + not implemented");
|
||||||
case Token::MINUS: return set_ret(GenUnaryMinus(arg));
|
case Token::MINUS: return set_ret(GenUnaryMinus(arg));
|
||||||
case '~': return error_not_implemented();
|
case '~': return error_not_implemented("unary ~ not implemented");
|
||||||
case '!': return error_not_implemented();
|
case '!': return error_not_implemented("unary ! not implemented");
|
||||||
case Token::BITCAST: return set_ret(GenBitCastOp(arg, GenIRType(unary->Type(), *ctx_)));
|
case Token::BITCAST: return set_ret(GenBitCastOp(arg, GenIRType(unary->Type(), *ctx_)));
|
||||||
case Token::CAST: return set_ret(GenSemCastOp(arg, GenIRType(unary->Type(), *ctx_)));
|
case Token::CAST: return set_ret(GenSemCastOp(arg, GenIRType(unary->Type(), *ctx_)));
|
||||||
|
case Token::EXP: return set_ret(bld_->create_exp(arg)); //FIXME cast
|
||||||
case Token::REDUCE: {
|
case Token::REDUCE: {
|
||||||
int ax, tag;
|
int ax, tag;
|
||||||
UnaryOp::decodeRed(unary->info_, ax, tag);
|
UnaryOp::decodeRed(unary->info_, ax, tag);
|
||||||
@@ -206,9 +210,9 @@ void Generator::VisitUnaryOp(UnaryOp* unary) {
|
|||||||
ir::reduce_inst::op_t op = reduce_op(tag, is_float);
|
ir::reduce_inst::op_t op = reduce_op(tag, is_float);
|
||||||
return set_ret(bld_->create_reduce(arg, op, ax));
|
return set_ret(bld_->create_reduce(arg, op, ax));
|
||||||
}
|
}
|
||||||
default: error_not_implemented();
|
default: error_not_implemented("unary " + std::to_string(unary->op_) + " not implemented");
|
||||||
}
|
}
|
||||||
return error_not_implemented();
|
return should_not_happen("");
|
||||||
}
|
}
|
||||||
|
|
||||||
void Generator::VisitTransOp(TransOp *trans) {
|
void Generator::VisitTransOp(TransOp *trans) {
|
||||||
@@ -225,7 +229,9 @@ void Generator::VisitConditionalOp(ConditionalOp* condOp) {
|
|||||||
ir::value* true_val = ret_;
|
ir::value* true_val = ret_;
|
||||||
VisitExpr(condOp->exprFalse_);
|
VisitExpr(condOp->exprFalse_);
|
||||||
ir::value* false_val = ret_;
|
ir::value* false_val = ret_;
|
||||||
if(ir::load_inst* ld = dynamic_cast<ir::load_inst*>(true_val)) {
|
if(ir::unmasked_load_inst* ld = dynamic_cast<ir::unmasked_load_inst*>(true_val)) {
|
||||||
|
if(!false_val->get_type()->is_tile_ty())
|
||||||
|
false_val = bld_->create_splat(false_val, cond->get_type()->get_tile_shapes());
|
||||||
ir::value* new_ld = bld_->create_masked_load(ld->get_pointer_operand(),
|
ir::value* new_ld = bld_->create_masked_load(ld->get_pointer_operand(),
|
||||||
cond,
|
cond,
|
||||||
false_val);
|
false_val);
|
||||||
@@ -233,7 +239,8 @@ void Generator::VisitConditionalOp(ConditionalOp* condOp) {
|
|||||||
ld->erase_from_parent();
|
ld->erase_from_parent();
|
||||||
return set_ret(new_ld);
|
return set_ret(new_ld);
|
||||||
}
|
}
|
||||||
return error_not_implemented();
|
return set_ret(bld_->create_select(cond, true_val, false_val));
|
||||||
|
// return error_not_implemented();
|
||||||
}
|
}
|
||||||
|
|
||||||
void Generator::VisitFuncCall(FuncCall* funcCall) {
|
void Generator::VisitFuncCall(FuncCall* funcCall) {
|
||||||
@@ -244,7 +251,7 @@ void Generator::VisitFuncCall(FuncCall* funcCall) {
|
|||||||
if(auto axis = dynamic_cast<ir::constant_int*>(ret))
|
if(auto axis = dynamic_cast<ir::constant_int*>(ret))
|
||||||
return set_ret(bld_->create_get_program_id(axis->get_value()));
|
return set_ret(bld_->create_get_program_id(axis->get_value()));
|
||||||
else
|
else
|
||||||
return should_not_happen();
|
return should_not_happen("get_program_id argument should be constant");
|
||||||
}
|
}
|
||||||
if(name == "get_num_programs"){
|
if(name == "get_num_programs"){
|
||||||
VisitExpr(funcCall->Args()->at(0));
|
VisitExpr(funcCall->Args()->at(0));
|
||||||
@@ -252,7 +259,7 @@ void Generator::VisitFuncCall(FuncCall* funcCall) {
|
|||||||
if(auto axis = dynamic_cast<ir::constant_int*>(ret))
|
if(auto axis = dynamic_cast<ir::constant_int*>(ret))
|
||||||
return set_ret(bld_->create_get_num_program(axis->get_value()));
|
return set_ret(bld_->create_get_num_program(axis->get_value()));
|
||||||
else
|
else
|
||||||
return should_not_happen();
|
return should_not_happen("get_num_programs argument should be constant");
|
||||||
}
|
}
|
||||||
if(name == "atomic_cas"){
|
if(name == "atomic_cas"){
|
||||||
VisitExpr(funcCall->Args()->at(0));
|
VisitExpr(funcCall->Args()->at(0));
|
||||||
@@ -294,7 +301,7 @@ void Generator::VisitFuncCall(FuncCall* funcCall) {
|
|||||||
ir::value* false_val = ret_;
|
ir::value* false_val = ret_;
|
||||||
return set_ret(bld_->create_select(cond, true_val, false_val));
|
return set_ret(bld_->create_select(cond, true_val, false_val));
|
||||||
}
|
}
|
||||||
return error_not_implemented();
|
return error_not_implemented("function calls not implemented");
|
||||||
}
|
}
|
||||||
|
|
||||||
void Generator::VisitObject(Object* obj) {
|
void Generator::VisitObject(Object* obj) {
|
||||||
@@ -302,7 +309,7 @@ void Generator::VisitObject(Object* obj) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
void Generator::VisitEnumerator(Enumerator* enumer) {
|
void Generator::VisitEnumerator(Enumerator* enumer) {
|
||||||
return error_not_implemented();
|
return error_not_implemented("enumeration not implemented");
|
||||||
}
|
}
|
||||||
|
|
||||||
void Generator::VisitIdentifier(Identifier* ident) {
|
void Generator::VisitIdentifier(Identifier* ident) {
|
||||||
@@ -316,31 +323,36 @@ void Generator::VisitConstant(Constant* cons) {
|
|||||||
return set_ret(ir::constant_int::get(type, cons->IVal()));
|
return set_ret(ir::constant_int::get(type, cons->IVal()));
|
||||||
if(ctype->IsFloat() && ctype->IsReal())
|
if(ctype->IsFloat() && ctype->IsReal())
|
||||||
return set_ret(ir::constant_fp::get(type, cons->FVal()));
|
return set_ret(ir::constant_fp::get(type, cons->FVal()));
|
||||||
return error_not_implemented();
|
return error_not_implemented("constant of type not in {int, float} not implemented");
|
||||||
}
|
}
|
||||||
|
|
||||||
void Generator::VisitTempVar(TempVar* tempVar) {
|
void Generator::VisitTempVar(TempVar* tempVar) {
|
||||||
return error_not_implemented();
|
return error_not_implemented("temporary variable not implemented");
|
||||||
}
|
}
|
||||||
|
|
||||||
// Statement
|
// Statement
|
||||||
void Generator::VisitDeclaration(Declaration* decl) {
|
void Generator::VisitDeclaration(Declaration* decl) {
|
||||||
auto obj = decl->obj_;
|
auto obj = decl->obj_;
|
||||||
// initialize to undef
|
// initialize to undef
|
||||||
|
|
||||||
ir::type* ty = GenIRType(obj->Type(), *ctx_);
|
ir::type* ty = GenIRType(obj->Type(), *ctx_);
|
||||||
ir::value* val = ir::undef_value::get(ty);
|
ir::value* val = ir::undef_value::get(ty);
|
||||||
|
//obj->GetAttrList()
|
||||||
// compute initializers
|
// compute initializers
|
||||||
std::vector<ir::value*> inits;
|
std::vector<ir::value*> inits;
|
||||||
for (const Initializer& init: decl->Inits()) {
|
for (const Initializer& init: decl->Inits()) {
|
||||||
VisitExpr(init.expr_);
|
VisitExpr(init.expr_);
|
||||||
inits.push_back(ret_);
|
ir::value *val = ret_;
|
||||||
|
for(const auto& attr: obj->GetAttrList())
|
||||||
|
SetIRMetadata(attr, val);
|
||||||
|
inits.push_back(val);
|
||||||
}
|
}
|
||||||
// initialize declaration
|
// initialize declaration
|
||||||
ir::type::id_t id = ty->get_type_id();
|
ir::type::id_t id = ty->get_type_id();
|
||||||
if(id == ir::type::StructTyID)
|
if(id == ir::type::StructTyID)
|
||||||
should_not_happen();
|
error_not_implemented("struct not implemented");
|
||||||
if(inits.size() > 1)
|
if(inits.size() > 1)
|
||||||
should_not_happen();
|
error_not_implemented("initializer list > 1 element not implemented");
|
||||||
if(inits.size() > 0)
|
if(inits.size() > 0)
|
||||||
val = inits[0];
|
val = inits[0];
|
||||||
assert(val->get_type() == ty);
|
assert(val->get_type() == ty);
|
||||||
@@ -427,20 +439,20 @@ void Generator::VisitForStmt(ForStmt *forStmt) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
void Generator::VisitJumpStmt(JumpStmt* jumpStmt) {
|
void Generator::VisitJumpStmt(JumpStmt* jumpStmt) {
|
||||||
return error_not_implemented();
|
return error_not_implemented("jump not implemented");
|
||||||
}
|
}
|
||||||
|
|
||||||
void Generator::VisitReturnStmt(ReturnStmt* returnStmt) {
|
void Generator::VisitReturnStmt(ReturnStmt* returnStmt) {
|
||||||
ir::value *ret;
|
ir::value *ret;
|
||||||
if(returnStmt->expr_)
|
if(returnStmt->expr_)
|
||||||
return error_not_implemented();
|
return error_not_implemented("non-void return not implemented");
|
||||||
else
|
else
|
||||||
ret = bld_->create_ret_void();
|
ret = bld_->create_ret_void();
|
||||||
return set_ret(ret);
|
return set_ret(ret);
|
||||||
}
|
}
|
||||||
|
|
||||||
void Generator::VisitLabelStmt(LabelStmt* labelStmt) {
|
void Generator::VisitLabelStmt(LabelStmt* labelStmt) {
|
||||||
return error_not_implemented();
|
return error_not_implemented("label not implemented");
|
||||||
}
|
}
|
||||||
|
|
||||||
void Generator::VisitCompoundStmt(CompoundStmt* compoundStmt) {
|
void Generator::VisitCompoundStmt(CompoundStmt* compoundStmt) {
|
||||||
@@ -458,7 +470,7 @@ void Generator::VisitFuncDef(FuncDef* funcDef) {
|
|||||||
FuncType* type = funcDef->FuncType();
|
FuncType* type = funcDef->FuncType();
|
||||||
auto prototype = dynamic_cast<ir::function_type*>(GenIRType(type, *ctx_));
|
auto prototype = dynamic_cast<ir::function_type*>(GenIRType(type, *ctx_));
|
||||||
if(!prototype)
|
if(!prototype)
|
||||||
should_not_happen();
|
should_not_happen("could not parse function prototype");
|
||||||
ir::function *fn = mod_->get_or_insert_function(name, prototype);
|
ir::function *fn = mod_->get_or_insert_function(name, prototype);
|
||||||
std::vector<ir::argument*> args = fn->args();
|
std::vector<ir::argument*> args = fn->args();
|
||||||
size_t i = 0;
|
size_t i = 0;
|
||||||
@@ -529,7 +541,7 @@ ir::value* Generator::GenBroadcastOp(ir::value* src, ir::type* dst_ty) {
|
|||||||
for(size_t d = 0; d < padded_shapes.size(); d++){
|
for(size_t d = 0; d < padded_shapes.size(); d++){
|
||||||
if(dst_shapes[d] != padded_shapes[d] &&
|
if(dst_shapes[d] != padded_shapes[d] &&
|
||||||
padded_shapes[d] != 1)
|
padded_shapes[d] != 1)
|
||||||
should_not_happen();
|
should_not_happen("broadcast should not happen between these shapes");
|
||||||
}
|
}
|
||||||
// pad and broadcast
|
// pad and broadcast
|
||||||
ir::value *padded = bld_->create_reshape(src, padded_shapes);
|
ir::value *padded = bld_->create_reshape(src, padded_shapes);
|
||||||
@@ -555,6 +567,9 @@ ir::value* Generator::GenNumcastOp(ir::value*src, ir::type* dst_ty) {
|
|||||||
bool dst_signed = false;
|
bool dst_signed = false;
|
||||||
if(src_scalar_ty == dst_scalar_ty)
|
if(src_scalar_ty == dst_scalar_ty)
|
||||||
return src;
|
return src;
|
||||||
|
else if(src_scalar_ty->is_pointer_ty() && dst_scalar_ty->is_bool_ty())
|
||||||
|
return bld_->create_icmpNE(bld_->create_ptr_to_int(src, ir::tile_type::get_same_shapes(bld_->get_int64_ty(), src->get_type())),
|
||||||
|
bld_->create_splat(bld_->get_int64(0), src->get_type()->get_tile_shapes()));
|
||||||
else if(src_scalar_ty->is_integer_ty() && src_signed && dst_scalar_ty->is_floating_point_ty())
|
else if(src_scalar_ty->is_integer_ty() && src_signed && dst_scalar_ty->is_floating_point_ty())
|
||||||
return bld_->create_si_to_fp(src, dst_ty);
|
return bld_->create_si_to_fp(src, dst_ty);
|
||||||
else if(src_scalar_ty->is_integer_ty() && !src_signed && dst_scalar_ty->is_floating_point_ty())
|
else if(src_scalar_ty->is_integer_ty() && !src_signed && dst_scalar_ty->is_floating_point_ty())
|
||||||
@@ -575,7 +590,7 @@ ir::value* Generator::GenNumcastOp(ir::value*src, ir::type* dst_ty) {
|
|||||||
else if(src_scalar_ty->is_pointer_ty() && dst_scalar_ty->is_pointer_ty())
|
else if(src_scalar_ty->is_pointer_ty() && dst_scalar_ty->is_pointer_ty())
|
||||||
return bld_->create_cast(ir::BitCast, src, dst_ty);
|
return bld_->create_cast(ir::BitCast, src, dst_ty);
|
||||||
else{
|
else{
|
||||||
should_not_happen();
|
error_not_implemented("cast type not implemented");
|
||||||
return nullptr;
|
return nullptr;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -594,7 +609,7 @@ ir::attribute Generator::GenIRAttr(ASTNode::Attr attr) {
|
|||||||
if(attr.kind == ASTNode::Attr::MULTIPLEOF) {
|
if(attr.kind == ASTNode::Attr::MULTIPLEOF) {
|
||||||
VisitExpr(attr.vals[0]);
|
VisitExpr(attr.vals[0]);
|
||||||
auto cst = dynamic_cast<ir::constant_int*>(ret_);
|
auto cst = dynamic_cast<ir::constant_int*>(ret_);
|
||||||
if(!cst) should_not_happen();
|
if(!cst) should_not_happen("multipleof only works on constants");
|
||||||
return ir::attribute(ir::multiple_of, cst->get_value());
|
return ir::attribute(ir::multiple_of, cst->get_value());
|
||||||
}
|
}
|
||||||
if(attr.kind == ASTNode::Attr::ALIGNED) {
|
if(attr.kind == ASTNode::Attr::ALIGNED) {
|
||||||
@@ -608,7 +623,15 @@ ir::attribute Generator::GenIRAttr(ASTNode::Attr attr) {
|
|||||||
return ir::attribute(ir::readonly);
|
return ir::attribute(ir::readonly);
|
||||||
if(attr.kind == ASTNode::Attr::WRITEONLY)
|
if(attr.kind == ASTNode::Attr::WRITEONLY)
|
||||||
return ir::attribute(ir::writeonly);
|
return ir::attribute(ir::writeonly);
|
||||||
should_not_happen();
|
error_not_implemented("attribute " + std::to_string(attr.kind) + " not implemented");
|
||||||
|
}
|
||||||
|
|
||||||
|
void Generator::SetIRMetadata(ASTNode::Attr attr, ir::value *v) {
|
||||||
|
auto *i = dynamic_cast<ir::instruction*>(v);
|
||||||
|
if(!i)
|
||||||
|
return;
|
||||||
|
if(attr.kind == ASTNode::Attr::MULTIPLEOF)
|
||||||
|
i->set_metadata(ir::metadata::multiple_of, GenIRAttr(attr).get_value());
|
||||||
}
|
}
|
||||||
|
|
||||||
// Triton-IR Types
|
// Triton-IR Types
|
||||||
@@ -684,12 +707,12 @@ ir::type* Generator::GenIRPointerType(PointerType* type, ir::context& ctx) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
ir::type* Generator::GenIRStructType(StructType* type, ir::context& ctx) {
|
ir::type* Generator::GenIRStructType(StructType* type, ir::context& ctx) {
|
||||||
error_not_implemented();
|
error_not_implemented("struct not implemented");
|
||||||
return nullptr;
|
return nullptr;
|
||||||
}
|
}
|
||||||
|
|
||||||
void Generator::AllocObjects(Scope* scope, const FuncDef::ParamList& params) {
|
void Generator::AllocObjects(Scope* scope, const FuncDef::ParamList& params) {
|
||||||
return error_not_implemented();
|
return error_not_implemented("alloc not implemented");
|
||||||
}
|
}
|
||||||
|
|
||||||
// SSA
|
// SSA
|
||||||
@@ -704,7 +727,7 @@ void Generator::popScope() {
|
|||||||
// LValue Generator
|
// LValue Generator
|
||||||
void LValAssigner::VisitBinaryOp(BinaryOp* binary) {
|
void LValAssigner::VisitBinaryOp(BinaryOp* binary) {
|
||||||
if(binary->op_ != Token::MASKED_DEREF)
|
if(binary->op_ != Token::MASKED_DEREF)
|
||||||
error_not_implemented();
|
error_not_implemented("lvalue for binary non masked-deref not implemented");
|
||||||
gen_->VisitExpr(binary->lhs_);
|
gen_->VisitExpr(binary->lhs_);
|
||||||
ir::value* mask = gen_->ret_;
|
ir::value* mask = gen_->ret_;
|
||||||
gen_->VisitExpr(binary->rhs_);
|
gen_->VisitExpr(binary->rhs_);
|
||||||
@@ -714,7 +737,7 @@ void LValAssigner::VisitBinaryOp(BinaryOp* binary) {
|
|||||||
|
|
||||||
void LValAssigner::VisitUnaryOp(UnaryOp* unary) {
|
void LValAssigner::VisitUnaryOp(UnaryOp* unary) {
|
||||||
if(unary->op_ != Token::DEREF)
|
if(unary->op_ != Token::DEREF)
|
||||||
should_not_happen();
|
error_not_implemented("lvalue for unary non deref not implemented");
|
||||||
gen_->VisitExpr(unary->operand_);
|
gen_->VisitExpr(unary->operand_);
|
||||||
ir::value* addr = gen_->ret_;
|
ir::value* addr = gen_->ret_;
|
||||||
ret_ = gen_->bld_->create_store(addr, rhs_);
|
ret_ = gen_->bld_->create_store(addr, rhs_);
|
||||||
|
@@ -258,7 +258,6 @@ Constant* Parser::ParseFloat(const Token* tok) {
|
|||||||
}
|
}
|
||||||
if (str[end] != 0)
|
if (str[end] != 0)
|
||||||
Error(tok, "invalid suffix");
|
Error(tok, "invalid suffix");
|
||||||
|
|
||||||
return Constant::New(tok, tag, val);
|
return Constant::New(tok, tag, val);
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -571,6 +570,7 @@ Expr* Parser::ParseUnaryExpr() {
|
|||||||
case Token::SIZEOF: return ParseSizeof();
|
case Token::SIZEOF: return ParseSizeof();
|
||||||
case Token::INC: return ParsePrefixIncDec(tok);
|
case Token::INC: return ParsePrefixIncDec(tok);
|
||||||
case Token::DEC: return ParsePrefixIncDec(tok);
|
case Token::DEC: return ParsePrefixIncDec(tok);
|
||||||
|
case Token::EXP: return ParseUnaryIntrinsicOp(tok, Token::EXP); //FIXME: merge into generic array functions
|
||||||
case '&': return ParseUnaryOp(tok, Token::ADDR);
|
case '&': return ParseUnaryOp(tok, Token::ADDR);
|
||||||
case '*': return ParseDerefOp(tok);
|
case '*': return ParseDerefOp(tok);
|
||||||
case '+': return ParseUnaryOp(tok, Token::PLUS);
|
case '+': return ParseUnaryOp(tok, Token::PLUS);
|
||||||
@@ -634,6 +634,12 @@ UnaryOp* Parser::ParsePrefixIncDec(const Token* tok) {
|
|||||||
return UnaryOp::New(op, operand);
|
return UnaryOp::New(op, operand);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
UnaryOp* Parser::ParseUnaryIntrinsicOp(const Token* tok, int op) {
|
||||||
|
ts_.Expect('(');
|
||||||
|
auto operand = ParseExpr();
|
||||||
|
ts_.Expect(')');
|
||||||
|
return UnaryOp::New(op, operand);
|
||||||
|
}
|
||||||
|
|
||||||
UnaryOp* Parser::ParseUnaryOp(const Token* tok, int op) {
|
UnaryOp* Parser::ParseUnaryOp(const Token* tok, int op) {
|
||||||
auto operand = ParseCastExpr();
|
auto operand = ParseCastExpr();
|
||||||
|
@@ -45,6 +45,7 @@ const std::unordered_map<std::string, int> Token::kwTypeMap_ {
|
|||||||
{ "volatile", Token::VOLATILE },
|
{ "volatile", Token::VOLATILE },
|
||||||
{ "while", Token::WHILE },
|
{ "while", Token::WHILE },
|
||||||
{ "bitcast", Token::BITCAST },
|
{ "bitcast", Token::BITCAST },
|
||||||
|
{ "exp", Token::EXP },
|
||||||
{ "_Alignas", Token::ALIGNAS },
|
{ "_Alignas", Token::ALIGNAS },
|
||||||
{ "_Alignof", Token::ALIGNOF },
|
{ "_Alignof", Token::ALIGNOF },
|
||||||
{ "_Atomic", Token::ATOMIC },
|
{ "_Atomic", Token::ATOMIC },
|
||||||
@@ -147,6 +148,7 @@ const std::unordered_map<int, const char*> Token::tagLexemeMap_ {
|
|||||||
{ Token::VOLATILE, "volatile" },
|
{ Token::VOLATILE, "volatile" },
|
||||||
{ Token::WHILE, "while" },
|
{ Token::WHILE, "while" },
|
||||||
{ Token::BITCAST, "bitcast" },
|
{ Token::BITCAST, "bitcast" },
|
||||||
|
{ Token::EXP, "exp" },
|
||||||
{ Token::ALIGNAS, "_Alignas" },
|
{ Token::ALIGNAS, "_Alignas" },
|
||||||
{ Token::ALIGNOF, "_Alignof" },
|
{ Token::ALIGNOF, "_Alignof" },
|
||||||
{ Token::ATOMIC, "_Atomic" },
|
{ Token::ATOMIC, "_Atomic" },
|
||||||
|
@@ -165,8 +165,10 @@ void function::caller::operator ()(driver::stream *stream, const grid_t& _grid,
|
|||||||
arg_type ty = arg_i.type();
|
arg_type ty = arg_i.type();
|
||||||
if(ty != param_tys_.at(i))
|
if(ty != param_tys_.at(i))
|
||||||
throw std::runtime_error("invalid type for argument " + std::to_string(i));
|
throw std::runtime_error("invalid type for argument " + std::to_string(i));
|
||||||
if(ty == BUFFER_T)
|
if(ty == BUFFER_T){
|
||||||
bin_->setArg(i, *((driver::buffer**)arg_i.data()));
|
driver::buffer* buf = *((driver::buffer**)arg_i.data());
|
||||||
|
bin_->setArg(i, buf->size() == 0 ? nullptr : buf);
|
||||||
|
}
|
||||||
else
|
else
|
||||||
bin_->setArg(i, size_of(ty), arg_i.data());
|
bin_->setArg(i, size_of(ty), arg_i.data());
|
||||||
}
|
}
|
||||||
@@ -216,6 +218,7 @@ std::unique_ptr<driver::module> function::make_bin(ir::module &module,
|
|||||||
codegen::transform::cts cts;
|
codegen::transform::cts cts;
|
||||||
codegen::generator isel(&axes, &layouts, &align, &allocation, target.get(), opt.num_warps);
|
codegen::generator isel(&axes, &layouts, &align, &allocation, target.get(), opt.num_warps);
|
||||||
// run passes
|
// run passes
|
||||||
|
// ir::print(module, std::cout);
|
||||||
dce.run(module);
|
dce.run(module);
|
||||||
disassociate.run(module);
|
disassociate.run(module);
|
||||||
dce.run(module);
|
dce.run(module);
|
||||||
@@ -231,6 +234,7 @@ std::unique_ptr<driver::module> function::make_bin(ir::module &module,
|
|||||||
dce.run(module);
|
dce.run(module);
|
||||||
reassociate.run(module);
|
reassociate.run(module);
|
||||||
cts.run(module);
|
cts.run(module);
|
||||||
|
peephole.run(module);
|
||||||
dce.run(module);
|
dce.run(module);
|
||||||
align.run(module);
|
align.run(module);
|
||||||
axes.run(module);
|
axes.run(module);
|
||||||
@@ -238,7 +242,7 @@ std::unique_ptr<driver::module> function::make_bin(ir::module &module,
|
|||||||
liveness.run(module);
|
liveness.run(module);
|
||||||
allocation.run(module);
|
allocation.run(module);
|
||||||
if(allocation.allocated_size() > context->device()->max_shared_memory())
|
if(allocation.allocated_size() > context->device()->max_shared_memory())
|
||||||
return std::unique_ptr<driver::module>();
|
throw std::runtime_error("using too much shared memory");
|
||||||
barriers.run(module);
|
barriers.run(module);
|
||||||
isel.visit(module, *llvm);
|
isel.visit(module, *llvm);
|
||||||
std::unique_ptr<driver::module> res(driver::module::create(context, std::move(llvm)));
|
std::unique_ptr<driver::module> res(driver::module::create(context, std::move(llvm)));
|
||||||
@@ -391,6 +395,8 @@ std::string function::preheader() {
|
|||||||
#define __aligned(A) __attribute__((aligned(A)))
|
#define __aligned(A) __attribute__((aligned(A)))
|
||||||
#define __multipleof(A) __attribute__((multipleof(A)))
|
#define __multipleof(A) __attribute__((multipleof(A)))
|
||||||
|
|
||||||
|
#define INFINITY bitcast<float>(0x7F800000)
|
||||||
|
|
||||||
extern int atomic_cas(int*, int, int);
|
extern int atomic_cas(int*, int, int);
|
||||||
extern int atomic_xchg(int*, int);
|
extern int atomic_xchg(int*, int);
|
||||||
extern int get_program_id(int);
|
extern int get_program_id(int);
|
||||||
|
@@ -13,18 +13,18 @@ configs = []
|
|||||||
|
|
||||||
# Matrix multiplication
|
# Matrix multiplication
|
||||||
MNK = [
|
MNK = [
|
||||||
(512, 512 ,512),
|
(1024, 1024, 1024),
|
||||||
(2048, 2048, 2048),
|
(2048, 2048, 2048),
|
||||||
#(8192, 8192, 8192),
|
(8192, 8192, 8192),
|
||||||
|
|
||||||
(64, 64, 64000),
|
#(64, 64, 64000),
|
||||||
(64, 64, 128000),
|
#(64, 64, 128000),
|
||||||
(256, 256, 64000),
|
#(256, 256, 64000),
|
||||||
(256, 256, 128000),
|
#(256, 256, 128000),
|
||||||
|
|
||||||
(1536, 16, 1536),
|
#(1536, 16, 1536),
|
||||||
(1536, 32, 1536),
|
#(1536, 32, 1536),
|
||||||
(1536, 64, 1536),
|
#(1536, 64, 1536),
|
||||||
# (1536, 128, 1536),
|
# (1536, 128, 1536),
|
||||||
# (4096, 16, 4096),
|
# (4096, 16, 4096),
|
||||||
# (4096, 32, 4096),
|
# (4096, 32, 4096),
|
||||||
@@ -33,9 +33,9 @@ MNK = [
|
|||||||
|
|
||||||
# (127008, 768, 576)
|
# (127008, 768, 576)
|
||||||
]
|
]
|
||||||
for M, N, K in MNK:
|
#for M, N, K in MNK:
|
||||||
matmul = lambda a, b: torch.matmul(a, b)
|
# matmul = lambda a, b: torch.matmul(a, b)
|
||||||
configs += [([M, K], [K, N], [M, N], matmul, 'mk,kn->mn', dict())]
|
# configs += [([M, K], [K, N], [M, N], matmul, 'mk,kn->mn', dict())]
|
||||||
#for M, N, K in MNK:
|
#for M, N, K in MNK:
|
||||||
# matmul = lambda a, b: torch.matmul(a.t(), b)
|
# matmul = lambda a, b: torch.matmul(a.t(), b)
|
||||||
# configs += [([M, K], [M, N], [K, N], None, 'mk,mn->kn', dict())]
|
# configs += [([M, K], [M, N], [K, N], None, 'mk,mn->kn', dict())]
|
||||||
@@ -94,8 +94,8 @@ for N, C, H, K, R in NCHKR:
|
|||||||
|
|
||||||
# 2D Dense convolution
|
# 2D Dense convolution
|
||||||
NCHWKRS = [
|
NCHWKRS = [
|
||||||
#(8, 64, 128, 128, 768, 3, 3),
|
(8, 64, 128, 128, 768, 3, 3),
|
||||||
(128, 3, 32, 32, 64, 3, 3),
|
#(128, 3, 32, 32, 64, 3, 3),
|
||||||
#(8, 256, 32, 32, 512, 3, 3),
|
#(8, 256, 32, 32, 512, 3, 3),
|
||||||
#(8, 512, 32, 32, 1024, 3, 3)
|
#(8, 512, 32, 32, 1024, 3, 3)
|
||||||
]
|
]
|
||||||
@@ -160,22 +160,39 @@ for N, C, H, W, K, R, S in NCHWKRS:
|
|||||||
b = b.permute(1, 0)
|
b = b.permute(1, 0)
|
||||||
b = b.reshape(b.shape[0], b.shape[1], 1, 1)
|
b = b.reshape(b.shape[0], b.shape[1], 1, 1)
|
||||||
return torch.nn.functional.conv2d(a, b)
|
return torch.nn.functional.conv2d(a, b)
|
||||||
configs += [([N, C, H, W],
|
configs += [([N, C, H, W], [C, K], [N, K, H, W],
|
||||||
[C, K],
|
shift_conv,
|
||||||
[N, K, H, W],
|
'nc(h + sh[c])(w + sw[c]),ck->nkhw',
|
||||||
shift_conv,
|
{'sh': shift_h, 'sw': shift_w})]
|
||||||
'nc(h + sh[c])(w + sw[c]),ck->nkhw',
|
|
||||||
{'sh': shift_h, 'sw': shift_w})]
|
NCHWKX = [
|
||||||
|
#(8, 64, 128, 128, 128, 7)
|
||||||
|
]
|
||||||
|
for N, C, H, W, K, X in NCHWKX:
|
||||||
|
off_h = np.array([0, 0, 0, 1, 2, 3, 4], dtype=np.int32)
|
||||||
|
off_w = np.array([0, 1, 3, 1, 3, 0, 4], dtype=np.int32)
|
||||||
|
R, S = 5, 5
|
||||||
|
def sparse_conv(a, b, **kwargs):
|
||||||
|
off_h, off_w = kwargs['off_h'], kwargs['off_w']
|
||||||
|
K, C, X = b.shape
|
||||||
|
cvtb = torch.zeros([K, C, R, S], dtype=b.dtype, device=b.device)
|
||||||
|
cvtb[:, :, off_h, off_w] = b
|
||||||
|
return torch.nn.functional.conv2d(a, cvtb)
|
||||||
|
configs += [([N, C, H, W], [K, C, X], [N, K, H - R + 1, W - S + 1],
|
||||||
|
sparse_conv,
|
||||||
|
'nc(h + off_h[x])(w + off_w[x]),kcx->nkhw',
|
||||||
|
{'off_h': off_h, 'off_w': off_w})]
|
||||||
|
|
||||||
|
|
||||||
# Benchmark
|
# Benchmark
|
||||||
torch.set_num_threads(1)
|
torch.set_num_threads(1)
|
||||||
for a_shape, b_shape, c_shape, torch_fn, expr, arrays in configs:
|
for a_shape, b_shape, c_shape, torch_fn, expr, arrays in configs:
|
||||||
dtype = torch.cuda.FloatTensor
|
dtype = torch.cuda.HalfTensor
|
||||||
# initialize input tensors
|
# initialize input tensors
|
||||||
a = torch.rand(*a_shape).type(dtype).cuda()
|
a = torch.rand(*a_shape).type(dtype).cuda()
|
||||||
b = torch.rand(*b_shape).type(dtype).cuda()
|
b = torch.rand(*b_shape).type(dtype).cuda()
|
||||||
# triton output
|
# triton output
|
||||||
tc = torch.empty(c_shape, device=a.device)
|
tc = torch.zeros(c_shape, dtype=a.dtype, device=a.device)
|
||||||
triton.ops.einsum(expr, a, b, tc, arrays = arrays, bench = True)
|
triton.ops.einsum(expr, a, b, tc, arrays = arrays, bench = True)
|
||||||
# reference output
|
# reference output
|
||||||
if torch_fn:
|
if torch_fn:
|
||||||
@@ -185,12 +202,13 @@ for a_shape, b_shape, c_shape, torch_fn, expr, arrays in configs:
|
|||||||
# performance relative to equivalent matrix multiplication
|
# performance relative to equivalent matrix multiplication
|
||||||
ctx = triton.ops._einsum.registry[tc]
|
ctx = triton.ops._einsum.registry[tc]
|
||||||
B, M, N, K = ctx.matmul_B, ctx.matmul_M, ctx.matmul_N, ctx.matmul_K
|
B, M, N, K = ctx.matmul_B, ctx.matmul_M, ctx.matmul_N, ctx.matmul_K
|
||||||
cmp_eqbmm = False
|
cmp_eqbmm = True
|
||||||
if cmp_eqbmm:
|
if cmp_eqbmm:
|
||||||
a = torch.rand(B, M, K).type(dtype).cuda()
|
a = torch.rand(B, M, K).type(dtype).cuda()
|
||||||
b = torch.rand(B, K, N).type(dtype).cuda()
|
b = torch.rand(B, K, N).type(dtype).cuda()
|
||||||
tmmc = triton.ops.einsum('bmk,bkn->bmn', a, b, [B, M, N], bench = True)
|
tmmc = torch.empty([B, M, N]).type(dtype).cuda()
|
||||||
ratio = triton.ctx_registry[tmmc].forward_ms / ctx.forward_ms
|
triton.ops.einsum('bmk,bkn->bmn', a, b, tmmc, bench = True)
|
||||||
|
ratio = triton.ops._einsum.registry[tmmc].forward_ms / ctx.forward_ms
|
||||||
cmp_str = f'({ratio:4.2f})'
|
cmp_str = f'({ratio:4.2f})'
|
||||||
else:
|
else:
|
||||||
cmp_str = ''
|
cmp_str = ''
|
||||||
|
@@ -329,14 +329,16 @@ __global__ void {name}(
|
|||||||
#endif
|
#endif
|
||||||
}
|
}
|
||||||
"""
|
"""
|
||||||
|
# print(src)
|
||||||
# compilation options
|
# compilation options
|
||||||
TM, TN, TB, TZ = [16, 32, 64, 128], [16, 32, 64, 128], 1, [1, 4, 16]
|
#TM, TN, TB, TZ = [16, 32, 64, 128], [16, 32, 64, 128], 1, [1, 4, 16]
|
||||||
TK = 16 if dtype==torch.float16 else 8
|
#TK = 16 if dtype==torch.float16 else 8
|
||||||
|
TM, TN, TB, TZ, TK = 128, 128, 1, 1, 16
|
||||||
defines = {'TM': TM, 'TN': TN, 'TB': TB, 'TK': TK, 'TZ': TZ, 'TYPE': dtype}
|
defines = {'TM': TM, 'TN': TN, 'TB': TB, 'TK': TK, 'TZ': TZ, 'TYPE': dtype}
|
||||||
if mask is not None:
|
if mask is not None:
|
||||||
defines['MASK'] = '{0:#0{1}x}'.format(mask, 10)
|
defines['MASK'] = '{0:#0{1}x}'.format(mask, 10)
|
||||||
# create kernel
|
# create kernel
|
||||||
ret = triton.kernel(src, defines=defines)
|
ret = triton.kernel(src, defines=defines, num_warps=[4])
|
||||||
# set constant
|
# set constant
|
||||||
if use_lut_a and lut_mode_a == _einsum.LUT_MODE.CONSTANT:
|
if use_lut_a and lut_mode_a == _einsum.LUT_MODE.CONSTANT:
|
||||||
ret.set_constant('AD', delta_a)
|
ret.set_constant('AD', delta_a)
|
||||||
|
@@ -13,7 +13,7 @@ int main() {
|
|||||||
for(auto x: std::vector<std::array<bool, 2>>{{false, false}}){
|
for(auto x: std::vector<std::array<bool, 2>>{{false, false}}){
|
||||||
std::vector<config_t> tmp = {
|
std::vector<config_t> tmp = {
|
||||||
// config_t{ord, x[0], x[1], 512, 512, 512},
|
// config_t{ord, x[0], x[1], 512, 512, 512},
|
||||||
config_t{ord, x[0], x[1], 2048, 2048, 2048},
|
config_t{ord, x[0], x[1], 1024, 1024, 1024},
|
||||||
// config_t{ord, x[0], x[1], 127008, 768, 576},
|
// config_t{ord, x[0], x[1], 127008, 768, 576},
|
||||||
// config_t{ord, x[0], x[1], 8192, 8192, 8192}
|
// config_t{ord, x[0], x[1], 8192, 8192, 8192}
|
||||||
// config_t{ord, x[0], x[1], 16, 2048, 2048},
|
// config_t{ord, x[0], x[1], 16, 2048, 2048},
|
||||||
|
@@ -109,10 +109,10 @@ bool triton_dot(drv::stream* stream, bool AT, bool BT,
|
|||||||
opt.num_warps = {nwarp};
|
opt.num_warps = {nwarp};
|
||||||
}
|
}
|
||||||
if(mode == BENCH) {
|
if(mode == BENCH) {
|
||||||
opt.defines.push_back({"TM", {"32", "64", "128"}});
|
opt.defines.push_back({"TM", {"128"}});
|
||||||
opt.defines.push_back({"TN", {"32", "64", "128"}});
|
opt.defines.push_back({"TN", {"32"}});
|
||||||
opt.defines.push_back({"TK", {to_string<T>::value == "half" ? "16" : "8"}});
|
opt.defines.push_back({"TK", {to_string<T>::value == "half" ? "16" : "8"}});
|
||||||
opt.num_warps = {2, 4, 8};
|
opt.num_warps = {4};
|
||||||
}
|
}
|
||||||
|
|
||||||
// kernels
|
// kernels
|
||||||
|
Reference in New Issue
Block a user