Skip to content
Snippets Groups Projects
Unverified Commit d4c60d2c authored by Haichao Zhang's avatar Haichao Zhang Committed by GitHub
Browse files

VQ VAE Algorithm (#1409)

* VQ VAE Algorithm

* Address comments

* Fix embedding summary
parent 59cf10cd
No related branches found
No related tags found
No related merge requests found
......@@ -48,6 +48,7 @@ Read the ALF documentation [here](https://alf.readthedocs.io/).
|[VAE](alf/algorithms/vae.py)|General|Higgins et al. "beta-VAE: Learning Basic Visual Concepts with a Constrained Variational Framework" [ICLR2017](https://openreview.net/forum?id=Sy2fzU9gl)|
|[RealNVP](alf/networks/normalizing_flow_networks.py)|General|Dinh et al. "Density estimation using Real NVP" [arXiv:1605.08803](https://arxiv.org/abs/1605.08803)|
|[SpatialBroadcastDecoder](alf/networks/encoding_networks.py)|General|Watters et al. "Spatial Broadcast Decoder: A Simple Architecture for Learning Disentangled Representations in VAEs" [arXiv:1901.07017](https://arxiv.org/abs/1901.07017)|
|[VQ-VAE](alf/algorithms/vq_vae.py)|General|A van den Oord et al. "Neural Discrete Representation Learning" [NeurIPS2017](https://proceedings.neurips.cc/paper/2017/file/7a98af17e63a0ac09ce2e96d03992fbc-Paper.pdf)|
## Installation
......
# Copyright (c) 2022 Horizon Robotics. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Vector Quantized Variational AutoEncoder Algorithm."""
import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import Callable
import alf
from alf.algorithms.algorithm import Algorithm
from alf.data_structures import AlgStep, LossInfo, namedtuple
from alf.networks import EncodingNetwork
VqvaeLossInfo = namedtuple(
"VqvaeLossInfo", ["quantization", "commitment", "reconstruction"],
default_value=())
class Vqvae(Algorithm):
r"""Vector Quantized Variational AutoEncoder (VQVAE) algorithm, described in:
::
A van den Oord et al. "Neural Discrete Representation Learning", NeurIPS 2017.
VQVAE is different from standard VAE mainly in the follows aspects:
1. Discrete latent is used, instead of continuous latent as in standard VAE.
2. Standard VAE uses Gaussian prior and posterior. VQVAE can be viewed as
using a determinstic form of posterior, which is a categorical
distribution with onehot samples computed by nearest neighbor matching
(Eq.1 of the paper). By using a uniform prior, the KL divergence is constant.
"""
def __init__(self,
input_tensor_spec: alf.NestedTensorSpec,
num_embeddings: int,
embedding_dim: int,
encoder_ctor: Callable = EncodingNetwork,
decoder_ctor: Callable = EncodingNetwork,
optimizer: torch.optim.Optimizer = None,
commitment_loss_weight: float = 1.0,
debug_summaries: bool = False,
name: str = "Vqvae"):
"""
Args:
input_tensor_spec (TensorSpec): the tensor spec of
the input.
num_embeddings (int): the number of embeddings (size of codebook)
embedding_dim (int): the dimensionality of embedding vectors
encoder_ctor (Callable): called as ``encoder_ctor(observation_spec)``
to construct the encoding ``Network``. The network takes raw observation
as input and output the latent representation.
decoder_ctor (Callable): called as ``decoder_ctor(latent_spec)`` to
construct the decoder.
optimizer (Optimzer|None): if provided, it will be used to optimize
the parameter of encoder_net, decoder_net and embedding vectors.
commitment_loss_weight (float): the weight for commitment loss.
"""
super().__init__(debug_summaries=debug_summaries, name=name)
self._embedding_dim = embedding_dim
self._num_embeddings = num_embeddings
# [n, d]
self._embedding = torch.nn.Parameter(
torch.FloatTensor(self._num_embeddings, self._embedding_dim))
torch.nn.init.uniform_(
self._embedding,
a=-1 / self._num_embeddings,
b=1 / self._num_embeddings)
self._encoding_net = encoder_ctor(input_tensor_spec)
self._decoding_net = decoder_ctor(self._encoding_net.output_spec)
if optimizer is not None:
self.add_optimizer(
optimizer,
[self._encoding_net, self._decoding_net, self._embedding])
self._optimizer = optimizer
self._commitment_loss_weight = commitment_loss_weight
def _predict_step(self, inputs, state=()):
"""
Args:
inputs (tensor): with the shape the same as input_tensor_spec
"""
# [B, d]
input_embedding, _ = self._encoding_net(inputs)
# calculate distances
# [B, 1] + [n] + [B, n]
distances = (torch.sum(input_embedding**2, dim=1, keepdim=True) +
torch.sum(self._embedding**2, dim=1) -
2 * torch.matmul(input_embedding, self._embedding.t()))
encoding_indices = torch.argmin(distances, dim=1)
quantized = self._embedding[encoding_indices]
# straight through
quantized_st = input_embedding + (quantized - input_embedding).detach()
return input_embedding, quantized, quantized_st
def predict_step(self, inputs, state=()):
_, _, quantized_st = self._predict_step(inputs)
rec = self._decoding_net(quantized_st)[0]
return AlgStep(output=rec, state=state, info=quantized_st)
def train_step(self, inputs, state=()):
"""
Args:
inputs (tensor): with the shape the same as input_tensor_spec
"""
input_embedding, quantized, quantized_st = self._predict_step(inputs)
e_latent_loss = F.mse_loss(
quantized.detach(), input_embedding, reduction="none")
q_latent_loss = F.mse_loss(
quantized, input_embedding.detach(), reduction="none")
# encoding loss
enc_loss = (q_latent_loss +
self._commitment_loss_weight * e_latent_loss).mean(dim=1)
# decoding loss
rec = self._decoding_net(quantized_st)[0]
recon_loss = F.mse_loss(rec, inputs, reduction="none").mean(dim=1)
if self._debug_summaries and alf.summary.should_record_summaries():
with alf.summary.scope(self._name):
alf.summary.embedding("vq_embedding", self._embedding.detach())
loss = (enc_loss + recon_loss)
info = VqvaeLossInfo(
quantization=q_latent_loss.mean(1),
commitment=e_latent_loss.mean(1),
reconstruction=recon_loss)
loss_info = LossInfo(loss=loss, extra=info)
return AlgStep(output=rec, state=state, info=loss_info)
# Copyright (c) 2022 Horizon Robotics. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from absl import logging
import os
import numpy as np
import tempfile
from functools import partial
import torch
import alf
from alf.algorithms.vq_vae import Vqvae
from alf.layers import FC
from alf.nest.utils import NestConcat
from alf.networks import EncodingNetwork
from alf.tensor_specs import TensorSpec, BoundedTensorSpec
from alf.utils import math_ops
class VQVaeTest(alf.test.TestCase):
def setUp(self):
super().setUp()
self._input_spec = TensorSpec((1, ))
self._epochs = 10
self._batch_size = 100
self._num_embeddings = 32
self._embedding_dim = 10
self._loss_f = math_ops.square
self._learning_rate = 1e-3
self._commitment_loss_weight = 0.1
def test_vq_vae(self):
"""Test for one dimensional signal."""
fc_layers_params = (256, ) * 2
encoder_cls = partial(
alf.networks.EncodingNetwork,
fc_layer_params=fc_layers_params,
last_layer_size=self._embedding_dim,
last_activation=math_ops.identity,
last_kernel_initializer=partial(torch.nn.init.uniform_, \
a=-0.03, b=0.03)
)
decoder_cls = partial(
alf.networks.EncodingNetwork,
fc_layer_params=fc_layers_params,
last_layer_size=1,
last_activation=math_ops.identity,
last_kernel_initializer=partial(torch.nn.init.uniform_, \
a=-0.03, b=0.03)
)
optimizer = alf.optimizers.Adam(lr=self._learning_rate)
vq_vae = Vqvae(
input_tensor_spec=self._input_spec,
num_embeddings=self._num_embeddings,
embedding_dim=self._embedding_dim,
encoder_ctor=encoder_cls,
decoder_ctor=decoder_cls,
optimizer=optimizer,
commitment_loss_weight=self._commitment_loss_weight)
# construct 1d samples around two centers with additive noise
num_centers = 2
x_train = 1e-1 * self._input_spec.randn(outer_dims=(10000, ))
x_test = 1e-1 * self._input_spec.randn(outer_dims=(10, ))
x_train = x_train.view(-1, num_centers) + torch.arange(0, num_centers)
x_test = x_test.view(-1, num_centers) + torch.arange(0, num_centers)
x_train = x_train.view(-1, 1)
x_test = x_test.view(-1, 1)
for _ in range(self._epochs):
x_train = x_train[torch.randperm(x_train.shape[0])]
for i in range(0, x_train.shape[0], self._batch_size):
batch = x_train[i:i + self._batch_size]
alg_step = vq_vae.train_step(batch)
vq_vae.update_with_gradient(alg_step.info)
alg_step = vq_vae.predict_step(x_test)
reconstruction_loss = float(
torch.mean(self._loss_f(x_test - alg_step.output)))
print("reconstruction_loss:", reconstruction_loss)
self.assertLess(reconstruction_loss, 0.05)
if __name__ == '__main__':
alf.test.main()
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment