Skip to content
TorchRL - DQN
#!pip install torch
!pip install tensordict torchvision torchrl gymnasium==0.29.1 pygame av==12.0.0
Import the modules
import torch
import time
import matplotlib.pyplot as plt
from torchrl.envs import GymEnv, StepCounter, TransformedEnv
from tensordict.nn import TensorDictModule as TensorDict, TensorDictSequential as Seq
from torchrl.modules import EGreedyModule, MLP, QValueModule
from torchrl.collectors import SyncDataCollector
from torchrl.data import LazyTensorStorage, ReplayBuffer
from torch.optim import Adam
from torchrl.objectives import DQNLoss, SoftUpdate
from torchrl._utils import logger as torchrl_logger
from torchrl.record import CSVLogger, VideoRecorder
torch.manual_seed(0)
env = TransformedEnv(GymEnv("CartPole-v1"), StepCounter())
env.set_seed(0)
Set the parameters and hyperparameters
INIT_RAND_STEPS = 5000
FRAMES_PER_BATCH = 100
OPTIM_STEPS = 10
EPS_0 = 0.5
BUFFER_LEN = 100_000
ALPHA = 0.05
TARGET_UPDATE_EPS = 0.95
REPLAY_BUFFER_SAMPLE = 128
LOG_EVERY = 1000
MLP_SIZE = 64
Create the neural network for the value function and the value-module based on it.
value_mlp = MLP(out_features=env.action_spec.shape[-1], num_cells=[MLP_SIZE, MLP_SIZE])
value_net = TensorDict(value_mlp, in_keys=["observation"], out_keys=["action_value"])
Create the policy
policy = Seq(value_net, QValueModule(spec=env.action_spec))
exploration_module = EGreedyModule(
env.action_spec, annealing_num_steps=BUFFER_LEN, eps_init=EPS_0
)
policy_explore = Seq(policy, exploration_module)
Declare the data collector and the replay buffer
collector = SyncDataCollector(
env,
policy_explore,
frames_per_batch=FRAMES_PER_BATCH,
total_frames=-1,
init_random_frames=INIT_RAND_STEPS,
)
rb = ReplayBuffer(storage=LazyTensorStorage(BUFFER_LEN))
Declare the loss function and the optimizer
loss = DQNLoss(value_network=policy, action_space=env.action_spec, delay_value=True)
optim = Adam(loss.parameters(), lr=ALPHA)
updater = SoftUpdate(loss, eps=TARGET_UPDATE_EPS)