xue50 3 miesięcy temu
rodzic
commit
d0d3f9387f

+ 77 - 0
aaa.py

@@ -0,0 +1,77 @@
+import torch
+from torchvision.utils import draw_bounding_boxes
+from torchvision import transforms
+import matplotlib.pyplot as plt
+import numpy as np
+
+
+def c(score):
+    # 根据分数返回颜色的函数,这里仅作示例,您可以根据需要修改
+    return (1, 0, 0) if score > 0.9 else (0, 1, 0)
+
+
+def postprocess(lines, scores, diag_threshold, min_score, remove_overlaps):
+    # 假设的后处理函数,用于过滤线段
+    nlines = []
+    nscores = []
+    for line, score in zip(lines, scores):
+        if score >= min_score:
+            nlines.append(line)
+            nscores.append(score)
+    return np.array(nlines), np.array(nscores)
+
+
+def show_line(img, pred, epoch, writer):
+    im = img.permute(1, 2, 0).cpu().numpy()
+
+    # 绘制边界框
+    boxed_image = draw_bounding_boxes((img * 255).to(torch.uint8), pred[0]["boxes"],
+                                      colors="yellow", width=1).permute(1, 2, 0).cpu().numpy()
+
+    H = pred[-1]['wires']
+    lines = H["lines"][0].cpu().numpy() / 128 * im.shape[:2]
+    scores = H["score"][0].cpu().numpy()
+
+    print(f"Lines before deduplication: {len(lines)}")
+
+    # 移除重复的线段
+    for i in range(1, len(lines)):
+        if (lines[i] == lines[0]).all():
+            lines = lines[:i]
+            scores = scores[:i]
+            break
+
+    print(f"Lines after deduplication: {len(lines)}")
+
+    # 后处理线段以移除重叠的线段
+    diag = (im.shape[0] ** 2 + im.shape[1] ** 2) ** 0.5
+    nlines, nscores = postprocess(lines, scores, diag * 0.01, 0, False)
+
+    print(f"Lines after postprocessing: {len(nlines)}")
+
+    # 创建一个新的图像并绘制线段和边界框
+    fig, ax = plt.subplots(figsize=(boxed_image.shape[1] / 100, boxed_image.shape[0] / 100))
+    ax.imshow(boxed_image)
+    ax.set_axis_off()
+    plt.subplots_adjust(top=1, bottom=0, right=1, left=0, hspace=0, wspace=0)
+    plt.margins(0, 0)
+    plt.gca().xaxis.set_major_locator(plt.NullLocator())
+    plt.gca().yaxis.set_major_locator(plt.NullLocator())
+
+    PLTOPTS = {"color": "#33FFFF", "s": 15, "edgecolors": "none", "zorder": 5}
+    for (a, b), s in zip(nlines, nscores):
+        if s < 0.85:  # 调整阈值以筛选显示的线段
+            continue
+        plt.plot([a[1], b[1]], [a[0], b[0]], c=c(s), linewidth=2, zorder=s)
+        plt.scatter(a[1], a[0], **PLTOPTS)
+        plt.scatter(b[1], b[0], **PLTOPTS)
+
+    plt.tight_layout()
+    fig.canvas.draw()
+    image_from_plot = np.frombuffer(fig.canvas.tostring_rgb(), dtype=np.uint8).reshape(
+        fig.canvas.get_width_height()[::-1] + (3,))
+    plt.close()
+    img2 = transforms.ToTensor()(image_from_plot)
+
+    writer.add_image("output_with_boxes_and_lines", img2, epoch)
+    print("Image with boxes and lines added to TensorBoard.")

+ 3 - 3
config/wireframe.yaml

@@ -2,8 +2,8 @@ io:
   logdir: logs/
   datadir: D:\python\PycharmProjects\data
   resume_from:
-  num_workers: 0
-  tensorboard_port: 0
+  num_workers: 8
+  tensorboard_port: 6000
   validation_interval: 300
 
 model:
@@ -15,7 +15,7 @@ model:
   batch_size_eval: 2
 
   # backbone multi-task parameters
-  head_size: [[2], [1], [2],[4]]
+  head_size: [[2], [1], [2]]
   loss_weight:
     jmap: 8.0
     lmap: 0.5

+ 0 - 4
lcnn/__init__.py

@@ -1,4 +0,0 @@
-import lcnn.models
-import lcnn.trainer
-import lcnn.datasets
-import lcnn.config

+ 0 - 1110
lcnn/box.py

@@ -1,1110 +0,0 @@
-#!/usr/bin/env python
-# -*- coding: UTF-8 -*-
-#
-# Copyright (c) 2017-2019 - Chris Griffith - MIT License
-"""
-Improved dictionary access through dot notation with additional tools.
-"""
-import string
-import sys
-import json
-import re
-import copy
-from keyword import kwlist
-import warnings
-
-try:
-    from collections.abc import Iterable, Mapping, Callable
-except ImportError:
-    from collections import Iterable, Mapping, Callable
-
-yaml_support = True
-
-try:
-    import yaml
-except ImportError:
-    try:
-        import ruamel.yaml as yaml
-    except ImportError:
-        yaml = None
-        yaml_support = False
-
-if sys.version_info >= (3, 0):
-    basestring = str
-else:
-    from io import open
-
-__all__ = ['Box', 'ConfigBox', 'BoxList', 'SBox',
-           'BoxError', 'BoxKeyError']
-__author__ = 'Chris Griffith'
-__version__ = '3.2.4'
-
-BOX_PARAMETERS = ('default_box', 'default_box_attr', 'conversion_box',
-                  'frozen_box', 'camel_killer_box', 'box_it_up',
-                  'box_safe_prefix', 'box_duplicates', 'ordered_box')
-
-_first_cap_re = re.compile('(.)([A-Z][a-z]+)')
-_all_cap_re = re.compile('([a-z0-9])([A-Z])')
-
-
-class BoxError(Exception):
-    """Non standard dictionary exceptions"""
-
-
-class BoxKeyError(BoxError, KeyError, AttributeError):
-    """Key does not exist"""
-
-
-# Abstract converter functions for use in any Box class
-
-
-def _to_json(obj, filename=None,
-             encoding="utf-8", errors="strict", **json_kwargs):
-    json_dump = json.dumps(obj,
-                           ensure_ascii=False, **json_kwargs)
-    if filename:
-        with open(filename, 'w', encoding=encoding, errors=errors) as f:
-            f.write(json_dump if sys.version_info >= (3, 0) else
-                    json_dump.decode("utf-8"))
-    else:
-        return json_dump
-
-
-def _from_json(json_string=None, filename=None,
-               encoding="utf-8", errors="strict", multiline=False, **kwargs):
-    if filename:
-        with open(filename, 'r', encoding=encoding, errors=errors) as f:
-            if multiline:
-                data = [json.loads(line.strip(), **kwargs) for line in f
-                        if line.strip() and not line.strip().startswith("#")]
-            else:
-                data = json.load(f, **kwargs)
-    elif json_string:
-        data = json.loads(json_string, **kwargs)
-    else:
-        raise BoxError('from_json requires a string or filename')
-    return data
-
-
-def _to_yaml(obj, filename=None, default_flow_style=False,
-             encoding="utf-8", errors="strict",
-             **yaml_kwargs):
-    if filename:
-        with open(filename, 'w',
-                  encoding=encoding, errors=errors) as f:
-            yaml.dump(obj, stream=f,
-                      default_flow_style=default_flow_style,
-                      **yaml_kwargs)
-    else:
-        return yaml.dump(obj,
-                         default_flow_style=default_flow_style,
-                         **yaml_kwargs)
-
-
-def _from_yaml(yaml_string=None, filename=None,
-               encoding="utf-8", errors="strict",
-               **kwargs):
-    if filename:
-        with open(filename, 'r',
-                  encoding=encoding, errors=errors) as f:
-            data = yaml.load(f, **kwargs)
-    elif yaml_string:
-        data = yaml.load(yaml_string, **kwargs)
-    else:
-        raise BoxError('from_yaml requires a string or filename')
-    return data
-
-
-# Helper functions
-
-
-def _safe_key(key):
-    try:
-        return str(key)
-    except UnicodeEncodeError:
-        return key.encode("utf-8", "ignore")
-
-
-def _safe_attr(attr, camel_killer=False, replacement_char='x'):
-    """Convert a key into something that is accessible as an attribute"""
-    allowed = string.ascii_letters + string.digits + '_'
-
-    attr = _safe_key(attr)
-
-    if camel_killer:
-        attr = _camel_killer(attr)
-
-    attr = attr.replace(' ', '_')
-
-    out = ''
-    for character in attr:
-        out += character if character in allowed else "_"
-    out = out.strip("_")
-
-    try:
-        int(out[0])
-    except (ValueError, IndexError):
-        pass
-    else:
-        out = '{0}{1}'.format(replacement_char, out)
-
-    if out in kwlist:
-        out = '{0}{1}'.format(replacement_char, out)
-
-    return re.sub('_+', '_', out)
-
-
-def _camel_killer(attr):
-    """
-    CamelKiller, qu'est-ce que c'est?
-
-    Taken from http://stackoverflow.com/a/1176023/3244542
-    """
-    try:
-        attr = str(attr)
-    except UnicodeEncodeError:
-        attr = attr.encode("utf-8", "ignore")
-
-    s1 = _first_cap_re.sub(r'\1_\2', attr)
-    s2 = _all_cap_re.sub(r'\1_\2', s1)
-    return re.sub('_+', '_', s2.casefold() if hasattr(s2, 'casefold') else
-                  s2.lower())
-
-
-def _recursive_tuples(iterable, box_class, recreate_tuples=False, **kwargs):
-    out_list = []
-    for i in iterable:
-        if isinstance(i, dict):
-            out_list.append(box_class(i, **kwargs))
-        elif isinstance(i, list) or (recreate_tuples and isinstance(i, tuple)):
-            out_list.append(_recursive_tuples(i, box_class,
-                                              recreate_tuples, **kwargs))
-        else:
-            out_list.append(i)
-    return tuple(out_list)
-
-
-def _conversion_checks(item, keys, box_config, check_only=False,
-                       pre_check=False):
-    """
-    Internal use for checking if a duplicate safe attribute already exists
-
-    :param item: Item to see if a dup exists
-    :param keys: Keys to check against
-    :param box_config: Easier to pass in than ask for specfic items
-    :param check_only: Don't bother doing the conversion work
-    :param pre_check: Need to add the item to the list of keys to check
-    :return: the original unmodified key, if exists and not check_only
-    """
-    if box_config['box_duplicates'] != 'ignore':
-        if pre_check:
-            keys = list(keys) + [item]
-
-        key_list = [(k,
-                     _safe_attr(k, camel_killer=box_config['camel_killer_box'],
-                                replacement_char=box_config['box_safe_prefix']
-                                )) for k in keys]
-        if len(key_list) > len(set(x[1] for x in key_list)):
-            seen = set()
-            dups = set()
-            for x in key_list:
-                if x[1] in seen:
-                    dups.add("{0}({1})".format(x[0], x[1]))
-                seen.add(x[1])
-            if box_config['box_duplicates'].startswith("warn"):
-                warnings.warn('Duplicate conversion attributes exist: '
-                              '{0}'.format(dups))
-            else:
-                raise BoxError('Duplicate conversion attributes exist: '
-                               '{0}'.format(dups))
-    if check_only:
-        return
-    # This way will be slower for warnings, as it will have double work
-    # But faster for the default 'ignore'
-    for k in keys:
-        if item == _safe_attr(k, camel_killer=box_config['camel_killer_box'],
-                              replacement_char=box_config['box_safe_prefix']):
-            return k
-
-
-def _get_box_config(cls, kwargs):
-    return {
-        # Internal use only
-        '__converted': set(),
-        '__box_heritage': kwargs.pop('__box_heritage', None),
-        '__created': False,
-        '__ordered_box_values': [],
-        # Can be changed by user after box creation
-        'default_box': kwargs.pop('default_box', False),
-        'default_box_attr': kwargs.pop('default_box_attr', cls),
-        'conversion_box': kwargs.pop('conversion_box', True),
-        'box_safe_prefix': kwargs.pop('box_safe_prefix', 'x'),
-        'frozen_box': kwargs.pop('frozen_box', False),
-        'camel_killer_box': kwargs.pop('camel_killer_box', False),
-        'modify_tuples_box': kwargs.pop('modify_tuples_box', False),
-        'box_duplicates': kwargs.pop('box_duplicates', 'ignore'),
-        'ordered_box': kwargs.pop('ordered_box', False)
-    }
-
-
-class Box(dict):
-    """
-    Improved dictionary access through dot notation with additional tools.
-
-    :param default_box: Similar to defaultdict, return a default value
-    :param default_box_attr: Specify the default replacement.
-        WARNING: If this is not the default 'Box', it will not be recursive
-    :param frozen_box: After creation, the box cannot be modified
-    :param camel_killer_box: Convert CamelCase to snake_case
-    :param conversion_box: Check for near matching keys as attributes
-    :param modify_tuples_box: Recreate incoming tuples with dicts into Boxes
-    :param box_it_up: Recursively create all Boxes from the start
-    :param box_safe_prefix: Conversion box prefix for unsafe attributes
-    :param box_duplicates: "ignore", "error" or "warn" when duplicates exists
-        in a conversion_box
-    :param ordered_box: Preserve the order of keys entered into the box
-    """
-
-    _protected_keys = dir({}) + ['to_dict', 'tree_view', 'to_json', 'to_yaml',
-                                 'from_yaml', 'from_json']
-
-    def __new__(cls, *args, **kwargs):
-        """
-        Due to the way pickling works in python 3, we need to make sure
-        the box config is created as early as possible.
-        """
-        obj = super(Box, cls).__new__(cls, *args, **kwargs)
-        obj._box_config = _get_box_config(cls, kwargs)
-        return obj
-
-    def __init__(self, *args, **kwargs):
-        self._box_config = _get_box_config(self.__class__, kwargs)
-        if self._box_config['ordered_box']:
-            self._box_config['__ordered_box_values'] = []
-        if (not self._box_config['conversion_box'] and
-                self._box_config['box_duplicates'] != "ignore"):
-            raise BoxError('box_duplicates are only for conversion_boxes')
-        if len(args) == 1:
-            if isinstance(args[0], basestring):
-                raise ValueError('Cannot extrapolate Box from string')
-            if isinstance(args[0], Mapping):
-                for k, v in args[0].items():
-                    if v is args[0]:
-                        v = self
-                    self[k] = v
-                    self.__add_ordered(k)
-            elif isinstance(args[0], Iterable):
-                for k, v in args[0]:
-                    self[k] = v
-                    self.__add_ordered(k)
-
-            else:
-                raise ValueError('First argument must be mapping or iterable')
-        elif args:
-            raise TypeError('Box expected at most 1 argument, '
-                            'got {0}'.format(len(args)))
-
-        box_it = kwargs.pop('box_it_up', False)
-        for k, v in kwargs.items():
-            if args and isinstance(args[0], Mapping) and v is args[0]:
-                v = self
-            self[k] = v
-            self.__add_ordered(k)
-
-        if (self._box_config['frozen_box'] or box_it or
-                self._box_config['box_duplicates'] != 'ignore'):
-            self.box_it_up()
-
-        self._box_config['__created'] = True
-
-    def __add_ordered(self, key):
-        if (self._box_config['ordered_box'] and
-                key not in self._box_config['__ordered_box_values']):
-            self._box_config['__ordered_box_values'].append(key)
-
-    def box_it_up(self):
-        """
-        Perform value lookup for all items in current dictionary,
-        generating all sub Box objects, while also running `box_it_up` on
-        any of those sub box objects.
-        """
-        for k in self:
-            _conversion_checks(k, self.keys(), self._box_config,
-                               check_only=True)
-            if self[k] is not self and hasattr(self[k], 'box_it_up'):
-                self[k].box_it_up()
-
-    def __hash__(self):
-        if self._box_config['frozen_box']:
-            hashing = 54321
-            for item in self.items():
-                hashing ^= hash(item)
-            return hashing
-        raise TypeError("unhashable type: 'Box'")
-
-    def __dir__(self):
-        allowed = string.ascii_letters + string.digits + '_'
-        kill_camel = self._box_config['camel_killer_box']
-        items = set(dir(dict) + ['to_dict', 'to_json',
-                                 'from_json', 'box_it_up'])
-        # Only show items accessible by dot notation
-        for key in self.keys():
-            key = _safe_key(key)
-            if (' ' not in key and key[0] not in string.digits and
-                    key not in kwlist):
-                for letter in key:
-                    if letter not in allowed:
-                        break
-                else:
-                    items.add(key)
-
-        for key in self.keys():
-            key = _safe_key(key)
-            if key not in items:
-                if self._box_config['conversion_box']:
-                    key = _safe_attr(key, camel_killer=kill_camel,
-                                     replacement_char=self._box_config[
-                                         'box_safe_prefix'])
-                    if key:
-                        items.add(key)
-            if kill_camel:
-                snake_key = _camel_killer(key)
-                if snake_key:
-                    items.remove(key)
-                    items.add(snake_key)
-
-        if yaml_support:
-            items.add('to_yaml')
-            items.add('from_yaml')
-
-        return list(items)
-
-    def get(self, key, default=None):
-        try:
-            return self[key]
-        except KeyError:
-            if isinstance(default, dict) and not isinstance(default, Box):
-                return Box(default)
-            if isinstance(default, list) and not isinstance(default, BoxList):
-                return BoxList(default)
-            return default
-
-    def copy(self):
-        return self.__class__(super(self.__class__, self).copy())
-
-    def __copy__(self):
-        return self.__class__(super(self.__class__, self).copy())
-
-    def __deepcopy__(self, memodict=None):
-        out = self.__class__()
-        memodict = memodict or {}
-        memodict[id(self)] = out
-        for k, v in self.items():
-            out[copy.deepcopy(k, memodict)] = copy.deepcopy(v, memodict)
-        return out
-
-    def __setstate__(self, state):
-        self._box_config = state['_box_config']
-        self.__dict__.update(state)
-
-    def __getitem__(self, item, _ignore_default=False):
-        try:
-            value = super(Box, self).__getitem__(item)
-        except KeyError as err:
-            if item == '_box_config':
-                raise BoxKeyError('_box_config should only exist as an '
-                                  'attribute and is never defaulted')
-            if self._box_config['default_box'] and not _ignore_default:
-                return self.__get_default(item)
-            raise BoxKeyError(str(err))
-        else:
-            return self.__convert_and_store(item, value)
-
-    def keys(self):
-        if self._box_config['ordered_box']:
-            return self._box_config['__ordered_box_values']
-        return super(Box, self).keys()
-
-    def values(self):
-        return [self[x] for x in self.keys()]
-
-    def items(self):
-        return [(x, self[x]) for x in self.keys()]
-
-    def __get_default(self, item):
-        default_value = self._box_config['default_box_attr']
-        if default_value is self.__class__:
-            return self.__class__(__box_heritage=(self, item),
-                                  **self.__box_config())
-        elif isinstance(default_value, Callable):
-            return default_value()
-        elif hasattr(default_value, 'copy'):
-            return default_value.copy()
-        return default_value
-
-    def __box_config(self):
-        out = {}
-        for k, v in self._box_config.copy().items():
-            if not k.startswith("__"):
-                out[k] = v
-        return out
-
-    def __convert_and_store(self, item, value):
-        if item in self._box_config['__converted']:
-            return value
-        if isinstance(value, dict) and not isinstance(value, Box):
-            value = self.__class__(value, __box_heritage=(self, item),
-                                   **self.__box_config())
-            self[item] = value
-        elif isinstance(value, list) and not isinstance(value, BoxList):
-            if self._box_config['frozen_box']:
-                value = _recursive_tuples(value, self.__class__,
-                                          recreate_tuples=self._box_config[
-                                              'modify_tuples_box'],
-                                          __box_heritage=(self, item),
-                                          **self.__box_config())
-            else:
-                value = BoxList(value, __box_heritage=(self, item),
-                                box_class=self.__class__,
-                                **self.__box_config())
-            self[item] = value
-        elif (self._box_config['modify_tuples_box'] and
-              isinstance(value, tuple)):
-            value = _recursive_tuples(value, self.__class__,
-                                      recreate_tuples=True,
-                                      __box_heritage=(self, item),
-                                      **self.__box_config())
-            self[item] = value
-        self._box_config['__converted'].add(item)
-        return value
-
-    def __create_lineage(self):
-        if (self._box_config['__box_heritage'] and
-                self._box_config['__created']):
-            past, item = self._box_config['__box_heritage']
-            if not past[item]:
-                past[item] = self
-            self._box_config['__box_heritage'] = None
-
-    def __getattr__(self, item):
-        try:
-            try:
-                value = self.__getitem__(item, _ignore_default=True)
-            except KeyError:
-                value = object.__getattribute__(self, item)
-        except AttributeError as err:
-            if item == "__getstate__":
-                raise AttributeError(item)
-            if item == '_box_config':
-                raise BoxError('_box_config key must exist')
-            kill_camel = self._box_config['camel_killer_box']
-            if self._box_config['conversion_box'] and item:
-                k = _conversion_checks(item, self.keys(), self._box_config)
-                if k:
-                    return self.__getitem__(k)
-            if kill_camel:
-                for k in self.keys():
-                    if item == _camel_killer(k):
-                        return self.__getitem__(k)
-            if self._box_config['default_box']:
-                return self.__get_default(item)
-            raise BoxKeyError(str(err))
-        else:
-            if item == '_box_config':
-                return value
-            return self.__convert_and_store(item, value)
-
-    def __setitem__(self, key, value):
-        if (key != '_box_config' and self._box_config['__created'] and
-                self._box_config['frozen_box']):
-            raise BoxError('Box is frozen')
-        if self._box_config['conversion_box']:
-            _conversion_checks(key, self.keys(), self._box_config,
-                               check_only=True, pre_check=True)
-        super(Box, self).__setitem__(key, value)
-        self.__add_ordered(key)
-        self.__create_lineage()
-
-    def __setattr__(self, key, value):
-        if (key != '_box_config' and self._box_config['frozen_box'] and
-                self._box_config['__created']):
-            raise BoxError('Box is frozen')
-        if key in self._protected_keys:
-            raise AttributeError("Key name '{0}' is protected".format(key))
-        if key == '_box_config':
-            return object.__setattr__(self, key, value)
-        try:
-            object.__getattribute__(self, key)
-        except (AttributeError, UnicodeEncodeError):
-            if (key not in self.keys() and
-                    (self._box_config['conversion_box'] or
-                     self._box_config['camel_killer_box'])):
-                if self._box_config['conversion_box']:
-                    k = _conversion_checks(key, self.keys(),
-                                           self._box_config)
-                    self[key if not k else k] = value
-                elif self._box_config['camel_killer_box']:
-                    for each_key in self:
-                        if key == _camel_killer(each_key):
-                            self[each_key] = value
-                            break
-            else:
-                self[key] = value
-        else:
-            object.__setattr__(self, key, value)
-        self.__add_ordered(key)
-        self.__create_lineage()
-
-    def __delitem__(self, key):
-        if self._box_config['frozen_box']:
-            raise BoxError('Box is frozen')
-        super(Box, self).__delitem__(key)
-        if (self._box_config['ordered_box'] and
-                key in self._box_config['__ordered_box_values']):
-            self._box_config['__ordered_box_values'].remove(key)
-
-    def __delattr__(self, item):
-        if self._box_config['frozen_box']:
-            raise BoxError('Box is frozen')
-        if item == '_box_config':
-            raise BoxError('"_box_config" is protected')
-        if item in self._protected_keys:
-            raise AttributeError("Key name '{0}' is protected".format(item))
-        try:
-            object.__getattribute__(self, item)
-        except AttributeError:
-            del self[item]
-        else:
-            object.__delattr__(self, item)
-        if (self._box_config['ordered_box'] and
-                item in self._box_config['__ordered_box_values']):
-            self._box_config['__ordered_box_values'].remove(item)
-
-    def pop(self, key, *args):
-        if args:
-            if len(args) != 1:
-                raise BoxError('pop() takes only one optional'
-                               ' argument "default"')
-            try:
-                item = self[key]
-            except KeyError:
-                return args[0]
-            else:
-                del self[key]
-                return item
-        try:
-            item = self[key]
-        except KeyError:
-            raise BoxKeyError('{0}'.format(key))
-        else:
-            del self[key]
-            return item
-
-    def clear(self):
-        self._box_config['__ordered_box_values'] = []
-        super(Box, self).clear()
-
-    def popitem(self):
-        try:
-            key = next(self.__iter__())
-        except StopIteration:
-            raise BoxKeyError('Empty box')
-        return key, self.pop(key)
-
-    def __repr__(self):
-        return '<Box: {0}>'.format(str(self.to_dict()))
-
-    def __str__(self):
-        return str(self.to_dict())
-
-    def __iter__(self):
-        for key in self.keys():
-            yield key
-
-    def __reversed__(self):
-        for key in reversed(list(self.keys())):
-            yield key
-
-    def to_dict(self):
-        """
-        Turn the Box and sub Boxes back into a native
-        python dictionary.
-
-        :return: python dictionary of this Box
-        """
-        out_dict = dict(self)
-        for k, v in out_dict.items():
-            if v is self:
-                out_dict[k] = out_dict
-            elif hasattr(v, 'to_dict'):
-                out_dict[k] = v.to_dict()
-            elif hasattr(v, 'to_list'):
-                out_dict[k] = v.to_list()
-        return out_dict
-
-    def update(self, item=None, **kwargs):
-        if not item:
-            item = kwargs
-        iter_over = item.items() if hasattr(item, 'items') else item
-        for k, v in iter_over:
-            if isinstance(v, dict):
-                # Box objects must be created in case they are already
-                # in the `converted` box_config set
-                v = self.__class__(v)
-                if k in self and isinstance(self[k], dict):
-                    self[k].update(v)
-                    continue
-            if isinstance(v, list):
-                v = BoxList(v)
-            try:
-                self.__setattr__(k, v)
-            except (AttributeError, TypeError):
-                self.__setitem__(k, v)
-
-    def setdefault(self, item, default=None):
-        if item in self:
-            return self[item]
-
-        if isinstance(default, dict):
-            default = self.__class__(default)
-        if isinstance(default, list):
-            default = BoxList(default)
-        self[item] = default
-        return default
-
-    def to_json(self, filename=None,
-                encoding="utf-8", errors="strict", **json_kwargs):
-        """
-        Transform the Box object into a JSON string.
-
-        :param filename: If provided will save to file
-        :param encoding: File encoding
-        :param errors: How to handle encoding errors
-        :param json_kwargs: additional arguments to pass to json.dump(s)
-        :return: string of JSON or return of `json.dump`
-        """
-        return _to_json(self.to_dict(), filename=filename,
-                        encoding=encoding, errors=errors, **json_kwargs)
-
-    @classmethod
-    def from_json(cls, json_string=None, filename=None,
-                  encoding="utf-8", errors="strict", **kwargs):
-        """
-        Transform a json object string into a Box object. If the incoming
-        json is a list, you must use BoxList.from_json.
-
-        :param json_string: string to pass to `json.loads`
-        :param filename: filename to open and pass to `json.load`
-        :param encoding: File encoding
-        :param errors: How to handle encoding errors
-        :param kwargs: parameters to pass to `Box()` or `json.loads`
-        :return: Box object from json data
-        """
-        bx_args = {}
-        for arg in kwargs.copy():
-            if arg in BOX_PARAMETERS:
-                bx_args[arg] = kwargs.pop(arg)
-
-        data = _from_json(json_string, filename=filename,
-                          encoding=encoding, errors=errors, **kwargs)
-
-        if not isinstance(data, dict):
-            raise BoxError('json data not returned as a dictionary, '
-                           'but rather a {0}'.format(type(data).__name__))
-        return cls(data, **bx_args)
-
-    if yaml_support:
-        def to_yaml(self, filename=None, default_flow_style=False,
-                    encoding="utf-8", errors="strict",
-                    **yaml_kwargs):
-            """
-            Transform the Box object into a YAML string.
-
-            :param filename:  If provided will save to file
-            :param default_flow_style: False will recursively dump dicts
-            :param encoding: File encoding
-            :param errors: How to handle encoding errors
-            :param yaml_kwargs: additional arguments to pass to yaml.dump
-            :return: string of YAML or return of `yaml.dump`
-            """
-            return _to_yaml(self.to_dict(), filename=filename,
-                            default_flow_style=default_flow_style,
-                            encoding=encoding, errors=errors, **yaml_kwargs)
-
-        @classmethod
-        def from_yaml(cls, yaml_string=None, filename=None,
-                      encoding="utf-8", errors="strict",
-                      loader=yaml.SafeLoader, **kwargs):
-            """
-            Transform a yaml object string into a Box object.
-
-            :param yaml_string: string to pass to `yaml.load`
-            :param filename: filename to open and pass to `yaml.load`
-            :param encoding: File encoding
-            :param errors: How to handle encoding errors
-            :param loader: YAML Loader, defaults to SafeLoader
-            :param kwargs: parameters to pass to `Box()` or `yaml.load`
-            :return: Box object from yaml data
-            """
-            bx_args = {}
-            for arg in kwargs.copy():
-                if arg in BOX_PARAMETERS:
-                    bx_args[arg] = kwargs.pop(arg)
-
-            data = _from_yaml(yaml_string=yaml_string, filename=filename,
-                              encoding=encoding, errors=errors,
-                              Loader=loader, **kwargs)
-            if not isinstance(data, dict):
-                raise BoxError('yaml data not returned as a dictionary'
-                               'but rather a {0}'.format(type(data).__name__))
-            return cls(data, **bx_args)
-
-
-class BoxList(list):
-    """
-    Drop in replacement of list, that converts added objects to Box or BoxList
-    objects as necessary.
-    """
-
-    def __init__(self, iterable=None, box_class=Box, **box_options):
-        self.box_class = box_class
-        self.box_options = box_options
-        self.box_org_ref = self.box_org_ref = id(iterable) if iterable else 0
-        if iterable:
-            for x in iterable:
-                self.append(x)
-        if box_options.get('frozen_box'):
-            def frozen(*args, **kwargs):
-                raise BoxError('BoxList is frozen')
-
-            for method in ['append', 'extend', 'insert', 'pop',
-                           'remove', 'reverse', 'sort']:
-                self.__setattr__(method, frozen)
-
-    def __delitem__(self, key):
-        if self.box_options.get('frozen_box'):
-            raise BoxError('BoxList is frozen')
-        super(BoxList, self).__delitem__(key)
-
-    def __setitem__(self, key, value):
-        if self.box_options.get('frozen_box'):
-            raise BoxError('BoxList is frozen')
-        super(BoxList, self).__setitem__(key, value)
-
-    def append(self, p_object):
-        if isinstance(p_object, dict):
-            try:
-                p_object = self.box_class(p_object, **self.box_options)
-            except AttributeError as err:
-                if 'box_class' in self.__dict__:
-                    raise err
-        elif isinstance(p_object, list):
-            try:
-                p_object = (self if id(p_object) == self.box_org_ref else
-                            BoxList(p_object))
-            except AttributeError as err:
-                if 'box_org_ref' in self.__dict__:
-                    raise err
-        super(BoxList, self).append(p_object)
-
-    def extend(self, iterable):
-        for item in iterable:
-            self.append(item)
-
-    def insert(self, index, p_object):
-        if isinstance(p_object, dict):
-            p_object = self.box_class(p_object, **self.box_options)
-        elif isinstance(p_object, list):
-            p_object = (self if id(p_object) == self.box_org_ref else
-                        BoxList(p_object))
-        super(BoxList, self).insert(index, p_object)
-
-    def __repr__(self):
-        return "<BoxList: {0}>".format(self.to_list())
-
-    def __str__(self):
-        return str(self.to_list())
-
-    def __copy__(self):
-        return BoxList((x for x in self),
-                       self.box_class,
-                       **self.box_options)
-
-    def __deepcopy__(self, memodict=None):
-        out = self.__class__()
-        memodict = memodict or {}
-        memodict[id(self)] = out
-        for k in self:
-            out.append(copy.deepcopy(k))
-        return out
-
-    def __hash__(self):
-        if self.box_options.get('frozen_box'):
-            hashing = 98765
-            hashing ^= hash(tuple(self))
-            return hashing
-        raise TypeError("unhashable type: 'BoxList'")
-
-    def to_list(self):
-        new_list = []
-        for x in self:
-            if x is self:
-                new_list.append(new_list)
-            elif isinstance(x, Box):
-                new_list.append(x.to_dict())
-            elif isinstance(x, BoxList):
-                new_list.append(x.to_list())
-            else:
-                new_list.append(x)
-        return new_list
-
-    def to_json(self, filename=None,
-                encoding="utf-8", errors="strict",
-                multiline=False, **json_kwargs):
-        """
-        Transform the BoxList object into a JSON string.
-
-        :param filename: If provided will save to file
-        :param encoding: File encoding
-        :param errors: How to handle encoding errors
-        :param multiline: Put each item in list onto it's own line
-        :param json_kwargs: additional arguments to pass to json.dump(s)
-        :return: string of JSON or return of `json.dump`
-        """
-        if filename and multiline:
-            lines = [_to_json(item, filename=False, encoding=encoding,
-                              errors=errors, **json_kwargs) for item in self]
-            with open(filename, 'w', encoding=encoding, errors=errors) as f:
-                f.write("\n".join(lines).decode('utf-8') if
-                        sys.version_info < (3, 0) else "\n".join(lines))
-        else:
-            return _to_json(self.to_list(), filename=filename,
-                            encoding=encoding, errors=errors, **json_kwargs)
-
-    @classmethod
-    def from_json(cls, json_string=None, filename=None, encoding="utf-8",
-                  errors="strict", multiline=False, **kwargs):
-        """
-        Transform a json object string into a BoxList object. If the incoming
-        json is a dict, you must use Box.from_json.
-
-        :param json_string: string to pass to `json.loads`
-        :param filename: filename to open and pass to `json.load`
-        :param encoding: File encoding
-        :param errors: How to handle encoding errors
-        :param multiline: One object per line
-        :param kwargs: parameters to pass to `Box()` or `json.loads`
-        :return: BoxList object from json data
-        """
-        bx_args = {}
-        for arg in kwargs.copy():
-            if arg in BOX_PARAMETERS:
-                bx_args[arg] = kwargs.pop(arg)
-
-        data = _from_json(json_string, filename=filename, encoding=encoding,
-                          errors=errors, multiline=multiline, **kwargs)
-
-        if not isinstance(data, list):
-            raise BoxError('json data not returned as a list, '
-                           'but rather a {0}'.format(type(data).__name__))
-        return cls(data, **bx_args)
-
-    if yaml_support:
-        def to_yaml(self, filename=None, default_flow_style=False,
-                    encoding="utf-8", errors="strict",
-                    **yaml_kwargs):
-            """
-            Transform the BoxList object into a YAML string.
-
-            :param filename:  If provided will save to file
-            :param default_flow_style: False will recursively dump dicts
-            :param encoding: File encoding
-            :param errors: How to handle encoding errors
-            :param yaml_kwargs: additional arguments to pass to yaml.dump
-            :return: string of YAML or return of `yaml.dump`
-            """
-            return _to_yaml(self.to_list(), filename=filename,
-                            default_flow_style=default_flow_style,
-                            encoding=encoding, errors=errors, **yaml_kwargs)
-
-        @classmethod
-        def from_yaml(cls, yaml_string=None, filename=None,
-                      encoding="utf-8", errors="strict",
-                      loader=yaml.SafeLoader,
-                      **kwargs):
-            """
-            Transform a yaml object string into a BoxList object.
-
-            :param yaml_string: string to pass to `yaml.load`
-            :param filename: filename to open and pass to `yaml.load`
-            :param encoding: File encoding
-            :param errors: How to handle encoding errors
-            :param loader: YAML Loader, defaults to SafeLoader
-            :param kwargs: parameters to pass to `BoxList()` or `yaml.load`
-            :return: BoxList object from yaml data
-            """
-            bx_args = {}
-            for arg in kwargs.copy():
-                if arg in BOX_PARAMETERS:
-                    bx_args[arg] = kwargs.pop(arg)
-
-            data = _from_yaml(yaml_string=yaml_string, filename=filename,
-                              encoding=encoding, errors=errors,
-                              Loader=loader, **kwargs)
-            if not isinstance(data, list):
-                raise BoxError('yaml data not returned as a list'
-                               'but rather a {0}'.format(type(data).__name__))
-            return cls(data, **bx_args)
-
-    def box_it_up(self):
-        for v in self:
-            if hasattr(v, 'box_it_up') and v is not self:
-                v.box_it_up()
-
-
-class ConfigBox(Box):
-    """
-    Modified box object to add object transforms.
-
-    Allows for build in transforms like:
-
-    cns = ConfigBox(my_bool='yes', my_int='5', my_list='5,4,3,3,2')
-
-    cns.bool('my_bool') # True
-    cns.int('my_int') # 5
-    cns.list('my_list', mod=lambda x: int(x)) # [5, 4, 3, 3, 2]
-    """
-
-    _protected_keys = dir({}) + ['to_dict', 'bool', 'int', 'float',
-                                 'list', 'getboolean', 'to_json', 'to_yaml',
-                                 'getfloat', 'getint',
-                                 'from_json', 'from_yaml']
-
-    def __getattr__(self, item):
-        """Config file keys are stored in lower case, be a little more
-        loosey goosey"""
-        try:
-            return super(ConfigBox, self).__getattr__(item)
-        except AttributeError:
-            return super(ConfigBox, self).__getattr__(item.lower())
-
-    def __dir__(self):
-        return super(ConfigBox, self).__dir__() + ['bool', 'int', 'float',
-                                                   'list', 'getboolean',
-                                                   'getfloat', 'getint']
-
-    def bool(self, item, default=None):
-        """ Return value of key as a boolean
-
-        :param item: key of value to transform
-        :param default: value to return if item does not exist
-        :return: approximated bool of value
-        """
-        try:
-            item = self.__getattr__(item)
-        except AttributeError as err:
-            if default is not None:
-                return default
-            raise err
-
-        if isinstance(item, (bool, int)):
-            return bool(item)
-
-        if (isinstance(item, str) and
-                item.lower() in ('n', 'no', 'false', 'f', '0')):
-            return False
-
-        return True if item else False
-
-    def int(self, item, default=None):
-        """ Return value of key as an int
-
-        :param item: key of value to transform
-        :param default: value to return if item does not exist
-        :return: int of value
-        """
-        try:
-            item = self.__getattr__(item)
-        except AttributeError as err:
-            if default is not None:
-                return default
-            raise err
-        return int(item)
-
-    def float(self, item, default=None):
-        """ Return value of key as a float
-
-        :param item: key of value to transform
-        :param default: value to return if item does not exist
-        :return: float of value
-        """
-        try:
-            item = self.__getattr__(item)
-        except AttributeError as err:
-            if default is not None:
-                return default
-            raise err
-        return float(item)
-
-    def list(self, item, default=None, spliter=",", strip=True, mod=None):
-        """ Return value of key as a list
-
-        :param item: key of value to transform
-        :param mod: function to map against list
-        :param default: value to return if item does not exist
-        :param spliter: character to split str on
-        :param strip: clean the list with the `strip`
-        :return: list of items
-        """
-        try:
-            item = self.__getattr__(item)
-        except AttributeError as err:
-            if default is not None:
-                return default
-            raise err
-        if strip:
-            item = item.lstrip('[').rstrip(']')
-        out = [x.strip() if strip else x for x in item.split(spliter)]
-        if mod:
-            return list(map(mod, out))
-        return out
-
-    # loose configparser compatibility
-
-    def getboolean(self, item, default=None):
-        return self.bool(item, default)
-
-    def getint(self, item, default=None):
-        return self.int(item, default)
-
-    def getfloat(self, item, default=None):
-        return self.float(item, default)
-
-    def __repr__(self):
-        return '<ConfigBox: {0}>'.format(str(self.to_dict()))
-
-
-class SBox(Box):
-    """
-    ShorthandBox (SBox) allows for
-    property access of `dict` `json` and `yaml`
-    """
-    _protected_keys = dir({}) + ['to_dict', 'tree_view', 'to_json', 'to_yaml',
-                                 'json', 'yaml', 'from_yaml', 'from_json',
-                                 'dict']
-
-    @property
-    def dict(self):
-        return self.to_dict()
-
-    @property
-    def json(self):
-        return self.to_json()
-
-    if yaml_support:
-        @property
-        def yaml(self):
-            return self.to_yaml()
-
-    def __repr__(self):
-        return '<ShorthandBox: {0}>'.format(str(self.to_dict()))

