Skip to content

Your project is centered around developing a reinforcement learning (RL) simulation for stock trading in Python. This initiative is spearheaded by Quantum Trading, a fictional but ambitious trading firm looking to leverage cutting-edge machine learning techniques to gain a competitive edge in the financial markets. Quantum Trading is a small but highly specialised team of financial analysts, data scientists, and software engineers who are passionate about transforming the way trading decisions are made.

In the fast-paced world of financial markets, staying ahead of the curve is crucial. Traditional trading strategies, while effective, often rely on historical data and predefined rules that may not adapt quickly to changing market conditions. Reinforcement learning, a subfield of machine learning where an agent learns to make decisions by interacting with an environment, offers a promising alternative. It allows the trading algorithms to learn and adapt in real-time, improving their performance as they gain more experience.

By engaging with this project, you will gain valuable insights into the dynamic world of algorithmic trading and enhance your skill set in data science, finance, and machine learning. Remember, the journey of learning and experimentation is as important as the results. Good luck, and may your trading algorithms be ever profitable!

The Data

The provided data AAPL.csv contains historical prices for AAPL (the ticker symbol for Apple Inc) and you will be using this in your model. It has been loaded for you already in the sample code below and contains two columns, described below.

ColumnDescription
DateThe date corresponding to the closing price
CloseThe closing price of the security on the given date

Disclaimer: This project is for educational purposes only. It is not financial advice, and should not be understood or construed as, financial advice.

# Make sure to run this cell to use gymnasium gym-anytrading stable-baselines3
!pip install gymnasium gym-anytrading stable-baselines3
# Import required packages
# Note that gym-anytrading is a gym environment specific for trading
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import gymnasium as gym
import gym_anytrading
from gym_anytrading.envs import Actions
from stable_baselines3 import PPO

# Load the data provided
data = pd.read_csv("AAPL.csv", parse_dates=True, index_col='Date')

# Set window size (lookback window), start and end index of data
window_size = 10
start_index = window_size
end_index = len(data)

# Create the environment using the pre-built 'stocks-v0', passing in the data loaded above
# window_size is how far the environment can look back at each step
# frame_bound is the bounding box of the environment, the size of our data
env = gym.make('stocks-v0', df=data, window_size=window_size, frame_bound=(start_index, end_index))
print("Observation Space:", env.observation_space)

# Initialise cash balance tracking and other variables
balance = 100000
balance_history = [balance]
shares_held = 0
action_stats = {Actions.Sell: 0, Actions.Buy: 0}
observation, info = env.reset(seed=2024)
# Create the PPO model
model = PPO("MlpPolicy", env, verbose=1)

# Train the model for 10,000 timesteps
model.learn(total_timesteps=10000)

# Setup for plotting
buy_dates = []
sell_dates = []
buy_prices = []
sell_prices = []

# Loop through the steps to simulate trading
for step in range(start_index, end_index):
    # Ensure the observation has the correct shape: (10, 2)
    assert observation.shape == (window_size, 2)
    # Predict the action using the trained model
    action, _states = model.predict(observation)
    
    # Get the current stock price
    current_price = data['Close'].iloc[step]
    
    # Determine the action and adjust the balance/shares held
    if action == Actions.Buy and balance >= current_price: # Can afford to buy
        shares_to_buy = balance // current_price # Buy as many shares as possible
        balance -= shares_to_buy * current_price # Deduct the cost
        shares_held += shares_to_buy # Update shares held
        action_stats[Actions.Buy] += 1
        buy_dates.append(data.index[step])
        buy_prices.append(current_price)
        
    elif action == Actions.Sell and shres_held > 0: # Can sell shares
        balance += shares_held * current_price # Add the proceeds to balance
        shares_held = 0 # No shares left
        action_stats[Actions.Sell] += 1
        sell_dates.append(data.index[step])
        sell_prices.append(current_price)
        
    # Hold action: no change to balance or shares_held
    
    # Record the balance at each step for visualization
    balance_history.append(balance + (shares_held * current_price)) # Add the value of shares held
    
    # Move to the next step in the environment
    observation, reward, done, truncated, info = env.step(action)
    # Check if the episode is done or truncated (end of data reached)
    if done or truncated:
        break  # Exit the loop if the environment is finished
    
# Final selling of any remaining shares, if any
if shares_held > 0:
    balance += shares_held * current_price
    shares_held = 0
# Chart 1, a plot showing trading actions
fig, ax = plt.subplots()
ax.plot(data.index[start_index:end_index], data['Close'].iloc[start_index:end_index], label="Stock Price", color='blue')
ax.scatter(buy_dates, buy_prices, marker='^', color='green', label='Buy', alpha=1)
ax.scatter(sell_dates, sell_prices, marker='v', color='red', label='Sell', alpha=1)
ax.set_xlabel("Date")
ax.set_ylabel("Stock Price")
ax.set_title("Stock Price with Buy/Sell Actions")
ax.legend()
plt.xticks(rotation=45)
plt.show()

# Chart 2, a plot of the balance_history over time
fig2, ax2 = plt.subplots()
# Ensure balance_history has the same length as the date range
# balance_history_trimmed = balance_history[1:start_index + len(data.index[start_index:end_index])]
# ax2.plot(data.index[start_index:end_index], balance_history_trimmed, label="Cash Balance", color='orange')
ax2.plot(balance_history)
ax2.set_xlabel("Date")
ax2.set_ylabel("Balance")
ax2.set_title("Cash Balance Over Time")
plt.xticks(rotation=45)
plt.show()

# Print out action stats to understand how many buys and sells occurred
print("Buy actions: ", action_stats[Actions.Buy])
print("Sell actions: ", action_stats[Actions.Sell])