Q-learning with Neural Networks

What is Q-learning?¶

Q-learning, like virtually all RL methods, is one type of algorithm used to calculate state-action values. It falls under the class of temporal difference (TD) algorithms, which suggests that time differences between actions taken and rewards received are involved.

In part 2 where we used a Monte Carlo method to learn to play blackjack, we had to wait until the end of a game (episode) to update our state-action values. With TD algorithms, we make updates after every action taken. In most cases, that makes more sense. We make a prediction (based on previous experience), take an action based on that prediction, receive a reward and then update our prediction.

(Btw: Don’t confuse the “Q” in Q-learning with the $Q$ function we’ve discussed in the previous parts. The $Q$ function is always the name of the function that accepts states and actions and spits out the value of that state-action pair. RL methods involve a $Q$ function but aren’t necessarily Q-learning algorithms.)

Here’s the tabular Q-learning update rule:

So, like Monte Carlo, we could have a table that stores the Q-value for every possible state-action pair and iteratively update this table as we play games. Our policy $\pi$ would be based on choosing the action with the highest Q value for that given state.

But we’re done with tables. This is 2015, we have GPUs and stuff. Well, as I alluded to in part 2, our $Q(s,a)$ function doesn’t have to just be a lookup table. In fact, in most interesting problems, our state-action space is much too large to store in a table. Imagine a very simplified game of Pacman. If we implement it as a graphics-based game, the state would be the raw pixel data. In a tabular method, if the pixel data changes by just a single pixel, we have to store that as a completely separate entry in the table. Obviously that’s silly and wasteful. What we need is some way to generalize and pattern match between states. We need our algorithm to say “the value of these kind of states is X” rather than “the value of this exact, super specific state is X.”

That’s where neural networks come in. Or any other type of function approximator, even a simple linear model. We can use a neural network, instead of a lookup table, as our $Q(s,a)$ function. Just like before, it will accept a state and an action and spit out the value of that state-action.

Importantly, however, unlike a lookup table, a neural network also has a bunch of parameters associated with it. These are the weights. So our $Q$ function actually looks like this: $Q(s, a, \theta)$ where $\theta$ is a vector of parameters. And instead of iteratively updating values in a table, we will iteratively update the $\theta$ parameters of our neural network so that it learns to provide us with better estimates of state-action values.

Of course we can use gradient descent (backpropagation) to train our $Q$ neural network just like any other neural network.

But what’s our target y vector (expected output vector)? Since the net is not a table, we don’t use the formula shown above, our target is simply: $r_{t+1} + \gamma * maxQ(s’, a’)$ for the state-action that just happened. $\gamma$ is a parameter $0\rightarrow1$ that is called the discount factor. Basically it determines how much each future reward is taken into consideration for updating our Q-value. If $\gamma$ is close to 0, we heavily discount future rewards and thus mostly care about immediate rewards. $s’$ refers to the new state after having taken action $a$ and $a’$ refers to the next actions possible in this new state. So $maxQ(s’, a’)$ means we calculate all the Q-values for each state-action pair in the new state, and take the maximium value to use in our new value update. (Note I may use $s’ \text{ and } a’$ interchangeably with $s_{t+1} \text{ and } a_{t+1}$.)

One important note: our reward update for every state-action pair is $r_{t+1} + \gamma*maxQ(s_{t+1}, a)$ except when the state $s’$ is a terminal state. When we’ve reached a terminal state, the reward update is simply $r_{t+1}$. A terminal state is the last state in an episode. In our case, there are 2 terminal states: the state where the player fell into the pit (and receives -10) and the state where the player has reached the goal (and receives +10). Any other state is non-terminal and the game is still in progress.

There are two keywords I need to mention as well: on-policy and off-policy methods. In on-policy methods we iteratively learn about state values at the same time that we improve our policy. In other words, the updates to our state values depend on the policy. In contrast, off-policy methods do not depend on the policy to update the value function. Q-learning is an off-policy method. It’s advantageous because with off-policy methods, we can follow one policy while learning about another. For example, with Q-learning, we could always take completely random actions and yet we would still learn about another policy function of taking the best actions in every state. If there’s ever a $\pi$ referenced in the value update part of the algorithm then it’s an on-policy method.