[syntax tree] fixed broadcast semantics lowering

This commit is contained in:
Philippe Tillet
2019-01-08 17:44:31 -05:00
parent 7a14693f51
commit 73db84c8ba
10 changed files with 153 additions and 128 deletions

View File

@@ -4,6 +4,7 @@
#include "ir/context.h"
#include "ir/module.h"
#include "codegen/selection.h"
#include "codegen/tune.h"
#include "llvm/IR/IRPrintingPasses.h"
#include "llvm/IR/Module.h"
#include "llvm/IR/LLVMContext.h"
@@ -20,6 +21,8 @@ extern translation_unit *ast_root;
const char src[] =
"\
void test(fp32 *A, fp32 *B, fp32 *C, int32 i){\
int32 tile[16, 16] = 0;\
int32 test[16, 16] = tile + i;\
i = 1;\
A = A + i;\
}\
@@ -35,10 +38,14 @@ int main() {
program->codegen(&module);
llvm::LLVMContext llvm_context;
llvm::Module llvm_module("test", llvm_context);
// lowering passes
tdl::codegen::selection selection;
selection.run(module, llvm_module);
llvm::PrintModulePass print(llvm::outs());
llvm::AnalysisManager<llvm::Module> analysis;
print.run(llvm_module, analysis);
tdl::codegen::tune tune;
tune.run(module);
// selection.run(module, llvm_module);
// // print LLVM program
// llvm::PrintModulePass print(llvm::outs());
// llvm::AnalysisManager<llvm::Module> analysis;
// print.run(llvm_module, analysis);
return 0;
}

View File

@@ -61,6 +61,11 @@ class identifier;
// AST
class node {
protected:
static ir::value* explicit_cast(ir::builder &builder, ir::value *src, ir::type *dst_ty);
static void implicit_broadcast(ir::module *mod, ir::value *&lhs, ir::value *&rhs);
static void implicit_cast(ir::builder &builder, ir::value *&lhs, ir::value *&rhs,
bool &is_float, bool &is_ptr, bool &is_int, bool &is_signed);
public:
virtual ir::value* codegen(ir::module *) const { return nullptr; }
};

View File

@@ -25,7 +25,7 @@ public:
unsigned get_allocated_size() const { return allocated_size_; }
// run
void run(ir::function &fn);
void run();
private:
std::map<ir::value*, unsigned> offsets_;

View File

