Making the generated Triton IR deterministic & a script to compare cached assembly (#522)
This commit is contained in:
@@ -1,12 +1,14 @@
|
||||
#ifndef TDL_INCLUDE_IR_CODEGEN_LIVENESS_H
|
||||
#define TDL_INCLUDE_IR_CODEGEN_LIVENESS_H
|
||||
|
||||
#include <map>
|
||||
#include <set>
|
||||
#include <vector>
|
||||
#include "triton/codegen/analysis/layout.h"
|
||||
#include "triton/tools/graph.h"
|
||||
|
||||
#include "llvm/ADT/MapVector.h"
|
||||
|
||||
#include <set>
|
||||
#include <vector>
|
||||
|
||||
namespace triton{
|
||||
|
||||
namespace ir{
|
||||
@@ -42,14 +44,14 @@ struct segment {
|
||||
|
||||
class liveness {
|
||||
private:
|
||||
typedef std::map<shared_layout*, segment> intervals_map_t;
|
||||
typedef llvm::MapVector<shared_layout*, segment> intervals_map_t;
|
||||
|
||||
public:
|
||||
// constructor
|
||||
liveness(layouts *l): layouts_(l){ }
|
||||
// accessors
|
||||
const intervals_map_t& get() const { return intervals_; }
|
||||
segment get(shared_layout* v) const { return intervals_.at(v); }
|
||||
segment get(shared_layout* v) const { return intervals_.lookup(v); }
|
||||
// run
|
||||
void run(ir::module &mod);
|
||||
|
||||
|
@@ -3,8 +3,9 @@
|
||||
#ifndef _TRITON_TOOLS_THREAD_GRAPH_H_
|
||||
#define _TRITON_TOOLS_THREAD_GRAPH_H_
|
||||
|
||||
#include "llvm/ADT/SetVector.h"
|
||||
|
||||
#include <map>
|
||||
#include <set>
|
||||
#include <vector>
|
||||
#include <iostream>
|
||||
|
||||
@@ -13,21 +14,21 @@ namespace tools{
|
||||
|
||||
template<class node_t>
|
||||
class graph {
|
||||
typedef std::map<node_t, std::set<node_t>> edges_t;
|
||||
typedef std::map<node_t, llvm::SetVector<node_t>> edges_t;
|
||||
|
||||
public:
|
||||
typedef std::map<size_t, std::vector<node_t>> cmap_t;
|
||||
typedef std::map<node_t, size_t> nmap_t;
|
||||
|
||||
private:
|
||||
void connected_components_impl(node_t x, std::set<node_t> &nodes,
|
||||
void connected_components_impl(node_t x, llvm::SetVector<node_t> &nodes,
|
||||
nmap_t* nmap, cmap_t* cmap, int id) const {
|
||||
if(nmap)
|
||||
(*nmap)[x] = id;
|
||||
if(cmap)
|
||||
(*cmap)[id].push_back(x);
|
||||
if(nodes.find(x) != nodes.end()) {
|
||||
nodes.erase(x);
|
||||
if (nodes.count(x)) {
|
||||
nodes.remove(x);
|
||||
for(const node_t &y: edges_.at(x))
|
||||
connected_components_impl(y, nodes, nmap, cmap, id);
|
||||
}
|
||||
@@ -39,7 +40,7 @@ public:
|
||||
cmap->clear();
|
||||
if(nmap)
|
||||
nmap->clear();
|
||||
std::set<node_t> nodes = nodes_;
|
||||
llvm::SetVector<node_t> nodes = nodes_;
|
||||
unsigned id = 0;
|
||||
while(!nodes.empty()){
|
||||
connected_components_impl(*nodes.begin(), nodes, nmap, cmap, id++);
|
||||
@@ -59,7 +60,7 @@ public:
|
||||
}
|
||||
|
||||
private:
|
||||
std::set<node_t> nodes_;
|
||||
llvm::SetVector<node_t> nodes_;
|
||||
edges_t edges_;
|
||||
};
|
||||
|
||||
|
76
python/triton/tools/compare_asm.py
Normal file
76
python/triton/tools/compare_asm.py
Normal file
@@ -0,0 +1,76 @@
|
||||
'''
|
||||
Compare cached triton kernels in 2 directories.
|
||||
|
||||
example:
|
||||
python compare_asm.py --dir0=triton-works/ --dir1=triton-fails/ --asm=ttir \
|
||||
--diff-out0=diff-works.ll --diff-out1=diff-fails.ll
|
||||
'''
|
||||
import argparse
|
||||
import os
|
||||
import pickle
|
||||
|
||||
parser = argparse.ArgumentParser(description="unpickle")
|
||||
parser.add_argument('--dir0', dest='dir0', required=True,
|
||||
help="Triton cache dir 0")
|
||||
parser.add_argument('--dir1', dest='dir1', required=True,
|
||||
help="Triton cache dir 1")
|
||||
parser.add_argument('--asm', dest='asm',
|
||||
choices=['ttir', 'llir', 'ptx', 'cubin'], required=True)
|
||||
parser.add_argument('--early-stop', dest='early_stop', action='store_true',
|
||||
help="Stop after first diff")
|
||||
parser.set_defaults(early_stop=True)
|
||||
parser.add_argument('--diff-out0', dest='diff_out0', required=True,
|
||||
help="output file path for kernels in dir0")
|
||||
parser.add_argument('--diff-out1', dest='diff_out1', required=True,
|
||||
help="output file path for kernels in dir1")
|
||||
args = parser.parse_args()
|
||||
dir0 = args.dir0
|
||||
dir1 = args.dir1
|
||||
asm = args.asm
|
||||
|
||||
dir0_files = {}
|
||||
dir1_files = {}
|
||||
for root, _, files in os.walk(dir0):
|
||||
for file in files:
|
||||
if not file.endswith('.lock'):
|
||||
path = os.path.join(root, file)
|
||||
with open(path, 'rb') as f:
|
||||
loaded_file = pickle.load(f)
|
||||
bin = loaded_file['binary']
|
||||
key = loaded_file['key']
|
||||
info = key.split('-')[-3:] # num_warps, num_stages, signature
|
||||
dict_key = bin.name + '-'.join(info)
|
||||
dir0_files[dict_key] = bin.asm
|
||||
|
||||
for root, _, files in os.walk(dir1):
|
||||
for file in files:
|
||||
if not file.endswith('.lock'):
|
||||
path = os.path.join(root, file)
|
||||
with open(path, 'rb') as f:
|
||||
loaded_file = pickle.load(f)
|
||||
bin = loaded_file['binary']
|
||||
key = loaded_file['key']
|
||||
info = key.split('-')[-3:] # num_warps, num_stages, signature
|
||||
dict_key = bin.name + '-'.join(info)
|
||||
dir1_files[dict_key] = bin.asm
|
||||
|
||||
diff_keys = []
|
||||
for key in dir0_files:
|
||||
asm0 = dir0_files[key]
|
||||
if key not in dir1_files:
|
||||
continue
|
||||
asm1 = dir1_files[key]
|
||||
if asm0[asm] != asm1[asm]:
|
||||
diff_keys.append(key)
|
||||
|
||||
if args.early_stops:
|
||||
diff_keys = diff_keys[:1]
|
||||
if diff_keys:
|
||||
with open(args.diff_out0, 'w') as f0, open(args.diff_out1, 'w') as f1:
|
||||
for key in diff_keys:
|
||||
f0.write(f'{asm} mismatch at {key}')
|
||||
f0.write(dir0_files[key][asm])
|
||||
f0.write('\n')
|
||||
f1.write(f'{asm} mismatch at {key}')
|
||||
f1.write(dir1_files[key][asm])
|
||||
f1.write('\n')
|
Reference in New Issue
Block a user