Trainer

RLHF Trainer

\(L_t^{C L I P+V F+S}(\theta)=\hat{\mathbb{E}}_t\left[L_t^{C L I P}(\theta)-c_1 L_t^{V F}(\theta)+c_2 S\left[\pi_\theta\right]\left(s_t\right)\right]\)

\(L^{C L I P}(\theta)=\hat{\mathbb{E}}_t\left[\min \left(r_t(\theta) \hat{A}_t, \operatorname{clip}\left(r_t(\theta), 1-\epsilon, 1+\epsilon\right) \hat{A}_t\right)\right]\)

\(\frac{\pi_\theta\left(a_t \mid s_t\right)}{\pi_{\theta_{\text {old }}}\left(a_t \mid s_t\right)} = \log(\pi_\theta\left(a_t \mid s_t\right)) - \log(\pi_{\theta_{\text {old }}}\left(a_t \mid s_t\right))\)

\(r_t(\theta)=\frac{\pi_\theta\left(a_t \mid s_t\right)}{\pi_{\theta_{\text {old }}}\left(a_t \mid s_t\right)}\)


source

RLHFTrainer

 RLHFTrainer (model:transformers.modeling_utils.PreTrainedModel,
              ref_model:transformers.modeling_utils.PreTrainedModel,
              config:instruct_goose.utils.RLHFConfig)

Initialize self. See help(type(self)) for accurate signature.

Type Details
model PreTrainedModel A pre-trained language model
ref_model PreTrainedModel A a reference model
config RLHFConfig

source

RLHFTrainer.compute_loss

 RLHFTrainer.compute_loss (query_ids:typing.Annotated[torch.Tensor,{'__tor
                           chtyping__':True,'details':('batch_size','seq_l
                           en',),'cls_name':'TensorType'}], query_attentio
                           n_mask:typing.Annotated[torch.Tensor,{'__torcht
                           yping__':True,'details':('batch_size','seq_len'
                           ,),'cls_name':'TensorType'}], response_ids:typi
                           ng.Annotated[torch.Tensor,{'__torchtyping__':Tr
                           ue,'details':('batch_size','seq_len',),'cls_nam
                           e':'TensorType'}], response_attention_mask:typi
                           ng.Annotated[torch.Tensor,{'__torchtyping__':Tr
                           ue,'details':('batch_size','seq_len',),'cls_nam
                           e':'TensorType'}], rewards:typing.Annotated[tor
                           ch.Tensor,{'__torchtyping__':True,'details':('b
                           atch_size',),'cls_name':'TensorType'}])

Calculate PPO’s loss.