数据配置文件

数据增强配置如下:

  1. gt_sampling 数据增强方法配置:

    • NAME: gt_sampling,使用 gt_sampling 数据增强方法。
    • USE_ROAD_PLANE: True,使用路面平面信息。
    • DB_INFO_PATH: kitti_dbinfos_train.pkl,用于过滤数据的数据库信息路径。
    • PREPARE: 数据准备配置,包括:
      • filter_by_min_points: [‘Car:5’, ‘Pedestrian:5’, ‘Cyclist:5’],根据最小点数过滤类别为 Car、Pedestrian 和 Cyclist 的目标。
      • filter_by_difficulty: [-1],根据困难程度过滤目标,此处的 -1 表示不过滤。
    • SAMPLE_GROUPS: [‘Car:15’,‘Pedestrian:15’, ‘Cyclist:15’],采样目标的数量,每个类别分别采样 15 个目标。
    • NUM_POINT_FEATURES: 4,点云特征的数量。
    • DATABASE_WITH_FAKELIDAR: False,是否使用虚拟激光雷达数据。
    • REMOVE_EXTRA_WIDTH: [0.0, 0.0, 0.0],移除额外的宽度。
    • LIMIT_WHOLE_SCENE: False,是否限制整个场景。
  2. random_world_flip 数据增强方法配置:

    • NAME: random_world_flip,使用 random_world_flip 数据增强方法。
    • ALONG_AXIS_LIST: [‘x’],沿着 x 轴进行随机翻转。
  3. random_world_rotation 数据增强方法配置:

    • NAME: random_world_rotation,使用 random_world_rotation 数据增强方法。
    • WORLD_ROT_ANGLE: [-0.78539816, 0.78539816],世界旋转的角度范围,这里表示 -45 度到 45 度之间的随机旋转。
  4. random_world_scaling 数据增强方法配置:

    • NAME: random_world_scaling,使用 random_world_scaling 数据增强方法。
    • WORLD_SCALE_RANGE: [0.95, 1.05],世界缩放的范围,表示将点云缩放到 0.95 倍到 1.05 倍之间的随机比例。

这些数据增强方法可以帮助增加数据的多样性和鲁棒性,提升模型在不同场景下的泛化能力。具体的数据增强方法和参数可以根据任务需求和数据特点进行调整和优化。

模块讲解

代码在OpenPCDet/pcdet/datasets/augmentor/data_augmentor.py

gt_sampling 模块

涉及到的代码太多,后续分开章节讲

random_world_flip 模块

  def random_world_flip(self, data_dict=None, config=None):
        if data_dict is None:
            return partial(self.random_world_flip, config=config)
        gt_boxes, points = data_dict['gt_boxes'], data_dict['points']
        for cur_axis in config['ALONG_AXIS_LIST']:
            assert cur_axis in ['x', 'y']
            gt_boxes, points, enable = getattr(augmentor_utils, 'random_flip_along_%s' % cur_axis)(
                gt_boxes, points, return_flip=True
            )
            data_dict['flip_%s'%cur_axis] = enable
            if 'roi_boxes' in data_dict.keys():
                num_frame, num_rois,dim = data_dict['roi_boxes'].shape
                roi_boxes, _, _ = getattr(augmentor_utils, 'random_flip_along_%s' % cur_axis)(
                data_dict['roi_boxes'].reshape(-1,dim), np.zeros([1,3]), return_flip=True, enable=enable
                )
                data_dict['roi_boxes'] = roi_boxes.reshape(num_frame, num_rois,dim)

        data_dict['gt_boxes'] = gt_boxes
        data_dict['points'] = points
        return data_dict
 gt_boxes, points, enable = getattr(augmentor_utils, 'random_flip_along_%s' % cur_axis)(
                gt_boxes, points, return_flip=True
            )
            

这行代码使用了 Python 中的 getattr() 函数和字符串格式化操作 % 来动态地调用函数。

  • getattr() 函数是 Python 中的一个内置函数,它用于获取对象的属性或者方法。在这个例子中,它使用了字符串格式化操作 % 构建了一个函数名,然后通过 getattr() 来获取这个函数并调用它。

    • getattr(object, name) 接受两个参数:object 是一个对象(在这里可能是 augmentor_utils 模块),name 是一个字符串,表示对象的属性名或者方法名。
  • %s 是 Python 字符串格式化的一种方式,用于将字符串中的 %s 替换为后续传入的字符串参数,这里是 cur_axis

