[BACKEND] Make flash attention forward pass work (#928)

This also simplifies BroadcastOp codegen
This commit is contained in:
Philippe Tillet
2022-11-30 11:13:24 +01:00
committed by GitHub
parent 4e6a8209ed
commit 6461254fb5
7 changed files with 326 additions and 205 deletions

View File

@@ -11,6 +11,7 @@
#include "mlir/Parser.h"
#include "mlir/Support/FileUtilities.h"
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
#include "triton/Analysis/Allocation.h"
#include "triton/Conversion/TritonGPUToLLVM/TritonGPUToLLVM.h"
#include "triton/Conversion/TritonToTritonGPU/TritonToTritonGPU.h"
@@ -115,6 +116,10 @@ void init_triton_ir(py::module &&m) {
.def(py::init<>())
.def("load_triton", [](mlir::MLIRContext &self) {
self.getOrLoadDialect<mlir::triton::TritonDialect>();
// we load LLVM because the frontend uses LLVM.undef for
// some placeholders
self.getOrLoadDialect<mlir::triton::TritonDialect>();
self.getOrLoadDialect<mlir::LLVM::LLVMDialect>();
});
// .def(py::init([](){
// mlir::MLIRContext context;
@@ -350,6 +355,7 @@ void init_triton_ir(py::module &&m) {
"parse_mlir_module",
[](const std::string &inputFilename, mlir::MLIRContext &context) {
// initialize registry
// note: we initialize llvm for undef
mlir::DialectRegistry registry;
registry.insert<mlir::triton::TritonDialect,
mlir::triton::gpu::TritonGPUDialect,
@@ -1243,7 +1249,14 @@ void init_triton_ir(py::module &&m) {
mlir::StringAttr::get(self.getContext(),
llvm::StringRef(prefix)),
values);
});
})
// Undef
.def("create_undef",
[](mlir::OpBuilder &self, mlir::Type &type) -> mlir::Value {
auto loc = self.getUnknownLoc();
return self.create<::mlir::LLVM::UndefOp>(loc, type);
})
;
py::class_<mlir::PassManager>(m, "pass_manager")
.def(py::init<mlir::MLIRContext *>())

View File

