Skip to main content
HomeTutorialsPython

Optuna for Deep Reinforcement Learning in Python

Explore how to master hyperparameter tuning with Optuna. Learn how to define hyperparameters, set up your objective function, and utilize sampling and pruning techniques in deep reinforcement learning.
Aug 7, 2024  · 11 min read

Machine learning models have hyperparameters that significantly affect their performance. These parameters can be described as the right temperature to bake a cake or the right angle to flip a pancake. In reinforcement learning, hyperparameters affect the policy and reward approximations, and with better-tuned results, hyperparameter optimization can drastically improve the results. 

This article will explore Optuna, a powerful open-source framework for optimizing hyperparameters. If you are new to the topic, consider starting with Understanding Machine Learning, which offers a no-code introduction. To practice while using Python, try our Machine Learning Fundamentals with Python course, which uses scikit-learn. Then, MLOps Concepts will help you understand how to transition machine learning models from local notebooks to production environments. Finally, try Preprocessing for Machine Learning in Python, a highly relevant course for this article because it teaches you to clean and prepare your data for machine learning models.

What is Optuna?

Optuna is an open-source tool for hyperparameter optimization framework for automating hyperparameter search. It can be used with any machine learning or deep learning framework. Some of the key capabilities of Optuna include: 

  • Easy to Parallelize and Scale Across Tasks: As the computational demands increase, Optuna gives the flexibility to leverage distributed training and also allows you to connect to cloud backends like Kubernetes or Dask.
  • Automated Search for Optimal Hyperparameters: Optuna searches for the best combinations of hyperparameters when you specify a range for each one. 
  • State-of-the-Art Algorithms for Sampling and Pruning: Optuna has some state-of-the-art algorithms, like Median Pruner and WilcoxonPruner, which prunes using statistical tests.
  • Ease of Use: Optuna is easy to use and requires 5-6 lines of code to implement!
  • Optuna Dashboard: Optuna now allows you to store and analyze the results of all your experiments with the Optuna Dashboard.
  • Optuna Hub. Allows you to share and use features made by contributors. 

What is Reinforcement Learning (RL)?

illustration of reinforcement learning using a robot and two flags, one green and one orange

Finding the optimal path. Source: Image by Author

Reinforcement learning (RL) combines machine learning and optimal control. It concerns how an intelligent agent should act in a dynamic environment to maximize its cumulative reward. Simply, reinforcement learning is about teaching an intelligent agent to make the most optimal decisions that maximize its rewards in an environment. The reward for any step it takes can be positive or negative as it approaches the desired destination.

Unlike machine learning, reinforcement learning algorithms make decisions to achieve the most optimal results, similar to the trial-and-error learning process that we use to achieve our goals. Deep reinforcement learning differs from reinforcement learning because it uses deep neural networks to approximate complex functions. We could also say that regular reinforcement learning focuses on more deterministic approaches, while deep reinforcement learning uses stochastic approaches.

Using Optuna in Your Reinforcement Learning Model

In an environment, there can be many routes to reaching the destination. Given that compared with supervised learning, deep reinforcement learning is far more sensitive to the choice of hyper-parameters such as learning rate, gamma, number of steps, and so on.

Poor choice of hyper-parameters can lead to poor/unstable convergence. This challenge is compounded by the variability in performance across random seeds (used to initialize the network weights and the environment).

In this section, we will walk through the process of setting up Optuna and related dependencies, defining the search space, and running the optimization for the MountainCarContinuous-v0 algorithm. The MountainCarContinuous-v0 algorithm is one of the RL algorithms in RL-Zoo, a training framework that provides scripts for evaluating agents and plotting results. The MountainCarContinous-v0 algorithm models an environment where the agent (a car) tries to achieve its goal (reaching the flag).

a gif showing a car riding up to reach the flag at the top of the  mountain to illustrate mountain car algorithm

Representation of the tuned MountainCarContinuous-v0 algorithm. Image by Author.

How to get started with Optuna

