|
@@ -1,3 +1,4 @@
|
|
|
|
|
+# ??roi_head??????????????
|
|
|
from torch.utils.data.dataset import T_co
|
|
from torch.utils.data.dataset import T_co
|
|
|
|
|
|
|
|
from models.base.base_dataset import BaseDataset
|
|
from models.base.base_dataset import BaseDataset
|
|
@@ -26,9 +27,10 @@ from models.dataset_tool import line_boxes, read_masks_from_txt_wire, read_masks
|
|
|
|
|
|
|
|
from tools.presets import DetectionPresetTrain
|
|
from tools.presets import DetectionPresetTrain
|
|
|
|
|
|
|
|
|
|
+
|
|
|
def line_boxes1(target):
|
|
def line_boxes1(target):
|
|
|
boxs = []
|
|
boxs = []
|
|
|
- lines = target.cpu().numpy()
|
|
|
|
|
|
|
+ lines = target.cpu().numpy() * 4
|
|
|
|
|
|
|
|
if len(lines) > 0 and not (lines[0] == 0).all():
|
|
if len(lines) > 0 and not (lines[0] == 0).all():
|
|
|
for i, ((a, b)) in enumerate(lines):
|
|
for i, ((a, b)) in enumerate(lines):
|
|
@@ -73,20 +75,37 @@ class WirePointDataset(BaseDataset):
|
|
|
img_path = os.path.join(self.img_path, self.imgs[index])
|
|
img_path = os.path.join(self.img_path, self.imgs[index])
|
|
|
lbl_path = os.path.join(self.lbl_path, self.imgs[index][:-3] + 'json')
|
|
lbl_path = os.path.join(self.lbl_path, self.imgs[index][:-3] + 'json')
|
|
|
|
|
|
|
|
- img = PIL.Image.open(img_path).convert('RGB')
|
|
|
|
|
- w, h = img.size
|
|
|
|
|
|
|
+ # img = PIL.Image.open(img_path).convert('RGB')
|
|
|
|
|
+ # w, h = img.size
|
|
|
|
|
+ img = cv2.imread(img_path, cv2.IMREAD_UNCHANGED)
|
|
|
|
|
+ print(img.shape)
|
|
|
|
|
+ w, h = img.shape[0:2]
|
|
|
|
|
|
|
|
# wire_labels, target = self.read_target(item=index, lbl_path=lbl_path, shape=(h, w))
|
|
# wire_labels, target = self.read_target(item=index, lbl_path=lbl_path, shape=(h, w))
|
|
|
target = self.read_target(item=index, lbl_path=lbl_path, shape=(h, w))
|
|
target = self.read_target(item=index, lbl_path=lbl_path, shape=(h, w))
|
|
|
- if self.transforms:
|
|
|
|
|
- img, target = self.transforms(img, target)
|
|
|
|
|
- else:
|
|
|
|
|
- img = self.default_transform(img)
|
|
|
|
|
|
|
+ # if self.transforms:
|
|
|
|
|
+ # img, target = self.transforms(img, target)
|
|
|
|
|
+ # else:
|
|
|
|
|
+ # img = self.default_transform(img)
|
|
|
|
|
+
|
|
|
|
|
+ # 分离RGB和深度通道
|
|
|
|
|
+ rgb_channels = img[:, :, :3]
|
|
|
|
|
+ depth_channel = img[:, :, 3]
|
|
|
|
|
+
|
|
|
|
|
+ rgb_normalized = rgb_channels.astype(np.float32) / 255.0
|
|
|
|
|
+ depth_normalized = (depth_channel - depth_channel.min()) / (depth_channel.max() - depth_channel.min())
|
|
|
|
|
+
|
|
|
|
|
+ # 将归一化后的RGB和深度通道重新组合
|
|
|
|
|
+ normalized_rgba_image = np.dstack((rgb_normalized, depth_normalized)) # 或者使用depth_normalized_fixed_range
|
|
|
|
|
|
|
|
|
|
+ print("Normalized RGBA image shape:", normalized_rgba_image.shape)
|
|
|
|
|
+
|
|
|
|
|
+ img = torch.tensor(normalized_rgba_image,dtype=torch.float32).permute(2,1,0)
|
|
|
|
|
|
|
|
# new_channel = torch.zeros(1, 512, 512)
|
|
# new_channel = torch.zeros(1, 512, 512)
|
|
|
# img=torch.cat((img,new_channel),dim=0)
|
|
# img=torch.cat((img,new_channel),dim=0)
|
|
|
- # print(f'img:{img.shape}')
|
|
|
|
|
|
|
+ print(f'img:{img.shape}')
|
|
|
|
|
+ # print(f'img dtype:{img.dtype}')
|
|
|
return img, target
|
|
return img, target
|
|
|
|
|
|
|
|
def __len__(self):
|
|
def __len__(self):
|
|
@@ -113,7 +132,7 @@ class WirePointDataset(BaseDataset):
|
|
|
ldir = lpre[:, 0, :2] - lpre[:, 1, :2]
|
|
ldir = lpre[:, 0, :2] - lpre[:, 1, :2]
|
|
|
ldir /= np.clip(LA.norm(ldir, axis=1, keepdims=True), 1e-6, None)
|
|
ldir /= np.clip(LA.norm(ldir, axis=1, keepdims=True), 1e-6, None)
|
|
|
feat = [
|
|
feat = [
|
|
|
- lpre[:, :, :2].reshape(-1, 4) / 512 * use_cood,
|
|
|
|
|
|
|
+ lpre[:, :, :2].reshape(-1, 4) / 128 * use_cood,
|
|
|
ldir * use_slop,
|
|
ldir * use_slop,
|
|
|
lpre[:, :, 2],
|
|
lpre[:, :, 2],
|
|
|
]
|
|
]
|
|
@@ -145,13 +164,12 @@ class WirePointDataset(BaseDataset):
|
|
|
target = {}
|
|
target = {}
|
|
|
# target["labels"] = torch.stack(labels)
|
|
# target["labels"] = torch.stack(labels)
|
|
|
|
|
|
|
|
-
|
|
|
|
|
target["image_id"] = torch.tensor(item)
|
|
target["image_id"] = torch.tensor(item)
|
|
|
# return wire_labels, target
|
|
# return wire_labels, target
|
|
|
target["wires"] = wire_labels
|
|
target["wires"] = wire_labels
|
|
|
# target["boxes"] = line_boxes(target)
|
|
# target["boxes"] = line_boxes(target)
|
|
|
target["boxes"] = line_boxes1(torch.tensor(wire["line_pos_coords"]["content"]))
|
|
target["boxes"] = line_boxes1(torch.tensor(wire["line_pos_coords"]["content"]))
|
|
|
- target["labels"]= torch.ones(len(target["boxes"]),dtype=torch.int64)
|
|
|
|
|
|
|
+ target["labels"] = torch.ones(len(target["boxes"]), dtype=torch.int64)
|
|
|
# print(f'target["labels"]:{ target["labels"]}')
|
|
# print(f'target["labels"]:{ target["labels"]}')
|
|
|
# print(f'boxes:{target["boxes"].shape}')
|
|
# print(f'boxes:{target["boxes"].shape}')
|
|
|
if target["boxes"].numel() == 0:
|
|
if target["boxes"].numel() == 0:
|
|
@@ -189,7 +207,6 @@ class WirePointDataset(BaseDataset):
|
|
|
break
|
|
break
|
|
|
plt.scatter(j[1], j[0], c="red", s=2, zorder=100) # ? s=64
|
|
plt.scatter(j[1], j[0], c="red", s=2, zorder=100) # ? s=64
|
|
|
|
|
|
|
|
-
|
|
|
|
|
img_path = os.path.join(self.img_path, self.imgs[idx])
|
|
img_path = os.path.join(self.img_path, self.imgs[idx])
|
|
|
img = PIL.Image.open(img_path).convert('RGB')
|
|
img = PIL.Image.open(img_path).convert('RGB')
|
|
|
boxed_image = draw_bounding_boxes((self.default_transform(img) * 255).to(torch.uint8), target["boxes"],
|
|
boxed_image = draw_bounding_boxes((self.default_transform(img) * 255).to(torch.uint8), target["boxes"],
|
|
@@ -201,19 +218,21 @@ class WirePointDataset(BaseDataset):
|
|
|
if fn != None:
|
|
if fn != None:
|
|
|
plt.savefig(fn)
|
|
plt.savefig(fn)
|
|
|
|
|
|
|
|
- junc = target['wires']['junc_coords'].cpu().numpy()
|
|
|
|
|
|
|
+ junc = target['wires']['junc_coords'].cpu().numpy() * 4
|
|
|
jtyp = target['wires']['jtyp'].cpu().numpy()
|
|
jtyp = target['wires']['jtyp'].cpu().numpy()
|
|
|
juncs = junc[jtyp == 0]
|
|
juncs = junc[jtyp == 0]
|
|
|
junts = junc[jtyp == 1]
|
|
junts = junc[jtyp == 1]
|
|
|
|
|
|
|
|
- lpre = target['wires']["lpre"].cpu().numpy()
|
|
|
|
|
|
|
+ lpre = target['wires']["lpre"].cpu().numpy() * 4
|
|
|
vecl_target = target['wires']["lpre_label"].cpu().numpy()
|
|
vecl_target = target['wires']["lpre_label"].cpu().numpy()
|
|
|
lpre = lpre[vecl_target == 1]
|
|
lpre = lpre[vecl_target == 1]
|
|
|
|
|
|
|
|
# draw_vecl(lpre, np.ones(lpre.shape[0]), juncs, junts, save_path)
|
|
# draw_vecl(lpre, np.ones(lpre.shape[0]), juncs, junts, save_path)
|
|
|
draw_vecl(lpre, np.ones(lpre.shape[0]), juncs, junts)
|
|
draw_vecl(lpre, np.ones(lpre.shape[0]), juncs, junts)
|
|
|
|
|
|
|
|
-
|
|
|
|
|
def show_img(self, img_path):
|
|
def show_img(self, img_path):
|
|
|
pass
|
|
pass
|
|
|
|
|
|
|
|
|
|
+# dataset_train = WirePointDataset("/data/lm/dataset/0424_", dataset_type='val')
|
|
|
|
|
+# for i in dataset_train:
|
|
|
|
|
+# a = 1
|