Date: 09/2019

Author(s):

Micha Livne
livne@seraphlabs.ca
University of Toronto
Vector Institute
Kevin Swersky
lkswersky@google.com
Google Research
David J. Fleet
fleet@cs.toronto.edu
University of Toronto
Vector Institute

Last Updated: 11/10/2019 16:22:51

MIM: Mutual Information Machine

Links

Why should you care? Posterior Collapse!

MIM

MIM

VAE

VAE

MIM and VAE models with 2D inputs, and 2D latent space.
Top row: Black contours depict level sets of P(x); red dots are reconstructed test points.
Bottom row: Green contours are one standard deviation ellipses of q(z|x) for test points. Dashed black circles depict one standard deviation of P(z).

MIM produces lower predictive variance and lower reconstruction errors, consistent with high mutual information.
VAE is optimized with annealing of beta in beta-VAE. Once annealing is completed (i.e., beta = 1), the VAE posteriors show high predictive variance, which is indicative of partial posterior collapse. The increased variance leads to reduced mutual information and worse reconstruction error.

Requirements

The code has been tested on CPU and NVIDIA Titan Xp GPU, using Anaconda, Python 3.6, and zsh:

# tools
zsh 5.4.2
Cuda compilation tools, release 9.0, V9.0.176
conda 4.6.14
Python 3.6.8

# python packages (see requirements.txt for complete list)
scipy==1.1.0
matplotlib==3.0.3
numpy==1.15.4
torchvision==0.2.1
torch==1.0.0
scikit_learn==0.21.3

Installation

Please follow installation instructions in the following link: pytorch.

pip install -r requirements.txt

Data

The experiments can be run on the following datasets:

All datasets are included as part of the repo for convenience. Links are provided as a workaround (i.e., in case of issues).

Experiments

Directory structure (if code fail due to a missing directory please create manually):

src/ - Experiments are assumed to be executed from this directory.
data/assets - Datasets will be saved here.
data/torch-generated - Results will be saved here.

NOTE (if code fails due to CUDA/GPU issues): To prevent the use of CUDA/GPU and enforce CPU computation, please add the following flag to the supplied command lines below:

--no-cuda

Otherwise, by default CUDA will be used, if detected by pytorch.

For detailed explanation of the plots below, please see the paper.

Animation

To produce the animation at the top:

# MIM
./vae-as-mim-dataset.py \
    --dataset toyMIM \
    --z-dim 2 \
    --mid-dim 50 \
    --min-logvar 6 \
    --seed 1 \
    --batch-size 128 \
    --epochs 49 \
    --warmup-steps 25 \
    --vis-progress \
    --mim-loss \
    --mim-samp
#VAE
./vae-as-mim-dataset.py \
    --dataset toyMIM \
    --z-dim 2 \
    --mid-dim 50 \
    --min-logvar 6 \
    --seed 1 \
    --batch-size 128 \
    --epochs 49 \
    --warmup-steps 25  \
    --vis-progress

2D Experiments

Experimenting with expressiveness of MIM and VAE:

for seed in 1 2 3 4 5 6 7 8 9 10; do
        for mid_dim in 5 20 50 100 200 300 400 500; do
            # MIM
            ./vae-as-mim-dataset.py \
                --dataset toy4 \
                --z-dim 2 \
                --mid-dim ${mid_dim} \
                --min-logvar 6 \
                --seed ${seed} \
                --batch-size 128 \
                --epochs 200 \
                --warmup-steps 3 \
                --mim-loss \
                --mim-samp
            # VAE
            ./vae-as-mim-dataset.py \
                --dataset toy4 \
                --z-dim 2 \
                --mid-dim ${mid_dim} \
                --min-logvar 6 \
                --seed ${seed} \
                --batch-size 128 \
                --epochs 200 \
                --warmup-steps 3
        done
done

Results below demonstrate posterior collapse in VAE, and the lack of it in MIM.

MIM (5, 20, 500 hidden units)

MIM MIM MIM

VAE (5, 20, 500 hidden units)

VAE VAE VAE
MIM
VAE
MI

MI

NLL

NLL

RMSE

Recon. RMSE

Cls. Acc.

Classification Acc.

Experimenting with effect of entropy prior on MIM and VAE:

for seed in 1 2 3 4 5 6 7 8 9 10; do
        for mid_dim in 5 20 50 100 200 300 400 500; do
                # MIM
                ./vae-as-mim-dataset.py \
                    --dataset toy4 \
                    --z-dim 2 \
                    --mid-dim ${mid_dim} \
                    --min-logvar 6 \
                    --seed ${seed} \
                    --batch-size 128 \
                    --epochs 200 \
                    --warmup-steps 3 \
                    --mim-loss \
                    --mim-samp \
                    --inv-H-loss
                # VAE
                ./vae-as-mim-dataset.py \
                    --dataset toy4 \
                    --z-dim 2 \
                    --mid-dim ${mid_dim} \
                    --min-logvar 6 \
                    --seed ${seed} \
                    --batch-size 128 \
                    --epochs 200 \
                    --warmup-steps 3 \
                    --inv-H-loss
        done
