ASAM: Adaptive Sharpness-Aware Minimization for Scale-Invariant Learning of Deep Neural Networks - arXiv
←
→
Page content transcription
If your browser does not render page correctly, please read the page content below
ASAM: Adaptive Sharpness-Aware Minimization for Scale-Invariant Learning of Deep Neural Networks Jungmin Kwon 1 Jeongseop Kim 1 Hyunseo Park 1 In Kwon Choi 1 arXiv:2102.11600v3 [cs.LG] 29 Jun 2021 Abstract Especially, Sharpness-Aware Minimization (SAM) (Foret et al., 2021) as a learning algorithm based on Recently, learning algorithms motivated from PAC-Bayesian generalization bound, achieves a state- sharpness of loss surface as an effective mea- of-the-art generalization performance for various image sure of generalization gap have shown state-of- classification tasks benefiting from minimizing sharpness the-art performances. Nevertheless, sharpness of loss landscape, which is correlated with generalization defined in a rigid region with a fixed radius, has gap. Also, they suggest a new sharpness calculation strat- a drawback in sensitivity to parameter re-scaling egy, which is computationally efficient, since it requires which leaves the loss unaffected, leading to weak- only a single gradient ascent step in contrast to other ening of the connection between sharpness and complex generalization measures such as sample-based or generalization gap. In this paper, we introduce Hessian-based approach. the concept of adaptive sharpness which is scale- invariant and propose the corresponding gener- However, even sharpness-based learning methods includ- alization bound. We suggest a novel learning ing SAM and some of sharpness measures suffer from sen- method, adaptive sharpness-aware minimization sitivity to model parameter re-scaling. Dinh et al. (2017) (ASAM), utilizing the proposed generalization point out that parameter re-scaling which does not change bound. Experimental results in various bench- loss functions can cause a difference in sharpness values so mark datasets show that ASAM contributes to this property may weaken correlation between sharpness significant improvement of model generalization and generalization gap. We call this phenomenon scale- performance. dependency problem. To remedy the scale-dependency problem of sharpness, many studies have been conducted recently (Liang et al., 1. Introduction 2019; Yi et al., 2019; Karakida et al., 2019; Tsuzuku et al., 2020). However, those previous works are limited to Generalization of deep neural networks has recently been proposing only generalization measures which do not suf- studied with great importance to address the shortfalls fer from the scale-dependency problem and do not provide of pure optimization, yielding models with no guarantee sufficient investigation on combining learning algorithm on generalization ability. To understand the generaliza- with the measures. tion phenomenon of neural networks, many studies have To this end, we introduce the concept of normalization op- attempted to clarify the relationship between the geome- erator which is not affected by any scaling operator that try of the loss surface and the generalization performance does not change the loss function. The operator varies de- (Hochreiter et al., 1995; McAllester, 1999; Keskar et al., pending on the way of normalizing, e.g., element-wise and 2017; Neyshabur et al., 2017; Jiang et al., 2019). Among filter-wise. We then define adaptive sharpness of the loss many proposed measures used to derive generalization function, sharpness whose maximization region is deter- bounds, loss surface sharpness and minimization of the mined by the normalization operator. We prove that adap- derived generalization bound have proven to be effec- tive sharpness remains the same under parameter re-scaling, tive in attaining state-of-the-art performances in various i.e., scale-invariant. Due to the scale-invariant property, tasks (Hochreiter & Schmidhuber, 1997; Mobahi, 2016; adaptive sharpness shows stronger correlation with gener- Chaudhari et al., 2019; Sun et al., 2020; Yue et al., 2020). alization gap than sharpness does. 1 Samsung Research, Seoul, Republic of Korea. Correspon- dence to: Jeongseop Kim . Motivated by the connection between generalization met- rics and loss minimization, we propose a novel learning Proceedings of the 38 th International Conference on Machine method, adaptive sharpness-aware minimization (ASAM), Learning, PMLR 139, 2021. Copyright 2021 by the author(s).
ASAM: Adaptive Sharpness-Aware Minimization for Scale-Invariant Learning of Deep Neural Networks which adaptively adjusts maximization regions thus act- for some strictly increasing function h. The domain of max ing uniformly under parameter re-scaling. ASAM mini- operator, called maximization region, is an ℓp ball with ra- mizes the corresponding generalization bound using adap- dius ρ for p ≥ 1. Here, sharpness of the loss function L is tive sharpness to generalize on unseen data, avoiding the defined as scale-dependency issue SAM suffers from. max LS (w + ǫ) − LS (w). (2) The main contributions of this paper are summarized as fol- kǫk2 ≤ρ lows: Because of the monotonicity of h in Equation 1, it can • We introduce adaptive sharpness of loss surface which be substituted by ℓ2 weight decaying regularizer, so the is invariant to parameter re-scaling. In terms of rank sharpness-aware minimization problem can be defined as statistics, adaptive sharpness shows stronger correla- the following minimax optimization tion with generalization than sharpness does, which means that adaptive sharpness is more effective mea- λ min max LS (w + ǫ) + kwk22 sure of generalization gap. w kǫkp ≤ρ 2 • We propose a new learning algorithm using adaptive where λ is a weight decay coefficient. sharpness which helps alleviate the side-effect in train- ing procedure caused by scale-dependency by adjust- SAM solves the minimax problem by iteratively applying ing their maximization region with respect to weight the following two-step procedure for t = 0, 1, 2, . . . as scale. ∇LS (wt ) ( ǫt = ρ • We empirically show its consistent improvement of k∇LS (wt )k2 (3) generalization performance on image classification wt+1 = wt − αt (∇LS (wt + ǫt ) + λwt ) and machine translation tasks using various neural net- work architectures. where αt is an appropriately scheduled learning rate. This procedure can be obtained by a first order approximation of LS and dual norm formulation as The rest of this paper is organized as follows. Section 2 briefly describes previous sharpness-based learning algo- ǫt = arg max LS (wt + ǫ) rithm. In Section 3, we introduce adaptive sharpness which kǫkp ≤ρ is a scale-invariant measure of generalization gap after scale-dependent property of sharpness is explained. In Sec- ≈ arg max ǫ⊤ ∇LS (wt ) kǫkp ≤ρ tion 4, ASAM algorithm is introduced in detail using the definition of adaptive sharpness. In Section 5, we evalu- |∇LS (wt )|q−1 = ρ sign(∇LS (wt )) ate the generalization performance of ASAM for various k∇LS (wt )kqq−1 models and datasets. We provide the conclusion and future work in Section 6. and λ 2. Preliminary wt+1 = argmin LS (w + ǫt ) + kwk22 w 2 Let us consider a model f : X → Y parametrized by a λ ≈ argmin (w − wt )⊤ ∇LS (wt + ǫt ) + kwk22 weight vector w and a loss function l : Y × Y → R+ . w 2 Given a sample S = {(x1 , y1 ), . . . , (xn , yn )} drawn from ≈ wt − αt (∇LS (wt + ǫt ) + λwt ) data distribution D with i.i.d Pcondition, the training loss can n be defined as LS (w) = i=1 l(y i , f (xi ; w))/n. Then, where 1/p + 1/q = 1 and | · | denotes element-wise abso- the generalization gap between the expected loss LD (w) = lute value function, and sign(·) also denotes element-wise E(x,y)∼D [l(y, f (x; w))] and the training loss LS (w) repre- signum function. It is experimentally confirmed that the sents the ability of the model to generalize on unseen data. above two-step procedure produces the best performance Sharpness-Aware Minimization (SAM) (Foret et al., 2021) when p = 2, which results in Equation 3. aims to minimize the following PAC-Bayesian generaliza- As can be seen from Equation 3, SAM estimates the point tion error upper bound wt + ǫt at which the loss is approximately maximized around wt in a rigid region with a fixed radius by perform- kwk22 ing gradient ascent, and performs gradient descent at wt LD (w) ≤ max LS (w + ǫ) + h (1) using the gradient at the maximum point wt + ǫt . kǫkp ≤ρ ρ2
ASAM: Adaptive Sharpness-Aware Minimization for Scale-Invariant Learning of Deep Neural Networks 3. Adaptive Sharpness: Scale-Invariant Measure of Generalization Gap In Foret et al. (2021), it is experimentally confirmed that the sharpness defined in Equation 2 is strongly correlated with the generalization gap. Also they show that SAM helps to find minima which show lower sharpness than other learning strategies and contributes to effectively low- ering generalization error. (a) kǫk2 ≤ ρ and kǫ′ k2 ≤ ρ (Foret et al., 2021) However, Dinh et al. (2017) show that sharpness defined in the rigid spherical region with a fixed radius can have a weak correlation with the generalization gap due to non- identifiability of rectifier neural networks, whose parame- ters can be freely re-scaled without affecting its output. If we assume that A is a scaling operator on the weight space that does not change the loss function, as shown in Figure 1(a), the interval of the loss contours around Aw becomes narrower than that around w but the size of the −1 −1 ′ region remains the same, i.e., (b) kTw ǫk∞ ≤ ρ and kTw′ ǫ k∞ ≤ ρ (Keskar et al., 2017) max LS (w + ǫ) 6= max LS (Aw + ǫ). kǫk2 ≤ρ kǫk2 ≤ρ Thus, neural networks with w and Aw can have arbitrar- ily different values of sharpness defined in Equation 2, al- though they have the same generalization gaps. This prop- erty of sharpness is a main cause of weak correlation be- tween generalization gap and sharpness and we call this scale-dependency in this paper. −1 −1 ′ (c) kTw ǫk2 ≤ ρ and kTw′ ǫ k2 ≤ ρ (In this paper) To solve the scale-dependency of sharpness, we introduce the concept of adaptive sharpness. Prior to explaining adap- tive sharpness, we first define normalization operator. The Figure 1. Loss contours and three types of maximization regions: (a) sphere, (b) cuboid and (c) ellipsoid. w = (1, 1) and w′ = normalization operator that cancels out the effect of A can (3, 1/3) are parameter points before and after multiplying a scal- be defined as follows. ing operator A = diag(3, 1/3) and are expressed as dots and Definition 1 (Normalization operator). Let {Tw , w ∈ Rk } triangles, respectively. The blue contour line has the same loss at be a family of invertible linear operators on Rk . Given w, and the red contour line has a loss equal to the maximum value a weight w, if TAw−1 A = Tw −1 for any invertible scaling of the loss in each type of region centered on w. ǫ∗ and ǫ′∗ are k operator A on R which does not change the loss function, the ǫ and ǫ′ which maximize the loss perturbed from w and w′ , −1 respectively. we say Tw is a normalization operator of w. Using the normalization operator, we define adaptive sharp- Theorem 1. For any invertible scaling operator A which ness as follows. does not change the loss function, values of adaptive sharp- −1 ness at w and Aw are the same as Definition 2 (Adaptive sharpness). If Tw is the normal- ization operator of w in Definition 1, adaptive sharpness of w is defined by −1 max LS (w + ǫ) − LS (w) kTw ǫkp ≤ρ max LS (w + ǫ) − LS (w) (4) = max −1 LS (Aw + ǫ) − LS (Aw) −1 kTw ǫkp ≤ρ kTAw ǫkp ≤ρ −1 −1 where 1 ≤ p ≤ ∞. where Tw and TAw are the normalization operators of w and Aw in Definition 1, respectively. Adaptive sharpness in Equation 4 has the following proper- ties. Proof. From the assumption, it suffices to show that the
ASAM: Adaptive Sharpness-Aware Minimization for Scale-Invariant Learning of Deep Neural Networks first terms of both sides are equal. By the definition of the 0.30 0.30 Generalization gap Generalization gap −1 −1 normalization operator, we have TAw A = Tw . Therefore, 0.25 0.25 −1 max −1 LS (Aw + ǫ) = max −1 LS (w + A ǫ) 0.20 0.20 kTAw ǫkp ≤ρ kTAw ǫkp ≤ρ = max LS (w + ǫ′ ) 0.15 0.15 −1 kTAw Aǫ′ kp ≤ρ = max LS (w + ǫ′ ) 0.0000 0.0001 0.0 0.5 1.0 −1 ′ kTw ǫ kp ≤ρ Sharpness Adaptive Sharpness (a) Sharpness (p = 2), (b) Adaptive sharpness (p = where ǫ′ = A−1 ǫ. τ = 0.174. 2), τ = 0.636. By Theorem 1, adaptive sharpness defined in Equation 4 is 0.30 Generalization gap Generalization gap scale-invariant as with training loss and generalization loss. 0.25 0.25 This property makes the correlation of adaptive sharpness with the generalization gap stronger than that of sharpness 0.20 0.20 in Equation 2. 0.15 0.15 Figure 1(b) and 1(c) show how a re-scaled weight vector can have the same adaptive sharpness value as that of the 0 5 10 0 2 4 original weight vector. It can be observed that the bound- Sharpness Adaptive Sharpness ary line of each region centered on w′ is in contact with the (c) Sharpness (p = ∞), (d) Adaptive sharpness red line. This implies that the maximum loss within each τ = 0.257. (p = ∞), τ = 0.616. −1 ′ region centered on w′ is maintained when kTw ′ ǫ kp ≤ ρ is used for the maximization region. Thus, in this exam- Figure 2. Scatter plots which show correlation of sharpness and ple, it can be seen that adaptive sharpness in 1(b) and 1(c) adaptive sharpness with respect to generalization gap and their has scale-invariant property in contrast to sharpness of the rank correlation coefficients τ . spherical region shown in Figure 1(a). The question that can be asked here is what kind of op- erators Tw can be considered as normalization operators Here, fi is the i-th flattened weight vector of a convolution −1 −1 filter and wj is the j-th weight parameter which is not in- which satisfy TAw A = Tw for any A which does not change the loss function. One of the conditions for the scal- cluded in any filters. And m is the number of filters and ing operator A that does not change the loss function is that q is the number of other weight parameters in the model. it should be node-wise scaling, which corresponds to row- If there is no convolutional layer in a model (i.e., m = 0), wise or column-wise scaling in fully-connected layers and then q = k and both normalization operators are identical channel-wise scaling in convolutional layers. The effect of to each other. Note that we use Tw + ηIk rather than Tw for such node-wise scaling can be canceled using the inverses sufficiently small η > 0 for stability. η is a hyper-parameter of the following operators: controlling trade-off between adaptivity and stability. To confirm that adaptive sharpness actually has a stronger • element-wise correlation with generalization gap than sharpness, we compare rank statistics which demonstrate the change of Tw = diag(|w1 |, . . . , |wk |) adaptive sharpness and sharpness with respect to gener- alization gap. For correlation analysis, we use 4 hyper- where parameters: mini-batch size, initial learning rate, weight w = [w1 , w2 , . . . , wk ], decay coefficient and dropout rate. As can be seen in Ta- ble 1, Kendall rank correlation coefficient (Kendall, 1938) • filter-wise of adaptive sharpness is greater than that of sharpness re- gardless of the value of p. Furthermore, we compare granu- Tw = diag(concat(kf1 k2 1n(f1 ) , . . . , kfm k2 1n(fm ) , lated coefficients (Jiang et al., 2019) with respect to differ- |w1 |, . . . , |wq |)) (5) ent hyper-parameters to measure the effect of each hyper- parameter separately. In Table 1, the coefficients of adap- where tive sharpness are higher in most hyper-parameters and the average as well. Scatter plots illustrated in Figure 2 also w = concat(f1 , f2 , . . . , fm , w1 , w2 , . . . , wq ). show stronger correlation of adaptive sharpness. The dif-
ASAM: Adaptive Sharpness-Aware Minimization for Scale-Invariant Learning of Deep Neural Networks adaptive sharpness. In that study, filter-wise normalization Table 1. Rank statistics for sharpness and adaptive sharpness. which is equivalent to the definition in Equation 5 is used p=2 p=∞ to remove scale-dependency from loss landscape and make adaptive adaptive comparisons between loss functions meaningful. In spite sharpness sharpness of their empirical success, Li et al. (2018) do not provide a sharpness sharpness τ (rank corr.) 0.174 0.636 0.257 0.616 theoretical evidence for explaining how filter-wise scaling mini-batch size 0.667 0.696 0.777 0.817 contributes the scale-invariance and correlation with gen- learning rate 0.563 0.577 0.797 0.806 eralization. In this paper, we clarify how the filter-wise weight decay −0.297 0.534 −0.469 0.656 normalization relates to generalization by proving the scale- dropout rate 0.102 −0.092 0.161 0.225 invariant property of adaptive sharpness in Theorem 1. Ψ (avg.) 0.259 0.429 0.316 0.626 Also, sharpness suggested in Keskar et al. (2017) can be regarded as a special case of adaptive sharpness which uses p = ∞ and the element-wise normalization operator. ference in correlation behaviors of adaptive sharpness and Jiang et al. (2019) confirm experimentally that the adap- sharpness provides an evidence that scale-invariant prop- tive sharpness suggested in Keskar et al. (2017) shows a erty helps strengthen the correlation with generalization higher correlation with the generalization gap than sharp- gap. The experimental details are described in Appendix B. ness which does not use element-wise normalization opera- tor. This experimental result implies that Theorem 1 is also Although there are various normalization methods other practically validated. than the normalization operators introduced above, this pa- per covers only element-wise and filter-wise normalization Therefore, it seems that sharpness with p = ∞ suggested operators. Node-wise normalization can also be viewed by Keskar et al. (2017) also can be used directly for learn- as a normalization operator. Tsuzuku et al. (2020) suggest ing as it is, but a problem arises in terms of generalization node-wise normalization method for obtaining normalized performance in learning. Foret et al. (2021) confirm exper- flatness. However, the method requires that the parameter imentally that the generalization performance with sharp- should be at a critical point. Also, in the case of node-wise ness defined in square region kǫk∞ ≤ ρ result is worse normalization using unit-invariant SVD (Uhlmann, 2018), than when SAM is performed with sharpness defined in there is a concern that the speed of the optimizer can be spherical region kǫk2 ≤ ρ. degraded due to the significant additional cost for scale- We conduct performance comparison tests for p = 2 and direction decomposition of weight tensors. Therefore the p = ∞, and experimentally reveal that p = 2 is more suit- node-wise normalization is not covered in this paper. In able for learning as in Foret et al. (2021). The experimental the case of layer-wise normalization using spectral norm results are shown in Section 5. or Frobenius norm of weight matrices (Neyshabur et al., −1 −1 2017), the condition TAw A = Tw is not satisfied. There- fore, it cannot be used for adaptive sharpness so we do not 4. Adaptive Sharpness-Aware Minimization cover it in this paper. In the previous section, we introduce a scale-invariant mea- Meanwhile, even though all weight parameters including sure called adaptive sharpness to overcome the limitation biases can have scale-dependency, there remains more to of sharpness. As in sharpness, we can obtain a generaliza- consider when applying normalization to the biases. In tion bound using adaptive sharpness, which is presented in terms of bias parameters of rectifier neural networks, there the following theorem. −1 also exists translation-dependency in sharpness, which Theorem 2. Let Tw be the normalization operator on Rk . weakens the correlation with the generalization gap as well. If LD (w) ≤ Eǫi ∼N (0,σ2 ) [LD (w+ǫ)] for some σ > 0, then Using the similar arguments as in the proof of Theorem 1, with probability 1 − δ, it can be derived that diagonal elements of Tw correspond- kwk22 ing to biases must be replaced by constants to guarantee LD (w) ≤ −1 max LS (w + ǫ) + h (6) kTw ǫk2 ≤ρ η 2 ρ2 translation-invariance, which induces adaptive sharpness that corresponds to the case of not applying bias normal- where h : R+√→ R+ ispa strictly increasing function, n = ization. We compare the generalization performance based |S| and ρ = kσ(1 + log n/k)/η. on adaptive sharpness with and without bias normalization, Note that Theorem 2 still holds for p > 2 due to the mono- in Section 5. tonicity of p-norm, i.e., if 0 < r < p, kxkp ≤ kxkr for There are several previous works which are closely related any x ∈ Rn . If Tw is an identity operator, Equation 6 is re- to adaptive sharpness. Li et al. (2018), which suggest a duced equivalently to Equation 1. The proof of Equation 6 methodology for visualizing loss landscape, is related to is described in detail in Appendix A.1.
ASAM: Adaptive Sharpness-Aware Minimization for Scale-Invariant Learning of Deep Neural Networks Algorithm 1 ASAM algorithm (p = 2) 0.25 0. 04 Input: Loss function l, training dataset S := 0. 0.20 02 ∪ni=1 {(xi , yi )}, mini-batch size b, radius of maximiza- 0.0 3 tion region ρ, weight decay coefficient λ, scheduled 0.0 learning rate α, initial weight w0 . 0.15 1 w2 Output: Trained weight w 0.10 Initialize weight w := w0 0.01 while not converged do 0.02 0.05 SAM Sample a mini-batch B of size b from S ASAM 0.03 T 2 ∇LB (w) ǫ := ρ w 0.00 kTw ∇LB (w)k2 0.0 0.1 0.2 0.3 0.4 † w := w − α (∇LB (w + ǫ) + λw) w1 end while return w Figure 3. Trajectories of SAM and ASAM. The right hand side of Equation 6, i.e., generalization and if p = ∞, bound, can be expressed using adaptive sharpness as ! ǫt = ρTwt sign(∇LS (wt )). kwk22 max LS (w + ǫ) − LS (w) + LS (w) + h . −1 kTw ǫkp ≤ρ η 2 ρ2 In this study, experiments are conducted on ASAM in cases of p = ∞ and p = 2. The ASAM algorithm with p = 2 Since h kwk22 /η 2 ρ2 is a strictly increasing function with is described in detail on Algorithm 1. Note that the SGD respect to kwk22 , it can be substituted with ℓ2 weight de- (Nesterov, 1983) update marked with † in Algorithm 1 can caying regularizer. Therefore, we can define adaptive be combined with momentum or be replaced by update of sharpness-aware minimization problem as another optimization scheme such as Adam (Kingma & Ba, 2015). λ min max LS (w + ǫ) + kwk22 . (7) w −1 kTw ǫkp ≤ρ 2 5. Experimental Results To solve the minimax problem in Equation 7, it is neces- In this section, we evaluate the performance of ASAM. We sary to find optimal ǫ first. Analogous to SAM, we can first show how SAM and ASAM operate differently in a toy approximate the optimal ǫ to maximize LS (w + ǫ) using a example. We then compare the generalization performance first-order approximation as of ASAM with other learning algorithms for various model architectures and various datasets: CIFAR-10, CIFAR- ǫ̃t = arg max LS (wt + Twt ǫ̃) 100 (Krizhevsky et al., 2009), ImageNet (Deng et al., 2009) kǫ̃kp ≤ρ and IWSLT’14 DE-EN (Cettolo et al., 2014). Finally, we ≈ arg max ǫ̃⊤ Twt ∇LS (wt ) show how robust to label noise ASAM is. kǫ̃kp ≤ρ |Twt ∇LS (wt )|q−1 5.1. Toy Example = ρ sign(∇LS (wt )) kTwt ∇LS (wt )kqq−1 As mentioned in Section 4, sharpness varies by parameter −1 re-scaling even if its loss function remains the same, while where ǫ̃ = Tw ǫ. Then, the two-step procedure for adap- adaptive sharpness does not. To elaborate this, we con- tive sharpness-aware minimization (ASAM) is expressed sider a simple loss function L(w) = |w1 ReLU(w2 )−0.04| as where w = (w1 , w2 ) ∈ R2 . Figure 3 presents the trajecto- ǫ = ρT sign(∇L (w )) |Twt ∇LS (wt )|q−1 ries of SAM and ASAM with two different initial weights t wt S t kTwt ∇LS (wt )kqq−1 w0 = (0.2, 0.05) and w0 = (0.3, 0.033). The red line wt+1 = wt − αt (∇LS (wt + ǫt ) + λwt ) represents the set of minimizers of the loss function L, i.e., {(w1 , w2 ); w1 w2 = 0.04, w2 > 0}. As seen in Figure 1(a), for t = 0, 1, 2, · · · . Especially, if p = 2, sharpness is maximized when w1 = w2 within the same loss contour line, and therefore SAM tries to converge to 2 Tw t ∇LS (wt ) (0.2, 0.2). Here, we use ρ = 0.05 as in Foret et al. (2021). ǫt = ρ On the other hand, adaptive sharpness remains the same kTwt ∇LS (wt )k2
ASAM: Adaptive Sharpness-Aware Minimization for Scale-Invariant Learning of Deep Neural Networks Table 2. Maximum test accuracies for SGD, SAM and ASAM on CIFAR-10 dataset. Model SGD SAM ASAM DenseNet-121 91.00±0.13 92.00±0.17 93.33±0.04 ResNet-20 93.18±0.21 93.56±0.15 93.82±0.17 ResNet-56 94.58±0.20 95.18±0.15 95.42±0.16 VGG19-BN∗ 93.87±0.09 94.60 95.07±0.05 ResNeXt29-32x4d 95.84±0.24 96.34±0.30 96.80±0.06 WRN-28-2 95.13±0.16 95.74±0.08 95.94±0.05 WRN-28-10 96.34±0.12 96.98±0.04 97.28±0.07 PyramidNet-272† 98.44±0.08 98.55±0.05 98.68±0.08 * Some runs completely failed, thus giving 10% of accuracy (suc- cess rate: SGD: 3/5, SAM 1/5, ASAM 3/5) † PyramidNet-272 architecture is tested 3 times for each learning algorithm. Table 3. Maximum test accuracies for SGD, SAM and ASAM on Figure 4. Test accuracy curves obtained from ASAM algorithm CIFAR-100 dataset. using a range of ρ with different factors: element-wise normaliza- Model SGD SAM ASAM tion with p = ∞, element-wise normalization with p = 2 with DenseNet-121 68.70±0.31 69.84±0.12 70.60±0.20 and without bias normalization (BN) and filter-wise normalization ResNet-20 69.76±0.44 71.06±0.31 71.40±0.30 with p = 2. ResNet-56 73.12±0.19 75.16±0.05 75.86±0.22 VGG19-BN∗ 71.80±1.35 73.52±1.74 75.80±0.27 ResNeXt29-32x4d 79.76±0.23 81.48±0.17 82.30±0.11 WRN-28-2 75.28±0.17 77.25±0.35 77.54±0.14 along the same contour line which implies that ASAM con- WRN-28-10 81.56±0.13 83.42±0.04 83.68±0.12 verges to the point in the red line near the initial point as PyramidNet-272† 88.91±0.12 89.36±0.20 89.90±0.13 can be seen in Figure 3. * Some runs completely failed, thus giving 10% of accuracy (suc- Since SAM uses a fixed radius in a spherical region for cess rate: SGD: 5/5, SAM 4/5, ASAM 4/5) † minimizing sharpness, it may cause undesirable results de- PyramidNet-272 architecture is tested 3 times for each learning algorithm. pending on the loss surface and the current weight. If w0 = (0.3, 0.033), while SAM even fails to converge to the valley with ρ = 0.05, √ ASAM converges no matter which w0 is used if ρ < 2. In other words, appropriate ρ for model (Zagoruyko & Komodakis, 2016) and illustrate the SAM is dependent on the scales of w on the training trajec- results in Figure 4. As can be seen, both test accuracies tory, whereas ρ of ASAM is not. are comparable across ρ, and element-wise normalization provides a slightly better accuracy at ρ = 1.0. 5.2. Image Classification: CIFAR-10/100 and ImageNet Similarly, Figure 4 shows how much test accuracy varies with the maximization region for adaptive sharpness. It can To confirm the effectiveness of ASAM, we conduct compar- be seen that p = 2 shows better test accuracies than p = ∞, ison experiments with SAM using CIFAR-10 and CIFAR- which is consistent with Foret et al. (2021). We could also 100 datasets. We use the same data split as the original pa- observe that bias normalization does not contribute to the per (Krizhevsky et al., 2009). The hyper-parameters used improvement of test accuracy in Figure 4. Therefore, we in this test is described in Table 3. Before the compari- decide to use element-wise normalization operator and p = son tests with SAM, there are three factors to be chosen in 2, and not to employ bias normalization in the remaining ASAM algorithm: tests. • normalization schemes: element-wise vs. filter-wise As ASAM has a hyper-parameter ρ to be tuned, we first conduct a grid search over {0.00005, 0.0001, 0.0002, . . . , • p-norm: p = ∞ vs. p = 2 0.5, 1.0, 2.0} for finding appropriate values of ρ. We use • bias normalization: with vs. without ρ = 0.5 for CIFAR-10 and ρ = 1.0 for CIFAR-100, be- cause it gives moderately good performance across various First, we perform the comparison test for filter-wise models. We set ρ for SAM as 0.05 for CIFAR-10 and 0.1 and element-wise normalization using WideResNet-16-8 for CIFAR-100 as in Foret et al. (2021). η for ASAM is set
ASAM: Adaptive Sharpness-Aware Minimization for Scale-Invariant Learning of Deep Neural Networks to 0.01. We set mini-batch size to 128, and m-sharpness Table 4. Top1 and Top5 maximum test accuracies for SGD, SAM suggested by Foret et al. (2021) is not employed. The num- and ASAM on ImageNet dataset using ResNet-50. ber of epochs is set to 200 for SAM and ASAM and 400 for SGD. Momentum and weight decay coefficient are set SGD SAM ASAM to 0.9 and 0.0005, respectively. Cosine learning rate de- Top1 75.79±0.22 76.39±0.03 76.63±0.18 cay (Loshchilov & Hutter, 2016) is adopted with an initial Top5 92.62±0.04 92.97±0.07 93.16±0.18 learning rate 0.1. Also, random resize, padding by four pix- els, normalization and random horizontal flip are applied for data augmentation and label smoothing (Müller et al., 2019) is adopted with its factor of 0.1. DE-EN, a dataset on machine translation task. Using the hyper-parameters, we compare the best test accu- racies obtained by SGD, SAM and ASAM for various recti- In this test, we adopt Transformer architec- fier neural network models: VGG (Simonyan & Zisserman, ture (Vaswani et al., 2017) and Adam optimizer as a 2015), ResNet (He et al., 2016), DenseNet (Huang et al., base optimizer of SAM and ASAM instead of SGD. 2017), WideResNet (Zagoruyko & Komodakis, 2016), and Learning rate, β1 and β2 for Adam are set to 0.0005, 0.9 ResNeXt (Xie et al., 2017). and 0.98, respectively. Dropout rate and weight decay coefficient are set to 0.3 and 0.0001, respectively. Label For PyramidNet-272 (Han et al., 2017), we addi- smoothing is adopted with its factor 0.1. We choose tionally apply some latest techniques: AutoAug- ρ = 0.1 for SAM and ρ = 0.2 for ASAM as a result of ment (Cubuk et al., 2019), CutMix (Yun et al., 2019) a grid search over {0.005, 0.01, 0.02, . . . , 0.5, 1.0, 2.0} and ShakeDrop (Yamada et al., 2019). We employ the using validation dataset. The results of the experiments are m-sharpness strategy with m = 32. Initial learning rate obtained from 3 independent runs. and mini-batch size are set to 0.05 and 256, respectively. The number of epochs is set to 900 for SAM and ASAM As can be seen in Table 5, we could observe improve- and 1800 for SGD. We choose ρ for SAM as 0.05, as ment even on IWSLT’14 in BLEU score when using in Foret et al. (2021), and ρ for ASAM as 1.0 for both Adam+ASAM instead of Adam or Adam+SAM. CIFAR-10 and CIFAR-100. Every entry in the tables represents mean and standard deviation of 5 independent runs. In both CIFAR-10 and CIFAR-100 cases, ASAM Table 5. BLEU scores for Adam, Adam+SAM and Adam+ASAM generally surpasses SGD and SAM, as can be seen in on IWSLT’14 DE-EN dataset using Transformer. Table 2 and Table 3. BLEU score Adam Adam+SAM Adam+ASAM For the sake of evaluations at larger scale, we compare the performance of SGD, SAM and ASAM on ImageNet. We Validation 35.34±
ASAM: Adaptive Sharpness-Aware Minimization for Scale-Invariant Learning of Deep Neural Networks R. Entropy-sgd: Biasing gradient descent into wide val- Table 6. Maximum test accuracies of ResNet-32 models trained leys. Journal of Statistical Mechanics: Theory and Ex- on CIFAR-10 with label noise. periment, 2019(12):124018, 2019. Noise rate SGD SAM ASAM 0% 94.50±0.11 94.80±0.12 94.88±0.12 Cubuk, E. D., Zoph, B., Mane, D., Vasudevan, V., and Le, 20% 91.32±0.23 92.94±0.12 93.21±0.10 Q. V. Autoaugment: Learning augmentation strategies 40% 87.68±0.05 90.62±0.18 90.89±0.13 from data. In Proceedings of the IEEE/CVF Conference 60% 82.50±0.30 86.58±0.30 87.41±0.16 on Computer Vision and Pattern Recognition, pp. 113– 80% 68.35±0.85 69.92±0.98 67.69±1.34 123, 2019. Deng, J., Dong, W., Socher, R., Li, L.-J., Li, K., and Fei- 6. Conclusions Fei, L. ImageNet: A large-scale hierarchical image database. In 2009 IEEE conference on computer vision In this paper, we have introduced adaptive sharpness with and pattern recognition, pp. 248–255. IEEE, 2009. scale-invariant property that improves training path in weight space by adjusting their maximization region with Dinh, L., Pascanu, R., Bengio, S., and Bengio, Y. Sharp respect to weight scale. Also, we have confirmed that minima can generalize for deep nets. In Interna- this property, which ASAM shares, contributes in improve- tional Conference on Machine Learning, pp. 1019–1028. ment of generalization performance. The superior per- PMLR, 2017. formance of ASAM is notable from the comparison tests conducted against SAM, which is currently state-of-the- Foret, P., Kleiner, A., Mobahi, H., and Neyshabur, B. art learning algorithm in many image classification bench- Sharpness-aware minimization for efficiently improving marks. In addition to the contribution as a learning algo- generalization. In International Conference on Learning rithm, adaptive sharpness can serve as a generalization mea- Representations, 2021. sure with stronger correlation with generalization gap ben- efiting from their scale-invariant property. Therefore adap- Han, D., Kim, J., and Kim, J. Deep pyramidal residual tive sharpness has a potential to be a metric for assessment networks. In Proceedings of the IEEE conference on of neural networks. We have also suggested the condition computer vision and pattern recognition, pp. 5927–5935, of normalization operator for adaptive sharpness but we did 2017. not cover all the normalization schemes which satisfy the condition. So this area could be further investigated for bet- He, K., Zhang, X., Ren, S., and Sun, J. Deep residual learn- ter generalization performance in future works. ing for image recognition. In Proceedings of the IEEE conference on computer vision and pattern recognition, pp. 770–778, 2016. 7. Acknowledgements We would like to thank Kangwook Lee, Jaedeok Kim and Hochreiter, S. and Schmidhuber, J. Flat minima. Neural Yonghyun Ryu for supports on our machine translation ex- computation, 9(1):1–42, 1997. periments. We also thank our other colleagues at Sam- Hochreiter, S., Schmidhuber, J., et al. Simplifying neural sung Research - Joohyung Lee, Chiyoun Park and Hyun- nets by discovering flat minima. Advances in neural in- Joo Jung - for their insightful discussions and feedback. formation processing systems, pp. 529–536, 1995. References Huang, G., Liu, Z., Van Der Maaten, L., and Weinberger, K. Q. Densely connected convolutional networks. In Cettolo, M., Niehues, J., Stüker, S., Bentivogli, L., and Fed- Proceedings of the IEEE conference on computer vision erico, M. Report on the 11th IWSLT evaluation cam- and pattern recognition, pp. 4700–4708, 2017. paign, IWSLT 2014. In Proceedings of the International Workshop on Spoken Language Translation, Hanoi, Viet- Jiang, L., Huang, D., Liu, M., and Yang, W. Beyond syn- nam, volume 57, 2014. thetic noise: Deep learning on controlled noisy labels. Chatterji, N., Neyshabur, B., and Sedghi, H. The intriguing In International Conference on Machine Learning, pp. role of module criticality in the generalization of deep 4804–4815. PMLR, 2020. networks. In International Conference on Learning Rep- resentations, 2019. Jiang, Y., Neyshabur, B., Mobahi, H., Krishnan, D., and Bengio, S. Fantastic generalization measures and where Chaudhari, P., Choromanska, A., Soatto, S., LeCun, Y., Bal- to find them. In International Conference on Learning dassi, C., Borgs, C., Chayes, J., Sagun, L., and Zecchina, Representations, 2019.
ASAM: Adaptive Sharpness-Aware Minimization for Scale-Invariant Learning of Deep Neural Networks Karakida, R., Akaho, S., and Amari, S.-i. The normal- Neyshabur, B., Bhojanapalli, S., Mcallester, D., and Srebro, ization method for alleviating pathological sharpness in N. Exploring generalization in deep learning. In Guyon, wide neural networks. In Advances in Neural Informa- I., Luxburg, U. V., Bengio, S., Wallach, H., Fergus, R., tion Processing Systems, volume 32. Curran Associates, Vishwanathan, S., and Garnett, R. (eds.), Advances in Inc., 2019. Neural Information Processing Systems, volume 30. Cur- ran Associates, Inc., 2017. Kendall, M. G. A new measure of rank correlation. Biometrika, 30(1/2):81–93, 1938. Rooyen, B. v., Menon, A. K., and Williamson, R. C. Learn- ing with symmetric label noise: the importance of be- Keskar, N. S., Nocedal, J., Tang, P. T. P., Mudigere, D., ing unhinged. In Proceedings of the 28th International and Smelyanskiy, M. On large-batch training for deep Conference on Neural Information Processing Systems- learning: Generalization gap and sharp minima. In 5th Volume 1, pp. 10–18, 2015. International Conference on Learning Representations, ICLR 2017, 2017. Simonyan, K. and Zisserman, A. Very deep convolutional networks for large-scale image recognition. In 3rd Inter- Kingma, D. P. and Ba, J. Adam: A method for stochastic national Conference on Learning Representations, ICLR optimization. In ICLR (Poster), 2015. 2015, San Diego, CA, USA, May 7-9, 2015, Conference Track Proceedings, 2015. Krizhevsky, A., Nair, V., and Hinton, G. CIFAR- 10 and CIFAR-100 datasets. 2009. URL Sun, X., Zhang, Z., Ren, X., Luo, R., and Li, L. Exploring https://www.cs.toronto.edu/~kriz/cifar.html. the vulnerability of deep neural networks: A study of parameter corruption. CoRR, abs/2006.05620, 2020. Laurent, B. and Massart, P. Adaptive estimation of a quadratic functional by model selection. Annals of Statis- Tsuzuku, Y., Sato, I., and Sugiyama, M. Normalized flat tics, pp. 1302–1338, 2000. minima: Exploring scale invariant definition of flat min- ima for neural networks using PAC-Bayesian analysis. Li, H., Xu, Z., Taylor, G., Studer, C., and Goldstein, T. Vi- In International Conference on Machine Learning, pp. sualizing the loss landscape of neural nets. In NIPS’18: 9636–9647. PMLR, 2020. Proceedings of the 32nd International Conference on Neural Information Processing Systems, pp. 6391–6401. Uhlmann, J. A generalized matrix inverse that is consis- Curran Associates Inc., 2018. tent with respect to diagonal transformations. SIAM Jour- nal on Matrix Analysis and Applications, 39(2):781–800, Liang, T., Poggio, T., Rakhlin, A., and Stokes, J. Fisher-rao 2018. metric, geometry, and complexity of neural networks. In Vaswani, A., Shazeer, N., Parmar, N., Uszkoreit, J., Jones, The 22nd International Conference on Artificial Intelli- L., Gomez, A. N., Kaiser, L., and Polosukhin, I. Atten- gence and Statistics, pp. 888–896. PMLR, 2019. tion is all you need. In NIPS, 2017. Loshchilov, I. and Hutter, F. Sgdr: Stochastic gra- Xie, S., Girshick, R., Dollár, P., Tu, Z., and He, K. Aggre- dient descent with warm restarts. arXiv preprint gated residual transformations for deep neural networks. arXiv:1608.03983, 2016. In Proceedings of the IEEE conference on computer vi- McAllester, D. A. PAC-Bayesian model averaging. In Pro- sion and pattern recognition, pp. 1492–1500, 2017. ceedings of the twelfth annual conference on Computa- Yamada, Y., Iwamura, M., Akiba, T., and Kise, K. Shake- tional learning theory, pp. 164–170, 1999. drop regularization for deep residual learning. IEEE Ac- cess, 7:186126–186136, 2019. Mobahi, H. Training recurrent neural networks by diffu- sion. arXiv preprint arXiv:1601.04114, 2016. Yi, M., Meng, Q., Chen, W., Ma, Z.-m., and Liu, T.-Y. Pos- itively scale-invariant flatness of relu neural networks. Müller, R., Kornblith, S., and Hinton, G. E. When does arXiv preprint arXiv:1903.02237, 2019. label smoothing help? In Wallach, H., Larochelle, H., Beygelzimer, A., d'Alché-Buc, F., Fox, E., and Garnett, Yue, X., Nouiehed, M., and Kontar, R. A. Salr: Sharpness- R. (eds.), Advances in Neural Information Processing aware learning rates for improved generalization. arXiv Systems, volume 32. Curran Associates, Inc., 2019. preprint arXiv:2011.05348, 2020. Nesterov, Y. E. A method for solving the convex program- Yun, S., Han, D., Oh, S. J., Chun, S., Choe, J., and Yoo, ming problem with convergence rate O(1/k 2 ). In Dokl. Y. CutMix: Regularization strategy to train strong clas- akad. nauk Sssr, volume 269, pp. 543–547, 1983. sifiers with localizable features. In Proceedings of the
ASAM: Adaptive Sharpness-Aware Minimization for Scale-Invariant Learning of Deep Neural Networks IEEE/CVF International Conference on Computer Vi- sion, pp. 6023–6032, 2019. Zagoruyko, S. and Komodakis, N. Wide residual networks. CoRR, abs/1605.07146, 2016.
ASAM: Adaptive Sharpness-Aware Minimization for Scale-Invariant Learning of Deep Neural Networks A. Proofs A.1. Proof of Theorem 2 We first introduce the following concentration inequality. Lemma 1. Let {ǫi , i = 1, . . . , k} be independent normal variables with mean 0 and variance σi2 . Then, !2 k r X 2 2 log n ≤ √1 P ǫi ≥ kσmax 1 + i=1 k n where σmax = max{σi }. Proof. From Lemma 1 in Laurent & Massart (2000), for any x > 0, v Xk k X u k uX P ǫ2i ≥ σi2 + 2t 2 σi4 x + 2σmax x ≤ exp(−x). i=1 i=1 i=1 Since v k X u k uX √ 2 σ + 2t i σ 4 x + 2σ 2 i max x 2 ≤ σmax (k + 2 kx + 2x) i=1 i=1 2 √ √ ≤ σmax ( k + 2x)2 , 1 plugging x = 2 log n proves the lemma. −1 Theorem 3. Let Tw be a normalization operator of w on Rk . If LD (w) ≤ Eǫi ∼N (0,σ2 ) [LD (w + ǫ)] for some σ > 0, then with probability 1 − δ, v u r !2 2 u 1 kwk log n n u LD (w) ≤ −1 max LS (w + ǫ) + t k log 1 + 2 22 1 + + 4 log + O(1) kTw ǫk2 ≤ρ n−1 η ρ k δ √ p where n = |S| and ρ = kσ(1 + log n/k)/η. Proof. The idea of the proof is given in Foret et al. (2021). From the assumption, adding Gaussian perturbation on the weight space does not improve the test error. Moreover, from Theorem 3.2 in Chatterji et al. (2019), the following general- ization bound holds under the perturbation: s kwk22 1 1 n Eǫi ∼N (0,σ2 ) [LD (w + ǫ)] ≤ Eǫi ∼N (0,σ2 ) [LS (w + ǫ)] + k log 1 + + log + C(n, σ, k) . n−1 4 kσ 2 δ Therefore, the left hand side of the statement can be bounded as s kwk22 1 1 n LD (w) ≤ Eǫi ∼N (0,σ2 ) [LS (w + ǫ)] + k log 1 + + log + C n−1 4 kσ 2 δ s kwk22 1 1 1 1 n ≤ 1− √ max LS (w + ǫ) + √ + k log 1 + + log + C n kTw−1 ǫk2 ≤ρ n n−1 4 kσ 2 δ s kwk22 1 n ≤ −1 max LS (w + ǫ) + k log 1 + 2 + 4 log + 4C kTw ǫk2 ≤ρ n−1 kσ δ −1 where the second inequality follows from Lemma 1 and kTw k2 ≤ η1 .
ASAM: Adaptive Sharpness-Aware Minimization for Scale-Invariant Learning of Deep Neural Networks B. Correlation Analysis To capture the correlation between generalization measures, i.e., sharpness and adaptive sharpness, and actual generaliza- tion gap, we utilize Kendall rank correlation coefficient (Kendall, 1938). Formally, given the set of pairs of a measure and generalization gap observed S = {(m1 , g1 ), . . . , (mn , gn )}, Kendall rank correlation coefficient τ is given by 2 X τ (S) = sign(mi − mj )sign(gi − gj ). n(n − 1) i mj and gi > gj among the whole n2 point pairs, and the proportion of discordant pairs, i.e., not concordant, the value of τ is in the range of [−1, 1]. While the rank correlation coefficient aggregates the effects of all the hyper-parameters, granulated Qcoefficient (Jiang et al., N 2019) can consider the correlation with respect to the each hyper-parameter separately. If Θ = i=1 Θi is the Cartesian product of each hyper-parameter space Θi , granulated coefficient with respect to Θi is given by ! 1 X X X X [ ψi = ··· ··· τ {(m(θ), g(θ)} |Θ−i | θ1 ∈Θ1 θi−1 ∈Θi−1 θi+1 ∈Θi+1 θN ∈ΘN θi ∈Θi PN where Θ−i = Θ1 × · · · Θi−1 × Θi+1 × ΘN . Then the average Ψ = i=1 ψi /N of ψi indicates whether the correlation exists across all hyper-parameters. We vary 4 hyper-parameters, mini-batch size, initial learning rate, weight decay coefficient and dropout rate, to produce different models. It is worth mentioning that changing one or two hyper-parameters for correlation analysis may cause spurious correlation (Jiang et al., 2019). For each hyper-parameter, we use 5 different values in Table 7 which implies that 54 = 625 configurations in total. Table 7. Hyper-parameter configurations. mini-batch size 32, 64, 128, 256, 512 learning rate 0.0033, 0.01, 0.033, 0.1, 0.33 weight decay 5e−7, 5e−6, 5e−5, 5e−4, 5e−3 dropout rate 0, 0.125, 0.25, 0.375, 0.5 By using the above hyper-parameter configurations, we train WideResNet-28-2 model on CIFAR-10 dataset. We use SGD as an optimizer and set momentum to 0.9. We set the number of epochs to 200 and cosine learning rate decay (Loshchilov & Hutter, 2016) is adopted. Also, random resize, padding by four pixels, normalization and random horizontal flip are applied for data augmentation and label smoothing (Müller et al., 2019) is adopted with its factor of 0.1. Using model parameters with training accuracy higher than 99.0% among the generated models, we calculate sharpness and adaptive sharpness with respect to generalization gap. To calculate adaptive sharpness, we fix normalization scheme to element-wise normalization. We calculate adaptive sharp- ness and sharpness with both p = 2 and p = ∞. We conduct a grid search over {5e−6, 1e−5, 5e−5, . . . , 5e−1, 1.0} to obtain each ρ for sharpness and adaptive sharpness which maximizes correlation with generalization gap. As results of the grid search, we select 1e−5 and 5e−4 as ρs for sharpness of p = 2 and p = ∞, respectively, and select 5e−1 and 5e−3 as ρs for adaptive sharpness of p = 2 and p = ∞, respectively. To calculate maximizers of each loss function for calculation of sharpness and adaptive sharpness, we follow m-sharpness strategy suggested by Foret et al. (2021) and m is set to 8.
You can also read