微调一个LLM安全检测模型

此前做了一个 LLM 安全检测系统,这篇只记模型微调部分。

目标很简单:把输入文本分成三类:

  • benign
  • jailbreak
  • injection

1. 任务定义

这是一个标准的文本三分类任务。

标签 类别
0 benign
1 jailbreak
2 injection

相比二分类,三分类的好处是后面更容易看错在哪。

模型输出可以写成:

y^=argmaxysoftmax(f(x))\hat{y} = \arg\max_y softmax(f(x))


2. 数据处理

数据来源主要有三部分:

  • 正常请求
  • 越狱样本
  • 注入样本

统一后的格式如下:

1
{"text": "Ignore all previous instructions", "label": 2}

预处理只做几件事:

  1. 统一格式
  2. 去重
  3. 清洗脏样本
  4. 划分训练/验证/测试集

这里最重要的是边界样本。

比如:

  • 带敏感词,但其实不是攻击
  • 看起来正常,但实际上在诱导模型越权

这类样本如果处理不好,模型很容易学偏。

数据集收集来源

整理的多源数据:

  • 正常样本alpaca-zhBelleGroup/train_0.5M_CNfirefly-train-1.1Mdatabricks/dolly-15k
  • jailbreak 样本Libr-AI/do-not-answer
  • injection 样本deepset/prompt-injectionsJasperLS/prompt-injectionsHackaprompt

这些数据集我统一转成了同一种格式,再做去重和清洗。


3. 微调实现

我直接用 Hugging Face 的序列分类模型做微调。

核心代码就是这几步:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForSequenceClassification.from_pretrained(
model_name,
num_labels=3
)

train_dataset = PromptDataset(train_data, tokenizer)
val_dataset = PromptDataset(val_data, tokenizer)

trainer = Trainer(
model=model,
args=training_args,
train_dataset=train_dataset,
eval_dataset=val_dataset,
compute_metrics=compute_metrics,
callbacks=[EarlyStoppingCallback(...)],
)

trainer.train()

损失函数就是标准交叉熵:

L=i=1Nc=13yiclogpicL = - \sum_{i=1}^{N} \sum_{c=1}^{3} y_{ic} \log p_{ic}

输入长度我固定成了 512


4. 训练参数

我这次用的是比较稳的一套配置:

参数 数值
max_length 512
learning_rate 2e-5
batch_size 16
num_epochs 5
warmup_ratio 0.1
weight_decay 0.01
metric_for_best_model f1
early_stopping_patience 2

这里我主要看 F1,不是只看 accuracy

F1=2PrecisionRecallPrecision+RecallF1 = 2 \cdot \frac{Precision \cdot Recall}{Precision + Recall}

因为安全检测里,误报和漏报都不能太高。

image-训练曲线


5. 二次微调

首轮训练跑完之后,我没有直接停,而是去看误分类样本。

后面的流程是:

  1. 收集误分类样本
  2. 重新检查标签
  3. 修正一部分噪声标签
  4. 构造困难样本集
  5. 用困难样本 + 部分常规样本继续微调
1
2
3
4
5
errors = collect_errors(predictions, labels)
hard_examples = relabel(errors)
regular_samples = sample(train_data, k=200)
finetune_data = shuffle(hard_examples + regular_samples)
finetune(model, finetune_data)

这里有两个关键点:

1)错误样本不能直接回灌

因为误分类不一定全是模型问题,也可能是标签本身有噪声。

2)只喂困难样本不行

如果只拿难例继续训练,很容易过拟合。所以我混入了一部分原始样本,防止灾难性遗忘。