Unsloth 到底快在哪#
第一次用 Unsloth 是我被一个单卡 LoRA 任务憋住:4090 24GB,要微调一个 13B 模型,vanilla LoRA OOM,QLoRA 勉强跑但一个 epoch 12 小时。同事甩了个 Unsloth 的链接:同样的卡、同样的模型、同样的数据,一个 epoch 3 小时,显存只用 18GB。
这种数量级的差距不是"优化"能解释的,肯定是底层重写了。去翻源码以后确认了:Unsloth 把 LoRA 训练里的几个关键 kernel 全部用 Triton 手写了一遍,顺便把反向传播路径做了手工推导和重排。官方论文里引用过具体的加速数字,我这里不重复那些数字(避免把论文指标当官方 benchmark),只谈原理和实操。
这篇文章讲清楚三件事:
- Unsloth 的加速机制到底是什么
- 怎么在自己的项目里用起来
- 哪些场景合适、哪些不合适,以及踩过的坑
一、加速机制拆解#
Unsloth 的性能提升来自四个方面,没有一个是魔法,都是把通用实现替换成针对 LoRA QLoRA 场景的定制路径。
1.1 手写 Triton kernel 替换 HuggingFace 的前反向#
HuggingFace Transformers 的前反向是 PyTorch 组合 + 少量 C++/CUDA op 拼出来的,灵活但开销大。典型 LLaMA 一个 decoder layer 的前向要触发几十个 kernel launch。
Unsloth 把几个关键 op 用 Triton 重写并融合:
- RMSNorm:融合平方求和 + rsqrt + mul
- RoPE:apply + 缓存融合
- SwiGLU:gate × silu × up 融合成一个 kernel
- Cross-entropy loss:融合 logits 计算 + log_softmax + gather + 反向
融合的直接收益是 kernel launch 次数大幅减少,HBM 往返也减少,两个都是现代 GPU 上非 compute-bound 场景的主要瓶颈。
1.2 手工推导的反向传播#
PyTorch 的 autograd 是通用的,但它对"通用"有代价——很多中间 tensor 要保存用于反向。Unsloth 对 LoRA 路径手工推导了反向,只保存真正必要的中间量,剩下的在反向时就地重算。
一个典型例子:RMSNorm 的反向只需要输入 x 和 rstd(反向里重算的平方和倒数),不需要保存 norm 后的激活。这种 trade-off 用计算换显存,在现代 GPU 上计算比显存便宜,划算。
1.3 4bit dequant 路径优化#
QLoRA 的核心操作是"读 4bit 权重 → 反量化成 fp16 → 和激活做 matmul"。bitsandbytes 的实现里 dequant 和 matmul 是两个独立 kernel,中间要把反量化结果写回 HBM 再读出来。
Unsloth 把 dequant 融合进 matmul 的 prologue:在 shared memory 里即时反量化再参与计算,避免中间 tensor 落盘。这是"4bit QLoRA 比 16bit LoRA 更快"这个反直觉现象的根源——Unsloth 的 4bit 路径比 bitsandbytes 原生快 2-3 倍。
1.4 只对 LoRA 路径求梯度#
vanilla peft 会把 base 模型的参数冻结,但 requires_grad=False 的 tensor 仍然会走完整 autograd 图。Unsloth 进一步把图裁剪,基础模型的反向只做到 “能把梯度传到 LoRA adapter” 的最小必要步骤,其他全部短路。
1.5 总结#
这几个优化单独看都不是颠覆性的,叠加起来:
- 显存:节省 30-60%(对比 bitsandbytes QLoRA)
- 速度:快 1.5-2.5x(对比 HF + peft)
代价是适用面变窄——Unsloth 只深度优化了特定的模型架构(主要是 LLaMA/Mistral/Gemma/Qwen 几个家族)和特定的训练方法(LoRA、QLoRA、DPO)。不在这个白名单里的场景要么用不了,要么退化到原生路径没加速。
二、支持范围#
官方支持的模型家族(以我实际测过的为准):
- LLaMA 2 / 3 / 3.1 / 3.2 / 3.3 全系
- Mistral / Mixtral
- Qwen 1.5 / 2 / 2.5 系列
- Gemma 1 / 2 / 3
- DeepSeek R1 Distill 系列
- Phi 3 / 4
支持的训练方法:
- SFT(LoRA / QLoRA / 全参有限支持)
- DPO / ORPO / KTO
- GRPO(推理模型训练)
- 继续预训练 CPT
不支持或退化:
- Encoder-Decoder 架构(T5、BART)
- 不常见的注意力变体
- 多机多卡训练(Unsloth 核心优化是单卡的,多卡支持较弱)
三、硬件要求#
- Ampere 及以后(RTX 30 / 40 / 50 系列,A100,H100,L40,L4 等)
- 推荐至少 16GB 显存
- Hopper 上效果最好(FP8、H100 的 wgmma)
Turing (T4, V100) 上 Unsloth 可以跑但优化受限,没必要折腾。
四、安装#
官方推荐 pip 安装,但版本锁得紧:
pip install "unsloth[cu121-ampere] @ git+https://github.com/unslothai/unsloth.git"
方括号里是你的 CUDA + GPU 架构组合:
cu121-ampere:CUDA 12.1 + Amperecu121-hopper:CUDA 12.1 + Hoppercu121-ada:CUDA 12.1 + Ada (40 系)
装错 arch 不会直接报错,但 kernel 编译会走 fallback 路径,速度降一半。安装后跑:
import unsloth
print(unsloth.__version__)
然后看一下 nvidia-smi 里的 CUDA / 驱动版本是不是匹配。
五、最小可用示例#
Unsloth 的 API 设计很"HuggingFace 化",几行替换就能让原本的脚本受益。
5.1 SFT LoRA 示例#
from unsloth import FastLanguageModel
from datasets import load_dataset
from trl import SFTTrainer, SFTConfig
max_seq_length = 4096
model, tokenizer = FastLanguageModel.from_pretrained(
model_name = "unsloth/Llama-3.1-8B-Instruct-bnb-4bit",
max_seq_length = max_seq_length,
dtype = None, # None = 自动选 bf16/fp16
load_in_4bit = True,
)
# 注入 LoRA
model = FastLanguageModel.get_peft_model(
model,
r = 32,
target_modules = [
"q_proj", "k_proj", "v_proj", "o_proj",
"gate_proj", "up_proj", "down_proj",
],
lora_alpha = 64,
lora_dropout = 0.0, # 0 比非 0 快很多
bias = "none", # "none" 比 "all" 快
use_gradient_checkpointing = "unsloth", # 特殊值,用 Unsloth 自己的 checkpointing
random_state = 42,
use_rslora = False,
loftq_config = None,
)
dataset = load_dataset("json", data_files="train.jsonl", split="train")
def format_example(ex):
messages = ex["conversations"]
text = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=False)
return {"text": text}
dataset = dataset.map(format_example)
trainer = SFTTrainer(
model = model,
tokenizer = tokenizer,
train_dataset = dataset,
dataset_text_field = "text",
max_seq_length = max_seq_length,
packing = True, # 开启样本 packing
args = SFTConfig(
per_device_train_batch_size = 4,
gradient_accumulation_steps = 4,
num_train_epochs = 3,
learning_rate = 2e-4,
warmup_ratio = 0.03,
lr_scheduler_type = "cosine",
bf16 = True,
logging_steps = 10,
save_steps = 500,
output_dir = "/checkpoints/llama8b-unsloth",
optim = "adamw_8bit",
weight_decay = 0.01,
report_to = "none",
seed = 42,
),
)
trainer.train()
# 保存 LoRA
model.save_pretrained("/checkpoints/llama8b-unsloth/lora")
tokenizer.save_pretrained("/checkpoints/llama8b-unsloth/lora")
几个 Unsloth 专属的点:
FastLanguageModel.from_pretrained:替代 HF 的AutoModelForCausalLM,返回 patch 过的模型model_name前缀是unsloth/...:这些是 Unsloth 官方提前做好的 4bit 权重,加载更快,也可以用普通 HF 路径use_gradient_checkpointing = "unsloth":特殊字符串,启用 Unsloth 版本的 checkpointing,比 PyTorch 原生省更多显存optim = "adamw_8bit":8bit AdamW,优化器状态也压缩,进一步省显存packing = True:把多个短样本拼成一个 max_seq_length,提升显存利用率
5.2 DPO 示例#
from unsloth import FastLanguageModel, PatchDPOTrainer
PatchDPOTrainer() # 必须在 DPOTrainer 之前调用
from trl import DPOTrainer, DPOConfig
model, tokenizer = FastLanguageModel.from_pretrained(
model_name = "/checkpoints/llama8b-sft-merged",
max_seq_length = 4096,
load_in_4bit = True,
)
model = FastLanguageModel.get_peft_model(
model,
r = 16,
target_modules = ["q_proj", "k_proj", "v_proj", "o_proj",
"gate_proj", "up_proj", "down_proj"],
lora_alpha = 32,
lora_dropout = 0.0,
bias = "none",
use_gradient_checkpointing = "unsloth",
)
dpo_trainer = DPOTrainer(
model = model,
ref_model = None, # Unsloth 自动处理 ref
tokenizer = tokenizer,
train_dataset = dpo_dataset,
args = DPOConfig(
per_device_train_batch_size = 2,
gradient_accumulation_steps = 4,
num_train_epochs = 2,
learning_rate = 5e-6,
lr_scheduler_type = "cosine",
warmup_ratio = 0.1,
bf16 = True,
beta = 0.1,
loss_type = "sigmoid",
max_length = 4096,
max_prompt_length = 2048,
output_dir = "/checkpoints/llama8b-dpo",
),
)
dpo_trainer.train()
PatchDPOTrainer() 必须在 DPOTrainer 导入/使用前调用,这是 Unsloth 的 monkey patch 机制——它要在 TRL 的类上打补丁把关键 kernel 替换掉。
六、和 LLaMA Factory 的组合用法#
LLaMA Factory 0.8+ 集成了 Unsloth 路径,YAML 配置加一行就行:
### model
model_name_or_path: /models/Llama-3.1-8B-Instruct
use_unsloth: true
### method
stage: sft
finetuning_type: lora
lora_target: all
lora_rank: 32
lora_alpha: 64
### dataset
dataset: my_sft_data
template: llama3
cutoff_len: 4096
### train
per_device_train_batch_size: 4
gradient_accumulation_steps: 4
learning_rate: 2e-4
num_train_epochs: 3
bf16: true
注意:
use_unsloth: true和deepspeed互斥,Unsloth 的多机支持弱use_unsloth: true和quantization_bit: 4同时生效时走 Unsloth 的 4bit 路径- 某些模型 + Unsloth 的组合不稳定,LLaMA Factory 里 WebUI 会有兼容性提示
我的日常做法:单卡任务必开 use_unsloth,多卡 DDP 不开。
七、显存和速度的经验数据#
下面这张表是我在 24GB/48GB/80GB 三档显存上测过的大致范围(bf16 + 4bit + packing):
| 模型 | 方案 | 24GB 能跑 | 48GB 能跑 | 80GB 能跑 |
|---|---|---|---|---|
| LLaMA 8B | LoRA bs=4 len=4096 | ✓ | ✓ | ✓ |
| LLaMA 8B | LoRA bs=8 len=4096 | 紧 | ✓ | ✓ |
| LLaMA 8B | LoRA bs=4 len=8192 | ✓ | ✓ | ✓ |
| Qwen 14B | QLoRA bs=2 len=4096 | ✓ | ✓ | ✓ |
| Qwen 14B | LoRA bs=2 len=4096 | ✗ | ✓ | ✓ |
| LLaMA 32B | QLoRA bs=1 len=4096 | 紧 | ✓ | ✓ |
| LLaMA 70B | QLoRA bs=1 len=2048 | ✗ | ✗ | ✓ |
单卡 24GB 能跑 14B QLoRA 是 Unsloth 最让人惊艳的点——用 HF + peft + bitsandbytes 直接 OOM。
八、合并与导出#
Unsloth 提供了方便的合并导出方法:
# 保存 16bit 合并后模型(用于 vLLM/SGLang 推理)
model.save_pretrained_merged(
"/models/llama8b-biz-merged",
tokenizer,
save_method = "merged_16bit",
)
# 只保存 LoRA
model.save_pretrained("/checkpoints/llama8b-lora")
# 保存到 GGUF(llama.cpp)
model.save_pretrained_gguf(
"/models/llama8b-biz-gguf",
tokenizer,
quantization_method = "q4_k_m", # 或 q5_k_m, q8_0, f16
)
save_method 常用值:
merged_16bit:合并后保存为 bf16/fp16merged_4bit:合并后保存为 4bit(适合部署在显存紧张的推理节点)lora:只保存 adaptermerged_4bit_forced:强制 4bit(某些模型默认不许)
GGUF 导出功能是 Unsloth 的一个大杀器——训完直接生成 llama.cpp 可以吃的格式,配合树莓派/Mac 本地部署非常丝滑。
九、调优 tips#
9.1 packing 开不开#
- 短样本多、长度差异大:开,显存利用率提升明显
- 样本长度已经接近 max_seq_length:开不开差不多
- 对序列内部位置很敏感的任务:关(packing 会把多个样本拼在一起,虽然有 attention mask 但个别模型会受影响)
9.2 lora_dropout 是不是该开#
Unsloth 明确说 lora_dropout=0 速度最快,因为非零 dropout 会走额外的 kernel。经验上数据量大(>20k)时 dropout=0 没问题;数据量小(<5k)且训多 epoch 开 0.05-0.1 防过拟合。
9.3 optim 选哪个#
adamw_8bit:bitsandbytes 的 8bit AdamW,省显存adamw_torch:PyTorch 原生paged_adamw_8bit:在显存紧张时把优化器状态 paged 到 CPU
默认 adamw_8bit,OOM 时换 paged_adamw_8bit。
9.4 gradient_accumulation#
Unsloth 的融合 kernel 对大 accumulation 也友好。显存不够就减 batch + 增 accumulation,保持有效 batch 不变。
十、踩坑合集#
坑 1:Unsloth 和 HF Transformers 版本冲突#
Unsloth 依赖特定 transformers 版本,升级 transformers 可能导致 monkey patch 失效。解法:
- 创建独立 conda env,不要和其他项目共享
pip install时指定 transformers 版本上限- 遇到报错第一反应是降级 transformers
坑 2:模型加载时某些 key 不匹配#
如果你要用的模型不在 Unsloth 官方预转换的 4bit 列表里,从 HF 原始仓库加载时偶尔会遇到 key 不匹配报错。解法:
model, tokenizer = FastLanguageModel.from_pretrained(
model_name = "meta-llama/Llama-3.1-8B",
max_seq_length = 4096,
load_in_4bit = True,
device_map = "auto",
)
加 device_map="auto" 有时候能绕过。不行的话只能等 Unsloth 升级支持。
坑 3:不支持的 LoRA target#
Unsloth 对 lora_target 只支持常规的 7 个线性层。自定义的 target(比如 embedding、lm_head)用不了或退化。
坑 4:多卡 DDP 不稳定#
Unsloth 对多卡的支持长期处于"能跑但偶尔崩"状态。典型症状是训练中途 NCCL hang 或 loss 突然爆炸。多卡建议用 LLaMA Factory 默认路径(不开 Unsloth)+ DeepSpeed ZeRO。
坑 5:gradient_checkpointing 模式#
use_gradient_checkpointing = "unsloth" 是 Unsloth 的专属值,比 HuggingFace 的 True 更省显存但对某些模型有兼容性问题。遇到怪异崩溃时可以改回 True 试试。
坑 6:tokenizer 的 chat_template#
Unsloth 的 apply_chat_template 用的是 tokenizer 自带的,如果你的 tokenizer 没设置(比如一些 base 模型而不是 instruct),apply 会报错。解法:手动设一个 template,或者用 unsloth.chat_templates 里预设的。
from unsloth.chat_templates import get_chat_template
tokenizer = get_chat_template(tokenizer, chat_template="llama-3.1")
坑 7:RTX 40 系 flash-attn 版本#
40 系 GPU 和某些 flash-attn 版本的 wgmma 代码路径不兼容,报 unknown architecture 之类的错。解法:装最新 flash-attn 或 pip install flash-attn --no-build-isolation。
坑 8:导出 GGUF 时调用 llama.cpp 失败#
GGUF 导出底层调用 llama.cpp/convert.py,需要系统里装有 llama.cpp 仓库。Unsloth 会尝试自动 clone,但有时网络问题失败。提前手动 clone:
git clone https://github.com/ggerganov/llama.cpp.git
cd llama.cpp && make
然后把路径告诉 Unsloth:
model.save_pretrained_gguf(
"/models/llama8b-gguf",
tokenizer,
quantization_method = "q4_k_m",
# save_method will pick up llama.cpp from PATH
)
坑 9:batch_size 太大静默退化#
有时候你设了 per_device_train_batch_size=8 但内部因为某个形状不匹配,Unsloth 偷偷降 batch size,表现是速度没提升显存也没涨。看日志第一行确认 actual batch size。
坑 10:LoRA 保存后 vLLM 加载失败#
Unsloth 保存的 LoRA 目录少了某些文件(比如 adapter_config.json 里的一个字段)让 vLLM 加载失败。解法:用 save_pretrained_merged 合并后保存完整模型再部署,别用动态 adapter。
十一、什么时候不该用 Unsloth#
Unsloth 的加速很诱人,但不是万能药。下面这些场景我会不用 Unsloth:
- 多机多卡训练:Unsloth 不是为多机设计的,跑起来不稳定
- 全参数 SFT:Unsloth 的收益主要在 LoRA / QLoRA 路径,全参几乎没差
- 非主流模型架构:支持列表之外的模型,退化到通用路径没意义
- 需要自定义训练 loop:Unsloth 的 monkey patch 假设你用 HF Trainer / TRL,自己写 loop 容易踩坑
- 生产化 CI/CD:Unsloth 版本更新快,API 偶尔 break,CI 里锁版本维护成本不低
最适合 Unsloth 的场景:
- 单卡 LoRA / QLoRA SFT / DPO
- 研究型快速实验
- 个人开发者、小团队
- 需要低门槛导出 GGUF 本地运行
十二、和 LLaMA Factory/Axolotl 的组合建议#
我日常的栈:
- 实验阶段:Unsloth 原生脚本,单卡 Jupyter 里快速试
- 训练主流程:LLaMA Factory +
use_unsloth: true,YAML 驱动可复现 - 多卡大任务:LLaMA Factory(不开 Unsloth)+ DeepSpeed ZeRO-2
- 导出 GGUF 给本地:Unsloth 的
save_pretrained_gguf
三者不是替代关系是组合关系。Unsloth 提供底层 kernel,LLaMA Factory 提供工作流,TRL 提供算法。最佳组合是三个都懂,按场景切换。
十三、一个实际例子:3090 训 Qwen 14B#
一个完整配置,单卡 RTX 3090 24GB 训 Qwen 14B QLoRA:
model_name_or_path: /models/Qwen2.5-14B-Instruct
use_unsloth: true
quantization_bit: 4
quantization_type: nf4
stage: sft
finetuning_type: lora
lora_target: all
lora_rank: 16
lora_alpha: 32
lora_dropout: 0.0
use_gradient_checkpointing: unsloth
dataset: my_sft
template: qwen
cutoff_len: 2048
max_samples: 15000
preprocessing_num_workers: 8
packing: true
per_device_train_batch_size: 2
gradient_accumulation_steps: 8
learning_rate: 2e-4
num_train_epochs: 3
lr_scheduler_type: cosine
warmup_ratio: 0.05
bf16: true
optim: adamw_8bit
weight_decay: 0.01
logging_steps: 10
save_steps: 500
save_total_limit: 3
output_dir: /checkpoints/qwen14b-biz-sft
实测 3090 上:
- 显存峰值 约 20GB
- 15000 条样本 × 3 epochs × cutoff 2048
- 训练时间 5-7 小时(具体取决于数据 packing 效率)
用原生 HF + peft 同样配置根本跑不起来(OOM)。
十四、上线 checklist#
[ ] conda env 独立,依赖版本锁定
[ ] GPU arch 和 pip install 参数匹配
[ ] use_gradient_checkpointing="unsloth"
[ ] lora_dropout=0, bias="none"(除非有特殊需求)
[ ] packing=True(短样本场景)
[ ] optim=adamw_8bit
[ ] 导出阶段用 merged_16bit 做推理,不用 adapter 动态挂
[ ] 用 vLLM/SGLang 跑 smoke test 确认合并模型能正常加载生成
[ ] eval 集跑过确认无退化
[ ] 训练 log / config / commit hash 归档
[ ] 如果用 LLaMA Factory,transformers 和 unsloth 版本组合测过
十五、收尾#
Unsloth 是那种用过就不想走回头路的工具——前提是你的场景对口:单卡 LoRA。它本来就不是通用训练框架,而是单卡 LoRA 的极致加速器。想清楚这个定位,别指望它做多卡大集群训练。
我自己的组合拳是:单卡试错 Unsloth,多卡生产 LLaMA Factory。两个一起用,95% 的微调场景够了。






