Bladeren bron

删除多余代码文件

RenLiqiang 3 maanden geleden
bovenliggende
commit
8a49ca689a

+ 0 - 4
lcnn/__init__.py

@@ -1,4 +0,0 @@
-import lcnn.models
-import lcnn.trainer
-import lcnn.datasets
-import lcnn.config

+ 0 - 1110
lcnn/box.py

@@ -1,1110 +0,0 @@
-#!/usr/bin/env python
-# -*- coding: UTF-8 -*-
-#
-# Copyright (c) 2017-2019 - Chris Griffith - MIT License
-"""
-Improved dictionary access through dot notation with additional tools.
-"""
-import string
-import sys
-import json
-import re
-import copy
-from keyword import kwlist
-import warnings
-
-try:
-    from collections.abc import Iterable, Mapping, Callable
-except ImportError:
-    from collections import Iterable, Mapping, Callable
-
-yaml_support = True
-
-try:
-    import yaml
-except ImportError:
-    try:
-        import ruamel.yaml as yaml
-    except ImportError:
-        yaml = None
-        yaml_support = False
-
-if sys.version_info >= (3, 0):
-    basestring = str
-else:
-    from io import open
-
-__all__ = ['Box', 'ConfigBox', 'BoxList', 'SBox',
-           'BoxError', 'BoxKeyError']
-__author__ = 'Chris Griffith'
-__version__ = '3.2.4'
-
-BOX_PARAMETERS = ('default_box', 'default_box_attr', 'conversion_box',
-                  'frozen_box', 'camel_killer_box', 'box_it_up',
-                  'box_safe_prefix', 'box_duplicates', 'ordered_box')
-
-_first_cap_re = re.compile('(.)([A-Z][a-z]+)')
-_all_cap_re = re.compile('([a-z0-9])([A-Z])')
-
-
-class BoxError(Exception):
-    """Non standard dictionary exceptions"""
-
-
-class BoxKeyError(BoxError, KeyError, AttributeError):
-    """Key does not exist"""
-
-
-# Abstract converter functions for use in any Box class
-
-
-def _to_json(obj, filename=None,
-             encoding="utf-8", errors="strict", **json_kwargs):
-    json_dump = json.dumps(obj,
-                           ensure_ascii=False, **json_kwargs)
-    if filename:
-        with open(filename, 'w', encoding=encoding, errors=errors) as f:
-            f.write(json_dump if sys.version_info >= (3, 0) else
-                    json_dump.decode("utf-8"))
-    else:
-        return json_dump
-
-
-def _from_json(json_string=None, filename=None,
-               encoding="utf-8", errors="strict", multiline=False, **kwargs):
-    if filename:
-        with open(filename, 'r', encoding=encoding, errors=errors) as f:
-            if multiline:
-                data = [json.loads(line.strip(), **kwargs) for line in f
-                        if line.strip() and not line.strip().startswith("#")]
-            else:
-                data = json.load(f, **kwargs)
-    elif json_string:
-        data = json.loads(json_string, **kwargs)
-    else:
-        raise BoxError('from_json requires a string or filename')
-    return data
-
-
-def _to_yaml(obj, filename=None, default_flow_style=False,
-             encoding="utf-8", errors="strict",
-             **yaml_kwargs):
-    if filename:
-        with open(filename, 'w',
-                  encoding=encoding, errors=errors) as f:
-            yaml.dump(obj, stream=f,
-                      default_flow_style=default_flow_style,
-                      **yaml_kwargs)
-    else:
-        return yaml.dump(obj,
-                         default_flow_style=default_flow_style,
-                         **yaml_kwargs)
-
-
-def _from_yaml(yaml_string=None, filename=None,
-               encoding="utf-8", errors="strict",
-               **kwargs):
-    if filename:
-        with open(filename, 'r',
-                  encoding=encoding, errors=errors) as f:
-            data = yaml.load(f, **kwargs)
-    elif yaml_string:
-        data = yaml.load(yaml_string, **kwargs)
-    else:
-        raise BoxError('from_yaml requires a string or filename')
-    return data
-
-
-# Helper functions
-
-
-def _safe_key(key):
-    try:
-        return str(key)
-    except UnicodeEncodeError:
-        return key.encode("utf-8", "ignore")
-
-
-def _safe_attr(attr, camel_killer=False, replacement_char='x'):
-    """Convert a key into something that is accessible as an attribute"""
-    allowed = string.ascii_letters + string.digits + '_'
-
-    attr = _safe_key(attr)
-
-    if camel_killer:
-        attr = _camel_killer(attr)
-
-    attr = attr.replace(' ', '_')
-
-    out = ''
-    for character in attr:
-        out += character if character in allowed else "_"
-    out = out.strip("_")
-
-    try:
-        int(out[0])
-    except (ValueError, IndexError):
-        pass
-    else:
-        out = '{0}{1}'.format(replacement_char, out)
-
-    if out in kwlist:
-        out = '{0}{1}'.format(replacement_char, out)
-
-    return re.sub('_+', '_', out)
-
-
-def _camel_killer(attr):
-    """
-    CamelKiller, qu'est-ce que c'est?
-
-    Taken from http://stackoverflow.com/a/1176023/3244542
-    """
-    try:
-        attr = str(attr)
-    except UnicodeEncodeError:
-        attr = attr.encode("utf-8", "ignore")
-
-    s1 = _first_cap_re.sub(r'\1_\2', attr)
-    s2 = _all_cap_re.sub(r'\1_\2', s1)
-    return re.sub('_+', '_', s2.casefold() if hasattr(s2, 'casefold') else
-                  s2.lower())
-
-
-def _recursive_tuples(iterable, box_class, recreate_tuples=False, **kwargs):
-    out_list = []
-    for i in iterable:
-        if isinstance(i, dict):
-            out_list.append(box_class(i, **kwargs))
-        elif isinstance(i, list) or (recreate_tuples and isinstance(i, tuple)):
-            out_list.append(_recursive_tuples(i, box_class,
-                                              recreate_tuples, **kwargs))
-        else:
-            out_list.append(i)
-    return tuple(out_list)
-
-
-def _conversion_checks(item, keys, box_config, check_only=False,
-                       pre_check=False):
-    """
-    Internal use for checking if a duplicate safe attribute already exists
-
-    :param item: Item to see if a dup exists
-    :param keys: Keys to check against
-    :param box_config: Easier to pass in than ask for specfic items
-    :param check_only: Don't bother doing the conversion work
-    :param pre_check: Need to add the item to the list of keys to check
-    :return: the original unmodified key, if exists and not check_only
-    """
-    if box_config['box_duplicates'] != 'ignore':
-        if pre_check:
-            keys = list(keys) + [item]
-
-        key_list = [(k,
-                     _safe_attr(k, camel_killer=box_config['camel_killer_box'],
-                                replacement_char=box_config['box_safe_prefix']
-                                )) for k in keys]
-        if len(key_list) > len(set(x[1] for x in key_list)):
-            seen = set()
-            dups = set()
-            for x in key_list:
-                if x[1] in seen:
-                    dups.add("{0}({1})".format(x[0], x[1]))
-                seen.add(x[1])
-            if box_config['box_duplicates'].startswith("warn"):
-                warnings.warn('Duplicate conversion attributes exist: '
-                              '{0}'.format(dups))
-            else:
-                raise BoxError('Duplicate conversion attributes exist: '
-                               '{0}'.format(dups))
-    if check_only:
-        return
-    # This way will be slower for warnings, as it will have double work
-    # But faster for the default 'ignore'
-    for k in keys:
-        if item == _safe_attr(k, camel_killer=box_config['camel_killer_box'],
-                              replacement_char=box_config['box_safe_prefix']):
-            return k
-
-
-def _get_box_config(cls, kwargs):
-    return {
-        # Internal use only
-        '__converted': set(),
-        '__box_heritage': kwargs.pop('__box_heritage', None),
-        '__created': False,
-        '__ordered_box_values': [],
-        # Can be changed by user after box creation
-        'default_box': kwargs.pop('default_box', False),
-        'default_box_attr': kwargs.pop('default_box_attr', cls),
-        'conversion_box': kwargs.pop('conversion_box', True),
-        'box_safe_prefix': kwargs.pop('box_safe_prefix', 'x'),
-        'frozen_box': kwargs.pop('frozen_box', False),
-        'camel_killer_box': kwargs.pop('camel_killer_box', False),
-        'modify_tuples_box': kwargs.pop('modify_tuples_box', False),
-        'box_duplicates': kwargs.pop('box_duplicates', 'ignore'),
-        'ordered_box': kwargs.pop('ordered_box', False)
-    }
-
-
-class Box(dict):
-    """
-    Improved dictionary access through dot notation with additional tools.
-
-    :param default_box: Similar to defaultdict, return a default value
-    :param default_box_attr: Specify the default replacement.
-        WARNING: If this is not the default 'Box', it will not be recursive
-    :param frozen_box: After creation, the box cannot be modified
-    :param camel_killer_box: Convert CamelCase to snake_case
-    :param conversion_box: Check for near matching keys as attributes
-    :param modify_tuples_box: Recreate incoming tuples with dicts into Boxes
-    :param box_it_up: Recursively create all Boxes from the start
-    :param box_safe_prefix: Conversion box prefix for unsafe attributes
-    :param box_duplicates: "ignore", "error" or "warn" when duplicates exists
-        in a conversion_box
-    :param ordered_box: Preserve the order of keys entered into the box
-    """
-
-    _protected_keys = dir({}) + ['to_dict', 'tree_view', 'to_json', 'to_yaml',
-                                 'from_yaml', 'from_json']
-
-    def __new__(cls, *args, **kwargs):
-        """
-        Due to the way pickling works in python 3, we need to make sure
-        the box config is created as early as possible.
-        """
-        obj = super(Box, cls).__new__(cls, *args, **kwargs)
-        obj._box_config = _get_box_config(cls, kwargs)
-        return obj
-
-    def __init__(self, *args, **kwargs):
-        self._box_config = _get_box_config(self.__class__, kwargs)
-        if self._box_config['ordered_box']:
-            self._box_config['__ordered_box_values'] = []
-        if (not self._box_config['conversion_box'] and
-                self._box_config['box_duplicates'] != "ignore"):
-            raise BoxError('box_duplicates are only for conversion_boxes')
-        if len(args) == 1:
-            if isinstance(args[0], basestring):
-                raise ValueError('Cannot extrapolate Box from string')
-            if isinstance(args[0], Mapping):
-                for k, v in args[0].items():
-                    if v is args[0]:
-                        v = self
-                    self[k] = v
-                    self.__add_ordered(k)
-            elif isinstance(args[0], Iterable):
-                for k, v in args[0]:
-                    self[k] = v
-                    self.__add_ordered(k)
-
-            else:
-                raise ValueError('First argument must be mapping or iterable')
-        elif args:
-            raise TypeError('Box expected at most 1 argument, '
-                            'got {0}'.format(len(args)))
-
-        box_it = kwargs.pop('box_it_up', False)
-        for k, v in kwargs.items():
-            if args and isinstance(args[0], Mapping) and v is args[0]:
-                v = self
-            self[k] = v
-            self.__add_ordered(k)
-
-        if (self._box_config['frozen_box'] or box_it or
-                self._box_config['box_duplicates'] != 'ignore'):
-            self.box_it_up()
-
-        self._box_config['__created'] = True
-
-    def __add_ordered(self, key):
-        if (self._box_config['ordered_box'] and
-                key not in self._box_config['__ordered_box_values']):
-            self._box_config['__ordered_box_values'].append(key)
-
-    def box_it_up(self):
-        """
-        Perform value lookup for all items in current dictionary,
-        generating all sub Box objects, while also running `box_it_up` on
-        any of those sub box objects.
-        """
-        for k in self:
-            _conversion_checks(k, self.keys(), self._box_config,
-                               check_only=True)
-            if self[k] is not self and hasattr(self[k], 'box_it_up'):
-                self[k].box_it_up()
-
-    def __hash__(self):
-        if self._box_config['frozen_box']:
-            hashing = 54321
-            for item in self.items():
-                hashing ^= hash(item)
-            return hashing
-        raise TypeError("unhashable type: 'Box'")
-
-    def __dir__(self):
-        allowed = string.ascii_letters + string.digits + '_'
-        kill_camel = self._box_config['camel_killer_box']
-        items = set(dir(dict) + ['to_dict', 'to_json',
-                                 'from_json', 'box_it_up'])
-        # Only show items accessible by dot notation
-        for key in self.keys():
-            key = _safe_key(key)
-            if (' ' not in key and key[0] not in string.digits and
-                    key not in kwlist):
-                for letter in key:
-                    if letter not in allowed:
-                        break
-                else:
-                    items.add(key)
-
-        for key in self.keys():
-            key = _safe_key(key)
-            if key not in items:
-                if self._box_config['conversion_box']:
-                    key = _safe_attr(key, camel_killer=kill_camel,
-                                     replacement_char=self._box_config[
-                                         'box_safe_prefix'])
-                    if key:
-                        items.add(key)
-            if kill_camel:
-                snake_key = _camel_killer(key)
-                if snake_key:
-                    items.remove(key)
-                    items.add(snake_key)
-
-        if yaml_support:
-            items.add('to_yaml')
-            items.add('from_yaml')
-
-        return list(items)
-
-    def get(self, key, default=None):
-        try:
-            return self[key]
-        except KeyError:
-            if isinstance(default, dict) and not isinstance(default, Box):
-                return Box(default)
-            if isinstance(default, list) and not isinstance(default, BoxList):
-                return BoxList(default)
-            return default
-
-    def copy(self):
-        return self.__class__(super(self.__class__, self).copy())
-
-    def __copy__(self):
-        return self.__class__(super(self.__class__, self).copy())
-
-    def __deepcopy__(self, memodict=None):
-        out = self.__class__()
-        memodict = memodict or {}
-        memodict[id(self)] = out
-        for k, v in self.items():
-            out[copy.deepcopy(k, memodict)] = copy.deepcopy(v, memodict)
-        return out
-
-    def __setstate__(self, state):
-        self._box_config = state['_box_config']
-        self.__dict__.update(state)
-
-    def __getitem__(self, item, _ignore_default=False):
-        try:
-            value = super(Box, self).__getitem__(item)
-        except KeyError as err:
-            if item == '_box_config':
-                raise BoxKeyError('_box_config should only exist as an '
-                                  'attribute and is never defaulted')
-            if self._box_config['default_box'] and not _ignore_default:
-                return self.__get_default(item)
-            raise BoxKeyError(str(err))
-        else:
-            return self.__convert_and_store(item, value)
-
-    def keys(self):
-        if self._box_config['ordered_box']:
-            return self._box_config['__ordered_box_values']
-        return super(Box, self).keys()
-
-    def values(self):
-        return [self[x] for x in self.keys()]
-
-    def items(self):
-        return [(x, self[x]) for x in self.keys()]
-
-    def __get_default(self, item):
-        default_value = self._box_config['default_box_attr']
-        if default_value is self.__class__:
-            return self.__class__(__box_heritage=(self, item),
-                                  **self.__box_config())
-        elif isinstance(default_value, Callable):
-            return default_value()
-        elif hasattr(default_value, 'copy'):
-            return default_value.copy()
-        return default_value
-
-    def __box_config(self):
-        out = {}
-        for k, v in self._box_config.copy().items():
-            if not k.startswith("__"):
-                out[k] = v
-        return out
-
-    def __convert_and_store(self, item, value):
-        if item in self._box_config['__converted']:
-            return value
-        if isinstance(value, dict) and not isinstance(value, Box):
-            value = self.__class__(value, __box_heritage=(self, item),
-                                   **self.__box_config())
-            self[item] = value
-        elif isinstance(value, list) and not isinstance(value, BoxList):
-            if self._box_config['frozen_box']:
-                value = _recursive_tuples(value, self.__class__,
-                                          recreate_tuples=self._box_config[
-                                              'modify_tuples_box'],
-                                          __box_heritage=(self, item),
-                                          **self.__box_config())
-            else:
-                value = BoxList(value, __box_heritage=(self, item),
-                                box_class=self.__class__,
-                                **self.__box_config())
-            self[item] = value
-        elif (self._box_config['modify_tuples_box'] and
-              isinstance(value, tuple)):
-            value = _recursive_tuples(value, self.__class__,
-                                      recreate_tuples=True,
-                                      __box_heritage=(self, item),
-                                      **self.__box_config())
-            self[item] = value
-        self._box_config['__converted'].add(item)
-        return value
-
-    def __create_lineage(self):
-        if (self._box_config['__box_heritage'] and
-                self._box_config['__created']):
-            past, item = self._box_config['__box_heritage']
-            if not past[item]:
-                past[item] = self
-            self._box_config['__box_heritage'] = None
-
-    def __getattr__(self, item):
-        try:
-            try:
-                value = self.__getitem__(item, _ignore_default=True)
-            except KeyError:
-                value = object.__getattribute__(self, item)
-        except AttributeError as err:
-            if item == "__getstate__":
-                raise AttributeError(item)
-            if item == '_box_config':
-                raise BoxError('_box_config key must exist')
-            kill_camel = self._box_config['camel_killer_box']
-            if self._box_config['conversion_box'] and item:
-                k = _conversion_checks(item, self.keys(), self._box_config)
-                if k:
-                    return self.__getitem__(k)
-            if kill_camel:
-                for k in self.keys():
-                    if item == _camel_killer(k):
-                        return self.__getitem__(k)
-            if self._box_config['default_box']:
-                return self.__get_default(item)
-            raise BoxKeyError(str(err))
-        else:
-            if item == '_box_config':
-                return value
-            return self.__convert_and_store(item, value)
-
-    def __setitem__(self, key, value):
-        if (key != '_box_config' and self._box_config['__created'] and
-                self._box_config['frozen_box']):
-            raise BoxError('Box is frozen')
-        if self._box_config['conversion_box']:
-            _conversion_checks(key, self.keys(), self._box_config,
-                               check_only=True, pre_check=True)
-        super(Box, self).__setitem__(key, value)
-        self.__add_ordered(key)
-        self.__create_lineage()
-
-    def __setattr__(self, key, value):
-        if (key != '_box_config' and self._box_config['frozen_box'] and
-                self._box_config['__created']):
-            raise BoxError('Box is frozen')
-        if key in self._protected_keys:
-            raise AttributeError("Key name '{0}' is protected".format(key))
-        if key == '_box_config':
-            return object.__setattr__(self, key, value)
-        try:
-            object.__getattribute__(self, key)
-        except (AttributeError, UnicodeEncodeError):
-            if (key not in self.keys() and
-                    (self._box_config['conversion_box'] or
-                     self._box_config['camel_killer_box'])):
-                if self._box_config['conversion_box']:
-                    k = _conversion_checks(key, self.keys(),
-                                           self._box_config)
-                    self[key if not k else k] = value
-                elif self._box_config['camel_killer_box']:
-                    for each_key in self:
-                        if key == _camel_killer(each_key):
-                            self[each_key] = value
-                            break
-            else:
-                self[key] = value
-        else:
-            object.__setattr__(self, key, value)
-        self.__add_ordered(key)
-        self.__create_lineage()
-
-    def __delitem__(self, key):
-        if self._box_config['frozen_box']:
-            raise BoxError('Box is frozen')
-        super(Box, self).__delitem__(key)
-        if (self._box_config['ordered_box'] and
-                key in self._box_config['__ordered_box_values']):
-            self._box_config['__ordered_box_values'].remove(key)
-
-    def __delattr__(self, item):
-        if self._box_config['frozen_box']:
-            raise BoxError('Box is frozen')
-        if item == '_box_config':
-            raise BoxError('"_box_config" is protected')
-        if item in self._protected_keys:
-            raise AttributeError("Key name '{0}' is protected".format(item))
-        try:
-            object.__getattribute__(self, item)
-        except AttributeError:
-            del self[item]
-        else:
-            object.__delattr__(self, item)
-        if (self._box_config['ordered_box'] and
-                item in self._box_config['__ordered_box_values']):
-            self._box_config['__ordered_box_values'].remove(item)
-
-    def pop(self, key, *args):
-        if args:
-            if len(args) != 1:
-                raise BoxError('pop() takes only one optional'
-                               ' argument "default"')
-            try:
-                item = self[key]
-            except KeyError:
-                return args[0]
-            else:
-                del self[key]
-                return item
-        try:
-            item = self[key]
-        except KeyError:
-            raise BoxKeyError('{0}'.format(key))
-        else:
-            del self[key]
-            return item
-
-    def clear(self):
-        self._box_config['__ordered_box_values'] = []
-        super(Box, self).clear()
-
-    def popitem(self):
-        try:
-            key = next(self.__iter__())
-        except StopIteration:
-            raise BoxKeyError('Empty box')
-        return key, self.pop(key)
-
-    def __repr__(self):
-        return '<Box: {0}>'.format(str(self.to_dict()))
-
-    def __str__(self):
-        return str(self.to_dict())
-
-    def __iter__(self):
-        for key in self.keys():
-            yield key
-
-    def __reversed__(self):
-        for key in reversed(list(self.keys())):
-            yield key
-
-    def to_dict(self):
-        """
-        Turn the Box and sub Boxes back into a native
-        python dictionary.
-
-        :return: python dictionary of this Box
-        """
-        out_dict = dict(self)
-        for k, v in out_dict.items():
-            if v is self:
-                out_dict[k] = out_dict
-            elif hasattr(v, 'to_dict'):
-                out_dict[k] = v.to_dict()
-            elif hasattr(v, 'to_list'):
-                out_dict[k] = v.to_list()
-        return out_dict
-
-    def update(self, item=None, **kwargs):
-        if not item:
-            item = kwargs
-        iter_over = item.items() if hasattr(item, 'items') else item
-        for k, v in iter_over:
-            if isinstance(v, dict):
-                # Box objects must be created in case they are already
-                # in the `converted` box_config set
-                v = self.__class__(v)
-                if k in self and isinstance(self[k], dict):
-                    self[k].update(v)
-                    continue
-            if isinstance(v, list):
-                v = BoxList(v)
-            try:
-                self.__setattr__(k, v)
-            except (AttributeError, TypeError):
-                self.__setitem__(k, v)
-
-    def setdefault(self, item, default=None):
-        if item in self:
-            return self[item]
-
-        if isinstance(default, dict):
-            default = self.__class__(default)
-        if isinstance(default, list):
-            default = BoxList(default)
-        self[item] = default
-        return default
-
-    def to_json(self, filename=None,
-                encoding="utf-8", errors="strict", **json_kwargs):
-        """
-        Transform the Box object into a JSON string.
-
-        :param filename: If provided will save to file
-        :param encoding: File encoding
-        :param errors: How to handle encoding errors
-        :param json_kwargs: additional arguments to pass to json.dump(s)
-        :return: string of JSON or return of `json.dump`
-        """
-        return _to_json(self.to_dict(), filename=filename,
-                        encoding=encoding, errors=errors, **json_kwargs)
-
-    @classmethod
-    def from_json(cls, json_string=None, filename=None,
-                  encoding="utf-8", errors="strict", **kwargs):
-        """
-        Transform a json object string into a Box object. If the incoming
-        json is a list, you must use BoxList.from_json.
-
-        :param json_string: string to pass to `json.loads`
-        :param filename: filename to open and pass to `json.load`
-        :param encoding: File encoding
-        :param errors: How to handle encoding errors
-        :param kwargs: parameters to pass to `Box()` or `json.loads`
-        :return: Box object from json data
-        """
-        bx_args = {}
-        for arg in kwargs.copy():
-            if arg in BOX_PARAMETERS:
-                bx_args[arg] = kwargs.pop(arg)
-
-        data = _from_json(json_string, filename=filename,
-                          encoding=encoding, errors=errors, **kwargs)
-
-        if not isinstance(data, dict):
-            raise BoxError('json data not returned as a dictionary, '
-                           'but rather a {0}'.format(type(data).__name__))
-        return cls(data, **bx_args)
-
-    if yaml_support:
-        def to_yaml(self, filename=None, default_flow_style=False,
-                    encoding="utf-8", errors="strict",
-                    **yaml_kwargs):
-            """
-            Transform the Box object into a YAML string.
-
-            :param filename:  If provided will save to file
-            :param default_flow_style: False will recursively dump dicts
-            :param encoding: File encoding
-            :param errors: How to handle encoding errors
-            :param yaml_kwargs: additional arguments to pass to yaml.dump
-            :return: string of YAML or return of `yaml.dump`
-            """
-            return _to_yaml(self.to_dict(), filename=filename,
-                            default_flow_style=default_flow_style,
-                            encoding=encoding, errors=errors, **yaml_kwargs)
-
-        @classmethod
-        def from_yaml(cls, yaml_string=None, filename=None,
-                      encoding="utf-8", errors="strict",
-                      loader=yaml.SafeLoader, **kwargs):
-            """
-            Transform a yaml object string into a Box object.
-
-            :param yaml_string: string to pass to `yaml.load`
-            :param filename: filename to open and pass to `yaml.load`
-            :param encoding: File encoding
-            :param errors: How to handle encoding errors
-            :param loader: YAML Loader, defaults to SafeLoader
-            :param kwargs: parameters to pass to `Box()` or `yaml.load`
-            :return: Box object from yaml data
-            """
-            bx_args = {}
-            for arg in kwargs.copy():
-                if arg in BOX_PARAMETERS:
-                    bx_args[arg] = kwargs.pop(arg)
-
-            data = _from_yaml(yaml_string=yaml_string, filename=filename,
-                              encoding=encoding, errors=errors,
-                              Loader=loader, **kwargs)
-            if not isinstance(data, dict):
-                raise BoxError('yaml data not returned as a dictionary'
-                               'but rather a {0}'.format(type(data).__name__))
-            return cls(data, **bx_args)
-
-
-class BoxList(list):
-    """
-    Drop in replacement of list, that converts added objects to Box or BoxList
-    objects as necessary.
-    """
-
-    def __init__(self, iterable=None, box_class=Box, **box_options):
-        self.box_class = box_class
-        self.box_options = box_options
-        self.box_org_ref = self.box_org_ref = id(iterable) if iterable else 0
-        if iterable:
-            for x in iterable:
-                self.append(x)
-        if box_options.get('frozen_box'):
-            def frozen(*args, **kwargs):
-                raise BoxError('BoxList is frozen')
-
-            for method in ['append', 'extend', 'insert', 'pop',
-                           'remove', 'reverse', 'sort']:
-                self.__setattr__(method, frozen)
-
-    def __delitem__(self, key):
-        if self.box_options.get('frozen_box'):
-            raise BoxError('BoxList is frozen')
-        super(BoxList, self).__delitem__(key)
-
-    def __setitem__(self, key, value):
-        if self.box_options.get('frozen_box'):
-            raise BoxError('BoxList is frozen')
-        super(BoxList, self).__setitem__(key, value)
-
-    def append(self, p_object):
-        if isinstance(p_object, dict):
-            try:
-                p_object = self.box_class(p_object, **self.box_options)
-            except AttributeError as err:
-                if 'box_class' in self.__dict__:
-                    raise err
-        elif isinstance(p_object, list):
-            try:
-                p_object = (self if id(p_object) == self.box_org_ref else
-                            BoxList(p_object))
-            except AttributeError as err:
-                if 'box_org_ref' in self.__dict__:
-                    raise err
-        super(BoxList, self).append(p_object)
-
-    def extend(self, iterable):
-        for item in iterable:
-            self.append(item)
-
-    def insert(self, index, p_object):
-        if isinstance(p_object, dict):
-            p_object = self.box_class(p_object, **self.box_options)
-        elif isinstance(p_object, list):
-            p_object = (self if id(p_object) == self.box_org_ref else
-                        BoxList(p_object))
-        super(BoxList, self).insert(index, p_object)
-
-    def __repr__(self):
-        return "<BoxList: {0}>".format(self.to_list())
-
-    def __str__(self):
-        return str(self.to_list())
-
-    def __copy__(self):
-        return BoxList((x for x in self),
-                       self.box_class,
-                       **self.box_options)
-
-    def __deepcopy__(self, memodict=None):
-        out = self.__class__()
-        memodict = memodict or {}
-        memodict[id(self)] = out
-        for k in self:
-            out.append(copy.deepcopy(k))
-        return out
-
-    def __hash__(self):
-        if self.box_options.get('frozen_box'):
-            hashing = 98765
-            hashing ^= hash(tuple(self))
-            return hashing
-        raise TypeError("unhashable type: 'BoxList'")
-
-    def to_list(self):
-        new_list = []
-        for x in self:
-            if x is self:
-                new_list.append(new_list)
-            elif isinstance(x, Box):
-                new_list.append(x.to_dict())
-            elif isinstance(x, BoxList):
-                new_list.append(x.to_list())
-            else:
-                new_list.append(x)
-        return new_list
-
-    def to_json(self, filename=None,
-                encoding="utf-8", errors="strict",
-                multiline=False, **json_kwargs):
-        """
-        Transform the BoxList object into a JSON string.
-
-        :param filename: If provided will save to file
-        :param encoding: File encoding
-        :param errors: How to handle encoding errors
-        :param multiline: Put each item in list onto it's own line
-        :param json_kwargs: additional arguments to pass to json.dump(s)
-        :return: string of JSON or return of `json.dump`
-        """
-        if filename and multiline:
-            lines = [_to_json(item, filename=False, encoding=encoding,
-                              errors=errors, **json_kwargs) for item in self]
-            with open(filename, 'w', encoding=encoding, errors=errors) as f:
-                f.write("\n".join(lines).decode('utf-8') if
-                        sys.version_info < (3, 0) else "\n".join(lines))
-        else:
-            return _to_json(self.to_list(), filename=filename,
-                            encoding=encoding, errors=errors, **json_kwargs)
-
-    @classmethod
-    def from_json(cls, json_string=None, filename=None, encoding="utf-8",
-                  errors="strict", multiline=False, **kwargs):
-        """
-        Transform a json object string into a BoxList object. If the incoming
-        json is a dict, you must use Box.from_json.
-
-        :param json_string: string to pass to `json.loads`
-        :param filename: filename to open and pass to `json.load`
-        :param encoding: File encoding
-        :param errors: How to handle encoding errors
-        :param multiline: One object per line
-        :param kwargs: parameters to pass to `Box()` or `json.loads`
-        :return: BoxList object from json data
-        """
-        bx_args = {}
-        for arg in kwargs.copy():
-            if arg in BOX_PARAMETERS:
-                bx_args[arg] = kwargs.pop(arg)
-
-        data = _from_json(json_string, filename=filename, encoding=encoding,
-                          errors=errors, multiline=multiline, **kwargs)
-
-        if not isinstance(data, list):
-            raise BoxError('json data not returned as a list, '
-                           'but rather a {0}'.format(type(data).__name__))
-        return cls(data, **bx_args)
-
-    if yaml_support:
-        def to_yaml(self, filename=None, default_flow_style=False,
-                    encoding="utf-8", errors="strict",
-                    **yaml_kwargs):
-            """
-            Transform the BoxList object into a YAML string.
-
-            :param filename:  If provided will save to file
-            :param default_flow_style: False will recursively dump dicts
-            :param encoding: File encoding
-            :param errors: How to handle encoding errors
-            :param yaml_kwargs: additional arguments to pass to yaml.dump
-            :return: string of YAML or return of `yaml.dump`
-            """
-            return _to_yaml(self.to_list(), filename=filename,
-                            default_flow_style=default_flow_style,
-                            encoding=encoding, errors=errors, **yaml_kwargs)
-
-        @classmethod
-        def from_yaml(cls, yaml_string=None, filename=None,
-                      encoding="utf-8", errors="strict",
-                      loader=yaml.SafeLoader,
-                      **kwargs):
-            """
-            Transform a yaml object string into a BoxList object.
-
-            :param yaml_string: string to pass to `yaml.load`
-            :param filename: filename to open and pass to `yaml.load`
-            :param encoding: File encoding
-            :param errors: How to handle encoding errors
-            :param loader: YAML Loader, defaults to SafeLoader
-            :param kwargs: parameters to pass to `BoxList()` or `yaml.load`
-            :return: BoxList object from yaml data
-            """
-            bx_args = {}
-            for arg in kwargs.copy():
-                if arg in BOX_PARAMETERS:
-                    bx_args[arg] = kwargs.pop(arg)
-
-            data = _from_yaml(yaml_string=yaml_string, filename=filename,
-                              encoding=encoding, errors=errors,
-                              Loader=loader, **kwargs)
-            if not isinstance(data, list):
-                raise BoxError('yaml data not returned as a list'
-                               'but rather a {0}'.format(type(data).__name__))
-            return cls(data, **bx_args)
-
-    def box_it_up(self):
-        for v in self:
-            if hasattr(v, 'box_it_up') and v is not self:
-                v.box_it_up()
-
-
-class ConfigBox(Box):
-    """
-    Modified box object to add object transforms.
-
-    Allows for build in transforms like:
-
-    cns = ConfigBox(my_bool='yes', my_int='5', my_list='5,4,3,3,2')
-
-    cns.bool('my_bool') # True
-    cns.int('my_int') # 5
-    cns.list('my_list', mod=lambda x: int(x)) # [5, 4, 3, 3, 2]
-    """
-
-    _protected_keys = dir({}) + ['to_dict', 'bool', 'int', 'float',
-                                 'list', 'getboolean', 'to_json', 'to_yaml',
-                                 'getfloat', 'getint',
-                                 'from_json', 'from_yaml']
-
-    def __getattr__(self, item):
-        """Config file keys are stored in lower case, be a little more
-        loosey goosey"""
-        try:
-            return super(ConfigBox, self).__getattr__(item)
-        except AttributeError:
-            return super(ConfigBox, self).__getattr__(item.lower())
-
-    def __dir__(self):
-        return super(ConfigBox, self).__dir__() + ['bool', 'int', 'float',
-                                                   'list', 'getboolean',
-                                                   'getfloat', 'getint']
-
-    def bool(self, item, default=None):
-        """ Return value of key as a boolean
-
-        :param item: key of value to transform
-        :param default: value to return if item does not exist
-        :return: approximated bool of value
-        """
-        try:
-            item = self.__getattr__(item)
-        except AttributeError as err:
-            if default is not None:
-                return default
-            raise err
-
-        if isinstance(item, (bool, int)):
-            return bool(item)
-
-        if (isinstance(item, str) and
-                item.lower() in ('n', 'no', 'false', 'f', '0')):
-            return False
-
-        return True if item else False
-
-    def int(self, item, default=None):
-        """ Return value of key as an int
-
-        :param item: key of value to transform
-        :param default: value to return if item does not exist
-        :return: int of value
-        """
-        try:
-            item = self.__getattr__(item)
-        except AttributeError as err:
-            if default is not None:
-                return default
-            raise err
-        return int(item)
-
-    def float(self, item, default=None):
-        """ Return value of key as a float
-
-        :param item: key of value to transform
-        :param default: value to return if item does not exist
-        :return: float of value
-        """
-        try:
-            item = self.__getattr__(item)
-        except AttributeError as err:
-            if default is not None:
-                return default
-            raise err
-        return float(item)
-
-    def list(self, item, default=None, spliter=",", strip=True, mod=None):
-        """ Return value of key as a list
-
-        :param item: key of value to transform
-        :param mod: function to map against list
-        :param default: value to return if item does not exist
-        :param spliter: character to split str on
-        :param strip: clean the list with the `strip`
-        :return: list of items
-        """
-        try:
-            item = self.__getattr__(item)
-        except AttributeError as err:
-            if default is not None:
-                return default
-            raise err
-        if strip:
-            item = item.lstrip('[').rstrip(']')
-        out = [x.strip() if strip else x for x in item.split(spliter)]
-        if mod:
-            return list(map(mod, out))
-        return out
-
-    # loose configparser compatibility
-
-    def getboolean(self, item, default=None):
-        return self.bool(item, default)
-
-    def getint(self, item, default=None):
-        return self.int(item, default)
-
-    def getfloat(self, item, default=None):
-        return self.float(item, default)
-
-    def __repr__(self):
-        return '<ConfigBox: {0}>'.format(str(self.to_dict()))
-
-
-class SBox(Box):
-    """
-    ShorthandBox (SBox) allows for
-    property access of `dict` `json` and `yaml`
-    """
-    _protected_keys = dir({}) + ['to_dict', 'tree_view', 'to_json', 'to_yaml',
-                                 'json', 'yaml', 'from_yaml', 'from_json',
-                                 'dict']
-
-    @property
-    def dict(self):
-        return self.to_dict()
-
-    @property
-    def json(self):
-        return self.to_json()
-
-    if yaml_support:
-        @property
-        def yaml(self):
-            return self.to_yaml()
-
-    def __repr__(self):
-        return '<ShorthandBox: {0}>'.format(str(self.to_dict()))

