PyTorch中No Weight Decay策略
Weight Decay权重衰减机制是一个比较常用的训练策略。 但是在某些场景下,需要在训练的时候关闭WeightDecay。
例如在训练ViT的时候,对于position embedding和class token都是不需要添加WeightDecay的,在训练卷积网络的时候,对于卷积层的bias参数也是可以不添加WeightDecay的。因此需要在创建优化器的时候指明。
# models.py
class ViT(nn.Module):
...
def no_weight_decay(self):
return {"pos_embed", "cls_token"}
...
class Model(nn.Module):
def __init__(self):
self.encoder = ViT()
def no_weight_decay(self):
def append_prefix_no_weight_decay(prefix, module):
return set(map(lambda x: prefix + x, module.no_weight_decay()))
nwd_params = append_prefix_no_weight_decay("encoder", self.encoder)
return nwd_params
# train.py
def train():
...
net = Model()
# 获得不需要添加weightdecay的列表
no_weight_decay_list = set(net.no_weight_decay())
decay = []
no_decay = []
# 遍历并区分参数
for name, param in net.named_parameters():
if not param.requires_grad:
continue
# 对于卷积层中的bias不需要weight decay
# 对于模型中指明不需要weight decay的部分
if param.ndim <= 1 or name.endswith(".bias") or name in no_weight_decay_list:
no_decay.append(param)
else:
decay.append(param)
# 获得多个param_group,指定不同的weightdecay参数
parameters = [
{'params': no_decay, 'weight_decay': 0.},
{'params': decay, 'weight_decay': 1e-5}
]
# pytorch的优化器允许输入多组参数
optimizer = optim.Adam(parameters, lr=config['learning_rate'])