Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feature(xrk): add q-transformer #783

Open
wants to merge 40 commits into
base: main
Choose a base branch
from
Open
Changes from 1 commit
Commits
Show all changes
40 commits
Select commit Hold shift + click to select a range
c0416af
make it can use
rongkunxue Mar 22, 2024
8ab5da8
change config to fit
rongkunxue Mar 28, 2024
b12714e
good use
rongkunxue Mar 28, 2024
066ff45
change all framework
rongkunxue Mar 29, 2024
5988d14
good use for eval
rongkunxue Apr 2, 2024
0875c3f
add q_value
rongkunxue Apr 2, 2024
cf51545
change action_bin to 8 with best control; init q weight for middle ou…
rongkunxue Apr 10, 2024
90b3dbb
Merge branch 'opendilab:main' into q_transformner
rongkunxue Apr 10, 2024
0446efe
Merge branch 'opendilab:main' into q_transformner
rongkunxue Apr 15, 2024
f309121
polish code
rongkunxue Apr 15, 2024
8eff2ef
change it
rongkunxue Apr 15, 2024
191fe53
polish code for init
rongkunxue Apr 15, 2024
33554e7
polish config
rongkunxue Apr 15, 2024
81bea50
add more high and low with action_bin
rongkunxue Apr 15, 2024
4fe9db0
polish import
rongkunxue Apr 15, 2024
1839ded
polish import
rongkunxue Apr 15, 2024
be60d5c
Merge branch 'opendilab:main' into q_transformner
rongkunxue Apr 23, 2024
4e5dd58
Merge branch 'opendilab:main' into q_transformner
rongkunxue Jun 18, 2024
0e71001
add dataset for update
rongkunxue Jun 19, 2024
6023c65
add init
rongkunxue Jun 19, 2024
7095b38
polish qtransformer
rongkunxue Jun 20, 2024
ad1ccb1
episode
rongkunxue Jun 20, 2024
660a038
polish
rongkunxue Jun 20, 2024
68003c8
polish
rongkunxue Jun 20, 2024
4b228cb
polish
rongkunxue Jun 20, 2024
8e97624
polish
rongkunxue Jun 20, 2024
54688fa
polish
rongkunxue Jun 20, 2024
d8b3868
polish
rongkunxue Jun 20, 2024
509cd5a
polish
rongkunxue Jun 21, 2024
6e3cf36
polish
rongkunxue Jun 21, 2024
d536ab1
polish
rongkunxue Jun 21, 2024
0b54465
poilsh
rongkunxue Jun 21, 2024
140b70f
Merge branch 'opendilab:main' into q_transformner
rongkunxue Jul 1, 2024
c76e9b3
polish online
rongkunxue Jul 1, 2024
44d746e
polish to d4rl dataset
rongkunxue Jul 1, 2024
5d59b3d
add
rongkunxue Jul 4, 2024
b784bb2
add
rongkunxue Jul 4, 2024
f35338b
polish
rongkunxue Jul 4, 2024
7c8d64f
polish
rongkunxue Jul 17, 2024
a057051
make more head for the task
rongkunxue Jul 18, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
good use for eval
rongkunxue committed Apr 2, 2024

Verified

This commit was signed with the committer’s verified signature.
commit 5988d141566b36737314cbd91d13d25cd661e425
78 changes: 16 additions & 62 deletions ding/model/template/qtransformer.py
Original file line number Diff line number Diff line change
@@ -446,62 +446,30 @@ def state_append_actions(self,state,actions:Optional[Tensor] = None):
def get_optimal_actions(
self,
encoded_state,
return_q_values = False,
actions: Optional[Tensor] = None,
prob_random_action: float = 0.5,
**kwargs
):
batch = encoded_state.shape[0]

if prob_random_action == 1:
return self.get_random_actions(batch)
prob_random_action = -1
sos_token = encoded_state
tokens = self.maybe_append_actions(sos_token, actions = actions)

action_bins = []
batch_size = encoded_state.shape[0]
action_bins = torch.empty(batch_size, self.num_actions, device=encoded_state.device,dtype=torch.long)
cache = None
tokens = self.state_append_actions(encoded_state, actions = actions)

for action_idx in range(self.num_actions):

embed, cache = self.transformer(
tokens,
context = encoded_state,
context = None,
cache = cache,
return_cache = True
)

last_embed = embed[:, action_idx]
bin_embeddings = self.action_bin_embeddings[action_idx]

q_values = einsum('b d, a d -> b a', last_embed, bin_embeddings)

selected_action_bins = q_values.argmax(dim = -1)

if prob_random_action > 0.:
random_mask = torch.zeros_like(selected_action_bins).float().uniform_(0., 1.) < prob_random_action
random_actions = self.get_random_actions(batch, 1)
random_actions = rearrange(random_actions, '... 1 -> ...')

selected_action_bins = torch.where(
random_mask,
random_actions,
selected_action_bins
)

next_action_embed = bin_embeddings[selected_action_bins]

tokens, _ = pack((tokens, next_action_embed), 'b * d')

action_bins.append(selected_action_bins)

action_bins = torch.stack(action_bins, dim = -1)

if not return_q_values:
return action_bins

all_q_values = self.get_q_values(embed)
return action_bins, all_q_values
q_values = self.get_q_value_fuction(embed[:, 1:, :])
if action_idx ==0 :
special_idx=action_idx
else :
special_idx=action_idx-1
_, selected_action_indices = q_values[:,special_idx,:].max(dim=-1)
action_bins[:, action_idx] = selected_action_indices
now_actions=action_bins[:,0:action_idx+1]
tokens = self.state_append_actions(encoded_state, actions = now_actions)
return action_bins

def forward(
self,
@@ -585,28 +553,14 @@ def embed_texts(self, texts: List[str]):
return self.conditioner.embed_texts(texts)

@torch.no_grad()
def get_optimal_actions(
def get_actions(
self,
state,
return_q_values = False,
actions: Optional[Tensor] = None,
**kwargs
):
encoded_state = self.state_encode(state)
return self.q_head.get_optimal_actions(encoded_state, return_q_values = return_q_values, actions = actions)

def get_actions(
self,
state,
prob_random_action = 0., # otherwise known as epsilon in RL
**kwargs,
):
batch_size = state.shape[0]
assert 0. <= prob_random_action <= 1.
return self.q_head.get_optimal_actions(encoded_state)

if random() < prob_random_action:
return self.get_random_actions(batch_size = batch_size)
return self.get_optimal_actions(state, **kwargs)

def forward(
self,
6 changes: 3 additions & 3 deletions ding/policy/qtransformer.py
Original file line number Diff line number Diff line change
@@ -414,7 +414,7 @@ def _discretize_action(self, actions):

def _get_actions(self, obs):
# evaluate to get action
action = self._eval_model.get_optimal_actions(obs)
action = self._target_model.get_actions(obs)
action = 2*action/256.0-1
return action

@@ -442,8 +442,8 @@ def _state_dict_learn(self) -> Dict[str, Any]:
- state_dict (:obj:`Dict[str, Any]`): The dict of current policy learn state, for saving and restoring.
"""
ret = {
'model': self._model.state_dict(),
'ema_model': self._ema_model.state_dict(),
'model': self._learn_model.state_dict(),
'ema_model': self._target_model.state_dict(),
'optimizer_q': self._optimizer_q.state_dict(),
}
if self._auto_alpha: