diff --git a/CMakeLists.txt b/CMakeLists.txt index 7c7a1c0ab..694cc5578 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -25,7 +25,7 @@ include_directories(${LLVM_INCLUDE_DIRS}) add_definitions(${LLVM_DEFINITIONS}) #llvm_map_components_to_libnames(llvm_libs all) -#Default build type +# Default build type if(NOT CMAKE_BUILD_TYPE) message(STATUS "Default build type: Release") set(CMAKE_BUILD_TYPE "Release") @@ -63,7 +63,14 @@ endif() # Triton file(GLOB_RECURSE LIBTRITON_SRC lib/*.cpp) -add_library(triton SHARED ${LIBTRITON_SRC} ${PYTHON_SRC} ${BISON_Parser_OUTPUTS} ${FLEX_Lexer_OUTPUTS}) +add_library(triton SHARED ${LIBTRITON_SRC} ${EIGHTCC_SRC} ${PYTHON_SRC} ${BISON_Parser_OUTPUTS} ${FLEX_Lexer_OUTPUTS}) target_link_libraries(triton LLVM) +# Warning level +if(MSVC) + target_compile_options(triton PRIVATE /W4) +else() + target_compile_options(triton PRIVATE -Wno-unused-parameter -Wall -Wextra -pedantic) +endif() + diff --git a/examples/cpp/CMakeLists.txt b/examples/cpp/CMakeLists.txt index 3366ba591..f5f6a40b8 100644 --- a/examples/cpp/CMakeLists.txt +++ b/examples/cpp/CMakeLists.txt @@ -1,4 +1,4 @@ -foreach(PROG dot conv shift) +foreach(PROG dot) add_executable(${PROG} ${PROG}.cpp) set_target_properties(${PROG} PROPERTIES OUTPUT_NAME ${PROG}) include_directories(/usr/local/cuda/include/) diff --git a/examples/cpp/conv.cpp b/examples/cpp/conv.cpp deleted file mode 100644 index dbe0591f0..000000000 --- a/examples/cpp/conv.cpp +++ /dev/null @@ -1,58 +0,0 @@ -#include -#include -#include -#include "triton/runtime/jit.h" -#include "triton/driver/backend.h" -#include "triton/driver/stream.h" -#include "triton/dnn/conv.h" -#include "triton/tools/bench.hpp" - -int main() { - // initialize default compute device - auto context = triton::driver::backend::contexts::get_default(); - triton::dnn::conv::type ty = triton::dnn::conv::FPROP; - // initialization - int32_t B = 16, NF = 128; - int32_t D = 1, H = 16, W = 16; - int32_t NC = 64, T = 1, R = 3, S = 3; - int32_t pad_d = 0, pad_h = 0, pad_w = 0; - int32_t stride_d = 1, stride_h = 1, stride_w = 1; - int32_t upsample_d = 1, upsample_h = 1, upsample_w = 1; -// triton::dnn::conv configuration(128, 256, 1, 14, 14, 1, 5, 5, 512, 1, 1, 1, 0, 0, 0, 1, 1, 1, "float", "float", triton::dnn::conv::FPROP, 0); - triton::dnn::conv configuration(B, NC, D, H, W, T, R, S, NF, - stride_d, stride_h, stride_w, - pad_d, pad_h, pad_w, - upsample_d, upsample_h, upsample_w, - "float", "float", ty, 0); - // convolution configuration - std::vector hc(configuration.c_size()); - std::vector rc(configuration.c_size()); - std::vector ha(configuration.a_size()); - std::vector hb(configuration.b_size()); - srand(0); - for(size_t i = 0; i < ha.size(); i++) - ha[i] = (float)rand()/RAND_MAX; - for(size_t i = 0; i < hb.size(); i++) - hb[i] = (float)rand()/RAND_MAX; - for(size_t i = 0; i < hc.size(); i++) - hc[i] = 0; - rc = hc; - triton::driver::buffer* dc = triton::driver::buffer::create(context, hc.size()*4); - triton::driver::buffer* da = triton::driver::buffer::create(context, ha.size()*4); - triton::driver::buffer* db = triton::driver::buffer::create(context, hb.size()*4); - triton::driver::stream* stream = triton::driver::stream::create(context); - stream->write(da, true, 0, ha); - stream->write(db, true, 0, hb); - stream->write(dc, true, 0, hc); - stream->synchronize(); - configuration.enqueue(stream, {da, db, dc, nullptr}); - stream->read(dc, true, 0, hc); - configuration.cpu_ref(rc.data(), ha.data(), hb.data()); - for(size_t i = 0; i < hc.size(); i++){ - if(std::isnan(hc[i]) || std::abs(hc[i] - rc[i])/std::max(hc[i], rc[i]) > 1e-4){ - std::cout << i << " " << hc[i] << " " << rc[i] << std::endl; - exit(EXIT_FAILURE); - } - } - std::cout << "Pass!" << std::endl; -} diff --git a/examples/cpp/dot.cpp b/examples/cpp/dot.cpp index e592da570..102380036 100644 --- a/examples/cpp/dot.cpp +++ b/examples/cpp/dot.cpp @@ -3,7 +3,6 @@ #include #include "triton/driver/backend.h" #include "triton/driver/stream.h" -#include "triton/dnn/dot.h" #include "triton/tools/bench.hpp" #include "triton/external/half.hpp" #include "triton/runtime/function.h" diff --git a/examples/cpp/shift.cpp b/examples/cpp/shift.cpp deleted file mode 100644 index 1495de3c4..000000000 --- a/examples/cpp/shift.cpp +++ /dev/null @@ -1,150 +0,0 @@ -#include -#include -#include -#include "cuda.h" -#include "triton/runtime/jit.h" -#include "triton/driver/backend.h" -#include "triton/driver/stream.h" -#include "triton/tools/bench.hpp" -#include "triton/dnn/shift.h" -#include "triton/external/half.hpp" - -struct perf_t { - double triton; - double cublas; -}; - -perf_t do_bench(triton::driver::stream *stream, - int32_t R, int32_t S, int32_t B, int32_t F, int32_t H, int32_t W, int32_t C, - triton::dnn::op_t op, triton::dnn::layout_t layout, - std::string numeric_t) { - typedef float NumericT; - - // driver variables - triton::driver::context* context = stream->context(); - - // random shifts - std::vector shift_h(C); - std::vector shift_w(C); - for(int32_t c = 0; c < C; c++){ - shift_h[c] = rand() % R - R / 2; - shift_w[c] = rand() % S - S / 2; - } - // configuration - triton::dnn::shift shift(B, C, 1, H, W, 1, R, S, F, 1, 1, - shift_h.data(), shift_w.data(), - numeric_t, numeric_t, - op, false, layout); - // host buffers - size_t a_size = B*C*H*W; - size_t b_size = C*F; - size_t c_size = B*F*H*W; - if(op == triton::dnn::BPROP) - std::swap(a_size, c_size); - if(op == triton::dnn::WGRAD){ - std::swap(b_size, c_size); - std::swap(a_size, b_size); - } - std::vector ha(a_size); - std::vector hb(b_size); - std::vector hc(c_size); - std::vector rc(hc.size()); - // device buffers - triton::driver::buffer* dc = triton::driver::buffer::create(context, hc.size()*4); - triton::driver::buffer* da = triton::driver::buffer::create(context, ha.size()*sizeof(NumericT)); - triton::driver::buffer* db = triton::driver::buffer::create(context, hb.size()*sizeof(NumericT)); - // initialize host - srand(0); - for(size_t i = 0; i < ha.size(); i++) - ha[i] = (NumericT)rand() / RAND_MAX; - for(size_t i = 0; i < hb.size(); i++) - hb[i] = (NumericT)rand() / RAND_MAX; - for(size_t i = 0; i < hc.size(); i++) - hc[i] = 0; - // initialize device - stream->write(da, true, 0, ha); - stream->write(db, true, 0, hb); - stream->write(dc, true, 0, hc); - stream->synchronize(); - // benchmark triton - double triton_ns = triton::tools::bench([&]() { shift.enqueue(stream, {da, db, dc}, triton::dnn::NO_TUNING);}, stream); - // benchmark cublas -// NumericT alpha = 1; -// NumericT beta = 0; -// cublasGemmAlgo_t fastest; -// cublasGemm(HALF_TYPE, stream, shift.AT(), shift.BT(), shift.M(), shift.N(), shift.K(), -// &alpha, da, shift.lda(), -// db, shift.ldb(), &beta, -// dc, shift.ldc(), &fastest); -// double cublas_ns = triton::tools::bench([&]() { cublasGemm(HALF_TYPE, stream, shift.AT(), shift.BT(), shift.M(), shift.N(), shift.K(), -// &alpha, da, shift.lda(), -// db, shift.ldb(), -// &beta, dc, shift.ldc(), nullptr, fastest); }, stream); - // result - auto tflops = [&](double nanosec) { return shift.num_flops() / nanosec * 1e-3; }; - perf_t result; -// result.cublas = tflops(cublas_ns); - result.triton = tflops(triton_ns); - delete da; - delete db; - delete dc; - return result; -} - -int main() { - using triton::dnn::op_t; - using triton::dnn::layout_t; - - struct config_t{ - int32_t B; - int32_t C; - int32_t H; - int32_t W; - int32_t R; - int32_t S; - int32_t F; - int32_t stride_h; - int32_t stride_w; - op_t op; - layout_t layout; - std::string ty; - - std::string repr() { - std::ostringstream oss; - oss << B << ", " << C << ", " << H << ", " << W << ", " << R << ", " << S << ", " << F << ", " << op << ", " << layout << ", " << ty; - return oss.str(); - } - - perf_t perf(triton::driver::stream *stream){ - return do_bench(stream, R, S, B, F, H, W, C, op, layout, ty); - } - }; - // shapes to benchmark - std::vector configs; - std::vector resnet18 = - { - {128, 128, 32, 32, 3, 3, 128, 1, 1}, - {128, 128, 32, 32, 3, 3, 128, 1, 1}, - {128, 128, 32, 32, 3, 3, 256, 2, 2}, - {128, 256, 16, 16, 3, 3, 256, 1, 1}, - {128, 256, 16, 16, 3, 3, 512, 2, 2}, - {128, 512, 8, 8, 3, 3, 512, 1, 1}, - {128, 512, 8, 8, 3, 3, 1024, 1, 1}, - {128, 1024, 8, 8, 3, 3, 1024, 1, 1} - }; - for(config_t c: resnet18){ - for(op_t op: {op_t::FPROP, op_t::BPROP, op_t::WGRAD}){ - configs.push_back({c.B, c.C, c.H, c.W, c.R, c.S, c.F, c.stride_h, c.stride_w, op, layout_t::CHWN, "half"}); - } - } - - // initialize default compute device - auto context = triton::driver::backend::contexts::get_default(); - triton::driver::stream *stream = triton::driver::stream::create(context); - - for(config_t c: configs){ - std::string repr = c.repr(); - perf_t perf = c.perf(stream); - std::cout << "// " << repr << ", " << perf.triton << ", " << perf.cublas << std::endl; - } -} diff --git a/examples/cpp/shift.ptx b/examples/cpp/shift.ptx deleted file mode 100644 index 62a841909..000000000 --- a/examples/cpp/shift.ptx +++ /dev/null @@ -1,93 +0,0 @@ -// -// Generated by NVIDIA NVVM Compiler -// -// Compiler Build ID: CL-24817639 -// Cuda compilation tools, release 10.0, V10.0.130 -// Based on LLVM 3.4svn -// - -.version 6.3 -.target sm_60 -.address_size 64 - - // .globl _Z25shift_cuda_forward_kernelPKfPKiPfiiii - -.visible .entry shift( - .param .u64 _Z25shift_cuda_forward_kernelPKfPKiPfiiii_param_0, - .param .u64 _Z25shift_cuda_forward_kernelPKfPKiPfiiii_param_1, - .param .u64 _Z25shift_cuda_forward_kernelPKfPKiPfiiii_param_2, - .param .u32 _Z25shift_cuda_forward_kernelPKfPKiPfiiii_param_3, - .param .u32 _Z25shift_cuda_forward_kernelPKfPKiPfiiii_param_4, - .param .u32 _Z25shift_cuda_forward_kernelPKfPKiPfiiii_param_5, - .param .u32 _Z25shift_cuda_forward_kernelPKfPKiPfiiii_param_6 -) -{ - .reg .pred %p<10>; - .reg .f32 %f<2>; - .reg .b32 %r<31>; - .reg .b64 %rd<13>; - - - ld.param.u64 %rd1, [_Z25shift_cuda_forward_kernelPKfPKiPfiiii_param_0]; - ld.param.u64 %rd3, [_Z25shift_cuda_forward_kernelPKfPKiPfiiii_param_1]; - ld.param.u64 %rd2, [_Z25shift_cuda_forward_kernelPKfPKiPfiiii_param_2]; - ld.param.u32 %r3, [_Z25shift_cuda_forward_kernelPKfPKiPfiiii_param_3]; - ld.param.u32 %r4, [_Z25shift_cuda_forward_kernelPKfPKiPfiiii_param_4]; - ld.param.u32 %r5, [_Z25shift_cuda_forward_kernelPKfPKiPfiiii_param_5]; - ld.param.u32 %r6, [_Z25shift_cuda_forward_kernelPKfPKiPfiiii_param_6]; - cvta.to.global.u64 %rd4, %rd3; - mov.u32 %r7, %ntid.x; - mov.u32 %r8, %ctaid.x; - mov.u32 %r9, %tid.x; - mad.lo.s32 %r1, %r7, %r8, %r9; - mul.lo.s32 %r10, %r4, %r3; - mul.lo.s32 %r11, %r10, %r5; - mul.lo.s32 %r12, %r11, %r6; - mul.lo.s32 %r13, %r5, %r4; - mul.lo.s32 %r14, %r13, %r6; - rem.s32 %r15, %r1, %r14; - sub.s32 %r16, %r1, %r15; - mul.lo.s32 %r17, %r6, %r5; - div.s32 %r18, %r15, %r17; - mul.lo.s32 %r19, %r18, %r17; - sub.s32 %r20, %r15, %r19; - div.s32 %r21, %r20, %r5; - mul.lo.s32 %r22, %r21, %r6; - sub.s32 %r23, %r20, %r22; - shl.b32 %r24, %r18, 1; - mul.wide.s32 %rd5, %r24, 4; - add.s64 %rd6, %rd4, %rd5; - ld.global.nc.u32 %r25, [%rd6]; - add.s32 %r26, %r25, %r21; - ld.global.nc.u32 %r27, [%rd6+4]; - add.s32 %r28, %r23, %r27; - add.s32 %r29, %r16, %r19; - mad.lo.s32 %r30, %r26, %r5, %r29; - add.s32 %r2, %r30, %r28; - setp.lt.s32 %p1, %r1, %r12; - setp.gt.s32 %p2, %r26, -1; - and.pred %p3, %p1, %p2; - setp.lt.s32 %p4, %r26, %r5; - and.pred %p5, %p3, %p4; - setp.gt.s32 %p6, %r28, -1; - and.pred %p7, %p5, %p6; - setp.lt.s32 %p8, %r28, %r6; - and.pred %p9, %p7, %p8; - @!%p9 bra BB0_2; - bra.uni BB0_1; - -BB0_1: - cvta.to.global.u64 %rd7, %rd1; - mul.wide.s32 %rd8, %r1, 4; - add.s64 %rd9, %rd7, %rd8; - ld.global.nc.f32 %f1, [%rd9]; - cvta.to.global.u64 %rd10, %rd2; - mul.wide.s32 %rd11, %r2, 4; - add.s64 %rd12, %rd10, %rd11; - st.global.f32 [%rd12], %f1; - -BB0_2: - ret; -} - - diff --git a/include/triton/codegen/selection/selection.h b/include/triton/codegen/selection/selection.h index 433633cff..3f118d47a 100644 --- a/include/triton/codegen/selection/selection.h +++ b/include/triton/codegen/selection/selection.h @@ -100,8 +100,8 @@ public: private: Value *ptr_; bool return_vector_; - Value *offset_; Builder &builder_; + Value *offset_; std::map ptr_cache_; unsigned vector_size_; }; @@ -206,9 +206,9 @@ private: tmap_t tmap_; analysis::shmem::allocation *alloc_; analysis::tune *params_; - target *tgt_; analysis::shmem::info *buffer_info_; analysis::alignment_info *alignment_; + target *tgt_; std::map axes_; Value *sh_mem_ptr_; Value *offset_a_i_, *offset_a_k_; diff --git a/include/triton/ir/function.h b/include/triton/ir/function.h index bde5218b2..9cfc89931 100644 --- a/include/triton/ir/function.h +++ b/include/triton/ir/function.h @@ -47,11 +47,11 @@ public: return std::make_pair(kind_, value_) < std::make_pair(other.kind_, other.value_); } - const attribute_kind_t get_kind() const { + attribute_kind_t get_kind() const { return kind_; } - const unsigned get_value() const { + unsigned get_value() const { return value_; } diff --git a/include/triton/lang/expression.h b/include/triton/lang/expression.h index 7724fdd61..a3574f15d 100644 --- a/include/triton/lang/expression.h +++ b/include/triton/lang/expression.h @@ -344,8 +344,8 @@ public: const expression *rvalue() const { return rvalue_; } public: - ASSIGN_OP_T op_; const expression *lvalue_; + ASSIGN_OP_T op_; const expression *rvalue_; }; diff --git a/include/triton/runtime/function.h b/include/triton/runtime/function.h index 2cbd65fd4..af849448b 100644 --- a/include/triton/runtime/function.h +++ b/include/triton/runtime/function.h @@ -76,8 +76,8 @@ private: void operator()(driver::stream *stream, const std::array& grid, const std::vector& args) const; private: - std::shared_ptr parent_; std::shared_ptr bin_; + std::shared_ptr parent_; std::vector param_tys_; size_t n_threads_; }; diff --git a/lib/codegen/analysis/alignment.cpp b/lib/codegen/analysis/alignment.cpp index a602c87ca..6383ed850 100644 --- a/lib/codegen/analysis/alignment.cpp +++ b/lib/codegen/analysis/alignment.cpp @@ -227,7 +227,7 @@ unsigned alignment_info::populate_starting_multiple(ir::value *v){ if(auto *x = dynamic_cast(v)){ return cache(x->get_first()->get_value()); } - if(auto *x = dynamic_cast(v)){ + if(dynamic_cast(v)){ return cache(128); } if(auto *x = dynamic_cast(v)){ diff --git a/lib/codegen/analysis/shmem/info.cpp b/lib/codegen/analysis/shmem/info.cpp index b674560bf..8f0dac32c 100644 --- a/lib/codegen/analysis/shmem/info.cpp +++ b/lib/codegen/analysis/shmem/info.cpp @@ -19,7 +19,7 @@ bool info::is_loop_latch(ir::phi_node *phi, ir::instruction *terminator){ if(auto *br = dynamic_cast(terminator)) return br->get_true_dest() == phi->get_parent() || br->get_false_dest() == phi->get_parent(); - else if(auto *br = dynamic_cast(terminator)) + else if(dynamic_cast(terminator)) return false; else throw std::runtime_error("unreachable"); @@ -36,15 +36,15 @@ void info::replace(ir::value* before, ir::value *after) { } inline bool get_is_shared(ir::value* v) { - if(auto x = dynamic_cast(v)) + if(dynamic_cast(v)) return true; - if(auto x = dynamic_cast(v)) + if(dynamic_cast(v)) return true; - if(auto x = dynamic_cast(v)) + if(dynamic_cast(v)) return true; - if(auto x = dynamic_cast(v)) + if(dynamic_cast(v)) return true; - if(auto x = dynamic_cast(v)){ + if(auto *x = dynamic_cast(v)){ bool res = true; for(unsigned inc = 0; inc < x->get_num_incoming(); inc++) res = res && get_is_shared(x->get_incoming_value(inc)); diff --git a/lib/codegen/analysis/tune.cpp b/lib/codegen/analysis/tune.cpp index ec67ef254..c43a7126b 100644 --- a/lib/codegen/analysis/tune.cpp +++ b/lib/codegen/analysis/tune.cpp @@ -58,7 +58,7 @@ void tune::init_c_graph(ir::instruction *v) { shapes = store->get_pointer_operand()->get_type()->get_tile_shapes(); else if(auto *atom = dynamic_cast(v)) shapes = atom->get_operand(0)->get_type()->get_tile_shapes(); - else if(auto *downcast = dynamic_cast(v)) + else if(dynamic_cast(v)) return; else if(auto *reduce = dynamic_cast(v)) { unsigned axis = reduce->get_axis(); @@ -116,7 +116,7 @@ void tune::init_c_graph(ir::instruction *v) { } } // Matrix multiplication - else if(auto *x = dynamic_cast(v)){ + else if(dynamic_cast(v)){ ir::value *A = v->get_operand(0); ir::value *B = v->get_operand(1); ir::value *D = v->get_operand(2); @@ -166,7 +166,7 @@ void tune::connected_components(node_t x, const std::vector if(nodes.find(x) != nodes.end()){ nodes.erase(x); std::string suffix = ".d" + std::to_string(x.second); - for(int i = 0; i < mps.size(); i++) + for(unsigned i = 0; i < mps.size(); i++) params_[x.first].insert({prefixes[i] + suffix, mps[i]}); ir::type *ty = x.first->get_type(); if(ty->is_tile_ty()){ @@ -254,24 +254,24 @@ void tune::init(ir::module &mod) { create_grids(grids_, references, fn); } - int num_threads = get_num_threads(); - auto clamp = [&](int x, int lo, int hi) { return std::min(std::max(x, lo), hi); }; + unsigned num_threads = get_num_threads(); + auto clamp = [&](unsigned x, unsigned lo, unsigned hi) { return std::min(std::max(x, lo), hi); }; for(ir::value *i: grids_){ if(!i->get_type()->is_tile_ty()) continue; auto shapes = i->get_type()->get_tile_shapes(); - int shape_0 = shapes[0]->get_value(); - int shape_1 = shapes[1]->get_value(); - int size = i->get_type()->get_tile_num_elements(); + unsigned shape_0 = shapes[0]->get_value(); + unsigned shape_1 = shapes[1]->get_value(); + unsigned size = i->get_type()->get_tile_num_elements(); /* HMMA parameters*/ if(fragments_.at({i, 0}) == HMMA_FRAGMENT_C){ /* fragments per warp */ // try to make things as square as possible to maximize data re-use - std::vector fpw = {1, 1, 1}; - std::vector fpw_nm1; - int num_fragments = std::min((shape_0/8)*(shape_1/8), 4); + std::vector fpw = {1, 1, 1}; + std::vector fpw_nm1; + unsigned num_fragments = std::min((shape_0/8)*(shape_1/8), 4); do { fpw_nm1 = fpw; if(fpw[0]*fpw[1] < num_fragments) @@ -280,13 +280,13 @@ void tune::init(ir::module &mod) { fpw[1] = clamp(fpw[1]*2, 1, shape_1 / 8); }while(fpw_nm1 != fpw); // store parameters - for(int d = 0; d < shapes.size(); d++) + for(unsigned d = 0; d < shapes.size(); d++) params_.at(i).at("fpw.d" + std::to_string(d))->set_value(fpw[d]); /* warps per tile */ // try to make things as square as possible to maximize data re-use - std::vector wpt = {1, 1, 1}; - std::vector wpt_nm1; + std::vector wpt = {1, 1, 1}; + std::vector wpt_nm1; do{ wpt_nm1 = wpt; if(wpt[0] * wpt[1] * wpt[2] < num_warps_) @@ -295,7 +295,7 @@ void tune::init(ir::module &mod) { wpt[1] = clamp(wpt[1]*2, 1, shape_1 / (fpw[1]*8)); }while(wpt_nm1 != wpt); // store parameters - for(int d = 0; d < shapes.size(); d++) + for(unsigned d = 0; d < shapes.size(); d++) params_.at(i).at("wpt.d" + std::to_string(d))->set_value(wpt[d]); /* sanity check */ @@ -309,8 +309,8 @@ void tune::init(ir::module &mod) { /* Scan-line */ else{ - int shape = shapes[0]->get_value(); - int current = num_threads; + unsigned shape = shapes[0]->get_value(); + unsigned current = num_threads; params_.at(i).at("nts.d0")->set_value(clamp(size / num_threads, 1, 8)); params_.at(i).at("mts.d0")->set_value(clamp(current, 1, shape / params_.at(i).at("nts.d0")->get_value())); current = current / params_.at(i).at("mts.d0")->get_value(); diff --git a/lib/codegen/selection/selection.cpp b/lib/codegen/selection/selection.cpp index 166b423bb..4b31dce52 100644 --- a/lib/codegen/selection/selection.cpp +++ b/lib/codegen/selection/selection.cpp @@ -226,6 +226,7 @@ llvm::Instruction::BinaryOps llvm_op(ir::binary_op_t op) { case ttop::Or: return llop::Or; case ttop::Xor: return llop::Xor; } + throw std::runtime_error("unknown operator"); } llvm::Instruction::CastOps llvm_op(ir::cast_op_t op) { @@ -246,6 +247,7 @@ llvm::Instruction::CastOps llvm_op(ir::cast_op_t op) { case ttop::BitCast: return llop::BitCast; case ttop::AddrSpaceCast: return llop::AddrSpaceCast; } + throw std::runtime_error("unknown operator"); } llvm::CmpInst::Predicate llvm_pred(ir::cmp_pred_t pred) { @@ -283,6 +285,7 @@ llvm::CmpInst::Predicate llvm_pred(ir::cmp_pred_t pred) { case ttop::ICMP_SLE: return llop::ICMP_SLE; case ttop::LAST_ICMP_PREDICATE: return llop::LAST_ICMP_PREDICATE; } + throw std::runtime_error("unknown operator"); } /* convert ir::type to Type */ @@ -468,7 +471,7 @@ Instruction *selection::llvm_inst(ir::instruction *inst, std::function(inst)){ Value *ptr = value(ii->get_operand(0)); Value *val = value(ii->get_operand(1)); - Value *atom_f_add; + Value *atom_f_add = nullptr; if(val->getType()->isFloatTy()) atom_f_add = Intrinsic::getDeclaration(builder.GetInsertBlock()->getModule(), Intrinsic::nvvm_atomic_load_add_f32, {ptr->getType()}); else if(val->getType()->isHalfTy()){ @@ -477,6 +480,8 @@ Instruction *selection::llvm_inst(ir::instruction *inst, std::functiongetPointerTo(), fp16}, false); atom_f_add = InlineAsm::get(atom_ty, " atom.relaxed.global.gpu.add.noftz.f16 $0, [$1], $2;", "=h,l,h", true); } + if(atom_f_add == nullptr) + throw std::runtime_error("unsupported atomic add"); Value *res = builder.CreateCall(atom_f_add, {ptr, val}); return (Instruction*)res; } @@ -607,7 +612,6 @@ void selection::init_axes(ir::value *v, IRBuilder<> &builder, Value *u_thread_id Value *_2 = builder.getInt32(2); Value *_3 = builder.getInt32(3); Value *_4 = builder.getInt32(4); - Value *_8 = builder.getInt32(8); Value *_16 = builder.getInt32(16); // fragments per warp @@ -1303,11 +1307,10 @@ void selection::lower_masked_load(ir::masked_load_inst *x, LLVMContext &ctx, Fun unsigned id = linear / vector_size; if(linear % vector_size == 0) { Value *ptr = pointers->get_value(idx); - ConstantInt *cst = nullptr; - if(GetElementPtrInst *gep = dyn_cast(ptr)) - if(gep->getNumIndices() == 1){ - cst = dyn_cast(gep->idx_begin()); - } +// ConstantInt *cst = nullptr; +// if(GetElementPtrInst *gep = dyn_cast(ptr)) +// if(gep->getNumIndices() == 1) +// cst = dyn_cast(gep->idx_begin()); ptr = builder.CreateBitCast(ptr, PointerType::get(VectorType::get(result->get_ty(), vector_size), ptr->getType()->getPointerAddressSpace())); @@ -1374,10 +1377,6 @@ void selection::lower_load(ir::load_inst *x, LLVMContext &ctx, Function *fn, IRB unsigned id = linear / vector_size; if(linear % vector_size == 0) { Value *ptr = pointers->get_value(idx); - ConstantInt *cst = nullptr; - if(GetElementPtrInst *gep = dyn_cast(ptr)) - if(gep->getNumIndices() == 1) - cst = dyn_cast(gep->idx_begin()); ptr = builder.CreateBitCast(ptr, PointerType::get(VectorType::get(result->get_ty(), vector_size), ptr->getType()->getPointerAddressSpace())); packets[id] = builder.CreateLoad(ptr); diff --git a/lib/codegen/transform/peephole.cpp b/lib/codegen/transform/peephole.cpp index d5d678628..73885c772 100644 --- a/lib/codegen/transform/peephole.cpp +++ b/lib/codegen/transform/peephole.cpp @@ -60,6 +60,7 @@ ir::value* rewrite_trans_phi_impl(ir::value *value, ir::builder &builder, trans->set_operand(0, i); return trans; } + return nullptr; } bool peephole::rewrite_trans_phi(ir::instruction* value, ir::builder& builder) { @@ -76,6 +77,8 @@ bool peephole::rewrite_trans_phi(ir::instruction* value, ir::builder& builder) { if(!phi) return false; ir::value* new_phi = rewrite_trans_phi_impl(phi, builder, trans->get_perm()); + if(!new_phi) + return false; trans->replace_all_uses_with(new_phi); return true; diff --git a/lib/ir/constant.cpp b/lib/ir/constant.cpp index 4b06af60e..5ace19a04 100644 --- a/lib/ir/constant.cpp +++ b/lib/ir/constant.cpp @@ -67,8 +67,7 @@ constant_range::constant_range(type *ty, constant_int *first, constant_int *last constant *constant_range::get(constant_int *first, constant_int *last) { assert(first->get_type()->is_integer_ty()); assert(first->get_type() == last->get_type()); - unsigned vfirst = ((constant_int*)first)->get_value(); - assert(vfirst == 0); + assert(((constant_int*)first)->get_value() == 0); type *ty = tile_type::get(first->get_type(), {last}); return new constant_range(ty, first, last); } diff --git a/lib/ir/instructions.cpp b/lib/ir/instructions.cpp index 85b6eee5c..bd06668e6 100644 --- a/lib/ir/instructions.cpp +++ b/lib/ir/instructions.cpp @@ -359,8 +359,11 @@ getelementptr_inst::getelementptr_inst(type *pointee_ty, value *ptr, const std:: : instruction(get_return_type(pointee_ty, ptr, idx), 1 + idx.size(), 1, name, next), source_elt_ty(pointee_ty), res_elt_ty(get_indexed_type(pointee_ty, idx)){ - type *expected_ty = ((pointer_type*)(get_type()->get_scalar_ty()))->get_element_ty(); + // sanity check + type *expected_ty = get_type()->get_scalar_ty(); + expected_ty = ((pointer_type*)expected_ty)->get_element_ty(); assert(res_elt_ty == expected_ty); + // set operands set_operand(0, ptr); for(size_t i = 0; i < idx.size(); i++) set_operand(1 + i, idx[i]); @@ -574,7 +577,7 @@ ir::type* trans_inst::get_res_ty(ir::type* ty, std::vector perm) // permutate argument shapes perm = init_perm(ty, perm); ir::tile_type::tile_shapes_t res_shapes = arg_shapes; - for(int i = 0; i < perm.size(); i++) + for(size_t i = 0; i < perm.size(); i++) res_shapes[i] = arg_shapes[perm[i]->get_value()]; // construct type return tile_type::get(ty->get_scalar_ty(), res_shapes); @@ -587,16 +590,17 @@ std::vector trans_inst::init_perm(ir::type* ty, const std::vector ir::type* int32_ty = type::get_int32_ty(ty->get_context()); std::vector result; result.push_back(ir::constant_int::get(int32_ty, size - 1)); - for(int i = 0; i < size - 1; i++) + for(size_t i = 0; i < size - 1; i++) result.push_back(ir::constant_int::get(int32_ty, i)); return result; } trans_inst::trans_inst(value *arg, const std::vector& perm, const std::string &name, instruction *next) : builtin_inst(get_res_ty(arg->get_type(), perm), 1, 1, name, next) { + // sanity check perm_ = init_perm(arg->get_type(), perm); - auto size = arg->get_type()->get_tile_shapes().size(); - assert(perm_.size() == size); + //auto size = arg->get_type()->get_tile_shapes().size(); + //assert(perm_.size() == size); set_operand(0, arg); } diff --git a/lib/ir/module.cpp b/lib/ir/module.cpp index 7adcbb14a..3d995558e 100644 --- a/lib/ir/module.cpp +++ b/lib/ir/module.cpp @@ -96,8 +96,7 @@ ir::value *module::get_value_recursive(const std::string& name, ir::basic_block bool is_const = const_.find(name) != const_.end(); auto &preds = block->get_predecessors(); ir::type *ty = get_scope().types.at(name); - if(block) - if(!is_const && sealed_blocks_.find(block) == sealed_blocks_.end()){ + if(block && !is_const && sealed_blocks_.find(block) == sealed_blocks_.end()){ incomplete_phis_[block][name] = make_phi(ty, 1, block); result = (ir::value*)incomplete_phis_[block][name]; } @@ -106,9 +105,9 @@ ir::value *module::get_value_recursive(const std::string& name, ir::basic_block result = get_value(name, has_pred?preds.front():nullptr); } else{ - result = make_phi(ty, 1, block); - set_value(name, block, result); - result = add_phi_operands(name, (ir::phi_node*&)result); + ir::phi_node* phi = make_phi(ty, 1, block); + set_value(name, block, phi); + result = add_phi_operands(name, phi); } if(auto *phi = dynamic_cast(result)) result = try_remove_trivial_phis(phi); diff --git a/lib/lang/node.cpp b/lib/lang/node.cpp index 29d61cdb8..dda7126bd 100644 --- a/lib/lang/node.cpp +++ b/lib/lang/node.cpp @@ -106,7 +106,7 @@ void node::implicit_broadcast(ir::module *mod, ir::value *&lhs, ir::value *&rhs) size_t res_size = std::max(lhs_size, rhs_size); ir::type::tile_shapes_t res_shapes(res_size); ir::type::tile_shapes_t::value_type one = ir::tile_type::make_one(mod->get_context()); - for(int i = 0; i < res_size; i++){ + for(size_t i = 0; i < res_size; i++){ if(i >= res_size - lhs_size && i >= res_size - rhs_size) res_shapes[i] = lhs_shapes[i]==one?rhs_shapes[i]:lhs_shapes[i]; else if(i >= res_size - lhs_size) @@ -147,7 +147,7 @@ void node::implicit_broadcast(ir::module *mod, ir::type *ty, ir::value *&src){ int src_dim = src_shapes.size(); // Pad int off = dst_dim - src_dim; - for(size_t i = 0; i < off; i++) + for(int i = 0; i < off; i++) src_shapes.insert(src_shapes.begin(), one); if(off > 0) src = builder.create_reshape(src, src_shapes); diff --git a/lib/runtime/function.cpp b/lib/runtime/function.cpp index 034738c93..d69049291 100644 --- a/lib/runtime/function.cpp +++ b/lib/runtime/function.cpp @@ -88,10 +88,10 @@ arg_type convert(ir::type *ty) { } function::caller::caller(ir::function *ir, std::shared_ptr parent, size_t n_threads) - : bin_(driver::kernel::create(&*parent, ir->get_name().c_str())), n_threads_(n_threads), parent_(parent) { + : bin_(driver::kernel::create(&*parent, ir->get_name().c_str())), parent_(parent), n_threads_(n_threads) { // extract signature ir::function_type* ty = ir->get_fn_type(); - for(int i = 0; i < ty->get_num_params(); i++) + for(size_t i = 0; i < ty->get_num_params(); i++) param_tys_.push_back(convert(ty->get_param_ty(i))); } diff --git a/python/examples/dot.py b/python/examples/dot.py index 75fe931bc..638d49c20 100644 --- a/python/examples/dot.py +++ b/python/examples/dot.py @@ -11,7 +11,8 @@ void matmul(restrict read_only align(16) half *A, restrict read_only align(16) half *B, restrict read_only align(16) half *C, int M, int N, int K, - multiple_of(8) int lda, multiple_of(8) int ldb, int ldc) { + multiple_of(8) int lda, multiple_of(8) int ldb, int ldc) +{ int ridx = get_program_id(0); int ridy = get_program_id(1); int rxa[TM] = ridx * TM + (0 ... TM); diff --git a/python/triton/ops.py b/python/triton/ops.py index ea782ad08..a10739903 100644 --- a/python/triton/ops.py +++ b/python/triton/ops.py @@ -17,8 +17,8 @@ import tensorflow as tf extra_ops = tf.load_op_library('/home/philippe/development/triton/python/build/lib.linux-x86_64-3.6/libextra_tf_ops.so') -def make_bindings(src, outputs, grids): - return libtriton.make_tensorflow_src(src, outputs, grids) +def make_bindings(src, out, grid): + return libtriton.make_tensorflow_src(src, out, grid) def make_cache_path(src): md5 = hashlib.sha1(src.encode())