Source code for pytorch_seed_rl.functional.vtrace

# This file taken from
#     https://github.com/deepmind/scalable_agent/blob/
#         cd66d00914d56c8ba2f0615d9cdeefcb169a8d70/vtrace.py
# and modified.

# Copyright 2018 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# Taken from
#   https://github.com/facebookresearch/torchbeast/blob/master/torchbeast/core/vtrace.py
# and modified (mostly documentation)
"""Functions to compute V-trace off-policy actor critic targets.

See Also
--------
`"IMPALA: Scalable Distributed Deep-RL with Importance Weighted Actor-Learner Architectures"
on arXiv <https://arxiv.org/abs/1802.01561>`__ by Espeholt, Soyer, Munos et al.

All exposed functions return a :py:class:`~VTraceFromLogitsReturns`.
"""

import collections

import torch
import torch.nn.functional as F


def _action_log_probs(policy_logits: torch.Tensor,
                      actions: torch.Tensor) -> torch.Tensor:
    """Select the logits by their actions using negative log likelihood loss.
    """
    return -F.nll_loss(F.log_softmax(torch.flatten(policy_logits, 0, -2), dim=-1),
                       torch.flatten(actions).to(torch.long),
                       reduction="none",
                       ).view_as(actions)


VTraceFromLogitsReturns = collections.namedtuple(
    "VTraceFromLogitsReturns",
    [
        "vs",
        "pg_advantages",
        "log_rhos",
        "behavior_action_log_probs",
        "target_action_log_probs",
    ],
)


[docs]def from_logits(behavior_policy_logits: torch.Tensor, target_policy_logits: torch.Tensor, values: torch.Tensor, bootstrap_value: torch.Tensor, actions: torch.Tensor, discounts: torch.Tensor, rewards: torch.Tensor, clip_rho_threshold: float = 1.0, clip_pg_rho_threshold: float = 1.0, ) -> VTraceFromLogitsReturns: """V-trace for softmax policies. Parameters ---------- behavior_policy_logits: `torch.Tensor` The policies logits used for action sampling during interaction with the environment. target_policy_logits: `torch.Tensor` The policies logits returned by the learning model. values: `torch.Tensor` The values returned by the learning model. bootstrap_value: `torch.Tensor` The value used for bootstrapping (usually most recent value returned by learning model.) actions: `torch.Tensor` The actions used during interaction with the environment. discounts: `torch.Tensor` The discounted rewards. rewards: `torch.Tensor` The original rewards. clip_rho_threshold: `float`, Clipping value for Vtrace. See paper for details. clip_pg_rho_threshold: `float`, Clipping value for Vtrace. See paper for details. """ # prepare logits target_action_log_probs = _action_log_probs(target_policy_logits, actions) behavior_action_log_probs = _action_log_probs( behavior_policy_logits, actions) log_rhos = target_action_log_probs - behavior_action_log_probs vtrace_returns = _from_importance_weights( log_rhos=log_rhos, discounts=discounts, rewards=rewards, values=values, bootstrap_value=bootstrap_value, clip_rho_threshold=clip_rho_threshold, clip_pg_rho_threshold=clip_pg_rho_threshold, ) return VTraceFromLogitsReturns( log_rhos=log_rhos, behavior_action_log_probs=behavior_action_log_probs, target_action_log_probs=target_action_log_probs, **vtrace_returns._asdict(), )
_VTraceReturns = collections.namedtuple("VTraceReturns", "vs pg_advantages") @torch.no_grad() def _from_importance_weights(log_rhos: torch.Tensor, values: torch.Tensor, bootstrap_value: torch.Tensor, discounts: torch.Tensor, rewards: torch.Tensor, clip_rho_threshold: float = 1.0, clip_pg_rho_threshold: float = 1.0, ) -> _VTraceReturns: """V-trace from logarithmic importance weights. Parameters ---------- log_rhos: `torch.Tensor` Logarithmic importance weights calculated from behaviour and target policy. values: `torch.Tensor` The values returned by the learning model. bootstrap_value: `torch.Tensor` The value used for bootstrapping (usually most recent value returned by learning model.) discounts: `torch.Tensor` The discounted rewards. rewards: `torch.Tensor` The original rewards. clip_rho_threshold: `float`, Clipping value for Vtrace. See paper for details. clip_pg_rho_threshold: `float`, Clipping value for Vtrace. See paper for details. """ # pylint: disable=invalid-name with torch.no_grad(): rhos = torch.exp(log_rhos) if clip_rho_threshold is not None: clipped_rhos = torch.clamp(rhos, max=clip_rho_threshold) else: clipped_rhos = rhos cs = torch.clamp(rhos, max=1.0) # Append bootstrapped value to get [v1, ..., v_t+1] values_t_plus_1 = torch.cat( [values[1:], torch.unsqueeze(bootstrap_value, 0)], dim=0 ) deltas = clipped_rhos * \ (rewards + discounts * values_t_plus_1 - values) acc = torch.zeros_like(bootstrap_value) result = [] for t in range(discounts.shape[0] - 1, -1, -1): acc = deltas[t] + discounts[t] * cs[t] * acc result.append(acc) result.reverse() vs_minus_v_xs = torch.stack(result) # Add V(x_s) to get v_s. vs = torch.add(vs_minus_v_xs, values) # Advantage for policy gradient. broadcasted_bootstrap_values = torch.ones_like(vs[0]) * bootstrap_value vs_t_plus_1 = torch.cat( [vs[1:], broadcasted_bootstrap_values.unsqueeze(0)], dim=0 ) if clip_pg_rho_threshold is not None: clipped_pg_rhos = torch.clamp(rhos, max=clip_pg_rho_threshold) else: clipped_pg_rhos = rhos pg_advantages = clipped_pg_rhos * \ (rewards + discounts * vs_t_plus_1 - values) # Make sure no gradients backpropagated through the returned values. return _VTraceReturns(vs=vs, pg_advantages=pg_advantages)