To get started, you need to:

  1. Setup your model-dependent configuration.
  2. Define your hyperparameters and suggest the range of values for each.
  3. Define your objective function, which contains the model or function you want to optimize. You should also define a CallBack to return periodic evaluations.
  4. Create an Optuna study and run the optimized function for a specified number of trials.

Setting up Optuna and StableBaseline 3

After creating a new virtual environment, we install our packages. As a note, Optuna supports Python 3.7 or newer. If you would like to review setting up a virtual enviornment, read our Virtual Environment in Python tutorial.

pip install optuna
pip install stable-baselines3
pip install sb3-contrib

Import the necessary packages

We will use the gym package to create the environment where the agent will perform actions.

from stable_baselines3.common.env_util import make_vec_env
from stable_baselines3 import A2C
import gym
import optuna
from optuna.pruners import MedianPruner
from optuna.samplers import TPESampler
from optuna.visualization import plot_optimization_history, plot_param_importances
from typing import Any, Dict
import torch
import torch.nn as nn
print(optuna.__version__)

This project will use the MountainCarContinuous-v0 algorithm from Advantage Actor-Critic (A2C). Advantage Actor-Critic (A2C) - one of the Actor-Critic algorithms, is a hybrid architecture combining value-based (measures the reward for an action taken) and policy-based (control the agent actions) methods that help to stabilize the training by reducing the variance.

Define the hyperparameters

Let us set the configurations to be used in this task.

N_TRIALS = 100  # Maximum number of trials
N_JOBS = 1 # Number of jobs to run in parallel
N_STARTUP_TRIALS = 5  # Stop random sampling after N_STARTUP_TRIALS
N_EVALUATIONS = 2  # Number of evaluations during the training
N_TIMESTEPS = 100000  # Training budget
EVAL_FREQ = int(N_TIMESTEPS / N_EVALUATIONS)
N_EVAL_ENVS = 5
N_EVAL_EPISODES = 10
TIMEOUT = int(60 * 15)  # 15 minutes

ENV_ID = "MountainCarContinuous-v0"

DEFAULT_HYPERPARAMS = {
    "policy": "MlpPolicy",
    "env": ENV_ID,
}

We will suggest a range of values for the hyperparameters.

def a2c_hyper_params(trial: optuna.Trial) -> dict:
    """Sample A2C hyperparameters for Optuna trial."""
    return {
        "learning_rate": trial.suggest_float("learning_rate", 1e-5, 1e-2),
        "gamma": trial.suggest_float("gamma", 0.9, 0.9999),
        "n_steps": trial.suggest_int("n_steps", 5, 2048),
        "ent_coef": trial.suggest_float("ent_coef", 1e-8, 1e-2),
        "vf_coef": trial.suggest_float("vf_coef", 0.1, 1.0),
        "max_grad_norm": trial.suggest_float("max_grad_norm", 0.3, 10)
    }

Define the objective function

A CallBack function TrialEvalCallback() was defined because we want to return the period evaluation results of the optimization tasks. Check the Google Colab notebook for the callback implementation.

def objective(trial: optuna.Trial) -> float:
    """
    This will be used by Optuna to evaluate one set of hyperparameters at a time.
Given a trial object, it will sample hyperparameters, evaluate it and report the result.
    :param trial: Optuna trial object
    :return: Mean episodic reward after training
    """
    kwargs = DEFAULT_HYPERPARAMS.copy()

    # 1. Sample hyperparameters and update the keyword arguments
    kwargs.update(a2c_hyper_params(trial))

    # 2. Create the RL model
    model = A2C(**kwargs)

    # 3. Create envs used for evaluation using make_vec_env, ENV_ID and N_EVAL_ENVS
    eval_envs = make_vec_env(ENV_ID, n_envs=N_EVAL_ENVS)

    # 4. Create the TrialEvalCallback callback 
    eval_callback = TrialEvalCallback(
        eval_envs,
        trial,
        n_eval_episodes=N_EVAL_EPISODES,
        eval_freq=EVAL_FREQ,
        deterministic=True,
        verbose=0,
    )
    nan_encountered = False
    try:
        # Train the model
        model.learn(N_TIMESTEPS, callback=eval_callback)
    except AssertionError as e:
        # Sometimes, random hyperparams can generate NaN
        print(e)
        nan_encountered = True
    finally:
        # Free memory
        model.env.close()
        eval_envs.close()
    # Tell the optimizer that the trial failed
    if nan_encountered:
        return float("nan")
    if eval_callback.is_pruned:
        raise optuna.exceptions.TrialPruned()
    return eval_callback.last_mean_reward

