main.py 17 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494
  1. import math
  2. import os.path
  3. import re
  4. import sys
  5. import PIL.Image
  6. import torch
  7. import numpy as np
  8. import matplotlib.pyplot as plt
  9. import torchvision.transforms
  10. import torchvision.transforms.functional as F
  11. from torch.utils.data import DataLoader
  12. from torchvision.transforms import v2
  13. from torchvision.utils import make_grid, draw_bounding_boxes
  14. from torchvision.io import read_image
  15. from pathlib import Path
  16. from torchvision.models.detection import maskrcnn_resnet50_fpn_v2, MaskRCNN_ResNet50_FPN_V2_Weights
  17. # PyTorch TensorBoard support
  18. from torch.utils.tensorboard import SummaryWriter
  19. import cv2
  20. from sklearn.cluster import DBSCAN
  21. from test.MaskRCNN import MaskRCNNDataset
  22. from tools import utils
  23. import pandas as pd
  24. plt.rcParams["savefig.bbox"] = 'tight'
  25. orig_path = r'F:\Downloads\severstal-steel-defect-detection'
  26. dst_path = r'F:\Downloads\severstal-steel-defect-detection'
  27. def show(imgs):
  28. if not isinstance(imgs, list):
  29. imgs = [imgs]
  30. fig, axs = plt.subplots(ncols=len(imgs), squeeze=False)
  31. for i, img in enumerate(imgs):
  32. img = img.detach()
  33. img = F.to_pil_image(img)
  34. axs[0, i].imshow(np.asarray(img))
  35. axs[0, i].set(xticklabels=[], yticklabels=[], xticks=[], yticks=[])
  36. plt.show()
  37. def train_one_epoch(model, optimizer, data_loader, device, epoch, print_freq, scaler=None):
  38. model.train()
  39. metric_logger = utils.MetricLogger(delimiter=" ")
  40. metric_logger.add_meter("lr", utils.SmoothedValue(window_size=1, fmt="{value:.6f}"))
  41. header = f"Epoch: [{epoch}]"
  42. lr_scheduler = None
  43. if epoch == 0:
  44. warmup_factor = 1.0 / 1000
  45. warmup_iters = min(1000, len(data_loader) - 1)
  46. lr_scheduler = torch.optim.lr_scheduler.LinearLR(
  47. optimizer, start_factor=warmup_factor, total_iters=warmup_iters
  48. )
  49. for images, targets in metric_logger.log_every(data_loader, print_freq, header):
  50. images = list(image.to(device) for image in images)
  51. targets = [{k: v.to(device) if isinstance(v, torch.Tensor) else v for k, v in t.items()} for t in targets]
  52. with torch.cuda.amp.autocast(enabled=scaler is not None):
  53. loss_dict = model(images, targets)
  54. losses = sum(loss for loss in loss_dict.values())
  55. # reduce losses over all GPUs for logging purposes
  56. loss_dict_reduced = utils.reduce_dict(loss_dict)
  57. losses_reduced = sum(loss for loss in loss_dict_reduced.values())
  58. loss_value = losses_reduced.item()
  59. if not math.isfinite(loss_value):
  60. print(f"Loss is {loss_value}, stopping training")
  61. print(loss_dict_reduced)
  62. sys.exit(1)
  63. optimizer.zero_grad()
  64. if scaler is not None:
  65. scaler.scale(losses).backward()
  66. scaler.step(optimizer)
  67. scaler.update()
  68. else:
  69. losses.backward()
  70. optimizer.step()
  71. if lr_scheduler is not None:
  72. lr_scheduler.step()
  73. metric_logger.update(loss=losses_reduced, **loss_dict_reduced)
  74. metric_logger.update(lr=optimizer.param_groups[0]["lr"])
  75. return metric_logger
  76. def train():
  77. pass
  78. def trans_datasets_format():
  79. # 使用pandas的read_csv函数读取文件
  80. df = pd.read_csv(os.path.join(orig_path, 'train.csv'))
  81. # 显示数据的前几行
  82. print(df.head())
  83. for row in df.itertuples():
  84. # print(f"Row index: {row.Index}")
  85. # print(getattr(row, 'ImageId')) # 输出特定列的值
  86. img_name = getattr(row, 'ImageId')
  87. img_path = os.path.join(orig_path + '/train_images', img_name)
  88. dst_img_path = os.path.join(dst_path + '/images/train', img_name)
  89. dst_label_path = os.path.join(dst_path + '/labels/train', img_name[:-3] + 'txt')
  90. print(f'dst label:{dst_label_path}')
  91. im = cv2.imread(img_path)
  92. # cv2.imshow('test',im)
  93. cv2.imwrite(dst_img_path, im)
  94. img = PIL.Image.open(img_path)
  95. height, width = im.shape[:2]
  96. print(f'cv2 size:{im.shape}')
  97. label, mask = compute_mask(row, img.size)
  98. lbls, ins_masks=cluster_dbscan(mask,img)
  99. with open(dst_label_path, 'a+') as writer:
  100. # writer.write(label)
  101. for ins_mask in ins_masks:
  102. lbl_data = str(label) + ' '
  103. for mp in ins_mask:
  104. h,w=mp
  105. lbl_data += str(w / width) + ' ' + str(h / height) + ' '
  106. # non_zero_coords = np.nonzero(inm.reshape(width,height).T)
  107. # coords_list = list(zip(non_zero_coords[0], non_zero_coords[1]))
  108. # # print(f'mask:{mask[0,333]}')
  109. # print(f'mask pixels:{coords_list}')
  110. #
  111. #
  112. # for coord in coords_list:
  113. # h, w = coord
  114. # lbl_data += str(w / width) + ' ' + str(h / height) + ' '
  115. writer.write(lbl_data + '\n')
  116. print(f'lbl_data:{lbl_data}')
  117. writer.close()
  118. print(f'label:{label}')
  119. # plt.imshow(img)
  120. # plt.imshow(mask, cmap='Reds', alpha=0.3)
  121. # plt.show()
  122. def compute_mask(row, shape):
  123. width, height = shape
  124. print(f'shape:{shape}')
  125. mask = np.zeros(width * height, dtype=np.uint8)
  126. pixels = np.array(list(map(int, row.EncodedPixels.split())))
  127. label = row.ClassId
  128. # print(f'pixels:{pixels}')
  129. mask_start = pixels[0::2]
  130. mask_length = pixels[1::2]
  131. for s, l in zip(mask_start, mask_length):
  132. mask[s:s + l] = 255
  133. mask = mask.reshape((width, height)).T
  134. # mask = np.flipud(np.rot90(mask.reshape((height, width))))
  135. return label, mask
  136. def cluster_dbscan(mask,image):
  137. # 将 mask 转换为二值图像
  138. _, mask_binary = cv2.threshold(mask, 127, 255, cv2.THRESH_BINARY)
  139. # 将 mask 一维化
  140. mask_flattened = mask_binary.flatten()
  141. # 获取 mask 中的前景像素坐标
  142. foreground_pixels = np.argwhere(mask_flattened == 255)
  143. # 将像素坐标转换为二维坐标
  144. foreground_pixels_2d = np.column_stack(
  145. (foreground_pixels // mask_binary.shape[1], foreground_pixels % mask_binary.shape[1]))
  146. # 定义 DBSCAN 参数
  147. eps = 3 # 邻域半径
  148. min_samples = 10 # 最少样本数量
  149. # 应用 DBSCAN
  150. dbscan = DBSCAN(eps=eps, min_samples=min_samples).fit(foreground_pixels_2d)
  151. # 获取聚类标签
  152. labels = dbscan.labels_
  153. print(f'labels:{labels}')
  154. # 获取唯一的标签
  155. unique_labels = set(labels)
  156. print(f'unique_labels:{unique_labels}')
  157. # 创建一个空的图像来保存聚类结果
  158. clustered_image = np.zeros_like(image)
  159. # print(f'clustered_image shape:{clustered_image.shape}')
  160. # 将每个像素分配给相应的簇
  161. clustered_points=[]
  162. for k in unique_labels:
  163. class_member_mask = (labels == k)
  164. # print(f'class_member_mask:{class_member_mask}')
  165. # plt.subplot(132), plt.imshow(class_member_mask), plt.title(str(labels))
  166. pixel_indices = foreground_pixels_2d[class_member_mask]
  167. clustered_points.append(pixel_indices)
  168. return unique_labels,clustered_points
  169. def show_cluster_dbscan(mask,image,unique_labels,clustered_points,):
  170. print(f'mask shape:{mask.shape}')
  171. # 将 mask 转换为二值图像
  172. _, mask_binary = cv2.threshold(mask, 127, 255, cv2.THRESH_BINARY)
  173. # 将 mask 一维化
  174. mask_flattened = mask_binary.flatten()
  175. # 获取 mask 中的前景像素坐标
  176. foreground_pixels = np.argwhere(mask_flattened == 255)
  177. # print(f'unique_labels:{unique_labels}')
  178. # 创建一个空的图像来保存聚类结果
  179. print(f'image shape:{image.shape}')
  180. clustered_image = np.zeros_like(image)
  181. print(f'clustered_image shape:{clustered_image.shape}')
  182. # 为每个簇分配颜色
  183. colors =np.array( [plt.cm.Spectral(each) for each in np.linspace(0, 1, len(unique_labels))])
  184. # print(f'colors:{colors}')
  185. plt.figure(figsize=(12, 6))
  186. for points_coord,col in zip(clustered_points,colors):
  187. for coord in points_coord:
  188. clustered_image[coord[0], coord[1]] = (np.array(col[:3]) * 255)
  189. # # 将每个像素分配给相应的簇
  190. # for k, col in zip(unique_labels, colors):
  191. # print(f'col:{col*255}')
  192. # if k == -1:
  193. # # 黑色用于噪声点
  194. # col = [0, 0, 0, 1]
  195. #
  196. # class_member_mask = (labels == k)
  197. # # print(f'class_member_mask:{class_member_mask}')
  198. # # plt.subplot(132), plt.imshow(class_member_mask), plt.title(str(labels))
  199. #
  200. # pixel_indices = foreground_pixels_2d[class_member_mask]
  201. # clustered_points.append(pixel_indices)
  202. # # print(f'pixel_indices:{pixel_indices}')
  203. # for pixel_index in pixel_indices:
  204. # clustered_image[pixel_index[0], pixel_index[1]] = (np.array(col[:3]) * 255)
  205. print(f'clustered_points:{len(clustered_points)}')
  206. # print(f'clustered_image:{clustered_image}')
  207. # 显示原图和聚类结果
  208. # plt.figure(figsize=(12, 6))
  209. plt.subplot(131), plt.imshow(image), plt.title('Original Image')
  210. # print(f'image:{image}')
  211. plt.subplot(132), plt.imshow(mask_binary, cmap='gray'), plt.title('Mask')
  212. plt.subplot(133), plt.imshow(clustered_image.astype(np.uint8)), plt.title('Clustered Image')
  213. plt.show()
  214. def test():
  215. dog1_int = read_image(str(Path('./assets') / 'dog1.jpg'))
  216. dog2_int = read_image(str(Path('./assets') / 'dog2.jpg'))
  217. dog_list = [dog1_int, dog2_int]
  218. grid = make_grid(dog_list)
  219. weights = MaskRCNN_ResNet50_FPN_V2_Weights.DEFAULT
  220. transforms = weights.transforms()
  221. images = [transforms(d) for d in dog_list]
  222. # 假设输入图像的尺寸为 (3, 800, 800)
  223. dummy_input = torch.randn(1, 3, 800, 800)
  224. model = maskrcnn_resnet50_fpn_v2(weights=weights, progress=False)
  225. model = model.eval()
  226. # 使用 torch.jit.script
  227. scripted_model = torch.jit.script(model)
  228. output = model(dummy_input)
  229. print(f'output:{output}')
  230. writer = SummaryWriter('runs/')
  231. writer.add_graph(scripted_model, input_to_model=dummy_input)
  232. writer.flush()
  233. # torch.onnx.export(models,images, f='maskrcnn.onnx') # 导出 .onnx 文
  234. # netron.start('AlexNet.onnx') # 展示结构图
  235. show(grid)
  236. def test_mask():
  237. name = 'fdb7c0397'
  238. label_path = os.path.join(dst_path + '/labels/train', name + '.txt')
  239. img_path = os.path.join(orig_path + '/train_images', name + '.jpg')
  240. mask = np.zeros((256, 1600), dtype=np.uint8)
  241. df = pd.read_csv(os.path.join(orig_path, 'train.csv'))
  242. # 显示数据的前几行
  243. print(df.head())
  244. points = []
  245. with open(label_path, 'r') as reader:
  246. lines = reader.readlines()
  247. for line in lines:
  248. parts = line.strip().split()
  249. # print(f'parts:{parts}')
  250. class_id = int(parts[0])
  251. x_array = parts[1::2]
  252. y_array = parts[2::2]
  253. for x, y in zip(x_array, y_array):
  254. x = float(x)
  255. y = float(y)
  256. points.append((int(y * 255), int(x * 1600)))
  257. # points = np.array([[float(parts[i]), float(parts[i + 1])] for i in range(1, len(parts), 2)])
  258. # mask_resized = cv2.resize(points, (1600, 256), interpolation=cv2.INTER_NEAREST)
  259. print(f'points:{points}')
  260. # mask[points[:,0],points[:,1]]=255
  261. for p in points:
  262. mask[p] = 255
  263. # cv2.fillPoly(mask, points, color=(255,))
  264. cv2.imshow('mask', mask)
  265. for row in df.itertuples():
  266. img_name = name + '.jpg'
  267. if img_name == getattr(row, 'ImageId'):
  268. img = PIL.Image.open(img_path)
  269. height, width = img.size
  270. print(f'img size:{img.size}')
  271. label, mask = compute_mask(row, img.size)
  272. plt.imshow(img)
  273. plt.imshow(mask, cmap='Reds', alpha=0.3)
  274. plt.show()
  275. cv2.waitKey(0)
  276. def show_img_mask(img_path):
  277. test_img = PIL.Image.open(img_path)
  278. w,h=test_img.size
  279. test_img=torchvision.transforms.ToTensor()(test_img)
  280. test_img=test_img.permute(1, 2, 0)
  281. print(f'test_img shape:{test_img.shape}')
  282. lbl_path=re.sub(r'\\images\\', r'\\labels\\', img_path[:-3]) + 'txt'
  283. # print(f'lbl_path:{lbl_path}')
  284. masks = []
  285. labels = []
  286. with open(lbl_path, 'r') as reader:
  287. lines = reader.readlines()
  288. # 为每个簇分配颜色
  289. colors = np.array([plt.cm.Spectral(each) for each in np.linspace(0, 1, len(lines))])
  290. print(f'colors:{colors*255}')
  291. mask_points = []
  292. for line ,col in zip(lines,colors):
  293. print(f'col:{np.array(col[:3]) * 255}')
  294. mask = torch.zeros(test_img.shape, dtype=torch.uint8)
  295. # print(f'mask shape:{mask.shape}')
  296. parts = line.strip().split()
  297. # print(f'parts:{parts}')
  298. cls = torch.tensor(int(parts[0]), dtype=torch.int64)
  299. labels.append(cls)
  300. x_array = parts[1::2]
  301. y_array = parts[2::2]
  302. for x, y in zip(x_array, y_array):
  303. x = float(x)
  304. y = float(y)
  305. mask_points.append((int(y * h), int(x * w)))
  306. for p in mask_points:
  307. # print(f'p:{p}')
  308. mask[p] = torch.tensor(np.array(col[:3])*255)
  309. masks.append(mask)
  310. reader.close()
  311. target = {}
  312. # target["boxes"] = masks_to_boxes(torch.stack(masks))
  313. # target["labels"] = torch.stack(labels)
  314. target["masks"] = torch.stack(masks)
  315. print(f'target:{target}')
  316. # plt.imshow(test_img.permute(1, 2, 0))
  317. fig, axs = plt.subplots(2, 1)
  318. print(f'test_img:{test_img*255}')
  319. axs[0].imshow(test_img)
  320. axs[0].axis('off')
  321. axs[1].axis('off')
  322. axs[1].imshow(test_img*255)
  323. for img_mask in target['masks']:
  324. # img_mask=img_mask.unsqueeze(0)
  325. # img_mask = img_mask.expand_as(test_img)
  326. # print(f'img_mask:{img_mask.shape}')
  327. axs[1].imshow(img_mask,alpha=0.3)
  328. # img_mask=np.array(img_mask)
  329. # print(f'img_mask:{img_mask.shape}')
  330. # plt.imshow(img_mask,alpha=0.5)
  331. # mask_3channel = cv2.merge([np.zeros_like(img_mask), np.zeros_like(img_mask), img_mask])
  332. # masked_image = cv2.addWeighted(test_img, 1, mask_3channel, 0.6, 0)
  333. # cv2.imshow('cv2 mask img', masked_image)
  334. # cv2.waitKey(0)
  335. plt.show()
  336. def show_dataset():
  337. global transforms, dataset, imgs
  338. transforms = v2.Compose([
  339. # v2.RandomResizedCrop(size=(224, 224), antialias=True),
  340. # v2.RandomPhotometricDistort(p=1),
  341. # v2.RandomHorizontalFlip(p=1),
  342. v2.ToTensor()
  343. ])
  344. dataset = MaskRCNNDataset(dataset_path=r'F:\Downloads\severstal-steel-defect-detection', transforms=transforms,
  345. dataset_type='train')
  346. dataloader = DataLoader(dataset, batch_size=4, shuffle=False, collate_fn=utils.collate_fn)
  347. imgs, targets = next(iter(dataloader))
  348. mask = np.array(targets[2]['masks'][0])
  349. boxes = targets[2]['boxes']
  350. print(f'boxes:{boxes}')
  351. # mask[mask == 255] = 1
  352. img = np.array(imgs[2].permute(1, 2, 0)) * 255
  353. img = img.astype(np.uint8)
  354. print(f'img shape:{img.shape}')
  355. print(f'mask:{mask.shape}')
  356. # print(f'target:{targets}')
  357. # print(f'imgs:{imgs[0]}')
  358. # print(f'cv2 img shape:{np.array(imgs[0]).shape}')
  359. # cv2.imshow('cv2 img',img)
  360. # cv2.imshow('cv2 mask', mask)
  361. # plt.imshow('mask',mask)
  362. mask_3channel = cv2.merge([np.zeros_like(mask), np.zeros_like(mask), mask])
  363. # cv2.imshow('mask_3channel',mask_3channel)
  364. print(f'mask_3channel:{mask_3channel.shape}')
  365. masked_image = cv2.addWeighted(img, 1, mask_3channel, 0.6, 0)
  366. # cv2.imshow('cv2 mask img', masked_image)
  367. plt.imshow(imgs[0].permute(1, 2, 0))
  368. plt.imshow(mask, cmap='Reds', alpha=0.3)
  369. drawn_boxes = draw_bounding_boxes((imgs[2] * 255).to(torch.uint8), boxes, colors="red", width=5)
  370. plt.imshow(drawn_boxes.permute(1, 2, 0))
  371. # show(drawn_boxes)
  372. plt.show()
  373. cv2.waitKey(0)
  374. def test_cluster(img_path):
  375. test_img = PIL.Image.open(img_path)
  376. w, h = test_img.size
  377. test_img = torchvision.transforms.ToTensor()(test_img)
  378. test_img=(test_img.permute(1, 2, 0).numpy() * 255).astype(np.uint8)
  379. # print(f'test_img:{test_img}')
  380. lbl_path = re.sub(r'\\images\\', r'\\labels\\', img_path[:-3]) + 'txt'
  381. # print(f'lbl_path:{lbl_path}')
  382. masks = []
  383. labels = []
  384. with open(lbl_path, 'r') as reader:
  385. lines = reader.readlines()
  386. mask_points = []
  387. for line in lines:
  388. mask = torch.zeros((h, w), dtype=torch.uint8)
  389. parts = line.strip().split()
  390. # print(f'parts:{parts}')
  391. cls = torch.tensor(int(parts[0]), dtype=torch.int64)
  392. labels.append(cls)
  393. x_array = parts[1::2]
  394. y_array = parts[2::2]
  395. for x, y in zip(x_array, y_array):
  396. x = float(x)
  397. y = float(y)
  398. mask_points.append((int(y * h), int(x * w)))
  399. for p in mask_points:
  400. mask[p] = 255
  401. masks.append(mask)
  402. # print(f'masks:{masks}')
  403. labels,clustered_points=cluster_dbscan(masks[0].numpy(),test_img)
  404. print(f'labels:{labels}')
  405. print(f'clustered_points len:{len(clustered_points)}')
  406. show_cluster_dbscan(masks[0].numpy(),test_img,labels,clustered_points)
  407. if __name__ == '__main__':
  408. # trans_datasets_format()
  409. # test_mask()
  410. # 定义转换
  411. # show_dataset()
  412. # test_img_path= r"F:\Downloads\severstal-steel-defect-detection\images\train\0025bde0c.jpg"
  413. test_img_path = r"F:\DevTools\datasets\renyaun\1012\spilt\images\train\2024-09-27-14-32-53_SaveImage.png"
  414. # test_img1_path=r"F:\Downloads\severstal-steel-defect-detection\images\train\1d00226a0.jpg"
  415. show_img_mask(test_img_path)
  416. #
  417. # test_cluster(test_img_path)