+ 0 - 9
lcnn/config.py

@@ -1,9 +0,0 @@
-import numpy as np
-
-from lcnn.box import Box
-
-# C is a dict storing all the configuration
-C = Box()
-
-# shortcut for C.model
-M = Box()

+ 0 - 378
lcnn/dataset_tool.py

@@ -1,378 +0,0 @@
-import cv2
-import numpy as np
-import torch
-import torchvision
-from matplotlib import pyplot as plt
-import tools.transforms as reference_transforms
-from collections import defaultdict
-
-from tools import presets
-
-import json
-
-
-def get_modules(use_v2):
-    # We need a protected import to avoid the V2 warning in case just V1 is used
-    if use_v2:
-        import torchvision.transforms.v2
-        import torchvision.tv_tensors
-
-        return torchvision.transforms.v2, torchvision.tv_tensors
-    else:
-        return reference_transforms, None
-
-
-class Augmentation:
-    # Note: this transform assumes that the input to forward() are always PIL
-    # images, regardless of the backend parameter.
-    def __init__(
-            self,
-            *,
-            data_augmentation,
-            hflip_prob=0.5,
-            mean=(123.0, 117.0, 104.0),
-            backend="pil",
-            use_v2=False,
-    ):
-
-        T, tv_tensors = get_modules(use_v2)
-
-        transforms = []
-        backend = backend.lower()
-        if backend == "tv_tensor":
-            transforms.append(T.ToImage())
-        elif backend == "tensor":
-            transforms.append(T.PILToTensor())
-        elif backend != "pil":
-            raise ValueError(f"backend can be 'tv_tensor', 'tensor' or 'pil', but got {backend}")
-
-        if data_augmentation == "hflip":
-            transforms += [T.RandomHorizontalFlip(p=hflip_prob)]
-        elif data_augmentation == "lsj":
-            transforms += [
-                T.ScaleJitter(target_size=(1024, 1024), antialias=True),
-                # TODO: FixedSizeCrop below doesn't work on tensors!
-                reference_transforms.FixedSizeCrop(size=(1024, 1024), fill=mean),
-                T.RandomHorizontalFlip(p=hflip_prob),
-            ]
-        elif data_augmentation == "multiscale":
-            transforms += [
-                T.RandomShortestSize(min_size=(480, 512, 544, 576, 608, 640, 672, 704, 736, 768, 800), max_size=1333),
-                T.RandomHorizontalFlip(p=hflip_prob),
-            ]
-        elif data_augmentation == "ssd":
-            fill = defaultdict(lambda: mean, {tv_tensors.Mask: 0}) if use_v2 else list(mean)
-            transforms += [
-                T.RandomPhotometricDistort(),
-                T.RandomZoomOut(fill=fill),
-                T.RandomIoUCrop(),
-                T.RandomHorizontalFlip(p=hflip_prob),
-            ]
-        elif data_augmentation == "ssdlite":
-            transforms += [
-                T.RandomIoUCrop(),
-                T.RandomHorizontalFlip(p=hflip_prob),
-            ]
-        else:
-            raise ValueError(f'Unknown data augmentation policy "{data_augmentation}"')
-
-        if backend == "pil":
-            # Note: we could just convert to pure tensors even in v2.
-            transforms += [T.ToImage() if use_v2 else T.PILToTensor()]
-
-        transforms += [T.ToDtype(torch.float, scale=True)]
-
-        if use_v2:
-            transforms += [
-                T.ConvertBoundingBoxFormat(tv_tensors.BoundingBoxFormat.XYXY),
-                T.SanitizeBoundingBoxes(),
-                T.ToPureTensor(),
-            ]
-
-        self.transforms = T.Compose(transforms)
-
-    def __call__(self, img, target):
-        return self.transforms(img, target)
-
-
-def read_polygon_points(lbl_path, shape):
-    """读取 YOLOv8 格式的标注文件并解析多边形轮廓"""
-    polygon_points = []
-    w, h = shape[:2]
-    with open(lbl_path, 'r') as f:
-        lines = f.readlines()
-
-    for line in lines:
-        parts = line.strip().split()
-        class_id = int(parts[0])
-        points = np.array(parts[1:], dtype=np.float32).reshape(-1, 2)  # 读取点坐标
-        points[:, 0] *= h
-        points[:, 1] *= w
-
-        polygon_points.append((class_id, points))
-
-    return polygon_points
-
-
-def read_masks_from_pixels(lbl_path, shape):
-    """读取纯像素点格式的文件,不是轮廓像素点"""
-    h, w = shape
-    masks = []
-    labels = []
-
-    with open(lbl_path, 'r') as reader:
-        lines = reader.readlines()
-        mask_points = []
-        for line in lines:
-            mask = torch.zeros((h, w), dtype=torch.uint8)
-            parts = line.strip().split()
-            # print(f'parts:{parts}')
-            cls = torch.tensor(int(parts[0]), dtype=torch.int64)
-            labels.append(cls)
-            x_array = parts[1::2]
-            y_array = parts[2::2]
-
-            for x, y in zip(x_array, y_array):
-                x = float(x)
-                y = float(y)
-                mask_points.append((int(y * h), int(x * w)))
-
-            for p in mask_points:
-                mask[p] = 1
-            masks.append(mask)
-    reader.close()
-    return labels, masks
-
-
-def create_masks_from_polygons(polygons, image_shape):
-    """创建一个与图像尺寸相同的掩码,并填充多边形轮廓"""
-    colors = np.array([plt.cm.Spectral(each) for each in np.linspace(0, 1, len(polygons))])
-    masks = []
-
-    for polygon_data, col in zip(polygons, colors):
-        mask = np.zeros(image_shape[:2], dtype=np.uint8)
-        # 将多边形顶点转换为 NumPy 数组
-        _, polygon = polygon_data
-        pts = np.array(polygon, np.int32).reshape((-1, 1, 2))
-
-        # 使用 OpenCV 的 fillPoly 函数填充多边形
-        # print(f'color:{col[:3]}')
-        cv2.fillPoly(mask, [pts], np.array(col[:3]) * 255)
-        mask = torch.from_numpy(mask)
-        mask[mask != 0] = 1
-        masks.append(mask)
-
-    return masks
-
-
-def read_masks_from_txt(label_path, shape):
-    polygon_points = read_polygon_points(label_path, shape)
-    masks = create_masks_from_polygons(polygon_points, shape)
-    labels = [torch.tensor(item[0]) for item in polygon_points]
-
-    return labels, masks
-
-
-def masks_to_boxes(masks: torch.Tensor, ) -> torch.Tensor:
-    """
-    Compute the bounding boxes around the provided masks.
-
-    Returns a [N, 4] tensor containing bounding boxes. The boxes are in ``(x1, y1, x2, y2)`` format with
-    ``0 <= x1 < x2`` and ``0 <= y1 < y2``.
-
-    Args:
-        masks (Tensor[N, H, W]): masks to transform where N is the number of masks
-            and (H, W) are the spatial dimensions.
-
-    Returns:
-        Tensor[N, 4]: bounding boxes
-    """
-    # if not torch.jit.is_scripting() and not torch.jit.is_tracing():
-    #     _log_api_usage_once(masks_to_boxes)
-    if masks.numel() == 0:
-        return torch.zeros((0, 4), device=masks.device, dtype=torch.float)
-
-    n = masks.shape[0]
-
-    bounding_boxes = torch.zeros((n, 4), device=masks.device, dtype=torch.float)
-
-    for index, mask in enumerate(masks):
-        y, x = torch.where(mask != 0)
-        bounding_boxes[index, 0] = torch.min(x)
-        bounding_boxes[index, 1] = torch.min(y)
-        bounding_boxes[index, 2] = torch.max(x)
-        bounding_boxes[index, 3] = torch.max(y)
-        # debug to pixel datasets
-
-        if bounding_boxes[index, 0] == bounding_boxes[index, 2]:
-            bounding_boxes[index, 2] = bounding_boxes[index, 2] + 1
-            bounding_boxes[index, 0] = bounding_boxes[index, 0] - 1
-
-        if bounding_boxes[index, 1] == bounding_boxes[index, 3]:
-            bounding_boxes[index, 3] = bounding_boxes[index, 3] + 1
-            bounding_boxes[index, 1] = bounding_boxes[index, 1] - 1
-
-    return bounding_boxes
-
-
-def line_boxes_faster(target):
-    boxs = []
-    lpre = target["lpre"].cpu().numpy() * 4
-    vecl_target = target["lpre_label"].cpu().numpy()
-    lpre = lpre[vecl_target == 1]
-
-    lines = lpre
-    sline = np.ones(lpre.shape[0])
-
-    if len(lines) > 0 and not (lines[0] == 0).all():
-        for i, ((a, b), s) in enumerate(zip(lines, sline)):
-            if i > 0 and (lines[i] == lines[0]).all():
-                break
-            # plt.plot([a[1], b[1]], [a[0], b[0]], c="red", linewidth=1)  # a[1], b[1]无明确大小
-
-            if a[1] > b[1]:
-                ymax = a[1] + 10
-                ymin = b[1] - 10
-            else:
-                ymin = a[1] - 10
-                ymax = b[1] + 10
-            if a[0] > b[0]:
-                xmax = a[0] + 10
-                xmin = b[0] - 10
-            else:
-                xmin = a[0] - 10
-                xmax = b[0] + 10
-            boxs.append([max(0, ymin), max(0, xmin), min(512, ymax), min(512, xmax)])
-
-    return torch.tensor(boxs)
-
-
-# 将线段变为 [起点,终点]形式  传入[num,2,2]
-def a_to_b(line):
-    result_pairs = []
-    for a, b in line:
-        min_x = min(a[0], b[0])
-        min_y = min(a[1], b[1])
-        new_top_left_x = max((min_x - 10), 0)
-        new_top_left_y = max((min_y - 10), 0)
-        dist_a = (a[0] - new_top_left_x) ** 2 + (a[1] - new_top_left_y) ** 2
-        dist_b = (b[0] - new_top_left_x) ** 2 + (b[1] - new_top_left_y) ** 2
-
-        # 根据距离选择起点并设置标签
-        if dist_a <= dist_b:  # 如果a点离新左上角更近或两者距离相等
-            result_pairs.append([a, b])  # 将a设为起点,b为终点
-        else:  # 如果b点离新左上角更近
-            result_pairs.append([b, a])  # 将b设为起点,a为终点
-    result_tensor = torch.stack([torch.stack(row) for row in result_pairs])
-    return result_tensor
-
-
-def read_polygon_points_wire(lbl_path, shape):
-    """读取 YOLOv8 格式的标注文件并解析多边形轮廓"""
-    polygon_points = []
-    w, h = shape[:2]
-    with open(lbl_path, 'r') as f:
-        lines = json.load(f)
-
-    for line in lines["segmentations"]:
-        parts = line["data"]
-        class_id = int(line["cls_id"])
-        points = np.array(parts, dtype=np.float32).reshape(-1, 2)  # 读取点坐标
-        points[:, 0] *= h
-        points[:, 1] *= w
-
-        polygon_points.append((class_id, points))
-
-    return polygon_points
-
-
-def read_masks_from_txt_wire(label_path, shape):
-    polygon_points = read_polygon_points_wire(label_path, shape)
-    masks = create_masks_from_polygons(polygon_points, shape)
-    labels = [torch.tensor(item[0]) for item in polygon_points]
-
-    return labels, masks
-
-
-def read_masks_from_pixels_wire(lbl_path, shape):
-    """读取纯像素点格式的文件,不是轮廓像素点"""
-    h, w = shape
-    masks = []
-    labels = []
-
-    with open(lbl_path, 'r') as reader:
-        lines = json.load(reader)
-        mask_points = []
-        for line in lines["segmentations"]:
-            # mask = torch.zeros((h, w), dtype=torch.uint8)
-            # parts = line["data"]
-            # print(f'parts:{parts}')
-            cls = torch.tensor(int(line["cls_id"]), dtype=torch.int64)
-            labels.append(cls)
-            # x_array = parts[0::2]
-            # y_array = parts[1::2]
-            # 
-            # for x, y in zip(x_array, y_array):
-            #     x = float(x)
-            #     y = float(y)
-            #     mask_points.append((int(y * h), int(x * w)))
-
-            # for p in mask_points:
-            #     mask[p] = 1
-            # masks.append(mask)
-    reader.close()
-    return labels
-
-
-def read_lable_keypoint(lbl_path):
-    """判断线段的起点终点, 起点 lable=0, 终点 lable=1"""
-    labels = []
-
-    with open(lbl_path, 'r') as reader:
-        lines = json.load(reader)
-        aa = lines["wires"][0]["line_pos_coords"]["content"]
-
-        result_pairs = []
-        for a, b in aa:
-            min_x = min(a[0], b[0])
-            min_y = min(a[1], b[1])
-
-            # 定义新的左上角位置
-            new_top_left_x = max((min_x - 10), 0)
-            new_top_left_y = max((min_y - 10), 0)
-
-            # Step 2: 计算各点到新左上角的距离平方(避免浮点运算误差)
-            dist_a = (a[0] - new_top_left_x) ** 2 + (a[1] - new_top_left_y) ** 2
-            dist_b = (b[0] - new_top_left_x) ** 2 + (b[1] - new_top_left_y) ** 2
-
-            # Step 3 & 4: 根据距离选择起点并设置标签
-            if dist_a <= dist_b:  # 如果a点离新左上角更近或两者距离相等
-                result_pairs.append([a, b])  # 将a设为起点,b为终点
-            else:  # 如果b点离新左上角更近
-                result_pairs.append([b, a])  # 将b设为起点,a为终点
-
-            # x_ = abs(a[0] - b[0])
-            # y_ = abs(a[1] - b[1])
-            # if x_ > y_:  # x小的离左上角近
-            #     if a[0] < b[0]:  # 视为起点,lable=0
-            #         label = [0, 1]
-            #     else:
-            #         label = [1, 0]
-            # else:  # x大的是起点
-            #     if a[0] > b[0]:  # 视为起点,lable=0
-            #         label = [0, 1]
-            #     else:
-            #         label = [1, 0]
-            # labels.append(label)
-        # print(result_pairs )
-    reader.close()
-    return labels
-
-
-def adjacency_matrix(n, link):  # 邻接矩阵
-    mat = torch.zeros(n + 1, n + 1, dtype=torch.uint8)
-    link = torch.tensor(link)
-    if len(link) > 0:
-        mat[link[:, 0], link[:, 1]] = 1
-        mat[link[:, 1], link[:, 0]] = 1
-    return mat

