[codegen] fixed issue in double buffering pointer update
This commit is contained in:
@@ -1577,7 +1577,8 @@ void selection::run(ir::module &src, Module &dst) {
|
||||
offset->addIncoming(next_offset, llvm_inc_block);
|
||||
}
|
||||
else {
|
||||
offset->addIncoming(dst_builder.getInt32(alloc_->get_num_bytes(phi)/(2*4)), llvm_inc_block);
|
||||
unsigned num_bytes = phi->get_type()->get_scalar_ty()->get_primitive_size_in_bits() / 8;
|
||||
offset->addIncoming(dst_builder.getInt32(alloc_->get_num_bytes(phi)/(2*num_bytes)), llvm_inc_block);
|
||||
}
|
||||
ptr->addIncoming(inc_shared->get_pointer(), llvm_inc_block);
|
||||
}
|
||||
|
@@ -72,7 +72,7 @@ handle<T>::~handle(){
|
||||
try{
|
||||
if(has_ownership_ && h_ && h_.unique())
|
||||
_delete(*h_);
|
||||
}catch(const exception::cuda::deinitialized&){
|
||||
}catch(const exception::cuda::base&){
|
||||
// order of destruction for global variables
|
||||
// is not guaranteed
|
||||
}
|
||||
|
@@ -26,6 +26,7 @@
|
||||
#include "triton/driver/error.h"
|
||||
#include "llvm/IR/IRBuilder.h"
|
||||
#include "llvm/IR/Verifier.h"
|
||||
#include "llvm/IR/IRPrintingPasses.h"
|
||||
#include "llvm/IR/Module.h"
|
||||
#include "llvm/Support/SourceMgr.h"
|
||||
#include "llvm/Support/raw_ostream.h"
|
||||
@@ -240,7 +241,6 @@ std::string cu_module::compile_llvm_module(llvm::Module* module) {
|
||||
cu_module::cu_module(driver::context * context, llvm::Module* ll_module): cu_module(context, compile_llvm_module(ll_module)) { }
|
||||
|
||||
cu_module::cu_module(driver::context * context, std::string const & source) : module(context, CUmodule(), true), source_(source){
|
||||
// std::cout << source << std::endl;
|
||||
cu_context::context_switcher ctx_switch(*context);
|
||||
// JIT compile source-code
|
||||
CUjit_option opt[] = {CU_JIT_ERROR_LOG_BUFFER_SIZE_BYTES, CU_JIT_ERROR_LOG_BUFFER};
|
||||
@@ -250,8 +250,10 @@ cu_module::cu_module(driver::context * context, std::string const & source) : mo
|
||||
try{
|
||||
dispatch::cuModuleLoadDataEx(&*cu_, source_.data(), 2, opt, optval);
|
||||
}catch(exception::cuda::base const &){
|
||||
#ifdef TRITON_LOG_PTX_ERROR
|
||||
std::cerr << "Compilation Failed! Log: " << std::endl;
|
||||
std::cerr << errbuf << std::endl;
|
||||
#endif
|
||||
throw;
|
||||
}
|
||||
}
|
||||
|
@@ -12,6 +12,7 @@
|
||||
#include "triton/driver/stream.h"
|
||||
#include "triton/driver/kernel.h"
|
||||
#include "triton/driver/module.h"
|
||||
#include "triton/driver/error.h"
|
||||
#include "triton/ir/module.h"
|
||||
#include "triton/ir/function.h"
|
||||
#include "triton/ir/print.h"
|
||||
@@ -166,6 +167,8 @@ function::caller function::autotune(driver::stream* stream, const grid_fn_ty& gr
|
||||
bin = make_bin(*ir, stream->context(), opt);
|
||||
}catch(const std::runtime_error& e) {
|
||||
return;
|
||||
}catch(const driver::exception::cuda::invalid_ptx& e) {
|
||||
return;
|
||||
}
|
||||
// benchmark
|
||||
ir::function *tmp = ir->get_function_list()[0];
|
||||
@@ -178,6 +181,8 @@ function::caller function::autotune(driver::stream* stream, const grid_fn_ty& gr
|
||||
}
|
||||
};
|
||||
_parallel_loop_nest<std::string>(space, benchmark, 1);
|
||||
if(!ret)
|
||||
throw std::runtime_error("could not find valid option in provided space");
|
||||
return *ret;
|
||||
}
|
||||
|
||||
|
@@ -47,7 +47,7 @@ class CMakeBuild(build_ext):
|
||||
tf_libs = 'tensorflow_framework'
|
||||
|
||||
cmake_args = ['-DCMAKE_LIBRARY_OUTPUT_DIRECTORY=' + extdir,
|
||||
'-DBUILD_EXAMPLES=OFF',
|
||||
'-DBUILD_TESTS=OFF',
|
||||
'-DBUILD_PYTHON_MODULE=ON',
|
||||
'-DPYTHON_INCLUDE_DIRS=' + python_include_dirs,
|
||||
'-DTF_INCLUDE_DIRS=' + tf_include_dirs,
|
||||
|
@@ -160,7 +160,8 @@ void gen_register_op(std::ostream &os, const std::string &name,
|
||||
std::string name = arg->get_name();
|
||||
auto tolower = [](char c) { return std::tolower(c);};
|
||||
std::transform(name.begin(), name.end(), name.begin(), tolower);
|
||||
os << " .Input(\"" << name << ": " << to_tf_scalar_ty(arg->get_type()) << "\")\n";
|
||||
os << " .Attr(\"T" << i << " : {bool, int8, int16, int32, int64, float16, float32, float64}\")" << std::endl;
|
||||
os << " .Input(\"" << name << ": T" << i << "\")\n";
|
||||
}
|
||||
for(size_t i = 0; i < outputs.size(); i++){
|
||||
std::string name = outputs[i];
|
||||
|
@@ -10,11 +10,6 @@
|
||||
#include "cuda/cublas.h"
|
||||
|
||||
|
||||
struct perf_t {
|
||||
double triton;
|
||||
double cublas;
|
||||
};
|
||||
|
||||
namespace drv = triton::driver;
|
||||
namespace rt = triton::runtime;
|
||||
|
||||
@@ -22,6 +17,14 @@ inline size_t ceil(size_t x, size_t y) {
|
||||
return (x + y - 1) / y;
|
||||
};
|
||||
|
||||
inline rt::function::grid_fn_ty grid(size_t M, size_t N) {
|
||||
return [M, N](const rt::function::options_t& x) {
|
||||
return rt::grid_t{ceil(M, x.D<int>("TM")),
|
||||
ceil(N, x.D<int>("TN"))};
|
||||
};
|
||||
}
|
||||
|
||||
|
||||
|
||||
std::vector<double> do_bench(drv::stream* stream, bool AT, bool BT, int32_t M, int32_t N, int32_t K){
|
||||
typedef half_float::half NumericT;
|
||||
@@ -33,9 +36,9 @@ std::vector<double> do_bench(drv::stream* stream, bool AT, bool BT, int32_t M, i
|
||||
int32_t ldb = BT ? N : K;
|
||||
int32_t ldc = M;
|
||||
// create inputs
|
||||
auto dc = std::unique_ptr<drv::buffer>(drv::buffer::create(context, M*N*dt_nbytes));
|
||||
auto da = std::unique_ptr<drv::buffer>(drv::buffer::create(context, M*K*dt_nbytes));
|
||||
auto db = std::unique_ptr<drv::buffer>(drv::buffer::create(context, K*N*dt_nbytes));
|
||||
auto dc = std::unique_ptr<drv::buffer>(drv::buffer::create(context, M*N*dt_nbytes));
|
||||
// create options
|
||||
rt::function::options_space_t opt;
|
||||
opt.defines.push_back({"TYPE", {ty}});
|
||||
@@ -47,11 +50,6 @@ std::vector<double> do_bench(drv::stream* stream, bool AT, bool BT, int32_t M, i
|
||||
opt.defines.push_back({"TN", {"16", "32", "64", "128"}});
|
||||
opt.defines.push_back({"TK", {"32"}});
|
||||
opt.num_warps = {1, 2, 4, 8};
|
||||
// create grid
|
||||
auto grid = [&](const rt::function::options_t& x) {
|
||||
return rt::grid_t{ceil(M, x.D<int>("TM")),
|
||||
ceil(N, x.D<int>("TN"))};
|
||||
};
|
||||
// create function
|
||||
rt::function function(src::dot, opt);
|
||||
// benchmark available libraries
|
||||
@@ -68,7 +66,7 @@ std::vector<double> do_bench(drv::stream* stream, bool AT, bool BT, int32_t M, i
|
||||
result.push_back(tflops(cublas_ms));
|
||||
}
|
||||
// triton
|
||||
double triton_ms = triton::tools::bench([&]() { function({&*da, &*db, &*dc, M, N, K, lda, ldb, ldc}, grid, stream);}, stream);
|
||||
double triton_ms = triton::tools::bench([&]() { function({&*da, &*db, &*dc, M, N, K, lda, ldb, ldc}, grid(M, N), stream);}, stream);
|
||||
result.push_back(tflops(triton_ms));
|
||||
// done
|
||||
return result;
|
||||
@@ -80,11 +78,25 @@ int main() {
|
||||
triton::driver::stream* stream = triton::driver::stream::create(context);
|
||||
// shapes to benchmark
|
||||
typedef std::tuple<bool, bool, int, int, int> config_t;
|
||||
std::vector<config_t> configs = {
|
||||
config_t{false, true, 512, 512, 512},
|
||||
config_t{false, true, 2048, 2048, 2048},
|
||||
config_t{false, true, 8192, 8192, 8192}
|
||||
};
|
||||
std::vector<config_t> configs;
|
||||
for(auto x: std::vector<std::array<bool, 2>>{{false, false},
|
||||
{false, true},
|
||||
{true, false}}){
|
||||
std::vector<config_t> tmp = {
|
||||
config_t{x[0], x[1], 8192, 8192, 8192}
|
||||
// config_t{x[0], x[1], 16, 2048, 2048},
|
||||
// config_t{x[0], x[1], 32, 2048, 2048},
|
||||
// config_t{x[0], x[1], 64, 2048, 2048},
|
||||
// config_t{x[0], x[1], 128, 2048, 2048},
|
||||
// config_t{x[0], x[1], 7000, 2048, 2048},
|
||||
// config_t{x[0], x[1], 16, 4096, 4096},
|
||||
// config_t{x[0], x[1], 32, 4096, 4096},
|
||||
// config_t{x[0], x[1], 64, 4096, 4096},
|
||||
// config_t{x[0], x[1], 128, 4096, 4096},
|
||||
// config_t{x[0], x[1], 7000, 4096, 4096},
|
||||
};
|
||||
configs.insert(configs.end(), tmp.begin(), tmp.end());
|
||||
}
|
||||
// does the work
|
||||
bool AT, BT;
|
||||
int32_t M, N, K;
|
||||
|
@@ -30,21 +30,21 @@ void dot(TYPE * A __noalias __readonly __aligned(16),
|
||||
float xc[TM, TN] = 0;
|
||||
#ifdef AT
|
||||
TYPE* pa[TK, TM] = A + rka[:, newaxis] + rxa[newaxis, :]*lda;
|
||||
bool checka[TK, TM] = rka[:, newaxis] < K;
|
||||
TYPE a[TK, TM] = checka ? *pa : 0;
|
||||
bool checka[TK, TM] = rka[:, newaxis] < TK;
|
||||
TYPE a[TK, TM] = *pa;
|
||||
#else
|
||||
TYPE* pa[TM, TK] = A + rka[newaxis, :]*lda + rxa[:, newaxis];
|
||||
bool checka[TM, TK] = rka[newaxis, :] < K;
|
||||
TYPE a[TM, TK] = checka ? *pa : 0;
|
||||
bool checka[TM, TK] = rka[newaxis, :] < TK;
|
||||
TYPE a[TM, TK] = *pa;
|
||||
#endif
|
||||
#ifdef BT
|
||||
TYPE* pb[TN, TK] = B + rkb[newaxis, :]*ldb + ryb[:, newaxis];
|
||||
bool checkb[TN, TK] = rkb[newaxis, :] < K;
|
||||
TYPE b[TN, TK] = checkb ? *pb : 0;
|
||||
bool checkb[TN, TK] = rkb[newaxis, :] < TK;
|
||||
TYPE b[TN, TK] = *pb;
|
||||
#else
|
||||
TYPE* pb[TK, TN] = B + rkb[:, newaxis] + ryb[newaxis, :]*ldb;
|
||||
bool checkb[TK, TN] = rkb[:, newaxis] < K;
|
||||
TYPE b[TK, TN] = checkb ? *pb : 0;
|
||||
bool checkb[TK, TN] = rkb[:, newaxis] < TK;
|
||||
TYPE b[TK, TN] = *pb;
|
||||
#endif
|
||||
for(int k = K; k > 0; k = k - TK){
|
||||
xc = USEA @ USEB + xc;
|
||||
@@ -60,8 +60,8 @@ void dot(TYPE * A __noalias __readonly __aligned(16),
|
||||
#endif
|
||||
checka = k > TK;
|
||||
checkb = k > TK;
|
||||
a = checka ? *pa : 0;
|
||||
b = checkb ? *pb : 0;
|
||||
a = *pa;
|
||||
b = *pb;
|
||||
}
|
||||
int rxc[TM] = ridx * TM + (0 ... TM);
|
||||
int ryc[TN] = ridy * TN + (0 ... TN);
|
||||
@@ -70,8 +70,8 @@ void dot(TYPE * A __noalias __readonly __aligned(16),
|
||||
bool checkc0[TM] = rxc < M;
|
||||
bool checkc1[TN] = ryc < N;
|
||||
bool checkc[TM, TN] = checkc0[:, newaxis] && checkc1[newaxis, :];
|
||||
*?(checkc) pc = c;
|
||||
*pc = c;
|
||||
}
|
||||
)";
|
||||
|
||||
}
|
||||
}
|
||||
|
@@ -9,6 +9,9 @@
|
||||
#include "src/dot.h"
|
||||
#include "cuda/cublas.h"
|
||||
|
||||
namespace drv = triton::driver;
|
||||
namespace rt = triton::runtime;
|
||||
|
||||
template<class T>
|
||||
void diff(const std::vector<T>& x, const std::vector<T>& y){
|
||||
for(size_t i = 0; i < x.size(); i++)
|
||||
@@ -44,16 +47,44 @@ void cpu_ref(bool AT_, bool BT_, size_t M, size_t N, size_t K,
|
||||
cpu_ref<T, false, false>(c, a, b, M, N, K);
|
||||
}
|
||||
|
||||
|
||||
struct perf_t {
|
||||
double triton;
|
||||
double cublas;
|
||||
inline size_t ceil(size_t x, size_t y) {
|
||||
return (x + y - 1) / y;
|
||||
};
|
||||
|
||||
namespace drv = triton::driver;
|
||||
namespace rt = triton::runtime;
|
||||
inline rt::function::grid_fn_ty grid(size_t M, size_t N) {
|
||||
return [M, N](const rt::function::options_t& x) {
|
||||
return rt::grid_t{ceil(M, x.D<int>("TM")),
|
||||
ceil(N, x.D<int>("TN"))};
|
||||
};
|
||||
}
|
||||
|
||||
perf_t do_bench(drv::stream* stream, bool AT, bool BT, int32_t M, int32_t N, int32_t K){
|
||||
namespace aux{
|
||||
template<std::size_t...> struct seq{};
|
||||
|
||||
template<std::size_t N, std::size_t... Is>
|
||||
struct gen_seq : gen_seq<N-1, N-1, Is...>{};
|
||||
|
||||
template<std::size_t... Is>
|
||||
struct gen_seq<0, Is...> : seq<Is...>{};
|
||||
|
||||
template<class Ch, class Tr, class Tuple, std::size_t... Is>
|
||||
void print_tuple(std::basic_ostream<Ch,Tr>& os, Tuple const& t, seq<Is...>){
|
||||
using swallow = int[];
|
||||
(void)swallow{0, (void(os << (Is == 0? "" : ", ") << std::get<Is>(t)), 0)...};
|
||||
}
|
||||
} // aux::
|
||||
|
||||
template<class Ch, class Tr, class... Args>
|
||||
auto operator<<(std::basic_ostream<Ch, Tr>& os, std::tuple<Args...> const& t)
|
||||
-> std::basic_ostream<Ch, Tr>&
|
||||
{
|
||||
os << "(";
|
||||
aux::print_tuple(os, t, aux::gen_seq<sizeof...(Args)>());
|
||||
return os << ")";
|
||||
}
|
||||
|
||||
|
||||
bool do_test(drv::stream* stream, bool AT, bool BT, int32_t M, int32_t N, int32_t K, int32_t TM, int32_t TN, int32_t TK, size_t nwarp){
|
||||
typedef half_float::half NumericT;
|
||||
std::string ty = "half";
|
||||
size_t dt_nbytes = sizeof(NumericT);
|
||||
@@ -71,12 +102,12 @@ perf_t do_bench(drv::stream* stream, bool AT, bool BT, int32_t M, int32_t N, int
|
||||
hb[i] = static_cast<NumericT>((float)rand()/RAND_MAX);
|
||||
for(size_t i = 0; i < hc.size(); i++)
|
||||
hc[i] = static_cast<NumericT>((double)0);
|
||||
drv::buffer* dc = drv::buffer::create(context, hc.size()*dt_nbytes);
|
||||
drv::buffer* da = drv::buffer::create(context, ha.size()*dt_nbytes);
|
||||
drv::buffer* db = drv::buffer::create(context, hb.size()*dt_nbytes);
|
||||
stream->write(da, true, 0, ha);
|
||||
stream->write(db, true, 0, hb);
|
||||
stream->write(dc, true, 0, hc);
|
||||
auto dc = std::shared_ptr<drv::buffer>(drv::buffer::create(context, hc.size()*dt_nbytes));
|
||||
auto da = std::shared_ptr<drv::buffer>(drv::buffer::create(context, ha.size()*dt_nbytes));
|
||||
auto db = std::shared_ptr<drv::buffer>(drv::buffer::create(context, hb.size()*dt_nbytes));
|
||||
stream->write(&*da, true, 0, ha);
|
||||
stream->write(&*db, true, 0, hb);
|
||||
stream->write(&*dc, true, 0, hc);
|
||||
stream->synchronize();
|
||||
// run
|
||||
rt::function::options_space_t opt;
|
||||
@@ -85,81 +116,50 @@ perf_t do_bench(drv::stream* stream, bool AT, bool BT, int32_t M, int32_t N, int
|
||||
opt.defines.push_back({"AT", {""}});
|
||||
if(BT)
|
||||
opt.defines.push_back({"BT", {""}});
|
||||
opt.defines.push_back({"TM", {"16", "32", "64", "128"}});
|
||||
opt.defines.push_back({"TN", {"16", "32", "64", "128"}});
|
||||
opt.defines.push_back({"TK", {"32"}});
|
||||
opt.num_warps = {1, 2, 4, 8};
|
||||
opt.defines.push_back({"TM", {std::to_string(TM)}});
|
||||
opt.defines.push_back({"TN", {std::to_string(TN)}});
|
||||
opt.defines.push_back({"TK", {std::to_string(TK)}});
|
||||
opt.num_warps = {nwarp};
|
||||
rt::function function(src::dot, opt);
|
||||
|
||||
auto ceil = [](size_t x, size_t y) { return (x + y - 1) / y; };
|
||||
auto grid = [&](const rt::function::options_t& x) { return rt::grid_t{ceil(M, x.D<int>("TM")), ceil(N, x.D<int>("TN")), 1}; };
|
||||
|
||||
auto tflops = [&](double nanosec) { return 2.*M*N*K / nanosec * 1e-3; };
|
||||
perf_t res;
|
||||
res.triton = tflops(triton::tools::bench([&]() { function({da, db, dc, M, N, K, lda, ldb, ldc}, grid, stream);}, stream));
|
||||
NumericT alpha(static_cast<double>(1));
|
||||
NumericT beta(static_cast<double>(0));
|
||||
cublasGemmAlgo_t fastest;
|
||||
cublasGemm(CUDA_R_16F, stream, AT, BT, M, N, K, &alpha, da, lda, db, ldb, &beta, dc, ldc, &fastest);
|
||||
res.cublas = tflops(triton::tools::bench([&]() { cublasGemm(CUDA_R_16F, stream, AT, BT, M, N, K,
|
||||
&alpha, da, lda, db, ldb, &beta, dc, ldc, nullptr, fastest); },
|
||||
stream));
|
||||
|
||||
try {
|
||||
function({&*da, &*db, &*dc, M, N, K, lda, ldb, ldc}, grid(M, N), stream);
|
||||
} catch (const std::runtime_error& e) {
|
||||
return true;
|
||||
}
|
||||
// test
|
||||
// stream->read(dc, true, 0, hc);
|
||||
// std::vector<NumericT> rc(hc.size());
|
||||
// cpu_ref(AT, BT, M, N, K, rc, ha, hb);
|
||||
// for(size_t i = 0; i < M*N; i++)
|
||||
// if(std::isinf(hc[i]) || std::isnan(hc[i]) || std::abs(hc[i] - rc[i])/std::max(hc[i], rc[i]) > 1e-2){
|
||||
// std::cout << i << " " << hc[i] << " " << rc[i] << std::endl;
|
||||
// exit(EXIT_FAILURE);
|
||||
// }
|
||||
// std::cout << hc[0] << " " << std::endl;
|
||||
// std::cout << "Pass!" << std::endl;
|
||||
|
||||
// clean-up
|
||||
delete dc;
|
||||
delete da;
|
||||
delete db;
|
||||
return res;
|
||||
stream->read(&*dc, true, 0, hc);
|
||||
std::vector<NumericT> rc(hc.size());
|
||||
cpu_ref(AT, BT, M, N, K, rc, ha, hb);
|
||||
for(size_t i = 0; i < M*N; i++)
|
||||
if(std::isinf(hc[i]) || std::isnan(hc[i]) || std::abs(hc[i] - rc[i])/std::max(hc[i], rc[i]) > 1e-2)
|
||||
return false;
|
||||
return true;
|
||||
}
|
||||
|
||||
int main() {
|
||||
struct config_t{
|
||||
bool AT;
|
||||
bool BT;
|
||||
int32_t M;
|
||||
int32_t N;
|
||||
int32_t K;
|
||||
|
||||
std::string repr() {
|
||||
std::ostringstream oss;
|
||||
oss << AT << " " << BT << " " << M << " " << N << " " << K;
|
||||
return oss.str();
|
||||
}
|
||||
|
||||
perf_t perf(triton::driver::stream *stream){
|
||||
return do_bench(stream, AT, BT, M, N, K);
|
||||
}
|
||||
};
|
||||
// shapes to benchmark
|
||||
std::vector<config_t> configs = {
|
||||
// {false, false, 8192, 512, 512},
|
||||
{false, true, 128, 128, 128}
|
||||
// {false, true, 128, 128, 128},
|
||||
// {false, false, 128, 128, 128},
|
||||
// {true, false, 128, 128, 128},
|
||||
// {true, true, 128, 128, 128}
|
||||
// {false, true, 32768, 256, 512}
|
||||
// {true, false, 8192, 512, 512},
|
||||
// {true, true, 8192, 512, 512}
|
||||
};
|
||||
// initialize default compute device
|
||||
auto context = triton::driver::backend::contexts::get_default();
|
||||
triton::driver::stream* stream = triton::driver::stream::create(context);
|
||||
// shapes to benchmark
|
||||
typedef std::tuple<bool, bool, int, int, int, int, int, int, int> config_t;
|
||||
std::vector<config_t> configs;
|
||||
for(bool AT: std::array<bool, 2>{false, true})
|
||||
for(bool BT: std::array<bool, 2>{false, true})
|
||||
for(int TM: std::vector<int>{16, 128})
|
||||
for(int TN: std::vector<int>{16, 128})
|
||||
for(int TK: std::vector<int>{16, 32})
|
||||
for(int nwarps: std::vector<int>{1, 2, 4, 8}){
|
||||
configs.push_back(config_t{AT, BT, 128, 128, 128, TM, TN, TK, nwarps});
|
||||
}
|
||||
// does the work
|
||||
for(config_t c: configs){
|
||||
perf_t perf = c.perf(stream);
|
||||
std::cout << "// " << c.repr() << ", " << perf.triton << ", " << perf.cublas << std::endl;
|
||||
bool AT, BT;
|
||||
int M, N, K, TM, TN, TK, nwarp;
|
||||
for(const auto& c: configs){
|
||||
std::tie(AT, BT, M, N, K, TM, TN, TK, nwarp) = c;
|
||||
std::cout << "Testing " << c << " ... " << std::flush;
|
||||
if(do_test(stream, AT, BT, M, N, K, TM, TN, TK, (size_t)nwarp))
|
||||
std::cout << " Pass! " << std::endl;
|
||||
else
|
||||
std::cout << " Fail! " << std::endl;
|
||||
}
|
||||
}
|
||||
|
Reference in New Issue
Block a user