From d5eaa8dfa0fc8a941a09da8d00b7d4c48b8ebd87 Mon Sep 17 00:00:00 2001 From: daadaada Date: Tue, 24 May 2022 23:56:36 +0800 Subject: [PATCH] Making the generated Triton IR deterministic & a script to compare cached assembly (#522) --- include/triton/codegen/analysis/liveness.h | 12 ++-- include/triton/tools/graph.h | 15 +++-- python/triton/tools/compare_asm.py | 76 ++++++++++++++++++++++ 3 files changed, 91 insertions(+), 12 deletions(-) create mode 100644 python/triton/tools/compare_asm.py diff --git a/include/triton/codegen/analysis/liveness.h b/include/triton/codegen/analysis/liveness.h index a95d62a06..12232b654 100644 --- a/include/triton/codegen/analysis/liveness.h +++ b/include/triton/codegen/analysis/liveness.h @@ -1,12 +1,14 @@ #ifndef TDL_INCLUDE_IR_CODEGEN_LIVENESS_H #define TDL_INCLUDE_IR_CODEGEN_LIVENESS_H -#include -#include -#include #include "triton/codegen/analysis/layout.h" #include "triton/tools/graph.h" +#include "llvm/ADT/MapVector.h" + +#include +#include + namespace triton{ namespace ir{ @@ -42,14 +44,14 @@ struct segment { class liveness { private: - typedef std::map intervals_map_t; + typedef llvm::MapVector intervals_map_t; public: // constructor liveness(layouts *l): layouts_(l){ } // accessors const intervals_map_t& get() const { return intervals_; } - segment get(shared_layout* v) const { return intervals_.at(v); } + segment get(shared_layout* v) const { return intervals_.lookup(v); } // run void run(ir::module &mod); diff --git a/include/triton/tools/graph.h b/include/triton/tools/graph.h index c2ba8d854..69afd5bb3 100644 --- a/include/triton/tools/graph.h +++ b/include/triton/tools/graph.h @@ -3,8 +3,9 @@ #ifndef _TRITON_TOOLS_THREAD_GRAPH_H_ #define _TRITON_TOOLS_THREAD_GRAPH_H_ +#include "llvm/ADT/SetVector.h" + #include -#include #include #include @@ -13,21 +14,21 @@ namespace tools{ template class graph { - typedef std::map> edges_t; + typedef std::map> edges_t; public: typedef std::map> cmap_t; typedef std::map nmap_t; private: - void connected_components_impl(node_t x, std::set &nodes, + void connected_components_impl(node_t x, llvm::SetVector &nodes, nmap_t* nmap, cmap_t* cmap, int id) const { if(nmap) (*nmap)[x] = id; if(cmap) (*cmap)[id].push_back(x); - if(nodes.find(x) != nodes.end()) { - nodes.erase(x); + if (nodes.count(x)) { + nodes.remove(x); for(const node_t &y: edges_.at(x)) connected_components_impl(y, nodes, nmap, cmap, id); } @@ -39,7 +40,7 @@ public: cmap->clear(); if(nmap) nmap->clear(); - std::set nodes = nodes_; + llvm::SetVector nodes = nodes_; unsigned id = 0; while(!nodes.empty()){ connected_components_impl(*nodes.begin(), nodes, nmap, cmap, id++); @@ -59,7 +60,7 @@ public: } private: - std::set nodes_; + llvm::SetVector nodes_; edges_t edges_; }; diff --git a/python/triton/tools/compare_asm.py b/python/triton/tools/compare_asm.py new file mode 100644 index 000000000..e612022bd --- /dev/null +++ b/python/triton/tools/compare_asm.py @@ -0,0 +1,76 @@ +''' +Compare cached triton kernels in 2 directories. + +example: +python compare_asm.py --dir0=triton-works/ --dir1=triton-fails/ --asm=ttir \ + --diff-out0=diff-works.ll --diff-out1=diff-fails.ll +''' +import argparse +import os +import pickle + +parser = argparse.ArgumentParser(description="unpickle") +parser.add_argument('--dir0', dest='dir0', required=True, + help="Triton cache dir 0") +parser.add_argument('--dir1', dest='dir1', required=True, + help="Triton cache dir 1") +parser.add_argument('--asm', dest='asm', + choices=['ttir', 'llir', 'ptx', 'cubin'], required=True) +parser.add_argument('--early-stop', dest='early_stop', action='store_true', + help="Stop after first diff") +parser.set_defaults(early_stop=True) +parser.add_argument('--diff-out0', dest='diff_out0', required=True, + help="output file path for kernels in dir0") +parser.add_argument('--diff-out1', dest='diff_out1', required=True, + help="output file path for kernels in dir1") +args = parser.parse_args() +dir0 = args.dir0 +dir1 = args.dir1 +asm = args.asm + +dir0_files = {} +dir1_files = {} +for root, _, files in os.walk(dir0): + for file in files: + if not file.endswith('.lock'): + path = os.path.join(root, file) + with open(path, 'rb') as f: + loaded_file = pickle.load(f) + bin = loaded_file['binary'] + key = loaded_file['key'] + info = key.split('-')[-3:] # num_warps, num_stages, signature + dict_key = bin.name + '-'.join(info) + dir0_files[dict_key] = bin.asm + +for root, _, files in os.walk(dir1): + for file in files: + if not file.endswith('.lock'): + path = os.path.join(root, file) + with open(path, 'rb') as f: + loaded_file = pickle.load(f) + bin = loaded_file['binary'] + key = loaded_file['key'] + info = key.split('-')[-3:] # num_warps, num_stages, signature + dict_key = bin.name + '-'.join(info) + dir1_files[dict_key] = bin.asm + +diff_keys = [] +for key in dir0_files: + asm0 = dir0_files[key] + if key not in dir1_files: + continue + asm1 = dir1_files[key] + if asm0[asm] != asm1[asm]: + diff_keys.append(key) + +if args.early_stops: + diff_keys = diff_keys[:1] +if diff_keys: + with open(args.diff_out0, 'w') as f0, open(args.diff_out1, 'w') as f1: + for key in diff_keys: + f0.write(f'{asm} mismatch at {key}') + f0.write(dir0_files[key][asm]) + f0.write('\n') + f1.write(f'{asm} mismatch at {key}') + f1.write(dir1_files[key][asm]) + f1.write('\n')