Native JAX/Flax Fine-Tuning (Official Path)

Updated on 4/7/2026

[Translation Pending]\n\n# 第 10 期 | JAX/Flax 原生微调(官方推荐路线) 副标题:使用 Google 官方 gemma PyPI 包进行微调,JAX 生态优势、TPU 训练配置、与 PyTorch 路线的对比选择。

🎯 学习目标

  • 深入理解 Google JAX/Flax 生态系统在大型语言模型(LLM)微调中的独特优势和应用场景。
  • 掌握利用 Google DeepMind 官方 gemma PyPI 包,在 JAX/Flax 框架下对 Gemma 模型进行参数高效微调(PEFT)的流程。
  • 了解如何配置和利用 Google Cloud TPU 资源,为 Gemma 模型提供高性能的训练环境。
  • 对比 JAX/Flax 与 PyTorch 在 Gemma 微调路线上的异同,并根据实际需求做出明智的技术选型。

📖 核心概念讲解

10.1 JAX/Flax 生态在 LLM 训练中的优势

JAX 是 Google DeepMind 开发的一个高性能数值计算库,它结合了 NumPy 的 API、自动微分功能以及 XLA(Accelerated Linear Algebra)编译器的强大性能。Flax 是构建在 JAX 之上的神经网络库,提供了模块化的组件和函数式编程范式,非常适合大规模模型训练。对于 Gemma 这样的先进 LLM,JAX/Flax 提供了以下显著优势:

  1. 极致性能与 XLA 编译器: JAX 的核心在于 XLA 编译器。它能够将 JAX 代码即时(JIT)编译成针对特定硬件(如 CPU、GPU、TPU)高度优化的机器码。这意味着模型训练和推理的计算图在运行前就被优化,避免了 Python 解释器的开销,实现了接近原生 C++ 的性能。
  2. 原生 TPU 支持与优化: JAX 是 Google 内部为 TPU 设计和优化的框架。TPU(Tensor Processing Unit)是 Google 专门为机器学习工作负载设计的 ASIC。JAX 与 TPU 的结合提供了无与伦比的训练速度和效率,尤其是在处理大规模矩阵运算和并行计算时。它能够充分利用 TPU 的高带宽内存和高吞吐量计算能力。
  3. 函数式编程范式: JAX 鼓励纯函数式编程。模型定义、参数更新、数据处理等都被封装为纯函数,这使得代码更加模块化、可测试,并且天然支持并行化和分布式计算,减少了副作用和状态管理的复杂性。
  4. SPMD (Single Program, Multiple Data) 分布式训练: JAX 通过 jax.pmap(parallel map)提供了一种简洁高效的 SPMD 编程模型。开发者只需编写一份代码,JAX 就能自动在多个设备(如多个 TPU 核心)上复制模型和数据,并协调计算,极大地简化了大规模分布式训练的实现。
  5. 精细的内存控制: JAX 允许开发者对内存布局和操作进行更精细的控制,这对于训练拥有数十亿甚至上千亿参数的 LLM 至关重要。在 TPU 上,JAX 能够更好地管理 HBM(High Bandwidth Memory),从而在相同硬件条件下支持更大的模型或批次大小。
  6. 梯度变换与自动微分: JAX 的 jax.grad 提供了灵活的自动微分功能,可以轻松地计算任意阶导数,并支持复杂的微分规则。这对于实现各种优化器和高级训练技术(如混合精度训练、梯度累积)非常方便。
  7. 官方推荐与最佳兼容性: Gemma 是 Google DeepMind 的产物,其原生实现和官方推荐的微调路线就是基于 JAX/Flax。这意味着使用 JAX 进行 Gemma 微调能够获得最佳的兼容性、最新的特性支持以及最直接的性能优化。

10.2 Gemma 官方 gemma PyPI 包介绍

Google DeepMind 针对 Gemma 模型发布了一个官方的 PyPI 包,名为 gemma。这个库是专门为 Gemma 模型在 JAX/Flax 生态中提供便利的工具集,它与 Hugging Face transformers 库的 JAX/Flax 后端有所不同,是 Google 团队为 Gemma 模型量身定制并持续优化的。

