The N Implementation Details of RLHF with PPO

Reinforcement Learning from Human Feedback (RLHF) is pivotal in the modern application of language modeling, as exemplified by ChatGPT. This blog post delves into an in-depth exploration of RLHF, attempting to reproduce the results from OpenAI's inaugural RLHF paper, published in 2019. Our detailed examination provides valuable insights into the implementation details of RLHF, which often go unnoticed.

Reinforcement Learning from Human Feedback (RLHF) has been an impactful technique for training modern language models such as ChatGPT. In our quest to research more on RLHF, this blog post closely examines OpenAI’s inaugural RLHF paper published in 2019 together with its open-source codebase at available at openai/lm-human-preferences. Despite being based on TensorFlow-1, the code base released by OpenAI is very well-evaluated and benchmarked, making it a good place to study RLHF implementation engineering details.

We aim to:

  1. reproduce OpenAI’s results in stylistic tasks and match the learning curves of openai/lm-human-preferences, using the modern PyTorch and JAX frameworks in conjunction with HuggingFace Transformers that are predominantly used by the open-source community nowadays;
  2. present a checklist of implementation details, similar to the spirit of The 37 Implementation Details of Proximal Policy Optimization and Debugging RL, Without the Agonizing Pain;
  3. provide a simple-to-read and minimal reference implementation of RLHF;

This work is just for educational / learning purposes. For advanced users requiring more features, such as running larger models with parameter-efficient fine-tuning, huggingface/trl would be a great choice.

Here are the important links:

Matching Learning Curves

Our main contribution is to reproduce OpenAI’s results in stylistic tasks, such as sentiment and descriptiveness. As shown in the figure below, our codebase (orange curves) can produce nearly identical learning curves as OpenAI’s codebase (blue curves).

A note on running openai/lm-human-preferences

To make a direct comparison, we ran the original RLHF code at openai/lm-human-preferences, which will offer valuable metrics to help validate and diagnose our reproduction. We were able to set the original TensorFlow 1.x code up, but it requires a hyper-specific setup:

General Implementation Details

