Compositional Generalization with Tree Stack Memory Units
We study compositional generalization, viz., the problem of zero-shot generalization to novel compositions of concepts in a domain. Standard neural networks fail to a large extent on compositional learning. We propose Tree Stack Memory Units (Tree-SM…
Authors: Forough Arabshahi, Zhichu Lu, Pranay Mundra
C O M P O S I T I O N A L G E N E R A L I Z A T I O N W I T H T R E E S T A C K M E M O RY U N I T S For ough Arabshahi ∗ Carnegie Mellon Uni versity Pittsbur gh, P A 15213 farabsha@cs.cmu.edu Zhichu Lu ∗ Johns Hopkins Univ ersity Baltimore, MD 21218 zhicul@cs.cmu.edu Pranay Mundra Univ ersity of W ashington Seattle, W A pranay99@uw.edu Sameer Singh Univ ersity of California Irvine Irvine, CA 92697 sameer@uci.edu Animashree Anandkumar California Institute of T echnology Pasadena, CA 91125 anima@caltech.edu A B S T R AC T W e study compositional generalization, viz., the problem of zero-shot generaliza- tion to nov el compositions of concepts in a domain. Standard neural netw orks fail to a lar ge extent on compositional learning. W e propose T ree Stack Memory Units (T ree-SMU) to enable strong compositional generalization. T ree-SMU is a recursiv e neural network with Stack Memory Units (SMUs), a novel memory augmented neural network whose memory has a dif ferentiable stack structure. Each SMU in the tree architecture learns to read from its stack and to write to it by combining the stacks and states of its children through gating. The stack helps capture long-range dependencies in the problem domain, thereby enabling compositional generalization. Additionally , the stack also preserves the ordering of each node’ s descendants, thereby retaining locality on the tree. W e demonstrate strong empirical results on two mathematical reasoning benchmarks. W e use four compositionality tests to assess the generalization performance of T ree-SMU and show that it enables accurate compositional generalization compared to strong baselines such as T ransformers and T ree-LSTMs. 1 I N T RO D U C T I O N Despite the impressiv e performance of deep learning in the last decade, systematic compositional generalization is mostly out of reach for standard neural networks (Hupkes et al., 2020). In a compositional domain, a set of concepts can be combined in novel ways to form new instances. Compositional generalization is defined as zero-shot generalization to such novel compositions. Lake & Baroni (2018) recently showed that a variety of recurrent neural networks fail spectacularly on tasks requiring compositional generalization. Lample & Charton (2019) presented similar results for the popular transformer architecture and V aswani et al. (2017) showed it in the symbolic mathematical domain. In particular , they sho wed that transformers can generalize nearly perfectly when the training and test distributions are identical. Howe ver , this generalization is brittle and their performance degrades significantly e ven with slight distrib utional shifts. Neuro-symbolic models, on the other hand, hold the promise for achieving compositional generaliza- tion (Lamb et al., 2020). They inte grate symbolic domain knowledge within neural architectures. For instance, a popular example is a recursiv e (tree-structured) neural network with nodes corresponding to different symbolic concepts and the tree representing their composition. Neuro-symbolic models can achie ve better generalization since the neural component is reliev ed from the additional b urden of learning symbolic knowledge from scratch. Moreover , in many domains, compositional supervision is readily av ailable, such as the domain of mathematical equations, and neuro-symbolic models can directly incorporate it (Allamanis et al., 2017; Arabshahi et al., 2018; Evans et al., 2018). Recursiv e neural networks, ho we ver , still fall short when it comes to compositional generalization, especially on instances much more complex (larger tree depth) compared to training data. This is ∗ Equal Contribution 1 4 3 5 1 2 State Memory … … (a) T ree-LSTM 4 3 5 1 2 Stack State … … (b) T ree-SMU Push : Pop : Gated Combination of Children Stacks Concatenated Children States Parent’ s Stack Parent’ s State Push gate (Eq.4) : Pop gate (Eq.5) : Gated Linear Combination (Eq.10) : (c) Soft Push and Pop Figure 1: Model architecture of Tree-LSTM vs. Tree-SMU (T ree Stack Memory Unit). Compared to LSTM, a SMU has an increased memory capacity of the form of a differentiable stack. The black arro ws over the stacks in Fig. 1b represent our soft push/pop operation, sho wn in Fig. 1c. As depicted in Fig. 1c, the children’ s stacks and states jointly determine the content of the parent stack. The children stacks are combined using the gating mechanism in Equation 1 and then used to fill the parent stack through the push and pop gating operations gi ven in Equations 5 and 6. because of error propagation along the tree, and the failure of recursi ve networks to capture long-range dependencies effectiv ely . There are currently no error-correction mechanisms in recursive neural networks to ov ercome this. W e address this issue in this paper by designing a novel recursi ve neural network (T ree-SMU) with stack memory units. In order to ev aluate the compositional generalization of different neural netw ork models, we use the tests proposed by Hupkes et al. (2020). The tests ev aluate the model on its ability to (1) generalize systematically to nov el compositions, (2) generalize to deeper compositions than they are trained on, (3) generalize to shallo wer compositions than they are trained on, and (4) learn similar representations for semantically equiv alent compositions. Mathematical reasoning is an excellent test bed for such e valuations of compositional generalization, since we can construct arbitrarily deep or shallo w compositions using a gi ven set of primitive functions in mathematics. Moreover , there has recently been a growing interest in the problem of mathematical reasoning (Lee et al., 2019; Saxton et al., 2019; Lample & Charton, 2019; Arabshahi et al., 2018; Loos et al., 2017; Allamanis et al., 2017). Furthermore, it is difficult (and often ambiguous) to accurately measure compositionality in other domains such as natural language or vision; benchmarks that do measure compositionality in these domains are synthetic datasets, which are far from free-form natural language or real images (Lake & Baroni, 2018; Johnson et al., 2017; Hupkes et al., 2020). ⇥ ⇥ 2 2 z z + + x y Figure 2: Expression tree for 2 x + ( y + z ) − 2 z . Because the equation is ev aluated bottom-up, having access to descen- dants 2 and z at the tree root enables the model to correctly e valuate the e x- pression. Summary of Results In this paper , 1 W e propose a nov el recursi ve neural network architecture, T ree Stack Memory Unit (T ree-SMU), to enable compositional generalization in the domain of mathematical reasoning. 2 W e ev aluate generalization of Tree-SMU on four different compo- sitionality tests. W e show that T ree-SMU consistently outperforms the compositional generalization of powerful baselines such as trans- formers, tree transformers (Shiv & Quirk, 2019) and T ree-LSTMs. 3 W e observe that improving the zero-shot generalization is also cor - related with improv ed sample efficienc y of training. W e propose augmenting the memory of recursi ve neural networks with a stack, thereby enabling them to capture long-range dependencies. Example to show how stacks can captur e non-local dependency: Consider ev aluating the mathematical expression 2 x + ( y + z ) − 2 z in Fig 2. At the subtract node (root), the model must ha ve access to the representations of 2 and z to be able to correctly equate the expression to 2 x + 2 y . Ho wev er , the nodes of a Tree-RNN or T ree-LSTM only hav e access to the states of their direct children, as imposed by the tree structure. Thus the representations of 2 and z will be mixed together by the time they reach the subtract node. Since the models do not learn perfect representations, the mixed representation of 2 and z is inaccurate and the models hav e difficulty correctly ev aluating the expression. On the other hand, a neural model with an extended memory capacity , such as a 2 T ree-SMU, can store intermediate representations at each node. If the memory is correctly used and trained, the subtract node in Fig 2, will have access to the original, unmixed representations of 2 and z and can correctly ev aluate the equation. W e propose Tree Stack Memory Unit (Tree-SMU), a tree-structured neural network whose nodes (SMUs) hav e an extended memory capacity with the structure of a differentiable stack (Fig. 1). The stack memory is built into the architecture of each SMU node, which learns to read (pop) from and write (push) to it. The nodes push the gated combination of their children’ s states and stacks onto their o wn stacks (Fig. 1c). Since, the intermediate results of the descendants are stored in the children’ s stacks, the current node also gets access to its descendants’ states if it chooses to push. This error-correction mechanism allo ws the tree model to capture long-range dependencies. W e choose the stack data structure (over say , a queue) because it preserves the ordering of each node’ s descendants. Therefore, this design retains locality while allowing for learning of global representations. It is worth noting that although T ree-SMU has a larger memory capacity , the model size is itself about the same as a tree-LSTM. Moreo ver , we sho w that merely increasing the model size in a tree-LSTM does not improve compositional generalization since it leads to overfitting to the training data distribution. Thus, we need sophisticated error-correction mechanism of T ree-SMU to capture long-range dependency ef fectiv ely for compositional generalization. W e test our model on two mathematical reasoning tasks, viz., equation verification and equation com- pletion (Arabshahi et al., 2018). On equation completion, we show through t-SNE visualization that T ree-SMU learns a much smoother representation for mathematical expressions, whereas T ree-LSTM is more sensitive to irrelev ant syntactic details. On equation completion, T ree-SMU achiev es 6% better top-5 accurac y compared to T ree-LSTM. On the equation v erification task, T ree-SMU achie ves 7% accuracy impro vement for generalizing to shallo w equations and 2% accuracy impro vement for generalizing to deeper equations, compared to Tree-LSTM, and 18.5% compared to transformers and 17.5% compared to tree transformers. For generalization at the same tree depth, we obtain 6.8% accuracy improvement compared to transformer and 5.5% compared to tree transformers with slightly (<1%) better accuracy compared to T ree-LSTM. Thus, we demonstrate that T ree-SMU achiev es the state-of-art performance, across the board on different compositionality tests. 2 B A C K G R O U N D A N D N OTA T I O N A recursiv e neural network is a tree-structured network in which each node is a neural network block. The tree’ s structure is often imposed by the data. For example, in mathematical equation verification the structure is that of the input equation (e.g., Figure 2). For simplicity , we present the formulation of binary recursiv e neural networks. Howe ver , all the formulations can be trivially e xtended to n-ary recursiv e neural networks and child-sum recursi ve neural netw orks. W e present matrices with bold uppercase letters, vectors with bold lo wercase letters, and scalars with non-bold letters. All the nodes (blocks) of a recursi ve neural netw ork have a state, h j ∈ R n , and an input, i j ∈ R 2 n , where n is the hidden dimension, j ∈ [0 , N − 1] and N is the number of nodes in the tree. W e label the children of node j with c j 1 and c j 2 . W e ha ve i j = [ h c j 1 ; h c j 2 ] where [ · ; · ] indicates concatenation. If the block is a leaf node, i j is the external input of the network. For e xample, in the equation tree shown in Figure 2, all the terminal nodes (lea ves) are the external neural netw ork inputs (represented with 1-hot vectors, for e xample). For simplicity we assume that non-terminal nodes do not hav e an external input and only take inputs from their children. Howe ver , if the y exist, we can easily handle them by additionally concatenating them with the children’ s states at each node’ s input. The way h j is computed using i j depends on the neural network block’ s architecture. For example, in a Tree-RNN and T ree-LSTM, h j is computed by passing i j through a feedforward neural network and a LSTM network, respectiv ely . Each node in T ree-LSTM also has a 1-D memory vector (Fig. 1a) 3 T R E E S TAC K M E M O RY U N I T ( T R E E - S M U ) In this section, we introduce T ree Stack Memory Units (T ree-SMU) that consist of SMU nodes which incorporate a dif ferentiable stack as a memory (Fig. 1b). Therefore, this structure has an increased memory capacity compared to a T ree-LSTM. Each SMU learns to read from and write to its stack and T ree-SMU learns to propagate information up the tree and to fill the stacks of each node using its children’ s states and stacks (Fig. 1c). The states of the descendants are stored in the children’ s stacks. Therefore, globally , the stack enables each node of Tree-SMU to ha ve indirect access to the state of its descendants. This is an error correction mechanism that captures long-range dependencies. Locally , stack preserves the order of each node’ s descendants. If a queue was used instead, the descendants’ 3 states would be stored in the rev erse order . Therefore, all the nodes see the state of the tree leav es on top of their memory . This destroys locality and hurts compositional generalization. Despite its increased memory capacity , the number of parameters of T ree-SMU is similar to Tree- LSTM for the same hidden state dimension. The stack memory is filled up and emptied using our proposed soft push and pop gates. Finally , a gated combination of the children’ s stacks along with the concatenated states of the children determine the content of the parent stack. Each SMU node j has a stack S j ∈ R p × n where p is the stack size. A stack is a LIFO data structure and the network only interacts with it through its top (in a soft w ay). W e use the notation S j [ i ] ∈ R n to refer to memory row i for i ∈ 0 , . . . p − 1 , where i = 0 indicates the stack top. For each node j , the children’ s stacks are combined using the equations below S c j [ i ] = f j 1 S c j 1 [ i ] + f j 2 S c j 2 [ i ] , (1) f j 1 = σ ( U ( f ) j 1 i j + b ( f ) j 1 ) , f j 2 = σ ( U ( f ) j 2 i j + b ( f ) j 2 ) . (2) Where matrices U ( f ) j 1 and U ( f ) j 2 , and vectors b ( f ) j 1 and b ( f ) j 2 are trainable weight and biases similar to the forget gates on T ree-LSTM. The push and pop gates are element-wise operators given belo w , a push j = σ ( A (push) j i j + b (push) j ) , (3) a pop j = σ ( A (pop) j i j + b (pop) j ) , (4) where A (push) j , A (pop) j , b (push) j , b (pop) j are trainable weights and biases and gates a push j , a pop j ∈ R n are element-wise normalized to 1. The stack is initialized with 0s and its update equations are, S j [0] = a push j u j + a pop j S c j [1] , (5) S j [ i ] = a push j S c j [ i − 1] + a pop j S c j [ i + 1] , (6) where u j is giv en below for trainable weights and biases U ( u ) j and b ( u ) j , u j = tanh( U ( u ) j i j + b ( u ) j ) , (7) The output state is computed by looking at the top-k stack elements as shown belo w if k > 1 , p j = σ ( U ( p ) j i j + b ( p ) j ) , (8) h j = o j tanh p j S j [0 : k − 1] , (9) where p j ∈ R 1 × k and U ( p ) j ∈ R k × n , b ( p ) j are trainable weights and biases, S j [0 : k − 1] indicates the top-k rows of the stack and k is a problem dependent tuning parameter . For k = 1 : h j = o j tanh( S j [0]) , (10) where o j is giv en below for trainable weights and biases U ( o ) j and b ( o ) j , o j = σ ( U ( o ) j i j + b ( o ) j ) . (11) Additional stack operation: No-Op W e can additionally add another stack operation called no-op. No-op is the state where the network neither pushes to the stack nor pops from it and keeps the stack in its previous state. The no-op gate and the stack update equations are giv en below a no-op j ∈ R n . a no-op j = σ ( A (no-op) j i j + b (no-op) j ) , (12) S j [0] = a push j u j + a pop j S c j [1] + a no-op j S c j [0] , (13) S j [ i ] = a push j S c j [ i − 1] + a pop j S c j [ i + 1] + a no-op j S c j [ i ] . (14) 4 E X P E R I M E N TA L S E T U P In this section, we present our studied benchmark tasks, our compositional generalization tests, and experiment details. W e briefly define these tasks below . 4 8 9 10 11 12 13 14 15 16 17 18 19 Depth 50 55 60 65 70 75 80 85 Accuracy LSTM Transformer Tree-Transformer Tree-RNN Tree-LSTM Tree-SMU (a) Equation V erification Experiment 8 9 10 11 12 13 14 15 16 17 18 19 Depth 50 55 60 65 70 75 80 85 Accuracy Tree-LSTM Tree-LSTM: 90% Tree-LSTM: 50% Tree-SMU Tree-SMU: 90% Tree-SMU: 50% (b) Sample efficienc y Figure 3: Producti vity T est breakdown of model accuracy across dif ferent depths for the stack- LSTM models and baselines T ree-RNN and T ree-LSTM. The percentage in Fig. 3b indicates the percentage of the training data used for training each model. Mathematical Equation V erification In this task, the inputs are symbolic and numeric mathemat- ical equations from trigonometry and linear algebra and the goal is to verify their correctness. For example, the symbolic equation 1 + tan 2 ( θ ) = 1 cos 2 ( θ ) is correct , whereas the numeric equation sin( π 2 ) = 0 . 5 is incorrect. This domain contains 29 trigonometric and algebraic functions. Mathematical Equation Completion In this task, the input is a mathematical equation that has a blank in it. The goal is to find a value for the blank such that the mathematical equation holds. For example for the equation sin 2 θ + _ = 1 , the blank is cos 2 θ . W e used the code released by Arabshahi et al. (2018) to generate more data of significantly higher depth. W e generate a dataset for training and v alidation purposes and another dataset with a different seed and generation hyperparameter (to encourage it to generate deeper expressions) for testing. All the equation verification models are run only once on this test set. The train and validation data ha ve 40k equations combined and the distribution of their depth is sho wn in Fig. 6 in the Appendix. The test dataset has around 200k equations of depth 1-19 with a pretty balanced depth distribution of about 10k equations per depth (apart from shallow e xpression of depth 1 and 2). Compositionality T ests In order to ev aluate the compositional generalization of our proposed model, we use four of the fiv e tests proposed by Hupkes et al. (2020). Localism states that the meaning of complex expression deriv es from the meanings of its con- stituents and ho w they are combined. T o see if models are able to capture localism, we train the models on deep expressions (depth 5-13) and test them on shallo w expressions (depth 1-4). Producti vity has to do with the infinite nature of compositionality . For example, we can construct arbitrarily deep (potentially infinite) equations using a finite number of functions, symbols and numbers and ideally our model should generalizes to them. In order to test for productivity , we train our models on equations of depth 1-7 and test the model on equations of depth greater than 7. Substitutivity has to do with preserving the meaning of equiv alent expressions such that replacing a sub-expression with its equivalent form does not change the meaning of the expression as a whole. In order to test for substitutivity , we look at t-SNE plots (Maaten & Hinton, 2008) of the learned representations of sub-e xpressions and show that equiv alent expressions (e.g. y × 0 and (1 × x ) × 0 ) map close to each other in the vector space. Systematicity refers to generalizing to the recombination of kno wn parts and rules to form other complex e xpressions. In order to assess this, we test the model on test data equations with depth 1-7, and train it on training data equation with depth 1-7. Baselines W e use sev eral baselines listed below to validate our e xperiments. All the recursive models (including T ree-SMU) perform a binary classification by optimizing the cross-entropy loss at the output which is the root (equality node). At the root, the models compute the dot product of the output embeddings of the right and the left sub-tree. The input of the recursiv e networks are the terminal nodes (leaves) of the equations that consist of symbols, representing variables in the equation, and numbers. The leaves of the recursi ve networks are embedding layers that embed the symbols and numbers in the equation. The parameters of the nodes that ha ve the same functionality 5 8 9 10 11 12 13 Depth 20 30 40 50 60 70 80 Top_1 Accuracy Tree-RNN Tree-LSTM Tree-SMU (a) T op 1 Accuracy 8 9 10 11 12 13 Depth 20 30 40 50 60 70 80 Top_5 Accuracy Tree-RNN Tree-LSTM Tree-SMU (b) T op 5 Accuracy Figure 4: Producti vity T est for Equation Completion. T op- K accuracy metrics breakdo wn in terms of the test data depth. are shared. For example, all the addition functions use the same set of parameters. Recurrent models input the data sequentially with parentheses. Majority Class is a classification approach that always predicts the majority class. LSTM is a Long Short T erm Memory network Hochreiter & Schmidhuber (1997). T ransformer is the transformer architecture of V aswani et al. (2017) T ree T ransf ormer a transformer with tree-structured positional encodings (Shiv & Quirk, 2019) T ree-RNN is a recursi ve neural network whose nodes are 2-layer feed-forw ard networks. T ree-LSTM is the T ree-LSTM network T ai et al. (2015) presented in Section 2. Evaluation metrics: For equation verification, we use the binary classification accuracy metric. For equation verification, we use the top- K accuracy (Arabshahi et al., 2018) which is the percentage of samples for which there is at least one correct match for the blank in the top- K predictions. Implementation Details The models are implemented in PyT orch (Paszk e et al., 2017). All the tree models use the Adam optimizer (Kingma & Ba, 2014) with β 1 = 0 . 9 , β 2 = 0 . 999 , learning rate 0 . 001 , and weight decay 0 . 00001 with batch size 32 . W e do a grid search for all the models to find the optimal hidden dimension and dropout rate in the range [50 , 55 , 60 , 80 , 100 , 120] and [0 , 0 . 1 , 0 . 15 , 0 . 2 , 0 . 25] , respecti vely . Tree-SMU’ s stack size is tuned in the range [1 , 2 , 3 , 4 , 5 , 7 , 14] and the best accuracy is obtained for stack size 2 . All the models are run for three dif ferent seeds. W e report the av erage and sample standard deviation of these three seeds. W e choose the models based on the best accuracy on the v alidation data (containing equations of depths 1-7) 1 . 5 R E S U L T S A N D D I S C U S S I O N 5 . 1 C O M P O S I T I O N A L I T Y T E S T R E S U LT S Localism As shown in T ab. 2, T ree-SMU significantly outperforms both T ree-RNN and Tree-LSTM achieving near perfect accuracy on the equation verification test set. This indicates that the stack memory is able to effecti vely capture and preserv e locality , as claimed in the introduction. Producti vity The producti vity test, ev aluates the stack’ s capability of capturing global long-range dependencies. In this test, the model is ev aluated on zero-shot generalization to deeper mathematical expressions. This is a harder task compared to localism. W e run the productivity test on both the equation completion and the equation verification tasks. Equation V erification: T able 1 shows the o verall accuracy of all the baselines on equations of depth 8-19. The break-down of the accurac y vs. the equation depth is shown in Fig. 3a. As it can be seen, T ree-SMU consistently outperforms all the baselines on all depths. This indicates that stack ef fecti vely captures global long-range dependencies, allowing the model to compositionally generalize to deeper and harder mathematical expressions. Howe ver , this task is harder and the impro vement mar gin of this test is smaller compared to the localism test. Equation Completion: In order to perform equation completion, we use the model trained on equation verification to predict blank fillers for an input equation that has a blank. The test data, dif ferent 1 The code and data for re-producing all the experiments will be released upon acceptance. W e hav e implemented a dynamic batching algorithm that achiev es a runtime comparable with Transformers 6 (a) T ree-SMU learned representations (b) T ree-LSTM learned representations Figure 5: Substitutivity T est : Learned representations of T ree-SMU and Tree-LSTM colored by expressions’ class label. Both the tree models are able to cluster equi v alent expressions. Howe ver , T ree-SMU is able to learn a much smoother representation per class label, whereas T ree-LSTM is more sensiti ve to irrele vant syntactic details. For example, the tw o sub-clusters of class 1 in Fig 5b group expressions raised to the power of 0 in the top sub-cluster and 1 raised to the power of an expression in the bottom sub-cluster . from the equation v erification test data, is generated by randomly substituting sub-trees of depth 1 or 2 in equations of depth 8-13 with a blank leaf. In order to make predictions for the blank, we generate candidates of depth 1 and 2 from the data vocab ulary and use the trained models to rank the candidates. Figures 4a and 4b show the top-1 and top-5 accuracy for the recursi ve models. As it can be seen, the performance of T ree-SMU is consistently better on all depths compared to T ree-LSTM and T ree-RNN. The performance of the recurrent models were poor and are not shown in these plots. Substitutivity The t-SNE plots of the representations learned by T ree-LSTM and Tree-SMU are shown in Fig. 5. Each point on the t-SNE plot is a mathematical expression with varying depths ranging from 1-3. The hyperparameters for both t-SNE plots are the same. W e hav e highlighted two of the clustered expressions with their equi valence class label, viz., 0 and 1 . These include expressions that e valuate to these numbers (e.g., 1 × ( x × 0) is in the 0 cluster). As it can be seen in the plots, both models are able to form clusters of equi valence classes although they are not directly minimizing these class losses. Howe ver , T ree-SMU learns a much smoother representation per equiv alence class. Moreov er , Tree-LSTM forms sub-clusters that are sensiti ve to irrele vant syntactic details. For e xample, the top sub-cluster of 1 in Fig 5b is the group of expressions raised to the power of 0 (e.g., x 0 ) and the bottom sub-cluster is the group of 1 raised to the power of an expression (e.g., 1 1 × y ). T ree-RNN is ev en more sensitiv e to these syntactic details (Fig. 7, Appendix). For example, T ree-RNN learns distinct sub-clusters for equiv alence class 0 that are grouped by expressions multiplied by 0 from the right (e.g., 1 2 × 0 ), and expression multiplied by 0 from the left (e.g., 0 × − 1 ). Another example of irrele vant syntactic detail captured by T ree-RNN is sensitivity to sub-e xpression depth. Therefore, if we substitute a sub-expression with a semantically equiv alent one, it is likely for the meaning of the expression as a whole to change due to T ree-RNN’ s sensitivity to irrele v ant syntactic details. Whereas Tree-SMU passes the substituti vity test. Systematicity The systematicity test results are shown in T ab. 4, column 3. As it can be seen the accuracy of T ree-LSTM is comparable with T ree-SMU. This compositional generalization test is easier compared with the previous ones because the training and test datasets are more similar in this case. Therefore, T ree-SMU’ s is more robust to changes in the compositionality of the test set compared to T ree-LSTM. This also indicates that Tree-LSTM is more likely to learn irrelev ant artifacts in the data that hurts its performance more under data distrib ution shifts. 5 . 2 S A M P L E E FFI C I E N C Y Finally , Figure 3b shows that T ree-SMU also has a better sample efficienc y compared with T ree- LSTM. This figure shows the performance on the entire producti vity test data for equation verification, but the models are trained on a sub-sample of the training data. The sub-sampling percentage is sho wn in the figure. As sho wn, the accuracy of T ree-SMU trained on 50% of the data is comparable to (at some depth slightly lower than) the accurac y of Tree-LSTM trained on the full dataset. Ho wever , the performance of Tree-LSTM significantly degrades when trained on 50% of the training data. Moreov er , even when T ree-SMU is trained using 90% of the data, it still outperforms Tree-LSTM. 7 T able 1: Producti vity and Systematicity T ests Equa- tion V erification, Overall model Accuracy on train and test datasets Appr oach T rain (Depths 1-7) V alidation (Depths 1-7) Systematicity T est (Depth 1-7) Productivity T est (Depths 8-19) Majority Class 58.12 56.67 60.40 51.71 LSTM 85.62 79.48 ± 4 . 53 79 . 24 ± 0 . 006 68.36 ± 0 . 42 Transformer 81.26 74 . 24 ± 0 . 02 76.45 ± 0 . 42 61.05 ± 1 . 53 Tree T ransformer 84.08 74 . 90 ± 0 . 02 77.80 ± 0 . 46 62.12 ± 1 . 06 Tree-RNN 99 . 11 89 . 45 ± 0 . 08 72 . 41 ± 0 . 01 68 . 95 ± 0 . 24 Tree-LSTM 99 . 86 93 . 05 ± 0 . 12 83 . 06 ± 0 . 01 77 . 58 ± 0 . 19 Tree-SMU 99.59 93 . 52 ± 0 . 28 83 . 29 ± 0 . 01 79 . 57 ± 0 . 16 T able 2: Localism T est Equation V erifi- cation, Overall model Accuracy on train and test datasets Approach T rain (Depths 5-13) V alidation (Depths 5-13) T est (Depths 1-4) Majority Class 55 . 08 56 . 45 60 . 88 Tree-RNN 99 . 07 88 . 33 ± 0 . 37 85 . 36 ± 0 . 35 Tree-LSTM 99 . 79 91 . 70 ± 0 . 0 . 28 91 . 58 ± 0 . 11 Tree-SMU 95 . 04 92 . 78 ± 0 . 09 98 . 86 ± 0 . 07 6 R E L A T E D W O R K Recursiv e neural networks hav e been used to model compositional data in many applications e.g., natural scene classification (Socher et al., 2011), sentiment classification, Semantic Relatedness and syntactic parsing (T ai et al., 2015; Socher et al., 2011), neural programming, and logic (Allamanis et al., 2017; Zaremba et al., 2014; Evans et al., 2018). In all these problems, there is an inherent compositional structure nested in the data. Recursiv e neural networks integrate this symbolic domain knowledge into their architecture and as a result achie ve a significantly better generalization performance. Howe ver , we sho w that standard recursive neural netw orks do not generalize in a zero- shot manner to unseen compositions mainly because of error propagation in the network. Therefore, we propose a nov el recursiv e neural network, T ree Stack Memory Unit that sho ws strong zero-shot generalization compared with our baselines. T ree Stack Memory Unit is a recursiv e network with an extended structured memory . Recently , there hav e been attempts to provide a global memory to recurrent neural models that plays the role of a working memory and can be used to store information to and read information from (Grav es et al., 2014; Jason W eston, 2015; Grefenstette et al., 2015; Joulin & Mikolov, 2015). Memory networks and their dif ferentiable counterpart (Jason W eston, 2015; Sukhbaatar et al., 2015) store instances of the input data into an external memory that can later be read through their recurrent neural network architecture. Neural Programmer Interpreters augment their underlying recurrent LSTM core with a key-v alue pair style memory and they additionaly enable read and write operations for accessing it (Reed & De Freitas, 2015; Cai et al., 2017). Neural T uring Machines (Grav es et al., 2014) define soft read and write operations so that a recurrent controller unit can access this memory for read and write operations. Another line of research proposes to augment recurrent neural networks with specific data structures such as stacks and queues (Das et al., 1992; Dyer et al., 2015; Sun et al., 2017; Joulin & Mik olov, 2015; Grefenstette et al., 2015; Mali et al., 2019). These works pro vide an e xternal memory for the neural network to access, whereas our proposed model integrates the memory within the cell architecture and does not treat the memory as an external element. Our design experiments showed that recursi ve neural netw orks do not learn to use the memory if it is an external component. Therefore, a trivial extension of works such as Joulin & Mikolov (2015) to tree-structured neural networks does not work in practice. Despite the amount of effort spent on augmenting recurrent neural networks, to the best of our knowledge, there has been no attempt to increase the memory capacity of recursi ve networks, which will allow them to e xtrapolate to harder problem instances. Therefore, inspired by the recent attempts to augment recurrent neural networks with stacks, we propose Tree Stack Memory Units, a recursi ve neural network that consists of dif ferentiable stacks. W e propose novel soft push and pop operations to fill the memory of each Stack Memory Unit using the stacks and states of its children. It is worth noting that a tri vial extension of stack augmented recurrent neural networks such as Joulin & Mik olov (2015) results in the stack-augmented T ree-RNN structure presented in the Appendix. W e show in our experiments that this tri vial extension does not work v ery well. In a parallel research direction, an episodic memory was presented for question answering applications Kumar et al. (2016). This is different from the symbolic way of defining memory . Another different line of work are graph memory networks and tree memory netw orks Pham et al. (2018); Fernando et al. (2018) which construct a memory with a specific structure. These works are different from our proposed recursiv e neural network which has an increased memory capacity due to an increase in the memory capacity of each cell in the recursiv e architecture. 8 7 C O N C L U S I O N S In this paper, we study the problem of zero-shot generalization of neural networks to novel com- positions of domain concepts. This problem is referred to as compositional generalization and it currently a challenge for state-of-the-art neural networks such as transformers and T ree-LSTMs. In this paper , we propose Tree Stack Memory Unit (T ree-SMU) to enable compositional generalization. T ree-SMU is a novel recursiv e neural network with an extended memory capacity compared to other recursiv e neural networks such as Tree-LSTM. The stack memory pro vides an error correction mechanism thriygh gi ves each node indirect access to its descendants acting as an error correction mechanism. Each node in Tree-SMU has a b uilt in differentiable stack memory . SMU learns to read from and write to its memory using its soft push and pop g ates. W e show that T ree-SMU achieves strong compositional generalization compared to baselines such as transformers, tree transformers and T ree-LSTMs for mathematical reasoning. R E F E R E N C E S Miltiadis Allamanis, Pankajan Chanthirasegaran, Pushmeet K ohli, and Charles Sutton. Learning con- tinuous semantic representations of symbolic e xpressions. In Pr oceedings of the 34th International Confer ence on Machine Learning-V olume 70 , pp. 80–88. JMLR. org, 2017. Forough Arabshahi, Sameer Singh, and Animashree Anandkumar . Combining symbolic expressions and black-box function ev aluations in neural programs. International Confer ence on Learning Repr esentations (ICLR) , 2018. Jonathon Cai, Richard Shin, and Dawn Song. Making neural programming architectures generalize via recursion. arXiv preprint , 2017. Sreerupa Das, C Lee Giles, and Guo-Zheng Sun. Learning context-free grammars: Capabilities and limitations of a recurrent neural network with an external stack memory . In Pr oceedings of The F ourteenth Annual Conference of Co gnitive Science Society . Indiana University , pp. 14, 1992. Chris Dyer , Miguel Ballesteros, W ang Ling, Austin Matthews, and Noah A Smith. Transition-based dependency parsing with stack long short-term memory . In Pr oceedings of the 53rd Annual Meeting of the Association for Computational Linguistics and the 7th International Joint Conference on Natural Languag e Pr ocessing (V olume 1: Long P apers) , pp. 334–343, 2015. Richard Evans, David Saxton, Da vid Amos, Pushmeet K ohli, and Edward Grefenstette. Can neural networks understand logical entailment? International Confer ence on Learning Repr esentations (ICLR) , 2018. Tharindu Fernando, Simon Denman, Aaron McFadyen, Sridha Sridharan, and Clinton Fookes. T ree memory networks for modelling long-term temporal dependencies. Neur ocomputing , 304:64–81, 2018. Alex Graves, Greg W ayne, and Ivo Danihelka. Neural turing machines. arXiv preprint arXiv:1410.5401 , 2014. Edward Grefenstette, Karl Moritz Hermann, Mustafa Suleyman, and Phil Blunsom. Learning to transduce with unbounded memory . In Advances in Neural Information Pr ocessing Systems , pp. 1828–1836, 2015. Sepp Hochreiter and Jürgen Schmidhuber . Long short-term memory . Neural computation , 9(8): 1735–1780, 1997. Dieuwke Hupkes, V erna Dankers, Mathijs Mul, and Elia Bruni. Compositionality decomposed: How do neural networks generalise? Journal of Artificial Intellig ence Resear ch , 67:757–795, 2020. Antoine Bordes Jason W eston, Sumit Chopra. Memory networks. 2015. Justin Johnson, Bharath Hariharan, Laurens van der Maaten, Li Fei-Fei, C Lawrence Zitnick, and Ross Girshick. Clevr: A diagnostic dataset for compositional language and elementary visual reasoning. In Pr oceedings of the IEEE Confer ence on Computer V ision and P attern Recognition , pp. 2901–2910, 2017. 9 Armand Joulin and T omas Mikolo v . Inferring algorithmic patterns with stack-augmented recurrent nets. In Advances in neural information pr ocessing systems , pp. 190–198, 2015. Diederik P Kingma and Jimmy Ba. Adam: A method for stochastic optimization. arXiv pr eprint arXiv:1412.6980 , 2014. Ankit Kumar , Ozan Irsoy , Peter Ondruska, Mohit Iyyer , James Bradbury , Ishaan Gulrajani, V ictor Zhong, Romain Paulus, and Richard Socher . Ask me anything: Dynamic memory networks for natural language processing. In International Conference on Mac hine Learning , pp. 1378–1387, 2016. Brenden Lake and Marco Baroni. Generalization without systematicity: On the compositional skills of sequence-to-sequence recurrent networks. In International Confer ence on Machine Learning , pp. 2873–2882, 2018. Luis Lamb, Artur Garcez, Marco Gori, Marcelo Prates, Pedro A velar , and Moshe V ardi. Graph neural networks meet neural-symbolic computing: A survey and perspectiv e. arXiv pr eprint arXiv:2003.00330 , 2020. Guillaume Lample and François Charton. Deep learning for symbolic mathematics. arXiv pr eprint arXiv:1912.01412 , 2019. Dennis Lee, Christian Szegedy , Markus N Rabe, Sarah M Loos, and Kshitij Bansal. Mathematical reasoning in latent space. arXiv preprint , 2019. Sarah Loos, Geof frey Irving, Christian Sze gedy , and Cezary Kaliszyk. Deep network guided proof search. EPiC Series in Computing , 46:85–105, 2017. Laurens van der Maaten and Geoffre y Hinton. V isualizing data using t-sne. Journal of machine learning r esear ch , 9(Nov):2579–2605, 2008. Ankur Mali, Alexander Ororbia, and C Lee Giles. The neural state pushdo wn automata. arXiv pr eprint arXiv:1909.05233 , 2019. Adam Paszk e, Sam Gross, Soumith Chintala, Gregory Chanan, Edward Y ang, Zachary DeV ito, Zeming Lin, Alban Desmaison, Luca Antiga, and Adam Lerer . Automatic differentiation in pytorch. 2017. T rang Pham, T ruyen T ran, and Svetha V enkatesh. Graph memory networks for molecular acti vity prediction. arXiv preprint , 2018. Scott Reed and Nando De Freitas. Neural programmer -interpreters. arXiv pr eprint arXiv:1511.06279 , 2015. David Saxton, Edward Grefenstette, Felix Hill, and Pushmeet Kohli. Analysing mathematical reasoning abilities of neural models. International Confer ence of Learning Repr esentations (ICLR) , 2019. V ighnesh Shiv and Chris Quirk. Nov el positional encodings to enable tree-based transformers. In Advances in Neural Information Pr ocessing Systems , pp. 12081–12091, 2019. Richard Socher , Cliff C Lin, Chris Manning, and Andrew Y Ng. Parsing natural scenes and natural language with recursiv e neural networks. In Pr oceedings of the 28th international conference on machine learning (ICML-11) , pp. 129–136, 2011. Sainbayar Sukhbaatar , Jason W eston, Rob Fergus, et al. End-to-end memory networks. In Advances in neural information pr ocessing systems , pp. 2440–2448, 2015. Guo-Zheng Sun, C Lee Giles, Hsing-Hen Chen, and Y ee-Chun Lee. The neural network pushdo wn automaton: Model, stack and learning simulations. arXiv preprint , 2017. Kai Sheng T ai, Richard Socher , and Christopher D Manning. Improv ed semantic representations from tree-structured long short-term memory networks. arXiv pr eprint arXiv:1503.00075 , 2015. 10 Ashish V aswani, Noam Shazeer , Niki Parmar , Jakob Uszk oreit, Llion Jones, Aidan N Gomez, Łukasz Kaiser , and Illia Polosukhin. Attention is all you need. In Advances in neural information pr ocessing systems , pp. 5998–6008, 2017. W ojciech Zaremba, Karol Kurach, and Rob Fergus. Learning to discov er efficient mathematical identities. In Advances in Neural Information Pr ocessing Systems , pp. 1278–1286, 2014. 11 A P P E N D I X D A TA D I S T R I B U T I O N A data generation strate gy for these tasks was presented in Arabshahi et al. (2018) and we use that to generate symbolic mathematical equations of up to depth 13. W e generate 41 , 894 equations of dif ferent depths (Figure 6 gi ves the number of equations in each depth). This dataset is approximately balanced with a total of 56% correct and 44% incorrect equations. This dataset is used for training and validation in the equation verification e xperiments. The test data is generated with a different seed and dif ferent hyper-parameters. This was done to encourage the data generation to generate a depth balanced dataset. The test dataset has about 200k equation of depths 1-19 with roughly 10 , 000 equations per depth (e xcluding depths 1 and 2). Equation verification models are run only once on the large test data and reported in the paper . Example equations from the training dataset are shown in T able 3. 1 2 3 4 5 6 7 8 9 10 11 12 13 0 2000 4000 6000 8000 21 355 2542 7508 9442 7957 6146 3634 1999 1124 677 300 189 Figure 6: Number of equations in the train/validation 40k data brok en down by their depth. T able 3: Examples of generated equations in the dataset Example Label Depth ( √ 1 × 1 × y ) + x = (1 × y ) + x Correct 4 sec( x + π ) = ( − 1 × sec(sec( x ))) Incorrect 4 y × 1 1 × (3 + ( − 1 × 4 0 × 1 )) + x 1 = y × 2 0 × (2 + x ) Correct 8 q 1 + ( − 1 × (cos( y + x )) √ csc(2) ) × (cos( y + x )) − 1 = tan( y 1 + x ) Incorrect 8 2 − 1 + − 1 2 × − 1 × q 1 + ( − 1 × sin 2 ( √ 4 × ( π + ( x × − 1)))) + cos √ 4 ( x ) = 1 Correct 13 cos( y 1 + x ) + z w = cos( x ) × cos(0 + y ) + − 1 × p 1 + − 1 × cos 2 ( y + 2 π ) × sin( x ) + z w Correct 13 sin √ 4 − 1 π + ( − 1 × sec csc 2 ( x ) − 1 + sin 2 (1 + ( − 1 × 1) + x + 2 − 1 π ) × x ) = cos(0 + x ) Incorrect 13 E X T E N D E D R E S U LT S A N D D I S C U S S I O N Producti vity test T able 4 is the extended version of T able 1 with additional ev aluation metrics. These additional metrics are precision (Prec) and recall (Rcl) for the binary classification problem of equation verification. The results are shown along with the accuracy metrics reported in T able 1. Stack size ablation In this experiment, we train Tree-SMU for various stack sizes p = [1 , 2 , 3 , 7 , 14] . The model is trained on equations of depth 1-7 from the training set and ev alu- ated on equations of depth 1-13 from the validation set. The purpose of this experiment is to indicate how man y of the stack rows are being used by the model for equations of different depth. The results are sho wn in T able 5. W e say that a stack row is being used if the P2 norm of that ro w has a v alue higher than a threshold τ = 0 . 001 . As shown in the T able, as the equations’ depth increase, more of 12 T able 4: Producti vity T est f or Equation V erification: Overall accurac y (Acc), precision (Prec) and recall (rcl) of the models on train and test datasets Appr oach T rain (Depths 1-7) V alidation (Depths 1-7) T est (Depths 8-19) Acc Prec Rcl Acc Prec Rcl Acc Prec Rcl Majority Class 58.12 - - 56.67 - - 51.71 - - LSTM 85.62 82.14 97.15 79.48 ± 4 . 53 76.18 ± 5 . 01 93.71 ± 0 . 80 68.36 ± 0 . 42 72.87 ± 7 . 01 60.25 ± 14 . 89 Transformer 81.26 78.36 93.67 76.45 ± 0 . 42 73.63 ± 0 . 80 91.11 ± 1 . 93 61.05 ± 1 . 53 59.96 ± 1 . 50 60.35 ± 7 . 28 Tree T ransformer 84.08 80.93 95.03 77.80 ± 0 . 46 74.97 ± 0 . 87 91.36 ± 1 . 25 62.12 ± 1 . 06 60.63 ± 1 . 54 63.68 ± 2 . 93 Tree-RNN 99 . 11 98 . 92 99 . 56 89 . 45 ± 0 . 08 88 . 47 ± 0 . 27 93 . 57 ± 0 . 40 68 . 95 ± 0 . 24 81 . 82 ± 1 . 42 46 . 60 ± 1 . 74 Tree-LSTM 99 . 86 99 . 80 99 . 96 93 . 05 ± 0 . 12 89 . 67 ± 1 . 32 98 . 36 ± 0 . 51 77 . 58 ± 0 . 19 79 . 50 ± 0 . 22 72 . 69 ± 0 . 25 Tree-SMU 99.59 99.31 99.98 93 . 52 ± 0 . 28 91 . 08 ± 0 . 49 98 . 19 ± 0 . 10 79 . 57 ± 0 . 16 80 . 56 ± 2 . 11 76 . 63 ± 3 . 15 T able 5: A verage stack usage for dif ferent stack sizes. Stack Size A verage Stack Usage broken do wn by Data Depth 3 4 5 6 7 8 9 10 11 12 13 1 0 . 05 0 . 19 0 . 41 0 . 63 0 . 80 0 . 89 0 . 89 0 . 91 0 . 97 0 . 89 0 . 82 2 0 . 05 0 . 31 0 . 74 1 . 16 1 . 45 1 . 63 1 . 69 1 . 66 1 . 77 1 . 73 1 . 54 3 0 . 05 0 . 31 0 . 93 1 . 55 1 . 97 2 . 18 2 . 32 2 . 23 2 . 35 2 . 39 2 . 13 7 0 . 05 0 . 31 0 . 93 1 . 73 2 . 47 2 . 82 3 . 21 3 . 00 3 . 14 3 . 51 3 . 31 14 0 . 05 0 . 31 0 . 93 1 . 73 2 . 47 2 . 82 3 . 21 3 . 05 3 . 23 3 . 76 3 . 53 the stack rows are being used. This might indicate that the SMU cells are using the extended memory capacity to capture long-range dependencies and ov ercome error-propagation. T -SNE plots Finally , we show the t-SNE plot visualizations for the tree-RNN in Figure 7. The learned representations by Tree-RNN are e ven more sensitiv e to irrelev ant syntactic details compared with T ree-LSTM and Tree-SMU. These v ariations in the learned representation of equiv alent sub- expressions, results in errors that propagate as e xpressions get deeper . Figure 7: Substitutivity T est : Learned representations of Tree-RNN. The model is able to cluster equiv alent expressions. Ho wev er , Tree-RNN is very sensitive to irrelev ant syntactic details. For example, the two sub-clusters of class 0 group e xpressions multiplied by 0 on the left (left sub-cluster) vs. on the right (right sub-cluster). 13
Original Paper
Loading high-quality paper...
Comments & Academic Discussion
Loading comments...
Leave a Comment