Agent
RL-based Language model Agent
Agent (The RL-based language model)
Agent
Agent (model:transformers.modeling_utils.PreTrainedModel)
The RL-based language model.
Type | Details | |
---|---|---|
model | PreTrainedModel | a pre-trained transformers model |
Agent.forward
Agent.forward (input_ids:typing.Annotated[torch.Tensor,{'__torchtyping__' :True,'details':('batch_size','seq_len',),'cls_name':'Tens orType'}], attention_mask:Optional[Annotated[torch.Tensor, {'__torchtyping__':True,'details':('batch_size,seq_len',), 'cls_name':'TensorType'}]]=None)
summary
Agent Objective
Equation 2 in the paper https://arxiv.org/abs/2203.02155
\(\begin{aligned} \operatorname{objective~}(\phi)= & E_{(x, y) \sim D_{\pi_\phi^{\mathrm{RL}}}}\left[r_\theta(x, y)-\beta \log \left(\pi_\phi^{\mathrm{RL}}(y \mid x) / \pi^{\mathrm{SFT}}(y \mid x)\right)\right]+ \\ & \gamma E_{x \sim D_{\text {pretrain }}}\left[\log \left(\pi_\phi^{\mathrm{RL}}(x)\right)\right]\end{aligned}\)
AgentObjective
AgentObjective (model:transformers.modeling_utils.PreTrainedModel, sft_model:transformers.modeling_utils.PreTrainedModel, reward_model:Callable, gamma:float, beta:float)
Agent objective.
Type | Details | |
---|---|---|
model | PreTrainedModel | the language model |
sft_model | PreTrainedModel | the reference model |
reward_model | typing.Callable | the reward model |
gamma | float | |
beta | float |
AgentObjective.forward
AgentObjective.forward (input_ids:typing.Annotated[torch.Tensor,{'__torch typing__':True,'details':('batch_size','seq_len', ),'cls_name':'TensorType'}], attention_mask:typin g.Annotated[torch.Tensor,{'__torchtyping__':True, 'details':('batch_size','seq_len',),'cls_name':'T ensorType'}])
Calculate the objective value given the input ids and attention mask.
Type | Details | |
---|---|---|
input_ids | typing.Annotated[torch.Tensor, {‘torchtyping’: True, ‘details’: (‘batch_size’, ‘seq_len’,), ‘cls_name’: ‘TensorType’}] | |
attention_mask | typing.Annotated[torch.Tensor, {‘torchtyping’: True, ‘details’: (‘batch_size’, ‘seq_len’,), ‘cls_name’: ‘TensorType’}] | |
Returns | typing.Annotated[torch.Tensor, {‘torchtyping’: True, ‘details’: (1,), ‘cls_name’: ‘TensorType’}] | A scalar objective value |