Skip to content

Input contamination in baseline rollout (prompt+response fed into generation) #19

@Dhs2004

Description

@Dhs2004

I have a question:
In your code in ray_trainer.py, the implementation of REMAX is placed after

batch = batch.union(gen_batch_output).
Therefore, when you execute
gen_baseline_output = self.actor_rollout_wg.generate_sequences(gen_baseline_batch)
the input_ids actually passed into generate_sequences consist of
prompt_ids + response_ids.

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions