A Partially Reversible U-Net for Memory-Efficient Volumetric Image Segmentation

One of the key drawbacks of 3D convolutional neural networks for segmentation is their memory footprint, which necessitates compromises in the network architecture in order to fit into a given memory budget. Motivated by the RevNet for image classifi…

Authors: Robin Br"ugger, Christian F. Baumgartner, Ender Konukoglu

A Partially Reversible U-Net for Memory-Efficient Volumetric Image   Segmentation
A P artially Rev ersible U-Net for Memory-Efficien t V olumetric Image Segmen tation Robin Brügger, Christian F. Baumgartner, and Ender K onuk oglu Computer Vision Lab, ETH Zürich, Switzerland {baumgartner, ender.konukoglu}@vision.ee.ethz.ch Abstract. One of the k ey dra wbacks of 3D conv olutional neural net- w orks for segmen tation is their memory footprint, whic h necessitates compromises in the net work arc hitecture in order to fit into a giv en mem- ory budget. Motiv ated b y the RevNet for image classification, we prop ose a partially rev ersible U-Net architecture that reduces memory consump- tion substantially . The reversible arc hitecture allows us to exactly reco ver eac h lay er’s outputs from the subsequent lay er’s ones, eliminating the need to store activ ations for backpropagation. This alleviates the biggest memory b ottleneck and enables very deep (theoretically infinitely deep) 3D architectures. On the BraTS challenge dataset, we demonstrate sub- stan tial memory sa vings. W e further sho w that the freed memory can be used for processing the whole field-of-view (F OV) instead of patches. In- creasing netw ork depth led to higher segmentation accuracy while gro w- ing the memory fo otprint only b y a very small fraction, thanks to the partially rev ersible architecture. Keyw ords: Rev ersible neural netw ork · CNN · U-Net · Dice loss 1 In tro duction The design of 3D segmentation netw orks is often severely limited b y GPU mem- ory consumption. The main issue is the training. The output of each lay er (re- ferred to as activ ations for the remainder of this pap er) needs to be stored for the bac kward pass. This problem is exacerbated for deep er net works and when feed- ing the netw ork with larger field-of-views. Alleviating this issue has the p otential to yield accuracy gains, esp ecially considering recen t success of 3D architectures [4,9]. While the memory problem can alwa ys b e mitigated b y better hardware, the prop osed metho d is a more cost effective and sustainable solution. F urther- more, our method allo ws for the design of 3D segmentation netw orks with depths that migh t not hav e b een p ossible b efore on any hardware. Sev eral strategies for reducing memory consumption of segmentation CNNs ha ve b een prop osed in the literature. One p ossible strategy is pro cessing the v olume in patches. Kamnitsas et al. (DeepMedic) [6] used patches sampled at t wo differen t scales of size 25 x 25 x 25 and 19 x 19 x 19 v oxels. They noted that con textual information from a large area is b eneficial for accuracy . This is also evidenced b y a trend to wards bigger patch sizes. F or instance, the winner of the 2 R. Brügger et al. BraTS 2018 c hallenge [9] used a large patc h size of 160x192x128, which co vers almost the entire F OV, but required a GPU with 32 GB of memory for training. A second method to reduce the memory consumption is to train with small batc h sizes. F or example, b oth the V-Net [8] and the No-New-Net [4] used a batc h size of tw o. The winner of the BraTS 2018 challenge even used a batch size of one [9]. Small batc h sizes do not work w ell with the commonly used Batc h Normalization but Group Normalization (GN) can b e used instead [11]. The p oten tial memory sa vings of patc h-based approaches and small batch sizes is limited. Small patch sizes lack global context, and the smallest p ossible batc h size is 1. On the other hand, most of the memory during training of a 3D neural netw ork is taken up by the need to store activ ations. Reducing the need to store activ ations can lead to substantial memory gains well b eyond the other strategies. F or a classification net work as w ell as an RNN, Chen et al. [2] prop osed to only store these v alues at certain in terv als. When they are needed during bac k-propagation, the missing activ ations can b e recomputed by an additional forw ard pass. Assuming a simple, nonbranc hing feed-forward netw ork, storing the activ ations at every √ n ’th lay er of a n -la yer netw ork reduces the memory requiremen t from O ( n ) to O ( √ n ) while doubling the computational cost of the forw ard passes. Recursively applying this technique allows for as little as O (log n ) memory for the activ ations. (a) F orward computation (b) Bac kward computation Fig. 1: Rev ersible blo ck Rev ersible lay ers prop osed by Gomez et al. [3] take the idea of not storing activ ations ev en further. During the backw ard pass, each lay er’s activ ations can b e calculated from the following lay er’s activ ations. Therefore, no intermediate activ ations need to sa ved, allo wing for O (1) memory cost for the activ ations. Ho wev er, for this to work the num b er of inputs of a lay er needs to match the n umber of outputs. The basic building element of a reversible neural netw ork is the rev ersible blo c k. It is illustrated in Fig. 1. In the forward computation, it takes inputs x 1 and x 2 of equal shap e and outputs y 1 and y 2 . The forward computation can b e expressed with the follo wing equations: y 1 = x 1 + F ( x 2 ) y 2 = x 2 + G ( y 1 ) (1) F and G can be any arbitrary function, as long as the output shape is the same as the input shape. An example for F and G is a con volutional lay er that main tains P artially Reversible, Memory-Efficient U-Net 3 c hannel size and n umber follo wed by an activ ation function. The reversible block partitions the lay er’s input into t wo groups, one flowing from the top path in Fig. 1(a) and one from the bottom. Among differen t partitioning strategies, Gomez et al. [3] found that partitioning the c hannels works b est. F ollowing [3], w e partition the c hannels of the activ ation maps into tw o groups for all rev ersible blo c ks in this work. In the bac kward pass, the deriv ativ es of the parameters of F and G , as well as the reversible blo ck’s original inputs are calculated. The design of the reversible blo c k allows to reco ver x 1 and x 2 giv en only y 1 and y 2 using Equation 2, thus making it rev ersible. x 2 = y 2 − G ( y 1 ) x 1 = y 1 − F ( x 2 ) (2) Multiple rev ersible blo cks can b e chained together to form a sequence of arbitrary length. Because every blo c k is rev ersible, the entire sequence is rev ersible as w ell. W e refer to this as a r eversible se quenc e (RSeq) and illustrate it in Fig. 2. An RSeq is an important building block for our partially rev ersible arc hitecture b ecause the activ ations only hav e to b e sa ved at the end of a RSeq, hence an en tire RSeq has almost the same memory fo otprint as a single blo ck with the difference b eing memory required to store extra w eights, whic h is a very small fraction of the activ ations for 3D conv olutional lay ers. The le ngth of the sequence can therefore be v aried (almost) free of memory constraints. That allo ws us to create a family of arc hitectures which v ary in the num ber of blo c ks p er RSeq. Fig. 2: Rev ersible sequence Blum b erg et al. [1] used rev ersible blocks to create a v ery deep net w ork for the task of super-resolution. They employ ed an architecture without skip connections and mainly fo cused on accuracy gains. W e on the other hand prop ose a partially rev ersible U-Net with the main aim of enabling deep and p ow erful 3D segmen tation netw orks on commo dity hardware. 2 Metho ds 2.1 P artially rev ersible architecture In this section, we describ e our nov el partially rev ersible U-Net architecture. W e c hose the No-New-Net b y Isensee et al. [4] as a non-rev ersible starting p oint and extend it. The No-New-Net won the second place in the BraTS 2018 c hallenge and was trained on a graphics card with 12GB of memory , a commonly used and widely a v ailable type of hardware. 4 R. Brügger et al. As men tioned b efore, a key constraint for a reversible blo ck is that the num- b er of input units must b e equal to the num b er of output units. There are tw o issues to ov ercome when adapting reversible blo cks in a U-Net architecture. First, spatial do wnsampling and upsampling which are an in tegral part of the U-Net [10] typically c hange the n umber of output units. Secondly , the branching and merging nature of the U-Net mak es it impractical to mo del the en tire netw ork as a single reversible sequence. While it is p ossible to create a reversible spatial do wnsampling op eration as demonstrated by Jacobsen et al. [5], in 3D netw orks it leads to a prohibitiv ely high n umber of channels on the low er resolution levels ( d 3 -fold increase for d -fold downsampling) and do es not solve the problem of the branching and merging net work structure. W e therefore forgo the idea of a fully reversible architecture and employ a single reversible sequence per resolu- tion level in b oth the enco der and deco der while using traditional non-reversible op erations for down- and upsampling as well as the skip connections. This par- tially rev ersible arc hitecture still realizes large memory sa vings o v er a traditional U-Net because the activ ations only need to b e sa ved at the end of each reversible sequence and for the non-rev ersible comp onents. It also retains the p ossibility to mak e the netw ork (almost) arbitrarily deep by increasing the n umber of re- v ersible blocks in any reversible sequence. This leads to a family of rev ersible arc hitectures whose members effectively hav e very similar memory requirements while substan tially v arying in depth. Figure 3 shows our architecture. Fig. 3: P artially reversible architecture 2.2 Non-rev ersible baseline equiv alen t F or a fair comparison, w e construct a non-rev ersible baseline with the same num- b er of parameters as the partially rev ersible architecture with only one reversible blo c k p er reversible sequence. T o achiev e this, we make the following changes: P artially Reversible, Memory-Efficient U-Net 5 – Replace eac h rev ersible blo ck with lay ers GN–LeakyReLU–Conv–GN–LeakyReLU– Con v, which is effectively a concatenation of F and G – Change the n umber of c hannels on the different resolution lev els to [30, 60, 120, 240, 480] following the argumentation of Gomez et al. in [3]. The smaller n umber of c hannels on the higher resolution levels are justified by the fact that c hannels are partitioned into tw o for the reversible architectures. – 1x1x1 conv olutions are not needed to c hange the num ber of channels. W e therefore omit them. The non-reversible baseline and the reversible arc hitecture with a single re- v ersible blo ck p er reversible sequence b oth ha ve approximately 12.5M parame- ters. 2.3 Memory analysis The bulk of the memory consumption when training a non-rev ersible neural net work can b e attributed to three categories. Firstly , the activ ations of the la yers ( M Al ). Secondly , the net work parameters and memory related to them, suc h as per-parameter state of some optimizers such as Adam ( M P l ). Finally , the deriv atives of the activ ations during bac k-propagation ( M Dl ). M Al , M P l and M Dl stand for the resp ective memory requirement of the l -th lay er. The tensors in the first tw o categories need to be allocated for the entire forward- and backw ard pass. In contrast, memory allo cated for the deriv atives of the activ ations can b e immediately freed after bac k-propagation has pro cessed all direct predecessors of a lay er. Assuming a non-branc hing netw ork, the total memory requirement for training can therefore b e expressed with the following equation: M T otal N onRev = X l ∈ Lay er s ( M Al + M P l ) + max l ∈ Lay er s M Dl (3) F or a partially rev ersible architecture, we distinguish betw een activ ations for the non-rev ersible lay ers ( M N l ) and the activ ations at the end of each reversible sequence ( M S i ). W e still need the parameter related memory ( M P l ). The back- propagation of a rev ersible blo ck is complex and needs additional memory o ver just the deriv atives of the activ ations ( M B l ). Again assuming a non-branching net work, the total memory needed for a partially rev ersible arc hitecture can b e calculated with the follo wing equation: M T otal P Rev = X l ∈ N onRev L. M N l + X i ∈ RS eq M S i + X l ∈ Lay er s M P l + max l ∈ Lay er s M B l (4) Eac h rev ersible sequence replaces several non-reversible lay ers, leading to a low er total memory consumption. F urthermore, making partially reversible net works deep er by adding more blo cks to any reversible sequence only grows the num b er of M P l -terms, whereas a non-rev ersible net w ork also needs additional M Al -terms for saving the activ ations. This comes at the exp ense of M B l b eing larger than M Dl due to the more complex backpropagation pro cess. How ev er, we only need to accoun t for the maximum M B l , therefore the drawbac k is easily outw eighed b y having to sav e fewer activ ations. 6 R. Brügger et al. 2.4 T raining pro cedure W e demonstrate the adv antages of the rev ersible neural netw ork on the BraTS 2018 dataset [7]. W e implemen ted the rev ersible blo ck and the reversible sequence using PyT orch, seamlessly in tegrating them into the autograd framework. The result is RevT orc h 1 , an open source library that can be used to create (partially) rev ersible neural net works. The code used to conduct the exp erimen ts in this pap er is public as well 2 . Prepro cessing W e standardize each image individually to hav e zero mean and unit v ariance (based on non-zero v oxels only). T o av oid initial bias of the netw ork to wards one mo dality , we standardize each modality indep endently . W e set all v oxels outside of the brain to zero. W e also crop all images to 160x192x160 which fits all but a single brain. Augmen tation W e use common training data augmentation strategies includ- ing random rotation, random scaling, random elastic deformations, random flips and small shifts in intensit y . Because the z-axis of the BraTS dataset has b een in terp olated, we av oid any augmentation that is not independent of the z-axis. That means we rotate, scale and do elastic deformation only on the planes that are p erp endicular to the z-axis. This wa y , we do not further degrade the image qualit y . T raining W e emplo y a random data split of 80%/20% of the training data pro vided b y the organizers of the BraTS c hallenge for training and v alidation. W e train using the Adam optimizer with an initial learning rate of 10 − 4 . The learning rate is decreased by a factor of 5 after 250, 400 and 550 ep o chs. The batc h size is one. W e output the nested tumor regions whole tumor (WT), tumor c or e (TC) and enhancing c or e (ET) directly . The loss is the unw eighted sum of the Dice losses of eac h of the three regions. T raining is stopp ed once the 30- ep o ch moving a verage Dice score on the v alidation split has not increased for 60 epo c hs. W e regularize using a w eight decay of 10 − 5 . W e do not employ an y p ost-pro cessing. F or inference, w e use the weigh ts which ac hieved the b est Dice score on the v alidation split. Thanks to the fully conv olutional nature of our arc hitecture, we alwa ys segment an en tire volume in a single pass at test time. 3 Exp erimen ts and Results W e ev aluated all models on the online v alidation dataset provided by the organiz- ers of the BraTS challenge. Note that this dataset is distinct from the v alidation split describ ed abov e. T able 1 shows our results. All metrics were computed using the online ev aluation platform. #Enco der and #Deco der correspond to 1 https://github.com/RobinBruegger/RevTorch 2 https://github.com/RobinBruegger/PartiallyReversibleUnet P artially Reversible, Memory-Efficient U-Net 7 T able 1: Results on the BraTS 2018 v alidation dataset Memory Dice HD95 Arc hitecture #Enco der #Deco der P atchsize (MB) ET WT TC ET WT TC Baseline – – 128 3 6646 79.29 89.99 82.02 4.57 5.53 8.88 Rev ersible 1 1 128 3 4138 79.02 90.24 83.92 3.19 5.33 8.33 Rev ersible 1 1 full 9436 78.99 90.39 84.40 3.16 4.67 6.39 Rev ersible 2 1 full 9529 79.66 90.68 85.47 3.35 4.64 7.38 Rev ersible 3 1 full 9620 80.33 91.01 86.21 2.58 4.58 6.84 Rev ersible 4 1 full 9713 80.56 90.61 85.71 3.35 5.61 7.83 Isensee et al. (2018) baseline 128 3 – 79.59 90.80 84.32 3.12 4.79 8.16 Isensee et al. (2018) final submission 128 3 – 80.87 91.26 86.34 2.41 4.27 6.52 the num ber of reversible blo cks in each reversible sequence in the enco der and deco der path resp ectively . The memory consumption was measured using Py- T orc h’s max_memory_allocated() . F or all p erformance metrics, the mean of all images in the dataset is rep orted. The baseline in row 1 is not reversible and w as trained with the patch-based approac hed mentioned in the in tro duction to address the memory issue. The rev ersible equiv alent in the 2nd ro w which was also trained with patches of the same size for direct comparison reduced the memory consumption by more than one third without lo wering the p erformance. The memory reduction realized b y our arc hitecture allo ws the pro cessing of the whole F OV instead of patches, for which the results are shown in the 3rd ro w. W e b elieve this is beneficial be cause it allo ws to use more context com- pared to the patch-based approach. While not having improv ed the Dice scores substan tially , it had a positive effect on the Hausdorff metric. As discussed, a k ey strength of the partially reversible architecture is that we can increase the depth of the netw ork with a very small additional memory requirement coming from the increased n umber of parameters. Comparing rows 3 to 6 in T able 1, w e observ e that the memory required when making the netw ork deep er is minimal. In contrast, a non-rev ersible equiv alent of the architecture with four reversible blo c ks p er reversible sequence in the enco der would need ov er 20GB of memory for the activ ations alone. W e also observe that the segmentation p erformance increased when going from a single to three reversible blocks per rev ersible se- quence in the enco der b efore receding again when using four reversible blo cks. F or reference, T able 1 also shows the results rep orted by Isensee et al. [4]. The score most directly comparable to our metho d is their baseline. It uses the Cross-en tropy loss instead of the Dice loss during training, employs test-time- augmen tation and is an ensemble of five mo dels trained on different splits of the training data. Their final submission additionally employ ed p ost-pro cessing and co-training with additional public and institutional data. Our architecture with three reversible blo cks p er sequence achiev ed similar segmentation p erformance with a single mo del trained on 80% of the training data and without an y of the additional strategies men tioned ab ov e. 8 R. Brügger et al. The reported memory sa vings do come at a cost. W e observ ed a 50% increase in training time for a reversible arc hitecture o ver its non-reversible equiv alen t. Ho wev er, considering the gains in p erformance and the opportunities for large- scale arc hitectures, we b elieve the longer training time is acceptable. 4 Discussion W e ha ve demonstrated an extremely memory efficient, partially rev ersible U-Net arc hitecture for segmentation of v olumetric images. W e ac hieve competitive pe r- formance compared to current state-of-the art achitectures on the same memory budget. W e do this without ensembling, test-time-augmentation, post-pro cessing or co-training with additional data. Applying these techniques to our arc hitecture ma y further improv e p erformance. W e hav e demonstrated that our architecture allo ws arbitrarily deep netw orks with minimal additional memory requirements. W e b elieve our contribution will allow more researc hers to design and inv esti- gate large-scale 3D netw ork architectures, even if they do not hav e access to exp ensiv e, highly sp ecialized hardware with massive amounts of memory . References 1. Blumberg, S.B., T anno, R., Kokkinos, I., Alexander, D.C.: Deeper image qual- it y transfer: T raining lo w-memory neural net works for 3d images. Medical Image Computing and Computer Assisted Interv en tion (MICCAI) 2018 2. Chen, T., Xu, B., Zhang, C., Guestrin, C.: T raining deep nets with sublinear mem- ory cost (2016), 3. Gomez, A.N., Ren, M., Urtasun, R., Grosse, R.B.: The reversible residual net work: Bac kpropagation without storing activ ations. In: Adv ances in Neural Information Pro cessing Systems 30, pp. 2214–2224. Curran Asso ciates, Inc. (2017) 4. Isensee, F., Kic kingereder, P ., Wick, W., Bendszus, M., Maier-Hein, K.H.: No new- net (2018), 5. Jacobsen, J.H., Smeulders, A.W., Oyallon, E.: i-revnet: Deep in vertible netw orks. In: In ternational Conference on Learning Representations (2018) 6. Kamnitsas, K., Ledig, C., New combe, V.F., Simpson, J.P ., Kane, A.D., Menon, D.K., Rueck ert, D., Glo ck er, B.: Efficient m ulti-scale 3d cnn with fully connected crf for accurate brain lesion segmentation. Medical Image Analysis 36 , 61 – 78 (2017), http://www.sciencedirect.com/science/article/pii/S1361841516301839 7. Menze, B.H., et al.: The multimodal brain tumor image segmentation b enchmark (brats). IEEE T ransactions on Medical Imaging 34 (10), 1993–2024 (Oct 2015) 8. Milletari, F., Na v ab, N., Ahmadi, S.: V-net: F ully conv olutional neural net works for v olumetric medical image segmentation. In: 2016 F ourth International Conference on 3D Vision (3DV). pp. 565–571 (Oct 2016) 9. Myronenko, A.: 3d MRI brain tumor segmentation using auto enco der regulariza- tion (2018), 10. Ronneb erger, O., Fisc her, P ., Bro x, T.: U-net: Con volutional net works for biomedi- cal image segmentation. In: Medical Image Computing and Computer-Assisted In- terv ention – MICCAI 2015. pp. 234–241. Springer In ternational Publishing, Cham (2015) 11. W u, Y., He, K.: Group normalization (2018),

Original Paper

Loading high-quality paper...

Comments & Academic Discussion

Loading comments...

Leave a Comment