+ 0 - 9
lcnn/config.py

@@ -1,9 +0,0 @@
-import numpy as np
-
-from lcnn.box import Box
-
-# C is a dict storing all the configuration
-C = Box()
-
-# shortcut for C.model
-M = Box()

+ 0 - 378
lcnn/dataset_tool.py

@@ -1,378 +0,0 @@
-import cv2
-import numpy as np
-import torch
-import torchvision
-from matplotlib import pyplot as plt
-import tools.transforms as reference_transforms
-from collections import defaultdict
-
-from tools import presets
-
-import json
-
-
-def get_modules(use_v2):
-    # We need a protected import to avoid the V2 warning in case just V1 is used
-    if use_v2:
-        import torchvision.transforms.v2
-        import torchvision.tv_tensors
-
-        return torchvision.transforms.v2, torchvision.tv_tensors
-    else:
-        return reference_transforms, None
-
-
-class Augmentation:
-    # Note: this transform assumes that the input to forward() are always PIL
-    # images, regardless of the backend parameter.
-    def __init__(
-            self,
-            *,
-            data_augmentation,
-            hflip_prob=0.5,
-            mean=(123.0, 117.0, 104.0),
-            backend="pil",
-            use_v2=False,
-    ):
-
-        T, tv_tensors = get_modules(use_v2)
-
-        transforms = []
-        backend = backend.lower()
-        if backend == "tv_tensor":
-            transforms.append(T.ToImage())
-        elif backend == "tensor":
-            transforms.append(T.PILToTensor())
-        elif backend != "pil":
-            raise ValueError(f"backend can be 'tv_tensor', 'tensor' or 'pil', but got {backend}")
-
-        if data_augmentation == "hflip":
-            transforms += [T.RandomHorizontalFlip(p=hflip_prob)]
-        elif data_augmentation == "lsj":
-            transforms += [
-                T.ScaleJitter(target_size=(1024, 1024), antialias=True),
-                # TODO: FixedSizeCrop below doesn't work on tensors!
-                reference_transforms.FixedSizeCrop(size=(1024, 1024), fill=mean),
-                T.RandomHorizontalFlip(p=hflip_prob),
-            ]
-        elif data_augmentation == "multiscale":
-            transforms += [
-                T.RandomShortestSize(min_size=(480, 512, 544, 576, 608, 640, 672, 704, 736, 768, 800), max_size=1333),
-                T.RandomHorizontalFlip(p=hflip_prob),
-            ]
-        elif data_augmentation == "ssd":
-            fill = defaultdict(lambda: mean, {tv_tensors.Mask: 0}) if use_v2 else list(mean)
-            transforms += [
-                T.RandomPhotometricDistort(),
-                T.RandomZoomOut(fill=fill),
-                T.RandomIoUCrop(),
-                T.RandomHorizontalFlip(p=hflip_prob),
-            ]
-        elif data_augmentation == "ssdlite":
-            transforms += [
-                T.RandomIoUCrop(),
-                T.RandomHorizontalFlip(p=hflip_prob),
-            ]
-        else:
-            raise ValueError(f'Unknown data augmentation policy "{data_augmentation}"')
-
-        if backend == "pil":
-            # Note: we could just convert to pure tensors even in v2.
-            transforms += [T.ToImage() if use_v2 else T.PILToTensor()]
-
-        transforms += [T.ToDtype(torch.float, scale=True)]
-
-        if use_v2:
-            transforms += [
-                T.ConvertBoundingBoxFormat(tv_tensors.BoundingBoxFormat.XYXY),
-                T.SanitizeBoundingBoxes(),
-                T.ToPureTensor(),
-            ]
-
-        self.transforms = T.Compose(transforms)
-
-    def __call__(self, img, target):
-        return self.transforms(img, target)
-
-
-def read_polygon_points(lbl_path, shape):
-    """读取 YOLOv8 格式的标注文件并解析多边形轮廓"""
-    polygon_points = []
-    w, h = shape[:2]
-    with open(lbl_path, 'r') as f:
-        lines = f.readlines()
-
-    for line in lines:
-        parts = line.strip().split()
-        class_id = int(parts[0])
-        points = np.array(parts[1:], dtype=np.float32).reshape(-1, 2)  # 读取点坐标
-        points[:, 0] *= h
-        points[:, 1] *= w
-
-        polygon_points.append((class_id, points))
-
-    return polygon_points
-
-
-def read_masks_from_pixels(lbl_path, shape):
-    """读取纯像素点格式的文件,不是轮廓像素点"""
-    h, w = shape
-    masks = []
-    labels = []
-
-    with open(lbl_path, 'r') as reader:
-        lines = reader.readlines()
-        mask_points = []
-        for line in lines:
-            mask = torch.zeros((h, w), dtype=torch.uint8)
-            parts = line.strip().split()
-            # print(f'parts:{parts}')
-            cls = torch.tensor(int(parts[0]), dtype=torch.int64)
-            labels.append(cls)
-            x_array = parts[1::2]
-            y_array = parts[2::2]
-
-            for x, y in zip(x_array, y_array):
-                x = float(x)
-                y = float(y)
-                mask_points.append((int(y * h), int(x * w)))
-
-            for p in mask_points:
-                mask[p] = 1
-            masks.append(mask)
-    reader.close()
-    return labels, masks
-
-
-def create_masks_from_polygons(polygons, image_shape):
-    """创建一个与图像尺寸相同的掩码,并填充多边形轮廓"""
-    colors = np.array([plt.cm.Spectral(each) for each in np.linspace(0, 1, len(polygons))])
-    masks = []
-
-    for polygon_data, col in zip(polygons, colors):
-        mask = np.zeros(image_shape[:2], dtype=np.uint8)
-        # 将多边形顶点转换为 NumPy 数组
-        _, polygon = polygon_data
-        pts = np.array(polygon, np.int32).reshape((-1, 1, 2))
-
-        # 使用 OpenCV 的 fillPoly 函数填充多边形
-        # print(f'color:{col[:3]}')
-        cv2.fillPoly(mask, [pts], np.array(col[:3]) * 255)
-        mask = torch.from_numpy(mask)
-        mask[mask != 0] = 1
-        masks.append(mask)
-
-    return masks
-
-
-def read_masks_from_txt(label_path, shape):
-    polygon_points = read_polygon_points(label_path, shape)
-    masks = create_masks_from_polygons(polygon_points, shape)
-    labels = [torch.tensor(item[0]) for item in polygon_points]
-
-    return labels, masks
-
-
-def masks_to_boxes(masks: torch.Tensor, ) -> torch.Tensor:
-    """
-    Compute the bounding boxes around the provided masks.
-
-    Returns a [N, 4] tensor containing bounding boxes. The boxes are in ``(x1, y1, x2, y2)`` format with
-    ``0 <= x1 < x2`` and ``0 <= y1 < y2``.
-
-    Args:
-        masks (Tensor[N, H, W]): masks to transform where N is the number of masks
-            and (H, W) are the spatial dimensions.
-
-    Returns:
-        Tensor[N, 4]: bounding boxes
-    """
-    # if not torch.jit.is_scripting() and not torch.jit.is_tracing():
-    #     _log_api_usage_once(masks_to_boxes)
-    if masks.numel() == 0:
-        return torch.zeros((0, 4), device=masks.device, dtype=torch.float)
-
-    n = masks.shape[0]
-
-    bounding_boxes = torch.zeros((n, 4), device=masks.device, dtype=torch.float)
-
-    for index, mask in enumerate(masks):
-        y, x = torch.where(mask != 0)
-        bounding_boxes[index, 0] = torch.min(x)
-        bounding_boxes[index, 1] = torch.min(y)
-        bounding_boxes[index, 2] = torch.max(x)
-        bounding_boxes[index, 3] = torch.max(y)
-        # debug to pixel datasets
-
-        if bounding_boxes[index, 0] == bounding_boxes[index, 2]:
-            bounding_boxes[index, 2] = bounding_boxes[index, 2] + 1
-            bounding_boxes[index, 0] = bounding_boxes[index, 0] - 1
-
-        if bounding_boxes[index, 1] == bounding_boxes[index, 3]:
-            bounding_boxes[index, 3] = bounding_boxes[index, 3] + 1
-            bounding_boxes[index, 1] = bounding_boxes[index, 1] - 1
-
-    return bounding_boxes
-
-
-def line_boxes_faster(target):
-    boxs = []
-    lpre = target["lpre"].cpu().numpy() * 4
-    vecl_target = target["lpre_label"].cpu().numpy()
-    lpre = lpre[vecl_target == 1]
-
-    lines = lpre
-    sline = np.ones(lpre.shape[0])
-
-    if len(lines) > 0 and not (lines[0] == 0).all():
-        for i, ((a, b), s) in enumerate(zip(lines, sline)):
-            if i > 0 and (lines[i] == lines[0]).all():
-                break
-            # plt.plot([a[1], b[1]], [a[0], b[0]], c="red", linewidth=1)  # a[1], b[1]无明确大小
-
-            if a[1] > b[1]:
-                ymax = a[1] + 10
-                ymin = b[1] - 10
-            else:
-                ymin = a[1] - 10
-                ymax = b[1] + 10
-            if a[0] > b[0]:
-                xmax = a[0] + 10
-                xmin = b[0] - 10
-            else:
-                xmin = a[0] - 10
-                xmax = b[0] + 10
-            boxs.append([max(0, ymin), max(0, xmin), min(512, ymax), min(512, xmax)])
-
-    return torch.tensor(boxs)
-
-
-# 将线段变为 [起点,终点]形式  传入[num,2,2]
-def a_to_b(line):
-    result_pairs = []
-    for a, b in line:
-        min_x = min(a[0], b[0])
-        min_y = min(a[1], b[1])
-        new_top_left_x = max((min_x - 10), 0)
-        new_top_left_y = max((min_y - 10), 0)
-        dist_a = (a[0] - new_top_left_x) ** 2 + (a[1] - new_top_left_y) ** 2
-        dist_b = (b[0] - new_top_left_x) ** 2 + (b[1] - new_top_left_y) ** 2
-
-        # 根据距离选择起点并设置标签
-        if dist_a <= dist_b:  # 如果a点离新左上角更近或两者距离相等
-            result_pairs.append([a, b])  # 将a设为起点,b为终点
-        else:  # 如果b点离新左上角更近
-            result_pairs.append([b, a])  # 将b设为起点,a为终点
-    result_tensor = torch.stack([torch.stack(row) for row in result_pairs])
-    return result_tensor
-
-
-def read_polygon_points_wire(lbl_path, shape):
-    """读取 YOLOv8 格式的标注文件并解析多边形轮廓"""
-    polygon_points = []
-    w, h = shape[:2]
-    with open(lbl_path, 'r') as f:
-        lines = json.load(f)
-
-    for line in lines["segmentations"]:
-        parts = line["data"]
-        class_id = int(line["cls_id"])
-        points = np.array(parts, dtype=np.float32).reshape(-1, 2)  # 读取点坐标
-        points[:, 0] *= h
-        points[:, 1] *= w
-
-        polygon_points.append((class_id, points))
-
-    return polygon_points
-
-
-def read_masks_from_txt_wire(label_path, shape):
-    polygon_points = read_polygon_points_wire(label_path, shape)
-    masks = create_masks_from_polygons(polygon_points, shape)
-    labels = [torch.tensor(item[0]) for item in polygon_points]
-
-    return labels, masks
-
-
-def read_masks_from_pixels_wire(lbl_path, shape):
-    """读取纯像素点格式的文件,不是轮廓像素点"""
-    h, w = shape
-    masks = []
-    labels = []
-
-    with open(lbl_path, 'r') as reader:
-        lines = json.load(reader)
-        mask_points = []
-        for line in lines["segmentations"]:
-            # mask = torch.zeros((h, w), dtype=torch.uint8)
-            # parts = line["data"]
-            # print(f'parts:{parts}')
-            cls = torch.tensor(int(line["cls_id"]), dtype=torch.int64)
-            labels.append(cls)
-            # x_array = parts[0::2]
-            # y_array = parts[1::2]
-            # 
-            # for x, y in zip(x_array, y_array):
-            #     x = float(x)
-            #     y = float(y)
-            #     mask_points.append((int(y * h), int(x * w)))
-
-            # for p in mask_points:
-            #     mask[p] = 1
-            # masks.append(mask)
-    reader.close()
-    return labels
-
-
-def read_lable_keypoint(lbl_path):
-    """判断线段的起点终点, 起点 lable=0, 终点 lable=1"""
-    labels = []
-
-    with open(lbl_path, 'r') as reader:
-        lines = json.load(reader)
-        aa = lines["wires"][0]["line_pos_coords"]["content"]
-
-        result_pairs = []
-        for a, b in aa:
-            min_x = min(a[0], b[0])
-            min_y = min(a[1], b[1])
-
-            # 定义新的左上角位置
-            new_top_left_x = max((min_x - 10), 0)
-            new_top_left_y = max((min_y - 10), 0)
-
-            # Step 2: 计算各点到新左上角的距离平方(避免浮点运算误差)
-            dist_a = (a[0] - new_top_left_x) ** 2 + (a[1] - new_top_left_y) ** 2
-            dist_b = (b[0] - new_top_left_x) ** 2 + (b[1] - new_top_left_y) ** 2
-
-            # Step 3 & 4: 根据距离选择起点并设置标签
-            if dist_a <= dist_b:  # 如果a点离新左上角更近或两者距离相等
-                result_pairs.append([a, b])  # 将a设为起点,b为终点
-            else:  # 如果b点离新左上角更近
-                result_pairs.append([b, a])  # 将b设为起点,a为终点
-
-            # x_ = abs(a[0] - b[0])
-            # y_ = abs(a[1] - b[1])
-            # if x_ > y_:  # x小的离左上角近
-            #     if a[0] < b[0]:  # 视为起点,lable=0
-            #         label = [0, 1]
-            #     else:
-            #         label = [1, 0]
-            # else:  # x大的是起点
-            #     if a[0] > b[0]:  # 视为起点,lable=0
-            #         label = [0, 1]
-            #     else:
-            #         label = [1, 0]
-            # labels.append(label)
-        # print(result_pairs )
-    reader.close()
-    return labels
-
-
-def adjacency_matrix(n, link):  # 邻接矩阵
-    mat = torch.zeros(n + 1, n + 1, dtype=torch.uint8)
-    link = torch.tensor(link)
-    if len(link) > 0:
-        mat[link[:, 0], link[:, 1]] = 1
-        mat[link[:, 1], link[:, 0]] = 1
-    return mat

+ 0 - 496
lcnn/datasets.py

@@ -1,496 +0,0 @@
-# import glob
-# import json
-# import math
-# import os
-# import random
-#
-# import numpy as np
-# import numpy.linalg as LA
-# import torch
-# from skimage import io
-# from torch.utils.data import Dataset
-# from torch.utils.data.dataloader import default_collate
-#
-# from lcnn.config import M
-#
-# from .dataset_tool import line_boxes_faster, read_masks_from_txt_wire, read_masks_from_pixels_wire
-#
-#
-# class WireframeDataset(Dataset):
-#     def __init__(self, rootdir, split):
-#         self.rootdir = rootdir
-#         filelist = glob.glob(f"{rootdir}/{split}/*_label.npz")
-#         filelist.sort()
-#
-#         # print(f"n{split}:", len(filelist))
-#         self.split = split
-#         self.filelist = filelist
-#
-#     def __len__(self):
-#         return len(self.filelist)
-#
-#     def __getitem__(self, idx):
-#         iname = self.filelist[idx][:-10].replace("_a0", "").replace("_a1", "") + ".png"
-#         image = io.imread(iname).astype(float)[:, :, :3]
-#         if "a1" in self.filelist[idx]:
-#             image = image[:, ::-1, :]
-#         image = (image - M.image.mean) / M.image.stddev
-#         image = np.rollaxis(image, 2).copy()
-#
-#         with np.load(self.filelist[idx]) as npz:
-#             target = {
-#                 name: torch.from_numpy(npz[name]).float()
-#                 for name in ["jmap", "joff", "lmap"]
-#             }
-#             lpos = np.random.permutation(npz["lpos"])[: M.n_stc_posl]
-#             lneg = np.random.permutation(npz["lneg"])[: M.n_stc_negl]
-#             npos, nneg = len(lpos), len(lneg)
-#             lpre = np.concatenate([lpos, lneg], 0)
-#             for i in range(len(lpre)):
-#                 if random.random() > 0.5:
-#                     lpre[i] = lpre[i, ::-1]
-#             ldir = lpre[:, 0, :2] - lpre[:, 1, :2]
-#             ldir /= np.clip(LA.norm(ldir, axis=1, keepdims=True), 1e-6, None)
-#             feat = [
-#                 lpre[:, :, :2].reshape(-1, 4) / 128 * M.use_cood,
-#                 ldir * M.use_slop,
-#                 lpre[:, :, 2],
-#             ]
-#             feat = np.concatenate(feat, 1)
-#             meta = {
-#                 "junc": torch.from_numpy(npz["junc"][:, :2]),
-#                 "jtyp": torch.from_numpy(npz["junc"][:, 2]).byte(),
-#                 "Lpos": self.adjacency_matrix(len(npz["junc"]), npz["Lpos"]),
-#                 "Lneg": self.adjacency_matrix(len(npz["junc"]), npz["Lneg"]),
-#                 "lpre": torch.from_numpy(lpre[:, :, :2]),
-#                 "lpre_label": torch.cat([torch.ones(npos), torch.zeros(nneg)]),
-#                 "lpre_feat": torch.from_numpy(feat),
-#             }
-#
-#         labels = []
-#         labels = read_masks_from_pixels_wire(iname, (512, 512))
-#         # if self.target_type == 'polygon':
-#         #     labels, masks = read_masks_from_txt_wire(iname, (512, 512))
-#         # elif self.target_type == 'pixel':
-#         #     labels = read_masks_from_pixels_wire(iname, (512, 512))
-#
-#         target["labels"] = torch.stack(labels)
-#         target["boxes"] = line_boxes_faster(meta)
-#
-#
-#         return torch.from_numpy(image).float(), meta, target
-#
-#     def adjacency_matrix(self, n, link):
-#         mat = torch.zeros(n + 1, n + 1, dtype=torch.uint8)
-#         link = torch.from_numpy(link)
-#         if len(link) > 0:
-#             mat[link[:, 0], link[:, 1]] = 1
-#             mat[link[:, 1], link[:, 0]] = 1
-#         return mat
-#
-#
-# def collate(batch):
-#     return (
-#         default_collate([b[0] for b in batch]),
-#         [b[1] for b in batch],
-#         default_collate([b[2] for b in batch]),
-#     )
-
-
-# 原LCNN数据格式,改了属性名,加了box相关
-
-from torch.utils.data.dataset import T_co
-
-from .models.base.base_dataset import BaseDataset
-
-import glob
-import json
-import math
-import os
-import random
-import cv2
-import PIL
-
-import matplotlib.pyplot as plt
-import matplotlib as mpl
-from torchvision.utils import draw_bounding_boxes
-
-import numpy as np
-import numpy.linalg as LA
-import torch
-from skimage import io
-from torch.utils.data import Dataset
-from torch.utils.data.dataloader import default_collate
-
-import matplotlib.pyplot as plt
-from .dataset_tool import line_boxes_faster, read_masks_from_txt_wire, read_masks_from_pixels_wire, adjacency_matrix
-
-class  WireframeDataset(BaseDataset):
-    def __init__(self, dataset_path, transforms=None, dataset_type=None, target_type='pixel'):
-        super().__init__(dataset_path)
-
-        self.data_path = dataset_path
-        print(f'data_path:{dataset_path}')
-        self.transforms = transforms
-        self.img_path = os.path.join(dataset_path, "images/" + dataset_type)
-        self.lbl_path = os.path.join(dataset_path, "labels/" + dataset_type)
-        self.imgs = os.listdir(self.img_path)
-        self.lbls = os.listdir(self.lbl_path)
-        self.target_type = target_type
-        # self.default_transform = DefaultTransform()
-
-    def __getitem__(self, index) -> T_co:
-        img_path = os.path.join(self.img_path, self.imgs[index])
-        lbl_path = os.path.join(self.lbl_path, self.imgs[index][:-3] + 'json')
-
-        img = PIL.Image.open(img_path).convert('RGB')
-        w, h = img.size
-
-        # wire_labels, target = self.read_target(item=index, lbl_path=lbl_path, shape=(h, w))
-        meta, target, target_b = self.read_target(item=index, lbl_path=lbl_path, shape=(h, w))
-
-        img = self.default_transform(img)
-
-        # print(f'img:{img}')
-        return img, meta, target, target_b
-
-    def __len__(self):
-        return len(self.imgs)
-
-    def read_target(self, item, lbl_path, shape, extra=None):
-        # print(f'shape:{shape}')
-        # print(f'lbl_path:{lbl_path}')
-        with open(lbl_path, 'r') as file:
-            lable_all = json.load(file)
-
-        n_stc_posl = 300
-        n_stc_negl = 40
-        use_cood = 0
-        use_slop = 0
-
-        wire = lable_all["wires"][0]  # 字典
-        line_pos_coords = np.random.permutation(wire["line_pos_coords"]["content"])[: n_stc_posl]  # 不足,有多少取多少
-        line_neg_coords = np.random.permutation(wire["line_neg_coords"]["content"])[: n_stc_negl]
-        npos, nneg = len(line_pos_coords), len(line_neg_coords)
-        lpre = np.concatenate([line_pos_coords, line_neg_coords], 0)  # 正负样本坐标合在一起
-        for i in range(len(lpre)):
-            if random.random() > 0.5:
-                lpre[i] = lpre[i, ::-1]
-        ldir = lpre[:, 0, :2] - lpre[:, 1, :2]
-        ldir /= np.clip(LA.norm(ldir, axis=1, keepdims=True), 1e-6, None)
-        feat = [
-            lpre[:, :, :2].reshape(-1, 4) / 128 * use_cood,
-            ldir * use_slop,
-            lpre[:, :, 2],
-        ]
-        feat = np.concatenate(feat, 1)
-
-        meta = {
-            "junc_coords": torch.tensor(wire["junc_coords"]["content"])[:, :2],
-            "jtyp": torch.tensor(wire["junc_coords"]["content"])[:, 2].byte(),
-            "line_pos_idx": adjacency_matrix(len(wire["junc_coords"]["content"]), wire["line_pos_idx"]["content"]),
-            # 真实存在线条的邻接矩阵
-            "line_neg_idx": adjacency_matrix(len(wire["junc_coords"]["content"]), wire["line_neg_idx"]["content"]),
-
-            "lpre": torch.tensor(lpre)[:, :, :2],
-            "lpre_label": torch.cat([torch.ones(npos), torch.zeros(nneg)]),  # 样本对应标签 1,0
-            "lpre_feat": torch.from_numpy(feat),
-
-        }
-
-        target = {
-            "junc_map": torch.tensor(wire['junc_map']["content"]),
-            "junc_offset": torch.tensor(wire['junc_offset']["content"]),
-            "line_map": torch.tensor(wire['line_map']["content"]),
-        }
-
-        labels = []
-        if self.target_type == 'polygon':
-            labels, masks = read_masks_from_txt_wire(lbl_path, shape)
-        elif self.target_type == 'pixel':
-            labels = read_masks_from_pixels_wire(lbl_path, shape)
-
-        # print(torch.stack(masks).shape)    # [线段数, 512, 512]
-        target_b = {}
-
-        # target_b["image_id"] = torch.tensor(item)
-
-        target_b["labels"] = torch.stack(labels)
-        target_b["boxes"] = line_boxes_faster(meta)
-
-        return meta, target, target_b
-
-
-    def show(self, idx):
-        image, target = self.__getitem__(idx)
-
-        cmap = plt.get_cmap("jet")
-        norm = mpl.colors.Normalize(vmin=0.4, vmax=1.0)
-        sm = plt.cm.ScalarMappable(cmap=cmap, norm=norm)
-        sm.set_array([])
-
-        def imshow(im):
-            plt.close()
-            plt.tight_layout()
-            plt.imshow(im)
-            plt.colorbar(sm, fraction=0.046)
-            plt.xlim([0, im.shape[0]])
-            plt.ylim([im.shape[0], 0])
-
-        def draw_vecl(lines, sline, juncs, junts, fn=None):
-            img_path = os.path.join(self.img_path, self.imgs[idx])
-            imshow(io.imread(img_path))
-            if len(lines) > 0 and not (lines[0] == 0).all():
-                for i, ((a, b), s) in enumerate(zip(lines, sline)):
-                    if i > 0 and (lines[i] == lines[0]).all():
-                        break
-                    plt.plot([a[1], b[1]], [a[0], b[0]], c="red", linewidth=1)  # a[1], b[1]无明确大小
-            if not (juncs[0] == 0).all():
-                for i, j in enumerate(juncs):
-                    if i > 0 and (i == juncs[0]).all():
-                        break
-                    plt.scatter(j[1], j[0], c="red", s=2, zorder=100)  # 原 s=64
-
-            img_path = os.path.join(self.img_path, self.imgs[idx])
-            img = PIL.Image.open(img_path).convert('RGB')
-            boxed_image = draw_bounding_boxes((self.default_transform(img) * 255).to(torch.uint8), target["boxes"],
-                                              colors="yellow", width=1)
-            plt.imshow(boxed_image.permute(1, 2, 0).numpy())
-            plt.show()
-
-            plt.show()
-            if fn != None:
-                plt.savefig(fn)
-
-        junc = target['wires']['junc_coords'].cpu().numpy() * 4
-        jtyp = target['wires']['jtyp'].cpu().numpy()
-        juncs = junc[jtyp == 0]
-        junts = junc[jtyp == 1]
-
-        lpre = target['wires']["lpre"].cpu().numpy() * 4
-        vecl_target = target['wires']["lpre_label"].cpu().numpy()
-        lpre = lpre[vecl_target == 1]
-
-        # draw_vecl(lpre, np.ones(lpre.shape[0]), juncs, junts, save_path)
-        draw_vecl(lpre, np.ones(lpre.shape[0]), juncs, junts)
-
-    def show_img(self, img_path):
-        pass
-
-
-def collate(batch):
-    return (
-        default_collate([b[0] for b in batch]),
-        [b[1] for b in batch],
-        default_collate([b[2] for b in batch]),
-        [b[3] for b in batch],
-    )
-
-
-# if __name__ == '__main__':
-#     path = r"D:\python\PycharmProjects\data"
-#     dataset = WireframeDataset(dataset_path=path, dataset_type='train')
-#     dataset.show(0)
-
-
-
-'''
-# 使用roi_head数据格式有要求,更改数据格式
-from torch.utils.data.dataset import T_co
-
-from models.base.base_dataset import BaseDataset
-
-import glob
-import json
-import math
-import os
-import random
-import cv2
-import PIL
-
-import matplotlib.pyplot as plt
-import matplotlib as mpl
-from torchvision.utils import draw_bounding_boxes
-
-import numpy as np
-import numpy.linalg as LA
-import torch
-from skimage import io
-from torch.utils.data import Dataset
-from torch.utils.data.dataloader import default_collate
-
-import matplotlib.pyplot as plt
-from models.dataset_tool import line_boxes, read_masks_from_txt_wire, read_masks_from_pixels_wire, adjacency_matrix
-
-from tools.presets import DetectionPresetTrain
-
-
-class WireframeDataset(BaseDataset):
-    def __init__(self, dataset_path, transforms=None, dataset_type=None, target_type='pixel'):
-        super().__init__(dataset_path)
-
-        self.data_path = dataset_path
-        print(f'data_path:{dataset_path}')
-        self.transforms = transforms
-        self.img_path = os.path.join(dataset_path, "images\\" + dataset_type)
-        self.lbl_path = os.path.join(dataset_path, "labels\\" + dataset_type)
-        self.imgs = os.listdir(self.img_path)
-        self.lbls = os.listdir(self.lbl_path)
-        self.target_type = target_type
-        # self.default_transform = DefaultTransform()
-        self.data_augmentation = DetectionPresetTrain(data_augmentation="hflip")  # multiscale会改变图像大小
-
-    def __getitem__(self, index) -> T_co:
-        img_path = os.path.join(self.img_path, self.imgs[index])
-        lbl_path = os.path.join(self.lbl_path, self.imgs[index][:-3] + 'json')
-
-        img = PIL.Image.open(img_path).convert('RGB')
-        w, h = img.size
-
-        # wire_labels, target = self.read_target(item=index, lbl_path=lbl_path, shape=(h, w))
-        target = self.read_target(item=index, lbl_path=lbl_path, shape=(h, w))
-        # if self.transforms:
-        #     img, target = self.transforms(img, target)
-        # else:
-        #     img = self.default_transform(img)
-
-        img, target = self.data_augmentation(img, target)
-
-        print(f'img:{img.shape}')
-        return img, target
-
-    def __len__(self):
-        return len(self.imgs)
-
-    def read_target(self, item, lbl_path, shape, extra=None):
-        # print(f'lbl_path:{lbl_path}')
-        with open(lbl_path, 'r') as file:
-            lable_all = json.load(file)
-
-        n_stc_posl = 300
-        n_stc_negl = 40
-        use_cood = 0
-        use_slop = 0
-
-        wire = lable_all["wires"][0]  # 字典
-        line_pos_coords = np.random.permutation(wire["line_pos_coords"]["content"])[: n_stc_posl]  # 不足,有多少取多少
-        line_neg_coords = np.random.permutation(wire["line_neg_coords"]["content"])[: n_stc_negl]
-        npos, nneg = len(line_pos_coords), len(line_neg_coords)
-        lpre = np.concatenate([line_pos_coords, line_neg_coords], 0)  # 正负样本坐标合在一起
-        for i in range(len(lpre)):
-            if random.random() > 0.5:
-                lpre[i] = lpre[i, ::-1]
-        ldir = lpre[:, 0, :2] - lpre[:, 1, :2]
-        ldir /= np.clip(LA.norm(ldir, axis=1, keepdims=True), 1e-6, None)
-        feat = [
-            lpre[:, :, :2].reshape(-1, 4) / 128 * use_cood,
-            ldir * use_slop,
-            lpre[:, :, 2],
-        ]
-        feat = np.concatenate(feat, 1)
-
-        wire_labels = {
-            "junc_coords": torch.tensor(wire["junc_coords"]["content"])[:, :2],
-            "jtyp": torch.tensor(wire["junc_coords"]["content"])[:, 2].byte(),
-            "line_pos_idx": adjacency_matrix(len(wire["junc_coords"]["content"]), wire["line_pos_idx"]["content"]),
-            # 真实存在线条的邻接矩阵
-            "line_neg_idx": adjacency_matrix(len(wire["junc_coords"]["content"]), wire["line_neg_idx"]["content"]),
-            # 不存在线条的临界矩阵
-            "lpre": torch.tensor(lpre)[:, :, :2],
-            "lpre_label": torch.cat([torch.ones(npos), torch.zeros(nneg)]),  # 样本对应标签 1,0
-            "lpre_feat": torch.from_numpy(feat),
-            "junc_map": torch.tensor(wire['junc_map']["content"]),
-            "junc_offset": torch.tensor(wire['junc_offset']["content"]),
-            "line_map": torch.tensor(wire['line_map']["content"]),
-        }
-
-        labels = []
-        # if self.target_type == 'polygon':
-        #     labels, masks = read_masks_from_txt_wire(lbl_path, shape)
-        # elif self.target_type == 'pixel':
-        #     labels = read_masks_from_pixels_wire(lbl_path, shape)
-
-        # print(torch.stack(masks).shape)    # [线段数, 512, 512]
-        target = {}
-        # target["labels"] = torch.stack(labels)
-        target["image_id"] = torch.tensor(item)
-        # return wire_labels, target
-        target["wires"] = wire_labels
-        target["boxes"] = line_boxes(target)
-        return target
-
-    def show(self, idx):
-        image, target = self.__getitem__(idx)
-        img_path = os.path.join(self.img_path, self.imgs[idx])
-        self._draw_vecl(img_path, target)
-
-    def show_img(self, img_path):
-
-        """根据给定的图片路径展示图像及其标注信息"""
-        # 获取对应的标签文件路径
-        img_name = os.path.basename(img_path)
-        img_path = os.path.join(self.img_path, img_name)
-        print(img_path)
-        lbl_name = img_name[:-3] + 'json'
-        lbl_path = os.path.join(self.lbl_path, lbl_name)
-        print(lbl_path)
-
-        if not os.path.exists(lbl_path):
-            raise FileNotFoundError(f"Label file {lbl_path} does not exist.")
-
-        img = PIL.Image.open(img_path).convert('RGB')
-        w, h = img.size
-
-        target = self.read_target(0, lbl_path, shape=(h, w))
-
-        # 调用绘图函数
-        self._draw_vecl(img_path, target)
-
-
-    def _draw_vecl(self, img_path, target, fn=None):
-        cmap = plt.get_cmap("jet")
-        norm = mpl.colors.Normalize(vmin=0.4, vmax=1.0)
-        sm = plt.cm.ScalarMappable(cmap=cmap, norm=norm)
-        sm.set_array([])
-
-        def imshow(im):
-            plt.close()
-            plt.tight_layout()
-            plt.imshow(im)
-            plt.colorbar(sm, fraction=0.046)
-            plt.xlim([0, im.shape[0]])
-            plt.ylim([im.shape[0], 0])
-
-        junc = target['wires']['junc_coords'].cpu().numpy() * 4
-        jtyp = target['wires']['jtyp'].cpu().numpy()
-        juncs = junc[jtyp == 0]
-        junts = junc[jtyp == 1]
-
-        lpre = target['wires']["lpre"].cpu().numpy() * 4
-        vecl_target = target['wires']["lpre_label"].cpu().numpy()
-        lpre = lpre[vecl_target == 1]
-
-        lines = lpre
-        sline = np.ones(lpre.shape[0])
-        imshow(io.imread(img_path))
-        if len(lines) > 0 and not (lines[0] == 0).all():
-            for i, ((a, b), s) in enumerate(zip(lines, sline)):
-                if i > 0 and (lines[i] == lines[0]).all():
-                    break
-                plt.plot([a[1], b[1]], [a[0], b[0]], c="red", linewidth=1)  # a[1], b[1]无明确大小
-        if not (juncs[0] == 0).all():
-            for i, j in enumerate(juncs):
-                if i > 0 and (i == juncs[0]).all():
-                    break
-                plt.scatter(j[1], j[0], c="red", s=2, zorder=100)  # 原 s=64
-
-        img = PIL.Image.open(img_path).convert('RGB')
-        boxed_image = draw_bounding_boxes((self.default_transform(img) * 255).to(torch.uint8), target["boxes"],
-                                          colors="yellow", width=1)
-        plt.imshow(boxed_image.permute(1, 2, 0).numpy())
-        plt.show()
-
-        if fn != None:
-            plt.savefig(fn)
-
-'''