done

Results below demonstrate how adding joint entropy as regularizer can prevent posterior collapse in VAE, and subtracting the joint entropy can generate a strong collapse in MIM.

MIM - H (5, 20, 500 hidden units)

MIM MIM MIM

VAE + H (5, 20, 500 hidden units)

VAE VAE VAE
MIM
VAE
MI

MI

NLL

NLL

RMSE

Recon. RMSE

Cls. Acc.

Classification Acc.

Bottleneck

Experimenting with effect of bottleneck on VAE and MIM.

20D with 5 GMM

A synthetic 5 GMM dataset with 20D x:

for seed in 1 2 3 4 5 6 7 8 9 10; do
        for z_dim in 2 4 6 8 10 12 14 16 18 20; do
            # MIM
            ./vae-as-mim-dataset.py \
                --dataset toy4_20  \
                --z-dim ${z_dim}  \
                --mid-dim 50  \
                --seed ${seed}  \
                --epochs 200   \
                --min-logvar 6  \
                --warmup-steps 3   \
                --mim-loss  \
                --mim-samp
            # VAE
            ./vae-as-mim-dataset.py  \
                --dataset toy4_20   \
                --z-dim ${z_dim}  \
                --mid-dim 50  \
                --seed ${seed}  \
                --epochs 200  \
                --min-logvar 6  \
                --warmup-steps 3
        done
done

Results below demonstrate posterior collapse in VAE which worsen as the latent dimensionality increases, and the lack of it in MIM.

MIM
VAE
MI

MI

NLL

NLL

RMSE

Recon. RMSE

Cls. Acc.

Classification Acc.

20D with Fashion-MNIST PCA

A PCA reduction of Fashion-MNIST to 20D x:

for seed in 1 2 3 4 5 6 7 8 9 10; do
    for z_dim in 2 4 6 8 10 12 14 16 18 20; do
            # MIM
            ./vae-as-mim-dataset.py  \
                --dataset pca-fashion-mnist20   \
                --z-dim ${z_dim}  \
                --mid-dim 50  \
                --seed ${seed}  \
                --epochs 200  \
                --min-logvar 6  \
                --warmup-steps 3  \
                --mim-loss  \
                --mim-samp
            # VAE
            ./vae-as-mim-dataset.py  \
                --dataset pca-fashion-mnist20   \
                --z-dim ${z_dim}  \
                --mid-dim 50  \
                --seed ${seed}  \
                --epochs 200  \
                --min-logvar 6  \
                --warmup-steps 3
        done
done

Results below demonstrate posterior collapse in VAE which worsen as the latent dimensionality increases, and the lack of it in MIM. Here, for real-world data observations.

MIM
VAE
MI

MI

NLL

NLL

RMSE

Recon. RMSE

Cls. Acc.

Classification Acc.

High Dimensional Image Data

Experimenting with high dimensional image data where we cannot reliably measure mutual information:

for seed in 1 2 3 4 5 6 7 8 9 10; do
    for dataset_name in dynamic_mnist dynamic_fashion_mnist omniglot; do
        for model_name in convhvae_2level convhvae_2level-smim pixelhvae_2level pixelhvae_2level-amim; do
            for prior in vampprior standard; do
                ./vae-as-mim-image.py \
                    --dataset_name ${dataset_name} \
                    --model_name ${model_name} \
                    --prior ${prior} \
                    --seed ${seed} \
                    --use_training_data_init
            done
        done
    done
done

Results below demonstrate comparable sampling and reconstruction of VAE and MIM, and better unsupervised clustering for MIM, as a result of higher mutual information.

Samples

Reconstruction

Latent Embeddings

MIM
MIM Samples
MIM Recon.
MIM Z Embed
VAE
VAE Samples
VAE Recon.
VAE Z Embed

Fashion-MNIST

Samples

Reconstruction

Latent Embeddings

MIM
MIM Samples
MIM Recon.
MIM Z Embed
VAE
VAE Samples
VAE Recon.
VAE Z Embed

MNIST

Code for this experiment is based on Vamprior paper

@article{TW:2017,
  title={{VAE with a VampPrior}},
  author={Tomczak, Jakub M and Welling, Max},
  journal={arXiv},
  year={2017}
}

Citation

Please cite our paper if you use this code in your research:

@misc{livne2019mim,
    title={MIM: Mutual Information Machine},
    author={Micha Livne and Kevin Swersky and David J. Fleet},
    year={2019},
    eprint={1910.03175},
    archivePrefix={arXiv},
    primaryClass={cs.LG}
}

Acknowledgements

Many thanks to Ethan Fetaya, Jacob Goldberger, Roger Grosse, Chris Maddison, and Daniel Roy for interesting discussions and for their helpful comments. We are especially grateful to Sajad Nourozi for extensive discussions and for his help to empirically validate the formulation and experimental work. This work was financially supported in part by the Canadian Institute for Advanced Research (Program on Learning in Machines and Brains), and NSERC Canada.

Your Feedback Is Appreciated

If you find this paper and/or repo to be useful, we would love to hear back! Tell us your success stories, and we will include them in this README.