+ 0 - 294
lcnn/datasets.py

@@ -1,294 +0,0 @@
-# import glob
-# import json
-# import math
-# import os
-# import random
-#
-# import numpy as np
-# import numpy.linalg as LA
-# import torch
-# from skimage import io
-# from torch.utils.data import Dataset
-# from torch.utils.data.dataloader import default_collate
-#
-# from lcnn.config import M
-#
-# from .dataset_tool import line_boxes_faster, read_masks_from_txt_wire, read_masks_from_pixels_wire
-#
-#
-# class WireframeDataset(Dataset):
-#     def __init__(self, rootdir, split):
-#         self.rootdir = rootdir
-#         filelist = glob.glob(f"{rootdir}/{split}/*_label.npz")
-#         filelist.sort()
-#
-#         # print(f"n{split}:", len(filelist))
-#         self.split = split
-#         self.filelist = filelist
-#
-#     def __len__(self):
-#         return len(self.filelist)
-#
-#     def __getitem__(self, idx):
-#         iname = self.filelist[idx][:-10].replace("_a0", "").replace("_a1", "") + ".png"
-#         image = io.imread(iname).astype(float)[:, :, :3]
-#         if "a1" in self.filelist[idx]:
-#             image = image[:, ::-1, :]
-#         image = (image - M.image.mean) / M.image.stddev
-#         image = np.rollaxis(image, 2).copy()
-#
-#         with np.load(self.filelist[idx]) as npz:
-#             target = {
-#                 name: torch.from_numpy(npz[name]).float()
-#                 for name in ["jmap", "joff", "lmap"]
-#             }
-#             lpos = np.random.permutation(npz["lpos"])[: M.n_stc_posl]
-#             lneg = np.random.permutation(npz["lneg"])[: M.n_stc_negl]
-#             npos, nneg = len(lpos), len(lneg)
-#             lpre = np.concatenate([lpos, lneg], 0)
-#             for i in range(len(lpre)):
-#                 if random.random() > 0.5:
-#                     lpre[i] = lpre[i, ::-1]
-#             ldir = lpre[:, 0, :2] - lpre[:, 1, :2]
-#             ldir /= np.clip(LA.norm(ldir, axis=1, keepdims=True), 1e-6, None)
-#             feat = [
-#                 lpre[:, :, :2].reshape(-1, 4) / 128 * M.use_cood,
-#                 ldir * M.use_slop,
-#                 lpre[:, :, 2],
-#             ]
-#             feat = np.concatenate(feat, 1)
-#             meta = {
-#                 "junc": torch.from_numpy(npz["junc"][:, :2]),
-#                 "jtyp": torch.from_numpy(npz["junc"][:, 2]).byte(),
-#                 "Lpos": self.adjacency_matrix(len(npz["junc"]), npz["Lpos"]),
-#                 "Lneg": self.adjacency_matrix(len(npz["junc"]), npz["Lneg"]),
-#                 "lpre": torch.from_numpy(lpre[:, :, :2]),
-#                 "lpre_label": torch.cat([torch.ones(npos), torch.zeros(nneg)]),
-#                 "lpre_feat": torch.from_numpy(feat),
-#             }
-#
-#         labels = []
-#         labels = read_masks_from_pixels_wire(iname, (512, 512))
-#         # if self.target_type == 'polygon':
-#         #     labels, masks = read_masks_from_txt_wire(iname, (512, 512))
-#         # elif self.target_type == 'pixel':
-#         #     labels = read_masks_from_pixels_wire(iname, (512, 512))
-#
-#         target["labels"] = torch.stack(labels)
-#         target["boxes"] = line_boxes_faster(meta)
-#
-#
-#         return torch.from_numpy(image).float(), meta, target
-#
-#     def adjacency_matrix(self, n, link):
-#         mat = torch.zeros(n + 1, n + 1, dtype=torch.uint8)
-#         link = torch.from_numpy(link)
-#         if len(link) > 0:
-#             mat[link[:, 0], link[:, 1]] = 1
-#             mat[link[:, 1], link[:, 0]] = 1
-#         return mat
-#
-#
-# def collate(batch):
-#     return (
-#         default_collate([b[0] for b in batch]),
-#         [b[1] for b in batch],
-#         default_collate([b[2] for b in batch]),
-#     )
-
-from torch.utils.data.dataset import T_co
-
-from .models.base.base_dataset import BaseDataset
-
-import glob
-import json
-import math
-import os
-import random
-import cv2
-import PIL
-
-import matplotlib.pyplot as plt
-import matplotlib as mpl
-from torchvision.utils import draw_bounding_boxes
-
-import numpy as np
-import numpy.linalg as LA
-import torch
-from skimage import io
-from torch.utils.data import Dataset
-from torch.utils.data.dataloader import default_collate
-
-import matplotlib.pyplot as plt
-from .dataset_tool import line_boxes_faster, read_masks_from_txt_wire, read_masks_from_pixels_wire, adjacency_matrix
-
-
-
-class  WireframeDataset(BaseDataset):
-    def __init__(self, dataset_path, transforms=None, dataset_type=None, target_type='pixel'):
-        super().__init__(dataset_path)
-
-        self.data_path = dataset_path
-        print(f'data_path:{dataset_path}')
-        self.transforms = transforms
-        self.img_path = os.path.join(dataset_path, "images/" + dataset_type)
-        self.lbl_path = os.path.join(dataset_path, "labels/" + dataset_type)
-        self.imgs = os.listdir(self.img_path)
-        self.lbls = os.listdir(self.lbl_path)
-        self.target_type = target_type
-        # self.default_transform = DefaultTransform()
-
-    def __getitem__(self, index) -> T_co:
-        img_path = os.path.join(self.img_path, self.imgs[index])
-        lbl_path = os.path.join(self.lbl_path, self.imgs[index][:-3] + 'json')
-
-        img = PIL.Image.open(img_path).convert('RGB')
-        w, h = img.size
-
-        # wire_labels, target = self.read_target(item=index, lbl_path=lbl_path, shape=(h, w))
-        meta, target, target_b = self.read_target(item=index, lbl_path=lbl_path, shape=(h, w))
-
-        img = self.default_transform(img)
-
-        # print(f'img:{img}')
-        return img, meta, target, target_b
-
-    def __len__(self):
-        return len(self.imgs)
-
-    def read_target(self, item, lbl_path, shape, extra=None):
-        # print(f'shape:{shape}')
-        # print(f'lbl_path:{lbl_path}')
-        with open(lbl_path, 'r') as file:
-            lable_all = json.load(file)
-
-        n_stc_posl = 300
-        n_stc_negl = 40
-        use_cood = 0
-        use_slop = 0
-
-        wire = lable_all["wires"][0]  # 字典
-        line_pos_coords = np.random.permutation(wire["line_pos_coords"]["content"])[: n_stc_posl]  # 不足,有多少取多少
-        line_neg_coords = np.random.permutation(wire["line_neg_coords"]["content"])[: n_stc_negl]
-        npos, nneg = len(line_pos_coords), len(line_neg_coords)
-        lpre = np.concatenate([line_pos_coords, line_neg_coords], 0)  # 正负样本坐标合在一起
-        for i in range(len(lpre)):
-            if random.random() > 0.5:
-                lpre[i] = lpre[i, ::-1]
-        ldir = lpre[:, 0, :2] - lpre[:, 1, :2]
-        ldir /= np.clip(LA.norm(ldir, axis=1, keepdims=True), 1e-6, None)
-        feat = [
-            lpre[:, :, :2].reshape(-1, 4) / 128 * use_cood,
-            ldir * use_slop,
-            lpre[:, :, 2],
-        ]
-        feat = np.concatenate(feat, 1)
-
-        meta = {
-            "junc_coords": torch.tensor(wire["junc_coords"]["content"])[:, :2],
-            "jtyp": torch.tensor(wire["junc_coords"]["content"])[:, 2].byte(),
-            "line_pos_idx": adjacency_matrix(len(wire["junc_coords"]["content"]), wire["line_pos_idx"]["content"]),
-            # 真实存在线条的邻接矩阵
-            "line_neg_idx": adjacency_matrix(len(wire["junc_coords"]["content"]), wire["line_neg_idx"]["content"]),
-
-            "lpre": torch.tensor(lpre)[:, :, :2],
-            "lpre_label": torch.cat([torch.ones(npos), torch.zeros(nneg)]),  # 样本对应标签 1,0
-            "lpre_feat": torch.from_numpy(feat),
-
-        }
-
-        target = {
-            "junc_map": torch.tensor(wire['junc_map']["content"]),
-            "junc_offset": torch.tensor(wire['junc_offset']["content"]),
-            "line_map": torch.tensor(wire['line_map']["content"]),
-        }
-
-        labels = []
-        if self.target_type == 'polygon':
-            labels, masks = read_masks_from_txt_wire(lbl_path, shape)
-        elif self.target_type == 'pixel':
-            labels = read_masks_from_pixels_wire(lbl_path, shape)
-
-        # print(torch.stack(masks).shape)    # [线段数, 512, 512]
-        target_b = {}
-
-        # target_b["image_id"] = torch.tensor(item)
-
-        target_b["labels"] = torch.stack(labels)
-        target_b["boxes"] = line_boxes_faster(meta)
-
-        return meta, target, target_b
-
-
-    def show(self, idx):
-        image, target = self.__getitem__(idx)
-
-        cmap = plt.get_cmap("jet")
-        norm = mpl.colors.Normalize(vmin=0.4, vmax=1.0)
-        sm = plt.cm.ScalarMappable(cmap=cmap, norm=norm)
-        sm.set_array([])
-
-        def imshow(im):
-            plt.close()
-            plt.tight_layout()
-            plt.imshow(im)
-            plt.colorbar(sm, fraction=0.046)
-            plt.xlim([0, im.shape[0]])
-            plt.ylim([im.shape[0], 0])
-
-        def draw_vecl(lines, sline, juncs, junts, fn=None):
-            img_path = os.path.join(self.img_path, self.imgs[idx])
-            imshow(io.imread(img_path))
-            if len(lines) > 0 and not (lines[0] == 0).all():
-                for i, ((a, b), s) in enumerate(zip(lines, sline)):
-                    if i > 0 and (lines[i] == lines[0]).all():
-                        break
-                    plt.plot([a[1], b[1]], [a[0], b[0]], c="red", linewidth=1)  # a[1], b[1]无明确大小
-            if not (juncs[0] == 0).all():
-                for i, j in enumerate(juncs):
-                    if i > 0 and (i == juncs[0]).all():
-                        break
-                    plt.scatter(j[1], j[0], c="red", s=2, zorder=100)  # 原 s=64
-
-            img_path = os.path.join(self.img_path, self.imgs[idx])
-            img = PIL.Image.open(img_path).convert('RGB')
-            boxed_image = draw_bounding_boxes((self.default_transform(img) * 255).to(torch.uint8), target["boxes"],
-                                              colors="yellow", width=1)
-            plt.imshow(boxed_image.permute(1, 2, 0).numpy())
-            plt.show()
-
-            plt.show()
-            if fn != None:
-                plt.savefig(fn)
-
-        junc = target['wires']['junc_coords'].cpu().numpy() * 4
-        jtyp = target['wires']['jtyp'].cpu().numpy()
-        juncs = junc[jtyp == 0]
-        junts = junc[jtyp == 1]
-
-        lpre = target['wires']["lpre"].cpu().numpy() * 4
-        vecl_target = target['wires']["lpre_label"].cpu().numpy()
-        lpre = lpre[vecl_target == 1]
-
-        # draw_vecl(lpre, np.ones(lpre.shape[0]), juncs, junts, save_path)
-        draw_vecl(lpre, np.ones(lpre.shape[0]), juncs, junts)
-
-    def show_img(self, img_path):
-        pass
-
-
-def collate(batch):
-    return (
-        default_collate([b[0] for b in batch]),
-        [b[1] for b in batch],
-        default_collate([b[2] for b in batch]),
-        [b[3] for b in batch],
-    )
-
-
-# if __name__ == '__main__':
-#     path = r"D:\python\PycharmProjects\data"
-#     dataset = WireframeDataset(dataset_path=path, dataset_type='train')
-#     dataset.show(0)
-
-