gemma 包的核心功能包括:

  • 模型架构实现: 提供了 gemma.model.Gemma 等类,实现了 Gemma 模型的完整 JAX/Flax 架构,包括注意力机制、前馈网络、归一化层等。
  • 配置管理: 包含 gemma.config.GemmaConfig,用于定义模型的各种超参数,如模型尺寸、层数、隐藏层维度、注意力头数等。它还支持通过配置来启用或禁用 LoRA 等微调功能。
  • 分词器(Tokenizer): 提供了 gemma.tokenizer.GemmaTokenizer,用于处理文本的编码和解码,与 Gemma 模型预训练时使用的分词器保持一致。
  • 采样器(Sampler): 包含 gemma.sampler.SampleFn 等,用于在推理阶段根据模型输出的概率分布生成文本。
  • 权重加载与管理: 提供了便捷的函数来加载官方发布的 Gemma 模型权重,并支持将权重映射到 JAX/Flax 模型结构中。
  • 微调工具: 集成了对 LoRA 等参数高效微调方法的支持,允许用户在不修改核心模型结构的情况下,通过配置启用微调。

使用 gemma 包进行微调的优势在于,它提供了与 Gemma 模型原生设计最匹配的实现,能够确保在 JAX/Flax 和 TPU 环境下获得最佳性能和兼容性。

10.3 Gemma 微调策略:JAX/Flax 中的 LoRA

LoRA (Low-Rank Adaptation of Large Language Models) 是一种高效的参数微调(PEFT)技术,它通过在预训练模型的线性层旁插入小型的低秩矩阵来训练,从而显著减少了需要更新的参数量,降低了计算和内存开销,同时保持了良好的性能。在 JAX/Flax 中实现 LoRA 的核心思想是将 LoRA 模块无缝集成到现有的 Flax nn.Module 中。

LoRA 原理简述: 对于一个预训练的权重矩阵 $W_0 \in \mathbb{R}^{d \times k}$,LoRA 在其更新时不是直接修改 $W_0$,而是引入两个低秩矩阵 $A \in \mathbb{R}^{d \times r}$ 和 $B \in \mathbb{R}^{r \times k}$,其中 $r \ll \min(d, k)$。微调过程中,我们只训练 $A$ 和 $B$,而 $W_0$ 保持不变。更新后的权重矩阵变为 $W_0 + BA$,其中 $BA$ 的秩为 $r$。

JAX/Flax 中的 LoRA 实现考量:

  1. 模块化设计: 在 Flax 中,LoRA 可以被实现为一个自定义的 nn.Module,它接收一个原始的线性层作为输入,并在其上添加 LoRA 适配器。

  2. 参数管理: Flax 的 param_initapply 方法使得 LoRA 适配器的新增参数能够与原始模型参数分离管理。在训练时,可以只更新 LoRA 模块的参数。

  3. 集成到 Gemma 模型: gemma PyPI 包很可能已经在 gemma.model.Gemma 的内部实现了 LoRA 适配器的集成,通过 GemmaConfig 中的配置项(如 lora_ranklora_target_modules)来启用和控制 LoRA 行为。 例如,在 Gemma 的 GemmaConfig 中可能包含类似如下的配置:

    from gemma import config as gemma_config
    
    cfg = gemma_config.GemmaConfig(...)
    cfg.lora_rank = 8  # 设置 LoRA 的秩
    cfg.lora_target_modules = ['q_proj', 'v_proj'] # 指定哪些线性层应用 LoRA
    # ... 其他配置
    

    当模型使用此配置初始化时,Flax 内部会在指定的层中自动添加 LoRA 模块。

  4. 梯度计算与优化: JAX 的自动微分机制会正确处理包含 LoRA 适配器的计算图。在 jax.value_and_grad 计算梯度时,只会针对那些设置为可训练的 LoRA 参数计算梯度,然后通过 optax 优化器进行更新。

通过 LoRA,我们可以在保持高效训练的同时,显著减少微调所需的计算资源和存储空间,使得在有限的硬件条件下也能对大型 Gemma 模型进行个性化适配。

10.4 TPU 训练配置与分布式策略

Google Cloud TPU 是进行大规模 JAX/Flax 训练的理想硬件。本节将介绍如何在 Google Cloud 上配置 TPU 环境,并了解 JAX 在 TPU 上的分布式训练策略。

