[FRONTEND][BACKEND] Small fixes to multiple_of, num_programs, axisinfo; enable block-sparse tests (#927)

This commit is contained in:
Philippe Tillet
2022-11-29 20:00:34 +01:00
committed by GitHub
parent 66c36c4378
commit 9bb54402b3
8 changed files with 229 additions and 17 deletions

View File

@@ -132,6 +132,7 @@ ChangeResult AxisInfoAnalysis::visitOperation(
AxisInfo::DimVectorT(ty.getShape().begin(), ty.getShape().end())); AxisInfo::DimVectorT(ty.getShape().begin(), ty.getShape().end()));
} }
} }
// TODO: refactor & complete binary ops
// Addition // Addition
if (llvm::isa<arith::AddIOp, triton::AddPtrOp>(op)) { if (llvm::isa<arith::AddIOp, triton::AddPtrOp>(op)) {
auto newContiguity = [&](AxisInfo lhs, AxisInfo rhs, int d) { auto newContiguity = [&](AxisInfo lhs, AxisInfo rhs, int d) {
@@ -159,6 +160,20 @@ ChangeResult AxisInfoAnalysis::visitOperation(
curr = visitBinaryOp(op, operands[0]->getValue(), operands[1]->getValue(), curr = visitBinaryOp(op, operands[0]->getValue(), operands[1]->getValue(),
newContiguity, newDivisibility, newConstancy); newContiguity, newDivisibility, newConstancy);
} }
// Remainder
if (llvm::isa<arith::RemSIOp, arith::RemUIOp>(op)) {
auto newContiguity = [](AxisInfo lhs, AxisInfo rhs, int d) {
return gcd(lhs.getContiguity(d), rhs.getDivisibility(d));
};
auto newDivisibility = [](AxisInfo lhs, AxisInfo rhs, int d) {
return gcd(lhs.getDivisibility(d), rhs.getDivisibility(d));
};
auto newConstancy = [](AxisInfo lhs, AxisInfo rhs, int d) {
return gcd(lhs.getConstancy(d), rhs.getConstancy(d));
};
curr = visitBinaryOp(op, operands[0]->getValue(), operands[1]->getValue(),
newContiguity, newDivisibility, newConstancy);
}
// TODO: All other binary ops // TODO: All other binary ops
if (llvm::isa<arith::AndIOp, arith::OrIOp>(op)) { if (llvm::isa<arith::AndIOp, arith::OrIOp>(op)) {
auto newContiguity = [](AxisInfo lhs, AxisInfo rhs, int d) { return 1; }; auto newContiguity = [](AxisInfo lhs, AxisInfo rhs, int d) { return 1; };

View File

@@ -2074,7 +2074,7 @@ struct GetNumProgramsOpConversion
Location loc = op->getLoc(); Location loc = op->getLoc();
assert(op.axis() < 3); assert(op.axis() < 3);
Value blockId = rewriter.create<::mlir::gpu::BlockDimOp>( Value blockId = rewriter.create<::mlir::gpu::GridDimOp>(
loc, rewriter.getIndexType(), dims[op.axis()]); loc, rewriter.getIndexType(), dims[op.axis()]);
auto llvmIndexTy = getTypeConverter()->getIndexType(); auto llvmIndexTy = getTypeConverter()->getIndexType();
rewriter.replaceOpWithNewOp<UnrealizedConversionCastOp>( rewriter.replaceOpWithNewOp<UnrealizedConversionCastOp>(

View File

@@ -187,6 +187,7 @@ void init_triton_ir(py::module &&m) {
/* issue a warning */ /* issue a warning */
} }
}) })
.def("get_context", &mlir::Value::getContext)
.def("replace_all_uses_with", .def("replace_all_uses_with",
[](mlir::Value &self, mlir::Value &newValue) { [](mlir::Value &self, mlir::Value &newValue) {
self.replaceAllUsesWith(newValue); self.replaceAllUsesWith(newValue);
@@ -335,6 +336,16 @@ void init_triton_ir(py::module &&m) {
return funcs[0]; return funcs[0];
}); });
m.def("make_attr",
[](const std::vector<int> &values, mlir::MLIRContext &context) {
return mlir::DenseIntElementsAttr::get(
mlir::RankedTensorType::get(
{static_cast<int64_t>(values.size())},
mlir::IntegerType::get(&context, 32)),
values)
.cast<mlir::Attribute>();
});
m.def( m.def(
"parse_mlir_module", "parse_mlir_module",
[](const std::string &inputFilename, mlir::MLIRContext &context) { [](const std::string &inputFilename, mlir::MLIRContext &context) {

View File

@@ -0,0 +1,188 @@
import pytest
import torch
import triton
# TODO: float32 fails
@pytest.mark.parametrize("MODE", ["sdd", "dds", "dsd"])
@pytest.mark.parametrize("TRANS_B", [False, True])
@pytest.mark.parametrize("TRANS_A", [False, True])
@pytest.mark.parametrize("BLOCK", [16, 32, 64])
@pytest.mark.parametrize("DTYPE", [torch.float16])
def test_matmul(MODE, TRANS_A, TRANS_B, BLOCK, DTYPE, Z=3, H=2, M=512, N=256, K=384):
seed = 0
torch.manual_seed(seed)
is_sdd = MODE == "sdd"
is_dsd = MODE == "dsd"
is_dds = MODE == "dds"
do_sparsify = lambda x: triton.testing.sparsify_tensor(x, layout, BLOCK)
do_mask = lambda x: triton.testing.mask_tensor(x, layout, BLOCK)
# create inputs
# create op
a_shape = (Z, H, K, M) if TRANS_A else (Z, H, M, K)
b_shape = (Z, H, N, K) if TRANS_B else (Z, H, K, N)
c_shape = (Z, H, M, N)
shape = {
"sdd": (M, N),
"dsd": (a_shape[2], a_shape[3]),
"dds": (b_shape[2], b_shape[3]),
}[MODE]
layout = torch.randint(2, (H, shape[0] // BLOCK, shape[1] // BLOCK))
layout[1, 2, :] = 0
layout[1, :, 1] = 0
# create data
a_ref, a_tri = triton.testing.make_pair(a_shape, alpha=.1, dtype=DTYPE)
b_ref, b_tri = triton.testing.make_pair(b_shape, alpha=.1, dtype=DTYPE)
dc_ref, dc_tri = triton.testing.make_pair(c_shape, dtype=DTYPE)
# compute [torch]
dc_ref = do_mask(dc_ref) if is_sdd else dc_ref
a_ref = do_mask(a_ref) if is_dsd else a_ref
b_ref = do_mask(b_ref) if is_dds else b_ref
a_ref.requires_grad_().retain_grad()
b_ref.requires_grad_().retain_grad()
c_ref = torch.matmul(a_ref.transpose(2, 3) if TRANS_A else a_ref,
b_ref.transpose(2, 3) if TRANS_B else b_ref)
c_ref.backward(dc_ref)
c_ref = do_sparsify(c_ref) if is_sdd else c_ref
da_ref = do_sparsify(a_ref.grad) if is_dsd else a_ref.grad
db_ref = do_sparsify(b_ref.grad) if is_dds else b_ref.grad
# triton result
dc_tri = do_sparsify(dc_tri) if is_sdd else dc_tri
a_tri = do_sparsify(a_tri) if is_dsd else a_tri
b_tri = do_sparsify(b_tri) if is_dds else b_tri
a_tri.requires_grad_().retain_grad()
b_tri.requires_grad_().retain_grad()
op = triton.ops.blocksparse.matmul(layout, BLOCK, MODE, trans_a=TRANS_A, trans_b=TRANS_B, device="cuda")
c_tri = triton.testing.catch_oor(lambda: op(a_tri, b_tri), pytest)
triton.testing.catch_oor(lambda: c_tri.backward(dc_tri), pytest)
da_tri = a_tri.grad
db_tri = b_tri.grad
# compare
triton.testing.assert_almost_equal(c_ref, c_tri)
triton.testing.assert_almost_equal(da_ref, da_tri)
triton.testing.assert_almost_equal(db_ref, db_tri)
configs = [
(16, 256),
(32, 576),
(64, 1871),
(128, 2511),
]
@pytest.mark.parametrize("is_dense", [False, True])
@pytest.mark.parametrize("BLOCK, WIDTH", configs)
def test_softmax(BLOCK, WIDTH, is_dense, Z=2, H=2, is_causal=True, scale=0.4):
# set seed
torch.random.manual_seed(0)
Z, H, M, N = 2, 3, WIDTH, WIDTH
# initialize layout
# make sure each row has at least one non-zero element
layout = torch.randint(2, (H, M // BLOCK, N // BLOCK))
if is_dense:
layout[:] = 1
else:
layout[1, 2, :] = 0
layout[1, :, 1] = 0
# initialize data
a_shape = (Z, H, M, N)
a_ref, a_tri = triton.testing.make_pair(a_shape)
dout_ref, dout_tri = triton.testing.make_pair(a_shape)
# compute [torch]
a_ref = triton.testing.mask_tensor(a_ref, layout, BLOCK, value=float("-inf"))
a_ref.retain_grad()
at_mask = torch.ones((M, N), device="cuda")
if is_causal:
at_mask = torch.tril(at_mask)
M = at_mask[None, None, :, :] + torch.zeros_like(a_ref)
a_ref[M == 0] = float("-inf")
out_ref = torch.softmax(a_ref * scale, -1)
out_ref.backward(dout_ref)
out_ref = triton.testing.sparsify_tensor(out_ref, layout, BLOCK)
da_ref = triton.testing.sparsify_tensor(a_ref.grad, layout, BLOCK)
# compute [triton]
a_tri = triton.testing.sparsify_tensor(a_tri, layout, BLOCK)
a_tri.retain_grad()
dout_tri = triton.testing.sparsify_tensor(dout_tri, layout, BLOCK)
op = triton.ops.blocksparse.softmax(layout, BLOCK, device="cuda", is_dense=is_dense)
out_tri = op(a_tri, scale=scale, is_causal=is_causal)
out_tri.backward(dout_tri)
da_tri = a_tri.grad
# compare
triton.testing.assert_almost_equal(out_tri, out_ref)
triton.testing.assert_almost_equal(da_tri, da_ref)
@pytest.mark.parametrize("block", [16, 32, 64])
@pytest.mark.parametrize("dtype", [torch.float16])
def test_attention_fwd_bwd(
block,
dtype,
input_scale=1.0,
scale=1 / 8.0,
n_ctx=256,
batch_size=2,
n_heads=2,
):
# inputs
qkv_shape = (batch_size, n_heads, n_ctx, 64)
qkvs = [
torch.nn.Parameter(input_scale * torch.randn(qkv_shape), requires_grad=True).to(dtype).cuda() for _ in range(3)
]
# Triton:
n_blocks = n_ctx // block
layout = torch.tril(torch.ones([n_heads, n_blocks, n_blocks], dtype=torch.long))
query, key, value = [x.clone() for x in qkvs]
query.retain_grad()
key.retain_grad()
value.retain_grad()
attn_out = triton_attention(layout, block, query=query, key=key, value=value, scale=scale)
# ad hoc loss
loss = (attn_out ** 2).mean()
loss.backward()
grads = [query.grad, key.grad, value.grad]
# Torch version:
torch_q, torch_k, torch_v = [x.clone() for x in qkvs]
attn_mask = torch.ones([n_ctx, n_ctx], device="cuda", dtype=dtype)
attn_mask = torch.tril(attn_mask, diagonal=0)
attn_mask = 1e6 * (-1 + (attn_mask.reshape((1, 1, n_ctx, n_ctx)).cuda()))
torch_q.retain_grad()
torch_k.retain_grad()
torch_v.retain_grad()
scores = scale * torch.einsum("bhsd,bhtd->bhst", torch_q, torch_k)
scores = scores + attn_mask
probs = torch.softmax(scores, dim=-1)
torch_attn_out = torch.einsum("bhst,bhtd->bhsd", probs, torch_v)
# ad hoc loss
torch_loss = (torch_attn_out ** 2).mean()
torch_loss.backward()
torch_grads = [torch_q.grad, torch_k.grad, torch_v.grad]
# comparison
# print(f"Triton loss {loss} and torch loss {torch_loss}. Also checking grads...")
triton.testing.assert_almost_equal(loss, torch_loss)
for g1, g2 in zip(grads, torch_grads):
triton.testing.assert_almost_equal(g1, g2)
@pytest.mark.parametrize("block", [16, 32, 64])
def triton_attention(
layout,
block: int,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
scale: float,
):
sparse_dot_sdd_nt = triton.ops.blocksparse.matmul(layout, block, "sdd", trans_a=False, trans_b=True, device=value.device)
sparse_dot_dsd_nn = triton.ops.blocksparse.matmul(layout, block, "dsd", trans_a=False, trans_b=False, device=value.device)
sparse_softmax = triton.ops.blocksparse.softmax(layout, block, device=value.device)
w = sparse_dot_sdd_nt(query, key)
w = sparse_softmax(w, scale=scale, is_causal=True)
a = sparse_dot_dsd_nn(w, value)
return a

View File

@@ -1515,7 +1515,7 @@ class CudaUtils(object):
} }
} }
#define CUDA_CHECK(ans) { gpuAssert((ans), __FILE__, __LINE__); } #define CUDA_CHECK(ans) { gpuAssert((ans), __FILE__, __LINE__); if(PyErr_Occurred()) return NULL; }
static PyObject* loadBinary(PyObject* self, PyObject* args) { static PyObject* loadBinary(PyObject* self, PyObject* args) {
const char* name; const char* name;
@@ -1530,7 +1530,6 @@ class CudaUtils(object):
CUmodule mod; CUmodule mod;
int32_t n_regs = 0; int32_t n_regs = 0;
int32_t n_spills = 0; int32_t n_spills = 0;
Py_BEGIN_ALLOW_THREADS;
// create driver handles // create driver handles
CUDA_CHECK(cuModuleLoadData(&mod, data)); CUDA_CHECK(cuModuleLoadData(&mod, data));
CUDA_CHECK(cuModuleGetFunction(&fun, mod, name)); CUDA_CHECK(cuModuleGetFunction(&fun, mod, name));
@@ -1548,7 +1547,6 @@ class CudaUtils(object):
CUDA_CHECK(cuFuncGetAttribute(&shared_static, CU_FUNC_ATTRIBUTE_SHARED_SIZE_BYTES, fun)); CUDA_CHECK(cuFuncGetAttribute(&shared_static, CU_FUNC_ATTRIBUTE_SHARED_SIZE_BYTES, fun));
CUDA_CHECK(cuFuncSetAttribute(fun, CU_FUNC_ATTRIBUTE_MAX_DYNAMIC_SHARED_SIZE_BYTES, shared_optin - shared_static)); CUDA_CHECK(cuFuncSetAttribute(fun, CU_FUNC_ATTRIBUTE_MAX_DYNAMIC_SHARED_SIZE_BYTES, shared_optin - shared_static));
} }
Py_END_ALLOW_THREADS;
if(PyErr_Occurred()) { if(PyErr_Occurred()) {
return NULL; return NULL;

View File

@@ -1117,16 +1117,16 @@ def sqrt(x: tl.tensor, builder: ir.builder) -> tl.tensor:
## ##
def multiple_of(x: tl.tensor, values: List[int]) -> tl.tensor: def multiple_of(x: tl.tensor, values: List[int]) -> tl.tensor:
if len(x.shape) != len(values): if len(x.shape) != len(values):
raise ValueError("Shape of input to multiple_of does not match the length of values") raise ValueError("Shape of input to multiple_of does not match the length of values")
x.handle.multiple_of(values) x.handle.set_attr("tt.divisibility", ir.make_attr(values, x.handle.get_context()))
return x return x
def max_contiguous(x: tl.tensor, values: List[int]) -> tl.tensor: def max_contiguous(x: tl.tensor, values: List[int]) -> tl.tensor:
if len(x.shape) != len(values): if len(x.shape) != len(values):
raise ValueError("Shape of input to max_contiguous does not match the length of values") raise ValueError("Shape of input to max_contiguous does not match the length of values")
x.handle.max_contiguous(values) x.handle.set_attr("tt.contiguity", ir.make_attr(values, x.handle.get_context()))
return x return x

View File

@@ -34,12 +34,12 @@ def sparsify_tensor(x, mask, block):
return ret return ret
def make_pair(shape, device="cuda", alpha=1e-2, beta=0., trans=False, data=None): def make_pair(shape, device="cuda", alpha=1e-2, beta=0., trans=False, data=None, dtype=torch.float32):
if data is None: if data is None:
data = torch.randn(shape, dtype=torch.float32, device=device) data = torch.randn(shape, dtype=torch.float32, requires_grad=True, device=device)
ref_ret = data ref_ret = data
ref_ret = ref_ret * alpha + beta ref_ret = ref_ret * alpha + beta
ref_ret = ref_ret.half().float() ref_ret = ref_ret.half().to(dtype)
if trans: if trans:
ref_ret = ref_ret.t().requires_grad_() ref_ret = ref_ret.t().requires_grad_()
ref_ret = ref_ret.detach().requires_grad_() ref_ret = ref_ret.detach().requires_grad_()

View File

@@ -917,9 +917,9 @@ func @test_get_program_id(%a: tensor<32x!tt.ptr<i32>, #blocked0>) {
module attributes {"triton_gpu.num-warps" = 4 : i32} { module attributes {"triton_gpu.num-warps" = 4 : i32} {
func @test_get_num_program(%a: tensor<32x!tt.ptr<i32>, #blocked0>) { func @test_get_num_program(%a: tensor<32x!tt.ptr<i32>, #blocked0>) {
// CHECK: nvvm.read.ptx.sreg.ntid.x // CHECK: nvvm.read.ptx.sreg.nctaid.x
// CHECK: nvvm.read.ptx.sreg.ntid.y // CHECK: nvvm.read.ptx.sreg.nctaid.y
// CHECK: nvvm.read.ptx.sreg.ntid.z // CHECK: nvvm.read.ptx.sreg.nctaid.z
%blockdimx = tt.get_num_programs {axis=0:i32} : i32 %blockdimx = tt.get_num_programs {axis=0:i32} : i32
%blockdimy = tt.get_num_programs {axis=1:i32} : i32 %blockdimy = tt.get_num_programs {axis=1:i32} : i32
%blockdimz = tt.get_num_programs {axis=2:i32} : i32 %blockdimz = tt.get_num_programs {axis=2:i32} : i32