[examples] back to 96 TFLOPS on V100
This commit is contained in:
@@ -49,7 +49,7 @@ inline size_t size_of(DType dtype){
|
||||
}
|
||||
}
|
||||
|
||||
std::vector<cublasGemmAlgo_t> gather_all_algos() {
|
||||
inline std::vector<cublasGemmAlgo_t> gather_all_algos() {
|
||||
std::vector<cublasGemmAlgo_t> result;
|
||||
// non-tensor ops
|
||||
for(int i = -1; i < 24; i++)
|
||||
@@ -124,7 +124,7 @@ inline cublasStatus_t cublasGemmEx(cublasHandle_t handle, cudaDataType cudt, cub
|
||||
|
||||
|
||||
/* Get cuBLAS handle */
|
||||
cublasHandle_t cublasGetHandle(triton::driver::stream* stream) {
|
||||
inline cublasHandle_t cublasGetHandle(triton::driver::stream* stream) {
|
||||
static std::map<CUstream, cublasHandle_t> cache;
|
||||
CUstream key = *stream->cu();
|
||||
|
||||
|
@@ -75,17 +75,21 @@ void dot(TYPE * A __noalias __readonly __aligned(16),
|
||||
float xc[TM, TN] = 0;
|
||||
#ifdef AT
|
||||
TYPE* pa[TK, TM] = A + rka[:, newaxis] + rxa[newaxis, :]*lda;
|
||||
TYPE a[TK, TM] = *pa;
|
||||
bool checka[TK, TM] = rka[:, newaxis] < K;
|
||||
TYPE a[TK, TM] = checka ? *pa : 0;
|
||||
#else
|
||||
TYPE* pa[TM, TK] = A + rka[newaxis, :]*lda + rxa[:, newaxis];
|
||||
TYPE a[TM, TK] = *pa;
|
||||
bool checka[TM, TK] = rka[newaxis, :] < K;
|
||||
TYPE a[TM, TK] = checka ? *pa : 0;
|
||||
#endif
|
||||
#ifdef BT
|
||||
TYPE* pb[TN, TK] = B + rkb[newaxis, :]*ldb + ryb[:, newaxis];
|
||||
TYPE b[TN, TK] = *pb;
|
||||
bool checkb[TN, TK] = rkb[newaxis, :] < K;
|
||||
TYPE b[TN, TK] = checkb ? *pb : 0;
|
||||
#else
|
||||
TYPE* pb[TK, TN] = B + rkb[:, newaxis] + ryb[newaxis, :]*ldb;
|
||||
TYPE b[TK, TN] = *pb;
|
||||
bool checkb[TK, TN] = rkb[:, newazis] < K;
|
||||
TYPE b[TK, TN] = checkb ? *pb : 0;
|
||||
#endif
|
||||
for(int k = K; k > 0; k = k - TK){
|
||||
xc = USEA @ USEB + xc;
|
||||
@@ -99,8 +103,10 @@ void dot(TYPE * A __noalias __readonly __aligned(16),
|
||||
#else
|
||||
pb = pb + TK;
|
||||
#endif
|
||||
a = *pa;
|
||||
b = *pb;
|
||||
checka = k > TK;
|
||||
checkb = k > TK;
|
||||
a = checka ? *pa : 0;
|
||||
b = checkb ? *pb : 0;
|
||||
}
|
||||
int rxc[TM] = ridx * TM + (0 ... TM);
|
||||
int ryc[TN] = ridy * TN + (0 ... TN);
|
||||
@@ -109,7 +115,7 @@ void dot(TYPE * A __noalias __readonly __aligned(16),
|
||||
bool checkc0[TM] = rxc < M;
|
||||
bool checkc1[TN] = ryc < N;
|
||||
bool checkc[TM, TN] = checkc0[:, newaxis] && checkc1[newaxis, :];
|
||||
*pc = c;
|
||||
*?(checkc) pc = c;
|
||||
}
|
||||
)";
|
||||
|
||||
|
@@ -37,11 +37,12 @@ private:
|
||||
ir::value *reassociate_ptr(ir::getelementptr_inst* pz, ir::builder &builder, std::map<ir::value*, cst_info> &offsets);
|
||||
|
||||
public:
|
||||
reassociate(analysis::grids *params);
|
||||
reassociate(analysis::alignment_info* align, analysis::grids *params);
|
||||
void run(ir::module& module);
|
||||
|
||||
private:
|
||||
analysis::grids* params_;
|
||||
analysis::alignment_info* align_;
|
||||
};
|
||||
|
||||
}
|
||||
|
@@ -360,11 +360,12 @@ public:
|
||||
switch (op_) {
|
||||
case '.': return !Type()->ToArray() && lhs_->IsLVal();
|
||||
case ']': return !Type()->ToArray();
|
||||
case Token::MASKED_DEREF: return true;
|
||||
default: return false;
|
||||
}
|
||||
}
|
||||
ArithmType* Convert();
|
||||
void Broadcast();
|
||||
static void Broadcast(Expr* loc, Expr*& lhs, Expr*& rhs, QualType &type);
|
||||
|
||||
virtual void TypeChecking();
|
||||
void SubScriptingOpTypeChecking();
|
||||
@@ -374,6 +375,7 @@ public:
|
||||
void ShiftOpTypeChecking();
|
||||
void RangeOpTypeChecking();
|
||||
void MatmulOpTypeChecking();
|
||||
void MaskedDerefOpTypeChecking();
|
||||
void RelationalOpTypeChecking();
|
||||
void EqualityOpTypeChecking();
|
||||
void BitwiseOpTypeChecking();
|
||||
|
@@ -84,6 +84,7 @@ public:
|
||||
Constant* ParseAlignof();
|
||||
UnaryOp* ParsePrefixIncDec(const Token* tok);
|
||||
UnaryOp* ParseUnaryOp(const Token* tok, int op);
|
||||
Expr* ParseDerefOp(const Token* tok);
|
||||
|
||||
QualType ParseTypeName();
|
||||
Expr* ParseCastExpr();
|
||||
|
@@ -94,6 +94,7 @@ public:
|
||||
OR_ASSIGN,
|
||||
|
||||
ELLIPSIS,
|
||||
MASKED_DEREF,
|
||||
// Punctuators end
|
||||
|
||||
// KEYWORD BEGIN
|
||||
|
@@ -41,7 +41,7 @@ inline double bench(std::function<void()> const & op, driver::stream * stream)
|
||||
while(total_time*1e-9 < 1e-3){
|
||||
float norm = 1;
|
||||
// normalize clock if possible to reduce noise in auto-tuning
|
||||
// if(auto cu_device = dynamic_cast<const triton::driver::cu_device*>(device))
|
||||
// if(auto cu_device = dynamic_cast<const triton::driver::cu_device*>(stream->context()->device()))
|
||||
// norm = (float)cu_device->current_sm_clock()/cu_device->max_sm_clock();
|
||||
tmr.start();
|
||||
op();
|
||||
|
@@ -111,8 +111,9 @@ unsigned alignment_info::populate_max_contiguous(ir::value *v){
|
||||
if(!v->get_type()->is_tile_ty())
|
||||
return cache(1);
|
||||
auto shapes = v->get_type()->get_tile_shapes();
|
||||
if(dynamic_cast<ir::constant_range*>(v))
|
||||
if(dynamic_cast<ir::constant_range*>(v)){
|
||||
return cache(shapes[0]);
|
||||
}
|
||||
if(auto *x = dynamic_cast<ir::retile_inst*>(v)){
|
||||
ir::value *op = x->get_operand(0);
|
||||
if(op->get_type()->is_tile_ty()){
|
||||
@@ -305,7 +306,6 @@ void alignment_info::run(ir::module &mod) {
|
||||
for(ir::basic_block *block: fn->blocks())
|
||||
for(ir::instruction *i: block->get_inst_list()){
|
||||
populate_max_contiguous(i);
|
||||
std::cout << i->get_name() << " " << is_constant_.at(i).num_cst << " " << starting_multiple_.at(i) << " " << max_contiguous_.at(i) << std::endl;
|
||||
}
|
||||
}
|
||||
|
||||
|
@@ -93,6 +93,9 @@ ir::value *reassociate::reassociate_idx(ir::value *old_value,
|
||||
params_->copy(new_value, old_value);
|
||||
params_->copy(new_lhs, old_value);
|
||||
params_->copy(new_rhs, old_value);
|
||||
align_->copy(new_value, old_value);
|
||||
align_->copy(new_lhs, old_value);
|
||||
align_->copy(new_rhs, old_value);
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -130,6 +133,9 @@ ir::value *reassociate::reassociate_idx(ir::value *old_value,
|
||||
params_->copy(new_value, old_value);
|
||||
params_->copy(((ir::instruction*)new_value)->get_operand(0), old_value);
|
||||
params_->copy(((ir::instruction*)new_value)->get_operand(1), old_value);
|
||||
align_->copy(new_value, old_value);
|
||||
align_->copy(((ir::instruction*)new_value)->get_operand(0), old_value);
|
||||
align_->copy(((ir::instruction*)new_value)->get_operand(1), old_value);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -155,8 +161,8 @@ ir::value *reassociate::reassociate_idx(ir::value *old_value,
|
||||
return new_value;
|
||||
}
|
||||
|
||||
reassociate::reassociate(analysis::grids* params)
|
||||
: params_(params)
|
||||
reassociate::reassociate(analysis::alignment_info *align, analysis::grids* params)
|
||||
: params_(params), align_(align)
|
||||
{ }
|
||||
|
||||
|
||||
@@ -185,6 +191,9 @@ void reassociate::run(ir::module &mod) {
|
||||
params_->copy(dyn_range, old_range);
|
||||
params_->copy(static_range, old_range);
|
||||
params_->copy(new_range, old_range);
|
||||
align_->copy(dyn_range, old_range);
|
||||
align_->copy(static_range, old_range);
|
||||
align_->copy(new_range, old_range);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -217,6 +226,8 @@ void reassociate::run(ir::module &mod) {
|
||||
ir::value *sta_ptr = builder.create_gep(dyn_ptr, {sta});
|
||||
params_->copy(dyn_ptr, pz);
|
||||
params_->copy(sta_ptr, pz);
|
||||
align_->copy(dyn_ptr, pz);
|
||||
align_->copy(sta_ptr, pz);
|
||||
pz->replace_all_uses_with(sta_ptr);
|
||||
infos[sta_ptr].dyn_ptr = dyn_ptr;
|
||||
infos[sta_ptr].sta_ptr = (ir::getelementptr_inst*)sta_ptr;
|
||||
@@ -233,6 +244,8 @@ void reassociate::run(ir::module &mod) {
|
||||
ir::value *pz_sta = builder.create_gep(pz_dyn, {cst}, pz->get_name());
|
||||
params_->copy(pz_dyn, pz);
|
||||
params_->copy(pz_sta, pz);
|
||||
align_->copy(pz_dyn, pz);
|
||||
align_->copy(pz_sta, pz);
|
||||
pz->replace_all_uses_with(pz_sta);
|
||||
infos[pz_sta].dyn_ptr = pz_dyn;
|
||||
infos[pz_sta].sta_ptr = (ir::getelementptr_inst*)pz_sta;
|
||||
@@ -283,6 +296,11 @@ void reassociate::run(ir::module &mod) {
|
||||
params_->copy(neg_off, off);
|
||||
params_->copy(phi_dyn, phi);
|
||||
params_->copy(phi_sta, phi);
|
||||
align_->copy(pz_dyn, pz);
|
||||
align_->copy(((ir::instruction*)neg_off)->get_operand(0), off);
|
||||
align_->copy(neg_off, off);
|
||||
align_->copy(phi_dyn, phi);
|
||||
align_->copy(phi_sta, phi);
|
||||
infos[phi_sta].dyn_ptr = phi_dyn;
|
||||
infos[phi_sta].sta_ptr = (ir::getelementptr_inst*)phi_sta;
|
||||
replaced.insert(phi);
|
||||
|
@@ -240,7 +240,7 @@ std::string cu_module::compile_llvm_module(llvm::Module* module) {
|
||||
cu_module::cu_module(driver::context * context, llvm::Module* ll_module): cu_module(context, compile_llvm_module(ll_module)) { }
|
||||
|
||||
cu_module::cu_module(driver::context * context, std::string const & source) : module(context, CUmodule(), true), source_(source){
|
||||
std::cout << source << std::endl;
|
||||
// std::cout << source << std::endl;
|
||||
cu_context::context_switcher ctx_switch(*context);
|
||||
// JIT compile source-code
|
||||
CUjit_option opt[] = {CU_JIT_ERROR_LOG_BUFFER_SIZE_BYTES, CU_JIT_ERROR_LOG_BUFFER};
|
||||
|
@@ -190,6 +190,7 @@ BinaryOp* BinaryOp::New(const Token* tok, int op, Expr* lhs, Expr* rhs) {
|
||||
case Token::LOGICAL_OR:
|
||||
case Token::ELLIPSIS:
|
||||
case Token::MATMUL:
|
||||
case Token::MASKED_DEREF:
|
||||
break;
|
||||
default:
|
||||
assert(0);
|
||||
@@ -218,22 +219,22 @@ ArithmType* BinaryOp::Convert() {
|
||||
return maxType;
|
||||
}
|
||||
|
||||
void BinaryOp::Broadcast() {
|
||||
auto lhsType = lhs_->Type()->ToTile();
|
||||
auto rhsType = rhs_->Type()->ToTile();
|
||||
auto eleType = type_->ScalarType();
|
||||
void BinaryOp::Broadcast(Expr* loc, Expr *&lhs, Expr *&rhs, QualType& type) {
|
||||
auto lhsType = lhs->Type()->ToTile();
|
||||
auto rhsType = rhs->Type()->ToTile();
|
||||
auto eleType = type->ScalarType();
|
||||
assert(eleType);
|
||||
if(!lhsType && !rhsType)
|
||||
return ;
|
||||
else if(lhsType && !rhsType){
|
||||
type_ = TileType::New(lhsType->Shape(), eleType);
|
||||
::Type* rtype = TileType::New(lhsType->Shape(), rhs_->Type()->ScalarType());
|
||||
rhs_ = UnaryOp::New(Token::CAST, rhs_, rtype);
|
||||
type = TileType::New(lhsType->Shape(), eleType);
|
||||
::Type* rtype = TileType::New(lhsType->Shape(), rhs->Type()->ScalarType());
|
||||
rhs = UnaryOp::New(Token::CAST, rhs, rtype);
|
||||
}
|
||||
else if(!lhsType && rhsType){
|
||||
type_ = TileType::New(rhsType->Shape(), eleType);
|
||||
::Type* ltype = TileType::New(rhsType->Shape(), lhs_->Type()->ScalarType());
|
||||
lhs_ = UnaryOp::New(Token::CAST, lhs_, ltype);
|
||||
type = TileType::New(rhsType->Shape(), eleType);
|
||||
::Type* ltype = TileType::New(rhsType->Shape(), lhs->Type()->ScalarType());
|
||||
lhs = UnaryOp::New(Token::CAST, lhs, ltype);
|
||||
|
||||
}
|
||||
else {
|
||||
@@ -257,17 +258,17 @@ void BinaryOp::Broadcast() {
|
||||
else if(lhsShape[i] == rhsShape[i])
|
||||
retShape[i] = lhsShape[i];
|
||||
else
|
||||
Error(this, "cannot broadcast dimension %d "
|
||||
Error(loc, "cannot broadcast dimension %d "
|
||||
"for operands of shape %d and %d",
|
||||
i, lhsShape[i], rhsShape[i]);
|
||||
}
|
||||
::Type* ltype = TileType::New(retShape, lhsType->ScalarType());
|
||||
::Type* rtype = TileType::New(retShape, rhsType->ScalarType());
|
||||
type_ = TileType::New(retShape, eleType);
|
||||
type = TileType::New(retShape, eleType);
|
||||
if(retShape != lhsShape)
|
||||
lhs_ = UnaryOp::New(Token::CAST, lhs_, ltype);
|
||||
lhs = UnaryOp::New(Token::CAST, lhs, ltype);
|
||||
if(retShape != rhsShape)
|
||||
rhs_ = UnaryOp::New(Token::CAST, rhs_, rtype);
|
||||
rhs = UnaryOp::New(Token::CAST, rhs, rtype);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -340,6 +341,9 @@ void BinaryOp::TypeChecking() {
|
||||
case Token::MATMUL:
|
||||
return MatmulOpTypeChecking();
|
||||
|
||||
case Token::MASKED_DEREF:
|
||||
return MaskedDerefOpTypeChecking();
|
||||
|
||||
default:
|
||||
assert(0);
|
||||
}
|
||||
@@ -375,7 +379,7 @@ void BinaryOp::MultiOpTypeChecking() {
|
||||
Error(this, "operands of '%%' should be integers");
|
||||
}
|
||||
type_ = Convert();
|
||||
Broadcast();
|
||||
Broadcast(this, lhs_, rhs_, type_);
|
||||
}
|
||||
|
||||
|
||||
@@ -425,7 +429,7 @@ void BinaryOp::AdditiveOpTypeChecking() {
|
||||
}
|
||||
type_ = Convert();
|
||||
}
|
||||
Broadcast();
|
||||
Broadcast(this, lhs_, rhs_, type_);
|
||||
}
|
||||
|
||||
void BinaryOp::RangeOpTypeChecking() {
|
||||
@@ -443,6 +447,19 @@ void BinaryOp::RangeOpTypeChecking() {
|
||||
type_ = TileType::New(TileType::ShapeInt{len}, lhs_->Type());
|
||||
}
|
||||
|
||||
void BinaryOp::MaskedDerefOpTypeChecking() {
|
||||
::Type* lhsScalType = TryExtractScalarType(this, lhs_);
|
||||
::Type* rhsScalType = TryExtractScalarType(this, rhs_);
|
||||
auto lhsType = lhsScalType->ToArithm();
|
||||
auto rhsType = rhsScalType->ToPointer();
|
||||
if (!rhsType)
|
||||
Error(this, "pointer expected for deref pointer in operator '*?'");
|
||||
if (!lhsType || (lhsType && !lhsType->IsBool()))
|
||||
Error(this, "bool expected for deref mask in operator '*?'");
|
||||
type_ = ScalarOrLikeTile(rhs_, rhsType->Derived().GetPtr());
|
||||
Broadcast(this, lhs_, rhs_, type_);
|
||||
}
|
||||
|
||||
void BinaryOp::MatmulOpTypeChecking() {
|
||||
auto lhsType = lhs_->Type()->ToTile();
|
||||
auto rhsType = rhs_->Type()->ToTile();
|
||||
@@ -477,7 +494,7 @@ void BinaryOp::ShiftOpTypeChecking() {
|
||||
lhs_ = Expr::MayCast(lhs_, ScalarOrLikeTile(lhs_, ArithmType::IntegerPromote(lhsType)));
|
||||
rhs_ = Expr::MayCast(rhs_, ScalarOrLikeTile(rhs_, ArithmType::IntegerPromote(rhsType)));
|
||||
type_ = lhs_->Type();
|
||||
Broadcast();
|
||||
Broadcast(this, lhs_, rhs_, type_);
|
||||
}
|
||||
|
||||
|
||||
@@ -493,7 +510,7 @@ void BinaryOp::RelationalOpTypeChecking() {
|
||||
Convert();
|
||||
}
|
||||
type_ = ArithmType::New(T_INT);
|
||||
Broadcast();
|
||||
Broadcast(this, lhs_, rhs_, type_);
|
||||
}
|
||||
|
||||
|
||||
@@ -508,7 +525,7 @@ void BinaryOp::EqualityOpTypeChecking() {
|
||||
Convert();
|
||||
}
|
||||
type_ = ArithmType::New(T_INT);
|
||||
Broadcast();
|
||||
Broadcast(this, lhs_, rhs_, type_);
|
||||
}
|
||||
|
||||
|
||||
@@ -518,7 +535,7 @@ void BinaryOp::BitwiseOpTypeChecking() {
|
||||
if (!lhsScalType->IsInteger() || !rhsScalType->IsInteger())
|
||||
Error(this, "operands of '&' should be integer");
|
||||
type_ = Convert();
|
||||
Broadcast();
|
||||
Broadcast(this, lhs_, rhs_, type_);
|
||||
}
|
||||
|
||||
|
||||
@@ -528,7 +545,7 @@ void BinaryOp::LogicalOpTypeChecking() {
|
||||
if (!lhsScalType->IsScalar() || !rhsScalType->IsScalar())
|
||||
Error(this, "the operand should be arithmetic type or pointer");
|
||||
type_ = ArithmType::New(T_INT);
|
||||
Broadcast();
|
||||
Broadcast(this, lhs_, rhs_, type_);
|
||||
}
|
||||
|
||||
|
||||
@@ -548,10 +565,9 @@ void BinaryOp::AssignOpTypeChecking() {
|
||||
// The other constraints are lefted to cast operator
|
||||
rhs_ = Expr::MayCast(rhs_, ScalarOrLikeTile(rhs_, lhsScalType));
|
||||
type_ = lhs_->Type();
|
||||
Broadcast();
|
||||
Broadcast(this, lhs_, rhs_, type_);
|
||||
}
|
||||
|
||||
|
||||
/*
|
||||
* Unary Operators
|
||||
*/
|
||||
@@ -734,8 +750,8 @@ ConditionalOp* ConditionalOp::New(const Token* tok,
|
||||
|
||||
|
||||
ArithmType* ConditionalOp::Convert() {
|
||||
auto lhsType = exprTrue_->Type()->ToArithm();
|
||||
auto rhsType = exprFalse_->Type()->ToArithm();
|
||||
auto lhsType = exprTrue_->Type()->ScalarType()->ToArithm();
|
||||
auto rhsType = exprFalse_->Type()->ScalarType()->ToArithm();
|
||||
assert(lhsType && rhsType);
|
||||
auto type = ArithmType::MaxType(lhsType, rhsType);
|
||||
if (lhsType != type) { // Pointer comparation is enough!
|
||||
@@ -750,18 +766,21 @@ ArithmType* ConditionalOp::Convert() {
|
||||
|
||||
|
||||
void ConditionalOp::TypeChecking() {
|
||||
if (!cond_->Type()->IsScalar()) {
|
||||
Error(cond_->Tok(), "scalar is required");
|
||||
auto condScalarType = TryExtractScalarType(this, cond_);
|
||||
|
||||
if (!condScalarType) {
|
||||
Error(cond_->Tok(), "condition must be tile or scalar");
|
||||
}
|
||||
|
||||
auto lhsType = exprTrue_->Type();
|
||||
auto rhsType = exprFalse_->Type();
|
||||
auto lhsType = TryExtractScalarType(this, exprTrue_);
|
||||
auto rhsType = TryExtractScalarType(this, exprFalse_);
|
||||
if (lhsType->ToArithm() && rhsType->ToArithm()) {
|
||||
type_ = Convert();
|
||||
} else {
|
||||
EnsureCompatibleOrVoidPointer(lhsType, rhsType);
|
||||
type_ = lhsType;
|
||||
}
|
||||
BinaryOp::Broadcast(this, exprFalse_, exprTrue_, type_);
|
||||
}
|
||||
|
||||
|
||||
|
@@ -55,6 +55,10 @@ void Generator::VisitBinaryOp(BinaryOp* binary) {
|
||||
_0 = bld_->create_splat(_0, ret_ty->get_tile_shapes());
|
||||
return set_ret(bld_->create_dot(lhs, rhs, _0));
|
||||
}
|
||||
case Token::MASKED_DEREF: {
|
||||
ir::type* ret_ty = GenIRType(binary->Type(), *ctx_);
|
||||
return set_ret(bld_->create_masked_load(rhs, lhs, ir::undef_value::get(ret_ty)));
|
||||
}
|
||||
case Token::ELLIPSIS: {
|
||||
auto clhs = dynamic_cast<ir::constant_int*>(lhs);
|
||||
auto crhs = dynamic_cast<ir::constant_int*>(rhs);
|
||||
@@ -176,6 +180,21 @@ void Generator::VisitUnaryOp(UnaryOp* unary) {
|
||||
}
|
||||
|
||||
void Generator::VisitConditionalOp(ConditionalOp* condOp) {
|
||||
// auto &instructions = bld_->get_insert_block()->get_inst_list();
|
||||
VisitExpr(condOp->cond_);
|
||||
ir::value* cond = ret_;
|
||||
VisitExpr(condOp->exprTrue_);
|
||||
ir::value* true_val = ret_;
|
||||
VisitExpr(condOp->exprFalse_);
|
||||
ir::value* false_val = ret_;
|
||||
if(ir::load_inst* ld = dynamic_cast<ir::load_inst*>(true_val)) {
|
||||
ir::value* new_ld = bld_->create_masked_load(ld->get_pointer_operand(),
|
||||
cond,
|
||||
false_val);
|
||||
ld->replace_all_uses_with(new_ld);
|
||||
ld->erase_from_parent();
|
||||
return set_ret(new_ld);
|
||||
}
|
||||
return error_not_implemented();
|
||||
}
|
||||
|
||||
@@ -528,7 +547,7 @@ ir::type* Generator::GenIRFuncType(FuncType* type, ir::context& ctx) {
|
||||
|
||||
ir::type* Generator::GenIRPointerType(PointerType* type, ir::context& ctx) {
|
||||
ir::type* ele_ty = GenIRType(type->Derived().GetPtr(), ctx);
|
||||
unsigned addr_space = 0;
|
||||
unsigned addr_space = 1;
|
||||
return ir::pointer_type::get(ele_ty, addr_space);
|
||||
}
|
||||
|
||||
@@ -552,7 +571,13 @@ void Generator::popScope() {
|
||||
|
||||
// LValue Generator
|
||||
void LValAssigner::VisitBinaryOp(BinaryOp* binary) {
|
||||
error_not_implemented();
|
||||
if(binary->op_ != Token::MASKED_DEREF)
|
||||
error_not_implemented();
|
||||
gen_->VisitExpr(binary->lhs_);
|
||||
ir::value* mask = gen_->ret_;
|
||||
gen_->VisitExpr(binary->rhs_);
|
||||
ir::value* addr = gen_->ret_;
|
||||
ret_ = gen_->bld_->create_masked_store(addr, rhs_, mask);
|
||||
}
|
||||
|
||||
void LValAssigner::VisitUnaryOp(UnaryOp* unary) {
|
||||
|
@@ -517,7 +517,7 @@ Expr* Parser::ParseUnaryExpr() {
|
||||
case Token::INC: return ParsePrefixIncDec(tok);
|
||||
case Token::DEC: return ParsePrefixIncDec(tok);
|
||||
case '&': return ParseUnaryOp(tok, Token::ADDR);
|
||||
case '*': return ParseUnaryOp(tok, Token::DEREF);
|
||||
case '*': return ParseDerefOp(tok);
|
||||
case '+': return ParseUnaryOp(tok, Token::PLUS);
|
||||
case '-': return ParseUnaryOp(tok, Token::MINUS);
|
||||
case '~': return ParseUnaryOp(tok, '~');
|
||||
@@ -577,6 +577,19 @@ UnaryOp* Parser::ParseUnaryOp(const Token* tok, int op) {
|
||||
return UnaryOp::New(op, operand);
|
||||
}
|
||||
|
||||
Expr* Parser::ParseDerefOp(const Token* tok) {
|
||||
Expr* pred = nullptr;
|
||||
if(ts_.Try('?')){
|
||||
ts_.Expect('(');
|
||||
pred = ParseCastExpr();
|
||||
ts_.Expect(')');
|
||||
}
|
||||
Expr* addr = ParseCastExpr();
|
||||
if(pred)
|
||||
return BinaryOp::New(tok, Token::MASKED_DEREF, pred, addr);
|
||||
else
|
||||
return UnaryOp::New(Token::DEREF, addr);
|
||||
}
|
||||
|
||||
QualType Parser::ParseTypeName() {
|
||||
auto type = ParseSpecQual();
|
||||
|
@@ -107,7 +107,6 @@ const std::unordered_map<int, const char*> Token::tagLexemeMap_ {
|
||||
{ Token::XOR_ASSIGN, "^=" },
|
||||
{ Token::OR_ASSIGN, "|=" },
|
||||
{ Token::ELLIPSIS, "..." },
|
||||
|
||||
{ Token::AUTO, "auto" },
|
||||
{ Token::BREAK, "break" },
|
||||
{ Token::CASE, "case" },
|
||||
|
@@ -172,8 +172,10 @@ function::caller function::autotune(driver::stream* stream, const grid_fn_ty& gr
|
||||
caller call(tmp, std::move(bin), opt);
|
||||
double ts = tools::bench([&]() { call(stream, grid_fn(opt), args); }, stream);
|
||||
// save best
|
||||
if(ts < best_ts)
|
||||
if(ts < best_ts) {
|
||||
best_ts = ts;
|
||||
ret.reset(new caller(call));
|
||||
}
|
||||
};
|
||||
_parallel_loop_nest<std::string>(space, benchmark, 1);
|
||||
return *ret;
|
||||
@@ -192,12 +194,13 @@ std::unique_ptr<driver::module> function::make_bin(ir::module &module, driver::c
|
||||
codegen::transform::vectorize vectorize(&grids);
|
||||
codegen::transform::dce dce;
|
||||
codegen::transform::peephole peephole;
|
||||
codegen::transform::reassociate reassociate(&grids);
|
||||
codegen::transform::reassociate reassociate(&alignment_info, &grids);
|
||||
codegen::selection selection(&shmem_allocation, &grids, &shmem_info, &alignment_info, target.get());
|
||||
// run passes
|
||||
peephole.run(module);
|
||||
dce.run(module);
|
||||
grids.run(module);
|
||||
alignment_info.run(module);
|
||||
reassociate.run(module);
|
||||
peephole.run(module);
|
||||
if(target->is_gpu()){
|
||||
@@ -207,8 +210,6 @@ std::unique_ptr<driver::module> function::make_bin(ir::module &module, driver::c
|
||||
shmem_barriers.run(module);
|
||||
}
|
||||
dce.run(module);
|
||||
ir::print(module, std::cout);
|
||||
alignment_info.run(module);
|
||||
vectorize.run(module);
|
||||
dce.run(module);
|
||||
// generate llvm code
|
||||
|
Reference in New Issue
Block a user