[TRITON] Fixed misaligned address issue
This commit is contained in:
committed by
Philippe Tillet
parent
a8f1b85c5f
commit
4bb0311f60
@@ -22,6 +22,30 @@ namespace codegen{
|
||||
|
||||
using namespace llvm;
|
||||
|
||||
// Function for extended Euclidean Algorithm
|
||||
inline int gcd_impl(int a, int b, int *x, int *y)
|
||||
{
|
||||
// Base Case
|
||||
if (a == 0)
|
||||
{
|
||||
*x = 0;
|
||||
*y = 1;
|
||||
return b;
|
||||
}
|
||||
int x1, y1; // To store results of recursive call
|
||||
int gcd = gcd_impl(b%a, a, &x1, &y1);
|
||||
// Update x and y using results of
|
||||
// recursive call
|
||||
*x = y1 - (b/a) * x1;
|
||||
*y = x1;
|
||||
return gcd;
|
||||
}
|
||||
|
||||
inline int gcd(int a, int b) {
|
||||
int x, y;
|
||||
return gcd_impl(a, b, &x, &y);
|
||||
}
|
||||
|
||||
|
||||
llvm::Instruction::BinaryOps llvm_op(ir::binary_op_t op) {
|
||||
using llop = llvm::Instruction::BinaryOps;
|
||||
@@ -309,7 +333,7 @@ void generator::visit_unmasked_load_inst(ir::unmasked_load_inst* x) {
|
||||
unsigned contiguous = 1;
|
||||
if(ld < x->get_type()->get_tile_rank())
|
||||
contiguous = result->axis(ld).contiguous;
|
||||
unsigned vector_size = std::min<unsigned>(contiguous, alignment);
|
||||
unsigned vector_size = gcd(contiguous, alignment);
|
||||
|
||||
unsigned linear = result->get_linear_index(idx);
|
||||
unsigned id = linear / vector_size;
|
||||
@@ -329,7 +353,7 @@ void generator::visit_unmasked_load_inst(ir::unmasked_load_inst* x) {
|
||||
unsigned contiguous = 1;
|
||||
if(ld < x->get_type()->get_tile_rank())
|
||||
contiguous = result->axis(ld).contiguous;
|
||||
unsigned vector_size = std::min<unsigned>(contiguous, alignment);
|
||||
unsigned vector_size = gcd(contiguous, alignment);
|
||||
unsigned linear = result->get_linear_index(idx);
|
||||
unsigned id = linear / vector_size;
|
||||
set_value(x, idx, builder_->CreateExtractElement(packets.at(id), linear % vector_size));
|
||||
@@ -347,7 +371,7 @@ void generator::visit_masked_load_inst(ir::masked_load_inst* x) {
|
||||
std::map<unsigned, Value*> packets;
|
||||
for_each(x, [&](indices_t idx){
|
||||
distributed_tile* result = (distributed_tile*)tmap_.at(x);
|
||||
unsigned vector_size = std::min<unsigned>(result->axis(ld).contiguous, alignment);
|
||||
unsigned vector_size = gcd(result->axis(ld).contiguous, alignment);
|
||||
unsigned linear = result->get_linear_index(idx);
|
||||
unsigned id = linear / vector_size;
|
||||
if(linear % vector_size == 0) {
|
||||
@@ -400,7 +424,7 @@ void generator::visit_masked_load_inst(ir::masked_load_inst* x) {
|
||||
// extract result element
|
||||
for_each(x, [&](indices_t idx){
|
||||
distributed_tile* result = (distributed_tile*)tmap_.at(x);
|
||||
unsigned vector_size = std::min<unsigned>(result->axis(ld).contiguous, alignment);
|
||||
unsigned vector_size = gcd(result->axis(ld).contiguous, alignment);
|
||||
unsigned linear = result->get_linear_index(idx);
|
||||
unsigned id = linear / vector_size;
|
||||
// Value *tmp = builder_->CreateExtractValue(packets.at(id), {(linear % vector_size) / 2});
|
||||
@@ -418,6 +442,8 @@ void generator::visit_unmasked_store_inst(ir::unmasked_store_inst* st) {
|
||||
});
|
||||
}
|
||||
|
||||
|
||||
|
||||
void generator::visit_masked_store_inst(ir::masked_store_inst* st) {
|
||||
distributed_tile* ptrs = (distributed_tile*)tmap_.at(st->get_pointer_operand());
|
||||
distributed_tile* masks = (distributed_tile*)tmap_.at(st->get_mask_operand());
|
||||
@@ -425,7 +451,7 @@ void generator::visit_masked_store_inst(ir::masked_store_inst* st) {
|
||||
int vector_size = 1;
|
||||
int ld = ptrs->get_order()[0];
|
||||
unsigned alignment = alignment_->get(st->get_pointer_operand(), ld);
|
||||
vector_size = std::min<unsigned>(ptrs->axis(ld).contiguous, alignment);
|
||||
vector_size = gcd(ptrs->axis(ld).contiguous, alignment);
|
||||
// create packets
|
||||
std::map<unsigned, Value*> packets;
|
||||
ir::value *arg = st->get_value_operand();
|
||||
|
@@ -266,6 +266,7 @@ cu_module::cu_module(driver::context * context, std::unique_ptr<llvm::Module> ll
|
||||
|
||||
cu_module::cu_module(driver::context * context, std::string const & source) : module(context, CUmodule(), true), source_(source){
|
||||
cu_context::context_switcher ctx(*context);
|
||||
// std::cout << source << std::endl;
|
||||
// JIT compile source-code
|
||||
CUjit_option opt[] = {CU_JIT_ERROR_LOG_BUFFER_SIZE_BYTES, CU_JIT_ERROR_LOG_BUFFER};
|
||||
unsigned int errbufsize = 8096;
|
||||
|
Reference in New Issue
Block a user