1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
| class DPOTrainer:
def __init__(self, model, reference_model, beta: float = 0.1):
self.model = model
self.reference_model = reference_model
self.beta = beta
# 冻结参考模型
for param in self.reference_model.parameters():
param.requires_grad = False
def compute_dpo_loss(self, prompts, chosen, rejected):
"""计算DPO损失"""
# 获取策略模型的对数概率
pi_logprobs_chosen = self.get_logprobs(self.model, prompts, chosen)
pi_logprobs_rejected = self.get_logprobs(self.model, prompts, rejected)
# 获取参考模型的对数概率
with torch.no_grad():
ref_logprobs_chosen = self.get_logprobs(
self.reference_model, prompts, chosen
)
ref_logprobs_rejected = self.get_logprobs(
self.reference_model, prompts, rejected
)
# 计算对数比率
pi_logratios = pi_logprobs_chosen - pi_logprobs_rejected
ref_logratios = ref_logprobs_chosen - ref_logprobs_rejected
# DPO损失
losses = -F.logsigmoid(self.beta * (pi_logratios - ref_logratios))
# 添加正则化
chosen_rewards = self.beta * (
pi_logprobs_chosen - ref_logprobs_chosen
).detach()
rejected_rewards = self.beta * (
pi_logprobs_rejected - ref_logprobs_rejected
).detach()
return losses.mean(), chosen_rewards, rejected_rewards
def get_logprobs(self, model, prompts, responses):
"""获取响应的对数概率"""
inputs = self.tokenizer(
[p + r for p, r in zip(prompts, responses)],
return_tensors="pt",
padding=True,
truncation=True
)
with torch.no_grad() if model == self.reference_model else nullcontext():
outputs = model(**inputs, labels=inputs["input_ids"])
# 提取响应部分的对数概率
logits = outputs.logits
labels = inputs["input_ids"]
# 计算对数概率
logprobs = F.log_softmax(logits, dim=-1)
# 获取标签对应的对数概率
selected_logprobs = torch.gather(
logprobs, 2, labels.unsqueeze(-1)
).squeeze(-1)
# 只计算响应部分
prompt_lens = [len(self.tokenizer(p)["input_ids"]) for p in prompts]
response_logprobs = []
for i, prompt_len in enumerate(prompt_lens):
response_logprobs.append(
selected_logprobs[i, prompt_len:].sum()
)
return torch.stack(response_logprobs)
|