+ 0 - 209
lcnn/metric.py

@@ -1,209 +0,0 @@
-import numpy as np
-import numpy.linalg as LA
-import matplotlib.pyplot as plt
-
-from lcnn.utils import argsort2d
-
-DX = [0, 0, 1, -1, 1, 1, -1, -1]
-DY = [1, -1, 0, 0, 1, -1, 1, -1]
-
-
-def ap(tp, fp):
-    recall = tp
-    precision = tp / np.maximum(tp + fp, 1e-9)
-
-    recall = np.concatenate(([0.0], recall, [1.0]))
-    precision = np.concatenate(([0.0], precision, [0.0]))
-
-    for i in range(precision.size - 1, 0, -1):
-        precision[i - 1] = max(precision[i - 1], precision[i])
-    i = np.where(recall[1:] != recall[:-1])[0]
-    return np.sum((recall[i + 1] - recall[i]) * precision[i + 1])
-
-
-def APJ(vert_pred, vert_gt, max_distance, im_ids):
-    if len(vert_pred) == 0:
-        return 0
-
-    vert_pred = np.array(vert_pred)
-    vert_gt = np.array(vert_gt)
-
-    confidence = vert_pred[:, -1]
-    idx = np.argsort(-confidence)
-    vert_pred = vert_pred[idx, :]
-    im_ids = im_ids[idx]
-    n_gt = sum(len(gt) for gt in vert_gt)
-
-    nd = len(im_ids)
-    tp, fp = np.zeros(nd, dtype=np.float), np.zeros(nd, dtype=np.float)
-    hit = [[False for _ in j] for j in vert_gt]
-
-    for i in range(nd):
-        gt_juns = vert_gt[im_ids[i]]
-        pred_juns = vert_pred[i][:-1]
-        if len(gt_juns) == 0:
-            continue
-        dists = np.linalg.norm((pred_juns[None, :] - gt_juns), axis=1)
-        choice = np.argmin(dists)
-        dist = np.min(dists)
-        if dist < max_distance and not hit[im_ids[i]][choice]:
-            tp[i] = 1
-            hit[im_ids[i]][choice] = True
-        else:
-            fp[i] = 1
-
-    tp = np.cumsum(tp) / n_gt
-    fp = np.cumsum(fp) / n_gt
-    return ap(tp, fp)
-
-
-def nms_j(heatmap, delta=1):
-    heatmap = heatmap.copy()
-    disable = np.zeros_like(heatmap, dtype=np.bool)
-    for x, y in argsort2d(heatmap):
-        for dx, dy in zip(DX, DY):
-            xp, yp = x + dx, y + dy
-            if not (0 <= xp < heatmap.shape[0] and 0 <= yp < heatmap.shape[1]):
-                continue
-            if heatmap[x, y] >= heatmap[xp, yp]:
-                disable[xp, yp] = True
-    heatmap[disable] *= 0.6
-    return heatmap
-
-
-def mAPJ(pred, truth, distances, im_ids):
-    return sum(APJ(pred, truth, d, im_ids) for d in distances) / len(distances) * 100
-
-
-def post_jheatmap(heatmap, offset=None, delta=1):
-    heatmap = nms_j(heatmap, delta=delta)
-    # only select the best 1000 junctions for efficiency
-    v0 = argsort2d(-heatmap)[:1000]
-    confidence = -np.sort(-heatmap.ravel())[:1000]
-    keep_id = np.where(confidence >= 1e-2)[0]
-    if len(keep_id) == 0:
-        return np.zeros((0, 3))
-
-    confidence = confidence[keep_id]
-    if offset is not None:
-        v0 = np.array([v + offset[:, v[0], v[1]] for v in v0])
-    v0 = v0[keep_id] + 0.5
-    v0 = np.hstack((v0, confidence[:, np.newaxis]))
-    return v0
-
-
-def vectorized_wireframe_2d_metric(
-    vert_pred, dpth_pred, edge_pred, vert_gt, dpth_gt, edge_gt, threshold
-):
-    # staging 1: matching
-    nd = len(vert_pred)
-    sorted_confidence = np.argsort(-vert_pred[:, -1])
-    vert_pred = vert_pred[sorted_confidence, :-1]
-    dpth_pred = dpth_pred[sorted_confidence]
-    d = np.sqrt(
-        np.sum(vert_pred ** 2, 1)[:, None]
-        + np.sum(vert_gt ** 2, 1)[None, :]
-        - 2 * vert_pred @ vert_gt.T
-    )
-    choice = np.argmin(d, 1)
-    dist = np.min(d, 1)
-
-    # staging 2: compute depth metric: SIL/L2
-    loss_L1 = loss_L2 = 0
-    hit = np.zeros_like(dpth_gt, np.bool)
-    SIL = np.zeros(dpth_pred)
-    for i in range(nd):
-        if dist[i] < threshold and not hit[choice[i]]:
-            hit[choice[i]] = True
-            loss_L1 += abs(dpth_gt[choice[i]] - dpth_pred[i])
-            loss_L2 += (dpth_gt[choice[i]] - dpth_pred[i]) ** 2
-            a = np.maximum(-dpth_pred[i], 1e-10)
-            b = -dpth_gt[choice[i]]
-            SIL[i] = np.log(a) - np.log(b)
-        else:
-            choice[i] = -1
-
-    n = max(np.sum(hit), 1)
-    loss_L1 /= n
-    loss_L2 /= n
-    loss_SIL = np.sum(SIL ** 2) / n - np.sum(SIL) ** 2 / (n * n)
-
-    # staging 3: compute mAP for edge matching
-    edgeset = set([frozenset(e) for e in edge_gt])
-    tp = np.zeros(len(edge_pred), dtype=np.float)
-    fp = np.zeros(len(edge_pred), dtype=np.float)
-    for i, (v0, v1, score) in enumerate(sorted(edge_pred, key=-edge_pred[2])):
-        length = LA.norm(vert_gt[v0] - vert_gt[v1], axis=1)
-        if frozenset([choice[v0], choice[v1]]) in edgeset:
-            tp[i] = length
-        else:
-            fp[i] = length
-    total_length = LA.norm(
-        vert_gt[edge_gt[:, 0]] - vert_gt[edge_gt[:, 1]], axis=1
-    ).sum()
-    return ap(tp / total_length, fp / total_length), (loss_SIL, loss_L1, loss_L2)
-
-
-def vectorized_wireframe_3d_metric(
-    vert_pred, dpth_pred, edge_pred, vert_gt, dpth_gt, edge_gt, threshold
-):
-    # staging 1: matching
-    nd = len(vert_pred)
-    sorted_confidence = np.argsort(-vert_pred[:, -1])
-    vert_pred = np.hstack([vert_pred[:, :-1], dpth_pred[:, None]])[sorted_confidence]
-    vert_gt = np.hstack([vert_gt[:, :-1], dpth_gt[:, None]])
-    d = np.sqrt(
-        np.sum(vert_pred ** 2, 1)[:, None]
-        + np.sum(vert_gt ** 2, 1)[None, :]
-        - 2 * vert_pred @ vert_gt.T
-    )
-    choice = np.argmin(d, 1)
-    dist = np.min(d, 1)
-    hit = np.zeros_like(dpth_gt, np.bool)
-    for i in range(nd):
-        if dist[i] < threshold and not hit[choice[i]]:
-            hit[choice[i]] = True
-        else:
-            choice[i] = -1
-
-    # staging 2: compute mAP for edge matching
-    edgeset = set([frozenset(e) for e in edge_gt])
-    tp = np.zeros(len(edge_pred), dtype=np.float)
-    fp = np.zeros(len(edge_pred), dtype=np.float)
-    for i, (v0, v1, score) in enumerate(sorted(edge_pred, key=-edge_pred[2])):
-        length = LA.norm(vert_gt[v0] - vert_gt[v1], axis=1)
-        if frozenset([choice[v0], choice[v1]]) in edgeset:
-            tp[i] = length
-        else:
-            fp[i] = length
-    total_length = LA.norm(
-        vert_gt[edge_gt[:, 0]] - vert_gt[edge_gt[:, 1]], axis=1
-    ).sum()
-
-    return ap(tp / total_length, fp / total_length)
-
-
-def msTPFP(line_pred, line_gt, threshold):
-    diff = ((line_pred[:, None, :, None] - line_gt[:, None]) ** 2).sum(-1)
-    diff = np.minimum(
-        diff[:, :, 0, 0] + diff[:, :, 1, 1], diff[:, :, 0, 1] + diff[:, :, 1, 0]
-    )
-    choice = np.argmin(diff, 1)
-    dist = np.min(diff, 1)
-    hit = np.zeros(len(line_gt), np.bool)
-    tp = np.zeros(len(line_pred), np.float)
-    fp = np.zeros(len(line_pred), np.float)
-    for i in range(len(line_pred)):
-        if dist[i] < threshold and not hit[choice[i]]:
-            hit[choice[i]] = True
-            tp[i] = 1
-        else:
-            fp[i] = 1
-    return tp, fp
-
-
-def msAP(line_pred, line_gt, threshold):
-    tp, fp = msTPFP(line_pred, line_gt, threshold)
-    tp = np.cumsum(tp) / len(line_gt)
-    fp = np.cumsum(fp) / len(line_gt)
-    return ap(tp, fp)

+ 0 - 9
lcnn/models/__init__.py

@@ -1,9 +0,0 @@
-# flake8: noqa
-from .hourglass_pose import hg
-from .unet import unet
-from .resnet50_pose import resnet50
-from .resnet50 import resnet501
-from .fasterrcnn_resnet50 import fasterrcnn_resnet50
-# from .dla import dla169, dla102x, dla102x2
-
-__all__ = ['hg', 'unet', 'resnet50', 'resnet501', 'fasterrcnn_resnet50']

+ 0 - 48
lcnn/models/base/base_dataset.py

@@ -1,48 +0,0 @@
-from abc import ABC, abstractmethod
-
-import torch
-from torch import nn, Tensor
-from torch.utils.data import Dataset
-from torch.utils.data.dataset import T_co
-
-from torchvision.transforms import  functional as F
-
-class BaseDataset(Dataset, ABC):
-    def __init__(self,dataset_path):
-        self.default_transform=DefaultTransform()
-        pass
-
-    def __getitem__(self, index) -> T_co:
-        pass
-
-    @abstractmethod
-    def read_target(self,item,lbl_path,extra=None):
-        pass
-
-    """显示数据集指定图片"""
-    @abstractmethod
-    def show(self,idx):
-        pass
-
-    """
-    显示数据集指定名字的图片
-    """
-
-    @abstractmethod
-    def show_img(self,img_path):
-        pass
-
-class DefaultTransform(nn.Module):
-    def forward(self, img: Tensor) -> Tensor:
-        if not isinstance(img, Tensor):
-            img = F.pil_to_tensor(img)
-        return F.convert_image_dtype(img, torch.float)
-
-    def __repr__(self) -> str:
-        return self.__class__.__name__ + "()"
-
-    def describe(self) -> str:
-        return (
-            "Accepts ``PIL.Image``, batched ``(B, C, H, W)`` and single ``(C, H, W)`` image ``torch.Tensor`` objects. "
-            "The images are rescaled to ``[0.0, 1.0]``."
-        )