Efficient sampling and pruning mechanisms

Pruning refers to cutting back on parameters or regions of the search space that are less likely to give an optimal result, and in this case, pruning after a specific budget. In Optuna, pruning removes trials that perform worse at the early stages of the training. This cuts down the set of parameters and the time to reach the optimal search space. Optuna has a couple of pruners, including the Median Pruner, which uses the median stopping rule. You can read more about Optuna Pruners from the Optuna documentation.

Implementing Optuna's MedianPruner 

The Median Pruner algorithm prunes if the current best result is worse than the median of previous results. The division by three is used because we only want to prune after 1/3 of the maximum budget is used.

pruner = MedianPruner(
n_startup_trials=N_STARTUP_TRIALS, n_warmup_steps=N_EVALUATIONS // 3
)

Implementing Optuna's TPESampler

Select the sampler, which can be random, TPESampler, CMAES, etc.

sampler = TPESampler(n_startup_trials=N_STARTUP_TRIALS)

Create the study and start the optimization task

We are now ready to create an Optuna study.

study = optuna.create_study(sampler=sampler, pruner=pruner, direction="maximize")
try:
study.optimize(objective, n_trials=N_TRIALS, n_jobs=N_JOBS, timeout=TIMEOUT)
except KeyboardInterrupt:
    pass
print("Number of finished trials: ", len(study.trials))
print("Best trial:")
trial = study.best_trial
print(f" Value: {trial.value}")
print(" Params: ")
for key, value in trial.params.items():
    print(f" {key}: {value}")
print(" User attrs:")
for key, value in trial.user_attrs.items():
    print(f" {key}: {value}")

Advanced Features of Optuna

Let’s take a look at some of Optuna’s more advanced features.

Optuna's visualization tools

Optuna provides a dashboard that allows us to visualize, track, and analyze previous and current runs. In the past, you might have had to connect your tracking tool, like MLFlow, to manage the results.

Optuna Dashboard example

Optuna Dashboard. Source: Optuna

Integration with distributed systems for large-scale experiments

Optuna's support for distributed training allows you to connect your distributed backend, like Kubernetes or Dask, to leverage more compute resources for large-scale experiments. You can also connect your storage backend to store and manage your session's results. You can use cloud storage like Dask storage or SQLite for local runs.

Tips for Using Optuna in Deep Reinforcement Learning

Here are some tips to avoid common pitfalls in hyperparameter tuning:

  • Determine Need for Hyperparameter Tuning: Consider if hyperparameter tuning is what you need. Sometimes, you need to train for longer steps.
  • Be Prepared for Inconsistent Results: Running the same experiment with the same hyperparameters multiple times in RL does not guarantee the same results.
  • Consider Search Space Size Challenges: Like a typical machine learning problem, you may never get the optimal parameters if the search space is too small. If the search space is too large, it may take forever.

Optuna vs. HyperOpt and Other Tools

As a final section, let’s compare Optuna, HyperOpt, and RayTune in terms of their ease of use, support, and other dimensions.

 
Features Optuna HyperOpt RayTune
Ease of Use User-friendly API Heavy on setup Comprehensive but more complex API
Supports PyTorch/TensorFlow and other ML frameworks Yes Yes Yes
Supports Distributed/parallel optimization Yes Limited support Yes
Pruning/Stopping Yes No, just early stopping Yes
Visualization tool Yes No Yes
Search Algorithms TPESampler, Random Sampler, CmAESampler, GridSampler, QMCSampler, NSGAIISampler, PartialFixedSampler, GPSampler TPE, Random Search Random Search, Grid Search, HyperBand, BOHB, PBT, ASHA

