Compare commits

...

58 Commits
0.3.4 ... 0.4.0

Author SHA1 Message Date
zyxucp
6beb0b52c7 Merge pull request #92 from AIDotNet/feature_llamafactory
update llamafactory 0.8.0
2024-06-08 18:47:13 +08:00
zyxucp
0ea167a204 update llamafactory 0.8.0 2024-06-08 18:29:37 +08:00
zyxucp
6e6afa2a7c Update docker-compose.simple.yml 2024-06-08 11:36:19 +08:00
zyxucp
7a2a5d86bb Update docker-compose.yml 2024-06-08 11:36:04 +08:00
zyxucp
a1a36c3494 update nuget 2024-06-08 11:31:24 +08:00
zyxucp
4f350081dd update llamasharp 2024-06-08 11:23:02 +08:00
zyxucp
b3ea0c4e1a add llamasharp 配置 2024-06-08 11:04:14 +08:00
zyxucp
e72a6acd03 fix 处理聊天上下文 2024-05-30 13:08:37 +08:00
zyxucp
9bb8ab89fe Update README.zh.md 2024-05-29 22:54:41 +08:00
zyxucp
e78da66d1a Update README.md 2024-05-29 22:54:25 +08:00
zyxucp
9ee21fd5e5 AddServiceDefaults 2024-05-29 21:26:41 +08:00
zyxucp
a22c04c9b2 Merge pull request #91 from AIDotNet/feature_aspire
Feature aspire
2024-05-29 17:29:00 +08:00
zyxucp
3bb5bfaca7 add otel 2024-05-29 16:34:54 +08:00
zyxucp
c4bf5ee7e5 fix 增加OTEL 2024-05-29 15:06:16 +08:00
zyxucp
5e1e688f84 fix seq 2024-05-29 14:20:07 +08:00
zyxucp
80d9bf68f3 fix seq 2024-05-29 13:52:47 +08:00
zyxucp
65f2e3e363 add Serilog.Sinks.Seq 2024-05-29 13:20:11 +08:00
zyxucp
68d27ff2bc update Serilog 2024-05-29 13:03:00 +08:00
zyxucp
034da30811 add Serilog 2024-05-29 12:14:12 +08:00
zyxucp
3db0cdcd19 add aspire 2024-05-29 00:01:30 +08:00
zyxucp
42181a6f1d add aspire 2024-05-28 22:23:55 +08:00
zyxucp
ec8cbf2550 add 增加跨域处理 2024-05-27 22:19:22 +08:00
zyxucp
9a1bd079da fix 删除默认提示词 2024-05-26 19:41:58 +08:00
zyxucp
4213c4379c update 处理openapi 没有systemPrompt的问题 2024-05-26 19:38:32 +08:00
zyxucp
05cda17e2e style 样式修改 2024-05-26 00:50:23 +08:00
zyxucp
cda6e54f0b Merge branch 'main' of github.com:AIDotNet/AntSK 2024-05-25 23:11:40 +08:00
zyxucp
51d8ba6408 update km、sk 版本 2024-05-25 23:11:33 +08:00
zyxucp
b571c7d22d Update README.md 2024-05-24 22:01:03 +08:00
zyxucp
a0c91f565e fix 修复openapi聊天上下文bug 2024-05-24 21:47:53 +08:00
zyxucp
280c750165 Update README.md 2024-05-23 14:47:20 +08:00
zyxucp
fec9337fda margin 2024-05-23 14:34:46 +08:00
zyxucp
b84f252f2f update 更新readme 2024-05-23 14:17:54 +08:00
zyxucp
5c998ccce2 Update README.en.md 2024-05-23 13:53:36 +08:00
zyxucp
0e3cfd2cfb Update README.md 2024-05-23 13:53:33 +08:00
zyxucp
4040831a23 Update README.md 2024-05-23 13:52:17 +08:00
zyxucp
a3a2308659 Update docker-compose.yml 2024-05-23 13:46:03 +08:00
zyxucp
6d43c71d13 Update docker-compose.simple.yml 2024-05-23 13:45:42 +08:00
zyxucp
8315b6f37f fix 样式修改 2024-05-23 12:07:37 +08:00
zyxucp
7bc708e6ae margin 2024-05-23 11:33:15 +08:00
zyxucp
e6f2c5c2fe update 升级SK KM版本 2024-05-23 11:29:23 +08:00
zyxucp
5cab781362 Merge pull request #90 from yc-2503/main
fix: 对话窗口的第一条对话没有传给大模型问题
2024-05-14 22:20:32 +08:00
Chason
02d7994bae fix: 对话窗口的第一条对话丢失 2024-05-14 20:32:11 +08:00
zyxucp
b740957157 fix 调整KM版本 2024-05-12 20:49:13 +08:00
zyxucp
2480ec1272 margin 2024-05-12 19:07:51 +08:00
zyxucp
35c98a0d14 update 更新ant blazor \sk \km 2024-05-12 19:07:27 +08:00
zyxucp
0964a5ad5b Merge pull request #88 from yc-2503/main
bugfix: 调用function时 报错 jsonbody 参数不存在
2024-05-09 23:33:24 +08:00
Chason
a95131efe9 fix: 调用function时 报错 jsonbody 参数不存在
KernelParameterMetadata 的构造函数已指定参数名 jsonbody, 后续却又将参数名改为 json参数字符串
2024-05-09 17:41:43 +08:00
Chason
7783cdf3c4 bugfix: 语法错误 2024-05-09 10:55:41 +08:00
zyxucp
7a65f33cb6 Update README.md 2024-05-09 01:33:37 +08:00
zyxucp
6efd01db3f Merge pull request #87 from yc-2503/main
fix: 修正 会话总结 中的返回字符串
2024-05-08 13:26:15 +08:00
Chason
1e2322b573 Merge pull request #1 from yc-2503/yc-2503-patch-1
fix: 修正会话总结
2024-05-07 19:55:12 +08:00
Chason
2cb2241a66 fix: 修正会话总结 2024-05-07 19:54:29 +08:00
zyxucp
64efdd7881 add logo 2024-05-01 14:09:21 +08:00
zyxucp
be28e32803 update 更新nuget版本 2024-05-01 13:05:11 +08:00
zyxucp
468422baee fix 处理异步聊天问题 2024-04-30 21:53:50 +08:00
zyxucp
7b1c6c8c64 fix 修改异步 2024-04-30 17:53:16 +08:00
zyxucp
7ff0ea0bfe Update README.en.md 2024-04-29 21:42:57 +08:00
zyxucp
6bed4356f0 Update README.md 2024-04-29 18:17:06 +08:00
194 changed files with 10090 additions and 4832 deletions

View File

@@ -1,214 +0,0 @@
[简体中文](./README.md) | English
# AntSK
## AI Knowledge Base/Intelligent Agent built on .Net8+AntBlazor+SemanticKernel
## ⭐Core Features
- **Semantic Kernel**: Utilizes advanced natural language processing technology to accurately understand, process, and respond to complex semantic queries, providing users with precise information retrieval and recommendation services.
- **Kernel Memory**: Capable of continuous learning and storing knowledge points, AntSK has long-term memory function, accumulates experience, and provides a more personalized interaction experience.
- **Knowledge Base**: Import knowledge base through documents (Word, PDF, Excel, Txt, Markdown, Json, PPT) and perform knowledge base Q&A.
- **GPT Generation**: This platform supports creating personalized GPT models, enabling users to build their own GPT models.
- **API Interface Publishing**: Exposes internal functions in the form of APIs, enabling developers to integrate AntSK into other applications and enhance application intelligence.
- **API Plugin System**: Open API plugin system that allows third-party developers or service providers to easily integrate their services into AntSK, continuously enhancing application functionality.
- **.Net Plugin System**: Open dll plugin system that allows third-party developers or service providers to easily integrate their business functions by generating dll in standard format code, continuously enhancing application functionality.
- **Online Search**: AntSK, real-time access to the latest information, ensuring users receive the most timely and relevant data.
- **Model Management**: Adapts and manages integration of different models from different manufacturers, including gguf types supported by **llama.cpp** and models offline running supported by **llamafactory**.
- **Domestic Innovation**: AntSK supports domestic models and databases and can run under domestic innovation conditions.
- **Model Fine-Tuning**: Planned based on llamafactory for model fine-tuning.
## ⛪Application Scenarios
AntSK is suitable for various business scenarios, such as:
- Enterprise knowledge management system
- Automatic customer service and chatbots
- Enterprise search engine
- Personalized recommendation system
- Intelligent writing assistance
- Education and online learning platforms
- Other interesting AI Apps
## ✏Function Examples
### Online Demo
```
https://antsk.ai-dotnet.com/
```
```
Default account: test
Default password: test
Due to the low configuration of the cloud server, the local model cannot be run, so the system settings permissions have been closed. You can simply view the interface. If you want to use the local model, please download and use it on your own.
```
### Other Function Examples
[Video Demonstration](https://www.bilibili.com/video/BV1zH4y1h7Y9/)
## ❓How to get started?
Here I am using Postgres as the data and vector storage because Semantic Kernel and Kernel Memory support it, but you can also use other options.
The model by default supports the local model of openai, azure openai, and llama. If you need to use other models, you can integrate them using one-api.
The Login configuration in the configuration file is the default login account and password.
The following configuration file needs to be configured
## 1⃣Using docker-compose
Provided the pg version **appsettings.json** and simplified version (Sqlite+disk) **docker-compose.simple.yml**
Download **docker-compose.yml** from the project root directory and place the configuration file **appsettings.json** in the same directory.
The pg image has already been prepared. You can modify the default username and password in docker-compose.yml, and then the database connection in your **appsettings.json** needs to be consistent.
Then you can execute the following command in the directory to start AntSK
```
docker-compose up -d
```
## 2⃣How to mount local models and model download directory in docker
```
# Non-host version, do not use local proxy
version: '3.8'
services:
antsk:
container_name: antsk
image: registry.cn-hangzhou.aliyuncs.com/AIDotNet/antsk:v0.1.5ports:
- 5000:5000
networks:
- antsk
depends_on:
- antskpg
restart: always
environment:
- ASPNETCORE_URLS=http://*:5000
volumes:
- ./appsettings.json:/app/appsettings.json # Local configuration file needs to be placed in the same directory
- D://model:/app/model
networks:
antsk:
```
Taking this as an example, it means mounting the local D://model folder of Windows into the container /app/model. If so, the model address in your appsettings.json should be configured as
```
model/xxx.gguf
```
## 3⃣Some meanings of configuration file
```
{
"DBConnection": {
"DbType": "Sqlite",
"ConnectionStrings": "Data Source=AntSK.db;"
},
"KernelMemory": {
"VectorDb": "Disk",
"ConnectionString": "Host=;Port=;Database=antsk;Username=;Password=",
"TableNamePrefix": "km-"
},
"LLamaSharp": {
"RunType": "GPU",
"FileDirectory": "D:\\Code\\AI\\AntBlazor\\model\\"
},
"Login": {
"User": "admin",
"Password": "xuzeyu"
},
"BackgroundTaskBroker": {
"ImportKMSTask": {
"WorkerCount": 1
}
}
}
```
```
// Supports various databases, you can check SqlSugar, MySql, SqlServer, Sqlite, Oracle, PostgreSQL, Dm, Kdbndp, Oscar, MySqlConnector, Access, OpenGauss, QuestDB, HG, ClickHouse, GBase, Odbc, OceanBaseForOracle, TDengine, GaussDB, OceanBase, Tidb, Vastbase, PolarDB, Custom
DBConnection.DbType
// Connection string, need to use the corresponding string according to the different DB types
DBConnection.ConnectionStrings
//The type of vector storage, supporting Postgres, Disk, Memory, Qdrant, Redis, AzureAISearch
//Postgres and Redis require ConnectionString configuration
//The ConnectionString of Qdrant and AzureAISearch uses Endpoint | APIKey
KernelMemory.VectorDb
//Local model execution options: GPU and CPU. When using the online API, any option can be used.
LLamaSharp.RunType
//Local model path, used for quick selection of models under llama, as well as saving downloaded models.
LLamaSharp.FileDirectory
//Default admin account password
Login
//Import asynchronous processing thread count. A higher count can be used for online API, but for local models, 1 is recommended to avoid memory overflow issues.
BackgroundTaskBroker.ImportKMSTask.WorkerCount
```
## ⚠Fixing Style Issues:
Run the following in AntSK/src/AntSK:
```
dotnet clean
dotnet build
dotnet publish "AntSK.csproj"
```
Then navigate to AntSK/src/AntSK/bin/Release/net8.0/publish and run:
```
dotnet AntSK.dll
```
The styles should now be applied after starting.
I'm using CodeFirst mode for the database, so as long as the database connection is properly configured, the table structure will be created automatically.
## ✔Using llamafactory
```
1. First, ensure that Python and pip are installed in your environment. This step is not necessary if using an image, such as version v0.2.3.2, which already includes the complete Python environment.
2. Go to the model add page and select llamafactory.
3. Click "Initialize" to check whether the 'pip install' environment setup is complete.
4. Choose a model that you like.
5. Click "Start" to begin downloading the model from the tower. This may involve a somewhat lengthy wait.
6. After the model has finished downloading, enter http://localhost:8000/ in the request address. The default port is 8000.
7. Click "Save" and start chatting.
8. Many people ask about the difference between LLamaSharp and llamafactory. In fact, LLamaSharp is a .NET implementation of llama.cpp, but only supports local gguf models, while llamafactory supports a wider variety of models and uses Python implementation. The main difference lies here. Additionally, llamafactory has the ability to fine-tune models, which is an area we will focus on integrating in the future.
```
## 🤝 Contributing
[![PRs Welcome](https://img.shields.io/badge/PRs-welcome-brightgreen.svg?style=flat-square)](https://github.com/AIDotNet/AntSK/pulls)
If you would like to contribute, feel free to create a [Pull Request](https://github.com/AIDotNet/AntSK/pulls), or give us [Bug Report](https://github.com/AIDotNet/AntSK/issues/new).
## 💕 Contributors
This project exists thanks to all the people who contribute.
<a href="https://github.com/AIDotNet/AntSK/graphs/contributors">
<img src="https://contrib.rocks/image?repo=AIDotNet/AntSK&max=1000&columns=15&anon=1" />
</a>
## 🚨 Code of Conduct
This project has adopted the code of conduct defined by the Contributor Covenant to clarify expected behavior in our community.
For more information see the [.NET Foundation Code of Conduct](https://dotnetfoundation.org/code-of-conduct).
To learn more or get started with **AntSK**, follow my official WeChat account and join the discussion group.
## ☎Contact Me
If you have any questions or suggestions, please contact me through my official WeChat account. We also have a discussion group where you can send a message to join, and then I will add you to the group.
![Official WeChat Account](https://github.com/AIDotNet/Avalonia-Assistant/blob/main/img/gzh.jpg)
---
We appreciate your interest in **AntSK** and look forward to collaborating with you to create an intelligent future!

222
README.md
View File

@@ -1,100 +1,90 @@
中文|[English](https://github.com/AIDotNet/AntSK/blob/main/README.en.md)
[简体中文](./README.zh.md) | English
# AntSK
## 使用.Net8+Blazor+SemanticKernel 打造的AI知识库/智能体
## AI Knowledge Base/Intelligent Agent built on .Net8+AntBlazor+SemanticKernel
## ⭐核心功能
## ⭐Core Features
- **语义内核 (Semantic Kernel)**:采用领先的自然语言处理技术,准确理解、处理和响应复杂的语义查询,为用户提供精确的信息检索和推荐服务。
- **Semantic Kernel**: Utilizes advanced natural language processing technology to accurately understand, process, and respond to complex semantic queries, providing users with precise information retrieval and recommendation services.
- **内存内核 (Kernel Memory)**具备持续学习和存储知识点的能力AntSK 拥有长期记忆功能,累积经验,提供更个性化的交互体验。
- **Kernel Memory**: Capable of continuous learning and storing knowledge points, AntSK has long-term memory function, accumulates experience, and provides a more personalized interaction experience.
- **知识库**:通过文档(WordPDFExcelTxtMarkdownJsonPPT等形式导入知识库可以进行知识库问答支持本地bge-embedding 向量模型 以及bge-rerank 重排模型。
- **Knowledge Base**: Import knowledge base through documents (Word, PDF, Excel, Txt, Markdown, Json, PPT) and perform knowledge base Q&A.
- **文生图**:集成**StableDiffusion** 本地模型,可以进行文生图。
- **GPT Generation**: This platform supports creating personalized GPT models, enabling users to build their own GPT models.
- **GPTs 生成**此平台支持创建个性化的GPT模型尝试构建您自己的GPT模型。
- **API Interface Publishing**: Exposes internal functions in the form of APIs, enabling developers to integrate AntSK into other applications and enhance application intelligence.
- **API接口发布**将内部功能以API的形式对外提供便于开发者将AntSK 集成进其他应用,增强应用智慧。
- **API Plugin System**: Open API plugin system that allows third-party developers or service providers to easily integrate their services into AntSK, continuously enhancing application functionality.
- **API插件系统**开放式API插件系统允许第三方开发者或服务商轻松将其服务集成到AntSK不断增强应用功能。
- **.Net Plugin System**: Open dll plugin system that allows third-party developers or service providers to easily integrate their business functions by generating dll in standard format code, continuously enhancing application functionality.
- **.Net插件系统**开放式dll插件系统允许第三方开发者或服务商轻松将其业务功能通过标准格式的代码生成dll后集成到AntSK不断增强应用功能。
- **Online Search**: AntSK, real-time access to the latest information, ensuring users receive the most timely and relevant data.
- **联网搜索**AntSK实时获取最新信息确保用户接受到的资料总是最及时、最相关的。
- **Model Management**: Adapts and manages integration of different models from different manufacturers, including gguf types supported by **llama.cpp** and models offline running supported by **llamafactory**.
- **模型管理**:适配和管理集成不同厂商的不同模型。并且支持**llama.cpp**所支持的gguf类型以及**llamafactory**所支持的模型离线运行
- **Domestic Innovation**: AntSK supports domestic models and databases and can run under domestic innovation conditions.
- **国产信创**AntSK支持国产模型和国产数据库可以在信创条件下运行
- **Model Fine-Tuning**: Planned based on llamafactory for model fine-tuning.
- **模型微调**规划中基于llamafactory进行模型微调
## ⛪Application Scenarios
## ⛪应用场景
AntSK is suitable for various business scenarios, such as:
- Enterprise knowledge management system
- Automatic customer service and chatbots
- Enterprise search engine
- Personalized recommendation system
- Intelligent writing assistance
- Education and online learning platforms
- Other interesting AI Apps
AntSK 适用于多种业务场景,例如:
- 企业级知识管理系统
- 自动客服与聊天机器人
- 企业级搜索引擎
- 个性化推荐系统
- 智能辅助写作
- 教育与在线学习平台
- 其他有意思的AI App
## ✏Function Examples
### Online Demo
[document](http://antsk.cn/)
## ✏️功能示例
### 在线演示
[文档地址](http://antsk.cn/)
[体验地址](https://antsk.ai-dotnet.com/)
[demo](https://antsk.ai-dotnet.com/)
```
默认账号:test
Default account: test
默认密码:test
Default password: test
由于云服务器配置较低,无法运行本地模型,所以把系统设置权限关闭了,大家看看界面即可,要使用本地模型,请下载自行使用
请勿在演示站点上传敏感信息
Due to the low configuration of the cloud server, the local model cannot be run, so the system settings permissions have been closed. You can simply view the interface. If you want to use the local model, please download and use it on your own.
```
### 其他功能示例
[视频示例](https://www.bilibili.com/video/BV1zH4y1h7Y9/)
### Other Function Examples
[Video Demonstration](https://www.bilibili.com/video/BV1zH4y1h7Y9/)
[在线文档http://antsk.cn](http://antsk.cn)
## ❓How to get started?
## ❓如何开始?
Here I am using Postgres as the data and vector storage because Semantic Kernel and Kernel Memory support it, but you can also use other options.
在这里我使用的是Postgres 作为数据存储和向量存储因为Semantic Kernel和Kernel Memory都支持他当然你也可以换成其他的。
The model by default supports the local model of openai, azure openai, and llama. If you need to use other models, you can integrate them using one-api.
模型默认支持openai、azure openai、讯飞星火、阿里云积、 和llama支持的gguf本地模型 以及llamafactory的本地模型,如果需要使用其他模型可以使用one-api进行集成。
The Login configuration in the configuration file is the default login account and password.
配置文件中的Login配置是默认的登录账号和密码
The following configuration file needs to be configured
需要配置如下的配置文件
## 1⃣Using docker-compose
## 1⃣使用docker-compose
Provided the pg version **appsettings.json** and simplified version (Sqlite+disk) **docker-compose.simple.yml**
提供了pg版本 **appsettings.json** 和 简化版本(**Sqlite+disk** **docker-compose.simple.yml**
Download **docker-compose.yml** from the project root directory and place the configuration file **appsettings.json** in the same directory.
从项目根目录下载**docker-compose.yml**,然后把配置文件**appsettings.json**和它放在统一目录,
The pg image has already been prepared. You can modify the default username and password in docker-compose.yml, and then the database connection in your **appsettings.json** needs to be consistent.
这里已经把pg的镜像做好了。在docker-compose.yml中可以修改默认账号密码然后你的**appsettings.json**的数据库连接需要保持一致。
然后你可以进入到目录后执行
Then you can execute the following command in the directory to start AntSK
```
docker-compose up -d
```
来启动AntSK
## 2如何在docker中挂载本地模型和模型下载的目录
## 2How to mount local models and model download directory in docker
```
# host 版本, 不使用本机代理
# Non-host version, do not use local proxy
version: '3.8'
services:
antsk:
container_name: antsk
image: registry.cn-hangzhou.aliyuncs.com/AIDotNet/antsk:v0.3.1
ports:
image: registry.cn-hangzhou.aliyuncs.com/AIDotNet/antsk:v0.1.5ports:
- 5000:5000
networks:
- antsk
@@ -104,32 +94,35 @@ services:
environment:
- ASPNETCORE_URLS=http://*:5000
volumes:
- ./appsettings.json:/app/appsettings.json # 本地配置文件 需要放在同级目录
- ./appsettings.json:/app/appsettings.json # Local configuration file needs to be placed in the same directory
- D://model:/app/model
- D://model:/root/.cache/modelscope/hub/AI-ModelScope #使用Llamafactory时需要挂载 否则初始化的环境重启后会丢失
networks:
antsk:
```
以这个为示例意思是把windows本地D://model的文件夹挂载进 容器内/app/model 如果是这样你的appsettings.json中的模型地址应该配置为
Taking this as an example, it means mounting the local D://model folder of Windows into the container /app/model. If so, the model address in your appsettings.json should be configured as
```
model/xxx.gguf
```
## 3配置文件的一些含义
## 3Some meanings of configuration file
```
{
"DBConnection": {
"DbType": "Sqlite",
"DbType": "Sqlite",
"ConnectionStrings": "Data Source=AntSK.db;"
},
"KernelMemory": {
"VectorDb": "Disk",
"VectorDb": "Disk",
"ConnectionString": "Host=;Port=;Database=antsk;Username=;Password=",
"TableNamePrefix": "km-"
},
"FileDir": {
"DirectoryPath": "D:\\git\\AntBlazor\\model"
},
"LLamaSharp": {
"RunType": "GPU",
"FileDirectory": "D:\\Code\\AI\\AntBlazor\\model\\"
"RunType": "GPU",
"ContextSize": 2048,
"GpuLayerCount": 20
},
"Login": {
"User": "admin",
@@ -143,86 +136,85 @@ model/xxx.gguf
}
```
```
//支持多种数据库,具体可以查看SqlSugarMySqlSqlServerSqliteOraclePostgreSQLDmKdbndpOscarMySqlConnectorAccessOpenGaussQuestDBHGClickHouseGBaseOdbcOceanBaseForOracleTDengineGaussDBOceanBaseTidbVastbasePolarDBCustom
// Supports various databases, you can check SqlSugar, MySql, SqlServer, Sqlite, Oracle, PostgreSQL, Dm, Kdbndp, Oscar, MySqlConnector, Access, OpenGauss, QuestDB, HG, ClickHouse, GBase, Odbc, OceanBaseForOracle, TDengine, GaussDB, OceanBase, Tidb, Vastbase, PolarDB, Custom
DBConnection.DbType
//连接字符串需要根据不同DB类型用对应的字符串
// Connection string, need to use the corresponding string according to the different DB types
DBConnection.ConnectionStrings
//向量存储的类型,支持 PostgresDiskMemoryQdrantRedisAzureAISearch
//Postgres、Redis需要配置 ConnectionString
//Qdrant AzureAISearch 的 ConnectionString 使用 Endpoint|APIKey
//The type of vector storage, supporting Postgres, Disk, Memory, Qdrant, Redis, AzureAISearch
//Postgres and Redis require ConnectionString configuration
//The ConnectionString of Qdrant and AzureAISearch uses Endpoint | APIKey
KernelMemory.VectorDb
//本地模型使用的运行方式 GUP CPU ,如果用在线API 这个随意使用一个即可
//Local model execution options: GPU and CPU. When using the online API, any option can be used.
LLamaSharp.RunType
//本地模型路径用于在选择llama时可以快速选择目录下的模型以及保存下载的模型
//Local model path, used for quick selection of models under llama, as well as saving downloaded models.
LLamaSharp.FileDirectory
//默认管理员账号密码
//Default admin account password
Login
//导入异步处理的线程数使用在线API可以高一点本地模型建议1 否则容易内存溢出崩掉
//Import asynchronous processing thread count. A higher count can be used for online API, but for local models, 1 is recommended to avoid memory overflow issues.
BackgroundTaskBroker.ImportKMSTask.WorkerCount
```
## ⚠️找不到样式问题解决:
AntSK/src/AntSK下执行:
## ⚠️Fixing Style Issues:
Run the following in AntSK/src/AntSK:
```
dotnet clean
dotnet build
dotnet publish "AntSK.csproj"
```
再去AntSK/src/AntSK/bin/Release/net8.0/publish
Then navigate to AntSK/src/AntSK/bin/Release/net8.0/publish and run:
```
dotnet AntSK.dll
```
然后启动就有样式了
The styles should now be applied after starting.
DB我使用的是CodeFirst模式只要配置好数据库链接表结构是自动创建的
I'm using CodeFirst mode for the database, so as long as the database connection is properly configured, the table structure will be created automatically.
## ✔️使用llamafactory
## ✔️Using llamafactory
```
1、首先需要确保你的环境已经安装了python和pip如果使用镜像例如p0.2.4版本已经包含了 python全套环境则无需此步骤
2、进入模型添加页面选择llamafactory
3、点击初始化可以检查pip install 环境是否完成
4、选择一个喜欢的模型
5、点击启动,这会开始从魔塔下载模型,你可能需要有一个较为漫长的等待
6、等待模型下载完毕后,在请求地址输入 http://localhost:8000/ 这里默认是使用8000端口
7、点击保存,然后就可以开始聊天了
8、很多人会问 LLamaSharpllamafactory有什么区别其实这两者LLamaSharp是llama.cpp的 dotnet实现但是只支持本地gguf模型 而llamafactory 支持的模型种类更多但使用的是python的实现其主要差异在这里另外llamafactory具有模型微调的能力这也是我们下一步需要重点集成的部分。
1. First, ensure that Python and pip are installed in your environment. This step is not necessary if using an image, such as version v0.2.3.2, which already includes the complete Python environment.
2. Go to the model add page and select llamafactory.
3. Click "Initialize" to check whether the 'pip install' environment setup is complete.
4. Choose a model that you like.
5. Click "Start" to begin downloading the model from the tower. This may involve a somewhat lengthy wait.
6. After the model has finished downloading, enter http://localhost:8000/ in the request address. The default port is 8000.
7. Click "Save" and start chatting.
8. Many people ask about the difference between LLamaSharp and llamafactory. In fact, LLamaSharp is a .NET implementation of llama.cpp, but only supports local gguf models, while llamafactory supports a wider variety of models and uses Python implementation. The main difference lies here. Additionally, llamafactory has the ability to fine-tune models, which is an area we will focus on integrating in the future.
```
## 🤝 贡献
## 🤝 Contributing
[![PRs Welcome](https://img.shields.io/badge/PRs-welcome-brightgreen.svg?style=flat-square)](https://github.com/AIDotNet/AntSK/pulls)

如果你想贡献,可以创建一个[拉取请求](https://github.com/AIDotNet/AntSK/pulls), 或给我们[错误报告](https://github.com/AIDotNet/AntSK/issues/new).


## 💕 贡献者
[PRs Welcome](https://img.shields.io/badge/PRs-welcome-brightgreen.svg?style=flat-square)
If you would like to contribute, feel free to create a [Pull Request](https://github.com/AIDotNet/AntSK/pulls), or give us [Bug Report](https://github.com/AIDotNet/AntSK/issues/new).
## 💕 Contributors
This project exists thanks to all the people who contribute.
这个项目的存在要感谢所有的贡献者。

<a href="https://github.com/AIDotNet/AntSK/graphs/contributors">
<img src="https://contrib.rocks/image?repo=AIDotNet/AntSK&max=1000&columns=15&anon=1" />
</a>

## 🚨 行为准则
该项目采用了贡献者公约定义的行为准则,以阐明我们社区的预期行为。有关更多信息,请参见 .NET Foundation 行为准则。 [.NET Foundation Code of Conduct](https://dotnetfoundation.org/code-of-conduct).
想了解更多信息或开始使用 **AntSK**,可以关注我的公众号以及加入交流群。
## ☎️联系我
如有任何问题或建议请通过以下方式关注我的公众号《许泽宇的技术分享》发消息与我联系我们也有AIDotnet交流群可以发送进群等消息然后我会拉你进交流群
![公众号](https://github.com/AIDotNet/AntSK/blob/main/images/gzh.jpg)
## 🌟 Star History
<a href="https://github.com/AIDotNet/AntSK/stargazers" target="_blank" style="display: block" align="center">
<picture>
<source media="(prefers-color-scheme: dark)" srcset="https://api.star-history.com/svg?repos=AIDotNet/AntSK&type=Date&theme=dark" />
<source media="(prefers-color-scheme: light)" srcset="https://api.star-history.com/svg?repos=AIDotNet/AntSK&type=Date" />
<img alt="Star History Chart" src="https://api.star-history.com/svg?repos=AIDotNet/AntSK&type=Date" />
</picture>
<img src="https://contrib.rocks/image?repo=AIDotNet/AntSK&max=1000&columns=15&anon=1" />
</a>
## 🚨 Use Protocol
This warehouse follows the [Apache-2.0 License](https://github.com/AIDotNet/AntSK?tab=Apache-2.0-1-ov-file) open source protocol.
The Apache open source license allows the use of AntSK in commercial environments, provided that the license terms are followed. One of the main terms is to retain the copyright and license statements.
If you plan to use AntSK in commercial projects, you need to ensure that you follow the following steps:
1. Copyright statement containing Apache license. [Apache-2.0 License](https://github.com/AIDotNet/AntSK?tab=Apache-2.0-1-ov-file).
2. If you modify the software source code, you need to clearly indicate these modifications in the source code.
## ☎Contact Me
If you have any questions or suggestions, please contact me through my official WeChat account. We also have a discussion group where you can send a message to join, and then I will add you to the group.
![Official WeChat Account](https://github.com/AIDotNet/AntSK/blob/main/images/gzh.jpg)
---
We appreciate your interest in **AntSK** and look forward to collaborating with you to create an intelligent future!

238
README.zh.md Normal file
View File

@@ -0,0 +1,238 @@
中文|[English](./README.md)
# AntSK
## 使用.Net8+Blazor+SemanticKernel 打造的AI知识库/智能体
## ⭐核心功能
- **语义内核 (Semantic Kernel)**:采用领先的自然语言处理技术,准确理解、处理和响应复杂的语义查询,为用户提供精确的信息检索和推荐服务。
- **内存内核 (Kernel Memory)**具备持续学习和存储知识点的能力AntSK 拥有长期记忆功能,累积经验,提供更个性化的交互体验。
- **知识库**通过文档Word、PDF、Excel、Txt、Markdown、Json、PPT等形式导入知识库可以进行知识库问答支持本地bge-embedding 向量模型 以及bge-rerank 重排模型。
- **文生图**:集成**StableDiffusion** 本地模型,可以进行文生图。
- **GPTs 生成**此平台支持创建个性化的GPT模型尝试构建您自己的GPT模型。
- **API接口发布**将内部功能以API的形式对外提供便于开发者将AntSK 集成进其他应用,增强应用智慧。
- **API插件系统**开放式API插件系统允许第三方开发者或服务商轻松将其服务集成到AntSK不断增强应用功能。
- **.Net插件系统**开放式dll插件系统允许第三方开发者或服务商轻松将其业务功能通过标准格式的代码生成dll后集成到AntSK不断增强应用功能。
- **联网搜索**AntSK实时获取最新信息确保用户接受到的资料总是最及时、最相关的。
- **模型管理**:适配和管理集成不同厂商的不同模型。并且支持**llama.cpp**所支持的gguf类型以及**llamafactory**所支持的模型离线运行
- **国产信创**AntSK支持国产模型和国产数据库可以在信创条件下运行
- **模型微调**规划中基于llamafactory进行模型微调
## ⛪应用场景
AntSK 适用于多种业务场景,例如:
- 企业级知识管理系统
- 自动客服与聊天机器人
- 企业级搜索引擎
- 个性化推荐系统
- 智能辅助写作
- 教育与在线学习平台
- 其他有意思的AI App
## ✏️功能示例
### 在线演示
[文档地址](http://antsk.cn/)
[体验地址](https://antsk.ai-dotnet.com/)
```
默认账号test
默认密码test
由于云服务器配置较低,无法运行本地模型,所以把系统设置权限关闭了,大家看看界面即可,要使用本地模型,请下载自行使用
请勿在演示站点上传敏感信息
```
### 其他功能示例
[视频示例](https://www.bilibili.com/video/BV1zH4y1h7Y9/)
[在线文档http://antsk.cn](http://antsk.cn)
## ❓如何开始?
在这里我使用的是Postgres 作为数据存储和向量存储因为Semantic Kernel和Kernel Memory都支持他当然你也可以换成其他的。
模型默认支持openai、azure openai、讯飞星火、阿里云积、 和llama支持的gguf本地模型 以及llamafactory的本地模型,如果需要使用其他模型可以使用one-api进行集成。
配置文件中的Login配置是默认的登录账号和密码
需要配置如下的配置文件
## 1⃣使用docker-compose
提供了pg版本 **appsettings.json** 和 简化版本(**Sqlite+disk** **docker-compose.simple.yml**
从项目根目录下载**docker-compose.yml**,然后把配置文件**appsettings.json**和它放在统一目录,
这里已经把pg的镜像做好了。在docker-compose.yml中可以修改默认账号密码然后你的**appsettings.json**的数据库连接需要保持一致。
然后你可以进入到目录后执行
```
docker-compose up -d
```
来启动AntSK
## 2⃣如何在docker中挂载本地模型和模型下载的目录
```
# 非 host 版本, 不使用本机代理
version: '3.8'
services:
antsk:
container_name: antsk
image: registry.cn-hangzhou.aliyuncs.com/AIDotNet/antsk:v0.3.1
ports:
- 5000:5000
networks:
- antsk
depends_on:
- antskpg
restart: always
environment:
- ASPNETCORE_URLS=http://*:5000
volumes:
- ./appsettings.json:/app/appsettings.json # 本地配置文件 需要放在同级目录
- D://model:/app/model
- D://model:/root/.cache/modelscope/hub/AI-ModelScope #使用Llamafactory时需要挂载 否则初始化的环境重启后会丢失
networks:
antsk:
```
以这个为示例意思是把windows本地D://model的文件夹挂载进 容器内/app/model 如果是这样你的appsettings.json中的模型地址应该配置为
```
model/xxx.gguf
```
## 3⃣配置文件的一些含义
```
{
"DBConnection": {
"DbType": "Sqlite",
"ConnectionStrings": "Data Source=AntSK.db;"
},
"KernelMemory": {
"VectorDb": "Disk",
"ConnectionString": "Host=;Port=;Database=antsk;Username=;Password=",
"TableNamePrefix": "km-"
},
"FileDir": {
"DirectoryPath": "D:\\git\\AntBlazor\\model"
},
"LLamaSharp": {
"RunType": "GPU",
"ContextSize": 2048,
"GpuLayerCount": 20
},
"Login": {
"User": "admin",
"Password": "xuzeyu"
},
"BackgroundTaskBroker": {
"ImportKMSTask": {
"WorkerCount": 1
}
}
}
```
```
//支持多种数据库具体可以查看SqlSugarMySqlSqlServerSqliteOraclePostgreSQLDmKdbndpOscarMySqlConnectorAccessOpenGaussQuestDBHGClickHouseGBaseOdbcOceanBaseForOracleTDengineGaussDBOceanBaseTidbVastbasePolarDBCustom
DBConnection.DbType
//连接字符串需要根据不同DB类型用对应的字符串
DBConnection.ConnectionStrings
//向量存储的类型,支持 Postgres、Disk、Memory、Qdrant、Redis、AzureAISearch
//Postgres、Redis需要配置 ConnectionString
//Qdrant 和AzureAISearch 的 ConnectionString 使用 Endpoint|APIKey
KernelMemory.VectorDb
//本地模型使用的运行方式 GUP CPU ,如果用在线API 这个随意使用一个即可
LLamaSharp.RunType
//本地模型路径用于在选择llama时可以快速选择目录下的模型以及保存下载的模型
LLamaSharp.FileDirectory
//默认管理员账号密码
Login
//导入异步处理的线程数使用在线API可以高一点本地模型建议1 否则容易内存溢出崩掉
BackgroundTaskBroker.ImportKMSTask.WorkerCount
```
## ⚠️找不到样式问题解决:
AntSK/src/AntSK下执行:
```
dotnet clean
dotnet build
dotnet publish "AntSK.csproj"
```
再去AntSK/src/AntSK/bin/Release/net8.0/publish下
```
dotnet AntSK.dll
```
然后启动就有样式了
DB我使用的是CodeFirst模式只要配置好数据库链接表结构是自动创建的
## ✔使用llamafactory
```
1、首先需要确保你的环境已经安装了python和pip如果使用镜像例如p0.2.4版本已经包含了 python全套环境则无需此步骤
2、进入模型添加页面选择llamafactory
3、点击初始化可以检查pip install 环境是否完成
4、选择一个喜欢的模型
5、点击启动,这会开始从魔塔下载模型,你可能需要有一个较为漫长的等待
6、等待模型下载完毕后在请求地址输入 http://localhost:8000/ 这里默认是使用8000端口
7、点击保存然后就可以开始聊天了
8、很多人会问 LLamaSharp与llamafactory有什么区别其实这两者LLamaSharp是llama.cpp的 dotnet实现但是只支持本地gguf模型 而llamafactory 支持的模型种类更多但使用的是python的实现其主要差异在这里另外llamafactory具有模型微调的能力这也是我们下一步需要重点集成的部分。
```
## 🤝 贡献
[![PRs Welcome](https://img.shields.io/badge/PRs-welcome-brightgreen.svg?style=flat-square)](https://github.com/AIDotNet/AntSK/pulls)

如果你想贡献,可以创建一个[拉取请求](https://github.com/AIDotNet/AntSK/pulls), 或给我们[错误报告](https://github.com/AIDotNet/AntSK/issues/new).


## 💕 贡献者
这个项目的存在要感谢所有的贡献者。

<a href="https://github.com/AIDotNet/AntSK/graphs/contributors">
<img src="https://contrib.rocks/image?repo=AIDotNet/AntSK&max=1000&columns=15&anon=1" />
</a>

## 🚨 使用协议
本仓库遵循 [Apache-2.0 License](https://github.com/AIDotNet/AntSK?tab=Apache-2.0-1-ov-file) 开源协议。
Apache开源许可证允许在商业环境中使用AntSK前提是需要遵守许可证的条款。主要条款之一是要保留版权声明和许可证声明。
如果您打算在商业项目中使用AntSK您需要确保遵守以下步骤
1、包含Apache许可证的版权声明。 [Apache-2.0 License](https://github.com/AIDotNet/AntSK?tab=Apache-2.0-1-ov-file) 。
2、如果您修改了软件源代码您需要在源代码中明确标明这些修改。
## ☎️联系我
如有任何问题或建议请通过以下方式关注我的公众号《许泽宇的技术分享》发消息与我联系我们也有AIDotnet交流群可以发送进群等消息然后我会拉你进交流群
![公众号](https://github.com/AIDotNet/AntSK/blob/main/images/gzh.jpg)
## 🌟 Star History
<a href="https://github.com/AIDotNet/AntSK/stargazers" target="_blank" style="display: block" align="center">
<picture>
<source media="(prefers-color-scheme: dark)" srcset="https://api.star-history.com/svg?repos=AIDotNet/AntSK&type=Date&theme=dark" />
<source media="(prefers-color-scheme: light)" srcset="https://api.star-history.com/svg?repos=AIDotNet/AntSK&type=Date" />
<img alt="Star History Chart" src="https://api.star-history.com/svg?repos=AIDotNet/AntSK&type=Date" />
</picture>
</a>

View File

@@ -3,9 +3,9 @@ version: '3.8'
services:
antsk:
container_name: antsk
image: registry.cn-hangzhou.aliyuncs.com/xuzeyu91/antsk:v0.3.1
image: registry.cn-hangzhou.aliyuncs.com/xuzeyu91/antsk:v0.3.9
# 如果需要pytorch环境需要使用下面这个镜像镜像比较大
# image: registry.cn-hangzhou.aliyuncs.com/xuzeyu91/antsk:p0.3.1
# image: registry.cn-hangzhou.aliyuncs.com/xuzeyu91/antsk:p0.3.9
ports:
- 5000:5000
networks:
@@ -15,5 +15,7 @@ services:
- ASPNETCORE_URLS=http://*:5000
volumes:
- ./appsettings.json:/app/appsettings.json # 本地配置文件 需要放在同级目录
- /AntSK/model:/app/model
- /AntSK/model:/root/.cache/modelscope/hub/AI-ModelScope # LLamaFactory模型文件
networks:
antsk:

View File

@@ -1,6 +1,20 @@
# 非 host 版本, 不使用本机代理
version: '3.8'
services:
aspire-dashboard:
container_name: aspire-dashboard
image: mcr.microsoft.com/dotnet/aspire-dashboard:8.0
networks:
- antsk
environment:
- DOTNET_DASHBOARD_UNSECURED_ALLOW_ANONYMOUS=true
- ASPIRE_ALLOW_UNSECURED_TRANSPORT=true
- DASHBOARD_OTLP_AUTHMODE=ApiKey
- DASHBOARD_OTLP_PRIMARYAPIKEY=antsk
ports:
- 18888:18888
- 18889:18889
restart: unless-stopped
antskpg:
image: registry.cn-hangzhou.aliyuncs.com/xuzeyu91/pg:v0.5.0
container_name: antskpg
@@ -18,9 +32,9 @@ services:
- ./pg/data:/var/lib/postgresql/data
antsk:
container_name: antsk
image: registry.cn-hangzhou.aliyuncs.com/xuzeyu91/antsk:v0.3.1
image: registry.cn-hangzhou.aliyuncs.com/xuzeyu91/antsk:v0.3.9
# 如果需要pytorch环境需要使用下面这个镜像镜像比较大
# image: registry.cn-hangzhou.aliyuncs.com/xuzeyu91/antsk:p0.3.1
# image: registry.cn-hangzhou.aliyuncs.com/xuzeyu91/antsk:p0.3.9
ports:
- 5000:5000
networks:
@@ -30,7 +44,15 @@ services:
restart: always
environment:
- ASPNETCORE_URLS=http://*:5000
- ASPNETCORE_FORWARDEDHEADERS_ENABLED=true
- OTEL_DOTNET_EXPERIMENTAL_OTLP_EMIT_EXCEPTION_LOG_ATTRIBUTES=true
- OTEL_DOTNET_EXPERIMENTAL_OTLP_EMIT_EVENT_LOG_ATTRIBUTES= true
- OTEL_DOTNET_EXPERIMENTAL_OTLP_RETRY=in_memory
- OTEL_EXPORTER_OTLP_ENDPOINT=http://aspire-dashboard:18889
- OTEL_SERVICE_NAME=antsk
volumes:
- ./appsettings.json:/app/appsettings.json # 本地配置文件 需要放在同级目录
- /AntSK/model:/app/model
- /AntSK/model:/root/.cache/modelscope/hub/AI-ModelScope # LLamaFactory模型文件
networks:
antsk:

View File

@@ -0,0 +1,20 @@
<Project Sdk="Microsoft.NET.Sdk">
<PropertyGroup>
<OutputType>Exe</OutputType>
<TargetFramework>net8.0</TargetFramework>
<ImplicitUsings>enable</ImplicitUsings>
<Nullable>enable</Nullable>
<IsAspireHost>true</IsAspireHost>
<UserSecretsId>32ac67c8-178a-4eeb-871d-879023582e06</UserSecretsId>
</PropertyGroup>
<ItemGroup>
<PackageReference Include="Aspire.Hosting.AppHost" Version="8.0.1" />
</ItemGroup>
<ItemGroup>
<ProjectReference Include="..\AntSK\AntSK.csproj" />
</ItemGroup>
</Project>

View File

@@ -0,0 +1,5 @@
var builder = DistributedApplication.CreateBuilder(args);
builder.AddProject<Projects.AntSK>("antsk");
builder.Build().Run();

View File

@@ -0,0 +1,8 @@
{
"Logging": {
"LogLevel": {
"Default": "Information",
"Microsoft.AspNetCore": "Warning"
}
}
}

View File

@@ -0,0 +1,9 @@
{
"Logging": {
"LogLevel": {
"Default": "Information",
"Microsoft.AspNetCore": "Warning",
"Aspire.Hosting.Dcp": "Warning"
}
}
}

View File

@@ -0,0 +1,26 @@
services:
aspire-dashboard:
container_name: "aspire-dashboard"
image: "mcr.microsoft.com/dotnet/aspire-dashboard:8.0"
environment:
DOTNET_DASHBOARD_UNSECURED_ALLOW_ANONYMOUS: "true"
ports:
- target: 18888
published: 18888
restart: unless-stopped
antsk:
container_name: "antsk"
image: "antsk:latest"
environment:
OTEL_DOTNET_EXPERIMENTAL_OTLP_EMIT_EXCEPTION_LOG_ATTRIBUTES: "true"
OTEL_DOTNET_EXPERIMENTAL_OTLP_EMIT_EVENT_LOG_ATTRIBUTES: "true"
OTEL_DOTNET_EXPERIMENTAL_OTLP_RETRY: "in_memory"
ASPNETCORE_FORWARDEDHEADERS_ENABLED: "true"
OTEL_EXPORTER_OTLP_ENDPOINT: "http://aspire-dashboard:18889"
OTEL_SERVICE_NAME: "antsk"
ports:
- target: 8080
published: 10000
- target: 8443
published: 10001
restart: unless-stopped

View File

@@ -0,0 +1,17 @@
{
"projectPath": ".",
"outputPath": "aspirate-output",
"containerImageTags": [
"latest"
],
"containerBuilder": "docker",
"outputFormat": "compose",
"privateRegistryEmail": "aspir8@aka.ms",
"includeDashboard": true,
"secrets": {
"salt": "fjamZa3pQbM1UyY4",
"hash": "QR\u002BSEr3p2SwD/w2oPE21vrWh/EerhNyVyTkr0atIREw=",
"secrets": {}
},
"processAllComponents": true
}

View File

@@ -0,0 +1,26 @@
{
"resources": {
"antsk": {
"type": "project.v0",
"path": "../AntSK/AntSK.csproj",
"env": {
"OTEL_DOTNET_EXPERIMENTAL_OTLP_EMIT_EXCEPTION_LOG_ATTRIBUTES": "true",
"OTEL_DOTNET_EXPERIMENTAL_OTLP_EMIT_EVENT_LOG_ATTRIBUTES": "true",
"OTEL_DOTNET_EXPERIMENTAL_OTLP_RETRY": "in_memory",
"ASPNETCORE_FORWARDEDHEADERS_ENABLED": "true"
},
"bindings": {
"http": {
"scheme": "http",
"protocol": "tcp",
"transport": "http"
},
"https": {
"scheme": "https",
"protocol": "tcp",
"transport": "http"
}
}
}
}
}

View File

@@ -5,30 +5,30 @@
<ImplicitUsings>enable</ImplicitUsings>
<Nullable>enable</Nullable>
<DocumentationFile>AntSK.Domain.xml</DocumentationFile>
<NoWarn>CA1050,CA1707,CA2007,VSTHRD111,CS1591,RCS1110,CA5394,SKEXP0001,SKEXP0002,SKEXP0003,SKEXP0004,SKEXP0010,SKEXP0011,,SKEXP0012,SKEXP0020,SKEXP0021,SKEXP0022,SKEXP0023,SKEXP0024,SKEXP0025,SKEXP0026,SKEXP0027,SKEXP0028,SKEXP0029,SKEXP0030,SKEXP0031,SKEXP0032,SKEXP0040,SKEXP0041,SKEXP0042,SKEXP0050,SKEXP0051,SKEXP0052,SKEXP0053,SKEXP0054,SKEXP0055,SKEXP0060,SKEXP0061,SKEXP0101,SKEXP0102</NoWarn>
<NoWarn>CA1050,CA1707,CA2007,VSTHRD111,CS1591,RCS1110,CA5394,SKEXP0001,SKEXP0002,SKEXP0003,SKEXP0004,SKEXP0010,SKEXP0011,,SKEXP0012,SKEXP0020,SKEXP0021,SKEXP0022,SKEXP0023,SKEXP0024,SKEXP0025,SKEXP0026,SKEXP0027,SKEXP0028,SKEXP0029,SKEXP0030,SKEXP0031,SKEXP0032,SKEXP0040,SKEXP0041,SKEXP0042,SKEXP0050,SKEXP0051,SKEXP0052,SKEXP0053,SKEXP0054,SKEXP0055,SKEXP0060,SKEXP0061,SKEXP0101,SKEXP0102,KMEXP00</NoWarn>
</PropertyGroup>
<ItemGroup>
<PackageReference Include="AntDesign.Charts" Version="0.5.1" />
<PackageReference Include="AntDesign.ProLayout" Version="0.18.2" />
<PackageReference Include="AntDesign.ProLayout" Version="0.19.0" />
<PackageReference Include="BlazorComponents.Terminal" Version="0.6.0" />
<PackageReference Include="Blazored.LocalStorage" Version="4.5.0" />
<PackageReference Include="pythonnet" Version="3.0.3" />
<PackageReference Include="Swashbuckle.AspNetCore" Version="6.5.0" />
<PackageReference Include="Swashbuckle.AspNetCore" Version="6.6.2" />
<PackageReference Include="AutoMapper" Version="8.1.0" />
<PackageReference Include="BCrypt.Net-Next" Version="4.0.3" />
<PackageReference Include="Markdig" Version="0.37.0" />
<PackageReference Include="Newtonsoft.Json" Version="13.0.3" />
<PackageReference Include="SqlSugarCore" Version="5.1.4.152" />
<PackageReference Include="SqlSugarCore" Version="5.1.4.158" />
<PackageReference Include="System.Data.SQLite.Core" Version="1.0.118" />
<PackageReference Include="RestSharp" Version="110.2.0" />
<PackageReference Include="RestSharp" Version="111.2.0" />
<PackageReference Include="NPOI" Version="2.7.0" />
<PackageReference Include="Microsoft.SemanticKernel" Version="1.8.0" />
<PackageReference Include="Microsoft.SemanticKernel.Core" Version="1.8.0" />
<PackageReference Include="Microsoft.SemanticKernel.Plugins.Core" Version="1.8.0-alpha" />
<PackageReference Include="Microsoft.SemanticKernel" Version="1.14.1" />
<PackageReference Include="Microsoft.SemanticKernel.Core" Version="1.14.1" />
<PackageReference Include="Microsoft.SemanticKernel.Plugins.Core" Version="1.14.1-alpha" />
<PackageReference Include="Microsoft.KernelMemory.Core" Version="$(KMVersion)" />
<PackageReference Include="Microsoft.KernelMemory.MemoryDb.Postgres" Version="$(KMVersion)" />
<PackageReference Include="Microsoft.KernelMemory.MemoryDb.Qdrant" Version="$(KMVersion)" />
@@ -40,8 +40,14 @@
<PackageReference Include="LLamaSharp.Backend.Cuda12" Version="$(LLamaSharpVersion)" />
<PackageReference Include="LLamaSharp.kernel-memory" Version="$(LLamaSharpVersion)" />
<PackageReference Include="LLamaSharp.semantic-kernel" Version="$(LLamaSharpVersion)" />
<PackageReference Include="Serilog" Version="4.0.0" />
<PackageReference Include="Serilog.Sinks.Console" Version="5.1.0-dev-00943" />
<PackageReference Include="Serilog.Sinks.File" Version="5.0.1-dev-00972" />
<PackageReference Include="Serilog.Extensions.Logging" Version="8.0.1-dev-10391" />
<PackageReference Include="Serilog.Settings.Configuration" Version="8.0.1" />
<PackageReference Include="Serilog.Sinks.Seq" Version="8.0.0" />
<PackageReference Include="Serilog.Sinks.OpenTelemetry" Version="3.0.0" />
</ItemGroup>
<ItemGroup>
<ProjectReference Include="..\AntSK.LLamaFactory\AntSK.LLamaFactory.csproj" />

View File

@@ -5,6 +5,7 @@ using DocumentFormat.OpenXml.Office2016.Drawing.ChartDrawing;
using Microsoft.AspNetCore.Builder;
using Microsoft.AspNetCore.Mvc;
using Microsoft.Extensions.DependencyInjection;
using Microsoft.Extensions.Logging;
using Microsoft.OpenApi.Models;
using SqlSugar;
using Swashbuckle.AspNetCore.SwaggerGen;
@@ -19,6 +20,12 @@ namespace AntSK.Domain.Common.DependencyInjection
{
public static class InitExtensions
{
private static ILogger _logger;
public static void InitLog(ILogger logger)
{
_logger = logger;
}
/// <summary>
/// 使用codefirst创建数据库表
/// </summary>
@@ -52,6 +59,8 @@ namespace AntSK.Domain.Common.DependencyInjection
}
//安装向量插件
_repository.GetDB().Ado.ExecuteCommandAsync($"CREATE EXTENSION IF NOT EXISTS vector;");
_logger.LogInformation("初始化表结构完成");
}
return app;
}
@@ -72,7 +81,7 @@ namespace AntSK.Domain.Common.DependencyInjection
llamafactoryStart.Value = "false";
_dic_Repository.Insert(llamafactoryStart);
}
_logger.LogInformation("初始化数据库初始数据完成");
}
return app;
}
@@ -99,7 +108,7 @@ namespace AntSK.Domain.Common.DependencyInjection
}
catch (Exception ex)
{
Console.WriteLine(ex.Message + " ---- " + ex.StackTrace);
_logger.LogError(ex.Message + " ---- " + ex.StackTrace);
}
return app;
}

View File

@@ -1,4 +1,6 @@
using System;
using Amazon.Runtime.Internal.Util;
using Microsoft.Extensions.Logging;
using System;
using System.Collections.Generic;
using System.Diagnostics;
using System.Linq;
@@ -7,7 +9,7 @@ using System.Threading.Tasks;
namespace AntSK.Domain.Common.LLamaFactory
{
public class ProcessWrapper
public class ProcessWrapper(ILogger<ProcessWrapper> _logger)
{
private Process process;
@@ -41,7 +43,7 @@ namespace AntSK.Domain.Common.LLamaFactory
isProcessComplete = true;
}
}
Console.WriteLine(result);
_logger.LogInformation(result);
}
start.WaitForExit();
}

View File

@@ -1,8 +0,0 @@
<Project>
<!-- See https://aka.ms/dotnet/msbuild/customize for more details on customizing your build -->
<PropertyGroup>
<KMVersion>0.37.240420.2</KMVersion>
<LLamaSharpVersion>0.11.2</LLamaSharpVersion>
</PropertyGroup>
</Project>

View File

@@ -1,27 +1,31 @@
using AntSK.BackgroundTask;
using Amazon.Runtime.Internal.Util;
using AntSK.BackgroundTask;
using AntSK.Domain.Domain.Interface;
using AntSK.Domain.Domain.Model;
using Microsoft.Extensions.DependencyInjection;
using Microsoft.Extensions.Logging;
namespace AntSK.Domain.Domain.Other
{
public class BackGroundTaskHandler : IBackgroundTaskHandler<ImportKMSTaskReq>
{
private readonly IServiceScopeFactory _scopeFactory;
private readonly ILogger<BackGroundTaskHandler> _logger;
public BackGroundTaskHandler(IServiceScopeFactory scopeFactory)
public BackGroundTaskHandler(IServiceScopeFactory scopeFactory, ILogger<BackGroundTaskHandler> logger)
{
_scopeFactory = scopeFactory;
_logger = logger;
}
public async Task ExecuteAsync(ImportKMSTaskReq item)
{
using (var scope = _scopeFactory.CreateScope())
{
Console.WriteLine("ExecuteAsync.开始执行后台任务");
_logger.LogInformation("ExecuteAsync.开始执行后台任务");
var importKMSService = scope.ServiceProvider.GetRequiredService<IImportKMSService>();
//不能使用异步
importKMSService.ImportKMSTask(item);
Console.WriteLine("ExecuteAsync.后台任务执行完成");
_logger.LogInformation("ExecuteAsync.后台任务执行完成");
}
}

View File

@@ -1,5 +1,7 @@
using Microsoft.KernelMemory.AI.OpenAI.GPT3;
using Microsoft.KernelMemory.AI.OpenAI;
using Microsoft.KernelMemory.AI.OpenAI.GPT3;
using Python.Runtime;
using Serilog;
using System;
using System.Collections.Generic;
using System.Linq;
@@ -85,13 +87,13 @@ namespace AntSK.Domain.Domain.Other.Bge
// return len;
//}
var tokenCount1 = GPT3Tokenizer.Encode(queryStr).Count;
var tokenCount1 = DefaultGPTTokenizer.StaticCountTokens(queryStr);
return tokenCount1;
}
public static void Dispose()
{
Console.WriteLine("python dispose");
Log.Information("python dispose");
}
}
}

View File

@@ -1,4 +1,5 @@
using AntSK.Domain.Domain.Model.Constant;
using AntSK.Domain.Utils;
using Microsoft.Extensions.Logging;
using Microsoft.KernelMemory.AI.OpenAI;
using Microsoft.KernelMemory.Configuration;
@@ -134,7 +135,7 @@ namespace AntSK.Domain.Domain.Other
PartitionNumber = partitionNumber,
SectionNumber = sectionNumber,
Tags = pipeline.Tags,
ContentSHA256 = textData.CalculateSHA256(),
ContentSHA256 = textData.AntSKCalculateSHA256(),
};
newFiles.Add(destFile, destFileDetails);
destFileDetails.MarkProcessedBy(this);

View File

@@ -1,4 +1,5 @@
using LLama;
using AntSK.Domain.Options;
using LLama;
using LLama.Common;
using LLamaSharp.KernelMemory;
@@ -29,10 +30,10 @@ namespace AntSK.Domain.Domain.Other
}
var parameters = new ModelParams(lsConfig.ModelPath)
{
ContextSize = lsConfig?.ContextSize ?? 2048,
ContextSize = LLamaSharpOption.ContextSize ?? 2048,
Seed = lsConfig?.Seed ?? 0,
GpuLayerCount = lsConfig?.GpuLayerCount ?? 20,
EmbeddingMode = true
GpuLayerCount = LLamaSharpOption.GpuLayerCount ?? 20,
Embeddings = true
};
var weights = LLamaWeights.LoadFromFile(parameters);
dicLLamaWeights.Add(modelPath, (weights, parameters));

View File

@@ -151,7 +151,7 @@ namespace AntSK.Domain.Domain.Other
PartitionNumber = partitionNumber,
SectionNumber = sectionNumber,
Tags = pipeline.Tags,
ContentSHA256 = textData.CalculateSHA256(),
ContentSHA256 = textData.AntSKCalculateSHA256(),
};
newFiles.Add(destFile, destFileDetails);
destFileDetails.MarkProcessedBy(this);

View File

@@ -330,22 +330,18 @@ namespace AntSK.Domain.Domain.Service
public async Task<ChatHistory> GetChatHistory(List<Chats> MessageList, ChatHistory history)
{
if (MessageList.Count > 1)
foreach (var item in MessageList)
{
foreach (var item in MessageList)
if (item.IsSend)
{
if (item.IsSend)
{
history.AddUserMessage(item.Context);
}
else
{
history.AddAssistantMessage(item.Context);
}
history.AddUserMessage(item.Context);
}
else
{
history.AddAssistantMessage(item.Context);
}
}
return history;
}
}
}
}

View File

@@ -8,6 +8,7 @@ using System.Text.RegularExpressions;
using Microsoft.SemanticKernel;
using HtmlAgilityPack;
using System.Collections.Generic;
using Serilog;
namespace AntSK.Domain.Domain.Service
{
@@ -115,7 +116,7 @@ namespace AntSK.Domain.Domain.Service
}
catch (Exception ex)
{
Console.WriteLine(ex.Message + " ---- " + ex.StackTrace);
Log.Error(ex.Message + " ---- " + ex.StackTrace);
}
}
}

View File

@@ -5,6 +5,7 @@ using AntSK.Domain.Domain.Model.Constant;
using AntSK.Domain.Domain.Model.Excel;
using AntSK.Domain.Domain.Other;
using AntSK.Domain.Repositories;
using Microsoft.Extensions.Logging;
using Microsoft.KernelMemory;
using Microsoft.KernelMemory.Handlers;
using System.Text;
@@ -15,7 +16,8 @@ namespace AntSK.Domain.Domain.Service
public class ImportKMSService(
IKMService _kMService,
IKmsDetails_Repositories _kmsDetails_Repositories,
IKmss_Repositories _kmss_Repositories
IKmss_Repositories _kmss_Repositories,
ILogger<ImportKMSService> _logger
) : IImportKMSService
{
@@ -140,13 +142,13 @@ namespace AntSK.Domain.Domain.Service
req.KmsDetail.Status = Model.Enum.ImportKmsStatus.Success;
_kmsDetails_Repositories.Update(req.KmsDetail);
//_kmsDetails_Repositories.GetList(p => p.KmsId == req.KmsId);
Console.WriteLine("后台导入任务成功:" + req.KmsDetail.DataCount);
_logger.LogInformation("后台导入任务成功:" + req.KmsDetail.DataCount);
}
catch (Exception ex)
{
req.KmsDetail.Status = Model.Enum.ImportKmsStatus.Fail;
_kmsDetails_Repositories.Update(req.KmsDetail);
Console.WriteLine("后台导入任务异常:" + ex.Message);
_logger.LogError("后台导入任务异常:" + ex.Message);
}
}
}

View File

@@ -296,7 +296,7 @@ namespace AntSK.Domain.Domain.Service
{
DocumentId = item.GetDocumentId(),
Text = item.GetPartitionText(),
Url = item.GetWebPageUrl(),
Url = item.GetWebPageUrl(KmsConstantcs.KmsIndex),
LastUpdate = item.GetLastUpdate().LocalDateTime.ToString("yyyy-MM-dd HH:mm:ss"),
File = item.GetFileName()
}));

View File

@@ -22,6 +22,8 @@ using Microsoft.KernelMemory;
using OpenCvSharp.ML;
using LLamaSharp.SemanticKernel.ChatCompletion;
using Microsoft.SemanticKernel.ChatCompletion;
using Amazon.Runtime.Internal.Util;
using Microsoft.Extensions.Logging;
namespace AntSK.Domain.Domain.Service
{
@@ -33,17 +35,20 @@ namespace AntSK.Domain.Domain.Service
private readonly FunctionService _functionService;
private readonly IServiceProvider _serviceProvider;
private Kernel _kernel;
private readonly ILogger<KernelService> _logger;
public KernelService(
IApis_Repositories apis_Repositories,
IAIModels_Repositories aIModels_Repositories,
FunctionService functionService,
IServiceProvider serviceProvider)
IServiceProvider serviceProvider,
ILogger<KernelService> logger)
{
_apis_Repositories = apis_Repositories;
_aIModels_Repositories = aIModels_Repositories;
_functionService = functionService;
_serviceProvider = serviceProvider;
_logger = logger;
}
/// <summary>
@@ -178,7 +183,6 @@ namespace AntSK.Domain.Domain.Service
var getParametes = new List<KernelParameterMetadata>() {
new KernelParameterMetadata("jsonbody"){
Name="json参数字符串",
ParameterType=typeof(string),
Description=$"背景文档:{Environment.NewLine}{api.InputPrompt} {Environment.NewLine}提取出对应的json格式字符串参考如下格式:{Environment.NewLine}{api.Query}"
}
@@ -217,7 +221,6 @@ namespace AntSK.Domain.Domain.Service
//处理json body
var postParametes = new List<KernelParameterMetadata>() {
new KernelParameterMetadata("jsonbody"){
Name="json参数字符串",
ParameterType=typeof(string),
Description=$"背景文档:{Environment.NewLine}{api.InputPrompt} {Environment.NewLine}提取出对应的json格式字符串参考如下格式:{Environment.NewLine}{api.JsonBody}"
}
@@ -226,7 +229,7 @@ namespace AntSK.Domain.Domain.Service
{
try
{
Console.WriteLine(jsonBody);
_logger.LogInformation(jsonBody);
RestClient client = new RestClient();
RestRequest request = new RestRequest(api.Url, Method.Post);
foreach (var header in api.Header.ConvertToString().Split("\n"))
@@ -305,8 +308,8 @@ namespace AntSK.Domain.Domain.Service
KernelFunction sunFun = _kernel.Plugins.GetFunction("ConversationSummaryPlugin", "SummarizeConversation");
var summary = await _kernel.InvokeAsync(sunFun, new() { ["input"] = $"内容是:{history.ToString()} {Environment.NewLine} 请注意用中文总结" });
string his = summary.GetValue<string>();
var msg = $"history{Environment.NewLine}{history.ToString()}{Environment.NewLine} user{questions}{Environment.NewLine}"; ;
var msg = $"history{Environment.NewLine}{his}{Environment.NewLine} user{questions}{Environment.NewLine}";
return msg;
}
}
}
}

View File

@@ -1,9 +1,11 @@
using AntSK.Domain.Common.DependencyInjection;
using Amazon.Runtime.Internal.Util;
using AntSK.Domain.Common.DependencyInjection;
using AntSK.Domain.Domain.Interface;
using AntSK.Domain.Domain.Model.Dto;
using AntSK.Domain.Options;
using AntSK.LLamaFactory.Model;
using Microsoft.AspNetCore.Mvc.ModelBinding;
using Microsoft.Extensions.Logging;
using Newtonsoft.Json;
using System;
using System.Collections.Generic;
@@ -17,7 +19,7 @@ using System.Threading.Tasks;
namespace AntSK.Domain.Domain.Service
{
[ServiceDescription(typeof(ILLamaFactoryService), ServiceLifetime.Singleton)]
public class LLamaFactoryService : ILLamaFactoryService
public class LLamaFactoryService(ILogger<LLamaFactoryService> _logger) : ILLamaFactoryService
{
private Process process;
@@ -26,7 +28,7 @@ namespace AntSK.Domain.Domain.Service
private readonly object _syncLock = new object();
private List<LLamaModel> modelList = new List<LLamaModel>();
public LLamaFactoryService() { }
public delegate Task LogMessageHandler(string message);
public event LogMessageHandler LogMessageReceived;
protected virtual async Task OnLogMessageReceived(string message)
@@ -56,12 +58,12 @@ namespace AntSK.Domain.Domain.Service
};
process.OutputDataReceived += (sender, eventArgs) =>
{
Console.WriteLine($"{eventArgs.Data}");
_logger.LogInformation($"{eventArgs.Data}");
OnLogMessageReceived(eventArgs.Data);
};
process.ErrorDataReceived += (sender, eventArgs) =>
{
Console.WriteLine($"{eventArgs.Data}");
_logger.LogInformation($"{eventArgs.Data}");
OnLogMessageReceived(eventArgs.Data);
};
process.Start();
@@ -85,7 +87,7 @@ namespace AntSK.Domain.Domain.Service
StartInfo = new ProcessStartInfo
{
FileName = "python",
Arguments = "api_demo.py --model_name_or_path " + modelName + " --template " + templateName + " ",
Arguments = "api_antsk.py --model_name_or_path " + modelName + " --template " + templateName + " ",
UseShellExecute = false,
RedirectStandardOutput = true,
RedirectStandardError=true,
@@ -97,12 +99,12 @@ namespace AntSK.Domain.Domain.Service
process.StartInfo.EnvironmentVariables["USE_MODELSCOPE_HUB"] = Environment.GetEnvironmentVariable("USE_MODELSCOPE_HUB") ?? "1";
process.OutputDataReceived += (sender, eventArgs) =>
{
Console.WriteLine($"{eventArgs.Data}");
_logger.LogInformation($"{eventArgs.Data}");
OnLogMessageReceived(eventArgs.Data);
};
process.ErrorDataReceived += (sender, eventArgs) =>
{
Console.WriteLine($"{eventArgs.Data}");
_logger.LogInformation($"{eventArgs.Data}");
OnLogMessageReceived(eventArgs.Data);
};
process.Start();
@@ -137,7 +139,7 @@ namespace AntSK.Domain.Domain.Service
if (process1.ProcessName.ToLower() == "python")
{
process1.Kill();
System.Console.WriteLine("kill python");
_logger.LogInformation("kill python");
}
}
}

View File

@@ -0,0 +1,13 @@
using System;
using System.Collections.Generic;
using System.Linq;
using System.Text;
using System.Threading.Tasks;
namespace AntSK.Domain.Options
{
public class FileDirOption
{
public static string DirectoryPath { get; set; } = Directory.GetCurrentDirectory();
}
}

View File

@@ -3,6 +3,7 @@
public class LLamaSharpOption
{
public static string RunType { get; set; }
public static string FileDirectory { get; set; } = Directory.GetCurrentDirectory();
public static uint? ContextSize { get; set; }
public static int? GpuLayerCount { get; set; }
}
}

View File

@@ -1,4 +1,5 @@
using System.Web;
using System.Security.Cryptography;
using System.Web;
namespace AntSK.Domain.Utils
{
@@ -261,5 +262,11 @@ namespace AntSK.Domain.Utils
{
return s.Equals(value, StringComparison.OrdinalIgnoreCase);
}
public static string AntSKCalculateSHA256(this BinaryData binaryData)
{
byte[] byteArray = SHA256.HashData(binaryData.ToMemory().Span);
return Convert.ToHexString(byteArray).ToLowerInvariant();
}
}
}

View File

@@ -1,4 +1,6 @@
using System.Text.RegularExpressions;

using Serilog;
using System.Text.RegularExpressions;
namespace AntSK.Domain.Utils
{
@@ -19,7 +21,7 @@ namespace AntSK.Domain.Utils
{
string requestBody = await request.Content.ReadAsStringAsync();
//便于调试查看请求prompt
Console.WriteLine(requestBody);
Log.Information(requestBody);
}
if (match.Success)
{

View File

@@ -0,0 +1,19 @@
import os
import uvicorn
from llamafactory.api.app import create_app
from llamafactory.chat import ChatModel
def main():
chat_model = ChatModel()
app = create_app(chat_model)
api_host = os.environ.get("API_HOST", "0.0.0.0")
api_port = int(os.environ.get("API_PORT", "8000"))
print("Visit http://localhost:{}/docs for API document.".format(api_port))
uvicorn.run(app, host=api_host, port=api_port)
if __name__ == "__main__":
main()

View File

@@ -1,16 +0,0 @@
import os
import uvicorn
from llmtuner import ChatModel, create_app
def main():
chat_model = ChatModel()
app = create_app(chat_model)
print("Visit http://localhost:{}/docs for API document.".format(os.environ.get("API_PORT", 8000)))
uvicorn.run(app, host="0.0.0.0", port=int(os.environ.get("API_PORT", 8000)), workers=1)
if __name__ == "__main__":
main()

View File

@@ -0,0 +1,6 @@
# Level: api, webui > chat, eval, train > data, model > hparams > extras
from .cli import VERSION
__version__ = VERSION

View File

@@ -0,0 +1,108 @@
import os
from contextlib import asynccontextmanager
from typing import Optional
from typing_extensions import Annotated
from ..chat import ChatModel
from ..extras.misc import torch_gc
from ..extras.packages import is_fastapi_available, is_starlette_available, is_uvicorn_available
from .chat import (
create_chat_completion_response,
create_score_evaluation_response,
create_stream_chat_completion_response,
)
from .protocol import (
ChatCompletionRequest,
ChatCompletionResponse,
ModelCard,
ModelList,
ScoreEvaluationRequest,
ScoreEvaluationResponse,
)
if is_fastapi_available():
from fastapi import Depends, FastAPI, HTTPException, status
from fastapi.middleware.cors import CORSMiddleware
from fastapi.security.http import HTTPAuthorizationCredentials, HTTPBearer
if is_starlette_available():
from sse_starlette import EventSourceResponse
if is_uvicorn_available():
import uvicorn
@asynccontextmanager
async def lifespan(app: "FastAPI"): # collects GPU memory
yield
torch_gc()
def create_app(chat_model: "ChatModel") -> "FastAPI":
app = FastAPI(lifespan=lifespan)
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
api_key = os.environ.get("API_KEY")
security = HTTPBearer(auto_error=False)
async def verify_api_key(auth: Annotated[Optional[HTTPAuthorizationCredentials], Depends(security)]):
if api_key and (auth is None or auth.credentials != api_key):
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid API key.")
@app.get(
"/v1/models",
response_model=ModelList,
status_code=status.HTTP_200_OK,
dependencies=[Depends(verify_api_key)],
)
async def list_models():
model_card = ModelCard(id="gpt-3.5-turbo")
return ModelList(data=[model_card])
@app.post(
"/v1/chat/completions",
response_model=ChatCompletionResponse,
status_code=status.HTTP_200_OK,
dependencies=[Depends(verify_api_key)],
)
async def create_chat_completion(request: ChatCompletionRequest):
if not chat_model.engine.can_generate:
raise HTTPException(status_code=status.HTTP_405_METHOD_NOT_ALLOWED, detail="Not allowed")
if request.stream:
generate = create_stream_chat_completion_response(request, chat_model)
return EventSourceResponse(generate, media_type="text/event-stream")
else:
return await create_chat_completion_response(request, chat_model)
@app.post(
"/v1/score/evaluation",
response_model=ScoreEvaluationResponse,
status_code=status.HTTP_200_OK,
dependencies=[Depends(verify_api_key)],
)
async def create_score_evaluation(request: ScoreEvaluationRequest):
if chat_model.engine.can_generate:
raise HTTPException(status_code=status.HTTP_405_METHOD_NOT_ALLOWED, detail="Not allowed")
return await create_score_evaluation_response(request, chat_model)
return app
def run_api() -> None:
chat_model = ChatModel()
app = create_app(chat_model)
api_host = os.environ.get("API_HOST", "0.0.0.0")
api_port = int(os.environ.get("API_PORT", "8000"))
print("Visit http://localhost:{}/docs for API document.".format(api_port))
uvicorn.run(app, host=api_host, port=api_port)

View File

@@ -0,0 +1,219 @@
import base64
import io
import json
import os
import uuid
from typing import TYPE_CHECKING, AsyncGenerator, Dict, List, Optional, Tuple
from ..data import Role as DataRole
from ..extras.logging import get_logger
from ..extras.packages import is_fastapi_available, is_pillow_available, is_requests_available
from .common import dictify, jsonify
from .protocol import (
ChatCompletionMessage,
ChatCompletionResponse,
ChatCompletionResponseChoice,
ChatCompletionResponseUsage,
ChatCompletionStreamResponse,
ChatCompletionStreamResponseChoice,
Finish,
Function,
FunctionCall,
Role,
ScoreEvaluationResponse,
)
if is_fastapi_available():
from fastapi import HTTPException, status
if is_pillow_available():
from PIL import Image
if is_requests_available():
import requests
if TYPE_CHECKING:
from numpy.typing import NDArray
from ..chat import ChatModel
from .protocol import ChatCompletionRequest, ScoreEvaluationRequest
logger = get_logger(__name__)
ROLE_MAPPING = {
Role.USER: DataRole.USER.value,
Role.ASSISTANT: DataRole.ASSISTANT.value,
Role.SYSTEM: DataRole.SYSTEM.value,
Role.FUNCTION: DataRole.FUNCTION.value,
Role.TOOL: DataRole.OBSERVATION.value,
}
def _process_request(
request: "ChatCompletionRequest",
) -> Tuple[List[Dict[str, str]], Optional[str], Optional[str], Optional["NDArray"]]:
logger.info("==== request ====\n{}".format(json.dumps(dictify(request), indent=2, ensure_ascii=False)))
if len(request.messages) == 0:
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid length")
if request.messages[0].role == Role.SYSTEM:
system = request.messages.pop(0).content
else:
system = None
if len(request.messages) % 2 == 0:
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Only supports u/a/u/a/u...")
input_messages = []
image = None
for i, message in enumerate(request.messages):
if i % 2 == 0 and message.role not in [Role.USER, Role.TOOL]:
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid role")
elif i % 2 == 1 and message.role not in [Role.ASSISTANT, Role.FUNCTION]:
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid role")
if message.role == Role.ASSISTANT and isinstance(message.tool_calls, list) and len(message.tool_calls):
name = message.tool_calls[0].function.name
arguments = message.tool_calls[0].function.arguments
content = json.dumps({"name": name, "argument": arguments}, ensure_ascii=False)
input_messages.append({"role": ROLE_MAPPING[Role.FUNCTION], "content": content})
elif isinstance(message.content, list):
for input_item in message.content:
if input_item.type == "text":
input_messages.append({"role": ROLE_MAPPING[message.role], "content": input_item.text})
else:
image_url = input_item.image_url.url
if image_url.startswith("data:image"): # base64 image
image_data = base64.b64decode(image_url.split(",", maxsplit=1)[1])
image_path = io.BytesIO(image_data)
elif os.path.isfile(image_url): # local file
image_path = open(image_url, "rb")
else: # web uri
image_path = requests.get(image_url, stream=True).raw
image = Image.open(image_path).convert("RGB")
else:
input_messages.append({"role": ROLE_MAPPING[message.role], "content": message.content})
tool_list = request.tools
if isinstance(tool_list, list) and len(tool_list):
try:
tools = json.dumps([dictify(tool.function) for tool in tool_list], ensure_ascii=False)
except Exception:
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid tools")
else:
tools = None
return input_messages, system, tools, image
def _create_stream_chat_completion_chunk(
completion_id: str,
model: str,
delta: "ChatCompletionMessage",
index: Optional[int] = 0,
finish_reason: Optional["Finish"] = None,
) -> str:
choice_data = ChatCompletionStreamResponseChoice(index=index, delta=delta, finish_reason=finish_reason)
chunk = ChatCompletionStreamResponse(id=completion_id, model=model, choices=[choice_data])
return jsonify(chunk)
async def create_chat_completion_response(
request: "ChatCompletionRequest", chat_model: "ChatModel"
) -> "ChatCompletionResponse":
completion_id = "chatcmpl-{}".format(uuid.uuid4().hex)
input_messages, system, tools, image = _process_request(request)
responses = await chat_model.achat(
input_messages,
system,
tools,
image,
do_sample=request.do_sample,
temperature=request.temperature,
top_p=request.top_p,
max_new_tokens=request.max_tokens,
num_return_sequences=request.n,
stop=request.stop,
)
prompt_length, response_length = 0, 0
choices = []
for i, response in enumerate(responses):
if tools:
result = chat_model.engine.template.format_tools.extract(response.response_text)
else:
result = response.response_text
if isinstance(result, tuple):
name, arguments = result
function = Function(name=name, arguments=arguments)
tool_call = FunctionCall(id="call_{}".format(uuid.uuid4().hex), function=function)
response_message = ChatCompletionMessage(role=Role.ASSISTANT, tool_calls=[tool_call])
finish_reason = Finish.TOOL
else:
response_message = ChatCompletionMessage(role=Role.ASSISTANT, content=result)
finish_reason = Finish.STOP if response.finish_reason == "stop" else Finish.LENGTH
choices.append(ChatCompletionResponseChoice(index=i, message=response_message, finish_reason=finish_reason))
prompt_length = response.prompt_length
response_length += response.response_length
usage = ChatCompletionResponseUsage(
prompt_tokens=prompt_length,
completion_tokens=response_length,
total_tokens=prompt_length + response_length,
)
return ChatCompletionResponse(id=completion_id, model=request.model, choices=choices, usage=usage)
async def create_stream_chat_completion_response(
request: "ChatCompletionRequest", chat_model: "ChatModel"
) -> AsyncGenerator[str, None]:
completion_id = "chatcmpl-{}".format(uuid.uuid4().hex)
input_messages, system, tools, image = _process_request(request)
if tools:
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Cannot stream function calls.")
if request.n > 1:
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Cannot stream multiple responses.")
yield _create_stream_chat_completion_chunk(
completion_id=completion_id, model=request.model, delta=ChatCompletionMessage(role=Role.ASSISTANT, content="")
)
async for new_token in chat_model.astream_chat(
input_messages,
system,
tools,
image,
do_sample=request.do_sample,
temperature=request.temperature,
top_p=request.top_p,
max_new_tokens=request.max_tokens,
stop=request.stop,
):
if len(new_token) != 0:
yield _create_stream_chat_completion_chunk(
completion_id=completion_id, model=request.model, delta=ChatCompletionMessage(content=new_token)
)
yield _create_stream_chat_completion_chunk(
completion_id=completion_id, model=request.model, delta=ChatCompletionMessage(), finish_reason=Finish.STOP
)
yield "[DONE]"
async def create_score_evaluation_response(
request: "ScoreEvaluationRequest", chat_model: "ChatModel"
) -> "ScoreEvaluationResponse":
if len(request.messages) == 0:
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid request")
scores = await chat_model.aget_scores(request.messages, max_length=request.max_length)
return ScoreEvaluationResponse(model=request.model, scores=scores)

View File

@@ -0,0 +1,20 @@
import json
from typing import TYPE_CHECKING, Any, Dict
if TYPE_CHECKING:
from pydantic import BaseModel
def dictify(data: "BaseModel") -> Dict[str, Any]:
try: # pydantic v2
return data.model_dump(exclude_unset=True)
except AttributeError: # pydantic v1
return data.dict(exclude_unset=True)
def jsonify(data: "BaseModel") -> str:
try: # pydantic v2
return json.dumps(data.model_dump(exclude_unset=True), ensure_ascii=False)
except AttributeError: # pydantic v1
return data.json(exclude_unset=True, ensure_ascii=False)

View File

@@ -1,6 +1,6 @@
import time
from enum import Enum, unique
from typing import List, Optional
from typing import Any, Dict, List, Optional, Union
from pydantic import BaseModel, Field
from typing_extensions import Literal
@@ -39,15 +39,37 @@ class Function(BaseModel):
arguments: str
class FunctionDefinition(BaseModel):
name: str
description: str
parameters: Dict[str, Any]
class FunctionAvailable(BaseModel):
type: Literal["function", "code_interpreter"] = "function"
function: Optional[FunctionDefinition] = None
class FunctionCall(BaseModel):
id: Literal["call_default"] = "call_default"
id: str
type: Literal["function"] = "function"
function: Function
class ImageURL(BaseModel):
url: str
class MultimodalInputItem(BaseModel):
type: Literal["text", "image_url"]
text: Optional[str] = None
image_url: Optional[ImageURL] = None
class ChatMessage(BaseModel):
role: Role
content: str
content: Optional[Union[str, List[MultimodalInputItem]]] = None
tool_calls: Optional[List[FunctionCall]] = None
class ChatCompletionMessage(BaseModel):
@@ -59,12 +81,13 @@ class ChatCompletionMessage(BaseModel):
class ChatCompletionRequest(BaseModel):
model: str
messages: List[ChatMessage]
tools: list = []
tools: Optional[List[FunctionAvailable]] = None
do_sample: bool = True
temperature: Optional[float] = None
top_p: Optional[float] = None
n: int = 1
max_tokens: Optional[int] = None
stop: Optional[Union[str, List[str]]] = None
stream: bool = False
@@ -74,7 +97,7 @@ class ChatCompletionResponseChoice(BaseModel):
finish_reason: Finish
class ChatCompletionResponseStreamChoice(BaseModel):
class ChatCompletionStreamResponseChoice(BaseModel):
index: int
delta: ChatCompletionMessage
finish_reason: Optional[Finish] = None
@@ -87,7 +110,7 @@ class ChatCompletionResponseUsage(BaseModel):
class ChatCompletionResponse(BaseModel):
id: Literal["chatcmpl-default"] = "chatcmpl-default"
id: str
object: Literal["chat.completion"] = "chat.completion"
created: int = Field(default_factory=lambda: int(time.time()))
model: str
@@ -96,11 +119,11 @@ class ChatCompletionResponse(BaseModel):
class ChatCompletionStreamResponse(BaseModel):
id: Literal["chatcmpl-default"] = "chatcmpl-default"
id: str
object: Literal["chat.completion.chunk"] = "chat.completion.chunk"
created: int = Field(default_factory=lambda: int(time.time()))
model: str
choices: List[ChatCompletionResponseStreamChoice]
choices: List[ChatCompletionStreamResponseChoice]
class ScoreEvaluationRequest(BaseModel):
@@ -110,7 +133,7 @@ class ScoreEvaluationRequest(BaseModel):
class ScoreEvaluationResponse(BaseModel):
id: Literal["scoreeval-default"] = "scoreeval-default"
id: str
object: Literal["score.evaluation"] = "score.evaluation"
model: str
scores: List[float]

View File

@@ -4,15 +4,13 @@ from typing import TYPE_CHECKING, Any, AsyncGenerator, Dict, List, Literal, Opti
if TYPE_CHECKING:
from numpy.typing import NDArray
from transformers import PreTrainedModel, PreTrainedTokenizer
from vllm import AsyncLLMEngine
from ..data import Template
from ..extras.packages import is_vllm_available
from ..hparams import DataArguments, FinetuningArguments, GeneratingArguments, ModelArguments
if is_vllm_available():
from vllm import AsyncLLMEngine
@dataclass
class Response:
@@ -49,6 +47,7 @@ class BaseEngine(ABC):
messages: Sequence[Dict[str, str]],
system: Optional[str] = None,
tools: Optional[str] = None,
image: Optional["NDArray"] = None,
**input_kwargs,
) -> List["Response"]: ...
@@ -58,6 +57,7 @@ class BaseEngine(ABC):
messages: Sequence[Dict[str, str]],
system: Optional[str] = None,
tools: Optional[str] = None,
image: Optional["NDArray"] = None,
**input_kwargs,
) -> AsyncGenerator[str, None]: ...

View File

@@ -2,12 +2,15 @@ import asyncio
from threading import Thread
from typing import TYPE_CHECKING, Any, AsyncGenerator, Dict, Generator, List, Optional, Sequence
from ..extras.misc import torch_gc
from ..hparams import get_infer_args
from .hf_engine import HuggingfaceEngine
from .vllm_engine import VllmEngine
if TYPE_CHECKING:
from numpy.typing import NDArray
from .base_engine import BaseEngine, Response
@@ -36,9 +39,10 @@ class ChatModel:
messages: Sequence[Dict[str, str]],
system: Optional[str] = None,
tools: Optional[str] = None,
image: Optional["NDArray"] = None,
**input_kwargs,
) -> List["Response"]:
task = asyncio.run_coroutine_threadsafe(self.achat(messages, system, tools, **input_kwargs), self._loop)
task = asyncio.run_coroutine_threadsafe(self.achat(messages, system, tools, image, **input_kwargs), self._loop)
return task.result()
async def achat(
@@ -46,18 +50,20 @@ class ChatModel:
messages: Sequence[Dict[str, str]],
system: Optional[str] = None,
tools: Optional[str] = None,
image: Optional["NDArray"] = None,
**input_kwargs,
) -> List["Response"]:
return await self.engine.chat(messages, system, tools, **input_kwargs)
return await self.engine.chat(messages, system, tools, image, **input_kwargs)
def stream_chat(
self,
messages: Sequence[Dict[str, str]],
system: Optional[str] = None,
tools: Optional[str] = None,
image: Optional["NDArray"] = None,
**input_kwargs,
) -> Generator[str, None, None]:
generator = self.astream_chat(messages, system, tools, **input_kwargs)
generator = self.astream_chat(messages, system, tools, image, **input_kwargs)
while True:
try:
task = asyncio.run_coroutine_threadsafe(generator.__anext__(), self._loop)
@@ -70,9 +76,10 @@ class ChatModel:
messages: Sequence[Dict[str, str]],
system: Optional[str] = None,
tools: Optional[str] = None,
image: Optional["NDArray"] = None,
**input_kwargs,
) -> AsyncGenerator[str, None]:
async for new_token in self.engine.stream_chat(messages, system, tools, **input_kwargs):
async for new_token in self.engine.stream_chat(messages, system, tools, image, **input_kwargs):
yield new_token
def get_scores(
@@ -89,3 +96,45 @@ class ChatModel:
**input_kwargs,
) -> List[float]:
return await self.engine.get_scores(batch_input, **input_kwargs)
def run_chat() -> None:
try:
import platform
if platform.system() != "Windows":
import readline # noqa: F401
except ImportError:
print("Install `readline` for a better experience.")
chat_model = ChatModel()
messages = []
print("Welcome to the CLI application, use `clear` to remove the history, use `exit` to exit the application.")
while True:
try:
query = input("\nUser: ")
except UnicodeDecodeError:
print("Detected decoding error at the inputs, please set the terminal encoding to utf-8.")
continue
except Exception:
raise
if query.strip() == "exit":
break
if query.strip() == "clear":
messages = []
torch_gc()
print("History has been removed.")
continue
messages.append({"role": "user", "content": query})
print("Assistant: ", end="", flush=True)
response = ""
for new_text in chat_model.stream_chat(messages):
print(new_text, end="", flush=True)
response += new_text
print()
messages.append({"role": "assistant", "content": response})

View File

@@ -2,25 +2,31 @@ import asyncio
import concurrent.futures
import os
from threading import Thread
from typing import TYPE_CHECKING, Any, AsyncGenerator, Callable, Dict, List, Optional, Sequence, Tuple
from typing import TYPE_CHECKING, Any, AsyncGenerator, Callable, Dict, List, Optional, Sequence, Tuple, Union
import torch
from transformers import GenerationConfig, TextIteratorStreamer
from ..data import get_template_and_fix_tokenizer
from ..extras.logging import get_logger
from ..extras.misc import get_logits_processor
from ..model import load_model_and_tokenizer
from ..model import load_model, load_tokenizer
from .base_engine import BaseEngine, Response
if TYPE_CHECKING:
from transformers import PreTrainedModel, PreTrainedTokenizer
from numpy.typing import NDArray
from transformers import PreTrainedModel, PreTrainedTokenizer, ProcessorMixin
from transformers.image_processing_utils import BaseImageProcessor
from trl import PreTrainedModelWrapper
from ..data import Template
from ..hparams import DataArguments, FinetuningArguments, GeneratingArguments, ModelArguments
logger = get_logger(__name__)
class HuggingfaceEngine(BaseEngine):
def __init__(
self,
@@ -30,55 +36,96 @@ class HuggingfaceEngine(BaseEngine):
generating_args: "GeneratingArguments",
) -> None:
self.can_generate = finetuning_args.stage == "sft"
self.model, self.tokenizer = load_model_and_tokenizer(
model_args, finetuning_args, is_trainable=False, add_valuehead=(not self.can_generate)
)
tokenizer_module = load_tokenizer(model_args)
self.tokenizer = tokenizer_module["tokenizer"]
self.processor = tokenizer_module["processor"]
self.tokenizer.padding_side = "left" if self.can_generate else "right"
self.template = get_template_and_fix_tokenizer(self.tokenizer, data_args.template)
self.model = load_model(
self.tokenizer, model_args, finetuning_args, is_trainable=False, add_valuehead=(not self.can_generate)
) # must after fixing tokenizer to resize vocab
self.generating_args = generating_args.to_dict()
@staticmethod
def _process_args(
model: "PreTrainedModel",
tokenizer: "PreTrainedTokenizer",
processor: Optional["ProcessorMixin"],
template: "Template",
generating_args: Dict[str, Any],
messages: Sequence[Dict[str, str]],
system: Optional[str] = None,
tools: Optional[str] = None,
image: Optional["NDArray"] = None,
input_kwargs: Optional[Dict[str, Any]] = {},
) -> Tuple[Dict[str, Any], int]:
if (
processor is not None
and image is not None
and not hasattr(processor, "image_seq_length")
and template.image_token not in messages[0]["content"]
): # llava-like models
messages[0]["content"] = template.image_token + messages[0]["content"]
paired_messages = messages + [{"role": "assistant", "content": ""}]
system = system or generating_args["default_system"]
pixel_values = None
prompt_ids, _ = template.encode_oneturn(
tokenizer=tokenizer, messages=paired_messages, system=system, tools=tools
)
if processor is not None and image is not None: # add image features
image_processor: "BaseImageProcessor" = getattr(processor, "image_processor")
batch_feature = image_processor(image, return_tensors="pt")
pixel_values = batch_feature.to(model.device)["pixel_values"] # shape (B, C, H, W)
if hasattr(processor, "image_seq_length"): # paligemma models
image_token_id = tokenizer.convert_tokens_to_ids(template.image_token)
prompt_ids = [image_token_id] * getattr(processor, "image_seq_length") + prompt_ids
prompt_length = len(prompt_ids)
inputs = torch.tensor([prompt_ids], device=model.device)
attention_mask = torch.ones_like(inputs, dtype=torch.bool)
do_sample = input_kwargs.pop("do_sample", None)
temperature = input_kwargs.pop("temperature", None)
top_p = input_kwargs.pop("top_p", None)
top_k = input_kwargs.pop("top_k", None)
num_return_sequences = input_kwargs.pop("num_return_sequences", None)
repetition_penalty = input_kwargs.pop("repetition_penalty", None)
max_length = input_kwargs.pop("max_length", None)
max_new_tokens = input_kwargs.pop("max_new_tokens", None)
do_sample: Optional[bool] = input_kwargs.pop("do_sample", None)
temperature: Optional[float] = input_kwargs.pop("temperature", None)
top_p: Optional[float] = input_kwargs.pop("top_p", None)
top_k: Optional[float] = input_kwargs.pop("top_k", None)
num_return_sequences: int = input_kwargs.pop("num_return_sequences", 1)
repetition_penalty: Optional[float] = input_kwargs.pop("repetition_penalty", None)
length_penalty: Optional[float] = input_kwargs.pop("length_penalty", None)
max_length: Optional[int] = input_kwargs.pop("max_length", None)
max_new_tokens: Optional[int] = input_kwargs.pop("max_new_tokens", None)
stop: Optional[Union[str, List[str]]] = input_kwargs.pop("stop", None)
if stop is not None:
logger.warning("Stop parameter is not supported in Huggingface engine yet.")
generating_args = generating_args.copy()
generating_args.update(
dict(
do_sample=do_sample if do_sample is not None else generating_args["do_sample"],
temperature=temperature or generating_args["temperature"],
top_p=top_p or generating_args["top_p"],
top_k=top_k or generating_args["top_k"],
num_return_sequences=num_return_sequences or 1,
repetition_penalty=repetition_penalty or generating_args["repetition_penalty"],
temperature=temperature if temperature is not None else generating_args["temperature"],
top_p=top_p if top_p is not None else generating_args["top_p"],
top_k=top_k if top_k is not None else generating_args["top_k"],
num_return_sequences=num_return_sequences,
repetition_penalty=repetition_penalty
if repetition_penalty is not None
else generating_args["repetition_penalty"],
length_penalty=length_penalty if length_penalty is not None else generating_args["length_penalty"],
eos_token_id=[tokenizer.eos_token_id] + tokenizer.additional_special_tokens_ids,
pad_token_id=tokenizer.pad_token_id,
)
)
if isinstance(num_return_sequences, int) and num_return_sequences > 1:
if isinstance(num_return_sequences, int) and num_return_sequences > 1: # do_sample needs temperature > 0
generating_args["do_sample"] = True
generating_args["temperature"] = generating_args["temperature"] or 1.0
if not generating_args["temperature"]:
generating_args["do_sample"] = False
if not generating_args["do_sample"]:
generating_args.pop("temperature", None)
generating_args.pop("top_p", None)
if max_length:
generating_args.pop("max_new_tokens", None)
@@ -90,10 +137,14 @@ class HuggingfaceEngine(BaseEngine):
gen_kwargs = dict(
inputs=inputs,
attention_mask=attention_mask,
generation_config=GenerationConfig(**generating_args),
logits_processor=get_logits_processor(),
)
if pixel_values is not None:
gen_kwargs["pixel_values"] = pixel_values
return gen_kwargs, prompt_length
@staticmethod
@@ -101,15 +152,17 @@ class HuggingfaceEngine(BaseEngine):
def _chat(
model: "PreTrainedModel",
tokenizer: "PreTrainedTokenizer",
processor: Optional["ProcessorMixin"],
template: "Template",
generating_args: Dict[str, Any],
messages: Sequence[Dict[str, str]],
system: Optional[str] = None,
tools: Optional[str] = None,
image: Optional["NDArray"] = None,
input_kwargs: Optional[Dict[str, Any]] = {},
) -> List["Response"]:
gen_kwargs, prompt_length = HuggingfaceEngine._process_args(
model, tokenizer, template, generating_args, messages, system, tools, input_kwargs
model, tokenizer, processor, template, generating_args, messages, system, tools, image, input_kwargs
)
generate_output = model.generate(**gen_kwargs)
response_ids = generate_output[:, prompt_length:]
@@ -134,15 +187,17 @@ class HuggingfaceEngine(BaseEngine):
def _stream_chat(
model: "PreTrainedModel",
tokenizer: "PreTrainedTokenizer",
processor: Optional["ProcessorMixin"],
template: "Template",
generating_args: Dict[str, Any],
messages: Sequence[Dict[str, str]],
system: Optional[str] = None,
tools: Optional[str] = None,
image: Optional["NDArray"] = None,
input_kwargs: Optional[Dict[str, Any]] = {},
) -> Callable[[], str]:
gen_kwargs, _ = HuggingfaceEngine._process_args(
model, tokenizer, template, generating_args, messages, system, tools, input_kwargs
model, tokenizer, processor, template, generating_args, messages, system, tools, image, input_kwargs
)
streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
gen_kwargs["streamer"] = streamer
@@ -198,6 +253,7 @@ class HuggingfaceEngine(BaseEngine):
messages: Sequence[Dict[str, str]],
system: Optional[str] = None,
tools: Optional[str] = None,
image: Optional["NDArray"] = None,
**input_kwargs,
) -> List["Response"]:
if not self.can_generate:
@@ -207,11 +263,13 @@ class HuggingfaceEngine(BaseEngine):
input_args = (
self.model,
self.tokenizer,
self.processor,
self.template,
self.generating_args,
messages,
system,
tools,
image,
input_kwargs,
)
async with self._semaphore:
@@ -223,6 +281,7 @@ class HuggingfaceEngine(BaseEngine):
messages: Sequence[Dict[str, str]],
system: Optional[str] = None,
tools: Optional[str] = None,
image: Optional["NDArray"] = None,
**input_kwargs,
) -> AsyncGenerator[str, None]:
if not self.can_generate:
@@ -232,11 +291,13 @@ class HuggingfaceEngine(BaseEngine):
input_args = (
self.model,
self.tokenizer,
self.processor,
self.template,
self.generating_args,
messages,
system,
tools,
image,
input_kwargs,
)
async with self._semaphore:

View File

@@ -0,0 +1,214 @@
import uuid
from typing import TYPE_CHECKING, AsyncGenerator, AsyncIterator, Dict, List, Optional, Sequence, Union
from ..data import get_template_and_fix_tokenizer
from ..extras.logging import get_logger
from ..extras.misc import get_device_count
from ..extras.packages import is_vllm_available
from ..model import load_config, load_tokenizer
from ..model.model_utils.visual import LlavaMultiModalProjectorForYiVLForVLLM
from .base_engine import BaseEngine, Response
if is_vllm_available():
from vllm import AsyncEngineArgs, AsyncLLMEngine, RequestOutput, SamplingParams
from vllm.lora.request import LoRARequest
from vllm.sequence import MultiModalData
if TYPE_CHECKING:
from numpy.typing import NDArray
from transformers.image_processing_utils import BaseImageProcessor
from ..hparams import DataArguments, FinetuningArguments, GeneratingArguments, ModelArguments
logger = get_logger(__name__)
class VllmEngine(BaseEngine):
def __init__(
self,
model_args: "ModelArguments",
data_args: "DataArguments",
finetuning_args: "FinetuningArguments",
generating_args: "GeneratingArguments",
) -> None:
config = load_config(model_args) # may download model from ms hub
self.can_generate = finetuning_args.stage == "sft"
tokenizer_module = load_tokenizer(model_args)
self.tokenizer = tokenizer_module["tokenizer"]
self.processor = tokenizer_module["processor"]
self.tokenizer.padding_side = "left"
self.template = get_template_and_fix_tokenizer(self.tokenizer, data_args.template)
self.generating_args = generating_args.to_dict()
engine_args = {
"model": model_args.model_name_or_path,
"trust_remote_code": True,
"download_dir": model_args.cache_dir,
"dtype": model_args.vllm_dtype,
"max_model_len": model_args.vllm_maxlen,
"tensor_parallel_size": get_device_count() or 1,
"gpu_memory_utilization": model_args.vllm_gpu_util,
"disable_log_stats": True,
"disable_log_requests": True,
"enforce_eager": model_args.vllm_enforce_eager,
"enable_lora": model_args.adapter_name_or_path is not None,
"max_lora_rank": model_args.vllm_max_lora_rank,
}
if model_args.visual_inputs:
image_size = config.vision_config.image_size
patch_size = config.vision_config.patch_size
self.image_feature_size = (image_size // patch_size) ** 2
engine_args["image_input_type"] = "pixel_values"
engine_args["image_token_id"] = self.tokenizer.convert_tokens_to_ids(self.template.image_token)
engine_args["image_input_shape"] = "1,3,{},{}".format(image_size, image_size)
engine_args["image_feature_size"] = self.image_feature_size
if getattr(config, "is_yi_vl_derived_model", None):
import vllm.model_executor.models.llava
logger.info("Detected Yi-VL model, applying projector patch.")
vllm.model_executor.models.llava.LlavaMultiModalProjector = LlavaMultiModalProjectorForYiVLForVLLM
self.model = AsyncLLMEngine.from_engine_args(AsyncEngineArgs(**engine_args))
if model_args.adapter_name_or_path is not None:
self.lora_request = LoRARequest("default", 1, model_args.adapter_name_or_path[0])
else:
self.lora_request = None
async def _generate(
self,
messages: Sequence[Dict[str, str]],
system: Optional[str] = None,
tools: Optional[str] = None,
image: Optional["NDArray"] = None,
**input_kwargs,
) -> AsyncIterator["RequestOutput"]:
request_id = "chatcmpl-{}".format(uuid.uuid4().hex)
if (
self.processor is not None
and image is not None
and not hasattr(self.processor, "image_seq_length")
and self.template.image_token not in messages[0]["content"]
): # llava-like models (TODO: paligemma models)
messages[0]["content"] = self.template.image_token * self.image_feature_size + messages[0]["content"]
paired_messages = messages + [{"role": "assistant", "content": ""}]
system = system or self.generating_args["default_system"]
prompt_ids, _ = self.template.encode_oneturn(
tokenizer=self.tokenizer, messages=paired_messages, system=system, tools=tools
)
if self.processor is not None and image is not None: # add image features
image_processor: "BaseImageProcessor" = getattr(self.processor, "image_processor")
pixel_values = image_processor(image, return_tensors="pt")["pixel_values"]
multi_modal_data = MultiModalData(type=MultiModalData.Type.IMAGE, data=pixel_values)
else:
multi_modal_data = None
prompt_length = len(prompt_ids)
use_beam_search: bool = self.generating_args["num_beams"] > 1
temperature: Optional[float] = input_kwargs.pop("temperature", None)
top_p: Optional[float] = input_kwargs.pop("top_p", None)
top_k: Optional[float] = input_kwargs.pop("top_k", None)
num_return_sequences: int = input_kwargs.pop("num_return_sequences", 1)
repetition_penalty: Optional[float] = input_kwargs.pop("repetition_penalty", None)
length_penalty: Optional[float] = input_kwargs.pop("length_penalty", None)
max_length: Optional[int] = input_kwargs.pop("max_length", None)
max_new_tokens: Optional[int] = input_kwargs.pop("max_new_tokens", None)
stop: Optional[Union[str, List[str]]] = input_kwargs.pop("stop", None)
if "max_new_tokens" in self.generating_args:
max_tokens = self.generating_args["max_new_tokens"]
elif "max_length" in self.generating_args:
if self.generating_args["max_length"] > prompt_length:
max_tokens = self.generating_args["max_length"] - prompt_length
else:
max_tokens = 1
if max_length:
max_tokens = max_length - prompt_length if max_length > prompt_length else 1
if max_new_tokens:
max_tokens = max_new_tokens
sampling_params = SamplingParams(
n=num_return_sequences,
repetition_penalty=(
repetition_penalty if repetition_penalty is not None else self.generating_args["repetition_penalty"]
)
or 1.0, # repetition_penalty must > 0
temperature=temperature if temperature is not None else self.generating_args["temperature"],
top_p=(top_p if top_p is not None else self.generating_args["top_p"]) or 1.0, # top_p must > 0
top_k=top_k if top_k is not None else self.generating_args["top_k"],
use_beam_search=use_beam_search,
length_penalty=length_penalty if length_penalty is not None else self.generating_args["length_penalty"],
stop=stop,
stop_token_ids=[self.tokenizer.eos_token_id] + self.tokenizer.additional_special_tokens_ids,
max_tokens=max_tokens,
skip_special_tokens=True,
)
result_generator = self.model.generate(
inputs={"prompt_token_ids": prompt_ids, "multi_modal_data": multi_modal_data},
sampling_params=sampling_params,
request_id=request_id,
lora_request=self.lora_request,
)
return result_generator
async def start(self) -> None:
pass
async def chat(
self,
messages: Sequence[Dict[str, str]],
system: Optional[str] = None,
tools: Optional[str] = None,
image: Optional["NDArray"] = None,
**input_kwargs,
) -> List["Response"]:
final_output = None
generator = await self._generate(messages, system, tools, image, **input_kwargs)
async for request_output in generator:
final_output = request_output
results = []
for output in final_output.outputs:
results.append(
Response(
response_text=output.text,
response_length=len(output.token_ids),
prompt_length=len(final_output.prompt_token_ids),
finish_reason=output.finish_reason,
)
)
return results
async def stream_chat(
self,
messages: Sequence[Dict[str, str]],
system: Optional[str] = None,
tools: Optional[str] = None,
image: Optional["NDArray"] = None,
**input_kwargs,
) -> AsyncGenerator[str, None]:
generated_text = ""
generator = await self._generate(messages, system, tools, image, **input_kwargs)
async for result in generator:
delta_text = result.outputs[0].text[len(generated_text) :]
generated_text = result.outputs[0].text
yield delta_text
async def get_scores(
self,
batch_input: List[str],
**input_kwargs,
) -> List[float]:
raise NotImplementedError("vLLM engine does not support get_scores.")

View File

@@ -0,0 +1,106 @@
import os
import random
import subprocess
import sys
from enum import Enum, unique
from . import launcher
from .api.app import run_api
from .chat.chat_model import run_chat
from .eval.evaluator import run_eval
from .extras.env import VERSION, print_env
from .extras.logging import get_logger
from .extras.misc import get_device_count
from .train.tuner import export_model, run_exp
from .webui.interface import run_web_demo, run_web_ui
USAGE = (
"-" * 70
+ "\n"
+ "| Usage: |\n"
+ "| llamafactory-cli api -h: launch an OpenAI-style API server |\n"
+ "| llamafactory-cli chat -h: launch a chat interface in CLI |\n"
+ "| llamafactory-cli eval -h: evaluate models |\n"
+ "| llamafactory-cli export -h: merge LoRA adapters and export model |\n"
+ "| llamafactory-cli train -h: train models |\n"
+ "| llamafactory-cli webchat -h: launch a chat interface in Web UI |\n"
+ "| llamafactory-cli webui: launch LlamaBoard |\n"
+ "| llamafactory-cli version: show version info |\n"
+ "-" * 70
)
WELCOME = (
"-" * 58
+ "\n"
+ "| Welcome to LLaMA Factory, version {}".format(VERSION)
+ " " * (21 - len(VERSION))
+ "|\n|"
+ " " * 56
+ "|\n"
+ "| Project page: https://github.com/hiyouga/LLaMA-Factory |\n"
+ "-" * 58
)
logger = get_logger(__name__)
@unique
class Command(str, Enum):
API = "api"
CHAT = "chat"
ENV = "env"
EVAL = "eval"
EXPORT = "export"
TRAIN = "train"
WEBDEMO = "webchat"
WEBUI = "webui"
VER = "version"
HELP = "help"
def main():
command = sys.argv.pop(1)
if command == Command.API:
run_api()
elif command == Command.CHAT:
run_chat()
elif command == Command.ENV:
print_env()
elif command == Command.EVAL:
run_eval()
elif command == Command.EXPORT:
export_model()
elif command == Command.TRAIN:
force_torchrun = os.environ.get("FORCE_TORCHRUN", "0").lower() in ["true", "1"]
if force_torchrun or get_device_count() > 1:
master_addr = os.environ.get("MASTER_ADDR", "127.0.0.1")
master_port = os.environ.get("MASTER_PORT", str(random.randint(20001, 29999)))
logger.info("Initializing distributed tasks at: {}:{}".format(master_addr, master_port))
subprocess.run(
(
"torchrun --nnodes {nnodes} --node_rank {node_rank} --nproc_per_node {nproc_per_node} "
"--master_addr {master_addr} --master_port {master_port} {file_name} {args}"
).format(
nnodes=os.environ.get("NNODES", "1"),
node_rank=os.environ.get("RANK", "0"),
nproc_per_node=os.environ.get("NPROC_PER_NODE", str(get_device_count())),
master_addr=master_addr,
master_port=master_port,
file_name=launcher.__file__,
args=" ".join(sys.argv[1:]),
),
shell=True,
)
else:
run_exp()
elif command == Command.WEBDEMO:
run_web_demo()
elif command == Command.WEBUI:
run_web_ui()
elif command == Command.VER:
print(WELCOME)
elif command == Command.HELP:
print(USAGE)
else:
raise NotImplementedError("Unknown command: {}".format(command))

View File

@@ -0,0 +1,16 @@
from .collator import KTODataCollatorWithPadding, PairwiseDataCollatorWithPadding
from .data_utils import Role, split_dataset
from .loader import get_dataset
from .template import TEMPLATES, Template, get_template_and_fix_tokenizer
__all__ = [
"KTODataCollatorWithPadding",
"PairwiseDataCollatorWithPadding",
"Role",
"split_dataset",
"get_dataset",
"TEMPLATES",
"Template",
"get_template_and_fix_tokenizer",
]

View File

@@ -0,0 +1,221 @@
import os
from functools import partial
from typing import TYPE_CHECKING, Any, Dict, List, Union
from datasets import Features
from ..extras.logging import get_logger
from .data_utils import Role
if TYPE_CHECKING:
from datasets import Dataset, IterableDataset
from ..hparams import DataArguments
from .parser import DatasetAttr
logger = get_logger(__name__)
def _convert_images(images: List[Any], dataset_attr: "DatasetAttr", data_args: "DataArguments") -> List[Any]:
r"""
Optionally concatenates image path to dataset dir when loading from local disk.
"""
outputs = []
if dataset_attr.load_from in ["script", "file"]:
for image in images:
if isinstance(image, str) and os.path.isfile(os.path.join(data_args.dataset_dir, image)):
outputs.append(os.path.join(data_args.dataset_dir, image))
else:
outputs.append(image)
return outputs
def convert_alpaca(
examples: Dict[str, List[Any]], dataset_attr: "DatasetAttr", data_args: "DataArguments"
) -> Dict[str, List[Any]]:
r"""
Converts alpaca format dataset to the standard format.
"""
outputs = {"prompt": [], "response": [], "system": [], "tools": [], "images": []}
convert_images = partial(_convert_images, dataset_attr=dataset_attr, data_args=data_args)
for i in range(len(examples[dataset_attr.prompt])):
prompt = []
if dataset_attr.history and isinstance(examples[dataset_attr.history][i], list):
for old_prompt, old_response in examples[dataset_attr.history][i]:
prompt.append({"role": Role.USER.value, "content": old_prompt})
prompt.append({"role": Role.ASSISTANT.value, "content": old_response})
content = []
if dataset_attr.prompt and examples[dataset_attr.prompt][i]:
content.append(examples[dataset_attr.prompt][i])
if dataset_attr.query and examples[dataset_attr.query][i]:
content.append(examples[dataset_attr.query][i])
prompt.append({"role": Role.USER.value, "content": "\n".join(content)}) # "prompt\nquery"
if dataset_attr.kto_tag and isinstance(examples[dataset_attr.kto_tag][i], bool): # kto example
response = [{"role": Role.ASSISTANT.value, "content": examples[dataset_attr.response][i]}]
if examples[dataset_attr.kto_tag][i]:
response = response + [{"role": Role.ASSISTANT.value, "content": ""}]
else:
response = [{"role": Role.ASSISTANT.value, "content": ""}] + response
elif (
dataset_attr.ranking
and isinstance(examples[dataset_attr.chosen][i], str)
and isinstance(examples[dataset_attr.rejected][i], str)
): # pairwise example
response = [
{"role": Role.ASSISTANT.value, "content": examples[dataset_attr.chosen][i]},
{"role": Role.ASSISTANT.value, "content": examples[dataset_attr.rejected][i]},
]
elif dataset_attr.response and isinstance(examples[dataset_attr.response][i], str): # normal example
response = [{"role": Role.ASSISTANT.value, "content": examples[dataset_attr.response][i]}]
else: # unsupervised
response = []
outputs["prompt"].append(prompt)
outputs["response"].append(response)
outputs["system"].append(examples[dataset_attr.system][i] if dataset_attr.system else "")
outputs["tools"].append(examples[dataset_attr.tools][i] if dataset_attr.tools else "")
outputs["images"].append(convert_images(examples[dataset_attr.images][i]) if dataset_attr.images else [])
return outputs
def convert_sharegpt(
examples: Dict[str, List[Any]], dataset_attr: "DatasetAttr", data_args: "DataArguments"
) -> Dict[str, List[Any]]:
r"""
Converts sharegpt format dataset to the standard format.
"""
outputs = {"prompt": [], "response": [], "system": [], "tools": [], "images": []}
convert_images = partial(_convert_images, dataset_attr=dataset_attr, data_args=data_args)
tag_mapping = {
dataset_attr.user_tag: Role.USER.value,
dataset_attr.assistant_tag: Role.ASSISTANT.value,
dataset_attr.observation_tag: Role.OBSERVATION.value,
dataset_attr.function_tag: Role.FUNCTION.value,
dataset_attr.system_tag: Role.SYSTEM.value,
}
odd_tags = (dataset_attr.user_tag, dataset_attr.observation_tag)
even_tags = (dataset_attr.assistant_tag, dataset_attr.function_tag)
accept_tags = (odd_tags, even_tags)
for i, messages in enumerate(examples[dataset_attr.messages]):
if dataset_attr.system_tag and messages[0][dataset_attr.role_tag] == dataset_attr.system_tag:
system = messages[0][dataset_attr.content_tag]
messages = messages[1:]
else:
system = examples[dataset_attr.system][i] if dataset_attr.system else ""
if len(messages) == 0:
continue
aligned_messages = []
broken_data = False
for turn_idx, message in enumerate(messages):
if message[dataset_attr.role_tag] not in accept_tags[turn_idx % 2]:
logger.warning("Invalid role tag in {}.".format(messages))
broken_data = True
aligned_messages.append(
{"role": tag_mapping[message[dataset_attr.role_tag]], "content": message[dataset_attr.content_tag]}
)
if (not dataset_attr.ranking and len(aligned_messages) % 2 != 0) or (
dataset_attr.ranking and len(aligned_messages) % 2 == 0
):
logger.warning("Invalid message count in {}.".format(messages))
broken_data = True
if dataset_attr.kto_tag and isinstance(examples[dataset_attr.kto_tag][i], bool): # kto example
prompt = aligned_messages[:-1]
response = aligned_messages[-1:]
if examples[dataset_attr.kto_tag][i]:
response = response + [{"role": Role.ASSISTANT.value, "content": ""}]
else:
response = [{"role": Role.ASSISTANT.value, "content": ""}] + response
elif (
dataset_attr.ranking
and isinstance(examples[dataset_attr.chosen][i], dict)
and isinstance(examples[dataset_attr.rejected][i], dict)
): # pairwise example
chosen = examples[dataset_attr.chosen][i]
rejected = examples[dataset_attr.rejected][i]
if (
chosen[dataset_attr.role_tag] not in accept_tags[-1]
or rejected[dataset_attr.role_tag] not in accept_tags[-1]
):
logger.warning("Invalid role tag in {}.".format([chosen, rejected]))
broken_data = True
prompt = aligned_messages
response = [
{"role": tag_mapping[chosen[dataset_attr.role_tag]], "content": chosen[dataset_attr.content_tag]},
{"role": tag_mapping[rejected[dataset_attr.role_tag]], "content": rejected[dataset_attr.content_tag]},
]
else: # normal example
prompt = aligned_messages[:-1]
response = aligned_messages[-1:]
if broken_data:
logger.warning("Skipping this abnormal example.")
continue
outputs["prompt"].append(prompt)
outputs["response"].append(response)
outputs["system"].append(system)
outputs["tools"].append(examples[dataset_attr.tools][i] if dataset_attr.tools else "")
outputs["images"].append(convert_images(examples[dataset_attr.images][i]) if dataset_attr.images else [])
return outputs
def align_dataset(
dataset: Union["Dataset", "IterableDataset"], dataset_attr: "DatasetAttr", data_args: "DataArguments"
) -> Union["Dataset", "IterableDataset"]:
r"""
Aligned dataset:
prompt: [{"role": "user", "content": "..."}] * (2T - 1)
response: [{"role": "assistant", "content": "..."}] * N (N > 1 for ranking dataset)
system: "..."
tools: "...",
images: [],
"""
if dataset_attr.formatting == "alpaca":
convert_func = partial(convert_alpaca, dataset_attr=dataset_attr, data_args=data_args)
else:
convert_func = partial(convert_sharegpt, dataset_attr=dataset_attr, data_args=data_args)
column_names = list(next(iter(dataset)).keys())
features = Features.from_dict(
{
"prompt": [
{"role": {"dtype": "string", "_type": "Value"}, "content": {"dtype": "string", "_type": "Value"}}
],
"response": [
{"role": {"dtype": "string", "_type": "Value"}, "content": {"dtype": "string", "_type": "Value"}}
],
"system": {"dtype": "string", "_type": "Value"},
"tools": {"dtype": "string", "_type": "Value"},
"images": [{"_type": "Image"}],
}
)
kwargs = {}
if not data_args.streaming:
kwargs = dict(
num_proc=data_args.preprocessing_num_workers,
load_from_cache_file=(not data_args.overwrite_cache),
desc="Converting format of dataset",
)
return dataset.map(
convert_func,
batched=True,
remove_columns=column_names,
features=features,
**kwargs,
)

View File

@@ -0,0 +1,81 @@
from dataclasses import dataclass
from typing import Any, Dict, Sequence
import torch
from transformers import DataCollatorForSeq2Seq
@dataclass
class PairwiseDataCollatorWithPadding(DataCollatorForSeq2Seq):
r"""
Data collator for pairwise data.
"""
def __call__(self, features: Sequence[Dict[str, Any]]) -> Dict[str, torch.Tensor]:
r"""
Pads batched data to the longest sequence in the batch.
We generate 2 * n examples where the first n examples represent chosen examples and
the last n examples represent rejected examples.
"""
concatenated_features = []
for key in ("chosen", "rejected"):
for feature in features:
target_feature = {
"input_ids": feature["{}_input_ids".format(key)],
"attention_mask": feature["{}_attention_mask".format(key)],
"labels": feature["{}_labels".format(key)],
}
if "pixel_values" in feature:
target_feature["pixel_values"] = feature["pixel_values"]
if "{}_token_type_ids".format(key) in feature:
target_feature["token_type_ids"] = feature["{}_token_type_ids".format(key)]
concatenated_features.append(target_feature)
return super().__call__(concatenated_features)
@dataclass
class KTODataCollatorWithPadding(DataCollatorForSeq2Seq):
r"""
Data collator for KTO data.
"""
def __call__(self, features: Sequence[Dict[str, Any]]) -> Dict[str, torch.Tensor]:
target_features = []
kl_features = []
kto_tags = []
for feature in features:
target_feature = {
"input_ids": feature["input_ids"],
"attention_mask": feature["attention_mask"],
"labels": feature["labels"],
}
kl_feature = {
"input_ids": feature["kl_input_ids"],
"attention_mask": feature["kl_attention_mask"],
"labels": feature["kl_labels"],
}
if "pixel_values" in feature:
target_feature["pixel_values"] = feature["pixel_values"]
if "token_type_ids" in feature:
target_feature["token_type_ids"] = feature["token_type_ids"]
kl_feature["token_type_ids"] = feature["kl_token_type_ids"]
target_features.append(target_feature)
kl_features.append(kl_feature)
kto_tags.append(feature["kto_tags"])
batch = super().__call__(target_features)
kl_batch = super().__call__(kl_features)
batch["kl_input_ids"] = kl_batch["input_ids"]
batch["kl_attention_mask"] = kl_batch["attention_mask"]
batch["kl_labels"] = kl_batch["labels"]
if "token_type_ids" in batch:
batch["kl_token_type_ids"] = kl_batch["token_type_ids"]
batch["kto_tags"] = torch.tensor(kto_tags)
return batch

View File

@@ -1,6 +1,5 @@
import hashlib
from enum import Enum, unique
from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union
from typing import TYPE_CHECKING, Dict, List, Tuple, Union
from datasets import concatenate_datasets, interleave_datasets
@@ -11,7 +10,7 @@ if TYPE_CHECKING:
from datasets import Dataset, IterableDataset
from transformers import Seq2SeqTrainingArguments
from llmtuner.hparams import DataArguments
from ..hparams import DataArguments
logger = get_logger(__name__)
@@ -26,25 +25,10 @@ class Role(str, Enum):
OBSERVATION = "observation"
def checksum(data_files: List[str], file_sha1: Optional[str] = None) -> None:
if file_sha1 is None:
logger.warning("Checksum failed: missing SHA-1 hash value in dataset_info.json.")
return
if len(data_files) != 1:
logger.warning("Checksum failed: too many files.")
return
with open(data_files[0], "rb") as f:
sha1 = hashlib.sha1(f.read()).hexdigest()
if sha1 != file_sha1:
logger.warning("Checksum failed: mismatched SHA-1 hash value at {}.".format(data_files[0]))
def infer_max_len(source_len: int, target_len: int, max_len: int, reserved_label_len: int) -> Tuple[int, int]:
max_target_len = int(max_len * (target_len / (source_len + target_len)))
max_target_len = max(max_target_len, reserved_label_len)
max_source_len = max_len - max_target_len
max_source_len = max_len - min(max_target_len, target_len)
return max_source_len, max_target_len
@@ -78,9 +62,9 @@ def split_dataset(
if training_args.do_train:
if data_args.val_size > 1e-6: # Split the dataset
if data_args.streaming:
dataset = dataset.shuffle(buffer_size=data_args.buffer_size, seed=training_args.seed)
val_set = dataset.take(int(data_args.val_size))
train_set = dataset.skip(int(data_args.val_size))
dataset = dataset.shuffle(buffer_size=data_args.buffer_size, seed=training_args.seed)
return {"train_dataset": train_set, "eval_dataset": val_set}
else:
val_size = int(data_args.val_size) if data_args.val_size > 1 else data_args.val_size

View File

@@ -1,21 +1,24 @@
import inspect
import os
from typing import TYPE_CHECKING, Literal, Union
import sys
from typing import TYPE_CHECKING, Literal, Optional, Union
import numpy as np
from datasets import load_dataset, load_from_disk
from ..extras.constants import FILEEXT2TYPE
from ..extras.logging import get_logger
from ..extras.misc import has_tokenized_data
from .aligner import align_dataset
from .data_utils import merge_dataset
from .parser import get_dataset_list
from .preprocess import get_preprocess_and_print_func
from .template import get_template_and_fix_tokenizer
from .utils import checksum, merge_dataset
if TYPE_CHECKING:
from datasets import Dataset, IterableDataset
from transformers import Seq2SeqTrainingArguments
from transformers import ProcessorMixin, Seq2SeqTrainingArguments
from transformers.tokenization_utils import PreTrainedTokenizer
from ..hparams import DataArguments, ModelArguments
@@ -56,14 +59,12 @@ def load_single_dataset(
data_files.append(local_path)
data_path = FILEEXT2TYPE.get(local_path.split(".")[-1], None)
else:
raise ValueError("File not found.")
raise ValueError("File {} not found.".format(local_path))
if data_path is None:
raise ValueError("File extension must be txt, csv, json or jsonl.")
checksum(data_files, dataset_attr.file_sha1)
raise ValueError("Allowed file types: {}.".format(",".join(FILEEXT2TYPE.keys())))
else:
raise NotImplementedError
raise NotImplementedError("Unknown load type: {}.".format(dataset_attr.load_from))
if dataset_attr.load_from == "ms_hub":
try:
@@ -80,7 +81,9 @@ def load_single_dataset(
cache_dir=cache_dir,
token=model_args.ms_hub_token,
use_streaming=(data_args.streaming and (dataset_attr.load_from != "file")),
).to_hf_dataset()
)
if isinstance(dataset, MsDataset):
dataset = dataset.to_hf_dataset()
except ImportError:
raise ImportError("Please install modelscope via `pip install modelscope -U`")
else:
@@ -104,30 +107,43 @@ def load_single_dataset(
if data_args.streaming and (dataset_attr.load_from == "file"): # faster than specifying streaming=True
dataset = dataset.to_iterable_dataset() # TODO: add num shards parameter
if dataset_attr.num_samples is not None and not data_args.streaming:
target_num = dataset_attr.num_samples
indexes = np.random.permutation(len(dataset))[:target_num]
target_num -= len(indexes)
if target_num > 0:
expand_indexes = np.random.choice(len(dataset), target_num)
indexes = np.concatenate((indexes, expand_indexes), axis=0)
assert len(indexes) == dataset_attr.num_samples, "Sample num mismatched."
dataset = dataset.select(indexes)
logger.info("Sampled {} examples from dataset {}.".format(dataset_attr.num_samples, dataset_attr))
if data_args.max_samples is not None: # truncate dataset
num_samples = min(data_args.max_samples, len(dataset))
dataset = dataset.select(range(num_samples))
max_samples = min(data_args.max_samples, len(dataset))
dataset = dataset.select(range(max_samples))
return align_dataset(dataset, dataset_attr, data_args)
def get_dataset(
tokenizer: "PreTrainedTokenizer",
model_args: "ModelArguments",
data_args: "DataArguments",
training_args: "Seq2SeqTrainingArguments",
stage: Literal["pt", "sft", "rm", "ppo"],
# split: Optional[str] = "train", # TODO: add split
stage: Literal["pt", "sft", "rm", "ppo", "kto"],
tokenizer: "PreTrainedTokenizer",
processor: Optional["ProcessorMixin"] = None,
) -> Union["Dataset", "IterableDataset"]:
template = get_template_and_fix_tokenizer(tokenizer, data_args.template)
if data_args.train_on_prompt and template.efficient_eos:
raise ValueError("Current template does not support `train_on_prompt`.")
# Load from cache
if data_args.cache_path is not None:
if os.path.exists(data_args.cache_path):
# Load tokenized dataset
if data_args.tokenized_path is not None:
if has_tokenized_data(data_args.tokenized_path):
logger.warning("Loading dataset from disk will ignore other data arguments.")
dataset = load_from_disk(data_args.cache_path)
dataset = load_from_disk(data_args.tokenized_path)
logger.info("Loaded tokenized dataset from {}.".format(data_args.tokenized_path))
if data_args.streaming:
dataset = dataset.to_iterable_dataset()
return dataset
@@ -138,12 +154,15 @@ def get_dataset(
with training_args.main_process_first(desc="load dataset"):
all_datasets = []
for dataset_attr in get_dataset_list(data_args):
if (stage == "rm" and dataset_attr.ranking is False) or (stage != "rm" and dataset_attr.ranking is True):
raise ValueError("The dataset is not applicable in the current training stage.")
all_datasets.append(load_single_dataset(dataset_attr, model_args, data_args))
dataset = merge_dataset(all_datasets, data_args, training_args)
with training_args.main_process_first(desc="pre-process dataset"):
preprocess_func, print_function = get_preprocess_and_print_func(
tokenizer, template, data_args, training_args, stage
data_args, training_args, stage, template, tokenizer, processor
)
column_names = list(next(iter(dataset)).keys())
kwargs = {}
@@ -156,15 +175,21 @@ def get_dataset(
dataset = dataset.map(preprocess_func, batched=True, remove_columns=column_names, **kwargs)
if data_args.cache_path is not None and not os.path.exists(data_args.cache_path):
if data_args.tokenized_path is not None:
if training_args.should_save:
dataset.save_to_disk(data_args.cache_path)
logger.info("Dataset cache saved at {}.".format(data_args.cache_path))
dataset.save_to_disk(data_args.tokenized_path)
logger.info("Tokenized dataset saved at {}.".format(data_args.tokenized_path))
logger.info("Please restart the training with `tokenized_path: {}`.".format(data_args.tokenized_path))
sys.exit(0)
if training_args.should_log:
try:
print_function(next(iter(dataset)))
except StopIteration:
raise RuntimeError("Cannot find valid samples, check `data/README.md` for the data format.")
if stage == "pt":
raise RuntimeError("Cannot find sufficient samples, consider increasing dataset size.")
else:
raise RuntimeError("Cannot find valid samples, check `data/README.md` for the data format.")
return dataset

View File

@@ -20,23 +20,28 @@ class DatasetAttr:
""" basic configs """
load_from: Literal["hf_hub", "ms_hub", "script", "file"]
dataset_name: str
formatting: Literal["alpaca", "sharegpt"] = "alpaca"
ranking: bool = False
""" extra configs """
file_sha1: Optional[str] = None
subset: Optional[str] = None
folder: Optional[str] = None
ranking: bool = False
formatting: Literal["alpaca", "sharegpt"] = "alpaca"
""" columns """
num_samples: Optional[int] = None
""" common columns """
system: Optional[str] = None
""" columns for the alpaca format """
tools: Optional[str] = None
images: Optional[str] = None
""" rlhf columns """
chosen: Optional[str] = None
rejected: Optional[str] = None
kto_tag: Optional[str] = None
""" alpaca columns """
prompt: Optional[str] = "instruction"
query: Optional[str] = "input"
response: Optional[str] = "output"
history: Optional[str] = None
""" columns for the sharegpt format """
""" sharegpt columns """
messages: Optional[str] = "conversations"
tools: Optional[str] = None
""" tags for the sharegpt format """
""" sharegpt tags """
role_tag: Optional[str] = "from"
content_tag: Optional[str] = "value"
user_tag: Optional[str] = "human"
@@ -53,22 +58,35 @@ class DatasetAttr:
def get_dataset_list(data_args: "DataArguments") -> List["DatasetAttr"]:
dataset_names = [ds.strip() for ds in data_args.dataset.split(",")] if data_args.dataset is not None else []
try:
with open(os.path.join(data_args.dataset_dir, DATA_CONFIG), "r") as f:
dataset_info = json.load(f)
except Exception as err:
if data_args.dataset is not None:
raise ValueError(
"Cannot open {} due to {}.".format(os.path.join(data_args.dataset_dir, DATA_CONFIG), str(err))
)
if data_args.dataset is not None:
dataset_names = [ds.strip() for ds in data_args.dataset.split(",")]
else:
dataset_names = []
if data_args.dataset_dir == "ONLINE":
dataset_info = None
else:
try:
with open(os.path.join(data_args.dataset_dir, DATA_CONFIG), "r") as f:
dataset_info = json.load(f)
except Exception as err:
if len(dataset_names) != 0:
raise ValueError(
"Cannot open {} due to {}.".format(os.path.join(data_args.dataset_dir, DATA_CONFIG), str(err))
)
dataset_info = None
if data_args.interleave_probs is not None:
data_args.interleave_probs = [float(prob.strip()) for prob in data_args.interleave_probs.split(",")]
dataset_list: List[DatasetAttr] = []
for name in dataset_names:
if dataset_info is None:
load_from = "ms_hub" if use_modelscope() else "hf_hub"
dataset_attr = DatasetAttr(load_from, dataset_name=name)
dataset_list.append(dataset_attr)
continue
if name not in dataset_info:
raise ValueError("Undefined dataset {} in {}.".format(name, DATA_CONFIG))
@@ -85,18 +103,18 @@ def get_dataset_list(data_args: "DataArguments") -> List["DatasetAttr"]:
else:
dataset_attr = DatasetAttr("file", dataset_name=dataset_info[name]["file_name"])
dataset_attr.set_attr("file_sha1", dataset_info[name])
dataset_attr.set_attr("formatting", dataset_info[name], default="alpaca")
dataset_attr.set_attr("ranking", dataset_info[name], default=False)
dataset_attr.set_attr("subset", dataset_info[name])
dataset_attr.set_attr("folder", dataset_info[name])
dataset_attr.set_attr("ranking", dataset_info[name], default=False)
dataset_attr.set_attr("formatting", dataset_info[name], default="alpaca")
dataset_attr.set_attr("num_samples", dataset_info[name])
if "columns" in dataset_info[name]:
column_names = ["system"]
column_names = ["system", "tools", "images", "chosen", "rejected", "kto_tag"]
if dataset_attr.formatting == "alpaca":
column_names.extend(["prompt", "query", "response", "history"])
else:
column_names.extend(["messages", "tools"])
column_names.extend(["messages"])
for column_name in column_names:
dataset_attr.set_attr(column_name, dataset_info[name]["columns"])

View File

@@ -0,0 +1,84 @@
from functools import partial
from typing import TYPE_CHECKING, Callable, Literal, Optional, Tuple
from .processors.feedback import preprocess_feedback_dataset
from .processors.pairwise import preprocess_pairwise_dataset, print_pairwise_dataset_example
from .processors.pretrain import preprocess_pretrain_dataset
from .processors.supervised import (
preprocess_packed_supervised_dataset,
preprocess_supervised_dataset,
print_supervised_dataset_example,
)
from .processors.unsupervised import preprocess_unsupervised_dataset, print_unsupervised_dataset_example
if TYPE_CHECKING:
from transformers import ProcessorMixin, Seq2SeqTrainingArguments
from transformers.tokenization_utils import PreTrainedTokenizer
from ..hparams import DataArguments
from .template import Template
def get_preprocess_and_print_func(
data_args: "DataArguments",
training_args: "Seq2SeqTrainingArguments",
stage: Literal["pt", "sft", "rm", "ppo", "kto"],
template: "Template",
tokenizer: "PreTrainedTokenizer",
processor: Optional["ProcessorMixin"],
) -> Tuple[Callable, Callable]:
if stage == "pt":
preprocess_func = partial(
preprocess_pretrain_dataset,
tokenizer=tokenizer,
data_args=data_args,
)
print_function = partial(print_unsupervised_dataset_example, tokenizer=tokenizer)
elif stage == "sft" and not training_args.predict_with_generate:
if data_args.packing:
preprocess_func = partial(
preprocess_packed_supervised_dataset,
template=template,
tokenizer=tokenizer,
data_args=data_args,
)
else:
preprocess_func = partial(
preprocess_supervised_dataset,
template=template,
tokenizer=tokenizer,
processor=processor,
data_args=data_args,
)
print_function = partial(print_supervised_dataset_example, tokenizer=tokenizer)
elif stage == "rm":
preprocess_func = partial(
preprocess_pairwise_dataset,
template=template,
tokenizer=tokenizer,
processor=processor,
data_args=data_args,
)
print_function = partial(print_pairwise_dataset_example, tokenizer=tokenizer)
elif stage == "kto":
preprocess_func = partial(
preprocess_feedback_dataset,
template=template,
tokenizer=tokenizer,
processor=processor,
data_args=data_args,
)
print_function = partial(print_supervised_dataset_example, tokenizer=tokenizer)
else:
preprocess_func = partial(
preprocess_unsupervised_dataset,
template=template,
tokenizer=tokenizer,
processor=processor,
data_args=data_args,
)
print_function = partial(print_unsupervised_dataset_example, tokenizer=tokenizer)
return preprocess_func, print_function

View File

@@ -0,0 +1,126 @@
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Sequence, Tuple
from ...extras.constants import IGNORE_INDEX
from ...extras.logging import get_logger
from .processor_utils import get_paligemma_token_type_ids, get_pixel_values
if TYPE_CHECKING:
from transformers import ProcessorMixin
from transformers.tokenization_utils import PreTrainedTokenizer
from ...hparams import DataArguments
from ..template import Template
logger = get_logger(__name__)
def _encode_feedback_example(
prompt: Sequence[Dict[str, str]],
response: Sequence[Dict[str, str]],
kl_response: Sequence[Dict[str, str]],
system: Optional[str],
tools: Optional[str],
template: "Template",
tokenizer: "PreTrainedTokenizer",
processor: Optional["ProcessorMixin"],
data_args: "DataArguments",
) -> Tuple[List[int], List[int], List[int], List[int], bool]:
if processor is not None and not hasattr(processor, "image_seq_length"): # llava-like models
prompt[0]["content"] = template.image_token + prompt[0]["content"]
if response[0]["content"]: # desired example
kto_tag = True
messages = prompt + [response[0]]
else: # undesired example
kto_tag = False
messages = prompt + [response[1]]
if kl_response[0]["content"]:
kl_messages = prompt + [kl_response[0]]
else:
kl_messages = prompt + [kl_response[1]]
prompt_ids, response_ids = template.encode_oneturn(
tokenizer, messages, system, tools, data_args.cutoff_len, data_args.reserved_label_len
)
_, kl_response_ids = template.encode_oneturn(
tokenizer, kl_messages, system, tools, data_args.cutoff_len, data_args.reserved_label_len
)
if template.efficient_eos:
response_ids += [tokenizer.eos_token_id]
kl_response_ids += [tokenizer.eos_token_id]
if processor is not None and hasattr(processor, "image_seq_length"): # paligemma models
image_token_id = tokenizer.convert_tokens_to_ids(template.image_token)
prompt_ids = [image_token_id] * getattr(processor, "image_seq_length") + prompt_ids
input_ids = prompt_ids + response_ids
labels = [IGNORE_INDEX] * len(prompt_ids) + response_ids
kl_input_ids = prompt_ids + kl_response_ids
kl_labels = [IGNORE_INDEX] * len(prompt_ids) + kl_response_ids
return input_ids, labels, kl_input_ids, kl_labels, kto_tag
def preprocess_feedback_dataset(
examples: Dict[str, List[Any]],
template: "Template",
tokenizer: "PreTrainedTokenizer",
processor: Optional["ProcessorMixin"],
data_args: "DataArguments",
) -> Dict[str, List[List[int]]]:
# create unrelated input-output pairs for estimating the KL term by flipping the matched pairs
kl_response = examples["response"][::-1]
model_inputs = {
"input_ids": [],
"attention_mask": [],
"labels": [],
"kl_input_ids": [],
"kl_attention_mask": [],
"kl_labels": [],
"kto_tags": [],
}
if processor is not None:
model_inputs["pixel_values"] = []
if hasattr(processor, "image_seq_length"): # paligemma models
model_inputs["token_type_ids"] = []
model_inputs["kl_token_type_ids"] = []
for i in range(len(examples["prompt"])):
if len(examples["prompt"][i]) % 2 != 1 or len(examples["response"][i]) < 2:
logger.warning("Dropped invalid example: {}".format(examples["prompt"][i] + examples["response"][i]))
continue
input_ids, labels, kl_input_ids, kl_labels, kto_tag = _encode_feedback_example(
prompt=examples["prompt"][i],
response=examples["response"][i],
kl_response=kl_response[i],
system=examples["system"][i],
tools=examples["tools"][i],
template=template,
tokenizer=tokenizer,
processor=processor,
data_args=data_args,
)
model_inputs["input_ids"].append(input_ids)
model_inputs["attention_mask"].append([1] * len(input_ids))
model_inputs["labels"].append(labels)
model_inputs["kl_input_ids"].append(kl_input_ids)
model_inputs["kl_attention_mask"].append([1] * len(kl_input_ids))
model_inputs["kl_labels"].append(kl_labels)
model_inputs["kto_tags"].append(kto_tag)
if processor is not None:
model_inputs["pixel_values"].append(get_pixel_values(examples["images"][i], processor))
if hasattr(processor, "image_seq_length"): # paligemma models
model_inputs["token_type_ids"].append(get_paligemma_token_type_ids(len(input_ids), processor))
model_inputs["kl_token_type_ids"].append(get_paligemma_token_type_ids(len(kl_input_ids), processor))
desirable_num = sum([1 for tag in model_inputs["kto_tags"] if tag])
undesirable_num = len(model_inputs["kto_tags"]) - desirable_num
if desirable_num == 0 or undesirable_num == 0:
logger.warning("Your dataset only has one preference type.")
return model_inputs

View File

@@ -0,0 +1,123 @@
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Sequence, Tuple
from ...extras.constants import IGNORE_INDEX
from ...extras.logging import get_logger
from .processor_utils import get_paligemma_token_type_ids, get_pixel_values
if TYPE_CHECKING:
from transformers import ProcessorMixin
from transformers.tokenization_utils import PreTrainedTokenizer
from ...hparams import DataArguments
from ..template import Template
logger = get_logger(__name__)
def _encode_pairwise_example(
prompt: Sequence[Dict[str, str]],
response: Sequence[Dict[str, str]],
system: Optional[str],
tools: Optional[str],
template: "Template",
tokenizer: "PreTrainedTokenizer",
processor: Optional["ProcessorMixin"],
data_args: "DataArguments",
) -> Tuple[List[int], List[int], List[int], List[int]]:
if processor is not None and not hasattr(processor, "image_seq_length"): # llava-like models
prompt[0]["content"] = template.image_token + prompt[0]["content"]
chosen_messages = prompt + [response[0]]
rejected_messages = prompt + [response[1]]
prompt_ids, chosen_ids = template.encode_oneturn(
tokenizer, chosen_messages, system, tools, data_args.cutoff_len, data_args.reserved_label_len
)
_, rejected_ids = template.encode_oneturn(
tokenizer, rejected_messages, system, tools, data_args.cutoff_len, data_args.reserved_label_len
)
if template.efficient_eos:
chosen_ids += [tokenizer.eos_token_id]
rejected_ids += [tokenizer.eos_token_id]
if processor is not None and hasattr(processor, "image_seq_length"): # paligemma models
image_token_id = tokenizer.convert_tokens_to_ids(template.image_token)
prompt_ids = [image_token_id] * getattr(processor, "image_seq_length") + prompt_ids
chosen_input_ids = prompt_ids + chosen_ids
chosen_labels = [IGNORE_INDEX] * len(prompt_ids) + chosen_ids
rejected_input_ids = prompt_ids + rejected_ids
rejected_labels = [IGNORE_INDEX] * len(prompt_ids) + rejected_ids
return chosen_input_ids, chosen_labels, rejected_input_ids, rejected_labels
def preprocess_pairwise_dataset(
examples: Dict[str, List[Any]],
template: "Template",
tokenizer: "PreTrainedTokenizer",
processor: Optional["ProcessorMixin"],
data_args: "DataArguments",
) -> Dict[str, List[List[int]]]:
# build input pairs with format `<bos> X`, `Y1 <eos>` and `Y2 <eos>`
model_inputs = {
"chosen_input_ids": [],
"chosen_attention_mask": [],
"chosen_labels": [],
"rejected_input_ids": [],
"rejected_attention_mask": [],
"rejected_labels": [],
}
if processor is not None:
model_inputs["pixel_values"] = []
if hasattr(processor, "image_seq_length"): # paligemma models
model_inputs["chosen_token_type_ids"] = []
model_inputs["rejected_token_type_ids"] = []
for i in range(len(examples["prompt"])):
if len(examples["prompt"][i]) % 2 != 1 or len(examples["response"][i]) < 2:
logger.warning("Dropped invalid example: {}".format(examples["prompt"][i] + examples["response"][i]))
continue
chosen_input_ids, chosen_labels, rejected_input_ids, rejected_labels = _encode_pairwise_example(
prompt=examples["prompt"][i],
response=examples["response"][i],
system=examples["system"][i],
tools=examples["tools"][i],
template=template,
tokenizer=tokenizer,
processor=processor,
data_args=data_args,
)
model_inputs["chosen_input_ids"].append(chosen_input_ids)
model_inputs["chosen_attention_mask"].append([1] * len(chosen_input_ids))
model_inputs["chosen_labels"].append(chosen_labels)
model_inputs["rejected_input_ids"].append(rejected_input_ids)
model_inputs["rejected_attention_mask"].append([1] * len(rejected_input_ids))
model_inputs["rejected_labels"].append(rejected_labels)
if processor is not None:
model_inputs["pixel_values"].append(get_pixel_values(examples["images"][i], processor))
if hasattr(processor, "image_seq_length"): # paligemma models
model_inputs["chosen_token_type_ids"].append(
get_paligemma_token_type_ids(len(chosen_input_ids), processor)
)
model_inputs["rejected_token_type_ids"].append(
get_paligemma_token_type_ids(len(rejected_input_ids), processor)
)
return model_inputs
def print_pairwise_dataset_example(example: Dict[str, List[int]], tokenizer: "PreTrainedTokenizer") -> None:
valid_chosen_labels = list(filter(lambda x: x != IGNORE_INDEX, example["chosen_labels"]))
valid_rejected_labels = list(filter(lambda x: x != IGNORE_INDEX, example["rejected_labels"]))
print("chosen_input_ids:\n{}".format(example["chosen_input_ids"]))
print("chosen_inputs:\n{}".format(tokenizer.decode(example["chosen_input_ids"], skip_special_tokens=False)))
print("chosen_label_ids:\n{}".format(example["chosen_labels"]))
print("chosen_labels:\n{}".format(tokenizer.decode(valid_chosen_labels, skip_special_tokens=False)))
print("rejected_input_ids:\n{}".format(example["rejected_input_ids"]))
print("rejected_inputs:\n{}".format(tokenizer.decode(example["rejected_input_ids"], skip_special_tokens=False)))
print("rejected_label_ids:\n{}".format(example["rejected_labels"]))
print("rejected_labels:\n{}".format(tokenizer.decode(valid_rejected_labels, skip_special_tokens=False)))

View File

@@ -0,0 +1,36 @@
from itertools import chain
from typing import TYPE_CHECKING, Any, Dict, List
if TYPE_CHECKING:
from transformers.tokenization_utils import PreTrainedTokenizer
from ...hparams import DataArguments
def preprocess_pretrain_dataset(
examples: Dict[str, List[Any]], tokenizer: "PreTrainedTokenizer", data_args: "DataArguments"
) -> Dict[str, List[List[int]]]:
# build grouped texts with format `X1 X2 X3 ...` if packing is enabled
text_examples = [messages[0]["content"] + tokenizer.eos_token for messages in examples["prompt"]]
if not data_args.packing:
if data_args.template == "gemma":
text_examples = [tokenizer.bos_token + example for example in text_examples]
result = tokenizer(text_examples, add_special_tokens=False, max_length=data_args.cutoff_len, truncation=True)
else:
tokenized_examples = tokenizer(text_examples, add_special_tokens=False)
concatenated_examples = {k: list(chain(*tokenized_examples[k])) for k in tokenized_examples.keys()}
total_length = len(concatenated_examples[list(concatenated_examples.keys())[0]])
block_size = data_args.cutoff_len
total_length = (total_length // block_size) * block_size
result = {
k: [t[i : i + block_size] for i in range(0, total_length, block_size)]
for k, t in concatenated_examples.items()
}
if data_args.template == "gemma":
for i in range(len(result["input_ids"])):
result["input_ids"][i][0] = tokenizer.bos_token_id
return result

View File

@@ -0,0 +1,64 @@
import bisect
from typing import TYPE_CHECKING, List, Sequence
from ...extras.packages import is_pillow_available
if is_pillow_available():
from PIL import Image
if TYPE_CHECKING:
from numpy.typing import NDArray
from PIL.Image import Image as ImageObject
from transformers import ProcessorMixin
from transformers.image_processing_utils import BaseImageProcessor
def search_for_fit(numbers: Sequence[int], capacity: int) -> int:
r"""
Finds the index of largest number that fits into the knapsack with the given capacity.
"""
index = bisect.bisect(numbers, capacity)
return -1 if index == 0 else (index - 1)
def greedy_knapsack(numbers: List[int], capacity: int) -> List[List[int]]:
r"""
An efficient greedy algorithm with binary search for the knapsack problem.
"""
numbers.sort() # sort numbers in ascending order for binary search
knapsacks = []
while numbers:
current_knapsack = []
remaining_capacity = capacity
while True:
index = search_for_fit(numbers, remaining_capacity)
if index == -1:
break # no more numbers fit in this knapsack
remaining_capacity -= numbers[index] # update the remaining capacity
current_knapsack.append(numbers.pop(index)) # add the number to knapsack
knapsacks.append(current_knapsack)
return knapsacks
def get_pixel_values(images: Sequence["ImageObject"], processor: "ProcessorMixin") -> "NDArray":
r"""
Processes visual inputs. (currently only supports a single image)
"""
image_processor: "BaseImageProcessor" = getattr(processor, "image_processor")
image = images[0] if len(images) != 0 else Image.new("RGB", (100, 100), (255, 255, 255))
return image_processor(image, return_tensors="pt")["pixel_values"][0] # shape (C, H, W)
def get_paligemma_token_type_ids(input_len: int, processor: "ProcessorMixin") -> List[int]:
r"""
Gets paligemma token type ids for computing loss.
"""
image_seq_length = getattr(processor, "image_seq_length")
return [0] * image_seq_length + [1] * (input_len - image_seq_length)

View File

@@ -0,0 +1,169 @@
from collections import defaultdict
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Sequence, Tuple
from ...extras.constants import IGNORE_INDEX
from ...extras.logging import get_logger
from .processor_utils import get_paligemma_token_type_ids, get_pixel_values, greedy_knapsack
if TYPE_CHECKING:
from transformers import ProcessorMixin
from transformers.tokenization_utils import PreTrainedTokenizer
from ...hparams import DataArguments
from ..template import Template
logger = get_logger(__name__)
def _encode_supervised_example(
prompt: Sequence[Dict[str, str]],
response: Sequence[Dict[str, str]],
system: Optional[str],
tools: Optional[str],
template: "Template",
tokenizer: "PreTrainedTokenizer",
processor: Optional["ProcessorMixin"],
data_args: "DataArguments",
) -> Tuple[List[int], List[int]]:
if processor is not None and not hasattr(processor, "image_seq_length"): # llava-like models
prompt[0]["content"] = template.image_token + prompt[0]["content"]
messages = prompt + response
input_ids, labels = [], []
if processor is not None and hasattr(processor, "image_seq_length"): # paligemma models
image_token_id = tokenizer.convert_tokens_to_ids(template.image_token)
input_ids += [image_token_id] * getattr(processor, "image_seq_length")
labels += [IGNORE_INDEX] * getattr(processor, "image_seq_length")
encoded_pairs = template.encode_multiturn(
tokenizer, messages, system, tools, data_args.cutoff_len, data_args.reserved_label_len
)
for turn_idx, (source_ids, target_ids) in enumerate(encoded_pairs):
if data_args.train_on_prompt:
source_mask = source_ids
elif turn_idx != 0 and template.efficient_eos:
source_mask = [tokenizer.eos_token_id] + [IGNORE_INDEX] * (len(source_ids) - 1)
else:
source_mask = [IGNORE_INDEX] * len(source_ids)
input_ids += source_ids + target_ids
labels += source_mask + target_ids
if template.efficient_eos:
input_ids += [tokenizer.eos_token_id]
labels += [tokenizer.eos_token_id]
return input_ids, labels
def preprocess_supervised_dataset(
examples: Dict[str, List[Any]],
template: "Template",
tokenizer: "PreTrainedTokenizer",
processor: Optional["ProcessorMixin"],
data_args: "DataArguments",
) -> Dict[str, List[List[int]]]:
# build inputs with format `<bos> X Y <eos>` and labels with format `<ignore> ... <ignore> Y <eos>`
# for multiturn examples, we only mask the prompt part in each prompt-response pair.
model_inputs = {"input_ids": [], "attention_mask": [], "labels": []}
if processor is not None:
model_inputs["pixel_values"] = []
if hasattr(processor, "image_seq_length"): # paligemma models
model_inputs["token_type_ids"] = []
for i in range(len(examples["prompt"])):
if len(examples["prompt"][i]) % 2 != 1 or len(examples["response"][i]) != 1:
logger.warning("Dropped invalid example: {}".format(examples["prompt"][i] + examples["response"][i]))
continue
input_ids, labels = _encode_supervised_example(
prompt=examples["prompt"][i],
response=examples["response"][i],
system=examples["system"][i],
tools=examples["tools"][i],
template=template,
tokenizer=tokenizer,
processor=processor,
data_args=data_args,
)
model_inputs["input_ids"].append(input_ids)
model_inputs["attention_mask"].append([1] * len(input_ids))
model_inputs["labels"].append(labels)
if processor is not None:
model_inputs["pixel_values"].append(get_pixel_values(examples["images"][i], processor))
if hasattr(processor, "image_seq_length"): # paligemma models
model_inputs["token_type_ids"].append(get_paligemma_token_type_ids(len(input_ids), processor))
return model_inputs
def preprocess_packed_supervised_dataset(
examples: Dict[str, List[Any]],
template: "Template",
tokenizer: "PreTrainedTokenizer",
data_args: "DataArguments",
) -> Dict[str, List[List[int]]]:
# build inputs with format `<bos> X1 Y1 <eos> <bos> X2 Y2 <eos>`
# and labels with format `<ignore> ... <ignore> Y1 <eos> <ignore> ... <ignore> Y2 <eos>`
valid_num = 0
batch_input_ids, batch_labels = [], []
lengths = []
length2indexes = defaultdict(list)
for i in range(len(examples["prompt"])):
if len(examples["prompt"][i]) % 2 != 1 or len(examples["response"][i]) != 1:
logger.warning("Dropped invalid example: {}".format(examples["prompt"][i] + examples["response"][i]))
continue
input_ids, labels = _encode_supervised_example(
prompt=examples["prompt"][i],
response=examples["response"][i],
system=examples["system"][i],
tools=examples["tools"][i],
template=template,
tokenizer=tokenizer,
processor=None,
data_args=data_args,
)
length = len(input_ids)
if length > data_args.cutoff_len:
logger.warning("Dropped lengthy example with length {} > {}.".format(length, data_args.cutoff_len))
else:
lengths.append(length)
length2indexes[length].append(valid_num)
batch_input_ids.append(input_ids)
batch_labels.append(labels)
valid_num += 1
model_inputs = {"input_ids": [], "attention_mask": [], "labels": []}
knapsacks = greedy_knapsack(lengths, data_args.cutoff_len)
for knapsack in knapsacks:
packed_input_ids, packed_labels = [], []
for length in knapsack:
index = length2indexes[length].pop()
packed_input_ids += batch_input_ids[index]
packed_labels += batch_labels[index]
if len(packed_input_ids) < data_args.cutoff_len:
pad_length = data_args.cutoff_len - len(packed_input_ids)
packed_input_ids += [tokenizer.pad_token_id] * pad_length
packed_labels += [IGNORE_INDEX] * pad_length
if len(packed_input_ids) != data_args.cutoff_len:
raise ValueError("The length of packed example should be identical to the cutoff length.")
model_inputs["input_ids"].append(packed_input_ids)
model_inputs["attention_mask"].append([1] * data_args.cutoff_len)
model_inputs["labels"].append(packed_labels)
return model_inputs
def print_supervised_dataset_example(example: Dict[str, List[int]], tokenizer: "PreTrainedTokenizer") -> None:
valid_labels = list(filter(lambda x: x != IGNORE_INDEX, example["labels"]))
print("input_ids:\n{}".format(example["input_ids"]))
print("inputs:\n{}".format(tokenizer.decode(example["input_ids"], skip_special_tokens=False)))
print("label_ids:\n{}".format(example["labels"]))
print("labels:\n{}".format(tokenizer.decode(valid_labels, skip_special_tokens=False)))

View File

@@ -0,0 +1,92 @@
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Sequence, Tuple
from ...extras.logging import get_logger
from ..data_utils import Role
from .processor_utils import get_paligemma_token_type_ids, get_pixel_values
if TYPE_CHECKING:
from transformers import ProcessorMixin
from transformers.tokenization_utils import PreTrainedTokenizer
from ...hparams import DataArguments
from ..template import Template
logger = get_logger(__name__)
def _encode_unsupervised_example(
prompt: Sequence[Dict[str, str]],
response: Sequence[Dict[str, str]],
system: Optional[str],
tools: Optional[str],
template: "Template",
tokenizer: "PreTrainedTokenizer",
processor: Optional["ProcessorMixin"],
data_args: "DataArguments",
) -> Tuple[List[int], List[int]]:
if processor is not None and not hasattr(processor, "image_seq_length"): # llava-like models
prompt[0]["content"] = template.image_token + prompt[0]["content"]
if len(response) == 1:
messages = prompt + response
else:
messages = prompt + [{"role": Role.ASSISTANT.value, "content": ""}]
input_ids, labels = template.encode_oneturn(
tokenizer, messages, system, tools, data_args.cutoff_len, data_args.reserved_label_len
)
if template.efficient_eos:
labels += [tokenizer.eos_token_id]
if processor is not None and hasattr(processor, "image_seq_length"): # paligemma models
image_token_id = tokenizer.convert_tokens_to_ids(template.image_token)
input_ids = [image_token_id] * getattr(processor, "image_seq_length") + input_ids
return input_ids, labels
def preprocess_unsupervised_dataset(
examples: Dict[str, List[Any]],
template: "Template",
tokenizer: "PreTrainedTokenizer",
processor: Optional["ProcessorMixin"],
data_args: "DataArguments",
) -> Dict[str, List[List[int]]]:
# build inputs with format `<bos> X` and labels with format `Y <eos>`
model_inputs = {"input_ids": [], "attention_mask": [], "labels": []}
if processor is not None:
model_inputs["pixel_values"] = []
if hasattr(processor, "image_seq_length"): # paligemma models
model_inputs["token_type_ids"] = []
for i in range(len(examples["prompt"])):
if len(examples["prompt"][i]) % 2 != 1:
logger.warning("Dropped invalid example: {}".format(examples["prompt"][i] + examples["response"][i]))
continue
input_ids, labels = _encode_unsupervised_example(
prompt=examples["prompt"][i],
response=examples["response"][i],
system=examples["system"][i],
tools=examples["tools"][i],
template=template,
tokenizer=tokenizer,
processor=processor,
data_args=data_args,
)
model_inputs["input_ids"].append(input_ids)
model_inputs["attention_mask"].append([1] * len(input_ids))
model_inputs["labels"].append(labels)
if processor is not None:
model_inputs["pixel_values"].append(get_pixel_values(examples["images"][i], processor))
if hasattr(processor, "image_seq_length"): # paligemma models
model_inputs["token_type_ids"].append(get_paligemma_token_type_ids(len(input_ids), processor))
return model_inputs
def print_unsupervised_dataset_example(example: Dict[str, List[int]], tokenizer: "PreTrainedTokenizer") -> None:
print("input_ids:\n{}".format(example["input_ids"]))
print("inputs:\n{}".format(tokenizer.decode(example["input_ids"], skip_special_tokens=False)))

View File

@@ -2,8 +2,8 @@ from dataclasses import dataclass
from typing import TYPE_CHECKING, Dict, List, Optional, Sequence, Tuple, Union
from ..extras.logging import get_logger
from .data_utils import Role, infer_max_len
from .formatter import EmptyFormatter, FunctionFormatter, StringFormatter, ToolFormatter
from .utils import Role, infer_max_len
if TYPE_CHECKING:
@@ -26,6 +26,7 @@ class Template:
format_separator: "Formatter"
default_system: str
stop_words: List[str]
image_token: str
efficient_eos: bool
replace_eos: bool
force_system: bool
@@ -68,8 +69,8 @@ class Template:
self,
tokenizer: "PreTrainedTokenizer",
messages: List[Dict[str, str]],
system: str,
tools: str,
system: Optional[str],
tools: Optional[str],
cutoff_len: int,
reserved_label_len: int,
) -> Sequence[Tuple[List[int], List[int]]]:
@@ -195,7 +196,7 @@ class Llama2Template(Template):
return self._make_pairs(encoded_messages, cutoff_len, reserved_label_len)
templates: Dict[str, Template] = {}
TEMPLATES: Dict[str, Template] = {}
def _register_template(
@@ -209,6 +210,7 @@ def _register_template(
format_separator: Optional["Formatter"] = None,
default_system: str = "",
stop_words: List[str] = [],
image_token: str = "<image>",
efficient_eos: bool = False,
replace_eos: bool = False,
force_system: bool = False,
@@ -246,7 +248,7 @@ def _register_template(
default_function_formatter = FunctionFormatter(slots=["Action: {{name}}\nAction Input: {{arguments}}"] + eos_slots)
default_tool_formatter = ToolFormatter(tool_format="default")
default_separator_formatter = EmptyFormatter()
templates[name] = template_class(
TEMPLATES[name] = template_class(
format_user=format_user or default_user_formatter,
format_assistant=format_assistant or default_assistant_formatter,
format_system=format_system or default_user_formatter,
@@ -256,6 +258,7 @@ def _register_template(
format_separator=format_separator or default_separator_formatter,
default_system=default_system,
stop_words=stop_words,
image_token=image_token,
efficient_eos=efficient_eos,
replace_eos=replace_eos,
force_system=force_system,
@@ -276,7 +279,7 @@ def _add_or_replace_eos_token(tokenizer: "PreTrainedTokenizer", eos_token: str)
def _jinja_escape(content: str) -> str:
return content.replace("\n", r"\n").replace("'", r"\'")
return content.replace("'", r"\'")
def _convert_slots_to_jinja(slots: "SLOTS", tokenizer: "PreTrainedTokenizer", placeholder: str = "content") -> str:
@@ -290,10 +293,10 @@ def _convert_slots_to_jinja(slots: "SLOTS", tokenizer: "PreTrainedTokenizer", pl
slot_items.append(placeholder)
if slot_pieces[1]:
slot_items.append("'" + _jinja_escape(slot_pieces[1]) + "'")
elif isinstance(slot, set):
if "bos_token" in slot:
elif isinstance(slot, set): # do not use {{ eos_token }} since it may be replaced
if "bos_token" in slot and tokenizer.bos_token_id is not None:
slot_items.append("'" + tokenizer.bos_token + "'")
elif "eos_token" in slot: # do not use {{ eos_token }} since it may be replaced
elif "eos_token" in slot and tokenizer.eos_token_id is not None:
slot_items.append("'" + tokenizer.eos_token + "'")
elif isinstance(slot, dict):
raise ValueError("Dict is not supported.")
@@ -308,7 +311,7 @@ def _get_jinja_template(template: "Template", tokenizer: "PreTrainedTokenizer")
jinja_template += "{% set system_message = '" + _jinja_escape(template.default_system) + "' %}"
jinja_template += (
"{% if messages[0]['role'] == 'system' %}" "{% set system_message = messages[0]['content'] %}" "{% endif %}"
"{% if messages[0]['role'] == 'system' %}{% set system_message = messages[0]['content'] %}{% endif %}"
)
system_message = _convert_slots_to_jinja(template.format_system.apply(), tokenizer, placeholder="system_message")
@@ -325,9 +328,11 @@ def _get_jinja_template(template: "Template", tokenizer: "PreTrainedTokenizer")
jinja_template += "{% if loop.index0 == 0 and system_message is defined %}"
jinja_template += "{% set content = " + system_message + " + message['content'] %}"
jinja_template += "{% endif %}"
jinja_template += "{% if message['role'] == 'user' %}"
user_message = _convert_slots_to_jinja(template.format_user.apply(), tokenizer)
jinja_template += "{{ " + user_message + " }}"
jinja_template += "{% elif message['role'] == 'assistant' %}"
assistant_message = _convert_slots_to_jinja(
template.format_assistant.apply() + template.format_separator.apply(), tokenizer
@@ -343,9 +348,9 @@ def get_template_and_fix_tokenizer(
name: Optional[str] = None,
) -> Template:
if name is None:
template = templates["vanilla"] # placeholder
template = TEMPLATES["empty"] # placeholder
else:
template = templates.get(name, None)
template = TEMPLATES.get(name, None)
if template is None:
raise ValueError("Template {} does not exist.".format(name))
@@ -385,7 +390,8 @@ _register_template(
format_user=StringFormatter(slots=["### Instruction:\n{{content}}\n\n### Response:\n"]),
format_separator=EmptyFormatter(slots=["\n\n"]),
default_system=(
"Below is an instruction that describes a task. " "Write a response that appropriately completes the request."
"Below is an instruction that describes a task. "
"Write a response that appropriately completes the request.\n\n"
),
)
@@ -414,7 +420,7 @@ _register_template(
_register_template(
name="baichuan",
format_user=StringFormatter(slots=["<reserved_102>{{content}}<reserved_103>"]),
format_user=StringFormatter(slots=[{"token": "<reserved_102>"}, "{{content}}", {"token": "<reserved_103>"}]),
efficient_eos=True,
)
@@ -441,6 +447,18 @@ _register_template(
)
_register_template(
name="breeze",
format_user=StringFormatter(slots=["[INST] {{content}} [/INST] "]),
format_system=StringFormatter(slots=[{"bos_token"}, "{{content}}"]),
default_system=(
"You are a helpful AI assistant built by MediaTek Research. "
"The user you are helping speaks Traditional Chinese and comes from Taiwan."
),
efficient_eos=True,
)
_register_template(
name="chatglm2",
format_user=StringFormatter(slots=["[Round {{idx}}]\n\n问:{{content}}\n\n答:"]),
@@ -490,6 +508,7 @@ _register_template(
name="chatml",
format_user=StringFormatter(slots=["<|im_start|>user\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]),
format_system=StringFormatter(slots=["<|im_start|>system\n{{content}}<|im_end|>\n"]),
format_observation=StringFormatter(slots=["<|im_start|>tool\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]),
format_separator=EmptyFormatter(slots=["\n"]),
stop_words=["<|im_end|>", "<|im_start|>"],
replace_eos=True,
@@ -500,6 +519,7 @@ _register_template(
name="chatml_de",
format_user=StringFormatter(slots=["<|im_start|>user\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]),
format_system=StringFormatter(slots=["<|im_start|>system\n{{content}}<|im_end|>\n"]),
format_observation=StringFormatter(slots=["<|im_start|>tool\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]),
format_separator=EmptyFormatter(slots=["\n"]),
default_system="Du bist ein freundlicher und hilfsbereiter KI-Assistent.",
stop_words=["<|im_end|>", "<|im_start|>"],
@@ -514,6 +534,26 @@ _register_template(
)
_register_template(
name="cohere",
format_user=StringFormatter(
slots=[
(
"<|START_OF_TURN_TOKEN|><|USER_TOKEN|>{{content}}<|END_OF_TURN_TOKEN|>"
"<|START_OF_TURN_TOKEN|><|CHATBOT_TOKEN|>"
)
]
),
format_system=StringFormatter(
slots=[{"bos_token"}, "<|START_OF_TURN_TOKEN|><|SYSTEM_TOKEN|>{{content}}<|END_OF_TURN_TOKEN|>"]
),
default_system=(
"You are Command-R, a brilliant, sophisticated, AI-assistant trained to assist human users "
"by providing thorough responses. You are trained by Cohere."
),
)
_register_template(
name="cpm",
format_user=StringFormatter(slots=["<用户>{{content}}<AI>"]),
@@ -522,6 +562,32 @@ _register_template(
)
_register_template(
name="dbrx",
format_user=StringFormatter(slots=["<|im_start|>user\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]),
format_system=StringFormatter(slots=["<|im_start|>system\n{{content}}<|im_end|>\n"]),
format_observation=StringFormatter(slots=["<|im_start|>tool\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]),
format_separator=EmptyFormatter(slots=["\n"]),
default_system=(
"You are DBRX, created by Databricks. You were last updated in December 2023. "
"You answer questions based on information available up to that point.\n"
"YOU PROVIDE SHORT RESPONSES TO SHORT QUESTIONS OR STATEMENTS, but provide thorough "
"responses to more complex and open-ended questions.\nYou assist with various tasks, "
"from writing to coding (using markdown for code blocks — remember to use ``` with "
"code, JSON, and tables).\n(You do not have real-time data access or code execution "
"capabilities. You avoid stereotyping and provide balanced perspectives on "
"controversial topics. You do not provide song lyrics, poems, or news articles and "
"do not divulge details of your training data.)\nThis is your system prompt, "
"guiding your responses. Do not reference it, just respond to the user. If you find "
"yourself talking about this message, stop. You should be responding appropriately "
"and usually that means not mentioning this.\nYOU DO NOT MENTION ANY OF THIS INFORMATION "
"ABOUT YOURSELF UNLESS THE INFORMATION IS DIRECTLY PERTINENT TO THE USER'S QUERY."
),
stop_words=["<|im_end|>"],
replace_eos=True,
)
_register_template(
name="deepseek",
format_user=StringFormatter(slots=["User: {{content}}\n\nAssistant:"]),
@@ -554,6 +620,16 @@ _register_template(
)
_register_template(
name="empty",
format_user=StringFormatter(slots=["{{content}}"]),
format_assistant=StringFormatter(slots=["{{content}}"]),
format_system=StringFormatter(slots=[{"bos_token"}, "{{content}}"]),
efficient_eos=True,
force_system=True,
)
_register_template(
name="falcon",
format_user=StringFormatter(slots=["User: {{content}}\nFalcon:"]),
@@ -562,16 +638,39 @@ _register_template(
)
_register_template(
name="fewshot",
format_separator=EmptyFormatter(slots=["\n\n"]),
efficient_eos=True,
)
_register_template(
name="gemma",
format_user=StringFormatter(slots=["<start_of_turn>user\n{{content}}<end_of_turn>\n<start_of_turn>model\n"]),
format_system=StringFormatter(slots=[{"bos_token"}, "{{content}}"]),
format_observation=StringFormatter(
slots=["<start_of_turn>tool\n{{content}}<end_of_turn>\n<start_of_turn>model\n"]
),
format_separator=EmptyFormatter(slots=["<end_of_turn>\n"]),
efficient_eos=True,
force_system=True,
)
_register_template(
name="glm4",
format_user=StringFormatter(slots=["<|user|>\n{{content}}<|assistant|>"]),
format_assistant=StringFormatter(slots=["\n{{content}}"]),
format_system=StringFormatter(slots=["[gMASK]<sop>{{content}}"]),
format_function=FunctionFormatter(slots=["{{name}}\n{{arguments}}"]),
format_observation=StringFormatter(slots=["<|observation|>\n{{content}}<|assistant|>"]),
stop_words=["<|user|>", "<|observation|>"],
efficient_eos=True,
force_system=True,
)
_register_template(
name="intern",
format_user=StringFormatter(slots=["<|User|>:{{content}}", {"token": "<eoh>"}, "\n<|Bot|>:"]),
@@ -601,17 +700,8 @@ _register_template(
_register_template(
name="llama2",
format_user=StringFormatter(slots=[{"bos_token"}, "[INST] {{content}} [/INST]"]),
format_assistant=StringFormatter(slots=[" {{content}} ", {"eos_token"}]),
format_system=StringFormatter(slots=["<<SYS>>\n{{content}}\n<</SYS>>\n\n"]),
default_system=(
"You are a helpful, respectful and honest assistant. "
"Always answer as helpfully as possible, while being safe. "
"Your answers should not include any harmful, unethical, "
"racist, sexist, toxic, dangerous, or illegal content. "
"Please ensure that your responses are socially unbiased and positive in nature.\n\n"
"If a question does not make any sense, or is not factually coherent, "
"explain why instead of answering something not correct. "
"If you don't know the answer to a question, please don't share false information."
),
)
@@ -623,6 +713,33 @@ _register_template(
)
_register_template(
name="llama3",
format_user=StringFormatter(
slots=[
(
"<|start_header_id|>user<|end_header_id|>\n\n{{content}}<|eot_id|>"
"<|start_header_id|>assistant<|end_header_id|>\n\n"
)
]
),
format_system=StringFormatter(
slots=[{"bos_token"}, "<|start_header_id|>system<|end_header_id|>\n\n{{content}}<|eot_id|>"]
),
format_observation=StringFormatter(
slots=[
(
"<|start_header_id|>tool<|end_header_id|>\n\n{{content}}<|eot_id|>"
"<|start_header_id|>assistant<|end_header_id|>\n\n"
)
]
),
default_system="You are a helpful assistant.",
stop_words=["<|eot_id|>"],
replace_eos=True,
)
_register_template(
name="mistral",
format_user=StringFormatter(slots=["[INST] {{content}} [/INST]"]),
@@ -633,8 +750,7 @@ _register_template(
_register_template(
name="olmo",
format_user=StringFormatter(slots=["<|user|>\n{{content}}<|assistant|>"]),
format_assistant=StringFormatter(slots=["{{content}}", {"eos_token"}]),
format_user=StringFormatter(slots=["<|user|>\n{{content}}<|assistant|>\n"]),
format_system=StringFormatter(slots=[{"eos_token"}, "{{content}}"]),
force_system=True,
)
@@ -643,12 +759,28 @@ _register_template(
_register_template(
name="openchat",
format_user=StringFormatter(slots=["GPT4 Correct User: {{content}}", {"eos_token"}, "GPT4 Correct Assistant:"]),
format_assistant=StringFormatter(slots=["{{content}}", {"eos_token"}]),
format_system=StringFormatter(slots=[{"bos_token"}, "{{content}}"]),
force_system=True,
)
_register_template(
name="openchat-3.6",
format_user=StringFormatter(
slots=[
(
"<|start_header_id|>GPT4 Correct User<|end_header_id|>\n\n{{content}}<|eot_id|>"
"<|start_header_id|>GPT4 Correct Assistant<|end_header_id|>\n\n"
)
]
),
format_system=StringFormatter(slots=[{"bos_token"}, "{{content}}"]),
stop_words=["<|eot_id|>"],
replace_eos=True,
force_system=True,
)
_register_template(
name="orion",
format_user=StringFormatter(slots=["Human: {{content}}\n\nAssistant: ", {"eos_token"}]),
@@ -657,10 +789,22 @@ _register_template(
)
_register_template(
name="phi",
format_user=StringFormatter(slots=["<|user|>\n{{content}}<|end|>\n<|assistant|>\n"]),
format_system=StringFormatter(slots=[{"bos_token"}, "<|system|>\n{{content}}<|end|>\n"]),
format_separator=EmptyFormatter(slots=["\n"]),
default_system="You are a helpful AI assistant.",
stop_words=["<|end|>"],
replace_eos=True,
)
_register_template(
name="qwen",
format_user=StringFormatter(slots=["<|im_start|>user\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]),
format_system=StringFormatter(slots=["<|im_start|>system\n{{content}}<|im_end|>\n"]),
format_observation=StringFormatter(slots=["<|im_start|>tool\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]),
format_separator=EmptyFormatter(slots=["\n"]),
default_system="You are a helpful assistant.",
stop_words=["<|im_end|>"],
@@ -688,7 +832,11 @@ _register_template(
_register_template(
name="vanilla",
name="telechat",
format_user=StringFormatter(slots=["<_user>{{content}}<_bot>"]),
format_system=StringFormatter(slots=["<_system>{{content}}<_end>"]),
stop_words=["<_end>"],
replace_eos=True,
)
@@ -742,12 +890,29 @@ _register_template(
_register_template(
name="yi",
format_user=StringFormatter(slots=["<|im_start|>user\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]),
format_system=StringFormatter(slots=["<|im_start|>system\n{{content}}<|im_end|>\n"]),
format_separator=EmptyFormatter(slots=["\n"]),
stop_words=["<|im_end|>"],
replace_eos=True,
)
_register_template(
name="yi_vl",
format_user=StringFormatter(slots=["### Human: {{content}}\n### Assistant:"]),
format_separator=EmptyFormatter(slots=["\n"]),
default_system=(
"This is a chat between an inquisitive human and an AI assistant. "
"Assume the role of the AI assistant. Read all the images carefully, "
"and respond to the human's questions with informative, helpful, detailed and polite answers. "
"这是一个好奇的人类和一个人工智能助手之间的对话。假设你扮演这个AI助手的角色。"
"仔细阅读所有的图像,并对人类的问题做出信息丰富、有帮助、详细的和礼貌的回答。\n\n"
),
stop_words=["###"],
efficient_eos=True,
)
_register_template(
name="yuan",
format_user=StringFormatter(slots=["{{content}}", {"token": "<sep>"}]),
@@ -762,7 +927,7 @@ _register_template(
format_user=StringFormatter(slots=["<|user|>\n{{content}}", {"eos_token"}, "<|assistant|>"]),
format_assistant=StringFormatter(slots=["\n{{content}}", {"eos_token"}]),
format_system=StringFormatter(slots=["<|system|>\n{{content}}", {"eos_token"}]),
default_system="You are a friendly chatbot who always responds in the style of a pirate",
default_system="You are Zephyr, a helpful assistant.",
)

View File

@@ -14,16 +14,17 @@ from transformers.utils import cached_file
from ..data import get_template_and_fix_tokenizer
from ..extras.constants import CHOICES, SUBJECTS
from ..hparams import get_eval_args
from ..model import load_model_and_tokenizer
from ..model import load_model, load_tokenizer
from .template import get_eval_template
class Evaluator:
def __init__(self, args: Optional[Dict[str, Any]] = None) -> None:
self.model_args, self.data_args, self.eval_args, finetuning_args = get_eval_args(args)
self.model, self.tokenizer = load_model_and_tokenizer(self.model_args, finetuning_args)
self.tokenizer = load_tokenizer(self.model_args)["tokenizer"]
self.tokenizer.padding_side = "right" # avoid overflow issue in batched inference for llama2
self.template = get_template_and_fix_tokenizer(self.tokenizer, self.data_args.template)
self.model = load_model(self.tokenizer, self.model_args, finetuning_args)
self.eval_template = get_eval_template(self.eval_args.lang)
self.choice_inputs = [
self.tokenizer.encode(self.eval_template.prefix + ch, add_special_tokens=False)[-1] for ch in CHOICES
@@ -117,6 +118,5 @@ class Evaluator:
f.write(score_info)
if __name__ == "__main__":
evaluator = Evaluator()
evaluator.eval()
def run_eval() -> None:
Evaluator().eval()

View File

@@ -1,14 +1,10 @@
from dataclasses import dataclass
from typing import TYPE_CHECKING, Dict, List, Tuple
from typing import Dict, List, Sequence, Tuple
from ..data import Role
from ..extras.constants import CHOICES
if TYPE_CHECKING:
from datasets import Dataset
@dataclass
class EvalTemplate:
system: str
@@ -16,22 +12,29 @@ class EvalTemplate:
answer: str
prefix: str
def parse_example(self, example: Dict[str, str]) -> Tuple[str, str]:
def _parse_example(self, example: Dict[str, str]) -> Tuple[str, str]:
r"""
input: a dict with keys {"question", "A", "B", "C", "D", "answer"}
output: a tuple of (prompt, response)
"""
candidates = [self.choice.format(choice=ch, content=example[ch]) for ch in CHOICES if ch in example]
return "".join([example["question"]] + candidates + [self.answer]), example["answer"]
def format_example(
self, target_data: Dict[str, str], support_set: "Dataset", subject_name: str
self, target_data: Dict[str, str], support_set: Sequence[Dict[str, str]], subject_name: str
) -> List[Dict[str, str]]:
r"""
Converts dataset examples to messages.
"""
messages = []
for k in range(len(support_set)):
prompt, response = self.parse_example(support_set[k])
messages.append({"role": Role.USER, "content": prompt})
messages.append({"role": Role.ASSISTANT, "content": response})
prompt, response = self._parse_example(support_set[k])
messages.append({"role": Role.USER.value, "content": prompt})
messages.append({"role": Role.ASSISTANT.value, "content": response})
prompt, response = self.parse_example(target_data)
messages.append({"role": Role.USER, "content": prompt})
messages.append({"role": Role.ASSISTANT, "content": response})
prompt, response = self._parse_example(target_data)
messages.append({"role": Role.USER.value, "content": prompt})
messages.append({"role": Role.ASSISTANT.value, "content": response})
messages[0]["content"] = self.system.format(subject=subject_name) + messages[0]["content"]
return messages
@@ -39,7 +42,7 @@ class EvalTemplate:
eval_templates: Dict[str, "EvalTemplate"] = {}
def register_eval_template(name: str, system: str, choice: str, answer: str, prefix: str) -> None:
def _register_eval_template(name: str, system: str, choice: str, answer: str, prefix: str) -> None:
eval_templates[name] = EvalTemplate(system=system, choice=choice, answer=answer, prefix=prefix)
@@ -49,7 +52,7 @@ def get_eval_template(name: str) -> "EvalTemplate":
return eval_template
register_eval_template(
_register_eval_template(
name="en",
system="The following are multiple choice questions (with answers) about {subject}.\n\n",
choice="\n{choice}. {content}",
@@ -58,10 +61,10 @@ register_eval_template(
)
register_eval_template(
_register_eval_template(
name="zh",
system="以下是中国关于{subject}考试的单项选择题,请选出其中的正确答案。\n\n",
choice="\n{choice}. {content}",
answer="\n答案:",
prefix="\n",
prefix=" ",
)

View File

@@ -0,0 +1,217 @@
import json
import logging
import os
import signal
import sys
import time
from concurrent.futures import ThreadPoolExecutor
from datetime import timedelta
from typing import TYPE_CHECKING, Any, Dict, Optional
import transformers
from transformers import TrainerCallback
from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR, has_length
from .constants import TRAINER_LOG
from .logging import LoggerHandler, get_logger
from .misc import fix_valuehead_checkpoint
if TYPE_CHECKING:
from transformers import TrainerControl, TrainerState, TrainingArguments
logger = get_logger(__name__)
class FixValueHeadModelCallback(TrainerCallback):
def on_save(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs):
r"""
Event called after a checkpoint save.
"""
if args.should_save:
fix_valuehead_checkpoint(
model=kwargs.pop("model"),
output_dir=os.path.join(args.output_dir, "{}-{}".format(PREFIX_CHECKPOINT_DIR, state.global_step)),
safe_serialization=args.save_safetensors,
)
class LogCallback(TrainerCallback):
def __init__(self, output_dir: str) -> None:
r"""
Initializes a callback for logging training and evaluation status.
"""
""" Progress """
self.start_time = 0
self.cur_steps = 0
self.max_steps = 0
self.elapsed_time = ""
self.remaining_time = ""
self.thread_pool: Optional["ThreadPoolExecutor"] = None
""" Status """
self.aborted = False
self.do_train = False
""" Web UI """
self.webui_mode = os.environ.get("LLAMABOARD_ENABLED", "0").lower() in ["true", "1"]
if self.webui_mode:
signal.signal(signal.SIGABRT, self._set_abort)
self.logger_handler = LoggerHandler(output_dir)
logging.root.addHandler(self.logger_handler)
transformers.logging.add_handler(self.logger_handler)
def _set_abort(self, signum, frame) -> None:
self.aborted = True
def _reset(self, max_steps: int = 0) -> None:
self.start_time = time.time()
self.cur_steps = 0
self.max_steps = max_steps
self.elapsed_time = ""
self.remaining_time = ""
def _timing(self, cur_steps: int) -> None:
cur_time = time.time()
elapsed_time = cur_time - self.start_time
avg_time_per_step = elapsed_time / cur_steps if cur_steps != 0 else 0
remaining_time = (self.max_steps - cur_steps) * avg_time_per_step
self.cur_steps = cur_steps
self.elapsed_time = str(timedelta(seconds=int(elapsed_time)))
self.remaining_time = str(timedelta(seconds=int(remaining_time)))
def _write_log(self, output_dir: str, logs: Dict[str, Any]) -> None:
with open(os.path.join(output_dir, TRAINER_LOG), "a", encoding="utf-8") as f:
f.write(json.dumps(logs) + "\n")
def _create_thread_pool(self, output_dir: str) -> None:
os.makedirs(output_dir, exist_ok=True)
self.thread_pool = ThreadPoolExecutor(max_workers=1)
def _close_thread_pool(self) -> None:
if self.thread_pool is not None:
self.thread_pool.shutdown(wait=True)
self.thread_pool = None
def on_init_end(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs):
r"""
Event called at the end of the initialization of the `Trainer`.
"""
if (
args.should_save
and os.path.exists(os.path.join(args.output_dir, TRAINER_LOG))
and args.overwrite_output_dir
):
logger.warning("Previous trainer log in this folder will be deleted.")
os.remove(os.path.join(args.output_dir, TRAINER_LOG))
def on_train_begin(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs):
r"""
Event called at the beginning of training.
"""
if args.should_save:
self.do_train = True
self._reset(max_steps=state.max_steps)
self._create_thread_pool(output_dir=args.output_dir)
def on_train_end(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs):
r"""
Event called at the end of training.
"""
self._close_thread_pool()
def on_substep_end(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs):
r"""
Event called at the end of an substep during gradient accumulation.
"""
if self.aborted:
control.should_epoch_stop = True
control.should_training_stop = True
def on_step_end(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs):
r"""
Event called at the end of a training step.
"""
if self.aborted:
control.should_epoch_stop = True
control.should_training_stop = True
def on_evaluate(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs):
r"""
Event called after an evaluation phase.
"""
if not self.do_train:
self._close_thread_pool()
def on_predict(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs):
r"""
Event called after a successful prediction.
"""
if not self.do_train:
self._close_thread_pool()
def on_log(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs):
r"""
Event called after logging the last logs.
"""
if not args.should_save:
return
self._timing(cur_steps=state.global_step)
logs = dict(
current_steps=self.cur_steps,
total_steps=self.max_steps,
loss=state.log_history[-1].get("loss", None),
eval_loss=state.log_history[-1].get("eval_loss", None),
predict_loss=state.log_history[-1].get("predict_loss", None),
reward=state.log_history[-1].get("reward", None),
accuracy=state.log_history[-1].get("rewards/accuracies", None),
learning_rate=state.log_history[-1].get("learning_rate", None),
epoch=state.log_history[-1].get("epoch", None),
percentage=round(self.cur_steps / self.max_steps * 100, 2) if self.max_steps != 0 else 100,
elapsed_time=self.elapsed_time,
remaining_time=self.remaining_time,
throughput="{:.2f}".format(state.num_input_tokens_seen / (time.time() - self.start_time)),
total_tokens=state.num_input_tokens_seen,
)
logs = {k: v for k, v in logs.items() if v is not None}
if self.webui_mode and all(key in logs for key in ["loss", "learning_rate", "epoch"]):
logger.info(
"{{'loss': {:.4f}, 'learning_rate': {:2.4e}, 'epoch': {:.2f}, 'throughput': {}}}".format(
logs["loss"], logs["learning_rate"], logs["epoch"], logs["throughput"]
)
)
if self.thread_pool is not None:
self.thread_pool.submit(self._write_log, args.output_dir, logs)
def on_prediction_step(
self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs
):
r"""
Event called after a prediction step.
"""
if self.do_train:
return
if self.aborted:
sys.exit(0)
if not args.should_save:
return
eval_dataloader = kwargs.pop("eval_dataloader", None)
if has_length(eval_dataloader):
if self.max_steps == 0:
self._reset(max_steps=len(eval_dataloader))
self._create_thread_pool(output_dir=args.output_dir)
self._timing(cur_steps=self.cur_steps + 1)
if self.cur_steps % 5 == 0 and self.thread_pool is not None:
logs = dict(
current_steps=self.cur_steps,
total_steps=self.max_steps,
percentage=round(self.cur_steps / self.max_steps * 100, 2) if self.max_steps != 0 else 100,
elapsed_time=self.elapsed_time,
remaining_time=self.remaining_time,
)
self.thread_pool.submit(self._write_log, args.output_dir, logs)

View File

@@ -2,13 +2,24 @@ from collections import OrderedDict, defaultdict
from enum import Enum
from typing import Dict, Optional
from peft.utils import SAFETENSORS_WEIGHTS_NAME as SAFE_ADAPTER_WEIGHTS_NAME
from peft.utils import WEIGHTS_NAME as ADAPTER_WEIGHTS_NAME
from transformers.utils import SAFE_WEIGHTS_INDEX_NAME, SAFE_WEIGHTS_NAME, WEIGHTS_INDEX_NAME, WEIGHTS_NAME
CHECKPOINT_NAMES = {
SAFE_ADAPTER_WEIGHTS_NAME,
ADAPTER_WEIGHTS_NAME,
SAFE_WEIGHTS_INDEX_NAME,
SAFE_WEIGHTS_NAME,
WEIGHTS_INDEX_NAME,
WEIGHTS_NAME,
}
CHOICES = ["A", "B", "C", "D"]
DATA_CONFIG = "dataset_info.json"
DEFAULT_MODULE = defaultdict(str)
DEFAULT_TEMPLATE = defaultdict(str)
FILEEXT2TYPE = {
@@ -24,28 +35,43 @@ IGNORE_INDEX = -100
LAYERNORM_NAMES = {"norm", "ln"}
LOG_FILE_NAME = "trainer_log.jsonl"
LLAMABOARD_CONFIG = "llamaboard_config.yaml"
METHODS = ["full", "freeze", "lora"]
PEFT_METHODS = ["lora"]
MOD_SUPPORTED_MODELS = {"bloom", "falcon", "gemma", "llama", "mistral", "mixtral", "phi", "starcoder2"}
PEFT_METHODS = {"lora"}
RUNNING_LOG = "running_log.txt"
SUBJECTS = ["Average", "STEM", "Social Sciences", "Humanities", "Other"]
SUPPORTED_MODELS = OrderedDict()
TRAINER_LOG = "trainer_log.jsonl"
TRAINING_ARGS = "training_args.yaml"
TRAINING_STAGES = {
"Supervised Fine-Tuning": "sft",
"Reward Modeling": "rm",
"PPO": "ppo",
"DPO": "dpo",
"KTO": "kto",
"Pre-Training": "pt",
}
STAGES_USE_PAIR_DATA = {"rm", "dpo"}
SUPPORTED_CLASS_FOR_S2ATTN = {"llama"}
V_HEAD_WEIGHTS_NAME = "value_head.bin"
V_HEAD_SAFE_WEIGHTS_NAME = "value_head.safetensors"
VISION_MODELS = set()
class DownloadSource(str, Enum):
DEFAULT = "hf"
@@ -54,8 +80,8 @@ class DownloadSource(str, Enum):
def register_model_group(
models: Dict[str, Dict[DownloadSource, str]],
module: Optional[str] = None,
template: Optional[str] = None,
vision: bool = False,
) -> None:
prefix = None
for name, path in models.items():
@@ -64,10 +90,23 @@ def register_model_group(
else:
assert prefix == name.split("-")[0], "prefix should be identical."
SUPPORTED_MODELS[name] = path
if module is not None:
DEFAULT_MODULE[prefix] = module
if template is not None:
DEFAULT_TEMPLATE[prefix] = template
if vision:
VISION_MODELS.add(prefix)
register_model_group(
models={
"Aya-23-8B-Chat": {
DownloadSource.DEFAULT: "CohereForAI/aya-23-8B",
},
"Aya-23-35B-Chat": {
DownloadSource.DEFAULT: "CohereForAI/aya-23-35B",
},
},
template="cohere",
)
register_model_group(
@@ -85,7 +124,6 @@ register_model_group(
DownloadSource.MODELSCOPE: "baichuan-inc/Baichuan-13B-Chat",
},
},
module="W_pack",
template="baichuan",
)
@@ -109,7 +147,6 @@ register_model_group(
DownloadSource.MODELSCOPE: "baichuan-inc/Baichuan2-13B-Chat",
},
},
module="W_pack",
template="baichuan2",
)
@@ -129,7 +166,6 @@ register_model_group(
DownloadSource.MODELSCOPE: "AI-ModelScope/bloom-7b1",
},
},
module="query_key_value",
)
@@ -148,7 +184,6 @@ register_model_group(
DownloadSource.MODELSCOPE: "AI-ModelScope/bloomz-7b1-mt",
},
},
module="query_key_value",
)
@@ -167,6 +202,19 @@ register_model_group(
)
register_model_group(
models={
"Breeze-7B": {
DownloadSource.DEFAULT: "MediaTek-Research/Breeze-7B-Base-v1_0",
},
"Breeze-7B-Chat": {
DownloadSource.DEFAULT: "MediaTek-Research/Breeze-7B-Instruct-v1_0",
},
},
template="breeze",
)
register_model_group(
models={
"ChatGLM2-6B-Chat": {
@@ -174,7 +222,6 @@ register_model_group(
DownloadSource.MODELSCOPE: "ZhipuAI/chatglm2-6b",
}
},
module="query_key_value",
template="chatglm2",
)
@@ -190,7 +237,6 @@ register_model_group(
DownloadSource.MODELSCOPE: "ZhipuAI/chatglm3-6b",
},
},
module="query_key_value",
template="chatglm3",
)
@@ -226,6 +272,73 @@ register_model_group(
)
register_model_group(
models={
"CodeGemma-7B": {
DownloadSource.DEFAULT: "google/codegemma-7b",
},
"CodeGemma-7B-Chat": {
DownloadSource.DEFAULT: "google/codegemma-7b-it",
DownloadSource.MODELSCOPE: "AI-ModelScope/codegemma-7b-it",
},
"CodeGemma-1.1-2B": {
DownloadSource.DEFAULT: "google/codegemma-1.1-2b",
},
"CodeGemma-1.1-7B-Chat": {
DownloadSource.DEFAULT: "google/codegemma-1.1-7b-it",
},
},
template="gemma",
)
register_model_group(
models={
"Codestral-22B-v0.1-Chat": {
DownloadSource.DEFAULT: "mistralai/Codestral-22B-v0.1",
},
},
template="mistral",
)
register_model_group(
models={
"CommandR-35B-Chat": {
DownloadSource.DEFAULT: "CohereForAI/c4ai-command-r-v01",
DownloadSource.MODELSCOPE: "AI-ModelScope/c4ai-command-r-v01",
},
"CommandR-Plus-104B-Chat": {
DownloadSource.DEFAULT: "CohereForAI/c4ai-command-r-plus",
DownloadSource.MODELSCOPE: "AI-ModelScope/c4ai-command-r-plus",
},
"CommandR-35B-4bit-Chat": {
DownloadSource.DEFAULT: "CohereForAI/c4ai-command-r-v01-4bit",
DownloadSource.MODELSCOPE: "mirror013/c4ai-command-r-v01-4bit",
},
"CommandR-Plus-104B-4bit-Chat": {
DownloadSource.DEFAULT: "CohereForAI/c4ai-command-r-plus-4bit",
},
},
template="cohere",
)
register_model_group(
models={
"DBRX-132B-Base": {
DownloadSource.DEFAULT: "databricks/dbrx-base",
DownloadSource.MODELSCOPE: "AI-ModelScope/dbrx-base",
},
"DBRX-132B-Chat": {
DownloadSource.DEFAULT: "databricks/dbrx-instruct",
DownloadSource.MODELSCOPE: "AI-ModelScope/dbrx-instruct",
},
},
template="dbrx",
)
register_model_group(
models={
"DeepSeek-LLM-7B-Base": {
@@ -246,18 +359,36 @@ register_model_group(
},
"DeepSeek-Math-7B-Base": {
DownloadSource.DEFAULT: "deepseek-ai/deepseek-math-7b-base",
DownloadSource.MODELSCOPE: "deepseek-ai/deepseek-math-7b-base",
},
"DeepSeek-Math-7B-Chat": {
DownloadSource.DEFAULT: "deepseek-ai/deepseek-math-7b-instruct",
DownloadSource.MODELSCOPE: "deepseek-ai/deepseek-math-7b-instruct",
},
"DeepSeek-MoE-16B-Base": {
DownloadSource.DEFAULT: "deepseek-ai/deepseek-moe-16b-base",
DownloadSource.MODELSCOPE: "deepseek-ai/deepseek-moe-16b-base",
},
"DeepSeek-MoE-16B-v2-Base": {
DownloadSource.DEFAULT: "deepseek-ai/DeepSeek-V2-Lite",
DownloadSource.MODELSCOPE: "deepseek-ai/DeepSeek-V2-Lite",
},
"DeepSeek-MoE-236B-Base": {
DownloadSource.DEFAULT: "deepseek-ai/DeepSeek-V2",
DownloadSource.MODELSCOPE: "deepseek-ai/DeepSeek-V2",
},
"DeepSeek-MoE-16B-Chat": {
DownloadSource.DEFAULT: "deepseek-ai/deepseek-moe-16b-chat",
DownloadSource.MODELSCOPE: "deepseek-ai/deepseek-moe-16b-chat",
},
"DeepSeek-MoE-16B-v2-Chat": {
DownloadSource.DEFAULT: "deepseek-ai/DeepSeek-V2-Lite-Chat",
DownloadSource.MODELSCOPE: "deepseek-ai/DeepSeek-V2-Lite-Chat",
},
"DeepSeek-MoE-236B-Chat": {
DownloadSource.DEFAULT: "deepseek-ai/DeepSeek-V2-Chat",
DownloadSource.MODELSCOPE: "deepseek-ai/DeepSeek-V2-Chat",
},
},
template="deepseek",
)
@@ -298,6 +429,9 @@ register_model_group(
DownloadSource.DEFAULT: "tiiuae/falcon-7b",
DownloadSource.MODELSCOPE: "AI-ModelScope/falcon-7b",
},
"Falcon-11B": {
DownloadSource.DEFAULT: "tiiuae/falcon-11B",
},
"Falcon-40B": {
DownloadSource.DEFAULT: "tiiuae/falcon-40b",
DownloadSource.MODELSCOPE: "AI-ModelScope/falcon-40b",
@@ -319,7 +453,6 @@ register_model_group(
DownloadSource.MODELSCOPE: "modelscope/falcon-180B-chat",
},
},
module="query_key_value",
template="falcon",
)
@@ -342,11 +475,36 @@ register_model_group(
DownloadSource.DEFAULT: "google/gemma-7b-it",
DownloadSource.MODELSCOPE: "AI-ModelScope/gemma-7b-it",
},
"Gemma-1.1-2B-Chat": {
DownloadSource.DEFAULT: "google/gemma-1.1-2b-it",
},
"Gemma-1.1-7B-Chat": {
DownloadSource.DEFAULT: "google/gemma-1.1-7b-it",
},
},
template="gemma",
)
register_model_group(
models={
"GLM-4-9B": {
DownloadSource.DEFAULT: "THUDM/glm-4-9b",
DownloadSource.MODELSCOPE: "ZhipuAI/glm-4-9b",
},
"GLM-4-9B-Chat": {
DownloadSource.DEFAULT: "THUDM/glm-4-9b-chat",
DownloadSource.MODELSCOPE: "ZhipuAI/glm-4-9b-chat",
},
"GLM-4-9B-1M-Chat": {
DownloadSource.DEFAULT: "THUDM/glm-4-9b-chat-1m",
DownloadSource.MODELSCOPE: "ZhipuAI/glm-4-9b-chat-1m",
},
},
template="glm4",
)
register_model_group(
models={
"InternLM-7B": {
@@ -389,11 +547,20 @@ register_model_group(
DownloadSource.MODELSCOPE: "Shanghai_AI_Laboratory/internlm2-chat-20b",
},
},
module="wqkv",
template="intern2",
)
register_model_group(
models={
"Jambda-v0.1": {
DownloadSource.DEFAULT: "ai21labs/Jamba-v0.1",
DownloadSource.MODELSCOPE: "AI-ModelScope/Jamba-v0.1",
}
},
)
register_model_group(
models={
"LingoWhale-8B": {
@@ -401,7 +568,6 @@ register_model_group(
DownloadSource.MODELSCOPE: "DeepLang/LingoWhale-8B",
}
},
module="qkv_proj",
)
@@ -460,18 +626,72 @@ register_model_group(
register_model_group(
models={
"Mistral-7B": {
"LLaMA3-8B": {
DownloadSource.DEFAULT: "meta-llama/Meta-Llama-3-8B",
DownloadSource.MODELSCOPE: "LLM-Research/Meta-Llama-3-8B",
},
"LLaMA3-70B": {
DownloadSource.DEFAULT: "meta-llama/Meta-Llama-3-70B",
DownloadSource.MODELSCOPE: "LLM-Research/Meta-Llama-3-70B",
},
"LLaMA3-8B-Chat": {
DownloadSource.DEFAULT: "meta-llama/Meta-Llama-3-8B-Instruct",
DownloadSource.MODELSCOPE: "LLM-Research/Meta-Llama-3-8B-Instruct",
},
"LLaMA3-70B-Chat": {
DownloadSource.DEFAULT: "meta-llama/Meta-Llama-3-70B-Instruct",
DownloadSource.MODELSCOPE: "LLM-Research/Meta-Llama-3-70B-Instruct",
},
"LLaMA3-8B-Chinese-Chat": {
DownloadSource.DEFAULT: "shenzhi-wang/Llama3-8B-Chinese-Chat",
DownloadSource.MODELSCOPE: "LLM-Research/Llama3-8B-Chinese-Chat",
},
"LLaMA3-70B-Chinese-Chat": {
DownloadSource.DEFAULT: "shenzhi-wang/Llama3-70B-Chinese-Chat",
},
},
template="llama3",
)
register_model_group(
models={
"LLaVA1.5-7B-Chat": {
DownloadSource.DEFAULT: "llava-hf/llava-1.5-7b-hf",
},
"LLaVA1.5-13B-Chat": {
DownloadSource.DEFAULT: "llava-hf/llava-1.5-13b-hf",
},
},
template="vicuna",
vision=True,
)
register_model_group(
models={
"Mistral-7B-v0.1": {
DownloadSource.DEFAULT: "mistralai/Mistral-7B-v0.1",
DownloadSource.MODELSCOPE: "AI-ModelScope/Mistral-7B-v0.1",
},
"Mistral-7B-Chat": {
"Mistral-7B-v0.1-Chat": {
DownloadSource.DEFAULT: "mistralai/Mistral-7B-Instruct-v0.1",
DownloadSource.MODELSCOPE: "AI-ModelScope/Mistral-7B-Instruct-v0.1",
},
"Mistral-7B-v0.2": {
DownloadSource.DEFAULT: "alpindale/Mistral-7B-v0.2-hf",
DownloadSource.MODELSCOPE: "AI-ModelScope/Mistral-7B-v0.2-hf",
},
"Mistral-7B-v0.2-Chat": {
DownloadSource.DEFAULT: "mistralai/Mistral-7B-Instruct-v0.2",
DownloadSource.MODELSCOPE: "AI-ModelScope/Mistral-7B-Instruct-v0.2",
},
"Mistral-7B-v0.3": {
DownloadSource.DEFAULT: "mistralai/Mistral-7B-v0.3",
},
"Mistral-7B-v0.3-Chat": {
DownloadSource.DEFAULT: "mistralai/Mistral-7B-Instruct-v0.3",
},
},
template="mistral",
)
@@ -479,14 +699,22 @@ register_model_group(
register_model_group(
models={
"Mixtral-8x7B": {
"Mixtral-8x7B-v0.1": {
DownloadSource.DEFAULT: "mistralai/Mixtral-8x7B-v0.1",
DownloadSource.MODELSCOPE: "AI-ModelScope/Mixtral-8x7B-v0.1",
},
"Mixtral-8x7B-Chat": {
"Mixtral-8x7B-v0.1-Chat": {
DownloadSource.DEFAULT: "mistralai/Mixtral-8x7B-Instruct-v0.1",
DownloadSource.MODELSCOPE: "AI-ModelScope/Mixtral-8x7B-Instruct-v0.1",
},
"Mixtral-8x22B-v0.1": {
DownloadSource.DEFAULT: "mistralai/Mixtral-8x22B-v0.1",
DownloadSource.MODELSCOPE: "AI-ModelScope/Mixtral-8x22B-v0.1",
},
"Mixtral-8x22B-v0.1-Chat": {
DownloadSource.DEFAULT: "mistralai/Mixtral-8x22B-Instruct-v0.1",
DownloadSource.MODELSCOPE: "AI-ModelScope/Mixtral-8x22B-Instruct-v0.1",
},
},
template="mistral",
)
@@ -495,18 +723,18 @@ register_model_group(
register_model_group(
models={
"OLMo-1B": {
DownloadSource.DEFAULT: "allenai/OLMo-1B",
DownloadSource.DEFAULT: "allenai/OLMo-1B-hf",
},
"OLMo-7B": {
DownloadSource.DEFAULT: "allenai/OLMo-7B",
DownloadSource.MODELSCOPE: "AI-ModelScope/OLMo-7B",
DownloadSource.DEFAULT: "allenai/OLMo-7B-hf",
},
"OLMo-7B-Chat": {
DownloadSource.DEFAULT: "allenai/OLMo-7B-Instruct",
DownloadSource.DEFAULT: "ssec-uw/OLMo-7B-Instruct-hf",
},
"OLMo-1.7-7B": {
DownloadSource.DEFAULT: "allenai/OLMo-1.7-7B-hf",
},
},
module="att_proj",
template="olmo",
)
@@ -514,13 +742,23 @@ register_model_group(
models={
"OpenChat3.5-7B-Chat": {
DownloadSource.DEFAULT: "openchat/openchat-3.5-0106",
DownloadSource.MODELSCOPE: "myxiongmodel/openchat_3.5",
DownloadSource.MODELSCOPE: "xcwzxcwz/openchat-3.5-0106",
}
},
template="openchat",
)
register_model_group(
models={
"OpenChat3.6-8B-Chat": {
DownloadSource.DEFAULT: "openchat/openchat-3.6-8b-20240522",
}
},
template="openchat-3.6",
)
register_model_group(
models={
"Orion-14B-Base": {
@@ -548,6 +786,33 @@ register_model_group(
)
register_model_group(
models={
"PaliGemma-3B-pt-224": {
DownloadSource.DEFAULT: "google/paligemma-3b-pt-224",
DownloadSource.MODELSCOPE: "AI-ModelScope/paligemma-3b-pt-224",
},
"PaliGemma-3B-pt-448": {
DownloadSource.DEFAULT: "google/paligemma-3b-pt-448",
DownloadSource.MODELSCOPE: "AI-ModelScope/paligemma-3b-pt-448",
},
"PaliGemma-3B-pt-896": {
DownloadSource.DEFAULT: "google/paligemma-3b-pt-896",
DownloadSource.MODELSCOPE: "AI-ModelScope/paligemma-3b-pt-896",
},
"PaliGemma-3B-mix-224": {
DownloadSource.DEFAULT: "google/paligemma-3b-mix-224",
DownloadSource.MODELSCOPE: "AI-ModelScope/paligemma-3b-mix-224",
},
"PaliGemma-3B-mix-448": {
DownloadSource.DEFAULT: "google/paligemma-3b-mix-448",
DownloadSource.MODELSCOPE: "AI-ModelScope/paligemma-3b-mix-448",
},
},
vision=True,
)
register_model_group(
models={
"Phi-1.5-1.3B": {
@@ -562,6 +827,37 @@ register_model_group(
)
register_model_group(
models={
"Phi3-4B-4k-Chat": {
DownloadSource.DEFAULT: "microsoft/Phi-3-mini-4k-instruct",
DownloadSource.MODELSCOPE: "LLM-Research/Phi-3-mini-4k-instruct",
},
"Phi3-4B-128k-Chat": {
DownloadSource.DEFAULT: "microsoft/Phi-3-mini-128k-instruct",
DownloadSource.MODELSCOPE: "LLM-Research/Phi-3-mini-128k-instruct",
},
"Phi3-7B-8k-Chat": {
DownloadSource.DEFAULT: "microsoft/Phi-3-small-8k-instruct",
DownloadSource.MODELSCOPE: "LLM-Research/Phi-3-small-8k-instruct",
},
"Phi3-7B-128k-Chat": {
DownloadSource.DEFAULT: "microsoft/Phi-3-small-128k-instruct",
DownloadSource.MODELSCOPE: "LLM-Research/Phi-3-small-128k-instruct",
},
"Phi3-14B-8k-Chat": {
DownloadSource.DEFAULT: "microsoft/Phi-3-medium-4k-instruct",
DownloadSource.MODELSCOPE: "LLM-Research/Phi-3-medium-4k-instruct",
},
"Phi3-14B-128k-Chat": {
DownloadSource.DEFAULT: "microsoft/Phi-3-medium-128k-instruct",
DownloadSource.MODELSCOPE: "LLM-Research/Phi-3-medium-128k-instruct",
},
},
template="phi",
)
register_model_group(
models={
"Qwen-1.8B": {
@@ -629,7 +925,6 @@ register_model_group(
DownloadSource.MODELSCOPE: "qwen/Qwen-72B-Chat-Int4",
},
},
module="c_attn",
template="qwen",
)
@@ -656,10 +951,26 @@ register_model_group(
DownloadSource.DEFAULT: "Qwen/Qwen1.5-14B",
DownloadSource.MODELSCOPE: "qwen/Qwen1.5-14B",
},
"Qwen1.5-32B": {
DownloadSource.DEFAULT: "Qwen/Qwen1.5-32B",
DownloadSource.MODELSCOPE: "qwen/Qwen1.5-32B",
},
"Qwen1.5-72B": {
DownloadSource.DEFAULT: "Qwen/Qwen1.5-72B",
DownloadSource.MODELSCOPE: "qwen/Qwen1.5-72B",
},
"Qwen1.5-110B": {
DownloadSource.DEFAULT: "Qwen/Qwen1.5-110B",
DownloadSource.MODELSCOPE: "qwen/Qwen1.5-110B",
},
"Qwen1.5-MoE-A2.7B": {
DownloadSource.DEFAULT: "Qwen/Qwen1.5-MoE-A2.7B",
DownloadSource.MODELSCOPE: "qwen/Qwen1.5-MoE-A2.7B",
},
"Qwen1.5-Code-7B": {
DownloadSource.DEFAULT: "Qwen/CodeQwen1.5-7B",
DownloadSource.MODELSCOPE: "qwen/CodeQwen1.5-7B",
},
"Qwen1.5-0.5B-Chat": {
DownloadSource.DEFAULT: "Qwen/Qwen1.5-0.5B-Chat",
DownloadSource.MODELSCOPE: "qwen/Qwen1.5-0.5B-Chat",
@@ -680,10 +991,26 @@ register_model_group(
DownloadSource.DEFAULT: "Qwen/Qwen1.5-14B-Chat",
DownloadSource.MODELSCOPE: "qwen/Qwen1.5-14B-Chat",
},
"Qwen1.5-32B-Chat": {
DownloadSource.DEFAULT: "Qwen/Qwen1.5-32B-Chat",
DownloadSource.MODELSCOPE: "qwen/Qwen1.5-32B-Chat",
},
"Qwen1.5-72B-Chat": {
DownloadSource.DEFAULT: "Qwen/Qwen1.5-72B-Chat",
DownloadSource.MODELSCOPE: "qwen/Qwen1.5-72B-Chat",
},
"Qwen1.5-110B-Chat": {
DownloadSource.DEFAULT: "Qwen/Qwen1.5-110B-Chat",
DownloadSource.MODELSCOPE: "qwen/Qwen1.5-110B-Chat",
},
"Qwen1.5-MoE-A2.7B-Chat": {
DownloadSource.DEFAULT: "Qwen/Qwen1.5-MoE-A2.7B-Chat",
DownloadSource.MODELSCOPE: "qwen/Qwen1.5-MoE-A2.7B-Chat",
},
"Qwen1.5-Code-7B-Chat": {
DownloadSource.DEFAULT: "Qwen/CodeQwen1.5-7B-Chat",
DownloadSource.MODELSCOPE: "qwen/CodeQwen1.5-7B-Chat",
},
"Qwen1.5-0.5B-int8-Chat": {
DownloadSource.DEFAULT: "Qwen/Qwen1.5-0.5B-Chat-GPTQ-Int8",
DownloadSource.MODELSCOPE: "qwen/Qwen1.5-0.5B-Chat-GPTQ-Int8",
@@ -724,6 +1051,10 @@ register_model_group(
DownloadSource.DEFAULT: "Qwen/Qwen1.5-14B-Chat-AWQ",
DownloadSource.MODELSCOPE: "qwen/Qwen1.5-14B-Chat-AWQ",
},
"Qwen1.5-32B-int4-Chat": {
DownloadSource.DEFAULT: "Qwen/Qwen1.5-32B-Chat-AWQ",
DownloadSource.MODELSCOPE: "qwen/Qwen1.5-32B-Chat-AWQ",
},
"Qwen1.5-72B-int8-Chat": {
DownloadSource.DEFAULT: "Qwen/Qwen1.5-72B-Chat-GPTQ-Int8",
DownloadSource.MODELSCOPE: "qwen/Qwen1.5-72B-Chat-GPTQ-Int8",
@@ -732,6 +1063,101 @@ register_model_group(
DownloadSource.DEFAULT: "Qwen/Qwen1.5-72B-Chat-AWQ",
DownloadSource.MODELSCOPE: "qwen/Qwen1.5-72B-Chat-AWQ",
},
"Qwen1.5-110B-int4-Chat": {
DownloadSource.DEFAULT: "Qwen/Qwen1.5-110B-Chat-AWQ",
DownloadSource.MODELSCOPE: "qwen/Qwen1.5-110B-Chat-AWQ",
},
"Qwen1.5-MoE-A2.7B-int4-Chat": {
DownloadSource.DEFAULT: "Qwen/Qwen1.5-MoE-A2.7B-Chat-GPTQ-Int4",
DownloadSource.MODELSCOPE: "qwen/Qwen1.5-MoE-A2.7B-Chat-GPTQ-Int4",
},
"Qwen1.5-Code-7B-int4-Chat": {
DownloadSource.DEFAULT: "Qwen/CodeQwen1.5-7B-Chat-AWQ",
DownloadSource.MODELSCOPE: "qwen/CodeQwen1.5-7B-Chat-AWQ",
},
},
template="qwen",
)
register_model_group(
models={
"Qwen2-0.5B": {
DownloadSource.DEFAULT: "Qwen/Qwen2-0.5B",
DownloadSource.MODELSCOPE: "qwen/Qwen2-0.5B",
},
"Qwen2-1.5B": {
DownloadSource.DEFAULT: "Qwen/Qwen2-1.5B",
DownloadSource.MODELSCOPE: "qwen/Qwen2-1.5B",
},
"Qwen2-7B": {
DownloadSource.DEFAULT: "Qwen/Qwen2-7B",
DownloadSource.MODELSCOPE: "qwen/Qwen2-7B",
},
"Qwen2-72B": {
DownloadSource.DEFAULT: "Qwen/Qwen2-72B",
DownloadSource.MODELSCOPE: "qwen/Qwen2-72B",
},
"Qwen2-MoE-57B": {
DownloadSource.DEFAULT: "Qwen/Qwen2-57B-A14B",
DownloadSource.MODELSCOPE: "qwen/Qwen2-57B-A14B",
},
"Qwen2-0.5B-Chat": {
DownloadSource.DEFAULT: "Qwen/Qwen2-0.5B-Instruct",
DownloadSource.MODELSCOPE: "qwen/Qwen2-0.5B-Instruct",
},
"Qwen2-1.5B-Chat": {
DownloadSource.DEFAULT: "Qwen/Qwen2-1.5B-Instruct",
DownloadSource.MODELSCOPE: "qwen/Qwen2-1.5B-Instruct",
},
"Qwen2-7B-Chat": {
DownloadSource.DEFAULT: "Qwen/Qwen2-7B-Instruct",
DownloadSource.MODELSCOPE: "qwen/Qwen2-7B-Instruct",
},
"Qwen2-72B-Chat": {
DownloadSource.DEFAULT: "Qwen/Qwen2-72B-Instruct",
DownloadSource.MODELSCOPE: "qwen/Qwen2-72B-Instruct",
},
"Qwen2-MoE-57B-Chat": {
DownloadSource.DEFAULT: "Qwen/Qwen2-57B-A14B-Instruct",
DownloadSource.MODELSCOPE: "qwen/Qwen2-57B-A14B-Instruct",
},
"Qwen2-0.5B-int8-Chat": {
DownloadSource.DEFAULT: "Qwen/Qwen2-0.5B-Instruct-GPTQ-Int8",
DownloadSource.MODELSCOPE: "qwen/Qwen2-0.5B-Instruct-GPTQ-Int8",
},
"Qwen2-0.5B-int4-Chat": {
DownloadSource.DEFAULT: "Qwen/Qwen2-0.5B-Instruct-AWQ",
DownloadSource.MODELSCOPE: "qwen/Qwen2-0.5B-Instruct-AWQ",
},
"Qwen2-1.5B-int8-Chat": {
DownloadSource.DEFAULT: "Qwen/Qwen2-1.5B-Instruct-GPTQ-Int8",
DownloadSource.MODELSCOPE: "qwen/Qwen2-1.5B-Instruct-GPTQ-Int8",
},
"Qwen2-1.5B-int4-Chat": {
DownloadSource.DEFAULT: "Qwen/Qwen2-1.5B-Instruct-AWQ",
DownloadSource.MODELSCOPE: "qwen/Qwen2-1.5B-Instruct-AWQ",
},
"Qwen2-7B-int8-Chat": {
DownloadSource.DEFAULT: "Qwen/Qwen2-7B-Instruct-GPTQ-Int8",
DownloadSource.MODELSCOPE: "qwen/Qwen2-7B-Instruct-GPTQ-Int8",
},
"Qwen2-7B-int4-Chat": {
DownloadSource.DEFAULT: "Qwen/Qwen2-7B-Instruct-AWQ",
DownloadSource.MODELSCOPE: "qwen/Qwen2-7B-Instruct-AWQ",
},
"Qwen2-72B-int8-Chat": {
DownloadSource.DEFAULT: "Qwen/Qwen2-72B-Instruct-GPTQ-Int8",
DownloadSource.MODELSCOPE: "qwen/Qwen2-72B-Instruct-GPTQ-Int8",
},
"Qwen2-72B-int4-Chat": {
DownloadSource.DEFAULT: "Qwen/Qwen2-72B-Instruct-AWQ",
DownloadSource.MODELSCOPE: "qwen/Qwen2-72B-Instruct-AWQ",
},
"Qwen2-MoE-57B-int4-Chat": {
DownloadSource.DEFAULT: "Qwen/Qwen2-57B-A14B-Instruct-GPTQ-Int4",
DownloadSource.MODELSCOPE: "qwen/Qwen2-57B-A14B-Instruct-GPTQ-Int4",
},
},
template="qwen",
)
@@ -765,17 +1191,39 @@ register_model_group(
models={
"StarCoder2-3B": {
DownloadSource.DEFAULT: "bigcode/starcoder2-3b",
DownloadSource.MODELSCOPE: "AI-ModelScope/starcoder2-3b",
},
"StarCoder2-7B": {
DownloadSource.DEFAULT: "bigcode/starcoder2-7b",
DownloadSource.MODELSCOPE: "AI-ModelScope/starcoder2-7b",
},
"StarCoder2-15B": {
DownloadSource.DEFAULT: "bigcode/starcoder2-15b",
DownloadSource.MODELSCOPE: "AI-ModelScope/starcoder2-15b",
},
}
)
register_model_group(
models={
"TeleChat-7B-Chat": {
DownloadSource.DEFAULT: "Tele-AI/telechat-7B",
DownloadSource.MODELSCOPE: "TeleAI/telechat-7B",
},
"TeleChat-12B-Chat": {
DownloadSource.DEFAULT: "Tele-AI/TeleChat-12B",
DownloadSource.MODELSCOPE: "TeleAI/TeleChat-12B",
},
"TeleChat-12B-v2-Chat": {
DownloadSource.DEFAULT: "Tele-AI/TeleChat-12B-v2",
DownloadSource.MODELSCOPE: "TeleAI/TeleChat-12B-v2",
},
},
template="telechat",
)
register_model_group(
models={
"Vicuna1.5-7B-Chat": {
@@ -793,17 +1241,53 @@ register_model_group(
register_model_group(
models={
"XuanYuan-6B": {
DownloadSource.DEFAULT: "Duxiaoman-DI/XuanYuan-6B",
DownloadSource.MODELSCOPE: "Duxiaoman-DI/XuanYuan-6B",
},
"XuanYuan-70B": {
DownloadSource.DEFAULT: "Duxiaoman-DI/XuanYuan-70B",
DownloadSource.MODELSCOPE: "Duxiaoman-DI/XuanYuan-70B",
},
"XuanYuan-2-70B": {
DownloadSource.DEFAULT: "Duxiaoman-DI/XuanYuan2-70B",
DownloadSource.MODELSCOPE: "Duxiaoman-DI/XuanYuan2-70B",
},
"XuanYuan-6B-Chat": {
DownloadSource.DEFAULT: "Duxiaoman-DI/XuanYuan-6B-Chat",
DownloadSource.MODELSCOPE: "Duxiaoman-DI/XuanYuan-6B-Chat",
},
"XuanYuan-70B-Chat": {
DownloadSource.DEFAULT: "Duxiaoman-DI/XuanYuan-70B-Chat",
DownloadSource.MODELSCOPE: "Duxiaoman-DI/XuanYuan-70B-Chat",
},
"XuanYuan-2-70B-Chat": {
DownloadSource.DEFAULT: "Duxiaoman-DI/XuanYuan2-70B-Chat",
DownloadSource.MODELSCOPE: "Duxiaoman-DI/XuanYuan2-70B-Chat",
},
"XuanYuan-6B-int8-Chat": {
DownloadSource.DEFAULT: "Duxiaoman-DI/XuanYuan-6B-Chat-8bit",
DownloadSource.MODELSCOPE: "Duxiaoman-DI/XuanYuan-6B-Chat-8bit",
},
"XuanYuan-6B-int4-Chat": {
DownloadSource.DEFAULT: "Duxiaoman-DI/XuanYuan-6B-Chat-4bit",
DownloadSource.MODELSCOPE: "Duxiaoman-DI/XuanYuan-6B-Chat-4bit",
},
"XuanYuan-70B-int8-Chat": {
DownloadSource.DEFAULT: "Duxiaoman-DI/XuanYuan-70B-Chat-8bit",
DownloadSource.MODELSCOPE: "Duxiaoman-DI/XuanYuan-70B-Chat-8bit",
},
"XuanYuan-70B-int4-Chat": {
DownloadSource.DEFAULT: "Duxiaoman-DI/XuanYuan-70B-Chat-4bit",
DownloadSource.MODELSCOPE: "Duxiaoman-DI/XuanYuan-70B-Chat-4bit",
},
"XuanYuan-2-70B-int8-Chat": {
DownloadSource.DEFAULT: "Duxiaoman-DI/XuanYuan2-70B-Chat-8bit",
DownloadSource.MODELSCOPE: "Duxiaoman-DI/XuanYuan2-70B-Chat-8bit",
},
"XuanYuan-2-70B-int4-Chat": {
DownloadSource.DEFAULT: "Duxiaoman-DI/XuanYuan2-70B-Chat-4bit",
DownloadSource.MODELSCOPE: "Duxiaoman-DI/XuanYuan2-70B-Chat-4bit",
},
},
template="xuanyuan",
@@ -840,6 +1324,30 @@ register_model_group(
DownloadSource.DEFAULT: "xverse/XVERSE-65B-Chat",
DownloadSource.MODELSCOPE: "xverse/XVERSE-65B-Chat",
},
"XVERSE-MoE-A4.2B": {
DownloadSource.DEFAULT: "xverse/XVERSE-MoE-A4.2B",
DownloadSource.MODELSCOPE: "xverse/XVERSE-MoE-A4.2B",
},
"XVERSE-7B-int8-Chat": {
DownloadSource.DEFAULT: "xverse/XVERSE-7B-Chat-GPTQ-Int8",
DownloadSource.MODELSCOPE: "xverse/XVERSE-7B-Chat-GPTQ-Int8",
},
"XVERSE-7B-int4-Chat": {
DownloadSource.DEFAULT: "xverse/XVERSE-7B-Chat-GPTQ-Int4",
DownloadSource.MODELSCOPE: "xverse/XVERSE-7B-Chat-GPTQ-Int4",
},
"XVERSE-13B-int8-Chat": {
DownloadSource.DEFAULT: "xverse/XVERSE-13B-Chat-GPTQ-Int8",
DownloadSource.MODELSCOPE: "xverse/XVERSE-13B-Chat-GPTQ-Int8",
},
"XVERSE-13B-int4-Chat": {
DownloadSource.DEFAULT: "xverse/XVERSE-13B-Chat-GPTQ-Int4",
DownloadSource.MODELSCOPE: "xverse/XVERSE-13B-Chat-GPTQ-Int4",
},
"XVERSE-65B-int4-Chat": {
DownloadSource.DEFAULT: "xverse/XVERSE-65B-Chat-GPTQ-Int4",
DownloadSource.MODELSCOPE: "xverse/XVERSE-65B-Chat-GPTQ-Int4",
},
},
template="xverse",
)
@@ -898,11 +1406,49 @@ register_model_group(
DownloadSource.DEFAULT: "01-ai/Yi-34B-Chat-4bits",
DownloadSource.MODELSCOPE: "01ai/Yi-34B-Chat-4bits",
},
"Yi-1.5-6B": {
DownloadSource.DEFAULT: "01-ai/Yi-1.5-6B",
DownloadSource.MODELSCOPE: "01ai/Yi-1.5-6B",
},
"Yi-1.5-9B": {
DownloadSource.DEFAULT: "01-ai/Yi-1.5-9B",
DownloadSource.MODELSCOPE: "01ai/Yi-1.5-9B",
},
"Yi-1.5-34B": {
DownloadSource.DEFAULT: "01-ai/Yi-1.5-34B",
DownloadSource.MODELSCOPE: "01ai/Yi-1.5-34B",
},
"Yi-1.5-6B-Chat": {
DownloadSource.DEFAULT: "01-ai/Yi-1.5-6B-Chat",
DownloadSource.MODELSCOPE: "01ai/Yi-1.5-6B-Chat",
},
"Yi-1.5-9B-Chat": {
DownloadSource.DEFAULT: "01-ai/Yi-1.5-9B-Chat",
DownloadSource.MODELSCOPE: "01ai/Yi-1.5-9B-Chat",
},
"Yi-1.5-34B-Chat": {
DownloadSource.DEFAULT: "01-ai/Yi-1.5-34B-Chat",
DownloadSource.MODELSCOPE: "01ai/Yi-1.5-34B-Chat",
},
},
template="yi",
)
register_model_group(
models={
"YiVL-6B-Chat": {
DownloadSource.DEFAULT: "BUAADreamer/Yi-VL-6B-hf",
},
"YiVL-34B-Chat": {
DownloadSource.DEFAULT: "BUAADreamer/Yi-VL-34B-hf",
},
},
template="yi_vl",
vision=True,
)
register_model_group(
models={
"Yuan2-2B-Chat": {
@@ -932,21 +1478,9 @@ register_model_group(
DownloadSource.DEFAULT: "HuggingFaceH4/zephyr-7b-beta",
DownloadSource.MODELSCOPE: "modelscope/zephyr-7b-beta",
},
"Zephyr-141B-ORPO-Chat": {
DownloadSource.DEFAULT: "HuggingFaceH4/zephyr-orpo-141b-A35b-v0.1",
},
},
template="zephyr",
)
register_model_group(
models={
"Atom-7B": {
DownloadSource.DEFAULT: "FlagAlpha/Atom-7B",
DownloadSource.MODELSCOPE: "FlagAlpha/Atom-7B",
},
"Atom-7B-Chat": {
DownloadSource.DEFAULT: "FlagAlpha/Atom-7B-Chat",
DownloadSource.MODELSCOPE: "FlagAlpha/Atom-7B-Chat",
},
},
template="atom",
)

View File

@@ -0,0 +1,55 @@
import platform
import accelerate
import datasets
import peft
import torch
import transformers
import trl
from transformers.integrations import is_deepspeed_available
from transformers.utils import is_bitsandbytes_available, is_torch_cuda_available, is_torch_npu_available
from .packages import is_vllm_available
VERSION = "0.8.0"
def print_env() -> None:
info = {
"`llamafactory` version": VERSION,
"Platform": platform.platform(),
"Python version": platform.python_version(),
"PyTorch version": torch.__version__,
"Transformers version": transformers.__version__,
"Datasets version": datasets.__version__,
"Accelerate version": accelerate.__version__,
"PEFT version": peft.__version__,
"TRL version": trl.__version__,
}
if is_torch_cuda_available():
info["PyTorch version"] += " (GPU)"
info["GPU type"] = torch.cuda.get_device_name()
if is_torch_npu_available():
info["PyTorch version"] += " (NPU)"
info["NPU type"] = torch.npu.get_device_name()
info["CANN version"] = torch.version.cann
if is_deepspeed_available():
import deepspeed # type: ignore
info["DeepSpeed version"] = deepspeed.__version__
if is_bitsandbytes_available():
import bitsandbytes
info["Bitsandbytes version"] = bitsandbytes.__version__
if is_vllm_available():
import vllm
info["vLLM version"] = vllm.__version__
print("\n" + "\n".join(["- {}: {}".format(key, value) for key, value in info.items()]) + "\n")

View File

@@ -0,0 +1,68 @@
import logging
import os
import sys
from concurrent.futures import ThreadPoolExecutor
from .constants import RUNNING_LOG
class LoggerHandler(logging.Handler):
r"""
Logger handler used in Web UI.
"""
def __init__(self, output_dir: str) -> None:
super().__init__()
formatter = logging.Formatter(
fmt="%(asctime)s - %(levelname)s - %(name)s - %(message)s", datefmt="%m/%d/%Y %H:%M:%S"
)
self.setLevel(logging.INFO)
self.setFormatter(formatter)
os.makedirs(output_dir, exist_ok=True)
self.running_log = os.path.join(output_dir, RUNNING_LOG)
if os.path.exists(self.running_log):
os.remove(self.running_log)
self.thread_pool = ThreadPoolExecutor(max_workers=1)
def _write_log(self, log_entry: str) -> None:
with open(self.running_log, "a", encoding="utf-8") as f:
f.write(log_entry + "\n\n")
def emit(self, record) -> None:
if record.name == "httpx":
return
log_entry = self.format(record)
self.thread_pool.submit(self._write_log, log_entry)
def close(self) -> None:
self.thread_pool.shutdown(wait=True)
return super().close()
def get_logger(name: str) -> logging.Logger:
r"""
Gets a standard logger with a stream hander to stdout.
"""
formatter = logging.Formatter(
fmt="%(asctime)s - %(levelname)s - %(name)s - %(message)s", datefmt="%m/%d/%Y %H:%M:%S"
)
handler = logging.StreamHandler(sys.stdout)
handler.setFormatter(formatter)
logger = logging.getLogger(name)
logger.setLevel(logging.INFO)
logger.addHandler(handler)
return logger
def reset_logging() -> None:
r"""
Removes basic config of root logger. (unused in script)
"""
root = logging.getLogger()
list(map(root.removeHandler, root.handlers))
list(map(root.removeFilter, root.filters))

View File

@@ -30,7 +30,7 @@ except Exception:
if TYPE_CHECKING:
from trl import AutoModelForCausalLMWithValueHead
from llmtuner.hparams import ModelArguments
from ..hparams import ModelArguments
logger = get_logger(__name__)
@@ -58,14 +58,14 @@ class AverageMeter:
def check_dependencies() -> None:
if int(os.environ.get("DISABLE_VERSION_CHECK", "0")):
if os.environ.get("DISABLE_VERSION_CHECK", "0").lower() in ["true", "1"]:
logger.warning("Version checking has been disabled, may lead to unexpected behaviors.")
else:
require_version("transformers>=4.37.2", "To fix: pip install transformers>=4.37.2")
require_version("datasets>=2.14.3", "To fix: pip install datasets>=2.14.3")
require_version("accelerate>=0.27.2", "To fix: pip install accelerate>=0.27.2")
require_version("peft>=0.9.0", "To fix: pip install peft>=0.9.0")
require_version("trl>=0.7.11", "To fix: pip install trl>=0.7.11")
require_version("transformers>=4.41.2", "To fix: pip install transformers>=4.41.2")
require_version("datasets>=2.16.0", "To fix: pip install datasets>=2.16.0")
require_version("accelerate>=0.30.1", "To fix: pip install accelerate>=0.30.1")
require_version("peft>=0.11.1", "To fix: pip install peft>=0.11.1")
require_version("trl>=0.8.6", "To fix: pip install trl>=0.8.6")
def count_parameters(model: torch.nn.Module) -> Tuple[int, int]:
@@ -81,7 +81,14 @@ def count_parameters(model: torch.nn.Module) -> Tuple[int, int]:
# Due to the design of 4bit linear layers from bitsandbytes, multiply the number of parameters by 2
if param.__class__.__name__ == "Params4bit":
num_params = num_params * 2
if hasattr(param, "quant_storage") and hasattr(param.quant_storage, "itemsize"):
num_bytes = param.quant_storage.itemsize
elif hasattr(param, "element_size"): # for older pytorch version
num_bytes = param.element_size()
else:
num_bytes = 1
num_params = num_params * 2 * num_bytes
all_param += num_params
if param.requires_grad:
@@ -158,13 +165,15 @@ def get_current_device() -> torch.device:
def get_device_count() -> int:
r"""
Gets the number of available GPU devices.
Gets the number of available GPU or NPU devices.
"""
if not torch.cuda.is_available():
if is_torch_npu_available():
return torch.npu.device_count()
elif is_torch_cuda_available():
return torch.cuda.device_count()
else:
return 0
return torch.cuda.device_count()
def get_logits_processor() -> "LogitsProcessorList":
r"""
@@ -187,30 +196,47 @@ def infer_optim_dtype(model_dtype: torch.dtype) -> torch.dtype:
return torch.float32
def is_gpu_or_npu_available() -> bool:
r"""
Checks if the GPU or NPU is available.
"""
return is_torch_npu_available() or is_torch_cuda_available()
def has_tokenized_data(path: os.PathLike) -> bool:
r"""
Checks if the path has a tokenized dataset.
"""
return os.path.isdir(path) and len(os.listdir(path)) > 0
def torch_gc() -> None:
r"""
Collects GPU memory.
Collects GPU or NPU memory.
"""
gc.collect()
if torch.cuda.is_available():
if is_torch_xpu_available():
torch.xpu.empty_cache()
elif is_torch_npu_available():
torch.npu.empty_cache()
elif is_torch_mps_available():
torch.mps.empty_cache()
elif is_torch_cuda_available():
torch.cuda.empty_cache()
torch.cuda.ipc_collect()
def try_download_model_from_ms(model_args: "ModelArguments") -> None:
def try_download_model_from_ms(model_args: "ModelArguments") -> str:
if not use_modelscope() or os.path.exists(model_args.model_name_or_path):
return
return model_args.model_name_or_path
try:
from modelscope import snapshot_download
revision = "master" if model_args.model_revision == "main" else model_args.model_revision
model_args.model_name_or_path = snapshot_download(
model_args.model_name_or_path, revision=revision, cache_dir=model_args.cache_dir
)
return snapshot_download(model_args.model_name_or_path, revision=revision, cache_dir=model_args.cache_dir)
except ImportError:
raise ImportError("Please install modelscope via `pip install modelscope -U`")
def use_modelscope() -> bool:
return bool(int(os.environ.get("USE_MODELSCOPE_HUB", "0")))
return os.environ.get("USE_MODELSCOPE_HUB", "0").lower() in ["true", "1"]

View File

@@ -1,30 +1,41 @@
import importlib.metadata
import importlib.util
from typing import TYPE_CHECKING
from packaging import version
if TYPE_CHECKING:
from packaging.version import Version
def _is_package_available(name: str) -> bool:
return importlib.util.find_spec(name) is not None
def _get_package_version(name: str) -> str:
def _get_package_version(name: str) -> "Version":
try:
return importlib.metadata.version(name)
return version.parse(importlib.metadata.version(name))
except Exception:
return "0.0.0"
return version.parse("0.0.0")
def is_fastapi_availble():
def is_fastapi_available():
return _is_package_available("fastapi")
def is_flash_attn2_available():
return _is_package_available("flash_attn") and _get_package_version("flash_attn").startswith("2")
return _is_package_available("flash_attn") and _get_package_version("flash_attn") > version.parse("2.0.0")
def is_galore_available():
return _is_package_available("galore_torch")
def is_gradio_available():
return _is_package_available("gradio")
def is_jieba_available():
return _is_package_available("jieba")
@@ -37,6 +48,10 @@ def is_nltk_available():
return _is_package_available("nltk")
def is_pillow_available():
return _is_package_available("PIL")
def is_requests_available():
return _is_package_available("requests")
@@ -45,14 +60,14 @@ def is_rouge_available():
return _is_package_available("rouge_chinese")
def is_sdpa_available():
return _get_package_version("torch") > version.parse("2.1.1")
def is_starlette_available():
return _is_package_available("sse_starlette")
def is_unsloth_available():
return _is_package_available("unsloth")
def is_uvicorn_available():
return _is_package_available("uvicorn")

View File

@@ -1,7 +1,7 @@
import json
import math
import os
from typing import List
from typing import Any, Dict, List
from transformers.trainer import TRAINER_STATE_NAME
@@ -10,6 +10,7 @@ from .packages import is_matplotlib_available
if is_matplotlib_available():
import matplotlib.figure
import matplotlib.pyplot as plt
@@ -20,8 +21,11 @@ def smooth(scalars: List[float]) -> List[float]:
r"""
EMA implementation according to TensorBoard.
"""
if len(scalars) == 0:
return []
last = scalars[0]
smoothed = list()
smoothed = []
weight = 1.8 * (1 / (1 + math.exp(-0.05 * len(scalars))) - 0.5) # a sigmoid function
for next_val in scalars:
smoothed_val = last * weight + (1 - weight) * next_val
@@ -30,7 +34,33 @@ def smooth(scalars: List[float]) -> List[float]:
return smoothed
def gen_loss_plot(trainer_log: List[Dict[str, Any]]) -> "matplotlib.figure.Figure":
r"""
Plots loss curves in LlamaBoard.
"""
plt.close("all")
plt.switch_backend("agg")
fig = plt.figure()
ax = fig.add_subplot(111)
steps, losses = [], []
for log in trainer_log:
if log.get("loss", None):
steps.append(log["current_steps"])
losses.append(log["loss"])
ax.plot(steps, losses, color="#1f77b4", alpha=0.4, label="original")
ax.plot(steps, smooth(losses), color="#1f77b4", label="smoothed")
ax.legend()
ax.set_xlabel("step")
ax.set_ylabel("loss")
return fig
def plot_loss(save_dictionary: os.PathLike, keys: List[str] = ["loss"]) -> None:
r"""
Plots loss curves and saves the image.
"""
plt.switch_backend("agg")
with open(os.path.join(save_dictionary, TRAINER_STATE_NAME), "r", encoding="utf-8") as f:
data = json.load(f)
@@ -52,6 +82,6 @@ def plot_loss(save_dictionary: os.PathLike, keys: List[str] = ["loss"]) -> None:
plt.xlabel("step")
plt.ylabel(key)
plt.legend()
figure_path = os.path.join(save_dictionary, "training_{}.png".format(key.replace(os.path.sep, "_")))
figure_path = os.path.join(save_dictionary, "training_{}.png".format(key.replace("/", "_")))
plt.savefig(figure_path, format="png", dpi=100)
print("Figure saved at:", figure_path)

View File

@@ -26,11 +26,11 @@ class DataArguments:
)
cutoff_len: int = field(
default=1024,
metadata={"help": "The cutoff length of the model inputs after tokenization."},
metadata={"help": "The cutoff length of the tokenized inputs in the dataset."},
)
reserved_label_len: int = field(
default=1,
metadata={"help": "The minimum cutoff length reserved for label after tokenization."},
metadata={"help": "The minimum cutoff length reserved for the tokenized labels in the dataset."},
)
train_on_prompt: bool = field(
default=False,
@@ -84,9 +84,9 @@ class DataArguments:
"help": "Whether or not to pack the sequences in training. Will automatically enable in pre-training."
},
)
cache_path: Optional[str] = field(
tokenized_path: Optional[str] = field(
default=None,
metadata={"help": "Path to save or load the pre-processed datasets."},
metadata={"help": "Path to save or load the tokenized datasets."},
)
def __post_init__(self):

View File

@@ -1,5 +1,4 @@
import json
from dataclasses import asdict, dataclass, field
from dataclasses import dataclass, field
from typing import Literal, Optional
@@ -9,22 +8,35 @@ class FreezeArguments:
Arguments pertaining to the freeze (partial-parameter) training.
"""
name_module_trainable: str = field(
default="all",
freeze_trainable_layers: int = field(
default=2,
metadata={
"help": """Name of trainable modules for partial-parameter (freeze) fine-tuning. \
Use commas to separate multiple modules. \
Use "all" to specify all the available modules. \
LLaMA choices: ["mlp", "self_attn"], \
BLOOM & Falcon & ChatGLM choices: ["mlp", "self_attention"], \
Qwen choices: ["mlp", "attn"], \
InternLM2 choices: ["feed_forward", "attention"], \
Others choices: the same as LLaMA."""
"help": (
"The number of trainable layers for freeze (partial-parameter) fine-tuning. "
"Positive numbers mean the last n layers are set as trainable, "
"negative numbers mean the first n layers are set as trainable."
)
},
)
num_layer_trainable: int = field(
default=2,
metadata={"help": "The number of trainable layers for partial-parameter (freeze) fine-tuning."},
freeze_trainable_modules: str = field(
default="all",
metadata={
"help": (
"Name(s) of trainable modules for freeze (partial-parameter) fine-tuning. "
"Use commas to separate multiple modules. "
"Use `all` to specify all the available modules."
)
},
)
freeze_extra_modules: Optional[str] = field(
default=None,
metadata={
"help": (
"Name(s) of modules apart from hidden layers to be set as trainable "
"for freeze (partial-parameter) fine-tuning. "
"Use commas to separate multiple modules."
)
},
)
@@ -37,7 +49,11 @@ class LoraArguments:
additional_target: Optional[str] = field(
default=None,
metadata={
"help": "Name(s) of modules apart from LoRA layers to be set as trainable and saved in the final checkpoint."
"help": (
"Name(s) of modules apart from LoRA layers to be set as trainable "
"and saved in the final checkpoint. "
"Use commas to separate multiple modules."
)
},
)
lora_alpha: Optional[int] = field(
@@ -55,17 +71,21 @@ class LoraArguments:
lora_target: str = field(
default="all",
metadata={
"help": """Name(s) of target modules to apply LoRA. \
Use commas to separate multiple modules. \
Use "all" to specify all the available modules. \
LLaMA choices: ["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"], \
BLOOM & Falcon & ChatGLM choices: ["query_key_value", "dense", "dense_h_to_4h", "dense_4h_to_h"], \
Baichuan choices: ["W_pack", "o_proj", "gate_proj", "up_proj", "down_proj"], \
Qwen choices: ["c_attn", "attn.c_proj", "w1", "w2", "mlp.c_proj"], \
InternLM2 choices: ["wqkv", "wo", "w1", "w2", "w3"], \
Others choices: the same as LLaMA."""
"help": (
"Name(s) of target modules to apply LoRA. "
"Use commas to separate multiple modules. "
"Use `all` to specify all the linear modules."
)
},
)
loraplus_lr_ratio: Optional[float] = field(
default=None,
metadata={"help": "LoRA plus learning rate ratio (lr_B / lr_A)."},
)
loraplus_lr_embedding: float = field(
default=1e-6,
metadata={"help": "LoRA plus learning rate for lora embedding layers."},
)
use_rslora: bool = field(
default=False,
metadata={"help": "Whether or not to use the rank stabilization scaling factor for LoRA layer."},
@@ -83,20 +103,36 @@ class LoraArguments:
@dataclass
class RLHFArguments:
r"""
Arguments pertaining to the PPO and DPO training.
Arguments pertaining to the PPO, DPO and KTO training.
"""
dpo_beta: float = field(
pref_beta: float = field(
default=0.1,
metadata={"help": "The beta parameter for the DPO loss."},
metadata={"help": "The beta parameter in the preference loss."},
)
dpo_loss: Literal["sigmoid", "hinge", "ipo", "kto_pair"] = field(
pref_ftx: float = field(
default=0.0,
metadata={"help": "The supervised fine-tuning loss coefficient in DPO training."},
)
pref_loss: Literal["sigmoid", "hinge", "ipo", "kto_pair", "orpo", "simpo"] = field(
default="sigmoid",
metadata={"help": "The type of DPO loss to use."},
)
dpo_ftx: float = field(
dpo_label_smoothing: float = field(
default=0.0,
metadata={"help": "The supervised fine-tuning loss coefficient in DPO training."},
metadata={"help": "The robust DPO label smoothing parameter in cDPO that should be between 0 and 0.5."},
)
kto_chosen_weight: float = field(
default=1.0,
metadata={"help": "The weight factor of the desirable losses in KTO training."},
)
kto_rejected_weight: float = field(
default=1.0,
metadata={"help": "The weight factor of the undesirable losses in KTO training."},
)
simpo_gamma: float = field(
default=0.5,
metadata={"help": "The target reward margin term in SimPO loss."},
)
ppo_buffer_size: int = field(
default=1,
@@ -106,10 +142,6 @@ class RLHFArguments:
default=4,
metadata={"help": "The number of epochs to perform in a PPO optimization step."},
)
ppo_logger: Optional[str] = field(
default=None,
metadata={"help": 'Log with either "wandb" or "tensorboard" in PPO training.'},
)
ppo_score_norm: bool = field(
default=False,
metadata={"help": "Use score normalization in PPO training."},
@@ -160,11 +192,16 @@ class GaloreArguments:
use_galore: bool = field(
default=False,
metadata={"help": "Whether or not to use gradient low-Rank projection."},
metadata={"help": "Whether or not to use the gradient low-Rank projection (GaLore)."},
)
galore_target: str = field(
default="mlp,attn",
metadata={"help": "Name(s) of modules to apply GaLore. Use commas to separate multiple modules."},
default="all",
metadata={
"help": (
"Name(s) of modules to apply GaLore. Use commas to separate multiple modules. "
"Use `all` to specify all the linear modules."
)
},
)
galore_rank: int = field(
default=16,
@@ -189,7 +226,60 @@ class GaloreArguments:
@dataclass
class FinetuningArguments(FreezeArguments, LoraArguments, RLHFArguments, GaloreArguments):
class BAdamArgument:
r"""
Arguments pertaining to the BAdam optimizer.
"""
use_badam: bool = field(
default=False,
metadata={"help": "Whether or not to use the BAdam optimizer."},
)
badam_mode: Literal["layer", "ratio"] = field(
default="layer",
metadata={"help": "Whether to use layer-wise or ratio-wise BAdam optimizer."},
)
badam_start_block: Optional[int] = field(
default=None,
metadata={"help": "The starting block index for layer-wise BAdam."},
)
badam_switch_mode: Optional[Literal["ascending", "descending", "random", "fixed"]] = field(
default="ascending",
metadata={"help": "the strategy of picking block to update for layer-wise BAdam."},
)
badam_switch_interval: Optional[int] = field(
default=50,
metadata={
"help": "Number of steps to update the block for layer-wise BAdam. Use -1 to disable the block update."
},
)
badam_update_ratio: float = field(
default=0.05,
metadata={"help": "The ratio of the update for ratio-wise BAdam."},
)
badam_mask_mode: Literal["adjacent", "scatter"] = field(
default="adjacent",
metadata={
"help": (
"The mode of the mask for BAdam optimizer. "
"`adjacent` means that the trainable parameters are adjacent to each other, "
"`scatter` means that trainable parameters are randomly choosed from the weight."
)
},
)
badam_verbose: int = field(
default=0,
metadata={
"help": (
"The verbosity level of BAdam optimizer. "
"0 for no print, 1 for print the block prefix, 2 for print trainable parameters."
)
},
)
@dataclass
class FinetuningArguments(FreezeArguments, LoraArguments, RLHFArguments, GaloreArguments, BAdamArgument):
r"""
Arguments pertaining to which techniques we are going to fine-tuning with.
"""
@@ -198,7 +288,7 @@ class FinetuningArguments(FreezeArguments, LoraArguments, RLHFArguments, GaloreA
default=False,
metadata={"help": "Whether or not to train model in purely bf16 precision (without AMP)."},
)
stage: Literal["pt", "sft", "rm", "ppo", "dpo"] = field(
stage: Literal["pt", "sft", "rm", "ppo", "dpo", "kto"] = field(
default="sft",
metadata={"help": "Which stage will be performed in training."},
)
@@ -210,6 +300,14 @@ class FinetuningArguments(FreezeArguments, LoraArguments, RLHFArguments, GaloreA
default=False,
metadata={"help": "Whether or not to make only the parameters in the expanded blocks trainable."},
)
freeze_vision_tower: bool = field(
default=True,
metadata={"help": "Whether ot not to freeze vision tower in MLLM training."},
)
train_mm_proj_only: bool = field(
default=False,
metadata={"help": "Whether or not to train the multimodal projector for MLLM only."},
)
plot_loss: bool = field(
default=False,
metadata={"help": "Whether or not to save the training loss curves."},
@@ -221,37 +319,40 @@ class FinetuningArguments(FreezeArguments, LoraArguments, RLHFArguments, GaloreA
return [item.strip() for item in arg.split(",")]
return arg
self.name_module_trainable = split_arg(self.name_module_trainable)
self.freeze_trainable_modules = split_arg(self.freeze_trainable_modules)
self.freeze_extra_modules = split_arg(self.freeze_extra_modules)
self.lora_alpha = self.lora_alpha or self.lora_rank * 2
self.lora_target = split_arg(self.lora_target)
self.additional_target = split_arg(self.additional_target)
self.galore_target = split_arg(self.galore_target)
self.freeze_vision_tower = self.freeze_vision_tower or self.train_mm_proj_only
assert self.finetuning_type in ["lora", "freeze", "full"], "Invalid fine-tuning method."
assert self.ref_model_quantization_bit in [None, 8, 4], "We only accept 4-bit or 8-bit quantization."
assert self.reward_model_quantization_bit in [None, 8, 4], "We only accept 4-bit or 8-bit quantization."
self.use_ref_model = self.pref_loss not in ["orpo", "simpo"]
if self.stage == "ppo" and self.reward_model is None:
raise ValueError("`reward_model` is necessary for PPO training.")
if self.stage == "ppo" and self.reward_model_type == "lora" and self.finetuning_type != "lora":
raise ValueError("`reward_model_type` cannot be lora for Freeze/Full PPO training.")
if self.stage == "dpo" and self.pref_loss != "sigmoid" and self.dpo_label_smoothing > 1e-6:
raise ValueError("`dpo_label_smoothing` is only valid for sigmoid loss function.")
if self.use_llama_pro and self.finetuning_type == "full":
raise ValueError("`use_llama_pro` is only valid for the Freeze or LoRA method.")
raise ValueError("`use_llama_pro` is only valid for Freeze or LoRA training.")
if self.use_galore and self.finetuning_type == "lora":
raise ValueError("Cannot use LoRA with GaLore together.")
if self.finetuning_type == "lora" and (self.use_galore or self.use_badam):
raise ValueError("Cannot use LoRA with GaLore or BAdam together.")
def save_to_json(self, json_path: str):
r"""Saves the content of this instance in JSON format inside `json_path`."""
json_string = json.dumps(asdict(self), indent=2, sort_keys=True) + "\n"
with open(json_path, "w", encoding="utf-8") as f:
f.write(json_string)
if self.use_galore and self.use_badam:
raise ValueError("Cannot use GaLore with BAdam together.")
@classmethod
def load_from_json(cls, json_path: str):
r"""Creates an instance from the content of `json_path`."""
with open(json_path, "r", encoding="utf-8") as f:
text = f.read()
if self.loraplus_lr_ratio is not None and self.finetuning_type != "lora":
raise ValueError("`loraplus_lr_ratio` is only valid for LoRA training.")
return cls(**json.loads(text))
if self.train_mm_proj_only and self.finetuning_type != "full":
raise ValueError("`train_mm_proj_only` is only valid for full training.")

View File

@@ -1,5 +1,5 @@
from dataclasses import asdict, dataclass, field
from typing import Any, Dict
from typing import Any, Dict, Optional
@dataclass
@@ -31,11 +31,11 @@ class GeneratingArguments:
metadata={"help": "Number of beams for beam search. 1 means no beam search."},
)
max_length: int = field(
default=512,
default=1024,
metadata={"help": "The maximum length the generated tokens can have. It can be overridden by max_new_tokens."},
)
max_new_tokens: int = field(
default=512,
default=1024,
metadata={"help": "The maximum numbers of tokens to generate, ignoring the number of tokens in the prompt."},
)
repetition_penalty: float = field(
@@ -46,6 +46,10 @@ class GeneratingArguments:
default=1.0,
metadata={"help": "Exponential penalty to the length that is used with beam-based generation."},
)
default_system: Optional[str] = field(
default=None,
metadata={"help": "Default system message to use in chat completion."},
)
def to_dict(self) -> Dict[str, Any]:
args = asdict(self)

View File

@@ -15,14 +15,19 @@ class ModelArguments:
)
adapter_name_or_path: Optional[str] = field(
default=None,
metadata={"help": "Path to the adapter weight or identifier from huggingface.co/models."},
metadata={
"help": (
"Path to the adapter weight or identifier from huggingface.co/models. "
"Use commas to separate multiple adapters."
)
},
)
cache_dir: Optional[str] = field(
default=None,
metadata={"help": "Where to store the pre-trained models downloaded from huggingface.co or modelscope.cn."},
)
use_fast_tokenizer: bool = field(
default=False,
default=True,
metadata={"help": "Whether or not to use one of the fast tokenizer (backed by the tokenizers library)."},
)
resize_vocab: bool = field(
@@ -33,6 +38,10 @@ class ModelArguments:
default=False,
metadata={"help": "Whether or not the special tokens should be split during the tokenization process."},
)
new_special_tokens: Optional[str] = field(
default=None,
metadata={"help": "Special tokens to be added into the tokenizer. Use commas to separate multiple tokens."},
)
model_revision: str = field(
default="main",
metadata={"help": "The specific model version to use (can be a branch name, tag name or commit id)."},
@@ -53,22 +62,38 @@ class ModelArguments:
default=True,
metadata={"help": "Whether or not to use double quantization in int4 training."},
)
quantization_device_map: Optional[Literal["auto"]] = field(
default=None,
metadata={"help": "Device map used to infer the 4-bit quantized model, needs bitsandbytes>=0.43.0."},
)
rope_scaling: Optional[Literal["linear", "dynamic"]] = field(
default=None,
metadata={"help": "Which scaling strategy should be adopted for the RoPE embeddings."},
)
flash_attn: bool = field(
default=False,
metadata={"help": "Enable FlashAttention-2 for faster training."},
flash_attn: Literal["off", "sdpa", "fa2", "auto"] = field(
default="auto",
metadata={"help": "Enable FlashAttention for faster training and inference."},
)
shift_attn: bool = field(
default=False,
metadata={"help": "Enable shift short attention (S^2-Attn) proposed by LongLoRA."},
)
mixture_of_depths: Optional[Literal["convert", "load"]] = field(
default=None,
metadata={"help": "Convert the model to mixture-of-depths (MoD) or load the MoD model."},
)
use_unsloth: bool = field(
default=False,
metadata={"help": "Whether or not to use unsloth's optimization for the LoRA training."},
)
visual_inputs: bool = field(
default=False,
metadata={"help": "Whethor or not to use multimodal LLM that accepts visual inputs."},
)
moe_aux_loss_coef: Optional[float] = field(
default=None,
metadata={"help": "Coefficient of the auxiliary router loss in mixture-of-experts model."},
)
disable_gradient_checkpointing: bool = field(
default=False,
metadata={"help": "Whether or not to disable gradient checkpointing."},
@@ -81,13 +106,17 @@ class ModelArguments:
default=False,
metadata={"help": "Whether or not to upcast the output of lm_head in fp32."},
)
train_from_scratch: bool = field(
default=False,
metadata={"help": "Whether or not to randomly initialize the model weights."},
)
infer_backend: Literal["huggingface", "vllm"] = field(
default="huggingface",
metadata={"help": "Backend engine used at inference."},
)
vllm_maxlen: int = field(
default=2048,
metadata={"help": "Maximum input length of the vLLM engine."},
metadata={"help": "Maximum sequence (prompt + response) length of the vLLM engine."},
)
vllm_gpu_util: float = field(
default=0.9,
@@ -97,6 +126,14 @@ class ModelArguments:
default=False,
metadata={"help": "Whether or not to disable CUDA graph in the vLLM engine."},
)
vllm_max_lora_rank: int = field(
default=8,
metadata={"help": "Maximum rank of all LoRAs in the vLLM engine."},
)
vllm_dtype: Literal["auto", "float16", "bfloat16", "float32"] = field(
default="auto",
metadata={"help": "Data type for model weights and activations in the vLLM engine."},
)
offload_folder: str = field(
default="offload",
metadata={"help": "Path to offload model weights."},
@@ -121,6 +158,10 @@ class ModelArguments:
default=1,
metadata={"help": "The file shard size (in GB) of the exported model."},
)
export_device: Literal["cpu", "auto"] = field(
default="cpu",
metadata={"help": "The device used in model export, use `auto` to accelerate exporting."},
)
export_quantization_bit: Optional[int] = field(
default=None,
metadata={"help": "The number of bits to quantize the exported model."},
@@ -158,9 +199,15 @@ class ModelArguments:
if self.split_special_tokens and self.use_fast_tokenizer:
raise ValueError("`split_special_tokens` is only supported for slow tokenizers.")
if self.visual_inputs and self.use_unsloth:
raise ValueError("Unsloth does not support MLLM yet. Stay tuned.")
if self.adapter_name_or_path is not None: # support merging multiple lora weights
self.adapter_name_or_path = [path.strip() for path in self.adapter_name_or_path.split(",")]
if self.new_special_tokens is not None: # support multiple special tokens
self.new_special_tokens = [token.strip() for token in self.new_special_tokens.split(",")]
assert self.quantization_bit in [None, 8, 4], "We only accept 4-bit or 8-bit quantization."
assert self.export_quantization_bit in [None, 8, 4, 3, 2], "We only accept 2/3/4/8-bit quantization."

View File

@@ -6,13 +6,14 @@ from typing import Any, Dict, Optional, Tuple
import torch
import transformers
from transformers import HfArgumentParser, Seq2SeqTrainingArguments
from transformers.integrations import is_deepspeed_zero3_enabled
from transformers.trainer_utils import get_last_checkpoint
from transformers.utils import is_torch_bf16_gpu_available
from transformers.utils.versions import require_version
from ..extras.constants import CHECKPOINT_NAMES
from ..extras.logging import get_logger
from ..extras.misc import check_dependencies
from ..extras.packages import is_unsloth_available
from ..extras.misc import check_dependencies, get_current_device
from .data_args import DataArguments
from .evaluation_args import EvaluationArguments
from .finetuning_args import FinetuningArguments
@@ -64,10 +65,16 @@ def _verify_model_args(model_args: "ModelArguments", finetuning_args: "Finetunin
if model_args.adapter_name_or_path is not None and finetuning_args.finetuning_type != "lora":
raise ValueError("Adapter is only valid for the LoRA method.")
if model_args.use_unsloth and is_deepspeed_zero3_enabled():
raise ValueError("Unsloth is incompatible with DeepSpeed ZeRO-3.")
if model_args.quantization_bit is not None:
if finetuning_args.finetuning_type != "lora":
raise ValueError("Quantization is only compatible with the LoRA method.")
if model_args.resize_vocab:
raise ValueError("Cannot resize embedding layers of a quantized model.")
if model_args.adapter_name_or_path is not None and finetuning_args.create_new_adapter:
raise ValueError("Cannot create new adapter upon a quantized model.")
@@ -75,6 +82,35 @@ def _verify_model_args(model_args: "ModelArguments", finetuning_args: "Finetunin
raise ValueError("Quantized model only accepts a single adapter. Merge them first.")
def _check_extra_dependencies(
model_args: "ModelArguments",
finetuning_args: "FinetuningArguments",
training_args: Optional["Seq2SeqTrainingArguments"] = None,
) -> None:
if model_args.use_unsloth:
require_version("unsloth", "Please install unsloth: https://github.com/unslothai/unsloth")
if model_args.mixture_of_depths is not None:
require_version("mixture-of-depth>=1.1.6", "To fix: pip install mixture-of-depth>=1.1.6")
if model_args.infer_backend == "vllm":
require_version("vllm>=0.4.3", "To fix: pip install vllm>=0.4.3")
if finetuning_args.use_galore:
require_version("galore_torch", "To fix: pip install galore_torch")
if finetuning_args.use_badam:
require_version("badam", "To fix: pip install badam")
if finetuning_args.plot_loss:
require_version("matplotlib", "To fix: pip install matplotlib")
if training_args is not None and training_args.predict_with_generate:
require_version("jieba", "To fix: pip install jieba")
require_version("nltk", "To fix: pip install nltk")
require_version("rouge_chinese", "To fix: pip install rouge-chinese")
def _parse_train_args(args: Optional[Dict[str, Any]] = None) -> _TRAIN_CLS:
parser = HfArgumentParser(_TRAIN_ARGS)
return _parse_args(parser, args)
@@ -119,21 +155,24 @@ def get_train_args(args: Optional[Dict[str, Any]] = None) -> _TRAIN_CLS:
if finetuning_args.stage == "ppo" and finetuning_args.reward_model_type == "lora" and model_args.use_unsloth:
raise ValueError("Unsloth does not support lora reward model.")
if (
finetuning_args.stage == "ppo"
and training_args.report_to
and training_args.report_to[0] not in ["wandb", "tensorboard"]
):
raise ValueError("PPO only accepts wandb or tensorboard logger.")
if training_args.max_steps == -1 and data_args.streaming:
raise ValueError("Please specify `max_steps` in streaming mode.")
if training_args.do_train and training_args.predict_with_generate:
raise ValueError("`predict_with_generate` cannot be set as True while training.")
if training_args.do_train and model_args.use_unsloth and not is_unsloth_available():
raise ValueError("Unsloth was not installed: https://github.com/unslothai/unsloth")
if training_args.do_train and model_args.quantization_device_map == "auto":
raise ValueError("Cannot use device map for quantized models in training.")
if finetuning_args.use_dora:
if model_args.quantization_bit is not None:
require_version("peft>=0.9.1.dev0", "To fix: pip install git+https://github.com/huggingface/peft.git")
if model_args.use_unsloth:
raise ValueError("Unsloth does not support DoRA.")
if finetuning_args.use_dora and model_args.use_unsloth:
raise ValueError("Unsloth does not support DoRA.")
if finetuning_args.pure_bf16:
if not is_torch_bf16_gpu_available():
@@ -149,18 +188,33 @@ def get_train_args(args: Optional[Dict[str, Any]] = None) -> _TRAIN_CLS:
):
raise ValueError("Distributed training does not support layer-wise GaLore.")
if (
finetuning_args.use_badam
and finetuning_args.badam_mode == "layer"
and training_args.parallel_mode.value == "distributed"
):
raise ValueError("Layer-wise BAdam does not yet support distributed training, use ratio-wise BAdam.")
if (finetuning_args.use_galore or finetuning_args.use_badam) and training_args.deepspeed is not None:
raise ValueError("GaLore and BAdam are incompatible with DeepSpeed yet.")
if model_args.infer_backend == "vllm":
raise ValueError("vLLM backend is only available for API, CLI and Web.")
if model_args.visual_inputs and data_args.packing:
raise ValueError("Cannot use packing in MLLM fine-tuning.")
_verify_model_args(model_args, finetuning_args)
_check_extra_dependencies(model_args, finetuning_args, training_args)
if (
training_args.do_train
and finetuning_args.finetuning_type == "lora"
and model_args.quantization_bit is None
and model_args.resize_vocab
and finetuning_args.additional_target is None
):
logger.warning("Add token embeddings to `additional_target` to make the added tokens trainable.")
logger.warning("Remember to add embedding layers to `additional_target` to make the added tokens trainable.")
if training_args.do_train and model_args.quantization_bit is not None and (not model_args.upcast_layernorm):
logger.warning("We recommend enable `upcast_layernorm` in quantized training.")
@@ -202,16 +256,15 @@ def get_train_args(args: Optional[Dict[str, Any]] = None) -> _TRAIN_CLS:
and can_resume_from_checkpoint
):
last_checkpoint = get_last_checkpoint(training_args.output_dir)
if last_checkpoint is None and len(os.listdir(training_args.output_dir)) > 0:
if last_checkpoint is None and any(
os.path.isfile(os.path.join(training_args.output_dir, name)) for name in CHECKPOINT_NAMES
):
raise ValueError("Output directory already exists and is not empty. Please set `overwrite_output_dir`.")
if last_checkpoint is not None:
training_args.resume_from_checkpoint = last_checkpoint
logger.info(
"Resuming training from {}. Change `output_dir` or use `overwrite_output_dir` to avoid.".format(
training_args.resume_from_checkpoint
)
)
logger.info("Resuming training from {}.".format(training_args.resume_from_checkpoint))
logger.info("Change `output_dir` or use `overwrite_output_dir` to avoid.")
if (
finetuning_args.stage in ["rm", "ppo"]
@@ -230,10 +283,11 @@ def get_train_args(args: Optional[Dict[str, Any]] = None) -> _TRAIN_CLS:
elif training_args.fp16:
model_args.compute_dtype = torch.float16
model_args.device_map = {"": get_current_device()}
model_args.model_max_length = data_args.cutoff_len
data_args.packing = data_args.packing if data_args.packing is not None else finetuning_args.stage == "pt"
# Log on each process the small summary:
# Log on each process the small summary
logger.info(
"Process rank: {}, device: {}, n_gpu: {}, distributed training: {}, compute dtype: {}".format(
training_args.local_rank,
@@ -261,18 +315,25 @@ def get_infer_args(args: Optional[Dict[str, Any]] = None) -> _INFER_CLS:
if finetuning_args.stage != "sft":
raise ValueError("vLLM engine only supports auto-regressive models.")
if model_args.adapter_name_or_path is not None:
raise ValueError("vLLM engine does not support LoRA adapters. Merge them first.")
if model_args.quantization_bit is not None:
raise ValueError("vLLM engine does not support quantization.")
raise ValueError("vLLM engine does not support bnb quantization (GPTQ and AWQ are supported).")
if model_args.rope_scaling is not None:
raise ValueError("vLLM engine does not support RoPE scaling.")
_verify_model_args(model_args, finetuning_args)
if model_args.adapter_name_or_path is not None and len(model_args.adapter_name_or_path) != 1:
raise ValueError("vLLM only accepts a single adapter. Merge them first.")
model_args.device_map = "auto"
if finetuning_args.stage == "rm" and model_args.visual_inputs:
raise ValueError("Reward server does not support MLLM yet. Stay tuned.")
_verify_model_args(model_args, finetuning_args)
_check_extra_dependencies(model_args, finetuning_args)
if model_args.export_dir is not None and model_args.export_device == "cpu":
model_args.device_map = {"": torch.device("cpu")}
else:
model_args.device_map = "auto"
return model_args, data_args, finetuning_args, generating_args
@@ -289,6 +350,7 @@ def get_eval_args(args: Optional[Dict[str, Any]] = None) -> _EVAL_CLS:
raise ValueError("vLLM backend is only available for API, CLI and Web.")
_verify_model_args(model_args, finetuning_args)
_check_extra_dependencies(model_args, finetuning_args)
model_args.device_map = "auto"

View File

@@ -0,0 +1,9 @@
from llamafactory.train.tuner import run_exp
def launch():
run_exp()
if __name__ == "__main__":
launch()

View File

@@ -0,0 +1,12 @@
from .loader import load_config, load_model, load_tokenizer
from .model_utils.misc import find_all_linear_modules
from .model_utils.valuehead import load_valuehead_params
__all__ = [
"load_config",
"load_model",
"load_tokenizer",
"find_all_linear_modules",
"load_valuehead_params",
]

View File

@@ -0,0 +1,275 @@
import re
from typing import TYPE_CHECKING
import torch
from peft import LoraConfig, LoraModel, PeftModel, TaskType, get_peft_model
from transformers.integrations import is_deepspeed_zero3_enabled
from transformers.modeling_utils import is_fsdp_enabled
from ..extras.logging import get_logger
from .model_utils.misc import find_all_linear_modules, find_expanded_modules
from .model_utils.quantization import QuantizationMethod
from .model_utils.unsloth import get_unsloth_peft_model, load_unsloth_peft_model
if TYPE_CHECKING:
from transformers import PretrainedConfig, PreTrainedModel
from ..hparams import FinetuningArguments, ModelArguments
logger = get_logger(__name__)
def _setup_full_tuning(
model: "PreTrainedModel",
model_args: "ModelArguments",
finetuning_args: "FinetuningArguments",
cast_trainable_params_to_fp32: bool,
) -> None:
logger.info("Fine-tuning method: Full")
forbidden_modules = set()
if model_args.visual_inputs and finetuning_args.freeze_vision_tower:
forbidden_modules.add("vision_tower")
if model_args.visual_inputs and finetuning_args.train_mm_proj_only:
forbidden_modules.add("language_model")
for name, param in model.named_parameters():
if not any(forbidden_module in name for forbidden_module in forbidden_modules):
if cast_trainable_params_to_fp32:
param.data = param.data.to(torch.float32)
else:
param.requires_grad_(False)
def _setup_freeze_tuning(
model: "PreTrainedModel",
model_args: "ModelArguments",
finetuning_args: "FinetuningArguments",
cast_trainable_params_to_fp32: bool,
) -> None:
logger.info("Fine-tuning method: Freeze")
if model_args.visual_inputs:
config = model.config.text_config
else:
config = model.config
num_layers = (
getattr(config, "num_hidden_layers", None)
or getattr(config, "num_layers", None)
or getattr(config, "n_layer", None)
)
if not num_layers:
raise ValueError("Current model does not support freeze tuning.")
if finetuning_args.use_llama_pro:
if num_layers % finetuning_args.freeze_trainable_layers != 0:
raise ValueError(
"`num_layers` {} should be divisible by `num_layer_trainable` {}.".format(
num_layers, finetuning_args.freeze_trainable_layers
)
)
stride = num_layers // finetuning_args.freeze_trainable_layers
trainable_layer_ids = range(stride - 1, num_layers + stride - 1, stride)
elif finetuning_args.freeze_trainable_layers > 0: # fine-tuning the last n layers if num_layer_trainable > 0
trainable_layer_ids = range(max(0, num_layers - finetuning_args.freeze_trainable_layers), num_layers)
else: # fine-tuning the first n layers if num_layer_trainable < 0
trainable_layer_ids = range(min(-finetuning_args.freeze_trainable_layers, num_layers))
hidden_modules = set()
non_hidden_modules = set()
for name, _ in model.named_parameters():
if ".0." in name:
hidden_modules.add(name.split(".0.")[-1].split(".")[0])
elif ".1." in name: # MoD starts from layer 1
hidden_modules.add(name.split(".1.")[-1].split(".")[0])
if re.search(r"\.\d+\.", name) is None:
non_hidden_modules.add(name.split(".")[-2])
trainable_layers = []
for module_name in finetuning_args.freeze_trainable_modules:
if module_name != "all" and module_name not in hidden_modules:
raise ValueError(
"Module {} is not found, please choose from {}".format(module_name, ", ".join(hidden_modules))
)
for idx in trainable_layer_ids:
trainable_layers.append(".{:d}.{}".format(idx, module_name if module_name != "all" else ""))
if finetuning_args.freeze_extra_modules:
for module_name in finetuning_args.freeze_extra_modules:
if module_name not in non_hidden_modules:
raise ValueError(
"Module {} is not found, please choose from {}".format(module_name, ", ".join(non_hidden_modules))
)
trainable_layers.append(module_name)
forbidden_modules = set()
if model_args.visual_inputs and finetuning_args.freeze_vision_tower:
forbidden_modules.add("vision_tower")
for name, param in model.named_parameters():
if any(trainable_layer in name for trainable_layer in trainable_layers) and not any(
forbidden_module in name for forbidden_module in forbidden_modules
):
if cast_trainable_params_to_fp32:
param.data = param.data.to(torch.float32)
else:
param.requires_grad_(False)
logger.info("Set trainable layers: {}".format(",".join(trainable_layers)))
def _setup_lora_tuning(
config: "PretrainedConfig",
model: "PreTrainedModel",
model_args: "ModelArguments",
finetuning_args: "FinetuningArguments",
is_trainable: bool,
cast_trainable_params_to_fp32: bool,
) -> "PeftModel":
logger.info("Fine-tuning method: {}".format("DoRA" if finetuning_args.use_dora else "LoRA"))
adapter_to_resume = None
if model_args.adapter_name_or_path is not None:
is_mergeable = True
if getattr(model, "quantization_method", None): # merge lora in quantized model is unstable
assert len(model_args.adapter_name_or_path) == 1, "Quantized model only accepts a single adapter."
is_mergeable = False
if is_deepspeed_zero3_enabled():
assert len(model_args.adapter_name_or_path) == 1, "Cannot use multiple adapters in DeepSpeed ZeRO-3."
is_mergeable = False
if model_args.use_unsloth:
assert len(model_args.adapter_name_or_path) == 1, "Unsloth model only accepts a single adapter."
is_mergeable = False
if (is_trainable and not finetuning_args.create_new_adapter) or (not is_mergeable):
adapter_to_merge = model_args.adapter_name_or_path[:-1]
adapter_to_resume = model_args.adapter_name_or_path[-1]
else:
adapter_to_merge = model_args.adapter_name_or_path
for adapter in adapter_to_merge:
model: "LoraModel" = PeftModel.from_pretrained(model, adapter, offload_folder=model_args.offload_folder)
model = model.merge_and_unload()
if len(adapter_to_merge) > 0:
logger.info("Merged {} adapter(s).".format(len(adapter_to_merge)))
if adapter_to_resume is not None: # resume lora training
if model_args.use_unsloth:
model = load_unsloth_peft_model(config, model_args, is_trainable=is_trainable)
else:
model = PeftModel.from_pretrained(
model,
adapter_to_resume,
is_trainable=is_trainable,
offload_folder=model_args.offload_folder,
)
if is_trainable and adapter_to_resume is None: # create new lora weights while training
if len(finetuning_args.lora_target) == 1 and finetuning_args.lora_target[0] == "all":
target_modules = find_all_linear_modules(model, finetuning_args.freeze_vision_tower)
else:
target_modules = finetuning_args.lora_target
if finetuning_args.use_llama_pro:
target_modules = find_expanded_modules(model, target_modules, finetuning_args.freeze_trainable_layers)
if model_args.visual_inputs and finetuning_args.freeze_vision_tower:
target_modules = "^(?!.*vision_tower).*(?:{}).*".format("|".join(target_modules))
if (
finetuning_args.use_dora
and getattr(model, "quantization_method", None) is not None
and getattr(model, "quantization_method", None) != QuantizationMethod.BITS_AND_BYTES
):
raise ValueError("DoRA is not compatible with PTQ-quantized models.")
if model_args.resize_vocab and finetuning_args.additional_target is None:
input_embeddings = model.get_input_embeddings()
output_embeddings = model.get_output_embeddings()
module_names = set()
for name, module in model.named_modules():
if module in [input_embeddings, output_embeddings]:
module_names.add(name.split(".")[-1])
finetuning_args.additional_target = module_names
logger.warning("Vocab has been resized, add {} to trainable params.".format(",".join(module_names)))
peft_kwargs = {
"r": finetuning_args.lora_rank,
"target_modules": target_modules,
"lora_alpha": finetuning_args.lora_alpha,
"lora_dropout": finetuning_args.lora_dropout,
"use_rslora": finetuning_args.use_rslora,
"modules_to_save": finetuning_args.additional_target,
}
if model_args.use_unsloth:
model = get_unsloth_peft_model(model, model_args, peft_kwargs)
else:
lora_config = LoraConfig(
task_type=TaskType.CAUSAL_LM,
inference_mode=False,
use_dora=finetuning_args.use_dora,
**peft_kwargs,
)
model = get_peft_model(model, lora_config)
if is_trainable and cast_trainable_params_to_fp32:
for param in filter(lambda p: p.requires_grad, model.parameters()):
param.data = param.data.to(torch.float32)
if model_args.adapter_name_or_path is not None:
logger.info("Loaded adapter(s): {}".format(",".join(model_args.adapter_name_or_path)))
return model
def init_adapter(
config: "PretrainedConfig",
model: "PreTrainedModel",
model_args: "ModelArguments",
finetuning_args: "FinetuningArguments",
is_trainable: bool,
) -> "PreTrainedModel":
r"""
Initializes the adapters.
Support full-parameter, freeze and LoRA training.
Note that the trainable parameters must be cast to float32.
"""
if (not is_trainable) and model_args.adapter_name_or_path is None:
logger.info("Adapter is not found at evaluation, load the base model.")
return model
if finetuning_args.finetuning_type != "lora" and getattr(model, "quantization_method", None):
raise ValueError("You can only use lora for quantized models.")
if is_deepspeed_zero3_enabled() or is_fsdp_enabled() or finetuning_args.pure_bf16 or finetuning_args.use_badam:
logger.info("ZeRO3/FSDP/PureBF16/BAdam detected, remaining trainable params as their original precision.")
cast_trainable_params_to_fp32 = False
else:
logger.info("Upcasting trainable params to float32.")
cast_trainable_params_to_fp32 = True
if is_trainable and finetuning_args.finetuning_type == "full":
_setup_full_tuning(model, model_args, finetuning_args, cast_trainable_params_to_fp32)
if is_trainable and finetuning_args.finetuning_type == "freeze":
_setup_freeze_tuning(model, model_args, finetuning_args, cast_trainable_params_to_fp32)
if finetuning_args.finetuning_type == "lora":
model = _setup_lora_tuning(
config, model, model_args, finetuning_args, is_trainable, cast_trainable_params_to_fp32
)
return model

View File

@@ -0,0 +1,186 @@
from typing import TYPE_CHECKING, Any, Dict, Optional, TypedDict
from transformers import AutoConfig, AutoModelForCausalLM, AutoModelForVision2Seq, AutoProcessor, AutoTokenizer
from trl import AutoModelForCausalLMWithValueHead
from ..extras.logging import get_logger
from ..extras.misc import count_parameters, try_download_model_from_ms
from .adapter import init_adapter
from .model_utils.misc import register_autoclass
from .model_utils.mod import convert_pretrained_model_to_mod, load_mod_pretrained_model
from .model_utils.unsloth import load_unsloth_pretrained_model
from .model_utils.valuehead import load_valuehead_params
from .patcher import patch_config, patch_model, patch_tokenizer, patch_valuehead_model
if TYPE_CHECKING:
from transformers import PretrainedConfig, PreTrainedModel, PreTrainedTokenizer, ProcessorMixin
from ..hparams import FinetuningArguments, ModelArguments
logger = get_logger(__name__)
class TokenizerModule(TypedDict):
tokenizer: "PreTrainedTokenizer"
processor: Optional["ProcessorMixin"]
def _get_init_kwargs(model_args: "ModelArguments") -> Dict[str, Any]:
r"""
Gets arguments to load config/tokenizer/model.
Note: including inplace operation of model_args.
"""
model_args.model_name_or_path = try_download_model_from_ms(model_args)
return {
"trust_remote_code": True,
"cache_dir": model_args.cache_dir,
"revision": model_args.model_revision,
"token": model_args.hf_hub_token,
}
def load_tokenizer(model_args: "ModelArguments") -> "TokenizerModule":
r"""
Loads pretrained tokenizer.
Note: including inplace operation of model_args.
"""
init_kwargs = _get_init_kwargs(model_args)
try:
tokenizer = AutoTokenizer.from_pretrained(
model_args.model_name_or_path,
use_fast=model_args.use_fast_tokenizer,
split_special_tokens=model_args.split_special_tokens,
padding_side="right",
**init_kwargs,
)
except ValueError: # try the fast one
tokenizer = AutoTokenizer.from_pretrained(
model_args.model_name_or_path,
use_fast=True,
padding_side="right",
**init_kwargs,
)
if model_args.new_special_tokens is not None:
num_added_tokens = tokenizer.add_special_tokens(
dict(additional_special_tokens=model_args.new_special_tokens),
replace_additional_special_tokens=False,
)
logger.info("Add {} to special tokens.".format(",".join(model_args.new_special_tokens)))
if num_added_tokens > 0 and not model_args.resize_vocab:
model_args.resize_vocab = True
logger.warning("New tokens have been added, changed `resize_vocab` to True.")
patch_tokenizer(tokenizer)
if model_args.visual_inputs:
try:
processor = AutoProcessor.from_pretrained(model_args.model_name_or_path, **init_kwargs)
setattr(processor, "tokenizer", tokenizer)
except Exception:
raise ValueError(
"This multimodal LLM is not supported.\n"
"Download LLaVA-1.5 models from: https://huggingface.co/llava-hf\n"
"Download Yi-VL models from: https://huggingface.co/BUAADreamer"
)
else:
processor = None
return {"tokenizer": tokenizer, "processor": processor}
def load_config(model_args: "ModelArguments") -> "PretrainedConfig":
r"""
Loads model config.
"""
init_kwargs = _get_init_kwargs(model_args)
return AutoConfig.from_pretrained(model_args.model_name_or_path, **init_kwargs)
def load_model(
tokenizer: "PreTrainedTokenizer",
model_args: "ModelArguments",
finetuning_args: "FinetuningArguments",
is_trainable: bool = False,
add_valuehead: bool = False,
) -> "PreTrainedModel":
r"""
Loads pretrained model.
"""
init_kwargs = _get_init_kwargs(model_args)
config = load_config(model_args)
patch_config(config, tokenizer, model_args, init_kwargs, is_trainable)
model = None
lazy_load = False
if model_args.use_unsloth:
if model_args.adapter_name_or_path is not None:
lazy_load = True
elif is_trainable:
model = load_unsloth_pretrained_model(config, model_args)
if model is None and not lazy_load:
init_kwargs["config"] = config
init_kwargs["pretrained_model_name_or_path"] = model_args.model_name_or_path
if model_args.mixture_of_depths == "load":
model = load_mod_pretrained_model(**init_kwargs)
elif model_args.visual_inputs:
model = AutoModelForVision2Seq.from_pretrained(**init_kwargs)
elif model_args.train_from_scratch:
model = AutoModelForCausalLM.from_config(config)
else:
model = AutoModelForCausalLM.from_pretrained(**init_kwargs)
if model_args.mixture_of_depths == "convert":
model = convert_pretrained_model_to_mod(model, config, model_args)
if not lazy_load:
patch_model(model, tokenizer, model_args, is_trainable, add_valuehead)
register_autoclass(config, model, tokenizer)
model = init_adapter(config, model, model_args, finetuning_args, is_trainable)
if add_valuehead:
model = AutoModelForCausalLMWithValueHead.from_pretrained(model)
patch_valuehead_model(model)
if model_args.adapter_name_or_path is not None:
vhead_path = model_args.adapter_name_or_path[-1]
else:
vhead_path = model_args.model_name_or_path
vhead_params = load_valuehead_params(vhead_path, model_args)
if vhead_params is not None:
model.load_state_dict(vhead_params, strict=False)
logger.info("Loaded valuehead from checkpoint: {}".format(vhead_path))
if not is_trainable:
model.requires_grad_(False)
model.eval()
else:
model.train()
trainable_params, all_param = count_parameters(model)
if is_trainable:
param_stats = "trainable params: {:d} || all params: {:d} || trainable%: {:.4f}".format(
trainable_params, all_param, 100 * trainable_params / all_param
)
else:
param_stats = "all params: {:d}".format(all_param)
logger.info(param_stats)
if model_args.print_param_status:
for name, param in model.named_parameters():
print(
"name: {}, dtype: {}, device: {}, trainable: {}".format(
name, param.dtype, param.device, param.requires_grad
)
)
return model

View File

@@ -0,0 +1,55 @@
from typing import TYPE_CHECKING
from ...extras.logging import get_logger
from ...extras.packages import is_flash_attn2_available, is_sdpa_available
if TYPE_CHECKING:
from transformers import PretrainedConfig
from ...hparams import ModelArguments
logger = get_logger(__name__)
def configure_attn_implementation(config: "PretrainedConfig", model_args: "ModelArguments") -> None:
if model_args.flash_attn == "auto":
return
elif model_args.flash_attn == "off":
requested_attn_implementation = "eager"
elif model_args.flash_attn == "sdpa":
if not is_sdpa_available():
logger.warning("torch>=2.1.1 is required for SDPA attention.")
return
requested_attn_implementation = "sdpa"
elif model_args.flash_attn == "fa2":
if not is_flash_attn2_available():
logger.warning("FlashAttention-2 is not installed.")
return
requested_attn_implementation = "flash_attention_2"
else:
raise NotImplementedError("Unknown attention type: {}".format(model_args.flash_attn))
if getattr(config, "model_type", None) == "internlm2": # special case for custom models
setattr(config, "attn_implementation", requested_attn_implementation)
else:
setattr(config, "_attn_implementation", requested_attn_implementation)
def print_attn_implementation(config: "PretrainedConfig") -> None:
if getattr(config, "model_type", None) == "internlm2": # special case for custom models
attn_implementation = getattr(config, "attn_implementation", None)
else:
attn_implementation = getattr(config, "_attn_implementation", None)
if attn_implementation == "flash_attention_2":
logger.info("Using FlashAttention-2 for faster training and inference.")
elif attn_implementation == "sdpa":
logger.info("Using torch SDPA for faster training and inference.")
else:
logger.info("Using vanilla attention implementation.")

View File

@@ -0,0 +1,94 @@
import inspect
from functools import partial
from types import MethodType
from typing import TYPE_CHECKING, Any, Dict, Optional, Tuple
import torch
from ...extras.constants import LAYERNORM_NAMES
from ...extras.logging import get_logger
if TYPE_CHECKING:
from transformers import PreTrainedModel
from ...hparams import ModelArguments
logger = get_logger(__name__)
def _gradient_checkpointing_enable(
self: "PreTrainedModel", gradient_checkpointing_kwargs: Optional[Dict[str, Any]] = None
) -> None:
r"""
Activates gradient checkpointing for the current model.
Modification of the original method to enable gradient checkpointing for block-wise optimizer.
"""
from torch.utils.checkpoint import checkpoint
if not self.supports_gradient_checkpointing:
raise ValueError("{} does not support gradient checkpointing.".format(self.__class__.__name__))
if gradient_checkpointing_kwargs is None:
gradient_checkpointing_kwargs = {"use_reentrant": True}
gradient_checkpointing_func = partial(checkpoint, **gradient_checkpointing_kwargs)
def custom_gradient_checkpointing_func(func, *args, **kwargs):
module: "torch.nn.Module" = func.__self__
if any(param.requires_grad for param in module.parameters()):
for arg in args:
if torch.is_tensor(arg) and torch.is_floating_point(arg):
arg.requires_grad_(True)
return gradient_checkpointing_func(func, *args, **kwargs)
if "value" in inspect.signature(self._set_gradient_checkpointing).parameters: # old GC format
self.apply(partial(self._set_gradient_checkpointing, value=True))
self.enable_input_require_grads()
logger.warning("You are using the old GC format, some features (e.g. BAdam) will be invalid.")
else: # have already enabled input require gradients
self._set_gradient_checkpointing(enable=True, gradient_checkpointing_func=custom_gradient_checkpointing_func)
def _fp32_forward_post_hook(
module: "torch.nn.Module", args: Tuple["torch.Tensor"], output: "torch.Tensor"
) -> "torch.Tensor":
return output.to(torch.float32)
def prepare_model_for_training(
model: "PreTrainedModel", model_args: "ModelArguments", output_layer_name: str = "lm_head"
) -> None:
r"""
Includes:
(1) cast the layernorm in fp32
(2) make output embedding layer require grads
(3) add the upcasting of the lm_head in fp32
Inspired by: https://github.com/huggingface/peft/blob/v0.7.1/src/peft/utils/other.py#L72
"""
if model_args.upcast_layernorm:
logger.info("Upcasting layernorm weights in float32.")
for name, param in model.named_parameters():
if param.ndim == 1 and any(ln_name in name for ln_name in LAYERNORM_NAMES):
param.data = param.data.to(torch.float32)
if not model_args.disable_gradient_checkpointing:
if not getattr(model, "supports_gradient_checkpointing", False):
logger.warning("Current model does not support gradient checkpointing.")
else:
# use_reentrant=False might increase VRAM usage (have not been empirically verified yet)
# According to: https://github.com/huggingface/transformers/issues/28339
model.gradient_checkpointing_enable = MethodType(_gradient_checkpointing_enable, model)
model.gradient_checkpointing_enable(gradient_checkpointing_kwargs={"use_reentrant": True})
setattr(model.config, "use_cache", False) # turn off when gradient checkpointing is enabled
logger.info("Gradient checkpointing enabled.")
if hasattr(model, output_layer_name) and model_args.upcast_lmhead_output:
logger.info("Upcasting lm_head outputs in float32.")
output_layer = getattr(model, output_layer_name)
if isinstance(output_layer, torch.nn.Linear) and output_layer.weight.dtype != torch.float32:
output_layer.register_forward_hook(_fp32_forward_post_hook)

View File

@@ -0,0 +1,58 @@
import math
from contextlib import nullcontext
from typing import TYPE_CHECKING
import torch
from transformers.integrations import is_deepspeed_zero3_enabled
from ...extras.logging import get_logger
if TYPE_CHECKING:
from transformers import PreTrainedModel, PreTrainedTokenizer
logger = get_logger(__name__)
def _noisy_mean_initialization(embed_weight: "torch.Tensor", num_new_tokens: int) -> None:
embedding_dim = embed_weight.size(1)
avg_weight = embed_weight[:-num_new_tokens].mean(dim=0, keepdim=True)
noise_weight = torch.empty_like(embed_weight[-num_new_tokens:])
noise_weight.normal_(mean=0, std=(1.0 / math.sqrt(embedding_dim)))
embed_weight[-num_new_tokens:] = avg_weight + noise_weight
def resize_embedding_layer(model: "PreTrainedModel", tokenizer: "PreTrainedTokenizer") -> None:
r"""
Resize token embeddings.
"""
if is_deepspeed_zero3_enabled():
import deepspeed # type: ignore
params = [model.get_input_embeddings().weight]
if model.get_output_embeddings() is not None and not model.config.tie_word_embeddings:
params.append(model.get_output_embeddings().weight)
context_maybe_zero3 = deepspeed.zero.GatheredParameters(params, modifier_rank=0)
else:
context_maybe_zero3 = nullcontext()
with context_maybe_zero3:
current_embedding_size = model.get_input_embeddings().weight.size(0)
if len(tokenizer) > current_embedding_size:
if getattr(model, "quantization_method", None):
raise ValueError("Cannot resize embedding layers of a quantized model.")
if not isinstance(model.get_output_embeddings(), torch.nn.Linear):
raise ValueError("Current model does not support resizing embedding layers.")
model.resize_token_embeddings(len(tokenizer), pad_to_multiple_of=64)
with context_maybe_zero3:
new_embedding_size = model.get_input_embeddings().weight.size(0)
num_new_tokens = new_embedding_size - current_embedding_size
_noisy_mean_initialization(model.get_input_embeddings().weight.data, num_new_tokens)
_noisy_mean_initialization(model.get_output_embeddings().weight.data, num_new_tokens)
logger.info("Resized token embeddings from {} to {}.".format(current_embedding_size, new_embedding_size))

View File

@@ -0,0 +1,323 @@
import math
from typing import TYPE_CHECKING, Optional, Tuple
import torch
import torch.nn as nn
from transformers.models.llama.modeling_llama import (
Cache,
LlamaAttention,
LlamaFlashAttention2,
LlamaSdpaAttention,
apply_rotary_pos_emb,
repeat_kv,
)
from transformers.utils import logging
from transformers.utils.versions import require_version
from ...extras.constants import SUPPORTED_CLASS_FOR_S2ATTN
from ...extras.logging import get_logger
if TYPE_CHECKING:
from transformers import PretrainedConfig
from ...hparams import ModelArguments
logger = logging.get_logger(__name__)
# Modified from:
# https://github.com/huggingface/transformers/blob/v4.40.0/src/transformers/models/llama/modeling_llama.py
def llama_attention_forward(
self: "LlamaAttention",
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_value: Optional["Cache"] = None,
output_attentions: bool = False,
cache_position: Optional[torch.LongTensor] = None,
**kwargs,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
bsz, q_len, _ = hidden_states.size()
query_states: "torch.Tensor" = self.q_proj(hidden_states)
key_states: "torch.Tensor" = self.k_proj(hidden_states)
value_states: "torch.Tensor" = self.v_proj(hidden_states)
query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
cos, sin = self.rotary_emb(value_states, position_ids)
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
past_key_value = getattr(self, "past_key_value", past_key_value)
if past_key_value is not None:
cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
key_states = repeat_kv(key_states, self.num_key_value_groups)
value_states = repeat_kv(value_states, self.num_key_value_groups)
if getattr(self.config, "group_size_ratio", None) and self.training: # shift
groupsz = int(q_len * getattr(self.config, "group_size_ratio"))
assert q_len % groupsz == 0, "q_len {} should be divisible by group size {}.".format(q_len, groupsz)
num_groups = q_len // groupsz
def shift(state: torch.Tensor) -> torch.Tensor:
state = state.transpose(1, 2) # output: (bsz, seq_len, n_heads, head_dim)
state = torch.cat(
(state[:, :, : self.num_heads // 2], state[:, :, self.num_heads // 2 :].roll(-groupsz // 2, dims=1)),
dim=2,
)
return state.reshape(bsz * num_groups, groupsz, self.num_heads, self.head_dim).transpose(1, 2)
query_states, key_states, value_states = shift(query_states), shift(key_states), shift(value_states)
if attention_mask is not None:
attention_mask = attention_mask[:, :, :groupsz, :groupsz].repeat(num_groups, 1, 1, 1)
attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
if attention_mask is not None: # no matter the length, we just slice it
causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
attn_weights = attn_weights + causal_mask
# upcast attention to fp32
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training)
attn_output = torch.matmul(attn_weights, value_states) # (bsz, :, seq_len, :) or (bsz * n_group, :, groupsz, :)
attn_output = attn_output.transpose(1, 2).contiguous()
if getattr(self.config, "group_size_ratio", None) and self.training: # shift back
attn_output.reshape(bsz, q_len, self.num_heads, self.head_dim)
attn_output = torch.cat(
(
attn_output[:, :, : self.num_heads // 2],
attn_output[:, :, self.num_heads // 2 :].roll(groupsz // 2, dims=1),
)
)
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
attn_output = self.o_proj(attn_output)
if not output_attentions:
attn_weights = None
return attn_output, attn_weights, past_key_value
# Modified from:
# https://github.com/huggingface/transformers/blob/v4.40.0/src/transformers/models/llama/modeling_llama.py
def llama_flash_attention_2_forward(
self: "LlamaFlashAttention2",
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_value: Optional["Cache"] = None,
output_attentions: bool = False,
cache_position: Optional[torch.LongTensor] = None,
**kwargs,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
# LlamaFlashAttention2 attention does not support output_attentions
output_attentions = False
bsz, q_len, _ = hidden_states.size()
query_states: "torch.Tensor" = self.q_proj(hidden_states)
key_states: "torch.Tensor" = self.k_proj(hidden_states)
value_states: "torch.Tensor" = self.v_proj(hidden_states)
query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
cos, sin = self.rotary_emb(value_states, position_ids)
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
past_key_value = getattr(self, "past_key_value", past_key_value)
if past_key_value is not None:
cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
key_states = repeat_kv(key_states, self.num_key_value_groups)
value_states = repeat_kv(value_states, self.num_key_value_groups)
# FlashAttention requires the input to have the shape (bsz, seq_len, n_heads, head_dim)
query_states = query_states.transpose(1, 2)
key_states = key_states.transpose(1, 2)
value_states = value_states.transpose(1, 2)
dropout_rate = self.attention_dropout if self.training else 0.0
input_dtype = query_states.dtype
if input_dtype == torch.float32:
if torch.is_autocast_enabled():
target_dtype = torch.get_autocast_gpu_dtype()
elif hasattr(self.config, "_pre_quantization_dtype"):
target_dtype = self.config._pre_quantization_dtype
else:
target_dtype = self.q_proj.weight.dtype
logger.warning_once("The input hidden states seems to be silently casted in float32.")
query_states = query_states.to(target_dtype)
key_states = key_states.to(target_dtype)
value_states = value_states.to(target_dtype)
if getattr(self.config, "group_size_ratio", None) and self.training: # shift
groupsz = int(q_len * getattr(self.config, "group_size_ratio"))
assert q_len % groupsz == 0, "q_len {} should be divisible by group size {}.".format(q_len, groupsz)
num_groups = q_len // groupsz
def shift(state: torch.Tensor) -> torch.Tensor:
state = torch.cat(
(state[:, :, : self.num_heads // 2], state[:, :, self.num_heads // 2 :].roll(-groupsz // 2, dims=1)),
dim=2,
)
return state.reshape(bsz * num_groups, groupsz, self.num_heads, self.head_dim)
query_states, key_states, value_states = shift(query_states), shift(key_states), shift(value_states)
if attention_mask is not None:
attention_mask = attention_mask[:, :groupsz].repeat(num_groups, 1)
else:
groupsz = q_len
attn_output: torch.Tensor = self._flash_attention_forward(
query_states, key_states, value_states, attention_mask, groupsz, dropout=dropout_rate
)
if getattr(self.config, "group_size_ratio", None) and self.training: # shift back
attn_output.reshape(bsz, q_len, self.num_heads, self.head_dim)
attn_output = torch.cat(
(
attn_output[:, :, : self.num_heads // 2],
attn_output[:, :, self.num_heads // 2 :].roll(groupsz // 2, dims=1),
)
)
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size).contiguous()
attn_output = self.o_proj(attn_output)
if not output_attentions:
attn_weights = None
return attn_output, attn_weights, past_key_value
# Modified from:
# https://github.com/huggingface/transformers/blob/v4.40.0/src/transformers/models/llama/modeling_llama.py
def llama_sdpa_attention_forward(
self: "LlamaSdpaAttention",
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_value: Optional["Cache"] = None,
output_attentions: bool = False,
cache_position: Optional[torch.LongTensor] = None,
**kwargs,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
if output_attentions:
logger.warning_once("SDPA does not support `output_attentions=True`. Falling back to the vanilla attention")
return llama_attention_forward(
self,
hidden_states=hidden_states,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_value=past_key_value,
output_attentions=output_attentions,
cache_position=cache_position,
**kwargs,
)
bsz, q_len, _ = hidden_states.size()
query_states: "torch.Tensor" = self.q_proj(hidden_states)
key_states: "torch.Tensor" = self.k_proj(hidden_states)
value_states: "torch.Tensor" = self.v_proj(hidden_states)
query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
cos, sin = self.rotary_emb(value_states, position_ids)
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
if past_key_value is not None:
cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
key_states = repeat_kv(key_states, self.num_key_value_groups)
value_states = repeat_kv(value_states, self.num_key_value_groups)
if getattr(self.config, "group_size_ratio", None) and self.training: # shift
groupsz = int(q_len * getattr(self.config, "group_size_ratio"))
assert q_len % groupsz == 0, "q_len {} should be divisible by group size {}.".format(q_len, groupsz)
num_groups = q_len // groupsz
def shift(state: torch.Tensor) -> torch.Tensor:
state = state.transpose(1, 2) # output: (bsz, seq_len, n_heads, head_dim)
state = torch.cat(
(state[:, :, : self.num_heads // 2], state[:, :, self.num_heads // 2 :].roll(-groupsz // 2, dims=1)),
dim=2,
)
return state.reshape(bsz * num_groups, groupsz, self.num_heads, self.head_dim).transpose(1, 2)
query_states, key_states, value_states = shift(query_states), shift(key_states), shift(value_states)
if attention_mask is not None:
attention_mask = attention_mask[:, :, :groupsz, :groupsz].repeat(num_groups, 1, 1, 1)
causal_mask = attention_mask
if attention_mask is not None:
causal_mask = causal_mask[:, :, :, : key_states.shape[-2]]
if query_states.device.type == "cuda" and causal_mask is not None:
query_states = query_states.contiguous()
key_states = key_states.contiguous()
value_states = value_states.contiguous()
attn_output = torch.nn.functional.scaled_dot_product_attention(
query_states,
key_states,
value_states,
attn_mask=causal_mask,
dropout_p=self.attention_dropout if self.training else 0.0,
is_causal=causal_mask is None and q_len > 1,
)
attn_output = attn_output.transpose(1, 2).contiguous()
if getattr(self.config, "group_size_ratio", None) and self.training: # shift back
attn_output.reshape(bsz, q_len, self.num_heads, self.head_dim)
attn_output = torch.cat(
(
attn_output[:, :, : self.num_heads // 2],
attn_output[:, :, self.num_heads // 2 :].roll(groupsz // 2, dims=1),
)
)
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
attn_output = self.o_proj(attn_output)
return attn_output, None, past_key_value
def _apply_llama_patch() -> None:
require_version("transformers==4.40.2", "To fix: pip install transformers==4.40.2")
LlamaAttention.forward = llama_attention_forward
LlamaFlashAttention2.forward = llama_flash_attention_2_forward
LlamaSdpaAttention.forward = llama_sdpa_attention_forward
def configure_longlora(config: "PretrainedConfig", model_args: "ModelArguments", is_trainable: bool) -> None:
if not is_trainable or not model_args.shift_attn:
return
logger = get_logger(__name__)
if getattr(config, "model_type", None) in SUPPORTED_CLASS_FOR_S2ATTN:
setattr(config, "group_size_ratio", 0.25)
_apply_llama_patch()
logger.info("Using shift short attention with group_size_ratio=1/4.")
else:
logger.warning("Current model does not support shift short attention.")

View File

@@ -0,0 +1,74 @@
from typing import TYPE_CHECKING, List
from ...extras.logging import get_logger
if TYPE_CHECKING:
from transformers import PretrainedConfig, PreTrainedModel, PreTrainedTokenizer
logger = get_logger(__name__)
def find_all_linear_modules(model: "PreTrainedModel", freeze_vision_tower: bool) -> List[str]:
r"""
Finds all available modules to apply lora or galore.
"""
forbidden_modules = {"lm_head"}
if model.config.model_type == "chatglm":
forbidden_modules.add("output_layer")
elif model.config.model_type == "internlm2":
forbidden_modules.add("output")
elif model.config.model_type in ["llava", "paligemma"]:
forbidden_modules.add("multi_modal_projector")
if freeze_vision_tower:
forbidden_modules.add("vision_tower")
module_names = set()
for name, module in model.named_modules():
if any(forbidden_module in name for forbidden_module in forbidden_modules):
continue
if "Linear" in module.__class__.__name__ and "Embedding" not in module.__class__.__name__:
module_names.add(name.split(".")[-1])
logger.info("Found linear modules: {}".format(",".join(module_names)))
return list(module_names)
def find_expanded_modules(model: "PreTrainedModel", target_modules: List[str], num_layer_trainable: int) -> List[str]:
r"""
Finds the modules in the expanded blocks to apply lora.
"""
num_layers = getattr(model.config, "num_hidden_layers", None)
if not num_layers:
raise ValueError("Model was not supported.")
if num_layers % num_layer_trainable != 0:
raise ValueError(
"`num_layers` {} should be divisible by `num_layer_trainable` {}.".format(num_layers, num_layer_trainable)
)
stride = num_layers // num_layer_trainable
trainable_layer_ids = range(stride - 1, num_layers + stride - 1, stride)
trainable_layers = [".{:d}.".format(idx) for idx in trainable_layer_ids]
module_names = []
for name, _ in model.named_modules():
if any(target_module in name for target_module in target_modules) and any(
trainable_layer in name for trainable_layer in trainable_layers
):
module_names.append(name)
logger.info("Apply lora to layers: {}".format(",".join(map(str, trainable_layer_ids))))
return module_names
def register_autoclass(config: "PretrainedConfig", model: "PreTrainedModel", tokenizer: "PreTrainedTokenizer"):
if "AutoConfig" in getattr(config, "auto_map", {}):
config.__class__.register_for_auto_class()
if "AutoModelForCausalLM" in getattr(config, "auto_map", {}):
model.__class__.register_for_auto_class()
if "AutoTokenizer" in tokenizer.init_kwargs.get("auto_map", {}):
tokenizer.__class__.register_for_auto_class()

View File

@@ -0,0 +1,28 @@
from typing import TYPE_CHECKING
from ...extras.constants import MOD_SUPPORTED_MODELS
if TYPE_CHECKING:
from transformers import PretrainedConfig, PreTrainedModel
from ...hparams import ModelArguments
def load_mod_pretrained_model(**init_kwargs) -> "PreTrainedModel":
from MoD import AutoMoDModelForCausalLM
return AutoMoDModelForCausalLM.from_pretrained(**init_kwargs)
def convert_pretrained_model_to_mod(
model: "PreTrainedModel", config: "PretrainedConfig", model_args: "ModelArguments"
) -> "PreTrainedModel":
from MoD import apply_mod_to_hf
if getattr(config, "model_type", None) not in MOD_SUPPORTED_MODELS:
raise ValueError("Current model is not supported by mixture-of-depth.")
model = apply_mod_to_hf(model)
model = model.to(model_args.compute_dtype)
return model

View File

@@ -0,0 +1,61 @@
from typing import TYPE_CHECKING
from transformers.integrations import is_deepspeed_zero3_enabled
from transformers.utils.versions import require_version
if TYPE_CHECKING:
from transformers import PretrainedConfig, PreTrainedModel
from ...hparams import ModelArguments
def add_z3_leaf_module(model: "PreTrainedModel") -> None:
r"""
Sets module as a leaf module to skip partitioning in deepspeed zero3.
"""
if not is_deepspeed_zero3_enabled():
return
require_version("deepspeed>=0.13.0", "To fix: pip install deepspeed>=0.13.0")
from deepspeed.utils import set_z3_leaf_modules # type: ignore
if getattr(model.config, "model_type", None) == "dbrx":
from transformers.models.dbrx.modeling_dbrx import DbrxFFN
set_z3_leaf_modules(model, [DbrxFFN])
if getattr(model.config, "model_type", None) == "jamba":
from transformers.models.jamba.modeling_jamba import JambaSparseMoeBlock
set_z3_leaf_modules(model, [JambaSparseMoeBlock])
if getattr(model.config, "model_type", None) == "jetmoe":
from transformers.models.jetmoe.modeling_jetmoe import JetMoeMoA, JetMoeMoE
set_z3_leaf_modules(model, [JetMoeMoA, JetMoeMoE])
if getattr(model.config, "model_type", None) == "mixtral":
from transformers.models.mixtral.modeling_mixtral import MixtralSparseMoeBlock
set_z3_leaf_modules(model, [MixtralSparseMoeBlock])
if getattr(model.config, "model_type", None) == "qwen2moe":
from transformers.models.qwen2_moe.modeling_qwen2_moe import Qwen2MoeSparseMoeBlock
set_z3_leaf_modules(model, [Qwen2MoeSparseMoeBlock])
def configure_moe(config: "PretrainedConfig", model_args: "ModelArguments", is_trainable: bool) -> None:
if model_args.moe_aux_loss_coef is not None:
if getattr(config, "model_type", None) in ["jamba", "mixtral", "qwen2_moe"]:
setattr(config, "router_aux_loss_coef", model_args.moe_aux_loss_coef)
elif getattr(config, "model_type", None) == "deepseek":
setattr(config, "aux_loss_alpha", model_args.moe_aux_loss_coef)
elif getattr(config, "model_type", None) == "jetmoe":
setattr(config, "aux_loss_coef", model_args.moe_aux_loss_coef)
if getattr(config, "model_type", None) in ["dbrx", "jamba", "jetmoe", "mixtral", "qwen2_moe"]:
setattr(config, "output_router_logits", is_trainable)

View File

@@ -0,0 +1,150 @@
import os
import random
from enum import Enum, unique
from typing import TYPE_CHECKING, Any, Dict, List
import torch
from datasets import load_dataset
from transformers import BitsAndBytesConfig, GPTQConfig
from transformers.integrations import is_deepspeed_zero3_enabled
from transformers.modeling_utils import is_fsdp_enabled
from transformers.utils.versions import require_version
from ...extras.constants import FILEEXT2TYPE
from ...extras.logging import get_logger
from ...extras.misc import get_current_device
if TYPE_CHECKING:
from transformers import PretrainedConfig, PreTrainedTokenizer
from ...hparams import ModelArguments
logger = get_logger(__name__)
@unique
class QuantizationMethod(str, Enum):
r"""
Borrowed from `transformers.utils.quantization_config.QuantizationMethod`.
"""
BITS_AND_BYTES = "bitsandbytes"
GPTQ = "gptq"
AWQ = "awq"
AQLM = "aqlm"
QUANTO = "quanto"
EETQ = "eetq"
HQQ = "hqq"
def _get_quantization_dataset(tokenizer: "PreTrainedTokenizer", model_args: "ModelArguments") -> List[str]:
r"""
Inspired by: https://github.com/huggingface/optimum/blob/v1.16.0/optimum/gptq/data.py#L133
TODO: remove tokenizer.decode() https://github.com/huggingface/optimum/pull/1600
"""
if os.path.isfile(model_args.export_quantization_dataset):
data_path = FILEEXT2TYPE.get(model_args.export_quantization_dataset.split(".")[-1], None)
data_files = model_args.export_quantization_dataset
else:
data_path = model_args.export_quantization_dataset
data_files = None
dataset = load_dataset(path=data_path, data_files=data_files, split="train", cache_dir=model_args.cache_dir)
maxlen = model_args.export_quantization_maxlen
samples = []
for _ in range(model_args.export_quantization_nsamples):
while True:
sample_idx = random.randint(0, len(dataset) - 1)
sample: Dict[str, torch.Tensor] = tokenizer(dataset[sample_idx]["text"], return_tensors="pt")
if sample["input_ids"].size(1) >= maxlen:
break # TODO: fix large maxlen
word_idx = random.randint(0, sample["input_ids"].size(1) - maxlen - 1)
input_ids = sample["input_ids"][:, word_idx : word_idx + maxlen]
samples.append(tokenizer.decode(input_ids[0].tolist(), skip_special_tokens=True))
return samples
def configure_quantization(
config: "PretrainedConfig",
tokenizer: "PreTrainedTokenizer",
model_args: "ModelArguments",
init_kwargs: Dict[str, Any],
) -> None:
r"""
Priority: PTQ-quantized (training) > AutoGPTQ (export) > Bitsandbytes (training)
"""
if getattr(config, "quantization_config", None): # ptq
if is_deepspeed_zero3_enabled():
raise ValueError("DeepSpeed ZeRO-3 is incompatible with quantized models.")
if model_args.quantization_device_map != "auto":
init_kwargs["device_map"] = {"": get_current_device()}
quantization_config: Dict[str, Any] = getattr(config, "quantization_config", None)
quant_method = quantization_config.get("quant_method", "")
if quant_method == QuantizationMethod.GPTQ:
require_version("auto_gptq>=0.5.0", "To fix: pip install auto_gptq>=0.5.0")
quantization_config.pop("disable_exllama", None) # remove deprecated args
quantization_config["use_exllama"] = False # disable exllama
if quant_method == QuantizationMethod.AWQ:
require_version("autoawq", "To fix: pip install autoawq")
if quant_method == QuantizationMethod.AQLM:
require_version("transformers>=4.39.0", "To fix: pip install transformers>=4.39.0")
require_version("aqlm>=1.1.0", "To fix: pip install aqlm[gpu]>=1.1.0")
quantization_config["bits"] = 2
quant_bits = quantization_config.get("bits", "?")
logger.info("Loading {}-bit {}-quantized model.".format(quant_bits, quant_method.upper()))
elif model_args.export_quantization_bit is not None: # auto-gptq
require_version("optimum>=1.16.0", "To fix: pip install optimum>=1.16.0")
require_version("auto_gptq>=0.5.0", "To fix: pip install auto_gptq>=0.5.0")
from accelerate.utils import get_max_memory
if getattr(config, "model_type", None) == "chatglm":
raise ValueError("ChatGLM model is not supported.")
init_kwargs["quantization_config"] = GPTQConfig(
bits=model_args.export_quantization_bit,
tokenizer=tokenizer,
dataset=_get_quantization_dataset(tokenizer, model_args),
)
init_kwargs["device_map"] = "auto"
init_kwargs["max_memory"] = get_max_memory()
logger.info("Quantizing model to {} bit.".format(model_args.export_quantization_bit))
elif model_args.quantization_bit is not None: # bnb
if model_args.quantization_bit == 8:
require_version("bitsandbytes>=0.37.0", "To fix: pip install bitsandbytes>=0.37.0")
init_kwargs["quantization_config"] = BitsAndBytesConfig(load_in_8bit=True)
elif model_args.quantization_bit == 4:
require_version("bitsandbytes>=0.39.0", "To fix: pip install bitsandbytes>=0.39.0")
init_kwargs["quantization_config"] = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_compute_dtype=model_args.compute_dtype,
bnb_4bit_use_double_quant=model_args.double_quantization,
bnb_4bit_quant_type=model_args.quantization_type,
bnb_4bit_quant_storage=model_args.compute_dtype, # crucial for fsdp+qlora
)
if is_deepspeed_zero3_enabled() or is_fsdp_enabled() or model_args.quantization_device_map == "auto":
if model_args.quantization_bit != 4:
raise ValueError("Only 4-bit quantized model can use auto device map.")
require_version("transformers>=4.39.0", "To fix: pip install transformers>=4.39.0")
require_version("accelerate>=0.28.0", "To fix: pip install accelerate>=0.28.0")
require_version("bitsandbytes>=0.43.0", "To fix: pip install bitsandbytes>=0.43.0")
init_kwargs["torch_dtype"] = model_args.compute_dtype # fsdp+qlora requires same dtype
else:
init_kwargs["device_map"] = {"": get_current_device()}
logger.info("Quantizing model to {} bit.".format(model_args.quantization_bit))

View File

@@ -0,0 +1,47 @@
import math
from typing import TYPE_CHECKING
from ...extras.logging import get_logger
if TYPE_CHECKING:
from transformers import PretrainedConfig
from ...hparams import ModelArguments
logger = get_logger(__name__)
def configure_rope(config: "PretrainedConfig", model_args: "ModelArguments", is_trainable: bool) -> None:
if model_args.rope_scaling is None:
return
if not hasattr(config, "rope_scaling"):
logger.warning("Current model does not support RoPE scaling.")
return
if is_trainable:
if model_args.rope_scaling == "dynamic":
logger.warning(
"Dynamic NTK scaling may not work well with fine-tuning. "
"See: https://github.com/huggingface/transformers/pull/24653"
)
current_max_length = getattr(config, "max_position_embeddings", None)
if current_max_length and model_args.model_max_length > current_max_length:
logger.info(
"Enlarge max model length from {} to {}.".format(current_max_length, model_args.model_max_length)
)
setattr(config, "max_position_embeddings", model_args.model_max_length)
scaling_factor = float(math.ceil(model_args.model_max_length / current_max_length))
else:
logger.warning("Input length is smaller than max length. Consider increase input length.")
scaling_factor = 1.0
else:
scaling_factor = 2.0
setattr(config, "rope_scaling", {"type": model_args.rope_scaling, "factor": scaling_factor})
logger.info(
"Using {} scaling strategy and setting scaling factor to {}".format(model_args.rope_scaling, scaling_factor)
)

View File

@@ -0,0 +1,88 @@
from typing import TYPE_CHECKING, Any, Dict, Optional
from ...extras.logging import get_logger
from ...extras.misc import get_current_device
if TYPE_CHECKING:
from transformers import PretrainedConfig, PreTrainedModel
from ...hparams import ModelArguments
logger = get_logger(__name__)
def _get_unsloth_kwargs(
config: "PretrainedConfig", model_name_or_path: str, model_args: "ModelArguments"
) -> Dict[str, Any]:
return {
"model_name": model_name_or_path,
"max_seq_length": model_args.model_max_length or 4096,
"dtype": model_args.compute_dtype,
"load_in_4bit": model_args.quantization_bit == 4,
"token": model_args.hf_hub_token,
"device_map": {"": get_current_device()},
"rope_scaling": getattr(config, "rope_scaling", None),
"fix_tokenizer": False,
"trust_remote_code": True,
"use_gradient_checkpointing": "unsloth",
}
def load_unsloth_pretrained_model(
config: "PretrainedConfig", model_args: "ModelArguments"
) -> Optional["PreTrainedModel"]:
r"""
Optionally loads pretrained model with unsloth. Used in training.
"""
from unsloth import FastLanguageModel
unsloth_kwargs = _get_unsloth_kwargs(config, model_args.model_name_or_path, model_args)
try:
model, _ = FastLanguageModel.from_pretrained(**unsloth_kwargs)
except NotImplementedError:
logger.warning("Unsloth does not support model type {}.".format(getattr(config, "model_type", None)))
model = None
model_args.use_unsloth = False
return model
def get_unsloth_peft_model(
model: "PreTrainedModel", model_args: "ModelArguments", peft_kwargs: Dict[str, Any]
) -> "PreTrainedModel":
r"""
Gets the peft model for the pretrained model with unsloth. Used in training.
"""
from unsloth import FastLanguageModel
unsloth_peft_kwargs = {
"model": model,
"max_seq_length": model_args.model_max_length,
"use_gradient_checkpointing": "unsloth",
}
return FastLanguageModel.get_peft_model(**peft_kwargs, **unsloth_peft_kwargs)
def load_unsloth_peft_model(
config: "PretrainedConfig", model_args: "ModelArguments", is_trainable: bool
) -> "PreTrainedModel":
r"""
Loads peft model with unsloth. Used in both training and inference.
"""
from unsloth import FastLanguageModel
unsloth_kwargs = _get_unsloth_kwargs(config, model_args.adapter_name_or_path[0], model_args)
try:
if not is_trainable:
unsloth_kwargs["use_gradient_checkpointing"] = False
model, _ = FastLanguageModel.from_pretrained(**unsloth_kwargs)
except NotImplementedError:
raise ValueError("Unsloth does not support model type {}.".format(getattr(config, "model_type", None)))
if not is_trainable:
FastLanguageModel.for_inference(model)
return model

View File

@@ -0,0 +1,59 @@
from typing import TYPE_CHECKING, Dict
import torch
from transformers.utils import cached_file
from ...extras.constants import V_HEAD_SAFE_WEIGHTS_NAME, V_HEAD_WEIGHTS_NAME
from ...extras.logging import get_logger
if TYPE_CHECKING:
from transformers import PreTrainedModel
from ...hparams import ModelArguments
logger = get_logger(__name__)
def load_valuehead_params(path_or_repo_id: str, model_args: "ModelArguments") -> Dict[str, torch.Tensor]:
r"""
Loads value head parameters from Hugging Face Hub or local disk.
Returns: dict with keys `v_head.summary.weight` and `v_head.summary.bias`.
"""
kwargs = {"path_or_repo_id": path_or_repo_id, "cache_dir": model_args.cache_dir, "token": model_args.hf_hub_token}
err_text = ""
try:
from safetensors import safe_open
vhead_file = cached_file(filename=V_HEAD_SAFE_WEIGHTS_NAME, **kwargs)
with safe_open(vhead_file, framework="pt", device="cpu") as f:
return {key: f.get_tensor(key) for key in f.keys()}
except Exception as err:
err_text = str(err)
try:
vhead_file = cached_file(filename=V_HEAD_WEIGHTS_NAME, **kwargs)
return torch.load(vhead_file, map_location="cpu")
except Exception as err:
err_text = str(err)
logger.info("Provided path ({}) does not contain value head weights: {}.".format(path_or_repo_id, err_text))
logger.info("Ignore the above message if you are not resuming the training of a value head model.")
return None
def prepare_valuehead_model(model: "PreTrainedModel") -> None:
if getattr(model.config, "model_type", None) == "llava":
setattr(model, "lm_head", model.language_model.get_output_embeddings())
setattr(model, "_keys_to_ignore_on_save", ["lm_head.weight"])
if getattr(model.config, "model_type", None) == "chatglm":
setattr(model, "lm_head", model.transformer.output_layer)
setattr(model, "_keys_to_ignore_on_save", ["lm_head.weight"])
if getattr(model.config, "model_type", None) == "internlm2":
setattr(model, "lm_head", model.output)
setattr(model, "_keys_to_ignore_on_save", ["lm_head.weight"])

View File

@@ -0,0 +1,84 @@
from typing import TYPE_CHECKING, Tuple
import torch
import transformers.models
from transformers.activations import ACT2FN
from ...extras.logging import get_logger
if TYPE_CHECKING:
from transformers import LlavaConfig, PretrainedConfig, PreTrainedModel
from ...hparams import ModelArguments
logger = get_logger(__name__)
class LlavaMultiModalProjectorForYiVL(torch.nn.Module):
def __init__(self, config: "LlavaConfig") -> None:
super().__init__()
self.config = config
if config is None:
return
self.linear_1 = torch.nn.Linear(config.vision_config.hidden_size, config.text_config.hidden_size, bias=True)
self.linear_2 = torch.nn.LayerNorm(config.text_config.hidden_size, bias=True)
self.linear_3 = torch.nn.Linear(config.text_config.hidden_size, config.text_config.hidden_size, bias=True)
self.linear_4 = torch.nn.LayerNorm(config.text_config.hidden_size, bias=True)
self.act = ACT2FN[config.projector_hidden_act]
def forward(self, image_features: "torch.Tensor") -> "torch.Tensor":
hidden_states = self.linear_1(image_features)
hidden_states = self.linear_2(hidden_states)
hidden_states = self.act(hidden_states)
hidden_states = self.linear_3(hidden_states)
hidden_states = self.linear_4(hidden_states)
if hidden_states.dtype == torch.float32:
if torch.is_autocast_enabled():
target_dtype = torch.get_autocast_gpu_dtype()
elif hasattr(self.config, "_pre_quantization_dtype"):
target_dtype = self.config._pre_quantization_dtype
else:
target_dtype = self.linear_1.weight.dtype
logger.warning_once("The hidden states seems to be silently casted in float32.")
hidden_states = hidden_states.to(target_dtype)
return hidden_states
class LlavaMultiModalProjectorForYiVLForVLLM(LlavaMultiModalProjectorForYiVL):
def __init__(self, vision_hidden_size: int, text_hidden_size: int, projector_hidden_act: str) -> None:
super().__init__(config=None)
self.linear_1 = torch.nn.Linear(vision_hidden_size, text_hidden_size, bias=True)
self.linear_2 = torch.nn.LayerNorm(text_hidden_size, bias=True)
self.linear_3 = torch.nn.Linear(text_hidden_size, text_hidden_size, bias=True)
self.linear_4 = torch.nn.LayerNorm(text_hidden_size, bias=True)
self.act = ACT2FN[projector_hidden_act]
def autocast_projector_dtype(
model: "PreTrainedModel", model_args: "ModelArguments", mm_projector_name: str = "multi_modal_projector"
) -> None:
def _mm_projector_forward_post_hook(
module: "torch.nn.Module", args: Tuple["torch.Tensor"], output: "torch.Tensor"
) -> "torch.Tensor":
return output.to(model_args.compute_dtype)
if hasattr(model, mm_projector_name) and getattr(model, "quantization_method", None):
logger.info("Casting multimodal projector outputs in {}.".format(model_args.compute_dtype))
mm_projector: "torch.nn.Module" = getattr(model, mm_projector_name)
mm_projector.register_forward_hook(_mm_projector_forward_post_hook)
def configure_visual_model(config: "PretrainedConfig") -> None:
if getattr(config, "model_type", None) == "llava": # required for ds zero3 and valuehead models
setattr(config, "hidden_size", getattr(config.text_config, "hidden_size", None))
if getattr(config, "is_yi_vl_derived_model", None):
logger.info("Detected Yi-VL model, applying projector patch.")
transformers.models.llava.modeling_llava.LlavaMultiModalProjector = LlavaMultiModalProjectorForYiVL

View File

@@ -0,0 +1,143 @@
import os
from types import MethodType
from typing import TYPE_CHECKING, Any, Dict
import torch
from peft import PeftModel
from transformers import PreTrainedModel, PreTrainedTokenizerBase, is_torch_npu_available
from transformers.integrations import is_deepspeed_zero3_enabled
from transformers.modeling_utils import is_fsdp_enabled
from ..extras.logging import get_logger
from ..extras.misc import infer_optim_dtype
from .model_utils.attention import configure_attn_implementation, print_attn_implementation
from .model_utils.checkpointing import prepare_model_for_training
from .model_utils.embedding import resize_embedding_layer
from .model_utils.longlora import configure_longlora
from .model_utils.moe import add_z3_leaf_module, configure_moe
from .model_utils.quantization import configure_quantization
from .model_utils.rope import configure_rope
from .model_utils.valuehead import prepare_valuehead_model
from .model_utils.visual import autocast_projector_dtype, configure_visual_model
if TYPE_CHECKING:
from transformers import PretrainedConfig, PreTrainedTokenizer
from trl import AutoModelForCausalLMWithValueHead
from ..hparams import ModelArguments
logger = get_logger(__name__)
def patch_tokenizer(tokenizer: "PreTrainedTokenizer") -> None:
if "PreTrainedTokenizerBase" not in str(tokenizer._pad.__func__):
tokenizer._pad = MethodType(PreTrainedTokenizerBase._pad, tokenizer)
def patch_config(
config: "PretrainedConfig",
tokenizer: "PreTrainedTokenizer",
model_args: "ModelArguments",
init_kwargs: Dict[str, Any],
is_trainable: bool,
) -> None:
if model_args.compute_dtype is None: # priority: bf16 > fp16 > fp32
model_args.compute_dtype = infer_optim_dtype(model_dtype=getattr(config, "torch_dtype", None))
if is_torch_npu_available():
use_jit_compile = os.environ.get("JIT_COMPILE", "0").lower() in ["true", "1"]
torch.npu.set_compile_mode(jit_compile=use_jit_compile)
configure_attn_implementation(config, model_args)
configure_rope(config, model_args, is_trainable)
configure_longlora(config, model_args, is_trainable)
configure_quantization(config, tokenizer, model_args, init_kwargs)
configure_moe(config, model_args, is_trainable)
configure_visual_model(config)
if model_args.use_cache and not is_trainable:
setattr(config, "use_cache", True)
logger.info("Using KV cache for faster generation.")
if getattr(config, "model_type", None) == "qwen":
setattr(config, "use_flash_attn", model_args.flash_attn == "fa2")
for dtype_name, dtype in [("fp16", torch.float16), ("bf16", torch.bfloat16), ("fp32", torch.float32)]:
setattr(config, dtype_name, model_args.compute_dtype == dtype)
if getattr(config, "model_type", None) == "qwen2" and is_trainable and model_args.flash_attn == "fa2":
setattr(config, "use_cache", False) # qwen2 does not support use_cache when using flash attn
# deepspeed zero3 is not compatible with low_cpu_mem_usage
init_kwargs["low_cpu_mem_usage"] = model_args.low_cpu_mem_usage and (not is_deepspeed_zero3_enabled())
if not is_deepspeed_zero3_enabled() and not is_fsdp_enabled(): # cast dtype and device if not use zero3 or fsdp
init_kwargs["torch_dtype"] = model_args.compute_dtype
if init_kwargs["low_cpu_mem_usage"]: # device map requires low_cpu_mem_usage=True
if "device_map" not in init_kwargs and model_args.device_map:
init_kwargs["device_map"] = model_args.device_map
if init_kwargs["device_map"] == "auto":
init_kwargs["offload_folder"] = model_args.offload_folder
def patch_model(
model: "PreTrainedModel",
tokenizer: "PreTrainedTokenizer",
model_args: "ModelArguments",
is_trainable: bool,
add_valuehead: bool,
) -> None:
gen_config = model.generation_config # check and fix generation config
if not gen_config.do_sample and (
(gen_config.temperature is not None and gen_config.temperature != 1.0)
or (gen_config.top_p is not None and gen_config.top_p != 1.0)
or (gen_config.typical_p is not None and gen_config.typical_p != 1.0)
):
gen_config.do_sample = True
if "GenerationMixin" not in str(model.generate.__func__):
model.generate = MethodType(PreTrainedModel.generate, model)
if add_valuehead:
prepare_valuehead_model(model)
if model_args.resize_vocab:
resize_embedding_layer(model, tokenizer)
if model_args.visual_inputs:
autocast_projector_dtype(model, model_args)
if is_trainable:
prepare_model_for_training(model, model_args)
add_z3_leaf_module(model)
if not model_args.use_unsloth:
print_attn_implementation(model.config)
try:
model.add_model_tags(["llama-factory"])
except Exception:
logger.warning("Cannot properly tag the model.")
def patch_valuehead_model(model: "AutoModelForCausalLMWithValueHead") -> None:
def tie_weights(self: "AutoModelForCausalLMWithValueHead") -> None:
if isinstance(self.pretrained_model, PreTrainedModel):
self.pretrained_model.tie_weights()
def get_input_embeddings(self: "AutoModelForCausalLMWithValueHead") -> torch.nn.Module:
if isinstance(self.pretrained_model, PreTrainedModel):
return self.pretrained_model.get_input_embeddings()
def create_or_update_model_card(self: "AutoModelForCausalLMWithValueHead", output_dir: str) -> None:
if isinstance(self.pretrained_model, PeftModel):
self.pretrained_model.create_or_update_model_card(output_dir)
ignore_modules = [name for name, _ in model.named_parameters() if "pretrained_model" in name]
setattr(model, "_keys_to_ignore_on_save", ignore_modules)
setattr(model, "tie_weights", MethodType(tie_weights, model))
setattr(model, "get_input_embeddings", MethodType(get_input_embeddings, model))
setattr(model, "create_or_update_model_card", MethodType(create_or_update_model_card, model))

View File

@@ -0,0 +1,233 @@
from collections import defaultdict
from contextlib import nullcontext
from types import MethodType
from typing import TYPE_CHECKING, Dict, Literal, Optional, Tuple, Union
import torch
import torch.nn.functional as F
from transformers import Trainer
from trl import DPOTrainer
from trl.trainer import disable_dropout_in_model
from ...extras.constants import IGNORE_INDEX
from ..trainer_utils import create_custom_optimzer, create_custom_scheduler, get_batch_logps, get_ref_context
if TYPE_CHECKING:
from transformers import PreTrainedModel, ProcessorMixin
from ...hparams import FinetuningArguments
class CustomDPOTrainer(DPOTrainer):
def __init__(
self,
model: Union["PreTrainedModel", torch.nn.Module],
ref_model: Optional[Union["PreTrainedModel", torch.nn.Module]],
finetuning_args: "FinetuningArguments",
processor: Optional["ProcessorMixin"],
disable_dropout: bool = True,
**kwargs,
):
if disable_dropout:
disable_dropout_in_model(model)
if ref_model is not None:
disable_dropout_in_model(ref_model)
self.finetuning_args = finetuning_args
self.processor = processor
self.reference_free = False
self.use_dpo_data_collator = True # hack to avoid warning
self.generate_during_eval = False # disable at evaluation
self.label_pad_token_id = IGNORE_INDEX
self.padding_value = 0
self.is_encoder_decoder = model.config.is_encoder_decoder
self.precompute_ref_log_probs = False
self._precomputed_train_ref_log_probs = False
self._precomputed_eval_ref_log_probs = False
self._peft_has_been_casted_to_bf16 = False
self.ref_model = ref_model
self._stored_metrics = defaultdict(lambda: defaultdict(list))
# dpo hyperparams
self.beta = finetuning_args.pref_beta
self.loss_type = finetuning_args.pref_loss
self.ftx_gamma = finetuning_args.pref_ftx
self.label_smoothing = finetuning_args.dpo_label_smoothing
self.simpo_gamma = finetuning_args.simpo_gamma
Trainer.__init__(self, model=model, **kwargs)
if not hasattr(self, "accelerator"):
raise AttributeError("Please update `transformers`.")
if ref_model is not None:
if self.is_deepspeed_enabled:
if not (
getattr(ref_model, "is_loaded_in_8bit", False) or getattr(ref_model, "is_loaded_in_4bit", False)
): # quantized models are already set on the correct device
self.ref_model = self._prepare_deepspeed(self.ref_model)
else:
self.ref_model = self.accelerator.prepare_model(self.ref_model, evaluation_mode=True)
self.ref_model.eval()
if finetuning_args.use_badam:
from badam import clip_grad_norm_for_sparse_tensor
self.accelerator.clip_grad_norm_ = MethodType(clip_grad_norm_for_sparse_tensor, self.accelerator)
def create_optimizer(self) -> "torch.optim.Optimizer":
if self.optimizer is None:
self.optimizer = create_custom_optimzer(self.model, self.args, self.finetuning_args)
return super().create_optimizer()
def create_scheduler(
self, num_training_steps: int, optimizer: Optional["torch.optim.Optimizer"] = None
) -> "torch.optim.lr_scheduler.LRScheduler":
create_custom_scheduler(self.args, num_training_steps, optimizer)
return super().create_scheduler(num_training_steps, optimizer)
def _save(self, output_dir: Optional[str] = None, state_dict: Optional[Dict[str, "torch.Tensor"]] = None) -> None:
super()._save(output_dir, state_dict)
if self.processor is not None:
output_dir = output_dir if output_dir is not None else self.args.output_dir
getattr(self.processor, "image_processor").save_pretrained(output_dir)
def odds_ratio_loss(self, chosen_logps: "torch.Tensor", rejected_logps: "torch.Tensor") -> "torch.Tensor":
r"""
Computes ORPO's odds ratio (OR) loss for batched log probabilities of the policy model.
"""
log_odds = (chosen_logps - rejected_logps) - (
torch.log1p(-torch.exp(chosen_logps)) - torch.log1p(-torch.exp(rejected_logps))
)
sft_loss = -chosen_logps
odds_ratio_loss = -F.logsigmoid(log_odds)
orpo_loss = sft_loss + self.beta * odds_ratio_loss
return orpo_loss
def simpo_loss(self, chosen_logps: "torch.Tensor", rejected_logps: "torch.Tensor") -> "torch.Tensor":
r"""
Computes SimPO loss for batched log probabilities of the policy model.
"""
pi_logratios = chosen_logps - rejected_logps
gamma_logratios = self.simpo_gamma / self.beta
logits = pi_logratios - gamma_logratios
simpo_loss = -F.logsigmoid(self.beta * logits)
return simpo_loss
def compute_preference_loss(
self,
policy_chosen_logps: "torch.Tensor",
policy_rejected_logps: "torch.Tensor",
reference_chosen_logps: Optional["torch.Tensor"],
reference_rejected_logps: Optional["torch.Tensor"],
) -> Tuple["torch.Tensor", "torch.Tensor", "torch.Tensor"]:
r"""
Computes loss for preference learning.
"""
if not self.finetuning_args.use_ref_model:
if self.loss_type == "orpo":
losses = self.odds_ratio_loss(policy_chosen_logps, policy_rejected_logps)
elif self.loss_type == "simpo":
losses = self.simpo_loss(policy_chosen_logps, policy_rejected_logps)
else:
raise NotImplementedError("Unknown loss type: {}.".format(self.loss_type))
chosen_rewards = self.beta * policy_chosen_logps.to(self.accelerator.device).detach()
rejected_rewards = self.beta * policy_rejected_logps.to(self.accelerator.device).detach()
else:
losses, chosen_rewards, rejected_rewards = self.dpo_loss(
policy_chosen_logps, policy_rejected_logps, reference_chosen_logps, reference_rejected_logps
)
return losses, chosen_rewards, rejected_rewards
def concatenated_forward(
self, model: "PreTrainedModel", batch: Dict[str, "torch.Tensor"]
) -> Tuple["torch.Tensor", "torch.Tensor", "torch.Tensor", "torch.Tensor", "torch.Tensor"]:
r"""
Computes the sum log probabilities of the labels under given logits if loss_type is not IPO, ORPO or SimPO.
Otherwise the average log probabilities.
"""
if self.finetuning_args.use_ref_model:
batch = {k: v.detach().clone() for k, v in batch.items()} # avoid error
all_logits: "torch.Tensor" = model(**batch, return_dict=True, use_cache=False).logits.to(torch.float32)
all_logps, valid_length = get_batch_logps(logits=all_logits, labels=batch["labels"])
if self.loss_type in ["ipo", "orpo", "simpo"]:
all_logps = all_logps / valid_length
batch_size = batch["input_ids"].size(0) // 2
chosen_logps, rejected_logps = all_logps.split(batch_size, dim=0)
chosen_logits, rejected_logits = all_logits.split(batch_size, dim=0)
chosen_length, _ = valid_length.split(batch_size, dim=0)
return chosen_logps, rejected_logps, chosen_logits, rejected_logits, chosen_logps / chosen_length
def compute_reference_log_probs(
self, model: "PreTrainedModel", batch: Dict[str, "torch.Tensor"]
) -> Tuple[Optional["torch.Tensor"], Optional["torch.Tensor"]]:
r"""
Computes log probabilities of the reference model.
"""
if not self.finetuning_args.use_ref_model:
return None, None
if self.ref_model is None:
ref_model = model
ref_context = get_ref_context(self.accelerator, model)
else:
ref_model = self.ref_model
ref_context = nullcontext()
with torch.no_grad(), ref_context:
reference_chosen_logps, reference_rejected_logps, *_ = self.concatenated_forward(ref_model, batch)
return reference_chosen_logps, reference_rejected_logps
def get_batch_loss_metrics(
self,
model: "PreTrainedModel",
batch: Dict[str, "torch.Tensor"],
train_eval: Literal["train", "eval"] = "train",
) -> Tuple["torch.Tensor", Dict[str, "torch.Tensor"]]:
r"""
Computes the DPO loss and other metrics for the given batch of inputs for train or test.
"""
metrics = {}
(
policy_chosen_logps,
policy_rejected_logps,
policy_chosen_logits,
policy_rejected_logits,
policy_chosen_logps_avg,
) = self.concatenated_forward(model, batch)
reference_chosen_logps, reference_rejected_logps = self.compute_reference_log_probs(model, batch)
losses, chosen_rewards, rejected_rewards = self.compute_preference_loss(
policy_chosen_logps,
policy_rejected_logps,
reference_chosen_logps,
reference_rejected_logps,
)
sft_loss = -policy_chosen_logps_avg
if self.ftx_gamma > 1e-6:
losses += self.ftx_gamma * sft_loss
reward_accuracies = (chosen_rewards > rejected_rewards).float()
prefix = "eval_" if train_eval == "eval" else ""
metrics["{}rewards/chosen".format(prefix)] = chosen_rewards.mean().cpu()
metrics["{}rewards/rejected".format(prefix)] = rejected_rewards.mean().cpu()
metrics["{}rewards/accuracies".format(prefix)] = reward_accuracies.mean().cpu()
metrics["{}rewards/margins".format(prefix)] = (chosen_rewards - rejected_rewards).mean().cpu()
metrics["{}logps/rejected".format(prefix)] = policy_rejected_logps.detach().mean().cpu()
metrics["{}logps/chosen".format(prefix)] = policy_chosen_logps.detach().mean().cpu()
metrics["{}logits/rejected".format(prefix)] = policy_rejected_logits.detach().mean().cpu()
metrics["{}logits/chosen".format(prefix)] = policy_chosen_logits.detach().mean().cpu()
if self.loss_type == "orpo":
metrics["{}sft_loss".format(prefix)] = sft_loss.detach().mean().cpu()
metrics["{}odds_ratio_loss".format(prefix)] = ((losses - sft_loss) / self.beta).detach().mean().cpu()
return losses.mean(), metrics

Some files were not shown because too many files have changed in this diff Show More