+ 0 - 120
lcnn/models/fasterrcnn_resnet50.py

@@ -1,120 +0,0 @@
-import torch
-import torch.nn as nn
-import torchvision
-from typing import Dict, List, Optional, Tuple
-import torch.nn.functional as F
-from torchvision.ops import MultiScaleRoIAlign
-from torchvision.models.detection.faster_rcnn import TwoMLPHead, FastRCNNPredictor
-from torchvision.models.detection.transform import GeneralizedRCNNTransform
-
-
-def get_model(num_classes):
-    # 加载预训练的ResNet-50 FPN backbone
-    model = torchvision.models.detection.fasterrcnn_resnet50_fpn(pretrained=True)
-
-    # 获取分类器的输入特征数
-    in_features = model.roi_heads.box_predictor.cls_score.in_features
-
-    # 替换分类器以适应新的类别数量
-    model.roi_heads.box_predictor = torchvision.models.detection.faster_rcnn.FastRCNNPredictor(in_features, num_classes)
-
-    return model
-
-
-def fastrcnn_loss(class_logits, box_regression, labels, regression_targets):
-    # type: (Tensor, Tensor, List[Tensor], List[Tensor]) -> Tuple[Tensor, Tensor]
-    """
-    Computes the loss for Faster R-CNN.
-
-    Args:
-        class_logits (Tensor)
-        box_regression (Tensor)
-        labels (list[BoxList])
-        regression_targets (Tensor)
-
-    Returns:
-        classification_loss (Tensor)
-        box_loss (Tensor)
-    """
-
-    labels = torch.cat(labels, dim=0)
-    regression_targets = torch.cat(regression_targets, dim=0)
-
-    classification_loss = F.cross_entropy(class_logits, labels)
-
-    # get indices that correspond to the regression targets for
-    # the corresponding ground truth labels, to be used with
-    # advanced indexing
-    sampled_pos_inds_subset = torch.where(labels > 0)[0]
-    labels_pos = labels[sampled_pos_inds_subset]
-    N, num_classes = class_logits.shape
-    box_regression = box_regression.reshape(N, box_regression.size(-1) // 4, 4)
-
-    box_loss = F.smooth_l1_loss(
-        box_regression[sampled_pos_inds_subset, labels_pos],
-        regression_targets[sampled_pos_inds_subset],
-        beta=1 / 9,
-        reduction="sum",
-    )
-    box_loss = box_loss / labels.numel()
-
-    return classification_loss, box_loss
-
-
-class Fasterrcnn_resnet50(nn.Module):
-    def __init__(self, num_classes=5, num_stacks=1):
-        super(Fasterrcnn_resnet50, self).__init__()
-
-        self.model = get_model(num_classes=5)
-        self.backbone = self.model.backbone
-
-        self.box_roi_pool = MultiScaleRoIAlign(featmap_names=["0", "1", "2", "3"], output_size=16, sampling_ratio=2)
-
-        out_channels = self.backbone.out_channels
-        resolution = self.box_roi_pool.output_size[0]
-        representation_size = 1024
-        self.box_head = TwoMLPHead(out_channels * resolution ** 2, representation_size)
-
-        self.box_predictor = FastRCNNPredictor(representation_size, num_classes)
-
-        # 多任务输出层
-        self.score_layers = nn.ModuleList([
-            nn.Sequential(
-                nn.Conv2d(256, 128, kernel_size=3, padding=1),
-                nn.BatchNorm2d(128),
-                nn.ReLU(inplace=True),
-                nn.Conv2d(128, num_classes, kernel_size=1)
-            )
-            for _ in range(num_stacks)
-        ])
-
-    def forward(self, x, target1, train_or_val, image_shapes=(512, 512)):
-
-        transform = GeneralizedRCNNTransform(min_size=512, max_size=1333, image_mean=[0.485, 0.456, 0.406],
-                                             image_std=[0.229, 0.224, 0.225])
-        images, targets = transform(x, target1)
-        x_ = self.backbone(images.tensors)
-
-        # x_ = self.backbone(x)  # '0'  '1'  '2'  '3'   'pool'
-        # print(f'backbone:{self.backbone}')
-        # print(f'Fasterrcnn_resnet50 x_:{x_}')
-        feature_ = x_['0']  # 图片特征
-        outputs = []
-        for score_layer in self.score_layers:
-            output = score_layer(feature_)
-            outputs.append(output)  # 多头
-
-        if train_or_val == "training":
-            loss_box = self.model(x, target1)
-            return outputs, feature_, loss_box
-        else:
-            box_all = self.model(x, target1)
-            return outputs, feature_, box_all
-
-
-def fasterrcnn_resnet50(**kwargs):
-    model = Fasterrcnn_resnet50(
-        num_classes=kwargs.get("num_classes", 5),
-        num_stacks=kwargs.get("num_stacks", 1)
-    )
-    return model

+ 0 - 201
lcnn/models/hourglass_pose.py

@@ -1,201 +0,0 @@
-"""
-Hourglass network inserted in the pre-activated Resnet
-Use lr=0.01 for current version
-(c) Yichao Zhou (LCNN)
-(c) YANG, Wei
-"""
-import torch
-import torch.nn as nn
-import torch.nn.functional as F
-
-__all__ = ["HourglassNet", "hg"]
-
-
-class Bottleneck2D(nn.Module):
-    expansion = 2  # 扩展因子
-
-    def __init__(self, inplanes, planes, stride=1, downsample=None):
-        super(Bottleneck2D, self).__init__()
-
-        self.bn1 = nn.BatchNorm2d(inplanes)
-        self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1)
-        self.bn2 = nn.BatchNorm2d(planes)
-        self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, padding=1)
-        self.bn3 = nn.BatchNorm2d(planes)
-        self.conv3 = nn.Conv2d(planes, planes * 2, kernel_size=1)
-        self.relu = nn.ReLU(inplace=True)
-        self.downsample = downsample
-        self.stride = stride
-
-    def forward(self, x):
-        residual = x
-
-        out = self.bn1(x)
-        out = self.relu(out)
-        out = self.conv1(out)
-
-        out = self.bn2(out)
-        out = self.relu(out)
-        out = self.conv2(out)
-
-        out = self.bn3(out)
-        out = self.relu(out)
-        out = self.conv3(out)
-
-        if self.downsample is not None:
-            residual = self.downsample(x)
-
-        out += residual
-
-        return out
-
-
-class Hourglass(nn.Module):
-    def __init__(self, block, num_blocks, planes, depth):
-        super(Hourglass, self).__init__()
-        self.depth = depth
-        self.block = block
-        self.hg = self._make_hour_glass(block, num_blocks, planes, depth)
-
-    def _make_residual(self, block, num_blocks, planes):
-        layers = []
-        for i in range(0, num_blocks):
-            layers.append(block(planes * block.expansion, planes))
-        return nn.Sequential(*layers)
-
-    def _make_hour_glass(self, block, num_blocks, planes, depth):
-        hg = []
-        for i in range(depth):
-            res = []
-            for j in range(3):
-                res.append(self._make_residual(block, num_blocks, planes))
-            if i == 0:
-                res.append(self._make_residual(block, num_blocks, planes))
-            hg.append(nn.ModuleList(res))
-        return nn.ModuleList(hg)
-
-    def _hour_glass_forward(self, n, x):
-        up1 = self.hg[n - 1][0](x)
-        low1 = F.max_pool2d(x, 2, stride=2)
-        low1 = self.hg[n - 1][1](low1)
-
-        if n > 1:
-            low2 = self._hour_glass_forward(n - 1, low1)
-        else:
-            low2 = self.hg[n - 1][3](low1)
-        low3 = self.hg[n - 1][2](low2)
-        up2 = F.interpolate(low3, scale_factor=2)
-        out = up1 + up2
-        return out
-
-    def forward(self, x):
-        return self._hour_glass_forward(self.depth, x)
-
-
-class HourglassNet(nn.Module):
-    """Hourglass model from Newell et al ECCV 2016"""
-
-    def __init__(self, block, head, depth, num_stacks, num_blocks, num_classes):
-        super(HourglassNet, self).__init__()
-
-        self.inplanes = 64
-        self.num_feats = 128
-        self.num_stacks = num_stacks
-        self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=7, stride=2, padding=3)
-        self.bn1 = nn.BatchNorm2d(self.inplanes)
-        self.relu = nn.ReLU(inplace=True)
-        self.layer1 = self._make_residual(block, self.inplanes, 1)
-        self.layer2 = self._make_residual(block, self.inplanes, 1)
-        self.layer3 = self._make_residual(block, self.num_feats, 1)
-        self.maxpool = nn.MaxPool2d(2, stride=2)
-
-        # build hourglass modules
-        ch = self.num_feats * block.expansion
-        # vpts = []
-        hg, res, fc, score, fc_, score_ = [], [], [], [], [], []
-        for i in range(num_stacks):
-            hg.append(Hourglass(block, num_blocks, self.num_feats, depth))
-            res.append(self._make_residual(block, self.num_feats, num_blocks))
-            fc.append(self._make_fc(ch, ch))
-            score.append(head(ch, num_classes))
-            # vpts.append(VptsHead(ch))
-            # vpts.append(nn.Linear(ch, 9))
-            # score.append(nn.Conv2d(ch, num_classes, kernel_size=1))
-            # score[i].bias.data[0] += 4.6
-            # score[i].bias.data[2] += 4.6
-            if i < num_stacks - 1:
-                fc_.append(nn.Conv2d(ch, ch, kernel_size=1))
-                score_.append(nn.Conv2d(num_classes, ch, kernel_size=1))
-        self.hg = nn.ModuleList(hg)
-        self.res = nn.ModuleList(res)
-        self.fc = nn.ModuleList(fc)
-        self.score = nn.ModuleList(score)
-        # self.vpts = nn.ModuleList(vpts)
-        self.fc_ = nn.ModuleList(fc_)
-        self.score_ = nn.ModuleList(score_)
-
-    def _make_residual(self, block, planes, blocks, stride=1):
-        downsample = None
-        if stride != 1 or self.inplanes != planes * block.expansion:
-            downsample = nn.Sequential(
-                nn.Conv2d(
-                    self.inplanes,
-                    planes * block.expansion,
-                    kernel_size=1,
-                    stride=stride,
-                )
-            )
-
-        layers = []
-        layers.append(block(self.inplanes, planes, stride, downsample))
-        self.inplanes = planes * block.expansion
-        for i in range(1, blocks):
-            layers.append(block(self.inplanes, planes))
-
-        return nn.Sequential(*layers)
-
-    def _make_fc(self, inplanes, outplanes):
-        bn = nn.BatchNorm2d(inplanes)
-        conv = nn.Conv2d(inplanes, outplanes, kernel_size=1)
-        return nn.Sequential(conv, bn, self.relu)
-
-    def forward(self, x):
-        out = []
-        # out_vps = []
-        x = self.conv1(x)
-        x = self.bn1(x)
-        x = self.relu(x)
-
-        x = self.layer1(x)
-        x = self.maxpool(x)
-        x = self.layer2(x)
-        x = self.layer3(x)
-
-        for i in range(self.num_stacks):
-            y = self.hg[i](x)
-            y = self.res[i](y)
-            y = self.fc[i](y)
-            score = self.score[i](y)
-            # pre_vpts = F.adaptive_avg_pool2d(x, (1, 1))
-            # pre_vpts = pre_vpts.reshape(-1, 256)
-            # vpts = self.vpts[i](x)
-            out.append(score)
-            # out_vps.append(vpts)
-            if i < self.num_stacks - 1:
-                fc_ = self.fc_[i](y)
-                score_ = self.score_[i](score)
-                x = x + fc_ + score_
-
-        return out[::-1], y  # , out_vps[::-1]
-
-
-def hg(**kwargs):
-    model = HourglassNet(
-        Bottleneck2D,
-        head=kwargs.get("head", lambda c_in, c_out: nn.Conv2D(c_in, c_out, 1)),
-        depth=kwargs["depth"],
-        num_stacks=kwargs["num_stacks"],
-        num_blocks=kwargs["num_blocks"],
-        num_classes=kwargs["num_classes"],
-    )
-    return model

+ 0 - 276
lcnn/models/line_vectorizer.py

