diff --git a/README.md b/README.md index cfa2f13..aeb7b1b 100644 --- a/README.md +++ b/README.md @@ -3,7 +3,6 @@ A GAN toolbox for researchers and developers with: - Progressive Growing of GAN(PGAN): https://arxiv.org/pdf/1710.10196.pdf - DCGAN: https://arxiv.org/pdf/1511.06434.pdf -- To come: StyleGAN https://arxiv.org/abs/1812.04948 illustration Picture: Generated samples from GANs trained on celebaHQ, fashionGen, DTD. @@ -48,6 +47,21 @@ pip install -r requirements.txt - DTD: https://www.robots.ox.ac.uk/~vgg/data/dtd/ - CIFAR10: http://www.cs.toronto.edu/~kriz/cifar.html +For a quick start with CelebAHQ, you might: +``` +git clone https://github.com/nperraud/download-celebA-HQ.git +cd download-celebA-HQ +conda create -n celebaHQ python=3 +source activate celebaHQ +conda install jpeg=8d tqdm requests pillow==3.1.1 urllib3 numpy cryptography scipy +pip install opencv-python==3.4.0.12 cryptography==2.1.4 +sudo apt-get install p7zip-full +python download_celebA.py ./ +python download_celebA_HQ.py ./ +python make_HQ_images.py ./ +export PATH_TO_CELEBAHQ=`readlink -f ./celebA-HQ/512` +``` + ## Quick training The datasets.py script allows you to prepare your datasets and build their corresponding configuration files. @@ -64,8 +78,9 @@ And wait for a few days. Your checkpoints will be dumped in output_networks/cele For celebaHQ: ``` -python datasets.py celebaHQ $PATH_TO_CELEBAHQ -o $OUTPUT_DATASET - f -python train.py PGAN -c config_celebaHQ.json --restart -n celebaHQ +python datasets.py celebaHQ $PATH_TO_CELEBAHQ -o $OUTPUT_DATASET # Prepare the dataset and build the configuration file. +python train.py PGAN -c config_celebaHQ.json --restart -n celebaHQ # Train. +python eval.py inception -n celebaHQ -m PGAN # If you want to check the inception score. ``` Your checkpoints will be dumped in output_networks/celebaHQ. You should get 1024x1024 generations at the end. @@ -130,7 +145,7 @@ Where: 1 - MODEL_NAME is the name of the model you want to run. Currently, two models are available: - PGAN(progressive growing of gan) - - PPGAN(decoupled version of PGAN) + - DCGAN 2 - CONFIGURATION_FILE(mandatory): path to a training configuration file. This file is a json file containing at least a pathDB entry with the path to the training dataset. See below for more informations about this file. @@ -209,19 +224,19 @@ You need to use the eval.py script. You can generate more images from an existing checkpoint using: ``` -python eval.py visualization -n $modelName -m $modelType +python eval.py visualization -n $runName -m $modelName ``` -Where modelType is in [PGAN, PPGAN, DCGAN] and modelName is the name given to your model. This script will load the last checkpoint detected at testNets/$modelName. If you want to load a specific iteration, please call: +Where modelName is in [PGAN, DCGAN] and runName is the name given to your run (trained model). This script will load the last checkpoint detected at output_networks/$modelName. If you want to load a specific iteration, please call: ``` -python eval.py visualization -n $modelName -m $modelType -s $SCALE -i $ITER +python eval.py visualization -n $runName -m $modelName -s $SCALE -i $ITER ``` If your model is conditioned, you can ask the visualizer to print out some conditioned generations. For example: ``` -python eval.py visualization -n $modelName -m $modelType --Class T_SHIRT +python eval.py visualization -n $runName -m $modelName --Class T_SHIRT ``` Will plot a batch of T_SHIRTS in visdom. Please use the option - -showLabels to see all the available labels for your model. @@ -231,16 +246,21 @@ Will plot a batch of T_SHIRTS in visdom. Please use the option - -showLabels to To save a randomly generated fake dataset from a checkpoint please use: ``` -python eval.py visualization -n $modelName -m $modelType --save_dataset $PATH_TO_THE_OUTPUT_DATASET --size_dataset $SIZE_OF_THE_OUTPUT +python eval.py visualization -n $runName -m $modelName --save_dataset $PATH_TO_THE_OUTPUT_DATASET --size_dataset $SIZE_OF_THE_OUTPUT ``` ### SWD metric Using the same kind of configuration file as above, just launch: - ``` -python eval.py laplacian_SWD -c $CONFIGURATION_FILE -n $modelName -m $modelType +python eval.py laplacian_SWD -c $CONFIGURATION_FILE -n $runName -m $modelName ``` +for the SWD score, to be maximized, or for the inception score: +``` +python eval.py inception -c $CONFIGURATION_FILE -n $runName -m $modelName +``` +also to be maximized (see https://hal.inria.fr/hal-01850447/document for a discussion). + Where $CONFIGURATION_FILE is the training configuration file called by train.py (see above): it must contains a "pathDB" field pointing to path to the dataset's directory. For example, if you followed the instruction of the Quick Training section to launch a training session on celebaHQ your configuration file will be config_celebaHQ.json. @@ -250,27 +270,53 @@ You can add optional arguments: - -i $ITER: specify the iteration to evaluate(if not set, will take the highest one) - --selfNoise: returns the typical noise of the SWD distance for each resolution -### Inspirational generation +### Inspirational generation (https://arxiv.org/abs/1906.11661) + +You might want to generate clothese (or faces, or whatever) using an inspirational image, e.g.: +celeba +An inspirational generation consists in generating with your GAN an image which looks like a given input image. +This is based on optimizing the latent vector z such that similarity(GAN(z), TargetImage) is maximum. To make an inspirational generation, you first need to build a feature extractor: ``` python save_feature_extractor.py {vgg16, vgg19} $PATH_TO_THE_OUTPUT_FEATURE_EXTRACTOR --layers 3 4 5 ``` +This feature extractor is then used for computing the similarity. Then run your model: ``` -python eval.py inspirational_generation -n $modelName -m $modelType --inputImage $pathTotheInputImage -f $PATH_TO_THE_OUTPUT_FEATURE_EXTRACTOR +python eval.py inspirational_generation -n $runName -m $modelName --inputImage $pathTotheInputImage -f $PATH_TO_THE_OUTPUT_FEATURE_EXTRACTOR ``` +You can compare choose for the optimization one of the optimizers in Nevergrad +(https://github.com/facebookresearch/nevergrad/). For example you can run: +``` +python eval.py inspirational_generation -n $runName -m $modelName --inputImage $pathTotheInputImage -f $PATH_TO_THE_OUTPUT_FEATURE_EXTRACTOR --nevergrad CMA +``` +if you want to use CMA-ES; or another optimizer in 'CMA', 'DE', 'PSO', 'TwoPointsDE', 'PortfolioDiscreteOnePlusOne', 'DiscreteOnePlusOne', 'OnePlusOne'. If you do not specify --nevergrad, then Adam is used. + ### I have generated my metrics. How can i plot them on visdom ? Just run ``` -python eval.py metric_plot -n $modelName +python eval.py metric_plot -n $runName ``` ## LICENSE This project is under BSD-3 license. + +## Citing + +```bibtex +@misc{pytorchganzoo, + author = {M. Riviere}, + title = {{Pytorch GAN Zoo}}, + year = {2019}, + publisher = {GitHub}, + journal = {GitHub repository}, + howpublished = {\url{https://GitHub.com/FacebookResearch/pytorch_GAN_zoo}}, +} +``` diff --git a/datasets.py b/datasets.py index d71a0b7..e5cbaa0 100644 --- a/datasets.py +++ b/datasets.py @@ -281,7 +281,8 @@ def resizeDataset(inputPath, outputPath, maxSize): maxSize = 1024 moveLastScale = False keepOriginalDataset = True - config["miniBatchScheduler"] = {"7": 12, "8": 8} + if args.model_type == 'PGAN': + config["miniBatchScheduler"] = {"7": 12, "8": 8} if args.model_type == 'DCGAN': print("WARNING: DCGAN is diverging for celebaHQ") diff --git a/inspir.png b/inspir.png new file mode 100644 index 0000000..69f77da Binary files /dev/null and b/inspir.png differ diff --git a/models/eval/inspirational_generation.py b/models/eval/inspirational_generation.py index 78ab10b..8b1ce72 100644 --- a/models/eval/inspirational_generation.py +++ b/models/eval/inspirational_generation.py @@ -71,7 +71,7 @@ def updateParser(parser): parser.add_argument('--weights', type=float, dest='weights', nargs='*', help="Weight of each classifier. Default \ value is one. If specified, the number of weights must\ - match the number of feature exatrcators.") + match the number of feature extractors.") parser.add_argument('--gradient_descent', help='gradient descent', action='store_true') parser.add_argument('--random_search', help='Random search', diff --git a/models/trainer/DCGAN_trainer.py b/models/trainer/DCGAN_trainer.py index 1b99141..55d73d2 100644 --- a/models/trainer/DCGAN_trainer.py +++ b/models/trainer/DCGAN_trainer.py @@ -1,5 +1,6 @@ # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved import os +import time from ..DCGAN import DCGAN from .gan_trainer import GANTrainer @@ -18,6 +19,7 @@ def getDefaultConfig(self): def __init__(self, pathdb, + miniBatchScheduler=None, **kwargs): r""" Args: @@ -46,10 +48,12 @@ def train(self): self.saveBaseConfig(pathBaseConfig) maxShift = int(self.modelConfig.nEpoch * len(self.getDBLoader(0))) - + start = time.time() for epoch in range(self.modelConfig.nEpoch): dbLoader = self.getDBLoader(0) self.trainOnEpoch(dbLoader, 0, shiftIter=shift) + if self.max_time > 0 and time.time() - start > self.max_time: + break shift += len(dbLoader) diff --git a/models/trainer/gan_trainer.py b/models/trainer/gan_trainer.py index 01bcf6a..071aa4f 100644 --- a/models/trainer/gan_trainer.py +++ b/models/trainer/gan_trainer.py @@ -27,6 +27,7 @@ def __init__(self, checkPointDir=None, modelLabel="GAN", config=None, + max_time=0, pathAttribDict=None, selectedAttributes=None, imagefolderDataset=False, @@ -50,6 +51,7 @@ def __init__(self, - modelLabel (string): name of the model - config (dictionary): configuration dictionnary. for all the possible options + - max_time (int): max number of seconds for training (0 = infinity). - pathAttribDict (string): path to the attribute dictionary giving the labels of the dataset - selectedAttributes (list): if not None, consider only the listed @@ -70,6 +72,7 @@ def __init__(self, self.path_db = pathdb self.pathPartition = pathPartition self.partitionValue = partitionValue + self.max_time = max_time if config is None: config = {} diff --git a/models/trainer/progressive_gan_trainer.py b/models/trainer/progressive_gan_trainer.py index ae34639..4b1e541 100644 --- a/models/trainer/progressive_gan_trainer.py +++ b/models/trainer/progressive_gan_trainer.py @@ -1,5 +1,6 @@ # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved import os +import time from .standard_configurations.pgan_config import _C from ..progressive_gan import ProgressiveGAN @@ -46,7 +47,6 @@ def __init__(self, - stopOnShitStorm (bool): should we stop the training if a diverging behavior is detected ? """ - self.configScheduler = {} if configScheduler is not None: self.configScheduler = { @@ -208,6 +208,7 @@ def train(self): + "_train_config.json") self.saveBaseConfig(pathBaseConfig) + start = time.time() for scale in range(self.startScale, n_scales): self.updateDatasetForScale(scale) @@ -230,7 +231,8 @@ def train(self): shiftAlpha += 1 while shiftIter < self.modelConfig.maxIterAtScale[scale]: - + if self.max_time > 0 and time.time() - start > self.max_time: + break self.indexJumpAlpha = shiftAlpha status = self.trainOnEpoch(dbLoader, scale, shiftIter=shiftIter, diff --git a/train.py b/train.py index 1587608..7d97dab 100644 --- a/train.py +++ b/train.py @@ -14,7 +14,6 @@ def getTrainer(name): match = {"PGAN": ("progressive_gan_trainer", "ProgressiveGANTrainer"), - "StyleGAN":("styleGAN_trainer", "StyleGANTrainer"), "DCGAN": ("DCGAN_trainer", "DCGANTrainer")} if name not in match: @@ -30,13 +29,15 @@ def getTrainer(name): parser = argparse.ArgumentParser(description='Testing script') parser.add_argument('model_name', type=str, help='Name of the model to launch, available models are\ - PGAN and PPGAN. To get all possible option for a model\ + PGAN and DCGAN. To get all possible option for a model\ please run train.py $MODEL_NAME -overrides') parser.add_argument('--no_vis', help=' Disable all visualizations', action='store_true') parser.add_argument('--np_vis', help=' Replace visdom by a numpy based \ visualizer (SLURM)', action='store_true') + parser.add_argument('--max_time', help=' Maximum time in seconds (0 for infinity)', type=int, + dest='max_time', default=0) parser.add_argument('--restart', help=' If a checkpoint is detected, do \ not try to load it', action='store_true') @@ -124,6 +125,7 @@ def getTrainer(name): lossIterEvaluation=kwargs["evalIter"], checkPointDir=checkPointDir, saveIter= kwargs["saveIter"], + max_time=kwargs["max_time"], modelLabel=modelLabel, partitionValue=partitionValue, **trainingConfig)