+ 0 - 209
lcnn/metric.py

@@ -1,209 +0,0 @@
-import numpy as np
-import numpy.linalg as LA
-import matplotlib.pyplot as plt
-
-from lcnn.utils import argsort2d
-
-DX = [0, 0, 1, -1, 1, 1, -1, -1]
-DY = [1, -1, 0, 0, 1, -1, 1, -1]
-
-
-def ap(tp, fp):
-    recall = tp
-    precision = tp / np.maximum(tp + fp, 1e-9)
-
-    recall = np.concatenate(([0.0], recall, [1.0]))
-    precision = np.concatenate(([0.0], precision, [0.0]))
-
-    for i in range(precision.size - 1, 0, -1):
-        precision[i - 1] = max(precision[i - 1], precision[i])
-    i = np.where(recall[1:] != recall[:-1])[0]
-    return np.sum((recall[i + 1] - recall[i]) * precision[i + 1])
-
-
-def APJ(vert_pred, vert_gt, max_distance, im_ids):
-    if len(vert_pred) == 0:
-        return 0
-
-    vert_pred = np.array(vert_pred)
-    vert_gt = np.array(vert_gt)
-
-    confidence = vert_pred[:, -1]
-    idx = np.argsort(-confidence)
-    vert_pred = vert_pred[idx, :]
-    im_ids = im_ids[idx]
-    n_gt = sum(len(gt) for gt in vert_gt)
-
-    nd = len(im_ids)
-    tp, fp = np.zeros(nd, dtype=np.float), np.zeros(nd, dtype=np.float)
-    hit = [[False for _ in j] for j in vert_gt]
-
-    for i in range(nd):
-        gt_juns = vert_gt[im_ids[i]]
-        pred_juns = vert_pred[i][:-1]
-        if len(gt_juns) == 0:
-            continue
-        dists = np.linalg.norm((pred_juns[None, :] - gt_juns), axis=1)
-        choice = np.argmin(dists)
-        dist = np.min(dists)
-        if dist < max_distance and not hit[im_ids[i]][choice]:
-            tp[i] = 1
-            hit[im_ids[i]][choice] = True
-        else:
-            fp[i] = 1
-
-    tp = np.cumsum(tp) / n_gt
-    fp = np.cumsum(fp) / n_gt
-    return ap(tp, fp)
-
-
-def nms_j(heatmap, delta=1):
-    heatmap = heatmap.copy()
-    disable = np.zeros_like(heatmap, dtype=np.bool)
-    for x, y in argsort2d(heatmap):
-        for dx, dy in zip(DX, DY):
-            xp, yp = x + dx, y + dy
-            if not (0 <= xp < heatmap.shape[0] and 0 <= yp < heatmap.shape[1]):
-                continue
-            if heatmap[x, y] >= heatmap[xp, yp]:
-                disable[xp, yp] = True
-    heatmap[disable] *= 0.6
-    return heatmap
-
-
-def mAPJ(pred, truth, distances, im_ids):
-    return sum(APJ(pred, truth, d, im_ids) for d in distances) / len(distances) * 100
-
-
-def post_jheatmap(heatmap, offset=None, delta=1):
-    heatmap = nms_j(heatmap, delta=delta)
-    # only select the best 1000 junctions for efficiency
-    v0 = argsort2d(-heatmap)[:1000]
-    confidence = -np.sort(-heatmap.ravel())[:1000]
-    keep_id = np.where(confidence >= 1e-2)[0]
-    if len(keep_id) == 0:
-        return np.zeros((0, 3))
-
-    confidence = confidence[keep_id]
-    if offset is not None:
-        v0 = np.array([v + offset[:, v[0], v[1]] for v in v0])
-    v0 = v0[keep_id] + 0.5
-    v0 = np.hstack((v0, confidence[:, np.newaxis]))
-    return v0
-
-
-def vectorized_wireframe_2d_metric(
-    vert_pred, dpth_pred, edge_pred, vert_gt, dpth_gt, edge_gt, threshold
-):
-    # staging 1: matching
-    nd = len(vert_pred)
-    sorted_confidence = np.argsort(-vert_pred[:, -1])
-    vert_pred = vert_pred[sorted_confidence, :-1]
-    dpth_pred = dpth_pred[sorted_confidence]
-    d = np.sqrt(
-        np.sum(vert_pred ** 2, 1)[:, None]
-        + np.sum(vert_gt ** 2, 1)[None, :]
-        - 2 * vert_pred @ vert_gt.T
-    )
-    choice = np.argmin(d, 1)
-    dist = np.min(d, 1)
-
-    # staging 2: compute depth metric: SIL/L2
-    loss_L1 = loss_L2 = 0
-    hit = np.zeros_like(dpth_gt, np.bool)
-    SIL = np.zeros(dpth_pred)
-    for i in range(nd):
-        if dist[i] < threshold and not hit[choice[i]]:
-            hit[choice[i]] = True
-            loss_L1 += abs(dpth_gt[choice[i]] - dpth_pred[i])
-            loss_L2 += (dpth_gt[choice[i]] - dpth_pred[i]) ** 2
-            a = np.maximum(-dpth_pred[i], 1e-10)
-            b = -dpth_gt[choice[i]]
-            SIL[i] = np.log(a) - np.log(b)
-        else:
-            choice[i] = -1
-
-    n = max(np.sum(hit), 1)
-    loss_L1 /= n
-    loss_L2 /= n
-    loss_SIL = np.sum(SIL ** 2) / n - np.sum(SIL) ** 2 / (n * n)
-
-    # staging 3: compute mAP for edge matching
-    edgeset = set([frozenset(e) for e in edge_gt])
-    tp = np.zeros(len(edge_pred), dtype=np.float)
-    fp = np.zeros(len(edge_pred), dtype=np.float)
-    for i, (v0, v1, score) in enumerate(sorted(edge_pred, key=-edge_pred[2])):
-        length = LA.norm(vert_gt[v0] - vert_gt[v1], axis=1)
-        if frozenset([choice[v0], choice[v1]]) in edgeset:
-            tp[i] = length
-        else:
-            fp[i] = length
-    total_length = LA.norm(
-        vert_gt[edge_gt[:, 0]] - vert_gt[edge_gt[:, 1]], axis=1
-    ).sum()
-    return ap(tp / total_length, fp / total_length), (loss_SIL, loss_L1, loss_L2)
-
-
-def vectorized_wireframe_3d_metric(
-    vert_pred, dpth_pred, edge_pred, vert_gt, dpth_gt, edge_gt, threshold
-):
-    # staging 1: matching
-    nd = len(vert_pred)
-    sorted_confidence = np.argsort(-vert_pred[:, -1])
-    vert_pred = np.hstack([vert_pred[:, :-1], dpth_pred[:, None]])[sorted_confidence]
-    vert_gt = np.hstack([vert_gt[:, :-1], dpth_gt[:, None]])
-    d = np.sqrt(
-        np.sum(vert_pred ** 2, 1)[:, None]
-        + np.sum(vert_gt ** 2, 1)[None, :]
-        - 2 * vert_pred @ vert_gt.T
-    )
-    choice = np.argmin(d, 1)
-    dist = np.min(d, 1)
-    hit = np.zeros_like(dpth_gt, np.bool)
-    for i in range(nd):
-        if dist[i] < threshold and not hit[choice[i]]:
-            hit[choice[i]] = True
-        else:
-            choice[i] = -1
-
-    # staging 2: compute mAP for edge matching
-    edgeset = set([frozenset(e) for e in edge_gt])
-    tp = np.zeros(len(edge_pred), dtype=np.float)
-    fp = np.zeros(len(edge_pred), dtype=np.float)
-    for i, (v0, v1, score) in enumerate(sorted(edge_pred, key=-edge_pred[2])):
-        length = LA.norm(vert_gt[v0] - vert_gt[v1], axis=1)
-        if frozenset([choice[v0], choice[v1]]) in edgeset:
-            tp[i] = length
-        else:
-            fp[i] = length
-    total_length = LA.norm(
-        vert_gt[edge_gt[:, 0]] - vert_gt[edge_gt[:, 1]], axis=1
-    ).sum()
-
-    return ap(tp / total_length, fp / total_length)
-
-
-def msTPFP(line_pred, line_gt, threshold):
-    diff = ((line_pred[:, None, :, None] - line_gt[:, None]) ** 2).sum(-1)
-    diff = np.minimum(
-        diff[:, :, 0, 0] + diff[:, :, 1, 1], diff[:, :, 0, 1] + diff[:, :, 1, 0]
-    )
-    choice = np.argmin(diff, 1)
-    dist = np.min(diff, 1)
-    hit = np.zeros(len(line_gt), np.bool)
-    tp = np.zeros(len(line_pred), np.float)
-    fp = np.zeros(len(line_pred), np.float)
-    for i in range(len(line_pred)):
-        if dist[i] < threshold and not hit[choice[i]]:
-            hit[choice[i]] = True
-            tp[i] = 1
-        else:
-            fp[i] = 1
-    return tp, fp
-
-
-def msAP(line_pred, line_gt, threshold):
-    tp, fp = msTPFP(line_pred, line_gt, threshold)
-    tp = np.cumsum(tp) / len(line_gt)
-    fp = np.cumsum(fp) / len(line_gt)
-    return ap(tp, fp)

+ 0 - 9
lcnn/models/__init__.py

@@ -1,9 +0,0 @@
-# flake8: noqa
-from .hourglass_pose import hg
-from .unet import unet
-from .resnet50_pose import resnet50
-from .resnet50 import resnet501
-from .fasterrcnn_resnet50 import fasterrcnn_resnet50
-# from .dla import dla169, dla102x, dla102x2
-
-__all__ = ['hg', 'unet', 'resnet50', 'resnet501', 'fasterrcnn_resnet50']

+ 0 - 48
lcnn/models/base/base_dataset.py

@@ -1,48 +0,0 @@
-from abc import ABC, abstractmethod
-
-import torch
-from torch import nn, Tensor
-from torch.utils.data import Dataset
-from torch.utils.data.dataset import T_co
-
-from torchvision.transforms import  functional as F
-
-class BaseDataset(Dataset, ABC):
-    def __init__(self,dataset_path):
-        self.default_transform=DefaultTransform()
-        pass
-
-    def __getitem__(self, index) -> T_co:
-        pass
-
-    @abstractmethod
-    def read_target(self,item,lbl_path,extra=None):
-        pass
-
-    """显示数据集指定图片"""
-    @abstractmethod
-    def show(self,idx):
-        pass
-
-    """
-    显示数据集指定名字的图片
-    """
-
-    @abstractmethod
-    def show_img(self,img_path):
-        pass
-
-class DefaultTransform(nn.Module):
-    def forward(self, img: Tensor) -> Tensor:
-        if not isinstance(img, Tensor):
-            img = F.pil_to_tensor(img)
-        return F.convert_image_dtype(img, torch.float)
-
-    def __repr__(self) -> str:
-        return self.__class__.__name__ + "()"
-
-    def describe(self) -> str:
-        return (
-            "Accepts ``PIL.Image``, batched ``(B, C, H, W)`` and single ``(C, H, W)`` image ``torch.Tensor`` objects. "
-            "The images are rescaled to ``[0.0, 1.0]``."
-        )

+ 0 - 201
lcnn/models/hourglass_pose.py

@@ -1,201 +0,0 @@
-"""
-Hourglass network inserted in the pre-activated Resnet
-Use lr=0.01 for current version
-(c) Yichao Zhou (LCNN)
-(c) YANG, Wei
-"""
-import torch
-import torch.nn as nn
-import torch.nn.functional as F
-
-__all__ = ["HourglassNet", "hg"]
-
-
-class Bottleneck2D(nn.Module):
-    expansion = 2  # 扩展因子
-
-    def __init__(self, inplanes, planes, stride=1, downsample=None):
-        super(Bottleneck2D, self).__init__()
-
-        self.bn1 = nn.BatchNorm2d(inplanes)
-        self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1)
-        self.bn2 = nn.BatchNorm2d(planes)
-        self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, padding=1)
-        self.bn3 = nn.BatchNorm2d(planes)
-        self.conv3 = nn.Conv2d(planes, planes * 2, kernel_size=1)
-        self.relu = nn.ReLU(inplace=True)
-        self.downsample = downsample
-        self.stride = stride
-
-    def forward(self, x):
-        residual = x
-
-        out = self.bn1(x)
-        out = self.relu(out)
-        out = self.conv1(out)
-
-        out = self.bn2(out)
-        out = self.relu(out)
-        out = self.conv2(out)
-
-        out = self.bn3(out)
-        out = self.relu(out)
-        out = self.conv3(out)
-
-        if self.downsample is not None:
-            residual = self.downsample(x)
-
-        out += residual
-
-        return out
-
-
-class Hourglass(nn.Module):
-    def __init__(self, block, num_blocks, planes, depth):
-        super(Hourglass, self).__init__()
-        self.depth = depth
-        self.block = block
-        self.hg = self._make_hour_glass(block, num_blocks, planes, depth)
-
-    def _make_residual(self, block, num_blocks, planes):
-        layers = []
-        for i in range(0, num_blocks):
-            layers.append(block(planes * block.expansion, planes))
-        return nn.Sequential(*layers)
-
-    def _make_hour_glass(self, block, num_blocks, planes, depth):
-        hg = []
-        for i in range(depth):
-            res = []
-            for j in range(3):
-                res.append(self._make_residual(block, num_blocks, planes))
-            if i == 0:
-                res.append(self._make_residual(block, num_blocks, planes))
-            hg.append(nn.ModuleList(res))
-        return nn.ModuleList(hg)
-
-    def _hour_glass_forward(self, n, x):
-        up1 = self.hg[n - 1][0](x)
-        low1 = F.max_pool2d(x, 2, stride=2)
-        low1 = self.hg[n - 1][1](low1)
-
-        if n > 1:
-            low2 = self._hour_glass_forward(n - 1, low1)
-        else:
-            low2 = self.hg[n - 1][3](low1)
-        low3 = self.hg[n - 1][2](low2)
-        up2 = F.interpolate(low3, scale_factor=2)
-        out = up1 + up2
-        return out
-
-    def forward(self, x):
-        return self._hour_glass_forward(self.depth, x)
-
-
-class HourglassNet(nn.Module):
-    """Hourglass model from Newell et al ECCV 2016"""
-
-    def __init__(self, block, head, depth, num_stacks, num_blocks, num_classes):
-        super(HourglassNet, self).__init__()
-
-        self.inplanes = 64
-        self.num_feats = 128
-        self.num_stacks = num_stacks
-        self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=7, stride=2, padding=3)
-        self.bn1 = nn.BatchNorm2d(self.inplanes)
-        self.relu = nn.ReLU(inplace=True)
-        self.layer1 = self._make_residual(block, self.inplanes, 1)
-        self.layer2 = self._make_residual(block, self.inplanes, 1)
-        self.layer3 = self._make_residual(block, self.num_feats, 1)
-        self.maxpool = nn.MaxPool2d(2, stride=2)
-
-        # build hourglass modules
-        ch = self.num_feats * block.expansion
-        # vpts = []
-        hg, res, fc, score, fc_, score_ = [], [], [], [], [], []
-        for i in range(num_stacks):
-            hg.append(Hourglass(block, num_blocks, self.num_feats, depth))
-            res.append(self._make_residual(block, self.num_feats, num_blocks))
-            fc.append(self._make_fc(ch, ch))
-            score.append(head(ch, num_classes))
-            # vpts.append(VptsHead(ch))
-            # vpts.append(nn.Linear(ch, 9))
-            # score.append(nn.Conv2d(ch, num_classes, kernel_size=1))
-            # score[i].bias.data[0] += 4.6
-            # score[i].bias.data[2] += 4.6
-            if i < num_stacks - 1:
-                fc_.append(nn.Conv2d(ch, ch, kernel_size=1))
-                score_.append(nn.Conv2d(num_classes, ch, kernel_size=1))
-        self.hg = nn.ModuleList(hg)
-        self.res = nn.ModuleList(res)
-        self.fc = nn.ModuleList(fc)
-        self.score = nn.ModuleList(score)
-        # self.vpts = nn.ModuleList(vpts)
-        self.fc_ = nn.ModuleList(fc_)
-        self.score_ = nn.ModuleList(score_)
-
-    def _make_residual(self, block, planes, blocks, stride=1):
-        downsample = None
-        if stride != 1 or self.inplanes != planes * block.expansion:
-            downsample = nn.Sequential(
-                nn.Conv2d(
-                    self.inplanes,
-                    planes * block.expansion,
-                    kernel_size=1,
-                    stride=stride,
-                )
-            )
-
-        layers = []
-        layers.append(block(self.inplanes, planes, stride, downsample))
-        self.inplanes = planes * block.expansion
-        for i in range(1, blocks):
-            layers.append(block(self.inplanes, planes))
-
-        return nn.Sequential(*layers)
-
-    def _make_fc(self, inplanes, outplanes):
-        bn = nn.BatchNorm2d(inplanes)
-        conv = nn.Conv2d(inplanes, outplanes, kernel_size=1)
-        return nn.Sequential(conv, bn, self.relu)
-
-    def forward(self, x):
-        out = []
-        # out_vps = []
-        x = self.conv1(x)
-        x = self.bn1(x)
-        x = self.relu(x)
-
-        x = self.layer1(x)
-        x = self.maxpool(x)
-        x = self.layer2(x)
-        x = self.layer3(x)
-
-        for i in range(self.num_stacks):
-            y = self.hg[i](x)
-            y = self.res[i](y)
-            y = self.fc[i](y)
-            score = self.score[i](y)
-            # pre_vpts = F.adaptive_avg_pool2d(x, (1, 1))
-            # pre_vpts = pre_vpts.reshape(-1, 256)
-            # vpts = self.vpts[i](x)
-            out.append(score)
-            # out_vps.append(vpts)
-            if i < self.num_stacks - 1:
-                fc_ = self.fc_[i](y)
-                score_ = self.score_[i](score)
-                x = x + fc_ + score_
-
-        return out[::-1], y  # , out_vps[::-1]
-
-
-def hg(**kwargs):
-    model = HourglassNet(
-        Bottleneck2D,
-        head=kwargs.get("head", lambda c_in, c_out: nn.Conv2D(c_in, c_out, 1)),
-        depth=kwargs["depth"],
-        num_stacks=kwargs["num_stacks"],
-        num_blocks=kwargs["num_blocks"],
-        num_classes=kwargs["num_classes"],
-    )
-    return model

+ 0 - 276
lcnn/models/line_vectorizer.py

