basic split-k across warps working for GEMM
This commit is contained in:
@@ -26,7 +26,7 @@ struct perf_t {
|
||||
|
||||
perf_t do_bench(triton::driver::stream* stream, bool AT, bool BT, int32_t M, int32_t N, int32_t K){
|
||||
typedef float NumericT;
|
||||
std::string ty = "half";
|
||||
std::string ty = "float";
|
||||
size_t dt_nbytes = sizeof(NumericT);
|
||||
triton::driver::context* context = stream->context();
|
||||
std::vector<NumericT> hc(M*N);
|
||||
@@ -48,28 +48,40 @@ perf_t do_bench(triton::driver::stream* stream, bool AT, bool BT, int32_t M, int
|
||||
stream->synchronize();
|
||||
triton::dnn::dot dot(M, N, K, AT, BT, ty, ty, 8, 8, 8);
|
||||
// benchmark triton
|
||||
double triton_ns = triton::tools::bench([&]() { dot.enqueue(stream, {da, db, dc}, triton::dnn::PARTIAL_TUNING);}, stream);
|
||||
double triton_ns = triton::tools::bench([&]() { dot.enqueue(stream, {da, db, dc}, triton::dnn::NO_TUNING);}, stream);
|
||||
// benchmark cublas
|
||||
NumericT alpha = 1;
|
||||
NumericT beta = 0;
|
||||
int32_t lda = AT ? K : M;
|
||||
int32_t ldb = BT ? N : K;
|
||||
int32_t ldc = M;
|
||||
// NumericT alpha = 1;
|
||||
// NumericT beta = 0;
|
||||
// int32_t lda = AT ? K : M;
|
||||
// int32_t ldb = BT ? N : K;
|
||||
// int32_t ldc = M;
|
||||
// cublasGemmAlgo_t fastest;
|
||||
// cublasGemm(HALF_TYPE, stream, AT, BT, M, N, K,
|
||||
// &alpha, da, lda,
|
||||
// db, ldb, &beta,
|
||||
// dc, ldc, &fastest);
|
||||
double cublas_ns = triton::tools::bench([&]() { cublasGemm(HALF_TYPE, stream, AT, BT, M, N, K,
|
||||
&alpha, da, lda,
|
||||
db, ldb, &beta,
|
||||
dc, ldc, nullptr, CUBLAS_GEMM_DEFAULT_TENSOR_OP); }, stream);
|
||||
// double cublas_ns = triton::tools::bench([&]() { cublasGemm(HALF_TYPE, stream, AT, BT, M, N, K,
|
||||
// &alpha, da, lda,
|
||||
// db, ldb, &beta,
|
||||
// dc, ldc, nullptr, CUBLAS_GEMM_DEFAULT_TENSOR_OP); }, stream);
|
||||
// result
|
||||
auto tflops = [&](double nanosec) { return dot.num_flops() / nanosec * 1e-3; };
|
||||
|
||||
perf_t result;
|
||||
result.cublas = tflops(cublas_ns);
|
||||
// result.cublas = tflops(cublas_ns);
|
||||
result.triton = tflops(triton_ns);
|
||||
|
||||
// test
|
||||
stream->read(dc, true, 0, hc);
|
||||
std::vector<float> rc(hc.size());
|
||||
dot.cpu_ref(rc, ha, hb);
|
||||
for(size_t i = 0; i < M*N; i++)
|
||||
if(!std::isnan(hc[i]) && std::abs(hc[i] - rc[i])/std::max(hc[i], rc[i]) > 1e-4){
|
||||
std::cout << i << " " << hc[i] << " " << rc[i] << std::endl;
|
||||
exit(EXIT_FAILURE);
|
||||
}
|
||||
std::cout << "Pass!" << std::endl;
|
||||
|
||||
// clean-up
|
||||
delete dc;
|
||||
delete da;
|
||||
@@ -99,8 +111,8 @@ int main() {
|
||||
std::vector<config_t> configs = {
|
||||
// {false, false, 8192, 512, 512},
|
||||
// {false, true, 8192, 8192, 8192}
|
||||
{false, true, 32768, 256, 256},
|
||||
{false, true, 32768, 256, 512}
|
||||
{false, true, 128, 128, 128},
|
||||
// {false, true, 32768, 256, 512}
|
||||
// {true, false, 8192, 512, 512},
|
||||
// {true, true, 8192, 512, 512}
|
||||
};
|
||||
|
@@ -49,7 +49,7 @@ class DotOp : public OpKernel {
|
||||
triton::driver::cu_buffer db(ctx, b.tensor_data().size(), (CUdeviceptr)b.tensor_data().data(), false);
|
||||
triton::driver::cu_buffer dc(ctx, c->tensor_data().size(), (CUdeviceptr)c->tensor_data().data(), false);
|
||||
// template
|
||||
triton::dnn::dot dot(M, N, K, false, false, "half", "half", 8, 8, 8);
|
||||
triton::dnn::dot dot(M, N, K, false, true, "half", "half", 8, 8, 8);
|
||||
dot.enqueue(stream, {&da, &db, &dc});
|
||||
}
|
||||
|
||||
|
@@ -23,7 +23,7 @@ def run_dot():
|
||||
result = sess.run([c], feed_dict = {a: ha,
|
||||
b: hb})[0]
|
||||
# Test
|
||||
hresult = np.dot(ha.T, hb.T).T
|
||||
hresult = np.dot(ha.T, hb).T
|
||||
dif = np.abs(result - hresult)
|
||||
np.savetxt('dif.dat', dif, '%2.4f')
|
||||
print(hresult)
|
||||
@@ -131,6 +131,6 @@ def run_batchnorm():
|
||||
print(np.max(np.abs(dg_t - dg_n)))
|
||||
print(np.max(np.abs(db_t - db_n)))
|
||||
|
||||
#run_dot()
|
||||
run_dot()
|
||||
#run_shift()
|
||||
run_batchnorm()
|
||||
#run_batchnorm()
|
||||
|
@@ -73,11 +73,11 @@ public:
|
||||
optimize_dot.run(module);
|
||||
optimize_trans.run(module);
|
||||
optimize_dce.run(module);
|
||||
// ir::print(module, std::cout);
|
||||
}
|
||||
|
||||
void target_dependent(ir::module &module) {
|
||||
alignment_info.run(module);
|
||||
// ir::print(module, std::cout);
|
||||
// reassociate.run(module);
|
||||
if(target_->is_gpu()){
|
||||
shmem_info.run(module);
|
||||
|
@@ -33,8 +33,7 @@ void optimize_dot::run(ir::module &mod) {
|
||||
for(ir::function *fn: mod.get_function_list())
|
||||
for(ir::basic_block *block: fn->blocks())
|
||||
for(ir::instruction *i: block->get_inst_list())
|
||||
if(auto dot = dynamic_cast<ir::dot_inst*>(i))
|
||||
if(dot->get_operand(1)->get_type()->get_tile_shapes()[1]->get_value() != 1){
|
||||
if(auto dot = dynamic_cast<ir::dot_inst*>(i)){
|
||||
builder.set_insert_point(i);
|
||||
ir::value *A = dot->get_operand(0);
|
||||
ir::value *B = dot->get_operand(1);
|
||||
|
@@ -135,8 +135,12 @@ void shared_tile::extract_constant(const indices_t &arg_idx, indices_t &non_cst_
|
||||
Value* shared_tile::shared_offset(llvm::IRBuilder<> &builder, const shapes_t& shapes, indices_t idx) {
|
||||
Value *result = builder.getInt32(0);
|
||||
result = builder.CreateAdd(result, idx[0]);
|
||||
for(size_t i = 1; i < idx.size(); i++)
|
||||
result = builder.CreateAdd(result, builder.CreateMul(idx[i], builder.getInt32(shapes[i-1])));
|
||||
Value *ld = builder.getInt32(shapes[0]);
|
||||
for(size_t i = 1; i < idx.size(); i++) {
|
||||
result = builder.CreateAdd(result, builder.CreateMul(idx[i], ld));
|
||||
if(i < idx.size() - 1)
|
||||
ld = builder.CreateMul(ld, builder.getInt32(shapes[i]));
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
@@ -854,10 +858,13 @@ void selection::lower_tile_instruction(ir::instruction *ins, llvm::IRBuilder<> &
|
||||
Value *&result = x.second;
|
||||
indices_t write_idx = x.first;
|
||||
write_idx.insert(write_idx.begin() + axis, lane);
|
||||
|
||||
// shared memory write pointer
|
||||
Value *write_offset = shared_tile::shared_offset(builder, op_tile->get_shapes(), write_idx);
|
||||
Value *write_ptr = builder.CreateGEP(base_ptr, write_offset);
|
||||
|
||||
// initialize shared memory
|
||||
tgt_->add_barrier(module, builder);
|
||||
builder.CreateStore(result, write_ptr);
|
||||
// build result
|
||||
for(unsigned i = depth/2; i > 0; i >>= 1){
|
||||
@@ -993,15 +1000,14 @@ void selection::lower_tile_instruction(ir::instruction *ins, llvm::IRBuilder<> &
|
||||
{
|
||||
shared_tile *TA = (shared_tile*)tmap_.at(A);
|
||||
shared_tile *TB = (shared_tile*)tmap_.at(B);
|
||||
if(params_->get_fragment(ins, 0) == tune::STRIDED_SCAN)
|
||||
{
|
||||
if(params_->get_fragment(ins, 0) == tune::STRIDED_SCAN) {
|
||||
TA->set_vector_size(TC->axis(0).contiguous);
|
||||
TB->set_vector_size(TC->axis(1).contiguous);
|
||||
result->for_each([&](indices_t idx){
|
||||
Value *res = TC->get_value(idx);
|
||||
for(unsigned K = 0; K < NK; ++K){
|
||||
indices_t a_idx = {idx[0], builder.getInt32(K)};
|
||||
indices_t b_idx = {builder.getInt32(K), idx[1]};
|
||||
indices_t a_idx = {idx[0], builder.getInt32(K), idx[2]};
|
||||
indices_t b_idx = {builder.getInt32(K), idx[1], idx[2]};
|
||||
if(AT)
|
||||
std::swap(a_idx[0], a_idx[1]);
|
||||
if(BT)
|
||||
@@ -1013,13 +1019,11 @@ void selection::lower_tile_instruction(ir::instruction *ins, llvm::IRBuilder<> &
|
||||
if(b->getType() != c_ty)
|
||||
b = builder.CreateFPCast(b, c_ty);
|
||||
res = builder.CreateCall(f_mul_add, {a, b, res});
|
||||
|
||||
}
|
||||
result->set_value(idx, res);
|
||||
});
|
||||
}
|
||||
else
|
||||
{
|
||||
else {
|
||||
TA->set_vector_size(4*pack_size_0_);
|
||||
TB->set_vector_size(4*pack_size_1_);
|
||||
TA->set_return_mode(true);
|
||||
|
@@ -42,8 +42,8 @@ unsigned shmem_allocation::is_ld_padded(ir::value *x) {
|
||||
}
|
||||
|
||||
unsigned shmem_allocation::get_num_bytes(ir::value *x) {
|
||||
unsigned num_bytes = x->get_type()->get_primitive_size_in_bits() / 8;
|
||||
if(auto *red = dynamic_cast<ir::reduce_inst*>(x)){
|
||||
unsigned num_bytes = x->get_type()->get_scalar_ty()->get_primitive_size_in_bits() / 8;
|
||||
size_t axis = red->get_axis();
|
||||
ir::value *op = red->get_operand(0);
|
||||
auto shapes = op->get_type()->get_tile_shapes();
|
||||
@@ -54,6 +54,7 @@ unsigned shmem_allocation::get_num_bytes(ir::value *x) {
|
||||
size_t depth = params_->get_param(op, "mts.d" + std::to_string(axis))->get_value();
|
||||
return num_elements * num_bytes * depth;
|
||||
}
|
||||
unsigned num_bytes = x->get_type()->get_primitive_size_in_bits() / 8;
|
||||
unsigned pad = is_ld_padded(x);
|
||||
if(pad > 0){
|
||||
unsigned ld = x->get_type()->get_tile_shapes()[0]->get_value();
|
||||
|
@@ -24,8 +24,7 @@ bool is_hmma(ir::value *v){
|
||||
ir::type *b_ty = b->get_type();
|
||||
// inputs have to be FP16
|
||||
result = a_ty->get_scalar_ty()->is_half_ty() && b_ty->get_scalar_ty()->is_half_ty();
|
||||
// reduction has to be multiple of 4
|
||||
result = result && ((a_ty->get_tile_shapes()[1]->get_value() % 4) == 0);
|
||||
// reduction has to be multiple of 4: TODO
|
||||
}
|
||||
return result;
|
||||
}
|
||||
@@ -66,9 +65,10 @@ void tune::init_c_graph(ir::instruction *v) {
|
||||
for(unsigned i = 0; i < in_shapes.size(); i++){
|
||||
if(i == axis)
|
||||
continue;
|
||||
// std::cout << arg->get_name() << " " << v->get_name() << std::endl;
|
||||
add_constraint({reduce, current++}, {arg, i});
|
||||
}
|
||||
// add_constraint({reduce, 0}, {arg, 0});
|
||||
// add_constraint({reduce, 1}, {arg, 1});
|
||||
return;
|
||||
}
|
||||
else
|
||||
@@ -81,8 +81,10 @@ void tune::init_c_graph(ir::instruction *v) {
|
||||
for(unsigned i = 0; i < shapes.size(); i ++){
|
||||
bool is_one = shapes[i] == one;
|
||||
bool is_same = shapes[i] == op->get_type()->get_tile_shapes()[current];
|
||||
if(is_one)
|
||||
if(is_one){
|
||||
static_params_.insert({{v, i}, 1});
|
||||
add_constraint({v, i}, {v, i});
|
||||
}
|
||||
else if(!is_skewed && is_same)
|
||||
add_constraint({v, i}, {op, current++});
|
||||
else{
|
||||
@@ -114,9 +116,17 @@ void tune::init_c_graph(ir::instruction *v) {
|
||||
}
|
||||
// Matrix multiplication
|
||||
else if(dynamic_cast<ir::dot_inst*>(v)){
|
||||
ir::value *A = v->get_operand(0);
|
||||
ir::value *B = v->get_operand(1);
|
||||
ir::value *D = v->get_operand(2);
|
||||
add_constraint({v, 0}, {D, 0});
|
||||
add_constraint({v, 1}, {D, 1});
|
||||
for(unsigned i = 0; i < shapes.size(); i++)
|
||||
add_constraint({v, i}, {D, i});
|
||||
for(unsigned i = 2; i < shapes.size(); i++){
|
||||
if(shapes[i] == one)
|
||||
static_params_.insert({{v, i}, 1});
|
||||
add_constraint({v, i}, {A, i});
|
||||
add_constraint({v, i}, {B, i});
|
||||
}
|
||||
}
|
||||
// Element-wise
|
||||
else if(dynamic_cast<ir::user*>(v)) {
|
||||
@@ -242,7 +252,7 @@ void tune::run(ir::module &mod) {
|
||||
node_t node = *nodes_.begin();
|
||||
if(fragments_[node] == STRIDED_SCAN) {
|
||||
ir::metaparameter *nts = ir::metaparameter::create(ctx, ty, 1, 1);
|
||||
ir::metaparameter *mts = ir::metaparameter::create(ctx, ty, 2, 64);
|
||||
ir::metaparameter *mts = ir::metaparameter::create(ctx, ty, 4, 32);
|
||||
connected_components(node, {nts, mts}, {"nts", "mts"}, nodes_, dependencies_, group_id++);
|
||||
nts->set_value(1);
|
||||
}
|
||||
@@ -266,14 +276,14 @@ void tune::run(ir::module &mod) {
|
||||
size_t addr_space = ptr_ty->get_pointer_address_space();
|
||||
if(addr_space < 4){
|
||||
ir::type *ty = mod.get_builder().get_int32_ty();
|
||||
std::unique_ptr<ir::metaparameter> tmp(ir::metaparameter::create(ctx, ty, 1, 8));
|
||||
std::unique_ptr<ir::metaparameter> tmp(ir::metaparameter::create(ctx, ty, 1, 1));
|
||||
*params_.at(i).at("nts.d0") = *tmp;
|
||||
}
|
||||
}
|
||||
if(dynamic_cast<ir::dot_inst*>(i) && i->get_type()->is_tile_ty()){
|
||||
ir::type *ty = mod.get_builder().get_int32_ty();
|
||||
std::unique_ptr<ir::metaparameter> tmp1(ir::metaparameter::create(ctx, ty, 1, 8));
|
||||
std::unique_ptr<ir::metaparameter> tmp2(ir::metaparameter::create(ctx, ty, 1, 8));
|
||||
std::unique_ptr<ir::metaparameter> tmp1(ir::metaparameter::create(ctx, ty, 1, 1));
|
||||
std::unique_ptr<ir::metaparameter> tmp2(ir::metaparameter::create(ctx, ty, 1, 1));
|
||||
*params_.at(i).at("nts.d0") = *tmp1;
|
||||
*params_.at(i).at("nts.d1") = *tmp2;
|
||||
}
|
||||
@@ -365,6 +375,7 @@ bool tune::check_constraints(std::map<ir::value *, std::vector<std::string>> &er
|
||||
|
||||
// check constraints
|
||||
for(ir::instruction *i: grids_){
|
||||
// std::cout << i->get_name() << std::endl;
|
||||
ir::type *ty = i->get_type();
|
||||
const auto &shapes = ty->get_tile_shapes();
|
||||
// for each dimension, the product of layout components
|
||||
@@ -396,11 +407,15 @@ bool tune::check_constraints(std::map<ir::value *, std::vector<std::string>> &er
|
||||
errors[i].push_back("HMMA must have only 4 fragments per warp");
|
||||
}
|
||||
int num_threads = get_req_num_threads(i);
|
||||
if(num_threads % 64 != 0)
|
||||
if(num_threads % 32 != 0)
|
||||
errors[i].push_back("number of threads per block (" + to_string(num_threads) + ") must be multiple of warp size");
|
||||
if(num_threads != num_threads_)
|
||||
errors[i].push_back("Number of threads must be the same for all tiles (" + to_string(num_threads_) + ")");
|
||||
}
|
||||
// for(auto x: errors)
|
||||
// for(auto e: x.second)
|
||||
// std::cout << x.first->get_name() << ": " << e << std::endl;
|
||||
// exit(EXIT_SUCCESS);
|
||||
return errors.empty();
|
||||
}
|
||||
|
||||
|
@@ -54,16 +54,17 @@ std::pair<base*, rt::jit*> base::get_profile_impl(driver::stream *stream, std::v
|
||||
return num_flops() / ts * 1e-3;
|
||||
};
|
||||
// auto-tune and save result
|
||||
if(autotune != NO_TUNING) {
|
||||
if(autotune == FULL_TUNING || autotune == PARTIAL_TUNING) {
|
||||
std::vector<params_t> space = {};
|
||||
if(autotune == PARTIAL_TUNING)
|
||||
space = search_space();
|
||||
rt::jit::tune_res_t best = jit->autotune(name_.c_str(), src.c_str(), benchmark, space);
|
||||
jit->add_module(name_.c_str(), src.c_str(), best.params);
|
||||
}
|
||||
else {
|
||||
params_t params = heuristics();
|
||||
else{
|
||||
// params_t params = heuristics();
|
||||
// params_t params = jit->get_valid(name_.c_str(), src.c_str());
|
||||
params_t params = {4, 1, 32, 4, 1, 32, 4, 4, 4, 1, 1, 16, 32, 16, 4, 4, 1};
|
||||
jit->add_module(name_.c_str(), src.c_str(), params);
|
||||
}
|
||||
triton::driver::kernel* kernel = jit->get_function(name_.c_str());
|
||||
|
@@ -74,12 +74,14 @@ void dot::enqueue_impl(driver::stream *stream, driver::kernel *kernel,
|
||||
void dot::triton_c_src(std::ostream &os) const {
|
||||
std::string AS0 = "TM", AS1 = "TK";
|
||||
std::string BS0 = "TK", BS1 = "TN";
|
||||
std::string XAS0 = "TM", XAS1 = "TK/4", XAS2 = "4";
|
||||
std::string XBS0 = "TN", XBS1 = "TK/4", XBS2 = "4";
|
||||
std::string bca0 = "[newaxis, :]", bca1 = "[:, newaxis]";
|
||||
std::string bcb0 = "[:, newaxis]", bcb1 = "[newaxis, :]";
|
||||
std::string lda0 = "*lda", lda1 = "";
|
||||
std::string ldb0 = "", ldb1 = "*ldb";
|
||||
std::string usea = AT_ ? "trans(a)" : "a";
|
||||
std::string useb = BT_ ? "trans(b)" : "b";
|
||||
std::string usea = AT_ ? "trans(xa)" : "xa";
|
||||
std::string useb = BT_ ? "trans(xb)" : "xb";
|
||||
if(AT_){
|
||||
std::swap(AS0, AS1);
|
||||
std::swap(bca0, bca1);
|
||||
@@ -92,12 +94,15 @@ void dot::triton_c_src(std::ostream &os) const {
|
||||
}
|
||||
std::string AS = AS0 + ", " + AS1;
|
||||
std::string BS = BS0 + ", " + BS1;
|
||||
std::string XAS = XAS0 + ", " + XAS1 + ", " + XAS2;
|
||||
std::string XBS = XBS0 + ", " + XBS1 + ", " + XBS2;
|
||||
std::string XCS = "TM, TN, 4";
|
||||
std::string align_lda_str = "multiple_of(" + std::to_string(align_lda_) + ")";
|
||||
std::string align_ldb_str = "multiple_of(" + std::to_string(align_ldb_) + ")";
|
||||
std::string res =
|
||||
R"(
|
||||
const tunable int TM = {16, 32, 64, 128};
|
||||
const tunable int TN = {16, 32, 64, 128};
|
||||
const tunable int TM = {32};
|
||||
const tunable int TN = {32};
|
||||
const tunable int TK = {32};
|
||||
const tunable int GZ = {1};
|
||||
|
||||
@@ -113,7 +118,7 @@ void matmul(restrict read_only align(16) )" + a_ty_ + R"( *A,
|
||||
int ryb[TN] = ridy * TN + (0 ... TN);
|
||||
int rka[TK] = 0 ... TK;
|
||||
int rkb[TK] = 0 ... TK;
|
||||
float c[TM, TN] = 0;
|
||||
float xc[)" + XCS + R"(] = 0;
|
||||
)" + a_ty_ + R"(* pa[)" + AS + "] = A + rka" + bca0 + lda0 + " + rxa" + bca1 + lda1 + R"(;
|
||||
)" + b_ty_ + R"(* pb[)" + BS + "] = B + rkb" + bcb0 + ldb0 + " + ryb" + bcb1 + ldb1 + R"(;
|
||||
bool checka[)" + AS + R"(] = (rka < K))" + bca0 + " && (rxa < M)" + bca1 + R"(;
|
||||
@@ -121,7 +126,9 @@ void matmul(restrict read_only align(16) )" + a_ty_ + R"( *A,
|
||||
)" + a_ty_ + R"( a[)" + AS + R"(] = checka ? *pa : 0;
|
||||
)" + b_ty_ + R"( b[)" + BS + R"(] = checkb ? *pb : 0;
|
||||
for(int k = K; k > 0; k = k - TK){
|
||||
c = dot()" + usea + ", " + useb + R"(, c);
|
||||
)" + a_ty_ + R"( xa[)" + XAS + "] = __reshape(a, " + XAS + R"();
|
||||
)" + b_ty_ + R"( xb[)" + XBS + "] = __reshape(b, " + XBS + R"();
|
||||
xc = dot()" + usea + ", " + useb + R"(, xc);
|
||||
pa = pa + TK)" + lda0 + R"(;
|
||||
pb = pb + TK)" + ldb0 + R"(;
|
||||
bool checka[)" + AS + R"(] = k > TK;
|
||||
@@ -131,11 +138,9 @@ void matmul(restrict read_only align(16) )" + a_ty_ + R"( *A,
|
||||
}
|
||||
int rxc[TM] = ridx * TM + (0 ... TM);
|
||||
int ryc[TN] = ridy * TN + (0 ... TN);
|
||||
bool checkc0[TM] = rxc < M;
|
||||
bool checkc1[TN] = ryc < N;
|
||||
bool checkc[TM, TN] = checkc0[:, newaxis] && checkc1[newaxis, :];
|
||||
float* pc[TM, TN] = C + ryc[newaxis, :]*ldc + rxc[:, newaxis];
|
||||
@checkc *pc = c;
|
||||
float c[TM, TN] = __sum(xc, 2);
|
||||
*pc = c;
|
||||
}
|
||||
)";
|
||||
|
||||
|
@@ -148,13 +148,13 @@ DEFINE_UNARY_FLOAT(fneg)
|
||||
value* builder::create_insert_nuwnswb_binop(binary_operator::op_t op, value *lhs,
|
||||
value *rhs, const std::string &name,
|
||||
bool has_nuw, bool has_nsw) {
|
||||
if(auto *clhs = dynamic_cast<constant_int*>(lhs)){
|
||||
if(auto *crhs = dynamic_cast<constant_int*>(rhs)){
|
||||
constant_expression* result = constant_expression::create(op, clhs, crhs);
|
||||
if (has_nuw) result->set_has_no_unsigned_wrap();
|
||||
if (has_nsw) result->set_has_no_signed_wrap();
|
||||
return result;
|
||||
}
|
||||
auto *clhs = dynamic_cast<constant_int*>(lhs);
|
||||
auto *crhs = dynamic_cast<constant_int*>(rhs);
|
||||
if(clhs && crhs){
|
||||
constant_expression* result = constant_expression::create(op, clhs, crhs);
|
||||
if (has_nuw) result->set_has_no_unsigned_wrap();
|
||||
if (has_nsw) result->set_has_no_signed_wrap();
|
||||
return result;
|
||||
}
|
||||
else {
|
||||
binary_operator* result = insert(binary_operator::create(op, lhs, rhs), name);
|
||||
|
@@ -101,7 +101,6 @@ ir::value *binary_expression::llvm_op(ir::module *mod, ir::builder &builder, ir:
|
||||
ir::value* binary_expression::codegen(ir::module *mod) const{
|
||||
ir::value *lhs = lhs_->codegen(mod);
|
||||
ir::value *rhs = rhs_->codegen(mod);
|
||||
std::cout << " " << typeid(*lhs_).name() << " " << typeid(*rhs_).name() << std::endl;
|
||||
ir::value *result = llvm_op(mod, mod->get_builder(), lhs, rhs, "");
|
||||
return result;
|
||||
}
|
||||
|
Reference in New Issue
Block a user