[syntax tree] fixed broadcast semantics lowering
This commit is contained in:
@@ -4,6 +4,7 @@
|
||||
#include "ir/context.h"
|
||||
#include "ir/module.h"
|
||||
#include "codegen/selection.h"
|
||||
#include "codegen/tune.h"
|
||||
#include "llvm/IR/IRPrintingPasses.h"
|
||||
#include "llvm/IR/Module.h"
|
||||
#include "llvm/IR/LLVMContext.h"
|
||||
@@ -20,6 +21,8 @@ extern translation_unit *ast_root;
|
||||
const char src[] =
|
||||
"\
|
||||
void test(fp32 *A, fp32 *B, fp32 *C, int32 i){\
|
||||
int32 tile[16, 16] = 0;\
|
||||
int32 test[16, 16] = tile + i;\
|
||||
i = 1;\
|
||||
A = A + i;\
|
||||
}\
|
||||
@@ -35,10 +38,14 @@ int main() {
|
||||
program->codegen(&module);
|
||||
llvm::LLVMContext llvm_context;
|
||||
llvm::Module llvm_module("test", llvm_context);
|
||||
// lowering passes
|
||||
tdl::codegen::selection selection;
|
||||
selection.run(module, llvm_module);
|
||||
llvm::PrintModulePass print(llvm::outs());
|
||||
llvm::AnalysisManager<llvm::Module> analysis;
|
||||
print.run(llvm_module, analysis);
|
||||
tdl::codegen::tune tune;
|
||||
tune.run(module);
|
||||
// selection.run(module, llvm_module);
|
||||
// // print LLVM program
|
||||
// llvm::PrintModulePass print(llvm::outs());
|
||||
// llvm::AnalysisManager<llvm::Module> analysis;
|
||||
// print.run(llvm_module, analysis);
|
||||
return 0;
|
||||
}
|
||||
|
@@ -61,6 +61,11 @@ class identifier;
|
||||
|
||||
// AST
|
||||
class node {
|
||||
protected:
|
||||
static ir::value* explicit_cast(ir::builder &builder, ir::value *src, ir::type *dst_ty);
|
||||
static void implicit_broadcast(ir::module *mod, ir::value *&lhs, ir::value *&rhs);
|
||||
static void implicit_cast(ir::builder &builder, ir::value *&lhs, ir::value *&rhs,
|
||||
bool &is_float, bool &is_ptr, bool &is_int, bool &is_signed);
|
||||
public:
|
||||
virtual ir::value* codegen(ir::module *) const { return nullptr; }
|
||||
};
|
||||
|
@@ -25,7 +25,7 @@ public:
|
||||
unsigned get_allocated_size() const { return allocated_size_; }
|
||||
|
||||
// run
|
||||
void run(ir::function &fn);
|
||||
void run();
|
||||
|
||||
private:
|
||||
std::map<ir::value*, unsigned> offsets_;
|
||||
|
@@ -7,7 +7,7 @@
|
||||
namespace tdl {
|
||||
|
||||
namespace ir {
|
||||
class function;
|
||||
class module;
|
||||
class instruction;
|
||||
class value;
|
||||
}
|
||||
@@ -32,7 +32,7 @@ public:
|
||||
shared_view_info get_shared_view(ir::value *v, unsigned idx);
|
||||
|
||||
// run
|
||||
bool run(ir::function &fn);
|
||||
void run(ir::module &mod);
|
||||
|
||||
private:
|
||||
std::map<ir::value*, shared_view_val_t> shared_views_;
|
||||
|
@@ -8,6 +8,7 @@ namespace tdl{
|
||||
namespace ir{
|
||||
class value;
|
||||
class function;
|
||||
class module;
|
||||
}
|
||||
|
||||
namespace codegen{
|
||||
@@ -47,7 +48,7 @@ public:
|
||||
segment get_interval(ir::value* v) const { return intervals_.at(v); }
|
||||
|
||||
// run
|
||||
void run(ir::function *fn);
|
||||
void run(ir::module &mod);
|
||||
|
||||
private:
|
||||
has_storage_map_t has_dedicated_storage_;
|
||||
|
@@ -14,6 +14,118 @@ namespace tdl{
|
||||
|
||||
namespace ast{
|
||||
|
||||
/* node */
|
||||
ir::value *node::explicit_cast(ir::builder &builder, ir::value *src, ir::type *dst_ty){
|
||||
ir::type *src_ty = src->get_type()->get_scalar_ty();
|
||||
bool src_signed = false;
|
||||
bool dst_signed = false;
|
||||
if(src_ty == dst_ty)
|
||||
return src;
|
||||
else if(src_ty->is_integer_ty() && src_signed && dst_ty->is_floating_point_ty())
|
||||
return builder.create_si_to_fp(src, dst_ty);
|
||||
|
||||
else if(src_ty->is_integer_ty() && !src_signed && dst_ty->is_floating_point_ty())
|
||||
return builder.create_ui_to_fp(src, dst_ty);
|
||||
|
||||
else if(src_ty->is_floating_point_ty() && dst_ty->is_integer_ty() && dst_signed)
|
||||
return builder.create_fp_to_si(src, dst_ty);
|
||||
|
||||
else if(src_ty->is_floating_point_ty() && dst_ty->is_integer_ty() && !dst_signed)
|
||||
return builder.create_fp_to_ui(src, dst_ty);
|
||||
|
||||
else if(src_ty->is_floating_point_ty() && dst_ty->is_floating_point_ty() &&
|
||||
src_ty->get_fp_mantissa_width() < dst_ty->get_fp_mantissa_width())
|
||||
return builder.create_fp_ext(src, dst_ty);
|
||||
|
||||
else if(src_ty->is_floating_point_ty() && dst_ty->is_floating_point_ty() &&
|
||||
src_ty->get_fp_mantissa_width() > dst_ty->get_fp_mantissa_width())
|
||||
return builder.create_fp_trunc(src, dst_ty);
|
||||
|
||||
else if(src_ty->is_integer_ty() && dst_ty->is_integer_ty() &&
|
||||
src_ty->get_integer_bitwidth())
|
||||
return builder.create_int_cast(src, dst_ty, dst_signed);
|
||||
|
||||
else
|
||||
throw std::runtime_error("unreachable");
|
||||
}
|
||||
|
||||
|
||||
void node::implicit_cast(ir::builder &builder, ir::value *&lhs, ir::value *&rhs,
|
||||
bool &is_float, bool &is_ptr, bool &is_int, bool &is_signed){
|
||||
// Input types
|
||||
ir::type *left_ty = lhs->get_type()->get_scalar_ty();
|
||||
ir::type *right_ty = rhs->get_type()->get_scalar_ty();
|
||||
// One operand is pointer
|
||||
if(left_ty->is_pointer_ty()){
|
||||
is_ptr = true;
|
||||
}
|
||||
// One operand is double
|
||||
else if(left_ty->is_double_ty() || right_ty->is_double_ty()){
|
||||
ir::value *&to_convert = left_ty->is_double_ty()?rhs:lhs;
|
||||
to_convert = explicit_cast(builder, to_convert, builder.get_double_ty());
|
||||
is_float = true;
|
||||
}
|
||||
// One operand is float
|
||||
else if(left_ty->is_float_ty() || right_ty->is_float_ty()){
|
||||
ir::value *&to_convert = left_ty->is_float_ty()?rhs:lhs;
|
||||
to_convert = explicit_cast(builder, to_convert, builder.get_float_ty());
|
||||
is_float = true;
|
||||
}
|
||||
// Both operands are integers
|
||||
else if(left_ty->is_integer_ty() && right_ty->is_integer_ty()){
|
||||
is_int = true;
|
||||
is_signed = false;
|
||||
if(left_ty->get_integer_bitwidth() != right_ty->get_integer_bitwidth()){
|
||||
ir::value *&to_convert = (left_ty->get_integer_bitwidth() > right_ty->get_integer_bitwidth())?rhs:lhs;
|
||||
ir::type *dst_ty = (to_convert==lhs)?right_ty:left_ty;
|
||||
to_convert = explicit_cast(builder, to_convert, dst_ty);
|
||||
}
|
||||
}
|
||||
// Not reachable
|
||||
else
|
||||
throw std::runtime_error("unreachable");
|
||||
}
|
||||
|
||||
void node::implicit_broadcast(ir::module *mod, ir::value *&lhs, ir::value *&rhs){
|
||||
ir::builder &builder = mod->get_builder();
|
||||
ir::type *lhs_ty = lhs->get_type();
|
||||
ir::type *rhs_ty = rhs->get_type();
|
||||
// Both are scalar
|
||||
if(!lhs_ty->is_tile_ty() && !rhs_ty->is_tile_ty())
|
||||
return;
|
||||
// One argument is scalar
|
||||
if(lhs_ty->is_tile_ty() ^ rhs_ty->is_tile_ty()){
|
||||
auto &shapes = lhs_ty->is_tile_ty()?lhs_ty->get_tile_shapes():rhs_ty->get_tile_shapes();
|
||||
auto &scalar = lhs_ty->is_tile_ty()?rhs:lhs;
|
||||
scalar = builder.create_splat(scalar, shapes);
|
||||
return;
|
||||
}
|
||||
// Both are arrays
|
||||
std::vector<unsigned> lhs_shapes = lhs->get_type()->get_tile_shapes();
|
||||
std::vector<unsigned> rhs_shapes = rhs->get_type()->get_tile_shapes();
|
||||
int lhs_dim = lhs_shapes.size();
|
||||
int rhs_dim = rhs_shapes.size();
|
||||
std::vector<unsigned> &shortest = (lhs_dim < rhs_dim)?lhs_shapes:rhs_shapes;
|
||||
std::vector<unsigned> &longest = (lhs_dim < rhs_dim)?rhs_shapes:lhs_shapes;
|
||||
size_t ndim = longest.size();
|
||||
int off = longest.size() - shortest.size();
|
||||
for(int i = longest.size(); i>= 0; i--){
|
||||
if(shortest[off + i] != longest[i])
|
||||
throw std::runtime_error("cannot broadcast");
|
||||
}
|
||||
// Pad
|
||||
for(size_t i = 0; i < off; i++)
|
||||
shortest.insert(shortest.begin(), 1);
|
||||
ir::value *&target = (lhs_dim < rhs_dim)?lhs:rhs;
|
||||
target = builder.create_reshape(target, shortest);
|
||||
// Broadcast
|
||||
std::vector<unsigned> shapes(ndim);
|
||||
for(size_t i = 0; i < ndim; i++)
|
||||
shapes[i] = std::max(shortest[i], longest[i]);
|
||||
lhs = builder.create_broadcast(lhs, shapes);
|
||||
rhs = builder.create_broadcast(rhs, shapes);
|
||||
}
|
||||
|
||||
/* Translation unit */
|
||||
ir::value* translation_unit::codegen(ir::module *mod) const{
|
||||
decls_->codegen(mod);
|
||||
@@ -195,11 +307,12 @@ void initializer::specifier(const declaration_specifier *spec) {
|
||||
ir::value* initializer::codegen(ir::module * mod) const{
|
||||
ir::type *ty = decl_->type(mod, spec_->type(mod));
|
||||
std::string name = decl_->id()->name();
|
||||
ir::value *value;
|
||||
if(expr_)
|
||||
value = expr_->codegen(mod);
|
||||
else
|
||||
value = ir::undef_value::get(ty);
|
||||
ir::value *value = ir::undef_value::get(ty);
|
||||
if(expr_){
|
||||
ir::value* target = expr_->codegen(mod);
|
||||
explicit_cast(mod->get_builder(), target, ty->get_scalar_ty());
|
||||
implicit_broadcast(mod, value, target);
|
||||
}
|
||||
value->set_name(name);
|
||||
mod->set_value(name, value);
|
||||
return value;
|
||||
@@ -208,119 +321,12 @@ ir::value* initializer::codegen(ir::module * mod) const{
|
||||
/*------------------*/
|
||||
/* Expression */
|
||||
/*------------------*/
|
||||
ir::value *llvm_cast(ir::builder &builder, ir::value *src, ir::type *dst_ty){
|
||||
ir::type *src_ty = src->get_type();
|
||||
bool src_signed = false;
|
||||
bool dst_signed = false;
|
||||
if(src_ty == dst_ty)
|
||||
return src;
|
||||
else if(src_ty->is_integer_ty() && src_signed && dst_ty->is_floating_point_ty())
|
||||
return builder.create_si_to_fp(src, dst_ty);
|
||||
|
||||
else if(src_ty->is_integer_ty() && !src_signed && dst_ty->is_floating_point_ty())
|
||||
return builder.create_ui_to_fp(src, dst_ty);
|
||||
|
||||
else if(src_ty->is_floating_point_ty() && dst_ty->is_integer_ty() && dst_signed)
|
||||
return builder.create_fp_to_si(src, dst_ty);
|
||||
|
||||
else if(src_ty->is_floating_point_ty() && dst_ty->is_integer_ty() && !dst_signed)
|
||||
return builder.create_fp_to_ui(src, dst_ty);
|
||||
|
||||
else if(src_ty->is_floating_point_ty() && dst_ty->is_floating_point_ty() &&
|
||||
src_ty->get_fp_mantissa_width() < dst_ty->get_fp_mantissa_width())
|
||||
return builder.create_fp_ext(src, dst_ty);
|
||||
|
||||
else if(src_ty->is_floating_point_ty() && dst_ty->is_floating_point_ty() &&
|
||||
src_ty->get_fp_mantissa_width() > dst_ty->get_fp_mantissa_width())
|
||||
return builder.create_fp_trunc(src, dst_ty);
|
||||
|
||||
else if(src_ty->is_integer_ty() && dst_ty->is_integer_ty() &&
|
||||
src_ty->get_integer_bitwidth())
|
||||
return builder.create_int_cast(src, dst_ty, dst_signed);
|
||||
|
||||
else
|
||||
throw std::runtime_error("unreachable");
|
||||
}
|
||||
|
||||
inline void implicit_cast(ir::builder &builder, ir::value *&lhs, ir::value *&rhs,
|
||||
bool &is_float, bool &is_ptr, bool &is_int, bool &is_signed){
|
||||
// Input types
|
||||
ir::type *left_ty = lhs->get_type();
|
||||
ir::type *right_ty = rhs->get_type();
|
||||
// One operand is pointer
|
||||
if(left_ty->is_pointer_ty()){
|
||||
is_ptr = true;
|
||||
}
|
||||
// One operand is double
|
||||
else if(left_ty->is_double_ty() || right_ty->is_double_ty()){
|
||||
ir::value *&to_convert = left_ty->is_double_ty()?rhs:lhs;
|
||||
to_convert = llvm_cast(builder, to_convert, builder.get_double_ty());
|
||||
is_float = true;
|
||||
}
|
||||
// One operand is float
|
||||
else if(left_ty->is_float_ty() || right_ty->is_float_ty()){
|
||||
ir::value *&to_convert = left_ty->is_float_ty()?rhs:lhs;
|
||||
to_convert = llvm_cast(builder, to_convert, builder.get_float_ty());
|
||||
is_float = true;
|
||||
}
|
||||
// Both operands are integers
|
||||
else if(left_ty->is_integer_ty() && right_ty->is_integer_ty()){
|
||||
is_int = true;
|
||||
is_signed = false;
|
||||
if(left_ty->get_integer_bitwidth() != right_ty->get_integer_bitwidth()){
|
||||
ir::value *&to_convert = (left_ty->get_integer_bitwidth() > right_ty->get_integer_bitwidth())?rhs:lhs;
|
||||
ir::type *dst_ty = (to_convert==lhs)?right_ty:left_ty;
|
||||
to_convert = llvm_cast(builder, to_convert, dst_ty);
|
||||
}
|
||||
}
|
||||
// Not reachable
|
||||
else
|
||||
throw std::runtime_error("unreachable");
|
||||
}
|
||||
|
||||
inline void implicit_broadcast(ir::module *mod, ir::builder &builder, ir::value *&lhs, ir::value *&rhs){
|
||||
std::vector<unsigned> lhs_shapes = lhs->get_type()->get_tile_shapes();
|
||||
std::vector<unsigned> rhs_shapes = rhs->get_type()->get_tile_shapes();
|
||||
// Both are scalar
|
||||
if(lhs_shapes.empty() && rhs_shapes.empty())
|
||||
return;
|
||||
// One argument is scalar
|
||||
if(!lhs_shapes.empty() ^ !rhs_shapes.empty()){
|
||||
auto &shapes = lhs_shapes.empty()?rhs_shapes:lhs_shapes;
|
||||
auto &target = lhs_shapes.empty()?lhs:rhs;
|
||||
target = builder.create_splat(target, shapes);
|
||||
return;
|
||||
}
|
||||
// Both are arrays
|
||||
int lhs_dim = lhs_shapes.size();
|
||||
int rhs_dim = rhs_shapes.size();
|
||||
std::vector<unsigned> &shortest = (lhs_dim < rhs_dim)?lhs_shapes:rhs_shapes;
|
||||
std::vector<unsigned> &longest = (lhs_dim < rhs_dim)?rhs_shapes:lhs_shapes;
|
||||
size_t ndim = longest.size();
|
||||
int off = longest.size() - shortest.size();
|
||||
for(int i = longest.size(); i>= 0; i--){
|
||||
if(shortest[off + i] != longest[i])
|
||||
throw std::runtime_error("cannot broadcast");
|
||||
}
|
||||
// Pad
|
||||
for(size_t i = 0; i < off; i++)
|
||||
shortest.insert(shortest.begin(), 1);
|
||||
ir::value *&target = (lhs_dim < rhs_dim)?lhs:rhs;
|
||||
target = builder.create_reshape(target, shortest);
|
||||
// Broadcast
|
||||
std::vector<unsigned> shapes(ndim);
|
||||
for(size_t i = 0; i < ndim; i++)
|
||||
shapes[i] = std::max(shortest[i], longest[i]);
|
||||
lhs = builder.create_broadcast(lhs, shapes);
|
||||
rhs = builder.create_broadcast(rhs, shapes);
|
||||
}
|
||||
|
||||
/* Binary operator */
|
||||
ir::value *binary_operator::llvm_op(ir::module *mod, ir::builder &builder, ir::value *lhs, ir::value *rhs, const std::string &name) const
|
||||
{
|
||||
bool is_float = false, is_ptr = false, is_int = false, is_signed = false;
|
||||
implicit_cast(builder, lhs, rhs, is_float, is_ptr, is_int, is_signed);
|
||||
// implicit_broadcast(mod, builder, lhs, rhs);
|
||||
implicit_broadcast(mod, lhs, rhs);
|
||||
if(op_==MUL && is_float)
|
||||
return builder.create_fmul(lhs, rhs, name);
|
||||
if(op_==MUL && is_int)
|
||||
|
@@ -11,7 +11,7 @@ namespace tdl{
|
||||
namespace codegen{
|
||||
|
||||
|
||||
void allocation::run(ir::function &fn){
|
||||
void allocation::run(){
|
||||
using std::max;
|
||||
using std::min;
|
||||
typedef std::multimap<unsigned, segment> triples_map_type;
|
||||
|
@@ -1,5 +1,6 @@
|
||||
#include "codegen/layout.h"
|
||||
#include "ir/function.h"
|
||||
#include "ir/module.h"
|
||||
#include "ir/basic_block.h"
|
||||
#include "ir/instructions.h"
|
||||
|
||||
@@ -36,19 +37,19 @@ void layout::add_shared_views(ir::value *v){
|
||||
}
|
||||
|
||||
// Entry point
|
||||
bool layout::run(ir::function &fn) {
|
||||
void layout::run(ir::module &mod) {
|
||||
for(ir::function *fn: mod.get_function_list()){
|
||||
// Non-phis
|
||||
for(ir::basic_block *block: fn.blocks())
|
||||
for(ir::basic_block *block: fn->blocks())
|
||||
for(ir::instruction *instr: block->get_inst_list()) {
|
||||
add_shared_views(instr);
|
||||
}
|
||||
// Phi nodes
|
||||
for(ir::basic_block *block: fn.blocks())
|
||||
for(ir::basic_block *block: fn->blocks())
|
||||
for(ir::instruction *instr: block->get_inst_list()) {
|
||||
add_phi_nodes(instr);
|
||||
}
|
||||
// Done
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
|
@@ -2,6 +2,7 @@
|
||||
#include "codegen/layout.h"
|
||||
#include "ir/basic_block.h"
|
||||
#include "ir/function.h"
|
||||
#include "ir/module.h"
|
||||
#include "ir/instructions.h"
|
||||
#include "ir/value.h"
|
||||
|
||||
@@ -10,7 +11,8 @@ namespace codegen{
|
||||
|
||||
|
||||
// Entry point
|
||||
void liveness::run(ir::function *fn) {
|
||||
void liveness::run(ir::module &mod) {
|
||||
for(ir::function *fn: mod.get_function_list()){
|
||||
// Assigns index to each instruction
|
||||
slot_index index = 0;
|
||||
for(ir::basic_block *block: fn->blocks())
|
||||
@@ -35,6 +37,7 @@ void liveness::run(ir::function *fn) {
|
||||
intervals_[v] = segment{start, end};
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
}
|
||||
|
@@ -108,6 +108,7 @@ for(ir::function *fn: mod.get_function_list()){
|
||||
r = i;
|
||||
}
|
||||
}
|
||||
|
||||
// extract unique instructions in order
|
||||
std::vector<ir::instruction*> grids;
|
||||
for(auto &ref: references)
|
||||
@@ -118,6 +119,7 @@ for(ir::function *fn: mod.get_function_list()){
|
||||
int num_warps = 1;
|
||||
for(size_t k = 0; k < grids.front()->get_type()->get_tile_shapes().size(); k++)
|
||||
num_warps *= *params_[grids.front()]["p2.d" + to_string(k)];
|
||||
|
||||
// check constraints
|
||||
for(ir::instruction *i: grids){
|
||||
ir::type *ty = i->get_type();
|
||||
|
Reference in New Issue
Block a user