diff --git a/README.md b/README.md
index cfa2f13..aeb7b1b 100644
--- a/README.md
+++ b/README.md
@@ -3,7 +3,6 @@
A GAN toolbox for researchers and developers with:
- Progressive Growing of GAN(PGAN): https://arxiv.org/pdf/1710.10196.pdf
- DCGAN: https://arxiv.org/pdf/1511.06434.pdf
-- To come: StyleGAN https://arxiv.org/abs/1812.04948
Picture: Generated samples from GANs trained on celebaHQ, fashionGen, DTD.
@@ -48,6 +47,21 @@ pip install -r requirements.txt
- DTD: https://www.robots.ox.ac.uk/~vgg/data/dtd/
- CIFAR10: http://www.cs.toronto.edu/~kriz/cifar.html
+For a quick start with CelebAHQ, you might:
+```
+git clone https://github.com/nperraud/download-celebA-HQ.git
+cd download-celebA-HQ
+conda create -n celebaHQ python=3
+source activate celebaHQ
+conda install jpeg=8d tqdm requests pillow==3.1.1 urllib3 numpy cryptography scipy
+pip install opencv-python==3.4.0.12 cryptography==2.1.4
+sudo apt-get install p7zip-full
+python download_celebA.py ./
+python download_celebA_HQ.py ./
+python make_HQ_images.py ./
+export PATH_TO_CELEBAHQ=`readlink -f ./celebA-HQ/512`
+```
+
## Quick training
The datasets.py script allows you to prepare your datasets and build their corresponding configuration files.
@@ -64,8 +78,9 @@ And wait for a few days. Your checkpoints will be dumped in output_networks/cele
For celebaHQ:
```
-python datasets.py celebaHQ $PATH_TO_CELEBAHQ -o $OUTPUT_DATASET - f
-python train.py PGAN -c config_celebaHQ.json --restart -n celebaHQ
+python datasets.py celebaHQ $PATH_TO_CELEBAHQ -o $OUTPUT_DATASET # Prepare the dataset and build the configuration file.
+python train.py PGAN -c config_celebaHQ.json --restart -n celebaHQ # Train.
+python eval.py inception -n celebaHQ -m PGAN # If you want to check the inception score.
```
Your checkpoints will be dumped in output_networks/celebaHQ. You should get 1024x1024 generations at the end.
@@ -130,7 +145,7 @@ Where:
1 - MODEL_NAME is the name of the model you want to run. Currently, two models are available:
- PGAN(progressive growing of gan)
- - PPGAN(decoupled version of PGAN)
+ - DCGAN
2 - CONFIGURATION_FILE(mandatory): path to a training configuration file. This file is a json file containing at least a pathDB entry with the path to the training dataset. See below for more informations about this file.
@@ -209,19 +224,19 @@ You need to use the eval.py script.
You can generate more images from an existing checkpoint using:
```
-python eval.py visualization -n $modelName -m $modelType
+python eval.py visualization -n $runName -m $modelName
```
-Where modelType is in [PGAN, PPGAN, DCGAN] and modelName is the name given to your model. This script will load the last checkpoint detected at testNets/$modelName. If you want to load a specific iteration, please call:
+Where modelName is in [PGAN, DCGAN] and runName is the name given to your run (trained model). This script will load the last checkpoint detected at output_networks/$modelName. If you want to load a specific iteration, please call:
```
-python eval.py visualization -n $modelName -m $modelType -s $SCALE -i $ITER
+python eval.py visualization -n $runName -m $modelName -s $SCALE -i $ITER
```
If your model is conditioned, you can ask the visualizer to print out some conditioned generations. For example:
```
-python eval.py visualization -n $modelName -m $modelType --Class T_SHIRT
+python eval.py visualization -n $runName -m $modelName --Class T_SHIRT
```
Will plot a batch of T_SHIRTS in visdom. Please use the option - -showLabels to see all the available labels for your model.
@@ -231,16 +246,21 @@ Will plot a batch of T_SHIRTS in visdom. Please use the option - -showLabels to
To save a randomly generated fake dataset from a checkpoint please use:
```
-python eval.py visualization -n $modelName -m $modelType --save_dataset $PATH_TO_THE_OUTPUT_DATASET --size_dataset $SIZE_OF_THE_OUTPUT
+python eval.py visualization -n $runName -m $modelName --save_dataset $PATH_TO_THE_OUTPUT_DATASET --size_dataset $SIZE_OF_THE_OUTPUT
```
### SWD metric
Using the same kind of configuration file as above, just launch:
-
```
-python eval.py laplacian_SWD -c $CONFIGURATION_FILE -n $modelName -m $modelType
+python eval.py laplacian_SWD -c $CONFIGURATION_FILE -n $runName -m $modelName
```
+for the SWD score, to be maximized, or for the inception score:
+```
+python eval.py inception -c $CONFIGURATION_FILE -n $runName -m $modelName
+```
+also to be maximized (see https://hal.inria.fr/hal-01850447/document for a discussion).
+
Where $CONFIGURATION_FILE is the training configuration file called by train.py (see above): it must contains a "pathDB" field pointing to path to the dataset's directory. For example, if you followed the instruction of the Quick Training section to launch a training session on celebaHQ your configuration file will be config_celebaHQ.json.
@@ -250,27 +270,53 @@ You can add optional arguments:
- -i $ITER: specify the iteration to evaluate(if not set, will take the highest one)
- --selfNoise: returns the typical noise of the SWD distance for each resolution
-### Inspirational generation
+### Inspirational generation (https://arxiv.org/abs/1906.11661)
+
+You might want to generate clothese (or faces, or whatever) using an inspirational image, e.g.:
+
+An inspirational generation consists in generating with your GAN an image which looks like a given input image.
+This is based on optimizing the latent vector z such that similarity(GAN(z), TargetImage) is maximum.
To make an inspirational generation, you first need to build a feature extractor:
```
python save_feature_extractor.py {vgg16, vgg19} $PATH_TO_THE_OUTPUT_FEATURE_EXTRACTOR --layers 3 4 5
```
+This feature extractor is then used for computing the similarity.
Then run your model:
```
-python eval.py inspirational_generation -n $modelName -m $modelType --inputImage $pathTotheInputImage -f $PATH_TO_THE_OUTPUT_FEATURE_EXTRACTOR
+python eval.py inspirational_generation -n $runName -m $modelName --inputImage $pathTotheInputImage -f $PATH_TO_THE_OUTPUT_FEATURE_EXTRACTOR
```
+You can compare choose for the optimization one of the optimizers in Nevergrad
+(https://github.com/facebookresearch/nevergrad/). For example you can run:
+```
+python eval.py inspirational_generation -n $runName -m $modelName --inputImage $pathTotheInputImage -f $PATH_TO_THE_OUTPUT_FEATURE_EXTRACTOR --nevergrad CMA
+```
+if you want to use CMA-ES; or another optimizer in 'CMA', 'DE', 'PSO', 'TwoPointsDE', 'PortfolioDiscreteOnePlusOne', 'DiscreteOnePlusOne', 'OnePlusOne'. If you do not specify --nevergrad, then Adam is used.
+
### I have generated my metrics. How can i plot them on visdom ?
Just run
```
-python eval.py metric_plot -n $modelName
+python eval.py metric_plot -n $runName
```
## LICENSE
This project is under BSD-3 license.
+
+## Citing
+
+```bibtex
+@misc{pytorchganzoo,
+ author = {M. Riviere},
+ title = {{Pytorch GAN Zoo}},
+ year = {2019},
+ publisher = {GitHub},
+ journal = {GitHub repository},
+ howpublished = {\url{https://GitHub.com/FacebookResearch/pytorch_GAN_zoo}},
+}
+```
diff --git a/datasets.py b/datasets.py
index d71a0b7..e5cbaa0 100644
--- a/datasets.py
+++ b/datasets.py
@@ -281,7 +281,8 @@ def resizeDataset(inputPath, outputPath, maxSize):
maxSize = 1024
moveLastScale = False
keepOriginalDataset = True
- config["miniBatchScheduler"] = {"7": 12, "8": 8}
+ if args.model_type == 'PGAN':
+ config["miniBatchScheduler"] = {"7": 12, "8": 8}
if args.model_type == 'DCGAN':
print("WARNING: DCGAN is diverging for celebaHQ")
diff --git a/inspir.png b/inspir.png
new file mode 100644
index 0000000..69f77da
Binary files /dev/null and b/inspir.png differ
diff --git a/models/eval/inspirational_generation.py b/models/eval/inspirational_generation.py
index 78ab10b..8b1ce72 100644
--- a/models/eval/inspirational_generation.py
+++ b/models/eval/inspirational_generation.py
@@ -71,7 +71,7 @@ def updateParser(parser):
parser.add_argument('--weights', type=float, dest='weights',
nargs='*', help="Weight of each classifier. Default \
value is one. If specified, the number of weights must\
- match the number of feature exatrcators.")
+ match the number of feature extractors.")
parser.add_argument('--gradient_descent', help='gradient descent',
action='store_true')
parser.add_argument('--random_search', help='Random search',
diff --git a/models/trainer/DCGAN_trainer.py b/models/trainer/DCGAN_trainer.py
index 1b99141..55d73d2 100644
--- a/models/trainer/DCGAN_trainer.py
+++ b/models/trainer/DCGAN_trainer.py
@@ -1,5 +1,6 @@
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
import os
+import time
from ..DCGAN import DCGAN
from .gan_trainer import GANTrainer
@@ -18,6 +19,7 @@ def getDefaultConfig(self):
def __init__(self,
pathdb,
+ miniBatchScheduler=None,
**kwargs):
r"""
Args:
@@ -46,10 +48,12 @@ def train(self):
self.saveBaseConfig(pathBaseConfig)
maxShift = int(self.modelConfig.nEpoch * len(self.getDBLoader(0)))
-
+ start = time.time()
for epoch in range(self.modelConfig.nEpoch):
dbLoader = self.getDBLoader(0)
self.trainOnEpoch(dbLoader, 0, shiftIter=shift)
+ if self.max_time > 0 and time.time() - start > self.max_time:
+ break
shift += len(dbLoader)
diff --git a/models/trainer/gan_trainer.py b/models/trainer/gan_trainer.py
index 01bcf6a..071aa4f 100644
--- a/models/trainer/gan_trainer.py
+++ b/models/trainer/gan_trainer.py
@@ -27,6 +27,7 @@ def __init__(self,
checkPointDir=None,
modelLabel="GAN",
config=None,
+ max_time=0,
pathAttribDict=None,
selectedAttributes=None,
imagefolderDataset=False,
@@ -50,6 +51,7 @@ def __init__(self,
- modelLabel (string): name of the model
- config (dictionary): configuration dictionnary.
for all the possible options
+ - max_time (int): max number of seconds for training (0 = infinity).
- pathAttribDict (string): path to the attribute dictionary giving
the labels of the dataset
- selectedAttributes (list): if not None, consider only the listed
@@ -70,6 +72,7 @@ def __init__(self,
self.path_db = pathdb
self.pathPartition = pathPartition
self.partitionValue = partitionValue
+ self.max_time = max_time
if config is None:
config = {}
diff --git a/models/trainer/progressive_gan_trainer.py b/models/trainer/progressive_gan_trainer.py
index ae34639..4b1e541 100644
--- a/models/trainer/progressive_gan_trainer.py
+++ b/models/trainer/progressive_gan_trainer.py
@@ -1,5 +1,6 @@
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
import os
+import time
from .standard_configurations.pgan_config import _C
from ..progressive_gan import ProgressiveGAN
@@ -46,7 +47,6 @@ def __init__(self,
- stopOnShitStorm (bool): should we stop the training if a diverging
behavior is detected ?
"""
-
self.configScheduler = {}
if configScheduler is not None:
self.configScheduler = {
@@ -208,6 +208,7 @@ def train(self):
+ "_train_config.json")
self.saveBaseConfig(pathBaseConfig)
+ start = time.time()
for scale in range(self.startScale, n_scales):
self.updateDatasetForScale(scale)
@@ -230,7 +231,8 @@ def train(self):
shiftAlpha += 1
while shiftIter < self.modelConfig.maxIterAtScale[scale]:
-
+ if self.max_time > 0 and time.time() - start > self.max_time:
+ break
self.indexJumpAlpha = shiftAlpha
status = self.trainOnEpoch(dbLoader, scale,
shiftIter=shiftIter,
diff --git a/train.py b/train.py
index 1587608..7d97dab 100644
--- a/train.py
+++ b/train.py
@@ -14,7 +14,6 @@
def getTrainer(name):
match = {"PGAN": ("progressive_gan_trainer", "ProgressiveGANTrainer"),
- "StyleGAN":("styleGAN_trainer", "StyleGANTrainer"),
"DCGAN": ("DCGAN_trainer", "DCGANTrainer")}
if name not in match:
@@ -30,13 +29,15 @@ def getTrainer(name):
parser = argparse.ArgumentParser(description='Testing script')
parser.add_argument('model_name', type=str,
help='Name of the model to launch, available models are\
- PGAN and PPGAN. To get all possible option for a model\
+ PGAN and DCGAN. To get all possible option for a model\
please run train.py $MODEL_NAME -overrides')
parser.add_argument('--no_vis', help=' Disable all visualizations',
action='store_true')
parser.add_argument('--np_vis', help=' Replace visdom by a numpy based \
visualizer (SLURM)',
action='store_true')
+ parser.add_argument('--max_time', help=' Maximum time in seconds (0 for infinity)', type=int,
+ dest='max_time', default=0)
parser.add_argument('--restart', help=' If a checkpoint is detected, do \
not try to load it',
action='store_true')
@@ -124,6 +125,7 @@ def getTrainer(name):
lossIterEvaluation=kwargs["evalIter"],
checkPointDir=checkPointDir,
saveIter= kwargs["saveIter"],
+ max_time=kwargs["max_time"],
modelLabel=modelLabel,
partitionValue=partitionValue,
**trainingConfig)