This PR adds an automatic memory alignment mechanism in the Triton runtime. Specifically, the JIT compiler detects the alignment (in bytes) of each pointer argument as well as the largest power of two divisor (between 1 and 16) of each integer argument. Proper .aligned and .multipleof attributes are then added to the Triton-IR on-the-fly for all auto-tunable kernels. There is a cache that remembers all the kernels compiled for each possible configuration. This PR also includes substantial cleaning of the Python API. This adds 2-3us overhead, mostly due to accessing integer #defines from the auto-tuned compilation options. The previous solution was slightly faster but hacky and potentially unsafe, so this is preferred for now.
42 lines
1.6 KiB
Python
42 lines
1.6 KiB
Python
import argparse
|
|
import sys
|
|
import os
|
|
import inspect
|
|
import triton
|
|
|
|
def run_all(result_dir, with_plots, names):
|
|
if not os.path.exists(result_dir):
|
|
os.makedirs(result_dir)
|
|
for mod in os.listdir(os.path.dirname(os.path.realpath(__file__))):
|
|
# skip non python files
|
|
if not mod.endswith('.py'):
|
|
continue
|
|
# skip file not in provided names
|
|
if names and names not in mod:
|
|
continue
|
|
# skip files that don't start with 'bench_'
|
|
if not mod.startswith('bench_'):
|
|
continue
|
|
print(f'running {mod}...')
|
|
mod = __import__(os.path.splitext(mod)[0])
|
|
benchmarks = inspect.getmembers(mod, lambda x: isinstance(x, triton.testing.Mark))
|
|
for name, bench in benchmarks:
|
|
curr_dir = os.path.join(result_dir, mod.__name__.replace('bench_', ''))
|
|
if len(benchmarks) > 1:
|
|
curr_dir = os.path.join(curr_dir, name.replace('bench_', ''))
|
|
if not os.path.exists(curr_dir):
|
|
os.makedirs(curr_dir)
|
|
bench.run(curr_dir, with_plots)
|
|
|
|
def main(args):
|
|
parser = argparse.ArgumentParser(description="Run the benchmark suite.")
|
|
parser.add_argument("-r", "--result-dir", type=str, default='results', required=False)
|
|
parser.add_argument("-n", "--names", type=str, default='', required=False)
|
|
parser.add_argument("-p", "--with-plots", dest='with_plots', action='store_true')
|
|
parser.set_defaults(feature=False)
|
|
args = parser.parse_args(args)
|
|
run_all(args.result_dir, args.with_plots, args.names)
|
|
|
|
if __name__ == '__main__':
|
|
main(sys.argv[1:])
|