Compare commits

..

259 Commits
0.2.4 ... 0.4.3

Author SHA1 Message Date
zyxucp
1e508e45af fix modellist 2024-06-30 17:11:21 +08:00
zyxucp
03d9ec2cad Merge pull request #94 from duyanming/main
解决内容较多的时候需要等结束转Markdown的不好体验。
2024-06-30 17:02:45 +08:00
zyxucp
86fb48bab7 Merge pull request #95 from AIDotNet/feature_ollama
Feature ollama
2024-06-30 17:01:39 +08:00
zyxucp
a4bc1e4a55 fix 2024-06-30 17:00:24 +08:00
zyxucp
8681e15da5 add ollama 2024-06-30 16:59:46 +08:00
zyxucp
ebc82f8b1b add ollamatype 2024-06-30 15:55:42 +08:00
duyanming
3bcd7bd7e1 1、生成结果的同时转化为 Markdown 文本,解决内容较多的时候需要等结束转Markdown的不好体验。
2、去掉模拟延迟,解决体验问题。仿佛生成很慢
2024-06-30 14:16:36 +08:00
zyxucp
b64d8669b1 fix AntDesign.ProLayout bug 2024-06-29 22:43:02 +08:00
zyxucp
0489044098 fix rerank 2024-06-29 10:57:58 +08:00
zyxucp
17e2062b72 margin 2024-06-29 10:57:18 +08:00
zyxucp
4e4f5a698d update nuget 2024-06-29 10:56:47 +08:00
zyxucp
b879d04bcd update nuget 2024-06-23 23:57:01 +08:00
zyxucp
95f918f4c7 Merge branch 'main' of github.com:AIDotNet/AntSK 2024-06-19 23:08:36 +08:00
zyxucp
f0e1ad6088 fix 处理星火模型秘钥在模型列表显示的问题,以及增加星火模型版本 2024-06-19 23:08:18 +08:00
zyxucp
61773af48d Update docker-compose.simple.yml 2024-06-12 21:57:21 +08:00
zyxucp
54cd04c3bf Update docker-compose.yml 2024-06-12 21:57:03 +08:00
zyxucp
cd9f4ae11b Update README.md 2024-06-12 21:05:58 +08:00
zyxucp
3f9c748b41 update nuget 2024-06-12 11:20:59 +08:00
zyxucp
d483005531 add api地址 2024-06-10 22:01:57 +08:00
zyxucp
1d2db6a896 Update docker-compose.simple.yml 2024-06-08 18:49:09 +08:00
zyxucp
9a7a263055 Update docker-compose.yml 2024-06-08 18:48:52 +08:00
zyxucp
6beb0b52c7 Merge pull request #92 from AIDotNet/feature_llamafactory
update llamafactory 0.8.0
2024-06-08 18:47:13 +08:00
zyxucp
0ea167a204 update llamafactory 0.8.0 2024-06-08 18:29:37 +08:00
zyxucp
6e6afa2a7c Update docker-compose.simple.yml 2024-06-08 11:36:19 +08:00
zyxucp
7a2a5d86bb Update docker-compose.yml 2024-06-08 11:36:04 +08:00
zyxucp
a1a36c3494 update nuget 2024-06-08 11:31:24 +08:00
zyxucp
4f350081dd update llamasharp 2024-06-08 11:23:02 +08:00
zyxucp
b3ea0c4e1a add llamasharp 配置 2024-06-08 11:04:14 +08:00
zyxucp
e72a6acd03 fix 处理聊天上下文 2024-05-30 13:08:37 +08:00
zyxucp
9bb8ab89fe Update README.zh.md 2024-05-29 22:54:41 +08:00
zyxucp
e78da66d1a Update README.md 2024-05-29 22:54:25 +08:00
zyxucp
9ee21fd5e5 AddServiceDefaults 2024-05-29 21:26:41 +08:00
zyxucp
a22c04c9b2 Merge pull request #91 from AIDotNet/feature_aspire
Feature aspire
2024-05-29 17:29:00 +08:00
zyxucp
3bb5bfaca7 add otel 2024-05-29 16:34:54 +08:00
zyxucp
c4bf5ee7e5 fix 增加OTEL 2024-05-29 15:06:16 +08:00
zyxucp
5e1e688f84 fix seq 2024-05-29 14:20:07 +08:00
zyxucp
80d9bf68f3 fix seq 2024-05-29 13:52:47 +08:00
zyxucp
65f2e3e363 add Serilog.Sinks.Seq 2024-05-29 13:20:11 +08:00
zyxucp
68d27ff2bc update Serilog 2024-05-29 13:03:00 +08:00
zyxucp
034da30811 add Serilog 2024-05-29 12:14:12 +08:00
zyxucp
3db0cdcd19 add aspire 2024-05-29 00:01:30 +08:00
zyxucp
42181a6f1d add aspire 2024-05-28 22:23:55 +08:00
zyxucp
ec8cbf2550 add 增加跨域处理 2024-05-27 22:19:22 +08:00
zyxucp
9a1bd079da fix 删除默认提示词 2024-05-26 19:41:58 +08:00
zyxucp
4213c4379c update 处理openapi 没有systemPrompt的问题 2024-05-26 19:38:32 +08:00
zyxucp
05cda17e2e style 样式修改 2024-05-26 00:50:23 +08:00
zyxucp
cda6e54f0b Merge branch 'main' of github.com:AIDotNet/AntSK 2024-05-25 23:11:40 +08:00
zyxucp
51d8ba6408 update km、sk 版本 2024-05-25 23:11:33 +08:00
zyxucp
b571c7d22d Update README.md 2024-05-24 22:01:03 +08:00
zyxucp
a0c91f565e fix 修复openapi聊天上下文bug 2024-05-24 21:47:53 +08:00
zyxucp
280c750165 Update README.md 2024-05-23 14:47:20 +08:00
zyxucp
fec9337fda margin 2024-05-23 14:34:46 +08:00
zyxucp
b84f252f2f update 更新readme 2024-05-23 14:17:54 +08:00
zyxucp
5c998ccce2 Update README.en.md 2024-05-23 13:53:36 +08:00
zyxucp
0e3cfd2cfb Update README.md 2024-05-23 13:53:33 +08:00
zyxucp
4040831a23 Update README.md 2024-05-23 13:52:17 +08:00
zyxucp
a3a2308659 Update docker-compose.yml 2024-05-23 13:46:03 +08:00
zyxucp
6d43c71d13 Update docker-compose.simple.yml 2024-05-23 13:45:42 +08:00
zyxucp
8315b6f37f fix 样式修改 2024-05-23 12:07:37 +08:00
zyxucp
7bc708e6ae margin 2024-05-23 11:33:15 +08:00
zyxucp
e6f2c5c2fe update 升级SK KM版本 2024-05-23 11:29:23 +08:00
zyxucp
5cab781362 Merge pull request #90 from yc-2503/main
fix: 对话窗口的第一条对话没有传给大模型问题
2024-05-14 22:20:32 +08:00
Chason
02d7994bae fix: 对话窗口的第一条对话丢失 2024-05-14 20:32:11 +08:00
zyxucp
b740957157 fix 调整KM版本 2024-05-12 20:49:13 +08:00
zyxucp
2480ec1272 margin 2024-05-12 19:07:51 +08:00
zyxucp
35c98a0d14 update 更新ant blazor \sk \km 2024-05-12 19:07:27 +08:00
zyxucp
0964a5ad5b Merge pull request #88 from yc-2503/main
bugfix: 调用function时 报错 jsonbody 参数不存在
2024-05-09 23:33:24 +08:00
Chason
a95131efe9 fix: 调用function时 报错 jsonbody 参数不存在
KernelParameterMetadata 的构造函数已指定参数名 jsonbody, 后续却又将参数名改为 json参数字符串
2024-05-09 17:41:43 +08:00
Chason
7783cdf3c4 bugfix: 语法错误 2024-05-09 10:55:41 +08:00
zyxucp
7a65f33cb6 Update README.md 2024-05-09 01:33:37 +08:00
zyxucp
6efd01db3f Merge pull request #87 from yc-2503/main
fix: 修正 会话总结 中的返回字符串
2024-05-08 13:26:15 +08:00
Chason
1e2322b573 Merge pull request #1 from yc-2503/yc-2503-patch-1
fix: 修正会话总结
2024-05-07 19:55:12 +08:00
Chason
2cb2241a66 fix: 修正会话总结 2024-05-07 19:54:29 +08:00
zyxucp
64efdd7881 add logo 2024-05-01 14:09:21 +08:00
zyxucp
be28e32803 update 更新nuget版本 2024-05-01 13:05:11 +08:00
zyxucp
468422baee fix 处理异步聊天问题 2024-04-30 21:53:50 +08:00
zyxucp
7b1c6c8c64 fix 修改异步 2024-04-30 17:53:16 +08:00
zyxucp
7ff0ea0bfe Update README.en.md 2024-04-29 21:42:57 +08:00
zyxucp
6bed4356f0 Update README.md 2024-04-29 18:17:06 +08:00
zyxucp
73b65f7305 Merge pull request #84 from AIDotNet/feature_llamasharp
Feature llamasharp
2024-04-28 20:38:01 +08:00
zyxucp
0ea52eced9 fix 修改聊天为chathistory 2024-04-28 20:37:37 +08:00
zyxucp
498e9ba9f6 Merge branch 'main' into feature_llamasharp 2024-04-28 20:24:30 +08:00
zyxucp
125695665b add 修改水印影响样式问题 2024-04-28 16:53:10 +08:00
zyxucp
0e08b3ae85 add 水印 2024-04-28 15:37:51 +08:00
zyxucp
7cb8f99e7e fix 处理聊天对话 2024-04-27 23:03:46 +08:00
zyxucp
d15cb527d0 add 教程视频 2024-04-24 23:56:45 +08:00
zyxucp
9cb36174fd Update README.md 2024-04-23 14:16:50 +08:00
zyxucp
6265f94ef2 fix 修改文档导入问答单独的index 2024-04-23 13:32:57 +08:00
zyxucp
09d90b654c fix 处理文档问答问题 2024-04-23 13:28:38 +08:00
zyxucp
64e2bca2e6 Merge pull request #80 from AIDotNet/fix_chatkmsbug
fix 修改聊天记录知识库保存bug
2024-04-23 12:54:49 +08:00
zyxucp
328ece6d73 fix 修改聊天记录知识库保存bug 2024-04-23 12:53:51 +08:00
zyxucp
fabb8c2044 fix 处理合并 2024-04-23 11:51:08 +08:00
zyxucp
6ca75df880 fix 处理合并 2024-04-23 11:50:46 +08:00
zyxucp
3d4dfaced1 margin 2024-04-23 11:48:43 +08:00
zyxucp
d532bf3bb6 add 聊天历史记录搜索 2024-04-22 23:50:08 +08:00
zyxucp
e1fd288875 fix 样式修改 2024-04-22 23:37:20 +08:00
zyxucp
91eae9cfa8 fix 修改聊天记录存储 2024-04-22 23:31:08 +08:00
zyxucp
b0059942d3 Merge pull request #79 from AIDotNet/feature_chat
Feature chat
2024-04-22 23:17:58 +08:00
zyxucp
a716982878 add 聊天记录 2024-04-22 23:17:16 +08:00
zyxucp
3d4e48f9f5 fix 修改错误 2024-04-22 22:26:45 +08:00
zyxucp
1f212d3156 update semantic kernel to 1.8.0 2024-04-22 22:19:46 +08:00
zyxucp
7d91ef6ba1 Update LICENSE 2024-04-22 22:14:44 +08:00
zyxucp
2a450b00de add 处理在有用户时使用chats表次存储聊天记录,匿名访问时使用localstorage存储聊天记录 2024-04-22 22:03:34 +08:00
zyxucp
3a97068248 Update README.md 2024-04-21 12:12:46 +08:00
zeyu xu
1d9d95899a update 更新antsk logo 2024-04-21 11:17:51 +08:00
zyxucp
7ae8e52b57 Merge pull request #77 from AIDotNet/feature_bge
update kernelMemory nuget版本
2024-04-21 11:01:06 +08:00
zeyu xu
f5c195a1d0 update kernelMemory nuget版本 2024-04-21 11:00:36 +08:00
zyxucp
78a6b662d3 Update docker-compose.simple.yml 2024-04-20 23:30:27 +08:00
zyxucp
5f814eb76c Update docker-compose.yml 2024-04-20 23:30:07 +08:00
zyxucp
d9e5ebb464 Update README.md 2024-04-20 23:29:48 +08:00
zyxucp
bce0e9183c Update README.md 2024-04-20 23:29:26 +08:00
zyxucp
c40a7bcf22 Merge pull request #76 from AIDotNet/feature_bge
Feature bge
2024-04-20 23:19:10 +08:00
zeyu xu
97a7d447ab add rerank kms 2024-04-20 23:18:07 +08:00
zeyu xu
f803b9538b fix 调整目录 2024-04-20 21:17:27 +08:00
zeyu xu
1ac34c1702 add 应用增加rerank 2024-04-20 21:09:34 +08:00
zeyu xu
e07b480da1 add bgemodel 2024-04-20 21:02:44 +08:00
zeyu xu
9036af57e3 重命名 2024-04-20 20:56:55 +08:00
zeyu xu
93288f9b5c add bgererank 模型下载 2024-04-20 20:56:00 +08:00
zyxucp
f40dd8b013 Merge pull request #75 from AIDotNet/feature_menu
add 模型管理页面 文字超长的样式处理
2024-04-20 10:42:23 +08:00
zeyu xu
c6b83d0695 add 模型管理页面 文字超长的样式处理 2024-04-20 10:41:48 +08:00
zyxucp
592c850198 Merge pull request #74 from AIDotNet/feature_menu
add 单独剥离模型管理菜单
2024-04-20 10:31:43 +08:00
zeyu xu
4a3930ac7b add 单独剥离模型管理菜单 2024-04-20 10:31:19 +08:00
zyxucp
c05ba0af3e Update README.md 2024-04-19 23:20:44 +08:00
zyxucp
630ee51df6 Update docker-compose.simple.yml 2024-04-19 23:20:26 +08:00
zyxucp
d0e75e26c3 Update docker-compose.yml 2024-04-19 23:20:04 +08:00
zyxucp
62c36c3072 Merge pull request #73 from AIDotNet/feature_deldimensions
add DelDimensions
2024-04-19 23:09:33 +08:00
zeyu xu
baef309064 add DelDimensions 2024-04-19 23:08:50 +08:00
zeyu xu
d717cbad9c update nuget sqlsugar 2024-04-19 22:12:02 +08:00
zeyu xu
5ef0624605 fix 修改WithLogCallback 日志输出 2024-04-19 22:04:08 +08:00
zyxucp
af2930a371 fix 修改WithLog 2024-04-19 18:37:03 +08:00
zyxucp
98f0f9fe84 Update README.md 2024-04-18 22:15:30 +08:00
zeyu xu
28a23271e9 fix 文件名修改 2024-04-18 22:05:23 +08:00
zeyu xu
f1ba0bdf10 add 模型删除校验 2024-04-18 21:49:09 +08:00
zeyu xu
0d5513f374 add Directory.Build.props 2024-04-18 21:23:25 +08:00
zyxucp
4812cc308c Update README.md 2024-04-17 22:59:46 +08:00
zyxucp
584f7faded add 环境变量 2024-04-17 18:25:58 +08:00
zyxucp
08dcef2d8b Update docker-compose.simple.yml 2024-04-16 21:56:58 +08:00
zyxucp
68218733a2 Update docker-compose.yml 2024-04-16 21:56:40 +08:00
zyxucp
eb64cbf3d4 update 升级km nuget版本 2024-04-16 17:01:32 +08:00
zyxucp
f0e8a55522 Merge pull request #72 from AIDotNet/feature_qa1
fix 修改切分使用服务注入
2024-04-16 13:56:39 +08:00
zyxucp
5ec5a0bde4 fix 修改切分使用服务注入 2024-04-16 13:54:24 +08:00
zyxucp
1cc56dd553 Merge pull request #71 from AIDotNet/feature_qa
Feature qa
2024-04-15 23:27:02 +08:00
zeyu xu
64e949a88b add qa切片 2024-04-15 23:26:23 +08:00
zeyu xu
a2390a7c97 add qa问答 2024-04-15 23:21:11 +08:00
zeyu xu
559661bb6c add qa 参数 2024-04-15 21:30:06 +08:00
zyxucp
79326de263 Merge pull request #70 from IntptrMax/Add-stable-diffusion-reference-for-Windows
Add stable diffusion reference for windows
2024-04-15 10:37:03 +08:00
IntptrMax
3815891b28 remove unused stable-diffusion.dll 2024-04-15 08:51:17 +08:00
IntptrMax
42d474382a Merge branch 'AIDotNet:main' into Add-stable-diffusion-reference-for-Windows 2024-04-15 08:36:29 +08:00
IntptrMax
fe691f2d44 1.Update some Stable Diffusion code 2024-04-15 08:34:16 +08:00
IntptrMax
3ee41a8ab1 1. Add references for Stable Diffusion for windows.
2. Update load lib code
2024-04-15 08:22:12 +08:00
zeyu xu
7ca41dff8a add docker file py 2024-04-13 10:56:47 +08:00
zeyu xu
ba2e86993e add dll 2024-04-13 10:55:42 +08:00
zyxucp
13878046a2 忽略文件 2024-04-12 12:13:43 +08:00
zeyu xu
49ff8bf54f fix 删除空引用 2024-04-11 23:35:53 +08:00
zeyu xu
e9cc5a3993 add loadding 2024-04-11 22:32:00 +08:00
zyxucp
b213964b63 Merge pull request #67 from IntptrMax/UpdateStableDiffusion
Update Stable Diffusion
2024-04-11 21:53:44 +08:00
IntptrMax
bfbed44270 Update Stable Diffusion 2024-04-11 15:11:17 +08:00
zyxucp
9b07d88392 Update Dockerfile-py 2024-04-10 23:31:20 +08:00
zyxucp
3f8ed109f9 fix nuget 2024-04-10 23:23:18 +08:00
zyxucp
3f969627a4 fix 修改ocr runtime 2024-04-10 22:51:37 +08:00
zyxucp
d92970819a Merge pull request #65 from AIDotNet/feature_ocr
add runtime
2024-04-10 22:39:28 +08:00
zyxucp
23e756fa9b add runtime 2024-04-10 22:37:32 +08:00
zyxucp
5f58126fbf Merge pull request #64 from AIDotNet/feature_ocr
Feature ocr
2024-04-10 22:25:36 +08:00
zyxucp
dcfd0ffb8f add ocr 2024-04-10 22:23:59 +08:00
zyxucp
17221d056c add img 2024-04-10 21:57:45 +08:00
zyxucp
4a9dcfada4 fix 修改默认key 2024-04-10 21:54:26 +08:00
zyxucp
bb6c2bb020 fix 升级LLamaSharp 2024-04-09 12:32:39 +08:00
zyxucp
a8760a34de fix 修改地址错误问题 2024-04-09 12:20:15 +08:00
zeyu xu
1e432a5782 fix 删除pynet 2024-04-08 14:50:34 +08:00
zeyu xu
cb861ef2bb fix OCR 2024-04-08 12:38:14 +08:00
zeyu xu
7cee8fd87a add OCR 和文档查询上限 2024-04-08 12:10:32 +08:00
zeyu xu
8ce0e5d348 add ocr 2024-04-07 22:31:57 +08:00
zyxucp
90bce7c89f Merge branch 'main' of https://github.com/AIDotNet/AntSK 2024-04-07 15:59:03 +08:00
zyxucp
b840d0bcce fix add gpu avx 2024-04-07 15:58:26 +08:00
zyxucp
bfa6d28289 Update README.md 2024-04-07 14:40:47 +08:00
zeyu xu
f6e6ca9747 Merge branch 'main' of github.com:AIDotNet/AntSK 2024-04-07 11:44:15 +08:00
zeyu xu
75f8d39648 fix 修改类型 2024-04-07 11:44:05 +08:00
zyxucp
9a939eba5a fix 修改异步为同步 2024-04-07 11:03:18 +08:00
zyxucp
4e93efe821 fix 修改异步为同步 2024-04-07 10:49:35 +08:00
zyxucp
8bdbee80a0 fix 修改类型 2024-04-07 10:37:16 +08:00
zyxucp
6bdf5dcc03 fix 修复PG字段报错问题 2024-04-07 10:05:15 +08:00
zeyu xu
0bf0a9d78a fix chat style 2024-04-06 11:50:15 +08:00
zeyu xu
38e9fea601 Merge branch 'main' of github.com:AIDotNet/AntSK 2024-04-06 11:47:59 +08:00
zeyu xu
d2366b3b46 update Semantic Kernel and fix kmsdetaillist style 2024-04-06 11:47:49 +08:00
zyxucp
3aff93083a Update README.md 2024-04-06 11:04:58 +08:00
zeyu xu
eb998199db fix 删除不要的Controller 2024-04-06 00:03:53 +08:00
zeyu xu
1dd794af1b fix 修改安装向量插件 2024-04-06 00:03:17 +08:00
zeyu xu
08c9923e7e update docker-compose.yal 2024-04-05 23:58:45 +08:00
zyxucp
06b109ca87 Merge pull request #62 from AIDotNet/feature_kms
Feature kms
2024-04-05 22:27:21 +08:00
zeyu xu
9b039335c7 fix 修改style 2024-04-05 22:26:37 +08:00
zeyu xu
041378e5fd add 文档搜索测试 2024-04-05 22:16:44 +08:00
zeyu xu
6dc5ae10e3 add 搜索测试布局 2024-04-05 21:19:41 +08:00
zeyu xu
5807f4c283 fix 修改切片详情样式 2024-04-05 21:09:38 +08:00
zyxucp
8ef4445908 Merge pull request #61 from AIDotNet/feature_kms
Feature kms
2024-04-05 20:23:49 +08:00
zeyu xu
8a0609e970 add excel导入 2024-04-05 20:23:03 +08:00
zeyu xu
9f33b5009b add excel 导入 2024-04-05 19:50:58 +08:00
zeyu xu
50e66db8a1 Merge branch 'feature_kms' of github.com:AIDotNet/AntSK into feature_kms 2024-04-05 19:18:26 +08:00
zeyu xu
c3e83b569a fix 升级nuget 2024-04-05 19:18:20 +08:00
zyxucp
85d1c5ea7e add npoi 2024-04-05 19:17:49 +08:00
zeyu xu
ec1d126a02 add excel导入 2024-04-05 19:13:30 +08:00
zyxucp
e857695e70 Update docker-compose.simple.yml 2024-04-05 18:45:56 +08:00
zyxucp
fa9b2051fe Update docker-compose.yml 2024-04-05 18:45:41 +08:00
zyxucp
d450efcffe Merge pull request #60 from AIDotNet/feature_kms
fix 修改提示词上限
2024-04-05 15:51:22 +08:00
zeyu xu
2a6c84c200 fix 修改提示词上限 2024-04-05 15:50:46 +08:00
zyxucp
138a952ace Merge pull request #59 from AIDotNet/feature_kms
Feature kms
2024-04-05 15:41:14 +08:00
zeyu xu
eb6528ecd2 add 修改message结构,减少localstore存储 2024-04-05 15:39:56 +08:00
zeyu xu
2c30bbfa09 fix 细节调整 2024-04-05 15:29:39 +08:00
zeyu xu
c5a78c2135 add modeldownchange 2024-04-05 15:12:29 +08:00
zeyu xu
f03362ee41 fix 修改dropdown Trigger.Click 2024-04-05 15:04:44 +08:00
zeyu xu
fad3167d97 add kms settings 2024-04-05 15:00:37 +08:00
zeyu xu
ad949681dd add change 2024-04-05 14:41:58 +08:00
zeyu xu
27999d76b0 fix 修改知识库函数 2024-04-05 14:26:35 +08:00
zeyu xu
83278352d6 add kms 配置 2024-04-05 14:12:25 +08:00
zeyu xu
fcc56f5fef fix bge embedding 无法切片问题 2024-04-04 00:37:08 +08:00
zyxucp
4ebe2ecc32 fix 修改初始化,增加完成标识 2024-04-02 13:53:32 +08:00
zeyu xu
e684cba527 Merge branch 'main' of github.com:AIDotNet/AntSK 2024-04-02 13:34:38 +08:00
zeyu xu
888dc19ee0 fix bgeembedding 2024-04-02 13:34:24 +08:00
zyxucp
731aea702f fix 修改提示词 2024-04-02 11:17:53 +08:00
zyxucp
09e22bc76a Update README.md 2024-04-02 00:07:08 +08:00
zyxucp
74406d88a0 Merge pull request #58 from AIDotNet/feature_StableDiffusion
fix 修改为静态类
2024-04-01 23:57:12 +08:00
zeyu xu
e5f9d97560 fix 修改为静态类 2024-04-01 23:56:44 +08:00
zyxucp
59e768aaea Merge pull request #57 from AIDotNet/feature_StableDiffusion
Feature stable diffusion
2024-04-01 23:39:22 +08:00
zeyu xu
6a7cb24a5b add sd 2024-04-01 23:08:53 +08:00
zeyu xu
1db40d534c add apptype 2024-04-01 22:14:18 +08:00
zeyu xu
11d6e30f7e add sd function 2024-04-01 22:03:00 +08:00
zeyu xu
9d5214aaae add sdmodel 2024-04-01 21:57:18 +08:00
zeyu xu
010b906271 add sd 2024-04-01 21:35:51 +08:00
zeyu xu
16bf944edf add sd 2024-04-01 21:31:15 +08:00
zeyu xu
5bae5a099a margin 2024-04-01 21:01:29 +08:00
zyxucp
f771ea9521 Merge branch 'main' of https://github.com/AIDotNet/AntSK 2024-04-01 13:54:53 +08:00
zyxucp
994efbf37c update nuget 2024-04-01 13:54:20 +08:00
zyxucp
938cd86c88 Update README.md 2024-03-31 13:24:21 +08:00
zeyu xu
1339cbadbc fix 修改menukey 2024-03-31 13:07:30 +08:00
zeyu xu
bd0ad570ad add 增加使用文档 2024-03-31 13:07:08 +08:00
zeyu xu
234e649a7e fix 优化部分内容 2024-03-31 12:38:17 +08:00
zyxucp
c431dbc842 Update README.md 2024-03-31 00:28:16 +08:00
zyxucp
76283060d9 Update docker-compose.simple.yml 2024-03-30 23:28:52 +08:00
zyxucp
75ba506db4 Update docker-compose.yml 2024-03-30 23:28:33 +08:00
zeyu xu
0c8ad5fe8d add loadding 2024-03-30 19:50:29 +08:00
zeyu xu
68ce0db011 fix 样式修改 2024-03-30 17:35:40 +08:00
zeyu xu
c36de1a1e9 add 选项控制 2024-03-30 17:25:58 +08:00
zeyu xu
44ef759abd fix 修改控件 2024-03-30 14:47:29 +08:00
longdream
0c3d9844be Merge pull request #52 from longdream/main
bge embedding模型添加,bge用的CPU。
2024-03-29 21:51:35 +08:00
longdream
854c62a4ca 合并 2024-03-29 21:50:17 +08:00
longdream
5ed4fd5299 Merge branch 'main' of https://github.com/longdream/AntSK 2024-03-29 20:00:53 +08:00
longdream
af5ec43571 修改设置界面 2024-03-29 20:00:49 +08:00
junlong
d7b56d1590 Merge branch 'main' of https://github.com/longdream/AntSK 2024-03-29 15:34:08 +08:00
longdream
b925f8890b 修改token长度 2024-03-28 23:06:21 +08:00
longdream
5d80ee994a 解决线程冲突问题 2024-03-28 19:04:11 +08:00
longdream
f73bd2dfda 增减embedding 2024-03-27 22:53:45 +08:00
longdream
f340ee1088 embedding封装 2024-03-26 23:14:55 +08:00
longdream
edad2644aa 删除没必要的py文件 2024-03-26 20:48:49 +08:00
longdream
8a56a0393a Merge branch 'main' of https://github.com/longdream/AntSK 2024-03-26 20:48:07 +08:00
junlong
bd5ca06d8f test 2024-03-25 16:55:41 +08:00
junlong
e0985ecec3 Merge branch 'main' of https://github.com/longdream/AntSK 2024-03-25 16:48:21 +08:00
junlong
e56b74d4af 删除chat以外的文件 2024-03-25 16:48:11 +08:00
longdream
849b18f677 Merge branch 'AIDotNet:main' into main 2024-03-22 19:36:20 +08:00
junlong
344128e49d Merge branch 'main' of https://github.com/longdream/AntSK 2024-03-21 19:38:03 +08:00
junlong
56fc9dd517 test 2024-03-21 19:37:56 +08:00
279 changed files with 15159 additions and 5552 deletions

4
.gitignore vendored
View File

@@ -324,10 +324,6 @@ ASALocalRun/
# MSBuild Binary and Structured Log
*.binlog
# NVidia Nsight GPU debugger configuration file
*.nvuser
*.dll
*.pdb
# MFractors (Xamarin productivity tool) working folder
.mfractor/
**/bin/

View File

@@ -22,4 +22,5 @@ WORKDIR /app
COPY --from=build /app/publish .
RUN ln -sf /usr/share/zoneinfo/Asia/Shanghai /etc/localtime
RUN echo 'Asia/Shanghai' >/etc/timezone
RUN apt update && apt install -y libpugixml-dev libtbb-dev
ENTRYPOINT ["dotnet", "AntSK.dll"]

View File

@@ -25,4 +25,5 @@ ENV PATH="/app:/opt/conda/bin:/usr/local/bin:${PATH}"
RUN ln -sf /usr/share/zoneinfo/Asia/Shanghai /etc/localtime
RUN echo 'Asia/Shanghai' >/etc/timezone
RUN pip install -r requirements.txt -i https://pypi.tuna.tsinghua.edu.cn/simple
ENTRYPOINT ["dotnet", "AntSK.dll"]
RUN apt update && apt install -y libpugixml-dev libtbb-dev
ENTRYPOINT ["dotnet", "AntSK.dll"]

View File

@@ -186,7 +186,7 @@
same "printed page" as the copyright notice for easier
identification within third-party archives.
Copyright [yyyy] [name of copyright owner]
Copyright [2024] [许泽宇]
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.

View File

@@ -1,214 +0,0 @@
[简体中文](./README.md) | English
# AntSK
## AI Knowledge Base/Intelligent Agent built on .Net8+AntBlazor+SemanticKernel
## ⭐Core Features
- **Semantic Kernel**: Utilizes advanced natural language processing technology to accurately understand, process, and respond to complex semantic queries, providing users with precise information retrieval and recommendation services.
- **Kernel Memory**: Capable of continuous learning and storing knowledge points, AntSK has long-term memory function, accumulates experience, and provides a more personalized interaction experience.
- **Knowledge Base**: Import knowledge base through documents (Word, PDF, Excel, Txt, Markdown, Json, PPT) and perform knowledge base Q&A.
- **GPT Generation**: This platform supports creating personalized GPT models, enabling users to build their own GPT models.
- **API Interface Publishing**: Exposes internal functions in the form of APIs, enabling developers to integrate AntSK into other applications and enhance application intelligence.
- **API Plugin System**: Open API plugin system that allows third-party developers or service providers to easily integrate their services into AntSK, continuously enhancing application functionality.
- **.Net Plugin System**: Open dll plugin system that allows third-party developers or service providers to easily integrate their business functions by generating dll in standard format code, continuously enhancing application functionality.
- **Online Search**: AntSK, real-time access to the latest information, ensuring users receive the most timely and relevant data.
- **Model Management**: Adapts and manages integration of different models from different manufacturers, including gguf types supported by **llama.cpp** and models offline running supported by **llamafactory**.
- **Domestic Innovation**: AntSK supports domestic models and databases and can run under domestic innovation conditions.
- **Model Fine-Tuning**: Planned based on llamafactory for model fine-tuning.
## ⛪Application Scenarios
AntSK is suitable for various business scenarios, such as:
- Enterprise knowledge management system
- Automatic customer service and chatbots
- Enterprise search engine
- Personalized recommendation system
- Intelligent writing assistance
- Education and online learning platforms
- Other interesting AI Apps
## ✏Function Examples
### Online Demo
```
https://antsk.ai-dotnet.com/
```
```
Default account: test
Default password: test
Due to the low configuration of the cloud server, the local model cannot be run, so the system settings permissions have been closed. You can simply view the interface. If you want to use the local model, please download and use it on your own.
```
### Other Function Examples
[Video Demonstration](https://www.bilibili.com/video/BV1zH4y1h7Y9/)
## ❓How to get started?
Here I am using Postgres as the data and vector storage because Semantic Kernel and Kernel Memory support it, but you can also use other options.
The model by default supports the local model of openai, azure openai, and llama. If you need to use other models, you can integrate them using one-api.
The Login configuration in the configuration file is the default login account and password.
The following configuration file needs to be configured
## 1⃣Using docker-compose
Provided the pg version **appsettings.json** and simplified version (Sqlite+disk) **docker-compose.simple.yml**
Download **docker-compose.yml** from the project root directory and place the configuration file **appsettings.json** in the same directory.
The pg image has already been prepared. You can modify the default username and password in docker-compose.yml, and then the database connection in your **appsettings.json** needs to be consistent.
Then you can execute the following command in the directory to start AntSK
```
docker-compose up -d
```
## 2⃣How to mount local models and model download directory in docker
```
# Non-host version, do not use local proxy
version: '3.8'
services:
antsk:
container_name: antsk
image: registry.cn-hangzhou.aliyuncs.com/AIDotNet/antsk:v0.1.5ports:
- 5000:5000
networks:
- antsk
depends_on:
- antskpg
restart: always
environment:
- ASPNETCORE_URLS=http://*:5000
volumes:
- ./appsettings.json:/app/appsettings.json # Local configuration file needs to be placed in the same directory
- D://model:/app/model
networks:
antsk:
```
Taking this as an example, it means mounting the local D://model folder of Windows into the container /app/model. If so, the model address in your appsettings.json should be configured as
```
model/xxx.gguf
```
## 3⃣Some meanings of configuration file
```
{
"DBConnection": {
"DbType": "Sqlite",
"ConnectionStrings": "Data Source=AntSK.db;"
},
"KernelMemory": {
"VectorDb": "Disk",
"ConnectionString": "Host=;Port=;Database=antsk;Username=;Password=",
"TableNamePrefix": "km-"
},
"LLamaSharp": {
"RunType": "GPU",
"FileDirectory": "D:\\Code\\AI\\AntBlazor\\model\\"
},
"Login": {
"User": "admin",
"Password": "xuzeyu"
},
"BackgroundTaskBroker": {
"ImportKMSTask": {
"WorkerCount": 1
}
}
}
```
```
// Supports various databases, you can check SqlSugar, MySql, SqlServer, Sqlite, Oracle, PostgreSQL, Dm, Kdbndp, Oscar, MySqlConnector, Access, OpenGauss, QuestDB, HG, ClickHouse, GBase, Odbc, OceanBaseForOracle, TDengine, GaussDB, OceanBase, Tidb, Vastbase, PolarDB, Custom
DBConnection.DbType
// Connection string, need to use the corresponding string according to the different DB types
DBConnection.ConnectionStrings
//The type of vector storage, supporting Postgres, Disk, Memory, Qdrant, Redis, AzureAISearch
//Postgres and Redis require ConnectionString configuration
//The ConnectionString of Qdrant and AzureAISearch uses Endpoint | APIKey
KernelMemory.VectorDb
//Local model execution options: GPU and CPU. When using the online API, any option can be used.
LLamaSharp.RunType
//Local model path, used for quick selection of models under llama, as well as saving downloaded models.
LLamaSharp.FileDirectory
//Default admin account password
Login
//Import asynchronous processing thread count. A higher count can be used for online API, but for local models, 1 is recommended to avoid memory overflow issues.
BackgroundTaskBroker.ImportKMSTask.WorkerCount
```
## ⚠Fixing Style Issues:
Run the following in AntSK/src/AntSK:
```
dotnet clean
dotnet build
dotnet publish "AntSK.csproj"
```
Then navigate to AntSK/src/AntSK/bin/Release/net8.0/publish and run:
```
dotnet AntSK.dll
```
The styles should now be applied after starting.
I'm using CodeFirst mode for the database, so as long as the database connection is properly configured, the table structure will be created automatically.
## ✔Using llamafactory
```
1. First, ensure that Python and pip are installed in your environment. This step is not necessary if using an image, such as version v0.2.3.2, which already includes the complete Python environment.
2. Go to the model add page and select llamafactory.
3. Click "Initialize" to check whether the 'pip install' environment setup is complete.
4. Choose a model that you like.
5. Click "Start" to begin downloading the model from the tower. This may involve a somewhat lengthy wait.
6. After the model has finished downloading, enter http://localhost:8000/ in the request address. The default port is 8000.
7. Click "Save" and start chatting.
8. Many people ask about the difference between LLamaSharp and llamafactory. In fact, LLamaSharp is a .NET implementation of llama.cpp, but only supports local gguf models, while llamafactory supports a wider variety of models and uses Python implementation. The main difference lies here. Additionally, llamafactory has the ability to fine-tune models, which is an area we will focus on integrating in the future.
```
## 🤝 Contributing
[![PRs Welcome](https://img.shields.io/badge/PRs-welcome-brightgreen.svg?style=flat-square)](https://github.com/AIDotNet/AntSK/pulls)
If you would like to contribute, feel free to create a [Pull Request](https://github.com/AIDotNet/AntSK/pulls), or give us [Bug Report](https://github.com/AIDotNet/AntSK/issues/new).
## 💕 Contributors
This project exists thanks to all the people who contribute.
<a href="https://github.com/AIDotNet/AntSK/graphs/contributors">
<img src="https://contrib.rocks/image?repo=AIDotNet/AntSK&max=1000&columns=15&anon=1" />
</a>
## 🚨 Code of Conduct
This project has adopted the code of conduct defined by the Contributor Covenant to clarify expected behavior in our community.
For more information see the [.NET Foundation Code of Conduct](https://dotnetfoundation.org/code-of-conduct).
To learn more or get started with **AntSK**, follow my official WeChat account and join the discussion group.
## ☎Contact Me
If you have any questions or suggestions, please contact me through my official WeChat account. We also have a discussion group where you can send a message to join, and then I will add you to the group.
![Official WeChat Account](https://github.com/AIDotNet/Avalonia-Assistant/blob/main/img/gzh.jpg)
---
We appreciate your interest in **AntSK** and look forward to collaborating with you to create an intelligent future!

217
README.md
View File

@@ -1,92 +1,90 @@
中文|[English](https://github.com/AIDotNet/AntSK/blob/main/README.en.md)
[简体中文](./README.zh.md) | English
# AntSK
## 使用.Net8+Blazor+SemanticKernel 打造的AI知识库/智能体
## AI Knowledge Base/Intelligent Agent built on .Net8+AntBlazor+SemanticKernel
## ⭐核心功能
## ⭐Core Features
- **语义内核 (Semantic Kernel)**:采用领先的自然语言处理技术,准确理解、处理和响应复杂的语义查询,为用户提供精确的信息检索和推荐服务。
- **Semantic Kernel**: Utilizes advanced natural language processing technology to accurately understand, process, and respond to complex semantic queries, providing users with precise information retrieval and recommendation services.
- **内存内核 (Kernel Memory)**具备持续学习和存储知识点的能力AntSK 拥有长期记忆功能,累积经验,提供更个性化的交互体验。
- **Kernel Memory**: Capable of continuous learning and storing knowledge points, AntSK has long-term memory function, accumulates experience, and provides a more personalized interaction experience.
- **知识库**:通过文档(WordPDFExcelTxtMarkdownJsonPPT)等形式导入知识库,可以进行知识库问答。
- **Knowledge Base**: Import knowledge base through documents (Word, PDF, Excel, Txt, Markdown, Json, PPT) and perform knowledge base Q&A.
- **GPTs 生成**此平台支持创建个性化的GPT模型尝试构建您自己的GPT模型。
- **GPT Generation**: This platform supports creating personalized GPT models, enabling users to build their own GPT models.
- **API接口发布**将内部功能以API的形式对外提供便于开发者将AntSK 集成进其他应用,增强应用智慧。
- **API Interface Publishing**: Exposes internal functions in the form of APIs, enabling developers to integrate AntSK into other applications and enhance application intelligence.
- **API插件系统**开放式API插件系统允许第三方开发者或服务商轻松将其服务集成到AntSK不断增强应用功能。
- **API Plugin System**: Open API plugin system that allows third-party developers or service providers to easily integrate their services into AntSK, continuously enhancing application functionality.
- **.Net插件系统**开放式dll插件系统允许第三方开发者或服务商轻松将其业务功能通过标准格式的代码生成dll后集成到AntSK不断增强应用功能。
- **.Net Plugin System**: Open dll plugin system that allows third-party developers or service providers to easily integrate their business functions by generating dll in standard format code, continuously enhancing application functionality.
- **联网搜索**AntSK实时获取最新信息确保用户接受到的资料总是最及时、最相关的。
- **Online Search**: AntSK, real-time access to the latest information, ensuring users receive the most timely and relevant data.
- **模型管理**:适配和管理集成不同厂商的不同模型。并且支持**llama.cpp**所支持的gguf类型以及**llamafactory**所支持的模型离线运行
- **Model Management**: Adapts and manages integration of different models from different manufacturers, including gguf types supported by **llama.cpp** and models offline running supported by **llamafactory**.
- **国产信创**AntSK支持国产模型和国产数据库可以在信创条件下运行
- **Domestic Innovation**: AntSK supports domestic models and databases and can run under domestic innovation conditions.
- **模型微调**规划中基于llamafactory进行模型微调
- **Model Fine-Tuning**: Planned based on llamafactory for model fine-tuning.
## ⛪应用场景
## ⛪Application Scenarios
AntSK 适用于多种业务场景,例如:
- 企业级知识管理系统
- 自动客服与聊天机器人
- 企业级搜索引擎
- 个性化推荐系统
- 智能辅助写作
- 教育与在线学习平台
- 其他有意思的AI App
AntSK is suitable for various business scenarios, such as:
- Enterprise knowledge management system
- Automatic customer service and chatbots
- Enterprise search engine
- Personalized recommendation system
- Intelligent writing assistance
- Education and online learning platforms
- Other interesting AI Apps
## ✏Function Examples
### Online Demo
[document](http://antsk.cn/)
[demo](https://antsk.ai-dotnet.com/)
## ✏️功能示例
### 在线演示
```
https://antsk.ai-dotnet.com/
```
```
默认账号test
Default account: test
默认密码:test
Default password: test
由于云服务器配置较低,无法运行本地模型,所以把系统设置权限关闭了,大家看看界面即可,要使用本地模型,请下载自行使用
Due to the low configuration of the cloud server, the local model cannot be run, so the system settings permissions have been closed. You can simply view the interface. If you want to use the local model, please download and use it on your own.
```
### 其他功能示例
[视频示例](https://www.bilibili.com/video/BV1zH4y1h7Y9/)
### Other Function Examples
[Video Demonstration](https://www.bilibili.com/video/BV1zH4y1h7Y9/)
## ❓如何开始?
## ❓How to get started?
在这里我使用的是Postgres 作为数据存储和向量存储,因为Semantic KernelKernel Memory都支持他,当然你也可以换成其他的。
Here I am using Postgres as the data and vector storage because Semantic Kernel and Kernel Memory support it, but you can also use other options.
模型默认支持openaiazure openai、讯飞星火、阿里云积、 和llama支持的gguf本地模型 以及llamafactory的本地模型,如果需要使用其他模型,可以使用one-api进行集成。
The model by default supports the local model of openai, azure openai, and llama. If you need to use other models, you can integrate them using one-api.
配置文件中的Login配置是默认的登录账号和密码
The Login configuration in the configuration file is the default login account and password.
需要配置如下的配置文件
The following configuration file needs to be configured
## 1使用docker-compose
## 1Using docker-compose
提供了pg版本 **appsettings.json** 和 简化版本(**Sqlite+disk** **docker-compose.simple.yml**
Provided the pg version **appsettings.json** and simplified version (Sqlite+disk) **docker-compose.simple.yml**
从项目根目录下载**docker-compose.yml**,然后把配置文件**appsettings.json**和它放在统一目录,
Download **docker-compose.yml** from the project root directory and place the configuration file **appsettings.json** in the same directory.
这里已经把pg的镜像做好了。在docker-compose.yml中可以修改默认账号密码然后你的**appsettings.json**的数据库连接需要保持一致。
The pg image has already been prepared. You can modify the default username and password in docker-compose.yml, and then the database connection in your **appsettings.json** needs to be consistent.
然后你可以进入到目录后执行
Then you can execute the following command in the directory to start AntSK
```
docker-compose up -d
```
来启动AntSK
## 2如何在docker中挂载本地模型和模型下载的目录
## 2How to mount local models and model download directory in docker
```
# host 版本, 不使用本机代理
# Non-host version, do not use local proxy
version: '3.8'
services:
antsk:
container_name: antsk
image: registry.cn-hangzhou.aliyuncs.com/AIDotNet/antsk:v0.2.3
ports:
image: registry.cn-hangzhou.aliyuncs.com/AIDotNet/antsk:v0.1.5ports:
- 5000:5000
networks:
- antsk
@@ -96,31 +94,35 @@ services:
environment:
- ASPNETCORE_URLS=http://*:5000
volumes:
- ./appsettings.json:/app/appsettings.json # 本地配置文件 需要放在同级目录
- ./appsettings.json:/app/appsettings.json # Local configuration file needs to be placed in the same directory
- D://model:/app/model
networks:
antsk:
```
以这个为示例意思是把windows本地D://model的文件夹挂载进 容器内/app/model 如果是这样你的appsettings.json中的模型地址应该配置为
Taking this as an example, it means mounting the local D://model folder of Windows into the container /app/model. If so, the model address in your appsettings.json should be configured as
```
model/xxx.gguf
```
## 3配置文件的一些含义
## 3Some meanings of configuration file
```
{
"DBConnection": {
"DbType": "Sqlite",
"DbType": "Sqlite",
"ConnectionStrings": "Data Source=AntSK.db;"
},
"KernelMemory": {
"VectorDb": "Disk",
"VectorDb": "Disk",
"ConnectionString": "Host=;Port=;Database=antsk;Username=;Password=",
"TableNamePrefix": "km-"
},
"FileDir": {
"DirectoryPath": "D:\\git\\AntBlazor\\model"
},
"LLamaSharp": {
"RunType": "GPU",
"FileDirectory": "D:\\Code\\AI\\AntBlazor\\model\\"
"RunType": "GPU",
"ContextSize": 2048,
"GpuLayerCount": 20
},
"Login": {
"User": "admin",
@@ -134,86 +136,85 @@ model/xxx.gguf
}
```
```
//支持多种数据库,具体可以查看SqlSugarMySqlSqlServerSqliteOraclePostgreSQLDmKdbndpOscarMySqlConnectorAccessOpenGaussQuestDBHGClickHouseGBaseOdbcOceanBaseForOracleTDengineGaussDBOceanBaseTidbVastbasePolarDBCustom
// Supports various databases, you can check SqlSugar, MySql, SqlServer, Sqlite, Oracle, PostgreSQL, Dm, Kdbndp, Oscar, MySqlConnector, Access, OpenGauss, QuestDB, HG, ClickHouse, GBase, Odbc, OceanBaseForOracle, TDengine, GaussDB, OceanBase, Tidb, Vastbase, PolarDB, Custom
DBConnection.DbType
//连接字符串需要根据不同DB类型用对应的字符串
// Connection string, need to use the corresponding string according to the different DB types
DBConnection.ConnectionStrings
//向量存储的类型,支持 PostgresDiskMemoryQdrantRedisAzureAISearch
//Postgres、Redis需要配置 ConnectionString
//Qdrant AzureAISearch 的 ConnectionString 使用 Endpoint|APIKey
//The type of vector storage, supporting Postgres, Disk, Memory, Qdrant, Redis, AzureAISearch
//Postgres and Redis require ConnectionString configuration
//The ConnectionString of Qdrant and AzureAISearch uses Endpoint | APIKey
KernelMemory.VectorDb
//本地模型使用的运行方式 GUP CPU ,如果用在线API 这个随意使用一个即可
//Local model execution options: GPU and CPU. When using the online API, any option can be used.
LLamaSharp.RunType
//本地模型路径用于在选择llama时可以快速选择目录下的模型以及保存下载的模型
//Local model path, used for quick selection of models under llama, as well as saving downloaded models.
LLamaSharp.FileDirectory
//默认管理员账号密码
//Default admin account password
Login
//导入异步处理的线程数使用在线API可以高一点本地模型建议1 否则容易内存溢出崩掉
//Import asynchronous processing thread count. A higher count can be used for online API, but for local models, 1 is recommended to avoid memory overflow issues.
BackgroundTaskBroker.ImportKMSTask.WorkerCount
```
## ⚠️找不到样式问题解决:
AntSK/src/AntSK下执行:
## ⚠️Fixing Style Issues:
Run the following in AntSK/src/AntSK:
```
dotnet clean
dotnet build
dotnet publish "AntSK.csproj"
```
再去AntSK/src/AntSK/bin/Release/net8.0/publish
Then navigate to AntSK/src/AntSK/bin/Release/net8.0/publish and run:
```
dotnet AntSK.dll
```
然后启动就有样式了
The styles should now be applied after starting.
DB我使用的是CodeFirst模式只要配置好数据库链接表结构是自动创建的
I'm using CodeFirst mode for the database, so as long as the database connection is properly configured, the table structure will be created automatically.
## ✔️使用llamafactory
## ✔️Using llamafactory
```
1、首先需要确保你的环境已经安装了python和pip如果使用镜像例如v0.2.3.2版本已经包含了 python全套环境则无需此步骤
2、进入模型添加页面选择llamafactory
3、点击初始化可以检查pip install 环境是否完成
4、选择一个喜欢的模型
5、点击启动,这会开始从魔塔下载模型,你可能需要有一个较为漫长的等待
6、等待模型下载完毕后,在请求地址输入 http://localhost:8000/ 这里默认是使用8000端口
7、点击保存,然后就可以开始聊天了
8、很多人会问 LLamaSharpllamafactory有什么区别其实这两者LLamaSharp是llama.cpp的 dotnet实现但是只支持本地gguf模型 而llamafactory 支持的模型种类更多但使用的是python的实现其主要差异在这里另外llamafactory具有模型微调的能力这也是我们下一步需要重点集成的部分。
1. First, ensure that Python and pip are installed in your environment. This step is not necessary if using an image, such as version v0.2.3.2, which already includes the complete Python environment.
2. Go to the model add page and select llamafactory.
3. Click "Initialize" to check whether the 'pip install' environment setup is complete.
4. Choose a model that you like.
5. Click "Start" to begin downloading the model from the tower. This may involve a somewhat lengthy wait.
6. After the model has finished downloading, enter http://localhost:8000/ in the request address. The default port is 8000.
7. Click "Save" and start chatting.
8. Many people ask about the difference between LLamaSharp and llamafactory. In fact, LLamaSharp is a .NET implementation of llama.cpp, but only supports local gguf models, while llamafactory supports a wider variety of models and uses Python implementation. The main difference lies here. Additionally, llamafactory has the ability to fine-tune models, which is an area we will focus on integrating in the future.
```
## 🤝 贡献
## 🤝 Contributing
[![PRs Welcome](https://img.shields.io/badge/PRs-welcome-brightgreen.svg?style=flat-square)](https://github.com/AIDotNet/AntSK/pulls)

如果你想贡献,可以创建一个[拉取请求](https://github.com/AIDotNet/AntSK/pulls), 或给我们[错误报告](https://github.com/AIDotNet/AntSK/issues/new).


## 💕 贡献者
[PRs Welcome](https://img.shields.io/badge/PRs-welcome-brightgreen.svg?style=flat-square)
If you would like to contribute, feel free to create a [Pull Request](https://github.com/AIDotNet/AntSK/pulls), or give us [Bug Report](https://github.com/AIDotNet/AntSK/issues/new).
## 💕 Contributors
This project exists thanks to all the people who contribute.
这个项目的存在要感谢所有的贡献者。

<a href="https://github.com/AIDotNet/AntSK/graphs/contributors">
<img src="https://contrib.rocks/image?repo=AIDotNet/AntSK&max=1000&columns=15&anon=1" />
</a>

## 🚨 行为准则
该项目采用了贡献者公约定义的行为准则,以阐明我们社区的预期行为。有关更多信息,请参见 .NET Foundation 行为准则。 [.NET Foundation Code of Conduct](https://dotnetfoundation.org/code-of-conduct).
想了解更多信息或开始使用 **AntSK**,可以关注我的公众号以及加入交流群。
## ☎️联系我
如有任何问题或建议,请通过以下方式关注我的公众号,发消息与我联系,我们也有交流群,可以发送进群等消息,然后我会拉你进交流群
![公众号](https://github.com/AIDotNet/AntSK/blob/main/images/gzh.jpg)
## 🌟 Star History
<a href="https://github.com/AIDotNet/AntSK/stargazers" target="_blank" style="display: block" align="center">
<picture>
<source media="(prefers-color-scheme: dark)" srcset="https://api.star-history.com/svg?repos=AIDotNet/AntSK&type=Date&theme=dark" />
<source media="(prefers-color-scheme: light)" srcset="https://api.star-history.com/svg?repos=AIDotNet/AntSK&type=Date" />
<img alt="Star History Chart" src="https://api.star-history.com/svg?repos=AIDotNet/AntSK&type=Date" />
</picture>
<img src="https://contrib.rocks/image?repo=AIDotNet/AntSK&max=1000&columns=15&anon=1" />
</a>
## 🚨 Use Protocol
This warehouse follows the [Apache-2.0 License](https://github.com/AIDotNet/AntSK?tab=Apache-2.0-1-ov-file) open source protocol.
The Apache open source license allows the use of AntSK in commercial environments, provided that the license terms are followed. One of the main terms is to retain the copyright and license statements.
If you plan to use AntSK in commercial projects, you need to ensure that you follow the following steps:
1. Copyright statement containing Apache license. [Apache-2.0 License](https://github.com/AIDotNet/AntSK?tab=Apache-2.0-1-ov-file).
2. If you modify the software source code, you need to clearly indicate these modifications in the source code.
## ☎Contact Me
If you have any questions or suggestions, please contact me through my official WeChat account. We also have a discussion group where you can send a message to join, and then I will add you to the group.
![Official WeChat Account](https://github.com/AIDotNet/AntSK/blob/main/images/gzh.jpg)
---
We appreciate your interest in **AntSK** and look forward to collaborating with you to create an intelligent future!

238
README.zh.md Normal file
View File

@@ -0,0 +1,238 @@
中文|[English](./README.md)
# AntSK
## 使用.Net8+Blazor+SemanticKernel 打造的AI知识库/智能体
## ⭐核心功能
- **语义内核 (Semantic Kernel)**:采用领先的自然语言处理技术,准确理解、处理和响应复杂的语义查询,为用户提供精确的信息检索和推荐服务。
- **内存内核 (Kernel Memory)**具备持续学习和存储知识点的能力AntSK 拥有长期记忆功能,累积经验,提供更个性化的交互体验。
- **知识库**通过文档Word、PDF、Excel、Txt、Markdown、Json、PPT等形式导入知识库可以进行知识库问答支持本地bge-embedding 向量模型 以及bge-rerank 重排模型。
- **文生图**:集成**StableDiffusion** 本地模型,可以进行文生图。
- **GPTs 生成**此平台支持创建个性化的GPT模型尝试构建您自己的GPT模型。
- **API接口发布**将内部功能以API的形式对外提供便于开发者将AntSK 集成进其他应用,增强应用智慧。
- **API插件系统**开放式API插件系统允许第三方开发者或服务商轻松将其服务集成到AntSK不断增强应用功能。
- **.Net插件系统**开放式dll插件系统允许第三方开发者或服务商轻松将其业务功能通过标准格式的代码生成dll后集成到AntSK不断增强应用功能。
- **联网搜索**AntSK实时获取最新信息确保用户接受到的资料总是最及时、最相关的。
- **模型管理**:适配和管理集成不同厂商的不同模型。并且支持**llama.cpp**所支持的gguf类型以及**llamafactory**所支持的模型离线运行
- **国产信创**AntSK支持国产模型和国产数据库可以在信创条件下运行
- **模型微调**规划中基于llamafactory进行模型微调
## ⛪应用场景
AntSK 适用于多种业务场景,例如:
- 企业级知识管理系统
- 自动客服与聊天机器人
- 企业级搜索引擎
- 个性化推荐系统
- 智能辅助写作
- 教育与在线学习平台
- 其他有意思的AI App
## ✏️功能示例
### 在线演示
[文档地址](http://antsk.cn/)
[体验地址](https://antsk.ai-dotnet.com/)
```
默认账号test
默认密码test
由于云服务器配置较低,无法运行本地模型,所以把系统设置权限关闭了,大家看看界面即可,要使用本地模型,请下载自行使用
请勿在演示站点上传敏感信息
```
### 其他功能示例
[视频示例](https://www.bilibili.com/video/BV1zH4y1h7Y9/)
[在线文档http://antsk.cn](http://antsk.cn)
## ❓如何开始?
在这里我使用的是Postgres 作为数据存储和向量存储因为Semantic Kernel和Kernel Memory都支持他当然你也可以换成其他的。
模型默认支持openai、azure openai、讯飞星火、阿里云积、 和llama支持的gguf本地模型 以及llamafactory的本地模型,如果需要使用其他模型可以使用one-api进行集成。
配置文件中的Login配置是默认的登录账号和密码
需要配置如下的配置文件
## 1⃣使用docker-compose
提供了pg版本 **appsettings.json** 和 简化版本(**Sqlite+disk** **docker-compose.simple.yml**
从项目根目录下载**docker-compose.yml**,然后把配置文件**appsettings.json**和它放在统一目录,
这里已经把pg的镜像做好了。在docker-compose.yml中可以修改默认账号密码然后你的**appsettings.json**的数据库连接需要保持一致。
然后你可以进入到目录后执行
```
docker-compose up -d
```
来启动AntSK
## 2⃣如何在docker中挂载本地模型和模型下载的目录
```
# 非 host 版本, 不使用本机代理
version: '3.8'
services:
antsk:
container_name: antsk
image: registry.cn-hangzhou.aliyuncs.com/AIDotNet/antsk:v0.3.1
ports:
- 5000:5000
networks:
- antsk
depends_on:
- antskpg
restart: always
environment:
- ASPNETCORE_URLS=http://*:5000
volumes:
- ./appsettings.json:/app/appsettings.json # 本地配置文件 需要放在同级目录
- D://model:/app/model
- D://model:/root/.cache/modelscope/hub/AI-ModelScope #使用Llamafactory时需要挂载 否则初始化的环境重启后会丢失
networks:
antsk:
```
以这个为示例意思是把windows本地D://model的文件夹挂载进 容器内/app/model 如果是这样你的appsettings.json中的模型地址应该配置为
```
model/xxx.gguf
```
## 3⃣配置文件的一些含义
```
{
"DBConnection": {
"DbType": "Sqlite",
"ConnectionStrings": "Data Source=AntSK.db;"
},
"KernelMemory": {
"VectorDb": "Disk",
"ConnectionString": "Host=;Port=;Database=antsk;Username=;Password=",
"TableNamePrefix": "km-"
},
"FileDir": {
"DirectoryPath": "D:\\git\\AntBlazor\\model"
},
"LLamaSharp": {
"RunType": "GPU",
"ContextSize": 2048,
"GpuLayerCount": 20
},
"Login": {
"User": "admin",
"Password": "xuzeyu"
},
"BackgroundTaskBroker": {
"ImportKMSTask": {
"WorkerCount": 1
}
}
}
```
```
//支持多种数据库具体可以查看SqlSugarMySqlSqlServerSqliteOraclePostgreSQLDmKdbndpOscarMySqlConnectorAccessOpenGaussQuestDBHGClickHouseGBaseOdbcOceanBaseForOracleTDengineGaussDBOceanBaseTidbVastbasePolarDBCustom
DBConnection.DbType
//连接字符串需要根据不同DB类型用对应的字符串
DBConnection.ConnectionStrings
//向量存储的类型,支持 Postgres、Disk、Memory、Qdrant、Redis、AzureAISearch
//Postgres、Redis需要配置 ConnectionString
//Qdrant 和AzureAISearch 的 ConnectionString 使用 Endpoint|APIKey
KernelMemory.VectorDb
//本地模型使用的运行方式 GUP CPU ,如果用在线API 这个随意使用一个即可
LLamaSharp.RunType
//本地模型路径用于在选择llama时可以快速选择目录下的模型以及保存下载的模型
LLamaSharp.FileDirectory
//默认管理员账号密码
Login
//导入异步处理的线程数使用在线API可以高一点本地模型建议1 否则容易内存溢出崩掉
BackgroundTaskBroker.ImportKMSTask.WorkerCount
```
## ⚠️找不到样式问题解决:
AntSK/src/AntSK下执行:
```
dotnet clean
dotnet build
dotnet publish "AntSK.csproj"
```
再去AntSK/src/AntSK/bin/Release/net8.0/publish下
```
dotnet AntSK.dll
```
然后启动就有样式了
DB我使用的是CodeFirst模式只要配置好数据库链接表结构是自动创建的
## ✔使用llamafactory
```
1、首先需要确保你的环境已经安装了python和pip如果使用镜像例如p0.2.4版本已经包含了 python全套环境则无需此步骤
2、进入模型添加页面选择llamafactory
3、点击初始化可以检查pip install 环境是否完成
4、选择一个喜欢的模型
5、点击启动,这会开始从魔塔下载模型,你可能需要有一个较为漫长的等待
6、等待模型下载完毕后在请求地址输入 http://localhost:8000/ 这里默认是使用8000端口
7、点击保存然后就可以开始聊天了
8、很多人会问 LLamaSharp与llamafactory有什么区别其实这两者LLamaSharp是llama.cpp的 dotnet实现但是只支持本地gguf模型 而llamafactory 支持的模型种类更多但使用的是python的实现其主要差异在这里另外llamafactory具有模型微调的能力这也是我们下一步需要重点集成的部分。
```
## 🤝 贡献
[![PRs Welcome](https://img.shields.io/badge/PRs-welcome-brightgreen.svg?style=flat-square)](https://github.com/AIDotNet/AntSK/pulls)

如果你想贡献,可以创建一个[拉取请求](https://github.com/AIDotNet/AntSK/pulls), 或给我们[错误报告](https://github.com/AIDotNet/AntSK/issues/new).


## 💕 贡献者
这个项目的存在要感谢所有的贡献者。

<a href="https://github.com/AIDotNet/AntSK/graphs/contributors">
<img src="https://contrib.rocks/image?repo=AIDotNet/AntSK&max=1000&columns=15&anon=1" />
</a>

## 🚨 使用协议
本仓库遵循 [Apache-2.0 License](https://github.com/AIDotNet/AntSK?tab=Apache-2.0-1-ov-file) 开源协议。
Apache开源许可证允许在商业环境中使用AntSK前提是需要遵守许可证的条款。主要条款之一是要保留版权声明和许可证声明。
如果您打算在商业项目中使用AntSK您需要确保遵守以下步骤
1、包含Apache许可证的版权声明。 [Apache-2.0 License](https://github.com/AIDotNet/AntSK?tab=Apache-2.0-1-ov-file) 。
2、如果您修改了软件源代码您需要在源代码中明确标明这些修改。
## ☎️联系我
如有任何问题或建议请通过以下方式关注我的公众号《许泽宇的技术分享》发消息与我联系我们也有AIDotnet交流群可以发送进群等消息然后我会拉你进交流群
![公众号](https://github.com/AIDotNet/AntSK/blob/main/images/gzh.jpg)
## 🌟 Star History
<a href="https://github.com/AIDotNet/AntSK/stargazers" target="_blank" style="display: block" align="center">
<picture>
<source media="(prefers-color-scheme: dark)" srcset="https://api.star-history.com/svg?repos=AIDotNet/AntSK&type=Date&theme=dark" />
<source media="(prefers-color-scheme: light)" srcset="https://api.star-history.com/svg?repos=AIDotNet/AntSK&type=Date" />
<img alt="Star History Chart" src="https://api.star-history.com/svg?repos=AIDotNet/AntSK&type=Date" />
</picture>
</a>

View File

@@ -3,9 +3,9 @@ version: '3.8'
services:
antsk:
container_name: antsk
image: registry.cn-hangzhou.aliyuncs.com/xuzeyu91/antsk:v0.2.3
image: registry.cn-hangzhou.aliyuncs.com/xuzeyu91/antsk:v0.4.1
# 如果需要pytorch环境需要使用下面这个镜像镜像比较大
# image: registry.cn-hangzhou.aliyuncs.com/xuzeyu91/antsk:v0.2.3.2
# image: registry.cn-hangzhou.aliyuncs.com/xuzeyu91/antsk:p0.4.1
ports:
- 5000:5000
networks:
@@ -15,5 +15,7 @@ services:
- ASPNETCORE_URLS=http://*:5000
volumes:
- ./appsettings.json:/app/appsettings.json # 本地配置文件 需要放在同级目录
- /AntSK/model:/app/model
- /AntSK/model:/root/.cache/modelscope/hub/AI-ModelScope # LLamaFactory模型文件
networks:
antsk:

View File

@@ -1,6 +1,20 @@
# 非 host 版本, 不使用本机代理
version: '3.8'
services:
aspire-dashboard:
container_name: aspire-dashboard
image: mcr.microsoft.com/dotnet/aspire-dashboard:8.0
networks:
- antsk
environment:
- DOTNET_DASHBOARD_UNSECURED_ALLOW_ANONYMOUS=true
- ASPIRE_ALLOW_UNSECURED_TRANSPORT=true
- DASHBOARD_OTLP_AUTHMODE=ApiKey
- DASHBOARD_OTLP_PRIMARYAPIKEY=antsk
ports:
- 18888:18888
- 18889:18889
restart: unless-stopped
antskpg:
image: registry.cn-hangzhou.aliyuncs.com/xuzeyu91/pg:v0.5.0
container_name: antskpg
@@ -18,9 +32,9 @@ services:
- ./pg/data:/var/lib/postgresql/data
antsk:
container_name: antsk
image: registry.cn-hangzhou.aliyuncs.com/xuzeyu91/antsk:v0.2.3
image: registry.cn-hangzhou.aliyuncs.com/xuzeyu91/antsk:v0.4.1
# 如果需要pytorch环境需要使用下面这个镜像镜像比较大
# image: registry.cn-hangzhou.aliyuncs.com/xuzeyu91/antsk:v0.2.3.2
# image: registry.cn-hangzhou.aliyuncs.com/xuzeyu91/antsk:p0.4.1
ports:
- 5000:5000
networks:
@@ -30,7 +44,15 @@ services:
restart: always
environment:
- ASPNETCORE_URLS=http://*:5000
- ASPNETCORE_FORWARDEDHEADERS_ENABLED=true
- OTEL_DOTNET_EXPERIMENTAL_OTLP_EMIT_EXCEPTION_LOG_ATTRIBUTES=true
- OTEL_DOTNET_EXPERIMENTAL_OTLP_EMIT_EVENT_LOG_ATTRIBUTES= true
- OTEL_DOTNET_EXPERIMENTAL_OTLP_RETRY=in_memory
- OTEL_EXPORTER_OTLP_ENDPOINT=http://aspire-dashboard:18889
- OTEL_SERVICE_NAME=antsk
volumes:
- ./appsettings.json:/app/appsettings.json # 本地配置文件 需要放在同级目录
- /AntSK/model:/app/model
- /AntSK/model:/root/.cache/modelscope/hub/AI-ModelScope # LLamaFactory模型文件
networks:
antsk:

View File

@@ -0,0 +1,20 @@
<Project Sdk="Microsoft.NET.Sdk">
<PropertyGroup>
<OutputType>Exe</OutputType>
<TargetFramework>net8.0</TargetFramework>
<ImplicitUsings>enable</ImplicitUsings>
<Nullable>enable</Nullable>
<IsAspireHost>true</IsAspireHost>
<UserSecretsId>32ac67c8-178a-4eeb-871d-879023582e06</UserSecretsId>
</PropertyGroup>
<ItemGroup>
<PackageReference Include="Aspire.Hosting.AppHost" Version="8.0.1" />
</ItemGroup>
<ItemGroup>
<ProjectReference Include="..\AntSK\AntSK.csproj" />
</ItemGroup>
</Project>

View File

@@ -0,0 +1,5 @@
var builder = DistributedApplication.CreateBuilder(args);
builder.AddProject<Projects.AntSK>("antsk");
builder.Build().Run();

View File

@@ -0,0 +1,8 @@
{
"Logging": {
"LogLevel": {
"Default": "Information",
"Microsoft.AspNetCore": "Warning"
}
}
}

View File

@@ -0,0 +1,9 @@
{
"Logging": {
"LogLevel": {
"Default": "Information",
"Microsoft.AspNetCore": "Warning",
"Aspire.Hosting.Dcp": "Warning"
}
}
}

View File

@@ -0,0 +1,26 @@
services:
aspire-dashboard:
container_name: "aspire-dashboard"
image: "mcr.microsoft.com/dotnet/aspire-dashboard:8.0"
environment:
DOTNET_DASHBOARD_UNSECURED_ALLOW_ANONYMOUS: "true"
ports:
- target: 18888
published: 18888
restart: unless-stopped
antsk:
container_name: "antsk"
image: "antsk:latest"
environment:
OTEL_DOTNET_EXPERIMENTAL_OTLP_EMIT_EXCEPTION_LOG_ATTRIBUTES: "true"
OTEL_DOTNET_EXPERIMENTAL_OTLP_EMIT_EVENT_LOG_ATTRIBUTES: "true"
OTEL_DOTNET_EXPERIMENTAL_OTLP_RETRY: "in_memory"
ASPNETCORE_FORWARDEDHEADERS_ENABLED: "true"
OTEL_EXPORTER_OTLP_ENDPOINT: "http://aspire-dashboard:18889"
OTEL_SERVICE_NAME: "antsk"
ports:
- target: 8080
published: 10000
- target: 8443
published: 10001
restart: unless-stopped

View File

@@ -0,0 +1,17 @@
{
"projectPath": ".",
"outputPath": "aspirate-output",
"containerImageTags": [
"latest"
],
"containerBuilder": "docker",
"outputFormat": "compose",
"privateRegistryEmail": "aspir8@aka.ms",
"includeDashboard": true,
"secrets": {
"salt": "fjamZa3pQbM1UyY4",
"hash": "QR\u002BSEr3p2SwD/w2oPE21vrWh/EerhNyVyTkr0atIREw=",
"secrets": {}
},
"processAllComponents": true
}

View File

@@ -0,0 +1,26 @@
{
"resources": {
"antsk": {
"type": "project.v0",
"path": "../AntSK/AntSK.csproj",
"env": {
"OTEL_DOTNET_EXPERIMENTAL_OTLP_EMIT_EXCEPTION_LOG_ATTRIBUTES": "true",
"OTEL_DOTNET_EXPERIMENTAL_OTLP_EMIT_EVENT_LOG_ATTRIBUTES": "true",
"OTEL_DOTNET_EXPERIMENTAL_OTLP_RETRY": "in_memory",
"ASPNETCORE_FORWARDEDHEADERS_ENABLED": "true"
},
"bindings": {
"http": {
"scheme": "http",
"protocol": "tcp",
"transport": "http"
},
"https": {
"scheme": "https",
"protocol": "tcp",
"transport": "http"
}
}
}
}
}

View File

@@ -0,0 +1,53 @@
<Project Sdk="Microsoft.NET.Sdk">
<PropertyGroup>
<TargetFramework>net8.0</TargetFramework>
<ImplicitUsings>enable</ImplicitUsings>
<Nullable>enable</Nullable>
<DocumentationFile>AntSK.Domain.xml</DocumentationFile>
<NoWarn>CA1050,CA1707,CA2007,VSTHRD111,CS1591,RCS1110,CA5394,SKEXP0001,SKEXP0002,SKEXP0003,SKEXP0004,SKEXP0010,SKEXP0011,,SKEXP0012,SKEXP0020,SKEXP0021,SKEXP0022,SKEXP0023,SKEXP0024,SKEXP0025,SKEXP0026,SKEXP0027,SKEXP0028,SKEXP0029,SKEXP0030,SKEXP0031,SKEXP0032,SKEXP0040,SKEXP0041,SKEXP0042,SKEXP0050,SKEXP0051,SKEXP0052,SKEXP0053,SKEXP0054,SKEXP0055,SKEXP0060,SKEXP0061,SKEXP0101,SKEXP0102</NoWarn>
</PropertyGroup>
<ItemGroup>
<PackageReference Include="AntDesign.Charts" Version="0.5.1" />
<PackageReference Include="AntDesign.ProLayout" Version="0.18.2" />
<PackageReference Include="BlazorComponents.Terminal" Version="0.6.0" />
<PackageReference Include="Blazored.LocalStorage" Version="4.5.0" />
<PackageReference Include="pythonnet" Version="3.0.3" />
<PackageReference Include="Swashbuckle.AspNetCore" Version="6.5.0" />
<PackageReference Include="AutoMapper" Version="8.1.0" />
<PackageReference Include="BCrypt.Net-Next" Version="4.0.3" />
<PackageReference Include="Markdig" Version="0.37.0" />
<PackageReference Include="Newtonsoft.Json" Version="13.0.3" />
<PackageReference Include="SqlSugarCore" Version="5.1.4.151" />
<PackageReference Include="System.Data.SQLite.Core" Version="1.0.118" />
<PackageReference Include="RestSharp" Version="110.2.0" />
<PackageReference Include="NPOI" Version="2.7.0" />
<PackageReference Include="Microsoft.SemanticKernel" Version="1.7.1" />
<PackageReference Include="Microsoft.SemanticKernel.Core" Version="1.7.1" />
<PackageReference Include="Microsoft.SemanticKernel.Plugins.Core" Version="1.7.1-alpha" />
<PackageReference Include="Microsoft.KernelMemory.Core" Version="0.36.240415.2" />
<PackageReference Include="Microsoft.KernelMemory.MemoryDb.Postgres" Version="0.36.240415.2" />
<PackageReference Include="Microsoft.KernelMemory.MemoryDb.Qdrant" Version="0.36.240415.2" />
<PackageReference Include="Microsoft.KernelMemory.MemoryDb.Redis" Version="0.36.240415.2" />
<PackageReference Include="Microsoft.KernelMemory.MemoryDb.AzureAISearch" Version="0.36.240415.2" />
<PackageReference Include="LLamaSharp" Version="0.11.2" />
<PackageReference Include="LLamaSharp.Backend.Cpu" Version="0.11.2" />
<PackageReference Include="LLamaSharp.Backend.Cuda12" Version="0.11.2" />
<PackageReference Include="LLamaSharp.kernel-memory" Version="0.11.2" />
<PackageReference Include="LLamaSharp.semantic-kernel" Version="0.11.2" />
</ItemGroup>
<ItemGroup>
<ProjectReference Include="..\AntSK.LLamaFactory\AntSK.LLamaFactory.csproj" />
<ProjectReference Include="..\AntSk.LLM\AntSK.LLM.csproj" />
<ProjectReference Include="..\AntSK.OCR\AntSK.OCR.csproj" />
<ProjectReference Include="..\MiddleWare\AntSK.BackgroundTask\AntSK.BackgroundTask.csproj" />
</ItemGroup>
</Project>

View File

@@ -5,44 +5,54 @@
<ImplicitUsings>enable</ImplicitUsings>
<Nullable>enable</Nullable>
<DocumentationFile>AntSK.Domain.xml</DocumentationFile>
<NoWarn>CA1050,CA1707,CA2007,VSTHRD111,CS1591,RCS1110,CA5394,SKEXP0001,SKEXP0002,SKEXP0003,SKEXP0004,SKEXP0010,SKEXP0011,,SKEXP0012,SKEXP0020,SKEXP0021,SKEXP0022,SKEXP0023,SKEXP0024,SKEXP0025,SKEXP0026,SKEXP0027,SKEXP0028,SKEXP0029,SKEXP0030,SKEXP0031,SKEXP0032,SKEXP0040,SKEXP0041,SKEXP0042,SKEXP0050,SKEXP0051,SKEXP0052,SKEXP0053,SKEXP0054,SKEXP0055,SKEXP0060,SKEXP0061,SKEXP0101,SKEXP0102</NoWarn>
<NoWarn>CA1050,CA1707,CA2007,VSTHRD111,CS1591,RCS1110,CA5394,SKEXP0001,SKEXP0002,SKEXP0003,SKEXP0004,SKEXP0010,SKEXP0011,,SKEXP0012,SKEXP0020,SKEXP0021,SKEXP0022,SKEXP0023,SKEXP0024,SKEXP0025,SKEXP0026,SKEXP0027,SKEXP0028,SKEXP0029,SKEXP0030,SKEXP0031,SKEXP0032,SKEXP0040,SKEXP0041,SKEXP0042,SKEXP0050,SKEXP0051,SKEXP0052,SKEXP0053,SKEXP0054,SKEXP0055,SKEXP0060,SKEXP0061,SKEXP0101,SKEXP0102,KMEXP00</NoWarn>
</PropertyGroup>
<ItemGroup>
<PackageReference Include="AntDesign.Charts" Version="0.5.1" />
<PackageReference Include="AntDesign.ProLayout" Version="0.18.1" />
<PackageReference Include="AntDesign.Charts" Version="0.5.2" />
<PackageReference Include="AntDesign.ProLayout" Version="0.19.2" />
<PackageReference Include="BlazorComponents.Terminal" Version="0.6.0" />
<PackageReference Include="Blazored.LocalStorage" Version="4.5.0" />
<PackageReference Include="Swashbuckle.AspNetCore" Version="6.5.0" />
<PackageReference Include="pythonnet" Version="3.0.3" />
<PackageReference Include="Swashbuckle.AspNetCore" Version="6.6.2" />
<PackageReference Include="AutoMapper" Version="8.1.0" />
<PackageReference Include="BCrypt.Net-Next" Version="4.0.3" />
<PackageReference Include="Markdig" Version="0.36.2" />
<PackageReference Include="Newtonsoft.Json" Version="13.0.3" />
<PackageReference Include="SqlSugarCore" Version="5.1.4.148" />
<PackageReference Include="Markdig" Version="0.37.0" />
<PackageReference Include="Newtonsoft.Json" Version="$(NewtonsoftVersion)" />
<PackageReference Include="SqlSugarCore" Version="5.1.4.160" />
<PackageReference Include="System.Data.SQLite.Core" Version="1.0.118" />
<PackageReference Include="RestSharp" Version="110.2.0" />
<PackageReference Include="Microsoft.SemanticKernel" Version="1.6.3" />
<PackageReference Include="Microsoft.SemanticKernel.Core" Version="1.6.3" />
<PackageReference Include="Microsoft.SemanticKernel.Plugins.Core" Version="1.6.3-alpha" />
<PackageReference Include="Microsoft.KernelMemory.Core" Version="0.35.240321.1" />
<PackageReference Include="Microsoft.KernelMemory.MemoryDb.Postgres" Version="0.35.240321.1" />
<PackageReference Include="Microsoft.KernelMemory.MemoryDb.Qdrant" Version="0.35.240321.1" />
<PackageReference Include="Microsoft.KernelMemory.MemoryDb.Redis" Version="0.35.240321.1" />
<PackageReference Include="Microsoft.KernelMemory.MemoryDb.AzureAISearch" Version="0.35.240321.1" />
<PackageReference Include="RestSharp" Version="$(RestSharpVersion)" />
<PackageReference Include="NPOI" Version="2.7.0" />
<PackageReference Include="LLamaSharp" Version="0.10.0" />
<PackageReference Include="LLamaSharp.Backend.Cpu" Version="0.10.0" />
<PackageReference Include="LLamaSharp.Backend.Cuda12" Version="0.10.0" />
<PackageReference Include="LLamaSharp.kernel-memory" Version="0.10.0" />
<PackageReference Include="LLamaSharp.semantic-kernel" Version="0.10.0" />
<PackageReference Include="Microsoft.SemanticKernel" Version="$(SKVersion)" />
<PackageReference Include="Microsoft.SemanticKernel.Core" Version="$(SKVersion)" />
<PackageReference Include="Microsoft.SemanticKernel.Plugins.Core" Version="$(SKVersion)-alpha" />
<PackageReference Include="Microsoft.KernelMemory.Core" Version="$(KMVersion)" />
<PackageReference Include="Microsoft.KernelMemory.MemoryDb.Postgres" Version="$(KMVersion)" />
<PackageReference Include="Microsoft.KernelMemory.MemoryDb.Qdrant" Version="$(KMVersion)" />
<PackageReference Include="Microsoft.KernelMemory.MemoryDb.Redis" Version="$(KMVersion)" />
<PackageReference Include="Microsoft.KernelMemory.MemoryDb.AzureAISearch" Version="$(KMVersion)" />
<PackageReference Include="LLamaSharp" Version="$(LLamaSharpVersion)" />
<PackageReference Include="LLamaSharp.Backend.Cpu" Version="$(LLamaSharpVersion)" />
<PackageReference Include="LLamaSharp.Backend.Cuda12" Version="$(LLamaSharpVersion)" />
<PackageReference Include="LLamaSharp.kernel-memory" Version="$(LLamaSharpVersion)" />
<PackageReference Include="LLamaSharp.semantic-kernel" Version="$(LLamaSharpVersion)" />
<PackageReference Include="Serilog" Version="4.0.0" />
<PackageReference Include="Serilog.Sinks.Console" Version="6.0.0" />
<PackageReference Include="Serilog.Sinks.File" Version="6.0.0" />
<PackageReference Include="Serilog.Extensions.Logging" Version="8.0.1-dev-10391" />
<PackageReference Include="Serilog.Settings.Configuration" Version="8.0.1" />
<PackageReference Include="Serilog.Sinks.Seq" Version="8.0.0" />
<PackageReference Include="Serilog.Sinks.OpenTelemetry" Version="3.0.0" />
</ItemGroup>
<ItemGroup>
<ProjectReference Include="..\AntSK.LLamaFactory\AntSK.LLamaFactory.csproj" />
<ProjectReference Include="..\AntSk.LLM\AntSK.LLM.csproj" />
<ProjectReference Include="..\AntSK.OCR\AntSK.OCR.csproj" />
<ProjectReference Include="..\MiddleWare\AntSK.BackgroundTask\AntSK.BackgroundTask.csproj" />
</ItemGroup>

View File

@@ -69,6 +69,84 @@
<param name="value"></param>
<returns></returns>
</member>
<member name="M:AntSK.Domain.ExeclHelper.ExcelToDataTable(System.String,System.Boolean)">
<summary>
将excel导入到datatable
</summary>
<param name="filePath">excel路径</param>
<param name="isColumnName">第一行是否是列名</param>
<returns>返回datatable</returns>
</member>
<member name="M:AntSK.Domain.ExeclHelper.ExcelToDataTable(System.IO.Stream,System.Boolean)">
<summary>
将excel导入到datatable
</summary>
<param name="stream"></param>
<param name="isColumnName">第一行是否是列名</param>
<returns></returns>
</member>
<member name="M:AntSK.Domain.ExeclHelper.ExcelToList``1(System.IO.Stream)">
<summary>
excel转list
</summary>
<typeparam name="TResult"></typeparam>
<param name="stream"></param>
<returns></returns>
</member>
<member name="M:AntSK.Domain.ExeclHelper.ExcelToList``1(System.IO.Stream,System.String)">
<summary>
excel转list-根据sheetName得到List
</summary>
<typeparam name="TResult"></typeparam>
<param name="stream"></param>
<param name="sheetName"></param>
<returns></returns>
</member>
<member name="M:AntSK.Domain.ExeclHelper.ListToExcel``1(``0[],System.String)">
<summary>
List导出excel 二进制流
</summary>
<typeparam name="T">实体</typeparam>
<param name="data">List</param>
<param name="sheetName">sheetname 可不填默认Sheet0</param>
<returns></returns>
</member>
<member name="M:AntSK.Domain.ExeclHelper.DataTableToExcel(System.Data.DataTable,System.String,System.String)">
<summary>
Dt导出excel 二进制流
</summary>
<param name="dt">datatable</param>
<param name="strFile">strFile</param>
<returns></returns>
</member>
<member name="M:AntSK.Domain.ExeclHelper.ListWriteExcel``1(``0[],System.String,System.String)">
<summary>
List写入excel
</summary>
<typeparam name="T"></typeparam>
<param name="data"></param>
<param name="strFile">路径</param>
<param name="sheetName"></param>
<returns></returns>
</member>
<member name="M:AntSK.Domain.ExeclHelper.DataTableWriteExcel(System.Data.DataTable,System.String,System.String)">
<summary>
dt写入excel
</summary>
<param name="dt">datatable</param>
<param name="strFile">路径</param>
<returns></returns>
</member>
<member name="M:AntSK.Domain.ExeclHelper.SetCellDropdownList(NPOI.SS.UserModel.IWorkbook,NPOI.SS.UserModel.ISheet,System.Collections.Generic.List{System.String},System.String,System.Int32,System.Int32,System.Int32)">
<summary>
设置单元格下拉框(除去标题行)
</summary>
<param name="workbook"></param>
<param name="sheet"></param>
<param name="ddlList"></param>
<param name="firstcol"></param>
<param name="lastcol"></param>
</member>
<member name="T:AntSK.Domain.Domain.Model.Enum.AIType">
<summary>
AI类型
@@ -79,11 +157,6 @@
模型类型
</summary>
</member>
<member name="P:AntSK.Domain.Domain.Model.MessageInfo.IsSend">
<summary>
发送是true 接收是false
</summary>
</member>
<member name="P:AntSK.Domain.Domain.Model.PageList`1.PageIndex">
<summary>
当前页从1开始
@@ -99,12 +172,34 @@
总数
</summary>
</member>
<member name="M:AntSK.Domain.Domain.Other.Bge.BegRerankConfig.LoadModel(System.String,System.String)">
<summary>
模型写死
</summary>
</member>
<member name="M:AntSK.Domain.Domain.Other.Bge.BgeEmbeddingConfig.LoadModel(System.String,System.String)">
<summary>
模型写死
</summary>
</member>
<member name="P:AntSK.Domain.Domain.Other.KMExcelHandler.StepName">
<inheritdoc />
</member>
<member name="M:AntSK.Domain.Domain.Other.KMExcelHandler.InvokeAsync(Microsoft.KernelMemory.Pipeline.DataPipeline,System.Threading.CancellationToken)">
<inheritdoc />
</member>
<member name="F:AntSK.Domain.Domain.Other.LLamaConfig.dicLLamaWeights">
<summary>
避免模型重复加载,本地缓存
</summary>
</member>
<member name="M:AntSK.Domain.Domain.Service.ChatService.SendChatByAppAsync(AntSK.Domain.Repositories.Apps,System.String,Microsoft.SemanticKernel.ChatCompletion.ChatHistory)">
<member name="P:AntSK.Domain.Domain.Other.QAHandler.StepName">
<inheritdoc />
</member>
<member name="M:AntSK.Domain.Domain.Other.QAHandler.InvokeAsync(Microsoft.KernelMemory.Pipeline.DataPipeline,System.Threading.CancellationToken)">
<inheritdoc />
</member>
<member name="M:AntSK.Domain.Domain.Service.ChatService.SendChatByAppAsync(AntSK.Domain.Repositories.Apps,Microsoft.SemanticKernel.ChatCompletion.ChatHistory)">
<summary>
发送消息
</summary>
@@ -287,6 +382,56 @@
API调用秘钥
</summary>
</member>
<member name="P:AntSK.Domain.Repositories.Apps.Relevance">
<summary>
相似度
</summary>
</member>
<member name="P:AntSK.Domain.Repositories.Apps.MaxAskPromptSize">
<summary>
提问最大token数
</summary>
</member>
<member name="P:AntSK.Domain.Repositories.Apps.MaxMatchesCount">
<summary>
向量匹配数
</summary>
</member>
<member name="P:AntSK.Domain.Repositories.Apps.AnswerTokens">
<summary>
回答最大token数
</summary>
</member>
<member name="P:AntSK.Domain.Repositories.Chats.UserName">
<summary>
用户名
</summary>
</member>
<member name="P:AntSK.Domain.Repositories.Chats.AppId">
<summary>
应用ID
</summary>
</member>
<member name="P:AntSK.Domain.Repositories.Chats.Context">
<summary>
消息内容
</summary>
</member>
<member name="P:AntSK.Domain.Repositories.Chats.IsSend">
<summary>
发送是true 接收是false
</summary>
</member>
<member name="P:AntSK.Domain.Repositories.Chats.CreateTime">
<summary>
创建事件
</summary>
</member>
<member name="P:AntSK.Domain.Repositories.Chats.FileName">
<summary>
文件名
</summary>
</member>
<member name="P:AntSK.Domain.Repositories.Funs.Path">
<summary>
接口描述
@@ -771,6 +916,14 @@
<param name="parameters"></param>
<returns></returns>
</member>
<member name="M:AntSK.Domain.Utils.ConvertUtils.ComparisonIgnoreCase(System.String,System.String)">
<summary>
忽略大小写匹配
</summary>
<param name="s"></param>
<param name="value"></param>
<returns></returns>
</member>
<member name="M:AntSK.Domain.Utils.RepoFiles.SamplePluginsPath">
<summary>
Scan the local folders from the repo, looking for "samples/plugins" folder.

View File

@@ -5,6 +5,7 @@ using DocumentFormat.OpenXml.Office2016.Drawing.ChartDrawing;
using Microsoft.AspNetCore.Builder;
using Microsoft.AspNetCore.Mvc;
using Microsoft.Extensions.DependencyInjection;
using Microsoft.Extensions.Logging;
using Microsoft.OpenApi.Models;
using SqlSugar;
using Swashbuckle.AspNetCore.SwaggerGen;
@@ -19,6 +20,12 @@ namespace AntSK.Domain.Common.DependencyInjection
{
public static class InitExtensions
{
private static ILogger _logger;
public static void InitLog(ILogger logger)
{
_logger = logger;
}
/// <summary>
/// 使用codefirst创建数据库表
/// </summary>
@@ -50,6 +57,10 @@ namespace AntSK.Domain.Common.DependencyInjection
_repository.GetDB().CodeFirst.InitTables(type);
}
}
//安装向量插件
_repository.GetDB().Ado.ExecuteCommandAsync($"CREATE EXTENSION IF NOT EXISTS vector;");
_logger.LogInformation("初始化表结构完成");
}
return app;
}
@@ -70,7 +81,7 @@ namespace AntSK.Domain.Common.DependencyInjection
llamafactoryStart.Value = "false";
_dic_Repository.Insert(llamafactoryStart);
}
_logger.LogInformation("初始化数据库初始数据完成");
}
return app;
}
@@ -97,7 +108,7 @@ namespace AntSK.Domain.Common.DependencyInjection
}
catch (Exception ex)
{
Console.WriteLine(ex.Message + " ---- " + ex.StackTrace);
_logger.LogError(ex.Message + " ---- " + ex.StackTrace);
}
return app;
}

View File

@@ -0,0 +1,21 @@
using LLamaSharp.KernelMemory;
using Microsoft.KernelMemory.AI;
using Microsoft.KernelMemory;
using System;
using System.Collections.Generic;
using System.Linq;
using System.Text;
using System.Threading.Tasks;
namespace AntSK.Domain.Common.Embedding
{
public static class BuilderBgeExtensions
{
public static IKernelMemoryBuilder WithBgeTextEmbeddingGeneration(this IKernelMemoryBuilder builder, HuggingfaceTextEmbeddingGenerator textEmbeddingGenerator)
{
builder.AddSingleton((ITextEmbeddingGenerator)textEmbeddingGenerator);
builder.AddIngestionEmbeddingGenerator(textEmbeddingGenerator);
return builder;
}
}
}

View File

@@ -0,0 +1,56 @@
using LLama.Common;
using LLama;
using LLamaSharp.KernelMemory;
using Microsoft.KernelMemory.AI;
using System;
using System.Collections.Generic;
using System.Linq;
using System.Text;
using System.Threading.Tasks;
using AntSK.Domain.Domain.Other.Bge;
namespace AntSK.Domain.Common.Embedding
{
public class HuggingfaceTextEmbeddingGenerator : ITextEmbeddingGenerator, ITextTokenizer, IDisposable
{
public int MaxTokens => 1024;
public int MaxTokenTotal => 1024;
private readonly dynamic _embedder;
public HuggingfaceTextEmbeddingGenerator(string pyDllPath,string modelName)
{
_embedder = BgeEmbeddingConfig.LoadModel(pyDllPath, modelName);
}
public void Dispose()
{
BgeEmbeddingConfig.Dispose();
}
//public async Task<IList<ReadOnlyMemory<float>>> GenerateEmbeddingAsync(IList<string> data, CancellationToken cancellationToken = default)
//{
// IList<ReadOnlyMemory<float>> results = new List<ReadOnlyMemory<float>>();
// foreach (var d in data)
// {
// var embeddings = await EmbeddingConfig.GetEmbedding(d);
// results.Add(new ReadOnlyMemory<float>(embeddings));
// }
// return results;
//}
public async Task<Microsoft.KernelMemory.Embedding> GenerateEmbeddingAsync(string text, CancellationToken cancellationToken = default)
{
var embeddings = await BgeEmbeddingConfig.GetEmbedding(text);
return new Microsoft.KernelMemory.Embedding(embeddings);
}
public int CountTokens(string text)
{
return BgeEmbeddingConfig.TokenCount(text);
}
}
}

View File

@@ -0,0 +1,822 @@
using NPOI.HSSF.UserModel;
using NPOI.SS.UserModel;
using NPOI.SS.Util;
using NPOI.XSSF.Streaming;
using NPOI.XSSF.UserModel;
using System;
using System.Collections.Generic;
using System.Data;
using System.IO;
using System.Linq;
using System.Reflection;
using System.Threading.Tasks;
namespace AntSK.Domain
{
public class ExeclHelper
{
/// <summary>
/// 将excel导入到datatable
/// </summary>
/// <param name="filePath">excel路径</param>
/// <param name="isColumnName">第一行是否是列名</param>
/// <returns>返回datatable</returns>
public static DataTable ExcelToDataTable(string filePath, bool isColumnName)
{
DataTable dataTable = null;
FileStream fs = null;
DataColumn column = null;
DataRow dataRow = null;
IWorkbook workbook = null;
ISheet sheet = null;
IRow row = null;
ICell cell = null;
int startRow = 0;
try
{
using (fs = File.OpenRead(filePath))
{
// 2007版本
if (filePath.Contains(".xlsx"))
workbook = new XSSFWorkbook(fs);
// 2003版本
else if (filePath.Contains(".xls"))
workbook = new HSSFWorkbook(fs);
if (workbook != null)
{
sheet = workbook.GetSheetAt(0);//读取第一个sheet当然也可以循环读取每个sheet
dataTable = new DataTable();
if (sheet != null)
{
int rowCount = sheet.LastRowNum;//总行数
if (rowCount > 0)
{
IRow firstRow = sheet.GetRow(0);//第一行
int cellCount = firstRow.LastCellNum;//列数
//构建datatable的列
if (isColumnName)
{
startRow = 1;//如果第一行是列名,则从第二行开始读取
for (int i = firstRow.FirstCellNum; i < cellCount; ++i)
{
cell = firstRow.GetCell(i);
if (cell != null)
{
if (cell.StringCellValue != null)
{
column = new DataColumn(cell.StringCellValue);
dataTable.Columns.Add(column);
}
}
}
}
else
{
for (int i = firstRow.FirstCellNum; i < cellCount; ++i)
{
column = new DataColumn("column" + (i + 1));
dataTable.Columns.Add(column);
}
}
//填充行
for (int i = startRow; i <= rowCount; ++i)
{
row = sheet.GetRow(i);
if (row == null) continue;
dataRow = dataTable.NewRow();
for (int j = row.FirstCellNum; j < cellCount; ++j)
{
cell = row.GetCell(j);
if (cell == null)
{
dataRow[j] = "";
}
else
{
//CellType(Unknown = -1,Numeric = 0,String = 1,Formula = 2,Blank = 3,Boolean = 4,Error = 5,)
switch (cell.CellType)
{
case CellType.Blank:
dataRow[j] = "";
break;
case CellType.Numeric:
short format = cell.CellStyle.DataFormat;
//对时间格式2015.12.5、2015/12/5、2015-12-5等的处理
if (format == 14 || format == 31 || format == 57 || format == 58)
dataRow[j] = cell.DateCellValue;
else
dataRow[j] = cell.NumericCellValue;
break;
case CellType.String:
dataRow[j] = cell.StringCellValue;
break;
}
}
}
dataTable.Rows.Add(dataRow);
}
}
}
}
}
return dataTable;
}
catch (Exception)
{
if (fs != null)
{
fs.Close();
}
return null;
}
}
/// <summary>
/// 将excel导入到datatable
/// </summary>
/// <param name="stream">流</param>
/// <param name="isColumnName">第一行是否是列名</param>
/// <returns></returns>
public static DataTable ExcelToDataTable(Stream stream, bool isColumnName)
{
DataTable dataTable = null;
DataColumn column = null;
DataRow dataRow = null;
IWorkbook workbook = new XSSFWorkbook(stream);
ISheet sheet = null;
IRow row = null;
ICell cell = null;
int startRow = 0;
try
{
if (workbook != null)
{
sheet = workbook.GetSheetAt(0);//读取第一个sheet当然也可以循环读取每个sheet
dataTable = new DataTable();
if (sheet != null)
{
int rowCount = sheet.LastRowNum;//总行数
if (rowCount > 0)
{
IRow firstRow = sheet.GetRow(0);//第一行
int cellCount = firstRow.LastCellNum;//列数
//构建datatable的列
if (isColumnName)
{
startRow = 1;//如果第一行是列名,则从第二行开始读取
for (int i = firstRow.FirstCellNum; i < cellCount; ++i)
{
cell = firstRow.GetCell(i);
if (cell != null)
{
if (cell.StringCellValue != null)
{
column = new DataColumn(cell.StringCellValue);
dataTable.Columns.Add(column);
}
}
}
}
else
{
for (int i = firstRow.FirstCellNum; i < cellCount; ++i)
{
column = new DataColumn("column" + (i + 1));
dataTable.Columns.Add(column);
}
}
//填充行
for (int i = startRow; i <= rowCount; ++i)
{
row = sheet.GetRow(i);
if (row == null) continue;
dataRow = dataTable.NewRow();
for (int j = row.FirstCellNum; j < cellCount; ++j)
{
cell = row.GetCell(j);
if (cell == null)
{
dataRow[j] = "";
}
else
{
//CellType(Unknown = -1,Numeric = 0,String = 1,Formula = 2,Blank = 3,Boolean = 4,Error = 5,)
switch (cell.CellType)
{
case CellType.Blank:
dataRow[j] = "";
break;
case CellType.Numeric:
short format = cell.CellStyle.DataFormat;
//对时间格式2015.12.5、2015/12/5、2015-12-5等的处理
if (format == 14 || format == 31 || format == 57 || format == 58)
dataRow[j] = cell.DateCellValue;
else
dataRow[j] = cell.NumericCellValue;
break;
case CellType.String:
dataRow[j] = cell.StringCellValue;
break;
}
}
}
dataTable.Rows.Add(dataRow);
}
}
}
}
return dataTable;
}
catch (Exception)
{
throw;
}
}
/// <summary>
/// excel转list
/// </summary>
/// <typeparam name="TResult"></typeparam>
/// <param name="stream"></param>
/// <returns></returns>
public static IEnumerable<TResult> ExcelToList<TResult>(Stream stream) where TResult : new()
{
var propertyInfos = typeof(TResult).GetProperties(BindingFlags.Public | BindingFlags.Instance).Where(p => p.CustomAttributes.Count() > 0)
.OrderBy(p => p.GetCustomAttribute<ExeclPropertyAttribute>().Order).ToArray();
List<TResult> list = new List<TResult>();
IWorkbook workbook = new XSSFWorkbook(stream);
ISheet sheet = null;
IRow row = null;
ICell cell = null;
int startRow = 1;
try
{
if (workbook != null)
{
sheet = workbook.GetSheetAt(0);//读取第一个sheet当然也可以循环读取每个sheet
if (sheet != null)
{
int rowCount = sheet.LastRowNum;//总行数
if (rowCount > 0)
{
IRow firstRow = sheet.GetRow(0);//第一行
int cellCount = firstRow.LastCellNum;//列数
//填充行
for (int i = startRow; i <= rowCount; ++i)
{
row = sheet.GetRow(i);
if (row == null) continue;
bool emptyRow = true;//是否空行
TResult dataModel = new TResult();
for (int j = row.FirstCellNum; j < cellCount; ++j)
{
var execlPropertyAttribute = propertyInfos[j].GetCustomAttribute<ExeclPropertyAttribute>();
cell = row.GetCell(j);
if (cell == null)
{
propertyInfos[j].SetValue(dataModel, "");
}
else
{
switch (cell.CellType)
{
case CellType.Blank:
propertyInfos[j].SetValue(dataModel, "");
break;
case CellType.Numeric:
short format = cell.CellStyle.DataFormat;
//对时间格式2015.12.5、2015/12/5、2015-12-5等的处理
if (format == 14 || format == 31 || format == 57 || format == 58)
propertyInfos[j].SetValue(dataModel, cell.DateCellValue);
else
{
if (execlPropertyAttribute.CellType == CellType.String)
{
propertyInfos[j].SetValue(dataModel, cell.NumericCellValue.ToString());
}
else
{
propertyInfos[j].SetValue(dataModel, cell.NumericCellValue);
}
}
break;
case CellType.String:
propertyInfos[j].SetValue(dataModel, cell.StringCellValue);
break;
}
}
if (cell != null && !string.IsNullOrEmpty(cell.ToString().Trim()))
{
emptyRow = false;
}
}
//非空数据行数据添加到DataTable
if (!emptyRow)
{
list.Add(dataModel);
}
}
}
}
}
return list;
}
catch (Exception)
{
throw;
}
}
public static IEnumerable<TResult> ExcelToListFileName<TResult>(Stream stream, string fileName) where TResult : new()
{
var propertyInfos = typeof(TResult).GetProperties(BindingFlags.Public | BindingFlags.Instance).Where(p => p.CustomAttributes.Count() > 0)
.OrderBy(p => p.GetCustomAttribute<ExeclPropertyAttribute>().Order).ToArray();
List<TResult> list = new List<TResult>();
IWorkbook workbook = null;
if (fileName.Contains(".xlsx"))
workbook = new XSSFWorkbook(stream);
// 2003版本
else if (fileName.Contains(".xls"))
workbook = new HSSFWorkbook(stream);
ISheet sheet = null;
IRow row = null;
ICell cell = null;
int startRow = 1;
try
{
if (workbook != null)
{
sheet = workbook.GetSheetAt(0);//读取第一个sheet当然也可以循环读取每个sheet
if (sheet != null)
{
int rowCount = sheet.LastRowNum;//总行数
if (rowCount > 0)
{
IRow firstRow = sheet.GetRow(0);//第一行
int cellCount = firstRow.LastCellNum;//列数
//填充行
for (int i = startRow; i <= rowCount; ++i)
{
row = sheet.GetRow(i);
if (row == null) continue;
bool emptyRow = true;//是否空行
TResult dataModel = new TResult();
for (int j = row.FirstCellNum; j < cellCount; ++j)
{
var execlPropertyAttribute = propertyInfos[j].GetCustomAttribute<ExeclPropertyAttribute>();
cell = row.GetCell(j);
if (cell == null)
{
propertyInfos[j].SetValue(dataModel, "");
}
else
{
switch (cell.CellType)
{
case CellType.Blank:
propertyInfos[j].SetValue(dataModel, "");
break;
case CellType.Numeric:
short format = cell.CellStyle.DataFormat;
//对时间格式2015.12.5、2015/12/5、2015-12-5等的处理
if (format == 14 || format == 31 || format == 57 || format == 58)
propertyInfos[j].SetValue(dataModel, cell.DateCellValue);
else
{
if (execlPropertyAttribute.CellType == CellType.String)
{
propertyInfos[j].SetValue(dataModel, cell.NumericCellValue.ToString());
}
else
{
propertyInfos[j].SetValue(dataModel, cell.NumericCellValue);
}
}
break;
case CellType.String:
propertyInfos[j].SetValue(dataModel, cell.StringCellValue);
break;
}
}
if (cell != null && !string.IsNullOrEmpty(cell.ToString().Trim()))
{
emptyRow = false;
}
}
//非空数据行数据添加到DataTable
if (!emptyRow)
{
list.Add(dataModel);
}
}
}
}
}
return list;
}
catch (Exception)
{
throw;
}
}
/// <summary>
/// excel转list-根据sheetName得到List
/// </summary>
/// <typeparam name="TResult"></typeparam>
/// <param name="stream"></param>
/// <param name="sheetName"></param>
/// <returns></returns>
public static IEnumerable<TResult> ExcelToList<TResult>(Stream stream, string sheetName) where TResult : new()
{
var propertyInfos = typeof(TResult).GetProperties(BindingFlags.Public | BindingFlags.Instance)
.OrderBy(p => p.GetCustomAttribute<ExeclPropertyAttribute>().Order).ToArray();
List<TResult> list = new List<TResult>();
IWorkbook workbook = new XSSFWorkbook(stream);
ISheet sheet = null;
IRow row = null;
ICell cell = null;
int startRow = 1;
try
{
if (workbook != null)
{
sheet = workbook.GetSheet(sheetName);//根据sheet读取对应的DataTable
if (sheet != null)
{
int rowCount = sheet.LastRowNum;//总行数
if (rowCount > 0)
{
IRow firstRow = sheet.GetRow(0);//第一行
int cellCount = firstRow.LastCellNum;//列数
//填充行
for (int i = startRow; i <= rowCount; ++i)
{
row = sheet.GetRow(i);
if (row == null) continue;
bool emptyRow = true;//是否空行
TResult dataModel = new TResult();
for (int j = row.FirstCellNum; j < cellCount; ++j)
{
var execlPropertyAttribute = propertyInfos[j].GetCustomAttribute<ExeclPropertyAttribute>();
cell = row.GetCell(j);
if (cell == null)
{
propertyInfos[j].SetValue(dataModel, "");
}
else
{
switch (cell.CellType)
{
case CellType.Blank:
propertyInfos[j].SetValue(dataModel, "");
break;
case CellType.Numeric:
short format = cell.CellStyle.DataFormat;
//对时间格式2015.12.5、2015/12/5、2015-12-5等的处理
if (format == 14 || format == 31 || format == 57 || format == 58)
propertyInfos[j].SetValue(dataModel, cell.DateCellValue);
else
{
if (execlPropertyAttribute.CellType == CellType.String)
{
propertyInfos[j].SetValue(dataModel, cell.NumericCellValue.ToString());
}
else
{
propertyInfos[j].SetValue(dataModel, cell.NumericCellValue);
}
}
break;
case CellType.String:
propertyInfos[j].SetValue(dataModel, cell.StringCellValue);
break;
}
}
if (cell != null && !string.IsNullOrEmpty(cell.ToString().Trim()))
{
emptyRow = false;
}
}
//非空数据行数据添加到DataTable
if (!emptyRow)
{
list.Add(dataModel);
}
}
}
}
}
return list;
}
catch (Exception ex)
{
throw;
}
}
/// <summary>
/// List导出excel 二进制流
/// </summary>
/// <typeparam name="T">实体</typeparam>
/// <param name="data">List</param>
/// <param name="sheetName">sheetname 可不填默认Sheet0</param>
/// <returns></returns>
public static byte[] ListToExcel<T>(T[] data, string sheetName = "Sheet0")
{
IWorkbook workbook = null;
IRow row = null;
ISheet sheet = null;
ICell cell = null;
var propertyInfos = typeof(T).GetProperties(BindingFlags.Public | BindingFlags.Instance)
.OrderBy(p => p.GetCustomAttribute<ExeclPropertyAttribute>().Order).ToArray();
workbook = new XSSFWorkbook();
sheet = workbook.CreateSheet(sheetName);//创建一个名称为Sheet0的表
int rowCount = data.Count();//行数
int columnCount = propertyInfos.Length;//列数
//设置列头
row = sheet.CreateRow(0);//excel第一行设为列头
for (int c = 0; c < columnCount; c++)
{
cell = row.CreateCell(c);
cell.SetCellValue(propertyInfos[c].GetCustomAttribute<ExeclPropertyAttribute>().DisplayName);
}
//设置每行每列的单元格,
for (int i = 0; i < rowCount; i++)
{
row = sheet.CreateRow(i + 1);
for (int j = 0; j < columnCount; j++)
{
cell = row.CreateCell(j);//excel第二行开始写入数据
cell.SetCellValue(propertyInfos[j].GetValue(data[i])?.ToString());
}
}
using (MemoryStream memoryStream = new MemoryStream())
{
workbook.Write(memoryStream);//向打开的这个xls文件中写入数据
return memoryStream.ToArray();
}
}
/// <summary>
/// Dt导出excel 二进制流
/// </summary>
/// <param name="dt">datatable</param>
/// <param name="strFile">strFile</param>
/// <returns></returns>
public static byte[] DataTableToExcel(DataTable dt, string strFile, string sheetName = "Sheet0")
{
bool result = false;
IWorkbook workbook = null;
FileStream fs = null;
IRow row = null;
ISheet sheet = null;
ICell cell = null;
if (dt != null && dt.Rows.Count > 0)
{
workbook = new XSSFWorkbook();
sheet = workbook.CreateSheet(sheetName);//创建一个名称为Sheet0的表
int rowCount = dt.Rows.Count;//行数
int columnCount = dt.Columns.Count;//列数
//设置列头
row = sheet.CreateRow(0);//excel第一行设为列头
for (int c = 0; c < columnCount; c++)
{
cell = row.CreateCell(c);
cell.SetCellValue(dt.Columns[c].ColumnName);
}
//设置每行每列的单元格,
for (int i = 0; i < rowCount; i++)
{
row = sheet.CreateRow(i + 1);
for (int j = 0; j < columnCount; j++)
{
cell = row.CreateCell(j);//excel第二行开始写入数据
cell.SetCellValue(dt.Rows[i][j].ToString());
}
}
using (MemoryStream memoryStream = new MemoryStream())
{
workbook.Write(memoryStream);//向打开的这个xls文件中写入数据
return memoryStream.ToArray();
}
}
else
{
return new byte[0];
}
}
/// <summary>
/// List写入excel
/// </summary>
/// <typeparam name="T"></typeparam>
/// <param name="data"></param>
/// <param name="strFile">路径</param>
/// <param name="sheetName"></param>
/// <returns></returns>
public static bool ListWriteExcel<T>(T[] data, string strFile, string sheetName = "Sheet0")
{
bool result = false;
IWorkbook workbook = null;
FileStream fs = null;
IRow row = null;
ISheet sheet = null;
ICell cell = null;
try
{
var propertyInfos = typeof(T).GetProperties(BindingFlags.Public | BindingFlags.Instance)
.OrderBy(p => p.GetCustomAttribute<ExeclPropertyAttribute>().Order).ToArray();
workbook = new XSSFWorkbook();
sheet = workbook.CreateSheet(sheetName);//创建一个名称为Sheet0的表
int rowCount = data.Count();//行数
int columnCount = propertyInfos.Length;//列数
//设置列头
row = sheet.CreateRow(0);//excel第一行设为列头
for (int c = 0; c < columnCount; c++)
{
cell = row.CreateCell(c);
cell.SetCellValue(propertyInfos[c].GetCustomAttribute<ExeclPropertyAttribute>().DisplayName);
}
//设置每行每列的单元格,
for (int i = 0; i < rowCount; i++)
{
row = sheet.CreateRow(i + 1);
for (int j = 0; j < columnCount; j++)
{
cell = row.CreateCell(j);//excel第二行开始写入数据
cell.SetCellValue(propertyInfos[j].GetValue(data[i])?.ToString());
}
}
using (fs = File.OpenWrite(strFile))
{
workbook.Write(fs);//向打开的这个xls文件中写入数据
result = true;
}
return result;
}
catch (Exception ex)
{
if (fs != null)
{
fs.Close();
}
return false;
}
}
/// <summary>
/// dt写入excel
/// </summary>
/// <param name="dt">datatable</param>
/// <param name="strFile">路径</param>
/// <returns></returns>
public static bool DataTableWriteExcel(DataTable dt, string strFile, string sheetName = "Sheet0")
{
bool result = false;
IWorkbook workbook = null;
FileStream fs = null;
IRow row = null;
ISheet sheet = null;
ICell cell = null;
try
{
if (dt != null && dt.Rows.Count > 0)
{
workbook = new XSSFWorkbook();
sheet = workbook.CreateSheet(sheetName);//创建一个名称为Sheet0的表
int rowCount = dt.Rows.Count;//行数
int columnCount = dt.Columns.Count;//列数
//设置列头
row = sheet.CreateRow(0);//excel第一行设为列头
for (int c = 0; c < columnCount; c++)
{
cell = row.CreateCell(c);
cell.SetCellValue(dt.Columns[c].ColumnName);
}
//设置每行每列的单元格,
for (int i = 0; i < rowCount; i++)
{
row = sheet.CreateRow(i + 1);
for (int j = 0; j < columnCount; j++)
{
cell = row.CreateCell(j);//excel第二行开始写入数据
cell.SetCellValue(dt.Rows[i][j].ToString());
}
}
using (fs = File.OpenWrite(strFile))
{
workbook.Write(fs);//向打开的这个xls文件中写入数据
result = true;
}
}
return result;
}
catch (Exception ex)
{
if (fs != null)
{
fs.Close();
}
return false;
}
}
/// <summary>
/// 设置单元格下拉框(除去标题行)
/// </summary>
/// <param name="workbook"></param>
/// <param name="sheet"></param>
/// <param name="ddlList"></param>
/// <param name="firstcol"></param>
/// <param name="lastcol"></param>
public static void SetCellDropdownList(IWorkbook workbook, ISheet sheet, List<string> ddlList, string sheetname, int sheetIndex, int firstcol, int lastcol)
{
# region ExcelHSSFWorkbook
//ISheet sheet2 = workbook.CreateSheet(sheetname);
////隐藏
//workbook.SetSheetHidden(sheetIndex, 1);
//int rowIndex = 0;
//foreach (var item in ddlList)
//{
// IRow vrow = sheet2.CreateRow(rowIndex);
// vrow.CreateCell(0).SetCellValue(item);
// rowIndex++;
//}
////创建的下拉项的区域:
//var rangeName = sheetname + "Range";
//IName range = workbook.CreateName();
//range.RefersToFormula = sheetname + "!$A$1:$A$" + rowIndex;
//range.NameName = rangeName;
//CellRangeAddressList regions = new CellRangeAddressList(1, 65535, firstcol, lastcol);
//DVConstraint constraint = DVConstraint.CreateFormulaListConstraint(rangeName);
//HSSFDataValidation dataValidate = new HSSFDataValidation(regions, constraint);
//dataValidate.CreateErrorBox("输入不合法", "请输入或选择下拉列表中的值。");
//dataValidate.ShowPromptBox = true;
//sheet.AddValidationData(dataValidate);
#endregion
//高版本excel【XSSFWorkbook】 设置下拉框
XSSFSheet sheetDDL = (XSSFSheet)workbook.CreateSheet(sheetname);
workbook.SetSheetHidden(sheetIndex, 1); //隐藏下拉框数据sheet
String[] datas = ddlList.ToArray(); //下拉框数据源
XSSFDataValidationHelper dvHelper = new XSSFDataValidationHelper(sheetDDL);
XSSFDataValidationConstraint dvConstraint = (XSSFDataValidationConstraint)dvHelper.CreateExplicitListConstraint(datas);
CellRangeAddressList addressList = new CellRangeAddressList(1, 65535, firstcol, lastcol); //下拉设置列
XSSFDataValidation validation = (XSSFDataValidation)dvHelper.CreateValidation(dvConstraint, addressList);
validation.SuppressDropDownArrow = true;
validation.ShowErrorBox = true;
validation.ShowPromptBox = true;
sheet.AddValidationData(validation);
}
}
}

View File

@@ -0,0 +1,28 @@
using NPOI.SS.UserModel;
using System;
using System.Collections.Generic;
using System.Linq;
using System.Threading.Tasks;
namespace AntSK.Domain
{
public class ExeclPropertyAttribute : Attribute
{
public ExeclPropertyAttribute()
{
}
public ExeclPropertyAttribute(string displayName, int order, CellType cellType = CellType.String)
{
DisplayName = displayName;
Order = order;
CellType = cellType;
}
public string DisplayName { get; set; }
public int Order { get; set; }
public CellType CellType { get; set; }
}
}

View File

@@ -1,4 +1,6 @@
using System;
using Amazon.Runtime.Internal.Util;
using Microsoft.Extensions.Logging;
using System;
using System.Collections.Generic;
using System.Diagnostics;
using System.Linq;
@@ -7,7 +9,7 @@ using System.Threading.Tasks;
namespace AntSK.Domain.Common.LLamaFactory
{
public class ProcessWrapper
public class ProcessWrapper(ILogger<ProcessWrapper> _logger)
{
private Process process;
@@ -41,7 +43,7 @@ namespace AntSK.Domain.Common.LLamaFactory
isProcessComplete = true;
}
}
Console.WriteLine(result);
_logger.LogInformation(result);
}
start.WaitForExit();
}

View File

@@ -5,6 +5,7 @@ using Microsoft.SemanticKernel;
using Microsoft.SemanticKernel.ChatCompletion;
using System;
using System.Collections.Generic;
using System.Drawing;
using System.Linq;
using System.Text;
using System.Threading.Tasks;
@@ -13,10 +14,10 @@ namespace AntSK.Domain.Domain.Interface
{
public interface IChatService
{
IAsyncEnumerable<StreamingKernelContent> SendChatByAppAsync(Apps app, string questions, ChatHistory history);
IAsyncEnumerable<string> SendChatByAppAsync(Apps app, ChatHistory history);
IAsyncEnumerable<StreamingKernelContent> SendKmsByAppAsync(Apps app, string questions, ChatHistory history, string filePath, List<RelevantSource> relevantSources = null);
Task<ChatHistory> GetChatHistory(List<MessageInfo> MessageList);
Task<string> SendImgByAppAsync(Apps app, string questions);
Task<ChatHistory> GetChatHistory(List<Chats> MessageList, ChatHistory history);
}
}

View File

@@ -7,13 +7,13 @@ namespace AntSK.Domain.Domain.Interface
{
public interface IKMService
{
MemoryServerless GetMemory(Apps app);
MemoryServerless GetMemoryByApp(Apps app);
MemoryServerless GetMemoryByKMS(string kmsID, SearchClientConfig searchClientConfig = null);
MemoryServerless GetMemoryByKMS(string kmsID);
Task<List<KMFile>> GetDocumentByFileID(string kmsId, string fileId);
Task<List<RelevantSource>> GetRelevantSourceList(string kmsIdListStr, string msg);
Task<List<RelevantSource>> GetRelevantSourceList(Apps app, string msg);
List<UploadFileItem> FileList { get; }

View File

@@ -6,6 +6,8 @@ namespace AntSK.Domain.Domain.Interface
public interface IKernelService
{
Kernel GetKernelByApp(Apps app);
Kernel GetKernelByAIModelID(string modelid);
void ImportFunctionsByApp(Apps app, Kernel _kernel);
Task<string> HistorySummarize(Kernel _kernel, string questions, string history);
}

View File

@@ -0,0 +1,15 @@
using System;
using System.Collections.Generic;
using System.Linq;
using System.Text;
using System.Threading.Tasks;
using static AntSK.Domain.Domain.Service.OllamaService;
namespace AntSK.Domain.Domain.Interface
{
public interface IOllamaService
{
public event LogMessageHandler LogMessageReceived;
Task StartOllama(string modelName);
}
}

View File

@@ -9,7 +9,29 @@ namespace AntSK.Domain.Domain.Model.Constant
public class KmsConstantcs
{
public const string KmsIdTag = "kmsid";
public const string FileIdTag = "fileid";
public const string AppIdTag = "appid";
public const string KmsIndex = "kms";
public const string FileIndex = "kms";
public const string KmsSearchNull="知识库未搜索到相关内容";
public const string KmsPrompt = @"使用<data></data>标记的内容作为你的知识:
<data>
{{$doc}}
</data>
--------------------------
回答要求:
- 如果你不清楚答案,你需要澄清
- 避免提及你是从<data></data>获取的知识
- 保持答案与<data></data>众描述一致
- 使用Markdown语法优化回答格式。
- 如果Markdown有图片则正常显示
--------------------------
历史聊天记录:{{ConversationSummaryPlugin.SummarizeConversation $history}}
--------------------------
用户问题: {{$input}}";
public const string KMExcelSplit = "*&antsk_excel&*";
}
}

View File

@@ -8,6 +8,8 @@ namespace AntSK.Domain.Domain.Model.Dto
public string Text { get; set; }
public float Relevance { get; set; }
public double RerankScore { get; set; }
public override string ToString()
{
return $"[file:{SourceName};Relevance:{(Relevance * 100):F2}%]:{Text}";

View File

@@ -21,12 +21,21 @@ namespace AntSK.Domain.Domain.Model.Enum
[Display(Name = "灵积大模型")]
DashScope = 5,
[Display(Name = "LLamaFactory")]
LLamaFactory = 6,
[Display(Name = "Bge Embedding")]
BgeEmbedding = 7,
[Display(Name = "Bge Rerank")]
BgeRerank = 8,
[Display(Name = "StableDiffusion")]
StableDiffusion = 9,
[Display(Name = "Ollama")]
Ollama = 10,
[Display(Name = "模拟输出")]
Mock = 100,
}
/// <summary>
@@ -36,5 +45,7 @@ namespace AntSK.Domain.Domain.Model.Enum
{
Chat = 1,
Embedding = 2,
Image=3,
Rerank=4
}
}

View File

@@ -9,6 +9,7 @@ namespace AntSK.Domain.Domain.Model.Enum
public enum AppType
{
chat = 1,
kms = 2
kms = 2,
img=3
}
}

View File

@@ -0,0 +1,17 @@
using System;
using System.Collections.Generic;
using System.Linq;
using System.Text;
using System.Threading.Tasks;
namespace AntSK.Domain.Domain.Model.Excel
{
public class KMSExcelModel
{
[ExeclProperty("问题",0)]
public string Question { get; set; }
[ExeclProperty("答案", 1)]
public string Answer { get; set; }
}
}

View File

@@ -17,11 +17,14 @@ namespace AntSK.Domain.Domain.Model
public string FilePath { get; set; } = "";
public string FileName { get; set; } = "";
public bool IsQA { get; set; } = false;
}
public class ImportKMSTaskReq : ImportKMSTaskDTO
{
public bool IsQA { get; set; }=false;
public KmsDetails KmsDetail { get; set; } = new KmsDetails();
}
@@ -29,6 +32,13 @@ namespace AntSK.Domain.Domain.Model
{
File = 1,
Url = 2,
Text = 3
Text = 3,
Excel=4
}
public class QAModel
{
public string ChatModelId { get; set; }
public string Context { get; set; }
}
}

View File

@@ -1,20 +0,0 @@
namespace AntSK.Domain.Domain.Model
{
public class MessageInfo
{
public string ID { get; set; } = "";
public string Context { get; set; } = "";
public string HtmlAnswers { get; set; } = "";
/// <summary>
/// 发送是true 接收是false
/// </summary>
public bool IsSend { get; set; } = false;
public DateTime CreateTime { get; set; }
public string? FilePath { get; set; }
public string? FileName { get; set; }
}
}

View File

@@ -1,27 +1,31 @@
using AntSK.BackgroundTask;
using Amazon.Runtime.Internal.Util;
using AntSK.BackgroundTask;
using AntSK.Domain.Domain.Interface;
using AntSK.Domain.Domain.Model;
using Microsoft.Extensions.DependencyInjection;
using Microsoft.Extensions.Logging;
namespace AntSK.Domain.Domain.Other
{
public class BackGroundTaskHandler : IBackgroundTaskHandler<ImportKMSTaskReq>
{
private readonly IServiceScopeFactory _scopeFactory;
private readonly ILogger<BackGroundTaskHandler> _logger;
public BackGroundTaskHandler(IServiceScopeFactory scopeFactory)
public BackGroundTaskHandler(IServiceScopeFactory scopeFactory, ILogger<BackGroundTaskHandler> logger)
{
_scopeFactory = scopeFactory;
_logger = logger;
}
public async Task ExecuteAsync(ImportKMSTaskReq item)
{
using (var scope = _scopeFactory.CreateScope())
{
Console.WriteLine("ExecuteAsync.开始执行后台任务");
_logger.LogInformation("ExecuteAsync.开始执行后台任务");
var importKMSService = scope.ServiceProvider.GetRequiredService<IImportKMSService>();
//不能使用异步
importKMSService.ImportKMSTask(item);
Console.WriteLine("ExecuteAsync.后台任务执行完成");
_logger.LogInformation("ExecuteAsync.后台任务执行完成");
}
}

View File

@@ -0,0 +1,82 @@
using Newtonsoft.Json;
using Python.Runtime;
using System;
using System.Collections.Generic;
using System.Linq;
using System.Text;
using System.Threading.Tasks;
using static Python.Runtime.Py;
namespace AntSK.Domain.Domain.Other.Bge
{
public static class BegRerankConfig
{
public static dynamic model { get; set; }
static object lockobj = new object();
/// <summary>
/// 模型写死
/// </summary>
public static dynamic LoadModel(string pythondllPath, string modelName)
{
lock (lockobj)
{
if (model == null)
{
if (string.IsNullOrEmpty(Runtime.PythonDLL))
{
Runtime.PythonDLL = pythondllPath;
}
PythonEngine.Initialize();
PythonEngine.BeginAllowThreads();
try
{
using (GIL())// 初始化Python环境的Global Interpreter Lock)
{
dynamic modelscope = Py.Import("modelscope");
dynamic flagEmbedding = Py.Import("FlagEmbedding");
dynamic model_dir = modelscope.snapshot_download(modelName, revision: "master");
dynamic flagReranker = flagEmbedding.FlagReranker(model_dir, use_fp16: true);
model = flagReranker;
return model;
}
}
catch (Exception ex)
{
throw ex;
}
}
else
{
return model;
}
}
}
public static double Rerank(List<string> list)
{
using (GIL())
{
try
{
PyList pyList = new PyList();
foreach (string item in list)
{
pyList.Append(item.ToPython()); // 将C# string转换为Python对象并添加到PyList中
}
PyObject result = model.compute_score(pyList, normalize: true);
return result.As<double>();
}
catch (Exception ex)
{
throw ex;
}
}
}
}
}

View File

@@ -0,0 +1,99 @@
using Microsoft.KernelMemory.AI.OpenAI;
using Microsoft.KernelMemory.AI.OpenAI.GPT3;
using Python.Runtime;
using Serilog;
using System;
using System.Collections.Generic;
using System.Linq;
using System.Text;
using System.Threading.Tasks;
using static Python.Runtime.Py;
namespace AntSK.Domain.Domain.Other.Bge
{
public static class BgeEmbeddingConfig
{
public static dynamic model { get; set; }
static object lockobj = new object();
/// <summary>
/// 模型写死
/// </summary>
public static dynamic LoadModel(string pythondllPath, string modelName)
{
lock (lockobj)
{
if (model == null)
{
//Runtime.PythonDLL = @"D:\Programs\Python\Python311\python311.dll";
if (string.IsNullOrEmpty(Runtime.PythonDLL))
{
Runtime.PythonDLL = pythondllPath;
}
PythonEngine.Initialize();
PythonEngine.BeginAllowThreads();
try
{
using (GIL())// 初始化Python环境的Global Interpreter Lock)
{
dynamic modelscope = Import("modelscope");
//dynamic model_dir = modelscope.snapshot_download("AI-ModelScope/bge-large-zh-v1.5", revision: "master");
dynamic model_dir = modelscope.snapshot_download(modelName, revision: "master");
dynamic HuggingFaceBgeEmbeddingstemp = Import("langchain_community.embeddings.huggingface");
dynamic HuggingFaceBgeEmbeddings = HuggingFaceBgeEmbeddingstemp.HuggingFaceBgeEmbeddings;
string model_name = model_dir;
dynamic model_kwargs = new PyDict();
model_kwargs["device"] = new PyString("cpu");
dynamic hugginmodel = HuggingFaceBgeEmbeddings(
model_name: model_dir,
model_kwargs: model_kwargs
);
model = hugginmodel;
return hugginmodel;
}
}
catch (Exception ex)
{
throw ex;
}
}
else
return model;
}
}
public static Task<float[]> GetEmbedding(string queryStr)
{
using (GIL())
{
PyObject queryResult = model.embed_query(queryStr);
var floatList = queryResult.As<float[]>();
return Task.FromResult(floatList); ;
}
}
public static int TokenCount(string queryStr)
{
//using (Py.GIL())
//{
// PyObject queryResult = model.client.tokenize(queryStr);
// // 使用Python的内置len()函数获取长度
// PyObject lenFunc = Py.Import("builtins").GetAttr("len");
// PyObject length = lenFunc.Invoke(queryResult["input_ids"]);
// int len = length.As<int>(); // 将PyObject转换为C#中的整数
// return len;
//}
var tokenCount1 = DefaultGPTTokenizer.StaticCountTokens(queryStr);
return tokenCount1;
}
public static void Dispose()
{
Log.Information("python dispose");
}
}
}

View File

@@ -0,0 +1,157 @@
using AntSK.Domain.Domain.Model.Constant;
using AntSK.Domain.Utils;
using Microsoft.Extensions.Logging;
using Microsoft.KernelMemory.AI.OpenAI;
using Microsoft.KernelMemory.Configuration;
using Microsoft.KernelMemory.DataFormats.Text;
using Microsoft.KernelMemory.Diagnostics;
using Microsoft.KernelMemory.Extensions;
using Microsoft.KernelMemory.Pipeline;
using System;
using System.Collections.Generic;
using System.Linq;
using System.Text;
using System.Threading.Tasks;
namespace AntSK.Domain.Domain.Other
{
public class KMExcelHandler: IPipelineStepHandler
{
private readonly TextPartitioningOptions _options;
private readonly IPipelineOrchestrator _orchestrator;
private readonly ILogger<KMExcelHandler> _log;
private readonly TextChunker.TokenCounter _tokenCounter;
public KMExcelHandler(
string stepName,
IPipelineOrchestrator orchestrator,
TextPartitioningOptions? options = null,
ILogger<KMExcelHandler>? log = null)
{
this.StepName = stepName;
this._orchestrator = orchestrator;
this._options = options ?? new TextPartitioningOptions();
this._options.Validate();
this._log = log ?? DefaultLogger<KMExcelHandler>.Instance;
this._tokenCounter = DefaultGPTTokenizer.StaticCountTokens;
}
/// <inheritdoc />
public string StepName { get; }
/// <inheritdoc />
public async Task<(bool success, DataPipeline updatedPipeline)> InvokeAsync(
DataPipeline pipeline, CancellationToken cancellationToken = default)
{
this._log.LogDebug("Partitioning text, pipeline '{0}/{1}'", pipeline.Index, pipeline.DocumentId);
if (pipeline.Files.Count == 0)
{
this._log.LogWarning("Pipeline '{0}/{1}': there are no files to process, moving to next pipeline step.", pipeline.Index, pipeline.DocumentId);
return (true, pipeline);
}
foreach (DataPipeline.FileDetails uploadedFile in pipeline.Files)
{
// Track new files being generated (cannot edit originalFile.GeneratedFiles while looping it)
Dictionary<string, DataPipeline.GeneratedFileDetails> newFiles = new();
foreach (KeyValuePair<string, DataPipeline.GeneratedFileDetails> generatedFile in uploadedFile.GeneratedFiles)
{
var file = generatedFile.Value;
if (file.AlreadyProcessedBy(this))
{
this._log.LogTrace("File {0} already processed by this handler", file.Name);
continue;
}
// Partition only the original text
if (file.ArtifactType != DataPipeline.ArtifactTypes.ExtractedText)
{
this._log.LogTrace("Skipping file {0} (not original text)", file.Name);
continue;
}
// Use a different partitioning strategy depending on the file type
List<string> partitions;
List<string> sentences;
BinaryData partitionContent = await this._orchestrator.ReadFileAsync(pipeline, file.Name, cancellationToken).ConfigureAwait(false);
// Skip empty partitions. Also: partitionContent.ToString() throws an exception if there are no bytes.
if (partitionContent.ToArray().Length == 0) { continue; }
switch (file.MimeType)
{
case MimeTypes.PlainText:
{
this._log.LogDebug("Partitioning text file {0}", file.Name);
string content = partitionContent.ToString();
var excelList = content.Split(KmsConstantcs.KMExcelSplit, StringSplitOptions.RemoveEmptyEntries).ToList();
sentences = excelList;
partitions = excelList;
break;
}
case MimeTypes.MarkDown:
{
this._log.LogDebug("Partitioning text file {0}", file.Name);
string content = partitionContent.ToString();
var excelList = content.Split(KmsConstantcs.KMExcelSplit, StringSplitOptions.RemoveEmptyEntries).ToList();
sentences = excelList;
partitions = excelList;
break;
}
default:
this._log.LogWarning("File {0} cannot be partitioned, type '{1}' not supported", file.Name, file.MimeType);
// Don't partition other files
continue;
}
if (partitions.Count == 0) { continue; }
this._log.LogDebug("Saving {0} file partitions", partitions.Count);
for (int partitionNumber = 0; partitionNumber < partitions.Count; partitionNumber++)
{
// TODO: turn partitions in objects with more details, e.g. page number
string text = partitions[partitionNumber];
int sectionNumber = 0; // TODO: use this to store the page number (if any)
BinaryData textData = new(text);
int tokenCount = this._tokenCounter(text);
this._log.LogDebug("Partition size: {0} tokens", tokenCount);
var destFile = uploadedFile.GetPartitionFileName(partitionNumber);
await this._orchestrator.WriteFileAsync(pipeline, destFile, textData, cancellationToken).ConfigureAwait(false);
var destFileDetails = new DataPipeline.GeneratedFileDetails
{
Id = Guid.NewGuid().ToString("N"),
ParentId = uploadedFile.Id,
Name = destFile,
Size = text.Length,
MimeType = MimeTypes.PlainText,
ArtifactType = DataPipeline.ArtifactTypes.TextPartition,
PartitionNumber = partitionNumber,
SectionNumber = sectionNumber,
Tags = pipeline.Tags,
ContentSHA256 = textData.AntSKCalculateSHA256(),
};
newFiles.Add(destFile, destFileDetails);
destFileDetails.MarkProcessedBy(this);
}
file.MarkProcessedBy(this);
}
// Add new files to pipeline status
foreach (var file in newFiles)
{
uploadedFile.GeneratedFiles.Add(file.Key, file.Value);
}
}
return (true, pipeline);
}
}
}

View File

@@ -1,4 +1,5 @@
using LLama;
using AntSK.Domain.Options;
using LLama;
using LLama.Common;
using LLamaSharp.KernelMemory;
@@ -29,10 +30,10 @@ namespace AntSK.Domain.Domain.Other
}
var parameters = new ModelParams(lsConfig.ModelPath)
{
ContextSize = lsConfig?.ContextSize ?? 2048,
ContextSize = LLamaSharpOption.ContextSize ?? 2048,
Seed = lsConfig?.Seed ?? 0,
GpuLayerCount = lsConfig?.GpuLayerCount ?? 20,
EmbeddingMode = true
GpuLayerCount = LLamaSharpOption.GpuLayerCount ?? 20,
Embeddings = true
};
var weights = LLamaWeights.LoadFromFile(parameters);
dicLLamaWeights.Add(modelPath, (weights, parameters));

View File

@@ -0,0 +1,173 @@
using AntSK.Domain.Domain.Interface;
using AntSK.Domain.Domain.Model;
using AntSK.Domain.Utils;
using Microsoft.Extensions.Configuration;
using Microsoft.Extensions.Logging;
using Microsoft.KernelMemory.AI.OpenAI;
using Microsoft.KernelMemory.Configuration;
using Microsoft.KernelMemory.DataFormats.Text;
using Microsoft.KernelMemory.Diagnostics;
using Microsoft.KernelMemory.Extensions;
using Microsoft.KernelMemory.Pipeline;
using Microsoft.SemanticKernel;
using Newtonsoft.Json;
using RestSharp;
using System.Security.Policy;
using System.Text;
using System.Text.RegularExpressions;
namespace AntSK.Domain.Domain.Other
{
public class QAHandler : IPipelineStepHandler
{
private readonly TextPartitioningOptions _options;
private readonly IPipelineOrchestrator _orchestrator;
private readonly ILogger<QAHandler> _log;
private readonly TextChunker.TokenCounter _tokenCounter;
private readonly IKernelService _kernelService;
public QAHandler(
string stepName,
IPipelineOrchestrator orchestrator,
IKernelService kernelService,
TextPartitioningOptions? options = null,
ILogger<QAHandler>? log = null
)
{
this.StepName = stepName;
this._orchestrator = orchestrator;
this._options = options ?? new TextPartitioningOptions();
this._options.Validate();
this._log = log ?? DefaultLogger<QAHandler>.Instance;
this._tokenCounter = DefaultGPTTokenizer.StaticCountTokens;
this._kernelService = kernelService;
}
/// <inheritdoc />
public string StepName { get; }
/// <inheritdoc />
public async Task<(bool success, DataPipeline updatedPipeline)> InvokeAsync(
DataPipeline pipeline, CancellationToken cancellationToken = default)
{
this._log.LogDebug("Partitioning text, pipeline '{0}/{1}'", pipeline.Index, pipeline.DocumentId);
if (pipeline.Files.Count == 0)
{
this._log.LogWarning("Pipeline '{0}/{1}': there are no files to process, moving to next pipeline step.", pipeline.Index, pipeline.DocumentId);
return (true, pipeline);
}
foreach (DataPipeline.FileDetails uploadedFile in pipeline.Files)
{
// Track new files being generated (cannot edit originalFile.GeneratedFiles while looping it)
Dictionary<string, DataPipeline.GeneratedFileDetails> newFiles = new();
foreach (KeyValuePair<string, DataPipeline.GeneratedFileDetails> generatedFile in uploadedFile.GeneratedFiles)
{
var file = generatedFile.Value;
if (file.AlreadyProcessedBy(this))
{
this._log.LogTrace("File {0} already processed by this handler", file.Name);
continue;
}
// Partition only the original text
if (file.ArtifactType != DataPipeline.ArtifactTypes.ExtractedText)
{
this._log.LogTrace("Skipping file {0} (not original text)", file.Name);
continue;
}
// Use a different partitioning strategy depending on the file type
List<string> partitions;
List<string> sentences;
BinaryData partitionContent = await this._orchestrator.ReadFileAsync(pipeline, file.Name, cancellationToken).ConfigureAwait(false);
// Skip empty partitions. Also: partitionContent.ToString() throws an exception if there are no bytes.
if (partitionContent.ToArray().Length == 0) { continue; }
switch (file.MimeType)
{
case MimeTypes.PlainText:
case MimeTypes.MarkDown:
{
this._log.LogDebug("Partitioning text file {0}", file.Name);
string content = partitionContent.ToString();
var kernel = _kernelService.GetKernelByAIModelID(StepName);
var lines = TextChunker.SplitPlainTextLines(content, 299);
var paragraphs = TextChunker.SplitPlainTextParagraphs(lines, 3000);
KernelFunction jsonFun = kernel.Plugins.GetFunction("KMSPlugin", "QA");
List<string> qaList = new List<string>();
foreach (var para in paragraphs)
{
var qaresult = await kernel.InvokeAsync(function: jsonFun, new KernelArguments() { ["input"] = para });
var qaListStr = qaresult.GetValue<string>().ConvertToString();
string pattern = @"Q\d+:.*?A\d+:.*?(?=(Q\d+:|$))";
RegexOptions options = RegexOptions.Singleline;
foreach (Match match in Regex.Matches(qaListStr, pattern, options))
{
qaList.Add(match.Value.Trim()); // Trim用于删除可能的首尾空格
}
}
sentences = qaList;
partitions = qaList;
break;
}
default:
this._log.LogWarning("File {0} cannot be partitioned, type '{1}' not supported", file.Name, file.MimeType);
// Don't partition other files
continue;
}
if (partitions.Count == 0) { continue; }
this._log.LogDebug("Saving {0} file partitions", partitions.Count);
for (int partitionNumber = 0; partitionNumber < partitions.Count; partitionNumber++)
{
// TODO: turn partitions in objects with more details, e.g. page number
string text = partitions[partitionNumber];
int sectionNumber = 0; // TODO: use this to store the page number (if any)
BinaryData textData = new(text);
int tokenCount = this._tokenCounter(text);
this._log.LogDebug("Partition size: {0} tokens", tokenCount);
var destFile = uploadedFile.GetPartitionFileName(partitionNumber);
await this._orchestrator.WriteFileAsync(pipeline, destFile, textData, cancellationToken).ConfigureAwait(false);
var destFileDetails = new DataPipeline.GeneratedFileDetails
{
Id = Guid.NewGuid().ToString("N"),
ParentId = uploadedFile.Id,
Name = destFile,
Size = text.Length,
MimeType = MimeTypes.PlainText,
ArtifactType = DataPipeline.ArtifactTypes.TextPartition,
PartitionNumber = partitionNumber,
SectionNumber = sectionNumber,
Tags = pipeline.Tags,
ContentSHA256 = textData.AntSKCalculateSHA256(),
};
newFiles.Add(destFile, destFileDetails);
destFileDetails.MarkProcessedBy(this);
}
file.MarkProcessedBy(this);
}
// Add new files to pipeline status
foreach (var file in newFiles)
{
uploadedFile.GeneratedFiles.Add(file.Key, file.Value);
}
}
return (true, pipeline);
}
}
}

View File

@@ -1,21 +1,23 @@
using AntSK.Domain.Common.DependencyInjection;
using AntSK.Domain.Domain.Interface;
using AntSK.Domain.Repositories;
using Microsoft.SemanticKernel.Connectors.OpenAI;
using Microsoft.SemanticKernel;
using System.Text;
using AntSK.Domain.Utils;
using AntSK.Domain.Domain.Model.Dto;
using AntSK.Domain.Domain.Model.Constant;
using DocumentFormat.OpenXml.Drawing;
using System.Reflection.Metadata;
using Microsoft.KernelMemory;
using System.Collections.Generic;
using Markdig;
using ChatHistory = Microsoft.SemanticKernel.ChatCompletion.ChatHistory;
using Microsoft.SemanticKernel.Plugins.Core;
using Azure.Core;
using AntSK.Domain.Domain.Model;
using AntSK.Domain.Domain.Model.Constant;
using AntSK.Domain.Domain.Model.Dto;
using AntSK.Domain.Domain.Other.Bge;
using AntSK.Domain.Repositories;
using AntSK.Domain.Utils;
using AntSK.LLM.StableDiffusion;
using Markdig;
using Microsoft.KernelMemory;
using Microsoft.SemanticKernel;
using Microsoft.SemanticKernel.ChatCompletion;
using Microsoft.SemanticKernel.Connectors.OpenAI;
using System.Diagnostics;
using System.Drawing;
using System.Runtime.InteropServices;
using System.Text;
using System.Text.RegularExpressions;
using ChatHistory = Microsoft.SemanticKernel.ChatCompletion.ChatHistory;
namespace AntSK.Domain.Domain.Service
{
@@ -23,7 +25,8 @@ namespace AntSK.Domain.Domain.Service
public class ChatService(
IKernelService _kernelService,
IKMService _kMService,
IKmsDetails_Repositories _kmsDetails_Repositories
IKmsDetails_Repositories _kmsDetails_Repositories,
IAIModels_Repositories _aIModels_Repositories
) : IChatService
{
/// <summary>
@@ -33,87 +36,183 @@ namespace AntSK.Domain.Domain.Service
/// <param name="questions"></param>
/// <param name="history"></param>
/// <returns></returns>
public async IAsyncEnumerable<StreamingKernelContent> SendChatByAppAsync(Apps app, string questions, ChatHistory history)
public async IAsyncEnumerable<string> SendChatByAppAsync(Apps app, ChatHistory history)
{
if (string.IsNullOrEmpty(app.Prompt) || !app.Prompt.Contains("{{$input}}"))
{
//如果模板为空,给默认提示词
app.Prompt = app.Prompt.ConvertToString() + "{{$input}}";
}
KernelArguments args =new KernelArguments();
if (history.Count > 10)
{
app.Prompt = @"${{ConversationSummaryPlugin.SummarizeConversation $history}}" + app.Prompt;
args = new() {
{ "history", string.Join("\n", history.Select(x => x.Role + ": " + x.Content)) },
{ "input", questions }
};
}
else
{
args=new()
{
{ "input", $"{string.Join("\n", history.Select(x => x.Role + ": " + x.Content))}{Environment.NewLine} user:{questions}" }
};
}
var _kernel = _kernelService.GetKernelByApp(app);
var chat = _kernel.GetRequiredService<IChatCompletionService>();
var temperature = app.Temperature / 100;//存的是0~100需要缩小
OpenAIPromptExecutionSettings settings = new() { Temperature = temperature };
List<string> completionList = new List<string>();
if (!string.IsNullOrEmpty(app.ApiFunctionList) || !string.IsNullOrEmpty(app.NativeFunctionList))//这里还需要加上本地插件的
{
_kernelService.ImportFunctionsByApp(app, _kernel);
settings.ToolCallBehavior = ToolCallBehavior.AutoInvokeKernelFunctions;
settings.ToolCallBehavior = ToolCallBehavior.EnableKernelFunctions;
while (true)
{
ChatMessageContent result = await chat.GetChatMessageContentAsync(history, settings, _kernel);
if (result.Content is not null)
{
string chunkCompletion = result.Content.ConvertToString();
completionList.Add(chunkCompletion);
foreach (var content in completionList)
{
yield return content.ConvertToString();
}
break;
}
history.Add(result);
IEnumerable<FunctionCallContent> functionCalls = FunctionCallContent.GetFunctionCalls(result);
if (!functionCalls.Any())
{
break;
}
foreach (var functionCall in functionCalls)
{
FunctionResultContent resultContent = await functionCall.InvokeAsync(_kernel);
history.Add(resultContent.ToChatMessage());
}
}
}
var func = _kernel.CreateFunctionFromPrompt(app.Prompt, settings);
var chatResult = _kernel.InvokeStreamingAsync(function: func,
arguments: args);
await foreach (var content in chatResult)
else
{
yield return content;
var chatResult = chat.GetStreamingChatMessageContentsAsync(history, settings, _kernel);
await foreach (var content in chatResult)
{
yield return content.ConvertToString();
}
}
}
public async IAsyncEnumerable<StreamingKernelContent> SendKmsByAppAsync(Apps app, string questions, ChatHistory history, string filePath, List<RelevantSource> relevantSources = null)
{
var relevantSourceList = await _kMService.GetRelevantSourceList(app.KmsIdList, questions);
relevantSources?.Clear();
var relevantSourceList = await _kMService.GetRelevantSourceList(app, questions);
var _kernel = _kernelService.GetKernelByApp(app);
if (!string.IsNullOrWhiteSpace(filePath))
{
var memory = _kMService.GetMemory(app);
var fileId = Guid.NewGuid().ToString();
var result = await memory.ImportDocumentAsync(new Microsoft.KernelMemory.Document(fileId).AddFile(filePath)
.AddTag(KmsConstantcs.KmsIdTag, app.Id)
, index: KmsConstantcs.KmsIndex);
var memory = _kMService.GetMemoryByApp(app);
var filters = new MemoryFilter().ByTag(KmsConstantcs.KmsIdTag, app.Id);
// 匹配GUID的正则表达式
string pattern = @"\b[a-fA-F0-9]{8}-[a-fA-F0-9]{4}-[a-fA-F0-9]{4}-[a-fA-F0-9]{4}-[a-fA-F0-9]{12}\b";
var searchResult = await memory.SearchAsync(questions, index: KmsConstantcs.KmsIndex, filters: [filters]);
relevantSourceList.AddRange(searchResult.Results.SelectMany(item => item.Partitions.Select(part => new RelevantSource()
// 使用正则表达式找到匹配
Match match = Regex.Match(filePath, pattern);
if (match.Success)
{
SourceName = item.SourceName,
Text = Markdown.ToHtml(part.Text),
Relevance = part.Relevance
})));
var fileId = match.Value;
var status=await memory.IsDocumentReadyAsync(fileId, index: KmsConstantcs.KmsIndex);
if (!status)
{
var result = await memory.ImportDocumentAsync(new Document(fileId).AddFile(filePath)
.AddTag(KmsConstantcs.AppIdTag, app.Id)
.AddTag(KmsConstantcs.FileIdTag, fileId)
, index: KmsConstantcs.FileIndex);
}
var filters = new List<MemoryFilter>() {
new MemoryFilter().ByTag(KmsConstantcs.AppIdTag, app.Id),
new MemoryFilter().ByTag(KmsConstantcs.FileIdTag, fileId)
};
var searchResult = await memory.SearchAsync(questions, index: KmsConstantcs.FileIndex, filters: filters);
relevantSourceList.AddRange(searchResult.Results.SelectMany(item => item.Partitions.Select(part => new RelevantSource()
{
SourceName = item.SourceName,
Text = Markdown.ToHtml(part.Text),
Relevance = part.Relevance
})));
app.Prompt = KmsConstantcs.KmsPrompt;
}
}
var dataMsg = new StringBuilder();
if (relevantSourceList.Any())
{
relevantSources?.AddRange(relevantSourceList);
if (!string.IsNullOrEmpty(app.RerankModelID))
{
var rerankModel=_aIModels_Repositories.GetById(app.RerankModelID);
BegRerankConfig.LoadModel(rerankModel.EndPoint, rerankModel.ModelName);
//进行rerank
foreach (var item in relevantSourceList)
{
List<string> rerank = new List<string>();
rerank.Add(questions);
rerank.Add(item.Text);
item.RerankScore = BegRerankConfig.Rerank(rerank);
}
relevantSourceList = relevantSourceList.OrderByDescending(p => p.RerankScore).Take(app.MaxMatchesCount).ToList();
}
bool isSearch = false;
foreach (var item in relevantSourceList)
{
dataMsg.AppendLine(item.ToString());
if (!string.IsNullOrEmpty(app.RerankModelID))
{
//匹配重排后相似度
if (item.RerankScore >= app.Relevance / 100)
{
dataMsg.AppendLine(item.ToString());
isSearch = true;
}
}
else
{
//匹配相似度
if (item.Relevance >= app.Relevance / 100)
{
dataMsg.AppendLine(item.ToString());
isSearch = true;
}
}
}
KernelFunction jsonFun = _kernel.Plugins.GetFunction("KMSPlugin", "Ask1");
var chatResult = _kernel.InvokeStreamingAsync(function: jsonFun,
arguments: new KernelArguments() { ["doc"] = dataMsg, ["history"] = string.Join("\n", history.Select(x => x.Role + ": " + x.Content)), ["questions"] = questions });
await foreach (var content in chatResult)
//处理markdown显示
relevantSources?.AddRange(relevantSourceList);
Dictionary<string, string> fileDic = new Dictionary<string, string>();
foreach (var item in relevantSourceList)
{
yield return content;
if (fileDic.ContainsKey(item.SourceName))
{
item.SourceName = fileDic[item.SourceName];
}
else
{
var fileDetail = _kmsDetails_Repositories.GetFirst(p => p.FileGuidName == item.SourceName);
if (fileDetail.IsNotNull())
{
string fileName = fileDetail.FileName;
fileDic.Add(item.SourceName, fileName);
item.SourceName = fileName;
}
}
item.Text = Markdown.ToHtml(item.Text);
}
if (isSearch)
{
//KernelFunction jsonFun = _kernel.Plugins.GetFunction("KMSPlugin", "Ask1");
var temperature = app.Temperature / 100;//存的是0~100需要缩小
OpenAIPromptExecutionSettings settings = new() { Temperature = temperature };
var func = _kernel.CreateFunctionFromPrompt(app.Prompt , settings);
var chatResult = _kernel.InvokeStreamingAsync(function: func,
arguments: new KernelArguments() { ["doc"] = dataMsg.ToString(), ["history"] = string.Join("\n", history.Select(x => x.Role + ": " + x.Content)), ["input"] = questions });
await foreach (var content in chatResult)
{
yield return content;
}
}
else
{
yield return new StreamingTextContent(KmsConstantcs.KmsSearchNull);
}
}
else
@@ -122,25 +221,127 @@ namespace AntSK.Domain.Domain.Service
}
}
public async Task<ChatHistory> GetChatHistory(List<MessageInfo> MessageList)
{
ChatHistory history = new ChatHistory();
if (MessageList.Count > 1)
{
foreach (var item in MessageList)
public async Task<string> SendImgByAppAsync(Apps app, string questions)
{
var imageModel = _aIModels_Repositories.GetFirst(p => p.Id == app.ImageModelID);
KernelArguments args = new() {
{ "input", questions }
};
var _kernel = _kernelService.GetKernelByApp(app);
var temperature = app.Temperature / 100; //存的是0~100需要缩小
OpenAIPromptExecutionSettings settings = new() { Temperature = temperature };
var func = _kernel.CreateFunctionFromPrompt("Translate this into English:{{$input}}", settings);
var chatResult = await _kernel.InvokeAsync(function: func, arguments: args);
if (chatResult.IsNotNull())
{
//Can Load stable-diffusion library in diffenert environment
//SDHelper.LoadLibrary()
string versionString = string.Empty;
string extensionString = string.Empty;
if (RuntimeInformation.IsOSPlatform(OSPlatform.Windows))
{
if (item.IsSend)
extensionString = ".dll";
}
else if (RuntimeInformation.IsOSPlatform(OSPlatform.Linux))
{
extensionString = ".so";
}
else
{
throw new InvalidOperationException("OS Platform no support");
}
ProcessStartInfo startInfo = new ProcessStartInfo("nvcc", "--version");
startInfo.RedirectStandardOutput = true;
startInfo.UseShellExecute = false;
startInfo.CreateNoWindow = true;
using (Process process = Process.Start(startInfo))
{
if (process != null)
{
history.AddUserMessage(item.Context);
string result = process.StandardOutput.ReadToEnd();
Regex regex = new Regex(@"release (\d+).[\d]");
Match match = regex.Match(result);
if (match.Success)
{
switch (match.Groups[1].Value.ToString())
{
case "11":
versionString = "Cuda11";
break;
case "12":
versionString = "Cuda12";
break;
default:
versionString = "CPU";
break;
}
}
}
else
{
history.AddAssistantMessage(item.Context);
throw new Exception("nvcc get an error");
}
}
string libraryPath = System.IO.Path.Combine(AppDomain.CurrentDomain.BaseDirectory, "StableDiffusion", "Backend", versionString, "stable-diffusion" + extensionString);
NativeLibrary.TryLoad(libraryPath, out _);
string prompt = chatResult.GetValue<string>();
if (!SDHelper.IsInitialized)
{
Structs.ModelParams modelParams = new Structs.ModelParams
{
ModelPath = imageModel.ModelName,
RngType = Structs.RngType.CUDA_RNG,
//VaePath = vaePath,
//KeepVaeOnCpu = keepVaeOnCpu,
//set false can get a better image, otherwise can use lower vram
VaeTiling = false,
//LoraModelDir = loraModelDir,
};
bool result = SDHelper.Initialize(modelParams);
}
Structs.TextToImageParams textToImageParams = new Structs.TextToImageParams
{
Prompt = prompt,
NegativePrompt = "bad quality, wrong image, worst quality",
SampleMethod = (Structs.SampleMethod)Enum.Parse(typeof(Structs.SampleMethod), "EULER_A"),
//the base image size in SD1.5 is 512x512
Width = 512,
Height = 512,
NormalizeInput = true,
ClipSkip = -1,
CfgScale = 7,
SampleSteps = 20,
Seed = -1,
};
Bitmap[] outputImages = SDHelper.TextToImage(textToImageParams);
var base64 = ImageUtils.BitmapToBase64(outputImages[0]);
return base64;
}
else
{
return "";
}
}
public async Task<ChatHistory> GetChatHistory(List<Chats> MessageList, ChatHistory history)
{
foreach (var item in MessageList)
{
if (item.IsSend)
{
history.AddUserMessage(item.Context);
}
else
{
history.AddAssistantMessage(item.Context);
}
}
return history;
}
}
}
}

View File

@@ -8,6 +8,7 @@ using System.Text.RegularExpressions;
using Microsoft.SemanticKernel;
using HtmlAgilityPack;
using System.Collections.Generic;
using Serilog;
namespace AntSK.Domain.Domain.Service
{
@@ -115,7 +116,7 @@ namespace AntSK.Domain.Domain.Service
}
catch (Exception ex)
{
Console.WriteLine(ex.Message + " ---- " + ex.StackTrace);
Log.Error(ex.Message + " ---- " + ex.StackTrace);
}
}
}

View File

@@ -2,8 +2,13 @@
using AntSK.Domain.Domain.Interface;
using AntSK.Domain.Domain.Model;
using AntSK.Domain.Domain.Model.Constant;
using AntSK.Domain.Domain.Model.Excel;
using AntSK.Domain.Domain.Other;
using AntSK.Domain.Repositories;
using Microsoft.Extensions.Logging;
using Microsoft.KernelMemory;
using Microsoft.KernelMemory.Handlers;
using System.Text;
namespace AntSK.Domain.Domain.Service
{
@@ -11,7 +16,8 @@ namespace AntSK.Domain.Domain.Service
public class ImportKMSService(
IKMService _kMService,
IKmsDetails_Repositories _kmsDetails_Repositories,
IKmss_Repositories _kmss_Repositories
IKmss_Repositories _kmss_Repositories,
ILogger<ImportKMSService> _logger
) : IImportKMSService
{
@@ -20,18 +26,40 @@ namespace AntSK.Domain.Domain.Service
try
{
var km = _kmss_Repositories.GetFirst(p => p.Id == req.KmsId);
var _memory = _kMService.GetMemoryByKMS(km.Id);
string fileid = req.KmsDetail.Id;
List<string> step = new List<string>();
if (req.IsQA)
{
_memory.Orchestrator.AddHandler<TextExtractionHandler>("extract_text");
_memory.Orchestrator.AddHandler<QAHandler>(km.ChatModelID);
_memory.Orchestrator.AddHandler<GenerateEmbeddingsHandler>("generate_embeddings");
_memory.Orchestrator.AddHandler<SaveRecordsHandler>("save_memory_records");
step.Add("extract_text");
step.Add(km.ChatModelID);
step.Add("generate_embeddings");
step.Add("save_memory_records");
}
switch (req.ImportType)
{
case ImportType.File:
//导入文件
{
var importResult = _memory.ImportDocumentAsync(new Document(fileid)
.AddFile(req.FilePath)
.AddTag(KmsConstantcs.KmsIdTag, req.KmsId)
, index: KmsConstantcs.KmsIndex).Result;
//导入文件
if (req.IsQA)
{
var importResult = _memory.ImportDocumentAsync(new Document(fileid)
.AddFile(req.FilePath)
.AddTag(KmsConstantcs.KmsIdTag, req.KmsId)
,index: KmsConstantcs.KmsIndex ,steps: step.ToArray()).Result;
}
else
{
var importResult = _memory.ImportDocumentAsync(new Document(fileid)
.AddFile(req.FilePath)
.AddTag(KmsConstantcs.KmsIdTag, req.KmsId)
, index: KmsConstantcs.KmsIndex).Result;
}
//查询文档数量
var docTextList = _kMService.GetDocumentByFileID(km.Id, fileid).Result;
string fileGuidName = Path.GetFileName(req.FilePath);
@@ -44,8 +72,16 @@ namespace AntSK.Domain.Domain.Service
case ImportType.Url:
{
//导入url
var importResult = _memory.ImportWebPageAsync(req.Url, fileid, new TagCollection() { { KmsConstantcs.KmsIdTag, req.KmsId } }
, index: KmsConstantcs.KmsIndex).Result;
if (req.IsQA)
{
var importResult = _memory.ImportWebPageAsync(req.Url, fileid, new TagCollection() { { KmsConstantcs.KmsIdTag, req.KmsId } }
, index: KmsConstantcs.KmsIndex, steps: step.ToArray()).Result;
}
else
{
var importResult = _memory.ImportWebPageAsync(req.Url, fileid, new TagCollection() { { KmsConstantcs.KmsIdTag, req.KmsId } }
, index: KmsConstantcs.KmsIndex).Result;
}
//查询文档数量
var docTextList = _kMService.GetDocumentByFileID(km.Id, fileid).Result;
req.KmsDetail.Url = req.Url;
@@ -55,8 +91,16 @@ namespace AntSK.Domain.Domain.Service
case ImportType.Text:
//导入文本
{
var importResult = _memory.ImportTextAsync(req.Text, fileid, new TagCollection() { { KmsConstantcs.KmsIdTag, req.KmsId } }
, index: KmsConstantcs.KmsIndex).Result;
if (req.IsQA)
{
var importResult = _memory.ImportTextAsync(req.Text, fileid, new TagCollection() { { KmsConstantcs.KmsIdTag, req.KmsId } }
, index: KmsConstantcs.KmsIndex, steps: step.ToArray()).Result;
}
else
{
var importResult = _memory.ImportTextAsync(req.Text, fileid, new TagCollection() { { KmsConstantcs.KmsIdTag, req.KmsId } }
, index: KmsConstantcs.KmsIndex).Result;
}
//查询文档数量
var docTextList = _kMService.GetDocumentByFileID(km.Id, fileid).Result;
req.KmsDetail.Url = req.Url;
@@ -64,17 +108,47 @@ namespace AntSK.Domain.Domain.Service
}
break;
case ImportType.Excel:
using (var fs = File.OpenRead(req.FilePath))
{
var excelList= ExeclHelper.ExcelToList<KMSExcelModel>(fs);
_memory.Orchestrator.AddHandler<TextExtractionHandler>("extract_text");
_memory.Orchestrator.AddHandler<KMExcelHandler>("antsk_excel_split");
_memory.Orchestrator.AddHandler<GenerateEmbeddingsHandler>("generate_embeddings");
_memory.Orchestrator.AddHandler<SaveRecordsHandler>("save_memory_records");
StringBuilder text = new StringBuilder();
foreach (var item in excelList)
{
text.AppendLine(@$"Question:{item.Question}{Environment.NewLine}Answer:{item.Answer}{KmsConstantcs.KMExcelSplit}");
}
var importResult = _memory.ImportTextAsync(text.ToString(), fileid, new TagCollection() { { KmsConstantcs.KmsIdTag, req.KmsId } }
, index: KmsConstantcs.KmsIndex,
steps: new[]
{
"extract_text",
"antsk_excel_split",
"generate_embeddings",
"save_memory_records"
}
).Result;
req.KmsDetail.FileName = req.FileName;
string fileGuidName = Path.GetFileName(req.FilePath);
req.KmsDetail.FileGuidName = fileGuidName;
req.KmsDetail.DataCount = excelList.Count();
}
break;
}
req.KmsDetail.Status = Model.Enum.ImportKmsStatus.Success;
_kmsDetails_Repositories.Update(req.KmsDetail);
//_kmsDetails_Repositories.GetList(p => p.KmsId == req.KmsId);
Console.WriteLine("后台导入任务成功:" + req.KmsDetail.DataCount);
_logger.LogInformation("后台导入任务成功:" + req.KmsDetail.DataCount);
}
catch (Exception ex)
{
req.KmsDetail.Status = Model.Enum.ImportKmsStatus.Fail;
_kmsDetails_Repositories.Update(req.KmsDetail);
Console.WriteLine("后台导入任务异常:" + ex.Message);
_logger.LogError("后台导入任务异常:" + ex.Message);
}
}
}

View File

@@ -1,5 +1,6 @@
using AntDesign;
using AntSK.Domain.Common.DependencyInjection;
using AntSK.Domain.Common.Embedding;
using AntSK.Domain.Domain.Interface;
using AntSK.Domain.Domain.Model.Constant;
using AntSK.Domain.Domain.Model.Dto;
@@ -7,6 +8,7 @@ using AntSK.Domain.Domain.Other;
using AntSK.Domain.Options;
using AntSK.Domain.Repositories;
using AntSK.Domain.Utils;
using AntSK.OCR;
using DocumentFormat.OpenXml.Drawing.Diagrams;
using LLama;
using LLamaSharp.KernelMemory;
@@ -15,6 +17,7 @@ using Microsoft.AspNetCore.Components;
using Microsoft.Extensions.Configuration;
using Microsoft.KernelMemory;
using Microsoft.KernelMemory.Configuration;
using Microsoft.KernelMemory.DataFormats;
using Microsoft.KernelMemory.FileSystem.DevTools;
using Microsoft.KernelMemory.MemoryStorage;
using Microsoft.KernelMemory.MemoryStorage.DevTools;
@@ -26,7 +29,8 @@ namespace AntSK.Domain.Domain.Service
public class KMService(
IKmss_Repositories _kmss_Repositories,
IAIModels_Repositories _aIModels_Repositories,
IMessageService? _message
IMessageService? _message,
IKernelService _kernelService
) : IKMService
{
private MemoryServerless _memory;
@@ -35,20 +39,36 @@ namespace AntSK.Domain.Domain.Service
public List<UploadFileItem> FileList => _fileList;
public MemoryServerless GetMemory(Apps app)
public MemoryServerless GetMemoryByApp(Apps app)
{
var chatModel = _aIModels_Repositories.GetFirst(p => p.Id == app.ChatModelID);
var embedModel = _aIModels_Repositories.GetFirst(p => p.Id == app.EmbeddingModelID);
var chatHttpClient = OpenAIHttpClientHandlerUtil.GetHttpClient(chatModel.EndPoint);
var embeddingHttpClient = OpenAIHttpClientHandlerUtil.GetHttpClient(embedModel.EndPoint);
var searchClientConfig = new SearchClientConfig
SearchClientConfig searchClientConfig;
if (string.IsNullOrEmpty(app.RerankModelID))
{
MaxAskPromptSize = 2048,
MaxMatchesCount = 3,
AnswerTokens = 1000,
EmptyAnswer = KmsConstantcs.KmsSearchNull
};
//不重排直接取查询数
searchClientConfig = new SearchClientConfig
{
MaxAskPromptSize = app.MaxAskPromptSize,
MaxMatchesCount = app.MaxMatchesCount,
AnswerTokens = app.AnswerTokens,
EmptyAnswer = KmsConstantcs.KmsSearchNull
};
}
else
{
//重排取rerank数
searchClientConfig = new SearchClientConfig
{
MaxAskPromptSize = app.MaxAskPromptSize,
MaxMatchesCount = app.RerankCount,
AnswerTokens = app.AnswerTokens,
EmptyAnswer = KmsConstantcs.KmsSearchNull
};
}
var memoryBuild = new KernelMemoryBuilder()
.WithSearchClientConfig(searchClientConfig)
@@ -70,7 +90,7 @@ namespace AntSK.Domain.Domain.Service
return _memory;
}
public MemoryServerless GetMemoryByKMS(string kmsID, SearchClientConfig searchClientConfig = null)
public MemoryServerless GetMemoryByKMS(string kmsID)
{
//if (_memory.IsNull())
{
@@ -84,33 +104,35 @@ namespace AntSK.Domain.Domain.Service
var embeddingHttpClient = OpenAIHttpClientHandlerUtil.GetHttpClient(embedModel.EndPoint);
//搜索配置
if (searchClientConfig.IsNull())
{
searchClientConfig = new SearchClientConfig
{
MaxAskPromptSize = 2048,
MaxMatchesCount = 3,
AnswerTokens = 1000,
EmptyAnswer = KmsConstantcs.KmsSearchNull
};
}
//if (searchClientConfig.IsNull())
//{
// searchClientConfig = new SearchClientConfig
// {
// MaxAskPromptSize = 2048,
// MaxMatchesCount = 3,
// AnswerTokens = 1000,
// EmptyAnswer = KmsConstantcs.KmsSearchNull
// };
//}
var memoryBuild = new KernelMemoryBuilder()
.WithSearchClientConfig(searchClientConfig)
//.WithSearchClientConfig(searchClientConfig)
.WithCustomTextPartitioningOptions(new TextPartitioningOptions
{
MaxTokensPerLine = kms.MaxTokensPerLine,
MaxTokensPerParagraph = kms.MaxTokensPerParagraph,
OverlappingTokens = kms.OverlappingTokens
});
//加载OCR
WithOcr(memoryBuild, kms);
//加载会话模型
WithTextGenerationByAIType(memoryBuild, chatModel, chatHttpClient);
//加载向量模型
WithTextEmbeddingGenerationByAIType(memoryBuild, embedModel, embeddingHttpClient);
//加载向量库
WithMemoryDbByVectorDB(memoryBuild);
_memory = memoryBuild.Build<MemoryServerless>();
_memory = memoryBuild.AddSingleton<IKernelService>(_kernelService).Build<MemoryServerless>();
return _memory;
}
//else {
@@ -118,6 +140,14 @@ namespace AntSK.Domain.Domain.Service
//}
}
private static void WithOcr(IKernelMemoryBuilder memoryBuild, Kmss kms)
{
if (kms.IsOCR == 1)
{
memoryBuild.WithCustomImageOcr(new AntSKOcrEngine());
}
}
private void WithTextEmbeddingGenerationByAIType(IKernelMemoryBuilder memory, AIModels embedModel,
HttpClient embeddingHttpClient)
{
@@ -147,6 +177,11 @@ namespace AntSK.Domain.Domain.Service
var embedder = new LLamaEmbedder(weights, parameters);
memory.WithLLamaSharpTextEmbeddingGeneration(new LLamaSharpTextEmbeddingGenerator(embedder));
break;
case Model.Enum.AIType.BgeEmbedding:
string pyDll = embedModel.EndPoint;
string bgeEmbeddingModelName = embedModel.ModelName;
memory.WithBgeTextEmbeddingGeneration(new HuggingfaceTextEmbeddingGenerator(pyDll,bgeEmbeddingModelName));
break;
case Model.Enum.AIType.DashScope:
memory.WithDashScopeDefaults(embedModel.ModelKey);
break;
@@ -183,6 +218,21 @@ namespace AntSK.Domain.Domain.Service
var executor = new StatelessExecutor(weights, parameters);
memory.WithLLamaSharpTextGeneration(new LlamaSharpTextGenerator(weights, context, executor));
break;
case Model.Enum.AIType.LLamaFactory:
memory.WithOpenAITextGeneration(new OpenAIConfig()
{
APIKey = "NotNull",
TextModel = chatModel.ModelName
}, null, chatHttpClient);
break;
case Model.Enum.AIType.Ollama:
memory.WithOpenAITextGeneration(new OpenAIConfig()
{
APIKey = "NotNull",
TextModel = chatModel.ModelName
}, null, chatHttpClient);
break;
case Model.Enum.AIType.DashScope:
memory.WithDashScopeTextGeneration(new Cnblogs.KernelMemory.AI.DashScope.DashScopeConfig
{
@@ -248,12 +298,12 @@ namespace AntSK.Domain.Domain.Service
{
foreach (var memoryDb in memoryDbs)
{
var items = await memoryDb.GetListAsync(memoryIndex.Name, new List<MemoryFilter>() { new MemoryFilter().ByDocument(fileId) }, 100, true).ToListAsync();
var items = await memoryDb.GetListAsync(memoryIndex.Name, new List<MemoryFilter>() { new MemoryFilter().ByDocument(fileId) }, 1000, true).ToListAsync();
docTextList.AddRange(items.Select(item => new KMFile()
{
DocumentId = item.GetDocumentId(),
Text = item.GetPartitionText(),
Url = item.GetWebPageUrl(),
Url = item.GetWebPageUrl(KmsConstantcs.KmsIndex),
LastUpdate = item.GetLastUpdate().LocalDateTime.ToString("yyyy-MM-dd HH:mm:ss"),
File = item.GetFileName()
}));
@@ -263,15 +313,15 @@ namespace AntSK.Domain.Domain.Service
return docTextList;
}
public async Task<List<RelevantSource>> GetRelevantSourceList(string kmsIdListStr, string msg)
public async Task<List<RelevantSource>> GetRelevantSourceList(Apps app ,string msg)
{
var result = new List<RelevantSource>();
if (string.IsNullOrWhiteSpace(kmsIdListStr))
if (string.IsNullOrWhiteSpace(app.KmsIdList))
return result;
var kmsIdList = kmsIdListStr.Split(",");
var kmsIdList = app.KmsIdList.Split(",");
if (!kmsIdList.Any()) return result;
var memory = GetMemoryByKMS(kmsIdList.FirstOrDefault()!);
var memory = GetMemoryByApp(app);
var filters = kmsIdList.Select(kmsId => new MemoryFilter().ByTag(KmsConstantcs.KmsIdTag, kmsId)).ToList();
@@ -283,7 +333,7 @@ namespace AntSK.Domain.Domain.Service
result.AddRange(item.Partitions.Select(part => new RelevantSource()
{
SourceName = item.SourceName,
Text = Markdown.ToHtml(part.Text),
Text = part.Text,
Relevance = part.Relevance
}));
}
@@ -305,7 +355,10 @@ namespace AntSK.Domain.Domain.Service
"application/pdf",
"application/json",
"text/x-markdown",
"text/markdown"
"text/markdown",
"image/jpeg",
"image/png",
"image/tiff"
};
string[] exceptExts = [".md", ".pdf"];

View File

@@ -18,6 +18,12 @@ using AntSK.Domain.Domain.Model.Enum;
using AntSK.LLM.LLamaFactory;
using System.Reflection;
using DocumentFormat.OpenXml.Drawing;
using Microsoft.KernelMemory;
using OpenCvSharp.ML;
using LLamaSharp.SemanticKernel.ChatCompletion;
using Microsoft.SemanticKernel.ChatCompletion;
using Amazon.Runtime.Internal.Util;
using Microsoft.Extensions.Logging;
namespace AntSK.Domain.Domain.Service
{
@@ -29,17 +35,20 @@ namespace AntSK.Domain.Domain.Service
private readonly FunctionService _functionService;
private readonly IServiceProvider _serviceProvider;
private Kernel _kernel;
private readonly ILogger<KernelService> _logger;
public KernelService(
IApis_Repositories apis_Repositories,
IAIModels_Repositories aIModels_Repositories,
FunctionService functionService,
IServiceProvider serviceProvider)
IServiceProvider serviceProvider,
ILogger<KernelService> logger)
{
_apis_Repositories = apis_Repositories;
_aIModels_Repositories = aIModels_Repositories;
_functionService = functionService;
_serviceProvider = serviceProvider;
_logger = logger;
}
/// <summary>
@@ -57,7 +66,7 @@ namespace AntSK.Domain.Domain.Service
var chatHttpClient = OpenAIHttpClientHandlerUtil.GetHttpClient(chatModel.EndPoint);
var builder = Kernel.CreateBuilder();
WithTextGenerationByAIType(builder, app, chatModel, chatHttpClient);
WithTextGenerationByAIType(builder, chatModel, chatHttpClient);
_kernel = builder.Build();
RegisterPluginsWithKernel(_kernel);
@@ -69,7 +78,18 @@ namespace AntSK.Domain.Domain.Service
//}
}
private void WithTextGenerationByAIType(IKernelBuilder builder, Apps app, AIModels chatModel, HttpClient chatHttpClient)
public Kernel GetKernelByAIModelID(string modelid)
{
var chatModel = _aIModels_Repositories.GetById(modelid);
var chatHttpClient = OpenAIHttpClientHandlerUtil.GetHttpClient(chatModel.EndPoint);
var builder = Kernel.CreateBuilder();
WithTextGenerationByAIType(builder, chatModel, chatHttpClient);
_kernel = builder.Build();
RegisterPluginsWithKernel(_kernel);
return _kernel;
}
private void WithTextGenerationByAIType(IKernelBuilder builder,AIModels chatModel, HttpClient chatHttpClient)
{
switch (chatModel.AIType)
{
@@ -92,11 +112,35 @@ namespace AntSK.Domain.Domain.Service
var (weights, parameters) = LLamaConfig.GetLLamaConfig(chatModel.ModelName);
var ex = new StatelessExecutor(weights, parameters);
builder.Services.AddKeyedSingleton<ITextGenerationService>("local-llama", new LLamaSharpTextCompletion(ex));
builder.Services.AddKeyedSingleton<IChatCompletionService>("local-llama-chat", new LLamaSharpChatCompletion(ex));
break;
case Model.Enum.AIType.SparkDesk:
var options = new SparkDeskOptions { AppId = chatModel.EndPoint, ApiSecret = chatModel.ModelKey, ApiKey = chatModel.ModelName, ModelVersion = Sdcb.SparkDesk.ModelVersion.V3_5 };
builder.Services.AddKeyedSingleton<ITextGenerationService>("spark-desk", new SparkDeskTextCompletion(options, app.Id));
var settings = chatModel.ModelKey.Split("|");
Sdcb.SparkDesk.ModelVersion modelVersion = Sdcb.SparkDesk.ModelVersion.V3_5;
switch (chatModel.ModelName)
{
case "V3_5":
modelVersion = Sdcb.SparkDesk.ModelVersion.V3_5;
break;
case "V3":
modelVersion = Sdcb.SparkDesk.ModelVersion.V3;
break;
case "V2":
modelVersion = Sdcb.SparkDesk.ModelVersion.V2;
break;
case "V1_5":
modelVersion = Sdcb.SparkDesk.ModelVersion.V1_5;
break;
}
SparkDeskOptions options = new SparkDeskOptions { AppId = settings[0], ApiSecret = settings[1], ApiKey = settings[2], ModelVersion = modelVersion };
builder.Services.AddKeyedSingleton<ITextGenerationService>("spark-desk", new SparkDeskTextCompletion(options, chatModel.Id));
builder.Services.AddKeyedSingleton<IChatCompletionService>("spark-desk-chat", new SparkDeskChatCompletion(options, chatModel.Id));
break;
case Model.Enum.AIType.DashScope:
@@ -105,11 +149,19 @@ namespace AntSK.Domain.Domain.Service
case Model.Enum.AIType.Mock:
builder.Services.AddKeyedSingleton<ITextGenerationService>("mock", new MockTextCompletion());
builder.Services.AddKeyedSingleton<IChatCompletionService>("mock-chat", new MockChatCompletion());
break;
case Model.Enum.AIType.LLamaFactory:
builder.AddOpenAIChatCompletion(
modelId: chatModel.ModelName,
apiKey: "123",
apiKey: "NotNull",
httpClient: chatHttpClient
);
break;
case AIType.Ollama:
builder.AddOpenAIChatCompletion(
modelId: chatModel.ModelName,
apiKey: "NotNull",
httpClient: chatHttpClient
);
break;
@@ -160,7 +212,6 @@ namespace AntSK.Domain.Domain.Service
var getParametes = new List<KernelParameterMetadata>() {
new KernelParameterMetadata("jsonbody"){
Name="json参数字符串",
ParameterType=typeof(string),
Description=$"背景文档:{Environment.NewLine}{api.InputPrompt} {Environment.NewLine}提取出对应的json格式字符串参考如下格式:{Environment.NewLine}{api.Query}"
}
@@ -199,7 +250,6 @@ namespace AntSK.Domain.Domain.Service
//处理json body
var postParametes = new List<KernelParameterMetadata>() {
new KernelParameterMetadata("jsonbody"){
Name="json参数字符串",
ParameterType=typeof(string),
Description=$"背景文档:{Environment.NewLine}{api.InputPrompt} {Environment.NewLine}提取出对应的json格式字符串参考如下格式:{Environment.NewLine}{api.JsonBody}"
}
@@ -208,7 +258,7 @@ namespace AntSK.Domain.Domain.Service
{
try
{
Console.WriteLine(jsonBody);
_logger.LogInformation(jsonBody);
RestClient client = new RestClient();
RestRequest request = new RestRequest(api.Url, Method.Post);
foreach (var header in api.Header.ConvertToString().Split("\n"))
@@ -287,8 +337,8 @@ namespace AntSK.Domain.Domain.Service
KernelFunction sunFun = _kernel.Plugins.GetFunction("ConversationSummaryPlugin", "SummarizeConversation");
var summary = await _kernel.InvokeAsync(sunFun, new() { ["input"] = $"内容是:{history.ToString()} {Environment.NewLine} 请注意用中文总结" });
string his = summary.GetValue<string>();
var msg = $"history{Environment.NewLine}{history.ToString()}{Environment.NewLine} user{questions}{Environment.NewLine}"; ;
var msg = $"history{Environment.NewLine}{his}{Environment.NewLine} user{questions}{Environment.NewLine}";
return msg;
}
}
}
}

View File

@@ -1,13 +1,16 @@
using AntSK.Domain.Common.DependencyInjection;
using Amazon.Runtime.Internal.Util;
using AntSK.Domain.Common.DependencyInjection;
using AntSK.Domain.Domain.Interface;
using AntSK.Domain.Domain.Model.Dto;
using AntSK.Domain.Options;
using AntSK.LLamaFactory.Model;
using Microsoft.AspNetCore.Mvc.ModelBinding;
using Microsoft.Extensions.Logging;
using Newtonsoft.Json;
using System;
using System.Collections.Generic;
using System.Diagnostics;
using System.Diagnostics.Tracing;
using System.Linq;
using System.Text;
using System.Text.Json;
@@ -16,7 +19,7 @@ using System.Threading.Tasks;
namespace AntSK.Domain.Domain.Service
{
[ServiceDescription(typeof(ILLamaFactoryService), ServiceLifetime.Singleton)]
public class LLamaFactoryService : ILLamaFactoryService
public class LLamaFactoryService(ILogger<LLamaFactoryService> _logger) : ILLamaFactoryService
{
private Process process;
@@ -25,7 +28,7 @@ namespace AntSK.Domain.Domain.Service
private readonly object _syncLock = new object();
private List<LLamaModel> modelList = new List<LLamaModel>();
public LLamaFactoryService() { }
public delegate Task LogMessageHandler(string message);
public event LogMessageHandler LogMessageReceived;
protected virtual async Task OnLogMessageReceived(string message)
@@ -55,19 +58,21 @@ namespace AntSK.Domain.Domain.Service
};
process.OutputDataReceived += (sender, eventArgs) =>
{
Console.WriteLine($"{eventArgs.Data}");
_logger.LogInformation($"{eventArgs.Data}");
OnLogMessageReceived(eventArgs.Data);
};
process.ErrorDataReceived += (sender, eventArgs) =>
{
Console.WriteLine($"{eventArgs.Data}");
_logger.LogInformation($"{eventArgs.Data}");
OnLogMessageReceived(eventArgs.Data);
};
process.Start();
process.BeginOutputReadLine();
process.BeginErrorReadLine();
process.WaitForExit();
OnLogMessageReceived("--------------------完成--------------------");
}, TaskCreationOptions.LongRunning);
await cmdTask;
}
public async Task StartLLamaFactory(string modelName, string templateName)
@@ -82,31 +87,34 @@ namespace AntSK.Domain.Domain.Service
StartInfo = new ProcessStartInfo
{
FileName = "python",
Arguments = "api_demo.py --model_name_or_path " + modelName + " --template " + templateName + " ",
Arguments = "api_antsk.py --model_name_or_path " + modelName + " --template " + templateName + " ",
UseShellExecute = false,
RedirectStandardOutput = true,
RedirectStandardError=true,
WorkingDirectory = Path.Combine(Path.GetDirectoryName(System.Reflection.Assembly.GetEntryAssembly().Location), "llamafactory"),
}
};
process.StartInfo.Environment["CUDA_VISIBLE_DEVICES"] = "0";
process.StartInfo.Environment["CUDA_VISIBLE_DEVICES"] = Environment.GetEnvironmentVariable("CUDA_VISIBLE_DEVICES") ?? "0";
process.StartInfo.Environment["API_PORT"] = "8000";
process.StartInfo.EnvironmentVariables["USE_MODELSCOPE_HUB"] = "1";
process.StartInfo.EnvironmentVariables["USE_MODELSCOPE_HUB"] = Environment.GetEnvironmentVariable("USE_MODELSCOPE_HUB") ?? "1";
process.OutputDataReceived += (sender, eventArgs) =>
{
Console.WriteLine($"{eventArgs.Data}");
_logger.LogInformation($"{eventArgs.Data}");
OnLogMessageReceived(eventArgs.Data);
};
process.ErrorDataReceived += (sender, eventArgs) =>
{
Console.WriteLine($"{eventArgs.Data}");
_logger.LogInformation($"{eventArgs.Data}");
OnLogMessageReceived(eventArgs.Data);
};
process.Start();
process.BeginOutputReadLine();
process.BeginErrorReadLine();
process.WaitForExit();
OnLogMessageReceived("--------------------完成--------------------");
}, TaskCreationOptions.LongRunning);
await cmdTask;
}
private void Process_OutputDataReceived(object sender, DataReceivedEventArgs e)
@@ -131,7 +139,7 @@ namespace AntSK.Domain.Domain.Service
if (process1.ProcessName.ToLower() == "python")
{
process1.Kill();
System.Console.WriteLine("kill python");
_logger.LogInformation("kill python");
}
}
}

View File

@@ -0,0 +1,74 @@
using AntSK.Domain.Common.DependencyInjection;
using AntSK.Domain.Domain.Interface;
using System;
using System.Collections.Generic;
using System.Diagnostics;
using System.Linq;
using System.Text;
using System.Threading.Tasks;
using Serilog;
using AntSK.Domain.Utils;
namespace AntSK.Domain.Domain.Service
{
[ServiceDescription(typeof(IOllamaService), ServiceLifetime.Singleton)]
public class OllamaService : IOllamaService
{
private Process process;
public delegate Task LogMessageHandler(string message);
public event LogMessageHandler LogMessageReceived;
protected virtual async Task OnLogMessageReceived(string message)
{
LogMessageReceived?.Invoke(message);
}
public async Task StartOllama(string modelName)
{
Console.OutputEncoding = Encoding.UTF8;
var cmdTask = Task.Factory.StartNew(() =>
{
var isProcessComplete = false;
process = new Process
{
StartInfo = new ProcessStartInfo
{
FileName = "ollama",
Arguments = "run " + modelName,
UseShellExecute = false,
RedirectStandardOutput = true,
RedirectStandardError = true,
}
};
process.OutputDataReceived += (sender, eventArgs) =>
{
Log.Information($"{eventArgs.Data.ConvertToString()}");
if (!eventArgs.Data.ConvertToString().Contains("The handle is invalid"))
{
OnLogMessageReceived(eventArgs.Data.ConvertToString());
}
};
process.ErrorDataReceived += (sender, eventArgs) =>
{
Log.Error($"{eventArgs.Data.ConvertToString()}");
if (!eventArgs.Data.ConvertToString().Contains("The handle is invalid"))
{
OnLogMessageReceived(eventArgs.Data.ConvertToString());
}
};
process.StartInfo.StandardOutputEncoding = Encoding.UTF8;
process.StartInfo.StandardErrorEncoding = Encoding.UTF8;
process.Start();
process.BeginOutputReadLine();
process.BeginErrorReadLine();
process.WaitForExit();
OnLogMessageReceived("--------------------完成--------------------");
}, TaskCreationOptions.LongRunning);
await cmdTask;
}
}
}

View File

@@ -0,0 +1,13 @@
using System;
using System.Collections.Generic;
using System.Linq;
using System.Text;
using System.Threading.Tasks;
namespace AntSK.Domain.Options
{
public class FileDirOption
{
public static string DirectoryPath { get; set; } = Directory.GetCurrentDirectory();
}
}

View File

@@ -3,6 +3,7 @@
public class LLamaSharpOption
{
public static string RunType { get; set; }
public static string FileDirectory { get; set; } = Directory.GetCurrentDirectory();
public static uint? ContextSize { get; set; }
public static int? GpuLayerCount { get; set; }
}
}

View File

@@ -44,6 +44,10 @@ namespace AntSK.Domain.Repositories
/// </summary>
public string? EmbeddingModelID { get; set; }
public string? RerankModelID { get; set; }
public string? ImageModelID { get; set; }
/// <summary>
/// 温度
/// </summary>
@@ -53,6 +57,7 @@ namespace AntSK.Domain.Repositories
/// <summary>
/// 提示词
/// </summary>
[SugarColumn(ColumnDataType = "varchar(2000)")]
public string? Prompt { get; set; }
/// <summary>
@@ -76,5 +81,31 @@ namespace AntSK.Domain.Repositories
/// API调用秘钥
/// </summary>
public string? SecretKey { get; set; }
/// <summary>
/// 相似度
/// </summary>
[SugarColumn(DefaultValue = "70")]
public double Relevance { get; set; } = 70f;
/// <summary>
/// 提问最大token数
/// </summary>
[SugarColumn(DefaultValue = "2048")]
public int MaxAskPromptSize { get; set; } = 2048;
/// <summary>
/// 向量匹配数
/// </summary>
[SugarColumn(DefaultValue = "3")]
public int MaxMatchesCount { get; set; } = 3;
[SugarColumn(DefaultValue = "20")]
public int RerankCount { get; set; } = 20;
/// <summary>
/// 回答最大token数
/// </summary>
[SugarColumn(DefaultValue = "2048")]
public int AnswerTokens { get; set; } = 2048;
}
}

View File

@@ -0,0 +1,41 @@
using AntSK.Domain.Domain.Model.Enum;
using SqlSugar;
using System.ComponentModel.DataAnnotations;
namespace AntSK.Domain.Repositories
{
[SugarTable("Chats")]
public partial class Chats
{
[SugarColumn(IsPrimaryKey = true)]
public string Id { get; set; }
/// <summary>
/// 用户名
/// </summary>
public string UserName { get; set; }
/// <summary>
/// 应用ID
/// </summary>
public string AppId { get; set; }
/// <summary>
/// 消息内容
/// </summary>
[SugarColumn(ColumnDataType = "varchar(4000)")]
public string Context { get; set; } = "";
/// <summary>
/// 发送是true 接收是false
/// </summary>
public bool IsSend { get; set; } = false;
/// <summary>
/// 创建事件
/// </summary>
public DateTime CreateTime { get; set; }
/// <summary>
/// 文件名
/// </summary>
public string? FileName { get; set; }
}
}

View File

@@ -0,0 +1,11 @@

using AntSK.Domain.Common.DependencyInjection;
using AntSK.Domain.Repositories.Base;
namespace AntSK.Domain.Repositories
{
[ServiceDescription(typeof(IChats_Repositories), ServiceLifetime.Scoped)]
public class Chats_Repositories : Repository<Chats>, IChats_Repositories
{
}
}

View File

@@ -0,0 +1,8 @@
using AntSK.Domain.Repositories.Base;
namespace AntSK.Domain.Repositories
{
public interface IChats_Repositories : IRepository<Chats>
{
}
}

View File

@@ -55,6 +55,7 @@ namespace AntSK.Domain.Repositories
[SugarColumn(DefaultValue = "49")]
public int OverlappingTokens { get; set; } = 49;
[SugarColumn(DefaultValue = "0")]
public int IsOCR { get; set; } = 0;
}
}

View File

@@ -1,4 +1,5 @@
using System.Web;
using System.Security.Cryptography;
using System.Web;
namespace AntSK.Domain.Utils
{
@@ -250,5 +251,22 @@ namespace AntSK.Domain.Utils
return nameValueCollection.ToString();
}
/// <summary>
/// 忽略大小写匹配
/// </summary>
/// <param name="s"></param>
/// <param name="value"></param>
/// <returns></returns>
public static bool ComparisonIgnoreCase(this string s, string value)
{
return s.Equals(value, StringComparison.OrdinalIgnoreCase);
}
public static string AntSKCalculateSHA256(this BinaryData binaryData)
{
byte[] byteArray = SHA256.HashData(binaryData.ToMemory().Span);
return Convert.ToHexString(byteArray).ToLowerInvariant();
}
}
}

View File

@@ -0,0 +1,39 @@
using System;
using System.Collections.Generic;
using System.Drawing.Imaging;
using System.Drawing;
using System.Linq;
using System.Text;
using System.Threading.Tasks;
namespace AntSK.Domain.Utils
{
public class ImageUtils
{
public static string BitmapToBase64(Bitmap bitmap)
{
using (MemoryStream memoryStream = new MemoryStream())
{
// 保存为JPEG格式也可以选择PngGif等等
bitmap.Save(memoryStream, ImageFormat.Jpeg);
// 获取内存流的字节数组
byte[] imageBytes = memoryStream.ToArray();
// 将字节转换为Base64字符串
string base64String = Convert.ToBase64String(imageBytes);
return base64String;
}
}
public static List<string> BitmapListToBase64(Bitmap[] bitmaps)
{
List<string> base64Strings = new List<string>();
foreach (Bitmap bitmap in bitmaps)
{
base64Strings.Add(BitmapToBase64(bitmap));
}
return base64Strings;
}
}
}

View File

@@ -1,4 +1,6 @@
using System.Text.RegularExpressions;

using Serilog;
using System.Text.RegularExpressions;
namespace AntSK.Domain.Utils
{
@@ -19,7 +21,7 @@ namespace AntSK.Domain.Utils
{
string requestBody = await request.Content.ReadAsStringAsync();
//便于调试查看请求prompt
Console.WriteLine(requestBody);
Log.Information(requestBody);
}
if (match.Success)
{

View File

@@ -0,0 +1,19 @@
import os
import uvicorn
from llamafactory.api.app import create_app
from llamafactory.chat import ChatModel
def main():
chat_model = ChatModel()
app = create_app(chat_model)
api_host = os.environ.get("API_HOST", "0.0.0.0")
api_port = int(os.environ.get("API_PORT", "8000"))
print("Visit http://localhost:{}/docs for API document.".format(api_port))
uvicorn.run(app, host=api_host, port=api_port)
if __name__ == "__main__":
main()

View File

@@ -1,16 +0,0 @@
import os
import uvicorn
from llmtuner import ChatModel, create_app
def main():
chat_model = ChatModel()
app = create_app(chat_model)
print("Visit http://localhost:{}/docs for API document.".format(os.environ.get("API_PORT", 8000)))
uvicorn.run(app, host="0.0.0.0", port=int(os.environ.get("API_PORT", 8000)), workers=1)
if __name__ == "__main__":
main()

View File

@@ -1,27 +0,0 @@
import subprocess
import shlex
import os
class Start(object):
def __init__(self,model_name_or_path):
self.model_name_or_path=model_name_or_path
def StartCommand(self):
os.environ['CUDA_VISIBLE_DEVICES'] = '0'
os.environ['API_PORT'] = '8000'
# 构建要执行的命令
command = (
'python api_demo.py'
' --model_name_or_path E:/model/Qwen1.5-0.5B-Chat_back'
' --template default '
)
# 使用shlex.split()去安全地分割命令字符串
command = shlex.split(command)
# 执行命令
subprocess.run(command, shell=True)
if __name__ == "__main__":
star= Start('model_name_or_path')
star.StartCommand()

View File

@@ -1,49 +0,0 @@
from llmtuner import ChatModel
from llmtuner.extras.misc import torch_gc
try:
import platform
if platform.system() != "Windows":
import readline # noqa: F401
except ImportError:
print("Install `readline` for a better experience.")
def main():
chat_model = ChatModel()
messages = []
print("Welcome to the CLI application, use `clear` to remove the history, use `exit` to exit the application.")
while True:
try:
query = input("\nUser: ")
except UnicodeDecodeError:
print("Detected decoding error at the inputs, please set the terminal encoding to utf-8.")
continue
except Exception:
raise
if query.strip() == "exit":
break
if query.strip() == "clear":
messages = []
torch_gc()
print("History has been removed.")
continue
messages.append({"role": "user", "content": query})
print("Assistant: ", end="", flush=True)
response = ""
for new_text in chat_model.stream_chat(messages):
print(new_text, end="", flush=True)
response += new_text
print()
messages.append({"role": "assistant", "content": response})
if __name__ == "__main__":
main()

View File

@@ -1,10 +0,0 @@
from llmtuner import Evaluator
def main():
evaluator = Evaluator()
evaluator.eval()
if __name__ == "__main__":
main()

View File

@@ -1,9 +0,0 @@
from llmtuner import export_model
def main():
export_model()
if __name__ == "__main__":
main()

View File

@@ -0,0 +1,6 @@
# Level: api, webui > chat, eval, train > data, model > hparams > extras
from .cli import VERSION
__version__ = VERSION

View File

@@ -0,0 +1,108 @@
import os
from contextlib import asynccontextmanager
from typing import Optional
from typing_extensions import Annotated
from ..chat import ChatModel
from ..extras.misc import torch_gc
from ..extras.packages import is_fastapi_available, is_starlette_available, is_uvicorn_available
from .chat import (
create_chat_completion_response,
create_score_evaluation_response,
create_stream_chat_completion_response,
)
from .protocol import (
ChatCompletionRequest,
ChatCompletionResponse,
ModelCard,
ModelList,
ScoreEvaluationRequest,
ScoreEvaluationResponse,
)
if is_fastapi_available():
from fastapi import Depends, FastAPI, HTTPException, status
from fastapi.middleware.cors import CORSMiddleware
from fastapi.security.http import HTTPAuthorizationCredentials, HTTPBearer
if is_starlette_available():
from sse_starlette import EventSourceResponse
if is_uvicorn_available():
import uvicorn
@asynccontextmanager
async def lifespan(app: "FastAPI"): # collects GPU memory
yield
torch_gc()
def create_app(chat_model: "ChatModel") -> "FastAPI":
app = FastAPI(lifespan=lifespan)
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
api_key = os.environ.get("API_KEY")
security = HTTPBearer(auto_error=False)
async def verify_api_key(auth: Annotated[Optional[HTTPAuthorizationCredentials], Depends(security)]):
if api_key and (auth is None or auth.credentials != api_key):
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid API key.")
@app.get(
"/v1/models",
response_model=ModelList,
status_code=status.HTTP_200_OK,
dependencies=[Depends(verify_api_key)],
)
async def list_models():
model_card = ModelCard(id="gpt-3.5-turbo")
return ModelList(data=[model_card])
@app.post(
"/v1/chat/completions",
response_model=ChatCompletionResponse,
status_code=status.HTTP_200_OK,
dependencies=[Depends(verify_api_key)],
)
async def create_chat_completion(request: ChatCompletionRequest):
if not chat_model.engine.can_generate:
raise HTTPException(status_code=status.HTTP_405_METHOD_NOT_ALLOWED, detail="Not allowed")
if request.stream:
generate = create_stream_chat_completion_response(request, chat_model)
return EventSourceResponse(generate, media_type="text/event-stream")
else:
return await create_chat_completion_response(request, chat_model)
@app.post(
"/v1/score/evaluation",
response_model=ScoreEvaluationResponse,
status_code=status.HTTP_200_OK,
dependencies=[Depends(verify_api_key)],
)
async def create_score_evaluation(request: ScoreEvaluationRequest):
if chat_model.engine.can_generate:
raise HTTPException(status_code=status.HTTP_405_METHOD_NOT_ALLOWED, detail="Not allowed")
return await create_score_evaluation_response(request, chat_model)
return app
def run_api() -> None:
chat_model = ChatModel()
app = create_app(chat_model)
api_host = os.environ.get("API_HOST", "0.0.0.0")
api_port = int(os.environ.get("API_PORT", "8000"))
print("Visit http://localhost:{}/docs for API document.".format(api_port))
uvicorn.run(app, host=api_host, port=api_port)

View File

@@ -0,0 +1,219 @@
import base64
import io
import json
import os
import uuid
from typing import TYPE_CHECKING, AsyncGenerator, Dict, List, Optional, Tuple
from ..data import Role as DataRole
from ..extras.logging import get_logger
from ..extras.packages import is_fastapi_available, is_pillow_available, is_requests_available
from .common import dictify, jsonify
from .protocol import (
ChatCompletionMessage,
ChatCompletionResponse,
ChatCompletionResponseChoice,
ChatCompletionResponseUsage,
ChatCompletionStreamResponse,
ChatCompletionStreamResponseChoice,
Finish,
Function,
FunctionCall,
Role,
ScoreEvaluationResponse,
)
if is_fastapi_available():
from fastapi import HTTPException, status
if is_pillow_available():
from PIL import Image
if is_requests_available():
import requests
if TYPE_CHECKING:
from numpy.typing import NDArray
from ..chat import ChatModel
from .protocol import ChatCompletionRequest, ScoreEvaluationRequest
logger = get_logger(__name__)
ROLE_MAPPING = {
Role.USER: DataRole.USER.value,
Role.ASSISTANT: DataRole.ASSISTANT.value,
Role.SYSTEM: DataRole.SYSTEM.value,
Role.FUNCTION: DataRole.FUNCTION.value,
Role.TOOL: DataRole.OBSERVATION.value,
}
def _process_request(
request: "ChatCompletionRequest",
) -> Tuple[List[Dict[str, str]], Optional[str], Optional[str], Optional["NDArray"]]:
logger.info("==== request ====\n{}".format(json.dumps(dictify(request), indent=2, ensure_ascii=False)))
if len(request.messages) == 0:
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid length")
if request.messages[0].role == Role.SYSTEM:
system = request.messages.pop(0).content
else:
system = None
if len(request.messages) % 2 == 0:
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Only supports u/a/u/a/u...")
input_messages = []
image = None
for i, message in enumerate(request.messages):
if i % 2 == 0 and message.role not in [Role.USER, Role.TOOL]:
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid role")
elif i % 2 == 1 and message.role not in [Role.ASSISTANT, Role.FUNCTION]:
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid role")
if message.role == Role.ASSISTANT and isinstance(message.tool_calls, list) and len(message.tool_calls):
name = message.tool_calls[0].function.name
arguments = message.tool_calls[0].function.arguments
content = json.dumps({"name": name, "argument": arguments}, ensure_ascii=False)
input_messages.append({"role": ROLE_MAPPING[Role.FUNCTION], "content": content})
elif isinstance(message.content, list):
for input_item in message.content:
if input_item.type == "text":
input_messages.append({"role": ROLE_MAPPING[message.role], "content": input_item.text})
else:
image_url = input_item.image_url.url
if image_url.startswith("data:image"): # base64 image
image_data = base64.b64decode(image_url.split(",", maxsplit=1)[1])
image_path = io.BytesIO(image_data)
elif os.path.isfile(image_url): # local file
image_path = open(image_url, "rb")
else: # web uri
image_path = requests.get(image_url, stream=True).raw
image = Image.open(image_path).convert("RGB")
else:
input_messages.append({"role": ROLE_MAPPING[message.role], "content": message.content})
tool_list = request.tools
if isinstance(tool_list, list) and len(tool_list):
try:
tools = json.dumps([dictify(tool.function) for tool in tool_list], ensure_ascii=False)
except Exception:
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid tools")
else:
tools = None
return input_messages, system, tools, image
def _create_stream_chat_completion_chunk(
completion_id: str,
model: str,
delta: "ChatCompletionMessage",
index: Optional[int] = 0,
finish_reason: Optional["Finish"] = None,
) -> str:
choice_data = ChatCompletionStreamResponseChoice(index=index, delta=delta, finish_reason=finish_reason)
chunk = ChatCompletionStreamResponse(id=completion_id, model=model, choices=[choice_data])
return jsonify(chunk)
async def create_chat_completion_response(
request: "ChatCompletionRequest", chat_model: "ChatModel"
) -> "ChatCompletionResponse":
completion_id = "chatcmpl-{}".format(uuid.uuid4().hex)
input_messages, system, tools, image = _process_request(request)
responses = await chat_model.achat(
input_messages,
system,
tools,
image,
do_sample=request.do_sample,
temperature=request.temperature,
top_p=request.top_p,
max_new_tokens=request.max_tokens,
num_return_sequences=request.n,
stop=request.stop,
)
prompt_length, response_length = 0, 0
choices = []
for i, response in enumerate(responses):
if tools:
result = chat_model.engine.template.format_tools.extract(response.response_text)
else:
result = response.response_text
if isinstance(result, tuple):
name, arguments = result
function = Function(name=name, arguments=arguments)
tool_call = FunctionCall(id="call_{}".format(uuid.uuid4().hex), function=function)
response_message = ChatCompletionMessage(role=Role.ASSISTANT, tool_calls=[tool_call])
finish_reason = Finish.TOOL
else:
response_message = ChatCompletionMessage(role=Role.ASSISTANT, content=result)
finish_reason = Finish.STOP if response.finish_reason == "stop" else Finish.LENGTH
choices.append(ChatCompletionResponseChoice(index=i, message=response_message, finish_reason=finish_reason))
prompt_length = response.prompt_length
response_length += response.response_length
usage = ChatCompletionResponseUsage(
prompt_tokens=prompt_length,
completion_tokens=response_length,
total_tokens=prompt_length + response_length,
)
return ChatCompletionResponse(id=completion_id, model=request.model, choices=choices, usage=usage)
async def create_stream_chat_completion_response(
request: "ChatCompletionRequest", chat_model: "ChatModel"
) -> AsyncGenerator[str, None]:
completion_id = "chatcmpl-{}".format(uuid.uuid4().hex)
input_messages, system, tools, image = _process_request(request)
if tools:
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Cannot stream function calls.")
if request.n > 1:
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Cannot stream multiple responses.")
yield _create_stream_chat_completion_chunk(
completion_id=completion_id, model=request.model, delta=ChatCompletionMessage(role=Role.ASSISTANT, content="")
)
async for new_token in chat_model.astream_chat(
input_messages,
system,
tools,
image,
do_sample=request.do_sample,
temperature=request.temperature,
top_p=request.top_p,
max_new_tokens=request.max_tokens,
stop=request.stop,
):
if len(new_token) != 0:
yield _create_stream_chat_completion_chunk(
completion_id=completion_id, model=request.model, delta=ChatCompletionMessage(content=new_token)
)
yield _create_stream_chat_completion_chunk(
completion_id=completion_id, model=request.model, delta=ChatCompletionMessage(), finish_reason=Finish.STOP
)
yield "[DONE]"
async def create_score_evaluation_response(
request: "ScoreEvaluationRequest", chat_model: "ChatModel"
) -> "ScoreEvaluationResponse":
if len(request.messages) == 0:
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid request")
scores = await chat_model.aget_scores(request.messages, max_length=request.max_length)
return ScoreEvaluationResponse(model=request.model, scores=scores)

View File

@@ -0,0 +1,20 @@
import json
from typing import TYPE_CHECKING, Any, Dict
if TYPE_CHECKING:
from pydantic import BaseModel
def dictify(data: "BaseModel") -> Dict[str, Any]:
try: # pydantic v2
return data.model_dump(exclude_unset=True)
except AttributeError: # pydantic v1
return data.dict(exclude_unset=True)
def jsonify(data: "BaseModel") -> str:
try: # pydantic v2
return json.dumps(data.model_dump(exclude_unset=True), ensure_ascii=False)
except AttributeError: # pydantic v1
return data.json(exclude_unset=True, ensure_ascii=False)

View File

@@ -1,6 +1,6 @@
import time
from enum import Enum, unique
from typing import List, Optional
from typing import Any, Dict, List, Optional, Union
from pydantic import BaseModel, Field
from typing_extensions import Literal
@@ -39,15 +39,37 @@ class Function(BaseModel):
arguments: str
class FunctionDefinition(BaseModel):
name: str
description: str
parameters: Dict[str, Any]
class FunctionAvailable(BaseModel):
type: Literal["function", "code_interpreter"] = "function"
function: Optional[FunctionDefinition] = None
class FunctionCall(BaseModel):
id: Literal["call_default"] = "call_default"
id: str
type: Literal["function"] = "function"
function: Function
class ImageURL(BaseModel):
url: str
class MultimodalInputItem(BaseModel):
type: Literal["text", "image_url"]
text: Optional[str] = None
image_url: Optional[ImageURL] = None
class ChatMessage(BaseModel):
role: Role
content: str
content: Optional[Union[str, List[MultimodalInputItem]]] = None
tool_calls: Optional[List[FunctionCall]] = None
class ChatCompletionMessage(BaseModel):
@@ -59,12 +81,13 @@ class ChatCompletionMessage(BaseModel):
class ChatCompletionRequest(BaseModel):
model: str
messages: List[ChatMessage]
tools: list = []
tools: Optional[List[FunctionAvailable]] = None
do_sample: bool = True
temperature: Optional[float] = None
top_p: Optional[float] = None
n: int = 1
max_tokens: Optional[int] = None
stop: Optional[Union[str, List[str]]] = None
stream: bool = False
@@ -74,7 +97,7 @@ class ChatCompletionResponseChoice(BaseModel):
finish_reason: Finish
class ChatCompletionResponseStreamChoice(BaseModel):
class ChatCompletionStreamResponseChoice(BaseModel):
index: int
delta: ChatCompletionMessage
finish_reason: Optional[Finish] = None
@@ -87,7 +110,7 @@ class ChatCompletionResponseUsage(BaseModel):
class ChatCompletionResponse(BaseModel):
id: Literal["chatcmpl-default"] = "chatcmpl-default"
id: str
object: Literal["chat.completion"] = "chat.completion"
created: int = Field(default_factory=lambda: int(time.time()))
model: str
@@ -96,11 +119,11 @@ class ChatCompletionResponse(BaseModel):
class ChatCompletionStreamResponse(BaseModel):
id: Literal["chatcmpl-default"] = "chatcmpl-default"
id: str
object: Literal["chat.completion.chunk"] = "chat.completion.chunk"
created: int = Field(default_factory=lambda: int(time.time()))
model: str
choices: List[ChatCompletionResponseStreamChoice]
choices: List[ChatCompletionStreamResponseChoice]
class ScoreEvaluationRequest(BaseModel):
@@ -110,7 +133,7 @@ class ScoreEvaluationRequest(BaseModel):
class ScoreEvaluationResponse(BaseModel):
id: Literal["scoreeval-default"] = "scoreeval-default"
id: str
object: Literal["score.evaluation"] = "score.evaluation"
model: str
scores: List[float]

View File

@@ -4,15 +4,13 @@ from typing import TYPE_CHECKING, Any, AsyncGenerator, Dict, List, Literal, Opti
if TYPE_CHECKING:
from numpy.typing import NDArray
from transformers import PreTrainedModel, PreTrainedTokenizer
from vllm import AsyncLLMEngine
from ..data import Template
from ..extras.packages import is_vllm_available
from ..hparams import DataArguments, FinetuningArguments, GeneratingArguments, ModelArguments
if is_vllm_available():
from vllm import AsyncLLMEngine
@dataclass
class Response:
@@ -49,6 +47,7 @@ class BaseEngine(ABC):
messages: Sequence[Dict[str, str]],
system: Optional[str] = None,
tools: Optional[str] = None,
image: Optional["NDArray"] = None,
**input_kwargs,
) -> List["Response"]: ...
@@ -58,6 +57,7 @@ class BaseEngine(ABC):
messages: Sequence[Dict[str, str]],
system: Optional[str] = None,
tools: Optional[str] = None,
image: Optional["NDArray"] = None,
**input_kwargs,
) -> AsyncGenerator[str, None]: ...

View File

@@ -2,12 +2,15 @@ import asyncio
from threading import Thread
from typing import TYPE_CHECKING, Any, AsyncGenerator, Dict, Generator, List, Optional, Sequence
from ..extras.misc import torch_gc
from ..hparams import get_infer_args
from .hf_engine import HuggingfaceEngine
from .vllm_engine import VllmEngine
if TYPE_CHECKING:
from numpy.typing import NDArray
from .base_engine import BaseEngine, Response
@@ -36,9 +39,10 @@ class ChatModel:
messages: Sequence[Dict[str, str]],
system: Optional[str] = None,
tools: Optional[str] = None,
image: Optional["NDArray"] = None,
**input_kwargs,
) -> List["Response"]:
task = asyncio.run_coroutine_threadsafe(self.achat(messages, system, tools, **input_kwargs), self._loop)
task = asyncio.run_coroutine_threadsafe(self.achat(messages, system, tools, image, **input_kwargs), self._loop)
return task.result()
async def achat(
@@ -46,18 +50,20 @@ class ChatModel:
messages: Sequence[Dict[str, str]],
system: Optional[str] = None,
tools: Optional[str] = None,
image: Optional["NDArray"] = None,
**input_kwargs,
) -> List["Response"]:
return await self.engine.chat(messages, system, tools, **input_kwargs)
return await self.engine.chat(messages, system, tools, image, **input_kwargs)
def stream_chat(
self,
messages: Sequence[Dict[str, str]],
system: Optional[str] = None,
tools: Optional[str] = None,
image: Optional["NDArray"] = None,
**input_kwargs,
) -> Generator[str, None, None]:
generator = self.astream_chat(messages, system, tools, **input_kwargs)
generator = self.astream_chat(messages, system, tools, image, **input_kwargs)
while True:
try:
task = asyncio.run_coroutine_threadsafe(generator.__anext__(), self._loop)
@@ -70,9 +76,10 @@ class ChatModel:
messages: Sequence[Dict[str, str]],
system: Optional[str] = None,
tools: Optional[str] = None,
image: Optional["NDArray"] = None,
**input_kwargs,
) -> AsyncGenerator[str, None]:
async for new_token in self.engine.stream_chat(messages, system, tools, **input_kwargs):
async for new_token in self.engine.stream_chat(messages, system, tools, image, **input_kwargs):
yield new_token
def get_scores(
@@ -89,3 +96,45 @@ class ChatModel:
**input_kwargs,
) -> List[float]:
return await self.engine.get_scores(batch_input, **input_kwargs)
def run_chat() -> None:
try:
import platform
if platform.system() != "Windows":
import readline # noqa: F401
except ImportError:
print("Install `readline` for a better experience.")
chat_model = ChatModel()
messages = []
print("Welcome to the CLI application, use `clear` to remove the history, use `exit` to exit the application.")
while True:
try:
query = input("\nUser: ")
except UnicodeDecodeError:
print("Detected decoding error at the inputs, please set the terminal encoding to utf-8.")
continue
except Exception:
raise
if query.strip() == "exit":
break
if query.strip() == "clear":
messages = []
torch_gc()
print("History has been removed.")
continue
messages.append({"role": "user", "content": query})
print("Assistant: ", end="", flush=True)
response = ""
for new_text in chat_model.stream_chat(messages):
print(new_text, end="", flush=True)
response += new_text
print()
messages.append({"role": "assistant", "content": response})

View File

@@ -2,25 +2,31 @@ import asyncio
import concurrent.futures
import os
from threading import Thread
from typing import TYPE_CHECKING, Any, AsyncGenerator, Callable, Dict, List, Optional, Sequence, Tuple
from typing import TYPE_CHECKING, Any, AsyncGenerator, Callable, Dict, List, Optional, Sequence, Tuple, Union
import torch
from transformers import GenerationConfig, TextIteratorStreamer
from ..data import get_template_and_fix_tokenizer
from ..extras.logging import get_logger
from ..extras.misc import get_logits_processor
from ..model import load_model_and_tokenizer
from ..model import load_model, load_tokenizer
from .base_engine import BaseEngine, Response
if TYPE_CHECKING:
from transformers import PreTrainedModel, PreTrainedTokenizer
from numpy.typing import NDArray
from transformers import PreTrainedModel, PreTrainedTokenizer, ProcessorMixin
from transformers.image_processing_utils import BaseImageProcessor
from trl import PreTrainedModelWrapper
from ..data import Template
from ..hparams import DataArguments, FinetuningArguments, GeneratingArguments, ModelArguments
logger = get_logger(__name__)
class HuggingfaceEngine(BaseEngine):
def __init__(
self,
@@ -30,55 +36,96 @@ class HuggingfaceEngine(BaseEngine):
generating_args: "GeneratingArguments",
) -> None:
self.can_generate = finetuning_args.stage == "sft"
self.model, self.tokenizer = load_model_and_tokenizer(
model_args, finetuning_args, is_trainable=False, add_valuehead=(not self.can_generate)
)
tokenizer_module = load_tokenizer(model_args)
self.tokenizer = tokenizer_module["tokenizer"]
self.processor = tokenizer_module["processor"]
self.tokenizer.padding_side = "left" if self.can_generate else "right"
self.template = get_template_and_fix_tokenizer(self.tokenizer, data_args.template)
self.model = load_model(
self.tokenizer, model_args, finetuning_args, is_trainable=False, add_valuehead=(not self.can_generate)
) # must after fixing tokenizer to resize vocab
self.generating_args = generating_args.to_dict()
@staticmethod
def _process_args(
model: "PreTrainedModel",
tokenizer: "PreTrainedTokenizer",
processor: Optional["ProcessorMixin"],
template: "Template",
generating_args: Dict[str, Any],
messages: Sequence[Dict[str, str]],
system: Optional[str] = None,
tools: Optional[str] = None,
image: Optional["NDArray"] = None,
input_kwargs: Optional[Dict[str, Any]] = {},
) -> Tuple[Dict[str, Any], int]:
if (
processor is not None
and image is not None
and not hasattr(processor, "image_seq_length")
and template.image_token not in messages[0]["content"]
): # llava-like models
messages[0]["content"] = template.image_token + messages[0]["content"]
paired_messages = messages + [{"role": "assistant", "content": ""}]
system = system or generating_args["default_system"]
pixel_values = None
prompt_ids, _ = template.encode_oneturn(
tokenizer=tokenizer, messages=paired_messages, system=system, tools=tools
)
if processor is not None and image is not None: # add image features
image_processor: "BaseImageProcessor" = getattr(processor, "image_processor")
batch_feature = image_processor(image, return_tensors="pt")
pixel_values = batch_feature.to(model.device)["pixel_values"] # shape (B, C, H, W)
if hasattr(processor, "image_seq_length"): # paligemma models
image_token_id = tokenizer.convert_tokens_to_ids(template.image_token)
prompt_ids = [image_token_id] * getattr(processor, "image_seq_length") + prompt_ids
prompt_length = len(prompt_ids)
inputs = torch.tensor([prompt_ids], device=model.device)
attention_mask = torch.ones_like(inputs, dtype=torch.bool)
do_sample = input_kwargs.pop("do_sample", None)
temperature = input_kwargs.pop("temperature", None)
top_p = input_kwargs.pop("top_p", None)
top_k = input_kwargs.pop("top_k", None)
num_return_sequences = input_kwargs.pop("num_return_sequences", None)
repetition_penalty = input_kwargs.pop("repetition_penalty", None)
max_length = input_kwargs.pop("max_length", None)
max_new_tokens = input_kwargs.pop("max_new_tokens", None)
do_sample: Optional[bool] = input_kwargs.pop("do_sample", None)
temperature: Optional[float] = input_kwargs.pop("temperature", None)
top_p: Optional[float] = input_kwargs.pop("top_p", None)
top_k: Optional[float] = input_kwargs.pop("top_k", None)
num_return_sequences: int = input_kwargs.pop("num_return_sequences", 1)
repetition_penalty: Optional[float] = input_kwargs.pop("repetition_penalty", None)
length_penalty: Optional[float] = input_kwargs.pop("length_penalty", None)
max_length: Optional[int] = input_kwargs.pop("max_length", None)
max_new_tokens: Optional[int] = input_kwargs.pop("max_new_tokens", None)
stop: Optional[Union[str, List[str]]] = input_kwargs.pop("stop", None)
if stop is not None:
logger.warning("Stop parameter is not supported in Huggingface engine yet.")
generating_args = generating_args.copy()
generating_args.update(
dict(
do_sample=do_sample if do_sample is not None else generating_args["do_sample"],
temperature=temperature or generating_args["temperature"],
top_p=top_p or generating_args["top_p"],
top_k=top_k or generating_args["top_k"],
num_return_sequences=num_return_sequences or 1,
repetition_penalty=repetition_penalty or generating_args["repetition_penalty"],
temperature=temperature if temperature is not None else generating_args["temperature"],
top_p=top_p if top_p is not None else generating_args["top_p"],
top_k=top_k if top_k is not None else generating_args["top_k"],
num_return_sequences=num_return_sequences,
repetition_penalty=repetition_penalty
if repetition_penalty is not None
else generating_args["repetition_penalty"],
length_penalty=length_penalty if length_penalty is not None else generating_args["length_penalty"],
eos_token_id=[tokenizer.eos_token_id] + tokenizer.additional_special_tokens_ids,
pad_token_id=tokenizer.pad_token_id,
)
)
if isinstance(num_return_sequences, int) and num_return_sequences > 1:
if isinstance(num_return_sequences, int) and num_return_sequences > 1: # do_sample needs temperature > 0
generating_args["do_sample"] = True
generating_args["temperature"] = generating_args["temperature"] or 1.0
if not generating_args["temperature"]:
generating_args["do_sample"] = False
if not generating_args["do_sample"]:
generating_args.pop("temperature", None)
generating_args.pop("top_p", None)
if max_length:
generating_args.pop("max_new_tokens", None)
@@ -90,10 +137,14 @@ class HuggingfaceEngine(BaseEngine):
gen_kwargs = dict(
inputs=inputs,
attention_mask=attention_mask,
generation_config=GenerationConfig(**generating_args),
logits_processor=get_logits_processor(),
)
if pixel_values is not None:
gen_kwargs["pixel_values"] = pixel_values
return gen_kwargs, prompt_length
@staticmethod
@@ -101,15 +152,17 @@ class HuggingfaceEngine(BaseEngine):
def _chat(
model: "PreTrainedModel",
tokenizer: "PreTrainedTokenizer",
processor: Optional["ProcessorMixin"],
template: "Template",
generating_args: Dict[str, Any],
messages: Sequence[Dict[str, str]],
system: Optional[str] = None,
tools: Optional[str] = None,
image: Optional["NDArray"] = None,
input_kwargs: Optional[Dict[str, Any]] = {},
) -> List["Response"]:
gen_kwargs, prompt_length = HuggingfaceEngine._process_args(
model, tokenizer, template, generating_args, messages, system, tools, input_kwargs
model, tokenizer, processor, template, generating_args, messages, system, tools, image, input_kwargs
)
generate_output = model.generate(**gen_kwargs)
response_ids = generate_output[:, prompt_length:]
@@ -134,15 +187,17 @@ class HuggingfaceEngine(BaseEngine):
def _stream_chat(
model: "PreTrainedModel",
tokenizer: "PreTrainedTokenizer",
processor: Optional["ProcessorMixin"],
template: "Template",
generating_args: Dict[str, Any],
messages: Sequence[Dict[str, str]],
system: Optional[str] = None,
tools: Optional[str] = None,
image: Optional["NDArray"] = None,
input_kwargs: Optional[Dict[str, Any]] = {},
) -> Callable[[], str]:
gen_kwargs, _ = HuggingfaceEngine._process_args(
model, tokenizer, template, generating_args, messages, system, tools, input_kwargs
model, tokenizer, processor, template, generating_args, messages, system, tools, image, input_kwargs
)
streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
gen_kwargs["streamer"] = streamer
@@ -198,6 +253,7 @@ class HuggingfaceEngine(BaseEngine):
messages: Sequence[Dict[str, str]],
system: Optional[str] = None,
tools: Optional[str] = None,
image: Optional["NDArray"] = None,
**input_kwargs,
) -> List["Response"]:
if not self.can_generate:
@@ -207,11 +263,13 @@ class HuggingfaceEngine(BaseEngine):
input_args = (
self.model,
self.tokenizer,
self.processor,
self.template,
self.generating_args,
messages,
system,
tools,
image,
input_kwargs,
)
async with self._semaphore:
@@ -223,6 +281,7 @@ class HuggingfaceEngine(BaseEngine):
messages: Sequence[Dict[str, str]],
system: Optional[str] = None,
tools: Optional[str] = None,
image: Optional["NDArray"] = None,
**input_kwargs,
) -> AsyncGenerator[str, None]:
if not self.can_generate:
@@ -232,11 +291,13 @@ class HuggingfaceEngine(BaseEngine):
input_args = (
self.model,
self.tokenizer,
self.processor,
self.template,
self.generating_args,
messages,
system,
tools,
image,
input_kwargs,
)
async with self._semaphore:

View File

@@ -0,0 +1,214 @@
import uuid
from typing import TYPE_CHECKING, AsyncGenerator, AsyncIterator, Dict, List, Optional, Sequence, Union
from ..data import get_template_and_fix_tokenizer
from ..extras.logging import get_logger
from ..extras.misc import get_device_count
from ..extras.packages import is_vllm_available
from ..model import load_config, load_tokenizer
from ..model.model_utils.visual import LlavaMultiModalProjectorForYiVLForVLLM
from .base_engine import BaseEngine, Response
if is_vllm_available():
from vllm import AsyncEngineArgs, AsyncLLMEngine, RequestOutput, SamplingParams
from vllm.lora.request import LoRARequest
from vllm.sequence import MultiModalData
if TYPE_CHECKING:
from numpy.typing import NDArray
from transformers.image_processing_utils import BaseImageProcessor
from ..hparams import DataArguments, FinetuningArguments, GeneratingArguments, ModelArguments
logger = get_logger(__name__)
class VllmEngine(BaseEngine):
def __init__(
self,
model_args: "ModelArguments",
data_args: "DataArguments",
finetuning_args: "FinetuningArguments",
generating_args: "GeneratingArguments",
) -> None:
config = load_config(model_args) # may download model from ms hub
self.can_generate = finetuning_args.stage == "sft"
tokenizer_module = load_tokenizer(model_args)
self.tokenizer = tokenizer_module["tokenizer"]
self.processor = tokenizer_module["processor"]
self.tokenizer.padding_side = "left"
self.template = get_template_and_fix_tokenizer(self.tokenizer, data_args.template)
self.generating_args = generating_args.to_dict()
engine_args = {
"model": model_args.model_name_or_path,
"trust_remote_code": True,
"download_dir": model_args.cache_dir,
"dtype": model_args.vllm_dtype,
"max_model_len": model_args.vllm_maxlen,
"tensor_parallel_size": get_device_count() or 1,
"gpu_memory_utilization": model_args.vllm_gpu_util,
"disable_log_stats": True,
"disable_log_requests": True,
"enforce_eager": model_args.vllm_enforce_eager,
"enable_lora": model_args.adapter_name_or_path is not None,
"max_lora_rank": model_args.vllm_max_lora_rank,
}
if model_args.visual_inputs:
image_size = config.vision_config.image_size
patch_size = config.vision_config.patch_size
self.image_feature_size = (image_size // patch_size) ** 2
engine_args["image_input_type"] = "pixel_values"
engine_args["image_token_id"] = self.tokenizer.convert_tokens_to_ids(self.template.image_token)
engine_args["image_input_shape"] = "1,3,{},{}".format(image_size, image_size)
engine_args["image_feature_size"] = self.image_feature_size
if getattr(config, "is_yi_vl_derived_model", None):
import vllm.model_executor.models.llava
logger.info("Detected Yi-VL model, applying projector patch.")
vllm.model_executor.models.llava.LlavaMultiModalProjector = LlavaMultiModalProjectorForYiVLForVLLM
self.model = AsyncLLMEngine.from_engine_args(AsyncEngineArgs(**engine_args))
if model_args.adapter_name_or_path is not None:
self.lora_request = LoRARequest("default", 1, model_args.adapter_name_or_path[0])
else:
self.lora_request = None
async def _generate(
self,
messages: Sequence[Dict[str, str]],
system: Optional[str] = None,
tools: Optional[str] = None,
image: Optional["NDArray"] = None,
**input_kwargs,
) -> AsyncIterator["RequestOutput"]:
request_id = "chatcmpl-{}".format(uuid.uuid4().hex)
if (
self.processor is not None
and image is not None
and not hasattr(self.processor, "image_seq_length")
and self.template.image_token not in messages[0]["content"]
): # llava-like models (TODO: paligemma models)
messages[0]["content"] = self.template.image_token * self.image_feature_size + messages[0]["content"]
paired_messages = messages + [{"role": "assistant", "content": ""}]
system = system or self.generating_args["default_system"]
prompt_ids, _ = self.template.encode_oneturn(
tokenizer=self.tokenizer, messages=paired_messages, system=system, tools=tools
)
if self.processor is not None and image is not None: # add image features
image_processor: "BaseImageProcessor" = getattr(self.processor, "image_processor")
pixel_values = image_processor(image, return_tensors="pt")["pixel_values"]
multi_modal_data = MultiModalData(type=MultiModalData.Type.IMAGE, data=pixel_values)
else:
multi_modal_data = None
prompt_length = len(prompt_ids)
use_beam_search: bool = self.generating_args["num_beams"] > 1
temperature: Optional[float] = input_kwargs.pop("temperature", None)
top_p: Optional[float] = input_kwargs.pop("top_p", None)
top_k: Optional[float] = input_kwargs.pop("top_k", None)
num_return_sequences: int = input_kwargs.pop("num_return_sequences", 1)
repetition_penalty: Optional[float] = input_kwargs.pop("repetition_penalty", None)
length_penalty: Optional[float] = input_kwargs.pop("length_penalty", None)
max_length: Optional[int] = input_kwargs.pop("max_length", None)
max_new_tokens: Optional[int] = input_kwargs.pop("max_new_tokens", None)
stop: Optional[Union[str, List[str]]] = input_kwargs.pop("stop", None)
if "max_new_tokens" in self.generating_args:
max_tokens = self.generating_args["max_new_tokens"]
elif "max_length" in self.generating_args:
if self.generating_args["max_length"] > prompt_length:
max_tokens = self.generating_args["max_length"] - prompt_length
else:
max_tokens = 1
if max_length:
max_tokens = max_length - prompt_length if max_length > prompt_length else 1
if max_new_tokens:
max_tokens = max_new_tokens
sampling_params = SamplingParams(
n=num_return_sequences,
repetition_penalty=(
repetition_penalty if repetition_penalty is not None else self.generating_args["repetition_penalty"]
)
or 1.0, # repetition_penalty must > 0
temperature=temperature if temperature is not None else self.generating_args["temperature"],
top_p=(top_p if top_p is not None else self.generating_args["top_p"]) or 1.0, # top_p must > 0
top_k=top_k if top_k is not None else self.generating_args["top_k"],
use_beam_search=use_beam_search,
length_penalty=length_penalty if length_penalty is not None else self.generating_args["length_penalty"],
stop=stop,
stop_token_ids=[self.tokenizer.eos_token_id] + self.tokenizer.additional_special_tokens_ids,
max_tokens=max_tokens,
skip_special_tokens=True,
)
result_generator = self.model.generate(
inputs={"prompt_token_ids": prompt_ids, "multi_modal_data": multi_modal_data},
sampling_params=sampling_params,
request_id=request_id,
lora_request=self.lora_request,
)
return result_generator
async def start(self) -> None:
pass
async def chat(
self,
messages: Sequence[Dict[str, str]],
system: Optional[str] = None,
tools: Optional[str] = None,
image: Optional["NDArray"] = None,
**input_kwargs,
) -> List["Response"]:
final_output = None
generator = await self._generate(messages, system, tools, image, **input_kwargs)
async for request_output in generator:
final_output = request_output
results = []
for output in final_output.outputs:
results.append(
Response(
response_text=output.text,
response_length=len(output.token_ids),
prompt_length=len(final_output.prompt_token_ids),
finish_reason=output.finish_reason,
)
)
return results
async def stream_chat(
self,
messages: Sequence[Dict[str, str]],
system: Optional[str] = None,
tools: Optional[str] = None,
image: Optional["NDArray"] = None,
**input_kwargs,
) -> AsyncGenerator[str, None]:
generated_text = ""
generator = await self._generate(messages, system, tools, image, **input_kwargs)
async for result in generator:
delta_text = result.outputs[0].text[len(generated_text) :]
generated_text = result.outputs[0].text
yield delta_text
async def get_scores(
self,
batch_input: List[str],
**input_kwargs,
) -> List[float]:
raise NotImplementedError("vLLM engine does not support get_scores.")

View File

@@ -0,0 +1,106 @@
import os
import random
import subprocess
import sys
from enum import Enum, unique
from . import launcher
from .api.app import run_api
from .chat.chat_model import run_chat
from .eval.evaluator import run_eval
from .extras.env import VERSION, print_env
from .extras.logging import get_logger
from .extras.misc import get_device_count
from .train.tuner import export_model, run_exp
from .webui.interface import run_web_demo, run_web_ui
USAGE = (
"-" * 70
+ "\n"
+ "| Usage: |\n"
+ "| llamafactory-cli api -h: launch an OpenAI-style API server |\n"
+ "| llamafactory-cli chat -h: launch a chat interface in CLI |\n"
+ "| llamafactory-cli eval -h: evaluate models |\n"
+ "| llamafactory-cli export -h: merge LoRA adapters and export model |\n"
+ "| llamafactory-cli train -h: train models |\n"
+ "| llamafactory-cli webchat -h: launch a chat interface in Web UI |\n"
+ "| llamafactory-cli webui: launch LlamaBoard |\n"
+ "| llamafactory-cli version: show version info |\n"
+ "-" * 70
)
WELCOME = (
"-" * 58
+ "\n"
+ "| Welcome to LLaMA Factory, version {}".format(VERSION)
+ " " * (21 - len(VERSION))
+ "|\n|"
+ " " * 56
+ "|\n"
+ "| Project page: https://github.com/hiyouga/LLaMA-Factory |\n"
+ "-" * 58
)
logger = get_logger(__name__)
@unique
class Command(str, Enum):
API = "api"
CHAT = "chat"
ENV = "env"
EVAL = "eval"
EXPORT = "export"
TRAIN = "train"
WEBDEMO = "webchat"
WEBUI = "webui"
VER = "version"
HELP = "help"
def main():
command = sys.argv.pop(1)
if command == Command.API:
run_api()
elif command == Command.CHAT:
run_chat()
elif command == Command.ENV:
print_env()
elif command == Command.EVAL:
run_eval()
elif command == Command.EXPORT:
export_model()
elif command == Command.TRAIN:
force_torchrun = os.environ.get("FORCE_TORCHRUN", "0").lower() in ["true", "1"]
if force_torchrun or get_device_count() > 1:
master_addr = os.environ.get("MASTER_ADDR", "127.0.0.1")
master_port = os.environ.get("MASTER_PORT", str(random.randint(20001, 29999)))
logger.info("Initializing distributed tasks at: {}:{}".format(master_addr, master_port))
subprocess.run(
(
"torchrun --nnodes {nnodes} --node_rank {node_rank} --nproc_per_node {nproc_per_node} "
"--master_addr {master_addr} --master_port {master_port} {file_name} {args}"
).format(
nnodes=os.environ.get("NNODES", "1"),
node_rank=os.environ.get("RANK", "0"),
nproc_per_node=os.environ.get("NPROC_PER_NODE", str(get_device_count())),
master_addr=master_addr,
master_port=master_port,
file_name=launcher.__file__,
args=" ".join(sys.argv[1:]),
),
shell=True,
)
else:
run_exp()
elif command == Command.WEBDEMO:
run_web_demo()
elif command == Command.WEBUI:
run_web_ui()
elif command == Command.VER:
print(WELCOME)
elif command == Command.HELP:
print(USAGE)
else:
raise NotImplementedError("Unknown command: {}".format(command))

View File

@@ -0,0 +1,16 @@
from .collator import KTODataCollatorWithPadding, PairwiseDataCollatorWithPadding
from .data_utils import Role, split_dataset
from .loader import get_dataset
from .template import TEMPLATES, Template, get_template_and_fix_tokenizer
__all__ = [
"KTODataCollatorWithPadding",
"PairwiseDataCollatorWithPadding",
"Role",
"split_dataset",
"get_dataset",
"TEMPLATES",
"Template",
"get_template_and_fix_tokenizer",
]

View File

@@ -0,0 +1,221 @@
import os
from functools import partial
from typing import TYPE_CHECKING, Any, Dict, List, Union
from datasets import Features
from ..extras.logging import get_logger
from .data_utils import Role
if TYPE_CHECKING:
from datasets import Dataset, IterableDataset
from ..hparams import DataArguments
from .parser import DatasetAttr
logger = get_logger(__name__)
def _convert_images(images: List[Any], dataset_attr: "DatasetAttr", data_args: "DataArguments") -> List[Any]:
r"""
Optionally concatenates image path to dataset dir when loading from local disk.
"""
outputs = []
if dataset_attr.load_from in ["script", "file"]:
for image in images:
if isinstance(image, str) and os.path.isfile(os.path.join(data_args.dataset_dir, image)):
outputs.append(os.path.join(data_args.dataset_dir, image))
else:
outputs.append(image)
return outputs
def convert_alpaca(
examples: Dict[str, List[Any]], dataset_attr: "DatasetAttr", data_args: "DataArguments"
) -> Dict[str, List[Any]]:
r"""
Converts alpaca format dataset to the standard format.
"""
outputs = {"prompt": [], "response": [], "system": [], "tools": [], "images": []}
convert_images = partial(_convert_images, dataset_attr=dataset_attr, data_args=data_args)
for i in range(len(examples[dataset_attr.prompt])):
prompt = []
if dataset_attr.history and isinstance(examples[dataset_attr.history][i], list):
for old_prompt, old_response in examples[dataset_attr.history][i]:
prompt.append({"role": Role.USER.value, "content": old_prompt})
prompt.append({"role": Role.ASSISTANT.value, "content": old_response})
content = []
if dataset_attr.prompt and examples[dataset_attr.prompt][i]:
content.append(examples[dataset_attr.prompt][i])
if dataset_attr.query and examples[dataset_attr.query][i]:
content.append(examples[dataset_attr.query][i])
prompt.append({"role": Role.USER.value, "content": "\n".join(content)}) # "prompt\nquery"
if dataset_attr.kto_tag and isinstance(examples[dataset_attr.kto_tag][i], bool): # kto example
response = [{"role": Role.ASSISTANT.value, "content": examples[dataset_attr.response][i]}]
if examples[dataset_attr.kto_tag][i]:
response = response + [{"role": Role.ASSISTANT.value, "content": ""}]
else:
response = [{"role": Role.ASSISTANT.value, "content": ""}] + response
elif (
dataset_attr.ranking
and isinstance(examples[dataset_attr.chosen][i], str)
and isinstance(examples[dataset_attr.rejected][i], str)
): # pairwise example
response = [
{"role": Role.ASSISTANT.value, "content": examples[dataset_attr.chosen][i]},
{"role": Role.ASSISTANT.value, "content": examples[dataset_attr.rejected][i]},
]
elif dataset_attr.response and isinstance(examples[dataset_attr.response][i], str): # normal example
response = [{"role": Role.ASSISTANT.value, "content": examples[dataset_attr.response][i]}]
else: # unsupervised
response = []
outputs["prompt"].append(prompt)
outputs["response"].append(response)
outputs["system"].append(examples[dataset_attr.system][i] if dataset_attr.system else "")
outputs["tools"].append(examples[dataset_attr.tools][i] if dataset_attr.tools else "")
outputs["images"].append(convert_images(examples[dataset_attr.images][i]) if dataset_attr.images else [])
return outputs
def convert_sharegpt(
examples: Dict[str, List[Any]], dataset_attr: "DatasetAttr", data_args: "DataArguments"
) -> Dict[str, List[Any]]:
r"""
Converts sharegpt format dataset to the standard format.
"""
outputs = {"prompt": [], "response": [], "system": [], "tools": [], "images": []}
convert_images = partial(_convert_images, dataset_attr=dataset_attr, data_args=data_args)
tag_mapping = {
dataset_attr.user_tag: Role.USER.value,
dataset_attr.assistant_tag: Role.ASSISTANT.value,
dataset_attr.observation_tag: Role.OBSERVATION.value,
dataset_attr.function_tag: Role.FUNCTION.value,
dataset_attr.system_tag: Role.SYSTEM.value,
}
odd_tags = (dataset_attr.user_tag, dataset_attr.observation_tag)
even_tags = (dataset_attr.assistant_tag, dataset_attr.function_tag)
accept_tags = (odd_tags, even_tags)
for i, messages in enumerate(examples[dataset_attr.messages]):
if dataset_attr.system_tag and messages[0][dataset_attr.role_tag] == dataset_attr.system_tag:
system = messages[0][dataset_attr.content_tag]
messages = messages[1:]
else:
system = examples[dataset_attr.system][i] if dataset_attr.system else ""
if len(messages) == 0:
continue
aligned_messages = []
broken_data = False
for turn_idx, message in enumerate(messages):
if message[dataset_attr.role_tag] not in accept_tags[turn_idx % 2]:
logger.warning("Invalid role tag in {}.".format(messages))
broken_data = True
aligned_messages.append(
{"role": tag_mapping[message[dataset_attr.role_tag]], "content": message[dataset_attr.content_tag]}
)
if (not dataset_attr.ranking and len(aligned_messages) % 2 != 0) or (
dataset_attr.ranking and len(aligned_messages) % 2 == 0
):
logger.warning("Invalid message count in {}.".format(messages))
broken_data = True
if dataset_attr.kto_tag and isinstance(examples[dataset_attr.kto_tag][i], bool): # kto example
prompt = aligned_messages[:-1]
response = aligned_messages[-1:]
if examples[dataset_attr.kto_tag][i]:
response = response + [{"role": Role.ASSISTANT.value, "content": ""}]
else:
response = [{"role": Role.ASSISTANT.value, "content": ""}] + response
elif (
dataset_attr.ranking
and isinstance(examples[dataset_attr.chosen][i], dict)
and isinstance(examples[dataset_attr.rejected][i], dict)
): # pairwise example
chosen = examples[dataset_attr.chosen][i]
rejected = examples[dataset_attr.rejected][i]
if (
chosen[dataset_attr.role_tag] not in accept_tags[-1]
or rejected[dataset_attr.role_tag] not in accept_tags[-1]
):
logger.warning("Invalid role tag in {}.".format([chosen, rejected]))
broken_data = True
prompt = aligned_messages
response = [
{"role": tag_mapping[chosen[dataset_attr.role_tag]], "content": chosen[dataset_attr.content_tag]},
{"role": tag_mapping[rejected[dataset_attr.role_tag]], "content": rejected[dataset_attr.content_tag]},
]
else: # normal example
prompt = aligned_messages[:-1]
response = aligned_messages[-1:]
if broken_data:
logger.warning("Skipping this abnormal example.")
continue
outputs["prompt"].append(prompt)
outputs["response"].append(response)
outputs["system"].append(system)
outputs["tools"].append(examples[dataset_attr.tools][i] if dataset_attr.tools else "")
outputs["images"].append(convert_images(examples[dataset_attr.images][i]) if dataset_attr.images else [])
return outputs
def align_dataset(
dataset: Union["Dataset", "IterableDataset"], dataset_attr: "DatasetAttr", data_args: "DataArguments"
) -> Union["Dataset", "IterableDataset"]:
r"""
Aligned dataset:
prompt: [{"role": "user", "content": "..."}] * (2T - 1)
response: [{"role": "assistant", "content": "..."}] * N (N > 1 for ranking dataset)
system: "..."
tools: "...",
images: [],
"""
if dataset_attr.formatting == "alpaca":
convert_func = partial(convert_alpaca, dataset_attr=dataset_attr, data_args=data_args)
else:
convert_func = partial(convert_sharegpt, dataset_attr=dataset_attr, data_args=data_args)
column_names = list(next(iter(dataset)).keys())
features = Features.from_dict(
{
"prompt": [
{"role": {"dtype": "string", "_type": "Value"}, "content": {"dtype": "string", "_type": "Value"}}
],
"response": [
{"role": {"dtype": "string", "_type": "Value"}, "content": {"dtype": "string", "_type": "Value"}}
],
"system": {"dtype": "string", "_type": "Value"},
"tools": {"dtype": "string", "_type": "Value"},
"images": [{"_type": "Image"}],
}
)
kwargs = {}
if not data_args.streaming:
kwargs = dict(
num_proc=data_args.preprocessing_num_workers,
load_from_cache_file=(not data_args.overwrite_cache),
desc="Converting format of dataset",
)
return dataset.map(
convert_func,
batched=True,
remove_columns=column_names,
features=features,
**kwargs,
)

View File

@@ -0,0 +1,81 @@
from dataclasses import dataclass
from typing import Any, Dict, Sequence
import torch
from transformers import DataCollatorForSeq2Seq
@dataclass
class PairwiseDataCollatorWithPadding(DataCollatorForSeq2Seq):
r"""
Data collator for pairwise data.
"""
def __call__(self, features: Sequence[Dict[str, Any]]) -> Dict[str, torch.Tensor]:
r"""
Pads batched data to the longest sequence in the batch.
We generate 2 * n examples where the first n examples represent chosen examples and
the last n examples represent rejected examples.
"""
concatenated_features = []
for key in ("chosen", "rejected"):
for feature in features:
target_feature = {
"input_ids": feature["{}_input_ids".format(key)],
"attention_mask": feature["{}_attention_mask".format(key)],
"labels": feature["{}_labels".format(key)],
}
if "pixel_values" in feature:
target_feature["pixel_values"] = feature["pixel_values"]
if "{}_token_type_ids".format(key) in feature:
target_feature["token_type_ids"] = feature["{}_token_type_ids".format(key)]
concatenated_features.append(target_feature)
return super().__call__(concatenated_features)
@dataclass
class KTODataCollatorWithPadding(DataCollatorForSeq2Seq):
r"""
Data collator for KTO data.
"""
def __call__(self, features: Sequence[Dict[str, Any]]) -> Dict[str, torch.Tensor]:
target_features = []
kl_features = []
kto_tags = []
for feature in features:
target_feature = {
"input_ids": feature["input_ids"],
"attention_mask": feature["attention_mask"],
"labels": feature["labels"],
}
kl_feature = {
"input_ids": feature["kl_input_ids"],
"attention_mask": feature["kl_attention_mask"],
"labels": feature["kl_labels"],
}
if "pixel_values" in feature:
target_feature["pixel_values"] = feature["pixel_values"]
if "token_type_ids" in feature:
target_feature["token_type_ids"] = feature["token_type_ids"]
kl_feature["token_type_ids"] = feature["kl_token_type_ids"]
target_features.append(target_feature)
kl_features.append(kl_feature)
kto_tags.append(feature["kto_tags"])
batch = super().__call__(target_features)
kl_batch = super().__call__(kl_features)
batch["kl_input_ids"] = kl_batch["input_ids"]
batch["kl_attention_mask"] = kl_batch["attention_mask"]
batch["kl_labels"] = kl_batch["labels"]
if "token_type_ids" in batch:
batch["kl_token_type_ids"] = kl_batch["token_type_ids"]
batch["kto_tags"] = torch.tensor(kto_tags)
return batch

View File

@@ -1,6 +1,5 @@
import hashlib
from enum import Enum, unique
from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union
from typing import TYPE_CHECKING, Dict, List, Tuple, Union
from datasets import concatenate_datasets, interleave_datasets
@@ -11,7 +10,7 @@ if TYPE_CHECKING:
from datasets import Dataset, IterableDataset
from transformers import Seq2SeqTrainingArguments
from llmtuner.hparams import DataArguments
from ..hparams import DataArguments
logger = get_logger(__name__)
@@ -26,25 +25,10 @@ class Role(str, Enum):
OBSERVATION = "observation"
def checksum(data_files: List[str], file_sha1: Optional[str] = None) -> None:
if file_sha1 is None:
logger.warning("Checksum failed: missing SHA-1 hash value in dataset_info.json.")
return
if len(data_files) != 1:
logger.warning("Checksum failed: too many files.")
return
with open(data_files[0], "rb") as f:
sha1 = hashlib.sha1(f.read()).hexdigest()
if sha1 != file_sha1:
logger.warning("Checksum failed: mismatched SHA-1 hash value at {}.".format(data_files[0]))
def infer_max_len(source_len: int, target_len: int, max_len: int, reserved_label_len: int) -> Tuple[int, int]:
max_target_len = int(max_len * (target_len / (source_len + target_len)))
max_target_len = max(max_target_len, reserved_label_len)
max_source_len = max_len - max_target_len
max_source_len = max_len - min(max_target_len, target_len)
return max_source_len, max_target_len
@@ -78,9 +62,9 @@ def split_dataset(
if training_args.do_train:
if data_args.val_size > 1e-6: # Split the dataset
if data_args.streaming:
dataset = dataset.shuffle(buffer_size=data_args.buffer_size, seed=training_args.seed)
val_set = dataset.take(int(data_args.val_size))
train_set = dataset.skip(int(data_args.val_size))
dataset = dataset.shuffle(buffer_size=data_args.buffer_size, seed=training_args.seed)
return {"train_dataset": train_set, "eval_dataset": val_set}
else:
val_size = int(data_args.val_size) if data_args.val_size > 1 else data_args.val_size

View File

@@ -1,21 +1,24 @@
import inspect
import os
from typing import TYPE_CHECKING, Literal, Union
import sys
from typing import TYPE_CHECKING, Literal, Optional, Union
import numpy as np
from datasets import load_dataset, load_from_disk
from ..extras.constants import FILEEXT2TYPE
from ..extras.logging import get_logger
from ..extras.misc import has_tokenized_data
from .aligner import align_dataset
from .data_utils import merge_dataset
from .parser import get_dataset_list
from .preprocess import get_preprocess_and_print_func
from .template import get_template_and_fix_tokenizer
from .utils import checksum, merge_dataset
if TYPE_CHECKING:
from datasets import Dataset, IterableDataset
from transformers import Seq2SeqTrainingArguments
from transformers import ProcessorMixin, Seq2SeqTrainingArguments
from transformers.tokenization_utils import PreTrainedTokenizer
from ..hparams import DataArguments, ModelArguments
@@ -56,14 +59,12 @@ def load_single_dataset(
data_files.append(local_path)
data_path = FILEEXT2TYPE.get(local_path.split(".")[-1], None)
else:
raise ValueError("File not found.")
raise ValueError("File {} not found.".format(local_path))
if data_path is None:
raise ValueError("File extension must be txt, csv, json or jsonl.")
checksum(data_files, dataset_attr.file_sha1)
raise ValueError("Allowed file types: {}.".format(",".join(FILEEXT2TYPE.keys())))
else:
raise NotImplementedError
raise NotImplementedError("Unknown load type: {}.".format(dataset_attr.load_from))
if dataset_attr.load_from == "ms_hub":
try:
@@ -80,7 +81,9 @@ def load_single_dataset(
cache_dir=cache_dir,
token=model_args.ms_hub_token,
use_streaming=(data_args.streaming and (dataset_attr.load_from != "file")),
).to_hf_dataset()
)
if isinstance(dataset, MsDataset):
dataset = dataset.to_hf_dataset()
except ImportError:
raise ImportError("Please install modelscope via `pip install modelscope -U`")
else:
@@ -104,30 +107,43 @@ def load_single_dataset(
if data_args.streaming and (dataset_attr.load_from == "file"): # faster than specifying streaming=True
dataset = dataset.to_iterable_dataset() # TODO: add num shards parameter
if dataset_attr.num_samples is not None and not data_args.streaming:
target_num = dataset_attr.num_samples
indexes = np.random.permutation(len(dataset))[:target_num]
target_num -= len(indexes)
if target_num > 0:
expand_indexes = np.random.choice(len(dataset), target_num)
indexes = np.concatenate((indexes, expand_indexes), axis=0)
assert len(indexes) == dataset_attr.num_samples, "Sample num mismatched."
dataset = dataset.select(indexes)
logger.info("Sampled {} examples from dataset {}.".format(dataset_attr.num_samples, dataset_attr))
if data_args.max_samples is not None: # truncate dataset
num_samples = min(data_args.max_samples, len(dataset))
dataset = dataset.select(range(num_samples))
max_samples = min(data_args.max_samples, len(dataset))
dataset = dataset.select(range(max_samples))
return align_dataset(dataset, dataset_attr, data_args)
def get_dataset(
tokenizer: "PreTrainedTokenizer",
model_args: "ModelArguments",
data_args: "DataArguments",
training_args: "Seq2SeqTrainingArguments",
stage: Literal["pt", "sft", "rm", "ppo"],
# split: Optional[str] = "train", # TODO: add split
stage: Literal["pt", "sft", "rm", "ppo", "kto"],
tokenizer: "PreTrainedTokenizer",
processor: Optional["ProcessorMixin"] = None,
) -> Union["Dataset", "IterableDataset"]:
template = get_template_and_fix_tokenizer(tokenizer, data_args.template)
if data_args.train_on_prompt and template.efficient_eos:
raise ValueError("Current template does not support `train_on_prompt`.")
# Load from cache
if data_args.cache_path is not None:
if os.path.exists(data_args.cache_path):
# Load tokenized dataset
if data_args.tokenized_path is not None:
if has_tokenized_data(data_args.tokenized_path):
logger.warning("Loading dataset from disk will ignore other data arguments.")
dataset = load_from_disk(data_args.cache_path)
dataset = load_from_disk(data_args.tokenized_path)
logger.info("Loaded tokenized dataset from {}.".format(data_args.tokenized_path))
if data_args.streaming:
dataset = dataset.to_iterable_dataset()
return dataset
@@ -138,12 +154,15 @@ def get_dataset(
with training_args.main_process_first(desc="load dataset"):
all_datasets = []
for dataset_attr in get_dataset_list(data_args):
if (stage == "rm" and dataset_attr.ranking is False) or (stage != "rm" and dataset_attr.ranking is True):
raise ValueError("The dataset is not applicable in the current training stage.")
all_datasets.append(load_single_dataset(dataset_attr, model_args, data_args))
dataset = merge_dataset(all_datasets, data_args, training_args)
with training_args.main_process_first(desc="pre-process dataset"):
preprocess_func, print_function = get_preprocess_and_print_func(
tokenizer, template, data_args, training_args, stage
data_args, training_args, stage, template, tokenizer, processor
)
column_names = list(next(iter(dataset)).keys())
kwargs = {}
@@ -156,15 +175,21 @@ def get_dataset(
dataset = dataset.map(preprocess_func, batched=True, remove_columns=column_names, **kwargs)
if data_args.cache_path is not None and not os.path.exists(data_args.cache_path):
if data_args.tokenized_path is not None:
if training_args.should_save:
dataset.save_to_disk(data_args.cache_path)
logger.info("Dataset cache saved at {}.".format(data_args.cache_path))
dataset.save_to_disk(data_args.tokenized_path)
logger.info("Tokenized dataset saved at {}.".format(data_args.tokenized_path))
logger.info("Please restart the training with `tokenized_path: {}`.".format(data_args.tokenized_path))
sys.exit(0)
if training_args.should_log:
try:
print_function(next(iter(dataset)))
except StopIteration:
raise RuntimeError("Cannot find valid samples, check `data/README.md` for the data format.")
if stage == "pt":
raise RuntimeError("Cannot find sufficient samples, consider increasing dataset size.")
else:
raise RuntimeError("Cannot find valid samples, check `data/README.md` for the data format.")
return dataset

View File

@@ -20,23 +20,28 @@ class DatasetAttr:
""" basic configs """
load_from: Literal["hf_hub", "ms_hub", "script", "file"]
dataset_name: str
formatting: Literal["alpaca", "sharegpt"] = "alpaca"
ranking: bool = False
""" extra configs """
file_sha1: Optional[str] = None
subset: Optional[str] = None
folder: Optional[str] = None
ranking: bool = False
formatting: Literal["alpaca", "sharegpt"] = "alpaca"
""" columns """
num_samples: Optional[int] = None
""" common columns """
system: Optional[str] = None
""" columns for the alpaca format """
tools: Optional[str] = None
images: Optional[str] = None
""" rlhf columns """
chosen: Optional[str] = None
rejected: Optional[str] = None
kto_tag: Optional[str] = None
""" alpaca columns """
prompt: Optional[str] = "instruction"
query: Optional[str] = "input"
response: Optional[str] = "output"
history: Optional[str] = None
""" columns for the sharegpt format """
""" sharegpt columns """
messages: Optional[str] = "conversations"
tools: Optional[str] = None
""" tags for the sharegpt format """
""" sharegpt tags """
role_tag: Optional[str] = "from"
content_tag: Optional[str] = "value"
user_tag: Optional[str] = "human"
@@ -53,22 +58,35 @@ class DatasetAttr:
def get_dataset_list(data_args: "DataArguments") -> List["DatasetAttr"]:
dataset_names = [ds.strip() for ds in data_args.dataset.split(",")] if data_args.dataset is not None else []
try:
with open(os.path.join(data_args.dataset_dir, DATA_CONFIG), "r") as f:
dataset_info = json.load(f)
except Exception as err:
if data_args.dataset is not None:
raise ValueError(
"Cannot open {} due to {}.".format(os.path.join(data_args.dataset_dir, DATA_CONFIG), str(err))
)
if data_args.dataset is not None:
dataset_names = [ds.strip() for ds in data_args.dataset.split(",")]
else:
dataset_names = []
if data_args.dataset_dir == "ONLINE":
dataset_info = None
else:
try:
with open(os.path.join(data_args.dataset_dir, DATA_CONFIG), "r") as f:
dataset_info = json.load(f)
except Exception as err:
if len(dataset_names) != 0:
raise ValueError(
"Cannot open {} due to {}.".format(os.path.join(data_args.dataset_dir, DATA_CONFIG), str(err))
)
dataset_info = None
if data_args.interleave_probs is not None:
data_args.interleave_probs = [float(prob.strip()) for prob in data_args.interleave_probs.split(",")]
dataset_list: List[DatasetAttr] = []
for name in dataset_names:
if dataset_info is None:
load_from = "ms_hub" if use_modelscope() else "hf_hub"
dataset_attr = DatasetAttr(load_from, dataset_name=name)
dataset_list.append(dataset_attr)
continue
if name not in dataset_info:
raise ValueError("Undefined dataset {} in {}.".format(name, DATA_CONFIG))
@@ -85,18 +103,18 @@ def get_dataset_list(data_args: "DataArguments") -> List["DatasetAttr"]:
else:
dataset_attr = DatasetAttr("file", dataset_name=dataset_info[name]["file_name"])
dataset_attr.set_attr("file_sha1", dataset_info[name])
dataset_attr.set_attr("formatting", dataset_info[name], default="alpaca")
dataset_attr.set_attr("ranking", dataset_info[name], default=False)
dataset_attr.set_attr("subset", dataset_info[name])
dataset_attr.set_attr("folder", dataset_info[name])
dataset_attr.set_attr("ranking", dataset_info[name], default=False)
dataset_attr.set_attr("formatting", dataset_info[name], default="alpaca")
dataset_attr.set_attr("num_samples", dataset_info[name])
if "columns" in dataset_info[name]:
column_names = ["system"]
column_names = ["system", "tools", "images", "chosen", "rejected", "kto_tag"]
if dataset_attr.formatting == "alpaca":
column_names.extend(["prompt", "query", "response", "history"])
else:
column_names.extend(["messages", "tools"])
column_names.extend(["messages"])
for column_name in column_names:
dataset_attr.set_attr(column_name, dataset_info[name]["columns"])

View File

@@ -0,0 +1,84 @@
from functools import partial
from typing import TYPE_CHECKING, Callable, Literal, Optional, Tuple
from .processors.feedback import preprocess_feedback_dataset
from .processors.pairwise import preprocess_pairwise_dataset, print_pairwise_dataset_example
from .processors.pretrain import preprocess_pretrain_dataset
from .processors.supervised import (
preprocess_packed_supervised_dataset,
preprocess_supervised_dataset,
print_supervised_dataset_example,
)
from .processors.unsupervised import preprocess_unsupervised_dataset, print_unsupervised_dataset_example
if TYPE_CHECKING:
from transformers import ProcessorMixin, Seq2SeqTrainingArguments
from transformers.tokenization_utils import PreTrainedTokenizer
from ..hparams import DataArguments
from .template import Template
def get_preprocess_and_print_func(
data_args: "DataArguments",
training_args: "Seq2SeqTrainingArguments",
stage: Literal["pt", "sft", "rm", "ppo", "kto"],
template: "Template",
tokenizer: "PreTrainedTokenizer",
processor: Optional["ProcessorMixin"],
) -> Tuple[Callable, Callable]:
if stage == "pt":
preprocess_func = partial(
preprocess_pretrain_dataset,
tokenizer=tokenizer,
data_args=data_args,
)
print_function = partial(print_unsupervised_dataset_example, tokenizer=tokenizer)
elif stage == "sft" and not training_args.predict_with_generate:
if data_args.packing:
preprocess_func = partial(
preprocess_packed_supervised_dataset,
template=template,
tokenizer=tokenizer,
data_args=data_args,
)
else:
preprocess_func = partial(
preprocess_supervised_dataset,
template=template,
tokenizer=tokenizer,
processor=processor,
data_args=data_args,
)
print_function = partial(print_supervised_dataset_example, tokenizer=tokenizer)
elif stage == "rm":
preprocess_func = partial(
preprocess_pairwise_dataset,
template=template,
tokenizer=tokenizer,
processor=processor,
data_args=data_args,
)
print_function = partial(print_pairwise_dataset_example, tokenizer=tokenizer)
elif stage == "kto":
preprocess_func = partial(
preprocess_feedback_dataset,
template=template,
tokenizer=tokenizer,
processor=processor,
data_args=data_args,
)
print_function = partial(print_supervised_dataset_example, tokenizer=tokenizer)
else:
preprocess_func = partial(
preprocess_unsupervised_dataset,
template=template,
tokenizer=tokenizer,
processor=processor,
data_args=data_args,
)
print_function = partial(print_unsupervised_dataset_example, tokenizer=tokenizer)
return preprocess_func, print_function

View File

@@ -0,0 +1,126 @@
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Sequence, Tuple
from ...extras.constants import IGNORE_INDEX
from ...extras.logging import get_logger
from .processor_utils import get_paligemma_token_type_ids, get_pixel_values
if TYPE_CHECKING:
from transformers import ProcessorMixin
from transformers.tokenization_utils import PreTrainedTokenizer
from ...hparams import DataArguments
from ..template import Template
logger = get_logger(__name__)
def _encode_feedback_example(
prompt: Sequence[Dict[str, str]],
response: Sequence[Dict[str, str]],
kl_response: Sequence[Dict[str, str]],
system: Optional[str],
tools: Optional[str],
template: "Template",
tokenizer: "PreTrainedTokenizer",
processor: Optional["ProcessorMixin"],
data_args: "DataArguments",
) -> Tuple[List[int], List[int], List[int], List[int], bool]:
if processor is not None and not hasattr(processor, "image_seq_length"): # llava-like models
prompt[0]["content"] = template.image_token + prompt[0]["content"]
if response[0]["content"]: # desired example
kto_tag = True
messages = prompt + [response[0]]
else: # undesired example
kto_tag = False
messages = prompt + [response[1]]
if kl_response[0]["content"]:
kl_messages = prompt + [kl_response[0]]
else:
kl_messages = prompt + [kl_response[1]]
prompt_ids, response_ids = template.encode_oneturn(
tokenizer, messages, system, tools, data_args.cutoff_len, data_args.reserved_label_len
)
_, kl_response_ids = template.encode_oneturn(
tokenizer, kl_messages, system, tools, data_args.cutoff_len, data_args.reserved_label_len
)
if template.efficient_eos:
response_ids += [tokenizer.eos_token_id]
kl_response_ids += [tokenizer.eos_token_id]
if processor is not None and hasattr(processor, "image_seq_length"): # paligemma models
image_token_id = tokenizer.convert_tokens_to_ids(template.image_token)
prompt_ids = [image_token_id] * getattr(processor, "image_seq_length") + prompt_ids
input_ids = prompt_ids + response_ids
labels = [IGNORE_INDEX] * len(prompt_ids) + response_ids
kl_input_ids = prompt_ids + kl_response_ids
kl_labels = [IGNORE_INDEX] * len(prompt_ids) + kl_response_ids
return input_ids, labels, kl_input_ids, kl_labels, kto_tag
def preprocess_feedback_dataset(
examples: Dict[str, List[Any]],
template: "Template",
tokenizer: "PreTrainedTokenizer",
processor: Optional["ProcessorMixin"],
data_args: "DataArguments",
) -> Dict[str, List[List[int]]]:
# create unrelated input-output pairs for estimating the KL term by flipping the matched pairs
kl_response = examples["response"][::-1]
model_inputs = {
"input_ids": [],
"attention_mask": [],
"labels": [],
"kl_input_ids": [],
"kl_attention_mask": [],
"kl_labels": [],
"kto_tags": [],
}
if processor is not None:
model_inputs["pixel_values"] = []
if hasattr(processor, "image_seq_length"): # paligemma models
model_inputs["token_type_ids"] = []
model_inputs["kl_token_type_ids"] = []
for i in range(len(examples["prompt"])):
if len(examples["prompt"][i]) % 2 != 1 or len(examples["response"][i]) < 2:
logger.warning("Dropped invalid example: {}".format(examples["prompt"][i] + examples["response"][i]))
continue
input_ids, labels, kl_input_ids, kl_labels, kto_tag = _encode_feedback_example(
prompt=examples["prompt"][i],
response=examples["response"][i],
kl_response=kl_response[i],
system=examples["system"][i],
tools=examples["tools"][i],
template=template,
tokenizer=tokenizer,
processor=processor,
data_args=data_args,
)
model_inputs["input_ids"].append(input_ids)
model_inputs["attention_mask"].append([1] * len(input_ids))
model_inputs["labels"].append(labels)
model_inputs["kl_input_ids"].append(kl_input_ids)
model_inputs["kl_attention_mask"].append([1] * len(kl_input_ids))
model_inputs["kl_labels"].append(kl_labels)
model_inputs["kto_tags"].append(kto_tag)
if processor is not None:
model_inputs["pixel_values"].append(get_pixel_values(examples["images"][i], processor))
if hasattr(processor, "image_seq_length"): # paligemma models
model_inputs["token_type_ids"].append(get_paligemma_token_type_ids(len(input_ids), processor))
model_inputs["kl_token_type_ids"].append(get_paligemma_token_type_ids(len(kl_input_ids), processor))
desirable_num = sum([1 for tag in model_inputs["kto_tags"] if tag])
undesirable_num = len(model_inputs["kto_tags"]) - desirable_num
if desirable_num == 0 or undesirable_num == 0:
logger.warning("Your dataset only has one preference type.")
return model_inputs

View File

@@ -0,0 +1,123 @@
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Sequence, Tuple
from ...extras.constants import IGNORE_INDEX
from ...extras.logging import get_logger
from .processor_utils import get_paligemma_token_type_ids, get_pixel_values
if TYPE_CHECKING:
from transformers import ProcessorMixin
from transformers.tokenization_utils import PreTrainedTokenizer
from ...hparams import DataArguments
from ..template import Template
logger = get_logger(__name__)
def _encode_pairwise_example(
prompt: Sequence[Dict[str, str]],
response: Sequence[Dict[str, str]],
system: Optional[str],
tools: Optional[str],
template: "Template",
tokenizer: "PreTrainedTokenizer",
processor: Optional["ProcessorMixin"],
data_args: "DataArguments",
) -> Tuple[List[int], List[int], List[int], List[int]]:
if processor is not None and not hasattr(processor, "image_seq_length"): # llava-like models
prompt[0]["content"] = template.image_token + prompt[0]["content"]
chosen_messages = prompt + [response[0]]
rejected_messages = prompt + [response[1]]
prompt_ids, chosen_ids = template.encode_oneturn(
tokenizer, chosen_messages, system, tools, data_args.cutoff_len, data_args.reserved_label_len
)
_, rejected_ids = template.encode_oneturn(
tokenizer, rejected_messages, system, tools, data_args.cutoff_len, data_args.reserved_label_len
)
if template.efficient_eos:
chosen_ids += [tokenizer.eos_token_id]
rejected_ids += [tokenizer.eos_token_id]
if processor is not None and hasattr(processor, "image_seq_length"): # paligemma models
image_token_id = tokenizer.convert_tokens_to_ids(template.image_token)
prompt_ids = [image_token_id] * getattr(processor, "image_seq_length") + prompt_ids
chosen_input_ids = prompt_ids + chosen_ids
chosen_labels = [IGNORE_INDEX] * len(prompt_ids) + chosen_ids
rejected_input_ids = prompt_ids + rejected_ids
rejected_labels = [IGNORE_INDEX] * len(prompt_ids) + rejected_ids
return chosen_input_ids, chosen_labels, rejected_input_ids, rejected_labels
def preprocess_pairwise_dataset(
examples: Dict[str, List[Any]],
template: "Template",
tokenizer: "PreTrainedTokenizer",
processor: Optional["ProcessorMixin"],
data_args: "DataArguments",
) -> Dict[str, List[List[int]]]:
# build input pairs with format `<bos> X`, `Y1 <eos>` and `Y2 <eos>`
model_inputs = {
"chosen_input_ids": [],
"chosen_attention_mask": [],
"chosen_labels": [],
"rejected_input_ids": [],
"rejected_attention_mask": [],
"rejected_labels": [],
}
if processor is not None:
model_inputs["pixel_values"] = []
if hasattr(processor, "image_seq_length"): # paligemma models
model_inputs["chosen_token_type_ids"] = []
model_inputs["rejected_token_type_ids"] = []
for i in range(len(examples["prompt"])):
if len(examples["prompt"][i]) % 2 != 1 or len(examples["response"][i]) < 2:
logger.warning("Dropped invalid example: {}".format(examples["prompt"][i] + examples["response"][i]))
continue
chosen_input_ids, chosen_labels, rejected_input_ids, rejected_labels = _encode_pairwise_example(
prompt=examples["prompt"][i],
response=examples["response"][i],
system=examples["system"][i],
tools=examples["tools"][i],
template=template,
tokenizer=tokenizer,
processor=processor,
data_args=data_args,
)
model_inputs["chosen_input_ids"].append(chosen_input_ids)
model_inputs["chosen_attention_mask"].append([1] * len(chosen_input_ids))
model_inputs["chosen_labels"].append(chosen_labels)
model_inputs["rejected_input_ids"].append(rejected_input_ids)
model_inputs["rejected_attention_mask"].append([1] * len(rejected_input_ids))
model_inputs["rejected_labels"].append(rejected_labels)
if processor is not None:
model_inputs["pixel_values"].append(get_pixel_values(examples["images"][i], processor))
if hasattr(processor, "image_seq_length"): # paligemma models
model_inputs["chosen_token_type_ids"].append(
get_paligemma_token_type_ids(len(chosen_input_ids), processor)
)
model_inputs["rejected_token_type_ids"].append(
get_paligemma_token_type_ids(len(rejected_input_ids), processor)
)
return model_inputs
def print_pairwise_dataset_example(example: Dict[str, List[int]], tokenizer: "PreTrainedTokenizer") -> None:
valid_chosen_labels = list(filter(lambda x: x != IGNORE_INDEX, example["chosen_labels"]))
valid_rejected_labels = list(filter(lambda x: x != IGNORE_INDEX, example["rejected_labels"]))
print("chosen_input_ids:\n{}".format(example["chosen_input_ids"]))
print("chosen_inputs:\n{}".format(tokenizer.decode(example["chosen_input_ids"], skip_special_tokens=False)))
print("chosen_label_ids:\n{}".format(example["chosen_labels"]))
print("chosen_labels:\n{}".format(tokenizer.decode(valid_chosen_labels, skip_special_tokens=False)))
print("rejected_input_ids:\n{}".format(example["rejected_input_ids"]))
print("rejected_inputs:\n{}".format(tokenizer.decode(example["rejected_input_ids"], skip_special_tokens=False)))
print("rejected_label_ids:\n{}".format(example["rejected_labels"]))
print("rejected_labels:\n{}".format(tokenizer.decode(valid_rejected_labels, skip_special_tokens=False)))

View File

@@ -0,0 +1,36 @@
from itertools import chain
from typing import TYPE_CHECKING, Any, Dict, List
if TYPE_CHECKING:
from transformers.tokenization_utils import PreTrainedTokenizer
from ...hparams import DataArguments
def preprocess_pretrain_dataset(
examples: Dict[str, List[Any]], tokenizer: "PreTrainedTokenizer", data_args: "DataArguments"
) -> Dict[str, List[List[int]]]:
# build grouped texts with format `X1 X2 X3 ...` if packing is enabled
text_examples = [messages[0]["content"] + tokenizer.eos_token for messages in examples["prompt"]]
if not data_args.packing:
if data_args.template == "gemma":
text_examples = [tokenizer.bos_token + example for example in text_examples]
result = tokenizer(text_examples, add_special_tokens=False, max_length=data_args.cutoff_len, truncation=True)
else:
tokenized_examples = tokenizer(text_examples, add_special_tokens=False)
concatenated_examples = {k: list(chain(*tokenized_examples[k])) for k in tokenized_examples.keys()}
total_length = len(concatenated_examples[list(concatenated_examples.keys())[0]])
block_size = data_args.cutoff_len
total_length = (total_length // block_size) * block_size
result = {
k: [t[i : i + block_size] for i in range(0, total_length, block_size)]
for k, t in concatenated_examples.items()
}
if data_args.template == "gemma":
for i in range(len(result["input_ids"])):
result["input_ids"][i][0] = tokenizer.bos_token_id
return result

View File

@@ -0,0 +1,64 @@
import bisect
from typing import TYPE_CHECKING, List, Sequence
from ...extras.packages import is_pillow_available
if is_pillow_available():
from PIL import Image
if TYPE_CHECKING:
from numpy.typing import NDArray
from PIL.Image import Image as ImageObject
from transformers import ProcessorMixin
from transformers.image_processing_utils import BaseImageProcessor
def search_for_fit(numbers: Sequence[int], capacity: int) -> int:
r"""
Finds the index of largest number that fits into the knapsack with the given capacity.
"""
index = bisect.bisect(numbers, capacity)
return -1 if index == 0 else (index - 1)
def greedy_knapsack(numbers: List[int], capacity: int) -> List[List[int]]:
r"""
An efficient greedy algorithm with binary search for the knapsack problem.
"""
numbers.sort() # sort numbers in ascending order for binary search
knapsacks = []
while numbers:
current_knapsack = []
remaining_capacity = capacity
while True:
index = search_for_fit(numbers, remaining_capacity)
if index == -1:
break # no more numbers fit in this knapsack
remaining_capacity -= numbers[index] # update the remaining capacity
current_knapsack.append(numbers.pop(index)) # add the number to knapsack
knapsacks.append(current_knapsack)
return knapsacks
def get_pixel_values(images: Sequence["ImageObject"], processor: "ProcessorMixin") -> "NDArray":
r"""
Processes visual inputs. (currently only supports a single image)
"""
image_processor: "BaseImageProcessor" = getattr(processor, "image_processor")
image = images[0] if len(images) != 0 else Image.new("RGB", (100, 100), (255, 255, 255))
return image_processor(image, return_tensors="pt")["pixel_values"][0] # shape (C, H, W)
def get_paligemma_token_type_ids(input_len: int, processor: "ProcessorMixin") -> List[int]:
r"""
Gets paligemma token type ids for computing loss.
"""
image_seq_length = getattr(processor, "image_seq_length")
return [0] * image_seq_length + [1] * (input_len - image_seq_length)

View File

@@ -0,0 +1,169 @@
from collections import defaultdict
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Sequence, Tuple
from ...extras.constants import IGNORE_INDEX
from ...extras.logging import get_logger
from .processor_utils import get_paligemma_token_type_ids, get_pixel_values, greedy_knapsack
if TYPE_CHECKING:
from transformers import ProcessorMixin
from transformers.tokenization_utils import PreTrainedTokenizer
from ...hparams import DataArguments
from ..template import Template
logger = get_logger(__name__)
def _encode_supervised_example(
prompt: Sequence[Dict[str, str]],
response: Sequence[Dict[str, str]],
system: Optional[str],
tools: Optional[str],
template: "Template",
tokenizer: "PreTrainedTokenizer",
processor: Optional["ProcessorMixin"],
data_args: "DataArguments",
) -> Tuple[List[int], List[int]]:
if processor is not None and not hasattr(processor, "image_seq_length"): # llava-like models
prompt[0]["content"] = template.image_token + prompt[0]["content"]
messages = prompt + response
input_ids, labels = [], []
if processor is not None and hasattr(processor, "image_seq_length"): # paligemma models
image_token_id = tokenizer.convert_tokens_to_ids(template.image_token)
input_ids += [image_token_id] * getattr(processor, "image_seq_length")
labels += [IGNORE_INDEX] * getattr(processor, "image_seq_length")
encoded_pairs = template.encode_multiturn(
tokenizer, messages, system, tools, data_args.cutoff_len, data_args.reserved_label_len
)
for turn_idx, (source_ids, target_ids) in enumerate(encoded_pairs):
if data_args.train_on_prompt:
source_mask = source_ids
elif turn_idx != 0 and template.efficient_eos:
source_mask = [tokenizer.eos_token_id] + [IGNORE_INDEX] * (len(source_ids) - 1)
else:
source_mask = [IGNORE_INDEX] * len(source_ids)
input_ids += source_ids + target_ids
labels += source_mask + target_ids
if template.efficient_eos:
input_ids += [tokenizer.eos_token_id]
labels += [tokenizer.eos_token_id]
return input_ids, labels
def preprocess_supervised_dataset(
examples: Dict[str, List[Any]],
template: "Template",
tokenizer: "PreTrainedTokenizer",
processor: Optional["ProcessorMixin"],
data_args: "DataArguments",
) -> Dict[str, List[List[int]]]:
# build inputs with format `<bos> X Y <eos>` and labels with format `<ignore> ... <ignore> Y <eos>`
# for multiturn examples, we only mask the prompt part in each prompt-response pair.
model_inputs = {"input_ids": [], "attention_mask": [], "labels": []}
if processor is not None:
model_inputs["pixel_values"] = []
if hasattr(processor, "image_seq_length"): # paligemma models
model_inputs["token_type_ids"] = []
for i in range(len(examples["prompt"])):
if len(examples["prompt"][i]) % 2 != 1 or len(examples["response"][i]) != 1:
logger.warning("Dropped invalid example: {}".format(examples["prompt"][i] + examples["response"][i]))
continue
input_ids, labels = _encode_supervised_example(
prompt=examples["prompt"][i],
response=examples["response"][i],
system=examples["system"][i],
tools=examples["tools"][i],
template=template,
tokenizer=tokenizer,
processor=processor,
data_args=data_args,
)
model_inputs["input_ids"].append(input_ids)
model_inputs["attention_mask"].append([1] * len(input_ids))
model_inputs["labels"].append(labels)
if processor is not None:
model_inputs["pixel_values"].append(get_pixel_values(examples["images"][i], processor))
if hasattr(processor, "image_seq_length"): # paligemma models
model_inputs["token_type_ids"].append(get_paligemma_token_type_ids(len(input_ids), processor))
return model_inputs
def preprocess_packed_supervised_dataset(
examples: Dict[str, List[Any]],
template: "Template",
tokenizer: "PreTrainedTokenizer",
data_args: "DataArguments",
) -> Dict[str, List[List[int]]]:
# build inputs with format `<bos> X1 Y1 <eos> <bos> X2 Y2 <eos>`
# and labels with format `<ignore> ... <ignore> Y1 <eos> <ignore> ... <ignore> Y2 <eos>`
valid_num = 0
batch_input_ids, batch_labels = [], []
lengths = []
length2indexes = defaultdict(list)
for i in range(len(examples["prompt"])):
if len(examples["prompt"][i]) % 2 != 1 or len(examples["response"][i]) != 1:
logger.warning("Dropped invalid example: {}".format(examples["prompt"][i] + examples["response"][i]))
continue
input_ids, labels = _encode_supervised_example(
prompt=examples["prompt"][i],
response=examples["response"][i],
system=examples["system"][i],
tools=examples["tools"][i],
template=template,
tokenizer=tokenizer,
processor=None,
data_args=data_args,
)
length = len(input_ids)
if length > data_args.cutoff_len:
logger.warning("Dropped lengthy example with length {} > {}.".format(length, data_args.cutoff_len))
else:
lengths.append(length)
length2indexes[length].append(valid_num)
batch_input_ids.append(input_ids)
batch_labels.append(labels)
valid_num += 1
model_inputs = {"input_ids": [], "attention_mask": [], "labels": []}
knapsacks = greedy_knapsack(lengths, data_args.cutoff_len)
for knapsack in knapsacks:
packed_input_ids, packed_labels = [], []
for length in knapsack:
index = length2indexes[length].pop()
packed_input_ids += batch_input_ids[index]
packed_labels += batch_labels[index]
if len(packed_input_ids) < data_args.cutoff_len:
pad_length = data_args.cutoff_len - len(packed_input_ids)
packed_input_ids += [tokenizer.pad_token_id] * pad_length
packed_labels += [IGNORE_INDEX] * pad_length
if len(packed_input_ids) != data_args.cutoff_len:
raise ValueError("The length of packed example should be identical to the cutoff length.")
model_inputs["input_ids"].append(packed_input_ids)
model_inputs["attention_mask"].append([1] * data_args.cutoff_len)
model_inputs["labels"].append(packed_labels)
return model_inputs
def print_supervised_dataset_example(example: Dict[str, List[int]], tokenizer: "PreTrainedTokenizer") -> None:
valid_labels = list(filter(lambda x: x != IGNORE_INDEX, example["labels"]))
print("input_ids:\n{}".format(example["input_ids"]))
print("inputs:\n{}".format(tokenizer.decode(example["input_ids"], skip_special_tokens=False)))
print("label_ids:\n{}".format(example["labels"]))
print("labels:\n{}".format(tokenizer.decode(valid_labels, skip_special_tokens=False)))

View File

@@ -0,0 +1,92 @@
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Sequence, Tuple
from ...extras.logging import get_logger
from ..data_utils import Role
from .processor_utils import get_paligemma_token_type_ids, get_pixel_values
if TYPE_CHECKING:
from transformers import ProcessorMixin
from transformers.tokenization_utils import PreTrainedTokenizer
from ...hparams import DataArguments
from ..template import Template
logger = get_logger(__name__)
def _encode_unsupervised_example(
prompt: Sequence[Dict[str, str]],
response: Sequence[Dict[str, str]],
system: Optional[str],
tools: Optional[str],
template: "Template",
tokenizer: "PreTrainedTokenizer",
processor: Optional["ProcessorMixin"],
data_args: "DataArguments",
) -> Tuple[List[int], List[int]]:
if processor is not None and not hasattr(processor, "image_seq_length"): # llava-like models
prompt[0]["content"] = template.image_token + prompt[0]["content"]
if len(response) == 1:
messages = prompt + response
else:
messages = prompt + [{"role": Role.ASSISTANT.value, "content": ""}]
input_ids, labels = template.encode_oneturn(
tokenizer, messages, system, tools, data_args.cutoff_len, data_args.reserved_label_len
)
if template.efficient_eos:
labels += [tokenizer.eos_token_id]
if processor is not None and hasattr(processor, "image_seq_length"): # paligemma models
image_token_id = tokenizer.convert_tokens_to_ids(template.image_token)
input_ids = [image_token_id] * getattr(processor, "image_seq_length") + input_ids
return input_ids, labels
def preprocess_unsupervised_dataset(
examples: Dict[str, List[Any]],
template: "Template",
tokenizer: "PreTrainedTokenizer",
processor: Optional["ProcessorMixin"],
data_args: "DataArguments",
) -> Dict[str, List[List[int]]]:
# build inputs with format `<bos> X` and labels with format `Y <eos>`
model_inputs = {"input_ids": [], "attention_mask": [], "labels": []}
if processor is not None:
model_inputs["pixel_values"] = []
if hasattr(processor, "image_seq_length"): # paligemma models
model_inputs["token_type_ids"] = []
for i in range(len(examples["prompt"])):
if len(examples["prompt"][i]) % 2 != 1:
logger.warning("Dropped invalid example: {}".format(examples["prompt"][i] + examples["response"][i]))
continue
input_ids, labels = _encode_unsupervised_example(
prompt=examples["prompt"][i],
response=examples["response"][i],
system=examples["system"][i],
tools=examples["tools"][i],
template=template,
tokenizer=tokenizer,
processor=processor,
data_args=data_args,
)
model_inputs["input_ids"].append(input_ids)
model_inputs["attention_mask"].append([1] * len(input_ids))
model_inputs["labels"].append(labels)
if processor is not None:
model_inputs["pixel_values"].append(get_pixel_values(examples["images"][i], processor))
if hasattr(processor, "image_seq_length"): # paligemma models
model_inputs["token_type_ids"].append(get_paligemma_token_type_ids(len(input_ids), processor))
return model_inputs
def print_unsupervised_dataset_example(example: Dict[str, List[int]], tokenizer: "PreTrainedTokenizer") -> None:
print("input_ids:\n{}".format(example["input_ids"]))
print("inputs:\n{}".format(tokenizer.decode(example["input_ids"], skip_special_tokens=False)))

View File

@@ -2,8 +2,8 @@ from dataclasses import dataclass
from typing import TYPE_CHECKING, Dict, List, Optional, Sequence, Tuple, Union
from ..extras.logging import get_logger
from .data_utils import Role, infer_max_len
from .formatter import EmptyFormatter, FunctionFormatter, StringFormatter, ToolFormatter
from .utils import Role, infer_max_len
if TYPE_CHECKING:
@@ -26,6 +26,7 @@ class Template:
format_separator: "Formatter"
default_system: str
stop_words: List[str]
image_token: str
efficient_eos: bool
replace_eos: bool
force_system: bool
@@ -68,8 +69,8 @@ class Template:
self,
tokenizer: "PreTrainedTokenizer",
messages: List[Dict[str, str]],
system: str,
tools: str,
system: Optional[str],
tools: Optional[str],
cutoff_len: int,
reserved_label_len: int,
) -> Sequence[Tuple[List[int], List[int]]]:
@@ -195,7 +196,7 @@ class Llama2Template(Template):
return self._make_pairs(encoded_messages, cutoff_len, reserved_label_len)
templates: Dict[str, Template] = {}
TEMPLATES: Dict[str, Template] = {}
def _register_template(
@@ -209,6 +210,7 @@ def _register_template(
format_separator: Optional["Formatter"] = None,
default_system: str = "",
stop_words: List[str] = [],
image_token: str = "<image>",
efficient_eos: bool = False,
replace_eos: bool = False,
force_system: bool = False,
@@ -246,7 +248,7 @@ def _register_template(
default_function_formatter = FunctionFormatter(slots=["Action: {{name}}\nAction Input: {{arguments}}"] + eos_slots)
default_tool_formatter = ToolFormatter(tool_format="default")
default_separator_formatter = EmptyFormatter()
templates[name] = template_class(
TEMPLATES[name] = template_class(
format_user=format_user or default_user_formatter,
format_assistant=format_assistant or default_assistant_formatter,
format_system=format_system or default_user_formatter,
@@ -256,6 +258,7 @@ def _register_template(
format_separator=format_separator or default_separator_formatter,
default_system=default_system,
stop_words=stop_words,
image_token=image_token,
efficient_eos=efficient_eos,
replace_eos=replace_eos,
force_system=force_system,
@@ -276,7 +279,7 @@ def _add_or_replace_eos_token(tokenizer: "PreTrainedTokenizer", eos_token: str)
def _jinja_escape(content: str) -> str:
return content.replace("\n", r"\n").replace("'", r"\'")
return content.replace("'", r"\'")
def _convert_slots_to_jinja(slots: "SLOTS", tokenizer: "PreTrainedTokenizer", placeholder: str = "content") -> str:
@@ -290,10 +293,10 @@ def _convert_slots_to_jinja(slots: "SLOTS", tokenizer: "PreTrainedTokenizer", pl
slot_items.append(placeholder)
if slot_pieces[1]:
slot_items.append("'" + _jinja_escape(slot_pieces[1]) + "'")
elif isinstance(slot, set):
if "bos_token" in slot:
elif isinstance(slot, set): # do not use {{ eos_token }} since it may be replaced
if "bos_token" in slot and tokenizer.bos_token_id is not None:
slot_items.append("'" + tokenizer.bos_token + "'")
elif "eos_token" in slot: # do not use {{ eos_token }} since it may be replaced
elif "eos_token" in slot and tokenizer.eos_token_id is not None:
slot_items.append("'" + tokenizer.eos_token + "'")
elif isinstance(slot, dict):
raise ValueError("Dict is not supported.")
@@ -308,7 +311,7 @@ def _get_jinja_template(template: "Template", tokenizer: "PreTrainedTokenizer")
jinja_template += "{% set system_message = '" + _jinja_escape(template.default_system) + "' %}"
jinja_template += (
"{% if messages[0]['role'] == 'system' %}" "{% set system_message = messages[0]['content'] %}" "{% endif %}"
"{% if messages[0]['role'] == 'system' %}{% set system_message = messages[0]['content'] %}{% endif %}"
)
system_message = _convert_slots_to_jinja(template.format_system.apply(), tokenizer, placeholder="system_message")
@@ -325,9 +328,11 @@ def _get_jinja_template(template: "Template", tokenizer: "PreTrainedTokenizer")
jinja_template += "{% if loop.index0 == 0 and system_message is defined %}"
jinja_template += "{% set content = " + system_message + " + message['content'] %}"
jinja_template += "{% endif %}"
jinja_template += "{% if message['role'] == 'user' %}"
user_message = _convert_slots_to_jinja(template.format_user.apply(), tokenizer)
jinja_template += "{{ " + user_message + " }}"
jinja_template += "{% elif message['role'] == 'assistant' %}"
assistant_message = _convert_slots_to_jinja(
template.format_assistant.apply() + template.format_separator.apply(), tokenizer
@@ -343,9 +348,9 @@ def get_template_and_fix_tokenizer(
name: Optional[str] = None,
) -> Template:
if name is None:
template = templates["vanilla"] # placeholder
template = TEMPLATES["empty"] # placeholder
else:
template = templates.get(name, None)
template = TEMPLATES.get(name, None)
if template is None:
raise ValueError("Template {} does not exist.".format(name))
@@ -385,7 +390,8 @@ _register_template(
format_user=StringFormatter(slots=["### Instruction:\n{{content}}\n\n### Response:\n"]),
format_separator=EmptyFormatter(slots=["\n\n"]),
default_system=(
"Below is an instruction that describes a task. " "Write a response that appropriately completes the request."
"Below is an instruction that describes a task. "
"Write a response that appropriately completes the request.\n\n"
),
)
@@ -414,7 +420,7 @@ _register_template(
_register_template(
name="baichuan",
format_user=StringFormatter(slots=["<reserved_102>{{content}}<reserved_103>"]),
format_user=StringFormatter(slots=[{"token": "<reserved_102>"}, "{{content}}", {"token": "<reserved_103>"}]),
efficient_eos=True,
)
@@ -441,6 +447,18 @@ _register_template(
)
_register_template(
name="breeze",
format_user=StringFormatter(slots=["[INST] {{content}} [/INST] "]),
format_system=StringFormatter(slots=[{"bos_token"}, "{{content}}"]),
default_system=(
"You are a helpful AI assistant built by MediaTek Research. "
"The user you are helping speaks Traditional Chinese and comes from Taiwan."
),
efficient_eos=True,
)
_register_template(
name="chatglm2",
format_user=StringFormatter(slots=["[Round {{idx}}]\n\n问:{{content}}\n\n答:"]),
@@ -490,6 +508,7 @@ _register_template(
name="chatml",
format_user=StringFormatter(slots=["<|im_start|>user\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]),
format_system=StringFormatter(slots=["<|im_start|>system\n{{content}}<|im_end|>\n"]),
format_observation=StringFormatter(slots=["<|im_start|>tool\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]),
format_separator=EmptyFormatter(slots=["\n"]),
stop_words=["<|im_end|>", "<|im_start|>"],
replace_eos=True,
@@ -500,6 +519,7 @@ _register_template(
name="chatml_de",
format_user=StringFormatter(slots=["<|im_start|>user\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]),
format_system=StringFormatter(slots=["<|im_start|>system\n{{content}}<|im_end|>\n"]),
format_observation=StringFormatter(slots=["<|im_start|>tool\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]),
format_separator=EmptyFormatter(slots=["\n"]),
default_system="Du bist ein freundlicher und hilfsbereiter KI-Assistent.",
stop_words=["<|im_end|>", "<|im_start|>"],
@@ -514,6 +534,26 @@ _register_template(
)
_register_template(
name="cohere",
format_user=StringFormatter(
slots=[
(
"<|START_OF_TURN_TOKEN|><|USER_TOKEN|>{{content}}<|END_OF_TURN_TOKEN|>"
"<|START_OF_TURN_TOKEN|><|CHATBOT_TOKEN|>"
)
]
),
format_system=StringFormatter(
slots=[{"bos_token"}, "<|START_OF_TURN_TOKEN|><|SYSTEM_TOKEN|>{{content}}<|END_OF_TURN_TOKEN|>"]
),
default_system=(
"You are Command-R, a brilliant, sophisticated, AI-assistant trained to assist human users "
"by providing thorough responses. You are trained by Cohere."
),
)
_register_template(
name="cpm",
format_user=StringFormatter(slots=["<用户>{{content}}<AI>"]),
@@ -522,6 +562,32 @@ _register_template(
)
_register_template(
name="dbrx",
format_user=StringFormatter(slots=["<|im_start|>user\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]),
format_system=StringFormatter(slots=["<|im_start|>system\n{{content}}<|im_end|>\n"]),
format_observation=StringFormatter(slots=["<|im_start|>tool\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]),
format_separator=EmptyFormatter(slots=["\n"]),
default_system=(
"You are DBRX, created by Databricks. You were last updated in December 2023. "
"You answer questions based on information available up to that point.\n"
"YOU PROVIDE SHORT RESPONSES TO SHORT QUESTIONS OR STATEMENTS, but provide thorough "
"responses to more complex and open-ended questions.\nYou assist with various tasks, "
"from writing to coding (using markdown for code blocks — remember to use ``` with "
"code, JSON, and tables).\n(You do not have real-time data access or code execution "
"capabilities. You avoid stereotyping and provide balanced perspectives on "
"controversial topics. You do not provide song lyrics, poems, or news articles and "
"do not divulge details of your training data.)\nThis is your system prompt, "
"guiding your responses. Do not reference it, just respond to the user. If you find "
"yourself talking about this message, stop. You should be responding appropriately "
"and usually that means not mentioning this.\nYOU DO NOT MENTION ANY OF THIS INFORMATION "
"ABOUT YOURSELF UNLESS THE INFORMATION IS DIRECTLY PERTINENT TO THE USER'S QUERY."
),
stop_words=["<|im_end|>"],
replace_eos=True,
)
_register_template(
name="deepseek",
format_user=StringFormatter(slots=["User: {{content}}\n\nAssistant:"]),
@@ -554,6 +620,16 @@ _register_template(
)
_register_template(
name="empty",
format_user=StringFormatter(slots=["{{content}}"]),
format_assistant=StringFormatter(slots=["{{content}}"]),
format_system=StringFormatter(slots=[{"bos_token"}, "{{content}}"]),
efficient_eos=True,
force_system=True,
)
_register_template(
name="falcon",
format_user=StringFormatter(slots=["User: {{content}}\nFalcon:"]),
@@ -562,16 +638,39 @@ _register_template(
)
_register_template(
name="fewshot",
format_separator=EmptyFormatter(slots=["\n\n"]),
efficient_eos=True,
)
_register_template(
name="gemma",
format_user=StringFormatter(slots=["<start_of_turn>user\n{{content}}<end_of_turn>\n<start_of_turn>model\n"]),
format_system=StringFormatter(slots=[{"bos_token"}, "{{content}}"]),
format_observation=StringFormatter(
slots=["<start_of_turn>tool\n{{content}}<end_of_turn>\n<start_of_turn>model\n"]
),
format_separator=EmptyFormatter(slots=["<end_of_turn>\n"]),
efficient_eos=True,
force_system=True,
)
_register_template(
name="glm4",
format_user=StringFormatter(slots=["<|user|>\n{{content}}<|assistant|>"]),
format_assistant=StringFormatter(slots=["\n{{content}}"]),
format_system=StringFormatter(slots=["[gMASK]<sop>{{content}}"]),
format_function=FunctionFormatter(slots=["{{name}}\n{{arguments}}"]),
format_observation=StringFormatter(slots=["<|observation|>\n{{content}}<|assistant|>"]),
stop_words=["<|user|>", "<|observation|>"],
efficient_eos=True,
force_system=True,
)
_register_template(
name="intern",
format_user=StringFormatter(slots=["<|User|>:{{content}}", {"token": "<eoh>"}, "\n<|Bot|>:"]),
@@ -601,17 +700,8 @@ _register_template(
_register_template(
name="llama2",
format_user=StringFormatter(slots=[{"bos_token"}, "[INST] {{content}} [/INST]"]),
format_assistant=StringFormatter(slots=[" {{content}} ", {"eos_token"}]),
format_system=StringFormatter(slots=["<<SYS>>\n{{content}}\n<</SYS>>\n\n"]),
default_system=(
"You are a helpful, respectful and honest assistant. "
"Always answer as helpfully as possible, while being safe. "
"Your answers should not include any harmful, unethical, "
"racist, sexist, toxic, dangerous, or illegal content. "
"Please ensure that your responses are socially unbiased and positive in nature.\n\n"
"If a question does not make any sense, or is not factually coherent, "
"explain why instead of answering something not correct. "
"If you don't know the answer to a question, please don't share false information."
),
)
@@ -623,6 +713,33 @@ _register_template(
)
_register_template(
name="llama3",
format_user=StringFormatter(
slots=[
(
"<|start_header_id|>user<|end_header_id|>\n\n{{content}}<|eot_id|>"
"<|start_header_id|>assistant<|end_header_id|>\n\n"
)
]
),
format_system=StringFormatter(
slots=[{"bos_token"}, "<|start_header_id|>system<|end_header_id|>\n\n{{content}}<|eot_id|>"]
),
format_observation=StringFormatter(
slots=[
(
"<|start_header_id|>tool<|end_header_id|>\n\n{{content}}<|eot_id|>"
"<|start_header_id|>assistant<|end_header_id|>\n\n"
)
]
),
default_system="You are a helpful assistant.",
stop_words=["<|eot_id|>"],
replace_eos=True,
)
_register_template(
name="mistral",
format_user=StringFormatter(slots=["[INST] {{content}} [/INST]"]),
@@ -633,8 +750,7 @@ _register_template(
_register_template(
name="olmo",
format_user=StringFormatter(slots=["<|user|>\n{{content}}<|assistant|>"]),
format_assistant=StringFormatter(slots=["{{content}}", {"eos_token"}]),
format_user=StringFormatter(slots=["<|user|>\n{{content}}<|assistant|>\n"]),
format_system=StringFormatter(slots=[{"eos_token"}, "{{content}}"]),
force_system=True,
)
@@ -643,12 +759,28 @@ _register_template(
_register_template(
name="openchat",
format_user=StringFormatter(slots=["GPT4 Correct User: {{content}}", {"eos_token"}, "GPT4 Correct Assistant:"]),
format_assistant=StringFormatter(slots=["{{content}}", {"eos_token"}]),
format_system=StringFormatter(slots=[{"bos_token"}, "{{content}}"]),
force_system=True,
)
_register_template(
name="openchat-3.6",
format_user=StringFormatter(
slots=[
(
"<|start_header_id|>GPT4 Correct User<|end_header_id|>\n\n{{content}}<|eot_id|>"
"<|start_header_id|>GPT4 Correct Assistant<|end_header_id|>\n\n"
)
]
),
format_system=StringFormatter(slots=[{"bos_token"}, "{{content}}"]),
stop_words=["<|eot_id|>"],
replace_eos=True,
force_system=True,
)
_register_template(
name="orion",
format_user=StringFormatter(slots=["Human: {{content}}\n\nAssistant: ", {"eos_token"}]),
@@ -657,10 +789,22 @@ _register_template(
)
_register_template(
name="phi",
format_user=StringFormatter(slots=["<|user|>\n{{content}}<|end|>\n<|assistant|>\n"]),
format_system=StringFormatter(slots=[{"bos_token"}, "<|system|>\n{{content}}<|end|>\n"]),
format_separator=EmptyFormatter(slots=["\n"]),
default_system="You are a helpful AI assistant.",
stop_words=["<|end|>"],
replace_eos=True,
)
_register_template(
name="qwen",
format_user=StringFormatter(slots=["<|im_start|>user\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]),
format_system=StringFormatter(slots=["<|im_start|>system\n{{content}}<|im_end|>\n"]),
format_observation=StringFormatter(slots=["<|im_start|>tool\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]),
format_separator=EmptyFormatter(slots=["\n"]),
default_system="You are a helpful assistant.",
stop_words=["<|im_end|>"],
@@ -688,7 +832,11 @@ _register_template(
_register_template(
name="vanilla",
name="telechat",
format_user=StringFormatter(slots=["<_user>{{content}}<_bot>"]),
format_system=StringFormatter(slots=["<_system>{{content}}<_end>"]),
stop_words=["<_end>"],
replace_eos=True,
)
@@ -742,12 +890,29 @@ _register_template(
_register_template(
name="yi",
format_user=StringFormatter(slots=["<|im_start|>user\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]),
format_system=StringFormatter(slots=["<|im_start|>system\n{{content}}<|im_end|>\n"]),
format_separator=EmptyFormatter(slots=["\n"]),
stop_words=["<|im_end|>"],
replace_eos=True,
)
_register_template(
name="yi_vl",
format_user=StringFormatter(slots=["### Human: {{content}}\n### Assistant:"]),
format_separator=EmptyFormatter(slots=["\n"]),
default_system=(
"This is a chat between an inquisitive human and an AI assistant. "
"Assume the role of the AI assistant. Read all the images carefully, "
"and respond to the human's questions with informative, helpful, detailed and polite answers. "
"这是一个好奇的人类和一个人工智能助手之间的对话。假设你扮演这个AI助手的角色。"
"仔细阅读所有的图像,并对人类的问题做出信息丰富、有帮助、详细的和礼貌的回答。\n\n"
),
stop_words=["###"],
efficient_eos=True,
)
_register_template(
name="yuan",
format_user=StringFormatter(slots=["{{content}}", {"token": "<sep>"}]),
@@ -762,7 +927,7 @@ _register_template(
format_user=StringFormatter(slots=["<|user|>\n{{content}}", {"eos_token"}, "<|assistant|>"]),
format_assistant=StringFormatter(slots=["\n{{content}}", {"eos_token"}]),
format_system=StringFormatter(slots=["<|system|>\n{{content}}", {"eos_token"}]),
default_system="You are a friendly chatbot who always responds in the style of a pirate",
default_system="You are Zephyr, a helpful assistant.",
)

View File

@@ -14,16 +14,17 @@ from transformers.utils import cached_file
from ..data import get_template_and_fix_tokenizer
from ..extras.constants import CHOICES, SUBJECTS
from ..hparams import get_eval_args
from ..model import load_model_and_tokenizer
from ..model import load_model, load_tokenizer
from .template import get_eval_template
class Evaluator:
def __init__(self, args: Optional[Dict[str, Any]] = None) -> None:
self.model_args, self.data_args, self.eval_args, finetuning_args = get_eval_args(args)
self.model, self.tokenizer = load_model_and_tokenizer(self.model_args, finetuning_args)
self.tokenizer = load_tokenizer(self.model_args)["tokenizer"]
self.tokenizer.padding_side = "right" # avoid overflow issue in batched inference for llama2
self.template = get_template_and_fix_tokenizer(self.tokenizer, self.data_args.template)
self.model = load_model(self.tokenizer, self.model_args, finetuning_args)
self.eval_template = get_eval_template(self.eval_args.lang)
self.choice_inputs = [
self.tokenizer.encode(self.eval_template.prefix + ch, add_special_tokens=False)[-1] for ch in CHOICES
@@ -117,6 +118,5 @@ class Evaluator:
f.write(score_info)
if __name__ == "__main__":
evaluator = Evaluator()
evaluator.eval()
def run_eval() -> None:
Evaluator().eval()

View File

@@ -1,14 +1,10 @@
from dataclasses import dataclass
from typing import TYPE_CHECKING, Dict, List, Tuple
from typing import Dict, List, Sequence, Tuple
from ..data import Role
from ..extras.constants import CHOICES
if TYPE_CHECKING:
from datasets import Dataset
@dataclass
class EvalTemplate:
system: str
@@ -16,22 +12,29 @@ class EvalTemplate:
answer: str
prefix: str
def parse_example(self, example: Dict[str, str]) -> Tuple[str, str]:
def _parse_example(self, example: Dict[str, str]) -> Tuple[str, str]:
r"""
input: a dict with keys {"question", "A", "B", "C", "D", "answer"}
output: a tuple of (prompt, response)
"""
candidates = [self.choice.format(choice=ch, content=example[ch]) for ch in CHOICES if ch in example]
return "".join([example["question"]] + candidates + [self.answer]), example["answer"]
def format_example(
self, target_data: Dict[str, str], support_set: "Dataset", subject_name: str
self, target_data: Dict[str, str], support_set: Sequence[Dict[str, str]], subject_name: str
) -> List[Dict[str, str]]:
r"""
Converts dataset examples to messages.
"""
messages = []
for k in range(len(support_set)):
prompt, response = self.parse_example(support_set[k])
messages.append({"role": Role.USER, "content": prompt})
messages.append({"role": Role.ASSISTANT, "content": response})
prompt, response = self._parse_example(support_set[k])
messages.append({"role": Role.USER.value, "content": prompt})
messages.append({"role": Role.ASSISTANT.value, "content": response})
prompt, response = self.parse_example(target_data)
messages.append({"role": Role.USER, "content": prompt})
messages.append({"role": Role.ASSISTANT, "content": response})
prompt, response = self._parse_example(target_data)
messages.append({"role": Role.USER.value, "content": prompt})
messages.append({"role": Role.ASSISTANT.value, "content": response})
messages[0]["content"] = self.system.format(subject=subject_name) + messages[0]["content"]
return messages
@@ -39,7 +42,7 @@ class EvalTemplate:
eval_templates: Dict[str, "EvalTemplate"] = {}
def register_eval_template(name: str, system: str, choice: str, answer: str, prefix: str) -> None:
def _register_eval_template(name: str, system: str, choice: str, answer: str, prefix: str) -> None:
eval_templates[name] = EvalTemplate(system=system, choice=choice, answer=answer, prefix=prefix)
@@ -49,7 +52,7 @@ def get_eval_template(name: str) -> "EvalTemplate":
return eval_template
register_eval_template(
_register_eval_template(
name="en",
system="The following are multiple choice questions (with answers) about {subject}.\n\n",
choice="\n{choice}. {content}",
@@ -58,10 +61,10 @@ register_eval_template(
)
register_eval_template(
_register_eval_template(
name="zh",
system="以下是中国关于{subject}考试的单项选择题,请选出其中的正确答案。\n\n",
choice="\n{choice}. {content}",
answer="\n答案:",
prefix="\n",
prefix=" ",
)

View File

@@ -0,0 +1,217 @@
import json
import logging
import os
import signal
import sys
import time
from concurrent.futures import ThreadPoolExecutor
from datetime import timedelta
from typing import TYPE_CHECKING, Any, Dict, Optional
import transformers
from transformers import TrainerCallback
from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR, has_length
from .constants import TRAINER_LOG
from .logging import LoggerHandler, get_logger
from .misc import fix_valuehead_checkpoint
if TYPE_CHECKING:
from transformers import TrainerControl, TrainerState, TrainingArguments
logger = get_logger(__name__)
class FixValueHeadModelCallback(TrainerCallback):
def on_save(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs):
r"""
Event called after a checkpoint save.
"""
if args.should_save:
fix_valuehead_checkpoint(
model=kwargs.pop("model"),
output_dir=os.path.join(args.output_dir, "{}-{}".format(PREFIX_CHECKPOINT_DIR, state.global_step)),
safe_serialization=args.save_safetensors,
)
class LogCallback(TrainerCallback):
def __init__(self, output_dir: str) -> None:
r"""
Initializes a callback for logging training and evaluation status.
"""
""" Progress """
self.start_time = 0
self.cur_steps = 0
self.max_steps = 0
self.elapsed_time = ""
self.remaining_time = ""
self.thread_pool: Optional["ThreadPoolExecutor"] = None
""" Status """
self.aborted = False
self.do_train = False
""" Web UI """
self.webui_mode = os.environ.get("LLAMABOARD_ENABLED", "0").lower() in ["true", "1"]
if self.webui_mode:
signal.signal(signal.SIGABRT, self._set_abort)
self.logger_handler = LoggerHandler(output_dir)
logging.root.addHandler(self.logger_handler)
transformers.logging.add_handler(self.logger_handler)
def _set_abort(self, signum, frame) -> None:
self.aborted = True
def _reset(self, max_steps: int = 0) -> None:
self.start_time = time.time()
self.cur_steps = 0
self.max_steps = max_steps
self.elapsed_time = ""
self.remaining_time = ""
def _timing(self, cur_steps: int) -> None:
cur_time = time.time()
elapsed_time = cur_time - self.start_time
avg_time_per_step = elapsed_time / cur_steps if cur_steps != 0 else 0
remaining_time = (self.max_steps - cur_steps) * avg_time_per_step
self.cur_steps = cur_steps
self.elapsed_time = str(timedelta(seconds=int(elapsed_time)))
self.remaining_time = str(timedelta(seconds=int(remaining_time)))
def _write_log(self, output_dir: str, logs: Dict[str, Any]) -> None:
with open(os.path.join(output_dir, TRAINER_LOG), "a", encoding="utf-8") as f:
f.write(json.dumps(logs) + "\n")
def _create_thread_pool(self, output_dir: str) -> None:
os.makedirs(output_dir, exist_ok=True)
self.thread_pool = ThreadPoolExecutor(max_workers=1)
def _close_thread_pool(self) -> None:
if self.thread_pool is not None:
self.thread_pool.shutdown(wait=True)
self.thread_pool = None
def on_init_end(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs):
r"""
Event called at the end of the initialization of the `Trainer`.
"""
if (
args.should_save
and os.path.exists(os.path.join(args.output_dir, TRAINER_LOG))
and args.overwrite_output_dir
):
logger.warning("Previous trainer log in this folder will be deleted.")
os.remove(os.path.join(args.output_dir, TRAINER_LOG))
def on_train_begin(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs):
r"""
Event called at the beginning of training.
"""
if args.should_save:
self.do_train = True
self._reset(max_steps=state.max_steps)
self._create_thread_pool(output_dir=args.output_dir)
def on_train_end(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs):
r"""
Event called at the end of training.
"""
self._close_thread_pool()
def on_substep_end(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs):
r"""
Event called at the end of an substep during gradient accumulation.
"""
if self.aborted:
control.should_epoch_stop = True
control.should_training_stop = True
def on_step_end(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs):
r"""
Event called at the end of a training step.
"""
if self.aborted:
control.should_epoch_stop = True
control.should_training_stop = True
def on_evaluate(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs):
r"""
Event called after an evaluation phase.
"""
if not self.do_train:
self._close_thread_pool()
def on_predict(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs):
r"""
Event called after a successful prediction.
"""
if not self.do_train:
self._close_thread_pool()
def on_log(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs):
r"""
Event called after logging the last logs.
"""
if not args.should_save:
return
self._timing(cur_steps=state.global_step)
logs = dict(
current_steps=self.cur_steps,
total_steps=self.max_steps,
loss=state.log_history[-1].get("loss", None),
eval_loss=state.log_history[-1].get("eval_loss", None),
predict_loss=state.log_history[-1].get("predict_loss", None),
reward=state.log_history[-1].get("reward", None),
accuracy=state.log_history[-1].get("rewards/accuracies", None),
learning_rate=state.log_history[-1].get("learning_rate", None),
epoch=state.log_history[-1].get("epoch", None),
percentage=round(self.cur_steps / self.max_steps * 100, 2) if self.max_steps != 0 else 100,
elapsed_time=self.elapsed_time,
remaining_time=self.remaining_time,
throughput="{:.2f}".format(state.num_input_tokens_seen / (time.time() - self.start_time)),
total_tokens=state.num_input_tokens_seen,
)
logs = {k: v for k, v in logs.items() if v is not None}
if self.webui_mode and all(key in logs for key in ["loss", "learning_rate", "epoch"]):
logger.info(
"{{'loss': {:.4f}, 'learning_rate': {:2.4e}, 'epoch': {:.2f}, 'throughput': {}}}".format(
logs["loss"], logs["learning_rate"], logs["epoch"], logs["throughput"]
)
)
if self.thread_pool is not None:
self.thread_pool.submit(self._write_log, args.output_dir, logs)
def on_prediction_step(
self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs
):
r"""
Event called after a prediction step.
"""
if self.do_train:
return
if self.aborted:
sys.exit(0)
if not args.should_save:
return
eval_dataloader = kwargs.pop("eval_dataloader", None)
if has_length(eval_dataloader):
if self.max_steps == 0:
self._reset(max_steps=len(eval_dataloader))
self._create_thread_pool(output_dir=args.output_dir)
self._timing(cur_steps=self.cur_steps + 1)
if self.cur_steps % 5 == 0 and self.thread_pool is not None:
logs = dict(
current_steps=self.cur_steps,
total_steps=self.max_steps,
percentage=round(self.cur_steps / self.max_steps * 100, 2) if self.max_steps != 0 else 100,
elapsed_time=self.elapsed_time,
remaining_time=self.remaining_time,
)
self.thread_pool.submit(self._write_log, args.output_dir, logs)

View File

@@ -2,13 +2,24 @@ from collections import OrderedDict, defaultdict
from enum import Enum
from typing import Dict, Optional
from peft.utils import SAFETENSORS_WEIGHTS_NAME as SAFE_ADAPTER_WEIGHTS_NAME
from peft.utils import WEIGHTS_NAME as ADAPTER_WEIGHTS_NAME
from transformers.utils import SAFE_WEIGHTS_INDEX_NAME, SAFE_WEIGHTS_NAME, WEIGHTS_INDEX_NAME, WEIGHTS_NAME
CHECKPOINT_NAMES = {
SAFE_ADAPTER_WEIGHTS_NAME,
ADAPTER_WEIGHTS_NAME,
SAFE_WEIGHTS_INDEX_NAME,
SAFE_WEIGHTS_NAME,
WEIGHTS_INDEX_NAME,
WEIGHTS_NAME,
}
CHOICES = ["A", "B", "C", "D"]
DATA_CONFIG = "dataset_info.json"
DEFAULT_MODULE = defaultdict(str)
DEFAULT_TEMPLATE = defaultdict(str)
FILEEXT2TYPE = {
@@ -24,28 +35,43 @@ IGNORE_INDEX = -100
LAYERNORM_NAMES = {"norm", "ln"}
LOG_FILE_NAME = "trainer_log.jsonl"
LLAMABOARD_CONFIG = "llamaboard_config.yaml"
METHODS = ["full", "freeze", "lora"]
PEFT_METHODS = ["lora"]
MOD_SUPPORTED_MODELS = {"bloom", "falcon", "gemma", "llama", "mistral", "mixtral", "phi", "starcoder2"}
PEFT_METHODS = {"lora"}
RUNNING_LOG = "running_log.txt"
SUBJECTS = ["Average", "STEM", "Social Sciences", "Humanities", "Other"]
SUPPORTED_MODELS = OrderedDict()
TRAINER_LOG = "trainer_log.jsonl"
TRAINING_ARGS = "training_args.yaml"
TRAINING_STAGES = {
"Supervised Fine-Tuning": "sft",
"Reward Modeling": "rm",
"PPO": "ppo",
"DPO": "dpo",
"KTO": "kto",
"Pre-Training": "pt",
}
STAGES_USE_PAIR_DATA = {"rm", "dpo"}
SUPPORTED_CLASS_FOR_S2ATTN = {"llama"}
V_HEAD_WEIGHTS_NAME = "value_head.bin"
V_HEAD_SAFE_WEIGHTS_NAME = "value_head.safetensors"
VISION_MODELS = set()
class DownloadSource(str, Enum):
DEFAULT = "hf"
@@ -54,8 +80,8 @@ class DownloadSource(str, Enum):
def register_model_group(
models: Dict[str, Dict[DownloadSource, str]],
module: Optional[str] = None,
template: Optional[str] = None,
vision: bool = False,
) -> None:
prefix = None
for name, path in models.items():
@@ -64,10 +90,23 @@ def register_model_group(
else:
assert prefix == name.split("-")[0], "prefix should be identical."
SUPPORTED_MODELS[name] = path
if module is not None:
DEFAULT_MODULE[prefix] = module
if template is not None:
DEFAULT_TEMPLATE[prefix] = template
if vision:
VISION_MODELS.add(prefix)
register_model_group(
models={
"Aya-23-8B-Chat": {
DownloadSource.DEFAULT: "CohereForAI/aya-23-8B",
},
"Aya-23-35B-Chat": {
DownloadSource.DEFAULT: "CohereForAI/aya-23-35B",
},
},
template="cohere",
)
register_model_group(
@@ -85,7 +124,6 @@ register_model_group(
DownloadSource.MODELSCOPE: "baichuan-inc/Baichuan-13B-Chat",
},
},
module="W_pack",
template="baichuan",
)
@@ -109,7 +147,6 @@ register_model_group(
DownloadSource.MODELSCOPE: "baichuan-inc/Baichuan2-13B-Chat",
},
},
module="W_pack",
template="baichuan2",
)
@@ -129,7 +166,6 @@ register_model_group(
DownloadSource.MODELSCOPE: "AI-ModelScope/bloom-7b1",
},
},
module="query_key_value",
)
@@ -148,7 +184,6 @@ register_model_group(
DownloadSource.MODELSCOPE: "AI-ModelScope/bloomz-7b1-mt",
},
},
module="query_key_value",
)
@@ -167,6 +202,19 @@ register_model_group(
)
register_model_group(
models={
"Breeze-7B": {
DownloadSource.DEFAULT: "MediaTek-Research/Breeze-7B-Base-v1_0",
},
"Breeze-7B-Chat": {
DownloadSource.DEFAULT: "MediaTek-Research/Breeze-7B-Instruct-v1_0",
},
},
template="breeze",
)
register_model_group(
models={
"ChatGLM2-6B-Chat": {
@@ -174,7 +222,6 @@ register_model_group(
DownloadSource.MODELSCOPE: "ZhipuAI/chatglm2-6b",
}
},
module="query_key_value",
template="chatglm2",
)
@@ -190,7 +237,6 @@ register_model_group(
DownloadSource.MODELSCOPE: "ZhipuAI/chatglm3-6b",
},
},
module="query_key_value",
template="chatglm3",
)
@@ -226,6 +272,73 @@ register_model_group(
)
register_model_group(
models={
"CodeGemma-7B": {
DownloadSource.DEFAULT: "google/codegemma-7b",
},
"CodeGemma-7B-Chat": {
DownloadSource.DEFAULT: "google/codegemma-7b-it",
DownloadSource.MODELSCOPE: "AI-ModelScope/codegemma-7b-it",
},
"CodeGemma-1.1-2B": {
DownloadSource.DEFAULT: "google/codegemma-1.1-2b",
},
"CodeGemma-1.1-7B-Chat": {
DownloadSource.DEFAULT: "google/codegemma-1.1-7b-it",
},
},
template="gemma",
)
register_model_group(
models={
"Codestral-22B-v0.1-Chat": {
DownloadSource.DEFAULT: "mistralai/Codestral-22B-v0.1",
},
},
template="mistral",
)
register_model_group(
models={
"CommandR-35B-Chat": {
DownloadSource.DEFAULT: "CohereForAI/c4ai-command-r-v01",
DownloadSource.MODELSCOPE: "AI-ModelScope/c4ai-command-r-v01",
},
"CommandR-Plus-104B-Chat": {
DownloadSource.DEFAULT: "CohereForAI/c4ai-command-r-plus",
DownloadSource.MODELSCOPE: "AI-ModelScope/c4ai-command-r-plus",
},
"CommandR-35B-4bit-Chat": {
DownloadSource.DEFAULT: "CohereForAI/c4ai-command-r-v01-4bit",
DownloadSource.MODELSCOPE: "mirror013/c4ai-command-r-v01-4bit",
},
"CommandR-Plus-104B-4bit-Chat": {
DownloadSource.DEFAULT: "CohereForAI/c4ai-command-r-plus-4bit",
},
},
template="cohere",
)
register_model_group(
models={
"DBRX-132B-Base": {
DownloadSource.DEFAULT: "databricks/dbrx-base",
DownloadSource.MODELSCOPE: "AI-ModelScope/dbrx-base",
},
"DBRX-132B-Chat": {
DownloadSource.DEFAULT: "databricks/dbrx-instruct",
DownloadSource.MODELSCOPE: "AI-ModelScope/dbrx-instruct",
},
},
template="dbrx",
)
register_model_group(
models={
"DeepSeek-LLM-7B-Base": {
@@ -246,18 +359,36 @@ register_model_group(
},
"DeepSeek-Math-7B-Base": {
DownloadSource.DEFAULT: "deepseek-ai/deepseek-math-7b-base",
DownloadSource.MODELSCOPE: "deepseek-ai/deepseek-math-7b-base",
},
"DeepSeek-Math-7B-Chat": {
DownloadSource.DEFAULT: "deepseek-ai/deepseek-math-7b-instruct",
DownloadSource.MODELSCOPE: "deepseek-ai/deepseek-math-7b-instruct",
},
"DeepSeek-MoE-16B-Base": {
DownloadSource.DEFAULT: "deepseek-ai/deepseek-moe-16b-base",
DownloadSource.MODELSCOPE: "deepseek-ai/deepseek-moe-16b-base",
},
"DeepSeek-MoE-16B-v2-Base": {
DownloadSource.DEFAULT: "deepseek-ai/DeepSeek-V2-Lite",
DownloadSource.MODELSCOPE: "deepseek-ai/DeepSeek-V2-Lite",
},
"DeepSeek-MoE-236B-Base": {
DownloadSource.DEFAULT: "deepseek-ai/DeepSeek-V2",
DownloadSource.MODELSCOPE: "deepseek-ai/DeepSeek-V2",
},
"DeepSeek-MoE-16B-Chat": {
DownloadSource.DEFAULT: "deepseek-ai/deepseek-moe-16b-chat",
DownloadSource.MODELSCOPE: "deepseek-ai/deepseek-moe-16b-chat",
},
"DeepSeek-MoE-16B-v2-Chat": {
DownloadSource.DEFAULT: "deepseek-ai/DeepSeek-V2-Lite-Chat",
DownloadSource.MODELSCOPE: "deepseek-ai/DeepSeek-V2-Lite-Chat",
},
"DeepSeek-MoE-236B-Chat": {
DownloadSource.DEFAULT: "deepseek-ai/DeepSeek-V2-Chat",
DownloadSource.MODELSCOPE: "deepseek-ai/DeepSeek-V2-Chat",
},
},
template="deepseek",
)
@@ -298,6 +429,9 @@ register_model_group(
DownloadSource.DEFAULT: "tiiuae/falcon-7b",
DownloadSource.MODELSCOPE: "AI-ModelScope/falcon-7b",
},
"Falcon-11B": {
DownloadSource.DEFAULT: "tiiuae/falcon-11B",
},
"Falcon-40B": {
DownloadSource.DEFAULT: "tiiuae/falcon-40b",
DownloadSource.MODELSCOPE: "AI-ModelScope/falcon-40b",
@@ -319,7 +453,6 @@ register_model_group(
DownloadSource.MODELSCOPE: "modelscope/falcon-180B-chat",
},
},
module="query_key_value",
template="falcon",
)
@@ -342,11 +475,36 @@ register_model_group(
DownloadSource.DEFAULT: "google/gemma-7b-it",
DownloadSource.MODELSCOPE: "AI-ModelScope/gemma-7b-it",
},
"Gemma-1.1-2B-Chat": {
DownloadSource.DEFAULT: "google/gemma-1.1-2b-it",
},
"Gemma-1.1-7B-Chat": {
DownloadSource.DEFAULT: "google/gemma-1.1-7b-it",
},
},
template="gemma",
)
register_model_group(
models={
"GLM-4-9B": {
DownloadSource.DEFAULT: "THUDM/glm-4-9b",
DownloadSource.MODELSCOPE: "ZhipuAI/glm-4-9b",
},
"GLM-4-9B-Chat": {
DownloadSource.DEFAULT: "THUDM/glm-4-9b-chat",
DownloadSource.MODELSCOPE: "ZhipuAI/glm-4-9b-chat",
},
"GLM-4-9B-1M-Chat": {
DownloadSource.DEFAULT: "THUDM/glm-4-9b-chat-1m",
DownloadSource.MODELSCOPE: "ZhipuAI/glm-4-9b-chat-1m",
},
},
template="glm4",
)
register_model_group(
models={
"InternLM-7B": {
@@ -389,11 +547,20 @@ register_model_group(
DownloadSource.MODELSCOPE: "Shanghai_AI_Laboratory/internlm2-chat-20b",
},
},
module="wqkv",
template="intern2",
)
register_model_group(
models={
"Jambda-v0.1": {
DownloadSource.DEFAULT: "ai21labs/Jamba-v0.1",
DownloadSource.MODELSCOPE: "AI-ModelScope/Jamba-v0.1",
}
},
)
register_model_group(
models={
"LingoWhale-8B": {
@@ -401,7 +568,6 @@ register_model_group(
DownloadSource.MODELSCOPE: "DeepLang/LingoWhale-8B",
}
},
module="qkv_proj",
)
@@ -460,18 +626,72 @@ register_model_group(
register_model_group(
models={
"Mistral-7B": {
"LLaMA3-8B": {
DownloadSource.DEFAULT: "meta-llama/Meta-Llama-3-8B",
DownloadSource.MODELSCOPE: "LLM-Research/Meta-Llama-3-8B",
},
"LLaMA3-70B": {
DownloadSource.DEFAULT: "meta-llama/Meta-Llama-3-70B",
DownloadSource.MODELSCOPE: "LLM-Research/Meta-Llama-3-70B",
},
"LLaMA3-8B-Chat": {
DownloadSource.DEFAULT: "meta-llama/Meta-Llama-3-8B-Instruct",
DownloadSource.MODELSCOPE: "LLM-Research/Meta-Llama-3-8B-Instruct",
},
"LLaMA3-70B-Chat": {
DownloadSource.DEFAULT: "meta-llama/Meta-Llama-3-70B-Instruct",
DownloadSource.MODELSCOPE: "LLM-Research/Meta-Llama-3-70B-Instruct",
},
"LLaMA3-8B-Chinese-Chat": {
DownloadSource.DEFAULT: "shenzhi-wang/Llama3-8B-Chinese-Chat",
DownloadSource.MODELSCOPE: "LLM-Research/Llama3-8B-Chinese-Chat",
},
"LLaMA3-70B-Chinese-Chat": {
DownloadSource.DEFAULT: "shenzhi-wang/Llama3-70B-Chinese-Chat",
},
},
template="llama3",
)
register_model_group(
models={
"LLaVA1.5-7B-Chat": {
DownloadSource.DEFAULT: "llava-hf/llava-1.5-7b-hf",
},
"LLaVA1.5-13B-Chat": {
DownloadSource.DEFAULT: "llava-hf/llava-1.5-13b-hf",
},
},
template="vicuna",
vision=True,
)
register_model_group(
models={
"Mistral-7B-v0.1": {
DownloadSource.DEFAULT: "mistralai/Mistral-7B-v0.1",
DownloadSource.MODELSCOPE: "AI-ModelScope/Mistral-7B-v0.1",
},
"Mistral-7B-Chat": {
"Mistral-7B-v0.1-Chat": {
DownloadSource.DEFAULT: "mistralai/Mistral-7B-Instruct-v0.1",
DownloadSource.MODELSCOPE: "AI-ModelScope/Mistral-7B-Instruct-v0.1",
},
"Mistral-7B-v0.2": {
DownloadSource.DEFAULT: "alpindale/Mistral-7B-v0.2-hf",
DownloadSource.MODELSCOPE: "AI-ModelScope/Mistral-7B-v0.2-hf",
},
"Mistral-7B-v0.2-Chat": {
DownloadSource.DEFAULT: "mistralai/Mistral-7B-Instruct-v0.2",
DownloadSource.MODELSCOPE: "AI-ModelScope/Mistral-7B-Instruct-v0.2",
},
"Mistral-7B-v0.3": {
DownloadSource.DEFAULT: "mistralai/Mistral-7B-v0.3",
},
"Mistral-7B-v0.3-Chat": {
DownloadSource.DEFAULT: "mistralai/Mistral-7B-Instruct-v0.3",
},
},
template="mistral",
)
@@ -479,14 +699,22 @@ register_model_group(
register_model_group(
models={
"Mixtral-8x7B": {
"Mixtral-8x7B-v0.1": {
DownloadSource.DEFAULT: "mistralai/Mixtral-8x7B-v0.1",
DownloadSource.MODELSCOPE: "AI-ModelScope/Mixtral-8x7B-v0.1",
},
"Mixtral-8x7B-Chat": {
"Mixtral-8x7B-v0.1-Chat": {
DownloadSource.DEFAULT: "mistralai/Mixtral-8x7B-Instruct-v0.1",
DownloadSource.MODELSCOPE: "AI-ModelScope/Mixtral-8x7B-Instruct-v0.1",
},
"Mixtral-8x22B-v0.1": {
DownloadSource.DEFAULT: "mistralai/Mixtral-8x22B-v0.1",
DownloadSource.MODELSCOPE: "AI-ModelScope/Mixtral-8x22B-v0.1",
},
"Mixtral-8x22B-v0.1-Chat": {
DownloadSource.DEFAULT: "mistralai/Mixtral-8x22B-Instruct-v0.1",
DownloadSource.MODELSCOPE: "AI-ModelScope/Mixtral-8x22B-Instruct-v0.1",
},
},
template="mistral",
)
@@ -495,18 +723,18 @@ register_model_group(
register_model_group(
models={
"OLMo-1B": {
DownloadSource.DEFAULT: "allenai/OLMo-1B",
DownloadSource.DEFAULT: "allenai/OLMo-1B-hf",
},
"OLMo-7B": {
DownloadSource.DEFAULT: "allenai/OLMo-7B",
DownloadSource.MODELSCOPE: "AI-ModelScope/OLMo-7B",
DownloadSource.DEFAULT: "allenai/OLMo-7B-hf",
},
"OLMo-7B-Chat": {
DownloadSource.DEFAULT: "allenai/OLMo-7B-Instruct",
DownloadSource.DEFAULT: "ssec-uw/OLMo-7B-Instruct-hf",
},
"OLMo-1.7-7B": {
DownloadSource.DEFAULT: "allenai/OLMo-1.7-7B-hf",
},
},
module="att_proj",
template="olmo",
)
@@ -514,13 +742,23 @@ register_model_group(
models={
"OpenChat3.5-7B-Chat": {
DownloadSource.DEFAULT: "openchat/openchat-3.5-0106",
DownloadSource.MODELSCOPE: "myxiongmodel/openchat_3.5",
DownloadSource.MODELSCOPE: "xcwzxcwz/openchat-3.5-0106",
}
},
template="openchat",
)
register_model_group(
models={
"OpenChat3.6-8B-Chat": {
DownloadSource.DEFAULT: "openchat/openchat-3.6-8b-20240522",
}
},
template="openchat-3.6",
)
register_model_group(
models={
"Orion-14B-Base": {
@@ -548,6 +786,33 @@ register_model_group(
)
register_model_group(
models={
"PaliGemma-3B-pt-224": {
DownloadSource.DEFAULT: "google/paligemma-3b-pt-224",
DownloadSource.MODELSCOPE: "AI-ModelScope/paligemma-3b-pt-224",
},
"PaliGemma-3B-pt-448": {
DownloadSource.DEFAULT: "google/paligemma-3b-pt-448",
DownloadSource.MODELSCOPE: "AI-ModelScope/paligemma-3b-pt-448",
},
"PaliGemma-3B-pt-896": {
DownloadSource.DEFAULT: "google/paligemma-3b-pt-896",
DownloadSource.MODELSCOPE: "AI-ModelScope/paligemma-3b-pt-896",
},
"PaliGemma-3B-mix-224": {
DownloadSource.DEFAULT: "google/paligemma-3b-mix-224",
DownloadSource.MODELSCOPE: "AI-ModelScope/paligemma-3b-mix-224",
},
"PaliGemma-3B-mix-448": {
DownloadSource.DEFAULT: "google/paligemma-3b-mix-448",
DownloadSource.MODELSCOPE: "AI-ModelScope/paligemma-3b-mix-448",
},
},
vision=True,
)
register_model_group(
models={
"Phi-1.5-1.3B": {
@@ -562,6 +827,37 @@ register_model_group(
)
register_model_group(
models={
"Phi3-4B-4k-Chat": {
DownloadSource.DEFAULT: "microsoft/Phi-3-mini-4k-instruct",
DownloadSource.MODELSCOPE: "LLM-Research/Phi-3-mini-4k-instruct",
},
"Phi3-4B-128k-Chat": {
DownloadSource.DEFAULT: "microsoft/Phi-3-mini-128k-instruct",
DownloadSource.MODELSCOPE: "LLM-Research/Phi-3-mini-128k-instruct",
},
"Phi3-7B-8k-Chat": {
DownloadSource.DEFAULT: "microsoft/Phi-3-small-8k-instruct",
DownloadSource.MODELSCOPE: "LLM-Research/Phi-3-small-8k-instruct",
},
"Phi3-7B-128k-Chat": {
DownloadSource.DEFAULT: "microsoft/Phi-3-small-128k-instruct",
DownloadSource.MODELSCOPE: "LLM-Research/Phi-3-small-128k-instruct",
},
"Phi3-14B-8k-Chat": {
DownloadSource.DEFAULT: "microsoft/Phi-3-medium-4k-instruct",
DownloadSource.MODELSCOPE: "LLM-Research/Phi-3-medium-4k-instruct",
},
"Phi3-14B-128k-Chat": {
DownloadSource.DEFAULT: "microsoft/Phi-3-medium-128k-instruct",
DownloadSource.MODELSCOPE: "LLM-Research/Phi-3-medium-128k-instruct",
},
},
template="phi",
)
register_model_group(
models={
"Qwen-1.8B": {
@@ -629,7 +925,6 @@ register_model_group(
DownloadSource.MODELSCOPE: "qwen/Qwen-72B-Chat-Int4",
},
},
module="c_attn",
template="qwen",
)
@@ -656,10 +951,26 @@ register_model_group(
DownloadSource.DEFAULT: "Qwen/Qwen1.5-14B",
DownloadSource.MODELSCOPE: "qwen/Qwen1.5-14B",
},
"Qwen1.5-32B": {
DownloadSource.DEFAULT: "Qwen/Qwen1.5-32B",
DownloadSource.MODELSCOPE: "qwen/Qwen1.5-32B",
},
"Qwen1.5-72B": {
DownloadSource.DEFAULT: "Qwen/Qwen1.5-72B",
DownloadSource.MODELSCOPE: "qwen/Qwen1.5-72B",
},
"Qwen1.5-110B": {
DownloadSource.DEFAULT: "Qwen/Qwen1.5-110B",
DownloadSource.MODELSCOPE: "qwen/Qwen1.5-110B",
},
"Qwen1.5-MoE-A2.7B": {
DownloadSource.DEFAULT: "Qwen/Qwen1.5-MoE-A2.7B",
DownloadSource.MODELSCOPE: "qwen/Qwen1.5-MoE-A2.7B",
},
"Qwen1.5-Code-7B": {
DownloadSource.DEFAULT: "Qwen/CodeQwen1.5-7B",
DownloadSource.MODELSCOPE: "qwen/CodeQwen1.5-7B",
},
"Qwen1.5-0.5B-Chat": {
DownloadSource.DEFAULT: "Qwen/Qwen1.5-0.5B-Chat",
DownloadSource.MODELSCOPE: "qwen/Qwen1.5-0.5B-Chat",
@@ -680,10 +991,26 @@ register_model_group(
DownloadSource.DEFAULT: "Qwen/Qwen1.5-14B-Chat",
DownloadSource.MODELSCOPE: "qwen/Qwen1.5-14B-Chat",
},
"Qwen1.5-32B-Chat": {
DownloadSource.DEFAULT: "Qwen/Qwen1.5-32B-Chat",
DownloadSource.MODELSCOPE: "qwen/Qwen1.5-32B-Chat",
},
"Qwen1.5-72B-Chat": {
DownloadSource.DEFAULT: "Qwen/Qwen1.5-72B-Chat",
DownloadSource.MODELSCOPE: "qwen/Qwen1.5-72B-Chat",
},
"Qwen1.5-110B-Chat": {
DownloadSource.DEFAULT: "Qwen/Qwen1.5-110B-Chat",
DownloadSource.MODELSCOPE: "qwen/Qwen1.5-110B-Chat",
},
"Qwen1.5-MoE-A2.7B-Chat": {
DownloadSource.DEFAULT: "Qwen/Qwen1.5-MoE-A2.7B-Chat",
DownloadSource.MODELSCOPE: "qwen/Qwen1.5-MoE-A2.7B-Chat",
},
"Qwen1.5-Code-7B-Chat": {
DownloadSource.DEFAULT: "Qwen/CodeQwen1.5-7B-Chat",
DownloadSource.MODELSCOPE: "qwen/CodeQwen1.5-7B-Chat",
},
"Qwen1.5-0.5B-int8-Chat": {
DownloadSource.DEFAULT: "Qwen/Qwen1.5-0.5B-Chat-GPTQ-Int8",
DownloadSource.MODELSCOPE: "qwen/Qwen1.5-0.5B-Chat-GPTQ-Int8",
@@ -724,6 +1051,10 @@ register_model_group(
DownloadSource.DEFAULT: "Qwen/Qwen1.5-14B-Chat-AWQ",
DownloadSource.MODELSCOPE: "qwen/Qwen1.5-14B-Chat-AWQ",
},
"Qwen1.5-32B-int4-Chat": {
DownloadSource.DEFAULT: "Qwen/Qwen1.5-32B-Chat-AWQ",
DownloadSource.MODELSCOPE: "qwen/Qwen1.5-32B-Chat-AWQ",
},
"Qwen1.5-72B-int8-Chat": {
DownloadSource.DEFAULT: "Qwen/Qwen1.5-72B-Chat-GPTQ-Int8",
DownloadSource.MODELSCOPE: "qwen/Qwen1.5-72B-Chat-GPTQ-Int8",
@@ -732,6 +1063,101 @@ register_model_group(
DownloadSource.DEFAULT: "Qwen/Qwen1.5-72B-Chat-AWQ",
DownloadSource.MODELSCOPE: "qwen/Qwen1.5-72B-Chat-AWQ",
},
"Qwen1.5-110B-int4-Chat": {
DownloadSource.DEFAULT: "Qwen/Qwen1.5-110B-Chat-AWQ",
DownloadSource.MODELSCOPE: "qwen/Qwen1.5-110B-Chat-AWQ",
},
"Qwen1.5-MoE-A2.7B-int4-Chat": {
DownloadSource.DEFAULT: "Qwen/Qwen1.5-MoE-A2.7B-Chat-GPTQ-Int4",
DownloadSource.MODELSCOPE: "qwen/Qwen1.5-MoE-A2.7B-Chat-GPTQ-Int4",
},
"Qwen1.5-Code-7B-int4-Chat": {
DownloadSource.DEFAULT: "Qwen/CodeQwen1.5-7B-Chat-AWQ",
DownloadSource.MODELSCOPE: "qwen/CodeQwen1.5-7B-Chat-AWQ",
},
},
template="qwen",
)
register_model_group(
models={
"Qwen2-0.5B": {
DownloadSource.DEFAULT: "Qwen/Qwen2-0.5B",
DownloadSource.MODELSCOPE: "qwen/Qwen2-0.5B",
},
"Qwen2-1.5B": {
DownloadSource.DEFAULT: "Qwen/Qwen2-1.5B",
DownloadSource.MODELSCOPE: "qwen/Qwen2-1.5B",
},
"Qwen2-7B": {
DownloadSource.DEFAULT: "Qwen/Qwen2-7B",
DownloadSource.MODELSCOPE: "qwen/Qwen2-7B",
},
"Qwen2-72B": {
DownloadSource.DEFAULT: "Qwen/Qwen2-72B",
DownloadSource.MODELSCOPE: "qwen/Qwen2-72B",
},
"Qwen2-MoE-57B": {
DownloadSource.DEFAULT: "Qwen/Qwen2-57B-A14B",
DownloadSource.MODELSCOPE: "qwen/Qwen2-57B-A14B",
},
"Qwen2-0.5B-Chat": {
DownloadSource.DEFAULT: "Qwen/Qwen2-0.5B-Instruct",
DownloadSource.MODELSCOPE: "qwen/Qwen2-0.5B-Instruct",
},
"Qwen2-1.5B-Chat": {
DownloadSource.DEFAULT: "Qwen/Qwen2-1.5B-Instruct",
DownloadSource.MODELSCOPE: "qwen/Qwen2-1.5B-Instruct",
},
"Qwen2-7B-Chat": {
DownloadSource.DEFAULT: "Qwen/Qwen2-7B-Instruct",
DownloadSource.MODELSCOPE: "qwen/Qwen2-7B-Instruct",
},
"Qwen2-72B-Chat": {
DownloadSource.DEFAULT: "Qwen/Qwen2-72B-Instruct",
DownloadSource.MODELSCOPE: "qwen/Qwen2-72B-Instruct",
},
"Qwen2-MoE-57B-Chat": {
DownloadSource.DEFAULT: "Qwen/Qwen2-57B-A14B-Instruct",
DownloadSource.MODELSCOPE: "qwen/Qwen2-57B-A14B-Instruct",
},
"Qwen2-0.5B-int8-Chat": {
DownloadSource.DEFAULT: "Qwen/Qwen2-0.5B-Instruct-GPTQ-Int8",
DownloadSource.MODELSCOPE: "qwen/Qwen2-0.5B-Instruct-GPTQ-Int8",
},
"Qwen2-0.5B-int4-Chat": {
DownloadSource.DEFAULT: "Qwen/Qwen2-0.5B-Instruct-AWQ",
DownloadSource.MODELSCOPE: "qwen/Qwen2-0.5B-Instruct-AWQ",
},
"Qwen2-1.5B-int8-Chat": {
DownloadSource.DEFAULT: "Qwen/Qwen2-1.5B-Instruct-GPTQ-Int8",
DownloadSource.MODELSCOPE: "qwen/Qwen2-1.5B-Instruct-GPTQ-Int8",
},
"Qwen2-1.5B-int4-Chat": {
DownloadSource.DEFAULT: "Qwen/Qwen2-1.5B-Instruct-AWQ",
DownloadSource.MODELSCOPE: "qwen/Qwen2-1.5B-Instruct-AWQ",
},
"Qwen2-7B-int8-Chat": {
DownloadSource.DEFAULT: "Qwen/Qwen2-7B-Instruct-GPTQ-Int8",
DownloadSource.MODELSCOPE: "qwen/Qwen2-7B-Instruct-GPTQ-Int8",
},
"Qwen2-7B-int4-Chat": {
DownloadSource.DEFAULT: "Qwen/Qwen2-7B-Instruct-AWQ",
DownloadSource.MODELSCOPE: "qwen/Qwen2-7B-Instruct-AWQ",
},
"Qwen2-72B-int8-Chat": {
DownloadSource.DEFAULT: "Qwen/Qwen2-72B-Instruct-GPTQ-Int8",
DownloadSource.MODELSCOPE: "qwen/Qwen2-72B-Instruct-GPTQ-Int8",
},
"Qwen2-72B-int4-Chat": {
DownloadSource.DEFAULT: "Qwen/Qwen2-72B-Instruct-AWQ",
DownloadSource.MODELSCOPE: "qwen/Qwen2-72B-Instruct-AWQ",
},
"Qwen2-MoE-57B-int4-Chat": {
DownloadSource.DEFAULT: "Qwen/Qwen2-57B-A14B-Instruct-GPTQ-Int4",
DownloadSource.MODELSCOPE: "qwen/Qwen2-57B-A14B-Instruct-GPTQ-Int4",
},
},
template="qwen",
)
@@ -765,17 +1191,39 @@ register_model_group(
models={
"StarCoder2-3B": {
DownloadSource.DEFAULT: "bigcode/starcoder2-3b",
DownloadSource.MODELSCOPE: "AI-ModelScope/starcoder2-3b",
},
"StarCoder2-7B": {
DownloadSource.DEFAULT: "bigcode/starcoder2-7b",
DownloadSource.MODELSCOPE: "AI-ModelScope/starcoder2-7b",
},
"StarCoder2-15B": {
DownloadSource.DEFAULT: "bigcode/starcoder2-15b",
DownloadSource.MODELSCOPE: "AI-ModelScope/starcoder2-15b",
},
}
)
register_model_group(
models={
"TeleChat-7B-Chat": {
DownloadSource.DEFAULT: "Tele-AI/telechat-7B",
DownloadSource.MODELSCOPE: "TeleAI/telechat-7B",
},
"TeleChat-12B-Chat": {
DownloadSource.DEFAULT: "Tele-AI/TeleChat-12B",
DownloadSource.MODELSCOPE: "TeleAI/TeleChat-12B",
},
"TeleChat-12B-v2-Chat": {
DownloadSource.DEFAULT: "Tele-AI/TeleChat-12B-v2",
DownloadSource.MODELSCOPE: "TeleAI/TeleChat-12B-v2",
},
},
template="telechat",
)
register_model_group(
models={
"Vicuna1.5-7B-Chat": {
@@ -793,17 +1241,53 @@ register_model_group(
register_model_group(
models={
"XuanYuan-6B": {
DownloadSource.DEFAULT: "Duxiaoman-DI/XuanYuan-6B",
DownloadSource.MODELSCOPE: "Duxiaoman-DI/XuanYuan-6B",
},
"XuanYuan-70B": {
DownloadSource.DEFAULT: "Duxiaoman-DI/XuanYuan-70B",
DownloadSource.MODELSCOPE: "Duxiaoman-DI/XuanYuan-70B",
},
"XuanYuan-2-70B": {
DownloadSource.DEFAULT: "Duxiaoman-DI/XuanYuan2-70B",
DownloadSource.MODELSCOPE: "Duxiaoman-DI/XuanYuan2-70B",
},
"XuanYuan-6B-Chat": {
DownloadSource.DEFAULT: "Duxiaoman-DI/XuanYuan-6B-Chat",
DownloadSource.MODELSCOPE: "Duxiaoman-DI/XuanYuan-6B-Chat",
},
"XuanYuan-70B-Chat": {
DownloadSource.DEFAULT: "Duxiaoman-DI/XuanYuan-70B-Chat",
DownloadSource.MODELSCOPE: "Duxiaoman-DI/XuanYuan-70B-Chat",
},
"XuanYuan-2-70B-Chat": {
DownloadSource.DEFAULT: "Duxiaoman-DI/XuanYuan2-70B-Chat",
DownloadSource.MODELSCOPE: "Duxiaoman-DI/XuanYuan2-70B-Chat",
},
"XuanYuan-6B-int8-Chat": {
DownloadSource.DEFAULT: "Duxiaoman-DI/XuanYuan-6B-Chat-8bit",
DownloadSource.MODELSCOPE: "Duxiaoman-DI/XuanYuan-6B-Chat-8bit",
},
"XuanYuan-6B-int4-Chat": {
DownloadSource.DEFAULT: "Duxiaoman-DI/XuanYuan-6B-Chat-4bit",
DownloadSource.MODELSCOPE: "Duxiaoman-DI/XuanYuan-6B-Chat-4bit",
},
"XuanYuan-70B-int8-Chat": {
DownloadSource.DEFAULT: "Duxiaoman-DI/XuanYuan-70B-Chat-8bit",
DownloadSource.MODELSCOPE: "Duxiaoman-DI/XuanYuan-70B-Chat-8bit",
},
"XuanYuan-70B-int4-Chat": {
DownloadSource.DEFAULT: "Duxiaoman-DI/XuanYuan-70B-Chat-4bit",
DownloadSource.MODELSCOPE: "Duxiaoman-DI/XuanYuan-70B-Chat-4bit",
},
"XuanYuan-2-70B-int8-Chat": {
DownloadSource.DEFAULT: "Duxiaoman-DI/XuanYuan2-70B-Chat-8bit",
DownloadSource.MODELSCOPE: "Duxiaoman-DI/XuanYuan2-70B-Chat-8bit",
},
"XuanYuan-2-70B-int4-Chat": {
DownloadSource.DEFAULT: "Duxiaoman-DI/XuanYuan2-70B-Chat-4bit",
DownloadSource.MODELSCOPE: "Duxiaoman-DI/XuanYuan2-70B-Chat-4bit",
},
},
template="xuanyuan",
@@ -840,6 +1324,30 @@ register_model_group(
DownloadSource.DEFAULT: "xverse/XVERSE-65B-Chat",
DownloadSource.MODELSCOPE: "xverse/XVERSE-65B-Chat",
},
"XVERSE-MoE-A4.2B": {
DownloadSource.DEFAULT: "xverse/XVERSE-MoE-A4.2B",
DownloadSource.MODELSCOPE: "xverse/XVERSE-MoE-A4.2B",
},
"XVERSE-7B-int8-Chat": {
DownloadSource.DEFAULT: "xverse/XVERSE-7B-Chat-GPTQ-Int8",
DownloadSource.MODELSCOPE: "xverse/XVERSE-7B-Chat-GPTQ-Int8",
},
"XVERSE-7B-int4-Chat": {
DownloadSource.DEFAULT: "xverse/XVERSE-7B-Chat-GPTQ-Int4",
DownloadSource.MODELSCOPE: "xverse/XVERSE-7B-Chat-GPTQ-Int4",
},
"XVERSE-13B-int8-Chat": {
DownloadSource.DEFAULT: "xverse/XVERSE-13B-Chat-GPTQ-Int8",
DownloadSource.MODELSCOPE: "xverse/XVERSE-13B-Chat-GPTQ-Int8",
},
"XVERSE-13B-int4-Chat": {
DownloadSource.DEFAULT: "xverse/XVERSE-13B-Chat-GPTQ-Int4",
DownloadSource.MODELSCOPE: "xverse/XVERSE-13B-Chat-GPTQ-Int4",
},
"XVERSE-65B-int4-Chat": {
DownloadSource.DEFAULT: "xverse/XVERSE-65B-Chat-GPTQ-Int4",
DownloadSource.MODELSCOPE: "xverse/XVERSE-65B-Chat-GPTQ-Int4",
},
},
template="xverse",
)
@@ -898,11 +1406,49 @@ register_model_group(
DownloadSource.DEFAULT: "01-ai/Yi-34B-Chat-4bits",
DownloadSource.MODELSCOPE: "01ai/Yi-34B-Chat-4bits",
},
"Yi-1.5-6B": {
DownloadSource.DEFAULT: "01-ai/Yi-1.5-6B",
DownloadSource.MODELSCOPE: "01ai/Yi-1.5-6B",
},
"Yi-1.5-9B": {
DownloadSource.DEFAULT: "01-ai/Yi-1.5-9B",
DownloadSource.MODELSCOPE: "01ai/Yi-1.5-9B",
},
"Yi-1.5-34B": {
DownloadSource.DEFAULT: "01-ai/Yi-1.5-34B",
DownloadSource.MODELSCOPE: "01ai/Yi-1.5-34B",
},
"Yi-1.5-6B-Chat": {
DownloadSource.DEFAULT: "01-ai/Yi-1.5-6B-Chat",
DownloadSource.MODELSCOPE: "01ai/Yi-1.5-6B-Chat",
},
"Yi-1.5-9B-Chat": {
DownloadSource.DEFAULT: "01-ai/Yi-1.5-9B-Chat",
DownloadSource.MODELSCOPE: "01ai/Yi-1.5-9B-Chat",
},
"Yi-1.5-34B-Chat": {
DownloadSource.DEFAULT: "01-ai/Yi-1.5-34B-Chat",
DownloadSource.MODELSCOPE: "01ai/Yi-1.5-34B-Chat",
},
},
template="yi",
)
register_model_group(
models={
"YiVL-6B-Chat": {
DownloadSource.DEFAULT: "BUAADreamer/Yi-VL-6B-hf",
},
"YiVL-34B-Chat": {
DownloadSource.DEFAULT: "BUAADreamer/Yi-VL-34B-hf",
},
},
template="yi_vl",
vision=True,
)
register_model_group(
models={
"Yuan2-2B-Chat": {
@@ -932,21 +1478,9 @@ register_model_group(
DownloadSource.DEFAULT: "HuggingFaceH4/zephyr-7b-beta",
DownloadSource.MODELSCOPE: "modelscope/zephyr-7b-beta",
},
"Zephyr-141B-ORPO-Chat": {
DownloadSource.DEFAULT: "HuggingFaceH4/zephyr-orpo-141b-A35b-v0.1",
},
},
template="zephyr",
)
register_model_group(
models={
"Atom-7B": {
DownloadSource.DEFAULT: "FlagAlpha/Atom-7B",
DownloadSource.MODELSCOPE: "FlagAlpha/Atom-7B",
},
"Atom-7B-Chat": {
DownloadSource.DEFAULT: "FlagAlpha/Atom-7B-Chat",
DownloadSource.MODELSCOPE: "FlagAlpha/Atom-7B-Chat",
},
},
template="atom",
)

View File

@@ -0,0 +1,55 @@
import platform
import accelerate
import datasets
import peft
import torch
import transformers
import trl
from transformers.integrations import is_deepspeed_available
from transformers.utils import is_bitsandbytes_available, is_torch_cuda_available, is_torch_npu_available
from .packages import is_vllm_available
VERSION = "0.8.0"
def print_env() -> None:
info = {
"`llamafactory` version": VERSION,
"Platform": platform.platform(),
"Python version": platform.python_version(),
"PyTorch version": torch.__version__,
"Transformers version": transformers.__version__,
"Datasets version": datasets.__version__,
"Accelerate version": accelerate.__version__,
"PEFT version": peft.__version__,
"TRL version": trl.__version__,
}
if is_torch_cuda_available():
info["PyTorch version"] += " (GPU)"
info["GPU type"] = torch.cuda.get_device_name()
if is_torch_npu_available():
info["PyTorch version"] += " (NPU)"
info["NPU type"] = torch.npu.get_device_name()
info["CANN version"] = torch.version.cann
if is_deepspeed_available():
import deepspeed # type: ignore
info["DeepSpeed version"] = deepspeed.__version__
if is_bitsandbytes_available():
import bitsandbytes
info["Bitsandbytes version"] = bitsandbytes.__version__
if is_vllm_available():
import vllm
info["vLLM version"] = vllm.__version__
print("\n" + "\n".join(["- {}: {}".format(key, value) for key, value in info.items()]) + "\n")

Some files were not shown because too many files have changed in this diff Show More