@@ -1,276 +0,0 @@
-import itertools
-import random
-from collections import defaultdict
-
-import numpy as np
-import torch
-import torch.nn as nn
-import torch.nn.functional as F
-
-from lcnn.config import M
-
-FEATURE_DIM = 8
-
-
-class LineVectorizer(nn.Module):
-    def __init__(self, backbone):
-        super().__init__()
-        self.backbone = backbone
-
-        lambda_ = torch.linspace(0, 1, M.n_pts0)[:, None]
-        self.register_buffer("lambda_", lambda_)
-        self.do_static_sampling = M.n_stc_posl + M.n_stc_negl > 0
-
-        self.fc1 = nn.Conv2d(256, M.dim_loi, 1)
-        scale_factor = M.n_pts0 // M.n_pts1
-        if M.use_conv:
-            self.pooling = nn.Sequential(
-                nn.MaxPool1d(scale_factor, scale_factor),
-                Bottleneck1D(M.dim_loi, M.dim_loi),
-            )
-            self.fc2 = nn.Sequential(
-                nn.ReLU(inplace=True), nn.Linear(M.dim_loi * M.n_pts1 + FEATURE_DIM, 1)
-            )
-        else:
-            self.pooling = nn.MaxPool1d(scale_factor, scale_factor)
-            self.fc2 = nn.Sequential(
-                nn.Linear(M.dim_loi * M.n_pts1 + FEATURE_DIM, M.dim_fc),
-                nn.ReLU(inplace=True),
-                nn.Linear(M.dim_fc, M.dim_fc),
-                nn.ReLU(inplace=True),
-                nn.Linear(M.dim_fc, 1),
-            )
-        self.loss = nn.BCEWithLogitsLoss(reduction="none")
-
-    def forward(self, input_dict):
-        result = self.backbone(input_dict)
-        h = result["preds"]
-        x = self.fc1(result["feature"])
-        n_batch, n_channel, row, col = x.shape
-
-        xs, ys, fs, ps, idx, jcs = [], [], [], [], [0], []
-        for i, meta in enumerate(input_dict["meta"]):
-            p, label, feat, jc = self.sample_lines(
-                meta, h["jmap"][i], h["joff"][i], input_dict["mode"]
-            )
-            # print("p.shape:", p.shape)
-            ys.append(label)
-            if input_dict["mode"] == "training" and self.do_static_sampling:
-                p = torch.cat([p, meta["lpre"]])
-                feat = torch.cat([feat, meta["lpre_feat"]])
-                ys.append(meta["lpre_label"])
-                del jc
-            else:
-                jcs.append(jc)
-                ps.append(p)
-            fs.append(feat)
-
-            p = p[:, 0:1, :] * self.lambda_ + p[:, 1:2, :] * (1 - self.lambda_) - 0.5
-            p = p.reshape(-1, 2)  # [N_LINE x N_POINT, 2_XY]
-            px, py = p[:, 0].contiguous(), p[:, 1].contiguous()
-            px0 = px.floor().clamp(min=0, max=127)
-            py0 = py.floor().clamp(min=0, max=127)
-            px1 = (px0 + 1).clamp(min=0, max=127)
-            py1 = (py0 + 1).clamp(min=0, max=127)
-            px0l, py0l, px1l, py1l = px0.long(), py0.long(), px1.long(), py1.long()
-
-            # xp: [N_LINE, N_CHANNEL, N_POINT]
-            xp = (
-                (
-                        x[i, :, px0l, py0l] * (px1 - px) * (py1 - py)
-                        + x[i, :, px1l, py0l] * (px - px0) * (py1 - py)
-                        + x[i, :, px0l, py1l] * (px1 - px) * (py - py0)
-                        + x[i, :, px1l, py1l] * (px - px0) * (py - py0)
-                )
-                    .reshape(n_channel, -1, M.n_pts0)
-                    .permute(1, 0, 2)
-            )
-            xp = self.pooling(xp)
-            xs.append(xp)
-            idx.append(idx[-1] + xp.shape[0])
-
-        x, y = torch.cat(xs), torch.cat(ys)
-        f = torch.cat(fs)
-        x = x.reshape(-1, M.n_pts1 * M.dim_loi)
-        x = torch.cat([x, f], 1)
-        x = self.fc2(x.float()).flatten()
-
-        if input_dict["mode"] != "training":
-            p = torch.cat(ps)
-            s = torch.sigmoid(x)
-            b = s > 0.5
-            lines = []
-            score = []
-            for i in range(n_batch):
-                p0 = p[idx[i]: idx[i + 1]]
-                s0 = s[idx[i]: idx[i + 1]]
-                mask = b[idx[i]: idx[i + 1]]
-                p0 = p0[mask]
-                s0 = s0[mask]
-                if len(p0) == 0:
-                    lines.append(torch.zeros([1, M.n_out_line, 2, 2], device=p.device))
-                    score.append(torch.zeros([1, M.n_out_line], device=p.device))
-                else:
-                    arg = torch.argsort(s0, descending=True)
-                    p0, s0 = p0[arg], s0[arg]
-                    lines.append(p0[None, torch.arange(M.n_out_line) % len(p0)])
-                    score.append(s0[None, torch.arange(M.n_out_line) % len(s0)])
-                for j in range(len(jcs[i])):
-                    if len(jcs[i][j]) == 0:
-                        jcs[i][j] = torch.zeros([M.n_out_junc, 2], device=p.device)
-                    jcs[i][j] = jcs[i][j][
-                        None, torch.arange(M.n_out_junc) % len(jcs[i][j])
-                    ]
-            result["preds"]["lines"] = torch.cat(lines)
-            result["preds"]["score"] = torch.cat(score)
-            result["preds"]["juncs"] = torch.cat([jcs[i][0] for i in range(n_batch)])
-            result["box"] = result['aaa']
-            del result['aaa']
-            if len(jcs[i]) > 1:
-                result["preds"]["junts"] = torch.cat(
-                    [jcs[i][1] for i in range(n_batch)]
-                )
-
-        if input_dict["mode"] != "testing":
-            y = torch.cat(ys)
-            loss = self.loss(x, y)
-            lpos_mask, lneg_mask = y, 1 - y
-            loss_lpos, loss_lneg = loss * lpos_mask, loss * lneg_mask
-
-            def sum_batch(x):
-                xs = [x[idx[i]: idx[i + 1]].sum()[None] for i in range(n_batch)]
-                return torch.cat(xs)
-
-            lpos = sum_batch(loss_lpos) / sum_batch(lpos_mask).clamp(min=1)
-            lneg = sum_batch(loss_lneg) / sum_batch(lneg_mask).clamp(min=1)
-            result["losses"][0]["lpos"] = lpos * M.loss_weight["lpos"]
-            result["losses"][0]["lneg"] = lneg * M.loss_weight["lneg"]
-
-        if input_dict["mode"] == "training":
-            for i in result["aaa"].keys():
-                result["losses"][0][i] = result["aaa"][i]
-            del result["preds"]
-
-        return result
-
-    def sample_lines(self, meta, jmap, joff, mode):
-        with torch.no_grad():
-            junc = meta["junc_coords"]  # [N, 2]
-            jtyp = meta["jtyp"]  # [N]
-            Lpos = meta["line_pos_idx"]
-            Lneg = meta["line_neg_idx"]
-
-            n_type = jmap.shape[0]
-            jmap = non_maximum_suppression(jmap).reshape(n_type, -1)
-            joff = joff.reshape(n_type, 2, -1)
-            max_K = M.n_dyn_junc // n_type
-            N = len(junc)
-            if mode != "training":
-                K = min(int((jmap > M.eval_junc_thres).float().sum().item()), max_K)
-            else:
-                K = min(int(N * 2 + 2), max_K)
-            if K < 2:
-                K = 2
-            device = jmap.device
-
-            # index: [N_TYPE, K]
-            score, index = torch.topk(jmap, k=K)
-            y = (index // 128).float() + torch.gather(joff[:, 0], 1, index) + 0.5
-            x = (index % 128).float() + torch.gather(joff[:, 1], 1, index) + 0.5
-
-            # xy: [N_TYPE, K, 2]
-            xy = torch.cat([y[..., None], x[..., None]], dim=-1)
-            xy_ = xy[..., None, :]
-            del x, y, index
-
-            # dist: [N_TYPE, K, N]
-            dist = torch.sum((xy_ - junc) ** 2, -1)
-            cost, match = torch.min(dist, -1)
-
-            # xy: [N_TYPE * K, 2]
-            # match: [N_TYPE, K]
-            for t in range(n_type):
-                match[t, jtyp[match[t]] != t] = N
-            match[cost > 1.5 * 1.5] = N
-            match = match.flatten()
-
-            _ = torch.arange(n_type * K, device=device)
-            u, v = torch.meshgrid(_, _)
-            u, v = u.flatten(), v.flatten()
-            up, vp = match[u], match[v]
-            label = Lpos[up, vp]
-
-            if mode == "training":
-                c = torch.zeros_like(label, dtype=torch.bool)
-
-                # sample positive lines
-                cdx = label.nonzero().flatten()
-                if len(cdx) > M.n_dyn_posl:
-                    # print("too many positive lines")
-                    perm = torch.randperm(len(cdx), device=device)[: M.n_dyn_posl]
-                    cdx = cdx[perm]
-                c[cdx] = 1
-
-                # sample negative lines
-                cdx = Lneg[up, vp].nonzero().flatten()
-                if len(cdx) > M.n_dyn_negl:
-                    # print("too many negative lines")
-                    perm = torch.randperm(len(cdx), device=device)[: M.n_dyn_negl]
-                    cdx = cdx[perm]
-                c[cdx] = 1
-
-                # sample other (unmatched) lines
-                cdx = torch.randint(len(c), (M.n_dyn_othr,), device=device)
-                c[cdx] = 1
-            else:
-                c = (u < v).flatten()
-
-            # sample lines
-            u, v, label = u[c], v[c], label[c]
-            xy = xy.reshape(n_type * K, 2)
-            xyu, xyv = xy[u], xy[v]
-
-            u2v = xyu - xyv
-            u2v /= torch.sqrt((u2v ** 2).sum(-1, keepdim=True)).clamp(min=1e-6)
-            feat = torch.cat(
-                [
-                    xyu / 128 * M.use_cood,
-                    xyv / 128 * M.use_cood,
-                    u2v * M.use_slop,
-                    (u[:, None] > K).float(),
-                    (v[:, None] > K).float(),
-                ],
-                1,
-            )
-            line = torch.cat([xyu[:, None], xyv[:, None]], 1)
-
-            xy = xy.reshape(n_type, K, 2)
-            jcs = [xy[i, score[i] > 0.03] for i in range(n_type)]
-            return line, label.float(), feat, jcs
-
-
-def non_maximum_suppression(a):
-    ap = F.max_pool2d(a, 3, stride=1, padding=1)
-    mask = (a == ap).float().clamp(min=0.0)
-    return a * mask
-
-
-class Bottleneck1D(nn.Module):
-    def __init__(self, inplanes, outplanes):
-        super(Bottleneck1D, self).__init__()
-
-        planes = outplanes // 2
-        self.op = nn.Sequential(
-            nn.BatchNorm1d(inplanes),
-            nn.ReLU(inplace=True),
-            nn.Conv1d(inplanes, planes, kernel_size=1),
-            nn.BatchNorm1d(planes),
-            nn.ReLU(inplace=True),
-            nn.Conv1d(planes, planes, kernel_size=3, padding=1),
-            nn.BatchNorm1d(planes),
-            nn.ReLU(inplace=True),
-            nn.Conv1d(planes, outplanes, kernel_size=1),
-        )
-
-    def forward(self, x):
-        return x + self.op(x)

+ 0 - 118
lcnn/models/multitask_learner.py

@@ -1,118 +0,0 @@
-from collections import OrderedDict, defaultdict
-
-import numpy as np
-import torch
-import torch.nn as nn
-import torch.nn.functional as F
-
-from lcnn.config import M
-
-
-class MultitaskHead(nn.Module):
-    def __init__(self, input_channels, num_class):
-        super(MultitaskHead, self).__init__()
-        # print("输入的维度是:", input_channels)
-        m = int(input_channels / 4)
-        heads = []
-        for output_channels in sum(M.head_size, []):
-            heads.append(
-                nn.Sequential(
-                    nn.Conv2d(input_channels, m, kernel_size=3, padding=1),
-                    nn.ReLU(inplace=True),
-                    nn.Conv2d(m, output_channels, kernel_size=1),
-                )
-            )
-        self.heads = nn.ModuleList(heads)
-        assert num_class == sum(sum(M.head_size, []))
-
-    def forward(self, x):
-        return torch.cat([head(x) for head in self.heads], dim=1)
-
-
-class MultitaskLearner(nn.Module):
-    def __init__(self, backbone):
-        super(MultitaskLearner, self).__init__()
-        self.backbone = backbone
-        head_size = M.head_size
-        self.num_class = sum(sum(head_size, []))
-        self.head_off = np.cumsum([sum(h) for h in head_size])
-
-    def forward(self, input_dict):
-        image = input_dict["image"]
-        target_b = input_dict["target_b"]
-        outputs, feature, aaa = self.backbone(image, target_b, input_dict["mode"])  # train时aaa是损失,val时是box
-
-        result = {"feature": feature}
-        batch, channel, row, col = outputs[0].shape
-        # print(f"batch:{batch}")
-        # print(f"channel:{channel}")
-        # print(f"row:{row}")
-        # print(f"col:{col}")
-
-        T = input_dict["target"].copy()
-        n_jtyp = T["junc_map"].shape[1]
-
-        # switch to CNHW
-        for task in ["junc_map"]:
-            T[task] = T[task].permute(1, 0, 2, 3)
-        for task in ["junc_offset"]:
-            T[task] = T[task].permute(1, 2, 0, 3, 4)
-
-        offset = self.head_off
-        loss_weight = M.loss_weight
-        losses = []
-        for stack, output in enumerate(outputs):
-            output = output.transpose(0, 1).reshape([-1, batch, row, col]).contiguous()
-            jmap = output[0: offset[0]].reshape(n_jtyp, 2, batch, row, col)
-            lmap = output[offset[0]: offset[1]].squeeze(0)
-            # print(f"lmap:{lmap.shape}")
-            joff = output[offset[1]: offset[2]].reshape(n_jtyp, 2, batch, row, col)
-            if stack == 0:
-                result["preds"] = {
-                    "jmap": jmap.permute(2, 0, 1, 3, 4).softmax(2)[:, :, 1],
-                    "lmap": lmap.sigmoid(),
-                    "joff": joff.permute(2, 0, 1, 3, 4).sigmoid() - 0.5,
-                }
-                if input_dict["mode"] == "testing":
-                    return result
-
-            L = OrderedDict()
-            L["jmap"] = sum(
-                cross_entropy_loss(jmap[i], T["junc_map"][i]) for i in range(n_jtyp)
-            )
-            L["lmap"] = (
-                F.binary_cross_entropy_with_logits(lmap, T["line_map"], reduction="none")
-                    .mean(2)
-                    .mean(1)
-            )
-            L["joff"] = sum(
-                sigmoid_l1_loss(joff[i, j], T["junc_offset"][i, j], -0.5, T["junc_map"][i])
-                for i in range(n_jtyp)
-                for j in range(2)
-            )
-            for loss_name in L:
-                L[loss_name].mul_(loss_weight[loss_name])
-            losses.append(L)
-        result["losses"] = losses
-        result["aaa"] = aaa
-        return result
-
-
-def l2loss(input, target):
-    return ((target - input) ** 2).mean(2).mean(1)
-
-
-def cross_entropy_loss(logits, positive):
-    nlogp = -F.log_softmax(logits, dim=0)
-    return (positive * nlogp[1] + (1 - positive) * nlogp[0]).mean(2).mean(1)
-
-
-def sigmoid_l1_loss(logits, target, offset=0.0, mask=None):
-    logp = torch.sigmoid(logits) + offset
-    loss = torch.abs(logp - target)
-    if mask is not None:
-        w = mask.mean(2, True).mean(1, True)
-        w[w == 0] = 1
-        loss = loss * (mask / w)
-
-    return loss.mean(2).mean(1)

+ 0 - 182
lcnn/models/resnet50.py

@@ -1,182 +0,0 @@
-# import torch
-# import torch.nn as nn
-# import torchvision.models as models
-#
-#
-# class ResNet50Backbone(nn.Module):
-#     def __init__(self, num_classes=5, num_stacks=1, pretrained=True):
-#         super(ResNet50Backbone, self).__init__()
-#
-#         # 加载预训练的ResNet50
-#         if pretrained:
-#             resnet = models.resnet50(weights=models.ResNet50_Weights.IMAGENET1K_V1)
-#         else:
-#             resnet = models.resnet50(weights=None)
-#
-#         # 移除最后的全连接层
-#         self.backbone = nn.Sequential(
-#             resnet.conv1,#特征图分辨率降低为1/2,通道数从3升为64
-#             resnet.bn1,
-#             resnet.relu,
-#             resnet.maxpool,#特征图分辨率降低为1/4,通道数仍然为64
-#             resnet.layer1,#stride为1,不改变分辨率,依然为1/4,通道数从64升为64*4=256
-#             # resnet.layer2,#stride为2,特征图分辨率降低为1/8,通道数从256升为128*4=512
-#             # resnet.layer3,#stride为2,特征图分辨率降低为1/16,通道数从512升为256*4=1024
-#             # resnet.layer4,#stride为2,特征图分辨率降低为1/32,通道数从512升为256*4=2048
-#         )
-#
-#         # 多任务输出层
-#         self.score_layers = nn.ModuleList([
-#             nn.Sequential(
-#                 nn.Conv2d(256, 128, kernel_size=3, padding=1),
-#                 nn.BatchNorm2d(128),
-#                 nn.ReLU(inplace=True),
-#                 nn.Conv2d(128, num_classes, kernel_size=1)
-#             )
-#             for _ in range(num_stacks)
-#         ])
-#
-#         # 上采样层,确保输出大小为128x128
-#         self.upsample = nn.Upsample(
-#             scale_factor=0.25,
-#             mode='bilinear',
-#             align_corners=True
-#         )
-#
-#     def forward(self, x):
-#         # 主干网络特征提取
-#         x = self.backbone(x)
-#
-#         # # 调整通道数
-#         # x = self.channel_adjust(x)
-#         #
-#         # # 上采样到128x128
-#         # x = self.upsample(x)
-#
-#         # 多堆栈输出
-#         outputs = []
-#         for score_layer in self.score_layers:
-#             output = score_layer(x)
-#             outputs.append(output)
-#
-#         # 返回第一个输出(如果有多个堆栈)
-#         return outputs, x
-#
-#
-# def resnet50(**kwargs):
-#     model = ResNet50Backbone(
-#         num_classes=kwargs.get("num_classes", 5),
-#         num_stacks=kwargs.get("num_stacks", 1),
-#         pretrained=kwargs.get("pretrained", True)
-#     )
-#     return model
-#
-#
-# __all__ = ["ResNet50Backbone", "resnet50"]
-#
-#
-# # 测试网络输出
-# model = resnet50(num_classes=5, num_stacks=1)
-#
-#
-# # 方法1:直接传入图像张量
-# x = torch.randn(2, 3, 512, 512)
-# outputs, feature = model(x)
-# print("Outputs length:", len(outputs))
-# print("Output[0] shape:", outputs[0].shape)
-# print("Feature shape:", feature.shape)
-
-import torch
-import torch.nn as nn
-import torchvision
-import torchvision.models as models
-from torchvision.models.detection.backbone_utils import _validate_trainable_layers, _resnet_fpn_extractor
-
-
-class ResNet50Backbone1(nn.Module):
-    def __init__(self, num_classes=5, num_stacks=1, pretrained=True):
-        super(ResNet50Backbone1, self).__init__()
-
-        # 加载预训练的ResNet50
-        if pretrained:
-            # self.resnet = models.resnet50(weights=models.ResNet50_Weights.IMAGENET1K_V1)
-            trainable_backbone_layers = _validate_trainable_layers(True, None, 5, 3)
-            backbone = models.resnet50(weights=models.ResNet50_Weights.IMAGENET1K_V1)
-            self.resnet = _resnet_fpn_extractor(backbone, trainable_backbone_layers)
-        else:
-            self.resnet = models.resnet50(weights=None)
-
-    def forward(self, x):
-        features = self.resnet(x)
-
-        # if self.training:
-        #     proposals, matched_idxs, labels, regression_targets = self.select_training_samples20(proposals, targets)
-        # else:
-        #     labels = None
-        #     regression_targets = None
-        #     matched_idxs = None
-        # box_features = self.box_roi_pool(features, proposals, image_shapes)  # ROI 映射到固定尺寸的特征表示上
-        # # print(f"box_features:{box_features.shape}")  # [2048, 256, 7, 7] 建议框统一大小
-        # box_features = self.box_head(box_features)  #
-        # # print(f"box_features:{box_features.shape}")  # [N, 1024] 经过头部网络处理后的特征向量
-        # class_logits, box_regression = self.box_predictor(box_features)
-        #
-        # result: List[Dict[str, torch.Tensor]] = []
-        # losses = {}
-        # if self.training:
-        #     if labels is None:
-        #         raise ValueError("labels cannot be None")
-        #     if regression_targets is None:
-        #         raise ValueError("regression_targets cannot be None")
-        #     loss_classifier, loss_box_reg = fastrcnn_loss(class_logits, box_regression, labels, regression_targets)
-        #     losses = {"loss_classifier": loss_classifier, "loss_box_reg": loss_box_reg}
-        # else:
-        #     boxes, scores, labels = self.postprocess_detections(class_logits, box_regression, proposals, image_shapes)
-        #     num_images = len(boxes)
-        #     for i in range(num_images):
-        #         result.append(
-        #             {
-        #                 "boxes": boxes[i],
-        #                 "labels": labels[i],
-        #                 "scores": scores[i],
-        #             }
-        #         )
-        #     print(f"boxes:{boxes[0].shape}")
-
-        return x['0']
-
-        # # 多堆栈输出
-        # outputs = []
-        # for score_layer in self.score_layers:
-        #     output = score_layer(x)
-        #     outputs.append(output)
-        #
-        # # 返回第一个输出(如果有多个堆栈)
-        # return outputs, x
-
-
-def resnet501(**kwargs):
-    model = ResNet50Backbone1(
-        num_classes=kwargs.get("num_classes", 5),
-        num_stacks=kwargs.get("num_stacks", 1),
-        pretrained=kwargs.get("pretrained", True)
-    )
-    return model
-
-
-# __all__ = ["ResNet50Backbone1", "resnet501"]
-#
-# # 测试网络输出
-# model = resnet501(num_classes=5, num_stacks=1)
-#
-# # 方法1:直接传入图像张量
-# x = torch.randn(2, 3, 512, 512)
-# # outputs, feature = model(x)
-# # print("Outputs length:", len(outputs))
-# # print("Output[0] shape:", outputs[0].shape)
-# # print("Feature shape:", feature.shape)
-# feature = model(x)
-# print("Feature shape:", feature.keys())
-
-
-

+ 0 - 87
lcnn/models/resnet50_pose.py

@@ -1,87 +0,0 @@
-import torch
-import torch.nn as nn
-import torchvision.models as models
-
-
-class ResNet50Backbone(nn.Module):
-    def __init__(self, num_classes=5, num_stacks=1, pretrained=True):
-        super(ResNet50Backbone, self).__init__()
-
-        # 加载预训练的ResNet50
-        if pretrained:
-            resnet = models.resnet50(weights=models.ResNet50_Weights.IMAGENET1K_V1)
-        else:
-            resnet = models.resnet50(weights=None)
-
-        # 移除最后的全连接层
-        self.backbone = nn.Sequential(
-            resnet.conv1,#特征图分辨率降低为1/2,通道数从3升为64
-            resnet.bn1,
-            resnet.relu,
-            resnet.maxpool,#特征图分辨率降低为1/4,通道数仍然为64
-            resnet.layer1,#stride为1,不改变分辨率,依然为1/4,通道数从64升为64*4=256
-            # resnet.layer2,#stride为2,特征图分辨率降低为1/8,通道数从256升为128*4=512
-            # resnet.layer3,#stride为2,特征图分辨率降低为1/16,通道数从512升为256*4=1024
-            # resnet.layer4,#stride为2,特征图分辨率降低为1/32,通道数从512升为256*4=2048
-        )
-
-        # 多任务输出层
-        self.score_layers = nn.ModuleList([
-            nn.Sequential(
-                nn.Conv2d(256, 128, kernel_size=3, padding=1),
-                nn.BatchNorm2d(128),
-                nn.ReLU(inplace=True),
-                nn.Conv2d(128, num_classes, kernel_size=1)
-            )
-            for _ in range(num_stacks)
-        ])
-
-        # 上采样层,确保输出大小为128x128
-        self.upsample = nn.Upsample(
-            scale_factor=0.25,
-            mode='bilinear',
-            align_corners=True
-        )
-
-    def forward(self, x):
-        # 主干网络特征提取
-        x = self.backbone(x)
-
-        # # 调整通道数
-        # x = self.channel_adjust(x)
-        #
-        # # 上采样到128x128
-        # x = self.upsample(x)
-
-        # 多堆栈输出
-        outputs = []
-        for score_layer in self.score_layers:
-            output = score_layer(x)
-            outputs.append(output)
-
-        # 返回第一个输出(如果有多个堆栈)
-        return outputs, x
-
-
-def resnet50(**kwargs):
-    model = ResNet50Backbone(
-        num_classes=kwargs.get("num_classes", 5),
-        num_stacks=kwargs.get("num_stacks", 1),
-        pretrained=kwargs.get("pretrained", True)
-    )
-    return model
-
-
-__all__ = ["ResNet50Backbone", "resnet50"]
-
-
-# 测试网络输出
-model = resnet50(num_classes=5, num_stacks=1)
-
-
-# 方法1:直接传入图像张量
-x = torch.randn(2, 3, 512, 512)
-outputs, feature = model(x)
-print("Outputs length:", len(outputs))
-print("Output[0] shape:", outputs[0].shape)
-print("Feature shape:", feature.shape)

+ 0 - 126
lcnn/models/unet.py

@@ -1,126 +0,0 @@
-import torch
-import torch.nn as nn
-import torch.nn.functional as F
-
-__all__ = ["UNetWithMultipleStacks", "unet"]
-
-
-class DoubleConv(nn.Module):
-    def __init__(self, in_channels, out_channels):
-        super().__init__()
-        self.double_conv = nn.Sequential(
-            nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
-            nn.BatchNorm2d(out_channels),
-            nn.ReLU(inplace=True),
-            nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1),
-            nn.BatchNorm2d(out_channels),
-            nn.ReLU(inplace=True)
-        )
-
-    def forward(self, x):
-        return self.double_conv(x)
-
-
-class UNetWithMultipleStacks(nn.Module):
-    def __init__(self, num_classes, num_stacks=2, base_channels=64):
-        super().__init__()
-        self.num_stacks = num_stacks
-        # 编码器
-        self.enc1 = DoubleConv(3, base_channels)
-        self.pool1 = nn.MaxPool2d(2)
-        self.enc2 = DoubleConv(base_channels, base_channels * 2)
-        self.pool2 = nn.MaxPool2d(2)
-        self.enc3 = DoubleConv(base_channels * 2, base_channels * 4)
-        self.pool3 = nn.MaxPool2d(2)
-        self.enc4 = DoubleConv(base_channels * 4, base_channels * 8)
-        self.pool4 = nn.MaxPool2d(2)
-        self.enc5 = DoubleConv(base_channels * 8, base_channels * 16)
-        self.pool5 = nn.MaxPool2d(2)
-
-        # bottleneck
-        self.bottleneck = DoubleConv(base_channels * 16, base_channels * 32)
-
-        # 解码器
-        self.upconv5 = nn.ConvTranspose2d(base_channels * 32, base_channels * 16, kernel_size=2, stride=2)
-        self.dec5 = DoubleConv(base_channels * 32, base_channels * 16)
-        self.upconv4 = nn.ConvTranspose2d(base_channels * 16, base_channels * 8, kernel_size=2, stride=2)
-        self.dec4 = DoubleConv(base_channels * 16, base_channels * 8)
-        self.upconv3 = nn.ConvTranspose2d(base_channels * 8, base_channels * 4, kernel_size=2, stride=2)
-        self.dec3 = DoubleConv(base_channels * 8, base_channels * 4)
-        self.upconv2 = nn.ConvTranspose2d(base_channels * 4, base_channels * 2, kernel_size=2, stride=2)
-        self.dec2 = DoubleConv(base_channels * 4, base_channels * 2)
-        self.upconv1 = nn.ConvTranspose2d(base_channels * 2, base_channels, kernel_size=2, stride=2)
-        self.dec1 = DoubleConv(base_channels * 2, base_channels)
-
-        # 额外的上采样层,从512降到128
-        self.final_upsample = nn.Sequential(
-            nn.Conv2d(base_channels, base_channels, kernel_size=3, padding=1),
-            nn.BatchNorm2d(base_channels),
-            nn.ReLU(inplace=True),
-            nn.Upsample(scale_factor=0.25, mode='bilinear', align_corners=True)
-        )
-
-        # 修改score_layers以匹配256通道
-        self.score_layers = nn.ModuleList([
-            nn.Sequential(
-                nn.Conv2d(256, 128, kernel_size=3, padding=1),
-                nn.ReLU(inplace=True),
-                nn.Conv2d(128, num_classes, kernel_size=1)
-            )
-            for _ in range(num_stacks)
-        ])
-
-        self.channel_adjust = nn.Conv2d(base_channels, 256, kernel_size=1)
-
-    def forward(self, x):
-        # 编码过程
-        enc1 = self.enc1(x)
-        enc2 = self.enc2(self.pool1(enc1))
-        enc3 = self.enc3(self.pool2(enc2))
-        enc4 = self.enc4(self.pool3(enc3))
-        enc5 = self.enc5(self.pool4(enc4))
-
-        # 瓶颈层
-        bottleneck = self.bottleneck(self.pool5(enc5))
-
-        # 解码过程
-        dec5 = self.upconv5(bottleneck)
-        dec5 = torch.cat([dec5, enc5], dim=1)
-        dec5 = self.dec5(dec5)
-
-        dec4 = self.upconv4(dec5)
-        dec4 = torch.cat([dec4, enc4], dim=1)
-        dec4 = self.dec4(dec4)
-
-        dec3 = self.upconv3(dec4)
-        dec3 = torch.cat([dec3, enc3], dim=1)
-        dec3 = self.dec3(dec3)
-
-        dec2 = self.upconv2(dec3)
-        dec2 = torch.cat([dec2, enc2], dim=1)
-        dec2 = self.dec2(dec2)
-
-        dec1 = self.upconv1(dec2)
-        dec1 = torch.cat([dec1, enc1], dim=1)
-        dec1 = self.dec1(dec1)
-
-        # 额外的上采样,使输出大小为128
-        dec1 = self.final_upsample(dec1)
-        # 调整通道数
-        dec1 = self.channel_adjust(dec1)
-        # 多堆栈输出
-        outputs = []
-        for score_layer in self.score_layers:
-            output = score_layer(dec1)
-            outputs.append(output)
-
-        return outputs[::-1], dec1
-
-
-def unet(**kwargs):
-    model = UNetWithMultipleStacks(
-        num_classes=kwargs["num_classes"],
-        num_stacks=kwargs.get("num_stacks", 2),
-        base_channels=kwargs.get("base_channels", 64)
-    )
-    return model

+ 0 - 77
lcnn/postprocess.py

@@ -1,77 +0,0 @@
-import numpy as np
-
-
-def pline(x1, y1, x2, y2, x, y):
-    px = x2 - x1
-    py = y2 - y1
-    dd = px * px + py * py
-    u = ((x - x1) * px + (y - y1) * py) / max(1e-9, float(dd))
-    dx = x1 + u * px - x
-    dy = y1 + u * py - y
-    return dx * dx + dy * dy
-
-
-def psegment(x1, y1, x2, y2, x, y):
-    px = x2 - x1
-    py = y2 - y1
-    dd = px * px + py * py
-    u = max(min(((x - x1) * px + (y - y1) * py) / float(dd), 1), 0)
-    dx = x1 + u * px - x
-    dy = y1 + u * py - y
-    return dx * dx + dy * dy
-
-
-def plambda(x1, y1, x2, y2, x, y):
-    px = x2 - x1
-    py = y2 - y1
-    dd = px * px + py * py
-    return ((x - x1) * px + (y - y1) * py) / max(1e-9, float(dd))
-
-
-def postprocess(lines, scores, threshold=0.01, tol=1e9, do_clip=False):
-    nlines, nscores = [], []
-    for (p, q), score in zip(lines, scores):
-        start, end = 0, 1
-        for a, b in nlines:
-            if (
-                min(
-                    max(pline(*p, *q, *a), pline(*p, *q, *b)),
-                    max(pline(*a, *b, *p), pline(*a, *b, *q)),
-                )
-                > threshold ** 2
-            ):
-                continue
-            lambda_a = plambda(*p, *q, *a)
-            lambda_b = plambda(*p, *q, *b)
-            if lambda_a > lambda_b:
-                lambda_a, lambda_b = lambda_b, lambda_a
-            lambda_a -= tol
-            lambda_b += tol
-
-            # case 1: skip (if not do_clip)
-            if start < lambda_a and lambda_b < end:
-                continue
-
-            # not intersect
-            if lambda_b < start or lambda_a > end:
-                continue
-
-            # cover
-            if lambda_a <= start and end <= lambda_b:
-                start = 10
-                break
-
-            # case 2 & 3:
-            if lambda_a <= start and start <= lambda_b:
-                start = lambda_b
-            if lambda_a <= end and end <= lambda_b:
-                end = lambda_a
-
-            if start >= end:
-                break
-
-        if start >= end:
-            continue
-        nlines.append(np.array([p + (q - p) * start, p + (q - p) * end]))
-        nscores.append(score)
-    return np.array(nlines), np.array(nscores)

+ 0 - 101
lcnn/utils.py

@@ -1,101 +0,0 @@
-import math
-import os.path as osp
-import multiprocessing
-from timeit import default_timer as timer
-
-import numpy as np
-import torch
-import matplotlib.pyplot as plt
-
-
-class benchmark(object):
-    def __init__(self, msg, enable=True, fmt="%0.3g"):
-        self.msg = msg
-        self.fmt = fmt
-        self.enable = enable
-
-    def __enter__(self):
-        if self.enable:
-            self.start = timer()
-        return self
-
-    def __exit__(self, *args):
-        if self.enable:
-            t = timer() - self.start
-            print(("%s : " + self.fmt + " seconds") % (self.msg, t))
-            self.time = t
-
-
-def quiver(x, y, ax):
-    ax.set_xlim(0, x.shape[1])
-    ax.set_ylim(x.shape[0], 0)
-    ax.quiver(
-        x,
-        y,
-        units="xy",
-        angles="xy",
-        scale_units="xy",
-        scale=1,
-        minlength=0.01,
-        width=0.1,
-        color="b",
-    )
-
-
-def recursive_to(input, device):
-    if isinstance(input, torch.Tensor):
-        return input.to(device)
-    if isinstance(input, dict):
-        for name in input:
-            if isinstance(input[name], torch.Tensor):
-                input[name] = input[name].to(device)
-        return input
-    if isinstance(input, list):
-        for i, item in enumerate(input):
-            input[i] = recursive_to(item, device)
-        return input
-    assert False
-
-
-def np_softmax(x, axis=0):
-    """Compute softmax values for each sets of scores in x."""
-    e_x = np.exp(x - np.max(x))
-    return e_x / e_x.sum(axis=axis, keepdims=True)
-
-
-def argsort2d(arr):
-    return np.dstack(np.unravel_index(np.argsort(arr.ravel()), arr.shape))[0]
-
-
-def __parallel_handle(f, q_in, q_out):
-    while True:
-        i, x = q_in.get()
-        if i is None:
-            break
-        q_out.put((i, f(x)))
-
-
-def parmap(f, X, nprocs=multiprocessing.cpu_count(), progress_bar=lambda x: x):
-    if nprocs == 0:
-        nprocs = multiprocessing.cpu_count()
-    q_in = multiprocessing.Queue(1)
-    q_out = multiprocessing.Queue()
-
-    proc = [
-        multiprocessing.Process(target=__parallel_handle, args=(f, q_in, q_out))
-        for _ in range(nprocs)
-    ]
-    for p in proc:
-        p.daemon = True
-        p.start()
-
-    try:
-        sent = [q_in.put((i, x)) for i, x in enumerate(X)]
-        [q_in.put((None, None)) for _ in range(nprocs)]
-        res = [q_out.get() for _ in progress_bar(range(len(sent)))]
-        [p.join() for p in proc]
-    except KeyboardInterrupt:
-        q_in.close()
-        q_out.close()
-        raise
-    return [x for i, x in sorted(res)]

BIN
libs/vision_libs/_C.pyd


+ 2 - 1
libs/vision_libs/__init__.py

@@ -3,7 +3,8 @@ import warnings
 from modulefinder import Module
 
 import torch
-from torchvision import _meta_registrations, datasets, io, models, ops, transforms, utils
+# from torchvision import _meta_registrations, datasets, io, models, ops, transforms, utils
+from torchvision import datasets, io, models, ops, transforms, utils
 
 from .extension import _HAS_OPS
 

BIN
libs/vision_libs/image.pyd


+ 1 - 0
libs/vision_libs/models/detection/rpn.py

@@ -370,6 +370,7 @@ class RegionProposalNetwork(torch.nn.Module):
         proposals = self.box_coder.decode(pred_bbox_deltas.detach(), anchors)
         proposals = proposals.view(num_images, -1, 4)
         boxes, scores = self.filter_proposals(proposals, objectness, images.image_sizes, num_anchors_per_level)