10.4.1 Google Cloud TPU 环境设置

  1. Google Cloud 项目与配额: 确保你有一个有效的 Google Cloud 项目,并且拥有足够的 TPU 配额(例如 tpu-v4-8tpu-v5e-8)。
  2. 安装 Google Cloud SDK: 在本地机器上安装 gcloud 命令行工具,并进行认证:
    gcloud init
    gcloud auth login
    gcloud config set project your-gcp-project-id
    
  3. 创建 TPU VM 实例: TPU VM 是一个运行在 Google Cloud 上的虚拟机,它直接连接到 TPU 芯片。
    export ZONE="us-central1-a" # 选择一个支持 TPU 的区域
    export TPU_TYPE="tpu-v4-8" # 例如 tpu-v4-8, tpu-v5e-8
    export VM_NAME="gemma-tpu-vm"
    export ACCELERATOR_TYPE="v4-8" # 对应 TPU_TYPE
    export RUNTIME_VERSION="tpu-vm-tf-2.15.0-pjrt" # 选择一个合适的运行时版本,通常是pjrt版本
    
    gcloud compute tpus tpu-vm create ${VM_NAME} \
        --zone=${ZONE} \
        --accelerator-type=${ACCELERATOR_TYPE} \
        --version=${RUNTIME_VERSION} \
        --project=${YOUR_GCP_PROJECT_ID}
    
    这将创建一个带有 8 个 TPU v4 核心的 VM。
  4. SSH 连接到 TPU VM:
    gcloud compute tpus tpu-vm ssh ${VM_NAME} --zone=${ZONE}
    
    进入 VM 后,可以像操作普通 Linux 服务器一样安装软件包和运行代码。

10.4.2 JAX 在 TPU 上的分布式策略 (SPMD)

JAX 采用 SPMD(Single Program, Multiple Data)范式进行分布式训练。这意味着你编写一份代码,这份代码会在每个设备(如 TPU 核心)上运行,但处理不同的数据分片。

核心概念:

  • jax.devices() 获取当前可用的所有 JAX 设备(例如 TPU 核心)。
  • jax.local_devices() 获取当前机器上的本地设备。
  • jax.device_count() 获取设备总数。
  • jax.pmap JAX 中用于 SPMD 并行计算的核心原语。它将一个函数映射到所有设备上并行执行。pmap 会自动处理数据的分发和结果的聚合。
    • 当一个函数被 jax.pmap 装饰时,它会在每个设备上独立运行。
    • 函数的输入参数会被沿第一个维度自动分片(默认行为),发送到每个设备。
    • 函数内部的计算会在每个设备上并行进行。
    • axis_name 参数用于在 pmap 内部进行跨设备通信(例如 jax.lax.pmean 用于计算全局平均)。
  • 数据并行: 这是最常见的分布式训练方式。每个设备拥有模型的完整副本,但处理不同批次的数据。梯度在设备之间进行平均,然后用于更新模型参数。JAX 的 pmap 结合 jax.lax.pmean 可以高效实现数据并行。
  • 模型并行: 对于超大型模型,模型的不同部分可以分布在不同的设备上。这通常更复杂,需要仔细设计模型架构和通信模式。JAX 提供了 shard_map 和更底层的 SPMD primitives 来支持模型并行。
  • 参数分片 (Sharding): JAX XLA 编译器能够自动对模型参数进行分片,以适应设备的内存限制。通过 jax.experimental.mesh_utils.create_device_meshjax.sharding.NamedSharding 可以更精细地控制参数和数据在设备网格上的分布。

SPMD 示例结构:

import jax
import jax.numpy as jnp
from jax.experimental import mesh_utils
from jax.sharding import NamedSharding, PartitionSpec as P

# 获取设备
devices = jax.devices()
num_devices = len(devices)
print(f"Number of JAX devices: {num_devices}")

# 创建设备网格 (用于更复杂的 sharding)
mesh = mesh_utils.create_device_mesh((num_devices,))
# 定义 sharding 规则
params_sharding = NamedSharding(mesh, P()) # 所有设备都有完整参数副本 (数据并行)
data_sharding = NamedSharding(mesh, P('batch')) # 批次维度在设备间分片

# 示例:一个简单的训练步骤
def train_step(params, batch, opt_state):
    def loss_fn(params):
        # 假设这是一个简单的模型和损失计算
        logits = jnp.dot(batch['inputs'], params['w']) + params['b']
        loss = jnp.mean((logits - batch['labels'])**2)
        return loss

    loss, grads = jax.value_and_grad(loss_fn)(params)

    # 跨设备平均梯度 (数据并行)
    grads = jax.lax.pmean(grads, axis_name='devices')

    # 更新参数 (使用 optax)
    # updates, new_opt_state = optimizer.update(grads, opt_state, params)
    # new_params = optax.apply_updates(params, updates)

    return loss #, new_params, new_opt_state

