Pruning of Deep Spiking Neural Networks through Gradient Rewiring
←
→
Page content transcription
If your browser does not render page correctly, please read the page content below
Pruning of Deep Spiking Neural Networks through Gradient Rewiring Yanqi Chen1,3 , Zhaofei Yu1,2,3∗ , Wei Fang1,3 , Tiejun Huang1,2,3 and Yonghong Tian1,3∗ 1 Department of Computer Science and Technology, Peking University 2 Institute for Artificial Intelligence, Peking University 3 Peng Cheng Laboratory {chyq, yuzf12, fwei, tjhuang, yhtian}@pku.edu.cn arXiv:2105.04916v3 [cs.NE] 22 Jun 2021 Abstract As these chips are usually resource-constrained, the compres- sion of SNNs is thus of great importance to lower compu- Spiking Neural Networks (SNNs) have been at- tation cost to make the deployment on neuromorphic chips tached great importance due to their biological more feasible, and realize the practical application of SNNs. plausibility and high energy-efficiency on neuro- Researchers have made a lot of efforts on the pruning of morphic chips. As these chips are usually resource- SNNs. Generally, most methods learn from the existing com- constrained, the compression of SNNs is thus cru- pression techniques in the field of ANNs and directly apply cial along the road of practical use of SNNs. them to SNNs, such as directly truncating weak weights be- Most existing methods directly apply pruning ap- low the threshold [Neftci et al., 2016; Rathi et al., 2019; proaches in artificial neural networks (ANNs) to Liu et al., 2020] and using Lottery Ticket Hypothesis [Fran- SNNs, which ignore the difference between ANNs kle and Carbin, 2019] in SNNs [Martinelli et al., 2020]. Some and SNNs, thus limiting the performance of the works use other metrics, such as levels of weight development pruned SNNs. Besides, these methods are only [Shi et al., 2019] and similarity between spike trains [Wu et suitable for shallow SNNs. In this paper, in- al., 2019a], as guidance for removing insignificant or redun- spired by synaptogenesis and synapse elimination dant weights. Notwithstanding some spike-based metrics for in the neural system, we propose gradient rewiring pruning are introduced in these works, the results on shallow (Grad R), a joint learning algorithm of connectiv- SNNs are not persuasive enough. ity and weight for SNNs, that enables us to seam- There are some attempts to mimic the way the human brain lessly optimize network structure without retrain- learns to enhance the pruning of SNNs. Qi et al. [2018] ing. Our key innovation is to redefine the gradient enabled structure learning of SNNs by adding connection to a new synaptic parameter, allowing better ex- gates for each synapse. They adopted the biological plau- ploration of network structures by taking full ad- sible local learning rules, Spike-Timing-Dependent Plastic- vantage of the competition between pruning and ity (STDP), to learn gated connection while using Tempotron regrowth of connections. The experimental re- rules [Gütig and Sompolinsky, 2006] for regular weight op- sults show that the proposed method achieves min- timization. However, the method only applies to the SNN imal loss of SNNs’ performance on MNIST and with a two-layer structure. Deng et al. [2019] exploited the CIFAR-10 datasets so far. Moreover, it reaches a ADMM optimization tool to prune the connections and op- ∼3.5% accuracy loss under unprecedented 0.73% timize network weights alternately. Kappel et al. [2015] connectivity, which reveals remarkable structure proposed synaptic sampling to optimize network construction refining capability in SNNs. Our work suggests and weights by modeling the spine motility of spiking neu- that there exists extremely high redundancy in deep rons as Bayesian learning. Bellec et al. went a further step SNNs. Our codes are available at https://github. and proposed Deep Rewiring (Deep R) as a pruning algo- com/Yanqi-Chen/Gradient-Rewiring. rithm for ANNs [2018a] and Long Short-term Memory Spik- ing Neural Networks (LSNNs) [2018b], which was then de- ployed on SpiNNaker 2 prototype chips [Liu et al., 2018]. 1 Introduction Deep R is an on-the-fly pruning algorithm, which indicates its Spiking Neural Networks (SNNs), as the third generation of ability to reach target connectivity without retraining. Despite Artificial Neural Networks (ANN) [Maass, 1997], have in- these methods attempting to combine learning and pruning creasingly aroused researchers’ great interest in recent years in a unified framework, they focus more on pruning, rather due to their high biological plausibility, temporal information than the rewiring process. Actually, the pruned connection processing capability, and low power consumption on neuro- is hard to regrow, limiting the diversity of network structures morphic chips [Gerstner and Kistler, 2002; Roy et al., 2019]. and thus the performance of the pruned SNNs. In this paper, inspired by synaptogenesis and synapse elim- ∗ Corresponding author ination in the neural system [Huttenlocher and Nelson, 1994;
Petzoldt and Sigrist, 2014], we propose gradient rewiring (Grad R), a joint learning algorithm of connectivity and Input Conv, 3×3, 256, BN weight for SNNs, by introducing a new synaptic parameter to Block 1 LIF unify the representation of synaptic connectivity and synaptic Block 2 weights. By redefining the gradient to the synaptic parame- Conv, 3×3, 256, BN ter, Grad R takes full advantage of the competition between Dropout LIF pruning and regrowth of connections to explore the optimal FC, 2048 network structure. To show Grad R can optimize synaptic LIF Conv, 3×3, 256, BN parameters of SNNs, we theoretically prove that the angle LIF between the accurate gradient and redefined gradient is less FC, 100 than 90◦ . We further derive the learning rule and realize the LIF MaxPool2d, /2 pruning of deep SNNs. We evaluate our algorithm on the AvgPool1d, /10 MNIST and CIFAR-10 dataset and achieves state-of-the-art Output performance. On CIFAR-10, Grad R reaches a classifica- tion accuracy of 92.03% (merely
synaptic parameter θ can describe both formation and elimi- Value of Loss Function Value of Loss Function nation of synaptic connection. Optimizing Synaptic Parameters θ With the above definition, the gradient descent algorithm and L and L̄ its variants can be directly applied to optimize all synaptic pa- L(θ < 0) rameters θ of the SNNs. Specifically, we assume that the loss L̄(θ < 0) function has the form L(xi , yi , w), where xi , yi denotes the i-th input sample and the corresponding target output, and −1 0 1 −1 0 1 θ θ w is the vector of all synaptic weights. The correspond- ∂L ∂L ing vectors of all synaptic parameters and signs are θ, s re- (a) s ∂w 0 w=0 w=0 spectively. Suppose wk , θk , sk stand for the k-th element of w, θ, s. Then we have w = s ReLU(θ), where is the Hadamard product, and ReLU is an element-wise function. Figure 2: The comparison of gradient descent on L̄ and L. The red The gradient descent algorithm would adjust all the synaptic arrows specify the direction of gradient descent on L̄ for θ. parameters to lower the loss function L, that is: can be overcome by adding a constraint scope to the parame- ∂L ∂L ∂wk ∂L ∆θk ∝ = = sk H(θk ), (3) ter θ (discussed in the subsequent part). ∂θk ∂wk ∂θk ∂wk Similar to the gradient descent algorithm (Eq. (3)), Grad where H(·) represents the Heaviside step function. Ideally, if R can also optimize synaptic parameters θ of the SNNs. An θk < 0 during learning, the corresponding synaptic connec- intuitive explanation is the angle between the accurate gradi- tion will eliminate. It will regrow when θk > 0. Nonethe- ent ∇θ L and the redefined gradient ∇θ L̄ is below 90°, i.e., less, the pruned connection will never regrow. To see this, h∇θ L, ∇θ L̄i ≥ 0, which can be derived from Lemma 1. Fur- supposing that θk = θ < 0, we get H(θk ) = 0, and ther, a rigorous proof is illustrated in Theorem 1. wk = sk ReLU(θk ) = 0. Then: Theorem 1 (Convergence). For a spiking neural network with synaptic parameters θ, synaptic weights w, signs s, ∂L ∂L and the SoftPlus mapping, a smooth approximation to the = H(θ)sk = 0, for θ < 0. (4) ∂θk θk =θ ∂wk wk =0 ReLU, wk = sβk log(1 + exp(βθk )), β 1, where wk , θk , sk stand for the k-th element of w, θ, s, the loss function L(θ) As the gradient of the loss function L with respect to θk is with lowerbound, if there exists a constant C such that 0, θk will keep less than 0, and the corresponding synap- ∀k, θk ≥ C and a Lipschitz constant L such that k∇θ L(θ1 )− tic connection will remain in the state of elimination. This ∇θ L(θ2 )k ≤ Lkθ1 − θ2 k for any differentiable points θ1 might result in a disastrous loss of expression capabilities for and θ2 , then L will decrease until convergence using gradi- SNNs. So far, the current theories and analyses can also sys- ent ∇θ L̄, if step size η > 0 is small enough. tematically explain the results of deep rewiring [Bellec et al., 2018a], where the pruned connection is hard to regrow even Proof. Let θ t be the parameter θ after t-th update steps us- noise is added to the gradient. There is no preference for min- ing ∇θ L̄ as direction for each step, that is, θ t+1 = θ t − imizing loss when adding noise to gradient. η∇θ L̄(θ t ). With Lipschitz continuous gradient, we have Gradient Rewiring L(θ t+1 ) − L(θ t ) Our new algorithm adjusts θ in the same way as gradient de- Z 1 scent, but for each θk , it uses a new formula to redefine the = h∇θ L(θ t (τ )), θ t+1 − θ t idτ gradient when θk < 0. 0 =h∇θ L(θ t ), θ t+1 − θ t i ∂ L̄ ∂L Z 1 (∆θk )GradR ∝ := sk , for θk ∈ R. ∂θk ∂wk wk =sk ReLU(θk ) + h∇θ L(θ t (τ )) − ∇θ L(θ t ), θ t+1 − θ t idτ 0 (5) ≤h∇θ L(θ t ), θ t+1 − θ t i (6) For an active connection θk = θ > 0, the definition makes Z 1 no difference as H(θk ) = 1. For an inactive connection + k∇θ L(θ t (τ )) − ∇θ L(θ t )k · kθ t+1 − θ t kdτ θk = θ < 0, the gradients are different as ∆θk = 0 and 0 ∂L 1 (∆θk )GradR ∝ sk ∂w . The change of gradient for in- Z k wk =0 ≤h∇θ L(θ t ), θ t+1 − θ t i + Lkθ t+1 − θ t k2 τ dτ active connections results in a different loss function L̄. As 0 ∂L illustrated in Fig. 2(a), when sk ∂w k wk =0 < 0, the gradient Lη 2 of L̄ guides a way to jump out the plateau in L (red arrow). = − ηh∇θ L(θ t ), ∇θ L̄(θ t )i + k∇θ L̄(θ t )k2 2 With the appropriate step size chosen, a depleted connection ≤ − η(1 − ALη/2)h∇θ L(θ t ), ∇θ L̄(θ t )i, will regrow without increasing the real loss L on this occa- sion. It should be noted that the situation in Fig. 2(b) can also where θ t (τ ) is defined by θ t (τ ) = θ t + τ (θ t+1 − θ t ), τ ∈ occur and push the parameter θk even more negative, which [0, 1]. The first two inequality holds by Cauchy-Schwarz
inequality and Lipschitz continuous gradient, respectively. The last inequality holds with constant A > 0 according to static conv.0 Lemma 1. static conv.1 conv.1 When k∇θ L̄(θ t )k > 0 and η < 2/AL, the last term in conv.2 Eq. (6) is strictly negative. Hence, we have a strictly decreas- conv.4 conv.5 ing sequence {L(θ n )}n∈N . Since the loss function has lower- conv.8 bound, this sequence must converges. conv.9 conv.11 conv.12 Prior-based Connectivity Control conv.14 There is no explicit control of connectivity in the above algo- conv.15 fc.2 rithm. We need to explore speed control of pruning targeting fc.4 diverse levels of sparsity. Accordingly, we introduce a prior −2 −1 0 1 2 distribution to impose constraints on the range of parameters. To reduce the number of extra hyperparameters, we adopt the Weight same Laplacian distribution for each synaptic parameter θk : Figure 3: The distribution of the parameters in a trained SNN with- α out pruning. Note that static conv.1, conv.2, conv.5, conv.9, conv.12 p(θk ) = exp(−α|θk − µ|), (7) 2 and conv.15 are all BN layers. where α and µ are scale parameter and location parameter respectively. The corresponding penalty term is: Algorithm 1: Gradient Rewiring with SGD ∂ log p(θk ) Input: Input data X, output labels Y , learning rate η, = −αsign(θk − µ), (8) ∂θk target sparsity p, penalty term α 1 Initialize weight w; which is similar to `1 -penalty with sign(·) denoting the sign 2 s ← sign(w), θ ← w s; function. 3 Calculate µ according to p, α; Here we illustrate how the prior influences network con- 4 for t in [1, N ] do nectivity. We assume that all the synaptic parameters θ fol- 5 for all synaptic parameters θk and sign sk ∈ s do low the same prior distribution in Eq. (7). Since the connec- 6 for input xi ∈ X and label yi ∈ Y do tion with the negative synaptic parameter (θk < 0) is pruned, 7 θk ← the probability that a connection is pruned is: Z 0 Z 0 θk − η sk ∂L(x∂w i ,yi ,θ) k + αsign(θk − µ) ; α p(θk )dθk = exp (−α|θk − µ|) dθk 8 end −∞ −∞ 2 9 wk ← sk ReLU(θk ); 1 (9) 2 exp(−αµ), if µ > 0 10 end = . end 1 − 12 exp(αµ), if µ ≤ 0 11 As all the synaptic parameters follow the same prior distri- bution, Eq. (9) also represents the network’s sparsity. Con- versely, given target sparsity p of the SNNs, one can get The prior encourages synaptic parameters to gather around µ, the appropriate choice of location parameter µ with Eq. (9), wherein the weights are more likely to be zero. Meanwhile, which is related to the scale parameter α α controls the speed of pruning. The penalty term prevent weights from growing too large, alleviating the concern in 1 − ln(2p), if p < 21 Fig. 2(b). µ= 1α 1 . (10) α ln(2 − 2p), if p ≥ 2 2.4 Detailed Learning Rule in SNNs Note that when p < 0.5, we have µ > 0, which indicates that the peak of the density function p(θk ) will locate at positive Here we apply Grad R to the SNN in Fig. 1 and derive the θk . As the sign sk is chosen randomly and set fixed, the corre- detailed learning rule. Specifically, the problem is to com- sponding prior distribution p(wk ) will be bimodal (explained pute the gradient of the loss function sk ∂L(x∂w i ,yi ,θ) (line 7 in k in Appendix C) . In view of that the distribution of weight Algorithm 1). We take a single FC layer without bias pa- (bias excluded) in convolutional and FC layers of a directly rameters as an example. Suppose that mt and ut represent trained SNN is usually zero-meaned unimodal (See Fig. 3), the membrane potential of all the LIF neurons after neural we assume p ≥ 0.5 in the following discussion. dynamics and after the trigger of a spike at time-step t, St The Grad R algorithm is given in Algorithm 1. It is im- denotes the vector of the binary spike of all the LIF neurons portant to notice that the target sparsity p is not bound to at time-step t, W denotes the weight matrix of FC layer, 1 be reached. Since the gradient of the loss function, i.e., denotes a vector composed entirely of 1 with the same size sk ∂L(x∂w i ,yi ,θ) k , also contributes to the posterior distribution as St . With these definitions, we can compute the gradient of (line 7). Consequently, the choice of these hyperparameters is output spikes St with respect to W by unfolding the network more a reference rather than a guarantee on the final sparsity. over the simulation time-steps, which is similar to the idea of
backpropagation through time (BPTT). We obtain α 0.0 0.0002 0.001 0.005 ∂St ∂St ∂mt 0.0001 0.0005 0.002 No Prune = , ∂W ∂mt ∂W ∂ut ∂ut ∂St ∂ut ∂mt 0.93 60 = + , (11) 40 ∂W ∂St ∂W ∂mt ∂W 30 Connectivity (%) ∂mt+1 ∂mt+1 ∂ut 1 > 0.92 20 = + I , Top-1 Acc. ∂W ∂ut ∂W τm t 0.91 10 where 6 ∂ut 0.90 4 = Diag(urest 1 − mt ), 3 ∂St 2 ∂mt+1 1 0.89 =1− , (12) 1 ∂ut τm 0.88 6e-01 ∂ut ∂ut ∂St 0 500 1000 1500 2000 = + Diag(1 − St ). ∂mt ∂St ∂mt Epoch ∂St When calculating ∂m , the derivative of spike with respect t Figure 4: Top-1 accuracy (solid line) on the test set and connectivity to the membrane potential after charging, we adopt the surro- (dashed line) during training on CIFAR-10. α is the penalty term in gate gradient method [Wu et al., 2018; Neftci et al., 2019]. Algorithm 1. Specifically, we utilize shifted ArcTan function H1 (x) = 1 1 π arctan(πx) + 2 as a surrogate function of the Heaviside ∂St in LIF neurons’ forward steps (Eq. 2) are dropped, i.e. man- step function H(·), and we have ∂m ≈ H10 (mt − uth ). t ually set to zero, according to [Zenke and Vogels, 2020]. The choice of all hyperparameters is shown in Table 3 and Table 3 Experiments 4. We evaluate the performance of the Grad R algorithm on im- age recognition benchmarks, namely the MNIST and CIFAR- 3.2 Performance 10 datasets. Our implementation of deep SNN are based on Fig. 4 shows the top-1 accuracy on the test set and connectiv- our open-source SNN framework SpikingJelly [Fang et al., ity during training on the CIFAR-10 datasets. The improve- 2020a]. ment of accuracy and network sparsity is simultaneous as the Grad R algorithm can learn the connectivity and weights of 3.1 Experimental Settings the SNN jointly. As the training goes into later stages, the For MNIST, we consider a shallow fully-connected net- growth of sparsity has little impact on the performance. The work (Input-Flatten-800-LIF-10-LIF) similar to previous sparsity will likely continue to increase, and the accuracy will works. On the CIFAR-10 dataset, we use a convolu- still maintain, even if we increase the training epochs. The ef- tional SNN with 6 convolutional layers and 2 FC layers fect of prior is also clearly shown in Fig. 4. Overall, a larger (Input-[[256C3S1-BN-LIF]×3-MaxPool2d2S2]×2-Flatten- penalty term α will bring higher sparsity and more appar- Dropout-2048-LIF-100-LIF-AvgPool1d10S10, illustrated in ent performance degradation. Aggressive pruning setting like Fig. 1) proposed in our previous work [Fang et al., 2020b]. huge α, e.g. α = 0.005, will soon prune most of the connec- The first convolutional layer acts as a learnable spike encoder tions and let the network leaves few parameters to be updated to transform the input static images to spikes. Through this (see the brown line). way, we elude from the Poisson encoder that requires a long As most of the previous works do not conduct experiments simulation time-step to restore inputs accurately [Wu et al., on datasets containing complicated patterns, such as CIFAR- 2019b]. The dropout mask is initiated at the first time-step 10, we also test the algorithm on a similar shallow fully- and remains constant within an iteration of mini-batch data, connected network for the classification task of the MNIST guaranteeing a fixed connective pattern in each iteration dataset. We further implement Deep R as a baseline, which [Lee et al., 2020]. All bias parameters are omitted except also can be viewed as a pruning-while-learning algorithm. for BatchNorm (BN) layers. Note that pruning weights in The highest performance on the test set against connectivity BN layers is equivalent to pruning all the corresponding is presented in Fig. 5. The performance of Grad R is com- input channels, which will lead to an unnecessary decline parable to Deep R on MNIST and much better than Deep R in the expressiveness of SNNs. Moreover, the peak of the on CIFAR-10. As illustrated in Fig. 5(b), the performance of probability density function of weights in BN layers are far Grad R is close to the performance of the fully connected net- from zero (see Fig. 3), which suggest the pruning is improper. work when the connectivity is above 10%. Besides, it reaches Based on the above reasons, we exclude all BN layers from a classification accuracy of 92.03% (merely
Pruning Methods Training Methods Arch. Acc. (%) Acc. Loss (%) Connectivity (%) Threshold-based [Neftci et al., 2016] Event-driven CD 2 FC 95.60 -0.60 26.00 Deep R [Bellec et al., 2018b]1 Surrogate Gradient LSNN2 93.70 +2.70 12.00 Threshold-based [Shi et al., 2019] STDP 1 FC 94.05 ≈-19.053 ≈30.00 Threshold-based [Rathi et al., 2019] STDP 2 layers4 93.20 -1.70 8.00 +0.03 50.00 ADMM-based [Deng et al., 2019] Surrogate Gradient LeNet-5 like 99.07 -0.43 40.00 -2.23 25.00 Online APTN [Guo et al., 2020] STDP 2 FC 90.40 -3.87 10.00 -0.36 37.14 -0.56 13.30 Deep R5 Surrogate Gradient 2 FC 98.92 -2.47 4.18 -9.67 1.10 -0.33 25.71 -0.43 17.94 Grad R Surrogate Gradient 2 FC 98.92 -2.02 5.63 -3.55 3.06 -8.08 1.38 1 Using sequential MNIST. 2 100 adaptive and 120 regular neurons. 3 A data point of conventional pruning. Since the soft-pruning will only freeze weight during training, bringing no real sparsity. 4 The 1st layer is FC and the 2nd layer is a one-to-one layer with lateral inhibition. 5 Our implementation of Deep R. Table 1: Performance comparison between the proposed method and previous work on MNIST dataset. 0.93 10000 #P-#R 0.975 0.92 # Pruned Top-1 Acc. Epoch Top-1 Acc. 256 # Regrown 0.950 512 5000 0.91 768 100 No Prune 1024 0 0.925 1280 GRAD R 0.90 1536 −100 DEEP R 1792 0 0.900 −2 0 0 1000 2000 1 2 3 4 6 10 20 3040 6e-01 1 2 3 4 6 10 20 3040 θ Epoch Connectivity (%) Connectivity (%) (a) MNIST (b) CIFAR-10 (a) Distribution of θ (b) Number of pruned and re- grown connections Figure 5: Top-1 accuracies on test set. Different penalty term α is chosen for different sparsity (connectivity). The accuracy of un- Figure 6: The statistics of the 2nd convolutional layer during train- pruned model is shown as a black dashed line. ing. in deep SNNs. We compare our algorithm to other state-of- Since we use a Laplacian distribution with µ < 0 as prior, the the-art pruning methods of SNNs, and the results are listed in posterior distribution will be dragged to fit the prior. Table 1 and Table 2, from which we can find that the proposed Competition between pruning and regrowth. A cursory method outperforms previous works. glance of the increasing sparsity over time may mislead one that the pruning process always gains huge ascendancy over 3.3 Behaviors of Algorithm the regrowth process. However, Fig. 6(b) sheds light on the fierce competition between these two processes. The pruning Ablation study of prior. We test the effects of prior. The takes full advantage of the regrowth merely in the early stages Grad R algorithm without prior (α = 0) achieves almost the (the first 200 epochs). Afterward, the pruning process only highest accuracy after pruning, but the connectivity drops the wins by a narrow margin. The competition encourages SNNs most slowly (the red curve in Fig. 4). The performance gap to explore different network structures during training. between unpruned SNN and pruned SNN without prior sug- gests the minimum accuracy loss using Grad R. It also clari- Firing rate and connectivity. Reduction of firing rate is es- fies that prior is needed if higher sparsity is required. sential for energy-efficient neuromorphic chips (event-driven hardware) [Oster et al., 2007]. Therefore, we explore the im- Distribution drift of parameters. Fig. 6(a) illustrates the pact of synaptic pruning on the firing rate. As shown in Fig. 7, distribution of the parameter θ during training (α = 0.005). there is a relatively weak correlation between the distribution
Pruning Methods Training Methods Arch. Acc. (%) Acc. Loss (%) Connectivity (%) -0.38 50.00 ADMM-based [Deng et al., 2019] Surrogate Gradient 7 Conv, 2 FC 89.53 -2.16 25.00 -3.85 10.00 -1.98 5.24 Deep R1 Surrogate Gradient 6 Conv, 2 FC 92.84 -2.56 1.95 -3.53 1.04 -0.30 28.41 -0.34 12.04 Grad R Surrogate Gradient 6 Conv, 2 FC 92.84 -0.81 5.08 -1.47 2.35 -3.52 0.73 1 Experiment on CIFAR-10 is not involved in SNNs’ pruning work of Deep R [Bellec et al., 2018b]. Therefore, we implement and generalize Deep R to deep SNNs. Table 2: Performance comparison between the proposed method and previous work on CIFAR-10 dataset. A Supporting Lemma Layer Connectivity (%) Lemma 1. If there exists a constant C, such that ∀k, θk ≥ C, 100 then there exists a constant A > 0 such that 48.18 k∇θ L̄k2 ≤ Ah∇θ L, ∇θ L̄i. (13) 37.21 Proof. Since wk = sβk log(1 + exp(βθk )), we have ∂w ∂θk = k 1 10.83 sk σ(βθk ), where σ(x) = 1+exp(−x) is sigmoid function. 4.13 Note that σ(x) > 0 is monotonically increasing, we have 0.45 P ∂L 2 k∇θ L̄k2 k sk ∂wk 0.0 0.2 0.4 0.6 0.8 1.0 = 2 h∇θ L, ∇θ L̄i 2 σ(βθ ) ∂L Avg. Firing Rate P k sk k ∂wk P 2 Figure 7: The distribution of the average firing rate of post-synaptic ∂L neurons in the 1st FC layer under different connectivity. The firing k ∂wk = 2 rate of each neuron is averaged on samples of the whole test dataset P ∂L and on time domain. k σ(βθk ) ∂wk (14) P 2 ∂L k ∂wk of neurons’ average firing rate and connectivity. However, ≤ 2 when the connectivity drastically drops extremely low under ∂L P k σ(βC) ∂wk 1%, the level of neural activities is severely suppressed. The peak on the left side has a significant shift, causing by silent 1 neurons. Meanwhile, the right peak representing saturated = . σ(βC) neurons vanishes. 1 Thus there exists a constant A = σ(βC) such that k∇θ L̄k2 ≤ Ah∇θ L, ∇θ L̄i. 4 Conclusion Lemma 1 also suggests that h∇θ L, ∇θ L̄i ≥ 0. When In this paper, we present Gradient Rewiring, a pruning k∇θ L̄(θ t )k > 0, the inner product h∇θ L, ∇θ L̄i is strictly method for deep SNNs, enabling efficient joint learning of positive, thus the angle between the accurate gradient ∇θ L connectivity and network weight. It utilizes the competition and the redefined gradient ∇θ L̄ is below 90°. between pruning and regrowth to strengthen the exploration of network structures and achieves the minimal loss of SNN’s performance on the MNIST and CIFAR-10 dataset so far. Our B Detailed Setting of Hyperparameters work reveals that there exists high redundancy in deep SNNs. The hyperparameters of Grad R and our implementation of An interesting future direction is to further investigate and ex- Deep R are illustrated in Table 3 and Table 4 for reproducing ploit such redundancy to achieve low-power computation on results in this paper. Note that for Deep R, zero penalty still neuromorphic hardware. introduces low connectivity.
Parameters Descriptions CIFAR-10 MNIST Since the distribution of sign is similar to the Bernoulli dis- N # Epoch 2048 512 tribution, the CDF of wk can also be derived as Batch Size - 16 128 Fwk (w) := P [wk < w] T # Timestep 8 = P [sk ReLU(θk ) < w] τm Membrane Constant 2.0 X uth Firing Threshold 1.0 = P [sk ReLU(θk ) < w | sk = s]P [sk = s] urest Resting Potential 0.0 s∈{1,−1} p Target Sparsity 95% 1 η Learning Rate 1e-4 = (P [ReLU(θk ) < w] + P [−ReLU(θk ) < w]) β1 , β2 Adam Decay 0.9, 0.999 2 1 = (P [ReLU(θk ) < w] + 1 − P [ReLU(θk ) < −w] Table 3: Choice of hyperparameters for both pruning algorithm. 2 − P [ReLU(θk ) = −w]) 1 2 (1 − Fθk (−w)), if w < 0 Dataset Algorithm Penalty Connectivity (%) = 12 (1 − P (θk ≤ 0)), if w = 0 . 1 0 48.35 2 (1 + Fθk (w)), if w > 0 10−4 37.79 (16) 2 × 10−4 28.41 Note that Fw0 k (w) = 12 Fθ0k (|w|), which indicates the prob- Grad R 5 × 10−4 12.04 ability density function (PDF) of wk is an even function. CIFAR-10 10−3 5.08 Therefore, the distribution of w is symmetric with respect to 2 × 10−3 2.35 the w-axis. We hope it to be unimodal, which matches our 5 × 10−3 0.73 observation of weight distribution in SNNs (See Fig. 3). So 0 5.24 there should not be any peak in Fw0 k (w) when w > 0, other- Deep R 5 × 10−4 1.95 wise we will get the other peak symmetric to it. 10−3 1.04 Note that the requirement that Fw0 k (w) = 21 Fθ0k (|w|)(w > 0) has no peak is equivalent to that Fθ0k (w)(w > 0) has no 5 × 10−3 25.71 peak. As Fθ0k (·) is the prior of θk , which is a Laplacian dis- 10−2 17.94 tribution where the PDF has only one peak at µ, we must Grad R 5 × 10−2 5.63 constrain µ ≤ 0. This explains why we require target sparsity 10−1 3.06 p ≥ 0.5. MNIST 2 × 10−1 1.38 0 37.14 References 10−2 13.30 [Bellec et al., 2018a] Guillaume Bellec, David Kappel, Deep R 5 × 10−2 4.18 Wolfgang Maass, and Robert Legenstein. Deep rewiring: 2 × 10−1 1.10 Training very sparse deep networks. In ICLR, 2018. [Bellec et al., 2018b] Guillaume Bellec, Darjan Salaj, Table 4: Choice of penalty term for Grad R (target sparsity p = 0.95) and Deep R (maximum sparsity p = 1 for CIFAR-10, p = Anand Subramoney, Robert Legenstein, and Wolfgang 0.99 for MNIST) with corresponding final connectivity. Maass. Long short-term memory and learning-to-learn in networks of spiking neurons. In NeurIPS, page 795–805. Curran Associates Inc., 2018. [Chklovskii et al., 2004] D. B. Chklovskii, B. W. Mel, and C Range of Target Sparsity K. Svoboda. Cortical rewiring and information storage. Nature, 431(7010):782–788, 2004. [Deng et al., 2019] Lei Deng, Yujie Wu, Yifan Hu, Ling Recall that the network weights w are determined by corre- Liang, Guoqi Li, Xing Hu, Yufei Ding, Peng Li, and Yuan sponding signs s and synaptic parameters θ, if the signs are Xie. Comprehensive SNN compression using ADMM randomly assigned and the distribution of θk is given, we are optimization and activity regularization. arXiv preprint able to deduce the distribution of wk . For this algorithm, we arXiv:1911.00822, 2019. denote the cumulative distribution function (CDF) of θk is [Fang et al., 2020a] Wei Fang, Yanqi Chen, Jianhao Ding, Fθk (θ) := P [θk < θ]. So we obtain Ding Chen, Zhaofei Yu, Huihui Zhou, Yonghong Tian, and other contributors. Spikingjelly. https://github.com/ fangwei123456/spikingjelly, 2020. Accessed: 2020-11- P [ReLU(θk ) < θ] = P [θ > max{θk , 0}] 09. 0, if θ ≤ 0 (15) [Fang et al., 2020b] Wei Fang, Zhaofei Yu, Yanqi Chen, = . Fθk (θ), if θ > 0 Timothée Masquelier, Tiejun Huang, and Yonghong Tian.
Incorporating learnable membrane time constant to en- [Neftci et al., 2016] Emre O. Neftci, Bruno U. Pedroni, Sid- hance learning of spiking neural networks. arXiv preprint dharth Joshi, Maruan Al-Shedivat, and Gert Cauwen- arXiv:2007.05785, 2020. berghs. Stochastic synapses enable efficient brain-inspired [Frankle and Carbin, 2019] Jonathan Frankle and Michael learning machines. Frontiers in Neuroscience, 10:241, 2016. Carbin. The lottery ticket hypothesis: Finding sparse, trainable neural networks. In ICLR, 2019. [Neftci et al., 2019] Emre O. Neftci, Hesham Mostafa, and Friedemann Zenke. Surrogate gradient learning in spiking [Gerstner and Kistler, 2002] Wulfram Gerstner and neural networks: Bringing the power of gradient-based op- Werner M. Kistler. Spiking Neuron Models: Single timization to spiking neural networks. IEEE Signal Pro- Neurons, Populations, Plasticity. Cambridge University cessing Magazine, 36(6):51–63, 2019. Press, 2002. [Oster et al., 2007] Matthias Oster, Rodney Douglas, and [Guo et al., 2020] Wenzhe Guo, Mohammed E. Fouda, Shih-Chii Liu. Quantifying input and output spike statis- Hasan Erdem Yantir, Ahmed M. Eltawil, and Khaled Nabil tics of a winner-take-all network in a vision system. In Salama. Unsupervised adaptive weight pruning for IEEE International Symposium on Circuits and Systems, energy-efficient neuromorphic systems. Frontiers in Neu- pages 853–856, 2007. roscience, 14:1189, 2020. [Petzoldt and Sigrist, 2014] Astrid G. Petzoldt and [Gütig and Sompolinsky, 2006] Robert Gütig and Haim Stephan J. Sigrist. Synaptogenesis. Current Biology, Sompolinsky. The tempotron: a neuron that learns spike 24(22):1076–1080, 2014. timing–based decisions. Nature Neuroscience, 9(3):420– [Qi et al., 2018] Yu Qi, Jiangrong Shen, Yueming Wang, 428, 2006. Huajin Tang, Hang Yu, Zhaohui Wu, and Gang Pan. [Huttenlocher and Nelson, 1994] Peter R Huttenlocher and Jointly learning network connections and link weights in Ch A Nelson. Synaptogenesis, synapse elimination, and spiking neural networks. In IJCAI, pages 1597–1603, 7 neural plasticity in human cerebral cortex. In Threats to 2018. Optimal Development: Integrating Biological, Psychoso- [Rathi et al., 2019] Nitin Rathi, Priyadarshini Panda, and cial, and Social Risk Factors: The Minnesota Symposia on Kaushik Roy. STDP-based pruning of connections Child Psychology, volume 27, pages 35–54, 1994. and weight quantization in spiking neural networks for [Kappel et al., 2015] David Kappel, Stefan Habenschuss, energy-efficient recognition. IEEE Transactions on Robert Legenstein, and Wolfgang Maass. Network plastic- Computer-Aided Design of Integrated Circuits and Sys- ity as Bayesian inference. PLOS Computational Biology, tems, 38(4):668–677, 2019. 11(11):1–31, 11 2015. [Roy et al., 2019] Kaushik Roy, Akhilesh Jaiswal, and [Lee et al., 2020] Chankyu Lee, Syed Shakib Sarwar, Priyadarshini Panda. Towards spike-based machine Priyadarshini Panda, Gopalakrishnan Srinivasan, and intelligence with neuromorphic computing. Nature, Kaushik Roy. Enabling spike-based backpropagation for 575(7784):607–617, 2019. training deep neural network architectures. Frontiers in [Shi et al., 2019] Yuhan Shi, Leon Nguyen, Sangheon Oh, Neuroscience, 14:119, 2020. Xin Liu, and Duygu Kuzum. A soft-pruning method [Liu et al., 2018] Chen Liu, Guillaume Bellec, Bernhard applied during training of spiking neural networks for Vogginger, David Kappel, Johannes Partzsch, Fe- in-memory computing applications. Frontiers in Neuro- lix Neumärker, Sebastian Höppner, Wolfgang Maass, science, 13:405, 2019. Steve B. Furber, Robert Legenstein, and Christian G. [Wu et al., 2018] Yujie Wu, Lei Deng, Guoqi Li, Jun Zhu, Mayr. Memory-efficient deep learning on a SpiNNaker and Luping Shi. Spatio-temporal backpropagation for 2 prototype. Frontiers in Neuroscience, 12:840, 2018. training high-performance spiking neural networks. Fron- [Liu et al., 2020] Yanchen Liu, Kun Qian, Shaogang Hu, tiers in Neuroscience, 12:331, 2018. Kun An, Sheng Xu, Xitong Zhan, JJ Wang, Rui Guo, [Wu et al., 2019a] Doudou Wu, Xianghong Lin, and Pangao Yuancong Wu, Tu-Pei Chen, et al. Application of deep Du. An adaptive structure learning algorithm for multi- compression technique in spiking neural network chip. layer spiking neural networks. In 15th International Con- IEEE Transactions on Biomedical Circuits and Systems, ference on Computational Intelligence and Security (CIS), 14(2):274–282, 2020. pages 98–102, 2019. [Maass, 1997] Wolfgang Maass. Networks of spiking neu- [Wu et al., 2019b] Yujie Wu, Lei Deng, Guoqi Li, Jun Zhu, rons: the third generation of neural network models. Neu- Yuan Xie, and Luping Shi. Direct training for spiking neu- ral Networks, 10(9):1659–1671, 1997. ral networks: Faster, larger, better. AAAI, 33(01):1311– 1318, July 2019. [Martinelli et al., 2020] Flavio Martinelli, Giorgia Dellafer- [Zenke and Vogels, 2020] Friedemann Zenke and Tim P. Vo- rera, Pablo Mainar, and Milos Cernak. Spiking neural net- works trained with backpropagation for low power neu- gels. The remarkable robustness of surrogate gradient romorphic implementation of voice activity detection. In learning for instilling complex function in spiking neural ICASSP, pages 8544–8548, 2020. networks. bioRxiv, 2020.
You can also read