|
|
@@ -5,13 +5,16 @@ from datetime import datetime
|
|
|
import numpy as np
|
|
|
import torch
|
|
|
from matplotlib import pyplot as plt
|
|
|
+from torch.optim.lr_scheduler import ReduceLROnPlateau
|
|
|
from torch.utils.tensorboard import SummaryWriter
|
|
|
|
|
|
-from libs.vision_libs.utils import draw_bounding_boxes
|
|
|
+from libs.vision_libs.utils import draw_bounding_boxes, draw_keypoints
|
|
|
from models.base.base_model import BaseModel
|
|
|
from models.base.base_trainer import BaseTrainer
|
|
|
from models.config.config_tool import read_yaml
|
|
|
-from models.line_detect.dataset_LD import WirePointDataset
|
|
|
+from models.line_detect.line_dataset import LineDataset
|
|
|
+
|
|
|
+from models.line_net.dataset_LD import WirePointDataset
|
|
|
from models.wirenet.postprocess import postprocess
|
|
|
from tools import utils
|
|
|
from torchvision import transforms
|
|
|
@@ -39,6 +42,42 @@ def c(x):
|
|
|
|
|
|
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
|
|
|
|
|
+import matplotlib.pyplot as plt
|
|
|
+from PIL import ImageDraw
|
|
|
+from torchvision.transforms import functional as F
|
|
|
+import torch
|
|
|
+
|
|
|
+
|
|
|
+# 由低到高蓝黄红
|
|
|
+def draw_lines_with_scores(tensor_image, lines, scores, width=3, cmap='viridis'):
|
|
|
+ """
|
|
|
+ 根据得分对线段着色并绘制
|
|
|
+ :param tensor_image: (3, H, W) uint8 图像
|
|
|
+ :param lines: (N, 2, 2) 每条线 [ [x1,y1], [x2,y2] ]
|
|
|
+ :param scores: (N,) 每条线的得分,范围 [0, 1]
|
|
|
+ :param width: 线宽
|
|
|
+ :param cmap: matplotlib colormap 名称,例如 'viridis', 'jet', 'coolwarm'
|
|
|
+ :return: (3, H, W) uint8 画好线的图像
|
|
|
+ """
|
|
|
+ assert tensor_image.dtype == torch.uint8
|
|
|
+ assert tensor_image.shape[0] == 3
|
|
|
+ assert lines.shape[0] == scores.shape[0]
|
|
|
+
|
|
|
+ # 准备色图
|
|
|
+ colormap = plt.get_cmap(cmap)
|
|
|
+ colors = (colormap(scores.cpu().numpy())[:, :3] * 255).astype('uint8') # 去掉 alpha 通道
|
|
|
+
|
|
|
+ # 转为 PIL 画图
|
|
|
+ image_pil = F.to_pil_image(tensor_image)
|
|
|
+ draw = ImageDraw.Draw(image_pil)
|
|
|
+
|
|
|
+ for line, color in zip(lines, colors):
|
|
|
+ start = tuple(map(float, line[0][:2].tolist()))
|
|
|
+ end = tuple(map(float, line[1][:2].tolist()))
|
|
|
+ draw.line([start, end], fill=tuple(color), width=width)
|
|
|
+
|
|
|
+ return (F.to_tensor(image_pil) * 255).to(torch.uint8)
|
|
|
+
|
|
|
|
|
|
class Trainer(BaseTrainer):
|
|
|
def __init__(self, model=None, **kwargs):
|
|
|
@@ -53,6 +92,7 @@ class Trainer(BaseTrainer):
|
|
|
self.freeze_config = kwargs['train_params']['freeze_params']
|
|
|
print(f'freeze_config:{self.freeze_config}')
|
|
|
self.dataset_path = kwargs['io']['datadir']
|
|
|
+ self.data_type = kwargs['io']['data_type']
|
|
|
self.batch_size = kwargs['train_params']['batch_size']
|
|
|
self.num_workers = kwargs['train_params']['num_workers']
|
|
|
self.logdir = kwargs['io']['logdir']
|
|
|
@@ -66,6 +106,7 @@ class Trainer(BaseTrainer):
|
|
|
self.best_train_model_path = os.path.join(self.wts_path, 'best_train.pth')
|
|
|
self.best_val_model_path = os.path.join(self.wts_path, 'best_val.pth')
|
|
|
self.max_epoch = kwargs['train_params']['max_epoch']
|
|
|
+ self.augmentation= kwargs['train_params']["augmentation"]
|
|
|
|
|
|
def move_to_device(self, data, device):
|
|
|
if isinstance(data, (list, tuple)):
|
|
|
@@ -146,70 +187,43 @@ class Trainer(BaseTrainer):
|
|
|
print(f"No saved model found at {save_path}")
|
|
|
return model, optimizer
|
|
|
|
|
|
- def writer_predict_result(self, img, result, epoch):
|
|
|
+
|
|
|
+
|
|
|
+
|
|
|
+
|
|
|
+ def writer_predict_result(self, img, result, epoch,type=1):
|
|
|
img = img.cpu().detach()
|
|
|
- img=img[:3]
|
|
|
- im = img.permute(1, 2, 0)
|
|
|
- self.writer.add_image("z-ori", (im*255).to(torch.uint8), epoch, dataformats="HWC")
|
|
|
+ im = img.permute(1, 2, 0) # [512, 512, 3]
|
|
|
+ self.writer.add_image("z-ori", im, epoch, dataformats="HWC")
|
|
|
|
|
|
- boxed_image = draw_bounding_boxes((img * 255).to(torch.uint8), result[0]["boxes"],
|
|
|
+ boxed_image = draw_bounding_boxes((img * 255).to(torch.uint8), result["boxes"],
|
|
|
colors="yellow", width=1)
|
|
|
- self.writer.add_image("z-boxes", boxed_image.permute(1, 2, 0), epoch, dataformats="HWC")
|
|
|
-
|
|
|
- PLTOPTS = {"color": "#33FFFF", "s": 15, "edgecolors": "none", "zorder": 5}
|
|
|
- # print(f'pred[1]:{pred[1]}')
|
|
|
- heatmaps = result[-2][0]
|
|
|
- print(f'heatmaps:{heatmaps.shape}')
|
|
|
- jmap = heatmaps[1: 2].cpu().detach()
|
|
|
- lmap = heatmaps[2: 3].cpu().detach()
|
|
|
- self.writer.add_image("z-jmap", jmap, epoch)
|
|
|
- self.writer.add_image("z-lmap", lmap, epoch)
|
|
|
- # plt.imshow(lmap)
|
|
|
+
|
|
|
+ # plt.imshow(boxed_image.permute(1, 2, 0).detach().cpu().numpy())
|
|
|
# plt.show()
|
|
|
- H = result[-1]['wires']
|
|
|
- lines = H["lines"][0].cpu().numpy() / 128 * im.shape[:2]
|
|
|
- scores = H["score"][0].cpu().numpy()
|
|
|
- for i in range(1, len(lines)):
|
|
|
- if (lines[i] == lines[0]).all():
|
|
|
- lines = lines[:i]
|
|
|
- scores = scores[:i]
|
|
|
- break
|
|
|
-
|
|
|
- # postprocess lines to remove overlapped lines
|
|
|
- diag = (im.shape[0] ** 2 + im.shape[1] ** 2) ** 0.5
|
|
|
- nlines, nscores = postprocess(lines, scores, diag * 0.01, 0, False)
|
|
|
-
|
|
|
- for i, t in enumerate([0]):
|
|
|
- plt.gca().set_axis_off()
|
|
|
- plt.subplots_adjust(top=1, bottom=0, right=1, left=0, hspace=0, wspace=0)
|
|
|
- plt.margins(0, 0)
|
|
|
- for (a, b), s in zip(nlines, nscores):
|
|
|
- if s < t:
|
|
|
- continue
|
|
|
- plt.plot([a[1], b[1]], [a[0], b[0]], c=c(s), linewidth=2, zorder=s)
|
|
|
- plt.scatter(a[1], a[0], **PLTOPTS)
|
|
|
- plt.scatter(b[1], b[0], **PLTOPTS)
|
|
|
- plt.gca().xaxis.set_major_locator(plt.NullLocator())
|
|
|
- plt.gca().yaxis.set_major_locator(plt.NullLocator())
|
|
|
- plt.imshow((im*255).to(torch.uint8))
|
|
|
- plt.tight_layout()
|
|
|
- fig = plt.gcf()
|
|
|
- fig.canvas.draw()
|
|
|
-
|
|
|
- width, height = fig.get_size_inches() * fig.get_dpi() # 获取图像尺寸
|
|
|
- tmp_img = fig.canvas.tostring_argb()
|
|
|
- tmp_img_np = np.frombuffer(tmp_img, dtype=np.uint8)
|
|
|
- tmp_img_np = tmp_img_np.reshape(int(height), int(width), 4)
|
|
|
-
|
|
|
- img_rgb = tmp_img_np[:, :, 1:] # 提取RGB部分,忽略Alpha通道
|
|
|
-
|
|
|
- # image_from_plot = np.frombuffer(tmp_img[:,:,1:], dtype=np.uint8).reshape(
|
|
|
- # fig.canvas.get_width_height()[::-1] + (3,))
|
|
|
- plt.close()
|
|
|
-
|
|
|
- img2 = transforms.ToTensor()(img_rgb)
|
|
|
-
|
|
|
- self.writer.add_image("z-output", (img2*255).to(torch.uint8), epoch)
|
|
|
+
|
|
|
+ self.writer.add_image("z-obj", boxed_image.permute(1, 2, 0), epoch, dataformats="HWC")
|
|
|
+
|
|
|
+
|
|
|
+ if type==1:
|
|
|
+ keypoint_img = draw_keypoints(boxed_image, result['points'], colors='red', width=3)
|
|
|
+
|
|
|
+ self.writer.add_image("z-output", keypoint_img, epoch)
|
|
|
+ # print("lines shape:", result['lines'].shape)
|
|
|
+
|
|
|
+
|
|
|
+ if type==2:
|
|
|
+ # 用自己写的函数画线段
|
|
|
+ # line_image = draw_lines(boxed_image, result['lines'], color='red', width=3)
|
|
|
+ print(f"shape of linescore:{result['liness_scores'].shape}")
|
|
|
+ scores = result['liness_scores'].mean(dim=1) # shape: [31]
|
|
|
+
|
|
|
+ line_image = draw_lines_with_scores((img * 255).to(torch.uint8), result['lines'],scores, width=3, cmap='jet')
|
|
|
+
|
|
|
+ self.writer.add_image("z-output_line", line_image.permute(1, 2, 0), epoch, dataformats="HWC")
|
|
|
+
|
|
|
+
|
|
|
+
|
|
|
|
|
|
def writer_loss(self, losses, epoch, phase='train'):
|
|
|
try:
|
|
|
@@ -236,8 +250,8 @@ class Trainer(BaseTrainer):
|
|
|
|
|
|
self.init_params(**kwargs)
|
|
|
|
|
|
- dataset_train = WirePointDataset(dataset_path=self.dataset_path, dataset_type='train')
|
|
|
- dataset_val = WirePointDataset(dataset_path=self.dataset_path, dataset_type='val')
|
|
|
+ dataset_train = LineDataset(dataset_path=self.dataset_path,augmentation=self.augmentation, data_type=self.data_type, dataset_type='train')
|
|
|
+ dataset_val = LineDataset(dataset_path=self.dataset_path,augmentation=False, data_type=self.data_type, dataset_type='val')
|
|
|
|
|
|
train_sampler = torch.utils.data.RandomSampler(dataset_train)
|
|
|
val_sampler = torch.utils.data.RandomSampler(dataset_val)
|
|
|
@@ -247,7 +261,7 @@ class Trainer(BaseTrainer):
|
|
|
val_collate_fn = utils.collate_fn
|
|
|
|
|
|
data_loader_train = torch.utils.data.DataLoader(
|
|
|
- dataset_train, batch_sampler=train_batch_sampler, num_workers=self.num_workers, collate_fn=train_collate_fn
|
|
|
+ dataset_train, batch_sampler=train_batch_sampler, num_workers=self.num_workers, collate_fn=train_collate_fn
|
|
|
)
|
|
|
data_loader_val = torch.utils.data.DataLoader(
|
|
|
dataset_val, batch_sampler=val_batch_sampler, num_workers=self.num_workers, collate_fn=val_collate_fn
|
|
|
@@ -257,22 +271,29 @@ class Trainer(BaseTrainer):
|
|
|
|
|
|
optimizer = torch.optim.Adam(
|
|
|
filter(lambda p: p.requires_grad, model.parameters()),
|
|
|
- lr=kwargs['train_params']['optim']['lr']
|
|
|
+ lr=kwargs['train_params']['optim']['lr'],
|
|
|
+ weight_decay=kwargs['train_params']['optim']['weight_decay'],
|
|
|
+
|
|
|
)
|
|
|
+ # scheduler = StepLR(optimizer, step_size=10, gamma=0.1)
|
|
|
+ scheduler = ReduceLROnPlateau(optimizer, 'min', patience=30)
|
|
|
|
|
|
for epoch in range(self.max_epoch):
|
|
|
print(f"train epoch:{epoch}")
|
|
|
|
|
|
model, epoch_train_loss = self.one_epoch(model, data_loader_train, epoch, optimizer)
|
|
|
+ scheduler.step(epoch_train_loss)
|
|
|
|
|
|
# ========== Validation ==========
|
|
|
with torch.no_grad():
|
|
|
model, epoch_val_loss = self.one_epoch(model, data_loader_val, epoch, optimizer, phase='val')
|
|
|
+ scheduler.step(epoch_val_loss)
|
|
|
|
|
|
if epoch==0:
|
|
|
best_train_loss = epoch_train_loss
|
|
|
best_val_loss = epoch_val_loss
|
|
|
|
|
|
+
|
|
|
self.save_last_model(model,self.last_model_path, epoch, optimizer)
|
|
|
best_train_loss = self.save_best_model(model, self.best_train_model_path, epoch, epoch_train_loss,
|
|
|
best_train_loss,
|
|
|
@@ -288,32 +309,40 @@ class Trainer(BaseTrainer):
|
|
|
|
|
|
total_loss = 0
|
|
|
epoch_step = 0
|
|
|
- global_step = epoch_step * len(data_loader)
|
|
|
+ global_step = epoch * len(data_loader)
|
|
|
for imgs, targets in data_loader:
|
|
|
imgs = self.move_to_device(imgs, device)
|
|
|
targets = self.move_to_device(targets, device)
|
|
|
if phase== 'val':
|
|
|
+ result,loss_dict = model(imgs, targets)
|
|
|
+ losses = sum(loss_dict.values()) if loss_dict else torch.tensor(0.0, device=device)
|
|
|
|
|
|
- result,losses = model(imgs, targets)
|
|
|
+ print(f'val losses:{losses}')
|
|
|
+ print(f'val result:{result}')
|
|
|
else:
|
|
|
- losses = model(imgs, targets)
|
|
|
+ loss_dict = model(imgs, targets)
|
|
|
+ losses = sum(loss_dict.values()) if loss_dict else torch.tensor(0.0, device=device)
|
|
|
+ print(f'train losses:{losses}')
|
|
|
|
|
|
- loss = _loss(losses)
|
|
|
+ # loss = _loss(losses)
|
|
|
+ loss=losses
|
|
|
total_loss += loss.item()
|
|
|
if phase == 'train':
|
|
|
optimizer.zero_grad()
|
|
|
loss.backward()
|
|
|
optimizer.step()
|
|
|
- self.writer_loss(losses, global_step, phase=phase)
|
|
|
+ self.writer_loss(loss_dict, global_step, phase=phase)
|
|
|
global_step += 1
|
|
|
|
|
|
if epoch_step == 0 and phase == 'val':
|
|
|
t_start = time.time()
|
|
|
print(f'start to predict:{t_start}')
|
|
|
result = model(self.move_to_device(imgs, self.device))
|
|
|
+ print(f'result:{result}')
|
|
|
t_end = time.time()
|
|
|
print(f'predict used:{t_end - t_start}')
|
|
|
- self.writer_predict_result(img=imgs[0], result=result, epoch=epoch)
|
|
|
+ self.writer_predict_result(img=imgs[0], result=result[0], epoch=epoch)
|
|
|
+ epoch_step+=1
|
|
|
|
|
|
avg_loss = total_loss / len(data_loader)
|
|
|
print(f'{phase}/loss epoch{epoch}:{avg_loss:4f}')
|
|
|
@@ -358,4 +387,4 @@ class Trainer(BaseTrainer):
|
|
|
|
|
|
|
|
|
if __name__ == '__main__':
|
|
|
- print('')
|
|
|
+ print('')
|