# 使用 jax.pmap 编译训练步骤
# 'devices' 是 axis_name,用于 pmean 等操作
p_train_step = jax.pmap(train_step, axis_name='devices')

# 假设初始化参数和优化器状态,并准备数据
# params = ...
# opt_state = ...
# batched_data = jax.tree_map(lambda x: x.reshape(num_devices, -1, *x.shape[1:]), data) # 将数据分片到每个设备

# 运行训练
# losses = p_train_step(params, batched_data, opt_state)
# print(f"Losses on each device: {losses}")

在 Gemma 微调中,gemma 包会利用这些 JAX 原语在后台处理分布式训练的细节,开发者通常只需关注模型、数据和训练循环的逻辑。

10.5 JAX/Flax 与 PyTorch 路线的对比选择

在 Gemma 微调中,开发者面临 JAX/Flax 和 PyTorch 两种主要的技术栈选择。两者各有优劣,选择哪种取决于具体的需求、团队经验和硬件环境。

特性/框架 JAX/Flax PyTorch
硬件亲和性 原生 TPU 优化,GPU 支持良好 原生 GPU 优化,TPU 支持通过 PyTorch/XLA
性能 XLA 编译器带来极致性能,尤其在 TPU 上 性能优异,但通常略逊于 JAX+TPU 的组合
编程范式 函数式编程,不可变状态,纯函数 面向对象编程,更直观,状态可变
学习曲线 相对较陡峭,需要适应函数式编程和 JIT 编译 相对平缓,与传统 Python 开发模式更接近
生态系统 相对较小,但在 Google 内部和研究领域强大 庞大且成熟,拥有丰富的库、工具和社区
调试 纯函数和 JIT 编译有时使调试复杂,需借助 pdbprint 调试技巧 动态图易于调试,有丰富的调试工具
分布式训练 jax.pmap (SPMD) 简洁高效,TPU 优化 torch.distributed 灵活强大,支持多种后端
内存管理 精细控制,高效利用 TPU HBM 自动管理,易于使用,但有时不如 JAX 精细
Gemma 官方支持 原生和推荐,有官方 gemma PyPI 包 通过 Hugging Face transformers 库支持
模型部署 通过 JAX/Flax 或 ONNX/TF SavedModel 导出 广泛支持 ONNX、TorchScript、TensorRT 等
适用场景 追求极致性能、大规模分布式训练、TPU 优先;研究和前沿模型开发 快速迭代、广泛工具链、GPU 优先;工业界应用和生产环境

选择建议:

  • 如果目标是最大化 Gemma 在 TPU 上的训练效率和性能,并且愿意投入学习 JAX 的函数式编程范式,那么 JAX/Flax 是官方推荐且最佳的选择。
  • 如果团队已经熟悉 PyTorch 生态,拥有大量 GPU 资源,或者更看重快速开发和丰富的社区支持,那么 PyTorch 路线(通过 Hugging Face transformers 库)会是更便捷的选择。

本教程专注于官方推荐的 JAX/Flax 路线,旨在帮助读者充分利用 Gemma 的原生优势。


💻 实战演示

本实战演示将指导您在 Google Cloud TPU VM 上,使用官方 gemma PyPI 包对 Gemma 模型进行 LoRA 微调。

前提条件:

  1. 一个有效的 Google Cloud 项目,并已设置好计费。
  2. 拥有足够的 TPU 配额(例如 tpu-v4-8tpu-v5e-8)。
  3. 已在本地安装 Google Cloud SDK 并完成认证。
  4. 已通过 Kaggle 认证并接受 Gemma 模型的条款,以便下载模型权重。

场景一:环境配置与 Gemma 官方库安装

