Skip to content

Commit

Permalink
Merge pull request #598 from KafuuChikai/main
Browse files Browse the repository at this point in the history
[MISC] Update Drone Entity and Training Performance Enhancements
  • Loading branch information
Genesis-Embodied-AI authored Jan 18, 2025
2 parents 2d67eed + 3400463 commit 512ad1a
Show file tree
Hide file tree
Showing 5 changed files with 29 additions and 17 deletions.
12 changes: 10 additions & 2 deletions examples/drone/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,13 @@ Train the drone hovering policy using the `HoverEnv` environment.
Run with:

```bash
python hover_train.py -e drone-hovering -B 8192 --max_iterations 500
python hover_train.py -e drone-hovering -B 8192 --max_iterations 300
```

Train with visualization:

```bash
python hover_train.py -e drone-hovering -B 8192 --max_iterations 300 -v
```

#### 3.2 Evaluation
Expand All @@ -68,12 +74,14 @@ Evaluate the trained drone hovering policy.
Run with:

```bash
python hover_eval.py -e drone-hovering --ckpt 500 --record
python hover_eval.py -e drone-hovering --ckpt 300 --record
```

**Note**: If you experience slow performance or encounter other issues
during evaluation, try removing the `--record` option.

For the latest updates, detailed documentation, and additional resources, visit this repository: [GenesisDroneEnv](https://github.com/KafuuChikai/GenesisDroneEnv).

## Technical Details

- The drone model used is the Crazyflie 2.X (`urdf/drones/cf2x.urdf`)
Expand Down
5 changes: 1 addition & 4 deletions examples/drone/hover_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,10 +129,7 @@ def _at_target(self):

def step(self, actions):
self.actions = torch.clip(actions, -self.env_cfg["clip_actions"], self.env_cfg["clip_actions"])
exec_actions = self.actions.cpu()
# exec_actions = self.last_actions.cpu() if self.simulate_action_latency else self.actions.cpu()
# target_dof_pos = exec_actions * self.env_cfg["action_scale"] + self.default_dof_pos
# self.drone.control_dofs_position(target_dof_pos)
exec_actions = self.actions

# 14468 is hover rpm
self.drone.set_propellels_rpm((1 + exec_actions * 0.8) * 14468.429183500699)
Expand Down
2 changes: 1 addition & 1 deletion examples/drone/hover_eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
def main():
parser = argparse.ArgumentParser()
parser.add_argument("-e", "--exp_name", type=str, default="drone-hovering")
parser.add_argument("--ckpt", type=int, default=500)
parser.add_argument("--ckpt", type=int, default=300)
parser.add_argument("--record", action="store_true", default=False)
args = parser.parse_args()

Expand Down
4 changes: 2 additions & 2 deletions examples/drone/hover_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ def get_train_cfg(exp_name, max_iterations):
"load_run": -1,
"log_interval": 1,
"max_iterations": max_iterations,
"num_steps_per_env": 24,
"num_steps_per_env": 100,
"policy_class_name": "ActorCritic",
"record_interval": -1,
"resume": False,
Expand Down Expand Up @@ -111,7 +111,7 @@ def main():
parser.add_argument("-e", "--exp_name", type=str, default="drone-hovering")
parser.add_argument("-v", "--vis", action="store_true", default=False)
parser.add_argument("-B", "--num_envs", type=int, default=8192)
parser.add_argument("--max_iterations", type=int, default=500)
parser.add_argument("--max_iterations", type=int, default=300)
args = parser.parse_args()

gs.init(logging_level="error")
Expand Down
23 changes: 15 additions & 8 deletions genesis/engine/entities/drone_entity.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import os
import xml.etree.ElementTree as etxml

import numpy as np
import torch
import taichi as ti

import genesis as gs
Expand All @@ -12,7 +12,6 @@

@ti.data_oriented
class DroneEntity(RigidEntity):

def _load_URDF(self, morph, surface):
super()._load_URDF(morph, surface)

Expand All @@ -25,32 +24,40 @@ def _load_URDF(self, morph, surface):
self._COM_link_idx = self.get_link(morph.COM_link_name).idx

propellers_links = gs.List([self.get_link(name) for name in morph.propellers_link_names])
self._propellers_link_idxs = np.array([link.idx for link in propellers_links], dtype=gs.np_int)
self._propellers_link_idxs = torch.tensor(
[link.idx for link in propellers_links], dtype=gs.tc_int, device=gs.device
)
try:
self._propellers_vgeom_idxs = np.array([link.vgeoms[0].idx for link in propellers_links], dtype=gs.np_int)
self._propellers_vgeom_idxs = torch.tensor(
[link.vgeoms[0].idx for link in propellers_links], dtype=gs.tc_int, device=gs.device
)
self._animate_propellers = True
except Exception:
gs.logger.warning("No visual geometry found for propellers. Skipping propeller animation.")
self._animate_propellers = False

self._propellers_spin = np.array(morph.propellers_spin, dtype=gs.np_float)
self._propellers_spin = torch.tensor(morph.propellers_spin, dtype=gs.tc_float, device=gs.device)
self._model = morph.model

def _build(self):
super()._build()

self._propellers_revs = np.zeros(self._solver._batch_shape(self._n_propellers), dtype=gs.np_float)
self._propellers_revs = torch.zeros(
self._solver._batch_shape(self._n_propellers), dtype=gs.tc_float, device=gs.device
)
self._prev_prop_t = None

def set_propellels_rpm(self, propellels_rpm):
if self._prev_prop_t == self.sim.cur_step_global:
gs.raise_exception("`set_propellels_rpm` can only be called once per step.")
self._prev_prop_t = self.sim.cur_step_global

propellels_rpm = self.solver._process_dim(np.array(propellels_rpm, dtype=gs.np_float)).T
propellels_rpm = self.solver._process_dim(
torch.as_tensor(propellels_rpm, dtype=gs.tc_float, device=gs.device)
).T.contiguous()
if len(propellels_rpm) != len(self._propellers_link_idxs):
gs.raise_exception("Last dimension of `propellels_rpm` does not match `entity.n_propellers`.")
if np.any(propellels_rpm < 0):
if torch.any(propellels_rpm < 0):
gs.raise_exception("`propellels_rpm` cannot be negative.")
self._propellers_revs = (self._propellers_revs + propellels_rpm) % (60 / self.solver.dt)

Expand Down

0 comments on commit 512ad1a

Please sign in to comment.