|
@@ -1,9 +1,9 @@
|
|
|
# ??roi_head??????????????
|
|
# ??roi_head??????????????
|
|
|
-from torch import dtype
|
|
|
|
|
from torch.utils.data.dataset import T_co
|
|
from torch.utils.data.dataset import T_co
|
|
|
|
|
|
|
|
|
|
+from libs.vision_libs.models.detection import MaskRCNN_ResNet50_FPN_V2_Weights
|
|
|
from models.base.base_dataset import BaseDataset
|
|
from models.base.base_dataset import BaseDataset
|
|
|
-
|
|
|
|
|
|
|
+from torchvision.transforms import functional as F
|
|
|
import glob
|
|
import glob
|
|
|
import json
|
|
import json
|
|
|
import math
|
|
import math
|
|
@@ -70,6 +70,7 @@ class WirePointDataset(BaseDataset):
|
|
|
self.imgs = os.listdir(self.img_path)
|
|
self.imgs = os.listdir(self.img_path)
|
|
|
self.lbls = os.listdir(self.lbl_path)
|
|
self.lbls = os.listdir(self.lbl_path)
|
|
|
self.target_type = target_type
|
|
self.target_type = target_type
|
|
|
|
|
+ self.transform = MaskRCNN_ResNet50_FPN_V2_Weights.DEFAULT.transforms()
|
|
|
# self.default_transform = DefaultTransform()
|
|
# self.default_transform = DefaultTransform()
|
|
|
|
|
|
|
|
def __getitem__(self, index) -> T_co:
|
|
def __getitem__(self, index) -> T_co:
|
|
@@ -77,35 +78,61 @@ class WirePointDataset(BaseDataset):
|
|
|
lbl_path = os.path.join(self.lbl_path, self.imgs[index][:-3] + 'json')
|
|
lbl_path = os.path.join(self.lbl_path, self.imgs[index][:-3] + 'json')
|
|
|
|
|
|
|
|
# img = PIL.Image.open(img_path).convert('RGB')
|
|
# img = PIL.Image.open(img_path).convert('RGB')
|
|
|
- # w, h = img.size
|
|
|
|
|
img = cv2.imread(img_path, cv2.IMREAD_UNCHANGED)
|
|
img = cv2.imread(img_path, cv2.IMREAD_UNCHANGED)
|
|
|
- print(img.shape)
|
|
|
|
|
- w, h = img.shape[0:2]
|
|
|
|
|
|
|
+ img_rgb=img[:,:,:3]
|
|
|
|
|
|
|
|
|
|
+ print(f'img shape:{img.shape}')
|
|
|
|
|
+ img_rgb=cv2.cvtColor(img_rgb, cv2.COLOR_BGR2RGB)
|
|
|
|
|
+ # img=np.array(img,copy=True)
|
|
|
|
|
+ # img = self.default_transform(img)
|
|
|
|
|
+ # print(f'pil img:{img.dtype}')
|
|
|
|
|
+ # w, h = img.size
|
|
|
|
|
+ # img = cv2.imread(img_path, cv2.IMREAD_UNCHANGED)
|
|
|
|
|
+ # cv2.imshow('img',img)
|
|
|
|
|
+ # cv2.waitKey(1000000)
|
|
|
|
|
+ # print(img.shape)
|
|
|
|
|
+ w, h = img.shape[0:2]
|
|
|
|
|
+ # w, h = img.size
|
|
|
# wire_labels, target = self.read_target(item=index, lbl_path=lbl_path, shape=(h, w))
|
|
# wire_labels, target = self.read_target(item=index, lbl_path=lbl_path, shape=(h, w))
|
|
|
target = self.read_target(item=index, lbl_path=lbl_path, shape=(h, w))
|
|
target = self.read_target(item=index, lbl_path=lbl_path, shape=(h, w))
|
|
|
|
|
+
|
|
|
|
|
+ print(f'self.default_transform:{self.default_transform}')
|
|
|
|
|
+
|
|
|
# if self.transforms:
|
|
# if self.transforms:
|
|
|
# img, target = self.transforms(img, target)
|
|
# img, target = self.transforms(img, target)
|
|
|
# else:
|
|
# else:
|
|
|
# img = self.default_transform(img)
|
|
# img = self.default_transform(img)
|
|
|
|
|
|
|
|
- # 分离RGB和深度通道
|
|
|
|
|
- rgb_channels = img[:, :, :3]
|
|
|
|
|
- depth_channel = img[:, :, 3]
|
|
|
|
|
|
|
|
|
|
- # rgb_normalized = rgb_channels.astype(np.float32) / 255.0
|
|
|
|
|
- rgb_normalized = rgb_channels
|
|
|
|
|
- depth_normalized = (depth_channel - depth_channel.min()) / (depth_channel.max() - depth_channel.min())*255
|
|
|
|
|
-
|
|
|
|
|
- # 将归一化后的RGB和深度通道重新组合
|
|
|
|
|
- normalized_rgba_image = np.dstack((rgb_normalized, depth_normalized)) # 或者使用depth_normalized_fixed_range
|
|
|
|
|
|
|
+ # # å离RGBåæ·±åº¦éé
|
|
|
|
|
+ # rgb_channels = img[:, :, :3]
|
|
|
|
|
+ # depth_channel = img[:, :, 3]
|
|
|
|
|
+ #
|
|
|
|
|
+ # rgb_normalized = rgb_channels/255
|
|
|
|
|
+ # depth_normalized = (depth_channel - depth_channel.min()) / (depth_channel.max() - depth_channel.min())*255
|
|
|
|
|
+ #
|
|
|
|
|
+ # # å°å½ä¸ååçRGBåæ·±åº¦éééæ°ç»å
|
|
|
|
|
+ # normalized_rgba_image = np.dstack((rgb_normalized, depth_normalized)) # æè
使ç¨depth_normalized_fixed_range
|
|
|
|
|
+ #
|
|
|
|
|
+ # print("Normalized RGBA image shape:", normalized_rgba_image.shape)
|
|
|
|
|
+ #
|
|
|
|
|
+ # img = torch.tensor(normalized_rgba_image,dtype=torch.uint8).permute(2,1,0)
|
|
|
|
|
|
|
|
- print("Normalized RGBA image shape:", normalized_rgba_image.shape)
|
|
|
|
|
|
|
|
|
|
- img = torch.tensor(normalized_rgba_image,dtype=torch.float32).permute(2,1,0)
|
|
|
|
|
|
|
+ # cv2.imshow('img',img[:3].permute(1,2,0).numpy().astype(np.uint8))
|
|
|
|
|
+ # cv2.waitKey(10000000)
|
|
|
|
|
+ # plt.imshow(img[:3].permute(1,2,0).numpy())
|
|
|
|
|
+ # plt.show()
|
|
|
|
|
|
|
|
# new_channel = torch.zeros(1, 512, 512)
|
|
# new_channel = torch.zeros(1, 512, 512)
|
|
|
# img=torch.cat((img,new_channel),dim=0)
|
|
# img=torch.cat((img,new_channel),dim=0)
|
|
|
|
|
+ img=np.dstack((img_rgb,img[:,:,3]))
|
|
|
|
|
+
|
|
|
|
|
+ img=torch.as_tensor(img).permute(2,0,1)
|
|
|
|
|
+ img=self.default_transform(img)
|
|
|
|
|
+
|
|
|
|
|
+
|
|
|
|
|
+ # img=F.convert_image_dtype(img, torch.float)
|
|
|
print(f'img:{img.shape}')
|
|
print(f'img:{img.shape}')
|
|
|
# print(f'img dtype:{img.dtype}')
|
|
# print(f'img dtype:{img.dtype}')
|
|
|
return img, target
|
|
return img, target
|