Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

WIP: feature(pu): adapt to unizero-multitask ddp, and adapt ppo to support jericho config #858

Open
wants to merge 8 commits into
base: main
Choose a base branch
from

Conversation

puyuan1996
Copy link
Collaborator

Description

Related Issue

TODO

Check List

  • merge the latest version source branch/repo, and resolve all the conflicts
  • pass style check
  • pass all the tests

if param.grad is not None:
allreduce(param.grad.data)
else:
# 如果梯度为 None,则创建一个与 param.grad_size 相同的零张量,并执行 allreduce
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

remove commented code and add English comment, then these modifications will be merged

# dist.init_process_group(backend=backend, rank=rank, world_size=world_size)
# TODO:
import datetime
dist.init_process_group(backend=backend, rank=rank, world_size=world_size, timeout=datetime.timedelta(seconds=60000))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why add this

# if self._rank == 0:
# self._monitor = get_simple_monitor_type(self._policy.monitor_vars())(TickTime(), expire=10)

self._monitor = get_simple_monitor_type(self._policy.monitor_vars())(TickTime(), expire=10)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

add an argument named only_monitor_rank0 to control the logic, defaults to True

for k in engine.log_buffer:
engine.log_buffer[k].clear()
return
# if engine.rank != 0:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

also pass the only_monitor_rank0 argument to the hook class

self._global_state_encoder = nn.Identity()
elif len(global_obs_shape) == 3:
self._mixer = Mixer(agent_num, embedding_size, embedding_size, activation=activation)
self._global_state_encoder = ConvEncoder(global_obs_shape, hidden_size_list=hidden_size_list, activation=activation, norm_type='BN')
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why BN rather than using LN as default here

agent_state, global_state = agent_state.unsqueeze(0), global_state.unsqueeze(0)
agent_state = agent_state.unsqueeze(0)
if single_step and len(global_state.shape) == 2:
global_state = global_state.unsqueeze(0)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

add shape comments

@@ -205,7 +214,10 @@ def forward(self, data: dict, single_step: bool = True) -> dict:
agent_q_act = torch.gather(agent_q, dim=-1, index=action.unsqueeze(-1))
agent_q_act = agent_q_act.squeeze(-1) # T, B, A
if self.mixer:
global_state_embedding = self._global_state_encoder(global_state)
if len(global_state.shape) == 5:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

add some comments

@@ -265,12 +265,17 @@ def compute_actor(self, x: torch.Tensor) -> Dict:
>>> assert actor_outputs['logit'].shape == torch.Size([4, 64])
"""
if self.share_encoder:
x = self.encoder(x)
# import ipdb;ipdb.set_trace()
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

modify the corresponding API comments, and the isinstance(x, dict) to control the logic

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants