PyTorch中No Weight Decay策略

Oct 25, 2022· · 1 min read

Weight Decay权重衰减机制是一个比较常用的训练策略。 但是在某些场景下,需要在训练的时候关闭WeightDecay。

例如在训练ViT的时候,对于position embedding和class token都是不需要添加WeightDecay的,在训练卷积网络的时候,对于卷积层的bias参数也是可以不添加WeightDecay的。因此需要在创建优化器的时候指明。

# models.py
class ViT(nn.Module):
    ...
    def no_weight_decay(self):
        return {"pos_embed", "cls_token"}
    ...

class Model(nn.Module):
    def __init__(self):
        self.encoder = ViT()

    def no_weight_decay(self):
        def append_prefix_no_weight_decay(prefix, module):
            return set(map(lambda x: prefix + x, module.no_weight_decay()))

        nwd_params = append_prefix_no_weight_decay("encoder", self.encoder)

        return nwd_params

# train.py

def train():
    ...
    
    net = Model()

    # 获得不需要添加weightdecay的列表
    no_weight_decay_list = set(net.no_weight_decay())
    
    decay = []
    no_decay = []
    # 遍历并区分参数
    for name, param in net.named_parameters():
        if not param.requires_grad:
            continue

        # 对于卷积层中的bias不需要weight decay
        # 对于模型中指明不需要weight decay的部分
        if param.ndim <= 1 or name.endswith(".bias") or name in no_weight_decay_list:
            no_decay.append(param)
        else:
            decay.append(param)

    # 获得多个param_group,指定不同的weightdecay参数
    parameters = [
        {'params': no_decay, 'weight_decay': 0.},
        {'params': decay, 'weight_decay': 1e-5}
    ]

    # pytorch的优化器允许输入多组参数
    optimizer = optim.Adam(parameters, lr=config['learning_rate'])