Openpcdet 系列 Pointpillar代码数据增强模块详解
Openpcdet 系列 Pointpillar代码数据增强模块详解
数据配置文件
数据增强配置如下:
-
gt_sampling
数据增强方法配置:- NAME: gt_sampling,使用 gt_sampling 数据增强方法。
- USE_ROAD_PLANE: True,使用路面平面信息。
- DB_INFO_PATH: kitti_dbinfos_train.pkl,用于过滤数据的数据库信息路径。
- PREPARE: 数据准备配置,包括:
- filter_by_min_points: [‘Car:5’, ‘Pedestrian:5’, ‘Cyclist:5’],根据最小点数过滤类别为 Car、Pedestrian 和 Cyclist 的目标。
- filter_by_difficulty: [-1],根据困难程度过滤目标,此处的 -1 表示不过滤。
- SAMPLE_GROUPS: [‘Car:15’,‘Pedestrian:15’, ‘Cyclist:15’],采样目标的数量,每个类别分别采样 15 个目标。
- NUM_POINT_FEATURES: 4,点云特征的数量。
- DATABASE_WITH_FAKELIDAR: False,是否使用虚拟激光雷达数据。
- REMOVE_EXTRA_WIDTH: [0.0, 0.0, 0.0],移除额外的宽度。
- LIMIT_WHOLE_SCENE: False,是否限制整个场景。
-
random_world_flip
数据增强方法配置:- NAME: random_world_flip,使用 random_world_flip 数据增强方法。
- ALONG_AXIS_LIST: [‘x’],沿着 x 轴进行随机翻转。
-
random_world_rotation
数据增强方法配置:- NAME: random_world_rotation,使用 random_world_rotation 数据增强方法。
- WORLD_ROT_ANGLE: [-0.78539816, 0.78539816],世界旋转的角度范围,这里表示 -45 度到 45 度之间的随机旋转。
-
random_world_scaling
数据增强方法配置:- NAME: random_world_scaling,使用 random_world_scaling 数据增强方法。
- WORLD_SCALE_RANGE: [0.95, 1.05],世界缩放的范围,表示将点云缩放到 0.95 倍到 1.05 倍之间的随机比例。
这些数据增强方法可以帮助增加数据的多样性和鲁棒性,提升模型在不同场景下的泛化能力。具体的数据增强方法和参数可以根据任务需求和数据特点进行调整和优化。
模块讲解
代码在OpenPCDet/pcdet/datasets/augmentor/data_augmentor.py
gt_sampling 模块
涉及到的代码太多,后续分开章节讲
random_world_flip 模块
def random_world_flip(self, data_dict=None, config=None):
if data_dict is None:
return partial(self.random_world_flip, config=config)
gt_boxes, points = data_dict['gt_boxes'], data_dict['points']
for cur_axis in config['ALONG_AXIS_LIST']:
assert cur_axis in ['x', 'y']
gt_boxes, points, enable = getattr(augmentor_utils, 'random_flip_along_%s' % cur_axis)(
gt_boxes, points, return_flip=True
)
data_dict['flip_%s'%cur_axis] = enable
if 'roi_boxes' in data_dict.keys():
num_frame, num_rois,dim = data_dict['roi_boxes'].shape
roi_boxes, _, _ = getattr(augmentor_utils, 'random_flip_along_%s' % cur_axis)(
data_dict['roi_boxes'].reshape(-1,dim), np.zeros([1,3]), return_flip=True, enable=enable
)
data_dict['roi_boxes'] = roi_boxes.reshape(num_frame, num_rois,dim)
data_dict['gt_boxes'] = gt_boxes
data_dict['points'] = points
return data_dict
gt_boxes, points, enable = getattr(augmentor_utils, 'random_flip_along_%s' % cur_axis)(
gt_boxes, points, return_flip=True
)
这行代码使用了 Python 中的 getattr()
函数和字符串格式化操作 %
来动态地调用函数。
-
getattr()
函数是 Python 中的一个内置函数,它用于获取对象的属性或者方法。在这个例子中,它使用了字符串格式化操作%
构建了一个函数名,然后通过getattr()
来获取这个函数并调用它。getattr(object, name)
接受两个参数:object
是一个对象(在这里可能是augmentor_utils
模块),name
是一个字符串,表示对象的属性名或者方法名。
-
%s
是 Python 字符串格式化的一种方式,用于将字符串中的%s
替换为后续传入的字符串参数,这里是cur_axis
。
所以,这一行代码的作用是动态地调用 augmentor_utils
模块中的类似 random_flip_along_x()
或者 random_flip_along_y()
这样的函数,并将 gt_boxes
和 points
作为参数传递给这个函数进行处理,同时将返回的结果分别赋值给 gt_boxes
、points
和 enable
这三个变量。
def random_flip_along_x(gt_boxes, points, return_flip=False, enable=None):
"""
Args:
gt_boxes: (N, 7 + C), [x, y, z, dx, dy, dz, heading, [vx], [vy]]
points: (M, 3 + C)
Returns:
"""
if enable is None:
enable = np.random.choice([False, True], replace=False, p=[0.5, 0.5])
if enable:
gt_boxes[:, 1] = -gt_boxes[:, 1]
gt_boxes[:, 6] = -gt_boxes[:, 6]
points[:, 1] = -points[:, 1]
if gt_boxes.shape[1] > 7:
gt_boxes[:, 8] = -gt_boxes[:, 8]
if return_flip:
return gt_boxes, points, enable
return gt_boxes, points
这个函数 random_flip_along_x
实现了沿着 x 轴进行翻转操作。让我解释一下这个函数的功能:
-
参数:
gt_boxes
:一个形状为(N, 7 + C)
的数组,包含了一些框的坐标信息和其他属性。points
:一个形状为(M, 3 + C)
的数组,可能包含一些点的信息。return_flip
:一个布尔值,默认为False
,表示是否返回翻转后的结果。enable
:一个布尔值,用于控制是否进行翻转。如果为None
,则随机选择是否进行翻转。
-
返回值:
- 如果
return_flip
为True
,则返回翻转后的gt_boxes
、points
以及enable
标志。 - 否则,只返回翻转后的
gt_boxes
和points
。
- 如果
-
函数逻辑:
- 首先,如果
enable
是None
,则通过随机选择的方式确定是否进行翻转操作,这里是以 50% 的概率进行翻转。 - 如果
enable
为True
,则执行翻转操作:- 将
gt_boxes
和points
中的与 y 轴相关的坐标(索引为 1 的位置)取负,实现了沿 x 轴的翻转。 - 如果
gt_boxes
中的列数大于 7,还会将第 8 列的内容也取负。
- 将
- 最后根据
return_flip
参数决定是否返回翻转后的结果。
- 首先,如果
这个函数实现了在给定条件下对 gt_boxes
和 points
进行沿 x 轴的翻转操作,并根据需要返回结果。
同理
random_flip_along_y
实现了沿着 y 轴进行翻转操作,区别在于取到的数据的维度
gt_boxes[:, 0] = -gt_boxes[:, 0]
gt_boxes[:, 6] = -(gt_boxes[:, 6] + np.pi)
points[:, 0] = -points[:, 0]
if gt_boxes.shape[1] > 7:
gt_boxes[:, 7] = -gt_boxes[:, 7]
random_world_rotation 模块
核心代码在于/OpenPCDet/pcdet/datasets/augmentor/augmentor_utils.py的global_rotation中
def global_rotation(gt_boxes, points, rot_range, return_rot=False, noise_rotation=None):
"""
Args:
gt_boxes: (N, 7 + C), [x, y, z, dx, dy, dz, heading, [vx], [vy]]
points: (M, 3 + C),
rot_range: [min, max]
Returns:
"""
if noise_rotation is None:
noise_rotation = np.random.uniform(rot_range[0], rot_range[1])
points = common_utils.rotate_points_along_z(points[np.newaxis, :, :], np.array([noise_rotation]))[0]
gt_boxes[:, 0:3] = common_utils.rotate_points_along_z(gt_boxes[np.newaxis, :, 0:3], np.array([noise_rotation]))[0]
gt_boxes[:, 6] += noise_rotation
if gt_boxes.shape[1] > 7:
gt_boxes[:, 7:9] = common_utils.rotate_points_along_z(
np.hstack((gt_boxes[:, 7:9], np.zeros((gt_boxes.shape[0], 1))))[np.newaxis, :, :],
np.array([noise_rotation])
)[0][:, 0:2]
if return_rot:
return gt_boxes, points, noise_rotation
return gt_boxes, points
用于进行全局旋转操作。
-
参数:
gt_boxes
:一个形状为(N, 7 + C)
的数组,包含了一些框的坐标信息和其他属性。points
:一个形状为(M, 3 + C)
的数组,可能包含一些点的信息。rot_range
:一个包含两个元素的列表,表示旋转角度的范围[min, max]
。return_rot
:一个布尔值,默认为False
,表示是否返回旋转后的结果。noise_rotation
:一个浮点数,表示旋转的角度。如果为None
,则随机选择一个在rot_range
范围内的角度。
-
返回值:
- 如果
return_rot
为True
,则返回旋转后的gt_boxes
、points
以及使用的旋转角度noise_rotation
。 - 否则,只返回旋转后的
gt_boxes
和points
。
- 如果
-
函数逻辑:
- 首先,如果
noise_rotation
是None
,则从指定的rot_range
范围内随机选择一个角度作为旋转角度。 - 使用
common_utils.rotate_points_along_z()
函数对points
和gt_boxes
进行沿着 Z 轴的旋转操作,将它们旋转了noise_rotation
度。 - 对于
gt_boxes
中的特定列(位置 0 到 2 和位置 6),进行了旋转变换,同时如果gt_boxes
中的列数大于 7,还对位置 7 到 8 的内容进行了相应的旋转。 - 最后根据
return_rot
参数决定是否返回旋转后的结果。
- 首先,如果
这个函数的作用是对 gt_boxes
和 points
进行全局的旋转操作,并根据需要返回结果。
random_world_scaling 模块
核心代码在于/OpenPCDet/pcdet/datasets/augmentor/augmentor_utils.py的global_scaling中
def global_scaling(gt_boxes, points, scale_range, return_scale=False):
"""
Args:
gt_boxes: (N, 7), [x, y, z, dx, dy, dz, heading]
points: (M, 3 + C),
scale_range: [min, max]
Returns:
"""
if scale_range[1] - scale_range[0] < 1e-3:
return gt_boxes, points
noise_scale = np.random.uniform(scale_range[0], scale_range[1])
points[:, :3] *= noise_scale
gt_boxes[:, :6] *= noise_scale
if gt_boxes.shape[1] > 7:
gt_boxes[:, 7:] *= noise_scale
if return_scale:
return gt_boxes, points, noise_scale
return gt_boxes, points
用于对给定的 gt_boxes
和 points
进行全局缩放操作。以下是这个函数的功能和实现细节:
-
参数:
gt_boxes
:一个形状为(N, 7)
的数组,包含了一些框的坐标信息和尺寸。points
:一个形状为(M, 3 + C)
的数组,可能包含一些点的信息。scale_range
:一个包含两个元素的列表,表示缩放的范围[min, max]
。return_scale
:一个布尔值,默认为False
,表示是否返回缩放后的结果。
-
返回值:
- 如果
return_scale
为True
,则返回缩放后的gt_boxes
、points
以及使用的缩放比例noise_scale
。 - 否则,只返回缩放后的
gt_boxes
和points
。
- 如果
-
函数逻辑:
- 首先,如果指定的缩放范围非常小(小于 1e-3),则直接返回原始的
gt_boxes
和points
,不做缩放操作。 - 否则,从指定的
scale_range
范围内随机选择一个缩放比例noise_scale
。 - 将
points
中的前三列(即位置 0 到 2)乘以noise_scale
,实现了对点的坐标进行缩放。 - 将
gt_boxes
中的前六列(即位置 0 到 5)乘以noise_scale
,对边界框的坐标和尺寸进行缩放。如果gt_boxes
中的列数大于 7,还对位置 7 到末尾的内容进行了相应的缩放。 - 最后根据
return_scale
参数决定是否返回缩放后的结果。
- 首先,如果指定的缩放范围非常小(小于 1e-3),则直接返回原始的
这个函数的作用是对 gt_boxes
和 points
进行全局的缩放操作,并根据需要返回结果。
更多推荐
所有评论(0)