@@ -594,11 +594,8 @@ class CodeGenerator(ast.NodeVisitor):
ub = self.builder.create_to_index(ub)
step = self.builder.create_to_index(step)
# Create placeholder for the loop induction variable
# We can use any value because the variable isn't a constexpr
# but use a distinctive value (of the right type) to ease debugging
st_target = ast.Name(id=node.target.id, ctx=ast.Store())
init_node = ast.Assign(targets=[st_target], value=ast.Num(value=0xBADF00D))
self.visit(init_node)
iv = self.builder.create_undef(self.builder.get_int32_ty())
self.set_value(node.target.id, triton.language.core.tensor(iv, triton.language.core.int32))
with enter_sub_region(self) as sr:
liveins, insert_block = sr
@@ -1014,6 +1011,7 @@ def ty_to_cpp(ty):
"u32": "uint32_t",
"u64": "uint64_t",
"fp32": "float",
"f32": "float",
}[ty]
@@ -1044,6 +1042,7 @@ def generate_launcher(constants, signature):
'u32': 'uint32_t',
'u64': 'uint64_t',
'fp32': 'float',
'f32': 'float',
'fp64': 'double',
}[ty]
@@ -1343,7 +1342,31 @@ def make_hash(fn, **kwargs):
key = f"{fn.cache_key}-{''.join(signature.values())}-{configs_key}-{constants}-{num_warps}-{num_stages}"
return hashlib.md5(key.encode("utf-8")).hexdigest()
assert isinstance(fn, str)
return hashlib.md5(Path(fn).read_text().encode("utf-8")).hexdigest()
return hashlib.md5((Path(fn).read_text() + triton.runtime.jit.version_key()).encode("utf-8")).hexdigest()
# - ^\s*func\s+ : match the start of the string, any leading whitespace, the keyword func,
# and any following whitespace
# - (public\s+)? : optionally match the keyword public and any following whitespace
# - (@\w+) : match an @ symbol followed by one or more word characters
# (letters, digits, or underscores), and capture it as group 1 (the function name)
# - (\((?:%\w+: \S+(?: \{\S+ = \S+ : \S+\})?(?:, )?)*\)) : match a pair of parentheses enclosing
# zero or more arguments separated by commas, and capture it as group 2 (the argument list)
mlir_prototype_pattern = r'^\s*func\s+(?:public\s+)?(@\w+)(\((?:%\w+: \S+(?: \{\S+ = \S+ : \S+\})?(?:, )?)*\))\s*\{\s*$'
ptx_prototype_pattern = r"\.(?:visible|extern)\s+\.(?:entry|func)\s+(\w+)\s*\(([^)]*)\)"
prototype_pattern = {
"ttir": mlir_prototype_pattern,
"ttgir": mlir_prototype_pattern,
"ptx": ptx_prototype_pattern,
}
mlir_arg_type_pattern = r'%\w+: ([^,\s]+)(?: \{\S+ = \S+ : \S+\})?,?'
ptx_arg_type_pattern = r"\.param\s+\.(\w+)"
arg_type_pattern = {
"ttir": mlir_arg_type_pattern,
"ttgir": mlir_arg_type_pattern,
"ptx": ptx_arg_type_pattern,
}
# def compile(fn, signature: str, device: int = -1, constants=dict(), num_warps: int = 4, num_stages: int = 3, extern_libs=None, configs=None):
@@ -1354,6 +1377,27 @@ def compile(fn, **kwargs):
context = _triton.ir.context()
asm = dict()
constants = kwargs.get("constants", dict())
num_warps = kwargs.get("num_warps", 4)
num_stages = kwargs.get("num_stages", 3)
extern_libs = kwargs.get("extern_libs", dict())
device = kwargs.get("device", torch.cuda.current_device())
capability = torch.cuda.get_device_capability()
capability = capability[0]*10 + capability[1]
# build compilation stages
stages = {
"ast" : (lambda path: fn, None),
"ttir": (lambda path: _triton.ir.parse_mlir_module(path, context),
lambda src: ast_to_ttir(src, signature, configs[0], constants)),
"ttgir": (lambda path: _triton.ir.parse_mlir_module(path, context),
lambda src: ttir_to_ttgir(src, num_warps, num_stages, capability)),
"llir": (lambda path: Path(path).read_bytes(),
lambda src: ttgir_to_llir(src, extern_libs, capability)),
"ptx": (lambda path: Path(path).read_text(),
lambda src: llir_to_ptx(src, capability)),
"cubin": (lambda path: Path(path).read_bytes(),
lambda src: ptx_to_cubin(src, capability))
}
# find out the signature of the function
if isinstance(fn, triton.runtime.JITFunction):
configs = kwargs.get("configs", None)
signature = kwargs["signature"]
@@ -1368,13 +1412,15 @@ def compile(fn, **kwargs):
kwargs["signature"] = signature
else:
assert isinstance(fn, str)
name, ir = os.path.basename(fn).split(".")
assert ir == "ttgir"
asm[ir] = _triton.ir.parse_mlir_module(fn, context)
function = asm[ir].get_single_function()
param_tys = [convert_type_repr(str(ty)) for ty in function.type.param_types()]
_, ir = os.path.basename(fn).split(".")
src = Path(fn).read_text()
import re
match = re.search(prototype_pattern[ir], src, re.MULTILINE)
name, signature = match.group(1), match.group(2)
types = re.findall(arg_type_pattern[ir], signature)
param_tys = [convert_type_repr(ty) for ty in types]
signature = {k: v for k, v in enumerate(param_tys)}
first_stage = 2
first_stage = list(stages.keys()).index(ir)
# cache manager
so_path = make_stub(name, signature, constants)
@@ -1384,58 +1430,42 @@ def compile(fn, **kwargs):
if isinstance(fn, triton.runtime.JITFunction):
name, ext = fn.__name__, "ast"
else:
name, ext = os.path.basename(fn).split(".")
# initialize compilation params
num_warps = kwargs.get("num_warps", 4)
num_stages = kwargs.get("num_stages", 3)
extern_libs = kwargs.get("extern_libs", dict())
device = kwargs.get("device", torch.cuda.current_device())
compute_capability = torch.cuda.get_device_capability(device)
compute_capability = compute_capability[0] * 10 + compute_capability[1]
name, ext = os.path.basename(fn).split(".")
# load metadata if any
metadata = None
if fn_cache_manager.has_file(f'{name}.json'):
with open(fn_cache_manager._make_path(f"{name}.json")) as f:
metadata = json.load(f)
else:
metadata = {"num_warps": num_warps, "num_stages": num_stages, "ctime": dict()}
# build compilation stages
stages = {
"ast": (lambda path: fn, None),
"ttir": (lambda path: _triton.ir.parse_mlir_module(path, context),
lambda src: ast_to_ttir(src, signature, configs[0], constants)),
"ttgir": (lambda path: _triton.ir.parse_mlir_module(path, context),
lambda src: ttir_to_ttgir(src, num_warps, num_stages, compute_capability)),
"llir": (lambda path: Path(path).read_bytes(),
lambda src: ttgir_to_llir(src, extern_libs, compute_capability)),
"ptx": (lambda path: Path(path).read_text(),
lambda src: llir_to_ptx(src, compute_capability)),
"cubin": (lambda path: Path(path).read_bytes(),
lambda src: ptx_to_cubin(src, compute_capability))
}
metadata = {"num_warps": num_warps, "num_stages": num_stages, "ctime": dict()}
if ext == "ptx":
assert "shared" in kwargs, "ptx compilation must provide shared memory size"
metadata["shared"] = kwargs["shared"]
first_stage = list(stages.keys()).index(ext)
asm = dict()
module = fn
# run compilation pipeline and populate metadata
for ir, (parse, compile) in list(stages.items())[first_stage:]:
path = fn_cache_manager._make_path(f"{name}.{ir}")
if ir == ext:
next_module = parse(fn)
elif os.path.exists(path) and \
ir in metadata["ctime"] and \
os.path.getctime(path) == metadata["ctime"][ir]:
next_module = parse(path)
else:
next_module = compile(module)
fn_cache_manager.put(next_module, f"{name}.{ir}")
if os.path.exists(path):
metadata["ctime"][ir] = os.path.getctime(path)
asm[ir] = next_module if ir == "cubin" else str(next_module)
if ir == "llir" and "shared" not in metadata:
metadata["shared"] = _triton.get_shared_memory_size(module)
if ir == "ptx":
metadata["name"] = ptx_get_kernel_name(next_module)
module = next_module
path = fn_cache_manager._make_path(f"{name}.{ir}")
if ir == ext:
next_module = parse(fn)
elif os.path.exists(path) and\
ir in metadata["ctime"] and\
os.path.getctime(path) == metadata["ctime"][ir]:
next_module = parse(path)
else:
next_module = compile(module)
fn_cache_manager.put(next_module, f"{name}.{ir}")
if os.path.exists(path):
metadata["ctime"][ir] = os.path.getctime(path)
asm[ir] = next_module if ir == "cubin" else str(next_module)
if ir == "llir" and "shared" not in metadata:
metadata["shared"] = _triton.get_shared_memory_size(module)
if ir == "ptx":
metadata["name"] = ptx_get_kernel_name(next_module)
module = next_module
# write-back metadata
fn_cache_manager.put(json.dumps(metadata), f"{name}.json", binary=False)
# return handle to compiled kernel

