Quellcode durchsuchen

Remove .tar suffices for checkpoints

Yichao Zhou vor 6 Jahren
Ursprung
Commit
49f1b45fdb
3 geänderte Dateien mit 14 neuen und 14 gelöschten Zeilen
  1. 8 8
      README.md
  2. 5 5
      lcnn/trainer.py
  3. 1 1
      train.py

+ 8 - 8
README.md

@@ -22,10 +22,10 @@ The following table reports the performance metrics of several wireframe and lin
 
 |                                                      | ShanghaiTech (sAP<sup>10</sup>) | ShanghaiTech (AP<sup>H</sup>) | ShanghaiTech (F<sup>H</sup>) | ShanghaiTech (mAP<sup>J</sup>) | 
 | :--------------------------------------------------: | :--------------------------------: | :-----------------------------: | :----------------------------: | :------------------------------: | 
-| [LSD](https://ieeexplore.ieee.org/document/4731268/) |                 /                  |              52.0             |              61.0                |                /                 |              
-|  [AFM](https://github.com/cherubicXN/afm_cvpr2019)   |                24.4                |              69.5               |              77.2              |               23.3               |           
-| [Wireframe](https://github.com/huangkuns/wireframe)  |                5.1                 |              67.8               |              72.6              |               40.9               |              
-|                      **L-CNN**                       |              **62.9**              |            **82.8**             |            **81.2**            |             **59.3**             |               
+| [LSD](https://ieeexplore.ieee.org/document/4731268/) |                 /                  |              52.0             |              61.0                |                /                 |
+|  [AFM](https://github.com/cherubicXN/afm_cvpr2019)   |                24.4                |              69.5               |              77.2              |               23.3               |
+| [Wireframe](https://github.com/huangkuns/wireframe)  |                5.1                 |              67.8               |              72.6              |               40.9               |
+|                      **L-CNN**                       |              **62.9**              |            **82.8**             |            **81.2**            |             **59.3**             |
 
 ### Precision-Recall Curves
 <p align="center">
@@ -82,8 +82,8 @@ git clone https://github.com/zhou13/lcnn
 cd lcnn
 conda create -y -n lcnn
 source activate lcnn
-# Replace cudatoolkit=10.0 with your CUDA version: https://pytorch.org/
-conda install -y pytorch cudatoolkit=10.0 -c pytorch
+# Replace cudatoolkit=10.1 with your CUDA version: https://pytorch.org/
+conda install -y pytorch cudatoolkit=10.1 -c pytorch
 conda install -y tensorboardx -c conda-forge
 conda install -y pyyaml docopt matplotlib scikit-image opencv
 mkdir data logs post
@@ -94,7 +94,7 @@ mkdir data logs post
 You can download our reference pre-trained models from [Google
 Drive](https://drive.google.com/file/d/1NvZkEqWNUBAfuhFPNGiCItjy4iU0UOy2).  Those models were
 trained with `config/wireframe.yaml` for 312k iterations.  Use `demo.py`, `process.py`, and
-`eval-*.py` to evaluate the pre-trained models. **Do not try to unzip them!**
+`eval-*.py` to evaluate the pre-trained models.
 
 ### Detect Wireframes for Your Own Images
 To test LCNN on your own images, you need download the pre-trained models and execute
@@ -144,7 +144,7 @@ python ./train.py -d 0 --identifier baseline config/wireframe.yaml
 To generate wireframes on the validation dataset with the pretrained model, execute
 
 ```bash
-./process.py config/wireframe.yaml <path-to-checkpoint.pth.tar> data/wireframe logs/pretrained-model/npz/000312000
+./process.py config/wireframe.yaml <path-to-checkpoint.pth> data/wireframe logs/pretrained-model/npz/000312000
 ```
 
 ### Post Processing

+ 5 - 5
lcnn/trainer.py

@@ -144,17 +144,17 @@ class Trainer(object):
                 "model_state_dict": self.model.state_dict(),
                 "best_mean_loss": self.best_mean_loss,
             },
-            osp.join(self.out, "checkpoint_latest.pth.tar"),
+            osp.join(self.out, "checkpoint_latest.pth"),
         )
         shutil.copy(
-            osp.join(self.out, "checkpoint_latest.pth.tar"),
-            osp.join(npz, "checkpoint.pth.tar"),
+            osp.join(self.out, "checkpoint_latest.pth"),
+            osp.join(npz, "checkpoint.pth"),
         )
         if self.mean_loss < self.best_mean_loss:
             self.best_mean_loss = self.mean_loss
             shutil.copy(
-                osp.join(self.out, "checkpoint_latest.pth.tar"),
-                osp.join(self.out, "checkpoint_best.pth.tar"),
+                osp.join(self.out, "checkpoint_latest.pth"),
+                osp.join(self.out, "checkpoint_best.pth"),
             )
 
         if training:

+ 1 - 1
train.py

@@ -114,7 +114,7 @@ def main():
     # print("epoch_size (valid):", len(val_loader))
 
     if resume_from:
-        checkpoint = torch.load(osp.join(resume_from, "checkpoint_latest.pth.tar"))
+        checkpoint = torch.load(osp.join(resume_from, "checkpoint_latest.pth"))
 
     # 2. model
     if M.backbone == "stacked_hourglass":