Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[RLlib] Split AddStates... connectors into 2 pieces (AddTimeDimToBatchAndZeroPad and AddStatesFromEpisodesToBatch) #49835

Merged
merged 9 commits into from
Jan 20, 2025
14 changes: 10 additions & 4 deletions rllib/algorithms/algorithm.py
Original file line number Diff line number Diff line change
Expand Up @@ -2977,6 +2977,16 @@ def log_result(self, result: ResultDict) -> None:

@override(Trainable)
def cleanup(self) -> None:
# Stop all Learners.
if hasattr(self, "learner_group") and self.learner_group is not None:
self.learner_group.shutdown()
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This does sometimes not work. At least in my workspace running on GPU a algo.stop() did not pull down the BackendExecutor workers.


# Stop all aggregation actors.
if hasattr(self, "_aggregator_actor_manager") and (
self._aggregator_actor_manager is not None
):
self._aggregator_actor_manager.clear()

# Stop all EnvRunners.
if hasattr(self, "env_runner_group") and self.env_runner_group is not None:
self.env_runner_group.stop()
Expand All @@ -2986,10 +2996,6 @@ def cleanup(self) -> None:
):
self.eval_env_runner_group.stop()

# Stop all Learners.
if hasattr(self, "learner_group") and self.learner_group is not None:
self.learner_group.shutdown()

