[GENERAL] Various improvements:

* Sparse einsum in triton.ops.einsum
* Hacky support for fixed-tile-size atomic-add
* Various bugfixes in parser
This commit is contained in:
Philippe Tillet
2020-10-25 11:55:58 -07:00
parent 444907589d
commit 049ab989b5
16 changed files with 574 additions and 331 deletions

View File

@@ -364,7 +364,14 @@ void generator::visit_unmasked_load_inst(ir::unmasked_load_inst* x) {
void generator::visit_masked_load_inst(ir::masked_load_inst* x) {
// find vector size
ir::value *ptr = x->get_pointer_operand();
size_t ld = layouts_->get(ptr)->get_order(0);
auto order = layouts_->get(ptr)->get_order();
size_t ld;
for(size_t i = 0; i < order.size(); i++){
ld = order[i];
if(ld < x->get_type()->get_tile_rank())
break;
}
//size_t ld = layouts_->get(ptr)->get_order(0);
unsigned alignment = alignment_->get(ptr, ld);
distributed_tile *pointers = (distributed_tile*)tmap_.at(ptr);
distributed_tile *masks = (distributed_tile*)tmap_.at(x->get_mask_operand());
@@ -652,6 +659,31 @@ void generator::visit_atomic_exch_inst(ir::atomic_exch_inst* xchg) {
}
void generator::visit_atomic_add_inst(ir::atomic_add_inst* add) {
if(add->get_type()->is_tile_ty()){
ir::value* ptr = add->get_operand(0);
ir::value* val = add->get_operand(1);
ir::value* msk = add->get_operand(2);
distributed_tile* ptrs = (distributed_tile*)tmap_.at(ptr);
distributed_tile* vals = (distributed_tile*)tmap_.at(val);
distributed_tile* msks = (distributed_tile*)tmap_.at(msk);
for_each(ptr, [&](indices_t idx){
Value *rmw_ptr = ptrs->get_value(idx);
Value *rmw_val = vals->get_value(idx);
Value *rmw_msk = msks->get_value(idx);
BasicBlock *current_bb = builder_->GetInsertBlock();
Function *parent = builder_->GetInsertBlock()->getParent();
BasicBlock *mask_then_bb = BasicBlock::Create(*ctx_, "mask_then", parent);
BasicBlock *mask_done_bb = BasicBlock::Create(*ctx_, "mask_done", parent);
builder_->CreateCondBr(rmw_msk, mask_then_bb, mask_done_bb);
builder_->SetInsertPoint(mask_then_bb);
builder_->CreateAtomicRMW(AtomicRMWInst::FAdd, rmw_ptr, rmw_val,
AtomicOrdering::Monotonic,
SyncScope::System);
builder_->CreateBr(mask_done_bb);
builder_->SetInsertPoint(mask_done_bb);
});
}
else{
BasicBlock *current = builder_->GetInsertBlock();
Module *module = current->getModule();
Value *rmw_ptr = vmap_.at(add->get_operand(0));
@@ -670,6 +702,7 @@ void generator::visit_atomic_add_inst(ir::atomic_add_inst* add) {
builder_->CreateBr(tid_0_done_bb);
builder_->SetInsertPoint(tid_0_done_bb);
tgt_->add_memfence(module, *builder_);
}
}
void generator::visit_hmma_dot(ir::dot_inst* dot, shared_tile *TA, shared_tile *TB, distributed_tile *TD, unsigned NK) {
@@ -1362,8 +1395,10 @@ void generator::visit_layout_shared(analysis::shared_layout* layout) {
void generator::visit_basic_block(ir::basic_block * block) {
BasicBlock *parent = (BasicBlock*)vmap_[block];
builder_->SetInsertPoint(parent);
for(ir::instruction *i: block->get_inst_list())
for(ir::instruction *i: block->get_inst_list()){
// std::cout << typeid(*i).name() << std::endl;
visit_value(i);
}
vmap_[block] = builder_->GetInsertBlock();
}

View File

@@ -253,7 +253,7 @@ cu_module::cu_module(driver::context * context, std::unique_ptr<llvm::Module> ll
cu_module::cu_module(driver::context * context, std::string const & source) : module(context, CUmodule(), true), source_(source){
cu_context::context_switcher ctx(*context);
// std::cout << source << std::endl;
// std::cout << source << std::endl;
// JIT compile source-code
CUjit_option opt[] = {CU_JIT_ERROR_LOG_BUFFER_SIZE_BYTES, CU_JIT_ERROR_LOG_BUFFER};
unsigned int errbufsize = 8096;

View File

@@ -307,8 +307,8 @@ value *builder::create_atomic_exch(value *ptr, value *val, const std::string &na
return insert(atomic_exch_inst::create(ptr, val, name));
}
value *builder::create_atomic_add(value *ptr, value *val, const std::string &name){
return insert(atomic_add_inst::create(ptr, val, name));
value *builder::create_atomic_add(value *ptr, value *val, value *msk, const std::string &name){
return insert(atomic_add_inst::create(ptr, val, msk, name));
}
value *builder::create_exp(value *arg, const std::string &name){

View File

@@ -736,14 +736,15 @@ instruction* atomic_exch_inst::create(value *ptr, value *val, const std::string
// atomic add
atomic_add_inst::atomic_add_inst(value *ptr, value *val, const std::string &name, instruction *next)
: builtin_inst(ptr->get_type()->get_pointer_element_ty(), INST_ATOMIC_ADD, 2, name, next) {
atomic_add_inst::atomic_add_inst(value *ptr, value *val, value *msk, const std::string &name, instruction *next)
: builtin_inst(ptr->get_type()->get_pointer_element_ty(), INST_ATOMIC_ADD, 3, name, next) {
set_operand(0, ptr);
set_operand(1, val);
set_operand(2, msk);
}
instruction* atomic_add_inst::create(value *ptr, value *val, const std::string &name, instruction *next) {
return new atomic_add_inst(ptr, val, name, next);
instruction* atomic_add_inst::create(value *ptr, value *val, value *msk, const std::string &name, instruction *next) {
return new atomic_add_inst(ptr, val, msk, name, next);
}
// exp

View File

@@ -523,7 +523,7 @@ void BinaryOp::RelationalOpTypeChecking() {
}
Convert();
}
type_ = ArithmType::New(T_INT);
type_ = ArithmType::New(T_BOOL);
Broadcast(this, lhs_, rhs_, type_);
}
@@ -538,7 +538,7 @@ void BinaryOp::EqualityOpTypeChecking() {
Error(this, "invalid operands to binary %s", tok_->str_.c_str());
Convert();
}
type_ = ArithmType::New(T_INT);
type_ = ArithmType::New(T_BOOL);
Broadcast(this, lhs_, rhs_, type_);
}
@@ -558,7 +558,7 @@ void BinaryOp::LogicalOpTypeChecking() {
::Type* rhsScalType = TryExtractScalarType(this, rhs_);
if (!lhsScalType->IsScalar() || !rhsScalType->IsScalar())
Error(this, "the operand should be arithmetic type or pointer");
type_ = ArithmType::New(T_INT);
type_ = ArithmType::New(T_BOOL);
Broadcast(this, lhs_, rhs_, type_);
}

View File

@@ -277,12 +277,14 @@ void Generator::VisitFuncCall(FuncCall* funcCall) {
ir::value* val = ret_;
return set_ret(bld_->create_atomic_exch(ptr, val));
}
if(name == "f32_atomic_add"){
if(name == "f32_atomic_add" || name == "atomic_add_64x64"){
VisitExpr(funcCall->Args()->at(0));
ir::value* ptr = ret_;
VisitExpr(funcCall->Args()->at(1));
ir::value* val = ret_;
return set_ret(bld_->create_atomic_add(ptr, val));
VisitExpr(funcCall->Args()->at(2));
ir::value* msk = ret_;
return set_ret(bld_->create_atomic_add(ptr, val, msk));
}
if(name == "sqrtf"){
VisitExpr(funcCall->Args()->at(0));
@@ -338,6 +340,7 @@ void Generator::VisitTempVar(TempVar* tempVar) {
}
// Statement
// TODO: int x = x; crashes
void Generator::VisitDeclaration(Declaration* decl) {
auto obj = decl->obj_;
// initialize to undef

View File

@@ -650,7 +650,7 @@ Expr* Parser::ParseDerefOp(const Token* tok) {
Expr* pred = nullptr;
if(ts_.Try('?')){
ts_.Expect('(');
pred = ParseCastExpr();
pred = ParseExpr();
ts_.Expect(')');
}
Expr* addr = ParseCastExpr();

View File

@@ -239,6 +239,7 @@ std::unique_ptr<driver::module> function::make_bin(ir::module &module,
throw std::runtime_error("using too much shared memory");
barriers.run(module);
isel.visit(module, *llvm);
// ir::print(module, std::cout);
std::unique_ptr<driver::module> res(driver::module::create(context, std::move(llvm)));
return res;
}
@@ -351,6 +352,8 @@ std::string function::preheader() {
extern int atomic_cas(int*, int, int);
extern int atomic_xchg(int*, int);
extern float f32_atomic_add(float*, float);
extern void atomic_add_128x128(float*[128, 128], float[128, 128], bool[128, 128]);
extern void atomic_add_64x64(float*[64, 64], float[64, 64], bool[64, 64]);
extern int get_program_id(int);
extern int get_num_programs(int);
extern float sqrtf(float);