mirror of
https://github.com/AIDotNet/AntSK.git
synced 2026-02-18 06:20:11 +08:00
Compare commits
144 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
bf0e62634f | ||
|
|
469ef9aab2 | ||
|
|
f7df26030d | ||
|
|
f61cbe9780 | ||
|
|
56b62cff2a | ||
|
|
0aec21cf03 | ||
|
|
ff4f6be5fc | ||
|
|
964a5022c8 | ||
|
|
b8c6a6a626 | ||
|
|
57fc9a9b7e | ||
|
|
068c126a23 | ||
|
|
0ed1662c7b | ||
|
|
a09377814f | ||
|
|
3cc952bb2a | ||
|
|
63c968742b | ||
|
|
a4d6f2a6fd | ||
|
|
d6de64853d | ||
|
|
c7c1911eb1 | ||
|
|
dec8b5bef7 | ||
|
|
db7271b519 | ||
|
|
fcbee1f64f | ||
|
|
2d9443a0a1 | ||
|
|
6e3dd00d6f | ||
|
|
7a0656cd81 | ||
|
|
3b89d9e974 | ||
|
|
cd174308cf | ||
|
|
aacef47626 | ||
|
|
fcef01a41f | ||
|
|
7d19b694fa | ||
|
|
966a31b156 | ||
|
|
7538393742 | ||
|
|
3658188be2 | ||
|
|
fcd9fb9079 | ||
|
|
d1168e16d6 | ||
|
|
25a6f00dd2 | ||
|
|
13c474f084 | ||
|
|
8d78270007 | ||
|
|
a94b59c156 | ||
|
|
ae6d61ee6d | ||
|
|
74a7c94619 | ||
|
|
bdd34ac786 | ||
|
|
0e754b4732 | ||
|
|
48a9fcfabf | ||
|
|
7ae2b9ac3b | ||
|
|
a6bfe3c69b | ||
|
|
6b146f4750 | ||
|
|
1387e716d1 | ||
|
|
fadc350047 | ||
|
|
8af4994a7d | ||
|
|
b505b61bfe | ||
|
|
5d086c9383 | ||
|
|
dd0e367dc8 | ||
|
|
7110eea912 | ||
|
|
154e67ef98 | ||
|
|
4a04423373 | ||
|
|
470ea50ebb | ||
|
|
2a30f3f221 | ||
|
|
315f58fdba | ||
|
|
f027682175 | ||
|
|
2618dffcd6 | ||
|
|
ee42d5870b | ||
|
|
27500b1e08 | ||
|
|
4d71a98724 | ||
|
|
b8ba0ab391 | ||
|
|
b589612913 | ||
|
|
ec4b440469 | ||
|
|
adbecb3b25 | ||
|
|
277aacc34d | ||
|
|
dba98f7968 | ||
|
|
a2b0f3f3c2 | ||
|
|
b5e527afdb | ||
|
|
e84b05a39b | ||
|
|
631c563e71 | ||
|
|
47b304e46f | ||
|
|
27ccfc5e88 | ||
|
|
69441167d3 | ||
|
|
4b97594217 | ||
|
|
9ee601f88c | ||
|
|
75b1a299e3 | ||
|
|
2613c463a1 | ||
|
|
78e0350d36 | ||
|
|
bc2425fc3f | ||
|
|
3e861bc72f | ||
|
|
b41c464753 | ||
|
|
2817275091 | ||
|
|
e76b0cf326 | ||
|
|
cf34103e15 | ||
|
|
1588fd7d7a | ||
|
|
4507ccde6c | ||
|
|
73fffd766f | ||
|
|
e529146c5b | ||
|
|
7050e52009 | ||
|
|
4f3238c4f6 | ||
|
|
b7d27c5d50 | ||
|
|
fe94aa0564 | ||
|
|
ae60a9aced | ||
|
|
4f686b0871 | ||
|
|
c61840b7e8 | ||
|
|
9adce95367 | ||
|
|
eef943458e | ||
|
|
f5c80689d4 | ||
|
|
5eaee3130a | ||
|
|
5846473f28 | ||
|
|
94c019b484 | ||
|
|
7e1140c022 | ||
|
|
ea9044719a | ||
|
|
8a96095448 | ||
|
|
fcc8f8751b | ||
|
|
af09ae7c3e | ||
|
|
e8e6a36d7b | ||
|
|
4f89d54ef0 | ||
|
|
2f9e2fb114 | ||
|
|
b6098024b8 | ||
|
|
1700131066 | ||
|
|
189536471a | ||
|
|
f534e0bcc3 | ||
|
|
e203a18e92 | ||
|
|
575a69bf4d | ||
|
|
69fd3a0367 | ||
|
|
8f7e70298e | ||
|
|
0fa3f5a554 | ||
|
|
f420012752 | ||
|
|
c1ca916549 | ||
|
|
dfcf2bdc85 | ||
|
|
72e7acfb7d | ||
|
|
9d06c127dc | ||
|
|
0460b388ab | ||
|
|
45b84ae898 | ||
|
|
41b1cb6f2d | ||
|
|
159aaab38e | ||
|
|
dc351238f6 | ||
|
|
e6491b39c6 | ||
|
|
91b4ed8940 | ||
|
|
ab99098afd | ||
|
|
d14ce2faa0 | ||
|
|
ca293691a8 | ||
|
|
cf8955b9b6 | ||
|
|
512828fdc9 | ||
|
|
91299a96e7 | ||
|
|
4876d9e727 | ||
|
|
a856f2a0e3 | ||
|
|
0e8113e7b0 | ||
|
|
34a953589d | ||
|
|
504ea5a238 |
3
.gitignore
vendored
3
.gitignore
vendored
@@ -337,6 +337,9 @@ ASALocalRun/
|
||||
/AntSK/appsettings.Development.json
|
||||
/AntSK.db
|
||||
**/tmp-memory-files/*
|
||||
**/tmp-memory-vectors/*
|
||||
/src/AntSK/AntSK.db
|
||||
/src/AntSK/appsettings.Development.json
|
||||
/src/AntSK.db
|
||||
/src/AntSK/llama_models
|
||||
/src/AntSK/AntSK.xml
|
||||
118
README.en.md
118
README.en.md
@@ -1,41 +1,42 @@
|
||||
[简体中文](./README.md) | English
|
||||
# AntSK
|
||||
|
||||
## Based on AI knowledge base/agent created by Net8+AntBlazor+SemanticKernel
|
||||
## An AI knowledge base/intelligent agent built with .Net 8+AntBlazor+SemanticKernel
|
||||
|
||||
|
||||
|
||||
## Core functions
|
||||
## Core Features
|
||||
|
||||
|
||||
|
||||
- **Semantic Kernel**: It uses advanced natural language processing technology to accurately understand, process and respond to complex semantic queries, and provides users with accurate information retrieval and recommendation services.
|
||||
- **Semantic Kernel**: Utilizes advanced natural language processing technologies to accurately understand, process, and respond to complex semantic queries, providing users with precise information retrieval and recommendation services.
|
||||
|
||||
|
||||
|
||||
- **Kernel Memory**: It has the ability to continuously learn and store knowledge points. AntSK has a long-term memory function to accumulate experience and provide a more personalized interactive experience.
|
||||
- **Kernel Memory**: Capable of continuous learning and knowledge storage, AntSK has a long-term memory function, accumulating experience to offer more personalized interaction experiences.
|
||||
|
||||
|
||||
|
||||
- **Knowledge base**: Knowledge base documents can be created by importing knowledge base documents (Word, PDF, Excel, Txt, Markdown, Json, PPT) and other forms.
|
||||
- **Knowledge base**: Import knowledge into the database through documents (Word, PDF, Excel, Txt, Markdown, Json, PPT) and manage knowledge base documents.
|
||||
|
||||
|
||||
|
||||
- **API plug-in system**: an open API plug-in system that allows third-party developers or service providers to easily integrate their services into AntSK and continuously enhance application functions.
|
||||
- **GPTs Generation**:The platform supports the creation of personalized GPT models, try building your own GPT model.
|
||||
|
||||
|
||||
|
||||
- **Online search**: AntSK can obtain the latest information in real time to ensure that the information received by users is always the most timely and relevant.
|
||||
- **API Interface Release**: Internal functions are provided as APIs for developers to integrate AntSK into other applications, enhancing application intelligence.
|
||||
|
||||
|
||||
|
||||
- **GPTs generation**: This platform supports the creation of personalized GPT models and attempts to build your own GPT models.
|
||||
- **API Plugin System**: An open API plugin system allows third-party developers or service providers to easily integrate their services into AntSK, continuously enhancing application functions.
|
||||
|
||||
|
||||
- **.Net Plugin System**: An open dll plugin system allows third-party developers or service providers to integrate their business functions into AntSK by generating dlls with the standard format codes, continuously enhancing application functions.
|
||||
|
||||
- **API interface publishing**: internal functions are provided externally in the form of API, so that developers can easily translate Xzy AntSK KnowledgeBase is integrated into other applications to enhance application intelligence.
|
||||
|
||||
- **Model management**: Adapt and manage different models from different vendors.
|
||||
- **Internet Search**: AntSK can retrieve the latest information in real-time, ensuring that the information users receive is always timely and relevant.
|
||||
|
||||
|
||||
- **Model management**: Adapts and manages different models from various manufacturers. It also supports offline running of models in 'gguf' format supported by llama.cpp.
|
||||
|
||||
|
||||
- **National Information Creation**: AntSK supports domestic models and databases, and can operate under information creation conditions.
|
||||
|
||||
|
||||
|
||||
@@ -43,11 +44,11 @@
|
||||
|
||||
|
||||
|
||||
AntSK is applicable to a variety of business scenarios, such as:
|
||||
AntSK is suitable for various business scenarios, such as:
|
||||
|
||||
- Enterprise level knowledge management system
|
||||
- Corporate knowledge management systems
|
||||
|
||||
- Automatic customer service and chat robot
|
||||
- Automated customer service and chatbots
|
||||
|
||||
- Enterprise Search Engine
|
||||
|
||||
@@ -57,7 +58,7 @@ AntSK is applicable to a variety of business scenarios, such as:
|
||||
|
||||
- Education and online learning platform
|
||||
|
||||
- Other interesting AI Apps
|
||||
- Other interesting AI applications
|
||||
|
||||
|
||||
|
||||
@@ -115,25 +116,71 @@ Let's see the effect
|
||||
|
||||
## How do I get started?
|
||||
|
||||
Login is the default login account and password
|
||||
Here I am using Postgres as a data and vector store, because the Semantic Kernel and Kernel Memory both support it, though you can switch to others.
|
||||
|
||||
Here I use Postgres as data storage and vector storage, because both the Semantic Kernel and Kernel Memory support it. Of course, you can switch to other ones.
|
||||
The model by default supports local models in 'gguf' format from openai, azure openai, and llama. If you need to use other models, you can integrate them using the one-api.
|
||||
|
||||
The model supports openai by default. If you need to use azure openai and need to adjust the dependency injection of SK, you can also use one api for integration.
|
||||
The login configuration in the configuration file is the default account and password.
|
||||
|
||||
The following configuration files need to be configured
|
||||
The following configuration files are needed:
|
||||
|
||||
## Using Docker Compose
|
||||
Provided pg version appsettings. json and simplified version (Sqlite+disk) Docker Compose. simple. yml
|
||||
Download Docker Compose.yml from the project root directory, and then place the configuration file appsettings.json and it in a unified directory,
|
||||
The image of PG has been prepared here. You can modify the default account password in Docker Compose.yml, and your appsettings. json database connection needs to be consistent.
|
||||
Then you can enter the directory and execute it
|
||||
A appsettings.json for the pg version and a simplified version (Sqlite+disk) docker-compose.simple.yml are provided.
|
||||
|
||||
Download docker-compose.yml from the project root directory, then place the configuration file appsettings.json in the same directory,
|
||||
|
||||
The pg image has already been prepared. You can modify the default account and password in the docker-compose.yml, and then your appsettings.json database connection needs to be consistent.
|
||||
|
||||
Then you can enter the directory and execute
|
||||
```
|
||||
docker compose up - d
|
||||
```
|
||||
To start AntSK
|
||||
to start AntSK.
|
||||
|
||||
Some meanings of configuration files
|
||||
How to mount local models and model download directories in docker
|
||||
|
||||
```
|
||||
# Non-host version, does not use local proxy
|
||||
version: '3.8'
|
||||
services:
|
||||
antsk:
|
||||
container_name: antsk
|
||||
image: registry.cn-hangzhou.aliyuncs.com/xuzeyu91/antsk:v0.2.1
|
||||
ports:
|
||||
- 5000:5000
|
||||
networks:
|
||||
- antsk
|
||||
depends_on:
|
||||
- antskpg
|
||||
restart: always
|
||||
environment:
|
||||
- ASPNETCORE_URLS=http://*:5000
|
||||
volumes:
|
||||
- ./appsettings.json:/app/appsettings.json # local configuration file must be placed in the same directory
|
||||
- D://model:/app/model
|
||||
networks:
|
||||
antsk:
|
||||
|
||||
```
|
||||
Using this as an example, the meaning is to mount the local folder D://model from Windows into the container /app/model. If so, your appsettings.json model directory should be configured as
|
||||
|
||||
```
|
||||
model/xxx.gguf
|
||||
```
|
||||
Some meanings of the configuration file
|
||||
// (The rest of the information is omitted as it's unnecessary for the translation example context.)
|
||||
|
||||
Solving the missing style issue:
|
||||
Execute under AntSK/src/AntSK:
|
||||
```
|
||||
dotnet clean
|
||||
dotnet build
|
||||
dotnet publish "AntSK.csproj"
|
||||
```
|
||||
Then go to AntSK/src/AntSK/bin/Release/net8.0/publish
|
||||
```
|
||||
dotnet AntSK.dll
|
||||
```
|
||||
|
||||
```
|
||||
|
||||
@@ -142,12 +189,6 @@ Some meanings of configuration files
|
||||
"DbType": "Sqlite",
|
||||
"ConnectionStrings": "Data Source=AntSK.db;"
|
||||
},
|
||||
"OpenAIOption": {
|
||||
"EndPoint": "http://localhost:5000/llama/",
|
||||
"Key": "NotNull",
|
||||
"Model": "gpt4-turbo",
|
||||
"EmbeddingModel": "text-embedding-ada-002"
|
||||
},
|
||||
"KernelMemory": {
|
||||
"VectorDb": "Disk",
|
||||
"ConnectionString": "Host=;Port=;Database=antsk;Username=;Password=",
|
||||
@@ -156,7 +197,8 @@ Some meanings of configuration files
|
||||
"LLamaSharp": {
|
||||
"RunType": "GPU",
|
||||
"Chat": "D:\\Code\\AI\\AntBlazor\\model\\qwen1_5-1_8b-chat-q8_0.gguf",
|
||||
"Embedding": "D:\\Code\\AI\\AntBlazor\\model\\qwen1_5-1_8b-chat-q8_0.gguf"
|
||||
"Embedding": "D:\\Code\\AI\\AntBlazor\\model\\qwen1_5-1_8b-chat-q8_0.gguf",
|
||||
"FileDirectory": "D:\\Code\\AI\\AntBlazor\\model\\"
|
||||
},
|
||||
"Login": {
|
||||
"User": "admin",
|
||||
@@ -176,10 +218,6 @@ Some meanings of configuration files
|
||||
DBConnection DbType
|
||||
//Connection string, corresponding strings need to be used according to different DB types
|
||||
DBConnection ConnectionStrings
|
||||
//You can use an online API that conforms to the OpenAI format (domestic models use one API adapter), or you can use AntSK's built-in llama API, with the IP and port being the AntSK startup address
|
||||
OpenAIOption EndPoint
|
||||
//Model key, if using a local model, it can default to Notnull. Chinese cannot be used here
|
||||
OpenAIOption Key
|
||||
//The type of vector storage supports Postgres Disk Memory, where Postgres requires the configuration of ConnectionString
|
||||
KernelMemory VectorDb
|
||||
//The running mode used by the local model is GUP CPU. If using an online API, you can freely use one
|
||||
|
||||
42
README.md
42
README.md
@@ -10,15 +10,19 @@
|
||||
|
||||
- **知识库**:通过文档(Word、PDF、Excel、Txt、Markdown、Json、PPT)等形式导入知识库,可以进行知识库文档。
|
||||
|
||||
- **API插件系统**:开放式API插件系统,允许第三方开发者或服务商轻松将其服务集成到AntSK,不断增强应用功能。
|
||||
|
||||
- **联网搜索**:AntSK,实时获取最新信息,确保用户接受到的资料总是最及时、最相关的。
|
||||
|
||||
- **GPTs 生成**:此平台支持创建个性化的GPT模型,尝试构建您自己的GPT模型。
|
||||
|
||||
- **API接口发布**:将内部功能以API的形式对外提供,便于开发者将AntSK 集成进其他应用,增强应用智慧。
|
||||
|
||||
- **模型管理**:适配和管理集成不同厂商的不同模型。
|
||||
- **API插件系统**:开放式API插件系统,允许第三方开发者或服务商轻松将其服务集成到AntSK,不断增强应用功能。
|
||||
|
||||
- **.Net插件系统**:开放式dll插件系统,允许第三方开发者或服务商轻松将其业务功能通过标准格式的代码生成dll后集成到AntSK,不断增强应用功能。
|
||||
|
||||
- **联网搜索**:AntSK,实时获取最新信息,确保用户接受到的资料总是最及时、最相关的。
|
||||
|
||||
- **模型管理**:适配和管理集成不同厂商的不同模型。并且支持llama.cpp所支持的gguf类型,以及llamafactory所支持的模型离线运行
|
||||
|
||||
- **国产信创**:AntSK支持国产模型,和国产数据库,可以在信创条件下运行
|
||||
|
||||
## 应用场景
|
||||
|
||||
@@ -33,6 +37,16 @@ AntSK 适用于多种业务场景,例如:
|
||||
|
||||
## 功能示例
|
||||
|
||||
### 在线演示
|
||||
```
|
||||
https://antsk.ai-dotnet.com/
|
||||
```
|
||||
默认账号:admin
|
||||
|
||||
默认密码:xuzeyu
|
||||
|
||||
|
||||
### 其他功能示例
|
||||
[视频示例](https://www.bilibili.com/video/BV1zH4y1h7Y9/)
|
||||
|
||||
首先需要创建知识库
|
||||
@@ -63,9 +77,9 @@ AntSK 适用于多种业务场景,例如:
|
||||
|
||||
在这里我使用的是Postgres 作为数据存储和向量存储,因为Semantic Kernel和Kernel Memory都支持他,当然你也可以换成其他的。
|
||||
|
||||
模型默认支持openai,如果需要使用azure openai需要调整SK的依赖注入,也可以使用one-api进行集成。
|
||||
模型默认支持openai、azure openai 和llama支持的gguf本地模型,如果需要使用其他模型,可以使用one-api进行集成。
|
||||
|
||||
Login是默认的登陆账号和密码
|
||||
配置文件中的Login配置是默认的登陆账号和密码
|
||||
|
||||
需要配置如下的配置文件
|
||||
|
||||
@@ -83,7 +97,7 @@ docker-compose up -d
|
||||
```
|
||||
来启动AntSK
|
||||
|
||||
## 如何在docker中挂载本地模型
|
||||
## 如何在docker中挂载本地模型,和模型下载的目录
|
||||
```
|
||||
# 非 host 版本, 不使用本机代理
|
||||
version: '3.8'
|
||||
@@ -126,7 +140,8 @@ model/xxx.gguf
|
||||
"LLamaSharp": {
|
||||
"RunType": "GPU",
|
||||
"Chat": "D:\\Code\\AI\\AntBlazor\\model\\qwen1_5-1_8b-chat-q8_0.gguf",
|
||||
"Embedding": "D:\\Code\\AI\\AntBlazor\\model\\qwen1_5-1_8b-chat-q8_0.gguf"
|
||||
"Embedding": "D:\\Code\\AI\\AntBlazor\\model\\qwen1_5-1_8b-chat-q8_0.gguf",
|
||||
"FileDirectory": "D:\\Code\\AI\\AntBlazor\\model\\"
|
||||
},
|
||||
"Login": {
|
||||
"User": "admin",
|
||||
@@ -144,18 +159,19 @@ model/xxx.gguf
|
||||
DBConnection.DbType
|
||||
//连接字符串,需要根据不同DB类型,用对应的字符串
|
||||
DBConnection.ConnectionStrings
|
||||
//可以使用符合openai格式的在线API(国产模型使用one-api转接) ,也可以使用AntSK自带的llama api,ip和端口是AntSK启动地址
|
||||
OpenAIOption.EndPoint
|
||||
//模型秘钥,如果使用本地模型可以默认NotNull 这里不能用中文
|
||||
OpenAIOption.Key
|
||||
|
||||
//向量存储的类型,支持 Postgres Disk Memory ,其中Postgres需要配置 ConnectionString
|
||||
KernelMemory.VectorDb
|
||||
|
||||
//本地模型使用的运行方式 GUP CPU ,如果用在线API 这个随意使用一个即可
|
||||
LLamaSharp.RunType
|
||||
//本地会话模型的模型路径 注意区分linux和windows盘符不同
|
||||
LLamaSharp.Chat
|
||||
//本地向量模型的模型路径 注意区分linux和windows盘符不同
|
||||
LLamaSharp.Embedding
|
||||
//本地模型路径,用于在选择llama时可以快速选择目录下的模型,以及保存下载的模型
|
||||
LLamaSharp.FileDirectory
|
||||
|
||||
//默认管理员账号密码
|
||||
Login
|
||||
//导入异步处理的线程数,使用在线API可以高一点,本地模型建议1 否则容易内存溢出崩掉
|
||||
|
||||
@@ -3,7 +3,7 @@ version: '3.8'
|
||||
services:
|
||||
antsk:
|
||||
container_name: antsk
|
||||
image: registry.cn-hangzhou.aliyuncs.com/xuzeyu91/antsk:v0.1.6.2
|
||||
image: registry.cn-hangzhou.aliyuncs.com/xuzeyu91/antsk:v0.2.1
|
||||
ports:
|
||||
- 5000:5000
|
||||
networks:
|
||||
|
||||
@@ -18,7 +18,7 @@ services:
|
||||
- ./pg/data:/var/lib/postgresql/data
|
||||
antsk:
|
||||
container_name: antsk
|
||||
image: registry.cn-hangzhou.aliyuncs.com/xuzeyu91/antsk:v0.1.6.2
|
||||
image: registry.cn-hangzhou.aliyuncs.com/xuzeyu91/antsk:v0.2.2.1
|
||||
ports:
|
||||
- 5000:5000
|
||||
networks:
|
||||
|
||||
@@ -9,21 +9,24 @@
|
||||
</PropertyGroup>
|
||||
<ItemGroup>
|
||||
<PackageReference Include="AntDesign.Charts" Version="0.5.1" />
|
||||
<PackageReference Include="AntDesign.ProLayout" Version="0.17.3" />
|
||||
<PackageReference Include="AntDesign.ProLayout" Version="0.18.0" />
|
||||
<PackageReference Include="BlazorComponents.Terminal" Version="0.6.0" />
|
||||
|
||||
<PackageReference Include="Swashbuckle.AspNetCore" Version="6.5.0" />
|
||||
|
||||
<PackageReference Include="AutoMapper" Version="8.1.0" />
|
||||
<PackageReference Include="BCrypt.Net-Next" Version="4.0.3" />
|
||||
<PackageReference Include="MarkdownSharp" Version="2.0.5" />
|
||||
<PackageReference Include="Markdig" Version="0.36.2" />
|
||||
<PackageReference Include="Newtonsoft.Json" Version="13.0.3" />
|
||||
<PackageReference Include="SqlSugarCore" Version="5.1.4.143" />
|
||||
<PackageReference Include="SqlSugarCore" Version="5.1.4.145" />
|
||||
<PackageReference Include="System.Data.SQLite.Core" Version="1.0.118" />
|
||||
<PackageReference Include="RestSharp" Version="110.2.0" />
|
||||
|
||||
<PackageReference Include="Microsoft.SemanticKernel" Version="1.4.0" />
|
||||
<PackageReference Include="Microsoft.SemanticKernel.Core" Version="1.4.0" />
|
||||
<PackageReference Include="Microsoft.SemanticKernel.Plugins.Core" Version="1.4.0-alpha" />
|
||||
<PackageReference Include="Microsoft.KernelMemory.Core" Version="0.30.240227.1" />
|
||||
<PackageReference Include="Microsoft.KernelMemory.MemoryDb.Postgres" Version="0.30.240227.1" />
|
||||
<PackageReference Include="Microsoft.SemanticKernel" Version="1.6.2" />
|
||||
<PackageReference Include="Microsoft.SemanticKernel.Core" Version="1.6.2" />
|
||||
<PackageReference Include="Microsoft.SemanticKernel.Plugins.Core" Version="1.6.2-alpha" />
|
||||
<PackageReference Include="Microsoft.KernelMemory.Core" Version="0.34.240313.1" />
|
||||
<PackageReference Include="Microsoft.KernelMemory.MemoryDb.Postgres" Version="0.34.240313.1" />
|
||||
|
||||
<PackageReference Include="LLamaSharp" Version="0.10.0" />
|
||||
<PackageReference Include="LLamaSharp.Backend.Cpu" Version="0.10.0" />
|
||||
@@ -34,6 +37,8 @@
|
||||
|
||||
</ItemGroup>
|
||||
<ItemGroup>
|
||||
<ProjectReference Include="..\AntSK.LLamaFactory\AntSK.LLamaFactory.csproj" />
|
||||
<ProjectReference Include="..\AntSk.LLM\AntSK.LLM.csproj" />
|
||||
<ProjectReference Include="..\MiddleWare\AntSK.BackgroundTask\AntSK.BackgroundTask.csproj" />
|
||||
</ItemGroup>
|
||||
|
||||
|
||||
@@ -17,6 +17,27 @@
|
||||
<param name="assemblies">程序集集合</param>
|
||||
<returns></returns>
|
||||
</member>
|
||||
<member name="M:AntSK.Domain.Common.DependencyInjection.InitExtensions.CodeFirst(Microsoft.AspNetCore.Builder.WebApplication)">
|
||||
<summary>
|
||||
使用codefirst创建数据库表
|
||||
</summary>
|
||||
<param name="services"></param>
|
||||
<returns></returns>
|
||||
</member>
|
||||
<member name="M:AntSK.Domain.Common.DependencyInjection.InitExtensions.LoadFun(Microsoft.AspNetCore.Builder.WebApplication)">
|
||||
<summary>
|
||||
加载数据库的插件
|
||||
</summary>
|
||||
<param name="services"></param>
|
||||
<returns></returns>
|
||||
</member>
|
||||
<member name="M:AntSK.Domain.Common.DependencyInjection.InitExtensions.AddAntSKSwagger(Microsoft.Extensions.DependencyInjection.IServiceCollection)">
|
||||
<summary>
|
||||
swagger 初始化
|
||||
</summary>
|
||||
<param name="serviceCollection"></param>
|
||||
<returns></returns>
|
||||
</member>
|
||||
<member name="F:AntSK.Domain.Common.DependencyInjection.ServiceLifetime.Scoped">
|
||||
<summary>
|
||||
作用域
|
||||
@@ -48,6 +69,36 @@
|
||||
<param name="value"></param>
|
||||
<returns></returns>
|
||||
</member>
|
||||
<member name="T:AntSK.Domain.Domain.Model.Enum.AIType">
|
||||
<summary>
|
||||
AI类型
|
||||
</summary>
|
||||
</member>
|
||||
<member name="T:AntSK.Domain.Domain.Model.Enum.AIModelType">
|
||||
<summary>
|
||||
模型类型
|
||||
</summary>
|
||||
</member>
|
||||
<member name="P:AntSK.Domain.Domain.Model.MessageInfo.IsSend">
|
||||
<summary>
|
||||
发送是true 接收是false
|
||||
</summary>
|
||||
</member>
|
||||
<member name="P:AntSK.Domain.Domain.Model.PageList`1.PageIndex">
|
||||
<summary>
|
||||
当前页,从1开始
|
||||
</summary>
|
||||
</member>
|
||||
<member name="P:AntSK.Domain.Domain.Model.PageList`1.PageSize">
|
||||
<summary>
|
||||
每页数量
|
||||
</summary>
|
||||
</member>
|
||||
<member name="P:AntSK.Domain.Domain.Model.PageList`1.TotalCount">
|
||||
<summary>
|
||||
总数
|
||||
</summary>
|
||||
</member>
|
||||
<member name="F:AntSK.Domain.Domain.Other.LLamaConfig.dicLLamaWeights">
|
||||
<summary>
|
||||
避免模型重复加载,本地缓存
|
||||
@@ -62,6 +113,11 @@
|
||||
<param name="history"></param>
|
||||
<returns></returns>
|
||||
</member>
|
||||
<member name="M:AntSK.Domain.Domain.Service.FunctionService.SearchMarkedMethods">
|
||||
<summary>
|
||||
查询程序集中的方法委托,后续利用Source Generators生成
|
||||
</summary>
|
||||
</member>
|
||||
<member name="M:AntSK.Domain.Domain.Service.KernelService.GetKernelByApp(AntSK.Domain.Repositories.Apps)">
|
||||
<summary>
|
||||
获取kernel实例,依赖注入不好按每个用户去Import不同的插件,所以每次new一个新的kernel
|
||||
@@ -77,6 +133,20 @@
|
||||
<param name="app"></param>
|
||||
<param name="_kernel"></param>
|
||||
</member>
|
||||
<member name="M:AntSK.Domain.Domain.Service.KernelService.ImportApiFunction(AntSK.Domain.Repositories.Apps,System.Collections.Generic.List{Microsoft.SemanticKernel.KernelFunction})">
|
||||
<summary>
|
||||
导入API插件
|
||||
</summary>
|
||||
<param name="app"></param>
|
||||
<param name="functions"></param>
|
||||
</member>
|
||||
<member name="M:AntSK.Domain.Domain.Service.KernelService.ImportNativeFunction(AntSK.Domain.Repositories.Apps,System.Collections.Generic.List{Microsoft.SemanticKernel.KernelFunction})">
|
||||
<summary>
|
||||
导入原生插件
|
||||
</summary>
|
||||
<param name="app"></param>
|
||||
<param name="functions"></param>
|
||||
</member>
|
||||
<member name="M:AntSK.Domain.Domain.Service.KernelService.RegisterPluginsWithKernel(Microsoft.SemanticKernel.Kernel)">
|
||||
<summary>
|
||||
注册默认插件
|
||||
@@ -92,61 +162,6 @@
|
||||
<param name="history"></param>
|
||||
<returns></returns>
|
||||
</member>
|
||||
<member name="M:AntSK.Domain.Map.MapperExtend.ToDTOList``1(System.Object)">
|
||||
<summary>
|
||||
Entity集合转DTO集合
|
||||
</summary>
|
||||
<typeparam name="T"></typeparam>
|
||||
<param name="value"></param>
|
||||
<returns></returns>
|
||||
</member>
|
||||
<member name="M:AntSK.Domain.Map.MapperExtend.ToDTO``1(System.Object)">
|
||||
<summary>
|
||||
Entity转DTO
|
||||
</summary>
|
||||
<typeparam name="T"></typeparam>
|
||||
<param name="value"></param>
|
||||
<returns></returns>
|
||||
</member>
|
||||
<member name="M:AntSK.Domain.Map.MapperExtend.MapTo``1(System.Object,``0)">
|
||||
<summary>
|
||||
给已有对象map,适合update场景,如需过滤空值需要在AutoMapProfile 设置
|
||||
</summary>
|
||||
<typeparam name="T"></typeparam>
|
||||
<param name="self"></param>
|
||||
<param name="result"></param>
|
||||
<returns></returns>
|
||||
</member>
|
||||
<member name="T:AntSK.Domain.Model.Enum.AIType">
|
||||
<summary>
|
||||
AI类型
|
||||
</summary>
|
||||
</member>
|
||||
<member name="T:AntSK.Domain.Model.Enum.AIModelType">
|
||||
<summary>
|
||||
模型类型
|
||||
</summary>
|
||||
</member>
|
||||
<member name="P:AntSK.Domain.Model.MessageInfo.IsSend">
|
||||
<summary>
|
||||
发送是true 接收是false
|
||||
</summary>
|
||||
</member>
|
||||
<member name="P:AntSK.Domain.Model.PageList`1.PageIndex">
|
||||
<summary>
|
||||
当前页,从1开始
|
||||
</summary>
|
||||
</member>
|
||||
<member name="P:AntSK.Domain.Model.PageList`1.PageSize">
|
||||
<summary>
|
||||
每页数量
|
||||
</summary>
|
||||
</member>
|
||||
<member name="P:AntSK.Domain.Model.PageList`1.TotalCount">
|
||||
<summary>
|
||||
总数
|
||||
</summary>
|
||||
</member>
|
||||
<member name="P:AntSK.Domain.Options.DBConnectionOption.DbType">
|
||||
<summary>
|
||||
sqlite连接字符串
|
||||
@@ -237,6 +252,11 @@
|
||||
会话模型ID
|
||||
</summary>
|
||||
</member>
|
||||
<member name="P:AntSK.Domain.Repositories.Apps.EmbeddingModelID">
|
||||
<summary>
|
||||
Embedding 模型Id
|
||||
</summary>
|
||||
</member>
|
||||
<member name="P:AntSK.Domain.Repositories.Apps.Temperature">
|
||||
<summary>
|
||||
温度
|
||||
@@ -252,6 +272,11 @@
|
||||
插件列表
|
||||
</summary>
|
||||
</member>
|
||||
<member name="P:AntSK.Domain.Repositories.Apps.NativeFunctionList">
|
||||
<summary>
|
||||
本地函数列表
|
||||
</summary>
|
||||
</member>
|
||||
<member name="P:AntSK.Domain.Repositories.Apps.KmsIdList">
|
||||
<summary>
|
||||
知识库ID列表
|
||||
@@ -262,6 +287,11 @@
|
||||
API调用秘钥
|
||||
</summary>
|
||||
</member>
|
||||
<member name="P:AntSK.Domain.Repositories.Funs.Path">
|
||||
<summary>
|
||||
接口描述
|
||||
</summary>
|
||||
</member>
|
||||
<member name="P:AntSK.Domain.Repositories.KmsDetails.FileName">
|
||||
<summary>
|
||||
文件名称
|
||||
@@ -734,11 +764,143 @@
|
||||
<param name="stream"></param>
|
||||
<returns></returns>
|
||||
</member>
|
||||
<member name="M:AntSK.Domain.Utils.ConvertUtils.ToQueryString(System.Collections.Generic.Dictionary{System.String,System.String})">
|
||||
<summary>
|
||||
json参数转化querystring参数
|
||||
</summary>
|
||||
<param name="parameters"></param>
|
||||
<returns></returns>
|
||||
</member>
|
||||
<member name="M:AntSK.Domain.Utils.RepoFiles.SamplePluginsPath">
|
||||
<summary>
|
||||
Scan the local folders from the repo, looking for "samples/plugins" folder.
|
||||
</summary>
|
||||
<returns>The full path to samples/plugins</returns>
|
||||
</member>
|
||||
<member name="T:AntSK.Domain.Utils.XmlCommentHelper">
|
||||
<summary>
|
||||
注释辅助类
|
||||
</summary>
|
||||
</member>
|
||||
<member name="M:AntSK.Domain.Utils.XmlCommentHelper.LoadAll">
|
||||
<summary>
|
||||
从当前dll文件中加载所有的xml文件
|
||||
</summary>
|
||||
</member>
|
||||
<member name="M:AntSK.Domain.Utils.XmlCommentHelper.LoadXml(System.String[])">
|
||||
<summary>
|
||||
从xml中加载
|
||||
</summary>
|
||||
<param name="xmls"></param>
|
||||
</member>
|
||||
<member name="M:AntSK.Domain.Utils.XmlCommentHelper.Load(System.String[])">
|
||||
<summary>
|
||||
从文件中加载
|
||||
</summary>
|
||||
<param name="xmlFiles"></param>
|
||||
</member>
|
||||
<member name="M:AntSK.Domain.Utils.XmlCommentHelper.Load(System.IO.Stream[])">
|
||||
<summary>
|
||||
从流中加载
|
||||
</summary>
|
||||
<param name="streams"></param>
|
||||
</member>
|
||||
<member name="M:AntSK.Domain.Utils.XmlCommentHelper.GetTypeComment(System.Type,System.String,System.Boolean)">
|
||||
<summary>
|
||||
读取类型中的注释
|
||||
</summary>
|
||||
<param name="type">类型</param>
|
||||
<param name="xPath">注释路径</param>
|
||||
<param name="humanize">可读性优化(比如:去掉xml标记)</param>
|
||||
<returns></returns>
|
||||
</member>
|
||||
<member name="M:AntSK.Domain.Utils.XmlCommentHelper.GetFieldOrPropertyComment(System.Reflection.MemberInfo,System.String,System.Boolean)">
|
||||
<summary>
|
||||
读取字段或者属性的注释
|
||||
</summary>
|
||||
<param name="fieldOrPropertyInfo">字段或者属性</param>
|
||||
<param name="xPath">注释路径</param>
|
||||
<param name="humanize">可读性优化(比如:去掉xml标记)</param>
|
||||
<returns></returns>
|
||||
</member>
|
||||
<member name="M:AntSK.Domain.Utils.XmlCommentHelper.GetMethodComment(System.Reflection.MethodInfo,System.String,System.Boolean)">
|
||||
<summary>
|
||||
读取方法中的注释
|
||||
</summary>
|
||||
<param name="methodInfo">方法</param>
|
||||
<param name="xPath">注释路径</param>
|
||||
<param name="humanize">可读性优化(比如:去掉xml标记)</param>
|
||||
<returns></returns>
|
||||
</member>
|
||||
<member name="M:AntSK.Domain.Utils.XmlCommentHelper.GetMethodReturnComment(System.Reflection.MethodInfo,System.Boolean)">
|
||||
<summary>
|
||||
读取方法中的返回值注释
|
||||
</summary>
|
||||
<param name="methodInfo">方法</param>
|
||||
<param name="humanize">可读性优化(比如:去掉xml标记)</param>
|
||||
<returns></returns>
|
||||
</member>
|
||||
<member name="M:AntSK.Domain.Utils.XmlCommentHelper.GetParameterComment(System.Reflection.ParameterInfo,System.Boolean)">
|
||||
<summary>
|
||||
读取参数的注释
|
||||
</summary>
|
||||
<param name="parameterInfo">参数</param>
|
||||
<param name="humanize">可读性优化(比如:去掉xml标记)</param>
|
||||
<returns></returns>
|
||||
</member>
|
||||
<member name="M:AntSK.Domain.Utils.XmlCommentHelper.GetParameterComments(System.Reflection.MethodInfo,System.Boolean)">
|
||||
<summary>
|
||||
读取方法的所有参数的注释
|
||||
</summary>
|
||||
<param name="methodInfo">方法</param>
|
||||
<param name="humanize">可读性优化(比如:去掉xml标记)</param>
|
||||
<returns></returns>
|
||||
</member>
|
||||
<member name="M:AntSK.Domain.Utils.XmlCommentHelper.GetComment(System.String,System.String,System.Boolean)">
|
||||
<summary>
|
||||
读取指定名称节点的注释
|
||||
</summary>
|
||||
<param name="name">节点名称</param>
|
||||
<param name="xPath">注释路径</param>
|
||||
<param name="humanize">可读性优化(比如:去掉xml标记)</param>
|
||||
<returns></returns>
|
||||
</member>
|
||||
<member name="M:AntSK.Domain.Utils.XmlCommentHelper.GetSummary(System.String,System.Boolean)">
|
||||
<summary>
|
||||
读取指定节点的summary注释
|
||||
</summary>
|
||||
<param name="name">节点名称</param>
|
||||
<param name="humanize">可读性优化(比如:去掉xml标记)</param>
|
||||
<returns></returns>
|
||||
</member>
|
||||
<member name="M:AntSK.Domain.Utils.XmlCommentHelper.GetExample(System.String,System.Boolean)">
|
||||
<summary>
|
||||
读取指定节点的example注释
|
||||
</summary>
|
||||
<param name="name">节点名称</param>
|
||||
<param name="humanize">可读性优化(比如:去掉xml标记)</param>
|
||||
<returns></returns>
|
||||
</member>
|
||||
<member name="M:AntSK.Domain.Utils.XmlCommentHelper.GetMemberNameForMethod(System.Reflection.MethodInfo)">
|
||||
<summary>
|
||||
获取方法的节点名称
|
||||
</summary>
|
||||
<param name="method"></param>
|
||||
<returns></returns>
|
||||
</member>
|
||||
<member name="M:AntSK.Domain.Utils.XmlCommentHelper.GetMemberNameForType(System.Type)">
|
||||
<summary>
|
||||
获取类型的节点名称
|
||||
</summary>
|
||||
<param name="type"></param>
|
||||
<returns></returns>
|
||||
</member>
|
||||
<member name="M:AntSK.Domain.Utils.XmlCommentHelper.GetMemberNameForFieldOrProperty(System.Reflection.MemberInfo)">
|
||||
<summary>
|
||||
获取字段或者属性的节点名称
|
||||
</summary>
|
||||
<param name="fieldOrPropertyInfo"></param>
|
||||
<returns></returns>
|
||||
</member>
|
||||
</members>
|
||||
</doc>
|
||||
|
||||
8
src/AntSK.Domain/Common/AntSkFunctionAttribute.cs
Normal file
8
src/AntSK.Domain/Common/AntSkFunctionAttribute.cs
Normal file
@@ -0,0 +1,8 @@
|
||||
namespace AntSK.Domain.Common
|
||||
{
|
||||
[AttributeUsage(AttributeTargets.Method)]
|
||||
public class AntSkFunctionAttribute : Attribute
|
||||
{
|
||||
// 自定义的ActionAttribute
|
||||
}
|
||||
}
|
||||
164
src/AntSK.Domain/Common/DependencyInjection/InitExtensions.cs
Normal file
164
src/AntSK.Domain/Common/DependencyInjection/InitExtensions.cs
Normal file
@@ -0,0 +1,164 @@
|
||||
using AntSK.Domain.Domain.Model.Constant;
|
||||
using AntSK.Domain.Domain.Service;
|
||||
using AntSK.Domain.Repositories;
|
||||
using DocumentFormat.OpenXml.Office2016.Drawing.ChartDrawing;
|
||||
using Microsoft.AspNetCore.Builder;
|
||||
using Microsoft.AspNetCore.Mvc;
|
||||
using Microsoft.Extensions.DependencyInjection;
|
||||
using Microsoft.OpenApi.Models;
|
||||
using SqlSugar;
|
||||
using Swashbuckle.AspNetCore.SwaggerGen;
|
||||
using System;
|
||||
using System.Collections.Generic;
|
||||
using System.Linq;
|
||||
using System.Reflection;
|
||||
using System.Text;
|
||||
using System.Threading.Tasks;
|
||||
|
||||
namespace AntSK.Domain.Common.DependencyInjection
|
||||
{
|
||||
public static class InitExtensions
|
||||
{
|
||||
/// <summary>
|
||||
/// 使用codefirst创建数据库表
|
||||
/// </summary>
|
||||
/// <param name="services"></param>
|
||||
/// <returns></returns>
|
||||
public static WebApplication CodeFirst(this WebApplication app)
|
||||
{
|
||||
using (var scope = app.Services.CreateScope())
|
||||
{
|
||||
// 获取仓储服务
|
||||
var _repository = scope.ServiceProvider.GetRequiredService<IApps_Repositories>();
|
||||
|
||||
// 创建数据库(如果不存在)
|
||||
_repository.GetDB().DbMaintenance.CreateDatabase();
|
||||
|
||||
// 获取当前应用程序域中所有程序集
|
||||
var assemblies = AppDomain.CurrentDomain.GetAssemblies();
|
||||
|
||||
// 在所有程序集中查找具有[SugarTable]特性的类
|
||||
foreach (var assembly in assemblies)
|
||||
{
|
||||
// 获取该程序集中所有具有SugarTable特性的类型
|
||||
var entityTypes = assembly.GetTypes()
|
||||
.Where(type => TypeIsEntity(type));
|
||||
|
||||
// 为每个找到的类型初始化数据库表
|
||||
foreach (var type in entityTypes)
|
||||
{
|
||||
_repository.GetDB().CodeFirst.InitTables(type);
|
||||
}
|
||||
}
|
||||
}
|
||||
return app;
|
||||
}
|
||||
|
||||
public static WebApplication InitDbData(this WebApplication app)
|
||||
{
|
||||
using (var scope = app.Services.CreateScope())
|
||||
{
|
||||
// 初始化字典
|
||||
var _dic_Repository = scope.ServiceProvider.GetRequiredService<IDics_Repositories>();
|
||||
var llamafactoryStart = _dic_Repository.GetFirst(p => p.Type == LLamaFactoryConstantcs.LLamaFactorDic && p.Key == LLamaFactoryConstantcs.IsStartKey);
|
||||
if (llamafactoryStart==null)
|
||||
{
|
||||
llamafactoryStart = new Dics();
|
||||
llamafactoryStart.Id=Guid.NewGuid().ToString();
|
||||
llamafactoryStart.Type = LLamaFactoryConstantcs.LLamaFactorDic;
|
||||
llamafactoryStart.Key = LLamaFactoryConstantcs.IsStartKey;
|
||||
llamafactoryStart.Value = "false";
|
||||
_dic_Repository.Insert(llamafactoryStart);
|
||||
}
|
||||
|
||||
}
|
||||
return app;
|
||||
}
|
||||
/// <summary>
|
||||
/// 加载数据库的插件
|
||||
/// </summary>
|
||||
/// <param name="services"></param>
|
||||
/// <returns></returns>
|
||||
public static WebApplication LoadFun(this WebApplication app)
|
||||
{
|
||||
try
|
||||
{
|
||||
using (var scope = app.Services.CreateScope())
|
||||
{
|
||||
//codefirst 创建表
|
||||
var funRep = scope.ServiceProvider.GetRequiredService<IFuns_Repositories>();
|
||||
var functionService = scope.ServiceProvider.GetRequiredService<FunctionService>();
|
||||
var funs = funRep.GetList();
|
||||
foreach (var fun in funs)
|
||||
{
|
||||
functionService.FuncLoad(fun.Path);
|
||||
}
|
||||
}
|
||||
}
|
||||
catch (Exception ex)
|
||||
{
|
||||
Console.WriteLine(ex.Message + " ---- " + ex.StackTrace);
|
||||
}
|
||||
return app;
|
||||
}
|
||||
private static bool TypeIsEntity(Type type)
|
||||
{
|
||||
// 检查类型是否具有SugarTable特性
|
||||
return type.GetCustomAttributes(typeof(SugarTable), inherit: false).Length > 0;
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// swagger 初始化
|
||||
/// </summary>
|
||||
/// <param name="serviceCollection"></param>
|
||||
/// <returns></returns>
|
||||
public static IServiceCollection AddAntSKSwagger(this IServiceCollection serviceCollection)
|
||||
{
|
||||
serviceCollection.AddSwaggerGen(c =>
|
||||
{
|
||||
c.SwaggerDoc("v1", new() { Title = "AntSK.Api", Version = "v1" });
|
||||
//添加Api层注释(true表示显示控制器注释)
|
||||
var xmlFile = $"{Assembly.GetExecutingAssembly().GetName().Name}.xml";
|
||||
var xmlPath = Path.Combine(AppContext.BaseDirectory, xmlFile);
|
||||
c.IncludeXmlComments(xmlPath, true);
|
||||
//添加Domain层注释(true表示显示控制器注释)
|
||||
var xmlFile1 = $"{Assembly.GetExecutingAssembly().GetName().Name.Replace("Api", "Domain")}.xml";
|
||||
var xmlPath1 = Path.Combine(AppContext.BaseDirectory, xmlFile1);
|
||||
c.IncludeXmlComments(xmlPath1, true);
|
||||
c.DocInclusionPredicate((docName, apiDes) =>
|
||||
{
|
||||
if (!apiDes.TryGetMethodInfo(out MethodInfo method))
|
||||
return false;
|
||||
var version = method.DeclaringType.GetCustomAttributes(true).OfType<ApiExplorerSettingsAttribute>().Select(m => m.GroupName);
|
||||
if (docName == "v1" && !version.Any())
|
||||
return true;
|
||||
var actionVersion = method.GetCustomAttributes(true).OfType<ApiExplorerSettingsAttribute>().Select(m => m.GroupName);
|
||||
if (actionVersion.Any())
|
||||
return actionVersion.Any(v => v == docName);
|
||||
return version.Any(v => v == docName);
|
||||
});
|
||||
c.AddSecurityDefinition("Bearer", new OpenApiSecurityScheme()
|
||||
{
|
||||
Description = "Directly enter bearer {token} in the box below (note that there is a space between bearer and token)",
|
||||
Name = "Authorization",
|
||||
In = ParameterLocation.Header,
|
||||
Type = SecuritySchemeType.ApiKey,
|
||||
});
|
||||
c.AddSecurityRequirement(new OpenApiSecurityRequirement
|
||||
{
|
||||
{
|
||||
new OpenApiSecurityScheme
|
||||
{
|
||||
Reference = new OpenApiReference()
|
||||
{
|
||||
Id = "Bearer",
|
||||
Type = ReferenceType.SecurityScheme
|
||||
}
|
||||
}, Array.Empty<string>()
|
||||
}
|
||||
});
|
||||
});
|
||||
return serviceCollection;
|
||||
}
|
||||
}
|
||||
}
|
||||
71
src/AntSK.Domain/Common/LLamaFactory/ProcessWrapper.cs
Normal file
71
src/AntSK.Domain/Common/LLamaFactory/ProcessWrapper.cs
Normal file
@@ -0,0 +1,71 @@
|
||||
using System;
|
||||
using System.Collections.Generic;
|
||||
using System.Diagnostics;
|
||||
using System.Linq;
|
||||
using System.Text;
|
||||
using System.Threading.Tasks;
|
||||
|
||||
namespace AntSK.Domain.Common.LLamaFactory
|
||||
{
|
||||
public class ProcessWrapper
|
||||
{
|
||||
private Process process;
|
||||
|
||||
public static bool isProcessComplete = false;
|
||||
|
||||
|
||||
public void StartProcess(string arguments, string workingDirectory)
|
||||
{
|
||||
process = new Process
|
||||
{
|
||||
StartInfo = new ProcessStartInfo
|
||||
{
|
||||
FileName = "python",
|
||||
Arguments = arguments,
|
||||
UseShellExecute = false,
|
||||
RedirectStandardOutput = true,
|
||||
RedirectStandardError = true,
|
||||
CreateNoWindow = true,
|
||||
WorkingDirectory = workingDirectory
|
||||
}
|
||||
};
|
||||
using (Process start = Process.Start(process.StartInfo))
|
||||
{
|
||||
using (StreamReader reader = start.StandardOutput)
|
||||
{
|
||||
string result = reader.ReadToEnd();
|
||||
if (result != null)
|
||||
{
|
||||
if (result.Contains(":8000"))
|
||||
{
|
||||
isProcessComplete = true;
|
||||
}
|
||||
}
|
||||
Console.WriteLine(result);
|
||||
}
|
||||
start.WaitForExit();
|
||||
}
|
||||
}
|
||||
|
||||
public string WaitForProcessExit()
|
||||
{
|
||||
process.WaitForExit();
|
||||
return process.StandardOutput.ReadToEnd();
|
||||
}
|
||||
|
||||
public void KillProcess()
|
||||
{
|
||||
try
|
||||
{
|
||||
if (!process.HasExited)
|
||||
{
|
||||
process.Kill();
|
||||
}
|
||||
}
|
||||
catch (InvalidOperationException)
|
||||
{
|
||||
// Process already exited.
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -1,16 +0,0 @@
|
||||
using System;
|
||||
using System.Collections.Generic;
|
||||
using System.Linq;
|
||||
using System.Text;
|
||||
using System.Threading.Tasks;
|
||||
|
||||
namespace AntSK.Domain.Domain.Dto
|
||||
{
|
||||
public class RelevantSource
|
||||
{
|
||||
public string SourceName { get; set; }
|
||||
|
||||
public string Text { get; set; }
|
||||
public float Relevance { get; set; }
|
||||
}
|
||||
}
|
||||
@@ -1,4 +1,4 @@
|
||||
using AntSK.Domain.Domain.Dto;
|
||||
using AntSK.Domain.Domain.Model.Dto;
|
||||
using AntSK.Domain.Repositories;
|
||||
using Microsoft.SemanticKernel;
|
||||
using System;
|
||||
@@ -13,6 +13,6 @@ namespace AntSK.Domain.Domain.Interface
|
||||
{
|
||||
IAsyncEnumerable<StreamingKernelContent> SendChatByAppAsync(Apps app, string questions, string history);
|
||||
|
||||
IAsyncEnumerable<StreamingKernelContent> SendKmsByAppAsync(Apps app, string questions, string history, List<RelevantSource> relevantSources = null);
|
||||
IAsyncEnumerable<StreamingKernelContent> SendKmsByAppAsync(Apps app, string questions, string history, string filePath, List<RelevantSource> relevantSources = null);
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -1,4 +1,4 @@
|
||||
using AntSK.Domain.Model;
|
||||
using AntSK.Domain.Domain.Model;
|
||||
|
||||
namespace AntSK.Domain.Domain.Interface
|
||||
{
|
||||
|
||||
@@ -1,11 +1,24 @@
|
||||
using AntSK.Domain.Domain.Dto;
|
||||
using AntDesign;
|
||||
using AntSK.Domain.Domain.Model.Dto;
|
||||
using AntSK.Domain.Repositories;
|
||||
using Microsoft.KernelMemory;
|
||||
|
||||
namespace AntSK.Domain.Domain.Interface
|
||||
{
|
||||
public interface IKMService
|
||||
{
|
||||
MemoryServerless GetMemory(Apps app);
|
||||
|
||||
MemoryServerless GetMemoryByKMS(string kmsID, SearchClientConfig searchClientConfig = null);
|
||||
Task<List<KMFile>> GetDocumentByFileID(string kmsid, string fileid);
|
||||
|
||||
Task<List<KMFile>> GetDocumentByFileID(string kmsId, string fileId);
|
||||
|
||||
Task<List<RelevantSource>> GetRelevantSourceList(string kmsIdListStr, string msg);
|
||||
|
||||
List<UploadFileItem> FileList { get; }
|
||||
|
||||
bool BeforeUpload(UploadFileItem file);
|
||||
|
||||
void OnSingleCompleted(UploadInfo fileinfo);
|
||||
}
|
||||
}
|
||||
}
|
||||
21
src/AntSK.Domain/Domain/Interface/ILLamaFactoryService.cs
Normal file
21
src/AntSK.Domain/Domain/Interface/ILLamaFactoryService.cs
Normal file
@@ -0,0 +1,21 @@
|
||||
using AntSK.LLamaFactory.Model;
|
||||
using System;
|
||||
using System.Collections.Generic;
|
||||
using System.Linq;
|
||||
using System.Text;
|
||||
using System.Threading.Tasks;
|
||||
using static AntSK.Domain.Domain.Service.LLamaFactoryService;
|
||||
|
||||
namespace AntSK.Domain.Domain.Interface
|
||||
{
|
||||
public interface ILLamaFactoryService
|
||||
{
|
||||
public event LogMessageHandler LogMessageReceived;
|
||||
Task PipInstall();
|
||||
Task StartLLamaFactory(string modelName, string templateName);
|
||||
|
||||
void KillProcess();
|
||||
|
||||
List<LLamaModel> GetLLamaFactoryModels();
|
||||
}
|
||||
}
|
||||
15
src/AntSK.Domain/Domain/Model/Constant/KmsConstantcs.cs
Normal file
15
src/AntSK.Domain/Domain/Model/Constant/KmsConstantcs.cs
Normal file
@@ -0,0 +1,15 @@
|
||||
using System;
|
||||
using System.Collections.Generic;
|
||||
using System.Linq;
|
||||
using System.Text;
|
||||
using System.Threading.Tasks;
|
||||
|
||||
namespace AntSK.Domain.Domain.Model.Constant
|
||||
{
|
||||
public class KmsConstantcs
|
||||
{
|
||||
public const string KmsIdTag = "kmsid";
|
||||
public const string KmsIndex = "kms";
|
||||
public const string KmsSearchNull="知识库未搜索到相关内容";
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,14 @@
|
||||
using System;
|
||||
using System.Collections.Generic;
|
||||
using System.Linq;
|
||||
using System.Text;
|
||||
using System.Threading.Tasks;
|
||||
|
||||
namespace AntSK.Domain.Domain.Model.Constant
|
||||
{
|
||||
public class LLamaFactoryConstantcs
|
||||
{
|
||||
public const string LLamaFactorDic = "llamafactory";
|
||||
public const string IsStartKey = "isstart";
|
||||
}
|
||||
}
|
||||
@@ -1,15 +1,14 @@
|
||||
namespace AntSK.Domain.Domain.Dto
|
||||
namespace AntSK.Domain.Domain.Model.Dto
|
||||
{
|
||||
public class KMFile
|
||||
{
|
||||
public string DocumentId { get; set; }
|
||||
public string Text { get; set; }
|
||||
|
||||
public string Url { get; set; }
|
||||
public string? Url { get; set; }
|
||||
|
||||
public string LastUpdate { get; set; }
|
||||
|
||||
public string Schema { get; set; }
|
||||
|
||||
public string File { get; set; }
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -1,4 +1,4 @@
|
||||
namespace AntSK.Domain.Domain.Dto
|
||||
namespace AntSK.Domain.Domain.Model.Dto.OpenAPI
|
||||
{
|
||||
public class OpenAIModel
|
||||
{
|
||||
@@ -1,6 +1,6 @@
|
||||
using Newtonsoft.Json;
|
||||
|
||||
namespace AntSK.Domain.Domain.Dto
|
||||
namespace AntSK.Domain.Domain.Model.Dto.OpenAPI
|
||||
{
|
||||
public class OpenAIResult
|
||||
{
|
||||
16
src/AntSK.Domain/Domain/Model/Dto/RelevantSource.cs
Normal file
16
src/AntSK.Domain/Domain/Model/Dto/RelevantSource.cs
Normal file
@@ -0,0 +1,16 @@
|
||||
|
||||
namespace AntSK.Domain.Domain.Model.Dto
|
||||
{
|
||||
public class RelevantSource
|
||||
{
|
||||
public string SourceName { get; set; }
|
||||
|
||||
public string Text { get; set; }
|
||||
public float Relevance { get; set; }
|
||||
|
||||
public override string ToString()
|
||||
{
|
||||
return $"[file:{SourceName};Relevance:{(Relevance * 100):F2}%]:{Text}";
|
||||
}
|
||||
}
|
||||
}
|
||||
40
src/AntSK.Domain/Domain/Model/Enum/AIModelType.cs
Normal file
40
src/AntSK.Domain/Domain/Model/Enum/AIModelType.cs
Normal file
@@ -0,0 +1,40 @@
|
||||
using System.ComponentModel.DataAnnotations;
|
||||
|
||||
namespace AntSK.Domain.Domain.Model.Enum
|
||||
{
|
||||
/// <summary>
|
||||
/// AI类型
|
||||
/// </summary>
|
||||
public enum AIType
|
||||
{
|
||||
[Display(Name = "Open AI")]
|
||||
OpenAI = 1,
|
||||
|
||||
[Display(Name = "Azure Open AI")]
|
||||
AzureOpenAI = 2,
|
||||
|
||||
[Display(Name = "LLama本地模型")]
|
||||
LLamaSharp = 3,
|
||||
|
||||
[Display(Name = "星火大模型")]
|
||||
SparkDesk = 4,
|
||||
|
||||
[Display(Name = "灵积大模型")]
|
||||
DashScope = 5,
|
||||
|
||||
[Display(Name = "LLamaFactory")]
|
||||
LLamaFactory = 6,
|
||||
|
||||
[Display(Name = "模拟输出")]
|
||||
Mock = 100,
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// 模型类型
|
||||
/// </summary>
|
||||
public enum AIModelType
|
||||
{
|
||||
Chat = 1,
|
||||
Embedding = 2,
|
||||
}
|
||||
}
|
||||
14
src/AntSK.Domain/Domain/Model/Enum/AppType.cs
Normal file
14
src/AntSK.Domain/Domain/Model/Enum/AppType.cs
Normal file
@@ -0,0 +1,14 @@
|
||||
using System;
|
||||
using System.Collections.Generic;
|
||||
using System.Linq;
|
||||
using System.Text;
|
||||
using System.Threading.Tasks;
|
||||
|
||||
namespace AntSK.Domain.Domain.Model.Enum
|
||||
{
|
||||
public enum AppType
|
||||
{
|
||||
chat = 1,
|
||||
kms = 2
|
||||
}
|
||||
}
|
||||
@@ -1,9 +1,8 @@
|
||||
namespace AntSK.Domain.Model
|
||||
namespace AntSK.Domain.Domain.Model.Enum
|
||||
{
|
||||
public enum HttpMethodType
|
||||
{
|
||||
Get = 1,
|
||||
Post = 2,
|
||||
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -1,4 +1,4 @@
|
||||
namespace AntSK.Domain.Model.Enum
|
||||
namespace AntSK.Domain.Domain.Model.Enum
|
||||
{
|
||||
public enum ImportKmsStatus
|
||||
{
|
||||
23
src/AntSK.Domain/Domain/Model/Fun/FunDto.cs
Normal file
23
src/AntSK.Domain/Domain/Model/Fun/FunDto.cs
Normal file
@@ -0,0 +1,23 @@
|
||||
using System;
|
||||
using System.Collections.Generic;
|
||||
using System.Linq;
|
||||
using System.Text;
|
||||
using System.Threading.Tasks;
|
||||
|
||||
namespace AntSK.Domain.Domain.Model.Fun
|
||||
{
|
||||
public class FunDto
|
||||
{
|
||||
public string Name { get; set; }
|
||||
|
||||
public string Description { get; set; }
|
||||
|
||||
public FunType FunType { get; set; }
|
||||
}
|
||||
|
||||
public enum FunType
|
||||
{
|
||||
System=1,
|
||||
Import=2
|
||||
}
|
||||
}
|
||||
@@ -1,6 +1,6 @@
|
||||
using AntSK.Domain.Repositories;
|
||||
|
||||
namespace AntSK.Domain.Model
|
||||
namespace AntSK.Domain.Domain.Model
|
||||
{
|
||||
public class ImportKMSTaskDTO
|
||||
{
|
||||
@@ -1,4 +1,4 @@
|
||||
namespace AntSK.Domain.Model
|
||||
namespace AntSK.Domain.Domain.Model
|
||||
{
|
||||
public class MessageInfo
|
||||
{
|
||||
@@ -10,7 +10,11 @@
|
||||
/// 发送是true 接收是false
|
||||
/// </summary>
|
||||
public bool IsSend { get; set; } = false;
|
||||
|
||||
public DateTime CreateTime { get; set; }
|
||||
|
||||
public string? FilePath { get; set; }
|
||||
|
||||
public string? FileName { get; set; }
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -1,4 +1,4 @@
|
||||
namespace AntSK.Domain.Model
|
||||
namespace AntSK.Domain.Domain.Model
|
||||
{
|
||||
public class PageList<T>
|
||||
{
|
||||
40
src/AntSK.Domain/Domain/Model/hfmirror/HfModel.cs
Normal file
40
src/AntSK.Domain/Domain/Model/hfmirror/HfModel.cs
Normal file
@@ -0,0 +1,40 @@
|
||||
using System;
|
||||
using System.Collections.Generic;
|
||||
using System.Linq;
|
||||
using System.Text;
|
||||
using System.Threading.Tasks;
|
||||
|
||||
namespace AntSK.Domain.Domain.Model.hfmirror
|
||||
{
|
||||
public class HfModel
|
||||
{
|
||||
public List<HfModels> models { get; set; }
|
||||
public int numItemsPerPage { get; set; }
|
||||
public int numTotalItems { get; set; }
|
||||
public int pageIndex { get; set; }
|
||||
}
|
||||
public class HfModels
|
||||
{
|
||||
public string Author { get; set; }
|
||||
public HfAuthorData AuthorData { get; set; }
|
||||
public int Downloads { get; set; }
|
||||
public bool Gated { get; set; }
|
||||
public string Id { get; set; }
|
||||
public DateTime LastModified { get; set; }
|
||||
public int Likes { get; set; }
|
||||
public string PipelineTag { get; set; }
|
||||
public bool Private { get; set; }
|
||||
public string RepoType { get; set; }
|
||||
public bool IsLikedByUser { get; set; }
|
||||
}
|
||||
|
||||
public class HfAuthorData
|
||||
{
|
||||
public string AvatarUrl { get; set; }
|
||||
public string Fullname { get; set; }
|
||||
public string Name { get; set; }
|
||||
public string Type { get; set; }
|
||||
public bool IsHf { get; set; }
|
||||
public bool IsEnterprise { get; set; }
|
||||
}
|
||||
}
|
||||
17
src/AntSK.Domain/Domain/Model/hfmirror/HfModelDetail.cs
Normal file
17
src/AntSK.Domain/Domain/Model/hfmirror/HfModelDetail.cs
Normal file
@@ -0,0 +1,17 @@
|
||||
using System;
|
||||
using System.Collections.Generic;
|
||||
using System.Linq;
|
||||
using System.Text;
|
||||
using System.Threading.Tasks;
|
||||
|
||||
namespace AntSK.Domain.Domain.Model.hfmirror
|
||||
{
|
||||
public class HfModelDetail
|
||||
{
|
||||
public string Name { get; set; }
|
||||
public string Size { get; set; }
|
||||
public string Path { get; set; }
|
||||
|
||||
public string Time { get; set; }
|
||||
}
|
||||
}
|
||||
@@ -1,6 +1,6 @@
|
||||
using AntSK.BackgroundTask;
|
||||
using AntSK.Domain.Domain.Interface;
|
||||
using AntSK.Domain.Model;
|
||||
using AntSK.Domain.Domain.Model;
|
||||
using Microsoft.Extensions.DependencyInjection;
|
||||
|
||||
namespace AntSK.Domain.Domain.Other
|
||||
|
||||
@@ -31,7 +31,7 @@ namespace AntSK.Domain.Domain.Other
|
||||
{
|
||||
ContextSize = lsConfig?.ContextSize ?? 2048,
|
||||
Seed = lsConfig?.Seed ?? 0,
|
||||
GpuLayerCount = lsConfig?.GpuLayerCount ?? 20,
|
||||
GpuLayerCount = lsConfig?.GpuLayerCount ?? 10,
|
||||
EmbeddingMode = true
|
||||
};
|
||||
var weights = LLamaWeights.LoadFromFile(parameters);
|
||||
|
||||
@@ -3,23 +3,22 @@ using AntSK.Domain.Domain.Interface;
|
||||
using AntSK.Domain.Repositories;
|
||||
using Microsoft.SemanticKernel.Connectors.OpenAI;
|
||||
using Microsoft.SemanticKernel;
|
||||
using System;
|
||||
using System.Collections.Generic;
|
||||
using System.Linq;
|
||||
using System.Text;
|
||||
using System.Threading.Tasks;
|
||||
using AntSK.Domain.Utils;
|
||||
using AntSK.Domain.Domain.Model.Dto;
|
||||
using AntSK.Domain.Domain.Model.Constant;
|
||||
using DocumentFormat.OpenXml.Drawing;
|
||||
using System.Reflection.Metadata;
|
||||
using Microsoft.KernelMemory;
|
||||
using AntSK.Domain.Model;
|
||||
using MarkdownSharp;
|
||||
using AntSK.Domain.Domain.Dto;
|
||||
using System.Collections.Generic;
|
||||
using Markdig;
|
||||
|
||||
namespace AntSK.Domain.Domain.Service
|
||||
{
|
||||
[ServiceDescription(typeof(IChatService), ServiceLifetime.Scoped)]
|
||||
public class ChatService(
|
||||
IKernelService _kernelService,
|
||||
IKMService _kMService ,
|
||||
IKMService _kMService,
|
||||
IKmsDetails_Repositories _kmsDetails_Repositories
|
||||
) : IChatService
|
||||
{
|
||||
@@ -40,12 +39,11 @@ namespace AntSK.Domain.Domain.Service
|
||||
var _kernel = _kernelService.GetKernelByApp(app);
|
||||
var temperature = app.Temperature / 100;//存的是0~100需要缩小
|
||||
OpenAIPromptExecutionSettings settings = new() { Temperature = temperature };
|
||||
if (!string.IsNullOrEmpty(app.ApiFunctionList))
|
||||
if (!string.IsNullOrEmpty(app.ApiFunctionList) || !string.IsNullOrEmpty(app.NativeFunctionList))//这里还需要加上本地插件的
|
||||
{
|
||||
_kernelService.ImportFunctionsByApp(app, _kernel);
|
||||
settings.ToolCallBehavior = ToolCallBehavior.AutoInvokeKernelFunctions;
|
||||
}
|
||||
|
||||
var func = _kernel.CreateFunctionFromPrompt(app.Prompt, settings);
|
||||
var chatResult = _kernel.InvokeStreamingAsync(function: func, arguments: new KernelArguments() { ["input"] = $"{history}{Environment.NewLine} user:{questions}" });
|
||||
await foreach (var content in chatResult)
|
||||
@@ -54,52 +52,51 @@ namespace AntSK.Domain.Domain.Service
|
||||
}
|
||||
}
|
||||
|
||||
public async IAsyncEnumerable<StreamingKernelContent> SendKmsByAppAsync(Apps app, string questions, string history, List<RelevantSource> relevantSources=null)
|
||||
public async IAsyncEnumerable<StreamingKernelContent> SendKmsByAppAsync(Apps app, string questions, string history, string filePath, List<RelevantSource> relevantSources = null)
|
||||
{
|
||||
var relevantSourceList = await _kMService.GetRelevantSourceList(app.KmsIdList, questions);
|
||||
var _kernel = _kernelService.GetKernelByApp(app);
|
||||
//知识库问答
|
||||
var filters = new List<MemoryFilter>();
|
||||
var kmsidList = app.KmsIdList.Split(",");
|
||||
//只取第一个知识库的配置
|
||||
var _memory = _kMService.GetMemoryByKMS(kmsidList.FirstOrDefault());
|
||||
foreach (var kmsid in kmsidList)
|
||||
if (!string.IsNullOrWhiteSpace(filePath))
|
||||
{
|
||||
filters.Add(new MemoryFilter().ByTag("kmsid", kmsid));
|
||||
}
|
||||
var xlresult = await _memory.SearchAsync(questions, index: "kms", filters: filters);
|
||||
string dataMsg = "";
|
||||
if (xlresult != null)
|
||||
{
|
||||
foreach (var item in xlresult.Results)
|
||||
{
|
||||
foreach (var part in item.Partitions)
|
||||
{
|
||||
dataMsg += $"[file:{item.SourceName};Relevance:{(part.Relevance * 100).ToString("F2")}%]:{part.Text}{Environment.NewLine}";
|
||||
var memory = _kMService.GetMemory(app);
|
||||
var fileId = Guid.NewGuid().ToString();
|
||||
var result = await memory.ImportDocumentAsync(new Microsoft.KernelMemory.Document(fileId).AddFile(filePath)
|
||||
.AddTag(KmsConstantcs.KmsIdTag, app.Id)
|
||||
, index: KmsConstantcs.KmsIndex);
|
||||
|
||||
if (relevantSources.IsNotNull())
|
||||
{
|
||||
var markdown = new Markdown();
|
||||
string sourceName = item.SourceName;
|
||||
var fileDetail = _kmsDetails_Repositories.GetFirst(p => p.FileGuidName == item.SourceName);
|
||||
if (fileDetail.IsNotNull())
|
||||
{
|
||||
sourceName = fileDetail.FileName;
|
||||
}
|
||||
relevantSources.Add(new RelevantSource() { SourceName = sourceName, Text = markdown.Transform(part.Text), Relevance = part.Relevance });
|
||||
}
|
||||
}
|
||||
var filters = new MemoryFilter().ByTag(KmsConstantcs.KmsIdTag, app.Id);
|
||||
|
||||
var searchResult = await memory.SearchAsync(questions, index: KmsConstantcs.KmsIndex, filters: [filters]);
|
||||
relevantSourceList.AddRange(searchResult.Results.SelectMany(item => item.Partitions.Select(part => new RelevantSource()
|
||||
{
|
||||
SourceName = item.SourceName,
|
||||
Text = Markdown.ToHtml(part.Text),
|
||||
Relevance = part.Relevance
|
||||
})));
|
||||
}
|
||||
|
||||
var dataMsg = new StringBuilder();
|
||||
if (relevantSourceList.Any())
|
||||
{
|
||||
relevantSources?.AddRange(relevantSourceList);
|
||||
foreach (var item in relevantSourceList)
|
||||
{
|
||||
dataMsg.AppendLine(item.ToString());
|
||||
}
|
||||
|
||||
KernelFunction jsonFun = _kernel.Plugins.GetFunction("KMSPlugin", "Ask");
|
||||
var chatResult = _kernel.InvokeStreamingAsync(function: jsonFun,
|
||||
arguments: new KernelArguments() { ["doc"] = dataMsg, ["history"] = history, ["questions"] = questions });
|
||||
|
||||
MessageInfo info = null;
|
||||
var markdown1 = new Markdown();
|
||||
await foreach (var content in chatResult)
|
||||
{
|
||||
yield return content;
|
||||
}
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
yield return new StreamingTextContent(KmsConstantcs.KmsSearchNull);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
122
src/AntSK.Domain/Domain/Service/FunctionService.cs
Normal file
122
src/AntSK.Domain/Domain/Service/FunctionService.cs
Normal file
@@ -0,0 +1,122 @@
|
||||
using System.ComponentModel;
|
||||
using System.Reflection;
|
||||
using System.Runtime.Loader;
|
||||
using System.Xml;
|
||||
using AntSK.Domain.Common;
|
||||
using AntSK.Domain.Utils;
|
||||
using System.Text.RegularExpressions;
|
||||
using Microsoft.SemanticKernel;
|
||||
using HtmlAgilityPack;
|
||||
using System.Collections.Generic;
|
||||
|
||||
namespace AntSK.Domain.Domain.Service
|
||||
{
|
||||
public class FunctionService
|
||||
{
|
||||
private readonly Dictionary<string, MethodInfo> _methodCache;
|
||||
private readonly Dictionary<string, (string Description, (Type ParameterType, string Description) ReturnType, (string ParameterName, Type ParameterType, string Description)[] Parameters)> _methodInfos;
|
||||
|
||||
private readonly IServiceProvider _serviceProvider;
|
||||
private Assembly[] _assemblies;
|
||||
private readonly AssemblyLoadContext loadContext;
|
||||
|
||||
public FunctionService(IServiceProvider serviceProvider, Assembly[] assemblies)
|
||||
{
|
||||
_methodCache = [];
|
||||
_methodInfos = [];
|
||||
_serviceProvider = serviceProvider;
|
||||
_assemblies = assemblies;
|
||||
loadContext = new AssemblyLoadContext("AntSKLoadContext", true);
|
||||
}
|
||||
|
||||
public Dictionary<string, MethodInfo> Functions => _methodCache;
|
||||
public Dictionary<string, (string Description, (Type ParameterType, string Description) ReturnType, (string ParameterName, Type ParameterType, string Description)[] Parameters)> MethodInfos => _methodInfos;
|
||||
|
||||
/// <summary>
|
||||
/// 查询程序集中的方法委托,后续利用Source Generators生成
|
||||
/// </summary>
|
||||
public void SearchMarkedMethods()
|
||||
{
|
||||
var markedMethods = new List<MethodInfo>();
|
||||
|
||||
_methodCache.Clear();
|
||||
_methodInfos.Clear();
|
||||
|
||||
foreach (var assembly in _assemblies)
|
||||
{
|
||||
// 从缓存中获取标记了ActionAttribute的方法
|
||||
foreach (var type in assembly.GetTypes())
|
||||
{
|
||||
markedMethods.AddRange(type.GetMethods().Where(m =>
|
||||
{
|
||||
DescriptionAttribute da = (DescriptionAttribute)m.GetCustomAttributes(typeof(DescriptionAttribute), true).FirstOrDefault();
|
||||
return da != null && da.Description.Contains( "AntSK");
|
||||
}));
|
||||
}
|
||||
}
|
||||
|
||||
//动态加载部分
|
||||
var loadedAssemblies = loadContext.Assemblies.ToList();
|
||||
foreach (var assembly in loadedAssemblies)
|
||||
{
|
||||
// 从缓存中获取标记了ActionAttribute的方法
|
||||
foreach (var type in assembly.GetTypes())
|
||||
{
|
||||
markedMethods.AddRange(type.GetMethods().Where(m =>
|
||||
{
|
||||
DescriptionAttribute da = (DescriptionAttribute)m.GetCustomAttributes(typeof(DescriptionAttribute), true).FirstOrDefault();
|
||||
return da != null && da.Description.Contains("AntSK");
|
||||
}));
|
||||
}
|
||||
}
|
||||
|
||||
// 构建方法调用
|
||||
foreach (var method in markedMethods)
|
||||
{
|
||||
var key = $"{method.DeclaringType.Assembly.GetName().Name}_{method.DeclaringType.Name}_{method.Name}";
|
||||
string pattern = "[^a-zA-Z0-9_]";
|
||||
// 使用 '-' 替换非ASCII的正则表达式的字符
|
||||
key = Regex.Replace(key, pattern, "_");
|
||||
_methodCache.TryAdd(key, method);
|
||||
|
||||
var description= method.GetCustomAttribute<DescriptionAttribute>().Description.ConvertToString().Replace("AntSK:","");
|
||||
var returnType = method.ReturnParameter.GetCustomAttribute<DescriptionAttribute>().Description.ConvertToString();
|
||||
var parameters = method.GetParameters().Select(x => (x.Name, x.ParameterType,x.GetCustomAttribute<DescriptionAttribute>()?.Description)).ToArray();
|
||||
// 假设 _methodInfos 是一个已经定义好的字典,用来保存方法的相关信息
|
||||
_methodInfos.TryAdd(key, (description, (method.ReflectedType, returnType), parameters));
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
public void FuncLoad(string pluginPath)
|
||||
{
|
||||
try
|
||||
{
|
||||
if (File.Exists(pluginPath))
|
||||
{
|
||||
string directory = Path.GetDirectoryName(pluginPath);
|
||||
string fileName = Path.GetFileName(pluginPath);
|
||||
var resolver = new AssemblyDependencyResolver(directory);
|
||||
|
||||
// Create a custom AssemblyLoadContext
|
||||
|
||||
loadContext.Resolving += (context, assemblyName) =>
|
||||
{
|
||||
string assemblyPath = resolver.ResolveAssemblyToPath(assemblyName);
|
||||
if (assemblyPath != null)
|
||||
{
|
||||
return context.LoadFromAssemblyPath(assemblyPath);
|
||||
}
|
||||
return null;
|
||||
};
|
||||
// Load your assembly
|
||||
Assembly pluginAssembly = loadContext.LoadFromAssemblyPath(pluginPath);
|
||||
}
|
||||
}
|
||||
catch (Exception ex)
|
||||
{
|
||||
Console.WriteLine(ex.Message + " ---- " + ex.StackTrace);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -1,6 +1,7 @@
|
||||
using AntSK.Domain.Common.DependencyInjection;
|
||||
using AntSK.Domain.Domain.Interface;
|
||||
using AntSK.Domain.Model;
|
||||
using AntSK.Domain.Domain.Model;
|
||||
using AntSK.Domain.Domain.Model.Constant;
|
||||
using AntSK.Domain.Repositories;
|
||||
using Microsoft.KernelMemory;
|
||||
|
||||
@@ -29,8 +30,8 @@ namespace AntSK.Domain.Domain.Service
|
||||
{
|
||||
var importResult = _memory.ImportDocumentAsync(new Document(fileid)
|
||||
.AddFile(req.FilePath)
|
||||
.AddTag("kmsid", req.KmsId)
|
||||
, index: "kms").Result;
|
||||
.AddTag(KmsConstantcs.KmsIdTag, req.KmsId)
|
||||
, index: KmsConstantcs.KmsIndex).Result;
|
||||
//查询文档数量
|
||||
var docTextList = _kMService.GetDocumentByFileID(km.Id, fileid).Result;
|
||||
string fileGuidName = Path.GetFileName(req.FilePath);
|
||||
@@ -43,8 +44,8 @@ namespace AntSK.Domain.Domain.Service
|
||||
case ImportType.Url:
|
||||
{
|
||||
//导入url
|
||||
var importResult = _memory.ImportWebPageAsync(req.Url, fileid, new TagCollection() { { "kmsid", req.KmsId } }
|
||||
, index: "kms").Result;
|
||||
var importResult = _memory.ImportWebPageAsync(req.Url, fileid, new TagCollection() { { KmsConstantcs.KmsIdTag, req.KmsId } }
|
||||
, index: KmsConstantcs.KmsIndex).Result;
|
||||
//查询文档数量
|
||||
var docTextList = _kMService.GetDocumentByFileID(km.Id, fileid).Result;
|
||||
req.KmsDetail.Url = req.Url;
|
||||
@@ -54,8 +55,8 @@ namespace AntSK.Domain.Domain.Service
|
||||
case ImportType.Text:
|
||||
//导入文本
|
||||
{
|
||||
var importResult = _memory.ImportTextAsync(req.Text, fileid, new TagCollection() { { "kmsid", req.KmsId } }
|
||||
, index: "kms").Result;
|
||||
var importResult = _memory.ImportTextAsync(req.Text, fileid, new TagCollection() { { KmsConstantcs.KmsIdTag, req.KmsId } }
|
||||
, index: KmsConstantcs.KmsIndex).Result;
|
||||
//查询文档数量
|
||||
var docTextList = _kMService.GetDocumentByFileID(km.Id, fileid).Result;
|
||||
req.KmsDetail.Url = req.Url;
|
||||
|
||||
@@ -1,71 +1,125 @@
|
||||
using AntSK.Domain.Common.DependencyInjection;
|
||||
using AntSK.Domain.Domain.Dto;
|
||||
using AntDesign;
|
||||
using AntSK.Domain.Common.DependencyInjection;
|
||||
using AntSK.Domain.Domain.Interface;
|
||||
using AntSK.Domain.Domain.Model.Constant;
|
||||
using AntSK.Domain.Domain.Model.Dto;
|
||||
using AntSK.Domain.Domain.Other;
|
||||
using AntSK.Domain.Options;
|
||||
using AntSK.Domain.Repositories;
|
||||
using AntSK.Domain.Utils;
|
||||
using DocumentFormat.OpenXml.Drawing.Diagrams;
|
||||
using LLama;
|
||||
using LLamaSharp.KernelMemory;
|
||||
using Markdig;
|
||||
using Microsoft.AspNetCore.Components;
|
||||
using Microsoft.Extensions.Configuration;
|
||||
using Microsoft.KernelMemory;
|
||||
using Microsoft.KernelMemory.Configuration;
|
||||
using Microsoft.KernelMemory.ContentStorage.DevTools;
|
||||
using Microsoft.KernelMemory.FileSystem.DevTools;
|
||||
using Microsoft.KernelMemory.MemoryStorage;
|
||||
using Microsoft.KernelMemory.MemoryStorage.DevTools;
|
||||
using Microsoft.KernelMemory.Postgres;
|
||||
|
||||
namespace AntSK.Domain.Domain.Service
|
||||
{
|
||||
[ServiceDescription(typeof(IKMService), ServiceLifetime.Scoped)]
|
||||
public class KMService(
|
||||
IConfiguration _config,
|
||||
IKmss_Repositories _kmss_Repositories,
|
||||
IAIModels_Repositories _aIModels_Repositories
|
||||
) : IKMService
|
||||
IKmss_Repositories _kmss_Repositories,
|
||||
IAIModels_Repositories _aIModels_Repositories,
|
||||
IMessageService? _message
|
||||
) : IKMService
|
||||
{
|
||||
public MemoryServerless GetMemoryByKMS(string kmsID, SearchClientConfig searchClientConfig = null)
|
||||
{
|
||||
//获取KMS配置
|
||||
var kms = _kmss_Repositories.GetFirst(p => p.Id == kmsID);
|
||||
var chatModel = _aIModels_Repositories.GetFirst(p => p.Id == kms.ChatModelID);
|
||||
var embedModel = _aIModels_Repositories.GetFirst(p => p.Id == kms.EmbeddingModelID);
|
||||
private MemoryServerless _memory;
|
||||
|
||||
//http代理
|
||||
private List<UploadFileItem> _fileList = [];
|
||||
|
||||
public List<UploadFileItem> FileList => _fileList;
|
||||
|
||||
public MemoryServerless GetMemory(Apps app)
|
||||
{
|
||||
var chatModel = _aIModels_Repositories.GetFirst(p => p.Id == app.ChatModelID);
|
||||
var embedModel = _aIModels_Repositories.GetFirst(p => p.Id == app.EmbeddingModelID);
|
||||
var chatHttpClient = OpenAIHttpClientHandlerUtil.GetHttpClient(chatModel.EndPoint);
|
||||
var embeddingHttpClient = OpenAIHttpClientHandlerUtil.GetHttpClient(embedModel.EndPoint);
|
||||
|
||||
//搜索配置
|
||||
if (searchClientConfig.IsNull())
|
||||
var searchClientConfig = new SearchClientConfig
|
||||
{
|
||||
searchClientConfig = new SearchClientConfig
|
||||
{
|
||||
MaxAskPromptSize = 2048,
|
||||
MaxMatchesCount = 3,
|
||||
AnswerTokens = 1000,
|
||||
EmptyAnswer = "知识库未搜索到相关内容"
|
||||
};
|
||||
}
|
||||
MaxAskPromptSize = 2048,
|
||||
MaxMatchesCount = 3,
|
||||
AnswerTokens = 1000,
|
||||
EmptyAnswer = KmsConstantcs.KmsSearchNull
|
||||
};
|
||||
|
||||
var memory = new KernelMemoryBuilder()
|
||||
.WithSearchClientConfig(searchClientConfig)
|
||||
.WithCustomTextPartitioningOptions(new TextPartitioningOptions
|
||||
{
|
||||
MaxTokensPerLine = kms.MaxTokensPerLine,
|
||||
MaxTokensPerParagraph = kms.MaxTokensPerParagraph,
|
||||
OverlappingTokens = kms.OverlappingTokens
|
||||
});
|
||||
//加载huihu 模型
|
||||
WithTextGenerationByAIType(memory, chatModel, chatHttpClient);
|
||||
var memoryBuild = new KernelMemoryBuilder()
|
||||
.WithSearchClientConfig(searchClientConfig)
|
||||
//.WithCustomTextPartitioningOptions(new TextPartitioningOptions
|
||||
//{
|
||||
// MaxTokensPerLine = app.MaxTokensPerLine,
|
||||
// MaxTokensPerParagraph = kms.MaxTokensPerParagraph,
|
||||
// OverlappingTokens = kms.OverlappingTokens
|
||||
//})
|
||||
;
|
||||
//加载会话模型
|
||||
WithTextGenerationByAIType(memoryBuild, chatModel, chatHttpClient);
|
||||
//加载向量模型
|
||||
WithTextEmbeddingGenerationByAIType(memory, embedModel, embeddingHttpClient);
|
||||
WithTextEmbeddingGenerationByAIType(memoryBuild, embedModel, embeddingHttpClient);
|
||||
//加载向量库
|
||||
WithMemoryDbByVectorDB(memory, _config);
|
||||
|
||||
var result = memory.Build<MemoryServerless>();
|
||||
return result;
|
||||
WithMemoryDbByVectorDB(memoryBuild);
|
||||
|
||||
_memory = memoryBuild.Build<MemoryServerless>();
|
||||
return _memory;
|
||||
}
|
||||
|
||||
private void WithTextEmbeddingGenerationByAIType(IKernelMemoryBuilder memory, AIModels embedModel, HttpClient embeddingHttpClient)
|
||||
public MemoryServerless GetMemoryByKMS(string kmsID, SearchClientConfig searchClientConfig = null)
|
||||
{
|
||||
//if (_memory.IsNull())
|
||||
{
|
||||
//获取KMS配置
|
||||
var kms = _kmss_Repositories.GetFirst(p => p.Id == kmsID);
|
||||
var chatModel = _aIModels_Repositories.GetFirst(p => p.Id == kms.ChatModelID);
|
||||
var embedModel = _aIModels_Repositories.GetFirst(p => p.Id == kms.EmbeddingModelID);
|
||||
|
||||
//http代理
|
||||
var chatHttpClient = OpenAIHttpClientHandlerUtil.GetHttpClient(chatModel.EndPoint);
|
||||
var embeddingHttpClient = OpenAIHttpClientHandlerUtil.GetHttpClient(embedModel.EndPoint);
|
||||
|
||||
//搜索配置
|
||||
if (searchClientConfig.IsNull())
|
||||
{
|
||||
searchClientConfig = new SearchClientConfig
|
||||
{
|
||||
MaxAskPromptSize = 2048,
|
||||
MaxMatchesCount = 3,
|
||||
AnswerTokens = 1000,
|
||||
EmptyAnswer = KmsConstantcs.KmsSearchNull
|
||||
};
|
||||
}
|
||||
|
||||
var memoryBuild = new KernelMemoryBuilder()
|
||||
.WithSearchClientConfig(searchClientConfig)
|
||||
.WithCustomTextPartitioningOptions(new TextPartitioningOptions
|
||||
{
|
||||
MaxTokensPerLine = kms.MaxTokensPerLine,
|
||||
MaxTokensPerParagraph = kms.MaxTokensPerParagraph,
|
||||
OverlappingTokens = kms.OverlappingTokens
|
||||
});
|
||||
//加载会话模型
|
||||
WithTextGenerationByAIType(memoryBuild, chatModel, chatHttpClient);
|
||||
//加载向量模型
|
||||
WithTextEmbeddingGenerationByAIType(memoryBuild, embedModel, embeddingHttpClient);
|
||||
//加载向量库
|
||||
WithMemoryDbByVectorDB(memoryBuild);
|
||||
|
||||
_memory = memoryBuild.Build<MemoryServerless>();
|
||||
return _memory;
|
||||
}
|
||||
//else {
|
||||
// return _memory;
|
||||
//}
|
||||
}
|
||||
|
||||
private void WithTextEmbeddingGenerationByAIType(IKernelMemoryBuilder memory, AIModels embedModel,
|
||||
HttpClient embeddingHttpClient)
|
||||
{
|
||||
switch (embedModel.AIType)
|
||||
{
|
||||
@@ -76,6 +130,7 @@ namespace AntSK.Domain.Domain.Service
|
||||
EmbeddingModel = embedModel.ModelName
|
||||
}, null, false, embeddingHttpClient);
|
||||
break;
|
||||
|
||||
case Model.Enum.AIType.AzureOpenAI:
|
||||
memory.WithAzureOpenAITextEmbeddingGeneration(new AzureOpenAIConfig()
|
||||
{
|
||||
@@ -86,15 +141,20 @@ namespace AntSK.Domain.Domain.Service
|
||||
APIType = AzureOpenAIConfig.APITypes.EmbeddingGeneration,
|
||||
});
|
||||
break;
|
||||
|
||||
case Model.Enum.AIType.LLamaSharp:
|
||||
var (weights, parameters) = LLamaConfig.GetLLamaConfig(embedModel.ModelName);
|
||||
var embedder = new LLamaEmbedder(weights, parameters);
|
||||
memory.WithLLamaSharpTextEmbeddingGeneration(new LLamaSharpTextEmbeddingGenerator(embedder));
|
||||
break;
|
||||
case Model.Enum.AIType.DashScope:
|
||||
memory.WithDashScopeDefaults(embedModel.ModelKey);
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
private void WithTextGenerationByAIType(IKernelMemoryBuilder memory, AIModels chatModel, HttpClient chatHttpClient)
|
||||
private void WithTextGenerationByAIType(IKernelMemoryBuilder memory, AIModels chatModel,
|
||||
HttpClient chatHttpClient)
|
||||
{
|
||||
switch (chatModel.AIType)
|
||||
{
|
||||
@@ -105,6 +165,7 @@ namespace AntSK.Domain.Domain.Service
|
||||
TextModel = chatModel.ModelName
|
||||
}, null, chatHttpClient);
|
||||
break;
|
||||
|
||||
case Model.Enum.AIType.AzureOpenAI:
|
||||
memory.WithAzureOpenAITextGeneration(new AzureOpenAIConfig()
|
||||
{
|
||||
@@ -115,20 +176,27 @@ namespace AntSK.Domain.Domain.Service
|
||||
APIType = AzureOpenAIConfig.APITypes.TextCompletion,
|
||||
});
|
||||
break;
|
||||
|
||||
case Model.Enum.AIType.LLamaSharp:
|
||||
var (weights, parameters) = LLamaConfig.GetLLamaConfig(chatModel.ModelName);
|
||||
var context = weights.CreateContext(parameters);
|
||||
var executor = new StatelessExecutor(weights, parameters);
|
||||
memory.WithLLamaSharpTextGeneration(new LlamaSharpTextGenerator(weights, context, executor));
|
||||
break;
|
||||
case Model.Enum.AIType.DashScope:
|
||||
memory.WithDashScopeTextGeneration(new Cnblogs.KernelMemory.AI.DashScope.DashScopeConfig
|
||||
{
|
||||
ApiKey = chatModel.ModelKey,
|
||||
});
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
private void WithMemoryDbByVectorDB(IKernelMemoryBuilder memory, IConfiguration _config)
|
||||
private void WithMemoryDbByVectorDB(IKernelMemoryBuilder memory)
|
||||
{
|
||||
string VectorDb = _config["KernelMemory:VectorDb"].ConvertToString();
|
||||
string ConnectionString = _config["KernelMemory:ConnectionString"].ConvertToString();
|
||||
string TableNamePrefix = _config["KernelMemory:TableNamePrefix"].ConvertToString();
|
||||
string VectorDb = KernelMemoryOption.VectorDb.ConvertToString();
|
||||
string ConnectionString = KernelMemoryOption.ConnectionString.ConvertToString();
|
||||
string TableNamePrefix = KernelMemoryOption.TableNamePrefix.ConvertToString();
|
||||
switch (VectorDb)
|
||||
{
|
||||
case "Postgres":
|
||||
@@ -138,14 +206,16 @@ namespace AntSK.Domain.Domain.Service
|
||||
TableNamePrefix = TableNamePrefix
|
||||
});
|
||||
break;
|
||||
|
||||
case "Disk":
|
||||
memory.WithSimpleFileStorage(new SimpleFileStorageConfig()
|
||||
memory.WithSimpleVectorDb(new SimpleVectorDbConfig()
|
||||
{
|
||||
StorageType = FileSystemTypes.Disk
|
||||
StorageType = FileSystemTypes.Disk,
|
||||
});
|
||||
break;
|
||||
|
||||
case "Memory":
|
||||
memory.WithSimpleFileStorage(new SimpleFileStorageConfig()
|
||||
memory.WithSimpleVectorDb(new SimpleVectorDbConfig()
|
||||
{
|
||||
StorageType = FileSystemTypes.Volatile
|
||||
});
|
||||
@@ -153,36 +223,103 @@ namespace AntSK.Domain.Domain.Service
|
||||
}
|
||||
}
|
||||
|
||||
public async Task<List<KMFile>> GetDocumentByFileID(string kmsid, string fileid)
|
||||
public async Task<List<KMFile>> GetDocumentByFileID(string kmsId, string fileId)
|
||||
{
|
||||
var _memory = GetMemoryByKMS(kmsid);
|
||||
var memories = await _memory.ListIndexesAsync();
|
||||
var memoryDbs = _memory.Orchestrator.GetMemoryDbs();
|
||||
List<KMFile> docTextList = new List<KMFile>();
|
||||
var memory = GetMemoryByKMS(kmsId);
|
||||
var memories = await memory.ListIndexesAsync();
|
||||
var memoryDbs = memory.Orchestrator.GetMemoryDbs();
|
||||
var docTextList = new List<KMFile>();
|
||||
|
||||
foreach (var memoryIndex in memories)
|
||||
{
|
||||
foreach (var memoryDb in memoryDbs)
|
||||
{
|
||||
|
||||
var items = await memoryDb.GetListAsync(memoryIndex.Name, new List<MemoryFilter>() { new MemoryFilter().ByDocument(fileid) }, 100, true).ToListAsync();
|
||||
foreach (var item in items)
|
||||
var items = await memoryDb.GetListAsync(memoryIndex.Name, new List<MemoryFilter>() { new MemoryFilter().ByDocument(fileId) }, 100, true).ToListAsync();
|
||||
docTextList.AddRange(items.Select(item => new KMFile()
|
||||
{
|
||||
KMFile file = new KMFile()
|
||||
{
|
||||
Text = item.Payload.FirstOrDefault(p => p.Key == "text").Value.ConvertToString(),
|
||||
Url = item.Payload.FirstOrDefault(p => p.Key == "url").Value.ConvertToString(),
|
||||
LastUpdate = item.Payload.FirstOrDefault(p => p.Key == "last_update").Value.ConvertToString(),
|
||||
Schema = item.Payload.FirstOrDefault(p => p.Key == "schema").Value.ConvertToString(),
|
||||
File = item.Payload.FirstOrDefault(p => p.Key == "file").Value.ConvertToString(),
|
||||
};
|
||||
docTextList.Add(file);
|
||||
}
|
||||
DocumentId = item.GetDocumentId(),
|
||||
Text = item.GetPartitionText(),
|
||||
Url = item.GetWebPageUrl(),
|
||||
LastUpdate = item.GetLastUpdate().LocalDateTime.ToString("yyyy-MM-dd HH:mm:ss"),
|
||||
File = item.GetFileName()
|
||||
}));
|
||||
}
|
||||
}
|
||||
|
||||
return docTextList;
|
||||
}
|
||||
|
||||
public async Task<List<RelevantSource>> GetRelevantSourceList(string kmsIdListStr, string msg)
|
||||
{
|
||||
var result = new List<RelevantSource>();
|
||||
if (string.IsNullOrWhiteSpace(kmsIdListStr))
|
||||
return result;
|
||||
var kmsIdList = kmsIdListStr.Split(",");
|
||||
if (!kmsIdList.Any()) return result;
|
||||
|
||||
var memory = GetMemoryByKMS(kmsIdList.FirstOrDefault()!);
|
||||
|
||||
var filters = kmsIdList.Select(kmsId => new MemoryFilter().ByTag(KmsConstantcs.KmsIdTag, kmsId)).ToList();
|
||||
|
||||
var searchResult = await memory.SearchAsync(msg, index: KmsConstantcs.KmsIndex, filters: filters);
|
||||
if (!searchResult.NoResult)
|
||||
{
|
||||
foreach (var item in searchResult.Results)
|
||||
{
|
||||
result.AddRange(item.Partitions.Select(part => new RelevantSource()
|
||||
{
|
||||
SourceName = item.SourceName,
|
||||
Text = Markdown.ToHtml(part.Text),
|
||||
Relevance = part.Relevance
|
||||
}));
|
||||
}
|
||||
}
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
public bool BeforeUpload(UploadFileItem file)
|
||||
{
|
||||
List<string> types = new List<string>() {
|
||||
"text/plain",
|
||||
"application/msword",
|
||||
"application/vnd.openxmlformats-officedocument.wordprocessingml.document",
|
||||
"application/vnd.ms-excel",
|
||||
"application/vnd.openxmlformats-officedocument.spreadsheetml.sheet",
|
||||
"application/vnd.ms-powerpoint",
|
||||
"application/vnd.openxmlformats-officedocument.presentationml.presentation",
|
||||
"application/pdf",
|
||||
"application/json",
|
||||
"text/x-markdown",
|
||||
"text/markdown"
|
||||
};
|
||||
|
||||
string[] exceptExts = [".md", ".pdf"];
|
||||
var validTypes = types.Contains(file.Type) || exceptExts.Contains(file.Ext);
|
||||
if (!validTypes && file.Ext != ".md")
|
||||
{
|
||||
_message.Error("文件格式错误,请重新选择!");
|
||||
}
|
||||
var IsLt500K = file.Size < 1024 * 1024 * 100;
|
||||
if (!IsLt500K)
|
||||
{
|
||||
_message.Error("文件需不大于100MB!");
|
||||
}
|
||||
|
||||
return validTypes && IsLt500K;
|
||||
}
|
||||
|
||||
public void OnSingleCompleted(UploadInfo fileinfo)
|
||||
{
|
||||
if (fileinfo.File.State == UploadState.Success)
|
||||
{
|
||||
//文件列表
|
||||
_fileList.Add(new UploadFileItem()
|
||||
{
|
||||
FileName = fileinfo.File.FileName,
|
||||
Url = fileinfo.File.Url = fileinfo.File.Response
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -1,7 +1,7 @@
|
||||
using AntSK.Domain.Common.DependencyInjection;
|
||||
using AntSK.LLM.SparkDesk;
|
||||
using AntSK.Domain.Common.DependencyInjection;
|
||||
using AntSK.Domain.Domain.Interface;
|
||||
using AntSK.Domain.Domain.Other;
|
||||
using AntSK.Domain.Model;
|
||||
using AntSK.Domain.Repositories;
|
||||
using AntSK.Domain.Utils;
|
||||
using LLama;
|
||||
@@ -11,7 +11,13 @@ using Microsoft.SemanticKernel;
|
||||
using Microsoft.SemanticKernel.Plugins.Core;
|
||||
using Microsoft.SemanticKernel.TextGeneration;
|
||||
using RestSharp;
|
||||
using System;
|
||||
using ServiceLifetime = AntSK.Domain.Common.DependencyInjection.ServiceLifetime;
|
||||
using AntSK.LLM.Mock;
|
||||
using AntSK.Domain.Domain.Model.Enum;
|
||||
using AntSK.LLM.LLamaFactory;
|
||||
using System.Reflection;
|
||||
using DocumentFormat.OpenXml.Drawing;
|
||||
|
||||
namespace AntSK.Domain.Domain.Service
|
||||
{
|
||||
@@ -20,15 +26,22 @@ namespace AntSK.Domain.Domain.Service
|
||||
{
|
||||
private readonly IApis_Repositories _apis_Repositories;
|
||||
private readonly IAIModels_Repositories _aIModels_Repositories;
|
||||
private readonly FunctionService _functionService;
|
||||
private readonly IServiceProvider _serviceProvider;
|
||||
private Kernel _kernel;
|
||||
|
||||
public KernelService(
|
||||
IApis_Repositories apis_Repositories,
|
||||
IAIModels_Repositories aIModels_Repositories
|
||||
)
|
||||
IAIModels_Repositories aIModels_Repositories,
|
||||
FunctionService functionService,
|
||||
IServiceProvider serviceProvider)
|
||||
{
|
||||
_apis_Repositories = apis_Repositories;
|
||||
_aIModels_Repositories = aIModels_Repositories;
|
||||
|
||||
_functionService = functionService;
|
||||
_serviceProvider = serviceProvider;
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// 获取kernel实例,依赖注入不好按每个用户去Import不同的插件,所以每次new一个新的kernel
|
||||
/// </summary>
|
||||
@@ -37,20 +50,26 @@ namespace AntSK.Domain.Domain.Service
|
||||
/// <returns></returns>
|
||||
public Kernel GetKernelByApp(Apps app)
|
||||
{
|
||||
var chatModel = _aIModels_Repositories.GetFirst(p => p.Id == app.ChatModelID);
|
||||
//if (_kernel.IsNull())
|
||||
{
|
||||
var chatModel = _aIModels_Repositories.GetFirst(p => p.Id == app.ChatModelID);
|
||||
|
||||
var chatHttpClient = OpenAIHttpClientHandlerUtil.GetHttpClient(chatModel.EndPoint);
|
||||
var chatHttpClient = OpenAIHttpClientHandlerUtil.GetHttpClient(chatModel.EndPoint);
|
||||
|
||||
var builder = Kernel.CreateBuilder();
|
||||
WithTextGenerationByAIType(builder, chatModel, chatHttpClient);
|
||||
var builder = Kernel.CreateBuilder();
|
||||
WithTextGenerationByAIType(builder, app, chatModel, chatHttpClient);
|
||||
|
||||
|
||||
var kernel = builder.Build();
|
||||
RegisterPluginsWithKernel(kernel);
|
||||
return kernel;
|
||||
_kernel = builder.Build();
|
||||
RegisterPluginsWithKernel(_kernel);
|
||||
return _kernel;
|
||||
}
|
||||
//else
|
||||
//{
|
||||
// return _kernel;
|
||||
//}
|
||||
}
|
||||
|
||||
private void WithTextGenerationByAIType(IKernelBuilder builder, AIModels chatModel, HttpClient chatHttpClient)
|
||||
private void WithTextGenerationByAIType(IKernelBuilder builder, Apps app, AIModels chatModel, HttpClient chatHttpClient)
|
||||
{
|
||||
switch (chatModel.AIType)
|
||||
{
|
||||
@@ -60,6 +79,7 @@ namespace AntSK.Domain.Domain.Service
|
||||
apiKey: chatModel.ModelKey,
|
||||
httpClient: chatHttpClient);
|
||||
break;
|
||||
|
||||
case Model.Enum.AIType.AzureOpenAI:
|
||||
builder.AddAzureOpenAIChatCompletion(
|
||||
deploymentName: chatModel.ModelName,
|
||||
@@ -67,11 +87,32 @@ namespace AntSK.Domain.Domain.Service
|
||||
endpoint: chatModel.EndPoint
|
||||
);
|
||||
break;
|
||||
|
||||
case Model.Enum.AIType.LLamaSharp:
|
||||
var (weights, parameters) = LLamaConfig.GetLLamaConfig(chatModel.ModelName);
|
||||
var ex = new StatelessExecutor(weights, parameters);
|
||||
builder.Services.AddKeyedSingleton<ITextGenerationService>("local-llama", new LLamaSharpTextCompletion(ex));
|
||||
break;
|
||||
|
||||
case Model.Enum.AIType.SparkDesk:
|
||||
var options = new SparkDeskOptions { AppId = chatModel.EndPoint, ApiSecret = chatModel.ModelKey, ApiKey = chatModel.ModelName, ModelVersion = Sdcb.SparkDesk.ModelVersion.V3_5 };
|
||||
builder.Services.AddKeyedSingleton<ITextGenerationService>("spark-desk", new SparkDeskTextCompletion(options, app.Id));
|
||||
break;
|
||||
|
||||
case Model.Enum.AIType.DashScope:
|
||||
builder.Services.AddDashScopeChatCompletion(chatModel.ModelKey, chatModel.ModelName);
|
||||
break;
|
||||
|
||||
case Model.Enum.AIType.Mock:
|
||||
builder.Services.AddKeyedSingleton<ITextGenerationService>("mock", new MockTextCompletion());
|
||||
break;
|
||||
case Model.Enum.AIType.LLamaFactory:
|
||||
builder.AddOpenAIChatCompletion(
|
||||
modelId: chatModel.ModelName,
|
||||
apiKey: "123",
|
||||
httpClient: chatHttpClient
|
||||
);
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
@@ -82,25 +123,57 @@ namespace AntSK.Domain.Domain.Service
|
||||
/// <param name="_kernel"></param>
|
||||
public void ImportFunctionsByApp(Apps app, Kernel _kernel)
|
||||
{
|
||||
//开启自动插件调用
|
||||
var apiIdList = app.ApiFunctionList.Split(",");
|
||||
var apiList = _apis_Repositories.GetList(p => apiIdList.Contains(p.Id));
|
||||
List<KernelFunction> functions = new List<KernelFunction>();
|
||||
var plugin = _kernel.Plugins.FirstOrDefault(p => p.Name == "ApiFunctions");
|
||||
//插件不能重复注册,否则会异常
|
||||
if (_kernel.Plugins.Any(p => p.Name == "AntSkFunctions"))
|
||||
{
|
||||
return;
|
||||
}
|
||||
List<KernelFunction> functions = new List<KernelFunction>();
|
||||
|
||||
//API插件
|
||||
ImportApiFunction(app, functions);
|
||||
//本地函数插件
|
||||
ImportNativeFunction(app, functions);
|
||||
|
||||
_kernel.ImportPluginFromFunctions("AntSkFunctions", functions);
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// 导入API插件
|
||||
/// </summary>
|
||||
/// <param name="app"></param>
|
||||
/// <param name="functions"></param>
|
||||
private void ImportApiFunction(Apps app, List<KernelFunction> functions)
|
||||
{
|
||||
if (!string.IsNullOrWhiteSpace(app.ApiFunctionList))
|
||||
{
|
||||
//开启自动插件调用
|
||||
var apiIdList = app.ApiFunctionList.Split(",");
|
||||
var apiList = _apis_Repositories.GetList(p => apiIdList.Contains(p.Id));
|
||||
|
||||
foreach (var api in apiList)
|
||||
{
|
||||
var returnType = new KernelReturnParameterMetadata() { Description = api.OutputPrompt };
|
||||
switch (api.Method)
|
||||
{
|
||||
case HttpMethodType.Get:
|
||||
functions.Add(_kernel.CreateFunctionFromMethod((string msg) =>
|
||||
|
||||
var getParametes = new List<KernelParameterMetadata>() {
|
||||
new KernelParameterMetadata("jsonbody"){
|
||||
Name="json参数字符串",
|
||||
ParameterType=typeof(string),
|
||||
Description=$"需要根据背景文档:{Environment.NewLine}{api.InputPrompt} {Environment.NewLine}提取出对应的json格式字符串,参考如下格式:{Environment.NewLine}{api.Query}"
|
||||
}
|
||||
};
|
||||
functions.Add(_kernel.CreateFunctionFromMethod((string jsonbody) =>
|
||||
{
|
||||
try
|
||||
{
|
||||
Console.WriteLine(msg);
|
||||
//将json 转换为query参数
|
||||
var queryString = Newtonsoft.Json.JsonConvert.DeserializeObject<Dictionary<string, string>>(jsonbody);
|
||||
RestClient client = new RestClient();
|
||||
RestRequest request = new RestRequest(api.Url, Method.Get);
|
||||
foreach (var header in api.Header.Split("\n"))
|
||||
foreach (var header in api.Header.ConvertToString().Split("\n"))
|
||||
{
|
||||
var headerArray = header.Split(":");
|
||||
if (headerArray.Length == 2)
|
||||
@@ -109,13 +182,9 @@ namespace AntSK.Domain.Domain.Service
|
||||
}
|
||||
}
|
||||
//这里应该还要处理一次参数提取,等后面再迭代
|
||||
foreach (var query in api.Query.Split("\n"))
|
||||
foreach (var q in queryString)
|
||||
{
|
||||
var queryArray = query.Split("=");
|
||||
if (queryArray.Length == 2)
|
||||
{
|
||||
request.AddQueryParameter(queryArray[0], queryArray[1]);
|
||||
}
|
||||
request.AddQueryParameter(q.Key, q.Value);
|
||||
}
|
||||
var result = client.Execute(request);
|
||||
return result.Content;
|
||||
@@ -124,17 +193,25 @@ namespace AntSK.Domain.Domain.Service
|
||||
{
|
||||
return "调用失败:" + ex.Message;
|
||||
}
|
||||
}, api.Name, $"{api.Describe}"));
|
||||
}, api.Name, api.Describe, getParametes, returnType));
|
||||
break;
|
||||
case HttpMethodType.Post:
|
||||
functions.Add(_kernel.CreateFunctionFromMethod((string msg) =>
|
||||
//处理json body
|
||||
var postParametes = new List<KernelParameterMetadata>() {
|
||||
new KernelParameterMetadata("jsonbody"){
|
||||
Name="json参数字符串",
|
||||
ParameterType=typeof(string),
|
||||
Description=$"需要根据背景文档:{Environment.NewLine}{api.InputPrompt} {Environment.NewLine}提取出对应的json格式字符串,参考如下格式:{Environment.NewLine}{api.JsonBody}"
|
||||
}
|
||||
};
|
||||
functions.Add(_kernel.CreateFunctionFromMethod((string jsonBody) =>
|
||||
{
|
||||
try
|
||||
{
|
||||
Console.WriteLine(msg);
|
||||
Console.WriteLine(jsonBody);
|
||||
RestClient client = new RestClient();
|
||||
RestRequest request = new RestRequest(api.Url, Method.Post);
|
||||
foreach (var header in api.Header.Split("\n"))
|
||||
foreach (var header in api.Header.ConvertToString().Split("\n"))
|
||||
{
|
||||
var headerArray = header.Split(":");
|
||||
if (headerArray.Length == 2)
|
||||
@@ -143,7 +220,7 @@ namespace AntSK.Domain.Domain.Service
|
||||
}
|
||||
}
|
||||
//这里应该还要处理一次参数提取,等后面再迭代
|
||||
request.AddJsonBody(api.JsonBody);
|
||||
request.AddJsonBody(jsonBody.ConvertToString());
|
||||
var result = client.Execute(request);
|
||||
return result.Content;
|
||||
}
|
||||
@@ -151,24 +228,51 @@ namespace AntSK.Domain.Domain.Service
|
||||
{
|
||||
return "调用失败:" + ex.Message;
|
||||
}
|
||||
}, api.Name, $"{api.Describe}"));
|
||||
}, api.Name, api.Describe, postParametes, returnType));
|
||||
break;
|
||||
}
|
||||
}
|
||||
_kernel.ImportPluginFromFunctions("ApiFunctions", functions);
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// 导入原生插件
|
||||
/// </summary>
|
||||
/// <param name="app"></param>
|
||||
/// <param name="functions"></param>
|
||||
private void ImportNativeFunction(Apps app, List<KernelFunction> functions)
|
||||
{
|
||||
if (!string.IsNullOrWhiteSpace(app.NativeFunctionList))//需要添加判断应用是否开启了本地函数插件
|
||||
{
|
||||
var nativeIdList = app.NativeFunctionList.Split(",");
|
||||
|
||||
_functionService.SearchMarkedMethods();
|
||||
using var scope = _serviceProvider.CreateScope();
|
||||
|
||||
foreach (var func in _functionService.Functions)
|
||||
{
|
||||
if (nativeIdList.Contains(func.Key))
|
||||
{
|
||||
var methodInfo = _functionService.MethodInfos[func.Key];
|
||||
var parameters = methodInfo.Parameters.Select(x => new KernelParameterMetadata(x.ParameterName) { ParameterType = x.ParameterType, Description = x.Description });
|
||||
var returnType = new KernelReturnParameterMetadata() { ParameterType = methodInfo.ReturnType.ParameterType, Description = methodInfo.ReturnType.Description };
|
||||
var target = ActivatorUtilities.CreateInstance(scope.ServiceProvider, func.Value.DeclaringType);
|
||||
functions.Add(_kernel.CreateFunctionFromMethod(func.Value, target, func.Key, methodInfo.Description, parameters, returnType));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// 注册默认插件
|
||||
/// </summary>
|
||||
/// <param name="kernel"></param>
|
||||
void RegisterPluginsWithKernel(Kernel kernel)
|
||||
private void RegisterPluginsWithKernel(Kernel kernel)
|
||||
{
|
||||
kernel.ImportPluginFromObject(new ConversationSummaryPlugin(), "ConversationSummaryPlugin");
|
||||
kernel.ImportPluginFromObject(new TimePlugin(), "TimePlugin");
|
||||
kernel.ImportPluginFromPromptDirectory(Path.Combine(RepoFiles.SamplePluginsPath(), "KMSPlugin"));
|
||||
//kernel.ImportPluginFromObject(new TimePlugin(), "TimePlugin");
|
||||
kernel.ImportPluginFromPromptDirectory(System.IO.Path.Combine(RepoFiles.SamplePluginsPath(), "KMSPlugin"));
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
@@ -183,8 +287,8 @@ namespace AntSK.Domain.Domain.Service
|
||||
KernelFunction sunFun = _kernel.Plugins.GetFunction("ConversationSummaryPlugin", "SummarizeConversation");
|
||||
var summary = await _kernel.InvokeAsync(sunFun, new() { ["input"] = $"内容是:{history.ToString()} {Environment.NewLine} 请注意用中文总结" });
|
||||
string his = summary.GetValue<string>();
|
||||
var msg = $"history:{history.ToString()}{Environment.NewLine} user:{questions}"; ;
|
||||
var msg = $"history:{Environment.NewLine}{history.ToString()}{Environment.NewLine} user:{questions}{Environment.NewLine}"; ;
|
||||
return msg;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
164
src/AntSK.Domain/Domain/Service/LLamaFactoryService.cs
Normal file
164
src/AntSK.Domain/Domain/Service/LLamaFactoryService.cs
Normal file
@@ -0,0 +1,164 @@
|
||||
using AntSK.Domain.Common.DependencyInjection;
|
||||
using AntSK.Domain.Domain.Interface;
|
||||
using AntSK.Domain.Domain.Model.Dto;
|
||||
using AntSK.Domain.Options;
|
||||
using AntSK.LLamaFactory.Model;
|
||||
using Microsoft.AspNetCore.Mvc.ModelBinding;
|
||||
using Newtonsoft.Json;
|
||||
using System;
|
||||
using System.Collections.Generic;
|
||||
using System.Diagnostics;
|
||||
using System.Linq;
|
||||
using System.Text;
|
||||
using System.Text.Json;
|
||||
using System.Threading.Tasks;
|
||||
|
||||
namespace AntSK.Domain.Domain.Service
|
||||
{
|
||||
[ServiceDescription(typeof(ILLamaFactoryService), ServiceLifetime.Singleton)]
|
||||
public class LLamaFactoryService : ILLamaFactoryService
|
||||
{
|
||||
private Process process;
|
||||
|
||||
public static bool isProcessComplete = false;
|
||||
|
||||
private readonly object _syncLock = new object();
|
||||
private List<LLamaModel> modelList = new List<LLamaModel>();
|
||||
|
||||
public LLamaFactoryService() { }
|
||||
public delegate Task LogMessageHandler(string message);
|
||||
public event LogMessageHandler LogMessageReceived;
|
||||
protected virtual async Task OnLogMessageReceived(string message)
|
||||
{
|
||||
LogMessageReceived?.Invoke(message);
|
||||
}
|
||||
|
||||
public async Task PipInstall()
|
||||
{
|
||||
|
||||
var cmdTask = Task.Factory.StartNew(() =>
|
||||
{
|
||||
|
||||
var isProcessComplete = false;
|
||||
|
||||
process = new Process
|
||||
{
|
||||
StartInfo = new ProcessStartInfo
|
||||
{
|
||||
FileName = "pip",
|
||||
Arguments = "install -r requirements.txt -i https://pypi.tuna.tsinghua.edu.cn/simple",
|
||||
UseShellExecute = false,
|
||||
RedirectStandardOutput = true,
|
||||
RedirectStandardError = true,
|
||||
WorkingDirectory = AppDomain.CurrentDomain.BaseDirectory,
|
||||
}
|
||||
};
|
||||
process.OutputDataReceived += (sender, eventArgs) =>
|
||||
{
|
||||
Console.WriteLine($"{eventArgs.Data}");
|
||||
OnLogMessageReceived(eventArgs.Data);
|
||||
};
|
||||
process.ErrorDataReceived += (sender, eventArgs) =>
|
||||
{
|
||||
Console.WriteLine($"{eventArgs.Data}");
|
||||
OnLogMessageReceived(eventArgs.Data);
|
||||
};
|
||||
process.Start();
|
||||
process.BeginOutputReadLine();
|
||||
process.BeginErrorReadLine();
|
||||
process.WaitForExit();
|
||||
}, TaskCreationOptions.LongRunning);
|
||||
}
|
||||
|
||||
public async Task StartLLamaFactory(string modelName, string templateName)
|
||||
{
|
||||
var cmdTask = Task.Factory.StartNew(() =>
|
||||
{
|
||||
|
||||
var isProcessComplete = false;
|
||||
|
||||
process = new Process
|
||||
{
|
||||
StartInfo = new ProcessStartInfo
|
||||
{
|
||||
FileName = "python",
|
||||
Arguments = "api_demo.py --model_name_or_path " + modelName + " --template " + templateName + " ",
|
||||
UseShellExecute = false,
|
||||
RedirectStandardOutput = true,
|
||||
RedirectStandardError=true,
|
||||
WorkingDirectory = Path.Combine(Path.GetDirectoryName(System.Reflection.Assembly.GetEntryAssembly().Location), "llamafactory"),
|
||||
}
|
||||
};
|
||||
process.StartInfo.Environment["CUDA_VISIBLE_DEVICES"] = "0";
|
||||
process.StartInfo.Environment["API_PORT"] = "8000";
|
||||
process.StartInfo.EnvironmentVariables["USE_MODELSCOPE_HUB"] = "1";
|
||||
process.OutputDataReceived += (sender, eventArgs) =>
|
||||
{
|
||||
Console.WriteLine($"{eventArgs.Data}");
|
||||
OnLogMessageReceived(eventArgs.Data);
|
||||
};
|
||||
process.ErrorDataReceived += (sender, eventArgs) =>
|
||||
{
|
||||
Console.WriteLine($"{eventArgs.Data}");
|
||||
OnLogMessageReceived(eventArgs.Data);
|
||||
};
|
||||
process.Start();
|
||||
process.BeginOutputReadLine();
|
||||
process.BeginErrorReadLine();
|
||||
process.WaitForExit();
|
||||
}, TaskCreationOptions.LongRunning);
|
||||
}
|
||||
|
||||
private void Process_OutputDataReceived(object sender, DataReceivedEventArgs e)
|
||||
{
|
||||
throw new NotImplementedException();
|
||||
}
|
||||
|
||||
public string WaitForProcessExit()
|
||||
{
|
||||
process.WaitForExit();
|
||||
return process.StandardOutput.ReadToEnd();
|
||||
}
|
||||
|
||||
public void KillProcess()
|
||||
{
|
||||
|
||||
try
|
||||
{
|
||||
Process[] processes = Process.GetProcesses();
|
||||
foreach (Process process1 in processes)
|
||||
{
|
||||
if (process1.ProcessName.ToLower() == "python")
|
||||
{
|
||||
process1.Kill();
|
||||
System.Console.WriteLine("kill python");
|
||||
}
|
||||
}
|
||||
}
|
||||
catch (InvalidOperationException ex)
|
||||
{
|
||||
// Process already exited.
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
public List<LLamaModel> GetLLamaFactoryModels()
|
||||
{
|
||||
if (modelList.Count==0)
|
||||
{
|
||||
string jsonString = File.ReadAllText(Path.Combine(AppDomain.CurrentDomain.BaseDirectory, "modelList.json"));
|
||||
|
||||
// 反序列化 JSON 字符串到相应的 C# 对象
|
||||
var Models = JsonConvert.DeserializeObject<List<LLamaFactoryModel>>(jsonString);
|
||||
foreach (var model in Models)
|
||||
{
|
||||
foreach (var m in model.Models)
|
||||
{
|
||||
modelList.Add(new LLamaModel() { Name=m.Key, ModelScope=m.Value.MODELSCOPE });
|
||||
}
|
||||
}
|
||||
}
|
||||
return modelList;
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -1,13 +0,0 @@
|
||||
using AutoMapper;
|
||||
|
||||
namespace AntSK.Domain.Map
|
||||
{
|
||||
public class AutoMapProfile : Profile
|
||||
{
|
||||
public AutoMapProfile()
|
||||
{
|
||||
|
||||
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -1,48 +0,0 @@
|
||||
using AutoMapper;
|
||||
|
||||
namespace AntSK.Domain.Map
|
||||
{
|
||||
public static class MapperExtend
|
||||
{
|
||||
/// <summary>
|
||||
/// Entity集合转DTO集合
|
||||
/// </summary>
|
||||
/// <typeparam name="T"></typeparam>
|
||||
/// <param name="value"></param>
|
||||
/// <returns></returns>
|
||||
public static List<T> ToDTOList<T>(this object value)
|
||||
{
|
||||
if (value == null)
|
||||
return new List<T>();
|
||||
|
||||
return Mapper.Map<List<T>>(value);
|
||||
}
|
||||
/// <summary>
|
||||
/// Entity转DTO
|
||||
/// </summary>
|
||||
/// <typeparam name="T"></typeparam>
|
||||
/// <param name="value"></param>
|
||||
/// <returns></returns>
|
||||
public static T ToDTO<T>(this object value)
|
||||
{
|
||||
if (value == null)
|
||||
return default(T);
|
||||
|
||||
return Mapper.Map<T>(value);
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// 给已有对象map,适合update场景,如需过滤空值需要在AutoMapProfile 设置
|
||||
/// </summary>
|
||||
/// <typeparam name="T"></typeparam>
|
||||
/// <param name="self"></param>
|
||||
/// <param name="result"></param>
|
||||
/// <returns></returns>
|
||||
public static T MapTo<T>(this object self, T result)
|
||||
{
|
||||
if (self == null)
|
||||
return default(T);
|
||||
return (T)Mapper.Map(self, result, self.GetType(), typeof(T));
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -1,30 +0,0 @@
|
||||
using AutoMapper;
|
||||
using Microsoft.Extensions.DependencyInjection;
|
||||
|
||||
namespace AntSK.Domain.Map
|
||||
{
|
||||
public static class MapperRegister
|
||||
{
|
||||
public static void AddMapper(this IServiceCollection services)
|
||||
{
|
||||
var config = new MapperConfiguration(cfg =>
|
||||
{
|
||||
cfg.CreateMissingTypeMaps = true;
|
||||
cfg.ValidateInlineMaps = false;
|
||||
cfg.ShouldMapMethod = m => false;
|
||||
cfg.AddProfile<AutoMapProfile>();
|
||||
});
|
||||
|
||||
IMapper mapper = config.CreateMapper();
|
||||
|
||||
//启动实体映射
|
||||
Mapper.Initialize(cfg =>
|
||||
{
|
||||
cfg.CreateMissingTypeMaps = true;
|
||||
cfg.ValidateInlineMaps = false;
|
||||
cfg.ShouldMapMethod = m => false;
|
||||
cfg.AddProfile<AutoMapProfile>();
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -1,21 +0,0 @@
|
||||
namespace AntSK.Domain.Model.Enum
|
||||
{
|
||||
/// <summary>
|
||||
/// AI类型
|
||||
/// </summary>
|
||||
public enum AIType
|
||||
{
|
||||
OpenAI = 1,
|
||||
AzureOpenAI = 2,
|
||||
LLamaSharp = 3
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// 模型类型
|
||||
/// </summary>
|
||||
public enum AIModelType
|
||||
{
|
||||
Chat = 1,
|
||||
Embedding = 2,
|
||||
}
|
||||
}
|
||||
@@ -6,5 +6,7 @@
|
||||
public static string Chat { get; set; }
|
||||
|
||||
public static string Embedding { get; set; }
|
||||
|
||||
public static string FileDirectory { get; set; } = Directory.GetCurrentDirectory();
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
using AntSK.Domain.Model;
|
||||
using AntSK.Domain.Domain.Model.Enum;
|
||||
using SqlSugar;
|
||||
using System.ComponentModel.DataAnnotations;
|
||||
|
||||
|
||||
@@ -20,6 +20,7 @@ namespace AntSK.Domain.Repositories
|
||||
/// </summary>
|
||||
[Required]
|
||||
public string Describe { get; set; }
|
||||
|
||||
/// <summary>
|
||||
/// 图标
|
||||
/// </summary>
|
||||
@@ -38,6 +39,11 @@ namespace AntSK.Domain.Repositories
|
||||
[Required]
|
||||
public string? ChatModelID { get; set; }
|
||||
|
||||
/// <summary>
|
||||
/// Embedding 模型Id
|
||||
/// </summary>
|
||||
public string? EmbeddingModelID { get; set; }
|
||||
|
||||
/// <summary>
|
||||
/// 温度
|
||||
/// </summary>
|
||||
@@ -52,16 +58,23 @@ namespace AntSK.Domain.Repositories
|
||||
/// <summary>
|
||||
/// 插件列表
|
||||
/// </summary>
|
||||
[SugarColumn(ColumnDataType = "varchar(1000)")]
|
||||
public string? ApiFunctionList { get; set; }
|
||||
|
||||
/// <summary>
|
||||
/// 本地函数列表
|
||||
/// </summary>
|
||||
[SugarColumn(ColumnDataType = "varchar(1000)")]
|
||||
public string? NativeFunctionList { get; set; }
|
||||
|
||||
/// <summary>
|
||||
/// 知识库ID列表
|
||||
/// </summary>
|
||||
public string? KmsIdList { get; set; }
|
||||
|
||||
/// <summary>
|
||||
/// API调用秘钥
|
||||
/// </summary>
|
||||
public string? SecretKey { get; set; }
|
||||
}
|
||||
}
|
||||
}
|
||||
20
src/AntSK.Domain/Repositories/AI/Fun/Funs.cs
Normal file
20
src/AntSK.Domain/Repositories/AI/Fun/Funs.cs
Normal file
@@ -0,0 +1,20 @@
|
||||
using AntSK.Domain.Domain.Model.Enum;
|
||||
using SqlSugar;
|
||||
using System.ComponentModel.DataAnnotations;
|
||||
|
||||
namespace AntSK.Domain.Repositories
|
||||
{
|
||||
[SugarTable("Funs")]
|
||||
public partial class Funs
|
||||
{
|
||||
[SugarColumn(IsPrimaryKey = true)]
|
||||
public string Id { get; set; }
|
||||
|
||||
/// <summary>
|
||||
/// 接口描述
|
||||
/// </summary>
|
||||
[Required]
|
||||
public string Path { get; set; }
|
||||
|
||||
}
|
||||
}
|
||||
11
src/AntSK.Domain/Repositories/AI/Fun/Funs_Repositories.cs
Normal file
11
src/AntSK.Domain/Repositories/AI/Fun/Funs_Repositories.cs
Normal file
@@ -0,0 +1,11 @@
|
||||
|
||||
using AntSK.Domain.Common.DependencyInjection;
|
||||
using AntSK.Domain.Repositories.Base;
|
||||
|
||||
namespace AntSK.Domain.Repositories
|
||||
{
|
||||
[ServiceDescription(typeof(IFuns_Repositories), ServiceLifetime.Scoped)]
|
||||
public class Funs_Repositories : Repository<Funs>, IFuns_Repositories
|
||||
{
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,8 @@
|
||||
using AntSK.Domain.Repositories.Base;
|
||||
|
||||
namespace AntSK.Domain.Repositories
|
||||
{
|
||||
public interface IFuns_Repositories : IRepository<Funs>
|
||||
{
|
||||
}
|
||||
}
|
||||
@@ -1,4 +1,4 @@
|
||||
using AntSK.Domain.Model.Enum;
|
||||
using AntSK.Domain.Domain.Model.Enum;
|
||||
using SqlSugar;
|
||||
|
||||
namespace AntSK.Domain.Repositories
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
using AntSK.Domain.Model;
|
||||
using AntSK.Domain.Domain.Model;
|
||||
using SqlSugar;
|
||||
using System.Linq.Expressions;
|
||||
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
using AntSK.Domain.Map;
|
||||
using AntSK.Domain.Model;
|
||||
using AntSK.Domain.Common.Map;
|
||||
using AntSK.Domain.Domain.Model;
|
||||
|
||||
using SqlSugar;
|
||||
using System.Linq.Expressions;
|
||||
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
using AntSK.Domain.Model.Enum;
|
||||
using AntSK.Domain.Domain.Model.Enum;
|
||||
using SqlSugar;
|
||||
using System.ComponentModel.DataAnnotations;
|
||||
|
||||
|
||||
16
src/AntSK.Domain/Repositories/Setting/Dic/Dics.cs
Normal file
16
src/AntSK.Domain/Repositories/Setting/Dic/Dics.cs
Normal file
@@ -0,0 +1,16 @@
|
||||
using AntSK.Domain.Domain.Model.Enum;
|
||||
using SqlSugar;
|
||||
using System.ComponentModel.DataAnnotations;
|
||||
|
||||
namespace AntSK.Domain.Repositories
|
||||
{
|
||||
[SugarTable("Dics")]
|
||||
public partial class Dics
|
||||
{
|
||||
[SugarColumn(IsPrimaryKey = true)]
|
||||
public string Id { get; set; }
|
||||
public string Type { get; set; }
|
||||
public string Key { get; set; }
|
||||
public string Value { get; set; }
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,11 @@
|
||||
|
||||
using AntSK.Domain.Common.DependencyInjection;
|
||||
using AntSK.Domain.Repositories.Base;
|
||||
|
||||
namespace AntSK.Domain.Repositories
|
||||
{
|
||||
[ServiceDescription(typeof(IDics_Repositories), ServiceLifetime.Scoped)]
|
||||
public class Dics_Repositories : Repository<Dics>, IDics_Repositories
|
||||
{
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,8 @@
|
||||
using AntSK.Domain.Repositories.Base;
|
||||
|
||||
namespace AntSK.Domain.Repositories
|
||||
{
|
||||
public interface IDics_Repositories : IRepository<Dics>
|
||||
{
|
||||
}
|
||||
}
|
||||
@@ -1,4 +1,6 @@
|
||||
namespace AntSK.Domain.Utils
|
||||
using System.Web;
|
||||
|
||||
namespace AntSK.Domain.Utils
|
||||
{
|
||||
public static class ConvertUtils
|
||||
{
|
||||
@@ -231,5 +233,22 @@
|
||||
{
|
||||
return $"{Environment.NewLine}```json{Environment.NewLine}{s}{Environment.NewLine}```{Environment.NewLine}";
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// json参数转化querystring参数
|
||||
/// </summary>
|
||||
/// <param name="parameters"></param>
|
||||
/// <returns></returns>
|
||||
public static string ToQueryString(this Dictionary<string, string> parameters)
|
||||
{
|
||||
var nameValueCollection = HttpUtility.ParseQueryString(string.Empty);
|
||||
|
||||
foreach (var param in parameters)
|
||||
{
|
||||
nameValueCollection[param.Key] = param.Value;
|
||||
}
|
||||
|
||||
return nameValueCollection.ToString();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
375
src/AntSK.Domain/Utils/XmlCommentHelper.cs
Normal file
375
src/AntSK.Domain/Utils/XmlCommentHelper.cs
Normal file
@@ -0,0 +1,375 @@
|
||||
using System;
|
||||
using System.Collections.Generic;
|
||||
using System.Linq;
|
||||
using System.Reflection;
|
||||
using System.Text;
|
||||
using System.Text.RegularExpressions;
|
||||
using System.Threading.Tasks;
|
||||
using System.Xml.XPath;
|
||||
|
||||
namespace AntSK.Domain.Utils
|
||||
{
|
||||
/// <summary>
|
||||
/// 注释辅助类
|
||||
/// </summary>
|
||||
public class XmlCommentHelper
|
||||
{
|
||||
private static Regex RefTagPattern = new Regex(@"<(see|paramref) (name|cref)=""([TPF]{1}:)?(?<display>.+?)"" ?/>");
|
||||
private static Regex CodeTagPattern = new Regex(@"<c>(?<display>.+?)</c>");
|
||||
private static Regex ParaTagPattern = new Regex(@"<para>(?<display>.+?)</para>", RegexOptions.Singleline);
|
||||
|
||||
List<XPathNavigator> navigators = new List<XPathNavigator>();
|
||||
|
||||
/// <summary>
|
||||
/// 从当前dll文件中加载所有的xml文件
|
||||
/// </summary>
|
||||
public void LoadAll()
|
||||
{
|
||||
var files = Directory.GetFiles(Directory.GetCurrentDirectory());
|
||||
foreach (var file in files)
|
||||
{
|
||||
if (string.Equals(Path.GetExtension(file), ".xml", StringComparison.OrdinalIgnoreCase))
|
||||
{
|
||||
Load(file);
|
||||
}
|
||||
}
|
||||
}
|
||||
/// <summary>
|
||||
/// 从xml中加载
|
||||
/// </summary>
|
||||
/// <param name="xmls"></param>
|
||||
public void LoadXml(params string[] xmls)
|
||||
{
|
||||
foreach (var xml in xmls)
|
||||
{
|
||||
Load(new MemoryStream(Encoding.UTF8.GetBytes(xml)));
|
||||
}
|
||||
}
|
||||
/// <summary>
|
||||
/// 从文件中加载
|
||||
/// </summary>
|
||||
/// <param name="xmlFiles"></param>
|
||||
public void Load(params string[] xmlFiles)
|
||||
{
|
||||
foreach (var xmlFile in xmlFiles)
|
||||
{
|
||||
var doc = new XPathDocument(xmlFile);
|
||||
navigators.Add(doc.CreateNavigator());
|
||||
}
|
||||
}
|
||||
/// <summary>
|
||||
/// 从流中加载
|
||||
/// </summary>
|
||||
/// <param name="streams"></param>
|
||||
public void Load(params Stream[] streams)
|
||||
{
|
||||
foreach (var stream in streams)
|
||||
{
|
||||
var doc = new XPathDocument(stream);
|
||||
navigators.Add(doc.CreateNavigator());
|
||||
}
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// 读取类型中的注释
|
||||
/// </summary>
|
||||
/// <param name="type">类型</param>
|
||||
/// <param name="xPath">注释路径</param>
|
||||
/// <param name="humanize">可读性优化(比如:去掉xml标记)</param>
|
||||
/// <returns></returns>
|
||||
public string GetTypeComment(Type type, string xPath = "summary", bool humanize = true)
|
||||
{
|
||||
var typeMemberName = GetMemberNameForType(type);
|
||||
return GetComment(typeMemberName, xPath, humanize);
|
||||
}
|
||||
/// <summary>
|
||||
/// 读取字段或者属性的注释
|
||||
/// </summary>
|
||||
/// <param name="fieldOrPropertyInfo">字段或者属性</param>
|
||||
/// <param name="xPath">注释路径</param>
|
||||
/// <param name="humanize">可读性优化(比如:去掉xml标记)</param>
|
||||
/// <returns></returns>
|
||||
public string GetFieldOrPropertyComment(MemberInfo fieldOrPropertyInfo, string xPath = "summary", bool humanize = true)
|
||||
{
|
||||
var fieldOrPropertyMemberName = GetMemberNameForFieldOrProperty(fieldOrPropertyInfo);
|
||||
return GetComment(fieldOrPropertyMemberName, xPath, humanize);
|
||||
}
|
||||
/// <summary>
|
||||
/// 读取方法中的注释
|
||||
/// </summary>
|
||||
/// <param name="methodInfo">方法</param>
|
||||
/// <param name="xPath">注释路径</param>
|
||||
/// <param name="humanize">可读性优化(比如:去掉xml标记)</param>
|
||||
/// <returns></returns>
|
||||
public string GetMethodComment(MethodInfo methodInfo, string xPath = "summary", bool humanize = true)
|
||||
{
|
||||
var methodMemberName = GetMemberNameForMethod(methodInfo);
|
||||
return GetComment(methodMemberName, xPath, humanize);
|
||||
}
|
||||
/// <summary>
|
||||
/// 读取方法中的返回值注释
|
||||
/// </summary>
|
||||
/// <param name="methodInfo">方法</param>
|
||||
/// <param name="humanize">可读性优化(比如:去掉xml标记)</param>
|
||||
/// <returns></returns>
|
||||
public string GetMethodReturnComment(MethodInfo methodInfo, bool humanize = true)
|
||||
{
|
||||
return GetMethodComment(methodInfo, "returns", humanize);
|
||||
}
|
||||
/// <summary>
|
||||
/// 读取参数的注释
|
||||
/// </summary>
|
||||
/// <param name="parameterInfo">参数</param>
|
||||
/// <param name="humanize">可读性优化(比如:去掉xml标记)</param>
|
||||
/// <returns></returns>
|
||||
public string GetParameterComment(ParameterInfo parameterInfo, bool humanize = true)
|
||||
{
|
||||
if (!(parameterInfo.Member is MethodInfo methodInfo)) return string.Empty;
|
||||
|
||||
var methodMemberName = GetMemberNameForMethod(methodInfo);
|
||||
return GetComment(methodMemberName, $"param[@name='{parameterInfo.Name}']", humanize);
|
||||
}
|
||||
/// <summary>
|
||||
/// 读取方法的所有参数的注释
|
||||
/// </summary>
|
||||
/// <param name="methodInfo">方法</param>
|
||||
/// <param name="humanize">可读性优化(比如:去掉xml标记)</param>
|
||||
/// <returns></returns>
|
||||
public Dictionary<string, string> GetParameterComments(MethodInfo methodInfo, bool humanize = true)
|
||||
{
|
||||
var parameterInfos = methodInfo.GetParameters();
|
||||
Dictionary<string, string> dict = new Dictionary<string, string>();
|
||||
foreach (var parameterInfo in parameterInfos)
|
||||
{
|
||||
dict[parameterInfo.Name] = GetParameterComment(parameterInfo, humanize);
|
||||
}
|
||||
return dict;
|
||||
}
|
||||
/// <summary>
|
||||
/// 读取指定名称节点的注释
|
||||
/// </summary>
|
||||
/// <param name="name">节点名称</param>
|
||||
/// <param name="xPath">注释路径</param>
|
||||
/// <param name="humanize">可读性优化(比如:去掉xml标记)</param>
|
||||
/// <returns></returns>
|
||||
public string GetComment(string name, string xPath, bool humanize = true)
|
||||
{
|
||||
foreach (var _xmlNavigator in navigators)
|
||||
{
|
||||
var typeSummaryNode = _xmlNavigator.SelectSingleNode($"/doc/members/member[@name='{name}']/{xPath.Trim('/', '\\')}");
|
||||
|
||||
if (typeSummaryNode != null)
|
||||
{
|
||||
return humanize ? Humanize(typeSummaryNode.InnerXml) : typeSummaryNode.InnerXml;
|
||||
}
|
||||
}
|
||||
|
||||
return string.Empty;
|
||||
}
|
||||
/// <summary>
|
||||
/// 读取指定节点的summary注释
|
||||
/// </summary>
|
||||
/// <param name="name">节点名称</param>
|
||||
/// <param name="humanize">可读性优化(比如:去掉xml标记)</param>
|
||||
/// <returns></returns>
|
||||
public string GetSummary(string name, bool humanize = true)
|
||||
{
|
||||
return GetComment(name, "summary", humanize);
|
||||
}
|
||||
/// <summary>
|
||||
/// 读取指定节点的example注释
|
||||
/// </summary>
|
||||
/// <param name="name">节点名称</param>
|
||||
/// <param name="humanize">可读性优化(比如:去掉xml标记)</param>
|
||||
/// <returns></returns>
|
||||
public string GetExample(string name, bool humanize = true)
|
||||
{
|
||||
return GetComment(name, "example", humanize);
|
||||
}
|
||||
/// <summary>
|
||||
/// 获取方法的节点名称
|
||||
/// </summary>
|
||||
/// <param name="method"></param>
|
||||
/// <returns></returns>
|
||||
public string GetMemberNameForMethod(MethodInfo method)
|
||||
{
|
||||
var builder = new StringBuilder("M:");
|
||||
|
||||
builder.Append(QualifiedNameFor(method.DeclaringType));
|
||||
builder.Append($".{method.Name}");
|
||||
|
||||
var parameters = method.GetParameters();
|
||||
if (parameters.Any())
|
||||
{
|
||||
var parametersNames = parameters.Select(p =>
|
||||
{
|
||||
return p.ParameterType.IsGenericParameter
|
||||
? $"`{p.ParameterType.GenericParameterPosition}"
|
||||
: QualifiedNameFor(p.ParameterType, expandGenericArgs: true);
|
||||
});
|
||||
builder.Append($"({string.Join(",", parametersNames)})");
|
||||
}
|
||||
|
||||
return builder.ToString();
|
||||
}
|
||||
/// <summary>
|
||||
/// 获取类型的节点名称
|
||||
/// </summary>
|
||||
/// <param name="type"></param>
|
||||
/// <returns></returns>
|
||||
public string GetMemberNameForType(Type type)
|
||||
{
|
||||
var builder = new StringBuilder("T:");
|
||||
builder.Append(QualifiedNameFor(type));
|
||||
|
||||
return builder.ToString();
|
||||
}
|
||||
/// <summary>
|
||||
/// 获取字段或者属性的节点名称
|
||||
/// </summary>
|
||||
/// <param name="fieldOrPropertyInfo"></param>
|
||||
/// <returns></returns>
|
||||
public string GetMemberNameForFieldOrProperty(MemberInfo fieldOrPropertyInfo)
|
||||
{
|
||||
var builder = new StringBuilder((fieldOrPropertyInfo.MemberType & MemberTypes.Field) != 0 ? "F:" : "P:");
|
||||
builder.Append(QualifiedNameFor(fieldOrPropertyInfo.DeclaringType));
|
||||
builder.Append($".{fieldOrPropertyInfo.Name}");
|
||||
|
||||
return builder.ToString();
|
||||
}
|
||||
|
||||
private string QualifiedNameFor(Type type, bool expandGenericArgs = false)
|
||||
{
|
||||
if (type.IsArray)
|
||||
return $"{QualifiedNameFor(type.GetElementType(), expandGenericArgs)}[]";
|
||||
|
||||
var builder = new StringBuilder();
|
||||
|
||||
if (!string.IsNullOrEmpty(type.Namespace))
|
||||
builder.Append($"{type.Namespace}.");
|
||||
|
||||
if (type.IsNested)
|
||||
{
|
||||
builder.Append($"{string.Join(".", GetNestedTypeNames(type))}.");
|
||||
}
|
||||
|
||||
if (type.IsConstructedGenericType && expandGenericArgs)
|
||||
{
|
||||
var nameSansGenericArgs = type.Name.Split('`').First();
|
||||
builder.Append(nameSansGenericArgs);
|
||||
|
||||
var genericArgsNames = type.GetGenericArguments().Select(t =>
|
||||
{
|
||||
return t.IsGenericParameter
|
||||
? $"`{t.GenericParameterPosition}"
|
||||
: QualifiedNameFor(t, true);
|
||||
});
|
||||
|
||||
builder.Append($"{{{string.Join(",", genericArgsNames)}}}");
|
||||
}
|
||||
else
|
||||
{
|
||||
builder.Append(type.Name);
|
||||
}
|
||||
|
||||
return builder.ToString();
|
||||
}
|
||||
private IEnumerable<string> GetNestedTypeNames(Type type)
|
||||
{
|
||||
if (!type.IsNested || type.DeclaringType == null) yield break;
|
||||
|
||||
foreach (var nestedTypeName in GetNestedTypeNames(type.DeclaringType))
|
||||
{
|
||||
yield return nestedTypeName;
|
||||
}
|
||||
|
||||
yield return type.DeclaringType.Name;
|
||||
}
|
||||
private string Humanize(string text)
|
||||
{
|
||||
if (text == null)
|
||||
throw new ArgumentNullException("text");
|
||||
|
||||
//Call DecodeXml at last to avoid entities like < and > to break valid xml
|
||||
text = NormalizeIndentation(text);
|
||||
text = HumanizeRefTags(text);
|
||||
text = HumanizeCodeTags(text);
|
||||
text = HumanizeParaTags(text);
|
||||
text = DecodeXml(text);
|
||||
return text;
|
||||
}
|
||||
private string NormalizeIndentation(string text)
|
||||
{
|
||||
string[] lines = text.Split('\n');
|
||||
string padding = GetCommonLeadingWhitespace(lines);
|
||||
|
||||
int padLen = padding == null ? 0 : padding.Length;
|
||||
|
||||
// remove leading padding from each line
|
||||
for (int i = 0, l = lines.Length; i < l; ++i)
|
||||
{
|
||||
string line = lines[i].TrimEnd('\r'); // remove trailing '\r'
|
||||
|
||||
if (padLen != 0 && line.Length >= padLen && line.Substring(0, padLen) == padding)
|
||||
line = line.Substring(padLen);
|
||||
|
||||
lines[i] = line;
|
||||
}
|
||||
|
||||
// remove leading empty lines, but not all leading padding
|
||||
// remove all trailing whitespace, regardless
|
||||
return string.Join("\r\n", lines.SkipWhile(x => string.IsNullOrWhiteSpace(x))).TrimEnd();
|
||||
}
|
||||
private string GetCommonLeadingWhitespace(string[] lines)
|
||||
{
|
||||
if (null == lines)
|
||||
throw new ArgumentException("lines");
|
||||
|
||||
if (lines.Length == 0)
|
||||
return null;
|
||||
|
||||
string[] nonEmptyLines = lines
|
||||
.Where(x => !string.IsNullOrWhiteSpace(x))
|
||||
.ToArray();
|
||||
|
||||
if (nonEmptyLines.Length < 1)
|
||||
return null;
|
||||
|
||||
int padLen = 0;
|
||||
|
||||
// use the first line as a seed, and see what is shared over all nonEmptyLines
|
||||
string seed = nonEmptyLines[0];
|
||||
for (int i = 0, l = seed.Length; i < l; ++i)
|
||||
{
|
||||
if (!char.IsWhiteSpace(seed, i))
|
||||
break;
|
||||
|
||||
if (nonEmptyLines.Any(line => line[i] != seed[i]))
|
||||
break;
|
||||
|
||||
++padLen;
|
||||
}
|
||||
|
||||
if (padLen > 0)
|
||||
return seed.Substring(0, padLen);
|
||||
|
||||
return null;
|
||||
}
|
||||
private string HumanizeRefTags(string text)
|
||||
{
|
||||
return RefTagPattern.Replace(text, (match) => match.Groups["display"].Value);
|
||||
}
|
||||
private string HumanizeCodeTags(string text)
|
||||
{
|
||||
return CodeTagPattern.Replace(text, (match) => "{" + match.Groups["display"].Value + "}");
|
||||
}
|
||||
private string HumanizeParaTags(string text)
|
||||
{
|
||||
return ParaTagPattern.Replace(text, (match) => "<br>" + match.Groups["display"].Value);
|
||||
}
|
||||
private string DecodeXml(string text)
|
||||
{
|
||||
return System.Net.WebUtility.HtmlDecode(text);
|
||||
}
|
||||
}
|
||||
}
|
||||
21
src/AntSK.LLamaFactory/AntSK.LLamaFactory.csproj
Normal file
21
src/AntSK.LLamaFactory/AntSK.LLamaFactory.csproj
Normal file
@@ -0,0 +1,21 @@
|
||||
<Project Sdk="Microsoft.NET.Sdk">
|
||||
|
||||
<PropertyGroup>
|
||||
<TargetFramework>net8.0</TargetFramework>
|
||||
<ImplicitUsings>enable</ImplicitUsings>
|
||||
<Nullable>enable</Nullable>
|
||||
</PropertyGroup>
|
||||
<ItemGroup>
|
||||
<Content Include="llamafactory\**">
|
||||
<CopyToOutputDirectory>PreserveNewest</CopyToOutputDirectory>
|
||||
</Content>
|
||||
</ItemGroup>
|
||||
<ItemGroup>
|
||||
<None Update="modelList.json">
|
||||
<CopyToOutputDirectory>PreserveNewest</CopyToOutputDirectory>
|
||||
</None>
|
||||
<None Update="requirements.txt">
|
||||
<CopyToOutputDirectory>PreserveNewest</CopyToOutputDirectory>
|
||||
</None>
|
||||
</ItemGroup>
|
||||
</Project>
|
||||
27
src/AntSK.LLamaFactory/Model/LLamaFactoryModel.cs
Normal file
27
src/AntSK.LLamaFactory/Model/LLamaFactoryModel.cs
Normal file
@@ -0,0 +1,27 @@
|
||||
using System;
|
||||
using System.Collections.Generic;
|
||||
using System.Linq;
|
||||
using System.Text;
|
||||
using System.Threading.Tasks;
|
||||
|
||||
namespace AntSK.LLamaFactory.Model
|
||||
{
|
||||
public class ModelInfo
|
||||
{
|
||||
public string DEFAULT { get; set; }
|
||||
public string MODELSCOPE { get; set; }
|
||||
}
|
||||
|
||||
|
||||
public class LLamaFactoryModel
|
||||
{
|
||||
public Dictionary<string, ModelInfo> Models { get; set; }
|
||||
public string Template { get; set; }
|
||||
}
|
||||
|
||||
public class LLamaModel
|
||||
{
|
||||
public string Name { get; set; }
|
||||
public string ModelScope { get; set; }
|
||||
}
|
||||
}
|
||||
0
src/AntSK.LLamaFactory/llamafactory/__init__.py
Normal file
0
src/AntSK.LLamaFactory/llamafactory/__init__.py
Normal file
16
src/AntSK.LLamaFactory/llamafactory/api_demo.py
Normal file
16
src/AntSK.LLamaFactory/llamafactory/api_demo.py
Normal file
@@ -0,0 +1,16 @@
|
||||
import os
|
||||
|
||||
import uvicorn
|
||||
|
||||
from llmtuner import ChatModel, create_app
|
||||
|
||||
|
||||
def main():
|
||||
chat_model = ChatModel()
|
||||
app = create_app(chat_model)
|
||||
print("Visit http://localhost:{}/docs for API document.".format(os.environ.get("API_PORT", 8000)))
|
||||
uvicorn.run(app, host="0.0.0.0", port=int(os.environ.get("API_PORT", 8000)), workers=1)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
27
src/AntSK.LLamaFactory/llamafactory/api_start.py
Normal file
27
src/AntSK.LLamaFactory/llamafactory/api_start.py
Normal file
@@ -0,0 +1,27 @@
|
||||
import subprocess
|
||||
import shlex
|
||||
import os
|
||||
class Start(object):
|
||||
|
||||
def __init__(self,model_name_or_path):
|
||||
self.model_name_or_path=model_name_or_path
|
||||
|
||||
def StartCommand(self):
|
||||
os.environ['CUDA_VISIBLE_DEVICES'] = '0'
|
||||
os.environ['API_PORT'] = '8000'
|
||||
# 构建要执行的命令
|
||||
command = (
|
||||
'python api_demo.py'
|
||||
' --model_name_or_path E:/model/Qwen1.5-0.5B-Chat_back'
|
||||
' --template default '
|
||||
)
|
||||
|
||||
# 使用shlex.split()去安全地分割命令字符串
|
||||
command = shlex.split(command)
|
||||
|
||||
# 执行命令
|
||||
subprocess.run(command, shell=True)
|
||||
|
||||
if __name__ == "__main__":
|
||||
star= Start('model_name_or_path')
|
||||
star.StartCommand()
|
||||
49
src/AntSK.LLamaFactory/llamafactory/cli_demo.py
Normal file
49
src/AntSK.LLamaFactory/llamafactory/cli_demo.py
Normal file
@@ -0,0 +1,49 @@
|
||||
from llmtuner import ChatModel
|
||||
from llmtuner.extras.misc import torch_gc
|
||||
|
||||
|
||||
try:
|
||||
import platform
|
||||
|
||||
if platform.system() != "Windows":
|
||||
import readline # noqa: F401
|
||||
except ImportError:
|
||||
print("Install `readline` for a better experience.")
|
||||
|
||||
|
||||
def main():
|
||||
chat_model = ChatModel()
|
||||
messages = []
|
||||
print("Welcome to the CLI application, use `clear` to remove the history, use `exit` to exit the application.")
|
||||
|
||||
while True:
|
||||
try:
|
||||
query = input("\nUser: ")
|
||||
except UnicodeDecodeError:
|
||||
print("Detected decoding error at the inputs, please set the terminal encoding to utf-8.")
|
||||
continue
|
||||
except Exception:
|
||||
raise
|
||||
|
||||
if query.strip() == "exit":
|
||||
break
|
||||
|
||||
if query.strip() == "clear":
|
||||
messages = []
|
||||
torch_gc()
|
||||
print("History has been removed.")
|
||||
continue
|
||||
|
||||
messages.append({"role": "user", "content": query})
|
||||
print("Assistant: ", end="", flush=True)
|
||||
|
||||
response = ""
|
||||
for new_text in chat_model.stream_chat(messages):
|
||||
print(new_text, end="", flush=True)
|
||||
response += new_text
|
||||
print()
|
||||
messages.append({"role": "assistant", "content": response})
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
10
src/AntSK.LLamaFactory/llamafactory/evaluate.py
Normal file
10
src/AntSK.LLamaFactory/llamafactory/evaluate.py
Normal file
@@ -0,0 +1,10 @@
|
||||
from llmtuner import Evaluator
|
||||
|
||||
|
||||
def main():
|
||||
evaluator = Evaluator()
|
||||
evaluator.eval()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
9
src/AntSK.LLamaFactory/llamafactory/export_model.py
Normal file
9
src/AntSK.LLamaFactory/llamafactory/export_model.py
Normal file
@@ -0,0 +1,9 @@
|
||||
from llmtuner import export_model
|
||||
|
||||
|
||||
def main():
|
||||
export_model()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
11
src/AntSK.LLamaFactory/llamafactory/llmtuner/__init__.py
Normal file
11
src/AntSK.LLamaFactory/llamafactory/llmtuner/__init__.py
Normal file
@@ -0,0 +1,11 @@
|
||||
# Level: api, webui > chat, eval, train > data, model > extras, hparams
|
||||
|
||||
from .api import create_app
|
||||
from .chat import ChatModel
|
||||
from .eval import Evaluator
|
||||
from .train import export_model, run_exp
|
||||
from .webui import create_ui, create_web_demo
|
||||
|
||||
|
||||
__version__ = "0.5.3"
|
||||
__all__ = ["create_app", "ChatModel", "Evaluator", "export_model", "run_exp", "create_ui", "create_web_demo"]
|
||||
@@ -0,0 +1,4 @@
|
||||
from .app import create_app
|
||||
|
||||
|
||||
__all__ = ["create_app"]
|
||||
224
src/AntSK.LLamaFactory/llamafactory/llmtuner/api/app.py
Normal file
224
src/AntSK.LLamaFactory/llamafactory/llmtuner/api/app.py
Normal file
@@ -0,0 +1,224 @@
|
||||
import json
|
||||
import os
|
||||
from contextlib import asynccontextmanager
|
||||
from typing import Any, Dict, Sequence
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from ..chat import ChatModel
|
||||
from ..data import Role as DataRole
|
||||
from ..extras.misc import torch_gc
|
||||
from ..extras.packages import is_fastapi_availble, is_starlette_available, is_uvicorn_available
|
||||
from .protocol import (
|
||||
ChatCompletionMessage,
|
||||
ChatCompletionRequest,
|
||||
ChatCompletionResponse,
|
||||
ChatCompletionResponseChoice,
|
||||
ChatCompletionResponseStreamChoice,
|
||||
ChatCompletionResponseUsage,
|
||||
ChatCompletionStreamResponse,
|
||||
Finish,
|
||||
Function,
|
||||
FunctionCall,
|
||||
ModelCard,
|
||||
ModelList,
|
||||
Role,
|
||||
ScoreEvaluationRequest,
|
||||
ScoreEvaluationResponse,
|
||||
)
|
||||
|
||||
|
||||
if is_fastapi_availble():
|
||||
from fastapi import FastAPI, HTTPException, status
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
|
||||
|
||||
if is_starlette_available():
|
||||
from sse_starlette import EventSourceResponse
|
||||
|
||||
|
||||
if is_uvicorn_available():
|
||||
import uvicorn
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def lifespan(app: "FastAPI"): # collects GPU memory
|
||||
yield
|
||||
torch_gc()
|
||||
|
||||
|
||||
def dictify(data: "BaseModel") -> Dict[str, Any]:
|
||||
try: # pydantic v2
|
||||
return data.model_dump(exclude_unset=True)
|
||||
except AttributeError: # pydantic v1
|
||||
return data.dict(exclude_unset=True)
|
||||
|
||||
|
||||
def jsonify(data: "BaseModel") -> str:
|
||||
try: # pydantic v2
|
||||
return json.dumps(data.model_dump(exclude_unset=True), ensure_ascii=False)
|
||||
except AttributeError: # pydantic v1
|
||||
return data.json(exclude_unset=True, ensure_ascii=False)
|
||||
|
||||
|
||||
def create_app(chat_model: "ChatModel") -> "FastAPI":
|
||||
app = FastAPI(lifespan=lifespan)
|
||||
|
||||
app.add_middleware(
|
||||
CORSMiddleware,
|
||||
allow_origins=["*"],
|
||||
allow_credentials=True,
|
||||
allow_methods=["*"],
|
||||
allow_headers=["*"],
|
||||
)
|
||||
|
||||
role_mapping = {
|
||||
Role.USER: DataRole.USER.value,
|
||||
Role.ASSISTANT: DataRole.ASSISTANT.value,
|
||||
Role.SYSTEM: DataRole.SYSTEM.value,
|
||||
Role.FUNCTION: DataRole.FUNCTION.value,
|
||||
Role.TOOL: DataRole.OBSERVATION.value,
|
||||
}
|
||||
|
||||
@app.get("/v1/models", response_model=ModelList)
|
||||
async def list_models():
|
||||
model_card = ModelCard(id="gpt-3.5-turbo")
|
||||
return ModelList(data=[model_card])
|
||||
|
||||
@app.post("/v1/chat/completions", response_model=ChatCompletionResponse, status_code=status.HTTP_200_OK)
|
||||
async def create_chat_completion(request: ChatCompletionRequest):
|
||||
if not chat_model.engine.can_generate:
|
||||
raise HTTPException(status_code=status.HTTP_405_METHOD_NOT_ALLOWED, detail="Not allowed")
|
||||
|
||||
if len(request.messages) == 0:
|
||||
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid length")
|
||||
|
||||
if request.messages[0].role == Role.SYSTEM:
|
||||
system = request.messages.pop(0).content
|
||||
else:
|
||||
system = ""
|
||||
|
||||
if len(request.messages) % 2 == 0:
|
||||
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Only supports u/a/u/a/u...")
|
||||
|
||||
input_messages = []
|
||||
for i, message in enumerate(request.messages):
|
||||
if i % 2 == 0 and message.role not in [Role.USER, Role.TOOL]:
|
||||
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid role")
|
||||
elif i % 2 == 1 and message.role not in [Role.ASSISTANT, Role.FUNCTION]:
|
||||
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid role")
|
||||
|
||||
input_messages.append({"role": role_mapping[message.role], "content": message.content})
|
||||
|
||||
tool_list = request.tools
|
||||
if isinstance(tool_list, list) and len(tool_list):
|
||||
try:
|
||||
tools = json.dumps([tool["function"] for tool in tool_list], ensure_ascii=False)
|
||||
except Exception:
|
||||
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid tools")
|
||||
else:
|
||||
tools = ""
|
||||
|
||||
if request.stream:
|
||||
if tools:
|
||||
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Cannot stream function calls.")
|
||||
|
||||
generate = stream_chat_completion(input_messages, system, tools, request)
|
||||
return EventSourceResponse(generate, media_type="text/event-stream")
|
||||
|
||||
responses = await chat_model.achat(
|
||||
input_messages,
|
||||
system,
|
||||
tools,
|
||||
do_sample=request.do_sample,
|
||||
temperature=request.temperature,
|
||||
top_p=request.top_p,
|
||||
max_new_tokens=request.max_tokens,
|
||||
num_return_sequences=request.n,
|
||||
)
|
||||
|
||||
prompt_length, response_length = 0, 0
|
||||
choices = []
|
||||
for i, response in enumerate(responses):
|
||||
if tools:
|
||||
result = chat_model.engine.template.format_tools.extract(response.response_text)
|
||||
else:
|
||||
result = response.response_text
|
||||
|
||||
if isinstance(result, tuple):
|
||||
name, arguments = result
|
||||
function = Function(name=name, arguments=arguments)
|
||||
response_message = ChatCompletionMessage(
|
||||
role=Role.ASSISTANT, tool_calls=[FunctionCall(function=function)]
|
||||
)
|
||||
finish_reason = Finish.TOOL
|
||||
else:
|
||||
response_message = ChatCompletionMessage(role=Role.ASSISTANT, content=result)
|
||||
finish_reason = Finish.STOP if response.finish_reason == "stop" else Finish.LENGTH
|
||||
|
||||
choices.append(
|
||||
ChatCompletionResponseChoice(index=i, message=response_message, finish_reason=finish_reason)
|
||||
)
|
||||
prompt_length = response.prompt_length
|
||||
response_length += response.response_length
|
||||
|
||||
usage = ChatCompletionResponseUsage(
|
||||
prompt_tokens=prompt_length,
|
||||
completion_tokens=response_length,
|
||||
total_tokens=prompt_length + response_length,
|
||||
)
|
||||
|
||||
return ChatCompletionResponse(model=request.model, choices=choices, usage=usage)
|
||||
|
||||
async def stream_chat_completion(
|
||||
messages: Sequence[Dict[str, str]], system: str, tools: str, request: ChatCompletionRequest
|
||||
):
|
||||
choice_data = ChatCompletionResponseStreamChoice(
|
||||
index=0, delta=ChatCompletionMessage(role=Role.ASSISTANT, content=""), finish_reason=None
|
||||
)
|
||||
chunk = ChatCompletionStreamResponse(model=request.model, choices=[choice_data])
|
||||
yield jsonify(chunk)
|
||||
|
||||
async for new_token in chat_model.astream_chat(
|
||||
messages,
|
||||
system,
|
||||
tools,
|
||||
do_sample=request.do_sample,
|
||||
temperature=request.temperature,
|
||||
top_p=request.top_p,
|
||||
max_new_tokens=request.max_tokens,
|
||||
):
|
||||
if len(new_token) == 0:
|
||||
continue
|
||||
|
||||
choice_data = ChatCompletionResponseStreamChoice(
|
||||
index=0, delta=ChatCompletionMessage(content=new_token), finish_reason=None
|
||||
)
|
||||
chunk = ChatCompletionStreamResponse(model=request.model, choices=[choice_data])
|
||||
yield jsonify(chunk)
|
||||
|
||||
choice_data = ChatCompletionResponseStreamChoice(
|
||||
index=0, delta=ChatCompletionMessage(), finish_reason=Finish.STOP
|
||||
)
|
||||
chunk = ChatCompletionStreamResponse(model=request.model, choices=[choice_data])
|
||||
yield jsonify(chunk)
|
||||
yield "[DONE]"
|
||||
|
||||
@app.post("/v1/score/evaluation", response_model=ScoreEvaluationResponse, status_code=status.HTTP_200_OK)
|
||||
async def create_score_evaluation(request: ScoreEvaluationRequest):
|
||||
if chat_model.engine.can_generate:
|
||||
raise HTTPException(status_code=status.HTTP_405_METHOD_NOT_ALLOWED, detail="Not allowed")
|
||||
|
||||
if len(request.messages) == 0:
|
||||
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid request")
|
||||
|
||||
scores = await chat_model.aget_scores(request.messages, max_length=request.max_length)
|
||||
return ScoreEvaluationResponse(model=request.model, scores=scores)
|
||||
|
||||
return app
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
chat_model = ChatModel()
|
||||
app = create_app(chat_model)
|
||||
uvicorn.run(app, host="0.0.0.0", port=int(os.environ.get("API_PORT", 8000)), workers=1)
|
||||
116
src/AntSK.LLamaFactory/llamafactory/llmtuner/api/protocol.py
Normal file
116
src/AntSK.LLamaFactory/llamafactory/llmtuner/api/protocol.py
Normal file
@@ -0,0 +1,116 @@
|
||||
import time
|
||||
from enum import Enum, unique
|
||||
from typing import List, Optional
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
from typing_extensions import Literal
|
||||
|
||||
|
||||
@unique
|
||||
class Role(str, Enum):
|
||||
USER = "user"
|
||||
ASSISTANT = "assistant"
|
||||
SYSTEM = "system"
|
||||
FUNCTION = "function"
|
||||
TOOL = "tool"
|
||||
|
||||
|
||||
@unique
|
||||
class Finish(str, Enum):
|
||||
STOP = "stop"
|
||||
LENGTH = "length"
|
||||
TOOL = "tool_calls"
|
||||
|
||||
|
||||
class ModelCard(BaseModel):
|
||||
id: str
|
||||
object: Literal["model"] = "model"
|
||||
created: int = Field(default_factory=lambda: int(time.time()))
|
||||
owned_by: Literal["owner"] = "owner"
|
||||
|
||||
|
||||
class ModelList(BaseModel):
|
||||
object: Literal["list"] = "list"
|
||||
data: List[ModelCard] = []
|
||||
|
||||
|
||||
class Function(BaseModel):
|
||||
name: str
|
||||
arguments: str
|
||||
|
||||
|
||||
class FunctionCall(BaseModel):
|
||||
id: Literal["call_default"] = "call_default"
|
||||
type: Literal["function"] = "function"
|
||||
function: Function
|
||||
|
||||
|
||||
class ChatMessage(BaseModel):
|
||||
role: Role
|
||||
content: str
|
||||
|
||||
|
||||
class ChatCompletionMessage(BaseModel):
|
||||
role: Optional[Role] = None
|
||||
content: Optional[str] = None
|
||||
tool_calls: Optional[List[FunctionCall]] = None
|
||||
|
||||
|
||||
class ChatCompletionRequest(BaseModel):
|
||||
model: str
|
||||
messages: List[ChatMessage]
|
||||
tools: list = []
|
||||
do_sample: bool = True
|
||||
temperature: Optional[float] = None
|
||||
top_p: Optional[float] = None
|
||||
n: int = 1
|
||||
max_tokens: Optional[int] = None
|
||||
stream: bool = False
|
||||
|
||||
|
||||
class ChatCompletionResponseChoice(BaseModel):
|
||||
index: int
|
||||
message: ChatCompletionMessage
|
||||
finish_reason: Finish
|
||||
|
||||
|
||||
class ChatCompletionResponseStreamChoice(BaseModel):
|
||||
index: int
|
||||
delta: ChatCompletionMessage
|
||||
finish_reason: Optional[Finish] = None
|
||||
|
||||
|
||||
class ChatCompletionResponseUsage(BaseModel):
|
||||
prompt_tokens: int
|
||||
completion_tokens: int
|
||||
total_tokens: int
|
||||
|
||||
|
||||
class ChatCompletionResponse(BaseModel):
|
||||
id: Literal["chatcmpl-default"] = "chatcmpl-default"
|
||||
object: Literal["chat.completion"] = "chat.completion"
|
||||
created: int = Field(default_factory=lambda: int(time.time()))
|
||||
model: str
|
||||
choices: List[ChatCompletionResponseChoice]
|
||||
usage: ChatCompletionResponseUsage
|
||||
|
||||
|
||||
class ChatCompletionStreamResponse(BaseModel):
|
||||
id: Literal["chatcmpl-default"] = "chatcmpl-default"
|
||||
object: Literal["chat.completion.chunk"] = "chat.completion.chunk"
|
||||
created: int = Field(default_factory=lambda: int(time.time()))
|
||||
model: str
|
||||
choices: List[ChatCompletionResponseStreamChoice]
|
||||
|
||||
|
||||
class ScoreEvaluationRequest(BaseModel):
|
||||
model: str
|
||||
messages: List[str]
|
||||
max_length: Optional[int] = None
|
||||
|
||||
|
||||
class ScoreEvaluationResponse(BaseModel):
|
||||
id: Literal["scoreeval-default"] = "scoreeval-default"
|
||||
object: Literal["score.evaluation"] = "score.evaluation"
|
||||
model: str
|
||||
scores: List[float]
|
||||
@@ -0,0 +1,5 @@
|
||||
from .base_engine import BaseEngine
|
||||
from .chat_model import ChatModel
|
||||
|
||||
|
||||
__all__ = ["BaseEngine", "ChatModel"]
|
||||
@@ -0,0 +1,69 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from dataclasses import dataclass
|
||||
from typing import TYPE_CHECKING, Any, AsyncGenerator, Dict, List, Literal, Optional, Sequence, Union
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from transformers import PreTrainedModel, PreTrainedTokenizer
|
||||
|
||||
from ..data import Template
|
||||
from ..extras.packages import is_vllm_available
|
||||
from ..hparams import DataArguments, FinetuningArguments, GeneratingArguments, ModelArguments
|
||||
|
||||
if is_vllm_available():
|
||||
from vllm import AsyncLLMEngine
|
||||
|
||||
|
||||
@dataclass
|
||||
class Response:
|
||||
response_text: str
|
||||
response_length: int
|
||||
prompt_length: int
|
||||
finish_reason: Literal["stop", "length"]
|
||||
|
||||
|
||||
class BaseEngine(ABC):
|
||||
model: Union["PreTrainedModel", "AsyncLLMEngine"]
|
||||
tokenizer: "PreTrainedTokenizer"
|
||||
can_generate: bool
|
||||
template: "Template"
|
||||
generating_args: Dict[str, Any]
|
||||
|
||||
@abstractmethod
|
||||
def __init__(
|
||||
self,
|
||||
model_args: "ModelArguments",
|
||||
data_args: "DataArguments",
|
||||
finetuning_args: "FinetuningArguments",
|
||||
generating_args: "GeneratingArguments",
|
||||
) -> None: ...
|
||||
|
||||
@abstractmethod
|
||||
async def start(
|
||||
self,
|
||||
) -> None: ...
|
||||
|
||||
@abstractmethod
|
||||
async def chat(
|
||||
self,
|
||||
messages: Sequence[Dict[str, str]],
|
||||
system: Optional[str] = None,
|
||||
tools: Optional[str] = None,
|
||||
**input_kwargs,
|
||||
) -> List["Response"]: ...
|
||||
|
||||
@abstractmethod
|
||||
async def stream_chat(
|
||||
self,
|
||||
messages: Sequence[Dict[str, str]],
|
||||
system: Optional[str] = None,
|
||||
tools: Optional[str] = None,
|
||||
**input_kwargs,
|
||||
) -> AsyncGenerator[str, None]: ...
|
||||
|
||||
@abstractmethod
|
||||
async def get_scores(
|
||||
self,
|
||||
batch_input: List[str],
|
||||
**input_kwargs,
|
||||
) -> List[float]: ...
|
||||
@@ -0,0 +1,91 @@
|
||||
import asyncio
|
||||
from threading import Thread
|
||||
from typing import TYPE_CHECKING, Any, AsyncGenerator, Dict, Generator, List, Optional, Sequence
|
||||
|
||||
from ..hparams import get_infer_args
|
||||
from .hf_engine import HuggingfaceEngine
|
||||
from .vllm_engine import VllmEngine
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .base_engine import BaseEngine, Response
|
||||
|
||||
|
||||
def _start_background_loop(loop: asyncio.AbstractEventLoop) -> None:
|
||||
asyncio.set_event_loop(loop)
|
||||
loop.run_forever()
|
||||
|
||||
|
||||
class ChatModel:
|
||||
def __init__(self, args: Optional[Dict[str, Any]] = None) -> None:
|
||||
model_args, data_args, finetuning_args, generating_args = get_infer_args(args)
|
||||
if model_args.infer_backend == "huggingface":
|
||||
self.engine: "BaseEngine" = HuggingfaceEngine(model_args, data_args, finetuning_args, generating_args)
|
||||
elif model_args.infer_backend == "vllm":
|
||||
self.engine: "BaseEngine" = VllmEngine(model_args, data_args, finetuning_args, generating_args)
|
||||
else:
|
||||
raise NotImplementedError("Unknown backend: {}".format(model_args.infer_backend))
|
||||
|
||||
self._loop = asyncio.new_event_loop()
|
||||
self._thread = Thread(target=_start_background_loop, args=(self._loop,), daemon=True)
|
||||
self._thread.start()
|
||||
asyncio.run_coroutine_threadsafe(self.engine.start(), self._loop)
|
||||
|
||||
def chat(
|
||||
self,
|
||||
messages: Sequence[Dict[str, str]],
|
||||
system: Optional[str] = None,
|
||||
tools: Optional[str] = None,
|
||||
**input_kwargs,
|
||||
) -> List["Response"]:
|
||||
task = asyncio.run_coroutine_threadsafe(self.achat(messages, system, tools, **input_kwargs), self._loop)
|
||||
return task.result()
|
||||
|
||||
async def achat(
|
||||
self,
|
||||
messages: Sequence[Dict[str, str]],
|
||||
system: Optional[str] = None,
|
||||
tools: Optional[str] = None,
|
||||
**input_kwargs,
|
||||
) -> List["Response"]:
|
||||
return await self.engine.chat(messages, system, tools, **input_kwargs)
|
||||
|
||||
def stream_chat(
|
||||
self,
|
||||
messages: Sequence[Dict[str, str]],
|
||||
system: Optional[str] = None,
|
||||
tools: Optional[str] = None,
|
||||
**input_kwargs,
|
||||
) -> Generator[str, None, None]:
|
||||
generator = self.astream_chat(messages, system, tools, **input_kwargs)
|
||||
while True:
|
||||
try:
|
||||
task = asyncio.run_coroutine_threadsafe(generator.__anext__(), self._loop)
|
||||
yield task.result()
|
||||
except StopAsyncIteration:
|
||||
break
|
||||
|
||||
async def astream_chat(
|
||||
self,
|
||||
messages: Sequence[Dict[str, str]],
|
||||
system: Optional[str] = None,
|
||||
tools: Optional[str] = None,
|
||||
**input_kwargs,
|
||||
) -> AsyncGenerator[str, None]:
|
||||
async for new_token in self.engine.stream_chat(messages, system, tools, **input_kwargs):
|
||||
yield new_token
|
||||
|
||||
def get_scores(
|
||||
self,
|
||||
batch_input: List[str],
|
||||
**input_kwargs,
|
||||
) -> List[float]:
|
||||
task = asyncio.run_coroutine_threadsafe(self.aget_scores(batch_input, **input_kwargs), self._loop)
|
||||
return task.result()
|
||||
|
||||
async def aget_scores(
|
||||
self,
|
||||
batch_input: List[str],
|
||||
**input_kwargs,
|
||||
) -> List[float]:
|
||||
return await self.engine.get_scores(batch_input, **input_kwargs)
|
||||
263
src/AntSK.LLamaFactory/llamafactory/llmtuner/chat/hf_engine.py
Normal file
263
src/AntSK.LLamaFactory/llamafactory/llmtuner/chat/hf_engine.py
Normal file
@@ -0,0 +1,263 @@
|
||||
import asyncio
|
||||
import concurrent.futures
|
||||
import os
|
||||
from threading import Thread
|
||||
from typing import TYPE_CHECKING, Any, AsyncGenerator, Callable, Dict, List, Optional, Sequence, Tuple
|
||||
|
||||
import torch
|
||||
from transformers import GenerationConfig, TextIteratorStreamer
|
||||
|
||||
from ..data import get_template_and_fix_tokenizer
|
||||
from ..extras.misc import get_logits_processor
|
||||
from ..model import load_model_and_tokenizer
|
||||
from .base_engine import BaseEngine, Response
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from transformers import PreTrainedModel, PreTrainedTokenizer
|
||||
from trl import PreTrainedModelWrapper
|
||||
|
||||
from ..data import Template
|
||||
from ..hparams import DataArguments, FinetuningArguments, GeneratingArguments, ModelArguments
|
||||
|
||||
|
||||
class HuggingfaceEngine(BaseEngine):
|
||||
def __init__(
|
||||
self,
|
||||
model_args: "ModelArguments",
|
||||
data_args: "DataArguments",
|
||||
finetuning_args: "FinetuningArguments",
|
||||
generating_args: "GeneratingArguments",
|
||||
) -> None:
|
||||
self.can_generate = finetuning_args.stage == "sft"
|
||||
self.model, self.tokenizer = load_model_and_tokenizer(
|
||||
model_args, finetuning_args, is_trainable=False, add_valuehead=(not self.can_generate)
|
||||
)
|
||||
self.tokenizer.padding_side = "left" if self.can_generate else "right"
|
||||
self.template = get_template_and_fix_tokenizer(self.tokenizer, data_args.template)
|
||||
self.generating_args = generating_args.to_dict()
|
||||
|
||||
@staticmethod
|
||||
def _process_args(
|
||||
model: "PreTrainedModel",
|
||||
tokenizer: "PreTrainedTokenizer",
|
||||
template: "Template",
|
||||
generating_args: Dict[str, Any],
|
||||
messages: Sequence[Dict[str, str]],
|
||||
system: Optional[str] = None,
|
||||
tools: Optional[str] = None,
|
||||
input_kwargs: Optional[Dict[str, Any]] = {},
|
||||
) -> Tuple[Dict[str, Any], int]:
|
||||
paired_messages = messages + [{"role": "assistant", "content": ""}]
|
||||
prompt_ids, _ = template.encode_oneturn(
|
||||
tokenizer=tokenizer, messages=paired_messages, system=system, tools=tools
|
||||
)
|
||||
prompt_length = len(prompt_ids)
|
||||
inputs = torch.tensor([prompt_ids], device=model.device)
|
||||
|
||||
do_sample = input_kwargs.pop("do_sample", None)
|
||||
temperature = input_kwargs.pop("temperature", None)
|
||||
top_p = input_kwargs.pop("top_p", None)
|
||||
top_k = input_kwargs.pop("top_k", None)
|
||||
num_return_sequences = input_kwargs.pop("num_return_sequences", None)
|
||||
repetition_penalty = input_kwargs.pop("repetition_penalty", None)
|
||||
max_length = input_kwargs.pop("max_length", None)
|
||||
max_new_tokens = input_kwargs.pop("max_new_tokens", None)
|
||||
|
||||
generating_args.update(
|
||||
dict(
|
||||
do_sample=do_sample if do_sample is not None else generating_args["do_sample"],
|
||||
temperature=temperature or generating_args["temperature"],
|
||||
top_p=top_p or generating_args["top_p"],
|
||||
top_k=top_k or generating_args["top_k"],
|
||||
num_return_sequences=num_return_sequences or 1,
|
||||
repetition_penalty=repetition_penalty or generating_args["repetition_penalty"],
|
||||
eos_token_id=[tokenizer.eos_token_id] + tokenizer.additional_special_tokens_ids,
|
||||
pad_token_id=tokenizer.pad_token_id,
|
||||
)
|
||||
)
|
||||
|
||||
if isinstance(num_return_sequences, int) and num_return_sequences > 1:
|
||||
generating_args["do_sample"] = True
|
||||
|
||||
if max_length:
|
||||
generating_args.pop("max_new_tokens", None)
|
||||
generating_args["max_length"] = max_length
|
||||
|
||||
if max_new_tokens:
|
||||
generating_args.pop("max_length", None)
|
||||
generating_args["max_new_tokens"] = max_new_tokens
|
||||
|
||||
gen_kwargs = dict(
|
||||
inputs=inputs,
|
||||
generation_config=GenerationConfig(**generating_args),
|
||||
logits_processor=get_logits_processor(),
|
||||
)
|
||||
|
||||
return gen_kwargs, prompt_length
|
||||
|
||||
@staticmethod
|
||||
@torch.inference_mode()
|
||||
def _chat(
|
||||
model: "PreTrainedModel",
|
||||
tokenizer: "PreTrainedTokenizer",
|
||||
template: "Template",
|
||||
generating_args: Dict[str, Any],
|
||||
messages: Sequence[Dict[str, str]],
|
||||
system: Optional[str] = None,
|
||||
tools: Optional[str] = None,
|
||||
input_kwargs: Optional[Dict[str, Any]] = {},
|
||||
) -> List["Response"]:
|
||||
gen_kwargs, prompt_length = HuggingfaceEngine._process_args(
|
||||
model, tokenizer, template, generating_args, messages, system, tools, input_kwargs
|
||||
)
|
||||
generate_output = model.generate(**gen_kwargs)
|
||||
response_ids = generate_output[:, prompt_length:]
|
||||
response = tokenizer.batch_decode(response_ids, skip_special_tokens=True, clean_up_tokenization_spaces=True)
|
||||
results = []
|
||||
for i in range(len(response)):
|
||||
eos_index = (response_ids[i] == tokenizer.eos_token_id).nonzero()
|
||||
response_length = (eos_index[0].item() + 1) if len(eos_index) else len(response_ids[i])
|
||||
results.append(
|
||||
Response(
|
||||
response_text=response[i],
|
||||
response_length=response_length,
|
||||
prompt_length=prompt_length,
|
||||
finish_reason="stop" if len(eos_index) else "length",
|
||||
)
|
||||
)
|
||||
|
||||
return results
|
||||
|
||||
@staticmethod
|
||||
@torch.inference_mode()
|
||||
def _stream_chat(
|
||||
model: "PreTrainedModel",
|
||||
tokenizer: "PreTrainedTokenizer",
|
||||
template: "Template",
|
||||
generating_args: Dict[str, Any],
|
||||
messages: Sequence[Dict[str, str]],
|
||||
system: Optional[str] = None,
|
||||
tools: Optional[str] = None,
|
||||
input_kwargs: Optional[Dict[str, Any]] = {},
|
||||
) -> Callable[[], str]:
|
||||
gen_kwargs, _ = HuggingfaceEngine._process_args(
|
||||
model, tokenizer, template, generating_args, messages, system, tools, input_kwargs
|
||||
)
|
||||
streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
|
||||
gen_kwargs["streamer"] = streamer
|
||||
thread = Thread(target=model.generate, kwargs=gen_kwargs, daemon=True)
|
||||
thread.start()
|
||||
|
||||
def stream():
|
||||
try:
|
||||
return streamer.__next__()
|
||||
except StopIteration:
|
||||
raise StopAsyncIteration()
|
||||
|
||||
return stream
|
||||
|
||||
@staticmethod
|
||||
@torch.inference_mode()
|
||||
def _get_scores(
|
||||
model: "PreTrainedModelWrapper",
|
||||
tokenizer: "PreTrainedTokenizer",
|
||||
batch_input: List[str],
|
||||
input_kwargs: Optional[Dict[str, Any]] = {},
|
||||
) -> List[float]:
|
||||
max_length = input_kwargs.pop("max_length", None)
|
||||
device = getattr(model.pretrained_model, "device", "cuda")
|
||||
inputs = tokenizer(
|
||||
batch_input,
|
||||
padding=True,
|
||||
truncation=True,
|
||||
max_length=max_length or getattr(model.config, "max_position_embeddings", 1024),
|
||||
return_tensors="pt",
|
||||
add_special_tokens=True,
|
||||
).to(device)
|
||||
|
||||
input_ids: torch.Tensor = inputs["input_ids"]
|
||||
_, _, values = model(**inputs, output_hidden_states=True, return_dict=True)
|
||||
|
||||
if getattr(model.config, "model_type", None) == "chatglm":
|
||||
values = torch.transpose(values, 0, 1)
|
||||
|
||||
scores = []
|
||||
for i in range(input_ids.size(0)):
|
||||
end_indexes = (input_ids[i] != tokenizer.pad_token_id).nonzero()
|
||||
end_index = end_indexes[-1].item() if len(end_indexes) else 0
|
||||
scores.append(values[i, end_index].nan_to_num().item())
|
||||
|
||||
return scores
|
||||
|
||||
async def start(self) -> None:
|
||||
self._semaphore = asyncio.Semaphore(int(os.environ.get("MAX_CONCURRENT", 1)))
|
||||
|
||||
async def chat(
|
||||
self,
|
||||
messages: Sequence[Dict[str, str]],
|
||||
system: Optional[str] = None,
|
||||
tools: Optional[str] = None,
|
||||
**input_kwargs,
|
||||
) -> List["Response"]:
|
||||
if not self.can_generate:
|
||||
raise ValueError("The current model does not support `chat`.")
|
||||
|
||||
loop = asyncio.get_running_loop()
|
||||
input_args = (
|
||||
self.model,
|
||||
self.tokenizer,
|
||||
self.template,
|
||||
self.generating_args,
|
||||
messages,
|
||||
system,
|
||||
tools,
|
||||
input_kwargs,
|
||||
)
|
||||
async with self._semaphore:
|
||||
with concurrent.futures.ThreadPoolExecutor() as pool:
|
||||
return await loop.run_in_executor(pool, self._chat, *input_args)
|
||||
|
||||
async def stream_chat(
|
||||
self,
|
||||
messages: Sequence[Dict[str, str]],
|
||||
system: Optional[str] = None,
|
||||
tools: Optional[str] = None,
|
||||
**input_kwargs,
|
||||
) -> AsyncGenerator[str, None]:
|
||||
if not self.can_generate:
|
||||
raise ValueError("The current model does not support `stream_chat`.")
|
||||
|
||||
loop = asyncio.get_running_loop()
|
||||
input_args = (
|
||||
self.model,
|
||||
self.tokenizer,
|
||||
self.template,
|
||||
self.generating_args,
|
||||
messages,
|
||||
system,
|
||||
tools,
|
||||
input_kwargs,
|
||||
)
|
||||
async with self._semaphore:
|
||||
with concurrent.futures.ThreadPoolExecutor() as pool:
|
||||
stream = self._stream_chat(*input_args)
|
||||
while True:
|
||||
try:
|
||||
yield await loop.run_in_executor(pool, stream)
|
||||
except StopAsyncIteration:
|
||||
break
|
||||
|
||||
async def get_scores(
|
||||
self,
|
||||
batch_input: List[str],
|
||||
**input_kwargs,
|
||||
) -> List[float]:
|
||||
if self.can_generate:
|
||||
raise ValueError("Cannot get scores using an auto-regressive model.")
|
||||
|
||||
loop = asyncio.get_running_loop()
|
||||
input_args = (self.model, self.tokenizer, batch_input, input_kwargs)
|
||||
async with self._semaphore:
|
||||
with concurrent.futures.ThreadPoolExecutor() as pool:
|
||||
return await loop.run_in_executor(pool, self._get_scores, *input_args)
|
||||
149
src/AntSK.LLamaFactory/llamafactory/llmtuner/chat/vllm_engine.py
Normal file
149
src/AntSK.LLamaFactory/llamafactory/llmtuner/chat/vllm_engine.py
Normal file
@@ -0,0 +1,149 @@
|
||||
import uuid
|
||||
from typing import TYPE_CHECKING, AsyncGenerator, AsyncIterator, Dict, List, Optional, Sequence
|
||||
|
||||
from transformers.utils.versions import require_version
|
||||
|
||||
from ..data import get_template_and_fix_tokenizer
|
||||
from ..extras.misc import get_device_count
|
||||
from ..extras.packages import is_vllm_available
|
||||
from ..model import load_tokenizer
|
||||
from .base_engine import BaseEngine, Response
|
||||
|
||||
|
||||
if is_vllm_available():
|
||||
from vllm import AsyncEngineArgs, AsyncLLMEngine, RequestOutput, SamplingParams
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ..hparams import DataArguments, FinetuningArguments, GeneratingArguments, ModelArguments
|
||||
|
||||
|
||||
class VllmEngine(BaseEngine):
|
||||
def __init__(
|
||||
self,
|
||||
model_args: "ModelArguments",
|
||||
data_args: "DataArguments",
|
||||
finetuning_args: "FinetuningArguments",
|
||||
generating_args: "GeneratingArguments",
|
||||
) -> None:
|
||||
require_version("vllm>=0.3.3", "To fix: pip install vllm>=0.3.3")
|
||||
self.can_generate = finetuning_args.stage == "sft"
|
||||
engine_args = AsyncEngineArgs(
|
||||
model=model_args.model_name_or_path,
|
||||
trust_remote_code=True,
|
||||
max_model_len=model_args.vllm_maxlen,
|
||||
tensor_parallel_size=get_device_count() or 1,
|
||||
gpu_memory_utilization=model_args.vllm_gpu_util,
|
||||
disable_log_stats=True,
|
||||
disable_log_requests=True,
|
||||
enforce_eager=model_args.vllm_enforce_eager,
|
||||
)
|
||||
self.model = AsyncLLMEngine.from_engine_args(engine_args)
|
||||
self.tokenizer = load_tokenizer(model_args)
|
||||
self.tokenizer.padding_side = "left"
|
||||
self.template = get_template_and_fix_tokenizer(self.tokenizer, data_args.template)
|
||||
self.generating_args = generating_args.to_dict()
|
||||
|
||||
async def _generate(
|
||||
self,
|
||||
messages: Sequence[Dict[str, str]],
|
||||
system: Optional[str] = None,
|
||||
tools: Optional[str] = None,
|
||||
**input_kwargs,
|
||||
) -> AsyncIterator["RequestOutput"]:
|
||||
request_id = "chatcmpl-{}".format(uuid.uuid4().hex)
|
||||
paired_messages = messages + [{"role": "assistant", "content": ""}]
|
||||
prompt_ids, _ = self.template.encode_oneturn(
|
||||
tokenizer=self.tokenizer, messages=paired_messages, system=system, tools=tools
|
||||
)
|
||||
prompt_length = len(prompt_ids)
|
||||
|
||||
temperature = input_kwargs.pop("temperature", None)
|
||||
top_p = input_kwargs.pop("top_p", None)
|
||||
top_k = input_kwargs.pop("top_k", None)
|
||||
num_return_sequences = input_kwargs.pop("num_return_sequences", None)
|
||||
repetition_penalty = input_kwargs.pop("repetition_penalty", None)
|
||||
max_length = input_kwargs.pop("max_length", None)
|
||||
max_new_tokens = input_kwargs.pop("max_new_tokens", None)
|
||||
|
||||
generating_args = self.generating_args.copy()
|
||||
generating_args.update(
|
||||
dict(
|
||||
temperature=temperature or generating_args["temperature"],
|
||||
top_p=top_p or generating_args["top_p"],
|
||||
top_k=top_k or generating_args["top_k"],
|
||||
num_return_sequences=num_return_sequences or 1,
|
||||
repetition_penalty=repetition_penalty or generating_args["repetition_penalty"],
|
||||
)
|
||||
)
|
||||
|
||||
if max_length:
|
||||
generating_args["max_new_tokens"] = max_length - prompt_length
|
||||
|
||||
if max_new_tokens:
|
||||
generating_args["max_new_tokens"] = max_new_tokens
|
||||
|
||||
sampling_params = SamplingParams(
|
||||
n=generating_args["num_return_sequences"],
|
||||
repetition_penalty=generating_args["repetition_penalty"],
|
||||
temperature=generating_args["temperature"],
|
||||
top_p=generating_args["top_p"],
|
||||
top_k=generating_args["top_k"],
|
||||
use_beam_search=generating_args["num_beams"] > 1,
|
||||
length_penalty=generating_args["length_penalty"],
|
||||
stop_token_ids=[self.tokenizer.eos_token_id] + self.tokenizer.additional_special_tokens_ids,
|
||||
max_tokens=generating_args["max_new_tokens"],
|
||||
skip_special_tokens=True,
|
||||
)
|
||||
result_generator = self.model.generate(
|
||||
prompt=None, sampling_params=sampling_params, request_id=request_id, prompt_token_ids=prompt_ids
|
||||
)
|
||||
return result_generator
|
||||
|
||||
async def start(self) -> None:
|
||||
pass
|
||||
|
||||
async def chat(
|
||||
self,
|
||||
messages: Sequence[Dict[str, str]],
|
||||
system: Optional[str] = None,
|
||||
tools: Optional[str] = None,
|
||||
**input_kwargs,
|
||||
) -> List["Response"]:
|
||||
final_output = None
|
||||
generator = await self._generate(messages, system, tools, **input_kwargs)
|
||||
async for request_output in generator:
|
||||
final_output = request_output
|
||||
|
||||
results = []
|
||||
for output in final_output.outputs:
|
||||
results.append(
|
||||
Response(
|
||||
response_text=output.text,
|
||||
response_length=len(output.token_ids),
|
||||
prompt_length=len(final_output.prompt_token_ids),
|
||||
finish_reason=output.finish_reason,
|
||||
)
|
||||
)
|
||||
|
||||
return results
|
||||
|
||||
async def stream_chat(
|
||||
self,
|
||||
messages: Sequence[Dict[str, str]],
|
||||
system: Optional[str] = None,
|
||||
tools: Optional[str] = None,
|
||||
**input_kwargs,
|
||||
) -> AsyncGenerator[str, None]:
|
||||
generated_text = ""
|
||||
generator = await self._generate(messages, system, tools, **input_kwargs)
|
||||
async for result in generator:
|
||||
delta_text = result.outputs[0].text[len(generated_text) :]
|
||||
generated_text = result.outputs[0].text
|
||||
yield delta_text
|
||||
|
||||
async def get_scores(
|
||||
self,
|
||||
batch_input: List[str],
|
||||
**input_kwargs,
|
||||
) -> List[float]:
|
||||
raise NotImplementedError("vLLM engine does not support get_scores.")
|
||||
@@ -0,0 +1,6 @@
|
||||
from .loader import get_dataset
|
||||
from .template import Template, get_template_and_fix_tokenizer, templates
|
||||
from .utils import Role, split_dataset
|
||||
|
||||
|
||||
__all__ = ["get_dataset", "Template", "get_template_and_fix_tokenizer", "templates", "Role", "split_dataset"]
|
||||
133
src/AntSK.LLamaFactory/llamafactory/llmtuner/data/aligner.py
Normal file
133
src/AntSK.LLamaFactory/llamafactory/llmtuner/data/aligner.py
Normal file
@@ -0,0 +1,133 @@
|
||||
from functools import partial
|
||||
from typing import TYPE_CHECKING, Any, Dict, List, Union
|
||||
|
||||
from datasets import Features
|
||||
|
||||
from .utils import Role
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from datasets import Dataset, IterableDataset
|
||||
|
||||
from ..hparams import DataArguments
|
||||
from .parser import DatasetAttr
|
||||
|
||||
|
||||
def convert_alpaca(examples: Dict[str, List[Any]], dataset_attr: "DatasetAttr") -> Dict[str, List[Any]]:
|
||||
outputs = {"prompt": [], "response": [], "system": [], "tools": []}
|
||||
for i in range(len(examples[dataset_attr.prompt])):
|
||||
prompt = []
|
||||
if dataset_attr.history and isinstance(examples[dataset_attr.history][i], list):
|
||||
for old_prompt, old_response in examples[dataset_attr.history][i]:
|
||||
prompt.append({"role": Role.USER.value, "content": old_prompt})
|
||||
prompt.append({"role": Role.ASSISTANT.value, "content": old_response})
|
||||
|
||||
content = []
|
||||
if dataset_attr.prompt and examples[dataset_attr.prompt][i]:
|
||||
content.append(examples[dataset_attr.prompt][i])
|
||||
|
||||
if dataset_attr.query and examples[dataset_attr.query][i]:
|
||||
content.append(examples[dataset_attr.query][i])
|
||||
|
||||
prompt.append({"role": Role.USER.value, "content": "\n".join(content)})
|
||||
|
||||
if dataset_attr.response and isinstance(examples[dataset_attr.response][i], list):
|
||||
response = [
|
||||
{"role": Role.ASSISTANT.value, "content": content} for content in examples[dataset_attr.response][i]
|
||||
]
|
||||
elif dataset_attr.response and isinstance(examples[dataset_attr.response][i], str):
|
||||
response = [{"role": Role.ASSISTANT.value, "content": examples[dataset_attr.response][i]}]
|
||||
else:
|
||||
response = []
|
||||
|
||||
outputs["prompt"].append(prompt)
|
||||
outputs["response"].append(response)
|
||||
outputs["system"].append(examples[dataset_attr.system][i] if dataset_attr.system else "")
|
||||
outputs["tools"].append("")
|
||||
|
||||
return outputs
|
||||
|
||||
|
||||
def convert_sharegpt(examples: Dict[str, List[Any]], dataset_attr: "DatasetAttr") -> Dict[str, List[Any]]:
|
||||
outputs = {"prompt": [], "response": [], "system": [], "tools": []}
|
||||
tag_mapping = {
|
||||
dataset_attr.user_tag: Role.USER.value,
|
||||
dataset_attr.assistant_tag: Role.ASSISTANT.value,
|
||||
dataset_attr.observation_tag: Role.OBSERVATION.value,
|
||||
dataset_attr.function_tag: Role.FUNCTION.value,
|
||||
dataset_attr.system_tag: Role.SYSTEM.value,
|
||||
}
|
||||
odd_tags = (dataset_attr.user_tag, dataset_attr.observation_tag)
|
||||
even_tags = (dataset_attr.assistant_tag, dataset_attr.function_tag)
|
||||
accept_tags = (odd_tags, even_tags)
|
||||
for i, messages in enumerate(examples[dataset_attr.messages]):
|
||||
if dataset_attr.system_tag and messages[0][dataset_attr.role_tag] == dataset_attr.system_tag:
|
||||
system = messages[0][dataset_attr.content_tag]
|
||||
messages = messages[1:]
|
||||
else:
|
||||
system = examples[dataset_attr.system][i] if dataset_attr.system else ""
|
||||
|
||||
messages = messages[: len(messages) // 2 * 2] # should be multiples of 2
|
||||
if len(messages) == 0:
|
||||
continue
|
||||
|
||||
aligned_messages = []
|
||||
for turn_idx, message in enumerate(messages):
|
||||
if message[dataset_attr.role_tag] not in accept_tags[turn_idx % 2]:
|
||||
raise ValueError("Invalid role tag in {}.".format(messages))
|
||||
|
||||
aligned_messages.append(
|
||||
{"role": tag_mapping[message[dataset_attr.role_tag]], "content": message[dataset_attr.content_tag]}
|
||||
)
|
||||
|
||||
outputs["prompt"].append(aligned_messages[:-1])
|
||||
outputs["response"].append(aligned_messages[-1:])
|
||||
outputs["system"].append(system)
|
||||
outputs["tools"].append(examples[dataset_attr.tools][i] if dataset_attr.tools else "")
|
||||
|
||||
return outputs
|
||||
|
||||
|
||||
def align_dataset(
|
||||
dataset: Union["Dataset", "IterableDataset"], dataset_attr: "DatasetAttr", data_args: "DataArguments"
|
||||
) -> Union["Dataset", "IterableDataset"]:
|
||||
r"""
|
||||
Aligned dataset:
|
||||
prompt: [{"role": "user", "content": "..."}] * (2T - 1)
|
||||
response: [{"role": "assistant", "content": "..."}] * N (N > 1 for ranking dataset)
|
||||
system: "..."
|
||||
tools: "..."
|
||||
"""
|
||||
if dataset_attr.formatting == "alpaca":
|
||||
convert_func = partial(convert_alpaca, dataset_attr=dataset_attr)
|
||||
else:
|
||||
convert_func = partial(convert_sharegpt, dataset_attr=dataset_attr)
|
||||
|
||||
column_names = list(next(iter(dataset)).keys())
|
||||
features = Features.from_dict(
|
||||
{
|
||||
"prompt": [
|
||||
{"role": {"dtype": "string", "_type": "Value"}, "content": {"dtype": "string", "_type": "Value"}}
|
||||
],
|
||||
"response": [
|
||||
{"role": {"dtype": "string", "_type": "Value"}, "content": {"dtype": "string", "_type": "Value"}}
|
||||
],
|
||||
"system": {"dtype": "string", "_type": "Value"},
|
||||
"tools": {"dtype": "string", "_type": "Value"},
|
||||
}
|
||||
)
|
||||
kwargs = {}
|
||||
if not data_args.streaming:
|
||||
kwargs = dict(
|
||||
num_proc=data_args.preprocessing_num_workers,
|
||||
load_from_cache_file=(not data_args.overwrite_cache),
|
||||
desc="Converting format of dataset",
|
||||
)
|
||||
|
||||
return dataset.map(
|
||||
convert_func,
|
||||
batched=True,
|
||||
remove_columns=column_names,
|
||||
features=features,
|
||||
**kwargs,
|
||||
)
|
||||
187
src/AntSK.LLamaFactory/llamafactory/llmtuner/data/formatter.py
Normal file
187
src/AntSK.LLamaFactory/llamafactory/llmtuner/data/formatter.py
Normal file
@@ -0,0 +1,187 @@
|
||||
import json
|
||||
import re
|
||||
from abc import ABC, abstractmethod
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any, Dict, List, Literal, Optional, Sequence, Set, Tuple, Union
|
||||
|
||||
|
||||
SLOTS = Sequence[Union[str, Set[str], Dict[str, str]]]
|
||||
|
||||
|
||||
JSON_FORMAT_PROMPT = (
|
||||
""", in a JSON format representing the kwargs (e.g. ```{"input": "hello world", "num_beams": 5}```)"""
|
||||
)
|
||||
|
||||
|
||||
TOOL_SYSTEM_PROMPT = (
|
||||
"You have access to the following tools:\n{tool_text}"
|
||||
"Use the following format if using a tool:\n"
|
||||
"```\n"
|
||||
"Action: tool name (one of [{tool_names}]).\n"
|
||||
"Action Input: the input to the tool{format_prompt}.\n"
|
||||
"```\n"
|
||||
)
|
||||
|
||||
|
||||
def default_tool_formatter(tools: List[Dict[str, Any]]) -> str:
|
||||
tool_text = ""
|
||||
tool_names = []
|
||||
for tool in tools:
|
||||
param_text = ""
|
||||
for name, param in tool["parameters"]["properties"].items():
|
||||
required = ", required" if name in tool["parameters"].get("required", []) else ""
|
||||
enum = ", should be one of [{}]".format(", ".join(param["enum"])) if param.get("enum", None) else ""
|
||||
items = (
|
||||
", where each item should be {}".format(param["items"].get("type", "")) if param.get("items") else ""
|
||||
)
|
||||
param_text += " - {name} ({type}{required}): {desc}{enum}{items}\n".format(
|
||||
name=name,
|
||||
type=param.get("type", ""),
|
||||
required=required,
|
||||
desc=param.get("description", ""),
|
||||
enum=enum,
|
||||
items=items,
|
||||
)
|
||||
|
||||
tool_text += "> Tool Name: {name}\nTool Description: {desc}\nTool Args:\n{args}\n".format(
|
||||
name=tool["name"], desc=tool.get("description", ""), args=param_text
|
||||
)
|
||||
tool_names.append(tool["name"])
|
||||
|
||||
return TOOL_SYSTEM_PROMPT.format(
|
||||
tool_text=tool_text, tool_names=", ".join(tool_names), format_prompt=JSON_FORMAT_PROMPT
|
||||
)
|
||||
|
||||
|
||||
def default_tool_extractor(content: str) -> Union[str, Tuple[str, str]]:
|
||||
regex = re.compile(r"Action:\s*([a-zA-Z0-9_]+).*?Action Input:\s*(.*)", re.DOTALL)
|
||||
action_match = re.search(regex, content)
|
||||
if not action_match:
|
||||
return content
|
||||
|
||||
tool_name = action_match.group(1).strip()
|
||||
tool_input = action_match.group(2).strip().strip('"').strip("```")
|
||||
try:
|
||||
arguments = json.loads(tool_input)
|
||||
except json.JSONDecodeError:
|
||||
return content
|
||||
|
||||
return tool_name, json.dumps(arguments, ensure_ascii=False)
|
||||
|
||||
|
||||
@dataclass
|
||||
class Formatter(ABC):
|
||||
slots: SLOTS = field(default_factory=list)
|
||||
tool_format: Optional[Literal["default"]] = None
|
||||
|
||||
@abstractmethod
|
||||
def apply(self, **kwargs) -> SLOTS: ...
|
||||
|
||||
def extract(self, content: str) -> Union[str, Tuple[str, str]]:
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
@dataclass
|
||||
class EmptyFormatter(Formatter):
|
||||
def __post_init__(self):
|
||||
has_placeholder = False
|
||||
for slot in filter(lambda s: isinstance(s, str), self.slots):
|
||||
if re.search(r"\{\{[a-zA-Z_][a-zA-Z0-9_]*\}\}", slot):
|
||||
has_placeholder = True
|
||||
|
||||
if has_placeholder:
|
||||
raise ValueError("Empty formatter should not contain any placeholder.")
|
||||
|
||||
def apply(self, **kwargs) -> SLOTS:
|
||||
return self.slots
|
||||
|
||||
|
||||
@dataclass
|
||||
class StringFormatter(Formatter):
|
||||
def __post_init__(self):
|
||||
has_placeholder = False
|
||||
for slot in filter(lambda s: isinstance(s, str), self.slots):
|
||||
if re.search(r"\{\{[a-zA-Z_][a-zA-Z0-9_]*\}\}", slot):
|
||||
has_placeholder = True
|
||||
|
||||
if not has_placeholder:
|
||||
raise ValueError("A placeholder is required in the string formatter.")
|
||||
|
||||
def apply(self, **kwargs) -> SLOTS:
|
||||
elements = []
|
||||
for slot in self.slots:
|
||||
if isinstance(slot, str):
|
||||
for name, value in kwargs.items():
|
||||
if not isinstance(value, str):
|
||||
raise RuntimeError("Expected a string, got {}".format(value))
|
||||
|
||||
slot = slot.replace("{{" + name + "}}", value, 1)
|
||||
elements.append(slot)
|
||||
elif isinstance(slot, (dict, set)):
|
||||
elements.append(slot)
|
||||
else:
|
||||
raise RuntimeError("Input must be string, set[str] or dict[str, str], got {}".format(type(slot)))
|
||||
|
||||
return elements
|
||||
|
||||
|
||||
@dataclass
|
||||
class FunctionFormatter(Formatter):
|
||||
def __post_init__(self):
|
||||
has_name, has_args = False, False
|
||||
for slot in filter(lambda s: isinstance(s, str), self.slots):
|
||||
if "{{name}}" in slot:
|
||||
has_name = True
|
||||
if "{{arguments}}" in slot:
|
||||
has_args = True
|
||||
|
||||
if not has_name or not has_args:
|
||||
raise ValueError("Name and arguments placeholders are required in the function formatter.")
|
||||
|
||||
def apply(self, **kwargs) -> SLOTS:
|
||||
content = kwargs.pop("content")
|
||||
try:
|
||||
function = json.loads(content)
|
||||
name = function["name"]
|
||||
arguments = json.dumps(function["arguments"], ensure_ascii=False)
|
||||
except Exception:
|
||||
name, arguments = "", ""
|
||||
|
||||
elements = []
|
||||
for slot in self.slots:
|
||||
if isinstance(slot, str):
|
||||
slot = slot.replace("{{name}}", name).replace("{{arguments}}", arguments)
|
||||
elements.append(slot)
|
||||
elif isinstance(slot, (dict, set)):
|
||||
elements.append(slot)
|
||||
else:
|
||||
raise RuntimeError("Input must be string, set[str] or dict[str, str], got {}".format(type(slot)))
|
||||
|
||||
return elements
|
||||
|
||||
|
||||
@dataclass
|
||||
class ToolFormatter(Formatter):
|
||||
def __post_init__(self):
|
||||
if self.tool_format is None:
|
||||
raise ValueError("Tool format was not found.")
|
||||
|
||||
def apply(self, **kwargs) -> SLOTS:
|
||||
content = kwargs.pop("content")
|
||||
try:
|
||||
tools = json.loads(content)
|
||||
if not len(tools):
|
||||
return [""]
|
||||
|
||||
if self.tool_format == "default":
|
||||
return [default_tool_formatter(tools)]
|
||||
else:
|
||||
raise NotImplementedError
|
||||
except Exception:
|
||||
return [""]
|
||||
|
||||
def extract(self, content: str) -> Union[str, Tuple[str, str]]:
|
||||
if self.tool_format == "default":
|
||||
return default_tool_extractor(content)
|
||||
else:
|
||||
raise NotImplementedError
|
||||
170
src/AntSK.LLamaFactory/llamafactory/llmtuner/data/loader.py
Normal file
170
src/AntSK.LLamaFactory/llamafactory/llmtuner/data/loader.py
Normal file
@@ -0,0 +1,170 @@
|
||||
import inspect
|
||||
import os
|
||||
from typing import TYPE_CHECKING, Literal, Union
|
||||
|
||||
from datasets import load_dataset, load_from_disk
|
||||
|
||||
from ..extras.constants import FILEEXT2TYPE
|
||||
from ..extras.logging import get_logger
|
||||
from .aligner import align_dataset
|
||||
from .parser import get_dataset_list
|
||||
from .preprocess import get_preprocess_and_print_func
|
||||
from .template import get_template_and_fix_tokenizer
|
||||
from .utils import checksum, merge_dataset
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from datasets import Dataset, IterableDataset
|
||||
from transformers import Seq2SeqTrainingArguments
|
||||
from transformers.tokenization_utils import PreTrainedTokenizer
|
||||
|
||||
from ..hparams import DataArguments, ModelArguments
|
||||
from .parser import DatasetAttr
|
||||
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
def load_single_dataset(
|
||||
dataset_attr: "DatasetAttr",
|
||||
model_args: "ModelArguments",
|
||||
data_args: "DataArguments",
|
||||
) -> Union["Dataset", "IterableDataset"]:
|
||||
logger.info("Loading dataset {}...".format(dataset_attr))
|
||||
data_path, data_name, data_dir, data_files = None, None, None, None
|
||||
if dataset_attr.load_from in ["hf_hub", "ms_hub"]:
|
||||
data_path = dataset_attr.dataset_name
|
||||
data_name = dataset_attr.subset
|
||||
data_dir = dataset_attr.folder
|
||||
|
||||
elif dataset_attr.load_from == "script":
|
||||
data_path = os.path.join(data_args.dataset_dir, dataset_attr.dataset_name)
|
||||
data_name = dataset_attr.subset
|
||||
data_dir = dataset_attr.folder
|
||||
|
||||
elif dataset_attr.load_from == "file":
|
||||
data_files = []
|
||||
local_path = os.path.join(data_args.dataset_dir, dataset_attr.dataset_name)
|
||||
if os.path.isdir(local_path): # is directory
|
||||
for file_name in os.listdir(local_path):
|
||||
data_files.append(os.path.join(local_path, file_name))
|
||||
if data_path is None:
|
||||
data_path = FILEEXT2TYPE.get(file_name.split(".")[-1], None)
|
||||
elif data_path != FILEEXT2TYPE.get(file_name.split(".")[-1], None):
|
||||
raise ValueError("File types should be identical.")
|
||||
elif os.path.isfile(local_path): # is file
|
||||
data_files.append(local_path)
|
||||
data_path = FILEEXT2TYPE.get(local_path.split(".")[-1], None)
|
||||
else:
|
||||
raise ValueError("File not found.")
|
||||
|
||||
if data_path is None:
|
||||
raise ValueError("File extension must be txt, csv, json or jsonl.")
|
||||
|
||||
checksum(data_files, dataset_attr.file_sha1)
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
if dataset_attr.load_from == "ms_hub":
|
||||
try:
|
||||
from modelscope import MsDataset
|
||||
from modelscope.utils.config_ds import MS_DATASETS_CACHE
|
||||
|
||||
cache_dir = model_args.cache_dir or MS_DATASETS_CACHE
|
||||
dataset = MsDataset.load(
|
||||
dataset_name=data_path,
|
||||
subset_name=data_name,
|
||||
data_dir=data_dir,
|
||||
data_files=data_files,
|
||||
split=data_args.split,
|
||||
cache_dir=cache_dir,
|
||||
token=model_args.ms_hub_token,
|
||||
use_streaming=(data_args.streaming and (dataset_attr.load_from != "file")),
|
||||
).to_hf_dataset()
|
||||
except ImportError:
|
||||
raise ImportError("Please install modelscope via `pip install modelscope -U`")
|
||||
else:
|
||||
if "trust_remote_code" in inspect.signature(load_dataset).parameters: # for datasets==2.16.0
|
||||
kwargs = {"trust_remote_code": True}
|
||||
else:
|
||||
kwargs = {}
|
||||
|
||||
dataset = load_dataset(
|
||||
path=data_path,
|
||||
name=data_name,
|
||||
data_dir=data_dir,
|
||||
data_files=data_files,
|
||||
split=data_args.split,
|
||||
cache_dir=model_args.cache_dir,
|
||||
token=model_args.hf_hub_token,
|
||||
streaming=(data_args.streaming and (dataset_attr.load_from != "file")),
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
if data_args.streaming and (dataset_attr.load_from == "file"): # faster than specifying streaming=True
|
||||
dataset = dataset.to_iterable_dataset() # TODO: add num shards parameter
|
||||
|
||||
if data_args.max_samples is not None: # truncate dataset
|
||||
num_samples = min(data_args.max_samples, len(dataset))
|
||||
dataset = dataset.select(range(num_samples))
|
||||
|
||||
return align_dataset(dataset, dataset_attr, data_args)
|
||||
|
||||
|
||||
def get_dataset(
|
||||
tokenizer: "PreTrainedTokenizer",
|
||||
model_args: "ModelArguments",
|
||||
data_args: "DataArguments",
|
||||
training_args: "Seq2SeqTrainingArguments",
|
||||
stage: Literal["pt", "sft", "rm", "ppo"],
|
||||
# split: Optional[str] = "train", # TODO: add split
|
||||
) -> Union["Dataset", "IterableDataset"]:
|
||||
template = get_template_and_fix_tokenizer(tokenizer, data_args.template)
|
||||
if data_args.train_on_prompt and template.efficient_eos:
|
||||
raise ValueError("Current template does not support `train_on_prompt`.")
|
||||
|
||||
# Load from cache
|
||||
if data_args.cache_path is not None:
|
||||
if os.path.exists(data_args.cache_path):
|
||||
logger.warning("Loading dataset from disk will ignore other data arguments.")
|
||||
dataset = load_from_disk(data_args.cache_path)
|
||||
if data_args.streaming:
|
||||
dataset = dataset.to_iterable_dataset()
|
||||
return dataset
|
||||
|
||||
if data_args.streaming:
|
||||
raise ValueError("Turn off `streaming` when saving dataset to disk.")
|
||||
|
||||
with training_args.main_process_first(desc="load dataset"):
|
||||
all_datasets = []
|
||||
for dataset_attr in get_dataset_list(data_args):
|
||||
all_datasets.append(load_single_dataset(dataset_attr, model_args, data_args))
|
||||
dataset = merge_dataset(all_datasets, data_args, training_args)
|
||||
|
||||
with training_args.main_process_first(desc="pre-process dataset"):
|
||||
preprocess_func, print_function = get_preprocess_and_print_func(
|
||||
tokenizer, template, data_args, training_args, stage
|
||||
)
|
||||
column_names = list(next(iter(dataset)).keys())
|
||||
kwargs = {}
|
||||
if not data_args.streaming:
|
||||
kwargs = dict(
|
||||
num_proc=data_args.preprocessing_num_workers,
|
||||
load_from_cache_file=(not data_args.overwrite_cache),
|
||||
desc="Running tokenizer on dataset",
|
||||
)
|
||||
|
||||
dataset = dataset.map(preprocess_func, batched=True, remove_columns=column_names, **kwargs)
|
||||
|
||||
if data_args.cache_path is not None and not os.path.exists(data_args.cache_path):
|
||||
if training_args.should_save:
|
||||
dataset.save_to_disk(data_args.cache_path)
|
||||
logger.info("Dataset cache saved at {}.".format(data_args.cache_path))
|
||||
|
||||
if training_args.should_log:
|
||||
try:
|
||||
print_function(next(iter(dataset)))
|
||||
except StopIteration:
|
||||
raise RuntimeError("Cannot find valid samples, check `data/README.md` for the data format.")
|
||||
|
||||
return dataset
|
||||
119
src/AntSK.LLamaFactory/llamafactory/llmtuner/data/parser.py
Normal file
119
src/AntSK.LLamaFactory/llamafactory/llmtuner/data/parser.py
Normal file
@@ -0,0 +1,119 @@
|
||||
import json
|
||||
import os
|
||||
from dataclasses import dataclass
|
||||
from typing import TYPE_CHECKING, Any, Dict, List, Literal, Optional
|
||||
|
||||
from ..extras.constants import DATA_CONFIG
|
||||
from ..extras.misc import use_modelscope
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ..hparams import DataArguments
|
||||
|
||||
|
||||
@dataclass
|
||||
class DatasetAttr:
|
||||
r"""
|
||||
Dataset attributes.
|
||||
"""
|
||||
|
||||
""" basic configs """
|
||||
load_from: Literal["hf_hub", "ms_hub", "script", "file"]
|
||||
dataset_name: str
|
||||
""" extra configs """
|
||||
file_sha1: Optional[str] = None
|
||||
subset: Optional[str] = None
|
||||
folder: Optional[str] = None
|
||||
ranking: bool = False
|
||||
formatting: Literal["alpaca", "sharegpt"] = "alpaca"
|
||||
""" columns """
|
||||
system: Optional[str] = None
|
||||
""" columns for the alpaca format """
|
||||
prompt: Optional[str] = "instruction"
|
||||
query: Optional[str] = "input"
|
||||
response: Optional[str] = "output"
|
||||
history: Optional[str] = None
|
||||
""" columns for the sharegpt format """
|
||||
messages: Optional[str] = "conversations"
|
||||
tools: Optional[str] = None
|
||||
""" tags for the sharegpt format """
|
||||
role_tag: Optional[str] = "from"
|
||||
content_tag: Optional[str] = "value"
|
||||
user_tag: Optional[str] = "human"
|
||||
assistant_tag: Optional[str] = "gpt"
|
||||
observation_tag: Optional[str] = "observation"
|
||||
function_tag: Optional[str] = "function_call"
|
||||
system_tag: Optional[str] = "system"
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return self.dataset_name
|
||||
|
||||
def set_attr(self, key: str, obj: Dict[str, Any], default: Optional[Any] = None) -> None:
|
||||
setattr(self, key, obj.get(key, default))
|
||||
|
||||
|
||||
def get_dataset_list(data_args: "DataArguments") -> List["DatasetAttr"]:
|
||||
dataset_names = [ds.strip() for ds in data_args.dataset.split(",")] if data_args.dataset is not None else []
|
||||
try:
|
||||
with open(os.path.join(data_args.dataset_dir, DATA_CONFIG), "r") as f:
|
||||
dataset_info = json.load(f)
|
||||
except Exception as err:
|
||||
if data_args.dataset is not None:
|
||||
raise ValueError(
|
||||
"Cannot open {} due to {}.".format(os.path.join(data_args.dataset_dir, DATA_CONFIG), str(err))
|
||||
)
|
||||
dataset_info = None
|
||||
|
||||
if data_args.interleave_probs is not None:
|
||||
data_args.interleave_probs = [float(prob.strip()) for prob in data_args.interleave_probs.split(",")]
|
||||
|
||||
dataset_list: List[DatasetAttr] = []
|
||||
for name in dataset_names:
|
||||
if name not in dataset_info:
|
||||
raise ValueError("Undefined dataset {} in {}.".format(name, DATA_CONFIG))
|
||||
|
||||
has_hf_url = "hf_hub_url" in dataset_info[name]
|
||||
has_ms_url = "ms_hub_url" in dataset_info[name]
|
||||
|
||||
if has_hf_url or has_ms_url:
|
||||
if (use_modelscope() and has_ms_url) or (not has_hf_url):
|
||||
dataset_attr = DatasetAttr("ms_hub", dataset_name=dataset_info[name]["ms_hub_url"])
|
||||
else:
|
||||
dataset_attr = DatasetAttr("hf_hub", dataset_name=dataset_info[name]["hf_hub_url"])
|
||||
elif "script_url" in dataset_info[name]:
|
||||
dataset_attr = DatasetAttr("script", dataset_name=dataset_info[name]["script_url"])
|
||||
else:
|
||||
dataset_attr = DatasetAttr("file", dataset_name=dataset_info[name]["file_name"])
|
||||
|
||||
dataset_attr.set_attr("file_sha1", dataset_info[name])
|
||||
dataset_attr.set_attr("subset", dataset_info[name])
|
||||
dataset_attr.set_attr("folder", dataset_info[name])
|
||||
dataset_attr.set_attr("ranking", dataset_info[name], default=False)
|
||||
dataset_attr.set_attr("formatting", dataset_info[name], default="alpaca")
|
||||
|
||||
if "columns" in dataset_info[name]:
|
||||
column_names = ["system"]
|
||||
if dataset_attr.formatting == "alpaca":
|
||||
column_names.extend(["prompt", "query", "response", "history"])
|
||||
else:
|
||||
column_names.extend(["messages", "tools"])
|
||||
|
||||
for column_name in column_names:
|
||||
dataset_attr.set_attr(column_name, dataset_info[name]["columns"])
|
||||
|
||||
if dataset_attr.formatting == "sharegpt" and "tags" in dataset_info[name]:
|
||||
tag_names = (
|
||||
"role_tag",
|
||||
"content_tag",
|
||||
"user_tag",
|
||||
"assistant_tag",
|
||||
"observation_tag",
|
||||
"function_tag",
|
||||
"system_tag",
|
||||
)
|
||||
for tag in tag_names:
|
||||
dataset_attr.set_attr(tag, dataset_info[name]["tags"])
|
||||
|
||||
dataset_list.append(dataset_attr)
|
||||
|
||||
return dataset_list
|
||||
276
src/AntSK.LLamaFactory/llamafactory/llmtuner/data/preprocess.py
Normal file
276
src/AntSK.LLamaFactory/llamafactory/llmtuner/data/preprocess.py
Normal file
@@ -0,0 +1,276 @@
|
||||
from functools import partial
|
||||
from itertools import chain
|
||||
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Literal, Tuple
|
||||
|
||||
from ..extras.constants import IGNORE_INDEX
|
||||
from ..extras.logging import get_logger
|
||||
from .utils import Role
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from transformers import Seq2SeqTrainingArguments
|
||||
from transformers.tokenization_utils import PreTrainedTokenizer
|
||||
|
||||
from ..hparams import DataArguments
|
||||
from .template import Template
|
||||
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
def preprocess_pretrain_dataset(
|
||||
examples: Dict[str, List[Any]], tokenizer: "PreTrainedTokenizer", data_args: "DataArguments"
|
||||
) -> Dict[str, List[List[int]]]:
|
||||
# build grouped texts with format `X1 X2 X3 ...` if packing is enabled
|
||||
text_examples = [messages[0]["content"] + tokenizer.eos_token for messages in examples["prompt"]]
|
||||
if not data_args.packing:
|
||||
return tokenizer(text_examples, add_special_tokens=False, max_length=data_args.cutoff_len)
|
||||
|
||||
tokenized_examples = tokenizer(text_examples, add_special_tokens=False)
|
||||
concatenated_examples = {k: list(chain(*tokenized_examples[k])) for k in tokenized_examples.keys()}
|
||||
total_length = len(concatenated_examples[list(concatenated_examples.keys())[0]])
|
||||
block_size = data_args.cutoff_len
|
||||
# we drop the small remainder, and if the total_length < block_size, we exclude this batch
|
||||
total_length = (total_length // block_size) * block_size
|
||||
# split by chunks of cutoff_len
|
||||
result = {
|
||||
k: [t[i : i + block_size] for i in range(0, total_length, block_size)]
|
||||
for k, t in concatenated_examples.items()
|
||||
}
|
||||
if data_args.template == "gemma":
|
||||
for i in range(len(result["input_ids"])):
|
||||
result["input_ids"][i][0] = tokenizer.bos_token_id
|
||||
|
||||
return result
|
||||
|
||||
|
||||
def preprocess_supervised_dataset(
|
||||
examples: Dict[str, List[Any]],
|
||||
tokenizer: "PreTrainedTokenizer",
|
||||
template: "Template",
|
||||
data_args: "DataArguments",
|
||||
) -> Dict[str, List[List[int]]]:
|
||||
# build inputs with format `<bos> X Y <eos>` and labels with format `<ignore> ... <ignore> Y <eos>`
|
||||
# for multiturn examples, we only mask the prompt part in each prompt-response pair.
|
||||
model_inputs = {"input_ids": [], "attention_mask": [], "labels": []}
|
||||
|
||||
for i in range(len(examples["prompt"])):
|
||||
if len(examples["prompt"][i]) % 2 != 1 or len(examples["response"][i]) != 1:
|
||||
continue
|
||||
|
||||
messages = examples["prompt"][i] + examples["response"][i]
|
||||
input_ids, labels = [], []
|
||||
for turn_idx, (source_ids, target_ids) in enumerate(
|
||||
template.encode_multiturn(
|
||||
tokenizer,
|
||||
messages,
|
||||
examples["system"][i],
|
||||
examples["tools"][i],
|
||||
data_args.cutoff_len,
|
||||
data_args.reserved_label_len,
|
||||
)
|
||||
):
|
||||
if data_args.train_on_prompt:
|
||||
source_mask = source_ids
|
||||
elif turn_idx != 0 and template.efficient_eos:
|
||||
source_mask = [tokenizer.eos_token_id] + [IGNORE_INDEX] * (len(source_ids) - 1)
|
||||
else:
|
||||
source_mask = [IGNORE_INDEX] * len(source_ids)
|
||||
|
||||
input_ids += source_ids + target_ids
|
||||
labels += source_mask + target_ids
|
||||
|
||||
if template.efficient_eos:
|
||||
input_ids += [tokenizer.eos_token_id]
|
||||
labels += [tokenizer.eos_token_id]
|
||||
|
||||
model_inputs["input_ids"].append(input_ids)
|
||||
model_inputs["attention_mask"].append([1] * len(input_ids))
|
||||
model_inputs["labels"].append(labels)
|
||||
|
||||
return model_inputs
|
||||
|
||||
|
||||
def preprocess_packed_supervised_dataset(
|
||||
examples: Dict[str, List[Any]],
|
||||
tokenizer: "PreTrainedTokenizer",
|
||||
template: "Template",
|
||||
data_args: "DataArguments",
|
||||
) -> Dict[str, List[List[int]]]:
|
||||
# build inputs with format `<bos> X1 Y1 <eos> <bos> X2 Y2 <eos>`
|
||||
# and labels with format `<ignore> ... <ignore> Y1 <eos> <ignore> ... <ignore> Y2 <eos>`
|
||||
model_inputs = {"input_ids": [], "attention_mask": [], "labels": []}
|
||||
input_ids, labels = [], []
|
||||
for i in range(len(examples["prompt"])):
|
||||
if len(examples["prompt"][i]) % 2 != 1 or len(examples["response"][i]) != 1:
|
||||
continue
|
||||
|
||||
messages = examples["prompt"][i] + examples["response"][i]
|
||||
for source_ids, target_ids in template.encode_multiturn(
|
||||
tokenizer, messages, examples["system"][i], examples["tools"][i]
|
||||
):
|
||||
if data_args.train_on_prompt:
|
||||
source_mask = source_ids
|
||||
elif len(input_ids) != 0 and template.efficient_eos:
|
||||
source_mask = [tokenizer.eos_token_id] + [IGNORE_INDEX] * (len(source_ids) - 1)
|
||||
else:
|
||||
source_mask = [IGNORE_INDEX] * len(source_ids)
|
||||
|
||||
input_ids += source_ids + target_ids
|
||||
labels += source_mask + target_ids
|
||||
|
||||
if template.efficient_eos:
|
||||
input_ids += [tokenizer.eos_token_id]
|
||||
labels += [tokenizer.eos_token_id]
|
||||
|
||||
total_length = len(input_ids)
|
||||
block_size = data_args.cutoff_len
|
||||
# we drop the small remainder, and if the total_length < block_size, we exclude this batch
|
||||
total_length = (total_length // block_size) * block_size
|
||||
# split by chunks of cutoff_len
|
||||
for i in range(0, total_length, block_size):
|
||||
if not all(label == IGNORE_INDEX for label in labels[i : i + block_size]):
|
||||
model_inputs["input_ids"].append(input_ids[i : i + block_size])
|
||||
model_inputs["attention_mask"].append([1] * block_size)
|
||||
model_inputs["labels"].append(labels[i : i + block_size])
|
||||
|
||||
return model_inputs
|
||||
|
||||
|
||||
def preprocess_unsupervised_dataset(
|
||||
examples: Dict[str, List[Any]],
|
||||
tokenizer: "PreTrainedTokenizer",
|
||||
template: "Template",
|
||||
data_args: "DataArguments",
|
||||
) -> Dict[str, List[List[int]]]:
|
||||
# build inputs with format `<bos> X` and labels with format `Y <eos>`
|
||||
model_inputs = {"input_ids": [], "attention_mask": [], "labels": []}
|
||||
|
||||
for i in range(len(examples["prompt"])):
|
||||
if len(examples["prompt"][i]) % 2 != 1:
|
||||
continue
|
||||
|
||||
if len(examples["response"][i]) == 1:
|
||||
messages = examples["prompt"][i] + examples["response"][i]
|
||||
else:
|
||||
messages = examples["prompt"][i] + [{"role": Role.ASSISTANT.value, "content": ""}]
|
||||
|
||||
input_ids, labels = template.encode_oneturn(
|
||||
tokenizer,
|
||||
messages,
|
||||
examples["system"][i],
|
||||
examples["tools"][i],
|
||||
data_args.cutoff_len,
|
||||
data_args.reserved_label_len,
|
||||
)
|
||||
|
||||
if template.efficient_eos:
|
||||
labels += [tokenizer.eos_token_id]
|
||||
|
||||
model_inputs["input_ids"].append(input_ids)
|
||||
model_inputs["attention_mask"].append([1] * len(input_ids))
|
||||
model_inputs["labels"].append(labels)
|
||||
|
||||
return model_inputs
|
||||
|
||||
|
||||
def preprocess_pairwise_dataset(
|
||||
examples: Dict[str, List[Any]],
|
||||
tokenizer: "PreTrainedTokenizer",
|
||||
template: "Template",
|
||||
data_args: "DataArguments",
|
||||
) -> Dict[str, List[List[int]]]:
|
||||
# build input pairs with format `<bos> X`, `Y1 <eos>` and `Y2 <eos>`
|
||||
model_inputs = {"prompt_ids": [], "chosen_ids": [], "rejected_ids": []}
|
||||
for i in range(len(examples["prompt"])):
|
||||
if len(examples["prompt"][i]) % 2 != 1 or len(examples["response"][i]) < 2:
|
||||
continue
|
||||
|
||||
chosen_messages = examples["prompt"][i] + [examples["response"][i][0]]
|
||||
rejected_messages = examples["prompt"][i] + [examples["response"][i][1]]
|
||||
prompt_ids, chosen_ids = template.encode_oneturn(
|
||||
tokenizer,
|
||||
chosen_messages,
|
||||
examples["system"][i],
|
||||
examples["tools"][i],
|
||||
data_args.cutoff_len,
|
||||
data_args.reserved_label_len,
|
||||
)
|
||||
_, rejected_ids = template.encode_oneturn(
|
||||
tokenizer,
|
||||
rejected_messages,
|
||||
examples["system"][i],
|
||||
examples["tools"][i],
|
||||
data_args.cutoff_len,
|
||||
data_args.reserved_label_len,
|
||||
)
|
||||
|
||||
if template.efficient_eos:
|
||||
chosen_ids += [tokenizer.eos_token_id]
|
||||
rejected_ids += [tokenizer.eos_token_id]
|
||||
|
||||
model_inputs["prompt_ids"].append(prompt_ids)
|
||||
model_inputs["chosen_ids"].append(chosen_ids)
|
||||
model_inputs["rejected_ids"].append(rejected_ids)
|
||||
|
||||
return model_inputs
|
||||
|
||||
|
||||
def print_supervised_dataset_example(example: Dict[str, List[int]], tokenizer: "PreTrainedTokenizer") -> None:
|
||||
print("input_ids:\n{}".format(example["input_ids"]))
|
||||
print("inputs:\n{}".format(tokenizer.decode(example["input_ids"], skip_special_tokens=False)))
|
||||
print("label_ids:\n{}".format(example["labels"]))
|
||||
print(
|
||||
"labels:\n{}".format(
|
||||
tokenizer.decode(list(filter(lambda x: x != IGNORE_INDEX, example["labels"])), skip_special_tokens=False)
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
def print_pairwise_dataset_example(example: Dict[str, List[int]], tokenizer: "PreTrainedTokenizer") -> None:
|
||||
print("prompt_ids:\n{}".format(example["prompt_ids"]))
|
||||
print("prompt:\n{}".format(tokenizer.decode(example["prompt_ids"], skip_special_tokens=False)))
|
||||
print("chosen_ids:\n{}".format(example["chosen_ids"]))
|
||||
print("chosen:\n{}".format(tokenizer.decode(example["chosen_ids"], skip_special_tokens=False)))
|
||||
print("rejected_ids:\n{}".format(example["rejected_ids"]))
|
||||
print("rejected:\n{}".format(tokenizer.decode(example["rejected_ids"], skip_special_tokens=False)))
|
||||
|
||||
|
||||
def print_unsupervised_dataset_example(example: Dict[str, List[int]], tokenizer: "PreTrainedTokenizer") -> None:
|
||||
print("input_ids:\n{}".format(example["input_ids"]))
|
||||
print("inputs:\n{}".format(tokenizer.decode(example["input_ids"], skip_special_tokens=False)))
|
||||
|
||||
|
||||
def get_preprocess_and_print_func(
|
||||
tokenizer: "PreTrainedTokenizer",
|
||||
template: "Template",
|
||||
data_args: "DataArguments",
|
||||
training_args: "Seq2SeqTrainingArguments",
|
||||
stage: Literal["pt", "sft", "rm", "ppo"],
|
||||
) -> Tuple[Callable, Callable]:
|
||||
if stage == "pt":
|
||||
preprocess_func = partial(preprocess_pretrain_dataset, tokenizer=tokenizer, data_args=data_args)
|
||||
print_function = partial(print_unsupervised_dataset_example, tokenizer=tokenizer)
|
||||
elif stage == "sft" and not training_args.predict_with_generate:
|
||||
if data_args.packing:
|
||||
preprocess_func = partial(
|
||||
preprocess_packed_supervised_dataset, tokenizer=tokenizer, template=template, data_args=data_args
|
||||
)
|
||||
else:
|
||||
preprocess_func = partial(
|
||||
preprocess_supervised_dataset, tokenizer=tokenizer, template=template, data_args=data_args
|
||||
)
|
||||
|
||||
print_function = partial(print_supervised_dataset_example, tokenizer=tokenizer)
|
||||
elif stage == "rm":
|
||||
preprocess_func = partial(
|
||||
preprocess_pairwise_dataset, tokenizer=tokenizer, template=template, data_args=data_args
|
||||
)
|
||||
print_function = partial(print_pairwise_dataset_example, tokenizer=tokenizer)
|
||||
else:
|
||||
preprocess_func = partial(
|
||||
preprocess_unsupervised_dataset, tokenizer=tokenizer, template=template, data_args=data_args
|
||||
)
|
||||
print_function = partial(print_unsupervised_dataset_example, tokenizer=tokenizer)
|
||||
|
||||
return preprocess_func, print_function
|
||||
773
src/AntSK.LLamaFactory/llamafactory/llmtuner/data/template.py
Normal file
773
src/AntSK.LLamaFactory/llamafactory/llmtuner/data/template.py
Normal file
@@ -0,0 +1,773 @@
|
||||
from dataclasses import dataclass
|
||||
from typing import TYPE_CHECKING, Dict, List, Optional, Sequence, Tuple, Union
|
||||
|
||||
from ..extras.logging import get_logger
|
||||
from .formatter import EmptyFormatter, FunctionFormatter, StringFormatter, ToolFormatter
|
||||
from .utils import Role, infer_max_len
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from transformers import PreTrainedTokenizer
|
||||
|
||||
from .formatter import SLOTS, Formatter
|
||||
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class Template:
|
||||
format_user: "Formatter"
|
||||
format_assistant: "Formatter"
|
||||
format_system: "Formatter"
|
||||
format_function: "Formatter"
|
||||
format_observation: "Formatter"
|
||||
format_tools: "Formatter"
|
||||
format_separator: "Formatter"
|
||||
default_system: str
|
||||
stop_words: List[str]
|
||||
efficient_eos: bool
|
||||
replace_eos: bool
|
||||
force_system: bool
|
||||
|
||||
def encode_oneturn(
|
||||
self,
|
||||
tokenizer: "PreTrainedTokenizer",
|
||||
messages: List[Dict[str, str]],
|
||||
system: Optional[str] = None,
|
||||
tools: Optional[str] = None,
|
||||
cutoff_len: int = 1_000_000,
|
||||
reserved_label_len: int = 1,
|
||||
) -> Tuple[List[int], List[int]]:
|
||||
r"""
|
||||
Returns a single pair of token ids representing prompt and response respectively.
|
||||
"""
|
||||
encoded_pairs = self._encode(tokenizer, messages, system, tools, cutoff_len, reserved_label_len)
|
||||
prompt_ids = []
|
||||
for query_ids, resp_ids in encoded_pairs[:-1]:
|
||||
prompt_ids += query_ids + resp_ids
|
||||
prompt_ids = prompt_ids + encoded_pairs[-1][0]
|
||||
answer_ids = encoded_pairs[-1][1]
|
||||
return prompt_ids, answer_ids
|
||||
|
||||
def encode_multiturn(
|
||||
self,
|
||||
tokenizer: "PreTrainedTokenizer",
|
||||
messages: List[Dict[str, str]],
|
||||
system: Optional[str] = None,
|
||||
tools: Optional[str] = None,
|
||||
cutoff_len: int = 1_000_000,
|
||||
reserved_label_len: int = 1,
|
||||
) -> Sequence[Tuple[List[int], List[int]]]:
|
||||
r"""
|
||||
Returns multiple pairs of token ids representing prompts and responses respectively.
|
||||
"""
|
||||
return self._encode(tokenizer, messages, system, tools, cutoff_len, reserved_label_len)
|
||||
|
||||
def _encode(
|
||||
self,
|
||||
tokenizer: "PreTrainedTokenizer",
|
||||
messages: List[Dict[str, str]],
|
||||
system: str,
|
||||
tools: str,
|
||||
cutoff_len: int,
|
||||
reserved_label_len: int,
|
||||
) -> Sequence[Tuple[List[int], List[int]]]:
|
||||
r"""
|
||||
Encodes formatted inputs to pairs of token ids.
|
||||
Turn 0: system + query resp
|
||||
Turn t: sep + query resp
|
||||
"""
|
||||
system = system or self.default_system
|
||||
encoded_messages = []
|
||||
for i, message in enumerate(messages):
|
||||
elements = []
|
||||
if i == 0 and (system or tools or self.force_system):
|
||||
tool_text = self.format_tools.apply(content=tools)[0] if tools else ""
|
||||
elements += self.format_system.apply(content=(system + tool_text))
|
||||
elif i > 0 and i % 2 == 0:
|
||||
elements += self.format_separator.apply()
|
||||
|
||||
if message["role"] == Role.USER.value:
|
||||
elements += self.format_user.apply(content=message["content"], idx=str(i // 2))
|
||||
elif message["role"] == Role.ASSISTANT.value:
|
||||
elements += self.format_assistant.apply(content=message["content"])
|
||||
elif message["role"] == Role.OBSERVATION.value:
|
||||
elements += self.format_observation.apply(content=message["content"])
|
||||
elif message["role"] == Role.FUNCTION.value:
|
||||
elements += self.format_function.apply(content=message["content"])
|
||||
else:
|
||||
raise NotImplementedError("Unexpected role: {}".format(message["role"]))
|
||||
|
||||
encoded_messages.append(self._convert_elements_to_ids(tokenizer, elements))
|
||||
|
||||
return self._make_pairs(encoded_messages, cutoff_len, reserved_label_len)
|
||||
|
||||
def _convert_elements_to_ids(
|
||||
self, tokenizer: "PreTrainedTokenizer", elements: List[Union[str, Dict[str, str]]]
|
||||
) -> List[int]:
|
||||
r"""
|
||||
Converts elements to token ids.
|
||||
"""
|
||||
token_ids = []
|
||||
for elem in elements:
|
||||
if isinstance(elem, str):
|
||||
if len(elem) != 0:
|
||||
token_ids += tokenizer.encode(elem, add_special_tokens=False)
|
||||
elif isinstance(elem, dict):
|
||||
token_ids += [tokenizer.convert_tokens_to_ids(elem.get("token"))]
|
||||
elif isinstance(elem, set):
|
||||
if "bos_token" in elem and tokenizer.bos_token_id is not None:
|
||||
token_ids += [tokenizer.bos_token_id]
|
||||
elif "eos_token" in elem and tokenizer.eos_token_id is not None:
|
||||
token_ids += [tokenizer.eos_token_id]
|
||||
else:
|
||||
raise ValueError("Input must be string, set[str] or dict[str, str], got {}".format(type(elem)))
|
||||
|
||||
return token_ids
|
||||
|
||||
def _make_pairs(
|
||||
self,
|
||||
encoded_messages: Sequence[List[int]],
|
||||
cutoff_len: int,
|
||||
reserved_label_len: int,
|
||||
) -> Sequence[Tuple[List[int], List[int]]]:
|
||||
encoded_pairs = []
|
||||
total_length = 0
|
||||
for i in range(0, len(encoded_messages), 2):
|
||||
if total_length >= cutoff_len:
|
||||
break
|
||||
|
||||
max_source_len, max_target_len = infer_max_len(
|
||||
source_len=len(encoded_messages[i]),
|
||||
target_len=len(encoded_messages[i + 1]),
|
||||
max_len=(cutoff_len - total_length),
|
||||
reserved_label_len=reserved_label_len,
|
||||
)
|
||||
source_ids = encoded_messages[i][:max_source_len]
|
||||
target_ids = encoded_messages[i + 1][:max_target_len]
|
||||
total_length += len(source_ids) + len(target_ids)
|
||||
encoded_pairs.append((source_ids, target_ids))
|
||||
|
||||
return encoded_pairs
|
||||
|
||||
|
||||
@dataclass
|
||||
class Llama2Template(Template):
|
||||
def _encode(
|
||||
self,
|
||||
tokenizer: "PreTrainedTokenizer",
|
||||
messages: List[Dict[str, str]],
|
||||
system: str,
|
||||
tools: str,
|
||||
cutoff_len: int,
|
||||
reserved_label_len: int,
|
||||
) -> Sequence[Tuple[List[int], List[int]]]:
|
||||
r"""
|
||||
Encodes formatted inputs to pairs of token ids.
|
||||
Turn 0: system + query resp
|
||||
Turn t: sep + query resp
|
||||
"""
|
||||
system = system or self.default_system
|
||||
encoded_messages = []
|
||||
for i, message in enumerate(messages):
|
||||
elements = []
|
||||
system_text = ""
|
||||
if i == 0 and (system or tools or self.force_system):
|
||||
tool_text = self.format_tools.apply(content=tools)[0] if tools else ""
|
||||
system_text = self.format_system.apply(content=(system + tool_text))[0]
|
||||
elif i > 0 and i % 2 == 0:
|
||||
elements += self.format_separator.apply()
|
||||
|
||||
if message["role"] == Role.USER.value:
|
||||
elements += self.format_user.apply(content=system_text + message["content"])
|
||||
elif message["role"] == Role.ASSISTANT.value:
|
||||
elements += self.format_assistant.apply(content=message["content"])
|
||||
elif message["role"] == Role.OBSERVATION.value:
|
||||
elements += self.format_observation.apply(content=message["content"])
|
||||
elif message["role"] == Role.FUNCTION.value:
|
||||
elements += self.format_function.apply(content=message["content"])
|
||||
else:
|
||||
raise NotImplementedError("Unexpected role: {}".format(message["role"]))
|
||||
|
||||
encoded_messages.append(self._convert_elements_to_ids(tokenizer, elements))
|
||||
|
||||
return self._make_pairs(encoded_messages, cutoff_len, reserved_label_len)
|
||||
|
||||
|
||||
templates: Dict[str, Template] = {}
|
||||
|
||||
|
||||
def _register_template(
|
||||
name: str,
|
||||
format_user: Optional["Formatter"] = None,
|
||||
format_assistant: Optional["Formatter"] = None,
|
||||
format_system: Optional["Formatter"] = None,
|
||||
format_function: Optional["Formatter"] = None,
|
||||
format_observation: Optional["Formatter"] = None,
|
||||
format_tools: Optional["Formatter"] = None,
|
||||
format_separator: Optional["Formatter"] = None,
|
||||
default_system: str = "",
|
||||
stop_words: List[str] = [],
|
||||
efficient_eos: bool = False,
|
||||
replace_eos: bool = False,
|
||||
force_system: bool = False,
|
||||
) -> None:
|
||||
r"""
|
||||
Registers a chat template.
|
||||
|
||||
To add the following chat template:
|
||||
```
|
||||
[HUMAN]:
|
||||
user prompt here
|
||||
[AI]:
|
||||
model response here
|
||||
|
||||
[HUMAN]:
|
||||
user prompt here
|
||||
[AI]:
|
||||
model response here
|
||||
```
|
||||
|
||||
The corresponding code should be:
|
||||
```
|
||||
_register_template(
|
||||
name="custom",
|
||||
format_user=StringFormatter(slots=["[HUMAN]:\n{{content}}\n[AI]:\n"]),
|
||||
format_separator=EmptyFormatter(slots=["\n\n"]),
|
||||
efficient_eos=True,
|
||||
)
|
||||
```
|
||||
"""
|
||||
eos_slots = [] if efficient_eos else [{"eos_token"}]
|
||||
template_class = Llama2Template if name.startswith("llama2") else Template
|
||||
default_user_formatter = StringFormatter(slots=["{{content}}"])
|
||||
default_assistant_formatter = StringFormatter(slots=["{{content}}"] + eos_slots)
|
||||
default_function_formatter = FunctionFormatter(slots=["Action: {{name}}\nAction Input: {{arguments}}"] + eos_slots)
|
||||
default_tool_formatter = ToolFormatter(tool_format="default")
|
||||
default_separator_formatter = EmptyFormatter()
|
||||
templates[name] = template_class(
|
||||
format_user=format_user or default_user_formatter,
|
||||
format_assistant=format_assistant or default_assistant_formatter,
|
||||
format_system=format_system or default_user_formatter,
|
||||
format_function=format_function or default_function_formatter,
|
||||
format_observation=format_observation or format_user or default_user_formatter,
|
||||
format_tools=format_tools or default_tool_formatter,
|
||||
format_separator=format_separator or default_separator_formatter,
|
||||
default_system=default_system,
|
||||
stop_words=stop_words,
|
||||
efficient_eos=efficient_eos,
|
||||
replace_eos=replace_eos,
|
||||
force_system=force_system,
|
||||
)
|
||||
|
||||
|
||||
def _add_or_replace_eos_token(tokenizer: "PreTrainedTokenizer", eos_token: str) -> None:
|
||||
is_added = tokenizer.eos_token_id is None
|
||||
num_added_tokens = tokenizer.add_special_tokens({"eos_token": eos_token})
|
||||
|
||||
if is_added:
|
||||
logger.info("Add eos token: {}".format(tokenizer.eos_token))
|
||||
else:
|
||||
logger.info("Replace eos token: {}".format(tokenizer.eos_token))
|
||||
|
||||
if num_added_tokens > 0:
|
||||
logger.warning("New tokens have been added, make sure `resize_vocab` is True.")
|
||||
|
||||
|
||||
def _jinja_escape(content: str) -> str:
|
||||
return content.replace("\n", r"\n").replace("'", r"\'")
|
||||
|
||||
|
||||
def _convert_slots_to_jinja(slots: "SLOTS", tokenizer: "PreTrainedTokenizer", placeholder: str = "content") -> str:
|
||||
slot_items = []
|
||||
for slot in slots:
|
||||
if isinstance(slot, str):
|
||||
slot_pieces = slot.split("{{content}}")
|
||||
if slot_pieces[0]:
|
||||
slot_items.append("'" + _jinja_escape(slot_pieces[0]) + "'")
|
||||
if len(slot_pieces) > 1:
|
||||
slot_items.append(placeholder)
|
||||
if slot_pieces[1]:
|
||||
slot_items.append("'" + _jinja_escape(slot_pieces[1]) + "'")
|
||||
elif isinstance(slot, set):
|
||||
if "bos_token" in slot:
|
||||
slot_items.append("'" + tokenizer.bos_token + "'")
|
||||
elif "eos_token" in slot: # do not use {{ eos_token }} since it may be replaced
|
||||
slot_items.append("'" + tokenizer.eos_token + "'")
|
||||
elif isinstance(slot, dict):
|
||||
raise ValueError("Dict is not supported.")
|
||||
|
||||
return " + ".join(slot_items)
|
||||
|
||||
|
||||
def _get_jinja_template(template: "Template", tokenizer: "PreTrainedTokenizer") -> str:
|
||||
jinja_template = ""
|
||||
|
||||
if template.default_system:
|
||||
jinja_template += "{% set system_message = '" + _jinja_escape(template.default_system) + "' %}"
|
||||
|
||||
jinja_template += (
|
||||
"{% if messages[0]['role'] == 'system' %}" "{% set system_message = messages[0]['content'] %}" "{% endif %}"
|
||||
)
|
||||
|
||||
system_message = _convert_slots_to_jinja(template.format_system.apply(), tokenizer, placeholder="system_message")
|
||||
if isinstance(template, Llama2Template):
|
||||
pass
|
||||
elif template.force_system:
|
||||
jinja_template += "{{ " + system_message + " }}"
|
||||
else:
|
||||
jinja_template += "{% if system_message is defined %}{{ " + system_message + " }}{% endif %}"
|
||||
|
||||
jinja_template += "{% for message in messages %}"
|
||||
jinja_template += "{% set content = message['content'] %}"
|
||||
if isinstance(template, Llama2Template):
|
||||
jinja_template += "{% if loop.index0 == 0 and system_message is defined %}"
|
||||
jinja_template += "{% set content = " + system_message + " + message['content'] %}"
|
||||
jinja_template += "{% endif %}"
|
||||
jinja_template += "{% if message['role'] == 'user' %}"
|
||||
user_message = _convert_slots_to_jinja(template.format_user.apply(), tokenizer)
|
||||
jinja_template += "{{ " + user_message + " }}"
|
||||
jinja_template += "{% elif message['role'] == 'assistant' %}"
|
||||
assistant_message = _convert_slots_to_jinja(
|
||||
template.format_assistant.apply() + template.format_separator.apply(), tokenizer
|
||||
)
|
||||
jinja_template += "{{ " + assistant_message + " }}"
|
||||
jinja_template += "{% endif %}"
|
||||
jinja_template += "{% endfor %}"
|
||||
return jinja_template
|
||||
|
||||
|
||||
def get_template_and_fix_tokenizer(
|
||||
tokenizer: "PreTrainedTokenizer",
|
||||
name: Optional[str] = None,
|
||||
) -> Template:
|
||||
if name is None:
|
||||
template = templates["vanilla"] # placeholder
|
||||
else:
|
||||
template = templates.get(name, None)
|
||||
if template is None:
|
||||
raise ValueError("Template {} does not exist.".format(name))
|
||||
|
||||
stop_words = template.stop_words
|
||||
if template.replace_eos:
|
||||
if not stop_words:
|
||||
raise ValueError("Stop words are required to replace the EOS token.")
|
||||
|
||||
_add_or_replace_eos_token(tokenizer, eos_token=stop_words[0])
|
||||
stop_words = stop_words[1:]
|
||||
|
||||
if tokenizer.eos_token_id is None:
|
||||
_add_or_replace_eos_token(tokenizer, eos_token="<|endoftext|>")
|
||||
|
||||
if tokenizer.pad_token_id is None:
|
||||
tokenizer.pad_token = tokenizer.eos_token
|
||||
logger.info("Add pad token: {}".format(tokenizer.pad_token))
|
||||
|
||||
if stop_words:
|
||||
num_added_tokens = tokenizer.add_special_tokens(
|
||||
dict(additional_special_tokens=stop_words), replace_additional_special_tokens=False
|
||||
)
|
||||
logger.info("Add {} to stop words.".format(",".join(stop_words)))
|
||||
if num_added_tokens > 0:
|
||||
logger.warning("New tokens have been added, make sure `resize_vocab` is True.")
|
||||
|
||||
try:
|
||||
tokenizer.chat_template = _get_jinja_template(template, tokenizer)
|
||||
except ValueError:
|
||||
logger.info("Cannot add this chat template to tokenizer.")
|
||||
|
||||
return template
|
||||
|
||||
|
||||
_register_template(
|
||||
name="alpaca",
|
||||
format_user=StringFormatter(slots=["### Instruction:\n{{content}}\n\n### Response:\n"]),
|
||||
format_separator=EmptyFormatter(slots=["\n\n"]),
|
||||
default_system=(
|
||||
"Below is an instruction that describes a task. " "Write a response that appropriately completes the request."
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
_register_template(
|
||||
name="aquila",
|
||||
format_user=StringFormatter(slots=["Human: {{content}}###Assistant:"]),
|
||||
format_separator=EmptyFormatter(slots=["###"]),
|
||||
default_system=(
|
||||
"A chat between a curious human and an artificial intelligence assistant. "
|
||||
"The assistant gives helpful, detailed, and polite answers to the human's questions."
|
||||
),
|
||||
stop_words=["</s>"],
|
||||
efficient_eos=True,
|
||||
)
|
||||
|
||||
|
||||
_register_template(
|
||||
name="atom",
|
||||
format_user=StringFormatter(
|
||||
slots=[{"bos_token"}, "Human: {{content}}\n", {"eos_token"}, {"bos_token"}, "Assistant:"]
|
||||
),
|
||||
format_assistant=StringFormatter(slots=["{{content}}\n", {"eos_token"}]),
|
||||
)
|
||||
|
||||
|
||||
_register_template(
|
||||
name="baichuan",
|
||||
format_user=StringFormatter(slots=["<reserved_102>{{content}}<reserved_103>"]),
|
||||
efficient_eos=True,
|
||||
)
|
||||
|
||||
|
||||
_register_template(
|
||||
name="baichuan2",
|
||||
format_user=StringFormatter(slots=["<reserved_106>{{content}}<reserved_107>"]),
|
||||
efficient_eos=True,
|
||||
)
|
||||
|
||||
|
||||
_register_template(
|
||||
name="belle",
|
||||
format_user=StringFormatter(slots=["Human: {{content}}\n\nBelle: "]),
|
||||
format_system=StringFormatter(slots=[{"bos_token"}, "{{content}}"]),
|
||||
format_separator=EmptyFormatter(slots=["\n\n"]),
|
||||
force_system=True,
|
||||
)
|
||||
|
||||
|
||||
_register_template(
|
||||
name="bluelm",
|
||||
format_user=StringFormatter(slots=[{"token": "[|Human|]:"}, "{{content}}", {"token": "[|AI|]:"}]),
|
||||
)
|
||||
|
||||
|
||||
_register_template(
|
||||
name="chatglm2",
|
||||
format_user=StringFormatter(slots=["[Round {{idx}}]\n\n问:{{content}}\n\n答:"]),
|
||||
format_system=StringFormatter(slots=[{"token": "[gMASK]"}, {"token": "sop"}, "{{content}}"]),
|
||||
format_separator=EmptyFormatter(slots=["\n\n"]),
|
||||
efficient_eos=True,
|
||||
force_system=True,
|
||||
)
|
||||
|
||||
|
||||
_register_template(
|
||||
name="chatglm3",
|
||||
format_user=StringFormatter(slots=[{"token": "<|user|>"}, "\n", "{{content}}", {"token": "<|assistant|>"}]),
|
||||
format_assistant=StringFormatter(slots=["\n", "{{content}}"]),
|
||||
format_system=StringFormatter(slots=[{"token": "[gMASK]"}, {"token": "sop"}, "{{content}}"]),
|
||||
format_function=FunctionFormatter(slots=["{{name}}\n{{arguments}}"]),
|
||||
format_observation=StringFormatter(
|
||||
slots=[{"token": "<|observation|>"}, "\n", "{{content}}", {"token": "<|assistant|>"}]
|
||||
),
|
||||
stop_words=["<|user|>", "<|observation|>"],
|
||||
efficient_eos=True,
|
||||
force_system=True,
|
||||
)
|
||||
|
||||
|
||||
_register_template(
|
||||
name="chatglm3_system",
|
||||
format_user=StringFormatter(slots=[{"token": "<|user|>"}, "\n", "{{content}}", {"token": "<|assistant|>"}]),
|
||||
format_assistant=StringFormatter(slots=["\n", "{{content}}"]),
|
||||
format_system=StringFormatter(
|
||||
slots=[{"token": "[gMASK]"}, {"token": "sop"}, {"token": "<|system|>"}, "\n", "{{content}}"]
|
||||
),
|
||||
format_function=FunctionFormatter(slots=["{{name}}\n{{arguments}}"]),
|
||||
format_observation=StringFormatter(
|
||||
slots=[{"token": "<|observation|>"}, "\n", "{{content}}", {"token": "<|assistant|>"}]
|
||||
),
|
||||
default_system=(
|
||||
"You are ChatGLM3, a large language model trained by Zhipu.AI. "
|
||||
"Follow the user's instructions carefully. Respond using markdown."
|
||||
),
|
||||
stop_words=["<|user|>", "<|observation|>"],
|
||||
efficient_eos=True,
|
||||
)
|
||||
|
||||
|
||||
_register_template(
|
||||
name="chatml",
|
||||
format_user=StringFormatter(slots=["<|im_start|>user\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]),
|
||||
format_system=StringFormatter(slots=["<|im_start|>system\n{{content}}<|im_end|>\n"]),
|
||||
format_separator=EmptyFormatter(slots=["\n"]),
|
||||
stop_words=["<|im_end|>", "<|im_start|>"],
|
||||
replace_eos=True,
|
||||
)
|
||||
|
||||
|
||||
_register_template(
|
||||
name="chatml_de",
|
||||
format_user=StringFormatter(slots=["<|im_start|>user\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]),
|
||||
format_system=StringFormatter(slots=["<|im_start|>system\n{{content}}<|im_end|>\n"]),
|
||||
format_separator=EmptyFormatter(slots=["\n"]),
|
||||
default_system="Du bist ein freundlicher und hilfsbereiter KI-Assistent.",
|
||||
stop_words=["<|im_end|>", "<|im_start|>"],
|
||||
replace_eos=True,
|
||||
)
|
||||
|
||||
|
||||
_register_template(
|
||||
name="codegeex2",
|
||||
format_system=StringFormatter(slots=[{"token": "[gMASK]"}, {"token": "sop"}, "{{content}}"]),
|
||||
force_system=True,
|
||||
)
|
||||
|
||||
|
||||
_register_template(
|
||||
name="cpm",
|
||||
format_user=StringFormatter(slots=["<用户>{{content}}<AI>"]),
|
||||
format_system=StringFormatter(slots=[{"bos_token"}, "{{content}}"]),
|
||||
force_system=True,
|
||||
)
|
||||
|
||||
|
||||
_register_template(
|
||||
name="deepseek",
|
||||
format_user=StringFormatter(slots=["User: {{content}}\n\nAssistant:"]),
|
||||
format_system=StringFormatter(slots=[{"bos_token"}, "{{content}}"]),
|
||||
force_system=True,
|
||||
)
|
||||
|
||||
|
||||
_register_template(
|
||||
name="deepseekcoder",
|
||||
format_user=StringFormatter(slots=["### Instruction:\n{{content}}\n### Response:"]),
|
||||
format_assistant=StringFormatter(slots=["\n", "{{content}}"]),
|
||||
format_separator=EmptyFormatter(slots=["\n<|EOT|>\n"]),
|
||||
default_system=(
|
||||
"You are an AI programming assistant, utilizing the Deepseek Coder model, "
|
||||
"developed by Deepseek Company, and you only answer questions related to computer science. "
|
||||
"For politically sensitive questions, security and privacy issues, "
|
||||
"and other non-computer science questions, you will refuse to answer\n"
|
||||
),
|
||||
stop_words=["<|EOT|>"],
|
||||
efficient_eos=True,
|
||||
)
|
||||
|
||||
|
||||
_register_template(
|
||||
name="default",
|
||||
format_user=StringFormatter(slots=["Human: {{content}}\nAssistant: "]),
|
||||
format_system=StringFormatter(slots=["{{content}}\n"]),
|
||||
format_separator=EmptyFormatter(slots=["\n"]),
|
||||
)
|
||||
|
||||
|
||||
_register_template(
|
||||
name="falcon",
|
||||
format_user=StringFormatter(slots=["User: {{content}}\nFalcon:"]),
|
||||
format_separator=EmptyFormatter(slots=["\n"]),
|
||||
efficient_eos=True,
|
||||
)
|
||||
|
||||
|
||||
_register_template(
|
||||
name="gemma",
|
||||
format_user=StringFormatter(slots=["<start_of_turn>user\n{{content}}<end_of_turn>\n<start_of_turn>model\n"]),
|
||||
format_system=StringFormatter(slots=[{"bos_token"}, "{{content}}"]),
|
||||
format_separator=EmptyFormatter(slots=["<end_of_turn>\n"]),
|
||||
efficient_eos=True,
|
||||
force_system=True,
|
||||
)
|
||||
|
||||
|
||||
_register_template(
|
||||
name="intern",
|
||||
format_user=StringFormatter(slots=["<|User|>:{{content}}", {"token": "<eoh>"}, "\n<|Bot|>:"]),
|
||||
format_separator=EmptyFormatter(slots=[{"token": "<eoa>"}, "\n"]),
|
||||
stop_words=["<eoa>"],
|
||||
efficient_eos=True,
|
||||
)
|
||||
|
||||
|
||||
_register_template(
|
||||
name="intern2",
|
||||
format_user=StringFormatter(slots=["<|im_start|>user\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]),
|
||||
format_system=StringFormatter(slots=[{"bos_token"}, "<|im_start|>system\n{{content}}<|im_end|>\n"]),
|
||||
format_separator=EmptyFormatter(slots=["\n"]),
|
||||
default_system=(
|
||||
"You are an AI assistant whose name is InternLM (书生·浦语).\n"
|
||||
"- InternLM (书生·浦语) is a conversational language model that is developed "
|
||||
"by Shanghai AI Laboratory (上海人工智能实验室). It is designed to be helpful, honest, and harmless.\n"
|
||||
"- InternLM (书生·浦语) can understand and communicate fluently in the language chosen "
|
||||
"by the user such as English and 中文."
|
||||
),
|
||||
stop_words=["<|im_end|>"],
|
||||
efficient_eos=True, # internlm2 tokenizer cannot set eos_token_id
|
||||
)
|
||||
|
||||
|
||||
_register_template(
|
||||
name="llama2",
|
||||
format_user=StringFormatter(slots=[{"bos_token"}, "[INST] {{content}} [/INST]"]),
|
||||
format_system=StringFormatter(slots=["<<SYS>>\n{{content}}\n<</SYS>>\n\n"]),
|
||||
default_system=(
|
||||
"You are a helpful, respectful and honest assistant. "
|
||||
"Always answer as helpfully as possible, while being safe. "
|
||||
"Your answers should not include any harmful, unethical, "
|
||||
"racist, sexist, toxic, dangerous, or illegal content. "
|
||||
"Please ensure that your responses are socially unbiased and positive in nature.\n\n"
|
||||
"If a question does not make any sense, or is not factually coherent, "
|
||||
"explain why instead of answering something not correct. "
|
||||
"If you don't know the answer to a question, please don't share false information."
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
_register_template(
|
||||
name="llama2_zh",
|
||||
format_user=StringFormatter(slots=[{"bos_token"}, "[INST] {{content}} [/INST]"]),
|
||||
format_system=StringFormatter(slots=["<<SYS>>\n{{content}}\n<</SYS>>\n\n"]),
|
||||
default_system="You are a helpful assistant. 你是一个乐于助人的助手。",
|
||||
)
|
||||
|
||||
|
||||
_register_template(
|
||||
name="mistral",
|
||||
format_user=StringFormatter(slots=["[INST] {{content}} [/INST]"]),
|
||||
format_system=StringFormatter(slots=[{"bos_token"}, "{{content}}"]),
|
||||
force_system=True,
|
||||
)
|
||||
|
||||
|
||||
_register_template(
|
||||
name="olmo",
|
||||
format_user=StringFormatter(slots=["<|user|>\n{{content}}<|assistant|>"]),
|
||||
format_assistant=StringFormatter(slots=["{{content}}", {"eos_token"}]),
|
||||
format_system=StringFormatter(slots=[{"eos_token"}, "{{content}}"]),
|
||||
force_system=True,
|
||||
)
|
||||
|
||||
|
||||
_register_template(
|
||||
name="openchat",
|
||||
format_user=StringFormatter(slots=["GPT4 Correct User: {{content}}", {"eos_token"}, "GPT4 Correct Assistant:"]),
|
||||
format_assistant=StringFormatter(slots=["{{content}}", {"eos_token"}]),
|
||||
format_system=StringFormatter(slots=[{"bos_token"}, "{{content}}"]),
|
||||
force_system=True,
|
||||
)
|
||||
|
||||
|
||||
_register_template(
|
||||
name="orion",
|
||||
format_user=StringFormatter(slots=["Human: {{content}}\n\nAssistant: ", {"eos_token"}]),
|
||||
format_system=StringFormatter(slots=[{"bos_token"}, "{{content}}"]),
|
||||
force_system=True,
|
||||
)
|
||||
|
||||
|
||||
_register_template(
|
||||
name="qwen",
|
||||
format_user=StringFormatter(slots=["<|im_start|>user\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]),
|
||||
format_system=StringFormatter(slots=["<|im_start|>system\n{{content}}<|im_end|>\n"]),
|
||||
format_separator=EmptyFormatter(slots=["\n"]),
|
||||
default_system="You are a helpful assistant.",
|
||||
stop_words=["<|im_end|>"],
|
||||
replace_eos=True,
|
||||
)
|
||||
|
||||
|
||||
_register_template(
|
||||
name="solar",
|
||||
format_user=StringFormatter(slots=["### User:\n{{content}}\n\n### Assistant:\n"]),
|
||||
format_system=StringFormatter(slots=["### System:\n{{content}}\n\n"]),
|
||||
efficient_eos=True,
|
||||
)
|
||||
|
||||
|
||||
_register_template(
|
||||
name="starchat",
|
||||
format_user=StringFormatter(slots=["<|user|>\n{{content}}<|end|>\n<|assistant|>"]),
|
||||
format_system=StringFormatter(slots=["<|system|>\n{{content}}<|end|>\n"]),
|
||||
format_separator=EmptyFormatter(slots=["\n"]),
|
||||
stop_words=["<|end|>"],
|
||||
replace_eos=True,
|
||||
force_system=True,
|
||||
)
|
||||
|
||||
|
||||
_register_template(
|
||||
name="vanilla",
|
||||
)
|
||||
|
||||
|
||||
_register_template(
|
||||
name="vicuna",
|
||||
format_user=StringFormatter(slots=["USER: {{content}} ASSISTANT:"]),
|
||||
default_system=(
|
||||
"A chat between a curious user and an artificial intelligence assistant. "
|
||||
"The assistant gives helpful, detailed, and polite answers to the user's questions."
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
_register_template(
|
||||
name="xuanyuan",
|
||||
format_user=StringFormatter(slots=["Human: {{content}} Assistant:"]),
|
||||
default_system=(
|
||||
"以下是用户和人工智能助手之间的对话。用户以Human开头,人工智能助手以Assistant开头,"
|
||||
"会对人类提出的问题给出有帮助、高质量、详细和礼貌的回答,并且总是拒绝参与与不道德、"
|
||||
"不安全、有争议、政治敏感等相关的话题、问题和指示。\n"
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
_register_template(
|
||||
name="xverse",
|
||||
format_user=StringFormatter(slots=["Human: {{content}}\n\nAssistant: "]),
|
||||
)
|
||||
|
||||
|
||||
_register_template(
|
||||
name="yayi",
|
||||
format_user=StringFormatter(slots=[{"token": "<|Human|>"}, ":\n{{content}}\n\n", {"token": "<|YaYi|>"}, ":"]),
|
||||
format_system=StringFormatter(slots=[{"token": "<|System|>"}, ":\n{{content}}\n\n"]),
|
||||
format_separator=EmptyFormatter(slots=["\n\n"]),
|
||||
default_system=(
|
||||
"You are a helpful, respectful and honest assistant named YaYi "
|
||||
"developed by Beijing Wenge Technology Co.,Ltd. "
|
||||
"Always answer as helpfully as possible, while being safe. "
|
||||
"Your answers should not include any harmful, unethical, "
|
||||
"racist, sexist, toxic, dangerous, or illegal content. "
|
||||
"Please ensure that your responses are socially unbiased and positive in nature.\n\n"
|
||||
"If a question does not make any sense, or is not factually coherent, "
|
||||
"explain why instead of answering something not correct. "
|
||||
"If you don't know the answer to a question, please don't share false information."
|
||||
),
|
||||
stop_words=["<|End|>"],
|
||||
)
|
||||
|
||||
|
||||
_register_template(
|
||||
name="yi",
|
||||
format_user=StringFormatter(slots=["<|im_start|>user\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]),
|
||||
format_separator=EmptyFormatter(slots=["\n"]),
|
||||
stop_words=["<|im_end|>"],
|
||||
replace_eos=True,
|
||||
)
|
||||
|
||||
|
||||
_register_template(
|
||||
name="yuan",
|
||||
format_user=StringFormatter(slots=["{{content}}", {"token": "<sep>"}]),
|
||||
format_separator=EmptyFormatter(slots=["\n"]),
|
||||
stop_words=["<eod>"],
|
||||
replace_eos=True,
|
||||
)
|
||||
|
||||
|
||||
_register_template(
|
||||
name="zephyr",
|
||||
format_user=StringFormatter(slots=["<|user|>\n{{content}}", {"eos_token"}, "<|assistant|>"]),
|
||||
format_assistant=StringFormatter(slots=["\n{{content}}", {"eos_token"}]),
|
||||
format_system=StringFormatter(slots=["<|system|>\n{{content}}", {"eos_token"}]),
|
||||
default_system="You are a friendly chatbot who always responds in the style of a pirate",
|
||||
)
|
||||
|
||||
|
||||
_register_template(
|
||||
name="ziya",
|
||||
format_user=StringFormatter(slots=["<human>:{{content}}\n<bot>:"]),
|
||||
format_separator=EmptyFormatter(slots=["\n"]),
|
||||
)
|
||||
94
src/AntSK.LLamaFactory/llamafactory/llmtuner/data/utils.py
Normal file
94
src/AntSK.LLamaFactory/llamafactory/llmtuner/data/utils.py
Normal file
@@ -0,0 +1,94 @@
|
||||
import hashlib
|
||||
from enum import Enum, unique
|
||||
from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union
|
||||
|
||||
from datasets import concatenate_datasets, interleave_datasets
|
||||
|
||||
from ..extras.logging import get_logger
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from datasets import Dataset, IterableDataset
|
||||
from transformers import Seq2SeqTrainingArguments
|
||||
|
||||
from llmtuner.hparams import DataArguments
|
||||
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
@unique
|
||||
class Role(str, Enum):
|
||||
USER = "user"
|
||||
ASSISTANT = "assistant"
|
||||
SYSTEM = "system"
|
||||
FUNCTION = "function"
|
||||
OBSERVATION = "observation"
|
||||
|
||||
|
||||
def checksum(data_files: List[str], file_sha1: Optional[str] = None) -> None:
|
||||
if file_sha1 is None:
|
||||
logger.warning("Checksum failed: missing SHA-1 hash value in dataset_info.json.")
|
||||
return
|
||||
|
||||
if len(data_files) != 1:
|
||||
logger.warning("Checksum failed: too many files.")
|
||||
return
|
||||
|
||||
with open(data_files[0], "rb") as f:
|
||||
sha1 = hashlib.sha1(f.read()).hexdigest()
|
||||
if sha1 != file_sha1:
|
||||
logger.warning("Checksum failed: mismatched SHA-1 hash value at {}.".format(data_files[0]))
|
||||
|
||||
|
||||
def infer_max_len(source_len: int, target_len: int, max_len: int, reserved_label_len: int) -> Tuple[int, int]:
|
||||
max_target_len = int(max_len * (target_len / (source_len + target_len)))
|
||||
max_target_len = max(max_target_len, reserved_label_len)
|
||||
max_source_len = max_len - max_target_len
|
||||
return max_source_len, max_target_len
|
||||
|
||||
|
||||
def merge_dataset(
|
||||
all_datasets: List[Union["Dataset", "IterableDataset"]],
|
||||
data_args: "DataArguments",
|
||||
training_args: "Seq2SeqTrainingArguments",
|
||||
) -> Union["Dataset", "IterableDataset"]:
|
||||
if len(all_datasets) == 1:
|
||||
return all_datasets[0]
|
||||
elif data_args.mix_strategy == "concat":
|
||||
if data_args.streaming:
|
||||
logger.warning("The samples between different datasets will not be mixed in streaming mode.")
|
||||
return concatenate_datasets(all_datasets)
|
||||
elif data_args.mix_strategy.startswith("interleave"):
|
||||
if not data_args.streaming:
|
||||
logger.warning("We recommend using `mix_strategy=concat` in non-streaming mode.")
|
||||
return interleave_datasets(
|
||||
datasets=all_datasets,
|
||||
probabilities=data_args.interleave_probs,
|
||||
seed=training_args.seed,
|
||||
stopping_strategy="first_exhausted" if data_args.mix_strategy.endswith("under") else "all_exhausted",
|
||||
)
|
||||
else:
|
||||
raise ValueError("Unknown mixing strategy.")
|
||||
|
||||
|
||||
def split_dataset(
|
||||
dataset: Union["Dataset", "IterableDataset"], data_args: "DataArguments", training_args: "Seq2SeqTrainingArguments"
|
||||
) -> Dict[str, "Dataset"]:
|
||||
if training_args.do_train:
|
||||
if data_args.val_size > 1e-6: # Split the dataset
|
||||
if data_args.streaming:
|
||||
val_set = dataset.take(int(data_args.val_size))
|
||||
train_set = dataset.skip(int(data_args.val_size))
|
||||
dataset = dataset.shuffle(buffer_size=data_args.buffer_size, seed=training_args.seed)
|
||||
return {"train_dataset": train_set, "eval_dataset": val_set}
|
||||
else:
|
||||
val_size = int(data_args.val_size) if data_args.val_size > 1 else data_args.val_size
|
||||
dataset = dataset.train_test_split(test_size=val_size, seed=training_args.seed)
|
||||
return {"train_dataset": dataset["train"], "eval_dataset": dataset["test"]}
|
||||
else:
|
||||
if data_args.streaming:
|
||||
dataset = dataset.shuffle(buffer_size=data_args.buffer_size, seed=training_args.seed)
|
||||
return {"train_dataset": dataset}
|
||||
else: # do_eval or do_predict
|
||||
return {"eval_dataset": dataset}
|
||||
@@ -0,0 +1,4 @@
|
||||
from .evaluator import Evaluator
|
||||
|
||||
|
||||
__all__ = ["Evaluator"]
|
||||
122
src/AntSK.LLamaFactory/llamafactory/llmtuner/eval/evaluator.py
Normal file
122
src/AntSK.LLamaFactory/llamafactory/llmtuner/eval/evaluator.py
Normal file
@@ -0,0 +1,122 @@
|
||||
# Inspired by: https://github.com/hendrycks/test/blob/master/evaluate_flan.py
|
||||
|
||||
import inspect
|
||||
import json
|
||||
import os
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from datasets import load_dataset
|
||||
from tqdm import tqdm, trange
|
||||
from transformers.utils import cached_file
|
||||
|
||||
from ..data import get_template_and_fix_tokenizer
|
||||
from ..extras.constants import CHOICES, SUBJECTS
|
||||
from ..hparams import get_eval_args
|
||||
from ..model import load_model_and_tokenizer
|
||||
from .template import get_eval_template
|
||||
|
||||
|
||||
class Evaluator:
|
||||
def __init__(self, args: Optional[Dict[str, Any]] = None) -> None:
|
||||
self.model_args, self.data_args, self.eval_args, finetuning_args = get_eval_args(args)
|
||||
self.model, self.tokenizer = load_model_and_tokenizer(self.model_args, finetuning_args)
|
||||
self.tokenizer.padding_side = "right" # avoid overflow issue in batched inference for llama2
|
||||
self.template = get_template_and_fix_tokenizer(self.tokenizer, self.data_args.template)
|
||||
self.eval_template = get_eval_template(self.eval_args.lang)
|
||||
self.choice_inputs = [
|
||||
self.tokenizer.encode(self.eval_template.prefix + ch, add_special_tokens=False)[-1] for ch in CHOICES
|
||||
]
|
||||
|
||||
@torch.inference_mode()
|
||||
def batch_inference(self, batch_input: Dict[str, torch.Tensor]) -> List[str]:
|
||||
logits = self.model(**batch_input).logits
|
||||
lengths = torch.sum(batch_input["attention_mask"], dim=-1)
|
||||
word_probs = torch.stack([logits[i, lengths[i] - 1] for i in range(len(lengths))], dim=0)
|
||||
choice_probs = torch.nn.functional.softmax(word_probs[:, self.choice_inputs], dim=-1).detach()
|
||||
return [chr(ord("A") + offset.item()) for offset in torch.argmax(choice_probs, dim=-1)]
|
||||
|
||||
def eval(self) -> None:
|
||||
mapping = cached_file(
|
||||
path_or_repo_id=os.path.join(self.eval_args.task_dir, self.eval_args.task),
|
||||
filename="mapping.json",
|
||||
cache_dir=self.model_args.cache_dir,
|
||||
token=self.model_args.hf_hub_token,
|
||||
)
|
||||
|
||||
with open(mapping, "r", encoding="utf-8") as f:
|
||||
categorys: Dict[str, Dict[str, str]] = json.load(f)
|
||||
|
||||
category_corrects = {subj: np.array([], dtype="bool") for subj in SUBJECTS}
|
||||
pbar = tqdm(categorys.keys(), desc="Processing subjects", position=0)
|
||||
results = {}
|
||||
for subject in pbar:
|
||||
if "trust_remote_code" in inspect.signature(load_dataset).parameters: # for datasets==2.16.0
|
||||
kwargs = {"trust_remote_code": True}
|
||||
else:
|
||||
kwargs = {}
|
||||
|
||||
dataset = load_dataset(
|
||||
path=os.path.join(self.eval_args.task_dir, self.eval_args.task),
|
||||
name=subject,
|
||||
cache_dir=self.model_args.cache_dir,
|
||||
download_mode=self.eval_args.download_mode,
|
||||
token=self.model_args.hf_hub_token,
|
||||
**kwargs,
|
||||
)
|
||||
pbar.set_postfix_str(categorys[subject]["name"])
|
||||
inputs, outputs, labels = [], [], []
|
||||
for i in trange(len(dataset[self.data_args.split]), desc="Formatting batches", position=1, leave=False):
|
||||
support_set = (
|
||||
dataset["train"].shuffle().select(range(min(self.eval_args.n_shot, len(dataset["train"]))))
|
||||
)
|
||||
messages = self.eval_template.format_example(
|
||||
target_data=dataset[self.data_args.split][i],
|
||||
support_set=support_set,
|
||||
subject_name=categorys[subject]["name"],
|
||||
)
|
||||
|
||||
input_ids, _ = self.template.encode_oneturn(tokenizer=self.tokenizer, messages=messages)
|
||||
inputs.append({"input_ids": input_ids, "attention_mask": [1] * len(input_ids)})
|
||||
labels.append(messages[-1]["content"])
|
||||
|
||||
for i in trange(
|
||||
0, len(inputs), self.eval_args.batch_size, desc="Predicting batches", position=1, leave=False
|
||||
):
|
||||
batch_input = self.tokenizer.pad(
|
||||
inputs[i : i + self.eval_args.batch_size], return_attention_mask=True, return_tensors="pt"
|
||||
).to(self.model.device)
|
||||
preds = self.batch_inference(batch_input)
|
||||
outputs += preds
|
||||
|
||||
corrects = np.array(outputs) == np.array(labels)
|
||||
category_name = categorys[subject]["category"]
|
||||
category_corrects[category_name] = np.concatenate([category_corrects[category_name], corrects], axis=0)
|
||||
category_corrects["Average"] = np.concatenate([category_corrects["Average"], corrects], axis=0)
|
||||
results[subject] = {str(i): outputs[i] for i in range(len(outputs))}
|
||||
|
||||
pbar.close()
|
||||
self._save_results(category_corrects, results)
|
||||
|
||||
def _save_results(self, category_corrects: Dict[str, np.ndarray], results: Dict[str, Dict[int, str]]) -> None:
|
||||
score_info = "\n".join(
|
||||
[
|
||||
"{:>15}: {:.2f}".format(category_name, 100 * np.mean(category_correct))
|
||||
for category_name, category_correct in category_corrects.items()
|
||||
if len(category_correct)
|
||||
]
|
||||
)
|
||||
print(score_info)
|
||||
if self.eval_args.save_dir is not None:
|
||||
os.makedirs(self.eval_args.save_dir, exist_ok=False)
|
||||
with open(os.path.join(self.eval_args.save_dir, "results.json"), "w", encoding="utf-8", newline="\n") as f:
|
||||
json.dump(results, f, indent=2)
|
||||
|
||||
with open(os.path.join(self.eval_args.save_dir, "results.log"), "w", encoding="utf-8", newline="\n") as f:
|
||||
f.write(score_info)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
evaluator = Evaluator()
|
||||
evaluator.eval()
|
||||
@@ -0,0 +1,67 @@
|
||||
from dataclasses import dataclass
|
||||
from typing import TYPE_CHECKING, Dict, List, Tuple
|
||||
|
||||
from ..data import Role
|
||||
from ..extras.constants import CHOICES
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from datasets import Dataset
|
||||
|
||||
|
||||
@dataclass
|
||||
class EvalTemplate:
|
||||
system: str
|
||||
choice: str
|
||||
answer: str
|
||||
prefix: str
|
||||
|
||||
def parse_example(self, example: Dict[str, str]) -> Tuple[str, str]:
|
||||
candidates = [self.choice.format(choice=ch, content=example[ch]) for ch in CHOICES if ch in example]
|
||||
return "".join([example["question"]] + candidates + [self.answer]), example["answer"]
|
||||
|
||||
def format_example(
|
||||
self, target_data: Dict[str, str], support_set: "Dataset", subject_name: str
|
||||
) -> List[Dict[str, str]]:
|
||||
messages = []
|
||||
for k in range(len(support_set)):
|
||||
prompt, response = self.parse_example(support_set[k])
|
||||
messages.append({"role": Role.USER, "content": prompt})
|
||||
messages.append({"role": Role.ASSISTANT, "content": response})
|
||||
|
||||
prompt, response = self.parse_example(target_data)
|
||||
messages.append({"role": Role.USER, "content": prompt})
|
||||
messages.append({"role": Role.ASSISTANT, "content": response})
|
||||
messages[0]["content"] = self.system.format(subject=subject_name) + messages[0]["content"]
|
||||
return messages
|
||||
|
||||
|
||||
eval_templates: Dict[str, "EvalTemplate"] = {}
|
||||
|
||||
|
||||
def register_eval_template(name: str, system: str, choice: str, answer: str, prefix: str) -> None:
|
||||
eval_templates[name] = EvalTemplate(system=system, choice=choice, answer=answer, prefix=prefix)
|
||||
|
||||
|
||||
def get_eval_template(name: str) -> "EvalTemplate":
|
||||
eval_template = eval_templates.get(name, None)
|
||||
assert eval_template is not None, "Template {} does not exist.".format(name)
|
||||
return eval_template
|
||||
|
||||
|
||||
register_eval_template(
|
||||
name="en",
|
||||
system="The following are multiple choice questions (with answers) about {subject}.\n\n",
|
||||
choice="\n{choice}. {content}",
|
||||
answer="\nAnswer: ",
|
||||
prefix=" ",
|
||||
)
|
||||
|
||||
|
||||
register_eval_template(
|
||||
name="zh",
|
||||
system="以下是中国关于{subject}考试的单项选择题,请选出其中的正确答案。\n\n",
|
||||
choice="\n{choice}. {content}",
|
||||
answer="\n答案:",
|
||||
prefix="\n",
|
||||
)
|
||||
153
src/AntSK.LLamaFactory/llamafactory/llmtuner/extras/callbacks.py
Normal file
153
src/AntSK.LLamaFactory/llamafactory/llmtuner/extras/callbacks.py
Normal file
@@ -0,0 +1,153 @@
|
||||
import json
|
||||
import os
|
||||
import time
|
||||
from datetime import timedelta
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from transformers import TrainerCallback
|
||||
from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR, has_length
|
||||
|
||||
from .constants import LOG_FILE_NAME
|
||||
from .logging import get_logger
|
||||
from .misc import fix_valuehead_checkpoint
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from transformers import TrainerControl, TrainerState, TrainingArguments
|
||||
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
class FixValueHeadModelCallback(TrainerCallback):
|
||||
def on_save(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs):
|
||||
r"""
|
||||
Event called after a checkpoint save.
|
||||
"""
|
||||
if args.should_save:
|
||||
fix_valuehead_checkpoint(
|
||||
model=kwargs.pop("model"),
|
||||
output_dir=os.path.join(args.output_dir, "{}-{}".format(PREFIX_CHECKPOINT_DIR, state.global_step)),
|
||||
safe_serialization=args.save_safetensors,
|
||||
)
|
||||
|
||||
|
||||
class LogCallback(TrainerCallback):
|
||||
def __init__(self, runner=None):
|
||||
self.runner = runner
|
||||
self.in_training = False
|
||||
self.start_time = time.time()
|
||||
self.cur_steps = 0
|
||||
self.max_steps = 0
|
||||
self.elapsed_time = ""
|
||||
self.remaining_time = ""
|
||||
|
||||
def timing(self):
|
||||
cur_time = time.time()
|
||||
elapsed_time = cur_time - self.start_time
|
||||
avg_time_per_step = elapsed_time / self.cur_steps if self.cur_steps != 0 else 0
|
||||
remaining_time = (self.max_steps - self.cur_steps) * avg_time_per_step
|
||||
self.elapsed_time = str(timedelta(seconds=int(elapsed_time)))
|
||||
self.remaining_time = str(timedelta(seconds=int(remaining_time)))
|
||||
|
||||
def on_train_begin(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs):
|
||||
r"""
|
||||
Event called at the beginning of training.
|
||||
"""
|
||||
if state.is_local_process_zero:
|
||||
self.in_training = True
|
||||
self.start_time = time.time()
|
||||
self.max_steps = state.max_steps
|
||||
if os.path.exists(os.path.join(args.output_dir, LOG_FILE_NAME)) and args.overwrite_output_dir:
|
||||
logger.warning("Previous log file in this folder will be deleted.")
|
||||
os.remove(os.path.join(args.output_dir, LOG_FILE_NAME))
|
||||
|
||||
def on_train_end(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs):
|
||||
r"""
|
||||
Event called at the end of training.
|
||||
"""
|
||||
if state.is_local_process_zero:
|
||||
self.in_training = False
|
||||
self.cur_steps = 0
|
||||
self.max_steps = 0
|
||||
|
||||
def on_substep_end(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs):
|
||||
r"""
|
||||
Event called at the end of an substep during gradient accumulation.
|
||||
"""
|
||||
if state.is_local_process_zero and self.runner is not None and self.runner.aborted:
|
||||
control.should_epoch_stop = True
|
||||
control.should_training_stop = True
|
||||
|
||||
def on_step_end(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs):
|
||||
r"""
|
||||
Event called at the end of a training step.
|
||||
"""
|
||||
if state.is_local_process_zero:
|
||||
self.cur_steps = state.global_step
|
||||
self.timing()
|
||||
if self.runner is not None and self.runner.aborted:
|
||||
control.should_epoch_stop = True
|
||||
control.should_training_stop = True
|
||||
|
||||
def on_evaluate(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs):
|
||||
r"""
|
||||
Event called after an evaluation phase.
|
||||
"""
|
||||
if state.is_local_process_zero and not self.in_training:
|
||||
self.cur_steps = 0
|
||||
self.max_steps = 0
|
||||
|
||||
def on_predict(
|
||||
self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", *other, **kwargs
|
||||
):
|
||||
r"""
|
||||
Event called after a successful prediction.
|
||||
"""
|
||||
if state.is_local_process_zero and not self.in_training:
|
||||
self.cur_steps = 0
|
||||
self.max_steps = 0
|
||||
|
||||
def on_log(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs) -> None:
|
||||
r"""
|
||||
Event called after logging the last logs.
|
||||
"""
|
||||
if not state.is_local_process_zero:
|
||||
return
|
||||
|
||||
logs = dict(
|
||||
current_steps=self.cur_steps,
|
||||
total_steps=self.max_steps,
|
||||
loss=state.log_history[-1].get("loss", None),
|
||||
eval_loss=state.log_history[-1].get("eval_loss", None),
|
||||
predict_loss=state.log_history[-1].get("predict_loss", None),
|
||||
reward=state.log_history[-1].get("reward", None),
|
||||
learning_rate=state.log_history[-1].get("learning_rate", None),
|
||||
epoch=state.log_history[-1].get("epoch", None),
|
||||
percentage=round(self.cur_steps / self.max_steps * 100, 2) if self.max_steps != 0 else 100,
|
||||
elapsed_time=self.elapsed_time,
|
||||
remaining_time=self.remaining_time,
|
||||
)
|
||||
if self.runner is not None:
|
||||
logger.info(
|
||||
"{{'loss': {:.4f}, 'learning_rate': {:2.4e}, 'epoch': {:.2f}}}".format(
|
||||
logs["loss"] or 0, logs["learning_rate"] or 0, logs["epoch"] or 0
|
||||
)
|
||||
)
|
||||
|
||||
os.makedirs(args.output_dir, exist_ok=True)
|
||||
with open(os.path.join(args.output_dir, "trainer_log.jsonl"), "a", encoding="utf-8") as f:
|
||||
f.write(json.dumps(logs) + "\n")
|
||||
|
||||
def on_prediction_step(
|
||||
self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs
|
||||
):
|
||||
r"""
|
||||
Event called after a prediction step.
|
||||
"""
|
||||
eval_dataloader = kwargs.pop("eval_dataloader", None)
|
||||
if state.is_local_process_zero and has_length(eval_dataloader) and not self.in_training:
|
||||
if self.max_steps == 0:
|
||||
self.max_steps = len(eval_dataloader)
|
||||
self.cur_steps += 1
|
||||
self.timing()
|
||||
952
src/AntSK.LLamaFactory/llamafactory/llmtuner/extras/constants.py
Normal file
952
src/AntSK.LLamaFactory/llamafactory/llmtuner/extras/constants.py
Normal file
@@ -0,0 +1,952 @@
|
||||
from collections import OrderedDict, defaultdict
|
||||
from enum import Enum
|
||||
from typing import Dict, Optional
|
||||
|
||||
|
||||
CHOICES = ["A", "B", "C", "D"]
|
||||
|
||||
DATA_CONFIG = "dataset_info.json"
|
||||
|
||||
DEFAULT_MODULE = defaultdict(str)
|
||||
|
||||
DEFAULT_TEMPLATE = defaultdict(str)
|
||||
|
||||
FILEEXT2TYPE = {
|
||||
"arrow": "arrow",
|
||||
"csv": "csv",
|
||||
"json": "json",
|
||||
"jsonl": "json",
|
||||
"parquet": "parquet",
|
||||
"txt": "text",
|
||||
}
|
||||
|
||||
IGNORE_INDEX = -100
|
||||
|
||||
LAYERNORM_NAMES = {"norm", "ln"}
|
||||
|
||||
LOG_FILE_NAME = "trainer_log.jsonl"
|
||||
|
||||
METHODS = ["full", "freeze", "lora"]
|
||||
|
||||
PEFT_METHODS = ["lora"]
|
||||
|
||||
SUBJECTS = ["Average", "STEM", "Social Sciences", "Humanities", "Other"]
|
||||
|
||||
SUPPORTED_MODELS = OrderedDict()
|
||||
|
||||
TRAINING_STAGES = {
|
||||
"Supervised Fine-Tuning": "sft",
|
||||
"Reward Modeling": "rm",
|
||||
"PPO": "ppo",
|
||||
"DPO": "dpo",
|
||||
"Pre-Training": "pt",
|
||||
}
|
||||
|
||||
V_HEAD_WEIGHTS_NAME = "value_head.bin"
|
||||
|
||||
V_HEAD_SAFE_WEIGHTS_NAME = "value_head.safetensors"
|
||||
|
||||
|
||||
class DownloadSource(str, Enum):
|
||||
DEFAULT = "hf"
|
||||
MODELSCOPE = "ms"
|
||||
|
||||
|
||||
def register_model_group(
|
||||
models: Dict[str, Dict[DownloadSource, str]],
|
||||
module: Optional[str] = None,
|
||||
template: Optional[str] = None,
|
||||
) -> None:
|
||||
prefix = None
|
||||
for name, path in models.items():
|
||||
if prefix is None:
|
||||
prefix = name.split("-")[0]
|
||||
else:
|
||||
assert prefix == name.split("-")[0], "prefix should be identical."
|
||||
SUPPORTED_MODELS[name] = path
|
||||
if module is not None:
|
||||
DEFAULT_MODULE[prefix] = module
|
||||
if template is not None:
|
||||
DEFAULT_TEMPLATE[prefix] = template
|
||||
|
||||
|
||||
register_model_group(
|
||||
models={
|
||||
"Baichuan-7B-Base": {
|
||||
DownloadSource.DEFAULT: "baichuan-inc/Baichuan-7B",
|
||||
DownloadSource.MODELSCOPE: "baichuan-inc/baichuan-7B",
|
||||
},
|
||||
"Baichuan-13B-Base": {
|
||||
DownloadSource.DEFAULT: "baichuan-inc/Baichuan-13B-Base",
|
||||
DownloadSource.MODELSCOPE: "baichuan-inc/Baichuan-13B-Base",
|
||||
},
|
||||
"Baichuan-13B-Chat": {
|
||||
DownloadSource.DEFAULT: "baichuan-inc/Baichuan-13B-Chat",
|
||||
DownloadSource.MODELSCOPE: "baichuan-inc/Baichuan-13B-Chat",
|
||||
},
|
||||
},
|
||||
module="W_pack",
|
||||
template="baichuan",
|
||||
)
|
||||
|
||||
|
||||
register_model_group(
|
||||
models={
|
||||
"Baichuan2-7B-Base": {
|
||||
DownloadSource.DEFAULT: "baichuan-inc/Baichuan2-7B-Base",
|
||||
DownloadSource.MODELSCOPE: "baichuan-inc/Baichuan2-7B-Base",
|
||||
},
|
||||
"Baichuan2-13B-Base": {
|
||||
DownloadSource.DEFAULT: "baichuan-inc/Baichuan2-13B-Base",
|
||||
DownloadSource.MODELSCOPE: "baichuan-inc/Baichuan2-13B-Base",
|
||||
},
|
||||
"Baichuan2-7B-Chat": {
|
||||
DownloadSource.DEFAULT: "baichuan-inc/Baichuan2-7B-Chat",
|
||||
DownloadSource.MODELSCOPE: "baichuan-inc/Baichuan2-7B-Chat",
|
||||
},
|
||||
"Baichuan2-13B-Chat": {
|
||||
DownloadSource.DEFAULT: "baichuan-inc/Baichuan2-13B-Chat",
|
||||
DownloadSource.MODELSCOPE: "baichuan-inc/Baichuan2-13B-Chat",
|
||||
},
|
||||
},
|
||||
module="W_pack",
|
||||
template="baichuan2",
|
||||
)
|
||||
|
||||
|
||||
register_model_group(
|
||||
models={
|
||||
"BLOOM-560M": {
|
||||
DownloadSource.DEFAULT: "bigscience/bloom-560m",
|
||||
DownloadSource.MODELSCOPE: "AI-ModelScope/bloom-560m",
|
||||
},
|
||||
"BLOOM-3B": {
|
||||
DownloadSource.DEFAULT: "bigscience/bloom-3b",
|
||||
DownloadSource.MODELSCOPE: "AI-ModelScope/bloom-3b",
|
||||
},
|
||||
"BLOOM-7B1": {
|
||||
DownloadSource.DEFAULT: "bigscience/bloom-7b1",
|
||||
DownloadSource.MODELSCOPE: "AI-ModelScope/bloom-7b1",
|
||||
},
|
||||
},
|
||||
module="query_key_value",
|
||||
)
|
||||
|
||||
|
||||
register_model_group(
|
||||
models={
|
||||
"BLOOMZ-560M": {
|
||||
DownloadSource.DEFAULT: "bigscience/bloomz-560m",
|
||||
DownloadSource.MODELSCOPE: "AI-ModelScope/bloomz-560m",
|
||||
},
|
||||
"BLOOMZ-3B": {
|
||||
DownloadSource.DEFAULT: "bigscience/bloomz-3b",
|
||||
DownloadSource.MODELSCOPE: "AI-ModelScope/bloomz-3b",
|
||||
},
|
||||
"BLOOMZ-7B1-mt": {
|
||||
DownloadSource.DEFAULT: "bigscience/bloomz-7b1-mt",
|
||||
DownloadSource.MODELSCOPE: "AI-ModelScope/bloomz-7b1-mt",
|
||||
},
|
||||
},
|
||||
module="query_key_value",
|
||||
)
|
||||
|
||||
|
||||
register_model_group(
|
||||
models={
|
||||
"BlueLM-7B-Base": {
|
||||
DownloadSource.DEFAULT: "vivo-ai/BlueLM-7B-Base",
|
||||
DownloadSource.MODELSCOPE: "vivo-ai/BlueLM-7B-Base",
|
||||
},
|
||||
"BlueLM-7B-Chat": {
|
||||
DownloadSource.DEFAULT: "vivo-ai/BlueLM-7B-Chat",
|
||||
DownloadSource.MODELSCOPE: "vivo-ai/BlueLM-7B-Chat",
|
||||
},
|
||||
},
|
||||
template="bluelm",
|
||||
)
|
||||
|
||||
|
||||
register_model_group(
|
||||
models={
|
||||
"ChatGLM2-6B-Chat": {
|
||||
DownloadSource.DEFAULT: "THUDM/chatglm2-6b",
|
||||
DownloadSource.MODELSCOPE: "ZhipuAI/chatglm2-6b",
|
||||
}
|
||||
},
|
||||
module="query_key_value",
|
||||
template="chatglm2",
|
||||
)
|
||||
|
||||
|
||||
register_model_group(
|
||||
models={
|
||||
"ChatGLM3-6B-Base": {
|
||||
DownloadSource.DEFAULT: "THUDM/chatglm3-6b-base",
|
||||
DownloadSource.MODELSCOPE: "ZhipuAI/chatglm3-6b-base",
|
||||
},
|
||||
"ChatGLM3-6B-Chat": {
|
||||
DownloadSource.DEFAULT: "THUDM/chatglm3-6b",
|
||||
DownloadSource.MODELSCOPE: "ZhipuAI/chatglm3-6b",
|
||||
},
|
||||
},
|
||||
module="query_key_value",
|
||||
template="chatglm3",
|
||||
)
|
||||
|
||||
|
||||
register_model_group(
|
||||
models={
|
||||
"ChineseLLaMA2-1.3B": {
|
||||
DownloadSource.DEFAULT: "hfl/chinese-llama-2-1.3b",
|
||||
DownloadSource.MODELSCOPE: "AI-ModelScope/chinese-llama-2-1.3b",
|
||||
},
|
||||
"ChineseLLaMA2-7B": {
|
||||
DownloadSource.DEFAULT: "hfl/chinese-llama-2-7b",
|
||||
DownloadSource.MODELSCOPE: "AI-ModelScope/chinese-llama-2-7b",
|
||||
},
|
||||
"ChineseLLaMA2-13B": {
|
||||
DownloadSource.DEFAULT: "hfl/chinese-llama-2-13b",
|
||||
DownloadSource.MODELSCOPE: "AI-ModelScope/chinese-llama-2-13b",
|
||||
},
|
||||
"ChineseLLaMA2-1.3B-Chat": {
|
||||
DownloadSource.DEFAULT: "hfl/chinese-alpaca-2-1.3b",
|
||||
DownloadSource.MODELSCOPE: "AI-ModelScope/chinese-alpaca-2-1.3b",
|
||||
},
|
||||
"ChineseLLaMA2-7B-Chat": {
|
||||
DownloadSource.DEFAULT: "hfl/chinese-alpaca-2-7b",
|
||||
DownloadSource.MODELSCOPE: "AI-ModelScope/chinese-alpaca-2-7b",
|
||||
},
|
||||
"ChineseLLaMA2-13B-Chat": {
|
||||
DownloadSource.DEFAULT: "hfl/chinese-alpaca-2-13b",
|
||||
DownloadSource.MODELSCOPE: "AI-ModelScope/chinese-alpaca-2-13b",
|
||||
},
|
||||
},
|
||||
template="llama2_zh",
|
||||
)
|
||||
|
||||
|
||||
register_model_group(
|
||||
models={
|
||||
"DeepSeek-LLM-7B-Base": {
|
||||
DownloadSource.DEFAULT: "deepseek-ai/deepseek-llm-7b-base",
|
||||
DownloadSource.MODELSCOPE: "deepseek-ai/deepseek-llm-7b-base",
|
||||
},
|
||||
"DeepSeek-LLM-67B-Base": {
|
||||
DownloadSource.DEFAULT: "deepseek-ai/deepseek-llm-67b-base",
|
||||
DownloadSource.MODELSCOPE: "deepseek-ai/deepseek-llm-67b-base",
|
||||
},
|
||||
"DeepSeek-LLM-7B-Chat": {
|
||||
DownloadSource.DEFAULT: "deepseek-ai/deepseek-llm-7b-chat",
|
||||
DownloadSource.MODELSCOPE: "deepseek-ai/deepseek-llm-7b-chat",
|
||||
},
|
||||
"DeepSeek-LLM-67B-Chat": {
|
||||
DownloadSource.DEFAULT: "deepseek-ai/deepseek-llm-67b-chat",
|
||||
DownloadSource.MODELSCOPE: "deepseek-ai/deepseek-llm-67b-chat",
|
||||
},
|
||||
"DeepSeek-Math-7B-Base": {
|
||||
DownloadSource.DEFAULT: "deepseek-ai/deepseek-math-7b-base",
|
||||
},
|
||||
"DeepSeek-Math-7B-Chat": {
|
||||
DownloadSource.DEFAULT: "deepseek-ai/deepseek-math-7b-instruct",
|
||||
},
|
||||
"DeepSeek-MoE-16B-Base": {
|
||||
DownloadSource.DEFAULT: "deepseek-ai/deepseek-moe-16b-base",
|
||||
DownloadSource.MODELSCOPE: "deepseek-ai/deepseek-moe-16b-base",
|
||||
},
|
||||
"DeepSeek-MoE-16B-Chat": {
|
||||
DownloadSource.DEFAULT: "deepseek-ai/deepseek-moe-16b-chat",
|
||||
DownloadSource.MODELSCOPE: "deepseek-ai/deepseek-moe-16b-chat",
|
||||
},
|
||||
},
|
||||
template="deepseek",
|
||||
)
|
||||
|
||||
|
||||
register_model_group(
|
||||
models={
|
||||
"DeepSeekCoder-6.7B-Base": {
|
||||
DownloadSource.DEFAULT: "deepseek-ai/deepseek-coder-6.7b-base",
|
||||
DownloadSource.MODELSCOPE: "deepseek-ai/deepseek-coder-6.7b-base",
|
||||
},
|
||||
"DeepSeekCoder-7B-Base": {
|
||||
DownloadSource.DEFAULT: "deepseek-ai/deepseek-coder-7b-base-v1.5",
|
||||
},
|
||||
"DeepSeekCoder-33B-Base": {
|
||||
DownloadSource.DEFAULT: "deepseek-ai/deepseek-coder-33b-base",
|
||||
DownloadSource.MODELSCOPE: "deepseek-ai/deepseek-coder-33b-base",
|
||||
},
|
||||
"DeepSeekCoder-6.7B-Chat": {
|
||||
DownloadSource.DEFAULT: "deepseek-ai/deepseek-coder-6.7b-instruct",
|
||||
DownloadSource.MODELSCOPE: "deepseek-ai/deepseek-coder-6.7b-instruct",
|
||||
},
|
||||
"DeepSeekCoder-7B-Chat": {
|
||||
DownloadSource.DEFAULT: "deepseek-ai/deepseek-coder-7b-instruct-v1.5",
|
||||
},
|
||||
"DeepSeekCoder-33B-Chat": {
|
||||
DownloadSource.DEFAULT: "deepseek-ai/deepseek-coder-33b-instruct",
|
||||
DownloadSource.MODELSCOPE: "deepseek-ai/deepseek-coder-33b-instruct",
|
||||
},
|
||||
},
|
||||
template="deepseekcoder",
|
||||
)
|
||||
|
||||
|
||||
register_model_group(
|
||||
models={
|
||||
"Falcon-7B": {
|
||||
DownloadSource.DEFAULT: "tiiuae/falcon-7b",
|
||||
DownloadSource.MODELSCOPE: "AI-ModelScope/falcon-7b",
|
||||
},
|
||||
"Falcon-40B": {
|
||||
DownloadSource.DEFAULT: "tiiuae/falcon-40b",
|
||||
DownloadSource.MODELSCOPE: "AI-ModelScope/falcon-40b",
|
||||
},
|
||||
"Falcon-180B": {
|
||||
DownloadSource.DEFAULT: "tiiuae/falcon-180b",
|
||||
DownloadSource.MODELSCOPE: "modelscope/falcon-180B",
|
||||
},
|
||||
"Falcon-7B-Chat": {
|
||||
DownloadSource.DEFAULT: "tiiuae/falcon-7b-instruct",
|
||||
DownloadSource.MODELSCOPE: "AI-ModelScope/falcon-7b-instruct",
|
||||
},
|
||||
"Falcon-40B-Chat": {
|
||||
DownloadSource.DEFAULT: "tiiuae/falcon-40b-instruct",
|
||||
DownloadSource.MODELSCOPE: "AI-ModelScope/falcon-40b-instruct",
|
||||
},
|
||||
"Falcon-180B-Chat": {
|
||||
DownloadSource.DEFAULT: "tiiuae/falcon-180b-chat",
|
||||
DownloadSource.MODELSCOPE: "modelscope/falcon-180B-chat",
|
||||
},
|
||||
},
|
||||
module="query_key_value",
|
||||
template="falcon",
|
||||
)
|
||||
|
||||
|
||||
register_model_group(
|
||||
models={
|
||||
"Gemma-2B": {
|
||||
DownloadSource.DEFAULT: "google/gemma-2b",
|
||||
DownloadSource.MODELSCOPE: "AI-ModelScope/gemma-2b",
|
||||
},
|
||||
"Gemma-7B": {
|
||||
DownloadSource.DEFAULT: "google/gemma-7b",
|
||||
DownloadSource.MODELSCOPE: "AI-ModelScope/gemma-2b-it",
|
||||
},
|
||||
"Gemma-2B-Chat": {
|
||||
DownloadSource.DEFAULT: "google/gemma-2b-it",
|
||||
DownloadSource.MODELSCOPE: "AI-ModelScope/gemma-7b",
|
||||
},
|
||||
"Gemma-7B-Chat": {
|
||||
DownloadSource.DEFAULT: "google/gemma-7b-it",
|
||||
DownloadSource.MODELSCOPE: "AI-ModelScope/gemma-7b-it",
|
||||
},
|
||||
},
|
||||
template="gemma",
|
||||
)
|
||||
|
||||
|
||||
register_model_group(
|
||||
models={
|
||||
"InternLM-7B": {
|
||||
DownloadSource.DEFAULT: "internlm/internlm-7b",
|
||||
DownloadSource.MODELSCOPE: "Shanghai_AI_Laboratory/internlm-7b",
|
||||
},
|
||||
"InternLM-20B": {
|
||||
DownloadSource.DEFAULT: "internlm/internlm-20b",
|
||||
DownloadSource.MODELSCOPE: "Shanghai_AI_Laboratory/internlm-20b",
|
||||
},
|
||||
"InternLM-7B-Chat": {
|
||||
DownloadSource.DEFAULT: "internlm/internlm-chat-7b",
|
||||
DownloadSource.MODELSCOPE: "Shanghai_AI_Laboratory/internlm-chat-7b",
|
||||
},
|
||||
"InternLM-20B-Chat": {
|
||||
DownloadSource.DEFAULT: "internlm/internlm-chat-20b",
|
||||
DownloadSource.MODELSCOPE: "Shanghai_AI_Laboratory/internlm-chat-20b",
|
||||
},
|
||||
},
|
||||
template="intern",
|
||||
)
|
||||
|
||||
|
||||
register_model_group(
|
||||
models={
|
||||
"InternLM2-7B": {
|
||||
DownloadSource.DEFAULT: "internlm/internlm2-7b",
|
||||
DownloadSource.MODELSCOPE: "Shanghai_AI_Laboratory/internlm2-7b",
|
||||
},
|
||||
"InternLM2-20B": {
|
||||
DownloadSource.DEFAULT: "internlm/internlm2-20b",
|
||||
DownloadSource.MODELSCOPE: "Shanghai_AI_Laboratory/internlm2-20b",
|
||||
},
|
||||
"InternLM2-7B-Chat": {
|
||||
DownloadSource.DEFAULT: "internlm/internlm2-chat-7b",
|
||||
DownloadSource.MODELSCOPE: "Shanghai_AI_Laboratory/internlm2-chat-7b",
|
||||
},
|
||||
"InternLM2-20B-Chat": {
|
||||
DownloadSource.DEFAULT: "internlm/internlm2-chat-20b",
|
||||
DownloadSource.MODELSCOPE: "Shanghai_AI_Laboratory/internlm2-chat-20b",
|
||||
},
|
||||
},
|
||||
module="wqkv",
|
||||
template="intern2",
|
||||
)
|
||||
|
||||
|
||||
register_model_group(
|
||||
models={
|
||||
"LingoWhale-8B": {
|
||||
DownloadSource.DEFAULT: "deeplang-ai/LingoWhale-8B",
|
||||
DownloadSource.MODELSCOPE: "DeepLang/LingoWhale-8B",
|
||||
}
|
||||
},
|
||||
module="qkv_proj",
|
||||
)
|
||||
|
||||
|
||||
register_model_group(
|
||||
models={
|
||||
"LLaMA-7B": {
|
||||
DownloadSource.DEFAULT: "huggyllama/llama-7b",
|
||||
DownloadSource.MODELSCOPE: "skyline2006/llama-7b",
|
||||
},
|
||||
"LLaMA-13B": {
|
||||
DownloadSource.DEFAULT: "huggyllama/llama-13b",
|
||||
DownloadSource.MODELSCOPE: "skyline2006/llama-13b",
|
||||
},
|
||||
"LLaMA-30B": {
|
||||
DownloadSource.DEFAULT: "huggyllama/llama-30b",
|
||||
DownloadSource.MODELSCOPE: "skyline2006/llama-30b",
|
||||
},
|
||||
"LLaMA-65B": {
|
||||
DownloadSource.DEFAULT: "huggyllama/llama-65b",
|
||||
DownloadSource.MODELSCOPE: "skyline2006/llama-65b",
|
||||
},
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
register_model_group(
|
||||
models={
|
||||
"LLaMA2-7B": {
|
||||
DownloadSource.DEFAULT: "meta-llama/Llama-2-7b-hf",
|
||||
DownloadSource.MODELSCOPE: "modelscope/Llama-2-7b-ms",
|
||||
},
|
||||
"LLaMA2-13B": {
|
||||
DownloadSource.DEFAULT: "meta-llama/Llama-2-13b-hf",
|
||||
DownloadSource.MODELSCOPE: "modelscope/Llama-2-13b-ms",
|
||||
},
|
||||
"LLaMA2-70B": {
|
||||
DownloadSource.DEFAULT: "meta-llama/Llama-2-70b-hf",
|
||||
DownloadSource.MODELSCOPE: "modelscope/Llama-2-70b-ms",
|
||||
},
|
||||
"LLaMA2-7B-Chat": {
|
||||
DownloadSource.DEFAULT: "meta-llama/Llama-2-7b-chat-hf",
|
||||
DownloadSource.MODELSCOPE: "modelscope/Llama-2-7b-chat-ms",
|
||||
},
|
||||
"LLaMA2-13B-Chat": {
|
||||
DownloadSource.DEFAULT: "meta-llama/Llama-2-13b-chat-hf",
|
||||
DownloadSource.MODELSCOPE: "modelscope/Llama-2-13b-chat-ms",
|
||||
},
|
||||
"LLaMA2-70B-Chat": {
|
||||
DownloadSource.DEFAULT: "meta-llama/Llama-2-70b-chat-hf",
|
||||
DownloadSource.MODELSCOPE: "modelscope/Llama-2-70b-chat-ms",
|
||||
},
|
||||
},
|
||||
template="llama2",
|
||||
)
|
||||
|
||||
|
||||
register_model_group(
|
||||
models={
|
||||
"Mistral-7B": {
|
||||
DownloadSource.DEFAULT: "mistralai/Mistral-7B-v0.1",
|
||||
DownloadSource.MODELSCOPE: "AI-ModelScope/Mistral-7B-v0.1",
|
||||
},
|
||||
"Mistral-7B-Chat": {
|
||||
DownloadSource.DEFAULT: "mistralai/Mistral-7B-Instruct-v0.1",
|
||||
DownloadSource.MODELSCOPE: "AI-ModelScope/Mistral-7B-Instruct-v0.1",
|
||||
},
|
||||
"Mistral-7B-v0.2-Chat": {
|
||||
DownloadSource.DEFAULT: "mistralai/Mistral-7B-Instruct-v0.2",
|
||||
DownloadSource.MODELSCOPE: "AI-ModelScope/Mistral-7B-Instruct-v0.2",
|
||||
},
|
||||
},
|
||||
template="mistral",
|
||||
)
|
||||
|
||||
|
||||
register_model_group(
|
||||
models={
|
||||
"Mixtral-8x7B": {
|
||||
DownloadSource.DEFAULT: "mistralai/Mixtral-8x7B-v0.1",
|
||||
DownloadSource.MODELSCOPE: "AI-ModelScope/Mixtral-8x7B-v0.1",
|
||||
},
|
||||
"Mixtral-8x7B-Chat": {
|
||||
DownloadSource.DEFAULT: "mistralai/Mixtral-8x7B-Instruct-v0.1",
|
||||
DownloadSource.MODELSCOPE: "AI-ModelScope/Mixtral-8x7B-Instruct-v0.1",
|
||||
},
|
||||
},
|
||||
template="mistral",
|
||||
)
|
||||
|
||||
|
||||
register_model_group(
|
||||
models={
|
||||
"OLMo-1B": {
|
||||
DownloadSource.DEFAULT: "allenai/OLMo-1B",
|
||||
},
|
||||
"OLMo-7B": {
|
||||
DownloadSource.DEFAULT: "allenai/OLMo-7B",
|
||||
DownloadSource.MODELSCOPE: "AI-ModelScope/OLMo-7B",
|
||||
},
|
||||
"OLMo-7B-Chat": {
|
||||
DownloadSource.DEFAULT: "allenai/OLMo-7B-Instruct",
|
||||
},
|
||||
},
|
||||
module="att_proj",
|
||||
template="olmo",
|
||||
)
|
||||
|
||||
|
||||
register_model_group(
|
||||
models={
|
||||
"OpenChat3.5-7B-Chat": {
|
||||
DownloadSource.DEFAULT: "openchat/openchat-3.5-0106",
|
||||
DownloadSource.MODELSCOPE: "myxiongmodel/openchat_3.5",
|
||||
}
|
||||
},
|
||||
template="openchat",
|
||||
)
|
||||
|
||||
|
||||
register_model_group(
|
||||
models={
|
||||
"Orion-14B-Base": {
|
||||
DownloadSource.DEFAULT: "OrionStarAI/Orion-14B-Base",
|
||||
DownloadSource.MODELSCOPE: "OrionStarAI/Orion-14B-Base",
|
||||
},
|
||||
"Orion-14B-Chat": {
|
||||
DownloadSource.DEFAULT: "OrionStarAI/Orion-14B-Chat",
|
||||
DownloadSource.MODELSCOPE: "OrionStarAI/Orion-14B-Chat",
|
||||
},
|
||||
"Orion-14B-Long-Chat": {
|
||||
DownloadSource.DEFAULT: "OrionStarAI/Orion-14B-LongChat",
|
||||
DownloadSource.MODELSCOPE: "OrionStarAI/Orion-14B-LongChat",
|
||||
},
|
||||
"Orion-14B-RAG-Chat": {
|
||||
DownloadSource.DEFAULT: "OrionStarAI/Orion-14B-Chat-RAG",
|
||||
DownloadSource.MODELSCOPE: "OrionStarAI/Orion-14B-Chat-RAG",
|
||||
},
|
||||
"Orion-14B-Plugin-Chat": {
|
||||
DownloadSource.DEFAULT: "OrionStarAI/Orion-14B-Chat-Plugin",
|
||||
DownloadSource.MODELSCOPE: "OrionStarAI/Orion-14B-Chat-Plugin",
|
||||
},
|
||||
},
|
||||
template="orion",
|
||||
)
|
||||
|
||||
|
||||
register_model_group(
|
||||
models={
|
||||
"Phi-1.5-1.3B": {
|
||||
DownloadSource.DEFAULT: "microsoft/phi-1_5",
|
||||
DownloadSource.MODELSCOPE: "allspace/PHI_1-5",
|
||||
},
|
||||
"Phi-2-2.7B": {
|
||||
DownloadSource.DEFAULT: "microsoft/phi-2",
|
||||
DownloadSource.MODELSCOPE: "AI-ModelScope/phi-2",
|
||||
},
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
register_model_group(
|
||||
models={
|
||||
"Qwen-1.8B": {
|
||||
DownloadSource.DEFAULT: "Qwen/Qwen-1_8B",
|
||||
DownloadSource.MODELSCOPE: "qwen/Qwen-1_8B",
|
||||
},
|
||||
"Qwen-7B": {
|
||||
DownloadSource.DEFAULT: "Qwen/Qwen-7B",
|
||||
DownloadSource.MODELSCOPE: "qwen/Qwen-7B",
|
||||
},
|
||||
"Qwen-14B": {
|
||||
DownloadSource.DEFAULT: "Qwen/Qwen-14B",
|
||||
DownloadSource.MODELSCOPE: "qwen/Qwen-14B",
|
||||
},
|
||||
"Qwen-72B": {
|
||||
DownloadSource.DEFAULT: "Qwen/Qwen-72B",
|
||||
DownloadSource.MODELSCOPE: "qwen/Qwen-72B",
|
||||
},
|
||||
"Qwen-1.8B-Chat": {
|
||||
DownloadSource.DEFAULT: "Qwen/Qwen-1_8B-Chat",
|
||||
DownloadSource.MODELSCOPE: "qwen/Qwen-1_8B-Chat",
|
||||
},
|
||||
"Qwen-7B-Chat": {
|
||||
DownloadSource.DEFAULT: "Qwen/Qwen-7B-Chat",
|
||||
DownloadSource.MODELSCOPE: "qwen/Qwen-7B-Chat",
|
||||
},
|
||||
"Qwen-14B-Chat": {
|
||||
DownloadSource.DEFAULT: "Qwen/Qwen-14B-Chat",
|
||||
DownloadSource.MODELSCOPE: "qwen/Qwen-14B-Chat",
|
||||
},
|
||||
"Qwen-72B-Chat": {
|
||||
DownloadSource.DEFAULT: "Qwen/Qwen-72B-Chat",
|
||||
DownloadSource.MODELSCOPE: "qwen/Qwen-72B-Chat",
|
||||
},
|
||||
"Qwen-1.8B-int8-Chat": {
|
||||
DownloadSource.DEFAULT: "Qwen/Qwen-1_8B-Chat-Int8",
|
||||
DownloadSource.MODELSCOPE: "qwen/Qwen-1_8B-Chat-Int8",
|
||||
},
|
||||
"Qwen-1.8B-int4-Chat": {
|
||||
DownloadSource.DEFAULT: "Qwen/Qwen-1_8B-Chat-Int4",
|
||||
DownloadSource.MODELSCOPE: "qwen/Qwen-1_8B-Chat-Int4",
|
||||
},
|
||||
"Qwen-7B-int8-Chat": {
|
||||
DownloadSource.DEFAULT: "Qwen/Qwen-7B-Chat-Int8",
|
||||
DownloadSource.MODELSCOPE: "qwen/Qwen-7B-Chat-Int8",
|
||||
},
|
||||
"Qwen-7B-int4-Chat": {
|
||||
DownloadSource.DEFAULT: "Qwen/Qwen-7B-Chat-Int4",
|
||||
DownloadSource.MODELSCOPE: "qwen/Qwen-7B-Chat-Int4",
|
||||
},
|
||||
"Qwen-14B-int8-Chat": {
|
||||
DownloadSource.DEFAULT: "Qwen/Qwen-14B-Chat-Int8",
|
||||
DownloadSource.MODELSCOPE: "qwen/Qwen-14B-Chat-Int8",
|
||||
},
|
||||
"Qwen-14B-int4-Chat": {
|
||||
DownloadSource.DEFAULT: "Qwen/Qwen-14B-Chat-Int4",
|
||||
DownloadSource.MODELSCOPE: "qwen/Qwen-14B-Chat-Int4",
|
||||
},
|
||||
"Qwen-72B-int8-Chat": {
|
||||
DownloadSource.DEFAULT: "Qwen/Qwen-72B-Chat-Int8",
|
||||
DownloadSource.MODELSCOPE: "qwen/Qwen-72B-Chat-Int8",
|
||||
},
|
||||
"Qwen-72B-int4-Chat": {
|
||||
DownloadSource.DEFAULT: "Qwen/Qwen-72B-Chat-Int4",
|
||||
DownloadSource.MODELSCOPE: "qwen/Qwen-72B-Chat-Int4",
|
||||
},
|
||||
},
|
||||
module="c_attn",
|
||||
template="qwen",
|
||||
)
|
||||
|
||||
|
||||
register_model_group(
|
||||
models={
|
||||
"Qwen1.5-0.5B": {
|
||||
DownloadSource.DEFAULT: "Qwen/Qwen1.5-0.5B",
|
||||
DownloadSource.MODELSCOPE: "qwen/Qwen1.5-0.5B",
|
||||
},
|
||||
"Qwen1.5-1.8B": {
|
||||
DownloadSource.DEFAULT: "Qwen/Qwen1.5-1.8B",
|
||||
DownloadSource.MODELSCOPE: "qwen/Qwen1.5-1.8B",
|
||||
},
|
||||
"Qwen1.5-4B": {
|
||||
DownloadSource.DEFAULT: "Qwen/Qwen1.5-4B",
|
||||
DownloadSource.MODELSCOPE: "qwen/Qwen1.5-4B",
|
||||
},
|
||||
"Qwen1.5-7B": {
|
||||
DownloadSource.DEFAULT: "Qwen/Qwen1.5-7B",
|
||||
DownloadSource.MODELSCOPE: "qwen/Qwen1.5-7B",
|
||||
},
|
||||
"Qwen1.5-14B": {
|
||||
DownloadSource.DEFAULT: "Qwen/Qwen1.5-14B",
|
||||
DownloadSource.MODELSCOPE: "qwen/Qwen1.5-14B",
|
||||
},
|
||||
"Qwen1.5-72B": {
|
||||
DownloadSource.DEFAULT: "Qwen/Qwen1.5-72B",
|
||||
DownloadSource.MODELSCOPE: "qwen/Qwen1.5-72B",
|
||||
},
|
||||
"Qwen1.5-0.5B-Chat": {
|
||||
DownloadSource.DEFAULT: "Qwen/Qwen1.5-0.5B-Chat",
|
||||
DownloadSource.MODELSCOPE: "qwen/Qwen1.5-0.5B-Chat",
|
||||
},
|
||||
"Qwen1.5-1.8B-Chat": {
|
||||
DownloadSource.DEFAULT: "Qwen/Qwen1.5-1.8B-Chat",
|
||||
DownloadSource.MODELSCOPE: "qwen/Qwen1.5-1.8B-Chat",
|
||||
},
|
||||
"Qwen1.5-4B-Chat": {
|
||||
DownloadSource.DEFAULT: "Qwen/Qwen1.5-4B-Chat",
|
||||
DownloadSource.MODELSCOPE: "qwen/Qwen1.5-4B-Chat",
|
||||
},
|
||||
"Qwen1.5-7B-Chat": {
|
||||
DownloadSource.DEFAULT: "Qwen/Qwen1.5-7B-Chat",
|
||||
DownloadSource.MODELSCOPE: "qwen/Qwen1.5-7B-Chat",
|
||||
},
|
||||
"Qwen1.5-14B-Chat": {
|
||||
DownloadSource.DEFAULT: "Qwen/Qwen1.5-14B-Chat",
|
||||
DownloadSource.MODELSCOPE: "qwen/Qwen1.5-14B-Chat",
|
||||
},
|
||||
"Qwen1.5-72B-Chat": {
|
||||
DownloadSource.DEFAULT: "Qwen/Qwen1.5-72B-Chat",
|
||||
DownloadSource.MODELSCOPE: "qwen/Qwen1.5-72B-Chat",
|
||||
},
|
||||
"Qwen1.5-0.5B-int8-Chat": {
|
||||
DownloadSource.DEFAULT: "Qwen/Qwen1.5-0.5B-Chat-GPTQ-Int8",
|
||||
DownloadSource.MODELSCOPE: "qwen/Qwen1.5-0.5B-Chat-GPTQ-Int8",
|
||||
},
|
||||
"Qwen1.5-0.5B-int4-Chat": {
|
||||
DownloadSource.DEFAULT: "Qwen/Qwen1.5-0.5B-Chat-AWQ",
|
||||
DownloadSource.MODELSCOPE: "qwen/Qwen1.5-0.5B-Chat-AWQ",
|
||||
},
|
||||
"Qwen1.5-1.8B-int8-Chat": {
|
||||
DownloadSource.DEFAULT: "Qwen/Qwen1.5-1.8B-Chat-GPTQ-Int8",
|
||||
DownloadSource.MODELSCOPE: "qwen/Qwen1.5-1.8B-Chat-GPTQ-Int8",
|
||||
},
|
||||
"Qwen1.5-1.8B-int4-Chat": {
|
||||
DownloadSource.DEFAULT: "Qwen/Qwen1.5-1.8B-Chat-AWQ",
|
||||
DownloadSource.MODELSCOPE: "qwen/Qwen1.5-1.8B-Chat-AWQ",
|
||||
},
|
||||
"Qwen1.5-4B-int8-Chat": {
|
||||
DownloadSource.DEFAULT: "Qwen/Qwen1.5-4B-Chat-GPTQ-Int8",
|
||||
DownloadSource.MODELSCOPE: "qwen/Qwen1.5-4B-Chat-GPTQ-Int8",
|
||||
},
|
||||
"Qwen1.5-4B-int4-Chat": {
|
||||
DownloadSource.DEFAULT: "Qwen/Qwen1.5-4B-Chat-AWQ",
|
||||
DownloadSource.MODELSCOPE: "qwen/Qwen1.5-4B-Chat-AWQ",
|
||||
},
|
||||
"Qwen1.5-7B-int8-Chat": {
|
||||
DownloadSource.DEFAULT: "Qwen/Qwen1.5-7B-Chat-GPTQ-Int8",
|
||||
DownloadSource.MODELSCOPE: "qwen/Qwen1.5-7B-Chat-GPTQ-Int8",
|
||||
},
|
||||
"Qwen1.5-7B-int4-Chat": {
|
||||
DownloadSource.DEFAULT: "Qwen/Qwen1.5-7B-Chat-AWQ",
|
||||
DownloadSource.MODELSCOPE: "qwen/Qwen1.5-7B-Chat-AWQ",
|
||||
},
|
||||
"Qwen1.5-14B-int8-Chat": {
|
||||
DownloadSource.DEFAULT: "Qwen/Qwen1.5-14B-Chat-GPTQ-Int8",
|
||||
DownloadSource.MODELSCOPE: "qwen/Qwen1.5-14B-Chat-GPTQ-Int8",
|
||||
},
|
||||
"Qwen1.5-14B-int4-Chat": {
|
||||
DownloadSource.DEFAULT: "Qwen/Qwen1.5-14B-Chat-AWQ",
|
||||
DownloadSource.MODELSCOPE: "qwen/Qwen1.5-14B-Chat-AWQ",
|
||||
},
|
||||
"Qwen1.5-72B-int8-Chat": {
|
||||
DownloadSource.DEFAULT: "Qwen/Qwen1.5-72B-Chat-GPTQ-Int8",
|
||||
DownloadSource.MODELSCOPE: "qwen/Qwen1.5-72B-Chat-GPTQ-Int8",
|
||||
},
|
||||
"Qwen1.5-72B-int4-Chat": {
|
||||
DownloadSource.DEFAULT: "Qwen/Qwen1.5-72B-Chat-AWQ",
|
||||
DownloadSource.MODELSCOPE: "qwen/Qwen1.5-72B-Chat-AWQ",
|
||||
},
|
||||
},
|
||||
template="qwen",
|
||||
)
|
||||
|
||||
|
||||
register_model_group(
|
||||
models={
|
||||
"SOLAR-10.7B": {
|
||||
DownloadSource.DEFAULT: "upstage/SOLAR-10.7B-v1.0",
|
||||
},
|
||||
"SOLAR-10.7B-Chat": {
|
||||
DownloadSource.DEFAULT: "upstage/SOLAR-10.7B-Instruct-v1.0",
|
||||
DownloadSource.MODELSCOPE: "AI-ModelScope/SOLAR-10.7B-Instruct-v1.0",
|
||||
},
|
||||
},
|
||||
template="solar",
|
||||
)
|
||||
|
||||
|
||||
register_model_group(
|
||||
models={
|
||||
"Skywork-13B-Base": {
|
||||
DownloadSource.DEFAULT: "Skywork/Skywork-13B-base",
|
||||
DownloadSource.MODELSCOPE: "skywork/Skywork-13B-base",
|
||||
}
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
register_model_group(
|
||||
models={
|
||||
"StarCoder2-3B": {
|
||||
DownloadSource.DEFAULT: "bigcode/starcoder2-3b",
|
||||
},
|
||||
"StarCoder2-7B": {
|
||||
DownloadSource.DEFAULT: "bigcode/starcoder2-7b",
|
||||
},
|
||||
"StarCoder2-15B": {
|
||||
DownloadSource.DEFAULT: "bigcode/starcoder2-15b",
|
||||
},
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
register_model_group(
|
||||
models={
|
||||
"Vicuna1.5-7B-Chat": {
|
||||
DownloadSource.DEFAULT: "lmsys/vicuna-7b-v1.5",
|
||||
DownloadSource.MODELSCOPE: "Xorbits/vicuna-7b-v1.5",
|
||||
},
|
||||
"Vicuna1.5-13B-Chat": {
|
||||
DownloadSource.DEFAULT: "lmsys/vicuna-13b-v1.5",
|
||||
DownloadSource.MODELSCOPE: "Xorbits/vicuna-13b-v1.5",
|
||||
},
|
||||
},
|
||||
template="vicuna",
|
||||
)
|
||||
|
||||
|
||||
register_model_group(
|
||||
models={
|
||||
"XuanYuan-70B": {
|
||||
DownloadSource.DEFAULT: "Duxiaoman-DI/XuanYuan-70B",
|
||||
},
|
||||
"XuanYuan-70B-Chat": {
|
||||
DownloadSource.DEFAULT: "Duxiaoman-DI/XuanYuan-70B-Chat",
|
||||
},
|
||||
"XuanYuan-70B-int8-Chat": {
|
||||
DownloadSource.DEFAULT: "Duxiaoman-DI/XuanYuan-70B-Chat-8bit",
|
||||
},
|
||||
"XuanYuan-70B-int4-Chat": {
|
||||
DownloadSource.DEFAULT: "Duxiaoman-DI/XuanYuan-70B-Chat-4bit",
|
||||
},
|
||||
},
|
||||
template="xuanyuan",
|
||||
)
|
||||
|
||||
|
||||
register_model_group(
|
||||
models={
|
||||
"XVERSE-7B": {
|
||||
DownloadSource.DEFAULT: "xverse/XVERSE-7B",
|
||||
DownloadSource.MODELSCOPE: "xverse/XVERSE-7B",
|
||||
},
|
||||
"XVERSE-13B": {
|
||||
DownloadSource.DEFAULT: "xverse/XVERSE-13B",
|
||||
DownloadSource.MODELSCOPE: "xverse/XVERSE-13B",
|
||||
},
|
||||
"XVERSE-65B": {
|
||||
DownloadSource.DEFAULT: "xverse/XVERSE-65B",
|
||||
DownloadSource.MODELSCOPE: "xverse/XVERSE-65B",
|
||||
},
|
||||
"XVERSE-65B-2": {
|
||||
DownloadSource.DEFAULT: "xverse/XVERSE-65B-2",
|
||||
DownloadSource.MODELSCOPE: "xverse/XVERSE-65B-2",
|
||||
},
|
||||
"XVERSE-7B-Chat": {
|
||||
DownloadSource.DEFAULT: "xverse/XVERSE-7B-Chat",
|
||||
DownloadSource.MODELSCOPE: "xverse/XVERSE-7B-Chat",
|
||||
},
|
||||
"XVERSE-13B-Chat": {
|
||||
DownloadSource.DEFAULT: "xverse/XVERSE-13B-Chat",
|
||||
DownloadSource.MODELSCOPE: "xverse/XVERSE-13B-Chat",
|
||||
},
|
||||
"XVERSE-65B-Chat": {
|
||||
DownloadSource.DEFAULT: "xverse/XVERSE-65B-Chat",
|
||||
DownloadSource.MODELSCOPE: "xverse/XVERSE-65B-Chat",
|
||||
},
|
||||
},
|
||||
template="xverse",
|
||||
)
|
||||
|
||||
|
||||
register_model_group(
|
||||
models={
|
||||
"Yayi-7B": {
|
||||
DownloadSource.DEFAULT: "wenge-research/yayi-7b-llama2",
|
||||
DownloadSource.MODELSCOPE: "AI-ModelScope/yayi-7b-llama2",
|
||||
},
|
||||
"Yayi-13B": {
|
||||
DownloadSource.DEFAULT: "wenge-research/yayi-13b-llama2",
|
||||
DownloadSource.MODELSCOPE: "AI-ModelScope/yayi-13b-llama2",
|
||||
},
|
||||
},
|
||||
template="yayi",
|
||||
)
|
||||
|
||||
|
||||
register_model_group(
|
||||
models={
|
||||
"Yi-6B": {
|
||||
DownloadSource.DEFAULT: "01-ai/Yi-6B",
|
||||
DownloadSource.MODELSCOPE: "01ai/Yi-6B",
|
||||
},
|
||||
"Yi-9B": {
|
||||
DownloadSource.DEFAULT: "01-ai/Yi-9B",
|
||||
DownloadSource.MODELSCOPE: "01ai/Yi-9B",
|
||||
},
|
||||
"Yi-34B": {
|
||||
DownloadSource.DEFAULT: "01-ai/Yi-34B",
|
||||
DownloadSource.MODELSCOPE: "01ai/Yi-34B",
|
||||
},
|
||||
"Yi-6B-Chat": {
|
||||
DownloadSource.DEFAULT: "01-ai/Yi-6B-Chat",
|
||||
DownloadSource.MODELSCOPE: "01ai/Yi-6B-Chat",
|
||||
},
|
||||
"Yi-34B-Chat": {
|
||||
DownloadSource.DEFAULT: "01-ai/Yi-34B-Chat",
|
||||
DownloadSource.MODELSCOPE: "01ai/Yi-34B-Chat",
|
||||
},
|
||||
"Yi-6B-int8-Chat": {
|
||||
DownloadSource.DEFAULT: "01-ai/Yi-6B-Chat-8bits",
|
||||
DownloadSource.MODELSCOPE: "01ai/Yi-6B-Chat-8bits",
|
||||
},
|
||||
"Yi-6B-int4-Chat": {
|
||||
DownloadSource.DEFAULT: "01-ai/Yi-6B-Chat-4bits",
|
||||
DownloadSource.MODELSCOPE: "01ai/Yi-6B-Chat-4bits",
|
||||
},
|
||||
"Yi-34B-int8-Chat": {
|
||||
DownloadSource.DEFAULT: "01-ai/Yi-34B-Chat-8bits",
|
||||
DownloadSource.MODELSCOPE: "01ai/Yi-34B-Chat-8bits",
|
||||
},
|
||||
"Yi-34B-int4-Chat": {
|
||||
DownloadSource.DEFAULT: "01-ai/Yi-34B-Chat-4bits",
|
||||
DownloadSource.MODELSCOPE: "01ai/Yi-34B-Chat-4bits",
|
||||
},
|
||||
},
|
||||
template="yi",
|
||||
)
|
||||
|
||||
|
||||
register_model_group(
|
||||
models={
|
||||
"Yuan2-2B-Chat": {
|
||||
DownloadSource.DEFAULT: "IEITYuan/Yuan2-2B-hf",
|
||||
DownloadSource.MODELSCOPE: "YuanLLM/Yuan2.0-2B-hf",
|
||||
},
|
||||
"Yuan2-51B-Chat": {
|
||||
DownloadSource.DEFAULT: "IEITYuan/Yuan2-51B-hf",
|
||||
DownloadSource.MODELSCOPE: "YuanLLM/Yuan2.0-51B-hf",
|
||||
},
|
||||
"Yuan2-102B-Chat": {
|
||||
DownloadSource.DEFAULT: "IEITYuan/Yuan2-102B-hf",
|
||||
DownloadSource.MODELSCOPE: "YuanLLM/Yuan2.0-102B-hf",
|
||||
},
|
||||
},
|
||||
template="yuan",
|
||||
)
|
||||
|
||||
|
||||
register_model_group(
|
||||
models={
|
||||
"Zephyr-7B-Alpha-Chat": {
|
||||
DownloadSource.DEFAULT: "HuggingFaceH4/zephyr-7b-alpha",
|
||||
DownloadSource.MODELSCOPE: "AI-ModelScope/zephyr-7b-alpha",
|
||||
},
|
||||
"Zephyr-7B-Beta-Chat": {
|
||||
DownloadSource.DEFAULT: "HuggingFaceH4/zephyr-7b-beta",
|
||||
DownloadSource.MODELSCOPE: "modelscope/zephyr-7b-beta",
|
||||
},
|
||||
},
|
||||
template="zephyr",
|
||||
)
|
||||
|
||||
|
||||
register_model_group(
|
||||
models={
|
||||
"Atom-7B": {
|
||||
DownloadSource.DEFAULT: "FlagAlpha/Atom-7B",
|
||||
DownloadSource.MODELSCOPE: "FlagAlpha/Atom-7B",
|
||||
},
|
||||
"Atom-7B-Chat": {
|
||||
DownloadSource.DEFAULT: "FlagAlpha/Atom-7B-Chat",
|
||||
DownloadSource.MODELSCOPE: "FlagAlpha/Atom-7B-Chat",
|
||||
},
|
||||
},
|
||||
template="atom",
|
||||
)
|
||||
@@ -0,0 +1,48 @@
|
||||
import logging
|
||||
import sys
|
||||
|
||||
|
||||
class LoggerHandler(logging.Handler):
|
||||
r"""
|
||||
Logger handler used in Web UI.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.log = ""
|
||||
|
||||
def reset(self):
|
||||
self.log = ""
|
||||
|
||||
def emit(self, record):
|
||||
if record.name == "httpx":
|
||||
return
|
||||
log_entry = self.format(record)
|
||||
self.log += log_entry
|
||||
self.log += "\n\n"
|
||||
|
||||
|
||||
def get_logger(name: str) -> logging.Logger:
|
||||
r"""
|
||||
Gets a standard logger with a stream hander to stdout.
|
||||
"""
|
||||
formatter = logging.Formatter(
|
||||
fmt="%(asctime)s - %(levelname)s - %(name)s - %(message)s", datefmt="%m/%d/%Y %H:%M:%S"
|
||||
)
|
||||
handler = logging.StreamHandler(sys.stdout)
|
||||
handler.setFormatter(formatter)
|
||||
|
||||
logger = logging.getLogger(name)
|
||||
logger.setLevel(logging.INFO)
|
||||
logger.addHandler(handler)
|
||||
|
||||
return logger
|
||||
|
||||
|
||||
def reset_logging() -> None:
|
||||
r"""
|
||||
Removes basic config of root logger. (unused in script)
|
||||
"""
|
||||
root = logging.getLogger()
|
||||
list(map(root.removeHandler, root.handlers))
|
||||
list(map(root.removeFilter, root.filters))
|
||||
216
src/AntSK.LLamaFactory/llamafactory/llmtuner/extras/misc.py
Normal file
216
src/AntSK.LLamaFactory/llamafactory/llmtuner/extras/misc.py
Normal file
@@ -0,0 +1,216 @@
|
||||
import gc
|
||||
import os
|
||||
from typing import TYPE_CHECKING, Dict, Tuple
|
||||
|
||||
import torch
|
||||
from peft import PeftModel
|
||||
from transformers import InfNanRemoveLogitsProcessor, LogitsProcessorList, PreTrainedModel
|
||||
from transformers.utils import (
|
||||
SAFE_WEIGHTS_NAME,
|
||||
WEIGHTS_NAME,
|
||||
is_torch_bf16_gpu_available,
|
||||
is_torch_cuda_available,
|
||||
is_torch_mps_available,
|
||||
is_torch_npu_available,
|
||||
is_torch_xpu_available,
|
||||
)
|
||||
from transformers.utils.versions import require_version
|
||||
|
||||
from .constants import V_HEAD_SAFE_WEIGHTS_NAME, V_HEAD_WEIGHTS_NAME
|
||||
from .logging import get_logger
|
||||
|
||||
|
||||
_is_fp16_available = is_torch_npu_available() or is_torch_cuda_available()
|
||||
try:
|
||||
_is_bf16_available = is_torch_bf16_gpu_available()
|
||||
except Exception:
|
||||
_is_bf16_available = False
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from trl import AutoModelForCausalLMWithValueHead
|
||||
|
||||
from llmtuner.hparams import ModelArguments
|
||||
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
class AverageMeter:
|
||||
r"""
|
||||
Computes and stores the average and current value.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self.reset()
|
||||
|
||||
def reset(self):
|
||||
self.val = 0
|
||||
self.avg = 0
|
||||
self.sum = 0
|
||||
self.count = 0
|
||||
|
||||
def update(self, val, n=1):
|
||||
self.val = val
|
||||
self.sum += val * n
|
||||
self.count += n
|
||||
self.avg = self.sum / self.count
|
||||
|
||||
|
||||
def check_dependencies() -> None:
|
||||
if int(os.environ.get("DISABLE_VERSION_CHECK", "0")):
|
||||
logger.warning("Version checking has been disabled, may lead to unexpected behaviors.")
|
||||
else:
|
||||
require_version("transformers>=4.37.2", "To fix: pip install transformers>=4.37.2")
|
||||
require_version("datasets>=2.14.3", "To fix: pip install datasets>=2.14.3")
|
||||
require_version("accelerate>=0.27.2", "To fix: pip install accelerate>=0.27.2")
|
||||
require_version("peft>=0.9.0", "To fix: pip install peft>=0.9.0")
|
||||
require_version("trl>=0.7.11", "To fix: pip install trl>=0.7.11")
|
||||
|
||||
|
||||
def count_parameters(model: torch.nn.Module) -> Tuple[int, int]:
|
||||
r"""
|
||||
Returns the number of trainable parameters and number of all parameters in the model.
|
||||
"""
|
||||
trainable_params, all_param = 0, 0
|
||||
for param in model.parameters():
|
||||
num_params = param.numel()
|
||||
# if using DS Zero 3 and the weights are initialized empty
|
||||
if num_params == 0 and hasattr(param, "ds_numel"):
|
||||
num_params = param.ds_numel
|
||||
|
||||
# Due to the design of 4bit linear layers from bitsandbytes, multiply the number of parameters by 2
|
||||
if param.__class__.__name__ == "Params4bit":
|
||||
num_params = num_params * 2
|
||||
|
||||
all_param += num_params
|
||||
if param.requires_grad:
|
||||
trainable_params += num_params
|
||||
|
||||
return trainable_params, all_param
|
||||
|
||||
|
||||
def fix_valuehead_checkpoint(
|
||||
model: "AutoModelForCausalLMWithValueHead", output_dir: str, safe_serialization: bool
|
||||
) -> None:
|
||||
r"""
|
||||
The model is already unwrapped.
|
||||
|
||||
There are three cases:
|
||||
1. full tuning without ds_zero3: state_dict = {"model.layers.*": ..., "v_head.summary.*": ...}
|
||||
2. lora tuning without ds_zero3: state_dict = {"v_head.summary.*": ...}
|
||||
3. under deepspeed zero3: state_dict = {"pretrained_model.model.layers.*": ..., "v_head.summary.*": ...}
|
||||
|
||||
We assume `stage3_gather_16bit_weights_on_model_save=true`.
|
||||
"""
|
||||
if not isinstance(model.pretrained_model, (PreTrainedModel, PeftModel)):
|
||||
return
|
||||
|
||||
if safe_serialization:
|
||||
from safetensors import safe_open
|
||||
from safetensors.torch import save_file
|
||||
|
||||
path_to_checkpoint = os.path.join(output_dir, SAFE_WEIGHTS_NAME)
|
||||
with safe_open(path_to_checkpoint, framework="pt", device="cpu") as f:
|
||||
state_dict: Dict[str, torch.Tensor] = {key: f.get_tensor(key) for key in f.keys()}
|
||||
else:
|
||||
path_to_checkpoint = os.path.join(output_dir, WEIGHTS_NAME)
|
||||
state_dict: Dict[str, torch.Tensor] = torch.load(path_to_checkpoint, map_location="cpu")
|
||||
|
||||
decoder_state_dict = {}
|
||||
v_head_state_dict = {}
|
||||
for name, param in state_dict.items():
|
||||
if name.startswith("v_head."):
|
||||
v_head_state_dict[name] = param
|
||||
else:
|
||||
decoder_state_dict[name.replace("pretrained_model.", "")] = param
|
||||
|
||||
os.remove(path_to_checkpoint)
|
||||
model.pretrained_model.save_pretrained(
|
||||
output_dir, state_dict=decoder_state_dict or None, safe_serialization=safe_serialization
|
||||
)
|
||||
|
||||
if safe_serialization:
|
||||
save_file(v_head_state_dict, os.path.join(output_dir, V_HEAD_SAFE_WEIGHTS_NAME), metadata={"format": "pt"})
|
||||
else:
|
||||
torch.save(v_head_state_dict, os.path.join(output_dir, V_HEAD_WEIGHTS_NAME))
|
||||
|
||||
logger.info("Value head model saved at: {}".format(output_dir))
|
||||
|
||||
|
||||
def get_current_device() -> torch.device:
|
||||
r"""
|
||||
Gets the current available device.
|
||||
"""
|
||||
if is_torch_xpu_available():
|
||||
device = "xpu:{}".format(os.environ.get("LOCAL_RANK", "0"))
|
||||
elif is_torch_npu_available():
|
||||
device = "npu:{}".format(os.environ.get("LOCAL_RANK", "0"))
|
||||
elif is_torch_mps_available():
|
||||
device = "mps:{}".format(os.environ.get("LOCAL_RANK", "0"))
|
||||
elif is_torch_cuda_available():
|
||||
device = "cuda:{}".format(os.environ.get("LOCAL_RANK", "0"))
|
||||
else:
|
||||
device = "cpu"
|
||||
|
||||
return torch.device(device)
|
||||
|
||||
|
||||
def get_device_count() -> int:
|
||||
r"""
|
||||
Gets the number of available GPU devices.
|
||||
"""
|
||||
if not torch.cuda.is_available():
|
||||
return 0
|
||||
|
||||
return torch.cuda.device_count()
|
||||
|
||||
|
||||
def get_logits_processor() -> "LogitsProcessorList":
|
||||
r"""
|
||||
Gets logits processor that removes NaN and Inf logits.
|
||||
"""
|
||||
logits_processor = LogitsProcessorList()
|
||||
logits_processor.append(InfNanRemoveLogitsProcessor())
|
||||
return logits_processor
|
||||
|
||||
|
||||
def infer_optim_dtype(model_dtype: torch.dtype) -> torch.dtype:
|
||||
r"""
|
||||
Infers the optimal dtype according to the model_dtype and device compatibility.
|
||||
"""
|
||||
if _is_bf16_available and model_dtype == torch.bfloat16:
|
||||
return torch.bfloat16
|
||||
elif _is_fp16_available:
|
||||
return torch.float16
|
||||
else:
|
||||
return torch.float32
|
||||
|
||||
|
||||
def torch_gc() -> None:
|
||||
r"""
|
||||
Collects GPU memory.
|
||||
"""
|
||||
gc.collect()
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.empty_cache()
|
||||
torch.cuda.ipc_collect()
|
||||
|
||||
|
||||
def try_download_model_from_ms(model_args: "ModelArguments") -> None:
|
||||
if not use_modelscope() or os.path.exists(model_args.model_name_or_path):
|
||||
return
|
||||
|
||||
try:
|
||||
from modelscope import snapshot_download
|
||||
|
||||
revision = "master" if model_args.model_revision == "main" else model_args.model_revision
|
||||
model_args.model_name_or_path = snapshot_download(
|
||||
model_args.model_name_or_path, revision=revision, cache_dir=model_args.cache_dir
|
||||
)
|
||||
except ImportError:
|
||||
raise ImportError("Please install modelscope via `pip install modelscope -U`")
|
||||
|
||||
|
||||
def use_modelscope() -> bool:
|
||||
return bool(int(os.environ.get("USE_MODELSCOPE_HUB", "0")))
|
||||
@@ -0,0 +1,61 @@
|
||||
import importlib.metadata
|
||||
import importlib.util
|
||||
|
||||
|
||||
def _is_package_available(name: str) -> bool:
|
||||
return importlib.util.find_spec(name) is not None
|
||||
|
||||
|
||||
def _get_package_version(name: str) -> str:
|
||||
try:
|
||||
return importlib.metadata.version(name)
|
||||
except Exception:
|
||||
return "0.0.0"
|
||||
|
||||
|
||||
def is_fastapi_availble():
|
||||
return _is_package_available("fastapi")
|
||||
|
||||
|
||||
def is_flash_attn2_available():
|
||||
return _is_package_available("flash_attn") and _get_package_version("flash_attn").startswith("2")
|
||||
|
||||
|
||||
def is_galore_available():
|
||||
return _is_package_available("galore_torch")
|
||||
|
||||
|
||||
def is_jieba_available():
|
||||
return _is_package_available("jieba")
|
||||
|
||||
|
||||
def is_matplotlib_available():
|
||||
return _is_package_available("matplotlib")
|
||||
|
||||
|
||||
def is_nltk_available():
|
||||
return _is_package_available("nltk")
|
||||
|
||||
|
||||
def is_requests_available():
|
||||
return _is_package_available("requests")
|
||||
|
||||
|
||||
def is_rouge_available():
|
||||
return _is_package_available("rouge_chinese")
|
||||
|
||||
|
||||
def is_starlette_available():
|
||||
return _is_package_available("sse_starlette")
|
||||
|
||||
|
||||
def is_unsloth_available():
|
||||
return _is_package_available("unsloth")
|
||||
|
||||
|
||||
def is_uvicorn_available():
|
||||
return _is_package_available("uvicorn")
|
||||
|
||||
|
||||
def is_vllm_available():
|
||||
return _is_package_available("vllm")
|
||||
@@ -0,0 +1,197 @@
|
||||
import math
|
||||
from typing import Optional, Tuple
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from transformers.models.llama.modeling_llama import (
|
||||
Cache,
|
||||
LlamaAttention,
|
||||
LlamaFlashAttention2,
|
||||
apply_rotary_pos_emb,
|
||||
repeat_kv,
|
||||
)
|
||||
from transformers.utils import logging
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
# Modified from: https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py
|
||||
def llama_torch_attn_forward(
|
||||
self: "LlamaAttention",
|
||||
hidden_states: torch.Tensor,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
past_key_value: Optional["Cache"] = None,
|
||||
output_attentions: bool = False,
|
||||
**kwargs,
|
||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
||||
bsz, q_len, _ = hidden_states.size()
|
||||
|
||||
query_states = self.q_proj(hidden_states)
|
||||
key_states = self.k_proj(hidden_states)
|
||||
value_states = self.v_proj(hidden_states)
|
||||
|
||||
query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
|
||||
key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
||||
value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
||||
|
||||
kv_seq_len = key_states.shape[-2]
|
||||
if past_key_value is not None:
|
||||
kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
|
||||
|
||||
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
|
||||
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
|
||||
|
||||
if past_key_value is not None:
|
||||
cache_kwargs = {"sin": sin, "cos": cos} # Specific to RoPE models
|
||||
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
|
||||
|
||||
key_states = repeat_kv(key_states, self.num_key_value_groups)
|
||||
value_states = repeat_kv(value_states, self.num_key_value_groups)
|
||||
|
||||
if getattr(self.config, "group_size_ratio", None) and self.training: # shift
|
||||
groupsz = int(q_len * getattr(self.config, "group_size_ratio"))
|
||||
assert q_len % groupsz == 0, "q_len {} should be divisible by group size {}.".format(q_len, groupsz)
|
||||
num_groups = q_len // groupsz
|
||||
|
||||
def shift(state: torch.Tensor) -> torch.Tensor:
|
||||
state = state.transpose(1, 2) # output: (bsz, seq_len, n_heads, head_dim)
|
||||
state = torch.cat(
|
||||
(state[:, :, : self.num_heads // 2], state[:, :, self.num_heads // 2 :].roll(-groupsz // 2, dims=1)),
|
||||
dim=2,
|
||||
)
|
||||
return state.reshape(bsz * num_groups, groupsz, self.num_heads, self.head_dim).transpose(1, 2)
|
||||
|
||||
query_states, key_states, value_states = shift(query_states), shift(key_states), shift(value_states)
|
||||
if attention_mask is not None:
|
||||
attention_mask = attention_mask[:, :, :groupsz, :groupsz].repeat(num_groups, 1, 1, 1)
|
||||
|
||||
attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
|
||||
|
||||
if attention_mask is not None:
|
||||
attn_weights = attn_weights + attention_mask
|
||||
|
||||
# upcast attention to fp32
|
||||
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
|
||||
attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training)
|
||||
attn_output = torch.matmul(attn_weights, value_states) # (bsz, :, seq_len, :) or (bsz*n_group, :, groupsz, :)
|
||||
attn_output = attn_output.transpose(1, 2).contiguous()
|
||||
|
||||
if getattr(self.config, "group_size_ratio", None) and self.training: # shift back
|
||||
attn_output.reshape(bsz, q_len, self.num_heads, self.head_dim)
|
||||
attn_output = torch.cat(
|
||||
(
|
||||
attn_output[:, :, : self.num_heads // 2],
|
||||
attn_output[:, :, self.num_heads // 2 :].roll(groupsz // 2, dims=1),
|
||||
)
|
||||
)
|
||||
|
||||
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
|
||||
attn_output = self.o_proj(attn_output)
|
||||
|
||||
if not output_attentions:
|
||||
attn_weights = None
|
||||
|
||||
return attn_output, attn_weights, past_key_value
|
||||
|
||||
|
||||
# Modified from: https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py
|
||||
def llama_flash_attn_forward(
|
||||
self: "LlamaFlashAttention2",
|
||||
hidden_states: torch.Tensor,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
past_key_value: Optional[Tuple[torch.Tensor]] = None,
|
||||
output_attentions: bool = False,
|
||||
**kwargs,
|
||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
||||
# LlamaFlashAttention2 attention does not support output_attentions
|
||||
output_attentions = False
|
||||
|
||||
bsz, q_len, _ = hidden_states.size()
|
||||
|
||||
query_states = self.q_proj(hidden_states)
|
||||
key_states = self.k_proj(hidden_states)
|
||||
value_states = self.v_proj(hidden_states)
|
||||
|
||||
# FlashAttention requires the input to have the shape (bsz, seq_len, n_heads, head_dim)
|
||||
query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
|
||||
key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
||||
value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
||||
|
||||
kv_seq_len = key_states.shape[-2]
|
||||
if past_key_value is not None:
|
||||
kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
|
||||
|
||||
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
|
||||
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
|
||||
|
||||
if past_key_value is not None:
|
||||
cache_kwargs = {"sin": sin, "cos": cos} # Specific to RoPE models
|
||||
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
|
||||
|
||||
key_states = repeat_kv(key_states, self.num_key_value_groups)
|
||||
value_states = repeat_kv(value_states, self.num_key_value_groups)
|
||||
|
||||
query_states = query_states.transpose(1, 2) # (bsz, seq_len, n_heads, head_dim)
|
||||
key_states = key_states.transpose(1, 2) # (bsz, seq_len, n_heads, head_dim)
|
||||
value_states = value_states.transpose(1, 2) # (bsz, seq_len, n_heads, head_dim)
|
||||
|
||||
dropout_rate = self.attention_dropout if self.training else 0.0
|
||||
|
||||
input_dtype = query_states.dtype
|
||||
if input_dtype == torch.float32:
|
||||
if torch.is_autocast_enabled():
|
||||
target_dtype = torch.get_autocast_gpu_dtype()
|
||||
elif hasattr(self.config, "_pre_quantization_dtype"):
|
||||
target_dtype = self.config._pre_quantization_dtype
|
||||
else:
|
||||
target_dtype = self.q_proj.weight.dtype
|
||||
|
||||
logger.warning_once("The input hidden states seems to be silently casted in float32.")
|
||||
query_states = query_states.to(target_dtype)
|
||||
key_states = key_states.to(target_dtype)
|
||||
value_states = value_states.to(target_dtype)
|
||||
|
||||
if getattr(self.config, "group_size_ratio", None) and self.training: # shift
|
||||
groupsz = int(q_len * getattr(self.config, "group_size_ratio"))
|
||||
assert q_len % groupsz == 0, "q_len {} should be divisible by group size {}.".format(q_len, groupsz)
|
||||
num_groups = q_len // groupsz
|
||||
|
||||
def shift(state: torch.Tensor) -> torch.Tensor:
|
||||
state = torch.cat(
|
||||
(state[:, :, : self.num_heads // 2], state[:, :, self.num_heads // 2 :].roll(-groupsz // 2, dims=1)),
|
||||
dim=2,
|
||||
)
|
||||
return state.reshape(bsz * num_groups, groupsz, self.num_heads, self.head_dim)
|
||||
|
||||
query_states, key_states, value_states = shift(query_states), shift(key_states), shift(value_states)
|
||||
if attention_mask is not None:
|
||||
attention_mask = attention_mask[:, :, :groupsz, :groupsz].repeat(num_groups, 1, 1, 1)
|
||||
|
||||
attn_output: torch.Tensor = self._flash_attention_forward(
|
||||
query_states, key_states, value_states, attention_mask, q_len, dropout=dropout_rate
|
||||
)
|
||||
|
||||
if getattr(self.config, "group_size_ratio", None) and self.training: # shift back
|
||||
attn_output.reshape(bsz, q_len, self.num_heads, self.head_dim)
|
||||
attn_output = torch.cat(
|
||||
(
|
||||
attn_output[:, :, : self.num_heads // 2],
|
||||
attn_output[:, :, self.num_heads // 2 :].roll(groupsz // 2, dims=1),
|
||||
)
|
||||
)
|
||||
|
||||
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size).contiguous()
|
||||
attn_output = self.o_proj(attn_output)
|
||||
|
||||
if not output_attentions:
|
||||
attn_weights = None
|
||||
|
||||
return attn_output, attn_weights, past_key_value
|
||||
|
||||
|
||||
def apply_llama_patch() -> None:
|
||||
LlamaAttention.forward = llama_torch_attn_forward
|
||||
LlamaFlashAttention2.forward = llama_flash_attn_forward
|
||||
@@ -0,0 +1,38 @@
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from transformers.models.mixtral.modeling_mixtral import MixtralBLockSparseTop2MLP, MixtralSparseMoeBlock
|
||||
|
||||
|
||||
def mlp_forward(self: "MixtralBLockSparseTop2MLP", hidden_states: torch.Tensor) -> torch.Tensor:
|
||||
current_hidden_states = self.act_fn(self.w1(hidden_states)) * self.w3(hidden_states)
|
||||
current_hidden_states = self.w2(current_hidden_states)
|
||||
return current_hidden_states
|
||||
|
||||
|
||||
# Modified from: https://huggingface.co/deepseek-ai/deepseek-moe-16b-base/blob/main/modeling_deepseek.py
|
||||
def moe_forward(self: "MixtralSparseMoeBlock", hidden_states: torch.Tensor) -> torch.Tensor:
|
||||
batch_size, sequence_length, hidden_dim = hidden_states.shape
|
||||
hidden_states = hidden_states.view(-1, hidden_dim)
|
||||
# router_logits: (batch * sequence_length, n_experts)
|
||||
router_logits = self.gate(hidden_states)
|
||||
|
||||
routing_weights = F.softmax(router_logits, dim=1, dtype=torch.float)
|
||||
topk_weight, topk_idx = torch.topk(routing_weights, self.top_k, dim=-1, sorted=False)
|
||||
topk_weight /= topk_weight.sum(dim=-1, keepdim=True)
|
||||
# we cast back to the input dtype
|
||||
topk_weight = topk_weight.to(hidden_states.dtype)
|
||||
|
||||
hidden_states = hidden_states.repeat_interleave(self.top_k, dim=0)
|
||||
y = torch.empty_like(hidden_states)
|
||||
flat_topk_idx = topk_idx.view(-1)
|
||||
for i in range(self.num_experts):
|
||||
expert = self.experts[i]
|
||||
y[flat_topk_idx == i] = expert(hidden_states[flat_topk_idx == i])
|
||||
y = (y.view(*topk_weight.shape, -1) * topk_weight.unsqueeze(-1)).sum(dim=1)
|
||||
final_hidden_states = y.reshape(batch_size, sequence_length, hidden_dim)
|
||||
return final_hidden_states, router_logits
|
||||
|
||||
|
||||
def patch_mixtral_replace_moe_impl() -> None:
|
||||
MixtralBLockSparseTop2MLP.forward = mlp_forward
|
||||
MixtralSparseMoeBlock.forward = moe_forward
|
||||
@@ -0,0 +1,57 @@
|
||||
import json
|
||||
import math
|
||||
import os
|
||||
from typing import List
|
||||
|
||||
from transformers.trainer import TRAINER_STATE_NAME
|
||||
|
||||
from .logging import get_logger
|
||||
from .packages import is_matplotlib_available
|
||||
|
||||
|
||||
if is_matplotlib_available():
|
||||
import matplotlib.pyplot as plt
|
||||
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
def smooth(scalars: List[float]) -> List[float]:
|
||||
r"""
|
||||
EMA implementation according to TensorBoard.
|
||||
"""
|
||||
last = scalars[0]
|
||||
smoothed = list()
|
||||
weight = 1.8 * (1 / (1 + math.exp(-0.05 * len(scalars))) - 0.5) # a sigmoid function
|
||||
for next_val in scalars:
|
||||
smoothed_val = last * weight + (1 - weight) * next_val
|
||||
smoothed.append(smoothed_val)
|
||||
last = smoothed_val
|
||||
return smoothed
|
||||
|
||||
|
||||
def plot_loss(save_dictionary: os.PathLike, keys: List[str] = ["loss"]) -> None:
|
||||
with open(os.path.join(save_dictionary, TRAINER_STATE_NAME), "r", encoding="utf-8") as f:
|
||||
data = json.load(f)
|
||||
|
||||
for key in keys:
|
||||
steps, metrics = [], []
|
||||
for i in range(len(data["log_history"])):
|
||||
if key in data["log_history"][i]:
|
||||
steps.append(data["log_history"][i]["step"])
|
||||
metrics.append(data["log_history"][i][key])
|
||||
|
||||
if len(metrics) == 0:
|
||||
logger.warning(f"No metric {key} to plot.")
|
||||
continue
|
||||
|
||||
plt.figure()
|
||||
plt.plot(steps, metrics, color="#1f77b4", alpha=0.4, label="original")
|
||||
plt.plot(steps, smooth(metrics), color="#1f77b4", label="smoothed")
|
||||
plt.title("training {} of {}".format(key, save_dictionary))
|
||||
plt.xlabel("step")
|
||||
plt.ylabel(key)
|
||||
plt.legend()
|
||||
figure_path = os.path.join(save_dictionary, "training_{}.png".format(key.replace(os.path.sep, "_")))
|
||||
plt.savefig(figure_path, format="png", dpi=100)
|
||||
print("Figure saved at:", figure_path)
|
||||
@@ -0,0 +1,18 @@
|
||||
from .data_args import DataArguments
|
||||
from .evaluation_args import EvaluationArguments
|
||||
from .finetuning_args import FinetuningArguments
|
||||
from .generating_args import GeneratingArguments
|
||||
from .model_args import ModelArguments
|
||||
from .parser import get_eval_args, get_infer_args, get_train_args
|
||||
|
||||
|
||||
__all__ = [
|
||||
"DataArguments",
|
||||
"EvaluationArguments",
|
||||
"FinetuningArguments",
|
||||
"GeneratingArguments",
|
||||
"ModelArguments",
|
||||
"get_eval_args",
|
||||
"get_infer_args",
|
||||
"get_train_args",
|
||||
]
|
||||
@@ -0,0 +1,100 @@
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Literal, Optional
|
||||
|
||||
|
||||
@dataclass
|
||||
class DataArguments:
|
||||
r"""
|
||||
Arguments pertaining to what data we are going to input our model for training and evaluation.
|
||||
"""
|
||||
|
||||
template: Optional[str] = field(
|
||||
default=None,
|
||||
metadata={"help": "Which template to use for constructing prompts in training and inference."},
|
||||
)
|
||||
dataset: Optional[str] = field(
|
||||
default=None,
|
||||
metadata={"help": "The name of provided dataset(s) to use. Use commas to separate multiple datasets."},
|
||||
)
|
||||
dataset_dir: str = field(
|
||||
default="data",
|
||||
metadata={"help": "Path to the folder containing the datasets."},
|
||||
)
|
||||
split: str = field(
|
||||
default="train",
|
||||
metadata={"help": "Which dataset split to use for training and evaluation."},
|
||||
)
|
||||
cutoff_len: int = field(
|
||||
default=1024,
|
||||
metadata={"help": "The cutoff length of the model inputs after tokenization."},
|
||||
)
|
||||
reserved_label_len: int = field(
|
||||
default=1,
|
||||
metadata={"help": "The minimum cutoff length reserved for label after tokenization."},
|
||||
)
|
||||
train_on_prompt: bool = field(
|
||||
default=False,
|
||||
metadata={"help": "Whether to disable the mask on the prompt or not."},
|
||||
)
|
||||
streaming: bool = field(
|
||||
default=False,
|
||||
metadata={"help": "Enable dataset streaming."},
|
||||
)
|
||||
buffer_size: int = field(
|
||||
default=16384,
|
||||
metadata={"help": "Size of the buffer to randomly sample examples from in dataset streaming."},
|
||||
)
|
||||
mix_strategy: Literal["concat", "interleave_under", "interleave_over"] = field(
|
||||
default="concat",
|
||||
metadata={"help": "Strategy to use in dataset mixing (concat/interleave) (undersampling/oversampling)."},
|
||||
)
|
||||
interleave_probs: Optional[str] = field(
|
||||
default=None,
|
||||
metadata={"help": "Probabilities to sample data from datasets. Use commas to separate multiple datasets."},
|
||||
)
|
||||
overwrite_cache: bool = field(
|
||||
default=False,
|
||||
metadata={"help": "Overwrite the cached training and evaluation sets."},
|
||||
)
|
||||
preprocessing_num_workers: Optional[int] = field(
|
||||
default=None,
|
||||
metadata={"help": "The number of processes to use for the pre-processing."},
|
||||
)
|
||||
max_samples: Optional[int] = field(
|
||||
default=None,
|
||||
metadata={"help": "For debugging purposes, truncate the number of examples for each dataset."},
|
||||
)
|
||||
eval_num_beams: Optional[int] = field(
|
||||
default=None,
|
||||
metadata={"help": "Number of beams to use for evaluation. This argument will be passed to `model.generate`"},
|
||||
)
|
||||
ignore_pad_token_for_loss: bool = field(
|
||||
default=True,
|
||||
metadata={
|
||||
"help": "Whether or not to ignore the tokens corresponding to padded labels in the loss computation."
|
||||
},
|
||||
)
|
||||
val_size: float = field(
|
||||
default=0.0,
|
||||
metadata={"help": "Size of the development set, should be an integer or a float in range `[0,1)`."},
|
||||
)
|
||||
packing: Optional[bool] = field(
|
||||
default=None,
|
||||
metadata={
|
||||
"help": "Whether or not to pack the sequences in training. Will automatically enable in pre-training."
|
||||
},
|
||||
)
|
||||
cache_path: Optional[str] = field(
|
||||
default=None,
|
||||
metadata={"help": "Path to save or load the pre-processed datasets."},
|
||||
)
|
||||
|
||||
def __post_init__(self):
|
||||
if self.reserved_label_len >= self.cutoff_len:
|
||||
raise ValueError("`reserved_label_len` must be smaller than `cutoff_len`.")
|
||||
|
||||
if self.streaming and self.val_size > 1e-6 and self.val_size < 1:
|
||||
raise ValueError("Streaming mode should have an integer val size.")
|
||||
|
||||
if self.streaming and self.max_samples is not None:
|
||||
raise ValueError("`max_samples` is incompatible with `streaming`.")
|
||||
@@ -0,0 +1,48 @@
|
||||
import os
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Literal, Optional
|
||||
|
||||
from datasets import DownloadMode
|
||||
|
||||
|
||||
@dataclass
|
||||
class EvaluationArguments:
|
||||
r"""
|
||||
Arguments pertaining to specify the evaluation parameters.
|
||||
"""
|
||||
|
||||
task: str = field(
|
||||
metadata={"help": "Name of the evaluation task."},
|
||||
)
|
||||
task_dir: str = field(
|
||||
default="evaluation",
|
||||
metadata={"help": "Path to the folder containing the evaluation datasets."},
|
||||
)
|
||||
batch_size: int = field(
|
||||
default=4,
|
||||
metadata={"help": "The batch size per GPU for evaluation."},
|
||||
)
|
||||
seed: int = field(
|
||||
default=42,
|
||||
metadata={"help": "Random seed to be used with data loaders."},
|
||||
)
|
||||
lang: Literal["en", "zh"] = field(
|
||||
default="en",
|
||||
metadata={"help": "Language used at evaluation."},
|
||||
)
|
||||
n_shot: int = field(
|
||||
default=5,
|
||||
metadata={"help": "Number of examplars for few-shot learning."},
|
||||
)
|
||||
save_dir: Optional[str] = field(
|
||||
default=None,
|
||||
metadata={"help": "Path to save the evaluation results."},
|
||||
)
|
||||
download_mode: DownloadMode = field(
|
||||
default=DownloadMode.REUSE_DATASET_IF_EXISTS,
|
||||
metadata={"help": "Download mode used for the evaluation datasets."},
|
||||
)
|
||||
|
||||
def __post_init__(self):
|
||||
if self.save_dir is not None and os.path.exists(self.save_dir):
|
||||
raise ValueError("`save_dir` already exists, use another one.")
|
||||
@@ -0,0 +1,257 @@
|
||||
import json
|
||||
from dataclasses import asdict, dataclass, field
|
||||
from typing import Literal, Optional
|
||||
|
||||
|
||||
@dataclass
|
||||
class FreezeArguments:
|
||||
r"""
|
||||
Arguments pertaining to the freeze (partial-parameter) training.
|
||||
"""
|
||||
|
||||
name_module_trainable: str = field(
|
||||
default="all",
|
||||
metadata={
|
||||
"help": """Name of trainable modules for partial-parameter (freeze) fine-tuning. \
|
||||
Use commas to separate multiple modules. \
|
||||
Use "all" to specify all the available modules. \
|
||||
LLaMA choices: ["mlp", "self_attn"], \
|
||||
BLOOM & Falcon & ChatGLM choices: ["mlp", "self_attention"], \
|
||||
Qwen choices: ["mlp", "attn"], \
|
||||
InternLM2 choices: ["feed_forward", "attention"], \
|
||||
Others choices: the same as LLaMA."""
|
||||
},
|
||||
)
|
||||
num_layer_trainable: int = field(
|
||||
default=2,
|
||||
metadata={"help": "The number of trainable layers for partial-parameter (freeze) fine-tuning."},
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class LoraArguments:
|
||||
r"""
|
||||
Arguments pertaining to the LoRA training.
|
||||
"""
|
||||
|
||||
additional_target: Optional[str] = field(
|
||||
default=None,
|
||||
metadata={
|
||||
"help": "Name(s) of modules apart from LoRA layers to be set as trainable and saved in the final checkpoint."
|
||||
},
|
||||
)
|
||||
lora_alpha: Optional[int] = field(
|
||||
default=None,
|
||||
metadata={"help": "The scale factor for LoRA fine-tuning (default: lora_rank * 2)."},
|
||||
)
|
||||
lora_dropout: float = field(
|
||||
default=0.0,
|
||||
metadata={"help": "Dropout rate for the LoRA fine-tuning."},
|
||||
)
|
||||
lora_rank: int = field(
|
||||
default=8,
|
||||
metadata={"help": "The intrinsic dimension for LoRA fine-tuning."},
|
||||
)
|
||||
lora_target: str = field(
|
||||
default="all",
|
||||
metadata={
|
||||
"help": """Name(s) of target modules to apply LoRA. \
|
||||
Use commas to separate multiple modules. \
|
||||
Use "all" to specify all the available modules. \
|
||||
LLaMA choices: ["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"], \
|
||||
BLOOM & Falcon & ChatGLM choices: ["query_key_value", "dense", "dense_h_to_4h", "dense_4h_to_h"], \
|
||||
Baichuan choices: ["W_pack", "o_proj", "gate_proj", "up_proj", "down_proj"], \
|
||||
Qwen choices: ["c_attn", "attn.c_proj", "w1", "w2", "mlp.c_proj"], \
|
||||
InternLM2 choices: ["wqkv", "wo", "w1", "w2", "w3"], \
|
||||
Others choices: the same as LLaMA."""
|
||||
},
|
||||
)
|
||||
use_rslora: bool = field(
|
||||
default=False,
|
||||
metadata={"help": "Whether or not to use the rank stabilization scaling factor for LoRA layer."},
|
||||
)
|
||||
use_dora: bool = field(
|
||||
default=False,
|
||||
metadata={"help": "Whether or not to use the weight-decomposed lora method (DoRA)."},
|
||||
)
|
||||
create_new_adapter: bool = field(
|
||||
default=False,
|
||||
metadata={"help": "Whether or not to create a new adapter with randomly initialized weight."},
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class RLHFArguments:
|
||||
r"""
|
||||
Arguments pertaining to the PPO and DPO training.
|
||||
"""
|
||||
|
||||
dpo_beta: float = field(
|
||||
default=0.1,
|
||||
metadata={"help": "The beta parameter for the DPO loss."},
|
||||
)
|
||||
dpo_loss: Literal["sigmoid", "hinge", "ipo", "kto_pair"] = field(
|
||||
default="sigmoid",
|
||||
metadata={"help": "The type of DPO loss to use."},
|
||||
)
|
||||
dpo_ftx: float = field(
|
||||
default=0.0,
|
||||
metadata={"help": "The supervised fine-tuning loss coefficient in DPO training."},
|
||||
)
|
||||
ppo_buffer_size: int = field(
|
||||
default=1,
|
||||
metadata={"help": "The number of mini-batches to make experience buffer in a PPO optimization step."},
|
||||
)
|
||||
ppo_epochs: int = field(
|
||||
default=4,
|
||||
metadata={"help": "The number of epochs to perform in a PPO optimization step."},
|
||||
)
|
||||
ppo_logger: Optional[str] = field(
|
||||
default=None,
|
||||
metadata={"help": 'Log with either "wandb" or "tensorboard" in PPO training.'},
|
||||
)
|
||||
ppo_score_norm: bool = field(
|
||||
default=False,
|
||||
metadata={"help": "Use score normalization in PPO training."},
|
||||
)
|
||||
ppo_target: float = field(
|
||||
default=6.0,
|
||||
metadata={"help": "Target KL value for adaptive KL control in PPO training."},
|
||||
)
|
||||
ppo_whiten_rewards: bool = field(
|
||||
default=False,
|
||||
metadata={"help": "Whiten the rewards before compute advantages in PPO training."},
|
||||
)
|
||||
ref_model: Optional[str] = field(
|
||||
default=None,
|
||||
metadata={"help": "Path to the reference model used for the PPO or DPO training."},
|
||||
)
|
||||
ref_model_adapters: Optional[str] = field(
|
||||
default=None,
|
||||
metadata={"help": "Path to the adapters of the reference model."},
|
||||
)
|
||||
ref_model_quantization_bit: Optional[int] = field(
|
||||
default=None,
|
||||
metadata={"help": "The number of bits to quantize the reference model."},
|
||||
)
|
||||
reward_model: Optional[str] = field(
|
||||
default=None,
|
||||
metadata={"help": "Path to the reward model used for the PPO training."},
|
||||
)
|
||||
reward_model_adapters: Optional[str] = field(
|
||||
default=None,
|
||||
metadata={"help": "Path to the adapters of the reward model."},
|
||||
)
|
||||
reward_model_quantization_bit: Optional[int] = field(
|
||||
default=None,
|
||||
metadata={"help": "The number of bits to quantize the reward model."},
|
||||
)
|
||||
reward_model_type: Literal["lora", "full", "api"] = field(
|
||||
default="lora",
|
||||
metadata={"help": "The type of the reward model in PPO training. Lora model only supports lora training."},
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class GaloreArguments:
|
||||
r"""
|
||||
Arguments pertaining to the GaLore algorithm.
|
||||
"""
|
||||
|
||||
use_galore: bool = field(
|
||||
default=False,
|
||||
metadata={"help": "Whether or not to use gradient low-Rank projection."},
|
||||
)
|
||||
galore_target: str = field(
|
||||
default="mlp,attn",
|
||||
metadata={"help": "Name(s) of modules to apply GaLore. Use commas to separate multiple modules."},
|
||||
)
|
||||
galore_rank: int = field(
|
||||
default=16,
|
||||
metadata={"help": "The rank of GaLore gradients."},
|
||||
)
|
||||
galore_update_interval: int = field(
|
||||
default=200,
|
||||
metadata={"help": "Number of steps to update the GaLore projection."},
|
||||
)
|
||||
galore_scale: float = field(
|
||||
default=0.25,
|
||||
metadata={"help": "GaLore scaling coefficient."},
|
||||
)
|
||||
galore_proj_type: Literal["std", "reverse_std", "right", "left", "full"] = field(
|
||||
default="std",
|
||||
metadata={"help": "Type of GaLore projection."},
|
||||
)
|
||||
galore_layerwise: bool = field(
|
||||
default=False,
|
||||
metadata={"help": "Whether or not to enable layer-wise update to further save memory."},
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class FinetuningArguments(FreezeArguments, LoraArguments, RLHFArguments, GaloreArguments):
|
||||
r"""
|
||||
Arguments pertaining to which techniques we are going to fine-tuning with.
|
||||
"""
|
||||
|
||||
pure_bf16: bool = field(
|
||||
default=False,
|
||||
metadata={"help": "Whether or not to train model in purely bf16 precision (without AMP)."},
|
||||
)
|
||||
stage: Literal["pt", "sft", "rm", "ppo", "dpo"] = field(
|
||||
default="sft",
|
||||
metadata={"help": "Which stage will be performed in training."},
|
||||
)
|
||||
finetuning_type: Literal["lora", "freeze", "full"] = field(
|
||||
default="lora",
|
||||
metadata={"help": "Which fine-tuning method to use."},
|
||||
)
|
||||
use_llama_pro: bool = field(
|
||||
default=False,
|
||||
metadata={"help": "Whether or not to make only the parameters in the expanded blocks trainable."},
|
||||
)
|
||||
plot_loss: bool = field(
|
||||
default=False,
|
||||
metadata={"help": "Whether or not to save the training loss curves."},
|
||||
)
|
||||
|
||||
def __post_init__(self):
|
||||
def split_arg(arg):
|
||||
if isinstance(arg, str):
|
||||
return [item.strip() for item in arg.split(",")]
|
||||
return arg
|
||||
|
||||
self.name_module_trainable = split_arg(self.name_module_trainable)
|
||||
self.lora_alpha = self.lora_alpha or self.lora_rank * 2
|
||||
self.lora_target = split_arg(self.lora_target)
|
||||
self.additional_target = split_arg(self.additional_target)
|
||||
|
||||
assert self.finetuning_type in ["lora", "freeze", "full"], "Invalid fine-tuning method."
|
||||
assert self.ref_model_quantization_bit in [None, 8, 4], "We only accept 4-bit or 8-bit quantization."
|
||||
assert self.reward_model_quantization_bit in [None, 8, 4], "We only accept 4-bit or 8-bit quantization."
|
||||
|
||||
if self.stage == "ppo" and self.reward_model is None:
|
||||
raise ValueError("`reward_model` is necessary for PPO training.")
|
||||
|
||||
if self.stage == "ppo" and self.reward_model_type == "lora" and self.finetuning_type != "lora":
|
||||
raise ValueError("`reward_model_type` cannot be lora for Freeze/Full PPO training.")
|
||||
|
||||
if self.use_llama_pro and self.finetuning_type == "full":
|
||||
raise ValueError("`use_llama_pro` is only valid for the Freeze or LoRA method.")
|
||||
|
||||
if self.use_galore and self.finetuning_type == "lora":
|
||||
raise ValueError("Cannot use LoRA with GaLore together.")
|
||||
|
||||
def save_to_json(self, json_path: str):
|
||||
r"""Saves the content of this instance in JSON format inside `json_path`."""
|
||||
json_string = json.dumps(asdict(self), indent=2, sort_keys=True) + "\n"
|
||||
with open(json_path, "w", encoding="utf-8") as f:
|
||||
f.write(json_string)
|
||||
|
||||
@classmethod
|
||||
def load_from_json(cls, json_path: str):
|
||||
r"""Creates an instance from the content of `json_path`."""
|
||||
with open(json_path, "r", encoding="utf-8") as f:
|
||||
text = f.read()
|
||||
|
||||
return cls(**json.loads(text))
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user