一、tensorboard

from tensorboardX import SummaryWriter

...

# 一个写对象就对应着一个event
train_log_path = '.'       # 在当前工作目录下保存,当前工作目录即os.getcwd()所得到的目录
valid_log_path = '.'
train_writer = SummaryWriter(train_log_path, filename_suffix='TRAIN')
val_writer = SummaryWriter(valid_log_path, filename_suffix='VAL')

...

for epoch in range(0, max_epochs):
    
    ...

    for batch_idx, batch in enumerate(train_dataloader):
        ...
        loss1=...
        loss2=...
        n_batchsize = len(train_dataloader)
        step = epoch * n_batchsize + batch_idx
        train_writer.add_scalar('Loss/Step/loss1', loss1, step)
        train_writer.add_scalar('Loss/Step/loss2', loss2, step)
    
    # 计算一个epoch下来平均的指标值
    avg_loss1=total_loss1/n_batchsize
    avg_loss2=total_loss2/n_batchsize
    train_writer.add_scalar('Loss/Epoch/loss1', avg_loss1, epoch)
    train_writer.add_scalar('Loss/Epoch/loss2', avg_loss2, epoch)

train_writer.close()
# tensorboard可视化命令 
tensorboard --logdir="/your/events_path"

注意点:

1.如果声明了多个SummaryWriter对象(即上述代码中的例子),且都在同一个目录下,那么tensorboard可视化的时候只会显示一个,若想同时可视化,需要放在不同的目录之下。

2.writer.add_scalar的第一个参数tag,在可视化的时候可以自己根据tag命名格式分类,如tag为train/loss,则属于train分组,那么train/iou、train/acc都在此分组下;如tag为epoch,则属于epoch分组。即会根据tag中的‘/’自动分组。

3.建议以分组形式,写在同一个SummaryWriter对象下。

二、logging


import logging
import torch.distributed as dist
import os

def get_logger(log_name, log_level, log_file=None, file_mode='a'):
    '''
    获取由logging库提供的logger
    log_name:用于标识log的名称
    log_level:打印等级,例如logging.INFO  logging.DEBUG  logging.ERROR  logging.CRITICAL
    log_file:输出日志的文件地址
    file_mode:输出日志的文件模式,a为追加,w为覆盖等
    '''

    logger = logging.getLogger(log_name)
    logger.propagate = False                    # 阻止日志消息传递给父级logger

    # 判断是否为是多卡运行
    if dist.is_available() and dist.is_initialized():
        rank = dist.get_rank()
    else:
        rank = 0

    handlers = []

    # 流处理器
    stream_handler = logging.StreamHandler()    # 用于将日志消息输出到控制台或者标准输出流
    handlers.append(stream_handler)

    if rank == 0 and log_file is not None:
        # 文件处理器
        if not os.path.exists(os.path.dirname(log_file)):
            os.makedirs(os.path.dirname(log_file))
        file_handler = logging.FileHandler(log_file, file_mode) # file_mode为'a'则追加,为'w'则覆盖
        handlers.append(file_handler)

    # 格式化器
    plain_formatter = logging.Formatter(
        "[%(asctime)s %(levelname)s %(filename)s line %(lineno)d %(process)d] %(message)s"
    )

    formatter = plain_formatter

    # 处理器加格式
    for handler in handlers:
        handler.setFormatter(formatter)
        handler.setLevel(log_level)
        logger.addHandler(handler)

    if rank == 0:
        logger.setLevel(log_level)
    else:
        logger.setLevel(logging.ERROR)
    
    return logger


if __name__ == "__main__":
    log_name = "Debug"
    log_level = logging.INFO            # DEBUG INFO WARNING ERROR CRITICAL
    logging.DEBUG
    logging.ERROR
    logging.CRITICAL
    file_mode = 'w'
    log_file = './log/train.log'

    logger = get_logger(log_name, log_level, log_file, file_mode='w')
    logger.info("=> Loading config ...")
    logger.info("=> Start train!")
    pass
Logo

技术共进,成长同行——讯飞AI开发者社区

更多推荐