[code generation] added masked loads

This commit is contained in:
Philippe Tillet
2019-02-15 11:14:50 -05:00
parent 896e856b07
commit 5f5959dc6e
11 changed files with 128 additions and 54 deletions

View File

@@ -15,7 +15,7 @@ find_package(LLVM REQUIRED CONFIG)
message(STATUS ${LLVM_INCLUDE_DIRS})
include_directories(${LLVM_INCLUDE_DIRS})
add_definitions(${LLVM_DEFINITIONS})
llvm_map_components_to_libnames(llvm_libs support core irreader MC NVPTXCodeGen all)
#llvm_map_components_to_libnames(llvm_libs all)
#Default build type
if(NOT CMAKE_BUILD_TYPE)
@@ -34,7 +34,7 @@ set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} ${LLVM_CXXFLAGS} -std=c++11")
# TDL
file(GLOB_RECURSE LIBTDL_SRC lib/*.cpp)
add_library(tdl SHARED ${LIBTDL_SRC} ${BISON_Parser_OUTPUTS} ${FLEX_Lexer_OUTPUTS})
target_link_libraries(tdl ${llvm_libs})
target_link_libraries(tdl LLVM)
# Examples
add_subdirectory(examples)

View File

@@ -36,7 +36,7 @@ extern translation_unit *ast_root;
const char src[] =
"\
void test(fp32 *a, fp32 *b, fp32 *c, int32 M, int32 N, int32 K){\
void test(fp32 *a, fp32 *b, fp32 *c, int32 M, int32 N, int32 K, int32 bound){\
int32 rxa[16] = get_global_range[16](0);\
int32 ryb[16] = get_global_range[16](1);\
int32 rka[8] = 0 ... 8;\
@@ -50,15 +50,17 @@ void test(fp32 *a, fp32 *b, fp32 *c, int32 M, int32 N, int32 K){\
fp32* pc[16, 16] = c + rxc[:, newaxis] + ryc[newaxis, :]*M;\
fp32 a[16, 8] = *pa;\
fp32 b[16, 8] = *pb;\
int1 checkc0[16] = (rxc < M);\
int1 checkc1[16] = (ryc < N);\
int1 checkc0[16] = rxc < M;\
int1 checkc1[16] = ryc < N;\
int1 checkc[16, 16] = checkc0[:, newaxis] && checkc1[newaxis, :];\
for(k = K; k > 0; k = k - 8){\
int1 sanitya[16, 8] = (k >= bound);\
int1 sanityb[16, 8] = (k >= bound);\
C = dot(a, b, C);\
pa = pa + 8*M;\
pb = pb + 8*K;\
a = *pa;\
b = *pb;\
@sanitya a = *pa;\
@sanityb b = *pb;\
}\
@checkc *pc = C;\
}\
@@ -201,6 +203,8 @@ int main() {
for(auto &e: x.second)
std::cout << e << std::endl;
}
if(errors.size())
exit(EXIT_FAILURE);
// run passes
shared.run(module);
@@ -213,7 +217,7 @@ int main() {
// llvm source
llvm::legacy::PassManager manager;
// manager.add(llvm::createPrintModulePass(llvm::outs()));
manager.add(llvm::createPrintModulePass(llvm::outs()));
manager.add(llvm::createVerifierPass(true));
manager.run(llvm_module);
@@ -233,6 +237,7 @@ int main() {
// Allocate buffers
typedef float numeric_t;
size_t M = 128, N = 128, K = 128;
size_t bound = 8;
std::vector<numeric_t> c(M*N);
std::vector<numeric_t> rc(M*N);
std::vector<numeric_t> a(M*K);
@@ -252,13 +257,13 @@ int main() {
checkCudaErrors(cuMemcpyHtoD(d_b, b.data(), sizeof(numeric_t) * b.size()));
checkCudaErrors(cuMemcpyHtoD(d_c, c.data(), sizeof(numeric_t) * c.size()));
// Launch kernel
void *args[] = { &d_a, &d_b, &d_c, &M, &N, &K};
void *args[] = { &d_a, &d_b, &d_c, &M, &N, &K, &bound};
int num_regs;
cuFuncGetAttribute(&num_regs, CU_FUNC_ATTRIBUTE_NUM_REGS, cu_kernel);
unsigned TM = 16;
unsigned TN = 16;
unsigned nthreads = 32;
checkCudaErrors(cuLaunchKernel(cu_kernel, M/TM, N/TN, 1, nthreads, 1, 1, 0, cu_stream, args, NULL));
checkCudaErrors(cuLaunchKernel(cu_kernel, (M + TM - 1)/TM, (N + TN - 1)/TN, 1, nthreads, 1, 1, 0, cu_stream, args, NULL));
checkCudaErrors(cuStreamSynchronize(cu_stream));
// Write back
checkCudaErrors(cuMemcpyDtoH(c.data(), d_c, sizeof(numeric_t) * c.size()));

View File

@@ -185,7 +185,8 @@ private:
public:
binary_operator(BIN_OP_T op, node *lhs, node *rhs)
: op_(op), lhs_((expression*)lhs), rhs_((expression*)rhs) { }
: op_(op), lhs_((expression*)lhs), rhs_((expression*)rhs) {
}
ir::value* codegen(ir::module *) const;
private:
@@ -320,14 +321,14 @@ private:
class expression_statement: public statement{
public:
expression_statement(node *expr, node *pred = nullptr)
: expr_((expression*)expr), pred_((expression*)pred){ }
expression_statement(node *expr, node *mask = nullptr)
: expr_((expression*)expr), mask_((expression*)mask){ }
ir::value* codegen(ir::module * mod) const;
private:
expression *expr_;
expression *pred_;
expression *mask_;
};
class compound_statement: public statement{

View File

@@ -121,7 +121,7 @@ primary_expression
| constant ELLIPSIS constant { $$ = new constant_range($1, $3); }
| builtin { $$ = $1; }
| STRING_LITERAL { $$ = new string_literal(yytext); }
| '(' expression ')' { $$ = $1; }
| '(' expression ')' { $$ = $2; }
;
slice
@@ -155,7 +155,7 @@ unary_operator
cast_expression
: unary_expression { $$ = $1; }
| '(' type_name ')' cast_expression { $$ = new cast_operator($1, $2); }
| '(' type_name ')' cast_expression { $$ = new cast_operator($2, $4); }
;
multiplicative_expression

View File

@@ -54,13 +54,15 @@ private:
llvm::Value* shared_offset(indices_t idx);
public:
shared_tile(llvm::Type* ty, const shapes_t &shapes, llvm::Value* ptr, llvm::IRBuilder<> &builder);
shared_tile(llvm::Type* ty, const shapes_t &shapes, llvm::Value* ptr, llvm::IRBuilder<> &builder, llvm::Value* offset = nullptr);
void set_value(indices_t, llvm::Value *);
llvm::Value* get_value(indices_t idx);
llvm::Value* get_pointer() { return ptr_; }
llvm::Value* get_offset() { return offset_; }
private:
llvm::Value *ptr_;
llvm::Value *offset_;
llvm::IRBuilder<> &builder_;
std::map<indices_t, llvm::Value*> ptr_cache_;
};

View File

@@ -26,12 +26,16 @@ public:
const basic_block *get_parent() const { return parent_; }
basic_block *get_parent() { return parent_; }
void erase_from_parent();
// mask
value* set_mask(value *mask) { mask_ = mask; }
value* get_mask() { return mask_; }
// helpers
bool has_tile_result_or_op();
private:
basic_block *parent_;
value *pred_;
value *mask_;
};
//===----------------------------------------------------------------------===//

View File

@@ -16,33 +16,34 @@ namespace ast{
/* node */
ir::value *node::explicit_cast(ir::builder &builder, ir::value *src, ir::type *dst_ty){
ir::type *src_ty = src->get_type()->get_scalar_ty();
ir::type *src_scalar_ty = src->get_type()->get_scalar_ty();
ir::type *dst_scalar_ty = dst_ty->get_scalar_ty();
bool src_signed = false;
bool dst_signed = false;
if(src_ty == dst_ty)
if(src_scalar_ty == dst_scalar_ty)
return src;
else if(src_ty->is_integer_ty() && src_signed && dst_ty->is_floating_point_ty())
else if(src_scalar_ty->is_integer_ty() && src_signed && dst_scalar_ty->is_floating_point_ty())
return builder.create_si_to_fp(src, dst_ty);
else if(src_ty->is_integer_ty() && !src_signed && dst_ty->is_floating_point_ty())
else if(src_scalar_ty->is_integer_ty() && !src_signed && dst_scalar_ty->is_floating_point_ty())
return builder.create_ui_to_fp(src, dst_ty);
else if(src_ty->is_floating_point_ty() && dst_ty->is_integer_ty() && dst_signed)
else if(src_scalar_ty->is_floating_point_ty() && dst_scalar_ty->is_integer_ty() && dst_signed)
return builder.create_fp_to_si(src, dst_ty);
else if(src_ty->is_floating_point_ty() && dst_ty->is_integer_ty() && !dst_signed)
else if(src_scalar_ty->is_floating_point_ty() && dst_scalar_ty->is_integer_ty() && !dst_signed)
return builder.create_fp_to_ui(src, dst_ty);
else if(src_ty->is_floating_point_ty() && dst_ty->is_floating_point_ty() &&
src_ty->get_fp_mantissa_width() < dst_ty->get_fp_mantissa_width())
else if(src_scalar_ty->is_floating_point_ty() && dst_scalar_ty->is_floating_point_ty() &&
src_scalar_ty->get_fp_mantissa_width() < dst_scalar_ty->get_fp_mantissa_width())
return builder.create_fp_ext(src, dst_ty);
else if(src_ty->is_floating_point_ty() && dst_ty->is_floating_point_ty() &&
src_ty->get_fp_mantissa_width() > dst_ty->get_fp_mantissa_width())
else if(src_scalar_ty->is_floating_point_ty() && dst_scalar_ty->is_floating_point_ty() &&
src_scalar_ty->get_fp_mantissa_width() > dst_scalar_ty->get_fp_mantissa_width())
return builder.create_fp_trunc(src, dst_ty);
else if(src_ty->is_integer_ty() && dst_ty->is_integer_ty() &&
src_ty->get_integer_bitwidth())
else if(src_scalar_ty->is_integer_ty() && dst_scalar_ty->is_integer_ty() &&
src_scalar_ty->get_integer_bitwidth())
return builder.create_int_cast(src, dst_ty, dst_signed);
else
@@ -247,7 +248,14 @@ ir::value* compound_statement::codegen(ir::module* mod) const{
/* expression statement */
ir::value* expression_statement::codegen(ir::module *mod) const{
return expr_->codegen(mod);
ir::value *expr = expr_->codegen(mod);
if(mask_) {
ir::instruction *itn = dynamic_cast<ir::instruction*>(expr);
assert(itn);
ir::value *mask = mask_->codegen(mod);
itn->set_mask(mask);
}
return expr;
}
/* Iteration statement */
@@ -325,7 +333,7 @@ ir::value* initializer::codegen(ir::module * mod) const{
ir::value *value = ir::undef_value::get(ty);
if(expr_){
value = expr_->codegen(mod);
value = explicit_cast(mod->get_builder(), value, ty->get_scalar_ty());
value = explicit_cast(mod->get_builder(), value, ty);
implicit_broadcast(mod, value, ty);
}
value->set_name(name);
@@ -526,7 +534,7 @@ ir::value *assignment_expression::codegen(ir::module *mod) const{
assert(x->get_op()==DEREF);
assert(x->lvalue());
ir::value *ptr = x->lvalue()->codegen(mod);
mod->get_builder().create_store(ptr, rvalue);
rvalue = mod->get_builder().create_store(ptr, rvalue);
}
return rvalue;
}

View File

@@ -1,6 +1,7 @@
#include "codegen/selection.h"
#include "codegen/tune.h"
#include "codegen/allocation.h"
#include "llvm/IR/InstrTypes.h"
#include "llvm/IR/Module.h"
#include "llvm/IR/IRBuilder.h"
#include "ir/context.h"
@@ -9,6 +10,8 @@
#include "ir/type.h"
#include "llvm/Transforms/Scalar/EarlyCSE.h"
#include "llvm/Analysis/LoopInfo.h"
#include "llvm/Transforms/Utils/BasicBlockUtils.h"
#include "llvm/IR/BasicBlock.h"
namespace tdl{
namespace codegen{
@@ -121,8 +124,8 @@ Value* shared_tile::shared_offset(indices_t idx) {
return result;
}
shared_tile::shared_tile(Type *ty, const shapes_t &shapes, Value *ptr, llvm::IRBuilder<> &builder):
tile(ty, shapes), ptr_(ptr), builder_(builder) {
shared_tile::shared_tile(Type *ty, const shapes_t &shapes, Value *ptr, llvm::IRBuilder<> &builder, Value *offset):
tile(ty, shapes), ptr_(ptr), builder_(builder), offset_(offset) {
}
void shared_tile::set_value(indices_t idx, Value *value) {
@@ -404,25 +407,17 @@ void selection::create_tile(ir::value *v, IRBuilder<> &builder,
std::swap(id_pre, id_loop);
ir::value *pre_value = phi->get_incoming_value(id_pre);
ir::value *loop_value = phi->get_incoming_value(id_loop);
BasicBlock *pre_block = (BasicBlock*)vmap_[phi->get_incoming_block(id_pre)];
BasicBlock *loop_block = (BasicBlock*)vmap_[phi->get_incoming_block(id_loop)];
if(parent->empty())
builder.SetInsertPoint(parent);
else
builder.SetInsertPoint(&*parent->getFirstInsertionPt());
PHINode *ptr = builder.CreatePHI(ptr_ty, 2);
// offset
PHINode *offset = builder.CreatePHI(builder.getInt32Ty(), 2);
Value *next_offset = builder.CreateNeg(offset);
offset->addIncoming(builder.getInt32(alloc_->get_num_bytes(phi) / 2 / 4), pre_block);
offset->addIncoming(next_offset, loop_block);
// next pointer
Value *pre_ptr = builder.CreateGEP(sh_mem_ptr, builder.getInt32(alloc_->get_offset(phi)));
pre_ptr = builder.CreateBitCast(pre_ptr, ptr->getType());
Value *next_ptr = builder.CreateGEP(ptr, offset);
ptr->addIncoming(pre_ptr, pre_block);
ptr->addIncoming(next_ptr, loop_block);
tmap_.insert({v, new shared_tile(ty, shapes, ptr, builder)});
tmap_.insert({phi, new shared_tile(ty, shapes, ptr, builder, offset)});
tmap_.insert({pre_value, new shared_tile(ty, shapes, pre_ptr, builder)});
tmap_.insert({loop_value, new shared_tile(ty, shapes, next_ptr, builder)});
}
@@ -483,14 +478,43 @@ void selection::init_grids(ir::function *fn, IRBuilder<> &builder, Value *sh_mem
void selection::lower_tile_instruction(ir::instruction *ins, llvm::IRBuilder<> &builder) {
Module *module = builder.GetInsertBlock()->getModule();
BasicBlock *block = builder.GetInsertBlock();
Module *module = block->getModule();
Function *function = block->getParent();
ir::value *mask = ins->get_mask();
LLVMContext &ctx = builder.getContext();
// helper to handle masks
auto insert_masked = [&](indices_t idx, std::function<Value*()> insert_value) {
BasicBlock *block = builder.GetInsertBlock();
Value *result;
if(mask){
Value *llvm_mask = tmap_.at(mask)->get_value(idx);
BasicBlock *then_bb = BasicBlock::Create(ctx, "", function);
BasicBlock *done_bb = BasicBlock::Create(ctx, "", function);
builder.CreateCondBr(llvm_mask, then_bb, done_bb);
builder.SetInsertPoint(then_bb);
result = insert_value();
builder.CreateBr(done_bb);
builder.SetInsertPoint(done_bb);
if(!ins->get_type()->is_void_ty()){
Type *ty = result->getType();
PHINode *phi = builder.CreatePHI(ty, 2);
phi->addIncoming(llvm::UndefValue::get(ty), block);
phi->addIncoming(result, then_bb);
return (Value*)phi;
}
}
else
result = insert_value();
return result;
};
// store
if(auto *x = dynamic_cast<ir::store_inst*>(ins)) {
distributed_tile* ptr = (distributed_tile*)tmap_.at(x->get_pointer_operand());
tile *value = tmap_.at(x->get_value_operand());
ptr->for_each([&](indices_t idx){
builder.CreateStore(value->get_value(idx), ptr->get_value(idx));
insert_masked(idx, [&]{ return builder.CreateStore(value->get_value(idx), ptr->get_value(idx)); });
});
}
else {
@@ -511,7 +535,7 @@ void selection::lower_tile_instruction(ir::instruction *ins, llvm::IRBuilder<> &
Value *offset = builder.CreateMul(builder.getInt32(shapes[0]), group_id);
result->for_each([&](indices_t idx){
BinaryOperator *bin = static_cast<BinaryOperator*>(idx[0]);
result->set_value(idx, builder.CreateAdd(bin, offset));
result->set_value(idx, insert_masked(idx, [&]{ return builder.CreateAdd(bin, offset); }));
});
}
// reshape
@@ -530,7 +554,7 @@ void selection::lower_tile_instruction(ir::instruction *ins, llvm::IRBuilder<> &
// splat
else if(dynamic_cast<ir::splat_inst*>(ins)) {
result->for_each([&](indices_t idx) {
result->set_value(idx, llvm_value(ins->get_operand(0), builder));
result->set_value(idx, insert_masked(idx, [&]{ return llvm_value(ins->get_operand(0), builder); }));
});
}
// broadcast
@@ -603,7 +627,7 @@ void selection::lower_tile_instruction(ir::instruction *ins, llvm::IRBuilder<> &
else
return llvm_value(x, builder);
};
result->set_value(idx, llvm_inst(ins, value, builder));
result->set_value(idx, insert_masked(idx, [&]() { return llvm_inst(ins, value, builder); }));
});
}
}
@@ -625,6 +649,7 @@ void selection::run(ir::module &src, Module &dst){
vmap_.clear();
LLVMContext &dst_ctx = dst.getContext();
IRBuilder<> dst_builder(dst_ctx);
std::map<ir::value*, llvm::BasicBlock*> block_of;
// iterate over functions
for(ir::function *fn: src.get_function_list()) {
@@ -661,6 +686,7 @@ void selection::run(ir::module &src, Module &dst){
}
// create grids
init_grids(fn, dst_builder, sh_mem_ptr);
std::map<ir::basic_block*, BasicBlock*> last_block;
// iterate through block
for(ir::basic_block *block: fn->blocks()) {
BasicBlock *parent = (BasicBlock*)vmap_[block];
@@ -671,6 +697,7 @@ void selection::run(ir::module &src, Module &dst){
lower_instruction(i, dst_builder);
if(dynamic_cast<ir::phi_node*>(i))
dst_builder.SetInsertPoint(parent);
last_block[block] = dst_builder.GetInsertBlock();
}
}
// add phi operands
@@ -678,12 +705,31 @@ void selection::run(ir::module &src, Module &dst){
for(ir::instruction *inst: block->get_inst_list())
if(auto *phi = dynamic_cast<ir::phi_node*>(inst)){
if(buffer_info_->is_shared(phi)) {
PHINode *ptr = (PHINode*)((shared_tile*)tmap_.at(phi))->get_pointer();
PHINode *offset = (PHINode*)((shared_tile*)tmap_.at(phi))->get_offset();
for(unsigned n = 0; n < phi->get_num_incoming(); n++){
ir::value *inc_val = phi->get_incoming_value(n);
ir::basic_block *inc_block = phi->get_incoming_block(n);
BasicBlock *llvm_inc_block = last_block.at(inc_block);
shared_tile *inc_shared = (shared_tile*)tmap_.at(inc_val);
GetElementPtrInst *inc_ptr = dyn_cast<GetElementPtrInst>(inc_shared->get_pointer());
if(inc_ptr && ptr == inc_ptr->getPointerOperand()){
dst_builder.SetInsertPoint(llvm_inc_block->getTerminator());
Value *next_offset = dst_builder.CreateNeg(offset);
offset->addIncoming(next_offset, llvm_inc_block);
}
else {
offset->addIncoming(dst_builder.getInt32(alloc_->get_num_bytes(phi)/(2*4)), llvm_inc_block);
}
ptr->addIncoming(inc_shared->get_pointer(), llvm_inc_block);
}
continue;
}
for(unsigned n = 0; n < phi->get_num_incoming(); n++){
ir::value *inc_val = phi->get_incoming_value(n);
ir::basic_block *inc_block = phi->get_incoming_block(n);
BasicBlock *llvm_inc_block = (BasicBlock*)vmap_[inc_block];
std::cout << typeid(*inc_val).name() << " " << inc_val << " " << inc_block << std::endl;
BasicBlock *llvm_inc_block = last_block.at(inc_block);
if(phi->get_type()->is_tile_ty()) {
distributed_tile *phi_tile = (distributed_tile*)tmap_.at(phi);
distributed_tile *inc_tile = (distributed_tile*)tmap_.at(inc_val);

View File

@@ -67,10 +67,17 @@ void tune::init_c_graph(ir::instruction *v) {
}
// Element-wise
else if(dynamic_cast<ir::user*>(v)){
std::cout << typeid(*v).name() << std::endl;
for(unsigned i = 0; i < shapes.size(); i ++)
for(ir::value* op: v->ops()){
for(ir::value* op: v->ops())
add_constraint({v, i}, {op, i});
}
}
/* Add mask constraints */
if(ir::value *mask = v->get_mask()){
std::cout << typeid(*mask).name() << " " << typeid(*v->ops()[0]).name() << std::endl;
for(unsigned i = 0; i < shapes.size(); i++)
add_constraint({v->ops()[0], i}, {mask, i});
}
}
@@ -99,6 +106,7 @@ std::vector<unsigned*> tune::get_params(ir::module &mod) {
for(ir::instruction *i : block->get_inst_list())
for(auto &x: params_[i])
if(seen.insert(x.second).second && *x.second == 0){
std::cout << typeid(*i).name() << std::endl;
result.push_back(x.second);
}
return result;

View File

@@ -186,8 +186,8 @@ cast_inst *cast_inst::create(op_t op, value *arg, type *ty, const std::string &n
cast_inst *cast_inst::create_integer_cast(value *arg, type *ty, bool is_signed, const std::string &name, instruction *next){
type *arg_ty = arg->get_type();
assert(arg_ty->is_int_or_tileint_ty() && ty->is_int_or_tileint_ty() && "Invalid integer cast!");
unsigned arg_bits = arg_ty->get_integer_bitwidth();
unsigned dst_bits = ty->get_integer_bitwidth();
unsigned arg_bits = arg_ty->get_scalar_ty()->get_integer_bitwidth();
unsigned dst_bits = ty->get_scalar_ty()->get_integer_bitwidth();
op_t op = (arg_bits == dst_bits ? ic::BitCast :
(arg_bits > dst_bits ? ic::Trunc :
(is_signed ? ic::SExt : ic::ZExt)));

View File

@@ -33,7 +33,7 @@ unsigned type::get_primitive_size_in_bits() const {
}
unsigned type::get_integer_bitwidth() const
{ return ((integer_type*)(this))->get_bitwidth(); }
{ assert(id_ == IntegerTyID); return ((integer_type*)(this))->get_bitwidth(); }
unsigned type::get_tile_bitwidth() const
{ return ((tile_type*)(this))->get_bitwidth(); }