| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118 |
- from collections import OrderedDict, defaultdict
- import numpy as np
- import torch
- import torch.nn as nn
- import torch.nn.functional as F
- from lcnn.config import M
- class MultitaskHead(nn.Module):
- def __init__(self, input_channels, num_class):
- super(MultitaskHead, self).__init__()
- # print("输入的维度是:", input_channels)
- m = int(input_channels / 4)
- heads = []
- for output_channels in sum(M.head_size, []):
- heads.append(
- nn.Sequential(
- nn.Conv2d(input_channels, m, kernel_size=3, padding=1),
- nn.ReLU(inplace=True),
- nn.Conv2d(m, output_channels, kernel_size=1),
- )
- )
- self.heads = nn.ModuleList(heads)
- assert num_class == sum(sum(M.head_size, []))
- def forward(self, x):
- return torch.cat([head(x) for head in self.heads], dim=1)
- class MultitaskLearner(nn.Module):
- def __init__(self, backbone):
- super(MultitaskLearner, self).__init__()
- self.backbone = backbone
- head_size = M.head_size
- self.num_class = sum(sum(head_size, []))
- self.head_off = np.cumsum([sum(h) for h in head_size])
- def forward(self, input_dict):
- image = input_dict["image"]
- target_b = input_dict["target_b"]
- outputs, feature, aaa = self.backbone(image, target_b, input_dict["mode"]) # train时aaa是损失,val时是box
- result = {"feature": feature}
- batch, channel, row, col = outputs[0].shape
- # print(f"batch:{batch}")
- # print(f"channel:{channel}")
- # print(f"row:{row}")
- # print(f"col:{col}")
- T = input_dict["target"].copy()
- n_jtyp = T["junc_map"].shape[1]
- # switch to CNHW
- for task in ["junc_map"]:
- T[task] = T[task].permute(1, 0, 2, 3)
- for task in ["junc_offset"]:
- T[task] = T[task].permute(1, 2, 0, 3, 4)
- offset = self.head_off
- loss_weight = M.loss_weight
- losses = []
- for stack, output in enumerate(outputs):
- output = output.transpose(0, 1).reshape([-1, batch, row, col]).contiguous()
- jmap = output[0: offset[0]].reshape(n_jtyp, 2, batch, row, col)
- lmap = output[offset[0]: offset[1]].squeeze(0)
- # print(f"lmap:{lmap.shape}")
- joff = output[offset[1]: offset[2]].reshape(n_jtyp, 2, batch, row, col)
- if stack == 0:
- result["preds"] = {
- "jmap": jmap.permute(2, 0, 1, 3, 4).softmax(2)[:, :, 1],
- "lmap": lmap.sigmoid(),
- "joff": joff.permute(2, 0, 1, 3, 4).sigmoid() - 0.5,
- }
- if input_dict["mode"] == "testing":
- return result
- L = OrderedDict()
- L["jmap"] = sum(
- cross_entropy_loss(jmap[i], T["junc_map"][i]) for i in range(n_jtyp)
- )
- L["lmap"] = (
- F.binary_cross_entropy_with_logits(lmap, T["line_map"], reduction="none")
- .mean(2)
- .mean(1)
- )
- L["joff"] = sum(
- sigmoid_l1_loss(joff[i, j], T["junc_offset"][i, j], -0.5, T["junc_map"][i])
- for i in range(n_jtyp)
- for j in range(2)
- )
- for loss_name in L:
- L[loss_name].mul_(loss_weight[loss_name])
- losses.append(L)
- result["losses"] = losses
- result["aaa"] = aaa
- return result
- def l2loss(input, target):
- return ((target - input) ** 2).mean(2).mean(1)
- def cross_entropy_loss(logits, positive):
- nlogp = -F.log_softmax(logits, dim=0)
- return (positive * nlogp[1] + (1 - positive) * nlogp[0]).mean(2).mean(1)
- def sigmoid_l1_loss(logits, target, offset=0.0, mask=None):
- logp = torch.sigmoid(logits) + offset
- loss = torch.abs(logp - target)
- if mask is not None:
- w = mask.mean(2, True).mean(1, True)
- w[w == 0] = 1
- loss = loss * (mask / w)
- return loss.mean(2).mean(1)
|