Estimating the Generalization in Deep Neural Networks via Sparsity

 
CONTINUE READING
Estimating the Generalization in Deep Neural Networks via Sparsity

                                                                          Yang Zhao                                      Hao Zhang
                                                                      Tsinghua University                            Tsinghua University
                                                                 zhao-yan18@tsinghua.edu.cn                      haozhang@tsinghua.edu.cn
arXiv:2104.00851v2 [cs.CV] 16 Jan 2022

                                                                  Abstract                                    ate the generalization bounds that is as tight as possible, in
                                                                                                              the menatime for the hope of demystifying the mechanism
                                            Generalization is the key capability for deep neural net-         of generalization in DNNs. Despite somehow being able to
                                         works (DNNs). However, it is challenging to give a reliable          also provide possible solutions on the mentioned problem,
                                         measure of the generalization ability of a DNN via only its          currently there remains barriers between the generalization
                                         nature. In this paper, we propose a novel method for esti-           bounds and how they could precisely estimate the gener-
                                         mating the generalization gap based on network sparsity. In          alization error in a practical manner. Further, works like
                                         our method, two key quantities are proposed first. They have         margin-based methods take the prior knowledge in solving
                                         close relationship with the generalization ability and can be        this problem, but due to the unbounded scales, they would
                                         calculated directly from the training results alone. Then a          fail to give acceptable results at times.
                                         simple linear model involving two key quantities are con-                In this paper, we introduce sparsity, a particular nature of
                                         structed to give accurate estimation of the generalization           DNNs, to quantitatively investigate the network generaliza-
                                         gap. By training DNNs with a wide range of generalization            tion from the perspective of network units. Two key quan-
                                         gap on popular datasets, we show that our key quantities             tities with close relation with both DNN sparsity and gen-
                                         and linear model could be efficient tools for estimating the         eralization are proposed. They have strict bounds within an
                                         generalization gap of DNNs.                                          appropriate range to ensure the accurate estimation of net-
                                                                                                              work generalization. A practical linear model for estimat-
                                                                                                              ing the generalization gap of DNNs is built by the two pro-
                                         1. Introduction                                                      posed quantities. We empirically found that units in DNNs
                                                                                                              trained on real data exhibit decreasing sparsity as the frac-
                                            Deep neural networks (DNNs) achieve great success in              tion of corrupted labels increases. By investigating several
                                         many real-world tasks, especially in computer vision like            DNN models with a wide range of generalization gap, we
                                         image recognition [12, 17, 34], object detection [31, 32] and        found that both of the two proposed quantities are highly
                                         so on. A fundamental reason behind this success is that              correlated with generalization gap in an approximately lin-
                                         DNNs could acquire extraordinary ability of generalization           ear manner. This ensures satisfactory results when perform-
                                         on unseen data by training with finite samples even though           ing estimation on practical networks using training sam-
                                         they could be largely overparameterized compared to the              ples and appropriate linear model. With extensive experi-
                                         number of training samples. But in the meantime, due to              ments on various datasets, we show that our linear model
                                         their excessive capacity, DNNs are demonstrated to be able           could give a reliable estimation of the generalization gap in
                                         to easily fit training data with arbitrary random labels or          DNNs. Also, we show that our method could give better
                                         even pure noise, obviously without generalizing [39]. So,            results than the margin-based method in practice.
                                         this motivates a critical problem that for a specific DNN,
                                         how could we estimate its generalization ability via only its        2. Related Works
                                         nature.
                                            Many works attempt to seek the attributes of neural net-              Regarding the discussed topic of estimating the gener-
                                         works that would have an underlying impact on the gener-             alization in DNNs, a conventional line is that generaliza-
                                         alization or may emerge in a generalized network. Among              tion could be bounded based on certain measurements of
                                         the works, a variety of attributes are considered to have pos-       the complexity or capacity of DNNs where VC dimension
                                         sible association with the generalization of DNNs, which             [4, 21] and Rademacher complexity [6, 30, 37] are typically
                                         generally includes the complexity or capacity, the stabil-           used. However, this approach appears unreasonable as [39]
                                         ity or the robustness, and the sparsity of DNNs. Based on            show that DNNs are able to fit any possible labels with and
                                         these attributes of DNNs, some of the works try to evalu-            without regularizations.

                                                                                                          1
After that, bounding the generalization via stability or            the corresponding convolutional filter. Given an input I,
robustness seem to be receiving more attention. As for sta-             the unit U is expected to present specific function prefer-
bility, it investigates the change of outputs when perturbing           ence and would be highly responded if it could capture the
a given input. Generally, for keeping the network stable,               feature of this input,
generalized DNNs are expected to stay in the flat landscape
neighboring the minima [16,18,20]. However, it is still hard                                         f :I→U                                   (1)
to appropriately define this ”flatness” of minima especially
                                                                            Consider a classification problem with a training dataset
for high dimensional space, where [28] argue that the defi-
                                                                        D including N classes {Dk }N    k=1 , D = D1 ∪ · · · ∪ DN .
nition provided by [20] could not well capture the general-
                                                                        To obtain excellent generalization on testing data, units are
ization behavior. In addition, [11] point that sharp minima
                                                                        trained to be capable to perceive diverse intrinsic natural
may also leads to the generalization when using different
                                                                        features hidden in a dataset. When comes to a single class,
definitions of flatness. In contrast to stability, robustness in-
                                                                        this would lead to the sparsity of units because only a small
vestigates the variation of outputs with respect to the input
                                                                        group of units, which are substantially function-correlated
space. One typical evaluation of robustness is the margin
                                                                        to the common features of the specific class, would be
of the predictions, where generalized DNNs are supposed
                                                                        highly active. On the contrary, due to the its excessive
to have large margin to ensure the robustness [13, 35]. In
                                                                        capacity, it is possible for DNN to just ”memorize” each
particular, [5] use the margin distribution of the outputs and
                                                                        sample in the class and then simply ”mark” it as belong-
normalized it with the spectral norm. Based on this mar-
                                                                        ing to this class. In this situation, much more units are
gin distribution, [19] further use the margin distribution of
                                                                        needed since they could not give true representations of
hidden layers to give an estimation of generalization. But
                                                                        class-specific intrinsic features. So, we would begin with
meanwhile, [3] argues that methods in [5, 29] could not yet
                                                                        identifying the group of units which are highly active on
give bounds of sample complexity better than naive param-
                                                                        samples in the specific class Dj .
eter counting.
    Sparsity of the network unit is considered as an impor-             3.1. Cumulative Unit Ablation
tant sign that units present highly specialized [40] and has
the ability to perceive the intrinsic common features for                  In general, unit ablation operation is supposed to able
samples in the same class, which provides the basis to gen-             to make connections between DNN sparsity and general-
eralize on unseen data. This is particularly significant for            ization. Single unit ablation individually checks each unit
units in CNNs since they generally are found to be concep-              that would cause a performance deterioration when being
tually aligned with our vision cognition [9, 14]. Moreover,             removed1 from the DNN. However, this tells nearly no in-
Sparsity of the overall network is also a fundamental na-               formation on the collaborative effect of a group of units.
ture of DNNs [15] and found to be helpful for generaliza-               Therefore, we introduce the cumulative unit ablation here
tion [7, 24, 25]. For a better generalization, many methods             for investigating the group effect of units to the network. In
such as dropout regularization are proposed for inducing the            cumulative unit ablation, firstly, units at a given layer are
DNNs to become sparse during training [26, 33, 36, 38]. On              arranged into a list R which is ranked by the quantity of
the other hands, in the view of network units, [27] claim               certain attribute of a unit,
that generalization should not rely on the single direction
                                                                                      R =< h(U0 ), h(U1 ), · · · , h(Ui ) >                   (2)
