Skip to main content

Microsoft's rStar-Math: A Guide With Implementation

Learn how to create a simplified implementation of the RStar framework using a combination of neural networks, symbolic reasoning, and Monte Carlo Tree Search (MCTS).
Jan 23, 2025  · 12 min read

Microsoft’s RStar-math paper presents an innovative approach to solving mathematical problems using a combination of reinforcement learning, symbolic reasoning, and Monte Carlo Tree Search (MCTS).

In this blog, I'll explore the RStar framework and its core components. I'll then guide you step-by-step through a simplified implementation that demonstrates its key concepts using Gradio. While this demo is inspired by the paper, some complexities have been simplified for accessibility.

What Is Microsoft’s rStar-Math?

The RStar math aims to bridge symbolic reasoning with the generalization capabilities of pre-trained neural models. The framework integrates components like Monte Carlo Tree Search (MCTS), pre-trained language models, and reinforcement learning to enable efficient exploration of problem-solving strategies.

The core idea is to represent mathematical reasoning as a search process over a structured tree of possible steps, where each node represents a partial solution or state.

rStar math framework

Source: Guan et al., 2025

Some of the reasons that make rStar-Math particularly interesting for me are:

  1. It includes a neural network (policy model) that predicts the next action in solving a mathematical problem, guiding the exploration of MCTS.
  2. A network (reward model) that evaluates the success of actions taken during MCTS rollouts and provides feedback for training.
  3. RStar uses symbolic computation libraries like SymPy for precise mathematical operations or symbolic reasoning like solving equations or computing derivatives.
  4. It embeds a Monte Carlo Tree Search algorithm that systematically explores possible solution paths using simulations, balancing exploration by trying new paths and exploitation.
  5. A feedback mechanism where the policy and reward models are iteratively trained based on the outcomes of MCTS rollouts, improving the decision-making process over time.
  6. The reasoning process is structured as a hierarchical tree where nodes represent states and edges represent transitions.

Demo Project Overview: Math Problem Solver with Gradio

The demo demonstrates how a policy model and a reward model, combined with symbolic reasoning using the sympy library, can tackle mathematical problems in a structured way. The key features of this implementation include:

  1. Policy model: A neural network that predicts the next action in the problem-solving process.
  2. Reward model: A network that evaluates the success of actions taken during MCTS rollouts.
  3. Symbolic reasoning: It utilizes SymPy for precise mathematical computation and solving equations.
  4. Monte Carlo Tree Search: It implements a simplified version of MCTS to explore possible solutions efficiently.
  5. Reinforcement learning loop: A basic training loop for improving the policy and reward models based on feedback.
  6. Support for single and multi-variable equations: The users can input one or two equations to find solutions for variables like x and y.

To keep the demo simple and focused, certain advanced features pointed out in the paper are beyond the scope of this tutorial. Those features are:

  1. Scalability: The original paper uses large pre-trained models and extensive computational resources. The demo uses smaller neural networks and avoids complex pre-training.
  2. Advanced MCTS strategies: Techniques like adaptive UCT and diverse exploration strategies are not fully implemented.
  3. Task generalization: The implementation focuses solely on solving algebraic equations, whereas RStar is designed to generalize across broader mathematical tasks.
  4. Dataset: Instead of using a curated dataset for training, the demo relies on symbolic reasoning and user-provided inputs.

Step 1: Prerequisites

The demo is broken into several components, each reflecting a part of the RStar methodology. Before we begin, ensure you have the following installed:

  • Python 3.8+
  • Required libraries: Install the necessary Python packages using pip:
pip install requests gradio, sympy 

Then, import these libraries:

import gradio as gr
import numpy as np
import torch
import re
import torch.nn as nn
import torch.optim as optim
from sympy import symbols, Eq, solve, N, sin, cos, tan, exp, log, E, sympify
from random import choice

Now that all the dependencies are installed, let’s set up the main components.

Step 2: Neural Networks for Policy and Reward

These networks are lightweight versions of the models described in the paper, used to predict the next action and evaluate success. The policy model predicts the next steps to solve the given equations. It uses a feedforward neural network to process encoded representations of the problem.

Similarly, the reward model evaluates partial solutions to guide the MCTS process. Both models are implemented using PyTorch.

class PolicyModel(nn.Module):
    def __init__(self, input_size, hidden_size, output_size):
        super().__init__()
        self.fc1 = nn.Linear(input_size, hidden_size)
        self.fc2 = nn.Linear(hidden_size, output_size)
        self.relu = nn.ReLU()
    def forward(self, x):
        x = self.relu(self.fc1(x))
        x = self.fc2(x)
        return x
