Skip to the content.

< Go back

Solving the Bellman equation with a numerical solver

This is a simple example of solving the Bellman equation with a numerical solver for a very small reinforcement learning environment. For such a small environment, it is possible to obtain the optimal policy without much effort. So let’s do that!

Pre-requisites

I’m assuming that you know what an agent, state, action, and reward are in reinforcement learning.

The Bellman equation

The total pay-off after starting in some state \(s_0\) is a sum of discounted rewards:

\(R(s_0) + \gamma R(s_1) + \gamma^2 R(s_2) + \dots\)

where \(R(s)\) is the reward the agent receives in state \(s\) and \(\gamma \in \langle 0,1 \rangle\) is the discount factor.

The goal of the optimal policy is to maximize the expected value of the total pay-off.

We’ll denote the expected value of the total pay-off under policy \(\pi\) and starting in state \(s_0\) as \(v^{\pi}(s_0)\):

\(v^{\pi}(s_0) = \mathbb{E} \big( R(s_0) + \gamma R(s_1) + \gamma^2 R(s_2) + \dots \big)\)

The agent starts in a state \(s_0\) and executes a policy \(\pi\).

The above equation seems like a super tedious equation to compute, because we need to account for all possible future states… But we can do a neat math trick to make it way more tractable!

We first notice that if we factor out \(\gamma\), we get:

\(v^{\pi}(s_0) = \mathbb{E} \big( R(s_0) \big) + \gamma \Big( \underbrace{ \mathbb{E} \big( R(s_1) + \gamma R(s_2) + \dots \big) }_{\text{this is $v^{\pi}(s_1)$}} \Big)\)

We can get rid of the expected value on \(R(s_0)\). The immediate reward for being in state \(s_0\) is certain. Therefore, its expected value is simply equal to \(R(s_0)\):

\(v^{\pi}(s_0) = R(s_0) + \gamma v^{\pi} (s_1)\)

To make things more general, we denote any current state as \(s\) and any next state – a state immediately following the current state – as \(s'\):

\(v^{\pi}(s) = R(s) + \gamma v^{\pi} (s')\)

This makes the equation work for any current state.

The last thing we need to do is to elaborate on the second term in the above equation. We need to account for situations where there are several possible future states accessible from the current state and that each can be entered with a certain transition probability, \(P_{s, \pi(s)}\). Hence, we sum over all probable immediate future states where each of them is weighted by its transition probability, \(P_{s, \pi(s)}\):

\(v^{\pi}(s) = R(s) + \gamma \sum_{s'} P_{s, \pi(s)} \cdot v^{\pi} (s')\)

This is known as the Bellman equation.

The summation in the above equation loops over all possible next states, \(s'\), that are achievable directly from the current state \(s\). This will typically be a sum over a small number of elements – typically much smaller then the total number of states in an environment – because only a handful of states are immediately achievable from any current state \(s\).

\(v^{\pi}(s)\) is known as the value function. You can think of it as a value of being in state \(s\) under the policy \(\pi\). For example, there is a higher value of states that are just one hop away from the states that reward the agent with a positive reward, compared to states that are a couple of hops away.

The environment

Consider a simple environment with only 6 states:

 _______ _______ _______
|       |       |       |
| (0,0) | (0,1) | (0,2) |
|_______|_______|_______|
|       |       |       |
| (1,0) | (1,1) | (1,2) |
|_______|_______|_______|

with a policy \(\pi\):

 _______ _______ _______
|       |       |       |
|  ->   |  ->   |  +1   |
|_______|_______|_______|
|       |       |       |
|  ^    |  ^    |  <-   |
|_______|_______|_______|

and with all transition probabilities equal to 1. The +1 tile is a terminal state at which the agent receives the +1 reward and no further transition from that state is possible. You can already see that this can’t be the optimal policy because in \(s = (1,2)\) it would be best to move up, since the terminal state is one hop away in that direction. Instead, the current policy makes the terminal state reachable within three hops from \(s = (1,2)\). But it also isn’t a terribly bad policy – the terminal state is reachable from every other state within at most three hops.

The agent travels the environment according to the policy starting from some initial position. The agent’s goal is to reach the +1 tile, irrespective of the starting position:

 _______ _______ _______
|       |       |       |
|    ___|_______|___✨  |
|___|___|_______|_______|
|   |   |       |       |
|   🤖  |       |       |
|_______|_______|_______|

The system of Bellman equations for this environment

In order to compute the value function for this policy, we solve the following system of six linear equations:

\(\begin{cases} v^{\pi}((0,0)) = R((0, 0)) + \gamma v^{\pi}((0,1)) \\ v^{\pi}((0,1)) = R((0, 1)) + \gamma v^{\pi}((0,2)) \\ v^{\pi}((0,2)) = R((0, 2)) \\ v^{\pi}((1,0)) = R((1, 0)) + \gamma v^{\pi}((0,0)) \\ v^{\pi}((1,1)) = R((1, 1)) + \gamma v^{\pi}((0,1)) \\ v^{\pi}((1,2)) = R((1, 2)) + \gamma v^{\pi}((1,1)) \end{cases}\)

Of all the immediate rewards present in this set of equations only \(R((0,2)) \neq 0\). We know that \(R((0,2)) = +1\).

This system can easily be solved in your head starting from \(v^{\pi}((0,2)) = 1\) and successively computing the remaining values. Assuming \(\gamma = 0.9\), this leads to the following value function for each state:

 _______ _______ _______
|       |       |       |
| 0.81  | 0.9   |  1    |
|_______|_______|_______|
|       |       |       |
| 0.729 | 0.81  | 0.729 |
|_______|_______|_______|

But let’s write out the Bellman equation in a matrix form:

\(\begin{bmatrix} 1 & -\gamma & 0 & 0 & 0 & 0 \\ 0 & 1 & -\gamma & 0 & 0 & 0 \\ 0 & 0 & 1 & 0 & 0 & 0 \\ -\gamma & 0 & 0 & 1 & 0 & 0 \\ 0 & -\gamma & 0 & 0 & 1 & 0 \\ 0 & 0 & 0 & 0 & -\gamma & 1 \end{bmatrix} \begin{bmatrix} v_{0,0} \\ v_{0,1} \\ v_{0,2} \\ v_{1,0} \\ v_{1,1} \\ v_{1,2} \end{bmatrix} = \begin{bmatrix} 0 \\ 0 \\ 1 \\ 0 \\ 0 \\ 0 \end{bmatrix}\)

and let’s solve it with numpy.linalg.solve().

Python code

import numpy as np
import matplotlib.pyplot as plt

discount = 0.9

A = np.eye(6,6)
A[0,1] = -discount
A[1,2] = -discount
A[3,0] = -discount
A[4,1] = -discount
A[5,4] = -discount

b = np.array([0,0,1,0,0,0])

value_function = np.linalg.solve(A, b)
value_function = np.reshape(value_function, (2,3))

plt.figure(figsize=(4,3))
plt.imshow(value_function, origin='lower', cmap='coolwarm')
plt.xticks([])
plt.yticks([])
for i in range(2):
    for j in range(3):
        text = plt.text(j, i, round(value_function[i, j], 3), fontsize=20, ha="center", va="center", color="w")
plt.savefig('RL-environement-with-value-functions.png', dpi=300, bbox_inches='tight')

This results in the following value function:

Just like we’ve computed manually before!