of the units and further show generalization improvement
via punishing this reliance with loss function in their work            where h(Ui ) denotes a given attribute of the unit Ui , and <
[23]. Trained units could present specific natural concepts             · > is the sorted list. Since we are focusing on the group of
that gives disentangled representations for different classes           units that are highly active to Dj , we use the L1-norm value
[8, 9]. [40] show that for CNNs, only several conceptual re-            of the unit as the target attribute h(Ui ) in the cumulative
lated units are necessary for classifying a single category.            ablation,
Also, [1] demonstrate that dropout could be as an effective
                                                                                                      X
                                                                                            h(Ui ) =      Ui (x, y)                (3)
regularizer to prevent memorization.                                                                       x,y

                                                                        In this way, units in the list R would be ordered based on
3. Method
                                                                        its L1-norm value h(Ui ).
   DNNs are bio-inspired artificial neural networks that                    Then after ranking, units are removed progressively from
generally consist of multiple layers where each layer is a              the head unit to the end unit in this ordered list R, and in
collection of certain amount of units. Similar to neuro-                the meantime, the evaluation of performance is recorded as
science, a unit in DNNs generally behaves as a perceptive                   1 Removing or ablating a unit is generally implemented via forcing the
node, and makes responses to the inputs. Specifically, in               elements in this unit to be all deactivated. For example, for ReLU activa-
CNNs, it refers to the activated feature map outputted by               tion, a unit is assigned to all zeros if being removed

                                                                    2
Dj as the turning point of E(n, Dj ),
                            1.0

                                   chance
                                                                                                       n0 (Dj ) = inf{n|E(n, Dj ) ≤ accchance }                                                 (4)
                            0.8
        training accuracy

                                                                           where accchance is the chance-level accuracy for the target
                            0.6                                            classification. If n0 (Dj ) is large, it means that the major-
                                                                           ity of units have positive contribution to Dj , so it requires
                                                                           the deactivation of more critical units to completely lose re-
                            0.4
                                                                           sponse to Dj .
                                                                               On the contrary, with the ordered list in ascending rank
                            0.2
                                                                           Rr , units that are highly active to Dj would be preserved
                                                                           at the beginning and removed near the end of the evolution
                            0.0
                                                                           process. So the corresponding Er (n, Dj ) would in general
                                  number of units ablated n                experience a continual slight increase at the early stage of
                                                                           evolution, and this would keep in the most time of the evolu-
Figure 1. Sketch figure of E(n, Dj ) (blue) and Er (n, Dj ) (green),       tion until the accuracy reaches at the maximum. After this
separately resulted from the descending and ascending cumulative           point, the accuracy would drop abruptly to below chance
unit ablation. The two markers represent for the two turning points        level. Similarly, we mark the maximal point as the turning
n0 (Dj ) and n0 (Dj ).
                                                                           point of Er (n, Dj ),

                                                                                                               nr0 (Dj ) = arg max Er (n, Dj )                                                  (5)
a characterization of the growing effect of this attribute to                                                                      n
the DNN.
   During the implementation of cumulative unit ablation,                  We suppose it would give a more clear meaning if nr0 (Dj )
we would perform the cumulative unit ablation on Dj twice,                 is viewed from the inverse order of the x-axis. In this way,
separately according to two different ordered lists. One is                nr0 (Dj ) becomes M − nr0 (Dj ), which represents that the
R, sorted by the descending rank of h(Ui ), and the other                  minimum number of units being activated jointly that could
one is Rr , sorted by the ascending rank of h(Ui ). Corre-                 stimulate the most performance to Dj . If nr0 (Dj ) is large,
spondingly, the two evolution curves of network accuracy                   it means that for Dj , most units are unrelated in function
with respect to the number of units being removed (notated                 and activating only a small number of critical units would
as n) could be recorded, where E(n, Dj ) denotes the ac-                   be able to provide the best effect.
curacy evolution on the ascending rank and Er (n, Dj ) de-                 3.3. Key Quantities
notes the other. Fig.1A illustrates typically E(n, Dj ) and
Er (n, Dj ).                                                                  Generally, units that are highly active to Dj should be
   Compared to units at the shallow layers, these units in                 more sparse for DNNs with better generalization. There-
the deeper layers are generally considered to perceive the                 fore, in the process of cumulative unit ablation, the accuracy
”high-level” features [9] and thus be more representative of               of well generalized DNNs would be more sensitive to the
the specific class. This makes the units that are highly active            removals of these ”important” units. This would make both
tend to be more sparse in deeper layer, and the product of                 E(n, Dj ) and Er (n, Dj ) more steep, as shown in Fig.2.
cumulative unit ablation would be more significant. So our
investigations here would focus on the units at the deeper                                           well generalized DNN                                        poor generalized DNN
layer.                                                                                         1.0                                                         1.0

3.2. Turning Points                                                                            0.8                                                         0.8
                                                                           training accuracy

                                                                                                                                       training accuracy

                                                                                               0.6                                                         0.6
    With the ordered list in descending rank R, units that are
highly active to Dj would be removed at the beginning of                                       0.4                                                         0.4

the evolution process. So the accuracy would experience a                                      0.2                                                         0.2

continuous decrease since the neural network gradually lose
                                                                                               0.0                                                         0.0
its function on extracting the features in this class. Notably,                                        number of units ablated n                                    number of units ablated n
the accuracy may reach below chance level after removing
some critical units and then remain this situation with tiny               Figure 2. Comparisons of the turning points and the enclosed area
variation until all units in the layer are removed. We mark                formed by the two curves (painted in gray) with respect to be-
the minimum number of removed units that could cause a                     tween DNNs with well generalization (left) and poor generaliza-
complete damage to the function for the DNN on the dataset                 tion (right).

                                                                       3
For DNNs with better generalization, it is expected that          clarity, we would use the word ”model” to refer to this lin-
n0 (Dj ) should be smaller while nr0 (Dj ) should be larger.         ear model that predicts the generalization gap of DNNs, and
Since the two values are in the same scale, we could simply          use the word ”network” to refer to trained neural network
combine the two values as,                                           models.

                        n0 (Dj ) + M − nr0 (Dj )                     4. Experimental Results
             ζ(Dj ) =                                     (6)
                                  2M
                                                                        In the following subsections, we are going to implement
where ζ(Dj ) is in the range between 0 to 1, and the smaller         our method on CIFAR100 [22] and ImageNet dataset [10].
the value is, the sparser the critical units are in the DNN
on the classification of Dj . This is one of the key quantity        4.1. Experiments on CIFAR100
derived from cumulative unit ablation.
   On the other hand, impacted by the two turning                    Dataset and networks. In this subsection, our investiga-
points, the area enclosed by the two curves E(n, Dj ) and            tion focuses on the classification task on CIFAR100 by us-
Er (n, Dj ) should be also valuable. In addition to ζ(Dj ),          ing VGG16 architecture [34]. For obtaining networks with
we have another key quantity,                                        a wide range of generalization gap, we randomly corrupt the
                                                                     labels with certain percentage for each class in the dataset
                       M                                             as [39]. Meanwhile, we use different training strategies to
                    1 X
         κ(Dj ) =         |Er (n, Dj ) − E(n, Dj )|       (7)        obtain the networks with diverse generalization gap, includ-
                    M n=0                                            ing varying the momentum and the common regularizations
                                                                     like weight decay, batch normalization, dropout, etc. All
                        1