class RewardModel(nn.Module):
    def __init__(self, input_size, hidden_size, output_size):
        super().__init__()
        self.fc1 = nn.Linear(input_size, hidden_size)
        self.fc2 = nn.Linear(hidden_size, output_size)
        self.relu = nn.ReLU()
    def forward(self, x):
        x = self.relu(self.fc1(x))
        x = self.fc2(x)
        return x

Next, we build a node class for MCTS trees.

Step 3: TreeNode Class for representing MCTS States

The TreeNode class represents nodes in the MCTS tree. Each node corresponds to a state in the search process, containing:

  • The state (e.g., equations or partial solutions).
  • A reference to its parent node.
  • A list of children nodes (expanded states).
  • Visits and Q-values, which track how often the node was explored and its accumulated rewards.
class TreeNode:
    """Represents a node in the MCTS tree."""
    def __init__(self, state, parent=None):
        self.state = state  # Current state
        self.parent = parent
        self.children = []
        self.visits = 0
        self.q_value = 0.0  # Accumulated rewards
    def is_fully_expanded(self):
        return len(self.children) > 0
    def best_child(self, exploration_weight=1.4):
        """Select the best child using UCT formula."""
        def uct_value(child):
            return (child.q_value / (child.visits + 1e-6)) + exploration_weight * np.sqrt(np.log(self.visits + 1) / (child.visits + 1e-6))
        return max(self.children, key=uct_value)
    def add_child(self, child_state):
        """Add a child node with the given state."""
        child = TreeNode(state=child_state, parent=self)
        self.children.append(child)
        return child

Now that we have the basic structure in place, we will work next with the core components of the demo.

Step 4: The MathSolver Class

The MathSolver class is the core of the demo, combining symbolic reasoning with neural-guided search. It implements several key components:

A. Policy and reward models

The PolicyModel predicts the next steps to solve equations, while the RewardModel evaluates the success of partial or complete solutions.

class MathSolver:
    def __init__(self, dataset=None):
        self.dataset = dataset or []  # Dataset of math problems
        self.policy_model = PolicyModel(input_size=128, hidden_size=64, output_size=4)  
        self.reward_model = RewardModel(input_size=128, hidden_size=64, output_size=1)  
        self.policy_optimizer = optim.Adam(self.policy_model.parameters(), lr=0.001)
        self.reward_optimizer = optim.Adam(self.reward_model.parameters(), lr=0.001)
        self.execution_context = {}  

The above method initializes the MathSolver class by setting up the components required for solving mathematical problems. It optionally accepts a dataset of math problems and initializes two neural networks: the policy model, which predicts the next action, and the reward model, which evaluates the success of actions. 

We now have a policy and reward function in place. Next, we need to parse and encode the input equations.

B. Equation parsing and encoding

The equations are parsed using sympy and encoded into feature vectors for processing by the policy and reward models.

def encode_problem(self, problem):
 # Advanced encoding using symbolic representation and problem length
     variables = len(re.findall(r'[a-zA-Z]', problem))
     operators = len(re.findall(r'[\+\-\*/\^]', problem))
      problem_length = len(problem)
      return np.array([variables, operators, problem_length] + [0] * 125)

The encode_problem method converts a math problem into a fixed-size numerical representation for the models. It extracts features like the number of variables, operators, and problem length, encoding them into a 128-dimensional NumPy array. This representation captures the problem's structure, enabling effective model processing.

C. Policy model prediction

The following code generates the next steps for solving the given equations, including defining variables, creating equations, and solving them.

def policy_model_predict(self, equation1, equation2=None):
    try:
        equations = []
        if equation1:
            equations.append(sympify(equation1.strip()))  # Sympify only equations
        if equation2:
            equations.append(sympify(equation2.strip()))
        all_variables = set()
        for eq in equations:
            all_variables.update(eq.free_symbols)
        var_definitions = [f"{v} = symbols('{v}')" for v in all_variables]
        steps = [
            ("Define variables", "\n".join(var_definitions)),
            ("Define equation(s)", f"equations = {equations}"),
            ("Solve equation(s)", f"solution = solve(equations, {list(all_variables)})"),
            ("Print solution", "print(solution)")
        ]
        return steps
    except Exception as e:
        print(f"Error during policy model prediction: {e}")
        return []

The policy_model_predict function parses the input equations using SymPy's sympify to ensure they are valid mathematical expressions. It then identifies all the variables present in the equations and solves them using SymPy's solve function. This method serves as a high-level guide for the problem-solving workflow. 

D. Reward model prediction

The reward_model_predict method plays a vital role in reinforcement learning by providing feedback for actions taken during Monte Carlo Tree Search (MCTS) rollouts. 

  
  def reward_model_predict(self, steps, success):
        encoded_steps = self.encode_problem(str(steps))
        encoded_steps = torch.tensor(encoded_steps, dtype=torch.float32)
        reward = self.reward_model(encoded_steps)
        return reward.item() if success else -reward.item()

The function encodes problem-solving steps into a numerical format and evaluates them through the reward model, returning a positive reward for success and a negative reward for failure. This feedback trains the policy model, guiding it to prioritize effective actions and improve decision-making. With the policy and reward model prediction functions in place, we can now work on the execution task.

E. Execute code function

This method handles multi-variable solutions as tuples or dictionaries and converts symbolic results to numerical approximations using SymPy's N function. 

    def execute_code(self, code):
        try:
            # Ensure necessary imports and variables are in the execution context
            exec("from sympy import symbols, Eq, solve, N, sin, cos, tan, exp, log, E", self.execution_context)
            # Dynamically initialize variables in the context
            for var_def in self.execution_context.get('var_definitions', []):
                exec(var_def, self.execution_context)
            exec(code, self.execution_context)
            if "solution" in self.execution_context:
                symbolic_solution = self.execution_context["solution"]
                # Handle multi-variable solutions as tuples
                if isinstance(symbolic_solution, list):
                    self.execution_context["solution"] = [tuple(map(N, sol)) if isinstance(sol, tuple) else N(sol) for sol in symbolic_solution]
                elif isinstance(symbolic_solution, dict):
                    self.execution_context["solution"] = {k: N(v) for k, v in symbolic_solution.items()}
                else:
                    self.execution_context["solution"] = N(symbolic_solution)
            return True
        except Exception as e:
            print(f"Error executing code: {e}")
            return False

This method ensures accurate computation and supports flexible handling of various solution formats. In case of errors, it logs the issue and returns False, maintaining robust error handling.

F. Monte Carlo Tree Search (MCTS)

The MCTS method iteratively selects the best states, expands the search tree, and simulates possible solutions. Rewards from simulations are back-propagated to improve decision-making.

def mcts(self, equation1, equation2=None, num_rollouts=10):
    root = TreeNode(state=(equation1, equation2))
    for _ in range(num_rollouts):
        # Selection
        node = root
        while node.is_fully_expanded() and node.children:
            node = node.best_child()
        # Expansion
        if not node.is_fully_expanded():
            steps = self.policy_model_predict(*node.state)
            for step, code in steps:
                child_state = (step, code)
                node.add_child(child_state)
        # Simulation
        success = True
        for step, code in steps:
            if not self.execute_code(code):
                success = False
                break
        # Backpropagation
        reward = self.reward_model_predict(steps, success)
        while node:
            node.visits += 1
            node.q_value += reward
            node = node.parent
    return root.best_child().state if root.children else None

The mcts method iteratively performs four key steps: 

  • Selection: It helps to navigate down the tree using the best child nodes.
  • Expansion: In this step, new child nodes are created using the policy_model_predict method.
  • Simulation: All the actions are executed, and success is determined in this step.
  • Backpropagation: The rewards are calculated using the reward_model_predict method and propagated up the tree to update node values. 

After a specified number of rollouts, the method returns the state of the best child node, representing the most promising solution explored during the search.

G. Solution execution

The solve method orchestrates the entire process, from parsing equations to executing and validating solutions.

def solve(self, equation1, equation2=None):
    self.execution_context = {}
    steps = self.policy_model_predict(equation1, equation2)
    variables = set()
    for eq in [equation1, equation2] if equation2 else [equation1]:
        if eq:
            variables.update(sympify(eq.strip()).free_symbols)
    self.execution_context['var_definitions'] = [f"{v} = symbols('{v}')" for v in variables]
    steps_output = ["Best solution found:"]
    for step, code in steps:
        steps_output.append(f"Step: {step}")
        steps_output.append(f"Code: {code}")
        if self.execute_code(code):
            steps_output.append("Execution successful.")
        else:
            steps_output.append("Execution failed.")
    if "solution" in self.execution_context:
        final_answer = self.execution_context["solution"]
        if isinstance(final_answer, dict):
            for var, value in final_answer.items():
                steps_output.append(f"{var} = {value}")
        elif isinstance(final_answer, list):
            for solution in final_answer:
                if isinstance(solution, tuple):
                    for idx, var in enumerate(variables):
                        steps_output.append(f"{list(variables)[idx]} = {solution[idx]}")
                else:
                    steps_output.append(f"Solution: {solution}")
        else:
            steps_output.append(f"Final Answer: {final_answer}")
    else:
        steps_output.append("No final answer found.")
    return "\n".join(steps_output)

The solve method processes one or two user-provided equations by initializing an execution context and generating steps via policy_model_predict. It executes each step, logs progress, and reports success or failure. Solutions, including single and multivariable results, are formatted with variable names and values for clarity. If no solution is found, an appropriate message is displayed.

We have all the core components in place, so we can work on the Gradio app next.

Step 5: Creating a User-Friendly Interface with Gradio

The Gradio interface allows users to input equations (one or more), solve them, and view the results interactively.

with gr.Blocks() as app:
    gr.Markdown("# Math Problem Solver with Advanced Multi-Step Reasoning and Learning")
    with gr.Row():
        equation1_input = gr.Textbox(label="Enter the first equation (e.g., x + y - 3)", placeholder="x + y - 3")
        equation2_input = gr.Textbox(label="Enter the second equation (optional, e.g., x - y - 1)", placeholder="x - y - 1")
    solve_button = gr.Button("Solve")
    solution_output = gr.Textbox(label="Solution", interactive=False)
    solve_button.click(solve_math_problem, inputs=[equation1_input, equation2_input], outputs=[solution_output])
app.launch(debug=True)

The above code creates a Gradio user interface for solving mathematical equations with advanced reasoning. The interface is wrapped in a gr.Blocks container, which contains two input fields using gr.Textbox: one for the first equation (mandatory) and another for the second equation (optional).

The output is displayed in a single gr.Textbox labeled "Solution". The interface.launch() command launches the Gradio app in a browser, and the debug=True flag enables detailed logs to help troubleshoot errors.

Step 6: Testing and Validating

It’s time to test our Math Problem Solver app. Here are some tests I ran:

1. Single variable single equation: I tried to find possible values of a single variable x given a single equation as input.

Solving equation with single variable

2. Multiple variable multiple equations problem: I passed in two equations with two variable problems to find possible values of variables x and y.

Solving equations with double variable

Possible Extensions

 This demo is a basic version of what we can achieve with the capabilities of the rStar-math method. There is still a lot of work that can be done to extend its capabilities. 

  • Add pre-trained language models to enhance the policy model's reasoning capabilities.
  • Integrate advanced MCTS strategies or try diverse exploration techniques to improve the efficiency and accuracy of the search process.
  • Extend the demo to solve polynomial equations of higher degrees and complex systems of equations.
  • Train the models on a larger dataset of equations for improved generalization.
  • Extend the demo to handle additional reasoning tasks beyond math.

You can refer to the original repository of the rStar-math paper on GitHub.

Conclusion

This demo showcases a practical implementation of multi-step reasoning for solving mathematical equations. By combining neural networks, symbolic reasoning, and MCTS, it provides a glimpse into how advanced AI techniques can tackle structured reasoning tasks. Future enhancements could bring it closer to the full capabilities of the RStar framework.


Aashi Dutt's photo
Author
Aashi Dutt
LinkedIn
Twitter

I am a Google Developers Expert in ML(Gen AI), a Kaggle 3x Expert, and a Women Techmakers Ambassador with 3+ years of experience in tech. I co-founded a health-tech startup in 2020 and am pursuing a master's in computer science at Georgia Tech, specializing in machine learning.

Topics

Learn AI with these courses!

track

Developing AI Applications

23hrs hr
Learn to create AI-powered applications with the latest AI developer tools, including the OpenAI API, Hugging Face, and LangChain.
See DetailsRight Arrow
Start Course
See MoreRight Arrow
Related

tutorial

Decision Trees in Machine Learning Using R

A comprehensive guide to building, visualizing, and interpreting decision tree models with R.
Arunn Thevapalan's photo

Arunn Thevapalan

27 min

tutorial

Microsoft's TinyTroupe: A Guide With Examples

Learn how to use Microsoft’s TinyTroupe to simulate interactions between AI personas with distinct characteristics for different purposes.
Hesam Sheikh Hassani's photo

Hesam Sheikh Hassani

8 min

tutorial

Microsoft's Phi-4: Step-by-Step Tutorial With Demo Project

Learn how to build a homework checker using Microsoft's Phi-4 model, which validates solutions, provides detailed corrections, and suggests elegant alternatives.
Aashi Dutt's photo

Aashi Dutt

12 min

tutorial

keras: Deep Learning in R

In this tutorial to deep learning in R with RStudio's keras package, you'll learn how to build a Multi-Layer Perceptron (MLP).
Karlijn Willems's photo

Karlijn Willems

31 min

tutorial

Machine Learning in R for beginners

This small tutorial is meant to introduce you to the basics of machine learning in R: it will show you how to use R to work with KNN.
Karlijn Willems's photo

Karlijn Willems

24 min

tutorial

Building Neural Network (NN) Models in R

In this tutorial, you will learn how to create a Neural Network model in R.
Abid Ali Awan's photo

Abid Ali Awan

16 min

See MoreSee More