[code generation] added masked loads
This commit is contained in:
@@ -15,7 +15,7 @@ find_package(LLVM REQUIRED CONFIG)
|
||||
message(STATUS ${LLVM_INCLUDE_DIRS})
|
||||
include_directories(${LLVM_INCLUDE_DIRS})
|
||||
add_definitions(${LLVM_DEFINITIONS})
|
||||
llvm_map_components_to_libnames(llvm_libs support core irreader MC NVPTXCodeGen all)
|
||||
#llvm_map_components_to_libnames(llvm_libs all)
|
||||
|
||||
#Default build type
|
||||
if(NOT CMAKE_BUILD_TYPE)
|
||||
@@ -34,7 +34,7 @@ set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} ${LLVM_CXXFLAGS} -std=c++11")
|
||||
# TDL
|
||||
file(GLOB_RECURSE LIBTDL_SRC lib/*.cpp)
|
||||
add_library(tdl SHARED ${LIBTDL_SRC} ${BISON_Parser_OUTPUTS} ${FLEX_Lexer_OUTPUTS})
|
||||
target_link_libraries(tdl ${llvm_libs})
|
||||
target_link_libraries(tdl LLVM)
|
||||
|
||||
# Examples
|
||||
add_subdirectory(examples)
|
||||
|
@@ -36,7 +36,7 @@ extern translation_unit *ast_root;
|
||||
|
||||
const char src[] =
|
||||
"\
|
||||
void test(fp32 *a, fp32 *b, fp32 *c, int32 M, int32 N, int32 K){\
|
||||
void test(fp32 *a, fp32 *b, fp32 *c, int32 M, int32 N, int32 K, int32 bound){\
|
||||
int32 rxa[16] = get_global_range[16](0);\
|
||||
int32 ryb[16] = get_global_range[16](1);\
|
||||
int32 rka[8] = 0 ... 8;\
|
||||
@@ -50,15 +50,17 @@ void test(fp32 *a, fp32 *b, fp32 *c, int32 M, int32 N, int32 K){\
|
||||
fp32* pc[16, 16] = c + rxc[:, newaxis] + ryc[newaxis, :]*M;\
|
||||
fp32 a[16, 8] = *pa;\
|
||||
fp32 b[16, 8] = *pb;\
|
||||
int1 checkc0[16] = (rxc < M);\
|
||||
int1 checkc1[16] = (ryc < N);\
|
||||
int1 checkc0[16] = rxc < M;\
|
||||
int1 checkc1[16] = ryc < N;\
|
||||
int1 checkc[16, 16] = checkc0[:, newaxis] && checkc1[newaxis, :];\
|
||||
for(k = K; k > 0; k = k - 8){\
|
||||
int1 sanitya[16, 8] = (k >= bound);\
|
||||
int1 sanityb[16, 8] = (k >= bound);\
|
||||
C = dot(a, b, C);\
|
||||
pa = pa + 8*M;\
|
||||
pb = pb + 8*K;\
|
||||
a = *pa;\
|
||||
b = *pb;\
|
||||
@sanitya a = *pa;\
|
||||
@sanityb b = *pb;\
|
||||
}\
|
||||
@checkc *pc = C;\
|
||||
}\
|
||||
@@ -201,6 +203,8 @@ int main() {
|
||||
for(auto &e: x.second)
|
||||
std::cout << e << std::endl;
|
||||
}
|
||||
if(errors.size())
|
||||
exit(EXIT_FAILURE);
|
||||
|
||||
// run passes
|
||||
shared.run(module);
|
||||
@@ -213,7 +217,7 @@ int main() {
|
||||
|
||||
// llvm source
|
||||
llvm::legacy::PassManager manager;
|
||||
// manager.add(llvm::createPrintModulePass(llvm::outs()));
|
||||
manager.add(llvm::createPrintModulePass(llvm::outs()));
|
||||
manager.add(llvm::createVerifierPass(true));
|
||||
manager.run(llvm_module);
|
||||
|
||||
@@ -233,6 +237,7 @@ int main() {
|
||||
// Allocate buffers
|
||||
typedef float numeric_t;
|
||||
size_t M = 128, N = 128, K = 128;
|
||||
size_t bound = 8;
|
||||
std::vector<numeric_t> c(M*N);
|
||||
std::vector<numeric_t> rc(M*N);
|
||||
std::vector<numeric_t> a(M*K);
|
||||
@@ -252,13 +257,13 @@ int main() {
|
||||
checkCudaErrors(cuMemcpyHtoD(d_b, b.data(), sizeof(numeric_t) * b.size()));
|
||||
checkCudaErrors(cuMemcpyHtoD(d_c, c.data(), sizeof(numeric_t) * c.size()));
|
||||
// Launch kernel
|
||||
void *args[] = { &d_a, &d_b, &d_c, &M, &N, &K};
|
||||
void *args[] = { &d_a, &d_b, &d_c, &M, &N, &K, &bound};
|
||||
int num_regs;
|
||||
cuFuncGetAttribute(&num_regs, CU_FUNC_ATTRIBUTE_NUM_REGS, cu_kernel);
|
||||
unsigned TM = 16;
|
||||
unsigned TN = 16;
|
||||
unsigned nthreads = 32;
|
||||
checkCudaErrors(cuLaunchKernel(cu_kernel, M/TM, N/TN, 1, nthreads, 1, 1, 0, cu_stream, args, NULL));
|
||||
checkCudaErrors(cuLaunchKernel(cu_kernel, (M + TM - 1)/TM, (N + TN - 1)/TN, 1, nthreads, 1, 1, 0, cu_stream, args, NULL));
|
||||
checkCudaErrors(cuStreamSynchronize(cu_stream));
|
||||
// Write back
|
||||
checkCudaErrors(cuMemcpyDtoH(c.data(), d_c, sizeof(numeric_t) * c.size()));
|
||||
|
@@ -185,7 +185,8 @@ private:
|
||||
|
||||
public:
|
||||
binary_operator(BIN_OP_T op, node *lhs, node *rhs)
|
||||
: op_(op), lhs_((expression*)lhs), rhs_((expression*)rhs) { }
|
||||
: op_(op), lhs_((expression*)lhs), rhs_((expression*)rhs) {
|
||||
}
|
||||
ir::value* codegen(ir::module *) const;
|
||||
|
||||
private:
|
||||
@@ -320,14 +321,14 @@ private:
|
||||
|
||||
class expression_statement: public statement{
|
||||
public:
|
||||
expression_statement(node *expr, node *pred = nullptr)
|
||||
: expr_((expression*)expr), pred_((expression*)pred){ }
|
||||
expression_statement(node *expr, node *mask = nullptr)
|
||||
: expr_((expression*)expr), mask_((expression*)mask){ }
|
||||
|
||||
ir::value* codegen(ir::module * mod) const;
|
||||
|
||||
private:
|
||||
expression *expr_;
|
||||
expression *pred_;
|
||||
expression *mask_;
|
||||
};
|
||||
|
||||
class compound_statement: public statement{
|
||||
|
@@ -121,7 +121,7 @@ primary_expression
|
||||
| constant ELLIPSIS constant { $$ = new constant_range($1, $3); }
|
||||
| builtin { $$ = $1; }
|
||||
| STRING_LITERAL { $$ = new string_literal(yytext); }
|
||||
| '(' expression ')' { $$ = $1; }
|
||||
| '(' expression ')' { $$ = $2; }
|
||||
;
|
||||
|
||||
slice
|
||||
@@ -155,7 +155,7 @@ unary_operator
|
||||
|
||||
cast_expression
|
||||
: unary_expression { $$ = $1; }
|
||||
| '(' type_name ')' cast_expression { $$ = new cast_operator($1, $2); }
|
||||
| '(' type_name ')' cast_expression { $$ = new cast_operator($2, $4); }
|
||||
;
|
||||
|
||||
multiplicative_expression
|
||||
|
@@ -54,13 +54,15 @@ private:
|
||||
llvm::Value* shared_offset(indices_t idx);
|
||||
|
||||
public:
|
||||
shared_tile(llvm::Type* ty, const shapes_t &shapes, llvm::Value* ptr, llvm::IRBuilder<> &builder);
|
||||
shared_tile(llvm::Type* ty, const shapes_t &shapes, llvm::Value* ptr, llvm::IRBuilder<> &builder, llvm::Value* offset = nullptr);
|
||||
void set_value(indices_t, llvm::Value *);
|
||||
llvm::Value* get_value(indices_t idx);
|
||||
llvm::Value* get_pointer() { return ptr_; }
|
||||
llvm::Value* get_offset() { return offset_; }
|
||||
|
||||
private:
|
||||
llvm::Value *ptr_;
|
||||
llvm::Value *offset_;
|
||||
llvm::IRBuilder<> &builder_;
|
||||
std::map<indices_t, llvm::Value*> ptr_cache_;
|
||||
};
|
||||
|
@@ -26,12 +26,16 @@ public:
|
||||
const basic_block *get_parent() const { return parent_; }
|
||||
basic_block *get_parent() { return parent_; }
|
||||
void erase_from_parent();
|
||||
// mask
|
||||
value* set_mask(value *mask) { mask_ = mask; }
|
||||
value* get_mask() { return mask_; }
|
||||
// helpers
|
||||
bool has_tile_result_or_op();
|
||||
|
||||
private:
|
||||
basic_block *parent_;
|
||||
value *pred_;
|
||||
value *mask_;
|
||||
};
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
@@ -16,33 +16,34 @@ namespace ast{
|
||||
|
||||
/* node */
|
||||
ir::value *node::explicit_cast(ir::builder &builder, ir::value *src, ir::type *dst_ty){
|
||||
ir::type *src_ty = src->get_type()->get_scalar_ty();
|
||||
ir::type *src_scalar_ty = src->get_type()->get_scalar_ty();
|
||||
ir::type *dst_scalar_ty = dst_ty->get_scalar_ty();
|
||||
bool src_signed = false;
|
||||
bool dst_signed = false;
|
||||
if(src_ty == dst_ty)
|
||||
if(src_scalar_ty == dst_scalar_ty)
|
||||
return src;
|
||||
else if(src_ty->is_integer_ty() && src_signed && dst_ty->is_floating_point_ty())
|
||||
else if(src_scalar_ty->is_integer_ty() && src_signed && dst_scalar_ty->is_floating_point_ty())
|
||||
return builder.create_si_to_fp(src, dst_ty);
|
||||
|
||||
else if(src_ty->is_integer_ty() && !src_signed && dst_ty->is_floating_point_ty())
|
||||
else if(src_scalar_ty->is_integer_ty() && !src_signed && dst_scalar_ty->is_floating_point_ty())
|
||||
return builder.create_ui_to_fp(src, dst_ty);
|
||||
|
||||
else if(src_ty->is_floating_point_ty() && dst_ty->is_integer_ty() && dst_signed)
|
||||
else if(src_scalar_ty->is_floating_point_ty() && dst_scalar_ty->is_integer_ty() && dst_signed)
|
||||
return builder.create_fp_to_si(src, dst_ty);
|
||||
|
||||
else if(src_ty->is_floating_point_ty() && dst_ty->is_integer_ty() && !dst_signed)
|
||||
else if(src_scalar_ty->is_floating_point_ty() && dst_scalar_ty->is_integer_ty() && !dst_signed)
|
||||
return builder.create_fp_to_ui(src, dst_ty);
|
||||
|
||||
else if(src_ty->is_floating_point_ty() && dst_ty->is_floating_point_ty() &&
|
||||
src_ty->get_fp_mantissa_width() < dst_ty->get_fp_mantissa_width())
|
||||
else if(src_scalar_ty->is_floating_point_ty() && dst_scalar_ty->is_floating_point_ty() &&
|
||||
src_scalar_ty->get_fp_mantissa_width() < dst_scalar_ty->get_fp_mantissa_width())
|
||||
return builder.create_fp_ext(src, dst_ty);
|
||||
|
||||
else if(src_ty->is_floating_point_ty() && dst_ty->is_floating_point_ty() &&
|
||||
src_ty->get_fp_mantissa_width() > dst_ty->get_fp_mantissa_width())
|
||||
else if(src_scalar_ty->is_floating_point_ty() && dst_scalar_ty->is_floating_point_ty() &&
|
||||
src_scalar_ty->get_fp_mantissa_width() > dst_scalar_ty->get_fp_mantissa_width())
|
||||
return builder.create_fp_trunc(src, dst_ty);
|
||||
|
||||
else if(src_ty->is_integer_ty() && dst_ty->is_integer_ty() &&
|
||||
src_ty->get_integer_bitwidth())
|
||||
else if(src_scalar_ty->is_integer_ty() && dst_scalar_ty->is_integer_ty() &&
|
||||
src_scalar_ty->get_integer_bitwidth())
|
||||
return builder.create_int_cast(src, dst_ty, dst_signed);
|
||||
|
||||
else
|
||||
@@ -247,7 +248,14 @@ ir::value* compound_statement::codegen(ir::module* mod) const{
|
||||
|
||||
/* expression statement */
|
||||
ir::value* expression_statement::codegen(ir::module *mod) const{
|
||||
return expr_->codegen(mod);
|
||||
ir::value *expr = expr_->codegen(mod);
|
||||
if(mask_) {
|
||||
ir::instruction *itn = dynamic_cast<ir::instruction*>(expr);
|
||||
assert(itn);
|
||||
ir::value *mask = mask_->codegen(mod);
|
||||
itn->set_mask(mask);
|
||||
}
|
||||
return expr;
|
||||
}
|
||||
|
||||
/* Iteration statement */
|
||||
@@ -325,7 +333,7 @@ ir::value* initializer::codegen(ir::module * mod) const{
|
||||
ir::value *value = ir::undef_value::get(ty);
|
||||
if(expr_){
|
||||
value = expr_->codegen(mod);
|
||||
value = explicit_cast(mod->get_builder(), value, ty->get_scalar_ty());
|
||||
value = explicit_cast(mod->get_builder(), value, ty);
|
||||
implicit_broadcast(mod, value, ty);
|
||||
}
|
||||
value->set_name(name);
|
||||
@@ -526,7 +534,7 @@ ir::value *assignment_expression::codegen(ir::module *mod) const{
|
||||
assert(x->get_op()==DEREF);
|
||||
assert(x->lvalue());
|
||||
ir::value *ptr = x->lvalue()->codegen(mod);
|
||||
mod->get_builder().create_store(ptr, rvalue);
|
||||
rvalue = mod->get_builder().create_store(ptr, rvalue);
|
||||
}
|
||||
return rvalue;
|
||||
}
|
||||
|
@@ -1,6 +1,7 @@
|
||||
#include "codegen/selection.h"
|
||||
#include "codegen/tune.h"
|
||||
#include "codegen/allocation.h"
|
||||
#include "llvm/IR/InstrTypes.h"
|
||||
#include "llvm/IR/Module.h"
|
||||
#include "llvm/IR/IRBuilder.h"
|
||||
#include "ir/context.h"
|
||||
@@ -9,6 +10,8 @@
|
||||
#include "ir/type.h"
|
||||
#include "llvm/Transforms/Scalar/EarlyCSE.h"
|
||||
#include "llvm/Analysis/LoopInfo.h"
|
||||
#include "llvm/Transforms/Utils/BasicBlockUtils.h"
|
||||
#include "llvm/IR/BasicBlock.h"
|
||||
|
||||
namespace tdl{
|
||||
namespace codegen{
|
||||
@@ -121,8 +124,8 @@ Value* shared_tile::shared_offset(indices_t idx) {
|
||||
return result;
|
||||
}
|
||||
|
||||
shared_tile::shared_tile(Type *ty, const shapes_t &shapes, Value *ptr, llvm::IRBuilder<> &builder):
|
||||
tile(ty, shapes), ptr_(ptr), builder_(builder) {
|
||||
shared_tile::shared_tile(Type *ty, const shapes_t &shapes, Value *ptr, llvm::IRBuilder<> &builder, Value *offset):
|
||||
tile(ty, shapes), ptr_(ptr), builder_(builder), offset_(offset) {
|
||||
}
|
||||
|
||||
void shared_tile::set_value(indices_t idx, Value *value) {
|
||||
@@ -404,25 +407,17 @@ void selection::create_tile(ir::value *v, IRBuilder<> &builder,
|
||||
std::swap(id_pre, id_loop);
|
||||
ir::value *pre_value = phi->get_incoming_value(id_pre);
|
||||
ir::value *loop_value = phi->get_incoming_value(id_loop);
|
||||
BasicBlock *pre_block = (BasicBlock*)vmap_[phi->get_incoming_block(id_pre)];
|
||||
BasicBlock *loop_block = (BasicBlock*)vmap_[phi->get_incoming_block(id_loop)];
|
||||
if(parent->empty())
|
||||
builder.SetInsertPoint(parent);
|
||||
else
|
||||
builder.SetInsertPoint(&*parent->getFirstInsertionPt());
|
||||
PHINode *ptr = builder.CreatePHI(ptr_ty, 2);
|
||||
// offset
|
||||
PHINode *offset = builder.CreatePHI(builder.getInt32Ty(), 2);
|
||||
Value *next_offset = builder.CreateNeg(offset);
|
||||
offset->addIncoming(builder.getInt32(alloc_->get_num_bytes(phi) / 2 / 4), pre_block);
|
||||
offset->addIncoming(next_offset, loop_block);
|
||||
// next pointer
|
||||
Value *pre_ptr = builder.CreateGEP(sh_mem_ptr, builder.getInt32(alloc_->get_offset(phi)));
|
||||
pre_ptr = builder.CreateBitCast(pre_ptr, ptr->getType());
|
||||
Value *next_ptr = builder.CreateGEP(ptr, offset);
|
||||
ptr->addIncoming(pre_ptr, pre_block);
|
||||
ptr->addIncoming(next_ptr, loop_block);
|
||||
tmap_.insert({v, new shared_tile(ty, shapes, ptr, builder)});
|
||||
tmap_.insert({phi, new shared_tile(ty, shapes, ptr, builder, offset)});
|
||||
tmap_.insert({pre_value, new shared_tile(ty, shapes, pre_ptr, builder)});
|
||||
tmap_.insert({loop_value, new shared_tile(ty, shapes, next_ptr, builder)});
|
||||
}
|
||||
@@ -483,14 +478,43 @@ void selection::init_grids(ir::function *fn, IRBuilder<> &builder, Value *sh_mem
|
||||
|
||||
|
||||
void selection::lower_tile_instruction(ir::instruction *ins, llvm::IRBuilder<> &builder) {
|
||||
Module *module = builder.GetInsertBlock()->getModule();
|
||||
BasicBlock *block = builder.GetInsertBlock();
|
||||
Module *module = block->getModule();
|
||||
Function *function = block->getParent();
|
||||
ir::value *mask = ins->get_mask();
|
||||
LLVMContext &ctx = builder.getContext();
|
||||
// helper to handle masks
|
||||
auto insert_masked = [&](indices_t idx, std::function<Value*()> insert_value) {
|
||||
BasicBlock *block = builder.GetInsertBlock();
|
||||
Value *result;
|
||||
if(mask){
|
||||
Value *llvm_mask = tmap_.at(mask)->get_value(idx);
|
||||
BasicBlock *then_bb = BasicBlock::Create(ctx, "", function);
|
||||
BasicBlock *done_bb = BasicBlock::Create(ctx, "", function);
|
||||
builder.CreateCondBr(llvm_mask, then_bb, done_bb);
|
||||
builder.SetInsertPoint(then_bb);
|
||||
result = insert_value();
|
||||
builder.CreateBr(done_bb);
|
||||
builder.SetInsertPoint(done_bb);
|
||||
if(!ins->get_type()->is_void_ty()){
|
||||
Type *ty = result->getType();
|
||||
PHINode *phi = builder.CreatePHI(ty, 2);
|
||||
phi->addIncoming(llvm::UndefValue::get(ty), block);
|
||||
phi->addIncoming(result, then_bb);
|
||||
return (Value*)phi;
|
||||
}
|
||||
}
|
||||
else
|
||||
result = insert_value();
|
||||
return result;
|
||||
};
|
||||
|
||||
// store
|
||||
if(auto *x = dynamic_cast<ir::store_inst*>(ins)) {
|
||||
distributed_tile* ptr = (distributed_tile*)tmap_.at(x->get_pointer_operand());
|
||||
tile *value = tmap_.at(x->get_value_operand());
|
||||
ptr->for_each([&](indices_t idx){
|
||||
builder.CreateStore(value->get_value(idx), ptr->get_value(idx));
|
||||
insert_masked(idx, [&]{ return builder.CreateStore(value->get_value(idx), ptr->get_value(idx)); });
|
||||
});
|
||||
}
|
||||
else {
|
||||
@@ -511,7 +535,7 @@ void selection::lower_tile_instruction(ir::instruction *ins, llvm::IRBuilder<> &
|
||||
Value *offset = builder.CreateMul(builder.getInt32(shapes[0]), group_id);
|
||||
result->for_each([&](indices_t idx){
|
||||
BinaryOperator *bin = static_cast<BinaryOperator*>(idx[0]);
|
||||
result->set_value(idx, builder.CreateAdd(bin, offset));
|
||||
result->set_value(idx, insert_masked(idx, [&]{ return builder.CreateAdd(bin, offset); }));
|
||||
});
|
||||
}
|
||||
// reshape
|
||||
@@ -530,7 +554,7 @@ void selection::lower_tile_instruction(ir::instruction *ins, llvm::IRBuilder<> &
|
||||
// splat
|
||||
else if(dynamic_cast<ir::splat_inst*>(ins)) {
|
||||
result->for_each([&](indices_t idx) {
|
||||
result->set_value(idx, llvm_value(ins->get_operand(0), builder));
|
||||
result->set_value(idx, insert_masked(idx, [&]{ return llvm_value(ins->get_operand(0), builder); }));
|
||||
});
|
||||
}
|
||||
// broadcast
|
||||
@@ -603,7 +627,7 @@ void selection::lower_tile_instruction(ir::instruction *ins, llvm::IRBuilder<> &
|
||||
else
|
||||
return llvm_value(x, builder);
|
||||
};
|
||||
result->set_value(idx, llvm_inst(ins, value, builder));
|
||||
result->set_value(idx, insert_masked(idx, [&]() { return llvm_inst(ins, value, builder); }));
|
||||
});
|
||||
}
|
||||
}
|
||||
@@ -625,6 +649,7 @@ void selection::run(ir::module &src, Module &dst){
|
||||
vmap_.clear();
|
||||
LLVMContext &dst_ctx = dst.getContext();
|
||||
IRBuilder<> dst_builder(dst_ctx);
|
||||
std::map<ir::value*, llvm::BasicBlock*> block_of;
|
||||
|
||||
// iterate over functions
|
||||
for(ir::function *fn: src.get_function_list()) {
|
||||
@@ -661,6 +686,7 @@ void selection::run(ir::module &src, Module &dst){
|
||||
}
|
||||
// create grids
|
||||
init_grids(fn, dst_builder, sh_mem_ptr);
|
||||
std::map<ir::basic_block*, BasicBlock*> last_block;
|
||||
// iterate through block
|
||||
for(ir::basic_block *block: fn->blocks()) {
|
||||
BasicBlock *parent = (BasicBlock*)vmap_[block];
|
||||
@@ -671,6 +697,7 @@ void selection::run(ir::module &src, Module &dst){
|
||||
lower_instruction(i, dst_builder);
|
||||
if(dynamic_cast<ir::phi_node*>(i))
|
||||
dst_builder.SetInsertPoint(parent);
|
||||
last_block[block] = dst_builder.GetInsertBlock();
|
||||
}
|
||||
}
|
||||
// add phi operands
|
||||
@@ -678,12 +705,31 @@ void selection::run(ir::module &src, Module &dst){
|
||||
for(ir::instruction *inst: block->get_inst_list())
|
||||
if(auto *phi = dynamic_cast<ir::phi_node*>(inst)){
|
||||
if(buffer_info_->is_shared(phi)) {
|
||||
PHINode *ptr = (PHINode*)((shared_tile*)tmap_.at(phi))->get_pointer();
|
||||
PHINode *offset = (PHINode*)((shared_tile*)tmap_.at(phi))->get_offset();
|
||||
for(unsigned n = 0; n < phi->get_num_incoming(); n++){
|
||||
ir::value *inc_val = phi->get_incoming_value(n);
|
||||
ir::basic_block *inc_block = phi->get_incoming_block(n);
|
||||
BasicBlock *llvm_inc_block = last_block.at(inc_block);
|
||||
shared_tile *inc_shared = (shared_tile*)tmap_.at(inc_val);
|
||||
GetElementPtrInst *inc_ptr = dyn_cast<GetElementPtrInst>(inc_shared->get_pointer());
|
||||
if(inc_ptr && ptr == inc_ptr->getPointerOperand()){
|
||||
dst_builder.SetInsertPoint(llvm_inc_block->getTerminator());
|
||||
Value *next_offset = dst_builder.CreateNeg(offset);
|
||||
offset->addIncoming(next_offset, llvm_inc_block);
|
||||
}
|
||||
else {
|
||||
offset->addIncoming(dst_builder.getInt32(alloc_->get_num_bytes(phi)/(2*4)), llvm_inc_block);
|
||||
}
|
||||
ptr->addIncoming(inc_shared->get_pointer(), llvm_inc_block);
|
||||
}
|
||||
continue;
|
||||
}
|
||||
for(unsigned n = 0; n < phi->get_num_incoming(); n++){
|
||||
ir::value *inc_val = phi->get_incoming_value(n);
|
||||
ir::basic_block *inc_block = phi->get_incoming_block(n);
|
||||
BasicBlock *llvm_inc_block = (BasicBlock*)vmap_[inc_block];
|
||||
std::cout << typeid(*inc_val).name() << " " << inc_val << " " << inc_block << std::endl;
|
||||
BasicBlock *llvm_inc_block = last_block.at(inc_block);
|
||||
if(phi->get_type()->is_tile_ty()) {
|
||||
distributed_tile *phi_tile = (distributed_tile*)tmap_.at(phi);
|
||||
distributed_tile *inc_tile = (distributed_tile*)tmap_.at(inc_val);
|
||||
|
@@ -67,10 +67,17 @@ void tune::init_c_graph(ir::instruction *v) {
|
||||
}
|
||||
// Element-wise
|
||||
else if(dynamic_cast<ir::user*>(v)){
|
||||
std::cout << typeid(*v).name() << std::endl;
|
||||
for(unsigned i = 0; i < shapes.size(); i ++)
|
||||
for(ir::value* op: v->ops()){
|
||||
for(ir::value* op: v->ops())
|
||||
add_constraint({v, i}, {op, i});
|
||||
}
|
||||
}
|
||||
|
||||
/* Add mask constraints */
|
||||
if(ir::value *mask = v->get_mask()){
|
||||
std::cout << typeid(*mask).name() << " " << typeid(*v->ops()[0]).name() << std::endl;
|
||||
for(unsigned i = 0; i < shapes.size(); i++)
|
||||
add_constraint({v->ops()[0], i}, {mask, i});
|
||||
}
|
||||
}
|
||||
|
||||
@@ -99,6 +106,7 @@ std::vector<unsigned*> tune::get_params(ir::module &mod) {
|
||||
for(ir::instruction *i : block->get_inst_list())
|
||||
for(auto &x: params_[i])
|
||||
if(seen.insert(x.second).second && *x.second == 0){
|
||||
std::cout << typeid(*i).name() << std::endl;
|
||||
result.push_back(x.second);
|
||||
}
|
||||
return result;
|
||||
|
@@ -186,8 +186,8 @@ cast_inst *cast_inst::create(op_t op, value *arg, type *ty, const std::string &n
|
||||
cast_inst *cast_inst::create_integer_cast(value *arg, type *ty, bool is_signed, const std::string &name, instruction *next){
|
||||
type *arg_ty = arg->get_type();
|
||||
assert(arg_ty->is_int_or_tileint_ty() && ty->is_int_or_tileint_ty() && "Invalid integer cast!");
|
||||
unsigned arg_bits = arg_ty->get_integer_bitwidth();
|
||||
unsigned dst_bits = ty->get_integer_bitwidth();
|
||||
unsigned arg_bits = arg_ty->get_scalar_ty()->get_integer_bitwidth();
|
||||
unsigned dst_bits = ty->get_scalar_ty()->get_integer_bitwidth();
|
||||
op_t op = (arg_bits == dst_bits ? ic::BitCast :
|
||||
(arg_bits > dst_bits ? ic::Trunc :
|
||||
(is_signed ? ic::SExt : ic::ZExt)));
|
||||
|
@@ -33,7 +33,7 @@ unsigned type::get_primitive_size_in_bits() const {
|
||||
}
|
||||
|
||||
unsigned type::get_integer_bitwidth() const
|
||||
{ return ((integer_type*)(this))->get_bitwidth(); }
|
||||
{ assert(id_ == IntegerTyID); return ((integer_type*)(this))->get_bitwidth(); }
|
||||
|
||||
unsigned type::get_tile_bitwidth() const
|
||||
{ return ((tile_type*)(this))->get_bitwidth(); }
|
||||
|
Reference in New Issue
Block a user