KVT: k-NN Attention for Boosting Vision Transformers

Page created by Gregory Vega
 
CONTINUE READING
KVT: k-NN Attention for Boosting Vision Transformers
KVT: k-NN Attention for Boosting Vision Transformers

                                                            Pichao Wang*, Xue Wang*, Fan Wang, Min Lin, Shuning Chang, Hao Li, Rong Jin
                                                                                          Alibaba Group
                                                                                               {pichao.wang,xue.wang}@alibaba-inc.com
                                                                                     {fan.w, ming.l, shuning.csn,lihao.lh, jinrong.jr}@alibaba-inc.com
arXiv:2106.00515v2 [cs.CV] 12 Jan 2022

                                                                          Abstract                                      de-facto standard for natural language processing (NLP)
                                                                                                                        tasks thanks to its advantages in modelling long-range
                                             Convolutional Neural Networks (CNNs) have dominated                        dependencies. Recently, various vision transformers [17,
                                         computer vision for years, due to its ability in capturing                     26, 44, 55, 56, 59, 63, 73–75] have been proposed by building
                                         locality and translation invariance. Recently, many vision                     pure or hybrid transformer models for visual tasks. Inspired
                                         transformer architectures have been proposed and they show                     by the transformer scaling success in NLP tasks, vision
                                         promising performance. A key component in vision trans-                        transformer converts an image into a sequence of image
                                         formers is the fully-connected self-attention which is more                    patches (tokens), with each patch encoded into a vector.
                                         powerful than CNNs in modelling long range dependencies.                       Since self-attention in the transformer is position agnostic,
                                         However, since the current dense self-attention uses all                       different positional encoding methods [12, 15, 17] have been
                                         image patches (tokens) to compute attention matrix, it may                     developed, and in [9, 63] their roles have been replaced
                                         neglect locality of images patches and involve noisy tokens                    by convolutions. Afterwards, all tokens are fed into
                                         (e.g., clutter background and occlusion), leading to a slow                    stacked transformer encoders for feature learning, with an
                                         training process and potential degradation of performance.                     extra CLS token [15, 17, 55] or global average pooling
                                         To address these problems, we propose the k-NN attention                       (GAP) [9,44] for final feature representation. Compared with
                                         for boosting vision transformers. Specifically, instead of                     CNNs, transformer-based models explicitly exploit global
                                         involving all the tokens for attention matrix calculation, we                  dependencies and demonstrate comparable, sometimes even
                                         only select the top-k similar tokens from the keys for each                    better, results than highly optimised CNNs [28, 51].
                                         query to compute the attention map. The proposed k-NN                             Albeit achieving its initial success, vision transformers
                                         attention naturally inherits the local bias of CNNs without                    suffer from slow training. One of the key culprits is the
                                         introducing convolutional operations, as nearby tokens tend                    fully-connected self-attention, which takes all the tokens
                                         to be more similar than others. In addition, the k-NN                          to calculate the attention map. The dense attention not
                                         attention allows for the exploration of long range correlation                 only neglects the locality of images patches, an important
                                         and at the same time filters out irrelevant tokens by choosing                 feature of CNNS, but also involves noisy tokens into the
                                         the most similar tokens from the entire image. Despite its                     computation of self-attention, especially in the situations
                                         simplicity, we verify, both theoretically and empirically, that                of cluttered background and occlusion. Both issues can
                                         k-NN attention is powerful in speeding up training and                         slow down the training significantly [13, 15]. Recent
                                         distilling noise from input tokens. Extensive experiments                      works [9, 63, 73] try to mitigate this problem by introducing
                                         are conducted by using 11 different vision transformer                         convolutional operators into vision transformers. Despite
                                         architectures to verify that the proposed k-NN attention can                   encouraging results, these studies fail to resolve the problem
                                         work with any existing transformer architectures to improve                    fundamentally from the transformer structure itself, limiting
                                         its prediction performance.                                                    their success. In this study, we address the challenge by
                                                                                                                        directly attacking its root cause, i.e. the fully-connected
                                                                                                                        self-attention.
                                         1. Introduction                                                                   To this end, we propose the k-NN attention to replace
                                            Traditional CNNs provide state of the art performance                       the fully-connected attention. Specifically, we do not use all
                                         in vision tasks, due to its ability in capturing locality                      the tokens for attention matrix calculation, but only select
                                         and translation invariance, while transformer [57] is the                      the top-k similar tokens from the sequence for each query
                                                                                                                        token to compute the attention map. The proposed k-NN
                                            * The   first two authors contribute equally.                               attention not only naturally inherits the local bias of CNNs

                                                                                                                    1
KVT: k-NN Attention for Boosting Vision Transformers
as the nearby tokens tend to be more similar than others,               2.2. Transformer for Vision
but also builds the long range dependency by choosing
the most similar tokens from the entire image. Compared                     Transformer [57] is an effective sequence-to-sequence
with convolution operator which is an aggregation operation             modeling network, and it has achieved state-of-the-art results
built on Ising model [46] and the feature of each node is               in NLP tasks with the success of BERT [16]. Due to its
aggregated from nearby pixels, in the k-NN attention, the               great success, it has also be exploited in computer vision
aggregation graph is no longer limited by the spatial location          community, and ‘Transformer in CNN’ becomes a popular
of nodes but is adaptively computed via attention maps, thus,           paradigm [3, 5, 8, 27, 40, 41, 62, 84]. ViT [17] leads the
the k-NN attention can be regarded as a relieved version of             other trend to use ‘CNN in Transformer’ paradigm for
local bias. Despite its simplicity, we verify, both theoretically       vision tasks [29, 39, 67, 72]. Even though ViT has been
and empirically, that k-NN attention is effective in speeding           proved compelling in vision recognition, it has several
up training and distilling noisy tokens of vision transformers.         drawbacks when compared with CNNs: large training
Eleven different available vision transformer architectures             data, fixed position embedding, rigid patch division, coarse
are adopted to verify the effectiveness of the proposed k-NN            modeling of inner patch feature, single scale, unstable
attention.                                                              training process, slow speed training, easily fitting data and
                                                                        poor generalization, shallow & narrow architecture, and
                                                                        quadratic complexity. To deal with these problems, many
                                                                        variants have been proposed [18, 21, 22, 33, 35, 47, 60, 61,
                                                                        69, 71, 79, 83]. For example, DeiT [55] adopts several
2. Related Work                                                         training techniques and uses distillation to extend ViT to
                                                                        a data-efficient version; CPVT [12] proposes a conditional
2.1. Self-attention                                                     positional encoding that is adaptable to arbitrary input sizes;
                                                                        CvT [63], CoaT [68] and Visformer [9] safely remove the
    Self-attention [57] has demonstrated promising results              position embedding by introducing convolution operations;
on NLP related tasks, and is making breakthroughs in                    T2T ViT [74], CeiT [73], and CvT [63] try to deal with the
speech and computer vision. For time series modeling, self-             rigid patch division by introducing convolution operation for
attention operates over sequences in a step-wise manner.                patch sequence generation; Focal Transformer [70] makes
Specifically, at every time-step, self-attention assigns an             each token attend its closest surrounding tokens at fine
attention weight to each previous input element and uses                granularity and the tokens far away at coarse granularity;
these weights to compute the representation of the current              TNT [26] proposes the pixel embedding to model the
time-step as a weighted sum of the past inputs. Besides the             inner patch feature; PVT [59], Swin Transformer [44],
vanilla self-attention, many efficient transformers [54] have           MViT [19], ViL [77], CvT [63], PiT [30], LeViT [24],
been proposed. Among these efficient transformers, sparse               CoaT [68], and Twins [11] adopt multi-scale technique
attention and local attention are one of the main streams,              for rich feature learning; DeepViT [82], CaiT [56], and
which are highly related to our work. Sparse attention                  PatchViT [23] investigate the unstable training problem, and
can be further categorized into data independent (fixed)                propose the re-attention, re-scale and anti-over-smoothing
sparse attention [2, 10, 31, 76] and content-based sparse               techniques respectively for stable training; to accelerate the
attention [14, 37, 48, 52]. Local attention [43–45] mainly              convergence of training, ConViT [15], PiT [30], CeiT [73],
considers attending only to a local window size. Our work               LocalViT [42] and Visformer [9] introduce convolutional
is also content-based attention, but compared with previous             bias to speedup the training; conv-stem is adopted in
works [14, 37, 48, 52], our k-NN attention has its merits               LeViT [24], EarlyConv [64], CMT [25], VOLO [75] and
for vision domain. For example, compared with routing                   ScaledReLU [58] to improve the robustness of training
transformer [48] that clusters both queries and keys, our               ViTs; LV-ViT [34] adopts several techniques including
k-NN attention equals only clustering keys by assigning                 MixToken and Token Labeling for better training and
each query as the cluster center, making the quantization               feature generation; T2T ViT [74], DeepViT [82] and
more continuous which is a better fitting of image domain;              CaiT [56] try to train deeper vision transformer models;
compared with reformer [37] which adopts complex hashing                T2T ViT [74], ViL [77] and CoaT [68] adopt efficient
attention that cannot guarantee each bucket contain both                transformers [54] to deal with the quadratic complexity;
queries and keys, our k-NN attention can guarantee that                 To further exploit the capacities of vision transformer,
each query has number k keys for attention computing. In                OmniNet [53], CrossViT [7] and So-ViT [65] propose the
addition, our k-NN attention is also a generalized local                dense omnidirectional representations, coarse-fine-grained
attention, but compared with local attention, our k-NN                  patch fusion and cross co-variance pooling of visual tokens,
attention not only enjoys the locality but also empowers                respectively. However, all of these works adopt the
the ability of global relation mining.                                  fully-connected self-attention which will bring the noise

                                                                    2
