Skip to content
TorchRL - PPO
!pip install tensordict torchrl gymnasium==0.29.1 pygame
import torch
from torch import nn
import matplotlib.pyplot as plt
from torchrl.envs import Compose, ObservationNorm, DoubleToFloat, StepCounter, TransformedEnv
from torchrl.envs.libs.gym import GymEnv
from torchrl.envs.utils import check_env_specs
from torchrl.modules import ProbabilisticActor, OneHotCategorical, ValueOperator
from torchrl.collectors import SyncDataCollector
from torchrl.data.replay_buffers import ReplayBuffer
from torchrl.data.replay_buffers.storages import LazyTensorStorage
from torchrl.data.replay_buffers.samplers import SamplerWithoutReplacement
from torchrl.objectives.value import GAE
from torchrl.objectives import ClipPPOLoss
from tensordict.nn import TensorDictModule
Hidden output
device="cpu"
base_env = GymEnv('CartPole-v1', device=device)
FRAMES_PER_BATCH = 1024
TOTAL_FRAMES = 1048576
GAMMA = 0.99
LAMBDA = 0.95
CLIP_EPSILON = 0.1
ALPHA = 5e-5
ENTROPY_EPS = 5e-4
SUB_BATCH_SIZE = 64
OPTIM_STEPS = 8
LOG_EVERY = 16
env = TransformedEnv(
base_env,
Compose(
ObservationNorm(in_keys=["observation"]),
DoubleToFloat(),
StepCounter()
)
)
env.transform[0].init_stats(1024)
torch.manual_seed(0)
env.set_seed(0)
check_env_specs(env)
actor_net = nn.Sequential(
nn.Linear(env.observation_spec["observation"].shape[-1], 16, device=device),
nn.ReLU(),
nn.Linear(16, 16, device=device),
nn.ReLU(),
nn.Linear(16, env.action_spec.shape[-1], device=device),
nn.ReLU()
)
actor_module = TensorDictModule(actor_net, in_keys=["observation"], out_keys=["logits"])
actor = ProbabilisticActor(
module = actor_module,
spec = env.action_spec,
in_keys = ["logits"],
distribution_class = OneHotCategorical,
return_log_prob = True
)
value_net = nn.Sequential(
nn.Linear(env.observation_spec["observation"].shape[-1], 16, device=device),
nn.ReLU(),
nn.Linear(16, 1, device=device),
nn.ReLU()
)
value_module = ValueOperator(
module = value_net,
in_keys = ["observation"]
)
collector = SyncDataCollector(
env,
actor,
frames_per_batch = FRAMES_PER_BATCH,
total_frames = TOTAL_FRAMES,
split_trajs = True,
reset_at_each_iter = True,
device=device
)
replay_buffer = ReplayBuffer(
storage = LazyTensorStorage(max_size=FRAMES_PER_BATCH),
sampler = SamplerWithoutReplacement()
)
advantage_module = GAE(
gamma = GAMMA,
lmbda = GAMMA,
value_network = value_module,
average_gae = True
)
loss_module = ClipPPOLoss(
actor_network = actor,
critic_network = value_module,
clip_epsilon = CLIP_EPSILON,
entropy_bonus = bool(ENTROPY_EPS),
entropy_coef = ENTROPY_EPS
)
optim = torch.optim.Adam(loss_module.parameters(), lr=ALPHA)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optim, TOTAL_FRAMES // FRAMES_PER_BATCH)
rewards = []
for i, tensordict_data in enumerate(collector):
for _ in range(OPTIM_STEPS):
advantage_module(tensordict_data)
replay_buffer.extend(tensordict_data.reshape(-1).cpu())
for _ in range(FRAMES_PER_BATCH // SUB_BATCH_SIZE):
data = replay_buffer.sample(SUB_BATCH_SIZE)
loss = loss_module(data.to(device))
loss_value = loss["loss_objective"] + loss["loss_critic"] + loss["loss_entropy"]
loss_value.backward()
optim.step()
optim.zero_grad()
scheduler.step()
if i % LOG_EVERY == 0:
with torch.no_grad():
rollout = env.rollout(FRAMES_PER_BATCH, actor)
reward_eval = rollout["next","reward"].sum()
print(reward_eval)
rewards.append(reward_eval)
del rollout