-
Notifications
You must be signed in to change notification settings - Fork 391
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
WIP: feature(pu): adapt to unizero-multitask ddp, and adapt ppo to support jericho config #858
base: main
Are you sure you want to change the base?
Conversation
…n mask), e.g. detective env
if param.grad is not None: | ||
allreduce(param.grad.data) | ||
else: | ||
# 如果梯度为 None,则创建一个与 param.grad_size 相同的零张量,并执行 allreduce |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
remove commented code and add English comment, then these modifications will be merged
# dist.init_process_group(backend=backend, rank=rank, world_size=world_size) | ||
# TODO: | ||
import datetime | ||
dist.init_process_group(backend=backend, rank=rank, world_size=world_size, timeout=datetime.timedelta(seconds=60000)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why add this
# if self._rank == 0: | ||
# self._monitor = get_simple_monitor_type(self._policy.monitor_vars())(TickTime(), expire=10) | ||
|
||
self._monitor = get_simple_monitor_type(self._policy.monitor_vars())(TickTime(), expire=10) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
add an argument named only_monitor_rank0
to control the logic, defaults to True
for k in engine.log_buffer: | ||
engine.log_buffer[k].clear() | ||
return | ||
# if engine.rank != 0: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
also pass the only_monitor_rank0
argument to the hook class
self._global_state_encoder = nn.Identity() | ||
elif len(global_obs_shape) == 3: | ||
self._mixer = Mixer(agent_num, embedding_size, embedding_size, activation=activation) | ||
self._global_state_encoder = ConvEncoder(global_obs_shape, hidden_size_list=hidden_size_list, activation=activation, norm_type='BN') |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why BN
rather than using LN
as default here
agent_state, global_state = agent_state.unsqueeze(0), global_state.unsqueeze(0) | ||
agent_state = agent_state.unsqueeze(0) | ||
if single_step and len(global_state.shape) == 2: | ||
global_state = global_state.unsqueeze(0) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
add shape comments
@@ -205,7 +214,10 @@ def forward(self, data: dict, single_step: bool = True) -> dict: | |||
agent_q_act = torch.gather(agent_q, dim=-1, index=action.unsqueeze(-1)) | |||
agent_q_act = agent_q_act.squeeze(-1) # T, B, A | |||
if self.mixer: | |||
global_state_embedding = self._global_state_encoder(global_state) | |||
if len(global_state.shape) == 5: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
add some comments
@@ -265,12 +265,17 @@ def compute_actor(self, x: torch.Tensor) -> Dict: | |||
>>> assert actor_outputs['logit'].shape == torch.Size([4, 64]) | |||
""" | |||
if self.share_encoder: | |||
x = self.encoder(x) | |||
# import ipdb;ipdb.set_trace() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
modify the corresponding API comments, and the isinstance(x, dict)
to control the logic
Description
Related Issue
TODO
Check List