SDA: Improving Text Generation with Self Data Augmentation
←
→
Page content transcription
If your browser does not render page correctly, please read the page content below
SDA: Improving Text Generation with Self Data Augmentation Ping Yu1 , Ruiyi Zhang2 , Yang Zhao1 , Yizhe Zhang3 , Chunyuan Li3 , Changyou Chen1 1 The State University of New York at Buffalo 2 Duke University 3 Microsoft Research, Redmond pingyu@buffalo.edu Abstract the effect of such word-level replacement is limited. A bet- arXiv:2101.03236v1 [cs.CL] 2 Jan 2021 ter option is to use sentence-level augmentation, which has Data augmentation has been widely used to improve deep been successfully used in many applications. For example, neural networks in many research fields, such as computer (Sennrich, Haddow, and Birch 2015; Sugiyama and Yoshi- vision. However, less work has been done in the context of text, partially due to its discrete nature and the complexity of naga 2019) used the back-translation method to generate natural languages. In this paper, we propose to improve the more training data to improve neural machine translation standard maximum likelihood estimation (MLE) paradigm by (NMT); (Kafle, Yousefhussien, and Kanan 2017) generated incorporating a self-imitation-learning phase for automatic question and answering pairs using LSTM for the Visual data augmentation. Unlike most existing sentence-level aug- Question Answering (VQA) task. A noticeable drawback of mentation strategies, which are only applied to specific models, these sentence-level text augmentation methods, however, our method is more general and could be easily adapted to any is the lack of universality in the sense that they are only MLE-based training procedure. In addition, our framework al- applicable to specific models. lows task-specific evaluation metrics to be designed to flexibly In this paper, we overcome this limitation by proposing a control the generated sentences, for example, in terms of con- self-imitation-learning based method to generate augmented trolling vocabulary usage and avoiding nontrivial repetitions. Extensive experimental results demonstrate the superiority of text data. In general, our method contains two learning our method on two synthetic and several standard real datasets, phases: significantly improving related baselines. i) The MLE Training stage: our model starts from the regular MLE based training until convergence; Introduction ii) The Self Data Augmentation stage: this phase maintains a buffer DB to store top-k ranked generated data samples Deep learning models are often considered data-hungry, in as the self-augmented data. In training, it also updates a the sense that the more data we have, the better performance discriminator to alternatively match the model generated we can achieve. In practice, it is usually too laborious and data with the self-augmented data and the training data. time-consuming to annotate a large amount of training data. An appropriate data augmentation method can significantly Especially, the rationality of our self-data-augmentation boost the performance of a deep-learning model. Generally stage is supported by a proposed self-inverse reinforcement- speaking, augmenting continuous data such as images is learning framework, which reveals that the alternative match- relatively simpler, as one could easily infer the meaning of ing in the discriminator leads to learning a policy that prefers an image even if it is manipulated by flipping, adding noise, generating higher-quality text. The major contributions of resizing, etc. However, for discrete data such as those in the this work are summarized in three aspects: natural-language-processing (NLP) filed, data augmentation • We propose a flexible sentence-level text augmentation is more challenging. This is partly because even a single-word technique that could be combined with any MLE based change could affect the meaning of a whole sentence. text generation method to improve model performance. Recent advances of NLP has advocated to adopt pre- trained models to effectively leverage extensive unlabeled • We show how to control the generated data by defining data collected from the internet. Such a learning paradigm task-specific evaluation metrics to rank the self-augmented has achieved defeating success for generating human-level data in the buffer, e.g., controlling vocabulary usage and sentences and articles. Nevertheless, we argue that a large avoiding nontrivial repetition. amount of data still benefits the finetuning stage, which can • We evaluate our model on both unconditional generation be achieved by data augmentation. In fact, data augmenta- and conditional generation tasks, achieving significant tion for text is not new. The traditional way is to replace improvements relative to the baselines. We further con- several words in a sentence with its synonyms (words or duct human evaluations, grammar check evaluations, and phrases having similar meanings) (Zhang, Zhao, and LeCun memorization evaluations to demonstrate our method’s 2015; Wang and Yang 2015; Kobayashi 2018). However, superiority. Ablation studies are conducted to investigate
a 00 the effects of varying training set size and removing the δs,s00 = 0 for other next states s . In these methods, usually imitation-learning component. a discriminator network, denoted as Dφ , is learned to pro- vide the reward for the REINFORCE algorithm (Williams Preliminaries and Related Works 1992). The discriminator could only calculate a reward for a complete sentences instead of short-term reward, which Text augmentation will be a problem for long-text generation. Yu et al. (2017) The traditional way of text augmentation mainly focuses on designed the N-Time Monte Carlo search (MC) to calculate word-level cases. Zhang, Zhao, and LeCun (2015) proposed a roll out reward. Specifically, they applied the Monte Carlo to select a word according to the geometric distribution and search with a roll-out policy Gθ to sample the unknown last replace it with its synonym. Wang and Yang (2015) found T − t tokens. They represent an N -time Monte Carlo search synonyms with k-nearest neighbors using word embedding. as: n o Kobayashi (2018) proposed to replace a word by the out- y1:T (1) , . . . , y1:T (N ) = MCGθ (y1:t ; N ) , (1) put of a bidirectional RNN language model. EDA (Wei and where y1:t = (y1 , . . . , yt ) and yt+1:T (n) is sampled based Zou 2019) proposed four operations: synonym replacement, on the generator policy Gθ and the current state. To reduce random insertion, random swap, and random deletion. As the variance and get a more accurate assessment of the action acknowledged, the performance of word-level text augmenta- value, they ran the roll-out policy starting from the current tion is limited due to simple word replacement in a sentence state until the end of the sequence for N times to get a batch without considering the sentence structure. of output samples. The roll-out reward QG θ Dφ is calculated as There have been some attempts for sentence-level text augmentation, but they were only designed for specific appli- G QDθφ (s = y1:t−1 , a = yt ) = (2) cations. In a translation task, Sennrich, Haddow, and Birch PN (n) (2015) designed back translation. Fadaee, Bisazza, and Monz N1 n=1 Dφ y1:T , for t < T (2017) generated new sentence pairs containing rare words in new and synthetically created contexts. Zheng et al. (2019) D (y ) φ 1:t for t = T designed a mirror architecture to augment translation pairs for non-parallel data. For the Visual Question Answering The objective function w.r.t. the generator parameters θ is (VQA) task, Kafle, Yousefhussien, and Kanan (2017) gener- T ated question-answer pairs using LSTM. It is worth noting X G J(θ) = Eyt ∼Gθ (yt |y1:t−1 ) [QDθφ (y1:t−1 , yt )] (3) that the aforementioned methods are not universal in terms t=1 of models. By contrast, our framework a universal sentence-level The gradient is calculated through policy gradient: method for text augmentation, which could be easily com- T X bined with the popular MLE based methods for training. ∇θ J(θ) = Eyt ∼Gθ (yt |y1:t−1 ) (4) Besides, our augmentation process can be easily controlled t=1 (e.g., control the vocabulary usage and avoid nontrivial repe- h i G ∇θ log Gθ (yt |y1:t−1 ) · QDθφ (y1:t−1 , yt ) titions in the generation) by incorporating some user-defined metrics. LeakGAN (Guo et al. 2018) improved the SeqGAN frame- work through leaking feature information from discriminator Text generation as reinforcement learning to generator and treating it as additional guided information Text generation can be explained as a reinforcement-learning for training generator. problem. To generate a text sequence, neural language mod- els (Mikolov et al. 2010) typically adopt an auto-regressive Self-Improved Text Augmentation manner. Given a dataset D = y1:T (i) . Denote y1:t−1 (i) as all tokens before token t. T is the length of the sentence Imitation learning is a framework of training policy to imitate training data. The recently developed self imitation learning y(i) . The standard MLE approach to train the model pθ mini- mizes: (Oh et al. 2018) explores the possibility of using the gen- erated data to improve a model further. In this paper, we T X leverage the advances of these ideas and propose a model to LMLE pθ , y(i) = − log pθ yt (i) |y1:t−1 (i) t=1 use self-augmented data to improve MLE-based models for data augmentation. The above sequential text generation process has been re- Specifically, in text generation, we treat an MLE-based formulated as reinforcement learning in SeqGAN (Yu et al. model, pθ (yt |y1:t−1 ), as a generator network Gθ (yt |y1:t−1 ), 2017). In the following, we will omit the sentence index “i” which aims to learn a policy to generate next word based when the context is clear. In the timestep t, the state s is the current produced tokens y1:t−1 and the action a is the on previous word. Some example generators can be LSTM- next token yt to be selected. The policy model (or generator), based models, CNN-based models and Transformer-based denoted as Gθ (yt |y1:t−1 ), is usually stochastic and defines models. On top of this, we add a discriminator network Dφ to the probability of generating the next token given previous provide reward to be applicable to the reinforcement-learning ones. The state transition is deterministic after an action has setting. Based on this setting, our self data-augmentation a 0 been chosen, i.e., δs,s 0 = 1 for the next state s = y1:t if method consists of two stages in training. The first stage the current state s = y1:t−1 and the action a = yt ; and is the general MLE-based model training (MLE Training
stage); and the second stage further refines the MLE-based Before describing the two steps, we first define a reference model with some self-augmented data, called the Self Data metric used to select generated sentences to construct the Augmentation stage. Note for the general MLE-based model, self-augmented data in the buffer DB . Some examples of the we typically add the second stage after the convergence of specific metric are given in the experiment part. the MLE training to further improve the MLE-based model. Below we explain these two stages in detail, with a formal Definition 1 (Reference Metric) A reference metric is de- explanation of the second stage given in the next section. fined as a function, m : S → R, mapping a sentence s ∈ S to a real value measuring the goodness of the sentence. The MLE Training stage In this stage, the generator Gθ (yt |y1:t−1 ) and the discrim- The imitating self-augmented data step In this step, we inator Dφ are trained separately. The generator is trained first update the buffer DB with the generated sentences asso- with the general MLE training mechanism, and the discrim- ciated with the highest reference metric values. Specifically, inator is trained in a GAN frame (Goodfellow et al. 2014). we associate each sentence with a reference metric score in Specifically, the generator is trained with an objective: the buffer. When new data are generated, we replace the low- T X score sentences in the buffer with the high-score generated Jg (π) = log Gθ (yt |y1:t−1 ) . (5) sentences, and only keep k highest-score sentences, where t=1 k is a hyperparameter depending on the size of the training The discriminator performs as a classifier to discriminate the generated sentences and ground truth training sentences with set and complexity of the task. After this, we assign the self- an objective: augmented data with label 1, and the current generated data with label 0, which are then used to update the discriminator Jd = Ey1:T ∼pdata [log Dφ (y1:T )] (6) with the objective (6) (but replacing the data distribution pdata + Ey1:T ∼Gθ [log (1 − Dφ (y1:T ))] , with self-augmented data distribution paug for this setting). where pdata represents the training data distribution; and As mentioned above, the discriminator provides rewards Gθ represents the generated-sentence distribution. In a large- to the generator to measure the current generated data quality. model setting, one can choose to use a pre-trained MLE- If a generated sentence is similar to sentences in the buffer based model and only train the discriminator with its gener- DB , the discriminator will output a larger value. In this way, ated sentences and the training data. a large reward will be fed into the generator, which is then up- dated in a reinforcement-learning setting. Specifically, given The Self Data Augmentation stage the current generated sentence y1:T ∈ Gθ , a reward QG θ Dφ Despite the simplicity, an MLE-trained model usually gen- is first calculated by (1) and (2). Finally, the generator is erates low-quality text in practice. A natural idea for im- updated with the following gradient: ∇θ J(θ) provement is to select some high-quality sentences from the generated data as the augmented data to enlarge the original T X h i training dataset. However, we note two issues with this ap- = Eyt ∼Gθ (yt |y1:t−1 ) ∇θ log Gθ (yt |y1:t−1 ) · QG θ (7) Dφ proach: i) It is hard to control the percentage of augmented t=1 data mixed with the training data. ii) If we direct mix aug- mented data with training data, it might have an adverse effect The imitating training data step This step is the same as as there might be some low-quality data in the augmented the previous step except that the data fed to the discriminator data. To overcome these issues, we propose to separate the is the training data with label 1 (instead of the self-augmented augmented and training data by using a buffer DB to store the data) and the currently generated data with label 0 through self-augmented data. In this way, i) we could conveniently (6). After training the discriminator, the reward is calculate control the percentage of augmented data at any time; and ii) through (1) and (2), which are then used to update the gen- the low-quality augmented data will not have a direct effect erator with the objective (7). Replacing the augmented data for the training data. in the imitating self-augmented data step with training data To apply the imitation-learning framework in our setting, would make the generator imitate the training data. we need to address the issue that there is only one data source in imitation learning. If we were to mix parts of the train- The whole process ing data and self-augmented data into a minibatch at every epoch, we found the training usually unstable because the To summarize, the Self Data Augmentation stage contains data in one small minibatch are not representative enough. To two steps: the imitating self augmented data step imitates the overcome this, we propose to divide the self Data Augmen- self-augmented data, while the imitating training data step tation stage into two sub-steps, which alternatively updates imitates the training data. These two steps run alternately. In the discriminator using the training data (called the imitating the beginning, we propose to run more imitating training data training data step) and the self-augmented data in the buffer steps and gradually increase the imitating self augmented DB (called the imitating augmented data step), respectively. data step. In our experiment, we run six imitating training data steps at the beginning. After updating the discriminator Remark 1 Although these two steps in the Self Data Aug- and generator, we run one imitating self augmented data mentation stage seem heuristic, we show that they are ac- step. After every ten epochs, we decrease one step of the tually derived from a more principled view of self-imitation imitating training data steps until it reduces to one. The learning framework, which is detailed in the next section. whole algorithm is described in algorithm 1.
Algorithm 1: Self Data Augmentation Capacity regularizer It has been shown in (Ho and Er- mon 2016b) that the adversarial regularizer can provide better Require: i) Generator network Gθ ; ii) Discriminator generalization. To imitate the expert policy, we adopt the ca- network Dφ ; iii) Training dataset DE ; iv) Reference pacity regularizer ψ1 (r) in GAIL (Ho and Ermon 2016b): metric m(s); v) the buffer DB to store self-augmented The specific form is given in the appendix . data. 1: Initialize πθ , rφ ; initialize the buffer DB ← ∅ Self-guided regularizer To promote generating subopti- 2: # Stage I: MLE Training stage mal demonstrations that are well aligned with the diverse 3: repeat tasks, we propose the self-guided regularizer as a weighted 4: Train generator Gθ via MLE with (5) linear function over rewards: 5: Train discriminator Dφ with (6) X 6: until converge ψ0 (r) = q(s, a)r(s, a) (10) 7: # Stage II: Self Data Augmentation stage s,a 8: repeat 9: Generate a set of sentences W from policy Gθ , Update the where q(s, a) is a probability function constructed via evalu- buffer DB by choosing the sentences from W with highest ation metrics. The specific form of q is defined in Appendix. reference-metric values. Generally, this regularizer encourages higher rewards in re- 10: if Imitating self augmented data step then gions where q(s, a) are large. 11: Sample sentences S ∼ πθ and SB ∼ DB By plugging the capacity regularizer and the self-guided 12: Update the discriminator Dφ using S and SB with (6) regularizer into (8), we formulate our objective as an op- 13: else if Imitating training data step then timization problem to find the optimal reward r∗ and the 14: Sample sentences S ∼ πθ and SE ∼ DE optimal policy π ∗ : 15: Update the discriminator Dφ using S and SE with (6) 16: end if {π ∗ , r∗ } = arg min max −2H(π) − ψ1 (r)+ (11) π∈Π r∈RS×A G 17: Update the reward QDθφ with (1 and 2) X X X q(s, a)r(s, a)+ ρE (s, a)r(s, a)−2 ρ(s, a)r(s, a), 18: Update the generator Gθ with (7) s,a s,a s,a 19: until Converge where we have set λ = 1 in (9), and the factor “2” is in- troduced for formulation convenience. The problem (11) is called IRL with self enhancement. Understanding the Self Data Augmentation Stage as IRL with Self Enhancement Solving (11) with a two-step algorithm Due to the in- In this section, we show that the self data augmentation stage, troduction of q(s, a), solving (11) directly is difficult. We which consists of alternatively updating the discriminator reformulate (11) with an augmented policy π1 : with the two data sources, can be explained as a procedure to X solve an inverse reinforcement learning (IRL) problem with min max −H(π1 ) + (q(s, a) − ρ(s, a))r(s, a) π1 ,π∈Π r∈RS×A s,a self-enhancement. The goal of IRL is to recover the optimal reward function r∗ (·, ·) as well as the corresponding policy X − H(π) − ψ1 (r) + (ρE (s, a) − ρ(s, a))r(s, a), π ∗ (Ho and Ermon 2016a): s,a {r∗ , π ∗ } = arg min max L(π, r) − ψ(r) (8) s.t. r1 = r, π1 = π (12) π∈Π r∈RS×A X Objective (12) now allows us to decompose the origi- = arg max ρE (s, a)r(s, a) nal problem into two sub-problems. Adopting ideas from r∈RS×A s,a X ADMM (Boyd et al. 2011), we decompose (12) into two − [max H(π) + ρ(s, a)r(s, a)] − ψ(r), sub-problems, which represent as the objectives of the im- π∈Π s,a itating self augmented data step (objective (13)) and the imitating training data step (objective (14)) in our proposed where ρ(s, a) is the occupancy measure defined as the data-augmentation framework: stationary joint distribution of (s, a) under policy π, ψ(·) is the regularizer. min max L1 (π1 , r1 ) , −H(π1 ) (13) π1 ∈Π r∈RS×A The traditional regularization in IRL only considers con- X straining the model capacity of the policy (Ho and Ermon + (q(s, a) − ρ(s, a))r(s, a) + λW1 (π, π1 ) 2016b), which makes it hard to be adapted to diverse tasks. s,a Natural languages own intrinsic characteristics based on syn- min max L2 (π, r) , −H(π) (14) tax and semantics. To leverage this prior knowledge, we π∈Π r∈RS×A propose incorporating user-defined evaluation measures into − ψ1 (r) + X (ρE (s, a) − ρ(s, a))r(s, a) + λW1 (π, π1 ) . ψ(r) so that the learned policy aligns well with the pre- s,a defined metrics when generating unseen sentences. Specif- ically, we define ψ(r) in (8) to be a mixture of a capacity The regularizer W1 (π, π1 ) denotes the 1-Wasserstein dis- regularizer ψ1 (r) and a self-guided regularizer ψ0 (r): tance, and is introduced to accommodate the constraint of ψ(r) = −λψ0 (r) + ψ1 (r) , (9) π1 = π in the original problem (12). The two phases (13) and (14) are solved alternatively. More details are shown in with λ > 0 the hyper-parameter to balance two terms. Appendix.
Experiments where Gθ and Goracle denote our generator and oracle net- We conduct experiments on two synthetic datasets and two work respectively. During training, we leverage Gθ to gener- real datasets for the unconditional generation task. The un- ate 100,000 sentences and calculate NLLoracle every 5 train- conditional generation task means input for the generator Gθ ing epoch. is pure Gaussian noise, and the generated sentences should In the experiment, we first run the MLE Training stage for 100 epochs and store the generated data to the buffer match the distribution with the training data distribution. We DB . Then we update our discriminator and generator using further test our method on a conditional generation task – the training data and the self augmented data selected by a unsupervised style transfer, detailed in the Appendix. reference metric: Experimental settings " T X # m(s) , Ey1:T ∼Gθ log Gθ (yt |y1:t−1 ) (16) Model architectures and baselines Two model bases: t=1 i) Both the generator and discriminator bases are RNN architectures. We test this model base with only the MLE As we can see from Fig. 1a and 1b, before the 100th epoch, training (denoted as RNN). For the ablation study, we im- both RNN and LeakRNN models converge. At stage 2, im- plement this model base with the Imitation Learning stage itation learning (both with and without self augmentation) but with only the imitating training data step without using further improves model performance. However, we observed augmented data, denoted as RNN+RL. Our method, com- that RNN+RL makes the training process less stable. This bined with this model base, is denoted as RNN+SDA. Since phenomenon becomes more obvious as the training process the EDA (Wei and Zou 2019) method could be applied with becomes more complicated. We suspect that learning a policy the general MLE training paradigm, we compare our model for text is more challenging due to the complexity of text with RNN+EDA. In addition, (Li et al. 2019; Roller et al. itself. Adding our self augmented step enlarges the training 2020) demonstrate the usefulness of the unlikelihood training dataset, making it easier to learn. mechanism for controlling vocabulary usage and nontriv- Table 1: BLEUs on COCO Image-caption testing dataset. ial repetition, which could be easily combined with MLE training models. We thus compare our SDA method with BLEU ↑ 2 3 4 5 unlikelihood training, denoted as RNN+UN for vocabulary RNN 0.800 0.599 0.367 0.201 usage control experiment and repetition control experiment. RNN+RL 0.819 0.601 0.370 0.211 ii) The second model base adapts the “leak” architecture RNN+EDA 0.822 0.610 0.373 0.220 RNN+SDA (ours) 0.826 0.612 0.378 0.222 from LeakGAN (Guo et al. 2018). The Discriminator Dφ LeakRNN 0.916 0.789 0.594 0.404 consists of a feature extractor and an output layer (the soft- LeakRNN+RL 0.921 0.797 0.602 0.416 max classification layer), whose output ranges in between LeakRNN+EDA 0.932 0.838 0.691 0.526 LeakRNN+SDA (ours) 0.943 0.856 0.738 0.599 [0, 1]. The Generator is designed to be a hierarchical RNN network and can use sentence features to generate the next- word. We call this model base with only the MLE train- Real data experiments ing LeakRNN. We further implement models LeakRNN+RL, For a real data experiment, we run both the MLE Training LeakRNN+SDA, LeakRNN+EDA and LeakRNN+UN using stage and the Self Data Augmentation stage for 100 epochs. the same method as first model base. At the Imitation Learning stage, generated sentences are stored in the buffer DB . The BLEU score (Papineni et al. Datasets We use two synthetic datasets and several 2002) is calculated between sentences in the buffer DB and real datasets, including the COCO image-caption dataset, some randomly selected 1,000 samples from the training EMNLP2017 WMT dataset* for unconditional generation, dataset. We use the BLEU3 score as our reference metric Yelp review dataset for conditional generation in Appendix. m(·) to construct the buffer. The two synthetic datasets are generated following (Yu et al. According to our observation of real data experiments, 2017; Guo et al. 2018). Table 9 in Appendix summarizes directly applying the policy gradient makes the model vulner- statics of all the datasets used in this paper. able and unstable, which is consistent with synthetic experi- Synthetic-sata experiments ment results. The time we conduct the test has a significant impact on results. Nevertheless, monitoring the BLEU score In order to better verify the effectiveness of our method and in real-time is time-consuming and laborious due to the large track the training process, we adapt the synthetic data exper- iments followed by (Yu et al. 2017; Guo et al. 2018). We testing dataset. Therefore, we examine the BLEU score every leverage a randomly initialized LSTM as an oracle model to ten epochs in the imitation Learning stage and pick the best generate the real data distribution p(xt |x1 , . . . , xt−1 ), which results for RNN+RL and LeakRNN+RL. Since our method is is considered to be a human server for real-world problems. stable, we only test the BLEU score result after 100 epochs Then we use NLL to track the performance of the model for RNN+SDA and LeakRNN+SDA. In addition to the results during training, defined as presented below, we also put some example-sentences in " T # Appendix. NLLoracle = −Ey1:T ∼Gθ X log Goracle (yt |y1:t−1 ) (15) Short-text generation The original COCO dataset con- t=1 sists of a total of 20,734 words and 417,126 sentences. Con- sidering the large gap between the average-length (11) and * http://statmt.org/wmt17/translation-task.html the max-length (50) of the sentences in this dataset, we only
LeakRNN RNN RNN LeakRNN LeakRNN RNN (a) (b) (c) Figure 1: Training curves on the synthetic dataset with sentence length 20 (a) and length 40 (b). The blue line before 100 epochs is our first model base RNN. After 100 epochs, the solid blue line is RNN+SDA, while the blue dotted line is RNN+RL. The yellow line before 100 epochs is our second model base LeakRNN. After 100 epochs, the solid yellow line is LeakRNN+SDA, while the yellow dotted line is LeakRNN+RL; (c): Test BLEU score between training dataset. A lower value on the y-axis means generative sentences achieve higher accuracy; a Lower value LeakRNN on the x-axis means the higher ability for generating sentences instead of memorializing sentences. randomly selected sentences with lengths less than 32, end- asked to rate over 100 randomly sampled sentences from ing up with a training dataset of 120,000 sentences and a each model with a scale from 1 to 5. The results shown in test dataset of 10,000 sentences. The experiment results are Table 3 indicate the outstanding performance of our proposed shown in Table 1. Although we do not pick the best result dur- methods. ing the training epoch, our self augmentation method (SDA) Table 3: Human evaluation results. still surpasses other baselines in all metrics, demonstrating Methods LeakRNN LeakRNN+RL LeakRNN+SDA Real the effectiveness of our data-augmentation scheme. Human scores ↑ 2.97 2.67 3.11 4.10 Table 2: BLEU score on test data set of EMNLP2017. BLEU ↑ Grammar check evaluations 2 3 4 5 RNN 0.626 0.350 0.160 0.085 We employ the grammar check tool † for Grammar Check RNN+RL 0.630 0.354 0.164 0.087 Evaluation, which is a professional software widely adopted RNN+EDA 0.639 0.361 0.172 0.106 for English writing checking in terms of contextual spelling, RNN+SDA 0.641 0.368 0.176 0.109 LeakRNN 0.914 0.692 0.461 0.280 grammar, punctuation, sentence structure, style, vocabulary LeakRNN+RL 0.916 0.707 0.469 0.289 enhancement, measured by the critical issue rate (CIR). The LeakRNN+EDA 0.921 0.738 0.509 0.312 LeakRNN+SDA 0.929 0.744 0.519 0.330 results are shown in Table 4. Please refer to Appendix for more details. Long-text generation To increase text generation diffi- culty, we use the whole sentences (with a maximum length Memorization evaluations of 50) without discarding low-frequency words on the This experiment aims at comparing the generalization abil- EMNLP2017 WMT dataset. The experimental settings are ity of different methods. To this end, we note that BigGAN the same as those used in the COCO dataset. Results are pre- (Brock, Donahue, and Simonyan 2018) carried out the nearest sented in Table 2. Our method outperforms other baselines. neighbor analysis on ImageNet (Deng et al. 2009) to shown their generator does not merely memorize data from the train- Analysis and Discussion ing set, which is a direct reflection of a model’s generalization To further evaluate the effectiveness and usefulness of our ability. As a result, we compare the BLEU score between our method, we carry out the following analyses and discussions generated sentences and training dataset with LeakRNN and on the more challenging dataset – EMNLP WMT dataset, LeakRNN+RL. Figure 1c shows a better ability of our model which has a longer sentence length, contains much more for generating sentences instead of memorizing sentences sentence structures and syntax than the COCO dataset. directly from the training dataset. This section first conducts a human evaluation on Amazon Mechanical Turk, grammar check evaluation, and memoriza- Improvement with different training data sizes tion evaluation. Then we carry out ablation studies to explore To test our method’s performance with different sizes of the the effect of different training data sizes and the necessity training set, we split the training dataset of EMNLP2017 to of imitation learning. At last, we design two task-specific 5%, 10%, and 50%, respectively. We compare our method reference metrics to control vocabulary usage and nontrivial with the MLE baseline LeakRNN. The results are shown in repetitions. Table 7. Interestingly, with our data augmentation scheme, our method could boost performance over the model base Human evaluations more significantly on smaller training data. We speculate that To further evaluate the generation quality, we conduct a hu- † man evaluation on Amazon Mechanical Turk; 3 judges are https://www.grammarly.com/
Table 4: Grammar evaluation results. Table 5: Vocabulary usage control. Table 6: Repetition control. Model Ave Len ↑ CIR ↓ BLEU2 ↑ Occurrence BLEU3 ↑ self-BLEU3 ↓ Training Dataset 27.526 2.583 LeakRNN 0.914 0.2% LeakRNN 0.914 0.838 LeakRNN 23.990 9.880 LeakRNN+UN 0.915 0.9% LeakRNN+UN 0.912 0.800 LeakRNN+RL 26.280 12.929 LeakRNN+SDA 0.928 0.5% LeakRNN+SDA 0.920 0.816 LeakRNN+SDA 27.850 9.874 Table 7: Performance comparison on different size of training occurs 20 times in the training dataset. Once the generated dataset. sentences contain the word ”project,” we add 0.0025 to our Size BLEU2 ↑ Increase % reference metrics. If the generated sentence contains two 5% LeakRNN 0.8199 4.12% special words, we add 1/o2 term twice. LeakRNN+SDA 0.8537 As a comparison, we adopt the unlikelihood estimation 10% LeakRNN 0.8551 2.28% (Welleck et al. 2019; Li et al. 2019; Roller et al. 2020) and LeakRNN+SDA 0.8746 combine it with the LeakRNN model (named LeakRNN+UN) LeakRNN 0.9024 50% LeakRNN+SDA 0.9181 1.74% to control vocabulary usage with a modified objective: LeakRNN 0.9144 100% 1.60% T log Gθ (yt |y1:t−1 ) for yt 6∈ Crare LeakRNN+SDA 0.9290 X Jg (π) = log G (y |y for yt ∈ Crare . t=1 θ t 1:t−1 ) + 1/o since overfitting tends to be more severe when training on smaller datasets, our method could alleviate the overfitting Table 5 shows that our method could control the appear- problem seamlessly as a data augmented method. ance of rare words to a certain extent, and can also improve the quality of generated sentences. The necessity of imitation learning Nontrivial repetition control We conduct experiments to verify the necessity of imitation learning to mimic both the training data and self-augmented A related issue to vocabulary usage is that generative models data. To this end, instead of leverage imitation learning to also tend to repeat (Holtzman et al. 2019). (Roller et al. 2020) update the generator alternatively, we mix training data with uses unlikelihood training to penalize 3-gram repetition in 5% the buffer DB at each epoch and use these mixed data generated sentences. We use this model as a baseline, de- to update the generator after the MLE Training stage. Mixup noted as LeakRNN+UN. Again, we use equation (17) as the (Zhang et al. 2017) could be successfully applied to data reference metric, with o representing the occurrence times augmentation for the image. We tried this method, but it of repeated 3-gram phases in a sentence. We adapt the self- did not gain better performance than EDA (Wei and Zou BLEU3 score (Zhu et al. 2018) to measure the repetition of 2019) in text augmentation. Then we did not include the 3-gram phases in the generated sentences. result from this method in this article. Our results show that Table 6 shows that unlikelihood training could significantly LeakRNN achieved 0.916 BLEU2 score, while the model reduce the repetition rate at the cost of lowering generation updating generator used mixed data achieves 0.919, versus quality, i.e., it can help control repetition with the risk of 0.943 with our LeakRNN+SDA model. The results indicate saying something that makes less sense. On the contrary, our the benefits of using imitation learning. method could seamlessly control repetition while keeping generation quality. Vocabulary usage control It has been observed that generative models tend to gener- Conclusion ate common words too frequently, and rare words too in- We propose a general and flexible text augmentation method frequently, as compared with human distribution (Holtzman that could be easily incorporated into the popular MLE based et al. 2018; Welleck et al. 2019; Li et al. 2019; Roller et al. methods to boosting model performance. Essentially, we 2020). We first analyze our training data distribution, which propose to add a Self Data Augmentation stage after the contains over 7 million words, with 5,722 unique characters standard MLE Training stage. At this stage, our generator in the dictionary. That means each word appears 1,260 times alternatively imitates training sentences and self-augmented on average. However, 20 words occur less than ten times, 43 sentences selected by a reference metric. We show formally words less than 20 times. that doing so can enhance the generator to focus on higher- Our model could be used to control vocabulary usage dur- quality sentences via the proposed IRL formulation with ing training due to the flexibility of defining a specific metric self-enhancement. Moreover, by defining different reference to select generated sentences. To this end, we collect charac- metrics, we can emphasize some characteristics of the gener- ters with occurrences of less than 50 to be a special character ated sentences, e.g., vocabulary usage, and nontrivial repeti- list Crare . We define a new reference metric as: tion control. Extensive experiments are conducted for both m(s) , BLEU(s) + 1/o2 , (17) unconditional and conditional text-generation, verifying the effectiveness of our proposed data-augmentation framework where o means occurrences times in the training dataset for for NLP. words in the collection Crare . For example, the word ”project”
References Hu, Z.; Yang, Z.; Liang, X.; Salakhutdinov, R.; and Xing, Ambrosio, L.; Gigli, N.; and Savaré, G. 2005. Gradient Flows E. P. 2017. Controllable text generation. In ICML. in Metric Spaces and in the Space of Probability Measures. Kafle, K.; Yousefhussien, M.; and Kanan, C. 2017. Data Lectures in Mathematics ETH Zürich. augmentation for visual question answering. In Proceedings Bahdanau, D.; Brakel, P.; Xu, K.; Goyal, A.; Lowe, R.; of the 10th International Conference on Natural Language Pineau, J.; Courville, A.; and Bengio, Y. 2017. An actor- Generation, 198–202. critic algorithm for sequence prediction. In ICLR. Kobayashi, S. 2018. Contextual augmentation: Data augmen- Boyd, S.; Parikh, N.; Chu, E.; Peleato, B.; and Eckstein, tation by words with paradigmatic relations. arXiv preprint J. 2011. Distributed Optimization and Statistical Learning arXiv:1805.06201 . via the Alternating Direction Method of Multipliers. Found. Li, J.; Jia, R.; He, H.; and Liang, P. 2018. Delete, retrieve, Trends Mach. Learn. 3(1): 1–122. ISSN 1935-8237. doi:10. generate: A simple approach to sentiment and style transfer. 1561/2200000016. URL http://dx.doi.org/10.1561/ arXiv preprint arXiv:1804.06437 . 2200000016. Li, M.; Roller, S.; Kulikov, I.; Welleck, S.; Boureau, Y.-L.; Brock, A.; Donahue, J.; and Simonyan, K. 2018. Large scale Cho, K.; and Weston, J. 2019. Don’t Say That! Making gan training for high fidelity natural image synthesis. arXiv Inconsistent Dialogue Unlikely with Unlikelihood Training. preprint arXiv:1809.11096 . arXiv preprint arXiv:1911.03860 . Caccia, M.; Caccia, L.; Fedus, W.; Larochelle, H.; Pineau, Liu, Q.; Li, L.; Tang, Z.; and Zhou, D. 2018. Breaking the J.; and Charlin, L. 2018. Language gans falling short. arXiv Curse of Horizon: Infinite-Horizon Off-policy Estimation. In preprint arXiv:1811.02549 . NIPS. Deng, J.; Dong, W.; Socher, R.; Li, L.-J.; Li, K.; and Fei- Mikolov, T.; Karafiát, M.; Burget, L.; Černockỳ, J.; and Khu- Fei, L. 2009. ImageNet: A Large-Scale Hierarchical Image danpur, S. 2010. Recurrent neural network based language Database. In CVPR09. model. In Eleventh Annual Conference of the International Fadaee, M.; Bisazza, A.; and Monz, C. 2017. Data augmen- Speech Communication Association. tation for low-resource neural machine translation. arXiv Ng, A. Y.; Harada, D.; and Russell, S. J. 1999. Policy In- preprint arXiv:1705.00440 . variance Under Reward Transformations: Theory and Ap- Finn, C.; Levine, S.; and Abbeel, P. 2016. Guided Cost Learn- plication to Reward Shaping. In Proceedings of the Six- ing: Deep Inverse Optimal Control via Policy Optimization. teenth International Conference on Machine Learning, ICML In Proceedings of the 33rd International Conference on In- ’99, 278–287. San Francisco, CA, USA: Morgan Kauf- ternational Conference on Machine Learning - Volume 48, mann Publishers Inc. ISBN 1-55860-612-2. URL http: ICML’16, 49–58. JMLR.org. URL http://dl.acm.org/ //dl.acm.org/citation.cfm?id=645528.657613. citation.cfm?id=3045390.3045397. Ng, A. Y.; and Russell, S. J. 2000. Algorithms for In- Goodfellow, I.; Pouget-Abadie, J.; Mirza, M.; Xu, B.; Warde- verse Reinforcement Learning. In Proceedings of the Sev- Farley, D.; Ozair, S.; Courville, A.; and Bengio, Y. 2014. enteenth International Conference on Machine Learning, Generative Adversarial Nets. In Neural Information Process- ICML ’00, 663–670. San Francisco, CA, USA: Morgan Kauf- ing Systems (NIPS). mann Publishers Inc. ISBN 1-55860-707-2. URL http: //dl.acm.org/citation.cfm?id=645529.657801. Guo, J.; Lu, S.; Cai, H.; Zhang, W.; Yu, Y.; and Wang, J. 2018. Long Text Generation via Adversarial Training with Leaked Oh, J.; Guo, Y.; Singh, S.; and Lee, H. 2018. Self-imitation Information. In AAAI. learning. arXiv preprint arXiv:1806.05635 . Ho, J.; and Ermon, S. 2016a. Generative Adversar- Papineni, K.; Roukos, S.; Ward, T.; and Zhu, W.-J. 2002. ial Imitation Learning. In Lee, D. D.; Sugiyama, BLEU: a method for automatic evaluation of machine trans- M.; Luxburg, U. V.; Guyon, I.; and Garnett, R., lation. In Proceedings of the 40th annual meeting on asso- eds., Advances in Neural Information Processing Sys- ciation for computational linguistics, 311–318. Association tems 29, 4565–4573. Curran Associates, Inc. URL for Computational Linguistics. http://papers.nips.cc/paper/6391-generative- Ranzato, M.; Chopra, S.; Auli, M.; and Zaremba, W. 2016. adversarial-imitation-learning.pdf. Sequence level training with recurrent neural networks. In Ho, J.; and Ermon, S. 2016b. Generative adversarial imitation ICLR. learning. In NIPS. Rennie, S. J.; Marcheret, E.; Mroueh, Y.; Ross, J.; and Goel, Holtzman, A.; Buys, J.; Du, L.; Forbes, M.; and Choi, Y. V. 2016. Self-critical sequence training for image captioning. 2019. The curious case of neural text degeneration. arXiv In CVPR. preprint arXiv:1904.09751 . Roller, S.; Dinan, E.; Goyal, N.; Ju, D.; Williamson, M.; Liu, Holtzman, A.; Buys, J.; Forbes, M.; Bosselut, A.; Golub, Y.; Xu, J.; Ott, M.; Shuster, K.; Smith, E. M.; et al. 2020. D.; and Choi, Y. 2018. Learning to write with cooperative Recipes for building an open-domain chatbot. arXiv preprint discriminators. arXiv preprint arXiv:1805.06087 . arXiv:2004.13637 .
Sennrich, R.; Haddow, B.; and Birch, A. 2015. Improving neural machine translation models with monolingual data. arXiv preprint arXiv:1511.06709 . Shen, T.; Lei, T.; Barzilay, R.; and Jaakkola, T. 2017. Style transfer from non-parallel text by cross-alignment. In NIPS. Song, J.; Ren, H.; Sadigh, D.; and Ermon, S. 2018. Multi- Agent Generative Adversarial Imitation Learning. In NIPS. Sudhakar, A.; Upadhyay, B.; and Maheswaran, A. 2019. Transforming delete, retrieve, generate approach for con- trolled text style transfer. arXiv preprint arXiv:1908.09368 . Sugiyama, A.; and Yoshinaga, N. 2019. Data augmenta- tion using back-translation for context-aware neural machine translation. In Proceedings of the Fourth Workshop on Dis- course in Machine Translation (DiscoMT 2019), 35–44. Wang, W. Y.; and Yang, D. 2015. That’s so annoying!!!: A lexical and frame-semantic embedding based data aug- mentation approach to automatic categorization of annoying behaviors using# petpeeve tweets. In Proceedings of the 2015 Conference on Empirical Methods in Natural Language Processing, 2557–2563. Wang, X.; Chen, W.; Wang, Y.-F.; and Wang, W. Y. 2018. No metrics are perfect: Adversarial reward learning for visual storytelling. arXiv preprint arXiv:1804.09160 . Wei, J. W.; and Zou, K. 2019. Eda: Easy data augmentation techniques for boosting performance on text classification tasks. arXiv preprint arXiv:1901.11196 . Welleck, S.; Kulikov, I.; Roller, S.; Dinan, E.; Cho, K.; and Weston, J. 2019. Neural text generation with unlikelihood training. arXiv preprint arXiv:1908.04319 . Williams, R. J. 1992. Simple statistical gradient-following algorithms for connectionist reinforcement learning. Machine learning 8(3-4): 229–256. Yang, Z.; Hu, Z.; Dyer, C.; Xing, E. P.; and Berg-Kirkpatrick, T. 2018. Unsupervised text style transfer using language models as discriminators. In Advances in Neural Information Processing Systems, 7287–7298. Yu, L.; Zhang, W.; Wang, J.; and Yu, Y. 2017. SeqGAN: Sequence Generative Adversarial Nets with Policy Gradient. In AAAI. Zhang, H.; Cisse, M.; Dauphin, Y. N.; and Lopez-Paz, D. 2017. mixup: Beyond empirical risk minimization. arXiv preprint arXiv:1710.09412 . Zhang, X.; Zhao, J.; and LeCun, Y. 2015. Character-level convolutional networks for text classification. In Advances in neural information processing systems, 649–657. Zheng, Z.; Zhou, H.; Huang, S.; Li, L.; Dai, X.-Y.; and Chen, J. 2019. Mirror-Generative Neural Machine Translation. In International Conference on Learning Representations. Zhu, Y.; Lu, S.; Zheng, L.; Guo, J.; Zhang, W.; Wang, J.; and Yu, Y. 2018. Texygen: A Benchmarking Platform for Text Generation Models. arXiv preprint arXiv:1802.01886 .
Preliminaries about Inverse Reinforcement Learning {r∗ , π ∗ } , IRL(πE ) X An alternative way (Ho and Ermon 2016b; Liu et al. 2018) to = arg max ρE (s, a)r(s, a) r∈RS×A write the RL objective (equation 3) is: s,a X h i − [max H(π) + ρ(s, a)r(s, a)] G π∈Π J(π) = Est ∼dπ (·),yt ∼Gθ (yt |y1:t−1 ) QDθφ (y1:t−1 , yt ) , with s,a T T = arg max min L(π, r) X t−1 X t r∈RS×A π∈Π dπ (s) , lim ( γ dπ,t (s))/( γ ), (18) X X T →∞ , −H(π)+ ρE (s, a)r(s, a)− ρ(s, a)r(s, a). t=0 t=0 s,a s,a where dπ,t (s) is the distribution of st when executing policy π starting from some initial state distribution; and Intuitively, the objective enforces the expert policy πE to dπ (s) is the state marginal distribution. induce higher rewards (the “max” part) than all other policies In order to unify the symbols in our formula with the (the “min” part). Note this setting is sub-optimal if the given symbols in Reinforcement Learning, we then use r(st , at ) to data is noisy, as the learned policy is expected to perform worse than the expert policy (and equal to the expert policy at represent QG Dφ (y1:t−1 , yt ), at ∼ π(·| st ) to represent yt ∼ θ optimality). Besides, the IRL objective is ill-defined (Ng and Gθ (yt |y1:t−1 ) in the following theories. Russell 2000; Finn, Levine, and Abbeel 2016), which often The occupancy measure ρ(s, a) is defined as the stationary induces multiple valid solutions to the problem because the joint distribution of (s, a) under policy π, i.e., ρ(s, a) , solution space is too flexible. For example, one can assign an dπ (s)π(a | s). Thus, given an environment, there is a one- arbitrary reward to trajectories not visited by the expert, as to-one correspondence between the policy π and occupancy long as these trajectories yield lower rewards than the expert measure ρ (Ho and Ermon 2016b). Consequently, RL can be trajectories (Song et al. 2018). To address these problems, equivalently rewritten as some constraints are typically placed on the reward functions, e.g., a convex reward functional, ψ : RS×A → R, is usually π ∗ , RL(π) = arg max E(s,a)∼ρ [r(s, a)] . (19) introduced as a regularizer (Ho and Ermon 2016b). Finally, π∈Π the IRL is formulated as {π ∗ , r∗ } = arg min max L(π, r) − ψ(r). (21) To formulate text generation as RL, we may consider the π∈Π r∈RS×A process as an agent taking actions to generate next words, conditioned on the states of observed words, i.e., Self-Improved Inverse RL In this paper, we advocate that IRL is a more natural frame- st = w 0 is the weighting hyper-parameter to balance learn the optimal reward function from the expert trajectories two terms. DE , which is represented as a training dataset in text genera- tion. Typically, DE consists of state-action pairs generated Capacity Regularizer It has been shown in (Ho and Er- from an expert policy πE , with the corresponding occupancy mon 2016b) that the adversarial regularizer can provide better measure ρE . The goal of IRL is to recover the optimal re- generalization. To imitate the expert policy, we adopt the ca- ward function r∗ (·, ·) as well as the corresponding policy π ∗ . pacity regularizer ψ1 (r) in GAIL (Ho and Ermon 2016b). Formally, we have for IRL‡ . Self-Guided Regularizer To promote generating subop- ‡ The original goal of IRL is to find the optimal reward function. timal demonstrations that are well aligned with the diverse In this paper, we generalize the definition to finding both optimal tasks, we propose the self-guided regularizer as a weighted reward and policy. linear function over rewards:
X X ψ0 (r) = q(s, a)r(s, a) (23) min max −H(π1 ) + (q(s, a) − ρ(s, a))r(s, a) π1 ,π∈Π r∈RS×A s,a s,a where q(s, a) is a probability function constructed via evalua- X − H(π) − ψ1 (r) + (ρE (s, a) − ρ(s, a))r(s, a), tion metrics (to be defined below). Generally, this regularizer s,a encourages higher rewards in regions where q(s, a) are large. In text generation, the evaluation metrics are used to evaluate s.t. r1 = r, π1 = π (25) certain quality aspects of a generated sentence, such as the BLEU or ROUGE scores. Formally, we define the notation Objective (25) now allows us to decompose the origi- of reference metric. nal problem into two sub-problems. Adopting ideas from Definition 2 (Reference Metric) A reference metric is de- ADMM (Boyd et al. 2011), we decompose (25) into two fined as a function, m : S → R, mapping a sentence s ∈ S sub-problems, an imitating self augmented data step and an to a real value measuring the goodness of the sentence. imitating training data step: For simplicity, we assume that a higher value of m(s) indi- cates higher quality of the sentence s. Note that the standard imitating self augmented data step: (26) IRL with a fixed pool of expert trajectories does not nec- min max L1 (π1 , r1 ) , −H(π1 ) essarily guarantee to learn a reward function that matches π1 ∈Π r∈RS×A the reference metric. We conjecture it may be caused by X + (q(s, a) − ρ(s, a))r(s, a) + λW1 (π, π1 ) the fact that the expert trajectories alone can be noisy and s,a unrepresentative. To enhance the state space in terms of the reference metric, imitating training data step: (27) we focus on a sub-space, S̄ ⊆ S, with high reference-metric min max L2 (π, r) , −H(π) values, i.e., S̄ , {s ∈ S : m(s) > m∗ } for some threshold π∈Π r∈RS×A m∗ . Next, the concentrated distribution q(s, a) is defined as X − ψ1 (r) + (ρE (s, a) − ρ(s, a))r(s, a) + λW1 (π, π1 ) . 1 s,a q(s, a) = dπ (s)π(a | s)δ(m(s) > m∗ ) , d¯π (s)π(a | s) , Z where the indication function δ(m(s) > m∗ ) = 1 if The regularizer W1 (π, π1 ) above is introduced to accom- m(s) > m∗ and 0 otherwise; Z is the normalizer to en- modate the constraint of π1 = π in the original problem (25). sure q to be a proper distribution. In contrast to the original The two phases (26) and (27) are solved alternatively. By occupancy measure, ρ(s, a) = dπ (s)π(a | s) for all state- parametrizing all the π, π1 and r with DNNs, they can be action pair in s ∈ S, q(s, a) is more concentrated towards updated by solving the min-max problems similar to GAIL regions with higher reference-metric values. In another word, (Ho and Ermon 2016a). The only difference lies on the W1 q(s, a) is larger than ρ(s, a) as the probabilities are only de- regularizer, where the gradients of policy parameters need to fined on the sub-space S̄ with high reference-metric values. be derived. Fortunately, this can be solved by the following This method is related to reward shaping (Ng, Harada, and theorem. Russell 1999), however, in a more principled manner. Un- Theorem 1 Let π be parameterized with θ, denoted as πθ . like traditional reward-shaping strategy where the policy is The gradient of W1 (πθ , π1 ) w.r.t.θ can be calculated as directly trained to maximize certain heuristic thus it could be Z noisy, the uncertainty in q(s, a) still allows an automatic and δW1 ∇θ W1 (πθ , π1 ) = (πθ , π1 ) · ∇θ πθ (Y ) dY softened exploration and exploitation trade-off. δπθ Z By plugging (22) and (23) into (21), we formulate our δW1 ∇θ πθ (Y ) = (πθ , π1 ) · · πθ (Y ) dY objective as an optimization problem to find the optimal δπθ πθ (Y ) reward r ∗ and the optimal policy π ∗ : Z δW1 = (πθ , π1 ) · log πθ (Y ) · πθ (Y ) dY δπθ {π ∗ , r∗ } = arg min max −2H(π) − ψ1 (r)+ (24) δW1 π∈Π r∈RS×A = EY ∼πθ (πθ )(Y ) · ∇θ log πθ (Y ) X X X δπθ q(s, a)r(s, a)+ ρE (s, a)r(s, a)−2 ρ(s, a)r(s, a), s,a s,a s,a where the first equality follows by the chain rule, with δW1 δπθ (πθ ) defined as the first variation of W1 (Ambrosio, Gigli, where we have set λ = 1 in (22), and the factor “2” is and Savaré 2005). Furthermore, rewrite W1 (πθ , π1 ) in its introduced for formulation convenience. dual form as: W1 (πθ , π1 ) = maxkDkLip
Theorem 1 applies to the parameterized π1 as well. It is Reward update We first parameterize r with φ imple- now clear that to calculate gradients of the W1 term, one mented as a deep neural network (DNN), which takes a i) first draw samples from the two policies to update the state-action pair as input and outputs a real value, i.e., discriminator D; ii) then use (28) to calculate ∇θ W1 . This rφ (·, ·) , − log D(·, ·). Consequently, the reward network consequently allows us to optimize (26) and (27) by gradient can be optimized as in the standard GAIL framework with descent. It is worth noting that when the two policies in the the following objective: W1 arguments are close enough to each other, the second part of the RHS of (28) will become relatively small, allowing us to drop it for computational efficiency. The whole algorithm φ ∗ = arg max Eπ [log Dφ (s, a)] + (30) φ is illustrated in Algorithm 1. EπE [log(1 − Dφ (s, a))] (31) Proof of Theorem 1 where the expectations over π and πE are approximated by samples from the current policy π (which is initialized with The gradient of W1 can be expanded, based on the definition π1 learned in the self-improvement phase) and the training of functional derivative, as: data, respectively. Z δW1 ∇θ W1 (πθ , π1 ) = (πθ , π1 ) · ∇θ πθ (Y )dY Policy update Given the learned reward network rφ , opti- δπθ mizing the policy can be done by standard policy-gradient- δW1 based methods, with the following objective: = EY ∼πθ (πθ , π1 ) · ∇θ log πθ (Y ) , δπθ where the first equality follows by the chain rule, with π = arg max H(π) + Eπ [rφ (s, a)] − λW1 (π1 , π) . (32) δW1 π∈Π δπθ (πθ ) defined as the first variation of the functional W1 (Ambrosio, Gigli, and Savaré 2005). To prove the second claim, by substituting the optimal Capacity Regularizer It has been shown in (Ho and Er- discriminator D∗ into the dual formula of W1 and noting D∗ mon 2016b) that the adversarial regularizer can provide better is implicitly dependent on θ, we have generalization. To imitate the expert policy, we adopt the reg- ularizer in GAIL (Ho and Ermon 2016b), which defines ψ1 W1 (πθ , π1 ) = EY ∼πθ [D∗ (Y )] − EY 0 ∼π1 D∗ (Y 0 ) . with the following form: EπE [g(r(s, a))] if r(s, a) ≥ 0 ψ1 (r) , , ∂W1 (πθ , π1 ) +∞ otherwise ⇒ Z ∂θ −x − log(1 − ex ) if x < 0 ∂ where g(x) = D∗ (Y )πθ (Y )dY − EY 0 ∼π1 ∇θ D∗ (Y 0 ) = +∞ otherwise ∂θ Z Z = ∇θ πθ (Y )D∗ (Y )dY + ∇θ D∗ (Y )πθ (Y )− Unsupervised Style Transfer EY 0 ∼π1 ∇θ D∗ (Y 0 ) Our framework can also be extended to the task of style transfer. The key changes needed to make are: i) Define the = EY ∼πθ [D∗ (Y )∇θ log πθ (Y )] + EY ∼πθ [∇θ D∗ (Y )] − generator as a conditional distribution, instead of the uncon- EY 0 ∼π1 ∇θ D∗ (Y 0 ) ditional distribution used for text generation; ii) Define the = EY ∼πθ [D∗ (Y )∇θ log πθ (Y )]+ buffer DB to store two bins of sentences, each corresponding to one style; iii) Include some additional reconstructed loss ∂D∗ (Y ) ∂D∗ (Y 0 ) EY ∼πθ ,Y 0 ∼π1 ∇θ Y − ∇θ Y 0 , in order to preserve the reversibility of the learned policy. ∂Y ∂Y 0 Detailed descriptions are given in Appendix . We apply the method for non-parallel text-style-transfer, which completes the proof. where the goal is to transfer one sentence in one style (e.g., positive) to a similar sentence but with a different style (e.g., More Details of Imitation Learning negative). For a fair comparison, we use the same data and its split method as in (Shen et al. 2017). To measure whether This describes our solution to the imitation learning phase the original test sentences have been transferred to the de- (27). According to (Ho and Ermon 2016b), let D(s, a) , sired sentiment, we follow the settings of (Shen et al. 2017) exp(−r(s, a)), then (27) can be reformulated as and employ a pretrained CNN classifier, which achieves an {π, r} = arg min max −H(π) (29) accuracy of 97.4% on the validation set, to evaluate the trans- π∈Π r∈RS×A ferred sentences. The results are presented in Table 8. Note + λW1 (π1 , π) + Eπ [log D(s, a)] + EπE [log(1 − D(s, a))] that both (Li et al. 2018) and (Sudhakar, Upadhyay, and Mah- eswaran 2019) are supervised approaches. It is observed that As in the standard GAIL algorithm, (29) can be solved by our method obtains much higher accuracy than other meth- alternating between the following two optimization problems. ods, including the controllable text generation method (Hu
You can also read