@@ -1,276 +0,0 @@
-import itertools
-import random
-from collections import defaultdict
-
-import numpy as np
-import torch
-import torch.nn as nn
-import torch.nn.functional as F
-
-from lcnn.config import M
-
-FEATURE_DIM = 8
-
-
-class LineVectorizer(nn.Module):
-    def __init__(self, backbone):
-        super().__init__()
-        self.backbone = backbone
-
-        lambda_ = torch.linspace(0, 1, M.n_pts0)[:, None]
-        self.register_buffer("lambda_", lambda_)
-        self.do_static_sampling = M.n_stc_posl + M.n_stc_negl > 0
-
-        self.fc1 = nn.Conv2d(256, M.dim_loi, 1)
-        scale_factor = M.n_pts0 // M.n_pts1
-        if M.use_conv:
-            self.pooling = nn.Sequential(
-                nn.MaxPool1d(scale_factor, scale_factor),
-                Bottleneck1D(M.dim_loi, M.dim_loi),
-            )
-            self.fc2 = nn.Sequential(
-                nn.ReLU(inplace=True), nn.Linear(M.dim_loi * M.n_pts1 + FEATURE_DIM, 1)
-            )
-        else:
-            self.pooling = nn.MaxPool1d(scale_factor, scale_factor)
-            self.fc2 = nn.Sequential(
-                nn.Linear(M.dim_loi * M.n_pts1 + FEATURE_DIM, M.dim_fc),
-                nn.ReLU(inplace=True),
-                nn.Linear(M.dim_fc, M.dim_fc),
-                nn.ReLU(inplace=True),
-                nn.Linear(M.dim_fc, 1),
-            )
-        self.loss = nn.BCEWithLogitsLoss(reduction="none")
-
-    def forward(self, input_dict):
-        result = self.backbone(input_dict)
-        h = result["preds"]
-        x = self.fc1(result["feature"])
-        n_batch, n_channel, row, col = x.shape
-
-        xs, ys, fs, ps, idx, jcs = [], [], [], [], [0], []
-        for i, meta in enumerate(input_dict["meta"]):
-            p, label, feat, jc = self.sample_lines(
-                meta, h["jmap"][i], h["joff"][i], input_dict["mode"]
-            )
-            # print("p.shape:", p.shape)
-            ys.append(label)
-            if input_dict["mode"] == "training" and self.do_static_sampling:
-                p = torch.cat([p, meta["lpre"]])
-                feat = torch.cat([feat, meta["lpre_feat"]])
-                ys.append(meta["lpre_label"])
-                del jc
-            else:
-                jcs.append(jc)
-                ps.append(p)
-            fs.append(feat)
-
-            p = p[:, 0:1, :] * self.lambda_ + p[:, 1:2, :] * (1 - self.lambda_) - 0.5
-            p = p.reshape(-1, 2)  # [N_LINE x N_POINT, 2_XY]
-            px, py = p[:, 0].contiguous(), p[:, 1].contiguous()
-            px0 = px.floor().clamp(min=0, max=127)
-            py0 = py.floor().clamp(min=0, max=127)
-            px1 = (px0 + 1).clamp(min=0, max=127)
-            py1 = (py0 + 1).clamp(min=0, max=127)
-            px0l, py0l, px1l, py1l = px0.long(), py0.long(), px1.long(), py1.long()
-
-            # xp: [N_LINE, N_CHANNEL, N_POINT]
-            xp = (
-                (
-                        x[i, :, px0l, py0l] * (px1 - px) * (py1 - py)
-                        + x[i, :, px1l, py0l] * (px - px0) * (py1 - py)
-                        + x[i, :, px0l, py1l] * (px1 - px) * (py - py0)
-                        + x[i, :, px1l, py1l] * (px - px0) * (py - py0)
-                )
-                    .reshape(n_channel, -1, M.n_pts0)
-                    .permute(1, 0, 2)
-            )
-            xp = self.pooling(xp)
-            xs.append(xp)
-            idx.append(idx[-1] + xp.shape[0])
-
-        x, y = torch.cat(xs), torch.cat(ys)
-        f = torch.cat(fs)
-        x = x.reshape(-1, M.n_pts1 * M.dim_loi)
-        x = torch.cat([x, f], 1)
-        x = self.fc2(x.float()).flatten()
-
-        if input_dict["mode"] != "training":
-            p = torch.cat(ps)
-            s = torch.sigmoid(x)
-            b = s > 0.5
-            lines = []
-            score = []
-            for i in range(n_batch):
-                p0 = p[idx[i]: idx[i + 1]]
-                s0 = s[idx[i]: idx[i + 1]]
-                mask = b[idx[i]: idx[i + 1]]
-                p0 = p0[mask]
-                s0 = s0[mask]
-                if len(p0) == 0:
-                    lines.append(torch.zeros([1, M.n_out_line, 2, 2], device=p.device))
-                    score.append(torch.zeros([1, M.n_out_line], device=p.device))
-                else:
-                    arg = torch.argsort(s0, descending=True)
-                    p0, s0 = p0[arg], s0[arg]
-                    lines.append(p0[None, torch.arange(M.n_out_line) % len(p0)])
-                    score.append(s0[None, torch.arange(M.n_out_line) % len(s0)])
-                for j in range(len(jcs[i])):
-                    if len(jcs[i][j]) == 0:
-                        jcs[i][j] = torch.zeros([M.n_out_junc, 2], device=p.device)
-                    jcs[i][j] = jcs[i][j][
-                        None, torch.arange(M.n_out_junc) % len(jcs[i][j])
-                    ]
-            result["preds"]["lines"] = torch.cat(lines)
-            result["preds"]["score"] = torch.cat(score)
-            result["preds"]["juncs"] = torch.cat([jcs[i][0] for i in range(n_batch)])
-            result["box"] = result['aaa']
-            del result['aaa']
-            if len(jcs[i]) > 1:
-                result["preds"]["junts"] = torch.cat(
-                    [jcs[i][1] for i in range(n_batch)]
-                )
-
-        if input_dict["mode"] != "testing":
-            y = torch.cat(ys)
-            loss = self.loss(x, y)
-            lpos_mask, lneg_mask = y, 1 - y
-            loss_lpos, loss_lneg = loss * lpos_mask, loss * lneg_mask
-
-            def sum_batch(x):
-                xs = [x[idx[i]: idx[i + 1]].sum()[None] for i in range(n_batch)]
-                return torch.cat(xs)
-
-            lpos = sum_batch(loss_lpos) / sum_batch(lpos_mask).clamp(min=1)
-            lneg = sum_batch(loss_lneg) / sum_batch(lneg_mask).clamp(min=1)
-            result["losses"][0]["lpos"] = lpos * M.loss_weight["lpos"]
-            result["losses"][0]["lneg"] = lneg * M.loss_weight["lneg"]
-
-        if input_dict["mode"] == "training":
-            for i in result["aaa"].keys():
-                result["losses"][0][i] = result["aaa"][i]
-            del result["preds"]
-
-        return result
-
-    def sample_lines(self, meta, jmap, joff, mode):
-        with torch.no_grad():
-            junc = meta["junc_coords"]  # [N, 2]
-            jtyp = meta["jtyp"]  # [N]
-            Lpos = meta["line_pos_idx"]
-            Lneg = meta["line_neg_idx"]
-
-            n_type = jmap.shape[0]
-            jmap = non_maximum_suppression(jmap).reshape(n_type, -1)
-            joff = joff.reshape(n_type, 2, -1)
-            max_K = M.n_dyn_junc // n_type
-            N = len(junc)
-            if mode != "training":
-                K = min(int((jmap > M.eval_junc_thres).float().sum().item()), max_K)
-            else:
-                K = min(int(N * 2 + 2), max_K)
-            if K < 2:
-                K = 2
-            device = jmap.device
-
-            # index: [N_TYPE, K]
-            score, index = torch.topk(jmap, k=K)
-            y = (index // 128).float() + torch.gather(joff[:, 0], 1, index) + 0.5
-            x = (index % 128).float() + torch.gather(joff[:, 1], 1, index) + 0.5
-
-            # xy: [N_TYPE, K, 2]
-            xy = torch.cat([y[..., None], x[..., None]], dim=-1)
-            xy_ = xy[..., None, :]
-            del x, y, index
-
-            # dist: [N_TYPE, K, N]
-            dist = torch.sum((xy_ - junc) ** 2, -1)
-            cost, match = torch.min(dist, -1)
-
-            # xy: [N_TYPE * K, 2]
-            # match: [N_TYPE, K]
-            for t in range(n_type):
-                match[t, jtyp[match[t]] != t] = N
-            match[cost > 1.5 * 1.5] = N
-            match = match.flatten()
-
-            _ = torch.arange(n_type * K, device=device)
-            u, v = torch.meshgrid(_, _)
-            u, v = u.flatten(), v.flatten()
-            up, vp = match[u], match[v]
-            label = Lpos[up, vp]
-
-            if mode == "training":
-                c = torch.zeros_like(label, dtype=torch.bool)
-
-                # sample positive lines
-                cdx = label.nonzero().flatten()
-                if len(cdx) > M.n_dyn_posl:
-                    # print("too many positive lines")
-                    perm = torch.randperm(len(cdx), device=device)[: M.n_dyn_posl]
-                    cdx = cdx[perm]
-                c[cdx] = 1
-
-                # sample negative lines
-                cdx = Lneg[up, vp].nonzero().flatten()
-                if len(cdx) > M.n_dyn_negl:
-                    # print("too many negative lines")
-                    perm = torch.randperm(len(cdx), device=device)[: M.n_dyn_negl]
-                    cdx = cdx[perm]
-                c[cdx] = 1
-
-                # sample other (unmatched) lines
-                cdx = torch.randint(len(c), (M.n_dyn_othr,), device=device)
-                c[cdx] = 1
-            else:
-                c = (u < v).flatten()
-
-            # sample lines
-            u, v, label = u[c], v[c], label[c]
-            xy = xy.reshape(n_type * K, 2)
-            xyu, xyv = xy[u], xy[v]
-
-            u2v = xyu - xyv
-            u2v /= torch.sqrt((u2v ** 2).sum(-1, keepdim=True)).clamp(min=1e-6)
-            feat = torch.cat(
-                [
-                    xyu / 128 * M.use_cood,
-                    xyv / 128 * M.use_cood,
-                    u2v * M.use_slop,
-                    (u[:, None] > K).float(),
-                    (v[:, None] > K).float(),
-                ],
-                1,
-            )
-            line = torch.cat([xyu[:, None], xyv[:, None]], 1)
-
-            xy = xy.reshape(n_type, K, 2)
-            jcs = [xy[i, score[i] > 0.03] for i in range(n_type)]
-            return line, label.float(), feat, jcs
-
-
-def non_maximum_suppression(a):
-    ap = F.max_pool2d(a, 3, stride=1, padding=1)
-    mask = (a == ap).float().clamp(min=0.0)
-    return a * mask
-
-
-class Bottleneck1D(nn.Module):
-    def __init__(self, inplanes, outplanes):
-        super(Bottleneck1D, self).__init__()
-
-        planes = outplanes // 2
-        self.op = nn.Sequential(
-            nn.BatchNorm1d(inplanes),
-            nn.ReLU(inplace=True),
-            nn.Conv1d(inplanes, planes, kernel_size=1),
-            nn.BatchNorm1d(planes),
-            nn.ReLU(inplace=True),
-            nn.Conv1d(planes, planes, kernel_size=3, padding=1),
-            nn.BatchNorm1d(planes),
-            nn.ReLU(inplace=True),
-            nn.Conv1d(planes, outplanes, kernel_size=1),
-        )
-
-    def forward(self, x):
-        return x + self.op(x)

+ 0 - 118
lcnn/models/multitask_learner.py

@@ -1,118 +0,0 @@
-from collections import OrderedDict, defaultdict
-
-import numpy as np
-import torch
-import torch.nn as nn
-import torch.nn.functional as F
-
-from lcnn.config import M
-
-
-class MultitaskHead(nn.Module):
-    def __init__(self, input_channels, num_class):
-        super(MultitaskHead, self).__init__()
-        # print("输入的维度是:", input_channels)
-        m = int(input_channels / 4)
-        heads = []
-        for output_channels in sum(M.head_size, []):
-            heads.append(
-                nn.Sequential(
-                    nn.Conv2d(input_channels, m, kernel_size=3, padding=1),
-                    nn.ReLU(inplace=True),
-                    nn.Conv2d(m, output_channels, kernel_size=1),
-                )
-            )
-        self.heads = nn.ModuleList(heads)
-        assert num_class == sum(sum(M.head_size, []))
-
-    def forward(self, x):
-        return torch.cat([head(x) for head in self.heads], dim=1)
-
-
-class MultitaskLearner(nn.Module):
-    def __init__(self, backbone):
-        super(MultitaskLearner, self).__init__()
-        self.backbone = backbone
-        head_size = M.head_size
-        self.num_class = sum(sum(head_size, []))
-        self.head_off = np.cumsum([sum(h) for h in head_size])
-
-    def forward(self, input_dict):
-        image = input_dict["image"]
-        target_b = input_dict["target_b"]
-        outputs, feature, aaa = self.backbone(image, target_b, input_dict["mode"])  # train时aaa是损失,val时是box
-
-        result = {"feature": feature}
-        batch, channel, row, col = outputs[0].shape
-        # print(f"batch:{batch}")
-        # print(f"channel:{channel}")
-        # print(f"row:{row}")
-        # print(f"col:{col}")
-
-        T = input_dict["target"].copy()
-        n_jtyp = T["junc_map"].shape[1]
-
-        # switch to CNHW
-        for task in ["junc_map"]:
-            T[task] = T[task].permute(1, 0, 2, 3)
-        for task in ["junc_offset"]:
-            T[task] = T[task].permute(1, 2, 0, 3, 4)
-
-        offset = self.head_off
-        loss_weight = M.loss_weight
-        losses = []
-        for stack, output in enumerate(outputs):
-            output = output.transpose(0, 1).reshape([-1, batch, row, col]).contiguous()
-            jmap = output[0: offset[0]].reshape(n_jtyp, 2, batch, row, col)
-            lmap = output[offset[0]: offset[1]].squeeze(0)
-            # print(f"lmap:{lmap.shape}")
-            joff = output[offset[1]: offset[2]].reshape(n_jtyp, 2, batch, row, col)
-            if stack == 0:
-                result["preds"] = {
-                    "jmap": jmap.permute(2, 0, 1, 3, 4).softmax(2)[:, :, 1],
-                    "lmap": lmap.sigmoid(),
-                    "joff": joff.permute(2, 0, 1, 3, 4).sigmoid() - 0.5,
-                }
-                if input_dict["mode"] == "testing":
-                    return result
-
-            L = OrderedDict()
-            L["jmap"] = sum(
-                cross_entropy_loss(jmap[i], T["junc_map"][i]) for i in range(n_jtyp)
-            )
-            L["lmap"] = (
-                F.binary_cross_entropy_with_logits(lmap, T["line_map"], reduction="none")
-                    .mean(2)
-                    .mean(1)
-            )
-            L["joff"] = sum(
-                sigmoid_l1_loss(joff[i, j], T["junc_offset"][i, j], -0.5, T["junc_map"][i])
-                for i in range(n_jtyp)
-                for j in range(2)
-            )
-            for loss_name in L:
-                L[loss_name].mul_(loss_weight[loss_name])
-            losses.append(L)
-        result["losses"] = losses
-        result["aaa"] = aaa
-        return result
-
-
-def l2loss(input, target):
-    return ((target - input) ** 2).mean(2).mean(1)
-
-
-def cross_entropy_loss(logits, positive):
-    nlogp = -F.log_softmax(logits, dim=0)
-    return (positive * nlogp[1] + (1 - positive) * nlogp[0]).mean(2).mean(1)
-
-
-def sigmoid_l1_loss(logits, target, offset=0.0, mask=None):
-    logp = torch.sigmoid(logits) + offset
-    loss = torch.abs(logp - target)
-    if mask is not None:
-        w = mask.mean(2, True).mean(1, True)
-        w[w == 0] = 1
-        loss = loss * (mask / w)
-
-    return loss.mean(2).mean(1)

+ 0 - 182
lcnn/models/resnet50.py

@@ -1,182 +0,0 @@
-# import torch
-# import torch.nn as nn
-# import torchvision.models as models
-#
-#
-# class ResNet50Backbone(nn.Module):
-#     def __init__(self, num_classes=5, num_stacks=1, pretrained=True):
-#         super(ResNet50Backbone, self).__init__()
-#
-#         # 加载预训练的ResNet50
-#         if pretrained:
-#             resnet = models.resnet50(weights=models.ResNet50_Weights.IMAGENET1K_V1)
-#         else:
-#             resnet = models.resnet50(weights=None)
-#
-#         # 移除最后的全连接层
-#         self.backbone = nn.Sequential(
-#             resnet.conv1,#特征图分辨率降低为1/2,通道数从3升为64
-#             resnet.bn1,
-#             resnet.relu,
-#             resnet.maxpool,#特征图分辨率降低为1/4,通道数仍然为64
-#             resnet.layer1,#stride为1,不改变分辨率,依然为1/4,通道数从64升为64*4=256
-#             # resnet.layer2,#stride为2,特征图分辨率降低为1/8,通道数从256升为128*4=512
-#             # resnet.layer3,#stride为2,特征图分辨率降低为1/16,通道数从512升为256*4=1024
-#             # resnet.layer4,#stride为2,特征图分辨率降低为1/32,通道数从512升为256*4=2048
-#         )
-#
-#         # 多任务输出层
-#         self.score_layers = nn.ModuleList([
-#             nn.Sequential(
-#                 nn.Conv2d(256, 128, kernel_size=3, padding=1),
-#                 nn.BatchNorm2d(128),
-#                 nn.ReLU(inplace=True),
-#                 nn.Conv2d(128, num_classes, kernel_size=1)
-#             )
-#             for _ in range(num_stacks)
-#         ])
-#
-#         # 上采样层,确保输出大小为128x128
-#         self.upsample = nn.Upsample(
-#             scale_factor=0.25,
-#             mode='bilinear',
-#             align_corners=True
-#         )
-#
-#     def forward(self, x):
-#         # 主干网络特征提取
-#         x = self.backbone(x)
-#
-#         # # 调整通道数
-#         # x = self.channel_adjust(x)
-#         #
-#         # # 上采样到128x128
-#         # x = self.upsample(x)
-#
-#         # 多堆栈输出
-#         outputs = []
-#         for score_layer in self.score_layers:
-#             output = score_layer(x)
-#             outputs.append(output)
-#
-#         # 返回第一个输出(如果有多个堆栈)
-#         return outputs, x
-#
-#
-# def resnet50(**kwargs):
-#     model = ResNet50Backbone(
-#         num_classes=kwargs.get("num_classes", 5),
-#         num_stacks=kwargs.get("num_stacks", 1),
-#         pretrained=kwargs.get("pretrained", True)
-#     )
-#     return model
-#
-#
-# __all__ = ["ResNet50Backbone", "resnet50"]
-#
-#
-# # 测试网络输出
-# model = resnet50(num_classes=5, num_stacks=1)
-#
-#
-# # 方法1:直接传入图像张量
-# x = torch.randn(2, 3, 512, 512)
-# outputs, feature = model(x)
-# print("Outputs length:", len(outputs))
-# print("Output[0] shape:", outputs[0].shape)
-# print("Feature shape:", feature.shape)
-
-import torch
-import torch.nn as nn
-import torchvision
-import torchvision.models as models
-from torchvision.models.detection.backbone_utils import _validate_trainable_layers, _resnet_fpn_extractor
-
-
-class ResNet50Backbone1(nn.Module):
-    def __init__(self, num_classes=5, num_stacks=1, pretrained=True):
-        super(ResNet50Backbone1, self).__init__()
-
-        # 加载预训练的ResNet50
-        if pretrained:
-            # self.resnet = models.resnet50(weights=models.ResNet50_Weights.IMAGENET1K_V1)
-            trainable_backbone_layers = _validate_trainable_layers(True, None, 5, 3)
-            backbone = models.resnet50(weights=models.ResNet50_Weights.IMAGENET1K_V1)
-            self.resnet = _resnet_fpn_extractor(backbone, trainable_backbone_layers)
-        else:
-            self.resnet = models.resnet50(weights=None)
-
-    def forward(self, x):
-        features = self.resnet(x)
-
-        # if self.training:
-        #     proposals, matched_idxs, labels, regression_targets = self.select_training_samples20(proposals, targets)
-        # else:
-        #     labels = None
-        #     regression_targets = None
-        #     matched_idxs = None
-        # box_features = self.box_roi_pool(features, proposals, image_shapes)  # ROI 映射到固定尺寸的特征表示上
-        # # print(f"box_features:{box_features.shape}")  # [2048, 256, 7, 7] 建议框统一大小
-        # box_features = self.box_head(box_features)  #
-        # # print(f"box_features:{box_features.shape}")  # [N, 1024] 经过头部网络处理后的特征向量
-        # class_logits, box_regression = self.box_predictor(box_features)
-        #
-        # result: List[Dict[str, torch.Tensor]] = []
-        # losses = {}
-        # if self.training:
-        #     if labels is None:
-        #         raise ValueError("labels cannot be None")
-        #     if regression_targets is None:
-        #         raise ValueError("regression_targets cannot be None")
-        #     loss_classifier, loss_box_reg = fastrcnn_loss(class_logits, box_regression, labels, regression_targets)
-        #     losses = {"loss_classifier": loss_classifier, "loss_box_reg": loss_box_reg}
-        # else:
-        #     boxes, scores, labels = self.postprocess_detections(class_logits, box_regression, proposals, image_shapes)
-        #     num_images = len(boxes)
-        #     for i in range(num_images):
-        #         result.append(
-        #             {
-        #                 "boxes": boxes[i],
-        #                 "labels": labels[i],
-        #                 "scores": scores[i],
-        #             }
-        #         )
-        #     print(f"boxes:{boxes[0].shape}")
-
-        return x['0']
-
-        # # 多堆栈输出
-        # outputs = []
-        # for score_layer in self.score_layers:
-        #     output = score_layer(x)
-        #     outputs.append(output)
-        #
-        # # 返回第一个输出(如果有多个堆栈)
-        # return outputs, x
-
-
-def resnet501(**kwargs):
-    model = ResNet50Backbone1(
-        num_classes=kwargs.get("num_classes", 5),
-        num_stacks=kwargs.get("num_stacks", 1),
-        pretrained=kwargs.get("pretrained", True)
-    )
-    return model
-
-
-# __all__ = ["ResNet50Backbone1", "resnet501"]
-#
-# # 测试网络输出
-# model = resnet501(num_classes=5, num_stacks=1)
-#
-# # 方法1:直接传入图像张量
-# x = torch.randn(2, 3, 512, 512)
-# # outputs, feature = model(x)
-# # print("Outputs length:", len(outputs))
-# # print("Output[0] shape:", outputs[0].shape)
-# # print("Feature shape:", feature.shape)
-# feature = model(x)
-# print("Feature shape:", feature.keys())
-
-
-

+ 0 - 87
lcnn/models/resnet50_pose.py

@@ -1,87 +0,0 @@
-import torch
-import torch.nn as nn
-import torchvision.models as models
-
-
-class ResNet50Backbone(nn.Module):
-    def __init__(self, num_classes=5, num_stacks=1, pretrained=True):
-        super(ResNet50Backbone, self).__init__()
-
-        # 加载预训练的ResNet50
-        if pretrained:
-            resnet = models.resnet50(weights=models.ResNet50_Weights.IMAGENET1K_V1)
-        else:
-            resnet = models.resnet50(weights=None)
-
-        # 移除最后的全连接层
-        self.backbone = nn.Sequential(
-            resnet.conv1,#特征图分辨率降低为1/2,通道数从3升为64
-            resnet.bn1,
-            resnet.relu,
-            resnet.maxpool,#特征图分辨率降低为1/4,通道数仍然为64
-            resnet.layer1,#stride为1,不改变分辨率,依然为1/4,通道数从64升为64*4=256
-            # resnet.layer2,#stride为2,特征图分辨率降低为1/8,通道数从256升为128*4=512
-            # resnet.layer3,#stride为2,特征图分辨率降低为1/16,通道数从512升为256*4=1024
-            # resnet.layer4,#stride为2,特征图分辨率降低为1/32,通道数从512升为256*4=2048
-        )
-
-        # 多任务输出层
-        self.score_layers = nn.ModuleList([
-            nn.Sequential(
-                nn.Conv2d(256, 128, kernel_size=3, padding=1),
-                nn.BatchNorm2d(128),
-                nn.ReLU(inplace=True),
-                nn.Conv2d(128, num_classes, kernel_size=1)
-            )
-            for _ in range(num_stacks)
-        ])
-
-        # 上采样层,确保输出大小为128x128
-        self.upsample = nn.Upsample(
-            scale_factor=0.25,
-            mode='bilinear',
-            align_corners=True
-        )
-
-    def forward(self, x):
-        # 主干网络特征提取
-        x = self.backbone(x)
-
-        # # 调整通道数
-        # x = self.channel_adjust(x)
-        #
-        # # 上采样到128x128
-        # x = self.upsample(x)
-
-        # 多堆栈输出
-        outputs = []
-        for score_layer in self.score_layers:
-            output = score_layer(x)
-            outputs.append(output)
-
-        # 返回第一个输出(如果有多个堆栈)
-        return outputs, x
-
-
-def resnet50(**kwargs):
-    model = ResNet50Backbone(
-        num_classes=kwargs.get("num_classes", 5),
-        num_stacks=kwargs.get("num_stacks", 1),
-        pretrained=kwargs.get("pretrained", True)
-    )
-    return model
-
-
-__all__ = ["ResNet50Backbone", "resnet50"]
-
-
-# 测试网络输出
-model = resnet50(num_classes=5, num_stacks=1)
-
-
-# 方法1:直接传入图像张量
-x = torch.randn(2, 3, 512, 512)
-outputs, feature = model(x)
-print("Outputs length:", len(outputs))
-print("Output[0] shape:", outputs[0].shape)
-print("Feature shape:", feature.shape)

+ 0 - 126
lcnn/models/unet.py

@@ -1,126 +0,0 @@
-import torch
-import torch.nn as nn
-import torch.nn.functional as F
-
-__all__ = ["UNetWithMultipleStacks", "unet"]
-
-
-class DoubleConv(nn.Module):
-    def __init__(self, in_channels, out_channels):
-        super().__init__()
-        self.double_conv = nn.Sequential(
-            nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
-            nn.BatchNorm2d(out_channels),
-            nn.ReLU(inplace=True),
-            nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1),
-            nn.BatchNorm2d(out_channels),
-            nn.ReLU(inplace=True)
-        )
-
-    def forward(self, x):
-        return self.double_conv(x)
-
-
-class UNetWithMultipleStacks(nn.Module):
-    def __init__(self, num_classes, num_stacks=2, base_channels=64):
-        super().__init__()
-        self.num_stacks = num_stacks
-        # 编码器
-        self.enc1 = DoubleConv(3, base_channels)
-        self.pool1 = nn.MaxPool2d(2)
-        self.enc2 = DoubleConv(base_channels, base_channels * 2)
-        self.pool2 = nn.MaxPool2d(2)
-        self.enc3 = DoubleConv(base_channels * 2, base_channels * 4)
-        self.pool3 = nn.MaxPool2d(2)
-        self.enc4 = DoubleConv(base_channels * 4, base_channels * 8)
-        self.pool4 = nn.MaxPool2d(2)
-        self.enc5 = DoubleConv(base_channels * 8, base_channels * 16)
-        self.pool5 = nn.MaxPool2d(2)
-
-        # bottleneck
-        self.bottleneck = DoubleConv(base_channels * 16, base_channels * 32)
-
-        # 解码器
-        self.upconv5 = nn.ConvTranspose2d(base_channels * 32, base_channels * 16, kernel_size=2, stride=2)
-        self.dec5 = DoubleConv(base_channels * 32, base_channels * 16)
-        self.upconv4 = nn.ConvTranspose2d(base_channels * 16, base_channels * 8, kernel_size=2, stride=2)
-        self.dec4 = DoubleConv(base_channels * 16, base_channels * 8)
-        self.upconv3 = nn.ConvTranspose2d(base_channels * 8, base_channels * 4, kernel_size=2, stride=2)
-        self.dec3 = DoubleConv(base_channels * 8, base_channels * 4)
-        self.upconv2 = nn.ConvTranspose2d(base_channels * 4, base_channels * 2, kernel_size=2, stride=2)
-        self.dec2 = DoubleConv(base_channels * 4, base_channels * 2)
-        self.upconv1 = nn.ConvTranspose2d(base_channels * 2, base_channels, kernel_size=2, stride=2)
-        self.dec1 = DoubleConv(base_channels * 2, base_channels)
-
-        # 额外的上采样层,从512降到128
-        self.final_upsample = nn.Sequential(
-            nn.Conv2d(base_channels, base_channels, kernel_size=3, padding=1),
-            nn.BatchNorm2d(base_channels),
-            nn.ReLU(inplace=True),
-            nn.Upsample(scale_factor=0.25, mode='bilinear', align_corners=True)
-        )
-
-        # 修改score_layers以匹配256通道
-        self.score_layers = nn.ModuleList([
-            nn.Sequential(
-                nn.Conv2d(256, 128, kernel_size=3, padding=1),
-                nn.ReLU(inplace=True),
-                nn.Conv2d(128, num_classes, kernel_size=1)
-            )
-            for _ in range(num_stacks)
-        ])
-
-        self.channel_adjust = nn.Conv2d(base_channels, 256, kernel_size=1)
-
-    def forward(self, x):
-        # 编码过程
-        enc1 = self.enc1(x)
-        enc2 = self.enc2(self.pool1(enc1))
-        enc3 = self.enc3(self.pool2(enc2))
-        enc4 = self.enc4(self.pool3(enc3))
-        enc5 = self.enc5(self.pool4(enc4))
-
-        # 瓶颈层
-        bottleneck = self.bottleneck(self.pool5(enc5))
-
-        # 解码过程
-        dec5 = self.upconv5(bottleneck)
-        dec5 = torch.cat([dec5, enc5], dim=1)
-        dec5 = self.dec5(dec5)
-
-        dec4 = self.upconv4(dec5)
-        dec4 = torch.cat([dec4, enc4], dim=1)
-        dec4 = self.dec4(dec4)
-
-        dec3 = self.upconv3(dec4)
-        dec3 = torch.cat([dec3, enc3], dim=1)
-        dec3 = self.dec3(dec3)
-
-        dec2 = self.upconv2(dec3)
-        dec2 = torch.cat([dec2, enc2], dim=1)
-        dec2 = self.dec2(dec2)
-
-        dec1 = self.upconv1(dec2)
-        dec1 = torch.cat([dec1, enc1], dim=1)
-        dec1 = self.dec1(dec1)
-
-        # 额外的上采样,使输出大小为128
-        dec1 = self.final_upsample(dec1)
-        # 调整通道数
-        dec1 = self.channel_adjust(dec1)
-        # 多堆栈输出
-        outputs = []
-        for score_layer in self.score_layers:
-            output = score_layer(dec1)
-            outputs.append(output)
-
-        return outputs[::-1], dec1
-
-
-def unet(**kwargs):
-    model = UNetWithMultipleStacks(
-        num_classes=kwargs["num_classes"],
-        num_stacks=kwargs.get("num_stacks", 2),
-        base_channels=kwargs.get("base_channels", 64)
-    )
-    return model

