HuggingFace peft LoRA 微调 LLaMA

news/2024/12/27 10:36:22 标签: llama, LoRA微调llama, peft

1. 安装必要库

pip install transformers peft accelerate

2. 加载 LLaMA 模型和分词器

Hugging Face Transformers 加载预训练的 LLaMA 模型和分词器。

from transformers import AutoModelForCausalLM, AutoTokenizer

# 加载 LLaMA 模型和分词器
model_name = "meta-llama/Llama-2-7b-hf"  # 替换为适合的模型
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(model_name, device_map="auto", load_in_8bit=True)

# 设置 pad_token 为 eos_token(如果模型没有 pad_token)
tokenizer.pad_token = tokenizer.eos_token
model.resize_token_embeddings(len(tokenizer))  # 调整词汇表大小

3. 配置 LoRA 微调

使用 PEFT 配置 LoRA 参数。

from peft import get_peft_model, LoraConfig, TaskType

# 定义 LoRA 配置
lora_config = LoraConfig(
    task_type=TaskType.CAUSAL_LM,  # 因果语言模型任务
    inference_mode=False,
    r=8,  # LoRA 的秩
    lora_alpha=16,
    lora_dropout=0.05
)

# 将 LoRA 应用于模型
model = get_peft_model(model, lora_config)

# 检查模型被正确标记为 trainable
print(model)

4. 定义数据集加载器

使用自定义数据集加载器和 Hugging Face 提供的 DataCollator 进行批量处理。

数据集预处理流程及其代码如下链接:训练数据格式为<input,output>,为什么微调大模型时,模型所需的输入数据input_ids有时仅包含了input,而有时包含了input和output呢?-CSDN博客

from torch.utils.data import DataLoader
from transformers import DataCollatorForSeq2Seq

# 自定义数据集(之前定义的 FineTuneDataset)
dataset = FineTuneDataset(data_path="./train.jsonl", tokenizer=tokenizer, max_length=1024)

# 定义数据批处理器
data_collator = DataCollatorForSeq2Seq(tokenizer, padding=True)

5. 配置 TrainingArguments

设置训练超参数,包括学习率、批次大小、保存频率等。

from transformers import TrainingArguments

training_args = TrainingArguments(
    output_dir="./llama_lora_finetuned",   # 输出模型路径
    evaluation_strategy="steps",          # 每隔多少步进行验证
    save_strategy="steps",                # 保存检查点的策略
    logging_dir="./logs",                 # 日志文件路径
    per_device_train_batch_size=8,        # 每个设备的训练批次大小
    gradient_accumulation_steps=4,        # 梯度累积
    learning_rate=2e-4,                   # 学习率
    num_train_epochs=3,                   # 训练轮数
    save_steps=500,                       # 每隔多少步保存模型
    logging_steps=100,                    # 日志记录频率
    fp16=True,                            # 使用混合精度训练
    push_to_hub=False                     # 如果需要保存到 Hugging Face Hub
)

6. 定义模型和 Trainer

from transformers import Trainer

# 定义 Trainer
trainer = Trainer(
    model=model,                          # 微调的模型
    args=training_args,                   # 训练参数
    train_dataset=dataset,                # 训练数据集
    data_collator=data_collator,          # 数据批处理器
)

7. 启动训练

trainer.train()
trainer.save_model("./llama_lora_finetuned")
tokenizer.save_pretrained("./llama_lora_finetuned")


http://www.niftyadmin.cn/n/5801619.html

相关文章

Linux下编译 libwebsockets简介和使用示例

目录 1:简单介绍: 2:项目地址 3:编译 3.1:集成介绍 3.2:编译 4:客户端服务端示例: 4.1 客户端示例 4.2 服务端示例: 1:简单介绍: Linux下…

在HTML中使用Vue如何使用嵌套循环把集合中的对象集合中的对象元素取出来(我的意思是集合中还有一个集合那种)

在 Vue.js 中处理嵌套集合(即集合中的对象包含另一个集合)时,使用多重 v-for 指令来遍历这些层次结构。每个 v-for 指令可以用于迭代一个特定级别的数据集,并且可以在模板中嵌套多个 v-for 来访问更深层次的数据。 例如&#xff…

V-Ray 来到 Blender:为艺术家提供专业级渲染

Chaos 正式宣布将其行业领先的渲染引擎 V-Ray 集成到 Blender 中。这一备受期待的开发为 Blender 用户带来了专业级渲染功能,使他们能够直接在他们最喜欢的 3D 平台中制作令人惊叹的、逼真的图像和动画。 渲染 强大的可缩放渲染 使用 V-Ray 将您的渲染提升到一个…

数据库-mysql高阶语句

mysql高阶语句 mysql高阶语句对复杂条件的查询 1、使用select语句时,如何按照顺序对结果进行排序 select name from info order by score desc; #我们查询的是name,按照成绩实现升序的操作select id,name from info order by score desc; #从大到小…

Springboot3国际化

国际化实现步骤 Spring Boot 3 提供了强大的国际化支持,使得应用程序可以根据用户的语言和区域偏好适配不同的语言和地区需求。 添加国际化资源文件: 国际化资源文件通常放在 src/main/resources 目录下,并按照不同的语言和地区命名&#xf…

微机接口课设——基于Proteus和8086的打地鼠设计(8255、8253、8259)

原理图设计 汇编代码 ; I/O 端口地址定义 IOY0 EQU 0600H IOY1 EQU 0640H IOY2 EQU 0680HMY8255_A EQU IOY000H*2 ; 8255 A 口端口地址 MY8255_B EQU IOY001H*2 ; 8255 B 口端口地址 MY8255_C EQU IOY002H*2 ; 8255 C 口端口地址 MY8255_MODE EQU IOY003H*2 ; …

RTOS 基础知识

**实时操作系统(RTOS, Real-Time Operating System)**是一种专为实时性要求设计的操作系统,具有确定性和高效性。RTOS 的系统架构围绕任务调度、时间管理和资源管理展开,以确保系统能够在规定时间内响应外部事件。以下是RTOS的系统…

绝美的数据处理图-三坐标轴-散点图-堆叠图-数据可视化图

clc clear close all %% 读取数据 load(MyColor.mat) %读取颜色包for iloop 1:25 %提取工作表数据data0(iloop) {readtable(data.xlsx,sheet,iloop)}; end%% 解析数据 countzeros(23,14); for iloop 1:25index(iloop) { cell2mat(table2array(data0{1,iloop}(1,1)))};data(i…