-
Notifications
You must be signed in to change notification settings - Fork 5.9k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[RLlib] Offpolicy add metrics to multi-agent replay buffers. #49959
base: master
Are you sure you want to change the base?
[RLlib] Offpolicy add metrics to multi-agent replay buffers. #49959
Conversation
…odeReplayBuffer'. Signed-off-by: simonsays1980 <simon.zehnder@gmail.com>
…chanism in 'DQN'. Signed-off-by: simonsays1980 <simon.zehnder@gmail.com>
…ermore, added a further key argument for the initialization of the buffer to get the number of iterations for smoothing. Signed-off-by: simonsays1980 <simon.zehnder@gmail.com>
…ics to the 'PrioritizedEpisodeReplayBuffer'. Added also docstrings. Signed-off-by: simonsays1980 <simon.zehnder@gmail.com>
Signed-off-by: simonsays1980 <simon.zehnder@gmail.com>
… logic such that method overriding works for MA-buffers. Signed-off-by: simonsays1980 <simon.zehnder@gmail.com>
…ethod overriding works for MA-buffers. Signed-off-by: simonsays1980 <simon.zehnder@gmail.com>
…'PrioritizedEpisodeReplayBuffer' such that method overriding works for MA-buffers. Signed-off-by: simonsays1980 <simon.zehnder@gmail.com>
Signed-off-by: simonsays1980 <simon.zehnder@gmail.com>
…ppened because no episodes were evicted, yet. Signed-off-by: simonsays1980 <simon.zehnder@gmail.com>
…playBuffer' for 'indipendent' sampling and fixed a small nit in the 'EpisodeReplayBuffer._update_sample_metrics'. Signed-off-by: simonsays1980 <simon.zehnder@gmail.com>
…PrioritizedReplayBuffer'. Furthermore, I modified the base class' '_update_sample_metrics' to already include the 'resample' counters, such that no subclass needs to moverride. This was before creating a conflict when the 'MultiAgentPrioritizedReplayBuffer' inherited from (a) 'PrioritizedReplayBuffer# and 'MultiAgentReplayBuffer'. Signed-off-by: simonsays1980 <simon.zehnder@gmail.com>
Signed-off-by: simonsays1980 <simon.zehnder@gmail.com>
Signed-off-by: simonsays1980 <simon.zehnder@gmail.com>
@@ -128,6 +124,7 @@ def __init__( | |||
batch_length_T: int = 1, | |||
alpha: float = 1.0, | |||
metrics_num_episodes_for_smoothing: int = 100, | |||
metrics_num_episodes_for_smoothing: int = 100, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Arg defined twice?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Great catch! Removed in the next commit.
@@ -468,6 +497,7 @@ def sample( | |||
# Skip, if we are too far to the end and `episode_ts` + n_step would go | |||
# beyond the episode's end. | |||
if episode_ts + actual_n_step > len(episode): | |||
num_resamples += 1 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
duplicate line?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let me check this. Should be called once only.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yup! Great catch again!
# Increment counter. | ||
B += 1 | ||
|
||
# Keep track of sampled indices for updating priorities later. | ||
self._last_sampled_indices.append(idx) | ||
|
||
# Add to the sampled timesteps counter of the buffer. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Duplicate line?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Again. I will remove this.
@@ -108,6 +129,7 @@ def __init__( | |||
batch_size_B: int = 16, | |||
batch_length_T: int = 64, | |||
metrics_num_episodes_for_smoothing: int = 100, | |||
metrics_num_episodes_for_smoothing: int = 100, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
duplicate line?
Signed-off-by: simonsays1980 <simon.zehnder@gmail.com>
module_to_num_steps_added[mid] += e_eps.agent_steps() | ||
|
||
# Update the adding metrics. | ||
self._update_add_metrics( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just for safety: Can we make all these utility methods with forced keywords?
def _update_add_metrics(self, *, ...):
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sure. Great idea.
module_to_sampled_n_step = { | ||
mid: sum(l) / len(l) for mid, l in module_to_sampled_n_steps.items() | ||
} | ||
self._update_sample_metrics( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
same here, let's make these forced keyword args:
def _update_sample_metrics(self, *, ...):
# Increase index to the new length of `self._indices`. | ||
j = len(self._indices) | ||
|
||
# Update the adding metrics. | ||
self._update_add_metrics( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
same: def _update_add_metrics(self, *, ...)
agent_to_sampled_episode_idxs[sa_episode.agent_id].add(sa_episode.id_) | ||
module_to_sampled_episode_idxs[module_id].add(sa_episode.id_) | ||
# Get the corresponding index in the `env_to_agent_t` mapping. | ||
ma_episode_ts = ma_episode.env_t_to_agent_t[agent_id].data.index( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could this be super expensive, if we have long episodes?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Are we doing this only for the actual_n_steps
metrics? How important is it to have these? It's mostly relevant for short episodes, correct? Where we sometimes sample too close to the end and can't fulfil the given n-step length?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No, this is done for any step. The idea is to check, how variate this sample batch is. How many steps are duplicates. Here specifically: how many agent steps come from the same env step. The hash is built over the multi-agent-episode env step
module_to_num_episodes_evicted[DEFAULT_MODULE_ID] += 1 | ||
agent_to_num_steps_evicted[ | ||
DEFAULT_AGENT_ID | ||
] += evicted_eps.agent_steps() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: use evicted_eps_len
?
] += evicted_eps.agent_steps() | ||
module_to_num_steps_evicted[ | ||
DEFAULT_MODULE_ID | ||
] += evicted_eps.agent_steps() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: use evicted_eps_len
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, we could. I wanted to make explicit that this metrics considers the agent steps - which here is of course the same.
Why are these changes needed?
Multi-agent (episode) replay buffers in the new API stack were lacking metrics. This PR proposes adding metrics tom the multi-agent (episode) replay buffers by:
_update_add_metrics
and_update_sample_metrics
method needs to be overridden.sampled_n_step
,agent_to_sampled_n_step
, andmodule_to_sampled_n_step
can beNone
and are then not added).Related issue number
Checks
git commit -s
) in this PR.scripts/format.sh
to lint the changes in this PR.method in Tune, I've added it in
doc/source/tune/api/
under thecorresponding
.rst
file.