LLM 模型大小估算与硬件选择(LLM Hardware and Model Size) #
模型大小估算 #
Note:什么是字节(Byte)?
字节(Byte) 是计算机中数据存储的基本单位之一。一个字节等于 8 位(bit)。字节用于表示一个字符或者基本的数据单元,在不同的数据类型中,它代表了数据的存储容量。
- FP32(32-bit Floating Point):标准的浮点数表示方法,广泛用于训练深度学习模型,尤其是在需要高精度的计算中。占用字节数:4 字节。
- FP16(16-bit Floating Point):通常用于加速推理和减少内存占用,特别是在现代 GPU(如 NVIDIA 的 Volta 和 Ampere 架构)中,FP16 被广泛应用。占用字节数:2 字节。
- INT8(8-bit Integer):常用于量化模型,将浮点数转换为 8 位整数,以减少内存和计算需求,特别是在边缘设备或移动设备上推理时。占用字节数:1 字节。
- Char(Character):通常用于表示 ASCII 或其他单字节字符编码。占用字节数:1 字节。
Note:字节(Byte)、MB(兆字节)和GB(千兆字节)之间的关系如下
- 1 字节 (Byte) 是计算机存储数据的基本单位。
- 1 千字节 (Kilobyte, KB) = 1,024 Byte
- 1 兆字节 (Megabyte, MB) = 1,024 KB = 1,024 × 1,024 Byte
- 1 千兆字节 (Gigabyte, GB) = 1,024 MB = 1,024 × 1,024 × 1,024 Byte
- 1 太字节 (Terabyte, TB) = 1,024 GB = 1,024 × 1,024 × 1,024 × 1,024 Byte
Note:假设一个 LLM 说它是 7B(如 LLaMA-7B),表示它有 7 Billion(70 亿, 7x10^9) 个参数。如果以 FP32 精度储存,则推理(inference)所需要的内存大致为:
\[ 7 \times 10^9 \times 4 Bytes \approx 26 GB \approx 7B \times 4 Bytes = 28 GB \]
训练阶段 #
在估算大型语言模型(LLM)训练阶段的模型大小时,需要综合考虑以下因素及其物理意义:
模型参数(Parameters):参数是模型的核心组成部分,直接影响模型的容量和内存占用。LLM的参数主要集中在:
- Transformer层的权重矩阵(自注意力层、前馈网络)
- 词嵌入矩阵(输入/输出嵌入)
对于包含 \(N\) 个Transformer层的模型,参数总量可近似为:
\[ Params≈12Nd_{model}^2+Vd_{model} \]- \(d_{model}\) :隐藏层维度(如4096)
- \(V\) :(如50,000)
- 示例:GPT-3( \(N=96, d_{model}=12288, V=50257)\) : \[ Params=12×96×12288^2+50257×12288≈175亿 \] 若参数用FP16存储(2字节/参数),则175B参数的模型需: \[ 175×10^9×2字节=350 GB \]
Note:像 LLaMA-7B 这样的标注通常指的是 模型的参数量,也就是 训练好的权重(weights)和偏置(biases) 的总数量。它不包括 梯度(gradients)、优化器状态(optimizer states)或中间激活值(activations)等训练过程中需要的额外存储。
梯度(Gradients):反向传播时需保存每个参数的梯度,其大小与参数数量一致。
优化器状态(Optimizer States):优化器(如Adam)需额外保存动量(momentum)和方差(variance),显存占用通常是参数的数倍。并且一般采用精度较高的 FP32 储存。
- 例如,Adam优化器的显存需求: \[ 优化器状态=Params×(2×4字节)(FP32存储m和v) \]
中间激活值(Activations):在 前向传播 和 反向传播 阶段,网络每一层的 激活值(即每一层的输出) 都需要存储。尤其是对于大规模模型,中间激活值 会占据大量的显存。
- 内存需求:中间激活值的内存需求取决于模型的 输入数据大小(例如 batch size) 和每一层的输出维度。每一层的输出通常是 矩阵或张量,这些需要存储在内存中,直到反向传播结束。
- 影响因素:影响中间激活值大小的因素有 批大小(batch size) 和 每层的激活维度(例如 Transformer 模型的 \(d_model\) 大小)。
总训练显存估算可以总结为: \[ 总显存=参数+梯度+优化器状态+激活值+其他 \]
Note:假设我们有一个 7B参数的模型,使用 FP32存储参数,训练时使用 Adam优化器(假设使用了动量和平方梯度),批大小为 32,每层输出维度为 d_model=4096。
\[ 7B \times 4 \text{字节} = 28 \text{GB} \]
- 参数存储:
\[ 7B \times 4 \text{字节} = 28 \text{GB} \]
- 梯度存储:
\[ 7B \times 2 \times 4 \text{字节} = 56 \text{GB} \]
- 优化器状态(假设Adam优化器):
\[ 48 \times 32 \times 2048 \times 4096 \times 4 \text{字节} \approx 1.4 TB \]
- 激活存储(假设每层大小是 [batch_size, seq_len, d_model],假设有 48 层和 batch_size = 32,seq_len = 2048):
总内存需求(仅计算主要内存需求,不考虑中间优化等):
\[ 28 \text{GB} + 28 \text{GB} + 56 \text{GB} + 1.4 \text{TB} = 1.48 \text{TB} \]
Note:Checkpoint 通常保存以下关键信息,以便后续恢复训练或进行推理(Inference)。
\[ Checkpoint 大小 ≈ 模型参数 + 优化器状态 + 额外训练信息 \]例如:
checkpoint = { "model_state": model.state_dict(), "optimizer_state": optimizer.state_dict(), "scheduler_state": scheduler.state_dict(), "epoch": epoch } torch.save(checkpoint, "checkpoint.pth")
推理阶段 #
在 LLM inference(推理) 阶段,估算 模型大小和显存需求 时,主要考虑 参数(parameters)、参数类型(precision)、梯度(gradients)、优化器状态(optimizer state)、中间激活值(activations)。
- 模型参数(Parameters):参数(weights)是模型的核心部分,它们在推理阶段不更新,仅用于计算。
- 中间激活值(Activations):在推理(inference)过程中,激活值(activations)只需要保留当前计算所需的部分,而不需要像训练时那样保存所有层的激活值。这是因为推理阶段不需要进行反向传播(backpropagation),所以不需要存储完整的计算图和梯度信息。
- KV-Cache(Key-Value Cache)
总推理显存估算可以总结为: \[ 总显存=参数+激活值+KQ Cache+其他 \]
例如:
部分 | 计算公式 | 显存需求 |
---|---|---|
参数存储 | 7B × 2B | 14GB |
激活值 | 2048 × 4096 × 2B | 16MB |
KV-Cache | 2 × 2048 × 32 × 128 × 2B | 16MB |
其他开销 | OS + CUDA 预留 | 2GB |
合计 | — | 16GB(接近 RTX 4090 限制) |
硬件选择 #
硬件基础知识
组件 作用 对 LLM 的影响 CPU(中央处理器) 处理系统任务,调度 GPU 计算 影响数据加载、预处理速度 GPU(图形处理器) 执行并行计算,加速矩阵运算 影响 LLM 训练 & 推理速度 TPU(张量处理单元) Google 专用 AI 计算芯片,比 GPU 更快 用于 Google Cloud LLM 训练 RAM(内存) 存储临时数据(不等于显存) 数据加载、训练时数据预处理 VRAM(显存) GPU 的专用内存,存放 LLM 参数 影响 LLM 最大可运行模型大小 存储(SSD/HDD) 存放 LLM 训练数据、模型参数 训练时的 I/O 速度 带宽(PCIe/NVLink) 设备间数据传输速度 影响分布式训练效率 LLM 的 训练 和 推理(inference) 需要不同的硬件资源:
对比项 训练 推理 计算量(FLOPs) 极高,多 GPU 并行训练 较低,单 GPU 可完成 显存需求 参数 + 梯度存储 + Adam 优化器 仅参数存储 + 计算缓存 适用 GPU A100, H100, TPU 4090, A100 推荐存储 高速 NVMe SSD(>8TB) 适量存储(<2TB) 推荐 CPU 高核心数(AMD EPYC 64 核) 普通 CPU(i9, Ryzen 9) 云计算 vs. 本地部署
选项 优点 缺点 云计算(AWS, Google Cloud) 高性能(H100, TPU) 成本高,GPU 计费 本地服务器(A100, 4090) 长期成本低 初期购置费用高 边缘 AI(M2, Jetson) 低功耗,适合移动端 只能运行小型模型