[CORE] Fixed several issues that arose in the development of the
torch-blocksparse package: * Now using warp shuffle in reductions when possible * Various bugfixes in layout inference * Added INFINITY, exponential and select * Better error messages for unimplemented constructs
This commit is contained in:
committed by
Philippe Tillet
parent
ac26fbdc1f
commit
3304629de9
@@ -165,8 +165,10 @@ void function::caller::operator ()(driver::stream *stream, const grid_t& _grid,
|
||||
arg_type ty = arg_i.type();
|
||||
if(ty != param_tys_.at(i))
|
||||
throw std::runtime_error("invalid type for argument " + std::to_string(i));
|
||||
if(ty == BUFFER_T)
|
||||
bin_->setArg(i, *((driver::buffer**)arg_i.data()));
|
||||
if(ty == BUFFER_T){
|
||||
driver::buffer* buf = *((driver::buffer**)arg_i.data());
|
||||
bin_->setArg(i, buf->size() == 0 ? nullptr : buf);
|
||||
}
|
||||
else
|
||||
bin_->setArg(i, size_of(ty), arg_i.data());
|
||||
}
|
||||
@@ -216,6 +218,7 @@ std::unique_ptr<driver::module> function::make_bin(ir::module &module,
|
||||
codegen::transform::cts cts;
|
||||
codegen::generator isel(&axes, &layouts, &align, &allocation, target.get(), opt.num_warps);
|
||||
// run passes
|
||||
// ir::print(module, std::cout);
|
||||
dce.run(module);
|
||||
disassociate.run(module);
|
||||
dce.run(module);
|
||||
@@ -231,6 +234,7 @@ std::unique_ptr<driver::module> function::make_bin(ir::module &module,
|
||||
dce.run(module);
|
||||
reassociate.run(module);
|
||||
cts.run(module);
|
||||
peephole.run(module);
|
||||
dce.run(module);
|
||||
align.run(module);
|
||||
axes.run(module);
|
||||
@@ -238,7 +242,7 @@ std::unique_ptr<driver::module> function::make_bin(ir::module &module,
|
||||
liveness.run(module);
|
||||
allocation.run(module);
|
||||
if(allocation.allocated_size() > context->device()->max_shared_memory())
|
||||
return std::unique_ptr<driver::module>();
|
||||
throw std::runtime_error("using too much shared memory");
|
||||
barriers.run(module);
|
||||
isel.visit(module, *llvm);
|
||||
std::unique_ptr<driver::module> res(driver::module::create(context, std::move(llvm)));
|
||||
@@ -391,6 +395,8 @@ std::string function::preheader() {
|
||||
#define __aligned(A) __attribute__((aligned(A)))
|
||||
#define __multipleof(A) __attribute__((multipleof(A)))
|
||||
|
||||
#define INFINITY bitcast<float>(0x7F800000)
|
||||
|
||||
extern int atomic_cas(int*, int, int);
|
||||
extern int atomic_xchg(int*, int);
|
||||
extern int get_program_id(int);
|
||||
|
Reference in New Issue
Block a user