@OverrideToImplementCustomLogic
@classmethod
@override(Trainable)
Expand Down
10 changes: 8 additions & 2 deletions rllib/algorithms/algorithm_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -964,6 +964,7 @@ def build_env_to_module_connector(self, env, device=None):
from ray.rllib.connectors.env_to_module import (
AddObservationsFromEpisodesToBatch,
AddStatesFromEpisodesToBatch,
AddTimeDimToBatchAndZeroPad,
AgentToModuleMapping,
BatchIndividualItems,
EnvToModulePipeline,
Expand Down Expand Up @@ -1017,7 +1018,9 @@ def build_env_to_module_connector(self, env, device=None):
if self.add_default_connectors_to_env_to_module_pipeline:
# Append OBS handling.
pipeline.append(AddObservationsFromEpisodesToBatch())
# Append STATE_IN/STATE_OUT (and time-rank) handler.
# Append time-rank handler.
pipeline.append(AddTimeDimToBatchAndZeroPad())
# Append STATE_IN/STATE_OUT handler.
pipeline.append(AddStatesFromEpisodesToBatch())
# If multi-agent -> Map from AgentID-based data to ModuleID based data.
if self.is_multi_agent:
Expand Down Expand Up @@ -1139,6 +1142,7 @@ def build_learner_connector(
AddColumnsFromEpisodesToTrainBatch,
AddObservationsFromEpisodesToBatch,
AddStatesFromEpisodesToBatch,
AddTimeDimToBatchAndZeroPad,
AgentToModuleMapping,
BatchIndividualItems,
LearnerConnectorPipeline,
Expand Down Expand Up @@ -1183,7 +1187,9 @@ def build_learner_connector(
)
# Append all other columns handling.
pipeline.append(AddColumnsFromEpisodesToTrainBatch())
# Append STATE_IN/STATE_OUT (and time-rank) handler.
# Append time-rank handler.
pipeline.append(AddTimeDimToBatchAndZeroPad(as_learner_connector=True))
# Append STATE_IN/STATE_OUT handler.
pipeline.append(AddStatesFromEpisodesToBatch(as_learner_connector=True))
# If multi-agent -> Map from AgentID-based data to ModuleID based data.
if self.is_multi_agent:
Expand Down
11 changes: 0 additions & 11 deletions rllib/algorithms/impala/impala.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,6 @@ def __init__(self, algo_class=None):
self.vtrace_clip_rho_threshold = 1.0
self.vtrace_clip_pg_rho_threshold = 1.0
self.learner_queue_size = 3
self.max_requests_in_flight_per_env_runner = 1
self.timeout_s_sampler_manager = 0.0
self.timeout_s_aggregator_manager = 0.0
self.broadcast_interval = 1
Expand Down Expand Up @@ -758,16 +757,6 @@ def _func(actor, p):

time.sleep(0.01)

@override(Algorithm)
def cleanup(self) -> None:
super().cleanup()

# Stop all aggregation actors.
if hasattr(self, "_aggregator_actor_manager") and (
self._aggregator_actor_manager is not None
):
self._aggregator_actor_manager.clear()

def _sample_and_get_connector_states(self):
def _remote_sample_get_state_and_metrics(_worker):
_episodes = _worker.sample()
Expand Down
4 changes: 4 additions & 0 deletions rllib/connectors/common/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,9 @@
from ray.rllib.connectors.common.add_states_from_episodes_to_batch import (
AddStatesFromEpisodesToBatch,
)
from ray.rllib.connectors.common.add_time_dim_to_batch_and_zero_pad import (
AddTimeDimToBatchAndZeroPad,
)
from ray.rllib.connectors.common.agent_to_module_mapping import AgentToModuleMapping
from ray.rllib.connectors.common.batch_individual_items import BatchIndividualItems
from ray.rllib.connectors.common.numpy_to_tensor import NumpyToTensor
Expand All @@ -12,6 +15,7 @@
__all__ = [
"AddObservationsFromEpisodesToBatch",
"AddStatesFromEpisodesToBatch",
"AddTimeDimToBatchAndZeroPad",
"AgentToModuleMapping",
"BatchIndividualItems",
"NumpyToTensor",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ class AddObservationsFromEpisodesToBatch(ConnectorV2):
[
[0 or more user defined ConnectorV2 pieces],
AddObservationsFromEpisodesToBatch,
AddTimeDimToBatchAndZeroPad,
AddStatesFromEpisodesToBatch,
AgentToModuleMapping, # only in multi-agent setups!
BatchIndividualItems,
Expand All @@ -34,6 +35,7 @@ class AddObservationsFromEpisodesToBatch(ConnectorV2):
[0 or more user defined ConnectorV2 pieces],
AddObservationsFromEpisodesToBatch,
AddColumnsFromEpisodesToTrainBatch,
AddTimeDimToBatchAndZeroPad,
AddStatesFromEpisodesToBatch,
AgentToModuleMapping, # only in multi-agent setups!
BatchIndividualItems,
Expand Down
102 changes: 6 additions & 96 deletions rllib/connectors/common/add_states_from_episodes_to_batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,6 @@
from ray.rllib.core.rl_module.rl_module import RLModule
from ray.rllib.utils.annotations import override
from ray.rllib.utils.numpy import convert_to_numpy
from ray.rllib.utils.postprocessing.zero_padding import (
create_mask_and_seq_lens,
split_and_zero_pad,
)
from ray.rllib.utils.spaces.space_utils import BatchedNdArray
from ray.rllib.utils.typing import EpisodeType
from ray.util.annotations import PublicAPI

Expand All @@ -35,6 +30,7 @@ class AddStatesFromEpisodesToBatch(ConnectorV2):
[
[0 or more user defined ConnectorV2 pieces],
AddObservationsFromEpisodesToBatch,
AddTimeDimToBatchAndZeroPad,
AddStatesFromEpisodesToBatch,
AgentToModuleMapping, # only in multi-agent setups!
BatchIndividualItems,
Expand All @@ -45,6 +41,7 @@ class AddStatesFromEpisodesToBatch(ConnectorV2):
[0 or more user defined ConnectorV2 pieces],
AddObservationsFromEpisodesToBatch,
AddColumnsFromEpisodesToTrainBatch,
AddTimeDimToBatchAndZeroPad,
AddStatesFromEpisodesToBatch,
AgentToModuleMapping, # only in multi-agent setups!
BatchIndividualItems,
Expand Down Expand Up @@ -160,7 +157,7 @@ def get_initial_state(self):
output_batch = connector(
rl_module=rl_module,
batch={},
episodes=[episode.to_numpy()],
episodes=[episode],
shared_data={},
)
check(
Expand All @@ -173,7 +170,7 @@ def get_initial_state(self):
# predictions).
# Also note that the different STATE_IN timesteps are already present
# as one batched item per episode in the list.
(episode.id_,): [[rl_module_init_state, -3.0]],
(episode.id_,): [rl_module_init_state, -3.0],
},
)
"""
Expand Down Expand Up @@ -217,61 +214,6 @@ def __call__(
if not rl_module.is_stateful() or Columns.STATE_IN in batch:
return batch

# Make all inputs (other than STATE_IN) have an additional T-axis.
# Since data has not been batched yet (we are still operating on lists in the
# batch), we add this time axis as 0 (not 1). When we batch, the batch axis will
# be 0 and the time axis will be 1.
# Also, let module-to-env pipeline know that we had added a single timestep
# time rank to the data (to remove it again).
if not self._as_learner_connector:
for column in batch.keys():
self.foreach_batch_item_change_in_place(
batch=batch,
column=column,
func=lambda item, eps_id, aid, mid: (
item
if mid is not None and not rl_module[mid].is_stateful()
# Expand on axis 0 (the to-be-time-dim) if item has not been
# batched yet, otherwise axis=1 (the time-dim).
else tree.map_structure(
lambda s: np.expand_dims(
s, axis=(1 if isinstance(s, BatchedNdArray) else 0)
),
item,
)
),
)
shared_data["_added_single_ts_time_rank"] = True
else:
# Before adding STATE_IN to the `data`, zero-pad existing data and batch
# into max_seq_len chunks.
for column, column_data in batch.copy().items():
# Do not zero-pad INFOS column.
if column == Columns.INFOS:
continue
for key, item_list in column_data.items():
# Multi-agent case AND RLModule is not stateful -> Do not zero-pad
# for this model.
assert isinstance(key, tuple)
mid = None
if len(key) == 3:
eps_id, aid, mid = key
if not rl_module[mid].is_stateful():
continue
column_data[key] = split_and_zero_pad(
item_list,
max_seq_len=self._get_max_seq_len(rl_module, module_id=mid),
)
# TODO (sven): Remove this hint/hack once we are not relying on
# SampleBatch anymore (which has to set its property
# zero_padded=True when shuffling).
shared_data[
(
"_zero_padded_for_mid="
f"{mid if mid is not None else DEFAULT_MODULE_ID}"
)
] = True

for sa_episode in self.single_agent_episode_iterator(
episodes,
# If Learner connector, get all episodes (for train batch).
Expand All @@ -280,8 +222,8 @@ def __call__(
agents_that_stepped_only=not self._as_learner_connector,
):
if self._as_learner_connector:
# Multi-agent case: Extract correct single agent RLModule (to get the
# state for individually).
# Multi-agent case: Extract correct single agent RLModule (to get its
# individual state).
if sa_episode.module_id is not None:
sa_module = rl_module[sa_episode.module_id]
else:
Expand Down Expand Up @@ -372,24 +314,6 @@ def __call__(
single_agent_episode=sa_episode,
)

# Also, create the loss mask (b/c of our now possibly zero-padded data)
# as well as the seq_lens array and add these to `data` as well.
mask, seq_lens = create_mask_and_seq_lens(len(sa_episode), max_seq_len)
self.add_n_batch_items(
batch=batch,
column=Columns.SEQ_LENS,
items_to_add=seq_lens,
num_items=len(seq_lens),
single_agent_episode=sa_episode,
)
if not shared_data.get("_added_loss_mask_for_valid_episode_ts"):
self.add_n_batch_items(
batch=batch,
column=Columns.LOSS_MASK,
items_to_add=mask,
num_items=len(mask),
single_agent_episode=sa_episode,
)
else:
assert not sa_episode.is_numpy

Expand Down Expand Up @@ -422,17 +346,3 @@ def __call__(
)

return batch

def _get_max_seq_len(self, rl_module, module_id=None):
if module_id:
mod = rl_module[module_id]
else:
mod = next(iter(rl_module.values()))
if "max_seq_len" not in mod.model_config:
raise ValueError(
"You are using a stateful RLModule and are not providing a "
"'max_seq_len' key inside your `model_config`. You can set this "
"dict and/or override keys in it via `config.rl_module("
"model_config={'max_seq_len': [some int]})`."
)
return mod.model_config["max_seq_len"]
Loading