+        # print(f'boxes:{boxes.shape},scores:{scores.shape}')
 
         losses = {}
         if self.training:

+ 0 - 0
lcnn/models/base/__init__.py → models/ins_detect/__init__.py


+ 143 - 0
models/ins_detect/maskrcnn.py

@@ -0,0 +1,143 @@
+import math
+import os
+import sys
+from datetime import datetime
+from typing import Mapping, Any
+import cv2
+import numpy as np
+import torch
+import torchvision
+from torch import nn
+from torchvision.io import read_image
+from torchvision.models.detection import MaskRCNN_ResNet50_FPN_V2_Weights
+from torchvision.models.detection.faster_rcnn import FastRCNNPredictor
+from torchvision.models.detection.mask_rcnn import MaskRCNNPredictor
+from torchvision.utils import draw_bounding_boxes
+
+from models.config.config_tool import read_yaml
+from models.ins_detect.trainer import train_cfg
+from tools import utils
+
+
+class MaskRCNNModel(nn.Module):
+
+    def __init__(self, num_classes=0, transforms=None):
+        super(MaskRCNNModel, self).__init__()
+        self.__model = torchvision.models.detection.maskrcnn_resnet50_fpn_v2(
+            weights=MaskRCNN_ResNet50_FPN_V2_Weights.DEFAULT)
+        if transforms is None:
+            self.transforms = MaskRCNN_ResNet50_FPN_V2_Weights.DEFAULT.transforms()
+        if num_classes != 0:
+            self.set_num_classes(num_classes)
+            # self.__num_classes=0
+
+        self.device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
+
+    def forward(self, inputs):
+        outputs = self.__model(inputs)
+        return outputs
+
+    def train(self, cfg):
+        parameters = read_yaml(cfg)
+        num_classes=parameters['num_classes']
+        # print(f'num_classes:{num_classes}')
+        self.set_num_classes(num_classes)
+        train_cfg(self.__model, cfg)
+
+    def set_num_classes(self, num_classes):
+        in_features = self.__model.roi_heads.box_predictor.cls_score.in_features
+        self.__model.roi_heads.box_predictor = FastRCNNPredictor(in_features, num_classes=num_classes)
+        in_features_mask = self.__model.roi_heads.mask_predictor.conv5_mask.in_channels
+        hidden_layer = 256
+        self.__model.roi_heads.mask_predictor = MaskRCNNPredictor(in_features_mask, hidden_layer,
+                                                                  num_classes=num_classes)
+
+    def load_weight(self, pt_path):
+        state_dict = torch.load(pt_path)
+        self.__model.load_state_dict(state_dict)
+
+    def load_state_dict(self, state_dict: Mapping[str, Any], strict: bool = True):
+        self.__model.load_state_dict(state_dict)
+        # return super().load_state_dict(state_dict, strict)
+
+    def predict(self, src, show_box=True, show_mask=True):
+        self.__model.eval()
+
+        img = read_image(src)
+        img = self.transforms(img)
+        img = img.to(self.device)
+        result = self.__model([img])
+        print(f'result:{result}')
+        masks = result[0]['masks']
+        boxes = result[0]['boxes']
+        # cv2.imshow('mask',masks[0].cpu().detach().numpy())
+        boxes = boxes.cpu().detach()
+        drawn_boxes = draw_bounding_boxes((img * 255).to(torch.uint8), boxes, colors="red", width=5)
+        print(f'drawn_boxes:{drawn_boxes.shape}')
+        boxed_img = drawn_boxes.permute(1, 2, 0).numpy()
+        # boxed_img=cv2.resize(boxed_img,(800,800))
+        # cv2.imshow('boxes',boxed_img)
+
+        mask = masks[0].cpu().detach().permute(1, 2, 0).numpy()
+
+        mask = cv2.resize(mask, (800, 800))
+        # cv2.imshow('mask',mask)
+        img = img.cpu().detach().permute(1, 2, 0).numpy()
+
+        masked_img = self.overlay_masks_on_image(boxed_img, masks)
+        masked_img = cv2.resize(masked_img, (800, 800))
+        cv2.imshow('img_masks', masked_img)
+        # show_img_boxes_masks(img, boxes, masks)
+        cv2.waitKey(0)
+
+    def generate_colors(self, n):
+        """
+        生成n个均匀分布在HSV色彩空间中的颜色,并转换成BGR色彩空间。
+
+        :param n: 需要的颜色数量
+        :return: 一个包含n个颜色的列表,每个颜色为BGR格式的元组
+        """
+        hsv_colors = [(i / n * 180, 1 / 3 * 255, 2 / 3 * 255) for i in range(n)]
+        bgr_colors = [tuple(map(int, cv2.cvtColor(np.uint8([[hsv]]), cv2.COLOR_HSV2BGR)[0][0])) for hsv in hsv_colors]
+        return bgr_colors
+
+    def overlay_masks_on_image(self, image, masks, alpha=0.6):
+        """
+        在原图上叠加多个掩码,每个掩码使用不同的颜色。
+
+        :param image: 原图 (NumPy 数组)
+        :param masks: 掩码列表 (每个都是 NumPy 数组,二值图像)
+        :param colors: 颜色列表 (每个颜色都是 (B, G, R) 格式的元组)
+        :param alpha: 掩码的透明度 (0.0 到 1.0)
+        :return: 叠加了多个掩码的图像
+        """
+        colors = self.generate_colors(len(masks))
+        if len(masks) != len(colors):
+            raise ValueError("The number of masks and colors must be the same.")
+
+        # 复制原图,避免修改原始图像
+        overlay = image.copy()
+
+        for mask, color in zip(masks, colors):
+            # 确保掩码是二值图像
+            mask = mask.cpu().detach().permute(1, 2, 0).numpy()
+            binary_mask = (mask > 0).astype(np.uint8) * 255  # 你可以根据实际情况调整阈值
+
+            # 创建彩色掩码
+            colored_mask = np.zeros_like(image)
+
+            colored_mask[:] = color
+            colored_mask = cv2.bitwise_and(colored_mask, colored_mask, mask=binary_mask)
+
+            # 将彩色掩码与当前的叠加图像混合
+            overlay = cv2.addWeighted(overlay, 1 - alpha, colored_mask, alpha, 0)
+
+        return overlay
+
+
+if __name__ == '__main__':
+    # ins_model = MaskRCNNModel(num_classes=5)
+    ins_model = MaskRCNNModel()
+    # data_path = r'F:\DevTools\datasets\renyaun\1012\spilt'
+    # ins_model.train(data_dir=data_path,epochs=5000,target_type='pixel',batch_size=6,num_workers=10,num_classes=5)
+    ins_model.train(cfg='train.yaml')

+ 93 - 0
models/ins_detect/maskrcnn_dataset.py

@@ -0,0 +1,93 @@
+import os
+
+import PIL
+import cv2
+import numpy as np
+import torch
+from matplotlib import pyplot as plt
+from torch.utils.data import Dataset
+from torchvision.models.detection import MaskRCNN_ResNet50_FPN_V2_Weights
+
+from models.dataset_tool import masks_to_boxes, read_masks_from_txt, read_masks_from_pixels
+
+
+class MaskRCNNDataset(Dataset):
+    def __init__(self, dataset_path, transforms=None, dataset_type=None, target_type='polygon'):
+        self.data_path = dataset_path
+        self.transforms = transforms
+        self.img_path = os.path.join(dataset_path, "images/" + dataset_type)
+        self.lbl_path = os.path.join(dataset_path, "labels/" + dataset_type)
+        self.imgs = os.listdir(self.img_path)
+        self.lbls = os.listdir(self.lbl_path)
+        self.target_type = target_type
+        self.deafult_transform= MaskRCNN_ResNet50_FPN_V2_Weights.DEFAULT.transforms()
+        # print('maskrcnn inited!')
+
+    def __getitem__(self, item):
+        # print('__getitem__')
+        img_path = os.path.join(self.img_path, self.imgs[item])
+        lbl_path = os.path.join(self.lbl_path, self.imgs[item][:-3] + 'txt')
+        img = PIL.Image.open(img_path).convert('RGB')
+        # h, w = np.array(img).shape[:2]
+        w, h = img.size
+        # print(f'h,w:{h, w}')
+        target = self.read_target(item=item, lbl_path=lbl_path, shape=(h, w))
+        if self.transforms:
+            img, target = self.transforms(img,target)
+        else:
+            img=self.deafult_transform(img)
+        # print(f'img:{img.shape},target:{target}')
+        return img, target
+
+    def create_masks_from_polygons(self, polygons, image_shape):
+        """创建一个与图像尺寸相同的掩码,并填充多边形轮廓"""
+        colors = np.array([plt.cm.Spectral(each) for each in np.linspace(0, 1, len(polygons))])
+        masks = []
+
+        for polygon_data, col in zip(polygons, colors):
+            mask = np.zeros(image_shape[:2], dtype=np.uint8)
+            # 将多边形顶点转换为 NumPy 数组
+            _, polygon = polygon_data
+            pts = np.array(polygon, np.int32).reshape((-1, 1, 2))
+
+            # 使用 OpenCV 的 fillPoly 函数填充多边形
+            # print(f'color:{col[:3]}')
+            cv2.fillPoly(mask, [pts], np.array(col[:3]) * 255)
+            mask = torch.from_numpy(mask)
+            mask[mask != 0] = 1
+            masks.append(mask)
+
+        return masks
+
+    def read_target(self, item, lbl_path, shape):
+        # print(f'lbl_path:{lbl_path}')
+        h, w = shape
+        labels = []
+        masks = []
+        if self.target_type == 'polygon':
+            labels, masks = read_masks_from_txt(lbl_path, shape)
+        elif self.target_type == 'pixel':
+            labels, masks = read_masks_from_pixels(lbl_path, shape)
+
+        target = {}
+        target["boxes"] = masks_to_boxes(torch.stack(masks))
+        target["labels"] = torch.stack(labels)
+        target["masks"] = torch.stack(masks)
+        target["image_id"] = torch.tensor(item)
+        target["area"] = torch.zeros(len(masks))
+        target["iscrowd"] = torch.zeros(len(masks))
+        return target
+
+    def heatmap_enhance(self, img):
+        # 直方图均衡化
+        img_eq = cv2.equalizeHist(img)
+
+        # 自适应直方图均衡化
+        # clahe = cv2.createCLAHE(clipLimit=2.0, tileGridSize=(8, 8))
+        # img_clahe = clahe.apply(img)
+
+        # 将灰度图转换为热力图
+        heatmap = cv2.applyColorMap(img_eq, cv2.COLORMAP_HOT)
+
+    def __len__(self):
+        return len(self.imgs)

+ 31 - 0
models/ins_detect/train.yaml

@@ -0,0 +1,31 @@
+
+
+dataset_path: F:\DevTools\datasets\renyaun\1012\spilt
+
+#train parameters
+num_classes: 5
+opt: 'adamw'
+batch_size: 2
+epochs: 10
+lr: 0.0005
+momentum: 0.9
+weight_decay: 0.0001
+lr_step_size: 3
+lr_gamma: 0.1
+num_workers: 4
+print_freq: 10
+target_type: polygon
+enable_logs: True
+augmentation: True
+checkpoint: None
+
+
+## Classes
+#names:
+#  0: fire
+#  1: dust
+#  2: move_machine
+#  3: open_machine
+#  4: close_machine
+
+

+ 220 - 0
models/ins_detect/trainer.py

@@ -0,0 +1,220 @@
+import math
+import os
+import sys
+from datetime import datetime
+
+import torch
+import torchvision
+from torch.utils.tensorboard import SummaryWriter
+from torchvision.models.detection import MaskRCNN_ResNet50_FPN_V2_Weights
+
+from models.config.config_tool import read_yaml
+from models.ins_detect.maskrcnn_dataset import MaskRCNNDataset
+from tools import utils, presets
+
+
+def train_one_epoch(model, optimizer, data_loader, device, epoch, print_freq, scaler=None):
+    model.train()
+    metric_logger = utils.MetricLogger(delimiter="  ")
+    metric_logger.add_meter("lr", utils.SmoothedValue(window_size=1, fmt="{value:.6f}"))
+    header = f"Epoch: [{epoch}]"
+
+    lr_scheduler = None
+    if epoch == 0:
+        warmup_factor = 1.0 / 1000
+        warmup_iters = min(1000, len(data_loader) - 1)
+
+        lr_scheduler = torch.optim.lr_scheduler.LinearLR(
+            optimizer, start_factor=warmup_factor, total_iters=warmup_iters
+        )
+
+    for images, targets in metric_logger.log_every(data_loader, print_freq, header):
+        print(f'images:{images}')
+        images = list(image.to(device) for image in images)
+        targets = [{k: v.to(device) if isinstance(v, torch.Tensor) else v for k, v in t.items()} for t in targets]
+        with torch.cuda.amp.autocast(enabled=scaler is not None):
+            loss_dict = model(images, targets)
+            losses = sum(loss for loss in loss_dict.values())
+
+        # reduce losses over all GPUs for logging purposes
+        loss_dict_reduced = utils.reduce_dict(loss_dict)
+        losses_reduced = sum(loss for loss in loss_dict_reduced.values())
+
+        loss_value = losses_reduced.item()
+
+        if not math.isfinite(loss_value):
+            print(f"Loss is {loss_value}, stopping training")
+            print(loss_dict_reduced)
+            sys.exit(1)
+
+        optimizer.zero_grad()
+        if scaler is not None:
+            scaler.scale(losses).backward()
+            scaler.step(optimizer)
+            scaler.update()
+        else:
+            losses.backward()
+            optimizer.step()
+
+        if lr_scheduler is not None:
+            lr_scheduler.step()
+
+        metric_logger.update(loss=losses_reduced, **loss_dict_reduced)
+        metric_logger.update(lr=optimizer.param_groups[0]["lr"])
+
+    return metric_logger
+
+
+def load_train_parameter(cfg):
+    parameters = read_yaml(cfg)
+    return parameters
+
+
+def train_cfg(model, cfg):
+    parameters = read_yaml(cfg)
+    print(f'train parameters:{parameters}')
+    train(model, **parameters)
+
+
+def train(model, **kwargs):
+    # 默认参数
+    default_params = {
+        'dataset_path': '/path/to/dataset',
+        'num_classes': 2,
+        'num_keypoints':2,
+        'opt': 'adamw',
+        'batch_size': 2,
+        'epochs': 10,
+        'lr': 0.005,
+        'momentum': 0.9,
+        'weight_decay': 1e-4,
+        'lr_step_size': 3,
+        'lr_gamma': 0.1,
+        'num_workers': 4,
+        'print_freq': 10,
+        'target_type': 'polygon',
+        'enable_logs': True,
+        'augmentation': False,
+        'checkpoint':None
+    }
+    # 更新默认参数
+    for key, value in kwargs.items():
+        if key in default_params:
+            default_params[key] = value
+        else:
+            raise ValueError(f"Unknown argument: {key}")
+
+    # 解析参数
+    dataset_path = default_params['dataset_path']
+    num_classes = default_params['num_classes']
+    batch_size = default_params['batch_size']
+    epochs = default_params['epochs']
+    lr = default_params['lr']
+    momentum = default_params['momentum']
+    weight_decay = default_params['weight_decay']
+    lr_step_size = default_params['lr_step_size']
+    lr_gamma = default_params['lr_gamma']
+    num_workers = default_params['num_workers']
+    print_freq = default_params['print_freq']
+    target_type = default_params['target_type']
+    augmentation = default_params['augmentation']
+    # 设置设备
+    device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
+
+    train_result_ptath = os.path.join('train_results', datetime.now().strftime("%Y%m%d_%H%M%S"))
+    wts_path = os.path.join(train_result_ptath, 'weights')
+    tb_path = os.path.join(train_result_ptath, 'logs')
+    writer = SummaryWriter(tb_path)
+
+    transforms = None
+    # default_transforms = MaskRCNN_ResNet50_FPN_V2_Weights.DEFAULT.transforms()
+    if augmentation:
+        transforms = get_transform(is_train=True)
+        print(f'transforms:{transforms}')
+    if not os.path.exists('train_results'):
+        os.mkdir('train_results')
+
+    model.to(device)
+    optimizer = torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=weight_decay)
+
+    dataset = MaskRCNNDataset(dataset_path=dataset_path,
+                              transforms=transforms, dataset_type='train', target_type=target_type)
+    dataset_test = MaskRCNNDataset(dataset_path=dataset_path, transforms=None,
+                                   dataset_type='val')
+
+    train_sampler = torch.utils.data.RandomSampler(dataset)
+    test_sampler = torch.utils.data.SequentialSampler(dataset_test)
+    train_batch_sampler = torch.utils.data.BatchSampler(train_sampler, batch_size, drop_last=True)
+    train_collate_fn = utils.collate_fn
+    data_loader = torch.utils.data.DataLoader(
+        dataset, batch_sampler=train_batch_sampler, num_workers=num_workers, collate_fn=train_collate_fn
+    )
+    # data_loader_test = torch.utils.data.DataLoader(
+    #     dataset_test, batch_size=1, sampler=test_sampler, num_workers=num_workers, collate_fn=utils.collate_fn
+    # )
+
+    img_results_path = os.path.join(train_result_ptath, 'img_results')
+    if os.path.exists(train_result_ptath):
+        pass
+    #     os.remove(train_result_ptath)
+    else:
+        os.mkdir(train_result_ptath)
+
+    if os.path.exists(train_result_ptath):
+        os.mkdir(wts_path)
+        os.mkdir(img_results_path)
+
+    for epoch in range(epochs):
+        metric_logger = train_one_epoch(model, optimizer, data_loader, device, epoch, print_freq, None)
+        losses = metric_logger.meters['loss'].global_avg
+        print(f'epoch {epoch}:loss:{losses}')
+        if os.path.exists(f'{wts_path}/last.pt'):
+            os.remove(f'{wts_path}/last.pt')
+        torch.save(model.state_dict(), f'{wts_path}/last.pt')
+        write_metric_logs(epoch, metric_logger, writer)
+        if epoch == 0:
+            best_loss = losses;
+        if best_loss >= losses:
+            best_loss = losses
+            if os.path.exists(f'{wts_path}/best.pt'):
+                os.remove(f'{wts_path}/best.pt')
+            torch.save(model.state_dict(), f'{wts_path}/best.pt')
+
+
+def get_transform(is_train, **kwargs):
+    default_params = {
+        'augmentation': 'multiscale',
+        'backend': 'tensor',
+        'use_v2': False,
+
+    }
+    # 更新默认参数
+    for key, value in kwargs.items():
+        if key in default_params:
+            default_params[key] = value
+        else:
+            raise ValueError(f"Unknown argument: {key}")
+
+    # 解析参数
+    augmentation = default_params['augmentation']
+    backend = default_params['backend']
+    use_v2 = default_params['use_v2']
+    if is_train:
+        return presets.DetectionPresetTrain(
+            data_augmentation=augmentation, backend=backend, use_v2=use_v2
+        )
+    # elif weights and test_only:
+    #     weights = torchvision.models.get_weight(args.weights)
+    #     trans = weights.transforms()
+    #     return lambda img, target: (trans(img), target)
+    else:
+        return presets.DetectionPresetEval(backend=backend, use_v2=use_v2)
+
+
+def write_metric_logs(epoch, metric_logger, writer):
+    writer.add_scalar(f'loss_classifier:', metric_logger.meters['loss_classifier'].global_avg, epoch)
+    writer.add_scalar(f'loss_box_reg:', metric_logger.meters['loss_box_reg'].global_avg, epoch)
+    writer.add_scalar(f'loss_mask:', metric_logger.meters['loss_mask'].global_avg, epoch)
+    writer.add_scalar(f'loss_objectness:', metric_logger.meters['loss_objectness'].global_avg, epoch)
+    writer.add_scalar(f'loss_rpn_box_reg:', metric_logger.meters['loss_rpn_box_reg'].global_avg, epoch)
+    writer.add_scalar(f'train loss:', metric_logger.meters['loss'].global_avg, epoch)

+ 312 - 0
models/keypoint/keypoint_dataset.py

@@ -1,3 +1,107 @@
+<<<<<<<< HEAD:models/keypoint/keypoint_dataset.py
+========
+# import glob
+# import json
+# import math
+# import os
+# import random
+#
+# import numpy as np
+# import numpy.linalg as LA
+# import torch
+# from skimage import io
+# from torch.utils.data import Dataset
+# from torch.utils.data.dataloader import default_collate
+#
+# from lcnn.config import M
+#
+# from .dataset_tool import line_boxes_faster, read_masks_from_txt_wire, read_masks_from_pixels_wire
+#
+#
+# class WireframeDataset(Dataset):
+#     def __init__(self, rootdir, split):
+#         self.rootdir = rootdir
+#         filelist = glob.glob(f"{rootdir}/{split}/*_label.npz")
+#         filelist.sort()
+#
+#         # print(f"n{split}:", len(filelist))
+#         self.split = split
+#         self.filelist = filelist
+#
+#     def __len__(self):
+#         return len(self.filelist)
+#
+#     def __getitem__(self, idx):
+#         iname = self.filelist[idx][:-10].replace("_a0", "").replace("_a1", "") + ".png"
+#         image = io.imread(iname).astype(float)[:, :, :3]
+#         if "a1" in self.filelist[idx]:
+#             image = image[:, ::-1, :]
+#         image = (image - M.image.mean) / M.image.stddev
+#         image = np.rollaxis(image, 2).copy()
+#
+#         with np.load(self.filelist[idx]) as npz:
+#             target = {
+#                 name: torch.from_numpy(npz[name]).float()
+#                 for name in ["jmap", "joff", "lmap"]
+#             }
+#             lpos = np.random.permutation(npz["lpos"])[: M.n_stc_posl]
+#             lneg = np.random.permutation(npz["lneg"])[: M.n_stc_negl]
+#             npos, nneg = len(lpos), len(lneg)
+#             lpre = np.concatenate([lpos, lneg], 0)
+#             for i in range(len(lpre)):
+#                 if random.random() > 0.5:
+#                     lpre[i] = lpre[i, ::-1]
+#             ldir = lpre[:, 0, :2] - lpre[:, 1, :2]
+#             ldir /= np.clip(LA.norm(ldir, axis=1, keepdims=True), 1e-6, None)
+#             feat = [
+#                 lpre[:, :, :2].reshape(-1, 4) / 128 * M.use_cood,
+#                 ldir * M.use_slop,
+#                 lpre[:, :, 2],
+#             ]
+#             feat = np.concatenate(feat, 1)
+#             meta = {
+#                 "junc": torch.from_numpy(npz["junc"][:, :2]),
+#                 "jtyp": torch.from_numpy(npz["junc"][:, 2]).byte(),
+#                 "Lpos": self.adjacency_matrix(len(npz["junc"]), npz["Lpos"]),
+#                 "Lneg": self.adjacency_matrix(len(npz["junc"]), npz["Lneg"]),
+#                 "lpre": torch.from_numpy(lpre[:, :, :2]),
+#                 "lpre_label": torch.cat([torch.ones(npos), torch.zeros(nneg)]),
+#                 "lpre_feat": torch.from_numpy(feat),
+#             }
+#
+#         labels = []
+#         labels = read_masks_from_pixels_wire(iname, (512, 512))
+#         # if self.target_type == 'polygon':
+#         #     labels, masks = read_masks_from_txt_wire(iname, (512, 512))
+#         # elif self.target_type == 'pixel':
+#         #     labels = read_masks_from_pixels_wire(iname, (512, 512))
+#
+#         target["labels"] = torch.stack(labels)
+#         target["boxes"] = line_boxes_faster(meta)
+#
+#
+#         return torch.from_numpy(image).float(), meta, target
+#
+#     def adjacency_matrix(self, n, link):
+#         mat = torch.zeros(n + 1, n + 1, dtype=torch.uint8)
+#         link = torch.from_numpy(link)
+#         if len(link) > 0:
+#             mat[link[:, 0], link[:, 1]] = 1
+#             mat[link[:, 1], link[:, 0]] = 1
+#         return mat
+#
+#
+# def collate(batch):
+#     return (
+#         default_collate([b[0] for b in batch]),
+#         [b[1] for b in batch],
+#         default_collate([b[2] for b in batch]),
+#     )
+
+
+# 原LCNN数据格式,改了属性名,加了box相关
+
+>>>>>>>> 8c208a87b75e9fde2fe6dcf23f2e589aeb87f172:lcnn/datasets.py
 from torch.utils.data.dataset import T_co
 
 from models.base.base_dataset import BaseDataset
@@ -30,8 +134,12 @@ def validate_keypoints(keypoints, image_width, image_height):
         if not (0 <= x < image_width and 0 <= y < image_height):
             raise ValueError(f"Key point ({x}, {y}) is out of bounds for image size ({image_width}, {image_height})")
 
+<<<<<<<< HEAD:models/keypoint/keypoint_dataset.py
 
 class KeypointDataset(BaseDataset):
+========
+class  WireframeDataset(BaseDataset):
+>>>>>>>> 8c208a87b75e9fde2fe6dcf23f2e589aeb87f172:lcnn/datasets.py
     def __init__(self, dataset_path, transforms=None, dataset_type=None, target_type='pixel'):
         super().__init__(dataset_path)
 
@@ -199,7 +307,211 @@ class KeypointDataset(BaseDataset):
 
 
 
+<<<<<<<< HEAD:models/keypoint/keypoint_dataset.py
 if __name__ == '__main__':
     path=r"I:\datasets\wirenet_1000"
     dataset= KeypointDataset(dataset_path=path, dataset_type='train')
     dataset.show(7)
+========
+
+'''
+# 使用roi_head数据格式有要求,更改数据格式
+from torch.utils.data.dataset import T_co
+
+from models.base.base_dataset import BaseDataset
+
+import glob
+import json
+import math
+import os
+import random
+import cv2
+import PIL
+
+import matplotlib.pyplot as plt
+import matplotlib as mpl
+from torchvision.utils import draw_bounding_boxes
+
+import numpy as np
+import numpy.linalg as LA
+import torch
+from skimage import io
+from torch.utils.data import Dataset
+from torch.utils.data.dataloader import default_collate
+
+import matplotlib.pyplot as plt
+from models.dataset_tool import line_boxes, read_masks_from_txt_wire, read_masks_from_pixels_wire, adjacency_matrix
+
+from tools.presets import DetectionPresetTrain
+
+
+class WireframeDataset(BaseDataset):
+    def __init__(self, dataset_path, transforms=None, dataset_type=None, target_type='pixel'):
+        super().__init__(dataset_path)
+
+        self.data_path = dataset_path
+        print(f'data_path:{dataset_path}')
+        self.transforms = transforms
+        self.img_path = os.path.join(dataset_path, "images\\" + dataset_type)
+        self.lbl_path = os.path.join(dataset_path, "labels\\" + dataset_type)
+        self.imgs = os.listdir(self.img_path)
+        self.lbls = os.listdir(self.lbl_path)
+        self.target_type = target_type
+        # self.default_transform = DefaultTransform()
+        self.data_augmentation = DetectionPresetTrain(data_augmentation="hflip")  # multiscale会改变图像大小
+
+    def __getitem__(self, index) -> T_co:
+        img_path = os.path.join(self.img_path, self.imgs[index])
+        lbl_path = os.path.join(self.lbl_path, self.imgs[index][:-3] + 'json')
+
+        img = PIL.Image.open(img_path).convert('RGB')
+        w, h = img.size
+
+        # wire_labels, target = self.read_target(item=index, lbl_path=lbl_path, shape=(h, w))
+        target = self.read_target(item=index, lbl_path=lbl_path, shape=(h, w))
+        # if self.transforms:
+        #     img, target = self.transforms(img, target)
+        # else:
+        #     img = self.default_transform(img)
+
+        img, target = self.data_augmentation(img, target)
+
+        print(f'img:{img.shape}')
+        return img, target
+
+    def __len__(self):
+        return len(self.imgs)
+
+    def read_target(self, item, lbl_path, shape, extra=None):
+        # print(f'lbl_path:{lbl_path}')
+        with open(lbl_path, 'r') as file:
+            lable_all = json.load(file)
+
+        n_stc_posl = 300
+        n_stc_negl = 40
+        use_cood = 0
+        use_slop = 0
+
+        wire = lable_all["wires"][0]  # 字典
+        line_pos_coords = np.random.permutation(wire["line_pos_coords"]["content"])[: n_stc_posl]  # 不足,有多少取多少
+        line_neg_coords = np.random.permutation(wire["line_neg_coords"]["content"])[: n_stc_negl]
+        npos, nneg = len(line_pos_coords), len(line_neg_coords)
+        lpre = np.concatenate([line_pos_coords, line_neg_coords], 0)  # 正负样本坐标合在一起
+        for i in range(len(lpre)):
+            if random.random() > 0.5:
+                lpre[i] = lpre[i, ::-1]
+        ldir = lpre[:, 0, :2] - lpre[:, 1, :2]
+        ldir /= np.clip(LA.norm(ldir, axis=1, keepdims=True), 1e-6, None)
+        feat = [
+            lpre[:, :, :2].reshape(-1, 4) / 128 * use_cood,
+            ldir * use_slop,
+            lpre[:, :, 2],
+        ]
+        feat = np.concatenate(feat, 1)
+
+        wire_labels = {
+            "junc_coords": torch.tensor(wire["junc_coords"]["content"])[:, :2],
+            "jtyp": torch.tensor(wire["junc_coords"]["content"])[:, 2].byte(),
+            "line_pos_idx": adjacency_matrix(len(wire["junc_coords"]["content"]), wire["line_pos_idx"]["content"]),
+            # 真实存在线条的邻接矩阵
+            "line_neg_idx": adjacency_matrix(len(wire["junc_coords"]["content"]), wire["line_neg_idx"]["content"]),
+            # 不存在线条的临界矩阵
+            "lpre": torch.tensor(lpre)[:, :, :2],
+            "lpre_label": torch.cat([torch.ones(npos), torch.zeros(nneg)]),  # 样本对应标签 1,0
+            "lpre_feat": torch.from_numpy(feat),
+            "junc_map": torch.tensor(wire['junc_map']["content"]),
+            "junc_offset": torch.tensor(wire['junc_offset']["content"]),
+            "line_map": torch.tensor(wire['line_map']["content"]),
+        }
+
+        labels = []
+        # if self.target_type == 'polygon':
+        #     labels, masks = read_masks_from_txt_wire(lbl_path, shape)
+        # elif self.target_type == 'pixel':
+        #     labels = read_masks_from_pixels_wire(lbl_path, shape)
+
+        # print(torch.stack(masks).shape)    # [线段数, 512, 512]
+        target = {}
+        # target["labels"] = torch.stack(labels)
+        target["image_id"] = torch.tensor(item)
+        # return wire_labels, target
+        target["wires"] = wire_labels
+        target["boxes"] = line_boxes(target)
+        return target
+
+    def show(self, idx):
+        image, target = self.__getitem__(idx)
+        img_path = os.path.join(self.img_path, self.imgs[idx])
+        self._draw_vecl(img_path, target)
+
+    def show_img(self, img_path):
+
+        """根据给定的图片路径展示图像及其标注信息"""
+        # 获取对应的标签文件路径
+        img_name = os.path.basename(img_path)
+        img_path = os.path.join(self.img_path, img_name)
+        print(img_path)
+        lbl_name = img_name[:-3] + 'json'
+        lbl_path = os.path.join(self.lbl_path, lbl_name)
+        print(lbl_path)
+
+        if not os.path.exists(lbl_path):
+            raise FileNotFoundError(f"Label file {lbl_path} does not exist.")
+
+        img = PIL.Image.open(img_path).convert('RGB')
+        w, h = img.size
+
+        target = self.read_target(0, lbl_path, shape=(h, w))
+
+        # 调用绘图函数
+        self._draw_vecl(img_path, target)
+
+
+    def _draw_vecl(self, img_path, target, fn=None):
+        cmap = plt.get_cmap("jet")
+        norm = mpl.colors.Normalize(vmin=0.4, vmax=1.0)
+        sm = plt.cm.ScalarMappable(cmap=cmap, norm=norm)
+        sm.set_array([])
+
+        def imshow(im):
+            plt.close()
+            plt.tight_layout()
+            plt.imshow(im)
+            plt.colorbar(sm, fraction=0.046)
+            plt.xlim([0, im.shape[0]])
+            plt.ylim([im.shape[0], 0])
+
+        junc = target['wires']['junc_coords'].cpu().numpy() * 4
+        jtyp = target['wires']['jtyp'].cpu().numpy()
+        juncs = junc[jtyp == 0]
+        junts = junc[jtyp == 1]
+
+        lpre = target['wires']["lpre"].cpu().numpy() * 4
+        vecl_target = target['wires']["lpre_label"].cpu().numpy()
+        lpre = lpre[vecl_target == 1]
+
+        lines = lpre
+        sline = np.ones(lpre.shape[0])
+        imshow(io.imread(img_path))
+        if len(lines) > 0 and not (lines[0] == 0).all():
+            for i, ((a, b), s) in enumerate(zip(lines, sline)):
+                if i > 0 and (lines[i] == lines[0]).all():
+                    break
+                plt.plot([a[1], b[1]], [a[0], b[0]], c="red", linewidth=1)  # a[1], b[1]无明确大小
+        if not (juncs[0] == 0).all():
+            for i, j in enumerate(juncs):
+                if i > 0 and (i == juncs[0]).all():
+                    break
+                plt.scatter(j[1], j[0], c="red", s=2, zorder=100)  # 原 s=64
+
+        img = PIL.Image.open(img_path).convert('RGB')
+        boxed_image = draw_bounding_boxes((self.default_transform(img) * 255).to(torch.uint8), target["boxes"],
+                                          colors="yellow", width=1)
+        plt.imshow(boxed_image.permute(1, 2, 0).numpy())
+        plt.show()
+
+        if fn != None:
+            plt.savefig(fn)
+
+'''
+>>>>>>>> 8c208a87b75e9fde2fe6dcf23f2e589aeb87f172:lcnn/datasets.py

