Source code for pytorch_seed_rl.nets.atari_net

# Copyright (c) Facebook, Inc. and its affiliates.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# taken from https://github.com/facebookresearch/torchbeast/blob/master/torchbeast/monobeast.py
#   and modified (mostly documentation)
#
# pylint: disable=empty-docstring,invalid-name
"""
"""
import torch
from torch import nn
from torch.nn import Module
from torch.nn import functional as F


[docs]class AtariNet(Module): """Neural network architecture intended for usage with :py:mod:`gym` environments. This network architecture is copied from the torchbeast project, which mimics the neural network used in the IMPALA paper. See Also -------- * `"IMPALA: Scalable Distributed Deep-RL with Importance Weighted Actor-Learner Architectures" on arXiv <https://arxiv.org/abs/1802.01561>`__ by Espeholt, Soyer, Munos et al. * `Torchbeast implementation <https://github.com/facebookresearch/torchbeast/blob/master/torchbeast/monobeast.py>`__ Parameters ---------- observation_shape: `tuple` The shape of the tensors this neural network processes. num_actions: `int` The number of discrete actions this neural network can return. use_lstm: `bool` Set True, if an LSTM shall be included with this neural network. """ def __init__(self, observation_shape: tuple, num_actions: int, use_lstm: bool = False): super(AtariNet, self).__init__() # ATTRIBUTES self.observation_shape = observation_shape self.num_actions = num_actions # Feature extraction. self.conv1 = nn.Conv2d( in_channels=self.observation_shape[0], out_channels=32, kernel_size=8, stride=4, ) self.conv2 = nn.Conv2d(32, 64, kernel_size=4, stride=2) self.conv3 = nn.Conv2d(64, 64, kernel_size=3, stride=1) # Fully connected layer. self.fc = nn.Linear(3136, 512) # FC output size + one-hot of last action + last reward. core_output_size = self.fc.out_features + num_actions + 1 self.use_lstm = use_lstm if use_lstm: self.core = nn.LSTM(core_output_size, core_output_size, 2) self.policy = nn.Linear(core_output_size, self.num_actions) self.baseline = nn.Linear(core_output_size, 1)
[docs] def initial_state(self, batch_size: int): """Return 0 :py:obj:`torch.Tensor` with shape of LSTM block. Returns None, if LSTM has not been activated during initialization. """ if not self.use_lstm: return tuple() return tuple( torch.zeros(self.core.num_layers, batch_size, self.core.hidden_size) for _ in range(2) )
[docs] def forward(self, inputs: dict, core_state: tuple = ()): """Forward step of the neural network Parameters ---------- inputs: `dict` of :py:obj:`torch.Tensor` Awaits a dictionary as returned by an step of :py:class:`~pytorch_seed_rl.environments.atari_wrappers.DictObservationsEnv` """ x = inputs["frame"] # [T, B, C, H, W]. T, B, *_ = x.shape x = torch.flatten(x, 0, 1) # Merge time and batch. x = x.float() / 255.0 x = F.leaky_relu(self.conv1(x)) x = F.leaky_relu(self.conv2(x)) x = F.leaky_relu(self.conv3(x)) x = x.view(T * B, -1) x = F.leaky_relu(self.fc(x)) one_hot_last_action = F.one_hot( inputs["last_action"].view(T * B), self.num_actions ).float() clipped_reward = torch.clamp(inputs["reward"], -1, 1).view(T * B, 1) core_input = torch.cat( [x, clipped_reward, one_hot_last_action], dim=-1) if self.use_lstm: core_input = core_input.view(T, B, -1) core_output_list = [] notdone = (~inputs["done"]).float() for c_input, nd in zip(core_input.unbind(), notdone.unbind()): # Reset core state to zero whenever an episode ended. # Make `done` broadcastable with (num_layers, B, hidden_size) # states: nd = nd.view(1, -1, 1) core_state = tuple(nd * s for s in core_state) output, core_state = self.core(c_input.unsqueeze(0), core_state) core_output_list.append(output) core_output = torch.flatten(torch.cat(core_output_list), 0, 1) else: core_output = core_input core_state = tuple() policy_logits = self.policy(core_output) baseline = self.baseline(core_output) probs = F.softmax(policy_logits, dim=1) action = torch.multinomial(probs, num_samples=1) policy_logits = policy_logits.view(T, B, self.num_actions) baseline = baseline.view(T, B) action = action.view(T, B) return ( dict(policy_logits=policy_logits, baseline=baseline, action=action), core_state, )