模型在保存时侯以键对值保存,同时在加载时根据现在网络的键值查找模型对应的键值,然后加载。一般报错是因为模型和网络的键值不匹配。

1、最常见的问题是键值多了或者少了 module.

此种情况是模型在DataParallel或者DDP训练后保存的键值有module. ,对应的网络的键值则没有module.

1)可以通过:

model = nn.DataParallel(model)
   

将模型的键值加上module.

2) 也可以通过遍历模型的键对值修改键值。

   如:加载模型时删除多余的module.  代码如下


   
  1. state_dict = torch.load(load_path)
  2. for key, param in state_dict.items():
  3. if key.startswith( 'module.'): #键值包含‘module.’ 则删除
  4. state_dict[key[ 7:]] = param
  5. state_dict.pop(key)
  6. net.load_state_dict(state_dict)

2、详解load_state_dict(state_dict, False)的False参数

很多教程说名字不匹配直接添加False参数即可,但是这里需要注意一个大坑。

如果模型的键值和网络的键值完全不匹配,那么模型就没有加载预训练参数,虽然不再报错。

该False参数作用在于 非严格匹配加载模型,可以下面几种情况进行分析。

1)模型包含网络的部分参数

比如说模型是resnet101模型,你现在的网络是resnet50。再假设resnet50的参数名包含在resnet101的参数中,那么直接使用False会为你的网络resnet50加载键值相同的参数。这样就避免了对resnet101的每个键对值进行循环匹配,看是否是resnet50需要的。

2)模型完全不包含网络的参数

情况如1,模型有100个参数,都包含'module.' ,网络也有100个参数,都没有'module.' 。这种情况下如果参数设置为False,会发现没有任何键值能匹配上,因此网络就不会加载任何参数。

3)再介绍一个False使用场景

比如蒸馏网络PISR中,教师网络包含Encoder和Decoder两部分,学生网络由其中的Decoder部分组成,所以在训练学生网络时,如果要加载教师网络保存的预训练模型,设置False会自动识别Decoder部分键值相同,然后加载。

综上,设置False参数后依旧是按照键值查询加载参数的,有多少键值匹配,就加载多少模型的参数。

 

3、只要参数尺寸相同,就能加载

比如说我有一个10层网络的模型,还有一个3层的网络。我想把其中第9层的参数加载到现在网络的1层。如果参数的尺寸相同,就可以遍历键对值。将参数加载到想要的键值中。


   
  1. state_dict = torch.load(load_path)
  2. new_state_dict = []
  3. for key, param in state_dict.items():
  4. if 'conv9' in key: # 如果找到conv9对应的参数,将其键值替换为网络的键
  5. new_state_dict[key.replace( 'conv9', 'conv1')] = param
  6. net.load_state_dict(new_state_dict)

更多参考

https://blog.csdn.net/t20134297/article/details/110533007
https://blog.csdn.net/binbinczsohu/article/details/107943806
https://blog.csdn.net/yangwangnndd/article/details/100207686
https://blog.csdn.net/jacke121/article/details/91390803
https://blog.csdn.net/G_inkk/article/details/119927235

Logo

技术共进,成长同行——讯飞AI开发者社区

更多推荐