@@ -7,7 +7,7 @@
namespace tdl {
namespace ir {
class function;
class module;
class instruction;
class value;
}
@@ -32,7 +32,7 @@ public:
shared_view_info get_shared_view(ir::value *v, unsigned idx);
// run
bool run(ir::function &fn);
void run(ir::module &mod);
private:
std::map<ir::value*, shared_view_val_t> shared_views_;

View File

@@ -8,6 +8,7 @@ namespace tdl{
namespace ir{
class value;
class function;
class module;
}
namespace codegen{
@@ -47,7 +48,7 @@ public:
segment get_interval(ir::value* v) const { return intervals_.at(v); }
// run
void run(ir::function *fn);
void run(ir::module &mod);
private:
has_storage_map_t has_dedicated_storage_;

View File

@@ -14,6 +14,118 @@ namespace tdl{
namespace ast{
/* node */
ir::value *node::explicit_cast(ir::builder &builder, ir::value *src, ir::type *dst_ty){
ir::type *src_ty = src->get_type()->get_scalar_ty();
bool src_signed = false;
bool dst_signed = false;
if(src_ty == dst_ty)
return src;
else if(src_ty->is_integer_ty() && src_signed && dst_ty->is_floating_point_ty())
return builder.create_si_to_fp(src, dst_ty);
else if(src_ty->is_integer_ty() && !src_signed && dst_ty->is_floating_point_ty())
return builder.create_ui_to_fp(src, dst_ty);
else if(src_ty->is_floating_point_ty() && dst_ty->is_integer_ty() && dst_signed)
return builder.create_fp_to_si(src, dst_ty);
else if(src_ty->is_floating_point_ty() && dst_ty->is_integer_ty() && !dst_signed)
return builder.create_fp_to_ui(src, dst_ty);
else if(src_ty->is_floating_point_ty() && dst_ty->is_floating_point_ty() &&
src_ty->get_fp_mantissa_width() < dst_ty->get_fp_mantissa_width())
return builder.create_fp_ext(src, dst_ty);
else if(src_ty->is_floating_point_ty() && dst_ty->is_floating_point_ty() &&
src_ty->get_fp_mantissa_width() > dst_ty->get_fp_mantissa_width())
return builder.create_fp_trunc(src, dst_ty);
else if(src_ty->is_integer_ty() && dst_ty->is_integer_ty() &&
src_ty->get_integer_bitwidth())
return builder.create_int_cast(src, dst_ty, dst_signed);
else
throw std::runtime_error("unreachable");
}
void node::implicit_cast(ir::builder &builder, ir::value *&lhs, ir::value *&rhs,
bool &is_float, bool &is_ptr, bool &is_int, bool &is_signed){
// Input types
ir::type *left_ty = lhs->get_type()->get_scalar_ty();
ir::type *right_ty = rhs->get_type()->get_scalar_ty();
// One operand is pointer
if(left_ty->is_pointer_ty()){
is_ptr = true;
}
// One operand is double
else if(left_ty->is_double_ty() || right_ty->is_double_ty()){
ir::value *&to_convert = left_ty->is_double_ty()?rhs:lhs;
to_convert = explicit_cast(builder, to_convert, builder.get_double_ty());
is_float = true;
}
// One operand is float
else if(left_ty->is_float_ty() || right_ty->is_float_ty()){
ir::value *&to_convert = left_ty->is_float_ty()?rhs:lhs;
to_convert = explicit_cast(builder, to_convert, builder.get_float_ty());
is_float = true;
}
// Both operands are integers
else if(left_ty->is_integer_ty() && right_ty->is_integer_ty()){
is_int = true;
is_signed = false;
if(left_ty->get_integer_bitwidth() != right_ty->get_integer_bitwidth()){
ir::value *&to_convert = (left_ty->get_integer_bitwidth() > right_ty->get_integer_bitwidth())?rhs:lhs;
ir::type *dst_ty = (to_convert==lhs)?right_ty:left_ty;
to_convert = explicit_cast(builder, to_convert, dst_ty);
}
}
// Not reachable
else
throw std::runtime_error("unreachable");
}
void node::implicit_broadcast(ir::module *mod, ir::value *&lhs, ir::value *&rhs){
ir::builder &builder = mod->get_builder();
ir::type *lhs_ty = lhs->get_type();
ir::type *rhs_ty = rhs->get_type();
// Both are scalar
if(!lhs_ty->is_tile_ty() && !rhs_ty->is_tile_ty())
return;
// One argument is scalar
if(lhs_ty->is_tile_ty() ^ rhs_ty->is_tile_ty()){
auto &shapes = lhs_ty->is_tile_ty()?lhs_ty->get_tile_shapes():rhs_ty->get_tile_shapes();
auto &scalar = lhs_ty->is_tile_ty()?rhs:lhs;
scalar = builder.create_splat(scalar, shapes);
return;
}
// Both are arrays
std::vector<unsigned> lhs_shapes = lhs->get_type()->get_tile_shapes();
std::vector<unsigned> rhs_shapes = rhs->get_type()->get_tile_shapes();
int lhs_dim = lhs_shapes.size();
int rhs_dim = rhs_shapes.size();
std::vector<unsigned> &shortest = (lhs_dim < rhs_dim)?lhs_shapes:rhs_shapes;
std::vector<unsigned> &longest = (lhs_dim < rhs_dim)?rhs_shapes:lhs_shapes;
size_t ndim = longest.size();
int off = longest.size() - shortest.size();
for(int i = longest.size(); i>= 0; i--){
if(shortest[off + i] != longest[i])
throw std::runtime_error("cannot broadcast");
}
// Pad
for(size_t i = 0; i < off; i++)
shortest.insert(shortest.begin(), 1);
ir::value *&target = (lhs_dim < rhs_dim)?lhs:rhs;
target = builder.create_reshape(target, shortest);
// Broadcast
std::vector<unsigned> shapes(ndim);
for(size_t i = 0; i < ndim; i++)
shapes[i] = std::max(shortest[i], longest[i]);
lhs = builder.create_broadcast(lhs, shapes);
rhs = builder.create_broadcast(rhs, shapes);
}
/* Translation unit */
ir::value* translation_unit::codegen(ir::module *mod) const{
decls_->codegen(mod);
@@ -195,11 +307,12 @@ void initializer::specifier(const declaration_specifier *spec) {
ir::value* initializer::codegen(ir::module * mod) const{
ir::type *ty = decl_->type(mod, spec_->type(mod));
std::string name = decl_->id()->name();
ir::value *value;
if(expr_)
value = expr_->codegen(mod);
else
value = ir::undef_value::get(ty);
ir::value *value = ir::undef_value::get(ty);
if(expr_){
ir::value* target = expr_->codegen(mod);
explicit_cast(mod->get_builder(), target, ty->get_scalar_ty());
implicit_broadcast(mod, value, target);
}
value->set_name(name);
mod->set_value(name, value);
return value;
@@ -208,119 +321,12 @@ ir::value* initializer::codegen(ir::module * mod) const{
/*------------------*/
/* Expression */
/*------------------*/
ir::value *llvm_cast(ir::builder &builder, ir::value *src, ir::type *dst_ty){
ir::type *src_ty = src->get_type();
bool src_signed = false;
bool dst_signed = false;
if(src_ty == dst_ty)
return src;
else if(src_ty->is_integer_ty() && src_signed && dst_ty->is_floating_point_ty())
return builder.create_si_to_fp(src, dst_ty);
else if(src_ty->is_integer_ty() && !src_signed && dst_ty->is_floating_point_ty())
return builder.create_ui_to_fp(src, dst_ty);
else if(src_ty->is_floating_point_ty() && dst_ty->is_integer_ty() && dst_signed)
return builder.create_fp_to_si(src, dst_ty);
else if(src_ty->is_floating_point_ty() && dst_ty->is_integer_ty() && !dst_signed)
return builder.create_fp_to_ui(src, dst_ty);
else if(src_ty->is_floating_point_ty() && dst_ty->is_floating_point_ty() &&
src_ty->get_fp_mantissa_width() < dst_ty->get_fp_mantissa_width())
return builder.create_fp_ext(src, dst_ty);
else if(src_ty->is_floating_point_ty() && dst_ty->is_floating_point_ty() &&
src_ty->get_fp_mantissa_width() > dst_ty->get_fp_mantissa_width())
return builder.create_fp_trunc(src, dst_ty);
else if(src_ty->is_integer_ty() && dst_ty->is_integer_ty() &&
src_ty->get_integer_bitwidth())
return builder.create_int_cast(src, dst_ty, dst_signed);
else
throw std::runtime_error("unreachable");
}
inline void implicit_cast(ir::builder &builder, ir::value *&lhs, ir::value *&rhs,
bool &is_float, bool &is_ptr, bool &is_int, bool &is_signed){
// Input types
ir::type *left_ty = lhs->get_type();
ir::type *right_ty = rhs->get_type();
// One operand is pointer
if(left_ty->is_pointer_ty()){
is_ptr = true;
}
// One operand is double
else if(left_ty->is_double_ty() || right_ty->is_double_ty()){
ir::value *&to_convert = left_ty->is_double_ty()?rhs:lhs;
to_convert = llvm_cast(builder, to_convert, builder.get_double_ty());
is_float = true;
}
// One operand is float
else if(left_ty->is_float_ty() || right_ty->is_float_ty()){
ir::value *&to_convert = left_ty->is_float_ty()?rhs:lhs;
to_convert = llvm_cast(builder, to_convert, builder.get_float_ty());
is_float = true;
}
// Both operands are integers
else if(left_ty->is_integer_ty() && right_ty->is_integer_ty()){
is_int = true;
is_signed = false;
if(left_ty->get_integer_bitwidth() != right_ty->get_integer_bitwidth()){
ir::value *&to_convert = (left_ty->get_integer_bitwidth() > right_ty->get_integer_bitwidth())?rhs:lhs;
ir::type *dst_ty = (to_convert==lhs)?right_ty:left_ty;
to_convert = llvm_cast(builder, to_convert, dst_ty);
}
}
// Not reachable
else
throw std::runtime_error("unreachable");
}
inline void implicit_broadcast(ir::module *mod, ir::builder &builder, ir::value *&lhs, ir::value *&rhs){
std::vector<unsigned> lhs_shapes = lhs->get_type()->get_tile_shapes();
std::vector<unsigned> rhs_shapes = rhs->get_type()->get_tile_shapes();
// Both are scalar
if(lhs_shapes.empty() && rhs_shapes.empty())
return;
// One argument is scalar
if(!lhs_shapes.empty() ^ !rhs_shapes.empty()){
auto &shapes = lhs_shapes.empty()?rhs_shapes:lhs_shapes;
auto &target = lhs_shapes.empty()?lhs:rhs;
target = builder.create_splat(target, shapes);
return;
}
// Both are arrays
int lhs_dim = lhs_shapes.size();
int rhs_dim = rhs_shapes.size();
std::vector<unsigned> &shortest = (lhs_dim < rhs_dim)?lhs_shapes:rhs_shapes;
std::vector<unsigned> &longest = (lhs_dim < rhs_dim)?rhs_shapes:lhs_shapes;
size_t ndim = longest.size();
int off = longest.size() - shortest.size();
for(int i = longest.size(); i>= 0; i--){
if(shortest[off + i] != longest[i])
throw std::runtime_error("cannot broadcast");
}
// Pad
for(size_t i = 0; i < off; i++)
shortest.insert(shortest.begin(), 1);
ir::value *&target = (lhs_dim < rhs_dim)?lhs:rhs;
target = builder.create_reshape(target, shortest);
// Broadcast
std::vector<unsigned> shapes(ndim);
for(size_t i = 0; i < ndim; i++)
shapes[i] = std::max(shortest[i], longest[i]);
lhs = builder.create_broadcast(lhs, shapes);
rhs = builder.create_broadcast(rhs, shapes);
}
/* Binary operator */
ir::value *binary_operator::llvm_op(ir::module *mod, ir::builder &builder, ir::value *lhs, ir::value *rhs, const std::string &name) const
{
bool is_float = false, is_ptr = false, is_int = false, is_signed = false;
implicit_cast(builder, lhs, rhs, is_float, is_ptr, is_int, is_signed);
// implicit_broadcast(mod, builder, lhs, rhs);
implicit_broadcast(mod, lhs, rhs);
if(op_==MUL && is_float)
return builder.create_fmul(lhs, rhs, name);
if(op_==MUL && is_int)

View File

@@ -11,7 +11,7 @@ namespace tdl{
namespace codegen{
void allocation::run(ir::function &fn){
void allocation::run(){
using std::max;
using std::min;
typedef std::multimap<unsigned, segment> triples_map_type;

View File

@@ -1,5 +1,6 @@
#include "codegen/layout.h"
#include "ir/function.h"
#include "ir/module.h"
#include "ir/basic_block.h"
#include "ir/instructions.h"
@@ -36,19 +37,19 @@ void layout::add_shared_views(ir::value *v){
}
// Entry point
bool layout::run(ir::function &fn) {
void layout::run(ir::module &mod) {
for(ir::function *fn: mod.get_function_list()){
// Non-phis
for(ir::basic_block *block: fn.blocks())
for(ir::basic_block *block: fn->blocks())
for(ir::instruction *instr: block->get_inst_list()) {
add_shared_views(instr);
}
// Phi nodes
for(ir::basic_block *block: fn.blocks())
for(ir::basic_block *block: fn->blocks())
for(ir::instruction *instr: block->get_inst_list()) {
add_phi_nodes(instr);
}
// Done
return false;
}
}
}

View File

@@ -2,6 +2,7 @@
#include "codegen/layout.h"
#include "ir/basic_block.h"
#include "ir/function.h"
#include "ir/module.h"
#include "ir/instructions.h"
#include "ir/value.h"
@@ -10,7 +11,8 @@ namespace codegen{
// Entry point
void liveness::run(ir::function *fn) {
void liveness::run(ir::module &mod) {
for(ir::function *fn: mod.get_function_list()){
// Assigns index to each instruction
slot_index index = 0;
for(ir::basic_block *block: fn->blocks())
@@ -35,6 +37,7 @@ void liveness::run(ir::function *fn) {
intervals_[v] = segment{start, end};
}
}
}
}
}

View File

@@ -108,6 +108,7 @@ for(ir::function *fn: mod.get_function_list()){
r = i;
}
}
// extract unique instructions in order
std::vector<ir::instruction*> grids;
for(auto &ref: references)
@@ -118,6 +119,7 @@ for(ir::function *fn: mod.get_function_list()){
int num_warps = 1;
for(size_t k = 0; k < grids.front()->get_type()->get_tile_shapes().size(); k++)
num_warps *= *params_[grids.front()]["p2.d" + to_string(k)];
// check constraints
for(ir::instruction *i: grids){
ir::type *ty = i->get_type();