Here, the coefficient M   is only used for adjusting the area        the networks are trained to reach at almost 1.0 accuracy on
value to fall in the range from 0 to 1. But note that some-          training set. We have trained 80 networks in total and their
times the original training accuracies of different DNNs on          generalization errors (accuracy on testing set) are ranged
Dj may be not the same, leading to a biased comparison. In           from 0.294 to 0.987.
this situation, a normalized process could be implemented                Besides, additional results by using other network archi-
by dividing the corresponding training accuracy in practice.         tectures could be found in the Appendix section.
Opposite to ζ(Dj ), the larger κ(Dj ) indicates the critical
units are sparser in the DNN on the classification of Dj .
   For dataset {Dk }N                                                Key quantities of networks trained with partially ran-
                      k=1 of all classes, we could simply use
the average to make fusion for the two characterization on           domized labels. We calculate the two key quantities of
various data classes to achieve the ensemble effect,                 networks that are trained with the same training strategy but
                                                                     on datasets with different percentage of randomized labels.
                             1 X
                                N                                       Firstly, we perform the cumulative unit ablation for the
                    ζ(D) =         ζ(Dj )                            networks on their training dataset. Fig.3A shows the two
                             N j=1
                                                                     evolution curves E(n, Dj ) and Er (n, Dj ) on the same class
                                                          (8)
                           1 X
                              N                                      from the datasets separately with 0, 0.2, 0.4, 0.6, 0.8, 1.0
                    κ(D) =       κ(Dj )                              fractions of randomized labels. In the figure, markers with
                           N j=1
                                                                     a black border denote the two turning points and the area
                                                                     enclosed by the two curves is painted with gray. We could
3.4. Estimating the Generalization Gap in DNNs                       see that as the fraction of randomized labels goes higher,
   The ζ(D) and κ(D) are in high correlation to the gener-           the first turning point n0 (Dj ) gradually increase while the
alization of DNN. In fact, they could be utilized to estimate        second turning point nr0 (Dj ) decrease. In addition to the
the generalization ability via a simple linear model,                two turning points, the area becomes smaller as well.
                                                                        Then, we calculate the two quantities ζ(D) and κ(D)
              ĝ(D) = a · ζ(D) + b · κ(D) + c             (9)        for these networks on all the classes in their correspond-
                                                                     ing datasets. Fig.3B makes the scatter plot of the point
where a, b and c are the parameters of this linear model.            pair (ζ(Dj ), κ(Dj )) for all the 100 classes in the datasets.
    The reason of not using other more complex models here           According to our statement, for better generalization, ζ(D)
is that we find the linear model is sufficient in this situa-        should be smaller while κ(D) should be larger. This makes
tion since the two quantities and the generalization present a       that the quantity pair should locate around the top left cor-
highly negative linear relationship. Also, it should be noted        ner of the scatter figure. As expected, the point group moves
that we are not stating the two quantities are in real linear        from the top left regularly to right bottom corner as the frac-
relation with the generalization gap even though they have           tion of label corruption increases. This confirms that when
presented high linear correlation in practice. Note that for         networks are trained with partially randomized labels, the

                                                                 4
A                    1.0
                                     corruption 0.0
                                   E(n) and Er(n)
                                                                                      1.0
                                                                                               corruption 0.2
                                                                                            E(n) and Er(n)
                                                                                                                                                 1.0
                                                                                                                                                          corruption 0.4
                                                                                                                                                       E(n) and Er(n)
                                                                                                                                                                                                                 1.0
                                                                                                                                                                                                                         corruption 0.6
                                                                                                                                                                                                                        E(n) and Er(n)
                                                                                                                                                                                                                                                                              1.0
                                                                                                                                                                                                                                                                                      corruption 0.8
                                                                                                                                                                                                                                                                                    E(n) and Er(n)
                                                                                                                                                                                                                                                                                                                                         1.0
                                                                                                                                                                                                                                                                                                                                                  corruption 1.0
                                                                                                                                                                                                                                                                                                                                               E(n) and Er(n)

                         0.8                                                          0.8                                                        0.8                                                             0.8                                                          0.8                                                        0.8
     training accuracy

                                                                  training accuracy

                                                                                                                             training accuracy

                                                                                                                                                                                             training accuracy

                                                                                                                                                                                                                                                          training accuracy

                                                                                                                                                                                                                                                                                                                     training accuracy
                         0.6                                                          0.6                                                        0.6                                                             0.6                                                          0.6                                                        0.6

                         0.4                                                          0.4                                                        0.4                                                             0.4                                                          0.4                                                        0.4

                         0.2                                                          0.2                                                        0.2                                                             0.2                                                          0.2                                                        0.2

                         0.0                                                          0.0                                                        0.0                                                             0.0                                                          0.0                                                        0.0
                               0    100        200   300   400                  500 0        100        200    300    400                500 0          100        200        300    400              500 0              100        200    300    400               500 0            100        200    300    400              500 0            100        200    300    400    500

    B              1.0
                                      number of units ablated n

                                                                                      1.0
                                                                                               number of units ablated n

                                                                                                                                                 1.0
                                                                                                                                                          number of units ablated n

                                                                                                                                                                                                                 1.0
                                                                                                                                                                                                                           number of units ablated n

                                                                                                                                                                                                                                                                              1.0
                                                                                                                                                                                                                                                                                       number of units ablated n

                                                                                                                                                                                                                                                                                                                                     1.0
                                                                                                                                                                                                                                                                                                                                                  number of units ablated n

                   0.8                                                                0.8                                                        0.8                                                             0.8                                                          0.8                                                    0.8

                   0.6                                                                0.6                                                        0.6                                                             0.6                                                          0.6                                                    0.6

                   0.4                                                                0.4                                                        0.4                                                             0.4                                                          0.4                                                    0.4

                   0.2                                                                0.2                                                        0.2                                                             0.2                                                          0.2                                                    0.2

                   0.0                                                                0.0                                                        0.0                                                             0.0                                                          0.0                                                    0.0
                          0.0       0.2        0.4   0.6    0.8                               0.2        0.4    0.6    0.8                                0.2           0.4    0.6     0.8                                 0.2       0.4    0.6     0.8                               0.2        0.4    0.6    0.8                               0.2        0.4    0.6    0.8

Figure 3. Results of networks trained via datasets with partially randomized labels. (A) The evolution example curves of accuracy E(n, Dj )
and Er (n, Dj ) on a single class. (B) Scatter plot between the two quantities ζ(D) and κ(D) across all the classes in the separate corrupted
datasets.

two sparsity quantities could effectively indicate the gener-                                                                                                                                                          the perfect fitting. We could see that the training points scat-
alization ability of these networks.                                                                                                                                                                                   tered closely to the reference line, showing that the two key
                                                                                                                                                                                                                       quantities and the generalization gap may be highly linearly
                                                                                                                                                                                                                       correlated. Furthermore, we find the Pearson correlation co-
Estimating the generalization of trained networks. Af-                                                                                                                                                                 efficient between the estimated generalization gaps and the
ter the calculation of the two sparsity quantities ζ(D) and                                                                                                                                                            true values of the training networks is 0.979, which con-
κ(D) for all 80 the trained networks, we are going to fur-                                                                                                                                                             firms the linear relationship between them. For testing, the
ther estimating their generalization in this paragraph.                                                                                                                                                                fitted linear model performs fairly well on these networks
    We begin with marking the quantity pair (ζ(D), κ(D))                                                                                                                                                               in the testing set. In addition, we use the summation of