+ 0 - 77
lcnn/postprocess.py

@@ -1,77 +0,0 @@
-import numpy as np
-
-
-def pline(x1, y1, x2, y2, x, y):
-    px = x2 - x1
-    py = y2 - y1
-    dd = px * px + py * py
-    u = ((x - x1) * px + (y - y1) * py) / max(1e-9, float(dd))
-    dx = x1 + u * px - x
-    dy = y1 + u * py - y
-    return dx * dx + dy * dy
-
-
-def psegment(x1, y1, x2, y2, x, y):
-    px = x2 - x1
-    py = y2 - y1
-    dd = px * px + py * py
-    u = max(min(((x - x1) * px + (y - y1) * py) / float(dd), 1), 0)
-    dx = x1 + u * px - x
-    dy = y1 + u * py - y
-    return dx * dx + dy * dy
-
-
-def plambda(x1, y1, x2, y2, x, y):
-    px = x2 - x1
-    py = y2 - y1
-    dd = px * px + py * py
-    return ((x - x1) * px + (y - y1) * py) / max(1e-9, float(dd))
-
-
-def postprocess(lines, scores, threshold=0.01, tol=1e9, do_clip=False):
-    nlines, nscores = [], []
-    for (p, q), score in zip(lines, scores):
-        start, end = 0, 1
-        for a, b in nlines:
-            if (
-                min(
-                    max(pline(*p, *q, *a), pline(*p, *q, *b)),
-                    max(pline(*a, *b, *p), pline(*a, *b, *q)),
-                )
-                > threshold ** 2
-            ):
-                continue
-            lambda_a = plambda(*p, *q, *a)
-            lambda_b = plambda(*p, *q, *b)
-            if lambda_a > lambda_b:
-                lambda_a, lambda_b = lambda_b, lambda_a
-            lambda_a -= tol
-            lambda_b += tol
-
-            # case 1: skip (if not do_clip)
-            if start < lambda_a and lambda_b < end:
-                continue
-
-            # not intersect
-            if lambda_b < start or lambda_a > end:
-                continue
-
-            # cover
-            if lambda_a <= start and end <= lambda_b:
-                start = 10
-                break
-
-            # case 2 & 3:
-            if lambda_a <= start and start <= lambda_b:
-                start = lambda_b
-            if lambda_a <= end and end <= lambda_b:
-                end = lambda_a
-
-            if start >= end:
-                break
-
-        if start >= end:
-            continue
-        nlines.append(np.array([p + (q - p) * start, p + (q - p) * end]))
-        nscores.append(score)
-    return np.array(nlines), np.array(nscores)

+ 0 - 429
lcnn/trainer.py

@@ -1,429 +0,0 @@
-import atexit
-import os
-import os.path as osp
-import shutil
-import signal
-import subprocess
-import threading
-import time
-from timeit import default_timer as timer
-
-import matplotlib as mpl
-import matplotlib.pyplot as plt
-import numpy as np
-import torch
-import torch.nn.functional as F
-from skimage import io
-from tensorboardX import SummaryWriter
-
-from lcnn.config import C, M
-from lcnn.utils import recursive_to
-import matplotlib
-
-from 冻结参数训练 import verify_freeze_params
-import os
-
-from torchvision.utils import draw_bounding_boxes
-from torchvision import transforms
-from .postprocess import postprocess
-
-os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'expandable_segments:True'
-
-
-# matplotlib.use('Agg')  # 使用无窗口后端
-
-
-# 绘图
-def show_line(img, pred, epoch, writer):
-    im = img.permute(1, 2, 0)
-    writer.add_image("ori", im, epoch, dataformats="HWC")
-
-    boxed_image = draw_bounding_boxes((img * 255).to(torch.uint8), pred["box"][0]["boxes"],
-                                      colors="yellow", width=1)
-    writer.add_image("boxes", boxed_image.permute(1, 2, 0), epoch, dataformats="HWC")
-
-    PLTOPTS = {"color": "#33FFFF", "s": 15, "edgecolors": "none", "zorder": 5}
-    H = pred["preds"]
-    lines = H["lines"][0].cpu().numpy() / 128 * im.shape[:2]
-    scores = H["score"][0].cpu().numpy()
-    for i in range(1, len(lines)):
-        if (lines[i] == lines[0]).all():
-            lines = lines[:i]
-            scores = scores[:i]
-            break
-
-    # postprocess lines to remove overlapped lines
-    diag = (im.shape[0] ** 2 + im.shape[1] ** 2) ** 0.5
-    nlines, nscores = postprocess(lines, scores, diag * 0.01, 0, False)
-
-    for i, t in enumerate([0.7]):
-        plt.gca().set_axis_off()
-        plt.subplots_adjust(top=1, bottom=0, right=1, left=0, hspace=0, wspace=0)
-        plt.margins(0, 0)
-        for (a, b), s in zip(nlines, nscores):
-            if s < t:
-                continue
-            plt.plot([a[1], b[1]], [a[0], b[0]], c=c(s), linewidth=2, zorder=s)
-            plt.scatter(a[1], a[0], **PLTOPTS)
-            plt.scatter(b[1], b[0], **PLTOPTS)
-        plt.gca().xaxis.set_major_locator(plt.NullLocator())
-        plt.gca().yaxis.set_major_locator(plt.NullLocator())
-        plt.imshow(im)
-        plt.tight_layout()
-        fig = plt.gcf()
-        fig.canvas.draw()
-        image_from_plot = np.frombuffer(fig.canvas.tostring_rgb(), dtype=np.uint8).reshape(
-            fig.canvas.get_width_height()[::-1] + (3,))
-        plt.close()
-        img2 = transforms.ToTensor()(image_from_plot)
-
-        writer.add_image("output", img2, epoch)
-
-class Trainer(object):
-    def __init__(self, device, model, optimizer, train_loader, val_loader, out):
-        self.device = device
-
-        self.model = model
-        self.optim = optimizer
-
-        self.train_loader = train_loader
-        self.val_loader = val_loader
-        self.batch_size = C.model.batch_size
-
-        self.validation_interval = C.io.validation_interval
-
-        self.out = out
-        if not osp.exists(self.out):
-            os.makedirs(self.out)
-
-        # self.run_tensorboard()
-        self.writer = SummaryWriter('logs/')
-        time.sleep(1)
-
-        self.epoch = 0
-        self.iteration = 0
-        self.max_epoch = C.optim.max_epoch
-        self.lr_decay_epoch = C.optim.lr_decay_epoch
-        self.num_stacks = C.model.num_stacks
-        self.mean_loss = self.best_mean_loss = 1e1000
-
-        self.loss_labels = None
-        self.avg_metrics = None
-        self.metrics = np.zeros(0)
-
-        self.show_line = show_line
-
-    # def run_tensorboard(self):
-    #     board_out = osp.join(self.out, "tensorboard")
-    #     if not osp.exists(board_out):
-    #         os.makedirs(board_out)
-    #     self.writer = SummaryWriter(board_out)
-    #     os.environ["CUDA_VISIBLE_DEVICES"] = ""
-    #     p = subprocess.Popen(
-    #         ["tensorboard", f"--logdir={board_out}", f"--port={C.io.tensorboard_port}"]
-    #     )
-    #
-    #     def killme():
-    #         os.kill(p.pid, signal.SIGTERM)
-    #
-    #     atexit.register(killme)
-
-    def _loss(self, result):
-        losses = result["losses"]
-        # Don't move loss label to other place.
-        # If I want to change the loss, I just need to change this function.
-        if self.loss_labels is None:
-            self.loss_labels = ["sum"] + list(losses[0].keys())
-            self.metrics = np.zeros([self.num_stacks, len(self.loss_labels)])
-            print()
-            print(
-                "| ".join(
-                    ["progress "]
-                    + list(map("{:7}".format, self.loss_labels))
-                    + ["speed"]
-                )
-            )
-            with open(f"{self.out}/loss.csv", "a") as fout:
-                print(",".join(["progress"] + self.loss_labels), file=fout)
-
-        total_loss = 0
-        for i in range(self.num_stacks):
-            for j, name in enumerate(self.loss_labels):
-                if name == "sum":
-                    continue
-                if name not in losses[i]:
-                    assert i != 0
-                    continue
-                loss = losses[i][name].mean()
-                self.metrics[i, 0] += loss.item()
-                self.metrics[i, j] += loss.item()
-                total_loss += loss
-        return total_loss
-
-
-    def validate(self):
-        tprint("Running validation...", " " * 75)
-        training = self.model.training
-        self.model.eval()
-
-        # viz = osp.join(self.out, "viz", f"{self.iteration * M.batch_size_eval:09d}")
-        # npz = osp.join(self.out, "npz", f"{self.iteration * M.batch_size_eval:09d}")
-        # osp.exists(viz) or os.makedirs(viz)
-        # osp.exists(npz) or os.makedirs(npz)
-
-        total_loss = 0
-        self.metrics[...] = 0
-        with torch.no_grad():
-            for batch_idx, (image, meta, target, target_b) in enumerate(self.val_loader):
-                input_dict = {
-                    "image": recursive_to(image, self.device),
-                    "meta": recursive_to(meta, self.device),
-                    "target": recursive_to(target, self.device),
-                    "target_b": recursive_to(target_b, self.device),
-                    "mode": "validation",
-                }
-                result = self.model(input_dict)
-                # print(f'image:{image.shape}')
-                # print(result["box"])
-
-                # total_loss += self._loss(result)
-
-                print(f'self.epoch:{self.epoch}')
-                # print(result.keys())
-                self.show_line(image[0], result, self.epoch, self.writer)
-
-
-
-                # H = result["preds"]
-                # for i in range(H["jmap"].shape[0]):
-                #     index = batch_idx * M.batch_size_eval + i
-                #     np.savez(
-                #         f"{npz}/{index:06}.npz",
-                #         **{k: v[i].cpu().numpy() for k, v in H.items()},
-                #     )
-                #     if index >= 20:
-                #         continue
-                #     self._plot_samples(i, index, H, meta, target, f"{viz}/{index:06}")
-
-        # self._write_metrics(len(self.val_loader), total_loss, "validation", True)
-        # self.mean_loss = total_loss / len(self.val_loader)
-
-        torch.save(
-            {
-                "iteration": self.iteration,
-                "arch": self.model.__class__.__name__,
-                "optim_state_dict": self.optim.state_dict(),
-                "model_state_dict": self.model.state_dict(),
-                "best_mean_loss": self.best_mean_loss,
-            },
-            osp.join(self.out, "checkpoint_latest.pth"),
-        )
-        # shutil.copy(
-        #     osp.join(self.out, "checkpoint_latest.pth"),
-        #     osp.join(npz, "checkpoint.pth"),
-        # )
-        if self.mean_loss < self.best_mean_loss:
-            self.best_mean_loss = self.mean_loss
-            shutil.copy(
-                osp.join(self.out, "checkpoint_latest.pth"),
-                osp.join(self.out, "checkpoint_best.pth"),
-            )
-
-        if training:
-            self.model.train()
-
-    def verify_freeze_params(model, freeze_config):
-        """
-        验证参数冻结是否生效
-        """
-        print("\n===== Verifying Parameter Freezing =====")
-
-        for name, module in model.named_children():
-            if name in freeze_config:
-                if freeze_config[name]:
-                    print(f"\nChecking module: {name}")
-                    for param_name, param in module.named_parameters():
-                        print(f"  {param_name}: requires_grad = {param.requires_grad}")
-
-            # 特别处理fc2子模块
-            if name == 'fc2' and 'fc2_submodules' in freeze_config:
-                for subname, submodule in module.named_children():
-                    if subname in freeze_config['fc2_submodules']:
-                        if freeze_config['fc2_submodules'][subname]:
-                            print(f"\nChecking fc2 submodule: {subname}")
-                            for param_name, param in submodule.named_parameters():
-                                print(f"  {param_name}: requires_grad = {param.requires_grad}")
-
-    def train_epoch(self):
-        self.model.train()
-
-        time = timer()
-        for batch_idx, (image, meta, target, target_b) in enumerate(self.train_loader):
-            self.optim.zero_grad()
-            self.metrics[...] = 0
-
-            input_dict = {
-                "image": recursive_to(image, self.device),
-                "meta": recursive_to(meta, self.device),
-                "target": recursive_to(target, self.device),
-                "target_b": recursive_to(target_b, self.device),
-                "mode": "training",
-            }
-            result = self.model(input_dict)
-
-            loss = self._loss(result)
-            if np.isnan(loss.item()):
-                raise ValueError("loss is nan while training")
-            loss.backward()
-            self.optim.step()
-
-            if self.avg_metrics is None:
-                self.avg_metrics = self.metrics
-            else:
-                self.avg_metrics = self.avg_metrics * 0.9 + self.metrics * 0.1
-            self.iteration += 1
-            self._write_metrics(1, loss.item(), "training", do_print=False)
-
-            if self.iteration % 4 == 0:
-                tprint(
-                    f"{self.epoch:03}/{self.iteration * self.batch_size // 1000:04}k| "
-                    + "| ".join(map("{:.5f}".format, self.avg_metrics[0]))
-                    + f"| {4 * self.batch_size / (timer() - time):04.1f} "
-                )
-                time = timer()
-            num_images = self.batch_size * self.iteration
-            # if num_images % self.validation_interval == 0 or num_images == 4:
-            #     self.validate()
-            #     time = timer()
-        self.validate()
-        # verify_freeze_params()
-
-    def _write_metrics(self, size, total_loss, prefix, do_print=False):
-        for i, metrics in enumerate(self.metrics):
-            for label, metric in zip(self.loss_labels, metrics):
-                self.writer.add_scalar(
-                    f"{prefix}/{i}/{label}", metric / size, self.iteration
-                )
-            if i == 0 and do_print:
-                csv_str = (
-                        f"{self.epoch:03}/{self.iteration * self.batch_size:07},"
-                        + ",".join(map("{:.11f}".format, metrics / size))
-                )
-                prt_str = (
-                        f"{self.epoch:03}/{self.iteration * self.batch_size // 1000:04}k| "
-                        + "| ".join(map("{:.5f}".format, metrics / size))
-                )
-                with open(f"{self.out}/loss.csv", "a") as fout:
-                    print(csv_str, file=fout)
-                pprint(prt_str, " " * 7)
-        self.writer.add_scalar(
-            f"{prefix}/total_loss", total_loss / size, self.iteration
-        )
-        return total_loss
-
-    def _plot_samples(self, i, index, result, meta, target, prefix):
-        fn = self.val_loader.dataset.filelist[index][:-10].replace("_a0", "") + ".png"
-        img = io.imread(fn)
-        imshow(img), plt.savefig(f"{prefix}_img.jpg"), plt.close()
-
-        mask_result = result["jmap"][i].cpu().numpy()
-        mask_target = target["jmap"][i].cpu().numpy()
-        for ch, (ia, ib) in enumerate(zip(mask_target, mask_result)):
-            imshow(ia), plt.savefig(f"{prefix}_mask_{ch}a.jpg"), plt.close()
-            imshow(ib), plt.savefig(f"{prefix}_mask_{ch}b.jpg"), plt.close()
-
-        line_result = result["lmap"][i].cpu().numpy()
-        line_target = target["lmap"][i].cpu().numpy()
-        imshow(line_target), plt.savefig(f"{prefix}_line_a.jpg"), plt.close()
-        imshow(line_result), plt.savefig(f"{prefix}_line_b.jpg"), plt.close()
-
-        def draw_vecl(lines, sline, juncs, junts, fn):
-            imshow(img)
-            if len(lines) > 0 and not (lines[0] == 0).all():
-                for i, ((a, b), s) in enumerate(zip(lines, sline)):
-                    if i > 0 and (lines[i] == lines[0]).all():
-                        break
-                    plt.plot([a[1], b[1]], [a[0], b[0]], c=c(s), linewidth=4)
-            if not (juncs[0] == 0).all():
-                for i, j in enumerate(juncs):
-                    if i > 0 and (i == juncs[0]).all():
-                        break
-                    plt.scatter(j[1], j[0], c="red", s=64, zorder=100)
-            if junts is not None and len(junts) > 0 and not (junts[0] == 0).all():
-                for i, j in enumerate(junts):
-                    if i > 0 and (i == junts[0]).all():
-                        break
-                    plt.scatter(j[1], j[0], c="blue", s=64, zorder=100)
-            plt.savefig(fn), plt.close()
-
-        junc = meta[i]["junc"].cpu().numpy() * 4
-        jtyp = meta[i]["jtyp"].cpu().numpy()
-        juncs = junc[jtyp == 0]
-        junts = junc[jtyp == 1]
-        rjuncs = result["juncs"][i].cpu().numpy() * 4
-        rjunts = None
-        if "junts" in result:
-            rjunts = result["junts"][i].cpu().numpy() * 4
-
-        lpre = meta[i]["lpre"].cpu().numpy() * 4
-        vecl_target = meta[i]["lpre_label"].cpu().numpy()
-        vecl_result = result["lines"][i].cpu().numpy() * 4
-        score = result["score"][i].cpu().numpy()
-        lpre = lpre[vecl_target == 1]
-
-        draw_vecl(lpre, np.ones(lpre.shape[0]), juncs, junts, f"{prefix}_vecl_a.jpg")
-        draw_vecl(vecl_result, score, rjuncs, rjunts, f"{prefix}_vecl_b.jpg")
-
-    def train(self):
-        plt.rcParams["figure.figsize"] = (24, 24)
-        # if self.iteration == 0:
-        #     self.validate()
-        epoch_size = len(self.train_loader)
-        start_epoch = self.iteration // epoch_size
-
-        for self.epoch in range(start_epoch, self.max_epoch):
-            print(f"Epoch {self.epoch}/{C.optim.max_epoch} - Iteration {self.iteration}/{epoch_size}")
-            if self.epoch == self.lr_decay_epoch:
-                self.optim.param_groups[0]["lr"] /= 10
-            self.train_epoch()
-
-
-cmap = plt.get_cmap("jet")
-norm = mpl.colors.Normalize(vmin=0.4, vmax=1.0)
-sm = plt.cm.ScalarMappable(cmap=cmap, norm=norm)
-sm.set_array([])
-
-
-def c(x):
-    return sm.to_rgba(x)
-
-
-def imshow(im):
-    plt.close()
-    plt.tight_layout()
-    plt.imshow(im)
-    plt.colorbar(sm, fraction=0.046)
-    plt.xlim([0, im.shape[0]])
-    plt.ylim([im.shape[0], 0])
-
-
-def tprint(*args):
-    """Temporarily prints things on the screen"""
-    print("\r", end="")
-    print(*args, end="")
-
-
-def pprint(*args):
-    """Permanently prints things on the screen"""
-    print("\r", end="")
-    print(*args)
-
-
-def _launch_tensorboard(board_out, port, out):
-    os.environ["CUDA_VISIBLE_DEVICES"] = ""
-    p = subprocess.Popen(["tensorboard", f"--logdir={board_out}", f"--port={port}"])
-
-    def kill():
-        os.kill(p.pid, signal.SIGTERM)
-
-    atexit.register(kill)

