This is the official PyTorch implementation of the paper Learning Words by Drawing Images by:
Dídac Surís(*) | Adrià Recasens(*) | David Bau | David Harwath | James Glass | Antonio Torralba |
(*) Equal contribution
In this paper, we propose a framework for learning through drawing. Our goal is to learn the correspondence between spoken words and abstract visual attributes, from a dataset of spoken descriptions of images. We use the learned representations of GANs and manipulate them to edit semantic concepts in the generated outputs, and use such GAN-generated images to train a model using a triplet loss.
The current code is programmed with Python 3.6. We do not guarantee it will work with other versions of Python. We use the PyTorch framework, version 0.4. It should also work for PyTorch 1.0, but it was not tested.
To install everything needed, have conda available, and create a new virtual environment:
conda create -n env_drawing python=3.6
conda activate env_drawing
Then install the libraries listed in requirements.txt
.
After that, we still need to install the netdissect
module, which is provided as part of the code. To do so, go to the
root folder of the project, and use the following instructions:
gandissect/script/download_data_drawing.sh # Download support models
pip install -v -e ./gandissect/ # Link the local netdissect package into the env
The data can be obtained downloading the files in this link
The default folder structure is:
/path/to/project/
data/
audio_clevrgan_natural/
audio_clevrgan_synth/
Please note that there is NO NEED to download the images, as the default training setting generates them with the GAN. They are only needed if the GAN generation is not desired (the flag loading_image has to be activated) or for evaluation purposes. The text transcription of the audios is also NOT used, but we make it available in case it is useful.
In order to download the basic data (audio and name_list files) and prepare it with the correct folder structure, execute:
./download_data.sh
Modify the DATA_FOLDER
in download_data.sh
to choose your preferred data folder. Change the flag "folder_dataset"
accordingly when running the code.
The name (ID) of the files corresponds to the ID of the noise vector used to generate the image of the image/caption pair. Do NOT modify the seed of the random noise generation.
Pretrained models are downloaded during the execution of the scripts, if they are not found. No manual action is
required. In case there is any problem, they can be found here.
Please take a look at the default folder where these models are stored, which is ./downloaded_files
. In order to
change it, set the flags path_negatives_test
, path_model_gan
and path_model_segmenter
accordingly.
The code is structured as follows
run.py
: script we have to execute. From a configuration file (in the\config_files
folder), it calls the script in thefile
attribute of the configuration file, with all the parameters in the configuration file. Example of execution:
CUDA_VISIBLE_DEVICES=0 python run.py -f train_example.yaml
main.py
: main script called from run.py. It creates all the actors (trainer, optimizer, dataset) and calls their methods in order to train, evaluate or test the system.trainer.py
: contains the main training class (Trainer), including the training loop.dataset.py
: implements a Dataset class inheriting from torch.utils.data.Dataset, which loads images, audios and text.models.py
: networks used for this project. Classes inheriting from torch.nn.Module.clusterer.py
: class that performs the clustering of features, as well as some auxiliary methods.segmenter.py
: class that performs segmentations, both using ground truth labels, and using the cluster classes.losses.py
: methods used to compute the loss. Loss definitions.experiments.py
: methods implementing experiments for testing. See the example inconfig_files/
.utils.py
: general useful methods.active_learning.py
: methods to generate and use active learning samples.README.md
: this file.requirements.txt
: list of python libraries required.download_data.sh
: script used to download data. See instructions above.config_files
: folder with .yaml configuration files. Take a look at the examples to understand the format. All the checkpoints and results contain the checkpoint name in the .yaml file so that it is easy to follow what parameters were used.ablate/
: contains auxiliary scripts to help ablate units of the GAN.gandissect/
: folder containing thenetdissect
module.
The training can be visualized in a web browser (localhost), using tensorboard with the following command (from the project root folder):
tensorboard --logidir=./results/runs/ --port=6006
If you want to cite our research, please use:
@InProceedings{Suris_2019_CVPR,
author = {Suris, Didac and Recasens, Adria and Bau, David and Harwath, David and Glass, James and Torralba, Antonio},
title = {Learning Words by Drawing Images},
booktitle = {The IEEE Conference on Computer Vision and Pattern Recognition (CVPR)},
month = {June},
year = {2019}
}