diff --git a/CMakeLists.txt b/CMakeLists.txt
new file mode 100644
index 000000000..bdb9e1ce7
--- /dev/null
+++ b/CMakeLists.txt
@@ -0,0 +1,45 @@
+cmake_minimum_required(VERSION 2.8)
+project(triton)
+include(CTest)
+list(APPEND CMAKE_MODULE_PATH "${CMAKE_CURRENT_SOURCE_DIR}/cmake")
+
+# Options
+option(BUILD_TESTS "Build C++ Triton tests" ON)
+option(BUILD_PYTHON_MODULE "Build Python Triton bindings" OFF)
+
+# LLVM
+find_package(LLVM REQUIRED)
+include_directories(${LLVM_INCLUDE_DIRS})
+add_definitions(${LLVM_DEFINITIONS})
+
+# Default build type
+if(NOT CMAKE_BUILD_TYPE)
+ message(STATUS "Default build type: Release")
+ set(CMAKE_BUILD_TYPE "Release")
+endif()
+
+# Compiler flags
+include_directories(${CMAKE_CURRENT_SOURCE_DIR}/include)
+set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -std=c++11")
+
+# Tests
+if(BUILD_TESTS)
+ message(STATUS "Adding C++ tests")
+ add_subdirectory(tests)
+endif()
+
+# Python module
+if(BUILD_PYTHON_MODULE)
+ message(STATUS "Adding Python module")
+ # PyBind11 wrapper source file
+ file(GLOB_RECURSE PYTHON_SRC python/src/bindings.cc)
+ include_directories(python/src/ ${PYTHON_INCLUDE_DIRS})
+endif()
+
+
+# Triton
+file(GLOB_RECURSE LIBTRITON_SRC lib/*.cc)
+add_library(triton SHARED ${LIBTRITON_SRC} ${PYTHON_SRC})
+link_directories(${LLVM_LIBRARY_DIRS})
+target_link_libraries(triton ${LLVM_LIBRARIES})
+
diff --git a/LICENSE b/LICENSE
new file mode 100755
index 000000000..464fb143d
--- /dev/null
+++ b/LICENSE
@@ -0,0 +1,26 @@
+/* Copyright 2018-2019 Philippe Tillet
+*
+* Permission is hereby granted, free of charge, to any person obtaining
+* a copy of this software and associated documentation files
+* (the "Software"), to deal in the Software without restriction,
+* including without limitation the rights to use, copy, modify, merge,
+* publish, distribute, sublicense, and/or sell copies of the Software,
+* and to permit persons to whom the Software is furnished to do so,
+* subject to the following conditions:
+*
+* The above copyright notice and this permission notice shall be
+* included in all copies or substantial portions of the Software.
+*
+* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
+* EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
+* MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
+* IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
+* CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
+* TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
+* SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
+*/
+
+
+// The compiler front-end is based on a modified version of WGTCC
+// https://github.com/wgtdkp/wgtcc
+// Copyright (c) 2016 wgtdkp
\ No newline at end of file
diff --git a/README.md b/README.md
new file mode 100644
index 000000000..8a5fe3e98
--- /dev/null
+++ b/README.md
@@ -0,0 +1,36 @@
+# Triton
+
+This is the development repository of Triton, a language and compiler for writing highly efficient custom Deep-Learning primitives.
+
+The formal foundations of this project are described in the following MAPL2019 publication: [Triton: An Intermediate Language and Compiler for Tiled Neural Network Computations](http://www.eecs.harvard.edu/~htk/publication/2019-mapl-tillet-kung-cox.pdf). Please cite us if you use our work!
+
+
+The main features of Triton at the moment are:
+- **PyTriton**: A Python API for writing custom operations for Triton-C compute-kernels. PyTriton automatically generates and just-in-time Tensorflow and PyTorch bindings.
+- **Triton-C**: An imperative, single-threaded language for writing highly efficient compute-kernels at a relatively high abstraction level using numpy-like extensions of the C language.
+- **Triton-IR**: An intermediate-representation for optimizing multi-dimensional array operations in linear algebra programs
+- **Triton-JIT**: An optimizing just-in-time compiler for Triton-C, which generates GPU code on par with state-of-the-art CUDA-C (e.g., [CUTLASS](https://github.com/NVIDIA/cutlass)) and PTX (e.g., [ISAAC](https://github.com/ptillet/isaac)). This includes transparent support for mixed-precision and Tensor Cores.
+
+
+
+
+## Installation
+
+Triton is a fairly self-contained package and uses its own parser (forked from [wgtcc](https://github.com/wgtdkp/wgtcc)) and LLVM code-generator. However, at the moment it still relies on LLVM-8.0+ for PTX code generation.
+
+```
+sudo apt-get install llvm-8-dev
+git clone https://github.com/ptillet/triton.git;
+cd triton/python/;
+python setup.py develop;
+cd examples;
+python dot.py
+```
+
+## Tutorials
+
+- [The PyTriton API](https://github.com/ptillet/triton/blob/master/docs/pytriton.md)
+- [The Triton-C language](https://github.com/ptillet/triton/blob/master/docs/triton-c.md)
+- The Triton-IR representation (coming soon...)
+- The Triton-JIT compiler (coming soon...)
+
diff --git a/cmake/FindLLVM.cmake b/cmake/FindLLVM.cmake
new file mode 100644
index 000000000..30ebcbd89
--- /dev/null
+++ b/cmake/FindLLVM.cmake
@@ -0,0 +1,166 @@
+# - Find LLVM headers and libraries.
+# This module locates LLVM and adapts the llvm-config output for use with
+# CMake.
+#
+# A given list of COMPONENTS is passed to llvm-config.
+#
+# The following variables are defined:
+# LLVM_FOUND - true if LLVM was found
+# LLVM_CXXFLAGS - C++ compiler flags for files that include LLVM headers.
+# LLVM_HOST_TARGET - Target triple used to configure LLVM.
+# LLVM_INCLUDE_DIRS - Directory containing LLVM include files.
+# LLVM_LDFLAGS - Linker flags to add when linking against LLVM
+# (includes -LLLVM_LIBRARY_DIRS).
+# LLVM_LIBRARIES - Full paths to the library files to link against.
+# LLVM_LIBRARY_DIRS - Directory containing LLVM libraries.
+# LLVM_ROOT_DIR - The root directory of the LLVM installation.
+# llvm-config is searched for in ${LLVM_ROOT_DIR}/bin.
+# LLVM_VERSION_MAJOR - Major version of LLVM.
+# LLVM_VERSION_MINOR - Minor version of LLVM.
+# LLVM_VERSION_STRING - Full LLVM version string (e.g. 6.0.0svn).
+# LLVM_VERSION_BASE_STRING - Base LLVM version string without git/svn suffix (e.g. 6.0.0).
+#
+# Note: The variable names were chosen in conformance with the offical CMake
+# guidelines, see ${CMAKE_ROOT}/Modules/readme.txt.
+
+# Try suffixed versions to pick up the newest LLVM install available on Debian
+# derivatives.
+# We also want an user-specified LLVM_ROOT_DIR to take precedence over the
+# system default locations such as /usr/local/bin. Executing find_program()
+# multiples times is the approach recommended in the docs.
+set(llvm_config_names llvm-config-9 llvm-config-9.0 llvm-config90
+ llvm-config-8 llvm-config-8.0 llvm-config80
+ llvm-config)
+find_program(LLVM_CONFIG
+ NAMES ${llvm_config_names}
+ PATHS ${LLVM_ROOT_DIR}/bin NO_DEFAULT_PATH
+ DOC "Path to llvm-config tool.")
+find_program(LLVM_CONFIG NAMES ${llvm_config_names})
+
+# Prints a warning/failure message depending on the required/quiet flags. Copied
+# from FindPackageHandleStandardArgs.cmake because it doesn't seem to be exposed.
+macro(_LLVM_FAIL _msg)
+ if(LLVM_FIND_REQUIRED)
+ message(FATAL_ERROR "${_msg}")
+ else()
+ if(NOT LLVM_FIND_QUIETLY)
+ message(STATUS "${_msg}")
+ endif()
+ endif()
+endmacro()
+
+
+if(NOT LLVM_CONFIG)
+ if(NOT LLVM_FIND_QUIETLY)
+ message(WARNING "Could not find llvm-config (LLVM >= ${LLVM_FIND_VERSION}). Try manually setting LLVM_CONFIG to the llvm-config executable of the installation to use.")
+ endif()
+else()
+ macro(llvm_set var flag)
+ if(LLVM_FIND_QUIETLY)
+ set(_quiet_arg ERROR_QUIET)
+ endif()
+ set(result_code)
+ execute_process(
+ COMMAND ${LLVM_CONFIG} --${flag}
+ RESULT_VARIABLE result_code
+ OUTPUT_VARIABLE LLVM_${var}
+ OUTPUT_STRIP_TRAILING_WHITESPACE
+ ${_quiet_arg}
+ )
+ if(result_code)
+ _LLVM_FAIL("Failed to execute llvm-config ('${LLVM_CONFIG}', result code: '${result_code})'")
+ else()
+ if(${ARGV2})
+ file(TO_CMAKE_PATH "${LLVM_${var}}" LLVM_${var})
+ endif()
+ endif()
+ endmacro()
+ macro(llvm_set_libs var flag)
+ if(LLVM_FIND_QUIETLY)
+ set(_quiet_arg ERROR_QUIET)
+ endif()
+ set(result_code)
+ execute_process(
+ COMMAND ${LLVM_CONFIG} --${flag} ${LLVM_FIND_COMPONENTS}
+ RESULT_VARIABLE result_code
+ OUTPUT_VARIABLE tmplibs
+ OUTPUT_STRIP_TRAILING_WHITESPACE
+ ${_quiet_arg}
+ )
+ if(result_code)
+ _LLVM_FAIL("Failed to execute llvm-config ('${LLVM_CONFIG}', result code: '${result_code})'")
+ else()
+ file(TO_CMAKE_PATH "${tmplibs}" tmplibs)
+ string(REGEX MATCHALL "${pattern}[^ ]+" LLVM_${var} ${tmplibs})
+ endif()
+ endmacro()
+
+ llvm_set(VERSION_STRING version)
+ llvm_set(CXXFLAGS cxxflags)
+ llvm_set(HOST_TARGET host-target)
+ llvm_set(INCLUDE_DIRS includedir true)
+ llvm_set(ROOT_DIR prefix true)
+ llvm_set(ENABLE_ASSERTIONS assertion-mode)
+
+ # The LLVM version string _may_ contain a git/svn suffix, so cut that off
+ string(SUBSTRING "${LLVM_VERSION_STRING}" 0 5 LLVM_VERSION_BASE_STRING)
+
+ # Versions below 4.0 do not support components debuginfomsf and demangle
+ if(${LLVM_VERSION_STRING} MATCHES "^3\\..*")
+ list(REMOVE_ITEM LLVM_FIND_COMPONENTS "debuginfomsf" index)
+ list(REMOVE_ITEM LLVM_FIND_COMPONENTS "demangle" index)
+ endif()
+ # Versions below 8.0 not supported
+ if(${LLVM_VERSION_STRING} MATCHES "^[3-7]\\..*")
+ message(FATAL_ERROR "LLVM version below 8.0 not supported")
+ endif()
+
+ llvm_set(LDFLAGS ldflags)
+ # In LLVM 3.5+, the system library dependencies (e.g. "-lz") are accessed
+ # using the separate "--system-libs" flag.
+ llvm_set(SYSTEM_LIBS system-libs)
+ string(REPLACE "\n" " " LLVM_LDFLAGS "${LLVM_LDFLAGS} ${LLVM_SYSTEM_LIBS}")
+ llvm_set(LIBRARY_DIRS libdir true)
+ llvm_set_libs(LIBRARIES libs)
+ # LLVM bug: llvm-config --libs tablegen returns -lLLVM-3.8.0
+ # but code for it is not in shared library
+ if("${LLVM_FIND_COMPONENTS}" MATCHES "tablegen")
+ if (NOT "${LLVM_LIBRARIES}" MATCHES "LLVMTableGen")
+ set(LLVM_LIBRARIES "${LLVM_LIBRARIES};-lLLVMTableGen")
+ endif()
+ endif()
+
+ # Versions below 4.0 do not support llvm-config --cmakedir
+ if(${LLVM_VERSION_STRING} MATCHES "^3\\..*")
+ set(LLVM_CMAKEDIR ${LLVM_LIBRARY_DIRS}/cmake/llvm)
+ else()
+ llvm_set(CMAKEDIR cmakedir)
+ endif()
+
+ llvm_set(TARGETS_TO_BUILD targets-built)
+ string(REGEX MATCHALL "${pattern}[^ ]+" LLVM_TARGETS_TO_BUILD ${LLVM_TARGETS_TO_BUILD})
+endif()
+
+# Remove some clang-specific flags for gcc.
+if(CMAKE_COMPILER_IS_GNUCXX)
+ string(REPLACE "-Wcovered-switch-default " "" LLVM_CXXFLAGS ${LLVM_CXXFLAGS})
+ string(REPLACE "-Wstring-conversion " "" LLVM_CXXFLAGS ${LLVM_CXXFLAGS})
+ string(REPLACE "-fcolor-diagnostics " "" LLVM_CXXFLAGS ${LLVM_CXXFLAGS})
+ string(REPLACE "-Werror=unguarded-availability-new " "" LLVM_CXXFLAGS ${LLVM_CXXFLAGS})
+endif()
+
+# Remove gcc-specific flags for clang.
+if(${CMAKE_CXX_COMPILER_ID} MATCHES "Clang")
+ string(REPLACE "-Wno-maybe-uninitialized " "" LLVM_CXXFLAGS ${LLVM_CXXFLAGS})
+endif()
+
+string(REGEX REPLACE "([0-9]+).*" "\\1" LLVM_VERSION_MAJOR "${LLVM_VERSION_STRING}" )
+string(REGEX REPLACE "[0-9]+\\.([0-9]+).*[A-Za-z]*" "\\1" LLVM_VERSION_MINOR "${LLVM_VERSION_STRING}" )
+
+
+# Use the default CMake facilities for handling QUIET/REQUIRED.
+include(FindPackageHandleStandardArgs)
+
+find_package_handle_standard_args(LLVM
+ REQUIRED_VARS LLVM_ROOT_DIR LLVM_HOST_TARGET
+ VERSION_VAR LLVM_VERSION_STRING)
diff --git a/docs/pytriton.md b/docs/pytriton.md
new file mode 100644
index 000000000..0e04e6246
--- /dev/null
+++ b/docs/pytriton.md
@@ -0,0 +1,243 @@
+# The PyTriton API
+
+
+## Table of Contents
+
+1. [Motivations](#motivations)
+2. [Triton Functions](#pytriton-function)
+ 1. [Creation of Triton Kernels](#creation-triton-kernels)
+ 2. [Usage of Triton Kernels](#usage-triton-kernels)
+3. [Integration with Automatic Differentiation](#autodiff)
+ 1. [Basics](#autodiff-basics)
+ 2. [Convenience](#autodiff-convenience)
+
+
+## Motivations
+
+
+The purpose of PyTriton is to provide an API for easily executing Triton-C kernels from PyTorch and Tensorflow. One of the main advantages of PyTriton is that it is framework agnostic: any custom op written using this API will be transparently compatible with both Tensorflow and PyTorch without any additional effort required, as will be shown in this tutorial.
+
+Consider for example the following piece of code:
+
+```python
+import numpy as np
+import triton
+
+def run_tf():
+ M, N, K = 128, 128, 128
+ a = tf.placeholder(tf.float32, shape=[M, K])
+ b = tf.placeholder(tf.float32, shape=[N, K])
+ c = triton.ops.dot(a, b, transpose_a = False, transpose_b = True)
+ da, db = tf.gradients(c, [a, b])
+ # Run
+ ha = np.random.rand(M, K).astype(np.float32)
+ hb = np.random.rand(K, N).astype(np.float32)
+ sess = tf.InteractiveSession()
+ sess.run(tf.global_variables_initializer())
+ result = sess.run([da], feed_dict = {a: ha, b: hb})
+
+def run_torch():
+ M, N, K = 128, 128, 128
+ a = torch.randn(M, K).cuda()
+ b = torch.randn(K, N).cuda()
+ a.requires_grad_(True)
+ b.requires_grad_(True)
+ c = triton.ops.dot(a, b, False, True)
+ c.backward()
+ da = a.grad.clone()
+ db = b.grad.clone()
+
+## Run on tensorflow
+# import tensorflow as tf
+# run_tf()
+
+## Run on pytorch
+# import torch
+# run_torch()
+```
+
+PyTriton works by detecting which frameworks are imported and automatically generating and just-in-time compiling C++ binding code for them. Specifically, the following chain of events is triggered when a Triton operation is executed:
+
+1. The imported frameworks are detected
+2. C++ binding code for Tensorflow or PyTorch is generated, compiled and cached.
+3. The corresponding custom-op is automatically loaded from the generated .so file, and a framework-agnostic wrapper is created.
+4. The wrapper is called and a tf.tensor or a torch.tensor is returned. In the case of Tensorflow, the gradient is also registered at this point if applicable
+
+
+The remainder of this tutorial will show you how to re-implement the above `triton.ops.dot` operation from scratch.
+
+## PyTriton Functions
+
+The PyTriton API provides a `triton.function` class which automatically handles the interaction with automatic differentiation in whichever framework was detected. Therefore, every differentiable custom operation written with PyTriton should inherit from this class
+
+```python
+import triton
+
+# Entry point
+class _dot(triton.function):
+
+ @staticmethod
+ # Forward Pass
+ def forward(ctx, *args):
+ #...
+
+ @staticmethod
+ # Backward Pass
+ def backward(ctx, dy):
+ #...
+```
+
+### Creation of Triton Kernels
+
+
+PyTriton also provides a `triton.kernel` class which automatically takes care of interaction with the Triton-JIT as well as the generation and compilation of C++ framework bindings code. For our dot operation we create a kernel from the Triton-C code derived at the end of the [previous tutorial](https://github.com/ptillet/triton/blob/master/docs/triton-c.md)
+
+```
+src = """
+__global__ void dot(TYPE * A, TYPE * B, TYPE * C,
+ int M, int N, int K,
+ int lda __multipleof(8), int ldb __multipleof(8), int ldc __multipleof(8)) {
+ // prologue
+ int pm = get_program_id(0);
+ int pn = get_program_id(1);
+ int rm[TM] = pm * TM + 0 ... TM;
+ int rn[TN] = pn * TN + 0 ... TN;
+ int rk[TK] = 0 ... TK;
+ float c[TM, TN] = 0;
+ // pointers to operands
+ TYPE* pa[SHAPE_A] = A + rk[BROADCAST_AK] * STRIDE_AK + rm[BROADCAST_AM] * STRIDE_AM;
+ TYPE* pb[SHAPE_B] = B + rk[BROADCAST_BK] * STRIDE_BK + rn[BROADCAST_BN] * STRIDE_BN;
+ // prefetches operands
+ TYPE a[SHAPE_A] = (*pa);
+ TYPE b[SHAPE_B] = (*pb);
+ // reduction loop
+ for(int k = K; k > 0; k-= TK){
+ c += USE_A @ USE_B;
+ pa = pa + TK * STRIDE_AK;
+ pb = pb + TK * STRIDE_BK;
+ a = *pa;
+ b = *pb;
+ }
+ // epilogue
+ int rcm[TM] = pm * TM + 0 ... TM;
+ int rcn[TN] = pn * TN + 0 ... TN;
+ TYPE* pc[TM, TN] = C + rcn[newaxis, :] + rcm[:, newaxis] * ldc;
+ *pc = c;
+}
+
+}
+"""
+
+ kernel = triton.kernel(src, ['C'])
+```
+
+Note that the second argument to `triton.kernel` constructors indicates which of the operands our kernel function should return. Here, we only return `C`.
+
+At this point, `kernel` is a callable object which takes the same signature as the `dot` function in our source code, except that pointers are treated as tensors:
+```
+[tensor, tensor, tensor, int, int, int, int, int, int]
+```
+
+### Usage of Triton Kernels
+
+However, in practice only A, B are provided by the user, and all the other `int` arguments should be derived from these operands only. Hence, we create a helper function that extracts shapes from the `A` and `B` tensors, and then returns the results of a call to `kernel`:
+
+```python
+ @staticmethod
+ def _call(a, b, transpose_a, transpose_b):
+ # extract shapes
+ shape_a = triton.shape(a)
+ shape_b = triton.shape(b)
+ M, Ka = shape_a[0], shape_a[1]
+ Kb, N = shape_b[0], shape_b[1]
+ # transpose shapes
+ if transpose_a:
+ M, Ka = Ka, M
+ if transpose_b:
+ Kb, N = N, Kb
+ # contiguous dimensions
+ lda = M if transpose_a else Ka
+ ldb = Kb if transpose_b else N
+ ldc = N
+ # data-type
+ dtype = a.dtype
+ # allocate output
+ c = triton.empty([M, N], dtype = dtype)
+ # compute
+ grid = lambda opt: [triton.cdiv(M, opt.d('TM')), triton.cdiv(N, opt.d('TN'))]
+ # macros -- not necessary but makes kernel source-code simpler
+ macros = {# handle A transposition
+ 'USE_A' : '^a' if transpose_a else 'a',
+ 'STRIDE_AK' : 'lda' if transpose_a else '1',
+ 'STRIDE_AM' : '1' if transpose_a else 'lda',
+ 'BROADCAST_AK': ':, newaxis' if transpose_a else 'newaxis, :',
+ 'BROADCAST_AM': 'newaxis, :' if transpose_a else ':, newaxis',
+ 'SHAPE_A' : 'TK, TM' if transpose_a else 'TM, TK',
+ # handle B transposition
+ 'USE_B' : '^b' if transpose_b else 'b',
+ 'STRIDE_BK' : '1' if transpose_b else 'ldb',
+ 'STRIDE_BN' : 'ldb' if transpose_b else '1',
+ 'BROADCAST_BK': 'newaxis, :' if transpose_b else ':, newaxis',
+ 'BROADCAST_BN': ':, newaxis' if transpose_b else 'newaxis, :',
+ 'SHAPE_B' : 'TN, TK' if transpose_b else 'TK, TN'}
+ return _dot.kernel(a, b, c, M, N, Ka, lda, ldb, ldc, grid,
+ AT = transpose_a, BT = transpose_b, TYPE = dtype,
+ TM = [32, 64, 128], TN = [32, 64, 128], TK = [8], **macros)
+
+```
+
+While this code should be mostly self-explanatory, there are a few of noteworthy things worth pointing out
+
+- `triton.shape` provides a framework-agnostic way to retrieve the shape of a tensor
+
+- `triton.empty` creates an empty tensor of the specified dimensions
+
+- `grid` corresponds to the grid with which our Triton kernel will be launched. Because in our case this grid depends on parametric tile variables, it is supplied as a function of compilation options `opt`, whose compile-time definition can be retrieved using `opt.d(name)`. Here, `opt.d('TM')` and `opt.d('TN')` retrieve the first and second tile dimension our kernel was compiled with. We also provide a helper `triton.cdiv` for ceil divisions.
+
+- `macros` provides a list of preprocessor definitions to compile the kernel with. Alternatively, these can also be supplied as named argument to the `_dot.kernel`. We recall that lists can be supplied to the preprocessor, in which case an auto-tuning procedure will be triggered. Here, the value of `TM` and `TN` are both tuned between 32, 64 and 128.
+
+## Compatibility with Automatic Differentiation
+
+At this point, our custom operation only takes two tensor arguments and transposition information, which is good. However, it is still not compatible with PyTorch's or TensorFlow's automatic differentiation engine, and a small amount of additional effort is needed.
+
+### Basics
+
+PyTriton binds to Tensorflow's and PyTorch's automatic differentiation framework using a single, common API inspired by PyTorch. It consists of two static methods `forward` and `backward` that take a context as their first input:
+
+```
+ @staticmethod
+ def forward(ctx, a, b, transpose_a = False, transpose_b = False):
+ ctx.save_for_backward(a, b)
+ ctx.t_a = transpose_a
+ ctx.t_b = transpose_b
+ return _dot._call(a, b, transpose_a, transpose_b)
+
+ @staticmethod
+ def backward(ctx, dy):
+ a, b = ctx.saved_tensors
+ t_a, t_b = ctx.t_a, ctx.t_b
+ if not t_a and not t_b:
+ da = _dot._call(dy, b, False, True)
+ db = _dot._call(a, dy, True, False)
+ elif not t_a and t_b:
+ da = _dot._call(dy, b, False, False)
+ db = _dot._call(dy, a, True, False)
+ elif t_a and not t_b:
+ da = _dot._call(b, dy, False, True)
+ db = _dot._call(a, dy, False, False)
+ elif t_a and t_b:
+ da = _dot._call(b, dy, True, True)
+ db = _dot._call(dy, a, True, True)
+ else:
+ assert False
+ return da, db, None, None, None, None, None, None, None
+```
+
+### Convenience
+
+Still like for PyTorch, a callable operation can be created using the `apply` method of our `triton.function` class. We wrap it as a module variable for convenience:
+
+```python
+dot = _dot.apply
+```
+And that's it! Our custom op is now created and ready to be used with both PyTorch and Tensorflow.
diff --git a/docs/triton-c.md b/docs/triton-c.md
new file mode 100644
index 000000000..6222169b5
--- /dev/null
+++ b/docs/triton-c.md
@@ -0,0 +1,436 @@
+# The Triton-C Programming Language
+
+## Table of Contents
+1. [Motivations](#motivations)
+2. [Vector Addition](#vector-addition)
+ 1. [Differences with CUDA](#differences-with-cuda)
+ 2. [Advantages over CUDA](#advantages-over-cuda)
+ 1. [Vectorization](#vectorization)
+ 2. [Parameterization](#parameterization)
+ 3. [Auto-Tuning](#auto-tuning)
+3. [Matrix Transposition](#matrix-transposition)
+ 1. [Compute Kernel](#trans-compute-kernel)
+ 2. [The __multipleof Attribute](#trans-multipleof)
+ 3. [Conditional Dereferencing](#conditional-dereferencing)
+4. [Matrix Multiplication](#matrix-multiplication)
+ 1. [Compute Kernel](#matmul-compute-kernel)
+ 2. [Optimizations](#optimizations)
+ 1. [Pre-Fetching](#pre-fetching)
+ 1. [Rematerialization](#rematerialization)
+ 3. [Fused Transpositions and Auto-Tuning](#fused-trans-autotuning)
+
+## Motivations
+
+In C and C++, arrays and pointers have similar semantics. Indeed, there is no native way to manipulate statically shaped multi-dimensional arrays (beyond initialization) as a whole:
+
+```c
+// C99
+float x[16][8] = {3.14};
+float y[16][8] = {5.17};
+// z = x + y
+float z[16][8];
+#pragma unroll
+for(int i = 0; i < 16; i++)
+ #pragma unroll
+ for(int j = 0; j < 8; j++)
+ z[i][j] = x[i][j] + y[i][j];
+```
+
+While it does not seem like a big deal at first sight, there are two issues with this:
+
+- **Ergonomics**: Of course, it is possible to simplify the above code using functions in C
+```
+float z[16][8];
+add(z, x, y, 16, 8);
+```
+but this would be semantically different as the loops can no longer be unrolled due to their bounds being now dynamic arguments of the add function. This can be mitigated using templates metaprogramming (and operator overloads) in C++:
+
+```c
+// C++
+template
+class matrix;
+
+matrix x = {3.14};
+matrix y = {5.17};
+matrix z = x + y;
+```
+
+While this is better and now equivalent to our initial code snippet, the syntax is not quite as ergonomically satisfying as what native syntactic support could provide:
+```c
+// Triton-C
+float x[16, 8] = 3.14;
+float y[16, 8] = 5.17;
+// float z[8, 8] = x + y; // doesn't compile -- incompatible shapes!
+float z[16, 8] = x + y;
+float u[16] = z[:, +]; // sum along the second axis
+float v[16, 32] = u[:, newaxis]; // broadcasting along the second axis
+```
+which is valid _Triton-C_.
+
+
+- **Portability**: One other issue with our initial C program is that it is not portable. While it will run well on a single CPU thread, the operation `z = x + y` would underutilize a GPU Streaming Processor as it would execute on a single thread only. For this reason, it would have to be rewritten in CUDA as follows:
+```
+// CUDA
+// Launch on a block of 16 x 8 threads
+float x = 3.14;
+float y = 5.17;
+float z = x + y
+```
+In Triton-C, the same code can be used across many different platforms (only CPUs and GPUs are supported at the moment). Furthermore, Triton-C is single-threaded, hence easier to write than CUDA.
+
+- **Performance**: Another issue with our initial C code snippet is its performance. Although the loops are unrolled, the program does not carry any data-flow information pertaining to array operations. This issue gets more and more problematic as programs get increasingly complex, eventually culminating in matrix multiplication being remarkably hard to optimize.
+
+ This can be worked around using heavy metaprogramming techniques (see [CUTLASS](https://github.com/NVIDIA/cutlass)), but even then programmers still have to allocate and synchronize shared memory manually and endure prohibitively long compilation procedures not easily amenable to auto-tuning. For these reasons, most Deep-Learning frameworks still rely heavily on highly optimized subroutines (e.g., BLAS), which makes the development of novel custom primitives time-consuming for experts and almost impossible for others.
+
+ Triton addresses this issue by relying on **Triton-IR**, an LLVM-like IR for array operations, and **Triton-JIT**, an optimizing compiler for Triton-IR. These two systems are, however, beyond the scope of this tutorial. More information can be found [here](http://www.eecs.harvard.edu/~htk/publication/2019-mapl-tillet-kung-cox.pdf).
+
+
+_Note: You might be thinking that this is exactly what [MLIR](https://github.com/tensorflow/mlir) was made for... and you're right! You can conceptually think of Triton-IR as a dialect for MLIR, and Triton-C as a frontend for it. I would like to integrate Triton-IR into MLIR in the future; If you're interested in making this a thing, let me know._
+
+
+
+## Vector Addition
+
+### Differences with CUDA
+
+Let's start it off by looking at a simple example. Vector addition, in its most trivial Triton-C implementation, can be written as follows:
+
+```c
+// Triton-C
+// launched on a grid of (N / 32) programs of 1 thread each
+__global__ void add(int N, float *a, float *b, float* c) {
+ int id = get_program_id(0);
+ int off[32] = id * 32 + (0 ... 32)
+ *(c + off) = *(a + off) + *(b + off);
+}
+```
+For reference, here is an equivalent CUDA kernel (NVCC will generate the same PTX code as Triton-JIT on the above code):
+
+```c
+// CUDA
+// launched on a grid of (N / 32) programs of 32 threads each
+__global__ void add(int N, float *a, float *b, float *c) {
+ int off = blockIdx.x * 32 + threadIdx.x;
+ c[off] = a[off] + b[off];
+}
+```
+
+As you can see, there are three main differences between our Triton-C kernel and the equivalent CUDA:
+
+- **The programming model is different**.
+While Triton-C and CUDA both use a Single-Program, Multiple-Data (SPMD) programming model, each Triton-C kernel is single-threaded.
+ Therefore, `get_program_id({0, 1, 2})` is equivalent to `blockIdx.{x, y, z}`, but there is no such thing as `blockDim` and `threadIdx`.
+
+- **The semantics of arrays is different**
+In the above Triton-C kernel, `off` is an array of 32 consecutive integers: `int off[32] = {id * 32 + 0, id * 32 + 1, ..., id * 32 + 31}`.
+ As a result, the statement: `c + off` implicitly broadcast `c` and creates an array of 32 pointers. This could also be done explicitly as follows:
+```
+float* c_broadcast[32] = c;
+float* c_ptr[32] = c_broadcast + off; // c_ptr = c + off
+```
+
+- **The semantics of the subscript operator is different**.
+n C/CUDA, subscripting can be used to offset and dereference a pointer, but in Triton-C it can only be used to index and broadcast an array (think NumPy).
+
+### Advantages over CUDA
+
+At this point, the advantages of Triton-C over CUDA may not be obvious. But they should become clearer and clearer as this tutorial progresses. First and foremost, the purpose of this subsection is to show how Triton can be used to optimize vector additions by automatically taking care of load/store vectorization, code parameterization and auto-tuning -- all of which require nontrivial implementation efforts in CUDA.
+
+#### Vectorization
+
+On some hardware architectures, vectorizing load/store operations can lead to better memory utilization and, in turn, noticeable performance gains. In general, 128-bit memory transactions are favored, leading to the following CUDA kernel:
+```c
+// CUDA
+// launched on a grid of (N / 128) programs of 32 threads each
+__global__ void add(int N, float4 *a, float4 *b, float4 *c) {
+ int off = blockIdx.x * 32 + threadIdx.x;
+ c[off] = a[off] + b[off];
+}
+```
+Or, for half-precision inputs:
+```c
+// CUDA
+// launched on a grid of (N / 256) programs of 32 threads each
+__global__ void add(int N, half8 *a, half8 *b, half8 *c) {
+ int off = blockIdx.x * 32 + threadIdx.x;
+ c[off] = a[off] + b[off];
+}
+```
+
+Now this is a bit annoying, because as a programmer you have to keep track of not only the ideal vector size for each data-type (which might change in future GPU architectures), but also of how many elements are processed in each thread-block -- and adjust the grid size of the kernel accordingly! Not to mention that you may want to tune the thread-block size as well.
+
+In Triton-C, this is not a problem as the compiler will figure out automatically when and where vectorization should be used, without any change in the source-code necessary.
+
+#### Parameterization
+
+Specifically, the Triton compiler would refuse to 4-way vectorize our above compute kernel because it would require the array `int off[32]` to be distributed over 8 threads, which is less than a warp. Fortunately, it turns out that this problem can be easily solved using preprocessor directrives to _parameterize_ our kernel:
+```c
+// Triton-C
+// launched on a grid of (N / SIZE) programs of 1 thread each
+__global__ void add(int N, TYPE* a, TYPE* b, TYPE* c) {
+ int id = get_program_id(0);
+ int off[SIZE] = id * SIZE + (0 ... SIZE);
+ *(c + off) = *(a + off) + *(b + off);
+}
+// Not vectorized when compiled with -DSIZE=32 -DTYPE=float
+// 4-Vectorized when compiled with -DSIZE=128 -DTYPE=float
+// 8-Vectorized when compiled with -DSIZE=256 -DTYPE=half
+```
+Now, `TYPE` and `SIZE` are preprocessors macros which can be specified at compile-time, thereby giving the Triton compiler enough information to vectorize when beneficial without requiring any additional code modification.
+
+
+#### Auto-Tuning
+
+As it turns out, different input vector lengths `N` may require different values of `SIZE` to perform optimally. Fortunately, the Triton preprocessor also accepts lists of possible definitions for macros, in which case an auto-tuning procedure will be launched every-time new input sizes are encountered. For example, compiling the above kernel with the option`-DSIZE=[32, 64, 128, 256] -DTYPE=float`
+will result in the parameter `SIZE` being automatically tuned every time a new value of `N` is encountered.
+
+_Note: Tuning our reference CUDA kernel would be much more cumbersome, as template metaprogramming would have to be used to ensure that proper vector types would be used_
+
+
+## Matrix Transposition
+
+Transpositions are (relatively) hard to efficiently write in CUDA because naive implementations typically suffer from _uncoalesced_ memory operations when writing back the transposed matrix to DRAM. Of course, this can be fixed by using shared memory as shown [here](https://devblogs.nvidia.com/efficient-matrix-transpose-cuda-cc/), but this comes at the cost of simplicity and -- more importantly -- interferes with auto-tuning.
+
+### Compute Kernel
+
+In Triton, however, kernels are single-threaded and the compiler automatically detects if and when data should be temporarily stashed to shared memory in order to enable shared memory stores/loads. Therefore, an optimal Triton kernel for this operation would look like:
+
+```c
+// launched on a grid of (M / TM) x (N / TN) programs of 1 thread each
+__global__ void transpose(TYPE * X, TYPE * Y, int M, int N, int ldx, int ldy) {
+// extract program ID
+ int pidm = get_program_id(0); //(1)
+ int pidn = get_program_id(1); //(2)
+ // create 1D range along the two matrix's axes
+ int rm[TM] = pidm * TM + 0 ... TM; //(3)
+ int rn[TN] = pidn * TN + 0 ... TN; //(4)
+ // create 2D array of pointers
+ TYPE* px[TM, TN] = X + rm[:, newaxis] + rn[newaxis, :] * ldx; //(5)
+ TYPE* py[TN, TM] = Y + rm[newaxis, :] * ldy + rn[:, newaxis]; //(6)
+ // write back using the transposition operator '^'
+ *py = ^(*px); //(7)
+}
+```
+
+At a high level, this kernel loads a `TM x TN` tile from the input matrix `X`, transposes it and writes the resulting `TN x TM` tile to the output matrix `Y`. Eventually, transposition of the full input matrix is achieved by launching a grid of `(M / TM) x (N / TN)` programs decomposed as follows:
+
+- Statements (1) and (2) extract the coordinates the program in the above 2D launch grid. For example, the program producing the output tile `Y[TN:2TN-1, 2TN:3TN-1]` holds the values:
+```
+pidm = 2
+pidn = 1
+```
+
+- Statements (3) and (4) construct the ranges of indices:
+```
+rm = [pidm*TM + 0, pidm*TM + 1, ..., pidm*TM + (TM - 1)]
+rn = [pidn*TN + 0, pidn*TN + 1, ..., pidn*TN + (TN - 1)]
+```
+
+which will be used in statements (5) and (6) to construct tiles of pointers
+
+- Statements (5) constructs the following array of pointers `px` using numpy-style broadcasting semantics:
+```
+│ X + (pidm*TM + 0) + (pidn*TN + 0)*ldx, ..., ..., X + (pidm*TM + 0) + (pidn*TN + TN - 1)*ldx) │
+│ ⋮ ⋮ │
+│ ⋮ ⋮ │
+│ X + (pidm*TM + TM - 1) + (pidn*TN + 0)*ldx, ..., ..., X + (pidm*TM + TM - 1) + (pidn*TN + TN - 1)*ldx) │
+```
+- Statement (6) constructs the following array of pointers `py` using numpy-style broadcasting semantics:
+```
+│ Y + (pidn*TN + 0) + (pidm*TM + 0)*ldy, ..., ..., Y + (pidn*TN + 0) + (pidm*TM + TM - 1)*ldy) │
+│ ⋮ ⋮ │
+│ ⋮ ⋮ │
+│ Y + (pidn*TN + TN - 1) + (pidn*TN + 0)*ldy, ..., ..., Y + (pidn*TN + TN - 1) + (pidm*TM + TM - 1)*ldy) │
+```
+- Statement (7) element-wise dereferences the above array of pointers `*px`, transposes it using the unary transposition operator `^`, and writes it back at the location specified by `py`.
+
+### The __multipleof Attribute
+
+The memory loads and store in our transposition kernel are not vectorizable by default, since `X + ldx` (and `Y + ldy`) may be misaligned when `ldx` (and `ldy`) are not multiples of e.g., 4. This is unfortunate because tensor dimensions can be easily made into nice powers of two in Deep Learning, due to batch-sizes and layer width being flexible.
+
+For this reason, Triton provides a __multipleof(N) attributes for variables that are guaranteed to always be multiple of N. In the case of Matrix Transpositions, vector loads can be enabled by modifying the function's signature as follows:
+
+```c
+__global__ void transpose(TYPE * X, TYPE * Y, int M, int N, int ldx __multipleof(8), int ldy __multipleof(8)) {
+// ...
+}
+```
+
+### Conditional Dereferencing
+
+You might have noticed that the above code will fail when `M` and `N` are not multiples of `TM` and `TN` respectively. Fortunately, the above kernel can be slightly modified to handle thie situation, as shown below:
+```c
+// launched on a grid of ((M + TM - 1) / TM) x ((N + TN - 1) / TN) programs
+__global__ void transpose(TYPE * X, TYPE * Y, int M, int N, int ldx, int ldy) {
+ // ...
+ // create bounds-checking mask
+ bool checkx[TM, TN] = (rm[:, newaxis] < M) && (rn[newaxis, :] < N); //(7a)
+ bool checky[TN, TM] = (rm[newaxis, :] < M) && (rn[:, newaxis] < N); //(7b)
+ // conditional write-back using the conditional dereferencing operatior '*?()'
+ *?(checky)py = ^(*?(checkx)px); //(7)
+}
+```
+
+Here, statements (7a) creates an array of booleans `checkx[TM, TN]` such that `checkx(i, j) = True` if and only if `px(i, j)` should be dereferenced. Statement (7b) does the same for `py`. Both `px` and `py` are then conditionally dereferenced using Triton-C's conditional dereferencing operator `*?(predicate) pointer`.
+
+
+## Matrix Multiplication
+
+The purpose of this section is to present a Triton-C implementation of matrix multiplication that achieves performance competitive with the best existing hand-written CUDA kernels (see [CUTLASS](https://github.com/NVIDIA/cutlass)). We will also see how pre-processors macros can be leveraged to fuse transposition operations as well as to provide support for auto-tuning and FP16 Tensor Cores.
+
+_Note: Bounds-checking is ommitted throughout for the sake of clarity. This feature can be easily added into our kernel, but may result in a slight performance hit because LLVM and PTXAS have issues dealing with conditionals and predicates inside loops._
+
+### Compute Kernel
+
+Matrix multiplications of the form `C = A x B` can be implemented in Triton-C fairly concisely, as shown below:
+
+```c
+// Triton-C
+// launched on a grid of (M / TM) x (N / TN) programs
+__global__ void dot(TYPE * A, TYPE * B, TYPE * C, int M, int N, int K,
+ int lda __multipleof(8), int ldb __multipleof(8), int ldc __multipleof(8)) {
+ // prologue
+ int pm = get_program_id(0); //(1)
+ int pn = get_program_id(1); //(2)
+ int rm[TM] = pm * TM + 0 ... TM; //(3)
+ int rn[TN] = pn * TN + 0 ... TN; //(4)
+ int rk[TK] = 0 ... TK; //(5)
+ // initialize accumulator
+ float c[TM, TN] = 0; //(6)
+ // pointers to operands
+ TYPE* pa[TM, TK] = A + rk[newaxis, :] * 1 + rm[:, newaxis] * lda; //(7)
+ TYPE* pb[TK, TN] = B + rk[:, newaxis] * ldb + rn[newaxis, :] * 1; //(8)
+ // reduction loop
+ for(int k = K; k > 0; k-= TK){
+ // fetch operands
+ TYPE a[TM, TK] = *pa; //(9)
+ TYPE b[TK, TN] = *pb; //(10)
+ // matrix-multiply accumulate
+ c += a @ b; //(11)
+ // increment pointers
+ pa = pa + TK * 1; //(12)
+ pb = pb + TK * ldb; //(13)
+ }
+ // epilogue
+ TYPE* pc[TM, TN] = C + rn[newaxis, :] + rm[:, newaxis] * ldc; //(14)
+ *pc = c; //(15)
+}
+```
+Here, each kernel instance produces a `TM x TN` tile of the output matrix C as follows:
+
+- Statements (1) - (2) fetch the id of the current program instance.
+- Statements (3) - (4) construct ranges of indices to process for the vertical and horizontal axes of the output matrix `C`
+- Statement (5) constructs a range of indices along the reduction axis: `rk = [0, 1, ..., TK - 1]`
+- Statement (6) initialize a `TM x TN` array of accumulators to hold the result of `A[rm, :] x B[:, rn]`
+- Statements (7) - (8) initializes arrays of pointers `pa` and `pb` to the operands `A` and `B` using logic similar to that of the above transposition kernel
+- Statements (9) - (10) load tiles of operands by dereferencing `pa` and `pb`
+- Statement (11) performs updates the accumulator array using Triton-C's matrix multiplication operator '@'
+- Statements (12) - (13) updates `pa` and `pb`
+- Statement (14) creates an array of pointers `pc` to the result matrix `C`
+- Statement (15) writes back the accumulator to `C`
+
+Internally, the Triton compiler will perform quite a few optimizations that will ensure good performance for this kernel:
+
+- Automatic coalescing of load/store operations
+- Automatic vectorization of load/store operations
+- Stashing `a` and `b` to shared memory
+- Automatic allocation of shared memory
+- Automatic synchronization of shared memory
+- Automatic padding of shared memory to avoid bank conflicts
+- Automatic usage of tensor cores when TYPE = half and TK % 4 = 0
+
+### Optimizations
+
+Nonetheless, there are two important optimizations that the Triton compiler does not do automatically at the moment yet are critical to achieve peak performance: pre-fetching and rematerialization. In this subsection we describe how these optimizations can be done manually by modifying the above source-code.
+
+#### Pre-Fetching
+
+The purpose of pre-fetching is to overlap the update of the accumulator `c` with the memory loads for the next tiles that will need to be multiplied. This can be done by modifying the above reduction loop as follows:
+
+```
+// pre-fetch operands
+TYPE a[TM, TK] = *pa; //(9)
+TYPE b[TK, TN] = *pb; //(10)
+for(int k = K; k > 0; k-= TK){
+ c += a @ b;
+ pa = pa + TK * 1;
+ pb = pb + TK * ldb;
+ // don't prefetch last iteration
+ bool check = k > TK;
+ // pre-fetch operands
+ a = check ? *pa : 0;
+ b = check ? *pb : 0;
+ }
+```
+
+Note that the Triton-C compiler will now also be able to use double-buffering techniques to make sure that the array `a` can be used and updated at the same time without any memory hazard.
+
+#### Rematerialization
+
+[Rematerialization](https://en.wikipedia.org/wiki/Rematerialization) is a compiler optimization which consists in recomputing some values instead of storing and reloading them from (register) memory, so as to decrease register pressure in the compute kernel. Although LLVM does this automatically to some extent, it fails to find good heuristics for the above kernel -- thereby requiring some source code modification to achieve optimal performance. Fortunately, only `rm` and `rn` need to be rematerialized, leading to the following epilogue:
+
+```c
+// epilogue
+int rcm[TM] = pm * TM + 0 ... TM;
+int rcn[TN] = pn * TN + 0 ... TN;
+TYPE* pc[TM, TN] = C + rcn[newaxis, :] + rcm[:, newaxis] * ldc;
+*pc = c;
+```
+
+### Fused Transpositions and Auto-Tuning
+
+It is common for optimized matrix-multiplication implementations (e.g., BLAS) to provide variants in which one or both operands are transposed. This is also what is done in the [PyTriton](https://github.com/ptillet/triton/blob/master/python/triton/ops/dot.py) implementation of matrix-multiplication. Fortunately, this can be done by using pre-processors macros for tile shapes and broadcasting directives, leading to the following kernel:
+
+```c
+// Triton-C
+// launched on a grid of (M / TM) x (N / TN) programs
+void dot(TYPE * A, TYPE * B, TYPE * C,
+ int M, int N, int K,
+ int lda __multipleof(8), int ldb __multipleof(8), int ldc __multipleof(8)) {
+ // prologue
+ int pm = get_program_id(0);
+ int pn = get_program_id(1);
+ int rm[TM] = pm * TM + 0 ... TM;
+ int rn[TN] = pn * TN + 0 ... TN;
+ int rk[TK] = 0 ... TK;
+ float c[TM, TN] = 0;
+ // pointers to operands
+ TYPE* pa[SHAPE_A] = A + rk[BROADCAST_AK] * STRIDE_AK + rm[BROADCAST_AM] * STRIDE_AM;
+ TYPE* pb[SHAPE_B] = B + rk[BROADCAST_BK] * STRIDE_BK + rn[BROADCAST_BN] * STRIDE_BN;
+ // prefetches operands
+ TYPE a[SHAPE_A] = (*pa);
+ TYPE b[SHAPE_B] = (*pb);
+ // reduction loop
+ for(int k = K; k > 0; k-= TK){
+ c += USE_A @ USE_B;
+ pa = pa + TK * STRIDE_AK;
+ pb = pb + TK * STRIDE_BK;
+ a = *pa;
+ b = *pb;
+ }
+ // epilogue
+ int rcm[TM] = pm * TM + 0 ... TM;
+ int rcn[TN] = pn * TN + 0 ... TN;
+ TYPE* pc[TM, TN] = C + rcn[newaxis, :] + rcm[:, newaxis] * ldc;
+ *pc = c;
+}
+```
+
+All matrix multiplications variants can then be retrieved using the following compilation option:
+```c
+// A is not transposed
+-DUSE_A=a -DSTRIDE_AK=1-DSTRIDE_AM=lda -DBROADCAST_AK=newaxis,: -DBROADCAST_AN=:,newaxis -DSHAPE_A=TM,TK
+// A is transposed
+-DUSE_A=^a -DSTRIDE_AK=lda-DSTRIDE_AM=1 -DBROADCAST_AK=:,newaxis -DBROADCAST_AN=newaxis,: -DSHAPE_A=TK,TM
+// B is not transpose
+-DUSE_B=b -DSTRIDE_BK=ldb-DSTRIDE_BN=1 -DBROADCAST_BK=:,newaxis -DBROADCAST_BN=newaxis,: -DSHAPE_B=TK,TN
+// B is transpose
+-DUSE_B=^b -DSTRIDE_BK=1-DSTRIDE_BN=ldb -DBROADCAST_BK=newaxis,: -DBROADCAST_BN=:,newaxis -DSHAPE_B=TN,TK
+```
+
+Auto-tuning can also be handled using pre-processor macros:
+```c
+// Auto-tuning TM and TN in {32, 64, 128}; TK in {8, 16}
+-DTM=[32, 64, 128] -DTN=[32, 64, 128] -DTK=[8, 16]
+```
+
diff --git a/include/triton/codegen/analysis/align.h b/include/triton/codegen/analysis/align.h
new file mode 100644
index 000000000..647db3984
--- /dev/null
+++ b/include/triton/codegen/analysis/align.h
@@ -0,0 +1,77 @@
+#ifndef TDL_INCLUDE_CODEGEN_ALIGNMENT_INFO_PASS_H
+#define TDL_INCLUDE_CODEGEN_ALIGNMENT_INFO_PASS_H
+
+#include