Ultra-Fast-Lane-Detection-v2 {后处理优化}//参考
作者:小教学发布时间:2023-10-04分类:程序开发学习浏览:88
导读:采用三次多项式拟合生成的anchor特征点,在给定的polyfit_draw函数中,degree参数代表了拟合多项式的度数。具体来说,当我们使用np.polyfit函数进行数据...
采用三次多项式拟合生成的anchor特征点,在给定的polyfit_draw
函数中,degree
参数代表了拟合多项式的度数。
具体来说,当我们使用np.polyfit
函数进行数据点的多项式拟合时,我们需要指定一个度数。这个度数决定了多项式的复杂度。例如:
-
degree = 1
:线性拟合,也就是最简单的直线拟合。拟合的多项式形式为 f(y)=ax+b。 -
degree = 2
:二次多项式拟合。拟合的多项式形式为 f(y)=ax2+bx+c。 -
degree = 3
:三次多项式拟合。拟合的多项式形式为 f(y)=ax3+bx2+cx+d。
...以此类推。
度数越高,多项式越复杂,可以更准确地拟合数据点,但也更容易过拟合(即模型过于复杂,过于依赖训练数据,对新数据的适应性差)。
import torch, os, cv2
from utils.dist_utils import dist_print
import torch, os
from utils.common import merge_config, get_model
import tqdm
import torchvision.transforms as transforms
from data.dataset import LaneTestDataset
def pred2coords(pred, row_anchor, col_anchor, local_width = 1, original_image_width = 1640, original_image_height = 590):
batch_size, num_grid_row, num_cls_row, num_lane_row = pred['loc_row'].shape
batch_size, num_grid_col, num_cls_col, num_lane_col = pred['loc_col'].shape
max_indices_row = pred['loc_row'].argmax(1).cpu()
# n , num_cls, num_lanes
valid_row = pred['exist_row'].argmax(1).cpu()
# n, num_cls, num_lanes
max_indices_col = pred['loc_col'].argmax(1).cpu()
# n , num_cls, num_lanes
valid_col = pred['exist_col'].argmax(1).cpu()
# n, num_cls, num_lanes
pred['loc_row'] = pred['loc_row'].cpu()
pred['loc_col'] = pred['loc_col'].cpu()
coords = []
row_lane_idx = [1,2]
col_lane_idx = [0,3]
for i in row_lane_idx:
tmp = []
if valid_row[0,:,i].sum() > num_cls_row / 2:
for k in range(valid_row.shape[1]):
if valid_row[0,k,i]:
all_ind = torch.tensor(list(range(max(0,max_indices_row[0,k,i] - local_width), min(num_grid_row-1, max_indices_row[0,k,i] + local_width) + 1)))
out_tmp = (pred['loc_row'][0,all_ind,k,i].softmax(0) * all_ind.float()).sum() + 0.5
out_tmp = out_tmp / (num_grid_row-1) * original_image_width
tmp.append((int(out_tmp), int(row_anchor[k] * original_image_height)))
coords.append(tmp)
for i in col_lane_idx:
tmp = []
if valid_col[0,:,i].sum() > num_cls_col / 4:
for k in range(valid_col.shape[1]):
if valid_col[0,k,i]:
all_ind = torch.tensor(list(range(max(0,max_indices_col[0,k,i] - local_width), min(num_grid_col-1, max_indices_col[0,k,i] + local_width) + 1)))
out_tmp = (pred['loc_col'][0,all_ind,k,i].softmax(0) * all_ind.float()).sum() + 0.5
out_tmp = out_tmp / (num_grid_col-1) * original_image_height
tmp.append((int(col_anchor[k] * original_image_width), int(out_tmp)))
coords.append(tmp)
return coords
def polyfit_draw(img, coords, degree=3, color=(144, 238, 144), thickness=2):
"""
对车道线坐标进行多项式拟合并在图像上绘制曲线。
:param img: 输入图像
:param coords: 车道线坐标列表
:param degree: 拟合的多项式的度数
:param color: 曲线的颜色
:param thickness: 曲线的宽度
:return: 绘制了曲线的图像
"""
if len(coords) == 0:
return img
x = [point[0] for point in coords]
y = [point[1] for point in coords]
# 对点进行多项式拟合
coefficients = np.polyfit(y, x, degree)
poly = np.poly1d(coefficients)
ys = np.linspace(min(y), max(y), 100)
xs = poly(ys)
for i in range(len(ys) - 1):
start_point = (int(xs[i]), int(ys[i]))
end_point = (int(xs[i+1]), int(ys[i+1]))
cv2.line(img, start_point, end_point, color, thickness)
return img
if __name__ == "__main__":
torch.backends.cudnn.benchmark = True
args, cfg = merge_config()
cfg.batch_size = 1
print('setting batch_size to 1 for demo generation')
dist_print('start testing...')
assert cfg.backbone in ['18','34','50','101','152','50next','101next','50wide','101wide']
if cfg.dataset == 'CULane':
cls_num_per_lane = 18
elif cfg.dataset == 'Tusimple':
cls_num_per_lane = 56
else:
raise NotImplementedError
net = get_model(cfg)
state_dict = torch.load(cfg.test_model, map_location='cpu')['model']
compatible_state_dict = {}
for k, v in state_dict.items():
if 'module.' in k:
compatible_state_dict[k[7:]] = v
else:
compatible_state_dict[k] = v
net.load_state_dict(compatible_state_dict, strict=False)
net.eval()
img_transforms = transforms.Compose([
transforms.Resize((int(cfg.train_height / cfg.crop_ratio), cfg.train_width)),
transforms.ToTensor(),
transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
])
if cfg.dataset == 'CULane':
splits = ['test0_normal.txt']
datasets = [LaneTestDataset(cfg.data_root,os.path.join(cfg.data_root, 'list/test_split/'+split),img_transform = img_transforms, crop_size = cfg.train_height) for split in splits]
img_w, img_h = 1570, 660
elif cfg.dataset == 'Tusimple':
splits = ['test.txt']
datasets = [LaneTestDataset(cfg.data_root,os.path.join(cfg.data_root, split),img_transform = img_transforms, crop_size = cfg.train_height) for split in splits]
img_w, img_h = 1280, 720
else:
raise NotImplementedError
for split, dataset in zip(splits, datasets):
loader = torch.utils.data.DataLoader(dataset, batch_size=1, shuffle = False, num_workers=1)
fourcc = cv2.VideoWriter_fourcc(*'MJPG')
print(split[:-3]+'avi')
vout = cv2.VideoWriter('4.'+'avi', fourcc , 30.0, (img_w, img_h))
for i, data in enumerate(tqdm.tqdm(loader)):
imgs, names = data
imgs = imgs.cuda()
with torch.no_grad():
pred = net(imgs)
vis = cv2.imread(os.path.join(cfg.data_root,names[0]))
coords = pred2coords(pred, cfg.row_anchor, cfg.col_anchor, original_image_width = img_w, original_image_height = img_h)
for lane in coords:
# for coord in lane:
# cv2.circle(vis,coord,1,(0,255,0),-1)
# vis = draw_lanes(vis, coords)
# polyfit_draw(vis, lane)
vis = polyfit_draw(vis, lane) # 对每一条车道线都使用polyfit_draw函数
vout.write(vis)
vout.release()
ps:
优化前
优化后
显存利用情况
- 程序开发学习排行
-
- 1鸿蒙HarmonyOS:Web组件网页白屏检测
- 2HTTPS协议是安全传输,为啥还要再加密?
- 3HarmonyOS鸿蒙应用开发——数据持久化Preferences
- 4记解决MaterialButton背景颜色与设置值不同
- 5鸿蒙HarmonyOS实战-ArkUI组件(RelativeContainer)
- 6鸿蒙HarmonyOS实战-ArkUI组件(Stack)
- 7鸿蒙HarmonyOS实战-ArkUI组件(GridRow/GridCol)
- 8[Android][NDK][Cmake]一文搞懂Android项目中的Cmake
- 9鸿蒙HarmonyOS实战-ArkUI组件(mediaquery)
- 最近发表
-
- WooCommerce最好的WordPress常用插件下载博客插件模块的相关产品
- 羊驼机器人最好的WordPress常用插件下载博客插件模块
- IP信息记录器最好的WordPress常用插件下载博客插件模块
- Linkly for WooCommerce最好的WordPress常用插件下载博客插件模块
- 元素聚合器Forms最好的WordPress常用插件下载博客插件模块
- Promaker Chat 最好的WordPress通用插件下载 博客插件模块
- 自动更新发布日期最好的WordPress常用插件下载博客插件模块
- WordPress官方最好的获取回复WordPress常用插件下载博客插件模块
- Img to rss最好的wordpress常用插件下载博客插件模块
- WPMozo为Elementor最好的WordPress常用插件下载博客插件模块添加精简版