ALDA: Associative Latent DisentAnglement

University of Southern California
Preprint

Abstract

Generalizing vision-based reinforcement learning (RL) agents to novel environments remains a difficult and open challenge. Current trends are to collect large-scale datasets or use data augmentation techniques to prevent overfitting and improve downstream generalization. However, the computational and data collection costs increase exponentially with the number of task variations and can destabilize the already difficult task of training RL agents. In this work, we take inspiration from recent advances in computational neuroscience and propose a model, Associative Latent DisentAnglement (ALDA), that builds on standard off-policy RL towards zero-shot generalization. Specifically, we revisit the role of latent disentanglement in RL and show how combining it with a model of associative memory achieves zero-shot generalization on difficult task variations without relying on data augmentation. Finally, we formally show that data augmentation techniques are a form of weak disentanglement and discuss the implications of this insight. Code coming soon!

ALDA trains only on the original, unmodified task without data augmentation (left) and acheives strong zero-shot generalization performance on challenging distribution shifts (center and right).

Diagram of disentanglement and association
ALDA has two main components. First, it learns a disentangled representation of the latent variables that produce the observations. Second, it uses an associative memory mechanism to zero-shot map individual latent variables to in-distribution values when presented with OOD observations.

Latent Traversals

Latent traversals
Latent traversals of ALDA trained directly on the "color hard" environment for the Walker task. The colors of the agent, floor, and sky are randomly changed to extreme RGB values on reset.


We can visualize the efffects of a disentangled latent representation. We do this by interpolating one latent dimension while keeping the others fixed, and visualize the resulting latent codes by passing them through a decoder. ALDA learns to factorize background or "distractor" variables from task relevant variables automatically. From the latent traversals on the "color hard" environment, we see that latent dimensions that interpolate aspects of the agent (legs, torso, feet) do not affect color information of the agent (or sky and floor) and vice versa.

Latent Trajectories

An interesting property of ALDA is that the latent variable trajectories through time oscillate with similar patterns as the some of the agent's proprioceptive state variables. We visualize a few of the latent variable trajectories and compare them with some of the agent's proprioceptive state trajectories for a single rollout. Given that the ALDA learns a disentangled latent representation, it highly likely that some of the latent variables correspond to certain proprioceptive state variables, such as joint angles through time. While the mapping from high dimensional image observation to the latent space is arbitrary and need not 1:1 correspond with proprioceptive state variables, this remains an interesting observation that merits further investigation.

Latent and state trajectories

Experiments

Main results

We compare against a set of baselines that together cover the range of approaches to zero-shot generalization in vision-based RL, including learning task-centric representations (RePo), disentangled representation learning without association (DARLA), and data augmentation (SVEA). We train on four tasks from the DMControl suite and test on two distribution shift environments, "color hard" that randomizes the colors of the scene and "DistractingCS" which introduces camera perturbations and plays a video in the background. ALDA performs better than all baselines on all tasks except SVEA, which uses additional data during data augmentation that likely puts the training distribution inside the support of the test distributions induced by the evaluation environments.

Videos

Cartpole Balance

Original Task

Color Hard

Distracting CS

Finger Spin

Original Task

Color Hard

Distracting CS

Ball in Cup Catch

Original Task

Color Hard

Distracting CS