of each network with a scatter plot, as shown in Fig.4A.                                                                                                                                                               squared residuals (SSR) [2] as a yardstick for checking the
In this figure, the colors of points vary progressively from                                                                                                                                                           predicting effect on testing networks. SSR is a conventional
red to purple, which indicates the true generalization gap of                                                                                                                                                          measurement of the performance of fitted models and usu-
networks from small to large. As expected, the point of net-                                                                                                                                                           ally used for model selection. For our testing set here, it is
works with better generalization ability mostly lie in the top                                                                                                                                                         0.023, which is very small and indicates this model could
right corner of figure while these with poor generalization                                                                                                                                                            provide excellent prediction in practice.
ability lie in the bottom left corner. In the meantime, we                                                                                                                                                                 For checking the stability of estimation when using our
could clearly find that the two quantities ζ(D) and κ(D) are                                                                                                                                                           method, we repeat the previous estimation 100 times and
negatively correlated, but apparently not in a linear manner.                                                                                                                                                          each time use a new splitted dataset but still with the same
    Then, we gives two scatters plots in the Fig.4B, one for                                                                                                                                                           fraction. Fig.4D presents the statistical results of RSS with
ζ(D) and generalization gap (left) and the other one for                                                                                                                                                               respect to all the testing sets. For the 100 tests, all the RSSs
κ(D) and generalization gap (right). We could see that as                                                                                                                                                              keep in a low value below 0.035. This verifies the overall
the generalization gap goes higher, the ζ(D) increase while                                                                                                                                                            effectiveness of two quantities ζ(D) and κ(D) to be indi-
the κ(D) decrease. This confirms that both of the two quan-                                                                                                                                                            cators of the generalization ability and the power of linear
tities could indeed provide efficient indication of the gener-                                                                                                                                                         model for the estimation of the generalization gap of neural
alization ability of networks.                                                                                                                                                                                         networks.
    Next, we randomly split all the 80 trained networks into
two sets by fractions of 0.9 and 0.1. The set with 72 net-
works is used as the training networks to build the linear                                                                                                                                                             Comparison with margin-based estimation. Lastly, we
model for estimating the generalization gap via Eq.9, and                                                                                                                                                              give comparisons of our method with margin-based method
the other set with 8 networks is used as a testing networks                                                                                                                                                            proposed in [19]. Typically, this method collects the margin
to check the performance of this linear model after fitting.                                                                                                                                                           distribution on a given dataset and use the key features of
The estimation here is a typical linear regression problem                                                                                                                                                             this distribution as the arguments for fitting the generaliza-
and could simply be solved by using least square.                                                                                                                                                                      tion gap.
    Fig.4C(1) illustrates the effect of linear fitting. The line                                                                                                                                                          Here, we keep previous setups and the dataset being the
in the figure is y = x, which acts as the reference line for                                                                                                                                                           same as used in Fig.4C. Fig.4E shows the estimation re-

                                                                                                                                                                                             5
A                      1.0                                          0.9860
                                                                                           B
                                                                                  0.8995

                                     0.8                                                            0.8                                                                                 0.8
                                                                                  0.8130

                                                                                  0.7265

                                     0.6                                                            0.6                                                                                 0.6
                                                                                  0.6400

                                                                                  0.5535

                                     0.4                                                            0.4                                                                                 0.4
                                                                                  0.4670

                                                                                  0.3805

                                     0.2                                                            0.2                                                                                 0.2
                                                                                  0.2940

                                     0.0
                                             0.2        0.4          0.6   0.8                          0.0   0.2       0.4      0.6      0.8    1.0                                        0.0     0.2       0.4         0.6   0.8   1.0
                                                                                                                     generalization gap                                                                    generalization gap

            C                      1.0
                                                                                           D 40                                                        E                              1.0
                                           training networks                                                                                                                                      training networks
    estimated generalization gap

                                                                                                                                                       estimated generalization gap
                                           testing networks                                                                                                                                       testing networks
                                           line x = y                                                                                                                                             line x = y
                                   0.8                                                             30                                                                                 0.8
                                                                                           count

                                   0.6                                                             20                                                                                 0.6

                                   0.4                                                             10                                                                                 0.4

                                   0.2                                                             0                                                                                  0.2
                                     0.2       0.4             0.6          0.8   1.0              0.00       0.01        0.02         0.03     0.04                                    0.2          0.4            0.6         0.8         1.0
                                                   true generalization gap                                                                                                                            true generalization gap
                                                                                                                         SSR

Figure 4. (A) Scatter plot between the two sparsity quantities ζ(D) and κ(D) across all the 80 networks. The color of each point indicates
the generalization gap, where red represents the smallest value and purple represents the largest. (B) Scatter plots of generalization gap
separately with respect to ζ(D) (left) and κ(D) (right). (C) Scatter plot between the estimated generalization gap and the true generalization
gap, where the blue point denotes the training networks and the red point denotes testing networks. (D) Histogram of SSRs of the 100
repeated tests. (E) Scatter plot between the generalization gap estimated by margin-based method and the true penalization gap.

sults based on the margin distribution. We could see that                                                                     4.2. Experiments on ImageNet
although the margin based model could estimate the gen-
                                                                                                                              Dataset and networks. In this subsection, the classifica-
eralization gap to some extent, it presents with a slightly
                                                                                                                              tion task is performed on the ImageNet dataset by using
worse linear correlation (Pearson correlation coefficient is
                                                                                                                              the same VGG16 architecture. Five networks are trained
0.75) than our model. When predicting the generalization
                                                                                                                              from scratch and eventually have different generalization
gap of testing networks, it presents a larger SSR, which is
                                                                                                                              ability. Without randomly corrupting labels, here we only
almost 0.5. Accordingly, our method could give a more ac-
                                                                                                                              alter the training strategies like momentum or dropout for
curate estimation of the generalization gap in the current
                                                                                                                              changing the networks. Since classification on ImageNet
situation.
                                                                                                                              is commonly supposed as a more difficult task, it is harder
   We suppose that two possible factors might lead to the                                                                     for networks to reach zero training error. Table 1 shows the
errors via margin-based method. The first one is that due to                                                                  training and testing accuracies of all the 5 networks used in
the non-linearity, margins in DNNs are actually intractable.                                                                  our experiments, where the generalization gap ranges from
Currently, the distance of a sample to the margin is approx-                                                                  0.054 to 0.564.
imately acquired by using the first-order Taylor approxima-
tion [13], which would introduce some error into the model                                                                    Results. Similarly, the cumulative unit ablation is per-
in our calculation. The second factor is that the calculated                                                                  formed for each network on all the class in the dataset at
margins are not bounded in scales. This may lead to that                                                                      first. Fig.5A shows almost the same tendency of the turning
for different models, their margins could differ by several                                                                   points (n0 (Dj ) and nr0 (Dj )) and accuracy curve (E(n, Dj )
orders of magnitude considering the distinct training set-                                                                    and Er (n, Dj )) as the result on CIFAR100.
tings. In this way, the corresponding linear model may                                                                            Next, ζ(D) and κ(D) are calculated based on the two
be ill-conditioned. During the calculation, we found that                                                                     curves across all the classes. Fig.5B presents the scatter plot
for some networks (especially for the networks trained with                                                                   of all the quantity pairs (ζ(Dj ), κ(Dj )) . We could see that
batch normalization), the situation become even worse.                                                                        as the generalization ability of networks becomes worse,

                                                                                                                         6
