Effects of Parameter Norm Growth During Transformer Training: Inductive Bias from Gradient Descent
←
→
Page content transcription
If your browser does not render page correctly, please read the page content below
Effects of Parameter Norm Growth During Transformer Training: Inductive Bias from Gradient Descent William Merrill∗† Vivek Ramanujan∗‡ Yoav Goldberg∗§ Roy Schwartz¶ Noah A. Smith∗‡ ∗ Allen Institute for AI † New York University ‡ University of Washington § Bar Ilan University ¶ Hebrew University of Jerusalem willm@nyu.edu ramanv@cs.washington.edu {yoavg,noah}@allenai.org roys@cs.huji.ac.il Abstract transformers (Rogers et al., 2020; Lovering et al., 2021), it remains an interesting open question what The capacity of neural networks like the they are, or even how to characterize them in a way widely adopted transformer is known to be relevant to the domain of language. arXiv:2010.09697v4 [cs.LG] 29 Sep 2021 very high. Evidence is emerging that they learn successfully due to inductive bias in the In this work, we take the perspective that thor- training routine, typically a variant of gradi- oughly understanding the dynamics of gradient de- ent descent (GD). To better understand this scent (GD) might clarify the linguistic biases of bias, we study the tendency for transformer pa- transformers, and the types of representations they rameters to grow in magnitude (`2 norm) dur- acquire. We start by making a potentially surprising ing training, and its implications for the emer- gent representations within self attention lay- √ the parameter `2 norm empirical observation (§3): ers. Empirically, we document norm growth grows proportional to t (where t is the timestep) in the training of transformer language models, during the training of T5 (Raffel et al., 2019) and including T5 during its pretraining. As the pa- other transformers. We refer to the phenomenon of rameters grow in magnitude, we prove that the growing parameter norm during training as norm network approximates a discretized network growth. Previous work has analyzed norm growth with saturated activation functions. Such “sat- in simplified classes of feedforward networks (Li urated” networks are known to have a reduced and Arora, 2019; Ji and Telgarsky, 2020), but, to capacity compared to the full network family that can be described in terms of formal lan- our knowledge, it has not been thoroughly demon- guages and automata. Our results suggest satu- strated or studied in the more complicated and prac- ration is a new characterization of an inductive tical setting of transformers. bias implicit in GD of particular interest for Our main contribution is analyzing the effect NLP. We leverage the emergent discrete struc- of norm growth on the representations within the ture in a saturated transformer to analyze the transformer (§4), which control the network’s gram- role of different attention heads, finding that matical generalization. With some light assump- some focus locally on a small number of po- sitions, while other heads compute global av- tions, we prove that any network where the parame- erages, allowing counting. We believe under- ter norm diverges during training approaches a sat- standing the interplay between these two capa- urated network (Merrill et al., 2020): a restricted bilities may shed further light on the structure network variant whose discretized representations of computation within large transformers. are understandable in terms of formal languages and automata. Empirically, we find that internal 1 Introduction representations of pretrained transformers approxi- Transformer-based models (Vaswani et al., 2017) mate their saturated counterparts, but for randomly like BERT (Devlin et al., 2019), XLNet (Yang et al., initialized transformers, they do not. This suggests 2019), RoBERTa (Liu et al., 2019), and T5 (Raffel that the norm growth implicit in training guides et al., 2019) have pushed the state of the art on an transformers to approximate saturated networks, impressive array of NLP tasks. Overparameterized justifying studying the latter (Merrill, 2019) as a transformers are known to be unversal approxima- way to analyze the linguistic biases of NLP archi- tors (Yun et al., 2020), suggesting their general- tectures and the structure of their representations. ization performance ought to rely on useful biases Past work (Merrill, 2019; Bhattamishra et al., or constraints imposed by the learning algorithm. 2020) reveals that saturation permits two useful Despite various attempts to study these biases in types of attention heads within a transformer: one
that locally targets a small number of positions, and converge to a finite local minimum during GD train- one that attends uniformly over the full sequence, ing. Rather, it suggests that GD follows a norm- enabling an “average” operation. Empirically, we increasing trajectory along which network behavior find that both of these head types emerge in trained stabilizes. These analyses motivate investigation of transformer language models. These capabilities this trajectory-driven perspective of training. reveal how the transformer can process various From a statistical perspective, work in this vein formal languages, and could also suggest how it has considered the implications of these training might represent the structure of natural language. dynamics for margin maximization (Poggio et al., Combined, our theoretical and empirical results 2019; Nacson et al., 2019; Lyu and Li, 2019). shed light on the linguistic inductive biases imbued While these works vary in the networks they con- in the transformer architecture by GD, and could sider and their assumptions, they reach similar con- serve as a tool to analyze transformers, visualize clusions: GD follows trajectories diverging in the them, and improve their performance. direction of a max-margin solution. As margin Finally, we discuss potential causes of norm maximization produces a simple decision boundary, growth in §5. We prove transformers are approx- this property suggests better generalization than an imately homogeneous (Ji and Telgarsky, 2020), a arbitrary solution with low training loss. This point property that has been extensively studied in deep of view partially explains why growing norm is learning theory. With some simplifying assump- associated with better generalization performance. tions, we√ then show how homogeneity might ex- plain the t growth observed for T5.1 2.2 NLP and Formal Language Theory Norm growth has another interpretation for NLP 2 Background and Related Work models. Past work characterizes the capacity of 2.1 GD and Deep Learning Theory infinite-norm networks in terms of formal lan- guages and automata theory. Merrill (2019) and A simple case where deep learning theory has stud- Merrill et al. (2020) propose saturation, a frame- ied the generalization properties of GD is matrix work for theoretical analysis of the capacity of NLP factorization (Gunasekar et al., 2017; Arora et al., architectures. A network is analyzed by assuming 2019; Razin and Cohen, 2020). It has been ob- it saturates its nonlinearities, which means replac- served that deep matrix factorization leads to low- ing functions like σ and tanh with step functions. rank matrix solutions. Razin and Cohen (2020) This is equivalent to the following definition: argued theoretically that this bias of GD cannot be explained as an implicit regularizer minimizing Definition 1 (Saturation; Merrill et al., 2020) Let some norm. Rather, they construct cases where all f (x; θ) be a neural network with inputs x and parameter norms diverge during GD. weights θ. The saturated network sf (x; θ) is2 Similar ideas have emerged in recent works sf (x; θ) = lim f (x; cθ), studying feedforward networks. Analyzing bias- c→∞ less ReLU networks with cross-entropy loss, Pog- where the limit exists, and undefined elsewhere. gio et al. (2019, 2020) show that the magnitude (`2 norm) of the parameter vector continues to Saturation reduces continuous neural networks grow during GD, while its direction converges. to discrete computational models resembling au- Li and Arora (2019) present a similar argument tomata or circuits, making some kinds of formal for scale-invariant networks, meaning that scaling linguistic analysis easier. For many common archi- the parameters by a constant does not change the tectures, the saturated capacity is known to be sig- output. Studying homogeneous networks, Ji and nificantly weaker than the full capacity of the net- Telgarsky (2020) show that the gradients become work with rational-valued weights (Merrill, 2019), aligned as t → ∞, meaning that their direction which, classically, is Turing-complete for even sim- converges to the parameter direction. This means ple RNNs (Siegelmann and Sontag, 1992). the norm will grow monotonically with t. The For example, one can hand-construct an RNN perspective developed by these works challenges or LSTM encoding a stack in its recurrent memory the once conventional wisdom that the parameters (Kirov and Frank, 2012). Stacks are useful for pro- cessing compositional structure in linguistic data 1 Code available at https://github.com/ 2 viking-sudo-rm/norm-growth. The limit over f is taken pointwise. The range of sf is R.
(Chomsky, 1956), e.g., for semantic parsing. How- parameter model, which was trained using the ever, a saturated LSTM does not have enough mem- AdaFactor optimizer (Shazeer and Stern, 2018). ory to simulate a stack (Merrill, 2019). Rather, sat- Further details are in §A. √ urated LSTMs resemble classical counter machines Fig. 1 shows that the T5 norm follows a t trend, (Merrill, 2019): automata limited in their ability to where t is time in training steps. The top right of model hierarchical structure (Merrill, 2020). Ex- Fig. 1 breaks down the growth trend by layer. Gen- periments suggest that LSTMs trained on synthetic erally, the norm grows more quickly in later layers tasks learn to implement counter memory (Weiss than in earlier √ ones, although always at a rate pro- et al., 2018; Suzgun et al., 2019a), and that they portional to t.5 Next, in the bottom row of Fig. 1, fail on tasks requiring stacks and other deeper mod- we plot the cosine similarity between each param- els of structure (Suzgun et al., 2019b). Similarly, eter checkpoint θt+1 and its predecessor θt . This Shibata et al. (2020) found that LSTM language rapidly approaches 1, suggesting the “direction” of models trained on natural language data acquire the parameters (θt /kθt k) converges. The trend in saturated representations approximating counters. directional convergence looks similar across layers. Recent work extends saturation analysis to trans- We also train smaller transformer language mod- formers (Merrill, 2019; Merrill et al., 2020). Sat- els with 38M parameters on Wikitext-2 (Merity urated attention heads reduce to generalized hard et al., 2016) and the Penn Treebank (PTB; Marcus attention, where the attention scores can tie. In the et al., 1993). We consider two variants of the trans- case of ties, the head output averages the positions former: pre-norm and post-norm, which vary in the with maximal scores.3 While their power is not relative order of layer normalization and residual fully understood, saturated transformers can imple- connections (cf. Xiong et al., 2020). Every model ment a counting mechanism similarly to LSTMs exhibits norm growth over training.6 (Merrill et al., 2020). In practice, Bhattamishra Combined, these results provide evidence that et al. (2020) show transformers can learn tasks re- the parameter norm of transformers tends to grow quiring counting, and that they struggle when more over the course of training. In the remainder of complicated structural representations are required. this paper, we will discuss the implications of this Ebrahimi et al. (2020) find that attention patterns of phenomenon for the linguistic biases of transform- certain heads can emulate bounded stacks, but that ers, and then discuss potential causes of the trend this ability falls off sharply for longer sequences. rooted in the optimization dynamics. Thus, the abilities of trained LSTMs and trans- formers appear to be predicted by the classes of 4 Effect of Norm Growth problems solvable by their saturated counterparts. §3 empirically documented √ that the parameter Merrill et al. (2020) conjecture that the saturated norm grows proportional to t during T5 pretrain- capacity might represent a class of tasks implicitly ing. Now, we move to the main contribution of our learnable by GD, but it is unclear a priori why this paper: the implications of norm growth for under- should be the case. This work aims to put this con- standing transformers’ linguistic inductive biases. jecture on more solid theoretical footing: we argue In particular, Prop. 1 says uniform norm growth that approximate saturation arises in transformers across the network guides GD towards saturated as a result of norm growth during training.4 networks. Thus, saturation is not just a useful ap- proximation for analyzing networks, but a state 3 Norm Growth in Transformers induced by training with enough time. We start with the observation that the parameter Proposition 1 (Informal) Let θt ∈ Rn be parame- `2 norm grows during training for practical trans- ters at step t for f (x; θt ). If every scalar parameter former language models. We first consider the θti diverges at the same rate up to a constant, then parameter norm of 104 historical checkpoints from f converges pointwise to a saturated network. T5-base (Raffel et al., 2019) pretraining, a 220M 5 We encourage future works that pretrain new transformer language models to track metrics around norm growth. 3 6 Hahn (2020) identified weaknesses of strictly hard atten- The post-norm transformer achieves 115.79 perplexity tion, which is weaker than saturated attention. on Wikitext-2 and 96.24 on PTB. On the other hand, the 4 This relates to Correia et al. (2019), who modify the trans- pre-norm transformer reaches 66.35 on Wikitext-2 and 26.16 former to facilitate approximately sparse attention. In contrast, on PTB, slightly outperforming Wang et al. (2019). This is we will show that approximate sparsity (i.e., saturation) arises consistent with previous findings (Xiong et al., 2020) showing implicitly in standard transformers. advantages of pre-norm over post-norm.
(t) 800 Layer 12 2000 (t) = a t + b 700 Layer 11 Layer 10 600 Layer 9 Layer 8 Parameter norm Parameter norm 1500 500 Layer 7 Layer 6 1000 400 Layer 5 Layer 4 300 Layer 3 500 200 Layer 2 Layer 1 100 0 100000 200000 300000 400000 500000 0 100000 200000 300000 400000 500000 Checkpoint t Checkpoint t 1.000 1.00 0.975 Layer 12 0.95 Layer 11 0.950 Layer 10 0.925 0.90 Layer 9 t + 1) t + 1) Layer 8 0.900 0.85 Layer 7 cos( t, cos( t, Layer 6 0.875 0.80 Layer 5 0.850 Layer 4 0.75 Layer 3 0.825 Layer 2 Layer 1 0.800 0.70 0 100000 200000 300000 400000 500000 0 100000 200000 300000 400000 500000 Checkpoint t Checkpoint t Figure 1: Top: Norm growth during T5 pretraining, with a coefficient r2 = 1.00. The right is broken down by layer. Bottom: cosine similarity between subsequent parameter checkpoints. The proof is in §B. Prop. 1 assumes not just large c (in practice, 1,000). We consider the “base” norm growth, but uniform norm growth, mean- versions of pretrained BERT, RoBERTa, T5, and ing no parameter can asymptotically dominate any XLNet (pretrained on masked language modeling), other. Notably, uniform growth implies directional and compute the mean saturation over 100 input convergence. Accepting uniform growth for a sentences from the Brown corpus (Francis and given training regimen, we expect transformers to Kučera, 1989). To match standard practice, each converge to saturated networks with infinite train- sentence is truncated at 512 word pieces. Fig. 2 ing.√ Based on §3, the T5 norm appears to grow plots the similarity for each layer of each model. ∝ t uniformly across the network, suggesting We compare the pretrained transformers against the uniform growth condition is reasonable. As we a randomly initialized baseline. For every model will discuss later in §5, we expect the growth trend type, the similarity is higher for the pretrained to depend heavily on the learning rate schedule. network than the randomly initialized network, which, except for T5, is ∼0. For T5 and XLNet, 4.1 Saturated Transformers the similarity in the final layer is ≥0.9, whereas, Having established that norm growth should lead for RoBERTa, the final similarity is 0.65 (although to saturation, we now empirically measure the sat- 0.94 in the penultimate layer). For T5 and XLNet, uration levels in T5 and other transformer models. similarity is higher in later layers, which is potentially surprising, as one might expect error Large transformers are highly saturated. to compound with more layers. This may relate Since kθt k empirically grows during training, to the fact that the norm grows faster for later we expect high cosine similarity between the layers in T5. One question is why the similarity for representations in trained networks and saturated BERT is lower than these models. As RoBERTa representations. We estimate this as the cosine is architecturally similar to BERT besides longer similarity between f (x; θ) and f (x; cθ) for some
Randomly initialized representation similarity Pretrained representation similarity 1.0 1.0 bert-base-cased bert-base-cased roberta-base roberta-base t5-base 0.8 t5-base 0.8 xlnet-base-cased xlnet-base-cased Representation similarity Representation similarity 0.6 0.6 0.4 0.4 0.2 0.2 0.0 0.0 2 4 6 8 10 12 2 4 6 8 10 12 Layer # Layer # Figure 2: Cosine similarities of the unsaturated and saturated (c = 1,000) transformer representations, by layer. We compare randomly initialized transformers (left) to pretrained ones (right). training, we hypothesize that RoBERTa’s higher We define this vectorized arg max(A) as similarity is due to longer pretraining. M(Ai ) = {j | aij = max aik } Small transformers reach full saturation. k ( Each of the transformers trained on Wikitext-2 1/|M(Ai )| if j ∈ M(Ai ) arg max(Ai )j = and PTB reached a saturation level of 1.00. It 0 otherwise. is unclear why these models saturate more fully than the pretrained ones, although it might be Crucially, in the case of ties, arg max(A) returns because they are smaller.7 For our LMs, the a uniform distribution over all tied positions. Satu- feedforward width (512) is less than for T5-base, rated attention can retrieve the “maximum” value while the encoder depth and width are the same. in a sequence according to some similarity matrix. Other possible explanations include differences in It is also capable of restricted counting (Merrill the initialization scheme, optimizer, and training et al., 2020). Formalizing these observations, we objective (masked vs. next-word modeling). See identify two useful computational operations that §A for full hyperparameters. are reducible to saturated self attention: argmax and mean. Let hi represent the input representation 4.2 Power of Saturated Attention at each time step 1 ≤ i ≤ n. We have shown that transformer training increases 1. Argmax: Set V = Id. Then the self attention the parameter norm (§3), creating a bias towards mechanism computes a function recovering saturation (§4.1). Now, we discuss the computa- the element of H that maximally resembles hi tional capabilities of saturated transformers, and according to a quadratic form M = KQ> . If empirically investigate how they manifest in pre- there is a tie for similarity, a uniform average trained transformers. What computation can satu- of the maximal entries in H is returned. rated transformers perform? We review theoretical background about saturated attention, largely devel- argmax(H; M ) = arg max hi M h> j . j oped by Merrill (2019). Let H (sequence length n by model dimension d) be the input representation 2. Mean: Parameterize the head to attend uni- to a self attention layer. We assume a standard self formly everywhere. Then the head computes attention mechanism with key, query, and value a function taking a uniform average of values: matrices K, Q, V.8 Saturated attention resembles n standard attention where softmax is constrained to 1X mean(H; V ) = V hj . (1) a generalization of “argmax” (Merrill, 2019): n j=1 s attn(H; Q, K, V ) = arg max(HQK > H > )HV. These constructions demonstrate some useful com- putational abilities of saturated transformers. Due Qualitatively, we observed that ∗ -small transformers 7 tended to be more saturated than the ∗ -base models. to the summation in (1), the mean operation (or 8 To simplify presentation, we omit bias terms. near variants of it) can be used to implement
# positions attended for all attention heads for pre-norm # positions attended for all attention heads for post-norm 60 70 60 50 50 # of attention heads # of attention heads 40 40 30 30 20 20 10 10 0 0 0 5 10 15 20 25 30 35 40 0 5 10 15 20 25 30 35 40 # positions attended # positions attended Figure 3: Distribution of the number of positions attended to for all heads in the PTB language models. The left plot is pre-norm, and the right is post-norm. Values are averaged over 200 sentences from the development set. counting, which allows recognizing languages like quired in transformers in practice. In the pre-norm an bn cn (Merrill et al., 2020). Empirically, Bhat- transformer, which performs substantially better, tamishra et al. (2020) find trained networks can there are also a small number of heads lying be- learn to recognize counter languages that rely on tween the two modes. We defer the investigation computing means, failing on more complicated lan- of the function of these heads to future work. guages like Dyck-2. Our findings partially justify why transformers can learn these languages: they 5 Explanation for Norm Growth lie within the capacity of saturated transformers. We have documented norm growth in T5 and other 4.3 Learned Attention Patterns transformers (§3) and showed how it induces par- tial saturation in their representations (§4). This Recall that the small language models trained in section points towards an understanding of why §4.1 reach 1.00 saturation. It follows that we can the parameter norm grows over the course of train- convert them to saturated transformers (by multi- ing, grounded in results about norm growth from plying θ by a large constant c) without significantly deep learning theory. We do not analyze spe- shifting the representations in cosine space. We cific optimizers directly; instead, we analyze norm will evaluate if the saturated attention heads mani- growth within simplified models of training dynam- fest the argmax and mean constructions from §4.2. ics taken from the literature. We then evaluate how As discussed in §4.2, saturated attention can these candidate dynamics models fit T5’s training. parameterize both argmax and mean heads. An argmax head should attend to a small number of 5.1 Setup positions. A mean head, on the other hand, attends uniformly over the full sequence. Are both patterns Let δt ∈ Rn denote the optimizer step at time t, acquired in practice by our models? We plot the i.e., δt = θt+1 − θt . We write ηt for the learning distribution of the number of positions attended to rate at t.9 Let ∇θt L denote the gradient of the loss by each head in the saturated PTB models in Fig. 3. with respect to θt . By GD, we refer to the update The distribution is bimodal, with one mode at 1, δt = −ηt ∇θt L.10 In contrast, we will use the term and the other around 41, representing the mean gradient flow to refer to its continuous relaxation, sequence length of a 83-length encoder with posi- specified by an analogous differential equation: tional masking to prevent lookahead. The empir- ical mode around 1 corresponds to heads that are dθt = −ηt ∇θt L. argmax-like. The mode around 41, on the other dt 9 hand, corresponds to mean-like heads, since it im- Without loss of generality, the arguments presented here plies uniform attention over the masked sequence. can be seen as applying to an individual parameter in the network, or the vector of all concatenated network parameters. Thus, our analysis suggests that analogs of both 10 Note that, in practice, T5 was trained with AdaFactor, types of attention heads theorized in §4.2 are ac- whereas the setup in this section assumes simpler optimizers.
Scaling curves for transformer variants rapidly approach 1.0 as c increases. We find simi- 1.0 lar curves for transformers with and without biases, Cosine similarity to sat. transformer 0.8 suggesting biasless transformers are similarly ho- mogeneous to transformers with biases.12 0.6 Since multiplying two homogeneous functions adds their homogeneity, a transformer encoder fol- 0.4 lowed by a linear classifier is approximately 2- Pre-norm homogeneous. A key property of homogeneous 0.2 Pre-norm (+bias) functions is Euler’s Homogeneity Theorem: the Post-norm Post-norm (+bias) derivative of a k-homogeneous function is (k − 1)- 2 4 6 8 10 homogeneous. Thus, we will assume the gradi- Scaling factor c ents of the linear classifier output are roughly 1- homogeneous, which under simple GD implies: Figure 4: Approximate cosine similarity of f (x; cθ) to sf (x; θ) for randomly initialized transformers f . Assumption 1 Let θt include all encoder and clas- sf (x; θ) is approximated as in Fig. 2. sifier parameters. Let ∝ ∼ mean “approximately pro- portional to”. For large enough t during trans- former training, kδt k ∝ ∼ ηt kθt k. 5.2 Homogeneity We will rely on properties of homogeneous net- 5.3 Aligned Dynamics works, a class of architectures well-studied in deep We now consider the first candidate dynamics learning theory (Ji and Telgarsky, 2020). model: aligned dynamics (Ji and Telgarsky, 2020). Definition 2 (Homogeneity) A function f (x; θ) is Analyzing homogeneous networks with an expo- k-homogeneous in θ iff, for all c ≥ 0, f (x; cθ) = nential binary classification loss and gradient flow, ck f (x; θ). We further say that f is homogeneous iff Ji and Telgarsky (2020) show that the parameters there exists some k such that f is k-homogeneous. converge in direction, and that the gradients be- come aligned, meaning that θt> · δt → kθt kkδt k. Many common components of modern neural While it is unclear whether transformers will follow networks are homogeneous (Li and Arora, 2019). aligned dynamics, we entertain this as one hypoth- Furthermore, as various computations within a neu- esis. Under Ass. 1, alignment implies ral network preserve homogeneity (§C), some full networks are also homogeneous. An example of a X t Z kθt k ≈ ∝ kδi k ∼ ηt kθt kdt. fully homogeneous neural network is a feedforward ReLU network without bias terms. i=0 √ Why is homogeneity relevant for transformers? With the ηt = 1/ t schedule √ used by T5 (Raffel Transformers are not homogeneous, but they are et al., 2019), kθt k ∝ ∼ exp t (see §E.1). This is √ almost homogeneous. We formalize this as: asymptotically faster than the observed t growth, Definition 3 (Approx. homogeneity) A scalar11 suggesting an alternate dynamics might be at play. function f (x; θ) is approximately k-homogeneous 5.4 Misaligned Dynamics in θ iff there exist d, ρ s.t., for c ≥ 1 and kθk ≥ ρ, Our second candidate model of training is mis- k aligned dynamics, which follows largely from Li f (x; cθ) − c f (x; θ) ≤ exp(−dkθk). and Arora (2019). This can be derived by assum- In other words, as kθk grows, f approximates a ing the gradients are misaligned (i.e., θt> · δt = 0), homogeneous function with exponentially vanish- which hold for scale-invariant networks (Li and ing error. In §D, we prove transformer encoders Arora, 2019) and in expectation for random normal without biases are approximately 1-homogeneous. gradients. Misalignment implies (derived in §E.2): In Fig. 4, we compare the cosine similarity of trans- t formers with and without biases to their saturated X kθt k ∝ 2 ∼ kδi k2 . (2) variants, as a function of a constant c scaling their i=0 weights. An approximately homogeneous function 12 Lyu and Li (2019) find similar results for feedforward 11 A vector function is approximately k-homogeneous if this ReLU networks. It is an interesting puzzle why networks with holds for all its elements. biases appear similarly homogeneous to those without biases.
400 0.15 380 0.10 360 cos( t, t) t 0.05 340 0.00 320 300 0.05 0 100000 200000 300000 400000 500000 0 100000 200000 300000 400000 500000 Checkpoint t Checkpoint t Figure 5: Alignment (cosine similarity of δt and θt ) and step size (kδt k) over training. We show √in §E.2 that, with the T5 learning √ rate norm growth if λ is large.15 In §F, we report pre- ∝ (ηt = 1/ t), (2) reduces to kθt k ∼ t, as ob- liminary experiments testing the effect of weight served empirically for T5. We now further test decay on norm growth. Indeed, if λ is set too large, whether misaligned dynamics are a good fit for T5. weight decay can prevent norm growth, but within the standard range of values for λ, we find norm 5.5 Evaluation growth even in the face of weight decay. How- We measure the gradient alignment over the course ever, it is possible these results may change if the of training T5. Our alignment metric is the cosine optimizer or other hyperparameters are varied. similarity of δt to θt . As shown on the left of Fig. 5, 6 Conclusion the alignment initially rapidly increases to ∼0.15, √ and then decays to near 0. This supports the hypoth- We empirically found that kθt k grows ∝ t dur- esis that the T5 dynamics are misaligned, since the ing T5 pretraining—a fact that may be caused by similarity is never high, and may be approaching 0. the approximate homogeneity of the transformer On the right of Fig. 5, we plot step size over train- architecture. We proved that norm growth induces ing in order to evaluate the validity of Ass. 1. At saturation, and then showed empirically that T5 and the beginning of training, a chaotic step size seems other large transformers become approximately sat- reasonable, as it is hard to predict the dynamics urated through their pretraining. Examining highly before approximate homogeneity takes hold. For saturated transformer language models, we found large t, Ass. 1 combined with the T5 learning rate the attention heads largely split between two dis- schedule predicts step size should be roughly con- tinct behaviors that can be roughly interpreted as stant.13 This is not exactly what we find: for large argmax and mean operations. While we lack a t, kδt k grows gradually with t. However, the ab- precise formal characterization of “semi-saturated” solute change in step size is small: < 20 across transformers, we conjecture their capacity resem- 220M parameters. Thus, we believe Ass. 1 is not bles that of the saturated models. Thus, we believe unreasonable, though it would be interesting to further analyzing the capabilities of saturated atten- understand what properties of the optimizer can tion may clarify the linguistic biases that emerge in explain the slight growth in step size.14 transformers through training, and the mechanisms they use to represent linguistic structure. 5.6 Weight Decay Acknowledgments One feature of practical training schemes not con- sidered in this section is weight decay. When ap- We thank Colin Raffel for sharing access to the T5 plied to standard GD, weight decay can be written training checkpoints. Additional thanks to Qiang δt = −ηt ∇θt L − λθt . Intuitively, it might hinder 15 Common wisdom says that weight decay improves gener- √ √ alization by keeping kθt k small; however, recent work chal- 13 Since kδt k ∝ ∼ ηt kθt k = t/ t = 1. lenges the assumption that a bias towards small norm is bene- 14 We believe the sharp drop in kδt k at the final step is an ficial (Goldblum et al., 2020), suggesting the benefit of weight artifact of the original recording of these checkpoints. decay may arise from more subtle effects on the GD trajectory.
Ning, Kyle Richardson, Mitchell Wortsman, Mar- Michael Hahn. 2020. Theoretical limitations of self- tin Lohmann, and other researchers at the Allen attention in neural sequence models. Transactions of the Association for Computational Linguistics, Institute for AI for their comments on the project. 8:156–171. Ziwei Ji and Matus Telgarsky. 2020. Directional con- References vergence and alignment in deep learning. In Ad- vances in Neural Information Processing Systems, Sanjeev Arora, Nadav Cohen, Wei Hu, and Yuping Luo. volume 33, pages 17176–17186. Curran Associates, 2019. Implicit regularization in deep matrix factor- Inc. ization. Christo Kirov and Robert Frank. 2012. Processing of Satwik Bhattamishra, Kabir Ahuja, and Navin Goyal. nested and cross-serial dependencies: an automaton 2020. On the Ability and Limitations of Transform- perspective on SRN behaviour. Connection Science, ers to Recognize Formal Languages. In Proceed- 24(1):1–24. ings of the 2020 Conference on Empirical Methods in Natural Language Processing (EMNLP), pages Zhiyuan Li and Sanjeev Arora. 2019. An exponential 7096–7116, Online. Association for Computational learning rate schedule for deep learning. In Proc. of Linguistics. ICLR. N. Chomsky. 1956. Three models for the description of Yinhan Liu, Myle Ott, Naman Goyal, Jingfei Du, Man- language. IRE Transactions on Information Theory, dar Joshi, Danqi Chen, Omer Levy, Mike Lewis, 2(3):113–124. Luke Zettlemoyer, and Veselin Stoyanov. 2019. RoBERTa: A robustly optimized BERT pretraining Gonçalo M. Correia, Vlad Niculae, and André F. T. approach. Martins. 2019. Adaptively sparse transformers. In Proceedings of the 2019 Conference on Empirical Ilya Loshchilov and Frank Hutter. 2017. Fixing Methods in Natural Language Processing and the weight decay regularization in adam. CoRR, 9th International Joint Conference on Natural Lan- abs/1711.05101. guage Processing (EMNLP-IJCNLP), pages 2174– Charles Lovering, Rohan Jha, Tal Linzen, and Ellie 2184, Hong Kong, China. Association for Computa- Pavlick. 2021. Predicting inductive biases of pre- tional Linguistics. trained models. In International Conference on Learning Representations. Jacob Devlin, Ming-Wei Chang, Kenton Lee, and Kristina Toutanova. 2019. BERT: Pre-training of Kaifeng Lyu and Jian Li. 2019. Gradient descent maxi- deep bidirectional transformers for language under- mizes the margin of homogeneous neural networks. standing. In Proceedings of the 2019 Conference of the North American Chapter of the Association Mitchell P. Marcus, Beatrice Santorini, and Mary Ann for Computational Linguistics: Human Language Marcinkiewicz. 1993. Building a large annotated Technologies, Volume 1 (Long and Short Papers), corpus of English: The Penn Treebank. Computa- pages 4171–4186, Minneapolis, Minnesota. Associ- tional Linguistics, 19(2):313–330. ation for Computational Linguistics. Stephen Merity, Caiming Xiong, James Bradbury, and Javid Ebrahimi, Dhruv Gelda, and Wei Zhang. 2020. Richard Socher. 2016. Pointer sentinel mixture mod- How can self-attention networks recognize Dyck-n els. languages? In Findings of the Association for Com- putational Linguistics: EMNLP 2020, pages 4301– William Merrill. 2019. Sequential neural networks as 4306, Online. Association for Computational Lin- automata. Proceedings of the Workshop on Deep guistics. Learning and Formal Languages: Building Bridges. William Merrill. 2020. On the linguistic capacity of Winthrop Nelson Francis and Henry Kučera. 1989. real-time counter automata. Manual of information to accompany a standard cor- pus of present-day edited American English, for use William. Merrill, Gail Garfinkel Weiss, Yoav Gold- with digital computers. Brown University, Depart- berg, Roy Schwartz, Noah A. Smith, and Eran Yahav. ment of Linguistics. 2020. A formal hierarchy of RNN architectures. In Proc. of ACL. Micah Goldblum, Jonas Geiping, Avi Schwarzschild, Michael Moeller, and Tom Goldstein. 2020. Truth Mor Shpigel Nacson, Suriya Gunasekar, Jason D. Lee, or backpropaganda? an empirical investigation of Nathan Srebro, and Daniel Soudry. 2019. Lexi- deep learning theory. In International Conference cographic and depth-sensitive margins in homoge- on Learning Representations. neous and non-homogeneous deep models. Suriya Gunasekar, Blake Woodworth, Srinadh Bho- Tomaso Poggio, Andrzej Banburski, and Qianli Liao. janapalli, Behnam Neyshabur, and Nathan Srebro. 2019. Theoretical issues in deep networks: Approx- 2017. Implicit regularization in matrix factorization. imation, optimization and generalization.
Tomaso Poggio, Qianli Liao, and Andrzej Banburski. Zhilin Yang, Zihang Dai, Yiming Yang, Jaime Car- 2020. Complexity control by gradient descent in bonell, Ruslan Salakhutdinov, and Quoc V. Le. 2019. deep networks. Nature communications, 11(1):1–5. Xlnet: Generalized autoregressive pretraining for language understanding. Colin Raffel, Noam Shazeer, Adam Roberts, Katherine Lee, Sharan Narang, Michael Matena, Yanqi Zhou, Chulhee Yun, Srinadh Bhojanapalli, Ankit Singh Wei Li, and Peter J. Liu. 2019. Exploring the limits Rawat, Sashank Reddi, and Sanjiv Kumar. 2020. of transfer learning with a unified text-to-text trans- Are transformers universal approximators of former. sequence-to-sequence functions? In International Conference on Learning Representations. Noam Razin and Nadav Cohen. 2020. Implicit regular- ization in deep learning may not be explainable by A Experimental Details norms. Anna Rogers, Olga Kovaleva, and Anna Rumshisky. We provide experimental details for the small lan- 2020. A primer in BERTology: What we know guage models that we trained. The models were about how BERT works. Transactions of the Associ- trained for 5 epochs, and the best performing model ation for Computational Linguistics, 8:842–866. was selected based on development loss. Reported Noam Shazeer and Mitchell Stern. 2018. Adafactor: metrics were then measured on the held-out test Adaptive learning rates with sublinear memory cost. set. We used our own implementation of the stan- CoRR, abs/1804.04235. dard pre- and post-norm transformer architectures. Chihiro Shibata, Kei Uchiumi, and Daichi Mochihashi. We did not do any hyperparameter search, instead 2020. How LSTM encodes syntax: Exploring con- choosing the following hyperparameters: text vectors and semi-quantization on natural text. • Batch size of 16 • Model dimension of 768 Hava T. Siegelmann and Eduardo D. Sontag. 1992. On the computational power of neural nets. In Proc. of • Feedforward hidden dimension of 512 COLT, pages 440–449. • 12 heads per layer • 12 layers Mirac Suzgun, Yonatan Belinkov, Stuart Shieber, and • AdamW optimizer with default PyTorch hy- Sebastian Gehrmann. 2019a. LSTM networks can perform dynamic counting. In Proceedings of the perparameters Workshop on Deep Learning and Formal Languages: • 0 probability of dropout Building Bridges, pages 44–54. • Default PyTorch initialization Mirac Suzgun, Sebastian Gehrmann, Yonatan Belinkov, Tokenization For Wikitext-2, 3 tokens in the and Stuart M. Shieber. 2019b. Memory-augmented recurrent neural networks can learn generalized whole test dataset were unattested in the training set Dyck languages. (due to capitalization). To make our model compat- ible with unseen tokens, we replaced these tokens Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob with , the same class that appeared for low Uszkoreit, Llion Jones, Aidan N. Gomez, Lukasz Kaiser, and Illia Polosukhin. 2017. Attention is all frequency words at training time, when evaluating you need. the final text perplexity. Due to the small number of tokens that were affected, the impact of this change Chenguang Wang, Mu Li, and Alexander J. Smola. 2019. Language models with transformers. CoRR, should be negligible. abs/1904.09408. Compute We estimate the experiments in this Gail Weiss, Yoav Goldberg, and Eran Yahav. 2018. On paper took several hundred GPU hours on NVIDIA the practical computational power of finite precision A100 GPUs over the course of almost two years of RNNs for language recognition. on-and-off research time. Thomas Wolf, Lysandre Debut, Victor Sanh, Julien T5 We used the historical checkpoints of bsl-0, Chaumond, Clement Delangue, Anthony Moi, Pier- ric Cistac, Tim Rault, R’emi Louf, Morgan Funtow- one of five T5-base models that was trained for the icz, and Jamie Brew. 2019. Huggingface’s trans- original paper (Raffel et al., 2019). formers: State-of-the-art natural language process- ing. ArXiv, abs/1910.03771. Measuring Norms As a systematic choice, all measurements of parameter norm include only en- Ruibin Xiong, Yunchang Yang, Di He, Kai Zheng, Shuxin Zheng, Chen Xing, Huishuai Zhang, Yanyan coder parameters that are not scalars. We advise Lan, Liwei Wang, and Tie-Yan Liu. 2020. On layer other researchers to follow the practice of exclud- normalization in the transformer architecture. ing embedding parameters, as embedding param-
Component k Input k Output C.1 Consistency Linear k k+1 We first prove that approximate homogeneity is Bias 1 1 consistent: in other words, if a function is both Affine 0 1 approximately k1 and k2 -homogeneous, then k1 = LayerNorm k 0 k2 . This is an important property for establishing LayerNorm + Affine k 1 approximate homogeneity as a meaningful notion. ReLU k k Sum (k, k) k Lemma 1 Let k1 , k2 ∈ N. Assume that f is both Product (k1 , k2 ) k1 + k2 approximately k1 and k2 -homogeneous. Then k1 = k2 . Table 1: Effects of network components on homogene- ity shown by Li and Arora (2019). We write the “k Proof. If f is both approximately k1 and k2 - Output” homogeneity as a function of the “k Input” ho- homogeneous, then we have vanishing terms 1 mogeneity. These facts can be applied recursively to and 2 such that, for all c, compute the homogeneity of a network. We will show that the same facts hold for approximate homogeneity. f (cθ) = ck1 f (θ) + 1 f (cθ) = ck2 f (θ) + 2 . eters that are infrequently updated may obscure general trends in the network parameters. Subtracting both sides yields B Norm Growth and Saturation 0 = (ck1 − ck2 )f (θ) + 1 − 2 |1 − 2 | Proposition 2 (Formal version of Prop. 1) Let θt ∈ ∴ ck1 − ck2 = . Rn be the parameter vector at train step t for a f (θ) network f (x; θt ). Assume that, as t → ∞, there The right-hand side vanishes exponentially in kθk exists a scalar sequence c(t) → ∞ and fixed vector for all c, whereas the left-hand side grows with c θ0 ∈ (R \ {0})n such that, for all t, θt → θ0 · c(t). unless k1 = k2 . Thus, to satisfy this equation for Then f converges pointwise to a saturated network all c, it must be the case that k1 = k2 . in function space. C.2 Closure Properties Proof. We now prove that effects of various functions on lim f (x; θt ) = lim f (x; θ0 · c(t)). homogeneity explored by Li and Arora (2019) also t→∞ t→∞ translate to approximate homogeneity. Now, since c(t) → ∞ and θ0 · c(t) contains no Lemma 2 ReLU preserves approximate k- indeterminate elements, we can simplify this to homogeneity, i.e., let f : Rn → R be approxi- lim f (x; cθ0 ) = sf (x; θ0 ). mately k-homogeneous. Then ReLU ◦f is approxi- c→∞ mately k-homogeneous. Proof. C Approximate Homogeneity ReLU f (cθ) = ReLU ck f (θ) + In this section, we will further develop the notion ≤ ReLU ck f (θ) + ||. of approximate homogeneity. We will prove that is consistent. In other words, every function can Therefore, have at most one degree k of approximate homo- geneity. Next, we will show that the useful closure ReLU f (cθ) − ReLU ck f (θ) ≤ ||. properties applying to full homogeneity also apply to partial homogeneity. Set 0 = ||, showing ReLU f (θ) is approxi- If f (θ) is approximately k-homogeneous mately k-homogeneous. (cf. Def. 3), then f (cθ) = ck f (θ)+ for some error vector where, for each i, |i | ≤ exp(−dkθk)), for Lemma 3 Let f, g be vector-valued functions of θ. all c and large enough kθk. We use this notation If f and g are approximately k-homogeneous, then throughout this section. f + g is approximately k-homogeneous.
Proof. Proof. Since addition preserves approximate k- homogeneity, mean (and difference to mean), pre- f (cθ) + g(cθ) = ck f (θ) + f + ck g(θ) + g serve approximate k-homogeneity. Letting C = ck , = ck f (θ) + ck g(θ) + 0 , we can write where 0 = f + g . Thus, f (cθ) − µ(f (cθ)) = C f (θ) − µ(f (θ)) + . f (cθ) + g(cθ) − ck f (θ) + g(θ) ≤ 0 . We now apply this to the definition of layer norm to get f (cθ)i − µ(f (cθ)) Lemma 4 Let f, g be vector-valued functions LN(f (cθ))i = of θ. If f and g are approximately kf and kg - kf (cθ) − µ(f (cθ))k homogeneous, then f ·g is approximately (kf +kg )- C f (θ)i − µ(f (θ)) + i = . homogeneous. Ckf (θ) − µ(f (θ))k + Proof. We show that the difference between this and the unscaled layer norm goes to zero. To simplify f (cθ) · g(cθ) = ck f (θ) + f · ck g(θ) + g notation, we now write f = f (θ), µ = µ(f (θ)), = ckf +kg f (θ)g(θ) + ckf f (θ)g and = in the left-hand side below: + ckg g(θ)f + f g . |LN(f (cθ))i − LN(f (θ)i )| We now rewrite the term ckf f (θ)g as C fi − µ + i fi − µ = − θg(x; θ̂) Ckf − µk + kf − µk ≤ exp −d0 kθk . exp(−dkθk) i kf − µk − (fi − µ) = Ckf − µk2 + kf − µk Now, set 0 = min(exp(−dkθk), f g ). i − v = f (cθ)g(cθ) − ckf +kg f (θ)g(θ) ≤ 0 . Ckf − µk + i − v ≤ . The analogous results for linear transformation, for some v ∈ Rn which does not grow with kθk. bias, and affine transformation directly follow from Thus, setting 0 to this final quantity satisfies the the results for sum and product in Lem. 3 and definition of approximate 0-homogeneity, i.e. ap- Lem. 4. proximate scale invariance. Finally, we show that layer norm converts a homogeneous function to approximately scale- C.3 Saturating Activation Functions invariant function. In order to be numerically sta- ble, practical implementations of layer norm utilize We show that the exponentially saturation activa- a small tolerance term so that the denominator is tion functions σ, softmax, and tanh are approxi- never zero. We omit this practical detail from our mately scale-invariant in x, i.e. scaling x has an analysis, instead defining the layer norm LN(x) for exponentially diminishing effect on the output. We x ∈ Rn according to start by analyzing the simpler sigmoid, and then show that the same result holds for softmax. For n completeness, we then present a proof for tanh. 1X µ(x) = xi We use Θ (not θ) in the standard sense of asymp- n i=1 totic notation. xi − µ(x) LN(x)i = . Lemma 6 The scaling error for σ vanishes expo- kx − µ(x)k nentially in the preactivation magnitude, i.e. for all Lemma 5 Let f be approximately k-homogeneous c ≥ 1, for some k. Then, LN(f ) is approximately 0- homogeneous. |σ(cx) − σ(x)| ≤ Θ(exp(−|x|)).
Proof. Assume without loss of generality that x 6= Finally, for completeness, we show that tanh ex- 0, as if this is the case, the error is 0. When x > 0, hibits the same property. The proof is very similar we have to sigmoid, following closely from the definition |σ(cx) − σ(x)| = σ(cx) − σ(x) exp(2x) − 1 tanh(x) = . ≤ 1 − σ(|x|) exp(2x) + 1 1 = Lemma 8 The scaling error for tanh vanishes exp(|x|) + 1 exponentially in the preactivation magnitude, i.e. = Θ(exp(−|x|)). for all c ≥ 1, When x < 0, we have |tanh(cx) − tanh(x)| ≤ exp(−Θ(|x|)). |σ(cx) − σ(x)| = σ(x) − σ(cx) ≤ 1 − σ(|x|) + 0 Proof. = Θ(exp(−|x|)). |tanh(cx) − tanh(x)| ≤ |1 − tanh(x)| = 1 − tanh(|x|) Lemma 7 The elementwise scaling error for exp(2|x|) − 1 =1− softmax vanishes exponentially in the preactiva- exp(2|x|) + 1 tion norm, i.e. for all c ≥ 1, x ∈ Rn s.t. 1 ≤ i ≤ n, exp(2|x|) + 1 − exp(2|x|) + 1 = exp(2|x|) + 1 |softmax(cx)i − softmax(x)i | ≤ exp(−Θ(kxk)). 2 = Proof. The proof closely follows that of Lem. 6, exp(2|x|) + 1 but is more involved. We consider two cases: xi = = exp(−Θ(|x|)). max(x), and xi 6= max(x). Case 1 xi = max(x). Thus, applying these functions to a homoge- |softmax(cx)i − softmax(x)i | neous input produces an output that is approxi- mately scale-invariant in the parameters θ. Thus, = softmax(cx)i − softmax(x)i these functions act similarly to layer norm, which ≤ 1 − softmax(x)i maps homogeneous input to scale-invariant output. exp(xi ) But what happens if the input is approximately ho- =1− P j exp(xj ) mogeneous, rather than strictly homogeneous? In exp(max(x)) this case, we show that the output is approximately ≤1− scale-invariant assuming kθk is sufficiently large. exp(max(x)) + (n − 1) exp(min(x)) 1 Proposition 3 Let f (x; θ) be approximately k- =1− homogeneous in θ. Then the following functions 1 + (n − 1) exp(min(x) − max(x)) 1 are approximately scale-invariant in θ: =1− , 1 + exp(min(x) − max(x) + d) gσ = σ ◦ f for some d ∈ R. As this has the form of σ, gsoftmax = softmax ◦f |softmax(cx)i − softmax(x)i | gtanh = tanh ◦f. = 1 − σ(Θ(kxk)) = exp(−Θ(kxk)). Proof. If f (x; θ) is approximately k-homogeneous, Case 2 xi 6= max(x). then f (x; cθ) = ck f (x; θ) + where kk ≤ |softmax(cx)i − softmax(x)i | exp(−O(kθk)). Crucially, since vanishes for large norm, there is some ρ where, for all θ such = softmax(x)i − softmax(cx)i that ρ < kθk: ≤ 1 − max(softmax(x)) − 0 sgn ck f (x; θ) + = sgn ck f (x; θ) = 1 − softmax(max(x)), arg max ck f (x; θ) + = arg max ck f (x; θ) . which is identical to case 1.
Therefore, for θ such that kθk > ρ, the bounds a self-attention head attn as used in Lem. 6, Lem. 7, and Lem. 8 hold for approximately homogeneous f . Thus, we can K = W kX conclude that the output is approximately scale- Q = W qX invariant. V = W vX A = softmax(QK > / dk ) p D Transformers H = AV, We introduce the notation ∼k-homogeneous to mean approximately k-homogeneous. In this sec- where H is the output tensor. tion, we show that the transformer encoder is ∼1- The multi-head self-attention sublayer computes homogeneous. A transformer Vaswani et al. (2017) several attention heads in parallel and aggregates is made up of three main components: an em- them into a single sequence of vectors. bedding layer, self attention sublayers, and feed- Definition 5 (Multi-head self-attention sublayer) forward sublayers. Since the embedding layer is Let X ∈ RT n be the input. We now define the just a matrix multiplication, it is a 1-homogeneous k-multi-head self-attention sublayer MSAk . First, function of the input. Assuming the self attention we compute k self-attention heads in parallel to and feed-forward sublayers have no bias terms, we produce H1 , · · · , Hk . We then concatenate these show that they approximate functions preserving along the feature axis to form H, and compute the approximate 1-homogeneity. As the full network is sublayer output Y as an initial embedding layer followed by these sub- layers, the final output is ∼1-homogeneous. In the MSAk (X) = LN(W H) + X. main paper, we discuss the connection between homogeneity and norm growth. Finally, the linear sublayer is the other compo- We base our analysis on the HuggingFace imple- nent of the transformer. mentation16 of BERT (Wolf et al., 2019). To aid Definition 6 (Feedforward sublayer) Let X ∈ analysis, we make some simplifying assumptions, RT n be the input. We compute the feedforward which are discussed along with the definitions. We sublayer FF according to later show empirically that homogeneity for the unsimplified versions is similar. FF(X) = LN(W f ReLU(W i X)) + X. D.1 Transformer Definition D.2 Results The transformer encoder is a cascade of alternat- Proposition 4 If X is ∼1-homogeneous in ing multi-head self-attention sublayers and feed- parameters θ, then attn(X; W k , W q , W v ) forward sublayers. Each multi-head self-attention is ∼1-homogeneous in the concatenation of sublayer can be further broken down as an aggre- θ, W k , W q , W v . gation of self-attention heads. Let LN(·) denote a layer norm followed by a learned affine trans- Proof. Consider a self-attention layer receiving a formation. Here we will consider the pre-norm ∼1-homogeneous input matrix X ∈ RT n where transformer variant (Xiong et al., 2020), meaning T is the sequence length. Using the homogene- that LN comes before the residual connection wher- ity rule for multiplication, K, Q, V are each ∼2- ever it appears.17 We will also assume that there homogeneous, as homogeneity is additive over mul- are no biases, making all affine transformations tiplication. By the same argument, QK > is ∼4- into strict linear transformations. homogeneous. In Prop. 3, we show that if the input to softmax is approximately homogeneous, then Definition 4 (Self-attention head) Given parame- the output is approximately scale-invariant. Thus, ters W k , W q , W v and input X ∈ RT n , we define A is approximately 0-homogeneous. Then AV is 16 ∼1-homogeneous. https://huggingface.co/transformers/ _modules/transformers/modeling_bert. We show that the multi-head component that ag- html#BertModel 17 The post-norm transformer applies these operations in the gregates multiple heads into a shared representation opposite order. also preserves approximate 1-homogeneity.
Proposition 5 If X is ∼1-homogeneous in param- E.2 Misaligned Case eters θ, then MSA is ∼1-homogeneous in the full First, we derive the sum approximation for kθt k. parameters. We start with the fact that θt+1 = θt + δt and misalignment, i.e., θt> · δt = 0. Proof. Since W h is ∼2-homogeneous, LN(W H) is ∼1-homogeneous. The input X is also ∼1- kθt+1 k2 = (θt + δt ) · (θt + δt ) homogeneous by assumption, meaning that the sum = kθt k2 + θt> δt + kδt k2 is also ∼1-homogeneous. = kθt k2 + kδt k2 t Finally, we turn to analyzing the feedforward 2 X sublayer of the transformer. = kθ0 k + kδi k2 . i=0 Proposition 6 If X is ∼1-homogeneous, then FF(X; W f , W i ) is ∼1-homogeneous in the full Taking the square Proot of both sides, kθt k is roughly parameters. proportional to i=0 kδi k2 . t Next, we show how to solve the integral, simi- larly to §E.1. Proof. Multiplying by each W increases approx- imate homogeneity by 1, and ReLU preserves Z approximate homogeneity. So the input to LN 2 ∝ kθt k ∼ kδt k2 dt is ∼3-homogeneous. Thus, its output is ∼1- Z homogeneous, and adding X preserves approxi- kθt k2 ∝ 2 ∼ ηt kθt k dt 2 mate 1-homogeneity. d kθt k2 ∝ ηt2 kθt k2 ∼ dt Together, these results suggest that the pre-norm dkθt k2 ∝ ηt2 dt transformer output is ∼1-homogeneous, assuming ∼ its input is ∼1-homogeneous. This precondition kθt k2 Z for the input holds in the “base case” of standard 2 ∝ log kθt k ∼ ηt2 dt. embeddings. By induction, we can imagine the output of a biasless pre-norm transformer encoder √ Now, we plug in the ηt = 1/ t learning rate: of any depth to be ∼1-homogeneous. Interestingly, the homogeneity arguments do not 2 ∝ Z √ work out if we instead consider the post-norm trans- log kθt k ∼ (1/ t)2 dt former architecture (Xiong et al., 2020). Z dt ∝ ∼ t E Sum Derivation ∝ log t. ∼ E.1 Aligned Case √ So, in conclusion: kθt k ∝ ∼ t. Assume that kθt k ≈ 0. Then, Z F Weight Decay kθt k ∝ ∼ ηt kθt kdt Weight decay regularizes the loss by the squared `2 d norm, modulated by a decay factor λ. For GD, this kθt k ∝ ∼ ηt kθt k can be written dt dkθt k ∝ ηt dt δt = −ηt ∇θt L − λθt . (3) kθt k ∼ Z log kθt k ∝ Intuitively, the new term −λθt will influence each ∼ ηt dt Z step to point towards 0. Thus, large values of λ kθt k ∝ exp η dt . might intuitively be expected to hinder or prevent ∼ t norm growth. While we leave developing a more √ √ complete theoretical story to future work, here we Plugging in ηt = 1/ t, we get kθt k ∝ ∼ exp t . empirically investigate the interplay of a constant
You can also read