From 8bedcce9befbbe95d8fe0a082718edc4050e2831 Mon Sep 17 00:00:00 2001 From: Szymon Sidor Date: Thu, 2 Sep 2021 22:02:40 -0700 Subject: [PATCH] [LANG] Added seeded random number generation - philox (#261) --- docs/conf.py | 5 +- .../getting-started/tutorials/random_bits.png | Bin 0 -> 42044 bytes docs/python-api/triton.language.rst | 15 +- python/setup.py | 2 +- .../test_core.py} | 0 python/test/language/test_random.py | 198 +++++++++++++++++ .../test/{ => operators}/test_blocksparse.py | 0 .../{ => operators}/test_cross_entropy.py | 0 python/test/{ => operators}/test_matmul.py | 0 python/test/{ => runtime}/test_comm.py | 0 python/triton/__init__.py | 2 +- python/triton/language/__init__.py | 4 +- python/triton/language/core.py | 1 + python/triton/language/random.py | 208 ++++++++++++++++++ python/tutorials/01-vector-add.py | 2 +- python/tutorials/04-low-memory-dropout.py | 164 ++++++++++++++ 16 files changed, 595 insertions(+), 6 deletions(-) create mode 100644 docs/getting-started/tutorials/random_bits.png rename python/test/{test_language.py => language/test_core.py} (100%) create mode 100644 python/test/language/test_random.py rename python/test/{ => operators}/test_blocksparse.py (100%) rename python/test/{ => operators}/test_cross_entropy.py (100%) rename python/test/{ => operators}/test_matmul.py (100%) rename python/test/{ => runtime}/test_comm.py (100%) create mode 100644 python/triton/language/random.py create mode 100644 python/tutorials/04-low-memory-dropout.py diff --git a/docs/conf.py b/docs/conf.py index 1107ef171..67a14f47a 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -66,7 +66,7 @@ def setup(app): import sys import os sys.path.insert(0, os.path.abspath('../python/')) -extensions = ['sphinx.ext.autodoc', 'sphinx.ext.autosummary', 'sphinx.ext.coverage', 'sphinx.ext.napoleon'] +extensions = ['sphinx.ext.autodoc', 'sphinx.ext.intersphinx', 'sphinx.ext.autosummary', 'sphinx.ext.coverage', 'sphinx.ext.napoleon'] autosummary_generate = True # Sphinx gallery @@ -78,6 +78,9 @@ sphinx_gallery_conf = { 'filename_pattern': '', 'ignore_pattern': r'__init__\.py', 'within_subsection_order': FileNameSortKey, + 'reference_url': { + 'sphinx_gallery': None, + } } # Add any paths that contain templates here, relative to this directory. diff --git a/docs/getting-started/tutorials/random_bits.png b/docs/getting-started/tutorials/random_bits.png new file mode 100644 index 0000000000000000000000000000000000000000..198f90a5ebb004b74a1bdc68f2d76fec0107dbf9 GIT binary patch literal 42044 zcmd42g?$_xXHgXLg>MbEck|bH3+Hw1%1jJ`Obw3JMColA^2@3JMwo1qBrY`w23KN0+$* z1qILERz^lcNk)cI!xd<4>tKa~q8Od3k7c0!i#*pjIWaMeBr533L>6P0^3wu2qQqD8 z*h-8AkPIRn2Lp04la440Ieocjrnro;G4sQP^A8r%a)uPRF@+bbFY=GNFS;%R-1fjW zW3bZYOuti<{F&YqWj9Snl#3V8`nxEi*^J+>$;;`8f(i1`IoA8PhKr9|TR)*#4iSM6v8*@P(k1eQu;J->>8%7HXBs?_4gmDhaAn@v9Wf{{@^55WXN_9uE z1sPC1)OL0%DVq?~(KV*O!`niYpjvaWevFf}Ow1b?riN5>-pf7L;Y{DXD0$2)hOQ;2 zS5lstqtocVi2ba;OK9)SVWahhM4Rc)myZz0D|>vG1cUqvQIwZ*-=b8s7K-})g|TJ{ zd6_(k&)yX=zKswvNv0vKnzi{D)7wVWI}@Ttw&zC5^ft)5K!ZLlzj=W(>sifB_^kDi zkV6xkQ+@VS|KKQ?nJv0%7xbLObp}S+rST;LjSthJ63xXG58@t`vm^Xe(=LlRoLOxo zC{BgA=JOiPrgy;$X4$?3vX4VQw`Oiwd@h3h>qeRR1m)vy_hLz&>t6S$UHCkn_nOoW zzk4cb8>H}(fmA#(R*I6kWQfj>`^yWKG9TY>P^R;WHX*sMG^Nc0Tw!g*%3YWc`il^Z zYZEbH25vPn$z1l)g!ejq(G=E~^8MwTJ?X)pexMcvVLq$13Q}UEi82;B64}VHdqH~^ zg!d7Z$yv+(iS@Uu1jY3?sLw(qnMlx}>z#>kO{SMY_dKCSGc!*=-XxG;jSr%#GmW8Z zGE%H^+}kKU>0$|G(|8V#^rCn0rv!3444s2&7$yRzYYSM7sxJ9}{Pch_J# zp$m(j6N#bJ%UEYe)1a~~qu-lvNl!Ge{dWEB!cqD2KKKDh<$wG7x93L+er5W11guMs zLMeZwZfQSi&fY)VmsyNbeiyWH-c~?lpH}b$1^iA(n}2tV;`u~0KBt$uPUOeWc=PLd zx#K7j!s=?a$4dJj#{M8N;Y0-(AxNxA-HK(0dxo2iJ5SdV`m-E56gQChB8QdKCi=qA z;m(7ozLli=k~b|YHj4vg2pD&GeVw*|z?aV010U@6h|@D3939!iN@DLZen`8uJ}z}1 zIMn=@!|+$nGab_(#L*Ek@nQPW5=Opk8}Z|r0h(w=0?ip3oRKmtIEj(`MTibBs-7Xn zTpOMWs=yCg4Ls!@{#xqq=vhDXYDte#AEfC%VGQ)xI-{CAnd>oa^dT~7!gHf5huj#6h%l96~;8FD3%=a?YUtTh9;vxqK_hbJ(^*pgLFkAAgOoIWt(;? zW?7z|nQo9~n>40CM~Co@wBy7^8R3w8(S)qK*w=71dE80pK1m~Om8@4$)4t0UnlGKB z44`m@^2cjFR_l0Sa|ueU*uJS+6_tq5pUSoT-$^QasMp=PW17r!<`uG|1f0qLVCXx) z$PK$Z|8>Xd&(}?`7CpDgc3ycW(}{~v&L{r-G#*0|EuA+aQ7+@%OV%&M&7=gxYCT=Q z)x6bTkYvgl7YZ)NmO~si-8U^ZpKO|0kk+N%Wz4HGC*di@e?!sYj8+(8BTJc0_DqRS zs!m)~5~u!3ZB6Y&y+LhDO%m5eZAIPjI^ND?(z|5x{Vc0-%Hv?&M^qnbA9f#_wqXD0 zWwk%W`x*}Hs9$82-&zqAXKTp_%e5)^RF;?9793Ox8)<(x{AO6)rhfT-S7tOp$Fu@H zRfbMMS0i1+OCzhSr5tGc`YS$HMwsgU$LYYnFz|8P0$IMdET;W`CP337p z-gik|@H<3QAPuSIB5&J>$*$IJ;4c1dDHE&AyiAx3Ies~ntJNS@Xu*@7C+|W@Lc5;K z;Y(1}tH5>~@$^Df{&6`PWPCsRx?V1OgF;`I++$sSA7!6r2I%S*gRXY;rz& zWYT0QeB~Hv$!N*AHJUeiVK~w>Z||~}Y;jnqwA3zJJ9%t@E z!a71GuCjDpw0-e3w8oLd5m15@yh>Wn>F7g<=PTz5Pv89!_ju3TBa=OtUF*zQ4WUW=W^?c>ka&_DEa_@>w$6N>a#`%iK*Ws#l=hqqS zPTP-$h0ECVmh_Rz{>t8^A<-4Tu3J?G@jRitJ_(zEr`=554FNO(lJ`E)orja#hvBtf zk^`W0#3ga_^RVa5ao)5ah=1VDlT45g;;0ZV)6SEBq=(>E;Q2mtpyI?{rH6fX93Qhj zwFcw3JWUFx#ex1f-x}F|9d1!TT)R-fC%i&4qNDMuMmT)$joG9q8$#i4uS`g=(8*5bS63qFrJ#x$A^ z{+qyT$K~zTHF;8nfw*eQ$JeWG0_9Xz(nzlGFj_V zzw-D({(gz|LDfvW?L9fm61T$k*3LE5_pB6N_*3mXWz$^xgS4FV_?K;Vtdc{ORFy8F zQS-$@+i{#;Tu^4b`kp#4*IdGK>~o~RfEyIe-iXw>Z%J{I$nEI~is$gm@(sp8{I#|)`{ z9>?X&rN)SQ`MQ!SvG&fus+mPAd&wHostz+^tvau+^Lx1?Ne3c_)_Oi~%Td-8)#5A( zM{;0y6N|-r$6{5h1-;GLkSg^S7NF>5>^P6^Qe1Po6QW+GzR1xAHkP8hs;k@a*^8@6 zzLIQ)u4CY&yOYgbJs?hdG_7l-tN7uEL|-=rxCG&A3HlA+MVG*`BQRnx0Y$l#eIqZ= z(Oc4GtFW(BuL6pTT{ydif_e+qvO!CepH>ZAWD$c?MPG_86K?1@#Aa7sI39}Sdca&R zA-f%H{n@Rau2;X(v^7h|atR~@Ht`-t>KagOXvU#AqW-IU{wtU@xJ?w|JoNV~m$x!p zW?W6_ojfyuo37;vk;3c!)O?PCng}!TfH=RCpVt{jRf5S*%k9E{te3cYs`gskfZ%Rz z=h9oJN%5au$q$RS!}OGrWw2g&LYJeNqXK<|MCaYhjmeb(*Z@!7V4lBI9^%VQ<8kIf zL&cucIF0B}QEo8hL*Zlbl@-wDJWm8v4YImDTnF@kOpmFjXW&LR?Wd!gEfGMk&P(sP zWxBb?^?=ofRYDf=2ys!#dw;P@_MPd5RrWlkykYP>$o7u<`q~U&2FUxuwsnob`A}Ps znTgW5fQrIJhVn$LMYF$REaWkyY^>%H8^-osgy$xf0_9DuHoV{ME6O}pmMZafgenv_ zj}lMA>$aza+^3`?D<&34{asjYcXtabN+LDN9rMxA(aOZ!trZR4&07@J6kQvQVxo;q zbGI^*6*J*(Y#3w3<$;$ld>-8jgMNJg4ghu}iUMReI+5B(jFo|swW=!0D`XlQ1p}2D z1rwP>J()R6JtClML`*7?5}Jvjme3;B-(8NBn+ z{#6a1w!hJr%Q@;9I=X}vf@ z(x0`}F>p6fRS~fSI&zzT1X@^edpkP+rGp~oErLusTDhAudOJEexrum-GykK62r~UQ zn}?b4A4S~l#hDFMH5g@pu2zh1xp}#HnI&)-85zY~KU#}u$;$t;I`W@5^Cx$AXAvGA zFE1}{F9B|#s|^pIu&^)>FFy}IKNqqDmz$51ySX=)liTzEBl0ggvQ}=EuC~tZwm>Jw zzjVzlfFADR%*=lo{rma9c3OGc{)dy3+dsR7>>$tIH#~gYygdJ=jjSs6H&;Z%*4xU# zP}bHFIc7*668vv@#r{$L|MTWQ9RI7P!GCHB@(TTT&HsAy|E{U)X5}gabVRarm-x?g z{j>6azx-!KF`mCY|F4nwzhVAIE^?eDaKw22J!cX)rx3AU$Z4dtl~vb4#>hqXk0%8A z`w|)c#>nt~MBhdd`CbY|Nmg3N8}&HAXP&GtHPH0t6#W?{OAIIL2M9#giXzzObJQ={ zt^O!$O>zNrp?nJq_?>Moaoh)MX6#q7DVeoTxlsiKG)Kz#C}&?5M)lK4hYne6Vsk4Z z#>y+Ebyr$|pq1NIvp!wKmS$#VFtNycaG(Fc=0*FA@~>z@h4^JH zqf^PqrIo;cW&gld#4sRollg!_i7FpN{LkpqEVehsISsIr$?2fe5(Vs)cOt1iZJV!(ue{4q zO>U)9o(+ouYIHmxpc z`rWvN8ztlLXnb7aL8Z}d=2brD@o|$`#dGP;x*BNnp!ie8+S%*87odj}wOld7ksMJ$ zUdJyv-kJ&D17*J(wJAHSb(m`iUt#k-z(x#n9B<;IHh_#G8u;OM)y-q9k zD+?H-=BkYe2RF*XK3?yY<^{eoYH^|))Cyvx=i~^d17j-aelH6N^su*a%whBDu)U5j zOXY}K^S@Mex$W%eD+F+nbax089H6MeZl9nY`|y!6N5b?R()ICM}GZ$fPo@G z;=XLS&i?}{*!k1AJLRtSVLZ?H9K=HaTkIS{YH@uyb5TwrVOq#!oUyr^1+)&t*Ll-s zwNx|8b!lwu?h32o>fW3sj~pG#5>)3j>ohpNALd@*qn1NlpDgKrW{@=doX2bCF4C#J zY|%GSAk(+lWWOO+M4B>{hb1X$;5vvAf9a8qPxCeiBTY+&;6Y9a@68vGE0~_(^MlyI z@}o%C1I=`%*m0*6QOAz_ih<`Tf2m9)L5YNpeoJS`@aaNBY3SUq_l}It`x8uiC0S0c zMvb3>TjE!5%Ajjt5XjL4Z>}C2l@J^VE)2lBG{!I>;O35HhTiwMJK zpQ!K5x-QgPjhWMlwtE`I!;cn+UCK`sMMP{y-TsG5{2DiwhbzxG)Jc*b}a=M+xhE4wf>O z8aS)z=omuq-Jmdx%JmJw9|-bOe!t&ow_Y5S@j*Pl`X==oY{tp&u7Y1Pn^q%Q7S|f# z%N<@d292Lw^5M8N{F=XI@GzB&#`+Zk->b}I zJIULhAgB{mv)*zY9J;}+m}>(ceM|jr|8SrF=8jglkF)<{{)%)PPd$|4n-uAofe5U0^1FoC`RPT@!)Ea|*1m05gdtMnZCdDwi1s3bi?jx@UNaL9+{@gc?g)z_#4eG7rI%QV`9iY2Rw7FglKfmWt+mxC<&g1rpbZY-XB&L=-zEHAogkfj_dR!+9{|m zgTVd$@;1Y1Tm?R7J?=!_?$>rUKZ4-g`oH1NPSb=+-oBc{Bq%$W&~{-z)A}|ImOQI; zB%F(}?BjCk)%eZ~Z^Npy*w_ArVix0mf31BJE9QGLcXqg!hw1X_3c?+>1#}p+JH&Wa zX%4>)!R3znRueD0OL@@X=*RY;5OBThHeMW$i(}CtQSBKtyE`+rRbyCUh1{n>N2*M2cj?c&PX1G-p_#vPz(!kf_h*?jS>NB*eCJ{ zI@oNlXi$$Q5gp(MQLeB~+SJW8r*&G5;RuOT(xf%w=ki|~B^UhmaW*WF!>;q`RR_`L z-12%p@cR5^K$RB*6tLUkuum$Qv<0-7*k!i#PhRZ~@MEFY)P}-%MktAm0jc$%jQiQA zN2#i4M2AeQIYA-`R{#T3v5<@il{+~_@{qeBj+)mU?u-Z zeH^A29$#Vo^Sq_Pd6&Ck3Wl>cHOw7XQ)A|*i{nu5gSgN2`=dmL4rWZdo;+u}ptPXn z4`#bZ*IEpr-2ItisgJF5Af<`di|G-({uD@*hWTg3Ypp88N%I4@P*?mZ*3NN*My8jL`7%=KbnyhXqfS z!In`V5h5cOI{LXk8Iu%DZ#Jp$0Cxynv6Mgz!_vZpQ_yKa0>R=wLq{kG*=}e`5MHl2nRJs4$Z4fACSwp0G}N)^{m7_ zM$j47bpJPPTW%J9wBATUeSfF}o{u){t;kuiayr02#PU-8z8|>a+;C@Z{nk=5GIY1DKz8pzP z_pPb(YBQE)OZ}ttXVO*%XkXKYX4Hh`p&z?k_u)b+HI|`Q z;idHP1`ad=JgvsDOM%IEo?u_^vz?Bq&HeIP$KuYpwE)v;Su(u=t}a8<1cQ2mO1AG1oA=##>rUq^kZZqFwU z8Cn~(hlun5w$l&aIxUxRE27JJUD@6cZJCDQeBOT@x8H-?emzTNwFQlo{2ok?4-9;D z_BCMtTZ8xW`s3J&4N@hdCE7gqDW$QCo#7QTnTLyS9|~{s<5ZhS$supmd=@-0NnJ3P zd-lY0$cFX5Y|oL6OCWkf*m=Hx^-X-#;_bVRX3O1Q_B-tghIL=yET7GzQMUx!G=jP} z_*$*I`pyOIpCogEHbTPt37Q_0On(jx4RO<6lkIHXyEtyyfu{_|#c%x=tRTWk7<{%E z`eiaX%(p~)=7SJHFjSSo0F82v%eeWArP5Z=J6lQUHYwHT6Ie{MH0SfAyUY(}$qG_W zbY4^5g`Q>%L-TvB@ijvXiUs3LFm=gQeqEBaV59?TD+R`cWG-lK_|XY7qhg(#5AeZ1 zg`ZbH+I6rojND{#rF$OGf*fBT0!8;(JR5t{&cw0&^^u1Q@KarbOurR-y+^jZXsMdw z!Y`)mgXz+AZP{bD&y2Hlx$XX)vlke3V<>VKmdexB0e2gbYRDy#%Gf+K5G%Bo?Pd}B zj&0ew7ypo$js!!TjqP$IBk5N0&BLX zC9Zb; z_3cA(tcE@2#W)ClSzl4>aisUWzUqiMq|=JMr_*vEzJ6#kC0Rxa$Kr^4+==Fa@L0Y2 z!q`tU(M2n&n{ck^Ri!dKCOwKByS>-L9JFu>z&2;Kd)I?H{9 z2hFvvGC1ubU3$PJ{^%RUF%!3L;5HwGCPrfAyYKqcAyY2@PM+ycY%!xEgM^Ot=SDF9 zhLt5*sK@E5X_{52{E`qse7Zl=2aJ?Oh{YrtNa zA*+rUlbK6g_N5pEx*oFyl<-!pAlEKYVV~09b{n457lOD9CNixd!fjD2BO zc*oTYj;9mVE34Trr~hvL@^G+0Uc_s~LqF(7E631wpTi5UuK2kULXI3_H;AD_O`FC= zH2u}H%DBY|M-~7TvT$8CM`j%*?p0N<qx2peYWUJ!g|!1{KmzmB_8%@a-U% zkYb4zA@qb0kQmpoUW=s8Bt_W) zontrPxUN%Ev3A4N0|Rdw@_u!G>$bDJB0JzwNss7G7Ka!~ z4z(M4#IZ!JxG?BF!N>mv%CKB%_xQ?q7$_Llms7FT{qQHnr(D5JL=t6hJ4O9QWypWg zwzMrSgc{3CGH{mHVcp|+wwF!kGYSr?tqxqC04$uZhX>rH%Am=$c(4f+C@tN@c^MdO zY{bEUJn`{>eN@h{D$1WGHu-lf9Ysy`;QDJaJ6S<`ikK^11pHA7a8UnIqzT!Eton#g z5LAP19h=nh!}6g7_Na0}GPPNXtO!Q4`j62ZYBMWg>IMQ zZG1QFm`xcKY86NxFxApU@3{bg9`V(U;7MF)^1STvzfJS()|W#wxCCcjuI$bM|H>!9 z&dolI;28a-z?rv~SzLk!A_Hs3!ix1z27&Aku|+jEe^!p%PZo6#S*i76p-&DgUADwT zMz|4if~$%bt~vR95S>5+o6=f)7C?oJ%O_mfd9|(g9(F{{ouGtl2n@a3PU;!~*VqVmlcAv>?7pc+2j8|l; zuqRN4-4aTpUl3@30#tr)T{$Le*xvArT%J43+p4p6UB6Da0>he$W5=Se8*Ohb(* zx59o)55u6Bsf@D+4(s$j+e!J_-T-;*JM>r7?aR?CJzVY%Gb@*8GQ-yh&{wj3m9Ld z*zVyQu>OXxAxuFv!<6t(-_` zz@F$DyAOAlKikvU9ZhJltoyH#JE=nK){uESilHCS_tw#9d0M+Q-8Xhhe|MNKii)+A zjR2hdXBtXG0^p};41D(SVB+A4$&Yg;ZmLQTu@yBAIkHQ8Ir>I`DucRr9zus~kIj$i zB$G2C`qxA>bWzvE7h7>mtZ#r-G}JT|y5q@El?-0#P=y4@sPFBc6_)W{bfMK2-AXiMW_UaFfO#FqB9_ArseDZDFh@^?+_4yM(?ebm zsSB#716JP4EN^~|RQWb6r(sJA3e_mj3*ZK$-rTPGorat3HyBF+nq&j4UQ;=3VNL&( z=!GF+9f};8uHVR+Wr^y#PQ8viO5lo8`ZPi3Pc04+n&UXJ?THVf%&ZVII@*5leoEudY& zr;zEho zC~xGIa0xvYAw8#3nN3=B9L?J&p9P53r~wtdEzejq(|B3V&MV!D)0TkX73M{gt02_J zoLbxg*zRbS%~M-33fVMn(7{x3am@B36=OcyD{-_*v7;6RLtZN>#nf0GL9u@L;El_u zI7Ih}Xz1}o;6b>OTW=A)W&NYIR8);(JXk{N@OW=qDdqhlU|#zQ0NbQU>*hbF5M|1m z)xV2R6CP+re7l~B<>@XvGBx>os!xEs90y9OZJSj0ZTi*i3a@PVpv-Nf^RW~qZY zjh2fbZ@$yVZGt^F;BH1l^la+w?&SoK{^HvX!veFExC(^NCEOmlkGN{bO+^E?r-2BV zViz^GO;)5Ynp98p>|uY}iJ-Alt$RFp!7cG4{uv#|bG5FU$Zo+;N%uGrIqedL54&l+ z{>(#-jdwkSzY|%lwc9XlW-4^0<@6l$L_ad;!v;eP#$;4bpUz^)Qa7l~Z~jIp`{BP4 zPJtFIf|TjdFdTa_Jy#lkT)}hn)pIJwZi7mOAneQDzX;LfLVnqm(rR4xYPY?4O3z>5 zkmMINEcZroudN|!Uyahf5=Skq*E4(Q&2yP3^p9k-VXWo@8ox z>qF5|HLl9(nlRZ_sxfPyy(1z){3b)vM;A1m7@dD%iEH7G{OnLZLO`yaC!J zEifc%LzZLF{awb3?l9Y|E8a`ue!r;jN?Me=5UCuP!btE+vYD~@l%7|pcv1#-db_+8vd0738A_#haGs4G z%MtC6VL3*jwh8QL7^g2+^zw@$^D%%oJ1&V?T_iXVGfF#1iszrNF3V*9J-~AAJsD4I)b*@!2TMdKf=#>|D@~Ao+WbxE6p)`qi71W=U|jFXHhwhugIMrQ`My zZGccetmnCcT>wi!d&}{+>r`S$g^95sB-`p8hc7$Hj3*LLOE&q{iFg1bj7c?2;B*bA zXI~JH4E~ZqoEs*-2gSMWs*gg~zqQ3p@0ATIFK&+)rhgy!i^*Y#q3Y7qZ^okA#k_u@ z_l1gI(*b$Nl7xykeKvi#Uf_Q6YFF?LwK6&6yyeAf50`Z?{0RKT8mAV#&|u?Neg3e9 z6<~EBd6@qfN@g?FmQ)koc~&_+b?5PUNmaFYNoQ4+u4RB|shyyPM-20R3R8VNodx!o^Hhy|HBb3oL@2^qs0=wAbFYf?C@a6UOlVKK_5$z9CDCEq@|Vd=mNt>t(GSU zU&tFhJ)EhGW37)6{Zd@h?#1J~7+Ke0KUW=HBp;Fi=_Kh>Ug{@-Tv0BLG2%R^{oU+dTfEQFWL6XJ}1t9y58aDB(OBr!jQnv3pGj#B0^Nx`3h5C zs7=o^!eOxzPiT`PaW>nV2q^{NfZLqV?Q9g|ZgPv{ll48j}wU~H{x-&q_FOA%bGmIQd&x=K5_jQgf zyy!~U!2yvs1SBGki?3VMSaX~cCy-!s9~u<7)2CV`<@dw(ec9BqDLHu5vc=GzX+%HI z0uhw`+RAf(I#`o#Y>5~FPVKYhf!H79#(c<=w$Br^&<56CzMOHQrv15oU8j7xPO+x} zFvYqbg6M96BJmSyYSPv^mZ$XPjm@D#dmdA`aZDvgO>@XiV6FxLL=CYpxP~;@r^8CK zh#6_o9kBU!wz?6pafG`wTy(j)qi3nnR(lrLd8=oL35swcmx^c3%*LQ1KU7M(YNCH)P$nXEt!rn!#u9)QQAi zd3*Wx*8Y_zv~{IK)mVl++Q4AVdn?W=c&W&SP~3SL8@QmsIL?;mU;kYU`fZ!YdnZ*N z9XgQp7CCwwxq4&2WoWlStyp|qaddZh1N& z&~Ux2Rmv$Jwu88+PY}CMJpa>iigUM^BQ}QEK}X#;C+YILy*X2V6B>l$*hu%d?SapH zc7Ie8%M)<$Uc=2X_MUp;_3f!mpnam2dg^50!_^=3q&^6_N|3{;&BOJF)EUBfupv{P zT3gIoig>sZ(Efm72~j8?!ZXzuU!@Gl!QP=X9&qJ7Ra0;$IfLtMB2OZfC{lHZNLl!h z0s>teQas#JEP|hj;(w;R41R4^QRFM7adkzqbgi*;$oI-g5^240H~M~U>a{{9G*5jA z`{M%>~34698)tVH%*NZ&L3OsU)^b(=s^3<>L5T9750p5?D3Ksy!7LujsX zQDh^6PSn$3LEjUOsSHp8R9~_Ja?`q*r5(UPc@JzF8Cu?t7DU;5rcy+u{2$TYn>JX- z1=$LVkp;zuEKzK$aG7?{b2U5qX|A6|v<|PqZv-d1{Jetg2gyzAO=>?Qms!;9DhS|y zJkN*hmV`M5V}6W8z`eIJ>c)$Mhr1sGb&;dq{F@A!&-3ZaZ&}^0eq&*19}jru;@v6i zV&^~RPV8Z<&)vCPfO+M!PXyf}bTE;MeOOdj#m2^5DTw!Lp~p`zN=PhET4{|0>`Z2| zYRP4@mrxq8VUZdvW4yuPUA|u_3BF1essjSlq2HeOMyMd|t`TuGkPoyncof06c-!sN zG&GnasNcDHHxTCcuO}(Ex5+5?-tTl|9LBM{Ps85#<*8(y2Ig`*b#9b8M@w~&jE|a{ zcrI93AjLSV9U4!P_}*aA=jS84HW>b{M%fY{%kFjs$BB9O0Q0u~Bk!c*`A+aHRSGm; z_TvE!aalLLknncYb5Sp1naEu3G2uvP)4si{c0BPZx+P?CKiFsq@sTUzOgx;7LodoE z2++7zAQQlbU0>w+^JKvOVGSMY#>Tfd^o9e9iC8Ya)Er$56o&gQId(n13se`Y ztMR@jo6zSE`zlm{URj33Fu{h)-dp(3zKPe}*Pn)T8)*dukU@y!nDq2ZYl!u0{qKL3 z7J3}b9!4TS9ROOh3YoToj^8fyA~UJEyNOS~6B02WN$<$y@Tm+1@J=jMyM8|*l(>@$ zoE#e^?Bq`k){b9a_LwnocKYc5p7e))Yx`V5L;)wDyYp+`jEC7_Ql(u2|5xTTN?r4u zndkbO6lOnupJ>lu9TP-Qy%8b$c=R>J&W!dVwiOea#|iD)h@heEt+55|lP&}nwV;f1 z@+I3?f%icTU**D%Tbh~y9gos%iJ@-gZwJD+fB_BFcW-VKe6g1hPuafI)ciw-54_vx zPbkVcqBMQk=V}y&JnSY^6f!j~95r&`n^4?t_*s=zRp9?9y?qu30rt%CR=Q7eE zuh#En%RliyK6&xPZ_B^`56rLKeBNsFW>j#^WVY#bOUf5UJ^aPju2#;5Pg(nFIw$o` zaD*cy$qaN&GuJ+G&0t%M@`-3arn7_f#LFZ+A1nilE{`_%dS7z{x=84>sPU?oJGL|; z(u|#=J%c^IAmK8)Br0e2V3^Sf=zOye)#gpt(P3o`cMRwoD0xEP{0)D|yt_rBYC2n# z(n9E^UD#@3V>(h3SF$17{zf2r%!}3E3aOYv>OPH^)A8!+zdu7;>}Ki@yLy4($L2ll z=v|S{ofd}${+lL9` zu$oQxeLP`UEr_X7vVI?+ON*FDOZL8B{e3;1emH280`AY<ZL%KHXZU=aBth~9X+S$d*O$vu@RZOJYdmCf>8G*@)5Eu zGs*vPF3@OT7|npV8?e|N$#gG!b@UhdG`S!|NX>8!O>-j(W4TEP#m=NOzxuA$xechf ztXl%JUgcXI3YWOhK$9O(`dXU7Eb_VHQ#%0jB8}yXW%9{tN6}RBu>m*do=)c_c-%vR z$l=S$a&5L{PQM<$jdTFrSIOG5xtg63Hz(e zTWl+?w_uNi%wH@w6r%3F8|W+xAZqch_b0=!ofts+61-ZU%o%A}MoiEn!BG5C(%4ni z)Ebdl+7qvVcsjHJ&#ldYwPwLh2B`Sa_l9Q}vO+msfYgrpnIj))I^+~B(@ z^*U~%F-LD7cWIMxTQ+iaGlX2VOTCits3A9r= zWv6_le`fghFt#MU{w&AT;15JtiXY$0jHJyy(+ON=qbAem-NKXqOlj^t)#_b4)%@wS zp#?u8P$g>@> z%pH-?l@2ezt5on~??Cmt>zSeiiNmX-xp7Em&>dC|ZNVj+6*KXWi69X|(H5(k(cFE%WG~I$lv0mF7m!mV_MhK~Y_mO}ZF<61eR(ibnI0O)9F6H) zoW;3TO>t%LLNji8w;Us7lr7q%VZ0x!C0elUq^+&M&0KYjvt=R)-pdlz`cAT2g=<{vs&XvqD-DqiV;Rd+I$b@j))d}ovaNhV?cpsL!wtPod0Q^Lz z4>5iH(BSf-dg}~8_Hcc^qyrs|1K`m9AE5UfY`0Ie;OqkHQ$mH^A39SuD>onAf?Xvc zYi3UmD@*2`2Z;RKLt-S8Fzdcrp{a$$*3n_0rR~%s-Fg0ciNv5h6S>oOFMtOi;m0dV z6~C+f4Y3xh_ZI1xc}lN6@LsO1bzmnvBdL5(FVIhKQBziThu(R7v$sKAceQbVG{(nH z!PX4J=`FlR?C(5=8>@TBy)^)o$teodB$X+95d#D1n=5l2pDeKtI5xo4FbSlH(*xUV zx~s-o0=R8?&DUxPihV>cKB>`yS29HJ20YyKuj%fC zCK)zU0zM}ha^>d)lFm@a3K$4B02(2KE9uK$*dzTM9@1fY(THz%`1YZrNKubYz6qQxB)`odiEkX& zcQdZq1m4}vd%SanT|d_XfTIChrEzc$+ZvW*Nv;60{vf1Koi1PFhNJP4~KPSKZcGH%T_?y3o+bXH3f=k}r8=${TD-2O@wXgtIJnrOo0kIF>}jGQUpqgN>d= zrh|TE><-kd!LdON%|f0h!Oo5GecNE85X$a~Z9Hm>8=3-t?;d_a>r^AUWi~NBSj~ob z&=w_3KX%P)Ento+3f69AnT1(Yf zjd<{rSl{J$f#>3QS3W~)n9DDeMcIXfxHcB34lA|pkWzhF(Hx+1jdxhKe42K=(VjQL z0FV0fboGIQGyW=Ssn{YIBncl{skhuCsCyheK`TZ5S)LoD;xas;Z(SK0i;O=M1|=mFQ^mhr?$bBOl+;d?;Xc+Q>-VBy0?qy3qSU=!{p&j@`1 zAq50!L3p)(ROIRE)uBj4_(b|RH`b2wou7=1C=p~;*m8x|`#CvUC>9|(ub%O9d$m4c z1b6E?T0>}y9SQ2M(|&yVo-n;D#!KTA|9cMA8xB95{bnp|_;siQIPkmO`E4ST4CMa$ z$YVwqCSx-Sn`O@G^I(0L8EJni1aH^)-50?X^=Q(oUGOQyARxmspMq$RYg`<%I<)&h zaWy^sbVm{?&9?Ehy_x1-py^~=XimJK5l^WuZ04DZV7CSXfbL5j?9W3MmqZPCFH;hUjA~ryp zs~!pq;{?R2%f;NrnA|X4fIkKQ$-14Qrpw+-N?%m?fK*E8!Po?PaI>eE-!8Qtuv>2` zeOMT}T*;PzCk;nJw4kHlzE5*RNL0AdCTjHDtb*qC;K>@Z6%ukru?ipV4tRk4V$0)P z4SaAMR9dV2aaaiuHgs5psr=&OA|oFzEe-pD#o*7_eUV!8VJm#1JYRW z=Dn#Q83?+Q-j|LvCj4+Sw|W6SeBEU%3-1w?)j z3O}StvY*HOYFd&a!^MdI$17W=wkCWT@P+A^PTaTg?mPkXJz)-Uh(z)!kVtAW52u|l zrwg)vDdeCu&kTX`8%o9ujwR>1FtI9_cqexKn|gOHRM`=(0ppGsqS-U}H2Drjd`tm= zY!){YBQfg7 z_n)3b0OG+uR?kGp8s(>alNw9ajThnR$GV5V zDep=5=y{sFG57-87kC^(Ur9ZDI?+?X`S{8Gf>~!P#r8OMibBNaVq43jBI^+eMW*IO z-w>75fCdw%>Aa1*oT{6I%?8^&jte7L1hmwbF}qRS`YeOxy{ zPZ~+*X^>;;aoLuQ^bU%)$@PhVR%s8Zi=Mm|zM(mpJ4H+cB3;Q4>7|C1=6~Wq+}+q> zHyvw_@ufYS5hd0)l*q4t7R^@3;h9JVT7}UusOxJ*24Lic3Mc|U6Z;v&08==vsLZ$f zd5;IX$a_5^xQvo2?iCU49jo}`_m1mH1qQBF5F+5V9lPSM%gcA8CS~_O%M0TDG)<8x zfsfAPtKzHj`02qgiK_`2wC*D_KZ?FLp2R|jWkrxbD;+Cqc14UxPgfimT%OhiLfmZ$ zYe*wXg#AGc^pGT!7>6ii4F&4gJxg;TIU}<#OHL zimJ=8!A86YksfXV=w~8+;$4$vB=Zto@lzG!0r6Dj$UEar@BGlgf^@ z()Qkl&$1t{c}--UoG18Lzuo5Uw2WIJ+IT^7?rBE^U_wtT!5KCx*mMWM7bVhsfB(i( zDcGb0PAg!o@{96>7y^y>-fS01Kv8;eFjoT>dN@w9IQ(6{%wb~uEMk@%+E?y>R`Tb# zVVqna5b);qRXDSrF;>zV5-N;FMbB7UwBR;!M5<<=&Qa%2Z6DQ0i{#@h9^#%DkA6d{ zJ_0&B6@{Epq=o3G-cJZ<2e-mVe+nKsxZvFd{jK~xTCvQ(c`YnxRH=Ccs1~JIMm++2 z_hZ(lnKSAr`$M^fN@BRN)GW*;|0+F$0Z$k3qVQ; zjWeqEe52?;_W}O;)6xXVatE}=>Kgg_Fp1G68>NQhhzR2@ds&>J&em8zW=$K z5QZcoM9UyC^gkRm_ONRitKCAN7yo@{;RD7`BoW$pLI$h<;aKK|9H}|s9)F+z?feFk z2*E&P0{{PT{4Igt_P4_DvG26$s z4{)@}3f)SnMDY-Bq=~^|B465Q*20&?c?JQ$>>5y^ZrG9Rg>gq!R)6J0f*F>c=j6|c zP=nPapt6mQOK;yPtNldIjE~G;QxEycV$-m3Q>%pTyP!w6)SwaDw@T-t?)w^bo+m~= zGxE~DZ%QagPWceRtBtDwq=ot!(uwct&u_3;i1&2As>S+x@cb`h3%v+x4HGIZCj2~qB=Ap(fTNZHrHbBqpB7N(s^nfSc<%#@qF?iO0Dq0 zWu?z$r6ZQ~{)#x=#5MZw1q}iDSe_i>K?=0IfE%j$Q@c958Rfcw2hTcE2d{xKx)0z7 zxl*+c;13-Gf2Cmb8^42$_kHV3p_d-0CRrL76Oxj6aA<*UB_s~APl%q10jLld^BajQkz#KK zqw9isEA))=Ltcd2szz=bopO4Nm^?3j8KsoM8Y6nzIh~@)BAM&G`LqrRSfpS++o5&c zu8l`>ZhL{Ui?KR{j<#=d+Wp{E=Y4K*o^A@#)Qy>ds~!3!Q;K?i&)_u=!E5ztC9jro z$DjwnT+Wdv$rzoPmE_xAa}%V~SbQ#3=?ETb{hC}XMYJhsfdyS55pk=~g_Npd&bJ1J zC7+FSZj+*J>i<8?z4t%W{~!OKmAxW+D?(;vbCOU9p~yNmW$%5Cl9^H2yA(pmR>rYs zD3rZd_Bb5paK`81{rP-M?8U!N{6&*yoL$K(FE-EY_1<7@-rOyJ_jpQ=z|k&Rx< z+6N=iys$u3$wZgy%Ge?{IGE&otK8=l`@>!J(7fV|%aAzdIhxIzby2x60+UdT^ri2E zVBX*}U%JcVy}Qt}fKpZB=1%VJJg|8}o12bxbL-siPZn-kdsZjfk-Qczsjn$7es&wjU7C_q?w~qW_-9AQC2#Ot zIuu4lanrp1=jP<$HvvK!wrYaVsU#$^5SU~8Fey)k`=5yR$9jwV7N8o9W+&+8WKLGD zv`XTqv8OxT*q&~c^5bVcP=>Aw$hVo)mPfnssNY$rGQTQ(%?z)51fK+DiBlhG0)JF| zc_CkjlaPL%{qC*o@nm(E(M>dGt?2QutXkCHZ#JD>k*rqE`*AHTP9r=RTOL)jQ_0)P zB-;^5Q=A+e*oWutWq-K?%`NG(m7aL&fEO}8(D1c(3&e~%+p8ZzWoe2v& zE9~cp=3!c5XGFTCe`etEZ!PsLV+L}%oei&IehP;$iroM(E!hGKiTe{*LfP5lS@lQm zA8^Y$vSUS$?9Sb;l&>-^HvWdUZi!YMMH+M2lD`Zp-BDw?P0T5s>CulwMNA!a_CH?(((G#y`lCNRXR2U0`qQj?s z$>cSJD6A!6^wrRYXsHZ%{etd|IO7}fcT!NC>+nOitxpBF?Ge|}`TpE99@nm1(Oq4W z%Olib=#~jTwkveLt>^eLlWyI?<1r&}6q>UM7k0K;eElOOo6N`VZ@z{vnyNll*wf!n(te|HR`vJq#&0anm&J>3ZFSRT*!UVcT&%;%~W`f?Xxs_D!?e2Xv9) z(t5*2`YLO8JyPr-jw0O<)%6tsi_C&fYXweQi%fRwNGoYnT**-_M-sm?z4)kji7v1< zuoG8q7J*AVncfFVuhs-CdpUo1LQ@Inc08iN&79U^%hkFZC-xcM^HIXXT@KG~PnWvA=JF}Y&b>8vSr{4CJ8`MRPsy71!NDl$S+2<;s z&@?*BDBgCY`%?IC#Ghcuc1u^;hOgT#djdHES$O%KQe7;z{UpOK@i!(zQy(KMTOUyG zI&~$BoL85eh;4fHcz*%v^Ywc{UgsrzKLMEGJN0J`EzZT4;F67tuY`74J=_a%X^V)e zs&~6`64JnxYae?!g6Cw#axXw;ROg+kQ6oz9sR{)v%(X`$=J#1{cy>Gn)z9 zS8bVsq!$Ui6FlL8r}AC4bna!@)wiX~EJ_5ID`k-Vqt(ozD8` z=j_P4U1j*Tn;!V#4w)u)u26^dIbz)KEIgKU@wFYttU&wNUO_1K+9L zm%9xEq_Xek=;3H}dVCkPx1j{$U%s23wEjucMbDnWaIMoG@fxvIAOph@r=az+p&aVE zYctz@z|?R#hNIO??Bi&K<=HE?-lF#Gh}N?kRqg>1NgQ_uf;MskRzE3t%!!C6@mExf zKgr<^6_|ZK_Jh{bJd`|IPOd3nU!p9&Y41nbmzKBuT!iiwuyXqGH5Bt1P4O0Y1%EK` zZIE`)NY=trGndH`vBd===TpT$%zKI24vOyFdfk@InW6Ie>#~hia6`i6A$dd6{+~h0 z?J1G6yQ(4#FDIaIZuB2K`?rP_(X$P0Ox{O3PM+^MZHQ3Ow;J&auHz z#B)OQPF8<#iMYQ*kWsIv=N^Lf{^H_Lc0SzupH#W|R1h-ukbzCQ!vpcunr!~C#ON)q zVD1Xke#8|xN%R#@(3DKfTJObH(%)4EAw_EzG@V&t_&@uvLK!7DT~@H%3S7j}!smim zFBLcpgT;e4G__N*kfNua1YSz>CyGVFnvkP;?m~CSN>CdRCxzf;VdvM)3l?wVWaNJvA*~FEj(BTH1{CluR znQpdnU)Mijn_e|?JQ0eKHfXemH)lcP&ilMd5@QWw;|x?wD(aXMf3Sy6Ax?` z{RXeF6~e8yrk8@{P`!@mD@PGpLT(|{8lU(;ePR*g>bk8YE0FGd@R9*}S(f;x6t!Bs zN|ZPZAJ;$XTYJ@cc=*2}Nh62umhNM77|R9fBa`Qvubh9jBcGL6uRq!RLuez}N3 zSJaW#Z``cUNUG)Kd^YZveCH}v`?1$|Pg6~-yxb=C!DszxCC^v87CMn<|Msz*lCyck ze6Y(@H(`k$9v=E<=xsi{Nx6QZnm^IE;CO_nm0jV>>jmJ=?bD-B)af!soQx zxQ1wD3>o8$hwcHBZX(JBE!~1aZ|*iMz?#0kFT~}In^0b_eFnw8gAs=*?l_%`Q!n=k z2VtHZG9PNe4lQ%N(*y`yi-golz$LWH4J!;?2SX^T7w+l=j6e?=k^?{ZvAkka-QHRJ z{2Tv~v9K?ItZF2)fN3qw<$_uZQVtprxFI`)VWalBsQiYpD`NiHbyxiEhaMDX zWXaJeGsej+A_mghY_;dTk|3uYQXfQfG~^7NPTz0Hg;_Bop1OW*K%bAkkt}7iwG^yF z^t}qiT~-++wYMb~4t*)UPo1b{?Dw~Yg@aEa4(PNpd&c3p^^p;k?I{qRH z;C_d#P`2V#jCehytCfooul_k6j`E%Yo|L&*8x#|`=eo-RLLPdD>7-R5$!?hU;5!WB zkYRjhCzm*J}v+Bkql(%qH-tnog1MljO6MHYWJ|?Xf(ldJg)? zmhZ}5Jy#o!dF5sUs?|MdSCTFHxbj!G&X#}wUts0K&cv0@G=hy{jrDi9*?4;KDeZ19Ab{?|4Pb001h%vfTrIK4UmLpj(6(J+vBguh#UN{|fi4_16#kJr1Bx@D8~Ei1o2L7n3BxZmUS(fRT4TM#;Eyg$kHaIbHuc%SX4i zaylCXEX+mGlk0J0URG@P&g*v5UQtgQbDe{HjC^FIq_222aQ&IFOi8LLRPRnm5Kbnb zBh?>P@XzA6cdT?ZYBc#C5+ zXQc9JgbTf}aD#h_f06xY6$uGd&*qjiu698<>!^mD;r8MB{aF@CypYyD{Nv+gr!d?b z7lh~R)t(>Np=`Nx0HaVb9JDU}OCX@Hwtt~MDr*FG`Wc-MaC{YNL>AFVyEz5X(oi;2 zoPZH6tTVt>>|(<3qSbt$;q9}L)KWgc2>hi4=`L>`+l5FY+%pEPTUpnPmXtf)4Uh_L z7f6iZZ-B%ETJxZ*FTrlNA%12^cRuo%u6r?t7I|!S6cOJ1k?j1AGCQGDvfFWF%&lQ= zefs@EVwnv>zUMar)XO#YnV^4GD|#-D5x0TE&&Kiv9PqCfAlrdl-H<_9htv_Q<|}3I zf^`X!E#*AmsX{%oC+=gM?iVxH(3{i%x=gNE|%HP{x_kg3zRffJ6I)m@Sj{sTd z{e}dz$|8O9>aReF){I`A6NYb?lU{;6erm*v4I!510;?OA+AZ}J);g(Z3|bqXUH16J zX8R)hw^L>~Ka(Np3@vlL3sT2#y&{?91*!)EQbNQUe-EWm`R*eKFlRC^W%kw}UD_=7 zanBBPgMSL11}T`NTrlvNReWx~mMq*ao5aA)fmRtvReY41qUdx0gSKn!$1#O=m&?(f*6LBG4k za-uzyR2zVR7k`!j8oi1nmuU`iz%bRGdB^2xKk}e0f?oB;NX=jYUt*`o1>W#K>#-?= z(@63HOhopB3NRytTo;Q$DLF5;37l(WyAHtMxNT?|a)wiR- z6zs7D4V(Pkh1x(eKO^jWxKMt)i)cR1CzP&pm=8wjq9pX!4qHc76;>XI12z?yCj%yFY4|oH7rTB#yv1Va^>d)~_*L$<1n=jqa!%PBnuazJaox*X>AFSkYL)~@d~(le=(+~KA%^OP6> zuEp)}($NB>>z~p(CWq-lAV^O6R9NCYV94wb0zrUJYd)JUYrCq~ZL$NVjavfUqJMLV zA6>p`_onYz25#BT=cv;jE`Pz;nDxL8OWrs9jOGdME9Lfh8W$d{B{+%iKzG|nUi%<| zuYa}V_8W^7-}|o3oPCtJpv``gxA4EIT}p2&qdBFG4Y5Nb`PmB8V^B}8`r0pSxb?uJwgd?Q7~=qu7Mu~gd_I(g9Fpp;^u%R4~$qxd;{GU6RCbo5F7$KH<^fD1a!wPq~b?HiomJetv6 zu2ivqHC<)VdD6?$wfgVN%ZYpR(!=UsC1$tfy?CjYt9QS=5MBdT6^T zGRY)EGqc{)5&xS8Zlp+C0#_2E$WX2Uy$y)6)_#EnM%nRE}z z^oaUwO4VmCEXc3$f@bh-#JW*8n3P+rl@Paf{RS$?gj#mJsV9B|j_*yD-RbRnH>qC8 zEs<8ZpIG%1BsfLp4)hvo-VH0EYaHcbAI20;rj$_Jq}!dSzuun&3kdyhpG$z7|%q)2L3Q5eaN#~;y3c+WH8&u~<8OXac~7fdUWYm-yc zX(u03o?)ne84NAUvs0o`Z24-35LcQ*0tkI05Ul9q!vDXFuH7)oU(6b5a8&O}7++s` zV)%y--wIxZS`yst*&HJ&i}$5fixQBISYjbSd~U1tcXpDpu! zS`9|2l*86sM*2NyI#U7-K2 zyqh$=O2%rjNSVI<8H#|(a8KnqWPb=X8U~_*x5mc#_4|vR^bp^*zm;%0Zbjpc_nuGH zD6p8r-s=#}wD>&Q6zFLC2e`;E8LJgrif@k#=vkWtUm7jA?Ti5a$@oLGCTJc^VBc5N z?GgZpszRxa)5R|Hgpo}`*&ahKYz`GaXJk1Q9APSNSwjt7P^Vxte&RMb3R#7R0`Q9c zkQtQ=*t0?r71tbET3Y{a_jx!YsPvMU4c4&Mm??eFN)gd7kGz~g+KErh}0&4YCbx4*hV|%lIpMbk5iP1pj`y8 z{x<)=KaUNUwU;yeZ8d1c$*M zsyePc4AiVWf`txJ?rUsaDsG)7mxC8%!&g5zKw7 zRDaj*yV0nK=BF&%Xn7(o9C{^*uJ!prz4+R34sZcizXnY$-(cWDq)sN*l3lxL`Qx^s z_&0J>6-L9#Zv)gwUTNF=pKIkku@sYZ=?w{=x%^{RdLbrsR+u^m(Y3&}Q&SN{Jzhig z)4)DxYu$~uQ`N>l$u_`LW7loP&cY-SnkE8xkb{y=$${HhK_H#IDbo3i?aU7P9w@zfhRKJ-R1rdgoVr8FNOFVxc0xRDorHr2=gJzdf?c${)lO%T#2Ra^g-Owa`q#)_ z2>$o;>DrQJC&;hc9z;jy!%8+ao}HEe{#y(!slOXC%1+H9;n{hpcF16WEgIxqPQr?U zhk(#{!!kTW_RndQ&?UY_AfwJ$Sgv|uSbF{3C5iv3+hFZBFAUutzO1C}6XT|zm6JTM z$uFBKM5J!x7_nv7qFd~k{0&VKr=ydQx4tXtSGBHyPC8$0o^IdZW#si)qJfq=E+fPz~D0Iw+I6D%D>-%_v^!2!5u_T?Fo}T8-{qt_KB9HBcYa*&#aYR9L;=G7ye(U^kdKSYlfYZ2;z%#@!53wLZzek;q@_^lRBSd*3h; zamd&-5Ct39ViFk9F%O|(Cx6WRH`~cjL^2i`SUf}=F@T?RTL$S&@us$PJY+5I9X!k zM8vyH5gNwgcMj}u8WRhZ%d)yq2{(l?-MhlP zf;kzX&S23QNV{TQ?!3gq?YA}lFkEdw2S(2&&$q^Unvz!im7ZR#?fS#6SmFABVhCWX9}WU7&|87 zaNYb&Vb-Q}mSYt4e!-}+FLdPgO%clxLf5ub$IbaZ0YTx(4NBqux2|)M8IT^ z&OyTC#!QNsB9ETObEM0PVGrDiGMKSbP15@N2YJXwBHsWekOh5?K#AKrx>3+PKU!6*Z>8|{mz3l5% zpG;$bZMXgl9~pmLS9m}zV`nWxs}87y5Ay0`q1hPF0n4p04V z0kod`ukK>8f+)X7TU@ujyJsyhCkNC~Ok1PXQ75+)n~Pe2!>$Qia1(Rvz%Ad6(H%)I zSNVpc)hssK@$lRoe}YiIIaA~A^Rbs`; z;3><`e8|El*Ev}lkR@`|-Gc_B7RmQ76Ie5-{rNQSy5&w@(q~W8pWc&nRfF8M)XEZX zu-Y(Ad;D|nvcbqiBQCeHjK-eSzhIp#7Gl`{k$Peqg<6?v3 zF&o+tru1K4S5n~R1CmEN=<^x)w(Rc{Ax_^_Y!RO3e+8IfP_ya#t{6A%91hAJ_I;!}a zlDs?#|L&@C_!w=>QEP%}K5bR(HQA4JN7voFQs7xTaBZ+KwNjM_kfI-_o>+v2YI7?b z} z4X792VvBb6DQ0V#2@{q9md1xa*{udD$K~43dfnl`RyCMuRY&#><5%@IP$Umr4%;(` z4z5Bv{rT34{7KE6(5*FS(~ja}k+k?bnv@ zz{8nqm#Cn*MB2!JkDEvtD+r=?`unGb%eUfHeVxn1g?{3>LVgGbxxEAQl&_Xqy=i;X zG@rxwL82>IW2Lk(dqvapGM3;tdztL#i{UX7-;SU(nC>bjdE#*v z#AFoqNJse{SEua*I=z)h52IJI2&|!UHsPJ1mZRk}(HnD86Axf+i4jM|ihl}V^FWsz z+hwa;b}~5c{j>dcLHCTE*E<1*n(@94Iot#v#rM@VNdz91Z-1U$=DrHzVBDsm|9FQTBCsLxNg{*&3_UuXZ|VPHF_s zP46sd)-)s1q2Si9b-6d#QxG&vdX~BCA=_SJe4bsd7dRr{wTa!>U~nl+8on&rZ0I3n zQqFvDVv4`J@D#tjWnnVU$`U#Cx_F67^rlq4Q71Zr8AwnbU*Dzvyl`W!cM7y|YKcE6>h) z^`a9W1#eUHbUs*0^ zLjKRG)9DA6aRW8~+5{Uwf{T}7nE(TUa&@XR+~*Hn^tQtb_(yKrfIiY2l*y(8?INGG zTi7FBh&(I*yi?u3QoCLMcG^$6SM%=BomBVI8+6jkcQy;&+~fkCS0rJ`17`FhhAEX- z^>=RQK!138cY|y6-yn{Wv%LMY?;ClX<|FTEr$0X=6roSn^moVB(HFwWr>8g&R5o)W ztm~YQ$?Fct!?JerYrc}mdVS=QDxBanhKJ#{r_-QV&*RnP$FPLF!Jk!gYBn8ZjUwc0 z;I(L}8tl7Rfv*)}cNG~3U-K9IE+RX4t%`6V(ndLs%Y+=CVK3B9V)4d;`)yYcAXv?7 zQh*zJD)?2u2*nyqb68YcnEj+vLWka~o}a1W-DnH#=hroCH<>5-!gpMJEpA5qu|NGSXMFO zGOdim&?vr#aFkvs`c&ZPF@3F7#6!2a^|4k*-Wi+GqzFuL16B8XD3|)?Qu?uOINk>1 z$7TBXmoVZtIyzNkP!usW&5PXbo3b*1bjEp`Z3LZ2_UFGdktA?y|CF3ra4z~Vy5mtl zUBYAQaW06Vn^VvEEZsFFYX0z)A=ILUthY_JwyVnY!~Ib%;U$Y(A^z8&i2bd$-Dgm| zgF8FDH7un1%e3+FO3(9c&qxuARZ?2-5;=N={9lRA3 zk(GS!3}~l#;(51%_4(=k{yS*LK-$XZJ`Vy6H`Y7wgZ}O-%$?(b`SVB~*bqr6lyE%~ z3f6xw73jqJ@s6c&N>P*VZAKx$Z^A4@L9~Vf#U=GWMxy*jKiqz)e06WHby;3)k6?~J z4HCFyL_?WjZKdo=cPtTiw=kT6HOXH!o})bo_f?^>TnxBbtLh2y zH#ePTUj@sxYotNA^96#$cR!5E)&n4}nV|h1eug=ob3Slb?AESE_>)<`&?9;?_(dm8 zq=$Nu;SEUYn{J=8OH!w?i|Kyb9#O}hZdFygDHWC&-hVjvAoMMSIA&j{Z3>}s)F-&y zj2uQn*O64})*ud%c)s?Fq4_HcUWpj6qA)gfYixv(Dr3#$GNIJ*RH|HxzF)w~L~y)q z67rR9dxJ{okEr4d_>-vWYksDJfc$y?d}cp=4WAewh!9H!H91V4p(0cm zK)0q54LEWYwiS|IAT6)8+(Ry+*VOMG&D4t@oL`M$NX%=q)E=K`u!aq4>HPMVJ)*{T z2;~PX(#MzB!`?)BW3JqT&4k179BUKPTicxB3vOTY#dSpwfAxVEReM!ZJM%aH9toZF zG<0+1ef8x9Swnj1P|_kO>K>K+Z3=viI#*=ePn*L(HXuT5?oyU0kr+Q^#>QyNFnMb1 zrK1@V^ogn-6*w4saj7z5-G%;)57MD0W?@_h=L`UuH!dOXsvI3WFquJkt-y%RFzwpq zEa~QgcL8A(=hTi$zf@AtvctbetJ3vXGqm$26Y*c!Z+oWE2ji>mN1qhO57EgL-xMW! zwf)>bmj1Kuj=LHz_Tm}(Vf>+~ygWUhM`}{|vRs&K@{=+X6WChgWxB)d-$^$SPsU`f zfn2hr6CXc=pQr}KKpG?J5iP&tk4?wVjr9H{ZLP8`^Gijs&E5VklKM7g(?=H>SnWEk;z#EKK}4zrpRi!;WxgQ*Y)){Se6PD>LMP`z^AeH~7d8@AKCj)My@NH_zudKzn{v4LeVUYrS z7|YAp4$bQDT^!Kftzlv*Mop8l2w(1#W#sn68=Fp37-h2doIbt*&IDXb>6Mx#bbcWB zaKXkm_VjAprI`hLrhv^9k>v!r1z((IZ7$kAj_^zW=q&j=PCtiqd8lFj7;SE!QKQh6 z@vdpM^R=?D)Xxg_l{38zdfinkCABxWzY6>kfXlZBl)LJr-F)X)znvtn zn;RTde%Ij7i7zJg$a)&kN8GgwB5NkwGCCWNT0&p=8Gq|>;a@XBzL-Ij7=g$`XO0S3 zJ(j*6(QnJ5NLHBie2^EDFyN#+fmmD|kY6>5I)QXn8k5sWU(ah5-uNqDR}mrH>%17& z^h+p5SbX=kE{JD;7T;x%b;0pTsB;*@Gv<3=PEpwQr6%9JO3W$i>gmaDc!#~E0Vjwc zr{lWu`0twB;|{Znc69z*lRt2HHMzDX$N-q+NM+OZ%Am(q0T#y`g=7A~%? z%$3bY+SZnUbh!am@WRuz&zS=AC53oj5|Lb{;5vV_n}?djb8?c$fEIKX=8Sa_CZf>nY!LJxExF zX-lx5?bDGmDN`mRi+V)tM*(BCC_0r#;xQ4D%eE*;O#(kkzN_);de$8mW3ro%h+0+; zY;rla@78^yUp?UA8B>_+RhjFw3*tQueUspc^mu z*@zQKw-BiM@;7)wy9F3oZvvvUu#}oBV2c7fA6%9}%u8-qpuPw5~r=?yw2rF(DjbhF@-Q)yV~M(ZvO0eiBm2`>zI8_Gf^Oid4|9RfrV;v z?s~;k$xc#lvucKA! zTui+3muY*YmkfuRpYRgPUk(L?in3ng*_YHi9X!<7yjQXEtR`BgQqFy&TY(0MA{ z;Pp3-ds+VXTJ3+)s4GdAl)7SzW2pQ70~0X_6Ynte#!3_q{pax8K+;KmJ!K7xyZ?ot z1fy`EQxfV$UR+|VEuZ+QV#|*G?t;lnmB&A?Q_#O#`>w5UT{98uB<;0)pD8No zzqHf~s+Sk`u25jU{4soZmLIU-mwqub^0o3rUl_-2``)&%Hr*pVidr(j)I1eP6Q+q4 zX_TCEP;MKd@`yA(UEHsIKGxL{Nj-dQm{tWPx>hj->C^bU%cu|_c=$s~MGMz%oxO2e zpi-`vw6zc8ndVOe`>I6W{rI7hRRC;rTA-TN6mIvEUZU(B=n4;sCi;Uh{#JfeIaZ`$POW^HyM{q+c+wNZph zdr>GOX%3qF%h`jY2RLN8qI$7Hi52l=w&&T5S*4ASCC-(=98+aOa?C&y=XIXH5O$jD z{P*iKDd>l{vu?n__%F;J-CxC!7h(|Tm)8+;9N{80D{rTG#zT|oiKGqj1plA_*en0B zVPOK0W0Zc~;g!{)Gr32)*K%q_kC>NbuE25o{FENn;mlmPGWyiuDFk<%jD!+Spn8^; zxj;>yk3<4P5OWH-mIN?eq7 z@}(UF4d(|ti)`Z4HvZwhB!DoX80dMFCwC_bRo5M*=9j71kZ*YZ7J8T_rJS$opGK&5 zbh9^Q9 z{iZAHj(%|0-}Ni`-@Vw3)z#IP2S&&S62q4R1P+>dSPgF73=mm7|qx$3&Co9Z>r8y%%UhHGK9=SFU+chX7#^iS=Tec;0ya;h35U{&^a@s;7pdQXjou zLBqnrdPtXBy?{QIE%SJ)_EmSatlQ}Lcl7d`+_k~<&*_B zOX6r$NBab6b1FU2GVAxx`Kvk&Drk=JOkgvr37wLmM}P?VbP`_eQh7Gn4{nQh)*C`+ z9S1Ndr9IELCd)I32836+57vWHJ#7;qda zV8RZ&S#v)nDry&!crhVze#NtL<^?L|VZT;ZK|Ph$DYrtqzxoj)q<>6cgt7f|XNeeiJO zo&pqF0q8v|?dr8FI3(mZAfPTl5dL z{4|_CH&O~9AOy^a3U>M^{f(V}%$q?iX+yhna~kI4Y|XkEsCSG5dFJ+@C{!@Ja)VvQ zCOD@5(3mpiQ#4zV7V?H^q4948mf_cwOkCJtg@GuQ{%AFiZ(0g=9@4MK&Iv8&dpG*M zCW%2wXNDdWR+w{vEOk&Lmm8jlg@%vjoS_cX9A92v|M!@2vfAMdQj%~a`JSn}GKq5t z4%DQag`sTnr7tGG%ixnXadDx~J7%zs@0FV8KiZKy<*mwPz}FjYhiddk?s_oSib}eI zt6gCb%e9-ycdgMSmG7Ends$VGHWR!Tmr$y0%Hcha1$x=yfk_gBq(-umw%f34TjRi8<`s3(ewB-DI|{L58ZbA(74H{^Q4?<$ zz*N=$QGXs~O>-5Q4WakyvOnh}FYSL=!?IJU5@ot6WCuS`@6a{RQr~&vE&9&zlmQCb z`&KIyaj9G;rYV9(mflv;CGx9*&QG*x-^%`U4nI^^o?q6w#4>UTHg?z>6#`!Eb<1$z zml>qL`%dRNX-|ph?cHYTC!ZLBL@uGgSDc1!QyJjhVm8g_4Py%T@cfrYSvF0fq}%Uu zzpHX}AXQF$%rg-yVbS<&?G*SFzu8||8|wo(yf;DRD)c7FHob^tM^d3n4JG_u-BpRn zK0|wxVSlo!^Z96PY?-o~V62^5XCf!vi~A-PUXj9<4Hae|K^XC^8yrwE0pghkj{s}B zyWJ%0hAmpKPjr!0Z=$e8D!(4{EWWlmYfimPv?ld(%x+2lY<1j7-|r$pMPIFtpN?vT zE}6Xae$KB9PxuCL9zB~B-z|h^{xp%-5LW1w;WG-bE?p8xd_NwCdqU^+4HqkK|B5+{ zqw>u_*++*>$ILb>vqGsy_|JE3q!XHF`ap}9J7aQcLuxb6Y0muR0QSOgtwf)g!AaCv z`NFAQE!UhB*57BC{dkjU zFgT_zSqu{fE~>9A-Oi7WnuNK=Ip--61sg8MT@2_9mt5}^aKOrD@S|+DqjD%!U;4wI z1kRzcCS&idx~E}bn_RYM6w&bl3bKzrnlScqtk-<(z=32E;Rbh##4T-o_^S-ij#Yi; zne&zcyN$q?-Hw9+h!_n3OZv<>IkuxWJsHz+Ey8Ny4(_YC3+p3i=7_CYIG0ct@86OD zL5xV~NfwBTOl6v1?`nz;8^E%pkckN)TlGR!F?#HlAnOn5R4d{2C$aU4L^f4rn? zVOwVlA9_part?*!Szt)Urf5)FA4z-t?8_dY^Iu7S5Vz><@pHVpnaC67KcBGO!wGgB zO&Pw@+VcmJ8tt8S&{*YlcOK)}wL=jK((&}(u(Wwx!aL>XLJ*)26tOztE;*UF>*M|Y z&`-kSP*BQ$2Qz%I9+#yC1E-Gp&H_pFrfft7r8Nf3R1N~aYh-HR2C*!ZtHjAR=jMFjlu1@PsPy2N zRQ=*(nkmAs|50jth3%v|YYbr1sMONp5kI{I!jntPqyJvpE$Wy0_$eE0 z786X0CSC0B=Cd`h!2LHr7qjE@$n+Xx9 zryuD{fBMto@VoeJ4GtN8;8SWekY{3rDt7UO7}A9$N-K5se|WR&Arlx=#{56?4qY|j zV@S;NLdBV&wiUw@23}lEJ{f9ynVjo!gTBr4MPPhrZ0Q*N)!|9;m0Hwk$E%3P?uQoL zA&3PkaO_=D=R3FsK}<_6$CEym<3j!VablCDj?}mh6kX4Q!OR-~ThIqx+8 zeB;9r8KStj9Cc%>fA(E1O9fOF5nF?B5CitE1_F}vukn@g1n+Qs4jC+O=&QwpV5GK5 zTMO|aj)|hHw}&6Jn|m>$!PC5oxMH{4wzZ-o@%@^&Bly3H7YD2IYG!>W;tgL)FmIt2 zv?C*s`YP41jD^1NUM@8h-vH`ElNX;x_P@W zCs|LipXBney!ENsTjV$m*P82!uBe6VfFYKa)9Og^TOZ{?JX$N4>d~Q2&sEOmGq6wQ zgql`HJ=oh1`Bvi|kxh2Qr<`4pcgx6DUy2MK+g}Q*7^mnGU*)dUF>>B7sT2>Nu2~I& z>+FgQM*B|Tqk2oPsQ=BvU5_@Mx^hgf+v_Ad!z&*(`Kh+;<&8*n5GF>o%fLw*#yUC-Q!Ne#0QIK{>KZj=BrMNb{#P14?h~0FkEi`hl%cx9|_9k_Zna`M_&|I~u{V z(0H^{e#42gb=(&=a6wJTcRk(z_7{Rt=&DZFY>9HGTybBSKM<5lDyThzv0kRYGZr~DI=lIw6+ILW25Z#q#6 zzL+x2a~-ii@ARL~5Ps2BDcN^;U#lwn-uhA1{|tWGwD5eR+t;mZ;B$_NRnT=~{D?9# ze*2P*1@RK7F#JKL#M`tU5ICZ;Bm$X5fCcaKCCoiM#iH^k6|~SZma~~9Z2{!a{Y}2y za=AT;%#$Vu-dXrzS9y2kWTf(xd;=*XV6CJptKTQBr0a04=)tq>prWz0e2n;&KC)|_ zsD?`iqI@sPr=UCh^9<9SfVa!9EH3V2ETFJD>}d&TvsL48|Vzo^$c7 zS%B4U=gZ&e5-nrN#k7k|`pS%K(UR-YZ5J7jf!~TQ$7MvXH2|M@7Btw|P`@b?ZCSml zcXfVi_d{sr0yZ#Sj771l9?PiT=oBp?F|Lzdk>zmjRF$~6OVS$#K-tl&uYIWQhwpmu z-F)Ix^>+twT7O}$M)1*l{1O|rE+!*cYoZSg(`pHzZ{;nk%Z29)p?N;^g0=&+TC=71 zNa&Rp*Pp{JqA`x%)iOdoa@7J-gKMlgrbD!cO;?Nh2J8Bl#GfjA=XD(E1le%5V>1^l zV$xBqzzaEi#jA(T>-qN>LAi@3iq!J+;^Ps2R1ebUN^5`BIbL8`e>)qhg-gEjj)$n0 z8mSAl@xAP{R*EUi@E^K`nn=G`Q`giZKv6G0lqKtzYhEfZfqpwQ6?_yc9OhLk9fc+A z6Y!EQd7-yBh!JOhDlNXI-YN?Ycvl8COK^C!1AuCutSg8%+AzNKvkCrV?@k<}>d|EW}5(fi}%f_q4YGh?>ommR>c{%Xf?v-tfm`XKwX$a>^o#EN>y6 zsfmp}+LU_P%$XODMT(usyff@_#o5I?jPe?x^J4rzRh?%%)&Kv-5e~&ch$16^eFy;t@s>)2aXwrtsZgGA&c3 z{qNeZODfvJ$T&fucZ}ZWor|sjwx_x<1%6qg9LQIgIQ|GOAJKItq(ip}ZPR!5IW_Oy zNCGPuKX#3GMBwVJJaVhc1Q& zy~vvqbm*N$=HYaG%o5#)dM*fI4TQ#ej^}bqlna&X{ltO=F$_+ReumM-cF;M>2}Z^1 zfYT#fMT1tz2Fsd{nvjJWNDESO#iJNIu{~yT(~KJiTOi3)9d4NtV_=%gxAJax_K!4T z@1^l++BI5kjR=oO4u{onR>lK;IpRR`mTyTWa-Hci7f<&Kr&71s*wG+ItV$6vFAcbI zRg47MHzH>Hjssz{s_xW>u?NY-O+5tYM0l5wi{AxKT`Y zv>y9Uwv7%$Y#YwUOj1DhL_L|+x7-`7pAoS^2XA#lX0lc6yHg|IZ~<|(7`3nhg z;13hvf0wPYH7<1~SE&-Bb@lF_Z=O2OuNKl>bCtVT?Q90#sEFq&?icT8c9NgI1Vi80 zimaAQzlM3gYj>)80V1ZC?GKG2?p;oW7t}ny>GE4PY93>7kpa*&oJ9s zxQ%O9~}7w_Ehk$QBU87a*q6XgFGaz3LKiaxIxoL%&iAM z0m)ZUPHR<{Lmkv6n*KV*gDNxS?C!zmCBJ!{W(MM=Tl;XC@f z2p_P8Dbn_iCnlabp-IChbdXcf>yyUK5nxu?X#-xFgiHu45`x=X!$=kvQ3+f*(seP>Ps+LhW>gV67X-;2N6=J`5B;1IriS zm}qAEqSQOzKZ{$m1@fx7Oi+{IfQCZzgNi>&b|dEB;pQsqK<2M&4yU6kK7y*3hpCrx z4}zp0^U;*9<{uvp$OmRy-ASo8^%q3pUe}#VI?c{_2!vD3 z=a|Zal)c1_Fk@7lzPhI>cKn4-$QTE92A~vYxj^CU{;^>IgbfKbd3?BRS{_ z0vD9RNL4kD3%dzTAS=BDkXDsW^0rdml1`tvuMiw~e|E5sL20EWz&gYEHgD}`ER*98 z_L7H_Mbl!Ao;L6Jp1vCflCYn6I)0KSUg~`@HbRwEOFfryFLyy4;ppmOnSx_+4OukE zAz80mZxK68OwN5fLkZHcrilJg$N7I<{JZ|^j$b*Mq6xks6&}On0CLQHnV-W9G+n?D z!Riuh=J~R&Cz1SArYd8ULx?7Cd{A$SAa~iDA8*`jR}3Jz)uzUiy!?dmIghDJmR4q1 z8tR#u^uj0FHc)>Gr!fpM@L<=$N<~D}ad52ogDl>NpRgfDxdvlQNJe~^D6PD_1}X#E zJ<(`s?WWtyHSc$6U&5$ov+iXQwRCz|tZM$1BB(AmBSIM-feB0+1>ZinLk}FU5R&g_ zi{FVzFlTRKzp zogT0fx(C7|z*}fl#`-$%+dhKt5Eg{!?mKA{ zjq0S^JBqCmt?xK=wNW#Q`_6uvQhL#R6YpwT8X(+J3kI(ZAB{-fvGDuIeSLIB_m}*} zUV02DV^x;RnPJXK?>5JtdSg}y1L1bTH!E#s?>D`BfzJ>5IVOlR8BkVc4W8W-!FsKC4->SKy0CaJyew6>Q7kzv z_XAB4jZgT}7HA?GT!9!>Hoo&dW4W)qMb(dmBafGi^biGHSyB3>@m=KcT~F~Bk7bJ& zt|O5^qw>50_jfu+;+XBu`;{GKP9&JX3av0;@T#@LBD<#@u9zyfpRt#m`(I=-Vprm} zo!q{>k(`=i>zIkz=;|yz(gY))W$$nMwb_RGzl!N(9-IbG?~-M@$X%Q+ZKr(=`-i9h z;nQQVwuqD-j+Nh|qtav+Vp7kAVb@4Cc^=>^|rG5^C(sqAt8;ZPqE3HwCy9eKZmlQ-1 ztctG+F%vCQdr9kyn-?2B2%G~uCpF}D`Di(={G4(;M?ms<)x@0r=?J4mq^(z~L)WJ_ zufjr^q-ujxL$2>^ZZ4VEP8KuYV63f67xbKgL@UQU+6Kx4M+nsiIlv?s?B`R}uzlEW zY8&!GfCUpHM+pN>vnQK$G`_}TuEmXBr%u%YAy1j7ln6f{_b@WkOO|%c$xWsn`+;7| z_;R2wVZym=P~H1crfV}|H%Z#HoU8fYU7}@C8b`a5Wwa@P?M3!%;=U-HwZW@+(U`l& z5)Wqz!PFr(W=^pD%SrAHK(r)kT`Ng~jav9ruxnZdcGKL5y^L{~k_=Y~0ad#s`}31% zLlttxf=u;Ab<1e91M~h5x5Rs2y7fTv z<(Y?%K(6%P^jp(ckEtQ+sc(sk2?(gH9we?pusObJ3cuWWKhY z`+L|ll=c7<<&nb8YQrHuD>9 zu+T2N>U`g*?se!0K*yOGijJ~IvciY>dhg|MiY`zyxNMd3Rjfa2fxBPO<=pTJ^2g|D zQZL@eWw_E;7J)e6o1!n3*it@i<~_@_STT@aTVoNqTYM8)HTHx;qCLx!hRsJAfjJJB z3K@6>jo`uAR1FGg17D103>s_ZaD3-F`qeagx%4qk#IQ6?+jFJn_N=?Ll~mq2J(g0p znVwgfsQGNN(r0+U-sg$ZLkc6}Lz9~)QVh&U)Wn$LB(9;7otAN^gt}*i+AJciRb&J^N{`ouI9?+wR-J0l#!KTc z4MN6j>4*Fp+!yOI4XS&+=v1r)tnK}*NABhzsvCU|LpeXj^A-9wP@&l6CEU$cuWWr*V-f5@&jo!<doWpUS9B@is@y?K4Uyho2_}y7o zD{sezI(J#FS-$0L#`Q_-E+FGgbIETj_RQUQPzwUc;<#>JWW+g~l1WJ zhp)&lKT9cW2##aDY`_yaSTEz%OLFl7nKjHthW|rkJi+FkSCa5xm7O6r=sq zD=>fYRPq4$&8=7FIi*L7@rgP_ca@Uzy(T8nS8#Q#b5P)7z_RE!xPtUREg~qEfPkC_ z@6nY-iONRg*J~p;KHLJh+t!MjCn+qk>1e&pY4A~me(TuhC|KtQz|SPjU0(d-Dg;=A ztPnD0zwIq?725owy!!OECNxSQb8;FDkxIP~V}7%m+Cty^YO2fcVp&+^8R^3>E%rC#%ncl`0R6;1Q`UpY3(|BQ`EmaKi`ObC%oa6K0ZA5v zYGBqMC1cfM?y~i#eLN%gWLLzCIr`3(?2upz zKN}lo#f!~2#Jx`JeI(&0-ay6>Id{<|a!c=WK)mfoT9472WG;+Q=aFfbNHH=H*JC^3 zy&e8n(F!S0A-x1)6h@0ajkEuvkLCE#iqq#t7vubjo5ND?zQ}Z=I)tBEW8tNF9hrbY zPuP=!mX*+wEK4Vi2DSspev}*$vlcuLr3-)|)*@dZ?h6V5jUPIwK)nUSQ% zPkc}dLLIJp#mhebSQ?Bc2k_zOkxQVW$sm5>dJS=^?oa(!?KN^lJOBGN9V8EMJx*HCsn3JoLIOeR97LGRuk6TV0GlC z4tx#X6OC~@?@mHkP^Un)e#4(MUB$leEIQB0hP_exDD-xC_~C<>YlG|-OC3+Moc=rq z(cLsIAi?M7s<(Cm+58nQ1;fSHpR~zNMO6VIlk-2S9QacQew4C3GnJsIQg>FKunJqUxwXct zypV}%heR9TmC!yKbj_Yx6o;f_C>tSYB>4E@+WdPkR1idy76i+Q<8GaY(Az3utdRnd zx^;ufpyaBeDM+y#RhePpw=z5Ag$t&kPO;ANZ7J-uH}1i#7Paq#=X$#ajvcxUa%E7= z`gWuRWcy#tarc{g)Ph758AN4?B7sky5@{&41t3CY9b?`abUVE_GSw+YwBXbq^IwGb zm%K;-bUizZXW(%~s2Ueb#d*<#u>%2#6Y77TSr(6kuo2_~EzEYYlaD9fjD-dc+s$1q+YWm z?CKaoh!f3Qr+}>i!6~;M<;WChof%EOkl2+3Y~K8MF^MMa+J+65!zN<)G3zL$uUJHU zBr{>qz1tT{0_O4Q%4REy81_;CzPc<1@dV>>g(xLFa33pVBJRH6MRcfUL8FBH`UF=x z@<-|LME)zrZ9i|`AE;f`05mKR-zKQagRJpIR&TTRh)(N>9KFGja6;xP8H{58MvQAwIBq(`XUH9N1}grs+|o+rd11<-XF?kDZ) z&F~JF#VL5nsPywo$Ak8$_HwUS4zfKC=1&N~+O*r|ChMOQJoljF(P}glG;t4K)F0jq ztgVeU1R&GjnyWw12=T7na_ffy>?3LOW(Z$Q5s_!RRewoQ;w$w>-MgUQDohWSd;tM` zK+Yv-{OIwzf0M^YksE`)pPq&(MJHD*oNLAV>D3>PB~lw;jo;>m;c;h*gfx#p_FK^q z^?7&6#BsEt!r4&9b^iKVM^;_f)bu|fb8`IWY;X)B2fuK=c0Z?y%9pJ6G^&cMe{lFE zX^L4ZS*FUk&{X%>805_q9zKp+xH!m<8>fSg8XGC1)j#Y0vet#e-Ma1YUClv407}!K zGZNRn;`7-oUsI-SjkqqXREC>F%WL36uQIpg!skpczUlOiCA#mA?+IZ9=?z-|)72zC zv`Ub&evC4E1$mvwD+&oS(*y?Mt3>GEzm8d^6|ce$O*GxF4kVs(eR$~g+o#WI;WSp3 zWSJq(^>{I~4p z9p;Q)RMl4qW6M(x*co6(i3OY2pOy#zBBfwt>r8nSBba4H2}t>1x6PDA+i(2wyM$M6 z`OXOkYE%{3)~Or17EY;Mvk&$-D&2>P*twudUt6pL!a*lfxU^wz@Z}#8x&)lycu9NyPvU6uHV1> ziHmq?U*(bF1(CZ}$Zz16ZEdd*NXba?M zBv|Yf#`cW8ln-gq+h&b}yg?S`vLjXuKH}VokWSW^qU*tT=t%81+(n#ECh(dfW3f92u&1Fuk-3-=J46!XNoDV~9 zalOWK@GTg>geFASE;2qCPCyY3gCE(5x#DqnU%3^~J2K=hwk(Kb`p(_#b86J)7Ef7b zIC5QT^3_os_$CQFKfY?L81Po3F$?;eWdIY;Cz)q#^-Z$}0nxO&&?I|lnkKQ&r=_j9 z*ri%N jKH>4u|IecH5M8p?qaY%#ndR05;H9Rj`JnjzOTYgE9&YM= literal 0 HcmV?d00001 diff --git a/docs/python-api/triton.language.rst b/docs/python-api/triton.language.rst index 4cb437faf..1f05ce8a6 100644 --- a/docs/python-api/triton.language.rst +++ b/docs/python-api/triton.language.rst @@ -121,6 +121,19 @@ Comparison ops minimum maximum +.. _Random Number Generation: + +Random Number Generation +------------------------- + +.. autosummary:: + :toctree: generated + :nosignatures: + + randint4x + randint + rand + randn Compiler Hint Ops ------------------- @@ -129,4 +142,4 @@ Compiler Hint Ops :toctree: generated :nosignatures: - multiple_of + multiple_of \ No newline at end of file diff --git a/python/setup.py b/python/setup.py index 308ffa966..2965f167b 100644 --- a/python/setup.py +++ b/python/setup.py @@ -126,7 +126,7 @@ setup( author_email="phil@openai.com", description="A language and compiler for custom Deep Learning operations", long_description="", - packages=["triton", "triton/_C", "triton/tools", "triton/ops", "triton/ops/blocksparse"], + packages=["triton", "triton/_C", "triton/language", "triton/tools", "triton/ops", "triton/ops/blocksparse"], install_requires=["torch"], package_data={"triton/ops": ["*.c"], "triton/ops/blocksparse": ["*.c"]}, include_package_data=True, diff --git a/python/test/test_language.py b/python/test/language/test_core.py similarity index 100% rename from python/test/test_language.py rename to python/test/language/test_core.py diff --git a/python/test/language/test_random.py b/python/test/language/test_random.py new file mode 100644 index 000000000..6c15a7588 --- /dev/null +++ b/python/test/language/test_random.py @@ -0,0 +1,198 @@ +import torch +import triton +import triton.language as tl +import pytest +import scipy.stats +import numpy as np + +from numpy.random import Philox + +##################################### +## Reference Philox Implementation +##################################### + +class PhiloxConfig: + def __init__(self, PHILOX_ROUND_A, PHILOX_ROUND_B, PHILOX_KEY_A, PHILOX_KEY_B, DTYPE): + self.PHILOX_ROUND_A = np.array(PHILOX_ROUND_A, dtype=DTYPE) + self.PHILOX_ROUND_B = np.array(PHILOX_ROUND_B, dtype=DTYPE) + self.PHILOX_KEY_A = np.array(PHILOX_KEY_A, dtype=DTYPE) + self.PHILOX_KEY_B = np.array(PHILOX_KEY_B, dtype=DTYPE) + self.DTYPE = DTYPE + + +# This is better for GPU +PHILOX_32 = PhiloxConfig( + PHILOX_KEY_A=0x9E3779B9, + PHILOX_KEY_B=0xBB67AE85, + PHILOX_ROUND_A=0xD2511F53, + PHILOX_ROUND_B=0xCD9E8D57, + DTYPE=np.uint32, +) + +# This is what numpy implements +PHILOX_64 = PhiloxConfig( + PHILOX_KEY_A=0x9E3779B97F4A7C15, + PHILOX_KEY_B=0xBB67AE8584CAA73B, + PHILOX_ROUND_A=0xD2E7470EE14C6C93, + PHILOX_ROUND_B=0xCA5A826395121157, + DTYPE=np.uint64, +) + + +class CustomPhilox4x: + def __init__(self, seed, config): + self._config = config + seed = self._into_pieces(seed) + self._key = np.array(seed[:2], dtype=self._dtype) + self._counter = np.array((0, 0) + seed[2:], dtype=self._dtype) + + @property + def _dtype(self): + return self._config.DTYPE + + def _into_pieces(self, n, pad=4): + res = [] + while len(res) < pad: + res.append(np.array(n, dtype=self._dtype)) + n >>= (np.dtype(self._dtype).itemsize * 8) + assert n == 0 + return tuple(res) + + def _multiply_low_high(self, a, b): + low = a * b + high = int(a) * int(b) + high = np.array(high >> (np.dtype(self._dtype).itemsize * 8), dtype=self._dtype) + return low, high + + def _single_round(self, counter, key): + lo0, hi0 = self._multiply_low_high(self._config.PHILOX_ROUND_A, counter[0]) + lo1, hi1 = self._multiply_low_high(self._config.PHILOX_ROUND_B, counter[2]) + ret0 = hi1 ^ counter[1] ^ key[0] + ret1 = lo1 + ret2 = hi0 ^ counter[3] ^ key[1] + ret3 = lo0 + return np.array([ret0, ret1, ret2, ret3], dtype=self._dtype) + + def _raise_key(self, key): + ret0 = key[0] + self._config.PHILOX_KEY_A + ret1 = key[1] + self._config.PHILOX_KEY_B + return np.array([ret0, ret1], dtype=self._dtype) + + def random_raw(self): + counter = self._counter + key = self._key + for _ in range(10): + counter = self._single_round(counter, key) + key = self._raise_key(key) + self.advance(1) + return counter + + def advance(self, n_steps): + self._counter[0] += n_steps + assert self._counter[0] < 2**32, "FIXME: doesn't work for large offsets" + + +class CustomPhilox(CustomPhilox4x): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.buffer = [] + + def random_raw(self): + if len(self.buffer) == 0: + self.buffer = list(super().random_raw())[::-1] + return int(self.buffer.pop()) + + +##################################### +## Unit Tests +##################################### + +BLOCK = 1024 + +# test generation of random uint32 +@pytest.mark.parametrize('size, seed', + [(size, seed) for size in ['10', '4,53', '10000']\ + for seed in [0, 42, 124, 54]] +) +def test_randint(size, seed, device='cuda'): + size = list(map(int, size.split(','))) + @triton.jit + def kernel(X, N, seed): + offset = tl.program_id(0) * BLOCK + tl.arange(0, BLOCK) + rand = tl.randint(seed, offset) + tl.store(X + offset, rand, mask=offset < N) + # triton result + x = torch.empty(size, dtype=torch.int32, device=device) + N = x.numel() + grid = (triton.cdiv(N, BLOCK),) + kernel[grid](x, N, seed) + out_tri = x.cpu().numpy().astype(np.uint32).flatten().tolist() + # reference result + gen = CustomPhilox4x(seed, config=PHILOX_32) + out_ref = [gen.random_raw()[0] for _ in out_tri] + assert out_tri == out_ref + +# test conversion of random uint32 into random float in [0, 1] +def test_uint32_to_uniform_float(): + @triton.jit + def kernel(SRC, TGT, N, **meta): + pid = tl.program_id(0) + offset = pid * BLOCK + tl.arange(0, BLOCK) + src = tl.load(SRC + offset) + tgt = tl.random.uint32_to_uniform_float(src) + tl.store(TGT + offset, tgt, mask=offset < N) + + def run(source): + target = -torch.ones(source.shape, dtype=torch.float32, device=source.device) + N = source.numel() + grid = lambda meta: (triton.cdiv(N, BLOCK),) + kernel[grid](source, target, N) + return target + + # check range of edge values + n = 100 + source = torch.tensor(list(range(n)) + list(range(-n, 0)), dtype=torch.int32).cuda() + target = run(source).tolist() + assert target == sorted(target) + assert all(0.0 <= num < 1.0 for num in target) + # check distribution is uniform + source = torch.randint(-2**31, 2**31 - 1, dtype=torch.int32, size=(100000,)).cuda() + target = run(source).tolist() + assert scipy.stats.kstest(target, 'uniform', args=(0, 1)).statistic < 0.01 + +# test uniform PRNG +@pytest.mark.parametrize('size, seed', + [(size, seed) for size in [1000000]\ + for seed in [0, 42, 124, 54]] +) +def test_rand(size, seed, device='cuda'): + @triton.jit + def kernel(X, N, seed): + offset = tl.program_id(0) * BLOCK + tl.arange(0, BLOCK) + rand = tl.rand(seed, offset) + tl.store(X + offset, rand, mask=offset < N) + # triton result + x = torch.empty(size, dtype=torch.float32, device=device) + N = x.numel() + grid = (triton.cdiv(N, BLOCK),) + kernel[grid](x, N, seed) + assert scipy.stats.kstest(x.tolist(), 'uniform', args=(0, 1)).statistic < 0.01 + +# test normal PRNG +@pytest.mark.parametrize('size, seed', + [(size, seed) for size in [1000000]\ + for seed in [0, 42, 124, 54]] +) +def test_randn(size, seed, device='cuda'): + @triton.jit + def kernel(X, N, seed): + offset = tl.program_id(0) * BLOCK + tl.arange(0, BLOCK) + rand = tl.randn(seed, offset) + tl.store(X + offset, rand, mask=offset < N) + # triton result + x = torch.empty(size, dtype=torch.float32, device=device) + N = x.numel() + grid = (triton.cdiv(N, BLOCK),) + kernel[grid](x, N, seed) + assert abs(x.mean()) < 1e-2 + assert abs(x.std() - 1) < 1e-2 diff --git a/python/test/test_blocksparse.py b/python/test/operators/test_blocksparse.py similarity index 100% rename from python/test/test_blocksparse.py rename to python/test/operators/test_blocksparse.py diff --git a/python/test/test_cross_entropy.py b/python/test/operators/test_cross_entropy.py similarity index 100% rename from python/test/test_cross_entropy.py rename to python/test/operators/test_cross_entropy.py diff --git a/python/test/test_matmul.py b/python/test/operators/test_matmul.py similarity index 100% rename from python/test/test_matmul.py rename to python/test/operators/test_matmul.py diff --git a/python/test/test_comm.py b/python/test/runtime/test_comm.py similarity index 100% rename from python/test/test_comm.py rename to python/test/runtime/test_comm.py diff --git a/python/triton/__init__.py b/python/triton/__init__.py index 7694b9ec9..3f08b5133 100644 --- a/python/triton/__init__.py +++ b/python/triton/__init__.py @@ -9,4 +9,4 @@ from . import code_gen from . import testing from . import ops # version -__version__ = '1.0.0' \ No newline at end of file +__version__ = '1.0.0' diff --git a/python/triton/language/__init__.py b/python/triton/language/__init__.py index ab1daeb41..b96260c51 100644 --- a/python/triton/language/__init__.py +++ b/python/triton/language/__init__.py @@ -1,2 +1,4 @@ from . import core -from .core import * \ No newline at end of file +from . import random +from .core import * +from .random import * diff --git a/python/triton/language/core.py b/python/triton/language/core.py index ff243bc5e..22cd717e7 100644 --- a/python/triton/language/core.py +++ b/python/triton/language/core.py @@ -648,6 +648,7 @@ def cdiv(x, div): """ return (x + div - 1) // div + @triton.jit def minimum(x, y): """ diff --git a/python/triton/language/random.py b/python/triton/language/random.py new file mode 100644 index 000000000..913073679 --- /dev/null +++ b/python/triton/language/random.py @@ -0,0 +1,208 @@ +import triton +import triton.language as tl + + +# Notes +# 1. triton doesn't support uint32, so we use int32 instead and benefit from the fact that two's complement operations are equivalent to uint operations. +# 2. multiply_low_high is currently inefficient. +# 3. Even though technically philox sampling outputs int, in many places we pretends they were actualy uints e.g. uint_to_uniform_float + + +@triton.jit +def PHILOX_KEY_A(): + # 0x9E3779B9 + return -1640531527 + + +@triton.jit +def PHILOX_KEY_B(): + # 0xBB67AE85 + return -1150833019 + + +@triton.jit +def PHILOX_ROUND_A(): + # 0xD2511F53 + return -766435501 + + +@triton.jit +def PHILOX_ROUND_B(): + # 0xCD9E8D57 + return -845247145 + + +@triton.jit +def hacky_to_uint64(x): + return ((x >> 1).to(tl.int64) << 1) + (x & 1).to(tl.int64) + + +@triton.jit +def multiply_low_high(a, b): + return ( + a * b, + ((hacky_to_uint64(a) * hacky_to_uint64(b)) >> 32).to(tl.int32) + ) + + +@triton.jit +def single_round(c0, c1, c2, c3, k0, k1): + A = PHILOX_ROUND_A() + B = PHILOX_ROUND_B() + lo0, hi0 = multiply_low_high(A, c0) + lo1, hi1 = multiply_low_high(B, c2) + + return ( + hi1 ^ c1 ^ k0, + lo1, + hi0 ^ c3 ^ k1, + lo0, + ) + + +@triton.jit +def raise_key(k0, k1): + return ( + k0 + PHILOX_KEY_A(), + k1 + PHILOX_KEY_B(), + ) + + +@triton.jit +def philox_f(c0, c1, c2, c3, k0, k1): + c0, c1, c2, c3 = single_round(c0, c1, c2, c3, k0, k1) + k0, k1 = raise_key(k0, k1) + c0, c1, c2, c3 = single_round(c0, c1, c2, c3, k0, k1) + k0, k1 = raise_key(k0, k1) + c0, c1, c2, c3 = single_round(c0, c1, c2, c3, k0, k1) + k0, k1 = raise_key(k0, k1) + c0, c1, c2, c3 = single_round(c0, c1, c2, c3, k0, k1) + k0, k1 = raise_key(k0, k1) + c0, c1, c2, c3 = single_round(c0, c1, c2, c3, k0, k1) + k0, k1 = raise_key(k0, k1) + c0, c1, c2, c3 = single_round(c0, c1, c2, c3, k0, k1) + k0, k1 = raise_key(k0, k1) + c0, c1, c2, c3 = single_round(c0, c1, c2, c3, k0, k1) + k0, k1 = raise_key(k0, k1) + c0, c1, c2, c3 = single_round(c0, c1, c2, c3, k0, k1) + k0, k1 = raise_key(k0, k1) + c0, c1, c2, c3 = single_round(c0, c1, c2, c3, k0, k1) + k0, k1 = raise_key(k0, k1) + c0, c1, c2, c3 = single_round(c0, c1, c2, c3, k0, k1) + return c0, c1, c2, c3 + + + +@triton.jit +def uint32_to_uniform_float(x): + """ + Numerically stable function to convert a random integer into a random float uniformly sampled in [0, 1). + This is originally designed from uint32, but it works with int32 too as long as the int32 uniformly + covers all the possible values it can take. + """ + mantissa = x & 0x7fffff + exp = 127 + res = mantissa | (exp << 23) + return res.to(tl.float32, bitcast=True) - 1.0 + + +@triton.jit +def pair_uniform_to_normal(u1, u2): + """Box-Muller transform""" + u1 = tl.maximum(1.0e-7, u1) + th = 6.283185307179586 * u2 + r = tl.sqrt(-2.0 * tl.log(u1)) + return r * tl.cos(th), r * tl.sin(th) + + +@triton.jit +def randint4x(seed, offset): + """ + Given a :code:`seed` scalar and an :code:`offset` block, returns four + blocks of random :code:`int32`. + + This is the maximally efficient entry point + to Triton's Philox pseudo-random number generator. + + :param seed: The seed for generating random numbers. + :param offsets: The offsets to generate random numbers for. + """ + z = 0 + return philox_f(offset, z, z, z, seed, z) + + +@triton.jit +def randint(seed, offset): + """ + Given a :code:`seed` scalar and an :code:`offset` block, returns a single + block of random :code:`int32`. + + If you need multiple streams of random numbers, + using `randint4x` is likely to be faster than calling `randint` 4 times. + + :param seed: The seed for generating random numbers. + :param offsets: The offsets to generate random numbers for. + """ + ret, _, _, _ = randint4x(seed, offset) + return ret + + +@triton.jit +def rand(seed, offset): + """ + Given a :code:`seed` scalar and an :code:`offset` block, + returns a block of random :code:`float32` in :math:`U(0, 1)` + + :param seed: The seed for generating random numbers. + :param offsets: The offsets to generate random numbers for. + """ + source = randint(seed, offset) + return uint32_to_uniform_float(source) + + +@triton.jit +def randn(seed, offset): + """ + Given a :code:`seed` scalar and an :code:`offset` block, + returns a block of random :code:`float32` in :math:`\mathcal{N}(0, 1)` + + :param seed: The seed for generating random numbers. + :param offsets: The offsets to generate random numbers for. + """ + i1, i2, _, _ = randint4x(seed, offset) + u1 = uint32_to_uniform_float(i1) + u2 = uint32_to_uniform_float(i2) + n1, _ = pair_uniform_to_normal(u1, u2) + return n1 + + +@triton.jit +def rand4x(seed, offsets): + """ + Given a :code:`seed` scalar and an :code:`offsets` block, + returns a 4 blocks of random :code:`float32` in :math:`U(0, 1)` + + :param seed: The seed for generating random numbers. + :param offsets: The offsets to generate random numbers for. + """ + i1, i2, i3, i4 = randint4x(seed, offsets) + u1 = uint32_to_uniform_float(i1) + u2 = uint32_to_uniform_float(i2) + u3 = uint32_to_uniform_float(i3) + u4 = uint32_to_uniform_float(i4) + return u1, u2, u3, u4 + + +@triton.jit +def randn4x(seed, offset): + """ + Given a :code:`seed` scalar and an :code:`offset` block, + returns a 4 blocks of random :code:`float32` in :math:`\mathcal{N}(0, 1)` + + :param seed: The seed for generating random numbers. + :param offsets: The offsets to generate random numbers for. + """ + u1, u2, u3, u4 = rand4x(seed, offset) + n1, n2 = pair_uniform_to_normal(u1, u2) + n3, n4 = pair_uniform_to_normal(u3, u4) + return n1, n2, n3, n4 diff --git a/python/tutorials/01-vector-add.py b/python/tutorials/01-vector-add.py index 4baa951c1..e0847ae86 100644 --- a/python/tutorials/01-vector-add.py +++ b/python/tutorials/01-vector-add.py @@ -43,7 +43,7 @@ def add_kernel( y = tl.load(y_ptr + offsets, mask=mask) output = x + y # Write x + y back to DRAM - tl.store(output_ptr + offsets, output) + tl.store(output_ptr + offsets, output, mask=mask) # %% diff --git a/python/tutorials/04-low-memory-dropout.py b/python/tutorials/04-low-memory-dropout.py new file mode 100644 index 000000000..d988746a7 --- /dev/null +++ b/python/tutorials/04-low-memory-dropout.py @@ -0,0 +1,164 @@ +""" +Low-Memory Dropout +================= + +In this tutorial, you will write a memory-efficient implementation of dropout whose state +will be composed of a single int32 seed. This differs from more traditional implementations of dropout, +whose state is generally composed of a bit mask tensor of the same shape as the input. You will learn about: + +- The limitations of naive implementations of Dropout with PyTorch +- Parallel pseudo-random number generation in Triton +""" + +# %% +# Baseline +# ------------- +# The *dropout* operator was first introduced in [SRIVASTAVA2014]_ as a way to improve the performance +# of deep neural networks in low-data regime (i.e. regularization). +# +# It takes a vector as input and produces a vector of the same shape as output. Each scalar in the +# output has a probability :math:`p` of being changed to zero and otherwise it is copied from the input. +# This forces the network to perform well even when only :math:`1 - p` scalars from the input are available. +# +# At evaluation time we want to use the full power of the network so we set :math:`p=0`. Naively this would +# increase the norm of the output (which can be a bad thing, e.g. it can lead to artificial decrease +# in the output softmax temperature). To prevent this we multiply the output by :math:`\frac{1}{1 - p}`, which +# keeps the norm consistent regardless of the dropout probability. +# +# Let's first take a look at the baseline implementation. + + +import tabulate +import torch +import triton +import triton.language as tl + +@triton.jit +def _dropout( + x_ptr, # pointer to the input + x_keep_ptr, # pointer to a mask of 0s and 1s + output_ptr, # pointer to the output + n_elements, # number of elements in the `x` tensor + p, # probability that an element of `x` is changed to zero + **meta, +): + BLOCK_SIZE = meta['BLOCK_SIZE'] + pid = tl.program_id(axis=0) + block_start = pid * BLOCK_SIZE + offsets = block_start + tl.arange(0, BLOCK_SIZE) + mask = offsets < n_elements + # Load data + x = tl.load(x_ptr + offsets, mask=mask) + x_keep = tl.load(x_keep_ptr + offsets, mask=mask) + # The line below is the crucial part, described in the paragraph above! + output = tl.where(x_keep, x / (1 - p), 0.0) + # Write-back output + tl.store(output_ptr + offsets, output, mask=mask) + + +def dropout(x, x_keep, p): + output = torch.empty_like(x) + assert x.is_contiguous() + n_elements = x.numel() + grid = lambda meta: (triton.cdiv(n_elements, meta['BLOCK_SIZE']),) + _dropout[grid](x, x_keep, output, n_elements, p, BLOCK_SIZE=1024) + return output + +# Input tensor +x = torch.randn(size=(10,)).cuda() +# Dropout mask +p = 0.5 +x_keep = (torch.rand(size=(10,)) > p).to(torch.int32).cuda() +# +output = dropout(x, x_keep=x_keep, p=p) +print(tabulate.tabulate([ + ["input"] + x.tolist(), + ["keep mask"] + x_keep.tolist(), + ["output"] + output.tolist() +])) + +# %% +# Seeded dropout +# ------------- +# Above implementation of dropout works fine, but it can be a bit awkward to deal with. Firstly +# we need to store the dropout mask for backpropagation. Secondly, dropout state management can get +# very tricky when using recompute/checkpointing (e.g. see all the notes about `preserve_rng_state` in +# https://pytorch.org/docs/1.9.0/checkpoint.html). In this tutorial we'll describe an alternative implementation +# that (1) has a smaller memory footprint; (2) requires less data movement; and (3) simplifies the management +# of persisting randomness across multiple invocations of the kernel. +# +# Pseudorandom number generation in Triton is simple! In this tutorial we will use the +# :code:`triton.language.rand` function which generates a block of uniformly distributed :code:`float32` +# values in [0, 1), given a seed and a block of :code:`int32` offsets. But if you need it, Triton also provides +# other :ref:`random number generation strategies `. +# +# .. note:: +# Triton's implementation of PRNG is based on the Philox algorithm (described on [SALMON2011]_). +# +# Let's put it all together. + +@triton.jit +def _seeded_dropout( + x_ptr, + output_ptr, + n_elements, + p, + seed, + **meta, +): + # compute memory offsets of elements handled by this instance + BLOCK_SIZE = meta['BLOCK_SIZE'] + pid = tl.program_id(axis=0) + block_start = pid * BLOCK_SIZE + offsets = block_start + tl.arange(0, BLOCK_SIZE) + # load data from x + mask = offsets < n_elements + x = tl.load(x_ptr + offsets, mask=mask) + # randomly prune it + random = tl.rand(seed, offsets) + x_keep = random > p + # write-back + output = tl.where(x_keep, x / (1 - p), 0.0) + tl.store(output_ptr + offsets, output, mask=mask) + + +def seeded_dropout(x, p, seed): + output = torch.empty_like(x) + assert x.is_contiguous() + n_elements = x.numel() + grid = lambda meta: (triton.cdiv(n_elements, meta['BLOCK_SIZE']),) + _seeded_dropout[grid](x, output, n_elements, p, seed, BLOCK_SIZE=1024) + return output + + +x = torch.randn(size=(10,)).cuda() +# Compare this to the baseline - dropout mask is never instantiated! +output = seeded_dropout(x, p=0.5, seed=123) +output2 = seeded_dropout(x, p=0.5, seed=123) +output3 = seeded_dropout(x, p=0.5, seed=512) + +print(tabulate.tabulate([ + ["input"] + x.tolist(), + ["output (seed = 123)"] + output.tolist(), + ["output (seed = 123)"] + output2.tolist(), + ["output (seed = 512)"] + output3.tolist() +])) + +# %% +# Et VoilĂ ! We have a triton kernel that applies the same dropout mask provided the seed is the same! +# If you'd like explore further applications of pseudorandomness in GPU programming, we encourage you +# to explore the `triton/language/random` folder! + +# %% +# Exercises +# ------------- +# 1. Extend the kernel to operate over a matrix and use a vector of seeds - one per row. +# 2. Add support for striding. +# 3. (challenge) Implement a kernel for sparse Johnson-Lindenstrauss transform which generates the projection matrix one the fly each time using a seed. + +# %% +# References +# -------------- +# +# .. [SALMON2011] John K. Salmon, Mark A. Moraes, Ron O. Dror, and David E. Shaw, "Parallel Random Numbers: As Easy as 1, 2, 3", 2011 +# .. [SRIVASTAVA2014] Nitish Srivastava and Geoffrey Hinton and Alex Krizhevsky and Ilya Sutskever and Ruslan Salakhutdinov, "Dropout: A Simple Way to Prevent Neural Networks from Overfitting", JMLR 2014