pytorch默认只保存最后一层的输出,中间层输出默认不保存,要提取中间层网络输出值,需要使用回调函数register_forward_hook(),通过传入处理函数,便可以提取和保存特点网络层的输出值。

class ActivationData():
	#网络输出值
	outputs = None
	def __init__(self,layer):
		#在模型的layer_num层上注册回调函数,并传入处理函数hook_fn
		self.hook = layer.register_forward_hook(self.hook_fn)

	def hook_fn(self,module,input,output):
	self.outputs = output.cpu()

	def remove(self):
	#由回调句柄调用,用于将回调函数从网络层删除
	self.hook.remove()
#获取第二个卷积层
conv_out = ActivationOutputData(model.conv2)
#传入图片
o = model(img)
#移除回调函数
conv_out.remove()
#输出图片
for i in range(16):
	ax.imshow(conv_out.outputs[0][i].detach().numpy())

资料来源:《pytorch深度学习实战 从新手小白到数据科学家》

Logo

技术共进,成长同行——讯飞AI开发者社区

更多推荐