projects.objectnav_baselines.models.object_nav_models#
Baseline models for use in the object navigation task.
Object navigation is currently available as a Task in AI2-THOR and Facebook's Habitat.
ObjectNavActorCritic#
class ObjectNavActorCritic(VisualNavActorCritic)
Baseline recurrent actor critic model for object-navigation.
Attributes
action_space: The space of actions available to the agent. Currently only discrete actions are allowed (so this space will always be of typegym.spaces.Discrete).observation_space: The observation space expected by the agent. This observation space should include (optionally) 'rgb' images and 'depth' images and is required to have a component corresponding to the goalgoal_sensor_uuid.goal_sensor_uuid: The uuid of the sensor of the goal object. SeeGoalObjectTypeThorSensoras an example of such a sensor.hidden_size: The hidden size of the GRU RNN.object_type_embedding_dim: The dimensionality of the embedding corresponding to the goal object type.
ObjectNavActorCritic.__init__#
| __init__(action_space: gym.spaces.Discrete, observation_space: SpaceDict, goal_sensor_uuid: str, hidden_size=512, num_rnn_layers=1, rnn_type="GRU", add_prev_actions=False, action_embed_size=6, multiple_beliefs=False, beliefs_fusion: Optional[FusionType] = None, auxiliary_uuids: Optional[List[str]] = None, rgb_uuid: Optional[str] = None, depth_uuid: Optional[str] = None, object_type_embedding_dim=8, trainable_masked_hidden_state: bool = False, backbone="gnresnet18", resnet_baseplanes=32)
Initializer.
See class documentation for parameter definitions.
ObjectNavActorCritic.is_blind#
| @property
| is_blind() -> bool
True if the model is blind (e.g. neither 'depth' or 'rgb' is an input observation type).
ObjectNavActorCritic.get_object_type_encoding#
| get_object_type_encoding(observations: Dict[str, torch.FloatTensor]) -> torch.FloatTensor
Get the object type encoding from input batched observations.
ResnetTensorObjectNavActorCritic#
class ResnetTensorObjectNavActorCritic(VisualNavActorCritic)
ResnetTensorObjectNavActorCritic.is_blind#
| @property
| is_blind() -> bool
True if the model is blind (e.g. neither 'depth' or 'rgb' is an input observation type).
ResnetTensorGoalEncoder#
class ResnetTensorGoalEncoder(nn.Module)
ResnetTensorGoalEncoder.get_object_type_encoding#
| get_object_type_encoding(observations: Dict[str, torch.FloatTensor]) -> torch.FloatTensor
Get the object type encoding from input batched observations.
ResnetDualTensorGoalEncoder#
class ResnetDualTensorGoalEncoder(nn.Module)
ResnetDualTensorGoalEncoder.get_object_type_encoding#
| get_object_type_encoding(observations: Dict[str, torch.FloatTensor]) -> torch.FloatTensor
Get the object type encoding from input batched observations.