+ 1 - 0
models/line_detect/dataset_LD.py

@@ -117,6 +117,7 @@ class WirePointDataset(BaseDataset):
         # return wire_labels, target
         target["wires"] = wire_labels
         target["boxes"] = line_boxes(target)
+        # print(f'boxes:{target["boxes"].shape}')
         return target
 
     def show(self, idx):

+ 0 - 120
models/line_detect/fasterrcnn_resnet50.py

@@ -1,120 +0,0 @@
-import torch
-import torch.nn as nn
-import torchvision
-from typing import Dict, List, Optional, Tuple
-import torch.nn.functional as F
-from torchvision.ops import MultiScaleRoIAlign
-from torchvision.models.detection.faster_rcnn import TwoMLPHead, FastRCNNPredictor
-from torchvision.models.detection.transform import GeneralizedRCNNTransform
-
-
-def get_model(num_classes):
-    # 加载预训练的ResNet-50 FPN backbone
-    model = torchvision.models.detection.fasterrcnn_resnet50_fpn(pretrained=True)
-
-    # 获取分类器的输入特征数
-    in_features = model.roi_heads.box_predictor.cls_score.in_features
-
-    # 替换分类器以适应新的类别数量
-    model.roi_heads.box_predictor = torchvision.models.detection.faster_rcnn.FastRCNNPredictor(in_features, num_classes)
-
-    return model
-
-
-def fastrcnn_loss(class_logits, box_regression, labels, regression_targets):
-    # type: (Tensor, Tensor, List[Tensor], List[Tensor]) -> Tuple[Tensor, Tensor]
-    """
-    Computes the loss for Faster R-CNN.
-
-    Args:
-        class_logits (Tensor)
-        box_regression (Tensor)
-        labels (list[BoxList])
-        regression_targets (Tensor)
-
-    Returns:
-        classification_loss (Tensor)
-        box_loss (Tensor)
-    """
-
-    labels = torch.cat(labels, dim=0)
-    regression_targets = torch.cat(regression_targets, dim=0)
-
-    classification_loss = F.cross_entropy(class_logits, labels)
-
-    # get indices that correspond to the regression targets for
-    # the corresponding ground truth labels, to be used with
-    # advanced indexing
-    sampled_pos_inds_subset = torch.where(labels > 0)[0]
-    labels_pos = labels[sampled_pos_inds_subset]
-    N, num_classes = class_logits.shape
-    box_regression = box_regression.reshape(N, box_regression.size(-1) // 4, 4)
-
-    box_loss = F.smooth_l1_loss(
-        box_regression[sampled_pos_inds_subset, labels_pos],
-        regression_targets[sampled_pos_inds_subset],
-        beta=1 / 9,
-        reduction="sum",
-    )
-    box_loss = box_loss / labels.numel()
-
-    return classification_loss, box_loss
-
-
-class Fasterrcnn_resnet50(nn.Module):
-    def __init__(self, num_classes=5, num_stacks=1):
-        super(Fasterrcnn_resnet50, self).__init__()
-
-        self.model = get_model(num_classes=5)
-        self.backbone = self.model.backbone
-
-        self.box_roi_pool = MultiScaleRoIAlign(featmap_names=["0", "1", "2", "3"], output_size=16, sampling_ratio=2)
-
-        out_channels = self.backbone.out_channels
-        resolution = self.box_roi_pool.output_size[0]
-        representation_size = 1024
-        self.box_head = TwoMLPHead(out_channels * resolution ** 2, representation_size)
-
-        self.box_predictor = FastRCNNPredictor(representation_size, num_classes)
-
-        # 多任务输出层
-        self.score_layers = nn.ModuleList([
-            nn.Sequential(
-                nn.Conv2d(256, 128, kernel_size=3, padding=1),
-                nn.BatchNorm2d(128),
-                nn.ReLU(inplace=True),
-                nn.Conv2d(128, num_classes, kernel_size=1)
-            )
-            for _ in range(num_stacks)
-        ])
-
-    def forward(self, x, target1, train_or_val, image_shapes=(512, 512)):
-
-        transform = GeneralizedRCNNTransform(min_size=512, max_size=1333, image_mean=[0.485, 0.456, 0.406],
-                                             image_std=[0.229, 0.224, 0.225])
-        images, targets = transform(x, target1)
-        x_ = self.backbone(images.tensors)
-
-        # x_ = self.backbone(x)  # '0'  '1'  '2'  '3'   'pool'
-        # print(f'backbone:{self.backbone}')
-        # print(f'Fasterrcnn_resnet50 x_:{x_}')
-        feature_ = x_['0']  # 图片特征
-        outputs = []
-        for score_layer in self.score_layers:
-            output = score_layer(feature_)
-            outputs.append(output)  # 多头
-
-        if train_or_val == "training":
-            loss_box = self.model(x, target1)
-            return outputs, feature_, loss_box
-        else:
-            box_all = self.model(x, target1)
-            return outputs, feature_, box_all
-
-
-def fasterrcnn_resnet50(**kwargs):
-    model = Fasterrcnn_resnet50(
-        num_classes=kwargs.get("num_classes", 5),
-        num_stacks=kwargs.get("num_stacks", 1)
-    )
-    return model

+ 24 - 0
models/line_detect/line_head.py

@@ -0,0 +1,24 @@
+import torch
+from torch import nn
+
+
+class LineRCNNHeads(nn.Sequential):
+    def __init__(self, input_channels, num_class):
+        super(LineRCNNHeads, self).__init__()
+        # print("输入的维度是:", input_channels)
+        m = int(input_channels / 4)
+        heads = []
+        self.head_size = [[2], [1], [2]]
+        for output_channels in sum(self.head_size, []):
+            heads.append(
+                nn.Sequential(
+                    nn.Conv2d(input_channels, m, kernel_size=3, padding=1),
+                    nn.ReLU(inplace=True),
+                    nn.Conv2d(m, output_channels, kernel_size=1),
+                )
+            )
+        self.heads = nn.ModuleList(heads)
+        assert num_class == sum(sum(self.head_size, []))
+
+    def forward(self, x):
+        return torch.cat([head(x) for head in self.heads], dim=1)

+ 77 - 12
models/line_detect/LineNet.py → models/line_detect/line_net.py

@@ -1,25 +1,33 @@
-from typing import Any, Callable, List, Optional, Tuple, Union
 
+from typing import Any, Callable, List, Optional, Tuple, Union
 import torch
-import torch.nn.functional as F
 from torch import nn
 from torchvision.ops import MultiScaleRoIAlign
 
-from  libs.vision_libs.ops import misc as misc_nn_ops
+from libs.vision_libs.models import MobileNet_V3_Large_Weights, mobilenet_v3_large
+from libs.vision_libs.models.detection.anchor_utils import AnchorGenerator
+from libs.vision_libs.models.detection.rpn import RPNHead, RegionProposalNetwork
+from libs.vision_libs.models.detection.ssdlite import _mobilenet_extractor
+from libs.vision_libs.models.detection.transform import GeneralizedRCNNTransform
+from libs.vision_libs.ops import misc as misc_nn_ops
 from libs.vision_libs.transforms._presets import ObjectDetection
+from .line_head import LineRCNNHeads
+from .line_predictor import LineRCNNPredictor
+from .roi_heads import RoIHeads
 from libs.vision_libs.models._api import register_model, Weights, WeightsEnum
-from libs.vision_libs.models._meta import _COCO_CATEGORIES
+from libs.vision_libs.models._meta import _COCO_PERSON_CATEGORIES, _COCO_PERSON_KEYPOINT_NAMES, _COCO_CATEGORIES
 from libs.vision_libs.models._utils import _ovewrite_value_param, handle_legacy_interface
-from libs.vision_libs.models.mobilenetv3 import mobilenet_v3_large, MobileNet_V3_Large_Weights
 from libs.vision_libs.models.resnet import resnet50, ResNet50_Weights
 from libs.vision_libs.models.detection._utils import overwrite_eps
-from libs.vision_libs.models.detection.anchor_utils import AnchorGenerator
-from libs.vision_libs.models.detection.backbone_utils import _mobilenet_extractor, _resnet_fpn_extractor, _validate_trainable_layers
+from libs.vision_libs.models.detection.backbone_utils import _resnet_fpn_extractor, _validate_trainable_layers
+from libs.vision_libs.models.detection.faster_rcnn import FasterRCNN, TwoMLPHead, FastRCNNPredictor
 
-from libs.vision_libs.models.detection.rpn import RegionProposalNetwork, RPNHead
-from libs.vision_libs.models.detection.transform import GeneralizedRCNNTransform
+from models.config.config_tool import read_yaml
+import numpy as np
+import torch.nn.functional as F
 
-######## 弃用  ###########
+FEATURE_DIM = 8
+device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
 
 __all__ = [
     "LineNet",
@@ -28,10 +36,50 @@ __all__ = [
     "LineNet_MobileNet_V3_Large_FPN_Weights",
     "LineNet_MobileNet_V3_Large_320_FPN_Weights",
     "linenet_resnet50_fpn",
-    "fasterrcnn_resnet50_fpn_v2",
+    "linenet_resnet50_fpn_v2",
     "linenet_mobilenet_v3_large_fpn",
     "linenet_mobilenet_v3_large_320_fpn",
 ]
+# __all__ = [
+#     "LineNet",
+#     "LineRCNN_ResNet50_FPN_Weights",
+#     "linercnn_resnet50_fpn",
+# ]
+
+
+def non_maximum_suppression(a):
+    ap = F.max_pool2d(a, 3, stride=1, padding=1)
+    mask = (a == ap).float().clamp(min=0.0)
+    return a * mask
+
+
+# class Bottleneck1D(nn.Module):
+#     def __init__(self, inplanes, outplanes):
+#         super(Bottleneck1D, self).__init__()
+#
+#         planes = outplanes // 2
+#         self.op = nn.Sequential(
+#             nn.BatchNorm1d(inplanes),
+#             nn.ReLU(inplace=True),
+#             nn.Conv1d(inplanes, planes, kernel_size=1),
+#             nn.BatchNorm1d(planes),
+#             nn.ReLU(inplace=True),
+#             nn.Conv1d(planes, planes, kernel_size=3, padding=1),
+#             nn.BatchNorm1d(planes),
+#             nn.ReLU(inplace=True),
+#             nn.Conv1d(planes, outplanes, kernel_size=1),
+#         )
+#
+#     def forward(self, x):
+#         return x + self.op(x)
+
+
+
+
+
+
+
+
 
 from .roi_heads import RoIHeads
 
@@ -199,6 +247,9 @@ class LineNet(BaseDetectionNet):
         box_batch_size_per_image=512,
         box_positive_fraction=0.25,
         bbox_reg_weights=None,
+        # line parameters
+        line_head=None,
+        line_predictor=None,
         **kwargs,
     ):
 
@@ -227,6 +278,13 @@ class LineNet(BaseDetectionNet):
 
         out_channels = backbone.out_channels
 
+        if line_head is None:
+            num_class = 5
+            line_head = LineRCNNHeads(out_channels, num_class)
+
+        if line_predictor is None:
+            line_predictor = LineRCNNPredictor()
+
         if rpn_anchor_generator is None:
             rpn_anchor_generator = _default_anchorgen()
         if rpn_head is None:
@@ -265,6 +323,8 @@ class LineNet(BaseDetectionNet):
             box_roi_pool,
             box_head,
             box_predictor,
+            line_head,
+            line_predictor,
             box_fg_iou_thresh,
             box_bg_iou_thresh,
             box_batch_size_per_image,
@@ -283,6 +343,10 @@ class LineNet(BaseDetectionNet):
 
         super().__init__(backbone, rpn, roi_heads, transform)
 
+        self.roi_heads = roi_heads
+        # self.roi_heads.line_head = line_head
+        # self.roi_heads.line_predictor = line_predictor
+
 
 class TwoMLPHead(nn.Module):
     """
@@ -587,7 +651,7 @@ def linenet_resnet50_fpn(
     weights=("pretrained", LineNet_ResNet50_FPN_V2_Weights.COCO_V1),
     weights_backbone=("pretrained_backbone", ResNet50_Weights.IMAGENET1K_V1),
 )
-def fasterrcnn_resnet50_fpn_v2(
+def linenet_resnet50_fpn_v2(
     *,
     weights: Optional[LineNet_ResNet50_FPN_V2_Weights] = None,
     progress: bool = True,
@@ -845,3 +909,4 @@ def linenet_mobilenet_v3_large_fpn(
         trainable_backbone_layers=trainable_backbone_layers,
         **kwargs,
     )
+

+ 324 - 0
models/line_detect/line_predictor.py

@@ -0,0 +1,324 @@
+from typing import Any, Optional
+
+import torch
+from torch import nn
+from torchvision.ops import MultiScaleRoIAlign
+
+from libs.vision_libs.ops import misc as misc_nn_ops
+from libs.vision_libs.transforms._presets import ObjectDetection
+from .roi_heads import RoIHeads
+from libs.vision_libs.models._api import register_model, Weights, WeightsEnum
+from libs.vision_libs.models._meta import _COCO_PERSON_CATEGORIES, _COCO_PERSON_KEYPOINT_NAMES
+from libs.vision_libs.models._utils import _ovewrite_value_param, handle_legacy_interface
+from libs.vision_libs.models.resnet import resnet50, ResNet50_Weights
+from libs.vision_libs.models.detection._utils import overwrite_eps
+from libs.vision_libs.models.detection.backbone_utils import _resnet_fpn_extractor, _validate_trainable_layers
+from libs.vision_libs.models.detection.faster_rcnn import FasterRCNN, TwoMLPHead, FastRCNNPredictor
+
+from models.config.config_tool import read_yaml
+import numpy as np
+import torch.nn.functional as F
+
+FEATURE_DIM = 8
+def non_maximum_suppression(a):
+    ap = F.max_pool2d(a, 3, stride=1, padding=1)
+    mask = (a == ap).float().clamp(min=0.0)
+    return a * mask
+
+class LineRCNNPredictor(nn.Module):
+    def __init__(self):
+        super().__init__()
+        # self.backbone = backbone
+        # self.cfg = read_yaml(cfg)
+        self.cfg = read_yaml(r'./config/wireframe.yaml')
+        self.n_pts0 = self.cfg['model']['n_pts0']
+        self.n_pts1 = self.cfg['model']['n_pts1']
+        self.n_stc_posl = self.cfg['model']['n_stc_posl']
+        self.dim_loi = self.cfg['model']['dim_loi']
+        self.use_conv = self.cfg['model']['use_conv']
+        self.dim_fc = self.cfg['model']['dim_fc']
+        self.n_out_line = self.cfg['model']['n_out_line']
+        self.n_out_junc = self.cfg['model']['n_out_junc']
+        self.loss_weight = self.cfg['model']['loss_weight']
+        self.n_dyn_junc = self.cfg['model']['n_dyn_junc']
+        self.eval_junc_thres = self.cfg['model']['eval_junc_thres']
+        self.n_dyn_posl = self.cfg['model']['n_dyn_posl']
+        self.n_dyn_negl = self.cfg['model']['n_dyn_negl']
+        self.n_dyn_othr = self.cfg['model']['n_dyn_othr']
+        self.use_cood = self.cfg['model']['use_cood']
+        self.use_slop = self.cfg['model']['use_slop']
+        self.n_stc_negl = self.cfg['model']['n_stc_negl']
+        self.head_size = self.cfg['model']['head_size']
+        self.num_class = sum(sum(self.head_size, []))
+        self.head_off = np.cumsum([sum(h) for h in self.head_size])
+
+        lambda_ = torch.linspace(0, 1, self.n_pts0)[:, None]
+        self.register_buffer("lambda_", lambda_)
+        self.do_static_sampling = self.n_stc_posl + self.n_stc_negl > 0
+
+        self.fc1 = nn.Conv2d(256, self.dim_loi, 1)
+        scale_factor = self.n_pts0 // self.n_pts1
+        if self.use_conv:
+            self.pooling = nn.Sequential(
+                nn.MaxPool1d(scale_factor, scale_factor),
+                Bottleneck1D(self.dim_loi, self.dim_loi),
+            )
+            self.fc2 = nn.Sequential(
+                nn.ReLU(inplace=True), nn.Linear(self.dim_loi * self.n_pts1 + FEATURE_DIM, 1)
+            )
+        else:
+            self.pooling = nn.MaxPool1d(scale_factor, scale_factor)
+            self.fc2 = nn.Sequential(
+                nn.Linear(self.dim_loi * self.n_pts1 + FEATURE_DIM, self.dim_fc),
+                nn.ReLU(inplace=True),
+                nn.Linear(self.dim_fc, self.dim_fc),
+                nn.ReLU(inplace=True),
+                nn.Linear(self.dim_fc, 1),
+            )
+        self.loss = nn.BCEWithLogitsLoss(reduction="none")
+
+    def forward(self, inputs, features, targets=None):
+
+        # outputs, features = input
+        # for out in outputs:
+        #     print(f'out:{out.shape}')
+        # outputs=merge_features(outputs,100)
+        batch, channel, row, col = inputs.shape
+        # print(f'outputs:{inputs.shape}')
+        # print(f'batch:{batch}, channel:{channel}, row:{row}, col:{col}')
+
+        if targets is not None:
+            self.training = True
+            # print(f'target:{targets}')
+            wires_targets = [t["wires"] for t in targets]
+            # print(f'wires_target:{wires_targets}')
+            # 提取所有 'junc_map', 'junc_offset', 'line_map' 的张量
+            junc_maps = [d["junc_map"] for d in wires_targets]
+            junc_offsets = [d["junc_offset"] for d in wires_targets]
+            line_maps = [d["line_map"] for d in wires_targets]
+
+            junc_map_tensor = torch.stack(junc_maps, dim=0)
+            junc_offset_tensor = torch.stack(junc_offsets, dim=0)
+            line_map_tensor = torch.stack(line_maps, dim=0)
+
+            wires_meta = {
+                "junc_map": junc_map_tensor,
+                "junc_offset": junc_offset_tensor,
+                # "line_map": line_map_tensor,
+            }
+        else:
+            self.training = False
+            t = {
+                "junc_coords": torch.zeros(1, 2),
+                "jtyp": torch.zeros(1, dtype=torch.uint8),
+                "line_pos_idx": torch.zeros(2, 2, dtype=torch.uint8),
+                "line_neg_idx": torch.zeros(2, 2, dtype=torch.uint8),
+                "junc_map": torch.zeros([1, 1, 128, 128]),
+                "junc_offset": torch.zeros([1, 1, 2, 128, 128]),
+            }
+            wires_targets = [t for b in range(inputs.size(0))]
+
+            wires_meta = {
+                "junc_map": torch.zeros([1, 1, 128, 128]),
+                "junc_offset": torch.zeros([1, 1, 2, 128, 128]),
+            }
+
+        T = wires_meta.copy()
+        n_jtyp = T["junc_map"].shape[1]
+        offset = self.head_off
+        result = {}
+        for stack, output in enumerate([inputs]):
+            output = output.transpose(0, 1).reshape([-1, batch, row, col]).contiguous()
+            # print(f"Stack {stack} output shape: {output.shape}")  # 打印每层的输出形状
+            jmap = output[0: offset[0]].reshape(n_jtyp, 2, batch, row, col)
+            lmap = output[offset[0]: offset[1]].squeeze(0)
+            joff = output[offset[1]: offset[2]].reshape(n_jtyp, 2, batch, row, col)
+
+            if stack == 0:
+                result["preds"] = {
+                    "jmap": jmap.permute(2, 0, 1, 3, 4).softmax(2)[:, :, 1],
+                    "lmap": lmap.sigmoid(),
+                    "joff": joff.permute(2, 0, 1, 3, 4).sigmoid() - 0.5,
+                }
+                # visualize_feature_map(jmap[0, 0], title=f"jmap - Stack {stack}")
+                # visualize_feature_map(lmap, title=f"lmap - Stack {stack}")
+                # visualize_feature_map(joff[0, 0], title=f"joff - Stack {stack}")
+
+        h = result["preds"]
+        # print(f'features shape:{features.shape}')
+        x = self.fc1(features)
+
+        # print(f'x:{x.shape}')
+
+        n_batch, n_channel, row, col = x.shape
+
+        # print(f'n_batch:{n_batch}, n_channel:{n_channel}, row:{row}, col:{col}')
+
+        xs, ys, fs, ps, idx, jcs = [], [], [], [], [0], []
+
+        for i, meta in enumerate(wires_targets):
+            p, label, feat, jc = self.sample_lines(
+                meta, h["jmap"][i], h["joff"][i],
+            )
+            # print(f"p.shape:{p.shape},label:{label.shape},feat:{feat.shape},jc:{len(jc)}")
+            ys.append(label)
+            if self.training and self.do_static_sampling:
+                p = torch.cat([p, meta["lpre"]])
+                feat = torch.cat([feat, meta["lpre_feat"]])
+                ys.append(meta["lpre_label"])
+                del jc
+            else:
+                jcs.append(jc)
+                ps.append(p)
+            fs.append(feat)
+
+            p = p[:, 0:1, :] * self.lambda_ + p[:, 1:2, :] * (1 - self.lambda_) - 0.5
+            p = p.reshape(-1, 2)  # [N_LINE x N_POINT, 2_XY]
+            px, py = p[:, 0].contiguous(), p[:, 1].contiguous()
+            px0 = px.floor().clamp(min=0, max=127)
+            py0 = py.floor().clamp(min=0, max=127)
+            px1 = (px0 + 1).clamp(min=0, max=127)
+            py1 = (py0 + 1).clamp(min=0, max=127)
+            px0l, py0l, px1l, py1l = px0.long(), py0.long(), px1.long(), py1.long()
+
+            # xp: [N_LINE, N_CHANNEL, N_POINT]
+            xp = (
+                (
+                        x[i, :, px0l, py0l] * (px1 - px) * (py1 - py)
+                        + x[i, :, px1l, py0l] * (px - px0) * (py1 - py)
+                        + x[i, :, px0l, py1l] * (px1 - px) * (py - py0)
+                        + x[i, :, px1l, py1l] * (px - px0) * (py - py0)
+                )
+                .reshape(n_channel, -1, self.n_pts0)
+                .permute(1, 0, 2)
+            )
+            xp = self.pooling(xp)
+            # print(f'xp.shape:{xp.shape}')
+            xs.append(xp)
+            idx.append(idx[-1] + xp.shape[0])
+            # print(f'idx__:{idx}')
+
+        x, y = torch.cat(xs), torch.cat(ys)
+        f = torch.cat(fs)
+        x = x.reshape(-1, self.n_pts1 * self.dim_loi)
+
+        # print("Weight dtype:", self.fc2.weight.dtype)
+        x = torch.cat([x, f], 1)
+        # print("Input dtype:", x.dtype)
+        x = x.to(dtype=torch.float32)
+        # print("Input dtype1:", x.dtype)
+        x = self.fc2(x).flatten()
+
+        # return  x,idx,jcs,n_batch,ps,self.n_out_line,self.n_out_junc
+        return x, y, idx, jcs, n_batch, ps, self.n_out_line, self.n_out_junc
+
+        # if mode != "training":
+        # self.inference(x, idx, jcs, n_batch, ps)
+
+        # return result
+
+    def sample_lines(self, meta, jmap, joff):
+        device = jmap.device
+        with torch.no_grad():
+            junc = meta["junc_coords"].to(device)  # [N, 2]
+            jtyp = meta["jtyp"].to(device)  # [N]
+            Lpos = meta["line_pos_idx"].to(device)
+            Lneg = meta["line_neg_idx"].to(device)
+
+            n_type = jmap.shape[0]
+            jmap = non_maximum_suppression(jmap).reshape(n_type, -1)
+            joff = joff.reshape(n_type, 2, -1)
+            max_K = self.n_dyn_junc // n_type
+            N = len(junc)
+            # if mode != "training":
+            if not self.training:
+                K = min(int((jmap > self.eval_junc_thres).float().sum().item()), max_K)
+            else:
+                K = min(int(N * 2 + 2), max_K)
+            if K < 2:
+                K = 2
+            device = jmap.device
+
+            # index: [N_TYPE, K]
+            score, index = torch.topk(jmap, k=K)
+            y = (index // 128).float() + torch.gather(joff[:, 0], 1, index) + 0.5
+            x = (index % 128).float() + torch.gather(joff[:, 1], 1, index) + 0.5
+
+            # xy: [N_TYPE, K, 2]
+            xy = torch.cat([y[..., None], x[..., None]], dim=-1)
+            xy_ = xy[..., None, :]
+            del x, y, index
+
+            # dist: [N_TYPE, K, N]
+            dist = torch.sum((xy_ - junc) ** 2, -1)
+            cost, match = torch.min(dist, -1)
+
+            # xy: [N_TYPE * K, 2]
+            # match: [N_TYPE, K]
+            for t in range(n_type):
+                match[t, jtyp[match[t]] != t] = N
+            match[cost > 1.5 * 1.5] = N
+            match = match.flatten()
+
+            _ = torch.arange(n_type * K, device=device)
+            u, v = torch.meshgrid(_, _)
+            u, v = u.flatten(), v.flatten()
+            up, vp = match[u], match[v]
+            label = Lpos[up, vp]
+
+            # if mode == "training":
+            if self.training:
+                c = torch.zeros_like(label, dtype=torch.bool)
+
+                # sample positive lines
+                cdx = label.nonzero().flatten()
+                if len(cdx) > self.n_dyn_posl:
+                    # print("too many positive lines")
+                    perm = torch.randperm(len(cdx), device=device)[: self.n_dyn_posl]
+                    cdx = cdx[perm]
+                c[cdx] = 1
+
+                # sample negative lines
+                cdx = Lneg[up, vp].nonzero().flatten()
+                if len(cdx) > self.n_dyn_negl:
+                    # print("too many negative lines")
+                    perm = torch.randperm(len(cdx), device=device)[: self.n_dyn_negl]
+                    cdx = cdx[perm]
+                c[cdx] = 1
+
+                # sample other (unmatched) lines
+                cdx = torch.randint(len(c), (self.n_dyn_othr,), device=device)
+                c[cdx] = 1
+            else:
+                c = (u < v).flatten()
+
+            # sample lines
+            u, v, label = u[c], v[c], label[c]
+            xy = xy.reshape(n_type * K, 2)
+            xyu, xyv = xy[u], xy[v]
+
+            u2v = xyu - xyv
+            u2v /= torch.sqrt((u2v ** 2).sum(-1, keepdim=True)).clamp(min=1e-6)
+            feat = torch.cat(
+                [
+                    xyu / 128 * self.use_cood,
+                    xyv / 128 * self.use_cood,
+                    u2v * self.use_slop,
+                    (u[:, None] > K).float(),
+                    (v[:, None] > K).float(),
+                ],
+                1,
+            )
+            line = torch.cat([xyu[:, None], xyv[:, None]], 1)
+
+            xy = xy.reshape(n_type, K, 2)
+            jcs = [xy[i, score[i] > 0.03] for i in range(n_type)]
+            return line, label.float(), feat, jcs
+
+
+
+_COMMON_META = {
+    "categories": _COCO_PERSON_CATEGORIES,
+    "keypoint_names": _COCO_PERSON_KEYPOINT_NAMES,
+    "min_size": (1, 1),
+}

+ 0 - 1020
models/line_detect/line_rcnn.py

@@ -1,1020 +0,0 @@
-from typing import Any, Optional
-
-import torch
-from torch import nn
-from torchvision.ops import MultiScaleRoIAlign
-
-from libs.vision_libs.ops import misc as misc_nn_ops
-from libs.vision_libs.transforms._presets import ObjectDetection
-from .roi_heads import RoIHeads
-from libs.vision_libs.models._api  import register_model, Weights, WeightsEnum
-from libs.vision_libs.models._meta import _COCO_PERSON_CATEGORIES, _COCO_PERSON_KEYPOINT_NAMES
-from libs.vision_libs.models._utils import _ovewrite_value_param, handle_legacy_interface
-from libs.vision_libs.models.resnet import resnet50, ResNet50_Weights
-from libs.vision_libs.models.detection._utils import overwrite_eps
-from libs.vision_libs.models.detection.backbone_utils import _resnet_fpn_extractor, _validate_trainable_layers
-from libs.vision_libs.models.detection.faster_rcnn import FasterRCNN, TwoMLPHead, FastRCNNPredictor
-
-from models.config.config_tool import read_yaml
-import numpy as np
-import torch.nn.functional as F
-
-FEATURE_DIM = 8
-device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
-
-__all__ = [
-    "LineRCNN",
-    "LineRCNN_ResNet50_FPN_Weights",
-    "linercnn_resnet50_fpn",
-]
-
-def non_maximum_suppression(a):
-    ap = F.max_pool2d(a, 3, stride=1, padding=1)
-    mask = (a == ap).float().clamp(min=0.0)
-    return a * mask
-
-
-class Bottleneck1D(nn.Module):
-    def __init__(self, inplanes, outplanes):
-        super(Bottleneck1D, self).__init__()
-
-        planes = outplanes // 2
-        self.op = nn.Sequential(
-            nn.BatchNorm1d(inplanes),
-            nn.ReLU(inplace=True),
-            nn.Conv1d(inplanes, planes, kernel_size=1),
-            nn.BatchNorm1d(planes),
-            nn.ReLU(inplace=True),
-            nn.Conv1d(planes, planes, kernel_size=3, padding=1),
-            nn.BatchNorm1d(planes),
-            nn.ReLU(inplace=True),
-            nn.Conv1d(planes, outplanes, kernel_size=1),
-        )
-
-    def forward(self, x):
-        return x + self.op(x)
-
-class LineRCNN(FasterRCNN):
-    """
-    Implements Keypoint R-CNN.
-
-    The input to the model is expected to be a list of tensors, each of shape [C, H, W], one for each
-    image, and should be in 0-1 range. Different images can have different sizes.
-
-    The behavior of the model changes depending on if it is in training or evaluation mode.
-
-    During training, the model expects both the input tensors and targets (list of dictionary),
-    containing:
-
-        - boxes (``FloatTensor[N, 4]``): the ground-truth boxes in ``[x1, y1, x2, y2]`` format, with
-            ``0 <= x1 < x2 <= W`` and ``0 <= y1 < y2 <= H``.
-        - labels (Int64Tensor[N]): the class label for each ground-truth box
-        - keypoints (FloatTensor[N, K, 3]): the K keypoints location for each of the N instances, in the
-          format [x, y, visibility], where visibility=0 means that the keypoint is not visible.
-
-    The model returns a Dict[Tensor] during training, containing the classification and regression
-    losses for both the RPN and the R-CNN, and the keypoint loss.
-
-    During inference, the model requires only the input tensors, and returns the post-processed
-    predictions as a List[Dict[Tensor]], one for each input image. The fields of the Dict are as
-    follows:
-
-        - boxes (``FloatTensor[N, 4]``): the predicted boxes in ``[x1, y1, x2, y2]`` format, with
-            ``0 <= x1 < x2 <= W`` and ``0 <= y1 < y2 <= H``.
-        - labels (Int64Tensor[N]): the predicted labels for each image
-        - scores (Tensor[N]): the scores or each prediction
-        - keypoints (FloatTensor[N, K, 3]): the locations of the predicted keypoints, in [x, y, v] format.
-
-    Args:
-        backbone (nn.Module): the network used to compute the features for the model.
-            It should contain an out_channels attribute, which indicates the number of output
-            channels that each feature map has (and it should be the same for all feature maps).
-            The backbone should return a single Tensor or and OrderedDict[Tensor].
-        num_classes (int): number of output classes of the model (including the background).
-            If box_predictor is specified, num_classes should be None.
-        min_size (int): minimum size of the image to be rescaled before feeding it to the backbone
-        max_size (int): maximum size of the image to be rescaled before feeding it to the backbone
-        image_mean (Tuple[float, float, float]): mean values used for input normalization.
-            They are generally the mean values of the dataset on which the backbone has been trained
-            on
-        image_std (Tuple[float, float, float]): std values used for input normalization.
-            They are generally the std values of the dataset on which the backbone has been trained on
-        rpn_anchor_generator (AnchorGenerator): module that generates the anchors for a set of feature
-            maps.
-        rpn_head (nn.Module): module that computes the objectness and regression deltas from the RPN
-        rpn_pre_nms_top_n_train (int): number of proposals to keep before applying NMS during training
-        rpn_pre_nms_top_n_test (int): number of proposals to keep before applying NMS during testing
-        rpn_post_nms_top_n_train (int): number of proposals to keep after applying NMS during training
-        rpn_post_nms_top_n_test (int): number of proposals to keep after applying NMS during testing
-        rpn_nms_thresh (float): NMS threshold used for postprocessing the RPN proposals
-        rpn_fg_iou_thresh (float): minimum IoU between the anchor and the GT box so that they can be
-            considered as positive during training of the RPN.
-        rpn_bg_iou_thresh (float): maximum IoU between the anchor and the GT box so that they can be
-            considered as negative during training of the RPN.
-        rpn_batch_size_per_image (int): number of anchors that are sampled during training of the RPN
-            for computing the loss
-        rpn_positive_fraction (float): proportion of positive anchors in a mini-batch during training
-            of the RPN
-        rpn_score_thresh (float): during inference, only return proposals with a classification score
-            greater than rpn_score_thresh
-        box_roi_pool (MultiScaleRoIAlign): the module which crops and resizes the feature maps in
-            the locations indicated by the bounding boxes
-        box_head (nn.Module): module that takes the cropped feature maps as input
-        box_predictor (nn.Module): module that takes the output of box_head and returns the
-            classification logits and box regression deltas.
-        box_score_thresh (float): during inference, only return proposals with a classification score
-            greater than box_score_thresh
-        box_nms_thresh (float): NMS threshold for the prediction head. Used during inference
-        box_detections_per_img (int): maximum number of detections per image, for all classes.
-        box_fg_iou_thresh (float): minimum IoU between the proposals and the GT box so that they can be
-            considered as positive during training of the classification head
-        box_bg_iou_thresh (float): maximum IoU between the proposals and the GT box so that they can be
-            considered as negative during training of the classification head
-        box_batch_size_per_image (int): number of proposals that are sampled during training of the
-            classification head
-        box_positive_fraction (float): proportion of positive proposals in a mini-batch during training
-            of the classification head
-        bbox_reg_weights (Tuple[float, float, float, float]): weights for the encoding/decoding of the
-            bounding boxes
-        keypoint_roi_pool (MultiScaleRoIAlign): the module which crops and resizes the feature maps in
-             the locations indicated by the bounding boxes, which will be used for the keypoint head.
-        keypoint_head (nn.Module): module that takes the cropped feature maps as input
-        keypoint_predictor (nn.Module): module that takes the output of the keypoint_head and returns the
-            heatmap logits
-
-    Example::
-
-        >>> import torch
-        >>> import torchvision
-        >>> from torchvision.models.detection import KeypointRCNN
-        >>> from torchvision.models.detection.anchor_utils import AnchorGenerator
-        >>>
-        >>> # load a pre-trained model for classification and return
-        >>> # only the features
-        >>> backbone = torchvision.models.mobilenet_v2(weights=MobileNet_V2_Weights.DEFAULT).features
-        >>> # KeypointRCNN needs to know the number of
-        >>> # output channels in a backbone. For mobilenet_v2, it's 1280,
-        >>> # so we need to add it here
-        >>> backbone.out_channels = 1280
-        >>>
-        >>> # let's make the RPN generate 5 x 3 anchors per spatial
-        >>> # location, with 5 different sizes and 3 different aspect
-        >>> # ratios. We have a Tuple[Tuple[int]] because each feature
-        >>> # map could potentially have different sizes and
-        >>> # aspect ratios
-        >>> anchor_generator = AnchorGenerator(sizes=((32, 64, 128, 256, 512),),
-        >>>                                    aspect_ratios=((0.5, 1.0, 2.0),))
-        >>>
-        >>> # let's define what are the feature maps that we will
-        >>> # use to perform the region of interest cropping, as well as
-        >>> # the size of the crop after rescaling.
-        >>> # if your backbone returns a Tensor, featmap_names is expected to
-        >>> # be ['0']. More generally, the backbone should return an
-        >>> # OrderedDict[Tensor], and in featmap_names you can choose which
-        >>> # feature maps to use.
-        >>> roi_pooler = torchvision.ops.MultiScaleRoIAlign(featmap_names=['0'],
-        >>>                                                 output_size=7,
-        >>>                                                 sampling_ratio=2)
-        >>>
-        >>> keypoint_roi_pooler = torchvision.ops.MultiScaleRoIAlign(featmap_names=['0'],
-        >>>                                                          output_size=14,
-        >>>                                                          sampling_ratio=2)
-        >>> # put the pieces together inside a KeypointRCNN model
-        >>> model = KeypointRCNN(backbone,
-        >>>                      num_classes=2,
-        >>>                      rpn_anchor_generator=anchor_generator,
-        >>>                      box_roi_pool=roi_pooler,
-        >>>                      keypoint_roi_pool=keypoint_roi_pooler)
-        >>> model.eval()
-        >>> model.eval()
-        >>> x = [torch.rand(3, 300, 400), torch.rand(3, 500, 400)]
-        >>> predictions = model(x)
-    """
-
-    def __init__(
-            self,
-            backbone,
-            num_classes=None,
-            # transform parameters
-            min_size=512,   # 原为None
-            max_size=1333,
-            image_mean=None,
-            image_std=None,
-            # RPN parameters
-            rpn_anchor_generator=None,
-            rpn_head=None,
-            rpn_pre_nms_top_n_train=2000,
-            rpn_pre_nms_top_n_test=1000,
-            rpn_post_nms_top_n_train=2000,
-            rpn_post_nms_top_n_test=1000,
-            rpn_nms_thresh=0.7,
-            rpn_fg_iou_thresh=0.7,
-            rpn_bg_iou_thresh=0.3,
-            rpn_batch_size_per_image=256,
-            rpn_positive_fraction=0.5,
-            rpn_score_thresh=0.0,
-            # Box parameters
-            box_roi_pool=None,
-            box_head=None,
-            box_predictor=None,
-            box_score_thresh=0.05,
-            box_nms_thresh=0.5,
-            box_detections_per_img=100,
-            box_fg_iou_thresh=0.5,
-            box_bg_iou_thresh=0.5,
-            box_batch_size_per_image=512,
-            box_positive_fraction=0.25,
-            bbox_reg_weights=None,
-            # line parameters
-            line_head=None,
-            line_predictor=None,
-
-            **kwargs,
-    ):
-
-        # if not isinstance(keypoint_roi_pool, (MultiScaleRoIAlign, type(None))):
-        #     raise TypeError(
-        #         "keypoint_roi_pool should be of type MultiScaleRoIAlign or None instead of {type(keypoint_roi_pool)}"
-        #     )
-        # if min_size is None:
-        #     min_size = (640, 672, 704, 736, 768, 800)
-        #
-        # if num_keypoints is not None:
-        #     if keypoint_predictor is not None:
-        #         raise ValueError("num_keypoints should be None when keypoint_predictor is specified")
-        # else:
-        #     num_keypoints = 17
-
-        out_channels = backbone.out_channels
-
-        if line_head is None:
-            # keypoint_layers = tuple(512 for _ in range(8))
-            num_class = 5
-            line_head = LineRCNNHeads(out_channels, num_class)
-
-        if line_predictor is None:
-            keypoint_dim_reduced = 512  # == keypoint_layers[-1]
-            line_predictor = LineRCNNPredictor()
-
-        super().__init__(
-            backbone,
-            num_classes,
-            # transform parameters
-            min_size,
-            max_size,
-            image_mean,
-            image_std,
-            # RPN-specific parameters
-            rpn_anchor_generator,
-            rpn_head,
-            rpn_pre_nms_top_n_train,
-            rpn_pre_nms_top_n_test,
-            rpn_post_nms_top_n_train,
-            rpn_post_nms_top_n_test,
-            rpn_nms_thresh,
-            rpn_fg_iou_thresh,
-            rpn_bg_iou_thresh,
-            rpn_batch_size_per_image,
-            rpn_positive_fraction,
-            rpn_score_thresh,
-            # Box parameters
-            box_roi_pool,
-            box_head,
-            box_predictor,
-            box_score_thresh,
-            box_nms_thresh,
-            box_detections_per_img,
-            box_fg_iou_thresh,
-            box_bg_iou_thresh,
-            box_batch_size_per_image,
-            box_positive_fraction,
-            bbox_reg_weights,
-            **kwargs,
-        )
-
-        if box_roi_pool is None:
-            box_roi_pool = MultiScaleRoIAlign(featmap_names=["0", "1", "2", "3"], output_size=14, sampling_ratio=2)
-
-        if box_head is None:
-            resolution = box_roi_pool.output_size[0]
-            representation_size = 1024
-            box_head = TwoMLPHead(out_channels * resolution ** 2, representation_size)
-
-        if box_predictor is None:
-            representation_size = 1024
-            box_predictor = FastRCNNPredictor(representation_size, num_classes)
-
-        roi_heads = RoIHeads(
-            # Box
-            box_roi_pool,
-            box_head,
-            box_predictor,
-            line_head,
-            line_predictor,
-            box_fg_iou_thresh,
-            box_bg_iou_thresh,
-            box_batch_size_per_image,
-            box_positive_fraction,
-            bbox_reg_weights,
-            box_score_thresh,
-            box_nms_thresh,
-            box_detections_per_img,
-
-        )
-        # super().roi_heads = roi_heads
-        self.roi_heads = roi_heads
-        self.roi_heads.line_head = line_head
-        self.roi_heads.line_predictor = line_predictor
-
-
-class LineRCNNHeads(nn.Sequential):
-    def __init__(self, input_channels, num_class):
-        super(LineRCNNHeads, self).__init__()
-        # print("输入的维度是:", input_channels)
-        m = int(input_channels / 4)
-        heads = []
-        self.head_size = [[2], [1], [2]]
-        for output_channels in sum(self.head_size, []):
-            heads.append(
-                nn.Sequential(
-                    nn.Conv2d(input_channels, m, kernel_size=3, padding=1),
-                    nn.ReLU(inplace=True),
-                    nn.Conv2d(m, output_channels, kernel_size=1),
-                )
-            )
-        self.heads = nn.ModuleList(heads)
-        assert num_class == sum(sum(self.head_size, []))
-
-    def forward(self, x):
-        return torch.cat([head(x) for head in self.heads], dim=1)
-    # def __init__(self, in_channels, layers):
-    #     d = []
-    #     next_feature = in_channels
-    #     for out_channels in layers:
-    #         d.append(nn.Conv2d(next_feature, out_channels, 3, stride=1, padding=1))
-    #         d.append(nn.ReLU(inplace=True))
-    #         next_feature = out_channels
-    #     super().__init__(*d)
-    #     for m in self.children():
-    #         if isinstance(m, nn.Conv2d):
-    #             nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu")
-    #             nn.init.constant_(m.bias, 0)
-
-
-class LineRCNNPredictor(nn.Module):
-    def __init__(self):
-        super().__init__()
-        # self.backbone = backbone
-        # self.cfg = read_yaml(cfg)
-        self.cfg = read_yaml(r'D:\python\PycharmProjects\lcnn-master\lcnn_\MultiVisionModels\config\wireframe.yaml')
-        self.n_pts0 = self.cfg['model']['n_pts0']
-        self.n_pts1 = self.cfg['model']['n_pts1']
-        self.n_stc_posl = self.cfg['model']['n_stc_posl']
-        self.dim_loi = self.cfg['model']['dim_loi']
-        self.use_conv = self.cfg['model']['use_conv']
-        self.dim_fc = self.cfg['model']['dim_fc']
-        self.n_out_line = self.cfg['model']['n_out_line']
-        self.n_out_junc = self.cfg['model']['n_out_junc']
-        self.loss_weight = self.cfg['model']['loss_weight']
-        self.n_dyn_junc = self.cfg['model']['n_dyn_junc']
-        self.eval_junc_thres = self.cfg['model']['eval_junc_thres']
-        self.n_dyn_posl = self.cfg['model']['n_dyn_posl']
-        self.n_dyn_negl = self.cfg['model']['n_dyn_negl']
-        self.n_dyn_othr = self.cfg['model']['n_dyn_othr']
-        self.use_cood = self.cfg['model']['use_cood']
-        self.use_slop = self.cfg['model']['use_slop']
-        self.n_stc_negl = self.cfg['model']['n_stc_negl']
-        self.head_size = self.cfg['model']['head_size']
-        self.num_class = sum(sum(self.head_size, []))
-        self.head_off = np.cumsum([sum(h) for h in self.head_size])
-
-        lambda_ = torch.linspace(0, 1, self.n_pts0)[:, None]
-        self.register_buffer("lambda_", lambda_)
-        self.do_static_sampling = self.n_stc_posl + self.n_stc_negl > 0
-
-        self.fc1 = nn.Conv2d(256, self.dim_loi, 1)
-        scale_factor = self.n_pts0 // self.n_pts1
-        if self.use_conv:
-            self.pooling = nn.Sequential(
-                nn.MaxPool1d(scale_factor, scale_factor),
-                Bottleneck1D(self.dim_loi, self.dim_loi),
-            )
-            self.fc2 = nn.Sequential(
-                nn.ReLU(inplace=True), nn.Linear(self.dim_loi * self.n_pts1 + FEATURE_DIM, 1)
-            )
-        else:
-            self.pooling = nn.MaxPool1d(scale_factor, scale_factor)
-            self.fc2 = nn.Sequential(
-                nn.Linear(self.dim_loi * self.n_pts1 + FEATURE_DIM, self.dim_fc),
-                nn.ReLU(inplace=True),
-                nn.Linear(self.dim_fc, self.dim_fc),
-                nn.ReLU(inplace=True),
-                nn.Linear(self.dim_fc, 1),
-            )
-        self.loss = nn.BCEWithLogitsLoss(reduction="none")
-
-    def forward(self, inputs, features, targets=None):
-
-        # outputs, features = input
-        # for out in outputs:
-        #     print(f'out:{out.shape}')
-        # outputs=merge_features(outputs,100)
-        batch, channel, row, col = inputs.shape
-        # print(f'outputs:{inputs.shape}')
-        # print(f'batch:{batch}, channel:{channel}, row:{row}, col:{col}')
-
-        if targets is not None:
-            self.training = True
-            # print(f'target:{targets}')
-            wires_targets = [t["wires"] for t in targets]
-            # print(f'wires_target:{wires_targets}')
-            # 提取所有 'junc_map', 'junc_offset', 'line_map' 的张量
-            junc_maps = [d["junc_map"] for d in wires_targets]
-            junc_offsets = [d["junc_offset"] for d in wires_targets]
-            line_maps = [d["line_map"] for d in wires_targets]
-
-            junc_map_tensor = torch.stack(junc_maps, dim=0)
-            junc_offset_tensor = torch.stack(junc_offsets, dim=0)
-            line_map_tensor = torch.stack(line_maps, dim=0)
-
-            wires_meta = {
-                "junc_map": junc_map_tensor,
-                "junc_offset": junc_offset_tensor,
-                # "line_map": line_map_tensor,
-            }
-        else:
-            self.training = False
-            t = {
-                "junc_coords": torch.zeros(1, 2),
-                "jtyp": torch.zeros(1, dtype=torch.uint8),
-                "line_pos_idx": torch.zeros(2, 2, dtype=torch.uint8),
-                "line_neg_idx": torch.zeros(2, 2, dtype=torch.uint8),
-                "junc_map": torch.zeros([1, 1, 128, 128]),
-                "junc_offset": torch.zeros([1, 1, 2, 128, 128]),
-            }
-            wires_targets = [t for b in range(inputs.size(0))]
-
-            wires_meta = {
-                "junc_map": torch.zeros([1, 1, 128, 128]),
-                "junc_offset": torch.zeros([1, 1, 2, 128, 128]),
-            }
-
-        T = wires_meta.copy()
-        n_jtyp = T["junc_map"].shape[1]
-        offset = self.head_off
-        result = {}
-        for stack, output in enumerate([inputs]):
-            output = output.transpose(0, 1).reshape([-1, batch, row, col]).contiguous()
-            # print(f"Stack {stack} output shape: {output.shape}")  # 打印每层的输出形状
-            jmap = output[0: offset[0]].reshape(n_jtyp, 2, batch, row, col)
-            lmap = output[offset[0]: offset[1]].squeeze(0)
-            joff = output[offset[1]: offset[2]].reshape(n_jtyp, 2, batch, row, col)
-
-            if stack == 0:
-                result["preds"] = {
-                    "jmap": jmap.permute(2, 0, 1, 3, 4).softmax(2)[:, :, 1],
-                    "lmap": lmap.sigmoid(),
-                    "joff": joff.permute(2, 0, 1, 3, 4).sigmoid() - 0.5,
-                }
-                # visualize_feature_map(jmap[0, 0], title=f"jmap - Stack {stack}")
-                # visualize_feature_map(lmap, title=f"lmap - Stack {stack}")
-                # visualize_feature_map(joff[0, 0], title=f"joff - Stack {stack}")
-
-        h = result["preds"]
-        # print(f'features shape:{features.shape}')
-        x = self.fc1(features)
-
-        # print(f'x:{x.shape}')
-
-        n_batch, n_channel, row, col = x.shape
-
-        # print(f'n_batch:{n_batch}, n_channel:{n_channel}, row:{row}, col:{col}')
-
-        xs, ys, fs, ps, idx, jcs = [], [], [], [], [0], []
-
-        for i, meta in enumerate(wires_targets):
-            p, label, feat, jc = self.sample_lines(
-                meta, h["jmap"][i], h["joff"][i],
-            )
-            # print(f"p.shape:{p.shape},label:{label.shape},feat:{feat.shape},jc:{len(jc)}")
-            ys.append(label)
-            if self.training and self.do_static_sampling:
-                p = torch.cat([p, meta["lpre"]])
-                feat = torch.cat([feat, meta["lpre_feat"]])
-                ys.append(meta["lpre_label"])
-                del jc
-            else:
-                jcs.append(jc)
-                ps.append(p)
-            fs.append(feat)
-
-            p = p[:, 0:1, :] * self.lambda_ + p[:, 1:2, :] * (1 - self.lambda_) - 0.5
-            p = p.reshape(-1, 2)  # [N_LINE x N_POINT, 2_XY]
-            px, py = p[:, 0].contiguous(), p[:, 1].contiguous()
-            px0 = px.floor().clamp(min=0, max=127)
-            py0 = py.floor().clamp(min=0, max=127)
-            px1 = (px0 + 1).clamp(min=0, max=127)
-            py1 = (py0 + 1).clamp(min=0, max=127)
-            px0l, py0l, px1l, py1l = px0.long(), py0.long(), px1.long(), py1.long()
-
-            # xp: [N_LINE, N_CHANNEL, N_POINT]
-            xp = (
-                (
-                        x[i, :, px0l, py0l] * (px1 - px) * (py1 - py)
-                        + x[i, :, px1l, py0l] * (px - px0) * (py1 - py)
-                        + x[i, :, px0l, py1l] * (px1 - px) * (py - py0)
-                        + x[i, :, px1l, py1l] * (px - px0) * (py - py0)
-                )
-                    .reshape(n_channel, -1, self.n_pts0)
-                    .permute(1, 0, 2)
-            )
-            xp = self.pooling(xp)
-            # print(f'xp.shape:{xp.shape}')
-            xs.append(xp)
-            idx.append(idx[-1] + xp.shape[0])
-            # print(f'idx__:{idx}')
-
-        x, y = torch.cat(xs), torch.cat(ys)
-        f = torch.cat(fs)
-        x = x.reshape(-1, self.n_pts1 * self.dim_loi)
-
-        # print("Weight dtype:", self.fc2.weight.dtype)
-        x = torch.cat([x, f], 1)
-        # print("Input dtype:", x.dtype)
-        x = x.to(dtype=torch.float32)
-        # print("Input dtype1:", x.dtype)
-        x = self.fc2(x).flatten()
-
-        # return  x,idx,jcs,n_batch,ps,self.n_out_line,self.n_out_junc
-        return x, y, idx, jcs, n_batch, ps, self.n_out_line, self.n_out_junc
-
-        # if mode != "training":
-        # self.inference(x, idx, jcs, n_batch, ps)
-
-        # return result
-
-    def sample_lines(self, meta, jmap, joff):
-        with torch.no_grad():
-            junc = meta["junc_coords"]  # [N, 2]
-            jtyp = meta["jtyp"]  # [N]
-            Lpos = meta["line_pos_idx"]
-            Lneg = meta["line_neg_idx"]
-
-            n_type = jmap.shape[0]
-            jmap = non_maximum_suppression(jmap).reshape(n_type, -1)
-            joff = joff.reshape(n_type, 2, -1)
-            max_K = self.n_dyn_junc // n_type
-            N = len(junc)
-            # if mode != "training":
-            if not self.training:
-                K = min(int((jmap > self.eval_junc_thres).float().sum().item()), max_K)
-            else:
-                K = min(int(N * 2 + 2), max_K)
-            if K < 2:
-                K = 2
-            device = jmap.device
-
-            # index: [N_TYPE, K]
-            score, index = torch.topk(jmap, k=K)
-            y = (index // 128).float() + torch.gather(joff[:, 0], 1, index) + 0.5
-            x = (index % 128).float() + torch.gather(joff[:, 1], 1, index) + 0.5
-
-            # xy: [N_TYPE, K, 2]
-            xy = torch.cat([y[..., None], x[..., None]], dim=-1)
-            xy_ = xy[..., None, :]
-            del x, y, index
-
-            # dist: [N_TYPE, K, N]
-            dist = torch.sum((xy_ - junc) ** 2, -1)
-            cost, match = torch.min(dist, -1)
-
-            # xy: [N_TYPE * K, 2]
-            # match: [N_TYPE, K]
-            for t in range(n_type):
-                match[t, jtyp[match[t]] != t] = N
-            match[cost > 1.5 * 1.5] = N
-            match = match.flatten()
-
-            _ = torch.arange(n_type * K, device=device)
-            u, v = torch.meshgrid(_, _)
-            u, v = u.flatten(), v.flatten()
-            up, vp = match[u], match[v]
-            label = Lpos[up, vp]
-
-            # if mode == "training":
-            if self.training:
-                c = torch.zeros_like(label, dtype=torch.bool)
-
-                # sample positive lines
-                cdx = label.nonzero().flatten()
-                if len(cdx) > self.n_dyn_posl:
-                    # print("too many positive lines")
-                    perm = torch.randperm(len(cdx), device=device)[: self.n_dyn_posl]
-                    cdx = cdx[perm]
-                c[cdx] = 1
-
-                # sample negative lines
-                cdx = Lneg[up, vp].nonzero().flatten()
-                if len(cdx) > self.n_dyn_negl:
-                    # print("too many negative lines")
-                    perm = torch.randperm(len(cdx), device=device)[: self.n_dyn_negl]
-                    cdx = cdx[perm]
-                c[cdx] = 1
-
-                # sample other (unmatched) lines
-                cdx = torch.randint(len(c), (self.n_dyn_othr,), device=device)
-                c[cdx] = 1
-            else:
-                c = (u < v).flatten()
-
-            # sample lines
-            u, v, label = u[c], v[c], label[c]
-            xy = xy.reshape(n_type * K, 2)
-            xyu, xyv = xy[u], xy[v]
-
-            u2v = xyu - xyv
-            u2v /= torch.sqrt((u2v ** 2).sum(-1, keepdim=True)).clamp(min=1e-6)
-            feat = torch.cat(
-                [
-                    xyu / 128 * self.use_cood,
-                    xyv / 128 * self.use_cood,
-                    u2v * self.use_slop,
-                    (u[:, None] > K).float(),
-                    (v[:, None] > K).float(),
-                ],
-                1,
-            )
-            line = torch.cat([xyu[:, None], xyv[:, None]], 1)
-
-            xy = xy.reshape(n_type, K, 2)
-            jcs = [xy[i, score[i] > 0.03] for i in range(n_type)]
-            return line, label.float(), feat, jcs
-    # def forward(self, result, targets=None):
-    #
-    #     # result = self.backbone(input_dict)
-    #     h = result["preds"]
-    #     x = self.fc1(result["feature"])
-    #     n_batch, n_channel, row, col = x.shape
-    #
-    #     if targets is not None:
-    #         self.training = True
-    #         # print(f'target:{targets}')
-    #         wires_targets = [t["wires"] for t in targets]
-    #         # print(f'wires_target:{wires_targets}')
-    #         # 提取所有 'junc_map', 'junc_offset', 'line_map' 的张量
-    #         junc_maps = [d["junc_map"] for d in wires_targets]
-    #         junc_offsets = [d["junc_offset"] for d in wires_targets]
-    #         line_maps = [d["line_map"] for d in wires_targets]
-    #
-    #         junc_map_tensor = torch.stack(junc_maps, dim=0)
-    #         junc_offset_tensor = torch.stack(junc_offsets, dim=0)
-    #         line_map_tensor = torch.stack(line_maps, dim=0)
-    #
-    #         wires_meta = {
-    #             "junc_map": junc_map_tensor,
-    #             "junc_offset": junc_offset_tensor,
-    #             # "line_map": line_map_tensor,
-    #         }
-    #     else:
-    #         self.training = False
-    #         # self.training = False
-    #         t = {
-    #             "junc_coords": torch.zeros(1, 2).to(device),
-    #             "jtyp": torch.zeros(1, dtype=torch.uint8).to(device),
-    #             "line_pos_idx": torch.zeros(2, 2, dtype=torch.uint8).to(device),
-    #             "line_neg_idx": torch.zeros(2, 2, dtype=torch.uint8).to(device),
-    #             "junc_map": torch.zeros([1, 1, 128, 128]).to(device),
-    #             "junc_offset": torch.zeros([1, 1, 2, 128, 128]).to(device),
-    #         }
-    #         wires_targets = [t for b in range(inputs.size(0))]
-    #
-    #         wires_meta = {
-    #             "junc_map": torch.zeros([1, 1, 128, 128]).to(device),
-    #             "junc_offset": torch.zeros([1, 1, 2, 128, 128]).to(device),
-    #         }
-    #
-    #     xs, ys, fs, ps, idx, jcs = [], [], [], [], [0], []
-    #     for i, meta in enumerate(input_dict["meta"]):
-    #         p, label, feat, jc = self.sample_lines(
-    #             meta, h["jmap"][i], h["joff"][i], input_dict["mode"]
-    #         )
-    #         # print("p.shape:", p.shape)
-    #         ys.append(label)
-    #         if input_dict["mode"] == "training" and self.do_static_sampling:
-    #             p = torch.cat([p, meta["lpre"]])
-    #             feat = torch.cat([feat, meta["lpre_feat"]])
-    #             ys.append(meta["lpre_label"])
-    #             del jc
-    #         else:
-    #             jcs.append(jc)
-    #             ps.append(p)
-    #         fs.append(feat)
-    #
-    #         p = p[:, 0:1, :] * self.lambda_ + p[:, 1:2, :] * (1 - self.lambda_) - 0.5
-    #         p = p.reshape(-1, 2)  # [N_LINE x N_POINT, 2_XY]
-    #         px, py = p[:, 0].contiguous(), p[:, 1].contiguous()
-    #         px0 = px.floor().clamp(min=0, max=127)
-    #         py0 = py.floor().clamp(min=0, max=127)
-    #         px1 = (px0 + 1).clamp(min=0, max=127)
-    #         py1 = (py0 + 1).clamp(min=0, max=127)
-    #         px0l, py0l, px1l, py1l = px0.long(), py0.long(), px1.long(), py1.long()
-    #
-    #         # xp: [N_LINE, N_CHANNEL, N_POINT]
-    #         xp = (
-    #             (
-    #                     x[i, :, px0l, py0l] * (px1 - px) * (py1 - py)
-    #                     + x[i, :, px1l, py0l] * (px - px0) * (py1 - py)
-    #                     + x[i, :, px0l, py1l] * (px1 - px) * (py - py0)
-    #                     + x[i, :, px1l, py1l] * (px - px0) * (py - py0)
-    #             )
-    #                 .reshape(n_channel, -1, M.n_pts0)
-    #                 .permute(1, 0, 2)
-    #         )
-    #         xp = self.pooling(xp)
-    #         xs.append(xp)
-    #         idx.append(idx[-1] + xp.shape[0])
-    #
-    #
-    #     x, y = torch.cat(xs), torch.cat(ys)
-    #     f = torch.cat(fs)
-    #     x = x.reshape(-1, self.n_pts1 * self.dim_loi)
-    #     x = torch.cat([x, f], 1)
-    #     x = x.to(dtype=torch.float32)
-    #     x = self.fc2(x).flatten()
-    #
-    #     # return  x,idx,jcs,n_batch,ps,self.n_out_line,self.n_out_junc
-    #     all=[x, ys, idx, jcs, n_batch, ps, self.n_out_line, self.n_out_junc]
-    #     return all
-    #     # return x, y, idx, jcs, n_batch, ps, self.n_out_line, self.n_out_junc
-    #
-    #     # if mode != "training":
-    #     # self.inference(x, idx, jcs, n_batch, ps)
-    #
-    #     # return result
-    #
-    # def sample_lines(self, meta, jmap, joff):
-    #     with torch.no_grad():
-    #         junc = meta["junc_coords"]  # [N, 2]
-    #         jtyp = meta["jtyp"]  # [N]
-    #         Lpos = meta["line_pos_idx"]
-    #         Lneg = meta["line_neg_idx"]
-    #
-    #         n_type = jmap.shape[0]
-    #         jmap = non_maximum_suppression(jmap).reshape(n_type, -1)
-    #         joff = joff.reshape(n_type, 2, -1)
-    #         max_K = self.n_dyn_junc // n_type
-    #         N = len(junc)
-    #         # if mode != "training":
-    #         if not self.training:
-    #             K = min(int((jmap > self.eval_junc_thres).float().sum().item()), max_K)
-    #         else:
-    #             K = min(int(N * 2 + 2), max_K)
-    #         if K < 2:
-    #             K = 2
-    #         device = jmap.device
-    #
-    #         # index: [N_TYPE, K]
-    #         score, index = torch.topk(jmap, k=K)
-    #         y = (index // 128).float() + torch.gather(joff[:, 0], 1, index) + 0.5
-    #         x = (index % 128).float() + torch.gather(joff[:, 1], 1, index) + 0.5
-    #
-    #         # xy: [N_TYPE, K, 2]
-    #         xy = torch.cat([y[..., None], x[..., None]], dim=-1)
-    #         xy_ = xy[..., None, :]
-    #         del x, y, index
-    #
-    #         # print(f"xy_.is_cuda: {xy_.is_cuda}")
-    #         # print(f"junc.is_cuda: {junc.is_cuda}")
-    #
-    #         # dist: [N_TYPE, K, N]
-    #         dist = torch.sum((xy_ - junc) ** 2, -1)
-    #         cost, match = torch.min(dist, -1)
-    #
-    #         # xy: [N_TYPE * K, 2]
-    #         # match: [N_TYPE, K]
-    #         for t in range(n_type):
-    #             match[t, jtyp[match[t]] != t] = N
-    #         match[cost > 1.5 * 1.5] = N
-    #         match = match.flatten()
-    #
-    #         _ = torch.arange(n_type * K, device=device)
-    #         u, v = torch.meshgrid(_, _)
-    #         u, v = u.flatten(), v.flatten()
-    #         up, vp = match[u], match[v]
-    #         label = Lpos[up, vp]
-    #
-    #         # if mode == "training":
-    #         if self.training:
-    #             c = torch.zeros_like(label, dtype=torch.bool)
-    #
-    #             # sample positive lines
-    #             cdx = label.nonzero().flatten()
-    #             if len(cdx) > self.n_dyn_posl:
-    #                 # print("too many positive lines")
-    #                 perm = torch.randperm(len(cdx), device=device)[: self.n_dyn_posl]
-    #                 cdx = cdx[perm]
-    #             c[cdx] = 1
-    #
-    #             # sample negative lines
-    #             cdx = Lneg[up, vp].nonzero().flatten()
-    #             if len(cdx) > self.n_dyn_negl:
-    #                 # print("too many negative lines")
-    #                 perm = torch.randperm(len(cdx), device=device)[: self.n_dyn_negl]
-    #                 cdx = cdx[perm]
-    #             c[cdx] = 1
-    #
-    #             # sample other (unmatched) lines
-    #             cdx = torch.randint(len(c), (self.n_dyn_othr,), device=device)
-    #             c[cdx] = 1
-    #         else:
-    #             c = (u < v).flatten()
-    #
-    #         # sample lines
-    #         u, v, label = u[c], v[c], label[c]
-    #         xy = xy.reshape(n_type * K, 2)
-    #         xyu, xyv = xy[u], xy[v]
-    #
-    #         u2v = xyu - xyv
-    #         u2v /= torch.sqrt((u2v ** 2).sum(-1, keepdim=True)).clamp(min=1e-6)
-    #         feat = torch.cat(
-    #             [
-    #                 xyu / 128 * self.use_cood,
-    #                 xyv / 128 * self.use_cood,
-    #                 u2v * self.use_slop,
-    #                 (u[:, None] > K).float(),
-    #                 (v[:, None] > K).float(),
-    #             ],
-    #             1,
-    #         )
-    #         line = torch.cat([xyu[:, None], xyv[:, None]], 1)
-    #
-    #         xy = xy.reshape(n_type, K, 2)
-    #         jcs = [xy[i, score[i] > 0.03] for i in range(n_type)]
-    #         return line, label.float(), feat, jcs
-
-
-
-
-_COMMON_META = {
-    "categories": _COCO_PERSON_CATEGORIES,
-    "keypoint_names": _COCO_PERSON_KEYPOINT_NAMES,
-    "min_size": (1, 1),
-}
-
-
-class LineRCNN_ResNet50_FPN_Weights(WeightsEnum):
-    COCO_LEGACY = Weights(
-        url="https://download.pytorch.org/models/keypointrcnn_resnet50_fpn_coco-9f466800.pth",
-        transforms=ObjectDetection,
-        meta={
-            **_COMMON_META,
-            "num_params": 59137258,
-            "recipe": "https://github.com/pytorch/vision/issues/1606",
-            "_metrics": {
-                "COCO-val2017": {
-                    "box_map": 50.6,
-                    "kp_map": 61.1,
-                }
-            },
-            "_ops": 133.924,
-            "_file_size": 226.054,
-            "_docs": """
-                These weights were produced by following a similar training recipe as on the paper but use a checkpoint
-                from an early epoch.
-            """,
-        },
-    )
-    COCO_V1 = Weights(
-        url="https://download.pytorch.org/models/keypointrcnn_resnet50_fpn_coco-fc266e95.pth",
-        transforms=ObjectDetection,
-        meta={
-            **_COMMON_META,
-            "num_params": 59137258,
-            "recipe": "https://github.com/pytorch/vision/tree/main/references/detection#keypoint-r-cnn",
-            "_metrics": {
-                "COCO-val2017": {
-                    "box_map": 54.6,
-                    "kp_map": 65.0,
-                }
-            },
-            "_ops": 137.42,
-            "_file_size": 226.054,
-            "_docs": """These weights were produced by following a similar training recipe as on the paper.""",
-        },
-    )
-    DEFAULT = COCO_V1
-
-
-@register_model()
-@handle_legacy_interface(
-    weights=(
-            "pretrained",
-            lambda kwargs: LineRCNN_ResNet50_FPN_Weights.COCO_LEGACY
-            if kwargs["pretrained"] == "legacy"
-            else LineRCNN_ResNet50_FPN_Weights.COCO_V1,
-    ),
-    weights_backbone=("pretrained_backbone", ResNet50_Weights.IMAGENET1K_V1),
-)
-def linercnn_resnet50_fpn(
-        *,
-        weights: Optional[LineRCNN_ResNet50_FPN_Weights] = None,
-        progress: bool = True,
-        num_classes: Optional[int] = None,
-        num_keypoints: Optional[int] = None,
-        weights_backbone: Optional[ResNet50_Weights] = ResNet50_Weights.IMAGENET1K_V1,
-        trainable_backbone_layers: Optional[int] = None,
-        **kwargs: Any,
-) -> LineRCNN:
-    """
-    Constructs a Keypoint R-CNN model with a ResNet-50-FPN backbone.
-
-    .. betastatus:: detection module
-
-    Reference: `Mask R-CNN <https://arxiv.org/abs/1703.06870>`__.
-
-    The input to the model is expected to be a list of tensors, each of shape ``[C, H, W]``, one for each
-    image, and should be in ``0-1`` range. Different images can have different sizes.
-
-    The behavior of the model changes depending on if it is in training or evaluation mode.
-
-    During training, the model expects both the input tensors and targets (list of dictionary),
-    containing:
-
-        - boxes (``FloatTensor[N, 4]``): the ground-truth boxes in ``[x1, y1, x2, y2]`` format, with
-          ``0 <= x1 < x2 <= W`` and ``0 <= y1 < y2 <= H``.
-        - labels (``Int64Tensor[N]``): the class label for each ground-truth box
-        - keypoints (``FloatTensor[N, K, 3]``): the ``K`` keypoints location for each of the ``N`` instances, in the
-          format ``[x, y, visibility]``, where ``visibility=0`` means that the keypoint is not visible.
-
-    The model returns a ``Dict[Tensor]`` during training, containing the classification and regression
-    losses for both the RPN and the R-CNN, and the keypoint loss.
-
-    During inference, the model requires only the input tensors, and returns the post-processed
-    predictions as a ``List[Dict[Tensor]]``, one for each input image. The fields of the ``Dict`` are as
-    follows, where ``N`` is the number of detected instances:
-
-        - boxes (``FloatTensor[N, 4]``): the predicted boxes in ``[x1, y1, x2, y2]`` format, with
-          ``0 <= x1 < x2 <= W`` and ``0 <= y1 < y2 <= H``.
-        - labels (``Int64Tensor[N]``): the predicted labels for each instance
-        - scores (``Tensor[N]``): the scores or each instance
-        - keypoints (``FloatTensor[N, K, 3]``): the locations of the predicted keypoints, in ``[x, y, v]`` format.
-
-    For more details on the output, you may refer to :ref:`instance_seg_output`.
-
-    Keypoint R-CNN is exportable to ONNX for a fixed batch size with inputs images of fixed size.
-
-    Example::
-
-        >>> model = torchvision.models.detection.keypointrcnn_resnet50_fpn(weights=KeypointRCNN_ResNet50_FPN_Weights.DEFAULT)
-        >>> model.eval()
-        >>> x = [torch.rand(3, 300, 400), torch.rand(3, 500, 400)]
-        >>> predictions = model(x)
-        >>>
-        >>> # optionally, if you want to export the model to ONNX:
-        >>> torch.onnx.export(model, x, "keypoint_rcnn.onnx", opset_version = 11)
-
-    Args:
-        weights (:class:`~torchvision.models.detection.KeypointRCNN_ResNet50_FPN_Weights`, optional): The
-            pretrained weights to use. See
-            :class:`~torchvision.models.detection.KeypointRCNN_ResNet50_FPN_Weights`
-            below for more details, and possible values. By default, no
-            pre-trained weights are used.
-        progress (bool): If True, displays a progress bar of the download to stderr
-        num_classes (int, optional): number of output classes of the model (including the background)
-        num_keypoints (int, optional): number of keypoints
-        weights_backbone (:class:`~torchvision.models.ResNet50_Weights`, optional): The
-            pretrained weights for the backbone.
-        trainable_backbone_layers (int, optional): number of trainable (not frozen) layers starting from final block.
-            Valid values are between 0 and 5, with 5 meaning all backbone layers are trainable. If ``None`` is
-            passed (the default) this value is set to 3.
-
-    .. autoclass:: torchvision.models.detection.KeypointRCNN_ResNet50_FPN_Weights
-        :members:
-    """
-    weights = LineRCNN_ResNet50_FPN_Weights.verify(weights)
-    weights_backbone = ResNet50_Weights.verify(weights_backbone)
-    if weights is not None:
-        weights_backbone = None
-        num_classes = _ovewrite_value_param("num_classes", num_classes, len(weights.meta["categories"]))
-        num_keypoints = _ovewrite_value_param("num_keypoints", num_keypoints, len(weights.meta["keypoint_names"]))
-    else:
-        if num_classes is None:
-            num_classes = 2
-        if num_keypoints is None:
-            num_keypoints = 17
-
-    is_trained = weights is not None or weights_backbone is not None
-    trainable_backbone_layers = _validate_trainable_layers(is_trained, trainable_backbone_layers, 5, 3)
-    norm_layer = misc_nn_ops.FrozenBatchNorm2d if is_trained else nn.BatchNorm2d
-
-    backbone = resnet50(weights=weights_backbone, progress=progress, norm_layer=norm_layer)
-
-    backbone = _resnet_fpn_extractor(backbone, trainable_backbone_layers)
-    model = LineRCNN(backbone, num_classes, num_keypoints=num_keypoints, **kwargs)
-
-    if weights is not None:
-        model.load_state_dict(weights.get_state_dict(progress=progress, check_hash=True))
-        if weights == LineRCNN_ResNet50_FPN_Weights.COCO_V1:
-            overwrite_eps(model, 0.0)
-
-    return model

+ 13 - 6
models/line_detect/roi_heads.py

@@ -6,7 +6,7 @@ import torchvision
 from torch import nn, Tensor
 from torchvision.ops import boxes as box_ops, roi_align
 
-from libs.vision_libs.models.detection import _utils as det_utils
+import libs.vision_libs.models.detection._utils as det_utils
 
 from collections import OrderedDict
 
@@ -182,17 +182,17 @@ def wirepoint_head_line_loss(targets, output, x, y, idx, loss_weight):
     L = OrderedDict()
     L["junc_map"] = sum(
         cross_entropy_loss(jmap[i], T["junc_map"][i]) for i in range(n_jtyp)
-    )
+    ).mean()
     L["line_map"] = (
         F.binary_cross_entropy_with_logits(lmap, T["line_map"], reduction="none")
             .mean(2)
             .mean(1)
-    )
+    ).mean()
     L["junc_offset"] = sum(
         sigmoid_l1_loss(joff[i, j], T["junc_offset"][i, j], -0.5, T["junc_map"][i])
         for i in range(n_jtyp)
         for j in range(2)
-    )
+    ).mean()
     for loss_name in L:
         L[loss_name].mul_(loss_weight[loss_name])
     losses.append(L)
@@ -209,8 +209,8 @@ def wirepoint_head_line_loss(targets, output, x, y, idx, loss_weight):
 
     lpos = sum_batch(loss_lpos) / sum_batch(lpos_mask).clamp(min=1)
     lneg = sum_batch(loss_lneg) / sum_batch(lneg_mask).clamp(min=1)
-    result["losses"][0]["lpos"] = lpos * loss_weight["lpos"]
-    result["losses"][0]["lneg"] = lneg * loss_weight["lneg"]
+    result["losses"][0]["lpos"] = (lpos * loss_weight["lpos"]).mean()
+    result["losses"][0]["lneg"] = (lneg * loss_weight["lneg"]).mean()
 
     return result
 
@@ -1053,6 +1053,7 @@ class RoIHeads(nn.Module):
 
         features_lcnn = features['0']
         if self.has_line():
+            # print('has line_head')
             outputs = self.line_head(features_lcnn)
             loss_weight = {'junc_map': 8.0, 'line_map': 0.5, 'junc_offset': 0.25, 'lpos': 1, 'lneg': 1}
             x, y, idx, jcs, n_batch, ps, n_out_line, n_out_junc = self.line_predictor(
@@ -1072,10 +1073,16 @@ class RoIHeads(nn.Module):
                 rcnn_loss_wirepoint = wirepoint_head_line_loss(targets, outputs, x, y, idx, loss_weight)
                 loss_wirepoint = {"loss_wirepoint": rcnn_loss_wirepoint}
             else:
+
                 pred = wirepoint_inference(x, idx, jcs, n_batch, ps, n_out_line, n_out_junc)
                 result.append(pred)
                 loss_wirepoint = {}
             losses.update(loss_wirepoint)
+        else:
+            pass
+            # print('has not line_head')
+
+
 
         if self.has_mask():
             mask_proposals = [p["boxes"] for p in result]

+ 111 - 0
predict.py

@@ -0,0 +1,111 @@
+import torch
+import matplotlib.pyplot as plt
+import numpy as np
+from torchvision import transforms
+
+
+def box_line(pred):
+    '''
+    :param pred: 预测结果
+    :return:
+
+    box与line一一对应
+{'box': [0.0, 34.23157501220703, 151.70858764648438, 125.10173797607422], 'line': array([[ 1.9720564, 81.73457  ],
+[ 1.9933801, 41.730167 ]], dtype=float32)}
+    '''
+    box_line = [[] for _ in range((len(pred) - 1))]
+    for idx, box_ in enumerate(pred[0:-1]):
+        box = box_['boxes']  # 是一个tensor
+        line = pred[-1]['wires']['lines'][idx].cpu().numpy() / 128 * 512
+        score = pred[-1]['wires']['score'][idx]
+        for i in box:
+            aaa = {}
+            aaa['box'] = i.tolist()
+            aaa['line'] = []
+            score_max = 0.0
+            for j in range(len(line)):
+                if (line[j][0][0] >= i[0] and line[j][1][0] >= i[0] and line[j][0][0] <= i[2] and
+                        line[j][1][0] <= i[2] and line[j][0][1] >= i[1] and line[j][1][1] >= i[1] and
+                        line[j][0][1] <= i[3] and line[j][1][1] <= i[3]):
+                    if score[j] > score_max:
+                        aaa['line'] = line[j]
+                        score_max = score[j]
+            box_line[idx].append(aaa)
+
+
+def box_line_(pred):
+    '''
+    形式同pred
+    '''
+    for idx, box_ in enumerate(pred[0:-1]):
+        box = box_['boxes']  # 是一个tensor
+        line = pred[-1]['wires']['lines'][idx].cpu().numpy() / 128 * 512
+        score = pred[-1]['wires']['score'][idx]
+        line_ = []
+        for i in box:
+            score_max = 0.0
+            tmp = [[0.0, 0.0], [0.0, 0.0]]
+            for j in range(len(line)):
+                if (line[j][0][0] >= i[0] and line[j][1][0] >= i[0] and line[j][0][0] <= i[2] and
+                        line[j][1][0] <= i[2] and line[j][0][1] >= i[1] and line[j][1][1] >= i[1] and
+                        line[j][0][1] <= i[3] and line[j][1][1] <= i[3]):
+                    if score[j] > score_max:
+                        tmp = line[j]
+                        score_max = score[j]
+            line_.append(tmp)
+        processed_list = torch.tensor(line_)
+        pred[idx]['line'] = processed_list
+    return pred
+
+
+def show_(imgs, pred, epoch, writer):
+    col = [
+        '#1f77b4', '#aec7e8', '#ff7f0e', '#ffbb78', '#2ca02c',
+        '#98df8a', '#d62728', '#ff9896', '#9467bd', '#c5b0d5',
+        '#8c564b', '#c49c94', '#e377c2', '#f7b6d2', '#7f7f7f',
+        '#c7c7c7', '#bcbd22', '#dbdb8d', '#17becf', '#9edae5',
+        '#8dd3c7', '#ffffb3', '#bebada', '#fb8072', '#80b1d3',
+        '#fdb462', '#b3de69', '#fccde5', '#bc80bd', '#ccebc5',
+        '#ffed6f', '#8da0cb', '#e78ac3', '#e5c494', '#b3b3b3',
+        '#fdbf6f', '#ff7f00', '#cab2d6', '#637939', '#b5cf6b',
+        '#cedb9c', '#8c6d31', '#e7969c', '#d6616b', '#7b4173',
+        '#ad494a', '#843c39', '#dd8452', '#f7f7f7', '#cccccc',
+        '#969696', '#525252', '#f7fcfd', '#e5f5f9', '#ccece6',
+        '#99d8c9', '#66c2a4', '#2ca25f', '#008d4c', '#005a32',
+        '#f7fcf0', '#e0f3db', '#ccebc5', '#a8ddb5', '#7bccc4',
+        '#4eb3d3', '#2b8cbe', '#08589e', '#f7fcfd', '#e0ecf4',
+        '#bfd3e6', '#9ebcda', '#8c96c6', '#8c6bb4', '#88419d',
+        '#810f7c', '#4d004b', '#f7f7f7', '#efefef', '#d9d9d9',
+        '#bfbfbf', '#969696', '#737373', '#525252', '#252525',
+        '#000000', '#ffffff', '#ffeda0', '#fed976', '#feb24c',
+        '#fd8d3c', '#fc4e2a', '#e31a1c', '#bd0026', '#800026',
+        '#ff6f61', '#ff9e64', '#ff6347', '#ffa07a', '#fa8072'
+    ]
+    print(len(col))
+    im = imgs[0].permute(1, 2, 0)
+    boxes = pred[0]['boxes'].cpu().numpy()
+    line = pred[0]['line'].cpu().numpy()
+
+    # 可视化预测结
+    fig, ax = plt.subplots(figsize=(10, 10))
+    ax.imshow(np.array(im))
+
+    for idx, box in enumerate(boxes):
+        x0, y0, x1, y1 = box
+        ax.add_patch(
+            plt.Rectangle((x0, y0), x1 - x0, y1 - y0, fill=False, edgecolor=col[idx], linewidth=1))
+
+    for idx, (a, b) in enumerate(line):
+        ax.scatter(a[0], a[1], c=col[99 - idx], s=2)
+        ax.scatter(b[0], b[1], c=col[99 - idx], s=2)
+        ax.plot([a[0], b[0]], [a[1], b[1]], c=col[idx], linewidth=1)
+
+    # 将Matplotlib图像转换为Tensor
+    fig.canvas.draw()
+    image_from_plot = np.frombuffer(fig.canvas.tostring_rgb(), dtype=np.uint8).reshape(
+        fig.canvas.get_width_height()[::-1] + (3,))
+    plt.close()
+    img2 = transforms.ToTensor()(image_from_plot)
+
+    writer.add_image("all", img2, epoch)
+

+ 142 - 188
train——line_rcnn.py

@@ -1,166 +1,3 @@
-# 根据LCNN写的train    2025/2/7
-'''
-#!/usr/bin/env python3
-import datetime
-import glob
-import os
-import os.path as osp
-import platform
-import pprint
-import random
-import shlex
-import shutil
-import subprocess
-import sys
-import numpy as np
-import torch
-import torchvision
-import yaml
-import lcnn
-from lcnn.config import C, M
-from lcnn.datasets import WireframeDataset, collate
-from lcnn.models.line_vectorizer import LineVectorizer
-from lcnn.models.multitask_learner import MultitaskHead, MultitaskLearner
-from torchvision.models import resnet50
-
-from models.line_detect.line_rcnn import linercnn_resnet50_fpn
-
-
-
-def main():
-
-    # 训练配置参数
-    config = {
-        # 数据集配置
-        'datadir': r'D:\python\PycharmProjects\data',  # 数据集目录
-        'config_file': 'config/wireframe.yaml',  # 配置文件路径
-
-        # GPU配置
-        'devices': '0',  # 使用的GPU设备
-        'identifier': 'fasterrcnn_resnet50',  # 训练标识符 stacked_hourglass unet
-
-        # 预训练模型路径
-        # 'pretrained_model': './190418-201834-f8934c6-lr4d10-312k.pth',  # 预训练模型路径
-    }
-
-    # 更新配置
-    C.update(C.from_yaml(filename=config['config_file']))
-    M.update(C.model)
-
-    # 设置随机数种子
-    random.seed(0)
-    np.random.seed(0)
-    torch.manual_seed(0)
-
-    # 设备配置
-    device_name = "cpu"
-    os.environ["CUDA_VISIBLE_DEVICES"] = config['devices']
-
-    if torch.cuda.is_available():
-        device_name = "cuda"
-        torch.backends.cudnn.deterministic = True
-        torch.cuda.manual_seed(0)
-        print("Let's use", torch.cuda.device_count(), "GPU(s)!")
-    else:
-        print("CUDA is not available")
-
-    device = torch.device(device_name)
-
-    # 数据加载
-    kwargs = {
-        "collate_fn": collate,
-        "num_workers": C.io.num_workers if os.name != "nt" else 0,
-        "pin_memory": True,
-    }
-
-    train_loader = torch.utils.data.DataLoader(
-        WireframeDataset(config['datadir'], dataset_type="train"),
-        shuffle=True,
-        batch_size=M.batch_size,
-        **kwargs,
-    )
-
-    val_loader = torch.utils.data.DataLoader(
-        WireframeDataset(config['datadir'], dataset_type="val"),
-        shuffle=False,
-        batch_size=M.batch_size_eval,
-        **kwargs,
-    )
-
-    model = linercnn_resnet50_fpn().to(device)
-
-    # 加载预训练权重
-
-    try:
-        # 加载模型权重
-        checkpoint = torch.load(config['pretrained_model'], map_location=device)
-
-        # 根据实际的检查点结构选择加载方式
-        if 'model_state_dict' in checkpoint:
-            # 如果是完整的检查点
-            model.load_state_dict(checkpoint['model_state_dict'])
-        elif 'state_dict' in checkpoint:
-            # 如果是只有状态字典的检查点
-            model.load_state_dict(checkpoint['state_dict'])
-        else:
-            # 直接加载权重字典
-            model.load_state_dict(checkpoint)
-
-        print("Successfully loaded pre-trained model weights.")
-    except Exception as e:
-        print(f"Error loading model weights: {e}")
-
-
-    # 优化器配置
-    if C.optim.name == "Adam":
-        optim = torch.optim.Adam(
-            filter(lambda p: p.requires_grad, model.parameters()),
-            lr=C.optim.lr,
-            weight_decay=C.optim.weight_decay,
-            amsgrad=C.optim.amsgrad,
-        )
-    elif C.optim.name == "SGD":
-        optim = torch.optim.SGD(
-            filter(lambda p: p.requires_grad, model.parameters()),
-            lr=C.optim.lr,
-            weight_decay=C.optim.weight_decay,
-            momentum=C.optim.momentum,
-        )
-    else:
-        raise NotImplementedError
-
-    # 输出目录
-    outdir = osp.join(
-        osp.expanduser(C.io.logdir),
-        f"{datetime.datetime.now().strftime('%y%m%d-%H%M%S')}-{config['identifier']}"
-    )
-    os.makedirs(outdir, exist_ok=True)
-
-    try:
-        trainer = lcnn.trainer.Trainer(
-            device=device,
-            model=model,
-            optimizer=optim,
-            train_loader=train_loader,
-            val_loader=val_loader,
-            out=outdir,
-        )
-
-        print("Starting training...")
-        trainer.train()
-        print("Training completed.")
-
-    except BaseException:
-        if len(glob.glob(f"{outdir}/viz/*")) <= 1:
-            shutil.rmtree(outdir)
-        raise
-
-
-if __name__ == "__main__":
-    main()
-'''
-
-
 # 2025/2/9
 import os
 from typing import Optional, Any
@@ -178,12 +15,17 @@ import matplotlib.pyplot as plt
 import matplotlib as mpl
 from skimage import io
 
-from models.line_detect.line_rcnn import linercnn_resnet50_fpn
+from models.line_detect.line_net import linenet_resnet50_fpn
 from torchvision.utils import draw_bounding_boxes
 from models.wirenet.postprocess import postprocess
 from torchvision import transforms
 from collections import OrderedDict
 
+from PIL import Image
+
+from predict import box_line_, show_
+import matplotlib.pyplot as plt
+
 device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
 
 
@@ -221,7 +63,7 @@ def imshow(im):
     plt.ylim([im.shape[0], 0])
 
 
-def show_line(img, pred,  epoch, writer):
+def show_line(img, pred, epoch, writer):
     im = img.permute(1, 2, 0)
     writer.add_image("ori", im, epoch, dataformats="HWC")
 
@@ -230,7 +72,8 @@ def show_line(img, pred,  epoch, writer):
     writer.add_image("boxes", boxed_image.permute(1, 2, 0), epoch, dataformats="HWC")
 
     PLTOPTS = {"color": "#33FFFF", "s": 15, "edgecolors": "none", "zorder": 5}
-    H = pred[1]['wires']
+    # print(f'pred[1]:{pred[1]}')
+    H = pred[-1]['wires']
     lines = H["lines"][0].cpu().numpy() / 128 * im.shape[:2]
     scores = H["score"][0].cpu().numpy()
     for i in range(1, len(lines)):
@@ -243,7 +86,7 @@ def show_line(img, pred,  epoch, writer):
     diag = (im.shape[0] ** 2 + im.shape[1] ** 2) ** 0.5
     nlines, nscores = postprocess(lines, scores, diag * 0.01, 0, False)
 
-    for i, t in enumerate([0.8]):
+    for i, t in enumerate([0.85]):
         plt.gca().set_axis_off()
         plt.subplots_adjust(top=1, bottom=0, right=1, left=0, hspace=0, wspace=0)
         plt.margins(0, 0)
@@ -267,8 +110,99 @@ def show_line(img, pred,  epoch, writer):
         writer.add_image("output", img2, epoch)
 
 
+def save_best_model(model, save_path, epoch, current_loss, best_loss, optimizer=None):
+    os.makedirs(os.path.dirname(save_path), exist_ok=True)
+
+    if current_loss < best_loss:
+        checkpoint = {
+            'epoch': epoch,
+            'model_state_dict': model.state_dict(),
+            'loss': current_loss
+        }
+
+        if optimizer is not None:
+            checkpoint['optimizer_state_dict'] = optimizer.state_dict()
+
+        torch.save(checkpoint, save_path)
+        print(f"Saved best model at epoch {epoch} with loss {current_loss:.4f}")
+
+        return current_loss
+
+    return best_loss
+
+
+def save_latest_model(model, save_path, epoch, optimizer=None):
+    os.makedirs(os.path.dirname(save_path), exist_ok=True)
+
+    checkpoint = {
+        'epoch': epoch,
+        'model_state_dict': model.state_dict(),
+    }
+
+    if optimizer is not None:
+        checkpoint['optimizer_state_dict'] = optimizer.state_dict()
+
+    torch.save(checkpoint, save_path)
+
+
+def load_best_model(model, optimizer, save_path, device):
+    if os.path.exists(save_path):
+        checkpoint = torch.load(save_path, map_location=device)
+        model.load_state_dict(checkpoint['model_state_dict'])
+        if optimizer is not None:
+            optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
+        epoch = checkpoint['epoch']
+        loss = checkpoint['loss']
+        print(f"Loaded best model from {save_path} at epoch {epoch} with loss {loss:.4f}")
+    else:
+        print(f"No saved model found at {save_path}")
+    return model, optimizer
+
+
+def predict(self, img, show_boxes=True, show_keypoint=True, show_line=True, save=False, save_path=None):
+    self.load_weight('weights/best.pt')
+    self.__model.eval()
+
+    if isinstance(img, str):
+        img = Image.open(img).convert("RGB")
+
+    # 预处理图像
+    img_tensor = self.transforms(img)
+    with torch.no_grad():
+        predictions = self.__model([img_tensor])
+
+    # 后处理预测结果
+    boxes = predictions[0]['boxes'].cpu().numpy()
+    keypoints = predictions[0]['keypoints'].cpu().numpy()
+
+    # 可视化预测结果
+    if show_boxes or show_keypoint or show_line or save:
+        fig, ax = plt.subplots(figsize=(10, 10))
+        ax.imshow(np.array(img))
+
+        if show_boxes:
+            for box in boxes:
+                x0, y0, x1, y1 = box
+                ax.add_patch(plt.Rectangle((x0, y0), x1 - x0, y1 - y0, fill=False, edgecolor='yellow', linewidth=1))
+
+        for (a, b) in keypoints:
+            if show_keypoint:
+                ax.scatter(a[0], a[1], c='c', s=2)
+                ax.scatter(b[0], b[1], c='c', s=2)
+            if show_line:
+                ax.plot([a[0], b[0]], [a[1], b[1]], c='red', linewidth=1)
+
+        if show_boxes or show_keypoint or show_line:
+            plt.show()
+
+        if save:
+            fig.savefig(save_path)
+            print(f"Prediction saved to {save_path}")
+        plt.close(fig)
+
+
 if __name__ == '__main__':
-    cfg = r'D:\python\PycharmProjects\lcnn-master\lcnn_\MultiVisionModels\config\wireframe.yaml'
+    cfg = r'./config/wireframe.yaml'
     cfg = read_yaml(cfg)
     print(f'cfg:{cfg}')
     print(cfg['model']['n_dyn_negl'])
@@ -277,26 +211,36 @@ if __name__ == '__main__':
     dataset_train = WirePointDataset(dataset_path=cfg['io']['datadir'], dataset_type='train')
     train_sampler = torch.utils.data.RandomSampler(dataset_train)
     # test_sampler = torch.utils.data.SequentialSampler(dataset_test)
-    train_batch_sampler = torch.utils.data.BatchSampler(train_sampler, batch_size=1, drop_last=True)
+    train_batch_sampler = torch.utils.data.BatchSampler(train_sampler, batch_size=2, drop_last=True)
     train_collate_fn = utils.collate_fn_wirepoint
     data_loader_train = torch.utils.data.DataLoader(
-        dataset_train, batch_sampler=train_batch_sampler, num_workers=0, collate_fn=train_collate_fn
+        dataset_train, batch_sampler=train_batch_sampler, num_workers=8, collate_fn=train_collate_fn
     )
 
     dataset_val = WirePointDataset(dataset_path=cfg['io']['datadir'], dataset_type='val')
     val_sampler = torch.utils.data.RandomSampler(dataset_val)
     # test_sampler = torch.utils.data.SequentialSampler(dataset_test)
-    val_batch_sampler = torch.utils.data.BatchSampler(val_sampler, batch_size=1, drop_last=True)
+    val_batch_sampler = torch.utils.data.BatchSampler(val_sampler, batch_size=2, drop_last=True)
     val_collate_fn = utils.collate_fn_wirepoint
     data_loader_val = torch.utils.data.DataLoader(
-        dataset_val, batch_sampler=val_batch_sampler, num_workers=0, collate_fn=val_collate_fn
+        dataset_val, batch_sampler=val_batch_sampler, num_workers=8, collate_fn=val_collate_fn
     )
 
-    model = linercnn_resnet50_fpn().to(device)
-
+    model = linenet_resnet50_fpn().to(device)
     optimizer = torch.optim.Adam(model.parameters(), lr=cfg['optim']['lr'])
     writer = SummaryWriter(cfg['io']['logdir'])
 
+    # 加载权重
+    save_path = 'logs/pth/best_model.pth'
+    model, optimizer = load_best_model(model, optimizer, save_path, device)
+
+    logdir_with_pth = os.path.join(cfg['io']['logdir'], 'pth')
+    os.makedirs(logdir_with_pth, exist_ok=True)  # 创建目录(如果不存在)
+    latest_model_path = os.path.join(logdir_with_pth, 'latest_model.pth')
+    best_model_path = os.path.join(logdir_with_pth, 'best_model.pth')
+    global_step = 0
+
+
     def move_to_device(data, device):
         if isinstance(data, (list, tuple)):
             return type(data)(move_to_device(item, device) for item in data)
@@ -308,19 +252,6 @@ if __name__ == '__main__':
             return data  # 对于非张量类型的数据不做任何改变
 
 
-    # def writer_loss(writer, losses, epoch):
-    #     try:
-    #         for key, value in losses.items():
-    #             if key == 'loss_wirepoint':
-    #                 for subdict in losses['loss_wirepoint']['losses']:
-    #                     for subkey, subvalue in subdict.items():
-    #                         writer.add_scalar(f'loss_wirepoint/{subkey}',
-    #                                           subvalue.item() if hasattr(subvalue, 'item') else subvalue,
-    #                                           epoch)
-    #             elif isinstance(value, torch.Tensor):
-    #                 writer.add_scalar(key, value.item(), epoch)
-    #     except Exception as e:
-    #         print(f"TensorBoard logging error: {e}")
     def writer_loss(writer, losses, epoch):
         try:
             for key, value in losses.items():
@@ -339,11 +270,13 @@ if __name__ == '__main__':
     for epoch in range(cfg['optim']['max_epoch']):
         print(f"epoch:{epoch}")
         model.train()
+        total_train_loss = 0.0
 
         for imgs, targets in data_loader_train:
             losses = model(move_to_device(imgs, device), move_to_device(targets, device))
             # print(losses)
             loss = _loss(losses)
+            total_train_loss += loss.item()
             optimizer.zero_grad()
             loss.backward()
             optimizer.step()
@@ -353,8 +286,29 @@ if __name__ == '__main__':
         with torch.no_grad():
             for batch_idx, (imgs, targets) in enumerate(data_loader_val):
                 pred = model(move_to_device(imgs, device))
+
+                pred_ = box_line_(pred)  # 将box与line对应
+                show_(imgs, pred_, epoch, writer)
+
                 if batch_idx == 0:
                     show_line(imgs[0], pred, epoch, writer)
-                break
 
+                break
 
+        avg_train_loss = total_train_loss / len(data_loader_train)
+        writer.add_scalar('loss/train', avg_train_loss, epoch)
+        best_loss = 10000
+        save_latest_model(
+            model,
+            latest_model_path,
+            epoch,
+            optimizer
+        )
+        best_loss = save_best_model(
+            model,
+            best_model_path,
+            epoch,
+            avg_train_loss,
+            best_loss,
+            optimizer
+        )