所以,这一行代码的作用是动态地调用 augmentor_utils 模块中的类似 random_flip_along_x() 或者 random_flip_along_y() 这样的函数,并将 gt_boxespoints 作为参数传递给这个函数进行处理,同时将返回的结果分别赋值给 gt_boxespointsenable 这三个变量。

def random_flip_along_x(gt_boxes, points, return_flip=False, enable=None):
    """
    Args:
        gt_boxes: (N, 7 + C), [x, y, z, dx, dy, dz, heading, [vx], [vy]]
        points: (M, 3 + C)
    Returns:
    """
    if enable is None:
        enable = np.random.choice([False, True], replace=False, p=[0.5, 0.5])
    if enable:
        gt_boxes[:, 1] = -gt_boxes[:, 1]
        gt_boxes[:, 6] = -gt_boxes[:, 6]
        points[:, 1] = -points[:, 1]
        
        if gt_boxes.shape[1] > 7:
            gt_boxes[:, 8] = -gt_boxes[:, 8]
    if return_flip:
        return gt_boxes, points, enable
    return gt_boxes, points

这个函数 random_flip_along_x 实现了沿着 x 轴进行翻转操作。让我解释一下这个函数的功能:

  • 参数

    • gt_boxes:一个形状为 (N, 7 + C) 的数组,包含了一些框的坐标信息和其他属性。
    • points:一个形状为 (M, 3 + C) 的数组,可能包含一些点的信息。
    • return_flip:一个布尔值,默认为 False,表示是否返回翻转后的结果。
    • enable:一个布尔值,用于控制是否进行翻转。如果为 None,则随机选择是否进行翻转。
  • 返回值

    • 如果 return_flipTrue,则返回翻转后的 gt_boxespoints 以及 enable 标志。
    • 否则,只返回翻转后的 gt_boxespoints
  • 函数逻辑

    • 首先,如果 enableNone,则通过随机选择的方式确定是否进行翻转操作,这里是以 50% 的概率进行翻转。
    • 如果 enableTrue,则执行翻转操作:
      • gt_boxespoints 中的与 y 轴相关的坐标(索引为 1 的位置)取负,实现了沿 x 轴的翻转。
      • 如果 gt_boxes 中的列数大于 7,还会将第 8 列的内容也取负。
    • 最后根据 return_flip 参数决定是否返回翻转后的结果。

这个函数实现了在给定条件下对 gt_boxespoints 进行沿 x 轴的翻转操作,并根据需要返回结果。

同理 random_flip_along_y 实现了沿着 y 轴进行翻转操作,区别在于取到的数据的维度

 gt_boxes[:, 0] = -gt_boxes[:, 0]
        gt_boxes[:, 6] = -(gt_boxes[:, 6] + np.pi)
        points[:, 0] = -points[:, 0]

        if gt_boxes.shape[1] > 7:
            gt_boxes[:, 7] = -gt_boxes[:, 7]

random_world_rotation 模块

核心代码在于/OpenPCDet/pcdet/datasets/augmentor/augmentor_utils.py的global_rotation中

def global_rotation(gt_boxes, points, rot_range, return_rot=False, noise_rotation=None):
    """
    Args:
        gt_boxes: (N, 7 + C), [x, y, z, dx, dy, dz, heading, [vx], [vy]]
        points: (M, 3 + C),
        rot_range: [min, max]
    Returns:
    """
    if noise_rotation is None: 
        noise_rotation = np.random.uniform(rot_range[0], rot_range[1])
    points = common_utils.rotate_points_along_z(points[np.newaxis, :, :], np.array([noise_rotation]))[0]
    gt_boxes[:, 0:3] = common_utils.rotate_points_along_z(gt_boxes[np.newaxis, :, 0:3], np.array([noise_rotation]))[0]
    gt_boxes[:, 6] += noise_rotation
    if gt_boxes.shape[1] > 7:
        gt_boxes[:, 7:9] = common_utils.rotate_points_along_z(
            np.hstack((gt_boxes[:, 7:9], np.zeros((gt_boxes.shape[0], 1))))[np.newaxis, :, :],
            np.array([noise_rotation])
        )[0][:, 0:2]

    if return_rot:
        return gt_boxes, points, noise_rotation
    return gt_boxes, points

