Ver código fonte

4080 train rgbd

lstrlq 6 meses atrás
pai
commit
ce3d797fa4

+ 3 - 0
libs/vision_libs/models/detection/transform.py

@@ -197,7 +197,10 @@ class GeneralizedRCNNTransform(nn.Module):
             return image, target
 
         bbox = target["boxes"]
+        print(f'bbox:{bbox}')
+        print(f'image.shape[-2:]:{image.shape},,,{image.shape[-2:]}')
         bbox = resize_boxes(bbox, (h, w), image.shape[-2:])
+
         target["boxes"] = bbox
 
         if "keypoints" in target:

+ 1 - 1
models/line_detect/train.yaml

@@ -1,6 +1,6 @@
 io:
   logdir: train_results
-  datadir: \\192.168.50.222/share/lm/1-dataset/a_dataset
+  datadir: /data/share/lm/1-dataset/a_dataset
 #  datadir: D:\python\PycharmProjects\data_20250223\0423_
 #  datadir: I:\datasets\wirenet_1000
 

+ 2 - 2
models/line_detect/train_demo.py

@@ -7,9 +7,9 @@ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
 if __name__ == '__main__':
 
     # model = LineNet('line_net.yaml')
-    # model=linenet_resnet50_fpn()
+    model=linenet_resnet50_fpn()
     # model=get_line_net_convnext_fpn(num_classes=2).to(device)
-    model=linenet_resnet18_fpn()
+    # model=linenet_resnet18_fpn()
     # model=linenet_resnet101_fpn_v2()
     # trainer = Trainer()
     # trainer.train_cfg(model,cfg='./train.yaml')

+ 71 - 100
models/line_detect/trainer.py

@@ -5,16 +5,13 @@ from datetime import datetime
 import numpy as np
 import torch
 from matplotlib import pyplot as plt
-from torch.optim.lr_scheduler import ReduceLROnPlateau
 from torch.utils.tensorboard import SummaryWriter
 
-from libs.vision_libs.utils import draw_bounding_boxes, draw_keypoints
+from libs.vision_libs.utils import draw_bounding_boxes
 from models.base.base_model import BaseModel
 from models.base.base_trainer import BaseTrainer
 from models.config.config_tool import read_yaml
-from models.line_detect.line_dataset import LineDataset
-
-from models.line_net.dataset_LD import WirePointDataset
+from models.line_detect.dataset_LD import WirePointDataset
 from models.wirenet.postprocess import postprocess
 from tools import utils
 from torchvision import transforms
@@ -42,42 +39,6 @@ def c(x):
 
 device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
 
-import matplotlib.pyplot as plt
-from PIL import ImageDraw
-from torchvision.transforms import functional as F
-import torch
-
-
-# 由低到高蓝黄红
-def draw_lines_with_scores(tensor_image, lines, scores, width=3, cmap='viridis'):
-    """
-    根据得分对线段着色并绘制
-    :param tensor_image: (3, H, W) uint8 图像
-    :param lines: (N, 2, 2) 每条线 [ [x1,y1], [x2,y2] ]
-    :param scores: (N,) 每条线的得分,范围 [0, 1]
-    :param width: 线宽
-    :param cmap: matplotlib colormap 名称,例如 'viridis', 'jet', 'coolwarm'
-    :return: (3, H, W) uint8 画好线的图像
-    """
-    assert tensor_image.dtype == torch.uint8
-    assert tensor_image.shape[0] == 3
-    assert lines.shape[0] == scores.shape[0]
-
-    # 准备色图
-    colormap = plt.get_cmap(cmap)
-    colors = (colormap(scores.cpu().numpy())[:, :3] * 255).astype('uint8')  # 去掉 alpha 通道
-
-    # 转为 PIL 画图
-    image_pil = F.to_pil_image(tensor_image)
-    draw = ImageDraw.Draw(image_pil)
-
-    for line, color in zip(lines, colors):
-        start = tuple(map(float, line[0][:2].tolist()))
-        end = tuple(map(float, line[1][:2].tolist()))
-        draw.line([start, end], fill=tuple(color), width=width)
-
-    return (F.to_tensor(image_pil) * 255).to(torch.uint8)
-
 
 class Trainer(BaseTrainer):
     def __init__(self, model=None, **kwargs):
@@ -92,7 +53,6 @@ class Trainer(BaseTrainer):
             self.freeze_config = kwargs['train_params']['freeze_params']
             print(f'freeze_config:{self.freeze_config}')
             self.dataset_path = kwargs['io']['datadir']
-            self.data_type = kwargs['io']['data_type']
             self.batch_size = kwargs['train_params']['batch_size']
             self.num_workers = kwargs['train_params']['num_workers']
             self.logdir = kwargs['io']['logdir']
@@ -106,7 +66,6 @@ class Trainer(BaseTrainer):
             self.best_train_model_path = os.path.join(self.wts_path, 'best_train.pth')
             self.best_val_model_path = os.path.join(self.wts_path, 'best_val.pth')
             self.max_epoch = kwargs['train_params']['max_epoch']
