[examples] back to 96 TFLOPS on V100

This commit is contained in:
Philippe Tillet
2019-08-26 22:48:15 -07:00
parent b4ae06a714
commit 37cbcfabd0
15 changed files with 140 additions and 54 deletions

View File

@@ -49,7 +49,7 @@ inline size_t size_of(DType dtype){
}
}
std::vector<cublasGemmAlgo_t> gather_all_algos() {
inline std::vector<cublasGemmAlgo_t> gather_all_algos() {
std::vector<cublasGemmAlgo_t> result;
// non-tensor ops
for(int i = -1; i < 24; i++)
@@ -124,7 +124,7 @@ inline cublasStatus_t cublasGemmEx(cublasHandle_t handle, cudaDataType cudt, cub
/* Get cuBLAS handle */
cublasHandle_t cublasGetHandle(triton::driver::stream* stream) {
inline cublasHandle_t cublasGetHandle(triton::driver::stream* stream) {
static std::map<CUstream, cublasHandle_t> cache;
CUstream key = *stream->cu();

View File

@@ -75,17 +75,21 @@ void dot(TYPE * A __noalias __readonly __aligned(16),
float xc[TM, TN] = 0;
#ifdef AT
TYPE* pa[TK, TM] = A + rka[:, newaxis] + rxa[newaxis, :]*lda;
TYPE a[TK, TM] = *pa;
bool checka[TK, TM] = rka[:, newaxis] < K;
TYPE a[TK, TM] = checka ? *pa : 0;
#else
TYPE* pa[TM, TK] = A + rka[newaxis, :]*lda + rxa[:, newaxis];
TYPE a[TM, TK] = *pa;
bool checka[TM, TK] = rka[newaxis, :] < K;
TYPE a[TM, TK] = checka ? *pa : 0;
#endif
#ifdef BT
TYPE* pb[TN, TK] = B + rkb[newaxis, :]*ldb + ryb[:, newaxis];
TYPE b[TN, TK] = *pb;
bool checkb[TN, TK] = rkb[newaxis, :] < K;
TYPE b[TN, TK] = checkb ? *pb : 0;
#else
TYPE* pb[TK, TN] = B + rkb[:, newaxis] + ryb[newaxis, :]*ldb;
TYPE b[TK, TN] = *pb;
bool checkb[TK, TN] = rkb[:, newazis] < K;
TYPE b[TK, TN] = checkb ? *pb : 0;
#endif
for(int k = K; k > 0; k = k - TK){
xc = USEA @ USEB + xc;
@@ -99,8 +103,10 @@ void dot(TYPE * A __noalias __readonly __aligned(16),
#else
pb = pb + TK;
#endif
a = *pa;
b = *pb;
checka = k > TK;
checkb = k > TK;
a = checka ? *pa : 0;
b = checkb ? *pb : 0;
}
int rxc[TM] = ridx * TM + (0 ... TM);
int ryc[TN] = ridy * TN + (0 ... TN);
@@ -109,7 +115,7 @@ void dot(TYPE * A __noalias __readonly __aligned(16),
bool checkc0[TM] = rxc < M;
bool checkc1[TN] = ryc < N;
bool checkc[TM, TN] = checkc0[:, newaxis] && checkc1[newaxis, :];
*pc = c;
*?(checkc) pc = c;
}
)";

View File

@@ -37,11 +37,12 @@ private:
ir::value *reassociate_ptr(ir::getelementptr_inst* pz, ir::builder &builder, std::map<ir::value*, cst_info> &offsets);
public:
reassociate(analysis::grids *params);
reassociate(analysis::alignment_info* align, analysis::grids *params);
void run(ir::module& module);
private:
analysis::grids* params_;
analysis::alignment_info* align_;
};
}

View File

@@ -360,11 +360,12 @@ public:
switch (op_) {
case '.': return !Type()->ToArray() && lhs_->IsLVal();
case ']': return !Type()->ToArray();
case Token::MASKED_DEREF: return true;
default: return false;
}
}
ArithmType* Convert();
void Broadcast();
static void Broadcast(Expr* loc, Expr*& lhs, Expr*& rhs, QualType &type);
virtual void TypeChecking();
void SubScriptingOpTypeChecking();
@@ -374,6 +375,7 @@ public:
void ShiftOpTypeChecking();
void RangeOpTypeChecking();
void MatmulOpTypeChecking();
void MaskedDerefOpTypeChecking();
void RelationalOpTypeChecking();
void EqualityOpTypeChecking();
void BitwiseOpTypeChecking();

View File

@@ -84,6 +84,7 @@ public:
Constant* ParseAlignof();
UnaryOp* ParsePrefixIncDec(const Token* tok);
UnaryOp* ParseUnaryOp(const Token* tok, int op);
Expr* ParseDerefOp(const Token* tok);
QualType ParseTypeName();
Expr* ParseCastExpr();

View File

@@ -94,6 +94,7 @@ public:
OR_ASSIGN,
ELLIPSIS,
MASKED_DEREF,
// Punctuators end
// KEYWORD BEGIN

View File

@@ -41,7 +41,7 @@ inline double bench(std::function<void()> const & op, driver::stream * stream)
while(total_time*1e-9 < 1e-3){
float norm = 1;
// normalize clock if possible to reduce noise in auto-tuning
// if(auto cu_device = dynamic_cast<const triton::driver::cu_device*>(device))
// if(auto cu_device = dynamic_cast<const triton::driver::cu_device*>(stream->context()->device()))
// norm = (float)cu_device->current_sm_clock()/cu_device->max_sm_clock();
tmr.start();
op();

View File

@@ -111,8 +111,9 @@ unsigned alignment_info::populate_max_contiguous(ir::value *v){
if(!v->get_type()->is_tile_ty())
return cache(1);
auto shapes = v->get_type()->get_tile_shapes();
if(dynamic_cast<ir::constant_range*>(v))
if(dynamic_cast<ir::constant_range*>(v)){
return cache(shapes[0]);
}
if(auto *x = dynamic_cast<ir::retile_inst*>(v)){
ir::value *op = x->get_operand(0);
if(op->get_type()->is_tile_ty()){
@@ -305,7 +306,6 @@ void alignment_info::run(ir::module &mod) {
for(ir::basic_block *block: fn->blocks())
for(ir::instruction *i: block->get_inst_list()){
populate_max_contiguous(i);
std::cout << i->get_name() << " " << is_constant_.at(i).num_cst << " " << starting_multiple_.at(i) << " " << max_contiguous_.at(i) << std::endl;
}
}

View File

@@ -93,6 +93,9 @@ ir::value *reassociate::reassociate_idx(ir::value *old_value,
params_->copy(new_value, old_value);
params_->copy(new_lhs, old_value);
params_->copy(new_rhs, old_value);
align_->copy(new_value, old_value);
align_->copy(new_lhs, old_value);
align_->copy(new_rhs, old_value);
}
}
}
@@ -130,6 +133,9 @@ ir::value *reassociate::reassociate_idx(ir::value *old_value,
params_->copy(new_value, old_value);
params_->copy(((ir::instruction*)new_value)->get_operand(0), old_value);
params_->copy(((ir::instruction*)new_value)->get_operand(1), old_value);
align_->copy(new_value, old_value);
align_->copy(((ir::instruction*)new_value)->get_operand(0), old_value);
align_->copy(((ir::instruction*)new_value)->get_operand(1), old_value);
}
}
@@ -155,8 +161,8 @@ ir::value *reassociate::reassociate_idx(ir::value *old_value,
return new_value;
}
reassociate::reassociate(analysis::grids* params)
: params_(params)
reassociate::reassociate(analysis::alignment_info *align, analysis::grids* params)
: params_(params), align_(align)
{ }
@@ -185,6 +191,9 @@ void reassociate::run(ir::module &mod) {
params_->copy(dyn_range, old_range);
params_->copy(static_range, old_range);
params_->copy(new_range, old_range);
align_->copy(dyn_range, old_range);
align_->copy(static_range, old_range);
align_->copy(new_range, old_range);
}
}
@@ -217,6 +226,8 @@ void reassociate::run(ir::module &mod) {
ir::value *sta_ptr = builder.create_gep(dyn_ptr, {sta});
params_->copy(dyn_ptr, pz);
params_->copy(sta_ptr, pz);
align_->copy(dyn_ptr, pz);
align_->copy(sta_ptr, pz);
pz->replace_all_uses_with(sta_ptr);
infos[sta_ptr].dyn_ptr = dyn_ptr;
infos[sta_ptr].sta_ptr = (ir::getelementptr_inst*)sta_ptr;
@@ -233,6 +244,8 @@ void reassociate::run(ir::module &mod) {
ir::value *pz_sta = builder.create_gep(pz_dyn, {cst}, pz->get_name());
params_->copy(pz_dyn, pz);
params_->copy(pz_sta, pz);
align_->copy(pz_dyn, pz);
align_->copy(pz_sta, pz);
pz->replace_all_uses_with(pz_sta);
infos[pz_sta].dyn_ptr = pz_dyn;
infos[pz_sta].sta_ptr = (ir::getelementptr_inst*)pz_sta;
@@ -283,6 +296,11 @@ void reassociate::run(ir::module &mod) {
params_->copy(neg_off, off);
params_->copy(phi_dyn, phi);
params_->copy(phi_sta, phi);
align_->copy(pz_dyn, pz);
align_->copy(((ir::instruction*)neg_off)->get_operand(0), off);
align_->copy(neg_off, off);
align_->copy(phi_dyn, phi);
align_->copy(phi_sta, phi);
infos[phi_sta].dyn_ptr = phi_dyn;
infos[phi_sta].sta_ptr = (ir::getelementptr_inst*)phi_sta;
replaced.insert(phi);

View File

@@ -240,7 +240,7 @@ std::string cu_module::compile_llvm_module(llvm::Module* module) {
cu_module::cu_module(driver::context * context, llvm::Module* ll_module): cu_module(context, compile_llvm_module(ll_module)) { }
cu_module::cu_module(driver::context * context, std::string const & source) : module(context, CUmodule(), true), source_(source){
std::cout << source << std::endl;
// std::cout << source << std::endl;
cu_context::context_switcher ctx_switch(*context);
// JIT compile source-code
CUjit_option opt[] = {CU_JIT_ERROR_LOG_BUFFER_SIZE_BYTES, CU_JIT_ERROR_LOG_BUFFER};

View File

@@ -190,6 +190,7 @@ BinaryOp* BinaryOp::New(const Token* tok, int op, Expr* lhs, Expr* rhs) {
case Token::LOGICAL_OR:
case Token::ELLIPSIS:
case Token::MATMUL:
case Token::MASKED_DEREF:
break;
default:
assert(0);
@@ -218,22 +219,22 @@ ArithmType* BinaryOp::Convert() {
return maxType;
}
void BinaryOp::Broadcast() {
auto lhsType = lhs_->Type()->ToTile();
auto rhsType = rhs_->Type()->ToTile();
auto eleType = type_->ScalarType();
void BinaryOp::Broadcast(Expr* loc, Expr *&lhs, Expr *&rhs, QualType& type) {
auto lhsType = lhs->Type()->ToTile();
auto rhsType = rhs->Type()->ToTile();
auto eleType = type->ScalarType();
assert(eleType);
if(!lhsType && !rhsType)
return ;
else if(lhsType && !rhsType){
type_ = TileType::New(lhsType->Shape(), eleType);
::Type* rtype = TileType::New(lhsType->Shape(), rhs_->Type()->ScalarType());
rhs_ = UnaryOp::New(Token::CAST, rhs_, rtype);
type = TileType::New(lhsType->Shape(), eleType);
::Type* rtype = TileType::New(lhsType->Shape(), rhs->Type()->ScalarType());
rhs = UnaryOp::New(Token::CAST, rhs, rtype);
}
else if(!lhsType && rhsType){
type_ = TileType::New(rhsType->Shape(), eleType);
::Type* ltype = TileType::New(rhsType->Shape(), lhs_->Type()->ScalarType());
lhs_ = UnaryOp::New(Token::CAST, lhs_, ltype);
type = TileType::New(rhsType->Shape(), eleType);
::Type* ltype = TileType::New(rhsType->Shape(), lhs->Type()->ScalarType());
lhs = UnaryOp::New(Token::CAST, lhs, ltype);
}
else {
@@ -257,17 +258,17 @@ void BinaryOp::Broadcast() {
else if(lhsShape[i] == rhsShape[i])
retShape[i] = lhsShape[i];
else
Error(this, "cannot broadcast dimension %d "
Error(loc, "cannot broadcast dimension %d "
"for operands of shape %d and %d",
i, lhsShape[i], rhsShape[i]);
}
::Type* ltype = TileType::New(retShape, lhsType->ScalarType());
::Type* rtype = TileType::New(retShape, rhsType->ScalarType());
type_ = TileType::New(retShape, eleType);
type = TileType::New(retShape, eleType);
if(retShape != lhsShape)
lhs_ = UnaryOp::New(Token::CAST, lhs_, ltype);
lhs = UnaryOp::New(Token::CAST, lhs, ltype);
if(retShape != rhsShape)
rhs_ = UnaryOp::New(Token::CAST, rhs_, rtype);
rhs = UnaryOp::New(Token::CAST, rhs, rtype);
}
}
@@ -340,6 +341,9 @@ void BinaryOp::TypeChecking() {
case Token::MATMUL:
return MatmulOpTypeChecking();
case Token::MASKED_DEREF:
return MaskedDerefOpTypeChecking();
default:
assert(0);
}
@@ -375,7 +379,7 @@ void BinaryOp::MultiOpTypeChecking() {
Error(this, "operands of '%%' should be integers");
}
type_ = Convert();
Broadcast();
Broadcast(this, lhs_, rhs_, type_);
}
@@ -425,7 +429,7 @@ void BinaryOp::AdditiveOpTypeChecking() {
}
type_ = Convert();
}
Broadcast();
Broadcast(this, lhs_, rhs_, type_);
}
void BinaryOp::RangeOpTypeChecking() {
@@ -443,6 +447,19 @@ void BinaryOp::RangeOpTypeChecking() {
type_ = TileType::New(TileType::ShapeInt{len}, lhs_->Type());
}
void BinaryOp::MaskedDerefOpTypeChecking() {
::Type* lhsScalType = TryExtractScalarType(this, lhs_);
::Type* rhsScalType = TryExtractScalarType(this, rhs_);
auto lhsType = lhsScalType->ToArithm();
auto rhsType = rhsScalType->ToPointer();
if (!rhsType)
Error(this, "pointer expected for deref pointer in operator '*?'");
if (!lhsType || (lhsType && !lhsType->IsBool()))
Error(this, "bool expected for deref mask in operator '*?'");
type_ = ScalarOrLikeTile(rhs_, rhsType->Derived().GetPtr());
Broadcast(this, lhs_, rhs_, type_);
}
void BinaryOp::MatmulOpTypeChecking() {
auto lhsType = lhs_->Type()->ToTile();
auto rhsType = rhs_->Type()->ToTile();
@@ -477,7 +494,7 @@ void BinaryOp::ShiftOpTypeChecking() {
lhs_ = Expr::MayCast(lhs_, ScalarOrLikeTile(lhs_, ArithmType::IntegerPromote(lhsType)));
rhs_ = Expr::MayCast(rhs_, ScalarOrLikeTile(rhs_, ArithmType::IntegerPromote(rhsType)));
type_ = lhs_->Type();
Broadcast();
Broadcast(this, lhs_, rhs_, type_);
}
@@ -493,7 +510,7 @@ void BinaryOp::RelationalOpTypeChecking() {
Convert();
}
type_ = ArithmType::New(T_INT);
Broadcast();
Broadcast(this, lhs_, rhs_, type_);
}
@@ -508,7 +525,7 @@ void BinaryOp::EqualityOpTypeChecking() {
Convert();
}
type_ = ArithmType::New(T_INT);
Broadcast();
Broadcast(this, lhs_, rhs_, type_);
}
@@ -518,7 +535,7 @@ void BinaryOp::BitwiseOpTypeChecking() {
if (!lhsScalType->IsInteger() || !rhsScalType->IsInteger())
Error(this, "operands of '&' should be integer");
type_ = Convert();
Broadcast();
Broadcast(this, lhs_, rhs_, type_);
}
@@ -528,7 +545,7 @@ void BinaryOp::LogicalOpTypeChecking() {
if (!lhsScalType->IsScalar() || !rhsScalType->IsScalar())
Error(this, "the operand should be arithmetic type or pointer");
type_ = ArithmType::New(T_INT);
Broadcast();
Broadcast(this, lhs_, rhs_, type_);
}
@@ -548,10 +565,9 @@ void BinaryOp::AssignOpTypeChecking() {
// The other constraints are lefted to cast operator
rhs_ = Expr::MayCast(rhs_, ScalarOrLikeTile(rhs_, lhsScalType));
type_ = lhs_->Type();
Broadcast();
Broadcast(this, lhs_, rhs_, type_);
}
/*
* Unary Operators
*/
@@ -734,8 +750,8 @@ ConditionalOp* ConditionalOp::New(const Token* tok,
ArithmType* ConditionalOp::Convert() {
auto lhsType = exprTrue_->Type()->ToArithm();
auto rhsType = exprFalse_->Type()->ToArithm();
auto lhsType = exprTrue_->Type()->ScalarType()->ToArithm();
auto rhsType = exprFalse_->Type()->ScalarType()->ToArithm();
assert(lhsType && rhsType);
auto type = ArithmType::MaxType(lhsType, rhsType);
if (lhsType != type) { // Pointer comparation is enough!
@@ -750,18 +766,21 @@ ArithmType* ConditionalOp::Convert() {
void ConditionalOp::TypeChecking() {
if (!cond_->Type()->IsScalar()) {
Error(cond_->Tok(), "scalar is required");
auto condScalarType = TryExtractScalarType(this, cond_);
if (!condScalarType) {
Error(cond_->Tok(), "condition must be tile or scalar");
}
auto lhsType = exprTrue_->Type();
auto rhsType = exprFalse_->Type();
auto lhsType = TryExtractScalarType(this, exprTrue_);
auto rhsType = TryExtractScalarType(this, exprFalse_);
if (lhsType->ToArithm() && rhsType->ToArithm()) {
type_ = Convert();
} else {
EnsureCompatibleOrVoidPointer(lhsType, rhsType);
type_ = lhsType;
}
BinaryOp::Broadcast(this, exprFalse_, exprTrue_, type_);
}

View File

@@ -55,6 +55,10 @@ void Generator::VisitBinaryOp(BinaryOp* binary) {
_0 = bld_->create_splat(_0, ret_ty->get_tile_shapes());
return set_ret(bld_->create_dot(lhs, rhs, _0));
}
case Token::MASKED_DEREF: {
ir::type* ret_ty = GenIRType(binary->Type(), *ctx_);
return set_ret(bld_->create_masked_load(rhs, lhs, ir::undef_value::get(ret_ty)));
}
case Token::ELLIPSIS: {
auto clhs = dynamic_cast<ir::constant_int*>(lhs);
auto crhs = dynamic_cast<ir::constant_int*>(rhs);
@@ -176,6 +180,21 @@ void Generator::VisitUnaryOp(UnaryOp* unary) {
}
void Generator::VisitConditionalOp(ConditionalOp* condOp) {
// auto &instructions = bld_->get_insert_block()->get_inst_list();
VisitExpr(condOp->cond_);
ir::value* cond = ret_;
VisitExpr(condOp->exprTrue_);
ir::value* true_val = ret_;
VisitExpr(condOp->exprFalse_);
ir::value* false_val = ret_;
if(ir::load_inst* ld = dynamic_cast<ir::load_inst*>(true_val)) {
ir::value* new_ld = bld_->create_masked_load(ld->get_pointer_operand(),
cond,
false_val);
ld->replace_all_uses_with(new_ld);
ld->erase_from_parent();
return set_ret(new_ld);
}
return error_not_implemented();
}
@@ -528,7 +547,7 @@ ir::type* Generator::GenIRFuncType(FuncType* type, ir::context& ctx) {
ir::type* Generator::GenIRPointerType(PointerType* type, ir::context& ctx) {
ir::type* ele_ty = GenIRType(type->Derived().GetPtr(), ctx);
unsigned addr_space = 0;
unsigned addr_space = 1;
return ir::pointer_type::get(ele_ty, addr_space);
}
@@ -552,7 +571,13 @@ void Generator::popScope() {
// LValue Generator
void LValAssigner::VisitBinaryOp(BinaryOp* binary) {
error_not_implemented();
if(binary->op_ != Token::MASKED_DEREF)
error_not_implemented();
gen_->VisitExpr(binary->lhs_);
ir::value* mask = gen_->ret_;
gen_->VisitExpr(binary->rhs_);
ir::value* addr = gen_->ret_;
ret_ = gen_->bld_->create_masked_store(addr, rhs_, mask);
}
void LValAssigner::VisitUnaryOp(UnaryOp* unary) {

View File

@@ -517,7 +517,7 @@ Expr* Parser::ParseUnaryExpr() {
case Token::INC: return ParsePrefixIncDec(tok);
case Token::DEC: return ParsePrefixIncDec(tok);
case '&': return ParseUnaryOp(tok, Token::ADDR);
case '*': return ParseUnaryOp(tok, Token::DEREF);
case '*': return ParseDerefOp(tok);
case '+': return ParseUnaryOp(tok, Token::PLUS);
case '-': return ParseUnaryOp(tok, Token::MINUS);
case '~': return ParseUnaryOp(tok, '~');
@@ -577,6 +577,19 @@ UnaryOp* Parser::ParseUnaryOp(const Token* tok, int op) {
return UnaryOp::New(op, operand);
}
Expr* Parser::ParseDerefOp(const Token* tok) {
Expr* pred = nullptr;
if(ts_.Try('?')){
ts_.Expect('(');
pred = ParseCastExpr();
ts_.Expect(')');
}
Expr* addr = ParseCastExpr();
if(pred)
return BinaryOp::New(tok, Token::MASKED_DEREF, pred, addr);
else
return UnaryOp::New(Token::DEREF, addr);
}
QualType Parser::ParseTypeName() {
auto type = ParseSpecQual();

View File

@@ -107,7 +107,6 @@ const std::unordered_map<int, const char*> Token::tagLexemeMap_ {
{ Token::XOR_ASSIGN, "^=" },
{ Token::OR_ASSIGN, "|=" },
{ Token::ELLIPSIS, "..." },
{ Token::AUTO, "auto" },
{ Token::BREAK, "break" },
{ Token::CASE, "case" },

View File

@@ -172,8 +172,10 @@ function::caller function::autotune(driver::stream* stream, const grid_fn_ty& gr
caller call(tmp, std::move(bin), opt);
double ts = tools::bench([&]() { call(stream, grid_fn(opt), args); }, stream);
// save best
if(ts < best_ts)
if(ts < best_ts) {
best_ts = ts;
ret.reset(new caller(call));
}
};
_parallel_loop_nest<std::string>(space, benchmark, 1);
return *ret;
@@ -192,12 +194,13 @@ std::unique_ptr<driver::module> function::make_bin(ir::module &module, driver::c
codegen::transform::vectorize vectorize(&grids);
codegen::transform::dce dce;
codegen::transform::peephole peephole;
codegen::transform::reassociate reassociate(&grids);
codegen::transform::reassociate reassociate(&alignment_info, &grids);
codegen::selection selection(&shmem_allocation, &grids, &shmem_info, &alignment_info, target.get());
// run passes
peephole.run(module);
dce.run(module);
grids.run(module);
alignment_info.run(module);
reassociate.run(module);
peephole.run(module);
if(target->is_gpu()){
@@ -207,8 +210,6 @@ std::unique_ptr<driver::module> function::make_bin(ir::module &module, driver::c
shmem_barriers.run(module);
}
dce.run(module);
ir::print(module, std::cout);
alignment_info.run(module);
vectorize.run(module);
dce.run(module);
// generate llvm code