[BACKEND] two minor bugfix on StoreOpLowering and kernel launch & support optional other in LoadOpLowering (#69)
* [BACKEND] two minor bugfix on StoreOpLowering and kernel launch & support optional other in LoadOpLowering * Clean code Co-authored-by: goostavz <gzhu@nvidia.com> Co-authored-by: Yan Chunwei <yanchunwei@outlook.com>
This commit is contained in:
@@ -309,10 +309,10 @@ public:
|
|||||||
PatternBenefit benefit = 1)
|
PatternBenefit benefit = 1)
|
||||||
: ConvertOpToLLVMPattern<SourceOp>(typeConverter, benefit) {}
|
: ConvertOpToLLVMPattern<SourceOp>(typeConverter, benefit) {}
|
||||||
|
|
||||||
SmallVector<Value, 4>
|
SmallVector<Value>
|
||||||
getElementsFromStruct(Location loc, Value llvmStruct, unsigned elems,
|
getElementsFromStruct(Location loc, Value llvmStruct, unsigned elems,
|
||||||
ConversionPatternRewriter &rewriter) const {
|
ConversionPatternRewriter &rewriter) const {
|
||||||
SmallVector<Value, 4> results(elems);
|
SmallVector<Value> results(elems);
|
||||||
for (unsigned i = 0; i < elems; ++i) {
|
for (unsigned i = 0; i < elems; ++i) {
|
||||||
Type type =
|
Type type =
|
||||||
llvmStruct.getType().cast<LLVM::LLVMStructType>().getBody()[i];
|
llvmStruct.getType().cast<LLVM::LLVMStructType>().getBody()[i];
|
||||||
@@ -710,7 +710,7 @@ struct StoreOpConversion
|
|||||||
|
|
||||||
PtxIOInstr asmStoreInstr("st");
|
PtxIOInstr asmStoreInstr("st");
|
||||||
asmStoreInstr.predicate(maskElems[vecIdx], "b");
|
asmStoreInstr.predicate(maskElems[vecIdx], "b");
|
||||||
asmStoreInstr.global().v(width).b(nWords);
|
asmStoreInstr.global().b(width).v(nWords);
|
||||||
|
|
||||||
llvm::SmallVector<std::string> asmArgs;
|
llvm::SmallVector<std::string> asmArgs;
|
||||||
|
|
||||||
@@ -970,7 +970,10 @@ struct LoadOpConversion
|
|||||||
unsigned numElems = getElemsPerThread(blockedLayout, shape);
|
unsigned numElems = getElemsPerThread(blockedLayout, shape);
|
||||||
auto ptrVals = getElementsFromStruct(loc, ptr, numElems, rewriter);
|
auto ptrVals = getElementsFromStruct(loc, ptr, numElems, rewriter);
|
||||||
auto maskVals = getElementsFromStruct(loc, mask, numElems, rewriter);
|
auto maskVals = getElementsFromStruct(loc, mask, numElems, rewriter);
|
||||||
auto otherVals = getElementsFromStruct(loc, other, numElems, rewriter);
|
SmallVector<Value> otherVals;
|
||||||
|
if (other != nullptr) {
|
||||||
|
otherVals = getElementsFromStruct(loc, other, numElems, rewriter);
|
||||||
|
}
|
||||||
unsigned nbits = elemTy.isa<FloatType>()
|
unsigned nbits = elemTy.isa<FloatType>()
|
||||||
? elemTy.cast<FloatType>().getWidth()
|
? elemTy.cast<FloatType>().getWidth()
|
||||||
: elemTy.cast<IntegerType>().getWidth();
|
: elemTy.cast<IntegerType>().getWidth();
|
||||||
@@ -1039,31 +1042,33 @@ struct LoadOpConversion
|
|||||||
asmOss << ", $" << n_words + 2;
|
asmOss << ", $" << n_words + 2;
|
||||||
asmOss << ";";
|
asmOss << ";";
|
||||||
SmallVector<Value> others;
|
SmallVector<Value> others;
|
||||||
for (size_t ii = 0; ii < n_words; ii++) {
|
if (other != nullptr) {
|
||||||
size_t size = width / nbits;
|
for (size_t ii = 0; ii < n_words; ii++) {
|
||||||
auto vecTy = LLVM::getFixedVectorType(elemTy, size);
|
size_t size = width / nbits;
|
||||||
Value v = rewriter.create<LLVM::UndefOp>(loc, vecTy);
|
auto vecTy = LLVM::getFixedVectorType(elemTy, size);
|
||||||
for (size_t s = 0; s < size; s++) {
|
Value v = rewriter.create<LLVM::UndefOp>(loc, vecTy);
|
||||||
Value falseVal = otherVals[i + ii * size + s];
|
for (size_t s = 0; s < size; s++) {
|
||||||
Value sVal = createIndexAttrConstant(
|
Value falseVal = otherVals[i + ii * size + s];
|
||||||
rewriter, loc, this->getTypeConverter()->getIndexType(), s);
|
Value sVal = createIndexAttrConstant(
|
||||||
v = rewriter.create<LLVM::InsertElementOp>(loc, vecTy, v, falseVal,
|
rewriter, loc, this->getTypeConverter()->getIndexType(), s);
|
||||||
sVal);
|
v = rewriter.create<LLVM::InsertElementOp>(loc, vecTy, v, falseVal,
|
||||||
|
sVal);
|
||||||
|
}
|
||||||
|
v = rewriter.create<LLVM::BitcastOp>(
|
||||||
|
loc, IntegerType::get(getContext(), width), v);
|
||||||
|
asmOss << "\n ";
|
||||||
|
asmOss << "@!$" << n_words << " mov.u" << width;
|
||||||
|
asmOss << " $" << ii << ", ";
|
||||||
|
std::ios_base::fmtflags flags(asmOss.flags());
|
||||||
|
if (otherIsSplatConstInt)
|
||||||
|
asmOss << "0x" << std::hex << splatVal;
|
||||||
|
else {
|
||||||
|
asmOss << "$" << n_words + has_l2_evict_policy + 2 + ii;
|
||||||
|
others.push_back(v);
|
||||||
|
}
|
||||||
|
asmOss.flags(flags);
|
||||||
|
asmOss << ";";
|
||||||
}
|
}
|
||||||
v = rewriter.create<LLVM::BitcastOp>(
|
|
||||||
loc, IntegerType::get(getContext(), width), v);
|
|
||||||
asmOss << "\n ";
|
|
||||||
asmOss << "@!$" << n_words << " mov.u" << width;
|
|
||||||
asmOss << " $" << ii << ", ";
|
|
||||||
std::ios_base::fmtflags flags(asmOss.flags());
|
|
||||||
if (otherIsSplatConstInt)
|
|
||||||
asmOss << "0x" << std::hex << splatVal;
|
|
||||||
else {
|
|
||||||
asmOss << "$" << n_words + has_l2_evict_policy + 2 + ii;
|
|
||||||
others.push_back(v);
|
|
||||||
}
|
|
||||||
asmOss.flags(flags);
|
|
||||||
asmOss << ";";
|
|
||||||
}
|
}
|
||||||
// ---
|
// ---
|
||||||
// create inline ASM signature
|
// create inline ASM signature
|
||||||
|
@@ -258,9 +258,9 @@ void parse_args(py::list &args, py::list do_not_specialize,
|
|||||||
|
|
||||||
void parse_args(py::list &args, py::list &arg_names, std::string ¶ms,
|
void parse_args(py::list &args, py::list &arg_names, std::string ¶ms,
|
||||||
size_t ¶ms_size, py::dict constants) {
|
size_t ¶ms_size, py::dict constants) {
|
||||||
char *params_ptr = params.data();
|
|
||||||
|
|
||||||
size_t len = PyList_Size(args.ptr());
|
size_t len = PyList_Size(args.ptr());
|
||||||
|
params.reserve(8 * len); // 8 max bytes by argument
|
||||||
|
char *params_ptr = params.data();
|
||||||
for (int i = 0; i < len; i++) {
|
for (int i = 0; i < len; i++) {
|
||||||
py::object arg = args[i];
|
py::object arg = args[i];
|
||||||
auto arg_ptr = arg.ptr();
|
auto arg_ptr = arg.ptr();
|
||||||
|
@@ -1,7 +1,12 @@
|
|||||||
|
import torch
|
||||||
|
from torch.testing import assert_allclose
|
||||||
|
|
||||||
import triton
|
import triton
|
||||||
import triton.language as tl
|
import triton.language as tl
|
||||||
|
import triton.runtime as runtime
|
||||||
|
|
||||||
NUM_WARPS = 4
|
NUM_WARPS = 4
|
||||||
|
BLOCK_SIZE = 256
|
||||||
|
|
||||||
# triton kernel
|
# triton kernel
|
||||||
|
|
||||||
@@ -22,6 +27,31 @@ def test_vecadd_no_scf():
|
|||||||
z_ptrs = z_ptr + offset
|
z_ptrs = z_ptr + offset
|
||||||
tl.store(z_ptrs, z)
|
tl.store(z_ptrs, z)
|
||||||
|
|
||||||
ret = triton.compile(kernel, "*fp32,i32,*fp32,i32,*fp32,i32", constants={"BLOCK_SIZE_N": 256}, num_warps=NUM_WARPS, device=0, output="ptx")
|
ptx, shem_size, kernel_name = triton.compile(kernel, "*fp32,i32,*fp32,i32,*fp32,i32", constants={"BLOCK_SIZE_N": 256}, num_warps=NUM_WARPS, device=0, output="ptx")
|
||||||
|
|
||||||
print(ret)
|
torch.zeros([10], device=torch.device('cuda'))
|
||||||
|
device = torch.cuda.current_device()
|
||||||
|
binary = runtime.build_kernel(kernel, "*fp32,i32,*fp32,i32,*fp32,i32",
|
||||||
|
device=device,
|
||||||
|
constants={"BLOCK_SIZE_N": BLOCK_SIZE},
|
||||||
|
num_warps=NUM_WARPS,
|
||||||
|
num_stages=3)
|
||||||
|
grid = lambda META: (1, )
|
||||||
|
|
||||||
|
x = torch.randn((256,), device='cuda', dtype=torch.float32)
|
||||||
|
y = torch.randn((256,), device='cuda', dtype=torch.float32)
|
||||||
|
z = torch.empty((256,), device=x.device, dtype=x.dtype)
|
||||||
|
runtime.launch_kernel(fn=kernel,
|
||||||
|
binary=binary,
|
||||||
|
grid=grid,
|
||||||
|
num_warps=NUM_WARPS,
|
||||||
|
num_stages=3,
|
||||||
|
x_ptr=x,
|
||||||
|
stride_xn=x.stride(0),
|
||||||
|
y_ptr=y,
|
||||||
|
stride_yn=y.stride(0),
|
||||||
|
z_ptr=z,
|
||||||
|
stride_zn=z.stride(0),
|
||||||
|
BLOCK_SIZE_N=tl.constexpr(BLOCK_SIZE))
|
||||||
|
golden_z = x + y
|
||||||
|
assert_allclose(z, golden_z, rtol=1e-7, atol=1e-7)
|
||||||
|
@@ -29,7 +29,7 @@ func @test_store_splat(%ptr: !tt.ptr<f32>) {
|
|||||||
%vs = tt.splat %a : (f32) -> tensor<128xf32>
|
%vs = tt.splat %a : (f32) -> tensor<128xf32>
|
||||||
%mask = tt.splat %true : (i1) -> tensor<128xi1>
|
%mask = tt.splat %true : (i1) -> tensor<128xi1>
|
||||||
|
|
||||||
// CHECK: %{{.*}} = llvm.inline_asm has_side_effects asm_dialect = att operand_attrs = [] "@$0 st.global.v32.b1 [ $1 + 0 ], { $2 };",
|
// CHECK: %{{.*}} = llvm.inline_asm has_side_effects asm_dialect = att operand_attrs = [] "@$0 st.global.b32 [ $1 + 0 ], { $2 };",
|
||||||
// CHECK-SAME: "b,l,r" %{{.*}}, %{{.*}}, %{{.*}} : (i1, !llvm.ptr<f32, 1>, i32) -> !llvm.struct<()>
|
// CHECK-SAME: "b,l,r" %{{.*}}, %{{.*}}, %{{.*}} : (i1, !llvm.ptr<f32, 1>, i32) -> !llvm.struct<()>
|
||||||
tt.store %ptrs, %vs, %mask, {} : tensor<128xf32>
|
tt.store %ptrs, %vs, %mask, {} : tensor<128xf32>
|
||||||
|
|
||||||
|
@@ -183,9 +183,9 @@ module attributes {"triton_gpu.num-warps" = 4 : i32} {
|
|||||||
// CHECK-LABEL: basic_store
|
// CHECK-LABEL: basic_store
|
||||||
func @basic_store(%ptrs: tensor<256x!tt.ptr<f32>, #blocked0>, %vals: tensor<256xf32, #blocked0>, %mask: tensor<256xi1, #blocked0>) {
|
func @basic_store(%ptrs: tensor<256x!tt.ptr<f32>, #blocked0>, %vals: tensor<256xf32, #blocked0>, %mask: tensor<256xi1, #blocked0>) {
|
||||||
// CHECK: llvm.inline_asm has_side_effects asm_dialect = att
|
// CHECK: llvm.inline_asm has_side_effects asm_dialect = att
|
||||||
// CHECK-SAME: st.global.v32.b1 [ ${{.*}} + 0 ], { ${{.*}} };", "b,l,r" %{{.*}}, %{{.*}}, %{{.*}} : (i1, !llvm.ptr<f32, 1>, i32) -> !llvm.struct<()>
|
// CHECK-SAME: st.global.b32 [ ${{.*}} + 0 ], { ${{.*}} };", "b,l,r" %{{.*}}, %{{.*}}, %{{.*}} : (i1, !llvm.ptr<f32, 1>, i32) -> !llvm.struct<()>
|
||||||
// CHECK: llvm.inline_asm has_side_effects asm_dialect = att
|
// CHECK: llvm.inline_asm has_side_effects asm_dialect = att
|
||||||
// CHECK-SAME: st.global.v32.b1 [ ${{.*}} + 0 ], { ${{.*}} };", "b,l,r" %{{.*}}, %{{.*}}, %{{.*}} : (i1, !llvm.ptr<f32, 1>, i32) -> !llvm.struct<()>
|
// CHECK-SAME: st.global.b32 [ ${{.*}} + 0 ], { ${{.*}} };", "b,l,r" %{{.*}}, %{{.*}}, %{{.*}} : (i1, !llvm.ptr<f32, 1>, i32) -> !llvm.struct<()>
|
||||||
tt.store %ptrs, %vals, %mask, {} : tensor<256xf32, #blocked0>
|
tt.store %ptrs, %vals, %mask, {} : tensor<256xf32, #blocked0>
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
Reference in New Issue
Block a user