KVT: k-NN Attention for Boosting Vision Transformers
←
→
Page content transcription
If your browser does not render page correctly, please read the page content below
KVT: k-NN Attention for Boosting Vision Transformers Pichao Wang*, Xue Wang*, Fan Wang, Min Lin, Shuning Chang, Hao Li, Rong Jin Alibaba Group {pichao.wang,xue.wang}@alibaba-inc.com {fan.w, ming.l, shuning.csn,lihao.lh, jinrong.jr}@alibaba-inc.com arXiv:2106.00515v2 [cs.CV] 12 Jan 2022 Abstract de-facto standard for natural language processing (NLP) tasks thanks to its advantages in modelling long-range Convolutional Neural Networks (CNNs) have dominated dependencies. Recently, various vision transformers [17, computer vision for years, due to its ability in capturing 26, 44, 55, 56, 59, 63, 73–75] have been proposed by building locality and translation invariance. Recently, many vision pure or hybrid transformer models for visual tasks. Inspired transformer architectures have been proposed and they show by the transformer scaling success in NLP tasks, vision promising performance. A key component in vision trans- transformer converts an image into a sequence of image formers is the fully-connected self-attention which is more patches (tokens), with each patch encoded into a vector. powerful than CNNs in modelling long range dependencies. Since self-attention in the transformer is position agnostic, However, since the current dense self-attention uses all different positional encoding methods [12, 15, 17] have been image patches (tokens) to compute attention matrix, it may developed, and in [9, 63] their roles have been replaced neglect locality of images patches and involve noisy tokens by convolutions. Afterwards, all tokens are fed into (e.g., clutter background and occlusion), leading to a slow stacked transformer encoders for feature learning, with an training process and potential degradation of performance. extra CLS token [15, 17, 55] or global average pooling To address these problems, we propose the k-NN attention (GAP) [9,44] for final feature representation. Compared with for boosting vision transformers. Specifically, instead of CNNs, transformer-based models explicitly exploit global involving all the tokens for attention matrix calculation, we dependencies and demonstrate comparable, sometimes even only select the top-k similar tokens from the keys for each better, results than highly optimised CNNs [28, 51]. query to compute the attention map. The proposed k-NN Albeit achieving its initial success, vision transformers attention naturally inherits the local bias of CNNs without suffer from slow training. One of the key culprits is the introducing convolutional operations, as nearby tokens tend fully-connected self-attention, which takes all the tokens to be more similar than others. In addition, the k-NN to calculate the attention map. The dense attention not attention allows for the exploration of long range correlation only neglects the locality of images patches, an important and at the same time filters out irrelevant tokens by choosing feature of CNNS, but also involves noisy tokens into the the most similar tokens from the entire image. Despite its computation of self-attention, especially in the situations simplicity, we verify, both theoretically and empirically, that of cluttered background and occlusion. Both issues can k-NN attention is powerful in speeding up training and slow down the training significantly [13, 15]. Recent distilling noise from input tokens. Extensive experiments works [9, 63, 73] try to mitigate this problem by introducing are conducted by using 11 different vision transformer convolutional operators into vision transformers. Despite architectures to verify that the proposed k-NN attention can encouraging results, these studies fail to resolve the problem work with any existing transformer architectures to improve fundamentally from the transformer structure itself, limiting its prediction performance. their success. In this study, we address the challenge by directly attacking its root cause, i.e. the fully-connected self-attention. 1. Introduction To this end, we propose the k-NN attention to replace Traditional CNNs provide state of the art performance the fully-connected attention. Specifically, we do not use all in vision tasks, due to its ability in capturing locality the tokens for attention matrix calculation, but only select and translation invariance, while transformer [57] is the the top-k similar tokens from the sequence for each query token to compute the attention map. The proposed k-NN * The first two authors contribute equally. attention not only naturally inherits the local bias of CNNs 1
as the nearby tokens tend to be more similar than others, 2.2. Transformer for Vision but also builds the long range dependency by choosing the most similar tokens from the entire image. Compared Transformer [57] is an effective sequence-to-sequence with convolution operator which is an aggregation operation modeling network, and it has achieved state-of-the-art results built on Ising model [46] and the feature of each node is in NLP tasks with the success of BERT [16]. Due to its aggregated from nearby pixels, in the k-NN attention, the great success, it has also be exploited in computer vision aggregation graph is no longer limited by the spatial location community, and ‘Transformer in CNN’ becomes a popular of nodes but is adaptively computed via attention maps, thus, paradigm [3, 5, 8, 27, 40, 41, 62, 84]. ViT [17] leads the the k-NN attention can be regarded as a relieved version of other trend to use ‘CNN in Transformer’ paradigm for local bias. Despite its simplicity, we verify, both theoretically vision tasks [29, 39, 67, 72]. Even though ViT has been and empirically, that k-NN attention is effective in speeding proved compelling in vision recognition, it has several up training and distilling noisy tokens of vision transformers. drawbacks when compared with CNNs: large training Eleven different available vision transformer architectures data, fixed position embedding, rigid patch division, coarse are adopted to verify the effectiveness of the proposed k-NN modeling of inner patch feature, single scale, unstable attention. training process, slow speed training, easily fitting data and poor generalization, shallow & narrow architecture, and quadratic complexity. To deal with these problems, many variants have been proposed [18, 21, 22, 33, 35, 47, 60, 61, 69, 71, 79, 83]. For example, DeiT [55] adopts several 2. Related Work training techniques and uses distillation to extend ViT to a data-efficient version; CPVT [12] proposes a conditional 2.1. Self-attention positional encoding that is adaptable to arbitrary input sizes; CvT [63], CoaT [68] and Visformer [9] safely remove the Self-attention [57] has demonstrated promising results position embedding by introducing convolution operations; on NLP related tasks, and is making breakthroughs in T2T ViT [74], CeiT [73], and CvT [63] try to deal with the speech and computer vision. For time series modeling, self- rigid patch division by introducing convolution operation for attention operates over sequences in a step-wise manner. patch sequence generation; Focal Transformer [70] makes Specifically, at every time-step, self-attention assigns an each token attend its closest surrounding tokens at fine attention weight to each previous input element and uses granularity and the tokens far away at coarse granularity; these weights to compute the representation of the current TNT [26] proposes the pixel embedding to model the time-step as a weighted sum of the past inputs. Besides the inner patch feature; PVT [59], Swin Transformer [44], vanilla self-attention, many efficient transformers [54] have MViT [19], ViL [77], CvT [63], PiT [30], LeViT [24], been proposed. Among these efficient transformers, sparse CoaT [68], and Twins [11] adopt multi-scale technique attention and local attention are one of the main streams, for rich feature learning; DeepViT [82], CaiT [56], and which are highly related to our work. Sparse attention PatchViT [23] investigate the unstable training problem, and can be further categorized into data independent (fixed) propose the re-attention, re-scale and anti-over-smoothing sparse attention [2, 10, 31, 76] and content-based sparse techniques respectively for stable training; to accelerate the attention [14, 37, 48, 52]. Local attention [43–45] mainly convergence of training, ConViT [15], PiT [30], CeiT [73], considers attending only to a local window size. Our work LocalViT [42] and Visformer [9] introduce convolutional is also content-based attention, but compared with previous bias to speedup the training; conv-stem is adopted in works [14, 37, 48, 52], our k-NN attention has its merits LeViT [24], EarlyConv [64], CMT [25], VOLO [75] and for vision domain. For example, compared with routing ScaledReLU [58] to improve the robustness of training transformer [48] that clusters both queries and keys, our ViTs; LV-ViT [34] adopts several techniques including k-NN attention equals only clustering keys by assigning MixToken and Token Labeling for better training and each query as the cluster center, making the quantization feature generation; T2T ViT [74], DeepViT [82] and more continuous which is a better fitting of image domain; CaiT [56] try to train deeper vision transformer models; compared with reformer [37] which adopts complex hashing T2T ViT [74], ViL [77] and CoaT [68] adopt efficient attention that cannot guarantee each bucket contain both transformers [54] to deal with the quadratic complexity; queries and keys, our k-NN attention can guarantee that To further exploit the capacities of vision transformer, each query has number k keys for attention computing. In OmniNet [53], CrossViT [7] and So-ViT [65] propose the addition, our k-NN attention is also a generalized local dense omnidirectional representations, coarse-fine-grained attention, but compared with local attention, our k-NN patch fusion and cross co-variance pooling of visual tokens, attention not only enjoys the locality but also empowers respectively. However, all of these works adopt the the ability of global relation mining. fully-connected self-attention which will bring the noise 2
or irrelevant tokens for computing and slow down the version is the exact definition of k-NN attention, but it is training of networks. In this paper, we propose an efficient extremely slow because for each query it has to compute sparse attention, called k-NN attention, for boosting vision distances for different k keys. transformers. The proposed k-NN attention not only inherits Fast Version: As the computation of Euclidean distance the local bias of CNNs but also achieves the ability of global against all the keys for each query is slow, we propose a feature exploitation. It can also speed up the training and fast version of k-NN attention. The key idea is to take achieve better performance. advantage of matrix multiplication operations. Same as vanilla attention, all the queries and keys are calculated 3. k-NN Attention by the dot product, and then row-wise top-k elements are selected for softmax computing. The procedure can be 3.1. Vanilla Attention formulated as: For any sequence of length n, the vanilla attention in the QK > transformer is the dot product attention [57]. Following the V̂ knn = softmax Tk √ ·V, d standard notation, the attention matrix A ∈ Rn×n is defined as: where Tk (·) denotes the row-wise top-k selection operator: QK > ( A = softmax √ , Aij Aij ∈ top-k(row j) d [Tk (A)]ij = −∞ otherwise. where Q ∈ Rn×d denotes the queries while K ∈ Rn×d The computational overhead in the fast version can be denotes the keys, and d represents the dimension. By ignored, and it does not increase the model size. The source multiplying the attention weights A with the values V ∈ codes (four lines) of fast version k-NN attention in Pytorch Rn×d , the new values V̂ are calculated as: are shown in Algorithm 1 in supplementary, and the speed V̂ = AV . comparisons between slow and fast versions are illustrated The intuitive understanding of the attention is the weighted in section A.2 of supplementary. average over the old ones, where the weights are defined by 3.3. Theoretical Analysis on k-NN Attention the attention matrix A. In this paper, we consider the Q, K and V are generated via the linear projection of the input In this section, we will show theoretically that despite token matrix X: its simplicity, k-NN attention is powerful in speeding up Q = XWQ , K = XWK , V = XWV , network training and in distilling noisy tokens. All the proof of the lemmas are provided in the supplementary. where X ∈ Rn×dm , WQ , WK , WV ∈ Rdm ×d and dm is Convergence Speed-up. Compared to CNNs, the fully- the input token dimension. connected self-attention is able to capture long range One shortcoming with fully-connected self-attention is dependency. However, the price to pay is that the dense that irrelevant tokens, even though assigned with smaller self-attention model requires to mix each image patch with weights, are still taken into consideration when updating every other patch in the image, which has potential to mix the representation V , making it less resilient to noises in irrelevant information together, e.g. the foreground patches V . This shortcoming motivates us to develop the k-NN may be mixed with background patches through the self- attention. attention. This defect could significantly slow down the 3.2. k-NN Attention convergence as the goal of visual object recognition is to identify key visual patches relevant to a given class. Instead of computing the attention matrix for all the To see this, we consider the model with only learnable query-key pairs as in vanilla attention, we select the top- parameters WQ , WK in attention layers and adopting Adam k most similar keys and values for each query in the k-NN optimizer [36]. According to Theorem 4.1 in [36], Adam’s attention. There are two versions of k-NN attention, as convergence is proportional to O α−1 (G∞ + 1) + αG∞ , described below. where α is the learning rate and G∞ is an element-wise upper Slow Version: For the i-th query, we first compute the bound on the magnitude of the batch gradient1 . Let fi be Euclidean distance against all the keys, and then obtain its the loss function corresponding to batch i. Via chain rule of k-nearest neighbors Nik and Niv from keys and values, and derivative, the gradient w.r.t the WQ in a self-attention block lastly calculate the scaled dot product attention as: V̂ knn can be represented as ∇WQ fi = Fi (V̂ knn ) · ∂∂W Q , where hqi , (kj1 , ..., kjl , ..., kjk )i Ai = softmax √ , kjl ∈ Nik . 1 Theorem d 4.1 in [36] describes the upper bound for regrets (the gap on loss function value between the current step parameters and optimal The shape of final attention matrix is Aknn ∈ Rn×k , and parameters). One can telescope it to the average regrets to consider the the new values V̂ knn is the same size of values V̂ . The slow Adam’s convergence. 3
Fi (V̂ knn ) is a matrix output function. Since the possible Lemma 2 (informal). We consider the self-attention for value of V̂ knn is a subset of its fully-connected counterpart, query patch l. Let’s assume the patch xi are bounded with the upper bound of on the magnitude of Fi (V̂ knn ) is no mean µi for i = 1, 2, ..., n, and ρk is the ratio of the noisy larger than the full attention. We then introduce the weighed patches in all selected patches. Under mild conditions, the covariate matrix of patches to characterize the scale of follow inequality holds with high probability: ∂ V̂ knn ∂WQ in the following lemma. V̂lknn − µl WV ≤ O(k −1/2 + c1 ρk ), ∞ Lemma 1. (Informal) Let V̂lknn be the l-th row of the V̂ knn . where c1 is a positive number. We have ∂ V̂lknn ∂ V̂lknn In the above lemma, the quantity V̂lknn − µl WV ∝ Varal (x) and ∝ Varal (x), ∞ ∂WQ ∂WK where Varal (x) is the covariate matrix on patches measures the distance between V̂lknn , represention vector {x1 , ..., xn } with probability from l-th row of the attention updated by the k-NN attention, and its mean µl WV . We matrix. now consider two cases: the normal k-NN attention with The same is true for V̂ of the fully-connected self-attention. appropriately chosen k, and fully-connected attention with k = n. In the first case, with appropriately chosen k, we Since k-NN attention only uses patches with large should have most of the selected patches coming from the similarity, its Varal (x) will be smaller than that computed relevant group, implying a small ρk . By combining with from the fully-connected attention. As indicated in Lemma the fact that k is decently large, we expect a small upper V̂ knn 1, ∂∂W Q is proportional to variance Varal (x) and thus bound for the distance V̂lknn − µl WV , indicating that the scale of ∇WQ fi becomes smaller in k-NN attention. ∞ k-NN attention is powerful in distilling noise. For the Similarly, the scale of ∇WK fi is also smaller in k-NN case of fully-connected attention model, i.e. k = n, it attention. Therefore, the element-wise upper bound on batch is clearly that ρn ≈ 1, leading to a large distance between gradient G∞ in Adam analysis is also smaller for k-NN transformed representation V̂l and its mean, indicating that attention. For the same learning rate, the k-NN attention fully-connected attention model is not very effective in yields faster convergence. It is particularly significant at the distilling noisy patches, particularly when noise is large. beginning of training. This is because, due to the random Besides the instance with low signal-noise-ratio, the initialization, we expect a relatively small difference in instance with a large volume of backgrounds can also be similarities between patches, which essentially makes self- hard. In the next lemma, we show that under a proper choice attention behave like "global average". It will take multiple of k, with a high probability the k-NN attention will be able iterations for Adam to turn the "global average" into the to select all meaningful patches. real function of self-attention. In Table 2 and Figure 2, we numerically verify the training efficiency of k-NN attention Lemma 3 (informal). Let M∗ be the index set contains all as opposed to the fully-connected attention. patches relevant to query ql . Under mild conditions, there Noisy patch distillation. As already mentioned before, exist c2 ∈ (0, 1) such that with high probability, we have the fully-connected self-attention model may mix irrelevant n 1(ql ki> ≥ min∗ ql kj> ) ≤ O(nd−c2 ). X patches with relevant ones, particularly at the beginning of training when similarities between relevant patches are not j∈M i=1 significantly larger than those for irrelevant patches. k-NN The above lemma shows that if we select the top attention is more effective in identifying noisy patches by O(nd−c2 ) elements, with high probability, we will be able only considering the top k most similar patches. To formally to eliminate almost all the irrelevant noisy patches, without justify this point, we consider a simple scenario where all the losing any relevant patches. Numerically, we verify the patches are divided into two groups, the group of relevant proper k gains better performance (e.g., Figure 1) and for the patches and the group of noisy patches. All the patches hard instance k-NN gives more accurate attention regions. are sampled independently from unknown distributions. (e.g., Figure 4 and Figure 5). We assume that all relevant patches are sampled from distributions with the same shared mean, which is different 4. Experiments for Vision Transformers from the means of distributions for noisy patches. It is important to know that although distributions for the relevant In this section, we replace the dense attention with k- patches share the mean, those relevant patches can look NN attention on the existing vision transformers for image quite differently, due to the large variance in stochastic classification to verify the effectiveness of the proposed sampling. In the following Lemma, we will show that the method. The recent DeiT [55] and its variants, including k-NN attention is more effective in distilling noises for the T2T ViT [74], TNT [26], PiT [30], Swin [44], CvT [63], relevant patches than the fully-connected attention. So-ViT [65], Visformer [9], Twins [11], Dino [4] and 4
73.0 79.0 VOLO [75], are adopted for evaluation. These methods Top-1 Accuracy (%) Top-1 Accuracy (%) 72.5 include both supervised methods [9, 11, 26, 30, 44, 55, 63, 65, 78.8 72.0 74, 75] and self-supervised method [4]. Ablation studies are 78.6 71.5 provided to further analyze the properties of k-NN attention. 78.4 71.0 DeiT-Tiny Visformer-Tiny 78.2 4.1. Experimental Settings 25 50 75 k 100 125 150 100/25 100/45 k 150/45 180/45 We perform image classification on the standard ILSVRC- (a) DeiT-Tiny (c) Visformer-Tiny 81.90 2012 ImageNet dataset [49]. It contains 1.3 million images 81.85 Top-1 Accuracy (%) Top-1 Accuracy (%) 82.55 in the training set and 50K images in the validation set, 81.80 81.75 covering 1000 object classes. In our experiments, we 81.70 82.50 follow the experimental setting of original official released 81.65 81.60 82.45 codes [4, 9, 26, 30, 44, 55, 63, 65, 65, 74, 75]. For fair 81.55 PiT-Base 81.50 CvT-13 comparison, we only replace the vanilla attention with 100/100/100 500/200/100 1600/400/100 2500/600/150 100/100/25 360/100/25 360/100/45 360/150/45 k k proposed k-NN attention. Unless otherwise specified, the fast version of k-NN attention is adopted for evaluation. To (b) CvT-13 (d) PiT-Base speed up the slow version, we develop the CUDA version k- Figure 1. The impact of k on DeiT-Tiny, Visformer-Tiny, CvT-13 NN attention. As for the value k, different architectures are and PiT-Base. assigned with different values. For DeiT [55], So-ViT [65], Dino [4], CvT [63], TNT [26] PiT [30] and VOLO [75], as they directly split an input image into rigid tokens and there 4.3. The Impact of Number k is no information exchange in the token generation stage, we suppose the irrelevant tokens are easy to filter, and tend to The only parameter for k-NN attention is k, and its impact assign a smaller k compared with these complicated token is analyzed in Figure 1. As shown in the figure, for DeiT- generation methods [9, 11, 44, 74]. Specifically, we assign k Tiny, k = 100 is the best, where the total number of tokens to approximate n2 at each scale stage; for these complicated n = 196 (14 × 14), meaning that k approximates half of token generation methods [9, 11, 44, 74], we assign a larger n; for CvT-13, there are three scale stages with the number k which is approximately 23 n or 54 n at each scale stage. of tokens n1 = 3136, n2 = 784 and n3 = 196, and the best results are achieved when the k in each stage is assigned 4.2. Results on ImageNet to 1600/400/100, which also approximate half of n in each stage; for Visformer-Tiny, there are two scale stages with the Table 1 shows top-1 accuracy results on the ImageNet- number of tokens n1 = 196 and n2 = 49, and the best results 1K validation set by replacing the dense attention with are achieved when k in each stage is assigned to 150/45, as k-NN attention using eleven different vision transformer there are more than 21 conv layers for token generation and architectures. Both ConvNets and Transformers are listed the information in each token are already mixed, making it for comparison. With the budget constraint at around 5M hard to distinguish the irrelevant tokens, thus larger values parameters, k-NN attention reports 0.8% improvements of k are desired; for PiT-Base, there are three scale stages on DeiT-Tiny, for both fast version and slow version; with the number of tokens n1 = 961, n2 = 256 and n3 = on So-ViT-7, it also improves 0.8%; under the 10M 64, and the optimal values of k also approximate the half constraints, Visformer-Tiny gains 0.4%; with the budget of n. Please note that, we do not perform exhaustive search constraint at 40M parameters, the k-NN attention improves for the optimal choice of k, instead, a general rule as below the performance by 0.3% on CvT-13 and DeiT-Small, 0.4% is sufficient: k ≈ n2 at each scale stage for simple token on TNT-Small, and 0.5% on T2T-ViT-t-19; on Swin-Tiny, generation methods and k ≈ 23 n or 45 n for complicated k-NN attention still improves the performance a little bit token generation methods at each scale stage. even though the local attention is already adopted in Swin transformers; k-NN also has a positive effect on Dino-small, 4.4. Convergence Speed of k-NN Attention a self-supervised vision transformer architecture; under the 80M constraints, it gets 0.2% and 0.6% increase in In Table 2, we investigate the convergence speed of k-NN performance on Twins-SVT-Base and PiT-Base, respectively; attention. Three methods are included for comparison, i.e. for large input resolution, on VOLO-D1 and VOLO-D3, DeiT-Small [55], CvT-13 [63] and T2T-ViT-t-19 [74]. From k-NN attention improves 0.2% at 384x384 and 448x448 the Table we can see that the convergence speed of k-NN resolutions. It is worth noting that on ImageNet-1k dataset, attention is faster than full-connected attention, especially in it is very hard to improve the accuracy after 85%, but our k- the early stage of training. These observations reflect that NN attention can still consistently improve the performance removing the irrelevant tokens benefits the convergence of even without model size increase. neural networks training. 5
Arch. Model Input Params GFLOPs Top-1 (%) 2 ConvNets MnasNet-A3 [50] 224 5.2M 0.4 76.7% EfficientNet-B0 [51] 2242 5.3M 0.4 77.1% ShuffleNet [78] 2242 5.4M 0.5 73.7% MoblieNet [32] 2242 6.9M 0.6 74.7% Transformers DeiT-Tiny [55] 2242 5.7M 1.3 72.2% DeiT-Tiny [55] → k-NN Attn 2242 5.7M 1.3 73.0% DeiT-Tiny [55] → k-NN Attn-slow 2242 5.7M 1.3 73.0% So-ViT-7 [65] 2242 5.5M 1.3 76.2% So-ViT-7 [65] → k-NN Attn 2242 5.5M 1.3 77.0% ConvNets EfficientNet-B2 [51] 2242 9M 1.0 80.1% SAN10 [80] 2242 11M 2.2 77.1% ResNet-18 [28] 2242 12M 1.8 69.8% LambdaNets [1] 2242 15M - 78.4% Transformers Visformer-Tiny [9] 2242 10M 1.3 78.6% Visformer-Tiny [9] → k-NN Attn 2242 10M 1.3 79.0% ConvNets EfficientNet-B4 [51] 2242 19M 4.2 82.9% ResNet-50 [28] 2242 25M 4.1 79.1% ResNeXt50-32x4d [66] 2242 25M 4.3 79.5% REDNet-101 [38] 2242 25M 4.7 79.1% REDNet-152 [38] 2242 34M 6.8 79.3% ResNet-101 [28] 2242 45M 7.9 79.9% ResNeXt101-32x4d [66] 2242 45M 8.0 80.6% Transformers CvT-13 [63] 2242 20M 4.6 81.6% CvT-13 [63] → k-NN Attn 2242 20M 4.6 81.9% DeiT-Small [55] 2242 22M 4.6 79.8% DeiT-Small [55] → k-NN Attn 2242 22M 4.6 80.1% TNT-Small [26] 2242 24M 5.2 81.5% TNT-Small [26] → k-NN Attn 2242 24M 5.2 81.9% VOLO-D1 [75] 3842 27M 22.8 85.2% VOLO-D1 [75] → k-NN Attn 3842 27M 22.8 85.4% Swin-Tiny [44] 2242 28M 4.5 81.2% Swin-Tiny [44] → k-NN Attn 2242 28M 4.5 81.3% T2T-ViT-t-19 [74] 2242 39M 9.8 82.2% T2T-ViT-t-19 [74] → k-NN Attn 2242 39M 9.8 82.7% Transformer Dino-Small [4]! 2242 22M 4.6 76.0% (Self-supervised) Dino-Small [4]! → k-NN Attn 2242 22M 4.6 76.2% ConvNets ResNet-152 [28] 2242 60M 11.6 80.8% ResNeXt101-64x4d [66] 2242 84M 15.6 81.5% Transformers Twins-SVT-Base [11] 2242 56M 8.3 83.2% Twins-SVT-Base [11] → k-NN Attn 2242 56M 8.3 83.4% PiT-Base [30] 2242 74M 12.5 82.0% PiT-Base [30] → k-NN Attn 2242 74M 12.5 82.6% VOLO-D3 [75] 4482 86M 67.9 86.3% VOLO-D3 [75] → k-NN Attn 4482 86M 67.9 86.5% Table 1. The k-NN attention performance on ImageNet-1K validation set. "!" means we pretrain the model with 300 epochs and finetune the pretrained model for 100 epoch for linear eval, following the instructions of Dino training and evaluation; "→ k-NN Attn" represents replacing the vanilla attention with proposed k-NN attention;→ k-NN Attn-slow means adopting the slow version. 6
Top-1 accuracy Epoch DeiT-S DeiT-S → k CvT-13 CvT-13 → k T2T-ViT-t-19 T2T-ViT-t-19 → k 10 29.1% 31.3% 51.4% 54.2% 0.52% 0.68% 30 54.4% 55.4% 65.4% 68.1% 63.0% 63.2% 50 60.9% 62.0% 68.1% 70.5% 73.8% 74.4% 70 65.0% 65.8% 69.9% 72.2% 76.9% 77.3% 90 67.7% 68.2% 71.0% 73.0% 78.4% 78.6% 120 69.9% 70.7% 72.4% 73.7% 79.7% 80.0% 150 72.4% 72.4% 74.4% 74.9% 80.7% 80.9% 200 75.5% 75.7% 77.3% 77.7% 82.0% 82.3% 300 79.8% 80.0% 81.6% 81.9% 81.3% 81.7% Table 2. Ablation study on the convergence speed of k-NN attention. DeiT-Tiny Layer-wise cosine similarity between tokens: follow- Avg. standard deviation DeiT-Tiny w k-NN attention Avg. cosine similarity 0.40 0.030 DeiT-Tiny ing [23] this metric is defined as: DeiT-Tiny w k-NN attention 0.35 0.025 1 X tT tj DeiT-Tiny i 0.30 0.020 DeiT-Tiny w k-NN attention CosSim(t) = , n(n − 1) kti kktj k DeiT-Tiny 0.015 i6=j 0.25 DeiT-Tiny w k-NN attention 2 4 6 8 10 12 2 4 6 8 10 12 where ti represents the i-th token in each layer and k·k Layer depth Layer depth denotes the Euclidean norm. This metric implies the (a) Layer-wise cosine (b) Layer-wise s.t.d of convergence speed of the network. similarity of tokens attention weights DeiT-Tiny DeiT-Tiny Layer-wise standard deviation of attention weights: DeiT-Tiny DeiT-Tiny w k-NN attention DeiT-Tiny w k-NN attention Given a token ti and its softmax attention weight sfm(ti ), Residual/Main-Branch (attn) 0.5 Residual/Main-Branch (ffn) 1.0 DeiT-Tiny w k-NN attention 0.4 0.8 the standard deviation of the softmax attention weight 0.6 DeiT-Tiny std(sfm(ti )) is defined as the second metric. For multi-head DeiT-Tiny w k-NN attention 0.3 0.4 attention, the standard deviations over all heads are averaged. 0.2 0.2 This metric represents the degree of training stability. 2 4 6 Layer depth 8 10 12 2 4 6 Layer depth 8 10 12 Ratio between the norms of residual activations and (c) Ratio of residual and (d) Ratio of residual and main branch: The ratio between the norm of the residual main branch for attn main branch for ffn activations and the norm of the activations of the main branch in each layer is defined as kfl (t)k/ktk, where fl (t) can be Figure 2. The properties of k-NN attention. Blue and red dotted the attention layer or the FFN layer. This metric denotes the lines represent the metrics for k-NN attetion and the original fully- information preservation ability of the network. connected self-attention, respectively. Nonlocality: following [15], the nonlocality is defined by summing, for each query patch i, the distances kδij k to DeiT-Tiny DeiT-Tiny with k-NN attention all the key patches j weighted by their attention score Aij . Layer 1 7 7 Layer 2 The number obtained over the query patch is averaged to Layer 3 6 6 obtain the nonlocality metric of head h, which can the be Non-locality Non-locality Layer 4 Layer 5 5 5 Layer 5 averaged over the attention heads to obtain the nonlocality Layer 7 4 4 Layer 8 of the whole layer l: Layer 9 3 Layer 10 l,h 1 X h,l l 1 X l,h 3 Layer 11 Dloc := Aij kδij k , Dloc := Dloc , 2 Layer 12 L ij Nh 0 50 100 150 200 250 300 0 50 100 150 200 250 300 h Epochs Epochs where Dloc is the number of patches between the center of Figure 3. The nonlocality of DeiT-Tiny. It is plotted averaged over attention and the query patch; the further the attention heads all the images from training set of ImageNet-1k. look from the query patch, the higher the nonlocality. Comparisons of the four metrics on DeiT-tiny without distillation token are shown in Figure 2 and Figure 3. From 4.5. Other properties of k-NN attention Figure 2 (a) we can see that by using k-NN attention, the averaged cosine similarity is larger than that of using To analyze other properties of k-NN attention, four dense self-attention, which reflects that the convergence quantitative metrics are defined as follows. speed is faster for k-NN attention. Figure 2 (b) shows 7
t 0.05 0.1 0.25 0.75 dense Top-1 (%) crash crash 72.0 72.5 t 2 4 8 16 Top-1 (%) 72.5 72.5 72.5 72.1 k-NN Table 3. The top-1 (%) over the temperature t in softmax. input head0 head1 head2 head3 head4 head5 Figure 4. Self-attention heads from the last layer. that the averaged standard deviation of k-NN attention is smoother than that of fully-connected self-attention, and k-NN attn dense attn input the smoothness will help make the training more stable. Figure 2 (c) and (d) show the ratio between the norms of residual activations and main branch are consistent with each other for k-NN attention and dense attention, which indicates that there is nearly no information lost in k-NN attention by removing the irrelevant tokens. Figure 3 shows that, with k-NN attention, lower layers tend to focus more on the local areas (with more lines being pushed toward the bottom area in Figure 3), while the higher layers still maintain their (a) (b) (c) (d) (e) (f) (g) capability of extracting global information. Additionally, it is also observed that the non-locality of different layers is Figure 5. Visualization using Transformer Attribution [6]. (a)dog, spreading more evenly, indicating that they can explore a (b)wheel, (c)ferret, (d)monitor, (e)hornbill, (f)chain, (g)crib. larger variety of dependencies at different ranges. Backbone Method mIoU 4.6. Comparisons with temperature in softmax Swin-T UPerNet 44.5 Swin-T-k-NN UPerNet 44.7 k-NN attention effectively zeros the bottom N − k tokens Twins-SVT-Base UPerNet 47.4 out of the attention calculation. How does this compare Twins-SVT-Base-k-NN UPerNet 47.9 with introducing a temperature parameter to softmax over Table 4. Segmentation results for Swin-Tiny and Twins-SVT-Base the attention values? We compare our k-NN attention with/without k-NN attention on the ADE20K validation set. All with temperature t in softmax as softmax(attn/t). The the models are pretrained on ImageNet-1k. performance over the t is shown in Table 3. From the Table we can see that small t makes the training crash due to large value of attention values; the performance increases a little bit to 72.5 (baseline 72.2) with t assigned 4.8. Semantic Segmentation to appropriate values. The k-NN attention is more robust To verify the effects of k-NN attention on downstream compared with temperature in softmax, and achieves much tasks, the widely-used ADE20K [81] is adopted for evalu- better performance, 73.0 (k-NN attention) vs 72.5 (best ation. There are total 25K images in ADE20K, including performance for temperature in softmax). 20K images for training, 2K images for validation, and 3K images for test. It covers 150 different common foreground 4.7. Visualization categories. We adopt Swin-Tiny [44] and Twins-SVT- Base [11] for comparisons due to the well released codes, Figure 4 visualizes the self-attention heads from the last and the results are shown in Table 4. From the Table we layer on Dino-Small [4]. We can see that different heads can see that by replacing the vanilla attention with our k- attend to different semantic regions of an image. Compared NN attention, the performance of semantic segmentation with dense attention, the k-NN attention filters out most increases with almost no overhead. irrelevant information from background regions which are similar to the foreground, and successfully concentrates 5. Conclusion on the most informative foreground regions. Images from different classes are visualized in Figure 5 using Transformer In this paper, we propose an effective k-NN attention Attribution method [6] on DeiT-Tiny. It can be seen that the for boosting vision transformers. By selecting the most k-NN attention is more concentrated and accurate, especially similar keys for each query to calculate the attention, it in the situations of cluttered background and occlusion. screens out the most ineffective tokens. The removal of 8
irrelevant tokens speeds up the training. We theoretically [16] Jacob Devlin, Ming-Wei Chang, Kenton Lee, and Kristina prove its properties in speeding up training, distilling noises Toutanova. Bert: Pre-training of deep bidirectional without losing information, and increasing the performance transformers for language understanding. arXiv preprint by choosing a proper k. Several vision transformers are arXiv:1810.04805, 2018. 2 adopted to verify the effectiveness of the k-NN attention. [17] Alexey Dosovitskiy, Lucas Beyer, Alexander Kolesnikov, Dirk Weissenborn, Xiaohua Zhai, Thomas Unterthiner, References Mostafa Dehghani, Matthias Minderer, Georg Heigold, Sylvain Gelly, et al. An image is worth 16x16 words: [1] Irwan Bello. Lambdanetworks: Modeling long-range Transformers for image recognition at scale. arXiv preprint interactions without attention. In ICLR, 2020. 6 arXiv:2010.11929, 2020. 1, 2 [2] Iz Beltagy, Matthew E Peters, and Arman Cohan. Long- [18] Alaaeldin El-Nouby, Hugo Touvron, Mathilde Caron, Piotr former: The long-document transformer. arXiv preprint Bojanowski, Matthijs Douze, Armand Joulin, Ivan Laptev, arXiv:2004.05150, 2020. 2 Natalia Neverova, Gabriel Synnaeve, Jakob Verbeek, et al. [3] Nicolas Carion, Francisco Massa, Gabriel Synnaeve, Nicolas Xcit: Cross-covariance image transformers. arXiv preprint Usunier, Alexander Kirillov, and Sergey Zagoruyko. End- arXiv:2106.09681, 2021. 2 to-end object detection with transformers. In ECCV, 2020. [19] Haoqi Fan, Bo Xiong, Karttikeya Mangalam, Yanghao 2 Li, Zhicheng Yan, Jitendra Malik, and Christoph Feicht- [4] Mathilde Caron, Hugo Touvron, Ishan Misra, Hervé Jégou, enhofer. Multiscale vision transformers. arXiv preprint Julien Mairal, Piotr Bojanowski, and Armand Joulin. arXiv:2104.11227, 2021. 2 Emerging properties in self-supervised vision transformers. [20] Jianqing Fan and Jinchi Lv. Sure independence screening arXiv preprint arXiv:2104.14294, 2021. 4, 5, 6, 8 for ultrahigh dimensional feature space. Journal of the [5] Shuning Chang, Pichao Wang, Fan Wang, Hao Li, and Royal Statistical Society: Series B (Statistical Methodology), Jiashi Feng. Augmented transformer with adaptive graph 70(5):849–911, 2008. 19, 20, 21 for temporal action proposal generation. arXiv preprint [21] Jiemin Fang, Lingxi Xie, Xinggang Wang, Xiaopeng Zhang, arXiv:2103.16024, 2021. 2 Wenyu Liu, and Qi Tian. Msg-transformer: Exchanging local [6] Hila Chefer, Shir Gur, and Lior Wolf. Transformer spatial information by manipulating messenger tokens. arXiv interpretability beyond attention visualization. In CVPR, 2021. preprint arXiv:2105.15168, 2021. 2 8 [22] Peng Gao, Jiasen Lu, Hongsheng Li, Roozbeh Mottaghi, [7] Chun-Fu Chen, Quanfu Fan, and Rameswar Panda. Crossvit: and Aniruddha Kembhavi. Container: Context aggregation Cross-attention multi-scale vision transformer for image network. arXiv preprint arXiv:2106.01401, 2021. 2 classification. arXiv preprint arXiv:2103.14899, 2021. 2 [23] Chengyue Gong, Dilin Wang, Meng Li, Vikas Chandra, and [8] Xin Chen, Bin Yan, Jiawen Zhu, Dong Wang, Xiaoyun Yang, Qiang Liu. Improve vision transformers training by suppress- and Huchuan Lu. Transformer tracking. In CVPR, 2021. 2 ing over-smoothing. arXiv preprint arXiv:2104.12753, 2021. [9] Zhengsu Chen, Lingxi Xie, Jianwei Niu, Xuefeng Liu, 2, 7 Longhui Wei, and Qi Tian. Visformer: The vision-friendly [24] Ben Graham, Alaaeldin El-Nouby, Hugo Touvron, Pierre transformer. arXiv preprint arXiv:2104.12533, 2021. 1, 2, 4, Stock, Armand Joulin, Hervé Jégou, and Matthijs Douze. 5, 6 Levit: a vision transformer in convnet’s clothing for faster [10] Rewon Child, Scott Gray, Alec Radford, and Ilya Sutskever. inference. arXiv preprint arXiv:2104.01136, 2021. 2 Generating long sequences with sparse transformers. arXiv [25] Jianyuan Guo, Kai Han, Han Wu, Chang Xu, Yehui Tang, preprint arXiv:1904.10509, 2019. 2 Chunjing Xu, and Yunhe Wang. Cmt: Convolutional [11] Xiangxiang Chu, Zhi Tian, Yuqing Wang, Bo Zhang, Haibing neural networks meet vision transformers. arXiv preprint Ren, Xiaolin Wei, Huaxia Xia, and Chunhua Shen. Twins: arXiv:2107.06263, 2021. 2 Revisiting spatial attention design in vision transformers. [26] Kai Han, An Xiao, Enhua Wu, Jianyuan Guo, Chunjing Xu, arXiv preprint arXiv:2104.13840, 2021. 2, 4, 5, 6, 8 and Yunhe Wang. Transformer in transformer. arXiv preprint [12] Xiangxiang Chu, Bo Zhang, Zhi Tian, Xiaolin Wei, and arXiv:2103.00112, 2021. 1, 2, 4, 5, 6 Huaxia Xia. Do we really need explicit position encodings [27] Liang Han, Pichao Wang, Zhaozheng Yin, Fan Wang, and for vision transformers? arXiv preprint arXiv:2102.10882, Hao Li. Exploiting better feature aggregation for video object 2021. 1, 2 detection. In ACM MM, 2020. 2 [13] Jean-Baptiste Cordonnier, Andreas Loukas, and Martin Jaggi. [28] Kaiming He, Xiangyu Zhang, Shaoqing Ren, and Jian Sun. On the relationship between self-attention and convolutional Deep residual learning for image recognition. In CVPR, 2016. layers. In ICLR, 2019. 1 1, 6 [14] Gonçalo M Correia, Vlad Niculae, and André FT Martins. [29] Shuting He, Hao Luo, Pichao Wang, Fan Wang, Hao Li, Adaptively sparse transformers. In EMNLP, pages 2174– and Wei Jiang. Transreid: Transformer-based object re- 2184, 2019. 2 identification. arXiv preprint arXiv:2102.04378, 2021. 2 [15] Stéphane d’Ascoli, Hugo Touvron, Matthew Leavitt, Ari [30] Byeongho Heo, Sangdoo Yun, Dongyoon Han, Sanghyuk Morcos, Giulio Biroli, and Levent Sagun. Convit: Improving Chun, Junsuk Choe, and Seong Joon Oh. Rethinking vision transformers with soft convolutional inductive biases. spatial dimensions of vision transformers. arXiv preprint arXiv preprint arXiv:2103.10697, 2021. 1, 2, 7 arXiv:2103.16302, 2021. 2, 4, 5, 6 9
[31] Jonathan Ho, Nal Kalchbrenner, Dirk Weissenborn, and Tim [47] Yongming Rao, Wenliang Zhao, Benlin Liu, Jiwen Lu, Jie Salimans. Axial attention in multidimensional transformers. Zhou, and Cho-Jui Hsieh. Dynamicvit: Efficient vision arXiv preprint arXiv:1912.12180, 2019. 2 transformers with dynamic token sparsification. arXiv [32] Andrew G Howard, Menglong Zhu, Bo Chen, Dmitry preprint arXiv:2106.02034, 2021. 2 Kalenichenko, Weijun Wang, Tobias Weyand, Marco [48] Aurko Roy, Mohammad Saffar, Ashish Vaswani, and David Andreetto, and Hartwig Adam. Mobilenets: Efficient Grangier. Efficient content-based sparse attention with convolutional neural networks for mobile vision applications. routing transformers. Transactions of the Association for arXiv preprint arXiv:1704.04861, 2017. 6 Computational Linguistics, 9:53–68, 2021. 2 [33] Zilong Huang, Youcheng Ben, Guozhong Luo, Pei Cheng, [49] Olga Russakovsky, Jia Deng, Hao Su, Jonathan Krause, Gang Yu, and Bin Fu. Shuffle transformer: Rethinking Sanjeev Satheesh, Sean Ma, Zhiheng Huang, Andrej spatial shuffle for vision transformer. arXiv preprint Karpathy, Aditya Khosla, Michael Bernstein, et al. Imagenet arXiv:2106.03650, 2021. 2 large scale visual recognition challenge. IJCV, 115(3):211– [34] Zihang Jiang, Qibin Hou, Li Yuan, Daquan Zhou, Xiaojie 252, 2015. 5 Jin, Anran Wang, and Jiashi Feng. Token labeling: [50] Mingxing Tan, Bo Chen, Ruoming Pang, Vijay Vasudevan, Training a 85.5% top-1 accuracy vision transformer with 56m Mark Sandler, Andrew Howard, and Quoc V Le. Mnasnet: parameters on imagenet. arXiv preprint arXiv:2104.10858, Platform-aware neural architecture search for mobile. In 2021. 2 CVPR, 2019. 6 [35] Aditya Jonnalagadda, William Wang, and Miguel P Eckstein. [51] Mingxing Tan and Quoc Le. Efficientnet: Rethinking model Foveater: Foveated transformer for image classification. arXiv scaling for convolutional neural networks. In ICML, 2019. 1, preprint arXiv:2105.14173, 2021. 2 6 [36] Diederik P Kingma and Jimmy Ba. Adam: A method for [52] Yi Tay, Dara Bahri, Liu Yang, Donald Metzler, and Da-Cheng stochastic optimization. arXiv preprint arXiv:1412.6980, Juan. Sparse sinkhorn attention. In ICML, 2020. 2 2014. 3 [53] Yi Tay, Mostafa Dehghani, Vamsi Aribandi, Jai Gupta, Philip [37] Nikita Kitaev, Łukasz Kaiser, and Anselm Levskaya. Pham, Zhen Qin, Dara Bahri, Da-Cheng Juan, and Donald Reformer: The efficient transformer. arXiv preprint Metzler. Omninet: Omnidirectional representations from arXiv:2001.04451, 2020. 2 transformers. arXiv preprint arXiv:2103.01075, 2021. 2 [38] Duo Li, Jie Hu, Changhu Wang, Xiangtai Li, Qi She, Lei [54] Yi Tay, Mostafa Dehghani, Dara Bahri, and Donald Zhu, Tong Zhang, and Qifeng Chen. Involution: Inverting Metzler. Efficient transformers: A survey. arXiv preprint the inherence of convolution for visual recognition. arXiv arXiv:2009.06732, 2020. 2 preprint arXiv:2103.06255, 2021. 6 [55] Hugo Touvron, Matthieu Cord, Matthijs Douze, Francisco [39] Wenhao Li, Hong Liu, Runwei Ding, Mengyuan Liu, and Massa, Alexandre Sablayrolles, and Hervé Jégou. Training Pichao Wang. Lifting transformer for 3d human pose data-efficient image transformers & distillation through estimation in video. arXiv preprint arXiv:2103.14304, 2021. attention. arXiv preprint arXiv:2012.12877, 2020. 1, 2, 4, 5, 2 6 [40] Xiangyu Li, Yonghong Hou, Pichao Wang, Zhimin Gao, [56] Hugo Touvron, Matthieu Cord, Alexandre Sablayrolles, Mingliang Xu, and Wanqing Li. Transformer guided Gabriel Synnaeve, and Hervé Jégou. Going deeper with geometry model for flow-based unsupervised visual odometry. image transformers. arXiv preprint arXiv:2103.17239, 2021. Neural Computing and Applications, pages 1–12, 2021. 2 1, 2 [41] Xiangyu Li, Yonghong Hou, Pichao Wang, Zhimin Gao, [57] Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Mingliang Xu, and Wanqing Li. Trear: Transformer- Uszkoreit, Llion Jones, Aidan N Gomez, Lukasz Kaiser, and based rgb-d egocentric action recognition. arXiv preprint Illia Polosukhin. Attention is all you need. In NIPS, 2017. 1, arXiv:2101.03904, 2021. 2 2, 3 [42] Yawei Li, Kai Zhang, Jiezhang Cao, Radu Timofte, and Luc [58] Pichao Wang, Xue Wang, Hao Luo, Jingkai Zhou, Zhipeng Van Gool. Localvit: Bringing locality to vision transformers. Zhou, Fan Wang, Hao Li, and Rong Jin. Scaled relu arXiv preprint arXiv:2104.05707, 2021. 2 matters for training vision transformers. arXiv preprint [43] Peter J Liu, Mohammad Saleh, Etienne Pot, Ben Goodrich, arXiv:2109.03810, 2021. 2 Ryan Sepassi, Lukasz Kaiser, and Noam Shazeer. Generating [59] Wenhai Wang, Enze Xie, Xiang Li, Deng-Ping Fan, Kaitao wikipedia by summarizing long sequences. In ICLR, 2018. 2 Song, Ding Liang, Tong Lu, Ping Luo, and Ling Shao. [44] Ze Liu, Yutong Lin, Yue Cao, Han Hu, Yixuan Wei, Zheng Pyramid vision transformer: A versatile backbone for Zhang, Stephen Lin, and Baining Guo. Swin transformer: dense prediction without convolutions. arXiv preprint Hierarchical vision transformer using shifted windows. arXiv arXiv:2102.12122, 2021. 1, 2 preprint arXiv:2103.14030, 2021. 1, 2, 4, 5, 6, 8 [60] Wenxiao Wang, Lu Yao, Long Chen, Deng Cai, Xiaofei [45] Minh-Thang Luong, Hieu Pham, and Christopher D Manning. He, and Wei Liu. Crossformer: A versatile vision Effective approaches to attention-based neural machine transformer based on cross-scale attention. arXiv preprint translation. In EMNLP, 2015. 2 arXiv:2108.00154, 2021. 2 [46] Sunil Pai. Convolutional neural networks arise from ising [61] Yulin Wang, Rui Huang, Shiji Song, Zeyi Huang, and Gao models and restricted boltzmann machines. 2 Huang. Not all images are worth 16x16 words: Dynamic 10
vision transformers with adaptive sequence length. arXiv bird: Transformers for longer sequences. arXiv preprint preprint arXiv:2105.15075, 2021. 2 arXiv:2007.14062, 2020. 2 [62] Yuqing Wang, Zhaoliang Xu, Xinlong Wang, Chunhua Shen, [77] Pengchuan Zhang, Xiyang Dai, Jianwei Yang, Bin Xiao, Baoshan Cheng, Hao Shen, and Huaxia Xia. End-to-end Lu Yuan, Lei Zhang, and Jianfeng Gao. Multi-scale vision video instance segmentation with transformers. In CVPR, longformer: A new vision transformer for high-resolution 2021. 2 image encoding. arXiv preprint arXiv:2103.15358, 2021. 2 [63] Haiping Wu, Bin Xiao, Noel Codella, Mengchen Liu, Xiyang [78] Xiangyu Zhang, Xinyu Zhou, Mengxiao Lin, and Jian Sun. Dai, Lu Yuan, and Lei Zhang. Cvt: Introducing convolutions Shufflenet: An extremely efficient convolutional neural to vision transformers. arXiv preprint arXiv:2103.15808, network for mobile devices. In CVPR, 2018. 6 2021. 1, 2, 4, 5, 6 [79] Zizhao Zhang, Han Zhang, Long Zhao, Ting Chen, and Tomas [64] Tete Xiao, Mannat Singh, Eric Mintun, Trevor Darrell, Pfister. Aggregating nested transformers. arXiv preprint Piotr Dollár, and Ross Girshick. Early convolutions help arXiv:2105.12723, 2021. 2 transformers see better. arXiv preprint arXiv:2106.14881, [80] Hengshuang Zhao, Jiaya Jia, and Vladlen Koltun. Exploring 2021. 2 self-attention for image recognition. In CVPR, 2020. 6 [65] Jiangtao Xie, Ruiren Zeng, Qilong Wang, Ziqi Zhou, and [81] Bolei Zhou, Hang Zhao, Xavier Puig, Tete Xiao, Sanja Peihua Li. So-vit: Mind visual tokens for vision transformer. Fidler, Adela Barriuso, and Antonio Torralba. Semantic arXiv preprint arXiv:2104.10935, 2021. 2, 4, 5, 6 understanding of scenes through the ade20k dataset. IJCV, [66] Saining Xie, Ross Girshick, Piotr Dollár, Zhuowen Tu, and 127(3):302–321, 2019. 8 Kaiming He. Aggregated residual transformations for deep [82] Daquan Zhou, Bingyi Kang, Xiaojie Jin, Linjie Yang, neural networks. In CVPR, 2017. 6 Xiaochen Lian, Qibin Hou, and Jiashi Feng. Deepvit: [67] Tongkun Xu, Weihua Chen, Pichao Wang, Fan Wang, Hao Towards deeper vision transformer. arXiv preprint Li, and Rong Jin. Cdtrans: Cross-domain transformer arXiv:2103.11886, 2021. 2 for unsupervised domain adaptation. arXiv preprint [83] Daquan Zhou, Yujun Shi, Bingyi Kang, Weihao Yu, Zihang arXiv:2109.06165, 2021. 2 Jiang, Yuan Li, Xiaojie Jin, Qibin Hou, and Jiashi Feng. [68] Weijian Xu, Yifan Xu, Tyler Chang, and Zhuowen Tu. Co- Refiner: Refining self-attention for vision transformers. arXiv scale conv-attentional image transformers. arXiv preprint preprint arXiv:2106.03714, 2021. 2 arXiv:2104.06399, 2021. 2 [84] Cheng Zou, Bohan Wang, Yue Hu, Junqi Liu, Qian Wu, Yu [69] Yifan Xu, Zhijie Zhang, Mengdan Zhang, Kekai Sheng, Ke Zhao, Boxun Li, Chenguang Zhang, Chi Zhang, Yichen Wei, Li, Weiming Dong, Liqing Zhang, Changsheng Xu, and Xing et al. End-to-end human object interaction detection with hoi Sun. Evo-vit: Slow-fast token evolution for dynamic vision transformer. In CVPR, 2021. 2 transformer. arXiv preprint arXiv:2108.01390, 2021. 2 [70] Jianwei Yang, Chunyuan Li, Pengchuan Zhang, Xiyang Dai, Bin Xiao, Lu Yuan, and Jianfeng Gao. Focal self-attention for local-global interactions in vision transformers. arXiv preprint arXiv:2107.00641, 2021. 2 [71] Qihang Yu, Yingda Xia, Yutong Bai, Yongyi Lu, Alan Yuille, and Wei Shen. Glance-and-gaze vision transformer. arXiv preprint arXiv:2106.02277, 2021. 2 [72] Zitong Yu, Xiaobai Li, Pichao Wang, and Guoying Zhao. Transrppg: Remote photoplethysmography transformer for 3d mask face presentation attack detection. arXiv preprint arXiv:2104.07419, 2021. 2 [73] Kun Yuan, Shaopeng Guo, Ziwei Liu, Aojun Zhou, Fengwei Yu, and Wei Wu. Incorporating convolution designs into visual transformers. arXiv preprint arXiv:2103.11816, 2021. 1, 2 [74] Li Yuan, Yunpeng Chen, Tao Wang, Weihao Yu, Yujun Shi, Francis EH Tay, Jiashi Feng, and Shuicheng Yan. Tokens- to-token vit: Training vision transformers from scratch on imagenet. arXiv preprint arXiv:2101.11986, 2021. 1, 2, 4, 5, 6 [75] Li Yuan, Qibin Hou, Zihang Jiang, Jiashi Feng, and Shuicheng Yan. Volo: Vision outlooker for visual recognition. arXiv preprint arXiv:2106.13112, 2021. 1, 2, 5, 6 [76] Manzil Zaheer, Guru Guruganesh, Avinava Dubey, Joshua Ainslie, Chris Alberti, Santiago Ontanon, Philip Pham, Anirudh Ravula, Qifan Wang, Li Yang, et al. Big 11
A. Appendix A.1. Source codes of fast version k-NN attention in Pytorch The source codes of fast version k-NN attention in Pytorch are shown in Algorithm 1, and we can see that the core codes of fast version k-NN attention is consisted of only four lines, and it can be easily imported to any architecture using fully-connected attention. Algorithm 1 Codes of fast version k-NN attention in Pytorch. 1 class kNN-Attention(nn.Module): 2 def __init__(self,dim,num_heads=8,qkv_bias=False,qk_scale=None,attn_drop=0.,proj_drop=0.,topk=100): 3 super().__init__() 4 self.num_heads=num_heads 5 head_dim=dim//num_heads 6 self.scale=qk_scale or head_dim**-0.5 7 self.topk=topk 8 9 self.qkv=nn.Linear(dim,dim*3,bias=qkv_bias) 10 self.attn_drop=nn.Dropout(attn_drop) 11 self.proj=nn.Linear(dim,dim) 12 self.proj_drop=nn.Dropout(proj_drop) 13 14 def forward(self,x): 15 B,N,C=x.shape 16 qkv=self.qkv(x).reshape(B,N,3,self.num_heads,C//self.num_heads).permute(2,0,3,1,4) 17 q,k,v=qkv[0],qkv[1],qkv[2] #B,H,N,C 18 attn=(q@k.transpose(-2,-1))*self.scale #B,H,N,N 19 # the core code block 20 mask=torch.zeros(B,self.num_heads,N,N,device=x.device,requires_grad=False) 21 index=torch.topk(attn,k=self.topk,dim=-1,largest=True)[1] 22 mask.scatter_(-1,index,1.) 23 attn=torch.where(mask>0,attn,torch.full_like(attn,float(’-inf’))) 24 # end of the core code block 25 attn=torch.softmax(attn,dim=-1) 26 attn=self.attn_drop(attn) 27 x=(attn@v).transpose(1,2).reshape(B,N,C) 28 x=self.proj(x) 29 x=self.proj_drop(x) 30 31 return x A.2. Comparisons between slow version and fast version We develop two versions of k-NN attention, one slow version and one fast version. The k-NN attention is exactly defined by slow version, but its speed is extremely slow, as for each query it needs to select different k keys and values, and this procedure is very slow. To speedup, we developed the CUDA version, but the speed is still slower than fast version. The fast version takes advantages of matrix multiplication and greatly speedup the computing. The speed comparisons on DeiT-Tiny using 8 V100 are illustrated in Table 5. Table 5. The speed comparisons on DeiT-tiny for slow and fast version method time per iteration (second) slow version (pytorch) 8192 slow version (CUDA) 1.55 fast version (pytorch) 0.45 A.3. Proof Notations. Throughout this appendix, we denote xi as i-th element of vector x, Wij as the element at i-th row and j-th column of matrix W , and Wj as the j-th row of matrix W . Moreover, we denote xi as the i-th patch (token) of the inputs with xi = Xi . 12
Proof for Lemma 1 We first give the formal statement of Lemma 1. Lemma 4 (Formal statement of PnLemma 1). Let V̂l knn be the l-th row of the V̂ knn and Varal (x) = Eal [x> x] − Eal [x ]Eal [x] with Eal [x] = t=1 alt xt . Then for any i, j = 1, 2, ..., n, we have > ∂ V̂l > = xli WK,j Varal (x)WV ∝ Varal (x) ∂WQ,ij and ∂ V̂l > = xli WQ,j Varal (x)WV ∝ Varal (x). ∂WK,ij The same is true for V̂ of the fully-connected self-attention. Proof. Let’s first consider the derivative of V̂l over WQ,ij . Via some algebraic computation, we have n n ! ∂ V̂l ∂(al V ) X ∂Tlknn (t) X ∂Tlknn (k1 ) = = alt − alk1 xt WV , (1) ∂WQ,ij ∂WQ,ij t=1 ∂WQ,ij ∂WQ,ij k1 =1 where we denote Tlknn (k) as follow for shorthand: ( > > knn xl WQ WK xk1 , if patch k1 is selected in row l Tl (k1 ) = −∞, otherwise . Let denote set S = {i : patch i is selected in row l} and then we consider the right-hand-side of (1). n > > > > ! X ∂ xi WQ WK xk X ∂ xi WQ WK xk1 (1) = alt − alk1 xt WV ∂WQ,ij ∂WQ,ij t∈S k1 ∈S n ! X X = alt x1i xt WK,j − alk1 xli xk1 WK,j xt WV t∈S k1 ∈S Xn n X X = alt xli xt WK,j xt WV − alt xt WV · alk1 xli xk1 WK,j . (2) t∈S t∈S k1 ∈S | {z } | {z } | {z } (a) (b) (c) P Since al is the l-th row of the attention matrix, we have alt ≥ 0 and t alt = 1. It is possible to treat terms (a), (b) and (c) as the expectation of some quantities over t replicates with probability alt . Then (2) can be further simplified as (2) = Eal [xli xWK,j · xWV ] − Eal [xWK,j ] · Eal [xli xWV ] Eal [WK,j > x> · xWV ] − Eal [WK,j > x> ] · Eal [xWV ] = xli > Eal [x> x] − Eal [x> ] · Eal [xt ] WV = xli WK,j > = xli WK,j Varal (x)WV , (3) where the second equality uses the fact that xt WK,j is a scalar. Combing (1)-(3), we have ∂ V̂l > = xli WK,j Varal (x)WV ∝ Varal (x). (4) ∂WQ,ij Due the symmetric on Q and K, we can follow the similar procedure to show ∂ V̂l > = xli WQ,j Varal (x)WV ∝ Varal (x). (5) ∂WK,ij Finally, by setting k = n, one may verify that equations (4) and (5) also hold for fully-connected self-attention. 13
Proof for Lemma 2 Before given the formal statement of the Lemma 2, we first show the assumptions. Assumption 2 1. The token xi is the sub-gaussian random vector with mean µi and variance (σ 2 /d)I for i = 1, 2, ..., n. 2. µ follows a discrete distribution with finite values µ ∈ V. Moreover, there exist 0 < ν1 , 0 < ν2 < ν4 such that a) T > > kµi k = ν1 , and b) µi WQ WK µi ∈ [ν2 , ν4 ] for all i and |µi WQ WK µj | ≤ ν2 for all µi 6= µj ∈ V. > (ij) > (ij) 3. WV and WQ WK are element-wise bounded with ν5 and ν6 respectively, that is, |WV | ≤ ν5 and |(WQ WK ) | ≤ ν6 , for all i, j from 1 to d. In Assumption 2 we ensure that for a given query patch, the difference between the clustering center and noises are large enough to be distinguished. Lemma 5 (formal statement of Lemma 2). Let patch xi be σ 2 -subgaussian random variable with √ mean µi and there are k1 patches out of all k patches follows the same clustering center of query l. Per Assumption 2, when d ≥ 3(ψ(δ, d) + ν2 + ν4 ), then with probability 1 − 5δ, we have Pk √1 xl WQ W > xi s i=1 exp d kxi WV ψ(δ, d) 2 2d − µl W V ≤ 4 exp √ σν5 log Pk 1 > d dk δ j=1 exp √ xl WQ W xj d K ∞ ν2 − ν4 + ψ(δ, d) ν2 − ν4 + ψ(δ, d) k1 + 8 exp √ − 7 + exp √ kµ1 WV k∞ , d d k q where ψ(δ, d) = 2σν1 ν6 2 log 1δ + 2σ 2 ν6 log dδ . Proof. Without loss of generality, we assume the first k patch are the top-k selected patches. From Assumption 2.1, we can decompose xi = µi + hi , i = 1, 2, ..., k, where hi is the sub-gaussian random vector with zero mean. We then analyze the numerator part. k X 1 > exp √ xl WQ Wk xi xi WV i=1 d (a) (b) z }| { z }| { k k X 1 > > X 1 > > = exp √ µl WQ WK µi µi Wv + exp √ xl WQ WK xi hi Wv i=1 d i=1 d (c) z }| { k X 1 1 + exp √ xl WQ Wk> xi − exp √ µl WQ WK > > µi µi Wv . (6) i=1 d d Below we will bound (a), (b) and (c) separately. Upper bound for (a). Let denote index set S1 = {i : µ1 = µi , i = 1, 2, ..., k}. We then have X 1 > > (a) − exp √ x1 WQ WK xi µ1 WV i∈S1 d ∞ 1 > > ≤(k − |S1 |) max exp √ x1 WQ WK xi kµ1 WV k∞ i d ν2 ≤(k − k1 ) exp √ kµ1 WV k∞ , (7) d 14
You can also read