[CORE] Fixed several issues that arose in the development of the

torch-blocksparse package:

* Now using warp shuffle in reductions when possible
* Various bugfixes in layout inference
* Added INFINITY, exponential and select
* Better error messages for unimplemented constructs
This commit is contained in:
Philippe Tillet
2020-03-31 18:55:31 -04:00
committed by Philippe Tillet
parent ac26fbdc1f
commit 3304629de9
33 changed files with 374 additions and 135 deletions

View File

@@ -165,8 +165,10 @@ void function::caller::operator ()(driver::stream *stream, const grid_t& _grid,
arg_type ty = arg_i.type();
if(ty != param_tys_.at(i))
throw std::runtime_error("invalid type for argument " + std::to_string(i));
if(ty == BUFFER_T)
bin_->setArg(i, *((driver::buffer**)arg_i.data()));
if(ty == BUFFER_T){
driver::buffer* buf = *((driver::buffer**)arg_i.data());
bin_->setArg(i, buf->size() == 0 ? nullptr : buf);
}
else
bin_->setArg(i, size_of(ty), arg_i.data());
}
@@ -216,6 +218,7 @@ std::unique_ptr<driver::module> function::make_bin(ir::module &module,
codegen::transform::cts cts;
codegen::generator isel(&axes, &layouts, &align, &allocation, target.get(), opt.num_warps);
// run passes
// ir::print(module, std::cout);
dce.run(module);
disassociate.run(module);
dce.run(module);
@@ -231,6 +234,7 @@ std::unique_ptr<driver::module> function::make_bin(ir::module &module,
dce.run(module);
reassociate.run(module);
cts.run(module);
peephole.run(module);
dce.run(module);
align.run(module);
axes.run(module);
@@ -238,7 +242,7 @@ std::unique_ptr<driver::module> function::make_bin(ir::module &module,
liveness.run(module);
allocation.run(module);
if(allocation.allocated_size() > context->device()->max_shared_memory())
return std::unique_ptr<driver::module>();
throw std::runtime_error("using too much shared memory");
barriers.run(module);
isel.visit(module, *llvm);
std::unique_ptr<driver::module> res(driver::module::create(context, std::move(llvm)));
@@ -391,6 +395,8 @@ std::string function::preheader() {
#define __aligned(A) __attribute__((aligned(A)))
#define __multipleof(A) __attribute__((multipleof(A)))
#define INFINITY bitcast<float>(0x7F800000)
extern int atomic_cas(int*, int, int);
extern int atomic_xchg(int*, int);
extern int get_program_id(int);