Browse Source

faster_rcnn and lcnn backbone is OK

xue50 4 months ago
parent
commit
02ffb35cac

BIN
assets/dog1.jpg


BIN
assets/dog1.png


BIN
assets/dog2.jpg


BIN
assets/dog2.png


+ 12 - 8
models/wirenet/wirenet.yaml → config/wireframe.yaml

@@ -1,32 +1,36 @@
 io:
   logdir: logs/
-  datadir: D:/python/PycharmProjects/data
+  datadir: D:\python\PycharmProjects\data
+#  datadir: /home/dieu/lcnn/dataset/line_data_104
   resume_from:
-  num_workers: 4
+#  resume_from: /home/dieu/lcnn/logs/241112-163302-175fb79-my_data_104_resume
+  num_workers: 0
   tensorboard_port: 0
-  validation_interval: 24000
+  validation_interval: 300    # 评估间隔
 
 model:
   image:
       mean: [109.730, 103.832, 98.681]
       stddev: [22.275, 22.124, 23.229]
 
-  batch_size: 2
+  batch_size: 4
   batch_size_eval: 2
 
   # backbone multi-task parameters
-  head_size: [[2], [1], [2]]
+  head_size: [[2], [1], [2],[4]]
   loss_weight:
     jmap: 8.0
     lmap: 0.5
     joff: 0.25
     lpos: 1
     lneg: 1
+    boxes: 1.0  # 新增 box loss 权重
 
   # backbone parameters
-  backbone: stacked_hourglass
+  backbone: fasterrcnn_resnet50
+#  backbone: unet
   depth: 4
-  num_stacks: 2
+  num_stacks: 1
   num_blocks: 1
 
   # sampler parameters
@@ -65,5 +69,5 @@ optim:
   lr: 4.0e-4
   amsgrad: True
   weight_decay: 1.0e-4
-  max_epoch: 24
+  max_epoch: 1000
   lr_decay_epoch: 10

+ 4 - 0
lcnn/__init__.py

@@ -0,0 +1,4 @@
+import lcnn.models
+import lcnn.trainer
+import lcnn.datasets
+import lcnn.config

+ 1110 - 0
lcnn/box.py

@@ -0,0 +1,1110 @@
+#!/usr/bin/env python
+# -*- coding: UTF-8 -*-
+#
+# Copyright (c) 2017-2019 - Chris Griffith - MIT License
+"""
+Improved dictionary access through dot notation with additional tools.
+"""
+import string
+import sys
+import json
+import re
+import copy
+from keyword import kwlist
+import warnings
+
+try:
+    from collections.abc import Iterable, Mapping, Callable
+except ImportError:
+    from collections import Iterable, Mapping, Callable
+
+yaml_support = True
+
+try:
+    import yaml
+except ImportError:
+    try:
+        import ruamel.yaml as yaml
+    except ImportError:
+        yaml = None
+        yaml_support = False
+
+if sys.version_info >= (3, 0):
+    basestring = str
+else:
+    from io import open
+
+__all__ = ['Box', 'ConfigBox', 'BoxList', 'SBox',
+           'BoxError', 'BoxKeyError']
+__author__ = 'Chris Griffith'
+__version__ = '3.2.4'
+
+BOX_PARAMETERS = ('default_box', 'default_box_attr', 'conversion_box',
+                  'frozen_box', 'camel_killer_box', 'box_it_up',
+                  'box_safe_prefix', 'box_duplicates', 'ordered_box')
+
+_first_cap_re = re.compile('(.)([A-Z][a-z]+)')
+_all_cap_re = re.compile('([a-z0-9])([A-Z])')
+
+
+class BoxError(Exception):
+    """Non standard dictionary exceptions"""
+
+
+class BoxKeyError(BoxError, KeyError, AttributeError):
+    """Key does not exist"""
+
+
+# Abstract converter functions for use in any Box class
+
+
+def _to_json(obj, filename=None,
+             encoding="utf-8", errors="strict", **json_kwargs):
+    json_dump = json.dumps(obj,
+                           ensure_ascii=False, **json_kwargs)
+    if filename:
+        with open(filename, 'w', encoding=encoding, errors=errors) as f:
+            f.write(json_dump if sys.version_info >= (3, 0) else
+                    json_dump.decode("utf-8"))
+    else:
+        return json_dump
+
+
+def _from_json(json_string=None, filename=None,
+               encoding="utf-8", errors="strict", multiline=False, **kwargs):
+    if filename:
+        with open(filename, 'r', encoding=encoding, errors=errors) as f:
+            if multiline:
+                data = [json.loads(line.strip(), **kwargs) for line in f
+                        if line.strip() and not line.strip().startswith("#")]
+            else:
+                data = json.load(f, **kwargs)
+    elif json_string:
+        data = json.loads(json_string, **kwargs)
+    else:
+        raise BoxError('from_json requires a string or filename')
+    return data
+
+
+def _to_yaml(obj, filename=None, default_flow_style=False,
+             encoding="utf-8", errors="strict",
+             **yaml_kwargs):
+    if filename:
+        with open(filename, 'w',
+                  encoding=encoding, errors=errors) as f:
+            yaml.dump(obj, stream=f,
+                      default_flow_style=default_flow_style,
+                      **yaml_kwargs)
+    else:
+        return yaml.dump(obj,
+                         default_flow_style=default_flow_style,
+                         **yaml_kwargs)
+
+
+def _from_yaml(yaml_string=None, filename=None,
+               encoding="utf-8", errors="strict",
+               **kwargs):
+    if filename:
+        with open(filename, 'r',
+                  encoding=encoding, errors=errors) as f:
+            data = yaml.load(f, **kwargs)
+    elif yaml_string:
+        data = yaml.load(yaml_string, **kwargs)
+    else:
+        raise BoxError('from_yaml requires a string or filename')
+    return data
+
+
+# Helper functions
+
+
+def _safe_key(key):
+    try:
+        return str(key)
+    except UnicodeEncodeError:
+        return key.encode("utf-8", "ignore")
+
+
+def _safe_attr(attr, camel_killer=False, replacement_char='x'):
+    """Convert a key into something that is accessible as an attribute"""
+    allowed = string.ascii_letters + string.digits + '_'
+
+    attr = _safe_key(attr)
+
+    if camel_killer:
+        attr = _camel_killer(attr)
+
+    attr = attr.replace(' ', '_')
+
+    out = ''
+    for character in attr:
+        out += character if character in allowed else "_"
+    out = out.strip("_")
+
+    try:
+        int(out[0])
+    except (ValueError, IndexError):
+        pass
+    else:
+        out = '{0}{1}'.format(replacement_char, out)
+
+    if out in kwlist:
+        out = '{0}{1}'.format(replacement_char, out)
+
+    return re.sub('_+', '_', out)
+
+
+def _camel_killer(attr):
+    """
+    CamelKiller, qu'est-ce que c'est?
+
+    Taken from http://stackoverflow.com/a/1176023/3244542
+    """
+    try:
+        attr = str(attr)
+    except UnicodeEncodeError:
+        attr = attr.encode("utf-8", "ignore")
+
+    s1 = _first_cap_re.sub(r'\1_\2', attr)
+    s2 = _all_cap_re.sub(r'\1_\2', s1)
+    return re.sub('_+', '_', s2.casefold() if hasattr(s2, 'casefold') else
+                  s2.lower())
+
+
+def _recursive_tuples(iterable, box_class, recreate_tuples=False, **kwargs):
+    out_list = []
+    for i in iterable:
+        if isinstance(i, dict):
+            out_list.append(box_class(i, **kwargs))
+        elif isinstance(i, list) or (recreate_tuples and isinstance(i, tuple)):
+            out_list.append(_recursive_tuples(i, box_class,
+                                              recreate_tuples, **kwargs))
+        else:
+            out_list.append(i)
+    return tuple(out_list)
+
+
+def _conversion_checks(item, keys, box_config, check_only=False,
+                       pre_check=False):
+    """
+    Internal use for checking if a duplicate safe attribute already exists
+
+    :param item: Item to see if a dup exists
+    :param keys: Keys to check against
+    :param box_config: Easier to pass in than ask for specfic items
+    :param check_only: Don't bother doing the conversion work
+    :param pre_check: Need to add the item to the list of keys to check
+    :return: the original unmodified key, if exists and not check_only
+    """
+    if box_config['box_duplicates'] != 'ignore':
+        if pre_check:
+            keys = list(keys) + [item]
+
+        key_list = [(k,
+                     _safe_attr(k, camel_killer=box_config['camel_killer_box'],
+                                replacement_char=box_config['box_safe_prefix']
+                                )) for k in keys]
+        if len(key_list) > len(set(x[1] for x in key_list)):
+            seen = set()
+            dups = set()
+            for x in key_list:
+                if x[1] in seen:
+                    dups.add("{0}({1})".format(x[0], x[1]))
+                seen.add(x[1])
+            if box_config['box_duplicates'].startswith("warn"):
+                warnings.warn('Duplicate conversion attributes exist: '
+                              '{0}'.format(dups))
+            else:
+                raise BoxError('Duplicate conversion attributes exist: '
+                               '{0}'.format(dups))
+    if check_only:
+        return
+    # This way will be slower for warnings, as it will have double work
+    # But faster for the default 'ignore'
+    for k in keys:
+        if item == _safe_attr(k, camel_killer=box_config['camel_killer_box'],
+                              replacement_char=box_config['box_safe_prefix']):
+            return k
+
+
+def _get_box_config(cls, kwargs):
+    return {
+        # Internal use only
+        '__converted': set(),
+        '__box_heritage': kwargs.pop('__box_heritage', None),
+        '__created': False,
+        '__ordered_box_values': [],
+        # Can be changed by user after box creation
+        'default_box': kwargs.pop('default_box', False),
+        'default_box_attr': kwargs.pop('default_box_attr', cls),
+        'conversion_box': kwargs.pop('conversion_box', True),
+        'box_safe_prefix': kwargs.pop('box_safe_prefix', 'x'),
+        'frozen_box': kwargs.pop('frozen_box', False),
+        'camel_killer_box': kwargs.pop('camel_killer_box', False),
+        'modify_tuples_box': kwargs.pop('modify_tuples_box', False),
+        'box_duplicates': kwargs.pop('box_duplicates', 'ignore'),
+        'ordered_box': kwargs.pop('ordered_box', False)
+    }
+
+
+class Box(dict):
+    """
+    Improved dictionary access through dot notation with additional tools.
+
+    :param default_box: Similar to defaultdict, return a default value
+    :param default_box_attr: Specify the default replacement.
+        WARNING: If this is not the default 'Box', it will not be recursive
+    :param frozen_box: After creation, the box cannot be modified
+    :param camel_killer_box: Convert CamelCase to snake_case
+    :param conversion_box: Check for near matching keys as attributes
+    :param modify_tuples_box: Recreate incoming tuples with dicts into Boxes
+    :param box_it_up: Recursively create all Boxes from the start
+    :param box_safe_prefix: Conversion box prefix for unsafe attributes
+    :param box_duplicates: "ignore", "error" or "warn" when duplicates exists
+        in a conversion_box
+    :param ordered_box: Preserve the order of keys entered into the box
+    """
+
+    _protected_keys = dir({}) + ['to_dict', 'tree_view', 'to_json', 'to_yaml',
+                                 'from_yaml', 'from_json']
+
+    def __new__(cls, *args, **kwargs):
+        """
+        Due to the way pickling works in python 3, we need to make sure
+        the box config is created as early as possible.
+        """
+        obj = super(Box, cls).__new__(cls, *args, **kwargs)
+        obj._box_config = _get_box_config(cls, kwargs)
+        return obj
+
+    def __init__(self, *args, **kwargs):
+        self._box_config = _get_box_config(self.__class__, kwargs)
+        if self._box_config['ordered_box']:
+            self._box_config['__ordered_box_values'] = []
+        if (not self._box_config['conversion_box'] and
+                self._box_config['box_duplicates'] != "ignore"):
+            raise BoxError('box_duplicates are only for conversion_boxes')
+        if len(args) == 1:
+            if isinstance(args[0], basestring):
+                raise ValueError('Cannot extrapolate Box from string')
+            if isinstance(args[0], Mapping):
+                for k, v in args[0].items():
+                    if v is args[0]:
+                        v = self
+                    self[k] = v
+                    self.__add_ordered(k)
+            elif isinstance(args[0], Iterable):
+                for k, v in args[0]:
+                    self[k] = v
+                    self.__add_ordered(k)
+
+            else:
+                raise ValueError('First argument must be mapping or iterable')
+        elif args:
+            raise TypeError('Box expected at most 1 argument, '
+                            'got {0}'.format(len(args)))
+
+        box_it = kwargs.pop('box_it_up', False)
+        for k, v in kwargs.items():
+            if args and isinstance(args[0], Mapping) and v is args[0]:
+                v = self
+            self[k] = v
+            self.__add_ordered(k)
+
+        if (self._box_config['frozen_box'] or box_it or
+                self._box_config['box_duplicates'] != 'ignore'):
+            self.box_it_up()
+
+        self._box_config['__created'] = True
+
+    def __add_ordered(self, key):
+        if (self._box_config['ordered_box'] and
+                key not in self._box_config['__ordered_box_values']):
+            self._box_config['__ordered_box_values'].append(key)
+
+    def box_it_up(self):
+        """
+        Perform value lookup for all items in current dictionary,
+        generating all sub Box objects, while also running `box_it_up` on
+        any of those sub box objects.
+        """
+        for k in self:
+            _conversion_checks(k, self.keys(), self._box_config,
+                               check_only=True)
+            if self[k] is not self and hasattr(self[k], 'box_it_up'):
+                self[k].box_it_up()
+
+    def __hash__(self):
+        if self._box_config['frozen_box']:
+            hashing = 54321
+            for item in self.items():
+                hashing ^= hash(item)
+            return hashing
+        raise TypeError("unhashable type: 'Box'")
+
+    def __dir__(self):
+        allowed = string.ascii_letters + string.digits + '_'
+        kill_camel = self._box_config['camel_killer_box']
+        items = set(dir(dict) + ['to_dict', 'to_json',
+                                 'from_json', 'box_it_up'])
+        # Only show items accessible by dot notation
+        for key in self.keys():
+            key = _safe_key(key)
+            if (' ' not in key and key[0] not in string.digits and
+                    key not in kwlist):
+                for letter in key:
+                    if letter not in allowed:
+                        break
+                else:
+                    items.add(key)
+
+        for key in self.keys():
+            key = _safe_key(key)
+            if key not in items:
+                if self._box_config['conversion_box']:
+                    key = _safe_attr(key, camel_killer=kill_camel,
+                                     replacement_char=self._box_config[
+                                         'box_safe_prefix'])
+                    if key:
+                        items.add(key)
+            if kill_camel:
+                snake_key = _camel_killer(key)
+                if snake_key:
+                    items.remove(key)
+                    items.add(snake_key)
+
+        if yaml_support:
+            items.add('to_yaml')
+            items.add('from_yaml')
+
+        return list(items)
+
+    def get(self, key, default=None):
+        try:
+            return self[key]
+        except KeyError:
+            if isinstance(default, dict) and not isinstance(default, Box):
+                return Box(default)
+            if isinstance(default, list) and not isinstance(default, BoxList):
+                return BoxList(default)
+            return default
+
+    def copy(self):
+        return self.__class__(super(self.__class__, self).copy())
+
+    def __copy__(self):
+        return self.__class__(super(self.__class__, self).copy())
+
+    def __deepcopy__(self, memodict=None):
+        out = self.__class__()
+        memodict = memodict or {}
+        memodict[id(self)] = out
+        for k, v in self.items():
+            out[copy.deepcopy(k, memodict)] = copy.deepcopy(v, memodict)
+        return out
+
+    def __setstate__(self, state):
+        self._box_config = state['_box_config']
+        self.__dict__.update(state)
+
+    def __getitem__(self, item, _ignore_default=False):
+        try:
+            value = super(Box, self).__getitem__(item)
+        except KeyError as err:
+            if item == '_box_config':
+                raise BoxKeyError('_box_config should only exist as an '
+                                  'attribute and is never defaulted')
+            if self._box_config['default_box'] and not _ignore_default:
+                return self.__get_default(item)
+            raise BoxKeyError(str(err))
+        else:
+            return self.__convert_and_store(item, value)
+
+    def keys(self):
+        if self._box_config['ordered_box']:
+            return self._box_config['__ordered_box_values']
+        return super(Box, self).keys()
+
+    def values(self):
+        return [self[x] for x in self.keys()]
+
+    def items(self):
+        return [(x, self[x]) for x in self.keys()]
+
+    def __get_default(self, item):
+        default_value = self._box_config['default_box_attr']
+        if default_value is self.__class__:
+            return self.__class__(__box_heritage=(self, item),
+                                  **self.__box_config())
+        elif isinstance(default_value, Callable):
+            return default_value()
+        elif hasattr(default_value, 'copy'):
+            return default_value.copy()
+        return default_value
+
+    def __box_config(self):
+        out = {}
+        for k, v in self._box_config.copy().items():
+            if not k.startswith("__"):
+                out[k] = v
+        return out
+
+    def __convert_and_store(self, item, value):
+        if item in self._box_config['__converted']:
+            return value
+        if isinstance(value, dict) and not isinstance(value, Box):
+            value = self.__class__(value, __box_heritage=(self, item),
+                                   **self.__box_config())
+            self[item] = value
+        elif isinstance(value, list) and not isinstance(value, BoxList):
+            if self._box_config['frozen_box']:
+                value = _recursive_tuples(value, self.__class__,
+                                          recreate_tuples=self._box_config[
+                                              'modify_tuples_box'],
+                                          __box_heritage=(self, item),
+                                          **self.__box_config())
+            else:
+                value = BoxList(value, __box_heritage=(self, item),
+                                box_class=self.__class__,
+                                **self.__box_config())
+            self[item] = value
+        elif (self._box_config['modify_tuples_box'] and
+              isinstance(value, tuple)):
+            value = _recursive_tuples(value, self.__class__,
+                                      recreate_tuples=True,
+                                      __box_heritage=(self, item),
+                                      **self.__box_config())
+            self[item] = value
+        self._box_config['__converted'].add(item)
+        return value
+
+    def __create_lineage(self):
+        if (self._box_config['__box_heritage'] and
+                self._box_config['__created']):
+            past, item = self._box_config['__box_heritage']
+            if not past[item]:
+                past[item] = self
+            self._box_config['__box_heritage'] = None
+
+    def __getattr__(self, item):
+        try:
+            try:
+                value = self.__getitem__(item, _ignore_default=True)
+            except KeyError:
+                value = object.__getattribute__(self, item)
+        except AttributeError as err:
+            if item == "__getstate__":
+                raise AttributeError(item)
+            if item == '_box_config':
+                raise BoxError('_box_config key must exist')
+            kill_camel = self._box_config['camel_killer_box']
+            if self._box_config['conversion_box'] and item:
+                k = _conversion_checks(item, self.keys(), self._box_config)
+                if k:
+                    return self.__getitem__(k)
+            if kill_camel:
+                for k in self.keys():
+                    if item == _camel_killer(k):
+                        return self.__getitem__(k)
+            if self._box_config['default_box']:
+                return self.__get_default(item)
+            raise BoxKeyError(str(err))
+        else:
+            if item == '_box_config':
+                return value
+            return self.__convert_and_store(item, value)
+
+    def __setitem__(self, key, value):
+        if (key != '_box_config' and self._box_config['__created'] and
+                self._box_config['frozen_box']):
+            raise BoxError('Box is frozen')
+        if self._box_config['conversion_box']:
+            _conversion_checks(key, self.keys(), self._box_config,
+                               check_only=True, pre_check=True)
+        super(Box, self).__setitem__(key, value)
+        self.__add_ordered(key)
+        self.__create_lineage()
+
+    def __setattr__(self, key, value):
+        if (key != '_box_config' and self._box_config['frozen_box'] and
+                self._box_config['__created']):
+            raise BoxError('Box is frozen')
+        if key in self._protected_keys:
+            raise AttributeError("Key name '{0}' is protected".format(key))
+        if key == '_box_config':
+            return object.__setattr__(self, key, value)
+        try:
+            object.__getattribute__(self, key)
+        except (AttributeError, UnicodeEncodeError):
+            if (key not in self.keys() and
+                    (self._box_config['conversion_box'] or
+                     self._box_config['camel_killer_box'])):
+                if self._box_config['conversion_box']:
+                    k = _conversion_checks(key, self.keys(),
+                                           self._box_config)
+                    self[key if not k else k] = value
+                elif self._box_config['camel_killer_box']:
+                    for each_key in self:
+                        if key == _camel_killer(each_key):
+                            self[each_key] = value
+                            break
+            else:
+                self[key] = value
+        else:
+            object.__setattr__(self, key, value)
+        self.__add_ordered(key)
+        self.__create_lineage()
+
+    def __delitem__(self, key):
+        if self._box_config['frozen_box']:
+            raise BoxError('Box is frozen')
+        super(Box, self).__delitem__(key)
+        if (self._box_config['ordered_box'] and
+                key in self._box_config['__ordered_box_values']):
+            self._box_config['__ordered_box_values'].remove(key)
+
+    def __delattr__(self, item):
+        if self._box_config['frozen_box']:
+            raise BoxError('Box is frozen')
+        if item == '_box_config':
+            raise BoxError('"_box_config" is protected')
+        if item in self._protected_keys:
+            raise AttributeError("Key name '{0}' is protected".format(item))
+        try:
+            object.__getattribute__(self, item)
+        except AttributeError:
+            del self[item]
+        else:
+            object.__delattr__(self, item)
+        if (self._box_config['ordered_box'] and
+                item in self._box_config['__ordered_box_values']):
+            self._box_config['__ordered_box_values'].remove(item)
+
+    def pop(self, key, *args):
+        if args:
+            if len(args) != 1:
+                raise BoxError('pop() takes only one optional'
+                               ' argument "default"')
+            try:
+                item = self[key]
+            except KeyError:
+                return args[0]
+            else:
+                del self[key]
+                return item
+        try:
+            item = self[key]
+        except KeyError:
+            raise BoxKeyError('{0}'.format(key))
+        else:
+            del self[key]
+            return item
+
+    def clear(self):
+        self._box_config['__ordered_box_values'] = []
+        super(Box, self).clear()
+
+    def popitem(self):
+        try:
+            key = next(self.__iter__())
+        except StopIteration:
+            raise BoxKeyError('Empty box')
+        return key, self.pop(key)
+
+    def __repr__(self):
+        return '<Box: {0}>'.format(str(self.to_dict()))
+
+    def __str__(self):
+        return str(self.to_dict())
+
+    def __iter__(self):
+        for key in self.keys():
+            yield key
+
+    def __reversed__(self):
+        for key in reversed(list(self.keys())):
+            yield key
+
+    def to_dict(self):
+        """
+        Turn the Box and sub Boxes back into a native
+        python dictionary.
+
+        :return: python dictionary of this Box
+        """
+        out_dict = dict(self)
+        for k, v in out_dict.items():
+            if v is self:
+                out_dict[k] = out_dict
+            elif hasattr(v, 'to_dict'):
+                out_dict[k] = v.to_dict()
+            elif hasattr(v, 'to_list'):
+                out_dict[k] = v.to_list()
+        return out_dict
+
+    def update(self, item=None, **kwargs):
+        if not item:
+            item = kwargs
+        iter_over = item.items() if hasattr(item, 'items') else item
+        for k, v in iter_over:
+            if isinstance(v, dict):
+                # Box objects must be created in case they are already
+                # in the `converted` box_config set
+                v = self.__class__(v)
+                if k in self and isinstance(self[k], dict):
+                    self[k].update(v)
+                    continue
+            if isinstance(v, list):
+                v = BoxList(v)
+            try:
+                self.__setattr__(k, v)
+            except (AttributeError, TypeError):
+                self.__setitem__(k, v)
+
+    def setdefault(self, item, default=None):
+        if item in self:
+            return self[item]
+
+        if isinstance(default, dict):
+            default = self.__class__(default)
+        if isinstance(default, list):
+            default = BoxList(default)
+        self[item] = default
+        return default
+
+    def to_json(self, filename=None,
+                encoding="utf-8", errors="strict", **json_kwargs):
+        """
+        Transform the Box object into a JSON string.
+
+        :param filename: If provided will save to file
+        :param encoding: File encoding
+        :param errors: How to handle encoding errors
+        :param json_kwargs: additional arguments to pass to json.dump(s)
+        :return: string of JSON or return of `json.dump`
+        """
+        return _to_json(self.to_dict(), filename=filename,
+                        encoding=encoding, errors=errors, **json_kwargs)
+
+    @classmethod
+    def from_json(cls, json_string=None, filename=None,
+                  encoding="utf-8", errors="strict", **kwargs):
+        """
+        Transform a json object string into a Box object. If the incoming
+        json is a list, you must use BoxList.from_json.
+
+        :param json_string: string to pass to `json.loads`
+        :param filename: filename to open and pass to `json.load`
+        :param encoding: File encoding
+        :param errors: How to handle encoding errors
+        :param kwargs: parameters to pass to `Box()` or `json.loads`
+        :return: Box object from json data
+        """
+        bx_args = {}
+        for arg in kwargs.copy():
+            if arg in BOX_PARAMETERS:
+                bx_args[arg] = kwargs.pop(arg)
+
+        data = _from_json(json_string, filename=filename,
+                          encoding=encoding, errors=errors, **kwargs)
+
+        if not isinstance(data, dict):
+            raise BoxError('json data not returned as a dictionary, '
+                           'but rather a {0}'.format(type(data).__name__))
+        return cls(data, **bx_args)
+
+    if yaml_support:
+        def to_yaml(self, filename=None, default_flow_style=False,
+                    encoding="utf-8", errors="strict",
+                    **yaml_kwargs):
+            """
+            Transform the Box object into a YAML string.
+
+            :param filename:  If provided will save to file
+            :param default_flow_style: False will recursively dump dicts
+            :param encoding: File encoding
+            :param errors: How to handle encoding errors
+            :param yaml_kwargs: additional arguments to pass to yaml.dump
+            :return: string of YAML or return of `yaml.dump`
+            """
+            return _to_yaml(self.to_dict(), filename=filename,
+                            default_flow_style=default_flow_style,
+                            encoding=encoding, errors=errors, **yaml_kwargs)
+
+        @classmethod
+        def from_yaml(cls, yaml_string=None, filename=None,
+                      encoding="utf-8", errors="strict",
+                      loader=yaml.SafeLoader, **kwargs):
+            """
+            Transform a yaml object string into a Box object.
+
+            :param yaml_string: string to pass to `yaml.load`
+            :param filename: filename to open and pass to `yaml.load`
+            :param encoding: File encoding
+            :param errors: How to handle encoding errors
+            :param loader: YAML Loader, defaults to SafeLoader
+            :param kwargs: parameters to pass to `Box()` or `yaml.load`
+            :return: Box object from yaml data
+            """
+            bx_args = {}
+            for arg in kwargs.copy():
+                if arg in BOX_PARAMETERS:
+                    bx_args[arg] = kwargs.pop(arg)
+
+            data = _from_yaml(yaml_string=yaml_string, filename=filename,
+                              encoding=encoding, errors=errors,
+                              Loader=loader, **kwargs)
+            if not isinstance(data, dict):
+                raise BoxError('yaml data not returned as a dictionary'
+                               'but rather a {0}'.format(type(data).__name__))
+            return cls(data, **bx_args)
+
+
+class BoxList(list):
+    """
+    Drop in replacement of list, that converts added objects to Box or BoxList
+    objects as necessary.
+    """
+
+    def __init__(self, iterable=None, box_class=Box, **box_options):
+        self.box_class = box_class
+        self.box_options = box_options
+        self.box_org_ref = self.box_org_ref = id(iterable) if iterable else 0
+        if iterable:
+            for x in iterable:
+                self.append(x)
+        if box_options.get('frozen_box'):
+            def frozen(*args, **kwargs):
+                raise BoxError('BoxList is frozen')
+
+            for method in ['append', 'extend', 'insert', 'pop',
+                           'remove', 'reverse', 'sort']:
+                self.__setattr__(method, frozen)
+
+    def __delitem__(self, key):
+        if self.box_options.get('frozen_box'):
+            raise BoxError('BoxList is frozen')
+        super(BoxList, self).__delitem__(key)
+
+    def __setitem__(self, key, value):
+        if self.box_options.get('frozen_box'):
+            raise BoxError('BoxList is frozen')
+        super(BoxList, self).__setitem__(key, value)
+
+    def append(self, p_object):
+        if isinstance(p_object, dict):
+            try:
+                p_object = self.box_class(p_object, **self.box_options)
+            except AttributeError as err:
+                if 'box_class' in self.__dict__:
+                    raise err
+        elif isinstance(p_object, list):
+            try:
+                p_object = (self if id(p_object) == self.box_org_ref else
+                            BoxList(p_object))
+            except AttributeError as err:
+                if 'box_org_ref' in self.__dict__:
+                    raise err
+        super(BoxList, self).append(p_object)
+
+    def extend(self, iterable):
+        for item in iterable:
+            self.append(item)
+
+    def insert(self, index, p_object):
+        if isinstance(p_object, dict):
+            p_object = self.box_class(p_object, **self.box_options)
+        elif isinstance(p_object, list):
+            p_object = (self if id(p_object) == self.box_org_ref else
+                        BoxList(p_object))
+        super(BoxList, self).insert(index, p_object)
+
+    def __repr__(self):
+        return "<BoxList: {0}>".format(self.to_list())
+
+    def __str__(self):
+        return str(self.to_list())
+
+    def __copy__(self):
+        return BoxList((x for x in self),
+                       self.box_class,
+                       **self.box_options)
+
+    def __deepcopy__(self, memodict=None):
+        out = self.__class__()
+        memodict = memodict or {}
+        memodict[id(self)] = out
+        for k in self:
+            out.append(copy.deepcopy(k))
+        return out
+
+    def __hash__(self):
+        if self.box_options.get('frozen_box'):
+            hashing = 98765
+            hashing ^= hash(tuple(self))
+            return hashing
+        raise TypeError("unhashable type: 'BoxList'")
+
+    def to_list(self):
+        new_list = []
+        for x in self:
+            if x is self:
+                new_list.append(new_list)
+            elif isinstance(x, Box):
+                new_list.append(x.to_dict())
+            elif isinstance(x, BoxList):
+                new_list.append(x.to_list())
+            else:
+                new_list.append(x)
+        return new_list
+
+    def to_json(self, filename=None,
+                encoding="utf-8", errors="strict",
+                multiline=False, **json_kwargs):
+        """
+        Transform the BoxList object into a JSON string.
+
+        :param filename: If provided will save to file
+        :param encoding: File encoding
+        :param errors: How to handle encoding errors
+        :param multiline: Put each item in list onto it's own line
+        :param json_kwargs: additional arguments to pass to json.dump(s)
+        :return: string of JSON or return of `json.dump`
+        """
+        if filename and multiline:
+            lines = [_to_json(item, filename=False, encoding=encoding,
+                              errors=errors, **json_kwargs) for item in self]
+            with open(filename, 'w', encoding=encoding, errors=errors) as f:
+                f.write("\n".join(lines).decode('utf-8') if
+                        sys.version_info < (3, 0) else "\n".join(lines))
+        else:
+            return _to_json(self.to_list(), filename=filename,
+                            encoding=encoding, errors=errors, **json_kwargs)
+
+    @classmethod
+    def from_json(cls, json_string=None, filename=None, encoding="utf-8",
+                  errors="strict", multiline=False, **kwargs):
+        """
+        Transform a json object string into a BoxList object. If the incoming
+        json is a dict, you must use Box.from_json.
+
+        :param json_string: string to pass to `json.loads`
+        :param filename: filename to open and pass to `json.load`
+        :param encoding: File encoding
+        :param errors: How to handle encoding errors
+        :param multiline: One object per line
+        :param kwargs: parameters to pass to `Box()` or `json.loads`
+        :return: BoxList object from json data
+        """
+        bx_args = {}
+        for arg in kwargs.copy():
+            if arg in BOX_PARAMETERS:
+                bx_args[arg] = kwargs.pop(arg)
+
+        data = _from_json(json_string, filename=filename, encoding=encoding,
+                          errors=errors, multiline=multiline, **kwargs)
+
+        if not isinstance(data, list):
+            raise BoxError('json data not returned as a list, '
+                           'but rather a {0}'.format(type(data).__name__))
+        return cls(data, **bx_args)
+
+    if yaml_support:
+        def to_yaml(self, filename=None, default_flow_style=False,
+                    encoding="utf-8", errors="strict",
+                    **yaml_kwargs):
+            """
+            Transform the BoxList object into a YAML string.
+
+            :param filename:  If provided will save to file
+            :param default_flow_style: False will recursively dump dicts
+            :param encoding: File encoding
+            :param errors: How to handle encoding errors
+            :param yaml_kwargs: additional arguments to pass to yaml.dump
+            :return: string of YAML or return of `yaml.dump`
+            """
+            return _to_yaml(self.to_list(), filename=filename,
+                            default_flow_style=default_flow_style,
+                            encoding=encoding, errors=errors, **yaml_kwargs)
+
+        @classmethod
+        def from_yaml(cls, yaml_string=None, filename=None,
+                      encoding="utf-8", errors="strict",
+                      loader=yaml.SafeLoader,
+                      **kwargs):
+            """
+            Transform a yaml object string into a BoxList object.
+
+            :param yaml_string: string to pass to `yaml.load`
+            :param filename: filename to open and pass to `yaml.load`
+            :param encoding: File encoding
+            :param errors: How to handle encoding errors
+            :param loader: YAML Loader, defaults to SafeLoader
+            :param kwargs: parameters to pass to `BoxList()` or `yaml.load`
+            :return: BoxList object from yaml data
+            """
+            bx_args = {}
+            for arg in kwargs.copy():
+                if arg in BOX_PARAMETERS:
+                    bx_args[arg] = kwargs.pop(arg)
+
+            data = _from_yaml(yaml_string=yaml_string, filename=filename,
+                              encoding=encoding, errors=errors,
+                              Loader=loader, **kwargs)
+            if not isinstance(data, list):
+                raise BoxError('yaml data not returned as a list'
+                               'but rather a {0}'.format(type(data).__name__))
+            return cls(data, **bx_args)
+
+    def box_it_up(self):
+        for v in self:
+            if hasattr(v, 'box_it_up') and v is not self:
+                v.box_it_up()
+
+
+class ConfigBox(Box):
+    """
+    Modified box object to add object transforms.
+
+    Allows for build in transforms like:
+
+    cns = ConfigBox(my_bool='yes', my_int='5', my_list='5,4,3,3,2')
+
+    cns.bool('my_bool') # True
+    cns.int('my_int') # 5
+    cns.list('my_list', mod=lambda x: int(x)) # [5, 4, 3, 3, 2]
+    """
+
+    _protected_keys = dir({}) + ['to_dict', 'bool', 'int', 'float',
+                                 'list', 'getboolean', 'to_json', 'to_yaml',
+                                 'getfloat', 'getint',
+                                 'from_json', 'from_yaml']
+
+    def __getattr__(self, item):
+        """Config file keys are stored in lower case, be a little more
+        loosey goosey"""
+        try:
+            return super(ConfigBox, self).__getattr__(item)
+        except AttributeError:
+            return super(ConfigBox, self).__getattr__(item.lower())
+
+    def __dir__(self):
+        return super(ConfigBox, self).__dir__() + ['bool', 'int', 'float',
+                                                   'list', 'getboolean',
+                                                   'getfloat', 'getint']
+
+    def bool(self, item, default=None):
+        """ Return value of key as a boolean
+
+        :param item: key of value to transform
+        :param default: value to return if item does not exist
+        :return: approximated bool of value
+        """
+        try:
+            item = self.__getattr__(item)
+        except AttributeError as err:
+            if default is not None:
+                return default
+            raise err
+
+        if isinstance(item, (bool, int)):
+            return bool(item)
+
+        if (isinstance(item, str) and
+                item.lower() in ('n', 'no', 'false', 'f', '0')):
+            return False
+
+        return True if item else False
+
+    def int(self, item, default=None):
+        """ Return value of key as an int
+
+        :param item: key of value to transform
+        :param default: value to return if item does not exist
+        :return: int of value
+        """
+        try:
+            item = self.__getattr__(item)
+        except AttributeError as err:
+            if default is not None:
+                return default
+            raise err
+        return int(item)
+
+    def float(self, item, default=None):
+        """ Return value of key as a float
+
+        :param item: key of value to transform
+        :param default: value to return if item does not exist
+        :return: float of value
+        """
+        try:
+            item = self.__getattr__(item)
+        except AttributeError as err:
+            if default is not None:
+                return default
+            raise err
+        return float(item)
+
+    def list(self, item, default=None, spliter=",", strip=True, mod=None):
+        """ Return value of key as a list
+
+        :param item: key of value to transform
+        :param mod: function to map against list
+        :param default: value to return if item does not exist
+        :param spliter: character to split str on
+        :param strip: clean the list with the `strip`
+        :return: list of items
+        """
+        try:
+            item = self.__getattr__(item)
+        except AttributeError as err:
+            if default is not None:
+                return default
+            raise err
+        if strip:
+            item = item.lstrip('[').rstrip(']')
+        out = [x.strip() if strip else x for x in item.split(spliter)]
+        if mod:
+            return list(map(mod, out))
+        return out
+
+    # loose configparser compatibility
+
+    def getboolean(self, item, default=None):
+        return self.bool(item, default)
+
+    def getint(self, item, default=None):
+        return self.int(item, default)
+
+    def getfloat(self, item, default=None):
+        return self.float(item, default)
+
+    def __repr__(self):
+        return '<ConfigBox: {0}>'.format(str(self.to_dict()))
+
+
+class SBox(Box):
+    """
+    ShorthandBox (SBox) allows for
+    property access of `dict` `json` and `yaml`
+    """
+    _protected_keys = dir({}) + ['to_dict', 'tree_view', 'to_json', 'to_yaml',
+                                 'json', 'yaml', 'from_yaml', 'from_json',
+                                 'dict']
+
+    @property
+    def dict(self):
+        return self.to_dict()
+
+    @property
+    def json(self):
+        return self.to_json()
+
+    if yaml_support:
+        @property
+        def yaml(self):
+            return self.to_yaml()
+
+    def __repr__(self):
+        return '<ShorthandBox: {0}>'.format(str(self.to_dict()))

+ 9 - 0
lcnn/config.py

@@ -0,0 +1,9 @@
+import numpy as np
+
+from lcnn.box import Box
+
+# C is a dict storing all the configuration
+C = Box()
+
+# shortcut for C.model
+M = Box()

+ 111 - 14
models/dataset_tool.py → lcnn/dataset_tool.py

@@ -173,7 +173,7 @@ def read_masks_from_txt(label_path, shape):
     return labels, masks
 
 
-def masks_to_boxes(masks: torch.Tensor) -> torch.Tensor:
+def masks_to_boxes(masks: torch.Tensor, ) -> torch.Tensor:
     """
     Compute the bounding boxes around the provided masks.
 
@@ -215,6 +215,58 @@ def masks_to_boxes(masks: torch.Tensor) -> torch.Tensor:
     return bounding_boxes
 
 
+def line_boxes_faster(target):
+    boxs = []
+    lpre = target["lpre"].cpu().numpy() * 4
+    vecl_target = target["lpre_label"].cpu().numpy()
+    lpre = lpre[vecl_target == 1]
+
+    lines = lpre
+    sline = np.ones(lpre.shape[0])
+
+    if len(lines) > 0 and not (lines[0] == 0).all():
+        for i, ((a, b), s) in enumerate(zip(lines, sline)):
+            if i > 0 and (lines[i] == lines[0]).all():
+                break
+            # plt.plot([a[1], b[1]], [a[0], b[0]], c="red", linewidth=1)  # a[1], b[1]无明确大小
+
+            if a[1] > b[1]:
+                ymax = a[1] + 10
+                ymin = b[1] - 10
+            else:
+                ymin = a[1] - 10
+                ymax = b[1] + 10
+            if a[0] > b[0]:
+                xmax = a[0] + 10
+                xmin = b[0] - 10
+            else:
+                xmin = a[0] - 10
+                xmax = b[0] + 10
+            boxs.append([max(0, ymin), max(0, xmin), min(512, ymax), min(512, xmax)])
+
+    return torch.tensor(boxs)
+
+
+# 将线段变为 [起点,终点]形式  传入[num,2,2]
+def a_to_b(line):
+    result_pairs = []
+    for a, b in line:
+        min_x = min(a[0], b[0])
+        min_y = min(a[1], b[1])
+        new_top_left_x = max((min_x - 10), 0)
+        new_top_left_y = max((min_y - 10), 0)
+        dist_a = (a[0] - new_top_left_x) ** 2 + (a[1] - new_top_left_y) ** 2
+        dist_b = (b[0] - new_top_left_x) ** 2 + (b[1] - new_top_left_y) ** 2
+
+        # 根据距离选择起点并设置标签
+        if dist_a <= dist_b:  # 如果a点离新左上角更近或两者距离相等
+            result_pairs.append([a, b])  # 将a设为起点,b为终点
+        else:  # 如果b点离新左上角更近
+            result_pairs.append([b, a])  # 将b设为起点,a为终点
+    result_tensor = torch.stack([torch.stack(row) for row in result_pairs])
+    return result_tensor
+
+
 def read_polygon_points_wire(lbl_path, shape):
     """读取 YOLOv8 格式的标注文件并解析多边形轮廓"""
     polygon_points = []
@@ -252,24 +304,69 @@ def read_masks_from_pixels_wire(lbl_path, shape):
         lines = json.load(reader)
         mask_points = []
         for line in lines["segmentations"]:
-            mask = torch.zeros((h, w), dtype=torch.uint8)
-            parts = line["data"]
+            # mask = torch.zeros((h, w), dtype=torch.uint8)
+            # parts = line["data"]
             # print(f'parts:{parts}')
             cls = torch.tensor(int(line["cls_id"]), dtype=torch.int64)
             labels.append(cls)
-            x_array = parts[0::2]
-            y_array = parts[1::2]
+            # x_array = parts[0::2]
+            # y_array = parts[1::2]
+            # 
+            # for x, y in zip(x_array, y_array):
+            #     x = float(x)
+            #     y = float(y)
+            #     mask_points.append((int(y * h), int(x * w)))
+
+            # for p in mask_points:
+            #     mask[p] = 1
+            # masks.append(mask)
+    reader.close()
+    return labels
 
-            for x, y in zip(x_array, y_array):
-                x = float(x)
-                y = float(y)
-                mask_points.append((int(y * h), int(x * w)))
 
-            for p in mask_points:
-                mask[p] = 1
-            masks.append(mask)
+def read_lable_keypoint(lbl_path):
+    """判断线段的起点终点, 起点 lable=0, 终点 lable=1"""
+    labels = []
+
+    with open(lbl_path, 'r') as reader:
+        lines = json.load(reader)
+        aa = lines["wires"][0]["line_pos_coords"]["content"]
+
+        result_pairs = []
+        for a, b in aa:
+            min_x = min(a[0], b[0])
+            min_y = min(a[1], b[1])
+
+            # 定义新的左上角位置
+            new_top_left_x = max((min_x - 10), 0)
+            new_top_left_y = max((min_y - 10), 0)
+
+            # Step 2: 计算各点到新左上角的距离平方(避免浮点运算误差)
+            dist_a = (a[0] - new_top_left_x) ** 2 + (a[1] - new_top_left_y) ** 2
+            dist_b = (b[0] - new_top_left_x) ** 2 + (b[1] - new_top_left_y) ** 2
+
+            # Step 3 & 4: 根据距离选择起点并设置标签
+            if dist_a <= dist_b:  # 如果a点离新左上角更近或两者距离相等
+                result_pairs.append([a, b])  # 将a设为起点,b为终点
+            else:  # 如果b点离新左上角更近
+                result_pairs.append([b, a])  # 将b设为起点,a为终点
+
+            # x_ = abs(a[0] - b[0])
+            # y_ = abs(a[1] - b[1])
+            # if x_ > y_:  # x小的离左上角近
+            #     if a[0] < b[0]:  # 视为起点,lable=0
+            #         label = [0, 1]
+            #     else:
+            #         label = [1, 0]
+            # else:  # x大的是起点
+            #     if a[0] > b[0]:  # 视为起点,lable=0
+            #         label = [0, 1]
+            #     else:
+            #         label = [1, 0]
+            # labels.append(label)
+        # print(result_pairs )
     reader.close()
-    return labels, masks
+    return labels
 
 
 def adjacency_matrix(n, link):  # 邻接矩阵
@@ -278,4 +375,4 @@ def adjacency_matrix(n, link):  # 邻接矩阵
     if len(link) > 0:
         mat[link[:, 0], link[:, 1]] = 1
         mat[link[:, 1], link[:, 0]] = 1
-    return mat
+    return mat

+ 294 - 0
lcnn/datasets.py

@@ -0,0 +1,294 @@
+# import glob
+# import json
+# import math
+# import os
+# import random
+#
+# import numpy as np
+# import numpy.linalg as LA
+# import torch
+# from skimage import io
+# from torch.utils.data import Dataset
+# from torch.utils.data.dataloader import default_collate
+#
+# from lcnn.config import M
+#
+# from .dataset_tool import line_boxes_faster, read_masks_from_txt_wire, read_masks_from_pixels_wire
+#
+#
+# class WireframeDataset(Dataset):
+#     def __init__(self, rootdir, split):
+#         self.rootdir = rootdir
+#         filelist = glob.glob(f"{rootdir}/{split}/*_label.npz")
+#         filelist.sort()
+#
+#         # print(f"n{split}:", len(filelist))
+#         self.split = split
+#         self.filelist = filelist
+#
+#     def __len__(self):
+#         return len(self.filelist)
+#
+#     def __getitem__(self, idx):
+#         iname = self.filelist[idx][:-10].replace("_a0", "").replace("_a1", "") + ".png"
+#         image = io.imread(iname).astype(float)[:, :, :3]
+#         if "a1" in self.filelist[idx]:
+#             image = image[:, ::-1, :]
+#         image = (image - M.image.mean) / M.image.stddev
+#         image = np.rollaxis(image, 2).copy()
+#
+#         with np.load(self.filelist[idx]) as npz:
+#             target = {
+#                 name: torch.from_numpy(npz[name]).float()
+#                 for name in ["jmap", "joff", "lmap"]
+#             }
+#             lpos = np.random.permutation(npz["lpos"])[: M.n_stc_posl]
+#             lneg = np.random.permutation(npz["lneg"])[: M.n_stc_negl]
+#             npos, nneg = len(lpos), len(lneg)
+#             lpre = np.concatenate([lpos, lneg], 0)
+#             for i in range(len(lpre)):
+#                 if random.random() > 0.5:
+#                     lpre[i] = lpre[i, ::-1]
+#             ldir = lpre[:, 0, :2] - lpre[:, 1, :2]
+#             ldir /= np.clip(LA.norm(ldir, axis=1, keepdims=True), 1e-6, None)
+#             feat = [
+#                 lpre[:, :, :2].reshape(-1, 4) / 128 * M.use_cood,
+#                 ldir * M.use_slop,
+#                 lpre[:, :, 2],
+#             ]
+#             feat = np.concatenate(feat, 1)
+#             meta = {
+#                 "junc": torch.from_numpy(npz["junc"][:, :2]),
+#                 "jtyp": torch.from_numpy(npz["junc"][:, 2]).byte(),
+#                 "Lpos": self.adjacency_matrix(len(npz["junc"]), npz["Lpos"]),
+#                 "Lneg": self.adjacency_matrix(len(npz["junc"]), npz["Lneg"]),
+#                 "lpre": torch.from_numpy(lpre[:, :, :2]),
+#                 "lpre_label": torch.cat([torch.ones(npos), torch.zeros(nneg)]),
+#                 "lpre_feat": torch.from_numpy(feat),
+#             }
+#
+#         labels = []
+#         labels = read_masks_from_pixels_wire(iname, (512, 512))
+#         # if self.target_type == 'polygon':
+#         #     labels, masks = read_masks_from_txt_wire(iname, (512, 512))
+#         # elif self.target_type == 'pixel':
+#         #     labels = read_masks_from_pixels_wire(iname, (512, 512))
+#
+#         target["labels"] = torch.stack(labels)
+#         target["boxes"] = line_boxes_faster(meta)
+#
+#
+#         return torch.from_numpy(image).float(), meta, target
+#
+#     def adjacency_matrix(self, n, link):
+#         mat = torch.zeros(n + 1, n + 1, dtype=torch.uint8)
+#         link = torch.from_numpy(link)
+#         if len(link) > 0:
+#             mat[link[:, 0], link[:, 1]] = 1
+#             mat[link[:, 1], link[:, 0]] = 1
+#         return mat
+#
+#
+# def collate(batch):
+#     return (
+#         default_collate([b[0] for b in batch]),
+#         [b[1] for b in batch],
+#         default_collate([b[2] for b in batch]),
+#     )
+
+from torch.utils.data.dataset import T_co
+
+from .models.base.base_dataset import BaseDataset
+
+import glob
+import json
+import math
+import os
+import random
+import cv2
+import PIL
+
+import matplotlib.pyplot as plt
+import matplotlib as mpl
+from torchvision.utils import draw_bounding_boxes
+
+import numpy as np
+import numpy.linalg as LA
+import torch
+from skimage import io
+from torch.utils.data import Dataset
+from torch.utils.data.dataloader import default_collate
+
+import matplotlib.pyplot as plt
+from .dataset_tool import line_boxes_faster, read_masks_from_txt_wire, read_masks_from_pixels_wire, adjacency_matrix
+
+
+
+class  WireframeDataset(BaseDataset):
+    def __init__(self, dataset_path, transforms=None, dataset_type=None, target_type='pixel'):
+        super().__init__(dataset_path)
+
+        self.data_path = dataset_path
+        print(f'data_path:{dataset_path}')
+        self.transforms = transforms
+        self.img_path = os.path.join(dataset_path, "images/" + dataset_type)
+        self.lbl_path = os.path.join(dataset_path, "labels/" + dataset_type)
+        self.imgs = os.listdir(self.img_path)
+        self.lbls = os.listdir(self.lbl_path)
+        self.target_type = target_type
+        # self.default_transform = DefaultTransform()
+
+    def __getitem__(self, index) -> T_co:
+        img_path = os.path.join(self.img_path, self.imgs[index])
+        lbl_path = os.path.join(self.lbl_path, self.imgs[index][:-3] + 'json')
+
+        img = PIL.Image.open(img_path).convert('RGB')
+        w, h = img.size
+
+        # wire_labels, target = self.read_target(item=index, lbl_path=lbl_path, shape=(h, w))
+        meta, target, target_b = self.read_target(item=index, lbl_path=lbl_path, shape=(h, w))
+
+        img = self.default_transform(img)
+
+        # print(f'img:{img}')
+        return img, meta, target, target_b
+
+    def __len__(self):
+        return len(self.imgs)
+
+    def read_target(self, item, lbl_path, shape, extra=None):
+        # print(f'shape:{shape}')
+        # print(f'lbl_path:{lbl_path}')
+        with open(lbl_path, 'r') as file:
+            lable_all = json.load(file)
+
+        n_stc_posl = 300
+        n_stc_negl = 40
+        use_cood = 0
+        use_slop = 0
+
+        wire = lable_all["wires"][0]  # 字典
+        line_pos_coords = np.random.permutation(wire["line_pos_coords"]["content"])[: n_stc_posl]  # 不足,有多少取多少
+        line_neg_coords = np.random.permutation(wire["line_neg_coords"]["content"])[: n_stc_negl]
+        npos, nneg = len(line_pos_coords), len(line_neg_coords)
+        lpre = np.concatenate([line_pos_coords, line_neg_coords], 0)  # 正负样本坐标合在一起
+        for i in range(len(lpre)):
+            if random.random() > 0.5:
+                lpre[i] = lpre[i, ::-1]
+        ldir = lpre[:, 0, :2] - lpre[:, 1, :2]
+        ldir /= np.clip(LA.norm(ldir, axis=1, keepdims=True), 1e-6, None)
+        feat = [
+            lpre[:, :, :2].reshape(-1, 4) / 128 * use_cood,
+            ldir * use_slop,
+            lpre[:, :, 2],
+        ]
+        feat = np.concatenate(feat, 1)
+
+        meta = {
+            "junc_coords": torch.tensor(wire["junc_coords"]["content"])[:, :2],
+            "jtyp": torch.tensor(wire["junc_coords"]["content"])[:, 2].byte(),
+            "line_pos_idx": adjacency_matrix(len(wire["junc_coords"]["content"]), wire["line_pos_idx"]["content"]),
+            # 真实存在线条的邻接矩阵
+            "line_neg_idx": adjacency_matrix(len(wire["junc_coords"]["content"]), wire["line_neg_idx"]["content"]),
+
+            "lpre": torch.tensor(lpre)[:, :, :2],
+            "lpre_label": torch.cat([torch.ones(npos), torch.zeros(nneg)]),  # 样本对应标签 1,0
+            "lpre_feat": torch.from_numpy(feat),
+
+        }
+
+        target = {
+            "junc_map": torch.tensor(wire['junc_map']["content"]),
+            "junc_offset": torch.tensor(wire['junc_offset']["content"]),
+            "line_map": torch.tensor(wire['line_map']["content"]),
+        }
+
+        labels = []
+        if self.target_type == 'polygon':
+            labels, masks = read_masks_from_txt_wire(lbl_path, shape)
+        elif self.target_type == 'pixel':
+            labels = read_masks_from_pixels_wire(lbl_path, shape)
+
+        # print(torch.stack(masks).shape)    # [线段数, 512, 512]
+        target_b = {}
+
+        # target_b["image_id"] = torch.tensor(item)
+
+        target_b["labels"] = torch.stack(labels)
+        target_b["boxes"] = line_boxes_faster(meta)
+
+        return meta, target, target_b
+
+
+    def show(self, idx):
+        image, target = self.__getitem__(idx)
+
+        cmap = plt.get_cmap("jet")
+        norm = mpl.colors.Normalize(vmin=0.4, vmax=1.0)
+        sm = plt.cm.ScalarMappable(cmap=cmap, norm=norm)
+        sm.set_array([])
+
+        def imshow(im):
+            plt.close()
+            plt.tight_layout()
+            plt.imshow(im)
+            plt.colorbar(sm, fraction=0.046)
+            plt.xlim([0, im.shape[0]])
+            plt.ylim([im.shape[0], 0])
+
+        def draw_vecl(lines, sline, juncs, junts, fn=None):
+            img_path = os.path.join(self.img_path, self.imgs[idx])
+            imshow(io.imread(img_path))
+            if len(lines) > 0 and not (lines[0] == 0).all():
+                for i, ((a, b), s) in enumerate(zip(lines, sline)):
+                    if i > 0 and (lines[i] == lines[0]).all():
+                        break
+                    plt.plot([a[1], b[1]], [a[0], b[0]], c="red", linewidth=1)  # a[1], b[1]无明确大小
+            if not (juncs[0] == 0).all():
+                for i, j in enumerate(juncs):
+                    if i > 0 and (i == juncs[0]).all():
+                        break
+                    plt.scatter(j[1], j[0], c="red", s=2, zorder=100)  # 原 s=64
+
+            img_path = os.path.join(self.img_path, self.imgs[idx])
+            img = PIL.Image.open(img_path).convert('RGB')
+            boxed_image = draw_bounding_boxes((self.default_transform(img) * 255).to(torch.uint8), target["boxes"],
+                                              colors="yellow", width=1)
+            plt.imshow(boxed_image.permute(1, 2, 0).numpy())
+            plt.show()
+
+            plt.show()
+            if fn != None:
+                plt.savefig(fn)
+
+        junc = target['wires']['junc_coords'].cpu().numpy() * 4
+        jtyp = target['wires']['jtyp'].cpu().numpy()
+        juncs = junc[jtyp == 0]
+        junts = junc[jtyp == 1]
+
+        lpre = target['wires']["lpre"].cpu().numpy() * 4
+        vecl_target = target['wires']["lpre_label"].cpu().numpy()
+        lpre = lpre[vecl_target == 1]
+
+        # draw_vecl(lpre, np.ones(lpre.shape[0]), juncs, junts, save_path)
+        draw_vecl(lpre, np.ones(lpre.shape[0]), juncs, junts)
+
+    def show_img(self, img_path):
+        pass
+
+
+def collate(batch):
+    return (
+        default_collate([b[0] for b in batch]),
+        [b[1] for b in batch],
+        default_collate([b[2] for b in batch]),
+        [b[3] for b in batch],
+    )
+
+
+# if __name__ == '__main__':
+#     path = r"D:\python\PycharmProjects\data"
+#     dataset = WireframeDataset(dataset_path=path, dataset_type='train')
+#     dataset.show(0)
+
+

+ 209 - 0
lcnn/metric.py

@@ -0,0 +1,209 @@
+import numpy as np
+import numpy.linalg as LA
+import matplotlib.pyplot as plt
+
+from lcnn.utils import argsort2d
+
+DX = [0, 0, 1, -1, 1, 1, -1, -1]
+DY = [1, -1, 0, 0, 1, -1, 1, -1]
+
+
+def ap(tp, fp):
+    recall = tp
+    precision = tp / np.maximum(tp + fp, 1e-9)
+
+    recall = np.concatenate(([0.0], recall, [1.0]))
+    precision = np.concatenate(([0.0], precision, [0.0]))
+
+    for i in range(precision.size - 1, 0, -1):
+        precision[i - 1] = max(precision[i - 1], precision[i])
+    i = np.where(recall[1:] != recall[:-1])[0]
+    return np.sum((recall[i + 1] - recall[i]) * precision[i + 1])
+
+
+def APJ(vert_pred, vert_gt, max_distance, im_ids):
+    if len(vert_pred) == 0:
+        return 0
+
+    vert_pred = np.array(vert_pred)
+    vert_gt = np.array(vert_gt)
+
+    confidence = vert_pred[:, -1]
+    idx = np.argsort(-confidence)
+    vert_pred = vert_pred[idx, :]
+    im_ids = im_ids[idx]
+    n_gt = sum(len(gt) for gt in vert_gt)
+
+    nd = len(im_ids)
+    tp, fp = np.zeros(nd, dtype=np.float), np.zeros(nd, dtype=np.float)
+    hit = [[False for _ in j] for j in vert_gt]
+
+    for i in range(nd):
+        gt_juns = vert_gt[im_ids[i]]
+        pred_juns = vert_pred[i][:-1]
+        if len(gt_juns) == 0:
+            continue
+        dists = np.linalg.norm((pred_juns[None, :] - gt_juns), axis=1)
+        choice = np.argmin(dists)
+        dist = np.min(dists)
+        if dist < max_distance and not hit[im_ids[i]][choice]:
+            tp[i] = 1
+            hit[im_ids[i]][choice] = True
+        else:
+            fp[i] = 1
+
+    tp = np.cumsum(tp) / n_gt
+    fp = np.cumsum(fp) / n_gt
+    return ap(tp, fp)
+
+
+def nms_j(heatmap, delta=1):
+    heatmap = heatmap.copy()
+    disable = np.zeros_like(heatmap, dtype=np.bool)
+    for x, y in argsort2d(heatmap):
+        for dx, dy in zip(DX, DY):
+            xp, yp = x + dx, y + dy
+            if not (0 <= xp < heatmap.shape[0] and 0 <= yp < heatmap.shape[1]):
+                continue
+            if heatmap[x, y] >= heatmap[xp, yp]:
+                disable[xp, yp] = True
+    heatmap[disable] *= 0.6
+    return heatmap
+
+
+def mAPJ(pred, truth, distances, im_ids):
+    return sum(APJ(pred, truth, d, im_ids) for d in distances) / len(distances) * 100
+
+
+def post_jheatmap(heatmap, offset=None, delta=1):
+    heatmap = nms_j(heatmap, delta=delta)
+    # only select the best 1000 junctions for efficiency
+    v0 = argsort2d(-heatmap)[:1000]
+    confidence = -np.sort(-heatmap.ravel())[:1000]
+    keep_id = np.where(confidence >= 1e-2)[0]
+    if len(keep_id) == 0:
+        return np.zeros((0, 3))
+
+    confidence = confidence[keep_id]
+    if offset is not None:
+        v0 = np.array([v + offset[:, v[0], v[1]] for v in v0])
+    v0 = v0[keep_id] + 0.5
+    v0 = np.hstack((v0, confidence[:, np.newaxis]))
+    return v0
+
+
+def vectorized_wireframe_2d_metric(
+    vert_pred, dpth_pred, edge_pred, vert_gt, dpth_gt, edge_gt, threshold
+):
+    # staging 1: matching
+    nd = len(vert_pred)
+    sorted_confidence = np.argsort(-vert_pred[:, -1])
+    vert_pred = vert_pred[sorted_confidence, :-1]
+    dpth_pred = dpth_pred[sorted_confidence]
+    d = np.sqrt(
+        np.sum(vert_pred ** 2, 1)[:, None]
+        + np.sum(vert_gt ** 2, 1)[None, :]
+        - 2 * vert_pred @ vert_gt.T
+    )
+    choice = np.argmin(d, 1)
+    dist = np.min(d, 1)
+
+    # staging 2: compute depth metric: SIL/L2
+    loss_L1 = loss_L2 = 0
+    hit = np.zeros_like(dpth_gt, np.bool)
+    SIL = np.zeros(dpth_pred)
+    for i in range(nd):
+        if dist[i] < threshold and not hit[choice[i]]:
+            hit[choice[i]] = True
+            loss_L1 += abs(dpth_gt[choice[i]] - dpth_pred[i])
+            loss_L2 += (dpth_gt[choice[i]] - dpth_pred[i]) ** 2
+            a = np.maximum(-dpth_pred[i], 1e-10)
+            b = -dpth_gt[choice[i]]
+            SIL[i] = np.log(a) - np.log(b)
+        else:
+            choice[i] = -1
+
+    n = max(np.sum(hit), 1)
+    loss_L1 /= n
+    loss_L2 /= n
+    loss_SIL = np.sum(SIL ** 2) / n - np.sum(SIL) ** 2 / (n * n)
+
+    # staging 3: compute mAP for edge matching
+    edgeset = set([frozenset(e) for e in edge_gt])
+    tp = np.zeros(len(edge_pred), dtype=np.float)
+    fp = np.zeros(len(edge_pred), dtype=np.float)
+    for i, (v0, v1, score) in enumerate(sorted(edge_pred, key=-edge_pred[2])):
+        length = LA.norm(vert_gt[v0] - vert_gt[v1], axis=1)
+        if frozenset([choice[v0], choice[v1]]) in edgeset:
+            tp[i] = length
+        else:
+            fp[i] = length
+    total_length = LA.norm(
+        vert_gt[edge_gt[:, 0]] - vert_gt[edge_gt[:, 1]], axis=1
+    ).sum()
+    return ap(tp / total_length, fp / total_length), (loss_SIL, loss_L1, loss_L2)
+
+
+def vectorized_wireframe_3d_metric(
+    vert_pred, dpth_pred, edge_pred, vert_gt, dpth_gt, edge_gt, threshold
+):
+    # staging 1: matching
+    nd = len(vert_pred)
+    sorted_confidence = np.argsort(-vert_pred[:, -1])
+    vert_pred = np.hstack([vert_pred[:, :-1], dpth_pred[:, None]])[sorted_confidence]
+    vert_gt = np.hstack([vert_gt[:, :-1], dpth_gt[:, None]])
+    d = np.sqrt(
+        np.sum(vert_pred ** 2, 1)[:, None]
+        + np.sum(vert_gt ** 2, 1)[None, :]
+        - 2 * vert_pred @ vert_gt.T
+    )
+    choice = np.argmin(d, 1)
+    dist = np.min(d, 1)
+    hit = np.zeros_like(dpth_gt, np.bool)
+    for i in range(nd):
+        if dist[i] < threshold and not hit[choice[i]]:
+            hit[choice[i]] = True
+        else:
+            choice[i] = -1
+
+    # staging 2: compute mAP for edge matching
+    edgeset = set([frozenset(e) for e in edge_gt])
+    tp = np.zeros(len(edge_pred), dtype=np.float)
+    fp = np.zeros(len(edge_pred), dtype=np.float)
+    for i, (v0, v1, score) in enumerate(sorted(edge_pred, key=-edge_pred[2])):
+        length = LA.norm(vert_gt[v0] - vert_gt[v1], axis=1)
+        if frozenset([choice[v0], choice[v1]]) in edgeset:
+            tp[i] = length
+        else:
+            fp[i] = length
+    total_length = LA.norm(
+        vert_gt[edge_gt[:, 0]] - vert_gt[edge_gt[:, 1]], axis=1
+    ).sum()
+
+    return ap(tp / total_length, fp / total_length)
+
+
+def msTPFP(line_pred, line_gt, threshold):
+    diff = ((line_pred[:, None, :, None] - line_gt[:, None]) ** 2).sum(-1)
+    diff = np.minimum(
+        diff[:, :, 0, 0] + diff[:, :, 1, 1], diff[:, :, 0, 1] + diff[:, :, 1, 0]
+    )
+    choice = np.argmin(diff, 1)
+    dist = np.min(diff, 1)
+    hit = np.zeros(len(line_gt), np.bool)
+    tp = np.zeros(len(line_pred), np.float)
+    fp = np.zeros(len(line_pred), np.float)
+    for i in range(len(line_pred)):
+        if dist[i] < threshold and not hit[choice[i]]:
+            hit[choice[i]] = True
+            tp[i] = 1
+        else:
+            fp[i] = 1
+    return tp, fp
+
+
+def msAP(line_pred, line_gt, threshold):
+    tp, fp = msTPFP(line_pred, line_gt, threshold)
+    tp = np.cumsum(tp) / len(line_gt)
+    fp = np.cumsum(fp) / len(line_gt)
+    return ap(tp, fp)

+ 9 - 0
lcnn/models/__init__.py

@@ -0,0 +1,9 @@
+# flake8: noqa
+from .hourglass_pose import hg
+from .unet import unet
+from .resnet50_pose import resnet50
+from .resnet50 import resnet501
+from .fasterrcnn_resnet50 import fasterrcnn_resnet50
+# from .dla import dla169, dla102x, dla102x2
+
+__all__ = ['hg', 'unet', 'resnet50', 'resnet501', 'fasterrcnn_resnet50']

+ 0 - 0
models/__init__.py → lcnn/models/base/__init__.py


+ 0 - 0
models/base/base_dataset.py → lcnn/models/base/base_dataset.py


+ 120 - 0
lcnn/models/fasterrcnn_resnet50.py

@@ -0,0 +1,120 @@
+import torch
+import torch.nn as nn
+import torchvision
+from typing import Dict, List, Optional, Tuple
+import torch.nn.functional as F
+from torchvision.ops import MultiScaleRoIAlign
+from torchvision.models.detection.faster_rcnn import TwoMLPHead, FastRCNNPredictor
+from torchvision.models.detection.transform import GeneralizedRCNNTransform
+
+
+def get_model(num_classes):
+    # 加载预训练的ResNet-50 FPN backbone
+    model = torchvision.models.detection.fasterrcnn_resnet50_fpn(pretrained=True)
+
+    # 获取分类器的输入特征数
+    in_features = model.roi_heads.box_predictor.cls_score.in_features
+
+    # 替换分类器以适应新的类别数量
+    model.roi_heads.box_predictor = torchvision.models.detection.faster_rcnn.FastRCNNPredictor(in_features, num_classes)
+
+    return model
+
+
+def fastrcnn_loss(class_logits, box_regression, labels, regression_targets):
+    # type: (Tensor, Tensor, List[Tensor], List[Tensor]) -> Tuple[Tensor, Tensor]
+    """
+    Computes the loss for Faster R-CNN.
+
+    Args:
+        class_logits (Tensor)
+        box_regression (Tensor)
+        labels (list[BoxList])
+        regression_targets (Tensor)
+
+    Returns:
+        classification_loss (Tensor)
+        box_loss (Tensor)
+    """
+
+    labels = torch.cat(labels, dim=0)
+    regression_targets = torch.cat(regression_targets, dim=0)
+
+    classification_loss = F.cross_entropy(class_logits, labels)
+
+    # get indices that correspond to the regression targets for
+    # the corresponding ground truth labels, to be used with
+    # advanced indexing
+    sampled_pos_inds_subset = torch.where(labels > 0)[0]
+    labels_pos = labels[sampled_pos_inds_subset]
+    N, num_classes = class_logits.shape
+    box_regression = box_regression.reshape(N, box_regression.size(-1) // 4, 4)
+
+    box_loss = F.smooth_l1_loss(
+        box_regression[sampled_pos_inds_subset, labels_pos],
+        regression_targets[sampled_pos_inds_subset],
+        beta=1 / 9,
+        reduction="sum",
+    )
+    box_loss = box_loss / labels.numel()
+
+    return classification_loss, box_loss
+
+
+class Fasterrcnn_resnet50(nn.Module):
+    def __init__(self, num_classes=5, num_stacks=1):
+        super(Fasterrcnn_resnet50, self).__init__()
+
+        self.model = get_model(num_classes=5)
+        self.backbone = self.model.backbone
+
+        self.box_roi_pool = MultiScaleRoIAlign(featmap_names=["0", "1", "2", "3"], output_size=16, sampling_ratio=2)
+
+        out_channels = self.backbone.out_channels
+        resolution = self.box_roi_pool.output_size[0]
+        representation_size = 1024
+        self.box_head = TwoMLPHead(out_channels * resolution ** 2, representation_size)
+
+        self.box_predictor = FastRCNNPredictor(representation_size, num_classes)
+
+        # 多任务输出层
+        self.score_layers = nn.ModuleList([
+            nn.Sequential(
+                nn.Conv2d(256, 128, kernel_size=3, padding=1),
+                nn.BatchNorm2d(128),
+                nn.ReLU(inplace=True),
+                nn.Conv2d(128, num_classes, kernel_size=1)
+            )
+            for _ in range(num_stacks)
+        ])
+
+    def forward(self, x, target1, train_or_val, image_shapes=(512, 512)):
+
+        transform = GeneralizedRCNNTransform(min_size=512, max_size=1333, image_mean=[0.485, 0.456, 0.406],
+                                             image_std=[0.229, 0.224, 0.225])
+        images, targets = transform(x, target1)
+        x_ = self.backbone(images.tensors)
+
+        # x_ = self.backbone(x)  # '0'  '1'  '2'  '3'   'pool'
+        # print(f'backbone:{self.backbone}')
+        # print(f'Fasterrcnn_resnet50 x_:{x_}')
+        feature_ = x_['0']  # 图片特征
+        outputs = []
+        for score_layer in self.score_layers:
+            output = score_layer(feature_)
+            outputs.append(output)  # 多头
+
+        if train_or_val == "training":
+            loss_box = self.model(x, target1)
+            return outputs, feature_, loss_box
+        else:
+            box_all = self.model(x, target1)
+            return outputs, feature_, box_all
+
+
+def fasterrcnn_resnet50(**kwargs):
+    model = Fasterrcnn_resnet50(
+        num_classes=kwargs.get("num_classes", 5),
+        num_stacks=kwargs.get("num_stacks", 1)
+    )
+    return model

+ 201 - 0
lcnn/models/hourglass_pose.py

@@ -0,0 +1,201 @@
+"""
+Hourglass network inserted in the pre-activated Resnet
+Use lr=0.01 for current version
+(c) Yichao Zhou (LCNN)
+(c) YANG, Wei
+"""
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+__all__ = ["HourglassNet", "hg"]
+
+
+class Bottleneck2D(nn.Module):
+    expansion = 2  # 扩展因子
+
+    def __init__(self, inplanes, planes, stride=1, downsample=None):
+        super(Bottleneck2D, self).__init__()
+
+        self.bn1 = nn.BatchNorm2d(inplanes)
+        self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1)
+        self.bn2 = nn.BatchNorm2d(planes)
+        self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, padding=1)
+        self.bn3 = nn.BatchNorm2d(planes)
+        self.conv3 = nn.Conv2d(planes, planes * 2, kernel_size=1)
+        self.relu = nn.ReLU(inplace=True)
+        self.downsample = downsample
+        self.stride = stride
+
+    def forward(self, x):
+        residual = x
+
+        out = self.bn1(x)
+        out = self.relu(out)
+        out = self.conv1(out)
+
+        out = self.bn2(out)
+        out = self.relu(out)
+        out = self.conv2(out)
+
+        out = self.bn3(out)
+        out = self.relu(out)
+        out = self.conv3(out)
+
+        if self.downsample is not None:
+            residual = self.downsample(x)
+
+        out += residual
+
+        return out
+
+
+class Hourglass(nn.Module):
+    def __init__(self, block, num_blocks, planes, depth):
+        super(Hourglass, self).__init__()
+        self.depth = depth
+        self.block = block
+        self.hg = self._make_hour_glass(block, num_blocks, planes, depth)
+
+    def _make_residual(self, block, num_blocks, planes):
+        layers = []
+        for i in range(0, num_blocks):
+            layers.append(block(planes * block.expansion, planes))
+        return nn.Sequential(*layers)
+
+    def _make_hour_glass(self, block, num_blocks, planes, depth):
+        hg = []
+        for i in range(depth):
+            res = []
+            for j in range(3):
+                res.append(self._make_residual(block, num_blocks, planes))
+            if i == 0:
+                res.append(self._make_residual(block, num_blocks, planes))
+            hg.append(nn.ModuleList(res))
+        return nn.ModuleList(hg)
+
+    def _hour_glass_forward(self, n, x):
+        up1 = self.hg[n - 1][0](x)
+        low1 = F.max_pool2d(x, 2, stride=2)
+        low1 = self.hg[n - 1][1](low1)
+
+        if n > 1:
+            low2 = self._hour_glass_forward(n - 1, low1)
+        else:
+            low2 = self.hg[n - 1][3](low1)
+        low3 = self.hg[n - 1][2](low2)
+        up2 = F.interpolate(low3, scale_factor=2)
+        out = up1 + up2
+        return out
+
+    def forward(self, x):
+        return self._hour_glass_forward(self.depth, x)
+
+
+class HourglassNet(nn.Module):
+    """Hourglass model from Newell et al ECCV 2016"""
+
+    def __init__(self, block, head, depth, num_stacks, num_blocks, num_classes):
+        super(HourglassNet, self).__init__()
+
+        self.inplanes = 64
+        self.num_feats = 128
+        self.num_stacks = num_stacks
+        self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=7, stride=2, padding=3)
+        self.bn1 = nn.BatchNorm2d(self.inplanes)
+        self.relu = nn.ReLU(inplace=True)
+        self.layer1 = self._make_residual(block, self.inplanes, 1)
+        self.layer2 = self._make_residual(block, self.inplanes, 1)
+        self.layer3 = self._make_residual(block, self.num_feats, 1)
+        self.maxpool = nn.MaxPool2d(2, stride=2)
+
+        # build hourglass modules
+        ch = self.num_feats * block.expansion
+        # vpts = []
+        hg, res, fc, score, fc_, score_ = [], [], [], [], [], []
+        for i in range(num_stacks):
+            hg.append(Hourglass(block, num_blocks, self.num_feats, depth))
+            res.append(self._make_residual(block, self.num_feats, num_blocks))
+            fc.append(self._make_fc(ch, ch))
+            score.append(head(ch, num_classes))
+            # vpts.append(VptsHead(ch))
+            # vpts.append(nn.Linear(ch, 9))
+            # score.append(nn.Conv2d(ch, num_classes, kernel_size=1))
+            # score[i].bias.data[0] += 4.6
+            # score[i].bias.data[2] += 4.6
+            if i < num_stacks - 1:
+                fc_.append(nn.Conv2d(ch, ch, kernel_size=1))
+                score_.append(nn.Conv2d(num_classes, ch, kernel_size=1))
+        self.hg = nn.ModuleList(hg)
+        self.res = nn.ModuleList(res)
+        self.fc = nn.ModuleList(fc)
+        self.score = nn.ModuleList(score)
+        # self.vpts = nn.ModuleList(vpts)
+        self.fc_ = nn.ModuleList(fc_)
+        self.score_ = nn.ModuleList(score_)
+
+    def _make_residual(self, block, planes, blocks, stride=1):
+        downsample = None
+        if stride != 1 or self.inplanes != planes * block.expansion:
+            downsample = nn.Sequential(
+                nn.Conv2d(
+                    self.inplanes,
+                    planes * block.expansion,
+                    kernel_size=1,
+                    stride=stride,
+                )
+            )
+
+        layers = []
+        layers.append(block(self.inplanes, planes, stride, downsample))
+        self.inplanes = planes * block.expansion
+        for i in range(1, blocks):
+            layers.append(block(self.inplanes, planes))
+
+        return nn.Sequential(*layers)
+
+    def _make_fc(self, inplanes, outplanes):
+        bn = nn.BatchNorm2d(inplanes)
+        conv = nn.Conv2d(inplanes, outplanes, kernel_size=1)
+        return nn.Sequential(conv, bn, self.relu)
+
+    def forward(self, x):
+        out = []
+        # out_vps = []
+        x = self.conv1(x)
+        x = self.bn1(x)
+        x = self.relu(x)
+
+        x = self.layer1(x)
+        x = self.maxpool(x)
+        x = self.layer2(x)
+        x = self.layer3(x)
+
+        for i in range(self.num_stacks):
+            y = self.hg[i](x)
+            y = self.res[i](y)
+            y = self.fc[i](y)
+            score = self.score[i](y)
+            # pre_vpts = F.adaptive_avg_pool2d(x, (1, 1))
+            # pre_vpts = pre_vpts.reshape(-1, 256)
+            # vpts = self.vpts[i](x)
+            out.append(score)
+            # out_vps.append(vpts)
+            if i < self.num_stacks - 1:
+                fc_ = self.fc_[i](y)
+                score_ = self.score_[i](score)
+                x = x + fc_ + score_
+
+        return out[::-1], y  # , out_vps[::-1]
+
+
+def hg(**kwargs):
+    model = HourglassNet(
+        Bottleneck2D,
+        head=kwargs.get("head", lambda c_in, c_out: nn.Conv2D(c_in, c_out, 1)),
+        depth=kwargs["depth"],
+        num_stacks=kwargs["num_stacks"],
+        num_blocks=kwargs["num_blocks"],
+        num_classes=kwargs["num_classes"],
+    )
+    return model

+ 276 - 0
lcnn/models/line_vectorizer.py

@@ -0,0 +1,276 @@
+import itertools
+import random
+from collections import defaultdict
+
+import numpy as np
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+from lcnn.config import M
+
+FEATURE_DIM = 8
+
+
+class LineVectorizer(nn.Module):
+    def __init__(self, backbone):
+        super().__init__()
+        self.backbone = backbone
+
+        lambda_ = torch.linspace(0, 1, M.n_pts0)[:, None]
+        self.register_buffer("lambda_", lambda_)
+        self.do_static_sampling = M.n_stc_posl + M.n_stc_negl > 0
+
+        self.fc1 = nn.Conv2d(256, M.dim_loi, 1)
+        scale_factor = M.n_pts0 // M.n_pts1
+        if M.use_conv:
+            self.pooling = nn.Sequential(
+                nn.MaxPool1d(scale_factor, scale_factor),
+                Bottleneck1D(M.dim_loi, M.dim_loi),
+            )
+            self.fc2 = nn.Sequential(
+                nn.ReLU(inplace=True), nn.Linear(M.dim_loi * M.n_pts1 + FEATURE_DIM, 1)
+            )
+        else:
+            self.pooling = nn.MaxPool1d(scale_factor, scale_factor)
+            self.fc2 = nn.Sequential(
+                nn.Linear(M.dim_loi * M.n_pts1 + FEATURE_DIM, M.dim_fc),
+                nn.ReLU(inplace=True),
+                nn.Linear(M.dim_fc, M.dim_fc),
+                nn.ReLU(inplace=True),
+                nn.Linear(M.dim_fc, 1),
+            )
+        self.loss = nn.BCEWithLogitsLoss(reduction="none")
+
+    def forward(self, input_dict):
+        result = self.backbone(input_dict)
+        h = result["preds"]
+        x = self.fc1(result["feature"])
+        n_batch, n_channel, row, col = x.shape
+
+        xs, ys, fs, ps, idx, jcs = [], [], [], [], [0], []
+        for i, meta in enumerate(input_dict["meta"]):
+            p, label, feat, jc = self.sample_lines(
+                meta, h["jmap"][i], h["joff"][i], input_dict["mode"]
+            )
+            # print("p.shape:", p.shape)
+            ys.append(label)
+            if input_dict["mode"] == "training" and self.do_static_sampling:
+                p = torch.cat([p, meta["lpre"]])
+                feat = torch.cat([feat, meta["lpre_feat"]])
+                ys.append(meta["lpre_label"])
+                del jc
+            else:
+                jcs.append(jc)
+                ps.append(p)
+            fs.append(feat)
+
+            p = p[:, 0:1, :] * self.lambda_ + p[:, 1:2, :] * (1 - self.lambda_) - 0.5
+            p = p.reshape(-1, 2)  # [N_LINE x N_POINT, 2_XY]
+            px, py = p[:, 0].contiguous(), p[:, 1].contiguous()
+            px0 = px.floor().clamp(min=0, max=127)
+            py0 = py.floor().clamp(min=0, max=127)
+            px1 = (px0 + 1).clamp(min=0, max=127)
+            py1 = (py0 + 1).clamp(min=0, max=127)
+            px0l, py0l, px1l, py1l = px0.long(), py0.long(), px1.long(), py1.long()
+
+            # xp: [N_LINE, N_CHANNEL, N_POINT]
+            xp = (
+                (
+                        x[i, :, px0l, py0l] * (px1 - px) * (py1 - py)
+                        + x[i, :, px1l, py0l] * (px - px0) * (py1 - py)
+                        + x[i, :, px0l, py1l] * (px1 - px) * (py - py0)
+                        + x[i, :, px1l, py1l] * (px - px0) * (py - py0)
+                )
+                    .reshape(n_channel, -1, M.n_pts0)
+                    .permute(1, 0, 2)
+            )
+            xp = self.pooling(xp)
+            xs.append(xp)
+            idx.append(idx[-1] + xp.shape[0])
+
+        x, y = torch.cat(xs), torch.cat(ys)
+        f = torch.cat(fs)
+        x = x.reshape(-1, M.n_pts1 * M.dim_loi)
+        x = torch.cat([x, f], 1)
+        x = self.fc2(x.float()).flatten()
+
+        if input_dict["mode"] != "training":
+            p = torch.cat(ps)
+            s = torch.sigmoid(x)
+            b = s > 0.5
+            lines = []
+            score = []
+            for i in range(n_batch):
+                p0 = p[idx[i]: idx[i + 1]]
+                s0 = s[idx[i]: idx[i + 1]]
+                mask = b[idx[i]: idx[i + 1]]
+                p0 = p0[mask]
+                s0 = s0[mask]
+                if len(p0) == 0:
+                    lines.append(torch.zeros([1, M.n_out_line, 2, 2], device=p.device))
+                    score.append(torch.zeros([1, M.n_out_line], device=p.device))
+                else:
+                    arg = torch.argsort(s0, descending=True)
+                    p0, s0 = p0[arg], s0[arg]
+                    lines.append(p0[None, torch.arange(M.n_out_line) % len(p0)])
+                    score.append(s0[None, torch.arange(M.n_out_line) % len(s0)])
+                for j in range(len(jcs[i])):
+                    if len(jcs[i][j]) == 0:
+                        jcs[i][j] = torch.zeros([M.n_out_junc, 2], device=p.device)
+                    jcs[i][j] = jcs[i][j][
+                        None, torch.arange(M.n_out_junc) % len(jcs[i][j])
+                    ]
+            result["preds"]["lines"] = torch.cat(lines)
+            result["preds"]["score"] = torch.cat(score)
+            result["preds"]["juncs"] = torch.cat([jcs[i][0] for i in range(n_batch)])
+            result["box"] = result['aaa']
+            del result['aaa']
+            if len(jcs[i]) > 1:
+                result["preds"]["junts"] = torch.cat(
+                    [jcs[i][1] for i in range(n_batch)]
+                )
+
+        if input_dict["mode"] != "testing":
+            y = torch.cat(ys)
+            loss = self.loss(x, y)
+            lpos_mask, lneg_mask = y, 1 - y
+            loss_lpos, loss_lneg = loss * lpos_mask, loss * lneg_mask
+
+            def sum_batch(x):
+                xs = [x[idx[i]: idx[i + 1]].sum()[None] for i in range(n_batch)]
+                return torch.cat(xs)
+
+            lpos = sum_batch(loss_lpos) / sum_batch(lpos_mask).clamp(min=1)
+            lneg = sum_batch(loss_lneg) / sum_batch(lneg_mask).clamp(min=1)
+            result["losses"][0]["lpos"] = lpos * M.loss_weight["lpos"]
+            result["losses"][0]["lneg"] = lneg * M.loss_weight["lneg"]
+
+        if input_dict["mode"] == "training":
+            for i in result["aaa"].keys():
+                result["losses"][0][i] = result["aaa"][i]
+            del result["preds"]
+
+        return result
+
+    def sample_lines(self, meta, jmap, joff, mode):
+        with torch.no_grad():
+            junc = meta["junc_coords"]  # [N, 2]
+            jtyp = meta["jtyp"]  # [N]
+            Lpos = meta["line_pos_idx"]
+            Lneg = meta["line_neg_idx"]
+
+            n_type = jmap.shape[0]
+            jmap = non_maximum_suppression(jmap).reshape(n_type, -1)
+            joff = joff.reshape(n_type, 2, -1)
+            max_K = M.n_dyn_junc // n_type
+            N = len(junc)
+            if mode != "training":
+                K = min(int((jmap > M.eval_junc_thres).float().sum().item()), max_K)
+            else:
+                K = min(int(N * 2 + 2), max_K)
+            if K < 2:
+                K = 2
+            device = jmap.device
+
+            # index: [N_TYPE, K]
+            score, index = torch.topk(jmap, k=K)
+            y = (index // 128).float() + torch.gather(joff[:, 0], 1, index) + 0.5
+            x = (index % 128).float() + torch.gather(joff[:, 1], 1, index) + 0.5
+
+            # xy: [N_TYPE, K, 2]
+            xy = torch.cat([y[..., None], x[..., None]], dim=-1)
+            xy_ = xy[..., None, :]
+            del x, y, index
+
+            # dist: [N_TYPE, K, N]
+            dist = torch.sum((xy_ - junc) ** 2, -1)
+            cost, match = torch.min(dist, -1)
+
+            # xy: [N_TYPE * K, 2]
+            # match: [N_TYPE, K]
+            for t in range(n_type):
+                match[t, jtyp[match[t]] != t] = N
+            match[cost > 1.5 * 1.5] = N
+            match = match.flatten()
+
+            _ = torch.arange(n_type * K, device=device)
+            u, v = torch.meshgrid(_, _)
+            u, v = u.flatten(), v.flatten()
+            up, vp = match[u], match[v]
+            label = Lpos[up, vp]
+
+            if mode == "training":
+                c = torch.zeros_like(label, dtype=torch.bool)
+
+                # sample positive lines
+                cdx = label.nonzero().flatten()
+                if len(cdx) > M.n_dyn_posl:
+                    # print("too many positive lines")
+                    perm = torch.randperm(len(cdx), device=device)[: M.n_dyn_posl]
+                    cdx = cdx[perm]
+                c[cdx] = 1
+
+                # sample negative lines
+                cdx = Lneg[up, vp].nonzero().flatten()
+                if len(cdx) > M.n_dyn_negl:
+                    # print("too many negative lines")
+                    perm = torch.randperm(len(cdx), device=device)[: M.n_dyn_negl]
+                    cdx = cdx[perm]
+                c[cdx] = 1
+
+                # sample other (unmatched) lines
+                cdx = torch.randint(len(c), (M.n_dyn_othr,), device=device)
+                c[cdx] = 1
+            else:
+                c = (u < v).flatten()
+
+            # sample lines
+            u, v, label = u[c], v[c], label[c]
+            xy = xy.reshape(n_type * K, 2)
+            xyu, xyv = xy[u], xy[v]
+
+            u2v = xyu - xyv
+            u2v /= torch.sqrt((u2v ** 2).sum(-1, keepdim=True)).clamp(min=1e-6)
+            feat = torch.cat(
+                [
+                    xyu / 128 * M.use_cood,
+                    xyv / 128 * M.use_cood,
+                    u2v * M.use_slop,
+                    (u[:, None] > K).float(),
+                    (v[:, None] > K).float(),
+                ],
+                1,
+            )
+            line = torch.cat([xyu[:, None], xyv[:, None]], 1)
+
+            xy = xy.reshape(n_type, K, 2)
+            jcs = [xy[i, score[i] > 0.03] for i in range(n_type)]
+            return line, label.float(), feat, jcs
+
+
+def non_maximum_suppression(a):
+    ap = F.max_pool2d(a, 3, stride=1, padding=1)
+    mask = (a == ap).float().clamp(min=0.0)
+    return a * mask
+
+
+class Bottleneck1D(nn.Module):
+    def __init__(self, inplanes, outplanes):
+        super(Bottleneck1D, self).__init__()
+
+        planes = outplanes // 2
+        self.op = nn.Sequential(
+            nn.BatchNorm1d(inplanes),
+            nn.ReLU(inplace=True),
+            nn.Conv1d(inplanes, planes, kernel_size=1),
+            nn.BatchNorm1d(planes),
+            nn.ReLU(inplace=True),
+            nn.Conv1d(planes, planes, kernel_size=3, padding=1),
+            nn.BatchNorm1d(planes),
+            nn.ReLU(inplace=True),
+            nn.Conv1d(planes, outplanes, kernel_size=1),
+        )
+
+    def forward(self, x):
+        return x + self.op(x)

+ 118 - 0
lcnn/models/multitask_learner.py

@@ -0,0 +1,118 @@
+from collections import OrderedDict, defaultdict
+
+import numpy as np
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+from lcnn.config import M
+
+
+class MultitaskHead(nn.Module):
+    def __init__(self, input_channels, num_class):
+        super(MultitaskHead, self).__init__()
+        # print("输入的维度是:", input_channels)
+        m = int(input_channels / 4)
+        heads = []
+        for output_channels in sum(M.head_size, []):
+            heads.append(
+                nn.Sequential(
+                    nn.Conv2d(input_channels, m, kernel_size=3, padding=1),
+                    nn.ReLU(inplace=True),
+                    nn.Conv2d(m, output_channels, kernel_size=1),
+                )
+            )
+        self.heads = nn.ModuleList(heads)
+        assert num_class == sum(sum(M.head_size, []))
+
+    def forward(self, x):
+        return torch.cat([head(x) for head in self.heads], dim=1)
+
+
+class MultitaskLearner(nn.Module):
+    def __init__(self, backbone):
+        super(MultitaskLearner, self).__init__()
+        self.backbone = backbone
+        head_size = M.head_size
+        self.num_class = sum(sum(head_size, []))
+        self.head_off = np.cumsum([sum(h) for h in head_size])
+
+    def forward(self, input_dict):
+        image = input_dict["image"]
+        target_b = input_dict["target_b"]
+        outputs, feature, aaa = self.backbone(image, target_b, input_dict["mode"])  # train时aaa是损失,val时是box
+
+        result = {"feature": feature}
+        batch, channel, row, col = outputs[0].shape
+        # print(f"batch:{batch}")
+        # print(f"channel:{channel}")
+        # print(f"row:{row}")
+        # print(f"col:{col}")
+
+        T = input_dict["target"].copy()
+        n_jtyp = T["junc_map"].shape[1]
+
+        # switch to CNHW
+        for task in ["junc_map"]:
+            T[task] = T[task].permute(1, 0, 2, 3)
+        for task in ["junc_offset"]:
+            T[task] = T[task].permute(1, 2, 0, 3, 4)
+
+        offset = self.head_off
+        loss_weight = M.loss_weight
+        losses = []
+        for stack, output in enumerate(outputs):
+            output = output.transpose(0, 1).reshape([-1, batch, row, col]).contiguous()
+            jmap = output[0: offset[0]].reshape(n_jtyp, 2, batch, row, col)
+            lmap = output[offset[0]: offset[1]].squeeze(0)
+            # print(f"lmap:{lmap.shape}")
+            joff = output[offset[1]: offset[2]].reshape(n_jtyp, 2, batch, row, col)
+            if stack == 0:
+                result["preds"] = {
+                    "jmap": jmap.permute(2, 0, 1, 3, 4).softmax(2)[:, :, 1],
+                    "lmap": lmap.sigmoid(),
+                    "joff": joff.permute(2, 0, 1, 3, 4).sigmoid() - 0.5,
+                }
+                if input_dict["mode"] == "testing":
+                    return result
+
+            L = OrderedDict()
+            L["jmap"] = sum(
+                cross_entropy_loss(jmap[i], T["junc_map"][i]) for i in range(n_jtyp)
+            )
+            L["lmap"] = (
+                F.binary_cross_entropy_with_logits(lmap, T["line_map"], reduction="none")
+                    .mean(2)
+                    .mean(1)
+            )
+            L["joff"] = sum(
+                sigmoid_l1_loss(joff[i, j], T["junc_offset"][i, j], -0.5, T["junc_map"][i])
+                for i in range(n_jtyp)
+                for j in range(2)
+            )
+            for loss_name in L:
+                L[loss_name].mul_(loss_weight[loss_name])
+            losses.append(L)
+        result["losses"] = losses
+        result["aaa"] = aaa
+        return result
+
+
+def l2loss(input, target):
+    return ((target - input) ** 2).mean(2).mean(1)
+
+
+def cross_entropy_loss(logits, positive):
+    nlogp = -F.log_softmax(logits, dim=0)
+    return (positive * nlogp[1] + (1 - positive) * nlogp[0]).mean(2).mean(1)
+
+
+def sigmoid_l1_loss(logits, target, offset=0.0, mask=None):
+    logp = torch.sigmoid(logits) + offset
+    loss = torch.abs(logp - target)
+    if mask is not None:
+        w = mask.mean(2, True).mean(1, True)
+        w[w == 0] = 1
+        loss = loss * (mask / w)
+
+    return loss.mean(2).mean(1)

+ 182 - 0
lcnn/models/resnet50.py

@@ -0,0 +1,182 @@
+# import torch
+# import torch.nn as nn
+# import torchvision.models as models
+#
+#
+# class ResNet50Backbone(nn.Module):
+#     def __init__(self, num_classes=5, num_stacks=1, pretrained=True):
+#         super(ResNet50Backbone, self).__init__()
+#
+#         # 加载预训练的ResNet50
+#         if pretrained:
+#             resnet = models.resnet50(weights=models.ResNet50_Weights.IMAGENET1K_V1)
+#         else:
+#             resnet = models.resnet50(weights=None)
+#
+#         # 移除最后的全连接层
+#         self.backbone = nn.Sequential(
+#             resnet.conv1,#特征图分辨率降低为1/2,通道数从3升为64
+#             resnet.bn1,
+#             resnet.relu,
+#             resnet.maxpool,#特征图分辨率降低为1/4,通道数仍然为64
+#             resnet.layer1,#stride为1,不改变分辨率,依然为1/4,通道数从64升为64*4=256
+#             # resnet.layer2,#stride为2,特征图分辨率降低为1/8,通道数从256升为128*4=512
+#             # resnet.layer3,#stride为2,特征图分辨率降低为1/16,通道数从512升为256*4=1024
+#             # resnet.layer4,#stride为2,特征图分辨率降低为1/32,通道数从512升为256*4=2048
+#         )
+#
+#         # 多任务输出层
+#         self.score_layers = nn.ModuleList([
+#             nn.Sequential(
+#                 nn.Conv2d(256, 128, kernel_size=3, padding=1),
+#                 nn.BatchNorm2d(128),
+#                 nn.ReLU(inplace=True),
+#                 nn.Conv2d(128, num_classes, kernel_size=1)
+#             )
+#             for _ in range(num_stacks)
+#         ])
+#
+#         # 上采样层,确保输出大小为128x128
+#         self.upsample = nn.Upsample(
+#             scale_factor=0.25,
+#             mode='bilinear',
+#             align_corners=True
+#         )
+#
+#     def forward(self, x):
+#         # 主干网络特征提取
+#         x = self.backbone(x)
+#
+#         # # 调整通道数
+#         # x = self.channel_adjust(x)
+#         #
+#         # # 上采样到128x128
+#         # x = self.upsample(x)
+#
+#         # 多堆栈输出
+#         outputs = []
+#         for score_layer in self.score_layers:
+#             output = score_layer(x)
+#             outputs.append(output)
+#
+#         # 返回第一个输出(如果有多个堆栈)
+#         return outputs, x
+#
+#
+# def resnet50(**kwargs):
+#     model = ResNet50Backbone(
+#         num_classes=kwargs.get("num_classes", 5),
+#         num_stacks=kwargs.get("num_stacks", 1),
+#         pretrained=kwargs.get("pretrained", True)
+#     )
+#     return model
+#
+#
+# __all__ = ["ResNet50Backbone", "resnet50"]
+#
+#
+# # 测试网络输出
+# model = resnet50(num_classes=5, num_stacks=1)
+#
+#
+# # 方法1:直接传入图像张量
+# x = torch.randn(2, 3, 512, 512)
+# outputs, feature = model(x)
+# print("Outputs length:", len(outputs))
+# print("Output[0] shape:", outputs[0].shape)
+# print("Feature shape:", feature.shape)
+
+import torch
+import torch.nn as nn
+import torchvision
+import torchvision.models as models
+from torchvision.models.detection.backbone_utils import _validate_trainable_layers, _resnet_fpn_extractor
+
+
+class ResNet50Backbone1(nn.Module):
+    def __init__(self, num_classes=5, num_stacks=1, pretrained=True):
+        super(ResNet50Backbone1, self).__init__()
+
+        # 加载预训练的ResNet50
+        if pretrained:
+            # self.resnet = models.resnet50(weights=models.ResNet50_Weights.IMAGENET1K_V1)
+            trainable_backbone_layers = _validate_trainable_layers(True, None, 5, 3)
+            backbone = models.resnet50(weights=models.ResNet50_Weights.IMAGENET1K_V1)
+            self.resnet = _resnet_fpn_extractor(backbone, trainable_backbone_layers)
+        else:
+            self.resnet = models.resnet50(weights=None)
+
+    def forward(self, x):
+        features = self.resnet(x)
+
+        # if self.training:
+        #     proposals, matched_idxs, labels, regression_targets = self.select_training_samples20(proposals, targets)
+        # else:
+        #     labels = None
+        #     regression_targets = None
+        #     matched_idxs = None
+        # box_features = self.box_roi_pool(features, proposals, image_shapes)  # ROI 映射到固定尺寸的特征表示上
+        # # print(f"box_features:{box_features.shape}")  # [2048, 256, 7, 7] 建议框统一大小
+        # box_features = self.box_head(box_features)  #
+        # # print(f"box_features:{box_features.shape}")  # [N, 1024] 经过头部网络处理后的特征向量
+        # class_logits, box_regression = self.box_predictor(box_features)
+        #
+        # result: List[Dict[str, torch.Tensor]] = []
+        # losses = {}
+        # if self.training:
+        #     if labels is None:
+        #         raise ValueError("labels cannot be None")
+        #     if regression_targets is None:
+        #         raise ValueError("regression_targets cannot be None")
+        #     loss_classifier, loss_box_reg = fastrcnn_loss(class_logits, box_regression, labels, regression_targets)
+        #     losses = {"loss_classifier": loss_classifier, "loss_box_reg": loss_box_reg}
+        # else:
+        #     boxes, scores, labels = self.postprocess_detections(class_logits, box_regression, proposals, image_shapes)
+        #     num_images = len(boxes)
+        #     for i in range(num_images):
+        #         result.append(
+        #             {
+        #                 "boxes": boxes[i],
+        #                 "labels": labels[i],
+        #                 "scores": scores[i],
+        #             }
+        #         )
+        #     print(f"boxes:{boxes[0].shape}")
+
+        return x['0']
+
+        # # 多堆栈输出
+        # outputs = []
+        # for score_layer in self.score_layers:
+        #     output = score_layer(x)
+        #     outputs.append(output)
+        #
+        # # 返回第一个输出(如果有多个堆栈)
+        # return outputs, x
+
+
+def resnet501(**kwargs):
+    model = ResNet50Backbone1(
+        num_classes=kwargs.get("num_classes", 5),
+        num_stacks=kwargs.get("num_stacks", 1),
+        pretrained=kwargs.get("pretrained", True)
+    )
+    return model
+
+
+# __all__ = ["ResNet50Backbone1", "resnet501"]
+#
+# # 测试网络输出
+# model = resnet501(num_classes=5, num_stacks=1)
+#
+# # 方法1:直接传入图像张量
+# x = torch.randn(2, 3, 512, 512)
+# # outputs, feature = model(x)
+# # print("Outputs length:", len(outputs))
+# # print("Output[0] shape:", outputs[0].shape)
+# # print("Feature shape:", feature.shape)
+# feature = model(x)
+# print("Feature shape:", feature.keys())
+
+
+

+ 87 - 0
lcnn/models/resnet50_pose.py

@@ -0,0 +1,87 @@
+import torch
+import torch.nn as nn
+import torchvision.models as models
+
+
+class ResNet50Backbone(nn.Module):
+    def __init__(self, num_classes=5, num_stacks=1, pretrained=True):
+        super(ResNet50Backbone, self).__init__()
+
+        # 加载预训练的ResNet50
+        if pretrained:
+            resnet = models.resnet50(weights=models.ResNet50_Weights.IMAGENET1K_V1)
+        else:
+            resnet = models.resnet50(weights=None)
+
+        # 移除最后的全连接层
+        self.backbone = nn.Sequential(
+            resnet.conv1,#特征图分辨率降低为1/2,通道数从3升为64
+            resnet.bn1,
+            resnet.relu,
+            resnet.maxpool,#特征图分辨率降低为1/4,通道数仍然为64
+            resnet.layer1,#stride为1,不改变分辨率,依然为1/4,通道数从64升为64*4=256
+            # resnet.layer2,#stride为2,特征图分辨率降低为1/8,通道数从256升为128*4=512
+            # resnet.layer3,#stride为2,特征图分辨率降低为1/16,通道数从512升为256*4=1024
+            # resnet.layer4,#stride为2,特征图分辨率降低为1/32,通道数从512升为256*4=2048
+        )
+
+        # 多任务输出层
+        self.score_layers = nn.ModuleList([
+            nn.Sequential(
+                nn.Conv2d(256, 128, kernel_size=3, padding=1),
+                nn.BatchNorm2d(128),
+                nn.ReLU(inplace=True),
+                nn.Conv2d(128, num_classes, kernel_size=1)
+            )
+            for _ in range(num_stacks)
+        ])
+
+        # 上采样层,确保输出大小为128x128
+        self.upsample = nn.Upsample(
+            scale_factor=0.25,
+            mode='bilinear',
+            align_corners=True
+        )
+
+    def forward(self, x):
+        # 主干网络特征提取
+        x = self.backbone(x)
+
+        # # 调整通道数
+        # x = self.channel_adjust(x)
+        #
+        # # 上采样到128x128
+        # x = self.upsample(x)
+
+        # 多堆栈输出
+        outputs = []
+        for score_layer in self.score_layers:
+            output = score_layer(x)
+            outputs.append(output)
+
+        # 返回第一个输出(如果有多个堆栈)
+        return outputs, x
+
+
+def resnet50(**kwargs):
+    model = ResNet50Backbone(
+        num_classes=kwargs.get("num_classes", 5),
+        num_stacks=kwargs.get("num_stacks", 1),
+        pretrained=kwargs.get("pretrained", True)
+    )
+    return model
+
+
+__all__ = ["ResNet50Backbone", "resnet50"]
+
+
+# 测试网络输出
+model = resnet50(num_classes=5, num_stacks=1)
+
+
+# 方法1:直接传入图像张量
+x = torch.randn(2, 3, 512, 512)
+outputs, feature = model(x)
+print("Outputs length:", len(outputs))
+print("Output[0] shape:", outputs[0].shape)
+print("Feature shape:", feature.shape)

+ 126 - 0
lcnn/models/unet.py

@@ -0,0 +1,126 @@
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+__all__ = ["UNetWithMultipleStacks", "unet"]
+
+
+class DoubleConv(nn.Module):
+    def __init__(self, in_channels, out_channels):
+        super().__init__()
+        self.double_conv = nn.Sequential(
+            nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
+            nn.BatchNorm2d(out_channels),
+            nn.ReLU(inplace=True),
+            nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1),
+            nn.BatchNorm2d(out_channels),
+            nn.ReLU(inplace=True)
+        )
+
+    def forward(self, x):
+        return self.double_conv(x)
+
+
+class UNetWithMultipleStacks(nn.Module):
+    def __init__(self, num_classes, num_stacks=2, base_channels=64):
+        super().__init__()
+        self.num_stacks = num_stacks
+        # 编码器
+        self.enc1 = DoubleConv(3, base_channels)
+        self.pool1 = nn.MaxPool2d(2)
+        self.enc2 = DoubleConv(base_channels, base_channels * 2)
+        self.pool2 = nn.MaxPool2d(2)
+        self.enc3 = DoubleConv(base_channels * 2, base_channels * 4)
+        self.pool3 = nn.MaxPool2d(2)
+        self.enc4 = DoubleConv(base_channels * 4, base_channels * 8)
+        self.pool4 = nn.MaxPool2d(2)
+        self.enc5 = DoubleConv(base_channels * 8, base_channels * 16)
+        self.pool5 = nn.MaxPool2d(2)
+
+        # bottleneck
+        self.bottleneck = DoubleConv(base_channels * 16, base_channels * 32)
+
+        # 解码器
+        self.upconv5 = nn.ConvTranspose2d(base_channels * 32, base_channels * 16, kernel_size=2, stride=2)
+        self.dec5 = DoubleConv(base_channels * 32, base_channels * 16)
+        self.upconv4 = nn.ConvTranspose2d(base_channels * 16, base_channels * 8, kernel_size=2, stride=2)
+        self.dec4 = DoubleConv(base_channels * 16, base_channels * 8)
+        self.upconv3 = nn.ConvTranspose2d(base_channels * 8, base_channels * 4, kernel_size=2, stride=2)
+        self.dec3 = DoubleConv(base_channels * 8, base_channels * 4)
+        self.upconv2 = nn.ConvTranspose2d(base_channels * 4, base_channels * 2, kernel_size=2, stride=2)
+        self.dec2 = DoubleConv(base_channels * 4, base_channels * 2)
+        self.upconv1 = nn.ConvTranspose2d(base_channels * 2, base_channels, kernel_size=2, stride=2)
+        self.dec1 = DoubleConv(base_channels * 2, base_channels)
+
+        # 额外的上采样层,从512降到128
+        self.final_upsample = nn.Sequential(
+            nn.Conv2d(base_channels, base_channels, kernel_size=3, padding=1),
+            nn.BatchNorm2d(base_channels),
+            nn.ReLU(inplace=True),
+            nn.Upsample(scale_factor=0.25, mode='bilinear', align_corners=True)
+        )
+
+        # 修改score_layers以匹配256通道
+        self.score_layers = nn.ModuleList([
+            nn.Sequential(
+                nn.Conv2d(256, 128, kernel_size=3, padding=1),
+                nn.ReLU(inplace=True),
+                nn.Conv2d(128, num_classes, kernel_size=1)
+            )
+            for _ in range(num_stacks)
+        ])
+
+        self.channel_adjust = nn.Conv2d(base_channels, 256, kernel_size=1)
+
+    def forward(self, x):
+        # 编码过程
+        enc1 = self.enc1(x)
+        enc2 = self.enc2(self.pool1(enc1))
+        enc3 = self.enc3(self.pool2(enc2))
+        enc4 = self.enc4(self.pool3(enc3))
+        enc5 = self.enc5(self.pool4(enc4))
+
+        # 瓶颈层
+        bottleneck = self.bottleneck(self.pool5(enc5))
+
+        # 解码过程
+        dec5 = self.upconv5(bottleneck)
+        dec5 = torch.cat([dec5, enc5], dim=1)
+        dec5 = self.dec5(dec5)
+
+        dec4 = self.upconv4(dec5)
+        dec4 = torch.cat([dec4, enc4], dim=1)
+        dec4 = self.dec4(dec4)
+
+        dec3 = self.upconv3(dec4)
+        dec3 = torch.cat([dec3, enc3], dim=1)
+        dec3 = self.dec3(dec3)
+
+        dec2 = self.upconv2(dec3)
+        dec2 = torch.cat([dec2, enc2], dim=1)
+        dec2 = self.dec2(dec2)
+
+        dec1 = self.upconv1(dec2)
+        dec1 = torch.cat([dec1, enc1], dim=1)
+        dec1 = self.dec1(dec1)
+
+        # 额外的上采样,使输出大小为128
+        dec1 = self.final_upsample(dec1)
+        # 调整通道数
+        dec1 = self.channel_adjust(dec1)
+        # 多堆栈输出
+        outputs = []
+        for score_layer in self.score_layers:
+            output = score_layer(dec1)
+            outputs.append(output)
+
+        return outputs[::-1], dec1
+
+
+def unet(**kwargs):
+    model = UNetWithMultipleStacks(
+        num_classes=kwargs["num_classes"],
+        num_stacks=kwargs.get("num_stacks", 2),
+        base_channels=kwargs.get("base_channels", 64)
+    )
+    return model

+ 77 - 0
lcnn/postprocess.py

@@ -0,0 +1,77 @@
+import numpy as np
+
+
+def pline(x1, y1, x2, y2, x, y):
+    px = x2 - x1
+    py = y2 - y1
+    dd = px * px + py * py
+    u = ((x - x1) * px + (y - y1) * py) / max(1e-9, float(dd))
+    dx = x1 + u * px - x
+    dy = y1 + u * py - y
+    return dx * dx + dy * dy
+
+
+def psegment(x1, y1, x2, y2, x, y):
+    px = x2 - x1
+    py = y2 - y1
+    dd = px * px + py * py
+    u = max(min(((x - x1) * px + (y - y1) * py) / float(dd), 1), 0)
+    dx = x1 + u * px - x
+    dy = y1 + u * py - y
+    return dx * dx + dy * dy
+
+
+def plambda(x1, y1, x2, y2, x, y):
+    px = x2 - x1
+    py = y2 - y1
+    dd = px * px + py * py
+    return ((x - x1) * px + (y - y1) * py) / max(1e-9, float(dd))
+
+
+def postprocess(lines, scores, threshold=0.01, tol=1e9, do_clip=False):
+    nlines, nscores = [], []
+    for (p, q), score in zip(lines, scores):
+        start, end = 0, 1
+        for a, b in nlines:
+            if (
+                min(
+                    max(pline(*p, *q, *a), pline(*p, *q, *b)),
+                    max(pline(*a, *b, *p), pline(*a, *b, *q)),
+                )
+                > threshold ** 2
+            ):
+                continue
+            lambda_a = plambda(*p, *q, *a)
+            lambda_b = plambda(*p, *q, *b)
+            if lambda_a > lambda_b:
+                lambda_a, lambda_b = lambda_b, lambda_a
+            lambda_a -= tol
+            lambda_b += tol
+
+            # case 1: skip (if not do_clip)
+            if start < lambda_a and lambda_b < end:
+                continue
+
+            # not intersect
+            if lambda_b < start or lambda_a > end:
+                continue
+
+            # cover
+            if lambda_a <= start and end <= lambda_b:
+                start = 10
+                break
+
+            # case 2 & 3:
+            if lambda_a <= start and start <= lambda_b:
+                start = lambda_b
+            if lambda_a <= end and end <= lambda_b:
+                end = lambda_a
+
+            if start >= end:
+                break
+
+        if start >= end:
+            continue
+        nlines.append(np.array([p + (q - p) * start, p + (q - p) * end]))
+        nscores.append(score)
+    return np.array(nlines), np.array(nscores)

+ 429 - 0
lcnn/trainer.py

@@ -0,0 +1,429 @@
+import atexit
+import os
+import os.path as osp
+import shutil
+import signal
+import subprocess
+import threading
+import time
+from timeit import default_timer as timer
+
+import matplotlib as mpl
+import matplotlib.pyplot as plt
+import numpy as np
+import torch
+import torch.nn.functional as F
+from skimage import io
+from tensorboardX import SummaryWriter
+
+from lcnn.config import C, M
+from lcnn.utils import recursive_to
+import matplotlib
+
+from 冻结参数训练 import verify_freeze_params
+import os
+
+from torchvision.utils import draw_bounding_boxes
+from torchvision import transforms
+from .postprocess import postprocess
+
+os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'expandable_segments:True'
+
+
+# matplotlib.use('Agg')  # 使用无窗口后端
+
+
+# 绘图
+def show_line(img, pred, epoch, writer):
+    im = img.permute(1, 2, 0)
+    writer.add_image("ori", im, epoch, dataformats="HWC")
+
+    boxed_image = draw_bounding_boxes((img * 255).to(torch.uint8), pred["box"][0]["boxes"],
+                                      colors="yellow", width=1)
+    writer.add_image("boxes", boxed_image.permute(1, 2, 0), epoch, dataformats="HWC")
+
+    PLTOPTS = {"color": "#33FFFF", "s": 15, "edgecolors": "none", "zorder": 5}
+    H = pred["preds"]
+    lines = H["lines"][0].cpu().numpy() / 128 * im.shape[:2]
+    scores = H["score"][0].cpu().numpy()
+    for i in range(1, len(lines)):
+        if (lines[i] == lines[0]).all():
+            lines = lines[:i]
+            scores = scores[:i]
+            break
+
+    # postprocess lines to remove overlapped lines
+    diag = (im.shape[0] ** 2 + im.shape[1] ** 2) ** 0.5
+    nlines, nscores = postprocess(lines, scores, diag * 0.01, 0, False)
+
+    for i, t in enumerate([0.7]):
+        plt.gca().set_axis_off()
+        plt.subplots_adjust(top=1, bottom=0, right=1, left=0, hspace=0, wspace=0)
+        plt.margins(0, 0)
+        for (a, b), s in zip(nlines, nscores):
+            if s < t:
+                continue
+            plt.plot([a[1], b[1]], [a[0], b[0]], c=c(s), linewidth=2, zorder=s)
+            plt.scatter(a[1], a[0], **PLTOPTS)
+            plt.scatter(b[1], b[0], **PLTOPTS)
+        plt.gca().xaxis.set_major_locator(plt.NullLocator())
+        plt.gca().yaxis.set_major_locator(plt.NullLocator())
+        plt.imshow(im)
+        plt.tight_layout()
+        fig = plt.gcf()
+        fig.canvas.draw()
+        image_from_plot = np.frombuffer(fig.canvas.tostring_rgb(), dtype=np.uint8).reshape(
+            fig.canvas.get_width_height()[::-1] + (3,))
+        plt.close()
+        img2 = transforms.ToTensor()(image_from_plot)
+
+        writer.add_image("output", img2, epoch)
+
+class Trainer(object):
+    def __init__(self, device, model, optimizer, train_loader, val_loader, out):
+        self.device = device
+
+        self.model = model
+        self.optim = optimizer
+
+        self.train_loader = train_loader
+        self.val_loader = val_loader
+        self.batch_size = C.model.batch_size
+
+        self.validation_interval = C.io.validation_interval
+
+        self.out = out
+        if not osp.exists(self.out):
+            os.makedirs(self.out)
+
+        # self.run_tensorboard()
+        self.writer = SummaryWriter('logs/')
+        time.sleep(1)
+
+        self.epoch = 0
+        self.iteration = 0
+        self.max_epoch = C.optim.max_epoch
+        self.lr_decay_epoch = C.optim.lr_decay_epoch
+        self.num_stacks = C.model.num_stacks
+        self.mean_loss = self.best_mean_loss = 1e1000
+
+        self.loss_labels = None
+        self.avg_metrics = None
+        self.metrics = np.zeros(0)
+
+        self.show_line = show_line
+
+    # def run_tensorboard(self):
+    #     board_out = osp.join(self.out, "tensorboard")
+    #     if not osp.exists(board_out):
+    #         os.makedirs(board_out)
+    #     self.writer = SummaryWriter(board_out)
+    #     os.environ["CUDA_VISIBLE_DEVICES"] = ""
+    #     p = subprocess.Popen(
+    #         ["tensorboard", f"--logdir={board_out}", f"--port={C.io.tensorboard_port}"]
+    #     )
+    #
+    #     def killme():
+    #         os.kill(p.pid, signal.SIGTERM)
+    #
+    #     atexit.register(killme)
+
+    def _loss(self, result):
+        losses = result["losses"]
+        # Don't move loss label to other place.
+        # If I want to change the loss, I just need to change this function.
+        if self.loss_labels is None:
+            self.loss_labels = ["sum"] + list(losses[0].keys())
+            self.metrics = np.zeros([self.num_stacks, len(self.loss_labels)])
+            print()
+            print(
+                "| ".join(
+                    ["progress "]
+                    + list(map("{:7}".format, self.loss_labels))
+                    + ["speed"]
+                )
+            )
+            with open(f"{self.out}/loss.csv", "a") as fout:
+                print(",".join(["progress"] + self.loss_labels), file=fout)
+
+        total_loss = 0
+        for i in range(self.num_stacks):
+            for j, name in enumerate(self.loss_labels):
+                if name == "sum":
+                    continue
+                if name not in losses[i]:
+                    assert i != 0
+                    continue
+                loss = losses[i][name].mean()
+                self.metrics[i, 0] += loss.item()
+                self.metrics[i, j] += loss.item()
+                total_loss += loss
+        return total_loss
+
+
+    def validate(self):
+        tprint("Running validation...", " " * 75)
+        training = self.model.training
+        self.model.eval()
+
+        # viz = osp.join(self.out, "viz", f"{self.iteration * M.batch_size_eval:09d}")
+        # npz = osp.join(self.out, "npz", f"{self.iteration * M.batch_size_eval:09d}")
+        # osp.exists(viz) or os.makedirs(viz)
+        # osp.exists(npz) or os.makedirs(npz)
+
+        total_loss = 0
+        self.metrics[...] = 0
+        with torch.no_grad():
+            for batch_idx, (image, meta, target, target_b) in enumerate(self.val_loader):
+                input_dict = {
+                    "image": recursive_to(image, self.device),
+                    "meta": recursive_to(meta, self.device),
+                    "target": recursive_to(target, self.device),
+                    "target_b": recursive_to(target_b, self.device),
+                    "mode": "validation",
+                }
+                result = self.model(input_dict)
+                # print(f'image:{image.shape}')
+                # print(result["box"])
+
+                # total_loss += self._loss(result)
+
+                print(f'self.epoch:{self.epoch}')
+                # print(result.keys())
+                self.show_line(image[0], result, self.epoch, self.writer)
+
+
+
+                # H = result["preds"]
+                # for i in range(H["jmap"].shape[0]):
+                #     index = batch_idx * M.batch_size_eval + i
+                #     np.savez(
+                #         f"{npz}/{index:06}.npz",
+                #         **{k: v[i].cpu().numpy() for k, v in H.items()},
+                #     )
+                #     if index >= 20:
+                #         continue
+                #     self._plot_samples(i, index, H, meta, target, f"{viz}/{index:06}")
+
+        # self._write_metrics(len(self.val_loader), total_loss, "validation", True)
+        # self.mean_loss = total_loss / len(self.val_loader)
+
+        torch.save(
+            {
+                "iteration": self.iteration,
+                "arch": self.model.__class__.__name__,
+                "optim_state_dict": self.optim.state_dict(),
+                "model_state_dict": self.model.state_dict(),
+                "best_mean_loss": self.best_mean_loss,
+            },
+            osp.join(self.out, "checkpoint_latest.pth"),
+        )
+        # shutil.copy(
+        #     osp.join(self.out, "checkpoint_latest.pth"),
+        #     osp.join(npz, "checkpoint.pth"),
+        # )
+        if self.mean_loss < self.best_mean_loss:
+            self.best_mean_loss = self.mean_loss
+            shutil.copy(
+                osp.join(self.out, "checkpoint_latest.pth"),
+                osp.join(self.out, "checkpoint_best.pth"),
+            )
+
+        if training:
+            self.model.train()
+
+    def verify_freeze_params(model, freeze_config):
+        """
+        验证参数冻结是否生效
+        """
+        print("\n===== Verifying Parameter Freezing =====")
+
+        for name, module in model.named_children():
+            if name in freeze_config:
+                if freeze_config[name]:
+                    print(f"\nChecking module: {name}")
+                    for param_name, param in module.named_parameters():
+                        print(f"  {param_name}: requires_grad = {param.requires_grad}")
+
+            # 特别处理fc2子模块
+            if name == 'fc2' and 'fc2_submodules' in freeze_config:
+                for subname, submodule in module.named_children():
+                    if subname in freeze_config['fc2_submodules']:
+                        if freeze_config['fc2_submodules'][subname]:
+                            print(f"\nChecking fc2 submodule: {subname}")
+                            for param_name, param in submodule.named_parameters():
+                                print(f"  {param_name}: requires_grad = {param.requires_grad}")
+
+    def train_epoch(self):
+        self.model.train()
+
+        time = timer()
+        for batch_idx, (image, meta, target, target_b) in enumerate(self.train_loader):
+            self.optim.zero_grad()
+            self.metrics[...] = 0
+
+            input_dict = {
+                "image": recursive_to(image, self.device),
+                "meta": recursive_to(meta, self.device),
+                "target": recursive_to(target, self.device),
+                "target_b": recursive_to(target_b, self.device),
+                "mode": "training",
+            }
+            result = self.model(input_dict)
+
+            loss = self._loss(result)
+            if np.isnan(loss.item()):
+                raise ValueError("loss is nan while training")
+            loss.backward()
+            self.optim.step()
+
+            if self.avg_metrics is None:
+                self.avg_metrics = self.metrics
+            else:
+                self.avg_metrics = self.avg_metrics * 0.9 + self.metrics * 0.1
+            self.iteration += 1
+            self._write_metrics(1, loss.item(), "training", do_print=False)
+
+            if self.iteration % 4 == 0:
+                tprint(
+                    f"{self.epoch:03}/{self.iteration * self.batch_size // 1000:04}k| "
+                    + "| ".join(map("{:.5f}".format, self.avg_metrics[0]))
+                    + f"| {4 * self.batch_size / (timer() - time):04.1f} "
+                )
+                time = timer()
+            num_images = self.batch_size * self.iteration
+            # if num_images % self.validation_interval == 0 or num_images == 4:
+            #     self.validate()
+            #     time = timer()
+        self.validate()
+        # verify_freeze_params()
+
+    def _write_metrics(self, size, total_loss, prefix, do_print=False):
+        for i, metrics in enumerate(self.metrics):
+            for label, metric in zip(self.loss_labels, metrics):
+                self.writer.add_scalar(
+                    f"{prefix}/{i}/{label}", metric / size, self.iteration
+                )
+            if i == 0 and do_print:
+                csv_str = (
+                        f"{self.epoch:03}/{self.iteration * self.batch_size:07},"
+                        + ",".join(map("{:.11f}".format, metrics / size))
+                )
+                prt_str = (
+                        f"{self.epoch:03}/{self.iteration * self.batch_size // 1000:04}k| "
+                        + "| ".join(map("{:.5f}".format, metrics / size))
+                )
+                with open(f"{self.out}/loss.csv", "a") as fout:
+                    print(csv_str, file=fout)
+                pprint(prt_str, " " * 7)
+        self.writer.add_scalar(
+            f"{prefix}/total_loss", total_loss / size, self.iteration
+        )
+        return total_loss
+
+    def _plot_samples(self, i, index, result, meta, target, prefix):
+        fn = self.val_loader.dataset.filelist[index][:-10].replace("_a0", "") + ".png"
+        img = io.imread(fn)
+        imshow(img), plt.savefig(f"{prefix}_img.jpg"), plt.close()
+
+        mask_result = result["jmap"][i].cpu().numpy()
+        mask_target = target["jmap"][i].cpu().numpy()
+        for ch, (ia, ib) in enumerate(zip(mask_target, mask_result)):
+            imshow(ia), plt.savefig(f"{prefix}_mask_{ch}a.jpg"), plt.close()
+            imshow(ib), plt.savefig(f"{prefix}_mask_{ch}b.jpg"), plt.close()
+
+        line_result = result["lmap"][i].cpu().numpy()
+        line_target = target["lmap"][i].cpu().numpy()
+        imshow(line_target), plt.savefig(f"{prefix}_line_a.jpg"), plt.close()
+        imshow(line_result), plt.savefig(f"{prefix}_line_b.jpg"), plt.close()
+
+        def draw_vecl(lines, sline, juncs, junts, fn):
+            imshow(img)
+            if len(lines) > 0 and not (lines[0] == 0).all():
+                for i, ((a, b), s) in enumerate(zip(lines, sline)):
+                    if i > 0 and (lines[i] == lines[0]).all():
+                        break
+                    plt.plot([a[1], b[1]], [a[0], b[0]], c=c(s), linewidth=4)
+            if not (juncs[0] == 0).all():
+                for i, j in enumerate(juncs):
+                    if i > 0 and (i == juncs[0]).all():
+                        break
+                    plt.scatter(j[1], j[0], c="red", s=64, zorder=100)
+            if junts is not None and len(junts) > 0 and not (junts[0] == 0).all():
+                for i, j in enumerate(junts):
+                    if i > 0 and (i == junts[0]).all():
+                        break
+                    plt.scatter(j[1], j[0], c="blue", s=64, zorder=100)
+            plt.savefig(fn), plt.close()
+
+        junc = meta[i]["junc"].cpu().numpy() * 4
+        jtyp = meta[i]["jtyp"].cpu().numpy()
+        juncs = junc[jtyp == 0]
+        junts = junc[jtyp == 1]
+        rjuncs = result["juncs"][i].cpu().numpy() * 4
+        rjunts = None
+        if "junts" in result:
+            rjunts = result["junts"][i].cpu().numpy() * 4
+
+        lpre = meta[i]["lpre"].cpu().numpy() * 4
+        vecl_target = meta[i]["lpre_label"].cpu().numpy()
+        vecl_result = result["lines"][i].cpu().numpy() * 4
+        score = result["score"][i].cpu().numpy()
+        lpre = lpre[vecl_target == 1]
+
+        draw_vecl(lpre, np.ones(lpre.shape[0]), juncs, junts, f"{prefix}_vecl_a.jpg")
+        draw_vecl(vecl_result, score, rjuncs, rjunts, f"{prefix}_vecl_b.jpg")
+
+    def train(self):
+        plt.rcParams["figure.figsize"] = (24, 24)
+        # if self.iteration == 0:
+        #     self.validate()
+        epoch_size = len(self.train_loader)
+        start_epoch = self.iteration // epoch_size
+
+        for self.epoch in range(start_epoch, self.max_epoch):
+            print(f"Epoch {self.epoch}/{C.optim.max_epoch} - Iteration {self.iteration}/{epoch_size}")
+            if self.epoch == self.lr_decay_epoch:
+                self.optim.param_groups[0]["lr"] /= 10
+            self.train_epoch()
+
+
+cmap = plt.get_cmap("jet")
+norm = mpl.colors.Normalize(vmin=0.4, vmax=1.0)
+sm = plt.cm.ScalarMappable(cmap=cmap, norm=norm)
+sm.set_array([])
+
+
+def c(x):
+    return sm.to_rgba(x)
+
+
+def imshow(im):
+    plt.close()
+    plt.tight_layout()
+    plt.imshow(im)
+    plt.colorbar(sm, fraction=0.046)
+    plt.xlim([0, im.shape[0]])
+    plt.ylim([im.shape[0], 0])
+
+
+def tprint(*args):
+    """Temporarily prints things on the screen"""
+    print("\r", end="")
+    print(*args, end="")
+
+
+def pprint(*args):
+    """Permanently prints things on the screen"""
+    print("\r", end="")
+    print(*args)
+
+
+def _launch_tensorboard(board_out, port, out):
+    os.environ["CUDA_VISIBLE_DEVICES"] = ""
+    p = subprocess.Popen(["tensorboard", f"--logdir={board_out}", f"--port={port}"])
+
+    def kill():
+        os.kill(p.pid, signal.SIGTERM)
+
+    atexit.register(kill)

+ 101 - 0
lcnn/utils.py

@@ -0,0 +1,101 @@
+import math
+import os.path as osp
+import multiprocessing
+from timeit import default_timer as timer
+
+import numpy as np
+import torch
+import matplotlib.pyplot as plt
+
+
+class benchmark(object):
+    def __init__(self, msg, enable=True, fmt="%0.3g"):
+        self.msg = msg
+        self.fmt = fmt
+        self.enable = enable
+
+    def __enter__(self):
+        if self.enable:
+            self.start = timer()
+        return self
+
+    def __exit__(self, *args):
+        if self.enable:
+            t = timer() - self.start
+            print(("%s : " + self.fmt + " seconds") % (self.msg, t))
+            self.time = t
+
+
+def quiver(x, y, ax):
+    ax.set_xlim(0, x.shape[1])
+    ax.set_ylim(x.shape[0], 0)
+    ax.quiver(
+        x,
+        y,
+        units="xy",
+        angles="xy",
+        scale_units="xy",
+        scale=1,
+        minlength=0.01,
+        width=0.1,
+        color="b",
+    )
+
+
+def recursive_to(input, device):
+    if isinstance(input, torch.Tensor):
+        return input.to(device)
+    if isinstance(input, dict):
+        for name in input:
+            if isinstance(input[name], torch.Tensor):
+                input[name] = input[name].to(device)
+        return input
+    if isinstance(input, list):
+        for i, item in enumerate(input):
+            input[i] = recursive_to(item, device)
+        return input
+    assert False
+
+
+def np_softmax(x, axis=0):
+    """Compute softmax values for each sets of scores in x."""
+    e_x = np.exp(x - np.max(x))
+    return e_x / e_x.sum(axis=axis, keepdims=True)
+
+
+def argsort2d(arr):
+    return np.dstack(np.unravel_index(np.argsort(arr.ravel()), arr.shape))[0]
+
+
+def __parallel_handle(f, q_in, q_out):
+    while True:
+        i, x = q_in.get()
+        if i is None:
+            break
+        q_out.put((i, f(x)))
+
+
+def parmap(f, X, nprocs=multiprocessing.cpu_count(), progress_bar=lambda x: x):
+    if nprocs == 0:
+        nprocs = multiprocessing.cpu_count()
+    q_in = multiprocessing.Queue(1)
+    q_out = multiprocessing.Queue()
+
+    proc = [
+        multiprocessing.Process(target=__parallel_handle, args=(f, q_in, q_out))
+        for _ in range(nprocs)
+    ]
+    for p in proc:
+        p.daemon = True
+        p.start()
+
+    try:
+        sent = [q_in.put((i, x)) for i, x in enumerate(X)]
+        [q_in.put((None, None)) for _ in range(nprocs)]
+        res = [q_out.get() for _ in progress_bar(range(len(sent)))]
+        [p.join() for p in proc]
+    except KeyboardInterrupt:
+        q_in.close()
+        q_out.close()
+        raise
+    return [x for i, x in sorted(res)]

+ 0 - 494
main.py

@@ -1,494 +0,0 @@
-import math
-import os.path
-import re
-import sys
-
-import PIL.Image
-import torch
-import numpy as np
-import matplotlib.pyplot as plt
-import torchvision.transforms
-import torchvision.transforms.functional as F
-from torch.utils.data import DataLoader
-from torchvision.transforms import v2
-
-from torchvision.utils import make_grid, draw_bounding_boxes
-from torchvision.io import read_image
-from pathlib import Path
-from torchvision.models.detection import maskrcnn_resnet50_fpn_v2, MaskRCNN_ResNet50_FPN_V2_Weights
-# PyTorch TensorBoard support
-from torch.utils.tensorboard import SummaryWriter
-import cv2
-from sklearn.cluster import DBSCAN
-from test.MaskRCNN import MaskRCNNDataset
-from tools import utils
-import pandas as pd
-
-plt.rcParams["savefig.bbox"] = 'tight'
-orig_path = r'F:\Downloads\severstal-steel-defect-detection'
-dst_path = r'F:\Downloads\severstal-steel-defect-detection'
-
-
-def show(imgs):
-    if not isinstance(imgs, list):
-        imgs = [imgs]
-    fig, axs = plt.subplots(ncols=len(imgs), squeeze=False)
-    for i, img in enumerate(imgs):
-        img = img.detach()
-        img = F.to_pil_image(img)
-        axs[0, i].imshow(np.asarray(img))
-        axs[0, i].set(xticklabels=[], yticklabels=[], xticks=[], yticks=[])
-    plt.show()
-
-
-def train_one_epoch(model, optimizer, data_loader, device, epoch, print_freq, scaler=None):
-    model.train()
-    metric_logger = utils.MetricLogger(delimiter="  ")
-    metric_logger.add_meter("lr", utils.SmoothedValue(window_size=1, fmt="{value:.6f}"))
-    header = f"Epoch: [{epoch}]"
-
-    lr_scheduler = None
-    if epoch == 0:
-        warmup_factor = 1.0 / 1000
-        warmup_iters = min(1000, len(data_loader) - 1)
-
-        lr_scheduler = torch.optim.lr_scheduler.LinearLR(
-            optimizer, start_factor=warmup_factor, total_iters=warmup_iters
-        )
-
-    for images, targets in metric_logger.log_every(data_loader, print_freq, header):
-        images = list(image.to(device) for image in images)
-        targets = [{k: v.to(device) if isinstance(v, torch.Tensor) else v for k, v in t.items()} for t in targets]
-        with torch.cuda.amp.autocast(enabled=scaler is not None):
-            loss_dict = model(images, targets)
-            losses = sum(loss for loss in loss_dict.values())
-
-        # reduce losses over all GPUs for logging purposes
-        loss_dict_reduced = utils.reduce_dict(loss_dict)
-        losses_reduced = sum(loss for loss in loss_dict_reduced.values())
-
-        loss_value = losses_reduced.item()
-
-        if not math.isfinite(loss_value):
-            print(f"Loss is {loss_value}, stopping training")
-            print(loss_dict_reduced)
-            sys.exit(1)
-
-        optimizer.zero_grad()
-        if scaler is not None:
-            scaler.scale(losses).backward()
-            scaler.step(optimizer)
-            scaler.update()
-        else:
-            losses.backward()
-            optimizer.step()
-
-        if lr_scheduler is not None:
-            lr_scheduler.step()
-
-        metric_logger.update(loss=losses_reduced, **loss_dict_reduced)
-        metric_logger.update(lr=optimizer.param_groups[0]["lr"])
-
-    return metric_logger
-
-
-def train():
-    pass
-
-
-def trans_datasets_format():
-    # 使用pandas的read_csv函数读取文件
-    df = pd.read_csv(os.path.join(orig_path, 'train.csv'))
-
-    # 显示数据的前几行
-    print(df.head())
-    for row in df.itertuples():
-        # print(f"Row index: {row.Index}")
-        # print(getattr(row, 'ImageId'))  # 输出特定列的值
-        img_name = getattr(row, 'ImageId')
-        img_path = os.path.join(orig_path + '/train_images', img_name)
-        dst_img_path = os.path.join(dst_path + '/images/train', img_name)
-        dst_label_path = os.path.join(dst_path + '/labels/train', img_name[:-3] + 'txt')
-        print(f'dst label:{dst_label_path}')
-        im = cv2.imread(img_path)
-        # cv2.imshow('test',im)
-        cv2.imwrite(dst_img_path, im)
-        img = PIL.Image.open(img_path)
-        height, width = im.shape[:2]
-        print(f'cv2 size:{im.shape}')
-        label, mask = compute_mask(row, img.size)
-        lbls, ins_masks=cluster_dbscan(mask,img)
-
-
-
-        with open(dst_label_path, 'a+') as writer:
-            # writer.write(label)
-            for ins_mask in ins_masks:
-                lbl_data = str(label) + ' '
-                for mp in ins_mask:
-                    h,w=mp
-                    lbl_data += str(w / width) + ' ' + str(h / height) + ' '
-
-                # non_zero_coords = np.nonzero(inm.reshape(width,height).T)
-                # coords_list = list(zip(non_zero_coords[0], non_zero_coords[1]))
-                # # print(f'mask:{mask[0,333]}')
-                # print(f'mask pixels:{coords_list}')
-                #
-                #
-                # for coord in coords_list:
-                #     h, w = coord
-                #     lbl_data += str(w / width) + ' ' + str(h / height) + ' '
-
-                writer.write(lbl_data + '\n')
-                print(f'lbl_data:{lbl_data}')
-        writer.close()
-        print(f'label:{label}')
-        # plt.imshow(img)
-        # plt.imshow(mask, cmap='Reds', alpha=0.3)
-        # plt.show()
-
-
-def compute_mask(row, shape):
-    width, height = shape
-    print(f'shape:{shape}')
-    mask = np.zeros(width * height, dtype=np.uint8)
-    pixels = np.array(list(map(int, row.EncodedPixels.split())))
-    label = row.ClassId
-    # print(f'pixels:{pixels}')
-    mask_start = pixels[0::2]
-    mask_length = pixels[1::2]
-
-    for s, l in zip(mask_start, mask_length):
-        mask[s:s + l] = 255
-    mask = mask.reshape((width, height)).T
-
-    # mask = np.flipud(np.rot90(mask.reshape((height, width))))
-    return label, mask
-
-def cluster_dbscan(mask,image):
-    # 将 mask 转换为二值图像
-    _, mask_binary = cv2.threshold(mask, 127, 255, cv2.THRESH_BINARY)
-
-    # 将 mask 一维化
-    mask_flattened = mask_binary.flatten()
-
-    # 获取 mask 中的前景像素坐标
-    foreground_pixels = np.argwhere(mask_flattened == 255)
-
-    # 将像素坐标转换为二维坐标
-    foreground_pixels_2d = np.column_stack(
-        (foreground_pixels // mask_binary.shape[1], foreground_pixels % mask_binary.shape[1]))
-
-    # 定义 DBSCAN 参数
-    eps = 3  # 邻域半径
-    min_samples = 10  # 最少样本数量
-
-    # 应用 DBSCAN
-    dbscan = DBSCAN(eps=eps, min_samples=min_samples).fit(foreground_pixels_2d)
-
-    # 获取聚类标签
-    labels = dbscan.labels_
-    print(f'labels:{labels}')
-    # 获取唯一的标签
-    unique_labels = set(labels)
-
-    print(f'unique_labels:{unique_labels}')
-    # 创建一个空的图像来保存聚类结果
-    clustered_image = np.zeros_like(image)
-    # print(f'clustered_image shape:{clustered_image.shape}')
-
-
-    # 将每个像素分配给相应的簇
-    clustered_points=[]
-    for k in unique_labels:
-
-
-        class_member_mask = (labels == k)
-        # print(f'class_member_mask:{class_member_mask}')
-        # plt.subplot(132), plt.imshow(class_member_mask), plt.title(str(labels))
-
-        pixel_indices = foreground_pixels_2d[class_member_mask]
-        clustered_points.append(pixel_indices)
-
-    return unique_labels,clustered_points
-
-def show_cluster_dbscan(mask,image,unique_labels,clustered_points,):
-    print(f'mask shape:{mask.shape}')
-    # 将 mask 转换为二值图像
-    _, mask_binary = cv2.threshold(mask, 127, 255, cv2.THRESH_BINARY)
-
-    # 将 mask 一维化
-    mask_flattened = mask_binary.flatten()
-
-    # 获取 mask 中的前景像素坐标
-    foreground_pixels = np.argwhere(mask_flattened == 255)
-    # print(f'unique_labels:{unique_labels}')
-    # 创建一个空的图像来保存聚类结果
-    print(f'image shape:{image.shape}')
-    clustered_image = np.zeros_like(image)
-    print(f'clustered_image shape:{clustered_image.shape}')
-
-    # 为每个簇分配颜色
-    colors =np.array( [plt.cm.Spectral(each) for each in np.linspace(0, 1, len(unique_labels))])
-    # print(f'colors:{colors}')
-    plt.figure(figsize=(12, 6))
-    for points_coord,col in  zip(clustered_points,colors):
-        for coord in points_coord:
-
-            clustered_image[coord[0], coord[1]] = (np.array(col[:3]) * 255)
-
-    # # 将每个像素分配给相应的簇
-    # for k, col in zip(unique_labels, colors):
-    #     print(f'col:{col*255}')
-    #     if k == -1:
-    #         # 黑色用于噪声点
-    #         col = [0, 0, 0, 1]
-    #
-    #     class_member_mask = (labels == k)
-    #     # print(f'class_member_mask:{class_member_mask}')
-    #     # plt.subplot(132), plt.imshow(class_member_mask), plt.title(str(labels))
-    #
-    #     pixel_indices = foreground_pixels_2d[class_member_mask]
-    #     clustered_points.append(pixel_indices)
-    #     # print(f'pixel_indices:{pixel_indices}')
-    #     for pixel_index in pixel_indices:
-    #         clustered_image[pixel_index[0], pixel_index[1]] = (np.array(col[:3]) * 255)
-
-    print(f'clustered_points:{len(clustered_points)}')
-    # print(f'clustered_image:{clustered_image}')
-    # 显示原图和聚类结果
-    # plt.figure(figsize=(12, 6))
-    plt.subplot(131), plt.imshow(image), plt.title('Original Image')
-    # print(f'image:{image}')
-    plt.subplot(132), plt.imshow(mask_binary, cmap='gray'), plt.title('Mask')
-    plt.subplot(133), plt.imshow(clustered_image.astype(np.uint8)), plt.title('Clustered Image')
-    plt.show()
-def test():
-    dog1_int = read_image(str(Path('./assets') / 'dog1.jpg'))
-    dog2_int = read_image(str(Path('./assets') / 'dog2.jpg'))
-    dog_list = [dog1_int, dog2_int]
-    grid = make_grid(dog_list)
-
-    weights = MaskRCNN_ResNet50_FPN_V2_Weights.DEFAULT
-    transforms = weights.transforms()
-
-    images = [transforms(d) for d in dog_list]
-    # 假设输入图像的尺寸为 (3, 800, 800)
-    dummy_input = torch.randn(1, 3, 800, 800)
-    model = maskrcnn_resnet50_fpn_v2(weights=weights, progress=False)
-    model = model.eval()
-
-    # 使用 torch.jit.script
-    scripted_model = torch.jit.script(model)
-
-    output = model(dummy_input)
-    print(f'output:{output}')
-
-    writer = SummaryWriter('runs/')
-    writer.add_graph(scripted_model, input_to_model=dummy_input)
-    writer.flush()
-
-    # torch.onnx.export(models,images, f='maskrcnn.onnx')  # 导出 .onnx 文
-    # netron.start('AlexNet.onnx')  # 展示结构图
-
-    show(grid)
-
-
-def test_mask():
-    name = 'fdb7c0397'
-    label_path = os.path.join(dst_path + '/labels/train', name + '.txt')
-    img_path = os.path.join(orig_path + '/train_images', name + '.jpg')
-    mask = np.zeros((256, 1600), dtype=np.uint8)
-    df = pd.read_csv(os.path.join(orig_path, 'train.csv'))
-    # 显示数据的前几行
-    print(df.head())
-    points = []
-    with open(label_path, 'r') as reader:
-        lines = reader.readlines()
-        for line in lines:
-            parts = line.strip().split()
-            # print(f'parts:{parts}')
-            class_id = int(parts[0])
-            x_array = parts[1::2]
-            y_array = parts[2::2]
-
-            for x, y in zip(x_array, y_array):
-                x = float(x)
-                y = float(y)
-                points.append((int(y * 255), int(x * 1600)))
-            # points = np.array([[float(parts[i]), float(parts[i + 1])] for i in range(1, len(parts), 2)])
-            # mask_resized = cv2.resize(points, (1600, 256), interpolation=cv2.INTER_NEAREST)
-            print(f'points:{points}')
-            # mask[points[:,0],points[:,1]]=255
-            for p in points:
-                mask[p] = 255
-            # cv2.fillPoly(mask, points, color=(255,))
-    cv2.imshow('mask', mask)
-    for row in df.itertuples():
-        img_name = name + '.jpg'
-        if img_name == getattr(row, 'ImageId'):
-            img = PIL.Image.open(img_path)
-            height, width = img.size
-            print(f'img size:{img.size}')
-            label, mask = compute_mask(row, img.size)
-            plt.imshow(img)
-            plt.imshow(mask, cmap='Reds', alpha=0.3)
-            plt.show()
-    cv2.waitKey(0)
-
-def show_img_mask(img_path):
-    test_img = PIL.Image.open(img_path)
-
-    w,h=test_img.size
-    test_img=torchvision.transforms.ToTensor()(test_img)
-    test_img=test_img.permute(1, 2, 0)
-    print(f'test_img shape:{test_img.shape}')
-    lbl_path=re.sub(r'\\images\\', r'\\labels\\', img_path[:-3]) + 'txt'
-    # print(f'lbl_path:{lbl_path}')
-    masks = []
-    labels = []
-
-    with open(lbl_path, 'r') as reader:
-        lines = reader.readlines()
-        # 为每个簇分配颜色
-        colors = np.array([plt.cm.Spectral(each) for each in np.linspace(0, 1, len(lines))])
-        print(f'colors:{colors*255}')
-        mask_points = []
-        for line ,col in zip(lines,colors):
-            print(f'col:{np.array(col[:3]) * 255}')
-            mask = torch.zeros(test_img.shape, dtype=torch.uint8)
-            # print(f'mask shape:{mask.shape}')
-            parts = line.strip().split()
-            # print(f'parts:{parts}')
-            cls = torch.tensor(int(parts[0]), dtype=torch.int64)
-            labels.append(cls)
-            x_array = parts[1::2]
-            y_array = parts[2::2]
-
-            for x, y in zip(x_array, y_array):
-                x = float(x)
-                y = float(y)
-                mask_points.append((int(y * h), int(x * w)))
-            for p in mask_points:
-                # print(f'p:{p}')
-                mask[p] = torch.tensor(np.array(col[:3])*255)
-            masks.append(mask)
-    reader.close()
-    target = {}
-
-    # target["boxes"] = masks_to_boxes(torch.stack(masks))
-
-    # target["labels"] = torch.stack(labels)
-
-    target["masks"] = torch.stack(masks)
-    print(f'target:{target}')
-
-    # plt.imshow(test_img.permute(1, 2, 0))
-    fig, axs = plt.subplots(2, 1)
-    print(f'test_img:{test_img*255}')
-    axs[0].imshow(test_img)
-    axs[0].axis('off')
-    axs[1].axis('off')
-    axs[1].imshow(test_img*255)
-    for img_mask in target['masks']:
-        # img_mask=img_mask.unsqueeze(0)
-        # img_mask = img_mask.expand_as(test_img)
-        # print(f'img_mask:{img_mask.shape}')
-        axs[1].imshow(img_mask,alpha=0.3)
-
-        # img_mask=np.array(img_mask)
-        # print(f'img_mask:{img_mask.shape}')
-        # plt.imshow(img_mask,alpha=0.5)
-        # mask_3channel = cv2.merge([np.zeros_like(img_mask), np.zeros_like(img_mask), img_mask])
-        # masked_image = cv2.addWeighted(test_img, 1, mask_3channel, 0.6, 0)
-
-    # cv2.imshow('cv2 mask img', masked_image)
-    # cv2.waitKey(0)
-    plt.show()
-def show_dataset():
-    global transforms, dataset, imgs
-    transforms = v2.Compose([
-        # v2.RandomResizedCrop(size=(224, 224), antialias=True),
-        # v2.RandomPhotometricDistort(p=1),
-        # v2.RandomHorizontalFlip(p=1),
-        v2.ToTensor()
-    ])
-    dataset = MaskRCNNDataset(dataset_path=r'F:\Downloads\severstal-steel-defect-detection', transforms=transforms,
-                              dataset_type='train')
-    dataloader = DataLoader(dataset, batch_size=4, shuffle=False, collate_fn=utils.collate_fn)
-    imgs, targets = next(iter(dataloader))
-
-    mask = np.array(targets[2]['masks'][0])
-    boxes = targets[2]['boxes']
-    print(f'boxes:{boxes}')
-    # mask[mask == 255] = 1
-    img = np.array(imgs[2].permute(1, 2, 0)) * 255
-    img = img.astype(np.uint8)
-    print(f'img shape:{img.shape}')
-    print(f'mask:{mask.shape}')
-    # print(f'target:{targets}')
-    # print(f'imgs:{imgs[0]}')
-    # print(f'cv2 img shape:{np.array(imgs[0]).shape}')
-    # cv2.imshow('cv2 img',img)
-    # cv2.imshow('cv2 mask', mask)
-    # plt.imshow('mask',mask)
-    mask_3channel = cv2.merge([np.zeros_like(mask), np.zeros_like(mask), mask])
-    # cv2.imshow('mask_3channel',mask_3channel)
-    print(f'mask_3channel:{mask_3channel.shape}')
-    masked_image = cv2.addWeighted(img, 1, mask_3channel, 0.6, 0)
-    # cv2.imshow('cv2 mask img', masked_image)
-    plt.imshow(imgs[0].permute(1, 2, 0))
-    plt.imshow(mask, cmap='Reds', alpha=0.3)
-    drawn_boxes = draw_bounding_boxes((imgs[2] * 255).to(torch.uint8), boxes, colors="red", width=5)
-    plt.imshow(drawn_boxes.permute(1, 2, 0))
-    # show(drawn_boxes)
-    plt.show()
-    cv2.waitKey(0)
-
-def test_cluster(img_path):
-    test_img = PIL.Image.open(img_path)
-    w, h = test_img.size
-    test_img = torchvision.transforms.ToTensor()(test_img)
-    test_img=(test_img.permute(1, 2, 0).numpy() * 255).astype(np.uint8)
-    # print(f'test_img:{test_img}')
-    lbl_path = re.sub(r'\\images\\', r'\\labels\\', img_path[:-3]) + 'txt'
-    # print(f'lbl_path:{lbl_path}')
-    masks = []
-    labels = []
-    with open(lbl_path, 'r') as reader:
-        lines = reader.readlines()
-        mask_points = []
-        for line in lines:
-            mask = torch.zeros((h, w), dtype=torch.uint8)
-            parts = line.strip().split()
-            # print(f'parts:{parts}')
-            cls = torch.tensor(int(parts[0]), dtype=torch.int64)
-            labels.append(cls)
-            x_array = parts[1::2]
-            y_array = parts[2::2]
-
-            for x, y in zip(x_array, y_array):
-                x = float(x)
-                y = float(y)
-                mask_points.append((int(y * h), int(x * w)))
-            for p in mask_points:
-                mask[p] = 255
-            masks.append(mask)
-    # print(f'masks:{masks}')
-    labels,clustered_points=cluster_dbscan(masks[0].numpy(),test_img)
-    print(f'labels:{labels}')
-    print(f'clustered_points len:{len(clustered_points)}')
-    show_cluster_dbscan(masks[0].numpy(),test_img,labels,clustered_points)
-
-if __name__ == '__main__':
-    # trans_datasets_format()
-    # test_mask()
-    # 定义转换
-    # show_dataset()
-
-    # test_img_path= r"F:\Downloads\severstal-steel-defect-detection\images\train\0025bde0c.jpg"
-    test_img_path = r"F:\DevTools\datasets\renyaun\1012\spilt\images\train\2024-09-27-14-32-53_SaveImage.png"
-    # test_img1_path=r"F:\Downloads\severstal-steel-defect-detection\images\train\1d00226a0.jpg"
-    show_img_mask(test_img_path)
-    #
-    # test_cluster(test_img_path)

+ 0 - 0
models/base/__init__.py


+ 0 - 0
models/config/__init__.py


+ 0 - 22
models/config/config_tool.py

@@ -1,22 +0,0 @@
-import yaml
-
-
-def read_yaml(path='application.yaml'):
-    try:
-        with open(path, 'r') as file:
-            data = file.read()
-            # result = yaml.load(data)
-            result = yaml.load(data, Loader=yaml.FullLoader)
-
-            return result
-    except Exception as e:
-        print(e)
-        return None
-
-
-def write_yaml(path='application.yaml', data=None):
-    try:
-        with open(path, 'w', encoding='utf-8') as f:
-            yaml.dump(data=data, stream=f, allow_unicode=True)
-    except Exception as e:
-        print(e)

+ 0 - 42
models/config/test_config.py

@@ -1,42 +0,0 @@
-import yaml
-
-test_data = {
-    'cameras': [{
-        'id': 1,
-        'ip': "192.168.1.2"
-    }, {
-        'id': 2,
-        'ip': "192.168.1.3"
-    }]
-}
-
-
-def read_yaml(path):
-    try:
-        with open(path, 'r') as file:
-            data = file.read()
-            # result = yaml.load(data)
-            result = yaml.load(data, Loader=yaml.FullLoader)
-
-            return result
-    except Exception as e:
-        print(e)
-        return None
-
-
-def write_yaml(path):
-    try:
-        with open('path', 'w', encoding='utf-8') as f:
-            yaml.dump(data=test_data, stream=f, allow_unicode=True)
-    except Exception as e:
-        print(e)
-
-
-if __name__ == '__main__':
-    p = 'train.yaml'
-    result = read_yaml(p)
-    # j=json.load(result)
-    print('result', result)
-    # print('cameras', result['cameras'])
-    # print('json',j)
-

+ 0 - 34
models/config/train.yaml

@@ -1,34 +0,0 @@
-
-
-# Train/val/test sets as 1) dir: path/to/imgs, 2) file: path/to/imgs.txt, or 3) list: [path/to/imgs1, path/to/imgs2, ..]
-dataset_path: F:\DevTools\datasets\renyaun\1012\spilt
-#train: images/train  # train images (relative to 'path') 128 images
-#val: images/train  # val images (relative to 'path') 128 images
-#test: images/test  # test images (optional)
-
-#train parameters
-num_classes: 5
-opt: 'adamw'
-batch_size: 2
-epochs: 10
-lr: 0.005
-momentum: 0.9
-weight_decay: 0.0001
-lr_step_size: 3
-lr_gamma: 0.1
-num_workers: 4
-print_freq: 10
-target_type: polygon
-enable_logs: True
-augmentation: True
-
-
-## Classes
-#names:
-#  0: fire
-#  1: dust
-#  2: move_machine
-#  3: open_machine
-#  4: close_machine
-
-

+ 0 - 0
models/ins/__init__.py


+ 0 - 142
models/ins/maskrcnn.py

@@ -1,142 +0,0 @@
-import math
-import os
-import sys
-from datetime import datetime
-from typing import Mapping, Any
-import cv2
-import numpy as np
-import torch
-import torchvision
-from torch import nn
-from torchvision.io import read_image
-from torchvision.models.detection import MaskRCNN_ResNet50_FPN_V2_Weights
-from torchvision.models.detection.faster_rcnn import FastRCNNPredictor
-from torchvision.models.detection.mask_rcnn import MaskRCNNPredictor
-from torchvision.utils import draw_bounding_boxes
-
-from models.config.config_tool import read_yaml
-from models.ins.trainer import train_cfg
-from tools import utils
-
-
-class MaskRCNNModel(nn.Module):
-
-    def __init__(self, num_classes=0, transforms=None):
-        super(MaskRCNNModel, self).__init__()
-        self.__model = torchvision.models.detection.maskrcnn_resnet50_fpn_v2(
-            weights=MaskRCNN_ResNet50_FPN_V2_Weights.DEFAULT)
-        if transforms is None:
-            self.transforms = MaskRCNN_ResNet50_FPN_V2_Weights.DEFAULT.transforms()
-        if num_classes != 0:
-            self.set_num_classes(num_classes)
-            # self.__num_classes=0
-
-        self.device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
-
-    def forward(self, inputs):
-        outputs = self.__model(inputs)
-        return outputs
-
-    def train(self, cfg):
-        parameters = read_yaml(cfg)
-        num_classes=parameters['num_classes']
-        # print(f'num_classes:{num_classes}')
-        self.set_num_classes(num_classes)
-        train_cfg(self.__model, cfg)
-
-    def set_num_classes(self, num_classes):
-        in_features = self.__model.roi_heads.box_predictor.cls_score.in_features
-        self.__model.roi_heads.box_predictor = FastRCNNPredictor(in_features, num_classes=num_classes)
-        in_features_mask = self.__model.roi_heads.mask_predictor.conv5_mask.in_channels
-        hidden_layer = 256
-        self.__model.roi_heads.mask_predictor = MaskRCNNPredictor(in_features_mask, hidden_layer,
-                                                                  num_classes=num_classes)
-
-    def load_weight(self, pt_path):
-        state_dict = torch.load(pt_path)
-        self.__model.load_state_dict(state_dict)
-
-    def load_state_dict(self, state_dict: Mapping[str, Any], strict: bool = True):
-        self.__model.load_state_dict(state_dict)
-        # return super().load_state_dict(state_dict, strict)
-
-    def predict(self, src, show_box=True, show_mask=True):
-        self.__model.eval()
-
-        img = read_image(src)
-        img = self.transforms(img)
-        img = img.to(self.device)
-        result = self.__model([img])
-        print(f'result:{result}')
-        masks = result[0]['masks']
-        boxes = result[0]['boxes']
-        # cv2.imshow('mask',masks[0].cpu().detach().numpy())
-        boxes = boxes.cpu().detach()
-        drawn_boxes = draw_bounding_boxes((img * 255).to(torch.uint8), boxes, colors="red", width=5)
-        print(f'drawn_boxes:{drawn_boxes.shape}')
-        boxed_img = drawn_boxes.permute(1, 2, 0).numpy()
-        # boxed_img=cv2.resize(boxed_img,(800,800))
-        # cv2.imshow('boxes',boxed_img)
-
-        mask = masks[0].cpu().detach().permute(1, 2, 0).numpy()
-
-        mask = cv2.resize(mask, (800, 800))
-        # cv2.imshow('mask',mask)
-        img = img.cpu().detach().permute(1, 2, 0).numpy()
-
-        masked_img = self.overlay_masks_on_image(boxed_img, masks)
-        masked_img = cv2.resize(masked_img, (800, 800))
-        cv2.imshow('img_masks', masked_img)
-        # show_img_boxes_masks(img, boxes, masks)
-        cv2.waitKey(0)
-
-    def generate_colors(self, n):
-        """
-        生成n个均匀分布在HSV色彩空间中的颜色,并转换成BGR色彩空间。
-
-        :param n: 需要的颜色数量
-        :return: 一个包含n个颜色的列表,每个颜色为BGR格式的元组
-        """
-        hsv_colors = [(i / n * 180, 1 / 3 * 255, 2 / 3 * 255) for i in range(n)]
-        bgr_colors = [tuple(map(int, cv2.cvtColor(np.uint8([[hsv]]), cv2.COLOR_HSV2BGR)[0][0])) for hsv in hsv_colors]
-        return bgr_colors
-
-    def overlay_masks_on_image(self, image, masks, alpha=0.6):
-        """
-        在原图上叠加多个掩码,每个掩码使用不同的颜色。
-
-        :param image: 原图 (NumPy 数组)
-        :param masks: 掩码列表 (每个都是 NumPy 数组,二值图像)
-        :param colors: 颜色列表 (每个颜色都是 (B, G, R) 格式的元组)
-        :param alpha: 掩码的透明度 (0.0 到 1.0)
-        :return: 叠加了多个掩码的图像
-        """
-        colors = self.generate_colors(len(masks))
-        if len(masks) != len(colors):
-            raise ValueError("The number of masks and colors must be the same.")
-
-        # 复制原图,避免修改原始图像
-        overlay = image.copy()
-
-        for mask, color in zip(masks, colors):
-            # 确保掩码是二值图像
-            mask = mask.cpu().detach().permute(1, 2, 0).numpy()
-            binary_mask = (mask > 0).astype(np.uint8) * 255  # 你可以根据实际情况调整阈值
-
-            # 创建彩色掩码
-            colored_mask = np.zeros_like(image)
-            colored_mask[:] = color
-            colored_mask = cv2.bitwise_and(colored_mask, colored_mask, mask=binary_mask)
-
-            # 将彩色掩码与当前的叠加图像混合
-            overlay = cv2.addWeighted(overlay, 1 - alpha, colored_mask, alpha, 0)
-
-        return overlay
-
-
-if __name__ == '__main__':
-    # ins_model = MaskRCNNModel(num_classes=5)
-    ins_model = MaskRCNNModel()
-    # data_path = r'F:\DevTools\datasets\renyaun\1012\spilt'
-    # ins_model.train(data_dir=data_path,epochs=5000,target_type='pixel',batch_size=6,num_workers=10,num_classes=5)
-    ins_model.train(cfg='train.yaml')

+ 0 - 93
models/ins/maskrcnn_dataset.py

@@ -1,93 +0,0 @@
-import os
-
-import PIL
-import cv2
-import numpy as np
-import torch
-from matplotlib import pyplot as plt
-from torch.utils.data import Dataset
-from torchvision.models.detection import MaskRCNN_ResNet50_FPN_V2_Weights
-
-from models.dataset_tool import masks_to_boxes, read_masks_from_txt, read_masks_from_pixels
-
-
-class MaskRCNNDataset(Dataset):
-    def __init__(self, dataset_path, transforms=None, dataset_type=None, target_type='polygon'):
-        self.data_path = dataset_path
-        self.transforms = transforms
-        self.img_path = os.path.join(dataset_path, "images/" + dataset_type)
-        self.lbl_path = os.path.join(dataset_path, "labels/" + dataset_type)
-        self.imgs = os.listdir(self.img_path)
-        self.lbls = os.listdir(self.lbl_path)
-        self.target_type = target_type
-        self.deafult_transform= MaskRCNN_ResNet50_FPN_V2_Weights.DEFAULT.transforms()
-        # print('maskrcnn inited!')
-
-    def __getitem__(self, item):
-        # print('__getitem__')
-        img_path = os.path.join(self.img_path, self.imgs[item])
-        lbl_path = os.path.join(self.lbl_path, self.imgs[item][:-3] + 'txt')
-        img = PIL.Image.open(img_path).convert('RGB')
-        # h, w = np.array(img).shape[:2]
-        w, h = img.size
-        # print(f'h,w:{h, w}')
-        target = self.read_target(item=item, lbl_path=lbl_path, shape=(h, w))
-        if self.transforms:
-            img, target = self.transforms(img,target)
-        else:
-            img=self.deafult_transform(img)
-        # print(f'img:{img.shape},target:{target}')
-        return img, target
-
-    def create_masks_from_polygons(self, polygons, image_shape):
-        """创建一个与图像尺寸相同的掩码,并填充多边形轮廓"""
-        colors = np.array([plt.cm.Spectral(each) for each in np.linspace(0, 1, len(polygons))])
-        masks = []
-
-        for polygon_data, col in zip(polygons, colors):
-            mask = np.zeros(image_shape[:2], dtype=np.uint8)
-            # 将多边形顶点转换为 NumPy 数组
-            _, polygon = polygon_data
-            pts = np.array(polygon, np.int32).reshape((-1, 1, 2))
-
-            # 使用 OpenCV 的 fillPoly 函数填充多边形
-            # print(f'color:{col[:3]}')
-            cv2.fillPoly(mask, [pts], np.array(col[:3]) * 255)
-            mask = torch.from_numpy(mask)
-            mask[mask != 0] = 1
-            masks.append(mask)
-
-        return masks
-
-    def read_target(self, item, lbl_path, shape):
-        # print(f'lbl_path:{lbl_path}')
-        h, w = shape
-        labels = []
-        masks = []
-        if self.target_type == 'polygon':
-            labels, masks = read_masks_from_txt(lbl_path, shape)
-        elif self.target_type == 'pixel':
-            labels, masks = read_masks_from_pixels(lbl_path, shape)
-
-        target = {}
-        target["boxes"] = masks_to_boxes(torch.stack(masks))
-        target["labels"] = torch.stack(labels)
-        target["masks"] = torch.stack(masks)
-        target["image_id"] = torch.tensor(item)
-        target["area"] = torch.zeros(len(masks))
-        target["iscrowd"] = torch.zeros(len(masks))
-        return target
-
-    def heatmap_enhance(self, img):
-        # 直方图均衡化
-        img_eq = cv2.equalizeHist(img)
-
-        # 自适应直方图均衡化
-        # clahe = cv2.createCLAHE(clipLimit=2.0, tileGridSize=(8, 8))
-        # img_clahe = clahe.apply(img)
-
-        # 将灰度图转换为热力图
-        heatmap = cv2.applyColorMap(img_eq, cv2.COLORMAP_HOT)
-
-    def __len__(self):
-        return len(self.imgs)

+ 0 - 31
models/ins/train.yaml

@@ -1,31 +0,0 @@
-
-
-dataset_path: F:\DevTools\datasets\renyaun\1012\spilt
-
-#train parameters
-num_classes: 5
-opt: 'adamw'
-batch_size: 2
-epochs: 10
-lr: 0.005
-momentum: 0.9
-weight_decay: 0.0001
-lr_step_size: 3
-lr_gamma: 0.1
-num_workers: 4
-print_freq: 10
-target_type: polygon
-enable_logs: True
-augmentation: True
-checkpoint: None
-
-
-## Classes
-#names:
-#  0: fire
-#  1: dust
-#  2: move_machine
-#  3: open_machine
-#  4: close_machine
-
-

+ 0 - 219
models/ins/trainer.py

@@ -1,219 +0,0 @@
-import math
-import os
-import sys
-from datetime import datetime
-
-import torch
-import torchvision
-from torch.utils.tensorboard import SummaryWriter
-from torchvision.models.detection import MaskRCNN_ResNet50_FPN_V2_Weights
-
-from models.config.config_tool import read_yaml
-from models.ins.maskrcnn_dataset import MaskRCNNDataset
-from tools import utils, presets
-
-
-def train_one_epoch(model, optimizer, data_loader, device, epoch, print_freq, scaler=None):
-    model.train()
-    metric_logger = utils.MetricLogger(delimiter="  ")
-    metric_logger.add_meter("lr", utils.SmoothedValue(window_size=1, fmt="{value:.6f}"))
-    header = f"Epoch: [{epoch}]"
-
-    lr_scheduler = None
-    if epoch == 0:
-        warmup_factor = 1.0 / 1000
-        warmup_iters = min(1000, len(data_loader) - 1)
-
-        lr_scheduler = torch.optim.lr_scheduler.LinearLR(
-            optimizer, start_factor=warmup_factor, total_iters=warmup_iters
-        )
-
-    for images, targets in metric_logger.log_every(data_loader, print_freq, header):
-        print(f'images:{images}')
-        images = list(image.to(device) for image in images)
-        targets = [{k: v.to(device) if isinstance(v, torch.Tensor) else v for k, v in t.items()} for t in targets]
-        with torch.cuda.amp.autocast(enabled=scaler is not None):
-            loss_dict = model(images, targets)
-            losses = sum(loss for loss in loss_dict.values())
-
-        # reduce losses over all GPUs for logging purposes
-        loss_dict_reduced = utils.reduce_dict(loss_dict)
-        losses_reduced = sum(loss for loss in loss_dict_reduced.values())
-
-        loss_value = losses_reduced.item()
-
-        if not math.isfinite(loss_value):
-            print(f"Loss is {loss_value}, stopping training")
-            print(loss_dict_reduced)
-            sys.exit(1)
-
-        optimizer.zero_grad()
-        if scaler is not None:
-            scaler.scale(losses).backward()
-            scaler.step(optimizer)
-            scaler.update()
-        else:
-            losses.backward()
-            optimizer.step()
-
-        if lr_scheduler is not None:
-            lr_scheduler.step()
-
-        metric_logger.update(loss=losses_reduced, **loss_dict_reduced)
-        metric_logger.update(lr=optimizer.param_groups[0]["lr"])
-
-    return metric_logger
-
-
-def load_train_parameter(cfg):
-    parameters = read_yaml(cfg)
-    return parameters
-
-
-def train_cfg(model, cfg):
-    parameters = read_yaml(cfg)
-    print(f'train parameters:{parameters}')
-    train(model, **parameters)
-
-
-def train(model, **kwargs):
-    # 默认参数
-    default_params = {
-        'dataset_path': '/path/to/dataset',
-        'num_classes': 10,
-        'opt': 'adamw',
-        'batch_size': 2,
-        'epochs': 10,
-        'lr': 0.005,
-        'momentum': 0.9,
-        'weight_decay': 1e-4,
-        'lr_step_size': 3,
-        'lr_gamma': 0.1,
-        'num_workers': 4,
-        'print_freq': 10,
-        'target_type': 'polygon',
-        'enable_logs': True,
-        'augmentation': False,
-        'checkpoint':None
-    }
-    # 更新默认参数
-    for key, value in kwargs.items():
-        if key in default_params:
-            default_params[key] = value
-        else:
-            raise ValueError(f"Unknown argument: {key}")
-
-    # 解析参数
-    dataset_path = default_params['dataset_path']
-    num_classes = default_params['num_classes']
-    batch_size = default_params['batch_size']
-    epochs = default_params['epochs']
-    lr = default_params['lr']
-    momentum = default_params['momentum']
-    weight_decay = default_params['weight_decay']
-    lr_step_size = default_params['lr_step_size']
-    lr_gamma = default_params['lr_gamma']
-    num_workers = default_params['num_workers']
-    print_freq = default_params['print_freq']
-    target_type = default_params['target_type']
-    augmentation = default_params['augmentation']
-    # 设置设备
-    device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
-
-    train_result_ptath = os.path.join('train_results', datetime.now().strftime("%Y%m%d_%H%M%S"))
-    wts_path = os.path.join(train_result_ptath, 'weights')
-    tb_path = os.path.join(train_result_ptath, 'logs')
-    writer = SummaryWriter(tb_path)
-
-    transforms = None
-    # default_transforms = MaskRCNN_ResNet50_FPN_V2_Weights.DEFAULT.transforms()
-    if augmentation:
-        transforms = get_transform(is_train=True)
-        print(f'transforms:{transforms}')
-    if not os.path.exists('train_results'):
-        os.mkdir('train_results')
-
-    model.to(device)
-    optimizer = torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=weight_decay)
-
-    dataset = MaskRCNNDataset(dataset_path=dataset_path,
-                              transforms=transforms, dataset_type='train', target_type=target_type)
-    dataset_test = MaskRCNNDataset(dataset_path=dataset_path, transforms=None,
-                                   dataset_type='val')
-
-    train_sampler = torch.utils.data.RandomSampler(dataset)
-    test_sampler = torch.utils.data.SequentialSampler(dataset_test)
-    train_batch_sampler = torch.utils.data.BatchSampler(train_sampler, batch_size, drop_last=True)
-    train_collate_fn = utils.collate_fn
-    data_loader = torch.utils.data.DataLoader(
-        dataset, batch_sampler=train_batch_sampler, num_workers=num_workers, collate_fn=train_collate_fn
-    )
-    # data_loader_test = torch.utils.data.DataLoader(
-    #     dataset_test, batch_size=1, sampler=test_sampler, num_workers=num_workers, collate_fn=utils.collate_fn
-    # )
-
-    img_results_path = os.path.join(train_result_ptath, 'img_results')
-    if os.path.exists(train_result_ptath):
-        pass
-    #     os.remove(train_result_ptath)
-    else:
-        os.mkdir(train_result_ptath)
-
-    if os.path.exists(train_result_ptath):
-        os.mkdir(wts_path)
-        os.mkdir(img_results_path)
-
-    for epoch in range(epochs):
-        metric_logger = train_one_epoch(model, optimizer, data_loader, device, epoch, print_freq, None)
-        losses = metric_logger.meters['loss'].global_avg
-        print(f'epoch {epoch}:loss:{losses}')
-        if os.path.exists(f'{wts_path}/last.pt'):
-            os.remove(f'{wts_path}/last.pt')
-        torch.save(model.state_dict(), f'{wts_path}/last.pt')
-        write_metric_logs(epoch, metric_logger, writer)
-        if epoch == 0:
-            best_loss = losses;
-        if best_loss >= losses:
-            best_loss = losses
-            if os.path.exists(f'{wts_path}/best.pt'):
-                os.remove(f'{wts_path}/best.pt')
-            torch.save(model.state_dict(), f'{wts_path}/best.pt')
-
-
-def get_transform(is_train, **kwargs):
-    default_params = {
-        'augmentation': 'multiscale',
-        'backend': 'tensor',
-        'use_v2': False,
-
-    }
-    # 更新默认参数
-    for key, value in kwargs.items():
-        if key in default_params:
-            default_params[key] = value
-        else:
-            raise ValueError(f"Unknown argument: {key}")
-
-    # 解析参数
-    augmentation = default_params['augmentation']
-    backend = default_params['backend']
-    use_v2 = default_params['use_v2']
-    if is_train:
-        return presets.DetectionPresetTrain(
-            data_augmentation=augmentation, backend=backend, use_v2=use_v2
-        )
-    # elif weights and test_only:
-    #     weights = torchvision.models.get_weight(args.weights)
-    #     trans = weights.transforms()
-    #     return lambda img, target: (trans(img), target)
-    else:
-        return presets.DetectionPresetEval(backend=backend, use_v2=use_v2)
-
-
-def write_metric_logs(epoch, metric_logger, writer):
-    writer.add_scalar(f'loss_classifier:', metric_logger.meters['loss_classifier'].global_avg, epoch)
-    writer.add_scalar(f'loss_box_reg:', metric_logger.meters['loss_box_reg'].global_avg, epoch)
-    writer.add_scalar(f'loss_mask:', metric_logger.meters['loss_mask'].global_avg, epoch)
-    writer.add_scalar(f'loss_objectness:', metric_logger.meters['loss_objectness'].global_avg, epoch)
-    writer.add_scalar(f'loss_rpn_box_reg:', metric_logger.meters['loss_rpn_box_reg'].global_avg, epoch)
-    writer.add_scalar(f'train loss:', metric_logger.meters['loss'].global_avg, epoch)

+ 0 - 0
models/wirenet/__init__.py


+ 0 - 548
models/wirenet/_utils.py

@@ -1,548 +0,0 @@
-import math
-from collections import OrderedDict
-from typing import Dict, List, Optional, Tuple
-
-import torch
-from torch import nn, Tensor
-from torch.nn import functional as F
-from torchvision.ops import complete_box_iou_loss, distance_box_iou_loss, FrozenBatchNorm2d, generalized_box_iou_loss
-
-
-class BalancedPositiveNegativeSampler:
-    """
-    This class samples batches, ensuring that they contain a fixed proportion of positives
-    """
-
-    def __init__(self, batch_size_per_image: int, positive_fraction: float) -> None:
-        """
-        Args:
-            batch_size_per_image (int): number of elements to be selected per image
-            positive_fraction (float): percentage of positive elements per batch
-        """
-        self.batch_size_per_image = batch_size_per_image
-        self.positive_fraction = positive_fraction
-
-    def __call__(self, matched_idxs: List[Tensor]) -> Tuple[List[Tensor], List[Tensor]]:
-        """
-        Args:
-            matched_idxs: list of tensors containing -1, 0 or positive values.
-                Each tensor corresponds to a specific image.
-                -1 values are ignored, 0 are considered as negatives and > 0 as
-                positives.
-
-        Returns:
-            pos_idx (list[tensor])
-            neg_idx (list[tensor])
-
-        Returns two lists of binary masks for each image.
-        The first list contains the positive elements that were selected,
-        and the second list the negative example.
-        """
-        pos_idx = []
-        neg_idx = []
-        for matched_idxs_per_image in matched_idxs:
-            positive = torch.where(matched_idxs_per_image >= 1)[0]
-            negative = torch.where(matched_idxs_per_image == 0)[0]
-
-            num_pos = int(self.batch_size_per_image * self.positive_fraction)
-            # protect against not enough positive examples
-            num_pos = min(positive.numel(), num_pos)
-            num_neg = self.batch_size_per_image - num_pos
-            # protect against not enough negative examples
-            num_neg = min(negative.numel(), num_neg)
-
-            # randomly select positive and negative examples
-            perm1 = torch.randperm(positive.numel(), device=positive.device)[:num_pos]
-            perm2 = torch.randperm(negative.numel(), device=negative.device)[:num_neg]
-
-            pos_idx_per_image = positive[perm1]
-            neg_idx_per_image = negative[perm2]
-
-            # create binary mask from indices
-            pos_idx_per_image_mask = torch.zeros_like(matched_idxs_per_image, dtype=torch.uint8)
-            neg_idx_per_image_mask = torch.zeros_like(matched_idxs_per_image, dtype=torch.uint8)
-
-            pos_idx_per_image_mask[pos_idx_per_image] = 1
-            neg_idx_per_image_mask[neg_idx_per_image] = 1
-
-            pos_idx.append(pos_idx_per_image_mask)
-            neg_idx.append(neg_idx_per_image_mask)
-
-        return pos_idx, neg_idx
-
-
-@torch.jit._script_if_tracing
-def encode_boxes(reference_boxes: Tensor, proposals: Tensor, weights: Tensor) -> Tensor:
-    """
-    Encode a set of proposals with respect to some
-    reference boxes
-
-    Args:
-        reference_boxes (Tensor): reference boxes
-        proposals (Tensor): boxes to be encoded
-        weights (Tensor[4]): the weights for ``(x, y, w, h)``
-    """
-
-    # perform some unpacking to make it JIT-fusion friendly
-    wx = weights[0]
-    wy = weights[1]
-    ww = weights[2]
-    wh = weights[3]
-
-    proposals_x1 = proposals[:, 0].unsqueeze(1)
-    proposals_y1 = proposals[:, 1].unsqueeze(1)
-    proposals_x2 = proposals[:, 2].unsqueeze(1)
-    proposals_y2 = proposals[:, 3].unsqueeze(1)
-
-    reference_boxes_x1 = reference_boxes[:, 0].unsqueeze(1)
-    reference_boxes_y1 = reference_boxes[:, 1].unsqueeze(1)
-    reference_boxes_x2 = reference_boxes[:, 2].unsqueeze(1)
-    reference_boxes_y2 = reference_boxes[:, 3].unsqueeze(1)
-
-    # implementation starts here
-    ex_widths = proposals_x2 - proposals_x1
-    ex_heights = proposals_y2 - proposals_y1
-    ex_ctr_x = proposals_x1 + 0.5 * ex_widths
-    ex_ctr_y = proposals_y1 + 0.5 * ex_heights
-
-    gt_widths = reference_boxes_x2 - reference_boxes_x1
-    gt_heights = reference_boxes_y2 - reference_boxes_y1
-    gt_ctr_x = reference_boxes_x1 + 0.5 * gt_widths
-    gt_ctr_y = reference_boxes_y1 + 0.5 * gt_heights
-
-    targets_dx = wx * (gt_ctr_x - ex_ctr_x) / ex_widths
-    targets_dy = wy * (gt_ctr_y - ex_ctr_y) / ex_heights
-    targets_dw = ww * torch.log(gt_widths / ex_widths)
-    targets_dh = wh * torch.log(gt_heights / ex_heights)
-
-    targets = torch.cat((targets_dx, targets_dy, targets_dw, targets_dh), dim=1)
-    return targets
-
-
-class BoxCoder:
-    """
-    This class encodes and decodes a set of bounding boxes into
-    the representation used for training the regressors.
-    """
-
-    def __init__(
-        self, weights: Tuple[float, float, float, float], bbox_xform_clip: float = math.log(1000.0 / 16)
-    ) -> None:
-        """
-        Args:
-            weights (4-element tuple)
-            bbox_xform_clip (float)
-        """
-        self.weights = weights
-        self.bbox_xform_clip = bbox_xform_clip
-
-    def encode(self, reference_boxes: List[Tensor], proposals: List[Tensor]) -> List[Tensor]:
-        boxes_per_image = [len(b) for b in reference_boxes]
-        reference_boxes = torch.cat(reference_boxes, dim=0)
-        proposals = torch.cat(proposals, dim=0)
-        targets = self.encode_single(reference_boxes, proposals)
-        return targets.split(boxes_per_image, 0)
-
-    def encode_single(self, reference_boxes: Tensor, proposals: Tensor) -> Tensor:
-        """
-        Encode a set of proposals with respect to some
-        reference boxes
-
-        Args:
-            reference_boxes (Tensor): reference boxes
-            proposals (Tensor): boxes to be encoded
-        """
-        dtype = reference_boxes.dtype
-        device = reference_boxes.device
-        weights = torch.as_tensor(self.weights, dtype=dtype, device=device)
-        targets = encode_boxes(reference_boxes, proposals, weights)
-
-        return targets
-
-    def decode(self, rel_codes: Tensor, boxes: List[Tensor]) -> Tensor:
-        torch._assert(
-            isinstance(boxes, (list, tuple)),
-            "This function expects boxes of type list or tuple.",
-        )
-        torch._assert(
-            isinstance(rel_codes, torch.Tensor),
-            "This function expects rel_codes of type torch.Tensor.",
-        )
-        boxes_per_image = [b.size(0) for b in boxes]
-        concat_boxes = torch.cat(boxes, dim=0)
-        box_sum = 0
-        for val in boxes_per_image:
-            box_sum += val
-        if box_sum > 0:
-            rel_codes = rel_codes.reshape(box_sum, -1)
-        pred_boxes = self.decode_single(rel_codes, concat_boxes)
-        if box_sum > 0:
-            pred_boxes = pred_boxes.reshape(box_sum, -1, 4)
-        return pred_boxes
-
-    def decode_single(self, rel_codes: Tensor, boxes: Tensor) -> Tensor:
-        """
-        From a set of original boxes and encoded relative box offsets,
-        get the decoded boxes.
-
-        Args:
-            rel_codes (Tensor): encoded boxes
-            boxes (Tensor): reference boxes.
-        """
-
-        boxes = boxes.to(rel_codes.dtype)
-
-        widths = boxes[:, 2] - boxes[:, 0]
-        heights = boxes[:, 3] - boxes[:, 1]
-        ctr_x = boxes[:, 0] + 0.5 * widths
-        ctr_y = boxes[:, 1] + 0.5 * heights
-
-        wx, wy, ww, wh = self.weights
-        dx = rel_codes[:, 0::4] / wx
-        dy = rel_codes[:, 1::4] / wy
-        dw = rel_codes[:, 2::4] / ww
-        dh = rel_codes[:, 3::4] / wh
-
-        # Prevent sending too large values into torch.exp()
-        dw = torch.clamp(dw, max=self.bbox_xform_clip)
-        dh = torch.clamp(dh, max=self.bbox_xform_clip)
-
-        pred_ctr_x = dx * widths[:, None] + ctr_x[:, None]
-        pred_ctr_y = dy * heights[:, None] + ctr_y[:, None]
-        pred_w = torch.exp(dw) * widths[:, None]
-        pred_h = torch.exp(dh) * heights[:, None]
-
-        # Distance from center to box's corner.
-        c_to_c_h = torch.tensor(0.5, dtype=pred_ctr_y.dtype, device=pred_h.device) * pred_h
-        c_to_c_w = torch.tensor(0.5, dtype=pred_ctr_x.dtype, device=pred_w.device) * pred_w
-
-        pred_boxes1 = pred_ctr_x - c_to_c_w
-        pred_boxes2 = pred_ctr_y - c_to_c_h
-        pred_boxes3 = pred_ctr_x + c_to_c_w
-        pred_boxes4 = pred_ctr_y + c_to_c_h
-        pred_boxes = torch.stack((pred_boxes1, pred_boxes2, pred_boxes3, pred_boxes4), dim=2).flatten(1)
-        return pred_boxes
-
-
-class BoxLinearCoder:
-    """
-    The linear box-to-box transform defined in FCOS. The transformation is parameterized
-    by the distance from the center of (square) src box to 4 edges of the target box.
-    """
-
-    def __init__(self, normalize_by_size: bool = True) -> None:
-        """
-        Args:
-            normalize_by_size (bool): normalize deltas by the size of src (anchor) boxes.
-        """
-        self.normalize_by_size = normalize_by_size
-
-    def encode(self, reference_boxes: Tensor, proposals: Tensor) -> Tensor:
-        """
-        Encode a set of proposals with respect to some reference boxes
-
-        Args:
-            reference_boxes (Tensor): reference boxes
-            proposals (Tensor): boxes to be encoded
-
-        Returns:
-            Tensor: the encoded relative box offsets that can be used to
-            decode the boxes.
-
-        """
-
-        # get the center of reference_boxes
-        reference_boxes_ctr_x = 0.5 * (reference_boxes[..., 0] + reference_boxes[..., 2])
-        reference_boxes_ctr_y = 0.5 * (reference_boxes[..., 1] + reference_boxes[..., 3])
-
-        # get box regression transformation deltas
-        target_l = reference_boxes_ctr_x - proposals[..., 0]
-        target_t = reference_boxes_ctr_y - proposals[..., 1]
-        target_r = proposals[..., 2] - reference_boxes_ctr_x
-        target_b = proposals[..., 3] - reference_boxes_ctr_y
-
-        targets = torch.stack((target_l, target_t, target_r, target_b), dim=-1)
-
-        if self.normalize_by_size:
-            reference_boxes_w = reference_boxes[..., 2] - reference_boxes[..., 0]
-            reference_boxes_h = reference_boxes[..., 3] - reference_boxes[..., 1]
-            reference_boxes_size = torch.stack(
-                (reference_boxes_w, reference_boxes_h, reference_boxes_w, reference_boxes_h), dim=-1
-            )
-            targets = targets / reference_boxes_size
-        return targets
-
-    def decode(self, rel_codes: Tensor, boxes: Tensor) -> Tensor:
-
-        """
-        From a set of original boxes and encoded relative box offsets,
-        get the decoded boxes.
-
-        Args:
-            rel_codes (Tensor): encoded boxes
-            boxes (Tensor): reference boxes.
-
-        Returns:
-            Tensor: the predicted boxes with the encoded relative box offsets.
-
-        .. note::
-            This method assumes that ``rel_codes`` and ``boxes`` have same size for 0th dimension. i.e. ``len(rel_codes) == len(boxes)``.
-
-        """
-
-        boxes = boxes.to(dtype=rel_codes.dtype)
-
-        ctr_x = 0.5 * (boxes[..., 0] + boxes[..., 2])
-        ctr_y = 0.5 * (boxes[..., 1] + boxes[..., 3])
-
-        if self.normalize_by_size:
-            boxes_w = boxes[..., 2] - boxes[..., 0]
-            boxes_h = boxes[..., 3] - boxes[..., 1]
-
-            list_box_size = torch.stack((boxes_w, boxes_h, boxes_w, boxes_h), dim=-1)
-            rel_codes = rel_codes * list_box_size
-
-        pred_boxes1 = ctr_x - rel_codes[..., 0]
-        pred_boxes2 = ctr_y - rel_codes[..., 1]
-        pred_boxes3 = ctr_x + rel_codes[..., 2]
-        pred_boxes4 = ctr_y + rel_codes[..., 3]
-
-        pred_boxes = torch.stack((pred_boxes1, pred_boxes2, pred_boxes3, pred_boxes4), dim=-1)
-        return pred_boxes
-
-
-class Matcher:
-    """
-    This class assigns to each predicted "element" (e.g., a box) a ground-truth
-    element. Each predicted element will have exactly zero or one matches; each
-    ground-truth element may be assigned to zero or more predicted elements.
-
-    Matching is based on the MxN match_quality_matrix, that characterizes how well
-    each (ground-truth, predicted)-pair match. For example, if the elements are
-    boxes, the matrix may contain box IoU overlap values.
-
-    The matcher returns a tensor of size N containing the index of the ground-truth
-    element m that matches to prediction n. If there is no match, a negative value
-    is returned.
-    """
-
-    BELOW_LOW_THRESHOLD = -1
-    BETWEEN_THRESHOLDS = -2
-
-    __annotations__ = {
-        "BELOW_LOW_THRESHOLD": int,
-        "BETWEEN_THRESHOLDS": int,
-    }
-
-    def __init__(self, high_threshold: float, low_threshold: float, allow_low_quality_matches: bool = False) -> None:
-        """
-        Args:
-            high_threshold (float): quality values greater than or equal to
-                this value are candidate matches.
-            low_threshold (float): a lower quality threshold used to stratify
-                matches into three levels:
-                1) matches >= high_threshold
-                2) BETWEEN_THRESHOLDS matches in [low_threshold, high_threshold)
-                3) BELOW_LOW_THRESHOLD matches in [0, low_threshold)
-            allow_low_quality_matches (bool): if True, produce additional matches
-                for predictions that have only low-quality match candidates. See
-                set_low_quality_matches_ for more details.
-        """
-        self.BELOW_LOW_THRESHOLD = -1
-        self.BETWEEN_THRESHOLDS = -2
-        torch._assert(low_threshold <= high_threshold, "low_threshold should be <= high_threshold")
-        self.high_threshold = high_threshold
-        self.low_threshold = low_threshold
-        self.allow_low_quality_matches = allow_low_quality_matches
-
-    def __call__(self, match_quality_matrix: Tensor) -> Tensor:
-        """
-        Args:
-            match_quality_matrix (Tensor[float]): an MxN tensor, containing the
-            pairwise quality between M ground-truth elements and N predicted elements.
-
-        Returns:
-            matches (Tensor[int64]): an N tensor where N[i] is a matched gt in
-            [0, M - 1] or a negative value indicating that prediction i could not
-            be matched.
-        """
-        if match_quality_matrix.numel() == 0:
-            # empty targets or proposals not supported during training
-            if match_quality_matrix.shape[0] == 0:
-                raise ValueError("No ground-truth boxes available for one of the images during training")
-            else:
-                raise ValueError("No proposal boxes available for one of the images during training")
-
-        # match_quality_matrix is M (gt) x N (predicted)
-        # Max over gt elements (dim 0) to find best gt candidate for each prediction
-        matched_vals, matches = match_quality_matrix.max(dim=0)
-        if self.allow_low_quality_matches:
-            all_matches = matches.clone()
-        else:
-            all_matches = None  # type: ignore[assignment]
-
-        # Assign candidate matches with low quality to negative (unassigned) values
-        below_low_threshold = matched_vals < self.low_threshold
-        between_thresholds = (matched_vals >= self.low_threshold) & (matched_vals < self.high_threshold)
-        matches[below_low_threshold] = self.BELOW_LOW_THRESHOLD
-        matches[between_thresholds] = self.BETWEEN_THRESHOLDS
-
-        if self.allow_low_quality_matches:
-            if all_matches is None:
-                torch._assert(False, "all_matches should not be None")
-            else:
-                self.set_low_quality_matches_(matches, all_matches, match_quality_matrix)
-
-        return matches
-
-    def set_low_quality_matches_(self, matches: Tensor, all_matches: Tensor, match_quality_matrix: Tensor) -> None:
-        """
-        Produce additional matches for predictions that have only low-quality matches.
-        Specifically, for each ground-truth find the set of predictions that have
-        maximum overlap with it (including ties); for each prediction in that set, if
-        it is unmatched, then match it to the ground-truth with which it has the highest
-        quality value.
-        """
-        # For each gt, find the prediction with which it has the highest quality
-        highest_quality_foreach_gt, _ = match_quality_matrix.max(dim=1)
-        # Find the highest quality match available, even if it is low, including ties
-        gt_pred_pairs_of_highest_quality = torch.where(match_quality_matrix == highest_quality_foreach_gt[:, None])
-        # Example gt_pred_pairs_of_highest_quality:
-        #   tensor([[    0, 39796],
-        #           [    1, 32055],
-        #           [    1, 32070],
-        #           [    2, 39190],
-        #           [    2, 40255],
-        #           [    3, 40390],
-        #           [    3, 41455],
-        #           [    4, 45470],
-        #           [    5, 45325],
-        #           [    5, 46390]])
-        # Each row is a (gt index, prediction index)
-        # Note how gt items 1, 2, 3, and 5 each have two ties
-
-        pred_inds_to_update = gt_pred_pairs_of_highest_quality[1]
-        matches[pred_inds_to_update] = all_matches[pred_inds_to_update]
-
-
-class SSDMatcher(Matcher):
-    def __init__(self, threshold: float) -> None:
-        super().__init__(threshold, threshold, allow_low_quality_matches=False)
-
-    def __call__(self, match_quality_matrix: Tensor) -> Tensor:
-        matches = super().__call__(match_quality_matrix)
-
-        # For each gt, find the prediction with which it has the highest quality
-        _, highest_quality_pred_foreach_gt = match_quality_matrix.max(dim=1)
-        matches[highest_quality_pred_foreach_gt] = torch.arange(
-            highest_quality_pred_foreach_gt.size(0), dtype=torch.int64, device=highest_quality_pred_foreach_gt.device
-        )
-
-        return matches
-
-
-def overwrite_eps(model: nn.Module, eps: float) -> None:
-    """
-    This method overwrites the default eps values of all the
-    FrozenBatchNorm2d layers of the model with the provided value.
-    This is necessary to address the BC-breaking change introduced
-    by the bug-fix at pytorch/vision#2933. The overwrite is applied
-    only when the pretrained weights are loaded to maintain compatibility
-    with previous versions.
-
-    Args:
-        model (nn.Module): The model on which we perform the overwrite.
-        eps (float): The new value of eps.
-    """
-    for module in model.modules():
-        if isinstance(module, FrozenBatchNorm2d):
-            module.eps = eps
-
-
-def retrieve_out_channels(model: nn.Module, size: Tuple[int, int]) -> List[int]:
-    """
-    This method retrieves the number of output channels of a specific model.
-
-    Args:
-        model (nn.Module): The model for which we estimate the out_channels.
-            It should return a single Tensor or an OrderedDict[Tensor].
-        size (Tuple[int, int]): The size (wxh) of the input.
-
-    Returns:
-        out_channels (List[int]): A list of the output channels of the model.
-    """
-    in_training = model.training
-    model.eval()
-
-    with torch.no_grad():
-        # Use dummy data to retrieve the feature map sizes to avoid hard-coding their values
-        device = next(model.parameters()).device
-        tmp_img = torch.zeros((1, 3, size[1], size[0]), device=device)
-        features = model(tmp_img)
-        if isinstance(features, torch.Tensor):
-            features = OrderedDict([("0", features)])
-        out_channels = [x.size(1) for x in features.values()]
-
-    if in_training:
-        model.train()
-
-    return out_channels
-
-
-@torch.jit.unused
-def _fake_cast_onnx(v: Tensor) -> int:
-    return v  # type: ignore[return-value]
-
-
-def _topk_min(input: Tensor, orig_kval: int, axis: int) -> int:
-    """
-    ONNX spec requires the k-value to be less than or equal to the number of inputs along
-    provided dim. Certain models use the number of elements along a particular axis instead of K
-    if K exceeds the number of elements along that axis. Previously, python's min() function was
-    used to determine whether to use the provided k-value or the specified dim axis value.
-
-    However, in cases where the model is being exported in tracing mode, python min() is
-    static causing the model to be traced incorrectly and eventually fail at the topk node.
-    In order to avoid this situation, in tracing mode, torch.min() is used instead.
-
-    Args:
-        input (Tensor): The original input tensor.
-        orig_kval (int): The provided k-value.
-        axis(int): Axis along which we retrieve the input size.
-
-    Returns:
-        min_kval (int): Appropriately selected k-value.
-    """
-    if not torch.jit.is_tracing():
-        return min(orig_kval, input.size(axis))
-    axis_dim_val = torch._shape_as_tensor(input)[axis].unsqueeze(0)
-    min_kval = torch.min(torch.cat((torch.tensor([orig_kval], dtype=axis_dim_val.dtype), axis_dim_val), 0))
-    return _fake_cast_onnx(min_kval)
-
-
-def _box_loss(
-    type: str,
-    box_coder: BoxCoder,
-    anchors_per_image: Tensor,
-    matched_gt_boxes_per_image: Tensor,
-    bbox_regression_per_image: Tensor,
-    cnf: Optional[Dict[str, float]] = None,
-) -> Tensor:
-    torch._assert(type in ["l1", "smooth_l1", "ciou", "diou", "giou"], f"Unsupported loss: {type}")
-
-    if type == "l1":
-        target_regression = box_coder.encode_single(matched_gt_boxes_per_image, anchors_per_image)
-        return F.l1_loss(bbox_regression_per_image, target_regression, reduction="sum")
-    elif type == "smooth_l1":
-        target_regression = box_coder.encode_single(matched_gt_boxes_per_image, anchors_per_image)
-        beta = cnf["beta"] if cnf is not None and "beta" in cnf else 1.0
-        return F.smooth_l1_loss(bbox_regression_per_image, target_regression, reduction="sum", beta=beta)
-    else:
-        bbox_per_image = box_coder.decode_single(bbox_regression_per_image, anchors_per_image)
-        eps = cnf["eps"] if cnf is not None and "eps" in cnf else 1e-7
-        if type == "ciou":
-            return complete_box_iou_loss(bbox_per_image, matched_gt_boxes_per_image, reduction="sum", eps=eps)
-        if type == "diou":
-            return distance_box_iou_loss(bbox_per_image, matched_gt_boxes_per_image, reduction="sum", eps=eps)
-        # otherwise giou
-        return generalized_box_iou_loss(bbox_per_image, matched_gt_boxes_per_image, reduction="sum", eps=eps)

+ 0 - 1193
models/wirenet/head.py

@@ -1,1193 +0,0 @@
-from collections import OrderedDict
-from typing import Dict, List, Optional, Tuple
-
-import matplotlib.pyplot as plt
-import torch
-import torch.nn.functional as F
-import torchvision
-from torch import nn, Tensor
-from torchvision.ops import boxes as box_ops, roi_align
-
-from . import _utils as det_utils
-
-from torch.utils.data.dataloader import default_collate
-
-
-def l2loss(input, target):
-    return ((target - input) ** 2).mean(2).mean(1)
-
-
-def cross_entropy_loss(logits, positive):
-    nlogp = -F.log_softmax(logits, dim=0)
-    return (positive * nlogp[1] + (1 - positive) * nlogp[0]).mean(2).mean(1)
-
-
-def sigmoid_l1_loss(logits, target, offset=0.0, mask=None):
-    logp = torch.sigmoid(logits) + offset
-    loss = torch.abs(logp - target)
-    if mask is not None:
-        w = mask.mean(2, True).mean(1, True)
-        w[w == 0] = 1
-        loss = loss * (mask / w)
-
-    return loss.mean(2).mean(1)
-
-
-# def wirepoint_loss(target, outputs, feature, loss_weight,mode):
-#     wires = target['wires']
-#     result = {"feature": feature}
-#     batch, channel, row, col = outputs[0].shape
-#     print(f"Initial Output[0] shape: {outputs[0].shape}")  # 打印初始输出形状
-#     print(f"Total Stacks: {len(outputs)}")  # 打印堆栈数
-#
-#     T = wires.copy()
-#     n_jtyp = T["junc_map"].shape[1]
-#     for task in ["junc_map"]:
-#         T[task] = T[task].permute(1, 0, 2, 3)
-#     for task in ["junc_offset"]:
-#         T[task] = T[task].permute(1, 2, 0, 3, 4)
-#
-#     offset = self.head_off
-#     loss_weight = loss_weight
-#     losses = []
-#
-#     for stack, output in enumerate(outputs):
-#         output = output.transpose(0, 1).reshape([-1, batch, row, col]).contiguous()
-#         print(f"Stack {stack} output shape: {output.shape}")  # 打印每层的输出形状
-#         jmap = output[0: offset[0]].reshape(n_jtyp, 2, batch, row, col)
-#         lmap = output[offset[0]: offset[1]].squeeze(0)
-#         joff = output[offset[1]: offset[2]].reshape(n_jtyp, 2, batch, row, col)
-#
-#         if stack == 0:
-#             result["preds"] = {
-#                 "jmap": jmap.permute(2, 0, 1, 3, 4).softmax(2)[:, :, 1],
-#                 "lmap": lmap.sigmoid(),
-#                 "joff": joff.permute(2, 0, 1, 3, 4).sigmoid() - 0.5,
-#             }
-#             # visualize_feature_map(jmap[0, 0], title=f"jmap - Stack {stack}")
-#             # visualize_feature_map(lmap, title=f"lmap - Stack {stack}")
-#             # visualize_feature_map(joff[0, 0], title=f"joff - Stack {stack}")
-#
-#             if mode == "testing":
-#                 return result
-#
-#         L = OrderedDict()
-#         L["junc_map"] = sum(
-#             cross_entropy_loss(jmap[i], T["junc_map"][i]) for i in range(n_jtyp)
-#         )
-#         L["line_map"] = (
-#             F.binary_cross_entropy_with_logits(lmap, T["line_map"], reduction="none")
-#             .mean(2)
-#             .mean(1)
-#         )
-#         L["junc_offset"] = sum(
-#             sigmoid_l1_loss(joff[i, j], T["junc_offset"][i, j], -0.5, T["junc_map"][i])
-#             for i in range(n_jtyp)
-#             for j in range(2)
-#         )
-#         for loss_name in L:
-#             L[loss_name].mul_(loss_weight[loss_name])
-#         losses.append(L)
-#
-#     result["losses"] = losses
-#     return result
-
-def wirepoint_head_line_loss(targets, output, x, y, idx, loss_weight):
-    # output, feature: head返回结果
-    # x, y, idx : line中间生成结果
-    result = {}
-    batch, channel, row, col = output.shape
-
-    wires_targets = [t["wires"] for t in targets]
-    wires_targets = wires_targets.copy()
-    # print(f'wires_target:{wires_targets}')
-    # 提取所有 'junc_map', 'junc_offset', 'line_map' 的张量
-    junc_maps = [d["junc_map"] for d in wires_targets]
-    junc_offsets = [d["junc_offset"] for d in wires_targets]
-    line_maps = [d["line_map"] for d in wires_targets]
-
-    junc_map_tensor = torch.stack(junc_maps, dim=0)
-    junc_offset_tensor = torch.stack(junc_offsets, dim=0)
-    line_map_tensor = torch.stack(line_maps, dim=0)
-    T = {"junc_map": junc_map_tensor, "junc_offset": junc_offset_tensor, "line_map": line_map_tensor}
-
-    n_jtyp = T["junc_map"].shape[1]
-
-    for task in ["junc_map"]:
-        T[task] = T[task].permute(1, 0, 2, 3)
-    for task in ["junc_offset"]:
-        T[task] = T[task].permute(1, 2, 0, 3, 4)
-
-    offset = [2, 3, 5]
-    losses = []
-    output = output.transpose(0, 1).reshape([-1, batch, row, col]).contiguous()
-    jmap = output[0: offset[0]].reshape(n_jtyp, 2, batch, row, col)
-    lmap = output[offset[0]: offset[1]].squeeze(0)
-    joff = output[offset[1]: offset[2]].reshape(n_jtyp, 2, batch, row, col)
-    L = OrderedDict()
-    L["junc_map"] = sum(
-        cross_entropy_loss(jmap[i], T["junc_map"][i]) for i in range(n_jtyp)
-    )
-    L["line_map"] = (
-        F.binary_cross_entropy_with_logits(lmap, T["line_map"], reduction="none")
-            .mean(2)
-            .mean(1)
-    )
-    L["junc_offset"] = sum(
-        sigmoid_l1_loss(joff[i, j], T["junc_offset"][i, j], -0.5, T["junc_map"][i])
-        for i in range(n_jtyp)
-        for j in range(2)
-    )
-    for loss_name in L:
-        L[loss_name].mul_(loss_weight[loss_name])
-    losses.append(L)
-    result["losses"] = losses
-
-    loss = nn.BCEWithLogitsLoss(reduction="none")
-    loss = loss(x, y)
-    lpos_mask, lneg_mask = y, 1 - y
-    loss_lpos, loss_lneg = loss * lpos_mask, loss * lneg_mask
-
-    def sum_batch(x):
-        xs = [x[idx[i]: idx[i + 1]].sum()[None] for i in range(batch)]
-        return torch.cat(xs)
-
-    lpos = sum_batch(loss_lpos) / sum_batch(lpos_mask).clamp(min=1)
-    lneg = sum_batch(loss_lneg) / sum_batch(lneg_mask).clamp(min=1)
-    result["losses"][0]["lpos"] = lpos * loss_weight["lpos"]
-    result["losses"][0]["lneg"] = lneg * loss_weight["lneg"]
-
-    return result
-
-
-def wirepoint_inference(input, idx, jcs, n_batch, ps, n_out_line, n_out_junc):
-    result = {}
-    result["wires"] = {}
-    p = torch.cat(ps)
-    s = torch.sigmoid(input)
-    b = s > 0.5
-    lines = []
-    score = []
-    print(f"n_batch:{n_batch}")
-    for i in range(n_batch):
-        print(f"idx:{idx}")
-        p0 = p[idx[i]: idx[i + 1]]
-        s0 = s[idx[i]: idx[i + 1]]
-        mask = b[idx[i]: idx[i + 1]]
-        p0 = p0[mask]
-        s0 = s0[mask]
-        if len(p0) == 0:
-            lines.append(torch.zeros([1, n_out_line, 2, 2], device=p.device))
-            score.append(torch.zeros([1, n_out_line], device=p.device))
-        else:
-            arg = torch.argsort(s0, descending=True)
-            p0, s0 = p0[arg], s0[arg]
-            lines.append(p0[None, torch.arange(n_out_line) % len(p0)])
-            score.append(s0[None, torch.arange(n_out_line) % len(s0)])
-        for j in range(len(jcs[i])):
-            if len(jcs[i][j]) == 0:
-                jcs[i][j] = torch.zeros([n_out_junc, 2], device=p.device)
-            jcs[i][j] = jcs[i][j][
-                None, torch.arange(n_out_junc) % len(jcs[i][j])
-            ]
-    result["wires"]["lines"] = torch.cat(lines)
-    result["wires"]["score"] = torch.cat(score)
-    result["wires"]["juncs"] = torch.cat([jcs[i][0] for i in range(n_batch)])
-
-    if len(jcs[i]) > 1:
-        result["preds"]["junts"] = torch.cat(
-            [jcs[i][1] for i in range(n_batch)]
-        )
-
-    return result
-
-
-def fastrcnn_loss(class_logits, box_regression, labels, regression_targets):
-    # type: (Tensor, Tensor, List[Tensor], List[Tensor]) -> Tuple[Tensor, Tensor]
-    """
-    Computes the loss for Faster R-CNN.
-
-    Args:
-        class_logits (Tensor)
-        box_regression (Tensor)
-        labels (list[BoxList])
-        regression_targets (Tensor)
-
-    Returns:
-        classification_loss (Tensor)
-        box_loss (Tensor)
-    """
-
-    labels = torch.cat(labels, dim=0)
-    regression_targets = torch.cat(regression_targets, dim=0)
-
-    classification_loss = F.cross_entropy(class_logits, labels)
-
-    # get indices that correspond to the regression targets for
-    # the corresponding ground truth labels, to be used with
-    # advanced indexing
-    sampled_pos_inds_subset = torch.where(labels > 0)[0]
-    labels_pos = labels[sampled_pos_inds_subset]
-    N, num_classes = class_logits.shape
-    box_regression = box_regression.reshape(N, box_regression.size(-1) // 4, 4)
-
-    box_loss = F.smooth_l1_loss(
-        box_regression[sampled_pos_inds_subset, labels_pos],
-        regression_targets[sampled_pos_inds_subset],
-        beta=1 / 9,
-        reduction="sum",
-    )
-    box_loss = box_loss / labels.numel()
-
-    return classification_loss, box_loss
-
-
-def maskrcnn_inference(x, labels):
-    # type: (Tensor, List[Tensor]) -> List[Tensor]
-    """
-    From the results of the CNN, post process the masks
-    by taking the mask corresponding to the class with max
-    probability (which are of fixed size and directly output
-    by the CNN) and return the masks in the mask field of the BoxList.
-
-    Args:
-        x (Tensor): the mask logits
-        labels (list[BoxList]): bounding boxes that are used as
-            reference, one for ech image
-
-    Returns:
-        results (list[BoxList]): one BoxList for each image, containing
-            the extra field mask
-    """
-    mask_prob = x.sigmoid()
-
-    # select masks corresponding to the predicted classes
-    num_masks = x.shape[0]
-    boxes_per_image = [label.shape[0] for label in labels]
-    labels = torch.cat(labels)
-    index = torch.arange(num_masks, device=labels.device)
-    mask_prob = mask_prob[index, labels][:, None]
-    mask_prob = mask_prob.split(boxes_per_image, dim=0)
-
-    return mask_prob
-
-
-def project_masks_on_boxes(gt_masks, boxes, matched_idxs, M):
-    # type: (Tensor, Tensor, Tensor, int) -> Tensor
-    """
-    Given segmentation masks and the bounding boxes corresponding
-    to the location of the masks in the image, this function
-    crops and resizes the masks in the position defined by the
-    boxes. This prepares the masks for them to be fed to the
-    loss computation as the targets.
-    """
-    matched_idxs = matched_idxs.to(boxes)
-    rois = torch.cat([matched_idxs[:, None], boxes], dim=1)
-    gt_masks = gt_masks[:, None].to(rois)
-    return roi_align(gt_masks, rois, (M, M), 1.0)[:, 0]
-
-
-def maskrcnn_loss(mask_logits, proposals, gt_masks, gt_labels, mask_matched_idxs):
-    # type: (Tensor, List[Tensor], List[Tensor], List[Tensor], List[Tensor]) -> Tensor
-    """
-    Args:
-        proposals (list[BoxList])
-        mask_logits (Tensor)
-        targets (list[BoxList])
-
-    Return:
-        mask_loss (Tensor): scalar tensor containing the loss
-    """
-
-    discretization_size = mask_logits.shape[-1]
-    # print(f'mask_logits:{mask_logits},gt_masks:{gt_masks},,gt_labels:{gt_labels}]')
-    # print(f'mask discretization_size:{discretization_size}')
-    labels = [gt_label[idxs] for gt_label, idxs in zip(gt_labels, mask_matched_idxs)]
-    # print(f'mask labels:{labels}')
-    mask_targets = [
-        project_masks_on_boxes(m, p, i, discretization_size) for m, p, i in zip(gt_masks, proposals, mask_matched_idxs)
-    ]
-
-    labels = torch.cat(labels, dim=0)
-    # print(f'mask labels1:{labels}')
-    mask_targets = torch.cat(mask_targets, dim=0)
-
-    # torch.mean (in binary_cross_entropy_with_logits) doesn't
-    # accept empty tensors, so handle it separately
-    if mask_targets.numel() == 0:
-        return mask_logits.sum() * 0
-    # print(f'mask_targets:{mask_targets.shape},mask_logits:{mask_logits.shape}')
-    # print(f'mask_targets:{mask_targets}')
-    mask_loss = F.binary_cross_entropy_with_logits(
-        mask_logits[torch.arange(labels.shape[0], device=labels.device), labels], mask_targets
-    )
-    # print(f'mask_loss:{mask_loss}')
-    return mask_loss
-
-
-def keypoints_to_heatmap(keypoints, rois, heatmap_size):
-    # type: (Tensor, Tensor, int) -> Tuple[Tensor, Tensor]
-    offset_x = rois[:, 0]
-    offset_y = rois[:, 1]
-    scale_x = heatmap_size / (rois[:, 2] - rois[:, 0])
-    scale_y = heatmap_size / (rois[:, 3] - rois[:, 1])
-
-    offset_x = offset_x[:, None]
-    offset_y = offset_y[:, None]
-    scale_x = scale_x[:, None]
-    scale_y = scale_y[:, None]
-
-    x = keypoints[..., 0]
-    y = keypoints[..., 1]
-
-    x_boundary_inds = x == rois[:, 2][:, None]
-    y_boundary_inds = y == rois[:, 3][:, None]
-
-    x = (x - offset_x) * scale_x
-    x = x.floor().long()
-    y = (y - offset_y) * scale_y
-    y = y.floor().long()
-
-    x[x_boundary_inds] = heatmap_size - 1
-    y[y_boundary_inds] = heatmap_size - 1
-
-    valid_loc = (x >= 0) & (y >= 0) & (x < heatmap_size) & (y < heatmap_size)
-    vis = keypoints[..., 2] > 0
-    valid = (valid_loc & vis).long()
-
-    lin_ind = y * heatmap_size + x
-    heatmaps = lin_ind * valid
-
-    return heatmaps, valid
-
-
-def _onnx_heatmaps_to_keypoints(
-        maps, maps_i, roi_map_width, roi_map_height, widths_i, heights_i, offset_x_i, offset_y_i
-):
-    num_keypoints = torch.scalar_tensor(maps.size(1), dtype=torch.int64)
-
-    width_correction = widths_i / roi_map_width
-    height_correction = heights_i / roi_map_height
-
-    roi_map = F.interpolate(
-        maps_i[:, None], size=(int(roi_map_height), int(roi_map_width)), mode="bicubic", align_corners=False
-    )[:, 0]
-
-    w = torch.scalar_tensor(roi_map.size(2), dtype=torch.int64)
-    pos = roi_map.reshape(num_keypoints, -1).argmax(dim=1)
-
-    x_int = pos % w
-    y_int = (pos - x_int) // w
-
-    x = (torch.tensor(0.5, dtype=torch.float32) + x_int.to(dtype=torch.float32)) * width_correction.to(
-        dtype=torch.float32
-    )
-    y = (torch.tensor(0.5, dtype=torch.float32) + y_int.to(dtype=torch.float32)) * height_correction.to(
-        dtype=torch.float32
-    )
-
-    xy_preds_i_0 = x + offset_x_i.to(dtype=torch.float32)
-    xy_preds_i_1 = y + offset_y_i.to(dtype=torch.float32)
-    xy_preds_i_2 = torch.ones(xy_preds_i_1.shape, dtype=torch.float32)
-    xy_preds_i = torch.stack(
-        [
-            xy_preds_i_0.to(dtype=torch.float32),
-            xy_preds_i_1.to(dtype=torch.float32),
-            xy_preds_i_2.to(dtype=torch.float32),
-        ],
-        0,
-    )
-
-    # TODO: simplify when indexing without rank will be supported by ONNX
-    base = num_keypoints * num_keypoints + num_keypoints + 1
-    ind = torch.arange(num_keypoints)
-    ind = ind.to(dtype=torch.int64) * base
-    end_scores_i = (
-        roi_map.index_select(1, y_int.to(dtype=torch.int64))
-            .index_select(2, x_int.to(dtype=torch.int64))
-            .view(-1)
-            .index_select(0, ind.to(dtype=torch.int64))
-    )
-
-    return xy_preds_i, end_scores_i
-
-
-@torch.jit._script_if_tracing
-def _onnx_heatmaps_to_keypoints_loop(
-        maps, rois, widths_ceil, heights_ceil, widths, heights, offset_x, offset_y, num_keypoints
-):
-    xy_preds = torch.zeros((0, 3, int(num_keypoints)), dtype=torch.float32, device=maps.device)
-    end_scores = torch.zeros((0, int(num_keypoints)), dtype=torch.float32, device=maps.device)
-
-    for i in range(int(rois.size(0))):
-        xy_preds_i, end_scores_i = _onnx_heatmaps_to_keypoints(
-            maps, maps[i], widths_ceil[i], heights_ceil[i], widths[i], heights[i], offset_x[i], offset_y[i]
-        )
-        xy_preds = torch.cat((xy_preds.to(dtype=torch.float32), xy_preds_i.unsqueeze(0).to(dtype=torch.float32)), 0)
-        end_scores = torch.cat(
-            (end_scores.to(dtype=torch.float32), end_scores_i.to(dtype=torch.float32).unsqueeze(0)), 0
-        )
-    return xy_preds, end_scores
-
-
-def heatmaps_to_keypoints(maps, rois):
-    """Extract predicted keypoint locations from heatmaps. Output has shape
-    (#rois, 4, #keypoints) with the 4 rows corresponding to (x, y, logit, prob)
-    for each keypoint.
-    """
-    # This function converts a discrete image coordinate in a HEATMAP_SIZE x
-    # HEATMAP_SIZE image to a continuous keypoint coordinate. We maintain
-    # consistency with keypoints_to_heatmap_labels by using the conversion from
-    # Heckbert 1990: c = d + 0.5, where d is a discrete coordinate and c is a
-    # continuous coordinate.
-    offset_x = rois[:, 0]
-    offset_y = rois[:, 1]
-
-    widths = rois[:, 2] - rois[:, 0]
-    heights = rois[:, 3] - rois[:, 1]
-    widths = widths.clamp(min=1)
-    heights = heights.clamp(min=1)
-    widths_ceil = widths.ceil()
-    heights_ceil = heights.ceil()
-
-    num_keypoints = maps.shape[1]
-
-    if torchvision._is_tracing():
-        xy_preds, end_scores = _onnx_heatmaps_to_keypoints_loop(
-            maps,
-            rois,
-            widths_ceil,
-            heights_ceil,
-            widths,
-            heights,
-            offset_x,
-            offset_y,
-            torch.scalar_tensor(num_keypoints, dtype=torch.int64),
-        )
-        return xy_preds.permute(0, 2, 1), end_scores
-
-    xy_preds = torch.zeros((len(rois), 3, num_keypoints), dtype=torch.float32, device=maps.device)
-    end_scores = torch.zeros((len(rois), num_keypoints), dtype=torch.float32, device=maps.device)
-    for i in range(len(rois)):
-        roi_map_width = int(widths_ceil[i].item())
-        roi_map_height = int(heights_ceil[i].item())
-        width_correction = widths[i] / roi_map_width
-        height_correction = heights[i] / roi_map_height
-        roi_map = F.interpolate(
-            maps[i][:, None], size=(roi_map_height, roi_map_width), mode="bicubic", align_corners=False
-        )[:, 0]
-        # roi_map_probs = scores_to_probs(roi_map.copy())
-        w = roi_map.shape[2]
-        pos = roi_map.reshape(num_keypoints, -1).argmax(dim=1)
-
-        x_int = pos % w
-        y_int = torch.div(pos - x_int, w, rounding_mode="floor")
-        # assert (roi_map_probs[k, y_int, x_int] ==
-        #         roi_map_probs[k, :, :].max())
-        x = (x_int.float() + 0.5) * width_correction
-        y = (y_int.float() + 0.5) * height_correction
-        xy_preds[i, 0, :] = x + offset_x[i]
-        xy_preds[i, 1, :] = y + offset_y[i]
-        xy_preds[i, 2, :] = 1
-        end_scores[i, :] = roi_map[torch.arange(num_keypoints, device=roi_map.device), y_int, x_int]
-
-    return xy_preds.permute(0, 2, 1), end_scores
-
-
-def keypointrcnn_loss(keypoint_logits, proposals, gt_keypoints, keypoint_matched_idxs):
-    # type: (Tensor, List[Tensor], List[Tensor], List[Tensor]) -> Tensor
-    N, K, H, W = keypoint_logits.shape
-    if H != W:
-        raise ValueError(
-            f"keypoint_logits height and width (last two elements of shape) should be equal. Instead got H = {H} and W = {W}"
-        )
-    discretization_size = H
-    heatmaps = []
-    valid = []
-    for proposals_per_image, gt_kp_in_image, midx in zip(proposals, gt_keypoints, keypoint_matched_idxs):
-        kp = gt_kp_in_image[midx]
-        heatmaps_per_image, valid_per_image = keypoints_to_heatmap(kp, proposals_per_image, discretization_size)
-        heatmaps.append(heatmaps_per_image.view(-1))
-        valid.append(valid_per_image.view(-1))
-
-    keypoint_targets = torch.cat(heatmaps, dim=0)
-    valid = torch.cat(valid, dim=0).to(dtype=torch.uint8)
-    valid = torch.where(valid)[0]
-
-    # torch.mean (in binary_cross_entropy_with_logits) doesn't
-    # accept empty tensors, so handle it sepaartely
-    if keypoint_targets.numel() == 0 or len(valid) == 0:
-        return keypoint_logits.sum() * 0
-
-    keypoint_logits = keypoint_logits.view(N * K, H * W)
-
-    keypoint_loss = F.cross_entropy(keypoint_logits[valid], keypoint_targets[valid])
-    return keypoint_loss
-
-
-def keypointrcnn_inference(x, boxes):
-    print(f'x:{x.shape}')
-    # type: (Tensor, List[Tensor]) -> Tuple[List[Tensor], List[Tensor]]
-    kp_probs = []
-    kp_scores = []
-
-    boxes_per_image = [box.size(0) for box in boxes]
-    x2 = x.split(boxes_per_image, dim=0)
-    print(f'x2:{x2}')
-
-    for xx, bb in zip(x2, boxes):
-        kp_prob, scores = heatmaps_to_keypoints(xx, bb)
-        kp_probs.append(kp_prob)
-        kp_scores.append(scores)
-
-    return kp_probs, kp_scores
-
-
-def _onnx_expand_boxes(boxes, scale):
-    # type: (Tensor, float) -> Tensor
-    w_half = (boxes[:, 2] - boxes[:, 0]) * 0.5
-    h_half = (boxes[:, 3] - boxes[:, 1]) * 0.5
-    x_c = (boxes[:, 2] + boxes[:, 0]) * 0.5
-    y_c = (boxes[:, 3] + boxes[:, 1]) * 0.5
-
-    w_half = w_half.to(dtype=torch.float32) * scale
-    h_half = h_half.to(dtype=torch.float32) * scale
-
-    boxes_exp0 = x_c - w_half
-    boxes_exp1 = y_c - h_half
-    boxes_exp2 = x_c + w_half
-    boxes_exp3 = y_c + h_half
-    boxes_exp = torch.stack((boxes_exp0, boxes_exp1, boxes_exp2, boxes_exp3), 1)
-    return boxes_exp
-
-
-# the next two functions should be merged inside Masker
-# but are kept here for the moment while we need them
-# temporarily for paste_mask_in_image
-def expand_boxes(boxes, scale):
-    # type: (Tensor, float) -> Tensor
-    if torchvision._is_tracing():
-        return _onnx_expand_boxes(boxes, scale)
-    w_half = (boxes[:, 2] - boxes[:, 0]) * 0.5
-    h_half = (boxes[:, 3] - boxes[:, 1]) * 0.5
-    x_c = (boxes[:, 2] + boxes[:, 0]) * 0.5
-    y_c = (boxes[:, 3] + boxes[:, 1]) * 0.5
-
-    w_half *= scale
-    h_half *= scale
-
-    boxes_exp = torch.zeros_like(boxes)
-    boxes_exp[:, 0] = x_c - w_half
-    boxes_exp[:, 2] = x_c + w_half
-    boxes_exp[:, 1] = y_c - h_half
-    boxes_exp[:, 3] = y_c + h_half
-    return boxes_exp
-
-
-@torch.jit.unused
-def expand_masks_tracing_scale(M, padding):
-    # type: (int, int) -> float
-    return torch.tensor(M + 2 * padding).to(torch.float32) / torch.tensor(M).to(torch.float32)
-
-
-def expand_masks(mask, padding):
-    # type: (Tensor, int) -> Tuple[Tensor, float]
-    M = mask.shape[-1]
-    if torch._C._get_tracing_state():  # could not import is_tracing(), not sure why
-        scale = expand_masks_tracing_scale(M, padding)
-    else:
-        scale = float(M + 2 * padding) / M
-    padded_mask = F.pad(mask, (padding,) * 4)
-    return padded_mask, scale
-
-
-def paste_mask_in_image(mask, box, im_h, im_w):
-    # type: (Tensor, Tensor, int, int) -> Tensor
-    TO_REMOVE = 1
-    w = int(box[2] - box[0] + TO_REMOVE)
-    h = int(box[3] - box[1] + TO_REMOVE)
-    w = max(w, 1)
-    h = max(h, 1)
-
-    # Set shape to [batchxCxHxW]
-    mask = mask.expand((1, 1, -1, -1))
-
-    # Resize mask
-    mask = F.interpolate(mask, size=(h, w), mode="bilinear", align_corners=False)
-    mask = mask[0][0]
-
-    im_mask = torch.zeros((im_h, im_w), dtype=mask.dtype, device=mask.device)
-    x_0 = max(box[0], 0)
-    x_1 = min(box[2] + 1, im_w)
-    y_0 = max(box[1], 0)
-    y_1 = min(box[3] + 1, im_h)
-
-    im_mask[y_0:y_1, x_0:x_1] = mask[(y_0 - box[1]): (y_1 - box[1]), (x_0 - box[0]): (x_1 - box[0])]
-    return im_mask
-
-
-def _onnx_paste_mask_in_image(mask, box, im_h, im_w):
-    one = torch.ones(1, dtype=torch.int64)
-    zero = torch.zeros(1, dtype=torch.int64)
-
-    w = box[2] - box[0] + one
-    h = box[3] - box[1] + one
-    w = torch.max(torch.cat((w, one)))
-    h = torch.max(torch.cat((h, one)))
-
-    # Set shape to [batchxCxHxW]
-    mask = mask.expand((1, 1, mask.size(0), mask.size(1)))
-
-    # Resize mask
-    mask = F.interpolate(mask, size=(int(h), int(w)), mode="bilinear", align_corners=False)
-    mask = mask[0][0]
-
-    x_0 = torch.max(torch.cat((box[0].unsqueeze(0), zero)))
-    x_1 = torch.min(torch.cat((box[2].unsqueeze(0) + one, im_w.unsqueeze(0))))
-    y_0 = torch.max(torch.cat((box[1].unsqueeze(0), zero)))
-    y_1 = torch.min(torch.cat((box[3].unsqueeze(0) + one, im_h.unsqueeze(0))))
-
-    unpaded_im_mask = mask[(y_0 - box[1]): (y_1 - box[1]), (x_0 - box[0]): (x_1 - box[0])]
-
-    # TODO : replace below with a dynamic padding when support is added in ONNX
-
-    # pad y
-    zeros_y0 = torch.zeros(y_0, unpaded_im_mask.size(1))
-    zeros_y1 = torch.zeros(im_h - y_1, unpaded_im_mask.size(1))
-    concat_0 = torch.cat((zeros_y0, unpaded_im_mask.to(dtype=torch.float32), zeros_y1), 0)[0:im_h, :]
-    # pad x
-    zeros_x0 = torch.zeros(concat_0.size(0), x_0)
-    zeros_x1 = torch.zeros(concat_0.size(0), im_w - x_1)
-    im_mask = torch.cat((zeros_x0, concat_0, zeros_x1), 1)[:, :im_w]
-    return im_mask
-
-
-@torch.jit._script_if_tracing
-def _onnx_paste_masks_in_image_loop(masks, boxes, im_h, im_w):
-    res_append = torch.zeros(0, im_h, im_w)
-    for i in range(masks.size(0)):
-        mask_res = _onnx_paste_mask_in_image(masks[i][0], boxes[i], im_h, im_w)
-        mask_res = mask_res.unsqueeze(0)
-        res_append = torch.cat((res_append, mask_res))
-    return res_append
-
-
-def paste_masks_in_image(masks, boxes, img_shape, padding=1):
-    # type: (Tensor, Tensor, Tuple[int, int], int) -> Tensor
-    masks, scale = expand_masks(masks, padding=padding)
-    boxes = expand_boxes(boxes, scale).to(dtype=torch.int64)
-    im_h, im_w = img_shape
-
-    if torchvision._is_tracing():
-        return _onnx_paste_masks_in_image_loop(
-            masks, boxes, torch.scalar_tensor(im_h, dtype=torch.int64), torch.scalar_tensor(im_w, dtype=torch.int64)
-        )[:, None]
-    res = [paste_mask_in_image(m[0], b, im_h, im_w) for m, b in zip(masks, boxes)]
-    if len(res) > 0:
-        ret = torch.stack(res, dim=0)[:, None]
-    else:
-        ret = masks.new_empty((0, 1, im_h, im_w))
-    return ret
-
-
-class RoIHeads(nn.Module):
-    __annotations__ = {
-        "box_coder": det_utils.BoxCoder,
-        "proposal_matcher": det_utils.Matcher,
-        "fg_bg_sampler": det_utils.BalancedPositiveNegativeSampler,
-    }
-
-    def __init__(
-            self,
-            box_roi_pool,
-            box_head,
-            box_predictor,
-            # Faster R-CNN training
-            fg_iou_thresh,
-            bg_iou_thresh,
-            batch_size_per_image,
-            positive_fraction,
-            bbox_reg_weights,
-            # Faster R-CNN inference
-            score_thresh,
-            nms_thresh,
-            detections_per_img,
-            # Mask
-            mask_roi_pool=None,
-            mask_head=None,
-            mask_predictor=None,
-            keypoint_roi_pool=None,
-            keypoint_head=None,
-            keypoint_predictor=None,
-            wirepoint_roi_pool=None,
-            wirepoint_head=None,
-            wirepoint_predictor=None,
-    ):
-        super().__init__()
-
-        self.box_similarity = box_ops.box_iou
-        # assign ground-truth boxes for each proposal
-        self.proposal_matcher = det_utils.Matcher(fg_iou_thresh, bg_iou_thresh, allow_low_quality_matches=False)
-
-        self.fg_bg_sampler = det_utils.BalancedPositiveNegativeSampler(batch_size_per_image, positive_fraction)
-
-        if bbox_reg_weights is None:
-            bbox_reg_weights = (10.0, 10.0, 5.0, 5.0)
-        self.box_coder = det_utils.BoxCoder(bbox_reg_weights)
-
-        self.box_roi_pool = box_roi_pool
-        self.box_head = box_head
-        self.box_predictor = box_predictor
-
-        self.score_thresh = score_thresh
-        self.nms_thresh = nms_thresh
-        self.detections_per_img = detections_per_img
-
-        self.mask_roi_pool = mask_roi_pool
-        self.mask_head = mask_head
-        self.mask_predictor = mask_predictor
-
-        self.keypoint_roi_pool = keypoint_roi_pool
-        self.keypoint_head = keypoint_head
-        self.keypoint_predictor = keypoint_predictor
-
-        self.wirepoint_roi_pool = wirepoint_roi_pool
-        self.wirepoint_head = wirepoint_head
-        self.wirepoint_predictor = wirepoint_predictor
-
-    def has_mask(self):
-        if self.mask_roi_pool is None:
-            return False
-        if self.mask_head is None:
-            return False
-        if self.mask_predictor is None:
-            return False
-        return True
-
-    def has_keypoint(self):
-        if self.keypoint_roi_pool is None:
-            return False
-        if self.keypoint_head is None:
-            return False
-        if self.keypoint_predictor is None:
-            return False
-        return True
-
-    def has_wirepoint(self):
-        if self.wirepoint_roi_pool is None:
-            print(f'wirepoint_roi_pool is None')
-            return False
-        if self.wirepoint_head is None:
-            print(f'wirepoint_head is None')
-            return False
-        if self.wirepoint_predictor is None:
-            print(f'wirepoint_roi_predictor is None')
-            return False
-        return True
-
-    def assign_targets_to_proposals(self, proposals, gt_boxes, gt_labels):
-        # type: (List[Tensor], List[Tensor], List[Tensor]) -> Tuple[List[Tensor], List[Tensor]]
-        matched_idxs = []
-        labels = []
-        for proposals_in_image, gt_boxes_in_image, gt_labels_in_image in zip(proposals, gt_boxes, gt_labels):
-
-            if gt_boxes_in_image.numel() == 0:
-                # Background image
-                device = proposals_in_image.device
-                clamped_matched_idxs_in_image = torch.zeros(
-                    (proposals_in_image.shape[0],), dtype=torch.int64, device=device
-                )
-                labels_in_image = torch.zeros((proposals_in_image.shape[0],), dtype=torch.int64, device=device)
-            else:
-                #  set to self.box_similarity when https://github.com/pytorch/pytorch/issues/27495 lands
-                match_quality_matrix = box_ops.box_iou(gt_boxes_in_image, proposals_in_image)
-                matched_idxs_in_image = self.proposal_matcher(match_quality_matrix)
-
-                clamped_matched_idxs_in_image = matched_idxs_in_image.clamp(min=0)
-
-                labels_in_image = gt_labels_in_image[clamped_matched_idxs_in_image]
-                labels_in_image = labels_in_image.to(dtype=torch.int64)
-
-                # Label background (below the low threshold)
-                bg_inds = matched_idxs_in_image == self.proposal_matcher.BELOW_LOW_THRESHOLD
-                labels_in_image[bg_inds] = 0
-
-                # Label ignore proposals (between low and high thresholds)
-                ignore_inds = matched_idxs_in_image == self.proposal_matcher.BETWEEN_THRESHOLDS
-                labels_in_image[ignore_inds] = -1  # -1 is ignored by sampler
-
-            matched_idxs.append(clamped_matched_idxs_in_image)
-            labels.append(labels_in_image)
-        return matched_idxs, labels
-
-    def subsample(self, labels):
-        # type: (List[Tensor]) -> List[Tensor]
-        sampled_pos_inds, sampled_neg_inds = self.fg_bg_sampler(labels)
-        sampled_inds = []
-        for img_idx, (pos_inds_img, neg_inds_img) in enumerate(zip(sampled_pos_inds, sampled_neg_inds)):
-            img_sampled_inds = torch.where(pos_inds_img | neg_inds_img)[0]
-            sampled_inds.append(img_sampled_inds)
-        return sampled_inds
-
-    def add_gt_proposals(self, proposals, gt_boxes):
-        # type: (List[Tensor], List[Tensor]) -> List[Tensor]
-        proposals = [torch.cat((proposal, gt_box)) for proposal, gt_box in zip(proposals, gt_boxes)]
-
-        return proposals
-
-    def check_targets(self, targets):
-        # type: (Optional[List[Dict[str, Tensor]]]) -> None
-        if targets is None:
-            raise ValueError("targets should not be None")
-        if not all(["boxes" in t for t in targets]):
-            raise ValueError("Every element of targets should have a boxes key")
-        if not all(["labels" in t for t in targets]):
-            raise ValueError("Every element of targets should have a labels key")
-        if self.has_mask():
-            if not all(["masks" in t for t in targets]):
-                raise ValueError("Every element of targets should have a masks key")
-
-    def select_training_samples(
-            self,
-            proposals,  # type: List[Tensor]
-            targets,  # type: Optional[List[Dict[str, Tensor]]]
-    ):
-        # type: (...) -> Tuple[List[Tensor], List[Tensor], List[Tensor], List[Tensor]]
-        self.check_targets(targets)
-        if targets is None:
-            raise ValueError("targets should not be None")
-        dtype = proposals[0].dtype
-        device = proposals[0].device
-
-        gt_boxes = [t["boxes"].to(dtype) for t in targets]
-        gt_labels = [t["labels"] for t in targets]
-
-        # append ground-truth bboxes to propos
-        proposals = self.add_gt_proposals(proposals, gt_boxes)
-
-        # get matching gt indices for each proposal
-        matched_idxs, labels = self.assign_targets_to_proposals(proposals, gt_boxes, gt_labels)
-        # sample a fixed proportion of positive-negative proposals
-        sampled_inds = self.subsample(labels)
-        matched_gt_boxes = []
-        num_images = len(proposals)
-        for img_id in range(num_images):
-            img_sampled_inds = sampled_inds[img_id]
-            proposals[img_id] = proposals[img_id][img_sampled_inds]
-            labels[img_id] = labels[img_id][img_sampled_inds]
-            matched_idxs[img_id] = matched_idxs[img_id][img_sampled_inds]
-
-            gt_boxes_in_image = gt_boxes[img_id]
-            if gt_boxes_in_image.numel() == 0:
-                gt_boxes_in_image = torch.zeros((1, 4), dtype=dtype, device=device)
-            matched_gt_boxes.append(gt_boxes_in_image[matched_idxs[img_id]])
-
-        regression_targets = self.box_coder.encode(matched_gt_boxes, proposals)
-        return proposals, matched_idxs, labels, regression_targets
-
-    def postprocess_detections(
-            self,
-            class_logits,  # type: Tensor
-            box_regression,  # type: Tensor
-            proposals,  # type: List[Tensor]
-            image_shapes,  # type: List[Tuple[int, int]]
-    ):
-        # type: (...) -> Tuple[List[Tensor], List[Tensor], List[Tensor]]
-        device = class_logits.device
-        num_classes = class_logits.shape[-1]
-
-        boxes_per_image = [boxes_in_image.shape[0] for boxes_in_image in proposals]
-        pred_boxes = self.box_coder.decode(box_regression, proposals)
-
-        pred_scores = F.softmax(class_logits, -1)
-
-        pred_boxes_list = pred_boxes.split(boxes_per_image, 0)
-        pred_scores_list = pred_scores.split(boxes_per_image, 0)
-
-        all_boxes = []
-        all_scores = []
-        all_labels = []
-        for boxes, scores, image_shape in zip(pred_boxes_list, pred_scores_list, image_shapes):
-            boxes = box_ops.clip_boxes_to_image(boxes, image_shape)
-
-            # create labels for each prediction
-            labels = torch.arange(num_classes, device=device)
-            labels = labels.view(1, -1).expand_as(scores)
-
-            # remove predictions with the background label
-            boxes = boxes[:, 1:]
-            scores = scores[:, 1:]
-            labels = labels[:, 1:]
-
-            # batch everything, by making every class prediction be a separate instance
-            boxes = boxes.reshape(-1, 4)
-            scores = scores.reshape(-1)
-            labels = labels.reshape(-1)
-
-            # remove low scoring boxes
-            inds = torch.where(scores > self.score_thresh)[0]
-            boxes, scores, labels = boxes[inds], scores[inds], labels[inds]
-
-            # remove empty boxes
-            keep = box_ops.remove_small_boxes(boxes, min_size=1e-2)
-            boxes, scores, labels = boxes[keep], scores[keep], labels[keep]
-
-            # non-maximum suppression, independently done per class
-            keep = box_ops.batched_nms(boxes, scores, labels, self.nms_thresh)
-            # keep only topk scoring predictions
-            keep = keep[: self.detections_per_img]
-            boxes, scores, labels = boxes[keep], scores[keep], labels[keep]
-
-            all_boxes.append(boxes)
-            all_scores.append(scores)
-            all_labels.append(labels)
-
-        return all_boxes, all_scores, all_labels
-
-    def forward(
-            self,
-            features,  # type: Dict[str, Tensor]
-            proposals,  # type: List[Tensor]
-            image_shapes,  # type: List[Tuple[int, int]]
-            targets=None,  # type: Optional[List[Dict[str, Tensor]]]
-    ):
-        # type: (...) -> Tuple[List[Dict[str, Tensor]], Dict[str, Tensor]]
-        """
-        Args:
-            features (List[Tensor])
-            proposals (List[Tensor[N, 4]])
-            image_shapes (List[Tuple[H, W]])
-            targets (List[Dict])
-        """
-        if targets is not None:
-            for t in targets:
-                # TODO: https://github.com/pytorch/pytorch/issues/26731
-                floating_point_types = (torch.float, torch.double, torch.half)
-                if not t["boxes"].dtype in floating_point_types:
-                    raise TypeError(f"target boxes must of float type, instead got {t['boxes'].dtype}")
-                if not t["labels"].dtype == torch.int64:
-                    raise TypeError(f"target labels must of int64 type, instead got {t['labels'].dtype}")
-                if self.has_keypoint():
-                    if not t["keypoints"].dtype == torch.float32:
-                        raise TypeError(f"target keypoints must of float type, instead got {t['keypoints'].dtype}")
-
-        if self.training:
-            proposals, matched_idxs, labels, regression_targets = self.select_training_samples(proposals, targets)
-        else:
-            labels = None
-            regression_targets = None
-            matched_idxs = None
-
-        box_features = self.box_roi_pool(features, proposals, image_shapes)
-        box_features = self.box_head(box_features)
-        class_logits, box_regression = self.box_predictor(box_features)
-
-        result: List[Dict[str, torch.Tensor]] = []
-        losses = {}
-        if self.training:
-            if labels is None:
-                raise ValueError("labels cannot be None")
-            if regression_targets is None:
-                raise ValueError("regression_targets cannot be None")
-            loss_classifier, loss_box_reg = fastrcnn_loss(class_logits, box_regression, labels, regression_targets)
-            losses = {"loss_classifier": loss_classifier, "loss_box_reg": loss_box_reg}
-        else:
-            boxes, scores, labels = self.postprocess_detections(class_logits, box_regression, proposals, image_shapes)
-            num_images = len(boxes)
-            for i in range(num_images):
-                result.append(
-                    {
-                        "boxes": boxes[i],
-                        "labels": labels[i],
-                        "scores": scores[i],
-                    }
-                )
-
-        if self.has_mask():
-            mask_proposals = [p["boxes"] for p in result]
-            if self.training:
-                if matched_idxs is None:
-                    raise ValueError("if in training, matched_idxs should not be None")
-
-                # during training, only focus on positive boxes
-                num_images = len(proposals)
-                mask_proposals = []
-                pos_matched_idxs = []
-                for img_id in range(num_images):
-                    pos = torch.where(labels[img_id] > 0)[0]
-                    mask_proposals.append(proposals[img_id][pos])
-                    pos_matched_idxs.append(matched_idxs[img_id][pos])
-            else:
-                pos_matched_idxs = None
-
-            if self.mask_roi_pool is not None:
-                mask_features = self.mask_roi_pool(features, mask_proposals, image_shapes)
-                mask_features = self.mask_head(mask_features)
-                mask_logits = self.mask_predictor(mask_features)
-            else:
-                raise Exception("Expected mask_roi_pool to be not None")
-
-            loss_mask = {}
-            if self.training:
-                if targets is None or pos_matched_idxs is None or mask_logits is None:
-                    raise ValueError("targets, pos_matched_idxs, mask_logits cannot be None when training")
-
-                gt_masks = [t["masks"] for t in targets]
-                gt_labels = [t["labels"] for t in targets]
-                rcnn_loss_mask = maskrcnn_loss(mask_logits, mask_proposals, gt_masks, gt_labels, pos_matched_idxs)
-                loss_mask = {"loss_mask": rcnn_loss_mask}
-            else:
-                labels = [r["labels"] for r in result]
-                masks_probs = maskrcnn_inference(mask_logits, labels)
-                for mask_prob, r in zip(masks_probs, result):
-                    r["masks"] = mask_prob
-
-            losses.update(loss_mask)
-
-        # keep none checks in if conditional so torchscript will conditionally
-        # compile each branch
-        if self.has_keypoint():
-            keypoint_proposals = [p["boxes"] for p in result]
-            if self.training:
-                # during training, only focus on positive boxes
-                num_images = len(proposals)
-                keypoint_proposals = []
-                pos_matched_idxs = []
-                if matched_idxs is None:
-                    raise ValueError("if in trainning, matched_idxs should not be None")
-
-                for img_id in range(num_images):
-                    pos = torch.where(labels[img_id] > 0)[0]
-                    keypoint_proposals.append(proposals[img_id][pos])
-                    pos_matched_idxs.append(matched_idxs[img_id][pos])
-            else:
-                pos_matched_idxs = None
-
-            keypoint_features = self.keypoint_roi_pool(features, keypoint_proposals, image_shapes)
-            # tmp = keypoint_features[0][0]
-            # plt.imshow(tmp.detach().numpy())
-            print(f'keypoint_features from roi_pool:{keypoint_features.shape}')
-            keypoint_features = self.keypoint_head(keypoint_features)
-
-            print(f'keypoint_features:{keypoint_features.shape}')
-            tmp = keypoint_features[0][0]
-            plt.imshow(tmp.detach().numpy())
-            keypoint_logits = self.keypoint_predictor(keypoint_features)
-            print(f'keypoint_logits:{keypoint_logits.shape}')
-            """
-            接wirenet
-            """
-
-            loss_keypoint = {}
-            if self.training:
-                if targets is None or pos_matched_idxs is None:
-                    raise ValueError("both targets and pos_matched_idxs should not be None when in training mode")
-
-                gt_keypoints = [t["keypoints"] for t in targets]
-                rcnn_loss_keypoint = keypointrcnn_loss(
-                    keypoint_logits, keypoint_proposals, gt_keypoints, pos_matched_idxs
-                )
-                loss_keypoint = {"loss_keypoint": rcnn_loss_keypoint}
-            else:
-                if keypoint_logits is None or keypoint_proposals is None:
-                    raise ValueError(
-                        "both keypoint_logits and keypoint_proposals should not be None when not in training mode"
-                    )
-
-                keypoints_probs, kp_scores = keypointrcnn_inference(keypoint_logits, keypoint_proposals)
-                for keypoint_prob, kps, r in zip(keypoints_probs, kp_scores, result):
-                    r["keypoints"] = keypoint_prob
-                    r["keypoints_scores"] = kps
-            losses.update(loss_keypoint)
-
-        if self.has_wirepoint():
-            wirepoint_proposals = [p["boxes"] for p in result]
-            if self.training:
-                # during training, only focus on positive boxes
-                num_images = len(proposals)
-                wirepoint_proposals = []
-                pos_matched_idxs = []
-                if matched_idxs is None:
-                    raise ValueError("if in trainning, matched_idxs should not be None")
-
-                for img_id in range(num_images):
-                    pos = torch.where(labels[img_id] > 0)[0]
-                    wirepoint_proposals.append(proposals[img_id][pos])
-                    pos_matched_idxs.append(matched_idxs[img_id][pos])
-            else:
-                pos_matched_idxs = None
-
-            print(f'proposals:{len(proposals)}')
-            wirepoint_features = self.wirepoint_roi_pool(features, wirepoint_proposals, image_shapes)
-
-            # tmp = keypoint_features[0][0]
-            # plt.imshow(tmp.detach().numpy())
-            print(f'keypoint_features from roi_pool:{wirepoint_features.shape}')
-            outputs, wirepoint_features = self.wirepoint_head(wirepoint_features)
-
-            outputs = merge_features(outputs, wirepoint_proposals)
-            wirepoint_features = merge_features(wirepoint_features, wirepoint_proposals)
-
-            print(f'outpust:{outputs.shape}')
-
-            wirepoint_logits = self.wirepoint_predictor(inputs=outputs, features=wirepoint_features, targets=targets)
-            x, y, idx, jcs, n_batch, ps, n_out_line, n_out_junc = wirepoint_logits
-
-            print(f'keypoint_features:{wirepoint_features.shape}')
-            if self.training:
-
-                if targets is None or pos_matched_idxs is None:
-                    raise ValueError("both targets and pos_matched_idxs should not be None when in training mode")
-
-                loss_weight = {'junc_map': 8.0, 'line_map': 0.5, 'junc_offset': 0.25, 'lpos': 1, 'lneg': 1}
-                rcnn_loss_wirepoint = wirepoint_head_line_loss(targets, outputs, x, y, idx, loss_weight)
-
-                loss_wirepoint = {"loss_wirepoint": rcnn_loss_wirepoint}
-
-            else:
-                pred = wirepoint_inference(x, idx, jcs, n_batch, ps, n_out_line, n_out_junc)
-                result.append(pred)
-
-            # tmp = wirepoint_features[0][0]
-            # plt.imshow(tmp.detach().numpy())
-            # wirepoint_logits = self.wirepoint_predictor((outputs,wirepoint_features))
-            # print(f'keypoint_logits:{wirepoint_logits.shape}')
-
-            # loss_wirepoint = {}    lm
-            # result=wirepoint_logits
-
-            # result.append(pred)    lm
-            losses.update(loss_wirepoint)
-        # print(f"result{result}")
-        # print(f"losses{losses}")
-
-        return result, losses
-
-
-def merge_features(features, proposals):
-    # 假设 roi_pool_features 是你的输入张量,形状为 [600, 256, 128, 128]
-
-    # 使用 torch.split 按照每个图像的提议数量分割 features
-    proposals_count = sum([p.size(0) for p in proposals])
-    features_size = features.size(0)
-    print(f'proposals sum:{proposals_count},features batch:{features.size(0)}')
-    if proposals_count != features_size:
-        raise ValueError("The length of proposals must match the batch size of features.")
-
-    split_features = []
-    start_idx = 0
-    for proposal in proposals:
-        # 提取当前图像的特征
-        current_features = features[start_idx:start_idx + proposal.size(0)]
-        print(f'current_features:{current_features.shape}')
-        split_features.append(current_features)
-        start_idx += 1
-
-    features_imgs = []
-    for features_per_img in split_features:
-        features_per_img, _ = torch.max(features_per_img, dim=0, keepdim=True)
-        features_imgs.append(features_per_img)
-
-    merged_features = torch.cat(features_imgs, dim=0)
-    print(f' merged_features:{merged_features.shape}')
-    return merged_features

+ 0 - 896
models/wirenet/roi_head.py

@@ -1,896 +0,0 @@
-from typing import Dict, List, Optional, Tuple
-
-import matplotlib.pyplot as plt
-import torch
-import torch.nn.functional as F
-import torchvision
-from torch import nn, Tensor
-from torchvision.ops import boxes as box_ops, roi_align
-
-from . import _utils as det_utils
-
-
-def fastrcnn_loss(class_logits, box_regression, labels, regression_targets):
-    # type: (Tensor, Tensor, List[Tensor], List[Tensor]) -> Tuple[Tensor, Tensor]
-    """
-    Computes the loss for Faster R-CNN.
-
-    Args:
-        class_logits (Tensor)
-        box_regression (Tensor)
-        labels (list[BoxList])
-        regression_targets (Tensor)
-
-    Returns:
-        classification_loss (Tensor)
-        box_loss (Tensor)
-    """
-
-    labels = torch.cat(labels, dim=0)
-    regression_targets = torch.cat(regression_targets, dim=0)
-
-    classification_loss = F.cross_entropy(class_logits, labels)
-
-    # get indices that correspond to the regression targets for
-    # the corresponding ground truth labels, to be used with
-    # advanced indexing
-    sampled_pos_inds_subset = torch.where(labels > 0)[0]
-    labels_pos = labels[sampled_pos_inds_subset]
-    N, num_classes = class_logits.shape
-    box_regression = box_regression.reshape(N, box_regression.size(-1) // 4, 4)
-
-    box_loss = F.smooth_l1_loss(
-        box_regression[sampled_pos_inds_subset, labels_pos],
-        regression_targets[sampled_pos_inds_subset],
-        beta=1 / 9,
-        reduction="sum",
-    )
-    box_loss = box_loss / labels.numel()
-
-    return classification_loss, box_loss
-
-
-def maskrcnn_inference(x, labels):
-    # type: (Tensor, List[Tensor]) -> List[Tensor]
-    """
-    From the results of the CNN, post process the masks
-    by taking the mask corresponding to the class with max
-    probability (which are of fixed size and directly output
-    by the CNN) and return the masks in the mask field of the BoxList.
-
-    Args:
-        x (Tensor): the mask logits
-        labels (list[BoxList]): bounding boxes that are used as
-            reference, one for ech image
-
-    Returns:
-        results (list[BoxList]): one BoxList for each image, containing
-            the extra field mask
-    """
-    mask_prob = x.sigmoid()
-
-    # select masks corresponding to the predicted classes
-    num_masks = x.shape[0]
-    boxes_per_image = [label.shape[0] for label in labels]
-    labels = torch.cat(labels)
-    index = torch.arange(num_masks, device=labels.device)
-    mask_prob = mask_prob[index, labels][:, None]
-    mask_prob = mask_prob.split(boxes_per_image, dim=0)
-
-    return mask_prob
-
-
-def project_masks_on_boxes(gt_masks, boxes, matched_idxs, M):
-    # type: (Tensor, Tensor, Tensor, int) -> Tensor
-    """
-    Given segmentation masks and the bounding boxes corresponding
-    to the location of the masks in the image, this function
-    crops and resizes the masks in the position defined by the
-    boxes. This prepares the masks for them to be fed to the
-    loss computation as the targets.
-    """
-    matched_idxs = matched_idxs.to(boxes)
-    rois = torch.cat([matched_idxs[:, None], boxes], dim=1)
-    gt_masks = gt_masks[:, None].to(rois)
-    return roi_align(gt_masks, rois, (M, M), 1.0)[:, 0]
-
-
-def maskrcnn_loss(mask_logits, proposals, gt_masks, gt_labels, mask_matched_idxs):
-    # type: (Tensor, List[Tensor], List[Tensor], List[Tensor], List[Tensor]) -> Tensor
-    """
-    Args:
-        proposals (list[BoxList])
-        mask_logits (Tensor)
-        targets (list[BoxList])
-
-    Return:
-        mask_loss (Tensor): scalar tensor containing the loss
-    """
-
-    discretization_size = mask_logits.shape[-1]
-    # print(f'mask_logits:{mask_logits},gt_masks:{gt_masks},,gt_labels:{gt_labels}]')
-    # print(f'mask discretization_size:{discretization_size}')
-    labels = [gt_label[idxs] for gt_label, idxs in zip(gt_labels, mask_matched_idxs)]
-    # print(f'mask labels:{labels}')
-    mask_targets = [
-        project_masks_on_boxes(m, p, i, discretization_size) for m, p, i in zip(gt_masks, proposals, mask_matched_idxs)
-    ]
-
-    labels = torch.cat(labels, dim=0)
-    # print(f'mask labels1:{labels}')
-    mask_targets = torch.cat(mask_targets, dim=0)
-
-    # torch.mean (in binary_cross_entropy_with_logits) doesn't
-    # accept empty tensors, so handle it separately
-    if mask_targets.numel() == 0:
-        return mask_logits.sum() * 0
-    # print(f'mask_targets:{mask_targets.shape},mask_logits:{mask_logits.shape}')
-    # print(f'mask_targets:{mask_targets}')
-    mask_loss = F.binary_cross_entropy_with_logits(
-        mask_logits[torch.arange(labels.shape[0], device=labels.device), labels], mask_targets
-    )
-    # print(f'mask_loss:{mask_loss}')
-    return mask_loss
-
-
-def keypoints_to_heatmap(keypoints, rois, heatmap_size):
-    # type: (Tensor, Tensor, int) -> Tuple[Tensor, Tensor]
-    offset_x = rois[:, 0]
-    offset_y = rois[:, 1]
-    scale_x = heatmap_size / (rois[:, 2] - rois[:, 0])
-    scale_y = heatmap_size / (rois[:, 3] - rois[:, 1])
-
-    offset_x = offset_x[:, None]
-    offset_y = offset_y[:, None]
-    scale_x = scale_x[:, None]
-    scale_y = scale_y[:, None]
-
-    x = keypoints[..., 0]
-    y = keypoints[..., 1]
-
-    x_boundary_inds = x == rois[:, 2][:, None]
-    y_boundary_inds = y == rois[:, 3][:, None]
-
-    x = (x - offset_x) * scale_x
-    x = x.floor().long()
-    y = (y - offset_y) * scale_y
-    y = y.floor().long()
-
-    x[x_boundary_inds] = heatmap_size - 1
-    y[y_boundary_inds] = heatmap_size - 1
-
-    valid_loc = (x >= 0) & (y >= 0) & (x < heatmap_size) & (y < heatmap_size)
-    vis = keypoints[..., 2] > 0
-    valid = (valid_loc & vis).long()
-
-    lin_ind = y * heatmap_size + x
-    heatmaps = lin_ind * valid
-
-    return heatmaps, valid
-
-
-def _onnx_heatmaps_to_keypoints(
-    maps, maps_i, roi_map_width, roi_map_height, widths_i, heights_i, offset_x_i, offset_y_i
-):
-    num_keypoints = torch.scalar_tensor(maps.size(1), dtype=torch.int64)
-
-    width_correction = widths_i / roi_map_width
-    height_correction = heights_i / roi_map_height
-
-    roi_map = F.interpolate(
-        maps_i[:, None], size=(int(roi_map_height), int(roi_map_width)), mode="bicubic", align_corners=False
-    )[:, 0]
-
-    w = torch.scalar_tensor(roi_map.size(2), dtype=torch.int64)
-    pos = roi_map.reshape(num_keypoints, -1).argmax(dim=1)
-
-    x_int = pos % w
-    y_int = (pos - x_int) // w
-
-    x = (torch.tensor(0.5, dtype=torch.float32) + x_int.to(dtype=torch.float32)) * width_correction.to(
-        dtype=torch.float32
-    )
-    y = (torch.tensor(0.5, dtype=torch.float32) + y_int.to(dtype=torch.float32)) * height_correction.to(
-        dtype=torch.float32
-    )
-
-    xy_preds_i_0 = x + offset_x_i.to(dtype=torch.float32)
-    xy_preds_i_1 = y + offset_y_i.to(dtype=torch.float32)
-    xy_preds_i_2 = torch.ones(xy_preds_i_1.shape, dtype=torch.float32)
-    xy_preds_i = torch.stack(
-        [
-            xy_preds_i_0.to(dtype=torch.float32),
-            xy_preds_i_1.to(dtype=torch.float32),
-            xy_preds_i_2.to(dtype=torch.float32),
-        ],
-        0,
-    )
-
-    # TODO: simplify when indexing without rank will be supported by ONNX
-    base = num_keypoints * num_keypoints + num_keypoints + 1
-    ind = torch.arange(num_keypoints)
-    ind = ind.to(dtype=torch.int64) * base
-    end_scores_i = (
-        roi_map.index_select(1, y_int.to(dtype=torch.int64))
-        .index_select(2, x_int.to(dtype=torch.int64))
-        .view(-1)
-        .index_select(0, ind.to(dtype=torch.int64))
-    )
-
-    return xy_preds_i, end_scores_i
-
-
-@torch.jit._script_if_tracing
-def _onnx_heatmaps_to_keypoints_loop(
-    maps, rois, widths_ceil, heights_ceil, widths, heights, offset_x, offset_y, num_keypoints
-):
-    xy_preds = torch.zeros((0, 3, int(num_keypoints)), dtype=torch.float32, device=maps.device)
-    end_scores = torch.zeros((0, int(num_keypoints)), dtype=torch.float32, device=maps.device)
-
-    for i in range(int(rois.size(0))):
-        xy_preds_i, end_scores_i = _onnx_heatmaps_to_keypoints(
-            maps, maps[i], widths_ceil[i], heights_ceil[i], widths[i], heights[i], offset_x[i], offset_y[i]
-        )
-        xy_preds = torch.cat((xy_preds.to(dtype=torch.float32), xy_preds_i.unsqueeze(0).to(dtype=torch.float32)), 0)
-        end_scores = torch.cat(
-            (end_scores.to(dtype=torch.float32), end_scores_i.to(dtype=torch.float32).unsqueeze(0)), 0
-        )
-    return xy_preds, end_scores
-
-
-def heatmaps_to_keypoints(maps, rois):
-    """Extract predicted keypoint locations from heatmaps. Output has shape
-    (#rois, 4, #keypoints) with the 4 rows corresponding to (x, y, logit, prob)
-    for each keypoint.
-    """
-    # This function converts a discrete image coordinate in a HEATMAP_SIZE x
-    # HEATMAP_SIZE image to a continuous keypoint coordinate. We maintain
-    # consistency with keypoints_to_heatmap_labels by using the conversion from
-    # Heckbert 1990: c = d + 0.5, where d is a discrete coordinate and c is a
-    # continuous coordinate.
-    offset_x = rois[:, 0]
-    offset_y = rois[:, 1]
-
-    widths = rois[:, 2] - rois[:, 0]
-    heights = rois[:, 3] - rois[:, 1]
-    widths = widths.clamp(min=1)
-    heights = heights.clamp(min=1)
-    widths_ceil = widths.ceil()
-    heights_ceil = heights.ceil()
-
-    num_keypoints = maps.shape[1]
-
-    if torchvision._is_tracing():
-        xy_preds, end_scores = _onnx_heatmaps_to_keypoints_loop(
-            maps,
-            rois,
-            widths_ceil,
-            heights_ceil,
-            widths,
-            heights,
-            offset_x,
-            offset_y,
-            torch.scalar_tensor(num_keypoints, dtype=torch.int64),
-        )
-        return xy_preds.permute(0, 2, 1), end_scores
-
-    xy_preds = torch.zeros((len(rois), 3, num_keypoints), dtype=torch.float32, device=maps.device)
-    end_scores = torch.zeros((len(rois), num_keypoints), dtype=torch.float32, device=maps.device)
-    for i in range(len(rois)):
-        roi_map_width = int(widths_ceil[i].item())
-        roi_map_height = int(heights_ceil[i].item())
-        width_correction = widths[i] / roi_map_width
-        height_correction = heights[i] / roi_map_height
-        roi_map = F.interpolate(
-            maps[i][:, None], size=(roi_map_height, roi_map_width), mode="bicubic", align_corners=False
-        )[:, 0]
-        # roi_map_probs = scores_to_probs(roi_map.copy())
-        w = roi_map.shape[2]
-        pos = roi_map.reshape(num_keypoints, -1).argmax(dim=1)
-
-        x_int = pos % w
-        y_int = torch.div(pos - x_int, w, rounding_mode="floor")
-        # assert (roi_map_probs[k, y_int, x_int] ==
-        #         roi_map_probs[k, :, :].max())
-        x = (x_int.float() + 0.5) * width_correction
-        y = (y_int.float() + 0.5) * height_correction
-        xy_preds[i, 0, :] = x + offset_x[i]
-        xy_preds[i, 1, :] = y + offset_y[i]
-        xy_preds[i, 2, :] = 1
-        end_scores[i, :] = roi_map[torch.arange(num_keypoints, device=roi_map.device), y_int, x_int]
-
-    return xy_preds.permute(0, 2, 1), end_scores
-
-
-def keypointrcnn_loss(keypoint_logits, proposals, gt_keypoints, keypoint_matched_idxs):
-    # type: (Tensor, List[Tensor], List[Tensor], List[Tensor]) -> Tensor
-    N, K, H, W = keypoint_logits.shape
-    if H != W:
-        raise ValueError(
-            f"keypoint_logits height and width (last two elements of shape) should be equal. Instead got H = {H} and W = {W}"
-        )
-    discretization_size = H
-    heatmaps = []
-    valid = []
-    for proposals_per_image, gt_kp_in_image, midx in zip(proposals, gt_keypoints, keypoint_matched_idxs):
-        kp = gt_kp_in_image[midx]
-        heatmaps_per_image, valid_per_image = keypoints_to_heatmap(kp, proposals_per_image, discretization_size)
-        heatmaps.append(heatmaps_per_image.view(-1))
-        valid.append(valid_per_image.view(-1))
-
-    keypoint_targets = torch.cat(heatmaps, dim=0)
-    valid = torch.cat(valid, dim=0).to(dtype=torch.uint8)
-    valid = torch.where(valid)[0]
-
-    # torch.mean (in binary_cross_entropy_with_logits) doesn't
-    # accept empty tensors, so handle it sepaartely
-    if keypoint_targets.numel() == 0 or len(valid) == 0:
-        return keypoint_logits.sum() * 0
-
-    keypoint_logits = keypoint_logits.view(N * K, H * W)
-
-    keypoint_loss = F.cross_entropy(keypoint_logits[valid], keypoint_targets[valid])
-    return keypoint_loss
-
-
-def keypointrcnn_inference(x, boxes):
-    print(f'x:{x.shape}')
-    # type: (Tensor, List[Tensor]) -> Tuple[List[Tensor], List[Tensor]]
-    kp_probs = []
-    kp_scores = []
-
-    boxes_per_image = [box.size(0) for box in boxes]
-    x2 = x.split(boxes_per_image, dim=0)
-    print(f'x2:{x2}')
-
-    for xx, bb in zip(x2, boxes):
-        kp_prob, scores = heatmaps_to_keypoints(xx, bb)
-        kp_probs.append(kp_prob)
-        kp_scores.append(scores)
-
-    return kp_probs, kp_scores
-
-
-def _onnx_expand_boxes(boxes, scale):
-    # type: (Tensor, float) -> Tensor
-    w_half = (boxes[:, 2] - boxes[:, 0]) * 0.5
-    h_half = (boxes[:, 3] - boxes[:, 1]) * 0.5
-    x_c = (boxes[:, 2] + boxes[:, 0]) * 0.5
-    y_c = (boxes[:, 3] + boxes[:, 1]) * 0.5
-
-    w_half = w_half.to(dtype=torch.float32) * scale
-    h_half = h_half.to(dtype=torch.float32) * scale
-
-    boxes_exp0 = x_c - w_half
-    boxes_exp1 = y_c - h_half
-    boxes_exp2 = x_c + w_half
-    boxes_exp3 = y_c + h_half
-    boxes_exp = torch.stack((boxes_exp0, boxes_exp1, boxes_exp2, boxes_exp3), 1)
-    return boxes_exp
-
-
-# the next two functions should be merged inside Masker
-# but are kept here for the moment while we need them
-# temporarily for paste_mask_in_image
-def expand_boxes(boxes, scale):
-    # type: (Tensor, float) -> Tensor
-    if torchvision._is_tracing():
-        return _onnx_expand_boxes(boxes, scale)
-    w_half = (boxes[:, 2] - boxes[:, 0]) * 0.5
-    h_half = (boxes[:, 3] - boxes[:, 1]) * 0.5
-    x_c = (boxes[:, 2] + boxes[:, 0]) * 0.5
-    y_c = (boxes[:, 3] + boxes[:, 1]) * 0.5
-
-    w_half *= scale
-    h_half *= scale
-
-    boxes_exp = torch.zeros_like(boxes)
-    boxes_exp[:, 0] = x_c - w_half
-    boxes_exp[:, 2] = x_c + w_half
-    boxes_exp[:, 1] = y_c - h_half
-    boxes_exp[:, 3] = y_c + h_half
-    return boxes_exp
-
-
-@torch.jit.unused
-def expand_masks_tracing_scale(M, padding):
-    # type: (int, int) -> float
-    return torch.tensor(M + 2 * padding).to(torch.float32) / torch.tensor(M).to(torch.float32)
-
-
-def expand_masks(mask, padding):
-    # type: (Tensor, int) -> Tuple[Tensor, float]
-    M = mask.shape[-1]
-    if torch._C._get_tracing_state():  # could not import is_tracing(), not sure why
-        scale = expand_masks_tracing_scale(M, padding)
-    else:
-        scale = float(M + 2 * padding) / M
-    padded_mask = F.pad(mask, (padding,) * 4)
-    return padded_mask, scale
-
-
-def paste_mask_in_image(mask, box, im_h, im_w):
-    # type: (Tensor, Tensor, int, int) -> Tensor
-    TO_REMOVE = 1
-    w = int(box[2] - box[0] + TO_REMOVE)
-    h = int(box[3] - box[1] + TO_REMOVE)
-    w = max(w, 1)
-    h = max(h, 1)
-
-    # Set shape to [batchxCxHxW]
-    mask = mask.expand((1, 1, -1, -1))
-
-    # Resize mask
-    mask = F.interpolate(mask, size=(h, w), mode="bilinear", align_corners=False)
-    mask = mask[0][0]
-
-    im_mask = torch.zeros((im_h, im_w), dtype=mask.dtype, device=mask.device)
-    x_0 = max(box[0], 0)
-    x_1 = min(box[2] + 1, im_w)
-    y_0 = max(box[1], 0)
-    y_1 = min(box[3] + 1, im_h)
-
-    im_mask[y_0:y_1, x_0:x_1] = mask[(y_0 - box[1]) : (y_1 - box[1]), (x_0 - box[0]) : (x_1 - box[0])]
-    return im_mask
-
-
-def _onnx_paste_mask_in_image(mask, box, im_h, im_w):
-    one = torch.ones(1, dtype=torch.int64)
-    zero = torch.zeros(1, dtype=torch.int64)
-
-    w = box[2] - box[0] + one
-    h = box[3] - box[1] + one
-    w = torch.max(torch.cat((w, one)))
-    h = torch.max(torch.cat((h, one)))
-
-    # Set shape to [batchxCxHxW]
-    mask = mask.expand((1, 1, mask.size(0), mask.size(1)))
-
-    # Resize mask
-    mask = F.interpolate(mask, size=(int(h), int(w)), mode="bilinear", align_corners=False)
-    mask = mask[0][0]
-
-    x_0 = torch.max(torch.cat((box[0].unsqueeze(0), zero)))
-    x_1 = torch.min(torch.cat((box[2].unsqueeze(0) + one, im_w.unsqueeze(0))))
-    y_0 = torch.max(torch.cat((box[1].unsqueeze(0), zero)))
-    y_1 = torch.min(torch.cat((box[3].unsqueeze(0) + one, im_h.unsqueeze(0))))
-
-    unpaded_im_mask = mask[(y_0 - box[1]) : (y_1 - box[1]), (x_0 - box[0]) : (x_1 - box[0])]
-
-    # TODO : replace below with a dynamic padding when support is added in ONNX
-
-    # pad y
-    zeros_y0 = torch.zeros(y_0, unpaded_im_mask.size(1))
-    zeros_y1 = torch.zeros(im_h - y_1, unpaded_im_mask.size(1))
-    concat_0 = torch.cat((zeros_y0, unpaded_im_mask.to(dtype=torch.float32), zeros_y1), 0)[0:im_h, :]
-    # pad x
-    zeros_x0 = torch.zeros(concat_0.size(0), x_0)
-    zeros_x1 = torch.zeros(concat_0.size(0), im_w - x_1)
-    im_mask = torch.cat((zeros_x0, concat_0, zeros_x1), 1)[:, :im_w]
-    return im_mask
-
-
-@torch.jit._script_if_tracing
-def _onnx_paste_masks_in_image_loop(masks, boxes, im_h, im_w):
-    res_append = torch.zeros(0, im_h, im_w)
-    for i in range(masks.size(0)):
-        mask_res = _onnx_paste_mask_in_image(masks[i][0], boxes[i], im_h, im_w)
-        mask_res = mask_res.unsqueeze(0)
-        res_append = torch.cat((res_append, mask_res))
-    return res_append
-
-
-def paste_masks_in_image(masks, boxes, img_shape, padding=1):
-    # type: (Tensor, Tensor, Tuple[int, int], int) -> Tensor
-    masks, scale = expand_masks(masks, padding=padding)
-    boxes = expand_boxes(boxes, scale).to(dtype=torch.int64)
-    im_h, im_w = img_shape
-
-    if torchvision._is_tracing():
-        return _onnx_paste_masks_in_image_loop(
-            masks, boxes, torch.scalar_tensor(im_h, dtype=torch.int64), torch.scalar_tensor(im_w, dtype=torch.int64)
-        )[:, None]
-    res = [paste_mask_in_image(m[0], b, im_h, im_w) for m, b in zip(masks, boxes)]
-    if len(res) > 0:
-        ret = torch.stack(res, dim=0)[:, None]
-    else:
-        ret = masks.new_empty((0, 1, im_h, im_w))
-    return ret
-
-
-class RoIHeads(nn.Module):
-    __annotations__ = {
-        "box_coder": det_utils.BoxCoder,
-        "proposal_matcher": det_utils.Matcher,
-        "fg_bg_sampler": det_utils.BalancedPositiveNegativeSampler,
-    }
-
-    def __init__(
-        self,
-        box_roi_pool,
-        box_head,
-        box_predictor,
-        # Faster R-CNN training
-        fg_iou_thresh,
-        bg_iou_thresh,
-        batch_size_per_image,
-        positive_fraction,
-        bbox_reg_weights,
-        # Faster R-CNN inference
-        score_thresh,
-        nms_thresh,
-        detections_per_img,
-        # Mask
-        mask_roi_pool=None,
-        mask_head=None,
-        mask_predictor=None,
-        keypoint_roi_pool=None,
-        keypoint_head=None,
-        keypoint_predictor=None,
-    ):
-        super().__init__()
-
-        self.box_similarity = box_ops.box_iou
-        # assign ground-truth boxes for each proposal
-        self.proposal_matcher = det_utils.Matcher(fg_iou_thresh, bg_iou_thresh, allow_low_quality_matches=False)
-
-        self.fg_bg_sampler = det_utils.BalancedPositiveNegativeSampler(batch_size_per_image, positive_fraction)
-
-        if bbox_reg_weights is None:
-            bbox_reg_weights = (10.0, 10.0, 5.0, 5.0)
-        self.box_coder = det_utils.BoxCoder(bbox_reg_weights)
-
-        self.box_roi_pool = box_roi_pool
-        self.box_head = box_head
-        self.box_predictor = box_predictor
-
-        self.score_thresh = score_thresh
-        self.nms_thresh = nms_thresh
-        self.detections_per_img = detections_per_img
-
-        self.mask_roi_pool = mask_roi_pool
-        self.mask_head = mask_head
-        self.mask_predictor = mask_predictor
-
-        self.keypoint_roi_pool = keypoint_roi_pool
-        self.keypoint_head = keypoint_head
-        self.keypoint_predictor = keypoint_predictor
-
-    def has_mask(self):
-        if self.mask_roi_pool is None:
-            return False
-        if self.mask_head is None:
-            return False
-        if self.mask_predictor is None:
-            return False
-        return True
-
-    def has_keypoint(self):
-        if self.keypoint_roi_pool is None:
-            return False
-        if self.keypoint_head is None:
-            return False
-        if self.keypoint_predictor is None:
-            return False
-        return True
-
-    def assign_targets_to_proposals(self, proposals, gt_boxes, gt_labels):
-        # type: (List[Tensor], List[Tensor], List[Tensor]) -> Tuple[List[Tensor], List[Tensor]]
-        matched_idxs = []
-        labels = []
-        for proposals_in_image, gt_boxes_in_image, gt_labels_in_image in zip(proposals, gt_boxes, gt_labels):
-
-            if gt_boxes_in_image.numel() == 0:
-                # Background image
-                device = proposals_in_image.device
-                clamped_matched_idxs_in_image = torch.zeros(
-                    (proposals_in_image.shape[0],), dtype=torch.int64, device=device
-                )
-                labels_in_image = torch.zeros((proposals_in_image.shape[0],), dtype=torch.int64, device=device)
-            else:
-                #  set to self.box_similarity when https://github.com/pytorch/pytorch/issues/27495 lands
-                match_quality_matrix = box_ops.box_iou(gt_boxes_in_image, proposals_in_image)
-                matched_idxs_in_image = self.proposal_matcher(match_quality_matrix)
-
-                clamped_matched_idxs_in_image = matched_idxs_in_image.clamp(min=0)
-
-                labels_in_image = gt_labels_in_image[clamped_matched_idxs_in_image]
-                labels_in_image = labels_in_image.to(dtype=torch.int64)
-
-                # Label background (below the low threshold)
-                bg_inds = matched_idxs_in_image == self.proposal_matcher.BELOW_LOW_THRESHOLD
-                labels_in_image[bg_inds] = 0
-
-                # Label ignore proposals (between low and high thresholds)
-                ignore_inds = matched_idxs_in_image == self.proposal_matcher.BETWEEN_THRESHOLDS
-                labels_in_image[ignore_inds] = -1  # -1 is ignored by sampler
-
-            matched_idxs.append(clamped_matched_idxs_in_image)
-            labels.append(labels_in_image)
-        return matched_idxs, labels
-
-    def subsample(self, labels):
-        # type: (List[Tensor]) -> List[Tensor]
-        sampled_pos_inds, sampled_neg_inds = self.fg_bg_sampler(labels)
-        sampled_inds = []
-        for img_idx, (pos_inds_img, neg_inds_img) in enumerate(zip(sampled_pos_inds, sampled_neg_inds)):
-            img_sampled_inds = torch.where(pos_inds_img | neg_inds_img)[0]
-            sampled_inds.append(img_sampled_inds)
-        return sampled_inds
-
-    def add_gt_proposals(self, proposals, gt_boxes):
-        # type: (List[Tensor], List[Tensor]) -> List[Tensor]
-        proposals = [torch.cat((proposal, gt_box)) for proposal, gt_box in zip(proposals, gt_boxes)]
-
-        return proposals
-
-    def check_targets(self, targets):
-        # type: (Optional[List[Dict[str, Tensor]]]) -> None
-        if targets is None:
-            raise ValueError("targets should not be None")
-        if not all(["boxes" in t for t in targets]):
-            raise ValueError("Every element of targets should have a boxes key")
-        if not all(["labels" in t for t in targets]):
-            raise ValueError("Every element of targets should have a labels key")
-        if self.has_mask():
-            if not all(["masks" in t for t in targets]):
-                raise ValueError("Every element of targets should have a masks key")
-
-    def select_training_samples(
-        self,
-        proposals,  # type: List[Tensor]
-        targets,  # type: Optional[List[Dict[str, Tensor]]]
-    ):
-        # type: (...) -> Tuple[List[Tensor], List[Tensor], List[Tensor], List[Tensor]]
-        self.check_targets(targets)
-        if targets is None:
-            raise ValueError("targets should not be None")
-        dtype = proposals[0].dtype
-        device = proposals[0].device
-
-        gt_boxes = [t["boxes"].to(dtype) for t in targets]
-        gt_labels = [t["labels"] for t in targets]
-
-        # append ground-truth bboxes to propos
-        proposals = self.add_gt_proposals(proposals, gt_boxes)
-
-        # get matching gt indices for each proposal
-        matched_idxs, labels = self.assign_targets_to_proposals(proposals, gt_boxes, gt_labels)
-        # sample a fixed proportion of positive-negative proposals
-        sampled_inds = self.subsample(labels)
-        matched_gt_boxes = []
-        num_images = len(proposals)
-        for img_id in range(num_images):
-            img_sampled_inds = sampled_inds[img_id]
-            proposals[img_id] = proposals[img_id][img_sampled_inds]
-            labels[img_id] = labels[img_id][img_sampled_inds]
-            matched_idxs[img_id] = matched_idxs[img_id][img_sampled_inds]
-
-            gt_boxes_in_image = gt_boxes[img_id]
-            if gt_boxes_in_image.numel() == 0:
-                gt_boxes_in_image = torch.zeros((1, 4), dtype=dtype, device=device)
-            matched_gt_boxes.append(gt_boxes_in_image[matched_idxs[img_id]])
-
-        regression_targets = self.box_coder.encode(matched_gt_boxes, proposals)
-        return proposals, matched_idxs, labels, regression_targets
-
-    def postprocess_detections(
-        self,
-        class_logits,  # type: Tensor
-        box_regression,  # type: Tensor
-        proposals,  # type: List[Tensor]
-        image_shapes,  # type: List[Tuple[int, int]]
-    ):
-        # type: (...) -> Tuple[List[Tensor], List[Tensor], List[Tensor]]
-        device = class_logits.device
-        num_classes = class_logits.shape[-1]
-
-        boxes_per_image = [boxes_in_image.shape[0] for boxes_in_image in proposals]
-        pred_boxes = self.box_coder.decode(box_regression, proposals)
-
-        pred_scores = F.softmax(class_logits, -1)
-
-        pred_boxes_list = pred_boxes.split(boxes_per_image, 0)
-        pred_scores_list = pred_scores.split(boxes_per_image, 0)
-
-        all_boxes = []
-        all_scores = []
-        all_labels = []
-        for boxes, scores, image_shape in zip(pred_boxes_list, pred_scores_list, image_shapes):
-            boxes = box_ops.clip_boxes_to_image(boxes, image_shape)
-
-            # create labels for each prediction
-            labels = torch.arange(num_classes, device=device)
-            labels = labels.view(1, -1).expand_as(scores)
-
-            # remove predictions with the background label
-            boxes = boxes[:, 1:]
-            scores = scores[:, 1:]
-            labels = labels[:, 1:]
-
-            # batch everything, by making every class prediction be a separate instance
-            boxes = boxes.reshape(-1, 4)
-            scores = scores.reshape(-1)
-            labels = labels.reshape(-1)
-
-            # remove low scoring boxes
-            inds = torch.where(scores > self.score_thresh)[0]
-            boxes, scores, labels = boxes[inds], scores[inds], labels[inds]
-
-            # remove empty boxes
-            keep = box_ops.remove_small_boxes(boxes, min_size=1e-2)
-            boxes, scores, labels = boxes[keep], scores[keep], labels[keep]
-
-            # non-maximum suppression, independently done per class
-            keep = box_ops.batched_nms(boxes, scores, labels, self.nms_thresh)
-            # keep only topk scoring predictions
-            keep = keep[: self.detections_per_img]
-            boxes, scores, labels = boxes[keep], scores[keep], labels[keep]
-
-            all_boxes.append(boxes)
-            all_scores.append(scores)
-            all_labels.append(labels)
-
-        return all_boxes, all_scores, all_labels
-
-    def forward(
-        self,
-        features,  # type: Dict[str, Tensor]
-        proposals,  # type: List[Tensor]
-        image_shapes,  # type: List[Tuple[int, int]]
-        targets=None,  # type: Optional[List[Dict[str, Tensor]]]
-    ):
-        # type: (...) -> Tuple[List[Dict[str, Tensor]], Dict[str, Tensor]]
-        """
-        Args:
-            features (List[Tensor])
-            proposals (List[Tensor[N, 4]])
-            image_shapes (List[Tuple[H, W]])
-            targets (List[Dict])
-        """
-        if targets is not None:
-            for t in targets:
-                # TODO: https://github.com/pytorch/pytorch/issues/26731
-                floating_point_types = (torch.float, torch.double, torch.half)
-                if not t["boxes"].dtype in floating_point_types:
-                    raise TypeError(f"target boxes must of float type, instead got {t['boxes'].dtype}")
-                if not t["labels"].dtype == torch.int64:
-                    raise TypeError(f"target labels must of int64 type, instead got {t['labels'].dtype}")
-                if self.has_keypoint():
-                    if not t["keypoints"].dtype == torch.float32:
-                        raise TypeError(f"target keypoints must of float type, instead got {t['keypoints'].dtype}")
-
-        if self.training:
-            proposals, matched_idxs, labels, regression_targets = self.select_training_samples(proposals, targets)
-        else:
-            labels = None
-            regression_targets = None
-            matched_idxs = None
-
-        box_features = self.box_roi_pool(features, proposals, image_shapes)
-        box_features = self.box_head(box_features)
-        class_logits, box_regression = self.box_predictor(box_features)
-
-        result: List[Dict[str, torch.Tensor]] = []
-        losses = {}
-        if self.training:
-            if labels is None:
-                raise ValueError("labels cannot be None")
-            if regression_targets is None:
-                raise ValueError("regression_targets cannot be None")
-            loss_classifier, loss_box_reg = fastrcnn_loss(class_logits, box_regression, labels, regression_targets)
-            losses = {"loss_classifier": loss_classifier, "loss_box_reg": loss_box_reg}
-        else:
-            boxes, scores, labels = self.postprocess_detections(class_logits, box_regression, proposals, image_shapes)
-            num_images = len(boxes)
-            for i in range(num_images):
-                result.append(
-                    {
-                        "boxes": boxes[i],
-                        "labels": labels[i],
-                        "scores": scores[i],
-                    }
-                )
-
-        if self.has_mask():
-            mask_proposals = [p["boxes"] for p in result]
-            if self.training:
-                if matched_idxs is None:
-                    raise ValueError("if in training, matched_idxs should not be None")
-
-                # during training, only focus on positive boxes
-                num_images = len(proposals)
-                mask_proposals = []
-                pos_matched_idxs = []
-                for img_id in range(num_images):
-                    pos = torch.where(labels[img_id] > 0)[0]
-                    mask_proposals.append(proposals[img_id][pos])
-                    pos_matched_idxs.append(matched_idxs[img_id][pos])
-            else:
-                pos_matched_idxs = None
-
-            if self.mask_roi_pool is not None:
-                mask_features = self.mask_roi_pool(features, mask_proposals, image_shapes)
-                mask_features = self.mask_head(mask_features)
-                mask_logits = self.mask_predictor(mask_features)
-            else:
-                raise Exception("Expected mask_roi_pool to be not None")
-
-            loss_mask = {}
-            if self.training:
-                if targets is None or pos_matched_idxs is None or mask_logits is None:
-                    raise ValueError("targets, pos_matched_idxs, mask_logits cannot be None when training")
-
-                gt_masks = [t["masks"] for t in targets]
-                gt_labels = [t["labels"] for t in targets]
-                rcnn_loss_mask = maskrcnn_loss(mask_logits, mask_proposals, gt_masks, gt_labels, pos_matched_idxs)
-                loss_mask = {"loss_mask": rcnn_loss_mask}
-            else:
-                labels = [r["labels"] for r in result]
-                masks_probs = maskrcnn_inference(mask_logits, labels)
-                for mask_prob, r in zip(masks_probs, result):
-                    r["masks"] = mask_prob
-
-            losses.update(loss_mask)
-
-        # keep none checks in if conditional so torchscript will conditionally
-        # compile each branch
-        if (
-            self.keypoint_roi_pool is not None
-            and self.keypoint_head is not None
-            and self.keypoint_predictor is not None
-        ):
-            keypoint_proposals = [p["boxes"] for p in result]
-            if self.training:
-                # during training, only focus on positive boxes
-                num_images = len(proposals)
-                keypoint_proposals = []
-                pos_matched_idxs = []
-                if matched_idxs is None:
-                    raise ValueError("if in trainning, matched_idxs should not be None")
-
-                for img_id in range(num_images):
-                    pos = torch.where(labels[img_id] > 0)[0]
-                    keypoint_proposals.append(proposals[img_id][pos])
-                    pos_matched_idxs.append(matched_idxs[img_id][pos])
-            else:
-                pos_matched_idxs = None
-
-            keypoint_features = self.keypoint_roi_pool(features, keypoint_proposals, image_shapes)
-            # tmp = keypoint_features[0][0]
-            # plt.imshow(tmp.detach().numpy())
-            print(f'keypoint_features from roi_pool:{keypoint_features.shape}')
-            keypoint_features = self.keypoint_head(keypoint_features)
-
-            print(f'keypoint_features:{keypoint_features.shape}')
-            tmp=keypoint_features[0][0]
-            plt.imshow(tmp.detach().numpy())
-            keypoint_logits = self.keypoint_predictor(keypoint_features)
-            print(f'keypoint_logits:{keypoint_logits.shape}')
-            """
-            接wirenet
-            """
-
-            loss_keypoint = {}
-            if self.training:
-                if targets is None or pos_matched_idxs is None:
-                    raise ValueError("both targets and pos_matched_idxs should not be None when in training mode")
-
-                gt_keypoints = [t["keypoints"] for t in targets]
-                rcnn_loss_keypoint = keypointrcnn_loss(
-                    keypoint_logits, keypoint_proposals, gt_keypoints, pos_matched_idxs
-                )
-                loss_keypoint = {"loss_keypoint": rcnn_loss_keypoint}
-            else:
-                if keypoint_logits is None or keypoint_proposals is None:
-                    raise ValueError(
-                        "both keypoint_logits and keypoint_proposals should not be None when not in training mode"
-                    )
-
-                keypoints_probs, kp_scores = keypointrcnn_inference(keypoint_logits, keypoint_proposals)
-                for keypoint_prob, kps, r in zip(keypoints_probs, kp_scores, result):
-                    r["keypoints"] = keypoint_prob
-                    r["keypoints_scores"] = kps
-            losses.update(loss_keypoint)
-
-        return result, losses

+ 0 - 151
models/wirenet/wirepoint_dataset.py

@@ -1,151 +0,0 @@
-from torch.utils.data.dataset import T_co
-
-from models.base.base_dataset import BaseDataset
-
-import glob
-import json
-import math
-import os
-import random
-import cv2
-import PIL
-
-import numpy as np
-import numpy.linalg as LA
-import torch
-from skimage import io
-from torch.utils.data import Dataset
-from torch.utils.data.dataloader import default_collate
-
-import matplotlib.pyplot as plt
-from models.dataset_tool import masks_to_boxes, read_masks_from_txt_wire, read_masks_from_pixels_wire, adjacency_matrix
-
-
-class WirePointDataset(BaseDataset):
-    def __init__(self, dataset_path, transforms=None, dataset_type=None, target_type='pixel'):
-        super().__init__(dataset_path)
-
-        self.data_path = dataset_path
-        print(f'data_path:{dataset_path}')
-        self.transforms = transforms
-        self.img_path = os.path.join(dataset_path, "images\\" + dataset_type)
-        self.lbl_path = os.path.join(dataset_path, "labels\\" + dataset_type)
-        self.imgs = os.listdir(self.img_path)
-        self.lbls = os.listdir(self.lbl_path)
-        self.target_type = target_type
-        # self.default_transform = DefaultTransform()
-
-    def __getitem__(self, index) -> T_co:
-        img_path = os.path.join(self.img_path, self.imgs[index])
-        lbl_path = os.path.join(self.lbl_path, self.imgs[index][:-3] + 'json')
-
-        img = PIL.Image.open(img_path).convert('RGB')
-        w, h = img.size
-
-        # wire_labels, target = self.read_target(item=index, lbl_path=lbl_path, shape=(h, w))
-        target = self.read_target(item=index, lbl_path=lbl_path, shape=(h, w))
-        if self.transforms:
-            img, target = self.transforms(img, target)
-        else:
-            img = self.default_transform(img)
-
-        # print(f'img:{img}')
-        return img, target
-    def __len__(self):
-        return len(self.imgs)
-    def read_target(self, item, lbl_path, shape, extra=None):
-        # print(f'lbl_path:{lbl_path}')
-        with open(lbl_path, 'r') as file:
-            lable_all = json.load(file)
-
-        n_stc_posl = 300
-        n_stc_negl = 40
-        use_cood = 0
-        use_slop = 0
-
-        wire = lable_all["wires"][0]  # 字典
-        line_pos_coords = np.random.permutation(wire["line_pos_coords"]["content"])[: n_stc_posl]  # 不足,有多少取多少
-        line_neg_coords = np.random.permutation(wire["line_neg_coords"]["content"])[: n_stc_negl]
-        npos, nneg = len(line_pos_coords), len(line_neg_coords)
-        lpre = np.concatenate([line_pos_coords, line_neg_coords], 0)  # 正负样本坐标合在一起
-        for i in range(len(lpre)):
-            if random.random() > 0.5:
-                lpre[i] = lpre[i, ::-1]
-        ldir = lpre[:, 0, :2] - lpre[:, 1, :2]
-        ldir /= np.clip(LA.norm(ldir, axis=1, keepdims=True), 1e-6, None)
-        feat = [
-            lpre[:, :, :2].reshape(-1, 4) / 128 * use_cood,
-            ldir * use_slop,
-            lpre[:, :, 2],
-        ]
-        feat = np.concatenate(feat, 1)
-
-        wire_labels = {
-            "junc_coords": torch.tensor(wire["junc_coords"]["content"])[:, :2],
-            "jtyp": torch.tensor(wire["junc_coords"]["content"])[:, 2].byte(),
-            "line_pos_idx": adjacency_matrix(len(wire["junc_coords"]["content"]), wire["line_pos_idx"]["content"]),
-            # 真实存在线条的邻接矩阵
-            "line_neg_idx": adjacency_matrix(len(wire["junc_coords"]["content"]), wire["line_neg_idx"]["content"]),
-            # 不存在线条的临界矩阵
-            "lpre": torch.tensor(lpre)[:, :, :2],
-            "lpre_label": torch.cat([torch.ones(npos), torch.zeros(nneg)]),  # 样本对应标签 1,0
-            "lpre_feat": torch.from_numpy(feat),
-            "junc_map": torch.tensor(wire['junc_map']["content"]),
-            "junc_offset": torch.tensor(wire['junc_offset']["content"]),
-            "line_map": torch.tensor(wire['line_map']["content"]),
-        }
-
-        h, w = shape
-        labels = []
-        masks = []
-        if self.target_type == 'polygon':
-            labels, masks = read_masks_from_txt_wire(lbl_path, shape)
-        elif self.target_type == 'pixel':
-            labels, masks = read_masks_from_pixels_wire(lbl_path, shape)
-
-        target = {}
-        target["boxes"] = masks_to_boxes(torch.stack(masks))
-        target["labels"] = torch.stack(labels)
-        target["masks"] = torch.stack(masks)
-        target["image_id"] = torch.tensor(item)
-        # return wire_labels, target
-        target["wires"] = wire_labels
-        return target
-
-    def show(self, idx):
-        img_path = os.path.join(self.img_path, self.imgs[idx])
-        lbl_path = os.path.join(self.lbl_path, self.imgs[idx][:-3] + 'json')
-
-        with open(lbl_path, 'r') as file:
-            lable_all = json.load(file)
-
-        # 可视化图像和标注
-        image = cv2.imread(img_path)  # [H,W,3]  # 默认为BGR格式
-        # print(image.shape)
-        # 绘制每个标注的多边形
-        # for ann in lable_all["segmentations"]:
-        #     segmentation = [[x * 512 for x in ann['data']]]
-        #     # segmentation = [ann['data']]
-        #     # for i in range(len(ann['data'])):
-        #     #     if i % 2 == 0:
-        #     #         segmentation[0][i] *= image.shape[0]
-        #     #     else:
-        #     #         segmentation[0][i] *= image.shape[0]
-        #
-        #     # if isinstance(segmentation, list):
-        #     #     for seg in segmentation:
-        #     #         poly = np.array(seg).reshape((-1, 2)).astype(int)
-        #     #         cv2.polylines(image, [poly], isClosed=True, color=(0, 255, 0), thickness=2)
-        #     #         cv2.fillPoly(image, [poly], color=(0, 255, 0))
-
-
-        #
-        # # 显示图像
-        # cv2.namedWindow('Image with Segmentations', cv2.WINDOW_NORMAL)
-        # cv2.imshow('Image with Segmentations', image)
-        # cv2.waitKey(0)
-        # cv2.destroyAllWindows()
-
-    def show_img(self,img_path):
-        pass
-

+ 0 - 703
models/wirenet/wirepoint_rcnn.py

@@ -1,703 +0,0 @@
-import os
-from typing import Optional, Any
-
-import numpy as np
-import torch
-from tensorboardX import SummaryWriter
-from torch import nn
-import torch.nn.functional as F
-# from torchinfo import summary
-from torchvision.io import read_image
-from torchvision.models import resnet50, ResNet50_Weights
-from torchvision.models.detection import FasterRCNN, MaskRCNN_ResNet50_FPN_V2_Weights
-from torchvision.models.detection._utils import overwrite_eps
-from torchvision.models.detection.backbone_utils import _resnet_fpn_extractor, _validate_trainable_layers
-from torchvision.models.detection.faster_rcnn import TwoMLPHead, FastRCNNPredictor
-from torchvision.models.detection.keypoint_rcnn import KeypointRCNNHeads, KeypointRCNNPredictor, \
-    KeypointRCNN_ResNet50_FPN_Weights
-from torchvision.ops import MultiScaleRoIAlign
-from torchvision.ops import misc as misc_nn_ops
-# from visdom import Visdom
-
-from models.config import config_tool
-from models.config.config_tool import read_yaml
-from models.ins.trainer import get_transform
-from models.wirenet.head import RoIHeads
-from models.wirenet.wirepoint_dataset import WirePointDataset
-from tools import utils
-
-
-FEATURE_DIM = 8
-
-
-def non_maximum_suppression(a):
-    ap = F.max_pool2d(a, 3, stride=1, padding=1)
-    mask = (a == ap).float().clamp(min=0.0)
-    return a * mask
-
-
-class Bottleneck1D(nn.Module):
-    def __init__(self, inplanes, outplanes):
-        super(Bottleneck1D, self).__init__()
-
-        planes = outplanes // 2
-        self.op = nn.Sequential(
-            nn.BatchNorm1d(inplanes),
-            nn.ReLU(inplace=True),
-            nn.Conv1d(inplanes, planes, kernel_size=1),
-            nn.BatchNorm1d(planes),
-            nn.ReLU(inplace=True),
-            nn.Conv1d(planes, planes, kernel_size=3, padding=1),
-            nn.BatchNorm1d(planes),
-            nn.ReLU(inplace=True),
-            nn.Conv1d(planes, outplanes, kernel_size=1),
-        )
-
-    def forward(self, x):
-        return x + self.op(x)
-
-
-class WirepointRCNN(FasterRCNN):
-    def __init__(
-            self,
-            backbone,
-            num_classes=None,
-            # transform parameters
-            min_size=None,
-            max_size=1333,
-            image_mean=None,
-            image_std=None,
-            # RPN parameters
-            rpn_anchor_generator=None,
-            rpn_head=None,
-            rpn_pre_nms_top_n_train=2000,
-            rpn_pre_nms_top_n_test=1000,
-            rpn_post_nms_top_n_train=2000,
-            rpn_post_nms_top_n_test=1000,
-            rpn_nms_thresh=0.7,
-            rpn_fg_iou_thresh=0.7,
-            rpn_bg_iou_thresh=0.3,
-            rpn_batch_size_per_image=256,
-            rpn_positive_fraction=0.5,
-            rpn_score_thresh=0.0,
-            # Box parameters
-            box_roi_pool=None,
-            box_head=None,
-            box_predictor=None,
-            box_score_thresh=0.05,
-            box_nms_thresh=0.5,
-            box_detections_per_img=100,
-            box_fg_iou_thresh=0.5,
-            box_bg_iou_thresh=0.5,
-            box_batch_size_per_image=512,
-            box_positive_fraction=0.25,
-            bbox_reg_weights=None,
-            # keypoint parameters
-            keypoint_roi_pool=None,
-            keypoint_head=None,
-            keypoint_predictor=None,
-            num_keypoints=None,
-            wirepoint_roi_pool=None,
-            wirepoint_head=None,
-            wirepoint_predictor=None,
-            **kwargs,
-    ):
-        if not isinstance(keypoint_roi_pool, (MultiScaleRoIAlign, type(None))):
-            raise TypeError(
-                "keypoint_roi_pool should be of type MultiScaleRoIAlign or None instead of {type(keypoint_roi_pool)}"
-            )
-        if min_size is None:
-            min_size = (640, 672, 704, 736, 768, 800)
-
-        if num_keypoints is not None:
-            if keypoint_predictor is not None:
-                raise ValueError("num_keypoints should be None when keypoint_predictor is specified")
-        else:
-            num_keypoints = 17
-
-        out_channels = backbone.out_channels
-
-        if wirepoint_roi_pool is None:
-            wirepoint_roi_pool = MultiScaleRoIAlign(featmap_names=["0", "1", "2", "3"], output_size=128,
-                                                    sampling_ratio=2,)
-
-        if wirepoint_head is None:
-            keypoint_layers = tuple(512 for _ in range(8))
-            print(f'keypoinyrcnnHeads inchannels:{out_channels},layers{keypoint_layers}')
-            wirepoint_head = WirepointHead(out_channels, keypoint_layers)
-
-        if wirepoint_predictor is None:
-            keypoint_dim_reduced = 512  # == keypoint_layers[-1]
-            wirepoint_predictor = WirepointPredictor()
-
-        super().__init__(
-            backbone,
-            num_classes,
-            # transform parameters
-            min_size,
-            max_size,
-            image_mean,
-            image_std,
-            # RPN-specific parameters
-            rpn_anchor_generator,
-            rpn_head,
-            rpn_pre_nms_top_n_train,
-            rpn_pre_nms_top_n_test,
-            rpn_post_nms_top_n_train,
-            rpn_post_nms_top_n_test,
-            rpn_nms_thresh,
-            rpn_fg_iou_thresh,
-            rpn_bg_iou_thresh,
-            rpn_batch_size_per_image,
-            rpn_positive_fraction,
-            rpn_score_thresh,
-            # Box parameters
-            box_roi_pool,
-            box_head,
-            box_predictor,
-            box_score_thresh,
-            box_nms_thresh,
-            box_detections_per_img,
-            box_fg_iou_thresh,
-            box_bg_iou_thresh,
-            box_batch_size_per_image,
-            box_positive_fraction,
-            bbox_reg_weights,
-            **kwargs,
-        )
-
-        if box_roi_pool is None:
-            box_roi_pool = MultiScaleRoIAlign(featmap_names=["0", "1", "2", "3"], output_size=14, sampling_ratio=2)
-
-        if box_head is None:
-            resolution = box_roi_pool.output_size[0]
-            representation_size = 1024
-            box_head = TwoMLPHead(out_channels * resolution ** 2, representation_size)
-
-        if box_predictor is None:
-            representation_size = 1024
-            box_predictor = FastRCNNPredictor(representation_size, num_classes)
-
-        roi_heads = RoIHeads(
-            # Box
-            box_roi_pool,
-            box_head,
-            box_predictor,
-            box_fg_iou_thresh,
-            box_bg_iou_thresh,
-            box_batch_size_per_image,
-            box_positive_fraction,
-            bbox_reg_weights,
-            box_score_thresh,
-            box_nms_thresh,
-            box_detections_per_img,
-            # wirepoint_roi_pool=wirepoint_roi_pool,
-            # wirepoint_head=wirepoint_head,
-            # wirepoint_predictor=wirepoint_predictor,
-        )
-        self.roi_heads = roi_heads
-
-        self.roi_heads.wirepoint_roi_pool = wirepoint_roi_pool
-        self.roi_heads.wirepoint_head = wirepoint_head
-        self.roi_heads.wirepoint_predictor = wirepoint_predictor
-
-
-class WirepointHead(nn.Module):
-    def __init__(self, input_channels, num_class):
-        super(WirepointHead, self).__init__()
-        self.head_size = [[2], [1], [2]]
-        m = int(input_channels / 4)
-        heads = []
-        # print(f'M.head_size:{M.head_size}')
-        # for output_channels in sum(M.head_size, []):
-        for output_channels in sum(self.head_size, []):
-            heads.append(
-                nn.Sequential(
-                    nn.Conv2d(input_channels, m, kernel_size=3, padding=1),
-                    nn.ReLU(inplace=True),
-                    nn.Conv2d(m, output_channels, kernel_size=1),
-                )
-            )
-        self.heads = nn.ModuleList(heads)
-
-    def forward(self, x):
-        # for idx, head in enumerate(self.heads):
-        #     print(f'{idx},multitask head:{head(x).shape},input x:{x.shape}')
-
-        outputs = torch.cat([head(x) for head in self.heads], dim=1)
-
-        features = x
-        return outputs, features
-
-
-class WirepointPredictor(nn.Module):
-
-    def __init__(self):
-        super().__init__()
-        # self.backbone = backbone
-        # self.cfg = read_yaml(cfg)
-        self.cfg = read_yaml('wirenet.yaml')
-        self.n_pts0 = self.cfg['model']['n_pts0']
-        self.n_pts1 = self.cfg['model']['n_pts1']
-        self.n_stc_posl = self.cfg['model']['n_stc_posl']
-        self.dim_loi = self.cfg['model']['dim_loi']
-        self.use_conv = self.cfg['model']['use_conv']
-        self.dim_fc = self.cfg['model']['dim_fc']
-        self.n_out_line = self.cfg['model']['n_out_line']
-        self.n_out_junc = self.cfg['model']['n_out_junc']
-        self.loss_weight = self.cfg['model']['loss_weight']
-        self.n_dyn_junc = self.cfg['model']['n_dyn_junc']
-        self.eval_junc_thres = self.cfg['model']['eval_junc_thres']
-        self.n_dyn_posl = self.cfg['model']['n_dyn_posl']
-        self.n_dyn_negl = self.cfg['model']['n_dyn_negl']
-        self.n_dyn_othr = self.cfg['model']['n_dyn_othr']
-        self.use_cood = self.cfg['model']['use_cood']
-        self.use_slop = self.cfg['model']['use_slop']
-        self.n_stc_negl = self.cfg['model']['n_stc_negl']
-        self.head_size = self.cfg['model']['head_size']
-        self.num_class = sum(sum(self.head_size, []))
-        self.head_off = np.cumsum([sum(h) for h in self.head_size])
-
-        lambda_ = torch.linspace(0, 1, self.n_pts0)[:, None]
-        self.register_buffer("lambda_", lambda_)
-        self.do_static_sampling = self.n_stc_posl + self.n_stc_negl > 0
-
-        self.fc1 = nn.Conv2d(256, self.dim_loi, 1)
-        scale_factor = self.n_pts0 // self.n_pts1
-        if self.use_conv:
-            self.pooling = nn.Sequential(
-                nn.MaxPool1d(scale_factor, scale_factor),
-                Bottleneck1D(self.dim_loi, self.dim_loi),
-            )
-            self.fc2 = nn.Sequential(
-                nn.ReLU(inplace=True), nn.Linear(self.dim_loi * self.n_pts1 + FEATURE_DIM, 1)
-            )
-        else:
-            self.pooling = nn.MaxPool1d(scale_factor, scale_factor)
-            self.fc2 = nn.Sequential(
-                nn.Linear(self.dim_loi * self.n_pts1 + FEATURE_DIM, self.dim_fc),
-                nn.ReLU(inplace=True),
-                nn.Linear(self.dim_fc, self.dim_fc),
-                nn.ReLU(inplace=True),
-                nn.Linear(self.dim_fc, 1),
-            )
-        self.loss = nn.BCEWithLogitsLoss(reduction="none")
-
-    def forward(self, inputs,features, targets=None):
-
-        # outputs, features = input
-        # for out in outputs:
-        #     print(f'out:{out.shape}')
-        # outputs=merge_features(outputs,100)
-        batch, channel, row, col = inputs.shape
-        print(f'outputs:{inputs.shape}')
-        print(f'batch:{batch}, channel:{channel}, row:{row}, col:{col}')
-
-        if targets is not None:
-            self.training = True
-            # print(f'target:{targets}')
-            wires_targets = [t["wires"] for t in targets]
-            # print(f'wires_target:{wires_targets}')
-            # 提取所有 'junc_map', 'junc_offset', 'line_map' 的张量
-            junc_maps = [d["junc_map"] for d in wires_targets]
-            junc_offsets = [d["junc_offset"] for d in wires_targets]
-            line_maps = [d["line_map"] for d in wires_targets]
-
-            junc_map_tensor = torch.stack(junc_maps, dim=0)
-            junc_offset_tensor = torch.stack(junc_offsets, dim=0)
-            line_map_tensor = torch.stack(line_maps, dim=0)
-
-            wires_meta = {
-                "junc_map": junc_map_tensor,
-                "junc_offset": junc_offset_tensor,
-                # "line_map": line_map_tensor,
-            }
-        else:
-            self.training = False
-            t = {
-                    "junc_coords": torch.zeros(1, 2),
-                    "jtyp": torch.zeros(1, dtype=torch.uint8),
-                    "line_pos_idx": torch.zeros(2, 2, dtype=torch.uint8),
-                    "line_neg_idx": torch.zeros(2, 2, dtype=torch.uint8),
-                    "junc_map": torch.zeros([1, 1, 128, 128]),
-                    "junc_offset": torch.zeros([1, 1, 2, 128, 128]),
-                }
-            wires_targets=[t for b in range(inputs.size(0))]
-
-            wires_meta={
-                "junc_map": torch.zeros([1, 1, 128, 128]),
-                "junc_offset": torch.zeros([1, 1, 2, 128, 128]),
-            }
-
-
-        T = wires_meta.copy()
-        n_jtyp = T["junc_map"].shape[1]
-        offset = self.head_off
-        result={}
-        for stack, output in enumerate([inputs]):
-            output = output.transpose(0, 1).reshape([-1, batch, row, col]).contiguous()
-            print(f"Stack {stack} output shape: {output.shape}")  # 打印每层的输出形状
-            jmap = output[0: offset[0]].reshape(n_jtyp, 2, batch, row, col)
-            lmap = output[offset[0]: offset[1]].squeeze(0)
-            joff = output[offset[1]: offset[2]].reshape(n_jtyp, 2, batch, row, col)
-
-            if stack == 0:
-                result["preds"] = {
-                    "jmap": jmap.permute(2, 0, 1, 3, 4).softmax(2)[:, :, 1],
-                    "lmap": lmap.sigmoid(),
-                    "joff": joff.permute(2, 0, 1, 3, 4).sigmoid() - 0.5,
-                }
-                # visualize_feature_map(jmap[0, 0], title=f"jmap - Stack {stack}")
-                # visualize_feature_map(lmap, title=f"lmap - Stack {stack}")
-                # visualize_feature_map(joff[0, 0], title=f"joff - Stack {stack}")
-
-        h = result["preds"]
-        print(f'features shape:{features.shape}')
-        x = self.fc1(features)
-
-        print(f'x:{x.shape}')
-
-        n_batch, n_channel, row, col = x.shape
-
-        print(f'n_batch:{n_batch}, n_channel:{n_channel}, row:{row}, col:{col}')
-
-        xs, ys, fs, ps, idx, jcs = [], [], [], [], [0], []
-
-        for i, meta in enumerate(wires_targets):
-            p, label, feat, jc = self.sample_lines(
-                meta, h["jmap"][i], h["joff"][i],
-            )
-            print(f"p.shape:{p.shape},label:{label.shape},feat:{feat.shape},jc:{len(jc)}")
-            ys.append(label)
-            if self.training and self.do_static_sampling:
-                p = torch.cat([p, meta["lpre"]])
-                feat = torch.cat([feat, meta["lpre_feat"]])
-                ys.append(meta["lpre_label"])
-                del jc
-            else:
-                jcs.append(jc)
-                ps.append(p)
-            fs.append(feat)
-
-            p = p[:, 0:1, :] * self.lambda_ + p[:, 1:2, :] * (1 - self.lambda_) - 0.5
-            p = p.reshape(-1, 2)  # [N_LINE x N_POINT, 2_XY]
-            px, py = p[:, 0].contiguous(), p[:, 1].contiguous()
-            px0 = px.floor().clamp(min=0, max=127)
-            py0 = py.floor().clamp(min=0, max=127)
-            px1 = (px0 + 1).clamp(min=0, max=127)
-            py1 = (py0 + 1).clamp(min=0, max=127)
-            px0l, py0l, px1l, py1l = px0.long(), py0.long(), px1.long(), py1.long()
-
-            # xp: [N_LINE, N_CHANNEL, N_POINT]
-            xp = (
-                (
-                        x[i, :, px0l, py0l] * (px1 - px) * (py1 - py)
-                        + x[i, :, px1l, py0l] * (px - px0) * (py1 - py)
-                        + x[i, :, px0l, py1l] * (px1 - px) * (py - py0)
-                        + x[i, :, px1l, py1l] * (px - px0) * (py - py0)
-                )
-                .reshape(n_channel, -1, self.n_pts0)
-                .permute(1, 0, 2)
-            )
-            xp = self.pooling(xp)
-            print(f'xp.shape:{xp.shape}')
-            xs.append(xp)
-            idx.append(idx[-1] + xp.shape[0])
-            print(f'idx__:{idx}')
-
-        x, y = torch.cat(xs), torch.cat(ys)
-        f = torch.cat(fs)
-        x = x.reshape(-1, self.n_pts1 * self.dim_loi)
-
-        # print("Weight dtype:", self.fc2.weight.dtype)
-        x = torch.cat([x, f], 1)
-        print("Input dtype:", x.dtype)
-        x = x.to(dtype=torch.float32)
-        print("Input dtype1:", x.dtype)
-        x = self.fc2(x).flatten()
-
-        # return  x,idx,jcs,n_batch,ps,self.n_out_line,self.n_out_junc
-        return x, y, idx, jcs, n_batch, ps, self.n_out_line, self.n_out_junc
-
-
-        # if mode != "training":
-        # self.inference(x, idx, jcs, n_batch, ps)
-
-        # return result
-
-
-    ####deprecated
-    # def inference(self,input, idx, jcs, n_batch, ps):
-    #     if not self.training:
-    #         p = torch.cat(ps)
-    #         s = torch.sigmoid(input)
-    #         b = s > 0.5
-    #         lines = []
-    #         score = []
-    #         print(f"n_batch:{n_batch}")
-    #         for i in range(n_batch):
-    #             print(f"idx:{idx}")
-    #             p0 = p[idx[i]: idx[i + 1]]
-    #             s0 = s[idx[i]: idx[i + 1]]
-    #             mask = b[idx[i]: idx[i + 1]]
-    #             p0 = p0[mask]
-    #             s0 = s0[mask]
-    #             if len(p0) == 0:
-    #                 lines.append(torch.zeros([1, self.n_out_line, 2, 2], device=p.device))
-    #                 score.append(torch.zeros([1, self.n_out_line], device=p.device))
-    #             else:
-    #                 arg = torch.argsort(s0, descending=True)
-    #                 p0, s0 = p0[arg], s0[arg]
-    #                 lines.append(p0[None, torch.arange(self.n_out_line) % len(p0)])
-    #                 score.append(s0[None, torch.arange(self.n_out_line) % len(s0)])
-    #             for j in range(len(jcs[i])):
-    #                 if len(jcs[i][j]) == 0:
-    #                     jcs[i][j] = torch.zeros([self.n_out_junc, 2], device=p.device)
-    #                 jcs[i][j] = jcs[i][j][
-    #                     None, torch.arange(self.n_out_junc) % len(jcs[i][j])
-    #                 ]
-    #         result["preds"]["lines"] = torch.cat(lines)
-    #         result["preds"]["score"] = torch.cat(score)
-    #         result["preds"]["juncs"] = torch.cat([jcs[i][0] for i in range(n_batch)])
-    #
-    #         if len(jcs[i]) > 1:
-    #             result["preds"]["junts"] = torch.cat(
-    #                 [jcs[i][1] for i in range(n_batch)]
-    #             )
-    #     if self.training:
-    #         del result["preds"]
-
-    def sample_lines(self, meta, jmap, joff):
-        with torch.no_grad():
-            junc = meta["junc_coords"]  # [N, 2]
-            jtyp = meta["jtyp"]  # [N]
-            Lpos = meta["line_pos_idx"]
-            Lneg = meta["line_neg_idx"]
-
-            n_type = jmap.shape[0]
-            jmap = non_maximum_suppression(jmap).reshape(n_type, -1)
-            joff = joff.reshape(n_type, 2, -1)
-            max_K = self.n_dyn_junc // n_type
-            N = len(junc)
-            # if mode != "training":
-            if not self.training:
-                K = min(int((jmap > self.eval_junc_thres).float().sum().item()), max_K)
-            else:
-                K = min(int(N * 2 + 2), max_K)
-            if K < 2:
-                K = 2
-            device = jmap.device
-
-            # index: [N_TYPE, K]
-            score, index = torch.topk(jmap, k=K)
-            y = (index // 128).float() + torch.gather(joff[:, 0], 1, index) + 0.5
-            x = (index % 128).float() + torch.gather(joff[:, 1], 1, index) + 0.5
-
-            # xy: [N_TYPE, K, 2]
-            xy = torch.cat([y[..., None], x[..., None]], dim=-1)
-            xy_ = xy[..., None, :]
-            del x, y, index
-
-            # dist: [N_TYPE, K, N]
-            dist = torch.sum((xy_ - junc) ** 2, -1)
-            cost, match = torch.min(dist, -1)
-
-            # xy: [N_TYPE * K, 2]
-            # match: [N_TYPE, K]
-            for t in range(n_type):
-                match[t, jtyp[match[t]] != t] = N
-            match[cost > 1.5 * 1.5] = N
-            match = match.flatten()
-
-            _ = torch.arange(n_type * K, device=device)
-            u, v = torch.meshgrid(_, _)
-            u, v = u.flatten(), v.flatten()
-            up, vp = match[u], match[v]
-            label = Lpos[up, vp]
-
-            # if mode == "training":
-            if self.training:
-                c = torch.zeros_like(label, dtype=torch.bool)
-
-                # sample positive lines
-                cdx = label.nonzero().flatten()
-                if len(cdx) > self.n_dyn_posl:
-                    # print("too many positive lines")
-                    perm = torch.randperm(len(cdx), device=device)[: self.n_dyn_posl]
-                    cdx = cdx[perm]
-                c[cdx] = 1
-
-                # sample negative lines
-                cdx = Lneg[up, vp].nonzero().flatten()
-                if len(cdx) > self.n_dyn_negl:
-                    # print("too many negative lines")
-                    perm = torch.randperm(len(cdx), device=device)[: self.n_dyn_negl]
-                    cdx = cdx[perm]
-                c[cdx] = 1
-
-                # sample other (unmatched) lines
-                cdx = torch.randint(len(c), (self.n_dyn_othr,), device=device)
-                c[cdx] = 1
-            else:
-                c = (u < v).flatten()
-
-            # sample lines
-            u, v, label = u[c], v[c], label[c]
-            xy = xy.reshape(n_type * K, 2)
-            xyu, xyv = xy[u], xy[v]
-
-            u2v = xyu - xyv
-            u2v /= torch.sqrt((u2v ** 2).sum(-1, keepdim=True)).clamp(min=1e-6)
-            feat = torch.cat(
-                [
-                    xyu / 128 * self.use_cood,
-                    xyv / 128 * self.use_cood,
-                    u2v * self.use_slop,
-                    (u[:, None] > K).float(),
-                    (v[:, None] > K).float(),
-                ],
-                1,
-            )
-            line = torch.cat([xyu[:, None], xyv[:, None]], 1)
-
-            xy = xy.reshape(n_type, K, 2)
-            jcs = [xy[i, score[i] > 0.03] for i in range(n_type)]
-            return line, label.float(), feat, jcs
-
-
-
-def wirepointrcnn_resnet50_fpn(
-        *,
-        weights: Optional[KeypointRCNN_ResNet50_FPN_Weights] = None,
-        progress: bool = True,
-        num_classes: Optional[int] = None,
-        num_keypoints: Optional[int] = None,
-        weights_backbone: Optional[ResNet50_Weights] = ResNet50_Weights.IMAGENET1K_V1,
-        trainable_backbone_layers: Optional[int] = None,
-        **kwargs: Any,
-) -> WirepointRCNN:
-    weights = KeypointRCNN_ResNet50_FPN_Weights.verify(weights)
-    weights_backbone = ResNet50_Weights.verify(weights_backbone)
-
-    is_trained = weights is not None or weights_backbone is not None
-    trainable_backbone_layers = _validate_trainable_layers(is_trained, trainable_backbone_layers, 5, 3)
-
-    norm_layer = misc_nn_ops.FrozenBatchNorm2d if is_trained else nn.BatchNorm2d
-
-    backbone = resnet50(weights=weights_backbone, progress=progress, norm_layer=norm_layer)
-    backbone = _resnet_fpn_extractor(backbone, trainable_backbone_layers)
-    model = WirepointRCNN(backbone, num_classes=5, **kwargs)
-
-    if weights is not None:
-        model.load_state_dict(weights.get_state_dict(progress=progress))
-        if weights == KeypointRCNN_ResNet50_FPN_Weights.COCO_V1:
-            overwrite_eps(model, 0.0)
-
-    return model
-
-
-if __name__ == '__main__':
-    cfg = 'wirenet.yaml'
-    cfg = read_yaml(cfg)
-    print(f'cfg:{cfg}')
-    print(cfg['model']['n_dyn_negl'])
-    # net = WirepointPredictor()
-
-    dataset = WirePointDataset(dataset_path=cfg['io']['datadir'], dataset_type='val')
-    train_sampler = torch.utils.data.RandomSampler(dataset)
-    # test_sampler = torch.utils.data.SequentialSampler(dataset_test)
-    train_batch_sampler = torch.utils.data.BatchSampler(train_sampler, batch_size=4, drop_last=True)
-    train_collate_fn = utils.collate_fn_wirepoint
-    data_loader = torch.utils.data.DataLoader(
-        dataset, batch_sampler=train_batch_sampler, num_workers=10, collate_fn=train_collate_fn
-    )
-    model = wirepointrcnn_resnet50_fpn()
-
-    imgs, targets = next(iter(data_loader))
-
-    model.train()
-    pred = model(imgs, targets)
-    print(f'pred:{pred}')
-    # result, losses = model(imgs, targets)
-    # print(f'result:{result}')
-    # print(f'pred:{losses}')
-'''
-########### predict#############
-
-    img_path=r"I:\wirenet_dateset\images\train\00030078_2.png"
-    transforms = MaskRCNN_ResNet50_FPN_V2_Weights.DEFAULT.transforms()
-    img = read_image(img_path)
-    img = transforms(img)
-
-    img = torch.ones((2, 3, 512, 512))
-    # print(f'img shape:{img.shape}')
-    model.eval()
-    onnx_file_path = "./wirenet.onnx"
-
-    # 导出模型为ONNX格式
-    # torch.onnx.export(model, img, onnx_file_path, verbose=True, input_names=['input'],
-    #                   output_names=['output'])
-    # torch.save(model,'./wirenet.pt')
-
-
-
-    # 5. 指定输出的 ONNX 文件名
-    # onnx_file_path = "./wirepoint_rcnn.onnx"
-
-    # 准备一个示例输入:Mask R-CNN 需要一个图像列表作为输入,每个图像张量的形状应为 [C, H, W]
-    img = [torch.ones((3, 800, 800))]  # 示例输入图像大小为 800x800,3个通道
-
-
-
-    # 指定输出的 ONNX 文件名
-    # onnx_file_path = "./mask_rcnn.onnx"
-
-
-
-    # model_scripted = torch.jit.script(model)
-    # torch.onnx.export(model_scripted, input, "model.onnx", verbose=True, input_names=["input"],
-    #                   output_names=["output"])
-    #
-    # print(f"Model has been converted to ONNX and saved to {onnx_file_path}")
-
-    pred=model(img)
-    #
-    print(f'pred:{pred}')
-
-
-
-################################################## end predict
-
-
-
-########## traing ###################################
-    # imgs, targets = next(iter(data_loader))
-
-    # model.train()
-    # pred = model(imgs, targets)
-
-    # class WrapperModule(torch.nn.Module):
-    #     def __init__(self, model):
-    #         super(WrapperModule, self).__init__()
-    #         self.model = model
-    #
-    #     def forward(self,img, targets):
-    #         # 在这里处理复杂的输入结构,将其转换为适合追踪的形式
-    #         return self.model(img,targets)
-
-    # torch.save(model.state_dict(),'./wire.pt')
-    # 包装原始模型
-    # wrapped_model = WrapperModule(model)
-    # # model_scripted = torch.jit.trace(wrapped_model,img)
-    # writer = SummaryWriter('./')
-    # writer.add_graph(wrapped_model, (imgs,targets))
-    # writer.close()
-
-
-    #
-    # print(f'pred:{pred}')
-########## end traing ###################################
-    # for imgs,targets in data_loader:
-    #     print(f'imgs:{imgs}')
-    #     print(f'targets:{targets}')
-'''

+ 4 - 4
tools/coco_utils.py

@@ -139,8 +139,8 @@ def convert_to_coco_api(ds):
         bboxes[:, 2:] -= bboxes[:, :2]
         bboxes = bboxes.tolist()
         labels = targets["labels"].tolist()
-        areas = targets["area"].tolist()
-        iscrowd = targets["iscrowd"].tolist()
+        # areas = targets["area"].tolist()
+        # iscrowd = targets["iscrowd"].tolist()
         if "masks" in targets:
             masks = targets["masks"]
             # make masks Fortran contiguous for coco_mask
@@ -155,8 +155,8 @@ def convert_to_coco_api(ds):
             ann["bbox"] = bboxes[i]
             ann["category_id"] = labels[i]
             categories.add(labels[i])
-            ann["area"] = areas[i]
-            ann["iscrowd"] = iscrowd[i]
+            # ann["area"] = areas[i]
+            # ann["iscrowd"] = iscrowd[i]
             ann["id"] = ann_id
             if "masks" in targets:
                 ann["segmentation"] = coco_mask.encode(masks[i].numpy())

+ 1 - 1
tools/utils.py

@@ -201,7 +201,7 @@ class MetricLogger:
 
 
 def collate_fn(batch):
-    print(f'batch:{len(batch)}')
+    # print(f'batch:{len(batch)}')
     return tuple(zip(*batch))
 
 def collate_fn_wirepoint(batch):

+ 371 - 0
冻结参数训练.py

@@ -0,0 +1,371 @@
+#!/usr/bin/env python3
+import datetime
+import glob
+import os
+import os.path as osp
+import platform
+import pprint
+import random
+import shlex
+import shutil
+import subprocess
+import sys
+import numpy as np
+import torch
+import torchvision
+import yaml
+import lcnn
+from lcnn.config import C, M
+from lcnn.datasets import WireframeDataset, collate
+from lcnn.models.line_vectorizer import LineVectorizer
+from lcnn.models.multitask_learner import MultitaskHead, MultitaskLearner
+from torchvision.models import resnet50
+
+
+def print_model_structure(model):
+    """
+    详细打印模型结构和参数
+    """
+    print("\n========= Model Structure =========")
+
+    # 打印模型总体信息
+    print("Model Type:", type(model))
+
+    # 打印模型总参数量
+    total_params = sum(p.numel() for p in model.parameters())
+    trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
+
+    print(f"\nTotal Parameters: {total_params:,}")
+    print(f"Trainable Parameters: {trainable_params:,}")
+    print(f"Non-trainable Parameters: {total_params - trainable_params:,}")
+
+    # 打印每个模块的参数量和可训练状态
+    print("\n===== Detailed Model Components =====")
+    for name, module in model.named_children():
+        module_params = sum(p.numel() for p in module.parameters())
+        module_trainable_params = sum(p.numel() for p in module.parameters() if p.requires_grad)
+
+        print(f"\nmodel.named:{name}:")
+        print(f"  Total Parameters: {module_params:,}")
+        print(f"  Trainable Parameters: {module_trainable_params:,}")
+
+        # 打印子模块
+        for subname, submodule in module.named_children():
+            sub_params = sum(p.numel() for p in submodule.parameters())
+            sub_trainable_params = sum(p.numel() for p in submodule.parameters() if p.requires_grad)
+
+            print(f"    {subname}:")
+            print(f"      Total Parameters: {sub_params:,}")
+            print(f"      Trainable Parameters: {sub_trainable_params:,}")
+
+
+def verify_freeze_params(model, freeze_config):
+    """
+    验证参数冻结是否生效
+    """
+    print("\n===== Verifying Parameter Freezing =====")
+
+    for name, module in model.named_children():
+        if name in freeze_config:
+            if freeze_config[name]:
+                print(f"\nChecking module: {name}")
+                for param_name, param in module.named_parameters():
+                    print(f"  {param_name}: requires_grad = {param.requires_grad}")
+
+        # 特别处理fc2子模块
+        if name == 'fc2' and 'fc2_submodules' in freeze_config:
+            for subname, submodule in module.named_children():
+                if subname in freeze_config['fc2_submodules']:
+                    if freeze_config['fc2_submodules'][subname]:
+                        print(f"\nChecking fc2 submodule: {subname}")
+                        for param_name, param in submodule.named_parameters():
+                            print(f"  {param_name}: requires_grad = {param.requires_grad}")
+
+
+def freeze_params(model, freeze_config=None):
+    """
+    更精细的参数冻结方法
+
+    Args:
+        model: 要冻结参数的模型
+        freeze_config: 冻结配置字典
+    """
+    # 默认冻结配置
+    default_config = {
+        'backbone': False,
+        'fc1': False,
+        'fc2': False,
+        'fc2_submodules': {
+            '0': False,  # fc2的第一个子模块
+            '2': False,  # fc2的第三个子模块
+            '4': False  # fc2的第五个子模块
+        },
+        'pooling': False,
+        'loss': False
+    }
+
+    # 更新默认配置
+    if freeze_config is not None:
+        for key, value in freeze_config.items():
+            if isinstance(value, dict):
+                default_config[key].update(value)
+            else:
+                default_config[key] = value
+
+    print("\n===== Parameter Freezing Configuration =====")
+
+    for name, module in model.named_children():
+        # 处理主模块冻结
+        if name in default_config:
+            for param in module.parameters():
+                param.requires_grad = not default_config[name]
+
+            if not default_config[name]:
+                print(f"Module {name} is trainable")
+            else:
+                print(f"Freezing module: {name}")
+
+        # 处理fc2的子模块
+        if name == 'fc2' and 'fc2_submodules' in default_config:
+            for subname, submodule in module.named_children():
+                if subname in default_config['fc2_submodules']:
+                    for param in submodule.parameters():
+                        param.requires_grad = not default_config['fc2_submodules'][subname]
+
+                    if not default_config['fc2_submodules'][subname]:
+                        print(f"Submodule fc2.{subname} is trainable")
+                    else:
+                        print(f"Freezing submodule: fc2.{subname}")
+
+    # 打印参数冻结后的详细信息
+    total_params = sum(p.numel() for p in model.parameters())
+    trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
+
+    print(f"\nTotal Parameters: {total_params:,}")
+    print(f"Trainable Parameters: {trainable_params:,}")
+    print(f"Frozen Parameters: {total_params - trainable_params:,}")
+
+
+def get_model(num_classes):
+    # 加载预训练的ResNet-50 FPN backbone
+    model = torchvision.models.detection.fasterrcnn_resnet50_fpn(pretrained=True)
+
+    # 获取分类器的输入特征数
+    in_features = model.roi_heads.box_predictor.cls_score.in_features
+
+    # 替换分类器以适应新的类别数量
+    model.roi_heads.box_predictor = torchvision.models.detection.faster_rcnn.FastRCNNPredictor(in_features, num_classes)
+
+    return model
+
+
+def main():
+    # 训练配置参数
+    config = {
+        # 数据集配置
+        'datadir': r'D:\python\PycharmProjects\data',  # 数据集目录
+        'config_file': 'config/wireframe.yaml',  # 配置文件路径
+
+        # GPU配置
+        'devices': '0',  # 使用的GPU设备
+        'identifier': 'fasterrcnn_resnet50',  # 训练标识符 stacked_hourglass unet
+
+        # 预训练模型路径
+        # 'pretrained_model': './190418-201834-f8934c6-lr4d10-312k.pth',  # 预训练模型路径
+
+        # 详细的参数冻结配置  冻结是True
+        'freeze_config': {
+            'backbone': False,  # 冻结backbone
+            'fc1': False,  # 不冻结fc1
+            'fc2': False,  # 不冻结fc2
+            'fc2_submodules': {
+                '0': False,  # fc2的第一个子模块保持可训练
+                '2': False,  # 冻结fc2的第三个子模块
+                '4': False  # fc2的第五个子模块保持可训练
+            },
+            'pooling': False,  # 不冻结pooling
+            'loss': False  # 不冻结loss
+        }
+    }
+
+    # 更新配置
+    C.update(C.from_yaml(filename=config['config_file']))
+    M.update(C.model)
+
+    # 设置随机数种子
+    random.seed(0)
+    np.random.seed(0)
+    torch.manual_seed(0)
+
+    # 设备配置
+    device_name = "cpu"
+    os.environ["CUDA_VISIBLE_DEVICES"] = config['devices']
+
+    if torch.cuda.is_available():
+        device_name = "cuda"
+        torch.backends.cudnn.deterministic = True
+        torch.cuda.manual_seed(0)
+        print("Let's use", torch.cuda.device_count(), "GPU(s)!")
+    else:
+        print("CUDA is not available")
+
+    device = torch.device(device_name)
+
+    # 数据加载
+    kwargs = {
+        "collate_fn": collate,
+        "num_workers": C.io.num_workers if os.name != "nt" else 0,
+        "pin_memory": True,
+    }
+
+    train_loader = torch.utils.data.DataLoader(
+        WireframeDataset(config['datadir'], dataset_type="train"),
+        shuffle=True,
+        batch_size=M.batch_size,
+        **kwargs,
+    )
+
+
+    val_loader = torch.utils.data.DataLoader(
+        WireframeDataset(config['datadir'], dataset_type="val"),
+        shuffle=False,
+        batch_size=M.batch_size_eval,
+        **kwargs,
+    )
+
+    # 构建模型
+    if M.backbone == "stacked_hourglass":
+        print(f"backbone == stacked_hourglass")
+        model = lcnn.models.hg(
+            depth=M.depth,
+            head=MultitaskHead,
+            num_stacks=M.num_stacks,
+            num_blocks=M.num_blocks,
+            num_classes=sum(sum(M.head_size, [])),
+        )
+        print(f"model.shape:{model}")
+        model = MultitaskLearner(model)
+        model = LineVectorizer(model)
+    elif M.backbone == "unet":
+        print(f"backbone == unet")
+        # weights_backbone = ResNet50_Weights.verify(weights_backbone)
+        model = lcnn.models.unet(
+            num_classes=sum(sum(M.head_size, [])),
+            num_stacks=M.num_stacks,
+            base_channels=kwargs.get("base_channels", 64)
+        )
+        model = MultitaskLearner(model)
+        model = LineVectorizer(model)
+    elif M.backbone == "resnet50":
+        print(f"backbone == resnet50")
+        model = lcnn.models.resnet50(
+            # num_stacks=M.num_stacks,
+            num_classes=sum(sum(M.head_size, [])),
+        )
+        model = MultitaskLearner(model)
+        model = LineVectorizer(model)
+    elif M.backbone == "resnet501":
+        print(f"backbone == resnet501")
+        model = lcnn.models.resnet501(
+            # num_stacks=M.num_stacks,
+            num_classes=sum(sum(M.head_size, [])),
+        )
+        model = MultitaskLearner(model)
+        model = LineVectorizer(model)
+    elif M.backbone == "fasterrcnn_resnet50":
+        print(f"backbone == fasterrcnn_resnet50")
+        model = lcnn.models.fasterrcnn_resnet50(
+            # num_stacks=M.num_stacks,
+            num_classes=sum(sum(M.head_size, [])),
+        )
+        model = MultitaskLearner(model)
+        model = LineVectorizer(model)
+    else:
+        raise NotImplementedError
+
+    # 加载预训练权重
+
+    try:
+        # 加载模型权重
+        checkpoint = torch.load(config['pretrained_model'], map_location=device)
+
+        # 根据实际的检查点结构选择加载方式
+        if 'model_state_dict' in checkpoint:
+            # 如果是完整的检查点
+            model.load_state_dict(checkpoint['model_state_dict'])
+        elif 'state_dict' in checkpoint:
+            # 如果是只有状态字典的检查点
+            model.load_state_dict(checkpoint['state_dict'])
+        else:
+            # 直接加载权重字典
+            model.load_state_dict(checkpoint)
+
+        print("Successfully loaded pre-trained model weights.")
+    except Exception as e:
+        print(f"Error loading model weights: {e}")
+
+    # 打印模型结构
+    # print_model_structure(model)
+
+    # # 冻结参数
+    # freeze_params(
+    #     model,
+    #     freeze_config=config['freeze_config']
+    # )
+    # # 验证冻结参数
+    # verify_freeze_params(model, config['freeze_config'])
+    #
+    # # 打印模型结构
+    # print("\n========= After Freezing Backbone =========")
+    # print_model_structure(model)
+
+    # 移动到设备
+    model = model.to(device)
+
+    # 优化器配置
+    if C.optim.name == "Adam":
+        optim = torch.optim.Adam(
+            filter(lambda p: p.requires_grad, model.parameters()),
+            lr=C.optim.lr,
+            weight_decay=C.optim.weight_decay,
+            amsgrad=C.optim.amsgrad,
+        )
+    elif C.optim.name == "SGD":
+        optim = torch.optim.SGD(
+            filter(lambda p: p.requires_grad, model.parameters()),
+            lr=C.optim.lr,
+            weight_decay=C.optim.weight_decay,
+            momentum=C.optim.momentum,
+        )
+    else:
+        raise NotImplementedError
+
+    # 输出目录
+    outdir = osp.join(
+        osp.expanduser(C.io.logdir),
+        f"{datetime.datetime.now().strftime('%y%m%d-%H%M%S')}-{config['identifier']}"
+    )
+    os.makedirs(outdir, exist_ok=True)
+
+    try:
+        trainer = lcnn.trainer.Trainer(
+            device=device,
+            model=model,
+            optimizer=optim,
+            train_loader=train_loader,
+            val_loader=val_loader,
+            out=outdir,
+        )
+
+        print("Starting training...")
+        trainer.train()
+        print("Training completed.")
+
+    except BaseException:
+        if len(glob.glob(f"{outdir}/viz/*")) <= 1:
+            shutil.rmtree(outdir)
+        raise
+
+
+if __name__ == "__main__":
+    main()