论文标题
使用M-Sharpness-Aware最小化改善了深层神经网络概括
Improved Deep Neural Network Generalization Using m-Sharpness-Aware Minimization
论文作者
论文摘要
现代深度学习模型被过度参数化,优化设置会强烈影响概括性能。这些系统可靠优化的关键要素是修改损耗函数。清晰度感知最小化(SAM)修改了基本损耗函数,以指导下降方法降低最小值,这可以说具有更好的概括能力。在本文中,我们专注于一种称为MSAM的SAM的变体,在培训期间,该变体平均在微型批次的几个不相交的碎片中通过对抗性扰动产生的更新。最近的工作表明,在测试准确性方面,MSAM可以胜过SAM。但是,文献中缺少对MSAM的全面实证研究 - 先前的结果主要仅限于特定的体系结构和数据集。为此,本文对MSAM进行了各种任务和数据集的彻底经验评估。我们提供了MSAM的灵活实现,并将MSAM的概括性能与SAM和Vanilla培训的性能进行比较,以进行不同的图像分类和自然语言处理任务。我们还进行了仔细的实验,以了解MSAM训练的计算成本,其对超参数的敏感性以及与损失景观平坦度的相关性。我们的分析表明,与SAM相比,MSAM在各种任务中产生了出色的概括性能,而平坦的最小值则在没有显着提高计算成本的情况下。
Modern deep learning models are over-parameterized, where the optimization setup strongly affects the generalization performance. A key element of reliable optimization for these systems is the modification of the loss function. Sharpness-Aware Minimization (SAM) modifies the underlying loss function to guide descent methods towards flatter minima, which arguably have better generalization abilities. In this paper, we focus on a variant of SAM known as mSAM, which, during training, averages the updates generated by adversarial perturbations across several disjoint shards of a mini-batch. Recent work suggests that mSAM can outperform SAM in terms of test accuracy. However, a comprehensive empirical study of mSAM is missing from the literature -- previous results have mostly been limited to specific architectures and datasets. To that end, this paper presents a thorough empirical evaluation of mSAM on various tasks and datasets. We provide a flexible implementation of mSAM and compare the generalization performance of mSAM to the performance of SAM and vanilla training on different image classification and natural language processing tasks. We also conduct careful experiments to understand the computational cost of training with mSAM, its sensitivity to hyperparameters and its correlation with the flatness of the loss landscape. Our analysis reveals that mSAM yields superior generalization performance and flatter minima, compared to SAM, across a wide range of tasks without significantly increasing computational costs.