Your project is centered around developing a reinforcement learning (RL) simulation for stock trading in Python. This initiative is spearheaded by Quantum Trading, a fictional but ambitious trading firm looking to leverage cutting-edge machine learning techniques to gain a competitive edge in the financial markets. Quantum Trading is a small but highly specialised team of financial analysts, data scientists, and software engineers who are passionate about transforming the way trading decisions are made.
In the fast-paced world of financial markets, staying ahead of the curve is crucial. Traditional trading strategies, while effective, often rely on historical data and predefined rules that may not adapt quickly to changing market conditions. Reinforcement learning, a subfield of machine learning where an agent learns to make decisions by interacting with an environment, offers a promising alternative. It allows the trading algorithms to learn and adapt in real-time, improving their performance as they gain more experience.
By engaging with this project, you will gain valuable insights into the dynamic world of algorithmic trading and enhance your skill set in data science, finance, and machine learning. Remember, the journey of learning and experimentation is as important as the results. Good luck, and may your trading algorithms be ever profitable!
The Data
The provided data AAPL.csv contains historical prices for AAPL (the ticker symbol for Apple Inc) and you will be using this in your model. It has been loaded for you already in the sample code below and contains two columns, described below.
| Column | Description |
|---|---|
Date | The date corresponding to the closing price |
Close | The closing price of the security on the given date |
Disclaimer: This project is for educational purposes only. It is not financial advice, and should not be understood or construed as, financial advice.
# Make sure to run this cell to use gymnasium gym-anytrading stable-baselines3
!pip install gymnasium gym-anytrading stable-baselines3# Import required packages
# Note that gym-anytrading is a gym environment specific for trading
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import gymnasium as gym
import gym_anytrading
from gym_anytrading.envs import Actions
from stable_baselines3 import PPO
# Load the data provided
data = pd.read_csv("AAPL.csv", parse_dates=True, index_col='Date')
# Set window size (lookback window), start and end index of data
window_size = 10
start_index = window_size
end_index = len(data)
# Create the environment using the pre-built 'stocks-v0', passing in the data loaded above
# window_size is how far the environment can look back at each step
# frame_bound is the bounding box of the environment, the size of our data
env = gym.make('stocks-v0', df=data, window_size=window_size, frame_bound=(start_index, end_index))
print("Observation Space:", env.observation_space)
# Initialise cash balance tracking and other variables
balance = 100000
balance_history = [balance]
shares_held = 0
action_stats = {Actions.Sell: 0, Actions.Buy: 0}
observation, info = env.reset(seed=2024)# Start coding here by training the PPO model
model = PPO('MlpPolicy', env, verbose=1)
model.learn(total_timesteps=10000)
balance = 100000
shares_held = 0
balance_history = [balance]
action_history = []
price_history = []
positions = []
observation, info = env.reset()
# Action is a scalar numpy array
action_counts = {0: 0, 1: 0} # 0=Buy, 1=Sell
for i in range(len(data) - window_size):
action, _states = model.predict(observation)
# Convert scalar numpy array to integer
action = int(action) # works for scalar arrays
# Track actions
action_counts[action] += 1
observation, reward, terminated, truncated, info = env.step(action)
done = terminated or truncated
# Use correct column name
try:
current_price = data['Class'].iloc[i + window_size]
except:
current_price = data['Close'].iloc[i + window_size]
price_history.append(current_price)
# Print first few actions
if i < 10:
action_names = {0: 'BUY', 1: 'SELL'}
print(f"Step {i}: Action={action}({action_names.get(action, 'UNKNOWN')}), Price=${current_price:.2f}, Balance=${balance:.2f}")
if action == 0 and balance > current_price: # BUY
shares_to_buy = int(balance * 0.1 // current_price)
cost = shares_to_buy * current_price
balance -= cost
shares_held += shares_to_buy
action_history.append(1)
positions.append(('buy', current_price, i))
if i < 10:
print(f" -> BOUGHT {shares_to_buy} shares at ${current_price:.2f}")
elif action == 1 and shares_held > 0: # SELL
proceeds = shares_held * current_price
balance += proceeds
action_history.append(2)
positions.append(('sell', current_price, i))
if i < 10:
print(f" -> SOLD {shares_held} shares at ${current_price:.2f}")
shares_held = 0
else:
action_history.append(0)
if i < 10 and action == 0:
print(f" -> Cannot buy: balance ${balance:.2f} < price ${current_price:.2f}")
elif i < 10 and action == 1:
print(f" -> Cannot sell: no shares held")
balance_history.append(balance)
if done:
break
if shares_held > 0:
try:
final_price = data['Class'].iloc[-1]
except:
final_price = data['Close'].iloc[-1]
balance += shares_held * final_price
balance_history[-1] = balance
# Show results
print(f"\n=== TRADING SUMMARY ===")
print(f"Final portfolio value: ${balance:,.2f}")
print(f"Action distribution: {action_counts}")
print(f"Total trades executed: {len(positions)}")
print(f"Buy actions: {len([p for p in positions if p[0]=='buy'])}")
print(f"Sell actions: {len([p for p in positions if p[0]=='sell'])}")
print(f"Shares still held: {shares_held}")# Create your two charts below. Note, do not change the fig and ax variable names.
#CHART 1: STOCK PRICE WITH TRADING ACTIONS
fig, ax = plt.subplots(figsize=(12, 6))
ax.plot(price_history, label='Stock Price', color='black', linewidth=1)
buy_points = [pos[2] for pos in positions if pos[0] == 'buy']
buy_prices = [pos[1] for pos in positions if pos[0] == 'buy']
ax.scatter(buy_points, buy_prices, color='green', marker='^', label='Buy', s=50, alpha=0.7)
sell_points = [pos[2] for pos in positions if pos[0] == 'sell']
sell_prices = [pos[1] for pos in positions if pos[0] == 'sell']
ax.scatter(sell_points, sell_prices, color='red', marker='v', label='Sell', s=50, alpha=0.7)
ax.set_title('AAPL Stock Price with PPO Trading Actions')
ax.set_xlabel('Trading Period')
ax.set_ylabel('Price ($)')
ax.legend()
ax.grid(True, alpha=0.3)
#CHART 2: CASH BALANCE OVER TIME
fig2, ax2 = plt.subplots(figsize=(12, 6))
ax2.plot(balance_history, label='Portfolio Value', color='blue', linewidth=2)
ax2.set_title('Portfolio Value Over Time')
ax2.set_xlabel('Trading Period')
ax2.set_ylabel('Portfolio Value ($)')
ax2.legend()
ax2.grid(True, alpha=0.3)