首先,我们需要在 Google Cloud 上创建一个 TPU VM,并安装必要的软件。

  1. 创建 TPU VM 实例并 SSH 连接: 在本地终端执行以下命令。请替换 YOUR_GCP_PROJECT_ID 为您的项目 ID。

    # 1. 设置环境变量
    export ZONE="us-central1-a" # 选择一个支持 TPU 的区域,例如 us-central1-a, europe-west4-a
    export TPU_TYPE="tpu-v4-8" # 例如 tpu-v4-8, tpu-v5e-8
    export VM_NAME="gemma-tpu-vm-lora"
    export ACCELERATOR_TYPE="v4-8" # 对应 TPU_TYPE
    export RUNTIME_VERSION="tpu-vm-tf-2.15.0-pjrt" # 推荐使用 pjrt 版本
    
    # 2. 创建 TPU VM
    echo "Creating TPU VM: ${VM_NAME} in zone ${ZONE} with accelerator type ${ACCELERATOR_TYPE}..."
    gcloud compute tpus tpu-vm create ${VM_NAME} \
        --zone=${ZONE} \
        --accelerator-type=${ACCELERATOR_TYPE} \
        --version=${RUNTIME_VERSION} \
        --project=${YOUR_GCP_PROJECT_ID} \
        --provision-network-bandwidth=20gbps # 可选,更高带宽
    
    echo "TPU VM created. Connecting via SSH..."
    # 3. SSH 连接到 TPU VM
    gcloud compute tpus tpu-vm ssh ${VM_NAME} --zone=${ZONE}
    

    成功连接后,您将进入 TPU VM 的 shell 环境。

  2. 在 TPU VM 中安装依赖: 在 TPU VM 内部执行以下命令。

    # 1. 更新系统包
    sudo apt update && sudo apt upgrade -y
    
    # 2. 安装 Python 依赖
    # jax[tpu] 会安装 JAX 和针对 TPU 的后端
    # gemma 是官方库
    # ml_collections 用于配置管理
    # tqdm 用于显示进度条
    # einops 用于张量操作
    pip install --upgrade pip
    pip install "jax[tpu]" -f https://storage.googleapis.com/jax-releases/libtpu_releases.html
    pip install gemma ml_collections tqdm einops
    
  3. 下载 Gemma 模型权重: Gemma 模型权重需要通过 Kaggle 认证才能下载。请确保您已在 Kaggle 上接受了 Gemma 的条款。 在本地机器上(不是 TPU VM),通过 Kaggle API 下载权重,然后上传到 TPU VM。

    • 本地操作:
      1. 安装 Kaggle API 客户端:pip install kaggle
      2. 从 Kaggle 网站生成 API Token (在个人资料页 -> Account -> API -> Create New API Token)。下载 kaggle.json 文件。
      3. kaggle.json 移动到 ~/.kaggle/ 目录。
      4. 下载 Gemma 模型(例如 gemma-2b-it):
        kaggle models download google/gemma/gemma-2b-it
        # 下载的模型文件会解压到当前目录下的 gemma-2b-it 文件夹
        
      5. 将下载的权重文件上传到 TPU VM。假设您想上传 gemma-2b-it 文件夹到 VM 的 /home/user/gemma_weights/ 目录:
        gcloud compute tpus tpu-vm scp --recurse gemma-2b-it ${VM_NAME}:~/gemma_weights/ --zone=${ZONE}
        
    • TPU VM 操作(验证): SSH 连接到 TPU VM 后,检查权重文件是否已存在:
      ls -l ~/gemma_weights/gemma-2b-it/
      # 应该能看到 model.safetensors, tokenizer.json 等文件
      

场景二:使用 gemma 库进行简单的 LoRA 微调

我们将编写一个 Python 脚本,使用 gemma 库加载 Gemma 模型,定义一个简单的指令微调数据集,并使用 LoRA 进行训练。

在 TPU VM 中创建一个名为 finetune_gemma_lora.py 的文件,并粘贴以下代码:

import os
import json
import jax
import jax.numpy as jnp
from jax.experimental import mesh_utils
from jax.sharding import NamedSharding, PartitionSpec as P
import optax
import ml_collections
from tqdm import tqdm

from gemma import config as gemma_config
from gemma import model as gemma_model
from gemma import tokenizer as gemma_tokenizer
from gemma import sampler as gemma_sampler

# --- 1. 配置参数 ---
# 定义模型和训练配置
def get_config():
    config = ml_collections.ConfigDict()

    # Gemma 模型配置
    config.model_name = "gemma-2b-it" # 确保与下载的权重匹配
    config.model_path = os.path.expanduser(f"~/gemma_weights/{config.model_name}/")
    config.model_config_path = os.path.join(config.model_path, "config.json")
    config.model_weights_path = os.path.join(config.model_path, "model.safetensors")
    config.tokenizer_path = os.path.join(config.model_path, "tokenizer.json")

    # LoRA 微调配置
    config.lora_rank = 8
    config.lora_target_modules = ["q_proj", "v_proj"] # 注意力机制的 Q 和 V 投影层

    # 训练配置
    config.batch_size = 4 # 每个设备上的批次大小
    config.learning_rate = 1e-4
    config.num_epochs = 3
    config