In recent years, the Transformer architecture has revolutionized how we process sequential data, but its applications in online RL settings are still very limited. In this paper a Transformer architecture is applied as the first layer inside a DQN in order to process in a set of observations. The hypothesis is that the encoder-decoder architecture can extract more information from an observation than a linear layer. Furthermore, the multi-head attention can help the network focus on what experiences are deemed most important. In the end I provide detailed results, backed by a recently suggested statistical framework, to show that the proposed architecture surpasses the baseline on 1 of the 3 classic control environments used for evaluation, but on average it performs worse, providing a good starting point for future research.

Fig. 2. Example screenshots for each environment. From left to right: Acrobot, CartPole, LunarLander.
The train
and play
modules support any environment that has a
Box
observation
space and a
Discrete
action space.
Plotting utils use hardcoded values for x-axis and y-axis ticks and limits
for Acrobot-v1
, CartPole-v1
and LunarLander-v2
, but additional
environments can be accomodated with minimal modifications.
Python 3.11 is required.
git clone [email protected]:robertoschiavone/transformer-q-network.git
cd transformer-q-network
pip install -r requirements.txt
If your machine supports CUDA, it is necessary to export the
CUBLAS_WORKSPACE_CONFIG
environment variable.
export CUBLAS_WORKSPACE_CONFIG=:4096:8
python -m tqn --train --env-id $ENV_ID
Optional arguments:
-
--seed
: an integer number. If no seed is passed, the model will still be trained deterministically by using the current Unix timestamp, approximated to the closest second, as a seed. -
--network
: eitherDQN
orTQN
. Default isTQN
.
You can find previously trained models and their TensorBoard training
respectively inside the folders models
and logs
.
python -m tqn --play --env-id $ENV_ID --model $MODEL_NAME
python -m tqn --play --env-id LunarLander-v2 \
--model ./models/LunarLander-v2/DQN/1693526400.zip # example
Optional arguments:
--seed
: as above.
python -m tqn --plot --env-id $ENV_ID --plot-type $PLOT_TYPE
$PLOT_TYPE
has to be one of the following values:
-
environments
(example) -
q-vs-frame
(example) -
q-vs-loss
(example) -
q-vs-reward
(example, it works only for LunarLander-v2) -
sample-efficiency-probability-improvement
(example) -
score-vs-episode
(example) -
statistics
(example)
You can see further examples inside the folder plots
.
python -m tensorboard.main --load_fast true --logdir "./logs"
.
├── config
│ └── {ENVIRONMENT}.json
├── logs
│ └── {ENVIRONMENT}
│ └── {NETWORK}
│ └── {SEED}
│ └── events.out.tfevents.{TIMESTAMP}.archlinux.0.0
├── models
│ └── {ENVIRONMENT}
│ └── {NETWORK}
│ └── {SEED}.zip
├── plots
│ └── {PLOT}.pdf
├── thesis
│ └── {LATEX FILES}
├── tqn
│ └── {SOURCE CODE}
├── LICENSE
├── LunarLander-v2_TQN_1694390400.gif
├── README.md
├── requirements.txt
└── thesis.pdf
-
CleanRL, Clean Implementation of RL Algorithms
-
Stable-Baselines3, Reliable Reinforcement Learning Implementations