GitHub链接:https://github.com/LiangYang666/prototypical-networks/tree/handbag
支持多gpu分布式训练,支持高版本pytorch1.x

EpisodicBatchSampler 抽样器

n_episodes: Number of episodes or equivalently batch size
n_way: Number of classes to sample
n_samples: Number of samples per episode (Usually n_query + n_support)
  • 部分参数
    • n_episodesbatch size
    • n_way 为每个batch中,抽取的训练类别数
    • n_samples为每一类中抽取的样本数量, 一般为查询集+支撑集数量
  • 每个batch采样生成的数据格式
    • Batch format: (c_i_1, c_j_1, ..., c_n_way_1, c_i_2, c_j_2, ... , c_n_way_2, ..., c_n_way_n_samples)
    • 不考虑图片维度,输出为一维,化为二维可看成,每一行n_way个,一共有n_samples

损失计算

  • class_prototypes,其shape(n_way_train, 512),为每一个类别其n_support个图经过网络生成的特征向量平均值
  • model(data_query), 其shape为(n_support*n_way_train, 512 )
  • logits, 其shape为(n_support*n_way_train, n_way), 为根据支撑集生成的特征向量,输出最有可能的类别
  • loss, 即交叉熵
Logo

技术共进,成长同行——讯飞AI开发者社区

更多推荐