From 50ff1aea86ef6a4f49668e865b5865842a67bb4e Mon Sep 17 00:00:00 2001 From: Philippe Tillet Date: Thu, 25 Feb 2021 14:49:47 -0500 Subject: [PATCH] [DOCS] Added Python 02-fused-softmax.ipynb tutorial --- python/tutorials/02-fused-softmax.ipynb | 308 ++++++++++++++++++++++++ 1 file changed, 308 insertions(+) create mode 100644 python/tutorials/02-fused-softmax.ipynb diff --git a/python/tutorials/02-fused-softmax.ipynb b/python/tutorials/02-fused-softmax.ipynb new file mode 100644 index 000000000..47a159208 --- /dev/null +++ b/python/tutorials/02-fused-softmax.ipynb @@ -0,0 +1,308 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "induced-zoning", + "metadata": {}, + "source": [ + "# Getting Started" + ] + }, + { + "cell_type": "markdown", + "id": "median-malaysia", + "metadata": {}, + "source": [ + "Custom GPU kernels for elementwise additions are educationally valuable but won't get you very far in practice. Let us consider instead the case of a simple (numerically stabilized) softmax operation:" + ] + }, + { + "cell_type": "code", + "execution_count": 34, + "id": "precise-professor", + "metadata": {}, + "outputs": [], + "source": [ + "import torch\n", + "\n", + "# Compute the row-wise softmax of x \\in R^{M \\times N}\n", + "def naive_softmax(x):\n", + " # read MN elements ; write M elements\n", + " x_max = torch.max(x, axis=1)[0]\n", + " # read 2MN elements ; write MN elements\n", + " z = x - x_max[:, None]\n", + " # read MN elements ; write MN elements\n", + " numerator = torch.exp(x)\n", + " # read MN elements ; write M elements\n", + " denominator = torch.sum(numerator, axis=1)\n", + " # read 2MN elements ; write MN elements\n", + " ret = numerator / denominator[:, None]\n", + " # in total: read 7MN elements ; wrote 3MN + 2M elements\n", + " return ret" + ] + }, + { + "cell_type": "markdown", + "id": "gorgeous-monday", + "metadata": {}, + "source": [ + "When implemented naively in pytorch, computing $y$ requires reading $7MN$ elements from DRAM and writing back $3MN + 2M$ elements.\n", + "\n", + "Instead, we want to write a custom \"fused\" pytorch operators that only reads X once and does all the necessary computations on-chip. This would require reading and writing back only $MN$ bytes, so we could expect a theoretical speed-up of 5x. In practice, though, we expect less because our kernel will spend some time computing exponentials and moving data around in shared memory." + ] + }, + { + "cell_type": "markdown", + "id": "identical-conditions", + "metadata": {}, + "source": [ + "# Writing the Compute Kernel" + ] + }, + { + "cell_type": "markdown", + "id": "prepared-apparatus", + "metadata": {}, + "source": [ + "Our softmax kernel works as follows: each program loads a row of X and writes back a normalized row of Y. Note that one important limitation of Triton is that each block must have a power-of-two number of elements, which means that we need to guard the memory operations properly if we want to handle any possible input shapes:\n", + "\n", + "```c\n", + "__global__ void softmax(float* Y, float* X, int stride_xm, int stride_ym, int M, int N){\n", + " // row index\n", + " int m = get_program_id(0);\n", + " // column indices\n", + " int n [BLOCK] = 0 ... BLOCK;\n", + " // the memory address of all the elements\n", + " // that we want to load can be computed as follows\n", + " float* px [BLOCK] = X + m*stride_xm + n;\n", + " // because BLOCK has to be a power of two\n", + " // (per Triton-C specs), it is important\n", + " // to guard each memory operation with predicates\n", + " // or we will read out of bounds\n", + " bool check[BLOCK] = n < N;\n", + " float x [BLOCK] = check ? *px : -F32_INFINITY;\n", + " // syntax for reduction in Triton is:\n", + " // x[..., OPERATOR, ...]\n", + " // ^\n", + " // index\n", + " // The operators currently supported are {min, max, +}\n", + " float z [BLOCK] = x - x[max];\n", + " // The exponential in Triton is fast but approximate \n", + " // (i.e., like __expf in CUDA)\n", + " float num [BLOCK] = exp(z);\n", + " float denom = num[+];\n", + " // The result of the reduction is now stored in y\n", + " float y [BLOCK] = num / denom;\n", + " // We write it back\n", + " float* py [BLOCK] = Y + m*stride_ym + n;\n", + " *?(check)py = y; \n", + "}\n", + "```" + ] + }, + { + "cell_type": "markdown", + "id": "forbidden-wednesday", + "metadata": {}, + "source": [ + "# Writing the Torch bindings" + ] + }, + { + "cell_type": "code", + "execution_count": 29, + "id": "former-pottery", + "metadata": {}, + "outputs": [], + "source": [ + "import torch\n", + "import triton\n", + "\n", + "# source-code for Triton compute kernel\n", + "_src = \"\"\"\n", + "__global__ void softmax(float* Y, float* X, int stride_ym, int stride_xm, int M, int N){\n", + " int m = get_program_id(0);\n", + " int n [BLOCK] = 0 ... BLOCK;\n", + " float* px [BLOCK] = X + m*stride_xm + n;\n", + " bool check[BLOCK] = n < N;\n", + " float x [BLOCK] = check ? *px : -F32_INFINITY;\n", + " float z [BLOCK] = x - x[max];\n", + " float num [BLOCK] = exp(z);\n", + " float denom = num[+];\n", + " float y [BLOCK] = num / denom;\n", + " float* py [BLOCK] = Y + m*stride_ym + n;\n", + " *?(check)py = y; \n", + "}\n", + "\"\"\"\n", + "\n", + "# We need to make sure that BLOCK is the smallest power of two\n", + "# greater than the number of rows N of the input matrix.\n", + "# Different values of BLOCK will result in different kernels\n", + "def next_power_of_2(n):\n", + " n -= 1\n", + " n |= n >> 1\n", + " n |= n >> 2\n", + " n |= n >> 4\n", + " n |= n >> 8\n", + " n |= n >> 16\n", + " n += 1\n", + " return n\n", + "\n", + "_kernels = dict()\n", + "def make_kernel(N, device):\n", + " BLOCK = next_power_of_2(N)\n", + " key = (BLOCK, device)\n", + " if key not in _kernels:\n", + " defines = {'BLOCK': BLOCK}\n", + " _kernels[key] = triton.kernel(_src, device=device, defines=defines)\n", + " return _kernels[key]\n", + "\n", + "class _softmax(torch.autograd.Function):\n", + " \n", + " @staticmethod\n", + " def forward(ctx, x):\n", + " # constraints of the op\n", + " assert x.dtype == torch.float32\n", + " y = torch.empty_like(x)\n", + " # *create launch grid*:\n", + " # here we just launch a grid of M programs\n", + " M, N = y.shape\n", + " grid = lambda opt: (M, )\n", + " # *launch kernel*:\n", + " kernel = make_kernel(N, y.device)\n", + " kernel(y.data_ptr(), x.data_ptr(), y.stride(0), x.stride(0), M, N, grid = grid)\n", + " return y\n", + " \n", + "softmax = _softmax.apply" + ] + }, + { + "cell_type": "markdown", + "id": "exclusive-salvation", + "metadata": {}, + "source": [ + "# Writing a Unit Test" + ] + }, + { + "cell_type": "code", + "execution_count": 39, + "id": "pretty-prospect", + "metadata": { + "scrolled": true + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "tensor([[0.0054, 0.0004, 0.0007, ..., 0.0004, 0.0019, 0.0018],\n", + " [0.0008, 0.0014, 0.0006, ..., 0.0023, 0.0019, 0.0012],\n", + " [0.0009, 0.0003, 0.0001, ..., 0.0010, 0.0001, 0.0012],\n", + " ...,\n", + " [0.0003, 0.0003, 0.0036, ..., 0.0002, 0.0003, 0.0013],\n", + " [0.0025, 0.0008, 0.0004, ..., 0.0016, 0.0007, 0.0005],\n", + " [0.0003, 0.0026, 0.0004, ..., 0.0005, 0.0009, 0.0005]],\n", + " device='cuda:0')\n", + "tensor([[0.0054, 0.0004, 0.0007, ..., 0.0004, 0.0019, 0.0018],\n", + " [0.0008, 0.0014, 0.0006, ..., 0.0023, 0.0019, 0.0012],\n", + " [0.0009, 0.0003, 0.0001, ..., 0.0010, 0.0001, 0.0012],\n", + " ...,\n", + " [0.0003, 0.0003, 0.0036, ..., 0.0002, 0.0003, 0.0013],\n", + " [0.0025, 0.0008, 0.0004, ..., 0.0016, 0.0007, 0.0005],\n", + " [0.0003, 0.0026, 0.0004, ..., 0.0005, 0.0009, 0.0005]],\n", + " device='cuda:0')\n", + "True\n" + ] + } + ], + "source": [ + "x = torch.randn(1823, 781, device='cuda')\n", + "y_tri = softmax(x)\n", + "y_ref = torch.softmax(x, axis=1)\n", + "print(y_tri)\n", + "print(y_ref)\n", + "print(torch.allclose(y_tri, y_ref))" + ] + }, + { + "cell_type": "markdown", + "id": "regular-andrew", + "metadata": {}, + "source": [ + "Seems to work!" + ] + }, + { + "cell_type": "markdown", + "id": "polished-australia", + "metadata": {}, + "source": [ + "# Writing a Benchmark" + ] + }, + { + "cell_type": "code", + "execution_count": 74, + "id": "chubby-audit", + "metadata": {}, + "outputs": [ + { + "data": { + "image/png": "iVBORw0KGgoAAAANSUhEUgAAAYUAAAEGCAYAAACKB4k+AAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjMuMywgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/Il7ecAAAACXBIWXMAAAsTAAALEwEAmpwYAABCZUlEQVR4nO3dd3wUdfrA8c+zJYVUSEJLQu9Krwo2bICIFct5guXk9NSz9zsPz7vf2dt5p2LvvSGWUwFBQaqAVKVICTUEQgKkbfb7++M7WZIQUiCbzSbP29e+dnZ2Zuf5xmWfmfk2McaglFJKAbhCHYBSSqn6Q5OCUkqpAE0KSimlAjQpKKWUCtCkoJRSKsAT6gCORHJysmnXrl2ow1BKqbCycOHCncaYlIreC+uk0K5dOxYsWBDqMJRSKqyIyIZDvae3j5RSSgVoUlBKKRWgSUEppVSAJgWllFIBmhSUUkoFBC0piEi6iEwXkRUislxEbnDWNxORb0RktfPc1FkvIvKUiKwRkZ9FpF+wYlNKKVWxYF4p+IBbjDE9gCHAtSLSA7gTmGqM6QxMdV4DjAQ6O48JwDNBjE0ppVQFgtZPwRizFdjqLOeKyEogFTgLONHZ7FXgO+AOZ/1rxo7lPUdEEkWklfM5Kgzl5BcxbeUOCnzFnNqjJc1iIkId0hHbva+QvKJiYiI8NIl043XX7R1YYww+vyGvqJj8wmLyipyHs1xQ5KfAV0y+81zg85NfVEyxH6K8LqK9bqK8bqK8LqK8bmIiPfRKSyDS467Tcqj6q046r4lIO6AvMBdoUeqHfhvQwllOBTaV2i3DWVcmKYjIBOyVBG3atAle0A3cmh25fLF0GwnRXhKbeGkWE0HTJhEkNvHStEkELhGKjcFvDH6/odhv8BuI9LqIi/QgIhV+7p68Ir5dsZ0vlm7l+9U7KSz2A3DPx8sY1jmZMb1bc9pRLYmNrN/9Jo0xbNmTz/LNe1i+JYflW+zz1j35ZbbzuoUmER5iItxER9gf3EiP66BnEcHnNxT7/RQ7f0+f81xU7KeouOyzz1kuLPY7rw8s1/YUKK0TorjxlC6c2y8VTx0nOVX/BP1fpojEAh8CNxpjckr/mBhjjIjU6CtujJkETAIYMGCAzhB0mB79+le+XLbtsPaNiXDTKjGaVglRziOa+GgvP6zO5Ic1OykqNrROiOLSY9oyqmcrorwuPluylc+WbOHm95YQ6VnK8G7NGXF0S7xuF3vyisjeX0R2XiE5znKkx0XHlFg6No+lU/NY2iY1qdHZbE5+Eau25rJiyx7W7dxHoc/5MXaSnM9vE16hz9gz6iI/+b5i8ovsWfaevCL25BUB4BLokBLLoPbNOKp1PPFRXvYVFrO/wMf+Ivu8r7CY/YU+50zdnp3n5vsCZ+0AHpfgcgkel+Au9fC6XUR5XcRFefC4XER4xHl24XXb90seEc7raCcJRXvtIyrCTZTnwBVApMdFZKnE5BLIL/I75bMx5RUVsyMnn2dnrOX2D3/muZlrufW0row4uuUhk35jsLfAx7rMvazL3Eduga/ijcyBpO53lu3JEwzv1pyeaQnVOtbufYXc9dFSMvcWkJoYTWrT6MBzWmI0SbGRFDrfpzzn/13JFWFa02g6t4irxZJbEsyZ10TEC0wB/meMecxZ9wtwojFmq4i0Ar4zxnQVkeec5bfLb3eozx8wYIDRYS5qzu839L3/G07u3px7RnVn9/5Cdu8vYte+QrL3F5K9vwi/AbcLXCK4xP54uVxCXqGPrXvy2Zqdz9Y9eWzZk8/OvQUYA6mJ0ZzRqxUjj25Jn/TEg35YjDH8tHE3kxdv4fOlW9m5t7DM+xFuFwlNvCREe8krLGZzdl7gPbdLaNOsCe2SmpAQ7aVJpD07bxLhISbSTXSEh6y9BazcmsOKrTls2nVg37goD9Fety2DU5aSH+iSH+SSH8+SH9bYKA9dW8TRo3UC3VvF0SSifl/ZHAljDP9bvo1Hvv6VNTv20istgdtO78qwTsk1Sg5FxX72FfjIyfORk19kH3k+cvOLyM33lbnNVfpHrthviPC48bqFSM+BBOhxy4EfxMCtMj/5hcVEel20ToimdWI0rROjSE20yylxkYGEnOMcN9eJo8C5ai2voKiYDVn7WZu5l7WZe9meU3BEf88It4t/nduT8/qnVbrdpl37Gf/yPDJ259GvTSJbnH9TRcXV+03+4wkduGtk98OKUUQWGmMGVPhesJKC2G/Tq8AuY8yNpdY/DGQZYx4QkTuBZsaY20XkDOA6YBQwGHjKGDOosmNoUjg8yzbvYfS/f+CJC/twdt/UI/68Qp+f3fsLaR4XWe0fEV+xnxVbc/C6XSQ6iSDa6y6z//5CH+sy97E2cy8btmaSsP4LUnfPZw49+bJ4ELsKPeQVFQe2F4H2STF0bx1Pj1bx9HCeaxJXY+Yr9vPRos08+e1qNmfn0SE5hiaRbtwuF24Bj8uFy2UTdF5hMfsLi9lb4As8F/oq/tEtzSXYqxvnVlu0141LhKJiP4XFfgp99hZZoc9Pkd8Q6XEFti9dH5JX5GdLdh6ZuUf2A14iPspDx+axdEiOpUNKjL1KTYkhscmh68HKX/m5RNhX4OPat35i9tosrjmxI7ed1hWX6+Dv3vIte7js5fkUFBXz4mUDGdiuGQDFfsOO3Hw2785jc3Yeu/YVEulxEx3hck5YDvwNWiVE0zIh6rDKG6qkMAz4HlgKlHxb7sbWK7wHtAE2ABcYY3Y5SeRpYASwH7jcGFPpL74mhcPz3Iy1/OvLVcy9+2RaxB/el6pOGAMb58DiN2D5J1C4FyJi7XNkAvS6AH/fS8lLOop9BT5iozwN+oy+rhT4inlr7kZmr80K1H3YW25+/H4oNoZor5smEW5iI22Fe0ykh5gIDzGRHuKjPMRHe4mP8hIf7SE+ykuc8//G65ZaTdAFvmK27clnc3YeW7LzycwtoEmEm7ioA8eNc54jvS6Eg4/tdQsJ0d5ai6uo2M/fJi/nrbkbOa1HCx6/sA8xperQfli9k6vfWEh8lIdXrxgUlFtAVQlJUqgLmhQOz7iX5rElO49vbz4hdEHszYQ3zwOXB2JbQlwLiHUecS1hxwpY9CbsWmsTwVFnQ5/fQ/pg2DgbFr4KKz6F4gJo3Rf6jYNuoyEmxV4yKBVCxhhemb2e+6esoGvLeF4cP4DWidF8smgzt76/hE7NY3nl8kGHfaZ/pDQpqIBCn5/e933NBQPSuO+so0MXyOK34JNroO1QyMuGvdtgf1bZbdoOhT6XQI+zIDL24M/I2w0/v2cTxI7ldl1UIiR3gZQu9jm5CyR1hqZtwe0NdqmUKmP6Lzu4/q1FRHndjOndmpdm/cYxHZJ4blx/4qNC932sLCnotXYjs2jjbvKKijm2U3JoA9kwC6Kbwvgp4HKaQfoKYd8OyN0OMUnQtF3lnxHdFAb/EQZNgC0/waZ5sPNX2LkaVn8Di944sK3LA4ltIakTJHW0j2Yd7THiWoG3Ht9GU2HrpK7N+ehPx3Llq/N5adZvnNm7NY+M7VWv+4VoUmhkZq3NwiUwpENSaANZP8teCbhKtYv3REBCmn3UhAik9reP0vJ2w841NlHsWgtZayBrHaz/Hor2l922SRLEt4b4VPsc2wI8keDy2isMlwfcEfbR6WSICXFSVWGjS4s4Jl87jLm/ZXFaj5YVVjzXJ5oUGpnZa3bSMzWBhOgQ3krJ2QK7f4NBVwX3ONFNIX2gfZRmDORutUkie5ONJ2ezfd6zGTLmH3wrq7SU7vDHGTZpKFUNTWMiGHF0q1CHUS2aFBqRfQU+Fm/K5qrjO4Q2kA2z7XPboaE5vohzVdD60NsU+6C4EPxFdtlfBMVFsGkufHglfP8onHR33cWsVB3RpNCIzFu/C5/fMLRjPahPiIiDlj1DG0dl3B77KC8x3dZXfP8odD+zfpdBqcOgA500IrPX7CTC7WJAu6ahDWT9LGgzBFz1t7KtUiP+BdHN4NNr7VWEUg2IJoXyfptpW680QLPWZNGvbSJR3hD+GO/NhJ2/QLsQ3TqqDU2awRmPwNYlMPupyrfN3gQvj4JPr7NNb5Wq5zQplJaxEF4/B766s+ptw8yufYWs2JoT+ltHG0Ncn1BbepxlH989AJm/VLxNxgJ4fjhsWWz7Zfz3GHvrSal6TJNCiYJcW4Ho99n27v7iqvcJIz+uta1pQt8/YTZ4m0CrPqGNozaMegQimtirgPLfl6Uf2CsEbzRcNQ3+8C1EJcCb58Mn1+pVg6q3NCmU+OI2yN4AA66AghzYvizUEdWqWWt3EhvpoXc1h/QNmvWzIG2g7ZMQ7mKbw4gHIWMezJtk1xkD0/9lTzBS+9mE0LybXf7jDDjuFljytnPV8G1o41eqAtr6CODn9+0/1BPuhL6/hwUvwYYfoVXvoByu2G+46d3FjOrZss7aLs9es5PB7ZuFdhKVvN022Takppy9LoBlH8K390GHE2HGQ7D8I+j9OzjzibJ9GTyRcPK9doymT66xYz91HWUrrYsL7ThOxUXgKwBTDMdeD51OCVXJVCOlVwq7foMpN0H6EDj+Nh6eu4/90a0O3PsOgik/b2Hyki08O2Nd0I5R2ubsPNZn7Q/9raONcwADbY8NbRy1SQRGP257PT97HCz/GE65D87+76E7t6X2gz/OhGE3w5ZFsO4722FuxypbMZ2fDZm/wgdXwJ6MuiyNUo38SqG4CD78A4gLznuer1ft5D/T1zIgvisnbfjR3gqo5RE3fcV+nvx2NSKweFM2Gbv3k9a0Sa0eo7zZa3YCMLRTiIe22DDLDhORWuE4XOErIRVGPQz/uxvOfAq6j656H08knPI3+6hI1lqbZD75E1z6SdnhQJQKosb9TfvuAdi8AM58gn3RrZk4eTki8M2+jnZgtl21fyb/yeItrNu5j7tGdgPgy6WHNyVmTcxem0VSTARdQzBuexnrZ9mE0BAHn+t9Edy2tnoJoTqSOsLp/4TfZhyor1CqDjTepPDb97ZXat/fw9Hn8vg3v7I1J597RnVnnr+r3WZD7d5CKir289TU1RzVOp6rjuvA0anxfL70kLON1gpjDLPW7OSYjkmhnX2sINe26w/n/glVqe2/b//LoPPp8O3fDt3sValaFrSkICIvicgOEVlWal0fEZkjIotFZIGIDHLWi4g8JSJrRORnEekXrLgA2L8LPppgz8ZGPsTyLXt4efZ6Lh7UhsuObcf2iLbscyfAxh9r9bAfLsxg46793HJye2TWE4xrtydwCylY1mbuZUduAUNDXZ+waa6tPG1I9QnBJgJj/m2b8H50lR1a/FCy1sJ74+DH/2gva3VEgnml8Ap2as3SHgLuM8b0Ae51XgOMBDo7jwnAM0GMC9ZOs6NgnvcixZ4m3P3xMpo28XLH6d3wuF0M6ZjMT3St1SuFAl8x/562hsFpkZy06Eb4diJjtj0NBPcW0qw1tn9CyDutbZhth59OHxzaOMJNXAs480l7lTXzoYPf9/th7iR4dhj88qWt15h0glOpr1TNBS0pGGNmArvKrwbineUEYIuzfBbwmrHmAIkiEry2mj3PhxuWQOs+vDV3A0s2ZfPX0T1IaGKHkx7WKZkZ+Z3t8M65tfOD/d78TeRnb+N5/0Rk7VRodxxRm3/ktJa5Qb2FNGvNTtKaRtMmKbiV2VVaP8t2WIuICW0c4ajHGNvE9ftHYdP8A+t3b4DXxsCXt9krsBuWwIVv2Ka/L51ux2baV8kQ4EpVoK7rFG4EHhaRTcAjwF3O+lRgU6ntMpx1BxGRCc6tpwWZmZmHH0l8K3bk5PPQV78wrFMyY3ofGEZ5aKdk5tdivUJ+UTEfT/uBKTH3E5ezBi56C857AcTNNXGzgnYLadOu/fy4Liv0VwlFebB5YcOuTwi2kQ/YCYA+ngAFe2HhK/DMsXYIjTOfgks+sEOBdz8Trp0Hx/4ZlrwDT/e305X6/dU7TuF+WP4JvHspvDkWtv4cxEKp+qiuk8I1wE3GmHTgJuDFmn6AMWaSMWaAMWZASkrKEQVz35QVFBT7+cfZR5ephO2YEsPO2G4USFSt1Cv875svmVR4F8mePGT8Z9B1pJ2cvutIeu38Ai++Wr2FZIzh7XkbGfHETDBw0aD0Wvvsw5Ix385HEO7jHYVSVAKc86ztV/P0APjsBtvf4U+zof/4spXckbFw2v3wx+/thECf/dkmkMl/hnnP246Z+TkHtvcVwi9fwYdXwSOd4f3x9vbT5oX2VtSXd5bdXjVodd1PYTxwg7P8PvCCs7wZKP3LleasC5rvftnB5z9v5eZTu9AuuewtDRFhSOeWLF7RiUEbZnMkbUoKVn3NqfOuZJ8nAe9Vn0Ny5wNv9huHe9UULk9eyZSlybUy+c2OnHzu+PBnpv+SybEdk3h4bG9SE6OP+HOPyIbZgNjhstXhazcMht0Ec5+14y4NuLLy/gstesDlX9je+ovfgpWT4adXD7xfMmf15gWQv8fOVHf0efbRbpgd7mXq3+3xVnwCp/8fHHVO7beyUvVKXSeFLcAJwHfAcKBkjOrJwHUi8g4wGNhjjAnajfa8wmL++ukyOqTE8McTKv4hHtopiR9/7sKg7R/bwcuiE2t+oJ9ew/PZTaz1p1J0wXuklE4IYIcwiGvN770zmLSp5xF3ZJvy8xb+8sky8gqLmXhmD8Yd065+zAe7/gc7GU1UiMddaghOvhdOvKv6Y0eJQJ/f2YcxdsrR7ctg21LYvtzOX911lE0EHU60PbNLRDe1vbX7/B4+vwk+uBwWvW4TUlLHoBRPhV7QkoKIvA2cCCSLSAbwN+Aq4EkR8QD52JZGAF8Ao4A1wH7g8mDFBTB5yWY27crj7auGEOmpeG6BoZ2S+cDfDcHYUVO7nFb9AxT74Ju/wpz/Mo9evNzm70zq0e3g7Vxu6Pt70mc+TCoX8uXSbYd1tbA9J59/fr6SyUu20Ds9kccu6E3HlFjbY/vTm6D/5ZDWv+oPCgZfob19NOCK0By/oRE5/MEERWzv64RU6HJ69fdL6w9XTYf5L8K0++G/Q+x36tjr7Ux0qkEJWlIwxlx8iLcO+nUyxhjg2mDFUt4FA9Lp3CKOfm0OPQNZi/gocpP64Mt149n4Y/WTQl62HbNm7VSmJZ7LVdvO4cPT+x56+36XIjMf5tqmc3l3aedqJ4Wc/CK+WrqNT5dsZvbaLNwi3HJqF645seOBQe9+fs+e2e1aZ28jVIffbwd06zjcTiZzpLb8BL587Z8Q7lxuGDzBtoSadj8seNE+el0IQ2+ElC6hjlDVkkY59pGIVJoQSvTvnMbyhe3ouWF29Wrkd66Bty/C7F7Py81u4u9bBvKXM7rTJz3x0PsktoGOwzkzYxp/2TSKTbv2k96s4ltIBb5ipq/K5NPFm5m6ageFPj9tk5pw/fDOnNs3tWzdiL8YfngMXF475lDGwupdLfzyuR32uesZcPFb1Sl15TbMss9tNCk0CHEt4az/2BGFZ/8bfnrN1ld0PxOOuxlaV3ICpMJC4x3mohqGdUpmbnE32PwTFOVXvvHaafDCcPz7dzEx8f+4f+tAHji3J384rhpn/v3GEVewjeNdS/hyWcVVKdNX7WDYg9O5+o2FzF+/i98NasMn1w7lu1tPrLCynBWfQtYae084MgFmP1l1HH6/HQ/K5bXJYdXnVe9Tmf27YMEr0LIXxIR4MD5VuxLTYdRDcONSmwzWzYBJJ8LbF9da3x4VGpoUKjG4QzMWmm64/IX2NsihzH8B3jgfX2xrrop8kDe3pfHvi/ty0aA21TtQ11HQJJkJsT/webmmqXmFxfz1k2Vc/sp8kmIieOXygcy562QmjjmKPumJFY9nZAx8/xgkdbYVjAOvgJWfVT3A36opthJy9OPQ/Cj44nbbJv5w+P12KJG92+y8Aqphik2xld83LYXhf7UnR88ce+QnFCpkNClUIi7KS37rQfbFoTqxLXoTPr+F/HbDOTv/Xn7YGcvz4wYwulfrireviCcC+lzM4KL5bNm0nk27bEe2ZZv3cObTP/D6nA38YVh7Pr1uKCd2bV71RDmrv4btzhmcyw2Dr7ZDTPz4n0Pv4/fDjAehWUfofbFNDDkZ8N2/ql+O0r5/BNZ8AyMegNQQVXKruhOVAMffaueJiE+Fd35n+0UU7gt1ZKqGNClUoXeXDvzqT6XotwqSwq9fw+TryUs/ntO3XMX6XDevXTGIk7o1r/mB+o3HbXyc557J50u38uyMtZzz31nk5hfxxpWD+cvoHodsKVWGMTDzEUhoAz3H2nVxLe0MYYvePPSwByVXCSfcAW4PtBlsR+mc84xtvlgTa6fB9P+zlZDa6qhxSekKf5gKQ2+w9Q3PHmc7wamwoUmhCsM6JTPf382O8ll6cvaMhfD+eAqSezBy6wRyioS3rxrC4A6Hee88uTO0Hcq4yJk89NVKHvhyFad0b8FXNxzPsM41GKZi/fd2zuChfy7b5vzYP4MvD+Y/f/A+JXUJSZ1se/USJ//NtlWfclP1h0nYkwEfXAnNu9urDe3o1Ph4IuDUv8P4z2zLsxdPg5kP6+itYUKTQhX6pCeyxNUdr2+vPZMGO0zxW2PxNUnhvJybyTFRvDPhGHqmHWHnrH7jaO3fwgkRv/LQ+b347yX9aBpTwzbpMx+B2BbQ99Ky61O6QpcRdsKWwnLjLK36DHYsh+Nvt1cJJZo0s71YM+bDT69UfWxfIbw33vaPuOB1HfyusWt/HFwzC3qcBdP+Af8ZZOezru4JhgoJTQpViPC4KE4/xr7Y8CPkbofXz8FvYFzhHWwsiOG1KwbRtWUtzGrW4yxMZDwv9VzBBQPSaz4pTsYCO1PXMddVPLvZsX+2Q4YvKdXU1O+H7x60Vwk9zz94n14XQPvj4duJsHdH5cf/+h47ZMLZ/4HkTjWLXTVM0U3hvBfh4nfsFKQfXAHPHQ+//s/e6lT1jiaFaujerQcZJpn8lV/BW2MxezO5Tu5kyb4kXr1iEEen1tLwDd5opPdFyPIPbcudms62NfMR+4/wUPfx2x5rK31nP33gVljJVcIJd9hK6fJE4IzH7Ein/7vn0Mf++X17FXLMdfbMUKkSInYQyKt/gHOfh8JceOsCeGmEHVJd1SuNsvNaTQ3rnMz8r7tyzobpGHHz1yZ/YVpOOq9ePpC+1egEVyPD/2ont1/wku2R3OMsOP42aHl05fttWwa/fgkn3m1HyayIiL1aeH+8rVjudmbFdQnlJXe2A7HNeBC6nQExKbZ5a+nHjpXQ5hg4ZeJhF101cC63vfI86hxbCT3jIXhlFPQ4G855rmHO3R2GNClUQ9cWcbzr7c05/lk8GXM972V356XxAw+/UrkyUfF2wvZhN8Oc/9hZtVZ8YnsYH3+rHS65Ij88BhFxdiiCynQ/E5q2h1lPgfHDjhX27K2iq4TSht0MS9+3CaWEywtN20KzDvYW09AbylZuK1URtxcGXmmbPv/4NEz/J+Rn23lGtB4q5MSE8X29AQMGmAULFtTJsW54awGLlv7MFmnJc5f25+TuLerkuOTttolhzn/tP5wmSbYdeEKa85xq24h/fou9Cjj1vqo/c97z8MWt0CTZ3m66dm7VSQHs7az1P0Cz9jYRxKeVrZhW6nAsfsvOEpc+BH73rj0xUkElIguNMQMqfE+TQvV8tWwr17+9iCcv6suonsGbKfSQCnLtP54dKyFnM+zZbDuX5e+x73ub2OkYY6vRR6JwPzxxtK10PvcF6DU2uLErVZVlH9pJflr3gd9/aE9WVNBoUqgleYXFREdU44y6LhXk2jHyPZHQtF3195v3vJ3o/ZL3q3eVoFSwrfoc3r/MNp++9BOICfE0sg2YJgWlVHhY/S28e4k9wRn3qe2Nr2pdZUkhaE1SReQlEdkhIsvKrb9eRFaJyHIReajU+rtEZI2I/CIiNZgBRCnVYHQ+xV69Zm+Cl0fZEYpVnQpmP4VXgBGlV4jIScBZQG9jzFHAI876HsBFwFHOPv8VEb2noVRj1P54uPRjW1/2/En2llLW2lBH1WgELSkYY2YCu8qtvgZ4wBhT4GxT0kX2LOAdY0yBMeY37LScg4IVm1KqnmszGP68yA698uv/7BAZn98KezNDHVmDV9c9mrsAx4nIXBGZISIDnfWpwKZS22U465RSjVVUPAy/xyaHfuNsh86n+tgOlwW5oY6uwarrpOABmgFDgNuA96SGA/yIyAQRWSAiCzIz9axBqQYvrqUdcffaedDpZDvHxzNDbS9+VevqOilkAB8Zax7gB5KBzUB6qe3SnHUHMcZMMsYMMMYMSElJCXrASql6IrkTXPAaXP4VFBfCi6fC8o9DHVWDU9dJ4RPgJAAR6QJEADuBycBFIhIpIu2BzsC8Oo5NKRUO2h4DE2ZAy562Evrb+8rOdaKOSDCbpL4N/Ah0FZEMEbkSeAno4DRTfQcY71w1LAfeA1YAXwHXGmP0/7JSqmJxLWD8FDs74A+PwVsX2iFh1BHTzmtKqfC24CX44nZITLeD6jXvHuqI6r2QdF5TSqk6MeAKO/VnwV544RQ7hIveTjpsmhSUUuGv7TEw4TtIG2BHAH7hFNiyONRRhSVNCkqphiEh1Q6kd96LsCfD9ob+8k7Izwl1ZGGlysHwRSQKGA0cB7QG8oBlwOdOBbFSStUPInau8U6nwLT7Ye6zdpKqEQ/YWQxrOu95I1TplYKI3AfMAo4B5gLPYVsJ+YAHROQbEekV9CiVUqomohPhjEfhD1Pt9LHvj7ctlHK2hDqyeq/S1kcicoYx5vNK3m8OtDHGhKQJkLY+UkpVqdgH856Dqffb+c9H/B/0uaRRXzUcduujihKCiLhEJN55f0eoEoJSSlWL2wPHXAvXzIKWR9upP98839Y7qINUq6JZRN4SkXgRicHWJ6wQkduCG5pSStWipI62w9vIh2HDbPjvMfDTaxDGfbWCobqtj3oYY3KAs4EvgfbApcEKSimlgsLlgsET4JrZ0Ko3TL4e3jgXdq4OdWT1RnWTgldEvNikMNkYUwRoelVKhadm7WHcZBj1CGyaD/8dYntF78sKdWQhV92k8BywHogBZopIW0Ab/yqlwpfLBYOugj//ZOdrmP88PNUXZj0FvoJQRxcyVbU+OgaYY8pt5MyB4DbG+IIcX6W09ZFSqtbsWAlf/xXWfAOJbeGUiXDUOQ2yldKRjH00DlgoIu+IyGUi0hLAGdk0pAlBKaVqVfPu8PsP7PzQEbHwweXw8shGN5lPVU1SrzHG9AMmAk2BV0TkRxH5PxE5XkTcdRGkUkrVmY7D4erv4cynYOev8Nzx8NVdjWa4jGrVKRhjVhljHjfGjACGAz8AY7G9nJVSqmFxuaH/eLhuga1vmPMMPD0Qln7Q4Juw1mhAPBFpAhwFzDfGXH+oe1JKKdUgNGkGZz4BV021c0V/eCW8NgYyfwl1ZEFT1dhHY0RkvYj8JCKjgOXA08BSERlfxb4vicgOZ5a18u/dIiJGRJKd1yIiT4nIGhH5WUT6HUGZlFKqdqX2h6umwRmPwdaf4Zlj4fNbYe+OUEdW66q6UrgfOA34I3YgvJONMUOAXsCtVez7CjCi/EoRSXc+c2Op1SOx8zJ3BiYAz1QjdqWUqjsuNwy8Eq5fCP3G2xnfnuwD0/8FBbmhjq7WVJUU/MaYX40x84HfjDHrwI55hB0p9ZCMMTOBXRW89ThwO2U7v50FvOa0apoDJIpIq+oWQiml6kxMMox+DK6dB51PgRkP2P4N854HX2GooztiVSUFl4g0FZEkwO8sNxORZtXY9yAichaw2RizpNxbqcCmUq8znHUVfcYEEVkgIgsyMzNrGoJSStWO5E5wwWt2eO7krnbGt/8MgvWzQh3ZEalqkp0EYCFQ0nvjp1Lv1agK3qmkvht76+iwGWMmAZPAdl4r/35RUREZGRnk5+cfyWHCVlRUFGlpaXi93lCHolTjkDYALpsCa76FL26D9y+D6+bbOR3CUKVJwRjTrhaP1RE7kN4S2yGaNOAnERkEbAbSS22b5qyrsYyMDOLi4mjXrh3SAHsiVsYYQ1ZWFhkZGbRv3z7U4SjVeIhA51Nh7Ct2GtBvJ9pWS2GoqtZHbhGJLfV6iNNp7XgRiavJgYwxS40xzY0x7ZxkkwH0M8ZsAyYD45xWSEOAPcaYrTUvDuTn55OUlNToEgKAiJCUlNRor5KUCrnWfWDwNbDwZdgYnt24qqoXeBD4U6nXbwO3AX8F/lLZjiLyNvAj0FVEMkTkyko2/wJYB6wBni93zBprjAmhRGMuu1L1wkl3Q3waTLkRiotCHU2NVVWncDIwsNTrbGPMmc6AeN9XtqMx5uIq3m9XatkA11YRi1JK1X+RsXDGI/D2RTD7KTjullBHVCNVtj4qN/DdHRD4EY+teJfGLSsriz59+tCnTx9atmxJampq4HVhoW2uNnnyZB544AEAPvnkE1asWBHKkJVSta3rSOh+Jsx4CHatC3U0NVJVUogoXXdgjPkaQEQSgKhgBhaukpKSWLx4MYsXL+bqq6/mpptuCryOiIjA5/MxZswY7rzzTkCTglIN1siHwOWFKTcferwkY2DLYti5pk5Dq0xVt4+eB94VkauNMRsBnAl2ngFeCHZwDcVll11GVFQUixYtYujQofTq1YsFCxbwu9/9jsmTJzNjxgz+8Y9/8OGHH5Kbm8vVV1/N/v376dixIy+99BJNmzblxBNPZPDgwUyfPp3s7GxefPFFjjvuuFAXTSl1KPGt4eR74cvb7EB6vcYeeM8YWP01fP8obHIqpJM6Q7dR0PUM28zVFZpBqKtqkvqYiOwHfhCRGGf1XuABY0y9H4rivs+Ws2JL7Q5326N1PH8786ga75eRkcHs2bNxu9288sorABx77LGMGTOG0aNHc/755wPQq1cv/v3vf3PCCSdw7733ct999/HEE08A4PP5mDdvHl988QX33Xcf3377bW0VSykVDAOvhJ/fga/uhE4nQ1QCLP8Yfngcti+DhDYw8mG77S+fw4//gVlPQkwKdDkduo+BTqfUaYKo6koBY8yzwLMlt5GMMQ1nkI86NHbsWNzuyv/H7tmzh+zsbE444QQAxo8fz9ixB84uzj33XAD69+/P+vXrgxarUqqWuNww+gmYdCK8Nw5yNts6huQucPaz0PN8cDsdTQdPgLxs2wlu1eewYjIsesMmjgGX2yG8Y5KDHnKlSUFEfg+8ZYzxV5QMRKQj0MoY80OwAjwSh3NGHywxMTFVb1SFyMhIANxuNz6fTnynVFho1QuOuda2RGrVBy54HbqNtnNElxedaBNFz/PtOEq/fmnHVJp6H3z3ABx9rp1XOrV/0MKt6kohCVgkIguxw11kYiuYOwEnADuBO4MWXSMQFxdHbq7NtwkJCTRt2pTvv/+e4447jtdffz1w1aCUCmOnTITeF0HzHtWf89kTAT3Oso8dK2H+C7DkHVjyNrTuB8ffZusgallV03E+CfTDdlpLwfZb6IcdguJSY8x5xpjVtR5VI3LRRRfx8MMP07dvX9auXcurr77KbbfdRq9evVi8eDH33ntvqENUSh0plxtaHFX9hFBe8+5wxqNw80pbB1G4N2hNXcWE8dRyAwYMMAsWLCizbuXKlXTv3j1EEdUP+jdQqoEzxvaW9kQc1u4isvBQM2fWePhrpZRSISZy2AmhKpoUlFJKBWhSUEopFVBlPwUAEYkEzgPald7HGPP34ISllFIqFKqVFIBPgT3YZqkFwQtHKaVUKFU3KaQZY0YENRKllFIhV92kMFtEehpjlgY1mgYgKyuLk08+GYBt27bhdrtJSUkBYN68eURE1KzFwMSJE4mNjeXWW2+t9ViVUqq8qoa5WAoYZ7vLRWQd9vaRYKdV6FXJvi8Bo4EdxpijnXUPA2cChcBa4HJjTLbz3l3AlUAx8GdjzP+OrGihUTJ0NtTsB724uLjKsZGUUirYqmp9NBr7Iz4SO7TFac7rkvWVeQUof8vpG+BoJ5n8CtwFICI9gIuAo5x9/isiDeYXcurUqfTt25eePXtyxRVXUFBgq2XatWvHHXfcQb9+/Xj//ff56quv6NevH7179w5cbQCsWLGCE088kQ4dOvDUU0+FqhhKqUagqqGzNwCIyOvGmEtLvycirwOXVrij3XemiLQrt+7rUi/nAOc7y2cB7xhjCoDfRGQNMAg7x/Ph+/JO2FbLd7xa9oSRD1R78/z8fC677DKmTp1Kly5dGDduHM888ww33ngjYK8sfvrpJzIzM+nXrx8zZ86kffv27Nq1K/AZq1atYvr06eTm5tK1a1euueYavF5v7ZZLKaWofj+FMsONOmfxRzpM3xXAl85yKrCp1HsZzrqDiMgEEVkgIgsyMzOPMITgKy4upn379nTp0gWww2HPnDkz8P6FF14IwJw5czj++ONp3749AM2aNQtsc8YZZxAZGUlycjLNmzdn+/btdVgCpVRjUlWdwl3A3UC0iJTMViPYOoFJh3tQEbkH8AFv1nRfY8ykkmMPGDCg8oGbanBGHyrVGVK7ZMhs0GGzlVLBVdUoqf8yxsQBDxtj4p1HnDEmyRhz1+EcUEQuw9ZJXGIOjMa3GUgvtVmasy7sud1u1q9fz5o1dg7WQw2HPWTIEGbOnMlvv/0GUOb2kVJK1ZWqrhT6OYvvl1oOMMb8VJODicgI4HbgBGPM/lJvTQbeEpHHgNZAZ2BeTT67voqKiuLll19m7Nix+Hw+Bg4cyNVXX33QdikpKUyaNIlzzz0Xv99P8+bN+eabb0IQsVKqMat06GwRme4sRgEDgCXY20e9gAXGmGMq2fdt4EQgGdgO/A3b2igSyHI2m2OMudrZ/h5sPYMPuNEY82X5zyxPh86umP4NlFKVqWzo7KpaH53kfMBHQL+SzmsicjQwsYp9L65g9YuVbP9P4J+VfaZSSqngqm7ro66lezMbY5YBeiqqlFINTHWHufhZRF4A3nBeXwL8HJyQlFJKhUp1k8LlwDXADc7rmcAzQYlIKaVUyFQrKRhj8oHHnYdSSqkGqqomqe8ZYy4oNTBeGZUNiKeUUir8VFXRXHK7qGQAvPIPVQER4ZZbbgm8fuSRR5g4cWKl+0yePJkHHqj/PbCVUg1bVT2atzqLpwARxpgNpR/BDy88RUZG8tFHH7Fz585q7zNmzBjuvPPOIEallFJVq26T1DbAcyKyTkTeF5HrRaRPEOMKax6PhwkTJvD44wdXwXz22WcMHjyYvn37csoppwQGt3vllVe47rrr2LNnD23btsXv9wOwb98+0tPTKSoqYu3atYwYMYL+/ftz3HHHsWrVqjotl1Kq4atuRfPfAEQkGrgKuA14AqjXcx48OO9BVu2q3R/Obs26ccegO6rc7tprr6VXr17cfvvtZdYPGzaMOXPmICK88MILPPTQQzz66KOB9xMSEujTpw8zZszgpJNOYsqUKZx++ul4vV4mTJjAs88+S+fOnZk7dy5/+tOfmDZtWq2WTynVuFUrKYjIX4ChQCywCLgV+D6IcYW9+Ph4xo0bx1NPPUV0dHRgfUZGBhdeeCFbt26lsLAwMFR2aRdeeCHvvvsuJ510Eu+88w5/+tOf2Lt3L7Nnz2bs2LGB7Uom61FKqdpS3X4K52LHJPocmAH86EyIU69V54w+mG688Ub69evH5ZdfHlh3/fXXc/PNNzNmzBi+++67Ciugx4wZw913382uXbtYuHAhw4cPZ9++fSQmJgam+lRKqWCoVp2CMaYftrJ5HnAqsFREfghmYA1Bs2bNuOCCC3jxxQNDPu3Zs4fUVDt/0KuvvlrhfrGxsQwcOJAbbriB0aNH43a7iY+Pp3379rz//vsAGGNYsmRJ8AuhlGpUqpUUnAHwLgHGAxdi5zrQm9nVcMstt5RphTRx4kTGjh1L//79SU5OPuR+F154IW+88UZgZjaAN998kxdffJHevXtz1FFH8emnnwY1dqVU41Pp0NmBjUSmYOsQvgfmG2OKgh1YdejQ2RXTv4FSqjKHPXR2CWPM6NoNSSmlVH1U1TAXFQ5vUaKyYS5E5CVsT+gdxpijnXXNgHeBdsB64AJjzG4REeBJYBSwH7isprO6KaWUOnJV1SmUDG/xlfO4xHl84Twq8wowoty6O4GpxpjOwFTnNcBI7BScnYEJHOEIrNW5JdZQNeayK6WOXFXDXJQMZ3GqMeZ2Y8xS53EncFoV+84Eys8+fxZQ0uTmVeDsUutfM9YcIFFEWtWwLICdEzkrK6tR/jgaY8jKyiIqKirUoSilwlR1+ymIiAw1xsxyXhxL9YfIKK1FqfGUtgEtnOVUYFOp7TKcdVspR0QmYK8maNOmzUEHSEtLIyMjg8zMzMMIL/xFRUWRlpYW6jCUUmGquknhSuAlEUkABNgNXHEkBzbGGBGp8em8MWYSMAls66Py73u93gp7CSullKpadVsfLQR6O0kBY8yewzzedhFpZYzZ6twe2uGs3wykl9ouzVmnlFKqDlV37KNI4DxsqyGPbSwExpi/1/B4k7Ed4B5wnj8ttf46EXkHGAzsKXWbSSmlVB2p7u2jT4E9wEKgWmMeicjbwIlAsohkAH/DJoP3RORKYANwgbP5F9jmqGuwTVIvP+gDlVJKBV11k0KaMaZ889JKGWMuPsRbJ1ewrQGurcnnK6WUqn3VbUE0W0R6BjUSpZRSIVfdK4VhwGUi8hv29pFgT/AP2aNZKaVU+KluUhgZ1CiUUkrVC9VtkroBQESaA9pdVimlGqjqzqcwRkRWA79hZ15bD3wZxLiUUkqFQHUrmu8HhgC/GmPaY1sQzQlaVEoppUKiukmhyBiTBbhExGWMmQ5UOEGDUkqp8FXdiuZsEYkFZgJvisgOYF/wwlJKKRUK1b1SOAvb0/gm7LwKa7HzLCillGpAqtv6qOSqwC8inwNZpjFOWKCUUg1cpVcKIjJERL4TkY9EpK+ILAOWYUc7rdGwF0oppeq/qq4UngbuBhKAacBIY8wcEekGvI29laSUUqqBqKpOwWOM+doY8z6wzZkqE2PMquCHppRSqq5VlRT8pZbzyr2ndQpKKdXAVHX7qLeI5GAHwIt2lnFe63AXSinVwFSaFIwx7roKRCmlVOhVt59CrRKRm0RkuYgsE5G3RSRKRNqLyFwRWSMi74pIRChiU0qpxqzOk4KIpAJ/BgYYY44G3MBFwIPA48aYTsBu4Mq6jk0ppRq7kFwpYG9bRYuIB2gCbAWGAx84778KnB2a0JRSqvGq86RgjNkMPAJsxCaDPcBCINsY43M2ywBSK9pfRCaIyAIRWZCZmVkXISulVKMRittHTbFjKbUHWgMxQLV7RxtjJhljBhhjBqSkpAQpSqWUapxCcfvoFOA3Y0ymMaYI+AgYCiQ6t5MA0oDNIYhNKaUatVAkhY3AEBFpIiKCnbBnBTAdON/ZZjzwaQhiU0qpRi0UdQpzsRXKPwFLnRgmAXcAN4vIGiAJeLGuY1NKqcauupPs1CpjzN+Av5VbvQ4YFIJwlFJKOULVJFUppVQ9pElBKaVUgCYFpZRSAZoUlFJKBWhSUEopFaBJQSmlVIAmBaWUUgGaFJRSSgVoUlBKKRWgSUEppVSAJgWllFIBmhSUUkoFaFJQSikVoElBKaVUgCYFpZRSASFJCiKSKCIfiMgqEVkpIseISDMR+UZEVjvPTUMRm1JKNWahulJ4EvjKGNMN6A2sBO4EphpjOgNTnddKKaXqUJ0nBRFJAI7HmW7TGFNojMkGzgJedTZ7FTi7rmNTSqnGLhRXCu2BTOBlEVkkIi+ISAzQwhiz1dlmG9AiBLEppVSjFoqk4AH6Ac8YY/oC+yh3q8gYYwBT0c4iMkFEFojIgszMzKAHq5RSjUkokkIGkGGMmeu8/gCbJLaLSCsA53lHRTsbYyYZYwYYYwakpKTUScBKKdVY1HlSMMZsAzaJSFdn1cnACmAyMN5ZNx74tK5jU0qpxs4TouNeD7wpIhHAOuBybIJ6T0SuBDYAF4QoNqWUarRCkhSMMYuBARW8dXIdh6KUUqoU7dGslFIqQJOCUkqpAE0KSimlAjQpKKWUCtCkoJRSKiBUTVKVUmHKGIPf+O1yuYEHDAZjDAa7TcmyHaQARMQ+I2Vel3xm6f38+PGbso9iU4wxhoTIBBIiE6od7+a9m8nKzzro80qOVSmx8YoIgf9EDoq3ZLl8vKXjLimzILjEZf8OAhjK7F/673ggDDmwLEKnxE70SOpRrb9BTWhSUGGl5Mej5Aej2F980D++Yn+xfS617PP7KnwdHxFPelw6sRGxRxTXb3t+4+v1X1PoL8QYE4inon/kJT8OJcvl4/MZny0X/sBgL86eAIFyl2zn8/vwGR8+v+/Aj7Xzo1pyrJK/TZnj+H0YDINbDeZ33X7HwJYDAz/SFVm3Zx1vrXyLz9Z+xn7f/iP6e9WG1NhUujfrTo+kHnRPss/NopqxM28ny3YuY+nOpSzfuZzlWcvJLsgOdbi17oqjr9CkoOqOMYaC4gLyffnkF+cHnguKCw78EJX6MfL5fRQUF5DnywtsW7Jc5C8K/DgCZZajPdHERsQS440hxhtDrDeWJt4mFBUXsX3/drbt28b2/dvZvm872/ZvY2fezsC+talZVDPaxLWhTXwb0uPS6ZDQgaGpQ4nxxlS63479O3hmyTN8vPpjik0xLnHhwmWfSz1KnyGXnCkCuMSFW9x4XJ4yzy6XXV+yT4mSfUu287q8RHmiDuznnH0GzmpLPXtdXtzixu1yB45VUFzANxu+YerGqXRK7MTF3S5mdIfRNPE2Cfy/mrV5Fm+ufJNZW2bhdXkZ0W4EbeLbHIiJsomkpLwlZ8MuKXuXOpAcnSRnjMEt7rJn0M5yyfqSspU8MvMyWZG1gpVZK/l247eBz47zxpFblBuIo2NiR4a3Gc5RSUfRKqZVmc8r/Vy+DIFYyyXz0jGLHChf6b91+VhLvy7zeSXLzmeJSOC7E/i8UldSpWMqKWswSJWXTvXYgAEDzIIFC0IdRlgyxrAzbycbcjawKXcTG3M3sjFnIxtzN5KRm8Heor1HfAy3uIn2RON1eQP/0F24Av+YDIY8Xx77CvfhM74KPyPaE03LmJa0aNKCljEtSYlOwev2lvmHVnrZI57Aj17gudSPrdvlDmzjEhfZBdlszNlY5m+wff/2wLGHtxnO6A6jGdJqCB7XgXOo3MJcXl72Mq+veB2f8XFh1wu5qudVJEUnHfHfra7l+/L58rcveXvV26zctZI4bxxndTqL1NhU3v3lXdbnrCclOoULu17I+V3Or3dlzCnM4Zddv7AiawUbcjbQLr4dRycfTbdm3QLJTZUlIguNMRV1INak0BjlFOZwz/f38F3Gd4F1HvGQGpdKelw66XHpxEfEE+WJItoTTaQ70i67o/G6vXhcHrwub5mzW4/LQ5Tbbh/liSLKE4XX5a1WPMYYCv2F7C3cy/6i/eQW5eJ1eWkZ05JYb2yltzSCId+Xz/Ks5Xy+7nO+Wv8VuYW5JEcnM7L9SM5ofwY/7fiJST9PIrsgm5HtR3J9n+tJj0+v0xiDwRjDkswlvLXyLb7Z8A0+46NXci8u6X4Jp7Y9Fa+7ev8/Vf2nSUEFrNuzjhum3UBGbgZX9bqK3im9aRPXhlaxrcqcCSursLiQmRkz+WztZ8zcPBOf317RDGk1hJv63xSUe7r1Qeb+TLILsunctHOoQ1FBoElBATBj0wzu/P5OItwRPHrCowxoWeF3Qh1Cdn420zdNp3Vsawa3GhzqcJQ6bJUlBT01bASMMTy/9HmeXvQ03Zp148mTnqRVbKtQhxV2EqMSOafzOaEOQ6mg0qTQwO0v2s9fZv2FbzZ8w6j2o5h47ESiPdGhDkspVU81yqSwfs96/jn3n0w8diKpsamhDqfWGGPIys9i9e7VrMlew+rdq5m/bT5b9m3hlv63MP6o8XVeaauUCi+NMils3ruZZTuXMfazsfxj6D8Y3mZ4qEM6bNv2bWPKuinM3jKbNbvXsLtgd+C9ZlHN6Ny0M3895q8c2/rYEEaplAoXIatoFhE3sADYbIwZLSLtgXeAJGAhcKkxprCyzziSiuZNOZu4deatrMhawe+7/56b+98csiZ3ewr2sDFnIxtyN7AhZwN7C/fSpWkXeiT1oGNix4NaBe0r2se3G77ls7WfMW/bPAzG9ups1p3OTTvTKbETnRI71bv25Eqp+qG+VjTfAKwE4p3XDwKPG2PeEZFngSuBZ4J18PT4dF4f+TqPLniUN1a+weIdi3n4hIdJi0ur1ePk+/LJzMskc39m4HlH3g4y92eSkZvBhpwNZc7uBSHSHUl+cT4Ake5IujbtSvek7nRO7MyizEVM2ziNPF8e6XHpXNPnGkZ3GE16XPi3k1dKhV5IrhREJA14FfgncDNwJpAJtDTG+ETkGGCiMeb0yj6ntpqkfrvhW+6ddS8A9w+9n+FthpNdkM2GnA1szN1on52eroEu6eWGHijyFx0Y4sF55PnyKuyp63V5SYlOITUulTZxbWgX34428fY5LS4Nj8vD+pz1rMhaEXis2rWKfUX7iIuIY0S7EYzpOIbeKb21jkApVWP1rp+CiHwA/AuIA24FLgPmGGM6Oe+nA18aY46uYN8JwASANm3a9N+wYUOtxJSRm8FtM25jWdayMuOngB1DpXVMa1rFtrLjuJQbpMwYQ4Q7gih3VKAXcMlzE08TkqOTad6kOSlNUmge3ZyEyIQa/5j7jZ/NezfTvElzIt2RtVJmpVTjVK9uH4nIaGCHMWahiJxY0/2NMZOASWCvFGorrrS4NF4b+RovLXuJzLxM2sS1oW18W9rEtyEtNi3kXfxd4tJbREqpoAtFncJQYIyIjAKisHUKTwKJIuIxxviANGBzXQfmdXv5Y+8/1vVhlVKq3qjzmdeMMXcZY9KMMe2Ai4BpxphLgOnA+c5m44FP6zo2pZRq7OrTdJx3ADeLyBpss9QXQxyPUko1OiHtvGaM+Q74zlleBwwKZTxKKdXY1acrBaWUUiGmSUEppVSAJgWllFIBmhSUUkoFaFJQSikVENbTcYpIJrABSAZ2hjic2qJlqX8aSjmg4ZSloZQDQlOWtsaYlIreCOukUEJEFhxqHI9wo2WpfxpKOaDhlKWhlAPqX1n09pFSSqkATQpKKaUCGkpSmBTqAGqRlqX+aSjlgIZTloZSDqhnZWkQdQpKKaVqR0O5UlBKKVULNCkopZQKCPukICIjROQXEVkjIneGOp6KiMhLIrJDRJaVWtdMRL4RkdXOc1NnvYjIU055fhaRfqX2Ge9sv1pExoegHOkiMl1EVojIchG5IRzLIiJRIjJPRJY45bjPWd9eROY68b4rIhHO+kjn9Rrn/XalPusuZ/0vIlLpnOLBJCJuEVkkIlOc12FZFhFZLyJLRWSxiCxw1oXV98s5fqKIfCAiq0RkpYgcEzblMMaE7QNwA2uBDkAEsAToEeq4KojzeKAfsKzUuoeAO53lO4EHneVRwJeAAEOAuc76ZsA657mps9y0jsvRCujnLMcBvwI9wq0sTjyxzrIXmOvE9x5wkbP+WeAaZ/lPwLPO8kXAu85yD+c7Fwm0d76L7hB9x24G3gKmOK/DsizAeiC53Lqw+n45MbwK/MFZjgASw6Ucdf7lreU//DHA/0q9vgu4K9RxHSLWdpRNCr8ArZzlVsAvzvJzwMXltwMuBp4rtb7MdiEq06fAqeFcFqAJ8BMwGNur1FP+uwX8DzjGWfY420n571vp7eq4DGnAVGA4MMWJLVzLsp6Dk0JYfb+ABOA3nIY84VaOcL99lApsKvU6w1kXDloYY7Y6y9uAFs7yocpUr8rq3Hboiz3LDruyOLdbFgM7gG+wZ8bZxs4RXj6mQLzO+3uwswOGvByOJ4DbAb/zOonwLYsBvhaRhSIywVkXbt+v9kAm8LJzS+8FEYkhTMoR7kmhQTD2NCBs2gaLSCzwIXCjMSan9HvhUhZjTLExpg/2LHsQ0C20ER0eERkN7DDGLAx1LLVkmDGmHzASuFZEji/9Zph8vzzY28XPGGP6Avuwt4sC6nM5wj0pbAbSS71Oc9aFg+0i0grAed7hrD9UmepFWUXEi00IbxpjPnJWh2VZAIwx2cB07C2WRBEpmaK2dEyBeJ33E4As6kc5hgJjRGQ98A72FtKThGdZMMZsdp53AB9jE3a4fb8ygAxjzFzn9QfYJBEW5Qj3pDAf6Oy0tIjAVpxNDnFM1TUZKGlNMB57f75k/TinRcIQYI9zyfk/4DQRaeq0WjjNWVdnRESAF4GVxpjHSr0VVmURkRQRSXSWo7H1IiuxyeH8Q5SjpHznA9OcM73JwEVOi572QGdgXp0UwmGMucsYk2aMaYf9/k8zxlxCGJZFRGJEJK5kGfu9WEaYfb+MMduATSLS1Vl1MrAibMpRV5UvQazUGYVtBbMWuCfU8RwixreBrUAR9iziSux93KnAauBboJmzrQD/ccqzFBhQ6nOuANY4j8tDUI5h2Even4HFzmNUuJUF6AUscsqxDLjXWd8B+0O4BngfiHTWRzmv1zjvdyj1Wfc45fsFGBni79mJHGh9FHZlcWJe4jyWl/x7Drfvl3P8PsAC5zv2Cbb1UFiUQ4e5UEopFRDut4+UUkrVIk0KSimlAjQpKKWUCtCkoJRSKkCTglJKqQBNCkrVIhExIvJoqde3isjEEIakVI1oUlCqdhUA54pIcqgDUepwaFJQqnb5sHPu3hTqQJQ6HJoUlKp9/wEuEZGEUAeiVE1pUlCqlhk7cuxrwJ9DHYtSNaVJQangeAI7xlVMiONQqkY0KSgVBMaYXdgpMa8MdSxK1YQmBaWC51FAWyGpsKKjpCqllArQKwWllFIBmhSUUkoFaFJQSikVoElBKaVUgCYFpZRSAZoUlFJKBWhSUEopFfD/R+QAXwKOaiwAAAAASUVORK5CYII=\n", + "text/plain": [ + "
" + ] + }, + "metadata": { + "needs_background": "light" + }, + "output_type": "display_data" + } + ], + "source": [ + "import matplotlib.pyplot as plt\n", + "\n", + "M = 4096\n", + "Ns = [128*i for i in range(2, 50)]\n", + "tri_ms = []\n", + "ref_ms = []\n", + "def_ms = []\n", + "for N in Ns:\n", + " x = torch.randn(M, N, device='cuda', dtype=torch.float32)\n", + " gbps = lambda ms: x.nelement() * x.element_size() * 1e-9 / (ms * 1e-3)\n", + " tri_ms += [gbps(triton.testing.do_bench(lambda: softmax(x)))]\n", + " ref_ms += [gbps(triton.testing.do_bench(lambda: torch.softmax(x, axis=1)))]\n", + " def_ms += [gbps(triton.testing.do_bench(lambda: naive_softmax(x)))]\n", + "plt.xlabel('N')\n", + "plt.ylabel('Bandwidth (GB/s)')\n", + "plt.plot(Ns, tri_ms, label = 'Triton')\n", + "plt.plot(Ns, ref_ms, label = 'Torch')\n", + "plt.plot(Ns, def_ms, label = 'Naive')\n", + "plt.legend()\n", + "plt.show()" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.7.9" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +}