Skip to content
#!pip install torch
!pip install tensordict torchvision torchrl gymnasium==0.29.1 pygame av==12.0.0

Import the modules

import torch
import time
import matplotlib.pyplot as plt

from torchrl.envs import GymEnv, StepCounter, TransformedEnv

from tensordict.nn import TensorDictModule as TensorDict, TensorDictSequential as Seq

from torchrl.modules import EGreedyModule, MLP, QValueModule

from torchrl.collectors import SyncDataCollector
from torchrl.data import LazyTensorStorage, ReplayBuffer

from torch.optim import Adam

from torchrl.objectives import DQNLoss, SoftUpdate

from torchrl._utils import logger as torchrl_logger
from torchrl.record import CSVLogger, VideoRecorder
torch.manual_seed(0)
env = TransformedEnv(GymEnv("CartPole-v1"), StepCounter())
env.set_seed(0)

Set the parameters and hyperparameters

INIT_RAND_STEPS = 5000
FRAMES_PER_BATCH = 100
OPTIM_STEPS = 10
EPS_0 = 0.5
BUFFER_LEN = 100_000
ALPHA = 0.05
TARGET_UPDATE_EPS = 0.95
REPLAY_BUFFER_SAMPLE = 128
LOG_EVERY = 1000
MLP_SIZE = 64

Create the neural network for the value function and the value-module based on it.

value_mlp = MLP(out_features=env.action_spec.shape[-1], num_cells=[MLP_SIZE, MLP_SIZE])
value_net = TensorDict(value_mlp, in_keys=["observation"], out_keys=["action_value"])

Create the policy

policy = Seq(value_net, QValueModule(spec=env.action_spec))
exploration_module = EGreedyModule(
    env.action_spec, annealing_num_steps=BUFFER_LEN, eps_init=EPS_0
)
policy_explore = Seq(policy, exploration_module)

Declare the data collector and the replay buffer

collector = SyncDataCollector(
    env,
    policy_explore,
    frames_per_batch=FRAMES_PER_BATCH,
    total_frames=-1,
    init_random_frames=INIT_RAND_STEPS,
)
rb = ReplayBuffer(storage=LazyTensorStorage(BUFFER_LEN))

Declare the loss function and the optimizer

loss = DQNLoss(value_network=policy, action_space=env.action_spec, delay_value=True)
optim = Adam(loss.parameters(), lr=ALPHA)
updater = SoftUpdate(loss, eps=TARGET_UPDATE_EPS)