RenLiqiang 3 tháng trước cách đây
mục cha
commit
a46b48e57f
41 tập tin đã thay đổi với 6004 bổ sung3142 xóa
  1. 26 0
      .gitignore
  2. 77 0
      aaa.py
  3. 5 7
      config/wireframe.yaml
  4. 0 4
      lcnn/__init__.py
  5. 0 1110
      lcnn/box.py
  6. 0 9
      lcnn/config.py
  7. 0 378
      lcnn/dataset_tool.py
  8. 0 209
      lcnn/metric.py
  9. 0 9
      lcnn/models/__init__.py
  10. 0 48
      lcnn/models/base/base_dataset.py
  11. 876 0
      lcnn/models/detection/ROI_heads.py
  12. 846 0
      lcnn/models/detection/faster_rcnn.py
  13. 336 0
      lcnn/models/detection/transform.py
  14. 8 51
      lcnn/models/fasterrcnn_resnet50.py
  15. 0 201
      lcnn/models/hourglass_pose.py
  16. 0 276
      lcnn/models/line_vectorizer.py
  17. 0 118
      lcnn/models/multitask_learner.py
  18. 0 182
      lcnn/models/resnet50.py
  19. 0 87
      lcnn/models/resnet50_pose.py
  20. 0 126
      lcnn/models/unet.py
  21. 0 77
      lcnn/postprocess.py
  22. 0 2
      lcnn/trainer.py
  23. 0 101
      lcnn/utils.py
  24. 124 0
      models/base/base_detection_net.py
  25. 1 6
      models/dataset_tool.py
  26. 0 0
      models/ins_detect/__init__.py
  27. 143 0
      models/ins_detect/maskrcnn.py
  28. 93 0
      models/ins_detect/maskrcnn_dataset.py
  29. 31 0
      models/ins_detect/train.yaml
  30. 220 0
      models/ins_detect/trainer.py
  31. 312 0
      models/keypoint/keypoint_dataset.py
  32. 0 0
      models/line_detect/__init__.py
  33. 26 141
      models/line_detect/dataset_LD.py
  34. 24 0
      models/line_detect/line_head.py
  35. 912 0
      models/line_detect/line_net.py
  36. 324 0
      models/line_detect/line_predictor.py
  37. 1177 0
      models/line_detect/roi_heads.py
  38. 0 0
      models/obj_detect/__init__.py
  39. 111 0
      predict.py
  40. 18 0
      readme.md
  41. 314 0
      train——line_rcnn.py

+ 26 - 0
.gitignore

@@ -1,5 +1,31 @@
 .idea
 *.pt
+*.log
+*.onnx
 runs
+logs
+log
+
+/tensorboard/
+logs/
+tensorboard_logs/
+summaries/
+events.out.tfevents.*
+
+# If you have a specific directory for your runs, you can ignore it directly
+/runs/
+
+# Ignore checkpoint files if you don't want to track them
+checkpoint
+*.ckpt.data-*
+*.ckpt.index
+*.ckpt.meta
+
+# Ignore TensorFlow model files that are not necessary for version control
+*.pb
+/*.pbtxt
+# Ignore Jupyter Notebook checkpoints
+/.ipynb_checkpoints/
+
 __pycache__
 train_results

+ 77 - 0
aaa.py

@@ -0,0 +1,77 @@
+import torch
+from torchvision.utils import draw_bounding_boxes
+from torchvision import transforms
+import matplotlib.pyplot as plt
+import numpy as np
+
+
+def c(score):
+    # 根据分数返回颜色的函数,这里仅作示例,您可以根据需要修改
+    return (1, 0, 0) if score > 0.9 else (0, 1, 0)
+
+
+def postprocess(lines, scores, diag_threshold, min_score, remove_overlaps):
+    # 假设的后处理函数,用于过滤线段
+    nlines = []
+    nscores = []
+    for line, score in zip(lines, scores):
+        if score >= min_score:
+            nlines.append(line)
+            nscores.append(score)
+    return np.array(nlines), np.array(nscores)
+
+
+def show_line(img, pred, epoch, writer):
+    im = img.permute(1, 2, 0).cpu().numpy()
+
+    # 绘制边界框
+    boxed_image = draw_bounding_boxes((img * 255).to(torch.uint8), pred[0]["boxes"],
+                                      colors="yellow", width=1).permute(1, 2, 0).cpu().numpy()
+
+    H = pred[-1]['wires']
+    lines = H["lines"][0].cpu().numpy() / 128 * im.shape[:2]
+    scores = H["score"][0].cpu().numpy()
+
+    print(f"Lines before deduplication: {len(lines)}")
+
+    # 移除重复的线段
+    for i in range(1, len(lines)):
+        if (lines[i] == lines[0]).all():
+            lines = lines[:i]
+            scores = scores[:i]
+            break
+
+    print(f"Lines after deduplication: {len(lines)}")
+
+    # 后处理线段以移除重叠的线段
+    diag = (im.shape[0] ** 2 + im.shape[1] ** 2) ** 0.5
+    nlines, nscores = postprocess(lines, scores, diag * 0.01, 0, False)
+
+    print(f"Lines after postprocessing: {len(nlines)}")
+
+    # 创建一个新的图像并绘制线段和边界框
+    fig, ax = plt.subplots(figsize=(boxed_image.shape[1] / 100, boxed_image.shape[0] / 100))
+    ax.imshow(boxed_image)
+    ax.set_axis_off()
+    plt.subplots_adjust(top=1, bottom=0, right=1, left=0, hspace=0, wspace=0)
+    plt.margins(0, 0)
+    plt.gca().xaxis.set_major_locator(plt.NullLocator())
+    plt.gca().yaxis.set_major_locator(plt.NullLocator())
+
+    PLTOPTS = {"color": "#33FFFF", "s": 15, "edgecolors": "none", "zorder": 5}
+    for (a, b), s in zip(nlines, nscores):
+        if s < 0.85:  # 调整阈值以筛选显示的线段
+            continue
+        plt.plot([a[1], b[1]], [a[0], b[0]], c=c(s), linewidth=2, zorder=s)
+        plt.scatter(a[1], a[0], **PLTOPTS)
+        plt.scatter(b[1], b[0], **PLTOPTS)
+
+    plt.tight_layout()
+    fig.canvas.draw()
+    image_from_plot = np.frombuffer(fig.canvas.tostring_rgb(), dtype=np.uint8).reshape(
+        fig.canvas.get_width_height()[::-1] + (3,))
+    plt.close()
+    img2 = transforms.ToTensor()(image_from_plot)
+
+    writer.add_image("output_with_boxes_and_lines", img2, epoch)
+    print("Image with boxes and lines added to TensorBoard.")

+ 5 - 7
config/wireframe.yaml

@@ -1,12 +1,10 @@
 io:
   logdir: logs/
   datadir: D:\python\PycharmProjects\data
-#  datadir: /home/dieu/lcnn/dataset/line_data_104
   resume_from:
-#  resume_from: /home/dieu/lcnn/logs/241112-163302-175fb79-my_data_104_resume
-  num_workers: 0
-  tensorboard_port: 0
-  validation_interval: 300    # 评估间隔
+  num_workers: 8
+  tensorboard_port: 6000
+  validation_interval: 300
 
 model:
   image:
@@ -17,14 +15,14 @@ model:
   batch_size_eval: 2
 
   # backbone multi-task parameters
-  head_size: [[2], [1], [2],[4]]
+  head_size: [[2], [1], [2]]
   loss_weight:
     jmap: 8.0
     lmap: 0.5
     joff: 0.25
     lpos: 1
     lneg: 1
-    boxes: 1.0  # 新增 box loss 权重
+    boxes: 1.0
 
   # backbone parameters
   backbone: fasterrcnn_resnet50

+ 0 - 4
lcnn/__init__.py

@@ -1,4 +0,0 @@
-import lcnn.models
-import lcnn.trainer
-import lcnn.datasets
-import lcnn.config

+ 0 - 1110
lcnn/box.py

@@ -1,1110 +0,0 @@
-#!/usr/bin/env python
-# -*- coding: UTF-8 -*-
-#
-# Copyright (c) 2017-2019 - Chris Griffith - MIT License
-"""
-Improved dictionary access through dot notation with additional tools.
-"""
-import string
-import sys
-import json
-import re
-import copy
-from keyword import kwlist
-import warnings
-
-try:
-    from collections.abc import Iterable, Mapping, Callable
-except ImportError:
-    from collections import Iterable, Mapping, Callable
-
-yaml_support = True
-
-try:
-    import yaml
-except ImportError:
-    try:
-        import ruamel.yaml as yaml
-    except ImportError:
-        yaml = None
-        yaml_support = False
-
-if sys.version_info >= (3, 0):
-    basestring = str
-else:
-    from io import open
-
-__all__ = ['Box', 'ConfigBox', 'BoxList', 'SBox',
-           'BoxError', 'BoxKeyError']
-__author__ = 'Chris Griffith'
-__version__ = '3.2.4'
-
-BOX_PARAMETERS = ('default_box', 'default_box_attr', 'conversion_box',
-                  'frozen_box', 'camel_killer_box', 'box_it_up',
-                  'box_safe_prefix', 'box_duplicates', 'ordered_box')
-
-_first_cap_re = re.compile('(.)([A-Z][a-z]+)')
-_all_cap_re = re.compile('([a-z0-9])([A-Z])')
-
-
-class BoxError(Exception):
-    """Non standard dictionary exceptions"""
-
-
-class BoxKeyError(BoxError, KeyError, AttributeError):
-    """Key does not exist"""
-
-
-# Abstract converter functions for use in any Box class
-
-
-def _to_json(obj, filename=None,
-             encoding="utf-8", errors="strict", **json_kwargs):
-    json_dump = json.dumps(obj,
-                           ensure_ascii=False, **json_kwargs)
-    if filename:
-        with open(filename, 'w', encoding=encoding, errors=errors) as f:
-            f.write(json_dump if sys.version_info >= (3, 0) else
-                    json_dump.decode("utf-8"))
-    else:
-        return json_dump
-
-
-def _from_json(json_string=None, filename=None,
-               encoding="utf-8", errors="strict", multiline=False, **kwargs):
-    if filename:
-        with open(filename, 'r', encoding=encoding, errors=errors) as f:
-            if multiline:
-                data = [json.loads(line.strip(), **kwargs) for line in f
-                        if line.strip() and not line.strip().startswith("#")]
-            else:
-                data = json.load(f, **kwargs)
-    elif json_string:
-        data = json.loads(json_string, **kwargs)
-    else:
-        raise BoxError('from_json requires a string or filename')
-    return data
-
-
-def _to_yaml(obj, filename=None, default_flow_style=False,
-             encoding="utf-8", errors="strict",
-             **yaml_kwargs):
-    if filename:
-        with open(filename, 'w',
-                  encoding=encoding, errors=errors) as f:
-            yaml.dump(obj, stream=f,
-                      default_flow_style=default_flow_style,
-                      **yaml_kwargs)
-    else:
-        return yaml.dump(obj,
-                         default_flow_style=default_flow_style,
-                         **yaml_kwargs)
-
-
-def _from_yaml(yaml_string=None, filename=None,
-               encoding="utf-8", errors="strict",
-               **kwargs):
-    if filename:
-        with open(filename, 'r',
-                  encoding=encoding, errors=errors) as f:
-            data = yaml.load(f, **kwargs)
-    elif yaml_string:
-        data = yaml.load(yaml_string, **kwargs)
-    else:
-        raise BoxError('from_yaml requires a string or filename')
-    return data
-
-
-# Helper functions
-
-
-def _safe_key(key):
-    try:
-        return str(key)
-    except UnicodeEncodeError:
-        return key.encode("utf-8", "ignore")
-
-
-def _safe_attr(attr, camel_killer=False, replacement_char='x'):
-    """Convert a key into something that is accessible as an attribute"""
-    allowed = string.ascii_letters + string.digits + '_'
-
-    attr = _safe_key(attr)
-
-    if camel_killer:
-        attr = _camel_killer(attr)
-
-    attr = attr.replace(' ', '_')
-
-    out = ''
-    for character in attr:
-        out += character if character in allowed else "_"
-    out = out.strip("_")
-
-    try:
-        int(out[0])
-    except (ValueError, IndexError):
-        pass
-    else:
-        out = '{0}{1}'.format(replacement_char, out)
-
-    if out in kwlist:
-        out = '{0}{1}'.format(replacement_char, out)
-
-    return re.sub('_+', '_', out)
-
-
-def _camel_killer(attr):
-    """
-    CamelKiller, qu'est-ce que c'est?
-
-    Taken from http://stackoverflow.com/a/1176023/3244542
-    """
-    try:
-        attr = str(attr)
-    except UnicodeEncodeError:
-        attr = attr.encode("utf-8", "ignore")
-
-    s1 = _first_cap_re.sub(r'\1_\2', attr)
-    s2 = _all_cap_re.sub(r'\1_\2', s1)
-    return re.sub('_+', '_', s2.casefold() if hasattr(s2, 'casefold') else
-                  s2.lower())
-
-
-def _recursive_tuples(iterable, box_class, recreate_tuples=False, **kwargs):
-    out_list = []
-    for i in iterable:
-        if isinstance(i, dict):
-            out_list.append(box_class(i, **kwargs))
-        elif isinstance(i, list) or (recreate_tuples and isinstance(i, tuple)):
-            out_list.append(_recursive_tuples(i, box_class,
-                                              recreate_tuples, **kwargs))
-        else:
-            out_list.append(i)
-    return tuple(out_list)
-
-
-def _conversion_checks(item, keys, box_config, check_only=False,
-                       pre_check=False):
-    """
-    Internal use for checking if a duplicate safe attribute already exists
-
-    :param item: Item to see if a dup exists
-    :param keys: Keys to check against
-    :param box_config: Easier to pass in than ask for specfic items
-    :param check_only: Don't bother doing the conversion work
-    :param pre_check: Need to add the item to the list of keys to check
-    :return: the original unmodified key, if exists and not check_only
-    """
-    if box_config['box_duplicates'] != 'ignore':
-        if pre_check:
-            keys = list(keys) + [item]
-
-        key_list = [(k,
-                     _safe_attr(k, camel_killer=box_config['camel_killer_box'],
-                                replacement_char=box_config['box_safe_prefix']
-                                )) for k in keys]
-        if len(key_list) > len(set(x[1] for x in key_list)):
-            seen = set()
-            dups = set()
-            for x in key_list:
-                if x[1] in seen:
-                    dups.add("{0}({1})".format(x[0], x[1]))
-                seen.add(x[1])
-            if box_config['box_duplicates'].startswith("warn"):
-                warnings.warn('Duplicate conversion attributes exist: '
-                              '{0}'.format(dups))
-            else:
-                raise BoxError('Duplicate conversion attributes exist: '
-                               '{0}'.format(dups))
-    if check_only:
-        return
-    # This way will be slower for warnings, as it will have double work
-    # But faster for the default 'ignore'
-    for k in keys:
-        if item == _safe_attr(k, camel_killer=box_config['camel_killer_box'],
-                              replacement_char=box_config['box_safe_prefix']):
-            return k
-
-
-def _get_box_config(cls, kwargs):
-    return {
-        # Internal use only
-        '__converted': set(),
-        '__box_heritage': kwargs.pop('__box_heritage', None),
-        '__created': False,
-        '__ordered_box_values': [],
-        # Can be changed by user after box creation
-        'default_box': kwargs.pop('default_box', False),
-        'default_box_attr': kwargs.pop('default_box_attr', cls),
-        'conversion_box': kwargs.pop('conversion_box', True),
-        'box_safe_prefix': kwargs.pop('box_safe_prefix', 'x'),
-        'frozen_box': kwargs.pop('frozen_box', False),
-        'camel_killer_box': kwargs.pop('camel_killer_box', False),
-        'modify_tuples_box': kwargs.pop('modify_tuples_box', False),
-        'box_duplicates': kwargs.pop('box_duplicates', 'ignore'),
-        'ordered_box': kwargs.pop('ordered_box', False)
-    }
-
-
-class Box(dict):
-    """
-    Improved dictionary access through dot notation with additional tools.
-
-    :param default_box: Similar to defaultdict, return a default value
-    :param default_box_attr: Specify the default replacement.
-        WARNING: If this is not the default 'Box', it will not be recursive
-    :param frozen_box: After creation, the box cannot be modified
-    :param camel_killer_box: Convert CamelCase to snake_case
-    :param conversion_box: Check for near matching keys as attributes
-    :param modify_tuples_box: Recreate incoming tuples with dicts into Boxes
-    :param box_it_up: Recursively create all Boxes from the start
-    :param box_safe_prefix: Conversion box prefix for unsafe attributes
-    :param box_duplicates: "ignore", "error" or "warn" when duplicates exists
-        in a conversion_box
-    :param ordered_box: Preserve the order of keys entered into the box
-    """
-
-    _protected_keys = dir({}) + ['to_dict', 'tree_view', 'to_json', 'to_yaml',
-                                 'from_yaml', 'from_json']
-
-    def __new__(cls, *args, **kwargs):
-        """
-        Due to the way pickling works in python 3, we need to make sure
-        the box config is created as early as possible.
-        """
-        obj = super(Box, cls).__new__(cls, *args, **kwargs)
-        obj._box_config = _get_box_config(cls, kwargs)
-        return obj
-
-    def __init__(self, *args, **kwargs):
-        self._box_config = _get_box_config(self.__class__, kwargs)
-        if self._box_config['ordered_box']:
-            self._box_config['__ordered_box_values'] = []
-        if (not self._box_config['conversion_box'] and
-                self._box_config['box_duplicates'] != "ignore"):
-            raise BoxError('box_duplicates are only for conversion_boxes')
-        if len(args) == 1:
-            if isinstance(args[0], basestring):
-                raise ValueError('Cannot extrapolate Box from string')
-            if isinstance(args[0], Mapping):
-                for k, v in args[0].items():
-                    if v is args[0]:
-                        v = self
-                    self[k] = v
-                    self.__add_ordered(k)
-            elif isinstance(args[0], Iterable):
-                for k, v in args[0]:
-                    self[k] = v
-                    self.__add_ordered(k)
-
-            else:
-                raise ValueError('First argument must be mapping or iterable')
-        elif args:
-            raise TypeError('Box expected at most 1 argument, '
-                            'got {0}'.format(len(args)))
-
-        box_it = kwargs.pop('box_it_up', False)
-        for k, v in kwargs.items():
-            if args and isinstance(args[0], Mapping) and v is args[0]:
-                v = self
-            self[k] = v
-            self.__add_ordered(k)
-
-        if (self._box_config['frozen_box'] or box_it or
-                self._box_config['box_duplicates'] != 'ignore'):
-            self.box_it_up()
-
-        self._box_config['__created'] = True
-
-    def __add_ordered(self, key):
-        if (self._box_config['ordered_box'] and
-                key not in self._box_config['__ordered_box_values']):
-            self._box_config['__ordered_box_values'].append(key)
-
-    def box_it_up(self):
-        """
-        Perform value lookup for all items in current dictionary,
-        generating all sub Box objects, while also running `box_it_up` on
-        any of those sub box objects.
-        """
-        for k in self:
-            _conversion_checks(k, self.keys(), self._box_config,
-                               check_only=True)
-            if self[k] is not self and hasattr(self[k], 'box_it_up'):
-                self[k].box_it_up()
-
-    def __hash__(self):
-        if self._box_config['frozen_box']:
-            hashing = 54321
-            for item in self.items():
-                hashing ^= hash(item)
-            return hashing
-        raise TypeError("unhashable type: 'Box'")
-
-    def __dir__(self):
-        allowed = string.ascii_letters + string.digits + '_'
-        kill_camel = self._box_config['camel_killer_box']
-        items = set(dir(dict) + ['to_dict', 'to_json',
-                                 'from_json', 'box_it_up'])
-        # Only show items accessible by dot notation
-        for key in self.keys():
-            key = _safe_key(key)
-            if (' ' not in key and key[0] not in string.digits and
-                    key not in kwlist):
-                for letter in key:
-                    if letter not in allowed:
-                        break
-                else:
-                    items.add(key)
-
-        for key in self.keys():
-            key = _safe_key(key)
-            if key not in items:
-                if self._box_config['conversion_box']:
-                    key = _safe_attr(key, camel_killer=kill_camel,
-                                     replacement_char=self._box_config[
-                                         'box_safe_prefix'])
-                    if key:
-                        items.add(key)
-            if kill_camel:
-                snake_key = _camel_killer(key)
-                if snake_key:
-                    items.remove(key)
-                    items.add(snake_key)
-
-        if yaml_support:
-            items.add('to_yaml')
-            items.add('from_yaml')
-
-        return list(items)
-
-    def get(self, key, default=None):
-        try:
-            return self[key]
-        except KeyError:
-            if isinstance(default, dict) and not isinstance(default, Box):
-                return Box(default)
-            if isinstance(default, list) and not isinstance(default, BoxList):
-                return BoxList(default)
-            return default
-
-    def copy(self):
-        return self.__class__(super(self.__class__, self).copy())
-
-    def __copy__(self):
-        return self.__class__(super(self.__class__, self).copy())
-
-    def __deepcopy__(self, memodict=None):
-        out = self.__class__()
-        memodict = memodict or {}
-        memodict[id(self)] = out
-        for k, v in self.items():
-            out[copy.deepcopy(k, memodict)] = copy.deepcopy(v, memodict)
-        return out
-
-    def __setstate__(self, state):
-        self._box_config = state['_box_config']
-        self.__dict__.update(state)
-
-    def __getitem__(self, item, _ignore_default=False):
-        try:
-            value = super(Box, self).__getitem__(item)
-        except KeyError as err:
-            if item == '_box_config':
-                raise BoxKeyError('_box_config should only exist as an '
-                                  'attribute and is never defaulted')
-            if self._box_config['default_box'] and not _ignore_default:
-                return self.__get_default(item)
-            raise BoxKeyError(str(err))
-        else:
-            return self.__convert_and_store(item, value)
-
-    def keys(self):
-        if self._box_config['ordered_box']:
-            return self._box_config['__ordered_box_values']
-        return super(Box, self).keys()
-
-    def values(self):
-        return [self[x] for x in self.keys()]
-
-    def items(self):
-        return [(x, self[x]) for x in self.keys()]
-
-    def __get_default(self, item):
-        default_value = self._box_config['default_box_attr']
-        if default_value is self.__class__:
-            return self.__class__(__box_heritage=(self, item),
-                                  **self.__box_config())
-        elif isinstance(default_value, Callable):
-            return default_value()
-        elif hasattr(default_value, 'copy'):
-            return default_value.copy()
-        return default_value
-
-    def __box_config(self):
-        out = {}
-        for k, v in self._box_config.copy().items():
-            if not k.startswith("__"):
-                out[k] = v
-        return out
-
-    def __convert_and_store(self, item, value):
-        if item in self._box_config['__converted']:
-            return value
-        if isinstance(value, dict) and not isinstance(value, Box):
-            value = self.__class__(value, __box_heritage=(self, item),
-                                   **self.__box_config())
-            self[item] = value
-        elif isinstance(value, list) and not isinstance(value, BoxList):
-            if self._box_config['frozen_box']:
-                value = _recursive_tuples(value, self.__class__,
-                                          recreate_tuples=self._box_config[
-                                              'modify_tuples_box'],
-                                          __box_heritage=(self, item),
-                                          **self.__box_config())
-            else:
-                value = BoxList(value, __box_heritage=(self, item),
-                                box_class=self.__class__,
-                                **self.__box_config())
-            self[item] = value
-        elif (self._box_config['modify_tuples_box'] and
-              isinstance(value, tuple)):
-            value = _recursive_tuples(value, self.__class__,
-                                      recreate_tuples=True,
-                                      __box_heritage=(self, item),
-                                      **self.__box_config())
-            self[item] = value
-        self._box_config['__converted'].add(item)
-        return value
-
-    def __create_lineage(self):
-        if (self._box_config['__box_heritage'] and
-                self._box_config['__created']):
-            past, item = self._box_config['__box_heritage']
-            if not past[item]:
-                past[item] = self
-            self._box_config['__box_heritage'] = None
-
-    def __getattr__(self, item):
-        try:
-            try:
-                value = self.__getitem__(item, _ignore_default=True)
-            except KeyError:
-                value = object.__getattribute__(self, item)
-        except AttributeError as err:
-            if item == "__getstate__":
-                raise AttributeError(item)
-            if item == '_box_config':
-                raise BoxError('_box_config key must exist')
-            kill_camel = self._box_config['camel_killer_box']
-            if self._box_config['conversion_box'] and item:
-                k = _conversion_checks(item, self.keys(), self._box_config)
-                if k:
-                    return self.__getitem__(k)
-            if kill_camel:
-                for k in self.keys():
-                    if item == _camel_killer(k):
-                        return self.__getitem__(k)
-            if self._box_config['default_box']:
-                return self.__get_default(item)
-            raise BoxKeyError(str(err))
-        else:
-            if item == '_box_config':
-                return value
-            return self.__convert_and_store(item, value)
-
-    def __setitem__(self, key, value):
-        if (key != '_box_config' and self._box_config['__created'] and
-                self._box_config['frozen_box']):
-            raise BoxError('Box is frozen')
-        if self._box_config['conversion_box']:
-            _conversion_checks(key, self.keys(), self._box_config,
-                               check_only=True, pre_check=True)
-        super(Box, self).__setitem__(key, value)
-        self.__add_ordered(key)
-        self.__create_lineage()
-
-    def __setattr__(self, key, value):
-        if (key != '_box_config' and self._box_config['frozen_box'] and
-                self._box_config['__created']):
-            raise BoxError('Box is frozen')
-        if key in self._protected_keys:
-            raise AttributeError("Key name '{0}' is protected".format(key))
-        if key == '_box_config':
-            return object.__setattr__(self, key, value)
-        try:
-            object.__getattribute__(self, key)
-        except (AttributeError, UnicodeEncodeError):
-            if (key not in self.keys() and
-                    (self._box_config['conversion_box'] or
-                     self._box_config['camel_killer_box'])):
-                if self._box_config['conversion_box']:
-                    k = _conversion_checks(key, self.keys(),
-                                           self._box_config)
-                    self[key if not k else k] = value
-                elif self._box_config['camel_killer_box']:
-                    for each_key in self:
-                        if key == _camel_killer(each_key):
-                            self[each_key] = value
-                            break
-            else:
-                self[key] = value
-        else:
-            object.__setattr__(self, key, value)
-        self.__add_ordered(key)
-        self.__create_lineage()
-
-    def __delitem__(self, key):
-        if self._box_config['frozen_box']:
-            raise BoxError('Box is frozen')
-        super(Box, self).__delitem__(key)
-        if (self._box_config['ordered_box'] and
-                key in self._box_config['__ordered_box_values']):
-            self._box_config['__ordered_box_values'].remove(key)
-
-    def __delattr__(self, item):
-        if self._box_config['frozen_box']:
-            raise BoxError('Box is frozen')
-        if item == '_box_config':
-            raise BoxError('"_box_config" is protected')
-        if item in self._protected_keys:
-            raise AttributeError("Key name '{0}' is protected".format(item))
-        try:
-            object.__getattribute__(self, item)
-        except AttributeError:
-            del self[item]
-        else:
-            object.__delattr__(self, item)
-        if (self._box_config['ordered_box'] and
-                item in self._box_config['__ordered_box_values']):
-            self._box_config['__ordered_box_values'].remove(item)
-
-    def pop(self, key, *args):
-        if args:
-            if len(args) != 1:
-                raise BoxError('pop() takes only one optional'
-                               ' argument "default"')
-            try:
-                item = self[key]
-            except KeyError:
-                return args[0]
-            else:
-                del self[key]
-                return item
-        try:
-            item = self[key]
-        except KeyError:
-            raise BoxKeyError('{0}'.format(key))
-        else:
-            del self[key]
-            return item
-
-    def clear(self):
-        self._box_config['__ordered_box_values'] = []
-        super(Box, self).clear()
-
-    def popitem(self):
-        try:
-            key = next(self.__iter__())
-        except StopIteration:
-            raise BoxKeyError('Empty box')
-        return key, self.pop(key)
-
-    def __repr__(self):
-        return '<Box: {0}>'.format(str(self.to_dict()))
-
-    def __str__(self):
-        return str(self.to_dict())
-
-    def __iter__(self):
-        for key in self.keys():
-            yield key
-
-    def __reversed__(self):
-        for key in reversed(list(self.keys())):
-            yield key
-
-    def to_dict(self):
-        """
-        Turn the Box and sub Boxes back into a native
-        python dictionary.
-
-        :return: python dictionary of this Box
-        """
-        out_dict = dict(self)
-        for k, v in out_dict.items():
-            if v is self:
-                out_dict[k] = out_dict
-            elif hasattr(v, 'to_dict'):
-                out_dict[k] = v.to_dict()
-            elif hasattr(v, 'to_list'):
-                out_dict[k] = v.to_list()
-        return out_dict
-
-    def update(self, item=None, **kwargs):
-        if not item:
-            item = kwargs
-        iter_over = item.items() if hasattr(item, 'items') else item
-        for k, v in iter_over:
-            if isinstance(v, dict):
-                # Box objects must be created in case they are already
-                # in the `converted` box_config set
-                v = self.__class__(v)
-                if k in self and isinstance(self[k], dict):
-                    self[k].update(v)
-                    continue
-            if isinstance(v, list):
-                v = BoxList(v)
-            try:
-                self.__setattr__(k, v)
-            except (AttributeError, TypeError):
-                self.__setitem__(k, v)
-
-    def setdefault(self, item, default=None):
-        if item in self:
-            return self[item]
-
-        if isinstance(default, dict):
-            default = self.__class__(default)
-        if isinstance(default, list):
-            default = BoxList(default)
-        self[item] = default
-        return default
-
-    def to_json(self, filename=None,
-                encoding="utf-8", errors="strict", **json_kwargs):
-        """
-        Transform the Box object into a JSON string.
-
-        :param filename: If provided will save to file
-        :param encoding: File encoding
-        :param errors: How to handle encoding errors
-        :param json_kwargs: additional arguments to pass to json.dump(s)
-        :return: string of JSON or return of `json.dump`
-        """
-        return _to_json(self.to_dict(), filename=filename,
-                        encoding=encoding, errors=errors, **json_kwargs)
-
-    @classmethod
-    def from_json(cls, json_string=None, filename=None,
-                  encoding="utf-8", errors="strict", **kwargs):
-        """
-        Transform a json object string into a Box object. If the incoming
-        json is a list, you must use BoxList.from_json.
-
-        :param json_string: string to pass to `json.loads`
-        :param filename: filename to open and pass to `json.load`
-        :param encoding: File encoding
-        :param errors: How to handle encoding errors
-        :param kwargs: parameters to pass to `Box()` or `json.loads`
-        :return: Box object from json data
-        """
-        bx_args = {}
-        for arg in kwargs.copy():
-            if arg in BOX_PARAMETERS:
-                bx_args[arg] = kwargs.pop(arg)
-
-        data = _from_json(json_string, filename=filename,
-                          encoding=encoding, errors=errors, **kwargs)
-
-        if not isinstance(data, dict):
-            raise BoxError('json data not returned as a dictionary, '
-                           'but rather a {0}'.format(type(data).__name__))
-        return cls(data, **bx_args)
-
-    if yaml_support:
-        def to_yaml(self, filename=None, default_flow_style=False,
-                    encoding="utf-8", errors="strict",
-                    **yaml_kwargs):
-            """
-            Transform the Box object into a YAML string.
-
-            :param filename:  If provided will save to file
-            :param default_flow_style: False will recursively dump dicts
-            :param encoding: File encoding
-            :param errors: How to handle encoding errors
-            :param yaml_kwargs: additional arguments to pass to yaml.dump
-            :return: string of YAML or return of `yaml.dump`
-            """
-            return _to_yaml(self.to_dict(), filename=filename,
-                            default_flow_style=default_flow_style,
-                            encoding=encoding, errors=errors, **yaml_kwargs)
-
-        @classmethod
-        def from_yaml(cls, yaml_string=None, filename=None,
-                      encoding="utf-8", errors="strict",
-                      loader=yaml.SafeLoader, **kwargs):
-            """
-            Transform a yaml object string into a Box object.
-
-            :param yaml_string: string to pass to `yaml.load`
-            :param filename: filename to open and pass to `yaml.load`
-            :param encoding: File encoding
-            :param errors: How to handle encoding errors
-            :param loader: YAML Loader, defaults to SafeLoader
-            :param kwargs: parameters to pass to `Box()` or `yaml.load`
-            :return: Box object from yaml data
-            """
-            bx_args = {}
-            for arg in kwargs.copy():
-                if arg in BOX_PARAMETERS:
-                    bx_args[arg] = kwargs.pop(arg)
-
-            data = _from_yaml(yaml_string=yaml_string, filename=filename,
-                              encoding=encoding, errors=errors,
-                              Loader=loader, **kwargs)
-            if not isinstance(data, dict):
-                raise BoxError('yaml data not returned as a dictionary'
-                               'but rather a {0}'.format(type(data).__name__))
-            return cls(data, **bx_args)
-
-
-class BoxList(list):
-    """
-    Drop in replacement of list, that converts added objects to Box or BoxList
-    objects as necessary.
-    """
-
-    def __init__(self, iterable=None, box_class=Box, **box_options):
-        self.box_class = box_class
-        self.box_options = box_options
-        self.box_org_ref = self.box_org_ref = id(iterable) if iterable else 0
-        if iterable:
-            for x in iterable:
-                self.append(x)
-        if box_options.get('frozen_box'):
-            def frozen(*args, **kwargs):
-                raise BoxError('BoxList is frozen')
-
-            for method in ['append', 'extend', 'insert', 'pop',
-                           'remove', 'reverse', 'sort']:
-                self.__setattr__(method, frozen)
-
-    def __delitem__(self, key):
-        if self.box_options.get('frozen_box'):
-            raise BoxError('BoxList is frozen')
-        super(BoxList, self).__delitem__(key)
-
-    def __setitem__(self, key, value):
-        if self.box_options.get('frozen_box'):
-            raise BoxError('BoxList is frozen')
-        super(BoxList, self).__setitem__(key, value)
-
-    def append(self, p_object):
-        if isinstance(p_object, dict):
-            try:
-                p_object = self.box_class(p_object, **self.box_options)
-            except AttributeError as err:
-                if 'box_class' in self.__dict__:
-                    raise err
-        elif isinstance(p_object, list):
-            try:
-                p_object = (self if id(p_object) == self.box_org_ref else
-                            BoxList(p_object))
-            except AttributeError as err:
-                if 'box_org_ref' in self.__dict__:
-                    raise err
-        super(BoxList, self).append(p_object)
-
-    def extend(self, iterable):
-        for item in iterable:
-            self.append(item)
-
-    def insert(self, index, p_object):
-        if isinstance(p_object, dict):
-            p_object = self.box_class(p_object, **self.box_options)
-        elif isinstance(p_object, list):
-            p_object = (self if id(p_object) == self.box_org_ref else
-                        BoxList(p_object))
-        super(BoxList, self).insert(index, p_object)
-
-    def __repr__(self):
-        return "<BoxList: {0}>".format(self.to_list())
-
-    def __str__(self):
-        return str(self.to_list())
-
-    def __copy__(self):
-        return BoxList((x for x in self),
-                       self.box_class,
-                       **self.box_options)
-
-    def __deepcopy__(self, memodict=None):
-        out = self.__class__()
-        memodict = memodict or {}
-        memodict[id(self)] = out
-        for k in self:
-            out.append(copy.deepcopy(k))
-        return out
-
-    def __hash__(self):
-        if self.box_options.get('frozen_box'):
-            hashing = 98765
-            hashing ^= hash(tuple(self))
-            return hashing
-        raise TypeError("unhashable type: 'BoxList'")
-
-    def to_list(self):
-        new_list = []
-        for x in self:
-            if x is self:
-                new_list.append(new_list)
-            elif isinstance(x, Box):
-                new_list.append(x.to_dict())
-            elif isinstance(x, BoxList):
-                new_list.append(x.to_list())
-            else:
-                new_list.append(x)
-        return new_list
-
-    def to_json(self, filename=None,
-                encoding="utf-8", errors="strict",
-                multiline=False, **json_kwargs):
-        """
-        Transform the BoxList object into a JSON string.
-
-        :param filename: If provided will save to file
-        :param encoding: File encoding
-        :param errors: How to handle encoding errors
-        :param multiline: Put each item in list onto it's own line
-        :param json_kwargs: additional arguments to pass to json.dump(s)
-        :return: string of JSON or return of `json.dump`
-        """
-        if filename and multiline:
-            lines = [_to_json(item, filename=False, encoding=encoding,
-                              errors=errors, **json_kwargs) for item in self]
-            with open(filename, 'w', encoding=encoding, errors=errors) as f:
-                f.write("\n".join(lines).decode('utf-8') if
-                        sys.version_info < (3, 0) else "\n".join(lines))
-        else:
-            return _to_json(self.to_list(), filename=filename,
-                            encoding=encoding, errors=errors, **json_kwargs)
-
-    @classmethod
-    def from_json(cls, json_string=None, filename=None, encoding="utf-8",
-                  errors="strict", multiline=False, **kwargs):
-        """
-        Transform a json object string into a BoxList object. If the incoming
-        json is a dict, you must use Box.from_json.
-
-        :param json_string: string to pass to `json.loads`
-        :param filename: filename to open and pass to `json.load`
-        :param encoding: File encoding
-        :param errors: How to handle encoding errors
-        :param multiline: One object per line
-        :param kwargs: parameters to pass to `Box()` or `json.loads`
-        :return: BoxList object from json data
-        """
-        bx_args = {}
-        for arg in kwargs.copy():
-            if arg in BOX_PARAMETERS:
-                bx_args[arg] = kwargs.pop(arg)
-
-        data = _from_json(json_string, filename=filename, encoding=encoding,
-                          errors=errors, multiline=multiline, **kwargs)
-
-        if not isinstance(data, list):
-            raise BoxError('json data not returned as a list, '
-                           'but rather a {0}'.format(type(data).__name__))
-        return cls(data, **bx_args)
-
-    if yaml_support:
-        def to_yaml(self, filename=None, default_flow_style=False,
-                    encoding="utf-8", errors="strict",
-                    **yaml_kwargs):
-            """
-            Transform the BoxList object into a YAML string.
-
-            :param filename:  If provided will save to file
-            :param default_flow_style: False will recursively dump dicts
-            :param encoding: File encoding
-            :param errors: How to handle encoding errors
-            :param yaml_kwargs: additional arguments to pass to yaml.dump
-            :return: string of YAML or return of `yaml.dump`
-            """
-            return _to_yaml(self.to_list(), filename=filename,
-                            default_flow_style=default_flow_style,
-                            encoding=encoding, errors=errors, **yaml_kwargs)
-
-        @classmethod
-        def from_yaml(cls, yaml_string=None, filename=None,
-                      encoding="utf-8", errors="strict",
-                      loader=yaml.SafeLoader,
-                      **kwargs):
-            """
-            Transform a yaml object string into a BoxList object.
-
-            :param yaml_string: string to pass to `yaml.load`
-            :param filename: filename to open and pass to `yaml.load`
-            :param encoding: File encoding
-            :param errors: How to handle encoding errors
-            :param loader: YAML Loader, defaults to SafeLoader
-            :param kwargs: parameters to pass to `BoxList()` or `yaml.load`
-            :return: BoxList object from yaml data
-            """
-            bx_args = {}
-            for arg in kwargs.copy():
-                if arg in BOX_PARAMETERS:
-                    bx_args[arg] = kwargs.pop(arg)
-
-            data = _from_yaml(yaml_string=yaml_string, filename=filename,
-                              encoding=encoding, errors=errors,
-                              Loader=loader, **kwargs)
-            if not isinstance(data, list):
-                raise BoxError('yaml data not returned as a list'
-                               'but rather a {0}'.format(type(data).__name__))
-            return cls(data, **bx_args)
-
-    def box_it_up(self):
-        for v in self:
-            if hasattr(v, 'box_it_up') and v is not self:
-                v.box_it_up()
-
-
-class ConfigBox(Box):
-    """
-    Modified box object to add object transforms.
-
-    Allows for build in transforms like:
-
-    cns = ConfigBox(my_bool='yes', my_int='5', my_list='5,4,3,3,2')
-
-    cns.bool('my_bool') # True
-    cns.int('my_int') # 5
-    cns.list('my_list', mod=lambda x: int(x)) # [5, 4, 3, 3, 2]
-    """
-
-    _protected_keys = dir({}) + ['to_dict', 'bool', 'int', 'float',
-                                 'list', 'getboolean', 'to_json', 'to_yaml',
-                                 'getfloat', 'getint',
-                                 'from_json', 'from_yaml']
-
-    def __getattr__(self, item):
-        """Config file keys are stored in lower case, be a little more
-        loosey goosey"""
-        try:
-            return super(ConfigBox, self).__getattr__(item)
-        except AttributeError:
-            return super(ConfigBox, self).__getattr__(item.lower())
-
-    def __dir__(self):
-        return super(ConfigBox, self).__dir__() + ['bool', 'int', 'float',
-                                                   'list', 'getboolean',
-                                                   'getfloat', 'getint']
-
-    def bool(self, item, default=None):
-        """ Return value of key as a boolean
-
-        :param item: key of value to transform
-        :param default: value to return if item does not exist
-        :return: approximated bool of value
-        """
-        try:
-            item = self.__getattr__(item)
-        except AttributeError as err:
-            if default is not None:
-                return default
-            raise err
-
-        if isinstance(item, (bool, int)):
-            return bool(item)
-
-        if (isinstance(item, str) and
-                item.lower() in ('n', 'no', 'false', 'f', '0')):
-            return False
-
-        return True if item else False
-
-    def int(self, item, default=None):
-        """ Return value of key as an int
-
-        :param item: key of value to transform
-        :param default: value to return if item does not exist
-        :return: int of value
-        """
-        try:
-            item = self.__getattr__(item)
-        except AttributeError as err:
-            if default is not None:
-                return default
-            raise err
-        return int(item)
-
-    def float(self, item, default=None):
-        """ Return value of key as a float
-
-        :param item: key of value to transform
-        :param default: value to return if item does not exist
-        :return: float of value
-        """
-        try:
-            item = self.__getattr__(item)
-        except AttributeError as err:
-            if default is not None:
-                return default
-            raise err
-        return float(item)
-
-    def list(self, item, default=None, spliter=",", strip=True, mod=None):
-        """ Return value of key as a list
-
-        :param item: key of value to transform
-        :param mod: function to map against list
-        :param default: value to return if item does not exist
-        :param spliter: character to split str on
-        :param strip: clean the list with the `strip`
-        :return: list of items
-        """
-        try:
-            item = self.__getattr__(item)
-        except AttributeError as err:
-            if default is not None:
-                return default
-            raise err
-        if strip:
-            item = item.lstrip('[').rstrip(']')
-        out = [x.strip() if strip else x for x in item.split(spliter)]
-        if mod:
-            return list(map(mod, out))
-        return out
-
-    # loose configparser compatibility
-
-    def getboolean(self, item, default=None):
-        return self.bool(item, default)
-
-    def getint(self, item, default=None):
-        return self.int(item, default)
-
-    def getfloat(self, item, default=None):
-        return self.float(item, default)
-
-    def __repr__(self):
-        return '<ConfigBox: {0}>'.format(str(self.to_dict()))
-
-
-class SBox(Box):
-    """
-    ShorthandBox (SBox) allows for
-    property access of `dict` `json` and `yaml`
-    """
-    _protected_keys = dir({}) + ['to_dict', 'tree_view', 'to_json', 'to_yaml',
-                                 'json', 'yaml', 'from_yaml', 'from_json',
-                                 'dict']
-
-    @property
-    def dict(self):
-        return self.to_dict()
-
-    @property
-    def json(self):
-        return self.to_json()
-
-    if yaml_support:
-        @property
-        def yaml(self):
-            return self.to_yaml()
-
-    def __repr__(self):
-        return '<ShorthandBox: {0}>'.format(str(self.to_dict()))

+ 0 - 9
lcnn/config.py

@@ -1,9 +0,0 @@
-import numpy as np
-
-from lcnn.box import Box
-
-# C is a dict storing all the configuration
-C = Box()
-
-# shortcut for C.model
-M = Box()

+ 0 - 378
lcnn/dataset_tool.py

@@ -1,378 +0,0 @@
-import cv2
-import numpy as np
-import torch
-import torchvision
-from matplotlib import pyplot as plt
-import tools.transforms as reference_transforms
-from collections import defaultdict
-
-from tools import presets
-
-import json
-
-
-def get_modules(use_v2):
-    # We need a protected import to avoid the V2 warning in case just V1 is used
-    if use_v2:
-        import torchvision.transforms.v2
-        import torchvision.tv_tensors
-
-        return torchvision.transforms.v2, torchvision.tv_tensors
-    else:
-        return reference_transforms, None
-
-
-class Augmentation:
-    # Note: this transform assumes that the input to forward() are always PIL
-    # images, regardless of the backend parameter.
-    def __init__(
-            self,
-            *,
-            data_augmentation,
-            hflip_prob=0.5,
-            mean=(123.0, 117.0, 104.0),
-            backend="pil",
-            use_v2=False,
-    ):
-
-        T, tv_tensors = get_modules(use_v2)
-
-        transforms = []
-        backend = backend.lower()
-        if backend == "tv_tensor":
-            transforms.append(T.ToImage())
-        elif backend == "tensor":
-            transforms.append(T.PILToTensor())
-        elif backend != "pil":
-            raise ValueError(f"backend can be 'tv_tensor', 'tensor' or 'pil', but got {backend}")
-
-        if data_augmentation == "hflip":
-            transforms += [T.RandomHorizontalFlip(p=hflip_prob)]
-        elif data_augmentation == "lsj":
-            transforms += [
-                T.ScaleJitter(target_size=(1024, 1024), antialias=True),
-                # TODO: FixedSizeCrop below doesn't work on tensors!
-                reference_transforms.FixedSizeCrop(size=(1024, 1024), fill=mean),
-                T.RandomHorizontalFlip(p=hflip_prob),
-            ]
-        elif data_augmentation == "multiscale":
-            transforms += [
-                T.RandomShortestSize(min_size=(480, 512, 544, 576, 608, 640, 672, 704, 736, 768, 800), max_size=1333),
-                T.RandomHorizontalFlip(p=hflip_prob),
-            ]
-        elif data_augmentation == "ssd":
-            fill = defaultdict(lambda: mean, {tv_tensors.Mask: 0}) if use_v2 else list(mean)
-            transforms += [
-                T.RandomPhotometricDistort(),
-                T.RandomZoomOut(fill=fill),
-                T.RandomIoUCrop(),
-                T.RandomHorizontalFlip(p=hflip_prob),
-            ]
-        elif data_augmentation == "ssdlite":
-            transforms += [
-                T.RandomIoUCrop(),
-                T.RandomHorizontalFlip(p=hflip_prob),
-            ]
-        else:
-            raise ValueError(f'Unknown data augmentation policy "{data_augmentation}"')
-
-        if backend == "pil":
-            # Note: we could just convert to pure tensors even in v2.
-            transforms += [T.ToImage() if use_v2 else T.PILToTensor()]
-
-        transforms += [T.ToDtype(torch.float, scale=True)]
-
-        if use_v2:
-            transforms += [
-                T.ConvertBoundingBoxFormat(tv_tensors.BoundingBoxFormat.XYXY),
-                T.SanitizeBoundingBoxes(),
-                T.ToPureTensor(),
-            ]
-
-        self.transforms = T.Compose(transforms)
-
-    def __call__(self, img, target):
-        return self.transforms(img, target)
-
-
-def read_polygon_points(lbl_path, shape):
-    """读取 YOLOv8 格式的标注文件并解析多边形轮廓"""
-    polygon_points = []
-    w, h = shape[:2]
-    with open(lbl_path, 'r') as f:
-        lines = f.readlines()
-
-    for line in lines:
-        parts = line.strip().split()
-        class_id = int(parts[0])
-        points = np.array(parts[1:], dtype=np.float32).reshape(-1, 2)  # 读取点坐标
-        points[:, 0] *= h
-        points[:, 1] *= w
-
-        polygon_points.append((class_id, points))
-
-    return polygon_points
-
-
-def read_masks_from_pixels(lbl_path, shape):
-    """读取纯像素点格式的文件,不是轮廓像素点"""
-    h, w = shape
-    masks = []
-    labels = []
-
-    with open(lbl_path, 'r') as reader:
-        lines = reader.readlines()
-        mask_points = []
-        for line in lines:
-            mask = torch.zeros((h, w), dtype=torch.uint8)
-            parts = line.strip().split()
-            # print(f'parts:{parts}')
-            cls = torch.tensor(int(parts[0]), dtype=torch.int64)
-            labels.append(cls)
-            x_array = parts[1::2]
-            y_array = parts[2::2]
-
-            for x, y in zip(x_array, y_array):
-                x = float(x)
-                y = float(y)
-                mask_points.append((int(y * h), int(x * w)))
-
-            for p in mask_points:
-                mask[p] = 1
-            masks.append(mask)
-    reader.close()
-    return labels, masks
-
-
-def create_masks_from_polygons(polygons, image_shape):
-    """创建一个与图像尺寸相同的掩码,并填充多边形轮廓"""
-    colors = np.array([plt.cm.Spectral(each) for each in np.linspace(0, 1, len(polygons))])
-    masks = []
-
-    for polygon_data, col in zip(polygons, colors):
-        mask = np.zeros(image_shape[:2], dtype=np.uint8)
-        # 将多边形顶点转换为 NumPy 数组
-        _, polygon = polygon_data
-        pts = np.array(polygon, np.int32).reshape((-1, 1, 2))
-
-        # 使用 OpenCV 的 fillPoly 函数填充多边形
-        # print(f'color:{col[:3]}')
-        cv2.fillPoly(mask, [pts], np.array(col[:3]) * 255)
-        mask = torch.from_numpy(mask)
-        mask[mask != 0] = 1
-        masks.append(mask)
-
-    return masks
-
-
-def read_masks_from_txt(label_path, shape):
-    polygon_points = read_polygon_points(label_path, shape)
-    masks = create_masks_from_polygons(polygon_points, shape)
-    labels = [torch.tensor(item[0]) for item in polygon_points]
-
-    return labels, masks
-
-
-def masks_to_boxes(masks: torch.Tensor, ) -> torch.Tensor:
-    """
-    Compute the bounding boxes around the provided masks.
-
-    Returns a [N, 4] tensor containing bounding boxes. The boxes are in ``(x1, y1, x2, y2)`` format with
-    ``0 <= x1 < x2`` and ``0 <= y1 < y2``.
-
-    Args:
-        masks (Tensor[N, H, W]): masks to transform where N is the number of masks
-            and (H, W) are the spatial dimensions.
-
-    Returns:
-        Tensor[N, 4]: bounding boxes
-    """
-    # if not torch.jit.is_scripting() and not torch.jit.is_tracing():
-    #     _log_api_usage_once(masks_to_boxes)
-    if masks.numel() == 0:
-        return torch.zeros((0, 4), device=masks.device, dtype=torch.float)
-
-    n = masks.shape[0]
-
-    bounding_boxes = torch.zeros((n, 4), device=masks.device, dtype=torch.float)
-
-    for index, mask in enumerate(masks):
-        y, x = torch.where(mask != 0)
-        bounding_boxes[index, 0] = torch.min(x)
-        bounding_boxes[index, 1] = torch.min(y)
-        bounding_boxes[index, 2] = torch.max(x)
-        bounding_boxes[index, 3] = torch.max(y)
-        # debug to pixel datasets
-
-        if bounding_boxes[index, 0] == bounding_boxes[index, 2]:
-            bounding_boxes[index, 2] = bounding_boxes[index, 2] + 1
-            bounding_boxes[index, 0] = bounding_boxes[index, 0] - 1
-
-        if bounding_boxes[index, 1] == bounding_boxes[index, 3]:
-            bounding_boxes[index, 3] = bounding_boxes[index, 3] + 1
-            bounding_boxes[index, 1] = bounding_boxes[index, 1] - 1
-
-    return bounding_boxes
-
-
-def line_boxes_faster(target):
-    boxs = []
-    lpre = target["lpre"].cpu().numpy() * 4
-    vecl_target = target["lpre_label"].cpu().numpy()
-    lpre = lpre[vecl_target == 1]
-
-    lines = lpre
-    sline = np.ones(lpre.shape[0])
-
-    if len(lines) > 0 and not (lines[0] == 0).all():
-        for i, ((a, b), s) in enumerate(zip(lines, sline)):
-            if i > 0 and (lines[i] == lines[0]).all():
-                break
-            # plt.plot([a[1], b[1]], [a[0], b[0]], c="red", linewidth=1)  # a[1], b[1]无明确大小
-
-            if a[1] > b[1]:
-                ymax = a[1] + 10
-                ymin = b[1] - 10
-            else:
-                ymin = a[1] - 10
-                ymax = b[1] + 10
-            if a[0] > b[0]:
-                xmax = a[0] + 10
-                xmin = b[0] - 10
-            else:
-                xmin = a[0] - 10
-                xmax = b[0] + 10
-            boxs.append([max(0, ymin), max(0, xmin), min(512, ymax), min(512, xmax)])
-
-    return torch.tensor(boxs)
-
-
-# 将线段变为 [起点,终点]形式  传入[num,2,2]
-def a_to_b(line):
-    result_pairs = []
-    for a, b in line:
-        min_x = min(a[0], b[0])
-        min_y = min(a[1], b[1])
-        new_top_left_x = max((min_x - 10), 0)
-        new_top_left_y = max((min_y - 10), 0)
-        dist_a = (a[0] - new_top_left_x) ** 2 + (a[1] - new_top_left_y) ** 2
-        dist_b = (b[0] - new_top_left_x) ** 2 + (b[1] - new_top_left_y) ** 2
-
-        # 根据距离选择起点并设置标签
-        if dist_a <= dist_b:  # 如果a点离新左上角更近或两者距离相等
-            result_pairs.append([a, b])  # 将a设为起点,b为终点
-        else:  # 如果b点离新左上角更近
-            result_pairs.append([b, a])  # 将b设为起点,a为终点
-    result_tensor = torch.stack([torch.stack(row) for row in result_pairs])
-    return result_tensor
-
-
-def read_polygon_points_wire(lbl_path, shape):
-    """读取 YOLOv8 格式的标注文件并解析多边形轮廓"""
-    polygon_points = []
-    w, h = shape[:2]
-    with open(lbl_path, 'r') as f:
-        lines = json.load(f)
-
-    for line in lines["segmentations"]:
-        parts = line["data"]
-        class_id = int(line["cls_id"])
-        points = np.array(parts, dtype=np.float32).reshape(-1, 2)  # 读取点坐标
-        points[:, 0] *= h
-        points[:, 1] *= w
-
-        polygon_points.append((class_id, points))
-
-    return polygon_points
-
-
-def read_masks_from_txt_wire(label_path, shape):
-    polygon_points = read_polygon_points_wire(label_path, shape)
-    masks = create_masks_from_polygons(polygon_points, shape)
-    labels = [torch.tensor(item[0]) for item in polygon_points]
-
-    return labels, masks
-
-
-def read_masks_from_pixels_wire(lbl_path, shape):
-    """读取纯像素点格式的文件,不是轮廓像素点"""
-    h, w = shape
-    masks = []
-    labels = []
-
-    with open(lbl_path, 'r') as reader:
-        lines = json.load(reader)
-        mask_points = []
-        for line in lines["segmentations"]:
-            # mask = torch.zeros((h, w), dtype=torch.uint8)
-            # parts = line["data"]
-            # print(f'parts:{parts}')
-            cls = torch.tensor(int(line["cls_id"]), dtype=torch.int64)
-            labels.append(cls)
-            # x_array = parts[0::2]
-            # y_array = parts[1::2]
-            # 
-            # for x, y in zip(x_array, y_array):
-            #     x = float(x)
-            #     y = float(y)
-            #     mask_points.append((int(y * h), int(x * w)))
-
-            # for p in mask_points:
-            #     mask[p] = 1
-            # masks.append(mask)
-    reader.close()
-    return labels
-
-
-def read_lable_keypoint(lbl_path):
-    """判断线段的起点终点, 起点 lable=0, 终点 lable=1"""
-    labels = []
-
-    with open(lbl_path, 'r') as reader:
-        lines = json.load(reader)
-        aa = lines["wires"][0]["line_pos_coords"]["content"]
-
-        result_pairs = []
-        for a, b in aa:
-            min_x = min(a[0], b[0])
-            min_y = min(a[1], b[1])
-
-            # 定义新的左上角位置
-            new_top_left_x = max((min_x - 10), 0)
-            new_top_left_y = max((min_y - 10), 0)
-
-            # Step 2: 计算各点到新左上角的距离平方(避免浮点运算误差)
-            dist_a = (a[0] - new_top_left_x) ** 2 + (a[1] - new_top_left_y) ** 2
-            dist_b = (b[0] - new_top_left_x) ** 2 + (b[1] - new_top_left_y) ** 2
-
-            # Step 3 & 4: 根据距离选择起点并设置标签
-            if dist_a <= dist_b:  # 如果a点离新左上角更近或两者距离相等
-                result_pairs.append([a, b])  # 将a设为起点,b为终点
-            else:  # 如果b点离新左上角更近
-                result_pairs.append([b, a])  # 将b设为起点,a为终点
-
-            # x_ = abs(a[0] - b[0])
-            # y_ = abs(a[1] - b[1])
-            # if x_ > y_:  # x小的离左上角近
-            #     if a[0] < b[0]:  # 视为起点,lable=0
-            #         label = [0, 1]
-            #     else:
-            #         label = [1, 0]
-            # else:  # x大的是起点
-            #     if a[0] > b[0]:  # 视为起点,lable=0
-            #         label = [0, 1]
-            #     else:
-            #         label = [1, 0]
-            # labels.append(label)
-        # print(result_pairs )
-    reader.close()
-    return labels
-
-
-def adjacency_matrix(n, link):  # 邻接矩阵
-    mat = torch.zeros(n + 1, n + 1, dtype=torch.uint8)
-    link = torch.tensor(link)
-    if len(link) > 0:
-        mat[link[:, 0], link[:, 1]] = 1
-        mat[link[:, 1], link[:, 0]] = 1
-    return mat

+ 0 - 209
lcnn/metric.py

@@ -1,209 +0,0 @@
-import numpy as np
-import numpy.linalg as LA
-import matplotlib.pyplot as plt
-
-from lcnn.utils import argsort2d
-
-DX = [0, 0, 1, -1, 1, 1, -1, -1]
-DY = [1, -1, 0, 0, 1, -1, 1, -1]
-
-
-def ap(tp, fp):
-    recall = tp
-    precision = tp / np.maximum(tp + fp, 1e-9)
-
-    recall = np.concatenate(([0.0], recall, [1.0]))
-    precision = np.concatenate(([0.0], precision, [0.0]))
-
-    for i in range(precision.size - 1, 0, -1):
-        precision[i - 1] = max(precision[i - 1], precision[i])
-    i = np.where(recall[1:] != recall[:-1])[0]
-    return np.sum((recall[i + 1] - recall[i]) * precision[i + 1])
-
-
-def APJ(vert_pred, vert_gt, max_distance, im_ids):
-    if len(vert_pred) == 0:
-        return 0
-
-    vert_pred = np.array(vert_pred)
-    vert_gt = np.array(vert_gt)
-
-    confidence = vert_pred[:, -1]
-    idx = np.argsort(-confidence)
-    vert_pred = vert_pred[idx, :]
-    im_ids = im_ids[idx]
-    n_gt = sum(len(gt) for gt in vert_gt)
-
-    nd = len(im_ids)
-    tp, fp = np.zeros(nd, dtype=np.float), np.zeros(nd, dtype=np.float)
-    hit = [[False for _ in j] for j in vert_gt]
-
-    for i in range(nd):
-        gt_juns = vert_gt[im_ids[i]]
-        pred_juns = vert_pred[i][:-1]
-        if len(gt_juns) == 0:
-            continue
-        dists = np.linalg.norm((pred_juns[None, :] - gt_juns), axis=1)
-        choice = np.argmin(dists)
-        dist = np.min(dists)
-        if dist < max_distance and not hit[im_ids[i]][choice]:
-            tp[i] = 1
-            hit[im_ids[i]][choice] = True
-        else:
-            fp[i] = 1
-
-    tp = np.cumsum(tp) / n_gt
-    fp = np.cumsum(fp) / n_gt
-    return ap(tp, fp)
-
-
-def nms_j(heatmap, delta=1):
-    heatmap = heatmap.copy()
-    disable = np.zeros_like(heatmap, dtype=np.bool)
-    for x, y in argsort2d(heatmap):
-        for dx, dy in zip(DX, DY):
-            xp, yp = x + dx, y + dy
-            if not (0 <= xp < heatmap.shape[0] and 0 <= yp < heatmap.shape[1]):
-                continue
-            if heatmap[x, y] >= heatmap[xp, yp]:
-                disable[xp, yp] = True
-    heatmap[disable] *= 0.6
-    return heatmap
-
-
-def mAPJ(pred, truth, distances, im_ids):
-    return sum(APJ(pred, truth, d, im_ids) for d in distances) / len(distances) * 100
-
-
-def post_jheatmap(heatmap, offset=None, delta=1):
-    heatmap = nms_j(heatmap, delta=delta)
-    # only select the best 1000 junctions for efficiency
-    v0 = argsort2d(-heatmap)[:1000]
-    confidence = -np.sort(-heatmap.ravel())[:1000]
-    keep_id = np.where(confidence >= 1e-2)[0]
-    if len(keep_id) == 0:
-        return np.zeros((0, 3))
-
-    confidence = confidence[keep_id]
-    if offset is not None:
-        v0 = np.array([v + offset[:, v[0], v[1]] for v in v0])
-    v0 = v0[keep_id] + 0.5
-    v0 = np.hstack((v0, confidence[:, np.newaxis]))
-    return v0
-
-
-def vectorized_wireframe_2d_metric(
-    vert_pred, dpth_pred, edge_pred, vert_gt, dpth_gt, edge_gt, threshold
-):
-    # staging 1: matching
-    nd = len(vert_pred)
-    sorted_confidence = np.argsort(-vert_pred[:, -1])
-    vert_pred = vert_pred[sorted_confidence, :-1]
-    dpth_pred = dpth_pred[sorted_confidence]
-    d = np.sqrt(
-        np.sum(vert_pred ** 2, 1)[:, None]
-        + np.sum(vert_gt ** 2, 1)[None, :]
-        - 2 * vert_pred @ vert_gt.T
-    )
-    choice = np.argmin(d, 1)
-    dist = np.min(d, 1)
-
-    # staging 2: compute depth metric: SIL/L2
-    loss_L1 = loss_L2 = 0
-    hit = np.zeros_like(dpth_gt, np.bool)
-    SIL = np.zeros(dpth_pred)
-    for i in range(nd):
-        if dist[i] < threshold and not hit[choice[i]]:
-            hit[choice[i]] = True
-            loss_L1 += abs(dpth_gt[choice[i]] - dpth_pred[i])
-            loss_L2 += (dpth_gt[choice[i]] - dpth_pred[i]) ** 2
-            a = np.maximum(-dpth_pred[i], 1e-10)
-            b = -dpth_gt[choice[i]]
-            SIL[i] = np.log(a) - np.log(b)
-        else:
-            choice[i] = -1
-
-    n = max(np.sum(hit), 1)
-    loss_L1 /= n
-    loss_L2 /= n
-    loss_SIL = np.sum(SIL ** 2) / n - np.sum(SIL) ** 2 / (n * n)
-
-    # staging 3: compute mAP for edge matching
-    edgeset = set([frozenset(e) for e in edge_gt])
-    tp = np.zeros(len(edge_pred), dtype=np.float)
-    fp = np.zeros(len(edge_pred), dtype=np.float)
-    for i, (v0, v1, score) in enumerate(sorted(edge_pred, key=-edge_pred[2])):
-        length = LA.norm(vert_gt[v0] - vert_gt[v1], axis=1)
-        if frozenset([choice[v0], choice[v1]]) in edgeset:
-            tp[i] = length
-        else:
-            fp[i] = length
-    total_length = LA.norm(
-        vert_gt[edge_gt[:, 0]] - vert_gt[edge_gt[:, 1]], axis=1
-    ).sum()
-    return ap(tp / total_length, fp / total_length), (loss_SIL, loss_L1, loss_L2)
-
-
-def vectorized_wireframe_3d_metric(
-    vert_pred, dpth_pred, edge_pred, vert_gt, dpth_gt, edge_gt, threshold
-):
-    # staging 1: matching
-    nd = len(vert_pred)
-    sorted_confidence = np.argsort(-vert_pred[:, -1])
-    vert_pred = np.hstack([vert_pred[:, :-1], dpth_pred[:, None]])[sorted_confidence]
-    vert_gt = np.hstack([vert_gt[:, :-1], dpth_gt[:, None]])
-    d = np.sqrt(
-        np.sum(vert_pred ** 2, 1)[:, None]
-        + np.sum(vert_gt ** 2, 1)[None, :]
-        - 2 * vert_pred @ vert_gt.T
-    )
-    choice = np.argmin(d, 1)
-    dist = np.min(d, 1)
-    hit = np.zeros_like(dpth_gt, np.bool)
-    for i in range(nd):
-        if dist[i] < threshold and not hit[choice[i]]:
-            hit[choice[i]] = True
-        else:
-            choice[i] = -1
-
-    # staging 2: compute mAP for edge matching
-    edgeset = set([frozenset(e) for e in edge_gt])
-    tp = np.zeros(len(edge_pred), dtype=np.float)
-    fp = np.zeros(len(edge_pred), dtype=np.float)
-    for i, (v0, v1, score) in enumerate(sorted(edge_pred, key=-edge_pred[2])):
-        length = LA.norm(vert_gt[v0] - vert_gt[v1], axis=1)
-        if frozenset([choice[v0], choice[v1]]) in edgeset:
-            tp[i] = length
-        else:
-            fp[i] = length
-    total_length = LA.norm(
-        vert_gt[edge_gt[:, 0]] - vert_gt[edge_gt[:, 1]], axis=1
-    ).sum()
-
-    return ap(tp / total_length, fp / total_length)
-
-
-def msTPFP(line_pred, line_gt, threshold):
-    diff = ((line_pred[:, None, :, None] - line_gt[:, None]) ** 2).sum(-1)
-    diff = np.minimum(
-        diff[:, :, 0, 0] + diff[:, :, 1, 1], diff[:, :, 0, 1] + diff[:, :, 1, 0]
-    )
-    choice = np.argmin(diff, 1)
-    dist = np.min(diff, 1)
-    hit = np.zeros(len(line_gt), np.bool)
-    tp = np.zeros(len(line_pred), np.float)
-    fp = np.zeros(len(line_pred), np.float)
-    for i in range(len(line_pred)):
-        if dist[i] < threshold and not hit[choice[i]]:
-            hit[choice[i]] = True
-            tp[i] = 1
-        else:
-            fp[i] = 1
-    return tp, fp
-
-
-def msAP(line_pred, line_gt, threshold):
-    tp, fp = msTPFP(line_pred, line_gt, threshold)
-    tp = np.cumsum(tp) / len(line_gt)
-    fp = np.cumsum(fp) / len(line_gt)
-    return ap(tp, fp)

+ 0 - 9
lcnn/models/__init__.py

@@ -1,9 +0,0 @@
-# flake8: noqa
-from .hourglass_pose import hg
-from .unet import unet
-from .resnet50_pose import resnet50
-from .resnet50 import resnet501
-from .fasterrcnn_resnet50 import fasterrcnn_resnet50
-# from .dla import dla169, dla102x, dla102x2
-
-__all__ = ['hg', 'unet', 'resnet50', 'resnet501', 'fasterrcnn_resnet50']

+ 0 - 48
lcnn/models/base/base_dataset.py

@@ -1,48 +0,0 @@
-from abc import ABC, abstractmethod
-
-import torch
-from torch import nn, Tensor
-from torch.utils.data import Dataset
-from torch.utils.data.dataset import T_co
-
-from torchvision.transforms import  functional as F
-
-class BaseDataset(Dataset, ABC):
-    def __init__(self,dataset_path):
-        self.default_transform=DefaultTransform()
-        pass
-
-    def __getitem__(self, index) -> T_co:
-        pass
-
-    @abstractmethod
-    def read_target(self,item,lbl_path,extra=None):
-        pass
-
-    """显示数据集指定图片"""
-    @abstractmethod
-    def show(self,idx):
-        pass
-
-    """
-    显示数据集指定名字的图片
-    """
-
-    @abstractmethod
-    def show_img(self,img_path):
-        pass
-
-class DefaultTransform(nn.Module):
-    def forward(self, img: Tensor) -> Tensor:
-        if not isinstance(img, Tensor):
-            img = F.pil_to_tensor(img)
-        return F.convert_image_dtype(img, torch.float)
-
-    def __repr__(self) -> str:
-        return self.__class__.__name__ + "()"
-
-    def describe(self) -> str:
-        return (
-            "Accepts ``PIL.Image``, batched ``(B, C, H, W)`` and single ``(C, H, W)`` image ``torch.Tensor`` objects. "
-            "The images are rescaled to ``[0.0, 1.0]``."
-        )

+ 876 - 0
lcnn/models/detection/ROI_heads.py

@@ -0,0 +1,876 @@
+from typing import Dict, List, Optional, Tuple
+
+import torch
+import torch.nn.functional as F
+import torchvision
+from torch import nn, Tensor
+from torchvision.ops import boxes as box_ops, roi_align
+
+from . import _utils as det_utils
+
+
+def fastrcnn_loss(class_logits, box_regression, labels, regression_targets):
+    # type: (Tensor, Tensor, List[Tensor], List[Tensor]) -> Tuple[Tensor, Tensor]
+    """
+    Computes the loss for Faster R-CNN.
+
+    Args:
+        class_logits (Tensor)
+        box_regression (Tensor)
+        labels (list[BoxList])
+        regression_targets (Tensor)
+
+    Returns:
+        classification_loss (Tensor)
+        box_loss (Tensor)
+    """
+
+    labels = torch.cat(labels, dim=0)
+    regression_targets = torch.cat(regression_targets, dim=0)
+
+    classification_loss = F.cross_entropy(class_logits, labels)
+
+    # get indices that correspond to the regression targets for
+    # the corresponding ground truth labels, to be used with
+    # advanced indexing
+    sampled_pos_inds_subset = torch.where(labels > 0)[0]
+    labels_pos = labels[sampled_pos_inds_subset]
+    N, num_classes = class_logits.shape
+    box_regression = box_regression.reshape(N, box_regression.size(-1) // 4, 4)
+
+    box_loss = F.smooth_l1_loss(
+        box_regression[sampled_pos_inds_subset, labels_pos],
+        regression_targets[sampled_pos_inds_subset],
+        beta=1 / 9,
+        reduction="sum",
+    )
+    box_loss = box_loss / labels.numel()
+
+    return classification_loss, box_loss
+
+
+def maskrcnn_inference(x, labels):
+    # type: (Tensor, List[Tensor]) -> List[Tensor]
+    """
+    From the results of the CNN, post process the masks
+    by taking the mask corresponding to the class with max
+    probability (which are of fixed size and directly output
+    by the CNN) and return the masks in the mask field of the BoxList.
+
+    Args:
+        x (Tensor): the mask logits
+        labels (list[BoxList]): bounding boxes that are used as
+            reference, one for ech image
+
+    Returns:
+        results (list[BoxList]): one BoxList for each image, containing
+            the extra field mask
+    """
+    mask_prob = x.sigmoid()
+
+    # select masks corresponding to the predicted classes
+    num_masks = x.shape[0]
+    boxes_per_image = [label.shape[0] for label in labels]
+    labels = torch.cat(labels)
+    index = torch.arange(num_masks, device=labels.device)
+    mask_prob = mask_prob[index, labels][:, None]
+    mask_prob = mask_prob.split(boxes_per_image, dim=0)
+
+    return mask_prob
+
+
+def project_masks_on_boxes(gt_masks, boxes, matched_idxs, M):
+    # type: (Tensor, Tensor, Tensor, int) -> Tensor
+    """
+    Given segmentation masks and the bounding boxes corresponding
+    to the location of the masks in the image, this function
+    crops and resizes the masks in the position defined by the
+    boxes. This prepares the masks for them to be fed to the
+    loss computation as the targets.
+    """
+    matched_idxs = matched_idxs.to(boxes)
+    rois = torch.cat([matched_idxs[:, None], boxes], dim=1)
+    gt_masks = gt_masks[:, None].to(rois)
+    return roi_align(gt_masks, rois, (M, M), 1.0)[:, 0]
+
+
+def maskrcnn_loss(mask_logits, proposals, gt_masks, gt_labels, mask_matched_idxs):
+    # type: (Tensor, List[Tensor], List[Tensor], List[Tensor], List[Tensor]) -> Tensor
+    """
+    Args:
+        proposals (list[BoxList])
+        mask_logits (Tensor)
+        targets (list[BoxList])
+
+    Return:
+        mask_loss (Tensor): scalar tensor containing the loss
+    """
+
+    discretization_size = mask_logits.shape[-1]
+    labels = [gt_label[idxs] for gt_label, idxs in zip(gt_labels, mask_matched_idxs)]
+    mask_targets = [
+        project_masks_on_boxes(m, p, i, discretization_size) for m, p, i in zip(gt_masks, proposals, mask_matched_idxs)
+    ]
+
+    labels = torch.cat(labels, dim=0)
+    mask_targets = torch.cat(mask_targets, dim=0)
+
+    # torch.mean (in binary_cross_entropy_with_logits) doesn't
+    # accept empty tensors, so handle it separately
+    if mask_targets.numel() == 0:
+        return mask_logits.sum() * 0
+
+    mask_loss = F.binary_cross_entropy_with_logits(
+        mask_logits[torch.arange(labels.shape[0], device=labels.device), labels], mask_targets
+    )
+    return mask_loss
+
+
+def keypoints_to_heatmap(keypoints, rois, heatmap_size):
+    # type: (Tensor, Tensor, int) -> Tuple[Tensor, Tensor]
+    offset_x = rois[:, 0]
+    offset_y = rois[:, 1]
+    scale_x = heatmap_size / (rois[:, 2] - rois[:, 0])
+    scale_y = heatmap_size / (rois[:, 3] - rois[:, 1])
+
+    offset_x = offset_x[:, None]
+    offset_y = offset_y[:, None]
+    scale_x = scale_x[:, None]
+    scale_y = scale_y[:, None]
+
+    x = keypoints[..., 0]
+    y = keypoints[..., 1]
+
+    x_boundary_inds = x == rois[:, 2][:, None]
+    y_boundary_inds = y == rois[:, 3][:, None]
+
+    x = (x - offset_x) * scale_x
+    x = x.floor().long()
+    y = (y - offset_y) * scale_y
+    y = y.floor().long()
+
+    x[x_boundary_inds] = heatmap_size - 1
+    y[y_boundary_inds] = heatmap_size - 1
+
+    valid_loc = (x >= 0) & (y >= 0) & (x < heatmap_size) & (y < heatmap_size)
+    vis = keypoints[..., 2] > 0
+    valid = (valid_loc & vis).long()
+
+    lin_ind = y * heatmap_size + x
+    heatmaps = lin_ind * valid
+
+    return heatmaps, valid
+
+
+def _onnx_heatmaps_to_keypoints(
+    maps, maps_i, roi_map_width, roi_map_height, widths_i, heights_i, offset_x_i, offset_y_i
+):
+    num_keypoints = torch.scalar_tensor(maps.size(1), dtype=torch.int64)
+
+    width_correction = widths_i / roi_map_width
+    height_correction = heights_i / roi_map_height
+
+    roi_map = F.interpolate(
+        maps_i[:, None], size=(int(roi_map_height), int(roi_map_width)), mode="bicubic", align_corners=False
+    )[:, 0]
+
+    w = torch.scalar_tensor(roi_map.size(2), dtype=torch.int64)
+    pos = roi_map.reshape(num_keypoints, -1).argmax(dim=1)
+
+    x_int = pos % w
+    y_int = (pos - x_int) // w
+
+    x = (torch.tensor(0.5, dtype=torch.float32) + x_int.to(dtype=torch.float32)) * width_correction.to(
+        dtype=torch.float32
+    )
+    y = (torch.tensor(0.5, dtype=torch.float32) + y_int.to(dtype=torch.float32)) * height_correction.to(
+        dtype=torch.float32
+    )
+
+    xy_preds_i_0 = x + offset_x_i.to(dtype=torch.float32)
+    xy_preds_i_1 = y + offset_y_i.to(dtype=torch.float32)
+    xy_preds_i_2 = torch.ones(xy_preds_i_1.shape, dtype=torch.float32)
+    xy_preds_i = torch.stack(
+        [
+            xy_preds_i_0.to(dtype=torch.float32),
+            xy_preds_i_1.to(dtype=torch.float32),
+            xy_preds_i_2.to(dtype=torch.float32),
+        ],
+        0,
+    )
+
+    # TODO: simplify when indexing without rank will be supported by ONNX
+    base = num_keypoints * num_keypoints + num_keypoints + 1
+    ind = torch.arange(num_keypoints)
+    ind = ind.to(dtype=torch.int64) * base
+    end_scores_i = (
+        roi_map.index_select(1, y_int.to(dtype=torch.int64))
+        .index_select(2, x_int.to(dtype=torch.int64))
+        .view(-1)
+        .index_select(0, ind.to(dtype=torch.int64))
+    )
+
+    return xy_preds_i, end_scores_i
+
+
+@torch.jit._script_if_tracing
+def _onnx_heatmaps_to_keypoints_loop(
+    maps, rois, widths_ceil, heights_ceil, widths, heights, offset_x, offset_y, num_keypoints
+):
+    xy_preds = torch.zeros((0, 3, int(num_keypoints)), dtype=torch.float32, device=maps.device)
+    end_scores = torch.zeros((0, int(num_keypoints)), dtype=torch.float32, device=maps.device)
+
+    for i in range(int(rois.size(0))):
+        xy_preds_i, end_scores_i = _onnx_heatmaps_to_keypoints(
+            maps, maps[i], widths_ceil[i], heights_ceil[i], widths[i], heights[i], offset_x[i], offset_y[i]
+        )
+        xy_preds = torch.cat((xy_preds.to(dtype=torch.float32), xy_preds_i.unsqueeze(0).to(dtype=torch.float32)), 0)
+        end_scores = torch.cat(
+            (end_scores.to(dtype=torch.float32), end_scores_i.to(dtype=torch.float32).unsqueeze(0)), 0
+        )
+    return xy_preds, end_scores
+
+
+def heatmaps_to_keypoints(maps, rois):
+    """Extract predicted keypoint locations from heatmaps. Output has shape
+    (#rois, 4, #keypoints) with the 4 rows corresponding to (x, y, logit, prob)
+    for each keypoint.
+    """
+    # This function converts a discrete image coordinate in a HEATMAP_SIZE x
+    # HEATMAP_SIZE image to a continuous keypoint coordinate. We maintain
+    # consistency with keypoints_to_heatmap_labels by using the conversion from
+    # Heckbert 1990: c = d + 0.5, where d is a discrete coordinate and c is a
+    # continuous coordinate.
+    offset_x = rois[:, 0]
+    offset_y = rois[:, 1]
+
+    widths = rois[:, 2] - rois[:, 0]
+    heights = rois[:, 3] - rois[:, 1]
+    widths = widths.clamp(min=1)
+    heights = heights.clamp(min=1)
+    widths_ceil = widths.ceil()
+    heights_ceil = heights.ceil()
+
+    num_keypoints = maps.shape[1]
+
+    if torchvision._is_tracing():
+        xy_preds, end_scores = _onnx_heatmaps_to_keypoints_loop(
+            maps,
+            rois,
+            widths_ceil,
+            heights_ceil,
+            widths,
+            heights,
+            offset_x,
+            offset_y,
+            torch.scalar_tensor(num_keypoints, dtype=torch.int64),
+        )
+        return xy_preds.permute(0, 2, 1), end_scores
+
+    xy_preds = torch.zeros((len(rois), 3, num_keypoints), dtype=torch.float32, device=maps.device)
+    end_scores = torch.zeros((len(rois), num_keypoints), dtype=torch.float32, device=maps.device)
+    for i in range(len(rois)):
+        roi_map_width = int(widths_ceil[i].item())
+        roi_map_height = int(heights_ceil[i].item())
+        width_correction = widths[i] / roi_map_width
+        height_correction = heights[i] / roi_map_height
+        roi_map = F.interpolate(
+            maps[i][:, None], size=(roi_map_height, roi_map_width), mode="bicubic", align_corners=False
+        )[:, 0]
+        # roi_map_probs = scores_to_probs(roi_map.copy())
+        w = roi_map.shape[2]
+        pos = roi_map.reshape(num_keypoints, -1).argmax(dim=1)
+
+        x_int = pos % w
+        y_int = torch.div(pos - x_int, w, rounding_mode="floor")
+        # assert (roi_map_probs[k, y_int, x_int] ==
+        #         roi_map_probs[k, :, :].max())
+        x = (x_int.float() + 0.5) * width_correction
+        y = (y_int.float() + 0.5) * height_correction
+        xy_preds[i, 0, :] = x + offset_x[i]
+        xy_preds[i, 1, :] = y + offset_y[i]
+        xy_preds[i, 2, :] = 1
+        end_scores[i, :] = roi_map[torch.arange(num_keypoints, device=roi_map.device), y_int, x_int]
+
+    return xy_preds.permute(0, 2, 1), end_scores
+
+
+def keypointrcnn_loss(keypoint_logits, proposals, gt_keypoints, keypoint_matched_idxs):
+    # type: (Tensor, List[Tensor], List[Tensor], List[Tensor]) -> Tensor
+    N, K, H, W = keypoint_logits.shape
+    if H != W:
+        raise ValueError(
+            f"keypoint_logits height and width (last two elements of shape) should be equal. Instead got H = {H} and W = {W}"
+        )
+    discretization_size = H
+    heatmaps = []
+    valid = []
+    for proposals_per_image, gt_kp_in_image, midx in zip(proposals, gt_keypoints, keypoint_matched_idxs):
+        kp = gt_kp_in_image[midx]
+        heatmaps_per_image, valid_per_image = keypoints_to_heatmap(kp, proposals_per_image, discretization_size)
+        heatmaps.append(heatmaps_per_image.view(-1))
+        valid.append(valid_per_image.view(-1))
+
+    keypoint_targets = torch.cat(heatmaps, dim=0)
+    valid = torch.cat(valid, dim=0).to(dtype=torch.uint8)
+    valid = torch.where(valid)[0]
+
+    # torch.mean (in binary_cross_entropy_with_logits) doesn't
+    # accept empty tensors, so handle it sepaartely
+    if keypoint_targets.numel() == 0 or len(valid) == 0:
+        return keypoint_logits.sum() * 0
+
+    keypoint_logits = keypoint_logits.view(N * K, H * W)
+
+    keypoint_loss = F.cross_entropy(keypoint_logits[valid], keypoint_targets[valid])
+    return keypoint_loss
+
+
+def keypointrcnn_inference(x, boxes):
+    # type: (Tensor, List[Tensor]) -> Tuple[List[Tensor], List[Tensor]]
+    kp_probs = []
+    kp_scores = []
+
+    boxes_per_image = [box.size(0) for box in boxes]
+    x2 = x.split(boxes_per_image, dim=0)
+
+    for xx, bb in zip(x2, boxes):
+        kp_prob, scores = heatmaps_to_keypoints(xx, bb)
+        kp_probs.append(kp_prob)
+        kp_scores.append(scores)
+
+    return kp_probs, kp_scores
+
+
+def _onnx_expand_boxes(boxes, scale):
+    # type: (Tensor, float) -> Tensor
+    w_half = (boxes[:, 2] - boxes[:, 0]) * 0.5
+    h_half = (boxes[:, 3] - boxes[:, 1]) * 0.5
+    x_c = (boxes[:, 2] + boxes[:, 0]) * 0.5
+    y_c = (boxes[:, 3] + boxes[:, 1]) * 0.5
+
+    w_half = w_half.to(dtype=torch.float32) * scale
+    h_half = h_half.to(dtype=torch.float32) * scale
+
+    boxes_exp0 = x_c - w_half
+    boxes_exp1 = y_c - h_half
+    boxes_exp2 = x_c + w_half
+    boxes_exp3 = y_c + h_half
+    boxes_exp = torch.stack((boxes_exp0, boxes_exp1, boxes_exp2, boxes_exp3), 1)
+    return boxes_exp
+
+
+# the next two functions should be merged inside Masker
+# but are kept here for the moment while we need them
+# temporarily for paste_mask_in_image
+def expand_boxes(boxes, scale):
+    # type: (Tensor, float) -> Tensor
+    if torchvision._is_tracing():
+        return _onnx_expand_boxes(boxes, scale)
+    w_half = (boxes[:, 2] - boxes[:, 0]) * 0.5
+    h_half = (boxes[:, 3] - boxes[:, 1]) * 0.5
+    x_c = (boxes[:, 2] + boxes[:, 0]) * 0.5
+    y_c = (boxes[:, 3] + boxes[:, 1]) * 0.5
+
+    w_half *= scale
+    h_half *= scale
+
+    boxes_exp = torch.zeros_like(boxes)
+    boxes_exp[:, 0] = x_c - w_half
+    boxes_exp[:, 2] = x_c + w_half
+    boxes_exp[:, 1] = y_c - h_half
+    boxes_exp[:, 3] = y_c + h_half
+    return boxes_exp
+
+
+@torch.jit.unused
+def expand_masks_tracing_scale(M, padding):
+    # type: (int, int) -> float
+    return torch.tensor(M + 2 * padding).to(torch.float32) / torch.tensor(M).to(torch.float32)
+
+
+def expand_masks(mask, padding):
+    # type: (Tensor, int) -> Tuple[Tensor, float]
+    M = mask.shape[-1]
+    if torch._C._get_tracing_state():  # could not import is_tracing(), not sure why
+        scale = expand_masks_tracing_scale(M, padding)
+    else:
+        scale = float(M + 2 * padding) / M
+    padded_mask = F.pad(mask, (padding,) * 4)
+    return padded_mask, scale
+
+
+def paste_mask_in_image(mask, box, im_h, im_w):
+    # type: (Tensor, Tensor, int, int) -> Tensor
+    TO_REMOVE = 1
+    w = int(box[2] - box[0] + TO_REMOVE)
+    h = int(box[3] - box[1] + TO_REMOVE)
+    w = max(w, 1)
+    h = max(h, 1)
+
+    # Set shape to [batchxCxHxW]
+    mask = mask.expand((1, 1, -1, -1))
+
+    # Resize mask
+    mask = F.interpolate(mask, size=(h, w), mode="bilinear", align_corners=False)
+    mask = mask[0][0]
+
+    im_mask = torch.zeros((im_h, im_w), dtype=mask.dtype, device=mask.device)
+    x_0 = max(box[0], 0)
+    x_1 = min(box[2] + 1, im_w)
+    y_0 = max(box[1], 0)
+    y_1 = min(box[3] + 1, im_h)
+
+    im_mask[y_0:y_1, x_0:x_1] = mask[(y_0 - box[1]) : (y_1 - box[1]), (x_0 - box[0]) : (x_1 - box[0])]
+    return im_mask
+
+
+def _onnx_paste_mask_in_image(mask, box, im_h, im_w):
+    one = torch.ones(1, dtype=torch.int64)
+    zero = torch.zeros(1, dtype=torch.int64)
+
+    w = box[2] - box[0] + one
+    h = box[3] - box[1] + one
+    w = torch.max(torch.cat((w, one)))
+    h = torch.max(torch.cat((h, one)))
+
+    # Set shape to [batchxCxHxW]
+    mask = mask.expand((1, 1, mask.size(0), mask.size(1)))
+
+    # Resize mask
+    mask = F.interpolate(mask, size=(int(h), int(w)), mode="bilinear", align_corners=False)
+    mask = mask[0][0]
+
+    x_0 = torch.max(torch.cat((box[0].unsqueeze(0), zero)))
+    x_1 = torch.min(torch.cat((box[2].unsqueeze(0) + one, im_w.unsqueeze(0))))
+    y_0 = torch.max(torch.cat((box[1].unsqueeze(0), zero)))
+    y_1 = torch.min(torch.cat((box[3].unsqueeze(0) + one, im_h.unsqueeze(0))))
+
+    unpaded_im_mask = mask[(y_0 - box[1]) : (y_1 - box[1]), (x_0 - box[0]) : (x_1 - box[0])]
+
+    # TODO : replace below with a dynamic padding when support is added in ONNX
+
+    # pad y
+    zeros_y0 = torch.zeros(y_0, unpaded_im_mask.size(1))
+    zeros_y1 = torch.zeros(im_h - y_1, unpaded_im_mask.size(1))
+    concat_0 = torch.cat((zeros_y0, unpaded_im_mask.to(dtype=torch.float32), zeros_y1), 0)[0:im_h, :]
+    # pad x
+    zeros_x0 = torch.zeros(concat_0.size(0), x_0)
+    zeros_x1 = torch.zeros(concat_0.size(0), im_w - x_1)
+    im_mask = torch.cat((zeros_x0, concat_0, zeros_x1), 1)[:, :im_w]
+    return im_mask
+
+
+@torch.jit._script_if_tracing
+def _onnx_paste_masks_in_image_loop(masks, boxes, im_h, im_w):
+    res_append = torch.zeros(0, im_h, im_w)
+    for i in range(masks.size(0)):
+        mask_res = _onnx_paste_mask_in_image(masks[i][0], boxes[i], im_h, im_w)
+        mask_res = mask_res.unsqueeze(0)
+        res_append = torch.cat((res_append, mask_res))
+    return res_append
+
+
+def paste_masks_in_image(masks, boxes, img_shape, padding=1):
+    # type: (Tensor, Tensor, Tuple[int, int], int) -> Tensor
+    masks, scale = expand_masks(masks, padding=padding)
+    boxes = expand_boxes(boxes, scale).to(dtype=torch.int64)
+    im_h, im_w = img_shape
+
+    if torchvision._is_tracing():
+        return _onnx_paste_masks_in_image_loop(
+            masks, boxes, torch.scalar_tensor(im_h, dtype=torch.int64), torch.scalar_tensor(im_w, dtype=torch.int64)
+        )[:, None]
+    res = [paste_mask_in_image(m[0], b, im_h, im_w) for m, b in zip(masks, boxes)]
+    if len(res) > 0:
+        ret = torch.stack(res, dim=0)[:, None]
+    else:
+        ret = masks.new_empty((0, 1, im_h, im_w))
+    return ret
+
+
+class RoIHeads(nn.Module):
+    __annotations__ = {
+        "box_coder": det_utils.BoxCoder,
+        "proposal_matcher": det_utils.Matcher,
+        "fg_bg_sampler": det_utils.BalancedPositiveNegativeSampler,
+    }
+
+    def __init__(
+        self,
+        box_roi_pool,
+        box_head,
+        box_predictor,
+        # Faster R-CNN training
+        fg_iou_thresh,
+        bg_iou_thresh,
+        batch_size_per_image,
+        positive_fraction,
+        bbox_reg_weights,
+        # Faster R-CNN inference
+        score_thresh,
+        nms_thresh,
+        detections_per_img,
+        # Mask
+        mask_roi_pool=None,
+        mask_head=None,
+        mask_predictor=None,
+        keypoint_roi_pool=None,
+        keypoint_head=None,
+        keypoint_predictor=None,
+    ):
+        super().__init__()
+
+        self.box_similarity = box_ops.box_iou
+        # assign ground-truth boxes for each proposal
+        self.proposal_matcher = det_utils.Matcher(fg_iou_thresh, bg_iou_thresh, allow_low_quality_matches=False)
+
+        self.fg_bg_sampler = det_utils.BalancedPositiveNegativeSampler(batch_size_per_image, positive_fraction)
+
+        if bbox_reg_weights is None:
+            bbox_reg_weights = (10.0, 10.0, 5.0, 5.0)
+        self.box_coder = det_utils.BoxCoder(bbox_reg_weights)
+
+        self.box_roi_pool = box_roi_pool
+        self.box_head = box_head
+        self.box_predictor = box_predictor
+
+        self.score_thresh = score_thresh
+        self.nms_thresh = nms_thresh
+        self.detections_per_img = detections_per_img
+
+        self.mask_roi_pool = mask_roi_pool
+        self.mask_head = mask_head
+        self.mask_predictor = mask_predictor
+
+        self.keypoint_roi_pool = keypoint_roi_pool
+        self.keypoint_head = keypoint_head
+        self.keypoint_predictor = keypoint_predictor
+
+    def has_mask(self):
+        if self.mask_roi_pool is None:
+            return False
+        if self.mask_head is None:
+            return False
+        if self.mask_predictor is None:
+            return False
+        return True
+
+    def has_keypoint(self):
+        if self.keypoint_roi_pool is None:
+            return False
+        if self.keypoint_head is None:
+            return False
+        if self.keypoint_predictor is None:
+            return False
+        return True
+
+    def assign_targets_to_proposals(self, proposals, gt_boxes, gt_labels):
+        # type: (List[Tensor], List[Tensor], List[Tensor]) -> Tuple[List[Tensor], List[Tensor]]
+        matched_idxs = []
+        labels = []
+        for proposals_in_image, gt_boxes_in_image, gt_labels_in_image in zip(proposals, gt_boxes, gt_labels):
+
+            if gt_boxes_in_image.numel() == 0:
+                # Background image
+                device = proposals_in_image.device
+                clamped_matched_idxs_in_image = torch.zeros(
+                    (proposals_in_image.shape[0],), dtype=torch.int64, device=device
+                )
+                labels_in_image = torch.zeros((proposals_in_image.shape[0],), dtype=torch.int64, device=device)
+            else:
+                #  set to self.box_similarity when https://github.com/pytorch/pytorch/issues/27495 lands
+                match_quality_matrix = box_ops.box_iou(gt_boxes_in_image, proposals_in_image)
+                matched_idxs_in_image = self.proposal_matcher(match_quality_matrix)
+
+                clamped_matched_idxs_in_image = matched_idxs_in_image.clamp(min=0)
+
+                labels_in_image = gt_labels_in_image[clamped_matched_idxs_in_image]
+                labels_in_image = labels_in_image.to(dtype=torch.int64)
+
+                # Label background (below the low threshold)
+                bg_inds = matched_idxs_in_image == self.proposal_matcher.BELOW_LOW_THRESHOLD
+                labels_in_image[bg_inds] = 0
+
+                # Label ignore proposals (between low and high thresholds)
+                ignore_inds = matched_idxs_in_image == self.proposal_matcher.BETWEEN_THRESHOLDS
+                labels_in_image[ignore_inds] = -1  # -1 is ignored by sampler
+
+            matched_idxs.append(clamped_matched_idxs_in_image)
+            labels.append(labels_in_image)
+        return matched_idxs, labels
+
+    def subsample(self, labels):
+        # type: (List[Tensor]) -> List[Tensor]
+        sampled_pos_inds, sampled_neg_inds = self.fg_bg_sampler(labels)
+        sampled_inds = []
+        for img_idx, (pos_inds_img, neg_inds_img) in enumerate(zip(sampled_pos_inds, sampled_neg_inds)):
+            img_sampled_inds = torch.where(pos_inds_img | neg_inds_img)[0]
+            sampled_inds.append(img_sampled_inds)
+        return sampled_inds
+
+    def add_gt_proposals(self, proposals, gt_boxes):
+        # type: (List[Tensor], List[Tensor]) -> List[Tensor]
+        proposals = [torch.cat((proposal, gt_box)) for proposal, gt_box in zip(proposals, gt_boxes)]
+
+        return proposals
+
+    def check_targets(self, targets):
+        # type: (Optional[List[Dict[str, Tensor]]]) -> None
+        if targets is None:
+            raise ValueError("targets should not be None")
+        if not all(["boxes" in t for t in targets]):
+            raise ValueError("Every element of targets should have a boxes key")
+        if not all(["labels" in t for t in targets]):
+            raise ValueError("Every element of targets should have a labels key")
+        if self.has_mask():
+            if not all(["masks" in t for t in targets]):
+                raise ValueError("Every element of targets should have a masks key")
+
+    def select_training_samples(
+        self,
+        proposals,  # type: List[Tensor]
+        targets,  # type: Optional[List[Dict[str, Tensor]]]
+    ):
+        # type: (...) -> Tuple[List[Tensor], List[Tensor], List[Tensor], List[Tensor]]
+        self.check_targets(targets)
+        if targets is None:
+            raise ValueError("targets should not be None")
+        dtype = proposals[0].dtype
+        device = proposals[0].device
+
+        gt_boxes = [t["boxes"].to(dtype) for t in targets]
+        gt_labels = [t["labels"] for t in targets]
+
+        # append ground-truth bboxes to propos
+        proposals = self.add_gt_proposals(proposals, gt_boxes)
+
+        # get matching gt indices for each proposal
+        matched_idxs, labels = self.assign_targets_to_proposals(proposals, gt_boxes, gt_labels)
+        # sample a fixed proportion of positive-negative proposals
+        sampled_inds = self.subsample(labels)
+        matched_gt_boxes = []
+        num_images = len(proposals)
+        for img_id in range(num_images):
+            img_sampled_inds = sampled_inds[img_id]
+            proposals[img_id] = proposals[img_id][img_sampled_inds]
+            labels[img_id] = labels[img_id][img_sampled_inds]
+            matched_idxs[img_id] = matched_idxs[img_id][img_sampled_inds]
+
+            gt_boxes_in_image = gt_boxes[img_id]
+            if gt_boxes_in_image.numel() == 0:
+                gt_boxes_in_image = torch.zeros((1, 4), dtype=dtype, device=device)
+            matched_gt_boxes.append(gt_boxes_in_image[matched_idxs[img_id]])
+
+        regression_targets = self.box_coder.encode(matched_gt_boxes, proposals)
+        return proposals, matched_idxs, labels, regression_targets
+
+    def postprocess_detections(
+        self,
+        class_logits,  # type: Tensor
+        box_regression,  # type: Tensor
+        proposals,  # type: List[Tensor]
+        image_shapes,  # type: List[Tuple[int, int]]
+    ):
+        # type: (...) -> Tuple[List[Tensor], List[Tensor], List[Tensor]]
+        device = class_logits.device
+        num_classes = class_logits.shape[-1]
+
+        boxes_per_image = [boxes_in_image.shape[0] for boxes_in_image in proposals]
+        pred_boxes = self.box_coder.decode(box_regression, proposals)
+
+        pred_scores = F.softmax(class_logits, -1)
+
+        pred_boxes_list = pred_boxes.split(boxes_per_image, 0)
+        pred_scores_list = pred_scores.split(boxes_per_image, 0)
+
+        all_boxes = []
+        all_scores = []
+        all_labels = []
+        for boxes, scores, image_shape in zip(pred_boxes_list, pred_scores_list, image_shapes):
+            boxes = box_ops.clip_boxes_to_image(boxes, image_shape)
+
+            # create labels for each prediction
+            labels = torch.arange(num_classes, device=device)
+            labels = labels.view(1, -1).expand_as(scores)
+
+            # remove predictions with the background label
+            boxes = boxes[:, 1:]
+            scores = scores[:, 1:]
+            labels = labels[:, 1:]
+
+            # batch everything, by making every class prediction be a separate instance
+            boxes = boxes.reshape(-1, 4)
+            scores = scores.reshape(-1)
+            labels = labels.reshape(-1)
+
+            # remove low scoring boxes
+            inds = torch.where(scores > self.score_thresh)[0]
+            boxes, scores, labels = boxes[inds], scores[inds], labels[inds]
+
+            # remove empty boxes
+            keep = box_ops.remove_small_boxes(boxes, min_size=1e-2)
+            boxes, scores, labels = boxes[keep], scores[keep], labels[keep]
+
+            # non-maximum suppression, independently done per class
+            keep = box_ops.batched_nms(boxes, scores, labels, self.nms_thresh)
+            # keep only topk scoring predictions
+            keep = keep[: self.detections_per_img]
+            boxes, scores, labels = boxes[keep], scores[keep], labels[keep]
+
+            all_boxes.append(boxes)
+            all_scores.append(scores)
+            all_labels.append(labels)
+
+        return all_boxes, all_scores, all_labels
+
+    def forward(
+        self,
+        features,  # type: Dict[str, Tensor]
+        proposals,  # type: List[Tensor]
+        image_shapes,  # type: List[Tuple[int, int]]
+        targets=None,  # type: Optional[List[Dict[str, Tensor]]]
+    ):
+        # type: (...) -> Tuple[List[Dict[str, Tensor]], Dict[str, Tensor]]
+        """
+        Args:
+            features (List[Tensor])
+            proposals (List[Tensor[N, 4]])
+            image_shapes (List[Tuple[H, W]])
+            targets (List[Dict])
+        """
+        if targets is not None:
+            for t in targets:
+                # TODO: https://github.com/pytorch/pytorch/issues/26731
+                floating_point_types = (torch.float, torch.double, torch.half)
+                if not t["boxes"].dtype in floating_point_types:
+                    raise TypeError(f"target boxes must of float type, instead got {t['boxes'].dtype}")
+                if not t["labels"].dtype == torch.int64:
+                    raise TypeError(f"target labels must of int64 type, instead got {t['labels'].dtype}")
+                if self.has_keypoint():
+                    if not t["keypoints"].dtype == torch.float32:
+                        raise TypeError(f"target keypoints must of float type, instead got {t['keypoints'].dtype}")
+
+        if self.training:
+            proposals, matched_idxs, labels, regression_targets = self.select_training_samples(proposals, targets)
+        else:
+            labels = None
+            regression_targets = None
+            matched_idxs = None
+
+        box_features = self.box_roi_pool(features, proposals, image_shapes)
+        box_features = self.box_head(box_features)
+        class_logits, box_regression = self.box_predictor(box_features)
+
+        result: List[Dict[str, torch.Tensor]] = []
+        losses = {}
+        if self.training:
+            if labels is None:
+                raise ValueError("labels cannot be None")
+            if regression_targets is None:
+                raise ValueError("regression_targets cannot be None")
+            loss_classifier, loss_box_reg = fastrcnn_loss(class_logits, box_regression, labels, regression_targets)
+            losses = {"loss_classifier": loss_classifier, "loss_box_reg": loss_box_reg}
+        else:
+            boxes, scores, labels = self.postprocess_detections(class_logits, box_regression, proposals, image_shapes)
+            num_images = len(boxes)
+            for i in range(num_images):
+                result.append(
+                    {
+                        "boxes": boxes[i],
+                        "labels": labels[i],
+                        "scores": scores[i],
+                    }
+                )
+
+        if self.has_mask():
+            mask_proposals = [p["boxes"] for p in result]
+            if self.training:
+                if matched_idxs is None:
+                    raise ValueError("if in training, matched_idxs should not be None")
+
+                # during training, only focus on positive boxes
+                num_images = len(proposals)
+                mask_proposals = []
+                pos_matched_idxs = []
+                for img_id in range(num_images):
+                    pos = torch.where(labels[img_id] > 0)[0]
+                    mask_proposals.append(proposals[img_id][pos])
+                    pos_matched_idxs.append(matched_idxs[img_id][pos])
+            else:
+                pos_matched_idxs = None
+
+            if self.mask_roi_pool is not None:
+                mask_features = self.mask_roi_pool(features, mask_proposals, image_shapes)
+                mask_features = self.mask_head(mask_features)
+                mask_logits = self.mask_predictor(mask_features)
+            else:
+                raise Exception("Expected mask_roi_pool to be not None")
+
+            loss_mask = {}
+            if self.training:
+                if targets is None or pos_matched_idxs is None or mask_logits is None:
+                    raise ValueError("targets, pos_matched_idxs, mask_logits cannot be None when training")
+
+                gt_masks = [t["masks"] for t in targets]
+                gt_labels = [t["labels"] for t in targets]
+                rcnn_loss_mask = maskrcnn_loss(mask_logits, mask_proposals, gt_masks, gt_labels, pos_matched_idxs)
+                loss_mask = {"loss_mask": rcnn_loss_mask}
+            else:
+                labels = [r["labels"] for r in result]
+                masks_probs = maskrcnn_inference(mask_logits, labels)
+                for mask_prob, r in zip(masks_probs, result):
+                    r["masks"] = mask_prob
+
+            losses.update(loss_mask)
+
+        # keep none checks in if conditional so torchscript will conditionally
+        # compile each branch
+        if (
+            self.keypoint_roi_pool is not None
+            and self.keypoint_head is not None
+            and self.keypoint_predictor is not None
+        ):
+            keypoint_proposals = [p["boxes"] for p in result]
+            if self.training:
+                # during training, only focus on positive boxes
+                num_images = len(proposals)
+                keypoint_proposals = []
+                pos_matched_idxs = []
+                if matched_idxs is None:
+                    raise ValueError("if in trainning, matched_idxs should not be None")
+
+                for img_id in range(num_images):
+                    pos = torch.where(labels[img_id] > 0)[0]
+                    keypoint_proposals.append(proposals[img_id][pos])
+                    pos_matched_idxs.append(matched_idxs[img_id][pos])
+            else:
+                pos_matched_idxs = None
+
+            keypoint_features = self.keypoint_roi_pool(features, keypoint_proposals, image_shapes)
+            keypoint_features = self.keypoint_head(keypoint_features)
+            keypoint_logits = self.keypoint_predictor(keypoint_features)
+
+            loss_keypoint = {}
+            if self.training:
+                if targets is None or pos_matched_idxs is None:
+                    raise ValueError("both targets and pos_matched_idxs should not be None when in training mode")
+
+                gt_keypoints = [t["keypoints"] for t in targets]
+                rcnn_loss_keypoint = keypointrcnn_loss(
+                    keypoint_logits, keypoint_proposals, gt_keypoints, pos_matched_idxs
+                )
+                loss_keypoint = {"loss_keypoint": rcnn_loss_keypoint}
+            else:
+                if keypoint_logits is None or keypoint_proposals is None:
+                    raise ValueError(
+                        "both keypoint_logits and keypoint_proposals should not be None when not in training mode"
+                    )
+
+                keypoints_probs, kp_scores = keypointrcnn_inference(keypoint_logits, keypoint_proposals)
+                for keypoint_prob, kps, r in zip(keypoints_probs, kp_scores, result):
+                    r["keypoints"] = keypoint_prob
+                    r["keypoints_scores"] = kps
+            losses.update(loss_keypoint)
+
+        return result, losses

+ 846 - 0
lcnn/models/detection/faster_rcnn.py

@@ -0,0 +1,846 @@
+from typing import Any, Callable, List, Optional, Tuple, Union
+
+import torch
+import torch.nn.functional as F
+from torch import nn
+from torchvision.ops import MultiScaleRoIAlign
+
+from ...ops import misc as misc_nn_ops
+from ...transforms._presets import ObjectDetection
+from .._api import register_model, Weights, WeightsEnum
+from .._meta import _COCO_CATEGORIES
+from .._utils import _ovewrite_value_param, handle_legacy_interface
+from ..mobilenetv3 import mobilenet_v3_large, MobileNet_V3_Large_Weights
+from ..resnet import resnet50, ResNet50_Weights
+from ._utils import overwrite_eps
+from .anchor_utils import AnchorGenerator
+from .backbone_utils import _mobilenet_extractor, _resnet_fpn_extractor, _validate_trainable_layers
+from .generalized_rcnn import GeneralizedRCNN
+from .roi_heads import RoIHeads
+from .rpn import RegionProposalNetwork, RPNHead
+from .transform import GeneralizedRCNNTransform
+
+
+__all__ = [
+    "FasterRCNN",
+    "FasterRCNN_ResNet50_FPN_Weights",
+    "FasterRCNN_ResNet50_FPN_V2_Weights",
+    "FasterRCNN_MobileNet_V3_Large_FPN_Weights",
+    "FasterRCNN_MobileNet_V3_Large_320_FPN_Weights",
+    "fasterrcnn_resnet50_fpn",
+    "fasterrcnn_resnet50_fpn_v2",
+    "fasterrcnn_mobilenet_v3_large_fpn",
+    "fasterrcnn_mobilenet_v3_large_320_fpn",
+]
+
+
+def _default_anchorgen():
+    anchor_sizes = ((32,), (64,), (128,), (256,), (512,))
+    aspect_ratios = ((0.5, 1.0, 2.0),) * len(anchor_sizes)
+    return AnchorGenerator(anchor_sizes, aspect_ratios)
+
+
+class FasterRCNN(GeneralizedRCNN):
+    """
+    Implements Faster R-CNN.
+
+    The input to the model is expected to be a list of tensors, each of shape [C, H, W], one for each
+    image, and should be in 0-1 range. Different images can have different sizes.
+
+    The behavior of the model changes depending on if it is in training or evaluation mode.
+
+    During training, the model expects both the input tensors and targets (list of dictionary),
+    containing:
+        - boxes (``FloatTensor[N, 4]``): the ground-truth boxes in ``[x1, y1, x2, y2]`` format, with
+          ``0 <= x1 < x2 <= W`` and ``0 <= y1 < y2 <= H``.
+        - labels (Int64Tensor[N]): the class label for each ground-truth box
+
+    The model returns a Dict[Tensor] during training, containing the classification and regression
+    losses for both the RPN and the R-CNN.
+
+    During inference, the model requires only the input tensors, and returns the post-processed
+    predictions as a List[Dict[Tensor]], one for each input image. The fields of the Dict are as
+    follows:
+        - boxes (``FloatTensor[N, 4]``): the predicted boxes in ``[x1, y1, x2, y2]`` format, with
+          ``0 <= x1 < x2 <= W`` and ``0 <= y1 < y2 <= H``.
+        - labels (Int64Tensor[N]): the predicted labels for each image
+        - scores (Tensor[N]): the scores or each prediction
+
+    Args:
+        backbone (nn.Module): the network used to compute the features for the model.
+            It should contain an out_channels attribute, which indicates the number of output
+            channels that each feature map has (and it should be the same for all feature maps).
+            The backbone should return a single Tensor or and OrderedDict[Tensor].
+        num_classes (int): number of output classes of the model (including the background).
+            If box_predictor is specified, num_classes should be None.
+        min_size (int): Images are rescaled before feeding them to the backbone:
+            we attempt to preserve the aspect ratio and scale the shorter edge
+            to ``min_size``. If the resulting longer edge exceeds ``max_size``,
+            then downscale so that the longer edge does not exceed ``max_size``.
+            This may result in the shorter edge beeing lower than ``min_size``.
+        max_size (int): See ``min_size``.
+        image_mean (Tuple[float, float, float]): mean values used for input normalization.
+            They are generally the mean values of the dataset on which the backbone has been trained
+            on
+        image_std (Tuple[float, float, float]): std values used for input normalization.
+            They are generally the std values of the dataset on which the backbone has been trained on
+        rpn_anchor_generator (AnchorGenerator): module that generates the anchors for a set of feature
+            maps.
+        rpn_head (nn.Module): module that computes the objectness and regression deltas from the RPN
+        rpn_pre_nms_top_n_train (int): number of proposals to keep before applying NMS during training
+        rpn_pre_nms_top_n_test (int): number of proposals to keep before applying NMS during testing
+        rpn_post_nms_top_n_train (int): number of proposals to keep after applying NMS during training
+        rpn_post_nms_top_n_test (int): number of proposals to keep after applying NMS during testing
+        rpn_nms_thresh (float): NMS threshold used for postprocessing the RPN proposals
+        rpn_fg_iou_thresh (float): minimum IoU between the anchor and the GT box so that they can be
+            considered as positive during training of the RPN.
+        rpn_bg_iou_thresh (float): maximum IoU between the anchor and the GT box so that they can be
+            considered as negative during training of the RPN.
+        rpn_batch_size_per_image (int): number of anchors that are sampled during training of the RPN
+            for computing the loss
+        rpn_positive_fraction (float): proportion of positive anchors in a mini-batch during training
+            of the RPN
+        rpn_score_thresh (float): only return proposals with an objectness score greater than rpn_score_thresh
+        box_roi_pool (MultiScaleRoIAlign): the module which crops and resizes the feature maps in
+            the locations indicated by the bounding boxes
+        box_head (nn.Module): module that takes the cropped feature maps as input
+        box_predictor (nn.Module): module that takes the output of box_head and returns the
+            classification logits and box regression deltas.
+        box_score_thresh (float): during inference, only return proposals with a classification score
+            greater than box_score_thresh
+        box_nms_thresh (float): NMS threshold for the prediction head. Used during inference
+        box_detections_per_img (int): maximum number of detections per image, for all classes.
+        box_fg_iou_thresh (float): minimum IoU between the proposals and the GT box so that they can be
+            considered as positive during training of the classification head
+        box_bg_iou_thresh (float): maximum IoU between the proposals and the GT box so that they can be
+            considered as negative during training of the classification head
+        box_batch_size_per_image (int): number of proposals that are sampled during training of the
+            classification head
+        box_positive_fraction (float): proportion of positive proposals in a mini-batch during training
+            of the classification head
+        bbox_reg_weights (Tuple[float, float, float, float]): weights for the encoding/decoding of the
+            bounding boxes
+
+    Example::
+
+        >>> import torch
+        >>> import torchvision
+        >>> from torchvision.models.detection import FasterRCNN
+        >>> from torchvision.models.detection.rpn import AnchorGenerator
+        >>> # load a pre-trained model for classification and return
+        >>> # only the features
+        >>> backbone = torchvision.models.mobilenet_v2(weights=MobileNet_V2_Weights.DEFAULT).features
+        >>> # FasterRCNN needs to know the number of
+        >>> # output channels in a backbone. For mobilenet_v2, it's 1280,
+        >>> # so we need to add it here
+        >>> backbone.out_channels = 1280
+        >>>
+        >>> # let's make the RPN generate 5 x 3 anchors per spatial
+        >>> # location, with 5 different sizes and 3 different aspect
+        >>> # ratios. We have a Tuple[Tuple[int]] because each feature
+        >>> # map could potentially have different sizes and
+        >>> # aspect ratios
+        >>> anchor_generator = AnchorGenerator(sizes=((32, 64, 128, 256, 512),),
+        >>>                                    aspect_ratios=((0.5, 1.0, 2.0),))
+        >>>
+        >>> # let's define what are the feature maps that we will
+        >>> # use to perform the region of interest cropping, as well as
+        >>> # the size of the crop after rescaling.
+        >>> # if your backbone returns a Tensor, featmap_names is expected to
+        >>> # be ['0']. More generally, the backbone should return an
+        >>> # OrderedDict[Tensor], and in featmap_names you can choose which
+        >>> # feature maps to use.
+        >>> roi_pooler = torchvision.ops.MultiScaleRoIAlign(featmap_names=['0'],
+        >>>                                                 output_size=7,
+        >>>                                                 sampling_ratio=2)
+        >>>
+        >>> # put the pieces together inside a FasterRCNN model
+        >>> model = FasterRCNN(backbone,
+        >>>                    num_classes=2,
+        >>>                    rpn_anchor_generator=anchor_generator,
+        >>>                    box_roi_pool=roi_pooler)
+        >>> model.eval()
+        >>> x = [torch.rand(3, 300, 400), torch.rand(3, 500, 400)]
+        >>> predictions = model(x)
+    """
+
+    def __init__(
+        self,
+        backbone,
+        num_classes=None,
+        # transform parameters
+        min_size=512,   # 原800
+        max_size=1333,
+        image_mean=None,
+        image_std=None,
+        # RPN parameters
+        rpn_anchor_generator=None,
+        rpn_head=None,
+        rpn_pre_nms_top_n_train=2000,
+        rpn_pre_nms_top_n_test=1000,
+        rpn_post_nms_top_n_train=2000,
+        rpn_post_nms_top_n_test=1000,
+        rpn_nms_thresh=0.7,
+        rpn_fg_iou_thresh=0.7,
+        rpn_bg_iou_thresh=0.3,
+        rpn_batch_size_per_image=256,
+        rpn_positive_fraction=0.5,
+        rpn_score_thresh=0.0,
+        # Box parameters
+        box_roi_pool=None,
+        box_head=None,
+        box_predictor=None,
+        box_score_thresh=0.05,
+        box_nms_thresh=0.5,
+        box_detections_per_img=100,
+        box_fg_iou_thresh=0.5,
+        box_bg_iou_thresh=0.5,
+        box_batch_size_per_image=512,
+        box_positive_fraction=0.25,
+        bbox_reg_weights=None,
+        **kwargs,
+    ):
+
+        if not hasattr(backbone, "out_channels"):
+            raise ValueError(
+                "backbone should contain an attribute out_channels "
+                "specifying the number of output channels (assumed to be the "
+                "same for all the levels)"
+            )
+
+        if not isinstance(rpn_anchor_generator, (AnchorGenerator, type(None))):
+            raise TypeError(
+                f"rpn_anchor_generator should be of type AnchorGenerator or None instead of {type(rpn_anchor_generator)}"
+            )
+        if not isinstance(box_roi_pool, (MultiScaleRoIAlign, type(None))):
+            raise TypeError(
+                f"box_roi_pool should be of type MultiScaleRoIAlign or None instead of {type(box_roi_pool)}"
+            )
+
+        if num_classes is not None:
+            if box_predictor is not None:
+                raise ValueError("num_classes should be None when box_predictor is specified")
+        else:
+            if box_predictor is None:
+                raise ValueError("num_classes should not be None when box_predictor is not specified")
+
+        out_channels = backbone.out_channels
+
+        if rpn_anchor_generator is None:
+            rpn_anchor_generator = _default_anchorgen()
+        if rpn_head is None:
+            rpn_head = RPNHead(out_channels, rpn_anchor_generator.num_anchors_per_location()[0])
+
+        rpn_pre_nms_top_n = dict(training=rpn_pre_nms_top_n_train, testing=rpn_pre_nms_top_n_test)
+        rpn_post_nms_top_n = dict(training=rpn_post_nms_top_n_train, testing=rpn_post_nms_top_n_test)
+
+        rpn = RegionProposalNetwork(
+            rpn_anchor_generator,
+            rpn_head,
+            rpn_fg_iou_thresh,
+            rpn_bg_iou_thresh,
+            rpn_batch_size_per_image,
+            rpn_positive_fraction,
+            rpn_pre_nms_top_n,
+            rpn_post_nms_top_n,
+            rpn_nms_thresh,
+            score_thresh=rpn_score_thresh,
+        )
+
+        if box_roi_pool is None:
+            box_roi_pool = MultiScaleRoIAlign(featmap_names=["0", "1", "2", "3"], output_size=7, sampling_ratio=2)
+
+        if box_head is None:
+            resolution = box_roi_pool.output_size[0]
+            representation_size = 1024
+            box_head = TwoMLPHead(out_channels * resolution**2, representation_size)
+
+        if box_predictor is None:
+            representation_size = 1024
+            box_predictor = FastRCNNPredictor(representation_size, num_classes)
+
+        roi_heads = RoIHeads(
+            # Box
+            box_roi_pool,
+            box_head,
+            box_predictor,
+            box_fg_iou_thresh,
+            box_bg_iou_thresh,
+            box_batch_size_per_image,
+            box_positive_fraction,
+            bbox_reg_weights,
+            box_score_thresh,
+            box_nms_thresh,
+            box_detections_per_img,
+        )
+
+        if image_mean is None:
+            image_mean = [0.485, 0.456, 0.406]
+        if image_std is None:
+            image_std = [0.229, 0.224, 0.225]
+        transform = GeneralizedRCNNTransform(min_size, max_size, image_mean, image_std, **kwargs)
+
+        super().__init__(backbone, rpn, roi_heads, transform)
+
+
+class TwoMLPHead(nn.Module):
+    """
+    Standard heads for FPN-based models
+
+    Args:
+        in_channels (int): number of input channels
+        representation_size (int): size of the intermediate representation
+    """
+
+    def __init__(self, in_channels, representation_size):
+        super().__init__()
+
+        self.fc6 = nn.Linear(in_channels, representation_size)
+        self.fc7 = nn.Linear(representation_size, representation_size)
+
+    def forward(self, x):
+        x = x.flatten(start_dim=1)
+
+        x = F.relu(self.fc6(x))
+        x = F.relu(self.fc7(x))
+
+        return x
+
+
+class FastRCNNConvFCHead(nn.Sequential):
+    def __init__(
+        self,
+        input_size: Tuple[int, int, int],
+        conv_layers: List[int],
+        fc_layers: List[int],
+        norm_layer: Optional[Callable[..., nn.Module]] = None,
+    ):
+        """
+        Args:
+            input_size (Tuple[int, int, int]): the input size in CHW format.
+            conv_layers (list): feature dimensions of each Convolution layer
+            fc_layers (list): feature dimensions of each FCN layer
+            norm_layer (callable, optional): Module specifying the normalization layer to use. Default: None
+        """
+        in_channels, in_height, in_width = input_size
+
+        blocks = []
+        previous_channels = in_channels
+        for current_channels in conv_layers:
+            blocks.append(misc_nn_ops.Conv2dNormActivation(previous_channels, current_channels, norm_layer=norm_layer))
+            previous_channels = current_channels
+        blocks.append(nn.Flatten())
+        previous_channels = previous_channels * in_height * in_width
+        for current_channels in fc_layers:
+            blocks.append(nn.Linear(previous_channels, current_channels))
+            blocks.append(nn.ReLU(inplace=True))
+            previous_channels = current_channels
+
+        super().__init__(*blocks)
+        for layer in self.modules():
+            if isinstance(layer, nn.Conv2d):
+                nn.init.kaiming_normal_(layer.weight, mode="fan_out", nonlinearity="relu")
+                if layer.bias is not None:
+                    nn.init.zeros_(layer.bias)
+
+
+class FastRCNNPredictor(nn.Module):
+    """
+    Standard classification + bounding box regression layers
+    for Fast R-CNN.
+
+    Args:
+        in_channels (int): number of input channels
+        num_classes (int): number of output classes (including background)
+    """
+
+    def __init__(self, in_channels, num_classes):
+        super().__init__()
+        self.cls_score = nn.Linear(in_channels, num_classes)
+        self.bbox_pred = nn.Linear(in_channels, num_classes * 4)
+
+    def forward(self, x):
+        if x.dim() == 4:
+            torch._assert(
+                list(x.shape[2:]) == [1, 1],
+                f"x has the wrong shape, expecting the last two dimensions to be [1,1] instead of {list(x.shape[2:])}",
+            )
+        x = x.flatten(start_dim=1)
+        scores = self.cls_score(x)
+        bbox_deltas = self.bbox_pred(x)
+
+        return scores, bbox_deltas
+
+
+_COMMON_META = {
+    "categories": _COCO_CATEGORIES,
+    "min_size": (1, 1),
+}
+
+
+class FasterRCNN_ResNet50_FPN_Weights(WeightsEnum):
+    COCO_V1 = Weights(
+        url="https://download.pytorch.org/models/fasterrcnn_resnet50_fpn_coco-258fb6c6.pth",
+        transforms=ObjectDetection,
+        meta={
+            **_COMMON_META,
+            "num_params": 41755286,
+            "recipe": "https://github.com/pytorch/vision/tree/main/references/detection#faster-r-cnn-resnet-50-fpn",
+            "_metrics": {
+                "COCO-val2017": {
+                    "box_map": 37.0,
+                }
+            },
+            "_ops": 134.38,
+            "_file_size": 159.743,
+            "_docs": """These weights were produced by following a similar training recipe as on the paper.""",
+        },
+    )
+    DEFAULT = COCO_V1
+
+
+class FasterRCNN_ResNet50_FPN_V2_Weights(WeightsEnum):
+    COCO_V1 = Weights(
+        url="https://download.pytorch.org/models/fasterrcnn_resnet50_fpn_v2_coco-dd69338a.pth",
+        transforms=ObjectDetection,
+        meta={
+            **_COMMON_META,
+            "num_params": 43712278,
+            "recipe": "https://github.com/pytorch/vision/pull/5763",
+            "_metrics": {
+                "COCO-val2017": {
+                    "box_map": 46.7,
+                }
+            },
+            "_ops": 280.371,
+            "_file_size": 167.104,
+            "_docs": """These weights were produced using an enhanced training recipe to boost the model accuracy.""",
+        },
+    )
+    DEFAULT = COCO_V1
+
+
+class FasterRCNN_MobileNet_V3_Large_FPN_Weights(WeightsEnum):
+    COCO_V1 = Weights(
+        url="https://download.pytorch.org/models/fasterrcnn_mobilenet_v3_large_fpn-fb6a3cc7.pth",
+        transforms=ObjectDetection,
+        meta={
+            **_COMMON_META,
+            "num_params": 19386354,
+            "recipe": "https://github.com/pytorch/vision/tree/main/references/detection#faster-r-cnn-mobilenetv3-large-fpn",
+            "_metrics": {
+                "COCO-val2017": {
+                    "box_map": 32.8,
+                }
+            },
+            "_ops": 4.494,
+            "_file_size": 74.239,
+            "_docs": """These weights were produced by following a similar training recipe as on the paper.""",
+        },
+    )
+    DEFAULT = COCO_V1
+
+
+class FasterRCNN_MobileNet_V3_Large_320_FPN_Weights(WeightsEnum):
+    COCO_V1 = Weights(
+        url="https://download.pytorch.org/models/fasterrcnn_mobilenet_v3_large_320_fpn-907ea3f9.pth",
+        transforms=ObjectDetection,
+        meta={
+            **_COMMON_META,
+            "num_params": 19386354,
+            "recipe": "https://github.com/pytorch/vision/tree/main/references/detection#faster-r-cnn-mobilenetv3-large-320-fpn",
+            "_metrics": {
+                "COCO-val2017": {
+                    "box_map": 22.8,
+                }
+            },
+            "_ops": 0.719,
+            "_file_size": 74.239,
+            "_docs": """These weights were produced by following a similar training recipe as on the paper.""",
+        },
+    )
+    DEFAULT = COCO_V1
+
+
+@register_model()
+@handle_legacy_interface(
+    weights=("pretrained", FasterRCNN_ResNet50_FPN_Weights.COCO_V1),
+    weights_backbone=("pretrained_backbone", ResNet50_Weights.IMAGENET1K_V1),
+)
+def fasterrcnn_resnet50_fpn(
+    *,
+    weights: Optional[FasterRCNN_ResNet50_FPN_Weights] = None,
+    progress: bool = True,
+    num_classes: Optional[int] = None,
+    weights_backbone: Optional[ResNet50_Weights] = ResNet50_Weights.IMAGENET1K_V1,
+    trainable_backbone_layers: Optional[int] = None,
+    **kwargs: Any,
+) -> FasterRCNN:
+    """
+    Faster R-CNN model with a ResNet-50-FPN backbone from the `Faster R-CNN: Towards Real-Time Object
+    Detection with Region Proposal Networks <https://arxiv.org/abs/1506.01497>`__
+    paper.
+
+    .. betastatus:: detection module
+
+    The input to the model is expected to be a list of tensors, each of shape ``[C, H, W]``, one for each
+    image, and should be in ``0-1`` range. Different images can have different sizes.
+
+    The behavior of the model changes depending on if it is in training or evaluation mode.
+
+    During training, the model expects both the input tensors and a targets (list of dictionary),
+    containing:
+
+        - boxes (``FloatTensor[N, 4]``): the ground-truth boxes in ``[x1, y1, x2, y2]`` format, with
+          ``0 <= x1 < x2 <= W`` and ``0 <= y1 < y2 <= H``.
+        - labels (``Int64Tensor[N]``): the class label for each ground-truth box
+
+    The model returns a ``Dict[Tensor]`` during training, containing the classification and regression
+    losses for both the RPN and the R-CNN.
+
+    During inference, the model requires only the input tensors, and returns the post-processed
+    predictions as a ``List[Dict[Tensor]]``, one for each input image. The fields of the ``Dict`` are as
+    follows, where ``N`` is the number of detections:
+
+        - boxes (``FloatTensor[N, 4]``): the predicted boxes in ``[x1, y1, x2, y2]`` format, with
+          ``0 <= x1 < x2 <= W`` and ``0 <= y1 < y2 <= H``.
+        - labels (``Int64Tensor[N]``): the predicted labels for each detection
+        - scores (``Tensor[N]``): the scores of each detection
+
+    For more details on the output, you may refer to :ref:`instance_seg_output`.
+
+    Faster R-CNN is exportable to ONNX for a fixed batch size with inputs images of fixed size.
+
+    Example::
+
+        >>> model = torchvision.models.detection.fasterrcnn_resnet50_fpn(weights=FasterRCNN_ResNet50_FPN_Weights.DEFAULT)
+        >>> # For training
+        >>> images, boxes = torch.rand(4, 3, 600, 1200), torch.rand(4, 11, 4)
+        >>> boxes[:, :, 2:4] = boxes[:, :, 0:2] + boxes[:, :, 2:4]
+        >>> labels = torch.randint(1, 91, (4, 11))
+        >>> images = list(image for image in images)
+        >>> targets = []
+        >>> for i in range(len(images)):
+        >>>     d = {}
+        >>>     d['boxes'] = boxes[i]
+        >>>     d['labels'] = labels[i]
+        >>>     targets.append(d)
+        >>> output = model(images, targets)
+        >>> # For inference
+        >>> model.eval()
+        >>> x = [torch.rand(3, 300, 400), torch.rand(3, 500, 400)]
+        >>> predictions = model(x)
+        >>>
+        >>> # optionally, if you want to export the model to ONNX:
+        >>> torch.onnx.export(model, x, "faster_rcnn.onnx", opset_version = 11)
+
+    Args:
+        weights (:class:`~torchvision.models.detection.FasterRCNN_ResNet50_FPN_Weights`, optional): The
+            pretrained weights to use. See
+            :class:`~torchvision.models.detection.FasterRCNN_ResNet50_FPN_Weights` below for
+            more details, and possible values. By default, no pre-trained
+            weights are used.
+        progress (bool, optional): If True, displays a progress bar of the
+            download to stderr. Default is True.
+        num_classes (int, optional): number of output classes of the model (including the background)
+        weights_backbone (:class:`~torchvision.models.ResNet50_Weights`, optional): The
+            pretrained weights for the backbone.
+        trainable_backbone_layers (int, optional): number of trainable (not frozen) layers starting from
+            final block. Valid values are between 0 and 5, with 5 meaning all backbone layers are
+            trainable. If ``None`` is passed (the default) this value is set to 3.
+        **kwargs: parameters passed to the ``torchvision.models.detection.faster_rcnn.FasterRCNN``
+            base class. Please refer to the `source code
+            <https://github.com/pytorch/vision/blob/main/torchvision/models/detection/faster_rcnn.py>`_
+            for more details about this class.
+
+    .. autoclass:: torchvision.models.detection.FasterRCNN_ResNet50_FPN_Weights
+        :members:
+    """
+    weights = FasterRCNN_ResNet50_FPN_Weights.verify(weights)
+    weights_backbone = ResNet50_Weights.verify(weights_backbone)
+
+    if weights is not None:
+        weights_backbone = None
+        num_classes = _ovewrite_value_param("num_classes", num_classes, len(weights.meta["categories"]))
+    elif num_classes is None:
+        num_classes = 91
+
+    is_trained = weights is not None or weights_backbone is not None
+    trainable_backbone_layers = _validate_trainable_layers(is_trained, trainable_backbone_layers, 5, 3)
+    norm_layer = misc_nn_ops.FrozenBatchNorm2d if is_trained else nn.BatchNorm2d
+
+    backbone = resnet50(weights=weights_backbone, progress=progress, norm_layer=norm_layer)
+    backbone = _resnet_fpn_extractor(backbone, trainable_backbone_layers)
+    model = FasterRCNN(backbone, num_classes=num_classes, **kwargs)
+
+    if weights is not None:
+        model.load_state_dict(weights.get_state_dict(progress=progress, check_hash=True))
+        if weights == FasterRCNN_ResNet50_FPN_Weights.COCO_V1:
+            overwrite_eps(model, 0.0)
+
+    return model
+
+
+@register_model()
+@handle_legacy_interface(
+    weights=("pretrained", FasterRCNN_ResNet50_FPN_V2_Weights.COCO_V1),
+    weights_backbone=("pretrained_backbone", ResNet50_Weights.IMAGENET1K_V1),
+)
+def fasterrcnn_resnet50_fpn_v2(
+    *,
+    weights: Optional[FasterRCNN_ResNet50_FPN_V2_Weights] = None,
+    progress: bool = True,
+    num_classes: Optional[int] = None,
+    weights_backbone: Optional[ResNet50_Weights] = None,
+    trainable_backbone_layers: Optional[int] = None,
+    **kwargs: Any,
+) -> FasterRCNN:
+    """
+    Constructs an improved Faster R-CNN model with a ResNet-50-FPN backbone from `Benchmarking Detection
+    Transfer Learning with Vision Transformers <https://arxiv.org/abs/2111.11429>`__ paper.
+
+    .. betastatus:: detection module
+
+    It works similarly to Faster R-CNN with ResNet-50 FPN backbone. See
+    :func:`~torchvision.models.detection.fasterrcnn_resnet50_fpn` for more
+    details.
+
+    Args:
+        weights (:class:`~torchvision.models.detection.FasterRCNN_ResNet50_FPN_V2_Weights`, optional): The
+            pretrained weights to use. See
+            :class:`~torchvision.models.detection.FasterRCNN_ResNet50_FPN_V2_Weights` below for
+            more details, and possible values. By default, no pre-trained
+            weights are used.
+        progress (bool, optional): If True, displays a progress bar of the
+            download to stderr. Default is True.
+        num_classes (int, optional): number of output classes of the model (including the background)
+        weights_backbone (:class:`~torchvision.models.ResNet50_Weights`, optional): The
+            pretrained weights for the backbone.
+        trainable_backbone_layers (int, optional): number of trainable (not frozen) layers starting from
+            final block. Valid values are between 0 and 5, with 5 meaning all backbone layers are
+            trainable. If ``None`` is passed (the default) this value is set to 3.
+        **kwargs: parameters passed to the ``torchvision.models.detection.faster_rcnn.FasterRCNN``
+            base class. Please refer to the `source code
+            <https://github.com/pytorch/vision/blob/main/torchvision/models/detection/faster_rcnn.py>`_
+            for more details about this class.
+
+    .. autoclass:: torchvision.models.detection.FasterRCNN_ResNet50_FPN_V2_Weights
+        :members:
+    """
+    weights = FasterRCNN_ResNet50_FPN_V2_Weights.verify(weights)
+    weights_backbone = ResNet50_Weights.verify(weights_backbone)
+
+    if weights is not None:
+        weights_backbone = None
+        num_classes = _ovewrite_value_param("num_classes", num_classes, len(weights.meta["categories"]))
+    elif num_classes is None:
+        num_classes = 91
+
+    is_trained = weights is not None or weights_backbone is not None
+    trainable_backbone_layers = _validate_trainable_layers(is_trained, trainable_backbone_layers, 5, 3)
+
+    backbone = resnet50(weights=weights_backbone, progress=progress)
+    backbone = _resnet_fpn_extractor(backbone, trainable_backbone_layers, norm_layer=nn.BatchNorm2d)
+    rpn_anchor_generator = _default_anchorgen()
+    rpn_head = RPNHead(backbone.out_channels, rpn_anchor_generator.num_anchors_per_location()[0], conv_depth=2)
+    box_head = FastRCNNConvFCHead(
+        (backbone.out_channels, 7, 7), [256, 256, 256, 256], [1024], norm_layer=nn.BatchNorm2d
+    )
+    model = FasterRCNN(
+        backbone,
+        num_classes=num_classes,
+        rpn_anchor_generator=rpn_anchor_generator,
+        rpn_head=rpn_head,
+        box_head=box_head,
+        **kwargs,
+    )
+
+    if weights is not None:
+        model.load_state_dict(weights.get_state_dict(progress=progress, check_hash=True))
+
+    return model
+
+
+def _fasterrcnn_mobilenet_v3_large_fpn(
+    *,
+    weights: Optional[Union[FasterRCNN_MobileNet_V3_Large_FPN_Weights, FasterRCNN_MobileNet_V3_Large_320_FPN_Weights]],
+    progress: bool,
+    num_classes: Optional[int],
+    weights_backbone: Optional[MobileNet_V3_Large_Weights],
+    trainable_backbone_layers: Optional[int],
+    **kwargs: Any,
+) -> FasterRCNN:
+    if weights is not None:
+        weights_backbone = None
+        num_classes = _ovewrite_value_param("num_classes", num_classes, len(weights.meta["categories"]))
+    elif num_classes is None:
+        num_classes = 91
+
+    is_trained = weights is not None or weights_backbone is not None
+    trainable_backbone_layers = _validate_trainable_layers(is_trained, trainable_backbone_layers, 6, 3)
+    norm_layer = misc_nn_ops.FrozenBatchNorm2d if is_trained else nn.BatchNorm2d
+
+    backbone = mobilenet_v3_large(weights=weights_backbone, progress=progress, norm_layer=norm_layer)
+    backbone = _mobilenet_extractor(backbone, True, trainable_backbone_layers)
+    anchor_sizes = (
+        (
+            32,
+            64,
+            128,
+            256,
+            512,
+        ),
+    ) * 3
+    aspect_ratios = ((0.5, 1.0, 2.0),) * len(anchor_sizes)
+    model = FasterRCNN(
+        backbone, num_classes, rpn_anchor_generator=AnchorGenerator(anchor_sizes, aspect_ratios), **kwargs
+    )
+
+    if weights is not None:
+        model.load_state_dict(weights.get_state_dict(progress=progress, check_hash=True))
+
+    return model
+
+
+@register_model()
+@handle_legacy_interface(
+    weights=("pretrained", FasterRCNN_MobileNet_V3_Large_320_FPN_Weights.COCO_V1),
+    weights_backbone=("pretrained_backbone", MobileNet_V3_Large_Weights.IMAGENET1K_V1),
+)
+def fasterrcnn_mobilenet_v3_large_320_fpn(
+    *,
+    weights: Optional[FasterRCNN_MobileNet_V3_Large_320_FPN_Weights] = None,
+    progress: bool = True,
+    num_classes: Optional[int] = None,
+    weights_backbone: Optional[MobileNet_V3_Large_Weights] = MobileNet_V3_Large_Weights.IMAGENET1K_V1,
+    trainable_backbone_layers: Optional[int] = None,
+    **kwargs: Any,
+) -> FasterRCNN:
+    """
+    Low resolution Faster R-CNN model with a MobileNetV3-Large backbone tuned for mobile use cases.
+
+    .. betastatus:: detection module
+
+    It works similarly to Faster R-CNN with ResNet-50 FPN backbone. See
+    :func:`~torchvision.models.detection.fasterrcnn_resnet50_fpn` for more
+    details.
+
+    Example::
+
+        >>> model = torchvision.models.detection.fasterrcnn_mobilenet_v3_large_320_fpn(weights=FasterRCNN_MobileNet_V3_Large_320_FPN_Weights.DEFAULT)
+        >>> model.eval()
+        >>> x = [torch.rand(3, 300, 400), torch.rand(3, 500, 400)]
+        >>> predictions = model(x)
+
+    Args:
+        weights (:class:`~torchvision.models.detection.FasterRCNN_MobileNet_V3_Large_320_FPN_Weights`, optional): The
+            pretrained weights to use. See
+            :class:`~torchvision.models.detection.FasterRCNN_MobileNet_V3_Large_320_FPN_Weights` below for
+            more details, and possible values. By default, no pre-trained
+            weights are used.
+        progress (bool, optional): If True, displays a progress bar of the
+            download to stderr. Default is True.
+        num_classes (int, optional): number of output classes of the model (including the background)
+        weights_backbone (:class:`~torchvision.models.MobileNet_V3_Large_Weights`, optional): The
+            pretrained weights for the backbone.
+        trainable_backbone_layers (int, optional): number of trainable (not frozen) layers starting from
+            final block. Valid values are between 0 and 6, with 6 meaning all backbone layers are
+            trainable. If ``None`` is passed (the default) this value is set to 3.
+        **kwargs: parameters passed to the ``torchvision.models.detection.faster_rcnn.FasterRCNN``
+            base class. Please refer to the `source code
+            <https://github.com/pytorch/vision/blob/main/torchvision/models/detection/faster_rcnn.py>`_
+            for more details about this class.
+
+    .. autoclass:: torchvision.models.detection.FasterRCNN_MobileNet_V3_Large_320_FPN_Weights
+        :members:
+    """
+    weights = FasterRCNN_MobileNet_V3_Large_320_FPN_Weights.verify(weights)
+    weights_backbone = MobileNet_V3_Large_Weights.verify(weights_backbone)
+
+    defaults = {
+        "min_size": 320,
+        "max_size": 640,
+        "rpn_pre_nms_top_n_test": 150,
+        "rpn_post_nms_top_n_test": 150,
+        "rpn_score_thresh": 0.05,
+    }
+
+    kwargs = {**defaults, **kwargs}
+    return _fasterrcnn_mobilenet_v3_large_fpn(
+        weights=weights,
+        progress=progress,
+        num_classes=num_classes,
+        weights_backbone=weights_backbone,
+        trainable_backbone_layers=trainable_backbone_layers,
+        **kwargs,
+    )
+
+
+@register_model()
+@handle_legacy_interface(
+    weights=("pretrained", FasterRCNN_MobileNet_V3_Large_FPN_Weights.COCO_V1),
+    weights_backbone=("pretrained_backbone", MobileNet_V3_Large_Weights.IMAGENET1K_V1),
+)
+def fasterrcnn_mobilenet_v3_large_fpn(
+    *,
+    weights: Optional[FasterRCNN_MobileNet_V3_Large_FPN_Weights] = None,
+    progress: bool = True,
+    num_classes: Optional[int] = None,
+    weights_backbone: Optional[MobileNet_V3_Large_Weights] = MobileNet_V3_Large_Weights.IMAGENET1K_V1,
+    trainable_backbone_layers: Optional[int] = None,
+    **kwargs: Any,
+) -> FasterRCNN:
+    """
+    Constructs a high resolution Faster R-CNN model with a MobileNetV3-Large FPN backbone.
+
+    .. betastatus:: detection module
+
+    It works similarly to Faster R-CNN with ResNet-50 FPN backbone. See
+    :func:`~torchvision.models.detection.fasterrcnn_resnet50_fpn` for more
+    details.
+
+    Example::
+
+        >>> model = torchvision.models.detection.fasterrcnn_mobilenet_v3_large_fpn(weights=FasterRCNN_MobileNet_V3_Large_FPN_Weights.DEFAULT)
+        >>> model.eval()
+        >>> x = [torch.rand(3, 300, 400), torch.rand(3, 500, 400)]
+        >>> predictions = model(x)
+
+    Args:
+        weights (:class:`~torchvision.models.detection.FasterRCNN_MobileNet_V3_Large_FPN_Weights`, optional): The
+            pretrained weights to use. See
+            :class:`~torchvision.models.detection.FasterRCNN_MobileNet_V3_Large_FPN_Weights` below for
+            more details, and possible values. By default, no pre-trained
+            weights are used.
+        progress (bool, optional): If True, displays a progress bar of the
+            download to stderr. Default is True.
+        num_classes (int, optional): number of output classes of the model (including the background)
+        weights_backbone (:class:`~torchvision.models.MobileNet_V3_Large_Weights`, optional): The
+            pretrained weights for the backbone.
+        trainable_backbone_layers (int, optional): number of trainable (not frozen) layers starting from
+            final block. Valid values are between 0 and 6, with 6 meaning all backbone layers are
+            trainable. If ``None`` is passed (the default) this value is set to 3.
+        **kwargs: parameters passed to the ``torchvision.models.detection.faster_rcnn.FasterRCNN``
+            base class. Please refer to the `source code
+            <https://github.com/pytorch/vision/blob/main/torchvision/models/detection/faster_rcnn.py>`_
+            for more details about this class.
+
+    .. autoclass:: torchvision.models.detection.FasterRCNN_MobileNet_V3_Large_FPN_Weights
+        :members:
+    """
+    weights = FasterRCNN_MobileNet_V3_Large_FPN_Weights.verify(weights)
+    weights_backbone = MobileNet_V3_Large_Weights.verify(weights_backbone)
+
+    defaults = {
+        "rpn_score_thresh": 0.05,
+    }
+
+    kwargs = {**defaults, **kwargs}
+    return _fasterrcnn_mobilenet_v3_large_fpn(
+        weights=weights,
+        progress=progress,
+        num_classes=num_classes,
+        weights_backbone=weights_backbone,
+        trainable_backbone_layers=trainable_backbone_layers,
+        **kwargs,
+    )

+ 336 - 0
lcnn/models/detection/transform.py

@@ -0,0 +1,336 @@
+import math
+from typing import Any, Dict, List, Optional, Tuple
+
+import torch
+import torchvision
+from torch import nn, Tensor
+
+# from .image_list import ImageList
+# from .roi_heads import paste_masks_in_image
+from .ROI_heads import paste_masks_in_image
+
+class ImageList:
+    """
+    Structure that holds a list of images (of possibly
+    varying sizes) as a single tensor.
+    This works by padding the images to the same size,
+    and storing in a field the original sizes of each image
+
+    Args:
+        tensors (tensor): Tensor containing images.
+        image_sizes (list[tuple[int, int]]): List of Tuples each containing size of images.
+    """
+
+    def __init__(self, tensors: Tensor, image_sizes: List[Tuple[int, int]]) -> None:
+        self.tensors = tensors
+        self.image_sizes = image_sizes
+
+    def to(self, device: torch.device) -> "ImageList":
+        cast_tensor = self.tensors.to(device)
+        return ImageList(cast_tensor, self.image_sizes)
+
+
+def _get_shape_onnx(image: Tensor) -> Tensor:
+    from torch.onnx import operators
+
+    return operators.shape_as_tensor(image)[-2:]
+
+def _fake_cast_onnx(v: Tensor) -> float:
+    # ONNX requires a tensor but here we fake its type for JIT.
+    return v
+
+def _resize_image_and_masks(
+    image: Tensor,
+    self_min_size: int,
+    self_max_size: int,
+    target: Optional[Dict[str, Tensor]] = None,
+    fixed_size: Optional[Tuple[int, int]] = None,
+) -> Tuple[Tensor, Optional[Dict[str, Tensor]]]:
+    if torchvision._is_tracing():
+        im_shape = _get_shape_onnx(image)
+    elif torch.jit.is_scripting():
+        im_shape = torch.tensor(image.shape[-2:])
+    else:
+        im_shape = image.shape[-2:]
+
+    size: Optional[List[int]] = None
+    scale_factor: Optional[float] = None
+    recompute_scale_factor: Optional[bool] = None
+    if fixed_size is not None:
+        size = [fixed_size[1], fixed_size[0]]
+    else:
+        if torch.jit.is_scripting() or torchvision._is_tracing():
+            min_size = torch.min(im_shape).to(dtype=torch.float32)
+            max_size = torch.max(im_shape).to(dtype=torch.float32)
+            self_min_size_f = float(self_min_size)
+            self_max_size_f = float(self_max_size)
+            scale = torch.min(self_min_size_f / min_size, self_max_size_f / max_size)
+
+            if torchvision._is_tracing():
+                scale_factor = _fake_cast_onnx(scale)
+            else:
+                scale_factor = scale.item()
+
+        else:
+            # Do it the normal way
+            min_size = min(im_shape)
+            max_size = max(im_shape)
+            scale_factor = min(self_min_size / min_size, self_max_size / max_size)
+
+        recompute_scale_factor = True
+
+    image = torch.nn.functional.interpolate(
+        image[None],
+        size=size,
+        scale_factor=scale_factor,
+        mode="bilinear",
+        recompute_scale_factor=recompute_scale_factor,
+        align_corners=False,
+    )[0]
+
+    if target is None:
+        return image, target
+
+    if "masks" in target:
+        mask = target["masks"]
+        mask = torch.nn.functional.interpolate(
+            mask[:, None].float(), size=size, scale_factor=scale_factor, recompute_scale_factor=recompute_scale_factor
+        )[:, 0].byte()
+        target["masks"] = mask
+    return image, target
+
+
+def resize_boxes(boxes: Tensor, original_size: List[int], new_size: List[int]) -> Tensor:
+    ratios = [
+        torch.tensor(s, dtype=torch.float32, device=boxes.device)
+        / torch.tensor(s_orig, dtype=torch.float32, device=boxes.device)
+        for s, s_orig in zip(new_size, original_size)
+    ]
+    ratio_height, ratio_width = ratios
+    xmin, ymin, xmax, ymax = boxes.unbind(1)
+
+    xmin = xmin * ratio_width
+    xmax = xmax * ratio_width
+    ymin = ymin * ratio_height
+    ymax = ymax * ratio_height
+    return torch.stack((xmin, ymin, xmax, ymax), dim=1)
+
+
+def resize_keypoints(keypoints: Tensor, original_size: List[int], new_size: List[int]) -> Tensor:
+    ratios = [
+        torch.tensor(s, dtype=torch.float32, device=keypoints.device)
+        / torch.tensor(s_orig, dtype=torch.float32, device=keypoints.device)
+        for s, s_orig in zip(new_size, original_size)
+    ]
+    ratio_h, ratio_w = ratios
+    resized_data = keypoints.clone()
+    if torch._C._get_tracing_state():
+        resized_data_0 = resized_data[:, :, 0] * ratio_w
+        resized_data_1 = resized_data[:, :, 1] * ratio_h
+        resized_data = torch.stack((resized_data_0, resized_data_1, resized_data[:, :, 2]), dim=2)
+    else:
+        resized_data[..., 0] *= ratio_w
+        resized_data[..., 1] *= ratio_h
+    return resized_data
+
+
+class GeneralizedRCNNTransform(nn.Module):
+    """
+    Performs input / target transformation before feeding the data to a GeneralizedRCNN
+    model.
+
+    The transformations it performs are:
+        - input normalization (mean subtraction and std division)
+        - input / target resizing to match min_size / max_size
+
+    It returns a ImageList for the inputs, and a List[Dict[Tensor]] for the targets
+    """
+
+    def __init__(
+        self,
+        min_size: int,
+        max_size: int,
+        image_mean: List[float],
+        image_std: List[float],
+        size_divisible: int = 32,
+        fixed_size: Optional[Tuple[int, int]] = None,
+        **kwargs: Any,
+    ):
+        super().__init__()
+        if not isinstance(min_size, (list, tuple)):
+            min_size = (min_size,)
+        self.min_size = min_size
+        self.max_size = max_size
+        self.image_mean = image_mean
+        self.image_std = image_std
+        self.size_divisible = size_divisible
+        self.fixed_size = fixed_size
+        self._skip_resize = kwargs.pop("_skip_resize", False)
+
+    def forward(
+        self, images: List[Tensor], targets: Optional[List[Dict[str, Tensor]]] = None
+    ) -> Tuple[ImageList, Optional[List[Dict[str, Tensor]]]]:
+        images = [img for img in images]
+        if targets is not None:
+            # make a copy of targets to avoid modifying it in-place
+            # once torchscript supports dict comprehension
+            # this can be simplified as follows
+            # targets = [{k: v for k,v in t.items()} for t in targets]
+            targets_copy: List[Dict[str, Tensor]] = []
+            for t in targets:
+                data: Dict[str, Tensor] = {}
+                for k, v in t.items():
+                    data[k] = v
+                targets_copy.append(data)
+            targets = targets_copy
+        for i in range(len(images)):
+            image = images[i]
+            target_index = targets[i] if targets is not None else None
+
+            if image.dim() != 3:
+                raise ValueError(f"images is expected to be a list of 3d tensors of shape [C, H, W], got {image.shape}")
+            image = self.normalize(image)
+            image, target_index = self.resize(image, target_index)
+            images[i] = image
+            if targets is not None and target_index is not None:
+                targets[i] = target_index
+
+        image_sizes = [img.shape[-2:] for img in images]
+        images = self.batch_images(images, size_divisible=self.size_divisible)
+        image_sizes_list: List[Tuple[int, int]] = []
+        for image_size in image_sizes:
+            torch._assert(
+                len(image_size) == 2,
+                f"Input tensors expected to have in the last two elements H and W, instead got {image_size}",
+            )
+            image_sizes_list.append((image_size[0], image_size[1]))
+
+        image_list = ImageList(images, image_sizes_list)
+        return image_list, targets
+
+    def normalize(self, image: Tensor) -> Tensor:
+        if not image.is_floating_point():
+            raise TypeError(
+                f"Expected input images to be of floating type (in range [0, 1]), "
+                f"but found type {image.dtype} instead"
+            )
+        dtype, device = image.dtype, image.device
+        mean = torch.as_tensor(self.image_mean, dtype=dtype, device=device)
+        std = torch.as_tensor(self.image_std, dtype=dtype, device=device)
+        return (image - mean[:, None, None]) / std[:, None, None]
+
+    def torch_choice(self, k: List[int]) -> int:
+        """
+        Implements `random.choice` via torch ops, so it can be compiled with
+        TorchScript and we use PyTorch's RNG (not native RNG)
+        """
+        index = int(torch.empty(1).uniform_(0.0, float(len(k))).item())
+        return k[index]
+
+    def resize(
+        self,
+        image: Tensor,
+        target: Optional[Dict[str, Tensor]] = None,
+    ) -> Tuple[Tensor, Optional[Dict[str, Tensor]]]:
+        h, w = image.shape[-2:]
+        if self.training:
+            if self._skip_resize:
+                return image, target
+            size = self.torch_choice(self.min_size)
+        else:
+            size = self.min_size[-1]
+        image, target = _resize_image_and_masks(image, size, self.max_size, target, self.fixed_size)
+
+        if target is None:
+            return image, target
+
+        bbox = target["boxes"]
+        bbox = resize_boxes(bbox, (h, w), image.shape[-2:])
+        target["boxes"] = bbox
+
+        if "keypoints" in target:
+            keypoints = target["keypoints"]
+            keypoints = resize_keypoints(keypoints, (h, w), image.shape[-2:])
+            target["keypoints"] = keypoints
+        return image, target
+
+    # _onnx_batch_images() is an implementation of
+    # batch_images() that is supported by ONNX tracing.
+    @torch.jit.unused
+    def _onnx_batch_images(self, images: List[Tensor], size_divisible: int = 32) -> Tensor:
+        max_size = []
+        for i in range(images[0].dim()):
+            max_size_i = torch.max(torch.stack([img.shape[i] for img in images]).to(torch.float32)).to(torch.int64)
+            max_size.append(max_size_i)
+        stride = size_divisible
+        max_size[1] = (torch.ceil((max_size[1].to(torch.float32)) / stride) * stride).to(torch.int64)
+        max_size[2] = (torch.ceil((max_size[2].to(torch.float32)) / stride) * stride).to(torch.int64)
+        max_size = tuple(max_size)
+
+        # work around for
+        # pad_img[: img.shape[0], : img.shape[1], : img.shape[2]].copy_(img)
+        # which is not yet supported in onnx
+        padded_imgs = []
+        for img in images:
+            padding = [(s1 - s2) for s1, s2 in zip(max_size, tuple(img.shape))]
+            padded_img = torch.nn.functional.pad(img, (0, padding[2], 0, padding[1], 0, padding[0]))
+            padded_imgs.append(padded_img)
+
+        return torch.stack(padded_imgs)
+
+    def max_by_axis(self, the_list: List[List[int]]) -> List[int]:
+        maxes = the_list[0]
+        for sublist in the_list[1:]:
+            for index, item in enumerate(sublist):
+                maxes[index] = max(maxes[index], item)
+        return maxes
+
+    def batch_images(self, images: List[Tensor], size_divisible: int = 32) -> Tensor:
+        if torchvision._is_tracing():
+            # batch_images() does not export well to ONNX
+            # call _onnx_batch_images() instead
+            return self._onnx_batch_images(images, size_divisible)
+
+        max_size = self.max_by_axis([list(img.shape) for img in images])
+        stride = float(size_divisible)
+        max_size = list(max_size)
+        max_size[1] = int(math.ceil(float(max_size[1]) / stride) * stride)
+        max_size[2] = int(math.ceil(float(max_size[2]) / stride) * stride)
+
+        batch_shape = [len(images)] + max_size
+        batched_imgs = images[0].new_full(batch_shape, 0)
+        for i in range(batched_imgs.shape[0]):
+            img = images[i]
+            batched_imgs[i, : img.shape[0], : img.shape[1], : img.shape[2]].copy_(img)
+
+        return batched_imgs
+
+    def postprocess(
+        self,
+        result: List[Dict[str, Tensor]],
+        image_shapes: List[Tuple[int, int]],
+        original_image_sizes: List[Tuple[int, int]],
+    ) -> List[Dict[str, Tensor]]:
+        if self.training:
+            return result
+        for i, (pred, im_s, o_im_s) in enumerate(zip(result, image_shapes, original_image_sizes)):
+            boxes = pred["boxes"]
+            boxes = resize_boxes(boxes, im_s, o_im_s)
+            result[i]["boxes"] = boxes
+            if "masks" in pred:
+                masks = pred["masks"]
+                masks = paste_masks_in_image(masks, boxes, o_im_s)
+                result[i]["masks"] = masks
+            if "keypoints" in pred:
+                keypoints = pred["keypoints"]
+                keypoints = resize_keypoints(keypoints, im_s, o_im_s)
+                result[i]["keypoints"] = keypoints
+        return result
+
+    def __repr__(self) -> str:
+        format_string = f"{self.__class__.__name__}("
+        _indent = "\n    "
+        format_string += f"{_indent}Normalize(mean={self.image_mean}, std={self.image_std})"
+        format_string += f"{_indent}Resize(min_size={self.min_size}, max_size={self.max_size}, mode='bilinear')"
+        format_string += "\n)"
+        return format_string

+ 8 - 51
lcnn/models/fasterrcnn_resnet50.py

@@ -1,12 +1,8 @@
 import torch
 import torch.nn as nn
 import torchvision
-from typing import Dict, List, Optional, Tuple
-import torch.nn.functional as F
-from torchvision.ops import MultiScaleRoIAlign
-from torchvision.models.detection.faster_rcnn import TwoMLPHead, FastRCNNPredictor
 from torchvision.models.detection.transform import GeneralizedRCNNTransform
-
+# from .detection.transform import GeneralizedRCNNTransform
 
 def get_model(num_classes):
     # 加载预训练的ResNet-50 FPN backbone
@@ -21,45 +17,6 @@ def get_model(num_classes):
     return model
 
 
-def fastrcnn_loss(class_logits, box_regression, labels, regression_targets):
-    # type: (Tensor, Tensor, List[Tensor], List[Tensor]) -> Tuple[Tensor, Tensor]
-    """
-    Computes the loss for Faster R-CNN.
-
-    Args:
-        class_logits (Tensor)
-        box_regression (Tensor)
-        labels (list[BoxList])
-        regression_targets (Tensor)
-
-    Returns:
-        classification_loss (Tensor)
-        box_loss (Tensor)
-    """
-
-    labels = torch.cat(labels, dim=0)
-    regression_targets = torch.cat(regression_targets, dim=0)
-
-    classification_loss = F.cross_entropy(class_logits, labels)
-
-    # get indices that correspond to the regression targets for
-    # the corresponding ground truth labels, to be used with
-    # advanced indexing
-    sampled_pos_inds_subset = torch.where(labels > 0)[0]
-    labels_pos = labels[sampled_pos_inds_subset]
-    N, num_classes = class_logits.shape
-    box_regression = box_regression.reshape(N, box_regression.size(-1) // 4, 4)
-
-    box_loss = F.smooth_l1_loss(
-        box_regression[sampled_pos_inds_subset, labels_pos],
-        regression_targets[sampled_pos_inds_subset],
-        beta=1 / 9,
-        reduction="sum",
-    )
-    box_loss = box_loss / labels.numel()
-
-    return classification_loss, box_loss
-
 
 class Fasterrcnn_resnet50(nn.Module):
     def __init__(self, num_classes=5, num_stacks=1):
@@ -68,14 +25,14 @@ class Fasterrcnn_resnet50(nn.Module):
         self.model = get_model(num_classes=5)
         self.backbone = self.model.backbone
 
-        self.box_roi_pool = MultiScaleRoIAlign(featmap_names=["0", "1", "2", "3"], output_size=16, sampling_ratio=2)
-
-        out_channels = self.backbone.out_channels
-        resolution = self.box_roi_pool.output_size[0]
-        representation_size = 1024
-        self.box_head = TwoMLPHead(out_channels * resolution ** 2, representation_size)
+        # self.box_roi_pool = MultiScaleRoIAlign(featmap_names=["0", "1", "2", "3"], output_size=16, sampling_ratio=2)
 
-        self.box_predictor = FastRCNNPredictor(representation_size, num_classes)
+        # out_channels = self.backbone.out_channels
+        # resolution = self.box_roi_pool.output_size[0]
+        # representation_size = 1024
+        # self.box_head = TwoMLPHead(out_channels * resolution ** 2, representation_size)
+        #
+        # self.box_predictor = FastRCNNPredictor(representation_size, num_classes)
 
         # 多任务输出层
         self.score_layers = nn.ModuleList([

+ 0 - 201
lcnn/models/hourglass_pose.py

@@ -1,201 +0,0 @@
-"""
-Hourglass network inserted in the pre-activated Resnet
-Use lr=0.01 for current version
-(c) Yichao Zhou (LCNN)
-(c) YANG, Wei
-"""
-import torch
-import torch.nn as nn
-import torch.nn.functional as F
-
-__all__ = ["HourglassNet", "hg"]
-
-
-class Bottleneck2D(nn.Module):
-    expansion = 2  # 扩展因子
-
-    def __init__(self, inplanes, planes, stride=1, downsample=None):
-        super(Bottleneck2D, self).__init__()
-
-        self.bn1 = nn.BatchNorm2d(inplanes)
-        self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1)
-        self.bn2 = nn.BatchNorm2d(planes)
-        self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, padding=1)
-        self.bn3 = nn.BatchNorm2d(planes)
-        self.conv3 = nn.Conv2d(planes, planes * 2, kernel_size=1)
-        self.relu = nn.ReLU(inplace=True)
-        self.downsample = downsample
-        self.stride = stride
-
-    def forward(self, x):
-        residual = x
-
-        out = self.bn1(x)
-        out = self.relu(out)
-        out = self.conv1(out)
-
-        out = self.bn2(out)
-        out = self.relu(out)
-        out = self.conv2(out)
-
-        out = self.bn3(out)
-        out = self.relu(out)
-        out = self.conv3(out)
-
-        if self.downsample is not None:
-            residual = self.downsample(x)
-
-        out += residual
-
-        return out
-
-
-class Hourglass(nn.Module):
-    def __init__(self, block, num_blocks, planes, depth):
-        super(Hourglass, self).__init__()
-        self.depth = depth
-        self.block = block
-        self.hg = self._make_hour_glass(block, num_blocks, planes, depth)
-
-    def _make_residual(self, block, num_blocks, planes):
-        layers = []
-        for i in range(0, num_blocks):
-            layers.append(block(planes * block.expansion, planes))
-        return nn.Sequential(*layers)
-
-    def _make_hour_glass(self, block, num_blocks, planes, depth):
-        hg = []
-        for i in range(depth):
-            res = []
-            for j in range(3):
-                res.append(self._make_residual(block, num_blocks, planes))
-            if i == 0:
-                res.append(self._make_residual(block, num_blocks, planes))
-            hg.append(nn.ModuleList(res))
-        return nn.ModuleList(hg)
-
-    def _hour_glass_forward(self, n, x):
-        up1 = self.hg[n - 1][0](x)
-        low1 = F.max_pool2d(x, 2, stride=2)
-        low1 = self.hg[n - 1][1](low1)
-
-        if n > 1:
-            low2 = self._hour_glass_forward(n - 1, low1)
-        else:
-            low2 = self.hg[n - 1][3](low1)
-        low3 = self.hg[n - 1][2](low2)
-        up2 = F.interpolate(low3, scale_factor=2)
-        out = up1 + up2
-        return out
-
-    def forward(self, x):
-        return self._hour_glass_forward(self.depth, x)
-
-
-class HourglassNet(nn.Module):
-    """Hourglass model from Newell et al ECCV 2016"""
-
-    def __init__(self, block, head, depth, num_stacks, num_blocks, num_classes):
-        super(HourglassNet, self).__init__()
-
-        self.inplanes = 64
-        self.num_feats = 128
-        self.num_stacks = num_stacks
-        self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=7, stride=2, padding=3)
-        self.bn1 = nn.BatchNorm2d(self.inplanes)
-        self.relu = nn.ReLU(inplace=True)
-        self.layer1 = self._make_residual(block, self.inplanes, 1)
-        self.layer2 = self._make_residual(block, self.inplanes, 1)
-        self.layer3 = self._make_residual(block, self.num_feats, 1)
-        self.maxpool = nn.MaxPool2d(2, stride=2)
-
-        # build hourglass modules
-        ch = self.num_feats * block.expansion
-        # vpts = []
-        hg, res, fc, score, fc_, score_ = [], [], [], [], [], []
-        for i in range(num_stacks):
-            hg.append(Hourglass(block, num_blocks, self.num_feats, depth))
-            res.append(self._make_residual(block, self.num_feats, num_blocks))
-            fc.append(self._make_fc(ch, ch))
-            score.append(head(ch, num_classes))
-            # vpts.append(VptsHead(ch))
-            # vpts.append(nn.Linear(ch, 9))
-            # score.append(nn.Conv2d(ch, num_classes, kernel_size=1))
-            # score[i].bias.data[0] += 4.6
-            # score[i].bias.data[2] += 4.6
-            if i < num_stacks - 1:
-                fc_.append(nn.Conv2d(ch, ch, kernel_size=1))
-                score_.append(nn.Conv2d(num_classes, ch, kernel_size=1))
-        self.hg = nn.ModuleList(hg)
-        self.res = nn.ModuleList(res)
-        self.fc = nn.ModuleList(fc)
-        self.score = nn.ModuleList(score)
-        # self.vpts = nn.ModuleList(vpts)
-        self.fc_ = nn.ModuleList(fc_)
-        self.score_ = nn.ModuleList(score_)
-
-    def _make_residual(self, block, planes, blocks, stride=1):
-        downsample = None
-        if stride != 1 or self.inplanes != planes * block.expansion:
-            downsample = nn.Sequential(
-                nn.Conv2d(
-                    self.inplanes,
-                    planes * block.expansion,
-                    kernel_size=1,
-                    stride=stride,
-                )
-            )
-
-        layers = []
-        layers.append(block(self.inplanes, planes, stride, downsample))
-        self.inplanes = planes * block.expansion
-        for i in range(1, blocks):
-            layers.append(block(self.inplanes, planes))
-
-        return nn.Sequential(*layers)
-
-    def _make_fc(self, inplanes, outplanes):
-        bn = nn.BatchNorm2d(inplanes)
-        conv = nn.Conv2d(inplanes, outplanes, kernel_size=1)
-        return nn.Sequential(conv, bn, self.relu)
-
-    def forward(self, x):
-        out = []
-        # out_vps = []
-        x = self.conv1(x)
-        x = self.bn1(x)
-        x = self.relu(x)
-
-        x = self.layer1(x)
-        x = self.maxpool(x)
-        x = self.layer2(x)
-        x = self.layer3(x)
-
-        for i in range(self.num_stacks):
-            y = self.hg[i](x)
-            y = self.res[i](y)
-            y = self.fc[i](y)
-            score = self.score[i](y)
-            # pre_vpts = F.adaptive_avg_pool2d(x, (1, 1))
-            # pre_vpts = pre_vpts.reshape(-1, 256)
-            # vpts = self.vpts[i](x)
-            out.append(score)
-            # out_vps.append(vpts)
-            if i < self.num_stacks - 1:
-                fc_ = self.fc_[i](y)
-                score_ = self.score_[i](score)
-                x = x + fc_ + score_
-
-        return out[::-1], y  # , out_vps[::-1]
-
-
-def hg(**kwargs):
-    model = HourglassNet(
-        Bottleneck2D,
-        head=kwargs.get("head", lambda c_in, c_out: nn.Conv2D(c_in, c_out, 1)),
-        depth=kwargs["depth"],
-        num_stacks=kwargs["num_stacks"],
-        num_blocks=kwargs["num_blocks"],
-        num_classes=kwargs["num_classes"],
-    )
-    return model

+ 0 - 276
lcnn/models/line_vectorizer.py

@@ -1,276 +0,0 @@
-import itertools
-import random
-from collections import defaultdict
-
-import numpy as np
-import torch
-import torch.nn as nn
-import torch.nn.functional as F
-
-from lcnn.config import M
-
-FEATURE_DIM = 8
-
-
-class LineVectorizer(nn.Module):
-    def __init__(self, backbone):
-        super().__init__()
-        self.backbone = backbone
-
-        lambda_ = torch.linspace(0, 1, M.n_pts0)[:, None]
-        self.register_buffer("lambda_", lambda_)
-        self.do_static_sampling = M.n_stc_posl + M.n_stc_negl > 0
-
-        self.fc1 = nn.Conv2d(256, M.dim_loi, 1)
-        scale_factor = M.n_pts0 // M.n_pts1
-        if M.use_conv:
-            self.pooling = nn.Sequential(
-                nn.MaxPool1d(scale_factor, scale_factor),
-                Bottleneck1D(M.dim_loi, M.dim_loi),
-            )
-            self.fc2 = nn.Sequential(
-                nn.ReLU(inplace=True), nn.Linear(M.dim_loi * M.n_pts1 + FEATURE_DIM, 1)
-            )
-        else:
-            self.pooling = nn.MaxPool1d(scale_factor, scale_factor)
-            self.fc2 = nn.Sequential(
-                nn.Linear(M.dim_loi * M.n_pts1 + FEATURE_DIM, M.dim_fc),
-                nn.ReLU(inplace=True),
-                nn.Linear(M.dim_fc, M.dim_fc),
-                nn.ReLU(inplace=True),
-                nn.Linear(M.dim_fc, 1),
-            )
-        self.loss = nn.BCEWithLogitsLoss(reduction="none")
-
-    def forward(self, input_dict):
-        result = self.backbone(input_dict)
-        h = result["preds"]
-        x = self.fc1(result["feature"])
-        n_batch, n_channel, row, col = x.shape
-
-        xs, ys, fs, ps, idx, jcs = [], [], [], [], [0], []
-        for i, meta in enumerate(input_dict["meta"]):
-            p, label, feat, jc = self.sample_lines(
-                meta, h["jmap"][i], h["joff"][i], input_dict["mode"]
-            )
-            # print("p.shape:", p.shape)
-            ys.append(label)
-            if input_dict["mode"] == "training" and self.do_static_sampling:
-                p = torch.cat([p, meta["lpre"]])
-                feat = torch.cat([feat, meta["lpre_feat"]])
-                ys.append(meta["lpre_label"])
-                del jc
-            else:
-                jcs.append(jc)
-                ps.append(p)
-            fs.append(feat)
-
-            p = p[:, 0:1, :] * self.lambda_ + p[:, 1:2, :] * (1 - self.lambda_) - 0.5
-            p = p.reshape(-1, 2)  # [N_LINE x N_POINT, 2_XY]
-            px, py = p[:, 0].contiguous(), p[:, 1].contiguous()
-            px0 = px.floor().clamp(min=0, max=127)
-            py0 = py.floor().clamp(min=0, max=127)
-            px1 = (px0 + 1).clamp(min=0, max=127)
-            py1 = (py0 + 1).clamp(min=0, max=127)
-            px0l, py0l, px1l, py1l = px0.long(), py0.long(), px1.long(), py1.long()
-
-            # xp: [N_LINE, N_CHANNEL, N_POINT]
-            xp = (
-                (
-                        x[i, :, px0l, py0l] * (px1 - px) * (py1 - py)
-                        + x[i, :, px1l, py0l] * (px - px0) * (py1 - py)
-                        + x[i, :, px0l, py1l] * (px1 - px) * (py - py0)
-                        + x[i, :, px1l, py1l] * (px - px0) * (py - py0)
-                )
-                    .reshape(n_channel, -1, M.n_pts0)
-                    .permute(1, 0, 2)
-            )
-            xp = self.pooling(xp)
-            xs.append(xp)
-            idx.append(idx[-1] + xp.shape[0])
-
-        x, y = torch.cat(xs), torch.cat(ys)
-        f = torch.cat(fs)
-        x = x.reshape(-1, M.n_pts1 * M.dim_loi)
-        x = torch.cat([x, f], 1)
-        x = self.fc2(x.float()).flatten()
-
-        if input_dict["mode"] != "training":
-            p = torch.cat(ps)
-            s = torch.sigmoid(x)
-            b = s > 0.5
-            lines = []
-            score = []
-            for i in range(n_batch):
-                p0 = p[idx[i]: idx[i + 1]]
-                s0 = s[idx[i]: idx[i + 1]]
-                mask = b[idx[i]: idx[i + 1]]
-                p0 = p0[mask]
-                s0 = s0[mask]
-                if len(p0) == 0:
-                    lines.append(torch.zeros([1, M.n_out_line, 2, 2], device=p.device))
-                    score.append(torch.zeros([1, M.n_out_line], device=p.device))
-                else:
-                    arg = torch.argsort(s0, descending=True)
-                    p0, s0 = p0[arg], s0[arg]
-                    lines.append(p0[None, torch.arange(M.n_out_line) % len(p0)])
-                    score.append(s0[None, torch.arange(M.n_out_line) % len(s0)])
-                for j in range(len(jcs[i])):
-                    if len(jcs[i][j]) == 0:
-                        jcs[i][j] = torch.zeros([M.n_out_junc, 2], device=p.device)
-                    jcs[i][j] = jcs[i][j][
-                        None, torch.arange(M.n_out_junc) % len(jcs[i][j])
-                    ]
-            result["preds"]["lines"] = torch.cat(lines)
-            result["preds"]["score"] = torch.cat(score)
-            result["preds"]["juncs"] = torch.cat([jcs[i][0] for i in range(n_batch)])
-            result["box"] = result['aaa']
-            del result['aaa']
-            if len(jcs[i]) > 1:
-                result["preds"]["junts"] = torch.cat(
-                    [jcs[i][1] for i in range(n_batch)]
-                )
-
-        if input_dict["mode"] != "testing":
-            y = torch.cat(ys)
-            loss = self.loss(x, y)
-            lpos_mask, lneg_mask = y, 1 - y
-            loss_lpos, loss_lneg = loss * lpos_mask, loss * lneg_mask
-
-            def sum_batch(x):
-                xs = [x[idx[i]: idx[i + 1]].sum()[None] for i in range(n_batch)]
-                return torch.cat(xs)
-
-            lpos = sum_batch(loss_lpos) / sum_batch(lpos_mask).clamp(min=1)
-            lneg = sum_batch(loss_lneg) / sum_batch(lneg_mask).clamp(min=1)
-            result["losses"][0]["lpos"] = lpos * M.loss_weight["lpos"]
-            result["losses"][0]["lneg"] = lneg * M.loss_weight["lneg"]
-
-        if input_dict["mode"] == "training":
-            for i in result["aaa"].keys():
-                result["losses"][0][i] = result["aaa"][i]
-            del result["preds"]
-
-        return result
-
-    def sample_lines(self, meta, jmap, joff, mode):
-        with torch.no_grad():
-            junc = meta["junc_coords"]  # [N, 2]
-            jtyp = meta["jtyp"]  # [N]
-            Lpos = meta["line_pos_idx"]
-            Lneg = meta["line_neg_idx"]
-
-            n_type = jmap.shape[0]
-            jmap = non_maximum_suppression(jmap).reshape(n_type, -1)
-            joff = joff.reshape(n_type, 2, -1)
-            max_K = M.n_dyn_junc // n_type
-            N = len(junc)
-            if mode != "training":
-                K = min(int((jmap > M.eval_junc_thres).float().sum().item()), max_K)
-            else:
-                K = min(int(N * 2 + 2), max_K)
-            if K < 2:
-                K = 2
-            device = jmap.device
-
-            # index: [N_TYPE, K]
-            score, index = torch.topk(jmap, k=K)
-            y = (index // 128).float() + torch.gather(joff[:, 0], 1, index) + 0.5
-            x = (index % 128).float() + torch.gather(joff[:, 1], 1, index) + 0.5
-
-            # xy: [N_TYPE, K, 2]
-            xy = torch.cat([y[..., None], x[..., None]], dim=-1)
-            xy_ = xy[..., None, :]
-            del x, y, index
-
-            # dist: [N_TYPE, K, N]
-            dist = torch.sum((xy_ - junc) ** 2, -1)
-            cost, match = torch.min(dist, -1)
-
-            # xy: [N_TYPE * K, 2]
-            # match: [N_TYPE, K]
-            for t in range(n_type):
-                match[t, jtyp[match[t]] != t] = N
-            match[cost > 1.5 * 1.5] = N
-            match = match.flatten()
-
-            _ = torch.arange(n_type * K, device=device)
-            u, v = torch.meshgrid(_, _)
-            u, v = u.flatten(), v.flatten()
-            up, vp = match[u], match[v]
-            label = Lpos[up, vp]
-
-            if mode == "training":
-                c = torch.zeros_like(label, dtype=torch.bool)
-
-                # sample positive lines
-                cdx = label.nonzero().flatten()
-                if len(cdx) > M.n_dyn_posl:
-                    # print("too many positive lines")
-                    perm = torch.randperm(len(cdx), device=device)[: M.n_dyn_posl]
-                    cdx = cdx[perm]
-                c[cdx] = 1
-
-                # sample negative lines
-                cdx = Lneg[up, vp].nonzero().flatten()
-                if len(cdx) > M.n_dyn_negl:
-                    # print("too many negative lines")
-                    perm = torch.randperm(len(cdx), device=device)[: M.n_dyn_negl]
-                    cdx = cdx[perm]
-                c[cdx] = 1
-
-                # sample other (unmatched) lines
-                cdx = torch.randint(len(c), (M.n_dyn_othr,), device=device)
-                c[cdx] = 1
-            else:
-                c = (u < v).flatten()
-
-            # sample lines
-            u, v, label = u[c], v[c], label[c]
-            xy = xy.reshape(n_type * K, 2)
-            xyu, xyv = xy[u], xy[v]
-
-            u2v = xyu - xyv
-            u2v /= torch.sqrt((u2v ** 2).sum(-1, keepdim=True)).clamp(min=1e-6)
-            feat = torch.cat(
-                [
-                    xyu / 128 * M.use_cood,
-                    xyv / 128 * M.use_cood,
-                    u2v * M.use_slop,
-                    (u[:, None] > K).float(),
-                    (v[:, None] > K).float(),
-                ],
-                1,
-            )
-            line = torch.cat([xyu[:, None], xyv[:, None]], 1)
-
-            xy = xy.reshape(n_type, K, 2)
-            jcs = [xy[i, score[i] > 0.03] for i in range(n_type)]
-            return line, label.float(), feat, jcs
-
-
-def non_maximum_suppression(a):
-    ap = F.max_pool2d(a, 3, stride=1, padding=1)
-    mask = (a == ap).float().clamp(min=0.0)
-    return a * mask
-
-
-class Bottleneck1D(nn.Module):
-    def __init__(self, inplanes, outplanes):
-        super(Bottleneck1D, self).__init__()
-
-        planes = outplanes // 2
-        self.op = nn.Sequential(
-            nn.BatchNorm1d(inplanes),
-            nn.ReLU(inplace=True),
-            nn.Conv1d(inplanes, planes, kernel_size=1),
-            nn.BatchNorm1d(planes),
-            nn.ReLU(inplace=True),
-            nn.Conv1d(planes, planes, kernel_size=3, padding=1),
-            nn.BatchNorm1d(planes),
-            nn.ReLU(inplace=True),
-            nn.Conv1d(planes, outplanes, kernel_size=1),
-        )
-
-    def forward(self, x):
-        return x + self.op(x)

+ 0 - 118
lcnn/models/multitask_learner.py

@@ -1,118 +0,0 @@
-from collections import OrderedDict, defaultdict
-
-import numpy as np
-import torch
-import torch.nn as nn
-import torch.nn.functional as F
-
-from lcnn.config import M
-
-
-class MultitaskHead(nn.Module):
-    def __init__(self, input_channels, num_class):
-        super(MultitaskHead, self).__init__()
-        # print("输入的维度是:", input_channels)
-        m = int(input_channels / 4)
-        heads = []
-        for output_channels in sum(M.head_size, []):
-            heads.append(
-                nn.Sequential(
-                    nn.Conv2d(input_channels, m, kernel_size=3, padding=1),
-                    nn.ReLU(inplace=True),
-                    nn.Conv2d(m, output_channels, kernel_size=1),
-                )
-            )
-        self.heads = nn.ModuleList(heads)
-        assert num_class == sum(sum(M.head_size, []))
-
-    def forward(self, x):
-        return torch.cat([head(x) for head in self.heads], dim=1)
-
-
-class MultitaskLearner(nn.Module):
-    def __init__(self, backbone):
-        super(MultitaskLearner, self).__init__()
-        self.backbone = backbone
-        head_size = M.head_size
-        self.num_class = sum(sum(head_size, []))
-        self.head_off = np.cumsum([sum(h) for h in head_size])
-
-    def forward(self, input_dict):
-        image = input_dict["image"]
-        target_b = input_dict["target_b"]
-        outputs, feature, aaa = self.backbone(image, target_b, input_dict["mode"])  # train时aaa是损失,val时是box
-
-        result = {"feature": feature}
-        batch, channel, row, col = outputs[0].shape
-        # print(f"batch:{batch}")
-        # print(f"channel:{channel}")
-        # print(f"row:{row}")
-        # print(f"col:{col}")
-
-        T = input_dict["target"].copy()
-        n_jtyp = T["junc_map"].shape[1]
-
-        # switch to CNHW
-        for task in ["junc_map"]:
-            T[task] = T[task].permute(1, 0, 2, 3)
-        for task in ["junc_offset"]:
-            T[task] = T[task].permute(1, 2, 0, 3, 4)
-
-        offset = self.head_off
-        loss_weight = M.loss_weight
-        losses = []
-        for stack, output in enumerate(outputs):
-            output = output.transpose(0, 1).reshape([-1, batch, row, col]).contiguous()
-            jmap = output[0: offset[0]].reshape(n_jtyp, 2, batch, row, col)
-            lmap = output[offset[0]: offset[1]].squeeze(0)
-            # print(f"lmap:{lmap.shape}")
-            joff = output[offset[1]: offset[2]].reshape(n_jtyp, 2, batch, row, col)
-            if stack == 0:
-                result["preds"] = {
-                    "jmap": jmap.permute(2, 0, 1, 3, 4).softmax(2)[:, :, 1],
-                    "lmap": lmap.sigmoid(),
-                    "joff": joff.permute(2, 0, 1, 3, 4).sigmoid() - 0.5,
-                }
-                if input_dict["mode"] == "testing":
-                    return result
-
-            L = OrderedDict()
-            L["jmap"] = sum(
-                cross_entropy_loss(jmap[i], T["junc_map"][i]) for i in range(n_jtyp)
-            )
-            L["lmap"] = (
-                F.binary_cross_entropy_with_logits(lmap, T["line_map"], reduction="none")
-                    .mean(2)
-                    .mean(1)
-            )
-            L["joff"] = sum(
-                sigmoid_l1_loss(joff[i, j], T["junc_offset"][i, j], -0.5, T["junc_map"][i])
-                for i in range(n_jtyp)
-                for j in range(2)
-            )
-            for loss_name in L:
-                L[loss_name].mul_(loss_weight[loss_name])
-            losses.append(L)
-        result["losses"] = losses
-        result["aaa"] = aaa
-        return result
-
-
-def l2loss(input, target):
-    return ((target - input) ** 2).mean(2).mean(1)
-
-
-def cross_entropy_loss(logits, positive):
-    nlogp = -F.log_softmax(logits, dim=0)
-    return (positive * nlogp[1] + (1 - positive) * nlogp[0]).mean(2).mean(1)
-
-
-def sigmoid_l1_loss(logits, target, offset=0.0, mask=None):
-    logp = torch.sigmoid(logits) + offset
-    loss = torch.abs(logp - target)
-    if mask is not None:
-        w = mask.mean(2, True).mean(1, True)
-        w[w == 0] = 1
-        loss = loss * (mask / w)
-
-    return loss.mean(2).mean(1)

+ 0 - 182
lcnn/models/resnet50.py

@@ -1,182 +0,0 @@
-# import torch
-# import torch.nn as nn
-# import torchvision.models as models
-#
-#
-# class ResNet50Backbone(nn.Module):
-#     def __init__(self, num_classes=5, num_stacks=1, pretrained=True):
-#         super(ResNet50Backbone, self).__init__()
-#
-#         # 加载预训练的ResNet50
-#         if pretrained:
-#             resnet = models.resnet50(weights=models.ResNet50_Weights.IMAGENET1K_V1)
-#         else:
-#             resnet = models.resnet50(weights=None)
-#
-#         # 移除最后的全连接层
-#         self.backbone = nn.Sequential(
-#             resnet.conv1,#特征图分辨率降低为1/2,通道数从3升为64
-#             resnet.bn1,
-#             resnet.relu,
-#             resnet.maxpool,#特征图分辨率降低为1/4,通道数仍然为64
-#             resnet.layer1,#stride为1,不改变分辨率,依然为1/4,通道数从64升为64*4=256
-#             # resnet.layer2,#stride为2,特征图分辨率降低为1/8,通道数从256升为128*4=512
-#             # resnet.layer3,#stride为2,特征图分辨率降低为1/16,通道数从512升为256*4=1024
-#             # resnet.layer4,#stride为2,特征图分辨率降低为1/32,通道数从512升为256*4=2048
-#         )
-#
-#         # 多任务输出层
-#         self.score_layers = nn.ModuleList([
-#             nn.Sequential(
-#                 nn.Conv2d(256, 128, kernel_size=3, padding=1),
-#                 nn.BatchNorm2d(128),
-#                 nn.ReLU(inplace=True),
-#                 nn.Conv2d(128, num_classes, kernel_size=1)
-#             )
-#             for _ in range(num_stacks)
-#         ])
-#
-#         # 上采样层,确保输出大小为128x128
-#         self.upsample = nn.Upsample(
-#             scale_factor=0.25,
-#             mode='bilinear',
-#             align_corners=True
-#         )
-#
-#     def forward(self, x):
-#         # 主干网络特征提取
-#         x = self.backbone(x)
-#
-#         # # 调整通道数
-#         # x = self.channel_adjust(x)
-#         #
-#         # # 上采样到128x128
-#         # x = self.upsample(x)
-#
-#         # 多堆栈输出
-#         outputs = []
-#         for score_layer in self.score_layers:
-#             output = score_layer(x)
-#             outputs.append(output)
-#
-#         # 返回第一个输出(如果有多个堆栈)
-#         return outputs, x
-#
-#
-# def resnet50(**kwargs):
-#     model = ResNet50Backbone(
-#         num_classes=kwargs.get("num_classes", 5),
-#         num_stacks=kwargs.get("num_stacks", 1),
-#         pretrained=kwargs.get("pretrained", True)
-#     )
-#     return model
-#
-#
-# __all__ = ["ResNet50Backbone", "resnet50"]
-#
-#
-# # 测试网络输出
-# model = resnet50(num_classes=5, num_stacks=1)
-#
-#
-# # 方法1:直接传入图像张量
-# x = torch.randn(2, 3, 512, 512)
-# outputs, feature = model(x)
-# print("Outputs length:", len(outputs))
-# print("Output[0] shape:", outputs[0].shape)
-# print("Feature shape:", feature.shape)
-
-import torch
-import torch.nn as nn
-import torchvision
-import torchvision.models as models
-from torchvision.models.detection.backbone_utils import _validate_trainable_layers, _resnet_fpn_extractor
-
-
-class ResNet50Backbone1(nn.Module):
-    def __init__(self, num_classes=5, num_stacks=1, pretrained=True):
-        super(ResNet50Backbone1, self).__init__()
-
-        # 加载预训练的ResNet50
-        if pretrained:
-            # self.resnet = models.resnet50(weights=models.ResNet50_Weights.IMAGENET1K_V1)
-            trainable_backbone_layers = _validate_trainable_layers(True, None, 5, 3)
-            backbone = models.resnet50(weights=models.ResNet50_Weights.IMAGENET1K_V1)
-            self.resnet = _resnet_fpn_extractor(backbone, trainable_backbone_layers)
-        else:
-            self.resnet = models.resnet50(weights=None)
-
-    def forward(self, x):
-        features = self.resnet(x)
-
-        # if self.training:
-        #     proposals, matched_idxs, labels, regression_targets = self.select_training_samples20(proposals, targets)
-        # else:
-        #     labels = None
-        #     regression_targets = None
-        #     matched_idxs = None
-        # box_features = self.box_roi_pool(features, proposals, image_shapes)  # ROI 映射到固定尺寸的特征表示上
-        # # print(f"box_features:{box_features.shape}")  # [2048, 256, 7, 7] 建议框统一大小
-        # box_features = self.box_head(box_features)  #
-        # # print(f"box_features:{box_features.shape}")  # [N, 1024] 经过头部网络处理后的特征向量
-        # class_logits, box_regression = self.box_predictor(box_features)
-        #
-        # result: List[Dict[str, torch.Tensor]] = []
-        # losses = {}
-        # if self.training:
-        #     if labels is None:
-        #         raise ValueError("labels cannot be None")
-        #     if regression_targets is None:
-        #         raise ValueError("regression_targets cannot be None")
-        #     loss_classifier, loss_box_reg = fastrcnn_loss(class_logits, box_regression, labels, regression_targets)
-        #     losses = {"loss_classifier": loss_classifier, "loss_box_reg": loss_box_reg}
-        # else:
-        #     boxes, scores, labels = self.postprocess_detections(class_logits, box_regression, proposals, image_shapes)
-        #     num_images = len(boxes)
-        #     for i in range(num_images):
-        #         result.append(
-        #             {
-        #                 "boxes": boxes[i],
-        #                 "labels": labels[i],
-        #                 "scores": scores[i],
-        #             }
-        #         )
-        #     print(f"boxes:{boxes[0].shape}")
-
-        return x['0']
-
-        # # 多堆栈输出
-        # outputs = []
-        # for score_layer in self.score_layers:
-        #     output = score_layer(x)
-        #     outputs.append(output)
-        #
-        # # 返回第一个输出(如果有多个堆栈)
-        # return outputs, x
-
-
-def resnet501(**kwargs):
-    model = ResNet50Backbone1(
-        num_classes=kwargs.get("num_classes", 5),
-        num_stacks=kwargs.get("num_stacks", 1),
-        pretrained=kwargs.get("pretrained", True)
-    )
-    return model
-
-
-# __all__ = ["ResNet50Backbone1", "resnet501"]
-#
-# # 测试网络输出
-# model = resnet501(num_classes=5, num_stacks=1)
-#
-# # 方法1:直接传入图像张量
-# x = torch.randn(2, 3, 512, 512)
-# # outputs, feature = model(x)
-# # print("Outputs length:", len(outputs))
-# # print("Output[0] shape:", outputs[0].shape)
-# # print("Feature shape:", feature.shape)
-# feature = model(x)
-# print("Feature shape:", feature.keys())
-
-
-

+ 0 - 87
lcnn/models/resnet50_pose.py

@@ -1,87 +0,0 @@
-import torch
-import torch.nn as nn
-import torchvision.models as models
-
-
-class ResNet50Backbone(nn.Module):
-    def __init__(self, num_classes=5, num_stacks=1, pretrained=True):
-        super(ResNet50Backbone, self).__init__()
-
-        # 加载预训练的ResNet50
-        if pretrained:
-            resnet = models.resnet50(weights=models.ResNet50_Weights.IMAGENET1K_V1)
-        else:
-            resnet = models.resnet50(weights=None)
-
-        # 移除最后的全连接层
-        self.backbone = nn.Sequential(
-            resnet.conv1,#特征图分辨率降低为1/2,通道数从3升为64
-            resnet.bn1,
-            resnet.relu,
-            resnet.maxpool,#特征图分辨率降低为1/4,通道数仍然为64
-            resnet.layer1,#stride为1,不改变分辨率,依然为1/4,通道数从64升为64*4=256
-            # resnet.layer2,#stride为2,特征图分辨率降低为1/8,通道数从256升为128*4=512
-            # resnet.layer3,#stride为2,特征图分辨率降低为1/16,通道数从512升为256*4=1024
-            # resnet.layer4,#stride为2,特征图分辨率降低为1/32,通道数从512升为256*4=2048
-        )
-
-        # 多任务输出层
-        self.score_layers = nn.ModuleList([
-            nn.Sequential(
-                nn.Conv2d(256, 128, kernel_size=3, padding=1),
-                nn.BatchNorm2d(128),
-                nn.ReLU(inplace=True),
-                nn.Conv2d(128, num_classes, kernel_size=1)
-            )
-            for _ in range(num_stacks)
-        ])
-
-        # 上采样层,确保输出大小为128x128
-        self.upsample = nn.Upsample(
-            scale_factor=0.25,
-            mode='bilinear',
-            align_corners=True
-        )
-
-    def forward(self, x):
-        # 主干网络特征提取
-        x = self.backbone(x)
-
-        # # 调整通道数
-        # x = self.channel_adjust(x)
-        #
-        # # 上采样到128x128
-        # x = self.upsample(x)
-
-        # 多堆栈输出
-        outputs = []
-        for score_layer in self.score_layers:
-            output = score_layer(x)
-            outputs.append(output)
-
-        # 返回第一个输出(如果有多个堆栈)
-        return outputs, x
-
-
-def resnet50(**kwargs):
-    model = ResNet50Backbone(
-        num_classes=kwargs.get("num_classes", 5),
-        num_stacks=kwargs.get("num_stacks", 1),
-        pretrained=kwargs.get("pretrained", True)
-    )
-    return model
-
-
-__all__ = ["ResNet50Backbone", "resnet50"]
-
-
-# 测试网络输出
-model = resnet50(num_classes=5, num_stacks=1)
-
-
-# 方法1:直接传入图像张量
-x = torch.randn(2, 3, 512, 512)
-outputs, feature = model(x)
-print("Outputs length:", len(outputs))
-print("Output[0] shape:", outputs[0].shape)
-print("Feature shape:", feature.shape)

+ 0 - 126
lcnn/models/unet.py

@@ -1,126 +0,0 @@
-import torch
-import torch.nn as nn
-import torch.nn.functional as F
-
-__all__ = ["UNetWithMultipleStacks", "unet"]
-
-
-class DoubleConv(nn.Module):
-    def __init__(self, in_channels, out_channels):
-        super().__init__()
-        self.double_conv = nn.Sequential(
-            nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
-            nn.BatchNorm2d(out_channels),
-            nn.ReLU(inplace=True),
-            nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1),
-            nn.BatchNorm2d(out_channels),
-            nn.ReLU(inplace=True)
-        )
-
-    def forward(self, x):
-        return self.double_conv(x)
-
-
-class UNetWithMultipleStacks(nn.Module):
-    def __init__(self, num_classes, num_stacks=2, base_channels=64):
-        super().__init__()
-        self.num_stacks = num_stacks
-        # 编码器
-        self.enc1 = DoubleConv(3, base_channels)
-        self.pool1 = nn.MaxPool2d(2)
-        self.enc2 = DoubleConv(base_channels, base_channels * 2)
-        self.pool2 = nn.MaxPool2d(2)
-        self.enc3 = DoubleConv(base_channels * 2, base_channels * 4)
-        self.pool3 = nn.MaxPool2d(2)
-        self.enc4 = DoubleConv(base_channels * 4, base_channels * 8)
-        self.pool4 = nn.MaxPool2d(2)
-        self.enc5 = DoubleConv(base_channels * 8, base_channels * 16)
-        self.pool5 = nn.MaxPool2d(2)
-
-        # bottleneck
-        self.bottleneck = DoubleConv(base_channels * 16, base_channels * 32)
-
-        # 解码器
-        self.upconv5 = nn.ConvTranspose2d(base_channels * 32, base_channels * 16, kernel_size=2, stride=2)
-        self.dec5 = DoubleConv(base_channels * 32, base_channels * 16)
-        self.upconv4 = nn.ConvTranspose2d(base_channels * 16, base_channels * 8, kernel_size=2, stride=2)
-        self.dec4 = DoubleConv(base_channels * 16, base_channels * 8)
-        self.upconv3 = nn.ConvTranspose2d(base_channels * 8, base_channels * 4, kernel_size=2, stride=2)
-        self.dec3 = DoubleConv(base_channels * 8, base_channels * 4)
-        self.upconv2 = nn.ConvTranspose2d(base_channels * 4, base_channels * 2, kernel_size=2, stride=2)
-        self.dec2 = DoubleConv(base_channels * 4, base_channels * 2)
-        self.upconv1 = nn.ConvTranspose2d(base_channels * 2, base_channels, kernel_size=2, stride=2)
-        self.dec1 = DoubleConv(base_channels * 2, base_channels)
-
-        # 额外的上采样层,从512降到128
-        self.final_upsample = nn.Sequential(
-            nn.Conv2d(base_channels, base_channels, kernel_size=3, padding=1),
-            nn.BatchNorm2d(base_channels),
-            nn.ReLU(inplace=True),
-            nn.Upsample(scale_factor=0.25, mode='bilinear', align_corners=True)
-        )
-
-        # 修改score_layers以匹配256通道
-        self.score_layers = nn.ModuleList([
-            nn.Sequential(
-                nn.Conv2d(256, 128, kernel_size=3, padding=1),
-                nn.ReLU(inplace=True),
-                nn.Conv2d(128, num_classes, kernel_size=1)
-            )
-            for _ in range(num_stacks)
-        ])
-
-        self.channel_adjust = nn.Conv2d(base_channels, 256, kernel_size=1)
-
-    def forward(self, x):
-        # 编码过程
-        enc1 = self.enc1(x)
-        enc2 = self.enc2(self.pool1(enc1))
-        enc3 = self.enc3(self.pool2(enc2))
-        enc4 = self.enc4(self.pool3(enc3))
-        enc5 = self.enc5(self.pool4(enc4))
-
-        # 瓶颈层
-        bottleneck = self.bottleneck(self.pool5(enc5))
-
-        # 解码过程
-        dec5 = self.upconv5(bottleneck)
-        dec5 = torch.cat([dec5, enc5], dim=1)
-        dec5 = self.dec5(dec5)
-
-        dec4 = self.upconv4(dec5)
-        dec4 = torch.cat([dec4, enc4], dim=1)
-        dec4 = self.dec4(dec4)
-
-        dec3 = self.upconv3(dec4)
-        dec3 = torch.cat([dec3, enc3], dim=1)
-        dec3 = self.dec3(dec3)
-
-        dec2 = self.upconv2(dec3)
-        dec2 = torch.cat([dec2, enc2], dim=1)
-        dec2 = self.dec2(dec2)
-
-        dec1 = self.upconv1(dec2)
-        dec1 = torch.cat([dec1, enc1], dim=1)
-        dec1 = self.dec1(dec1)
-
-        # 额外的上采样,使输出大小为128
-        dec1 = self.final_upsample(dec1)
-        # 调整通道数
-        dec1 = self.channel_adjust(dec1)
-        # 多堆栈输出
-        outputs = []
-        for score_layer in self.score_layers:
-            output = score_layer(dec1)
-            outputs.append(output)
-
-        return outputs[::-1], dec1
-
-
-def unet(**kwargs):
-    model = UNetWithMultipleStacks(
-        num_classes=kwargs["num_classes"],
-        num_stacks=kwargs.get("num_stacks", 2),
-        base_channels=kwargs.get("base_channels", 64)
-    )
-    return model

+ 0 - 77
lcnn/postprocess.py

@@ -1,77 +0,0 @@
-import numpy as np
-
-
-def pline(x1, y1, x2, y2, x, y):
-    px = x2 - x1
-    py = y2 - y1
-    dd = px * px + py * py
-    u = ((x - x1) * px + (y - y1) * py) / max(1e-9, float(dd))
-    dx = x1 + u * px - x
-    dy = y1 + u * py - y
-    return dx * dx + dy * dy
-
-
-def psegment(x1, y1, x2, y2, x, y):
-    px = x2 - x1
-    py = y2 - y1
-    dd = px * px + py * py
-    u = max(min(((x - x1) * px + (y - y1) * py) / float(dd), 1), 0)
-    dx = x1 + u * px - x
-    dy = y1 + u * py - y
-    return dx * dx + dy * dy
-
-
-def plambda(x1, y1, x2, y2, x, y):
-    px = x2 - x1
-    py = y2 - y1
-    dd = px * px + py * py
-    return ((x - x1) * px + (y - y1) * py) / max(1e-9, float(dd))
-
-
-def postprocess(lines, scores, threshold=0.01, tol=1e9, do_clip=False):
-    nlines, nscores = [], []
-    for (p, q), score in zip(lines, scores):
-        start, end = 0, 1
-        for a, b in nlines:
-            if (
-                min(
-                    max(pline(*p, *q, *a), pline(*p, *q, *b)),
-                    max(pline(*a, *b, *p), pline(*a, *b, *q)),
-                )
-                > threshold ** 2
-            ):
-                continue
-            lambda_a = plambda(*p, *q, *a)
-            lambda_b = plambda(*p, *q, *b)
-            if lambda_a > lambda_b:
-                lambda_a, lambda_b = lambda_b, lambda_a
-            lambda_a -= tol
-            lambda_b += tol
-
-            # case 1: skip (if not do_clip)
-            if start < lambda_a and lambda_b < end:
-                continue
-
-            # not intersect
-            if lambda_b < start or lambda_a > end:
-                continue
-
-            # cover
-            if lambda_a <= start and end <= lambda_b:
-                start = 10
-                break
-
-            # case 2 & 3:
-            if lambda_a <= start and start <= lambda_b:
-                start = lambda_b
-            if lambda_a <= end and end <= lambda_b:
-                end = lambda_a
-
-            if start >= end:
-                break
-
-        if start >= end:
-            continue
-        nlines.append(np.array([p + (q - p) * start, p + (q - p) * end]))
-        nscores.append(score)
-    return np.array(nlines), np.array(nscores)

+ 0 - 2
lcnn/trainer.py

@@ -192,8 +192,6 @@ class Trainer(object):
                 # print(result.keys())
                 self.show_line(image[0], result, self.epoch, self.writer)
 
-
-
                 # H = result["preds"]
                 # for i in range(H["jmap"].shape[0]):
                 #     index = batch_idx * M.batch_size_eval + i

+ 0 - 101
lcnn/utils.py

@@ -1,101 +0,0 @@
-import math
-import os.path as osp
-import multiprocessing
-from timeit import default_timer as timer
-
-import numpy as np
-import torch
-import matplotlib.pyplot as plt
-
-
-class benchmark(object):
-    def __init__(self, msg, enable=True, fmt="%0.3g"):
-        self.msg = msg
-        self.fmt = fmt
-        self.enable = enable
-
-    def __enter__(self):
-        if self.enable:
-            self.start = timer()
-        return self
-
-    def __exit__(self, *args):
-        if self.enable:
-            t = timer() - self.start
-            print(("%s : " + self.fmt + " seconds") % (self.msg, t))
-            self.time = t
-
-
-def quiver(x, y, ax):
-    ax.set_xlim(0, x.shape[1])
-    ax.set_ylim(x.shape[0], 0)
-    ax.quiver(
-        x,
-        y,
-        units="xy",
-        angles="xy",
-        scale_units="xy",
-        scale=1,
-        minlength=0.01,
-        width=0.1,
-        color="b",
-    )
-
-
-def recursive_to(input, device):
-    if isinstance(input, torch.Tensor):
-        return input.to(device)
-    if isinstance(input, dict):
-        for name in input:
-            if isinstance(input[name], torch.Tensor):
-                input[name] = input[name].to(device)
-        return input
-    if isinstance(input, list):
-        for i, item in enumerate(input):
-            input[i] = recursive_to(item, device)
-        return input
-    assert False
-
-
-def np_softmax(x, axis=0):
-    """Compute softmax values for each sets of scores in x."""
-    e_x = np.exp(x - np.max(x))
-    return e_x / e_x.sum(axis=axis, keepdims=True)
-
-
-def argsort2d(arr):
-    return np.dstack(np.unravel_index(np.argsort(arr.ravel()), arr.shape))[0]
-
-
-def __parallel_handle(f, q_in, q_out):
-    while True:
-        i, x = q_in.get()
-        if i is None:
-            break
-        q_out.put((i, f(x)))
-
-
-def parmap(f, X, nprocs=multiprocessing.cpu_count(), progress_bar=lambda x: x):
-    if nprocs == 0:
-        nprocs = multiprocessing.cpu_count()
-    q_in = multiprocessing.Queue(1)
-    q_out = multiprocessing.Queue()
-
-    proc = [
-        multiprocessing.Process(target=__parallel_handle, args=(f, q_in, q_out))
-        for _ in range(nprocs)
-    ]
-    for p in proc:
-        p.daemon = True
-        p.start()
-
-    try:
-        sent = [q_in.put((i, x)) for i, x in enumerate(X)]
-        [q_in.put((None, None)) for _ in range(nprocs)]
-        res = [q_out.get() for _ in progress_bar(range(len(sent)))]
-        [p.join() for p in proc]
-    except KeyboardInterrupt:
-        q_in.close()
-        q_out.close()
-        raise
-    return [x for i, x in sorted(res)]

+ 124 - 0
models/base/base_detection_net.py

@@ -0,0 +1,124 @@
+"""
+Implements the Generalized R-CNN framework
+"""
+
+import warnings
+from collections import OrderedDict
+from typing import Dict, List, Optional, Tuple, Union
+
+import torch
+from torch import nn, Tensor
+
+from libs.vision_libs.utils import _log_api_usage_once
+
+
+class BaseDetectionNet(nn.Module):
+    """
+    Main class for Generalized R-CNN.
+
+    Args:
+        backbone (nn.Module):
+        rpn (nn.Module):
+        roi_heads (nn.Module): takes the features + the proposals from the RPN and computes
+            detections / masks from it.
+        transform (nn.Module): performs the data transformation from the inputs to feed into
+            the model
+    """
+
+    def __init__(self, backbone: nn.Module, rpn: nn.Module, roi_heads: nn.Module, transform: nn.Module) -> None:
+        super().__init__()
+        _log_api_usage_once(self)
+        self.transform = transform
+        self.backbone = backbone
+        self.rpn = rpn
+        self.roi_heads = roi_heads
+        # used only on torchscript mode
+        self._has_warned = False
+
+    @torch.jit.unused
+    def eager_outputs(self, losses, detections):
+        # type: (Dict[str, Tensor], List[Dict[str, Tensor]]) -> Union[Dict[str, Tensor], List[Dict[str, Tensor]]]
+        if self.training:
+            return losses
+
+        return detections
+
+    def forward(self, images, targets=None):
+        # type: (List[Tensor], Optional[List[Dict[str, Tensor]]]) -> Tuple[Dict[str, Tensor], List[Dict[str, Tensor]]]
+        """
+        Args:
+            images (list[Tensor]): images to be processed
+            targets (list[Dict[str, Tensor]]): ground-truth boxes present in the image (optional)
+
+        Returns:
+            result (list[BoxList] or dict[Tensor]): the output from the model.
+                During training, it returns a dict[Tensor] which contains the losses.
+                During testing, it returns list[BoxList] contains additional fields
+                like `scores`, `labels` and `mask` (for Mask R-CNN models).
+
+        """
+        if self.training:
+            if targets is None:
+                torch._assert(False, "targets should not be none when in training mode")
+            else:
+                for target in targets:
+                    boxes = target["boxes"]
+                    if isinstance(boxes, torch.Tensor):
+                        torch._assert(
+                            len(boxes.shape) == 2 and boxes.shape[-1] == 4,
+                            f"Expected target boxes to be a tensor of shape [N, 4], got {boxes.shape}.",
+                        )
+                    else:
+                        torch._assert(False, f"Expected target boxes to be of type Tensor, got {type(boxes)}.")
+
+        original_image_sizes: List[Tuple[int, int]] = []
+        for img in images:
+            val = img.shape[-2:]
+            torch._assert(
+                len(val) == 2,
+                f"expecting the last two dimensions of the Tensor to be H and W instead got {img.shape[-2:]}",
+            )
+            original_image_sizes.append((val[0], val[1]))
+
+        images, targets = self.transform(images, targets)
+
+        # Check for degenerate boxes
+        # TODO: Move this to a function
+        if targets is not None:
+            for target_idx, target in enumerate(targets):
+                boxes = target["boxes"]
+                degenerate_boxes = boxes[:, 2:] <= boxes[:, :2]
+                if degenerate_boxes.any():
+                    # print the first degenerate box
+                    bb_idx = torch.where(degenerate_boxes.any(dim=1))[0][0]
+                    degen_bb: List[float] = boxes[bb_idx].tolist()
+                    torch._assert(
+                        False,
+                        "All bounding boxes should have positive height and width."
+                        f" Found invalid box {degen_bb} for target at index {target_idx}.",
+                    )
+
+        features = self.backbone(images.tensors)
+
+        if isinstance(features, torch.Tensor):
+            features = OrderedDict([("0", features)])
+        proposals, proposal_losses = self.rpn(images, features, targets)
+
+        detections, detector_losses = self.roi_heads(features, proposals, images.image_sizes, targets)
+        detections = self.transform.postprocess(detections, images.image_sizes, original_image_sizes)  # type: ignore[operator]
+
+        # ->multi task head
+        # ->learner,->vectorize
+
+
+        losses = {}
+        losses.update(detector_losses)
+        losses.update(proposal_losses)
+
+        if torch.jit.is_scripting():
+            if not self._has_warned:
+                warnings.warn("RCNN always returns a (Losses, Detections) tuple in scripting")
+                self._has_warned = True
+            return losses, detections
+        else:
+            return self.eager_outputs(losses, detections)

+ 1 - 6
models/dataset_tool.py

@@ -224,17 +224,12 @@ def line_boxes(target):
     lines = lpre
     sline = np.ones(lpre.shape[0])
 
-    keypoints = []
-
     if len(lines) > 0 and not (lines[0] == 0).all():
         for i, ((a, b), s) in enumerate(zip(lines, sline)):
             if i > 0 and (lines[i] == lines[0]).all():
                 break
             # plt.plot([a[1], b[1]], [a[0], b[0]], c="red", linewidth=1)  # a[1], b[1]无明确大小
 
-            keypoints.append([a[0], b[0]])
-            keypoints.append([a[1], b[1]])
-
             if a[1] > b[1]:
                 ymax = a[1] + 10
                 ymin = b[1] - 10
@@ -249,7 +244,7 @@ def line_boxes(target):
                 xmax = b[0] + 10
             boxs.append([ymin, xmin, ymax, xmax])
 
-    return torch.tensor(boxs), torch.tensor(keypoints)
+    return torch.tensor(boxs)
 
 
 def read_polygon_points_wire(lbl_path, shape):

+ 0 - 0
lcnn/models/base/__init__.py → models/ins_detect/__init__.py


+ 143 - 0
models/ins_detect/maskrcnn.py

@@ -0,0 +1,143 @@
+import math
+import os
+import sys
+from datetime import datetime
+from typing import Mapping, Any
+import cv2
+import numpy as np
+import torch
+import torchvision
+from torch import nn
+from torchvision.io import read_image
+from torchvision.models.detection import MaskRCNN_ResNet50_FPN_V2_Weights
+from torchvision.models.detection.faster_rcnn import FastRCNNPredictor
+from torchvision.models.detection.mask_rcnn import MaskRCNNPredictor
+from torchvision.utils import draw_bounding_boxes
+
+from models.config.config_tool import read_yaml
+from models.ins_detect.trainer import train_cfg
+from tools import utils
+
+
+class MaskRCNNModel(nn.Module):
+
+    def __init__(self, num_classes=0, transforms=None):
+        super(MaskRCNNModel, self).__init__()
+        self.__model = torchvision.models.detection.maskrcnn_resnet50_fpn_v2(
+            weights=MaskRCNN_ResNet50_FPN_V2_Weights.DEFAULT)
+        if transforms is None:
+            self.transforms = MaskRCNN_ResNet50_FPN_V2_Weights.DEFAULT.transforms()
+        if num_classes != 0:
+            self.set_num_classes(num_classes)
+            # self.__num_classes=0
+
+        self.device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
+
+    def forward(self, inputs):
+        outputs = self.__model(inputs)
+        return outputs
+
+    def train(self, cfg):
+        parameters = read_yaml(cfg)
+        num_classes=parameters['num_classes']
+        # print(f'num_classes:{num_classes}')
+        self.set_num_classes(num_classes)
+        train_cfg(self.__model, cfg)
+
+    def set_num_classes(self, num_classes):
+        in_features = self.__model.roi_heads.box_predictor.cls_score.in_features
+        self.__model.roi_heads.box_predictor = FastRCNNPredictor(in_features, num_classes=num_classes)
+        in_features_mask = self.__model.roi_heads.mask_predictor.conv5_mask.in_channels
+        hidden_layer = 256
+        self.__model.roi_heads.mask_predictor = MaskRCNNPredictor(in_features_mask, hidden_layer,
+                                                                  num_classes=num_classes)
+
+    def load_weight(self, pt_path):
+        state_dict = torch.load(pt_path)
+        self.__model.load_state_dict(state_dict)
+
+    def load_state_dict(self, state_dict: Mapping[str, Any], strict: bool = True):
+        self.__model.load_state_dict(state_dict)
+        # return super().load_state_dict(state_dict, strict)
+
+    def predict(self, src, show_box=True, show_mask=True):
+        self.__model.eval()
+
+        img = read_image(src)
+        img = self.transforms(img)
+        img = img.to(self.device)
+        result = self.__model([img])
+        print(f'result:{result}')
+        masks = result[0]['masks']
+        boxes = result[0]['boxes']
+        # cv2.imshow('mask',masks[0].cpu().detach().numpy())
+        boxes = boxes.cpu().detach()
+        drawn_boxes = draw_bounding_boxes((img * 255).to(torch.uint8), boxes, colors="red", width=5)
+        print(f'drawn_boxes:{drawn_boxes.shape}')
+        boxed_img = drawn_boxes.permute(1, 2, 0).numpy()
+        # boxed_img=cv2.resize(boxed_img,(800,800))
+        # cv2.imshow('boxes',boxed_img)
+
+        mask = masks[0].cpu().detach().permute(1, 2, 0).numpy()
+
+        mask = cv2.resize(mask, (800, 800))
+        # cv2.imshow('mask',mask)
+        img = img.cpu().detach().permute(1, 2, 0).numpy()
+
+        masked_img = self.overlay_masks_on_image(boxed_img, masks)
+        masked_img = cv2.resize(masked_img, (800, 800))
+        cv2.imshow('img_masks', masked_img)
+        # show_img_boxes_masks(img, boxes, masks)
+        cv2.waitKey(0)
+
+    def generate_colors(self, n):
+        """
+        生成n个均匀分布在HSV色彩空间中的颜色,并转换成BGR色彩空间。
+
+        :param n: 需要的颜色数量
+        :return: 一个包含n个颜色的列表,每个颜色为BGR格式的元组
+        """
+        hsv_colors = [(i / n * 180, 1 / 3 * 255, 2 / 3 * 255) for i in range(n)]
+        bgr_colors = [tuple(map(int, cv2.cvtColor(np.uint8([[hsv]]), cv2.COLOR_HSV2BGR)[0][0])) for hsv in hsv_colors]
+        return bgr_colors
+
+    def overlay_masks_on_image(self, image, masks, alpha=0.6):
+        """
+        在原图上叠加多个掩码,每个掩码使用不同的颜色。
+
+        :param image: 原图 (NumPy 数组)
+        :param masks: 掩码列表 (每个都是 NumPy 数组,二值图像)
+        :param colors: 颜色列表 (每个颜色都是 (B, G, R) 格式的元组)
+        :param alpha: 掩码的透明度 (0.0 到 1.0)
+        :return: 叠加了多个掩码的图像
+        """
+        colors = self.generate_colors(len(masks))
+        if len(masks) != len(colors):
+            raise ValueError("The number of masks and colors must be the same.")
+
+        # 复制原图,避免修改原始图像
+        overlay = image.copy()
+
+        for mask, color in zip(masks, colors):
+            # 确保掩码是二值图像
+            mask = mask.cpu().detach().permute(1, 2, 0).numpy()
+            binary_mask = (mask > 0).astype(np.uint8) * 255  # 你可以根据实际情况调整阈值
+
+            # 创建彩色掩码
+            colored_mask = np.zeros_like(image)
+
+            colored_mask[:] = color
+            colored_mask = cv2.bitwise_and(colored_mask, colored_mask, mask=binary_mask)
+
+            # 将彩色掩码与当前的叠加图像混合
+            overlay = cv2.addWeighted(overlay, 1 - alpha, colored_mask, alpha, 0)
+
+        return overlay
+
+
+if __name__ == '__main__':
+    # ins_model = MaskRCNNModel(num_classes=5)
+    ins_model = MaskRCNNModel()
+    # data_path = r'F:\DevTools\datasets\renyaun\1012\spilt'
+    # ins_model.train(data_dir=data_path,epochs=5000,target_type='pixel',batch_size=6,num_workers=10,num_classes=5)
+    ins_model.train(cfg='train.yaml')

+ 93 - 0
models/ins_detect/maskrcnn_dataset.py

@@ -0,0 +1,93 @@
+import os
+
+import PIL
+import cv2
+import numpy as np
+import torch
+from matplotlib import pyplot as plt
+from torch.utils.data import Dataset
+from torchvision.models.detection import MaskRCNN_ResNet50_FPN_V2_Weights
+
+from models.dataset_tool import masks_to_boxes, read_masks_from_txt, read_masks_from_pixels
+
+
+class MaskRCNNDataset(Dataset):
+    def __init__(self, dataset_path, transforms=None, dataset_type=None, target_type='polygon'):
+        self.data_path = dataset_path
+        self.transforms = transforms
+        self.img_path = os.path.join(dataset_path, "images/" + dataset_type)
+        self.lbl_path = os.path.join(dataset_path, "labels/" + dataset_type)
+        self.imgs = os.listdir(self.img_path)
+        self.lbls = os.listdir(self.lbl_path)
+        self.target_type = target_type
+        self.deafult_transform= MaskRCNN_ResNet50_FPN_V2_Weights.DEFAULT.transforms()
+        # print('maskrcnn inited!')
+
+    def __getitem__(self, item):
+        # print('__getitem__')
+        img_path = os.path.join(self.img_path, self.imgs[item])
+        lbl_path = os.path.join(self.lbl_path, self.imgs[item][:-3] + 'txt')
+        img = PIL.Image.open(img_path).convert('RGB')
+        # h, w = np.array(img).shape[:2]
+        w, h = img.size
+        # print(f'h,w:{h, w}')
+        target = self.read_target(item=item, lbl_path=lbl_path, shape=(h, w))
+        if self.transforms:
+            img, target = self.transforms(img,target)
+        else:
+            img=self.deafult_transform(img)
+        # print(f'img:{img.shape},target:{target}')
+        return img, target
+
+    def create_masks_from_polygons(self, polygons, image_shape):
+        """创建一个与图像尺寸相同的掩码,并填充多边形轮廓"""
+        colors = np.array([plt.cm.Spectral(each) for each in np.linspace(0, 1, len(polygons))])
+        masks = []
+
+        for polygon_data, col in zip(polygons, colors):
+            mask = np.zeros(image_shape[:2], dtype=np.uint8)
+            # 将多边形顶点转换为 NumPy 数组
+            _, polygon = polygon_data
+            pts = np.array(polygon, np.int32).reshape((-1, 1, 2))
+
+            # 使用 OpenCV 的 fillPoly 函数填充多边形
+            # print(f'color:{col[:3]}')
+            cv2.fillPoly(mask, [pts], np.array(col[:3]) * 255)
+            mask = torch.from_numpy(mask)
+            mask[mask != 0] = 1
+            masks.append(mask)
+
+        return masks
+
+    def read_target(self, item, lbl_path, shape):
+        # print(f'lbl_path:{lbl_path}')
+        h, w = shape
+        labels = []
+        masks = []
+        if self.target_type == 'polygon':
+            labels, masks = read_masks_from_txt(lbl_path, shape)
+        elif self.target_type == 'pixel':
+            labels, masks = read_masks_from_pixels(lbl_path, shape)
+
+        target = {}
+        target["boxes"] = masks_to_boxes(torch.stack(masks))
+        target["labels"] = torch.stack(labels)
+        target["masks"] = torch.stack(masks)
+        target["image_id"] = torch.tensor(item)
+        target["area"] = torch.zeros(len(masks))
+        target["iscrowd"] = torch.zeros(len(masks))
+        return target
+
+    def heatmap_enhance(self, img):
+        # 直方图均衡化
+        img_eq = cv2.equalizeHist(img)
+
+        # 自适应直方图均衡化
+        # clahe = cv2.createCLAHE(clipLimit=2.0, tileGridSize=(8, 8))
+        # img_clahe = clahe.apply(img)
+
+        # 将灰度图转换为热力图
+        heatmap = cv2.applyColorMap(img_eq, cv2.COLORMAP_HOT)
+
+    def __len__(self):
+        return len(self.imgs)

+ 31 - 0
models/ins_detect/train.yaml

@@ -0,0 +1,31 @@
+
+
+dataset_path: F:\DevTools\datasets\renyaun\1012\spilt
+
+#train parameters
+num_classes: 5
+opt: 'adamw'
+batch_size: 2
+epochs: 10
+lr: 0.0005
+momentum: 0.9
+weight_decay: 0.0001
+lr_step_size: 3
+lr_gamma: 0.1
+num_workers: 4
+print_freq: 10
+target_type: polygon
+enable_logs: True
+augmentation: True
+checkpoint: None
+
+
+## Classes
+#names:
+#  0: fire
+#  1: dust
+#  2: move_machine
+#  3: open_machine
+#  4: close_machine
+
+

+ 220 - 0
models/ins_detect/trainer.py

@@ -0,0 +1,220 @@
+import math
+import os
+import sys
+from datetime import datetime
+
+import torch
+import torchvision
+from torch.utils.tensorboard import SummaryWriter
+from torchvision.models.detection import MaskRCNN_ResNet50_FPN_V2_Weights
+
+from models.config.config_tool import read_yaml
+from models.ins_detect.maskrcnn_dataset import MaskRCNNDataset
+from tools import utils, presets
+
+
+def train_one_epoch(model, optimizer, data_loader, device, epoch, print_freq, scaler=None):
+    model.train()
+    metric_logger = utils.MetricLogger(delimiter="  ")
+    metric_logger.add_meter("lr", utils.SmoothedValue(window_size=1, fmt="{value:.6f}"))
+    header = f"Epoch: [{epoch}]"
+
+    lr_scheduler = None
+    if epoch == 0:
+        warmup_factor = 1.0 / 1000
+        warmup_iters = min(1000, len(data_loader) - 1)
+
+        lr_scheduler = torch.optim.lr_scheduler.LinearLR(
+            optimizer, start_factor=warmup_factor, total_iters=warmup_iters
+        )
+
+    for images, targets in metric_logger.log_every(data_loader, print_freq, header):
+        print(f'images:{images}')
+        images = list(image.to(device) for image in images)
+        targets = [{k: v.to(device) if isinstance(v, torch.Tensor) else v for k, v in t.items()} for t in targets]
+        with torch.cuda.amp.autocast(enabled=scaler is not None):
+            loss_dict = model(images, targets)
+            losses = sum(loss for loss in loss_dict.values())
+
+        # reduce losses over all GPUs for logging purposes
+        loss_dict_reduced = utils.reduce_dict(loss_dict)
+        losses_reduced = sum(loss for loss in loss_dict_reduced.values())
+
+        loss_value = losses_reduced.item()
+
+        if not math.isfinite(loss_value):
+            print(f"Loss is {loss_value}, stopping training")
+            print(loss_dict_reduced)
+            sys.exit(1)
+
+        optimizer.zero_grad()
+        if scaler is not None:
+            scaler.scale(losses).backward()
+            scaler.step(optimizer)
+            scaler.update()
+        else:
+            losses.backward()
+            optimizer.step()
+
+        if lr_scheduler is not None:
+            lr_scheduler.step()
+
+        metric_logger.update(loss=losses_reduced, **loss_dict_reduced)
+        metric_logger.update(lr=optimizer.param_groups[0]["lr"])
+
+    return metric_logger
+
+
+def load_train_parameter(cfg):
+    parameters = read_yaml(cfg)
+    return parameters
+
+
+def train_cfg(model, cfg):
+    parameters = read_yaml(cfg)
+    print(f'train parameters:{parameters}')
+    train(model, **parameters)
+
+
+def train(model, **kwargs):
+    # 默认参数
+    default_params = {
+        'dataset_path': '/path/to/dataset',
+        'num_classes': 2,
+        'num_keypoints':2,
+        'opt': 'adamw',
+        'batch_size': 2,
+        'epochs': 10,
+        'lr': 0.005,
+        'momentum': 0.9,
+        'weight_decay': 1e-4,
+        'lr_step_size': 3,
+        'lr_gamma': 0.1,
+        'num_workers': 4,
+        'print_freq': 10,
+        'target_type': 'polygon',
+        'enable_logs': True,
+        'augmentation': False,
+        'checkpoint':None
+    }
+    # 更新默认参数
+    for key, value in kwargs.items():
+        if key in default_params:
+            default_params[key] = value
+        else:
+            raise ValueError(f"Unknown argument: {key}")
+
+    # 解析参数
+    dataset_path = default_params['dataset_path']
+    num_classes = default_params['num_classes']
+    batch_size = default_params['batch_size']
+    epochs = default_params['epochs']
+    lr = default_params['lr']
+    momentum = default_params['momentum']
+    weight_decay = default_params['weight_decay']
+    lr_step_size = default_params['lr_step_size']
+    lr_gamma = default_params['lr_gamma']
+    num_workers = default_params['num_workers']
+    print_freq = default_params['print_freq']
+    target_type = default_params['target_type']
+    augmentation = default_params['augmentation']
+    # 设置设备
+    device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
+
+    train_result_ptath = os.path.join('train_results', datetime.now().strftime("%Y%m%d_%H%M%S"))
+    wts_path = os.path.join(train_result_ptath, 'weights')
+    tb_path = os.path.join(train_result_ptath, 'logs')
+    writer = SummaryWriter(tb_path)
+
+    transforms = None
+    # default_transforms = MaskRCNN_ResNet50_FPN_V2_Weights.DEFAULT.transforms()
+    if augmentation:
+        transforms = get_transform(is_train=True)
+        print(f'transforms:{transforms}')
+    if not os.path.exists('train_results'):
+        os.mkdir('train_results')
+
+    model.to(device)
+    optimizer = torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=weight_decay)
+
+    dataset = MaskRCNNDataset(dataset_path=dataset_path,
+                              transforms=transforms, dataset_type='train', target_type=target_type)
+    dataset_test = MaskRCNNDataset(dataset_path=dataset_path, transforms=None,
+                                   dataset_type='val')
+
+    train_sampler = torch.utils.data.RandomSampler(dataset)
+    test_sampler = torch.utils.data.SequentialSampler(dataset_test)
+    train_batch_sampler = torch.utils.data.BatchSampler(train_sampler, batch_size, drop_last=True)
+    train_collate_fn = utils.collate_fn
+    data_loader = torch.utils.data.DataLoader(
+        dataset, batch_sampler=train_batch_sampler, num_workers=num_workers, collate_fn=train_collate_fn
+    )
+    # data_loader_test = torch.utils.data.DataLoader(
+    #     dataset_test, batch_size=1, sampler=test_sampler, num_workers=num_workers, collate_fn=utils.collate_fn
+    # )
+
+    img_results_path = os.path.join(train_result_ptath, 'img_results')
+    if os.path.exists(train_result_ptath):
+        pass
+    #     os.remove(train_result_ptath)
+    else:
+        os.mkdir(train_result_ptath)
+
+    if os.path.exists(train_result_ptath):
+        os.mkdir(wts_path)
+        os.mkdir(img_results_path)
+
+    for epoch in range(epochs):
+        metric_logger = train_one_epoch(model, optimizer, data_loader, device, epoch, print_freq, None)
+        losses = metric_logger.meters['loss'].global_avg
+        print(f'epoch {epoch}:loss:{losses}')
+        if os.path.exists(f'{wts_path}/last.pt'):
+            os.remove(f'{wts_path}/last.pt')
+        torch.save(model.state_dict(), f'{wts_path}/last.pt')
+        write_metric_logs(epoch, metric_logger, writer)
+        if epoch == 0:
+            best_loss = losses;
+        if best_loss >= losses:
+            best_loss = losses
+            if os.path.exists(f'{wts_path}/best.pt'):
+                os.remove(f'{wts_path}/best.pt')
+            torch.save(model.state_dict(), f'{wts_path}/best.pt')
+
+
+def get_transform(is_train, **kwargs):
+    default_params = {
+        'augmentation': 'multiscale',
+        'backend': 'tensor',
+        'use_v2': False,
+
+    }
+    # 更新默认参数
+    for key, value in kwargs.items():
+        if key in default_params:
+            default_params[key] = value
+        else:
+            raise ValueError(f"Unknown argument: {key}")
+
+    # 解析参数
+    augmentation = default_params['augmentation']
+    backend = default_params['backend']
+    use_v2 = default_params['use_v2']
+    if is_train:
+        return presets.DetectionPresetTrain(
+            data_augmentation=augmentation, backend=backend, use_v2=use_v2
+        )
+    # elif weights and test_only:
+    #     weights = torchvision.models.get_weight(args.weights)
+    #     trans = weights.transforms()
+    #     return lambda img, target: (trans(img), target)
+    else:
+        return presets.DetectionPresetEval(backend=backend, use_v2=use_v2)
+
+
+def write_metric_logs(epoch, metric_logger, writer):
+    writer.add_scalar(f'loss_classifier:', metric_logger.meters['loss_classifier'].global_avg, epoch)
+    writer.add_scalar(f'loss_box_reg:', metric_logger.meters['loss_box_reg'].global_avg, epoch)
+    writer.add_scalar(f'loss_mask:', metric_logger.meters['loss_mask'].global_avg, epoch)
+    writer.add_scalar(f'loss_objectness:', metric_logger.meters['loss_objectness'].global_avg, epoch)
+    writer.add_scalar(f'loss_rpn_box_reg:', metric_logger.meters['loss_rpn_box_reg'].global_avg, epoch)
+    writer.add_scalar(f'train loss:', metric_logger.meters['loss'].global_avg, epoch)

+ 312 - 0
models/keypoint/keypoint_dataset.py

@@ -1,3 +1,107 @@
+<<<<<<<< HEAD:models/keypoint/keypoint_dataset.py
+========
+# import glob
+# import json
+# import math
+# import os
+# import random
+#
+# import numpy as np
+# import numpy.linalg as LA
+# import torch
+# from skimage import io
+# from torch.utils.data import Dataset
+# from torch.utils.data.dataloader import default_collate
+#
+# from lcnn.config import M
+#
+# from .dataset_tool import line_boxes_faster, read_masks_from_txt_wire, read_masks_from_pixels_wire
+#
+#
+# class WireframeDataset(Dataset):
+#     def __init__(self, rootdir, split):
+#         self.rootdir = rootdir
+#         filelist = glob.glob(f"{rootdir}/{split}/*_label.npz")
+#         filelist.sort()
+#
+#         # print(f"n{split}:", len(filelist))
+#         self.split = split
+#         self.filelist = filelist
+#
+#     def __len__(self):
+#         return len(self.filelist)
+#
+#     def __getitem__(self, idx):
+#         iname = self.filelist[idx][:-10].replace("_a0", "").replace("_a1", "") + ".png"
+#         image = io.imread(iname).astype(float)[:, :, :3]
+#         if "a1" in self.filelist[idx]:
+#             image = image[:, ::-1, :]
+#         image = (image - M.image.mean) / M.image.stddev
+#         image = np.rollaxis(image, 2).copy()
+#
+#         with np.load(self.filelist[idx]) as npz:
+#             target = {
+#                 name: torch.from_numpy(npz[name]).float()
+#                 for name in ["jmap", "joff", "lmap"]
+#             }
+#             lpos = np.random.permutation(npz["lpos"])[: M.n_stc_posl]
+#             lneg = np.random.permutation(npz["lneg"])[: M.n_stc_negl]
+#             npos, nneg = len(lpos), len(lneg)
+#             lpre = np.concatenate([lpos, lneg], 0)
+#             for i in range(len(lpre)):
+#                 if random.random() > 0.5:
+#                     lpre[i] = lpre[i, ::-1]
+#             ldir = lpre[:, 0, :2] - lpre[:, 1, :2]
+#             ldir /= np.clip(LA.norm(ldir, axis=1, keepdims=True), 1e-6, None)
+#             feat = [
+#                 lpre[:, :, :2].reshape(-1, 4) / 128 * M.use_cood,
+#                 ldir * M.use_slop,
+#                 lpre[:, :, 2],
+#             ]
+#             feat = np.concatenate(feat, 1)
+#             meta = {
+#                 "junc": torch.from_numpy(npz["junc"][:, :2]),
+#                 "jtyp": torch.from_numpy(npz["junc"][:, 2]).byte(),
+#                 "Lpos": self.adjacency_matrix(len(npz["junc"]), npz["Lpos"]),
+#                 "Lneg": self.adjacency_matrix(len(npz["junc"]), npz["Lneg"]),
+#                 "lpre": torch.from_numpy(lpre[:, :, :2]),
+#                 "lpre_label": torch.cat([torch.ones(npos), torch.zeros(nneg)]),
+#                 "lpre_feat": torch.from_numpy(feat),
+#             }
+#
+#         labels = []
+#         labels = read_masks_from_pixels_wire(iname, (512, 512))
+#         # if self.target_type == 'polygon':
+#         #     labels, masks = read_masks_from_txt_wire(iname, (512, 512))
+#         # elif self.target_type == 'pixel':
+#         #     labels = read_masks_from_pixels_wire(iname, (512, 512))
+#
+#         target["labels"] = torch.stack(labels)
+#         target["boxes"] = line_boxes_faster(meta)
+#
+#
+#         return torch.from_numpy(image).float(), meta, target
+#
+#     def adjacency_matrix(self, n, link):
+#         mat = torch.zeros(n + 1, n + 1, dtype=torch.uint8)
+#         link = torch.from_numpy(link)
+#         if len(link) > 0:
+#             mat[link[:, 0], link[:, 1]] = 1
+#             mat[link[:, 1], link[:, 0]] = 1
+#         return mat
+#
+#
+# def collate(batch):
+#     return (
+#         default_collate([b[0] for b in batch]),
+#         [b[1] for b in batch],
+#         default_collate([b[2] for b in batch]),
+#     )
+
+
+# 原LCNN数据格式,改了属性名,加了box相关
+
+>>>>>>>> 8c208a87b75e9fde2fe6dcf23f2e589aeb87f172:lcnn/datasets.py
 from torch.utils.data.dataset import T_co
 
 from models.base.base_dataset import BaseDataset
@@ -30,8 +134,12 @@ def validate_keypoints(keypoints, image_width, image_height):
         if not (0 <= x < image_width and 0 <= y < image_height):
             raise ValueError(f"Key point ({x}, {y}) is out of bounds for image size ({image_width}, {image_height})")
 
+<<<<<<<< HEAD:models/keypoint/keypoint_dataset.py
 
 class KeypointDataset(BaseDataset):
+========
+class  WireframeDataset(BaseDataset):
+>>>>>>>> 8c208a87b75e9fde2fe6dcf23f2e589aeb87f172:lcnn/datasets.py
     def __init__(self, dataset_path, transforms=None, dataset_type=None, target_type='pixel'):
         super().__init__(dataset_path)
 
@@ -199,7 +307,211 @@ class KeypointDataset(BaseDataset):
 
 
 
+<<<<<<<< HEAD:models/keypoint/keypoint_dataset.py
 if __name__ == '__main__':
     path=r"I:\datasets\wirenet_1000"
     dataset= KeypointDataset(dataset_path=path, dataset_type='train')
     dataset.show(7)
+========
+
+'''
+# 使用roi_head数据格式有要求,更改数据格式
+from torch.utils.data.dataset import T_co
+
+from models.base.base_dataset import BaseDataset
+
+import glob
+import json
+import math
+import os
+import random
+import cv2
+import PIL
+
+import matplotlib.pyplot as plt
+import matplotlib as mpl
+from torchvision.utils import draw_bounding_boxes
+
+import numpy as np
+import numpy.linalg as LA
+import torch
+from skimage import io
+from torch.utils.data import Dataset
+from torch.utils.data.dataloader import default_collate
+
+import matplotlib.pyplot as plt
+from models.dataset_tool import line_boxes, read_masks_from_txt_wire, read_masks_from_pixels_wire, adjacency_matrix
+
+from tools.presets import DetectionPresetTrain
+
+
+class WireframeDataset(BaseDataset):
+    def __init__(self, dataset_path, transforms=None, dataset_type=None, target_type='pixel'):
+        super().__init__(dataset_path)
+
+        self.data_path = dataset_path
+        print(f'data_path:{dataset_path}')
+        self.transforms = transforms
+        self.img_path = os.path.join(dataset_path, "images\\" + dataset_type)
+        self.lbl_path = os.path.join(dataset_path, "labels\\" + dataset_type)
+        self.imgs = os.listdir(self.img_path)
+        self.lbls = os.listdir(self.lbl_path)
+        self.target_type = target_type
+        # self.default_transform = DefaultTransform()
+        self.data_augmentation = DetectionPresetTrain(data_augmentation="hflip")  # multiscale会改变图像大小
+
+    def __getitem__(self, index) -> T_co:
+        img_path = os.path.join(self.img_path, self.imgs[index])
+        lbl_path = os.path.join(self.lbl_path, self.imgs[index][:-3] + 'json')
+
+        img = PIL.Image.open(img_path).convert('RGB')
+        w, h = img.size
+
+        # wire_labels, target = self.read_target(item=index, lbl_path=lbl_path, shape=(h, w))
+        target = self.read_target(item=index, lbl_path=lbl_path, shape=(h, w))
+        # if self.transforms:
+        #     img, target = self.transforms(img, target)
+        # else:
+        #     img = self.default_transform(img)
+
+        img, target = self.data_augmentation(img, target)
+
+        print(f'img:{img.shape}')
+        return img, target
+
+    def __len__(self):
+        return len(self.imgs)
+
+    def read_target(self, item, lbl_path, shape, extra=None):
+        # print(f'lbl_path:{lbl_path}')
+        with open(lbl_path, 'r') as file:
+            lable_all = json.load(file)
+
+        n_stc_posl = 300
+        n_stc_negl = 40
+        use_cood = 0
+        use_slop = 0
+
+        wire = lable_all["wires"][0]  # 字典
+        line_pos_coords = np.random.permutation(wire["line_pos_coords"]["content"])[: n_stc_posl]  # 不足,有多少取多少
+        line_neg_coords = np.random.permutation(wire["line_neg_coords"]["content"])[: n_stc_negl]
+        npos, nneg = len(line_pos_coords), len(line_neg_coords)
+        lpre = np.concatenate([line_pos_coords, line_neg_coords], 0)  # 正负样本坐标合在一起
+        for i in range(len(lpre)):
+            if random.random() > 0.5:
+                lpre[i] = lpre[i, ::-1]
+        ldir = lpre[:, 0, :2] - lpre[:, 1, :2]
+        ldir /= np.clip(LA.norm(ldir, axis=1, keepdims=True), 1e-6, None)
+        feat = [
+            lpre[:, :, :2].reshape(-1, 4) / 128 * use_cood,
+            ldir * use_slop,
+            lpre[:, :, 2],
+        ]
+        feat = np.concatenate(feat, 1)
+
+        wire_labels = {
+            "junc_coords": torch.tensor(wire["junc_coords"]["content"])[:, :2],
+            "jtyp": torch.tensor(wire["junc_coords"]["content"])[:, 2].byte(),
+            "line_pos_idx": adjacency_matrix(len(wire["junc_coords"]["content"]), wire["line_pos_idx"]["content"]),
+            # 真实存在线条的邻接矩阵
+            "line_neg_idx": adjacency_matrix(len(wire["junc_coords"]["content"]), wire["line_neg_idx"]["content"]),
+            # 不存在线条的临界矩阵
+            "lpre": torch.tensor(lpre)[:, :, :2],
+            "lpre_label": torch.cat([torch.ones(npos), torch.zeros(nneg)]),  # 样本对应标签 1,0
+            "lpre_feat": torch.from_numpy(feat),
+            "junc_map": torch.tensor(wire['junc_map']["content"]),
+            "junc_offset": torch.tensor(wire['junc_offset']["content"]),
+            "line_map": torch.tensor(wire['line_map']["content"]),
+        }
+
+        labels = []
+        # if self.target_type == 'polygon':
+        #     labels, masks = read_masks_from_txt_wire(lbl_path, shape)
+        # elif self.target_type == 'pixel':
+        #     labels = read_masks_from_pixels_wire(lbl_path, shape)
+
+        # print(torch.stack(masks).shape)    # [线段数, 512, 512]
+        target = {}
+        # target["labels"] = torch.stack(labels)
+        target["image_id"] = torch.tensor(item)
+        # return wire_labels, target
+        target["wires"] = wire_labels
+        target["boxes"] = line_boxes(target)
+        return target
+
+    def show(self, idx):
+        image, target = self.__getitem__(idx)
+        img_path = os.path.join(self.img_path, self.imgs[idx])
+        self._draw_vecl(img_path, target)
+
+    def show_img(self, img_path):
+
+        """根据给定的图片路径展示图像及其标注信息"""
+        # 获取对应的标签文件路径
+        img_name = os.path.basename(img_path)
+        img_path = os.path.join(self.img_path, img_name)
+        print(img_path)
+        lbl_name = img_name[:-3] + 'json'
+        lbl_path = os.path.join(self.lbl_path, lbl_name)
+        print(lbl_path)
+
+        if not os.path.exists(lbl_path):
+            raise FileNotFoundError(f"Label file {lbl_path} does not exist.")
+
+        img = PIL.Image.open(img_path).convert('RGB')
+        w, h = img.size
+
+        target = self.read_target(0, lbl_path, shape=(h, w))
+
+        # 调用绘图函数
+        self._draw_vecl(img_path, target)
+
+
+    def _draw_vecl(self, img_path, target, fn=None):
+        cmap = plt.get_cmap("jet")
+        norm = mpl.colors.Normalize(vmin=0.4, vmax=1.0)
+        sm = plt.cm.ScalarMappable(cmap=cmap, norm=norm)
+        sm.set_array([])
+
+        def imshow(im):
+            plt.close()
+            plt.tight_layout()
+            plt.imshow(im)
+            plt.colorbar(sm, fraction=0.046)
+            plt.xlim([0, im.shape[0]])
+            plt.ylim([im.shape[0], 0])
+
+        junc = target['wires']['junc_coords'].cpu().numpy() * 4
+        jtyp = target['wires']['jtyp'].cpu().numpy()
+        juncs = junc[jtyp == 0]
+        junts = junc[jtyp == 1]
+
+        lpre = target['wires']["lpre"].cpu().numpy() * 4
+        vecl_target = target['wires']["lpre_label"].cpu().numpy()
+        lpre = lpre[vecl_target == 1]
+
+        lines = lpre
+        sline = np.ones(lpre.shape[0])
+        imshow(io.imread(img_path))
+        if len(lines) > 0 and not (lines[0] == 0).all():
+            for i, ((a, b), s) in enumerate(zip(lines, sline)):
+                if i > 0 and (lines[i] == lines[0]).all():
+                    break
+                plt.plot([a[1], b[1]], [a[0], b[0]], c="red", linewidth=1)  # a[1], b[1]无明确大小
+        if not (juncs[0] == 0).all():
+            for i, j in enumerate(juncs):
+                if i > 0 and (i == juncs[0]).all():
+                    break
+                plt.scatter(j[1], j[0], c="red", s=2, zorder=100)  # 原 s=64
+
+        img = PIL.Image.open(img_path).convert('RGB')
+        boxed_image = draw_bounding_boxes((self.default_transform(img) * 255).to(torch.uint8), target["boxes"],
+                                          colors="yellow", width=1)
+        plt.imshow(boxed_image.permute(1, 2, 0).numpy())
+        plt.show()
+
+        if fn != None:
+            plt.savefig(fn)
+
+'''
+>>>>>>>> 8c208a87b75e9fde2fe6dcf23f2e589aeb87f172:lcnn/datasets.py

+ 0 - 0
models/line_detect/__init__.py


+ 26 - 141
lcnn/datasets.py → models/line_detect/dataset_LD.py

@@ -1,104 +1,7 @@
-# import glob
-# import json
-# import math
-# import os
-# import random
-#
-# import numpy as np
-# import numpy.linalg as LA
-# import torch
-# from skimage import io
-# from torch.utils.data import Dataset
-# from torch.utils.data.dataloader import default_collate
-#
-# from lcnn.config import M
-#
-# from .dataset_tool import line_boxes_faster, read_masks_from_txt_wire, read_masks_from_pixels_wire
-#
-#
-# class WireframeDataset(Dataset):
-#     def __init__(self, rootdir, split):
-#         self.rootdir = rootdir
-#         filelist = glob.glob(f"{rootdir}/{split}/*_label.npz")
-#         filelist.sort()
-#
-#         # print(f"n{split}:", len(filelist))
-#         self.split = split
-#         self.filelist = filelist
-#
-#     def __len__(self):
-#         return len(self.filelist)
-#
-#     def __getitem__(self, idx):
-#         iname = self.filelist[idx][:-10].replace("_a0", "").replace("_a1", "") + ".png"
-#         image = io.imread(iname).astype(float)[:, :, :3]
-#         if "a1" in self.filelist[idx]:
-#             image = image[:, ::-1, :]
-#         image = (image - M.image.mean) / M.image.stddev
-#         image = np.rollaxis(image, 2).copy()
-#
-#         with np.load(self.filelist[idx]) as npz:
-#             target = {
-#                 name: torch.from_numpy(npz[name]).float()
-#                 for name in ["jmap", "joff", "lmap"]
-#             }
-#             lpos = np.random.permutation(npz["lpos"])[: M.n_stc_posl]
-#             lneg = np.random.permutation(npz["lneg"])[: M.n_stc_negl]
-#             npos, nneg = len(lpos), len(lneg)
-#             lpre = np.concatenate([lpos, lneg], 0)
-#             for i in range(len(lpre)):
-#                 if random.random() > 0.5:
-#                     lpre[i] = lpre[i, ::-1]
-#             ldir = lpre[:, 0, :2] - lpre[:, 1, :2]
-#             ldir /= np.clip(LA.norm(ldir, axis=1, keepdims=True), 1e-6, None)
-#             feat = [
-#                 lpre[:, :, :2].reshape(-1, 4) / 128 * M.use_cood,
-#                 ldir * M.use_slop,
-#                 lpre[:, :, 2],
-#             ]
-#             feat = np.concatenate(feat, 1)
-#             meta = {
-#                 "junc": torch.from_numpy(npz["junc"][:, :2]),
-#                 "jtyp": torch.from_numpy(npz["junc"][:, 2]).byte(),
-#                 "Lpos": self.adjacency_matrix(len(npz["junc"]), npz["Lpos"]),
-#                 "Lneg": self.adjacency_matrix(len(npz["junc"]), npz["Lneg"]),
-#                 "lpre": torch.from_numpy(lpre[:, :, :2]),
-#                 "lpre_label": torch.cat([torch.ones(npos), torch.zeros(nneg)]),
-#                 "lpre_feat": torch.from_numpy(feat),
-#             }
-#
-#         labels = []
-#         labels = read_masks_from_pixels_wire(iname, (512, 512))
-#         # if self.target_type == 'polygon':
-#         #     labels, masks = read_masks_from_txt_wire(iname, (512, 512))
-#         # elif self.target_type == 'pixel':
-#         #     labels = read_masks_from_pixels_wire(iname, (512, 512))
-#
-#         target["labels"] = torch.stack(labels)
-#         target["boxes"] = line_boxes_faster(meta)
-#
-#
-#         return torch.from_numpy(image).float(), meta, target
-#
-#     def adjacency_matrix(self, n, link):
-#         mat = torch.zeros(n + 1, n + 1, dtype=torch.uint8)
-#         link = torch.from_numpy(link)
-#         if len(link) > 0:
-#             mat[link[:, 0], link[:, 1]] = 1
-#             mat[link[:, 1], link[:, 0]] = 1
-#         return mat
-#
-#
-# def collate(batch):
-#     return (
-#         default_collate([b[0] for b in batch]),
-#         [b[1] for b in batch],
-#         default_collate([b[2] for b in batch]),
-#     )
-
+# 使用roi_head数据格式有要求,更改数据格式
 from torch.utils.data.dataset import T_co
 
-from .models.base.base_dataset import BaseDataset
+from models.base.base_dataset import BaseDataset
 
 import glob
 import json
@@ -120,19 +23,20 @@ from torch.utils.data import Dataset
 from torch.utils.data.dataloader import default_collate
 
 import matplotlib.pyplot as plt
-from .dataset_tool import line_boxes_faster, read_masks_from_txt_wire, read_masks_from_pixels_wire, adjacency_matrix
+from models.dataset_tool import line_boxes, read_masks_from_txt_wire, read_masks_from_pixels_wire, adjacency_matrix
 
+from tools.presets import DetectionPresetTrain
 
 
-class  WireframeDataset(BaseDataset):
+class WirePointDataset(BaseDataset):
     def __init__(self, dataset_path, transforms=None, dataset_type=None, target_type='pixel'):
         super().__init__(dataset_path)
 
         self.data_path = dataset_path
         print(f'data_path:{dataset_path}')
         self.transforms = transforms
-        self.img_path = os.path.join(dataset_path, "images/" + dataset_type)
-        self.lbl_path = os.path.join(dataset_path, "labels/" + dataset_type)
+        self.img_path = os.path.join(dataset_path, "images\\" + dataset_type)
+        self.lbl_path = os.path.join(dataset_path, "labels\\" + dataset_type)
         self.imgs = os.listdir(self.img_path)
         self.lbls = os.listdir(self.lbl_path)
         self.target_type = target_type
@@ -146,18 +50,19 @@ class  WireframeDataset(BaseDataset):
         w, h = img.size
 
         # wire_labels, target = self.read_target(item=index, lbl_path=lbl_path, shape=(h, w))
-        meta, target, target_b = self.read_target(item=index, lbl_path=lbl_path, shape=(h, w))
-
-        img = self.default_transform(img)
+        target = self.read_target(item=index, lbl_path=lbl_path, shape=(h, w))
+        if self.transforms:
+            img, target = self.transforms(img, target)
+        else:
+            img = self.default_transform(img)
 
         # print(f'img:{img}')
-        return img, meta, target, target_b
+        return img, target
 
     def __len__(self):
         return len(self.imgs)
 
     def read_target(self, item, lbl_path, shape, extra=None):
-        # print(f'shape:{shape}')
         # print(f'lbl_path:{lbl_path}')
         with open(lbl_path, 'r') as file:
             lable_all = json.load(file)
@@ -184,20 +89,16 @@ class  WireframeDataset(BaseDataset):
         ]
         feat = np.concatenate(feat, 1)
 
-        meta = {
+        wire_labels = {
             "junc_coords": torch.tensor(wire["junc_coords"]["content"])[:, :2],
             "jtyp": torch.tensor(wire["junc_coords"]["content"])[:, 2].byte(),
             "line_pos_idx": adjacency_matrix(len(wire["junc_coords"]["content"]), wire["line_pos_idx"]["content"]),
             # 真实存在线条的邻接矩阵
             "line_neg_idx": adjacency_matrix(len(wire["junc_coords"]["content"]), wire["line_neg_idx"]["content"]),
-
+            # 不存在线条的临界矩阵
             "lpre": torch.tensor(lpre)[:, :, :2],
             "lpre_label": torch.cat([torch.ones(npos), torch.zeros(nneg)]),  # 样本对应标签 1,0
             "lpre_feat": torch.from_numpy(feat),
-
-        }
-
-        target = {
             "junc_map": torch.tensor(wire['junc_map']["content"]),
             "junc_offset": torch.tensor(wire['junc_offset']["content"]),
             "line_map": torch.tensor(wire['line_map']["content"]),
@@ -210,15 +111,14 @@ class  WireframeDataset(BaseDataset):
             labels = read_masks_from_pixels_wire(lbl_path, shape)
 
         # print(torch.stack(masks).shape)    # [线段数, 512, 512]
-        target_b = {}
-
-        # target_b["image_id"] = torch.tensor(item)
-
-        target_b["labels"] = torch.stack(labels)
-        target_b["boxes"] = line_boxes_faster(meta)
-
-        return meta, target, target_b
-
+        target = {}
+        target["labels"] = torch.stack(labels)
+        target["image_id"] = torch.tensor(item)
+        # return wire_labels, target
+        target["wires"] = wire_labels
+        target["boxes"] = line_boxes(target)
+        # print(f'boxes:{target["boxes"].shape}')
+        return target
 
     def show(self, idx):
         image, target = self.__getitem__(idx)
@@ -250,6 +150,7 @@ class  WireframeDataset(BaseDataset):
                         break
                     plt.scatter(j[1], j[0], c="red", s=2, zorder=100)  # 原 s=64
 
+
             img_path = os.path.join(self.img_path, self.imgs[idx])
             img = PIL.Image.open(img_path).convert('RGB')
             boxed_image = draw_bounding_boxes((self.default_transform(img) * 255).to(torch.uint8), target["boxes"],
@@ -273,22 +174,6 @@ class  WireframeDataset(BaseDataset):
         # draw_vecl(lpre, np.ones(lpre.shape[0]), juncs, junts, save_path)
         draw_vecl(lpre, np.ones(lpre.shape[0]), juncs, junts)
 
-    def show_img(self, img_path):
-        pass
-
-
-def collate(batch):
-    return (
-        default_collate([b[0] for b in batch]),
-        [b[1] for b in batch],
-        default_collate([b[2] for b in batch]),
-        [b[3] for b in batch],
-    )
-
-
-# if __name__ == '__main__':
-#     path = r"D:\python\PycharmProjects\data"
-#     dataset = WireframeDataset(dataset_path=path, dataset_type='train')
-#     dataset.show(0)
-
 
+    def show_img(self, img_path):
+        pass

+ 24 - 0
models/line_detect/line_head.py

@@ -0,0 +1,24 @@
+import torch
+from torch import nn
+
+
+class LineRCNNHeads(nn.Sequential):
+    def __init__(self, input_channels, num_class):
+        super(LineRCNNHeads, self).__init__()
+        # print("输入的维度是:", input_channels)
+        m = int(input_channels / 4)
+        heads = []
+        self.head_size = [[2], [1], [2]]
+        for output_channels in sum(self.head_size, []):
+            heads.append(
+                nn.Sequential(
+                    nn.Conv2d(input_channels, m, kernel_size=3, padding=1),
+                    nn.ReLU(inplace=True),
+                    nn.Conv2d(m, output_channels, kernel_size=1),
+                )
+            )
+        self.heads = nn.ModuleList(heads)
+        assert num_class == sum(sum(self.head_size, []))
+
+    def forward(self, x):
+        return torch.cat([head(x) for head in self.heads], dim=1)

+ 912 - 0
models/line_detect/line_net.py

@@ -0,0 +1,912 @@
+
+from typing import Any, Callable, List, Optional, Tuple, Union
+import torch
+from torch import nn
+from torchvision.ops import MultiScaleRoIAlign
+
+from libs.vision_libs.models import MobileNet_V3_Large_Weights, mobilenet_v3_large
+from libs.vision_libs.models.detection.anchor_utils import AnchorGenerator
+from libs.vision_libs.models.detection.rpn import RPNHead, RegionProposalNetwork
+from libs.vision_libs.models.detection.ssdlite import _mobilenet_extractor
+from libs.vision_libs.models.detection.transform import GeneralizedRCNNTransform
+from libs.vision_libs.ops import misc as misc_nn_ops
+from libs.vision_libs.transforms._presets import ObjectDetection
+from .line_head import LineRCNNHeads
+from .line_predictor import LineRCNNPredictor
+from .roi_heads import RoIHeads
+from libs.vision_libs.models._api import register_model, Weights, WeightsEnum
+from libs.vision_libs.models._meta import _COCO_PERSON_CATEGORIES, _COCO_PERSON_KEYPOINT_NAMES, _COCO_CATEGORIES
+from libs.vision_libs.models._utils import _ovewrite_value_param, handle_legacy_interface
+from libs.vision_libs.models.resnet import resnet50, ResNet50_Weights
+from libs.vision_libs.models.detection._utils import overwrite_eps
+from libs.vision_libs.models.detection.backbone_utils import _resnet_fpn_extractor, _validate_trainable_layers
+from libs.vision_libs.models.detection.faster_rcnn import FasterRCNN, TwoMLPHead, FastRCNNPredictor
+
+from models.config.config_tool import read_yaml
+import numpy as np
+import torch.nn.functional as F
+
+FEATURE_DIM = 8
+device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
+
+__all__ = [
+    "LineNet",
+    "LineNet_ResNet50_FPN_Weights",
+    "LineNet_ResNet50_FPN_V2_Weights",
+    "LineNet_MobileNet_V3_Large_FPN_Weights",
+    "LineNet_MobileNet_V3_Large_320_FPN_Weights",
+    "linenet_resnet50_fpn",
+    "linenet_resnet50_fpn_v2",
+    "linenet_mobilenet_v3_large_fpn",
+    "linenet_mobilenet_v3_large_320_fpn",
+]
+# __all__ = [
+#     "LineNet",
+#     "LineRCNN_ResNet50_FPN_Weights",
+#     "linercnn_resnet50_fpn",
+# ]
+
+
+def non_maximum_suppression(a):
+    ap = F.max_pool2d(a, 3, stride=1, padding=1)
+    mask = (a == ap).float().clamp(min=0.0)
+    return a * mask
+
+
+# class Bottleneck1D(nn.Module):
+#     def __init__(self, inplanes, outplanes):
+#         super(Bottleneck1D, self).__init__()
+#
+#         planes = outplanes // 2
+#         self.op = nn.Sequential(
+#             nn.BatchNorm1d(inplanes),
+#             nn.ReLU(inplace=True),
+#             nn.Conv1d(inplanes, planes, kernel_size=1),
+#             nn.BatchNorm1d(planes),
+#             nn.ReLU(inplace=True),
+#             nn.Conv1d(planes, planes, kernel_size=3, padding=1),
+#             nn.BatchNorm1d(planes),
+#             nn.ReLU(inplace=True),
+#             nn.Conv1d(planes, outplanes, kernel_size=1),
+#         )
+#
+#     def forward(self, x):
+#         return x + self.op(x)
+
+
+
+
+
+
+
+
+
+from .roi_heads import RoIHeads
+
+from ..base.base_detection_net import BaseDetectionNet
+
+
+def _default_anchorgen():
+    anchor_sizes = ((32,), (64,), (128,), (256,), (512,))
+    aspect_ratios = ((0.5, 1.0, 2.0),) * len(anchor_sizes)
+    return AnchorGenerator(anchor_sizes, aspect_ratios)
+
+
+class LineNet(BaseDetectionNet):
+    """
+    Implements Faster R-CNN.
+
+    The input to the model is expected to be a list of tensors, each of shape [C, H, W], one for each
+    image, and should be in 0-1 range. Different images can have different sizes.
+
+    The behavior of the model changes depending on if it is in training or evaluation mode.
+
+    During training, the model expects both the input tensors and targets (list of dictionary),
+    containing:
+        - boxes (``FloatTensor[N, 4]``): the ground-truth boxes in ``[x1, y1, x2, y2]`` format, with
+          ``0 <= x1 < x2 <= W`` and ``0 <= y1 < y2 <= H``.
+        - labels (Int64Tensor[N]): the class label for each ground-truth box
+
+    The model returns a Dict[Tensor] during training, containing the classification and regression
+    losses for both the RPN and the R-CNN.
+
+    During inference, the model requires only the input tensors, and returns the post-processed
+    predictions as a List[Dict[Tensor]], one for each input image. The fields of the Dict are as
+    follows:
+        - boxes (``FloatTensor[N, 4]``): the predicted boxes in ``[x1, y1, x2, y2]`` format, with
+          ``0 <= x1 < x2 <= W`` and ``0 <= y1 < y2 <= H``.
+        - labels (Int64Tensor[N]): the predicted labels for each image
+        - scores (Tensor[N]): the scores or each prediction
+
+    Args:
+        backbone (nn.Module): the network used to compute the features for the model.
+            It should contain an out_channels attribute, which indicates the number of output
+            channels that each feature map has (and it should be the same for all feature maps).
+            The backbone should return a single Tensor or and OrderedDict[Tensor].
+        num_classes (int): number of output classes of the model (including the background).
+            If box_predictor is specified, num_classes should be None.
+        min_size (int): minimum size of the image to be rescaled before feeding it to the backbone
+        max_size (int): maximum size of the image to be rescaled before feeding it to the backbone
+        image_mean (Tuple[float, float, float]): mean values used for input normalization.
+            They are generally the mean values of the dataset on which the backbone has been trained
+            on
+        image_std (Tuple[float, float, float]): std values used for input normalization.
+            They are generally the std values of the dataset on which the backbone has been trained on
+        rpn_anchor_generator (AnchorGenerator): module that generates the anchors for a set of feature
+            maps.
+        rpn_head (nn.Module): module that computes the objectness and regression deltas from the RPN
+        rpn_pre_nms_top_n_train (int): number of proposals to keep before applying NMS during training
+        rpn_pre_nms_top_n_test (int): number of proposals to keep before applying NMS during testing
+        rpn_post_nms_top_n_train (int): number of proposals to keep after applying NMS during training
+        rpn_post_nms_top_n_test (int): number of proposals to keep after applying NMS during testing
+        rpn_nms_thresh (float): NMS threshold used for postprocessing the RPN proposals
+        rpn_fg_iou_thresh (float): minimum IoU between the anchor and the GT box so that they can be
+            considered as positive during training of the RPN.
+        rpn_bg_iou_thresh (float): maximum IoU between the anchor and the GT box so that they can be
+            considered as negative during training of the RPN.
+        rpn_batch_size_per_image (int): number of anchors that are sampled during training of the RPN
+            for computing the loss
+        rpn_positive_fraction (float): proportion of positive anchors in a mini-batch during training
+            of the RPN
+        rpn_score_thresh (float): during inference, only return proposals with a classification score
+            greater than rpn_score_thresh
+        box_roi_pool (MultiScaleRoIAlign): the module which crops and resizes the feature maps in
+            the locations indicated by the bounding boxes
+        box_head (nn.Module): module that takes the cropped feature maps as input
+        box_predictor (nn.Module): module that takes the output of box_head and returns the
+            classification logits and box regression deltas.
+        box_score_thresh (float): during inference, only return proposals with a classification score
+            greater than box_score_thresh
+        box_nms_thresh (float): NMS threshold for the prediction head. Used during inference
+        box_detections_per_img (int): maximum number of detections per image, for all classes.
+        box_fg_iou_thresh (float): minimum IoU between the proposals and the GT box so that they can be
+            considered as positive during training of the classification head
+        box_bg_iou_thresh (float): maximum IoU between the proposals and the GT box so that they can be
+            considered as negative during training of the classification head
+        box_batch_size_per_image (int): number of proposals that are sampled during training of the
+            classification head
+        box_positive_fraction (float): proportion of positive proposals in a mini-batch during training
+            of the classification head
+        bbox_reg_weights (Tuple[float, float, float, float]): weights for the encoding/decoding of the
+            bounding boxes
+
+    Example::
+
+        >>> import torch
+        >>> import torchvision
+        >>> from torchvision.models.detection import FasterRCNN
+        >>> from torchvision.models.detection.rpn import AnchorGenerator
+        >>> # load a pre-trained model for classification and return
+        >>> # only the features
+        >>> backbone = torchvision.models.mobilenet_v2(weights=MobileNet_V2_Weights.DEFAULT).features
+        >>> # FasterRCNN needs to know the number of
+        >>> # output channels in a backbone. For mobilenet_v2, it's 1280,
+        >>> # so we need to add it here
+        >>> backbone.out_channels = 1280
+        >>>
+        >>> # let's make the RPN generate 5 x 3 anchors per spatial
+        >>> # location, with 5 different sizes and 3 different aspect
+        >>> # ratios. We have a Tuple[Tuple[int]] because each feature
+        >>> # map could potentially have different sizes and
+        >>> # aspect ratios
+        >>> anchor_generator = AnchorGenerator(sizes=((32, 64, 128, 256, 512),),
+        >>>                                    aspect_ratios=((0.5, 1.0, 2.0),))
+        >>>
+        >>> # let's define what are the feature maps that we will
+        >>> # use to perform the region of interest cropping, as well as
+        >>> # the size of the crop after rescaling.
+        >>> # if your backbone returns a Tensor, featmap_names is expected to
+        >>> # be ['0']. More generally, the backbone should return an
+        >>> # OrderedDict[Tensor], and in featmap_names you can choose which
+        >>> # feature maps to use.
+        >>> roi_pooler = torchvision.ops.MultiScaleRoIAlign(featmap_names=['0'],
+        >>>                                                 output_size=7,
+        >>>                                                 sampling_ratio=2)
+        >>>
+        >>> # put the pieces together inside a FasterRCNN model
+        >>> model = FasterRCNN(backbone,
+        >>>                    num_classes=2,
+        >>>                    rpn_anchor_generator=anchor_generator,
+        >>>                    box_roi_pool=roi_pooler)
+        >>> model.eval()
+        >>> x = [torch.rand(3, 300, 400), torch.rand(3, 500, 400)]
+        >>> predictions = model(x)
+    """
+
+    def __init__(
+        self,
+        backbone,
+        num_classes=None,
+        # transform parameters
+        min_size=512,
+        max_size=1333,
+        image_mean=None,
+        image_std=None,
+        # RPN parameters
+        rpn_anchor_generator=None,
+        rpn_head=None,
+        rpn_pre_nms_top_n_train=2000,
+        rpn_pre_nms_top_n_test=1000,
+        rpn_post_nms_top_n_train=2000,
+        rpn_post_nms_top_n_test=1000,
+        rpn_nms_thresh=0.7,
+        rpn_fg_iou_thresh=0.7,
+        rpn_bg_iou_thresh=0.3,
+        rpn_batch_size_per_image=256,
+        rpn_positive_fraction=0.5,
+        rpn_score_thresh=0.0,
+        # Box parameters
+        box_roi_pool=None,
+        box_head=None,
+        box_predictor=None,
+        box_score_thresh=0.05,
+        box_nms_thresh=0.5,
+        box_detections_per_img=100,
+        box_fg_iou_thresh=0.5,
+        box_bg_iou_thresh=0.5,
+        box_batch_size_per_image=512,
+        box_positive_fraction=0.25,
+        bbox_reg_weights=None,
+        # line parameters
+        line_head=None,
+        line_predictor=None,
+        **kwargs,
+    ):
+
+        if not hasattr(backbone, "out_channels"):
+            raise ValueError(
+                "backbone should contain an attribute out_channels "
+                "specifying the number of output channels (assumed to be the "
+                "same for all the levels)"
+            )
+
+        if not isinstance(rpn_anchor_generator, (AnchorGenerator, type(None))):
+            raise TypeError(
+                f"rpn_anchor_generator should be of type AnchorGenerator or None instead of {type(rpn_anchor_generator)}"
+            )
+        if not isinstance(box_roi_pool, (MultiScaleRoIAlign, type(None))):
+            raise TypeError(
+                f"box_roi_pool should be of type MultiScaleRoIAlign or None instead of {type(box_roi_pool)}"
+            )
+
+        if num_classes is not None:
+            if box_predictor is not None:
+                raise ValueError("num_classes should be None when box_predictor is specified")
+        else:
+            if box_predictor is None:
+                raise ValueError("num_classes should not be None when box_predictor is not specified")
+
+        out_channels = backbone.out_channels
+
+        if line_head is None:
+            num_class = 5
+            line_head = LineRCNNHeads(out_channels, num_class)
+
+        if line_predictor is None:
+            line_predictor = LineRCNNPredictor()
+
+        if rpn_anchor_generator is None:
+            rpn_anchor_generator = _default_anchorgen()
+        if rpn_head is None:
+            rpn_head = RPNHead(out_channels, rpn_anchor_generator.num_anchors_per_location()[0])
+
+        rpn_pre_nms_top_n = dict(training=rpn_pre_nms_top_n_train, testing=rpn_pre_nms_top_n_test)
+        rpn_post_nms_top_n = dict(training=rpn_post_nms_top_n_train, testing=rpn_post_nms_top_n_test)
+
+        rpn = RegionProposalNetwork(
+            rpn_anchor_generator,
+            rpn_head,
+            rpn_fg_iou_thresh,
+            rpn_bg_iou_thresh,
+            rpn_batch_size_per_image,
+            rpn_positive_fraction,
+            rpn_pre_nms_top_n,
+            rpn_post_nms_top_n,
+            rpn_nms_thresh,
+            score_thresh=rpn_score_thresh,
+        )
+
+        if box_roi_pool is None:
+            box_roi_pool = MultiScaleRoIAlign(featmap_names=["0", "1", "2", "3"], output_size=7, sampling_ratio=2)
+
+        if box_head is None:
+            resolution = box_roi_pool.output_size[0]
+            representation_size = 1024
+            box_head = TwoMLPHead(out_channels * resolution**2, representation_size)
+
+        if box_predictor is None:
+            representation_size = 1024
+            box_predictor = BoxPredictor(representation_size, num_classes)
+
+        roi_heads = RoIHeads(
+            # Box
+            box_roi_pool,
+            box_head,
+            box_predictor,
+            line_head,
+            line_predictor,
+            box_fg_iou_thresh,
+            box_bg_iou_thresh,
+            box_batch_size_per_image,
+            box_positive_fraction,
+            bbox_reg_weights,
+            box_score_thresh,
+            box_nms_thresh,
+            box_detections_per_img,
+        )
+
+        if image_mean is None:
+            image_mean = [0.485, 0.456, 0.406]
+        if image_std is None:
+            image_std = [0.229, 0.224, 0.225]
+        transform = GeneralizedRCNNTransform(min_size, max_size, image_mean, image_std, **kwargs)
+
+        super().__init__(backbone, rpn, roi_heads, transform)
+
+        self.roi_heads = roi_heads
+        # self.roi_heads.line_head = line_head
+        # self.roi_heads.line_predictor = line_predictor
+
+
+class TwoMLPHead(nn.Module):
+    """
+    Standard heads for FPN-based models
+
+    Args:
+        in_channels (int): number of input channels
+        representation_size (int): size of the intermediate representation
+    """
+
+    def __init__(self, in_channels, representation_size):
+        super().__init__()
+
+        self.fc6 = nn.Linear(in_channels, representation_size)
+        self.fc7 = nn.Linear(representation_size, representation_size)
+
+    def forward(self, x):
+        x = x.flatten(start_dim=1)
+
+        x = F.relu(self.fc6(x))
+        x = F.relu(self.fc7(x))
+
+        return x
+
+
+class LineNetConvFCHead(nn.Sequential):
+    def __init__(
+        self,
+        input_size: Tuple[int, int, int],
+        conv_layers: List[int],
+        fc_layers: List[int],
+        norm_layer: Optional[Callable[..., nn.Module]] = None,
+    ):
+        """
+        Args:
+            input_size (Tuple[int, int, int]): the input size in CHW format.
+            conv_layers (list): feature dimensions of each Convolution layer
+            fc_layers (list): feature dimensions of each FCN layer
+            norm_layer (callable, optional): Module specifying the normalization layer to use. Default: None
+        """
+        in_channels, in_height, in_width = input_size
+
+        blocks = []
+        previous_channels = in_channels
+        for current_channels in conv_layers:
+            blocks.append(misc_nn_ops.Conv2dNormActivation(previous_channels, current_channels, norm_layer=norm_layer))
+            previous_channels = current_channels
+        blocks.append(nn.Flatten())
+        previous_channels = previous_channels * in_height * in_width
+        for current_channels in fc_layers:
+            blocks.append(nn.Linear(previous_channels, current_channels))
+            blocks.append(nn.ReLU(inplace=True))
+            previous_channels = current_channels
+
+        super().__init__(*blocks)
+        for layer in self.modules():
+            if isinstance(layer, nn.Conv2d):
+                nn.init.kaiming_normal_(layer.weight, mode="fan_out", nonlinearity="relu")
+                if layer.bias is not None:
+                    nn.init.zeros_(layer.bias)
+
+
+class BoxPredictor(nn.Module):
+    """
+    Standard classification + bounding box regression layers
+    for Fast R-CNN.
+
+    Args:
+        in_channels (int): number of input channels
+        num_classes (int): number of output classes (including background)
+    """
+
+    def __init__(self, in_channels, num_classes):
+        super().__init__()
+        self.cls_score = nn.Linear(in_channels, num_classes)
+        self.bbox_pred = nn.Linear(in_channels, num_classes * 4)
+
+    def forward(self, x):
+        if x.dim() == 4:
+            torch._assert(
+                list(x.shape[2:]) == [1, 1],
+                f"x has the wrong shape, expecting the last two dimensions to be [1,1] instead of {list(x.shape[2:])}",
+            )
+        x = x.flatten(start_dim=1)
+        scores = self.cls_score(x)
+        bbox_deltas = self.bbox_pred(x)
+
+        return scores, bbox_deltas
+
+
+_COMMON_META = {
+    "categories": _COCO_CATEGORIES,
+    "min_size": (1, 1),
+}
+
+
+class LineNet_ResNet50_FPN_Weights(WeightsEnum):
+    COCO_V1 = Weights(
+        url="https://download.pytorch.org/models/fasterrcnn_resnet50_fpn_coco-258fb6c6.pth",
+        transforms=ObjectDetection,
+        meta={
+            **_COMMON_META,
+            "num_params": 41755286,
+            "recipe": "https://github.com/pytorch/vision/tree/main/references/detection#faster-r-cnn-resnet-50-fpn",
+            "_metrics": {
+                "COCO-val2017": {
+                    "box_map": 37.0,
+                }
+            },
+            "_ops": 134.38,
+            "_file_size": 159.743,
+            "_docs": """These weights were produced by following a similar training recipe as on the paper.""",
+        },
+    )
+    DEFAULT = COCO_V1
+
+
+class LineNet_ResNet50_FPN_V2_Weights(WeightsEnum):
+    COCO_V1 = Weights(
+        url="https://download.pytorch.org/models/fasterrcnn_resnet50_fpn_v2_coco-dd69338a.pth",
+        transforms=ObjectDetection,
+        meta={
+            **_COMMON_META,
+            "num_params": 43712278,
+            "recipe": "https://github.com/pytorch/vision/pull/5763",
+            "_metrics": {
+                "COCO-val2017": {
+                    "box_map": 46.7,
+                }
+            },
+            "_ops": 280.371,
+            "_file_size": 167.104,
+            "_docs": """These weights were produced using an enhanced training recipe to boost the model accuracy.""",
+        },
+    )
+    DEFAULT = COCO_V1
+
+
+class LineNet_MobileNet_V3_Large_FPN_Weights(WeightsEnum):
+    COCO_V1 = Weights(
+        url="https://download.pytorch.org/models/fasterrcnn_mobilenet_v3_large_fpn-fb6a3cc7.pth",
+        transforms=ObjectDetection,
+        meta={
+            **_COMMON_META,
+            "num_params": 19386354,
+            "recipe": "https://github.com/pytorch/vision/tree/main/references/detection#faster-r-cnn-mobilenetv3-large-fpn",
+            "_metrics": {
+                "COCO-val2017": {
+                    "box_map": 32.8,
+                }
+            },
+            "_ops": 4.494,
+            "_file_size": 74.239,
+            "_docs": """These weights were produced by following a similar training recipe as on the paper.""",
+        },
+    )
+    DEFAULT = COCO_V1
+
+
+class LineNet_MobileNet_V3_Large_320_FPN_Weights(WeightsEnum):
+    COCO_V1 = Weights(
+        url="https://download.pytorch.org/models/fasterrcnn_mobilenet_v3_large_320_fpn-907ea3f9.pth",
+        transforms=ObjectDetection,
+        meta={
+            **_COMMON_META,
+            "num_params": 19386354,
+            "recipe": "https://github.com/pytorch/vision/tree/main/references/detection#faster-r-cnn-mobilenetv3-large-320-fpn",
+            "_metrics": {
+                "COCO-val2017": {
+                    "box_map": 22.8,
+                }
+            },
+            "_ops": 0.719,
+            "_file_size": 74.239,
+            "_docs": """These weights were produced by following a similar training recipe as on the paper.""",
+        },
+    )
+    DEFAULT = COCO_V1
+
+
+@register_model()
+@handle_legacy_interface(
+    weights=("pretrained", LineNet_ResNet50_FPN_Weights.COCO_V1),
+    weights_backbone=("pretrained_backbone", ResNet50_Weights.IMAGENET1K_V1),
+)
+def linenet_resnet50_fpn(
+    *,
+    weights: Optional[LineNet_ResNet50_FPN_Weights] = None,
+    progress: bool = True,
+    num_classes: Optional[int] = None,
+    weights_backbone: Optional[ResNet50_Weights] = ResNet50_Weights.IMAGENET1K_V1,
+    trainable_backbone_layers: Optional[int] = None,
+    **kwargs: Any,
+) -> LineNet:
+    """
+    Faster R-CNN model with a ResNet-50-FPN backbone from the `Faster R-CNN: Towards Real-Time Object
+    Detection with Region Proposal Networks <https://arxiv.org/abs/1506.01497>`__
+    paper.
+
+    .. betastatus:: detection module
+
+    The input to the model is expected to be a list of tensors, each of shape ``[C, H, W]``, one for each
+    image, and should be in ``0-1`` range. Different images can have different sizes.
+
+    The behavior of the model changes depending on if it is in training or evaluation mode.
+
+    During training, the model expects both the input tensors and a targets (list of dictionary),
+    containing:
+
+        - boxes (``FloatTensor[N, 4]``): the ground-truth boxes in ``[x1, y1, x2, y2]`` format, with
+          ``0 <= x1 < x2 <= W`` and ``0 <= y1 < y2 <= H``.
+        - labels (``Int64Tensor[N]``): the class label for each ground-truth box
+
+    The model returns a ``Dict[Tensor]`` during training, containing the classification and regression
+    losses for both the RPN and the R-CNN.
+
+    During inference, the model requires only the input tensors, and returns the post-processed
+    predictions as a ``List[Dict[Tensor]]``, one for each input image. The fields of the ``Dict`` are as
+    follows, where ``N`` is the number of detections:
+
+        - boxes (``FloatTensor[N, 4]``): the predicted boxes in ``[x1, y1, x2, y2]`` format, with
+          ``0 <= x1 < x2 <= W`` and ``0 <= y1 < y2 <= H``.
+        - labels (``Int64Tensor[N]``): the predicted labels for each detection
+        - scores (``Tensor[N]``): the scores of each detection
+
+    For more details on the output, you may refer to :ref:`instance_seg_output`.
+
+    Faster R-CNN is exportable to ONNX for a fixed batch size with inputs images of fixed size.
+
+    Example::
+
+        >>> model = torchvision.models.detection.fasterrcnn_resnet50_fpn(weights=FasterRCNN_ResNet50_FPN_Weights.DEFAULT)
+        >>> # For training
+        >>> images, boxes = torch.rand(4, 3, 600, 1200), torch.rand(4, 11, 4)
+        >>> boxes[:, :, 2:4] = boxes[:, :, 0:2] + boxes[:, :, 2:4]
+        >>> labels = torch.randint(1, 91, (4, 11))
+        >>> images = list(image for image in images)
+        >>> targets = []
+        >>> for i in range(len(images)):
+        >>>     d = {}
+        >>>     d['boxes'] = boxes[i]
+        >>>     d['labels'] = labels[i]
+        >>>     targets.append(d)
+        >>> output = model(images, targets)
+        >>> # For inference
+        >>> model.eval()
+        >>> x = [torch.rand(3, 300, 400), torch.rand(3, 500, 400)]
+        >>> predictions = model(x)
+        >>>
+        >>> # optionally, if you want to export the model to ONNX:
+        >>> torch.onnx.export(model, x, "faster_rcnn.onnx", opset_version = 11)
+
+    Args:
+        weights (:class:`~torchvision.models.detection.FasterRCNN_ResNet50_FPN_Weights`, optional): The
+            pretrained weights to use. See
+            :class:`~torchvision.models.detection.FasterRCNN_ResNet50_FPN_Weights` below for
+            more details, and possible values. By default, no pre-trained
+            weights are used.
+        progress (bool, optional): If True, displays a progress bar of the
+            download to stderr. Default is True.
+        num_classes (int, optional): number of output classes of the model (including the background)
+        weights_backbone (:class:`~torchvision.models.ResNet50_Weights`, optional): The
+            pretrained weights for the backbone.
+        trainable_backbone_layers (int, optional): number of trainable (not frozen) layers starting from
+            final block. Valid values are between 0 and 5, with 5 meaning all backbone layers are
+            trainable. If ``None`` is passed (the default) this value is set to 3.
+        **kwargs: parameters passed to the ``torchvision.models.detection.faster_rcnn.FasterRCNN``
+            base class. Please refer to the `source code
+            <https://github.com/pytorch/vision/blob/main/torchvision/models/detection/faster_rcnn.py>`_
+            for more details about this class.
+
+    .. autoclass:: torchvision.models.detection.FasterRCNN_ResNet50_FPN_Weights
+        :members:
+    """
+    weights = LineNet_ResNet50_FPN_Weights.verify(weights)
+    weights_backbone = ResNet50_Weights.verify(weights_backbone)
+
+    if weights is not None:
+        weights_backbone = None
+        num_classes = _ovewrite_value_param("num_classes", num_classes, len(weights.meta["categories"]))
+    elif num_classes is None:
+        num_classes = 91
+
+    is_trained = weights is not None or weights_backbone is not None
+    trainable_backbone_layers = _validate_trainable_layers(is_trained, trainable_backbone_layers, 5, 3)
+    norm_layer = misc_nn_ops.FrozenBatchNorm2d if is_trained else nn.BatchNorm2d
+
+    backbone = resnet50(weights=weights_backbone, progress=progress, norm_layer=norm_layer)
+    backbone = _resnet_fpn_extractor(backbone, trainable_backbone_layers)
+    model = LineNet(backbone, num_classes=num_classes, **kwargs)
+
+    if weights is not None:
+        model.load_state_dict(weights.get_state_dict(progress=progress, check_hash=True))
+        if weights == LineNet_ResNet50_FPN_Weights.COCO_V1:
+            overwrite_eps(model, 0.0)
+
+    return model
+
+
+@register_model()
+@handle_legacy_interface(
+    weights=("pretrained", LineNet_ResNet50_FPN_V2_Weights.COCO_V1),
+    weights_backbone=("pretrained_backbone", ResNet50_Weights.IMAGENET1K_V1),
+)
+def linenet_resnet50_fpn_v2(
+    *,
+    weights: Optional[LineNet_ResNet50_FPN_V2_Weights] = None,
+    progress: bool = True,
+    num_classes: Optional[int] = None,
+    weights_backbone: Optional[ResNet50_Weights] = None,
+    trainable_backbone_layers: Optional[int] = None,
+    **kwargs: Any,
+) -> LineNet:
+    """
+    Constructs an improved Faster R-CNN model with a ResNet-50-FPN backbone from `Benchmarking Detection
+    Transfer Learning with Vision Transformers <https://arxiv.org/abs/2111.11429>`__ paper.
+
+    .. betastatus:: detection module
+
+    It works similarly to Faster R-CNN with ResNet-50 FPN backbone. See
+    :func:`~torchvision.models.detection.fasterrcnn_resnet50_fpn` for more
+    details.
+
+    Args:
+        weights (:class:`~torchvision.models.detection.FasterRCNN_ResNet50_FPN_V2_Weights`, optional): The
+            pretrained weights to use. See
+            :class:`~torchvision.models.detection.FasterRCNN_ResNet50_FPN_V2_Weights` below for
+            more details, and possible values. By default, no pre-trained
+            weights are used.
+        progress (bool, optional): If True, displays a progress bar of the
+            download to stderr. Default is True.
+        num_classes (int, optional): number of output classes of the model (including the background)
+        weights_backbone (:class:`~torchvision.models.ResNet50_Weights`, optional): The
+            pretrained weights for the backbone.
+        trainable_backbone_layers (int, optional): number of trainable (not frozen) layers starting from
+            final block. Valid values are between 0 and 5, with 5 meaning all backbone layers are
+            trainable. If ``None`` is passed (the default) this value is set to 3.
+        **kwargs: parameters passed to the ``torchvision.models.detection.faster_rcnn.FasterRCNN``
+            base class. Please refer to the `source code
+            <https://github.com/pytorch/vision/blob/main/torchvision/models/detection/faster_rcnn.py>`_
+            for more details about this class.
+
+    .. autoclass:: torchvision.models.detection.FasterRCNN_ResNet50_FPN_V2_Weights
+        :members:
+    """
+    weights = LineNet_ResNet50_FPN_V2_Weights.verify(weights)
+    weights_backbone = ResNet50_Weights.verify(weights_backbone)
+
+    if weights is not None:
+        weights_backbone = None
+        num_classes = _ovewrite_value_param("num_classes", num_classes, len(weights.meta["categories"]))
+    elif num_classes is None:
+        num_classes = 91
+
+    is_trained = weights is not None or weights_backbone is not None
+    trainable_backbone_layers = _validate_trainable_layers(is_trained, trainable_backbone_layers, 5, 3)
+
+    backbone = resnet50(weights=weights_backbone, progress=progress)
+    backbone = _resnet_fpn_extractor(backbone, trainable_backbone_layers, norm_layer=nn.BatchNorm2d)
+    rpn_anchor_generator = _default_anchorgen()
+    rpn_head = RPNHead(backbone.out_channels, rpn_anchor_generator.num_anchors_per_location()[0], conv_depth=2)
+    box_head = LineNetConvFCHead(
+        (backbone.out_channels, 7, 7), [256, 256, 256, 256], [1024], norm_layer=nn.BatchNorm2d
+    )
+    model = LineNet(
+        backbone,
+        num_classes=num_classes,
+        rpn_anchor_generator=rpn_anchor_generator,
+        rpn_head=rpn_head,
+        box_head=box_head,
+        **kwargs,
+    )
+
+    if weights is not None:
+        model.load_state_dict(weights.get_state_dict(progress=progress, check_hash=True))
+
+    return model
+
+
+def _linenet_mobilenet_v3_large_fpn(
+    *,
+    weights: Optional[Union[LineNet_MobileNet_V3_Large_FPN_Weights, LineNet_MobileNet_V3_Large_320_FPN_Weights]],
+    progress: bool,
+    num_classes: Optional[int],
+    weights_backbone: Optional[MobileNet_V3_Large_Weights],
+    trainable_backbone_layers: Optional[int],
+    **kwargs: Any,
+) -> LineNet:
+    if weights is not None:
+        weights_backbone = None
+        num_classes = _ovewrite_value_param("num_classes", num_classes, len(weights.meta["categories"]))
+    elif num_classes is None:
+        num_classes = 91
+
+    is_trained = weights is not None or weights_backbone is not None
+    trainable_backbone_layers = _validate_trainable_layers(is_trained, trainable_backbone_layers, 6, 3)
+    norm_layer = misc_nn_ops.FrozenBatchNorm2d if is_trained else nn.BatchNorm2d
+
+    backbone = mobilenet_v3_large(weights=weights_backbone, progress=progress, norm_layer=norm_layer)
+    backbone = _mobilenet_extractor(backbone, True, trainable_backbone_layers)
+    anchor_sizes = (
+        (
+            32,
+            64,
+            128,
+            256,
+            512,
+        ),
+    ) * 3
+    aspect_ratios = ((0.5, 1.0, 2.0),) * len(anchor_sizes)
+    model = LineNet(
+        backbone, num_classes, rpn_anchor_generator=AnchorGenerator(anchor_sizes, aspect_ratios), **kwargs
+    )
+
+    if weights is not None:
+        model.load_state_dict(weights.get_state_dict(progress=progress, check_hash=True))
+
+    return model
+
+
+@register_model()
+@handle_legacy_interface(
+    weights=("pretrained", LineNet_MobileNet_V3_Large_320_FPN_Weights.COCO_V1),
+    weights_backbone=("pretrained_backbone", MobileNet_V3_Large_Weights.IMAGENET1K_V1),
+)
+def linenet_mobilenet_v3_large_320_fpn(
+    *,
+    weights: Optional[LineNet_MobileNet_V3_Large_320_FPN_Weights] = None,
+    progress: bool = True,
+    num_classes: Optional[int] = None,
+    weights_backbone: Optional[MobileNet_V3_Large_Weights] = MobileNet_V3_Large_Weights.IMAGENET1K_V1,
+    trainable_backbone_layers: Optional[int] = None,
+    **kwargs: Any,
+) -> LineNet:
+    """
+    Low resolution Faster R-CNN model with a MobileNetV3-Large backbone tuned for mobile use cases.
+
+    .. betastatus:: detection module
+
+    It works similarly to Faster R-CNN with ResNet-50 FPN backbone. See
+    :func:`~torchvision.models.detection.fasterrcnn_resnet50_fpn` for more
+    details.
+
+    Example::
+
+        >>> model = torchvision.models.detection.fasterrcnn_mobilenet_v3_large_320_fpn(weights=FasterRCNN_MobileNet_V3_Large_320_FPN_Weights.DEFAULT)
+        >>> model.eval()
+        >>> x = [torch.rand(3, 300, 400), torch.rand(3, 500, 400)]
+        >>> predictions = model(x)
+
+    Args:
+        weights (:class:`~torchvision.models.detection.FasterRCNN_MobileNet_V3_Large_320_FPN_Weights`, optional): The
+            pretrained weights to use. See
+            :class:`~torchvision.models.detection.FasterRCNN_MobileNet_V3_Large_320_FPN_Weights` below for
+            more details, and possible values. By default, no pre-trained
+            weights are used.
+        progress (bool, optional): If True, displays a progress bar of the
+            download to stderr. Default is True.
+        num_classes (int, optional): number of output classes of the model (including the background)
+        weights_backbone (:class:`~torchvision.models.MobileNet_V3_Large_Weights`, optional): The
+            pretrained weights for the backbone.
+        trainable_backbone_layers (int, optional): number of trainable (not frozen) layers starting from
+            final block. Valid values are between 0 and 6, with 6 meaning all backbone layers are
+            trainable. If ``None`` is passed (the default) this value is set to 3.
+        **kwargs: parameters passed to the ``torchvision.models.detection.faster_rcnn.FasterRCNN``
+            base class. Please refer to the `source code
+            <https://github.com/pytorch/vision/blob/main/torchvision/models/detection/faster_rcnn.py>`_
+            for more details about this class.
+
+    .. autoclass:: torchvision.models.detection.FasterRCNN_MobileNet_V3_Large_320_FPN_Weights
+        :members:
+    """
+    weights = LineNet_MobileNet_V3_Large_320_FPN_Weights.verify(weights)
+    weights_backbone = MobileNet_V3_Large_Weights.verify(weights_backbone)
+
+    defaults = {
+        "min_size": 320,
+        "max_size": 640,
+        "rpn_pre_nms_top_n_test": 150,
+        "rpn_post_nms_top_n_test": 150,
+        "rpn_score_thresh": 0.05,
+    }
+
+    kwargs = {**defaults, **kwargs}
+    return _linenet_mobilenet_v3_large_fpn(
+        weights=weights,
+        progress=progress,
+        num_classes=num_classes,
+        weights_backbone=weights_backbone,
+        trainable_backbone_layers=trainable_backbone_layers,
+        **kwargs,
+    )
+
+
+@register_model()
+@handle_legacy_interface(
+    weights=("pretrained", LineNet_MobileNet_V3_Large_FPN_Weights.COCO_V1),
+    weights_backbone=("pretrained_backbone", MobileNet_V3_Large_Weights.IMAGENET1K_V1),
+)
+def linenet_mobilenet_v3_large_fpn(
+    *,
+    weights: Optional[LineNet_MobileNet_V3_Large_FPN_Weights] = None,
+    progress: bool = True,
+    num_classes: Optional[int] = None,
+    weights_backbone: Optional[MobileNet_V3_Large_Weights] = MobileNet_V3_Large_Weights.IMAGENET1K_V1,
+    trainable_backbone_layers: Optional[int] = None,
+    **kwargs: Any,
+) -> LineNet:
+    """
+    Constructs a high resolution Faster R-CNN model with a MobileNetV3-Large FPN backbone.
+
+    .. betastatus:: detection module
+
+    It works similarly to Faster R-CNN with ResNet-50 FPN backbone. See
+    :func:`~torchvision.models.detection.fasterrcnn_resnet50_fpn` for more
+    details.
+
+    Example::
+
+        >>> model = torchvision.models.detection.fasterrcnn_mobilenet_v3_large_fpn(weights=FasterRCNN_MobileNet_V3_Large_FPN_Weights.DEFAULT)
+        >>> model.eval()
+        >>> x = [torch.rand(3, 300, 400), torch.rand(3, 500, 400)]
+        >>> predictions = model(x)
+
+    Args:
+        weights (:class:`~torchvision.models.detection.FasterRCNN_MobileNet_V3_Large_FPN_Weights`, optional): The
+            pretrained weights to use. See
+            :class:`~torchvision.models.detection.FasterRCNN_MobileNet_V3_Large_FPN_Weights` below for
+            more details, and possible values. By default, no pre-trained
+            weights are used.
+        progress (bool, optional): If True, displays a progress bar of the
+            download to stderr. Default is True.
+        num_classes (int, optional): number of output classes of the model (including the background)
+        weights_backbone (:class:`~torchvision.models.MobileNet_V3_Large_Weights`, optional): The
+            pretrained weights for the backbone.
+        trainable_backbone_layers (int, optional): number of trainable (not frozen) layers starting from
+            final block. Valid values are between 0 and 6, with 6 meaning all backbone layers are
+            trainable. If ``None`` is passed (the default) this value is set to 3.
+        **kwargs: parameters passed to the ``torchvision.models.detection.faster_rcnn.FasterRCNN``
+            base class. Please refer to the `source code
+            <https://github.com/pytorch/vision/blob/main/torchvision/models/detection/faster_rcnn.py>`_
+            for more details about this class.
+
+    .. autoclass:: torchvision.models.detection.FasterRCNN_MobileNet_V3_Large_FPN_Weights
+        :members:
+    """
+    weights = LineNet_MobileNet_V3_Large_FPN_Weights.verify(weights)
+    weights_backbone = MobileNet_V3_Large_Weights.verify(weights_backbone)
+
+    defaults = {
+        "rpn_score_thresh": 0.05,
+    }
+
+    kwargs = {**defaults, **kwargs}
+    return _linenet_mobilenet_v3_large_fpn(
+        weights=weights,
+        progress=progress,
+        num_classes=num_classes,
+        weights_backbone=weights_backbone,
+        trainable_backbone_layers=trainable_backbone_layers,
+        **kwargs,
+    )
+

+ 324 - 0
models/line_detect/line_predictor.py

@@ -0,0 +1,324 @@
+from typing import Any, Optional
+
+import torch
+from torch import nn
+from torchvision.ops import MultiScaleRoIAlign
+
+from libs.vision_libs.ops import misc as misc_nn_ops
+from libs.vision_libs.transforms._presets import ObjectDetection
+from .roi_heads import RoIHeads
+from libs.vision_libs.models._api import register_model, Weights, WeightsEnum
+from libs.vision_libs.models._meta import _COCO_PERSON_CATEGORIES, _COCO_PERSON_KEYPOINT_NAMES
+from libs.vision_libs.models._utils import _ovewrite_value_param, handle_legacy_interface
+from libs.vision_libs.models.resnet import resnet50, ResNet50_Weights
+from libs.vision_libs.models.detection._utils import overwrite_eps
+from libs.vision_libs.models.detection.backbone_utils import _resnet_fpn_extractor, _validate_trainable_layers
+from libs.vision_libs.models.detection.faster_rcnn import FasterRCNN, TwoMLPHead, FastRCNNPredictor
+
+from models.config.config_tool import read_yaml
+import numpy as np
+import torch.nn.functional as F
+
+FEATURE_DIM = 8
+def non_maximum_suppression(a):
+    ap = F.max_pool2d(a, 3, stride=1, padding=1)
+    mask = (a == ap).float().clamp(min=0.0)
+    return a * mask
+
+class LineRCNNPredictor(nn.Module):
+    def __init__(self):
+        super().__init__()
+        # self.backbone = backbone
+        # self.cfg = read_yaml(cfg)
+        self.cfg = read_yaml(r'./config/wireframe.yaml')
+        self.n_pts0 = self.cfg['model']['n_pts0']
+        self.n_pts1 = self.cfg['model']['n_pts1']
+        self.n_stc_posl = self.cfg['model']['n_stc_posl']
+        self.dim_loi = self.cfg['model']['dim_loi']
+        self.use_conv = self.cfg['model']['use_conv']
+        self.dim_fc = self.cfg['model']['dim_fc']
+        self.n_out_line = self.cfg['model']['n_out_line']
+        self.n_out_junc = self.cfg['model']['n_out_junc']
+        self.loss_weight = self.cfg['model']['loss_weight']
+        self.n_dyn_junc = self.cfg['model']['n_dyn_junc']
+        self.eval_junc_thres = self.cfg['model']['eval_junc_thres']
+        self.n_dyn_posl = self.cfg['model']['n_dyn_posl']
+        self.n_dyn_negl = self.cfg['model']['n_dyn_negl']
+        self.n_dyn_othr = self.cfg['model']['n_dyn_othr']
+        self.use_cood = self.cfg['model']['use_cood']
+        self.use_slop = self.cfg['model']['use_slop']
+        self.n_stc_negl = self.cfg['model']['n_stc_negl']
+        self.head_size = self.cfg['model']['head_size']
+        self.num_class = sum(sum(self.head_size, []))
+        self.head_off = np.cumsum([sum(h) for h in self.head_size])
+
+        lambda_ = torch.linspace(0, 1, self.n_pts0)[:, None]
+        self.register_buffer("lambda_", lambda_)
+        self.do_static_sampling = self.n_stc_posl + self.n_stc_negl > 0
+
+        self.fc1 = nn.Conv2d(256, self.dim_loi, 1)
+        scale_factor = self.n_pts0 // self.n_pts1
+        if self.use_conv:
+            self.pooling = nn.Sequential(
+                nn.MaxPool1d(scale_factor, scale_factor),
+                Bottleneck1D(self.dim_loi, self.dim_loi),
+            )
+            self.fc2 = nn.Sequential(
+                nn.ReLU(inplace=True), nn.Linear(self.dim_loi * self.n_pts1 + FEATURE_DIM, 1)
+            )
+        else:
+            self.pooling = nn.MaxPool1d(scale_factor, scale_factor)
+            self.fc2 = nn.Sequential(
+                nn.Linear(self.dim_loi * self.n_pts1 + FEATURE_DIM, self.dim_fc),
+                nn.ReLU(inplace=True),
+                nn.Linear(self.dim_fc, self.dim_fc),
+                nn.ReLU(inplace=True),
+                nn.Linear(self.dim_fc, 1),
+            )
+        self.loss = nn.BCEWithLogitsLoss(reduction="none")
+
+    def forward(self, inputs, features, targets=None):
+
+        # outputs, features = input
+        # for out in outputs:
+        #     print(f'out:{out.shape}')
+        # outputs=merge_features(outputs,100)
+        batch, channel, row, col = inputs.shape
+        # print(f'outputs:{inputs.shape}')
+        # print(f'batch:{batch}, channel:{channel}, row:{row}, col:{col}')
+
+        if targets is not None:
+            self.training = True
+            # print(f'target:{targets}')
+            wires_targets = [t["wires"] for t in targets]
+            # print(f'wires_target:{wires_targets}')
+            # 提取所有 'junc_map', 'junc_offset', 'line_map' 的张量
+            junc_maps = [d["junc_map"] for d in wires_targets]
+            junc_offsets = [d["junc_offset"] for d in wires_targets]
+            line_maps = [d["line_map"] for d in wires_targets]
+
+            junc_map_tensor = torch.stack(junc_maps, dim=0)
+            junc_offset_tensor = torch.stack(junc_offsets, dim=0)
+            line_map_tensor = torch.stack(line_maps, dim=0)
+
+            wires_meta = {
+                "junc_map": junc_map_tensor,
+                "junc_offset": junc_offset_tensor,
+                # "line_map": line_map_tensor,
+            }
+        else:
+            self.training = False
+            t = {
+                "junc_coords": torch.zeros(1, 2),
+                "jtyp": torch.zeros(1, dtype=torch.uint8),
+                "line_pos_idx": torch.zeros(2, 2, dtype=torch.uint8),
+                "line_neg_idx": torch.zeros(2, 2, dtype=torch.uint8),
+                "junc_map": torch.zeros([1, 1, 128, 128]),
+                "junc_offset": torch.zeros([1, 1, 2, 128, 128]),
+            }
+            wires_targets = [t for b in range(inputs.size(0))]
+
+            wires_meta = {
+                "junc_map": torch.zeros([1, 1, 128, 128]),
+                "junc_offset": torch.zeros([1, 1, 2, 128, 128]),
+            }
+
+        T = wires_meta.copy()
+        n_jtyp = T["junc_map"].shape[1]
+        offset = self.head_off
+        result = {}
+        for stack, output in enumerate([inputs]):
+            output = output.transpose(0, 1).reshape([-1, batch, row, col]).contiguous()
+            # print(f"Stack {stack} output shape: {output.shape}")  # 打印每层的输出形状
+            jmap = output[0: offset[0]].reshape(n_jtyp, 2, batch, row, col)
+            lmap = output[offset[0]: offset[1]].squeeze(0)
+            joff = output[offset[1]: offset[2]].reshape(n_jtyp, 2, batch, row, col)
+
+            if stack == 0:
+                result["preds"] = {
+                    "jmap": jmap.permute(2, 0, 1, 3, 4).softmax(2)[:, :, 1],
+                    "lmap": lmap.sigmoid(),
+                    "joff": joff.permute(2, 0, 1, 3, 4).sigmoid() - 0.5,
+                }
+                # visualize_feature_map(jmap[0, 0], title=f"jmap - Stack {stack}")
+                # visualize_feature_map(lmap, title=f"lmap - Stack {stack}")
+                # visualize_feature_map(joff[0, 0], title=f"joff - Stack {stack}")
+
+        h = result["preds"]
+        # print(f'features shape:{features.shape}')
+        x = self.fc1(features)
+
+        # print(f'x:{x.shape}')
+
+        n_batch, n_channel, row, col = x.shape
+
+        # print(f'n_batch:{n_batch}, n_channel:{n_channel}, row:{row}, col:{col}')
+
+        xs, ys, fs, ps, idx, jcs = [], [], [], [], [0], []
+
+        for i, meta in enumerate(wires_targets):
+            p, label, feat, jc = self.sample_lines(
+                meta, h["jmap"][i], h["joff"][i],
+            )
+            # print(f"p.shape:{p.shape},label:{label.shape},feat:{feat.shape},jc:{len(jc)}")
+            ys.append(label)
+            if self.training and self.do_static_sampling:
+                p = torch.cat([p, meta["lpre"]])
+                feat = torch.cat([feat, meta["lpre_feat"]])
+                ys.append(meta["lpre_label"])
+                del jc
+            else:
+                jcs.append(jc)
+                ps.append(p)
+            fs.append(feat)
+
+            p = p[:, 0:1, :] * self.lambda_ + p[:, 1:2, :] * (1 - self.lambda_) - 0.5
+            p = p.reshape(-1, 2)  # [N_LINE x N_POINT, 2_XY]
+            px, py = p[:, 0].contiguous(), p[:, 1].contiguous()
+            px0 = px.floor().clamp(min=0, max=127)
+            py0 = py.floor().clamp(min=0, max=127)
+            px1 = (px0 + 1).clamp(min=0, max=127)
+            py1 = (py0 + 1).clamp(min=0, max=127)
+            px0l, py0l, px1l, py1l = px0.long(), py0.long(), px1.long(), py1.long()
+
+            # xp: [N_LINE, N_CHANNEL, N_POINT]
+            xp = (
+                (
+                        x[i, :, px0l, py0l] * (px1 - px) * (py1 - py)
+                        + x[i, :, px1l, py0l] * (px - px0) * (py1 - py)
+                        + x[i, :, px0l, py1l] * (px1 - px) * (py - py0)
+                        + x[i, :, px1l, py1l] * (px - px0) * (py - py0)
+                )
+                .reshape(n_channel, -1, self.n_pts0)
+                .permute(1, 0, 2)
+            )
+            xp = self.pooling(xp)
+            # print(f'xp.shape:{xp.shape}')
+            xs.append(xp)
+            idx.append(idx[-1] + xp.shape[0])
+            # print(f'idx__:{idx}')
+
+        x, y = torch.cat(xs), torch.cat(ys)
+        f = torch.cat(fs)
+        x = x.reshape(-1, self.n_pts1 * self.dim_loi)
+
+        # print("Weight dtype:", self.fc2.weight.dtype)
+        x = torch.cat([x, f], 1)
+        # print("Input dtype:", x.dtype)
+        x = x.to(dtype=torch.float32)
+        # print("Input dtype1:", x.dtype)
+        x = self.fc2(x).flatten()
+
+        # return  x,idx,jcs,n_batch,ps,self.n_out_line,self.n_out_junc
+        return x, y, idx, jcs, n_batch, ps, self.n_out_line, self.n_out_junc
+
+        # if mode != "training":
+        # self.inference(x, idx, jcs, n_batch, ps)
+
+        # return result
+
+    def sample_lines(self, meta, jmap, joff):
+        device = jmap.device
+        with torch.no_grad():
+            junc = meta["junc_coords"].to(device)  # [N, 2]
+            jtyp = meta["jtyp"].to(device)  # [N]
+            Lpos = meta["line_pos_idx"].to(device)
+            Lneg = meta["line_neg_idx"].to(device)
+
+            n_type = jmap.shape[0]
+            jmap = non_maximum_suppression(jmap).reshape(n_type, -1)
+            joff = joff.reshape(n_type, 2, -1)
+            max_K = self.n_dyn_junc // n_type
+            N = len(junc)
+            # if mode != "training":
+            if not self.training:
+                K = min(int((jmap > self.eval_junc_thres).float().sum().item()), max_K)
+            else:
+                K = min(int(N * 2 + 2), max_K)
+            if K < 2:
+                K = 2
+            device = jmap.device
+
+            # index: [N_TYPE, K]
+            score, index = torch.topk(jmap, k=K)
+            y = (index // 128).float() + torch.gather(joff[:, 0], 1, index) + 0.5
+            x = (index % 128).float() + torch.gather(joff[:, 1], 1, index) + 0.5
+
+            # xy: [N_TYPE, K, 2]
+            xy = torch.cat([y[..., None], x[..., None]], dim=-1)
+            xy_ = xy[..., None, :]
+            del x, y, index
+
+            # dist: [N_TYPE, K, N]
+            dist = torch.sum((xy_ - junc) ** 2, -1)
+            cost, match = torch.min(dist, -1)
+
+            # xy: [N_TYPE * K, 2]
+            # match: [N_TYPE, K]
+            for t in range(n_type):
+                match[t, jtyp[match[t]] != t] = N
+            match[cost > 1.5 * 1.5] = N
+            match = match.flatten()
+
+            _ = torch.arange(n_type * K, device=device)
+            u, v = torch.meshgrid(_, _)
+            u, v = u.flatten(), v.flatten()
+            up, vp = match[u], match[v]
+            label = Lpos[up, vp]
+
+            # if mode == "training":
+            if self.training:
+                c = torch.zeros_like(label, dtype=torch.bool)
+
+                # sample positive lines
+                cdx = label.nonzero().flatten()
+                if len(cdx) > self.n_dyn_posl:
+                    # print("too many positive lines")
+                    perm = torch.randperm(len(cdx), device=device)[: self.n_dyn_posl]
+                    cdx = cdx[perm]
+                c[cdx] = 1
+
+                # sample negative lines
+                cdx = Lneg[up, vp].nonzero().flatten()
+                if len(cdx) > self.n_dyn_negl:
+                    # print("too many negative lines")
+                    perm = torch.randperm(len(cdx), device=device)[: self.n_dyn_negl]
+                    cdx = cdx[perm]
+                c[cdx] = 1
+
+                # sample other (unmatched) lines
+                cdx = torch.randint(len(c), (self.n_dyn_othr,), device=device)
+                c[cdx] = 1
+            else:
+                c = (u < v).flatten()
+
+            # sample lines
+            u, v, label = u[c], v[c], label[c]
+            xy = xy.reshape(n_type * K, 2)
+            xyu, xyv = xy[u], xy[v]
+
+            u2v = xyu - xyv
+            u2v /= torch.sqrt((u2v ** 2).sum(-1, keepdim=True)).clamp(min=1e-6)
+            feat = torch.cat(
+                [
+                    xyu / 128 * self.use_cood,
+                    xyv / 128 * self.use_cood,
+                    u2v * self.use_slop,
+                    (u[:, None] > K).float(),
+                    (v[:, None] > K).float(),
+                ],
+                1,
+            )
+            line = torch.cat([xyu[:, None], xyv[:, None]], 1)
+
+            xy = xy.reshape(n_type, K, 2)
+            jcs = [xy[i, score[i] > 0.03] for i in range(n_type)]
+            return line, label.float(), feat, jcs
+
+
+
+_COMMON_META = {
+    "categories": _COCO_PERSON_CATEGORIES,
+    "keypoint_names": _COCO_PERSON_KEYPOINT_NAMES,
+    "min_size": (1, 1),
+}

+ 1177 - 0
models/line_detect/roi_heads.py

@@ -0,0 +1,1177 @@
+from typing import Dict, List, Optional, Tuple
+
+import torch
+import torch.nn.functional as F
+import torchvision
+from torch import nn, Tensor
+from torchvision.ops import boxes as box_ops, roi_align
+
+import libs.vision_libs.models.detection._utils as det_utils
+
+from collections import OrderedDict
+
+
+def l2loss(input, target):
+    return ((target - input) ** 2).mean(2).mean(1)
+
+
+def cross_entropy_loss(logits, positive):
+    nlogp = -F.log_softmax(logits, dim=0)
+    return (positive * nlogp[1] + (1 - positive) * nlogp[0]).mean(2).mean(1)
+
+
+def sigmoid_l1_loss(logits, target, offset=0.0, mask=None):
+    logp = torch.sigmoid(logits) + offset
+    loss = torch.abs(logp - target)
+    if mask is not None:
+        w = mask.mean(2, True).mean(1, True)
+        w[w == 0] = 1
+        loss = loss * (mask / w)
+
+    return loss.mean(2).mean(1)
+
+
+###计算多头损失
+def line_head_loss(input_dict, outputs, feature, loss_weight, mode_train):
+    # image = input_dict["image"]
+    # target_b = input_dict["target_b"]
+    # outputs, feature, aaa = self.backbone(image, target_b, input_dict["mode"])  # train时aaa是损失,val时是box
+
+    result = {"feature": feature}
+    batch, channel, row, col = outputs[0].shape
+
+    T = input_dict["target"].copy()
+    n_jtyp = T["junc_map"].shape[1]
+
+    # switch to CNHW
+    for task in ["junc_map"]:
+        T[task] = T[task].permute(1, 0, 2, 3)
+    for task in ["junc_offset"]:
+        T[task] = T[task].permute(1, 2, 0, 3, 4)
+
+    offset = [2, 3, 5]
+    losses = []
+    for stack, output in enumerate(outputs):
+        output = output.transpose(0, 1).reshape([-1, batch, row, col]).contiguous()
+        jmap = output[0: offset[0]].reshape(n_jtyp, 2, batch, row, col)
+        lmap = output[offset[0]: offset[1]].squeeze(0)
+        # print(f"lmap:{lmap.shape}")
+        joff = output[offset[1]: offset[2]].reshape(n_jtyp, 2, batch, row, col)
+        if stack == 0:
+            result["preds"] = {
+                "jmap": jmap.permute(2, 0, 1, 3, 4).softmax(2)[:, :, 1],
+                "lmap": lmap.sigmoid(),
+                "joff": joff.permute(2, 0, 1, 3, 4).sigmoid() - 0.5,
+            }
+            if mode_train == False:
+                return result
+
+        L = OrderedDict()
+        L["jmap"] = sum(
+            cross_entropy_loss(jmap[i], T["junc_map"][i]) for i in range(n_jtyp)
+        )
+        L["lmap"] = (
+            F.binary_cross_entropy_with_logits(lmap, T["line_map"], reduction="none")
+                .mean(2)
+                .mean(1)
+        )
+        L["joff"] = sum(
+            sigmoid_l1_loss(joff[i, j], T["junc_offset"][i, j], -0.5, T["junc_map"][i])
+            for i in range(n_jtyp)
+            for j in range(2)
+        )
+        for loss_name in L:
+            L[loss_name].mul_(loss_weight[loss_name])
+        losses.append(L)
+    result["losses"] = losses
+    # result["aaa"] = aaa
+    return result
+
+
+#  计算线性损失
+def line_vectorizer_loss(result, x, ys, idx, jcs, n_batch, ps, n_out_line, n_out_junc, loss_weight, mode_train):
+    if mode_train == False:
+        p = torch.cat(ps)
+        s = torch.sigmoid(x)
+        b = s > 0.5
+        lines = []
+        score = []
+        for i in range(n_batch):
+            p0 = p[idx[i]: idx[i + 1]]
+            s0 = s[idx[i]: idx[i + 1]]
+            mask = b[idx[i]: idx[i + 1]]
+            p0 = p0[mask]
+            s0 = s0[mask]
+            if len(p0) == 0:
+                lines.append(torch.zeros([1, n_out_line, 2, 2], device=p.device))
+                score.append(torch.zeros([1, n_out_line], device=p.device))
+            else:
+                arg = torch.argsort(s0, descending=True)
+                p0, s0 = p0[arg], s0[arg]
+                lines.append(p0[None, torch.arange(n_out_line) % len(p0)])
+                score.append(s0[None, torch.arange(n_out_line) % len(s0)])
+            for j in range(len(jcs[i])):
+                if len(jcs[i][j]) == 0:
+                    jcs[i][j] = torch.zeros([n_out_junc, 2], device=p.device)
+                jcs[i][j] = jcs[i][j][
+                    None, torch.arange(n_out_junc) % len(jcs[i][j])
+                ]
+        result["preds"]["lines"] = torch.cat(lines)
+        result["preds"]["score"] = torch.cat(score)
+        result["preds"]["juncs"] = torch.cat([jcs[i][0] for i in range(n_batch)])
+        if len(jcs[i]) > 1:
+            result["preds"]["junts"] = torch.cat(
+                [jcs[i][1] for i in range(n_batch)]
+            )
+
+    # if input_dict["mode"] != "testing":
+    y = torch.cat(ys)
+    loss = nn.BCEWithLogitsLoss(reduction="none")
+    loss = loss(x, y)
+    lpos_mask, lneg_mask = y, 1 - y
+    loss_lpos, loss_lneg = loss * lpos_mask, loss * lneg_mask
+
+    def sum_batch(x):
+        xs = [x[idx[i]: idx[i + 1]].sum()[None] for i in range(n_batch)]
+        return torch.cat(xs)
+
+    lpos = sum_batch(loss_lpos) / sum_batch(lpos_mask).clamp(min=1)
+    lneg = sum_batch(loss_lneg) / sum_batch(lneg_mask).clamp(min=1)
+    result["losses"][0]["lpos"] = lpos * loss_weight["lpos"]
+    result["losses"][0]["lneg"] = lneg * loss_weight["lneg"]
+
+    if mode_train == True:
+        del result["preds"]
+
+    return result
+
+
+
+def wirepoint_head_line_loss(targets, output, x, y, idx, loss_weight):
+    # output, feature: head返回结果
+    # x, y, idx : line中间生成结果
+    result = {}
+    batch, channel, row, col = output.shape
+
+    wires_targets = [t["wires"] for t in targets]
+    wires_targets = wires_targets.copy()
+    # print(f'wires_target:{wires_targets}')
+    # 提取所有 'junc_map', 'junc_offset', 'line_map' 的张量
+    junc_maps = [d["junc_map"] for d in wires_targets]
+    junc_offsets = [d["junc_offset"] for d in wires_targets]
+    line_maps = [d["line_map"] for d in wires_targets]
+
+    junc_map_tensor = torch.stack(junc_maps, dim=0)
+    junc_offset_tensor = torch.stack(junc_offsets, dim=0)
+    line_map_tensor = torch.stack(line_maps, dim=0)
+    T = {"junc_map": junc_map_tensor, "junc_offset": junc_offset_tensor, "line_map": line_map_tensor}
+
+    n_jtyp = T["junc_map"].shape[1]
+
+    for task in ["junc_map"]:
+        T[task] = T[task].permute(1, 0, 2, 3)
+    for task in ["junc_offset"]:
+        T[task] = T[task].permute(1, 2, 0, 3, 4)
+
+    offset = [2, 3, 5]
+    losses = []
+    output = output.transpose(0, 1).reshape([-1, batch, row, col]).contiguous()
+    jmap = output[0: offset[0]].reshape(n_jtyp, 2, batch, row, col)
+    lmap = output[offset[0]: offset[1]].squeeze(0)
+    joff = output[offset[1]: offset[2]].reshape(n_jtyp, 2, batch, row, col)
+    L = OrderedDict()
+    L["junc_map"] = sum(
+        cross_entropy_loss(jmap[i], T["junc_map"][i]) for i in range(n_jtyp)
+    ).mean()
+    L["line_map"] = (
+        F.binary_cross_entropy_with_logits(lmap, T["line_map"], reduction="none")
+            .mean(2)
+            .mean(1)
+    ).mean()
+    L["junc_offset"] = sum(
+        sigmoid_l1_loss(joff[i, j], T["junc_offset"][i, j], -0.5, T["junc_map"][i])
+        for i in range(n_jtyp)
+        for j in range(2)
+    ).mean()
+    for loss_name in L:
+        L[loss_name].mul_(loss_weight[loss_name])
+    losses.append(L)
+    result["losses"] = losses
+
+    loss = nn.BCEWithLogitsLoss(reduction="none")
+    loss = loss(x, y)
+    lpos_mask, lneg_mask = y, 1 - y
+    loss_lpos, loss_lneg = loss * lpos_mask, loss * lneg_mask
+
+    def sum_batch(x):
+        xs = [x[idx[i]: idx[i + 1]].sum()[None] for i in range(batch)]
+        return torch.cat(xs)
+
+    lpos = sum_batch(loss_lpos) / sum_batch(lpos_mask).clamp(min=1)
+    lneg = sum_batch(loss_lneg) / sum_batch(lneg_mask).clamp(min=1)
+    result["losses"][0]["lpos"] = (lpos * loss_weight["lpos"]).mean()
+    result["losses"][0]["lneg"] = (lneg * loss_weight["lneg"]).mean()
+
+    return result
+
+
+def wirepoint_inference(input, idx, jcs, n_batch, ps, n_out_line, n_out_junc):
+    result = {}
+    result["wires"] = {}
+    p = torch.cat(ps)
+    s = torch.sigmoid(input)
+    b = s > 0.5
+    lines = []
+    score = []
+    # print(f"n_batch:{n_batch}")
+    for i in range(n_batch):
+        # print(f"idx:{idx}")
+        p0 = p[idx[i]: idx[i + 1]]
+        s0 = s[idx[i]: idx[i + 1]]
+        mask = b[idx[i]: idx[i + 1]]
+        p0 = p0[mask]
+        s0 = s0[mask]
+        if len(p0) == 0:
+            lines.append(torch.zeros([1, n_out_line, 2, 2], device=p.device))
+            score.append(torch.zeros([1, n_out_line], device=p.device))
+        else:
+            arg = torch.argsort(s0, descending=True)
+            p0, s0 = p0[arg], s0[arg]
+            lines.append(p0[None, torch.arange(n_out_line) % len(p0)])
+            score.append(s0[None, torch.arange(n_out_line) % len(s0)])
+        for j in range(len(jcs[i])):
+            if len(jcs[i][j]) == 0:
+                jcs[i][j] = torch.zeros([n_out_junc, 2], device=p.device)
+            jcs[i][j] = jcs[i][j][
+                None, torch.arange(n_out_junc) % len(jcs[i][j])
+            ]
+    result["wires"]["lines"] = torch.cat(lines)
+    result["wires"]["score"] = torch.cat(score)
+    result["wires"]["juncs"] = torch.cat([jcs[i][0] for i in range(n_batch)])
+
+    if len(jcs[i]) > 1:
+        result["preds"]["junts"] = torch.cat(
+            [jcs[i][1] for i in range(n_batch)]
+        )
+
+    return result
+
+
+def fastrcnn_loss(class_logits, box_regression, labels, regression_targets):
+    # type: (Tensor, Tensor, List[Tensor], List[Tensor]) -> Tuple[Tensor, Tensor]
+    """
+    Computes the loss for Faster R-CNN.
+
+    Args:
+        class_logits (Tensor)
+        box_regression (Tensor)
+        labels (list[BoxList])
+        regression_targets (Tensor)
+
+    Returns:
+        classification_loss (Tensor)
+        box_loss (Tensor)
+    """
+
+    labels = torch.cat(labels, dim=0)
+    regression_targets = torch.cat(regression_targets, dim=0)
+
+    classification_loss = F.cross_entropy(class_logits, labels)
+
+    # get indices that correspond to the regression targets for
+    # the corresponding ground truth labels, to be used with
+    # advanced indexing
+    sampled_pos_inds_subset = torch.where(labels > 0)[0]
+    labels_pos = labels[sampled_pos_inds_subset]
+    N, num_classes = class_logits.shape
+    box_regression = box_regression.reshape(N, box_regression.size(-1) // 4, 4)
+
+    box_loss = F.smooth_l1_loss(
+        box_regression[sampled_pos_inds_subset, labels_pos],
+        regression_targets[sampled_pos_inds_subset],
+        beta=1 / 9,
+        reduction="sum",
+    )
+    box_loss = box_loss / labels.numel()
+
+    return classification_loss, box_loss
+
+
+def maskrcnn_inference(x, labels):
+    # type: (Tensor, List[Tensor]) -> List[Tensor]
+    """
+    From the results of the CNN, post process the masks
+    by taking the mask corresponding to the class with max
+    probability (which are of fixed size and directly output
+    by the CNN) and return the masks in the mask field of the BoxList.
+
+    Args:
+        x (Tensor): the mask logits
+        labels (list[BoxList]): bounding boxes that are used as
+            reference, one for ech image
+
+    Returns:
+        results (list[BoxList]): one BoxList for each image, containing
+            the extra field mask
+    """
+    mask_prob = x.sigmoid()
+
+    # select masks corresponding to the predicted classes
+    num_masks = x.shape[0]
+    boxes_per_image = [label.shape[0] for label in labels]
+    labels = torch.cat(labels)
+    index = torch.arange(num_masks, device=labels.device)
+    mask_prob = mask_prob[index, labels][:, None]
+    mask_prob = mask_prob.split(boxes_per_image, dim=0)
+
+    return mask_prob
+
+
+def project_masks_on_boxes(gt_masks, boxes, matched_idxs, M):
+    # type: (Tensor, Tensor, Tensor, int) -> Tensor
+    """
+    Given segmentation masks and the bounding boxes corresponding
+    to the location of the masks in the image, this function
+    crops and resizes the masks in the position defined by the
+    boxes. This prepares the masks for them to be fed to the
+    loss computation as the targets.
+    """
+    matched_idxs = matched_idxs.to(boxes)
+    rois = torch.cat([matched_idxs[:, None], boxes], dim=1)
+    gt_masks = gt_masks[:, None].to(rois)
+    return roi_align(gt_masks, rois, (M, M), 1.0)[:, 0]
+
+
+def maskrcnn_loss(mask_logits, proposals, gt_masks, gt_labels, mask_matched_idxs):
+    # type: (Tensor, List[Tensor], List[Tensor], List[Tensor], List[Tensor]) -> Tensor
+    """
+    Args:
+        proposals (list[BoxList])
+        mask_logits (Tensor)
+        targets (list[BoxList])
+
+    Return:
+        mask_loss (Tensor): scalar tensor containing the loss
+    """
+
+    discretization_size = mask_logits.shape[-1]
+    labels = [gt_label[idxs] for gt_label, idxs in zip(gt_labels, mask_matched_idxs)]
+    mask_targets = [
+        project_masks_on_boxes(m, p, i, discretization_size) for m, p, i in zip(gt_masks, proposals, mask_matched_idxs)
+    ]
+
+    labels = torch.cat(labels, dim=0)
+    mask_targets = torch.cat(mask_targets, dim=0)
+
+    # torch.mean (in binary_cross_entropy_with_logits) doesn't
+    # accept empty tensors, so handle it separately
+    if mask_targets.numel() == 0:
+        return mask_logits.sum() * 0
+
+    mask_loss = F.binary_cross_entropy_with_logits(
+        mask_logits[torch.arange(labels.shape[0], device=labels.device), labels], mask_targets
+    )
+    return mask_loss
+
+
+def keypoints_to_heatmap(keypoints, rois, heatmap_size):
+    # type: (Tensor, Tensor, int) -> Tuple[Tensor, Tensor]
+    offset_x = rois[:, 0]
+    offset_y = rois[:, 1]
+    scale_x = heatmap_size / (rois[:, 2] - rois[:, 0])
+    scale_y = heatmap_size / (rois[:, 3] - rois[:, 1])
+
+    offset_x = offset_x[:, None]
+    offset_y = offset_y[:, None]
+    scale_x = scale_x[:, None]
+    scale_y = scale_y[:, None]
+
+    x = keypoints[..., 0]
+    y = keypoints[..., 1]
+
+    x_boundary_inds = x == rois[:, 2][:, None]
+    y_boundary_inds = y == rois[:, 3][:, None]
+
+    x = (x - offset_x) * scale_x
+    x = x.floor().long()
+    y = (y - offset_y) * scale_y
+    y = y.floor().long()
+
+    x[x_boundary_inds] = heatmap_size - 1
+    y[y_boundary_inds] = heatmap_size - 1
+
+    valid_loc = (x >= 0) & (y >= 0) & (x < heatmap_size) & (y < heatmap_size)
+    vis = keypoints[..., 2] > 0
+    valid = (valid_loc & vis).long()
+
+    lin_ind = y * heatmap_size + x
+    heatmaps = lin_ind * valid
+
+    return heatmaps, valid
+
+
+def _onnx_heatmaps_to_keypoints(
+        maps, maps_i, roi_map_width, roi_map_height, widths_i, heights_i, offset_x_i, offset_y_i
+):
+    num_keypoints = torch.scalar_tensor(maps.size(1), dtype=torch.int64)
+
+    width_correction = widths_i / roi_map_width
+    height_correction = heights_i / roi_map_height
+
+    roi_map = F.interpolate(
+        maps_i[:, None], size=(int(roi_map_height), int(roi_map_width)), mode="bicubic", align_corners=False
+    )[:, 0]
+
+    w = torch.scalar_tensor(roi_map.size(2), dtype=torch.int64)
+    pos = roi_map.reshape(num_keypoints, -1).argmax(dim=1)
+
+    x_int = pos % w
+    y_int = (pos - x_int) // w
+
+    x = (torch.tensor(0.5, dtype=torch.float32) + x_int.to(dtype=torch.float32)) * width_correction.to(
+        dtype=torch.float32
+    )
+    y = (torch.tensor(0.5, dtype=torch.float32) + y_int.to(dtype=torch.float32)) * height_correction.to(
+        dtype=torch.float32
+    )
+
+    xy_preds_i_0 = x + offset_x_i.to(dtype=torch.float32)
+    xy_preds_i_1 = y + offset_y_i.to(dtype=torch.float32)
+    xy_preds_i_2 = torch.ones(xy_preds_i_1.shape, dtype=torch.float32)
+    xy_preds_i = torch.stack(
+        [
+            xy_preds_i_0.to(dtype=torch.float32),
+            xy_preds_i_1.to(dtype=torch.float32),
+            xy_preds_i_2.to(dtype=torch.float32),
+        ],
+        0,
+    )
+
+    # TODO: simplify when indexing without rank will be supported by ONNX
+    base = num_keypoints * num_keypoints + num_keypoints + 1
+    ind = torch.arange(num_keypoints)
+    ind = ind.to(dtype=torch.int64) * base
+    end_scores_i = (
+        roi_map.index_select(1, y_int.to(dtype=torch.int64))
+            .index_select(2, x_int.to(dtype=torch.int64))
+            .view(-1)
+            .index_select(0, ind.to(dtype=torch.int64))
+    )
+
+    return xy_preds_i, end_scores_i
+
+
+@torch.jit._script_if_tracing
+def _onnx_heatmaps_to_keypoints_loop(
+        maps, rois, widths_ceil, heights_ceil, widths, heights, offset_x, offset_y, num_keypoints
+):
+    xy_preds = torch.zeros((0, 3, int(num_keypoints)), dtype=torch.float32, device=maps.device)
+    end_scores = torch.zeros((0, int(num_keypoints)), dtype=torch.float32, device=maps.device)
+
+    for i in range(int(rois.size(0))):
+        xy_preds_i, end_scores_i = _onnx_heatmaps_to_keypoints(
+            maps, maps[i], widths_ceil[i], heights_ceil[i], widths[i], heights[i], offset_x[i], offset_y[i]
+        )
+        xy_preds = torch.cat((xy_preds.to(dtype=torch.float32), xy_preds_i.unsqueeze(0).to(dtype=torch.float32)), 0)
+        end_scores = torch.cat(
+            (end_scores.to(dtype=torch.float32), end_scores_i.to(dtype=torch.float32).unsqueeze(0)), 0
+        )
+    return xy_preds, end_scores
+
+
+def heatmaps_to_keypoints(maps, rois):
+    """Extract predicted keypoint locations from heatmaps. Output has shape
+    (#rois, 4, #keypoints) with the 4 rows corresponding to (x, y, logit, prob)
+    for each keypoint.
+    """
+    # This function converts a discrete image coordinate in a HEATMAP_SIZE x
+    # HEATMAP_SIZE image to a continuous keypoint coordinate. We maintain
+    # consistency with keypoints_to_heatmap_labels by using the conversion from
+    # Heckbert 1990: c = d + 0.5, where d is a discrete coordinate and c is a
+    # continuous coordinate.
+    offset_x = rois[:, 0]
+    offset_y = rois[:, 1]
+
+    widths = rois[:, 2] - rois[:, 0]
+    heights = rois[:, 3] - rois[:, 1]
+    widths = widths.clamp(min=1)
+    heights = heights.clamp(min=1)
+    widths_ceil = widths.ceil()
+    heights_ceil = heights.ceil()
+
+    num_keypoints = maps.shape[1]
+
+    if torchvision._is_tracing():
+        xy_preds, end_scores = _onnx_heatmaps_to_keypoints_loop(
+            maps,
+            rois,
+            widths_ceil,
+            heights_ceil,
+            widths,
+            heights,
+            offset_x,
+            offset_y,
+            torch.scalar_tensor(num_keypoints, dtype=torch.int64),
+        )
+        return xy_preds.permute(0, 2, 1), end_scores
+
+    xy_preds = torch.zeros((len(rois), 3, num_keypoints), dtype=torch.float32, device=maps.device)
+    end_scores = torch.zeros((len(rois), num_keypoints), dtype=torch.float32, device=maps.device)
+    for i in range(len(rois)):
+        roi_map_width = int(widths_ceil[i].item())
+        roi_map_height = int(heights_ceil[i].item())
+        width_correction = widths[i] / roi_map_width
+        height_correction = heights[i] / roi_map_height
+        roi_map = F.interpolate(
+            maps[i][:, None], size=(roi_map_height, roi_map_width), mode="bicubic", align_corners=False
+        )[:, 0]
+        # roi_map_probs = scores_to_probs(roi_map.copy())
+        w = roi_map.shape[2]
+        pos = roi_map.reshape(num_keypoints, -1).argmax(dim=1)
+
+        x_int = pos % w
+        y_int = torch.div(pos - x_int, w, rounding_mode="floor")
+        # assert (roi_map_probs[k, y_int, x_int] ==
+        #         roi_map_probs[k, :, :].max())
+        x = (x_int.float() + 0.5) * width_correction
+        y = (y_int.float() + 0.5) * height_correction
+        xy_preds[i, 0, :] = x + offset_x[i]
+        xy_preds[i, 1, :] = y + offset_y[i]
+        xy_preds[i, 2, :] = 1
+        end_scores[i, :] = roi_map[torch.arange(num_keypoints, device=roi_map.device), y_int, x_int]
+
+    return xy_preds.permute(0, 2, 1), end_scores
+
+
+def keypointrcnn_loss(keypoint_logits, proposals, gt_keypoints, keypoint_matched_idxs):
+    # type: (Tensor, List[Tensor], List[Tensor], List[Tensor]) -> Tensor
+    N, K, H, W = keypoint_logits.shape
+    if H != W:
+        raise ValueError(
+            f"keypoint_logits height and width (last two elements of shape) should be equal. Instead got H = {H} and W = {W}"
+        )
+    discretization_size = H
+    heatmaps = []
+    valid = []
+    for proposals_per_image, gt_kp_in_image, midx in zip(proposals, gt_keypoints, keypoint_matched_idxs):
+        kp = gt_kp_in_image[midx]
+        heatmaps_per_image, valid_per_image = keypoints_to_heatmap(kp, proposals_per_image, discretization_size)
+        heatmaps.append(heatmaps_per_image.view(-1))
+        valid.append(valid_per_image.view(-1))
+
+    keypoint_targets = torch.cat(heatmaps, dim=0)
+    valid = torch.cat(valid, dim=0).to(dtype=torch.uint8)
+    valid = torch.where(valid)[0]
+
+    # torch.mean (in binary_cross_entropy_with_logits) doesn't
+    # accept empty tensors, so handle it sepaartely
+    if keypoint_targets.numel() == 0 or len(valid) == 0:
+        return keypoint_logits.sum() * 0
+
+    keypoint_logits = keypoint_logits.view(N * K, H * W)
+
+    keypoint_loss = F.cross_entropy(keypoint_logits[valid], keypoint_targets[valid])
+    return keypoint_loss
+
+
+def keypointrcnn_inference(x, boxes):
+    # type: (Tensor, List[Tensor]) -> Tuple[List[Tensor], List[Tensor]]
+    kp_probs = []
+    kp_scores = []
+
+    boxes_per_image = [box.size(0) for box in boxes]
+    x2 = x.split(boxes_per_image, dim=0)
+
+    for xx, bb in zip(x2, boxes):
+        kp_prob, scores = heatmaps_to_keypoints(xx, bb)
+        kp_probs.append(kp_prob)
+        kp_scores.append(scores)
+
+    return kp_probs, kp_scores
+
+
+def _onnx_expand_boxes(boxes, scale):
+    # type: (Tensor, float) -> Tensor
+    w_half = (boxes[:, 2] - boxes[:, 0]) * 0.5
+    h_half = (boxes[:, 3] - boxes[:, 1]) * 0.5
+    x_c = (boxes[:, 2] + boxes[:, 0]) * 0.5
+    y_c = (boxes[:, 3] + boxes[:, 1]) * 0.5
+
+    w_half = w_half.to(dtype=torch.float32) * scale
+    h_half = h_half.to(dtype=torch.float32) * scale
+
+    boxes_exp0 = x_c - w_half
+    boxes_exp1 = y_c - h_half
+    boxes_exp2 = x_c + w_half
+    boxes_exp3 = y_c + h_half
+    boxes_exp = torch.stack((boxes_exp0, boxes_exp1, boxes_exp2, boxes_exp3), 1)
+    return boxes_exp
+
+
+# the next two functions should be merged inside Masker
+# but are kept here for the moment while we need them
+# temporarily for paste_mask_in_image
+def expand_boxes(boxes, scale):
+    # type: (Tensor, float) -> Tensor
+    if torchvision._is_tracing():
+        return _onnx_expand_boxes(boxes, scale)
+    w_half = (boxes[:, 2] - boxes[:, 0]) * 0.5
+    h_half = (boxes[:, 3] - boxes[:, 1]) * 0.5
+    x_c = (boxes[:, 2] + boxes[:, 0]) * 0.5
+    y_c = (boxes[:, 3] + boxes[:, 1]) * 0.5
+
+    w_half *= scale
+    h_half *= scale
+
+    boxes_exp = torch.zeros_like(boxes)
+    boxes_exp[:, 0] = x_c - w_half
+    boxes_exp[:, 2] = x_c + w_half
+    boxes_exp[:, 1] = y_c - h_half
+    boxes_exp[:, 3] = y_c + h_half
+    return boxes_exp
+
+
+@torch.jit.unused
+def expand_masks_tracing_scale(M, padding):
+    # type: (int, int) -> float
+    return torch.tensor(M + 2 * padding).to(torch.float32) / torch.tensor(M).to(torch.float32)
+
+
+def expand_masks(mask, padding):
+    # type: (Tensor, int) -> Tuple[Tensor, float]
+    M = mask.shape[-1]
+    if torch._C._get_tracing_state():  # could not import is_tracing(), not sure why
+        scale = expand_masks_tracing_scale(M, padding)
+    else:
+        scale = float(M + 2 * padding) / M
+    padded_mask = F.pad(mask, (padding,) * 4)
+    return padded_mask, scale
+
+
+def paste_mask_in_image(mask, box, im_h, im_w):
+    # type: (Tensor, Tensor, int, int) -> Tensor
+    TO_REMOVE = 1
+    w = int(box[2] - box[0] + TO_REMOVE)
+    h = int(box[3] - box[1] + TO_REMOVE)
+    w = max(w, 1)
+    h = max(h, 1)
+
+    # Set shape to [batchxCxHxW]
+    mask = mask.expand((1, 1, -1, -1))
+
+    # Resize mask
+    mask = F.interpolate(mask, size=(h, w), mode="bilinear", align_corners=False)
+    mask = mask[0][0]
+
+    im_mask = torch.zeros((im_h, im_w), dtype=mask.dtype, device=mask.device)
+    x_0 = max(box[0], 0)
+    x_1 = min(box[2] + 1, im_w)
+    y_0 = max(box[1], 0)
+    y_1 = min(box[3] + 1, im_h)
+
+    im_mask[y_0:y_1, x_0:x_1] = mask[(y_0 - box[1]): (y_1 - box[1]), (x_0 - box[0]): (x_1 - box[0])]
+    return im_mask
+
+
+def _onnx_paste_mask_in_image(mask, box, im_h, im_w):
+    one = torch.ones(1, dtype=torch.int64)
+    zero = torch.zeros(1, dtype=torch.int64)
+
+    w = box[2] - box[0] + one
+    h = box[3] - box[1] + one
+    w = torch.max(torch.cat((w, one)))
+    h = torch.max(torch.cat((h, one)))
+
+    # Set shape to [batchxCxHxW]
+    mask = mask.expand((1, 1, mask.size(0), mask.size(1)))
+
+    # Resize mask
+    mask = F.interpolate(mask, size=(int(h), int(w)), mode="bilinear", align_corners=False)
+    mask = mask[0][0]
+
+    x_0 = torch.max(torch.cat((box[0].unsqueeze(0), zero)))
+    x_1 = torch.min(torch.cat((box[2].unsqueeze(0) + one, im_w.unsqueeze(0))))
+    y_0 = torch.max(torch.cat((box[1].unsqueeze(0), zero)))
+    y_1 = torch.min(torch.cat((box[3].unsqueeze(0) + one, im_h.unsqueeze(0))))
+
+    unpaded_im_mask = mask[(y_0 - box[1]): (y_1 - box[1]), (x_0 - box[0]): (x_1 - box[0])]
+
+    # TODO : replace below with a dynamic padding when support is added in ONNX
+
+    # pad y
+    zeros_y0 = torch.zeros(y_0, unpaded_im_mask.size(1))
+    zeros_y1 = torch.zeros(im_h - y_1, unpaded_im_mask.size(1))
+    concat_0 = torch.cat((zeros_y0, unpaded_im_mask.to(dtype=torch.float32), zeros_y1), 0)[0:im_h, :]
+    # pad x
+    zeros_x0 = torch.zeros(concat_0.size(0), x_0)
+    zeros_x1 = torch.zeros(concat_0.size(0), im_w - x_1)
+    im_mask = torch.cat((zeros_x0, concat_0, zeros_x1), 1)[:, :im_w]
+    return im_mask
+
+
+@torch.jit._script_if_tracing
+def _onnx_paste_masks_in_image_loop(masks, boxes, im_h, im_w):
+    res_append = torch.zeros(0, im_h, im_w)
+    for i in range(masks.size(0)):
+        mask_res = _onnx_paste_mask_in_image(masks[i][0], boxes[i], im_h, im_w)
+        mask_res = mask_res.unsqueeze(0)
+        res_append = torch.cat((res_append, mask_res))
+    return res_append
+
+
+def paste_masks_in_image(masks, boxes, img_shape, padding=1):
+    # type: (Tensor, Tensor, Tuple[int, int], int) -> Tensor
+    masks, scale = expand_masks(masks, padding=padding)
+    boxes = expand_boxes(boxes, scale).to(dtype=torch.int64)
+    im_h, im_w = img_shape
+
+    if torchvision._is_tracing():
+        return _onnx_paste_masks_in_image_loop(
+            masks, boxes, torch.scalar_tensor(im_h, dtype=torch.int64), torch.scalar_tensor(im_w, dtype=torch.int64)
+        )[:, None]
+    res = [paste_mask_in_image(m[0], b, im_h, im_w) for m, b in zip(masks, boxes)]
+    if len(res) > 0:
+        ret = torch.stack(res, dim=0)[:, None]
+    else:
+        ret = masks.new_empty((0, 1, im_h, im_w))
+    return ret
+
+
+class RoIHeads(nn.Module):
+    __annotations__ = {
+        "box_coder": det_utils.BoxCoder,
+        "proposal_matcher": det_utils.Matcher,
+        "fg_bg_sampler": det_utils.BalancedPositiveNegativeSampler,
+    }
+
+    def __init__(
+            self,
+            box_roi_pool,
+            box_head,
+            box_predictor,
+            line_head,
+            line_predictor,
+            # Faster R-CNN training
+            fg_iou_thresh,
+            bg_iou_thresh,
+            batch_size_per_image,
+            positive_fraction,
+            bbox_reg_weights,
+            # Faster R-CNN inference
+            score_thresh,
+            nms_thresh,
+            detections_per_img,
+            # Mask
+            mask_roi_pool=None,
+            mask_head=None,
+            mask_predictor=None,
+            keypoint_roi_pool=None,
+            keypoint_head=None,
+            keypoint_predictor=None,
+    ):
+        super().__init__()
+
+        self.box_similarity = box_ops.box_iou
+        # assign ground-truth boxes for each proposal
+        self.proposal_matcher = det_utils.Matcher(fg_iou_thresh, bg_iou_thresh, allow_low_quality_matches=False)
+
+        self.fg_bg_sampler = det_utils.BalancedPositiveNegativeSampler(batch_size_per_image, positive_fraction)
+
+        if bbox_reg_weights is None:
+            bbox_reg_weights = (10.0, 10.0, 5.0, 5.0)
+        self.box_coder = det_utils.BoxCoder(bbox_reg_weights)
+
+        self.box_roi_pool = box_roi_pool
+        self.box_head = box_head
+        self.box_predictor = box_predictor
+
+        self.line_head = line_head
+        self.line_predictor = line_predictor
+
+        self.score_thresh = score_thresh
+        self.nms_thresh = nms_thresh
+        self.detections_per_img = detections_per_img
+
+        self.mask_roi_pool = mask_roi_pool
+        self.mask_head = mask_head
+        self.mask_predictor = mask_predictor
+
+        self.keypoint_roi_pool = keypoint_roi_pool
+        self.keypoint_head = keypoint_head
+        self.keypoint_predictor = keypoint_predictor
+
+    def has_line(self):
+        # if self.mask_roi_pool is None:
+        #     return False
+        if self.line_head is None:
+            return False
+        if self.line_predictor is None:
+            return False
+        return True
+
+    def has_mask(self):
+        if self.mask_roi_pool is None:
+            return False
+        if self.mask_head is None:
+            return False
+        if self.mask_predictor is None:
+            return False
+        return True
+
+    def has_keypoint(self):
+        if self.keypoint_roi_pool is None:
+            return False
+        if self.keypoint_head is None:
+            return False
+        if self.keypoint_predictor is None:
+            return False
+        return True
+
+    def assign_targets_to_proposals(self, proposals, gt_boxes, gt_labels):
+        # type: (List[Tensor], List[Tensor], List[Tensor]) -> Tuple[List[Tensor], List[Tensor]]
+        matched_idxs = []
+        labels = []
+        for proposals_in_image, gt_boxes_in_image, gt_labels_in_image in zip(proposals, gt_boxes, gt_labels):
+
+            if gt_boxes_in_image.numel() == 0:
+                # Background image
+                device = proposals_in_image.device
+                clamped_matched_idxs_in_image = torch.zeros(
+                    (proposals_in_image.shape[0],), dtype=torch.int64, device=device
+                )
+                labels_in_image = torch.zeros((proposals_in_image.shape[0],), dtype=torch.int64, device=device)
+            else:
+                #  set to self.box_similarity when https://github.com/pytorch/pytorch/issues/27495 lands
+                match_quality_matrix = box_ops.box_iou(gt_boxes_in_image, proposals_in_image)
+                matched_idxs_in_image = self.proposal_matcher(match_quality_matrix)
+
+                clamped_matched_idxs_in_image = matched_idxs_in_image.clamp(min=0)
+
+                labels_in_image = gt_labels_in_image[clamped_matched_idxs_in_image]
+                labels_in_image = labels_in_image.to(dtype=torch.int64)
+
+                # Label background (below the low threshold)
+                bg_inds = matched_idxs_in_image == self.proposal_matcher.BELOW_LOW_THRESHOLD
+                labels_in_image[bg_inds] = 0
+
+                # Label ignore proposals (between low and high thresholds)
+                ignore_inds = matched_idxs_in_image == self.proposal_matcher.BETWEEN_THRESHOLDS
+                labels_in_image[ignore_inds] = -1  # -1 is ignored by sampler
+
+            matched_idxs.append(clamped_matched_idxs_in_image)
+            labels.append(labels_in_image)
+        return matched_idxs, labels
+
+    def subsample(self, labels):
+        # type: (List[Tensor]) -> List[Tensor]
+        sampled_pos_inds, sampled_neg_inds = self.fg_bg_sampler(labels)
+        sampled_inds = []
+        for img_idx, (pos_inds_img, neg_inds_img) in enumerate(zip(sampled_pos_inds, sampled_neg_inds)):
+            img_sampled_inds = torch.where(pos_inds_img | neg_inds_img)[0]
+            sampled_inds.append(img_sampled_inds)
+        return sampled_inds
+
+    def add_gt_proposals(self, proposals, gt_boxes):
+        # type: (List[Tensor], List[Tensor]) -> List[Tensor]
+        proposals = [torch.cat((proposal, gt_box)) for proposal, gt_box in zip(proposals, gt_boxes)]
+
+        return proposals
+
+    def check_targets(self, targets):
+        # type: (Optional[List[Dict[str, Tensor]]]) -> None
+        if targets is None:
+            raise ValueError("targets should not be None")
+        if not all(["boxes" in t for t in targets]):
+            raise ValueError("Every element of targets should have a boxes key")
+        if not all(["labels" in t for t in targets]):
+            raise ValueError("Every element of targets should have a labels key")
+        if self.has_mask():
+            if not all(["masks" in t for t in targets]):
+                raise ValueError("Every element of targets should have a masks key")
+
+    def select_training_samples(
+            self,
+            proposals,  # type: List[Tensor]
+            targets,  # type: Optional[List[Dict[str, Tensor]]]
+    ):
+        # type: (...) -> Tuple[List[Tensor], List[Tensor], List[Tensor], List[Tensor]]
+        self.check_targets(targets)
+        if targets is None:
+            raise ValueError("targets should not be None")
+        dtype = proposals[0].dtype
+        device = proposals[0].device
+
+        gt_boxes = [t["boxes"].to(dtype) for t in targets]
+        gt_labels = [t["labels"] for t in targets]
+
+        # append ground-truth bboxes to propos
+        proposals = self.add_gt_proposals(proposals, gt_boxes)
+
+        # get matching gt indices for each proposal
+        matched_idxs, labels = self.assign_targets_to_proposals(proposals, gt_boxes, gt_labels)
+        # sample a fixed proportion of positive-negative proposals
+        sampled_inds = self.subsample(labels)
+        matched_gt_boxes = []
+        num_images = len(proposals)
+        for img_id in range(num_images):
+            img_sampled_inds = sampled_inds[img_id]
+            proposals[img_id] = proposals[img_id][img_sampled_inds]
+            labels[img_id] = labels[img_id][img_sampled_inds]
+            matched_idxs[img_id] = matched_idxs[img_id][img_sampled_inds]
+
+            gt_boxes_in_image = gt_boxes[img_id]
+            if gt_boxes_in_image.numel() == 0:
+                gt_boxes_in_image = torch.zeros((1, 4), dtype=dtype, device=device)
+            matched_gt_boxes.append(gt_boxes_in_image[matched_idxs[img_id]])
+
+        regression_targets = self.box_coder.encode(matched_gt_boxes, proposals)
+        return proposals, matched_idxs, labels, regression_targets
+
+    def postprocess_detections(
+            self,
+            class_logits,  # type: Tensor
+            box_regression,  # type: Tensor
+            proposals,  # type: List[Tensor]
+            image_shapes,  # type: List[Tuple[int, int]]
+    ):
+        # type: (...) -> Tuple[List[Tensor], List[Tensor], List[Tensor]]
+        device = class_logits.device
+        num_classes = class_logits.shape[-1]
+
+        boxes_per_image = [boxes_in_image.shape[0] for boxes_in_image in proposals]
+        pred_boxes = self.box_coder.decode(box_regression, proposals)
+
+        pred_scores = F.softmax(class_logits, -1)
+
+        pred_boxes_list = pred_boxes.split(boxes_per_image, 0)
+        pred_scores_list = pred_scores.split(boxes_per_image, 0)
+
+        all_boxes = []
+        all_scores = []
+        all_labels = []
+        for boxes, scores, image_shape in zip(pred_boxes_list, pred_scores_list, image_shapes):
+            boxes = box_ops.clip_boxes_to_image(boxes, image_shape)
+
+            # create labels for each prediction
+            labels = torch.arange(num_classes, device=device)
+            labels = labels.view(1, -1).expand_as(scores)
+
+            # remove predictions with the background label
+            boxes = boxes[:, 1:]
+            scores = scores[:, 1:]
+            labels = labels[:, 1:]
+
+            # batch everything, by making every class prediction be a separate instance
+            boxes = boxes.reshape(-1, 4)
+            scores = scores.reshape(-1)
+            labels = labels.reshape(-1)
+
+            # remove low scoring boxes
+            inds = torch.where(scores > self.score_thresh)[0]
+            boxes, scores, labels = boxes[inds], scores[inds], labels[inds]
+
+            # remove empty boxes
+            keep = box_ops.remove_small_boxes(boxes, min_size=1e-2)
+            boxes, scores, labels = boxes[keep], scores[keep], labels[keep]
+
+            # non-maximum suppression, independently done per class
+            keep = box_ops.batched_nms(boxes, scores, labels, self.nms_thresh)
+            # keep only topk scoring predictions
+            keep = keep[: self.detections_per_img]
+            boxes, scores, labels = boxes[keep], scores[keep], labels[keep]
+
+            all_boxes.append(boxes)
+            all_scores.append(scores)
+            all_labels.append(labels)
+
+        return all_boxes, all_scores, all_labels
+
+    def forward(
+            self,
+            features,  # type: Dict[str, Tensor]
+            proposals,  # type: List[Tensor]
+            image_shapes,  # type: List[Tuple[int, int]]
+            targets=None,  # type: Optional[List[Dict[str, Tensor]]]
+    ):
+        # type: (...) -> Tuple[List[Dict[str, Tensor]], Dict[str, Tensor]]
+        """
+        Args:
+            features (List[Tensor])
+            proposals (List[Tensor[N, 4]])
+            image_shapes (List[Tuple[H, W]])
+            targets (List[Dict])
+        """
+        if targets is not None:
+            self.training = True
+
+        else:
+            self.training = False
+
+        if targets is not None:
+            for t in targets:
+                # TODO: https://github.com/pytorch/pytorch/issues/26731
+                floating_point_types = (torch.float, torch.double, torch.half)
+                if not t["boxes"].dtype in floating_point_types:
+                    raise TypeError(f"target boxes must of float type, instead got {t['boxes'].dtype}")
+                if not t["labels"].dtype == torch.int64:
+                    raise TypeError(f"target labels must of int64 type, instead got {t['labels'].dtype}")
+                if self.has_keypoint():
+                    if not t["keypoints"].dtype == torch.float32:
+                        raise TypeError(f"target keypoints must of float type, instead got {t['keypoints'].dtype}")
+
+        if self.training:
+            proposals, matched_idxs, labels, regression_targets = self.select_training_samples(proposals, targets)
+        else:
+            labels = None
+            regression_targets = None
+            matched_idxs = None
+
+        box_features = self.box_roi_pool(features, proposals, image_shapes)
+        box_features = self.box_head(box_features)
+        class_logits, box_regression = self.box_predictor(box_features)
+
+        result: List[Dict[str, torch.Tensor]] = []
+        losses = {}
+        if self.training:
+            if labels is None:
+                raise ValueError("labels cannot be None")
+            if regression_targets is None:
+                raise ValueError("regression_targets cannot be None")
+            loss_classifier, loss_box_reg = fastrcnn_loss(class_logits, box_regression, labels, regression_targets)
+            losses = {"loss_classifier": loss_classifier, "loss_box_reg": loss_box_reg}
+        else:
+            boxes, scores, labels = self.postprocess_detections(class_logits, box_regression, proposals, image_shapes)
+            num_images = len(boxes)
+            for i in range(num_images):
+                result.append(
+                    {
+                        "boxes": boxes[i],
+                        "labels": labels[i],
+                        "scores": scores[i],
+                    }
+                )
+
+        features_lcnn = features['0']
+        if self.has_line():
+            # print('has line_head')
+            outputs = self.line_head(features_lcnn)
+            loss_weight = {'junc_map': 8.0, 'line_map': 0.5, 'junc_offset': 0.25, 'lpos': 1, 'lneg': 1}
+            x, y, idx, jcs, n_batch, ps, n_out_line, n_out_junc = self.line_predictor(
+                inputs=outputs, features=features_lcnn, targets=targets)
+
+            # # line_loss(multitasklearner)
+            # if self.training:
+            #     head_result = line_head_loss(targets, outputs, features_lcnn, loss_weight, mode_train=True)
+            #     line_result = line_vectorizer_loss(head_result, x, ys, idx, jcs, n_batch, ps, n_out_line, n_out_junc,
+            #                                        loss_weight, mode_train=True)
+            # else:
+            #     head_result = line_head_loss(targets, outputs, features_lcnn, loss_weight, mode_train=False)
+            #     line_result = line_vectorizer_loss(head_result, x, ys, idx, jcs, n_batch, ps, n_out_line, n_out_junc,
+            #                                        loss_weight, mode_train=False)
+
+            if self.training:
+                rcnn_loss_wirepoint = wirepoint_head_line_loss(targets, outputs, x, y, idx, loss_weight)
+                loss_wirepoint = {"loss_wirepoint": rcnn_loss_wirepoint}
+            else:
+
+                pred = wirepoint_inference(x, idx, jcs, n_batch, ps, n_out_line, n_out_junc)
+                result.append(pred)
+                loss_wirepoint = {}
+            losses.update(loss_wirepoint)
+        else:
+            pass
+            # print('has not line_head')
+
+
+
+        if self.has_mask():
+            mask_proposals = [p["boxes"] for p in result]
+            if self.training:
+                if matched_idxs is None:
+                    raise ValueError("if in training, matched_idxs should not be None")
+
+                # during training, only focus on positive boxes
+                num_images = len(proposals)
+                mask_proposals = []
+                pos_matched_idxs = []
+                for img_id in range(num_images):
+                    pos = torch.where(labels[img_id] > 0)[0]
+                    mask_proposals.append(proposals[img_id][pos])
+                    pos_matched_idxs.append(matched_idxs[img_id][pos])
+            else:
+                pos_matched_idxs = None
+
+            if self.mask_roi_pool is not None:
+                mask_features = self.mask_roi_pool(features, mask_proposals, image_shapes)
+                mask_features = self.mask_head(mask_features)
+                mask_logits = self.mask_predictor(mask_features)
+            else:
+                raise Exception("Expected mask_roi_pool to be not None")
+
+            loss_mask = {}
+            if self.training:
+                if targets is None or pos_matched_idxs is None or mask_logits is None:
+                    raise ValueError("targets, pos_matched_idxs, mask_logits cannot be None when training")
+
+                gt_masks = [t["masks"] for t in targets]
+                gt_labels = [t["labels"] for t in targets]
+                rcnn_loss_mask = maskrcnn_loss(mask_logits, mask_proposals, gt_masks, gt_labels, pos_matched_idxs)
+                loss_mask = {"loss_mask": rcnn_loss_mask}
+            else:
+                labels = [r["labels"] for r in result]
+                masks_probs = maskrcnn_inference(mask_logits, labels)
+                for mask_prob, r in zip(masks_probs, result):
+                    r["masks"] = mask_prob
+
+            losses.update(loss_mask)
+
+        # keep none checks in if conditional so torchscript will conditionally
+        # compile each branch
+        if (
+                self.keypoint_roi_pool is not None
+                and self.keypoint_head is not None
+                and self.keypoint_predictor is not None
+        ):
+            keypoint_proposals = [p["boxes"] for p in result]
+            if self.training:
+                # during training, only focus on positive boxes
+                num_images = len(proposals)
+                keypoint_proposals = []
+                pos_matched_idxs = []
+                if matched_idxs is None:
+                    raise ValueError("if in trainning, matched_idxs should not be None")
+
+                for img_id in range(num_images):
+                    pos = torch.where(labels[img_id] > 0)[0]
+                    keypoint_proposals.append(proposals[img_id][pos])
+                    pos_matched_idxs.append(matched_idxs[img_id][pos])
+            else:
+                pos_matched_idxs = None
+
+            keypoint_features = self.keypoint_roi_pool(features, keypoint_proposals, image_shapes)
+            keypoint_features = self.keypoint_head(keypoint_features)
+            keypoint_logits = self.keypoint_predictor(keypoint_features)
+
+            loss_keypoint = {}
+            if self.training:
+                if targets is None or pos_matched_idxs is None:
+                    raise ValueError("both targets and pos_matched_idxs should not be None when in training mode")
+
+                gt_keypoints = [t["keypoints"] for t in targets]
+                rcnn_loss_keypoint = keypointrcnn_loss(
+                    keypoint_logits, keypoint_proposals, gt_keypoints, pos_matched_idxs
+                )
+                loss_keypoint = {"loss_keypoint": rcnn_loss_keypoint}
+            else:
+                if keypoint_logits is None or keypoint_proposals is None:
+                    raise ValueError(
+                        "both keypoint_logits and keypoint_proposals should not be None when not in training mode"
+                    )
+
+                keypoints_probs, kp_scores = keypointrcnn_inference(keypoint_logits, keypoint_proposals)
+                for keypoint_prob, kps, r in zip(keypoints_probs, kp_scores, result):
+                    r["keypoints"] = keypoint_prob
+                    r["keypoints_scores"] = kps
+            losses.update(loss_keypoint)
+
+        return result, losses

+ 0 - 0
models/obj_detect/__init__.py


+ 111 - 0
predict.py

@@ -0,0 +1,111 @@
+import torch
+import matplotlib.pyplot as plt
+import numpy as np
+from torchvision import transforms
+
+
+def box_line(pred):
+    '''
+    :param pred: 预测结果
+    :return:
+
+    box与line一一对应
+{'box': [0.0, 34.23157501220703, 151.70858764648438, 125.10173797607422], 'line': array([[ 1.9720564, 81.73457  ],
+[ 1.9933801, 41.730167 ]], dtype=float32)}
+    '''
+    box_line = [[] for _ in range((len(pred) - 1))]
+    for idx, box_ in enumerate(pred[0:-1]):
+        box = box_['boxes']  # 是一个tensor
+        line = pred[-1]['wires']['lines'][idx].cpu().numpy() / 128 * 512
+        score = pred[-1]['wires']['score'][idx]
+        for i in box:
+            aaa = {}
+            aaa['box'] = i.tolist()
+            aaa['line'] = []
+            score_max = 0.0
+            for j in range(len(line)):
+                if (line[j][0][0] >= i[0] and line[j][1][0] >= i[0] and line[j][0][0] <= i[2] and
+                        line[j][1][0] <= i[2] and line[j][0][1] >= i[1] and line[j][1][1] >= i[1] and
+                        line[j][0][1] <= i[3] and line[j][1][1] <= i[3]):
+                    if score[j] > score_max:
+                        aaa['line'] = line[j]
+                        score_max = score[j]
+            box_line[idx].append(aaa)
+
+
+def box_line_(pred):
+    '''
+    形式同pred
+    '''
+    for idx, box_ in enumerate(pred[0:-1]):
+        box = box_['boxes']  # 是一个tensor
+        line = pred[-1]['wires']['lines'][idx].cpu().numpy() / 128 * 512
+        score = pred[-1]['wires']['score'][idx]
+        line_ = []
+        for i in box:
+            score_max = 0.0
+            tmp = [[0.0, 0.0], [0.0, 0.0]]
+            for j in range(len(line)):
+                if (line[j][0][0] >= i[0] and line[j][1][0] >= i[0] and line[j][0][0] <= i[2] and
+                        line[j][1][0] <= i[2] and line[j][0][1] >= i[1] and line[j][1][1] >= i[1] and
+                        line[j][0][1] <= i[3] and line[j][1][1] <= i[3]):
+                    if score[j] > score_max:
+                        tmp = line[j]
+                        score_max = score[j]
+            line_.append(tmp)
+        processed_list = torch.tensor(line_)
+        pred[idx]['line'] = processed_list
+    return pred
+
+
+def show_(imgs, pred, epoch, writer):
+    col = [
+        '#1f77b4', '#aec7e8', '#ff7f0e', '#ffbb78', '#2ca02c',
+        '#98df8a', '#d62728', '#ff9896', '#9467bd', '#c5b0d5',
+        '#8c564b', '#c49c94', '#e377c2', '#f7b6d2', '#7f7f7f',
+        '#c7c7c7', '#bcbd22', '#dbdb8d', '#17becf', '#9edae5',
+        '#8dd3c7', '#ffffb3', '#bebada', '#fb8072', '#80b1d3',
+        '#fdb462', '#b3de69', '#fccde5', '#bc80bd', '#ccebc5',
+        '#ffed6f', '#8da0cb', '#e78ac3', '#e5c494', '#b3b3b3',
+        '#fdbf6f', '#ff7f00', '#cab2d6', '#637939', '#b5cf6b',
+        '#cedb9c', '#8c6d31', '#e7969c', '#d6616b', '#7b4173',
+        '#ad494a', '#843c39', '#dd8452', '#f7f7f7', '#cccccc',
+        '#969696', '#525252', '#f7fcfd', '#e5f5f9', '#ccece6',
+        '#99d8c9', '#66c2a4', '#2ca25f', '#008d4c', '#005a32',
+        '#f7fcf0', '#e0f3db', '#ccebc5', '#a8ddb5', '#7bccc4',
+        '#4eb3d3', '#2b8cbe', '#08589e', '#f7fcfd', '#e0ecf4',
+        '#bfd3e6', '#9ebcda', '#8c96c6', '#8c6bb4', '#88419d',
+        '#810f7c', '#4d004b', '#f7f7f7', '#efefef', '#d9d9d9',
+        '#bfbfbf', '#969696', '#737373', '#525252', '#252525',
+        '#000000', '#ffffff', '#ffeda0', '#fed976', '#feb24c',
+        '#fd8d3c', '#fc4e2a', '#e31a1c', '#bd0026', '#800026',
+        '#ff6f61', '#ff9e64', '#ff6347', '#ffa07a', '#fa8072'
+    ]
+    print(len(col))
+    im = imgs[0].permute(1, 2, 0)
+    boxes = pred[0]['boxes'].cpu().numpy()
+    line = pred[0]['line'].cpu().numpy()
+
+    # 可视化预测结
+    fig, ax = plt.subplots(figsize=(10, 10))
+    ax.imshow(np.array(im))
+
+    for idx, box in enumerate(boxes):
+        x0, y0, x1, y1 = box
+        ax.add_patch(
+            plt.Rectangle((x0, y0), x1 - x0, y1 - y0, fill=False, edgecolor=col[idx], linewidth=1))
+
+    for idx, (a, b) in enumerate(line):
+        ax.scatter(a[0], a[1], c=col[99 - idx], s=2)
+        ax.scatter(b[0], b[1], c=col[99 - idx], s=2)
+        ax.plot([a[0], b[0]], [a[1], b[1]], c=col[idx], linewidth=1)
+
+    # 将Matplotlib图像转换为Tensor
+    fig.canvas.draw()
+    image_from_plot = np.frombuffer(fig.canvas.tostring_rgb(), dtype=np.uint8).reshape(
+        fig.canvas.get_width_height()[::-1] + (3,))
+    plt.close()
+    img2 = transforms.ToTensor()(image_from_plot)
+
+    writer.add_image("all", img2, epoch)
+

+ 18 - 0
readme.md

@@ -0,0 +1,18 @@
+# MultiVisionModels Platform
+
+## 1.Dependences
+
+| Name  | Version  | Build|  Channel |
+| -------- | -------- | -------- |------|
+| python    |3.12.8      | h3f84c4b_1_cpython|conda-forge|
+|pytorch| 2.2.2 |py3.12_cuda12.1_cudnn8_0|pytorch
+pytorch-cuda|12.1| hde6ce7c_6| pytorch
+torchvision|0.17.2 |pypi_0|pypi
+numpy|1.26.3|py312h8753938_0|conda-forge
+
+
+## 2.Overview
+
+Include objection dectection ,keypoint detection,instance segment detection and line dectection.
+
+

+ 314 - 0
train——line_rcnn.py

@@ -0,0 +1,314 @@
+# 2025/2/9
+import os
+from typing import Optional, Any
+
+import cv2
+import numpy as np
+import torch
+
+from models.config.config_tool import read_yaml
+from models.line_detect.dataset_LD import WirePointDataset
+from tools import utils
+
+from torch.utils.tensorboard import SummaryWriter
+import matplotlib.pyplot as plt
+import matplotlib as mpl
+from skimage import io
+
+from models.line_detect.line_net import linenet_resnet50_fpn
+from torchvision.utils import draw_bounding_boxes
+from models.wirenet.postprocess import postprocess
+from torchvision import transforms
+from collections import OrderedDict
+
+from PIL import Image
+
+from predict import box_line_, show_
+import matplotlib.pyplot as plt
+
+device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
+
+
+def _loss(losses):
+    total_loss = 0
+    for i in losses.keys():
+        if i != "loss_wirepoint":
+            total_loss += losses[i]
+        else:
+            loss_labels = losses[i]["losses"]
+    loss_labels_k = list(loss_labels[0].keys())
+    for j, name in enumerate(loss_labels_k):
+        loss = loss_labels[0][name].mean()
+        total_loss += loss
+
+    return total_loss
+
+
+cmap = plt.get_cmap("jet")
+norm = mpl.colors.Normalize(vmin=0.4, vmax=1.0)
+sm = plt.cm.ScalarMappable(cmap=cmap, norm=norm)
+sm.set_array([])
+
+
+def c(x):
+    return sm.to_rgba(x)
+
+
+def imshow(im):
+    plt.close()
+    plt.tight_layout()
+    plt.imshow(im)
+    plt.colorbar(sm, fraction=0.046)
+    plt.xlim([0, im.shape[0]])
+    plt.ylim([im.shape[0], 0])
+
+
+def show_line(img, pred, epoch, writer):
+    im = img.permute(1, 2, 0)
+    writer.add_image("ori", im, epoch, dataformats="HWC")
+
+    boxed_image = draw_bounding_boxes((img * 255).to(torch.uint8), pred[0]["boxes"],
+                                      colors="yellow", width=1)
+    writer.add_image("boxes", boxed_image.permute(1, 2, 0), epoch, dataformats="HWC")
+
+    PLTOPTS = {"color": "#33FFFF", "s": 15, "edgecolors": "none", "zorder": 5}
+    # print(f'pred[1]:{pred[1]}')
+    H = pred[-1]['wires']
+    lines = H["lines"][0].cpu().numpy() / 128 * im.shape[:2]
+    scores = H["score"][0].cpu().numpy()
+    for i in range(1, len(lines)):
+        if (lines[i] == lines[0]).all():
+            lines = lines[:i]
+            scores = scores[:i]
+            break
+
+    # postprocess lines to remove overlapped lines
+    diag = (im.shape[0] ** 2 + im.shape[1] ** 2) ** 0.5
+    nlines, nscores = postprocess(lines, scores, diag * 0.01, 0, False)
+
+    for i, t in enumerate([0.85]):
+        plt.gca().set_axis_off()
+        plt.subplots_adjust(top=1, bottom=0, right=1, left=0, hspace=0, wspace=0)
+        plt.margins(0, 0)
+        for (a, b), s in zip(nlines, nscores):
+            if s < t:
+                continue
+            plt.plot([a[1], b[1]], [a[0], b[0]], c=c(s), linewidth=2, zorder=s)
+            plt.scatter(a[1], a[0], **PLTOPTS)
+            plt.scatter(b[1], b[0], **PLTOPTS)
+        plt.gca().xaxis.set_major_locator(plt.NullLocator())
+        plt.gca().yaxis.set_major_locator(plt.NullLocator())
+        plt.imshow(im)
+        plt.tight_layout()
+        fig = plt.gcf()
+        fig.canvas.draw()
+        image_from_plot = np.frombuffer(fig.canvas.tostring_rgb(), dtype=np.uint8).reshape(
+            fig.canvas.get_width_height()[::-1] + (3,))
+        plt.close()
+        img2 = transforms.ToTensor()(image_from_plot)
+
+        writer.add_image("output", img2, epoch)
+
+
+def save_best_model(model, save_path, epoch, current_loss, best_loss, optimizer=None):
+    os.makedirs(os.path.dirname(save_path), exist_ok=True)
+
+    if current_loss < best_loss:
+        checkpoint = {
+            'epoch': epoch,
+            'model_state_dict': model.state_dict(),
+            'loss': current_loss
+        }
+
+        if optimizer is not None:
+            checkpoint['optimizer_state_dict'] = optimizer.state_dict()
+
+        torch.save(checkpoint, save_path)
+        print(f"Saved best model at epoch {epoch} with loss {current_loss:.4f}")
+
+        return current_loss
+
+    return best_loss
+
+
+def save_latest_model(model, save_path, epoch, optimizer=None):
+    os.makedirs(os.path.dirname(save_path), exist_ok=True)
+
+    checkpoint = {
+        'epoch': epoch,
+        'model_state_dict': model.state_dict(),
+    }
+
+    if optimizer is not None:
+        checkpoint['optimizer_state_dict'] = optimizer.state_dict()
+
+    torch.save(checkpoint, save_path)
+
+
+def load_best_model(model, optimizer, save_path, device):
+    if os.path.exists(save_path):
+        checkpoint = torch.load(save_path, map_location=device)
+        model.load_state_dict(checkpoint['model_state_dict'])
+        if optimizer is not None:
+            optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
+        epoch = checkpoint['epoch']
+        loss = checkpoint['loss']
+        print(f"Loaded best model from {save_path} at epoch {epoch} with loss {loss:.4f}")
+    else:
+        print(f"No saved model found at {save_path}")
+    return model, optimizer
+
+
+def predict(self, img, show_boxes=True, show_keypoint=True, show_line=True, save=False, save_path=None):
+    self.load_weight('weights/best.pt')
+    self.__model.eval()
+
+    if isinstance(img, str):
+        img = Image.open(img).convert("RGB")
+
+    # 预处理图像
+    img_tensor = self.transforms(img)
+    with torch.no_grad():
+        predictions = self.__model([img_tensor])
+
+    # 后处理预测结果
+    boxes = predictions[0]['boxes'].cpu().numpy()
+    keypoints = predictions[0]['keypoints'].cpu().numpy()
+
+    # 可视化预测结果
+    if show_boxes or show_keypoint or show_line or save:
+        fig, ax = plt.subplots(figsize=(10, 10))
+        ax.imshow(np.array(img))
+
+        if show_boxes:
+            for box in boxes:
+                x0, y0, x1, y1 = box
+                ax.add_patch(plt.Rectangle((x0, y0), x1 - x0, y1 - y0, fill=False, edgecolor='yellow', linewidth=1))
+
+        for (a, b) in keypoints:
+            if show_keypoint:
+                ax.scatter(a[0], a[1], c='c', s=2)
+                ax.scatter(b[0], b[1], c='c', s=2)
+            if show_line:
+                ax.plot([a[0], b[0]], [a[1], b[1]], c='red', linewidth=1)
+
+        if show_boxes or show_keypoint or show_line:
+            plt.show()
+
+        if save:
+            fig.savefig(save_path)
+            print(f"Prediction saved to {save_path}")
+        plt.close(fig)
+
+
+if __name__ == '__main__':
+    cfg = r'./config/wireframe.yaml'
+    cfg = read_yaml(cfg)
+    print(f'cfg:{cfg}')
+    print(cfg['model']['n_dyn_negl'])
+    # net = WirepointPredictor()
+
+    dataset_train = WirePointDataset(dataset_path=cfg['io']['datadir'], dataset_type='train')
+    train_sampler = torch.utils.data.RandomSampler(dataset_train)
+    # test_sampler = torch.utils.data.SequentialSampler(dataset_test)
+    train_batch_sampler = torch.utils.data.BatchSampler(train_sampler, batch_size=2, drop_last=True)
+    train_collate_fn = utils.collate_fn_wirepoint
+    data_loader_train = torch.utils.data.DataLoader(
+        dataset_train, batch_sampler=train_batch_sampler, num_workers=8, collate_fn=train_collate_fn
+    )
+
+    dataset_val = WirePointDataset(dataset_path=cfg['io']['datadir'], dataset_type='val')
+    val_sampler = torch.utils.data.RandomSampler(dataset_val)
+    # test_sampler = torch.utils.data.SequentialSampler(dataset_test)
+    val_batch_sampler = torch.utils.data.BatchSampler(val_sampler, batch_size=2, drop_last=True)
+    val_collate_fn = utils.collate_fn_wirepoint
+    data_loader_val = torch.utils.data.DataLoader(
+        dataset_val, batch_sampler=val_batch_sampler, num_workers=8, collate_fn=val_collate_fn
+    )
+
+    model = linenet_resnet50_fpn().to(device)
+    optimizer = torch.optim.Adam(model.parameters(), lr=cfg['optim']['lr'])
+    writer = SummaryWriter(cfg['io']['logdir'])
+
+    # 加载权重
+    save_path = 'logs/pth/best_model.pth'
+    model, optimizer = load_best_model(model, optimizer, save_path, device)
+
+    logdir_with_pth = os.path.join(cfg['io']['logdir'], 'pth')
+    os.makedirs(logdir_with_pth, exist_ok=True)  # 创建目录(如果不存在)
+    latest_model_path = os.path.join(logdir_with_pth, 'latest_model.pth')
+    best_model_path = os.path.join(logdir_with_pth, 'best_model.pth')
+    global_step = 0
+
+
+    def move_to_device(data, device):
+        if isinstance(data, (list, tuple)):
+            return type(data)(move_to_device(item, device) for item in data)
+        elif isinstance(data, dict):
+            return {key: move_to_device(value, device) for key, value in data.items()}
+        elif isinstance(data, torch.Tensor):
+            return data.to(device)
+        else:
+            return data  # 对于非张量类型的数据不做任何改变
+
+
+    def writer_loss(writer, losses, epoch):
+        try:
+            for key, value in losses.items():
+                if key == 'loss_wirepoint':
+                    for subdict in losses['loss_wirepoint']['losses']:
+                        for subkey, subvalue in subdict.items():
+                            writer.add_scalar(f'loss/{subkey}',
+                                              subvalue.item() if hasattr(subvalue, 'item') else subvalue,
+                                              epoch)
+                elif isinstance(value, torch.Tensor):
+                    writer.add_scalar(f'loss/{key}', value.item(), epoch)
+        except Exception as e:
+            print(f"TensorBoard logging error: {e}")
+
+
+    for epoch in range(cfg['optim']['max_epoch']):
+        print(f"epoch:{epoch}")
+        model.train()
+        total_train_loss = 0.0
+
+        for imgs, targets in data_loader_train:
+            losses = model(move_to_device(imgs, device), move_to_device(targets, device))
+            # print(losses)
+            loss = _loss(losses)
+            total_train_loss += loss.item()
+            optimizer.zero_grad()
+            loss.backward()
+            optimizer.step()
+            writer_loss(writer, losses, epoch)
+
+        model.eval()
+        with torch.no_grad():
+            for batch_idx, (imgs, targets) in enumerate(data_loader_val):
+                pred = model(move_to_device(imgs, device))
+
+                pred_ = box_line_(pred)  # 将box与line对应
+                show_(imgs, pred_, epoch, writer)
+
+                if batch_idx == 0:
+                    show_line(imgs[0], pred, epoch, writer)
+
+                break
+
+        avg_train_loss = total_train_loss / len(data_loader_train)
+        writer.add_scalar('loss/train', avg_train_loss, epoch)
+        best_loss = 10000
+        save_latest_model(
+            model,
+            latest_model_path,
+            epoch,
+            optimizer
+        )
+        best_loss = save_best_model(
+            model,
+            best_model_path,
+            epoch,
+            avg_train_loss,
+            best_loss,
+            optimizer
+        )