Module nets

Collection of neural networks intended for reinforcement learning.

Currently contains a single neural network architecture intended for usage with gym environments.

Exposed Networks:
  • AtariNet

Networks

Neural network for atari games (nets.AtariNet)

class pytorch_seed_rl.nets.AtariNet(*args: Any, **kwargs: Any)[source]

Bases: torch.nn.Module

Neural network architecture intended for usage with gym environments.

This network architecture is copied from the torchbeast project, which mimics the neural network used in the IMPALA paper.

Parameters
  • observation_shape (tuple) – The shape of the tensors this neural network processes.

  • num_actions (int) – The number of discrete actions this neural network can return.

  • use_lstm (bool) – Set True, if an LSTM shall be included with this neural network.

initial_state(batch_size: int)[source]

Return 0 torch.Tensor with shape of LSTM block.

Returns None, if LSTM has not been activated during initialization.

forward(inputs: dict, core_state: tuple = ())[source]

Forward step of the neural network

Parameters

inputs (dict of torch.Tensor) – Awaits a dictionary as returned by an step of DictObservationsEnv