A                  1.0
                                                   Model A
                                        E(n) and Er(n)
                                                                                                   1.0
                                                                                                         E(n) and Er(n)
                                                                                                                                 Model B
                                                                                                                                                                                       1.0
                                                                                                                                                                                                      Model C
                                                                                                                                                                                             E(n) and Er(n)
                                                                                                                                                                                                                                                                             1.0
                                                                                                                                                                                                                                                                                                    Model D
                                                                                                                                                                                                                                                                                         E(n) and Er(n)
                                                                                                                                                                                                                                                                                                                                                                  1.0
                                                                                                                                                                                                                                                                                                                                                                                   Model E
                                                                                                                                                                                                                                                                                                                                                                        E(n) and Er(n)

                        0.8                                                                        0.8                                                                                 0.8                                                                                   0.8                                                                                  0.8
    training accuracy

                                                                               training accuracy

                                                                                                                                                                   training accuracy

                                                                                                                                                                                                                                                         training accuracy

                                                                                                                                                                                                                                                                                                                                              training accuracy
                        0.6                                                                        0.6                                                                                 0.6                                                                                   0.6                                                                                  0.6

                        0.4                                                                        0.4                                                                                 0.4                                                                                   0.4                                                                                  0.4

                        0.2                                                                        0.2                                                                                 0.2                                                                                   0.2                                                                                  0.2

                        0.0                                                                        0.0                                                                                 0.0                                                                                   0.0                                                                                  0.0
                              0           100       200         300    400                 500 0          100                       200         300      400                  500 0           100        200          300                        400               500 0                  100        200    300    400                                 500 0             100        200         300     400        500
                                           number of units ablated n                                           number of units ablated n                                                        number of units ablated n                                                                   number of units ablated n                                                      number of units ablated n
     B            1.4                                                                              1.4                                                                                 1.4                                                                                   1.4                                                                                  1.4

                  1.2                                                                              1.2                                                                                 1.2                                                                                   1.2                                                                                  1.2

                  1.0                                                                              1.0                                                                                 1.0                                                                                   1.0                                                                                  1.0

                  0.8                                                                              0.8                                                                                 0.8                                                                                   0.8                                                                                  0.8

                  0.6                                                                              0.6                                                                                 0.6                                                                                   0.6                                                                                  0.6

                  0.4                                                                              0.4                                                                                 0.4                                                                                   0.4                                                                                  0.4

                                          0.2       0.4         0.6     0.8                                0.2                          0.4      0.6         0.8                               0.2        0.4                      0.6             0.8                                     0.2        0.4    0.6     0.8                                                  0.2        0.4         0.6         0.8

    C                             1.2                                                                                                                                                                                                                                                                                                         0.8

                                                                                                                                        0.6
                                                                                                                                                                                                                                           0.8

                                                                                                                                                                                                                      generalization gap
                                  1.0                                                                                                                                                                                                                                                                                                         0.6
                                                                                                                   generalization gap

                                                                                                                                                                                                                                                                                                                         generalization gap
                                                                                                                                                                                                                                           0.6
                                                                                                                                        0.4
                                  0.8                                                                                                                                                                                                                                                                                                         0.4
                                                                                                                                                                                                                                           0.4

                                                                                                                                        0.2
                                  0.6                                                                                                                                                                                                                                                                                                         0.2
                                                                                                                                                                                                                                           0.2

                                  0.4                                                                                                   0.0                                                                                                0.0                                                                                                0.0
                                    0.0                   0.2            0.4                             0.6                              0.0          0.2                             0.4      0.6             0.8                                                                                                                             0.0                       0.2             0.4          0.6          0.8
                                                                                                                                                                                                                                                   0.6                             0.8             1.0       1.2
                        (1)                                                                                       (2)                                                                                                              (3)                                                                                   (4)                                            estimated generalization gap

Figure 5. (A) The evolution example curves of accuracy E(n, Dj ) and Er (n, Dj ) for the five networks. (B) Scatter plot between the two
quantities ζ(D) and κ(D) across all the classes in the ImageNet dataset. (C) Scatter plots between different quantities.

                        Model                               Training Acc                                   Testing Acc                                             Gap                                                           tremely strong linear correlation with generalization gap,
                                                                                                                                                                                                                                 where the Pearson correlation coefficient reaches remark-
                        Model A                                       0.732                                                0.657                                   0.075                                                         ably 0.998. Even use κ(D) instead of ζ(D), the results in
                        Model B                                       0.730                                                0.600                                   0.130                                                         Fig.5C(3) still shows a well degree of linear correlation with
                        Model C                                       0.818                                                0.543                                   0.275                                                         Pearson correlation 0.967. This support strongly that ζ(D)
                        Model D                                       0.828                                                0.444                                   0.384                                                         and κ(D) are both really effective characterizations of the
                        Model E                                       0.978                                                0.374                                   0.604                                                         generalization of DNNs.
                                                                                                                                                                                                                                    Fig.5C(4) shows that the points with respect to five net-
                        Table 1. Training and testing accuracies of 5 networks.                                                                                                                                                  works lies very closely to the reference line, indicating that
                                                                                                                                                                                                                                 the estimated generalization gap and the true gap are almost
                                                                                                                                                                                                                                 equal. Also, the SSR here is only 0.004.
their quantity pairs move gradually towards the bottom right
direction. Besides, we could also find that the quantity pairs                                                                                                                                                                   5. Conclusion
of those networks with better generalization would be more
gathered, especially with respect to the ζ(D).                                                                                                                                                                                      We propose a method for reliably estimating the gener-
   Then, Fig.5C(1) shows the similar scatter plot of the                                                                                                                                                                         alization ability of DNNs. By characterizing the sparsity
quantity pair (ζ(D), κ(D)) for the five networks. The colors                                                                                                                                                                     using specific-designed cumulative unit ablation, we con-
are in the same scale with it used in Fig.4A. As we could see                                                                                                                                                                    clude two key quantities both of which could effectively
that from the top right corner to the left bottom, the gener-                                                                                                                                                                    give measurements of the sparsity degree for network units
alization gap gradually increases. This is the same with the                                                                                                                                                                     at a layer. We empirically show the two key quantities are
regularity in in Fig.4A, and again provides evidence in the                                                                                                                                                                      in well correlation with the generalization ability, and in the
correlation between the two sparsity quantities and network                                                                                                                                                                      meantime find that the two key quantities and the practi-
generalization gap.                                                                                                                                                                                                              cal generalization gap present strongly linearly correlated
   Fig.5C(2) visualizes the linear relation between ζ(D)                                                                                                                                                                         where the Pearson correlation coefficient could somehow
and generalization gap for five networks. ζ(D) has an ex-                                                                                                                                                                        beyond 0.98. So a simple linear model is built to find the un-

                                                                                                                                                                                                              7
derlying parameters for precisely describing the connection                     ral network pruning. In Hugo Larochelle, Marc’Aurelio
between the two quantities and the generalization gap with                      Ranzato, Raia Hadsell, Maria-Florina Balcan, and Hsuan-
respect to a specific network architecture and task. To test                    Tien Lin, editors, Advances in Neural Information Process-
the effectiveness of this linear model, we apply the conven-                    ing Systems 33: Annual Conference on Neural Information
tional training-testing framework on trained networks with                      Processing Systems 2020, NeurIPS 2020, December 6-12,
                                                                                2020, virtual, 2020. 2
the same architecture but having various generalization gap.
Our results show the generalization gap could be predicted                [8]   David Bau, Bolei Zhou, Aditya Khosla, Aude Oliva, and
                                                                                Antonio Torralba. Network dissection: Quantifying inter-
via the linear model both accurately and stably. Meanwhile,
                                                                                pretability of deep visual representations. In 2017 IEEE Con-
