PyTorch中查看模型的参数信息的几种方式
【代码】PyTorch中查看模型的参数信息的几种方式。
·
如下 :
# model是实例化的模型对象
print(model)
print('*********************************************************************')
for param_tensor in model.state_dict(): # 字典的遍历默认是遍历 key,所以param_tensor实际上是键值
print(param_tensor, '\t', model.state_dict()[param_tensor].size())
print('*********************************************************************')
from prettytable import PrettyTable
def count_parameters(model):
table = PrettyTable(["Modules", "Parameters"])
total_params = 0
for name, parameter in model.named_parameters():
if not parameter.requires_grad: continue
param = parameter.numel()
table.add_row([name, param])
total_params += param
print(table)
print(f"Total Trainable Params: {total_params}")
return total_params
count_parameters(model)
print('*********************************************************************')
for para in model.named_parameters(): # 返回的每一个元素是一个元组 tuple
'''
是一个元组 tuple ,元组的第一个元素是参数所对应的名称,第二个元素就是对应的参数值
'''
print(para[0], '\t', para[1].size())
print('*********************************************************************')
# 总参数个数
print('总参数个数 = ',sum(p.numel() for p in model.parameters() if p.requires_grad))
print('*********************************************************************')
params = list(model.parameters())
# 网络层数
print('网络层数 = ',params.__len__())
更多推荐
所有评论(0)