用于进行全局旋转操作。

  • 参数

    • gt_boxes:一个形状为 (N, 7 + C) 的数组,包含了一些框的坐标信息和其他属性。
    • points:一个形状为 (M, 3 + C) 的数组,可能包含一些点的信息。
    • rot_range:一个包含两个元素的列表,表示旋转角度的范围 [min, max]
    • return_rot:一个布尔值,默认为 False,表示是否返回旋转后的结果。
    • noise_rotation:一个浮点数,表示旋转的角度。如果为 None,则随机选择一个在 rot_range 范围内的角度。
  • 返回值

    • 如果 return_rotTrue,则返回旋转后的 gt_boxespoints 以及使用的旋转角度 noise_rotation
    • 否则,只返回旋转后的 gt_boxespoints
  • 函数逻辑

    • 首先,如果 noise_rotationNone,则从指定的 rot_range 范围内随机选择一个角度作为旋转角度。
    • 使用 common_utils.rotate_points_along_z() 函数对 pointsgt_boxes 进行沿着 Z 轴的旋转操作,将它们旋转了 noise_rotation 度。
    • 对于 gt_boxes 中的特定列(位置 0 到 2 和位置 6),进行了旋转变换,同时如果 gt_boxes 中的列数大于 7,还对位置 7 到 8 的内容进行了相应的旋转。
    • 最后根据 return_rot 参数决定是否返回旋转后的结果。

这个函数的作用是对 gt_boxespoints 进行全局的旋转操作,并根据需要返回结果。

random_world_scaling 模块

核心代码在于/OpenPCDet/pcdet/datasets/augmentor/augmentor_utils.py的global_scaling中

def global_scaling(gt_boxes, points, scale_range, return_scale=False):
    """
    Args:
        gt_boxes: (N, 7), [x, y, z, dx, dy, dz, heading]
        points: (M, 3 + C),
        scale_range: [min, max]
    Returns:
    """
    if scale_range[1] - scale_range[0] < 1e-3:
        return gt_boxes, points
    noise_scale = np.random.uniform(scale_range[0], scale_range[1])
    points[:, :3] *= noise_scale
    gt_boxes[:, :6] *= noise_scale
    if gt_boxes.shape[1] > 7:
        gt_boxes[:, 7:] *= noise_scale
        
    if return_scale:
        return gt_boxes, points, noise_scale
    return gt_boxes, points

用于对给定的 gt_boxespoints 进行全局缩放操作。以下是这个函数的功能和实现细节:

  • 参数

    • gt_boxes:一个形状为 (N, 7) 的数组,包含了一些框的坐标信息和尺寸。
    • points:一个形状为 (M, 3 + C) 的数组,可能包含一些点的信息。
    • scale_range:一个包含两个元素的列表,表示缩放的范围 [min, max]
    • return_scale:一个布尔值,默认为 False,表示是否返回缩放后的结果。
  • 返回值

    • 如果 return_scaleTrue,则返回缩放后的 gt_boxespoints 以及使用的缩放比例 noise_scale
    • 否则,只返回缩放后的 gt_boxespoints
  • 函数逻辑

    • 首先,如果指定的缩放范围非常小(小于 1e-3),则直接返回原始的 gt_boxespoints,不做缩放操作。
    • 否则,从指定的 scale_range 范围内随机选择一个缩放比例 noise_scale
    • points 中的前三列(即位置 0 到 2)乘以 noise_scale,实现了对点的坐标进行缩放。
    • gt_boxes 中的前六列(即位置 0 到 5)乘以 noise_scale,对边界框的坐标和尺寸进行缩放。如果 gt_boxes 中的列数大于 7,还对位置 7 到末尾的内容进行了相应的缩放。
    • 最后根据 return_scale 参数决定是否返回缩放后的结果。

这个函数的作用是对 gt_boxespoints 进行全局的缩放操作,并根据需要返回结果。

Logo

技术共进,成长同行——讯飞AI开发者社区

更多推荐