we analyze another predicted model based on the collection                      ference on Computer Vision and Pattern Recognition, CVPR
of margin distribution. We find that due to the lack of scales,                 2017, Honolulu, HI, USA, July 21-26, 2017, pages 3319–
the model built from the current approximation of margin                        3327. IEEE Computer Society, 2017. 2
distribution may be more sensitive to networks trained with               [9]   David Bau, Jun-Yan Zhu, Hendrik Strobelt, Àgata Lapedriza,
certain strategy (such as batch normalization).                                 Bolei Zhou, and Antonio Torralba. Understanding the role of
    It is supposed that the estimation of generalization for                    individual units in a deep neural network. Proc. Natl. Acad.
DNNs is a fundamental but challenging problem. On the                           Sci. USA, 117(48):30071–30078, 2020. 2, 3
one hand, for estimation on generalization bounds, they are              [10]   Jia Deng, Wei Dong, Richard Socher, Li-Jia Li, Kai Li,
still too loose to be used on practical networks for large                      and Li Fei-Fei. Imagenet: A large-scale hierarchical image
datasets. On the other hand, for an accurate prediction, a big                  database. In 2009 IEEE conference on computer vision and
drawback is that they may rely heavily on the prior knowl-                      pattern recognition, pages 248–255. Ieee, 2009. 4
edge of the input space. In our further works, we expect to              [11]   Laurent Dinh, Razvan Pascanu, Samy Bengio, and Yoshua
alleviate this reliance and meanwhile draw possible connec-                     Bengio. Sharp minima can generalize for deep nets. In
tion with the current generalization bounds                                     Doina Precup and Yee Whye Teh, editors, Proceedings of the
                                                                                34th International Conference on Machine Learning, ICML
                                                                                2017, Sydney, NSW, Australia, 6-11 August 2017, volume 70
References                                                                      of Proceedings of Machine Learning Research, pages 1019–
 [1] Alessandro Achille and Stefano Soatto. On the emergence of                 1028. PMLR, 2017. 2
     invariance and disentangling in deep representations. arXiv,        [12]   Alexey Dosovitskiy, Lucas Beyer, Alexander Kolesnikov,
     abs/1706.01350, 2017. 2                                                    Dirk Weissenborn, Xiaohua Zhai, Thomas Unterthiner,
 [2] Thomas J Archdeacon. Correlation and regression analysis:                  Mostafa Dehghani, Matthias Minderer, Georg Heigold, Syl-
     a historian’s guide. Univ of Wisconsin Press, 1994. 5                      vain Gelly, Jakob Uszkoreit, and Neil Houlsby. An image
 [3] Sanjeev Arora, Rong Ge, Behnam Neyshabur, and Yi Zhang.                    is worth 16x16 words: Transformers for image recognition
     Stronger generalization bounds for deep nets via a compres-                at scale. In 9th International Conference on Learning Rep-
     sion approach. In Jennifer G. Dy and Andreas Krause, ed-                   resentations, ICLR 2021, Virtual Event, Austria, May 3-7,
     itors, Proceedings of the 35th International Conference on                 2021. OpenReview.net, 2021. 1
     Machine Learning, ICML 2018, Stockholmsmässan, Stock-              [13]   Gamaleldin F. Elsayed, Dilip Krishnan, Hossein Mobahi,
     holm, Sweden, July 10-15, 2018, volume 80 of Proceed-                      Kevin Regan, and Samy Bengio. Large margin deep net-
     ings of Machine Learning Research, pages 254–263. PMLR,                    works for classification. In Samy Bengio, Hanna M. Wal-
     2018. 2                                                                    lach, Hugo Larochelle, Kristen Grauman, Nicolò Cesa-
 [4] Peter L Bartlett. The sample complexity of pattern classifi-               Bianchi, and Roman Garnett, editors, Advances in Neu-
     cation with neural networks: the size of the weights is more               ral Information Processing Systems 31: Annual Conference
     important than the size of the network. IEEE transactions on               on Neural Information Processing Systems 2018, NeurIPS
     Information Theory, 44(2):525–536, 1998. 1                                 2018, December 3-8, 2018, Montréal, Canada, pages 850–
 [5] Peter L. Bartlett, Dylan J. Foster, and Matus Telgarsky.                   860, 2018. 2, 6
     Spectrally-normalized margin bounds for neural networks.            [14]   Ruth Fong and Andrea Vedaldi. Interpretable explana-
     In Isabelle Guyon, Ulrike von Luxburg, Samy Bengio,                        tions of black boxes by meaningful perturbation. arXiv,
     Hanna M. Wallach, Rob Fergus, S. V. N. Vishwanathan,                       abs/1704.03296, 2017. 2
     and Roman Garnett, editors, Advances in Neural Informa-             [15]   Jonathan Frankle and Michael Carbin. The lottery ticket hy-
     tion Processing Systems 30: Annual Conference on Neural                    pothesis: Training pruned neural networks. arxiv preprint,
     Information Processing Systems 2017, December 4-9, 2017,                   abs/1803.03635, 2018. 2
     Long Beach, CA, USA, pages 6240–6249, 2017. 2                       [16]   Alon Gonen and Shai Shalev-Shwartz. Fast rates for empir-
 [6] Peter L Bartlett and Shahar Mendelson. Rademacher and                      ical risk minimization of strict saddle problems. In Satyen
     gaussian complexities: Risk bounds and structural results.                 Kale and Ohad Shamir, editors, Proceedings of the 30th
     Journal of Machine Learning Research, 3(Nov):463–482,                      Conference on Learning Theory, COLT 2017, Amsterdam,
     2002. 1                                                                    The Netherlands, 7-10 July 2017, volume 65 of Proceedings
 [7] Brian Bartoldson, Ari S. Morcos, Adrian Barbu, and Gor-                    of Machine Learning Research, pages 1043–1063. PMLR,
     don Erlebacher. The generalization-stability tradeoff in neu-              2017. 2

                                                                     8
[17] Kaiming He, Xiangyu Zhang, Shaoqing Ren, and Jian Sun.                      gin bounds for neural networks. In 6th International Confer-
     Deep residual learning for image recognition. In 2016 IEEE                  ence on Learning Representations, ICLR 2018, Vancouver,
     Conference on Computer Vision and Pattern Recognition,                      BC, Canada, April 30 - May 3, 2018, Conference Track Pro-
     CVPR 2016, Las Vegas, NV, USA, June 27-30, 2016, pages                      ceedings, 2018. 2
     770–778, 2016. 1                                                     [30]   Behnam Neyshabur, Ryota Tomioka, and Nathan Srebro.
[18] Sepp Hochreiter and Jürgen Schmidhuber. Flat minima. Neu-                  Norm-based capacity control in neural networks. In Confer-
     ral Comput., 9(1):1–42, 1997. 2                                             ence on Learning Theory, pages 1376–1401. PMLR, 2015.
[19] Yiding Jiang, Dilip Krishnan, Hossein Mobahi, and Samy                      1
     Bengio. Predicting the generalization gap in deep networks           [31]   Joseph Redmon, Santosh Kumar Divvala, Ross B. Girshick,
     with margin distributions. In 7th International Conference                  and Ali Farhadi. You only look once: Unified, real-time ob-
     on Learning Representations, ICLR 2019, New Orleans, LA,                    ject detection. In 2016 IEEE Conference on Computer Vision
     USA, May 6-9, 2019, 2019. 2, 5                                              and Pattern Recognition, CVPR 2016, Las Vegas, NV, USA,