+ 0 - 101
lcnn/utils.py

@@ -1,101 +0,0 @@
-import math
-import os.path as osp
-import multiprocessing
-from timeit import default_timer as timer
-
-import numpy as np
-import torch
-import matplotlib.pyplot as plt
-
-
-class benchmark(object):
-    def __init__(self, msg, enable=True, fmt="%0.3g"):
-        self.msg = msg
-        self.fmt = fmt
-        self.enable = enable
-
-    def __enter__(self):
-        if self.enable:
-            self.start = timer()
-        return self
-
-    def __exit__(self, *args):
-        if self.enable:
-            t = timer() - self.start
-            print(("%s : " + self.fmt + " seconds") % (self.msg, t))
-            self.time = t
-
-
-def quiver(x, y, ax):
-    ax.set_xlim(0, x.shape[1])
-    ax.set_ylim(x.shape[0], 0)
-    ax.quiver(
-        x,
-        y,
-        units="xy",
-        angles="xy",
-        scale_units="xy",
-        scale=1,
-        minlength=0.01,
-        width=0.1,
-        color="b",
-    )
-
-
-def recursive_to(input, device):
-    if isinstance(input, torch.Tensor):
-        return input.to(device)
-    if isinstance(input, dict):
-        for name in input:
-            if isinstance(input[name], torch.Tensor):
-                input[name] = input[name].to(device)
-        return input
-    if isinstance(input, list):
-        for i, item in enumerate(input):
-            input[i] = recursive_to(item, device)
-        return input
-    assert False
-
-
-def np_softmax(x, axis=0):
-    """Compute softmax values for each sets of scores in x."""
-    e_x = np.exp(x - np.max(x))
-    return e_x / e_x.sum(axis=axis, keepdims=True)
-
-
-def argsort2d(arr):
-    return np.dstack(np.unravel_index(np.argsort(arr.ravel()), arr.shape))[0]
-
-
-def __parallel_handle(f, q_in, q_out):
-    while True:
-        i, x = q_in.get()
-        if i is None:
-            break
-        q_out.put((i, f(x)))
-
-
-def parmap(f, X, nprocs=multiprocessing.cpu_count(), progress_bar=lambda x: x):
-    if nprocs == 0:
-        nprocs = multiprocessing.cpu_count()
-    q_in = multiprocessing.Queue(1)
-    q_out = multiprocessing.Queue()
-
-    proc = [
-        multiprocessing.Process(target=__parallel_handle, args=(f, q_in, q_out))
-        for _ in range(nprocs)
-    ]
-    for p in proc:
-        p.daemon = True
-        p.start()
-
-    try:
-        sent = [q_in.put((i, x)) for i, x in enumerate(X)]
-        [q_in.put((None, None)) for _ in range(nprocs)]
-        res = [q_out.get() for _ in progress_bar(range(len(sent)))]
-        [p.join() for p in proc]
-    except KeyboardInterrupt:
-        q_in.close()
-        q_out.close()
-        raise
-    return [x for i, x in sorted(res)]

+ 0 - 0
models/ins/__init__.py


+ 0 - 0
lcnn/models/base/__init__.py → models/ins_detect/__init__.py


+ 1 - 1
models/ins/maskrcnn.py → models/ins_detect/maskrcnn.py

@@ -15,7 +15,7 @@ from torchvision.models.detection.mask_rcnn import MaskRCNNPredictor
 from torchvision.utils import draw_bounding_boxes
 
 from models.config.config_tool import read_yaml
-from models.ins.trainer import train_cfg
+from models.ins_detect.trainer import train_cfg
 from tools import utils
 
 

+ 0 - 0
models/ins/maskrcnn_dataset.py → models/ins_detect/maskrcnn_dataset.py


+ 0 - 0
models/ins/train.yaml → models/ins_detect/train.yaml


+ 1 - 1
models/ins/trainer.py → models/ins_detect/trainer.py

@@ -9,7 +9,7 @@ from torch.utils.tensorboard import SummaryWriter
 from torchvision.models.detection import MaskRCNN_ResNet50_FPN_V2_Weights
 
 from models.config.config_tool import read_yaml
-from models.ins.maskrcnn_dataset import MaskRCNNDataset
+from models.ins_detect.maskrcnn_dataset import MaskRCNNDataset
 from tools import utils, presets
 
 

+ 1 - 1
models/keypoint/trainer.py

@@ -19,7 +19,7 @@ from tools.coco_eval import CocoEvaluator
 import time
 
 from models.config.config_tool import read_yaml
-from models.ins.maskrcnn_dataset import MaskRCNNDataset
+from models.ins_detect.maskrcnn_dataset import MaskRCNNDataset
 from models.keypoint.keypoint_dataset import KeypointDataset
 from tools import utils, presets
 

+ 0 - 120
models/line_detect/fasterrcnn_resnet50.py

@@ -1,120 +0,0 @@
-import torch
-import torch.nn as nn
-import torchvision
-from typing import Dict, List, Optional, Tuple
-import torch.nn.functional as F
-from torchvision.ops import MultiScaleRoIAlign
-from torchvision.models.detection.faster_rcnn import TwoMLPHead, FastRCNNPredictor
-from torchvision.models.detection.transform import GeneralizedRCNNTransform
-
-
-def get_model(num_classes):
-    # 加载预训练的ResNet-50 FPN backbone
-    model = torchvision.models.detection.fasterrcnn_resnet50_fpn(pretrained=True)
-
-    # 获取分类器的输入特征数
-    in_features = model.roi_heads.box_predictor.cls_score.in_features
-
-    # 替换分类器以适应新的类别数量
-    model.roi_heads.box_predictor = torchvision.models.detection.faster_rcnn.FastRCNNPredictor(in_features, num_classes)
-
-    return model
-
-
-def fastrcnn_loss(class_logits, box_regression, labels, regression_targets):
-    # type: (Tensor, Tensor, List[Tensor], List[Tensor]) -> Tuple[Tensor, Tensor]
-    """
-    Computes the loss for Faster R-CNN.
-
-    Args:
-        class_logits (Tensor)
-        box_regression (Tensor)
-        labels (list[BoxList])
-        regression_targets (Tensor)
-
-    Returns:
-        classification_loss (Tensor)
-        box_loss (Tensor)
-    """
-
-    labels = torch.cat(labels, dim=0)
-    regression_targets = torch.cat(regression_targets, dim=0)
-
-    classification_loss = F.cross_entropy(class_logits, labels)
-
-    # get indices that correspond to the regression targets for
-    # the corresponding ground truth labels, to be used with
-    # advanced indexing
-    sampled_pos_inds_subset = torch.where(labels > 0)[0]
-    labels_pos = labels[sampled_pos_inds_subset]
-    N, num_classes = class_logits.shape
-    box_regression = box_regression.reshape(N, box_regression.size(-1) // 4, 4)
-
-    box_loss = F.smooth_l1_loss(
-        box_regression[sampled_pos_inds_subset, labels_pos],
-        regression_targets[sampled_pos_inds_subset],
-        beta=1 / 9,
-        reduction="sum",
-    )
-    box_loss = box_loss / labels.numel()
-
-    return classification_loss, box_loss
-
-
-class Fasterrcnn_resnet50(nn.Module):
-    def __init__(self, num_classes=5, num_stacks=1):
-        super(Fasterrcnn_resnet50, self).__init__()
-
-        self.model = get_model(num_classes=5)
-        self.backbone = self.model.backbone
-
-        self.box_roi_pool = MultiScaleRoIAlign(featmap_names=["0", "1", "2", "3"], output_size=16, sampling_ratio=2)
-
-        out_channels = self.backbone.out_channels
-        resolution = self.box_roi_pool.output_size[0]
-        representation_size = 1024
-        self.box_head = TwoMLPHead(out_channels * resolution ** 2, representation_size)
-
-        self.box_predictor = FastRCNNPredictor(representation_size, num_classes)
-
-        # 多任务输出层
-        self.score_layers = nn.ModuleList([
-            nn.Sequential(
-                nn.Conv2d(256, 128, kernel_size=3, padding=1),
-                nn.BatchNorm2d(128),
-                nn.ReLU(inplace=True),
-                nn.Conv2d(128, num_classes, kernel_size=1)
-            )
-            for _ in range(num_stacks)
-        ])
-
-    def forward(self, x, target1, train_or_val, image_shapes=(512, 512)):
-
-        transform = GeneralizedRCNNTransform(min_size=512, max_size=1333, image_mean=[0.485, 0.456, 0.406],
-                                             image_std=[0.229, 0.224, 0.225])
-        images, targets = transform(x, target1)
-        x_ = self.backbone(images.tensors)
-
-        # x_ = self.backbone(x)  # '0'  '1'  '2'  '3'   'pool'
-        # print(f'backbone:{self.backbone}')
-        # print(f'Fasterrcnn_resnet50 x_:{x_}')
-        feature_ = x_['0']  # 图片特征
-        outputs = []
-        for score_layer in self.score_layers:
-            output = score_layer(feature_)
-            outputs.append(output)  # 多头
-
-        if train_or_val == "training":
-            loss_box = self.model(x, target1)
-            return outputs, feature_, loss_box
-        else:
-            box_all = self.model(x, target1)
-            return outputs, feature_, box_all
-
-
-def fasterrcnn_resnet50(**kwargs):
-    model = Fasterrcnn_resnet50(
-        num_classes=kwargs.get("num_classes", 5),
-        num_stacks=kwargs.get("num_stacks", 1)
-    )
-    return model

+ 1 - 1
models/wirenet/wirepoint_rcnn.py

@@ -22,7 +22,7 @@ from torchvision.ops import misc as misc_nn_ops
 
 from models.config import config_tool
 from models.config.config_tool import read_yaml
-from models.ins.trainer import get_transform
+from models.ins_detect.trainer import get_transform
 from models.wirenet.head import RoIHeads
 from models.wirenet.wirepoint_dataset import WirePointDataset
 from tools import utils

+ 1 - 1
models/wirenet2/trainer.py

@@ -9,7 +9,7 @@ from torch.utils.tensorboard import SummaryWriter
 from torchvision.models.detection import MaskRCNN_ResNet50_FPN_V2_Weights
 
 from models.config.config_tool import read_yaml
-from models.ins.maskrcnn_dataset import MaskRCNNDataset
+from models.ins_detect.maskrcnn_dataset import MaskRCNNDataset
 from models.keypoint.keypoint_dataset import KeypointDataset
 from tools import utils, presets
 def train_one_epoch(model, optimizer, data_loader, device, epoch, print_freq, scaler=None):