We now take a technical deep dive into the implementation details that are relevant to reproducing OpenAI’s work. In this section, we talk about basic details, such as how rewards/values are generated and how responses are generated. Here are these details in no particular order:

  1. The reward model and policy’s value head take input as the concatenation of query and response
    1. The reward model and policy’s value head do not only look at the response. Instead, it concatenates the query and response together as query_response (lm_human_preferences/rewards.py#L105-L107).
    2. So, for example, if query = "he was quiet for a minute, his eyes unreadable"., and the response = "He looked at his left hand, which held the arm that held his arm out in front of him.", then the reward model and policy’s value do a forward pass on query_response = "he was quiet for a minute, his eyes unreadable. He looked at his left hand, which held the arm that held his arm out in front of him." and produced rewards and values of shape (B, T, 1), where B is the batch size, T is the sequence length, and 1 is the reward head dimension of 1 (lm_human_preferences/rewards.py#L105-L107, lm_human_preferences/policy.py#L111).
    3. The T means that each token has a reward associated with it and its previous context. For example, the eyes token would have a reward corresponding to he was quiet for a minute, his eyes.
  2. Pad with a special padding token and truncate inputs.
    1. OpenAI sets a fixed input length for query query_length; it pads sequences that are too short with pad_token (lm_human_preferences/language/datasets.py#L66-L67) and truncates sequences that are too long (lm_human_preferences/language/datasets.py#L57). See here for a general introduction to the concept). When padding the inputs, OpenAI uses a token beyond the vocabulary (lm_human_preferences/language/encodings.py#L56).
      1. Note on HF’s transformers — padding token. According to (transformers#2630#issuecomment-578159876), padding tokens were not used during the pre-training of GPT and GPT-2; therefore transformer’s gpt2 models have no official padding token associated with its tokenizer. A common practice is to set tokenizer.pad_token = tokenizer.eos_token, but in this work, we shall distinguish these two special tokens to match OpenAI’s original setting, so we will use tokenizer.add_special_tokens({"pad_token": "[PAD]"}).

      Note that having no padding token is a default setting for decoder models, since they train with “packing” during pretraining, which means that many sequences are concatenated and separated by the EOS token and chunks of this sequence that always have the max length are fed to the model during pretraining.

    2. When putting everything together, here is an example
     import transformers
     tokenizer = transformers.AutoTokenizer.from_pretrained("gpt2", padding_side="right")
     tokenizer.add_special_tokens({"pad_token": "[PAD]"})
     query_length = 5
     texts = [
         "usually, he would",
         "she thought about it",
     ]    
     tokens = []
     for text in texts:
         tokens.append(tokenizer.encode(text)[:query_length])
        
     print("tokens", tokens)
     inputs = tokenizer.pad(
         {"input_ids": tokens},
         padding="max_length",
         max_length=query_length,
         return_tensors="pt",
         return_attention_mask=True,
     )
     print("inputs", inputs)
        
     """prints are
     tokens [[23073, 11, 339, 561], [7091, 1807, 546, 340]]
     inputs {'input_ids': tensor([[23073,    11,   339,   561, 50257],
             [ 7091,  1807,   546,   340, 50257]]), 'attention_mask': tensor([[1, 1, 1, 1, 0],
             [1, 1, 1, 1, 0]])}
     """
    
  3. Adjust position indices correspondingly for padding tokens
    1. When calculating the logits, OpenAI’s code works by masking out padding tokens properly. This is achieved by finding out the token indices corresponding to the padding tokens (lm_human_preferences/language/model.py#L296-L297), followed by adjusting their position indices correspondingly (lm_human_preferences/language/model.py#L320).
    2. For example, if the query=[23073, 50259, 50259] and response=[11, 339, 561], where (50259 is OpenAI’s padding token), it then creates position indices as [[0 1 1 1 2 3]] and logits as follows. Note how the logits corresponding to the padding tokens remain the same as before! This is the effect we should be aiming for in our reproduction.

       all_logits [[[ -35.28693   -34.2875    -38.16074  ...  -41.595802  -41.082108
           -35.36577 ]
         [ -35.28693   -34.2875    -38.16074  ...  -41.595802  -41.082108
           -35.36577 ]
         [ -35.28693   -34.2875    -38.16074  ...  -41.595802  -41.082108
           -35.36577 ]
         [-111.303955 -110.94471  -112.90624  ... -113.13064  -113.7788
          -109.17345 ]
         [-111.51512  -109.61077  -114.90231  ... -118.43514  -111.56671
          -112.12478 ]
         [-122.69775  -121.84468  -128.27417  ... -132.28055  -130.39604
          -125.707756]]] (1, 6, 50257)
      
    3. Note on HF’s transformers — position_ids and padding_side. We can replicate the exact logits using Hugging Face’s transformer with 1) left padding and 2) pass in the appropriate position_ids:

       import torch
       import transformers
       tokenizer = transformers.AutoTokenizer.from_pretrained("gpt2", padding_side="right")
       tokenizer.add_special_tokens({"pad_token": "[PAD]"})
       pad_id = tokenizer.pad_token_id
       query = torch.tensor([
           [pad_id, pad_id, 23073],
       ])
       response = torch.tensor([
           [11, 339, 561],
       ])
       temperature = 1.0
              
       query = torch.tensor(query)
       response = torch.tensor(response).long()
       context_length = query.shape[1]
       query_response = torch.cat((query, response), 1)
       pretrained_model = transformers.AutoModelForCausalLM.from_pretrained("gpt2")
       def forward(policy, query_responses, tokenizer):
           attention_mask = query_responses != tokenizer.pad_token_id
           position_ids = attention_mask.cumsum(1) - attention_mask.long()  # exclusive cumsum
           input_ids = query_responses.clone()
           input_ids[~attention_mask] = 0
           return policy(
               input_ids=input_ids,
               attention_mask=attention_mask,
               position_ids=position_ids,
               return_dict=True,
               output_hidden_states=True,
           )
       output = forward(pretrained_model, query_response, tokenizer)
       logits = output.logits
       logits /= temperature
       print(logits)
              
       """
       tensor([[[ -26.9395,  -26.4709,  -30.0456,  ...,  -33.2208,  -33.2884,
                  -27.4360],
                [ -27.1677,  -26.7330,  -30.2386,  ...,  -33.6813,  -33.6931,
                  -27.5928],
                [ -35.2869,  -34.2875,  -38.1608,  ...,  -41.5958,  -41.0821,
                  -35.3658],
                [-111.3040, -110.9447, -112.9062,  ..., -113.1306, -113.7788,
                 -109.1734],
                [-111.5152, -109.6108, -114.9024,  ..., -118.4352, -111.5668,
                 -112.1248],
                [-122.6978, -121.8447, -128.2742,  ..., -132.2805, -130.3961,
                 -125.7078]]], grad_fn=<DivBackward0>)
       """
      
    4. Note on HF’s transformers — position_ids during generate: during generate we should not pass in position_ids because the position_ids are already adjusted in transformers (see huggingface/transformers#/7552).

    Usually, we almost never pass position_ids in transformers. All the masking and shifting logic are already implemented e.g. in the generate function (need permanent code link).

  4. Response generation samples a fixed-length response without padding.
    1. During response generation, OpenAI uses top_k=0, top_p=1.0 and just do categorical samples across the vocabulary (lm_human_preferences/language/sample.py#L43) and the code would keep sampling until a fixed-length response is generated (lm_human_preferences/policy.py#L103). Notably, even if it encounters EOS (end-of-sequence) tokens, it will keep sampling.
    2. Note on HF’s transformers — sampling could stop at eos_token: in transformers, the generation could stop at eos_token (src/transformers/generation/utils.py#L2248-L2256), which is not the same as OpenAI’s setting. To align the setting, we need to do set pretrained_model.generation_config.eos_token_id = None, pretrained_model.generation_config.pad_token_id = None. Note that transformers.GenerationConfig(eos_token_id=None, pad_token_id=None, ...) does not work because pretrained_model.generation_config would override and set a eos_token.

       import torch
       import transformers
       tokenizer = transformers.AutoTokenizer.from_pretrained("gpt2", padding_side="right")
       tokenizer.add_special_tokens({"pad_token": "[PAD]"})
       pad_id = tokenizer.pad_token_id
       query = torch.tensor([
           [pad_id, pad_id, 23073],
       ])
       response = torch.tensor([
           [11, 339, 561],
       ])
       response_length = 4
       temperature = 0.7
       pretrained_model = transformers.AutoModelForCausalLM.from_pretrained("gpt2")
       pretrained_model.generation_config.eos_token_id = None # disable `pad_token_id` and `eos_token_id` because we just want to
       pretrained_model.generation_config.pad_token_id = None  # generate tokens without truncation / padding
       generation_config = transformers.GenerationConfig(
           max_new_tokens=response_length,
           min_new_tokens=response_length,
           temperature=temperature,
           top_k=0.0,
           top_p=1.0,
           do_sample=True,
       )
       context_length = query.shape[1]
       attention_mask = query != tokenizer.pad_token_id
       input_ids = query.clone()
       input_ids[~attention_mask] = 0  # set padding tokens to 0
       output = pretrained_model.generate(
           input_ids=input_ids,
           attention_mask=attention_mask,
           # position_ids=attention_mask.cumsum(1) - attention_mask.long(), # generation collapsed if this was turned on.
           generation_config=generation_config,
           return_dict_in_generate=True,
       )
       print(output.sequences)
              
       """
       tensor([[    0,     0, 23073, 16851,    11,   475,   991]])
       """
      
    3. Note that in a more recent codebase https://github.com/openai/summarize-from-feedback, OpenAI does stop sampling when encountering EOS token (summarize_from_feedback/utils/experiment_helpers.py#L19). However in this work we aim to do a 1:1 replication, so we align the setting that could keep sampling even eos_token is encountered
  5. Learning rate annealing for reward model and policy training.
    1. As Ziegler et al. (2019) suggested, the reward model is trained for a single epoch to avoid overfitting the limited amount of human annotation data (e.g., the descriptiveness task only had about 5000 labels). During this single epoch, the learning rate is annealed to zero (lm_human_preferences/train_reward.py#L249).
    2. Similar to reward model training, the policy’s learning rate is annealed to zero (lm_human_preferences/train_policy.py#L172-L173).
  6. Use different seeds for different processes
    1. When spawning 8 GPU processes to do data parallelism, OpenAI sets a different random seed per process (lm_human_preferences/utils/core.py#L108-L111). Implementation-wise, this is done via local_seed = args.seed + process_rank * 100003. The seed is going to make the model produce different responses and get different scores, for example.
      1. Note: We believe the dataset shuffling has a bug — the dataset is shuffled using the same seed for some reason (lm_human_preferences/lm_tasks.py#L94-L97).

Reward Model Implementation Details

In this section, we discuss reward-model-specific implementation details. We talk about details such as reward normalization and layer initialization. Here are these details in no particular order:

  1. The reward model only outputs the value at the last token.
    1. Notice that the rewards obtained after the forward pass on the concatenation of query and response will have the shape (B, T, 1), where B is the batch size, T is the sequence length (which is always the same; it is query_length + response_length = 64 + 24 = 88 in OpenAI’s setting for stylistic tasks, see launch.py#L9-L11), and 1 is the reward head dimension of 1. For RLHF purposes, the original codebase extracts the reward of the last token (lm_human_preferences/rewards.py#L132), so that the rewards will only have shape (B, 1).
    2. Note that in a more recent codebase openai/summarize-from-feedback, OpenAI stops sampling when encountering EOS token (summarize_from_feedback/utils/experiment_helpers.py#L19). When extracting rewards, it is going to identify the last_response_index, the index before the EOS token (#L11-L13), and extract the reward at that index (summarize_from_feedback/reward_model.py#L59). However in this work we just stick with the original setting.
  2. Reward head layer initialization
    1. The weight of the reward head is initialized according to \( \mathcal{N}\left(0,1 /\left(\sqrt{d_{\text {model }}+1}\right)\right) \) (lm_human_preferences/language/model.py#L368, lm_human_preferences/language/model.py#L251-L252). This aligns with the settings in Stiennon et al., 2020 (summarize_from_feedback/query_response_model.py#L106-L107) (P.S., Stiennon et al., 2020 had a typo on page 17 saying the distribution is \( \mathcal{N}\left(0,1 /\left(d_{\text {model }}+1\right)\right) \) without the square root)
    2. The bias of the reward head is set to 0 (lm_human_preferences/language/model.py#L254).
  3. Reward model normalization before and after
    1. In the paper, Ziegler el al. (2019) mentioned that “to keep the scale of the reward model consistent across training, we normalize it so that it has mean 0 and variance 1 for
      \( x \sim \mathcal{D}, y \sim \rho(·|x) \).” To perform the normalization process, the code first creates a reward_gain and reward_bias, such that the reward can be calculated by reward = reward * reward_gain + reward_bias (lm_human_preferences/rewards.py#L50-L51).
    2. When performing the normalization process, the code first sets reward_gain=1, reward_bias=0 (lm_human_preferences/train_reward.py#L211), followed by collecting sampled queries from the target dataset (e.g., bookcorpus, tldr, cnndm), completed responses, and evaluated rewards. It then gets the empirical mean and std of the evaluated reward (lm_human_preferences/train_reward.py#L162-L167) and tries to compute what the reward_gain and reward_bias should be.
    3. Let us use \( \mu_{\mathcal{D}} \) to denote the empirical mean, \( \sigma_{\mathcal{D}} \) the empirical std, \(g\) the reward_gain, \(b\) reward_bias, \( \mu_{\mathcal{T}} = 0\) target mean and \( \sigma_{\mathcal{T}}=1\) target std. Then we have the following formula.
    \[\begin{aligned}g*\mathcal{N}(\mu_{\mathcal{D}}, \sigma_{\mathcal{D}}) + b &= \mathcal{N}(g*\mu_{\mathcal{D}}, g*\sigma_{\mathcal{D}}) + b\\&= \mathcal{N}(g*\mu_{\mathcal{D}} + b, g*\sigma_{\mathcal{D}}) \\&= \mathcal{N}(\mu_{\mathcal{T}}, \sigma_{\mathcal{T}}) \\g &= \frac{\sigma_{\mathcal{T}}}{\sigma_{\mathcal{D}}} \\b &= \mu_{\mathcal{T}} - g*\mu_{\mathcal{D}}\end{aligned}\]
    1. The normalization process is then applied before and after reward model training (lm_human_preferences/train_reward.py#L232-L234, lm_human_preferences/train_reward.py#L252-L254).

    2. Note that responses \( y \sim \rho(·|x) \) we generated for the normalization purpose are from the pre-trained language model \(\rho \). The model \(\rho \) is fixed as a reference and is not updated in reward learning (lm_human_preferences/train_reward.py#L286C1-L286C31).

Policy Training Implementation Details

In this section, we will delve into details, such as layer initialization, data post-processing, and dropout settings. We will also explore techniques, such as of rejection sampling and reward “whitening”, and adaptive KL. Here are these details in no particular order:

  1. Scale the logits by sampling temperature.
    1. When calculating the log probability of responses, the model first outputs the logits of the tokens in the responses, followed by dividing the logits with the sampling temperature (lm_human_preferences/policy.py#L121). I.e., logits /= self.temperature
    2. In an informal test, we found that without this scaling, the KL would rise faster than expected, and performance would deteriorate.
  2. Value head layer initialization
    1. The weight of the value head is initialized according to \(\mathcal{N}\left(0,0\right)\) (lm_human_preferences/language/model.py#L368, lm_human_preferences/language/model.py#L251-L252). This is
    2. The bias of the reward head is set to 0 (lm_human_preferences/language/model.py#L254).
  3. Select query texts that start and end with a period
    1. This is done as part of the data preprocessing;
      1. Tries to select text only after start_text="." (lm_human_preferences/language/datasets.py#L51)
      2. Tries select text just before end_text="." (lm_human_preferences/language/datasets.py#L61)
      3. Then pad the text (lm_human_preferences/language/datasets.py#L66-L67)
    2. When running openai/lm-human-preferences, OpenAI’s datasets were partially corrupted/lost (openai/lm-human-preferences/issues/17#issuecomment-104405149), so we had to replace them with similar HF datasets, which may or may not cause a performance difference)
    3. For the book dataset, we used https://huggingface.co/datasets/bookcorpus, which we find not necessary to extract sentences that start and end with periods because the dataset ) is already pre-processed this way (e.g., "usually , he would be tearing around the living room , playing with his toys .") To this end, we set start_text=None, end_text=None for the sentiment and descriptiveness tasks.
  4. Disable dropout
    1. Ziegler et al. (2019) suggested, “We do not use dropout for policy training.” This is also done in the code (lm_human_preferences/policy.py#L48).
  5. Rejection sampling
    1. Ziegler et al. (2019) suggested, “We use rejection sampling to ensure there is a period between tokens 16 and 24 and then truncate at that period (This is a crude approximation for ‘end of sentence.’ We chose it because it is easy to integrate into the RL loop, and even a crude approximation is sufficient for the intended purpose of making the human evaluation task somewhat easier). During the RL finetuning, we penalize continuations that don’t have such a period by giving them a fixed reward of −1.”
    2. Specifically, this is achieved with the following steps:
      1. Token truncation: We want to truncate at the first occurrence of truncate_token that appears at or after position truncate_after in the responses (lm_human_preferences/train_policy.py#L378)
      2. Run reward model on truncated response: After the response has been truncated by the token truncation process, the code then runs the reward model on the truncated response.
      3. Rejection sampling: if there is not a period between tokens 16 and 24, then replace the score of the response with a fixed low value (such as -1)(lm_human_preferences/train_policy.py#L384, lm_human_preferences/train_policy.py#L384-L402)
      4. To give some examples in descriptiveness:

</figure>

  1. Discount factor = 1
    1. The discount parameter \(\gamma\) is set to 1 (lm_human_preferences/train_policy.py#L56), which means that future rewards are given the same weight as immediate rewards.
  2. Terminology of the training loop: batches and minibatches in PPO
    1. OpenAI uses the following training loop (lm_human_preferences/train_policy.py#L184-L192). Note: we additionally added the micro_batch_size to help deal with the case in gradient accumulation. At each epoch, it shuffles the batch indices.

      
       import numpy as np
       batch_size = 8
       nminibatches = 2
       gradient_accumulation_steps = 2
       mini_batch_size = batch_size // nminibatches
       micro_batch_size = mini_batch_size // gradient_accumulation_steps
       data = np.arange(batch_size).astype(np.float32)
       print("data:", data)
       print("batch_size:", batch_size)
       print("mini_batch_size:", mini_batch_size)
       print("micro_batch_size:", micro_batch_size)
       for epoch in range(4):
           batch_inds = np.random.permutation(batch_size)
           print("epoch:", epoch, "batch_inds:", batch_inds)
           for mini_batch_start in range(0, batch_size, mini_batch_size):
               mini_batch_end = mini_batch_start + mini_batch_size
               mini_batch_inds = batch_inds[mini_batch_start:mini_batch_end]
                      
               # `optimizer.zero_grad()` set optimizer to zero for gradient accumulation
               for micro_batch_start in range(0, mini_batch_size, micro_batch_size):
                   micro_batch_end = micro_batch_start + micro_batch_size 
                   micro_batch_inds = mini_batch_inds[micro_batch_start:micro_batch_end]
                   print("____⏩ a forward pass on", data[micro_batch_inds])
               # `optimizer.step()`
               print("⏪ a backward pass on", data[mini_batch_inds])
              
       # data: [0. 1. 2. 3. 4. 5. 6. 7.]
       # batch_size: 8
       # mini_batch_size: 4
       # micro_batch_size: 2
       # epoch: 0 batch_inds: [6 4 0 7 3 5 1 2]
       # ____⏩ a forward pass on [6. 4.]
       # ____⏩ a forward pass on [0. 7.]
       # ⏪ a backward pass on [6. 4. 0. 7.]
       # ____⏩ a forward pass on [3. 5.]
       # ____⏩ a forward pass on [1. 2.]
       # ⏪ a backward pass on [3. 5. 1. 2.]
       # epoch: 1 batch_inds: [6 7 3 2 0 4 5 1]
       # ____⏩ a forward pass on [6. 7.]
       # ____⏩ a forward pass on [3. 2.]
       # ⏪ a backward pass on [6. 7. 3. 2.]
       # ____⏩ a forward pass on [0. 4.]
       # ____⏩ a forward pass on [5. 1.]
       # ⏪ a backward pass on [0. 4. 5. 1.]
       # epoch: 2 batch_inds: [1 4 5 6 0 7 3 2]
       # ____⏩ a forward pass on [1. 4.]
       # ____⏩ a forward pass on [5. 6.]
       # ⏪ a backward pass on [1. 4. 5. 6.]
       # ____⏩ a forward pass on [0. 7.]
       # ____⏩ a forward pass on [3. 2.]
       # ⏪ a backward pass on [0. 7. 3. 2.]
       # epoch: 3 batch_inds: [7 2 4 1 3 0 6 5]
       # ____⏩ a forward pass on [7. 2.]
       # ____⏩ a forward pass on [4. 1.]
       # ⏪ a backward pass on [7. 2. 4. 1.]
       # ____⏩ a forward pass on [3. 0.]
       # ____⏩ a forward pass on [6. 5.]
       # ⏪ a backward pass on [3. 0. 6. 5.]
      
  3. Per-token KL penalty
    • The code adds a per-token KL penalty (lm_human_preferences/train_policy.py#L150-L153) to the rewards, in order to discourage the policy to be very different from the original policy.
    • Using the "usually, he would" as an example, it gets tokenized to [23073, 11, 339, 561]. Say we use [23073] as the query and [11, 339, 561] as the response. Then under the default gpt2 parameters, the response tokens will have log probabilities of the reference policy logprobs=[-3.3213, -4.9980, -3.8690] .
      • During the first PPO update epoch and minibatch update, so the active policy will have the same log probabilities new_logprobs=[-3.3213, -4.9980, -3.8690]. , so the per-token KL penalty would be kl = new_logprobs - logprobs = [0., 0., 0.,]
      • However, after the first gradient backward pass, we could have new_logprob=[3.3213, -4.9980, -3.8690] , so the per-token KL penalty becomes kl = new_logprobs - logprobs = [-0.3315, -0.0426, 0.6351]
      • Then the non_score_reward = beta * kl , where beta is the KL penalty coefficient \(\beta\), and it’s added to the score obtained from the reward model to create the rewards used for training. The score is only given at the end of episode; it could look like [0.4,] , and we have rewards = [beta * -0.3315, beta * -0.0426, beta * 0.6351 + 0.4].
  4. Per-minibatch reward and advantage whitening, with optional mean shifting
    1. OpenAI implements a whiten function that looks like below, basically normalizing the values by subtracting its mean followed by dividing by its standard deviation. Optionally, whiten can shift back the mean of the whitened values with shift_mean=True.
     def whiten(values, shift_mean=True):
         mean, var = torch.mean(values), torch.var(values, unbiased=False)
         whitened = (values - mean) * torch.rsqrt(var + 1e-8)
         if not shift_mean:
             whitened += mean
         return whitened
    
    1. In each minibatch, OpenAI then whitens the reward whiten(rewards, shift_mean=False) without shifting the mean (lm_human_preferences/train_policy.py#L325) and whitens the advantages whiten(advantages) with the shifted mean (lm_human_preferences/train_policy.py#L338).
    2. Optimization note: if the number of minibatches is one (which is the case in this reproduction) we only need to whiten rewards, calculate and whiten advantages once since their values won’t change.
    3. TensorFlow vs PyTorch note: Different behavior of tf.moments vs torch.var: The behavior of whitening is different in torch vs tf because the variance calculation is different:

       import numpy as np
       import tensorflow as tf
       import torch
              
       def whiten_tf(values, shift_mean=True):
           mean, var = tf.nn.moments(values, axes=list(range(values.shape.rank)))
           mean = tf.Print(mean, [mean], 'mean', summarize=100)
           var = tf.Print(var, [var], 'var', summarize=100)
           whitened = (values - mean) * tf.rsqrt(var + 1e-8)
           if not shift_mean:
               whitened += mean
           return whitened
              
       def whiten_pt(values, shift_mean=True, unbiased=True):
           mean, var = torch.mean(values), torch.var(values, unbiased=unbiased)
           print("mean", mean)
           print("var", var)
           whitened = (values - mean) * torch.rsqrt(var + 1e-8)
           if not shift_mean:
               whitened += mean
           return whitened
              
       rewards = np.array([
           [1.2, 1.3, 1.4],
           [1.5, 1.6, 1.7],
           [1.8, 1.9, 2.0],
       ])
              
       with tf.Session() as sess:
           print(sess.run(whiten_tf(tf.constant(rewards, dtype=tf.float32), shift_mean=False)))
           print(whiten_pt(torch.tensor(rewards), shift_mean=False, unbiased=True))
           print(whiten_pt(torch.tensor(rewards), shift_mean=False, unbiased=False))
      
       mean[1.5999999]
       var[0.0666666627]
       [[0.05080712 0.4381051  0.8254035 ]
        [1.2127019  1.6000004  1.9872988 ]
        [2.3745968  2.7618952  3.1491938 ]]
       mean tensor(1.6000, dtype=torch.float64)
       var tensor(0.0750, dtype=torch.float64)
       tensor([[0.1394, 0.5046, 0.8697],
               [1.2349, 1.6000, 1.9651],
               [2.3303, 2.6954, 3.0606]], dtype=torch.float64)
       mean tensor(1.6000, dtype=torch.float64)
       var tensor(0.0667, dtype=torch.float64)
       tensor([[0.0508, 0.4381, 0.8254],
               [1.2127, 1.6000, 1.9873],
               [2.3746, 2.7619, 3.1492]], dtype=torch.float64)
              
      
  5. Clipped value function
    1. As done in the original PPO (baselines/ppo2/model.py#L68-L75), the value function is clipped (lm_human_preferences/train_policy.py#L343-L348) in a similar fashion as the policy objective.
  6. Adaptive KL
    • The KL divergence penalty coefficient \(\beta\) is modified adaptively based on the KL divergence between the current policy and the previous policy. If the KL divergence is outside a predefined target range, the penalty coefficient is adjusted to bring it closer to the target range (lm_human_preferences/train_policy.py#L115-L124). It’s implemented as follows:

        class AdaptiveKLController:
            def __init__(self, init_kl_coef, hparams):
                self.value = init_kl_coef
                self.hparams = hparams
              
            def update(self, current, n_steps):
                target = self.hparams.target
                proportional_error = np.clip(current / target - 1, -0.2, 0.2)
                mult = 1 + proportional_error * n_steps / self.hparams.horizon
                self.value *= mult
      
    • For the sentiment and descriptiveness tasks examined in this work, we have init_kl_coef=0.15, hparams.target=6, hparams.horizon=10000.

PyTorch Adam optimizer numerical issues w.r.t RLHF

### pytorch adam implementation:
bias_correction1 = 1 - beta1 ** step
bias_correction2 = 1 - beta2 ** step
step_size = lr / bias_correction1
bias_correction2_sqrt = _dispatch_sqrt(bias_correction2)
denom = (exp_avg_sq.sqrt() / bias_correction2_sqrt).add_(eps)
param.addcdiv_(exp_avg, denom, value=-step_size)

### tensorflow adam implementation:
lr_t = lr * _dispatch_sqrt((1 - beta2 ** step)) / (1 - beta1 ** step)
denom = exp_avg_sq.sqrt().add_(eps)
param.addcdiv_(exp_avg, denom, value=-lr_t)
\[\begin{aligned}\text{pytorch adam :}\quad \theta_t & =\theta_{t-1}-\alpha \cdot \hat{m}_t /\left(\sqrt{\hat{v}_t}+\varepsilon\right) \\& =\theta_{t-1}- \alpha \underbrace{\left[m_t /\left(1-\beta_1^t\right)\right]}_{=\hat{m}_t} /\left[\sqrt{\underbrace{v_t /\left(1-\beta_2^t\right)}_{=\hat{v}_t} }+\varepsilon\right]\\& =\theta_{t-1}- \alpha\left[m_t /\left(1-\beta_1^t\right)\right]\frac{\sqrt{1-\beta_2^t}}{\sqrt{v_t}+\color{green}{\varepsilon \sqrt{1-\beta_2^t}}}\end{aligned}\] \[\begin{aligned}\text{tensorflow adam:}\quad \theta_t & =\theta_{t-1}-\alpha_t m_t /\left(\sqrt{v_t}+\hat{\varepsilon}\right) \\& =\theta_{t-1}-\underbrace{\left[\alpha \sqrt{1-\beta_2^t} /\left(1-\beta_1^t\right)\right]}_{=\alpha_t} m_t /\left(\sqrt{v_t}+\hat{\varepsilon}\right) \\& =\theta_{t-1}- \alpha\left[m_t /\left(1-\beta_1^t\right)\right] \frac{\sqrt{1-\beta_2^t}}{\sqrt{v_t}+\color{green}{\hat{\varepsilon}}} \end{aligned}\]

Limitations

Noticed this work does not try to reproduce the summarization work in CNN DM or TL;DR. This was because we found the training to be time-consuming and brittle.

The particular training run we had showed poor GPU utilization (around 30%), so it takes almost 4 days to perform a training run, which is highly expensive (only AWS sells p3dn.24xlarge, and it costs $31.212 per hour)

Additionally, training was brittle. While the reward goes up, we find it difficult to reproduce the “smart copier” behavior reported by Ziegler et al. (2019) . Below are some sample outputs — clearly, the agent overfits somehow.

QUERY: The modern take on Death & Taxes still plays it. I know as a Tron player I find Mindcensor 
pretty frustrating. ⏎ However, the answer to your question is probably that the decks that 
are currently taking up the lion's share of the metagame don't really care about its effect.
It has a chance to act as an instant speed Stone Rain + Gut Shot in response to a fetch 
crack, but after that it's a 2/1 flyer that dies to all relevant removal. ⏎ It's especially 
powerful against Tron since so much of the deck's consistency is in its tutor effects -- 
Sylvan Scrying, Expedition Map, and Eye of Ugin. This combined with Pyroclasm and Oblivion 
Stone being the only creature removal the deck has available to it before Karn make 
Mindcensor fairly potent. ⏎ I'm having trouble thinking of any other popular decks (ignoring
the interaction with fetches) that Mindcensor really sticks it to, however.

PPO GENERATED RESPONSE: Mindcensor is a very good card, but it's not really worth playing.<|endoftext|>Image 
caption The family was forced from their home after they were not allowed to leave the 
house⏎⏎A family of four has been jailed for five years after admitting breaching home rule 
by refusing to leave their home.⏎

reward: -0.48
kl: +11.37
total reward: -0.82

QUERY: To make a long story short, my SO (27/m) and I live in my hometown. We've been together for 
4 years. His family lives across the country. A lot of the time we spend with my parents but
besides that, we aren't very social. His family not being here doesn't make me feel like I 
can't be with mine. I just don't like to. ⏎ I find my extended family (so not my parents) 
extremely irritating. Especially the ones around my age. They're dramatic and rude. Every 
family event is some Jerry Springer episode I don't want to participate in. I don't have any
issues with anyone in particular. I just find them annoying. I'm never uncivil or openly 
telling them how I feel. If there's get togethers, I tend to avoid them. I don't ask anyone 
to hang out or come over. ⏎ My SO and I have experienced multiple miscarriages due to 
unknown reasons. This has made me very sensitive and a bit of a loaner when it comes to the 
girls in my family my age. I'm not jealous that my family has children. I just don't seem to
be able to relate to them. When I explained our issues, they reacted very unsympathetic and 
just weren't very supportive. I guess that was the time that I started to pull away and not 
try to keep up any relationships. ⏎ My SO and I are pregnant, half way! This is something 
we're so joyed about. However, when we told everyone, they just didn't really give a 
reaction. No congrats or that they were happy for us. Since then, I completely pulled away. 
I talked to my parents about this and I just don't feel the need to contact anyone. The 
reason I'm posting is because I discussed deleting some relatives off social media with my 
mom and she was telling me how it's really rude. I don't think it is. We don't talk. There's
no relationship. Honestly, I think they're just friends with me for gossip purposes. A few 
times I tried reaching out to some of my cousins with no reply. ⏎ I guess what I'm asking is
if I should be trying harder to establish something with my family or if it's perfectly OK 
just to delete them and not talk to them again? Our primary contact is Facebook even though 
we live in the same town.

PPO GENERATED RESPONSE:  We're dating now. We've met, we're dating. We're dating now. We're dating now. We're dating
now. We're dating now. We're dating now.⏎⏎We've met⏎⏎We're dating now.⏎⏎We're dating now.⏎

reward: +2.09
kl: +19.45
total reward: +1.50

QUERY: This tower totally decided to screw me over on the final stage. So a bit of background I 
don't have the strongest squad but pretty much I have: ⏎ Guardian Loch (Sacred Jewel) ⏎ 
Anima Tiara (Flesh Armor) ⏎ Lord Cayena (sacred Jewel) ⏎ Oracle Twins (sacred Jewel) ⏎ 
Guardian Logan (Burny soul stone) ⏎ and was using a friends maxed michelle. ⏎ So throughout 
the earlier stages I've seen this guy about 5 times but have never caught him. So I had a 
little trouble with stage 9 but felt like with the additional 50% atk boost and bb from 
michelle I could do it. Also on stage 9 the reason I died was because I didn't take any 
stimulants and this guy managed to paralyze 3 members of my squad. So I replaced mega cures 
for stimulants. ⏎ Round 3 I meet him. He gets a lucky crit on my caynea and takes her down. 
I have to use my one revive, take him down and  capture  him. I'm pretty excited. ⏎ Round 7 
I see the warning danger sign. I face him again and manage to take him down without any of 
my units dying no capture. At this point I'm starting to get low on healing items. Also I've
only heard of meeting him twice on a run. But then I checked this thread and saw that he 
always appears on the last stage. I thought that maybe it glitched and he wouldn't appear on
the final round. ⏎ Nope. ⏎ This guy shows up on the final round showing up 3 times in one 
run. Worst part was I manged to get him down to a sliver of health so that as long as one 
unit hit him I would've defeated him with the next turn. This was also the turn he decided 
to use soul desecration twice. I definitely felt like my soul got desecrated watching all my
units die. So at this point it was give up the captured unit and try all over again (knowing
my luck never capture him again) or gem for the last hit. I was really looking forward to 
getting the gem for completion (f2p gotta scrap up everything right?) but I let it go.

PPO GENERATED RESPONSE:  ⏎⏎When you capture a unit try to kill it.⏎⏎Kill them.⏎⏎Kill them.⏎⏎Kill them.⏎⏎Kill 
them.⏎⏎Kill them.⏎⏎Kill them.⏎⏎Kill them.⏎⏎Kill them.⏎⏎Kill them.⏎

reward: +0.79
kl: +10.82
total reward: +0.46

Conclusion

In this work, we took a deep dive into OpenAI’s original RLHF codebase and compiled a list of its implementation details. We also created a minimal base which reproduces the same learning curves as OpenAI’s original RLHF codebase, when the dataset and hyperparameters are controlled. Furthermore, we identify surprising implementation details such as the adam optimizer’s setting which causes aggressive updates in early RLHF training.

For attribution in academic contexts, please cite this work as
        PLACEHOLDER FOR ACADEMIC ATTRIBUTION
  
BibTeX citation
        PLACEHOLDER FOR BIBTEX