[20] Nitish Shirish Keskar, Dheevatsa Mudigere, Jorge Nocedal,                   June 27-30, 2016, pages 779–788. IEEE Computer Society,
     Mikhail Smelyanskiy, and Ping Tak Peter Tang. On large-                     2016. 1
     batch training for deep learning: Generalization gap and             [32]   Shaoqing Ren, Kaiming He, Ross B. Girshick, and Jian Sun.
     sharp minima. In 5th International Conference on Learning                   Faster R-CNN: towards real-time object detection with re-
     Representations, ICLR 2017, Toulon, France, April 24-26,                    gion proposal networks. CoRR, abs/1506.01497, 2015. 1
     2017, Conference Track Proceedings, 2017. 2                          [33]   Simone Scardapane, Danilo Comminiello, Amir Hussain,
[21] Vladimir Koltchinskii and Dmitry Panchenko. Empirical                       and Aurelio Uncini. Group sparse regularization for deep
     margin distributions and bounding the generalization error                  neural networks. Neurocomputing, 241:81–89, 2017. 2
     of combined classifiers. The Annals of Statistics, 30(1):1–          [34]   Karen Simonyan and Andrew Zisserman. Very deep con-
     50, 2002. 1                                                                 volutional networks for large-scale image recognition. In
[22] Alex Krizhevsky et al. Learning multiple layers of features                 Yoshua Bengio and Yann LeCun, editors, 3rd International
     from tiny images. 2009. 4                                                   Conference on Learning Representations, ICLR 2015, San
[23] Matthew L. Leavitt and Ari S. Morcos. Selectivity consid-                   Diego, CA, USA, May 7-9, 2015, Conference Track Proceed-
     ered harmful: evaluating the causal impact of class selec-                  ings, 2015. 1, 4
     tivity in dnns. In 9th International Conference on Learning          [35]   Jure Sokolic, Raja Giryes, Guillermo Sapiro, and Miguel
     Representations, ICLR 2021, Virtual Event, Austria, May 3-                  R. D. Rodrigues. Robust large margin deep neural networks.
     7, 2021. OpenReview.net, 2021. 2                                            IEEE Trans. Signal Process., 65(16):4265–4280, 2017. 2
[24] Shiwei Liu. Learning sparse neural networks for better gen-          [36]   Nitish Srivastava, Geoffrey Hinton, Alex Krizhevsky, Ilya
     eralization. In Christian Bessiere, editor, Proceedings of the              Sutskever, and Ruslan Salakhutdinov. Dropout: a simple way
     Twenty-Ninth International Joint Conference on Artificial In-               to prevent neural networks from overfitting. The journal of
     telligence, IJCAI 2020, pages 5190–5191. ijcai.org, 2020. 2                 machine learning research, 15(1):1929–1958, 2014. 2
[25] Shiwei Liu, Decebal Constantin Mocanu, and Mykola                    [37]   Shizhao Sun, Wei Chen, Liwei Wang, Xiaoguang Liu, and
     Pechenizkiy.     On improving deep learning generaliza-                     Tie-Yan Liu. On the depth of deep neural networks: A the-
     tion with adaptive sparse connectivity. arXiv preprint,                     oretical view. In Proceedings of the AAAI Conference on
     abs/1906.11626, 2019. 2                                                     Artificial Intelligence, volume 30, 2016. 1
[26] Christos Louizos, Max Welling, and Diederik P. Kingma.               [38]   Wei Wen, Chunpeng Wu, Yandan Wang, Yiran Chen, and
     Learning sparse neural networks through l0 regularization.                  Hai Li. Learning structured sparsity in deep neural net-
     arXiv preprint, abs/1712.01312, 2017. 2                                     works. In Daniel D. Lee, Masashi Sugiyama, Ulrike von
[27] Ari S. Morcos, David G. T. Barrett, Neil C. Rabinowitz,                     Luxburg, Isabelle Guyon, and Roman Garnett, editors, Ad-
     and Matthew Botvinick. On the importance of single di-                      vances in Neural Information Processing Systems 29: An-
     rections for generalization. In 6th International Conference                nual Conference on Neural Information Processing Systems
     on Learning Representations, ICLR 2018, Vancouver, BC,                      2016, December 5-10, 2016, Barcelona, Spain, pages 2074–
     Canada, April 30 - May 3, 2018, Conference Track Proceed-                   2082, 2016. 2
     ings. OpenReview.net, 2018. 2                                        [39]   Chiyuan Zhang, Samy Bengio, Moritz Hardt, Benjamin
[28] Behnam Neyshabur, Srinadh Bhojanapalli, David                               Recht, and Oriol Vinyals. Understanding deep learning re-
     McAllester, and Nati Srebro. Exploring generalization                       quires rethinking generalization. In 5th International Con-
     in deep learning. In Isabelle Guyon, Ulrike von Luxburg,                    ference on Learning Representations, ICLR 2017, Toulon,
     Samy Bengio, Hanna M. Wallach, Rob Fergus, S. V. N.                         France, April 24-26, 2017, 2017. 1, 4
     Vishwanathan, and Roman Garnett, editors, Advances                   [40]   Bolei Zhou, Yiyou Sun, David Bau, and Antonio Torralba.
     in Neural Information Processing Systems 30: Annual                         Revisiting the importance of individual units in cnns via ab-
     Conference on Neural Information Processing Systems                         lation. arXiv preprint, abs/1806.02891, 2018. 2
     2017, December 4-9, 2017, Long Beach, CA, USA, pages
     5947–5956, 2017. 2
[29] Behnam Neyshabur, Srinadh Bhojanapalli, and Nathan Sre-
     bro. A pac-bayesian approach to spectrally-normalized mar-

                                                                      9
A. Appendix                                                            A.2. Experiments with MobileNet on CIFAR100
A.1. Implementation Details of Experiments                                 For this experiment, we use the MobileNet architecture
                                                                       (Howard AG et al. 2017) for classifying CIFAR100. When
Implementation with VGG16 on CIFAR100. In this ex-                     training the networks, we still partially corrupt the dataset
periment, the networks we used are the standard VGG16                  with the same fractions as being used in the VGG16. Simi-
architecture. To get networks with a wider range of gener-             larly, to get networks with different generalization gaps, we
alization, we perform the following implementations when               use the same training strategies as them in VGG16 except
training,                                                              for the dropout and batch normalization.
                                                                           The results are presented in Fig.6. As expected, we
  • Build network with or without batch normalization.                 could see in Fig.6A that when the fraction of corrupted
                                                                       labels becomes higher, n0 (Dj ) becomes gradually larger
  • Use dropout at the fully connected layers with rate
                                                                       while nr0 (Dj ) becomes smaller, and meanwhile, the area
    from {0, 0.3, 0.5}.
                                                                       becomes smaller as well. This is the same with VGG16
  • Use SGD optimizer             with    momentum       from          (Fig.3A in the paper). Then, Fig.6B shows the quantity
    {0, 0.5, 0.9}.                                                     pair (ζ(D), κ(D)) of each network with a scatter plot. The
                                                                       quantity pair of networks with better generalization ability
  • Use L2 regularization          with    coefficient   from          mostly lie in the top right corner, and contrarily these with
    {0, 0.0001}.                                                       poor generalization ability lie in the bottom left corner. By
                                                                       using our estimation model, Fig.6C shows the estimated
  • Use batch size from {128, 256, 512}.                               generalization gap of all the networks. We could see that
                                                                       the points are scattered closely between the reference line
  • Use or not use data augmentation with random crop-                 x = y. In the meantime, the SSR of this fitting is 0.074,
    ping, horizontal flipping and rotation.                            which demonstrates the effectiveness of our model.

  • Partially corrupt labels in the training dataset with frac-        A.3. Experiments with ResNet34 on CIFAR100
    tions from {0, 0.2, 0.4, 0.6, 0.8, 1.0}.                              For this experiment, we change the network architecture
                                                                       to ResNet34 (He et al. 2016) for classifying CIFAR100.