View File

@@ -405,7 +405,7 @@ class constexpr:
return constexpr(self.value != other.value)
def __bool__(self):
return constexpr(bool(self.value))
return bool(self.value)
def __neg__(self):
return constexpr(-self.value)

View File

@@ -32,7 +32,7 @@ def _fwd_kernel(
offs_n = tl.arange(0, BLOCK_N)
offs_d = tl.arange(0, BLOCK_DMODEL)
off_q = off_hz * stride_qh + offs_m[:, None] * stride_qm + offs_d[None, :] * stride_qk
off_k = off_hz * stride_qh + offs_n[:, None] * stride_kn + offs_d[None, :] * stride_kk
off_k = off_hz * stride_qh + offs_n[None, :] * stride_kn + offs_d[:, None] * stride_kk
off_v = off_hz * stride_qh + offs_n[:, None] * stride_qm + offs_d[None, :] * stride_qk
# Initialize pointers to Q, K, V
q_ptrs = Q + off_q
@@ -50,7 +50,7 @@ def _fwd_kernel(
# -- compute qk ----
k = tl.load(k_ptrs + start_n * stride_kn)
qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)
qk += tl.dot(q, k, trans_b=True)
qk += tl.dot(q, k)
qk *= sm_scale
qk += tl.where(offs_m[:, None] >= (start_n + offs_n[None, :]), 0, float("-inf"))
# -- compute m_ij, p, l_ij
@@ -195,6 +195,7 @@ def _bwd_kernel(
tl.store(dk_ptrs, dk)
empty = torch.empty(128, device="cuda")
class _attention(torch.autograd.Function):
@staticmethod
@@ -205,7 +206,7 @@ class _attention(torch.autograd.Function):
assert Lq == Lk and Lk == Lv
assert Lk in {16, 32, 64, 128}
o = torch.empty_like(q)
grid = (triton.cdiv(q.shape[2], BLOCK), q.shape[0] * q.shape[1])
grid = (triton.cdiv(q.shape[2], BLOCK), q.shape[0] * q.shape[1], 1)
tmp = torch.empty((q.shape[0] * q.shape[1], q.shape[2]), device=q.device, dtype=torch.float32)
L = torch.empty((q.shape[0] * q.shape[1], q.shape[2]), device=q.device, dtype=torch.float32)
m = torch.empty((q.shape[0] * q.shape[1], q.shape[2]), device=q.device, dtype=torch.float32)
@@ -224,6 +225,7 @@ class _attention(torch.autograd.Function):
BLOCK_DMODEL=Lk, num_warps=num_warps,
num_stages=1,
)
ctx.save_for_backward(q, k, v, o, L, m)
ctx.BLOCK = BLOCK
ctx.grid = grid
@@ -268,13 +270,13 @@ class _attention(torch.autograd.Function):
attention = _attention.apply
@pytest.mark.parametrize('Z, H, N_CTX, D_HEAD', [(3, 2, 2048, 64)])
@pytest.mark.parametrize('Z, H, N_CTX, D_HEAD', [(4, 48, 1024, 64)])
def test_op(Z, H, N_CTX, D_HEAD, dtype=torch.float16):
torch.manual_seed(20)
q = torch.empty((Z, H, N_CTX, D_HEAD), dtype=dtype, device="cuda").normal_(mean=0, std=.5).requires_grad_()
k = torch.empty((Z, H, N_CTX, D_HEAD), dtype=dtype, device="cuda").normal_(mean=0, std=.5).requires_grad_()
v = torch.empty((Z, H, N_CTX, D_HEAD), dtype=dtype, device="cuda").normal_(mean=0, std=.5).requires_grad_()
sm_scale = 0.3
q = torch.empty((Z, H, N_CTX, D_HEAD), dtype=dtype, device="cuda").normal_(mean=0.1, std=0.1).requires_grad_()
k = torch.empty((Z, H, N_CTX, D_HEAD), dtype=dtype, device="cuda").normal_(mean=0.4, std=0.2).requires_grad_()
v = torch.empty((Z, H, N_CTX, D_HEAD), dtype=dtype, device="cuda").normal_(mean=0.3, std=0.2).requires_grad_()
sm_scale = 0.2
dout = torch.randn_like(q)
# reference implementation
M = torch.tril(torch.ones((N_CTX, N_CTX), device="cuda"))
@@ -283,19 +285,69 @@ def test_op(Z, H, N_CTX, D_HEAD, dtype=torch.float16):
for h in range(H):
p[:, :, M == 0] = float("-inf")
p = torch.softmax(p.float(), dim=-1).half()
# p = torch.exp(p)
ref_out = torch.matmul(p, v)
ref_out.backward(dout)
ref_dv, v.grad = v.grad.clone(), None
ref_dk, k.grad = k.grad.clone(), None
ref_dq, q.grad = q.grad.clone(), None
# triton implementation
# ref_out.backward(dout)
# ref_dv, v.grad = v.grad.clone(), None
# ref_dk, k.grad = k.grad.clone(), None
# ref_dq, q.grad = q.grad.clone(), None
# # triton implementation
tri_out = attention(q, k, v, sm_scale)
tri_out.backward(dout)
tri_dv, v.grad = v.grad.clone(), None
tri_dk, k.grad = k.grad.clone(), None
tri_dq, q.grad = q.grad.clone(), None
# print(ref_out)
# print(tri_out)
# tri_out.backward(dout)
# tri_dv, v.grad = v.grad.clone(), None
# tri_dk, k.grad = k.grad.clone(), None
# tri_dq, q.grad = q.grad.clone(), None
# compare
triton.testing.assert_almost_equal(ref_out, tri_out)
triton.testing.assert_almost_equal(ref_dv, tri_dv)
triton.testing.assert_almost_equal(ref_dk, tri_dk)
triton.testing.assert_almost_equal(ref_dq, tri_dq)
# triton.testing.assert_almost_equal(ref_dv, tri_dv)
# triton.testing.assert_almost_equal(ref_dk, tri_dk)
# triton.testing.assert_almost_equal(ref_dq, tri_dq)
BATCH, N_HEADS, N_CTX, D_HEAD = 4, 48, 4096, 64
# vary seq length for fixed head and batch=4
configs = [triton.testing.Benchmark(
x_names=['N_CTX'],
x_vals=[2**i for i in range(10, 16)],
line_arg='provider',
line_vals=['triton'],
line_names=['Triton'],
styles=[('red', '-'), ('blue', '-')],
ylabel='ms',
plot_name=f'fused-attention-batch{BATCH}-head{N_HEADS}-d{D_HEAD}-{mode}',
args={'H': N_HEADS, 'BATCH': BATCH, 'D_HEAD': D_HEAD, 'dtype': torch.float16, 'mode': mode}
) for mode in ['fwd']]
@triton.testing.perf_report(configs)
def bench_flash_attention(BATCH, H, N_CTX, D_HEAD, mode, provider, dtype=torch.float16, device="cuda"):
assert mode in ['fwd', 'bwd']
warmup = 25
rep = 100
if provider == "triton":
q = torch.randn((BATCH, H, N_CTX, D_HEAD), dtype=dtype, device="cuda", requires_grad=True)
k = torch.randn((BATCH, H, N_CTX, D_HEAD), dtype=dtype, device="cuda", requires_grad=True)
v = torch.randn((BATCH, H, N_CTX, D_HEAD), dtype=dtype, device="cuda", requires_grad=True)
sm_scale = 1.3
fn = lambda: attention(q, k, v, sm_scale)
if mode == 'bwd':
o = fn()
do = torch.randn_like(o)
fn = lambda: o.backward(do, retain_graph=True)
ms = triton.testing.do_bench(fn, percentiles=None, warmup=warmup, rep=rep)
return ms
if provider == "flash":
lengths = torch.full((BATCH,), fill_value=N_CTX, device=device)
cu_seqlens = torch.zeros((BATCH + 1,), device=device, dtype=torch.int32)
cu_seqlens[1:] = lengths.cumsum(0)
qkv = torch.randn((BATCH * N_CTX, 3, H, D_HEAD), dtype=dtype, device=device, requires_grad=True)
fn = lambda: flash_attn_func(qkv, cu_seqlens, 0., N_CTX, causal=True)
if mode == 'bwd':
o = fn()
do = torch.randn_like(o)
fn = lambda: o.backward(do, retain_graph=True)
ms = triton.testing.do_bench(fn, percentiles=None, warmup=warmup, rep=rep)
return ms
bench_flash_attention.run(save_path='.', print_data=True)