Skip to content
!pip install tensordict torchrl gymnasium==0.29.1 pygame 
import torch
from torch import nn

import matplotlib.pyplot as plt
from torchrl.envs import Compose, ObservationNorm, DoubleToFloat, StepCounter, TransformedEnv
from torchrl.envs.libs.gym import GymEnv
from torchrl.envs.utils import check_env_specs 
from torchrl.modules import ProbabilisticActor, OneHotCategorical, ValueOperator
from torchrl.collectors import SyncDataCollector
from torchrl.data.replay_buffers import ReplayBuffer
from torchrl.data.replay_buffers.storages import LazyTensorStorage
from torchrl.data.replay_buffers.samplers import SamplerWithoutReplacement
from torchrl.objectives.value import GAE
from torchrl.objectives import ClipPPOLoss

from tensordict.nn import TensorDictModule
Hidden output
device="cpu"
base_env = GymEnv('CartPole-v1', device=device) 
FRAMES_PER_BATCH = 1024
TOTAL_FRAMES = 1048576
GAMMA = 0.99
LAMBDA = 0.95
CLIP_EPSILON = 0.1
ALPHA = 5e-5
ENTROPY_EPS = 5e-4 
SUB_BATCH_SIZE = 64
OPTIM_STEPS = 8
LOG_EVERY = 16
env = TransformedEnv( 
    base_env, 
    Compose(
        ObservationNorm(in_keys=["observation"]), 
        DoubleToFloat(), 
        StepCounter()
    )
)

env.transform[0].init_stats(1024) 
torch.manual_seed(0)
env.set_seed(0)
check_env_specs(env) 

actor_net = nn.Sequential(
    nn.Linear(env.observation_spec["observation"].shape[-1], 16, device=device),
    nn.ReLU(),
    nn.Linear(16, 16, device=device),
    nn.ReLU(),
    nn.Linear(16, env.action_spec.shape[-1], device=device),
    nn.ReLU()
)
actor_module = TensorDictModule(actor_net, in_keys=["observation"], out_keys=["logits"])
actor = ProbabilisticActor(
    module = actor_module,
    spec = env.action_spec,
    in_keys = ["logits"],
    distribution_class = OneHotCategorical, 
    return_log_prob = True
)
value_net = nn.Sequential(
    nn.Linear(env.observation_spec["observation"].shape[-1], 16, device=device),
    nn.ReLU(),
    nn.Linear(16, 1, device=device),
    nn.ReLU()
)

value_module = ValueOperator(
    module = value_net,
    in_keys = ["observation"]
)

collector = SyncDataCollector(
    env,
    actor,
    frames_per_batch = FRAMES_PER_BATCH,
    total_frames = TOTAL_FRAMES,
    split_trajs = True,
    reset_at_each_iter = True,
    device=device
)

replay_buffer = ReplayBuffer(
    storage = LazyTensorStorage(max_size=FRAMES_PER_BATCH),
    sampler = SamplerWithoutReplacement()
)
advantage_module = GAE(
    gamma = GAMMA, 
    lmbda = GAMMA, 
    value_network = value_module,
    average_gae = True
) 
loss_module = ClipPPOLoss(
    actor_network = actor,
    critic_network = value_module,
    clip_epsilon = CLIP_EPSILON,
    entropy_bonus = bool(ENTROPY_EPS),
    entropy_coef = ENTROPY_EPS
)
optim = torch.optim.Adam(loss_module.parameters(), lr=ALPHA)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optim, TOTAL_FRAMES // FRAMES_PER_BATCH)
rewards = []
for i, tensordict_data in enumerate(collector): 
    for _ in range(OPTIM_STEPS): 
        advantage_module(tensordict_data)
        replay_buffer.extend(tensordict_data.reshape(-1).cpu())
        for _ in range(FRAMES_PER_BATCH // SUB_BATCH_SIZE): 
            data = replay_buffer.sample(SUB_BATCH_SIZE)
            loss = loss_module(data.to(device))
            loss_value = loss["loss_objective"] + loss["loss_critic"] + loss["loss_entropy"]
            loss_value.backward()
            optim.step()
            optim.zero_grad()
    scheduler.step()
    if i % LOG_EVERY == 0:
        with torch.no_grad():
            rollout = env.rollout(FRAMES_PER_BATCH, actor)
            reward_eval = rollout["next","reward"].sum()
            print(reward_eval)
            rewards.append(reward_eval)
            del rollout