I have a question:
In your code in ray_trainer.py, the implementation of REMAX is placed after
batch = batch.union(gen_batch_output).
Therefore, when you execute
gen_baseline_output = self.actor_rollout_wg.generate_sequences(gen_baseline_batch)
the input_ids actually passed into generate_sequences consist of
prompt_ids + response_ids.