mirror of
https://github.com/AIDotNet/AntSK.git
synced 2026-02-18 06:20:11 +08:00
Compare commits
112 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
1c34ad5987 | ||
|
|
4e8039703e | ||
|
|
c2b97c7f82 | ||
|
|
1a621f5cbc | ||
|
|
836f898ffe | ||
|
|
0a9a737709 | ||
|
|
0df8c74ec2 | ||
|
|
5e3ff74eaa | ||
|
|
7c49ff0a6c | ||
|
|
63f5267bca | ||
|
|
e0c35aac06 | ||
|
|
27d52d3331 | ||
|
|
e7b2c6e193 | ||
|
|
7600397b79 | ||
|
|
874b8e5d7f | ||
|
|
3ac18086a1 | ||
|
|
16049c7413 | ||
|
|
7bb7a41bb3 | ||
|
|
6c37ed66b2 | ||
|
|
bc86f96159 | ||
|
|
8eb09fb783 | ||
|
|
eff5f69f0f | ||
|
|
e3f966d4f2 | ||
|
|
015f51b99c | ||
|
|
cd66b61014 | ||
|
|
f0bef7d2fa | ||
|
|
de051b047d | ||
|
|
ed6f5dada2 | ||
|
|
d2e3fde829 | ||
|
|
195551e9c1 | ||
|
|
855103c2a4 | ||
|
|
6150d543d3 | ||
|
|
d968d78982 | ||
|
|
0ec5d1f1cf | ||
|
|
31f44c1758 | ||
|
|
e5f63d605d | ||
|
|
7db62e3dc6 | ||
|
|
4408fa4345 | ||
|
|
c5e952b98e | ||
|
|
bedfeaf53d | ||
|
|
d605fd6685 | ||
|
|
e5e3f7cd8f | ||
|
|
657949694c | ||
|
|
10b6035f84 | ||
|
|
3f9fe27456 | ||
|
|
da3a0681e5 | ||
|
|
57b7948d86 | ||
|
|
40b8bd0439 | ||
|
|
6ed9cc9b70 | ||
|
|
e51bf35217 | ||
|
|
28f88438e7 | ||
|
|
85f4a330d5 | ||
|
|
21d7c719f1 | ||
|
|
4ef398bd57 | ||
|
|
dc70270362 | ||
|
|
97b7211cce | ||
|
|
3e762e13af | ||
|
|
e084317a46 | ||
|
|
531b4473e8 | ||
|
|
aefd0d2775 | ||
|
|
960468edf0 | ||
|
|
07ad1f58b5 | ||
|
|
095428be50 | ||
|
|
87fc8911fa | ||
|
|
58272e1ce8 | ||
|
|
700bbcb63f | ||
|
|
dde1d68876 | ||
|
|
71553a6153 | ||
|
|
d4f8de3e21 | ||
|
|
6cf5dea10d | ||
|
|
05379dfee6 | ||
|
|
5a6d49ff64 | ||
|
|
64ab940a26 | ||
|
|
55982ea36d | ||
|
|
21efcf2479 | ||
|
|
0dc7bfcadb | ||
|
|
22d99091e1 | ||
|
|
7558d3ffdc | ||
|
|
85ae41c44c | ||
|
|
91193850dd | ||
|
|
7cc04e3364 | ||
|
|
3da28090c6 | ||
|
|
1595ef2c0a | ||
|
|
83e3d81de7 | ||
|
|
18437ddda4 | ||
|
|
fd503171a1 | ||
|
|
7022139780 | ||
|
|
1e508e45af | ||
|
|
03d9ec2cad | ||
|
|
86fb48bab7 | ||
|
|
a4bc1e4a55 | ||
|
|
8681e15da5 | ||
|
|
ebc82f8b1b | ||
|
|
3bcd7bd7e1 | ||
|
|
b64d8669b1 | ||
|
|
0489044098 | ||
|
|
17e2062b72 | ||
|
|
4e4f5a698d | ||
|
|
b879d04bcd | ||
|
|
95f918f4c7 | ||
|
|
f0e1ad6088 | ||
|
|
61773af48d | ||
|
|
54cd04c3bf | ||
|
|
cd9f4ae11b | ||
|
|
3f9c748b41 | ||
|
|
d483005531 | ||
|
|
1d2db6a896 | ||
|
|
9a7a263055 | ||
|
|
6beb0b52c7 | ||
|
|
0ea167a204 | ||
|
|
6e6afa2a7c | ||
|
|
7a2a5d86bb |
@@ -1,8 +1,4 @@
|
||||
# 1. Define the Python image to use for getting pip
|
||||
FROM pytorch/pytorch AS python-base
|
||||
|
||||
# 2. Define the .NET SDK image to build your application
|
||||
FROM mcr.microsoft.com/dotnet/sdk:8.0 AS build
|
||||
FROM mcr.microsoft.com/dotnet/sdk:8.0 AS build
|
||||
WORKDIR /src
|
||||
COPY ["src/AntSK/AntSK.csproj", "AntSK/"]
|
||||
RUN dotnet restore "AntSK/AntSK.csproj"
|
||||
@@ -11,19 +7,11 @@ WORKDIR "/src/AntSK"
|
||||
RUN dotnet build "AntSK.csproj" -c Release -o /app/build
|
||||
RUN dotnet publish "AntSK.csproj" -c Release -o /app/publish
|
||||
|
||||
# 3. Define the final image that will contain both .NET runtime and Python
|
||||
FROM mcr.microsoft.com/dotnet/aspnet:8.0 AS final
|
||||
|
||||
# Copy the Python/pip installation from the official Python image
|
||||
COPY --from=python-base /usr/local /usr/local
|
||||
COPY --from=python-base /opt/conda/ /opt/conda/
|
||||
FROM registry.cn-hangzhou.aliyuncs.com/xuzeyu91/antsk-base:v1.0.0 AS final
|
||||
WORKDIR /app
|
||||
COPY --from=build /app/publish .
|
||||
# Make sure the app and Python directories are in PATH
|
||||
ENV PATH="/app:/opt/conda/bin:/usr/local/bin:${PATH}"
|
||||
|
||||
ENV PATH="/app:/opt/conda/bin:/usr/local/bin:${PATH}"
|
||||
RUN ln -sf /usr/share/zoneinfo/Asia/Shanghai /etc/localtime
|
||||
RUN echo 'Asia/Shanghai' >/etc/timezone
|
||||
RUN pip install -r requirements.txt -i https://pypi.tuna.tsinghua.edu.cn/simple
|
||||
RUN apt update && apt install -y libpugixml-dev libtbb-dev
|
||||
ENTRYPOINT ["dotnet", "AntSK.dll"]
|
||||
|
||||
10
LICENSE
10
LICENSE
@@ -1,9 +1,17 @@
|
||||
Apache License
|
||||
AntSK License
|
||||
Version 2.0, January 2004
|
||||
http://www.apache.org/licenses/
|
||||
|
||||
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
|
||||
|
||||
This project follows the Apache 2.0 agreement, in addition to the following additional terms
|
||||
1.This project can be used for commercial purposes, but it has the right to prohibit you from using it if it violates the following provisions
|
||||
2. Without authorization, you are not allowed to modify AntSK's logo and title information
|
||||
3. Without authorization, you are not allowed to modify the copyright information at the bottom of the page
|
||||
4. If you need authorization, you can contact WeChat: xuzeyu91 or Email:antskpro@qq.com
|
||||
|
||||
|
||||
Apache 2.0 License
|
||||
1. Definitions.
|
||||
|
||||
"License" shall mean the terms and conditions for use, reproduction,
|
||||
|
||||
52
README.md
52
README.md
@@ -20,7 +20,7 @@
|
||||
|
||||
- **Online Search**: AntSK, real-time access to the latest information, ensuring users receive the most timely and relevant data.
|
||||
|
||||
- **Model Management**: Adapts and manages integration of different models from different manufacturers, including gguf types supported by **llama.cpp** and models offline running supported by **llamafactory**.
|
||||
- **Model Management**: Adapts and manages integration of different models from different manufacturers, including gguf types supported by **llama.cpp** and models offline running supported by **llamafactory** and **ollama**.
|
||||
|
||||
- **Domestic Innovation**: AntSK supports domestic models and databases and can run under domestic innovation conditions.
|
||||
|
||||
@@ -41,7 +41,9 @@ AntSK is suitable for various business scenarios, such as:
|
||||
### Online Demo
|
||||
[document](http://antsk.cn/)
|
||||
|
||||
[demo](https://antsk.ai-dotnet.com/)
|
||||
[demo](https://demo.antsk.cn/)
|
||||
and
|
||||
[demo1](https://antsk.ai-dotnet.com/)
|
||||
|
||||
```
|
||||
Default account: test
|
||||
@@ -84,7 +86,8 @@ version: '3.8'
|
||||
services:
|
||||
antsk:
|
||||
container_name: antsk
|
||||
image: registry.cn-hangzhou.aliyuncs.com/AIDotNet/antsk:v0.1.5ports:
|
||||
image: registry.cn-hangzhou.aliyuncs.com/AIDotNet/antsk:v0.5.0
|
||||
ports:
|
||||
- 5000:5000
|
||||
networks:
|
||||
- antsk
|
||||
@@ -119,11 +122,6 @@ model/xxx.gguf
|
||||
"FileDir": {
|
||||
"DirectoryPath": "D:\\git\\AntBlazor\\model"
|
||||
},
|
||||
"LLamaSharp": {
|
||||
"RunType": "GPU",
|
||||
"ContextSize": 2048,
|
||||
"GpuLayerCount": 20
|
||||
},
|
||||
"Login": {
|
||||
"User": "admin",
|
||||
"Password": "xuzeyu"
|
||||
@@ -147,11 +145,8 @@ DBConnection.ConnectionStrings
|
||||
//The ConnectionString of Qdrant and AzureAISearch uses Endpoint | APIKey
|
||||
KernelMemory.VectorDb
|
||||
|
||||
//Local model execution options: GPU and CPU. When using the online API, any option can be used.
|
||||
LLamaSharp.RunType
|
||||
|
||||
//Local model path, used for quick selection of models under llama, as well as saving downloaded models.
|
||||
LLamaSharp.FileDirectory
|
||||
FileDir.DirectoryPath
|
||||
|
||||
//Default admin account password
|
||||
Login
|
||||
@@ -188,13 +183,6 @@ I'm using CodeFirst mode for the database, so as long as the database connection
|
||||
8. Many people ask about the difference between LLamaSharp and llamafactory. In fact, LLamaSharp is a .NET implementation of llama.cpp, but only supports local gguf models, while llamafactory supports a wider variety of models and uses Python implementation. The main difference lies here. Additionally, llamafactory has the ability to fine-tune models, which is an area we will focus on integrating in the future.
|
||||
```
|
||||
|
||||
## 🤝 Contributing
|
||||
|
||||
[PRs Welcome](https://img.shields.io/badge/PRs-welcome-brightgreen.svg?style=flat-square)
|
||||
|
||||
If you would like to contribute, feel free to create a [Pull Request](https://github.com/AIDotNet/AntSK/pulls), or give us [Bug Report](https://github.com/AIDotNet/AntSK/issues/new).
|
||||
|
||||
|
||||
## 💕 Contributors
|
||||
|
||||
This project exists thanks to all the people who contribute.
|
||||
@@ -204,15 +192,35 @@ This project exists thanks to all the people who contribute.
|
||||
</a>
|
||||
|
||||
## 🚨 Use Protocol
|
||||
This warehouse follows the [Apache-2.0 License](https://github.com/AIDotNet/AntSK?tab=Apache-2.0-1-ov-file) open source protocol.
|
||||
The Apache open source license allows the use of AntSK in commercial environments, provided that the license terms are followed. One of the main terms is to retain the copyright and license statements.
|
||||
|
||||
This warehouse follows the [AntSK License](https://github.com/AIDotNet/AntSK?tab=Apache-2.0-1-ov-file) open source protocol.
|
||||
|
||||
This project follows the Apache 2.0 agreement, in addition to the following additional terms
|
||||
|
||||
1. This project can be used for commercial purposes, but it has the right to prohibit you from using it if it violates the following provisions
|
||||
|
||||
2. Without authorization, you are not allowed to modify AntSK's logo and title information
|
||||
|
||||
4. Without authorization, you are not allowed to modify the copyright information at the bottom of the page
|
||||
|
||||
6. If you need authorization, you can contact WeChat: **xuzeyu91**
|
||||
|
||||
If you plan to use AntSK in commercial projects, you need to ensure that you follow the following steps:
|
||||
1. Copyright statement containing Apache license. [Apache-2.0 License](https://github.com/AIDotNet/AntSK?tab=Apache-2.0-1-ov-file).
|
||||
|
||||
1. Copyright statement containing AntSK license. [AntSK License](https://github.com/AIDotNet/AntSK?tab=Apache-2.0-1-ov-file).
|
||||
|
||||
2. If you modify the software source code, you need to clearly indicate these modifications in the source code.
|
||||
|
||||
3. Meet the above requirements
|
||||
|
||||
## 💕 Special thanks
|
||||
Helping enterprise AI application development, we recommend [AntBlazor](https://antblazor.com)
|
||||
|
||||
## ☎️Contact Me
|
||||
If you have any questions or suggestions, please contact me through my official WeChat account. We also have a discussion group where you can send a message to join, and then I will add you to the group.
|
||||
|
||||
Additionally, you can also contact me via email: antskpro@qq.com
|
||||
|
||||

|
||||
|
||||
---
|
||||
|
||||
49
README.zh.md
49
README.zh.md
@@ -22,7 +22,7 @@
|
||||
|
||||
- **联网搜索**:AntSK,实时获取最新信息,确保用户接受到的资料总是最及时、最相关的。
|
||||
|
||||
- **模型管理**:适配和管理集成不同厂商的不同模型。并且支持**llama.cpp**所支持的gguf类型,以及**llamafactory**所支持的模型离线运行
|
||||
- **模型管理**:适配和管理集成不同厂商的不同模型。并且支持**llama.cpp**所支持的gguf类型,以及**llamafactory** 和 **ollama** 所支持的模型离线运行
|
||||
|
||||
- **国产信创**:AntSK支持国产模型,和国产数据库,可以在信创条件下运行
|
||||
|
||||
@@ -43,10 +43,11 @@ AntSK 适用于多种业务场景,例如:
|
||||
## ✏️功能示例
|
||||
### 在线演示
|
||||
|
||||
[文档地址](http://antsk.cn/)
|
||||
[体验地址1](https://demo.antsk.cn/)
|
||||
|
||||
[体验地址](https://antsk.ai-dotnet.com/)
|
||||
和
|
||||
|
||||
[体验地址2](https://antsk.ai-dotnet.com/)
|
||||
```
|
||||
默认账号:test
|
||||
|
||||
@@ -130,11 +131,6 @@ model/xxx.gguf
|
||||
"FileDir": {
|
||||
"DirectoryPath": "D:\\git\\AntBlazor\\model"
|
||||
},
|
||||
"LLamaSharp": {
|
||||
"RunType": "GPU",
|
||||
"ContextSize": 2048,
|
||||
"GpuLayerCount": 20
|
||||
},
|
||||
"Login": {
|
||||
"User": "admin",
|
||||
"Password": "xuzeyu"
|
||||
@@ -157,11 +153,8 @@ DBConnection.ConnectionStrings
|
||||
//Qdrant 和AzureAISearch 的 ConnectionString 使用 Endpoint|APIKey
|
||||
KernelMemory.VectorDb
|
||||
|
||||
//本地模型使用的运行方式 GUP CPU ,如果用在线API 这个随意使用一个即可
|
||||
LLamaSharp.RunType
|
||||
|
||||
//本地模型路径,用于在选择llama时可以快速选择目录下的模型,以及保存下载的模型
|
||||
LLamaSharp.FileDirectory
|
||||
FileDir.DirectoryPath
|
||||
|
||||
//默认管理员账号密码
|
||||
Login
|
||||
@@ -195,13 +188,6 @@ DB我使用的是CodeFirst模式,只要配置好数据库链接,表结构是
|
||||
7、点击保存,然后就可以开始聊天了
|
||||
8、很多人会问 LLamaSharp与llamafactory有什么区别?其实这两者LLamaSharp是llama.cpp的 dotnet实现,但是只支持本地gguf模型, 而llamafactory 支持的模型种类更多,但使用的是python的实现,其主要差异在这里,另外llamafactory具有模型微调的能力,这也是我们下一步需要重点集成的部分。
|
||||
```
|
||||
|
||||
## 🤝 贡献
|
||||
|
||||
[](https://github.com/AIDotNet/AntSK/pulls)
|
||||
|
||||
如果你想贡献,可以创建一个[拉取请求](https://github.com/AIDotNet/AntSK/pulls), 或给我们[错误报告](https://github.com/AIDotNet/AntSK/issues/new).
|
||||
|
||||
|
||||
## 💕 贡献者
|
||||
|
||||
@@ -213,18 +199,35 @@ DB我使用的是CodeFirst模式,只要配置好数据库链接,表结构是
|
||||
|
||||
## 🚨 使用协议
|
||||
|
||||
本仓库遵循 [Apache-2.0 License](https://github.com/AIDotNet/AntSK?tab=Apache-2.0-1-ov-file) 开源协议。
|
||||
Apache开源许可证允许在商业环境中使用AntSK,前提是需要遵守许可证的条款。主要条款之一是要保留版权声明和许可证声明。
|
||||
本仓库遵循 [AntSK License](https://github.com/AIDotNet/AntSK?tab=Apache-2.0-1-ov-file) 开源协议。
|
||||
|
||||
除以下附加条款外,该项目遵循Apache 2.0协议
|
||||
|
||||
1. 本项目可以用于商业目的,但如果违反以下规定,它有权禁止您使用
|
||||
|
||||
2. 未经授权,您不允许修改AntSK的徽标和标题信息
|
||||
|
||||
3. 未经授权,您不能修改页面底部的版权信息
|
||||
|
||||
4. 如果您需要授权,可以联系微信:xuzeyu91
|
||||
|
||||
如果您打算在商业项目中使用AntSK,您需要确保遵守以下步骤:
|
||||
|
||||
1、包含Apache许可证的版权声明。 [Apache-2.0 License](https://github.com/AIDotNet/AntSK?tab=Apache-2.0-1-ov-file) 。
|
||||
1. 包含AntSK许可证的版权声明。 [AntSK License](https://github.com/AIDotNet/AntSK?tab=Apache-2.0-1-ov-file) 。
|
||||
|
||||
2、如果您修改了软件源代码,您需要在源代码中明确标明这些修改。
|
||||
2. 如果您修改了软件源代码,您需要在源代码中明确标明这些修改。
|
||||
|
||||
3. 满足以上要求
|
||||
|
||||
## 💕 特别感谢
|
||||
助力企业级AI应用开发,推荐使用 [AntBlazor](https://antblazor.com)
|
||||
|
||||
|
||||
## ☎️联系我
|
||||
如有任何问题或建议,请通过以下方式关注我的公众号《许泽宇的技术分享》,发消息与我联系,我们也有AIDotnet交流群,可以发送进群等消息,然后我会拉你进交流群
|
||||
|
||||
另外您也可以通过邮箱与我联系:antskpro@qq.com
|
||||
|
||||

|
||||
|
||||
## 🌟 Star History
|
||||
|
||||
@@ -3,9 +3,9 @@ version: '3.8'
|
||||
services:
|
||||
antsk:
|
||||
container_name: antsk
|
||||
image: registry.cn-hangzhou.aliyuncs.com/xuzeyu91/antsk:v0.3.6
|
||||
image: registry.cn-hangzhou.aliyuncs.com/xuzeyu91/antsk:v0.5.1
|
||||
# 如果需要pytorch环境需要使用下面这个镜像,镜像比较大
|
||||
# image: registry.cn-hangzhou.aliyuncs.com/xuzeyu91/antsk:p0.3.6
|
||||
# image: registry.cn-hangzhou.aliyuncs.com/xuzeyu91/antsk:p0.5.1
|
||||
ports:
|
||||
- 5000:5000
|
||||
networks:
|
||||
|
||||
@@ -32,9 +32,9 @@ services:
|
||||
- ./pg/data:/var/lib/postgresql/data
|
||||
antsk:
|
||||
container_name: antsk
|
||||
image: registry.cn-hangzhou.aliyuncs.com/xuzeyu91/antsk:v0.3.6
|
||||
image: registry.cn-hangzhou.aliyuncs.com/xuzeyu91/antsk:v0.5.1
|
||||
# 如果需要pytorch环境需要使用下面这个镜像,镜像比较大
|
||||
# image: registry.cn-hangzhou.aliyuncs.com/xuzeyu91/antsk:p0.3.6
|
||||
# image: registry.cn-hangzhou.aliyuncs.com/xuzeyu91/antsk:p0.5.1
|
||||
ports:
|
||||
- 5000:5000
|
||||
networks:
|
||||
|
||||
@@ -8,50 +8,44 @@
|
||||
<NoWarn>CA1050,CA1707,CA2007,VSTHRD111,CS1591,RCS1110,CA5394,SKEXP0001,SKEXP0002,SKEXP0003,SKEXP0004,SKEXP0010,SKEXP0011,,SKEXP0012,SKEXP0020,SKEXP0021,SKEXP0022,SKEXP0023,SKEXP0024,SKEXP0025,SKEXP0026,SKEXP0027,SKEXP0028,SKEXP0029,SKEXP0030,SKEXP0031,SKEXP0032,SKEXP0040,SKEXP0041,SKEXP0042,SKEXP0050,SKEXP0051,SKEXP0052,SKEXP0053,SKEXP0054,SKEXP0055,SKEXP0060,SKEXP0061,SKEXP0101,SKEXP0102,KMEXP00</NoWarn>
|
||||
</PropertyGroup>
|
||||
<ItemGroup>
|
||||
<PackageReference Include="AntDesign.Charts" Version="0.5.1" />
|
||||
<PackageReference Include="AntDesign.ProLayout" Version="0.19.0" />
|
||||
<PackageReference Include="AntDesign.Charts" Version="0.5.5" />
|
||||
<PackageReference Include="AntDesign.ProLayout" Version="0.20.3" />
|
||||
<PackageReference Include="BlazorComponents.Terminal" Version="0.6.0" />
|
||||
<PackageReference Include="Blazored.LocalStorage" Version="4.5.0" />
|
||||
|
||||
<PackageReference Include="pythonnet" Version="3.0.3" />
|
||||
|
||||
<PackageReference Include="Swashbuckle.AspNetCore" Version="6.6.2" />
|
||||
<PackageReference Include="Swashbuckle.AspNetCore" Version="6.7.0" />
|
||||
|
||||
<PackageReference Include="AutoMapper" Version="8.1.0" />
|
||||
<PackageReference Include="BCrypt.Net-Next" Version="4.0.3" />
|
||||
<PackageReference Include="Markdig" Version="0.37.0" />
|
||||
<PackageReference Include="Newtonsoft.Json" Version="13.0.3" />
|
||||
<PackageReference Include="SqlSugarCore" Version="5.1.4.158" />
|
||||
<PackageReference Include="Newtonsoft.Json" Version="$(NewtonsoftVersion)" />
|
||||
<PackageReference Include="SqlSugarCore" Version="5.1.4.166" />
|
||||
<PackageReference Include="System.Data.SQLite.Core" Version="1.0.118" />
|
||||
<PackageReference Include="RestSharp" Version="111.2.0" />
|
||||
<PackageReference Include="NPOI" Version="2.7.0" />
|
||||
|
||||
<PackageReference Include="Microsoft.SemanticKernel" Version="1.14.1" />
|
||||
<PackageReference Include="Microsoft.SemanticKernel.Core" Version="1.14.1" />
|
||||
<PackageReference Include="Microsoft.SemanticKernel.Plugins.Core" Version="1.14.1-alpha" />
|
||||
<PackageReference Include="RestSharp" Version="$(RestSharpVersion)" />
|
||||
<PackageReference Include="NPOI" Version="2.7.1" />
|
||||
|
||||
<PackageReference Include="Microsoft.SemanticKernel" Version="$(SKVersion)" />
|
||||
<PackageReference Include="Microsoft.SemanticKernel.Core" Version="$(SKVersion)" />
|
||||
<PackageReference Include="Microsoft.SemanticKernel.Plugins.Core" Version="$(SKVersion)-alpha" />
|
||||
<PackageReference Include="Microsoft.KernelMemory.Core" Version="$(KMVersion)" />
|
||||
<PackageReference Include="Microsoft.KernelMemory.MemoryDb.Postgres" Version="$(KMVersion)" />
|
||||
<PackageReference Include="Microsoft.KernelMemory.MemoryDb.Qdrant" Version="$(KMVersion)" />
|
||||
<PackageReference Include="Microsoft.KernelMemory.MemoryDb.Redis" Version="$(KMVersion)" />
|
||||
<PackageReference Include="Microsoft.KernelMemory.MemoryDb.AzureAISearch" Version="$(KMVersion)" />
|
||||
|
||||
<PackageReference Include="LLamaSharp" Version="$(LLamaSharpVersion)" />
|
||||
<PackageReference Include="LLamaSharp.Backend.Cpu" Version="$(LLamaSharpVersion)" />
|
||||
<PackageReference Include="LLamaSharp.Backend.Cuda12" Version="$(LLamaSharpVersion)" />
|
||||
<PackageReference Include="LLamaSharp.kernel-memory" Version="$(LLamaSharpVersion)" />
|
||||
<PackageReference Include="LLamaSharp.semantic-kernel" Version="$(LLamaSharpVersion)" />
|
||||
|
||||
<PackageReference Include="Serilog" Version="4.0.0" />
|
||||
<PackageReference Include="Serilog.Sinks.Console" Version="5.1.0-dev-00943" />
|
||||
<PackageReference Include="Serilog.Sinks.File" Version="5.0.1-dev-00972" />
|
||||
<PackageReference Include="Serilog" Version="4.0.1" />
|
||||
<PackageReference Include="Serilog.Sinks.Console" Version="6.0.0" />
|
||||
<PackageReference Include="Serilog.Sinks.File" Version="6.0.0" />
|
||||
<PackageReference Include="Serilog.Extensions.Logging" Version="8.0.1-dev-10391" />
|
||||
<PackageReference Include="Serilog.Settings.Configuration" Version="8.0.1" />
|
||||
<PackageReference Include="Serilog.Settings.Configuration" Version="8.0.2" />
|
||||
<PackageReference Include="Serilog.Sinks.Seq" Version="8.0.0" />
|
||||
<PackageReference Include="Serilog.Sinks.OpenTelemetry" Version="3.0.0" />
|
||||
<PackageReference Include="Serilog.Sinks.OpenTelemetry" Version="4.0.0" />
|
||||
</ItemGroup>
|
||||
<ItemGroup>
|
||||
<ProjectReference Include="..\AntSK.LLamaFactory\AntSK.LLamaFactory.csproj" />
|
||||
<ProjectReference Include="..\AntSk.LLM\AntSK.LLM.csproj" />
|
||||
<ProjectReference Include="..\AntSK.LLM\AntSK.LLM.csproj" />
|
||||
<ProjectReference Include="..\AntSK.OCR\AntSK.OCR.csproj" />
|
||||
<ProjectReference Include="..\MiddleWare\AntSK.BackgroundTask\AntSK.BackgroundTask.csproj" />
|
||||
</ItemGroup>
|
||||
|
||||
@@ -188,11 +188,6 @@
|
||||
<member name="M:AntSK.Domain.Domain.Other.KMExcelHandler.InvokeAsync(Microsoft.KernelMemory.Pipeline.DataPipeline,System.Threading.CancellationToken)">
|
||||
<inheritdoc />
|
||||
</member>
|
||||
<member name="F:AntSK.Domain.Domain.Other.LLamaConfig.dicLLamaWeights">
|
||||
<summary>
|
||||
避免模型重复加载,本地缓存
|
||||
</summary>
|
||||
</member>
|
||||
<member name="P:AntSK.Domain.Domain.Other.QAHandler.StepName">
|
||||
<inheritdoc />
|
||||
</member>
|
||||
@@ -924,6 +919,20 @@
|
||||
<param name="value"></param>
|
||||
<returns></returns>
|
||||
</member>
|
||||
<member name="M:AntSK.Domain.Utils.ConvertUtils.Unescape(System.String)">
|
||||
<summary>
|
||||
\uxxxx转中文,保留换行符号
|
||||
</summary>
|
||||
<param name="unicodeString"></param>
|
||||
<returns></returns>
|
||||
</member>
|
||||
<member name="M:AntSK.Domain.Utils.ConvertUtils.IsStream(System.String)">
|
||||
<summary>
|
||||
是否为流式请求
|
||||
</summary>
|
||||
<param name="value"></param>
|
||||
<returns></returns>
|
||||
</member>
|
||||
<member name="M:AntSK.Domain.Utils.RepoFiles.SamplePluginsPath">
|
||||
<summary>
|
||||
Scan the local folders from the repo, looking for "samples/plugins" folder.
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
namespace AntSK.Domain.Common
|
||||
{
|
||||
[AttributeUsage(AttributeTargets.Method)]
|
||||
public class AntSkFunctionAttribute : Attribute
|
||||
public class AntSKFunctionAttribute : Attribute
|
||||
{
|
||||
// 自定义的ActionAttribute
|
||||
}
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
using LLamaSharp.KernelMemory;
|
||||
using Microsoft.KernelMemory.AI;
|
||||
using Microsoft.KernelMemory.AI;
|
||||
using Microsoft.KernelMemory;
|
||||
using System;
|
||||
using System.Collections.Generic;
|
||||
|
||||
@@ -1,12 +1,4 @@
|
||||
using LLama.Common;
|
||||
using LLama;
|
||||
using LLamaSharp.KernelMemory;
|
||||
using Microsoft.KernelMemory.AI;
|
||||
using System;
|
||||
using System.Collections.Generic;
|
||||
using System.Linq;
|
||||
using System.Text;
|
||||
using System.Threading.Tasks;
|
||||
using Microsoft.KernelMemory.AI;
|
||||
using AntSK.Domain.Domain.Other.Bge;
|
||||
|
||||
namespace AntSK.Domain.Common.Embedding
|
||||
@@ -52,5 +44,10 @@ namespace AntSK.Domain.Common.Embedding
|
||||
{
|
||||
return BgeEmbeddingConfig.TokenCount(text);
|
||||
}
|
||||
|
||||
public IReadOnlyList<string> GetTokens(string text)
|
||||
{
|
||||
return new List<string>();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -12,7 +12,9 @@ namespace AntSK.Domain.Domain.Interface
|
||||
{
|
||||
public event LogMessageHandler LogMessageReceived;
|
||||
Task PipInstall();
|
||||
Task StartLLamaFactory(string modelName, string templateName);
|
||||
|
||||
Task PipInstallName(string name);
|
||||
Task StartLLamaFactory(string modelName);
|
||||
|
||||
void KillProcess();
|
||||
|
||||
|
||||
15
src/AntSK.Domain/Domain/Interface/IOllamaService.cs
Normal file
15
src/AntSK.Domain/Domain/Interface/IOllamaService.cs
Normal file
@@ -0,0 +1,15 @@
|
||||
using System;
|
||||
using System.Collections.Generic;
|
||||
using System.Linq;
|
||||
using System.Text;
|
||||
using System.Threading.Tasks;
|
||||
using static AntSK.Domain.Domain.Service.OllamaService;
|
||||
|
||||
namespace AntSK.Domain.Domain.Interface
|
||||
{
|
||||
public interface IOllamaService
|
||||
{
|
||||
public event LogMessageHandler LogMessageReceived;
|
||||
Task StartOllama(string modelName);
|
||||
}
|
||||
}
|
||||
@@ -13,9 +13,6 @@ namespace AntSK.Domain.Domain.Model.Enum
|
||||
[Display(Name = "Azure Open AI")]
|
||||
AzureOpenAI = 2,
|
||||
|
||||
[Display(Name = "LLama本地模型")]
|
||||
LLamaSharp = 3,
|
||||
|
||||
[Display(Name = "星火大模型")]
|
||||
SparkDesk = 4,
|
||||
|
||||
@@ -28,8 +25,11 @@ namespace AntSK.Domain.Domain.Model.Enum
|
||||
BgeEmbedding = 7,
|
||||
[Display(Name = "Bge Rerank")]
|
||||
BgeRerank = 8,
|
||||
[Display(Name = "StableDiffusion")]
|
||||
StableDiffusion = 9,
|
||||
|
||||
[Display(Name = "Ollama")]
|
||||
Ollama = 10,
|
||||
[Display(Name = "OllamaEmbedding")]
|
||||
OllamaEmbedding = 11,
|
||||
[Display(Name = "模拟输出")]
|
||||
Mock = 100,
|
||||
|
||||
@@ -42,7 +42,6 @@ namespace AntSK.Domain.Domain.Model.Enum
|
||||
{
|
||||
Chat = 1,
|
||||
Embedding = 2,
|
||||
Image=3,
|
||||
Rerank=4
|
||||
}
|
||||
}
|
||||
|
||||
@@ -6,6 +6,7 @@ using System.Linq;
|
||||
using System.Text;
|
||||
using System.Threading.Tasks;
|
||||
using static Python.Runtime.Py;
|
||||
using AntSK.Domain.Utils;
|
||||
|
||||
namespace AntSK.Domain.Domain.Other.Bge
|
||||
{
|
||||
@@ -26,11 +27,7 @@ namespace AntSK.Domain.Domain.Other.Bge
|
||||
{
|
||||
if (model == null)
|
||||
{
|
||||
if (string.IsNullOrEmpty(Runtime.PythonDLL))
|
||||
{
|
||||
Runtime.PythonDLL = pythondllPath;
|
||||
}
|
||||
PythonEngine.Initialize();
|
||||
PyRunTime.InitRunTime(pythondllPath);
|
||||
try
|
||||
{
|
||||
using (GIL())// 初始化Python环境的Global Interpreter Lock)
|
||||
@@ -39,7 +36,7 @@ namespace AntSK.Domain.Domain.Other.Bge
|
||||
dynamic flagEmbedding = Py.Import("FlagEmbedding");
|
||||
|
||||
dynamic model_dir = modelscope.snapshot_download(modelName, revision: "master");
|
||||
dynamic flagReranker = flagEmbedding.FlagReranker(model_dir, use_fp16: true);
|
||||
dynamic flagReranker = flagEmbedding.FlagReranker(model_dir, use_fp16: false);
|
||||
model = flagReranker;
|
||||
return model;
|
||||
}
|
||||
@@ -69,7 +66,7 @@ namespace AntSK.Domain.Domain.Other.Bge
|
||||
pyList.Append(item.ToPython()); // 将C# string转换为Python对象并添加到PyList中
|
||||
}
|
||||
PyObject result = model.compute_score(pyList, normalize: true);
|
||||
return result.As<double>();
|
||||
return result.ConvertToString().Trim('[').Trim(']').ConvertToDouble();
|
||||
}
|
||||
catch (Exception ex)
|
||||
{
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
using Microsoft.KernelMemory.AI.OpenAI;
|
||||
using Microsoft.KernelMemory.AI.OpenAI.GPT3;
|
||||
using Python.Runtime;
|
||||
using Serilog;
|
||||
using System;
|
||||
@@ -28,13 +27,7 @@ namespace AntSK.Domain.Domain.Other.Bge
|
||||
{
|
||||
if (model == null)
|
||||
{
|
||||
//Runtime.PythonDLL = @"D:\Programs\Python\Python311\python311.dll";
|
||||
if (string.IsNullOrEmpty(Runtime.PythonDLL))
|
||||
{
|
||||
Runtime.PythonDLL = pythondllPath;
|
||||
}
|
||||
PythonEngine.Initialize();
|
||||
PythonEngine.BeginAllowThreads();
|
||||
PyRunTime.InitRunTime(pythondllPath);
|
||||
try
|
||||
{
|
||||
using (GIL())// 初始化Python环境的Global Interpreter Lock)
|
||||
|
||||
28
src/AntSK.Domain/Domain/Other/Bge/PyRunTime.cs
Normal file
28
src/AntSK.Domain/Domain/Other/Bge/PyRunTime.cs
Normal file
@@ -0,0 +1,28 @@
|
||||
using Python.Runtime;
|
||||
|
||||
namespace AntSK.Domain.Domain.Other.Bge
|
||||
{
|
||||
public static class PyRunTime
|
||||
{
|
||||
static object lockobj = new object();
|
||||
|
||||
static bool isInit = false;
|
||||
|
||||
public static void InitRunTime(string pythonPath)
|
||||
{
|
||||
lock (lockobj)
|
||||
{
|
||||
if (!isInit)
|
||||
{
|
||||
if (string.IsNullOrEmpty(Runtime.PythonDLL))
|
||||
{
|
||||
Runtime.PythonDLL = pythonPath;
|
||||
}
|
||||
PythonEngine.Initialize();
|
||||
PythonEngine.BeginAllowThreads();
|
||||
isInit = true;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -1,45 +0,0 @@
|
||||
using AntSK.Domain.Options;
|
||||
using LLama;
|
||||
using LLama.Common;
|
||||
using LLamaSharp.KernelMemory;
|
||||
|
||||
namespace AntSK.Domain.Domain.Other
|
||||
{
|
||||
public static class LLamaConfig
|
||||
{
|
||||
static object lockobj = new object();
|
||||
/// <summary>
|
||||
/// 避免模型重复加载,本地缓存
|
||||
/// </summary>
|
||||
static Dictionary<string, (LLamaWeights, ModelParams)> dicLLamaWeights = new Dictionary<string, (LLamaWeights, ModelParams)>();
|
||||
public static (LLamaWeights, ModelParams) GetLLamaConfig(string modelPath, LLamaSharpConfig config = null)
|
||||
{
|
||||
lock (lockobj)
|
||||
{
|
||||
if (dicLLamaWeights.ContainsKey(modelPath))
|
||||
{
|
||||
return dicLLamaWeights.GetValueOrDefault(modelPath);
|
||||
}
|
||||
else
|
||||
{
|
||||
InferenceParams infParams = new() { AntiPrompts = ["\n\n"] };
|
||||
LLamaSharpConfig lsConfig = new(modelPath) { DefaultInferenceParams = infParams };
|
||||
if (config != null)
|
||||
{
|
||||
lsConfig = config;
|
||||
}
|
||||
var parameters = new ModelParams(lsConfig.ModelPath)
|
||||
{
|
||||
ContextSize = LLamaSharpOption.ContextSize ?? 2048,
|
||||
Seed = lsConfig?.Seed ?? 0,
|
||||
GpuLayerCount = LLamaSharpOption.GpuLayerCount ?? 20,
|
||||
Embeddings = true
|
||||
};
|
||||
var weights = LLamaWeights.LoadFromFile(parameters);
|
||||
dicLLamaWeights.Add(modelPath, (weights, parameters));
|
||||
return (weights, parameters);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -11,7 +11,6 @@ using AntSK.Domain.Utils;
|
||||
using AntSK.OCR;
|
||||
using DocumentFormat.OpenXml.Drawing.Diagrams;
|
||||
using LLama;
|
||||
using LLamaSharp.KernelMemory;
|
||||
using Markdig;
|
||||
using Microsoft.AspNetCore.Components;
|
||||
using Microsoft.Extensions.Configuration;
|
||||
@@ -171,12 +170,6 @@ namespace AntSK.Domain.Domain.Service
|
||||
APIType = AzureOpenAIConfig.APITypes.EmbeddingGeneration,
|
||||
});
|
||||
break;
|
||||
|
||||
case Model.Enum.AIType.LLamaSharp:
|
||||
var (weights, parameters) = LLamaConfig.GetLLamaConfig(embedModel.ModelName);
|
||||
var embedder = new LLamaEmbedder(weights, parameters);
|
||||
memory.WithLLamaSharpTextEmbeddingGeneration(new LLamaSharpTextEmbeddingGenerator(embedder));
|
||||
break;
|
||||
case Model.Enum.AIType.BgeEmbedding:
|
||||
string pyDll = embedModel.EndPoint;
|
||||
string bgeEmbeddingModelName = embedModel.ModelName;
|
||||
@@ -185,6 +178,13 @@ namespace AntSK.Domain.Domain.Service
|
||||
case Model.Enum.AIType.DashScope:
|
||||
memory.WithDashScopeDefaults(embedModel.ModelKey);
|
||||
break;
|
||||
case Model.Enum.AIType.OllamaEmbedding:
|
||||
memory.WithOpenAITextEmbeddingGeneration(new OpenAIConfig()
|
||||
{
|
||||
APIKey = "NotNull",
|
||||
EmbeddingModel = embedModel.ModelName
|
||||
}, null, false, embeddingHttpClient);
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
@@ -211,15 +211,15 @@ namespace AntSK.Domain.Domain.Service
|
||||
APIType = AzureOpenAIConfig.APITypes.TextCompletion,
|
||||
});
|
||||
break;
|
||||
|
||||
case Model.Enum.AIType.LLamaSharp:
|
||||
var (weights, parameters) = LLamaConfig.GetLLamaConfig(chatModel.ModelName);
|
||||
var context = weights.CreateContext(parameters);
|
||||
var executor = new StatelessExecutor(weights, parameters);
|
||||
memory.WithLLamaSharpTextGeneration(new LlamaSharpTextGenerator(weights, context, executor));
|
||||
break;
|
||||
case Model.Enum.AIType.LLamaFactory:
|
||||
|
||||
memory.WithOpenAITextGeneration(new OpenAIConfig()
|
||||
{
|
||||
APIKey = "NotNull",
|
||||
TextModel = chatModel.ModelName
|
||||
}, null, chatHttpClient);
|
||||
break;
|
||||
case Model.Enum.AIType.Ollama:
|
||||
memory.WithOpenAITextGeneration(new OpenAIConfig()
|
||||
{
|
||||
APIKey = "NotNull",
|
||||
|
||||
@@ -4,25 +4,15 @@ using AntSK.Domain.Domain.Interface;
|
||||
using AntSK.Domain.Domain.Other;
|
||||
using AntSK.Domain.Repositories;
|
||||
using AntSK.Domain.Utils;
|
||||
using LLama;
|
||||
using LLamaSharp.SemanticKernel.TextCompletion;
|
||||
using Microsoft.Extensions.DependencyInjection;
|
||||
using Microsoft.SemanticKernel;
|
||||
using Microsoft.SemanticKernel.Plugins.Core;
|
||||
using Microsoft.SemanticKernel.TextGeneration;
|
||||
using RestSharp;
|
||||
using System;
|
||||
using ServiceLifetime = AntSK.Domain.Common.DependencyInjection.ServiceLifetime;
|
||||
using AntSK.LLM.Mock;
|
||||
using AntSK.Domain.Domain.Model.Enum;
|
||||
using AntSK.LLM.LLamaFactory;
|
||||
using System.Reflection;
|
||||
using DocumentFormat.OpenXml.Drawing;
|
||||
using Microsoft.KernelMemory;
|
||||
using OpenCvSharp.ML;
|
||||
using LLamaSharp.SemanticKernel.ChatCompletion;
|
||||
using Microsoft.SemanticKernel.ChatCompletion;
|
||||
using Amazon.Runtime.Internal.Util;
|
||||
using Microsoft.Extensions.Logging;
|
||||
|
||||
namespace AntSK.Domain.Domain.Service
|
||||
@@ -108,15 +98,30 @@ namespace AntSK.Domain.Domain.Service
|
||||
);
|
||||
break;
|
||||
|
||||
case Model.Enum.AIType.LLamaSharp:
|
||||
var (weights, parameters) = LLamaConfig.GetLLamaConfig(chatModel.ModelName);
|
||||
var ex = new StatelessExecutor(weights, parameters);
|
||||
builder.Services.AddKeyedSingleton<ITextGenerationService>("local-llama", new LLamaSharpTextCompletion(ex));
|
||||
builder.Services.AddKeyedSingleton<IChatCompletionService>("local-llama-chat", new LLamaSharpChatCompletion(ex));
|
||||
break;
|
||||
|
||||
case Model.Enum.AIType.SparkDesk:
|
||||
var options = new SparkDeskOptions { AppId = chatModel.EndPoint, ApiSecret = chatModel.ModelKey, ApiKey = chatModel.ModelName, ModelVersion = Sdcb.SparkDesk.ModelVersion.V3_5 };
|
||||
|
||||
var settings = chatModel.ModelKey.Split("|");
|
||||
|
||||
Sdcb.SparkDesk.ModelVersion modelVersion = Sdcb.SparkDesk.ModelVersion.V3_5;
|
||||
|
||||
switch (chatModel.ModelName)
|
||||
{
|
||||
case "V3_5":
|
||||
modelVersion = Sdcb.SparkDesk.ModelVersion.V3_5;
|
||||
break;
|
||||
case "V3":
|
||||
modelVersion = Sdcb.SparkDesk.ModelVersion.V3;
|
||||
break;
|
||||
case "V2":
|
||||
modelVersion = Sdcb.SparkDesk.ModelVersion.V2;
|
||||
break;
|
||||
case "V1_5":
|
||||
modelVersion = Sdcb.SparkDesk.ModelVersion.V1_5;
|
||||
break;
|
||||
}
|
||||
|
||||
SparkDeskOptions options = new SparkDeskOptions { AppId = settings[0], ApiSecret = settings[1], ApiKey = settings[2], ModelVersion = modelVersion };
|
||||
|
||||
builder.Services.AddKeyedSingleton<ITextGenerationService>("spark-desk", new SparkDeskTextCompletion(options, chatModel.Id));
|
||||
builder.Services.AddKeyedSingleton<IChatCompletionService>("spark-desk-chat", new SparkDeskChatCompletion(options, chatModel.Id));
|
||||
break;
|
||||
@@ -132,7 +137,14 @@ namespace AntSK.Domain.Domain.Service
|
||||
case Model.Enum.AIType.LLamaFactory:
|
||||
builder.AddOpenAIChatCompletion(
|
||||
modelId: chatModel.ModelName,
|
||||
apiKey: "123",
|
||||
apiKey: "NotNull",
|
||||
httpClient: chatHttpClient
|
||||
);
|
||||
break;
|
||||
case AIType.Ollama:
|
||||
builder.AddOpenAIChatCompletion(
|
||||
modelId: chatModel.ModelName,
|
||||
apiKey: "NotNull",
|
||||
httpClient: chatHttpClient
|
||||
);
|
||||
break;
|
||||
@@ -147,7 +159,7 @@ namespace AntSK.Domain.Domain.Service
|
||||
public void ImportFunctionsByApp(Apps app, Kernel _kernel)
|
||||
{
|
||||
//插件不能重复注册,否则会异常
|
||||
if (_kernel.Plugins.Any(p => p.Name == "AntSkFunctions"))
|
||||
if (_kernel.Plugins.Any(p => p.Name == "AntSKFunctions"))
|
||||
{
|
||||
return;
|
||||
}
|
||||
@@ -158,7 +170,7 @@ namespace AntSK.Domain.Domain.Service
|
||||
//本地函数插件
|
||||
ImportNativeFunction(app, functions);
|
||||
|
||||
_kernel.ImportPluginFromFunctions("AntSkFunctions", functions);
|
||||
_kernel.ImportPluginFromFunctions("AntSKFunctions", functions);
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
|
||||
@@ -7,6 +7,7 @@ using AntSK.LLamaFactory.Model;
|
||||
using Microsoft.AspNetCore.Mvc.ModelBinding;
|
||||
using Microsoft.Extensions.Logging;
|
||||
using Newtonsoft.Json;
|
||||
using Serilog;
|
||||
using System;
|
||||
using System.Collections.Generic;
|
||||
using System.Diagnostics;
|
||||
@@ -74,8 +75,45 @@ namespace AntSK.Domain.Domain.Service
|
||||
}, TaskCreationOptions.LongRunning);
|
||||
await cmdTask;
|
||||
}
|
||||
public async Task PipInstallName(string name)
|
||||
{
|
||||
|
||||
public async Task StartLLamaFactory(string modelName, string templateName)
|
||||
var cmdTask = Task.Factory.StartNew(() =>
|
||||
{
|
||||
|
||||
var isProcessComplete = false;
|
||||
|
||||
process = new Process
|
||||
{
|
||||
StartInfo = new ProcessStartInfo
|
||||
{
|
||||
FileName = "pip",
|
||||
Arguments = $"install {name} -i https://pypi.tuna.tsinghua.edu.cn/simple",
|
||||
UseShellExecute = false,
|
||||
RedirectStandardOutput = true,
|
||||
RedirectStandardError = true,
|
||||
WorkingDirectory = AppDomain.CurrentDomain.BaseDirectory,
|
||||
}
|
||||
};
|
||||
process.OutputDataReceived += (sender, eventArgs) =>
|
||||
{
|
||||
Log.Information($"{eventArgs.Data}");
|
||||
OnLogMessageReceived(eventArgs.Data);
|
||||
};
|
||||
process.ErrorDataReceived += (sender, eventArgs) =>
|
||||
{
|
||||
Log.Information($"{eventArgs.Data}");
|
||||
OnLogMessageReceived(eventArgs.Data);
|
||||
};
|
||||
process.Start();
|
||||
process.BeginOutputReadLine();
|
||||
process.BeginErrorReadLine();
|
||||
process.WaitForExit();
|
||||
OnLogMessageReceived("--------------------完成--------------------");
|
||||
}, TaskCreationOptions.LongRunning);
|
||||
await cmdTask;
|
||||
}
|
||||
public async Task StartLLamaFactory(string modelName)
|
||||
{
|
||||
var cmdTask = Task.Factory.StartNew(() =>
|
||||
{
|
||||
@@ -87,7 +125,7 @@ namespace AntSK.Domain.Domain.Service
|
||||
StartInfo = new ProcessStartInfo
|
||||
{
|
||||
FileName = "python",
|
||||
Arguments = "api_demo.py --model_name_or_path " + modelName + " --template " + templateName + " ",
|
||||
Arguments = "api_antsk.py --model_name_or_path " + modelName + " --template default ",
|
||||
UseShellExecute = false,
|
||||
RedirectStandardOutput = true,
|
||||
RedirectStandardError=true,
|
||||
|
||||
74
src/AntSK.Domain/Domain/Service/OllamaService.cs
Normal file
74
src/AntSK.Domain/Domain/Service/OllamaService.cs
Normal file
@@ -0,0 +1,74 @@
|
||||
using AntSK.Domain.Common.DependencyInjection;
|
||||
using AntSK.Domain.Domain.Interface;
|
||||
using System;
|
||||
using System.Collections.Generic;
|
||||
using System.Diagnostics;
|
||||
using System.Linq;
|
||||
using System.Text;
|
||||
using System.Threading.Tasks;
|
||||
using Serilog;
|
||||
using AntSK.Domain.Utils;
|
||||
|
||||
namespace AntSK.Domain.Domain.Service
|
||||
{
|
||||
[ServiceDescription(typeof(IOllamaService), ServiceLifetime.Singleton)]
|
||||
public class OllamaService : IOllamaService
|
||||
{
|
||||
private Process process;
|
||||
public delegate Task LogMessageHandler(string message);
|
||||
public event LogMessageHandler LogMessageReceived;
|
||||
protected virtual async Task OnLogMessageReceived(string message)
|
||||
{
|
||||
LogMessageReceived?.Invoke(message);
|
||||
}
|
||||
|
||||
public async Task StartOllama(string modelName)
|
||||
{
|
||||
Console.OutputEncoding = Encoding.UTF8;
|
||||
var cmdTask = Task.Factory.StartNew(() =>
|
||||
{
|
||||
|
||||
var isProcessComplete = false;
|
||||
|
||||
process = new Process
|
||||
{
|
||||
StartInfo = new ProcessStartInfo
|
||||
{
|
||||
FileName = "ollama",
|
||||
Arguments = "run " + modelName,
|
||||
UseShellExecute = false,
|
||||
RedirectStandardOutput = true,
|
||||
RedirectStandardError = true,
|
||||
}
|
||||
};
|
||||
process.OutputDataReceived += (sender, eventArgs) =>
|
||||
{
|
||||
Log.Information($"{eventArgs.Data.ConvertToString()}");
|
||||
if (!eventArgs.Data.ConvertToString().Contains("The handle is invalid"))
|
||||
{
|
||||
OnLogMessageReceived(eventArgs.Data.ConvertToString());
|
||||
}
|
||||
};
|
||||
process.ErrorDataReceived += (sender, eventArgs) =>
|
||||
{
|
||||
Log.Error($"{eventArgs.Data.ConvertToString()}");
|
||||
if (!eventArgs.Data.ConvertToString().Contains("The handle is invalid"))
|
||||
{
|
||||
OnLogMessageReceived(eventArgs.Data.ConvertToString());
|
||||
}
|
||||
};
|
||||
process.StartInfo.StandardOutputEncoding = Encoding.UTF8;
|
||||
process.StartInfo.StandardErrorEncoding = Encoding.UTF8;
|
||||
|
||||
process.Start();
|
||||
process.BeginOutputReadLine();
|
||||
process.BeginErrorReadLine();
|
||||
process.WaitForExit();
|
||||
|
||||
OnLogMessageReceived("--------------------完成--------------------");
|
||||
}, TaskCreationOptions.LongRunning);
|
||||
await cmdTask;
|
||||
}
|
||||
|
||||
}
|
||||
}
|
||||
@@ -1,9 +0,0 @@
|
||||
namespace AntSK.Domain.Options
|
||||
{
|
||||
public class LLamaSharpOption
|
||||
{
|
||||
public static string RunType { get; set; }
|
||||
public static uint? ContextSize { get; set; }
|
||||
public static int? GpuLayerCount { get; set; }
|
||||
}
|
||||
}
|
||||
@@ -25,7 +25,7 @@ namespace AntSK.Domain.Repositories
|
||||
/// 图标
|
||||
/// </summary>
|
||||
[Required]
|
||||
public string Icon { get; set; }
|
||||
public string Icon { get; set; } = "windows";
|
||||
|
||||
/// <summary>
|
||||
/// 类型
|
||||
@@ -75,6 +75,7 @@ namespace AntSK.Domain.Repositories
|
||||
/// <summary>
|
||||
/// 知识库ID列表
|
||||
/// </summary>
|
||||
[SugarColumn(ColumnDataType = "varchar(1000)")]
|
||||
public string? KmsIdList { get; set; }
|
||||
|
||||
/// <summary>
|
||||
|
||||
@@ -12,7 +12,7 @@ namespace AntSK.Domain.Repositories
|
||||
/// 图标
|
||||
/// </summary>
|
||||
[Required]
|
||||
public string Icon { get; set; }
|
||||
public string Icon { get; set; } = "question-circle";
|
||||
/// <summary>
|
||||
/// 名称
|
||||
/// </summary>
|
||||
|
||||
@@ -1,4 +1,7 @@
|
||||
using System.Security.Cryptography;
|
||||
using Newtonsoft.Json;
|
||||
using Serilog;
|
||||
using System.Security.Cryptography;
|
||||
using System.Text.RegularExpressions;
|
||||
using System.Web;
|
||||
|
||||
namespace AntSK.Domain.Utils
|
||||
@@ -263,6 +266,50 @@ namespace AntSK.Domain.Utils
|
||||
return s.Equals(value, StringComparison.OrdinalIgnoreCase);
|
||||
}
|
||||
|
||||
|
||||
/// <summary>
|
||||
/// \uxxxx转中文,保留换行符号
|
||||
/// </summary>
|
||||
/// <param name="unicodeString"></param>
|
||||
/// <returns></returns>
|
||||
public static string Unescape(this string value)
|
||||
{
|
||||
if (value.IsNull())
|
||||
{
|
||||
return "";
|
||||
}
|
||||
|
||||
try
|
||||
{
|
||||
Formatting formatting = Formatting.None;
|
||||
|
||||
object jsonObj = JsonConvert.DeserializeObject(value);
|
||||
string unescapeValue = JsonConvert.SerializeObject(jsonObj, formatting);
|
||||
return unescapeValue;
|
||||
}
|
||||
catch (Exception ex)
|
||||
{
|
||||
Log.Error(ex.ToString());
|
||||
return "";
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
/// <summary>
|
||||
/// 是否为流式请求
|
||||
/// </summary>
|
||||
/// <param name="value"></param>
|
||||
/// <returns></returns>
|
||||
public static bool IsStream(this string value)
|
||||
{
|
||||
// 正则表达式忽略空格的情况
|
||||
string pattern = @"\s*""stream""\s*:\s*true\s*";
|
||||
|
||||
// 使用正则表达式匹配
|
||||
bool contains = Regex.IsMatch(value, pattern);
|
||||
return contains;
|
||||
}
|
||||
|
||||
public static string AntSKCalculateSHA256(this BinaryData binaryData)
|
||||
{
|
||||
byte[] byteArray = SHA256.HashData(binaryData.ToMemory().Span);
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
|
||||
using Serilog;
|
||||
using Serilog;
|
||||
using System.Text;
|
||||
using System.Text.RegularExpressions;
|
||||
|
||||
namespace AntSK.Domain.Utils
|
||||
@@ -17,12 +17,19 @@ namespace AntSK.Domain.Utils
|
||||
UriBuilder uriBuilder;
|
||||
Regex regex = new Regex(@"(https?)://([^/:]+)(:\d+)?/(.*)");
|
||||
Match match = regex.Match(_endPoint);
|
||||
if (Environment.GetEnvironmentVariable("ASPNETCORE_ENVIRONMENT") == "Development" && request.Content != null)
|
||||
string guid = Guid.NewGuid().ToString();
|
||||
var mediaType = request.Content.Headers.ContentType.MediaType;
|
||||
string requestBody = (await request.Content.ReadAsStringAsync()).Unescape();
|
||||
var uncaseBody = new StringContent(requestBody, Encoding.UTF8, mediaType);
|
||||
request.Content = uncaseBody;
|
||||
|
||||
if (Environment.GetEnvironmentVariable("ASPNETCORE_ENVIRONMENT").ConvertToString() != "Production")
|
||||
{
|
||||
string requestBody = await request.Content.ReadAsStringAsync();
|
||||
//生产环境根据环境变量可去关闭日志
|
||||
//便于调试查看请求prompt
|
||||
Log.Information(requestBody);
|
||||
Log.Information("{Message}", $"【模型服务接口调用-{guid},host:{_endPoint}】:{Environment.NewLine}{requestBody}");
|
||||
}
|
||||
|
||||
if (match.Success)
|
||||
{
|
||||
string xieyi = match.Groups[1].Value;
|
||||
@@ -72,7 +79,11 @@ namespace AntSK.Domain.Utils
|
||||
|
||||
// 接着,调用基类的 SendAsync 方法将你的修改后的请求发出去
|
||||
HttpResponseMessage response = await base.SendAsync(request, cancellationToken);
|
||||
|
||||
if (Environment.GetEnvironmentVariable("ASPNETCORE_ENVIRONMENT").ConvertToString() != "Production")
|
||||
{
|
||||
string responseContent = requestBody.IsStream() ? response.Content.ReadAsStringAsync().Result : response.Content.ReadAsStringAsync().Result.Unescape();
|
||||
Log.Information("{Message}", $"【模型服务接口返回-{guid},host:{_endPoint}】:{Environment.NewLine}{responseContent}");
|
||||
}
|
||||
return response;
|
||||
}
|
||||
}
|
||||
@@ -84,7 +95,7 @@ namespace AntSK.Domain.Utils
|
||||
{
|
||||
var handler = new OpenAIHttpClientHandler(endPoint.ConvertToString());
|
||||
var httpClient = new HttpClient(handler);
|
||||
httpClient.Timeout = TimeSpan.FromMinutes(5);
|
||||
httpClient.Timeout = TimeSpan.FromMinutes(10);
|
||||
return httpClient;
|
||||
}
|
||||
}
|
||||
|
||||
@@ -7,16 +7,22 @@
|
||||
</PropertyGroup>
|
||||
|
||||
<ItemGroup>
|
||||
<PackageReference Include="Newtonsoft.Json" Version="13.0.3" />
|
||||
<PackageReference Include="RestSharp" Version="110.2.0" />
|
||||
<PackageReference Include="Cnblogs.KernelMemory.AI.DashScope" Version="0.1.0" />
|
||||
<PackageReference Include="Microsoft.SemanticKernel" Version="$(SKVersion)" />
|
||||
<PackageReference Include="Newtonsoft.Json" Version="$(NewtonsoftVersion)" />
|
||||
<PackageReference Include="RestSharp" Version="$(RestSharpVersion)" />
|
||||
<PackageReference Include="Cnblogs.KernelMemory.AI.DashScope" Version="0.3.0" />
|
||||
<PackageReference Include="Cnblogs.SemanticKernel.Connectors.DashScope" Version="0.3.2" />
|
||||
<PackageReference Include="Microsoft.SemanticKernel" Version="1.13.0" />
|
||||
<PackageReference Include="Sdcb.SparkDesk" Version="3.0.0" />
|
||||
<PackageReference Include="System.Drawing.Common" Version="8.0.0" />
|
||||
</ItemGroup>
|
||||
|
||||
<ItemGroup>
|
||||
<None Update="OllamaEmbeddingModelList.txt">
|
||||
<CopyToOutputDirectory>Always</CopyToOutputDirectory>
|
||||
</None>
|
||||
<None Update="OllamaModelList.txt">
|
||||
<CopyToOutputDirectory>Always</CopyToOutputDirectory>
|
||||
</None>
|
||||
<None Update="StableDiffusion\Backend\CPU\stable-diffusion.dll">
|
||||
<CopyToOutputDirectory>PreserveNewest</CopyToOutputDirectory>
|
||||
</None>
|
||||
@@ -42,7 +48,7 @@
|
||||
<CopyToOutputDirectory>PreserveNewest</CopyToOutputDirectory>
|
||||
</None>
|
||||
<None Update="StableDiffusionModelList.txt">
|
||||
<CopyToOutputDirectory>PreserveNewest</CopyToOutputDirectory>
|
||||
<CopyToOutputDirectory>Always</CopyToOutputDirectory>
|
||||
</None>
|
||||
</ItemGroup>
|
||||
|
||||
105
src/AntSK.LLM/OllamaModelList.txt
Normal file
105
src/AntSK.LLM/OllamaModelList.txt
Normal file
@@ -0,0 +1,105 @@
|
||||
gemma2
|
||||
gemma2:27b
|
||||
gemma:2b
|
||||
gemma:7b
|
||||
llama3
|
||||
llama3:70b
|
||||
yi:6b
|
||||
yi:9B
|
||||
yi:34B
|
||||
qwen2:0.5b
|
||||
qwen2:1.5b
|
||||
qwen2:7b
|
||||
qwen2:72b
|
||||
qwen:0.5b
|
||||
qwen:1.8b
|
||||
qwen:4b
|
||||
qwen:7b
|
||||
qwen:14b
|
||||
qwen:32b
|
||||
qwen:72b
|
||||
qwen:110b
|
||||
deepseek-coder:1.3b
|
||||
deepseek-coder:6.7b
|
||||
deepseek-coder:33b
|
||||
deepseek-coder-v2:16b
|
||||
deepseek-coder-v2:236b
|
||||
phi:2.7b
|
||||
phi3:mini
|
||||
phi3:medium
|
||||
phi3:medium-128k
|
||||
aya:8b
|
||||
aya:35b
|
||||
mistral:7b
|
||||
mixtral:8x22b
|
||||
mixtral:8x7b
|
||||
codegemma:2b
|
||||
codegemma:7b
|
||||
command-r:35b
|
||||
llava
|
||||
gemma:2b
|
||||
gemma:7b
|
||||
llama2:7b
|
||||
llama2:13b
|
||||
llama2:70b
|
||||
llama2-chinese:7b
|
||||
llama2-chinese:13b
|
||||
llama3.1:8b
|
||||
llama3.1:70b
|
||||
llama3.1:405b
|
||||
codellama:7b
|
||||
codellama:13b
|
||||
codellama:34b
|
||||
codellama:70b
|
||||
dolphin-mistral:7b
|
||||
dolphin-mixtral:8x22b
|
||||
dolphin-mixtral:8x7b
|
||||
llama2-uncensored:7b
|
||||
llama2-uncensored:70b
|
||||
tinyllama:1.1b
|
||||
openchat:7b
|
||||
orca-mini:3b
|
||||
orca-mini:7b
|
||||
orca-mini:13b
|
||||
orca-mini:70b
|
||||
mistral-openorca:7b
|
||||
dolphin-llama3:8b
|
||||
dolphin-llama3:70b
|
||||
starcoder:1b
|
||||
starcoder:3b
|
||||
starcoder:7b
|
||||
starcoder:15b
|
||||
starcoder2:3b
|
||||
starcoder2:7b
|
||||
starcoder2:15b
|
||||
zephyr:7b
|
||||
zephyr:141b
|
||||
nous-hermes2:10.7b
|
||||
nous-hermes2:34b
|
||||
vicuna:7b
|
||||
vicuna:13b
|
||||
vicuna:33b
|
||||
wizard-vicuna-uncensored:7b
|
||||
wizard-vicuna-uncensored:13b
|
||||
wizard-vicuna-uncensored:30b
|
||||
wizardlm2:7b
|
||||
codestral:22b
|
||||
tinydolphin:1.1b
|
||||
openhermes:v2.5
|
||||
neural-chat:7b
|
||||
codeqwen:7b
|
||||
phind-codellama:34b
|
||||
nous-hermes:7b
|
||||
nous-hermes:13b
|
||||
nous-hermes:13b
|
||||
starling-lm:7b
|
||||
llama3-gradient:8b
|
||||
llama3-gradient:70b
|
||||
yarn-llama2:7b
|
||||
yarn-llama2:13b
|
||||
llava-llama3:8b
|
||||
llama-pro:instruct
|
||||
everythinglm:13b
|
||||
llava-phi3:3.8b
|
||||
mistrallite:7b
|
||||
notus:7b
|
||||
@@ -50,7 +50,7 @@ namespace AntSK.LLM.SparkDesk
|
||||
parameters.Temperature = (float)chatExecutionSettings.Temperature;
|
||||
parameters.MaxTokens = chatExecutionSettings.MaxTokens ?? parameters.MaxTokens;
|
||||
|
||||
IList<KernelFunctionMetadata> functions = kernel?.Plugins.GetFunctionsMetadata().Where(x => x.PluginName == "AntSkFunctions").ToList() ?? [];
|
||||
IList<KernelFunctionMetadata> functions = kernel?.Plugins.GetFunctionsMetadata().Where(x => x.PluginName == "AntSKFunctions").ToList() ?? [];
|
||||
var functionDefs = functions.Select(func => new FunctionDef(func.Name, func.Description, func.Parameters.Select(p => new FunctionParametersDef(p.Name, p.ParameterType?.IsClass == true ? "object" : "string", p.Description, p.IsRequired)).ToList())).ToList();
|
||||
|
||||
List<ChatMessage> messages = GetSparkMessage(chatHistory);
|
||||
@@ -133,7 +133,7 @@ namespace AntSK.LLM.SparkDesk
|
||||
parameters.Temperature = (float)chatExecutionSettings.Temperature;
|
||||
parameters.MaxTokens = chatExecutionSettings.MaxTokens ?? parameters.MaxTokens;
|
||||
|
||||
IList<KernelFunctionMetadata> functions = kernel?.Plugins.GetFunctionsMetadata().Where(x => x.PluginName == "AntSkFunctions").ToList() ?? [];
|
||||
IList<KernelFunctionMetadata> functions = kernel?.Plugins.GetFunctionsMetadata().Where(x => x.PluginName == "AntSKFunctions").ToList() ?? [];
|
||||
var functionDefs = functions.Select(func => new FunctionDef(func.Name, func.Description, func.Parameters.Select(p => new FunctionParametersDef(p.Name, p.ParameterType?.IsClass == true ? "object" : "string", p.Description, p.IsRequired)).ToList())).ToList();
|
||||
List<ChatMessage> messages = GetSparkMessage(chatHistory);
|
||||
await foreach (StreamedChatResponse msg in _client.ChatAsStreamAsync(_options.ModelVersion, messages.ToArray(), parameters, functionDefs.Count > 0 ? [.. functionDefs] : null, cancellationToken: cancellationToken))
|
||||
@@ -67,7 +67,7 @@ namespace AntSK.LLM.SparkDesk
|
||||
parameters.Temperature = (float)chatExecutionSettings.Temperature;
|
||||
parameters.MaxTokens = chatExecutionSettings.MaxTokens ?? parameters.MaxTokens;
|
||||
|
||||
IList<KernelFunctionMetadata> functions = kernel?.Plugins.GetFunctionsMetadata().Where(x => x.PluginName == "AntSkFunctions").ToList() ?? [];
|
||||
IList<KernelFunctionMetadata> functions = kernel?.Plugins.GetFunctionsMetadata().Where(x => x.PluginName == "AntSKFunctions").ToList() ?? [];
|
||||
var functionDefs = functions.Select(func => new FunctionDef(func.Name, func.Description, func.Parameters.Select(p => new FunctionParametersDef(p.Name, p.ParameterType?.IsClass == true ? "object" : "string", p.Description, p.IsRequired)).ToList())).ToList();
|
||||
|
||||
//var messages = GetHistories(prompt);
|
||||
19
src/AntSK.LLamaFactory/llamafactory/api_antsk.py
Normal file
19
src/AntSK.LLamaFactory/llamafactory/api_antsk.py
Normal file
@@ -0,0 +1,19 @@
|
||||
import os
|
||||
|
||||
import uvicorn
|
||||
|
||||
from llamafactory.api.app import create_app
|
||||
from llamafactory.chat import ChatModel
|
||||
|
||||
|
||||
def main():
|
||||
chat_model = ChatModel()
|
||||
app = create_app(chat_model)
|
||||
api_host = os.environ.get("API_HOST", "0.0.0.0")
|
||||
api_port = int(os.environ.get("API_PORT", "8000"))
|
||||
print("Visit http://localhost:{}/docs for API document.".format(api_port))
|
||||
uvicorn.run(app, host=api_host, port=api_port)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -1,16 +0,0 @@
|
||||
import os
|
||||
|
||||
import uvicorn
|
||||
|
||||
from llmtuner import ChatModel, create_app
|
||||
|
||||
|
||||
def main():
|
||||
chat_model = ChatModel()
|
||||
app = create_app(chat_model)
|
||||
print("Visit http://localhost:{}/docs for API document.".format(os.environ.get("API_PORT", 8000)))
|
||||
uvicorn.run(app, host="0.0.0.0", port=int(os.environ.get("API_PORT", 8000)), workers=1)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -0,0 +1,6 @@
|
||||
# Level: api, webui > chat, eval, train > data, model > hparams > extras
|
||||
|
||||
from .cli import VERSION
|
||||
|
||||
|
||||
__version__ = VERSION
|
||||
108
src/AntSK.LLamaFactory/llamafactory/llamafactory/api/app.py
Normal file
108
src/AntSK.LLamaFactory/llamafactory/llamafactory/api/app.py
Normal file
@@ -0,0 +1,108 @@
|
||||
import os
|
||||
from contextlib import asynccontextmanager
|
||||
from typing import Optional
|
||||
|
||||
from typing_extensions import Annotated
|
||||
|
||||
from ..chat import ChatModel
|
||||
from ..extras.misc import torch_gc
|
||||
from ..extras.packages import is_fastapi_available, is_starlette_available, is_uvicorn_available
|
||||
from .chat import (
|
||||
create_chat_completion_response,
|
||||
create_score_evaluation_response,
|
||||
create_stream_chat_completion_response,
|
||||
)
|
||||
from .protocol import (
|
||||
ChatCompletionRequest,
|
||||
ChatCompletionResponse,
|
||||
ModelCard,
|
||||
ModelList,
|
||||
ScoreEvaluationRequest,
|
||||
ScoreEvaluationResponse,
|
||||
)
|
||||
|
||||
|
||||
if is_fastapi_available():
|
||||
from fastapi import Depends, FastAPI, HTTPException, status
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from fastapi.security.http import HTTPAuthorizationCredentials, HTTPBearer
|
||||
|
||||
|
||||
if is_starlette_available():
|
||||
from sse_starlette import EventSourceResponse
|
||||
|
||||
|
||||
if is_uvicorn_available():
|
||||
import uvicorn
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def lifespan(app: "FastAPI"): # collects GPU memory
|
||||
yield
|
||||
torch_gc()
|
||||
|
||||
|
||||
def create_app(chat_model: "ChatModel") -> "FastAPI":
|
||||
app = FastAPI(lifespan=lifespan)
|
||||
app.add_middleware(
|
||||
CORSMiddleware,
|
||||
allow_origins=["*"],
|
||||
allow_credentials=True,
|
||||
allow_methods=["*"],
|
||||
allow_headers=["*"],
|
||||
)
|
||||
api_key = os.environ.get("API_KEY")
|
||||
security = HTTPBearer(auto_error=False)
|
||||
|
||||
async def verify_api_key(auth: Annotated[Optional[HTTPAuthorizationCredentials], Depends(security)]):
|
||||
if api_key and (auth is None or auth.credentials != api_key):
|
||||
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid API key.")
|
||||
|
||||
@app.get(
|
||||
"/v1/models",
|
||||
response_model=ModelList,
|
||||
status_code=status.HTTP_200_OK,
|
||||
dependencies=[Depends(verify_api_key)],
|
||||
)
|
||||
async def list_models():
|
||||
model_card = ModelCard(id="gpt-3.5-turbo")
|
||||
return ModelList(data=[model_card])
|
||||
|
||||
@app.post(
|
||||
"/v1/chat/completions",
|
||||
response_model=ChatCompletionResponse,
|
||||
status_code=status.HTTP_200_OK,
|
||||
dependencies=[Depends(verify_api_key)],
|
||||
)
|
||||
async def create_chat_completion(request: ChatCompletionRequest):
|
||||
if not chat_model.engine.can_generate:
|
||||
raise HTTPException(status_code=status.HTTP_405_METHOD_NOT_ALLOWED, detail="Not allowed")
|
||||
|
||||
if request.stream:
|
||||
generate = create_stream_chat_completion_response(request, chat_model)
|
||||
return EventSourceResponse(generate, media_type="text/event-stream")
|
||||
else:
|
||||
return await create_chat_completion_response(request, chat_model)
|
||||
|
||||
@app.post(
|
||||
"/v1/score/evaluation",
|
||||
response_model=ScoreEvaluationResponse,
|
||||
status_code=status.HTTP_200_OK,
|
||||
dependencies=[Depends(verify_api_key)],
|
||||
)
|
||||
async def create_score_evaluation(request: ScoreEvaluationRequest):
|
||||
if chat_model.engine.can_generate:
|
||||
raise HTTPException(status_code=status.HTTP_405_METHOD_NOT_ALLOWED, detail="Not allowed")
|
||||
|
||||
return await create_score_evaluation_response(request, chat_model)
|
||||
|
||||
return app
|
||||
|
||||
|
||||
def run_api() -> None:
|
||||
chat_model = ChatModel()
|
||||
app = create_app(chat_model)
|
||||
api_host = os.environ.get("API_HOST", "0.0.0.0")
|
||||
api_port = int(os.environ.get("API_PORT", "8000"))
|
||||
print("Visit http://localhost:{}/docs for API document.".format(api_port))
|
||||
uvicorn.run(app, host=api_host, port=api_port)
|
||||
219
src/AntSK.LLamaFactory/llamafactory/llamafactory/api/chat.py
Normal file
219
src/AntSK.LLamaFactory/llamafactory/llamafactory/api/chat.py
Normal file
@@ -0,0 +1,219 @@
|
||||
import base64
|
||||
import io
|
||||
import json
|
||||
import os
|
||||
import uuid
|
||||
from typing import TYPE_CHECKING, AsyncGenerator, Dict, List, Optional, Tuple
|
||||
|
||||
from ..data import Role as DataRole
|
||||
from ..extras.logging import get_logger
|
||||
from ..extras.packages import is_fastapi_available, is_pillow_available, is_requests_available
|
||||
from .common import dictify, jsonify
|
||||
from .protocol import (
|
||||
ChatCompletionMessage,
|
||||
ChatCompletionResponse,
|
||||
ChatCompletionResponseChoice,
|
||||
ChatCompletionResponseUsage,
|
||||
ChatCompletionStreamResponse,
|
||||
ChatCompletionStreamResponseChoice,
|
||||
Finish,
|
||||
Function,
|
||||
FunctionCall,
|
||||
Role,
|
||||
ScoreEvaluationResponse,
|
||||
)
|
||||
|
||||
|
||||
if is_fastapi_available():
|
||||
from fastapi import HTTPException, status
|
||||
|
||||
|
||||
if is_pillow_available():
|
||||
from PIL import Image
|
||||
|
||||
|
||||
if is_requests_available():
|
||||
import requests
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from numpy.typing import NDArray
|
||||
|
||||
from ..chat import ChatModel
|
||||
from .protocol import ChatCompletionRequest, ScoreEvaluationRequest
|
||||
|
||||
|
||||
logger = get_logger(__name__)
|
||||
ROLE_MAPPING = {
|
||||
Role.USER: DataRole.USER.value,
|
||||
Role.ASSISTANT: DataRole.ASSISTANT.value,
|
||||
Role.SYSTEM: DataRole.SYSTEM.value,
|
||||
Role.FUNCTION: DataRole.FUNCTION.value,
|
||||
Role.TOOL: DataRole.OBSERVATION.value,
|
||||
}
|
||||
|
||||
|
||||
def _process_request(
|
||||
request: "ChatCompletionRequest",
|
||||
) -> Tuple[List[Dict[str, str]], Optional[str], Optional[str], Optional["NDArray"]]:
|
||||
logger.info("==== request ====\n{}".format(json.dumps(dictify(request), indent=2, ensure_ascii=False)))
|
||||
|
||||
if len(request.messages) == 0:
|
||||
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid length")
|
||||
|
||||
if request.messages[0].role == Role.SYSTEM:
|
||||
system = request.messages.pop(0).content
|
||||
else:
|
||||
system = None
|
||||
|
||||
if len(request.messages) % 2 == 0:
|
||||
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Only supports u/a/u/a/u...")
|
||||
|
||||
input_messages = []
|
||||
image = None
|
||||
for i, message in enumerate(request.messages):
|
||||
if i % 2 == 0 and message.role not in [Role.USER, Role.TOOL]:
|
||||
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid role")
|
||||
elif i % 2 == 1 and message.role not in [Role.ASSISTANT, Role.FUNCTION]:
|
||||
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid role")
|
||||
|
||||
if message.role == Role.ASSISTANT and isinstance(message.tool_calls, list) and len(message.tool_calls):
|
||||
name = message.tool_calls[0].function.name
|
||||
arguments = message.tool_calls[0].function.arguments
|
||||
content = json.dumps({"name": name, "argument": arguments}, ensure_ascii=False)
|
||||
input_messages.append({"role": ROLE_MAPPING[Role.FUNCTION], "content": content})
|
||||
elif isinstance(message.content, list):
|
||||
for input_item in message.content:
|
||||
if input_item.type == "text":
|
||||
input_messages.append({"role": ROLE_MAPPING[message.role], "content": input_item.text})
|
||||
else:
|
||||
image_url = input_item.image_url.url
|
||||
if image_url.startswith("data:image"): # base64 image
|
||||
image_data = base64.b64decode(image_url.split(",", maxsplit=1)[1])
|
||||
image_path = io.BytesIO(image_data)
|
||||
elif os.path.isfile(image_url): # local file
|
||||
image_path = open(image_url, "rb")
|
||||
else: # web uri
|
||||
image_path = requests.get(image_url, stream=True).raw
|
||||
|
||||
image = Image.open(image_path).convert("RGB")
|
||||
else:
|
||||
input_messages.append({"role": ROLE_MAPPING[message.role], "content": message.content})
|
||||
|
||||
tool_list = request.tools
|
||||
if isinstance(tool_list, list) and len(tool_list):
|
||||
try:
|
||||
tools = json.dumps([dictify(tool.function) for tool in tool_list], ensure_ascii=False)
|
||||
except Exception:
|
||||
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid tools")
|
||||
else:
|
||||
tools = None
|
||||
|
||||
return input_messages, system, tools, image
|
||||
|
||||
|
||||
def _create_stream_chat_completion_chunk(
|
||||
completion_id: str,
|
||||
model: str,
|
||||
delta: "ChatCompletionMessage",
|
||||
index: Optional[int] = 0,
|
||||
finish_reason: Optional["Finish"] = None,
|
||||
) -> str:
|
||||
choice_data = ChatCompletionStreamResponseChoice(index=index, delta=delta, finish_reason=finish_reason)
|
||||
chunk = ChatCompletionStreamResponse(id=completion_id, model=model, choices=[choice_data])
|
||||
return jsonify(chunk)
|
||||
|
||||
|
||||
async def create_chat_completion_response(
|
||||
request: "ChatCompletionRequest", chat_model: "ChatModel"
|
||||
) -> "ChatCompletionResponse":
|
||||
completion_id = "chatcmpl-{}".format(uuid.uuid4().hex)
|
||||
input_messages, system, tools, image = _process_request(request)
|
||||
responses = await chat_model.achat(
|
||||
input_messages,
|
||||
system,
|
||||
tools,
|
||||
image,
|
||||
do_sample=request.do_sample,
|
||||
temperature=request.temperature,
|
||||
top_p=request.top_p,
|
||||
max_new_tokens=request.max_tokens,
|
||||
num_return_sequences=request.n,
|
||||
stop=request.stop,
|
||||
)
|
||||
|
||||
prompt_length, response_length = 0, 0
|
||||
choices = []
|
||||
for i, response in enumerate(responses):
|
||||
if tools:
|
||||
result = chat_model.engine.template.format_tools.extract(response.response_text)
|
||||
else:
|
||||
result = response.response_text
|
||||
|
||||
if isinstance(result, tuple):
|
||||
name, arguments = result
|
||||
function = Function(name=name, arguments=arguments)
|
||||
tool_call = FunctionCall(id="call_{}".format(uuid.uuid4().hex), function=function)
|
||||
response_message = ChatCompletionMessage(role=Role.ASSISTANT, tool_calls=[tool_call])
|
||||
finish_reason = Finish.TOOL
|
||||
else:
|
||||
response_message = ChatCompletionMessage(role=Role.ASSISTANT, content=result)
|
||||
finish_reason = Finish.STOP if response.finish_reason == "stop" else Finish.LENGTH
|
||||
|
||||
choices.append(ChatCompletionResponseChoice(index=i, message=response_message, finish_reason=finish_reason))
|
||||
prompt_length = response.prompt_length
|
||||
response_length += response.response_length
|
||||
|
||||
usage = ChatCompletionResponseUsage(
|
||||
prompt_tokens=prompt_length,
|
||||
completion_tokens=response_length,
|
||||
total_tokens=prompt_length + response_length,
|
||||
)
|
||||
|
||||
return ChatCompletionResponse(id=completion_id, model=request.model, choices=choices, usage=usage)
|
||||
|
||||
|
||||
async def create_stream_chat_completion_response(
|
||||
request: "ChatCompletionRequest", chat_model: "ChatModel"
|
||||
) -> AsyncGenerator[str, None]:
|
||||
completion_id = "chatcmpl-{}".format(uuid.uuid4().hex)
|
||||
input_messages, system, tools, image = _process_request(request)
|
||||
if tools:
|
||||
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Cannot stream function calls.")
|
||||
|
||||
if request.n > 1:
|
||||
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Cannot stream multiple responses.")
|
||||
|
||||
yield _create_stream_chat_completion_chunk(
|
||||
completion_id=completion_id, model=request.model, delta=ChatCompletionMessage(role=Role.ASSISTANT, content="")
|
||||
)
|
||||
async for new_token in chat_model.astream_chat(
|
||||
input_messages,
|
||||
system,
|
||||
tools,
|
||||
image,
|
||||
do_sample=request.do_sample,
|
||||
temperature=request.temperature,
|
||||
top_p=request.top_p,
|
||||
max_new_tokens=request.max_tokens,
|
||||
stop=request.stop,
|
||||
):
|
||||
if len(new_token) != 0:
|
||||
yield _create_stream_chat_completion_chunk(
|
||||
completion_id=completion_id, model=request.model, delta=ChatCompletionMessage(content=new_token)
|
||||
)
|
||||
|
||||
yield _create_stream_chat_completion_chunk(
|
||||
completion_id=completion_id, model=request.model, delta=ChatCompletionMessage(), finish_reason=Finish.STOP
|
||||
)
|
||||
yield "[DONE]"
|
||||
|
||||
|
||||
async def create_score_evaluation_response(
|
||||
request: "ScoreEvaluationRequest", chat_model: "ChatModel"
|
||||
) -> "ScoreEvaluationResponse":
|
||||
if len(request.messages) == 0:
|
||||
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid request")
|
||||
|
||||
scores = await chat_model.aget_scores(request.messages, max_length=request.max_length)
|
||||
return ScoreEvaluationResponse(model=request.model, scores=scores)
|
||||
@@ -0,0 +1,20 @@
|
||||
import json
|
||||
from typing import TYPE_CHECKING, Any, Dict
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
def dictify(data: "BaseModel") -> Dict[str, Any]:
|
||||
try: # pydantic v2
|
||||
return data.model_dump(exclude_unset=True)
|
||||
except AttributeError: # pydantic v1
|
||||
return data.dict(exclude_unset=True)
|
||||
|
||||
|
||||
def jsonify(data: "BaseModel") -> str:
|
||||
try: # pydantic v2
|
||||
return json.dumps(data.model_dump(exclude_unset=True), ensure_ascii=False)
|
||||
except AttributeError: # pydantic v1
|
||||
return data.json(exclude_unset=True, ensure_ascii=False)
|
||||
@@ -1,6 +1,6 @@
|
||||
import time
|
||||
from enum import Enum, unique
|
||||
from typing import List, Optional
|
||||
from typing import Any, Dict, List, Optional, Union
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
from typing_extensions import Literal
|
||||
@@ -39,15 +39,37 @@ class Function(BaseModel):
|
||||
arguments: str
|
||||
|
||||
|
||||
class FunctionDefinition(BaseModel):
|
||||
name: str
|
||||
description: str
|
||||
parameters: Dict[str, Any]
|
||||
|
||||
|
||||
class FunctionAvailable(BaseModel):
|
||||
type: Literal["function", "code_interpreter"] = "function"
|
||||
function: Optional[FunctionDefinition] = None
|
||||
|
||||
|
||||
class FunctionCall(BaseModel):
|
||||
id: Literal["call_default"] = "call_default"
|
||||
id: str
|
||||
type: Literal["function"] = "function"
|
||||
function: Function
|
||||
|
||||
|
||||
class ImageURL(BaseModel):
|
||||
url: str
|
||||
|
||||
|
||||
class MultimodalInputItem(BaseModel):
|
||||
type: Literal["text", "image_url"]
|
||||
text: Optional[str] = None
|
||||
image_url: Optional[ImageURL] = None
|
||||
|
||||
|
||||
class ChatMessage(BaseModel):
|
||||
role: Role
|
||||
content: str
|
||||
content: Optional[Union[str, List[MultimodalInputItem]]] = None
|
||||
tool_calls: Optional[List[FunctionCall]] = None
|
||||
|
||||
|
||||
class ChatCompletionMessage(BaseModel):
|
||||
@@ -59,12 +81,13 @@ class ChatCompletionMessage(BaseModel):
|
||||
class ChatCompletionRequest(BaseModel):
|
||||
model: str
|
||||
messages: List[ChatMessage]
|
||||
tools: list = []
|
||||
tools: Optional[List[FunctionAvailable]] = None
|
||||
do_sample: bool = True
|
||||
temperature: Optional[float] = None
|
||||
top_p: Optional[float] = None
|
||||
n: int = 1
|
||||
max_tokens: Optional[int] = None
|
||||
stop: Optional[Union[str, List[str]]] = None
|
||||
stream: bool = False
|
||||
|
||||
|
||||
@@ -74,7 +97,7 @@ class ChatCompletionResponseChoice(BaseModel):
|
||||
finish_reason: Finish
|
||||
|
||||
|
||||
class ChatCompletionResponseStreamChoice(BaseModel):
|
||||
class ChatCompletionStreamResponseChoice(BaseModel):
|
||||
index: int
|
||||
delta: ChatCompletionMessage
|
||||
finish_reason: Optional[Finish] = None
|
||||
@@ -87,7 +110,7 @@ class ChatCompletionResponseUsage(BaseModel):
|
||||
|
||||
|
||||
class ChatCompletionResponse(BaseModel):
|
||||
id: Literal["chatcmpl-default"] = "chatcmpl-default"
|
||||
id: str
|
||||
object: Literal["chat.completion"] = "chat.completion"
|
||||
created: int = Field(default_factory=lambda: int(time.time()))
|
||||
model: str
|
||||
@@ -96,11 +119,11 @@ class ChatCompletionResponse(BaseModel):
|
||||
|
||||
|
||||
class ChatCompletionStreamResponse(BaseModel):
|
||||
id: Literal["chatcmpl-default"] = "chatcmpl-default"
|
||||
id: str
|
||||
object: Literal["chat.completion.chunk"] = "chat.completion.chunk"
|
||||
created: int = Field(default_factory=lambda: int(time.time()))
|
||||
model: str
|
||||
choices: List[ChatCompletionResponseStreamChoice]
|
||||
choices: List[ChatCompletionStreamResponseChoice]
|
||||
|
||||
|
||||
class ScoreEvaluationRequest(BaseModel):
|
||||
@@ -110,7 +133,7 @@ class ScoreEvaluationRequest(BaseModel):
|
||||
|
||||
|
||||
class ScoreEvaluationResponse(BaseModel):
|
||||
id: Literal["scoreeval-default"] = "scoreeval-default"
|
||||
id: str
|
||||
object: Literal["score.evaluation"] = "score.evaluation"
|
||||
model: str
|
||||
scores: List[float]
|
||||
@@ -4,15 +4,13 @@ from typing import TYPE_CHECKING, Any, AsyncGenerator, Dict, List, Literal, Opti
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from numpy.typing import NDArray
|
||||
from transformers import PreTrainedModel, PreTrainedTokenizer
|
||||
from vllm import AsyncLLMEngine
|
||||
|
||||
from ..data import Template
|
||||
from ..extras.packages import is_vllm_available
|
||||
from ..hparams import DataArguments, FinetuningArguments, GeneratingArguments, ModelArguments
|
||||
|
||||
if is_vllm_available():
|
||||
from vllm import AsyncLLMEngine
|
||||
|
||||
|
||||
@dataclass
|
||||
class Response:
|
||||
@@ -49,6 +47,7 @@ class BaseEngine(ABC):
|
||||
messages: Sequence[Dict[str, str]],
|
||||
system: Optional[str] = None,
|
||||
tools: Optional[str] = None,
|
||||
image: Optional["NDArray"] = None,
|
||||
**input_kwargs,
|
||||
) -> List["Response"]: ...
|
||||
|
||||
@@ -58,6 +57,7 @@ class BaseEngine(ABC):
|
||||
messages: Sequence[Dict[str, str]],
|
||||
system: Optional[str] = None,
|
||||
tools: Optional[str] = None,
|
||||
image: Optional["NDArray"] = None,
|
||||
**input_kwargs,
|
||||
) -> AsyncGenerator[str, None]: ...
|
||||
|
||||
@@ -2,12 +2,15 @@ import asyncio
|
||||
from threading import Thread
|
||||
from typing import TYPE_CHECKING, Any, AsyncGenerator, Dict, Generator, List, Optional, Sequence
|
||||
|
||||
from ..extras.misc import torch_gc
|
||||
from ..hparams import get_infer_args
|
||||
from .hf_engine import HuggingfaceEngine
|
||||
from .vllm_engine import VllmEngine
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from numpy.typing import NDArray
|
||||
|
||||
from .base_engine import BaseEngine, Response
|
||||
|
||||
|
||||
@@ -36,9 +39,10 @@ class ChatModel:
|
||||
messages: Sequence[Dict[str, str]],
|
||||
system: Optional[str] = None,
|
||||
tools: Optional[str] = None,
|
||||
image: Optional["NDArray"] = None,
|
||||
**input_kwargs,
|
||||
) -> List["Response"]:
|
||||
task = asyncio.run_coroutine_threadsafe(self.achat(messages, system, tools, **input_kwargs), self._loop)
|
||||
task = asyncio.run_coroutine_threadsafe(self.achat(messages, system, tools, image, **input_kwargs), self._loop)
|
||||
return task.result()
|
||||
|
||||
async def achat(
|
||||
@@ -46,18 +50,20 @@ class ChatModel:
|
||||
messages: Sequence[Dict[str, str]],
|
||||
system: Optional[str] = None,
|
||||
tools: Optional[str] = None,
|
||||
image: Optional["NDArray"] = None,
|
||||
**input_kwargs,
|
||||
) -> List["Response"]:
|
||||
return await self.engine.chat(messages, system, tools, **input_kwargs)
|
||||
return await self.engine.chat(messages, system, tools, image, **input_kwargs)
|
||||
|
||||
def stream_chat(
|
||||
self,
|
||||
messages: Sequence[Dict[str, str]],
|
||||
system: Optional[str] = None,
|
||||
tools: Optional[str] = None,
|
||||
image: Optional["NDArray"] = None,
|
||||
**input_kwargs,
|
||||
) -> Generator[str, None, None]:
|
||||
generator = self.astream_chat(messages, system, tools, **input_kwargs)
|
||||
generator = self.astream_chat(messages, system, tools, image, **input_kwargs)
|
||||
while True:
|
||||
try:
|
||||
task = asyncio.run_coroutine_threadsafe(generator.__anext__(), self._loop)
|
||||
@@ -70,9 +76,10 @@ class ChatModel:
|
||||
messages: Sequence[Dict[str, str]],
|
||||
system: Optional[str] = None,
|
||||
tools: Optional[str] = None,
|
||||
image: Optional["NDArray"] = None,
|
||||
**input_kwargs,
|
||||
) -> AsyncGenerator[str, None]:
|
||||
async for new_token in self.engine.stream_chat(messages, system, tools, **input_kwargs):
|
||||
async for new_token in self.engine.stream_chat(messages, system, tools, image, **input_kwargs):
|
||||
yield new_token
|
||||
|
||||
def get_scores(
|
||||
@@ -89,3 +96,45 @@ class ChatModel:
|
||||
**input_kwargs,
|
||||
) -> List[float]:
|
||||
return await self.engine.get_scores(batch_input, **input_kwargs)
|
||||
|
||||
|
||||
def run_chat() -> None:
|
||||
try:
|
||||
import platform
|
||||
|
||||
if platform.system() != "Windows":
|
||||
import readline # noqa: F401
|
||||
except ImportError:
|
||||
print("Install `readline` for a better experience.")
|
||||
|
||||
chat_model = ChatModel()
|
||||
messages = []
|
||||
print("Welcome to the CLI application, use `clear` to remove the history, use `exit` to exit the application.")
|
||||
|
||||
while True:
|
||||
try:
|
||||
query = input("\nUser: ")
|
||||
except UnicodeDecodeError:
|
||||
print("Detected decoding error at the inputs, please set the terminal encoding to utf-8.")
|
||||
continue
|
||||
except Exception:
|
||||
raise
|
||||
|
||||
if query.strip() == "exit":
|
||||
break
|
||||
|
||||
if query.strip() == "clear":
|
||||
messages = []
|
||||
torch_gc()
|
||||
print("History has been removed.")
|
||||
continue
|
||||
|
||||
messages.append({"role": "user", "content": query})
|
||||
print("Assistant: ", end="", flush=True)
|
||||
|
||||
response = ""
|
||||
for new_text in chat_model.stream_chat(messages):
|
||||
print(new_text, end="", flush=True)
|
||||
response += new_text
|
||||
print()
|
||||
messages.append({"role": "assistant", "content": response})
|
||||
@@ -2,25 +2,31 @@ import asyncio
|
||||
import concurrent.futures
|
||||
import os
|
||||
from threading import Thread
|
||||
from typing import TYPE_CHECKING, Any, AsyncGenerator, Callable, Dict, List, Optional, Sequence, Tuple
|
||||
from typing import TYPE_CHECKING, Any, AsyncGenerator, Callable, Dict, List, Optional, Sequence, Tuple, Union
|
||||
|
||||
import torch
|
||||
from transformers import GenerationConfig, TextIteratorStreamer
|
||||
|
||||
from ..data import get_template_and_fix_tokenizer
|
||||
from ..extras.logging import get_logger
|
||||
from ..extras.misc import get_logits_processor
|
||||
from ..model import load_model_and_tokenizer
|
||||
from ..model import load_model, load_tokenizer
|
||||
from .base_engine import BaseEngine, Response
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from transformers import PreTrainedModel, PreTrainedTokenizer
|
||||
from numpy.typing import NDArray
|
||||
from transformers import PreTrainedModel, PreTrainedTokenizer, ProcessorMixin
|
||||
from transformers.image_processing_utils import BaseImageProcessor
|
||||
from trl import PreTrainedModelWrapper
|
||||
|
||||
from ..data import Template
|
||||
from ..hparams import DataArguments, FinetuningArguments, GeneratingArguments, ModelArguments
|
||||
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
class HuggingfaceEngine(BaseEngine):
|
||||
def __init__(
|
||||
self,
|
||||
@@ -30,55 +36,96 @@ class HuggingfaceEngine(BaseEngine):
|
||||
generating_args: "GeneratingArguments",
|
||||
) -> None:
|
||||
self.can_generate = finetuning_args.stage == "sft"
|
||||
self.model, self.tokenizer = load_model_and_tokenizer(
|
||||
model_args, finetuning_args, is_trainable=False, add_valuehead=(not self.can_generate)
|
||||
)
|
||||
tokenizer_module = load_tokenizer(model_args)
|
||||
self.tokenizer = tokenizer_module["tokenizer"]
|
||||
self.processor = tokenizer_module["processor"]
|
||||
self.tokenizer.padding_side = "left" if self.can_generate else "right"
|
||||
self.template = get_template_and_fix_tokenizer(self.tokenizer, data_args.template)
|
||||
self.model = load_model(
|
||||
self.tokenizer, model_args, finetuning_args, is_trainable=False, add_valuehead=(not self.can_generate)
|
||||
) # must after fixing tokenizer to resize vocab
|
||||
self.generating_args = generating_args.to_dict()
|
||||
|
||||
@staticmethod
|
||||
def _process_args(
|
||||
model: "PreTrainedModel",
|
||||
tokenizer: "PreTrainedTokenizer",
|
||||
processor: Optional["ProcessorMixin"],
|
||||
template: "Template",
|
||||
generating_args: Dict[str, Any],
|
||||
messages: Sequence[Dict[str, str]],
|
||||
system: Optional[str] = None,
|
||||
tools: Optional[str] = None,
|
||||
image: Optional["NDArray"] = None,
|
||||
input_kwargs: Optional[Dict[str, Any]] = {},
|
||||
) -> Tuple[Dict[str, Any], int]:
|
||||
if (
|
||||
processor is not None
|
||||
and image is not None
|
||||
and not hasattr(processor, "image_seq_length")
|
||||
and template.image_token not in messages[0]["content"]
|
||||
): # llava-like models
|
||||
messages[0]["content"] = template.image_token + messages[0]["content"]
|
||||
|
||||
paired_messages = messages + [{"role": "assistant", "content": ""}]
|
||||
system = system or generating_args["default_system"]
|
||||
pixel_values = None
|
||||
prompt_ids, _ = template.encode_oneturn(
|
||||
tokenizer=tokenizer, messages=paired_messages, system=system, tools=tools
|
||||
)
|
||||
if processor is not None and image is not None: # add image features
|
||||
image_processor: "BaseImageProcessor" = getattr(processor, "image_processor")
|
||||
batch_feature = image_processor(image, return_tensors="pt")
|
||||
pixel_values = batch_feature.to(model.device)["pixel_values"] # shape (B, C, H, W)
|
||||
if hasattr(processor, "image_seq_length"): # paligemma models
|
||||
image_token_id = tokenizer.convert_tokens_to_ids(template.image_token)
|
||||
prompt_ids = [image_token_id] * getattr(processor, "image_seq_length") + prompt_ids
|
||||
|
||||
prompt_length = len(prompt_ids)
|
||||
inputs = torch.tensor([prompt_ids], device=model.device)
|
||||
attention_mask = torch.ones_like(inputs, dtype=torch.bool)
|
||||
|
||||
do_sample = input_kwargs.pop("do_sample", None)
|
||||
temperature = input_kwargs.pop("temperature", None)
|
||||
top_p = input_kwargs.pop("top_p", None)
|
||||
top_k = input_kwargs.pop("top_k", None)
|
||||
num_return_sequences = input_kwargs.pop("num_return_sequences", None)
|
||||
repetition_penalty = input_kwargs.pop("repetition_penalty", None)
|
||||
max_length = input_kwargs.pop("max_length", None)
|
||||
max_new_tokens = input_kwargs.pop("max_new_tokens", None)
|
||||
do_sample: Optional[bool] = input_kwargs.pop("do_sample", None)
|
||||
temperature: Optional[float] = input_kwargs.pop("temperature", None)
|
||||
top_p: Optional[float] = input_kwargs.pop("top_p", None)
|
||||
top_k: Optional[float] = input_kwargs.pop("top_k", None)
|
||||
num_return_sequences: int = input_kwargs.pop("num_return_sequences", 1)
|
||||
repetition_penalty: Optional[float] = input_kwargs.pop("repetition_penalty", None)
|
||||
length_penalty: Optional[float] = input_kwargs.pop("length_penalty", None)
|
||||
max_length: Optional[int] = input_kwargs.pop("max_length", None)
|
||||
max_new_tokens: Optional[int] = input_kwargs.pop("max_new_tokens", None)
|
||||
stop: Optional[Union[str, List[str]]] = input_kwargs.pop("stop", None)
|
||||
|
||||
if stop is not None:
|
||||
logger.warning("Stop parameter is not supported in Huggingface engine yet.")
|
||||
|
||||
generating_args = generating_args.copy()
|
||||
generating_args.update(
|
||||
dict(
|
||||
do_sample=do_sample if do_sample is not None else generating_args["do_sample"],
|
||||
temperature=temperature or generating_args["temperature"],
|
||||
top_p=top_p or generating_args["top_p"],
|
||||
top_k=top_k or generating_args["top_k"],
|
||||
num_return_sequences=num_return_sequences or 1,
|
||||
repetition_penalty=repetition_penalty or generating_args["repetition_penalty"],
|
||||
temperature=temperature if temperature is not None else generating_args["temperature"],
|
||||
top_p=top_p if top_p is not None else generating_args["top_p"],
|
||||
top_k=top_k if top_k is not None else generating_args["top_k"],
|
||||
num_return_sequences=num_return_sequences,
|
||||
repetition_penalty=repetition_penalty
|
||||
if repetition_penalty is not None
|
||||
else generating_args["repetition_penalty"],
|
||||
length_penalty=length_penalty if length_penalty is not None else generating_args["length_penalty"],
|
||||
eos_token_id=[tokenizer.eos_token_id] + tokenizer.additional_special_tokens_ids,
|
||||
pad_token_id=tokenizer.pad_token_id,
|
||||
)
|
||||
)
|
||||
|
||||
if isinstance(num_return_sequences, int) and num_return_sequences > 1:
|
||||
if isinstance(num_return_sequences, int) and num_return_sequences > 1: # do_sample needs temperature > 0
|
||||
generating_args["do_sample"] = True
|
||||
generating_args["temperature"] = generating_args["temperature"] or 1.0
|
||||
|
||||
if not generating_args["temperature"]:
|
||||
generating_args["do_sample"] = False
|
||||
|
||||
if not generating_args["do_sample"]:
|
||||
generating_args.pop("temperature", None)
|
||||
generating_args.pop("top_p", None)
|
||||
|
||||
if max_length:
|
||||
generating_args.pop("max_new_tokens", None)
|
||||
@@ -90,10 +137,14 @@ class HuggingfaceEngine(BaseEngine):
|
||||
|
||||
gen_kwargs = dict(
|
||||
inputs=inputs,
|
||||
attention_mask=attention_mask,
|
||||
generation_config=GenerationConfig(**generating_args),
|
||||
logits_processor=get_logits_processor(),
|
||||
)
|
||||
|
||||
if pixel_values is not None:
|
||||
gen_kwargs["pixel_values"] = pixel_values
|
||||
|
||||
return gen_kwargs, prompt_length
|
||||
|
||||
@staticmethod
|
||||
@@ -101,15 +152,17 @@ class HuggingfaceEngine(BaseEngine):
|
||||
def _chat(
|
||||
model: "PreTrainedModel",
|
||||
tokenizer: "PreTrainedTokenizer",
|
||||
processor: Optional["ProcessorMixin"],
|
||||
template: "Template",
|
||||
generating_args: Dict[str, Any],
|
||||
messages: Sequence[Dict[str, str]],
|
||||
system: Optional[str] = None,
|
||||
tools: Optional[str] = None,
|
||||
image: Optional["NDArray"] = None,
|
||||
input_kwargs: Optional[Dict[str, Any]] = {},
|
||||
) -> List["Response"]:
|
||||
gen_kwargs, prompt_length = HuggingfaceEngine._process_args(
|
||||
model, tokenizer, template, generating_args, messages, system, tools, input_kwargs
|
||||
model, tokenizer, processor, template, generating_args, messages, system, tools, image, input_kwargs
|
||||
)
|
||||
generate_output = model.generate(**gen_kwargs)
|
||||
response_ids = generate_output[:, prompt_length:]
|
||||
@@ -134,15 +187,17 @@ class HuggingfaceEngine(BaseEngine):
|
||||
def _stream_chat(
|
||||
model: "PreTrainedModel",
|
||||
tokenizer: "PreTrainedTokenizer",
|
||||
processor: Optional["ProcessorMixin"],
|
||||
template: "Template",
|
||||
generating_args: Dict[str, Any],
|
||||
messages: Sequence[Dict[str, str]],
|
||||
system: Optional[str] = None,
|
||||
tools: Optional[str] = None,
|
||||
image: Optional["NDArray"] = None,
|
||||
input_kwargs: Optional[Dict[str, Any]] = {},
|
||||
) -> Callable[[], str]:
|
||||
gen_kwargs, _ = HuggingfaceEngine._process_args(
|
||||
model, tokenizer, template, generating_args, messages, system, tools, input_kwargs
|
||||
model, tokenizer, processor, template, generating_args, messages, system, tools, image, input_kwargs
|
||||
)
|
||||
streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
|
||||
gen_kwargs["streamer"] = streamer
|
||||
@@ -198,6 +253,7 @@ class HuggingfaceEngine(BaseEngine):
|
||||
messages: Sequence[Dict[str, str]],
|
||||
system: Optional[str] = None,
|
||||
tools: Optional[str] = None,
|
||||
image: Optional["NDArray"] = None,
|
||||
**input_kwargs,
|
||||
) -> List["Response"]:
|
||||
if not self.can_generate:
|
||||
@@ -207,11 +263,13 @@ class HuggingfaceEngine(BaseEngine):
|
||||
input_args = (
|
||||
self.model,
|
||||
self.tokenizer,
|
||||
self.processor,
|
||||
self.template,
|
||||
self.generating_args,
|
||||
messages,
|
||||
system,
|
||||
tools,
|
||||
image,
|
||||
input_kwargs,
|
||||
)
|
||||
async with self._semaphore:
|
||||
@@ -223,6 +281,7 @@ class HuggingfaceEngine(BaseEngine):
|
||||
messages: Sequence[Dict[str, str]],
|
||||
system: Optional[str] = None,
|
||||
tools: Optional[str] = None,
|
||||
image: Optional["NDArray"] = None,
|
||||
**input_kwargs,
|
||||
) -> AsyncGenerator[str, None]:
|
||||
if not self.can_generate:
|
||||
@@ -232,11 +291,13 @@ class HuggingfaceEngine(BaseEngine):
|
||||
input_args = (
|
||||
self.model,
|
||||
self.tokenizer,
|
||||
self.processor,
|
||||
self.template,
|
||||
self.generating_args,
|
||||
messages,
|
||||
system,
|
||||
tools,
|
||||
image,
|
||||
input_kwargs,
|
||||
)
|
||||
async with self._semaphore:
|
||||
@@ -0,0 +1,214 @@
|
||||
import uuid
|
||||
from typing import TYPE_CHECKING, AsyncGenerator, AsyncIterator, Dict, List, Optional, Sequence, Union
|
||||
|
||||
from ..data import get_template_and_fix_tokenizer
|
||||
from ..extras.logging import get_logger
|
||||
from ..extras.misc import get_device_count
|
||||
from ..extras.packages import is_vllm_available
|
||||
from ..model import load_config, load_tokenizer
|
||||
from ..model.model_utils.visual import LlavaMultiModalProjectorForYiVLForVLLM
|
||||
from .base_engine import BaseEngine, Response
|
||||
|
||||
|
||||
if is_vllm_available():
|
||||
from vllm import AsyncEngineArgs, AsyncLLMEngine, RequestOutput, SamplingParams
|
||||
from vllm.lora.request import LoRARequest
|
||||
from vllm.sequence import MultiModalData
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from numpy.typing import NDArray
|
||||
from transformers.image_processing_utils import BaseImageProcessor
|
||||
|
||||
from ..hparams import DataArguments, FinetuningArguments, GeneratingArguments, ModelArguments
|
||||
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
class VllmEngine(BaseEngine):
|
||||
def __init__(
|
||||
self,
|
||||
model_args: "ModelArguments",
|
||||
data_args: "DataArguments",
|
||||
finetuning_args: "FinetuningArguments",
|
||||
generating_args: "GeneratingArguments",
|
||||
) -> None:
|
||||
config = load_config(model_args) # may download model from ms hub
|
||||
|
||||
self.can_generate = finetuning_args.stage == "sft"
|
||||
tokenizer_module = load_tokenizer(model_args)
|
||||
self.tokenizer = tokenizer_module["tokenizer"]
|
||||
self.processor = tokenizer_module["processor"]
|
||||
self.tokenizer.padding_side = "left"
|
||||
self.template = get_template_and_fix_tokenizer(self.tokenizer, data_args.template)
|
||||
self.generating_args = generating_args.to_dict()
|
||||
|
||||
engine_args = {
|
||||
"model": model_args.model_name_or_path,
|
||||
"trust_remote_code": True,
|
||||
"download_dir": model_args.cache_dir,
|
||||
"dtype": model_args.vllm_dtype,
|
||||
"max_model_len": model_args.vllm_maxlen,
|
||||
"tensor_parallel_size": get_device_count() or 1,
|
||||
"gpu_memory_utilization": model_args.vllm_gpu_util,
|
||||
"disable_log_stats": True,
|
||||
"disable_log_requests": True,
|
||||
"enforce_eager": model_args.vllm_enforce_eager,
|
||||
"enable_lora": model_args.adapter_name_or_path is not None,
|
||||
"max_lora_rank": model_args.vllm_max_lora_rank,
|
||||
}
|
||||
|
||||
if model_args.visual_inputs:
|
||||
image_size = config.vision_config.image_size
|
||||
patch_size = config.vision_config.patch_size
|
||||
self.image_feature_size = (image_size // patch_size) ** 2
|
||||
engine_args["image_input_type"] = "pixel_values"
|
||||
engine_args["image_token_id"] = self.tokenizer.convert_tokens_to_ids(self.template.image_token)
|
||||
engine_args["image_input_shape"] = "1,3,{},{}".format(image_size, image_size)
|
||||
engine_args["image_feature_size"] = self.image_feature_size
|
||||
if getattr(config, "is_yi_vl_derived_model", None):
|
||||
import vllm.model_executor.models.llava
|
||||
|
||||
logger.info("Detected Yi-VL model, applying projector patch.")
|
||||
vllm.model_executor.models.llava.LlavaMultiModalProjector = LlavaMultiModalProjectorForYiVLForVLLM
|
||||
|
||||
self.model = AsyncLLMEngine.from_engine_args(AsyncEngineArgs(**engine_args))
|
||||
if model_args.adapter_name_or_path is not None:
|
||||
self.lora_request = LoRARequest("default", 1, model_args.adapter_name_or_path[0])
|
||||
else:
|
||||
self.lora_request = None
|
||||
|
||||
async def _generate(
|
||||
self,
|
||||
messages: Sequence[Dict[str, str]],
|
||||
system: Optional[str] = None,
|
||||
tools: Optional[str] = None,
|
||||
image: Optional["NDArray"] = None,
|
||||
**input_kwargs,
|
||||
) -> AsyncIterator["RequestOutput"]:
|
||||
request_id = "chatcmpl-{}".format(uuid.uuid4().hex)
|
||||
|
||||
if (
|
||||
self.processor is not None
|
||||
and image is not None
|
||||
and not hasattr(self.processor, "image_seq_length")
|
||||
and self.template.image_token not in messages[0]["content"]
|
||||
): # llava-like models (TODO: paligemma models)
|
||||
messages[0]["content"] = self.template.image_token * self.image_feature_size + messages[0]["content"]
|
||||
|
||||
paired_messages = messages + [{"role": "assistant", "content": ""}]
|
||||
system = system or self.generating_args["default_system"]
|
||||
prompt_ids, _ = self.template.encode_oneturn(
|
||||
tokenizer=self.tokenizer, messages=paired_messages, system=system, tools=tools
|
||||
)
|
||||
|
||||
if self.processor is not None and image is not None: # add image features
|
||||
image_processor: "BaseImageProcessor" = getattr(self.processor, "image_processor")
|
||||
pixel_values = image_processor(image, return_tensors="pt")["pixel_values"]
|
||||
multi_modal_data = MultiModalData(type=MultiModalData.Type.IMAGE, data=pixel_values)
|
||||
else:
|
||||
multi_modal_data = None
|
||||
|
||||
prompt_length = len(prompt_ids)
|
||||
|
||||
use_beam_search: bool = self.generating_args["num_beams"] > 1
|
||||
temperature: Optional[float] = input_kwargs.pop("temperature", None)
|
||||
top_p: Optional[float] = input_kwargs.pop("top_p", None)
|
||||
top_k: Optional[float] = input_kwargs.pop("top_k", None)
|
||||
num_return_sequences: int = input_kwargs.pop("num_return_sequences", 1)
|
||||
repetition_penalty: Optional[float] = input_kwargs.pop("repetition_penalty", None)
|
||||
length_penalty: Optional[float] = input_kwargs.pop("length_penalty", None)
|
||||
max_length: Optional[int] = input_kwargs.pop("max_length", None)
|
||||
max_new_tokens: Optional[int] = input_kwargs.pop("max_new_tokens", None)
|
||||
stop: Optional[Union[str, List[str]]] = input_kwargs.pop("stop", None)
|
||||
|
||||
if "max_new_tokens" in self.generating_args:
|
||||
max_tokens = self.generating_args["max_new_tokens"]
|
||||
elif "max_length" in self.generating_args:
|
||||
if self.generating_args["max_length"] > prompt_length:
|
||||
max_tokens = self.generating_args["max_length"] - prompt_length
|
||||
else:
|
||||
max_tokens = 1
|
||||
|
||||
if max_length:
|
||||
max_tokens = max_length - prompt_length if max_length > prompt_length else 1
|
||||
|
||||
if max_new_tokens:
|
||||
max_tokens = max_new_tokens
|
||||
|
||||
sampling_params = SamplingParams(
|
||||
n=num_return_sequences,
|
||||
repetition_penalty=(
|
||||
repetition_penalty if repetition_penalty is not None else self.generating_args["repetition_penalty"]
|
||||
)
|
||||
or 1.0, # repetition_penalty must > 0
|
||||
temperature=temperature if temperature is not None else self.generating_args["temperature"],
|
||||
top_p=(top_p if top_p is not None else self.generating_args["top_p"]) or 1.0, # top_p must > 0
|
||||
top_k=top_k if top_k is not None else self.generating_args["top_k"],
|
||||
use_beam_search=use_beam_search,
|
||||
length_penalty=length_penalty if length_penalty is not None else self.generating_args["length_penalty"],
|
||||
stop=stop,
|
||||
stop_token_ids=[self.tokenizer.eos_token_id] + self.tokenizer.additional_special_tokens_ids,
|
||||
max_tokens=max_tokens,
|
||||
skip_special_tokens=True,
|
||||
)
|
||||
|
||||
result_generator = self.model.generate(
|
||||
inputs={"prompt_token_ids": prompt_ids, "multi_modal_data": multi_modal_data},
|
||||
sampling_params=sampling_params,
|
||||
request_id=request_id,
|
||||
lora_request=self.lora_request,
|
||||
)
|
||||
return result_generator
|
||||
|
||||
async def start(self) -> None:
|
||||
pass
|
||||
|
||||
async def chat(
|
||||
self,
|
||||
messages: Sequence[Dict[str, str]],
|
||||
system: Optional[str] = None,
|
||||
tools: Optional[str] = None,
|
||||
image: Optional["NDArray"] = None,
|
||||
**input_kwargs,
|
||||
) -> List["Response"]:
|
||||
final_output = None
|
||||
generator = await self._generate(messages, system, tools, image, **input_kwargs)
|
||||
async for request_output in generator:
|
||||
final_output = request_output
|
||||
|
||||
results = []
|
||||
for output in final_output.outputs:
|
||||
results.append(
|
||||
Response(
|
||||
response_text=output.text,
|
||||
response_length=len(output.token_ids),
|
||||
prompt_length=len(final_output.prompt_token_ids),
|
||||
finish_reason=output.finish_reason,
|
||||
)
|
||||
)
|
||||
|
||||
return results
|
||||
|
||||
async def stream_chat(
|
||||
self,
|
||||
messages: Sequence[Dict[str, str]],
|
||||
system: Optional[str] = None,
|
||||
tools: Optional[str] = None,
|
||||
image: Optional["NDArray"] = None,
|
||||
**input_kwargs,
|
||||
) -> AsyncGenerator[str, None]:
|
||||
generated_text = ""
|
||||
generator = await self._generate(messages, system, tools, image, **input_kwargs)
|
||||
async for result in generator:
|
||||
delta_text = result.outputs[0].text[len(generated_text) :]
|
||||
generated_text = result.outputs[0].text
|
||||
yield delta_text
|
||||
|
||||
async def get_scores(
|
||||
self,
|
||||
batch_input: List[str],
|
||||
**input_kwargs,
|
||||
) -> List[float]:
|
||||
raise NotImplementedError("vLLM engine does not support get_scores.")
|
||||
106
src/AntSK.LLamaFactory/llamafactory/llamafactory/cli.py
Normal file
106
src/AntSK.LLamaFactory/llamafactory/llamafactory/cli.py
Normal file
@@ -0,0 +1,106 @@
|
||||
import os
|
||||
import random
|
||||
import subprocess
|
||||
import sys
|
||||
from enum import Enum, unique
|
||||
|
||||
from . import launcher
|
||||
from .api.app import run_api
|
||||
from .chat.chat_model import run_chat
|
||||
from .eval.evaluator import run_eval
|
||||
from .extras.env import VERSION, print_env
|
||||
from .extras.logging import get_logger
|
||||
from .extras.misc import get_device_count
|
||||
from .train.tuner import export_model, run_exp
|
||||
from .webui.interface import run_web_demo, run_web_ui
|
||||
|
||||
|
||||
USAGE = (
|
||||
"-" * 70
|
||||
+ "\n"
|
||||
+ "| Usage: |\n"
|
||||
+ "| llamafactory-cli api -h: launch an OpenAI-style API server |\n"
|
||||
+ "| llamafactory-cli chat -h: launch a chat interface in CLI |\n"
|
||||
+ "| llamafactory-cli eval -h: evaluate models |\n"
|
||||
+ "| llamafactory-cli export -h: merge LoRA adapters and export model |\n"
|
||||
+ "| llamafactory-cli train -h: train models |\n"
|
||||
+ "| llamafactory-cli webchat -h: launch a chat interface in Web UI |\n"
|
||||
+ "| llamafactory-cli webui: launch LlamaBoard |\n"
|
||||
+ "| llamafactory-cli version: show version info |\n"
|
||||
+ "-" * 70
|
||||
)
|
||||
|
||||
WELCOME = (
|
||||
"-" * 58
|
||||
+ "\n"
|
||||
+ "| Welcome to LLaMA Factory, version {}".format(VERSION)
|
||||
+ " " * (21 - len(VERSION))
|
||||
+ "|\n|"
|
||||
+ " " * 56
|
||||
+ "|\n"
|
||||
+ "| Project page: https://github.com/hiyouga/LLaMA-Factory |\n"
|
||||
+ "-" * 58
|
||||
)
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
@unique
|
||||
class Command(str, Enum):
|
||||
API = "api"
|
||||
CHAT = "chat"
|
||||
ENV = "env"
|
||||
EVAL = "eval"
|
||||
EXPORT = "export"
|
||||
TRAIN = "train"
|
||||
WEBDEMO = "webchat"
|
||||
WEBUI = "webui"
|
||||
VER = "version"
|
||||
HELP = "help"
|
||||
|
||||
|
||||
def main():
|
||||
command = sys.argv.pop(1)
|
||||
if command == Command.API:
|
||||
run_api()
|
||||
elif command == Command.CHAT:
|
||||
run_chat()
|
||||
elif command == Command.ENV:
|
||||
print_env()
|
||||
elif command == Command.EVAL:
|
||||
run_eval()
|
||||
elif command == Command.EXPORT:
|
||||
export_model()
|
||||
elif command == Command.TRAIN:
|
||||
force_torchrun = os.environ.get("FORCE_TORCHRUN", "0").lower() in ["true", "1"]
|
||||
if force_torchrun or get_device_count() > 1:
|
||||
master_addr = os.environ.get("MASTER_ADDR", "127.0.0.1")
|
||||
master_port = os.environ.get("MASTER_PORT", str(random.randint(20001, 29999)))
|
||||
logger.info("Initializing distributed tasks at: {}:{}".format(master_addr, master_port))
|
||||
subprocess.run(
|
||||
(
|
||||
"torchrun --nnodes {nnodes} --node_rank {node_rank} --nproc_per_node {nproc_per_node} "
|
||||
"--master_addr {master_addr} --master_port {master_port} {file_name} {args}"
|
||||
).format(
|
||||
nnodes=os.environ.get("NNODES", "1"),
|
||||
node_rank=os.environ.get("RANK", "0"),
|
||||
nproc_per_node=os.environ.get("NPROC_PER_NODE", str(get_device_count())),
|
||||
master_addr=master_addr,
|
||||
master_port=master_port,
|
||||
file_name=launcher.__file__,
|
||||
args=" ".join(sys.argv[1:]),
|
||||
),
|
||||
shell=True,
|
||||
)
|
||||
else:
|
||||
run_exp()
|
||||
elif command == Command.WEBDEMO:
|
||||
run_web_demo()
|
||||
elif command == Command.WEBUI:
|
||||
run_web_ui()
|
||||
elif command == Command.VER:
|
||||
print(WELCOME)
|
||||
elif command == Command.HELP:
|
||||
print(USAGE)
|
||||
else:
|
||||
raise NotImplementedError("Unknown command: {}".format(command))
|
||||
@@ -0,0 +1,16 @@
|
||||
from .collator import KTODataCollatorWithPadding, PairwiseDataCollatorWithPadding
|
||||
from .data_utils import Role, split_dataset
|
||||
from .loader import get_dataset
|
||||
from .template import TEMPLATES, Template, get_template_and_fix_tokenizer
|
||||
|
||||
|
||||
__all__ = [
|
||||
"KTODataCollatorWithPadding",
|
||||
"PairwiseDataCollatorWithPadding",
|
||||
"Role",
|
||||
"split_dataset",
|
||||
"get_dataset",
|
||||
"TEMPLATES",
|
||||
"Template",
|
||||
"get_template_and_fix_tokenizer",
|
||||
]
|
||||
221
src/AntSK.LLamaFactory/llamafactory/llamafactory/data/aligner.py
Normal file
221
src/AntSK.LLamaFactory/llamafactory/llamafactory/data/aligner.py
Normal file
@@ -0,0 +1,221 @@
|
||||
import os
|
||||
from functools import partial
|
||||
from typing import TYPE_CHECKING, Any, Dict, List, Union
|
||||
|
||||
from datasets import Features
|
||||
|
||||
from ..extras.logging import get_logger
|
||||
from .data_utils import Role
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from datasets import Dataset, IterableDataset
|
||||
|
||||
from ..hparams import DataArguments
|
||||
from .parser import DatasetAttr
|
||||
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
def _convert_images(images: List[Any], dataset_attr: "DatasetAttr", data_args: "DataArguments") -> List[Any]:
|
||||
r"""
|
||||
Optionally concatenates image path to dataset dir when loading from local disk.
|
||||
"""
|
||||
outputs = []
|
||||
if dataset_attr.load_from in ["script", "file"]:
|
||||
for image in images:
|
||||
if isinstance(image, str) and os.path.isfile(os.path.join(data_args.dataset_dir, image)):
|
||||
outputs.append(os.path.join(data_args.dataset_dir, image))
|
||||
else:
|
||||
outputs.append(image)
|
||||
|
||||
return outputs
|
||||
|
||||
|
||||
def convert_alpaca(
|
||||
examples: Dict[str, List[Any]], dataset_attr: "DatasetAttr", data_args: "DataArguments"
|
||||
) -> Dict[str, List[Any]]:
|
||||
r"""
|
||||
Converts alpaca format dataset to the standard format.
|
||||
"""
|
||||
outputs = {"prompt": [], "response": [], "system": [], "tools": [], "images": []}
|
||||
convert_images = partial(_convert_images, dataset_attr=dataset_attr, data_args=data_args)
|
||||
for i in range(len(examples[dataset_attr.prompt])):
|
||||
prompt = []
|
||||
if dataset_attr.history and isinstance(examples[dataset_attr.history][i], list):
|
||||
for old_prompt, old_response in examples[dataset_attr.history][i]:
|
||||
prompt.append({"role": Role.USER.value, "content": old_prompt})
|
||||
prompt.append({"role": Role.ASSISTANT.value, "content": old_response})
|
||||
|
||||
content = []
|
||||
if dataset_attr.prompt and examples[dataset_attr.prompt][i]:
|
||||
content.append(examples[dataset_attr.prompt][i])
|
||||
|
||||
if dataset_attr.query and examples[dataset_attr.query][i]:
|
||||
content.append(examples[dataset_attr.query][i])
|
||||
|
||||
prompt.append({"role": Role.USER.value, "content": "\n".join(content)}) # "prompt\nquery"
|
||||
|
||||
if dataset_attr.kto_tag and isinstance(examples[dataset_attr.kto_tag][i], bool): # kto example
|
||||
response = [{"role": Role.ASSISTANT.value, "content": examples[dataset_attr.response][i]}]
|
||||
if examples[dataset_attr.kto_tag][i]:
|
||||
response = response + [{"role": Role.ASSISTANT.value, "content": ""}]
|
||||
else:
|
||||
response = [{"role": Role.ASSISTANT.value, "content": ""}] + response
|
||||
elif (
|
||||
dataset_attr.ranking
|
||||
and isinstance(examples[dataset_attr.chosen][i], str)
|
||||
and isinstance(examples[dataset_attr.rejected][i], str)
|
||||
): # pairwise example
|
||||
response = [
|
||||
{"role": Role.ASSISTANT.value, "content": examples[dataset_attr.chosen][i]},
|
||||
{"role": Role.ASSISTANT.value, "content": examples[dataset_attr.rejected][i]},
|
||||
]
|
||||
elif dataset_attr.response and isinstance(examples[dataset_attr.response][i], str): # normal example
|
||||
response = [{"role": Role.ASSISTANT.value, "content": examples[dataset_attr.response][i]}]
|
||||
else: # unsupervised
|
||||
response = []
|
||||
|
||||
outputs["prompt"].append(prompt)
|
||||
outputs["response"].append(response)
|
||||
outputs["system"].append(examples[dataset_attr.system][i] if dataset_attr.system else "")
|
||||
outputs["tools"].append(examples[dataset_attr.tools][i] if dataset_attr.tools else "")
|
||||
outputs["images"].append(convert_images(examples[dataset_attr.images][i]) if dataset_attr.images else [])
|
||||
|
||||
return outputs
|
||||
|
||||
|
||||
def convert_sharegpt(
|
||||
examples: Dict[str, List[Any]], dataset_attr: "DatasetAttr", data_args: "DataArguments"
|
||||
) -> Dict[str, List[Any]]:
|
||||
r"""
|
||||
Converts sharegpt format dataset to the standard format.
|
||||
"""
|
||||
outputs = {"prompt": [], "response": [], "system": [], "tools": [], "images": []}
|
||||
convert_images = partial(_convert_images, dataset_attr=dataset_attr, data_args=data_args)
|
||||
tag_mapping = {
|
||||
dataset_attr.user_tag: Role.USER.value,
|
||||
dataset_attr.assistant_tag: Role.ASSISTANT.value,
|
||||
dataset_attr.observation_tag: Role.OBSERVATION.value,
|
||||
dataset_attr.function_tag: Role.FUNCTION.value,
|
||||
dataset_attr.system_tag: Role.SYSTEM.value,
|
||||
}
|
||||
odd_tags = (dataset_attr.user_tag, dataset_attr.observation_tag)
|
||||
even_tags = (dataset_attr.assistant_tag, dataset_attr.function_tag)
|
||||
accept_tags = (odd_tags, even_tags)
|
||||
for i, messages in enumerate(examples[dataset_attr.messages]):
|
||||
if dataset_attr.system_tag and messages[0][dataset_attr.role_tag] == dataset_attr.system_tag:
|
||||
system = messages[0][dataset_attr.content_tag]
|
||||
messages = messages[1:]
|
||||
else:
|
||||
system = examples[dataset_attr.system][i] if dataset_attr.system else ""
|
||||
|
||||
if len(messages) == 0:
|
||||
continue
|
||||
|
||||
aligned_messages = []
|
||||
broken_data = False
|
||||
for turn_idx, message in enumerate(messages):
|
||||
if message[dataset_attr.role_tag] not in accept_tags[turn_idx % 2]:
|
||||
logger.warning("Invalid role tag in {}.".format(messages))
|
||||
broken_data = True
|
||||
|
||||
aligned_messages.append(
|
||||
{"role": tag_mapping[message[dataset_attr.role_tag]], "content": message[dataset_attr.content_tag]}
|
||||
)
|
||||
|
||||
if (not dataset_attr.ranking and len(aligned_messages) % 2 != 0) or (
|
||||
dataset_attr.ranking and len(aligned_messages) % 2 == 0
|
||||
):
|
||||
logger.warning("Invalid message count in {}.".format(messages))
|
||||
broken_data = True
|
||||
|
||||
if dataset_attr.kto_tag and isinstance(examples[dataset_attr.kto_tag][i], bool): # kto example
|
||||
prompt = aligned_messages[:-1]
|
||||
response = aligned_messages[-1:]
|
||||
if examples[dataset_attr.kto_tag][i]:
|
||||
response = response + [{"role": Role.ASSISTANT.value, "content": ""}]
|
||||
else:
|
||||
response = [{"role": Role.ASSISTANT.value, "content": ""}] + response
|
||||
elif (
|
||||
dataset_attr.ranking
|
||||
and isinstance(examples[dataset_attr.chosen][i], dict)
|
||||
and isinstance(examples[dataset_attr.rejected][i], dict)
|
||||
): # pairwise example
|
||||
chosen = examples[dataset_attr.chosen][i]
|
||||
rejected = examples[dataset_attr.rejected][i]
|
||||
if (
|
||||
chosen[dataset_attr.role_tag] not in accept_tags[-1]
|
||||
or rejected[dataset_attr.role_tag] not in accept_tags[-1]
|
||||
):
|
||||
logger.warning("Invalid role tag in {}.".format([chosen, rejected]))
|
||||
broken_data = True
|
||||
|
||||
prompt = aligned_messages
|
||||
response = [
|
||||
{"role": tag_mapping[chosen[dataset_attr.role_tag]], "content": chosen[dataset_attr.content_tag]},
|
||||
{"role": tag_mapping[rejected[dataset_attr.role_tag]], "content": rejected[dataset_attr.content_tag]},
|
||||
]
|
||||
else: # normal example
|
||||
prompt = aligned_messages[:-1]
|
||||
response = aligned_messages[-1:]
|
||||
|
||||
if broken_data:
|
||||
logger.warning("Skipping this abnormal example.")
|
||||
continue
|
||||
|
||||
outputs["prompt"].append(prompt)
|
||||
outputs["response"].append(response)
|
||||
outputs["system"].append(system)
|
||||
outputs["tools"].append(examples[dataset_attr.tools][i] if dataset_attr.tools else "")
|
||||
outputs["images"].append(convert_images(examples[dataset_attr.images][i]) if dataset_attr.images else [])
|
||||
|
||||
return outputs
|
||||
|
||||
|
||||
def align_dataset(
|
||||
dataset: Union["Dataset", "IterableDataset"], dataset_attr: "DatasetAttr", data_args: "DataArguments"
|
||||
) -> Union["Dataset", "IterableDataset"]:
|
||||
r"""
|
||||
Aligned dataset:
|
||||
prompt: [{"role": "user", "content": "..."}] * (2T - 1)
|
||||
response: [{"role": "assistant", "content": "..."}] * N (N > 1 for ranking dataset)
|
||||
system: "..."
|
||||
tools: "...",
|
||||
images: [],
|
||||
"""
|
||||
if dataset_attr.formatting == "alpaca":
|
||||
convert_func = partial(convert_alpaca, dataset_attr=dataset_attr, data_args=data_args)
|
||||
else:
|
||||
convert_func = partial(convert_sharegpt, dataset_attr=dataset_attr, data_args=data_args)
|
||||
|
||||
column_names = list(next(iter(dataset)).keys())
|
||||
features = Features.from_dict(
|
||||
{
|
||||
"prompt": [
|
||||
{"role": {"dtype": "string", "_type": "Value"}, "content": {"dtype": "string", "_type": "Value"}}
|
||||
],
|
||||
"response": [
|
||||
{"role": {"dtype": "string", "_type": "Value"}, "content": {"dtype": "string", "_type": "Value"}}
|
||||
],
|
||||
"system": {"dtype": "string", "_type": "Value"},
|
||||
"tools": {"dtype": "string", "_type": "Value"},
|
||||
"images": [{"_type": "Image"}],
|
||||
}
|
||||
)
|
||||
kwargs = {}
|
||||
if not data_args.streaming:
|
||||
kwargs = dict(
|
||||
num_proc=data_args.preprocessing_num_workers,
|
||||
load_from_cache_file=(not data_args.overwrite_cache),
|
||||
desc="Converting format of dataset",
|
||||
)
|
||||
|
||||
return dataset.map(
|
||||
convert_func,
|
||||
batched=True,
|
||||
remove_columns=column_names,
|
||||
features=features,
|
||||
**kwargs,
|
||||
)
|
||||
@@ -0,0 +1,81 @@
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Dict, Sequence
|
||||
|
||||
import torch
|
||||
from transformers import DataCollatorForSeq2Seq
|
||||
|
||||
|
||||
@dataclass
|
||||
class PairwiseDataCollatorWithPadding(DataCollatorForSeq2Seq):
|
||||
r"""
|
||||
Data collator for pairwise data.
|
||||
"""
|
||||
|
||||
def __call__(self, features: Sequence[Dict[str, Any]]) -> Dict[str, torch.Tensor]:
|
||||
r"""
|
||||
Pads batched data to the longest sequence in the batch.
|
||||
|
||||
We generate 2 * n examples where the first n examples represent chosen examples and
|
||||
the last n examples represent rejected examples.
|
||||
"""
|
||||
concatenated_features = []
|
||||
for key in ("chosen", "rejected"):
|
||||
for feature in features:
|
||||
target_feature = {
|
||||
"input_ids": feature["{}_input_ids".format(key)],
|
||||
"attention_mask": feature["{}_attention_mask".format(key)],
|
||||
"labels": feature["{}_labels".format(key)],
|
||||
}
|
||||
if "pixel_values" in feature:
|
||||
target_feature["pixel_values"] = feature["pixel_values"]
|
||||
|
||||
if "{}_token_type_ids".format(key) in feature:
|
||||
target_feature["token_type_ids"] = feature["{}_token_type_ids".format(key)]
|
||||
|
||||
concatenated_features.append(target_feature)
|
||||
|
||||
return super().__call__(concatenated_features)
|
||||
|
||||
|
||||
@dataclass
|
||||
class KTODataCollatorWithPadding(DataCollatorForSeq2Seq):
|
||||
r"""
|
||||
Data collator for KTO data.
|
||||
"""
|
||||
|
||||
def __call__(self, features: Sequence[Dict[str, Any]]) -> Dict[str, torch.Tensor]:
|
||||
target_features = []
|
||||
kl_features = []
|
||||
kto_tags = []
|
||||
for feature in features:
|
||||
target_feature = {
|
||||
"input_ids": feature["input_ids"],
|
||||
"attention_mask": feature["attention_mask"],
|
||||
"labels": feature["labels"],
|
||||
}
|
||||
kl_feature = {
|
||||
"input_ids": feature["kl_input_ids"],
|
||||
"attention_mask": feature["kl_attention_mask"],
|
||||
"labels": feature["kl_labels"],
|
||||
}
|
||||
if "pixel_values" in feature:
|
||||
target_feature["pixel_values"] = feature["pixel_values"]
|
||||
|
||||
if "token_type_ids" in feature:
|
||||
target_feature["token_type_ids"] = feature["token_type_ids"]
|
||||
kl_feature["token_type_ids"] = feature["kl_token_type_ids"]
|
||||
|
||||
target_features.append(target_feature)
|
||||
kl_features.append(kl_feature)
|
||||
kto_tags.append(feature["kto_tags"])
|
||||
|
||||
batch = super().__call__(target_features)
|
||||
kl_batch = super().__call__(kl_features)
|
||||
batch["kl_input_ids"] = kl_batch["input_ids"]
|
||||
batch["kl_attention_mask"] = kl_batch["attention_mask"]
|
||||
batch["kl_labels"] = kl_batch["labels"]
|
||||
if "token_type_ids" in batch:
|
||||
batch["kl_token_type_ids"] = kl_batch["token_type_ids"]
|
||||
|
||||
batch["kto_tags"] = torch.tensor(kto_tags)
|
||||
return batch
|
||||
@@ -1,6 +1,5 @@
|
||||
import hashlib
|
||||
from enum import Enum, unique
|
||||
from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union
|
||||
from typing import TYPE_CHECKING, Dict, List, Tuple, Union
|
||||
|
||||
from datasets import concatenate_datasets, interleave_datasets
|
||||
|
||||
@@ -11,7 +10,7 @@ if TYPE_CHECKING:
|
||||
from datasets import Dataset, IterableDataset
|
||||
from transformers import Seq2SeqTrainingArguments
|
||||
|
||||
from llmtuner.hparams import DataArguments
|
||||
from ..hparams import DataArguments
|
||||
|
||||
|
||||
logger = get_logger(__name__)
|
||||
@@ -26,25 +25,10 @@ class Role(str, Enum):
|
||||
OBSERVATION = "observation"
|
||||
|
||||
|
||||
def checksum(data_files: List[str], file_sha1: Optional[str] = None) -> None:
|
||||
if file_sha1 is None:
|
||||
logger.warning("Checksum failed: missing SHA-1 hash value in dataset_info.json.")
|
||||
return
|
||||
|
||||
if len(data_files) != 1:
|
||||
logger.warning("Checksum failed: too many files.")
|
||||
return
|
||||
|
||||
with open(data_files[0], "rb") as f:
|
||||
sha1 = hashlib.sha1(f.read()).hexdigest()
|
||||
if sha1 != file_sha1:
|
||||
logger.warning("Checksum failed: mismatched SHA-1 hash value at {}.".format(data_files[0]))
|
||||
|
||||
|
||||
def infer_max_len(source_len: int, target_len: int, max_len: int, reserved_label_len: int) -> Tuple[int, int]:
|
||||
max_target_len = int(max_len * (target_len / (source_len + target_len)))
|
||||
max_target_len = max(max_target_len, reserved_label_len)
|
||||
max_source_len = max_len - max_target_len
|
||||
max_source_len = max_len - min(max_target_len, target_len)
|
||||
return max_source_len, max_target_len
|
||||
|
||||
|
||||
@@ -78,9 +62,9 @@ def split_dataset(
|
||||
if training_args.do_train:
|
||||
if data_args.val_size > 1e-6: # Split the dataset
|
||||
if data_args.streaming:
|
||||
dataset = dataset.shuffle(buffer_size=data_args.buffer_size, seed=training_args.seed)
|
||||
val_set = dataset.take(int(data_args.val_size))
|
||||
train_set = dataset.skip(int(data_args.val_size))
|
||||
dataset = dataset.shuffle(buffer_size=data_args.buffer_size, seed=training_args.seed)
|
||||
return {"train_dataset": train_set, "eval_dataset": val_set}
|
||||
else:
|
||||
val_size = int(data_args.val_size) if data_args.val_size > 1 else data_args.val_size
|
||||
@@ -1,21 +1,24 @@
|
||||
import inspect
|
||||
import os
|
||||
from typing import TYPE_CHECKING, Literal, Union
|
||||
import sys
|
||||
from typing import TYPE_CHECKING, Literal, Optional, Union
|
||||
|
||||
import numpy as np
|
||||
from datasets import load_dataset, load_from_disk
|
||||
|
||||
from ..extras.constants import FILEEXT2TYPE
|
||||
from ..extras.logging import get_logger
|
||||
from ..extras.misc import has_tokenized_data
|
||||
from .aligner import align_dataset
|
||||
from .data_utils import merge_dataset
|
||||
from .parser import get_dataset_list
|
||||
from .preprocess import get_preprocess_and_print_func
|
||||
from .template import get_template_and_fix_tokenizer
|
||||
from .utils import checksum, merge_dataset
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from datasets import Dataset, IterableDataset
|
||||
from transformers import Seq2SeqTrainingArguments
|
||||
from transformers import ProcessorMixin, Seq2SeqTrainingArguments
|
||||
from transformers.tokenization_utils import PreTrainedTokenizer
|
||||
|
||||
from ..hparams import DataArguments, ModelArguments
|
||||
@@ -56,14 +59,12 @@ def load_single_dataset(
|
||||
data_files.append(local_path)
|
||||
data_path = FILEEXT2TYPE.get(local_path.split(".")[-1], None)
|
||||
else:
|
||||
raise ValueError("File not found.")
|
||||
raise ValueError("File {} not found.".format(local_path))
|
||||
|
||||
if data_path is None:
|
||||
raise ValueError("File extension must be txt, csv, json or jsonl.")
|
||||
|
||||
checksum(data_files, dataset_attr.file_sha1)
|
||||
raise ValueError("Allowed file types: {}.".format(",".join(FILEEXT2TYPE.keys())))
|
||||
else:
|
||||
raise NotImplementedError
|
||||
raise NotImplementedError("Unknown load type: {}.".format(dataset_attr.load_from))
|
||||
|
||||
if dataset_attr.load_from == "ms_hub":
|
||||
try:
|
||||
@@ -80,7 +81,9 @@ def load_single_dataset(
|
||||
cache_dir=cache_dir,
|
||||
token=model_args.ms_hub_token,
|
||||
use_streaming=(data_args.streaming and (dataset_attr.load_from != "file")),
|
||||
).to_hf_dataset()
|
||||
)
|
||||
if isinstance(dataset, MsDataset):
|
||||
dataset = dataset.to_hf_dataset()
|
||||
except ImportError:
|
||||
raise ImportError("Please install modelscope via `pip install modelscope -U`")
|
||||
else:
|
||||
@@ -104,30 +107,43 @@ def load_single_dataset(
|
||||
if data_args.streaming and (dataset_attr.load_from == "file"): # faster than specifying streaming=True
|
||||
dataset = dataset.to_iterable_dataset() # TODO: add num shards parameter
|
||||
|
||||
if dataset_attr.num_samples is not None and not data_args.streaming:
|
||||
target_num = dataset_attr.num_samples
|
||||
indexes = np.random.permutation(len(dataset))[:target_num]
|
||||
target_num -= len(indexes)
|
||||
if target_num > 0:
|
||||
expand_indexes = np.random.choice(len(dataset), target_num)
|
||||
indexes = np.concatenate((indexes, expand_indexes), axis=0)
|
||||
|
||||
assert len(indexes) == dataset_attr.num_samples, "Sample num mismatched."
|
||||
dataset = dataset.select(indexes)
|
||||
logger.info("Sampled {} examples from dataset {}.".format(dataset_attr.num_samples, dataset_attr))
|
||||
|
||||
if data_args.max_samples is not None: # truncate dataset
|
||||
num_samples = min(data_args.max_samples, len(dataset))
|
||||
dataset = dataset.select(range(num_samples))
|
||||
max_samples = min(data_args.max_samples, len(dataset))
|
||||
dataset = dataset.select(range(max_samples))
|
||||
|
||||
return align_dataset(dataset, dataset_attr, data_args)
|
||||
|
||||
|
||||
def get_dataset(
|
||||
tokenizer: "PreTrainedTokenizer",
|
||||
model_args: "ModelArguments",
|
||||
data_args: "DataArguments",
|
||||
training_args: "Seq2SeqTrainingArguments",
|
||||
stage: Literal["pt", "sft", "rm", "ppo"],
|
||||
# split: Optional[str] = "train", # TODO: add split
|
||||
stage: Literal["pt", "sft", "rm", "ppo", "kto"],
|
||||
tokenizer: "PreTrainedTokenizer",
|
||||
processor: Optional["ProcessorMixin"] = None,
|
||||
) -> Union["Dataset", "IterableDataset"]:
|
||||
template = get_template_and_fix_tokenizer(tokenizer, data_args.template)
|
||||
if data_args.train_on_prompt and template.efficient_eos:
|
||||
raise ValueError("Current template does not support `train_on_prompt`.")
|
||||
|
||||
# Load from cache
|
||||
if data_args.cache_path is not None:
|
||||
if os.path.exists(data_args.cache_path):
|
||||
# Load tokenized dataset
|
||||
if data_args.tokenized_path is not None:
|
||||
if has_tokenized_data(data_args.tokenized_path):
|
||||
logger.warning("Loading dataset from disk will ignore other data arguments.")
|
||||
dataset = load_from_disk(data_args.cache_path)
|
||||
dataset = load_from_disk(data_args.tokenized_path)
|
||||
logger.info("Loaded tokenized dataset from {}.".format(data_args.tokenized_path))
|
||||
if data_args.streaming:
|
||||
dataset = dataset.to_iterable_dataset()
|
||||
return dataset
|
||||
@@ -138,12 +154,15 @@ def get_dataset(
|
||||
with training_args.main_process_first(desc="load dataset"):
|
||||
all_datasets = []
|
||||
for dataset_attr in get_dataset_list(data_args):
|
||||
if (stage == "rm" and dataset_attr.ranking is False) or (stage != "rm" and dataset_attr.ranking is True):
|
||||
raise ValueError("The dataset is not applicable in the current training stage.")
|
||||
|
||||
all_datasets.append(load_single_dataset(dataset_attr, model_args, data_args))
|
||||
dataset = merge_dataset(all_datasets, data_args, training_args)
|
||||
|
||||
with training_args.main_process_first(desc="pre-process dataset"):
|
||||
preprocess_func, print_function = get_preprocess_and_print_func(
|
||||
tokenizer, template, data_args, training_args, stage
|
||||
data_args, training_args, stage, template, tokenizer, processor
|
||||
)
|
||||
column_names = list(next(iter(dataset)).keys())
|
||||
kwargs = {}
|
||||
@@ -156,15 +175,21 @@ def get_dataset(
|
||||
|
||||
dataset = dataset.map(preprocess_func, batched=True, remove_columns=column_names, **kwargs)
|
||||
|
||||
if data_args.cache_path is not None and not os.path.exists(data_args.cache_path):
|
||||
if data_args.tokenized_path is not None:
|
||||
if training_args.should_save:
|
||||
dataset.save_to_disk(data_args.cache_path)
|
||||
logger.info("Dataset cache saved at {}.".format(data_args.cache_path))
|
||||
dataset.save_to_disk(data_args.tokenized_path)
|
||||
logger.info("Tokenized dataset saved at {}.".format(data_args.tokenized_path))
|
||||
logger.info("Please restart the training with `tokenized_path: {}`.".format(data_args.tokenized_path))
|
||||
|
||||
sys.exit(0)
|
||||
|
||||
if training_args.should_log:
|
||||
try:
|
||||
print_function(next(iter(dataset)))
|
||||
except StopIteration:
|
||||
raise RuntimeError("Cannot find valid samples, check `data/README.md` for the data format.")
|
||||
if stage == "pt":
|
||||
raise RuntimeError("Cannot find sufficient samples, consider increasing dataset size.")
|
||||
else:
|
||||
raise RuntimeError("Cannot find valid samples, check `data/README.md` for the data format.")
|
||||
|
||||
return dataset
|
||||
@@ -20,23 +20,28 @@ class DatasetAttr:
|
||||
""" basic configs """
|
||||
load_from: Literal["hf_hub", "ms_hub", "script", "file"]
|
||||
dataset_name: str
|
||||
formatting: Literal["alpaca", "sharegpt"] = "alpaca"
|
||||
ranking: bool = False
|
||||
""" extra configs """
|
||||
file_sha1: Optional[str] = None
|
||||
subset: Optional[str] = None
|
||||
folder: Optional[str] = None
|
||||
ranking: bool = False
|
||||
formatting: Literal["alpaca", "sharegpt"] = "alpaca"
|
||||
""" columns """
|
||||
num_samples: Optional[int] = None
|
||||
""" common columns """
|
||||
system: Optional[str] = None
|
||||
""" columns for the alpaca format """
|
||||
tools: Optional[str] = None
|
||||
images: Optional[str] = None
|
||||
""" rlhf columns """
|
||||
chosen: Optional[str] = None
|
||||
rejected: Optional[str] = None
|
||||
kto_tag: Optional[str] = None
|
||||
""" alpaca columns """
|
||||
prompt: Optional[str] = "instruction"
|
||||
query: Optional[str] = "input"
|
||||
response: Optional[str] = "output"
|
||||
history: Optional[str] = None
|
||||
""" columns for the sharegpt format """
|
||||
""" sharegpt columns """
|
||||
messages: Optional[str] = "conversations"
|
||||
tools: Optional[str] = None
|
||||
""" tags for the sharegpt format """
|
||||
""" sharegpt tags """
|
||||
role_tag: Optional[str] = "from"
|
||||
content_tag: Optional[str] = "value"
|
||||
user_tag: Optional[str] = "human"
|
||||
@@ -53,22 +58,35 @@ class DatasetAttr:
|
||||
|
||||
|
||||
def get_dataset_list(data_args: "DataArguments") -> List["DatasetAttr"]:
|
||||
dataset_names = [ds.strip() for ds in data_args.dataset.split(",")] if data_args.dataset is not None else []
|
||||
try:
|
||||
with open(os.path.join(data_args.dataset_dir, DATA_CONFIG), "r") as f:
|
||||
dataset_info = json.load(f)
|
||||
except Exception as err:
|
||||
if data_args.dataset is not None:
|
||||
raise ValueError(
|
||||
"Cannot open {} due to {}.".format(os.path.join(data_args.dataset_dir, DATA_CONFIG), str(err))
|
||||
)
|
||||
if data_args.dataset is not None:
|
||||
dataset_names = [ds.strip() for ds in data_args.dataset.split(",")]
|
||||
else:
|
||||
dataset_names = []
|
||||
|
||||
if data_args.dataset_dir == "ONLINE":
|
||||
dataset_info = None
|
||||
else:
|
||||
try:
|
||||
with open(os.path.join(data_args.dataset_dir, DATA_CONFIG), "r") as f:
|
||||
dataset_info = json.load(f)
|
||||
except Exception as err:
|
||||
if len(dataset_names) != 0:
|
||||
raise ValueError(
|
||||
"Cannot open {} due to {}.".format(os.path.join(data_args.dataset_dir, DATA_CONFIG), str(err))
|
||||
)
|
||||
dataset_info = None
|
||||
|
||||
if data_args.interleave_probs is not None:
|
||||
data_args.interleave_probs = [float(prob.strip()) for prob in data_args.interleave_probs.split(",")]
|
||||
|
||||
dataset_list: List[DatasetAttr] = []
|
||||
for name in dataset_names:
|
||||
if dataset_info is None:
|
||||
load_from = "ms_hub" if use_modelscope() else "hf_hub"
|
||||
dataset_attr = DatasetAttr(load_from, dataset_name=name)
|
||||
dataset_list.append(dataset_attr)
|
||||
continue
|
||||
|
||||
if name not in dataset_info:
|
||||
raise ValueError("Undefined dataset {} in {}.".format(name, DATA_CONFIG))
|
||||
|
||||
@@ -85,18 +103,18 @@ def get_dataset_list(data_args: "DataArguments") -> List["DatasetAttr"]:
|
||||
else:
|
||||
dataset_attr = DatasetAttr("file", dataset_name=dataset_info[name]["file_name"])
|
||||
|
||||
dataset_attr.set_attr("file_sha1", dataset_info[name])
|
||||
dataset_attr.set_attr("formatting", dataset_info[name], default="alpaca")
|
||||
dataset_attr.set_attr("ranking", dataset_info[name], default=False)
|
||||
dataset_attr.set_attr("subset", dataset_info[name])
|
||||
dataset_attr.set_attr("folder", dataset_info[name])
|
||||
dataset_attr.set_attr("ranking", dataset_info[name], default=False)
|
||||
dataset_attr.set_attr("formatting", dataset_info[name], default="alpaca")
|
||||
dataset_attr.set_attr("num_samples", dataset_info[name])
|
||||
|
||||
if "columns" in dataset_info[name]:
|
||||
column_names = ["system"]
|
||||
column_names = ["system", "tools", "images", "chosen", "rejected", "kto_tag"]
|
||||
if dataset_attr.formatting == "alpaca":
|
||||
column_names.extend(["prompt", "query", "response", "history"])
|
||||
else:
|
||||
column_names.extend(["messages", "tools"])
|
||||
column_names.extend(["messages"])
|
||||
|
||||
for column_name in column_names:
|
||||
dataset_attr.set_attr(column_name, dataset_info[name]["columns"])
|
||||
@@ -0,0 +1,84 @@
|
||||
from functools import partial
|
||||
from typing import TYPE_CHECKING, Callable, Literal, Optional, Tuple
|
||||
|
||||
from .processors.feedback import preprocess_feedback_dataset
|
||||
from .processors.pairwise import preprocess_pairwise_dataset, print_pairwise_dataset_example
|
||||
from .processors.pretrain import preprocess_pretrain_dataset
|
||||
from .processors.supervised import (
|
||||
preprocess_packed_supervised_dataset,
|
||||
preprocess_supervised_dataset,
|
||||
print_supervised_dataset_example,
|
||||
)
|
||||
from .processors.unsupervised import preprocess_unsupervised_dataset, print_unsupervised_dataset_example
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from transformers import ProcessorMixin, Seq2SeqTrainingArguments
|
||||
from transformers.tokenization_utils import PreTrainedTokenizer
|
||||
|
||||
from ..hparams import DataArguments
|
||||
from .template import Template
|
||||
|
||||
|
||||
def get_preprocess_and_print_func(
|
||||
data_args: "DataArguments",
|
||||
training_args: "Seq2SeqTrainingArguments",
|
||||
stage: Literal["pt", "sft", "rm", "ppo", "kto"],
|
||||
template: "Template",
|
||||
tokenizer: "PreTrainedTokenizer",
|
||||
processor: Optional["ProcessorMixin"],
|
||||
) -> Tuple[Callable, Callable]:
|
||||
if stage == "pt":
|
||||
preprocess_func = partial(
|
||||
preprocess_pretrain_dataset,
|
||||
tokenizer=tokenizer,
|
||||
data_args=data_args,
|
||||
)
|
||||
print_function = partial(print_unsupervised_dataset_example, tokenizer=tokenizer)
|
||||
elif stage == "sft" and not training_args.predict_with_generate:
|
||||
if data_args.packing:
|
||||
preprocess_func = partial(
|
||||
preprocess_packed_supervised_dataset,
|
||||
template=template,
|
||||
tokenizer=tokenizer,
|
||||
data_args=data_args,
|
||||
)
|
||||
else:
|
||||
preprocess_func = partial(
|
||||
preprocess_supervised_dataset,
|
||||
template=template,
|
||||
tokenizer=tokenizer,
|
||||
processor=processor,
|
||||
data_args=data_args,
|
||||
)
|
||||
|
||||
print_function = partial(print_supervised_dataset_example, tokenizer=tokenizer)
|
||||
elif stage == "rm":
|
||||
preprocess_func = partial(
|
||||
preprocess_pairwise_dataset,
|
||||
template=template,
|
||||
tokenizer=tokenizer,
|
||||
processor=processor,
|
||||
data_args=data_args,
|
||||
)
|
||||
print_function = partial(print_pairwise_dataset_example, tokenizer=tokenizer)
|
||||
elif stage == "kto":
|
||||
preprocess_func = partial(
|
||||
preprocess_feedback_dataset,
|
||||
template=template,
|
||||
tokenizer=tokenizer,
|
||||
processor=processor,
|
||||
data_args=data_args,
|
||||
)
|
||||
print_function = partial(print_supervised_dataset_example, tokenizer=tokenizer)
|
||||
else:
|
||||
preprocess_func = partial(
|
||||
preprocess_unsupervised_dataset,
|
||||
template=template,
|
||||
tokenizer=tokenizer,
|
||||
processor=processor,
|
||||
data_args=data_args,
|
||||
)
|
||||
print_function = partial(print_unsupervised_dataset_example, tokenizer=tokenizer)
|
||||
|
||||
return preprocess_func, print_function
|
||||
@@ -0,0 +1,126 @@
|
||||
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Sequence, Tuple
|
||||
|
||||
from ...extras.constants import IGNORE_INDEX
|
||||
from ...extras.logging import get_logger
|
||||
from .processor_utils import get_paligemma_token_type_ids, get_pixel_values
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from transformers import ProcessorMixin
|
||||
from transformers.tokenization_utils import PreTrainedTokenizer
|
||||
|
||||
from ...hparams import DataArguments
|
||||
from ..template import Template
|
||||
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
def _encode_feedback_example(
|
||||
prompt: Sequence[Dict[str, str]],
|
||||
response: Sequence[Dict[str, str]],
|
||||
kl_response: Sequence[Dict[str, str]],
|
||||
system: Optional[str],
|
||||
tools: Optional[str],
|
||||
template: "Template",
|
||||
tokenizer: "PreTrainedTokenizer",
|
||||
processor: Optional["ProcessorMixin"],
|
||||
data_args: "DataArguments",
|
||||
) -> Tuple[List[int], List[int], List[int], List[int], bool]:
|
||||
if processor is not None and not hasattr(processor, "image_seq_length"): # llava-like models
|
||||
prompt[0]["content"] = template.image_token + prompt[0]["content"]
|
||||
|
||||
if response[0]["content"]: # desired example
|
||||
kto_tag = True
|
||||
messages = prompt + [response[0]]
|
||||
else: # undesired example
|
||||
kto_tag = False
|
||||
messages = prompt + [response[1]]
|
||||
|
||||
if kl_response[0]["content"]:
|
||||
kl_messages = prompt + [kl_response[0]]
|
||||
else:
|
||||
kl_messages = prompt + [kl_response[1]]
|
||||
|
||||
prompt_ids, response_ids = template.encode_oneturn(
|
||||
tokenizer, messages, system, tools, data_args.cutoff_len, data_args.reserved_label_len
|
||||
)
|
||||
_, kl_response_ids = template.encode_oneturn(
|
||||
tokenizer, kl_messages, system, tools, data_args.cutoff_len, data_args.reserved_label_len
|
||||
)
|
||||
|
||||
if template.efficient_eos:
|
||||
response_ids += [tokenizer.eos_token_id]
|
||||
kl_response_ids += [tokenizer.eos_token_id]
|
||||
|
||||
if processor is not None and hasattr(processor, "image_seq_length"): # paligemma models
|
||||
image_token_id = tokenizer.convert_tokens_to_ids(template.image_token)
|
||||
prompt_ids = [image_token_id] * getattr(processor, "image_seq_length") + prompt_ids
|
||||
|
||||
input_ids = prompt_ids + response_ids
|
||||
labels = [IGNORE_INDEX] * len(prompt_ids) + response_ids
|
||||
kl_input_ids = prompt_ids + kl_response_ids
|
||||
kl_labels = [IGNORE_INDEX] * len(prompt_ids) + kl_response_ids
|
||||
|
||||
return input_ids, labels, kl_input_ids, kl_labels, kto_tag
|
||||
|
||||
|
||||
def preprocess_feedback_dataset(
|
||||
examples: Dict[str, List[Any]],
|
||||
template: "Template",
|
||||
tokenizer: "PreTrainedTokenizer",
|
||||
processor: Optional["ProcessorMixin"],
|
||||
data_args: "DataArguments",
|
||||
) -> Dict[str, List[List[int]]]:
|
||||
# create unrelated input-output pairs for estimating the KL term by flipping the matched pairs
|
||||
kl_response = examples["response"][::-1]
|
||||
model_inputs = {
|
||||
"input_ids": [],
|
||||
"attention_mask": [],
|
||||
"labels": [],
|
||||
"kl_input_ids": [],
|
||||
"kl_attention_mask": [],
|
||||
"kl_labels": [],
|
||||
"kto_tags": [],
|
||||
}
|
||||
if processor is not None:
|
||||
model_inputs["pixel_values"] = []
|
||||
if hasattr(processor, "image_seq_length"): # paligemma models
|
||||
model_inputs["token_type_ids"] = []
|
||||
model_inputs["kl_token_type_ids"] = []
|
||||
|
||||
for i in range(len(examples["prompt"])):
|
||||
if len(examples["prompt"][i]) % 2 != 1 or len(examples["response"][i]) < 2:
|
||||
logger.warning("Dropped invalid example: {}".format(examples["prompt"][i] + examples["response"][i]))
|
||||
continue
|
||||
|
||||
input_ids, labels, kl_input_ids, kl_labels, kto_tag = _encode_feedback_example(
|
||||
prompt=examples["prompt"][i],
|
||||
response=examples["response"][i],
|
||||
kl_response=kl_response[i],
|
||||
system=examples["system"][i],
|
||||
tools=examples["tools"][i],
|
||||
template=template,
|
||||
tokenizer=tokenizer,
|
||||
processor=processor,
|
||||
data_args=data_args,
|
||||
)
|
||||
model_inputs["input_ids"].append(input_ids)
|
||||
model_inputs["attention_mask"].append([1] * len(input_ids))
|
||||
model_inputs["labels"].append(labels)
|
||||
model_inputs["kl_input_ids"].append(kl_input_ids)
|
||||
model_inputs["kl_attention_mask"].append([1] * len(kl_input_ids))
|
||||
model_inputs["kl_labels"].append(kl_labels)
|
||||
model_inputs["kto_tags"].append(kto_tag)
|
||||
if processor is not None:
|
||||
model_inputs["pixel_values"].append(get_pixel_values(examples["images"][i], processor))
|
||||
if hasattr(processor, "image_seq_length"): # paligemma models
|
||||
model_inputs["token_type_ids"].append(get_paligemma_token_type_ids(len(input_ids), processor))
|
||||
model_inputs["kl_token_type_ids"].append(get_paligemma_token_type_ids(len(kl_input_ids), processor))
|
||||
|
||||
desirable_num = sum([1 for tag in model_inputs["kto_tags"] if tag])
|
||||
undesirable_num = len(model_inputs["kto_tags"]) - desirable_num
|
||||
if desirable_num == 0 or undesirable_num == 0:
|
||||
logger.warning("Your dataset only has one preference type.")
|
||||
|
||||
return model_inputs
|
||||
@@ -0,0 +1,123 @@
|
||||
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Sequence, Tuple
|
||||
|
||||
from ...extras.constants import IGNORE_INDEX
|
||||
from ...extras.logging import get_logger
|
||||
from .processor_utils import get_paligemma_token_type_ids, get_pixel_values
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from transformers import ProcessorMixin
|
||||
from transformers.tokenization_utils import PreTrainedTokenizer
|
||||
|
||||
from ...hparams import DataArguments
|
||||
from ..template import Template
|
||||
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
def _encode_pairwise_example(
|
||||
prompt: Sequence[Dict[str, str]],
|
||||
response: Sequence[Dict[str, str]],
|
||||
system: Optional[str],
|
||||
tools: Optional[str],
|
||||
template: "Template",
|
||||
tokenizer: "PreTrainedTokenizer",
|
||||
processor: Optional["ProcessorMixin"],
|
||||
data_args: "DataArguments",
|
||||
) -> Tuple[List[int], List[int], List[int], List[int]]:
|
||||
if processor is not None and not hasattr(processor, "image_seq_length"): # llava-like models
|
||||
prompt[0]["content"] = template.image_token + prompt[0]["content"]
|
||||
|
||||
chosen_messages = prompt + [response[0]]
|
||||
rejected_messages = prompt + [response[1]]
|
||||
prompt_ids, chosen_ids = template.encode_oneturn(
|
||||
tokenizer, chosen_messages, system, tools, data_args.cutoff_len, data_args.reserved_label_len
|
||||
)
|
||||
_, rejected_ids = template.encode_oneturn(
|
||||
tokenizer, rejected_messages, system, tools, data_args.cutoff_len, data_args.reserved_label_len
|
||||
)
|
||||
|
||||
if template.efficient_eos:
|
||||
chosen_ids += [tokenizer.eos_token_id]
|
||||
rejected_ids += [tokenizer.eos_token_id]
|
||||
|
||||
if processor is not None and hasattr(processor, "image_seq_length"): # paligemma models
|
||||
image_token_id = tokenizer.convert_tokens_to_ids(template.image_token)
|
||||
prompt_ids = [image_token_id] * getattr(processor, "image_seq_length") + prompt_ids
|
||||
|
||||
chosen_input_ids = prompt_ids + chosen_ids
|
||||
chosen_labels = [IGNORE_INDEX] * len(prompt_ids) + chosen_ids
|
||||
rejected_input_ids = prompt_ids + rejected_ids
|
||||
rejected_labels = [IGNORE_INDEX] * len(prompt_ids) + rejected_ids
|
||||
|
||||
return chosen_input_ids, chosen_labels, rejected_input_ids, rejected_labels
|
||||
|
||||
|
||||
def preprocess_pairwise_dataset(
|
||||
examples: Dict[str, List[Any]],
|
||||
template: "Template",
|
||||
tokenizer: "PreTrainedTokenizer",
|
||||
processor: Optional["ProcessorMixin"],
|
||||
data_args: "DataArguments",
|
||||
) -> Dict[str, List[List[int]]]:
|
||||
# build input pairs with format `<bos> X`, `Y1 <eos>` and `Y2 <eos>`
|
||||
model_inputs = {
|
||||
"chosen_input_ids": [],
|
||||
"chosen_attention_mask": [],
|
||||
"chosen_labels": [],
|
||||
"rejected_input_ids": [],
|
||||
"rejected_attention_mask": [],
|
||||
"rejected_labels": [],
|
||||
}
|
||||
if processor is not None:
|
||||
model_inputs["pixel_values"] = []
|
||||
if hasattr(processor, "image_seq_length"): # paligemma models
|
||||
model_inputs["chosen_token_type_ids"] = []
|
||||
model_inputs["rejected_token_type_ids"] = []
|
||||
|
||||
for i in range(len(examples["prompt"])):
|
||||
if len(examples["prompt"][i]) % 2 != 1 or len(examples["response"][i]) < 2:
|
||||
logger.warning("Dropped invalid example: {}".format(examples["prompt"][i] + examples["response"][i]))
|
||||
continue
|
||||
|
||||
chosen_input_ids, chosen_labels, rejected_input_ids, rejected_labels = _encode_pairwise_example(
|
||||
prompt=examples["prompt"][i],
|
||||
response=examples["response"][i],
|
||||
system=examples["system"][i],
|
||||
tools=examples["tools"][i],
|
||||
template=template,
|
||||
tokenizer=tokenizer,
|
||||
processor=processor,
|
||||
data_args=data_args,
|
||||
)
|
||||
model_inputs["chosen_input_ids"].append(chosen_input_ids)
|
||||
model_inputs["chosen_attention_mask"].append([1] * len(chosen_input_ids))
|
||||
model_inputs["chosen_labels"].append(chosen_labels)
|
||||
model_inputs["rejected_input_ids"].append(rejected_input_ids)
|
||||
model_inputs["rejected_attention_mask"].append([1] * len(rejected_input_ids))
|
||||
model_inputs["rejected_labels"].append(rejected_labels)
|
||||
if processor is not None:
|
||||
model_inputs["pixel_values"].append(get_pixel_values(examples["images"][i], processor))
|
||||
if hasattr(processor, "image_seq_length"): # paligemma models
|
||||
model_inputs["chosen_token_type_ids"].append(
|
||||
get_paligemma_token_type_ids(len(chosen_input_ids), processor)
|
||||
)
|
||||
model_inputs["rejected_token_type_ids"].append(
|
||||
get_paligemma_token_type_ids(len(rejected_input_ids), processor)
|
||||
)
|
||||
|
||||
return model_inputs
|
||||
|
||||
|
||||
def print_pairwise_dataset_example(example: Dict[str, List[int]], tokenizer: "PreTrainedTokenizer") -> None:
|
||||
valid_chosen_labels = list(filter(lambda x: x != IGNORE_INDEX, example["chosen_labels"]))
|
||||
valid_rejected_labels = list(filter(lambda x: x != IGNORE_INDEX, example["rejected_labels"]))
|
||||
print("chosen_input_ids:\n{}".format(example["chosen_input_ids"]))
|
||||
print("chosen_inputs:\n{}".format(tokenizer.decode(example["chosen_input_ids"], skip_special_tokens=False)))
|
||||
print("chosen_label_ids:\n{}".format(example["chosen_labels"]))
|
||||
print("chosen_labels:\n{}".format(tokenizer.decode(valid_chosen_labels, skip_special_tokens=False)))
|
||||
print("rejected_input_ids:\n{}".format(example["rejected_input_ids"]))
|
||||
print("rejected_inputs:\n{}".format(tokenizer.decode(example["rejected_input_ids"], skip_special_tokens=False)))
|
||||
print("rejected_label_ids:\n{}".format(example["rejected_labels"]))
|
||||
print("rejected_labels:\n{}".format(tokenizer.decode(valid_rejected_labels, skip_special_tokens=False)))
|
||||
@@ -0,0 +1,36 @@
|
||||
from itertools import chain
|
||||
from typing import TYPE_CHECKING, Any, Dict, List
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from transformers.tokenization_utils import PreTrainedTokenizer
|
||||
|
||||
from ...hparams import DataArguments
|
||||
|
||||
|
||||
def preprocess_pretrain_dataset(
|
||||
examples: Dict[str, List[Any]], tokenizer: "PreTrainedTokenizer", data_args: "DataArguments"
|
||||
) -> Dict[str, List[List[int]]]:
|
||||
# build grouped texts with format `X1 X2 X3 ...` if packing is enabled
|
||||
text_examples = [messages[0]["content"] + tokenizer.eos_token for messages in examples["prompt"]]
|
||||
|
||||
if not data_args.packing:
|
||||
if data_args.template == "gemma":
|
||||
text_examples = [tokenizer.bos_token + example for example in text_examples]
|
||||
|
||||
result = tokenizer(text_examples, add_special_tokens=False, max_length=data_args.cutoff_len, truncation=True)
|
||||
else:
|
||||
tokenized_examples = tokenizer(text_examples, add_special_tokens=False)
|
||||
concatenated_examples = {k: list(chain(*tokenized_examples[k])) for k in tokenized_examples.keys()}
|
||||
total_length = len(concatenated_examples[list(concatenated_examples.keys())[0]])
|
||||
block_size = data_args.cutoff_len
|
||||
total_length = (total_length // block_size) * block_size
|
||||
result = {
|
||||
k: [t[i : i + block_size] for i in range(0, total_length, block_size)]
|
||||
for k, t in concatenated_examples.items()
|
||||
}
|
||||
if data_args.template == "gemma":
|
||||
for i in range(len(result["input_ids"])):
|
||||
result["input_ids"][i][0] = tokenizer.bos_token_id
|
||||
|
||||
return result
|
||||
@@ -0,0 +1,64 @@
|
||||
import bisect
|
||||
from typing import TYPE_CHECKING, List, Sequence
|
||||
|
||||
from ...extras.packages import is_pillow_available
|
||||
|
||||
|
||||
if is_pillow_available():
|
||||
from PIL import Image
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from numpy.typing import NDArray
|
||||
from PIL.Image import Image as ImageObject
|
||||
from transformers import ProcessorMixin
|
||||
from transformers.image_processing_utils import BaseImageProcessor
|
||||
|
||||
|
||||
def search_for_fit(numbers: Sequence[int], capacity: int) -> int:
|
||||
r"""
|
||||
Finds the index of largest number that fits into the knapsack with the given capacity.
|
||||
"""
|
||||
index = bisect.bisect(numbers, capacity)
|
||||
return -1 if index == 0 else (index - 1)
|
||||
|
||||
|
||||
def greedy_knapsack(numbers: List[int], capacity: int) -> List[List[int]]:
|
||||
r"""
|
||||
An efficient greedy algorithm with binary search for the knapsack problem.
|
||||
"""
|
||||
numbers.sort() # sort numbers in ascending order for binary search
|
||||
knapsacks = []
|
||||
|
||||
while numbers:
|
||||
current_knapsack = []
|
||||
remaining_capacity = capacity
|
||||
|
||||
while True:
|
||||
index = search_for_fit(numbers, remaining_capacity)
|
||||
if index == -1:
|
||||
break # no more numbers fit in this knapsack
|
||||
|
||||
remaining_capacity -= numbers[index] # update the remaining capacity
|
||||
current_knapsack.append(numbers.pop(index)) # add the number to knapsack
|
||||
|
||||
knapsacks.append(current_knapsack)
|
||||
|
||||
return knapsacks
|
||||
|
||||
|
||||
def get_pixel_values(images: Sequence["ImageObject"], processor: "ProcessorMixin") -> "NDArray":
|
||||
r"""
|
||||
Processes visual inputs. (currently only supports a single image)
|
||||
"""
|
||||
image_processor: "BaseImageProcessor" = getattr(processor, "image_processor")
|
||||
image = images[0] if len(images) != 0 else Image.new("RGB", (100, 100), (255, 255, 255))
|
||||
return image_processor(image, return_tensors="pt")["pixel_values"][0] # shape (C, H, W)
|
||||
|
||||
|
||||
def get_paligemma_token_type_ids(input_len: int, processor: "ProcessorMixin") -> List[int]:
|
||||
r"""
|
||||
Gets paligemma token type ids for computing loss.
|
||||
"""
|
||||
image_seq_length = getattr(processor, "image_seq_length")
|
||||
return [0] * image_seq_length + [1] * (input_len - image_seq_length)
|
||||
@@ -0,0 +1,169 @@
|
||||
from collections import defaultdict
|
||||
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Sequence, Tuple
|
||||
|
||||
from ...extras.constants import IGNORE_INDEX
|
||||
from ...extras.logging import get_logger
|
||||
from .processor_utils import get_paligemma_token_type_ids, get_pixel_values, greedy_knapsack
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from transformers import ProcessorMixin
|
||||
from transformers.tokenization_utils import PreTrainedTokenizer
|
||||
|
||||
from ...hparams import DataArguments
|
||||
from ..template import Template
|
||||
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
def _encode_supervised_example(
|
||||
prompt: Sequence[Dict[str, str]],
|
||||
response: Sequence[Dict[str, str]],
|
||||
system: Optional[str],
|
||||
tools: Optional[str],
|
||||
template: "Template",
|
||||
tokenizer: "PreTrainedTokenizer",
|
||||
processor: Optional["ProcessorMixin"],
|
||||
data_args: "DataArguments",
|
||||
) -> Tuple[List[int], List[int]]:
|
||||
if processor is not None and not hasattr(processor, "image_seq_length"): # llava-like models
|
||||
prompt[0]["content"] = template.image_token + prompt[0]["content"]
|
||||
|
||||
messages = prompt + response
|
||||
input_ids, labels = [], []
|
||||
|
||||
if processor is not None and hasattr(processor, "image_seq_length"): # paligemma models
|
||||
image_token_id = tokenizer.convert_tokens_to_ids(template.image_token)
|
||||
input_ids += [image_token_id] * getattr(processor, "image_seq_length")
|
||||
labels += [IGNORE_INDEX] * getattr(processor, "image_seq_length")
|
||||
|
||||
encoded_pairs = template.encode_multiturn(
|
||||
tokenizer, messages, system, tools, data_args.cutoff_len, data_args.reserved_label_len
|
||||
)
|
||||
for turn_idx, (source_ids, target_ids) in enumerate(encoded_pairs):
|
||||
if data_args.train_on_prompt:
|
||||
source_mask = source_ids
|
||||
elif turn_idx != 0 and template.efficient_eos:
|
||||
source_mask = [tokenizer.eos_token_id] + [IGNORE_INDEX] * (len(source_ids) - 1)
|
||||
else:
|
||||
source_mask = [IGNORE_INDEX] * len(source_ids)
|
||||
|
||||
input_ids += source_ids + target_ids
|
||||
labels += source_mask + target_ids
|
||||
|
||||
if template.efficient_eos:
|
||||
input_ids += [tokenizer.eos_token_id]
|
||||
labels += [tokenizer.eos_token_id]
|
||||
|
||||
return input_ids, labels
|
||||
|
||||
|
||||
def preprocess_supervised_dataset(
|
||||
examples: Dict[str, List[Any]],
|
||||
template: "Template",
|
||||
tokenizer: "PreTrainedTokenizer",
|
||||
processor: Optional["ProcessorMixin"],
|
||||
data_args: "DataArguments",
|
||||
) -> Dict[str, List[List[int]]]:
|
||||
# build inputs with format `<bos> X Y <eos>` and labels with format `<ignore> ... <ignore> Y <eos>`
|
||||
# for multiturn examples, we only mask the prompt part in each prompt-response pair.
|
||||
model_inputs = {"input_ids": [], "attention_mask": [], "labels": []}
|
||||
if processor is not None:
|
||||
model_inputs["pixel_values"] = []
|
||||
if hasattr(processor, "image_seq_length"): # paligemma models
|
||||
model_inputs["token_type_ids"] = []
|
||||
|
||||
for i in range(len(examples["prompt"])):
|
||||
if len(examples["prompt"][i]) % 2 != 1 or len(examples["response"][i]) != 1:
|
||||
logger.warning("Dropped invalid example: {}".format(examples["prompt"][i] + examples["response"][i]))
|
||||
continue
|
||||
|
||||
input_ids, labels = _encode_supervised_example(
|
||||
prompt=examples["prompt"][i],
|
||||
response=examples["response"][i],
|
||||
system=examples["system"][i],
|
||||
tools=examples["tools"][i],
|
||||
template=template,
|
||||
tokenizer=tokenizer,
|
||||
processor=processor,
|
||||
data_args=data_args,
|
||||
)
|
||||
model_inputs["input_ids"].append(input_ids)
|
||||
model_inputs["attention_mask"].append([1] * len(input_ids))
|
||||
model_inputs["labels"].append(labels)
|
||||
if processor is not None:
|
||||
model_inputs["pixel_values"].append(get_pixel_values(examples["images"][i], processor))
|
||||
if hasattr(processor, "image_seq_length"): # paligemma models
|
||||
model_inputs["token_type_ids"].append(get_paligemma_token_type_ids(len(input_ids), processor))
|
||||
|
||||
return model_inputs
|
||||
|
||||
|
||||
def preprocess_packed_supervised_dataset(
|
||||
examples: Dict[str, List[Any]],
|
||||
template: "Template",
|
||||
tokenizer: "PreTrainedTokenizer",
|
||||
data_args: "DataArguments",
|
||||
) -> Dict[str, List[List[int]]]:
|
||||
# build inputs with format `<bos> X1 Y1 <eos> <bos> X2 Y2 <eos>`
|
||||
# and labels with format `<ignore> ... <ignore> Y1 <eos> <ignore> ... <ignore> Y2 <eos>`
|
||||
valid_num = 0
|
||||
batch_input_ids, batch_labels = [], []
|
||||
lengths = []
|
||||
length2indexes = defaultdict(list)
|
||||
for i in range(len(examples["prompt"])):
|
||||
if len(examples["prompt"][i]) % 2 != 1 or len(examples["response"][i]) != 1:
|
||||
logger.warning("Dropped invalid example: {}".format(examples["prompt"][i] + examples["response"][i]))
|
||||
continue
|
||||
|
||||
input_ids, labels = _encode_supervised_example(
|
||||
prompt=examples["prompt"][i],
|
||||
response=examples["response"][i],
|
||||
system=examples["system"][i],
|
||||
tools=examples["tools"][i],
|
||||
template=template,
|
||||
tokenizer=tokenizer,
|
||||
processor=None,
|
||||
data_args=data_args,
|
||||
)
|
||||
length = len(input_ids)
|
||||
if length > data_args.cutoff_len:
|
||||
logger.warning("Dropped lengthy example with length {} > {}.".format(length, data_args.cutoff_len))
|
||||
else:
|
||||
lengths.append(length)
|
||||
length2indexes[length].append(valid_num)
|
||||
batch_input_ids.append(input_ids)
|
||||
batch_labels.append(labels)
|
||||
valid_num += 1
|
||||
|
||||
model_inputs = {"input_ids": [], "attention_mask": [], "labels": []}
|
||||
knapsacks = greedy_knapsack(lengths, data_args.cutoff_len)
|
||||
for knapsack in knapsacks:
|
||||
packed_input_ids, packed_labels = [], []
|
||||
for length in knapsack:
|
||||
index = length2indexes[length].pop()
|
||||
packed_input_ids += batch_input_ids[index]
|
||||
packed_labels += batch_labels[index]
|
||||
|
||||
if len(packed_input_ids) < data_args.cutoff_len:
|
||||
pad_length = data_args.cutoff_len - len(packed_input_ids)
|
||||
packed_input_ids += [tokenizer.pad_token_id] * pad_length
|
||||
packed_labels += [IGNORE_INDEX] * pad_length
|
||||
|
||||
if len(packed_input_ids) != data_args.cutoff_len:
|
||||
raise ValueError("The length of packed example should be identical to the cutoff length.")
|
||||
|
||||
model_inputs["input_ids"].append(packed_input_ids)
|
||||
model_inputs["attention_mask"].append([1] * data_args.cutoff_len)
|
||||
model_inputs["labels"].append(packed_labels)
|
||||
|
||||
return model_inputs
|
||||
|
||||
|
||||
def print_supervised_dataset_example(example: Dict[str, List[int]], tokenizer: "PreTrainedTokenizer") -> None:
|
||||
valid_labels = list(filter(lambda x: x != IGNORE_INDEX, example["labels"]))
|
||||
print("input_ids:\n{}".format(example["input_ids"]))
|
||||
print("inputs:\n{}".format(tokenizer.decode(example["input_ids"], skip_special_tokens=False)))
|
||||
print("label_ids:\n{}".format(example["labels"]))
|
||||
print("labels:\n{}".format(tokenizer.decode(valid_labels, skip_special_tokens=False)))
|
||||
@@ -0,0 +1,92 @@
|
||||
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Sequence, Tuple
|
||||
|
||||
from ...extras.logging import get_logger
|
||||
from ..data_utils import Role
|
||||
from .processor_utils import get_paligemma_token_type_ids, get_pixel_values
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from transformers import ProcessorMixin
|
||||
from transformers.tokenization_utils import PreTrainedTokenizer
|
||||
|
||||
from ...hparams import DataArguments
|
||||
from ..template import Template
|
||||
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
def _encode_unsupervised_example(
|
||||
prompt: Sequence[Dict[str, str]],
|
||||
response: Sequence[Dict[str, str]],
|
||||
system: Optional[str],
|
||||
tools: Optional[str],
|
||||
template: "Template",
|
||||
tokenizer: "PreTrainedTokenizer",
|
||||
processor: Optional["ProcessorMixin"],
|
||||
data_args: "DataArguments",
|
||||
) -> Tuple[List[int], List[int]]:
|
||||
if processor is not None and not hasattr(processor, "image_seq_length"): # llava-like models
|
||||
prompt[0]["content"] = template.image_token + prompt[0]["content"]
|
||||
|
||||
if len(response) == 1:
|
||||
messages = prompt + response
|
||||
else:
|
||||
messages = prompt + [{"role": Role.ASSISTANT.value, "content": ""}]
|
||||
|
||||
input_ids, labels = template.encode_oneturn(
|
||||
tokenizer, messages, system, tools, data_args.cutoff_len, data_args.reserved_label_len
|
||||
)
|
||||
if template.efficient_eos:
|
||||
labels += [tokenizer.eos_token_id]
|
||||
|
||||
if processor is not None and hasattr(processor, "image_seq_length"): # paligemma models
|
||||
image_token_id = tokenizer.convert_tokens_to_ids(template.image_token)
|
||||
input_ids = [image_token_id] * getattr(processor, "image_seq_length") + input_ids
|
||||
|
||||
return input_ids, labels
|
||||
|
||||
|
||||
def preprocess_unsupervised_dataset(
|
||||
examples: Dict[str, List[Any]],
|
||||
template: "Template",
|
||||
tokenizer: "PreTrainedTokenizer",
|
||||
processor: Optional["ProcessorMixin"],
|
||||
data_args: "DataArguments",
|
||||
) -> Dict[str, List[List[int]]]:
|
||||
# build inputs with format `<bos> X` and labels with format `Y <eos>`
|
||||
model_inputs = {"input_ids": [], "attention_mask": [], "labels": []}
|
||||
if processor is not None:
|
||||
model_inputs["pixel_values"] = []
|
||||
if hasattr(processor, "image_seq_length"): # paligemma models
|
||||
model_inputs["token_type_ids"] = []
|
||||
|
||||
for i in range(len(examples["prompt"])):
|
||||
if len(examples["prompt"][i]) % 2 != 1:
|
||||
logger.warning("Dropped invalid example: {}".format(examples["prompt"][i] + examples["response"][i]))
|
||||
continue
|
||||
|
||||
input_ids, labels = _encode_unsupervised_example(
|
||||
prompt=examples["prompt"][i],
|
||||
response=examples["response"][i],
|
||||
system=examples["system"][i],
|
||||
tools=examples["tools"][i],
|
||||
template=template,
|
||||
tokenizer=tokenizer,
|
||||
processor=processor,
|
||||
data_args=data_args,
|
||||
)
|
||||
model_inputs["input_ids"].append(input_ids)
|
||||
model_inputs["attention_mask"].append([1] * len(input_ids))
|
||||
model_inputs["labels"].append(labels)
|
||||
if processor is not None:
|
||||
model_inputs["pixel_values"].append(get_pixel_values(examples["images"][i], processor))
|
||||
if hasattr(processor, "image_seq_length"): # paligemma models
|
||||
model_inputs["token_type_ids"].append(get_paligemma_token_type_ids(len(input_ids), processor))
|
||||
|
||||
return model_inputs
|
||||
|
||||
|
||||
def print_unsupervised_dataset_example(example: Dict[str, List[int]], tokenizer: "PreTrainedTokenizer") -> None:
|
||||
print("input_ids:\n{}".format(example["input_ids"]))
|
||||
print("inputs:\n{}".format(tokenizer.decode(example["input_ids"], skip_special_tokens=False)))
|
||||
@@ -2,8 +2,8 @@ from dataclasses import dataclass
|
||||
from typing import TYPE_CHECKING, Dict, List, Optional, Sequence, Tuple, Union
|
||||
|
||||
from ..extras.logging import get_logger
|
||||
from .data_utils import Role, infer_max_len
|
||||
from .formatter import EmptyFormatter, FunctionFormatter, StringFormatter, ToolFormatter
|
||||
from .utils import Role, infer_max_len
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
@@ -26,6 +26,7 @@ class Template:
|
||||
format_separator: "Formatter"
|
||||
default_system: str
|
||||
stop_words: List[str]
|
||||
image_token: str
|
||||
efficient_eos: bool
|
||||
replace_eos: bool
|
||||
force_system: bool
|
||||
@@ -68,8 +69,8 @@ class Template:
|
||||
self,
|
||||
tokenizer: "PreTrainedTokenizer",
|
||||
messages: List[Dict[str, str]],
|
||||
system: str,
|
||||
tools: str,
|
||||
system: Optional[str],
|
||||
tools: Optional[str],
|
||||
cutoff_len: int,
|
||||
reserved_label_len: int,
|
||||
) -> Sequence[Tuple[List[int], List[int]]]:
|
||||
@@ -195,7 +196,7 @@ class Llama2Template(Template):
|
||||
return self._make_pairs(encoded_messages, cutoff_len, reserved_label_len)
|
||||
|
||||
|
||||
templates: Dict[str, Template] = {}
|
||||
TEMPLATES: Dict[str, Template] = {}
|
||||
|
||||
|
||||
def _register_template(
|
||||
@@ -209,6 +210,7 @@ def _register_template(
|
||||
format_separator: Optional["Formatter"] = None,
|
||||
default_system: str = "",
|
||||
stop_words: List[str] = [],
|
||||
image_token: str = "<image>",
|
||||
efficient_eos: bool = False,
|
||||
replace_eos: bool = False,
|
||||
force_system: bool = False,
|
||||
@@ -246,7 +248,7 @@ def _register_template(
|
||||
default_function_formatter = FunctionFormatter(slots=["Action: {{name}}\nAction Input: {{arguments}}"] + eos_slots)
|
||||
default_tool_formatter = ToolFormatter(tool_format="default")
|
||||
default_separator_formatter = EmptyFormatter()
|
||||
templates[name] = template_class(
|
||||
TEMPLATES[name] = template_class(
|
||||
format_user=format_user or default_user_formatter,
|
||||
format_assistant=format_assistant or default_assistant_formatter,
|
||||
format_system=format_system or default_user_formatter,
|
||||
@@ -256,6 +258,7 @@ def _register_template(
|
||||
format_separator=format_separator or default_separator_formatter,
|
||||
default_system=default_system,
|
||||
stop_words=stop_words,
|
||||
image_token=image_token,
|
||||
efficient_eos=efficient_eos,
|
||||
replace_eos=replace_eos,
|
||||
force_system=force_system,
|
||||
@@ -276,7 +279,7 @@ def _add_or_replace_eos_token(tokenizer: "PreTrainedTokenizer", eos_token: str)
|
||||
|
||||
|
||||
def _jinja_escape(content: str) -> str:
|
||||
return content.replace("\n", r"\n").replace("'", r"\'")
|
||||
return content.replace("'", r"\'")
|
||||
|
||||
|
||||
def _convert_slots_to_jinja(slots: "SLOTS", tokenizer: "PreTrainedTokenizer", placeholder: str = "content") -> str:
|
||||
@@ -290,10 +293,10 @@ def _convert_slots_to_jinja(slots: "SLOTS", tokenizer: "PreTrainedTokenizer", pl
|
||||
slot_items.append(placeholder)
|
||||
if slot_pieces[1]:
|
||||
slot_items.append("'" + _jinja_escape(slot_pieces[1]) + "'")
|
||||
elif isinstance(slot, set):
|
||||
if "bos_token" in slot:
|
||||
elif isinstance(slot, set): # do not use {{ eos_token }} since it may be replaced
|
||||
if "bos_token" in slot and tokenizer.bos_token_id is not None:
|
||||
slot_items.append("'" + tokenizer.bos_token + "'")
|
||||
elif "eos_token" in slot: # do not use {{ eos_token }} since it may be replaced
|
||||
elif "eos_token" in slot and tokenizer.eos_token_id is not None:
|
||||
slot_items.append("'" + tokenizer.eos_token + "'")
|
||||
elif isinstance(slot, dict):
|
||||
raise ValueError("Dict is not supported.")
|
||||
@@ -308,7 +311,7 @@ def _get_jinja_template(template: "Template", tokenizer: "PreTrainedTokenizer")
|
||||
jinja_template += "{% set system_message = '" + _jinja_escape(template.default_system) + "' %}"
|
||||
|
||||
jinja_template += (
|
||||
"{% if messages[0]['role'] == 'system' %}" "{% set system_message = messages[0]['content'] %}" "{% endif %}"
|
||||
"{% if messages[0]['role'] == 'system' %}{% set system_message = messages[0]['content'] %}{% endif %}"
|
||||
)
|
||||
|
||||
system_message = _convert_slots_to_jinja(template.format_system.apply(), tokenizer, placeholder="system_message")
|
||||
@@ -325,9 +328,11 @@ def _get_jinja_template(template: "Template", tokenizer: "PreTrainedTokenizer")
|
||||
jinja_template += "{% if loop.index0 == 0 and system_message is defined %}"
|
||||
jinja_template += "{% set content = " + system_message + " + message['content'] %}"
|
||||
jinja_template += "{% endif %}"
|
||||
|
||||
jinja_template += "{% if message['role'] == 'user' %}"
|
||||
user_message = _convert_slots_to_jinja(template.format_user.apply(), tokenizer)
|
||||
jinja_template += "{{ " + user_message + " }}"
|
||||
|
||||
jinja_template += "{% elif message['role'] == 'assistant' %}"
|
||||
assistant_message = _convert_slots_to_jinja(
|
||||
template.format_assistant.apply() + template.format_separator.apply(), tokenizer
|
||||
@@ -343,9 +348,9 @@ def get_template_and_fix_tokenizer(
|
||||
name: Optional[str] = None,
|
||||
) -> Template:
|
||||
if name is None:
|
||||
template = templates["vanilla"] # placeholder
|
||||
template = TEMPLATES["empty"] # placeholder
|
||||
else:
|
||||
template = templates.get(name, None)
|
||||
template = TEMPLATES.get(name, None)
|
||||
if template is None:
|
||||
raise ValueError("Template {} does not exist.".format(name))
|
||||
|
||||
@@ -385,7 +390,8 @@ _register_template(
|
||||
format_user=StringFormatter(slots=["### Instruction:\n{{content}}\n\n### Response:\n"]),
|
||||
format_separator=EmptyFormatter(slots=["\n\n"]),
|
||||
default_system=(
|
||||
"Below is an instruction that describes a task. " "Write a response that appropriately completes the request."
|
||||
"Below is an instruction that describes a task. "
|
||||
"Write a response that appropriately completes the request.\n\n"
|
||||
),
|
||||
)
|
||||
|
||||
@@ -414,7 +420,7 @@ _register_template(
|
||||
|
||||
_register_template(
|
||||
name="baichuan",
|
||||
format_user=StringFormatter(slots=["<reserved_102>{{content}}<reserved_103>"]),
|
||||
format_user=StringFormatter(slots=[{"token": "<reserved_102>"}, "{{content}}", {"token": "<reserved_103>"}]),
|
||||
efficient_eos=True,
|
||||
)
|
||||
|
||||
@@ -441,6 +447,18 @@ _register_template(
|
||||
)
|
||||
|
||||
|
||||
_register_template(
|
||||
name="breeze",
|
||||
format_user=StringFormatter(slots=["[INST] {{content}} [/INST] "]),
|
||||
format_system=StringFormatter(slots=[{"bos_token"}, "{{content}}"]),
|
||||
default_system=(
|
||||
"You are a helpful AI assistant built by MediaTek Research. "
|
||||
"The user you are helping speaks Traditional Chinese and comes from Taiwan."
|
||||
),
|
||||
efficient_eos=True,
|
||||
)
|
||||
|
||||
|
||||
_register_template(
|
||||
name="chatglm2",
|
||||
format_user=StringFormatter(slots=["[Round {{idx}}]\n\n问:{{content}}\n\n答:"]),
|
||||
@@ -490,6 +508,7 @@ _register_template(
|
||||
name="chatml",
|
||||
format_user=StringFormatter(slots=["<|im_start|>user\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]),
|
||||
format_system=StringFormatter(slots=["<|im_start|>system\n{{content}}<|im_end|>\n"]),
|
||||
format_observation=StringFormatter(slots=["<|im_start|>tool\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]),
|
||||
format_separator=EmptyFormatter(slots=["\n"]),
|
||||
stop_words=["<|im_end|>", "<|im_start|>"],
|
||||
replace_eos=True,
|
||||
@@ -500,6 +519,7 @@ _register_template(
|
||||
name="chatml_de",
|
||||
format_user=StringFormatter(slots=["<|im_start|>user\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]),
|
||||
format_system=StringFormatter(slots=["<|im_start|>system\n{{content}}<|im_end|>\n"]),
|
||||
format_observation=StringFormatter(slots=["<|im_start|>tool\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]),
|
||||
format_separator=EmptyFormatter(slots=["\n"]),
|
||||
default_system="Du bist ein freundlicher und hilfsbereiter KI-Assistent.",
|
||||
stop_words=["<|im_end|>", "<|im_start|>"],
|
||||
@@ -514,6 +534,26 @@ _register_template(
|
||||
)
|
||||
|
||||
|
||||
_register_template(
|
||||
name="cohere",
|
||||
format_user=StringFormatter(
|
||||
slots=[
|
||||
(
|
||||
"<|START_OF_TURN_TOKEN|><|USER_TOKEN|>{{content}}<|END_OF_TURN_TOKEN|>"
|
||||
"<|START_OF_TURN_TOKEN|><|CHATBOT_TOKEN|>"
|
||||
)
|
||||
]
|
||||
),
|
||||
format_system=StringFormatter(
|
||||
slots=[{"bos_token"}, "<|START_OF_TURN_TOKEN|><|SYSTEM_TOKEN|>{{content}}<|END_OF_TURN_TOKEN|>"]
|
||||
),
|
||||
default_system=(
|
||||
"You are Command-R, a brilliant, sophisticated, AI-assistant trained to assist human users "
|
||||
"by providing thorough responses. You are trained by Cohere."
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
_register_template(
|
||||
name="cpm",
|
||||
format_user=StringFormatter(slots=["<用户>{{content}}<AI>"]),
|
||||
@@ -522,6 +562,32 @@ _register_template(
|
||||
)
|
||||
|
||||
|
||||
_register_template(
|
||||
name="dbrx",
|
||||
format_user=StringFormatter(slots=["<|im_start|>user\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]),
|
||||
format_system=StringFormatter(slots=["<|im_start|>system\n{{content}}<|im_end|>\n"]),
|
||||
format_observation=StringFormatter(slots=["<|im_start|>tool\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]),
|
||||
format_separator=EmptyFormatter(slots=["\n"]),
|
||||
default_system=(
|
||||
"You are DBRX, created by Databricks. You were last updated in December 2023. "
|
||||
"You answer questions based on information available up to that point.\n"
|
||||
"YOU PROVIDE SHORT RESPONSES TO SHORT QUESTIONS OR STATEMENTS, but provide thorough "
|
||||
"responses to more complex and open-ended questions.\nYou assist with various tasks, "
|
||||
"from writing to coding (using markdown for code blocks — remember to use ``` with "
|
||||
"code, JSON, and tables).\n(You do not have real-time data access or code execution "
|
||||
"capabilities. You avoid stereotyping and provide balanced perspectives on "
|
||||
"controversial topics. You do not provide song lyrics, poems, or news articles and "
|
||||
"do not divulge details of your training data.)\nThis is your system prompt, "
|
||||
"guiding your responses. Do not reference it, just respond to the user. If you find "
|
||||
"yourself talking about this message, stop. You should be responding appropriately "
|
||||
"and usually that means not mentioning this.\nYOU DO NOT MENTION ANY OF THIS INFORMATION "
|
||||
"ABOUT YOURSELF UNLESS THE INFORMATION IS DIRECTLY PERTINENT TO THE USER'S QUERY."
|
||||
),
|
||||
stop_words=["<|im_end|>"],
|
||||
replace_eos=True,
|
||||
)
|
||||
|
||||
|
||||
_register_template(
|
||||
name="deepseek",
|
||||
format_user=StringFormatter(slots=["User: {{content}}\n\nAssistant:"]),
|
||||
@@ -554,6 +620,16 @@ _register_template(
|
||||
)
|
||||
|
||||
|
||||
_register_template(
|
||||
name="empty",
|
||||
format_user=StringFormatter(slots=["{{content}}"]),
|
||||
format_assistant=StringFormatter(slots=["{{content}}"]),
|
||||
format_system=StringFormatter(slots=[{"bos_token"}, "{{content}}"]),
|
||||
efficient_eos=True,
|
||||
force_system=True,
|
||||
)
|
||||
|
||||
|
||||
_register_template(
|
||||
name="falcon",
|
||||
format_user=StringFormatter(slots=["User: {{content}}\nFalcon:"]),
|
||||
@@ -562,16 +638,39 @@ _register_template(
|
||||
)
|
||||
|
||||
|
||||
_register_template(
|
||||
name="fewshot",
|
||||
format_separator=EmptyFormatter(slots=["\n\n"]),
|
||||
efficient_eos=True,
|
||||
)
|
||||
|
||||
|
||||
_register_template(
|
||||
name="gemma",
|
||||
format_user=StringFormatter(slots=["<start_of_turn>user\n{{content}}<end_of_turn>\n<start_of_turn>model\n"]),
|
||||
format_system=StringFormatter(slots=[{"bos_token"}, "{{content}}"]),
|
||||
format_observation=StringFormatter(
|
||||
slots=["<start_of_turn>tool\n{{content}}<end_of_turn>\n<start_of_turn>model\n"]
|
||||
),
|
||||
format_separator=EmptyFormatter(slots=["<end_of_turn>\n"]),
|
||||
efficient_eos=True,
|
||||
force_system=True,
|
||||
)
|
||||
|
||||
|
||||
_register_template(
|
||||
name="glm4",
|
||||
format_user=StringFormatter(slots=["<|user|>\n{{content}}<|assistant|>"]),
|
||||
format_assistant=StringFormatter(slots=["\n{{content}}"]),
|
||||
format_system=StringFormatter(slots=["[gMASK]<sop>{{content}}"]),
|
||||
format_function=FunctionFormatter(slots=["{{name}}\n{{arguments}}"]),
|
||||
format_observation=StringFormatter(slots=["<|observation|>\n{{content}}<|assistant|>"]),
|
||||
stop_words=["<|user|>", "<|observation|>"],
|
||||
efficient_eos=True,
|
||||
force_system=True,
|
||||
)
|
||||
|
||||
|
||||
_register_template(
|
||||
name="intern",
|
||||
format_user=StringFormatter(slots=["<|User|>:{{content}}", {"token": "<eoh>"}, "\n<|Bot|>:"]),
|
||||
@@ -601,17 +700,8 @@ _register_template(
|
||||
_register_template(
|
||||
name="llama2",
|
||||
format_user=StringFormatter(slots=[{"bos_token"}, "[INST] {{content}} [/INST]"]),
|
||||
format_assistant=StringFormatter(slots=[" {{content}} ", {"eos_token"}]),
|
||||
format_system=StringFormatter(slots=["<<SYS>>\n{{content}}\n<</SYS>>\n\n"]),
|
||||
default_system=(
|
||||
"You are a helpful, respectful and honest assistant. "
|
||||
"Always answer as helpfully as possible, while being safe. "
|
||||
"Your answers should not include any harmful, unethical, "
|
||||
"racist, sexist, toxic, dangerous, or illegal content. "
|
||||
"Please ensure that your responses are socially unbiased and positive in nature.\n\n"
|
||||
"If a question does not make any sense, or is not factually coherent, "
|
||||
"explain why instead of answering something not correct. "
|
||||
"If you don't know the answer to a question, please don't share false information."
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
@@ -623,6 +713,33 @@ _register_template(
|
||||
)
|
||||
|
||||
|
||||
_register_template(
|
||||
name="llama3",
|
||||
format_user=StringFormatter(
|
||||
slots=[
|
||||
(
|
||||
"<|start_header_id|>user<|end_header_id|>\n\n{{content}}<|eot_id|>"
|
||||
"<|start_header_id|>assistant<|end_header_id|>\n\n"
|
||||
)
|
||||
]
|
||||
),
|
||||
format_system=StringFormatter(
|
||||
slots=[{"bos_token"}, "<|start_header_id|>system<|end_header_id|>\n\n{{content}}<|eot_id|>"]
|
||||
),
|
||||
format_observation=StringFormatter(
|
||||
slots=[
|
||||
(
|
||||
"<|start_header_id|>tool<|end_header_id|>\n\n{{content}}<|eot_id|>"
|
||||
"<|start_header_id|>assistant<|end_header_id|>\n\n"
|
||||
)
|
||||
]
|
||||
),
|
||||
default_system="You are a helpful assistant.",
|
||||
stop_words=["<|eot_id|>"],
|
||||
replace_eos=True,
|
||||
)
|
||||
|
||||
|
||||
_register_template(
|
||||
name="mistral",
|
||||
format_user=StringFormatter(slots=["[INST] {{content}} [/INST]"]),
|
||||
@@ -633,8 +750,7 @@ _register_template(
|
||||
|
||||
_register_template(
|
||||
name="olmo",
|
||||
format_user=StringFormatter(slots=["<|user|>\n{{content}}<|assistant|>"]),
|
||||
format_assistant=StringFormatter(slots=["{{content}}", {"eos_token"}]),
|
||||
format_user=StringFormatter(slots=["<|user|>\n{{content}}<|assistant|>\n"]),
|
||||
format_system=StringFormatter(slots=[{"eos_token"}, "{{content}}"]),
|
||||
force_system=True,
|
||||
)
|
||||
@@ -643,12 +759,28 @@ _register_template(
|
||||
_register_template(
|
||||
name="openchat",
|
||||
format_user=StringFormatter(slots=["GPT4 Correct User: {{content}}", {"eos_token"}, "GPT4 Correct Assistant:"]),
|
||||
format_assistant=StringFormatter(slots=["{{content}}", {"eos_token"}]),
|
||||
format_system=StringFormatter(slots=[{"bos_token"}, "{{content}}"]),
|
||||
force_system=True,
|
||||
)
|
||||
|
||||
|
||||
_register_template(
|
||||
name="openchat-3.6",
|
||||
format_user=StringFormatter(
|
||||
slots=[
|
||||
(
|
||||
"<|start_header_id|>GPT4 Correct User<|end_header_id|>\n\n{{content}}<|eot_id|>"
|
||||
"<|start_header_id|>GPT4 Correct Assistant<|end_header_id|>\n\n"
|
||||
)
|
||||
]
|
||||
),
|
||||
format_system=StringFormatter(slots=[{"bos_token"}, "{{content}}"]),
|
||||
stop_words=["<|eot_id|>"],
|
||||
replace_eos=True,
|
||||
force_system=True,
|
||||
)
|
||||
|
||||
|
||||
_register_template(
|
||||
name="orion",
|
||||
format_user=StringFormatter(slots=["Human: {{content}}\n\nAssistant: ", {"eos_token"}]),
|
||||
@@ -657,10 +789,22 @@ _register_template(
|
||||
)
|
||||
|
||||
|
||||
_register_template(
|
||||
name="phi",
|
||||
format_user=StringFormatter(slots=["<|user|>\n{{content}}<|end|>\n<|assistant|>\n"]),
|
||||
format_system=StringFormatter(slots=[{"bos_token"}, "<|system|>\n{{content}}<|end|>\n"]),
|
||||
format_separator=EmptyFormatter(slots=["\n"]),
|
||||
default_system="You are a helpful AI assistant.",
|
||||
stop_words=["<|end|>"],
|
||||
replace_eos=True,
|
||||
)
|
||||
|
||||
|
||||
_register_template(
|
||||
name="qwen",
|
||||
format_user=StringFormatter(slots=["<|im_start|>user\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]),
|
||||
format_system=StringFormatter(slots=["<|im_start|>system\n{{content}}<|im_end|>\n"]),
|
||||
format_observation=StringFormatter(slots=["<|im_start|>tool\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]),
|
||||
format_separator=EmptyFormatter(slots=["\n"]),
|
||||
default_system="You are a helpful assistant.",
|
||||
stop_words=["<|im_end|>"],
|
||||
@@ -688,7 +832,11 @@ _register_template(
|
||||
|
||||
|
||||
_register_template(
|
||||
name="vanilla",
|
||||
name="telechat",
|
||||
format_user=StringFormatter(slots=["<_user>{{content}}<_bot>"]),
|
||||
format_system=StringFormatter(slots=["<_system>{{content}}<_end>"]),
|
||||
stop_words=["<_end>"],
|
||||
replace_eos=True,
|
||||
)
|
||||
|
||||
|
||||
@@ -742,12 +890,29 @@ _register_template(
|
||||
_register_template(
|
||||
name="yi",
|
||||
format_user=StringFormatter(slots=["<|im_start|>user\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]),
|
||||
format_system=StringFormatter(slots=["<|im_start|>system\n{{content}}<|im_end|>\n"]),
|
||||
format_separator=EmptyFormatter(slots=["\n"]),
|
||||
stop_words=["<|im_end|>"],
|
||||
replace_eos=True,
|
||||
)
|
||||
|
||||
|
||||
_register_template(
|
||||
name="yi_vl",
|
||||
format_user=StringFormatter(slots=["### Human: {{content}}\n### Assistant:"]),
|
||||
format_separator=EmptyFormatter(slots=["\n"]),
|
||||
default_system=(
|
||||
"This is a chat between an inquisitive human and an AI assistant. "
|
||||
"Assume the role of the AI assistant. Read all the images carefully, "
|
||||
"and respond to the human's questions with informative, helpful, detailed and polite answers. "
|
||||
"这是一个好奇的人类和一个人工智能助手之间的对话。假设你扮演这个AI助手的角色。"
|
||||
"仔细阅读所有的图像,并对人类的问题做出信息丰富、有帮助、详细的和礼貌的回答。\n\n"
|
||||
),
|
||||
stop_words=["###"],
|
||||
efficient_eos=True,
|
||||
)
|
||||
|
||||
|
||||
_register_template(
|
||||
name="yuan",
|
||||
format_user=StringFormatter(slots=["{{content}}", {"token": "<sep>"}]),
|
||||
@@ -762,7 +927,7 @@ _register_template(
|
||||
format_user=StringFormatter(slots=["<|user|>\n{{content}}", {"eos_token"}, "<|assistant|>"]),
|
||||
format_assistant=StringFormatter(slots=["\n{{content}}", {"eos_token"}]),
|
||||
format_system=StringFormatter(slots=["<|system|>\n{{content}}", {"eos_token"}]),
|
||||
default_system="You are a friendly chatbot who always responds in the style of a pirate",
|
||||
default_system="You are Zephyr, a helpful assistant.",
|
||||
)
|
||||
|
||||
|
||||
@@ -14,16 +14,17 @@ from transformers.utils import cached_file
|
||||
from ..data import get_template_and_fix_tokenizer
|
||||
from ..extras.constants import CHOICES, SUBJECTS
|
||||
from ..hparams import get_eval_args
|
||||
from ..model import load_model_and_tokenizer
|
||||
from ..model import load_model, load_tokenizer
|
||||
from .template import get_eval_template
|
||||
|
||||
|
||||
class Evaluator:
|
||||
def __init__(self, args: Optional[Dict[str, Any]] = None) -> None:
|
||||
self.model_args, self.data_args, self.eval_args, finetuning_args = get_eval_args(args)
|
||||
self.model, self.tokenizer = load_model_and_tokenizer(self.model_args, finetuning_args)
|
||||
self.tokenizer = load_tokenizer(self.model_args)["tokenizer"]
|
||||
self.tokenizer.padding_side = "right" # avoid overflow issue in batched inference for llama2
|
||||
self.template = get_template_and_fix_tokenizer(self.tokenizer, self.data_args.template)
|
||||
self.model = load_model(self.tokenizer, self.model_args, finetuning_args)
|
||||
self.eval_template = get_eval_template(self.eval_args.lang)
|
||||
self.choice_inputs = [
|
||||
self.tokenizer.encode(self.eval_template.prefix + ch, add_special_tokens=False)[-1] for ch in CHOICES
|
||||
@@ -117,6 +118,5 @@ class Evaluator:
|
||||
f.write(score_info)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
evaluator = Evaluator()
|
||||
evaluator.eval()
|
||||
def run_eval() -> None:
|
||||
Evaluator().eval()
|
||||
@@ -1,14 +1,10 @@
|
||||
from dataclasses import dataclass
|
||||
from typing import TYPE_CHECKING, Dict, List, Tuple
|
||||
from typing import Dict, List, Sequence, Tuple
|
||||
|
||||
from ..data import Role
|
||||
from ..extras.constants import CHOICES
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from datasets import Dataset
|
||||
|
||||
|
||||
@dataclass
|
||||
class EvalTemplate:
|
||||
system: str
|
||||
@@ -16,22 +12,29 @@ class EvalTemplate:
|
||||
answer: str
|
||||
prefix: str
|
||||
|
||||
def parse_example(self, example: Dict[str, str]) -> Tuple[str, str]:
|
||||
def _parse_example(self, example: Dict[str, str]) -> Tuple[str, str]:
|
||||
r"""
|
||||
input: a dict with keys {"question", "A", "B", "C", "D", "answer"}
|
||||
output: a tuple of (prompt, response)
|
||||
"""
|
||||
candidates = [self.choice.format(choice=ch, content=example[ch]) for ch in CHOICES if ch in example]
|
||||
return "".join([example["question"]] + candidates + [self.answer]), example["answer"]
|
||||
|
||||
def format_example(
|
||||
self, target_data: Dict[str, str], support_set: "Dataset", subject_name: str
|
||||
self, target_data: Dict[str, str], support_set: Sequence[Dict[str, str]], subject_name: str
|
||||
) -> List[Dict[str, str]]:
|
||||
r"""
|
||||
Converts dataset examples to messages.
|
||||
"""
|
||||
messages = []
|
||||
for k in range(len(support_set)):
|
||||
prompt, response = self.parse_example(support_set[k])
|
||||
messages.append({"role": Role.USER, "content": prompt})
|
||||
messages.append({"role": Role.ASSISTANT, "content": response})
|
||||
prompt, response = self._parse_example(support_set[k])
|
||||
messages.append({"role": Role.USER.value, "content": prompt})
|
||||
messages.append({"role": Role.ASSISTANT.value, "content": response})
|
||||
|
||||
prompt, response = self.parse_example(target_data)
|
||||
messages.append({"role": Role.USER, "content": prompt})
|
||||
messages.append({"role": Role.ASSISTANT, "content": response})
|
||||
prompt, response = self._parse_example(target_data)
|
||||
messages.append({"role": Role.USER.value, "content": prompt})
|
||||
messages.append({"role": Role.ASSISTANT.value, "content": response})
|
||||
messages[0]["content"] = self.system.format(subject=subject_name) + messages[0]["content"]
|
||||
return messages
|
||||
|
||||
@@ -39,7 +42,7 @@ class EvalTemplate:
|
||||
eval_templates: Dict[str, "EvalTemplate"] = {}
|
||||
|
||||
|
||||
def register_eval_template(name: str, system: str, choice: str, answer: str, prefix: str) -> None:
|
||||
def _register_eval_template(name: str, system: str, choice: str, answer: str, prefix: str) -> None:
|
||||
eval_templates[name] = EvalTemplate(system=system, choice=choice, answer=answer, prefix=prefix)
|
||||
|
||||
|
||||
@@ -49,7 +52,7 @@ def get_eval_template(name: str) -> "EvalTemplate":
|
||||
return eval_template
|
||||
|
||||
|
||||
register_eval_template(
|
||||
_register_eval_template(
|
||||
name="en",
|
||||
system="The following are multiple choice questions (with answers) about {subject}.\n\n",
|
||||
choice="\n{choice}. {content}",
|
||||
@@ -58,10 +61,10 @@ register_eval_template(
|
||||
)
|
||||
|
||||
|
||||
register_eval_template(
|
||||
_register_eval_template(
|
||||
name="zh",
|
||||
system="以下是中国关于{subject}考试的单项选择题,请选出其中的正确答案。\n\n",
|
||||
choice="\n{choice}. {content}",
|
||||
answer="\n答案:",
|
||||
prefix="\n",
|
||||
prefix=" ",
|
||||
)
|
||||
@@ -0,0 +1,217 @@
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import signal
|
||||
import sys
|
||||
import time
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from datetime import timedelta
|
||||
from typing import TYPE_CHECKING, Any, Dict, Optional
|
||||
|
||||
import transformers
|
||||
from transformers import TrainerCallback
|
||||
from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR, has_length
|
||||
|
||||
from .constants import TRAINER_LOG
|
||||
from .logging import LoggerHandler, get_logger
|
||||
from .misc import fix_valuehead_checkpoint
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from transformers import TrainerControl, TrainerState, TrainingArguments
|
||||
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
class FixValueHeadModelCallback(TrainerCallback):
|
||||
def on_save(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs):
|
||||
r"""
|
||||
Event called after a checkpoint save.
|
||||
"""
|
||||
if args.should_save:
|
||||
fix_valuehead_checkpoint(
|
||||
model=kwargs.pop("model"),
|
||||
output_dir=os.path.join(args.output_dir, "{}-{}".format(PREFIX_CHECKPOINT_DIR, state.global_step)),
|
||||
safe_serialization=args.save_safetensors,
|
||||
)
|
||||
|
||||
|
||||
class LogCallback(TrainerCallback):
|
||||
def __init__(self, output_dir: str) -> None:
|
||||
r"""
|
||||
Initializes a callback for logging training and evaluation status.
|
||||
"""
|
||||
""" Progress """
|
||||
self.start_time = 0
|
||||
self.cur_steps = 0
|
||||
self.max_steps = 0
|
||||
self.elapsed_time = ""
|
||||
self.remaining_time = ""
|
||||
self.thread_pool: Optional["ThreadPoolExecutor"] = None
|
||||
""" Status """
|
||||
self.aborted = False
|
||||
self.do_train = False
|
||||
""" Web UI """
|
||||
self.webui_mode = os.environ.get("LLAMABOARD_ENABLED", "0").lower() in ["true", "1"]
|
||||
if self.webui_mode:
|
||||
signal.signal(signal.SIGABRT, self._set_abort)
|
||||
self.logger_handler = LoggerHandler(output_dir)
|
||||
logging.root.addHandler(self.logger_handler)
|
||||
transformers.logging.add_handler(self.logger_handler)
|
||||
|
||||
def _set_abort(self, signum, frame) -> None:
|
||||
self.aborted = True
|
||||
|
||||
def _reset(self, max_steps: int = 0) -> None:
|
||||
self.start_time = time.time()
|
||||
self.cur_steps = 0
|
||||
self.max_steps = max_steps
|
||||
self.elapsed_time = ""
|
||||
self.remaining_time = ""
|
||||
|
||||
def _timing(self, cur_steps: int) -> None:
|
||||
cur_time = time.time()
|
||||
elapsed_time = cur_time - self.start_time
|
||||
avg_time_per_step = elapsed_time / cur_steps if cur_steps != 0 else 0
|
||||
remaining_time = (self.max_steps - cur_steps) * avg_time_per_step
|
||||
self.cur_steps = cur_steps
|
||||
self.elapsed_time = str(timedelta(seconds=int(elapsed_time)))
|
||||
self.remaining_time = str(timedelta(seconds=int(remaining_time)))
|
||||
|
||||
def _write_log(self, output_dir: str, logs: Dict[str, Any]) -> None:
|
||||
with open(os.path.join(output_dir, TRAINER_LOG), "a", encoding="utf-8") as f:
|
||||
f.write(json.dumps(logs) + "\n")
|
||||
|
||||
def _create_thread_pool(self, output_dir: str) -> None:
|
||||
os.makedirs(output_dir, exist_ok=True)
|
||||
self.thread_pool = ThreadPoolExecutor(max_workers=1)
|
||||
|
||||
def _close_thread_pool(self) -> None:
|
||||
if self.thread_pool is not None:
|
||||
self.thread_pool.shutdown(wait=True)
|
||||
self.thread_pool = None
|
||||
|
||||
def on_init_end(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs):
|
||||
r"""
|
||||
Event called at the end of the initialization of the `Trainer`.
|
||||
"""
|
||||
if (
|
||||
args.should_save
|
||||
and os.path.exists(os.path.join(args.output_dir, TRAINER_LOG))
|
||||
and args.overwrite_output_dir
|
||||
):
|
||||
logger.warning("Previous trainer log in this folder will be deleted.")
|
||||
os.remove(os.path.join(args.output_dir, TRAINER_LOG))
|
||||
|
||||
def on_train_begin(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs):
|
||||
r"""
|
||||
Event called at the beginning of training.
|
||||
"""
|
||||
if args.should_save:
|
||||
self.do_train = True
|
||||
self._reset(max_steps=state.max_steps)
|
||||
self._create_thread_pool(output_dir=args.output_dir)
|
||||
|
||||
def on_train_end(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs):
|
||||
r"""
|
||||
Event called at the end of training.
|
||||
"""
|
||||
self._close_thread_pool()
|
||||
|
||||
def on_substep_end(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs):
|
||||
r"""
|
||||
Event called at the end of an substep during gradient accumulation.
|
||||
"""
|
||||
if self.aborted:
|
||||
control.should_epoch_stop = True
|
||||
control.should_training_stop = True
|
||||
|
||||
def on_step_end(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs):
|
||||
r"""
|
||||
Event called at the end of a training step.
|
||||
"""
|
||||
if self.aborted:
|
||||
control.should_epoch_stop = True
|
||||
control.should_training_stop = True
|
||||
|
||||
def on_evaluate(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs):
|
||||
r"""
|
||||
Event called after an evaluation phase.
|
||||
"""
|
||||
if not self.do_train:
|
||||
self._close_thread_pool()
|
||||
|
||||
def on_predict(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs):
|
||||
r"""
|
||||
Event called after a successful prediction.
|
||||
"""
|
||||
if not self.do_train:
|
||||
self._close_thread_pool()
|
||||
|
||||
def on_log(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs):
|
||||
r"""
|
||||
Event called after logging the last logs.
|
||||
"""
|
||||
if not args.should_save:
|
||||
return
|
||||
|
||||
self._timing(cur_steps=state.global_step)
|
||||
logs = dict(
|
||||
current_steps=self.cur_steps,
|
||||
total_steps=self.max_steps,
|
||||
loss=state.log_history[-1].get("loss", None),
|
||||
eval_loss=state.log_history[-1].get("eval_loss", None),
|
||||
predict_loss=state.log_history[-1].get("predict_loss", None),
|
||||
reward=state.log_history[-1].get("reward", None),
|
||||
accuracy=state.log_history[-1].get("rewards/accuracies", None),
|
||||
learning_rate=state.log_history[-1].get("learning_rate", None),
|
||||
epoch=state.log_history[-1].get("epoch", None),
|
||||
percentage=round(self.cur_steps / self.max_steps * 100, 2) if self.max_steps != 0 else 100,
|
||||
elapsed_time=self.elapsed_time,
|
||||
remaining_time=self.remaining_time,
|
||||
throughput="{:.2f}".format(state.num_input_tokens_seen / (time.time() - self.start_time)),
|
||||
total_tokens=state.num_input_tokens_seen,
|
||||
)
|
||||
logs = {k: v for k, v in logs.items() if v is not None}
|
||||
if self.webui_mode and all(key in logs for key in ["loss", "learning_rate", "epoch"]):
|
||||
logger.info(
|
||||
"{{'loss': {:.4f}, 'learning_rate': {:2.4e}, 'epoch': {:.2f}, 'throughput': {}}}".format(
|
||||
logs["loss"], logs["learning_rate"], logs["epoch"], logs["throughput"]
|
||||
)
|
||||
)
|
||||
|
||||
if self.thread_pool is not None:
|
||||
self.thread_pool.submit(self._write_log, args.output_dir, logs)
|
||||
|
||||
def on_prediction_step(
|
||||
self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs
|
||||
):
|
||||
r"""
|
||||
Event called after a prediction step.
|
||||
"""
|
||||
if self.do_train:
|
||||
return
|
||||
|
||||
if self.aborted:
|
||||
sys.exit(0)
|
||||
|
||||
if not args.should_save:
|
||||
return
|
||||
|
||||
eval_dataloader = kwargs.pop("eval_dataloader", None)
|
||||
if has_length(eval_dataloader):
|
||||
if self.max_steps == 0:
|
||||
self._reset(max_steps=len(eval_dataloader))
|
||||
self._create_thread_pool(output_dir=args.output_dir)
|
||||
|
||||
self._timing(cur_steps=self.cur_steps + 1)
|
||||
if self.cur_steps % 5 == 0 and self.thread_pool is not None:
|
||||
logs = dict(
|
||||
current_steps=self.cur_steps,
|
||||
total_steps=self.max_steps,
|
||||
percentage=round(self.cur_steps / self.max_steps * 100, 2) if self.max_steps != 0 else 100,
|
||||
elapsed_time=self.elapsed_time,
|
||||
remaining_time=self.remaining_time,
|
||||
)
|
||||
self.thread_pool.submit(self._write_log, args.output_dir, logs)
|
||||
@@ -2,13 +2,24 @@ from collections import OrderedDict, defaultdict
|
||||
from enum import Enum
|
||||
from typing import Dict, Optional
|
||||
|
||||
from peft.utils import SAFETENSORS_WEIGHTS_NAME as SAFE_ADAPTER_WEIGHTS_NAME
|
||||
from peft.utils import WEIGHTS_NAME as ADAPTER_WEIGHTS_NAME
|
||||
from transformers.utils import SAFE_WEIGHTS_INDEX_NAME, SAFE_WEIGHTS_NAME, WEIGHTS_INDEX_NAME, WEIGHTS_NAME
|
||||
|
||||
|
||||
CHECKPOINT_NAMES = {
|
||||
SAFE_ADAPTER_WEIGHTS_NAME,
|
||||
ADAPTER_WEIGHTS_NAME,
|
||||
SAFE_WEIGHTS_INDEX_NAME,
|
||||
SAFE_WEIGHTS_NAME,
|
||||
WEIGHTS_INDEX_NAME,
|
||||
WEIGHTS_NAME,
|
||||
}
|
||||
|
||||
CHOICES = ["A", "B", "C", "D"]
|
||||
|
||||
DATA_CONFIG = "dataset_info.json"
|
||||
|
||||
DEFAULT_MODULE = defaultdict(str)
|
||||
|
||||
DEFAULT_TEMPLATE = defaultdict(str)
|
||||
|
||||
FILEEXT2TYPE = {
|
||||
@@ -24,28 +35,43 @@ IGNORE_INDEX = -100
|
||||
|
||||
LAYERNORM_NAMES = {"norm", "ln"}
|
||||
|
||||
LOG_FILE_NAME = "trainer_log.jsonl"
|
||||
LLAMABOARD_CONFIG = "llamaboard_config.yaml"
|
||||
|
||||
METHODS = ["full", "freeze", "lora"]
|
||||
|
||||
PEFT_METHODS = ["lora"]
|
||||
MOD_SUPPORTED_MODELS = {"bloom", "falcon", "gemma", "llama", "mistral", "mixtral", "phi", "starcoder2"}
|
||||
|
||||
PEFT_METHODS = {"lora"}
|
||||
|
||||
RUNNING_LOG = "running_log.txt"
|
||||
|
||||
SUBJECTS = ["Average", "STEM", "Social Sciences", "Humanities", "Other"]
|
||||
|
||||
SUPPORTED_MODELS = OrderedDict()
|
||||
|
||||
TRAINER_LOG = "trainer_log.jsonl"
|
||||
|
||||
TRAINING_ARGS = "training_args.yaml"
|
||||
|
||||
TRAINING_STAGES = {
|
||||
"Supervised Fine-Tuning": "sft",
|
||||
"Reward Modeling": "rm",
|
||||
"PPO": "ppo",
|
||||
"DPO": "dpo",
|
||||
"KTO": "kto",
|
||||
"Pre-Training": "pt",
|
||||
}
|
||||
|
||||
STAGES_USE_PAIR_DATA = {"rm", "dpo"}
|
||||
|
||||
SUPPORTED_CLASS_FOR_S2ATTN = {"llama"}
|
||||
|
||||
V_HEAD_WEIGHTS_NAME = "value_head.bin"
|
||||
|
||||
V_HEAD_SAFE_WEIGHTS_NAME = "value_head.safetensors"
|
||||
|
||||
VISION_MODELS = set()
|
||||
|
||||
|
||||
class DownloadSource(str, Enum):
|
||||
DEFAULT = "hf"
|
||||
@@ -54,8 +80,8 @@ class DownloadSource(str, Enum):
|
||||
|
||||
def register_model_group(
|
||||
models: Dict[str, Dict[DownloadSource, str]],
|
||||
module: Optional[str] = None,
|
||||
template: Optional[str] = None,
|
||||
vision: bool = False,
|
||||
) -> None:
|
||||
prefix = None
|
||||
for name, path in models.items():
|
||||
@@ -64,10 +90,23 @@ def register_model_group(
|
||||
else:
|
||||
assert prefix == name.split("-")[0], "prefix should be identical."
|
||||
SUPPORTED_MODELS[name] = path
|
||||
if module is not None:
|
||||
DEFAULT_MODULE[prefix] = module
|
||||
if template is not None:
|
||||
DEFAULT_TEMPLATE[prefix] = template
|
||||
if vision:
|
||||
VISION_MODELS.add(prefix)
|
||||
|
||||
|
||||
register_model_group(
|
||||
models={
|
||||
"Aya-23-8B-Chat": {
|
||||
DownloadSource.DEFAULT: "CohereForAI/aya-23-8B",
|
||||
},
|
||||
"Aya-23-35B-Chat": {
|
||||
DownloadSource.DEFAULT: "CohereForAI/aya-23-35B",
|
||||
},
|
||||
},
|
||||
template="cohere",
|
||||
)
|
||||
|
||||
|
||||
register_model_group(
|
||||
@@ -85,7 +124,6 @@ register_model_group(
|
||||
DownloadSource.MODELSCOPE: "baichuan-inc/Baichuan-13B-Chat",
|
||||
},
|
||||
},
|
||||
module="W_pack",
|
||||
template="baichuan",
|
||||
)
|
||||
|
||||
@@ -109,7 +147,6 @@ register_model_group(
|
||||
DownloadSource.MODELSCOPE: "baichuan-inc/Baichuan2-13B-Chat",
|
||||
},
|
||||
},
|
||||
module="W_pack",
|
||||
template="baichuan2",
|
||||
)
|
||||
|
||||
@@ -129,7 +166,6 @@ register_model_group(
|
||||
DownloadSource.MODELSCOPE: "AI-ModelScope/bloom-7b1",
|
||||
},
|
||||
},
|
||||
module="query_key_value",
|
||||
)
|
||||
|
||||
|
||||
@@ -148,7 +184,6 @@ register_model_group(
|
||||
DownloadSource.MODELSCOPE: "AI-ModelScope/bloomz-7b1-mt",
|
||||
},
|
||||
},
|
||||
module="query_key_value",
|
||||
)
|
||||
|
||||
|
||||
@@ -167,6 +202,19 @@ register_model_group(
|
||||
)
|
||||
|
||||
|
||||
register_model_group(
|
||||
models={
|
||||
"Breeze-7B": {
|
||||
DownloadSource.DEFAULT: "MediaTek-Research/Breeze-7B-Base-v1_0",
|
||||
},
|
||||
"Breeze-7B-Chat": {
|
||||
DownloadSource.DEFAULT: "MediaTek-Research/Breeze-7B-Instruct-v1_0",
|
||||
},
|
||||
},
|
||||
template="breeze",
|
||||
)
|
||||
|
||||
|
||||
register_model_group(
|
||||
models={
|
||||
"ChatGLM2-6B-Chat": {
|
||||
@@ -174,7 +222,6 @@ register_model_group(
|
||||
DownloadSource.MODELSCOPE: "ZhipuAI/chatglm2-6b",
|
||||
}
|
||||
},
|
||||
module="query_key_value",
|
||||
template="chatglm2",
|
||||
)
|
||||
|
||||
@@ -190,7 +237,6 @@ register_model_group(
|
||||
DownloadSource.MODELSCOPE: "ZhipuAI/chatglm3-6b",
|
||||
},
|
||||
},
|
||||
module="query_key_value",
|
||||
template="chatglm3",
|
||||
)
|
||||
|
||||
@@ -226,6 +272,73 @@ register_model_group(
|
||||
)
|
||||
|
||||
|
||||
register_model_group(
|
||||
models={
|
||||
"CodeGemma-7B": {
|
||||
DownloadSource.DEFAULT: "google/codegemma-7b",
|
||||
},
|
||||
"CodeGemma-7B-Chat": {
|
||||
DownloadSource.DEFAULT: "google/codegemma-7b-it",
|
||||
DownloadSource.MODELSCOPE: "AI-ModelScope/codegemma-7b-it",
|
||||
},
|
||||
"CodeGemma-1.1-2B": {
|
||||
DownloadSource.DEFAULT: "google/codegemma-1.1-2b",
|
||||
},
|
||||
"CodeGemma-1.1-7B-Chat": {
|
||||
DownloadSource.DEFAULT: "google/codegemma-1.1-7b-it",
|
||||
},
|
||||
},
|
||||
template="gemma",
|
||||
)
|
||||
|
||||
|
||||
register_model_group(
|
||||
models={
|
||||
"Codestral-22B-v0.1-Chat": {
|
||||
DownloadSource.DEFAULT: "mistralai/Codestral-22B-v0.1",
|
||||
},
|
||||
},
|
||||
template="mistral",
|
||||
)
|
||||
|
||||
|
||||
register_model_group(
|
||||
models={
|
||||
"CommandR-35B-Chat": {
|
||||
DownloadSource.DEFAULT: "CohereForAI/c4ai-command-r-v01",
|
||||
DownloadSource.MODELSCOPE: "AI-ModelScope/c4ai-command-r-v01",
|
||||
},
|
||||
"CommandR-Plus-104B-Chat": {
|
||||
DownloadSource.DEFAULT: "CohereForAI/c4ai-command-r-plus",
|
||||
DownloadSource.MODELSCOPE: "AI-ModelScope/c4ai-command-r-plus",
|
||||
},
|
||||
"CommandR-35B-4bit-Chat": {
|
||||
DownloadSource.DEFAULT: "CohereForAI/c4ai-command-r-v01-4bit",
|
||||
DownloadSource.MODELSCOPE: "mirror013/c4ai-command-r-v01-4bit",
|
||||
},
|
||||
"CommandR-Plus-104B-4bit-Chat": {
|
||||
DownloadSource.DEFAULT: "CohereForAI/c4ai-command-r-plus-4bit",
|
||||
},
|
||||
},
|
||||
template="cohere",
|
||||
)
|
||||
|
||||
|
||||
register_model_group(
|
||||
models={
|
||||
"DBRX-132B-Base": {
|
||||
DownloadSource.DEFAULT: "databricks/dbrx-base",
|
||||
DownloadSource.MODELSCOPE: "AI-ModelScope/dbrx-base",
|
||||
},
|
||||
"DBRX-132B-Chat": {
|
||||
DownloadSource.DEFAULT: "databricks/dbrx-instruct",
|
||||
DownloadSource.MODELSCOPE: "AI-ModelScope/dbrx-instruct",
|
||||
},
|
||||
},
|
||||
template="dbrx",
|
||||
)
|
||||
|
||||
|
||||
register_model_group(
|
||||
models={
|
||||
"DeepSeek-LLM-7B-Base": {
|
||||
@@ -246,18 +359,36 @@ register_model_group(
|
||||
},
|
||||
"DeepSeek-Math-7B-Base": {
|
||||
DownloadSource.DEFAULT: "deepseek-ai/deepseek-math-7b-base",
|
||||
DownloadSource.MODELSCOPE: "deepseek-ai/deepseek-math-7b-base",
|
||||
},
|
||||
"DeepSeek-Math-7B-Chat": {
|
||||
DownloadSource.DEFAULT: "deepseek-ai/deepseek-math-7b-instruct",
|
||||
DownloadSource.MODELSCOPE: "deepseek-ai/deepseek-math-7b-instruct",
|
||||
},
|
||||
"DeepSeek-MoE-16B-Base": {
|
||||
DownloadSource.DEFAULT: "deepseek-ai/deepseek-moe-16b-base",
|
||||
DownloadSource.MODELSCOPE: "deepseek-ai/deepseek-moe-16b-base",
|
||||
},
|
||||
"DeepSeek-MoE-16B-v2-Base": {
|
||||
DownloadSource.DEFAULT: "deepseek-ai/DeepSeek-V2-Lite",
|
||||
DownloadSource.MODELSCOPE: "deepseek-ai/DeepSeek-V2-Lite",
|
||||
},
|
||||
"DeepSeek-MoE-236B-Base": {
|
||||
DownloadSource.DEFAULT: "deepseek-ai/DeepSeek-V2",
|
||||
DownloadSource.MODELSCOPE: "deepseek-ai/DeepSeek-V2",
|
||||
},
|
||||
"DeepSeek-MoE-16B-Chat": {
|
||||
DownloadSource.DEFAULT: "deepseek-ai/deepseek-moe-16b-chat",
|
||||
DownloadSource.MODELSCOPE: "deepseek-ai/deepseek-moe-16b-chat",
|
||||
},
|
||||
"DeepSeek-MoE-16B-v2-Chat": {
|
||||
DownloadSource.DEFAULT: "deepseek-ai/DeepSeek-V2-Lite-Chat",
|
||||
DownloadSource.MODELSCOPE: "deepseek-ai/DeepSeek-V2-Lite-Chat",
|
||||
},
|
||||
"DeepSeek-MoE-236B-Chat": {
|
||||
DownloadSource.DEFAULT: "deepseek-ai/DeepSeek-V2-Chat",
|
||||
DownloadSource.MODELSCOPE: "deepseek-ai/DeepSeek-V2-Chat",
|
||||
},
|
||||
},
|
||||
template="deepseek",
|
||||
)
|
||||
@@ -298,6 +429,9 @@ register_model_group(
|
||||
DownloadSource.DEFAULT: "tiiuae/falcon-7b",
|
||||
DownloadSource.MODELSCOPE: "AI-ModelScope/falcon-7b",
|
||||
},
|
||||
"Falcon-11B": {
|
||||
DownloadSource.DEFAULT: "tiiuae/falcon-11B",
|
||||
},
|
||||
"Falcon-40B": {
|
||||
DownloadSource.DEFAULT: "tiiuae/falcon-40b",
|
||||
DownloadSource.MODELSCOPE: "AI-ModelScope/falcon-40b",
|
||||
@@ -319,7 +453,6 @@ register_model_group(
|
||||
DownloadSource.MODELSCOPE: "modelscope/falcon-180B-chat",
|
||||
},
|
||||
},
|
||||
module="query_key_value",
|
||||
template="falcon",
|
||||
)
|
||||
|
||||
@@ -342,11 +475,36 @@ register_model_group(
|
||||
DownloadSource.DEFAULT: "google/gemma-7b-it",
|
||||
DownloadSource.MODELSCOPE: "AI-ModelScope/gemma-7b-it",
|
||||
},
|
||||
"Gemma-1.1-2B-Chat": {
|
||||
DownloadSource.DEFAULT: "google/gemma-1.1-2b-it",
|
||||
},
|
||||
"Gemma-1.1-7B-Chat": {
|
||||
DownloadSource.DEFAULT: "google/gemma-1.1-7b-it",
|
||||
},
|
||||
},
|
||||
template="gemma",
|
||||
)
|
||||
|
||||
|
||||
register_model_group(
|
||||
models={
|
||||
"GLM-4-9B": {
|
||||
DownloadSource.DEFAULT: "THUDM/glm-4-9b",
|
||||
DownloadSource.MODELSCOPE: "ZhipuAI/glm-4-9b",
|
||||
},
|
||||
"GLM-4-9B-Chat": {
|
||||
DownloadSource.DEFAULT: "THUDM/glm-4-9b-chat",
|
||||
DownloadSource.MODELSCOPE: "ZhipuAI/glm-4-9b-chat",
|
||||
},
|
||||
"GLM-4-9B-1M-Chat": {
|
||||
DownloadSource.DEFAULT: "THUDM/glm-4-9b-chat-1m",
|
||||
DownloadSource.MODELSCOPE: "ZhipuAI/glm-4-9b-chat-1m",
|
||||
},
|
||||
},
|
||||
template="glm4",
|
||||
)
|
||||
|
||||
|
||||
register_model_group(
|
||||
models={
|
||||
"InternLM-7B": {
|
||||
@@ -389,11 +547,20 @@ register_model_group(
|
||||
DownloadSource.MODELSCOPE: "Shanghai_AI_Laboratory/internlm2-chat-20b",
|
||||
},
|
||||
},
|
||||
module="wqkv",
|
||||
template="intern2",
|
||||
)
|
||||
|
||||
|
||||
register_model_group(
|
||||
models={
|
||||
"Jambda-v0.1": {
|
||||
DownloadSource.DEFAULT: "ai21labs/Jamba-v0.1",
|
||||
DownloadSource.MODELSCOPE: "AI-ModelScope/Jamba-v0.1",
|
||||
}
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
register_model_group(
|
||||
models={
|
||||
"LingoWhale-8B": {
|
||||
@@ -401,7 +568,6 @@ register_model_group(
|
||||
DownloadSource.MODELSCOPE: "DeepLang/LingoWhale-8B",
|
||||
}
|
||||
},
|
||||
module="qkv_proj",
|
||||
)
|
||||
|
||||
|
||||
@@ -460,18 +626,72 @@ register_model_group(
|
||||
|
||||
register_model_group(
|
||||
models={
|
||||
"Mistral-7B": {
|
||||
"LLaMA3-8B": {
|
||||
DownloadSource.DEFAULT: "meta-llama/Meta-Llama-3-8B",
|
||||
DownloadSource.MODELSCOPE: "LLM-Research/Meta-Llama-3-8B",
|
||||
},
|
||||
"LLaMA3-70B": {
|
||||
DownloadSource.DEFAULT: "meta-llama/Meta-Llama-3-70B",
|
||||
DownloadSource.MODELSCOPE: "LLM-Research/Meta-Llama-3-70B",
|
||||
},
|
||||
"LLaMA3-8B-Chat": {
|
||||
DownloadSource.DEFAULT: "meta-llama/Meta-Llama-3-8B-Instruct",
|
||||
DownloadSource.MODELSCOPE: "LLM-Research/Meta-Llama-3-8B-Instruct",
|
||||
},
|
||||
"LLaMA3-70B-Chat": {
|
||||
DownloadSource.DEFAULT: "meta-llama/Meta-Llama-3-70B-Instruct",
|
||||
DownloadSource.MODELSCOPE: "LLM-Research/Meta-Llama-3-70B-Instruct",
|
||||
},
|
||||
"LLaMA3-8B-Chinese-Chat": {
|
||||
DownloadSource.DEFAULT: "shenzhi-wang/Llama3-8B-Chinese-Chat",
|
||||
DownloadSource.MODELSCOPE: "LLM-Research/Llama3-8B-Chinese-Chat",
|
||||
},
|
||||
"LLaMA3-70B-Chinese-Chat": {
|
||||
DownloadSource.DEFAULT: "shenzhi-wang/Llama3-70B-Chinese-Chat",
|
||||
},
|
||||
},
|
||||
template="llama3",
|
||||
)
|
||||
|
||||
|
||||
register_model_group(
|
||||
models={
|
||||
"LLaVA1.5-7B-Chat": {
|
||||
DownloadSource.DEFAULT: "llava-hf/llava-1.5-7b-hf",
|
||||
},
|
||||
"LLaVA1.5-13B-Chat": {
|
||||
DownloadSource.DEFAULT: "llava-hf/llava-1.5-13b-hf",
|
||||
},
|
||||
},
|
||||
template="vicuna",
|
||||
vision=True,
|
||||
)
|
||||
|
||||
|
||||
register_model_group(
|
||||
models={
|
||||
"Mistral-7B-v0.1": {
|
||||
DownloadSource.DEFAULT: "mistralai/Mistral-7B-v0.1",
|
||||
DownloadSource.MODELSCOPE: "AI-ModelScope/Mistral-7B-v0.1",
|
||||
},
|
||||
"Mistral-7B-Chat": {
|
||||
"Mistral-7B-v0.1-Chat": {
|
||||
DownloadSource.DEFAULT: "mistralai/Mistral-7B-Instruct-v0.1",
|
||||
DownloadSource.MODELSCOPE: "AI-ModelScope/Mistral-7B-Instruct-v0.1",
|
||||
},
|
||||
"Mistral-7B-v0.2": {
|
||||
DownloadSource.DEFAULT: "alpindale/Mistral-7B-v0.2-hf",
|
||||
DownloadSource.MODELSCOPE: "AI-ModelScope/Mistral-7B-v0.2-hf",
|
||||
},
|
||||
"Mistral-7B-v0.2-Chat": {
|
||||
DownloadSource.DEFAULT: "mistralai/Mistral-7B-Instruct-v0.2",
|
||||
DownloadSource.MODELSCOPE: "AI-ModelScope/Mistral-7B-Instruct-v0.2",
|
||||
},
|
||||
"Mistral-7B-v0.3": {
|
||||
DownloadSource.DEFAULT: "mistralai/Mistral-7B-v0.3",
|
||||
},
|
||||
"Mistral-7B-v0.3-Chat": {
|
||||
DownloadSource.DEFAULT: "mistralai/Mistral-7B-Instruct-v0.3",
|
||||
},
|
||||
},
|
||||
template="mistral",
|
||||
)
|
||||
@@ -479,14 +699,22 @@ register_model_group(
|
||||
|
||||
register_model_group(
|
||||
models={
|
||||
"Mixtral-8x7B": {
|
||||
"Mixtral-8x7B-v0.1": {
|
||||
DownloadSource.DEFAULT: "mistralai/Mixtral-8x7B-v0.1",
|
||||
DownloadSource.MODELSCOPE: "AI-ModelScope/Mixtral-8x7B-v0.1",
|
||||
},
|
||||
"Mixtral-8x7B-Chat": {
|
||||
"Mixtral-8x7B-v0.1-Chat": {
|
||||
DownloadSource.DEFAULT: "mistralai/Mixtral-8x7B-Instruct-v0.1",
|
||||
DownloadSource.MODELSCOPE: "AI-ModelScope/Mixtral-8x7B-Instruct-v0.1",
|
||||
},
|
||||
"Mixtral-8x22B-v0.1": {
|
||||
DownloadSource.DEFAULT: "mistralai/Mixtral-8x22B-v0.1",
|
||||
DownloadSource.MODELSCOPE: "AI-ModelScope/Mixtral-8x22B-v0.1",
|
||||
},
|
||||
"Mixtral-8x22B-v0.1-Chat": {
|
||||
DownloadSource.DEFAULT: "mistralai/Mixtral-8x22B-Instruct-v0.1",
|
||||
DownloadSource.MODELSCOPE: "AI-ModelScope/Mixtral-8x22B-Instruct-v0.1",
|
||||
},
|
||||
},
|
||||
template="mistral",
|
||||
)
|
||||
@@ -495,18 +723,18 @@ register_model_group(
|
||||
register_model_group(
|
||||
models={
|
||||
"OLMo-1B": {
|
||||
DownloadSource.DEFAULT: "allenai/OLMo-1B",
|
||||
DownloadSource.DEFAULT: "allenai/OLMo-1B-hf",
|
||||
},
|
||||
"OLMo-7B": {
|
||||
DownloadSource.DEFAULT: "allenai/OLMo-7B",
|
||||
DownloadSource.MODELSCOPE: "AI-ModelScope/OLMo-7B",
|
||||
DownloadSource.DEFAULT: "allenai/OLMo-7B-hf",
|
||||
},
|
||||
"OLMo-7B-Chat": {
|
||||
DownloadSource.DEFAULT: "allenai/OLMo-7B-Instruct",
|
||||
DownloadSource.DEFAULT: "ssec-uw/OLMo-7B-Instruct-hf",
|
||||
},
|
||||
"OLMo-1.7-7B": {
|
||||
DownloadSource.DEFAULT: "allenai/OLMo-1.7-7B-hf",
|
||||
},
|
||||
},
|
||||
module="att_proj",
|
||||
template="olmo",
|
||||
)
|
||||
|
||||
|
||||
@@ -514,13 +742,23 @@ register_model_group(
|
||||
models={
|
||||
"OpenChat3.5-7B-Chat": {
|
||||
DownloadSource.DEFAULT: "openchat/openchat-3.5-0106",
|
||||
DownloadSource.MODELSCOPE: "myxiongmodel/openchat_3.5",
|
||||
DownloadSource.MODELSCOPE: "xcwzxcwz/openchat-3.5-0106",
|
||||
}
|
||||
},
|
||||
template="openchat",
|
||||
)
|
||||
|
||||
|
||||
register_model_group(
|
||||
models={
|
||||
"OpenChat3.6-8B-Chat": {
|
||||
DownloadSource.DEFAULT: "openchat/openchat-3.6-8b-20240522",
|
||||
}
|
||||
},
|
||||
template="openchat-3.6",
|
||||
)
|
||||
|
||||
|
||||
register_model_group(
|
||||
models={
|
||||
"Orion-14B-Base": {
|
||||
@@ -548,6 +786,33 @@ register_model_group(
|
||||
)
|
||||
|
||||
|
||||
register_model_group(
|
||||
models={
|
||||
"PaliGemma-3B-pt-224": {
|
||||
DownloadSource.DEFAULT: "google/paligemma-3b-pt-224",
|
||||
DownloadSource.MODELSCOPE: "AI-ModelScope/paligemma-3b-pt-224",
|
||||
},
|
||||
"PaliGemma-3B-pt-448": {
|
||||
DownloadSource.DEFAULT: "google/paligemma-3b-pt-448",
|
||||
DownloadSource.MODELSCOPE: "AI-ModelScope/paligemma-3b-pt-448",
|
||||
},
|
||||
"PaliGemma-3B-pt-896": {
|
||||
DownloadSource.DEFAULT: "google/paligemma-3b-pt-896",
|
||||
DownloadSource.MODELSCOPE: "AI-ModelScope/paligemma-3b-pt-896",
|
||||
},
|
||||
"PaliGemma-3B-mix-224": {
|
||||
DownloadSource.DEFAULT: "google/paligemma-3b-mix-224",
|
||||
DownloadSource.MODELSCOPE: "AI-ModelScope/paligemma-3b-mix-224",
|
||||
},
|
||||
"PaliGemma-3B-mix-448": {
|
||||
DownloadSource.DEFAULT: "google/paligemma-3b-mix-448",
|
||||
DownloadSource.MODELSCOPE: "AI-ModelScope/paligemma-3b-mix-448",
|
||||
},
|
||||
},
|
||||
vision=True,
|
||||
)
|
||||
|
||||
|
||||
register_model_group(
|
||||
models={
|
||||
"Phi-1.5-1.3B": {
|
||||
@@ -562,6 +827,37 @@ register_model_group(
|
||||
)
|
||||
|
||||
|
||||
register_model_group(
|
||||
models={
|
||||
"Phi3-4B-4k-Chat": {
|
||||
DownloadSource.DEFAULT: "microsoft/Phi-3-mini-4k-instruct",
|
||||
DownloadSource.MODELSCOPE: "LLM-Research/Phi-3-mini-4k-instruct",
|
||||
},
|
||||
"Phi3-4B-128k-Chat": {
|
||||
DownloadSource.DEFAULT: "microsoft/Phi-3-mini-128k-instruct",
|
||||
DownloadSource.MODELSCOPE: "LLM-Research/Phi-3-mini-128k-instruct",
|
||||
},
|
||||
"Phi3-7B-8k-Chat": {
|
||||
DownloadSource.DEFAULT: "microsoft/Phi-3-small-8k-instruct",
|
||||
DownloadSource.MODELSCOPE: "LLM-Research/Phi-3-small-8k-instruct",
|
||||
},
|
||||
"Phi3-7B-128k-Chat": {
|
||||
DownloadSource.DEFAULT: "microsoft/Phi-3-small-128k-instruct",
|
||||
DownloadSource.MODELSCOPE: "LLM-Research/Phi-3-small-128k-instruct",
|
||||
},
|
||||
"Phi3-14B-8k-Chat": {
|
||||
DownloadSource.DEFAULT: "microsoft/Phi-3-medium-4k-instruct",
|
||||
DownloadSource.MODELSCOPE: "LLM-Research/Phi-3-medium-4k-instruct",
|
||||
},
|
||||
"Phi3-14B-128k-Chat": {
|
||||
DownloadSource.DEFAULT: "microsoft/Phi-3-medium-128k-instruct",
|
||||
DownloadSource.MODELSCOPE: "LLM-Research/Phi-3-medium-128k-instruct",
|
||||
},
|
||||
},
|
||||
template="phi",
|
||||
)
|
||||
|
||||
|
||||
register_model_group(
|
||||
models={
|
||||
"Qwen-1.8B": {
|
||||
@@ -629,7 +925,6 @@ register_model_group(
|
||||
DownloadSource.MODELSCOPE: "qwen/Qwen-72B-Chat-Int4",
|
||||
},
|
||||
},
|
||||
module="c_attn",
|
||||
template="qwen",
|
||||
)
|
||||
|
||||
@@ -656,10 +951,26 @@ register_model_group(
|
||||
DownloadSource.DEFAULT: "Qwen/Qwen1.5-14B",
|
||||
DownloadSource.MODELSCOPE: "qwen/Qwen1.5-14B",
|
||||
},
|
||||
"Qwen1.5-32B": {
|
||||
DownloadSource.DEFAULT: "Qwen/Qwen1.5-32B",
|
||||
DownloadSource.MODELSCOPE: "qwen/Qwen1.5-32B",
|
||||
},
|
||||
"Qwen1.5-72B": {
|
||||
DownloadSource.DEFAULT: "Qwen/Qwen1.5-72B",
|
||||
DownloadSource.MODELSCOPE: "qwen/Qwen1.5-72B",
|
||||
},
|
||||
"Qwen1.5-110B": {
|
||||
DownloadSource.DEFAULT: "Qwen/Qwen1.5-110B",
|
||||
DownloadSource.MODELSCOPE: "qwen/Qwen1.5-110B",
|
||||
},
|
||||
"Qwen1.5-MoE-A2.7B": {
|
||||
DownloadSource.DEFAULT: "Qwen/Qwen1.5-MoE-A2.7B",
|
||||
DownloadSource.MODELSCOPE: "qwen/Qwen1.5-MoE-A2.7B",
|
||||
},
|
||||
"Qwen1.5-Code-7B": {
|
||||
DownloadSource.DEFAULT: "Qwen/CodeQwen1.5-7B",
|
||||
DownloadSource.MODELSCOPE: "qwen/CodeQwen1.5-7B",
|
||||
},
|
||||
"Qwen1.5-0.5B-Chat": {
|
||||
DownloadSource.DEFAULT: "Qwen/Qwen1.5-0.5B-Chat",
|
||||
DownloadSource.MODELSCOPE: "qwen/Qwen1.5-0.5B-Chat",
|
||||
@@ -680,10 +991,26 @@ register_model_group(
|
||||
DownloadSource.DEFAULT: "Qwen/Qwen1.5-14B-Chat",
|
||||
DownloadSource.MODELSCOPE: "qwen/Qwen1.5-14B-Chat",
|
||||
},
|
||||
"Qwen1.5-32B-Chat": {
|
||||
DownloadSource.DEFAULT: "Qwen/Qwen1.5-32B-Chat",
|
||||
DownloadSource.MODELSCOPE: "qwen/Qwen1.5-32B-Chat",
|
||||
},
|
||||
"Qwen1.5-72B-Chat": {
|
||||
DownloadSource.DEFAULT: "Qwen/Qwen1.5-72B-Chat",
|
||||
DownloadSource.MODELSCOPE: "qwen/Qwen1.5-72B-Chat",
|
||||
},
|
||||
"Qwen1.5-110B-Chat": {
|
||||
DownloadSource.DEFAULT: "Qwen/Qwen1.5-110B-Chat",
|
||||
DownloadSource.MODELSCOPE: "qwen/Qwen1.5-110B-Chat",
|
||||
},
|
||||
"Qwen1.5-MoE-A2.7B-Chat": {
|
||||
DownloadSource.DEFAULT: "Qwen/Qwen1.5-MoE-A2.7B-Chat",
|
||||
DownloadSource.MODELSCOPE: "qwen/Qwen1.5-MoE-A2.7B-Chat",
|
||||
},
|
||||
"Qwen1.5-Code-7B-Chat": {
|
||||
DownloadSource.DEFAULT: "Qwen/CodeQwen1.5-7B-Chat",
|
||||
DownloadSource.MODELSCOPE: "qwen/CodeQwen1.5-7B-Chat",
|
||||
},
|
||||
"Qwen1.5-0.5B-int8-Chat": {
|
||||
DownloadSource.DEFAULT: "Qwen/Qwen1.5-0.5B-Chat-GPTQ-Int8",
|
||||
DownloadSource.MODELSCOPE: "qwen/Qwen1.5-0.5B-Chat-GPTQ-Int8",
|
||||
@@ -724,6 +1051,10 @@ register_model_group(
|
||||
DownloadSource.DEFAULT: "Qwen/Qwen1.5-14B-Chat-AWQ",
|
||||
DownloadSource.MODELSCOPE: "qwen/Qwen1.5-14B-Chat-AWQ",
|
||||
},
|
||||
"Qwen1.5-32B-int4-Chat": {
|
||||
DownloadSource.DEFAULT: "Qwen/Qwen1.5-32B-Chat-AWQ",
|
||||
DownloadSource.MODELSCOPE: "qwen/Qwen1.5-32B-Chat-AWQ",
|
||||
},
|
||||
"Qwen1.5-72B-int8-Chat": {
|
||||
DownloadSource.DEFAULT: "Qwen/Qwen1.5-72B-Chat-GPTQ-Int8",
|
||||
DownloadSource.MODELSCOPE: "qwen/Qwen1.5-72B-Chat-GPTQ-Int8",
|
||||
@@ -732,6 +1063,101 @@ register_model_group(
|
||||
DownloadSource.DEFAULT: "Qwen/Qwen1.5-72B-Chat-AWQ",
|
||||
DownloadSource.MODELSCOPE: "qwen/Qwen1.5-72B-Chat-AWQ",
|
||||
},
|
||||
"Qwen1.5-110B-int4-Chat": {
|
||||
DownloadSource.DEFAULT: "Qwen/Qwen1.5-110B-Chat-AWQ",
|
||||
DownloadSource.MODELSCOPE: "qwen/Qwen1.5-110B-Chat-AWQ",
|
||||
},
|
||||
"Qwen1.5-MoE-A2.7B-int4-Chat": {
|
||||
DownloadSource.DEFAULT: "Qwen/Qwen1.5-MoE-A2.7B-Chat-GPTQ-Int4",
|
||||
DownloadSource.MODELSCOPE: "qwen/Qwen1.5-MoE-A2.7B-Chat-GPTQ-Int4",
|
||||
},
|
||||
"Qwen1.5-Code-7B-int4-Chat": {
|
||||
DownloadSource.DEFAULT: "Qwen/CodeQwen1.5-7B-Chat-AWQ",
|
||||
DownloadSource.MODELSCOPE: "qwen/CodeQwen1.5-7B-Chat-AWQ",
|
||||
},
|
||||
},
|
||||
template="qwen",
|
||||
)
|
||||
|
||||
|
||||
register_model_group(
|
||||
models={
|
||||
"Qwen2-0.5B": {
|
||||
DownloadSource.DEFAULT: "Qwen/Qwen2-0.5B",
|
||||
DownloadSource.MODELSCOPE: "qwen/Qwen2-0.5B",
|
||||
},
|
||||
"Qwen2-1.5B": {
|
||||
DownloadSource.DEFAULT: "Qwen/Qwen2-1.5B",
|
||||
DownloadSource.MODELSCOPE: "qwen/Qwen2-1.5B",
|
||||
},
|
||||
"Qwen2-7B": {
|
||||
DownloadSource.DEFAULT: "Qwen/Qwen2-7B",
|
||||
DownloadSource.MODELSCOPE: "qwen/Qwen2-7B",
|
||||
},
|
||||
"Qwen2-72B": {
|
||||
DownloadSource.DEFAULT: "Qwen/Qwen2-72B",
|
||||
DownloadSource.MODELSCOPE: "qwen/Qwen2-72B",
|
||||
},
|
||||
"Qwen2-MoE-57B": {
|
||||
DownloadSource.DEFAULT: "Qwen/Qwen2-57B-A14B",
|
||||
DownloadSource.MODELSCOPE: "qwen/Qwen2-57B-A14B",
|
||||
},
|
||||
"Qwen2-0.5B-Chat": {
|
||||
DownloadSource.DEFAULT: "Qwen/Qwen2-0.5B-Instruct",
|
||||
DownloadSource.MODELSCOPE: "qwen/Qwen2-0.5B-Instruct",
|
||||
},
|
||||
"Qwen2-1.5B-Chat": {
|
||||
DownloadSource.DEFAULT: "Qwen/Qwen2-1.5B-Instruct",
|
||||
DownloadSource.MODELSCOPE: "qwen/Qwen2-1.5B-Instruct",
|
||||
},
|
||||
"Qwen2-7B-Chat": {
|
||||
DownloadSource.DEFAULT: "Qwen/Qwen2-7B-Instruct",
|
||||
DownloadSource.MODELSCOPE: "qwen/Qwen2-7B-Instruct",
|
||||
},
|
||||
"Qwen2-72B-Chat": {
|
||||
DownloadSource.DEFAULT: "Qwen/Qwen2-72B-Instruct",
|
||||
DownloadSource.MODELSCOPE: "qwen/Qwen2-72B-Instruct",
|
||||
},
|
||||
"Qwen2-MoE-57B-Chat": {
|
||||
DownloadSource.DEFAULT: "Qwen/Qwen2-57B-A14B-Instruct",
|
||||
DownloadSource.MODELSCOPE: "qwen/Qwen2-57B-A14B-Instruct",
|
||||
},
|
||||
"Qwen2-0.5B-int8-Chat": {
|
||||
DownloadSource.DEFAULT: "Qwen/Qwen2-0.5B-Instruct-GPTQ-Int8",
|
||||
DownloadSource.MODELSCOPE: "qwen/Qwen2-0.5B-Instruct-GPTQ-Int8",
|
||||
},
|
||||
"Qwen2-0.5B-int4-Chat": {
|
||||
DownloadSource.DEFAULT: "Qwen/Qwen2-0.5B-Instruct-AWQ",
|
||||
DownloadSource.MODELSCOPE: "qwen/Qwen2-0.5B-Instruct-AWQ",
|
||||
},
|
||||
"Qwen2-1.5B-int8-Chat": {
|
||||
DownloadSource.DEFAULT: "Qwen/Qwen2-1.5B-Instruct-GPTQ-Int8",
|
||||
DownloadSource.MODELSCOPE: "qwen/Qwen2-1.5B-Instruct-GPTQ-Int8",
|
||||
},
|
||||
"Qwen2-1.5B-int4-Chat": {
|
||||
DownloadSource.DEFAULT: "Qwen/Qwen2-1.5B-Instruct-AWQ",
|
||||
DownloadSource.MODELSCOPE: "qwen/Qwen2-1.5B-Instruct-AWQ",
|
||||
},
|
||||
"Qwen2-7B-int8-Chat": {
|
||||
DownloadSource.DEFAULT: "Qwen/Qwen2-7B-Instruct-GPTQ-Int8",
|
||||
DownloadSource.MODELSCOPE: "qwen/Qwen2-7B-Instruct-GPTQ-Int8",
|
||||
},
|
||||
"Qwen2-7B-int4-Chat": {
|
||||
DownloadSource.DEFAULT: "Qwen/Qwen2-7B-Instruct-AWQ",
|
||||
DownloadSource.MODELSCOPE: "qwen/Qwen2-7B-Instruct-AWQ",
|
||||
},
|
||||
"Qwen2-72B-int8-Chat": {
|
||||
DownloadSource.DEFAULT: "Qwen/Qwen2-72B-Instruct-GPTQ-Int8",
|
||||
DownloadSource.MODELSCOPE: "qwen/Qwen2-72B-Instruct-GPTQ-Int8",
|
||||
},
|
||||
"Qwen2-72B-int4-Chat": {
|
||||
DownloadSource.DEFAULT: "Qwen/Qwen2-72B-Instruct-AWQ",
|
||||
DownloadSource.MODELSCOPE: "qwen/Qwen2-72B-Instruct-AWQ",
|
||||
},
|
||||
"Qwen2-MoE-57B-int4-Chat": {
|
||||
DownloadSource.DEFAULT: "Qwen/Qwen2-57B-A14B-Instruct-GPTQ-Int4",
|
||||
DownloadSource.MODELSCOPE: "qwen/Qwen2-57B-A14B-Instruct-GPTQ-Int4",
|
||||
},
|
||||
},
|
||||
template="qwen",
|
||||
)
|
||||
@@ -765,17 +1191,39 @@ register_model_group(
|
||||
models={
|
||||
"StarCoder2-3B": {
|
||||
DownloadSource.DEFAULT: "bigcode/starcoder2-3b",
|
||||
DownloadSource.MODELSCOPE: "AI-ModelScope/starcoder2-3b",
|
||||
},
|
||||
"StarCoder2-7B": {
|
||||
DownloadSource.DEFAULT: "bigcode/starcoder2-7b",
|
||||
DownloadSource.MODELSCOPE: "AI-ModelScope/starcoder2-7b",
|
||||
},
|
||||
"StarCoder2-15B": {
|
||||
DownloadSource.DEFAULT: "bigcode/starcoder2-15b",
|
||||
DownloadSource.MODELSCOPE: "AI-ModelScope/starcoder2-15b",
|
||||
},
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
register_model_group(
|
||||
models={
|
||||
"TeleChat-7B-Chat": {
|
||||
DownloadSource.DEFAULT: "Tele-AI/telechat-7B",
|
||||
DownloadSource.MODELSCOPE: "TeleAI/telechat-7B",
|
||||
},
|
||||
"TeleChat-12B-Chat": {
|
||||
DownloadSource.DEFAULT: "Tele-AI/TeleChat-12B",
|
||||
DownloadSource.MODELSCOPE: "TeleAI/TeleChat-12B",
|
||||
},
|
||||
"TeleChat-12B-v2-Chat": {
|
||||
DownloadSource.DEFAULT: "Tele-AI/TeleChat-12B-v2",
|
||||
DownloadSource.MODELSCOPE: "TeleAI/TeleChat-12B-v2",
|
||||
},
|
||||
},
|
||||
template="telechat",
|
||||
)
|
||||
|
||||
|
||||
register_model_group(
|
||||
models={
|
||||
"Vicuna1.5-7B-Chat": {
|
||||
@@ -793,17 +1241,53 @@ register_model_group(
|
||||
|
||||
register_model_group(
|
||||
models={
|
||||
"XuanYuan-6B": {
|
||||
DownloadSource.DEFAULT: "Duxiaoman-DI/XuanYuan-6B",
|
||||
DownloadSource.MODELSCOPE: "Duxiaoman-DI/XuanYuan-6B",
|
||||
},
|
||||
"XuanYuan-70B": {
|
||||
DownloadSource.DEFAULT: "Duxiaoman-DI/XuanYuan-70B",
|
||||
DownloadSource.MODELSCOPE: "Duxiaoman-DI/XuanYuan-70B",
|
||||
},
|
||||
"XuanYuan-2-70B": {
|
||||
DownloadSource.DEFAULT: "Duxiaoman-DI/XuanYuan2-70B",
|
||||
DownloadSource.MODELSCOPE: "Duxiaoman-DI/XuanYuan2-70B",
|
||||
},
|
||||
"XuanYuan-6B-Chat": {
|
||||
DownloadSource.DEFAULT: "Duxiaoman-DI/XuanYuan-6B-Chat",
|
||||
DownloadSource.MODELSCOPE: "Duxiaoman-DI/XuanYuan-6B-Chat",
|
||||
},
|
||||
"XuanYuan-70B-Chat": {
|
||||
DownloadSource.DEFAULT: "Duxiaoman-DI/XuanYuan-70B-Chat",
|
||||
DownloadSource.MODELSCOPE: "Duxiaoman-DI/XuanYuan-70B-Chat",
|
||||
},
|
||||
"XuanYuan-2-70B-Chat": {
|
||||
DownloadSource.DEFAULT: "Duxiaoman-DI/XuanYuan2-70B-Chat",
|
||||
DownloadSource.MODELSCOPE: "Duxiaoman-DI/XuanYuan2-70B-Chat",
|
||||
},
|
||||
"XuanYuan-6B-int8-Chat": {
|
||||
DownloadSource.DEFAULT: "Duxiaoman-DI/XuanYuan-6B-Chat-8bit",
|
||||
DownloadSource.MODELSCOPE: "Duxiaoman-DI/XuanYuan-6B-Chat-8bit",
|
||||
},
|
||||
"XuanYuan-6B-int4-Chat": {
|
||||
DownloadSource.DEFAULT: "Duxiaoman-DI/XuanYuan-6B-Chat-4bit",
|
||||
DownloadSource.MODELSCOPE: "Duxiaoman-DI/XuanYuan-6B-Chat-4bit",
|
||||
},
|
||||
"XuanYuan-70B-int8-Chat": {
|
||||
DownloadSource.DEFAULT: "Duxiaoman-DI/XuanYuan-70B-Chat-8bit",
|
||||
DownloadSource.MODELSCOPE: "Duxiaoman-DI/XuanYuan-70B-Chat-8bit",
|
||||
},
|
||||
"XuanYuan-70B-int4-Chat": {
|
||||
DownloadSource.DEFAULT: "Duxiaoman-DI/XuanYuan-70B-Chat-4bit",
|
||||
DownloadSource.MODELSCOPE: "Duxiaoman-DI/XuanYuan-70B-Chat-4bit",
|
||||
},
|
||||
"XuanYuan-2-70B-int8-Chat": {
|
||||
DownloadSource.DEFAULT: "Duxiaoman-DI/XuanYuan2-70B-Chat-8bit",
|
||||
DownloadSource.MODELSCOPE: "Duxiaoman-DI/XuanYuan2-70B-Chat-8bit",
|
||||
},
|
||||
"XuanYuan-2-70B-int4-Chat": {
|
||||
DownloadSource.DEFAULT: "Duxiaoman-DI/XuanYuan2-70B-Chat-4bit",
|
||||
DownloadSource.MODELSCOPE: "Duxiaoman-DI/XuanYuan2-70B-Chat-4bit",
|
||||
},
|
||||
},
|
||||
template="xuanyuan",
|
||||
@@ -840,6 +1324,30 @@ register_model_group(
|
||||
DownloadSource.DEFAULT: "xverse/XVERSE-65B-Chat",
|
||||
DownloadSource.MODELSCOPE: "xverse/XVERSE-65B-Chat",
|
||||
},
|
||||
"XVERSE-MoE-A4.2B": {
|
||||
DownloadSource.DEFAULT: "xverse/XVERSE-MoE-A4.2B",
|
||||
DownloadSource.MODELSCOPE: "xverse/XVERSE-MoE-A4.2B",
|
||||
},
|
||||
"XVERSE-7B-int8-Chat": {
|
||||
DownloadSource.DEFAULT: "xverse/XVERSE-7B-Chat-GPTQ-Int8",
|
||||
DownloadSource.MODELSCOPE: "xverse/XVERSE-7B-Chat-GPTQ-Int8",
|
||||
},
|
||||
"XVERSE-7B-int4-Chat": {
|
||||
DownloadSource.DEFAULT: "xverse/XVERSE-7B-Chat-GPTQ-Int4",
|
||||
DownloadSource.MODELSCOPE: "xverse/XVERSE-7B-Chat-GPTQ-Int4",
|
||||
},
|
||||
"XVERSE-13B-int8-Chat": {
|
||||
DownloadSource.DEFAULT: "xverse/XVERSE-13B-Chat-GPTQ-Int8",
|
||||
DownloadSource.MODELSCOPE: "xverse/XVERSE-13B-Chat-GPTQ-Int8",
|
||||
},
|
||||
"XVERSE-13B-int4-Chat": {
|
||||
DownloadSource.DEFAULT: "xverse/XVERSE-13B-Chat-GPTQ-Int4",
|
||||
DownloadSource.MODELSCOPE: "xverse/XVERSE-13B-Chat-GPTQ-Int4",
|
||||
},
|
||||
"XVERSE-65B-int4-Chat": {
|
||||
DownloadSource.DEFAULT: "xverse/XVERSE-65B-Chat-GPTQ-Int4",
|
||||
DownloadSource.MODELSCOPE: "xverse/XVERSE-65B-Chat-GPTQ-Int4",
|
||||
},
|
||||
},
|
||||
template="xverse",
|
||||
)
|
||||
@@ -898,11 +1406,49 @@ register_model_group(
|
||||
DownloadSource.DEFAULT: "01-ai/Yi-34B-Chat-4bits",
|
||||
DownloadSource.MODELSCOPE: "01ai/Yi-34B-Chat-4bits",
|
||||
},
|
||||
"Yi-1.5-6B": {
|
||||
DownloadSource.DEFAULT: "01-ai/Yi-1.5-6B",
|
||||
DownloadSource.MODELSCOPE: "01ai/Yi-1.5-6B",
|
||||
},
|
||||
"Yi-1.5-9B": {
|
||||
DownloadSource.DEFAULT: "01-ai/Yi-1.5-9B",
|
||||
DownloadSource.MODELSCOPE: "01ai/Yi-1.5-9B",
|
||||
},
|
||||
"Yi-1.5-34B": {
|
||||
DownloadSource.DEFAULT: "01-ai/Yi-1.5-34B",
|
||||
DownloadSource.MODELSCOPE: "01ai/Yi-1.5-34B",
|
||||
},
|
||||
"Yi-1.5-6B-Chat": {
|
||||
DownloadSource.DEFAULT: "01-ai/Yi-1.5-6B-Chat",
|
||||
DownloadSource.MODELSCOPE: "01ai/Yi-1.5-6B-Chat",
|
||||
},
|
||||
"Yi-1.5-9B-Chat": {
|
||||
DownloadSource.DEFAULT: "01-ai/Yi-1.5-9B-Chat",
|
||||
DownloadSource.MODELSCOPE: "01ai/Yi-1.5-9B-Chat",
|
||||
},
|
||||
"Yi-1.5-34B-Chat": {
|
||||
DownloadSource.DEFAULT: "01-ai/Yi-1.5-34B-Chat",
|
||||
DownloadSource.MODELSCOPE: "01ai/Yi-1.5-34B-Chat",
|
||||
},
|
||||
},
|
||||
template="yi",
|
||||
)
|
||||
|
||||
|
||||
register_model_group(
|
||||
models={
|
||||
"YiVL-6B-Chat": {
|
||||
DownloadSource.DEFAULT: "BUAADreamer/Yi-VL-6B-hf",
|
||||
},
|
||||
"YiVL-34B-Chat": {
|
||||
DownloadSource.DEFAULT: "BUAADreamer/Yi-VL-34B-hf",
|
||||
},
|
||||
},
|
||||
template="yi_vl",
|
||||
vision=True,
|
||||
)
|
||||
|
||||
|
||||
register_model_group(
|
||||
models={
|
||||
"Yuan2-2B-Chat": {
|
||||
@@ -932,21 +1478,9 @@ register_model_group(
|
||||
DownloadSource.DEFAULT: "HuggingFaceH4/zephyr-7b-beta",
|
||||
DownloadSource.MODELSCOPE: "modelscope/zephyr-7b-beta",
|
||||
},
|
||||
"Zephyr-141B-ORPO-Chat": {
|
||||
DownloadSource.DEFAULT: "HuggingFaceH4/zephyr-orpo-141b-A35b-v0.1",
|
||||
},
|
||||
},
|
||||
template="zephyr",
|
||||
)
|
||||
|
||||
|
||||
register_model_group(
|
||||
models={
|
||||
"Atom-7B": {
|
||||
DownloadSource.DEFAULT: "FlagAlpha/Atom-7B",
|
||||
DownloadSource.MODELSCOPE: "FlagAlpha/Atom-7B",
|
||||
},
|
||||
"Atom-7B-Chat": {
|
||||
DownloadSource.DEFAULT: "FlagAlpha/Atom-7B-Chat",
|
||||
DownloadSource.MODELSCOPE: "FlagAlpha/Atom-7B-Chat",
|
||||
},
|
||||
},
|
||||
template="atom",
|
||||
)
|
||||
@@ -0,0 +1,55 @@
|
||||
import platform
|
||||
|
||||
import accelerate
|
||||
import datasets
|
||||
import peft
|
||||
import torch
|
||||
import transformers
|
||||
import trl
|
||||
from transformers.integrations import is_deepspeed_available
|
||||
from transformers.utils import is_bitsandbytes_available, is_torch_cuda_available, is_torch_npu_available
|
||||
|
||||
from .packages import is_vllm_available
|
||||
|
||||
|
||||
VERSION = "0.8.0"
|
||||
|
||||
|
||||
def print_env() -> None:
|
||||
info = {
|
||||
"`llamafactory` version": VERSION,
|
||||
"Platform": platform.platform(),
|
||||
"Python version": platform.python_version(),
|
||||
"PyTorch version": torch.__version__,
|
||||
"Transformers version": transformers.__version__,
|
||||
"Datasets version": datasets.__version__,
|
||||
"Accelerate version": accelerate.__version__,
|
||||
"PEFT version": peft.__version__,
|
||||
"TRL version": trl.__version__,
|
||||
}
|
||||
|
||||
if is_torch_cuda_available():
|
||||
info["PyTorch version"] += " (GPU)"
|
||||
info["GPU type"] = torch.cuda.get_device_name()
|
||||
|
||||
if is_torch_npu_available():
|
||||
info["PyTorch version"] += " (NPU)"
|
||||
info["NPU type"] = torch.npu.get_device_name()
|
||||
info["CANN version"] = torch.version.cann
|
||||
|
||||
if is_deepspeed_available():
|
||||
import deepspeed # type: ignore
|
||||
|
||||
info["DeepSpeed version"] = deepspeed.__version__
|
||||
|
||||
if is_bitsandbytes_available():
|
||||
import bitsandbytes
|
||||
|
||||
info["Bitsandbytes version"] = bitsandbytes.__version__
|
||||
|
||||
if is_vllm_available():
|
||||
import vllm
|
||||
|
||||
info["vLLM version"] = vllm.__version__
|
||||
|
||||
print("\n" + "\n".join(["- {}: {}".format(key, value) for key, value in info.items()]) + "\n")
|
||||
@@ -0,0 +1,68 @@
|
||||
import logging
|
||||
import os
|
||||
import sys
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
|
||||
from .constants import RUNNING_LOG
|
||||
|
||||
|
||||
class LoggerHandler(logging.Handler):
|
||||
r"""
|
||||
Logger handler used in Web UI.
|
||||
"""
|
||||
|
||||
def __init__(self, output_dir: str) -> None:
|
||||
super().__init__()
|
||||
formatter = logging.Formatter(
|
||||
fmt="%(asctime)s - %(levelname)s - %(name)s - %(message)s", datefmt="%m/%d/%Y %H:%M:%S"
|
||||
)
|
||||
self.setLevel(logging.INFO)
|
||||
self.setFormatter(formatter)
|
||||
|
||||
os.makedirs(output_dir, exist_ok=True)
|
||||
self.running_log = os.path.join(output_dir, RUNNING_LOG)
|
||||
if os.path.exists(self.running_log):
|
||||
os.remove(self.running_log)
|
||||
|
||||
self.thread_pool = ThreadPoolExecutor(max_workers=1)
|
||||
|
||||
def _write_log(self, log_entry: str) -> None:
|
||||
with open(self.running_log, "a", encoding="utf-8") as f:
|
||||
f.write(log_entry + "\n\n")
|
||||
|
||||
def emit(self, record) -> None:
|
||||
if record.name == "httpx":
|
||||
return
|
||||
|
||||
log_entry = self.format(record)
|
||||
self.thread_pool.submit(self._write_log, log_entry)
|
||||
|
||||
def close(self) -> None:
|
||||
self.thread_pool.shutdown(wait=True)
|
||||
return super().close()
|
||||
|
||||
|
||||
def get_logger(name: str) -> logging.Logger:
|
||||
r"""
|
||||
Gets a standard logger with a stream hander to stdout.
|
||||
"""
|
||||
formatter = logging.Formatter(
|
||||
fmt="%(asctime)s - %(levelname)s - %(name)s - %(message)s", datefmt="%m/%d/%Y %H:%M:%S"
|
||||
)
|
||||
handler = logging.StreamHandler(sys.stdout)
|
||||
handler.setFormatter(formatter)
|
||||
|
||||
logger = logging.getLogger(name)
|
||||
logger.setLevel(logging.INFO)
|
||||
logger.addHandler(handler)
|
||||
|
||||
return logger
|
||||
|
||||
|
||||
def reset_logging() -> None:
|
||||
r"""
|
||||
Removes basic config of root logger. (unused in script)
|
||||
"""
|
||||
root = logging.getLogger()
|
||||
list(map(root.removeHandler, root.handlers))
|
||||
list(map(root.removeFilter, root.filters))
|
||||
@@ -30,7 +30,7 @@ except Exception:
|
||||
if TYPE_CHECKING:
|
||||
from trl import AutoModelForCausalLMWithValueHead
|
||||
|
||||
from llmtuner.hparams import ModelArguments
|
||||
from ..hparams import ModelArguments
|
||||
|
||||
|
||||
logger = get_logger(__name__)
|
||||
@@ -58,14 +58,14 @@ class AverageMeter:
|
||||
|
||||
|
||||
def check_dependencies() -> None:
|
||||
if int(os.environ.get("DISABLE_VERSION_CHECK", "0")):
|
||||
if os.environ.get("DISABLE_VERSION_CHECK", "0").lower() in ["true", "1"]:
|
||||
logger.warning("Version checking has been disabled, may lead to unexpected behaviors.")
|
||||
else:
|
||||
require_version("transformers>=4.37.2", "To fix: pip install transformers>=4.37.2")
|
||||
require_version("datasets>=2.14.3", "To fix: pip install datasets>=2.14.3")
|
||||
require_version("accelerate>=0.27.2", "To fix: pip install accelerate>=0.27.2")
|
||||
require_version("peft>=0.9.0", "To fix: pip install peft>=0.9.0")
|
||||
require_version("trl>=0.7.11", "To fix: pip install trl>=0.7.11")
|
||||
require_version("transformers>=4.41.2", "To fix: pip install transformers>=4.41.2")
|
||||
require_version("datasets>=2.16.0", "To fix: pip install datasets>=2.16.0")
|
||||
require_version("accelerate>=0.30.1", "To fix: pip install accelerate>=0.30.1")
|
||||
require_version("peft>=0.11.1", "To fix: pip install peft>=0.11.1")
|
||||
require_version("trl>=0.8.6", "To fix: pip install trl>=0.8.6")
|
||||
|
||||
|
||||
def count_parameters(model: torch.nn.Module) -> Tuple[int, int]:
|
||||
@@ -81,7 +81,14 @@ def count_parameters(model: torch.nn.Module) -> Tuple[int, int]:
|
||||
|
||||
# Due to the design of 4bit linear layers from bitsandbytes, multiply the number of parameters by 2
|
||||
if param.__class__.__name__ == "Params4bit":
|
||||
num_params = num_params * 2
|
||||
if hasattr(param, "quant_storage") and hasattr(param.quant_storage, "itemsize"):
|
||||
num_bytes = param.quant_storage.itemsize
|
||||
elif hasattr(param, "element_size"): # for older pytorch version
|
||||
num_bytes = param.element_size()
|
||||
else:
|
||||
num_bytes = 1
|
||||
|
||||
num_params = num_params * 2 * num_bytes
|
||||
|
||||
all_param += num_params
|
||||
if param.requires_grad:
|
||||
@@ -158,13 +165,15 @@ def get_current_device() -> torch.device:
|
||||
|
||||
def get_device_count() -> int:
|
||||
r"""
|
||||
Gets the number of available GPU devices.
|
||||
Gets the number of available GPU or NPU devices.
|
||||
"""
|
||||
if not torch.cuda.is_available():
|
||||
if is_torch_npu_available():
|
||||
return torch.npu.device_count()
|
||||
elif is_torch_cuda_available():
|
||||
return torch.cuda.device_count()
|
||||
else:
|
||||
return 0
|
||||
|
||||
return torch.cuda.device_count()
|
||||
|
||||
|
||||
def get_logits_processor() -> "LogitsProcessorList":
|
||||
r"""
|
||||
@@ -187,30 +196,47 @@ def infer_optim_dtype(model_dtype: torch.dtype) -> torch.dtype:
|
||||
return torch.float32
|
||||
|
||||
|
||||
def is_gpu_or_npu_available() -> bool:
|
||||
r"""
|
||||
Checks if the GPU or NPU is available.
|
||||
"""
|
||||
return is_torch_npu_available() or is_torch_cuda_available()
|
||||
|
||||
|
||||
def has_tokenized_data(path: os.PathLike) -> bool:
|
||||
r"""
|
||||
Checks if the path has a tokenized dataset.
|
||||
"""
|
||||
return os.path.isdir(path) and len(os.listdir(path)) > 0
|
||||
|
||||
|
||||
def torch_gc() -> None:
|
||||
r"""
|
||||
Collects GPU memory.
|
||||
Collects GPU or NPU memory.
|
||||
"""
|
||||
gc.collect()
|
||||
if torch.cuda.is_available():
|
||||
if is_torch_xpu_available():
|
||||
torch.xpu.empty_cache()
|
||||
elif is_torch_npu_available():
|
||||
torch.npu.empty_cache()
|
||||
elif is_torch_mps_available():
|
||||
torch.mps.empty_cache()
|
||||
elif is_torch_cuda_available():
|
||||
torch.cuda.empty_cache()
|
||||
torch.cuda.ipc_collect()
|
||||
|
||||
|
||||
def try_download_model_from_ms(model_args: "ModelArguments") -> None:
|
||||
def try_download_model_from_ms(model_args: "ModelArguments") -> str:
|
||||
if not use_modelscope() or os.path.exists(model_args.model_name_or_path):
|
||||
return
|
||||
return model_args.model_name_or_path
|
||||
|
||||
try:
|
||||
from modelscope import snapshot_download
|
||||
|
||||
revision = "master" if model_args.model_revision == "main" else model_args.model_revision
|
||||
model_args.model_name_or_path = snapshot_download(
|
||||
model_args.model_name_or_path, revision=revision, cache_dir=model_args.cache_dir
|
||||
)
|
||||
return snapshot_download(model_args.model_name_or_path, revision=revision, cache_dir=model_args.cache_dir)
|
||||
except ImportError:
|
||||
raise ImportError("Please install modelscope via `pip install modelscope -U`")
|
||||
|
||||
|
||||
def use_modelscope() -> bool:
|
||||
return bool(int(os.environ.get("USE_MODELSCOPE_HUB", "0")))
|
||||
return os.environ.get("USE_MODELSCOPE_HUB", "0").lower() in ["true", "1"]
|
||||
@@ -1,30 +1,41 @@
|
||||
import importlib.metadata
|
||||
import importlib.util
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from packaging import version
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from packaging.version import Version
|
||||
|
||||
|
||||
def _is_package_available(name: str) -> bool:
|
||||
return importlib.util.find_spec(name) is not None
|
||||
|
||||
|
||||
def _get_package_version(name: str) -> str:
|
||||
def _get_package_version(name: str) -> "Version":
|
||||
try:
|
||||
return importlib.metadata.version(name)
|
||||
return version.parse(importlib.metadata.version(name))
|
||||
except Exception:
|
||||
return "0.0.0"
|
||||
return version.parse("0.0.0")
|
||||
|
||||
|
||||
def is_fastapi_availble():
|
||||
def is_fastapi_available():
|
||||
return _is_package_available("fastapi")
|
||||
|
||||
|
||||
def is_flash_attn2_available():
|
||||
return _is_package_available("flash_attn") and _get_package_version("flash_attn").startswith("2")
|
||||
return _is_package_available("flash_attn") and _get_package_version("flash_attn") > version.parse("2.0.0")
|
||||
|
||||
|
||||
def is_galore_available():
|
||||
return _is_package_available("galore_torch")
|
||||
|
||||
|
||||
def is_gradio_available():
|
||||
return _is_package_available("gradio")
|
||||
|
||||
|
||||
def is_jieba_available():
|
||||
return _is_package_available("jieba")
|
||||
|
||||
@@ -37,6 +48,10 @@ def is_nltk_available():
|
||||
return _is_package_available("nltk")
|
||||
|
||||
|
||||
def is_pillow_available():
|
||||
return _is_package_available("PIL")
|
||||
|
||||
|
||||
def is_requests_available():
|
||||
return _is_package_available("requests")
|
||||
|
||||
@@ -45,14 +60,14 @@ def is_rouge_available():
|
||||
return _is_package_available("rouge_chinese")
|
||||
|
||||
|
||||
def is_sdpa_available():
|
||||
return _get_package_version("torch") > version.parse("2.1.1")
|
||||
|
||||
|
||||
def is_starlette_available():
|
||||
return _is_package_available("sse_starlette")
|
||||
|
||||
|
||||
def is_unsloth_available():
|
||||
return _is_package_available("unsloth")
|
||||
|
||||
|
||||
def is_uvicorn_available():
|
||||
return _is_package_available("uvicorn")
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
import json
|
||||
import math
|
||||
import os
|
||||
from typing import List
|
||||
from typing import Any, Dict, List
|
||||
|
||||
from transformers.trainer import TRAINER_STATE_NAME
|
||||
|
||||
@@ -10,6 +10,7 @@ from .packages import is_matplotlib_available
|
||||
|
||||
|
||||
if is_matplotlib_available():
|
||||
import matplotlib.figure
|
||||
import matplotlib.pyplot as plt
|
||||
|
||||
|
||||
@@ -20,8 +21,11 @@ def smooth(scalars: List[float]) -> List[float]:
|
||||
r"""
|
||||
EMA implementation according to TensorBoard.
|
||||
"""
|
||||
if len(scalars) == 0:
|
||||
return []
|
||||
|
||||
last = scalars[0]
|
||||
smoothed = list()
|
||||
smoothed = []
|
||||
weight = 1.8 * (1 / (1 + math.exp(-0.05 * len(scalars))) - 0.5) # a sigmoid function
|
||||
for next_val in scalars:
|
||||
smoothed_val = last * weight + (1 - weight) * next_val
|
||||
@@ -30,7 +34,33 @@ def smooth(scalars: List[float]) -> List[float]:
|
||||
return smoothed
|
||||
|
||||
|
||||
def gen_loss_plot(trainer_log: List[Dict[str, Any]]) -> "matplotlib.figure.Figure":
|
||||
r"""
|
||||
Plots loss curves in LlamaBoard.
|
||||
"""
|
||||
plt.close("all")
|
||||
plt.switch_backend("agg")
|
||||
fig = plt.figure()
|
||||
ax = fig.add_subplot(111)
|
||||
steps, losses = [], []
|
||||
for log in trainer_log:
|
||||
if log.get("loss", None):
|
||||
steps.append(log["current_steps"])
|
||||
losses.append(log["loss"])
|
||||
|
||||
ax.plot(steps, losses, color="#1f77b4", alpha=0.4, label="original")
|
||||
ax.plot(steps, smooth(losses), color="#1f77b4", label="smoothed")
|
||||
ax.legend()
|
||||
ax.set_xlabel("step")
|
||||
ax.set_ylabel("loss")
|
||||
return fig
|
||||
|
||||
|
||||
def plot_loss(save_dictionary: os.PathLike, keys: List[str] = ["loss"]) -> None:
|
||||
r"""
|
||||
Plots loss curves and saves the image.
|
||||
"""
|
||||
plt.switch_backend("agg")
|
||||
with open(os.path.join(save_dictionary, TRAINER_STATE_NAME), "r", encoding="utf-8") as f:
|
||||
data = json.load(f)
|
||||
|
||||
@@ -52,6 +82,6 @@ def plot_loss(save_dictionary: os.PathLike, keys: List[str] = ["loss"]) -> None:
|
||||
plt.xlabel("step")
|
||||
plt.ylabel(key)
|
||||
plt.legend()
|
||||
figure_path = os.path.join(save_dictionary, "training_{}.png".format(key.replace(os.path.sep, "_")))
|
||||
figure_path = os.path.join(save_dictionary, "training_{}.png".format(key.replace("/", "_")))
|
||||
plt.savefig(figure_path, format="png", dpi=100)
|
||||
print("Figure saved at:", figure_path)
|
||||
@@ -26,11 +26,11 @@ class DataArguments:
|
||||
)
|
||||
cutoff_len: int = field(
|
||||
default=1024,
|
||||
metadata={"help": "The cutoff length of the model inputs after tokenization."},
|
||||
metadata={"help": "The cutoff length of the tokenized inputs in the dataset."},
|
||||
)
|
||||
reserved_label_len: int = field(
|
||||
default=1,
|
||||
metadata={"help": "The minimum cutoff length reserved for label after tokenization."},
|
||||
metadata={"help": "The minimum cutoff length reserved for the tokenized labels in the dataset."},
|
||||
)
|
||||
train_on_prompt: bool = field(
|
||||
default=False,
|
||||
@@ -84,9 +84,9 @@ class DataArguments:
|
||||
"help": "Whether or not to pack the sequences in training. Will automatically enable in pre-training."
|
||||
},
|
||||
)
|
||||
cache_path: Optional[str] = field(
|
||||
tokenized_path: Optional[str] = field(
|
||||
default=None,
|
||||
metadata={"help": "Path to save or load the pre-processed datasets."},
|
||||
metadata={"help": "Path to save or load the tokenized datasets."},
|
||||
)
|
||||
|
||||
def __post_init__(self):
|
||||
@@ -1,5 +1,4 @@
|
||||
import json
|
||||
from dataclasses import asdict, dataclass, field
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Literal, Optional
|
||||
|
||||
|
||||
@@ -9,22 +8,35 @@ class FreezeArguments:
|
||||
Arguments pertaining to the freeze (partial-parameter) training.
|
||||
"""
|
||||
|
||||
name_module_trainable: str = field(
|
||||
default="all",
|
||||
freeze_trainable_layers: int = field(
|
||||
default=2,
|
||||
metadata={
|
||||
"help": """Name of trainable modules for partial-parameter (freeze) fine-tuning. \
|
||||
Use commas to separate multiple modules. \
|
||||
Use "all" to specify all the available modules. \
|
||||
LLaMA choices: ["mlp", "self_attn"], \
|
||||
BLOOM & Falcon & ChatGLM choices: ["mlp", "self_attention"], \
|
||||
Qwen choices: ["mlp", "attn"], \
|
||||
InternLM2 choices: ["feed_forward", "attention"], \
|
||||
Others choices: the same as LLaMA."""
|
||||
"help": (
|
||||
"The number of trainable layers for freeze (partial-parameter) fine-tuning. "
|
||||
"Positive numbers mean the last n layers are set as trainable, "
|
||||
"negative numbers mean the first n layers are set as trainable."
|
||||
)
|
||||
},
|
||||
)
|
||||
num_layer_trainable: int = field(
|
||||
default=2,
|
||||
metadata={"help": "The number of trainable layers for partial-parameter (freeze) fine-tuning."},
|
||||
freeze_trainable_modules: str = field(
|
||||
default="all",
|
||||
metadata={
|
||||
"help": (
|
||||
"Name(s) of trainable modules for freeze (partial-parameter) fine-tuning. "
|
||||
"Use commas to separate multiple modules. "
|
||||
"Use `all` to specify all the available modules."
|
||||
)
|
||||
},
|
||||
)
|
||||
freeze_extra_modules: Optional[str] = field(
|
||||
default=None,
|
||||
metadata={
|
||||
"help": (
|
||||
"Name(s) of modules apart from hidden layers to be set as trainable "
|
||||
"for freeze (partial-parameter) fine-tuning. "
|
||||
"Use commas to separate multiple modules."
|
||||
)
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
@@ -37,7 +49,11 @@ class LoraArguments:
|
||||
additional_target: Optional[str] = field(
|
||||
default=None,
|
||||
metadata={
|
||||
"help": "Name(s) of modules apart from LoRA layers to be set as trainable and saved in the final checkpoint."
|
||||
"help": (
|
||||
"Name(s) of modules apart from LoRA layers to be set as trainable "
|
||||
"and saved in the final checkpoint. "
|
||||
"Use commas to separate multiple modules."
|
||||
)
|
||||
},
|
||||
)
|
||||
lora_alpha: Optional[int] = field(
|
||||
@@ -55,17 +71,21 @@ class LoraArguments:
|
||||
lora_target: str = field(
|
||||
default="all",
|
||||
metadata={
|
||||
"help": """Name(s) of target modules to apply LoRA. \
|
||||
Use commas to separate multiple modules. \
|
||||
Use "all" to specify all the available modules. \
|
||||
LLaMA choices: ["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"], \
|
||||
BLOOM & Falcon & ChatGLM choices: ["query_key_value", "dense", "dense_h_to_4h", "dense_4h_to_h"], \
|
||||
Baichuan choices: ["W_pack", "o_proj", "gate_proj", "up_proj", "down_proj"], \
|
||||
Qwen choices: ["c_attn", "attn.c_proj", "w1", "w2", "mlp.c_proj"], \
|
||||
InternLM2 choices: ["wqkv", "wo", "w1", "w2", "w3"], \
|
||||
Others choices: the same as LLaMA."""
|
||||
"help": (
|
||||
"Name(s) of target modules to apply LoRA. "
|
||||
"Use commas to separate multiple modules. "
|
||||
"Use `all` to specify all the linear modules."
|
||||
)
|
||||
},
|
||||
)
|
||||
loraplus_lr_ratio: Optional[float] = field(
|
||||
default=None,
|
||||
metadata={"help": "LoRA plus learning rate ratio (lr_B / lr_A)."},
|
||||
)
|
||||
loraplus_lr_embedding: float = field(
|
||||
default=1e-6,
|
||||
metadata={"help": "LoRA plus learning rate for lora embedding layers."},
|
||||
)
|
||||
use_rslora: bool = field(
|
||||
default=False,
|
||||
metadata={"help": "Whether or not to use the rank stabilization scaling factor for LoRA layer."},
|
||||
@@ -83,20 +103,36 @@ class LoraArguments:
|
||||
@dataclass
|
||||
class RLHFArguments:
|
||||
r"""
|
||||
Arguments pertaining to the PPO and DPO training.
|
||||
Arguments pertaining to the PPO, DPO and KTO training.
|
||||
"""
|
||||
|
||||
dpo_beta: float = field(
|
||||
pref_beta: float = field(
|
||||
default=0.1,
|
||||
metadata={"help": "The beta parameter for the DPO loss."},
|
||||
metadata={"help": "The beta parameter in the preference loss."},
|
||||
)
|
||||
dpo_loss: Literal["sigmoid", "hinge", "ipo", "kto_pair"] = field(
|
||||
pref_ftx: float = field(
|
||||
default=0.0,
|
||||
metadata={"help": "The supervised fine-tuning loss coefficient in DPO training."},
|
||||
)
|
||||
pref_loss: Literal["sigmoid", "hinge", "ipo", "kto_pair", "orpo", "simpo"] = field(
|
||||
default="sigmoid",
|
||||
metadata={"help": "The type of DPO loss to use."},
|
||||
)
|
||||
dpo_ftx: float = field(
|
||||
dpo_label_smoothing: float = field(
|
||||
default=0.0,
|
||||
metadata={"help": "The supervised fine-tuning loss coefficient in DPO training."},
|
||||
metadata={"help": "The robust DPO label smoothing parameter in cDPO that should be between 0 and 0.5."},
|
||||
)
|
||||
kto_chosen_weight: float = field(
|
||||
default=1.0,
|
||||
metadata={"help": "The weight factor of the desirable losses in KTO training."},
|
||||
)
|
||||
kto_rejected_weight: float = field(
|
||||
default=1.0,
|
||||
metadata={"help": "The weight factor of the undesirable losses in KTO training."},
|
||||
)
|
||||
simpo_gamma: float = field(
|
||||
default=0.5,
|
||||
metadata={"help": "The target reward margin term in SimPO loss."},
|
||||
)
|
||||
ppo_buffer_size: int = field(
|
||||
default=1,
|
||||
@@ -106,10 +142,6 @@ class RLHFArguments:
|
||||
default=4,
|
||||
metadata={"help": "The number of epochs to perform in a PPO optimization step."},
|
||||
)
|
||||
ppo_logger: Optional[str] = field(
|
||||
default=None,
|
||||
metadata={"help": 'Log with either "wandb" or "tensorboard" in PPO training.'},
|
||||
)
|
||||
ppo_score_norm: bool = field(
|
||||
default=False,
|
||||
metadata={"help": "Use score normalization in PPO training."},
|
||||
@@ -160,11 +192,16 @@ class GaloreArguments:
|
||||
|
||||
use_galore: bool = field(
|
||||
default=False,
|
||||
metadata={"help": "Whether or not to use gradient low-Rank projection."},
|
||||
metadata={"help": "Whether or not to use the gradient low-Rank projection (GaLore)."},
|
||||
)
|
||||
galore_target: str = field(
|
||||
default="mlp,attn",
|
||||
metadata={"help": "Name(s) of modules to apply GaLore. Use commas to separate multiple modules."},
|
||||
default="all",
|
||||
metadata={
|
||||
"help": (
|
||||
"Name(s) of modules to apply GaLore. Use commas to separate multiple modules. "
|
||||
"Use `all` to specify all the linear modules."
|
||||
)
|
||||
},
|
||||
)
|
||||
galore_rank: int = field(
|
||||
default=16,
|
||||
@@ -189,7 +226,60 @@ class GaloreArguments:
|
||||
|
||||
|
||||
@dataclass
|
||||
class FinetuningArguments(FreezeArguments, LoraArguments, RLHFArguments, GaloreArguments):
|
||||
class BAdamArgument:
|
||||
r"""
|
||||
Arguments pertaining to the BAdam optimizer.
|
||||
"""
|
||||
|
||||
use_badam: bool = field(
|
||||
default=False,
|
||||
metadata={"help": "Whether or not to use the BAdam optimizer."},
|
||||
)
|
||||
badam_mode: Literal["layer", "ratio"] = field(
|
||||
default="layer",
|
||||
metadata={"help": "Whether to use layer-wise or ratio-wise BAdam optimizer."},
|
||||
)
|
||||
badam_start_block: Optional[int] = field(
|
||||
default=None,
|
||||
metadata={"help": "The starting block index for layer-wise BAdam."},
|
||||
)
|
||||
badam_switch_mode: Optional[Literal["ascending", "descending", "random", "fixed"]] = field(
|
||||
default="ascending",
|
||||
metadata={"help": "the strategy of picking block to update for layer-wise BAdam."},
|
||||
)
|
||||
badam_switch_interval: Optional[int] = field(
|
||||
default=50,
|
||||
metadata={
|
||||
"help": "Number of steps to update the block for layer-wise BAdam. Use -1 to disable the block update."
|
||||
},
|
||||
)
|
||||
badam_update_ratio: float = field(
|
||||
default=0.05,
|
||||
metadata={"help": "The ratio of the update for ratio-wise BAdam."},
|
||||
)
|
||||
badam_mask_mode: Literal["adjacent", "scatter"] = field(
|
||||
default="adjacent",
|
||||
metadata={
|
||||
"help": (
|
||||
"The mode of the mask for BAdam optimizer. "
|
||||
"`adjacent` means that the trainable parameters are adjacent to each other, "
|
||||
"`scatter` means that trainable parameters are randomly choosed from the weight."
|
||||
)
|
||||
},
|
||||
)
|
||||
badam_verbose: int = field(
|
||||
default=0,
|
||||
metadata={
|
||||
"help": (
|
||||
"The verbosity level of BAdam optimizer. "
|
||||
"0 for no print, 1 for print the block prefix, 2 for print trainable parameters."
|
||||
)
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class FinetuningArguments(FreezeArguments, LoraArguments, RLHFArguments, GaloreArguments, BAdamArgument):
|
||||
r"""
|
||||
Arguments pertaining to which techniques we are going to fine-tuning with.
|
||||
"""
|
||||
@@ -198,7 +288,7 @@ class FinetuningArguments(FreezeArguments, LoraArguments, RLHFArguments, GaloreA
|
||||
default=False,
|
||||
metadata={"help": "Whether or not to train model in purely bf16 precision (without AMP)."},
|
||||
)
|
||||
stage: Literal["pt", "sft", "rm", "ppo", "dpo"] = field(
|
||||
stage: Literal["pt", "sft", "rm", "ppo", "dpo", "kto"] = field(
|
||||
default="sft",
|
||||
metadata={"help": "Which stage will be performed in training."},
|
||||
)
|
||||
@@ -210,6 +300,14 @@ class FinetuningArguments(FreezeArguments, LoraArguments, RLHFArguments, GaloreA
|
||||
default=False,
|
||||
metadata={"help": "Whether or not to make only the parameters in the expanded blocks trainable."},
|
||||
)
|
||||
freeze_vision_tower: bool = field(
|
||||
default=True,
|
||||
metadata={"help": "Whether ot not to freeze vision tower in MLLM training."},
|
||||
)
|
||||
train_mm_proj_only: bool = field(
|
||||
default=False,
|
||||
metadata={"help": "Whether or not to train the multimodal projector for MLLM only."},
|
||||
)
|
||||
plot_loss: bool = field(
|
||||
default=False,
|
||||
metadata={"help": "Whether or not to save the training loss curves."},
|
||||
@@ -221,37 +319,40 @@ class FinetuningArguments(FreezeArguments, LoraArguments, RLHFArguments, GaloreA
|
||||
return [item.strip() for item in arg.split(",")]
|
||||
return arg
|
||||
|
||||
self.name_module_trainable = split_arg(self.name_module_trainable)
|
||||
self.freeze_trainable_modules = split_arg(self.freeze_trainable_modules)
|
||||
self.freeze_extra_modules = split_arg(self.freeze_extra_modules)
|
||||
self.lora_alpha = self.lora_alpha or self.lora_rank * 2
|
||||
self.lora_target = split_arg(self.lora_target)
|
||||
self.additional_target = split_arg(self.additional_target)
|
||||
self.galore_target = split_arg(self.galore_target)
|
||||
self.freeze_vision_tower = self.freeze_vision_tower or self.train_mm_proj_only
|
||||
|
||||
assert self.finetuning_type in ["lora", "freeze", "full"], "Invalid fine-tuning method."
|
||||
assert self.ref_model_quantization_bit in [None, 8, 4], "We only accept 4-bit or 8-bit quantization."
|
||||
assert self.reward_model_quantization_bit in [None, 8, 4], "We only accept 4-bit or 8-bit quantization."
|
||||
|
||||
self.use_ref_model = self.pref_loss not in ["orpo", "simpo"]
|
||||
|
||||
if self.stage == "ppo" and self.reward_model is None:
|
||||
raise ValueError("`reward_model` is necessary for PPO training.")
|
||||
|
||||
if self.stage == "ppo" and self.reward_model_type == "lora" and self.finetuning_type != "lora":
|
||||
raise ValueError("`reward_model_type` cannot be lora for Freeze/Full PPO training.")
|
||||
|
||||
if self.stage == "dpo" and self.pref_loss != "sigmoid" and self.dpo_label_smoothing > 1e-6:
|
||||
raise ValueError("`dpo_label_smoothing` is only valid for sigmoid loss function.")
|
||||
|
||||
if self.use_llama_pro and self.finetuning_type == "full":
|
||||
raise ValueError("`use_llama_pro` is only valid for the Freeze or LoRA method.")
|
||||
raise ValueError("`use_llama_pro` is only valid for Freeze or LoRA training.")
|
||||
|
||||
if self.use_galore and self.finetuning_type == "lora":
|
||||
raise ValueError("Cannot use LoRA with GaLore together.")
|
||||
if self.finetuning_type == "lora" and (self.use_galore or self.use_badam):
|
||||
raise ValueError("Cannot use LoRA with GaLore or BAdam together.")
|
||||
|
||||
def save_to_json(self, json_path: str):
|
||||
r"""Saves the content of this instance in JSON format inside `json_path`."""
|
||||
json_string = json.dumps(asdict(self), indent=2, sort_keys=True) + "\n"
|
||||
with open(json_path, "w", encoding="utf-8") as f:
|
||||
f.write(json_string)
|
||||
if self.use_galore and self.use_badam:
|
||||
raise ValueError("Cannot use GaLore with BAdam together.")
|
||||
|
||||
@classmethod
|
||||
def load_from_json(cls, json_path: str):
|
||||
r"""Creates an instance from the content of `json_path`."""
|
||||
with open(json_path, "r", encoding="utf-8") as f:
|
||||
text = f.read()
|
||||
if self.loraplus_lr_ratio is not None and self.finetuning_type != "lora":
|
||||
raise ValueError("`loraplus_lr_ratio` is only valid for LoRA training.")
|
||||
|
||||
return cls(**json.loads(text))
|
||||
if self.train_mm_proj_only and self.finetuning_type != "full":
|
||||
raise ValueError("`train_mm_proj_only` is only valid for full training.")
|
||||
@@ -1,5 +1,5 @@
|
||||
from dataclasses import asdict, dataclass, field
|
||||
from typing import Any, Dict
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -31,11 +31,11 @@ class GeneratingArguments:
|
||||
metadata={"help": "Number of beams for beam search. 1 means no beam search."},
|
||||
)
|
||||
max_length: int = field(
|
||||
default=512,
|
||||
default=1024,
|
||||
metadata={"help": "The maximum length the generated tokens can have. It can be overridden by max_new_tokens."},
|
||||
)
|
||||
max_new_tokens: int = field(
|
||||
default=512,
|
||||
default=1024,
|
||||
metadata={"help": "The maximum numbers of tokens to generate, ignoring the number of tokens in the prompt."},
|
||||
)
|
||||
repetition_penalty: float = field(
|
||||
@@ -46,6 +46,10 @@ class GeneratingArguments:
|
||||
default=1.0,
|
||||
metadata={"help": "Exponential penalty to the length that is used with beam-based generation."},
|
||||
)
|
||||
default_system: Optional[str] = field(
|
||||
default=None,
|
||||
metadata={"help": "Default system message to use in chat completion."},
|
||||
)
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
args = asdict(self)
|
||||
@@ -15,14 +15,19 @@ class ModelArguments:
|
||||
)
|
||||
adapter_name_or_path: Optional[str] = field(
|
||||
default=None,
|
||||
metadata={"help": "Path to the adapter weight or identifier from huggingface.co/models."},
|
||||
metadata={
|
||||
"help": (
|
||||
"Path to the adapter weight or identifier from huggingface.co/models. "
|
||||
"Use commas to separate multiple adapters."
|
||||
)
|
||||
},
|
||||
)
|
||||
cache_dir: Optional[str] = field(
|
||||
default=None,
|
||||
metadata={"help": "Where to store the pre-trained models downloaded from huggingface.co or modelscope.cn."},
|
||||
)
|
||||
use_fast_tokenizer: bool = field(
|
||||
default=False,
|
||||
default=True,
|
||||
metadata={"help": "Whether or not to use one of the fast tokenizer (backed by the tokenizers library)."},
|
||||
)
|
||||
resize_vocab: bool = field(
|
||||
@@ -33,6 +38,10 @@ class ModelArguments:
|
||||
default=False,
|
||||
metadata={"help": "Whether or not the special tokens should be split during the tokenization process."},
|
||||
)
|
||||
new_special_tokens: Optional[str] = field(
|
||||
default=None,
|
||||
metadata={"help": "Special tokens to be added into the tokenizer. Use commas to separate multiple tokens."},
|
||||
)
|
||||
model_revision: str = field(
|
||||
default="main",
|
||||
metadata={"help": "The specific model version to use (can be a branch name, tag name or commit id)."},
|
||||
@@ -53,22 +62,38 @@ class ModelArguments:
|
||||
default=True,
|
||||
metadata={"help": "Whether or not to use double quantization in int4 training."},
|
||||
)
|
||||
quantization_device_map: Optional[Literal["auto"]] = field(
|
||||
default=None,
|
||||
metadata={"help": "Device map used to infer the 4-bit quantized model, needs bitsandbytes>=0.43.0."},
|
||||
)
|
||||
rope_scaling: Optional[Literal["linear", "dynamic"]] = field(
|
||||
default=None,
|
||||
metadata={"help": "Which scaling strategy should be adopted for the RoPE embeddings."},
|
||||
)
|
||||
flash_attn: bool = field(
|
||||
default=False,
|
||||
metadata={"help": "Enable FlashAttention-2 for faster training."},
|
||||
flash_attn: Literal["off", "sdpa", "fa2", "auto"] = field(
|
||||
default="auto",
|
||||
metadata={"help": "Enable FlashAttention for faster training and inference."},
|
||||
)
|
||||
shift_attn: bool = field(
|
||||
default=False,
|
||||
metadata={"help": "Enable shift short attention (S^2-Attn) proposed by LongLoRA."},
|
||||
)
|
||||
mixture_of_depths: Optional[Literal["convert", "load"]] = field(
|
||||
default=None,
|
||||
metadata={"help": "Convert the model to mixture-of-depths (MoD) or load the MoD model."},
|
||||
)
|
||||
use_unsloth: bool = field(
|
||||
default=False,
|
||||
metadata={"help": "Whether or not to use unsloth's optimization for the LoRA training."},
|
||||
)
|
||||
visual_inputs: bool = field(
|
||||
default=False,
|
||||
metadata={"help": "Whethor or not to use multimodal LLM that accepts visual inputs."},
|
||||
)
|
||||
moe_aux_loss_coef: Optional[float] = field(
|
||||
default=None,
|
||||
metadata={"help": "Coefficient of the auxiliary router loss in mixture-of-experts model."},
|
||||
)
|
||||
disable_gradient_checkpointing: bool = field(
|
||||
default=False,
|
||||
metadata={"help": "Whether or not to disable gradient checkpointing."},
|
||||
@@ -81,13 +106,17 @@ class ModelArguments:
|
||||
default=False,
|
||||
metadata={"help": "Whether or not to upcast the output of lm_head in fp32."},
|
||||
)
|
||||
train_from_scratch: bool = field(
|
||||
default=False,
|
||||
metadata={"help": "Whether or not to randomly initialize the model weights."},
|
||||
)
|
||||
infer_backend: Literal["huggingface", "vllm"] = field(
|
||||
default="huggingface",
|
||||
metadata={"help": "Backend engine used at inference."},
|
||||
)
|
||||
vllm_maxlen: int = field(
|
||||
default=2048,
|
||||
metadata={"help": "Maximum input length of the vLLM engine."},
|
||||
metadata={"help": "Maximum sequence (prompt + response) length of the vLLM engine."},
|
||||
)
|
||||
vllm_gpu_util: float = field(
|
||||
default=0.9,
|
||||
@@ -97,6 +126,14 @@ class ModelArguments:
|
||||
default=False,
|
||||
metadata={"help": "Whether or not to disable CUDA graph in the vLLM engine."},
|
||||
)
|
||||
vllm_max_lora_rank: int = field(
|
||||
default=8,
|
||||
metadata={"help": "Maximum rank of all LoRAs in the vLLM engine."},
|
||||
)
|
||||
vllm_dtype: Literal["auto", "float16", "bfloat16", "float32"] = field(
|
||||
default="auto",
|
||||
metadata={"help": "Data type for model weights and activations in the vLLM engine."},
|
||||
)
|
||||
offload_folder: str = field(
|
||||
default="offload",
|
||||
metadata={"help": "Path to offload model weights."},
|
||||
@@ -121,6 +158,10 @@ class ModelArguments:
|
||||
default=1,
|
||||
metadata={"help": "The file shard size (in GB) of the exported model."},
|
||||
)
|
||||
export_device: Literal["cpu", "auto"] = field(
|
||||
default="cpu",
|
||||
metadata={"help": "The device used in model export, use `auto` to accelerate exporting."},
|
||||
)
|
||||
export_quantization_bit: Optional[int] = field(
|
||||
default=None,
|
||||
metadata={"help": "The number of bits to quantize the exported model."},
|
||||
@@ -158,9 +199,15 @@ class ModelArguments:
|
||||
if self.split_special_tokens and self.use_fast_tokenizer:
|
||||
raise ValueError("`split_special_tokens` is only supported for slow tokenizers.")
|
||||
|
||||
if self.visual_inputs and self.use_unsloth:
|
||||
raise ValueError("Unsloth does not support MLLM yet. Stay tuned.")
|
||||
|
||||
if self.adapter_name_or_path is not None: # support merging multiple lora weights
|
||||
self.adapter_name_or_path = [path.strip() for path in self.adapter_name_or_path.split(",")]
|
||||
|
||||
if self.new_special_tokens is not None: # support multiple special tokens
|
||||
self.new_special_tokens = [token.strip() for token in self.new_special_tokens.split(",")]
|
||||
|
||||
assert self.quantization_bit in [None, 8, 4], "We only accept 4-bit or 8-bit quantization."
|
||||
assert self.export_quantization_bit in [None, 8, 4, 3, 2], "We only accept 2/3/4/8-bit quantization."
|
||||
|
||||
@@ -6,13 +6,14 @@ from typing import Any, Dict, Optional, Tuple
|
||||
import torch
|
||||
import transformers
|
||||
from transformers import HfArgumentParser, Seq2SeqTrainingArguments
|
||||
from transformers.integrations import is_deepspeed_zero3_enabled
|
||||
from transformers.trainer_utils import get_last_checkpoint
|
||||
from transformers.utils import is_torch_bf16_gpu_available
|
||||
from transformers.utils.versions import require_version
|
||||
|
||||
from ..extras.constants import CHECKPOINT_NAMES
|
||||
from ..extras.logging import get_logger
|
||||
from ..extras.misc import check_dependencies
|
||||
from ..extras.packages import is_unsloth_available
|
||||
from ..extras.misc import check_dependencies, get_current_device
|
||||
from .data_args import DataArguments
|
||||
from .evaluation_args import EvaluationArguments
|
||||
from .finetuning_args import FinetuningArguments
|
||||
@@ -64,10 +65,16 @@ def _verify_model_args(model_args: "ModelArguments", finetuning_args: "Finetunin
|
||||
if model_args.adapter_name_or_path is not None and finetuning_args.finetuning_type != "lora":
|
||||
raise ValueError("Adapter is only valid for the LoRA method.")
|
||||
|
||||
if model_args.use_unsloth and is_deepspeed_zero3_enabled():
|
||||
raise ValueError("Unsloth is incompatible with DeepSpeed ZeRO-3.")
|
||||
|
||||
if model_args.quantization_bit is not None:
|
||||
if finetuning_args.finetuning_type != "lora":
|
||||
raise ValueError("Quantization is only compatible with the LoRA method.")
|
||||
|
||||
if model_args.resize_vocab:
|
||||
raise ValueError("Cannot resize embedding layers of a quantized model.")
|
||||
|
||||
if model_args.adapter_name_or_path is not None and finetuning_args.create_new_adapter:
|
||||
raise ValueError("Cannot create new adapter upon a quantized model.")
|
||||
|
||||
@@ -75,6 +82,35 @@ def _verify_model_args(model_args: "ModelArguments", finetuning_args: "Finetunin
|
||||
raise ValueError("Quantized model only accepts a single adapter. Merge them first.")
|
||||
|
||||
|
||||
def _check_extra_dependencies(
|
||||
model_args: "ModelArguments",
|
||||
finetuning_args: "FinetuningArguments",
|
||||
training_args: Optional["Seq2SeqTrainingArguments"] = None,
|
||||
) -> None:
|
||||
if model_args.use_unsloth:
|
||||
require_version("unsloth", "Please install unsloth: https://github.com/unslothai/unsloth")
|
||||
|
||||
if model_args.mixture_of_depths is not None:
|
||||
require_version("mixture-of-depth>=1.1.6", "To fix: pip install mixture-of-depth>=1.1.6")
|
||||
|
||||
if model_args.infer_backend == "vllm":
|
||||
require_version("vllm>=0.4.3", "To fix: pip install vllm>=0.4.3")
|
||||
|
||||
if finetuning_args.use_galore:
|
||||
require_version("galore_torch", "To fix: pip install galore_torch")
|
||||
|
||||
if finetuning_args.use_badam:
|
||||
require_version("badam", "To fix: pip install badam")
|
||||
|
||||
if finetuning_args.plot_loss:
|
||||
require_version("matplotlib", "To fix: pip install matplotlib")
|
||||
|
||||
if training_args is not None and training_args.predict_with_generate:
|
||||
require_version("jieba", "To fix: pip install jieba")
|
||||
require_version("nltk", "To fix: pip install nltk")
|
||||
require_version("rouge_chinese", "To fix: pip install rouge-chinese")
|
||||
|
||||
|
||||
def _parse_train_args(args: Optional[Dict[str, Any]] = None) -> _TRAIN_CLS:
|
||||
parser = HfArgumentParser(_TRAIN_ARGS)
|
||||
return _parse_args(parser, args)
|
||||
@@ -119,21 +155,24 @@ def get_train_args(args: Optional[Dict[str, Any]] = None) -> _TRAIN_CLS:
|
||||
if finetuning_args.stage == "ppo" and finetuning_args.reward_model_type == "lora" and model_args.use_unsloth:
|
||||
raise ValueError("Unsloth does not support lora reward model.")
|
||||
|
||||
if (
|
||||
finetuning_args.stage == "ppo"
|
||||
and training_args.report_to
|
||||
and training_args.report_to[0] not in ["wandb", "tensorboard"]
|
||||
):
|
||||
raise ValueError("PPO only accepts wandb or tensorboard logger.")
|
||||
|
||||
if training_args.max_steps == -1 and data_args.streaming:
|
||||
raise ValueError("Please specify `max_steps` in streaming mode.")
|
||||
|
||||
if training_args.do_train and training_args.predict_with_generate:
|
||||
raise ValueError("`predict_with_generate` cannot be set as True while training.")
|
||||
|
||||
if training_args.do_train and model_args.use_unsloth and not is_unsloth_available():
|
||||
raise ValueError("Unsloth was not installed: https://github.com/unslothai/unsloth")
|
||||
if training_args.do_train and model_args.quantization_device_map == "auto":
|
||||
raise ValueError("Cannot use device map for quantized models in training.")
|
||||
|
||||
if finetuning_args.use_dora:
|
||||
if model_args.quantization_bit is not None:
|
||||
require_version("peft>=0.9.1.dev0", "To fix: pip install git+https://github.com/huggingface/peft.git")
|
||||
|
||||
if model_args.use_unsloth:
|
||||
raise ValueError("Unsloth does not support DoRA.")
|
||||
if finetuning_args.use_dora and model_args.use_unsloth:
|
||||
raise ValueError("Unsloth does not support DoRA.")
|
||||
|
||||
if finetuning_args.pure_bf16:
|
||||
if not is_torch_bf16_gpu_available():
|
||||
@@ -149,18 +188,33 @@ def get_train_args(args: Optional[Dict[str, Any]] = None) -> _TRAIN_CLS:
|
||||
):
|
||||
raise ValueError("Distributed training does not support layer-wise GaLore.")
|
||||
|
||||
if (
|
||||
finetuning_args.use_badam
|
||||
and finetuning_args.badam_mode == "layer"
|
||||
and training_args.parallel_mode.value == "distributed"
|
||||
):
|
||||
raise ValueError("Layer-wise BAdam does not yet support distributed training, use ratio-wise BAdam.")
|
||||
|
||||
if (finetuning_args.use_galore or finetuning_args.use_badam) and training_args.deepspeed is not None:
|
||||
raise ValueError("GaLore and BAdam are incompatible with DeepSpeed yet.")
|
||||
|
||||
if model_args.infer_backend == "vllm":
|
||||
raise ValueError("vLLM backend is only available for API, CLI and Web.")
|
||||
|
||||
if model_args.visual_inputs and data_args.packing:
|
||||
raise ValueError("Cannot use packing in MLLM fine-tuning.")
|
||||
|
||||
_verify_model_args(model_args, finetuning_args)
|
||||
_check_extra_dependencies(model_args, finetuning_args, training_args)
|
||||
|
||||
if (
|
||||
training_args.do_train
|
||||
and finetuning_args.finetuning_type == "lora"
|
||||
and model_args.quantization_bit is None
|
||||
and model_args.resize_vocab
|
||||
and finetuning_args.additional_target is None
|
||||
):
|
||||
logger.warning("Add token embeddings to `additional_target` to make the added tokens trainable.")
|
||||
logger.warning("Remember to add embedding layers to `additional_target` to make the added tokens trainable.")
|
||||
|
||||
if training_args.do_train and model_args.quantization_bit is not None and (not model_args.upcast_layernorm):
|
||||
logger.warning("We recommend enable `upcast_layernorm` in quantized training.")
|
||||
@@ -202,16 +256,15 @@ def get_train_args(args: Optional[Dict[str, Any]] = None) -> _TRAIN_CLS:
|
||||
and can_resume_from_checkpoint
|
||||
):
|
||||
last_checkpoint = get_last_checkpoint(training_args.output_dir)
|
||||
if last_checkpoint is None and len(os.listdir(training_args.output_dir)) > 0:
|
||||
if last_checkpoint is None and any(
|
||||
os.path.isfile(os.path.join(training_args.output_dir, name)) for name in CHECKPOINT_NAMES
|
||||
):
|
||||
raise ValueError("Output directory already exists and is not empty. Please set `overwrite_output_dir`.")
|
||||
|
||||
if last_checkpoint is not None:
|
||||
training_args.resume_from_checkpoint = last_checkpoint
|
||||
logger.info(
|
||||
"Resuming training from {}. Change `output_dir` or use `overwrite_output_dir` to avoid.".format(
|
||||
training_args.resume_from_checkpoint
|
||||
)
|
||||
)
|
||||
logger.info("Resuming training from {}.".format(training_args.resume_from_checkpoint))
|
||||
logger.info("Change `output_dir` or use `overwrite_output_dir` to avoid.")
|
||||
|
||||
if (
|
||||
finetuning_args.stage in ["rm", "ppo"]
|
||||
@@ -230,10 +283,11 @@ def get_train_args(args: Optional[Dict[str, Any]] = None) -> _TRAIN_CLS:
|
||||
elif training_args.fp16:
|
||||
model_args.compute_dtype = torch.float16
|
||||
|
||||
model_args.device_map = {"": get_current_device()}
|
||||
model_args.model_max_length = data_args.cutoff_len
|
||||
data_args.packing = data_args.packing if data_args.packing is not None else finetuning_args.stage == "pt"
|
||||
|
||||
# Log on each process the small summary:
|
||||
# Log on each process the small summary
|
||||
logger.info(
|
||||
"Process rank: {}, device: {}, n_gpu: {}, distributed training: {}, compute dtype: {}".format(
|
||||
training_args.local_rank,
|
||||
@@ -261,18 +315,25 @@ def get_infer_args(args: Optional[Dict[str, Any]] = None) -> _INFER_CLS:
|
||||
if finetuning_args.stage != "sft":
|
||||
raise ValueError("vLLM engine only supports auto-regressive models.")
|
||||
|
||||
if model_args.adapter_name_or_path is not None:
|
||||
raise ValueError("vLLM engine does not support LoRA adapters. Merge them first.")
|
||||
|
||||
if model_args.quantization_bit is not None:
|
||||
raise ValueError("vLLM engine does not support quantization.")
|
||||
raise ValueError("vLLM engine does not support bnb quantization (GPTQ and AWQ are supported).")
|
||||
|
||||
if model_args.rope_scaling is not None:
|
||||
raise ValueError("vLLM engine does not support RoPE scaling.")
|
||||
|
||||
_verify_model_args(model_args, finetuning_args)
|
||||
if model_args.adapter_name_or_path is not None and len(model_args.adapter_name_or_path) != 1:
|
||||
raise ValueError("vLLM only accepts a single adapter. Merge them first.")
|
||||
|
||||
model_args.device_map = "auto"
|
||||
if finetuning_args.stage == "rm" and model_args.visual_inputs:
|
||||
raise ValueError("Reward server does not support MLLM yet. Stay tuned.")
|
||||
|
||||
_verify_model_args(model_args, finetuning_args)
|
||||
_check_extra_dependencies(model_args, finetuning_args)
|
||||
|
||||
if model_args.export_dir is not None and model_args.export_device == "cpu":
|
||||
model_args.device_map = {"": torch.device("cpu")}
|
||||
else:
|
||||
model_args.device_map = "auto"
|
||||
|
||||
return model_args, data_args, finetuning_args, generating_args
|
||||
|
||||
@@ -289,6 +350,7 @@ def get_eval_args(args: Optional[Dict[str, Any]] = None) -> _EVAL_CLS:
|
||||
raise ValueError("vLLM backend is only available for API, CLI and Web.")
|
||||
|
||||
_verify_model_args(model_args, finetuning_args)
|
||||
_check_extra_dependencies(model_args, finetuning_args)
|
||||
|
||||
model_args.device_map = "auto"
|
||||
|
||||
@@ -0,0 +1,9 @@
|
||||
from llamafactory.train.tuner import run_exp
|
||||
|
||||
|
||||
def launch():
|
||||
run_exp()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
launch()
|
||||
@@ -0,0 +1,12 @@
|
||||
from .loader import load_config, load_model, load_tokenizer
|
||||
from .model_utils.misc import find_all_linear_modules
|
||||
from .model_utils.valuehead import load_valuehead_params
|
||||
|
||||
|
||||
__all__ = [
|
||||
"load_config",
|
||||
"load_model",
|
||||
"load_tokenizer",
|
||||
"find_all_linear_modules",
|
||||
"load_valuehead_params",
|
||||
]
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user