-            self.augmentation= kwargs['train_params']["augmentation"]
 
     def move_to_device(self, data, device):
         if isinstance(data, (list, tuple)):
@@ -187,43 +146,70 @@ class Trainer(BaseTrainer):
             print(f"No saved model found at {save_path}")
         return model, optimizer
 
-
-
-
-
-    def writer_predict_result(self, img, result, epoch,type=1):
+    def writer_predict_result(self, img, result, epoch):
         img = img.cpu().detach()
-        im = img.permute(1, 2, 0)  # [512, 512, 3]
+        img=img[:3,:,:]
+        im = img.permute(1, 2, 0)
         self.writer.add_image("z-ori", im, epoch, dataformats="HWC")
 
-        boxed_image = draw_bounding_boxes((img * 255).to(torch.uint8), result["boxes"],
+        boxed_image = draw_bounding_boxes((img * 255).to(torch.uint8), result[0]["boxes"],
                                           colors="yellow", width=1)
-
-        # plt.imshow(boxed_image.permute(1, 2, 0).detach().cpu().numpy())
+        self.writer.add_image("z-boxes", boxed_image.permute(1, 2, 0), epoch, dataformats="HWC")
+
+        PLTOPTS = {"color": "#33FFFF", "s": 15, "edgecolors": "none", "zorder": 5}
+        # print(f'pred[1]:{pred[1]}')
+        heatmaps = result[-2][0]
+        print(f'heatmaps:{heatmaps.shape}')
+        jmap = heatmaps[1: 2].cpu().detach()
+        lmap = heatmaps[2: 3].cpu().detach()
+        self.writer.add_image("z-jmap", jmap, epoch)
+        self.writer.add_image("z-lmap", lmap, epoch)
+        # plt.imshow(lmap)
         # plt.show()
-
-        self.writer.add_image("z-obj", boxed_image.permute(1, 2, 0), epoch, dataformats="HWC")
-
-
-        if type==1:
-            keypoint_img = draw_keypoints(boxed_image, result['points'], colors='red', width=3)
-
-            self.writer.add_image("z-output", keypoint_img, epoch)
-        # print("lines shape:", result['lines'].shape)
-
-
-        if type==2:
-            # 用自己写的函数画线段
-            # line_image = draw_lines(boxed_image, result['lines'], color='red', width=3)
-            print(f"shape of linescore:{result['liness_scores'].shape}")
-            scores = result['liness_scores'].mean(dim=1)  # shape: [31]
-
-            line_image = draw_lines_with_scores((img * 255).to(torch.uint8),  result['lines'],scores, width=3, cmap='jet')
-
-            self.writer.add_image("z-output_line", line_image.permute(1, 2, 0), epoch, dataformats="HWC")
-
-
-
+        H = result[-1]['wires']
+        lines = H["lines"][0].cpu().numpy() / 128 * im.shape[:2]
+        scores = H["score"][0].cpu().numpy()
+        for i in range(1, len(lines)):
+            if (lines[i] == lines[0]).all():
+                lines = lines[:i]
+                scores = scores[:i]
+                break
+
+        # postprocess lines to remove overlapped lines
+        diag = (im.shape[0] ** 2 + im.shape[1] ** 2) ** 0.5
+        nlines, nscores = postprocess(lines, scores, diag * 0.01, 0, False)
+
+        for i, t in enumerate([0]):
+            plt.gca().set_axis_off()
+            plt.subplots_adjust(top=1, bottom=0, right=1, left=0, hspace=0, wspace=0)
+            plt.margins(0, 0)
+            for (a, b), s in zip(nlines, nscores):
+                if s < t:
+                    continue
+                plt.plot([a[1], b[1]], [a[0], b[0]], c=c(s), linewidth=2, zorder=s)
+                plt.scatter(a[1], a[0], **PLTOPTS)
+                plt.scatter(b[1], b[0], **PLTOPTS)
+            plt.gca().xaxis.set_major_locator(plt.NullLocator())
+            plt.gca().yaxis.set_major_locator(plt.NullLocator())
+            plt.imshow(im)
+            plt.tight_layout()
+            fig = plt.gcf()
+            fig.canvas.draw()
+
+            width, height = fig.get_size_inches() * fig.get_dpi()  # 获取图像尺寸
+            tmp_img = fig.canvas.tostring_argb()
+            tmp_img_np = np.frombuffer(tmp_img, dtype=np.uint8)
+            tmp_img_np = tmp_img_np.reshape(int(height), int(width), 4)
+
+            img_rgb = tmp_img_np[:, :, 1:]  # 提取RGB部分,忽略Alpha通道
+
+            # image_from_plot = np.frombuffer(tmp_img[:,:,1:], dtype=np.uint8).reshape(
+            #     fig.canvas.get_width_height()[::-1] + (3,))
+            plt.close()
+
+            img2 = transforms.ToTensor()(img_rgb)
+
+            self.writer.add_image("z-output", img2, epoch)
 
     def writer_loss(self, losses, epoch, phase='train'):
         try:
@@ -250,8 +236,8 @@ class Trainer(BaseTrainer):
 
         self.init_params(**kwargs)
 
-        dataset_train = LineDataset(dataset_path=self.dataset_path,augmentation=self.augmentation, data_type=self.data_type, dataset_type='train')
-        dataset_val = LineDataset(dataset_path=self.dataset_path,augmentation=False, data_type=self.data_type, dataset_type='val')
+        dataset_train = WirePointDataset(dataset_path=self.dataset_path, dataset_type='train')
+        dataset_val = WirePointDataset(dataset_path=self.dataset_path, dataset_type='val')
 
         train_sampler = torch.utils.data.RandomSampler(dataset_train)
         val_sampler = torch.utils.data.RandomSampler(dataset_val)
@@ -261,7 +247,7 @@ class Trainer(BaseTrainer):
         val_collate_fn = utils.collate_fn
 
         data_loader_train = torch.utils.data.DataLoader(
-            dataset_train, batch_sampler=train_batch_sampler,  num_workers=self.num_workers, collate_fn=train_collate_fn
+            dataset_train, batch_sampler=train_batch_sampler, num_workers=self.num_workers, collate_fn=train_collate_fn
         )
         data_loader_val = torch.utils.data.DataLoader(
             dataset_val, batch_sampler=val_batch_sampler, num_workers=self.num_workers, collate_fn=val_collate_fn
@@ -271,29 +257,22 @@ class Trainer(BaseTrainer):
 
         optimizer = torch.optim.Adam(
             filter(lambda p: p.requires_grad, model.parameters()),
-            lr=kwargs['train_params']['optim']['lr'],
-            weight_decay=kwargs['train_params']['optim']['weight_decay'],
-
+            lr=kwargs['train_params']['optim']['lr']
         )
-        # scheduler = StepLR(optimizer, step_size=10, gamma=0.1)
-        scheduler = ReduceLROnPlateau(optimizer, 'min', patience=30)
 
         for epoch in range(self.max_epoch):
             print(f"train epoch:{epoch}")
 
             model, epoch_train_loss = self.one_epoch(model, data_loader_train, epoch, optimizer)
-            scheduler.step(epoch_train_loss)
 
             # ========== Validation ==========
             with torch.no_grad():
                 model, epoch_val_loss = self.one_epoch(model, data_loader_val, epoch, optimizer, phase='val')
-                scheduler.step(epoch_val_loss)
 
             if epoch==0:
                 best_train_loss = epoch_train_loss
                 best_val_loss = epoch_val_loss
 
-
             self.save_last_model(model,self.last_model_path, epoch, optimizer)
             best_train_loss = self.save_best_model(model, self.best_train_model_path, epoch, epoch_train_loss,
                                                    best_train_loss,
@@ -309,40 +288,32 @@ class Trainer(BaseTrainer):
 
         total_loss = 0
         epoch_step = 0
-        global_step = epoch * len(data_loader)
+        global_step = epoch_step * len(data_loader)
         for imgs, targets in data_loader:
             imgs = self.move_to_device(imgs, device)
             targets = self.move_to_device(targets, device)
             if phase== 'val':
-                result,loss_dict = model(imgs, targets)
-                losses = sum(loss_dict.values()) if loss_dict else torch.tensor(0.0, device=device)
 
-                print(f'val losses:{losses}')
-                print(f'val result:{result}')
+                result,losses = model(imgs, targets)
             else:
-                loss_dict = model(imgs, targets)
-                losses = sum(loss_dict.values()) if loss_dict else torch.tensor(0.0, device=device)
-                print(f'train losses:{losses}')
+                losses = model(imgs, targets)
 
-            # loss = _loss(losses)
-            loss=losses
+            loss = _loss(losses)
             total_loss += loss.item()
             if phase == 'train':
                 optimizer.zero_grad()
                 loss.backward()
                 optimizer.step()
-            self.writer_loss(loss_dict, global_step, phase=phase)
+            self.writer_loss(losses, global_step, phase=phase)
             global_step += 1
 
             if epoch_step == 0 and phase == 'val':
                 t_start = time.time()
                 print(f'start to predict:{t_start}')
                 result = model(self.move_to_device(imgs, self.device))
-                print(f'result:{result}')
                 t_end = time.time()
                 print(f'predict used:{t_end - t_start}')
-                self.writer_predict_result(img=imgs[0], result=result[0], epoch=epoch)
-                epoch_step+=1
+                self.writer_predict_result(img=imgs[0], result=result, epoch=epoch)
 
         avg_loss = total_loss / len(data_loader)
         print(f'{phase}/loss epoch{epoch}:{avg_loss:4f}')

+ 2 - 1
models/line_net/dataset_LD.py

@@ -92,7 +92,8 @@ class WirePointDataset(BaseDataset):
         rgb_channels = img[:, :, :3]
         depth_channel = img[:, :, 3]
 
-        rgb_normalized = rgb_channels.astype(np.float32) / 255.0
+        # rgb_normalized = rgb_channels.astype(np.float32) / 255.0
+        rgb_normalized = rgb_channels
         depth_normalized = (depth_channel - depth_channel.min()) / (depth_channel.max() - depth_channel.min())
 
         # 将归一化后的RGB和深度通道重新组合