[syntax tree] fixed broadcast semantics lowering
This commit is contained in:
@@ -4,6 +4,7 @@
|
|||||||
#include "ir/context.h"
|
#include "ir/context.h"
|
||||||
#include "ir/module.h"
|
#include "ir/module.h"
|
||||||
#include "codegen/selection.h"
|
#include "codegen/selection.h"
|
||||||
|
#include "codegen/tune.h"
|
||||||
#include "llvm/IR/IRPrintingPasses.h"
|
#include "llvm/IR/IRPrintingPasses.h"
|
||||||
#include "llvm/IR/Module.h"
|
#include "llvm/IR/Module.h"
|
||||||
#include "llvm/IR/LLVMContext.h"
|
#include "llvm/IR/LLVMContext.h"
|
||||||
@@ -20,6 +21,8 @@ extern translation_unit *ast_root;
|
|||||||
const char src[] =
|
const char src[] =
|
||||||
"\
|
"\
|
||||||
void test(fp32 *A, fp32 *B, fp32 *C, int32 i){\
|
void test(fp32 *A, fp32 *B, fp32 *C, int32 i){\
|
||||||
|
int32 tile[16, 16] = 0;\
|
||||||
|
int32 test[16, 16] = tile + i;\
|
||||||
i = 1;\
|
i = 1;\
|
||||||
A = A + i;\
|
A = A + i;\
|
||||||
}\
|
}\
|
||||||
@@ -35,10 +38,14 @@ int main() {
|
|||||||
program->codegen(&module);
|
program->codegen(&module);
|
||||||
llvm::LLVMContext llvm_context;
|
llvm::LLVMContext llvm_context;
|
||||||
llvm::Module llvm_module("test", llvm_context);
|
llvm::Module llvm_module("test", llvm_context);
|
||||||
|
// lowering passes
|
||||||
tdl::codegen::selection selection;
|
tdl::codegen::selection selection;
|
||||||
selection.run(module, llvm_module);
|
tdl::codegen::tune tune;
|
||||||
llvm::PrintModulePass print(llvm::outs());
|
tune.run(module);
|
||||||
llvm::AnalysisManager<llvm::Module> analysis;
|
// selection.run(module, llvm_module);
|
||||||
print.run(llvm_module, analysis);
|
// // print LLVM program
|
||||||
|
// llvm::PrintModulePass print(llvm::outs());
|
||||||
|
// llvm::AnalysisManager<llvm::Module> analysis;
|
||||||
|
// print.run(llvm_module, analysis);
|
||||||
return 0;
|
return 0;
|
||||||
}
|
}
|
||||||
|
@@ -61,6 +61,11 @@ class identifier;
|
|||||||
|
|
||||||
// AST
|
// AST
|
||||||
class node {
|
class node {
|
||||||
|
protected:
|
||||||
|
static ir::value* explicit_cast(ir::builder &builder, ir::value *src, ir::type *dst_ty);
|
||||||
|
static void implicit_broadcast(ir::module *mod, ir::value *&lhs, ir::value *&rhs);
|
||||||
|
static void implicit_cast(ir::builder &builder, ir::value *&lhs, ir::value *&rhs,
|
||||||
|
bool &is_float, bool &is_ptr, bool &is_int, bool &is_signed);
|
||||||
public:
|
public:
|
||||||
virtual ir::value* codegen(ir::module *) const { return nullptr; }
|
virtual ir::value* codegen(ir::module *) const { return nullptr; }
|
||||||
};
|
};
|
||||||
|
@@ -25,7 +25,7 @@ public:
|
|||||||
unsigned get_allocated_size() const { return allocated_size_; }
|
unsigned get_allocated_size() const { return allocated_size_; }
|
||||||
|
|
||||||
// run
|
// run
|
||||||
void run(ir::function &fn);
|
void run();
|
||||||
|
|
||||||
private:
|
private:
|
||||||
std::map<ir::value*, unsigned> offsets_;
|
std::map<ir::value*, unsigned> offsets_;
|
||||||
|
@@ -7,7 +7,7 @@
|
|||||||
namespace tdl {
|
namespace tdl {
|
||||||
|
|
||||||
namespace ir {
|
namespace ir {
|
||||||
class function;
|
class module;
|
||||||
class instruction;
|
class instruction;
|
||||||
class value;
|
class value;
|
||||||
}
|
}
|
||||||
@@ -32,7 +32,7 @@ public:
|
|||||||
shared_view_info get_shared_view(ir::value *v, unsigned idx);
|
shared_view_info get_shared_view(ir::value *v, unsigned idx);
|
||||||
|
|
||||||
// run
|
// run
|
||||||
bool run(ir::function &fn);
|
void run(ir::module &mod);
|
||||||
|
|
||||||
private:
|
private:
|
||||||
std::map<ir::value*, shared_view_val_t> shared_views_;
|
std::map<ir::value*, shared_view_val_t> shared_views_;
|
||||||
|
@@ -8,6 +8,7 @@ namespace tdl{
|
|||||||
namespace ir{
|
namespace ir{
|
||||||
class value;
|
class value;
|
||||||
class function;
|
class function;
|
||||||
|
class module;
|
||||||
}
|
}
|
||||||
|
|
||||||
namespace codegen{
|
namespace codegen{
|
||||||
@@ -47,7 +48,7 @@ public:
|
|||||||
segment get_interval(ir::value* v) const { return intervals_.at(v); }
|
segment get_interval(ir::value* v) const { return intervals_.at(v); }
|
||||||
|
|
||||||
// run
|
// run
|
||||||
void run(ir::function *fn);
|
void run(ir::module &mod);
|
||||||
|
|
||||||
private:
|
private:
|
||||||
has_storage_map_t has_dedicated_storage_;
|
has_storage_map_t has_dedicated_storage_;
|
||||||
|
@@ -14,6 +14,118 @@ namespace tdl{
|
|||||||
|
|
||||||
namespace ast{
|
namespace ast{
|
||||||
|
|
||||||
|
/* node */
|
||||||
|
ir::value *node::explicit_cast(ir::builder &builder, ir::value *src, ir::type *dst_ty){
|
||||||
|
ir::type *src_ty = src->get_type()->get_scalar_ty();
|
||||||
|
bool src_signed = false;
|
||||||
|
bool dst_signed = false;
|
||||||
|
if(src_ty == dst_ty)
|
||||||
|
return src;
|
||||||
|
else if(src_ty->is_integer_ty() && src_signed && dst_ty->is_floating_point_ty())
|
||||||
|
return builder.create_si_to_fp(src, dst_ty);
|
||||||
|
|
||||||
|
else if(src_ty->is_integer_ty() && !src_signed && dst_ty->is_floating_point_ty())
|
||||||
|
return builder.create_ui_to_fp(src, dst_ty);
|
||||||
|
|
||||||
|
else if(src_ty->is_floating_point_ty() && dst_ty->is_integer_ty() && dst_signed)
|
||||||
|
return builder.create_fp_to_si(src, dst_ty);
|
||||||
|
|
||||||
|
else if(src_ty->is_floating_point_ty() && dst_ty->is_integer_ty() && !dst_signed)
|
||||||
|
return builder.create_fp_to_ui(src, dst_ty);
|
||||||
|
|
||||||
|
else if(src_ty->is_floating_point_ty() && dst_ty->is_floating_point_ty() &&
|
||||||
|
src_ty->get_fp_mantissa_width() < dst_ty->get_fp_mantissa_width())
|
||||||
|
return builder.create_fp_ext(src, dst_ty);
|
||||||
|
|
||||||
|
else if(src_ty->is_floating_point_ty() && dst_ty->is_floating_point_ty() &&
|
||||||
|
src_ty->get_fp_mantissa_width() > dst_ty->get_fp_mantissa_width())
|
||||||
|
return builder.create_fp_trunc(src, dst_ty);
|
||||||
|
|
||||||
|
else if(src_ty->is_integer_ty() && dst_ty->is_integer_ty() &&
|
||||||
|
src_ty->get_integer_bitwidth())
|
||||||
|
return builder.create_int_cast(src, dst_ty, dst_signed);
|
||||||
|
|
||||||
|
else
|
||||||
|
throw std::runtime_error("unreachable");
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
void node::implicit_cast(ir::builder &builder, ir::value *&lhs, ir::value *&rhs,
|
||||||
|
bool &is_float, bool &is_ptr, bool &is_int, bool &is_signed){
|
||||||
|
// Input types
|
||||||
|
ir::type *left_ty = lhs->get_type()->get_scalar_ty();
|
||||||
|
ir::type *right_ty = rhs->get_type()->get_scalar_ty();
|
||||||
|
// One operand is pointer
|
||||||
|
if(left_ty->is_pointer_ty()){
|
||||||
|
is_ptr = true;
|
||||||
|
}
|
||||||
|
// One operand is double
|
||||||
|
else if(left_ty->is_double_ty() || right_ty->is_double_ty()){
|
||||||
|
ir::value *&to_convert = left_ty->is_double_ty()?rhs:lhs;
|
||||||
|
to_convert = explicit_cast(builder, to_convert, builder.get_double_ty());
|
||||||
|
is_float = true;
|
||||||
|
}
|
||||||
|
// One operand is float
|
||||||
|
else if(left_ty->is_float_ty() || right_ty->is_float_ty()){
|
||||||
|
ir::value *&to_convert = left_ty->is_float_ty()?rhs:lhs;
|
||||||
|
to_convert = explicit_cast(builder, to_convert, builder.get_float_ty());
|
||||||
|
is_float = true;
|
||||||
|
}
|
||||||
|
// Both operands are integers
|
||||||
|
else if(left_ty->is_integer_ty() && right_ty->is_integer_ty()){
|
||||||
|
is_int = true;
|
||||||
|
is_signed = false;
|
||||||
|
if(left_ty->get_integer_bitwidth() != right_ty->get_integer_bitwidth()){
|
||||||
|
ir::value *&to_convert = (left_ty->get_integer_bitwidth() > right_ty->get_integer_bitwidth())?rhs:lhs;
|
||||||
|
ir::type *dst_ty = (to_convert==lhs)?right_ty:left_ty;
|
||||||
|
to_convert = explicit_cast(builder, to_convert, dst_ty);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// Not reachable
|
||||||
|
else
|
||||||
|
throw std::runtime_error("unreachable");
|
||||||
|
}
|
||||||
|
|
||||||
|
void node::implicit_broadcast(ir::module *mod, ir::value *&lhs, ir::value *&rhs){
|
||||||
|
ir::builder &builder = mod->get_builder();
|
||||||
|
ir::type *lhs_ty = lhs->get_type();
|
||||||
|
ir::type *rhs_ty = rhs->get_type();
|
||||||
|
// Both are scalar
|
||||||
|
if(!lhs_ty->is_tile_ty() && !rhs_ty->is_tile_ty())
|
||||||
|
return;
|
||||||
|
// One argument is scalar
|
||||||
|
if(lhs_ty->is_tile_ty() ^ rhs_ty->is_tile_ty()){
|
||||||
|
auto &shapes = lhs_ty->is_tile_ty()?lhs_ty->get_tile_shapes():rhs_ty->get_tile_shapes();
|
||||||
|
auto &scalar = lhs_ty->is_tile_ty()?rhs:lhs;
|
||||||
|
scalar = builder.create_splat(scalar, shapes);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
// Both are arrays
|
||||||
|
std::vector<unsigned> lhs_shapes = lhs->get_type()->get_tile_shapes();
|
||||||
|
std::vector<unsigned> rhs_shapes = rhs->get_type()->get_tile_shapes();
|
||||||
|
int lhs_dim = lhs_shapes.size();
|
||||||
|
int rhs_dim = rhs_shapes.size();
|
||||||
|
std::vector<unsigned> &shortest = (lhs_dim < rhs_dim)?lhs_shapes:rhs_shapes;
|
||||||
|
std::vector<unsigned> &longest = (lhs_dim < rhs_dim)?rhs_shapes:lhs_shapes;
|
||||||
|
size_t ndim = longest.size();
|
||||||
|
int off = longest.size() - shortest.size();
|
||||||
|
for(int i = longest.size(); i>= 0; i--){
|
||||||
|
if(shortest[off + i] != longest[i])
|
||||||
|
throw std::runtime_error("cannot broadcast");
|
||||||
|
}
|
||||||
|
// Pad
|
||||||
|
for(size_t i = 0; i < off; i++)
|
||||||
|
shortest.insert(shortest.begin(), 1);
|
||||||
|
ir::value *&target = (lhs_dim < rhs_dim)?lhs:rhs;
|
||||||
|
target = builder.create_reshape(target, shortest);
|
||||||
|
// Broadcast
|
||||||
|
std::vector<unsigned> shapes(ndim);
|
||||||
|
for(size_t i = 0; i < ndim; i++)
|
||||||
|
shapes[i] = std::max(shortest[i], longest[i]);
|
||||||
|
lhs = builder.create_broadcast(lhs, shapes);
|
||||||
|
rhs = builder.create_broadcast(rhs, shapes);
|
||||||
|
}
|
||||||
|
|
||||||
/* Translation unit */
|
/* Translation unit */
|
||||||
ir::value* translation_unit::codegen(ir::module *mod) const{
|
ir::value* translation_unit::codegen(ir::module *mod) const{
|
||||||
decls_->codegen(mod);
|
decls_->codegen(mod);
|
||||||
@@ -195,11 +307,12 @@ void initializer::specifier(const declaration_specifier *spec) {
|
|||||||
ir::value* initializer::codegen(ir::module * mod) const{
|
ir::value* initializer::codegen(ir::module * mod) const{
|
||||||
ir::type *ty = decl_->type(mod, spec_->type(mod));
|
ir::type *ty = decl_->type(mod, spec_->type(mod));
|
||||||
std::string name = decl_->id()->name();
|
std::string name = decl_->id()->name();
|
||||||
ir::value *value;
|
ir::value *value = ir::undef_value::get(ty);
|
||||||
if(expr_)
|
if(expr_){
|
||||||
value = expr_->codegen(mod);
|
ir::value* target = expr_->codegen(mod);
|
||||||
else
|
explicit_cast(mod->get_builder(), target, ty->get_scalar_ty());
|
||||||
value = ir::undef_value::get(ty);
|
implicit_broadcast(mod, value, target);
|
||||||
|
}
|
||||||
value->set_name(name);
|
value->set_name(name);
|
||||||
mod->set_value(name, value);
|
mod->set_value(name, value);
|
||||||
return value;
|
return value;
|
||||||
@@ -208,119 +321,12 @@ ir::value* initializer::codegen(ir::module * mod) const{
|
|||||||
/*------------------*/
|
/*------------------*/
|
||||||
/* Expression */
|
/* Expression */
|
||||||
/*------------------*/
|
/*------------------*/
|
||||||
ir::value *llvm_cast(ir::builder &builder, ir::value *src, ir::type *dst_ty){
|
|
||||||
ir::type *src_ty = src->get_type();
|
|
||||||
bool src_signed = false;
|
|
||||||
bool dst_signed = false;
|
|
||||||
if(src_ty == dst_ty)
|
|
||||||
return src;
|
|
||||||
else if(src_ty->is_integer_ty() && src_signed && dst_ty->is_floating_point_ty())
|
|
||||||
return builder.create_si_to_fp(src, dst_ty);
|
|
||||||
|
|
||||||
else if(src_ty->is_integer_ty() && !src_signed && dst_ty->is_floating_point_ty())
|
|
||||||
return builder.create_ui_to_fp(src, dst_ty);
|
|
||||||
|
|
||||||
else if(src_ty->is_floating_point_ty() && dst_ty->is_integer_ty() && dst_signed)
|
|
||||||
return builder.create_fp_to_si(src, dst_ty);
|
|
||||||
|
|
||||||
else if(src_ty->is_floating_point_ty() && dst_ty->is_integer_ty() && !dst_signed)
|
|
||||||
return builder.create_fp_to_ui(src, dst_ty);
|
|
||||||
|
|
||||||
else if(src_ty->is_floating_point_ty() && dst_ty->is_floating_point_ty() &&
|
|
||||||
src_ty->get_fp_mantissa_width() < dst_ty->get_fp_mantissa_width())
|
|
||||||
return builder.create_fp_ext(src, dst_ty);
|
|
||||||
|
|
||||||
else if(src_ty->is_floating_point_ty() && dst_ty->is_floating_point_ty() &&
|
|
||||||
src_ty->get_fp_mantissa_width() > dst_ty->get_fp_mantissa_width())
|
|
||||||
return builder.create_fp_trunc(src, dst_ty);
|
|
||||||
|
|
||||||
else if(src_ty->is_integer_ty() && dst_ty->is_integer_ty() &&
|
|
||||||
src_ty->get_integer_bitwidth())
|
|
||||||
return builder.create_int_cast(src, dst_ty, dst_signed);
|
|
||||||
|
|
||||||
else
|
|
||||||
throw std::runtime_error("unreachable");
|
|
||||||
}
|
|
||||||
|
|
||||||
inline void implicit_cast(ir::builder &builder, ir::value *&lhs, ir::value *&rhs,
|
|
||||||
bool &is_float, bool &is_ptr, bool &is_int, bool &is_signed){
|
|
||||||
// Input types
|
|
||||||
ir::type *left_ty = lhs->get_type();
|
|
||||||
ir::type *right_ty = rhs->get_type();
|
|
||||||
// One operand is pointer
|
|
||||||
if(left_ty->is_pointer_ty()){
|
|
||||||
is_ptr = true;
|
|
||||||
}
|
|
||||||
// One operand is double
|
|
||||||
else if(left_ty->is_double_ty() || right_ty->is_double_ty()){
|
|
||||||
ir::value *&to_convert = left_ty->is_double_ty()?rhs:lhs;
|
|
||||||
to_convert = llvm_cast(builder, to_convert, builder.get_double_ty());
|
|
||||||
is_float = true;
|
|
||||||
}
|
|
||||||
// One operand is float
|
|
||||||
else if(left_ty->is_float_ty() || right_ty->is_float_ty()){
|
|
||||||
ir::value *&to_convert = left_ty->is_float_ty()?rhs:lhs;
|
|
||||||
to_convert = llvm_cast(builder, to_convert, builder.get_float_ty());
|
|
||||||
is_float = true;
|
|
||||||
}
|
|
||||||
// Both operands are integers
|
|
||||||
else if(left_ty->is_integer_ty() && right_ty->is_integer_ty()){
|
|
||||||
is_int = true;
|
|
||||||
is_signed = false;
|
|
||||||
if(left_ty->get_integer_bitwidth() != right_ty->get_integer_bitwidth()){
|
|
||||||
ir::value *&to_convert = (left_ty->get_integer_bitwidth() > right_ty->get_integer_bitwidth())?rhs:lhs;
|
|
||||||
ir::type *dst_ty = (to_convert==lhs)?right_ty:left_ty;
|
|
||||||
to_convert = llvm_cast(builder, to_convert, dst_ty);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
// Not reachable
|
|
||||||
else
|
|
||||||
throw std::runtime_error("unreachable");
|
|
||||||
}
|
|
||||||
|
|
||||||
inline void implicit_broadcast(ir::module *mod, ir::builder &builder, ir::value *&lhs, ir::value *&rhs){
|
|
||||||
std::vector<unsigned> lhs_shapes = lhs->get_type()->get_tile_shapes();
|
|
||||||
std::vector<unsigned> rhs_shapes = rhs->get_type()->get_tile_shapes();
|
|
||||||
// Both are scalar
|
|
||||||
if(lhs_shapes.empty() && rhs_shapes.empty())
|
|
||||||
return;
|
|
||||||
// One argument is scalar
|
|
||||||
if(!lhs_shapes.empty() ^ !rhs_shapes.empty()){
|
|
||||||
auto &shapes = lhs_shapes.empty()?rhs_shapes:lhs_shapes;
|
|
||||||
auto &target = lhs_shapes.empty()?lhs:rhs;
|
|
||||||
target = builder.create_splat(target, shapes);
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
// Both are arrays
|
|
||||||
int lhs_dim = lhs_shapes.size();
|
|
||||||
int rhs_dim = rhs_shapes.size();
|
|
||||||
std::vector<unsigned> &shortest = (lhs_dim < rhs_dim)?lhs_shapes:rhs_shapes;
|
|
||||||
std::vector<unsigned> &longest = (lhs_dim < rhs_dim)?rhs_shapes:lhs_shapes;
|
|
||||||
size_t ndim = longest.size();
|
|
||||||
int off = longest.size() - shortest.size();
|
|
||||||
for(int i = longest.size(); i>= 0; i--){
|
|
||||||
if(shortest[off + i] != longest[i])
|
|
||||||
throw std::runtime_error("cannot broadcast");
|
|
||||||
}
|
|
||||||
// Pad
|
|
||||||
for(size_t i = 0; i < off; i++)
|
|
||||||
shortest.insert(shortest.begin(), 1);
|
|
||||||
ir::value *&target = (lhs_dim < rhs_dim)?lhs:rhs;
|
|
||||||
target = builder.create_reshape(target, shortest);
|
|
||||||
// Broadcast
|
|
||||||
std::vector<unsigned> shapes(ndim);
|
|
||||||
for(size_t i = 0; i < ndim; i++)
|
|
||||||
shapes[i] = std::max(shortest[i], longest[i]);
|
|
||||||
lhs = builder.create_broadcast(lhs, shapes);
|
|
||||||
rhs = builder.create_broadcast(rhs, shapes);
|
|
||||||
}
|
|
||||||
|
|
||||||
/* Binary operator */
|
/* Binary operator */
|
||||||
ir::value *binary_operator::llvm_op(ir::module *mod, ir::builder &builder, ir::value *lhs, ir::value *rhs, const std::string &name) const
|
ir::value *binary_operator::llvm_op(ir::module *mod, ir::builder &builder, ir::value *lhs, ir::value *rhs, const std::string &name) const
|
||||||
{
|
{
|
||||||
bool is_float = false, is_ptr = false, is_int = false, is_signed = false;
|
bool is_float = false, is_ptr = false, is_int = false, is_signed = false;
|
||||||
implicit_cast(builder, lhs, rhs, is_float, is_ptr, is_int, is_signed);
|
implicit_cast(builder, lhs, rhs, is_float, is_ptr, is_int, is_signed);
|
||||||
// implicit_broadcast(mod, builder, lhs, rhs);
|
implicit_broadcast(mod, lhs, rhs);
|
||||||
if(op_==MUL && is_float)
|
if(op_==MUL && is_float)
|
||||||
return builder.create_fmul(lhs, rhs, name);
|
return builder.create_fmul(lhs, rhs, name);
|
||||||
if(op_==MUL && is_int)
|
if(op_==MUL && is_int)
|
||||||
|
@@ -11,7 +11,7 @@ namespace tdl{
|
|||||||
namespace codegen{
|
namespace codegen{
|
||||||
|
|
||||||
|
|
||||||
void allocation::run(ir::function &fn){
|
void allocation::run(){
|
||||||
using std::max;
|
using std::max;
|
||||||
using std::min;
|
using std::min;
|
||||||
typedef std::multimap<unsigned, segment> triples_map_type;
|
typedef std::multimap<unsigned, segment> triples_map_type;
|
||||||
|
@@ -1,5 +1,6 @@
|
|||||||
#include "codegen/layout.h"
|
#include "codegen/layout.h"
|
||||||
#include "ir/function.h"
|
#include "ir/function.h"
|
||||||
|
#include "ir/module.h"
|
||||||
#include "ir/basic_block.h"
|
#include "ir/basic_block.h"
|
||||||
#include "ir/instructions.h"
|
#include "ir/instructions.h"
|
||||||
|
|
||||||
@@ -36,19 +37,19 @@ void layout::add_shared_views(ir::value *v){
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Entry point
|
// Entry point
|
||||||
bool layout::run(ir::function &fn) {
|
void layout::run(ir::module &mod) {
|
||||||
|
for(ir::function *fn: mod.get_function_list()){
|
||||||
// Non-phis
|
// Non-phis
|
||||||
for(ir::basic_block *block: fn.blocks())
|
for(ir::basic_block *block: fn->blocks())
|
||||||
for(ir::instruction *instr: block->get_inst_list()) {
|
for(ir::instruction *instr: block->get_inst_list()) {
|
||||||
add_shared_views(instr);
|
add_shared_views(instr);
|
||||||
}
|
}
|
||||||
// Phi nodes
|
// Phi nodes
|
||||||
for(ir::basic_block *block: fn.blocks())
|
for(ir::basic_block *block: fn->blocks())
|
||||||
for(ir::instruction *instr: block->get_inst_list()) {
|
for(ir::instruction *instr: block->get_inst_list()) {
|
||||||
add_phi_nodes(instr);
|
add_phi_nodes(instr);
|
||||||
}
|
}
|
||||||
// Done
|
}
|
||||||
return false;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
@@ -2,6 +2,7 @@
|
|||||||
#include "codegen/layout.h"
|
#include "codegen/layout.h"
|
||||||
#include "ir/basic_block.h"
|
#include "ir/basic_block.h"
|
||||||
#include "ir/function.h"
|
#include "ir/function.h"
|
||||||
|
#include "ir/module.h"
|
||||||
#include "ir/instructions.h"
|
#include "ir/instructions.h"
|
||||||
#include "ir/value.h"
|
#include "ir/value.h"
|
||||||
|
|
||||||
@@ -10,7 +11,8 @@ namespace codegen{
|
|||||||
|
|
||||||
|
|
||||||
// Entry point
|
// Entry point
|
||||||
void liveness::run(ir::function *fn) {
|
void liveness::run(ir::module &mod) {
|
||||||
|
for(ir::function *fn: mod.get_function_list()){
|
||||||
// Assigns index to each instruction
|
// Assigns index to each instruction
|
||||||
slot_index index = 0;
|
slot_index index = 0;
|
||||||
for(ir::basic_block *block: fn->blocks())
|
for(ir::basic_block *block: fn->blocks())
|
||||||
@@ -35,6 +37,7 @@ void liveness::run(ir::function *fn) {
|
|||||||
intervals_[v] = segment{start, end};
|
intervals_[v] = segment{start, end};
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@@ -108,6 +108,7 @@ for(ir::function *fn: mod.get_function_list()){
|
|||||||
r = i;
|
r = i;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// extract unique instructions in order
|
// extract unique instructions in order
|
||||||
std::vector<ir::instruction*> grids;
|
std::vector<ir::instruction*> grids;
|
||||||
for(auto &ref: references)
|
for(auto &ref: references)
|
||||||
@@ -118,6 +119,7 @@ for(ir::function *fn: mod.get_function_list()){
|
|||||||
int num_warps = 1;
|
int num_warps = 1;
|
||||||
for(size_t k = 0; k < grids.front()->get_type()->get_tile_shapes().size(); k++)
|
for(size_t k = 0; k < grids.front()->get_type()->get_tile_shapes().size(); k++)
|
||||||
num_warps *= *params_[grids.front()]["p2.d" + to_string(k)];
|
num_warps *= *params_[grids.front()]["p2.d" + to_string(k)];
|
||||||
|
|
||||||
// check constraints
|
// check constraints
|
||||||
for(ir::instruction *i: grids){
|
for(ir::instruction *i: grids){
|
||||||
ir::type *ty = i->get_type();
|
ir::type *ty = i->get_type();
|
||||||
|
Reference in New Issue
Block a user