Trainer
RLHF Trainer
RLHFTrainer
RLHFTrainer (model:transformers.modeling_utils.PreTrainedModel, ref_model:transformers.modeling_utils.PreTrainedModel, config:instruct_goose.utils.RLHFConfig)
Initialize self. See help(type(self)) for accurate signature.
Type | Details | |
---|---|---|
model | PreTrainedModel | A pre-trained language model |
ref_model | PreTrainedModel | A a reference model |
config | RLHFConfig |
RLHFTrainer.compute_loss
RLHFTrainer.compute_loss (query_ids:typing.Annotated[torch.Tensor,{'__tor chtyping__':True,'details':('batch_size','seq_l en',),'cls_name':'TensorType'}], query_attentio n_mask:typing.Annotated[torch.Tensor,{'__torcht yping__':True,'details':('batch_size','seq_len' ,),'cls_name':'TensorType'}], response_ids:typi ng.Annotated[torch.Tensor,{'__torchtyping__':Tr ue,'details':('batch_size','seq_len',),'cls_nam e':'TensorType'}], response_attention_mask:typi ng.Annotated[torch.Tensor,{'__torchtyping__':Tr ue,'details':('batch_size','seq_len',),'cls_nam e':'TensorType'}], rewards:typing.Annotated[tor ch.Tensor,{'__torchtyping__':True,'details':('b atch_size',),'cls_name':'TensorType'}])
Calculate PPO’s loss.