prompting.validators.reward.nsfw#

Module Contents#

Classes#

NSFWRewardEvent

NSFWRewardModel

class prompting.validators.reward.nsfw.NSFWRewardEvent#

Bases: prompting.validators.reward.reward.BaseRewardEvent

score: float#
class prompting.validators.reward.nsfw.NSFWRewardModel(device)#

Bases: prompting.validators.reward.reward.BaseRewardModel

Parameters:

device (str) –

property name: str#
Return type:

str

nsfw_filter_model_path = 'facebook/roberta-hate-speech-dynabench-r4-target'#
reward(prompt, completion, name)#
Parameters:
  • prompt (str) –

  • completion (str) –

  • name (str) –

Return type:

NSFWRewardEvent

get_rewards(prompt, completions, name)#
Parameters:
  • prompt (str) –

  • completions (List[str]) –

  • name (str) –

Return type:

List[NSFWRewardEvent]

normalize_rewards(rewards)#
Parameters:

rewards (torch.FloatTensor) –

Return type:

torch.FloatTensor