Module nets¶
Collection of neural networks intended for reinforcement learning.
Currently contains a single neural network architecture intended for usage with
gym environments.
- Exposed Networks:
AtariNet
Networks¶
Neural network for atari games (nets.AtariNet)¶
-
class
pytorch_seed_rl.nets.AtariNet(*args: Any, **kwargs: Any)[source]¶ Bases:
torch.nn.ModuleNeural network architecture intended for usage with
gymenvironments.This network architecture is copied from the torchbeast project, which mimics the neural network used in the IMPALA paper.
See also
“IMPALA: Scalable Distributed Deep-RL with Importance Weighted Actor-Learner Architectures” on arXiv by Espeholt, Soyer, Munos et al.
- Parameters
observation_shape (tuple) – The shape of the tensors this neural network processes.
num_actions (int) – The number of discrete actions this neural network can return.
use_lstm (bool) – Set True, if an LSTM shall be included with this neural network.
-
initial_state(batch_size: int)[source]¶ Return 0
torch.Tensorwith shape of LSTM block.Returns None, if LSTM has not been activated during initialization.
-
forward(inputs: dict, core_state: tuple = ())[source]¶ Forward step of the neural network
- Parameters
inputs (dict of
torch.Tensor) – Awaits a dictionary as returned by an step ofDictObservationsEnv