Your project is centered around developing a reinforcement learning (RL) simulation for stock trading in Python. This initiative is spearheaded by Quantum Trading, a fictional but ambitious trading firm looking to leverage cutting-edge machine learning techniques to gain a competitive edge in the financial markets. Quantum Trading is a small but highly specialised team of financial analysts, data scientists, and software engineers who are passionate about transforming the way trading decisions are made.
In the fast-paced world of financial markets, staying ahead of the curve is crucial. Traditional trading strategies, while effective, often rely on historical data and predefined rules that may not adapt quickly to changing market conditions. Reinforcement learning, a subfield of machine learning where an agent learns to make decisions by interacting with an environment, offers a promising alternative. It allows the trading algorithms to learn and adapt in real-time, improving their performance as they gain more experience.
By engaging with this project, you will gain valuable insights into the dynamic world of algorithmic trading and enhance your skill set in data science, finance, and machine learning. Remember, the journey of learning and experimentation is as important as the results. Good luck, and may your trading algorithms be ever profitable!
The Data
The provided data AAPL.csv
contains historical prices for AAPL (the ticker symbol for Apple Inc) and you will be using this in your model. It has been loaded for you already in the sample code below and contains two columns, described below.
Column | Description |
---|---|
Date | The date corresponding to the closing price |
Close | The closing price of the security on the given date |
Disclaimer: This project is for educational purposes only. It is not financial advice, and should not be understood or construed as, financial advice.
# Make sure to run this cell to use gymnasium gym-anytrading stable-baselines3
!pip install gymnasium gym-anytrading stable-baselines3
# Import required packages
# Note that gym-anytrading is a gym environment specific for trading
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import gymnasium as gym
import gym_anytrading
from gym_anytrading.envs import Actions
from stable_baselines3 import PPO
# Load the data provided
data = pd.read_csv("AAPL.csv", parse_dates=True, index_col='Date')
# Set window size (lookback window), start and end index of data
window_size = 10
start_index = window_size
end_index = len(data)
# Create the environment using the pre-built 'stocks-v0', passing in the data loaded above
# window_size is how far the environment can look back at each step
# frame_bound is the bounding box of the environment, the size of our data
env = gym.make('stocks-v0', df=data, window_size=window_size, frame_bound=(start_index, end_index))
print("Observation Space:", env.observation_space)
# Initialise cash balance tracking and other variables
balance = 100000
balance_history = [balance]
shares_held = 0
action_stats = {Actions.Sell: 0, Actions.Buy: 0}
observation, info = env.reset(seed=2024)
# Import required packages
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import gymnasium as gym
import gym_anytrading
from gym_anytrading.envs import Actions
from stable_baselines3 import PPO
# Function to load data and create the trading environment
def create_trading_env(file_path, window_size=10):
"""
Load stock data and create the gym environment for trading.
"""
data = pd.read_csv(file_path, parse_dates=True, index_col='Date')
start_index = window_size
end_index = len(data)
env = gym.make('stocks-v0', df=data, window_size=window_size, frame_bound=(start_index, end_index))
return env, data
# Function to train the PPO model
def train_model(env, timesteps=10000):
"""
Train the PPO model on the provided environment.
"""
model = PPO('MlpPolicy', env, verbose=0)
model.learn(total_timesteps=timesteps)
return model
# Function to execute trading actions in the environment using the trained model
def execute_trades(env, model, balance=100000, trade_percent=0.10):
"""
Execute trades based on the trained model and track balance changes.
"""
balance_history = [balance]
shares_held = 0
action_stats = {Actions.Sell: 0, Actions.Buy: 0} # Track only Buy and Sell actions
observation, info = env.reset(seed=2024)
step = 0
while True:
action, _states = model.predict(observation)
current_price = env.unwrapped.prices[env.unwrapped._current_tick]
observation, reward, terminated, truncated, info = env.step(action)
balance, shares_held = perform_trade_action(action, current_price, balance, shares_held, trade_percent)
# Update action stats and balance history
if action == Actions.Buy.value or action == Actions.Sell.value:
action_stats[Actions(action)] += 1
balance_history.append(balance)
# Print action details
print_action_details(step, action, current_price, shares_held, balance)
step += 1
if terminated or truncated:
break
# Final sell if there are still shares held
if shares_held > 0:
balance, shares_held = final_sell(current_price, balance, shares_held)
return balance, balance_history, action_stats
# Helper function to perform trade actions
def perform_trade_action(action, current_price, balance, shares_held, trade_percent):
"""
Execute a buy or sell action and update balance and shares held.
"""
trade_amount = balance * trade_percent
if action == Actions.Buy.value:
shares_to_buy = trade_amount / current_price
shares_held += shares_to_buy
balance -= trade_amount
elif action == Actions.Sell.value and shares_held > 0:
balance += shares_held * current_price
shares_held = 0
return balance, shares_held
# Helper function to print details of each action
def print_action_details(step, action, current_price, shares_held, balance):
"""
Print trade action details (Buy, Sell, Hold) along with current balance and shares.
"""
if action == Actions.Buy.value:
print(f"{step}: BUY | Shares Held: {shares_held:.2f} | Price: ${current_price:.2f} | Balance: ${balance:.2f}")
elif action == Actions.Sell.value:
print(f"{step}: SELL | Shares Held: {shares_held:.2f} | Price: ${current_price:.2f} | Balance: ${balance:.2f}")
else:
print(f"{step}: HOLD | Price: ${current_price:.2f} | Balance: ${balance:.2f}")
# Helper function for final sell action
def final_sell(current_price, balance, shares_held):
"""
Execute the final sell of remaining shares at the end of the simulation.
"""
balance += shares_held * current_price
print(f"\nFinal SELL | {shares_held:.2f} shares at ${current_price:.2f} | Final Balance: ${balance:.2f}")
return balance, 0
# Function to plot results (price actions and balance history)
def plot_results(env, balance_history):
"""
Plot the trading actions and balance history.
"""
# Plot the price chart with buy/sell actions
fig, ax = plt.subplots()
env.unwrapped.render_all()
ax.set_title("PPO Agent - Trading Actions")
plt.show()
# Plot the balance history
fig2, ax2 = plt.subplots()
ax2.plot(balance_history)
ax2.set_title("PPO Agent - Balance Over Time")
ax2.set_xlabel("Steps")
ax2.set_ylabel("Balance ($)")
fig2.tight_layout()
plt.show()
# Main function to run the full simulation
def main():
"""
Main function to run the trading simulation with PPO.
"""
# Create the environment
env, data = create_trading_env("AAPL.csv")
# Train the PPO model
model = train_model(env, timesteps=10000)
# Execute trades and track performance
final_balance, balance_history, action_stats = execute_trades(env, model)
# Display results
print("Action Stats:", action_stats)
print(f"Final Balance: ${final_balance:.2f}")
# Plot the results
plot_results(env, balance_history)
# Run the simulation
if __name__ == "__main__":
main()
# Create your two charts below. Note, do not change the fig and ax variable names.
# Chart 1: A plot showing trading actions
fig, ax = plt.subplots()
env.unwrapped.render_all() # Render trading actions (buy/sell)
ax.set_title("PPO Agent - Trading Actions")
plt.show()
# Check the length of balance_history and ensure it has the correct data
print(f"Length of balance_history: {len(balance_history)}")
print(f"Balance history data: {balance_history}")
# Chart 2: A plot of the balance_history over time
fig2, ax2 = plt.subplots()
ax2.plot(range(len(balance_history)), balance_history, label="Balance Over Time", color='blue')
ax2.set_title("PPO Agent - Balance Over Time")
ax2.set_xlabel("Steps")
ax2.set_ylabel("Balance ($)")
ax2.legend()
fig2.tight_layout()
plt.show()