Post

TRPO Details

The origin paper: Schulman, John, et al. “Trust region policy optimization.” International conference on machine learning. PMLR, 2015.

Overview

This derivation comes from the Appendix A.1 of this paper: Yang, Jiachen, et al. “Adaptive incentive design with multi-agent meta-gradient reinforcement learning.” arXiv preprint arXiv:2112.10859 (2021).

The obejective is

\[J(\pi) := \mathbb{E}_{\pi} \left[ \sum\limits_{t=0}^{\infty} \gamma^t\cdot r(s_t, a_t) \right].\]

The Problem is

\[\max\limits_{\pi} J(\pi).\]

The performance difference lemma (see below) shows that

\[\begin{aligned} J(\textcolor{blue}{\pi'}) =& J(\pi) + \mathbb{E}_{\pi'} \left[ \sum\limits_{t=0}^\infty \gamma^t\cdot A_{\pi} (s_t,a_t) \right] \\ =& J(\pi) + \sum\limits_{s} d_{\textcolor{blue}{\pi'}}(s) \sum\limits_{a} \textcolor{blue}{\pi'}(a\mid s) \cdot A_{\pi} (s,a), \end{aligned}\]

where $d_{\pi}(s)$ is the discounted state visitation frequencies and $A_{\pi}$ is the advantage function under policy $\pi$.

TRPO makes a local approximation, whereby $d_{\textcolor{blue}{\pi’}}$ is replaced by $d_{\pi}(s)$

One can define

\[L_\pi(\textcolor{blue}{\pi'}) := J(\pi)+\sum_s d_\pi(s) \sum_a \textcolor{blue}{\pi'}(a \mid s) \cdot A_\pi(s, a)\]

and derive the lower bound $J(\textcolor{blue}{\pi’}) \geq L_\pi(\textcolor{blue}{\pi’})-c \cdot D_{\mathrm{KL}}^{\max }(\pi, \textcolor{blue}{\pi’})$, where $D_{\mathrm{KL}}^{\max }$ is the KL divergence maximized over states and $c$ depends on $\pi$. The KL divergence penalty can be replaced by a constraint, so the problem becomes

\[\begin{aligned} & \max _{\textcolor{blue}{\theta'}} \sum_s d_\theta(s) \sum_a \textcolor{blue}{\pi'}_{\textcolor{blue}{\theta'}}(a \mid s) \cdot A_\theta(s, a) \\ & \text { s.t. } \bar{D}_{\mathrm{KL}}^\theta(\theta, \textcolor{blue}{\theta'}) \leq \delta, \end{aligned}\]

where $\bar{D}_{\mathrm{KL}}^\theta$ is the KL divergence averaged over states $s \sim d_\theta$. Using importance sampling, the summation over actions $\sum_a(\cdot)$ is replaced by $\mathbb{E}_{a \sim q}\left[\frac{1}{q(a \mid s)}(\cdot)\right]$. It is convenient to choose $q=\pi_\theta$, which results in:

\[\begin{aligned} & \max _{\textcolor{blue}{\theta'}} \mathbb{E}_{s \sim d_\theta, a \sim \pi_\theta}\left[\frac{\textcolor{blue}{\pi'}_{\textcolor{blue}{\theta'}}(a \mid s)}{\pi_\theta(a \mid s)} A_\theta(s, a)\right] \\ & \text { s.t. } \mathbb{E}_{s \sim d_\theta}\left[D_{\mathrm{KL}}\left(\pi_\theta(\cdot \mid s), \textcolor{blue}{\pi'}_{\textcolor{blue}{\theta'}}(\cdot \mid s)\right)\right] \leq \delta . \end{aligned}\]

During online learning, the $\textcolor{blue}{\theta’}$ that is optimized and the old $\theta$ are the same at each iteration, so the gradient estimate is

\[\mathbb{E}_{\pi_\theta}\left[\frac{\nabla_\theta \pi_\theta(a \mid s)}{\pi_\theta(a \mid s)} A_\theta(s, a)\right] .\]

Performance Difference Lemma

In this section, the $\pi$ and $\pi’$ from the previous text have been swapped.

For all policies $\pi, \textcolor{blue}{\pi^\prime}$ and states $s_0$,

\[\begin{aligned} V^\pi(s_0) - V^{\textcolor{blue}{\pi^\prime}}(s_0) =& \mathbb{E}_{\tau \sim {\Pr}^\pi(\tau|s_0=s) } \left[\sum_{t=0}^\infty \gamma^t A^{\textcolor{blue}{\pi'}}(s_t,a_t)\right] \\ =& \frac{1}{1-\gamma}\mathbb{E}_{s\sim d_{s_0}^\pi }\mathbb{E}_{a\sim \pi(\cdot|s) } \left[ A^{\textcolor{blue}{\pi^\prime}}(s,a)\right]. \end{aligned}\]

Kakade, Sham, and John Langford. “Approximately optimal approximate reinforcement learning.” Proceedings of the Nineteenth International Conference on Machine Learning. 2002.

Proof

The proof is provided in the appendix of “On the theory of policy gradient methods: Optimality, approximation, and distribution shift” and I just transcribed it here with additional details.

Let $\Pr^\pi(\tau \mid s_0 = s)$ denote the probability of observing a trajectory $\tau$ when starting in state $s$ and following the policy $\pi$. Using a telescoping argument, we have:

\[\begin{aligned} &V^\pi(s) - V^{\textcolor{blue}{\pi'}}(s) \\ =& \mathbb{E}_{\tau \sim {\Pr}^\pi(\tau|s_0=s) } \left[\sum_{t=0}^\infty \gamma^t r(s_t,a_t)\right] - V^{\textcolor{blue}{\pi'}}(s) \\ =& \mathbb{E}_{\tau \sim {\Pr}^\pi(\tau|s_0=s) } \left[\sum_{t=0}^\infty \gamma^t \left(r(s_t,a_t)+V^{\textcolor{blue}{\pi'}}(s_t)-V^{\textcolor{blue}{\pi'}}(s_t) \right)\right]-V^{\textcolor{blue}{\pi'}}(s)\\ \stackrel{(a)}{=}& \mathbb{E}_{\tau \sim {\Pr}^\pi(\tau|s_0=s) } \left[\sum_{t=0}^\infty \gamma^t \left(r(s_t,a_t)+\gamma V^{\textcolor{blue}{\pi'}}(s_{t+1})-V^{\textcolor{blue}{\pi'}}(s_t)\right)\right]\\ \stackrel{(b)}{=}&\mathbb{E}_{\tau \sim {\Pr}^\pi(\tau|s_0=s) } \left[\sum_{t=0}^\infty \gamma^t \left(r(s_t,a_t)+\gamma \mathbb{E}[V^{\textcolor{blue}{\pi'}}(s_{t+1})|s_t,a_t]-V^{\textcolor{blue}{\pi'}}(s_t)\right)\right]\\ \stackrel{(c)}{=}& \mathbb{E}_{\tau \sim {\Pr}^\pi(\tau|s_0=s) } \left[\sum_{t=0}^\infty \gamma^t A^{\textcolor{blue}{\pi'}}(s_t,a_t)\right] \\ =& \frac{1}{1-\gamma}\mathbb{E}_{s'\sim d^\pi_s }\,\mathbb{E}_{a\sim \pi(\cdot | s')} \left[ A^{\textcolor{blue}{\pi'}}(s',a) \right], \end{aligned}\]

where $(a)$ rearranges terms in the summation and cancels the $V^{\textcolor{blue}{\pi’}}(s_0)$ term with the $-V^{\textcolor{blue}{\pi’}}(s)$ outside the summation, and $(b)$ uses the tower property of conditional expectations and the final equality follows from the definition of $d^\pi_s$. $\blacksquare$

Details

$(a)$:

\[- a_0 +\sum\limits_{k=0}^{\infty} \left(a_k - b_k \right) = \sum\limits_{k=0}^{\infty} \left(a_{k+1} - b_k\right).\]

$(b)$: The tower property of conditional expectations (or law of total probability): If $\mathcal{H} \subseteq \mathcal{G}$, then

\[\mathbb{E}\left[\mathbb{E}\left[X\mid \mathcal{G} \right] \mid \mathcal{H} \right] = \mathbb{E}\left[X\mid \mathcal{H} \right].\]

Correspondingly,

  • $\mathcal{G} = \tau \sim {\Pr}^\pi(\tau \mid s_0=s)$,
  • $\mathcal{H} = (s_t,a_t)$.

$(c)$: Step $(b)$ is necessary. Note that

\[Q^{\pi}(s, a) \ne r(s, a) + \gamma \cdot V^{\pi}(s').\]

But

\[Q^{\pi}(s, a) = r(s, a) + \gamma \cdot \sum\limits_{s'} P(s' \mid s,a) \cdot V^{\pi}(s').\]

Other proofs

Check other proofs here and here.

The Lower Bound

Importance Sampling

This post is licensed under CC BY 4.0 by the author.