test_datasets.py 19 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535
  1. import functools
  2. import math
  3. import os.path
  4. import re
  5. import sys
  6. import PIL.Image
  7. import torch
  8. import numpy as np
  9. import matplotlib.pyplot as plt
  10. import torchvision.transforms
  11. import torchvision.transforms.functional as F
  12. from torch.utils.data import DataLoader
  13. from torchvision import transforms
  14. from torchvision.transforms import v2
  15. from torchvision.utils import make_grid, draw_bounding_boxes, draw_segmentation_masks
  16. from torchvision.io import read_image
  17. from pathlib import Path
  18. from torchvision.models.detection import maskrcnn_resnet50_fpn, MaskRCNN_ResNet50_FPN_Weights
  19. from torchvision.models.detection import maskrcnn_resnet50_fpn_v2, MaskRCNN_ResNet50_FPN_V2_Weights
  20. # PyTorch TensorBoard support
  21. from torch.utils.tensorboard import SummaryWriter
  22. import cv2
  23. from sklearn.cluster import DBSCAN
  24. from models.ins_detect.maskrcnn_dataset import MaskRCNNDataset
  25. from tools import utils
  26. import pandas as pd
  27. plt.rcParams["savefig.bbox"] = 'tight'
  28. orig_path = r'F:\Downloads\severstal-steel-defect-detection'
  29. dst_path = r'F:\Downloads\severstal-steel-defect-detection'
  30. def show(imgs):
  31. if not isinstance(imgs, list):
  32. imgs = [imgs]
  33. fig, axs = plt.subplots(ncols=len(imgs), squeeze=False)
  34. for i, img in enumerate(imgs):
  35. img = img.detach()
  36. img = F.to_pil_image(img)
  37. axs[0, i].imshow(np.asarray(img))
  38. axs[0, i].set(xticklabels=[], yticklabels=[], xticks=[], yticks=[])
  39. plt.show()
  40. def train_one_epoch(model, optimizer, data_loader, device, epoch, print_freq, scaler=None):
  41. model.train()
  42. metric_logger = utils.MetricLogger(delimiter=" ")
  43. metric_logger.add_meter("lr", utils.SmoothedValue(window_size=1, fmt="{value:.6f}"))
  44. header = f"Epoch: [{epoch}]"
  45. lr_scheduler = None
  46. if epoch == 0:
  47. warmup_factor = 1.0 / 1000
  48. warmup_iters = min(1000, len(data_loader) - 1)
  49. lr_scheduler = torch.optim.lr_scheduler.LinearLR(
  50. optimizer, start_factor=warmup_factor, total_iters=warmup_iters
  51. )
  52. for images, targets in metric_logger.log_every(data_loader, print_freq, header):
  53. images = list(image.to(device) for image in images)
  54. targets = [{k: v.to(device) if isinstance(v, torch.Tensor) else v for k, v in t.items()} for t in targets]
  55. with torch.cuda.amp.autocast(enabled=scaler is not None):
  56. loss_dict = model(images, targets)
  57. losses = sum(loss for loss in loss_dict.values())
  58. # reduce losses over all GPUs for logging purposes
  59. loss_dict_reduced = utils.reduce_dict(loss_dict)
  60. losses_reduced = sum(loss for loss in loss_dict_reduced.values())
  61. loss_value = losses_reduced.item()
  62. if not math.isfinite(loss_value):
  63. print(f"Loss is {loss_value}, stopping training")
  64. print(loss_dict_reduced)
  65. sys.exit(1)
  66. optimizer.zero_grad()
  67. if scaler is not None:
  68. scaler.scale(losses).backward()
  69. scaler.step(optimizer)
  70. scaler.update()
  71. else:
  72. losses.backward()
  73. optimizer.step()
  74. if lr_scheduler is not None:
  75. lr_scheduler.step()
  76. metric_logger.update(loss=losses_reduced, **loss_dict_reduced)
  77. metric_logger.update(lr=optimizer.param_groups[0]["lr"])
  78. return metric_logger
  79. def train():
  80. pass
  81. def trans_datasets_format():
  82. # 使用pandas的read_csv函数读取文件
  83. df = pd.read_csv(os.path.join(orig_path, 'train.csv'))
  84. # 显示数据的前几行
  85. print(df.head())
  86. for row in df.itertuples():
  87. # print(f"Row index: {row.Index}")
  88. # print(getattr(row, 'ImageId')) # 输出特定列的值
  89. img_name = getattr(row, 'ImageId')
  90. img_path = os.path.join(orig_path + '/train_images', img_name)
  91. dst_img_path = os.path.join(dst_path + '/images/train', img_name)
  92. dst_label_path = os.path.join(dst_path + '/labels/train', img_name[:-3] + 'txt')
  93. print(f'dst label:{dst_label_path}')
  94. im = cv2.imread(img_path)
  95. # cv2.imshow('test',im)
  96. cv2.imwrite(dst_img_path, im)
  97. img = PIL.Image.open(img_path)
  98. height, width = im.shape[:2]
  99. print(f'cv2 size:{im.shape}')
  100. label, mask = compute_mask(row, img.size)
  101. lbls, ins_masks=cluster_dbscan(mask,img)
  102. with open(dst_label_path, 'a+') as writer:
  103. # writer.write(label)
  104. for ins_mask in ins_masks:
  105. lbl_data = str(label) + ' '
  106. for mp in ins_mask:
  107. h,w=mp
  108. lbl_data += str(w / width) + ' ' + str(h / height) + ' '
  109. # non_zero_coords = np.nonzero(inm.reshape(width,height).T)
  110. # coords_list = list(zip(non_zero_coords[0], non_zero_coords[1]))
  111. # # print(f'ins:{ins[0,333]}')
  112. # print(f'ins pixels:{coords_list}')
  113. #
  114. #
  115. # for coord in coords_list:
  116. # h, w = coord
  117. # lbl_data += str(w / width) + ' ' + str(h / height) + ' '
  118. writer.write(lbl_data + '\n')
  119. print(f'lbl_data:{lbl_data}')
  120. writer.close()
  121. print(f'label:{label}')
  122. # plt.imshow(img)
  123. # plt.imshow(ins, cmap='Reds', alpha=0.3)
  124. # plt.show()
  125. def compute_mask(row, shape):
  126. width, height = shape
  127. print(f'shape:{shape}')
  128. mask = np.zeros(width * height, dtype=np.uint8)
  129. pixels = np.array(list(map(int, row.EncodedPixels.split())))
  130. label = row.ClassId
  131. # print(f'pixels:{pixels}')
  132. mask_start = pixels[0::2]
  133. mask_length = pixels[1::2]
  134. for s, l in zip(mask_start, mask_length):
  135. mask[s:s + l] = 255
  136. mask = mask.reshape((width, height)).T
  137. # ins = np.flipud(np.rot90(ins.reshape((height, width))))
  138. return label, mask
  139. def cluster_dbscan(mask,image):
  140. # 将 ins 转换为二值图像
  141. _, mask_binary = cv2.threshold(mask, 127, 255, cv2.THRESH_BINARY)
  142. # 将 ins 一维化
  143. mask_flattened = mask_binary.flatten()
  144. # 获取 ins 中的前景像素坐标
  145. foreground_pixels = np.argwhere(mask_flattened == 255)
  146. # 将像素坐标转换为二维坐标
  147. foreground_pixels_2d = np.column_stack(
  148. (foreground_pixels // mask_binary.shape[1], foreground_pixels % mask_binary.shape[1]))
  149. # 定义 DBSCAN 参数
  150. eps = 3 # 邻域半径
  151. min_samples = 10 # 最少样本数量
  152. # 应用 DBSCAN
  153. dbscan = DBSCAN(eps=eps, min_samples=min_samples).fit(foreground_pixels_2d)
  154. # 获取聚类标签
  155. labels = dbscan.labels_
  156. print(f'labels:{labels}')
  157. # 获取唯一的标签
  158. unique_labels = set(labels)
  159. print(f'unique_labels:{unique_labels}')
  160. # 创建一个空的图像来保存聚类结果
  161. clustered_image = np.zeros_like(image)
  162. # print(f'clustered_image shape:{clustered_image.shape}')
  163. # 将每个像素分配给相应的簇
  164. clustered_points=[]
  165. for k in unique_labels:
  166. class_member_mask = (labels == k)
  167. # print(f'class_member_mask:{class_member_mask}')
  168. # plt.subplot(132), plt.imshow(class_member_mask), plt.title(str(labels))
  169. pixel_indices = foreground_pixels_2d[class_member_mask]
  170. clustered_points.append(pixel_indices)
  171. return unique_labels,clustered_points
  172. def show_cluster_dbscan(mask,image,unique_labels,clustered_points,):
  173. print(f'ins shape:{mask.shape}')
  174. # 将 ins 转换为二值图像
  175. _, mask_binary = cv2.threshold(mask, 127, 255, cv2.THRESH_BINARY)
  176. # 将 ins 一维化
  177. mask_flattened = mask_binary.flatten()
  178. # 获取 ins 中的前景像素坐标
  179. foreground_pixels = np.argwhere(mask_flattened == 255)
  180. # print(f'unique_labels:{unique_labels}')
  181. # 创建一个空的图像来保存聚类结果
  182. print(f'image shape:{image.shape}')
  183. clustered_image = np.zeros_like(image)
  184. print(f'clustered_image shape:{clustered_image.shape}')
  185. # 为每个簇分配颜色
  186. colors =np.array( [plt.cm.Spectral(each) for each in np.linspace(0, 1, len(unique_labels))])
  187. # print(f'colors:{colors}')
  188. plt.figure(figsize=(12, 6))
  189. for points_coord,col in zip(clustered_points,colors):
  190. for coord in points_coord:
  191. clustered_image[coord[0], coord[1]] = (np.array(col[:3]) * 255)
  192. # # 将每个像素分配给相应的簇
  193. # for k, col in zip(unique_labels, colors):
  194. # print(f'col:{col*255}')
  195. # if k == -1:
  196. # # 黑色用于噪声点
  197. # col = [0, 0, 0, 1]
  198. #
  199. # class_member_mask = (labels == k)
  200. # # print(f'class_member_mask:{class_member_mask}')
  201. # # plt.subplot(132), plt.imshow(class_member_mask), plt.title(str(labels))
  202. #
  203. # pixel_indices = foreground_pixels_2d[class_member_mask]
  204. # clustered_points.append(pixel_indices)
  205. # # print(f'pixel_indices:{pixel_indices}')
  206. # for pixel_index in pixel_indices:
  207. # clustered_image[pixel_index[0], pixel_index[1]] = (np.array(col[:3]) * 255)
  208. print(f'clustered_points:{len(clustered_points)}')
  209. # print(f'clustered_image:{clustered_image}')
  210. # 显示原图和聚类结果
  211. # plt.figure(figsize=(12, 6))
  212. plt.subplot(131), plt.imshow(image), plt.title('Original Image')
  213. # print(f'image:{image}')
  214. plt.subplot(132), plt.imshow(mask_binary, cmap='gray'), plt.title('Mask')
  215. plt.subplot(133), plt.imshow(clustered_image.astype(np.uint8)), plt.title('Clustered Image')
  216. plt.show()
  217. def test():
  218. dog1_int = read_image(str(Path('../assets') / 'dog1.jpg'))
  219. dog2_int = read_image(str(Path('../assets') / 'dog2.jpg'))
  220. dog_list = [dog1_int, dog2_int]
  221. grid = make_grid(dog_list)
  222. weights = MaskRCNN_ResNet50_FPN_V2_Weights.DEFAULT
  223. transforms = weights.transforms()
  224. images = [transforms(d) for d in dog_list]
  225. # 假设输入图像的尺寸为 (3, 800, 800)
  226. dummy_input = torch.randn(1, 3, 800, 800)
  227. model = maskrcnn_resnet50_fpn_v2(weights=weights, progress=False)
  228. model = model.eval()
  229. # 使用 torch.jit.script
  230. scripted_model = torch.jit.script(model)
  231. output = model(dummy_input)
  232. print(f'output:{output}')
  233. writer = SummaryWriter('runs/')
  234. writer.add_graph(scripted_model, input_to_model=dummy_input)
  235. writer.flush()
  236. # torch.onnx.export(models,images, f='maskrcnn.onnx') # 导出 .onnx 文
  237. # netron.start('AlexNet.onnx') # 展示结构图
  238. show(grid)
  239. def test_mask():
  240. name = 'fdb7c0397'
  241. label_path = os.path.join(dst_path + '/labels/train', name + '.txt')
  242. img_path = os.path.join(orig_path + '/train_images', name + '.jpg')
  243. mask = np.zeros((256, 1600), dtype=np.uint8)
  244. df = pd.read_csv(os.path.join(orig_path, 'train.csv'))
  245. # 显示数据的前几行
  246. print(df.head())
  247. points = []
  248. with open(label_path, 'r') as reader:
  249. lines = reader.readlines()
  250. for line in lines:
  251. parts = line.strip().split()
  252. # print(f'parts:{parts}')
  253. class_id = int(parts[0])
  254. x_array = parts[1::2]
  255. y_array = parts[2::2]
  256. for x, y in zip(x_array, y_array):
  257. x = float(x)
  258. y = float(y)
  259. points.append((int(y * 255), int(x * 1600)))
  260. # points = np.array([[float(parts[i]), float(parts[i + 1])] for i in range(1, len(parts), 2)])
  261. # mask_resized = cv2.resize(points, (1600, 256), interpolation=cv2.INTER_NEAREST)
  262. print(f'points:{points}')
  263. # ins[points[:,0],points[:,1]]=255
  264. for p in points:
  265. mask[p] = 255
  266. # cv2.fillPoly(ins, points, color=(255,))
  267. cv2.imshow('ins', mask)
  268. for row in df.itertuples():
  269. img_name = name + '.jpg'
  270. if img_name == getattr(row, 'ImageId'):
  271. img = PIL.Image.open(img_path)
  272. height, width = img.size
  273. print(f'img size:{img.size}')
  274. label, mask = compute_mask(row, img.size)
  275. plt.imshow(img)
  276. plt.imshow(mask, cmap='Reds', alpha=0.3)
  277. plt.show()
  278. cv2.waitKey(0)
  279. def show_img_mask(img_path):
  280. test_img = PIL.Image.open(img_path)
  281. w,h=test_img.size
  282. print(f'test_img size:{test_img.size}')
  283. test_img=torchvision.transforms.ToTensor()(test_img)
  284. test_img=test_img.permute(1, 2, 0)
  285. print(f'test_img shape:{test_img.shape}')
  286. lbl_path=re.sub(r'\\images\\', r'\\labels\\', img_path[:-3]) + 'txt'
  287. # print(f'lbl_path:{lbl_path}')
  288. masks = []
  289. labels = []
  290. polygons=read_labels(lbl_path,test_img.shape)
  291. # print(f'polygons data:{polygons}')
  292. masks=create_mask_from_polygons(polygons,test_img.shape)
  293. # print(f'polygons shape:{polygons.shape}')
  294. labels =[item[0] for item in polygons]
  295. print(f'labels:{labels}')
  296. target = {}
  297. # target["boxes"] = masks_to_boxes(torch.stack(masks))
  298. # target["labels"] = torch.stack(labels)
  299. target["masks"] = torch.stack(masks)
  300. print(f'target:{target}')
  301. # plt.imshow(test_img.permute(1, 2, 0))
  302. fig, axs = plt.subplots(2, 1)
  303. print(f'test_img:{test_img*255}')
  304. axs[0].imshow(test_img)
  305. axs[0].axis('off')
  306. axs[1].axis('off')
  307. axs[1].imshow(test_img*255)
  308. for img_mask in target['masks']:
  309. # img_mask=img_mask.unsqueeze(0)
  310. # img_mask = img_mask.expand_as(test_img)
  311. # print(f'img_mask:{img_mask.shape}')
  312. axs[1].imshow(img_mask,alpha=0.3)
  313. # img_mask=np.array(img_mask)
  314. # print(f'img_mask:{img_mask.shape}')
  315. # plt.imshow(img_mask,alpha=0.5)
  316. # mask_3channel = cv2.merge([np.zeros_like(img_mask), np.zeros_like(img_mask), img_mask])
  317. # masked_image = cv2.addWeighted(test_img, 1, mask_3channel, 0.6, 0)
  318. # cv2.imshow('cv2 ins img', masked_image)
  319. # cv2.waitKey(0)
  320. plt.show()
  321. def create_mask_from_polygons(polygons, image_shape):
  322. """创建一个与图像尺寸相同的掩码,并填充多边形轮廓"""
  323. colors = np.array([plt.cm.Spectral(each) for each in np.linspace(0, 1, len(polygons))])
  324. masks=[]
  325. for polygon_data ,col in zip(polygons,colors):
  326. mask = np.zeros(image_shape[:2], dtype=np.uint8)
  327. # 将多边形顶点转换为 NumPy 数组
  328. _,polygon=polygon_data
  329. pts = np.array(polygon, np.int32).reshape((-1, 1, 2))
  330. # 使用 OpenCV 的 fillPoly 函数填充多边形
  331. print(f'color:{col[:3]}')
  332. cv2.fillPoly(mask, [pts], np.array(col[:3]) * 255)
  333. mask=torch.from_numpy(mask)
  334. masks.append(mask)
  335. return masks
  336. def read_labels(lbl_path,shape):
  337. """读取 YOLOv8 格式的标注文件并解析多边形轮廓"""
  338. polygons = []
  339. w, h = shape[:2]
  340. with open(lbl_path, 'r') as f:
  341. lines = f.readlines()
  342. for line in lines:
  343. parts = line.strip().split()
  344. class_id = int(parts[0])
  345. # 假设多边形顶点从第2个元素开始,且已经归一化
  346. polygon = [float(coord) for coord in parts[1:]]
  347. # # 将归一化坐标转换为像素坐标
  348. # polygon = [int(polygon[i] * image_shape[1] if i % 2 == 0 else polygon[i] * image_shape[0]) for i in
  349. # range(len(polygon))]
  350. points = np.array(parts[1:], dtype=np.float32).reshape(-1, 2) # 读取点坐标
  351. # print(f'points :{points}')
  352. points[:, 0] *= h
  353. points[:, 1] *= w
  354. # 将轮廓坐标重新组织为 (x, y) 对
  355. # polygon = [(polygon[i], polygon[i + 1]) for i in range(0, len(polygon), 2)]
  356. # polygons.append((class_id, polygon))
  357. polygons.append((class_id, points))
  358. return polygons
  359. def show_dataset():
  360. global transforms, dataset, imgs
  361. transforms = v2.Compose([
  362. # v2.RandomResizedCrop(size=(224, 224), antialias=True),
  363. # v2.RandomPhotometricDistort(p=1),
  364. # v2.RandomHorizontalFlip(p=1),
  365. v2.ToTensor()
  366. ])
  367. dataset = MaskRCNNDataset(dataset_path=r'\\192.168.50.222\share\rlq\datasets\bangcai2', transforms=transforms,
  368. dataset_type='train')
  369. dataloader = DataLoader(dataset, batch_size=1, shuffle=False, collate_fn=utils.collate_fn)
  370. for imgs, targets in dataloader:
  371. masks=targets[0]['masks']
  372. boxes = targets[0]['boxes']
  373. print(f'boxes:{boxes}')
  374. # ins[ins == 255] = 1
  375. # img = np.array(imgs[2].permute(1, 2, 0)) * 255
  376. show_boxes_masks( imgs, boxes,masks)
  377. def show_boxes_masks(imgs, boxes,masks):
  378. img = np.array(imgs[0])
  379. img = img.astype(np.uint8)
  380. masks=masks.to(torch.bool)
  381. print(f'masks shape:{masks.shape}')
  382. print(f'img shape:{img.shape}')
  383. print(f'img shape:{img.shape}')
  384. # print(f'ins:{ins.shape}')
  385. # mask_3channel = cv2.merge([np.zeros_like(masks[0]), np.zeros_like(masks[0]), masks[0]])
  386. # print(f'mask_3channel:{mask_3channel.shape}')
  387. img_tensor = torch.tensor(imgs[0], dtype=torch.uint8)
  388. boxed_img = draw_bounding_boxes(img_tensor, boxes).permute(1, 2, 0).contiguous()
  389. masked_img = draw_segmentation_masks(img_tensor, masks).permute(1, 2, 0).contiguous()
  390. plt.imshow(imgs[0].permute(1, 2, 0))
  391. # plt.imshow(ins, cmap='Reds', alpha=0.5)
  392. plt.imshow(masked_img, cmap='Reds', alpha=0.3)
  393. plt.imshow(boxed_img, cmap='Greens', alpha=0.5)
  394. plt.show()
  395. cv2.waitKey(0)
  396. def test_cluster(img_path):
  397. test_img = PIL.Image.open(img_path)
  398. w, h = test_img.size
  399. test_img = torchvision.transforms.ToTensor()(test_img)
  400. test_img=(test_img.permute(1, 2, 0).numpy() * 255).astype(np.uint8)
  401. # print(f'test_img:{test_img}')
  402. lbl_path = re.sub(r'\\images\\', r'\\labels\\', img_path[:-3]) + 'txt'
  403. # print(f'lbl_path:{lbl_path}')
  404. masks = []
  405. labels = []
  406. with open(lbl_path, 'r') as reader:
  407. lines = reader.readlines()
  408. mask_points = []
  409. for line in lines:
  410. mask = torch.zeros((h, w), dtype=torch.uint8)
  411. parts = line.strip().split()
  412. # print(f'parts:{parts}')
  413. cls = torch.tensor(int(parts[0]), dtype=torch.int64)
  414. labels.append(cls)
  415. x_array = parts[1::2]
  416. y_array = parts[2::2]
  417. for x, y in zip(x_array, y_array):
  418. x = float(x)
  419. y = float(y)
  420. mask_points.append((int(y * h), int(x * w)))
  421. for p in mask_points:
  422. mask[p] = 255
  423. masks.append(mask)
  424. # print(f'masks:{masks}')
  425. labels,clustered_points=cluster_dbscan(masks[0].numpy(),test_img)
  426. print(f'labels:{labels}')
  427. print(f'clustered_points len:{len(clustered_points)}')
  428. show_cluster_dbscan(masks[0].numpy(),test_img,labels,clustered_points)
  429. if __name__ == '__main__':
  430. # trans_datasets_format()
  431. # test_mask()
  432. # 定义转换
  433. show_dataset()
  434. # test_img_path= r"F:\Downloads\severstal-steel-defect-detection\images\train\0025bde0c.jpg"
  435. # test_img_path = r"F:\DevTools\datasets\renyaun\1012\spilt\images\train\2024-09-23-10-03-03_SaveImage.png"
  436. # test_img_path=r"\\192.168.50.222\share\rlq\datasets\bangcai2\images\train\frame_000068.jpg"
  437. # show_img_mask(test_img_path)
  438. #
  439. # test_cluster(test_img_path)