Implementation with VGG16 on ImageNet. In this ex-                     When training the networks, all the strategies are the same
periment, the networks we used are still the standard                  with them used in MobileNet for obtaining the networks
VGG16 architecture. And the following implementations                  with diverse generalization gaps.
are performed when training,                                              Fig.7 presents the results. All the figures in the Fig.7 are
                                                                       in the same meaning with Fig.6. And we get the similar
  • Model A. The hyper-parameters are the same as those                results with the results when using VGG16 and MobileNet.
    in the paper (Simonyan and Zisserman 2015).                        The SSR here is 0.122, which shows again the effectiveness
                                                                       of our model.
  • Model B. The hyper-parameters are the same as those
    in Model A, except for not using the data augmentation             A.4. Code Release
    strategy.                                                             The code is available at github.
                                                                          https://github.com/JustNobody0204/generalization-
  • Model C. The hyper-parameters are the same as those                estimating.git
    in Model B, except for changing the momentum to 0.

  • Model D. The hyper-parameters are the same as those
    in Model C, except for that only the first fully con-
    nected layer use the dropout with the rate of 0.3.

  • Model E. None of the conventional training enhance-
    ment technique is applied. Basically, It is Model D
    without using dropout and l2 regularization.

Units in cumulative unit ablation. For our cumulative
unit ablation, we choose units at the ”block5 conv3” layer,
which is the last convolution layer in VGG16 architecture.

                                                                  10
A                                corruption 0.0                                              corruption 0.2                                                   corruption 0.4                                             corruption 0.6                                                                                 corruption 0.8                                                  corruption 1.0
                                  E(n) and Er(n)                                             E(n) and Er(n)                                                    E(n) and Er(n)                                              E(n) and Er(n)                                                                                E(n) and Er(n)                                                  E(n) and Er(n)
                        1.0                                                            1.0                                                               1.0                                                         1.0                                                                                     1.0                                                             1.0

                        0.8                                                            0.8
    training accuracy

                                                                 training accuracy                                                                       0.8                                                         0.8                                                                                     0.8                                                             0.8

                                                                                                                                    training accuracy

                                                                                                                                                                                                 training accuracy

                                                                                                                                                                                                                                                                                        training accuracy

                                                                                                                                                                                                                                                                                                                                                        training accuracy
                        0.6                                                            0.6                                                               0.6                                                         0.6                                                                                     0.6                                                             0.6

                        0.4                                                            0.4                                                               0.4                                                         0.4                                                                                     0.4                                                             0.4

                        0.2                                                            0.2                                                               0.2                                                         0.2                                                                                     0.2                                                             0.2

                        0.0                                                            0.0                                                               0.0                                                         0.0                                                                                     0.0                                                             0.0
                              0    100        200   300   400              500 0              100        200        300     400              500 0              100        200   300   400                 500 0            100        200                                  300   400            500 0                    100        200   300    400            500 0                    100        200   300   400    500
                                     number of units ablated n                                  number of units ablated n                                         number of units ablated n                                   number of units ablated n                                                                     number of units ablated n                                       number of units ablated n

                                                           B
                                                                                     1.0                                                                                                0.9860
                                                                                                                                                                                                                             C                                        1.0
                                                                                                                                                                                                                                                                                    line x = y

                                                                                                                                                                                                                                       estimated generalization gap
                                                                                                                                                                                        0.9253

                                                                                     0.8
                                                                                                                                                                                        0.8645

                                                                                                                                                                                                                                                                      0.8
                                                                                                                                                                                        0.8038
                                                                                     0.6
                                                                                                                                                                                        0.7430

                                                                                     0.4                                                                                                0.6823
                                                                                                                                                                                                                                                                      0.6
                                                                                                                                                                                        0.6215

                                                                                     0.2
                                                                                                                                                                                        0.5608

                                                                                                                                                                                        0.5000
                                                                                     0.0
                                                                                                              0.2             0.3                               0.4                                                                                                                                                0.6                     0.8                                     1.0
                                                                                                                                                                                                                                                                                                    true generalization gap

Figure 6. Results on MobileNet. (A) The evolution example curves of accuracy E(n, Dj ) and Er (n, Dj ) on dataset with separate
corruption fraction of labels in {0, 0.2, 0.4, 0.6, 0.8, 1.0}. (B) Scatter plot between the two key quantities ζ(D) and κ(D) across all the
networks. The color of each point indicates the generalization gap, where red represents the smallest value and purple represents the largest.
(C) Scatter plot between the estimated generalization gap and the true generalization gap.

    A                                corruption 0.0                                              corruption 0.2                                                   corruption 0.4                                             corruption 0.6                                                                                 corruption 0.8                                                  corruption 1.0
                                  E(n) and Er(n)                                             E(n) and Er(n)                                                    E(n) and Er(n)                                              E(n) and Er(n)                                                                                E(n) and Er(n)                                                  E(n) and Er(n)
                        1.0                                                            1.0                                                               1.0                                                         1.0                                                                                     1.0                                                             1.0

                                                                                                                                                         0.8
                                                                                                                                     training accuracy

                                                                                       0.8
                                                                   training accuracy

                        0.8
    training accuracy

                                                                                                                                                                                                                     0.8
                                                                                                                                                                                                 training accuracy

                                                                                                                                                                                                                                                                                                             0.8                                                             0.8
                                                                                                                                                                                                                                                                                         training accuracy

                                                                                                                                                                                                                                                                                                                                                         training accuracy

                        0.6                                                            0.6                                                               0.6                                                         0.6                                                                                     0.6                                                             0.6

                        0.4                                                            0.4                                                               0.4                                                         0.4                                                                                     0.4                                                             0.4

                        0.2                                                            0.2                                                               0.2                                                         0.2                                                                                     0.2                                                             0.2

                        0.0                                                            0.0                                                               0.0                                                         0.0                                                                                     0.0                                                             0.0
                              0    100        200   300   400             500 0                100       200        300     400                500 0             100       200   300   400                   500 0          100        200                                  300   400            500 0                     100       200    300   400              500 0                  100        200   300   400    500
                                     number of units ablated n                                  number of units ablated n                                         number of units ablated n                                   number of units ablated n                                                                     number of units ablated n                                       number of units ablated n

                                                           B
                                                                             1.0                                                                                                        0.9860
                                                                                                                                                                                                                             C                                        1.0
                                                                                                                                                                                                                                       estimated generalization gap

                                                                                                                                                                                        0.9253

                                                                                                                                                                                        0.8645
                                                                             0.8
                                                                                                                                                                                                                                                                      0.8
                                                                                                                                                                                        0.8038

                                                                                                                                                                                        0.7430

                                                                             0.6
                                                                                                                                                                                        0.6823
                                                                                                                                                                                                                                                                      0.6
                                                                                                                                                                                        0.6215

                                                                             0.4                                                                                                        0.5608

                                                                                                                                                                                        0.5000
                                                                                                                                                                                                                                                                                                                   0.6                     0.8                                     1.0
                                                                                                     0.4              0.5                     0.6                  0.7           0.8
                                                                                                                                                                                                                                                                                                     true generalization gap

                                                                           Figure 7. Results on ResNet34. The meaning of all the figures is the same with Fig.6.

                                                                                                                                                                                                 11
You can also read