|
|
@@ -27,6 +27,7 @@ from models.dataset_tool import line_boxes, read_masks_from_txt_wire, read_masks
|
|
|
|
|
|
from tools.presets import DetectionPresetTrain
|
|
|
|
|
|
+
|
|
|
def line_boxes1(target):
|
|
|
boxs = []
|
|
|
lines = target.cpu().numpy() * 4
|
|
|
@@ -74,20 +75,37 @@ class WirePointDataset(BaseDataset):
|
|
|
img_path = os.path.join(self.img_path, self.imgs[index])
|
|
|
lbl_path = os.path.join(self.lbl_path, self.imgs[index][:-3] + 'json')
|
|
|
|
|
|
- img = PIL.Image.open(img_path).convert('RGB')
|
|
|
- w, h = img.size
|
|
|
+ # img = PIL.Image.open(img_path).convert('RGB')
|
|
|
+ # w, h = img.size
|
|
|
+ img = cv2.imread(img_path, cv2.IMREAD_UNCHANGED)
|
|
|
+ print(img.shape)
|
|
|
+ w, h = img.shape[0:2]
|
|
|
|
|
|
# wire_labels, target = self.read_target(item=index, lbl_path=lbl_path, shape=(h, w))
|
|
|
target = self.read_target(item=index, lbl_path=lbl_path, shape=(h, w))
|
|
|
- if self.transforms:
|
|
|
- img, target = self.transforms(img, target)
|
|
|
- else:
|
|
|
- img = self.default_transform(img)
|
|
|
+ # if self.transforms:
|
|
|
+ # img, target = self.transforms(img, target)
|
|
|
+ # else:
|
|
|
+ # img = self.default_transform(img)
|
|
|
+
|
|
|
+ # 分离RGB和深度通道
|
|
|
+ rgb_channels = img[:, :, :3]
|
|
|
+ depth_channel = img[:, :, 3]
|
|
|
+
|
|
|
+ rgb_normalized = rgb_channels.astype(np.float32) / 255.0
|
|
|
+ depth_normalized = (depth_channel - depth_channel.min()) / (depth_channel.max() - depth_channel.min())
|
|
|
+
|
|
|
+ # 将归一化后的RGB和深度通道重新组合
|
|
|
+ normalized_rgba_image = np.dstack((rgb_normalized, depth_normalized)) # 或者使用depth_normalized_fixed_range
|
|
|
+
|
|
|
+ print("Normalized RGBA image shape:", normalized_rgba_image.shape)
|
|
|
|
|
|
+ img = torch.tensor(normalized_rgba_image,dtype=torch.float32).permute(2,1,0)
|
|
|
|
|
|
# new_channel = torch.zeros(1, 512, 512)
|
|
|
# img=torch.cat((img,new_channel),dim=0)
|
|
|
- # print(f'img:{img.shape}')
|
|
|
+ print(f'img:{img.shape}')
|
|
|
+ # print(f'img dtype:{img.dtype}')
|
|
|
return img, target
|
|
|
|
|
|
def __len__(self):
|
|
|
@@ -146,13 +164,12 @@ class WirePointDataset(BaseDataset):
|
|
|
target = {}
|
|
|
# target["labels"] = torch.stack(labels)
|
|
|
|
|
|
-
|
|
|
target["image_id"] = torch.tensor(item)
|
|
|
# return wire_labels, target
|
|
|
target["wires"] = wire_labels
|
|
|
# target["boxes"] = line_boxes(target)
|
|
|
target["boxes"] = line_boxes1(torch.tensor(wire["line_pos_coords"]["content"]))
|
|
|
- target["labels"]= torch.ones(len(target["boxes"]),dtype=torch.int64)
|
|
|
+ target["labels"] = torch.ones(len(target["boxes"]), dtype=torch.int64)
|
|
|
# print(f'target["labels"]:{ target["labels"]}')
|
|
|
# print(f'boxes:{target["boxes"].shape}')
|
|
|
if target["boxes"].numel() == 0:
|
|
|
@@ -190,7 +207,6 @@ class WirePointDataset(BaseDataset):
|
|
|
break
|
|
|
plt.scatter(j[1], j[0], c="red", s=2, zorder=100) # ? s=64
|
|
|
|
|
|
-
|
|
|
img_path = os.path.join(self.img_path, self.imgs[idx])
|
|
|
img = PIL.Image.open(img_path).convert('RGB')
|
|
|
boxed_image = draw_bounding_boxes((self.default_transform(img) * 255).to(torch.uint8), target["boxes"],
|
|
|
@@ -214,11 +230,9 @@ class WirePointDataset(BaseDataset):
|
|
|
# draw_vecl(lpre, np.ones(lpre.shape[0]), juncs, junts, save_path)
|
|
|
draw_vecl(lpre, np.ones(lpre.shape[0]), juncs, junts)
|
|
|
|
|
|
-
|
|
|
def show_img(self, img_path):
|
|
|
pass
|
|
|
|
|
|
-
|
|
|
# dataset_train = WirePointDataset("/data/lm/dataset/0424_", dataset_type='val')
|
|
|
# for i in dataset_train:
|
|
|
-# a = 1
|
|
|
+# a = 1
|