【深度学习】元学习原型网络实现细节
GitHub链接:https://github.com/LiangYang666/prototypical-networks/tree/handbag支持多gpu分布式训练,支持高版本pytorch1.xEpisodicBatchSampler 抽样器n_episodes: Number of episodes or equivalently batch sizen_way: Number of
·
GitHub链接:https://github.com/LiangYang666/prototypical-networks/tree/handbag
支持多gpu分布式训练,支持高版本pytorch1.x
EpisodicBatchSampler 抽样器
n_episodes: Number of episodes or equivalently batch size
n_way: Number of classes to sample
n_samples: Number of samples per episode (Usually n_query + n_support)
- 部分参数
n_episodes
为batch size
n_way
为每个batch
中,抽取的训练类别数n_samples
为每一类中抽取的样本数量, 一般为查询集+支撑集数量
- 每个batch采样生成的数据格式
Batch format: (c_i_1, c_j_1, ..., c_n_way_1, c_i_2, c_j_2, ... , c_n_way_2, ..., c_n_way_n_samples)
- 不考虑图片维度,输出为一维,化为二维可看成,每一行
n_way
个,一共有n_samples
行
损失计算
class_prototypes
,其shape
为(n_way_train, 512)
,为每一个类别其n_support个图经过网络生成的特征向量平均值model(data_query)
, 其shape为(n_support*n_way_train, 512 )
logits
, 其shape为(n_support*n_way_train, n_way)
, 为根据支撑集生成的特征向量,输出最有可能的类别loss
, 即交叉熵
更多推荐
所有评论(0)