|
@@ -97,12 +97,20 @@ def show_line(img, pred, epoch, writer):
|
|
plt.tight_layout()
|
|
plt.tight_layout()
|
|
fig = plt.gcf()
|
|
fig = plt.gcf()
|
|
fig.canvas.draw()
|
|
fig.canvas.draw()
|
|
- image_from_plot = np.frombuffer(fig.canvas.tostring_rgb(), dtype=np.uint8).reshape(
|
|
|
|
- fig.canvas.get_width_height()[::-1] + (3,))
|
|
|
|
- plt.close()
|
|
|
|
- img2 = transforms.ToTensor()(image_from_plot)
|
|
|
|
|
|
+ width, height = fig.get_size_inches() * fig.get_dpi() # 获取图像尺寸
|
|
|
|
+ tmp_img = fig.canvas.tostring_argb()
|
|
|
|
+ tmp_img_np = np.frombuffer(tmp_img, dtype=np.uint8)
|
|
|
|
+ tmp_img_np = tmp_img_np.reshape(int(height), int(width), 4)
|
|
|
|
|
|
- writer.add_image("output", img2, epoch)
|
|
|
|
|
|
+ img_rgb = tmp_img_np[:, :, 1:] # 提取RGB部分,忽略Alpha通道
|
|
|
|
+
|
|
|
|
+ # image_from_plot = np.frombuffer(tmp_img[:,:,1:], dtype=np.uint8).reshape(
|
|
|
|
+ # fig.canvas.get_width_height()[::-1] + (3,))
|
|
|
|
+ # plt.close()
|
|
|
|
+
|
|
|
|
+ img2 = transforms.ToTensor()(img_rgb)
|
|
|
|
+
|
|
|
|
+ writer.add_image("z-output", img2, epoch)
|
|
|
|
|
|
|
|
|
|
def save_best_model(model, save_path, epoch, current_loss, best_loss, optimizer=None):
|
|
def save_best_model(model, save_path, epoch, current_loss, best_loss, optimizer=None):
|
|
@@ -282,8 +290,8 @@ if __name__ == '__main__':
|
|
for batch_idx, (imgs, targets) in enumerate(data_loader_val):
|
|
for batch_idx, (imgs, targets) in enumerate(data_loader_val):
|
|
pred = model(move_to_device(imgs, device))
|
|
pred = model(move_to_device(imgs, device))
|
|
|
|
|
|
- pred_ = box_line_(pred) # 将box与line对应
|
|
|
|
- show_(imgs, pred_, epoch, writer)
|
|
|
|
|
|
+ # pred_ = box_line_(pred) # 将box与line对应
|
|
|
|
+ # show_(imgs, pred_, epoch, writer)
|
|
|
|
|
|
if batch_idx == 0:
|
|
if batch_idx == 0:
|
|
show_line(imgs[0], pred, epoch, writer)
|
|
show_line(imgs[0], pred, epoch, writer)
|