[GENERAL] Merged v1.0alpha into master. Added features are:

- A100 support via mma.16816
- Thread swizzling for conflict-free shared memory accesses without
padding
- Complete overhaul of the LLVM code generation in
codegen/selection/generator.cc to remove overengineering
- Added debugging capabilities in the Python binding
- Compilation error for kernels that spill
This commit is contained in:
Philippe Tillet
2021-01-11 19:20:34 -05:00
parent c0bc7ed8b0
commit 083bbd1e8d
75 changed files with 2688 additions and 4512 deletions

View File

@@ -56,10 +56,13 @@ void Generator::VisitBinaryOp(BinaryOp* binary) {
return set_ret(bld_->create_dot(lhs, rhs, _0));
}
case Token::MASKED_DEREF: {
// TODO: FIXME
ir::type* ret_ty = GenIRType(binary->Type(), *ctx_);
ir::value* false_value = ir::undef_value::get(ret_ty->get_scalar_ty());
auto it = bld_->get_insert_block();
if(ret_ty->is_tile_ty())
false_value = bld_->create_splat(false_value, ret_ty->get_tile_shapes());
bld_->set_insert_point(it);
return set_ret(bld_->create_masked_load(rhs, lhs, false_value));
}
case Token::ELLIPSIS: {
@@ -274,9 +277,7 @@ void Generator::VisitConditionalOp(ConditionalOp* condOp) {
if(ir::unmasked_load_inst* ld = dynamic_cast<ir::unmasked_load_inst*>(true_val)) {
if(true_val->get_type()->is_tile_ty() && !false_val->get_type()->is_tile_ty())
false_val = bld_->create_splat(false_val, cond->get_type()->get_tile_shapes());
ir::value* new_ld = bld_->create_masked_load(ld->get_pointer_operand(),
cond,
false_val);
ir::value* new_ld = bld_->create_masked_load(ld->get_pointer_operand(), cond, false_val);
ld->replace_all_uses_with(new_ld);
ld->erase_from_parent();
return set_ret(new_ld);
@@ -468,10 +469,10 @@ void Generator::VisitForStmt(ForStmt *forStmt) {
});
if(init_)
VisitStmt(init_);
// VisitExpr(cond_);
// ir::value *cond = ret_;
// bld_->create_cond_br(cond, loop_bb, next_bb);
bld_->create_br(loop_bb);
VisitExpr(cond_);
ir::value *cond = ret_;
bld_->create_cond_br(cond, loop_bb, next_bb);
// bld_->create_br(loop_bb);
bld_->set_insert_point(loop_bb);
if(body_)
VisitStmt(body_);