Conclusion

In this article, we studied the concepts of hyperparameter optimization and reinforcement learning. You also learned how to set up Optuna and connect it to your favorite RL algorithm. You should get hands-on with these concepts and apply them to a project you've worked on or want to start. 

For further reading and to deepen your understanding, consider exploring Optuna's official documentation and this Reinforcement Learning Tutorial on Hyperparameter Optimization

Also, consider additional Datacamp resources for further learning, including our What is Reinforcement Learning From AI Feedback? and What is Reinforcement Learning from Human Feedback? tutorials. If you are interviewing in the space, check out our 25 Machine Learning Projects for All Levels and Top 25 Machine Learning Interview Questions blog posts.

Build Machine Learning Skills

Elevate your machine learning skills to production level.

Start Learning for Free

Photo of Bunmi Akinremi
Author
Bunmi Akinremi
LinkedIn
Twitter

Machine Learning Engineer and Poet

Frequently Asked Questions

What is Optuna and how does it work?

Optuna is an open-source hyperparameter optimization framework designed to automate the tuning process for machine learning models. It requires you to define the hyperparameters you want to tune, the objective function, and run the study.

Why should I use Optuna for Deep Reinforcement Learning?

Optuna handles complex and high-dimensional search spaces, prunes unpromising trials, and easily integrates with popular deep learning frameworks.

How can I integrate Optuna in my existing Reinforcement learning projects?

To integrate Optuna with deep reinforcement learning, you should select the hyperparameters and provide a search space for each, define the objective function that evaluates your model's performance, and create a study that runs the function for n trials. This article provides a step-by-step guide on how to achieve this integration.

Can Optuna be used with popular Deep Reinforcement Learning frameworks like TensorFlow and PyTorch?

Yes, Optuna supports TensorFlow and PyTorch frameworks.

What hyperparameters in Reinforcement Learning algorithms can be optimized?

You can optimize the learning rate, number of steps, gamma, activation function, and network architecture.

Topics

Learn Machine Learning with DataCamp

Course

Understanding Machine Learning

2 hr
200K
An introduction to machine learning with no coding involved.
See DetailsRight Arrow
Start Course
See MoreRight Arrow
Related

tutorial

Mastering Bayesian Optimization in Data Science

Unlock the power of Bayesian Optimization for hyperparameter tuning in Machine Learning. Master theoretical foundations and practical applications with Python to enhance model accuracy.
Zoumana Keita 's photo

Zoumana Keita

11 min

tutorial

Hyperparameter Optimization in Machine Learning Models

This tutorial covers what a parameter and a hyperparameter are in a machine learning model along with why it is vital in order to enhance your model’s performance.
Sayak Paul's photo

Sayak Paul

19 min

tutorial

Keras Tutorial: Deep Learning in Python

This Keras tutorial introduces you to deep learning in Python: learn to preprocess your data, model, evaluate and optimize neural networks.
Karlijn Willems's photo

Karlijn Willems

43 min

tutorial

An Introduction to Q-Learning: A Tutorial For Beginners

Learn about the most popular model-free reinforcement learning algorithm with a Python tutorial.
Abid Ali Awan's photo

Abid Ali Awan

16 min

tutorial

Fine-Tuning GPT-3 Using the OpenAI API and Python

Unleash the full potential of GPT-3 through fine-tuning. Learn how to use the OpenAI API and Python to improve this advanced neural network model for your specific use case.
Zoumana Keita 's photo

Zoumana Keita

12 min

code-along

Running Machine Learning Experiments in Python

In this webinar, you'll use MLflow to manage a machine learning experiment pipeline. The session will cover model evaluation, hyperparameter tuning, and MLOps strategies, using a London weather dataset.
Folkert Stijnman's photo

Folkert Stijnman

See MoreSee More