KVT: k-NN Attention for Boosting Vision Transformers
or irrelevant tokens for computing and slow down the                         version is the exact definition of k-NN attention, but it is
training of networks. In this paper, we propose an efficient                 extremely slow because for each query it has to compute
sparse attention, called k-NN attention, for boosting vision                 distances for different k keys.
transformers. The proposed k-NN attention not only inherits                     Fast Version: As the computation of Euclidean distance
the local bias of CNNs but also achieves the ability of global               against all the keys for each query is slow, we propose a
feature exploitation. It can also speed up the training and                  fast version of k-NN attention. The key idea is to take
achieve better performance.                                                  advantage of matrix multiplication operations. Same as
                                                                             vanilla attention, all the queries and keys are calculated
3. k-NN Attention                                                            by the dot product, and then row-wise top-k elements are
                                                                             selected for softmax computing. The procedure can be
3.1. Vanilla Attention                                                       formulated as:
    For any sequence of length n, the vanilla attention in the                                                  QK >
                                                                                                                    
transformer is the dot product attention [57]. Following the                           V̂ knn = softmax Tk        √        ·V,
                                                                                                                    d
standard notation, the attention matrix A ∈ Rn×n is defined
as:                                                                          where Tk (·) denotes the row-wise top-k selection operator:
                                   QK >
                                                                                                 (
                 A = softmax        √      ,                                                         Aij      Aij ∈ top-k(row j)
                                      d                                              [Tk (A)]ij =
                                                                                                     −∞       otherwise.
where Q ∈ Rn×d denotes the queries while K ∈ Rn×d
                                                                             The computational overhead in the fast version can be
denotes the keys, and d represents the dimension. By
                                                                             ignored, and it does not increase the model size. The source
multiplying the attention weights A with the values V ∈
                                                                             codes (four lines) of fast version k-NN attention in Pytorch
Rn×d , the new values V̂ are calculated as:                                  are shown in Algorithm 1 in supplementary, and the speed
                            V̂ = AV .                                        comparisons between slow and fast versions are illustrated
The intuitive understanding of the attention is the weighted                 in section A.2 of supplementary.
average over the old ones, where the weights are defined by
                                                                             3.3. Theoretical Analysis on k-NN Attention
the attention matrix A. In this paper, we consider the Q, K
and V are generated via the linear projection of the input                       In this section, we will show theoretically that despite
token matrix X:                                                              its simplicity, k-NN attention is powerful in speeding up
         Q = XWQ , K = XWK , V = XWV ,                                       network training and in distilling noisy tokens. All the proof
                                                                             of the lemmas are provided in the supplementary.
where X ∈ Rn×dm , WQ , WK , WV ∈ Rdm ×d and dm is                                Convergence Speed-up. Compared to CNNs, the fully-
the input token dimension.                                                   connected self-attention is able to capture long range
   One shortcoming with fully-connected self-attention is                    dependency. However, the price to pay is that the dense
that irrelevant tokens, even though assigned with smaller                    self-attention model requires to mix each image patch with
weights, are still taken into consideration when updating                    every other patch in the image, which has potential to mix
the representation V , making it less resilient to noises in                 irrelevant information together, e.g. the foreground patches
V . This shortcoming motivates us to develop the k-NN                        may be mixed with background patches through the self-
attention.                                                                   attention. This defect could significantly slow down the
3.2. k-NN Attention                                                          convergence as the goal of visual object recognition is to
                                                                             identify key visual patches relevant to a given class.
   Instead of computing the attention matrix for all the                         To see this, we consider the model with only learnable
query-key pairs as in vanilla attention, we select the top-                  parameters WQ , WK in attention layers and adopting Adam
k most similar keys and values for each query in the k-NN                    optimizer [36]. According to Theorem 4.1 in [36], Adam’s
attention. There are two versions of k-NN attention, as                      convergence is proportional to O α−1 (G∞ + 1) + αG∞ ,
                                                                                                                                            
described below.                                                             where α is the learning rate and G∞ is an element-wise upper
   Slow Version: For the i-th query, we first compute the                    bound on the magnitude of the batch gradient1 . Let fi be
Euclidean distance against all the keys, and then obtain its                 the loss function corresponding to batch i. Via chain rule of
k-nearest neighbors Nik and Niv from keys and values, and                    derivative, the gradient w.r.t the WQ in a self-attention block
lastly calculate the scaled dot product attention as:                                                                          V̂ knn
                                                                           can be represented as ∇WQ fi = Fi (V̂ knn ) · ∂∂W     Q
                                                                                                                                      , where
                   hqi , (kj1 , ..., kjl , ..., kjk )i
 Ai = softmax                     √                      , kjl ∈ Nik .          1 Theorem
                                     d                                                     4.1 in [36] describes the upper bound for regrets (the gap
                                                                             on loss function value between the current step parameters and optimal
The shape of final attention matrix is Aknn ∈ Rn×k , and                     parameters). One can telescope it to the average regrets to consider the
the new values V̂ knn is the same size of values V̂ . The slow               Adam’s convergence.

                                                                         3
KVT: k-NN Attention for Boosting Vision Transformers
Fi (V̂ knn ) is a matrix output function. Since the possible           Lemma 2 (informal). We consider the self-attention for
value of V̂ knn is a subset of its fully-connected counterpart,        query patch l. Let’s assume the patch xi are bounded with
the upper bound of on the magnitude of Fi (V̂ knn ) is no              mean µi for i = 1, 2, ..., n, and ρk is the ratio of the noisy
larger than the full attention. We then introduce the weighed          patches in all selected patches. Under mild conditions, the
covariate matrix of patches to characterize the scale of               follow inequality holds with high probability:
∂ V̂ knn
 ∂WQ in the following lemma.                                                     V̂lknn − µl WV         ≤ O(k −1/2 + c1 ρk ),
                                                                                                    ∞
Lemma 1. (Informal) Let V̂lknn be the l-th row of the V̂ knn .         where c1 is a positive number.
We have
        ∂ V̂lknn                   ∂ V̂lknn                               In the above lemma, the quantity V̂lknn − µl WV
                 ∝ Varal (x) and            ∝ Varal (x),                                                                          ∞
         ∂WQ                        ∂WK
where Varal (x) is the covariate matrix on patches                     measures the distance between V̂lknn , represention vector
{x1 , ..., xn } with probability from l-th row of the attention        updated by the k-NN attention, and its mean µl WV . We
matrix.                                                                now consider two cases: the normal k-NN attention with
The same is true for V̂ of the fully-connected self-attention.         appropriately chosen k, and fully-connected attention with
                                                                       k = n. In the first case, with appropriately chosen k, we
    Since k-NN attention only uses patches with large                  should have most of the selected patches coming from the
similarity, its Varal (x) will be smaller than that computed           relevant group, implying a small ρk . By combining with
from the fully-connected attention. As indicated in Lemma              the fact that k is decently large, we expect a small upper
     V̂ knn
1, ∂∂W   Q
             is proportional to variance Varal (x) and thus            bound for the distance V̂lknn − µl WV        , indicating that
the scale of ∇WQ fi becomes smaller in k-NN attention.                                                            ∞
                                                                       k-NN attention is powerful in distilling noise. For the
Similarly, the scale of ∇WK fi is also smaller in k-NN
                                                                       case of fully-connected attention model, i.e. k = n, it
attention. Therefore, the element-wise upper bound on batch
                                                                       is clearly that ρn ≈ 1, leading to a large distance between
gradient G∞ in Adam analysis is also smaller for k-NN
                                                                       transformed representation V̂l and its mean, indicating that
attention. For the same learning rate, the k-NN attention
                                                                       fully-connected attention model is not very effective in
yields faster convergence. It is particularly significant at the
                                                                       distilling noisy patches, particularly when noise is large.
beginning of training. This is because, due to the random
                                                                           Besides the instance with low signal-noise-ratio, the
initialization, we expect a relatively small difference in
                                                                       instance with a large volume of backgrounds can also be
similarities between patches, which essentially makes self-
                                                                       hard. In the next lemma, we show that under a proper choice
attention behave like "global average". It will take multiple
                                                                       of k, with a high probability the k-NN attention will be able
iterations for Adam to turn the "global average" into the
                                                                       to select all meaningful patches.
real function of self-attention. In Table 2 and Figure 2, we
numerically verify the training efficiency of k-NN attention           Lemma 3 (informal). Let M∗ be the index set contains all
as opposed to the fully-connected attention.                           patches relevant to query ql . Under mild conditions, there
    Noisy patch distillation. As already mentioned before,             exist c2 ∈ (0, 1) such that with high probability, we have
the fully-connected self-attention model may mix irrelevant                     n
                                                                                      1(ql ki> ≥ min∗ ql kj> ) ≤ O(nd−c2 ).
                                                                                X
patches with relevant ones, particularly at the beginning of
training when similarities between relevant patches are not                                     j∈M
                                                                                i=1
significantly larger than those for irrelevant patches. k-NN              The above lemma shows that if we select the top
attention is more effective in identifying noisy patches by            O(nd−c2 ) elements, with high probability, we will be able
only considering the top k most similar patches. To formally           to eliminate almost all the irrelevant noisy patches, without
justify this point, we consider a simple scenario where all the        losing any relevant patches. Numerically, we verify the
patches are divided into two groups, the group of relevant             proper k gains better performance (e.g., Figure 1) and for the
patches and the group of noisy patches. All the patches                hard instance k-NN gives more accurate attention regions.
are sampled independently from unknown distributions.                  (e.g., Figure 4 and Figure 5).
We assume that all relevant patches are sampled from
distributions with the same shared mean, which is different            4. Experiments for Vision Transformers
from the means of distributions for noisy patches. It is
important to know that although distributions for the relevant            In this section, we replace the dense attention with k-
patches share the mean, those relevant patches can look                NN attention on the existing vision transformers for image
quite differently, due to the large variance in stochastic             classification to verify the effectiveness of the proposed
sampling. In the following Lemma, we will show that the                method. The recent DeiT [55] and its variants, including
k-NN attention is more effective in distilling noises for the          T2T ViT [74], TNT [26], PiT [30], Swin [44], CvT [63],
relevant patches than the fully-connected attention.                   So-ViT [65], Visformer [9], Twins [11], Dino [4] and

                                                                   4
KVT: k-NN Attention for Boosting Vision Transformers
73.0                                                                                                    79.0
VOLO [75], are adopted for evaluation. These methods

                                                                                                                                                                                Top-1 Accuracy (%)
                                                                        Top-1 Accuracy (%)
                                                                                             72.5
include both supervised methods [9, 11, 26, 30, 44, 55, 63, 65,                                                                                                                                      78.8

                                                                                             72.0
74, 75] and self-supervised method [4]. Ablation studies are                                                                                                                                         78.6

                                                                                             71.5
provided to further analyze the properties of k-NN attention.                                                                                                                                        78.4
                                                                                             71.0
                                                                                                                                             DeiT-Tiny                                                                                                 Visformer-Tiny
                                                                                                                                                                                                     78.2
4.1. Experimental Settings                                                                           25      50     75
                                                                                                                         k
                                                                                                                                 100   125      150                                                         100/25        100/45
                                                                                                                                                                                                                                       k
                                                                                                                                                                                                                                           150/45           180/45

    We perform image classification on the standard ILSVRC-                                               (a) DeiT-Tiny                                                                                     (c) Visformer-Tiny
                                                                                             81.90
2012 ImageNet dataset [49]. It contains 1.3 million images                                   81.85

                                                                                                                                                           Top-1 Accuracy (%)
                                                                      Top-1 Accuracy (%)
                                                                                                                                                                                82.55
in the training set and 50K images in the validation set,                                    81.80

                                                                                             81.75
covering 1000 object classes. In our experiments, we                                         81.70                                                                              82.50

follow the experimental setting of original official released                                81.65

                                                                                             81.60
                                                                                                                                                                                82.45
codes [4, 9, 26, 30, 44, 55, 63, 65, 65, 74, 75]. For fair                                   81.55
                                                                                                                                                                                                                                                     PiT-Base
                                                                                             81.50                                                CvT-13

comparison, we only replace the vanilla attention with                                         100/100/100    500/200/100 1600/400/100 2500/600/150
                                                                                                                                                                                             100/100/25              360/100/25       360/100/45    360/150/45
                                                                                                                             k                                                                                                    k
proposed k-NN attention. Unless otherwise specified, the
fast version of k-NN attention is adopted for evaluation. To                                                 (b) CvT-13                                                                                        (d) PiT-Base
speed up the slow version, we develop the CUDA version k-
                                                                        Figure 1. The impact of k on DeiT-Tiny, Visformer-Tiny, CvT-13
NN attention. As for the value k, different architectures are
                                                                        and PiT-Base.
assigned with different values. For DeiT [55], So-ViT [65],
Dino [4], CvT [63], TNT [26] PiT [30] and VOLO [75], as
they directly split an input image into rigid tokens and there          4.3. The Impact of Number k
is no information exchange in the token generation stage, we
suppose the irrelevant tokens are easy to filter, and tend to              The only parameter for k-NN attention is k, and its impact
assign a smaller k compared with these complicated token               is analyzed in Figure 1. As shown in the figure, for DeiT-
generation methods [9, 11, 44, 74]. Specifically, we assign k          Tiny, k = 100 is the best, where the total number of tokens
to approximate n2 at each scale stage; for these complicated           n = 196 (14 × 14), meaning that k approximates half of
token generation methods [9, 11, 44, 74], we assign a larger           n; for CvT-13, there are three scale stages with the number
k which is approximately 23 n or 54 n at each scale stage.             of tokens n1 = 3136, n2 = 784 and n3 = 196, and the best
                                                                       results are achieved when the k in each stage is assigned
4.2. Results on ImageNet                                               to 1600/400/100, which also approximate half of n in each
                                                                       stage; for Visformer-Tiny, there are two scale stages with the
    Table 1 shows top-1 accuracy results on the ImageNet-
                                                                       number of tokens n1 = 196 and n2 = 49, and the best results
1K validation set by replacing the dense attention with
                                                                       are achieved when k in each stage is assigned to 150/45, as
k-NN attention using eleven different vision transformer
                                                                       there are more than 21 conv layers for token generation and
architectures. Both ConvNets and Transformers are listed
                                                                       the information in each token are already mixed, making it
for comparison. With the budget constraint at around 5M
                                                                       hard to distinguish the irrelevant tokens, thus larger values
parameters, k-NN attention reports 0.8% improvements
                                                                       of k are desired; for PiT-Base, there are three scale stages
on DeiT-Tiny, for both fast version and slow version;
                                                                       with the number of tokens n1 = 961, n2 = 256 and n3 =
on So-ViT-7, it also improves 0.8%; under the 10M
                                                                       64, and the optimal values of k also approximate the half
constraints, Visformer-Tiny gains 0.4%; with the budget
                                                                       of n. Please note that, we do not perform exhaustive search
constraint at 40M parameters, the k-NN attention improves
                                                                       for the optimal choice of k, instead, a general rule as below
the performance by 0.3% on CvT-13 and DeiT-Small, 0.4%
                                                                       is sufficient: k ≈ n2 at each scale stage for simple token
on TNT-Small, and 0.5% on T2T-ViT-t-19; on Swin-Tiny,
                                                                       generation methods and k ≈ 23 n or 45 n for complicated
k-NN attention still improves the performance a little bit
                                                                       token generation methods at each scale stage.
even though the local attention is already adopted in Swin
transformers; k-NN also has a positive effect on Dino-small,
                                                                        4.4. Convergence Speed of k-NN Attention
a self-supervised vision transformer architecture; under
the 80M constraints, it gets 0.2% and 0.6% increase in                      In Table 2, we investigate the convergence speed of k-NN
performance on Twins-SVT-Base and PiT-Base, respectively;               attention. Three methods are included for comparison, i.e.
for large input resolution, on VOLO-D1 and VOLO-D3,                     DeiT-Small [55], CvT-13 [63] and T2T-ViT-t-19 [74]. From
k-NN attention improves 0.2% at 384x384 and 448x448                     the Table we can see that the convergence speed of k-NN
resolutions. It is worth noting that on ImageNet-1k dataset,            attention is faster than full-connected attention, especially in
it is very hard to improve the accuracy after 85%, but our k-           the early stage of training. These observations reflect that
NN attention can still consistently improve the performance             removing the irrelevant tokens benefits the convergence of
even without model size increase.                                       neural networks training.

                                                                  5
KVT: k-NN Attention for Boosting Vision Transformers
Arch.                 Model                                      Input    Params     GFLOPs       Top-1 (%)
                                                                                  2
             ConvNets              MnasNet-A3 [50]                            224      5.2M       0.4          76.7%
                                   EfficientNet-B0 [51]                       2242     5.3M       0.4          77.1%
                                   ShuffleNet [78]                            2242     5.4M       0.5          73.7%
                                   MoblieNet [32]                             2242     6.9M       0.6          74.7%
             Transformers          DeiT-Tiny [55]                             2242     5.7M       1.3          72.2%
                                   DeiT-Tiny [55] → k-NN Attn                 2242     5.7M       1.3          73.0%
                                   DeiT-Tiny [55] → k-NN Attn-slow            2242     5.7M       1.3          73.0%
                                   So-ViT-7 [65]                              2242     5.5M       1.3          76.2%
                                   So-ViT-7 [65] → k-NN Attn                  2242     5.5M       1.3          77.0%
             ConvNets              EfficientNet-B2 [51]                       2242     9M         1.0          80.1%
                                   SAN10 [80]                                 2242     11M        2.2          77.1%
                                   ResNet-18 [28]                             2242     12M        1.8          69.8%
                                   LambdaNets [1]                             2242     15M        -            78.4%
             Transformers          Visformer-Tiny [9]                         2242     10M        1.3          78.6%
                                   Visformer-Tiny [9] → k-NN Attn             2242     10M        1.3          79.0%
             ConvNets              EfficientNet-B4 [51]                       2242     19M        4.2          82.9%
                                   ResNet-50 [28]                             2242     25M        4.1          79.1%
                                   ResNeXt50-32x4d [66]                       2242     25M        4.3          79.5%
                                   REDNet-101 [38]                            2242     25M        4.7          79.1%
                                   REDNet-152 [38]                            2242     34M        6.8          79.3%
                                   ResNet-101 [28]                            2242     45M        7.9          79.9%
                                   ResNeXt101-32x4d [66]                      2242     45M        8.0          80.6%
             Transformers          CvT-13 [63]                                2242     20M        4.6          81.6%
                                   CvT-13 [63] → k-NN Attn                    2242     20M        4.6          81.9%
                                   DeiT-Small [55]                            2242     22M        4.6          79.8%
                                   DeiT-Small [55] → k-NN Attn                2242     22M        4.6          80.1%
                                   TNT-Small [26]                             2242     24M        5.2          81.5%
                                   TNT-Small [26] → k-NN Attn                 2242     24M        5.2          81.9%
                                   VOLO-D1 [75]                               3842     27M        22.8         85.2%
                                   VOLO-D1 [75] → k-NN Attn                   3842     27M        22.8         85.4%
                                   Swin-Tiny [44]                             2242     28M        4.5          81.2%
                                   Swin-Tiny [44] → k-NN Attn                 2242     28M        4.5          81.3%
                                   T2T-ViT-t-19 [74]                          2242     39M        9.8          82.2%
                                   T2T-ViT-t-19 [74] → k-NN Attn              2242     39M        9.8          82.7%
             Transformer           Dino-Small [4]!                            2242     22M        4.6          76.0%
             (Self-supervised)     Dino-Small [4]! → k-NN Attn                2242     22M        4.6          76.2%
             ConvNets              ResNet-152 [28]                            2242     60M        11.6         80.8%
                                   ResNeXt101-64x4d [66]                      2242     84M        15.6         81.5%
             Transformers          Twins-SVT-Base [11]                        2242     56M        8.3          83.2%
                                   Twins-SVT-Base [11] → k-NN Attn            2242     56M        8.3          83.4%
                                   PiT-Base [30]                              2242     74M        12.5         82.0%
                                   PiT-Base [30] → k-NN Attn                  2242     74M        12.5         82.6%
                                   VOLO-D3 [75]                               4482     86M        67.9         86.3%
                                   VOLO-D3 [75] → k-NN Attn                   4482     86M        67.9         86.5%

Table 1. The k-NN attention performance on ImageNet-1K validation set. "!" means we pretrain the model with 300 epochs and finetune
the pretrained model for 100 epoch for linear eval, following the instructions of Dino training and evaluation; "→ k-NN Attn" represents
replacing the vanilla attention with proposed k-NN attention;→ k-NN Attn-slow means adopting the slow version.

                                                                   6
KVT: k-NN Attention for Boosting Vision Transformers
Top-1 accuracy
                                                                 Epoch
                                                                                                    DeiT-S                                                                       DeiT-S → k                                           CvT-13                                  CvT-13 → k T2T-ViT-t-19          T2T-ViT-t-19 → k
                                                                        10                          29.1%                                                                                  31.3%                                        51.4%                                   54.2%            0.52%               0.68%
                                                                        30                          54.4%                                                                                  55.4%                                        65.4%                                   68.1%            63.0%               63.2%
                                                                        50                          60.9%                                                                                  62.0%                                        68.1%                                   70.5%            73.8%               74.4%
                                                                        70                          65.0%                                                                                  65.8%                                        69.9%                                   72.2%            76.9%               77.3%
                                                                        90                          67.7%                                                                                  68.2%                                        71.0%                                   73.0%            78.4%               78.6%
                                                                       120                          69.9%                                                                                  70.7%                                        72.4%                                   73.7%            79.7%               80.0%
                                                                       150                          72.4%                                                                                  72.4%                                        74.4%                                   74.9%            80.7%               80.9%
                                                                       200                          75.5%                                                                                  75.7%                                        77.3%                                   77.7%            82.0%               82.3%
                                                                       300                          79.8%                                                                                  80.0%                                        81.6%                                   81.9%            81.3%               81.7%

                                                                                                                                                         Table 2. Ablation study on the convergence speed of k-NN attention.

                                                                                                                                                                                                                                                 DeiT-Tiny

                                                                                                                                                                                                                                                                                   Layer-wise cosine similarity between tokens: follow-
                                                                                                                                                     Avg. standard deviation

                                                                                                                                                                                                                                                 DeiT-Tiny w k-NN attention
Avg. cosine similarity

                               0.40                                                                                                                                            0.030               DeiT-Tiny

                                                                                                                                                                                                                                                                                   ing [23] this metric is defined as:
                                                                                           DeiT-Tiny w k-NN attention

                               0.35                                                                                                                                            0.025

                                                                                                                                                                                                                                                                                                                 1     X tT tj
                                             DeiT-Tiny
                                                                                                                                                                                                                                                                                                                            i
                               0.30
                                                                                                                                                                               0.020                                                DeiT-Tiny w k-NN attention
                                                                                                                                                                                                                                                                                              CosSim(t) =                          ,
                                                                                                                                                                                                                                                                                                             n(n − 1)   kti kktj k
                                                                                   DeiT-Tiny
                                                                                                                                                                               0.015                                                                                                                                   i6=j
                               0.25                                                DeiT-Tiny w k-NN attention

                                          2              4       6         8             10                12                                                                                      2           4          6               8             10             12          where ti represents the i-th token in each layer and k·k
                                                              Layer depth                                                                                                                                              Layer depth
                                                                                                                                                                                                                                                                                   denotes the Euclidean norm. This metric implies the
                                     (a) Layer-wise cosine                                                                                                                       (b) Layer-wise s.t.d of                                                                           convergence speed of the network.
                                      similarity of tokens                                                                                                                          attention weights
                                                                                      DeiT-Tiny                                                                                                                                           DeiT-Tiny
                                                                                                                                                                                                                                                                                       Layer-wise standard deviation of attention weights:
                                         DeiT-Tiny                                    DeiT-Tiny w k-NN attention                                                                                                                          DeiT-Tiny w k-NN attention

                                                                                                                                                                                                                                                                                   Given a token ti and its softmax attention weight sfm(ti ),
 Residual/Main-Branch (attn)

                               0.5
                                                                                                                        Residual/Main-Branch (ffn)

                                                                                                                                                     1.0
                                                                       DeiT-Tiny w k-NN attention

                               0.4
                                                                                                                                                     0.8                                                                                                                           the standard deviation of the softmax attention weight
                                                                                                                                                     0.6                               DeiT-Tiny
                                                                                                                                                                                                                                                                                   std(sfm(ti )) is defined as the second metric. For multi-head
                                                                                                                                                                                                                          DeiT-Tiny w k-NN attention
                               0.3
                                                                                                                                                     0.4                                                                                                                           attention, the standard deviations over all heads are averaged.
                               0.2                                                                                                                   0.2                                                                                                                           This metric represents the degree of training stability.
                                         2               4        6
                                                              Layer depth
                                                                               8            10                     12                                                                  2               4           6
                                                                                                                                                                                                               Layer depth
                                                                                                                                                                                                                                8               10               12
                                                                                                                                                                                                                                                                                       Ratio between the norms of residual activations and
                                   (c) Ratio of residual and                                                                                             (d) Ratio of residual and                                                                                                 main branch: The ratio between the norm of the residual
                                     main branch for attn                                                                                                   main branch for ffn                                                                                                    activations and the norm of the activations of the main branch
                                                                                                                                                                                                                                                                                   in each layer is defined as kfl (t)k/ktk, where fl (t) can be
 Figure 2. The properties of k-NN attention. Blue and red dotted                                                                                                                                                                                                                   the attention layer or the FFN layer. This metric denotes the
 lines represent the metrics for k-NN attetion and the original fully-                                                                                                                                                                                                             information preservation ability of the network.
 connected self-attention, respectively.                                                                                                                                                                                                                                               Nonlocality: following [15], the nonlocality is defined
                                                                                                                                                                                                                                                                                   by summing, for each query patch i, the distances kδij k to
                                                           DeiT-Tiny                                                    DeiT-Tiny with k-NN attention
                                                                                                                                                                                                                                                                                   all the key patches j weighted by their attention score Aij .
                                                                                                                                                                                                                                                                 Layer   1
                               7                                                                                   7                                                                                                                                             Layer   2         The number obtained over the query patch is averaged to
                                                                                                                                                                                                                                                                 Layer   3
                               6                                                                                   6                                                                                                                                                               obtain the nonlocality metric of head h, which can the be
                                                                                                    Non-locality
Non-locality

                                                                                                                                                                                                                                                                 Layer   4
                                                                                                                                                                                                                                                                 Layer   5
                               5                                                                                   5                                                                                                                                             Layer   5         averaged over the attention heads to obtain the nonlocality
                                                                                                                                                                                                                                                                 Layer   7
                               4                                                                                   4                                                                                                                                             Layer   8         of the whole layer l:
                                                                                                                                                                                                                                                                 Layer   9
                               3                                                                                                                                                                                                                                 Layer   10                l,h     1 X h,l              l       1 X l,h
                                                                                                                   3
                                                                                                                                                                                                                                                                 Layer   11              Dloc  :=        Aij kδij k , Dloc :=          Dloc ,
                               2                                                                                                                                                                                                                                 Layer   12                        L ij                        Nh
                                     0   50          100      150 200          250         300                          0                                         50 100 150 200 250 300                                                                                                                                            h
                                                             Epochs                                                                                                     Epochs                                                                                                     where Dloc is the number of patches between the center of
 Figure 3. The nonlocality of DeiT-Tiny. It is plotted averaged over
                                                                                                                                                                                                                                                                                   attention and the query patch; the further the attention heads
 all the images from training set of ImageNet-1k.                                                                                                                                                                                                                                  look from the query patch, the higher the nonlocality.
                                                                                                                                                                                                                                                                                      Comparisons of the four metrics on DeiT-tiny without
                                                                                                                                                                                                                                                                                   distillation token are shown in Figure 2 and Figure 3. From
 4.5. Other properties of k-NN attention                                                                                                                                                                                                                                           Figure 2 (a) we can see that by using k-NN attention,
                                                                                                                                                                                                                                                                                   the averaged cosine similarity is larger than that of using
   To analyze other properties of k-NN attention, four                                                                                                                                                                                                                             dense self-attention, which reflects that the convergence
 quantitative metrics are defined as follows.                                                                                                                                                                                                                                      speed is faster for k-NN attention. Figure 2 (b) shows

                                                                                                                                                                                                                                                                               7
KVT: k-NN Attention for Boosting Vision Transformers
t         0.05      0.1     0.25    0.75

                                                                      dense
         Top-1 (%)     crash    crash    72.0    72.5
             t           2        4       8       16
         Top-1 (%)     72.5     72.5     72.5    72.1

                                                                      k-NN
   Table 3. The top-1 (%) over the temperature t in softmax.
                                                                                                   input     head0 head1 head2 head3 head4 head5
                                                                                                          Figure 4. Self-attention heads from the last layer.
that the averaged standard deviation of k-NN attention is
smoother than that of fully-connected self-attention, and

                                                                      k-NN attn dense attn input
the smoothness will help make the training more stable.
Figure 2 (c) and (d) show the ratio between the norms of
residual activations and main branch are consistent with each
other for k-NN attention and dense attention, which indicates
that there is nearly no information lost in k-NN attention
by removing the irrelevant tokens. Figure 3 shows that,
with k-NN attention, lower layers tend to focus more on the
local areas (with more lines being pushed toward the bottom
area in Figure 3), while the higher layers still maintain their                                     (a)        (b)       (c)      (d)       (e)       (f)       (g)
capability of extracting global information. Additionally, it
is also observed that the non-locality of different layers is         Figure 5. Visualization using Transformer Attribution [6]. (a)dog,
spreading more evenly, indicating that they can explore a             (b)wheel, (c)ferret, (d)monitor, (e)hornbill, (f)chain, (g)crib.
larger variety of dependencies at different ranges.
                                                                                                                 Backbone               Method       mIoU
4.6. Comparisons with temperature in softmax                                                                     Swin-T                 UPerNet       44.5
                                                                                                              Swin-T-k-NN               UPerNet       44.7
   k-NN attention effectively zeros the bottom N − k tokens                                                 Twins-SVT-Base              UPerNet       47.4
out of the attention calculation. How does this compare                                                   Twins-SVT-Base-k-NN           UPerNet       47.9
with introducing a temperature parameter to softmax over
                                                                      Table 4. Segmentation results for Swin-Tiny and Twins-SVT-Base
the attention values? We compare our k-NN attention                   with/without k-NN attention on the ADE20K validation set. All
with temperature t in softmax as softmax(attn/t). The                 the models are pretrained on ImageNet-1k.
performance over the t is shown in Table 3. From the
Table we can see that small t makes the training crash
due to large value of attention values; the performance
increases a little bit to 72.5 (baseline 72.2) with t assigned        4.8. Semantic Segmentation
to appropriate values. The k-NN attention is more robust
                                                                         To verify the effects of k-NN attention on downstream
compared with temperature in softmax, and achieves much
                                                                      tasks, the widely-used ADE20K [81] is adopted for evalu-
better performance, 73.0 (k-NN attention) vs 72.5 (best
                                                                      ation. There are total 25K images in ADE20K, including
performance for temperature in softmax).
                                                                      20K images for training, 2K images for validation, and 3K
                                                                      images for test. It covers 150 different common foreground
4.7. Visualization                                                    categories. We adopt Swin-Tiny [44] and Twins-SVT-
                                                                      Base [11] for comparisons due to the well released codes,
    Figure 4 visualizes the self-attention heads from the last        and the results are shown in Table 4. From the Table we
layer on Dino-Small [4]. We can see that different heads              can see that by replacing the vanilla attention with our k-
attend to different semantic regions of an image. Compared            NN attention, the performance of semantic segmentation
with dense attention, the k-NN attention filters out most             increases with almost no overhead.
irrelevant information from background regions which are
similar to the foreground, and successfully concentrates              5. Conclusion
on the most informative foreground regions. Images from
different classes are visualized in Figure 5 using Transformer           In this paper, we propose an effective k-NN attention
Attribution method [6] on DeiT-Tiny. It can be seen that the          for boosting vision transformers. By selecting the most
k-NN attention is more concentrated and accurate, especially          similar keys for each query to calculate the attention, it
in the situations of cluttered background and occlusion.              screens out the most ineffective tokens. The removal of

                                                                  8
KVT: k-NN Attention for Boosting Vision Transformers
irrelevant tokens speeds up the training. We theoretically                 [16] Jacob Devlin, Ming-Wei Chang, Kenton Lee, and Kristina
prove its properties in speeding up training, distilling noises                 Toutanova.       Bert: Pre-training of deep bidirectional
without losing information, and increasing the performance                      transformers for language understanding. arXiv preprint
by choosing a proper k. Several vision transformers are                         arXiv:1810.04805, 2018. 2
adopted to verify the effectiveness of the k-NN attention.                 [17] Alexey Dosovitskiy, Lucas Beyer, Alexander Kolesnikov,
                                                                                Dirk Weissenborn, Xiaohua Zhai, Thomas Unterthiner,
References                                                                      Mostafa Dehghani, Matthias Minderer, Georg Heigold,
                                                                                Sylvain Gelly, et al. An image is worth 16x16 words:
 [1] Irwan Bello. Lambdanetworks: Modeling long-range                           Transformers for image recognition at scale. arXiv preprint
     interactions without attention. In ICLR, 2020. 6                           arXiv:2010.11929, 2020. 1, 2
 [2] Iz Beltagy, Matthew E Peters, and Arman Cohan. Long-                  [18] Alaaeldin El-Nouby, Hugo Touvron, Mathilde Caron, Piotr
     former: The long-document transformer. arXiv preprint                      Bojanowski, Matthijs Douze, Armand Joulin, Ivan Laptev,
     arXiv:2004.05150, 2020. 2                                                  Natalia Neverova, Gabriel Synnaeve, Jakob Verbeek, et al.
 [3] Nicolas Carion, Francisco Massa, Gabriel Synnaeve, Nicolas                 Xcit: Cross-covariance image transformers. arXiv preprint
     Usunier, Alexander Kirillov, and Sergey Zagoruyko. End-                    arXiv:2106.09681, 2021. 2
     to-end object detection with transformers. In ECCV, 2020.             [19] Haoqi Fan, Bo Xiong, Karttikeya Mangalam, Yanghao
     2                                                                          Li, Zhicheng Yan, Jitendra Malik, and Christoph Feicht-
 [4] Mathilde Caron, Hugo Touvron, Ishan Misra, Hervé Jégou,                    enhofer. Multiscale vision transformers. arXiv preprint
     Julien Mairal, Piotr Bojanowski, and Armand Joulin.                        arXiv:2104.11227, 2021. 2
     Emerging properties in self-supervised vision transformers.           [20] Jianqing Fan and Jinchi Lv. Sure independence screening
     arXiv preprint arXiv:2104.14294, 2021. 4, 5, 6, 8                          for ultrahigh dimensional feature space. Journal of the
 [5] Shuning Chang, Pichao Wang, Fan Wang, Hao Li, and                          Royal Statistical Society: Series B (Statistical Methodology),
     Jiashi Feng. Augmented transformer with adaptive graph                     70(5):849–911, 2008. 19, 20, 21
     for temporal action proposal generation. arXiv preprint               [21] Jiemin Fang, Lingxi Xie, Xinggang Wang, Xiaopeng Zhang,
     arXiv:2103.16024, 2021. 2                                                  Wenyu Liu, and Qi Tian. Msg-transformer: Exchanging local
 [6] Hila Chefer, Shir Gur, and Lior Wolf.              Transformer             spatial information by manipulating messenger tokens. arXiv
     interpretability beyond attention visualization. In CVPR, 2021.            preprint arXiv:2105.15168, 2021. 2
     8                                                                     [22] Peng Gao, Jiasen Lu, Hongsheng Li, Roozbeh Mottaghi,
 [7] Chun-Fu Chen, Quanfu Fan, and Rameswar Panda. Crossvit:                    and Aniruddha Kembhavi. Container: Context aggregation
     Cross-attention multi-scale vision transformer for image                   network. arXiv preprint arXiv:2106.01401, 2021. 2
     classification. arXiv preprint arXiv:2103.14899, 2021. 2              [23] Chengyue Gong, Dilin Wang, Meng Li, Vikas Chandra, and
 [8] Xin Chen, Bin Yan, Jiawen Zhu, Dong Wang, Xiaoyun Yang,                    Qiang Liu. Improve vision transformers training by suppress-
     and Huchuan Lu. Transformer tracking. In CVPR, 2021. 2                     ing over-smoothing. arXiv preprint arXiv:2104.12753, 2021.
 [9] Zhengsu Chen, Lingxi Xie, Jianwei Niu, Xuefeng Liu,                        2, 7
     Longhui Wei, and Qi Tian. Visformer: The vision-friendly              [24] Ben Graham, Alaaeldin El-Nouby, Hugo Touvron, Pierre
     transformer. arXiv preprint arXiv:2104.12533, 2021. 1, 2, 4,               Stock, Armand Joulin, Hervé Jégou, and Matthijs Douze.
     5, 6                                                                       Levit: a vision transformer in convnet’s clothing for faster
[10] Rewon Child, Scott Gray, Alec Radford, and Ilya Sutskever.                 inference. arXiv preprint arXiv:2104.01136, 2021. 2
     Generating long sequences with sparse transformers. arXiv             [25] Jianyuan Guo, Kai Han, Han Wu, Chang Xu, Yehui Tang,
     preprint arXiv:1904.10509, 2019. 2                                         Chunjing Xu, and Yunhe Wang. Cmt: Convolutional
[11] Xiangxiang Chu, Zhi Tian, Yuqing Wang, Bo Zhang, Haibing                   neural networks meet vision transformers. arXiv preprint
     Ren, Xiaolin Wei, Huaxia Xia, and Chunhua Shen. Twins:                     arXiv:2107.06263, 2021. 2
     Revisiting spatial attention design in vision transformers.           [26] Kai Han, An Xiao, Enhua Wu, Jianyuan Guo, Chunjing Xu,
     arXiv preprint arXiv:2104.13840, 2021. 2, 4, 5, 6, 8                       and Yunhe Wang. Transformer in transformer. arXiv preprint
[12] Xiangxiang Chu, Bo Zhang, Zhi Tian, Xiaolin Wei, and                       arXiv:2103.00112, 2021. 1, 2, 4, 5, 6
     Huaxia Xia. Do we really need explicit position encodings             [27] Liang Han, Pichao Wang, Zhaozheng Yin, Fan Wang, and
     for vision transformers? arXiv preprint arXiv:2102.10882,                  Hao Li. Exploiting better feature aggregation for video object
     2021. 1, 2                                                                 detection. In ACM MM, 2020. 2
[13] Jean-Baptiste Cordonnier, Andreas Loukas, and Martin Jaggi.           [28] Kaiming He, Xiangyu Zhang, Shaoqing Ren, and Jian Sun.
     On the relationship between self-attention and convolutional               Deep residual learning for image recognition. In CVPR, 2016.
     layers. In ICLR, 2019. 1                                                   1, 6
[14] Gonçalo M Correia, Vlad Niculae, and André FT Martins.                [29] Shuting He, Hao Luo, Pichao Wang, Fan Wang, Hao Li,
     Adaptively sparse transformers. In EMNLP, pages 2174–                      and Wei Jiang. Transreid: Transformer-based object re-
     2184, 2019. 2                                                              identification. arXiv preprint arXiv:2102.04378, 2021. 2
[15] Stéphane d’Ascoli, Hugo Touvron, Matthew Leavitt, Ari                 [30] Byeongho Heo, Sangdoo Yun, Dongyoon Han, Sanghyuk
     Morcos, Giulio Biroli, and Levent Sagun. Convit: Improving                 Chun, Junsuk Choe, and Seong Joon Oh. Rethinking
     vision transformers with soft convolutional inductive biases.              spatial dimensions of vision transformers. arXiv preprint
     arXiv preprint arXiv:2103.10697, 2021. 1, 2, 7                             arXiv:2103.16302, 2021. 2, 4, 5, 6

                                                                       9
KVT: k-NN Attention for Boosting Vision Transformers
[31] Jonathan Ho, Nal Kalchbrenner, Dirk Weissenborn, and Tim              [47] Yongming Rao, Wenliang Zhao, Benlin Liu, Jiwen Lu, Jie
     Salimans. Axial attention in multidimensional transformers.                Zhou, and Cho-Jui Hsieh. Dynamicvit: Efficient vision
     arXiv preprint arXiv:1912.12180, 2019. 2                                   transformers with dynamic token sparsification. arXiv
[32] Andrew G Howard, Menglong Zhu, Bo Chen, Dmitry                             preprint arXiv:2106.02034, 2021. 2
     Kalenichenko, Weijun Wang, Tobias Weyand, Marco                       [48] Aurko Roy, Mohammad Saffar, Ashish Vaswani, and David
     Andreetto, and Hartwig Adam. Mobilenets: Efficient                         Grangier. Efficient content-based sparse attention with
     convolutional neural networks for mobile vision applications.              routing transformers. Transactions of the Association for
     arXiv preprint arXiv:1704.04861, 2017. 6                                   Computational Linguistics, 9:53–68, 2021. 2
[33] Zilong Huang, Youcheng Ben, Guozhong Luo, Pei Cheng,                  [49] Olga Russakovsky, Jia Deng, Hao Su, Jonathan Krause,
     Gang Yu, and Bin Fu. Shuffle transformer: Rethinking                       Sanjeev Satheesh, Sean Ma, Zhiheng Huang, Andrej
     spatial shuffle for vision transformer. arXiv preprint                     Karpathy, Aditya Khosla, Michael Bernstein, et al. Imagenet
     arXiv:2106.03650, 2021. 2                                                  large scale visual recognition challenge. IJCV, 115(3):211–
[34] Zihang Jiang, Qibin Hou, Li Yuan, Daquan Zhou, Xiaojie                     252, 2015. 5
     Jin, Anran Wang, and Jiashi Feng.            Token labeling:          [50] Mingxing Tan, Bo Chen, Ruoming Pang, Vijay Vasudevan,
     Training a 85.5% top-1 accuracy vision transformer with 56m                Mark Sandler, Andrew Howard, and Quoc V Le. Mnasnet:
     parameters on imagenet. arXiv preprint arXiv:2104.10858,                   Platform-aware neural architecture search for mobile. In
     2021. 2                                                                    CVPR, 2019. 6
[35] Aditya Jonnalagadda, William Wang, and Miguel P Eckstein.             [51] Mingxing Tan and Quoc Le. Efficientnet: Rethinking model
     Foveater: Foveated transformer for image classification. arXiv             scaling for convolutional neural networks. In ICML, 2019. 1,
     preprint arXiv:2105.14173, 2021. 2                                         6
[36] Diederik P Kingma and Jimmy Ba. Adam: A method for                    [52] Yi Tay, Dara Bahri, Liu Yang, Donald Metzler, and Da-Cheng
     stochastic optimization. arXiv preprint arXiv:1412.6980,                   Juan. Sparse sinkhorn attention. In ICML, 2020. 2
     2014. 3                                                               [53] Yi Tay, Mostafa Dehghani, Vamsi Aribandi, Jai Gupta, Philip
[37] Nikita Kitaev, Łukasz Kaiser, and Anselm Levskaya.                         Pham, Zhen Qin, Dara Bahri, Da-Cheng Juan, and Donald
     Reformer: The efficient transformer.          arXiv preprint               Metzler. Omninet: Omnidirectional representations from
     arXiv:2001.04451, 2020. 2                                                  transformers. arXiv preprint arXiv:2103.01075, 2021. 2
[38] Duo Li, Jie Hu, Changhu Wang, Xiangtai Li, Qi She, Lei                [54] Yi Tay, Mostafa Dehghani, Dara Bahri, and Donald
     Zhu, Tong Zhang, and Qifeng Chen. Involution: Inverting                    Metzler. Efficient transformers: A survey. arXiv preprint
     the inherence of convolution for visual recognition. arXiv                 arXiv:2009.06732, 2020. 2
     preprint arXiv:2103.06255, 2021. 6                                    [55] Hugo Touvron, Matthieu Cord, Matthijs Douze, Francisco
[39] Wenhao Li, Hong Liu, Runwei Ding, Mengyuan Liu, and                        Massa, Alexandre Sablayrolles, and Hervé Jégou. Training
     Pichao Wang. Lifting transformer for 3d human pose                         data-efficient image transformers & distillation through
     estimation in video. arXiv preprint arXiv:2103.14304, 2021.                attention. arXiv preprint arXiv:2012.12877, 2020. 1, 2, 4, 5,
     2                                                                          6
[40] Xiangyu Li, Yonghong Hou, Pichao Wang, Zhimin Gao,                    [56] Hugo Touvron, Matthieu Cord, Alexandre Sablayrolles,
     Mingliang Xu, and Wanqing Li. Transformer guided                           Gabriel Synnaeve, and Hervé Jégou. Going deeper with
     geometry model for flow-based unsupervised visual odometry.                image transformers. arXiv preprint arXiv:2103.17239, 2021.
     Neural Computing and Applications, pages 1–12, 2021. 2                     1, 2
[41] Xiangyu Li, Yonghong Hou, Pichao Wang, Zhimin Gao,                    [57] Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob
     Mingliang Xu, and Wanqing Li. Trear: Transformer-                          Uszkoreit, Llion Jones, Aidan N Gomez, Lukasz Kaiser, and
     based rgb-d egocentric action recognition. arXiv preprint                  Illia Polosukhin. Attention is all you need. In NIPS, 2017. 1,
     arXiv:2101.03904, 2021. 2                                                  2, 3
[42] Yawei Li, Kai Zhang, Jiezhang Cao, Radu Timofte, and Luc              [58] Pichao Wang, Xue Wang, Hao Luo, Jingkai Zhou, Zhipeng
     Van Gool. Localvit: Bringing locality to vision transformers.              Zhou, Fan Wang, Hao Li, and Rong Jin. Scaled relu
     arXiv preprint arXiv:2104.05707, 2021. 2                                   matters for training vision transformers. arXiv preprint
[43] Peter J Liu, Mohammad Saleh, Etienne Pot, Ben Goodrich,                    arXiv:2109.03810, 2021. 2
     Ryan Sepassi, Lukasz Kaiser, and Noam Shazeer. Generating             [59] Wenhai Wang, Enze Xie, Xiang Li, Deng-Ping Fan, Kaitao
     wikipedia by summarizing long sequences. In ICLR, 2018. 2                  Song, Ding Liang, Tong Lu, Ping Luo, and Ling Shao.
[44] Ze Liu, Yutong Lin, Yue Cao, Han Hu, Yixuan Wei, Zheng                     Pyramid vision transformer: A versatile backbone for
     Zhang, Stephen Lin, and Baining Guo. Swin transformer:                     dense prediction without convolutions. arXiv preprint
     Hierarchical vision transformer using shifted windows. arXiv               arXiv:2102.12122, 2021. 1, 2
     preprint arXiv:2103.14030, 2021. 1, 2, 4, 5, 6, 8                     [60] Wenxiao Wang, Lu Yao, Long Chen, Deng Cai, Xiaofei
[45] Minh-Thang Luong, Hieu Pham, and Christopher D Manning.                    He, and Wei Liu.         Crossformer: A versatile vision
     Effective approaches to attention-based neural machine                     transformer based on cross-scale attention. arXiv preprint
     translation. In EMNLP, 2015. 2                                             arXiv:2108.00154, 2021. 2
[46] Sunil Pai. Convolutional neural networks arise from ising             [61] Yulin Wang, Rui Huang, Shiji Song, Zeyi Huang, and Gao
     models and restricted boltzmann machines. 2                                Huang. Not all images are worth 16x16 words: Dynamic

                                                                      10
vision transformers with adaptive sequence length. arXiv                    bird: Transformers for longer sequences. arXiv preprint
       preprint arXiv:2105.15075, 2021. 2                                          arXiv:2007.14062, 2020. 2
[62]   Yuqing Wang, Zhaoliang Xu, Xinlong Wang, Chunhua Shen,               [77]   Pengchuan Zhang, Xiyang Dai, Jianwei Yang, Bin Xiao,
       Baoshan Cheng, Hao Shen, and Huaxia Xia. End-to-end                         Lu Yuan, Lei Zhang, and Jianfeng Gao. Multi-scale vision
       video instance segmentation with transformers. In CVPR,                     longformer: A new vision transformer for high-resolution
       2021. 2                                                                     image encoding. arXiv preprint arXiv:2103.15358, 2021. 2
[63]   Haiping Wu, Bin Xiao, Noel Codella, Mengchen Liu, Xiyang             [78]   Xiangyu Zhang, Xinyu Zhou, Mengxiao Lin, and Jian Sun.
       Dai, Lu Yuan, and Lei Zhang. Cvt: Introducing convolutions                  Shufflenet: An extremely efficient convolutional neural
       to vision transformers. arXiv preprint arXiv:2103.15808,                    network for mobile devices. In CVPR, 2018. 6
       2021. 1, 2, 4, 5, 6                                                  [79]   Zizhao Zhang, Han Zhang, Long Zhao, Ting Chen, and Tomas
[64]   Tete Xiao, Mannat Singh, Eric Mintun, Trevor Darrell,                       Pfister. Aggregating nested transformers. arXiv preprint
       Piotr Dollár, and Ross Girshick. Early convolutions help                    arXiv:2105.12723, 2021. 2
       transformers see better. arXiv preprint arXiv:2106.14881,            [80]   Hengshuang Zhao, Jiaya Jia, and Vladlen Koltun. Exploring
       2021. 2                                                                     self-attention for image recognition. In CVPR, 2020. 6
[65]   Jiangtao Xie, Ruiren Zeng, Qilong Wang, Ziqi Zhou, and               [81]   Bolei Zhou, Hang Zhao, Xavier Puig, Tete Xiao, Sanja
       Peihua Li. So-vit: Mind visual tokens for vision transformer.               Fidler, Adela Barriuso, and Antonio Torralba. Semantic
       arXiv preprint arXiv:2104.10935, 2021. 2, 4, 5, 6                           understanding of scenes through the ade20k dataset. IJCV,
[66]   Saining Xie, Ross Girshick, Piotr Dollár, Zhuowen Tu, and                   127(3):302–321, 2019. 8
       Kaiming He. Aggregated residual transformations for deep             [82]   Daquan Zhou, Bingyi Kang, Xiaojie Jin, Linjie Yang,
       neural networks. In CVPR, 2017. 6                                           Xiaochen Lian, Qibin Hou, and Jiashi Feng. Deepvit:
[67]   Tongkun Xu, Weihua Chen, Pichao Wang, Fan Wang, Hao                         Towards deeper vision transformer.              arXiv preprint
       Li, and Rong Jin. Cdtrans: Cross-domain transformer                         arXiv:2103.11886, 2021. 2
       for unsupervised domain adaptation.          arXiv preprint          [83]   Daquan Zhou, Yujun Shi, Bingyi Kang, Weihao Yu, Zihang
       arXiv:2109.06165, 2021. 2                                                   Jiang, Yuan Li, Xiaojie Jin, Qibin Hou, and Jiashi Feng.
[68]   Weijian Xu, Yifan Xu, Tyler Chang, and Zhuowen Tu. Co-                      Refiner: Refining self-attention for vision transformers. arXiv
       scale conv-attentional image transformers. arXiv preprint                   preprint arXiv:2106.03714, 2021. 2
       arXiv:2104.06399, 2021. 2                                            [84]   Cheng Zou, Bohan Wang, Yue Hu, Junqi Liu, Qian Wu, Yu
[69]   Yifan Xu, Zhijie Zhang, Mengdan Zhang, Kekai Sheng, Ke                      Zhao, Boxun Li, Chenguang Zhang, Chi Zhang, Yichen Wei,
       Li, Weiming Dong, Liqing Zhang, Changsheng Xu, and Xing                     et al. End-to-end human object interaction detection with hoi
       Sun. Evo-vit: Slow-fast token evolution for dynamic vision                  transformer. In CVPR, 2021. 2
       transformer. arXiv preprint arXiv:2108.01390, 2021. 2
[70]   Jianwei Yang, Chunyuan Li, Pengchuan Zhang, Xiyang Dai,
       Bin Xiao, Lu Yuan, and Jianfeng Gao. Focal self-attention
       for local-global interactions in vision transformers. arXiv
       preprint arXiv:2107.00641, 2021. 2
[71]   Qihang Yu, Yingda Xia, Yutong Bai, Yongyi Lu, Alan Yuille,
       and Wei Shen. Glance-and-gaze vision transformer. arXiv
       preprint arXiv:2106.02277, 2021. 2
[72]   Zitong Yu, Xiaobai Li, Pichao Wang, and Guoying Zhao.
       Transrppg: Remote photoplethysmography transformer for
       3d mask face presentation attack detection. arXiv preprint
       arXiv:2104.07419, 2021. 2
[73]   Kun Yuan, Shaopeng Guo, Ziwei Liu, Aojun Zhou, Fengwei
       Yu, and Wei Wu. Incorporating convolution designs into
       visual transformers. arXiv preprint arXiv:2103.11816, 2021.
       1, 2
[74]   Li Yuan, Yunpeng Chen, Tao Wang, Weihao Yu, Yujun Shi,
       Francis EH Tay, Jiashi Feng, and Shuicheng Yan. Tokens-
       to-token vit: Training vision transformers from scratch on
       imagenet. arXiv preprint arXiv:2101.11986, 2021. 1, 2, 4, 5,
       6
[75]   Li Yuan, Qibin Hou, Zihang Jiang, Jiashi Feng, and
       Shuicheng Yan. Volo: Vision outlooker for visual recognition.
       arXiv preprint arXiv:2106.13112, 2021. 1, 2, 5, 6
[76]   Manzil Zaheer, Guru Guruganesh, Avinava Dubey, Joshua
       Ainslie, Chris Alberti, Santiago Ontanon, Philip Pham,
       Anirudh Ravula, Qifan Wang, Li Yang, et al.              Big

                                                                       11
A. Appendix
     A.1. Source codes of fast version k-NN attention in Pytorch
        The source codes of fast version k-NN attention in Pytorch are shown in Algorithm 1, and we can see that the core
     codes of fast version k-NN attention is consisted of only four lines, and it can be easily imported to any architecture using
     fully-connected attention.

     Algorithm 1 Codes of fast version k-NN attention in Pytorch.
 1   class kNN-Attention(nn.Module):
 2       def __init__(self,dim,num_heads=8,qkv_bias=False,qk_scale=None,attn_drop=0.,proj_drop=0.,topk=100):
 3           super().__init__()
 4           self.num_heads=num_heads
 5           head_dim=dim//num_heads
 6           self.scale=qk_scale or head_dim**-0.5
 7           self.topk=topk
 8
 9           self.qkv=nn.Linear(dim,dim*3,bias=qkv_bias)
10           self.attn_drop=nn.Dropout(attn_drop)
11           self.proj=nn.Linear(dim,dim)
12           self.proj_drop=nn.Dropout(proj_drop)
13
14       def forward(self,x):
15           B,N,C=x.shape
16           qkv=self.qkv(x).reshape(B,N,3,self.num_heads,C//self.num_heads).permute(2,0,3,1,4)
17           q,k,v=qkv[0],qkv[1],qkv[2] #B,H,N,C
18           attn=(q@k.transpose(-2,-1))*self.scale #B,H,N,N
19           # the core code block
20           mask=torch.zeros(B,self.num_heads,N,N,device=x.device,requires_grad=False)
21           index=torch.topk(attn,k=self.topk,dim=-1,largest=True)[1]
22           mask.scatter_(-1,index,1.)
23           attn=torch.where(mask>0,attn,torch.full_like(attn,float(’-inf’)))
24           # end of the core code block
25           attn=torch.softmax(attn,dim=-1)
26           attn=self.attn_drop(attn)
27           x=(attn@v).transpose(1,2).reshape(B,N,C)
28           x=self.proj(x)
29           x=self.proj_drop(x)
30
31           return x

     A.2. Comparisons between slow version and fast version
        We develop two versions of k-NN attention, one slow version and one fast version. The k-NN attention is exactly defined
     by slow version, but its speed is extremely slow, as for each query it needs to select different k keys and values, and this
     procedure is very slow. To speedup, we developed the CUDA version, but the speed is still slower than fast version. The fast
     version takes advantages of matrix multiplication and greatly speedup the computing. The speed comparisons on DeiT-Tiny
     using 8 V100 are illustrated in Table 5.

                                   Table 5. The speed comparisons on DeiT-tiny for slow and fast version

                                         method                      time per iteration (second)
                                         slow version (pytorch)      8192
                                         slow version (CUDA)         1.55
                                         fast version (pytorch)      0.45

     A.3. Proof
        Notations. Throughout this appendix, we denote xi as i-th element of vector x, Wij as the element at i-th row and j-th
     column of matrix W , and Wj as the j-th row of matrix W . Moreover, we denote xi as the i-th patch (token) of the inputs
     with xi = Xi .

                                                                    12
Proof for Lemma 1 We first give the formal statement of Lemma 1.
Lemma 4 (Formal statement of  PnLemma 1). Let V̂l
                                                      knn
                                                           be the l-th row of the V̂ knn and Varal (x) =          Eal [x> x] −
Eal [x ]Eal [x] with Eal [x] = t=1 alt xt . Then for any i, j = 1, 2, ..., n, we have
      >

                                             ∂ V̂l        >
                                                   = xli WK,j Varal (x)WV ∝ Varal (x)
                                            ∂WQ,ij
and
                                             ∂ V̂l        >
                                                   = xli WQ,j Varal (x)WV ∝ Varal (x).
                                            ∂WK,ij
The same is true for V̂ of the fully-connected self-attention.
Proof. Let’s first consider the derivative of V̂l over WQ,ij . Via some algebraic computation, we have
                                                   n                      n
                                                                                               !
                          ∂ V̂l      ∂(al V ) X           ∂Tlknn (t)     X        ∂Tlknn (k1 )
                                 =             =      alt             −      alk1                xt WV ,                   (1)
                        ∂WQ,ij       ∂WQ,ij       t=1
                                                            ∂WQ,ij                  ∂WQ,ij
                                                                               k1 =1

where we denote Tlknn (k) as follow for shorthand:
                                             (
                                                        > >
                                knn            xl WQ WK    xk1 , if patch k1 is selected in row l
                             Tl (k1 ) =
                                               −∞,                  otherwise
                 .
Let denote set S = {i : patch i is selected in row l} and then we consider the right-hand-side of (1).
                               n                      > >
                                                                                      > >
                                                                                             !
                             X          ∂ xi WQ WK     xk        X         ∂ xi WQ WK   xk1
                       (1) =      alt                        −        alk1                      xt WV
                                              ∂WQ,ij                            ∂WQ,ij
                             t∈S                                k1 ∈S
                               n
                                                                              !
                             X                          X
                           =      alt x1i xt WK,j −          alk1 xli xk1 WK,j xt WV
                                t∈S                      k1 ∈S
                                Xn                              n
                                                                X                   X
                            =         alt xli xt WK,j xt WV −         alt xt WV ·           alk1 xli xk1 WK,j .            (2)
                                t∈S                             t∈S                 k1 ∈S
                                |             {z         }      |        {z     } |              {z         }
                                              (a)                        (b)                    (c)
                                                                           P
   Since al is the l-th row of the attention matrix, we have alt ≥ 0 and t alt = 1. It is possible to treat terms (a), (b) and
(c) as the expectation of some quantities over t replicates with probability alt . Then (2) can be further simplified as

                             (2) = Eal [xli xWK,j · xWV ] − Eal [xWK,j ] · Eal [xli xWV ]
                                           Eal [WK,j
                                                  >
                                                      x> · xWV ] − Eal [WK,j >
                                                                                 x> ] · Eal [xWV ]
                                                                                                            
                                    = xli
                                            >
                                                 Eal [x> x] − Eal [x> ] · Eal [xt ] WV
                                                                                   
                                    = xli WK,j
                                           >
                                    = xli WK,j Varal (x)WV ,                                                               (3)
where the second equality uses the fact that xt WK,j is a scalar.
  Combing (1)-(3), we have
                                             ∂ V̂l        >
                                                   = xli WK,j Varal (x)WV ∝ Varal (x).                                     (4)
                                            ∂WQ,ij
Due the symmetric on Q and K, we can follow the similar procedure to show
                                             ∂ V̂l        >
                                                   = xli WQ,j Varal (x)WV ∝ Varal (x).                                     (5)
                                            ∂WK,ij
Finally, by setting k = n, one may verify that equations (4) and (5) also hold for fully-connected self-attention.

                                                                    13
Proof for Lemma 2 Before given the formal statement of the Lemma 2, we first show the assumptions.
   Assumption 2
 1. The token xi is the sub-gaussian random vector with mean µi and variance (σ 2 /d)I for i = 1, 2, ..., n.
 2. µ follows a discrete distribution with finite values µ ∈ V. Moreover, there exist 0 < ν1 , 0 < ν2 < ν4 such that a)
                                 T                                       > >
    kµi k = ν1 , and b) µi WQ WK   µi ∈ [ν2 , ν4 ] for all i and |µi WQ WK µj | ≤ ν2 for all µi 6= µj ∈ V.
                       >                                                                        (ij)                    > (ij)
 3. WV and WQ WK         are element-wise bounded with ν5 and ν6 respectively, that is, |WV            | ≤ ν5 and |(WQ WK ) | ≤ ν6 ,
    for all i, j from 1 to d.
In Assumption 2 we ensure that for a given query patch, the difference between the clustering center and noises are large
enough to be distinguished.
Lemma 5 (formal statement of Lemma 2). Let patch xi be σ 2 -subgaussian random variable with      √ mean µi and there are k1
patches out of all k patches follows the same clustering center of query l. Per Assumption 2, when d ≥ 3(ψ(δ, d) + ν2 + ν4 ),
then with probability 1 − 5δ, we have

                Pk                                 
                                  √1 xl WQ W > xi
                                                                                         s
                  i=1   exp         d       kxi WV                       
                                                                           ψ(δ, d)
                                                                                   
                                                                                            2
                                                                                                   
                                                                                                   2d
                                                  − µl W V    ≤  4 exp     √       σν5      log
                  Pk            1          >                                   d           dk       δ
                     j=1 exp
                               √ xl WQ W xj
                                 d         K
                                                              ∞
                                                                        
                         ν2 − ν4 + ψ(δ, d)                ν2 − ν4 + ψ(δ, d)     k1
              + 8 exp            √           − 7 + exp           √                   kµ1 WV k∞ ,
                                    d                              d             k
                       q
where ψ(δ, d) = 2σν1 ν6 2 log 1δ + 2σ 2 ν6 log dδ .
                                                

Proof. Without loss of generality, we assume the first k patch are the top-k selected patches. From Assumption 2.1, we can
decompose xi = µi + hi , i = 1, 2, ..., k, where hi is the sub-gaussian random vector with zero mean. We then analyze the
numerator part.
                         k                           
                         X                1       >
                               exp       √ xl WQ Wk xi xi WV
                         i=1
                                           d
                                              (a)                                         (b)
                         z                    }|                      {    z              }|                    {
                           k                                               k                          
                         X                1       > >
                                                                           X          1       > >
                     =         exp       √ µl WQ WK µi µi Wv +     exp               √ xl WQ WK xi hi Wv
                         i=1
                                           d                   i=1
                                                                                       d
                                                                (c)
                       z                        }|                        {
                         k                                     
                       X         1                    1
                     +      exp √ xl WQ Wk> xi − exp √ µl WQ WK
                                                              > >
                                                                µi    µi Wv .                                                    (6)
                       i=1
                                  d                    d

Below we will bound (a), (b) and (c) separately.

Upper bound for (a). Let denote index set S1 = {i : µ1 = µi , i = 1, 2, ..., k}. We then have
                                                                       
                                              X          1          > >
                                       (a) −      exp √ x1 WQ WK xi µ1 WV
                                             i∈S1
                                                          d
                                                                           ∞
                                                              1         > >
                                    ≤(k − |S1 |) max exp √ x1 WQ WK      xi   kµ1 WV k∞
                                                  i            d
                                                       
                                                     ν2
                                    ≤(k − k1 ) exp √       kµ1 WV k∞ ,                                                           (7)
                                                      d

                                                                      14
You can also read