[GENERAL] Various improvements:
* Sparse einsum in triton.ops.einsum * Hacky support for fixed-tile-size atomic-add * Various bugfixes in parser
This commit is contained in:
@@ -364,7 +364,14 @@ void generator::visit_unmasked_load_inst(ir::unmasked_load_inst* x) {
|
||||
void generator::visit_masked_load_inst(ir::masked_load_inst* x) {
|
||||
// find vector size
|
||||
ir::value *ptr = x->get_pointer_operand();
|
||||
size_t ld = layouts_->get(ptr)->get_order(0);
|
||||
auto order = layouts_->get(ptr)->get_order();
|
||||
size_t ld;
|
||||
for(size_t i = 0; i < order.size(); i++){
|
||||
ld = order[i];
|
||||
if(ld < x->get_type()->get_tile_rank())
|
||||
break;
|
||||
}
|
||||
//size_t ld = layouts_->get(ptr)->get_order(0);
|
||||
unsigned alignment = alignment_->get(ptr, ld);
|
||||
distributed_tile *pointers = (distributed_tile*)tmap_.at(ptr);
|
||||
distributed_tile *masks = (distributed_tile*)tmap_.at(x->get_mask_operand());
|
||||
@@ -652,6 +659,31 @@ void generator::visit_atomic_exch_inst(ir::atomic_exch_inst* xchg) {
|
||||
}
|
||||
|
||||
void generator::visit_atomic_add_inst(ir::atomic_add_inst* add) {
|
||||
if(add->get_type()->is_tile_ty()){
|
||||
ir::value* ptr = add->get_operand(0);
|
||||
ir::value* val = add->get_operand(1);
|
||||
ir::value* msk = add->get_operand(2);
|
||||
distributed_tile* ptrs = (distributed_tile*)tmap_.at(ptr);
|
||||
distributed_tile* vals = (distributed_tile*)tmap_.at(val);
|
||||
distributed_tile* msks = (distributed_tile*)tmap_.at(msk);
|
||||
for_each(ptr, [&](indices_t idx){
|
||||
Value *rmw_ptr = ptrs->get_value(idx);
|
||||
Value *rmw_val = vals->get_value(idx);
|
||||
Value *rmw_msk = msks->get_value(idx);
|
||||
BasicBlock *current_bb = builder_->GetInsertBlock();
|
||||
Function *parent = builder_->GetInsertBlock()->getParent();
|
||||
BasicBlock *mask_then_bb = BasicBlock::Create(*ctx_, "mask_then", parent);
|
||||
BasicBlock *mask_done_bb = BasicBlock::Create(*ctx_, "mask_done", parent);
|
||||
builder_->CreateCondBr(rmw_msk, mask_then_bb, mask_done_bb);
|
||||
builder_->SetInsertPoint(mask_then_bb);
|
||||
builder_->CreateAtomicRMW(AtomicRMWInst::FAdd, rmw_ptr, rmw_val,
|
||||
AtomicOrdering::Monotonic,
|
||||
SyncScope::System);
|
||||
builder_->CreateBr(mask_done_bb);
|
||||
builder_->SetInsertPoint(mask_done_bb);
|
||||
});
|
||||
}
|
||||
else{
|
||||
BasicBlock *current = builder_->GetInsertBlock();
|
||||
Module *module = current->getModule();
|
||||
Value *rmw_ptr = vmap_.at(add->get_operand(0));
|
||||
@@ -670,6 +702,7 @@ void generator::visit_atomic_add_inst(ir::atomic_add_inst* add) {
|
||||
builder_->CreateBr(tid_0_done_bb);
|
||||
builder_->SetInsertPoint(tid_0_done_bb);
|
||||
tgt_->add_memfence(module, *builder_);
|
||||
}
|
||||
}
|
||||
|
||||
void generator::visit_hmma_dot(ir::dot_inst* dot, shared_tile *TA, shared_tile *TB, distributed_tile *TD, unsigned NK) {
|
||||
@@ -1362,8 +1395,10 @@ void generator::visit_layout_shared(analysis::shared_layout* layout) {
|
||||
void generator::visit_basic_block(ir::basic_block * block) {
|
||||
BasicBlock *parent = (BasicBlock*)vmap_[block];
|
||||
builder_->SetInsertPoint(parent);
|
||||
for(ir::instruction *i: block->get_inst_list())
|
||||
for(ir::instruction *i: block->get_inst_list()){
|
||||
// std::cout << typeid(*i).name() << std::endl;
|
||||
visit_value(i);
|
||||
}
|
||||
vmap_[block] = builder_->GetInsertBlock();
|
||||
}
|
||||
|
||||
|
@@ -253,7 +253,7 @@ cu_module::cu_module(driver::context * context, std::unique_ptr<llvm::Module> ll
|
||||
|
||||
cu_module::cu_module(driver::context * context, std::string const & source) : module(context, CUmodule(), true), source_(source){
|
||||
cu_context::context_switcher ctx(*context);
|
||||
// std::cout << source << std::endl;
|
||||
// std::cout << source << std::endl;
|
||||
// JIT compile source-code
|
||||
CUjit_option opt[] = {CU_JIT_ERROR_LOG_BUFFER_SIZE_BYTES, CU_JIT_ERROR_LOG_BUFFER};
|
||||
unsigned int errbufsize = 8096;
|
||||
|
@@ -307,8 +307,8 @@ value *builder::create_atomic_exch(value *ptr, value *val, const std::string &na
|
||||
return insert(atomic_exch_inst::create(ptr, val, name));
|
||||
}
|
||||
|
||||
value *builder::create_atomic_add(value *ptr, value *val, const std::string &name){
|
||||
return insert(atomic_add_inst::create(ptr, val, name));
|
||||
value *builder::create_atomic_add(value *ptr, value *val, value *msk, const std::string &name){
|
||||
return insert(atomic_add_inst::create(ptr, val, msk, name));
|
||||
}
|
||||
|
||||
value *builder::create_exp(value *arg, const std::string &name){
|
||||
|
@@ -736,14 +736,15 @@ instruction* atomic_exch_inst::create(value *ptr, value *val, const std::string
|
||||
|
||||
// atomic add
|
||||
|
||||
atomic_add_inst::atomic_add_inst(value *ptr, value *val, const std::string &name, instruction *next)
|
||||
: builtin_inst(ptr->get_type()->get_pointer_element_ty(), INST_ATOMIC_ADD, 2, name, next) {
|
||||
atomic_add_inst::atomic_add_inst(value *ptr, value *val, value *msk, const std::string &name, instruction *next)
|
||||
: builtin_inst(ptr->get_type()->get_pointer_element_ty(), INST_ATOMIC_ADD, 3, name, next) {
|
||||
set_operand(0, ptr);
|
||||
set_operand(1, val);
|
||||
set_operand(2, msk);
|
||||
}
|
||||
|
||||
instruction* atomic_add_inst::create(value *ptr, value *val, const std::string &name, instruction *next) {
|
||||
return new atomic_add_inst(ptr, val, name, next);
|
||||
instruction* atomic_add_inst::create(value *ptr, value *val, value *msk, const std::string &name, instruction *next) {
|
||||
return new atomic_add_inst(ptr, val, msk, name, next);
|
||||
}
|
||||
|
||||
// exp
|
||||
|
@@ -523,7 +523,7 @@ void BinaryOp::RelationalOpTypeChecking() {
|
||||
}
|
||||
Convert();
|
||||
}
|
||||
type_ = ArithmType::New(T_INT);
|
||||
type_ = ArithmType::New(T_BOOL);
|
||||
Broadcast(this, lhs_, rhs_, type_);
|
||||
}
|
||||
|
||||
@@ -538,7 +538,7 @@ void BinaryOp::EqualityOpTypeChecking() {
|
||||
Error(this, "invalid operands to binary %s", tok_->str_.c_str());
|
||||
Convert();
|
||||
}
|
||||
type_ = ArithmType::New(T_INT);
|
||||
type_ = ArithmType::New(T_BOOL);
|
||||
Broadcast(this, lhs_, rhs_, type_);
|
||||
}
|
||||
|
||||
@@ -558,7 +558,7 @@ void BinaryOp::LogicalOpTypeChecking() {
|
||||
::Type* rhsScalType = TryExtractScalarType(this, rhs_);
|
||||
if (!lhsScalType->IsScalar() || !rhsScalType->IsScalar())
|
||||
Error(this, "the operand should be arithmetic type or pointer");
|
||||
type_ = ArithmType::New(T_INT);
|
||||
type_ = ArithmType::New(T_BOOL);
|
||||
Broadcast(this, lhs_, rhs_, type_);
|
||||
}
|
||||
|
||||
|
@@ -277,12 +277,14 @@ void Generator::VisitFuncCall(FuncCall* funcCall) {
|
||||
ir::value* val = ret_;
|
||||
return set_ret(bld_->create_atomic_exch(ptr, val));
|
||||
}
|
||||
if(name == "f32_atomic_add"){
|
||||
if(name == "f32_atomic_add" || name == "atomic_add_64x64"){
|
||||
VisitExpr(funcCall->Args()->at(0));
|
||||
ir::value* ptr = ret_;
|
||||
VisitExpr(funcCall->Args()->at(1));
|
||||
ir::value* val = ret_;
|
||||
return set_ret(bld_->create_atomic_add(ptr, val));
|
||||
VisitExpr(funcCall->Args()->at(2));
|
||||
ir::value* msk = ret_;
|
||||
return set_ret(bld_->create_atomic_add(ptr, val, msk));
|
||||
}
|
||||
if(name == "sqrtf"){
|
||||
VisitExpr(funcCall->Args()->at(0));
|
||||
@@ -338,6 +340,7 @@ void Generator::VisitTempVar(TempVar* tempVar) {
|
||||
}
|
||||
|
||||
// Statement
|
||||
// TODO: int x = x; crashes
|
||||
void Generator::VisitDeclaration(Declaration* decl) {
|
||||
auto obj = decl->obj_;
|
||||
// initialize to undef
|
||||
|
@@ -650,7 +650,7 @@ Expr* Parser::ParseDerefOp(const Token* tok) {
|
||||
Expr* pred = nullptr;
|
||||
if(ts_.Try('?')){
|
||||
ts_.Expect('(');
|
||||
pred = ParseCastExpr();
|
||||
pred = ParseExpr();
|
||||
ts_.Expect(')');
|
||||
}
|
||||
Expr* addr = ParseCastExpr();
|
||||
|
@@ -239,6 +239,7 @@ std::unique_ptr<driver::module> function::make_bin(ir::module &module,
|
||||
throw std::runtime_error("using too much shared memory");
|
||||
barriers.run(module);
|
||||
isel.visit(module, *llvm);
|
||||
// ir::print(module, std::cout);
|
||||
std::unique_ptr<driver::module> res(driver::module::create(context, std::move(llvm)));
|
||||
return res;
|
||||
}
|
||||
@@ -351,6 +352,8 @@ std::string function::preheader() {
|
||||
extern int atomic_cas(int*, int, int);
|
||||
extern int atomic_xchg(int*, int);
|
||||
extern float f32_atomic_add(float*, float);
|
||||
extern void atomic_add_128x128(float*[128, 128], float[128, 128], bool[128, 128]);
|
||||
extern void atomic_add_64x64(float*[64, 64], float[64, 64], bool[64, 64]);
|
||||
extern int get_program_id(int);
|
||||
extern int get_num_programs(int);
|
||||
extern float sqrtf(float);
|
||||
|
Reference in New Issue
Block a user