Compare commits

...

263 Commits
0.1.7 ... 0.2.5

Author SHA1 Message Date
zyxucp
74406d88a0 Merge pull request #58 from AIDotNet/feature_StableDiffusion
fix 修改为静态类
2024-04-01 23:57:12 +08:00
zeyu xu
e5f9d97560 fix 修改为静态类 2024-04-01 23:56:44 +08:00
zyxucp
59e768aaea Merge pull request #57 from AIDotNet/feature_StableDiffusion
Feature stable diffusion
2024-04-01 23:39:22 +08:00
zeyu xu
6a7cb24a5b add sd 2024-04-01 23:08:53 +08:00
zeyu xu
1db40d534c add apptype 2024-04-01 22:14:18 +08:00
zeyu xu
11d6e30f7e add sd function 2024-04-01 22:03:00 +08:00
zeyu xu
9d5214aaae add sdmodel 2024-04-01 21:57:18 +08:00
zeyu xu
010b906271 add sd 2024-04-01 21:35:51 +08:00
zeyu xu
16bf944edf add sd 2024-04-01 21:31:15 +08:00
zeyu xu
5bae5a099a margin 2024-04-01 21:01:29 +08:00
zyxucp
f771ea9521 Merge branch 'main' of https://github.com/AIDotNet/AntSK 2024-04-01 13:54:53 +08:00
zyxucp
994efbf37c update nuget 2024-04-01 13:54:20 +08:00
zyxucp
938cd86c88 Update README.md 2024-03-31 13:24:21 +08:00
zeyu xu
1339cbadbc fix 修改menukey 2024-03-31 13:07:30 +08:00
zeyu xu
bd0ad570ad add 增加使用文档 2024-03-31 13:07:08 +08:00
zeyu xu
234e649a7e fix 优化部分内容 2024-03-31 12:38:17 +08:00
zyxucp
c431dbc842 Update README.md 2024-03-31 00:28:16 +08:00
zyxucp
76283060d9 Update docker-compose.simple.yml 2024-03-30 23:28:52 +08:00
zyxucp
75ba506db4 Update docker-compose.yml 2024-03-30 23:28:33 +08:00
zyxucp
e086ca60df Merge pull request #56 from AIDotNet/feature_chatview
add chathistory to localstorage
2024-03-30 22:40:06 +08:00
zeyu xu
04acaa9b12 add chathistory to localstorage 2024-03-30 22:39:26 +08:00
zyxucp
7a824bf18c Merge pull request #55 from AIDotNet/feature_chatview
fix 处理文件上传会话 不配embedding模型则隐藏
2024-03-30 22:26:55 +08:00
zeyu xu
769de2e526 fix 处理文件上传会话 不配embedding模型则隐藏 2024-03-30 22:23:01 +08:00
zyxucp
c9950609c9 Merge pull request #54 from AIDotNet/feature_chatview
fix 删除不要元素
2024-03-30 21:58:08 +08:00
zeyu xu
ccee6cfea5 fix 删除不要元素 2024-03-30 21:57:38 +08:00
zyxucp
f626c618be Merge pull request #53 from AIDotNet/feature_chatview
Feature chatview
2024-03-30 21:53:01 +08:00
zeyu xu
3b601a9e3d fix 调整样式 2024-03-30 21:48:29 +08:00
zeyu xu
3d5f63d595 add chatview 2024-03-30 21:45:37 +08:00
zeyu xu
6933f2f495 add openchat file 2024-03-30 20:39:02 +08:00
zeyu xu
79c7e8626a add chatview 2024-03-30 20:28:40 +08:00
zyxucp
4a017d311c Update README.md 2024-03-30 20:03:57 +08:00
zeyu xu
0c8ad5fe8d add loadding 2024-03-30 19:50:29 +08:00
zeyu xu
68ce0db011 fix 样式修改 2024-03-30 17:35:40 +08:00
zeyu xu
c36de1a1e9 add 选项控制 2024-03-30 17:25:58 +08:00
zeyu xu
44ef759abd fix 修改控件 2024-03-30 14:47:29 +08:00
longdream
0c3d9844be Merge pull request #52 from longdream/main
bge embedding模型添加,bge用的CPU。
2024-03-29 21:51:35 +08:00
longdream
854c62a4ca 合并 2024-03-29 21:50:17 +08:00
longdream
5ed4fd5299 Merge branch 'main' of https://github.com/longdream/AntSK 2024-03-29 20:00:53 +08:00
longdream
af5ec43571 修改设置界面 2024-03-29 20:00:49 +08:00
zyxucp
24d685879e Update README.md 2024-03-29 18:43:10 +08:00
zyxucp
e801a2ec46 Update README.md 2024-03-29 18:42:18 +08:00
junlong
d7b56d1590 Merge branch 'main' of https://github.com/longdream/AntSK 2024-03-29 15:34:08 +08:00
longdream
b925f8890b 修改token长度 2024-03-28 23:06:21 +08:00
longdream
5d80ee994a 解决线程冲突问题 2024-03-28 19:04:11 +08:00
zeyu xu
da8f955ca2 Merge branch 'main' of github.com:AIDotNet/AntSK 2024-03-28 18:34:59 +08:00
zeyu xu
2e04582c5e fix prompt 2024-03-28 18:34:46 +08:00
zyxucp
e69994f727 Merge pull request #50 from ElderJames/fix/sparkdesk-func-call
fix function calling result for sparkdesk
2024-03-28 11:14:52 +08:00
James Yeung
d8dc26127d fix function calling result for sparkdesk 2024-03-28 10:49:50 +08:00
longdream
f73bd2dfda 增减embedding 2024-03-27 22:53:45 +08:00
zyxucp
9f08b60348 Update README.md 2024-03-27 20:25:36 +08:00
zeyu xu
75c2f36b30 update nuget 2024-03-27 19:27:59 +08:00
zyxucp
39c02a6064 Update README.en.md 2024-03-27 19:20:21 +08:00
zyxucp
52c119befd Update README.md 2024-03-27 19:19:12 +08:00
zyxucp
23903ded3f Update README.en.md 2024-03-27 19:11:46 +08:00
zyxucp
4799fbac72 Update README.md 2024-03-27 19:11:02 +08:00
zyxucp
16a7d55271 Merge pull request #49 from AIDotNet/feature_vctordb
add AzureAISearch
2024-03-27 19:10:27 +08:00
zeyu xu
be0bafcc50 add AzureAISearch 2024-03-27 19:09:42 +08:00
zyxucp
defc51a074 Update README.md 2024-03-27 18:46:51 +08:00
zyxucp
09709c210d Merge pull request #48 from AIDotNet/feature_vctordb
add RedisMemoryDb
2024-03-27 18:44:59 +08:00
zeyu xu
8ebb2f54eb add RedisMemoryDb 2024-03-27 18:44:26 +08:00
zyxucp
b831aab115 Update README.md 2024-03-27 17:49:27 +08:00
zyxucp
8e322162cc Update README.md 2024-03-27 17:48:40 +08:00
zyxucp
6a8a6509b8 Merge pull request #47 from AIDotNet/feature_vctordb
add QdrantMemoryDb
2024-03-27 17:42:12 +08:00
zeyu xu
707dff09f8 add QdrantMemoryDb 2024-03-27 17:41:15 +08:00
zyxucp
17c8fca40f Merge pull request #45 from duyanming/main
文字纠正
2024-03-27 15:49:29 +08:00
zeyu xu
415f9757e9 Merge branch 'main' of github.com:AIDotNet/AntSK 2024-03-27 15:40:57 +08:00
zeyu xu
27394f0699 fix 修改bearer 示例错误 2024-03-27 15:40:47 +08:00
duyanming
8a9ca40bb6 文字纠正 2024-03-27 08:09:23 +08:00
longdream
f340ee1088 embedding封装 2024-03-26 23:14:55 +08:00
zyxucp
080eb5765e Update README.md 2024-03-26 21:24:26 +08:00
zeyu xu
36c8ff184a Merge branch 'main' of github.com:AIDotNet/AntSK 2024-03-26 21:22:26 +08:00
zeyu xu
0486f67b50 add gzh 2024-03-26 21:21:57 +08:00
zyxucp
aa7e8d545c Update README.md 2024-03-26 21:07:02 +08:00
zyxucp
59f6a899a6 Update README.en.md 2024-03-26 21:06:28 +08:00
zyxucp
0fc98d42aa Update README.md 2024-03-26 21:04:21 +08:00
longdream
edad2644aa 删除没必要的py文件 2024-03-26 20:48:49 +08:00
longdream
8a56a0393a Merge branch 'main' of https://github.com/longdream/AntSK 2024-03-26 20:48:07 +08:00
zyxucp
f4cbf9a40a Update README.md 2024-03-26 00:03:18 +08:00
zyxucp
fb5b92f499 Update README.en.md 2024-03-26 00:02:59 +08:00
zeyu xu
c286258f2b del 删除LLamaSharp早起http版本 2024-03-25 23:04:17 +08:00
zeyu xu
4416651589 Merge branch 'main' of github.com:AIDotNet/AntSK 2024-03-25 22:29:16 +08:00
zeyu xu
48a33e8977 fix 修复格式变更 2024-03-25 22:29:04 +08:00
junlong
bd5ca06d8f test 2024-03-25 16:55:41 +08:00
junlong
e0985ecec3 Merge branch 'main' of https://github.com/longdream/AntSK 2024-03-25 16:48:21 +08:00
junlong
e56b74d4af 删除chat以外的文件 2024-03-25 16:48:11 +08:00
zyxucp
c417098c2c Merge pull request #44 from ElderJames/fix/sparkdesk-func-issue
fix: sparkdesk function call definition conversion
2024-03-25 14:15:03 +08:00
zyxucp
93527215a7 Merge branch 'main' of https://github.com/AIDotNet/AntSK 2024-03-25 13:43:57 +08:00
James Yeung
0cf3945693 fix: sparkdesk function call definition conversion 2024-03-25 13:16:10 +08:00
zyxucp
ced2a9b2e2 fix 调整llamafactory加载顺序 2024-03-25 12:03:09 +08:00
zyxucp
987b231c4d Update README.md 2024-03-25 11:16:26 +08:00
zyxucp
7a541c1da1 Update README.md 2024-03-25 11:14:43 +08:00
zyxucp
74e323158d fix 修改变量名规范 2024-03-24 23:46:14 +08:00
zyxucp
563a7409f6 Merge pull request #42 from AIDotNet/feature_chathistory
Feature chathistory
2024-03-24 23:44:20 +08:00
zyxucp
b13b93e04e Merge branch 'feature_chathistory' of https://github.com/AIDotNet/AntSK into feature_chathistory 2024-03-24 23:43:34 +08:00
zyxucp
44568c8d65 fix 修改OpenAIService 历史对话 2024-03-24 23:43:07 +08:00
zyxucp
fb277dff80 Merge pull request #41 from AIDotNet/feature_chathistory
Feature chathistory
2024-03-24 23:12:45 +08:00
zeyu xu
efae890650 fix 调整kms提示词 2024-03-24 23:08:32 +08:00
zyxucp
a146f6059e fix 调整历史记录会话 2024-03-24 22:51:20 +08:00
zyxucp
3c67096cd8 Update README.md 2024-03-24 19:56:13 +08:00
zeyu xu
a993a60f95 Merge branch 'main' of github.com:AIDotNet/AntSK 2024-03-24 13:47:51 +08:00
zeyu xu
d3fdc77600 add 增加文档 2024-03-24 13:47:38 +08:00
zyxucp
b62c56e36f Update README.md 2024-03-24 13:05:18 +08:00
zeyu xu
7d72911239 fix 修改gpu默认分层为20 2024-03-24 12:23:27 +08:00
zeyu xu
9e24d7cc67 add Authors 2024-03-24 12:07:31 +08:00
zeyu xu
9baa24b496 Merge branch 'main' of github.com:AIDotNet/AntSK 2024-03-24 12:06:02 +08:00
zeyu xu
da826525f7 add py环境docker file 2024-03-24 12:05:50 +08:00
zyxucp
62dfab41fd Update README.en.md 2024-03-24 12:03:08 +08:00
zeyu xu
04fc811c2c fix 修改首页样式和github链接样式 2024-03-24 12:02:05 +08:00
zeyu xu
8638ecbe29 fix 修改描述一致 2024-03-24 10:55:30 +08:00
zeyu xu
6f1f93fbaf update 更新AntDesign.ProLayout、SemanticKernel、KernelMemory 版本 2024-03-24 10:52:44 +08:00
zyxucp
dc38d83f89 Update README.md 2024-03-23 23:14:20 +08:00
zeyu xu
fd780780c5 update docker-compose.yml 2024-03-23 22:39:51 +08:00
zyxucp
6fd918f33b Update README.md 2024-03-23 21:48:17 +08:00
zyxucp
8fcfa8974b Update README.md 2024-03-23 16:49:15 +08:00
zyxucp
7e23c32c6c Update README.md 2024-03-23 16:48:24 +08:00
zyxucp
7fdfceeea5 Update docker-compose.yml 2024-03-23 13:01:03 +08:00
zyxucp
bf0e62634f Update README.md 2024-03-23 12:44:53 +08:00
zyxucp
469ef9aab2 Merge pull request #38 from AIDotNet/feature_llamafactory
Feature llamafactory
2024-03-23 12:34:11 +08:00
zeyu xu
f7df26030d add 增加环境安装 2024-03-23 12:30:52 +08:00
zeyu xu
f61cbe9780 add 增加记录llamafactory是否启动的状态 2024-03-23 12:20:56 +08:00
zeyu xu
56b62cff2a fix 提示修改 2024-03-23 11:55:50 +08:00
zeyu xu
0aec21cf03 fix 修改日志输出样式宽度 2024-03-23 11:40:30 +08:00
zeyu xu
ff4f6be5fc add 日志输出 2024-03-22 22:57:38 +08:00
longdream
964a5022c8 修改输出 2024-03-22 21:44:51 +08:00
longdream
849b18f677 Merge branch 'AIDotNet:main' into main 2024-03-22 19:36:20 +08:00
zeyu xu
b8c6a6a626 add 增加校验 2024-03-21 23:59:23 +08:00
zeyu xu
57fc9a9b7e fix 修改按钮启动后不可用 2024-03-21 23:58:31 +08:00
zeyu xu
068c126a23 add 增加llamafactory 下拉列表 2024-03-21 23:53:50 +08:00
zeyu xu
0ed1662c7b add llamafactory model 2024-03-21 23:20:05 +08:00
zeyu xu
a09377814f Merge branch 'main' into feature_llamafactory 2024-03-21 22:31:32 +08:00
zeyu xu
3cc952bb2a fix 修改SK config.json升级结构变更 2024-03-21 22:24:45 +08:00
zeyu xu
63c968742b Merge branch 'main' into feature_llamafactory 2024-03-21 22:15:48 +08:00
zeyu xu
a4d6f2a6fd fix 修复分享会话空指针bug 2024-03-21 22:14:49 +08:00
zeyu xu
d6de64853d fix 修改filelist封装 2024-03-21 21:51:09 +08:00
junlong
344128e49d Merge branch 'main' of https://github.com/longdream/AntSK 2024-03-21 19:38:03 +08:00
junlong
56fc9dd517 test 2024-03-21 19:37:56 +08:00
zyxucp
c7c1911eb1 fix 限制会话中上传按钮点击 2024-03-21 19:02:48 +08:00
zyxucp
dec8b5bef7 fix 修改样式 2024-03-21 18:12:15 +08:00
zyxucp
db7271b519 Merge branch 'feature_llamafactory' of https://github.com/AIDotNet/AntSK into feature_llamafactory 2024-03-21 14:28:37 +08:00
zyxucp
fcbee1f64f margin 2024-03-21 14:15:23 +08:00
zyxucp
2d9443a0a1 Merge pull request #36 from jeffersyuan1976/main
IconPicker组件
2024-03-21 13:53:43 +08:00
zeyu xu
6e3dd00d6f fix 删除重复文件 2024-03-21 13:46:45 +08:00
Jeffers
7a0656cd81 IconPicker组件 2024-03-21 13:23:22 +08:00
zyxucp
3b89d9e974 fix 整理启动注入函数 2024-03-21 12:38:08 +08:00
zyxucp
cd174308cf fix 修改codefirst注入模式 2024-03-21 12:20:05 +08:00
zyxucp
aacef47626 Update README.md 2024-03-20 23:27:45 +08:00
zyxucp
fcef01a41f Update README.md 2024-03-20 23:27:19 +08:00
zyxucp
7d19b694fa Update README.md 2024-03-20 23:26:44 +08:00
zyxucp
966a31b156 Update README.md 2024-03-20 23:25:28 +08:00
zeyu xu
7538393742 add requirements 2024-03-20 23:09:52 +08:00
zeyu xu
3658188be2 fix 修改llamafactory 单独一个类库,并增加目录输出 2024-03-20 22:26:23 +08:00
zyxucp
fcd9fb9079 Update docker-compose.simple.yml 2024-03-20 21:38:59 +08:00
zyxucp
d1168e16d6 Update docker-compose.yml 2024-03-20 21:38:44 +08:00
zyxucp
25a6f00dd2 Merge pull request #35 from longdream/main
llama factory 初始化
2024-03-20 17:41:33 +08:00
longdream
13c474f084 Merge branch 'feature_llamafactory' into main 2024-03-20 14:55:28 +08:00
zyxucp
8d78270007 fix 修复导入文件无法导入的bug 2024-03-20 13:47:57 +08:00
junlong
a94b59c156 llama factory 初始化 2024-03-20 11:19:57 +08:00
zyxucp
ae6d61ee6d Merge branch 'main' of https://github.com/AIDotNet/AntSK 2024-03-20 09:52:13 +08:00
zyxucp
74a7c94619 fix 修改配置文件目录层级 2024-03-20 09:52:02 +08:00
zeyu xu
bdd34ac786 update docker file 2024-03-19 22:37:28 +08:00
zeyu xu
0e754b4732 fix 注释 2024-03-19 22:09:50 +08:00
zeyu xu
48a9fcfabf fix 修改样式,提取KMOption 2024-03-19 21:53:49 +08:00
zyxucp
7ae2b9ac3b Merge pull request #33 from ElderJames/feat/chat-file
support embeding files in chat
2024-03-19 21:32:21 +08:00
zeyu xu
a6bfe3c69b Merge branch 'main' of github.com:AIDotNet/AntSK 2024-03-19 20:40:24 +08:00
zeyu xu
6b146f4750 忽略本地目录 2024-03-19 20:40:19 +08:00
zeyu xu
1387e716d1 忽略tmp-memory-vectors/ 2024-03-19 20:39:31 +08:00
James Yeung
fadc350047 support embeding files in chat 2024-03-19 16:39:27 +08:00
zyxucp
8af4994a7d Merge pull request #32 from ElderJames/feat/DashScope
Add DashScope support for Kernal Memory
2024-03-19 15:54:05 +08:00
zyxucp
b505b61bfe fix 增加ask提示词 2024-03-19 15:36:01 +08:00
James Yeung
5d086c9383 Add DashScope support for Kernal Memory
Support DashScope
2024-03-19 13:35:45 +08:00
zeyu xu
dd0e367dc8 fix 增加默认端口5000 2024-03-19 00:06:36 +08:00
zyxucp
7110eea912 fix 修改注释错误 2024-03-18 17:26:40 +08:00
zyxucp
154e67ef98 fix 判断无插件不走function call 2024-03-18 17:17:03 +08:00
zyxucp
4a04423373 add swagger token 2024-03-18 16:04:55 +08:00
zyxucp
470ea50ebb update readme.en 2024-03-17 23:53:47 +08:00
zyxucp
2a30f3f221 fix 修改post提示词 2024-03-17 20:07:53 +08:00
zyxucp
315f58fdba fix 修改api插件 2024-03-17 20:04:48 +08:00
zyxucp
f027682175 add post的Function call 2024-03-17 17:23:50 +08:00
zyxucp
2618dffcd6 fix 修复知识库ID错误导致搜索不到的问题 2024-03-17 15:56:59 +08:00
zyxucp
ee42d5870b fix 增加清空导入插件功能 2024-03-17 12:05:32 +08:00
zyxucp
27500b1e08 fix 修改错误命名空间和类 2024-03-17 11:45:24 +08:00
zyxucp
4d71a98724 fix 修改Function Call注释获取不到的问题 2024-03-17 11:30:02 +08:00
zyxucp
b8ba0ab391 Merge pull request #30 from wmchuang/feature/km_disk
Feature/km disk
Thank you very much for your contribution
2024-03-16 21:24:08 +08:00
王闯
b589612913 feat: 解决disk模式下重启 向量文件读取不到的问题 2024-03-16 17:45:45 +08:00
王闯
ec4b440469 feat: 调整知识库搜索 没结果时直接返回 2024-03-16 17:16:29 +08:00
zyxucp
adbecb3b25 update 更新docker compose 2024-03-16 00:23:16 +08:00
zyxucp
277aacc34d add 动态添加dll 2024-03-15 23:53:53 +08:00
zyxucp
dba98f7968 add 增加上传dll功能 2024-03-15 22:55:23 +08:00
zyxucp
a2b0f3f3c2 add 插件导入 2024-03-15 22:45:36 +08:00
zyxucp
b5e527afdb add 增加函数列表 2024-03-15 22:33:39 +08:00
zyxucp
e84b05a39b fix 修改普通特性 2024-03-15 22:16:07 +08:00
zyxucp
631c563e71 fix 修改为普通特性,避免上传DLL需要引用特殊特性 2024-03-15 22:15:02 +08:00
zyxucp
47b304e46f update 升级SK到1.6.2 2024-03-15 22:08:22 +08:00
zyxucp
27ccfc5e88 add 增加分享使用 2024-03-15 21:57:21 +08:00
zyxucp
69441167d3 fix 修改高度自适应 2024-03-15 21:50:55 +08:00
zyxucp
4b97594217 fix 调整会话总结 2024-03-15 21:15:28 +08:00
zyxucp
9ee601f88c fix 首页调整 2024-03-15 21:12:31 +08:00
zyxucp
75b1a299e3 add 限制移动端屏幕缩放 2024-03-15 21:07:59 +08:00
zyxucp
2613c463a1 Merge branch 'main' of https://github.com/xuzeyu91/AntSK 2024-03-15 14:45:15 +08:00
zyxucp
78e0350d36 add 首页调整 2024-03-15 14:39:13 +08:00
zyxucp
bc2425fc3f Update README.md 2024-03-15 14:18:15 +08:00
zyxucp
3e861bc72f Update README.md 2024-03-15 13:01:35 +08:00
zyxucp
b41c464753 fix 暂时取消星火的会话拆分 2024-03-14 16:58:19 +08:00
zyxucp
2817275091 fix 暂时取消注册时间函数 2024-03-14 12:57:28 +08:00
zyxucp
e76b0cf326 fix 调整目录结构 2024-03-14 12:24:09 +08:00
zyxucp
cf34103e15 update docker-compose.yml 2024-03-14 12:07:17 +08:00
zyxucp
1588fd7d7a add 增加ErrorBoundary全局异常 2024-03-14 11:32:42 +08:00
zyxucp
4507ccde6c margin 2024-03-14 10:03:40 +08:00
zyxucp
73fffd766f fix 修复LLamaSharp给一个默认下载路径 2024-03-14 10:02:15 +08:00
zyxucp
e529146c5b fix 修改LLamaSharp文件夹给一个默认路径 2024-03-14 09:59:51 +08:00
zyxucp
7050e52009 Merge pull request #28 from ElderJames/main
Fix DI and OpenAI function calling
2024-03-13 22:25:45 +08:00
James Yeung
4f3238c4f6 Fix DI and OpenAI function calling 2024-03-13 22:22:22 +08:00
zyxucp
b7d27c5d50 fix 修复bug 2024-03-13 22:11:35 +08:00
zyxucp
fe94aa0564 fix 修复修改自身重复问题 2024-03-13 20:59:14 +08:00
zyxucp
ae60a9aced Merge branch 'main' of https://github.com/xuzeyu91/AntSK 2024-03-13 12:13:56 +08:00
zyxucp
4f686b0871 fix 修复按钮停止不了的问题 2024-03-13 10:23:52 +08:00
zyxucp
c61840b7e8 Update README.md 2024-03-13 10:17:50 +08:00
zyxucp
9adce95367 fix appsettings.json示例 2024-03-13 10:10:52 +08:00
zyxucp
eef943458e fix 模型选择只能选gguf,修复搜索空指针页面 2024-03-13 10:02:00 +08:00
zyxucp
f5c80689d4 add 增加模型列表 自动下载模型 2024-03-13 00:52:46 +08:00
zyxucp
5eaee3130a add 增加模型下载页面 2024-03-13 00:10:30 +08:00
zyxucp
5846473f28 Merge branch 'main' of https://github.com/xuzeyu91/AntSK 2024-03-12 23:20:02 +08:00
zyxucp
94c019b484 Merge pull request #25 from ElderJames/model-download
support model download and file list
2024-03-12 22:51:15 +08:00
zyxucp
7e1140c022 add 增加下载页面 2024-03-12 22:50:06 +08:00
James Yeung
ea9044719a support model download and file list 2024-03-12 22:41:12 +08:00
zyxucp
8a96095448 add 增加模型下载地址 2024-03-12 21:57:34 +08:00
zyxucp
fcc8f8751b add 修改对外接口授权添加Bearer 2024-03-12 21:34:58 +08:00
zyxucp
af09ae7c3e Update docker-compose.simple.yml 2024-03-12 10:22:39 +08:00
zyxucp
e8e6a36d7b Update docker-compose.yml 2024-03-12 10:22:12 +08:00
zyxucp
4f89d54ef0 add mock类型便于测试 2024-03-12 09:57:12 +08:00
zyxucp
2f9e2fb114 fix 增加导入后清空文件列表 2024-03-12 09:33:55 +08:00
zyxucp
b6098024b8 add 更换markdown组件,实现代码高亮 2024-03-12 01:11:48 +08:00
zyxucp
1700131066 add 增加首页点击跳转 2024-03-12 00:31:17 +08:00
zyxucp
189536471a fix 先隐藏星火的KM选择器 2024-03-12 00:15:04 +08:00
zyxucp
f534e0bcc3 fix 修复多文件不能同时导入的bug 2024-03-12 00:12:41 +08:00
zyxucp
e203a18e92 add 增加原生插件调用示例 2024-03-11 23:42:26 +08:00
zyxucp
575a69bf4d add 增加首页,暂时取消本地函数调用 2024-03-11 23:16:23 +08:00
zyxucp
69fd3a0367 fix 修改命名空间 2024-03-11 21:23:07 +08:00
zyxucp
8f7e70298e fix 修改描述错误 2024-03-11 21:16:23 +08:00
zyxucp
0fa3f5a554 fix 移动文件目录 2024-03-11 19:41:51 +08:00
zyxucp
f420012752 fix 移动文件目录 2024-03-11 19:40:15 +08:00
zyxucp
c1ca916549 fix 增加openai模型示例 2024-03-11 19:38:26 +08:00
zyxucp
dfcf2bdc85 fix 取消非空类型 2024-03-11 19:12:47 +08:00
zyxucp
72e7acfb7d 修改写法 2024-03-11 19:09:41 +08:00
zyxucp
9d06c127dc fix 增加注释 2024-03-11 19:08:52 +08:00
zyxucp
0460b388ab fix 修改kernel 和km 每次build的问题,进行缓存 2024-03-11 19:02:17 +08:00
zyxucp
45b84ae898 fix 修改类名 2024-03-11 18:26:26 +08:00
zyxucp
41b1cb6f2d Merge pull request #24 from ElderJames/feat/sparkdesk-func-cal
Add support for the Function Calling of Spark Desk
2024-03-11 18:16:27 +08:00
James Yeung
159aaab38e Add support for the Function Calling of Spark Desk 2024-03-11 18:10:14 +08:00
zyxucp
dc351238f6 Update README.md 2024-03-11 15:13:23 +08:00
zyxucp
e6491b39c6 Update README.md 2024-03-11 15:11:11 +08:00
zyxucp
91b4ed8940 Update README.md 2024-03-11 14:23:16 +08:00
zyxucp
ab99098afd Update README.md 2024-03-11 13:26:41 +08:00
zyxucp
d14ce2faa0 add 增加系统设置管理权限 2024-03-10 21:26:01 +08:00
zyxucp
ca293691a8 fix 修改GpuLayerCount 默认值为10 2024-03-10 17:48:00 +08:00
zyxucp
cf8955b9b6 add 更新AntDesign KernelMemory nuget包版本 2024-03-10 16:31:16 +08:00
zyxucp
512828fdc9 fix 调整星火key和screct的位置 2024-03-10 10:55:44 +08:00
zyxucp
91299a96e7 Merge pull request #22 from ElderJames/main
Add Spark Desk text ceneration support
2024-03-10 09:21:35 +08:00
James Yeung
4876d9e727 fix build 2024-03-10 01:34:18 +08:00
James Yeung
a856f2a0e3 clean 2024-03-10 01:31:21 +08:00
James Yeung
0e8113e7b0 Add Spark Desk text ceneration support 2024-03-10 00:55:33 +08:00
zyxucp
34a953589d Update docker-compose.simple.yml 2024-03-09 23:55:50 +08:00
zyxucp
504ea5a238 Update docker-compose.yml 2024-03-09 23:55:40 +08:00
267 changed files with 18916 additions and 1953 deletions

3
.gitignore vendored
View File

@@ -337,6 +337,9 @@ ASALocalRun/
/AntSK/appsettings.Development.json
/AntSK.db
**/tmp-memory-files/*
**/tmp-memory-vectors/*
/src/AntSK/AntSK.db
/src/AntSK/appsettings.Development.json
/src/AntSK.db
/src/AntSK/llama_models
/src/AntSK/AntSK.xml

28
Dockerfile-py Normal file
View File

@@ -0,0 +1,28 @@
# 1. Define the Python image to use for getting pip
FROM pytorch/pytorch AS python-base
# 2. Define the .NET SDK image to build your application
FROM mcr.microsoft.com/dotnet/sdk:8.0 AS build
WORKDIR /src
COPY ["src/AntSK/AntSK.csproj", "AntSK/"]
RUN dotnet restore "AntSK/AntSK.csproj"
COPY src/ .
WORKDIR "/src/AntSK"
RUN dotnet build "AntSK.csproj" -c Release -o /app/build
RUN dotnet publish "AntSK.csproj" -c Release -o /app/publish
# 3. Define the final image that will contain both .NET runtime and Python
FROM mcr.microsoft.com/dotnet/aspnet:8.0 AS final
# Copy the Python/pip installation from the official Python image
COPY --from=python-base /usr/local /usr/local
COPY --from=python-base /opt/conda/ /opt/conda/
WORKDIR /app
COPY --from=build /app/publish .
# Make sure the app and Python directories are in PATH
ENV PATH="/app:/opt/conda/bin:/usr/local/bin:${PATH}"
RUN ln -sf /usr/share/zoneinfo/Asia/Shanghai /etc/localtime
RUN echo 'Asia/Shanghai' >/etc/timezone
RUN pip install -r requirements.txt -i https://pypi.tuna.tsinghua.edu.cn/simple
ENTRYPOINT ["dotnet", "AntSK.dll"]

View File

@@ -1,162 +1,123 @@
[简体中文](./README.md) | English
# AntSK
## AI Knowledge Base/Intelligent Agent built on .Net8+AntBlazor+SemanticKernel
## Based on AI knowledge base/agent created by Net8+AntBlazor+SemanticKernel
## ⭐Core Features
- **Semantic Kernel**: Utilizes advanced natural language processing technology to accurately understand, process, and respond to complex semantic queries, providing users with precise information retrieval and recommendation services.
- **Kernel Memory**: Capable of continuous learning and storing knowledge points, AntSK has long-term memory function, accumulates experience, and provides a more personalized interaction experience.
## Core functions
- **Knowledge Base**: Import knowledge base through documents (Word, PDF, Excel, Txt, Markdown, Json, PPT) and perform knowledge base Q&A.
- **GPT Generation**: This platform supports creating personalized GPT models, enabling users to build their own GPT models.
- **API Interface Publishing**: Exposes internal functions in the form of APIs, enabling developers to integrate AntSK into other applications and enhance application intelligence.
- **Semantic Kernel**: It uses advanced natural language processing technology to accurately understand, process and respond to complex semantic queries, and provides users with accurate information retrieval and recommendation services.
- **API Plugin System**: Open API plugin system that allows third-party developers or service providers to easily integrate their services into AntSK, continuously enhancing application functionality.
- **.Net Plugin System**: Open dll plugin system that allows third-party developers or service providers to easily integrate their business functions by generating dll in standard format code, continuously enhancing application functionality.
- **Online Search**: AntSK, real-time access to the latest information, ensuring users receive the most timely and relevant data.
- **Kernel Memory**: It has the ability to continuously learn and store knowledge points. AntSK has a long-term memory function to accumulate experience and provide a more personalized interactive experience.
- **Model Management**: Adapts and manages integration of different models from different manufacturers, including gguf types supported by **llama.cpp** and models offline running supported by **llamafactory**.
- **Domestic Innovation**: AntSK supports domestic models and databases and can run under domestic innovation conditions.
- **Model Fine-Tuning**: Planned based on llamafactory for model fine-tuning.
- **Knowledge base**: Knowledge base documents can be created by importing knowledge base documents (Word, PDF, Excel, Txt, Markdown, Json, PPT) and other forms.
- **API plug-in system**: an open API plug-in system that allows third-party developers or service providers to easily integrate their services into AntSK and continuously enhance application functions.
- **Online search**: AntSK can obtain the latest information in real time to ensure that the information received by users is always the most timely and relevant.
- **GPTs generation**: This platform supports the creation of personalized GPT models and attempts to build your own GPT models.
- **API interface publishing**: internal functions are provided externally in the form of API, so that developers can easily translate Xzy AntSK KnowledgeBase is integrated into other applications to enhance application intelligence.
- **Model management**: Adapt and manage different models from different vendors.
## Application scenarios
AntSK is applicable to a variety of business scenarios, such as:
- Enterprise level knowledge management system
- Automatic customer service and chat robot
- Enterprise Search Engine
## ⛪Application Scenarios
AntSK is suitable for various business scenarios, such as:
- Enterprise knowledge management system
- Automatic customer service and chatbots
- Enterprise search engine
- Personalized recommendation system
- Intelligent assisted writing
- Education and online learning platform
- Intelligent writing assistance
- Education and online learning platforms
- Other interesting AI Apps
## Function example
First, you need to create a knowledge base
![Knowledge base](https://github.com/xuzeyu91/AntSK/blob/main/images/%E7%9F%A5%E8%AF%86%E5%BA%93.png)
In the knowledge base, you can use documents or urls to import
![Knowledge base details](https://github.com/xuzeyu91/AntSK/blob/main/images/%E7%9F%A5%E8%AF%86%E5%BA%93%E8%AF%A6%E6%83%85.png)
Click View to view the document slicing of the knowledge base
![Document Slice](https://github.com/xuzeyu91/AntSK/blob/main/images/%E6%96%87%E6%A1%A3%E5%88%87%E7%89%87.png)
Then we need to create applications, which can create dialog applications and knowledge bases.
![Application](https://github.com/xuzeyu91/AntSK/blob/main/images/%E5%BA%94%E7%94%A8.png)
The application of knowledge base needs to select the existing knowledge base, which can be multiple
![Application Configuration](https://github.com/xuzeyu91/AntSK/blob/main/images/%E5%BA%94%E7%94%A8%E9%85%8D%E7%BD%AE.png)
Then you can ask questions about the knowledge base documents in the dialogue
![Q&A](https://github.com/xuzeyu91/AntSK/blob/main/images/%E9%97%AE%E7%AD%94.png)
In addition, we can also create dialogue applications, and configure prompt word templates in corresponding applications
![Conversation application](https://github.com/xuzeyu91/AntSK/blob/main/images/%E7%AE%80%E5%8D%95%E5%AF%B9%E8%AF%9D.png)
Let's see the effect
![Conversation effect](https://github.com/xuzeyu91/AntSK/blob/main/images/%E5%AF%B9%E8%AF%9D%E6%95%88%E6%9E%9C.png)
## How do I get started?
Login is the default login account and password
Here I use Postgres as data storage and vector storage, because both the Semantic Kernel and Kernel Memory support it. Of course, you can switch to other ones.
The model supports openai by default. If you need to use azure openai and need to adjust the dependency injection of SK, you can also use one api for integration.
The following configuration files need to be configured
## Using Docker Compose
Provided pg version appsettings. json and simplified version (Sqlite+disk) Docker Compose. simple. yml
Download Docker Compose.yml from the project root directory, and then place the configuration file appsettings.json and it in a unified directory,
The image of PG has been prepared here. You can modify the default account password in Docker Compose.yml, and your appsettings. json database connection needs to be consistent.
Then you can enter the directory and execute it
## ✏Function Examples
### Online Demo
```
docker compose up - d
https://antsk.ai-dotnet.com/
```
To start AntSK
```
Default account: test
Some meanings of configuration files
Default password: test
Due to the low configuration of the cloud server, the local model cannot be run, so the system settings permissions have been closed. You can simply view the interface. If you want to use the local model, please download and use it on your own.
```
### Other Function Examples
[Video Demonstration](https://www.bilibili.com/video/BV1zH4y1h7Y9/)
## ❓How to get started?
Here I am using Postgres as the data and vector storage because Semantic Kernel and Kernel Memory support it, but you can also use other options.
The model by default supports the local model of openai, azure openai, and llama. If you need to use other models, you can integrate them using one-api.
The Login configuration in the configuration file is the default login account and password.
The following configuration file needs to be configured
## 1⃣Using docker-compose
Provided the pg version **appsettings.json** and simplified version (Sqlite+disk) **docker-compose.simple.yml**
Download **docker-compose.yml** from the project root directory and place the configuration file **appsettings.json** in the same directory.
The pg image has already been prepared. You can modify the default username and password in docker-compose.yml, and then the database connection in your **appsettings.json** needs to be consistent.
Then you can execute the following command in the directory to start AntSK
```
docker-compose up -d
```
## 2⃣How to mount local models and model download directory in docker
```
# Non-host version, do not use local proxy
version: '3.8'
services:
antsk:
container_name: antsk
image: registry.cn-hangzhou.aliyuncs.com/AIDotNet/antsk:v0.1.5ports:
- 5000:5000
networks:
- antsk
depends_on:
- antskpg
restart: always
environment:
- ASPNETCORE_URLS=http://*:5000
volumes:
- ./appsettings.json:/app/appsettings.json # Local configuration file needs to be placed in the same directory
- D://model:/app/model
networks:
antsk:
```
Taking this as an example, it means mounting the local D://model folder of Windows into the container /app/model. If so, the model address in your appsettings.json should be configured as
```
model/xxx.gguf
```
## 3⃣Some meanings of configuration file
```
{
"DBConnection": {
"DbType": "Sqlite",
"DbType": "Sqlite",
"ConnectionStrings": "Data Source=AntSK.db;"
},
"OpenAIOption": {
"EndPoint": "http://localhost:5000/llama/",
"Key": "NotNull",
"Model": "gpt4-turbo",
"EmbeddingModel": "text-embedding-ada-002"
},
"KernelMemory": {
"VectorDb": "Disk",
"VectorDb": "Disk",
"ConnectionString": "Host=;Port=;Database=antsk;Username=;Password=",
"TableNamePrefix": "km-"
},
"LLamaSharp": {
"RunType": "GPU",
"Chat": "D:\\Code\\AI\\AntBlazor\\model\\qwen1_5-1_8b-chat-q8_0.gguf",
"Embedding": "D:\\Code\\AI\\AntBlazor\\model\\qwen1_5-1_8b-chat-q8_0.gguf"
"RunType": "GPU",
"FileDirectory": "D:\\Code\\AI\\AntBlazor\\model\\"
},
"Login": {
"User": "admin",
@@ -168,48 +129,86 @@ Some meanings of configuration files
}
}
}
```
```
//Supports multiple databases, including SqlSugar, MySql, SqlServer, Sqlite, Oracle, PostgreSQL, Dm, Kdbndp, Oscar, MySqlConnector, Access, OpenGaussian, QuestDB, HG, ClickHouse, GBase, Odbc, OceanBaseForOracle, TDengine, GaussDB, OceanBase, Tidb, Vastbase, PolarDB, Custom
DBConnection DbType
//Connection string, corresponding strings need to be used according to different DB types
DBConnection ConnectionStrings
//You can use an online API that conforms to the OpenAI format (domestic models use one API adapter), or you can use AntSK's built-in llama API, with the IP and port being the AntSK startup address
OpenAIOption EndPoint
//Model key, if using a local model, it can default to Notnull. Chinese cannot be used here
OpenAIOption Key
//The type of vector storage supports Postgres Disk Memory, where Postgres requires the configuration of ConnectionString
KernelMemory VectorDb
//The running mode used by the local model is GUP CPU. If using an online API, you can freely use one
LLamaSharp RunType
//The model path of the local session model should pay attention to distinguishing between Linux and Windows drive letters
LLamaSharp Chat
//The model path of the local vector model should pay attention to distinguishing between Linux and Windows drive letters
LLamaSharp Embedding
//Default administrator account password
// Supports various databases, you can check SqlSugar, MySql, SqlServer, Sqlite, Oracle, PostgreSQL, Dm, Kdbndp, Oscar, MySqlConnector, Access, OpenGauss, QuestDB, HG, ClickHouse, GBase, Odbc, OceanBaseForOracle, TDengine, GaussDB, OceanBase, Tidb, Vastbase, PolarDB, Custom
DBConnection.DbType
// Connection string, need to use the corresponding string according to the different DB types
DBConnection.ConnectionStrings
//The type of vector storage, supporting Postgres, Disk, Memory, Qdrant, Redis, AzureAISearch
//Postgres and Redis require ConnectionString configuration
//The ConnectionString of Qdrant and AzureAISearch uses Endpoint | APIKey
KernelMemory.VectorDb
//Local model execution options: GPU and CPU. When using the online API, any option can be used.
LLamaSharp.RunType
//Local model path, used for quick selection of models under llama, as well as saving downloaded models.
LLamaSharp.FileDirectory
//Default admin account password
Login
//The number of threads for importing asynchronous processing can be higher when using online APIs. Local models suggest 1, otherwise memory overflow and crash may occur
BackgroundTaskBroker ImportKMSTask WorkerCount
//Import asynchronous processing thread count. A higher count can be used for online API, but for local models, 1 is recommended to avoid memory overflow issues.
BackgroundTaskBroker.ImportKMSTask.WorkerCount
```
## ⚠Fixing Style Issues:
Run the following in AntSK/src/AntSK:
```
dotnet clean
dotnet build
dotnet publish "AntSK.csproj"
```
Then navigate to AntSK/src/AntSK/bin/Release/net8.0/publish and run:
```
dotnet AntSK.dll
```
The styles should now be applied after starting.
To learn more or start using**AntSK**, you can follow my public account and join the exchange group.
I'm using CodeFirst mode for the database, so as long as the database connection is properly configured, the table structure will be created automatically.
## ✔Using llamafactory
```
1. First, ensure that Python and pip are installed in your environment. This step is not necessary if using an image, such as version v0.2.3.2, which already includes the complete Python environment.
2. Go to the model add page and select llamafactory.
3. Click "Initialize" to check whether the 'pip install' environment setup is complete.
4. Choose a model that you like.
5. Click "Start" to begin downloading the model from the tower. This may involve a somewhat lengthy wait.
6. After the model has finished downloading, enter http://localhost:8000/ in the request address. The default port is 8000.
7. Click "Save" and start chatting.
8. Many people ask about the difference between LLamaSharp and llamafactory. In fact, LLamaSharp is a .NET implementation of llama.cpp, but only supports local gguf models, while llamafactory supports a wider variety of models and uses Python implementation. The main difference lies here. Additionally, llamafactory has the ability to fine-tune models, which is an area we will focus on integrating in the future.
```
## 🤝 Contributing
[![PRs Welcome](https://img.shields.io/badge/PRs-welcome-brightgreen.svg?style=flat-square)](https://github.com/AIDotNet/AntSK/pulls)
If you would like to contribute, feel free to create a [Pull Request](https://github.com/AIDotNet/AntSK/pulls), or give us [Bug Report](https://github.com/AIDotNet/AntSK/issues/new).
## 💕 Contributors
## Contact me
This project exists thanks to all the people who contribute.
If you have any questions or suggestions, please follow my public account through the following ways, and send a message to me. We also have an exchange group, which can send messages such as joining the group, and then I will bring you into the exchange group
<a href="https://github.com/AIDotNet/AntSK/graphs/contributors">
<img src="https://contrib.rocks/image?repo=AIDotNet/AntSK&max=1000&columns=15&anon=1" />
</a>
![Official account](https://github.com/xuzeyu91/Avalonia-Assistant/blob/main/img/gzh.jpg)
## 🚨 Code of Conduct
This project has adopted the code of conduct defined by the Contributor Covenant to clarify expected behavior in our community.
For more information see the [.NET Foundation Code of Conduct](https://dotnetfoundation.org/code-of-conduct).
To learn more or get started with **AntSK**, follow my official WeChat account and join the discussion group.
## ☎Contact Me
If you have any questions or suggestions, please contact me through my official WeChat account. We also have a discussion group where you can send a message to join, and then I will add you to the group.
![Official WeChat Account](https://github.com/AIDotNet/Avalonia-Assistant/blob/main/img/gzh.jpg)
---
We appreciate your interest in**AntSK**and look forward to working with you to create an intelligent future!
We appreciate your interest in **AntSK** and look forward to collaborating with you to create an intelligent future!

148
README.md
View File

@@ -1,26 +1,33 @@
中文|[English](https://github.com/xuzeyu91/AntSK/blob/main/README.en.md)
中文|[English](https://github.com/AIDotNet/AntSK/blob/main/README.en.md)
# AntSK
## 基于.Net8+AntBlazor+SemanticKernel 打造的AI知识库/智能体
## 使用.Net8+Blazor+SemanticKernel 打造的AI知识库/智能体
## 核心功能
## 核心功能
- **语义内核 (Semantic Kernel)**:采用领先的自然语言处理技术,准确理解、处理和响应复杂的语义查询,为用户提供精确的信息检索和推荐服务。
- **内存内核 (Kernel Memory)**具备持续学习和存储知识点的能力AntSK 拥有长期记忆功能,累积经验,提供更个性化的交互体验。
- **知识库**通过文档Word、PDF、Excel、Txt、Markdown、Json、PPT等形式导入知识库可以进行知识库文档
- **API插件系统**开放式API插件系统允许第三方开发者或服务商轻松将其服务集成到AntSK不断增强应用功能。
- **联网搜索**AntSK实时获取最新信息确保用户接受到的资料总是最及时、最相关的。
- **知识库**通过文档Word、PDF、Excel、Txt、Markdown、Json、PPT等形式导入知识库可以进行知识库问答
- **GPTs 生成**此平台支持创建个性化的GPT模型尝试构建您自己的GPT模型。
- **API接口发布**将内部功能以API的形式对外提供便于开发者将AntSK 集成进其他应用,增强应用智慧。
- **模型管理**:适配和管理集成不同厂商的不同模型
- **API插件系统**开放式API插件系统允许第三方开发者或服务商轻松将其服务集成到AntSK不断增强应用功能
## 应用场景
- **.Net插件系统**开放式dll插件系统允许第三方开发者或服务商轻松将其业务功能通过标准格式的代码生成dll后集成到AntSK不断增强应用功能。
- **联网搜索**AntSK实时获取最新信息确保用户接受到的资料总是最及时、最相关的。
- **模型管理**:适配和管理集成不同厂商的不同模型。并且支持**llama.cpp**所支持的gguf类型以及**llamafactory**所支持的模型离线运行
- **国产信创**AntSK支持国产模型和国产数据库可以在信创条件下运行
- **模型微调**规划中基于llamafactory进行模型微调
## ⛪应用场景
AntSK 适用于多种业务场景,例如:
- 企业级知识管理系统
@@ -31,47 +38,37 @@ AntSK 适用于多种业务场景,例如:
- 教育与在线学习平台
- 其他有意思的AI App
## 功能示例
## ✏️功能示例
### 在线演示
```
https://antsk.ai-dotnet.com/
```
```
默认账号test
默认密码test
由于云服务器配置较低,无法运行本地模型,所以把系统设置权限关闭了,大家看看界面即可,要使用本地模型,请下载自行使用
```
### 其他功能示例
[视频示例](https://www.bilibili.com/video/BV1zH4y1h7Y9/)
首先需要创建知识库
![知识库](https://github.com/xuzeyu91/AntSK/blob/main/images/%E7%9F%A5%E8%AF%86%E5%BA%93.png)
[在线文档http://antsk.cn](http://antsk.cn)
在知识库里可以使用文档或者url进行导入
![知识库详情](https://github.com/xuzeyu91/AntSK/blob/main/images/%E7%9F%A5%E8%AF%86%E5%BA%93%E8%AF%A6%E6%83%85.png)
点击查看可以查看知识库的文档切片情况
![文档切片](https://github.com/xuzeyu91/AntSK/blob/main/images/%E6%96%87%E6%A1%A3%E5%88%87%E7%89%87.png)
然后我们需要创建应用,可以创建对话应用和知识库。
![应用](https://github.com/xuzeyu91/AntSK/blob/main/images/%E5%BA%94%E7%94%A8.png)
知识库应用需要选择已有的知识库,可以选多个
![应用配置](https://github.com/xuzeyu91/AntSK/blob/main/images/%E5%BA%94%E7%94%A8%E9%85%8D%E7%BD%AE.png)
然后再对话中可以对知识库的文档进行提问
![问答](https://github.com/xuzeyu91/AntSK/blob/main/images/%E9%97%AE%E7%AD%94.png)
另外我们也可以创建对话应用,可以在对应应用中配置提示词模板
![对话应用](https://github.com/xuzeyu91/AntSK/blob/main/images/%E7%AE%80%E5%8D%95%E5%AF%B9%E8%AF%9D.png)
下面来看看效果吧
![对话效果](https://github.com/xuzeyu91/AntSK/blob/main/images/%E5%AF%B9%E8%AF%9D%E6%95%88%E6%9E%9C.png)
## 如何开始?
## ❓如何开始?
在这里我使用的是Postgres 作为数据存储和向量存储因为Semantic Kernel和Kernel Memory都支持他当然你也可以换成其他的。
模型默认支持openai,如果需要使用azure openai需要调整SK的依赖注入可以使用one-api进行集成。
模型默认支持openaiazure openai、讯飞星火、阿里云积、 和llama支持的gguf本地模型 以及llamafactory的本地模型,如果需要使用其他模型,可以使用one-api进行集成。
Login是默认的登账号和密码
配置文件中的Login配置是默认的登账号和密码
需要配置如下的配置文件
## 使用docker-compose
## 1使用docker-compose
提供了pg版本 **appsettings.json** 和 简化版本Sqlite+disk **docker-compose.simple.yml**
提供了pg版本 **appsettings.json** 和 简化版本(**Sqlite+disk** **docker-compose.simple.yml**
从项目根目录下载**docker-compose.yml**,然后把配置文件**appsettings.json**和它放在统一目录,
@@ -83,14 +80,14 @@ docker-compose up -d
```
来启动AntSK
## 如何在docker中挂载本地模型
## 2如何在docker中挂载本地模型,和模型下载的目录
```
# 非 host 版本, 不使用本机代理
version: '3.8'
services:
antsk:
container_name: antsk
image: registry.cn-hangzhou.aliyuncs.com/xuzeyu91/antsk:v0.1.5
image: registry.cn-hangzhou.aliyuncs.com/AIDotNet/antsk:v0.2.3
ports:
- 5000:5000
networks:
@@ -111,7 +108,7 @@ networks:
model/xxx.gguf
```
## 配置文件的一些含义
## 3配置文件的一些含义
```
{
"DBConnection": {
@@ -125,8 +122,7 @@ model/xxx.gguf
},
"LLamaSharp": {
"RunType": "GPU",
"Chat": "D:\\Code\\AI\\AntBlazor\\model\\qwen1_5-1_8b-chat-q8_0.gguf",
"Embedding": "D:\\Code\\AI\\AntBlazor\\model\\qwen1_5-1_8b-chat-q8_0.gguf"
"FileDirectory": "D:\\Code\\AI\\AntBlazor\\model\\"
},
"Login": {
"User": "admin",
@@ -144,25 +140,25 @@ model/xxx.gguf
DBConnection.DbType
//连接字符串需要根据不同DB类型用对应的字符串
DBConnection.ConnectionStrings
//可以使用符合openai格式的在线API国产模型使用one-api转接 也可以使用AntSK自带的llama apiip和端口是AntSK启动地址
OpenAIOption.EndPoint
//模型秘钥如果使用本地模型可以默认NotNull 这里不能用中文
OpenAIOption.Key
//向量存储的类型,支持 Postgres Disk Memory 其中Postgres需要配置 ConnectionString
//向量存储的类型,支持 Postgres、Disk、Memory、Qdrant、Redis、AzureAISearch
//Postgres、Redis需要配置 ConnectionString
//Qdrant 和AzureAISearch 的 ConnectionString 使用 Endpoint|APIKey
KernelMemory.VectorDb
//本地模型使用的运行方式 GUP CPU ,如果用在线API 这个随意使用一个即可
LLamaSharp.RunType
//本地会话模型的模型路径 注意区分linux和windows盘符不同
LLamaSharp.Chat
//本地向量模型的模型路径 注意区分linux和windows盘符不同
LLamaSharp.Embedding
//本地模型路径用于在选择llama时可以快速选择目录下的模型以及保存下载的模型
LLamaSharp.FileDirectory
//默认管理员账号密码
Login
//导入异步处理的线程数使用在线API可以高一点本地模型建议1 否则容易内存溢出崩掉
BackgroundTaskBroker.ImportKMSTask.WorkerCount
```
## 找不到样式问题解决:
## ⚠️找不到样式问题解决:
AntSK/src/AntSK下执行:
```
dotnet clean
@@ -177,15 +173,49 @@ dotnet AntSK.dll
DB我使用的是CodeFirst模式只要配置好数据库链接表结构是自动创建的
## ✔使用llamafactory
```
1、首先需要确保你的环境已经安装了python和pip如果使用镜像例如p0.2.4版本已经包含了 python全套环境则无需此步骤
2、进入模型添加页面选择llamafactory
3、点击初始化可以检查pip install 环境是否完成
4、选择一个喜欢的模型
5、点击启动,这会开始从魔塔下载模型,你可能需要有一个较为漫长的等待
6、等待模型下载完毕后在请求地址输入 http://localhost:8000/ 这里默认是使用8000端口
7、点击保存然后就可以开始聊天了
8、很多人会问 LLamaSharp与llamafactory有什么区别其实这两者LLamaSharp是llama.cpp的 dotnet实现但是只支持本地gguf模型 而llamafactory 支持的模型种类更多但使用的是python的实现其主要差异在这里另外llamafactory具有模型微调的能力这也是我们下一步需要重点集成的部分。
```
## 🤝 贡献
[![PRs Welcome](https://img.shields.io/badge/PRs-welcome-brightgreen.svg?style=flat-square)](https://github.com/AIDotNet/AntSK/pulls)

如果你想贡献,可以创建一个[拉取请求](https://github.com/AIDotNet/AntSK/pulls), 或给我们[错误报告](https://github.com/AIDotNet/AntSK/issues/new).


## 💕 贡献者
这个项目的存在要感谢所有的贡献者。

<a href="https://github.com/AIDotNet/AntSK/graphs/contributors">
<img src="https://contrib.rocks/image?repo=AIDotNet/AntSK&max=1000&columns=15&anon=1" />
</a>

## 🚨 行为准则
该项目采用了贡献者公约定义的行为准则,以阐明我们社区的预期行为。有关更多信息,请参见 .NET Foundation 行为准则。 [.NET Foundation Code of Conduct](https://dotnetfoundation.org/code-of-conduct).
想了解更多信息或开始使用 **AntSK**,可以关注我的公众号以及加入交流群。
## 联系我
## ☎️联系我
如有任何问题或建议,请通过以下方式关注我的公众号,发消息与我联系,我们也有交流群,可以发送进群等消息,然后我会拉你进交流群
![公众号](https://github.com/xuzeyu91/Avalonia-Assistant/blob/main/img/gzh.jpg)
![公众号](https://github.com/AIDotNet/AntSK/blob/main/images/gzh.jpg)
---
我们对您在**AntSK**的兴趣表示感谢,并期待与您携手共创智能化的未来!
## 🌟 Star History
<a href="https://github.com/AIDotNet/AntSK/stargazers" target="_blank" style="display: block" align="center">
<picture>
<source media="(prefers-color-scheme: dark)" srcset="https://api.star-history.com/svg?repos=AIDotNet/AntSK&type=Date&theme=dark" />
<source media="(prefers-color-scheme: light)" srcset="https://api.star-history.com/svg?repos=AIDotNet/AntSK&type=Date" />
<img alt="Star History Chart" src="https://api.star-history.com/svg?repos=AIDotNet/AntSK&type=Date" />
</picture>
</a>

View File

@@ -3,7 +3,9 @@ version: '3.8'
services:
antsk:
container_name: antsk
image: registry.cn-hangzhou.aliyuncs.com/xuzeyu91/antsk:v0.1.6.2
image: registry.cn-hangzhou.aliyuncs.com/xuzeyu91/antsk:v0.2.4
# 如果需要pytorch环境需要使用下面这个镜像镜像比较大
# image: registry.cn-hangzhou.aliyuncs.com/xuzeyu91/antsk:p0.2.4
ports:
- 5000:5000
networks:

View File

@@ -18,7 +18,9 @@ services:
- ./pg/data:/var/lib/postgresql/data
antsk:
container_name: antsk
image: registry.cn-hangzhou.aliyuncs.com/xuzeyu91/antsk:v0.1.6.2
image: registry.cn-hangzhou.aliyuncs.com/xuzeyu91/antsk:v0.2.4
# 如果需要pytorch环境需要使用下面这个镜像镜像比较大
# image: registry.cn-hangzhou.aliyuncs.com/xuzeyu91/antsk:p0.2.4
ports:
- 5000:5000
networks:

View File

@@ -0,0 +1,14 @@
{
"position": 3,
"label": "部署",
"collapsible": true,
"collapsed": false,
"className": "red",
"link": {
"type": "generated-index",
"title": "使用案例"
},
"customProps": {
"description": "提供快速使用AntSK的一些案例"
}
}

56
docs/deploy/settings.md Normal file
View File

@@ -0,0 +1,56 @@
---
sidebar_position: 2
---
# 配置文件的一些含义
```
{
"DBConnection": {
"DbType": "Sqlite",
"ConnectionStrings": "Data Source=AntSK.db;"
},
"KernelMemory": {
"VectorDb": "Disk",
"ConnectionString": "Host=;Port=;Database=antsk;Username=;Password=",
"TableNamePrefix": "km-"
},
"LLamaSharp": {
"RunType": "GPU",
"Chat": "D:\\Code\\AI\\AntBlazor\\model\\qwen1_5-1_8b-chat-q8_0.gguf",
"Embedding": "D:\\Code\\AI\\AntBlazor\\model\\qwen1_5-1_8b-chat-q8_0.gguf",
"FileDirectory": "D:\\Code\\AI\\AntBlazor\\model\\"
},
"Login": {
"User": "admin",
"Password": "xuzeyu"
},
"BackgroundTaskBroker": {
"ImportKMSTask": {
"WorkerCount": 1
}
}
}
```
```
//支持多种数据库具体可以查看SqlSugarMySqlSqlServerSqliteOraclePostgreSQLDmKdbndpOscarMySqlConnectorAccessOpenGaussQuestDBHGClickHouseGBaseOdbcOceanBaseForOracleTDengineGaussDBOceanBaseTidbVastbasePolarDBCustom
DBConnection.DbType
//连接字符串需要根据不同DB类型用对应的字符串
DBConnection.ConnectionStrings
//向量存储的类型,支持 Postgres Disk Memory 其中Postgres需要配置 ConnectionString
KernelMemory.VectorDb
//本地模型使用的运行方式 GUP CPU ,如果用在线API 这个随意使用一个即可
LLamaSharp.RunType
//本地会话模型的模型路径 注意区分linux和windows盘符不同
LLamaSharp.Chat
//本地向量模型的模型路径 注意区分linux和windows盘符不同
LLamaSharp.Embedding
//本地模型路径用于在选择llama时可以快速选择目录下的模型以及保存下载的模型
LLamaSharp.FileDirectory
//默认管理员账号密码
Login
//导入异步处理的线程数使用在线API可以高一点本地模型建议1 否则容易内存溢出崩掉
BackgroundTaskBroker.ImportKMSTask.WorkerCount
```

57
docs/deploy/start.md Normal file
View File

@@ -0,0 +1,57 @@
---
sidebar_position: 1
---
# 如何开始?
在这里我使用的是Postgres 作为数据存储和向量存储因为Semantic Kernel和Kernel Memory都支持他当然你也可以换成其他的。
模型默认支持openai、azure openai 和llama支持的gguf本地模型,如果需要使用其他模型可以使用one-api进行集成。
配置文件中的Login配置是默认的登陆账号和密码
需要配置如下的配置文件
## 使用docker-compose
提供了pg版本 **appsettings.json** 和 简化版本Sqlite+disk **docker-compose.simple.yml**
从项目根目录下载**docker-compose.yml**,然后把配置文件**appsettings.json**和它放在统一目录,
这里已经把pg的镜像做好了。在docker-compose.yml中可以修改默认账号密码然后你的**appsettings.json**的数据库连接需要保持一致。
然后你可以进入到目录后执行
```
docker-compose up -d
```
来启动AntSK
## 如何在docker中挂载本地模型和模型下载的目录
```
# 非 host 版本, 不使用本机代理
version: '3.8'
services:
antsk:
container_name: antsk
image: registry.cn-hangzhou.aliyuncs.com/AIDotNet/antsk:v0.1.5
ports:
- 5000:5000
networks:
- antsk
depends_on:
- antskpg
restart: always
environment:
- ASPNETCORE_URLS=http://*:5000
volumes:
- ./appsettings.json:/app/appsettings.json # 本地配置文件 需要放在同级目录
- D://model:/app/model
networks:
antsk:
```
以这个为示例意思是把windows本地D://model的文件夹挂载进 容器内/app/model 如果是这样你的appsettings.json中的模型地址应该配置为
```
model/xxx.gguf
```
DB我使用的是CodeFirst模式只要配置好数据库链接表结构是自动创建的

16
docs/deploy/style.md Normal file
View File

@@ -0,0 +1,16 @@
---
sidebar_position: 3
---
# 找不到样式问题解决
AntSK/src/AntSK下执行:
```
dotnet clean
dotnet build
dotnet publish "AntSK.csproj"
```
再去AntSK/src/AntSK/bin/Release/net8.0/publish下
```
dotnet AntSK.dll
```
然后启动就有样式了

View File

@@ -0,0 +1,14 @@
{
"position": 2,
"label": "快速开发",
"collapsible": true,
"collapsed": false,
"className": "red",
"link": {
"type": "generated-index",
"title": "快速开发"
},
"customProps": {
"description": "快速基于项目二次开发!"
}
}

View File

@@ -0,0 +1,14 @@
{
"position": 2,
"label": "介绍",
"collapsible": true,
"collapsed": false,
"className": "red",
"link": {
"type": "generated-index",
"title": "使用案例"
},
"customProps": {
"description": "提供快速使用AntSK的一些案例"
}
}

Binary file not shown.

After

Width:  |  Height:  |  Size: 101 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 54 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 53 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 202 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 47 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 48 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 55 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 170 KiB

70
docs/introduce/readme.md Normal file
View File

@@ -0,0 +1,70 @@
---
sidebar_position: 1
---
# AntSK功能介绍
## 基于.Net8+AntBlazor+SemanticKernel 打造的AI知识库/智能体
## 核心功能
- **语义内核 (Semantic Kernel)**:采用领先的自然语言处理技术,准确理解、处理和响应复杂的语义查询,为用户提供精确的信息检索和推荐服务。
- **内存内核 (Kernel Memory)**具备持续学习和存储知识点的能力AntSK 拥有长期记忆功能,累积经验,提供更个性化的交互体验。
- **知识库**通过文档Word、PDF、Excel、Txt、Markdown、Json、PPT等形式导入知识库可以进行知识库问答。
- **GPTs 生成**此平台支持创建个性化的GPT模型尝试构建您自己的GPT模型。
- **API接口发布**将内部功能以API的形式对外提供便于开发者将AntSK 集成进其他应用,增强应用智慧。
- **API插件系统**开放式API插件系统允许第三方开发者或服务商轻松将其服务集成到AntSK不断增强应用功能。
- **.Net插件系统**开放式dll插件系统允许第三方开发者或服务商轻松将其业务功能通过标准格式的代码生成dll后集成到AntSK不断增强应用功能。
- **联网搜索**AntSK实时获取最新信息确保用户接受到的资料总是最及时、最相关的。
- **模型管理**:适配和管理集成不同厂商的不同模型。并且支持**llama.cpp**所支持的gguf类型以及**llamafactory**所支持的模型离线运行
- **国产信创**AntSK支持国产模型和国产数据库可以在信创条件下运行
- **模型微调**规划中基于llamafactory进行模型微调
## 应用场景
AntSK 适用于多种业务场景,例如:
- 企业级知识管理系统
- 自动客服与聊天机器人
- 企业级搜索引擎
- 个性化推荐系统
- 智能辅助写作
- 教育与在线学习平台
- 其他有意思的AI App
## 功能示例
[视频示例](https://www.bilibili.com/video/BV1zH4y1h7Y9/)
首先需要创建知识库
![知识库](./img/知识库.png)
在知识库里可以使用文档或者url进行导入
![知识库详情](./img/知识库详情.png)
点击查看可以查看知识库的文档切片情况
![文档切片](./img/文档切片.png)
然后我们需要创建应用,可以创建对话应用和知识库。
![应用](./img/应用.png)
知识库应用需要选择已有的知识库,可以选多个
![应用配置](./img/应用配置.png)
然后再对话中可以对知识库的文档进行提问
![问答](./img/问答.png)
另外我们也可以创建对话应用,可以在对应应用中配置提示词模板
![对话应用](./img/简单对话.png)
下面来看看效果吧
![对话效果](./img/对话效果.png)

BIN
images/gzh.jpg Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 180 KiB

View File

@@ -9,31 +9,42 @@
</PropertyGroup>
<ItemGroup>
<PackageReference Include="AntDesign.Charts" Version="0.5.1" />
<PackageReference Include="AntDesign.ProLayout" Version="0.17.3" />
<PackageReference Include="AntDesign.ProLayout" Version="0.18.1" />
<PackageReference Include="BlazorComponents.Terminal" Version="0.6.0" />
<PackageReference Include="Blazored.LocalStorage" Version="4.5.0" />
<PackageReference Include="pythonnet" Version="3.0.3" />
<PackageReference Include="Swashbuckle.AspNetCore" Version="6.5.0" />
<PackageReference Include="AutoMapper" Version="8.1.0" />
<PackageReference Include="BCrypt.Net-Next" Version="4.0.3" />
<PackageReference Include="MarkdownSharp" Version="2.0.5" />
<PackageReference Include="Markdig" Version="0.36.2" />
<PackageReference Include="Newtonsoft.Json" Version="13.0.3" />
<PackageReference Include="SqlSugarCore" Version="5.1.4.143" />
<PackageReference Include="SqlSugarCore" Version="5.1.4.149" />
<PackageReference Include="System.Data.SQLite.Core" Version="1.0.118" />
<PackageReference Include="RestSharp" Version="110.2.0" />
<PackageReference Include="Microsoft.SemanticKernel" Version="1.4.0" />
<PackageReference Include="Microsoft.SemanticKernel.Core" Version="1.4.0" />
<PackageReference Include="Microsoft.SemanticKernel.Plugins.Core" Version="1.4.0-alpha" />
<PackageReference Include="Microsoft.KernelMemory.Core" Version="0.30.240227.1" />
<PackageReference Include="Microsoft.KernelMemory.MemoryDb.Postgres" Version="0.30.240227.1" />
<PackageReference Include="Microsoft.SemanticKernel" Version="1.6.3" />
<PackageReference Include="Microsoft.SemanticKernel.Core" Version="1.6.3" />
<PackageReference Include="Microsoft.SemanticKernel.Plugins.Core" Version="1.6.3-alpha" />
<PackageReference Include="Microsoft.KernelMemory.Core" Version="0.35.240321.1" />
<PackageReference Include="Microsoft.KernelMemory.MemoryDb.Postgres" Version="0.35.240321.1" />
<PackageReference Include="Microsoft.KernelMemory.MemoryDb.Qdrant" Version="0.35.240321.1" />
<PackageReference Include="Microsoft.KernelMemory.MemoryDb.Redis" Version="0.35.240321.1" />
<PackageReference Include="Microsoft.KernelMemory.MemoryDb.AzureAISearch" Version="0.35.240321.1" />
<PackageReference Include="LLamaSharp" Version="0.10.0" />
<PackageReference Include="LLamaSharp.Backend.Cpu" Version="0.10.0" />
<PackageReference Include="LLamaSharp.Backend.Cuda12" Version="0.10.0" />
<PackageReference Include="LLamaSharp.kernel-memory" Version="0.10.0" />
<PackageReference Include="LLamaSharp.semantic-kernel" Version="0.10.0" />
<PackageReference Include="LLamaSharp" Version="0.11.1" />
<PackageReference Include="LLamaSharp.Backend.Cpu" Version="0.11.1" />
<PackageReference Include="LLamaSharp.Backend.Cuda12" Version="0.11.1" />
<PackageReference Include="LLamaSharp.kernel-memory" Version="0.11.1" />
<PackageReference Include="LLamaSharp.semantic-kernel" Version="0.11.1" />
</ItemGroup>
<ItemGroup>
<ProjectReference Include="..\AntSK.LLamaFactory\AntSK.LLamaFactory.csproj" />
<ProjectReference Include="..\AntSk.LLM\AntSK.LLM.csproj" />
<ProjectReference Include="..\MiddleWare\AntSK.BackgroundTask\AntSK.BackgroundTask.csproj" />
</ItemGroup>

View File

@@ -17,6 +17,27 @@
<param name="assemblies">程序集集合</param>
<returns></returns>
</member>
<member name="M:AntSK.Domain.Common.DependencyInjection.InitExtensions.CodeFirst(Microsoft.AspNetCore.Builder.WebApplication)">
<summary>
使用codefirst创建数据库表
</summary>
<param name="services"></param>
<returns></returns>
</member>
<member name="M:AntSK.Domain.Common.DependencyInjection.InitExtensions.LoadFun(Microsoft.AspNetCore.Builder.WebApplication)">
<summary>
加载数据库的插件
</summary>
<param name="services"></param>
<returns></returns>
</member>
<member name="M:AntSK.Domain.Common.DependencyInjection.InitExtensions.AddAntSKSwagger(Microsoft.Extensions.DependencyInjection.IServiceCollection)">
<summary>
swagger 初始化
</summary>
<param name="serviceCollection"></param>
<returns></returns>
</member>
<member name="F:AntSK.Domain.Common.DependencyInjection.ServiceLifetime.Scoped">
<summary>
作用域
@@ -48,12 +69,47 @@
<param name="value"></param>
<returns></returns>
</member>
<member name="T:AntSK.Domain.Domain.Model.Enum.AIType">
<summary>
AI类型
</summary>
</member>
<member name="T:AntSK.Domain.Domain.Model.Enum.AIModelType">
<summary>
模型类型
</summary>
</member>
<member name="P:AntSK.Domain.Domain.Model.MessageInfo.IsSend">
<summary>
发送是true 接收是false
</summary>
</member>
<member name="P:AntSK.Domain.Domain.Model.PageList`1.PageIndex">
<summary>
当前页从1开始
</summary>
</member>
<member name="P:AntSK.Domain.Domain.Model.PageList`1.PageSize">
<summary>
每页数量
</summary>
</member>
<member name="P:AntSK.Domain.Domain.Model.PageList`1.TotalCount">
<summary>
总数
</summary>
</member>
<member name="M:AntSK.Domain.Domain.Other.EmbeddingConfig.LoadModel(System.String,System.String)">
<summary>
模型写死
</summary>
</member>
<member name="F:AntSK.Domain.Domain.Other.LLamaConfig.dicLLamaWeights">
<summary>
避免模型重复加载,本地缓存
</summary>
</member>
<member name="M:AntSK.Domain.Domain.Service.ChatService.SendChatByAppAsync(AntSK.Domain.Repositories.Apps,System.String,System.String)">
<member name="M:AntSK.Domain.Domain.Service.ChatService.SendChatByAppAsync(AntSK.Domain.Repositories.Apps,System.String,Microsoft.SemanticKernel.ChatCompletion.ChatHistory)">
<summary>
发送消息
</summary>
@@ -62,6 +118,11 @@
<param name="history"></param>
<returns></returns>
</member>
<member name="M:AntSK.Domain.Domain.Service.FunctionService.SearchMarkedMethods">
<summary>
查询程序集中的方法委托后续利用Source Generators生成
</summary>
</member>
<member name="M:AntSK.Domain.Domain.Service.KernelService.GetKernelByApp(AntSK.Domain.Repositories.Apps)">
<summary>
获取kernel实例依赖注入不好按每个用户去Import不同的插件所以每次new一个新的kernel
@@ -77,6 +138,20 @@
<param name="app"></param>
<param name="_kernel"></param>
</member>
<member name="M:AntSK.Domain.Domain.Service.KernelService.ImportApiFunction(AntSK.Domain.Repositories.Apps,System.Collections.Generic.List{Microsoft.SemanticKernel.KernelFunction})">
<summary>
导入API插件
</summary>
<param name="app"></param>
<param name="functions"></param>
</member>
<member name="M:AntSK.Domain.Domain.Service.KernelService.ImportNativeFunction(AntSK.Domain.Repositories.Apps,System.Collections.Generic.List{Microsoft.SemanticKernel.KernelFunction})">
<summary>
导入原生插件
</summary>
<param name="app"></param>
<param name="functions"></param>
</member>
<member name="M:AntSK.Domain.Domain.Service.KernelService.RegisterPluginsWithKernel(Microsoft.SemanticKernel.Kernel)">
<summary>
注册默认插件
@@ -92,61 +167,6 @@
<param name="history"></param>
<returns></returns>
</member>
<member name="M:AntSK.Domain.Map.MapperExtend.ToDTOList``1(System.Object)">
<summary>
Entity集合转DTO集合
</summary>
<typeparam name="T"></typeparam>
<param name="value"></param>
<returns></returns>
</member>
<member name="M:AntSK.Domain.Map.MapperExtend.ToDTO``1(System.Object)">
<summary>
Entity转DTO
</summary>
<typeparam name="T"></typeparam>
<param name="value"></param>
<returns></returns>
</member>
<member name="M:AntSK.Domain.Map.MapperExtend.MapTo``1(System.Object,``0)">
<summary>
给已有对象map,适合update场景如需过滤空值需要在AutoMapProfile 设置
</summary>
<typeparam name="T"></typeparam>
<param name="self"></param>
<param name="result"></param>
<returns></returns>
</member>
<member name="T:AntSK.Domain.Model.Enum.AIType">
<summary>
AI类型
</summary>
</member>
<member name="T:AntSK.Domain.Model.Enum.AIModelType">
<summary>
模型类型
</summary>
</member>
<member name="P:AntSK.Domain.Model.MessageInfo.IsSend">
<summary>
发送是true 接收是false
</summary>
</member>
<member name="P:AntSK.Domain.Model.PageList`1.PageIndex">
<summary>
当前页从1开始
</summary>
</member>
<member name="P:AntSK.Domain.Model.PageList`1.PageSize">
<summary>
每页数量
</summary>
</member>
<member name="P:AntSK.Domain.Model.PageList`1.TotalCount">
<summary>
总数
</summary>
</member>
<member name="P:AntSK.Domain.Options.DBConnectionOption.DbType">
<summary>
sqlite连接字符串
@@ -237,6 +257,11 @@
会话模型ID
</summary>
</member>
<member name="P:AntSK.Domain.Repositories.Apps.EmbeddingModelID">
<summary>
Embedding 模型Id
</summary>
</member>
<member name="P:AntSK.Domain.Repositories.Apps.Temperature">
<summary>
温度
@@ -252,6 +277,11 @@
插件列表
</summary>
</member>
<member name="P:AntSK.Domain.Repositories.Apps.NativeFunctionList">
<summary>
本地函数列表
</summary>
</member>
<member name="P:AntSK.Domain.Repositories.Apps.KmsIdList">
<summary>
知识库ID列表
@@ -262,6 +292,11 @@
API调用秘钥
</summary>
</member>
<member name="P:AntSK.Domain.Repositories.Funs.Path">
<summary>
接口描述
</summary>
</member>
<member name="P:AntSK.Domain.Repositories.KmsDetails.FileName">
<summary>
文件名称
@@ -734,11 +769,151 @@
<param name="stream"></param>
<returns></returns>
</member>
<member name="M:AntSK.Domain.Utils.ConvertUtils.ToQueryString(System.Collections.Generic.Dictionary{System.String,System.String})">
<summary>
json参数转化querystring参数
</summary>
<param name="parameters"></param>
<returns></returns>
</member>
<member name="M:AntSK.Domain.Utils.ConvertUtils.ComparisonIgnoreCase(System.String,System.String)">
<summary>
忽略大小写匹配
</summary>
<param name="s"></param>
<param name="value"></param>
<returns></returns>
</member>
<member name="M:AntSK.Domain.Utils.RepoFiles.SamplePluginsPath">
<summary>
Scan the local folders from the repo, looking for "samples/plugins" folder.
</summary>
<returns>The full path to samples/plugins</returns>
</member>
<member name="T:AntSK.Domain.Utils.XmlCommentHelper">
<summary>
注释辅助类
</summary>
</member>
<member name="M:AntSK.Domain.Utils.XmlCommentHelper.LoadAll">
<summary>
从当前dll文件中加载所有的xml文件
</summary>
</member>
<member name="M:AntSK.Domain.Utils.XmlCommentHelper.LoadXml(System.String[])">
<summary>
从xml中加载
</summary>
<param name="xmls"></param>
</member>
<member name="M:AntSK.Domain.Utils.XmlCommentHelper.Load(System.String[])">
<summary>
从文件中加载
</summary>
<param name="xmlFiles"></param>
</member>
<member name="M:AntSK.Domain.Utils.XmlCommentHelper.Load(System.IO.Stream[])">
<summary>
从流中加载
</summary>
<param name="streams"></param>
</member>
<member name="M:AntSK.Domain.Utils.XmlCommentHelper.GetTypeComment(System.Type,System.String,System.Boolean)">
<summary>
读取类型中的注释
</summary>
<param name="type">类型</param>
<param name="xPath">注释路径</param>
<param name="humanize">可读性优化(比如去掉xml标记)</param>
<returns></returns>
</member>
<member name="M:AntSK.Domain.Utils.XmlCommentHelper.GetFieldOrPropertyComment(System.Reflection.MemberInfo,System.String,System.Boolean)">
<summary>
读取字段或者属性的注释
</summary>
<param name="fieldOrPropertyInfo">字段或者属性</param>
<param name="xPath">注释路径</param>
<param name="humanize">可读性优化(比如去掉xml标记)</param>
<returns></returns>
</member>
<member name="M:AntSK.Domain.Utils.XmlCommentHelper.GetMethodComment(System.Reflection.MethodInfo,System.String,System.Boolean)">
<summary>
读取方法中的注释
</summary>
<param name="methodInfo">方法</param>
<param name="xPath">注释路径</param>
<param name="humanize">可读性优化(比如去掉xml标记)</param>
<returns></returns>
</member>
<member name="M:AntSK.Domain.Utils.XmlCommentHelper.GetMethodReturnComment(System.Reflection.MethodInfo,System.Boolean)">
<summary>
读取方法中的返回值注释
</summary>
<param name="methodInfo">方法</param>
<param name="humanize">可读性优化(比如去掉xml标记)</param>
<returns></returns>
</member>
<member name="M:AntSK.Domain.Utils.XmlCommentHelper.GetParameterComment(System.Reflection.ParameterInfo,System.Boolean)">
<summary>
读取参数的注释
</summary>
<param name="parameterInfo">参数</param>
<param name="humanize">可读性优化(比如去掉xml标记)</param>
<returns></returns>
</member>
<member name="M:AntSK.Domain.Utils.XmlCommentHelper.GetParameterComments(System.Reflection.MethodInfo,System.Boolean)">
<summary>
读取方法的所有参数的注释
</summary>
<param name="methodInfo">方法</param>
<param name="humanize">可读性优化(比如去掉xml标记)</param>
<returns></returns>
</member>
<member name="M:AntSK.Domain.Utils.XmlCommentHelper.GetComment(System.String,System.String,System.Boolean)">
<summary>
读取指定名称节点的注释
</summary>
<param name="name">节点名称</param>
<param name="xPath">注释路径</param>
<param name="humanize">可读性优化(比如去掉xml标记)</param>
<returns></returns>
</member>
<member name="M:AntSK.Domain.Utils.XmlCommentHelper.GetSummary(System.String,System.Boolean)">
<summary>
读取指定节点的summary注释
</summary>
<param name="name">节点名称</param>
<param name="humanize">可读性优化(比如去掉xml标记)</param>
<returns></returns>
</member>
<member name="M:AntSK.Domain.Utils.XmlCommentHelper.GetExample(System.String,System.Boolean)">
<summary>
读取指定节点的example注释
</summary>
<param name="name">节点名称</param>
<param name="humanize">可读性优化(比如去掉xml标记)</param>
<returns></returns>
</member>
<member name="M:AntSK.Domain.Utils.XmlCommentHelper.GetMemberNameForMethod(System.Reflection.MethodInfo)">
<summary>
获取方法的节点名称
</summary>
<param name="method"></param>
<returns></returns>
</member>
<member name="M:AntSK.Domain.Utils.XmlCommentHelper.GetMemberNameForType(System.Type)">
<summary>
获取类型的节点名称
</summary>
<param name="type"></param>
<returns></returns>
</member>
<member name="M:AntSK.Domain.Utils.XmlCommentHelper.GetMemberNameForFieldOrProperty(System.Reflection.MemberInfo)">
<summary>
获取字段或者属性的节点名称
</summary>
<param name="fieldOrPropertyInfo"></param>
<returns></returns>
</member>
</members>
</doc>

View File

@@ -0,0 +1,8 @@
namespace AntSK.Domain.Common
{
[AttributeUsage(AttributeTargets.Method)]
public class AntSkFunctionAttribute : Attribute
{
// 自定义的ActionAttribute
}
}

View File

@@ -0,0 +1,164 @@
using AntSK.Domain.Domain.Model.Constant;
using AntSK.Domain.Domain.Service;
using AntSK.Domain.Repositories;
using DocumentFormat.OpenXml.Office2016.Drawing.ChartDrawing;
using Microsoft.AspNetCore.Builder;
using Microsoft.AspNetCore.Mvc;
using Microsoft.Extensions.DependencyInjection;
using Microsoft.OpenApi.Models;
using SqlSugar;
using Swashbuckle.AspNetCore.SwaggerGen;
using System;
using System.Collections.Generic;
using System.Linq;
using System.Reflection;
using System.Text;
using System.Threading.Tasks;
namespace AntSK.Domain.Common.DependencyInjection
{
public static class InitExtensions
{
/// <summary>
/// 使用codefirst创建数据库表
/// </summary>
/// <param name="services"></param>
/// <returns></returns>
public static WebApplication CodeFirst(this WebApplication app)
{
using (var scope = app.Services.CreateScope())
{
// 获取仓储服务
var _repository = scope.ServiceProvider.GetRequiredService<IApps_Repositories>();
// 创建数据库(如果不存在)
_repository.GetDB().DbMaintenance.CreateDatabase();
// 获取当前应用程序域中所有程序集
var assemblies = AppDomain.CurrentDomain.GetAssemblies();
// 在所有程序集中查找具有[SugarTable]特性的类
foreach (var assembly in assemblies)
{
// 获取该程序集中所有具有SugarTable特性的类型
var entityTypes = assembly.GetTypes()
.Where(type => TypeIsEntity(type));
// 为每个找到的类型初始化数据库表
foreach (var type in entityTypes)
{
_repository.GetDB().CodeFirst.InitTables(type);
}
}
}
return app;
}
public static WebApplication InitDbData(this WebApplication app)
{
using (var scope = app.Services.CreateScope())
{
// 初始化字典
var _dic_Repository = scope.ServiceProvider.GetRequiredService<IDics_Repositories>();
var llamafactoryStart = _dic_Repository.GetFirst(p => p.Type == LLamaFactoryConstantcs.LLamaFactorDic && p.Key == LLamaFactoryConstantcs.IsStartKey);
if (llamafactoryStart==null)
{
llamafactoryStart = new Dics();
llamafactoryStart.Id=Guid.NewGuid().ToString();
llamafactoryStart.Type = LLamaFactoryConstantcs.LLamaFactorDic;
llamafactoryStart.Key = LLamaFactoryConstantcs.IsStartKey;
llamafactoryStart.Value = "false";
_dic_Repository.Insert(llamafactoryStart);
}
}
return app;
}
/// <summary>
/// 加载数据库的插件
/// </summary>
/// <param name="services"></param>
/// <returns></returns>
public static WebApplication LoadFun(this WebApplication app)
{
try
{
using (var scope = app.Services.CreateScope())
{
//codefirst 创建表
var funRep = scope.ServiceProvider.GetRequiredService<IFuns_Repositories>();
var functionService = scope.ServiceProvider.GetRequiredService<FunctionService>();
var funs = funRep.GetList();
foreach (var fun in funs)
{
functionService.FuncLoad(fun.Path);
}
}
}
catch (Exception ex)
{
Console.WriteLine(ex.Message + " ---- " + ex.StackTrace);
}
return app;
}
private static bool TypeIsEntity(Type type)
{
// 检查类型是否具有SugarTable特性
return type.GetCustomAttributes(typeof(SugarTable), inherit: false).Length > 0;
}
/// <summary>
/// swagger 初始化
/// </summary>
/// <param name="serviceCollection"></param>
/// <returns></returns>
public static IServiceCollection AddAntSKSwagger(this IServiceCollection serviceCollection)
{
serviceCollection.AddSwaggerGen(c =>
{
c.SwaggerDoc("v1", new() { Title = "AntSK.Api", Version = "v1" });
//添加Api层注释true表示显示控制器注释
var xmlFile = $"{Assembly.GetExecutingAssembly().GetName().Name}.xml";
var xmlPath = Path.Combine(AppContext.BaseDirectory, xmlFile);
c.IncludeXmlComments(xmlPath, true);
//添加Domain层注释true表示显示控制器注释
var xmlFile1 = $"{Assembly.GetExecutingAssembly().GetName().Name.Replace("Api", "Domain")}.xml";
var xmlPath1 = Path.Combine(AppContext.BaseDirectory, xmlFile1);
c.IncludeXmlComments(xmlPath1, true);
c.DocInclusionPredicate((docName, apiDes) =>
{
if (!apiDes.TryGetMethodInfo(out MethodInfo method))
return false;
var version = method.DeclaringType.GetCustomAttributes(true).OfType<ApiExplorerSettingsAttribute>().Select(m => m.GroupName);
if (docName == "v1" && !version.Any())
return true;
var actionVersion = method.GetCustomAttributes(true).OfType<ApiExplorerSettingsAttribute>().Select(m => m.GroupName);
if (actionVersion.Any())
return actionVersion.Any(v => v == docName);
return version.Any(v => v == docName);
});
c.AddSecurityDefinition("Bearer", new OpenApiSecurityScheme()
{
Description = "Directly enter bearer {token} in the box below (note that there is a space between bearer and token)",
Name = "Authorization",
In = ParameterLocation.Header,
Type = SecuritySchemeType.ApiKey,
});
c.AddSecurityRequirement(new OpenApiSecurityRequirement
{
{
new OpenApiSecurityScheme
{
Reference = new OpenApiReference()
{
Id = "Bearer",
Type = ReferenceType.SecurityScheme
}
}, Array.Empty<string>()
}
});
});
return serviceCollection;
}
}
}

View File

@@ -0,0 +1,21 @@
using LLamaSharp.KernelMemory;
using Microsoft.KernelMemory.AI;
using Microsoft.KernelMemory;
using System;
using System.Collections.Generic;
using System.Linq;
using System.Text;
using System.Threading.Tasks;
namespace AntSK.Domain.Common.Embedding
{
public static class BuilderBgeExtensions
{
public static IKernelMemoryBuilder WithBgeTextEmbeddingGeneration(this IKernelMemoryBuilder builder, HuggingfaceTextEmbeddingGenerator textEmbeddingGenerator)
{
builder.AddSingleton((ITextEmbeddingGenerator)textEmbeddingGenerator);
builder.AddIngestionEmbeddingGenerator(textEmbeddingGenerator);
return builder;
}
}
}

View File

@@ -0,0 +1,56 @@
using LLama.Common;
using LLama;
using LLamaSharp.KernelMemory;
using Microsoft.KernelMemory.AI;
using System;
using System.Collections.Generic;
using System.Linq;
using System.Text;
using System.Threading.Tasks;
using AntSK.Domain.Domain.Other;
namespace AntSK.Domain.Common.Embedding
{
public class HuggingfaceTextEmbeddingGenerator : ITextEmbeddingGenerator, ITextTokenizer, IDisposable
{
public int MaxTokens => 1024;
public int MaxTokenTotal => 1024;
private readonly dynamic _embedder;
public HuggingfaceTextEmbeddingGenerator(string pyDllPath,string modelName)
{
_embedder = EmbeddingConfig.LoadModel(pyDllPath, modelName);
}
public void Dispose()
{
EmbeddingConfig.Dispose();
}
//public async Task<IList<ReadOnlyMemory<float>>> GenerateEmbeddingAsync(IList<string> data, CancellationToken cancellationToken = default)
//{
// IList<ReadOnlyMemory<float>> results = new List<ReadOnlyMemory<float>>();
// foreach (var d in data)
// {
// var embeddings = await EmbeddingConfig.GetEmbedding(d);
// results.Add(new ReadOnlyMemory<float>(embeddings));
// }
// return results;
//}
public async Task<Microsoft.KernelMemory.Embedding> GenerateEmbeddingAsync(string text, CancellationToken cancellationToken = default)
{
var embeddings = await EmbeddingConfig.GetEmbedding(text);
return new Microsoft.KernelMemory.Embedding(embeddings);
}
public int CountTokens(string text)
{
return EmbeddingConfig.TokenCount(text);
}
}
}

View File

@@ -0,0 +1,71 @@
using System;
using System.Collections.Generic;
using System.Diagnostics;
using System.Linq;
using System.Text;
using System.Threading.Tasks;
namespace AntSK.Domain.Common.LLamaFactory
{
public class ProcessWrapper
{
private Process process;
public static bool isProcessComplete = false;
public void StartProcess(string arguments, string workingDirectory)
{
process = new Process
{
StartInfo = new ProcessStartInfo
{
FileName = "python",
Arguments = arguments,
UseShellExecute = false,
RedirectStandardOutput = true,
RedirectStandardError = true,
CreateNoWindow = true,
WorkingDirectory = workingDirectory
}
};
using (Process start = Process.Start(process.StartInfo))
{
using (StreamReader reader = start.StandardOutput)
{
string result = reader.ReadToEnd();
if (result != null)
{
if (result.Contains(":8000"))
{
isProcessComplete = true;
}
}
Console.WriteLine(result);
}
start.WaitForExit();
}
}
public string WaitForProcessExit()
{
process.WaitForExit();
return process.StandardOutput.ReadToEnd();
}
public void KillProcess()
{
try
{
if (!process.HasExited)
{
process.Kill();
}
}
catch (InvalidOperationException)
{
// Process already exited.
}
}
}
}

View File

@@ -1,16 +0,0 @@
using System;
using System.Collections.Generic;
using System.Linq;
using System.Text;
using System.Threading.Tasks;
namespace AntSK.Domain.Domain.Dto
{
public class RelevantSource
{
public string SourceName { get; set; }
public string Text { get; set; }
public float Relevance { get; set; }
}
}

View File

@@ -1,8 +1,11 @@
using AntSK.Domain.Domain.Dto;
using AntSK.Domain.Domain.Model;
using AntSK.Domain.Domain.Model.Dto;
using AntSK.Domain.Repositories;
using Microsoft.SemanticKernel;
using Microsoft.SemanticKernel.ChatCompletion;
using System;
using System.Collections.Generic;
using System.Drawing;
using System.Linq;
using System.Text;
using System.Threading.Tasks;
@@ -11,8 +14,10 @@ namespace AntSK.Domain.Domain.Interface
{
public interface IChatService
{
IAsyncEnumerable<StreamingKernelContent> SendChatByAppAsync(Apps app, string questions, string history);
IAsyncEnumerable<StreamingKernelContent> SendChatByAppAsync(Apps app, string questions, ChatHistory history);
IAsyncEnumerable<StreamingKernelContent> SendKmsByAppAsync(Apps app, string questions, string history, List<RelevantSource> relevantSources = null);
IAsyncEnumerable<StreamingKernelContent> SendKmsByAppAsync(Apps app, string questions, ChatHistory history, string filePath, List<RelevantSource> relevantSources = null);
Task<string> SendImgByAppAsync(Apps app, string questions);
Task<ChatHistory> GetChatHistory(List<MessageInfo> MessageList);
}
}
}

View File

@@ -1,4 +1,4 @@
using AntSK.Domain.Model;
using AntSK.Domain.Domain.Model;
namespace AntSK.Domain.Domain.Interface
{

View File

@@ -1,11 +1,24 @@
using AntSK.Domain.Domain.Dto;
using AntDesign;
using AntSK.Domain.Domain.Model.Dto;
using AntSK.Domain.Repositories;
using Microsoft.KernelMemory;
namespace AntSK.Domain.Domain.Interface
{
public interface IKMService
{
MemoryServerless GetMemory(Apps app);
MemoryServerless GetMemoryByKMS(string kmsID, SearchClientConfig searchClientConfig = null);
Task<List<KMFile>> GetDocumentByFileID(string kmsid, string fileid);
Task<List<KMFile>> GetDocumentByFileID(string kmsId, string fileId);
Task<List<RelevantSource>> GetRelevantSourceList(string kmsIdListStr, string msg);
List<UploadFileItem> FileList { get; }
bool BeforeUpload(UploadFileItem file);
void OnSingleCompleted(UploadInfo fileinfo);
}
}
}

View File

@@ -0,0 +1,21 @@
using AntSK.LLamaFactory.Model;
using System;
using System.Collections.Generic;
using System.Linq;
using System.Text;
using System.Threading.Tasks;
using static AntSK.Domain.Domain.Service.LLamaFactoryService;
namespace AntSK.Domain.Domain.Interface
{
public interface ILLamaFactoryService
{
public event LogMessageHandler LogMessageReceived;
Task PipInstall();
Task StartLLamaFactory(string modelName, string templateName);
void KillProcess();
List<LLamaModel> GetLLamaFactoryModels();
}
}

View File

@@ -0,0 +1,15 @@
using System;
using System.Collections.Generic;
using System.Linq;
using System.Text;
using System.Threading.Tasks;
namespace AntSK.Domain.Domain.Model.Constant
{
public class KmsConstantcs
{
public const string KmsIdTag = "kmsid";
public const string KmsIndex = "kms";
public const string KmsSearchNull="知识库未搜索到相关内容";
}
}

View File

@@ -0,0 +1,14 @@
using System;
using System.Collections.Generic;
using System.Linq;
using System.Text;
using System.Threading.Tasks;
namespace AntSK.Domain.Domain.Model.Constant
{
public class LLamaFactoryConstantcs
{
public const string LLamaFactorDic = "llamafactory";
public const string IsStartKey = "isstart";
}
}

View File

@@ -1,15 +1,14 @@
namespace AntSK.Domain.Domain.Dto
namespace AntSK.Domain.Domain.Model.Dto
{
public class KMFile
{
public string DocumentId { get; set; }
public string Text { get; set; }
public string Url { get; set; }
public string? Url { get; set; }
public string LastUpdate { get; set; }
public string Schema { get; set; }
public string File { get; set; }
}
}
}

View File

@@ -1,4 +1,4 @@
namespace AntSK.Domain.Domain.Dto
namespace AntSK.Domain.Domain.Model.Dto.OpenAPI
{
public class OpenAIModel
{

View File

@@ -1,6 +1,6 @@
using Newtonsoft.Json;
namespace AntSK.Domain.Domain.Dto
namespace AntSK.Domain.Domain.Model.Dto.OpenAPI
{
public class OpenAIResult
{

View File

@@ -0,0 +1,16 @@

namespace AntSK.Domain.Domain.Model.Dto
{
public class RelevantSource
{
public string SourceName { get; set; }
public string Text { get; set; }
public float Relevance { get; set; }
public override string ToString()
{
return $"[file:{SourceName};Relevance:{(Relevance * 100):F2}%]:{Text}";
}
}
}

View File

@@ -0,0 +1,45 @@
using System.ComponentModel.DataAnnotations;
namespace AntSK.Domain.Domain.Model.Enum
{
/// <summary>
/// AI类型
/// </summary>
public enum AIType
{
[Display(Name = "Open AI")]
OpenAI = 1,
[Display(Name = "Azure Open AI")]
AzureOpenAI = 2,
[Display(Name = "LLama本地模型")]
LLamaSharp = 3,
[Display(Name = "星火大模型")]
SparkDesk = 4,
[Display(Name = "灵积大模型")]
DashScope = 5,
[Display(Name = "LLamaFactory")]
LLamaFactory = 6,
[Display(Name = "Bge Embedding")]
BgeEmbedding = 7,
[Display(Name = "StableDiffusion")]
StableDiffusion = 8,
[Display(Name = "模拟输出")]
Mock = 100,
}
/// <summary>
/// 模型类型
/// </summary>
public enum AIModelType
{
Chat = 1,
Embedding = 2,
Image=3,
}
}

View File

@@ -0,0 +1,15 @@
using System;
using System.Collections.Generic;
using System.Linq;
using System.Text;
using System.Threading.Tasks;
namespace AntSK.Domain.Domain.Model.Enum
{
public enum AppType
{
chat = 1,
kms = 2,
img=3
}
}

View File

@@ -1,9 +1,8 @@
namespace AntSK.Domain.Model
namespace AntSK.Domain.Domain.Model.Enum
{
public enum HttpMethodType
{
Get = 1,
Post = 2,
}
}
}

View File

@@ -1,4 +1,4 @@
namespace AntSK.Domain.Model.Enum
namespace AntSK.Domain.Domain.Model.Enum
{
public enum ImportKmsStatus
{

View File

@@ -0,0 +1,23 @@
using System;
using System.Collections.Generic;
using System.Linq;
using System.Text;
using System.Threading.Tasks;
namespace AntSK.Domain.Domain.Model.Fun
{
public class FunDto
{
public string Name { get; set; }
public string Description { get; set; }
public FunType FunType { get; set; }
}
public enum FunType
{
System=1,
Import=2
}
}

View File

@@ -1,6 +1,6 @@
using AntSK.Domain.Repositories;
namespace AntSK.Domain.Model
namespace AntSK.Domain.Domain.Model
{
public class ImportKMSTaskDTO
{

View File

@@ -1,4 +1,4 @@
namespace AntSK.Domain.Model
namespace AntSK.Domain.Domain.Model
{
public class MessageInfo
{
@@ -10,7 +10,11 @@
/// 发送是true 接收是false
/// </summary>
public bool IsSend { get; set; } = false;
public DateTime CreateTime { get; set; }
public string? FilePath { get; set; }
public string? FileName { get; set; }
}
}
}

View File

@@ -1,4 +1,4 @@
namespace AntSK.Domain.Model
namespace AntSK.Domain.Domain.Model
{
public class PageList<T>
{

View File

@@ -0,0 +1,40 @@
using System;
using System.Collections.Generic;
using System.Linq;
using System.Text;
using System.Threading.Tasks;
namespace AntSK.Domain.Domain.Model.hfmirror
{
public class HfModel
{
public List<HfModels> models { get; set; }
public int numItemsPerPage { get; set; }
public int numTotalItems { get; set; }
public int pageIndex { get; set; }
}
public class HfModels
{
public string Author { get; set; }
public HfAuthorData AuthorData { get; set; }
public int Downloads { get; set; }
public object Gated { get; set; }
public string Id { get; set; }
public DateTime LastModified { get; set; }
public int Likes { get; set; }
public string PipelineTag { get; set; }
public bool Private { get; set; }
public string RepoType { get; set; }
public bool IsLikedByUser { get; set; }
}
public class HfAuthorData
{
public string AvatarUrl { get; set; }
public string Fullname { get; set; }
public string Name { get; set; }
public string Type { get; set; }
public bool IsHf { get; set; }
public bool IsEnterprise { get; set; }
}
}

View File

@@ -0,0 +1,17 @@
using System;
using System.Collections.Generic;
using System.Linq;
using System.Text;
using System.Threading.Tasks;
namespace AntSK.Domain.Domain.Model.hfmirror
{
public class HfModelDetail
{
public string Name { get; set; }
public string Size { get; set; }
public string Path { get; set; }
public string Time { get; set; }
}
}

View File

@@ -1,6 +1,6 @@
using AntSK.BackgroundTask;
using AntSK.Domain.Domain.Interface;
using AntSK.Domain.Model;
using AntSK.Domain.Domain.Model;
using Microsoft.Extensions.DependencyInjection;
namespace AntSK.Domain.Domain.Other

View File

@@ -0,0 +1,88 @@
using Python.Runtime;
using System;
using System.Collections.Generic;
using System.Linq;
using System.Text;
using System.Threading.Tasks;
using static Python.Runtime.Py;
namespace AntSK.Domain.Domain.Other
{
public static class EmbeddingConfig
{
public static dynamic model { get; set; }
static object lockobj = new object();
/// <summary>
/// 模型写死
/// </summary>
public static dynamic LoadModel(string pythondllPath, string modelName)
{
lock (lockobj)
{
if (model == null)
{
//Runtime.PythonDLL = @"D:\Programs\Python\Python311\python311.dll";
Runtime.PythonDLL = pythondllPath;
PythonEngine.Initialize();
PythonEngine.BeginAllowThreads();
try
{
using (Py.GIL())// 初始化Python环境的Global Interpreter Lock)
{
dynamic modelscope = Py.Import("modelscope");
//dynamic model_dir = modelscope.snapshot_download("AI-ModelScope/bge-large-zh-v1.5", revision: "master");
dynamic model_dir = modelscope.snapshot_download(modelName, revision: "master");
dynamic HuggingFaceBgeEmbeddingstemp = Py.Import("langchain.embeddings");
dynamic HuggingFaceBgeEmbeddings = HuggingFaceBgeEmbeddingstemp.HuggingFaceBgeEmbeddings;
string model_name = model_dir;
dynamic model_kwargs = new PyDict();
model_kwargs["device"] = new PyString("cpu");
dynamic hugginmodel = HuggingFaceBgeEmbeddings(
model_name: model_dir,
model_kwargs: model_kwargs
);
model = hugginmodel;
return hugginmodel;
}
}
catch(Exception ex)
{
throw ex;
}
}
else
return model;
}
}
public static Task<float[]> GetEmbedding(string queryStr)
{
using (Py.GIL())
{
PyObject queryResult = model.embed_query(queryStr);
var floatList = queryResult.As<float[]>();
return Task.FromResult(floatList); ;
}
}
public static int TokenCount(string queryStr)
{
using (Py.GIL())
{
PyObject queryResult = model.client.tokenize(queryStr);
int len = (int)(queryResult.Length());
return len;
}
}
public static void Dispose()
{
Console.WriteLine("python dispose");
}
}
}

View File

@@ -3,24 +3,30 @@ using AntSK.Domain.Domain.Interface;
using AntSK.Domain.Repositories;
using Microsoft.SemanticKernel.Connectors.OpenAI;
using Microsoft.SemanticKernel;
using System;
using System.Collections.Generic;
using System.Linq;
using System.Text;
using System.Threading.Tasks;
using AntSK.Domain.Utils;
using AntSK.Domain.Domain.Model.Dto;
using AntSK.Domain.Domain.Model.Constant;
using DocumentFormat.OpenXml.Drawing;
using System.Reflection.Metadata;
using Microsoft.KernelMemory;
using AntSK.Domain.Model;
using MarkdownSharp;
using AntSK.Domain.Domain.Dto;
using System.Collections.Generic;
using Markdig;
using ChatHistory = Microsoft.SemanticKernel.ChatCompletion.ChatHistory;
using Microsoft.SemanticKernel.Plugins.Core;
using Azure.Core;
using AntSK.Domain.Domain.Model;
using AntSK.LLM.StableDiffusion;
using System.Drawing;
namespace AntSK.Domain.Domain.Service
{
[ServiceDescription(typeof(IChatService), ServiceLifetime.Scoped)]
public class ChatService(
IKernelService _kernelService,
IKMService _kMService ,
IKmsDetails_Repositories _kmsDetails_Repositories
IKMService _kMService,
IKmsDetails_Repositories _kmsDetails_Repositories,
IAIModels_Repositories _aIModels_Repositories
) : IChatService
{
/// <summary>
@@ -30,76 +36,166 @@ namespace AntSK.Domain.Domain.Service
/// <param name="questions"></param>
/// <param name="history"></param>
/// <returns></returns>
public async IAsyncEnumerable<StreamingKernelContent> SendChatByAppAsync(Apps app, string questions, string history)
public async IAsyncEnumerable<StreamingKernelContent> SendChatByAppAsync(Apps app, string questions, ChatHistory history)
{
if (string.IsNullOrEmpty(app.Prompt) || !app.Prompt.Contains("{{$input}}"))
{
//如果模板为空,给默认提示词
app.Prompt = app.Prompt.ConvertToString() + "{{$input}}";
}
KernelArguments args =new KernelArguments();
if (history.Count > 10)
{
app.Prompt = @"${{ConversationSummaryPlugin.SummarizeConversation $history}}" + app.Prompt;
args = new() {
{ "history", string.Join("\n", history.Select(x => x.Role + ": " + x.Content)) },
{ "input", questions }
};
}
else
{
args=new()
{
{ "input", $"{string.Join("\n", history.Select(x => x.Role + ": " + x.Content))}{Environment.NewLine} user:{questions}" }
};
}
var _kernel = _kernelService.GetKernelByApp(app);
var temperature = app.Temperature / 100;//存的是0~100需要缩小
OpenAIPromptExecutionSettings settings = new() { Temperature = temperature };
if (!string.IsNullOrEmpty(app.ApiFunctionList))
if (!string.IsNullOrEmpty(app.ApiFunctionList) || !string.IsNullOrEmpty(app.NativeFunctionList))//这里还需要加上本地插件的
{
_kernelService.ImportFunctionsByApp(app, _kernel);
settings.ToolCallBehavior = ToolCallBehavior.AutoInvokeKernelFunctions;
}
var func = _kernel.CreateFunctionFromPrompt(app.Prompt, settings);
var chatResult = _kernel.InvokeStreamingAsync(function: func, arguments: new KernelArguments() { ["input"] = $"{history}{Environment.NewLine} user:{questions}" });
var chatResult = _kernel.InvokeStreamingAsync(function: func,
arguments: args);
await foreach (var content in chatResult)
{
yield return content;
}
}
public async IAsyncEnumerable<StreamingKernelContent> SendKmsByAppAsync(Apps app, string questions, string history, List<RelevantSource> relevantSources=null)
public async IAsyncEnumerable<StreamingKernelContent> SendKmsByAppAsync(Apps app, string questions, ChatHistory history, string filePath, List<RelevantSource> relevantSources = null)
{
var relevantSourceList = await _kMService.GetRelevantSourceList(app.KmsIdList, questions);
var _kernel = _kernelService.GetKernelByApp(app);
//知识库问答
var filters = new List<MemoryFilter>();
var kmsidList = app.KmsIdList.Split(",");
//只取第一个知识库的配置
var _memory = _kMService.GetMemoryByKMS(kmsidList.FirstOrDefault());
foreach (var kmsid in kmsidList)
if (!string.IsNullOrWhiteSpace(filePath))
{
filters.Add(new MemoryFilter().ByTag("kmsid", kmsid));
}
var xlresult = await _memory.SearchAsync(questions, index: "kms", filters: filters);
string dataMsg = "";
if (xlresult != null)
{
foreach (var item in xlresult.Results)
var memory = _kMService.GetMemory(app);
var fileId = Guid.NewGuid().ToString();
var result = await memory.ImportDocumentAsync(new Microsoft.KernelMemory.Document(fileId).AddFile(filePath)
.AddTag(KmsConstantcs.KmsIdTag, app.Id)
, index: KmsConstantcs.KmsIndex);
var filters = new MemoryFilter().ByTag(KmsConstantcs.KmsIdTag, app.Id);
var searchResult = await memory.SearchAsync(questions, index: KmsConstantcs.KmsIndex, filters: [filters]);
relevantSourceList.AddRange(searchResult.Results.SelectMany(item => item.Partitions.Select(part => new RelevantSource()
{
foreach (var part in item.Partitions)
{
dataMsg += $"[file:{item.SourceName};Relevance:{(part.Relevance * 100).ToString("F2")}%]:{part.Text}{Environment.NewLine}";
SourceName = item.SourceName,
Text = Markdown.ToHtml(part.Text),
Relevance = part.Relevance
})));
}
if (relevantSources.IsNotNull())
{
var markdown = new Markdown();
string sourceName = item.SourceName;
var fileDetail = _kmsDetails_Repositories.GetFirst(p => p.FileGuidName == item.SourceName);
if (fileDetail.IsNotNull())
{
sourceName = fileDetail.FileName;
}
relevantSources.Add(new RelevantSource() { SourceName = sourceName, Text = markdown.Transform(part.Text), Relevance = part.Relevance });
}
}
var dataMsg = new StringBuilder();
if (relevantSourceList.Any())
{
relevantSources?.AddRange(relevantSourceList);
foreach (var item in relevantSourceList)
{
dataMsg.AppendLine(item.ToString());
}
KernelFunction jsonFun = _kernel.Plugins.GetFunction("KMSPlugin", "Ask");
var chatResult = _kernel.InvokeStreamingAsync(function: jsonFun,
arguments: new KernelArguments() { ["doc"] = dataMsg, ["history"] = history, ["questions"] = questions });
MessageInfo info = null;
var markdown1 = new Markdown();
KernelFunction jsonFun = _kernel.Plugins.GetFunction("KMSPlugin", "Ask1");
var chatResult = _kernel.InvokeStreamingAsync(function: jsonFun,
arguments: new KernelArguments() { ["doc"] = dataMsg, ["history"] = string.Join("\n", history.Select(x => x.Role + ": " + x.Content)), ["questions"] = questions });
await foreach (var content in chatResult)
{
yield return content;
}
}
}
else
{
yield return new StreamingTextContent(KmsConstantcs.KmsSearchNull);
}
}
public async Task<string> SendImgByAppAsync(Apps app, string questions)
{
var imageModel = _aIModels_Repositories.GetFirst(p => p.Id == app.ImageModelID);
KernelArguments args = new() {
{ "input", questions }
};
var _kernel = _kernelService.GetKernelByApp(app);
var temperature = app.Temperature / 100; //存的是0~100需要缩小
OpenAIPromptExecutionSettings settings = new() { Temperature = temperature };
var func = _kernel.CreateFunctionFromPrompt("你是一个StableDiffusion提示词助手,需要将用户问题转化为StableDiffusion的英文提示词并返回,请注意只返回提示词不要有其他多余内容,用户的问题是:{{$input}}", settings);
var chatResult = await _kernel.InvokeAsync(function: func, arguments: args);
if (chatResult.IsNotNull())
{
string prompt = chatResult.GetValue<string>();
if (!SDHelper.IsInitialized)
{
Structs.ModelParams modelParams = new Structs.ModelParams
{
ModelPath = imageModel.ModelName,
RngType = Structs.RngType.CUDA_RNG,
//VaePath = vaePath,
//KeepVaeOnCpu = keepVaeOnCpu,
//VaeTiling = vaeTiling,
//LoraModelDir = loraModelDir,
};
bool result = SDHelper.Initialize(modelParams);
}
Structs.TextToImageParams textToImageParams = new Structs.TextToImageParams
{
Prompt = prompt,
NegativePrompt = "2d, 3d, cartoon, paintings",
SampleMethod = (Structs.SampleMethod)Enum.Parse(typeof(Structs.SampleMethod), "EULER_A"),
Width = 256,
Height = 256,
NormalizeInput = true,
ClipSkip = -1,
CfgScale = 7,
SampleSteps = 20,
Seed = -1,
};
Bitmap[] outputImages = SDHelper.TextToImage(textToImageParams);
var base64 = ImageUtils.BitmapToBase64(outputImages[0]);
return base64;
}
else
{
return "";
}
}
public async Task<ChatHistory> GetChatHistory(List<MessageInfo> MessageList)
{
ChatHistory history = new ChatHistory();
if (MessageList.Count > 1)
{
foreach (var item in MessageList)
{
if (item.IsSend)
{
history.AddUserMessage(item.Context);
}
else
{
history.AddAssistantMessage(item.Context);
}
}
}
return history;
}
}
}
}

View File

@@ -0,0 +1,122 @@
using System.ComponentModel;
using System.Reflection;
using System.Runtime.Loader;
using System.Xml;
using AntSK.Domain.Common;
using AntSK.Domain.Utils;
using System.Text.RegularExpressions;
using Microsoft.SemanticKernel;
using HtmlAgilityPack;
using System.Collections.Generic;
namespace AntSK.Domain.Domain.Service
{
public class FunctionService
{
private readonly Dictionary<string, MethodInfo> _methodCache;
private readonly Dictionary<string, (string Description, (Type ParameterType, string Description) ReturnType, (string ParameterName, Type ParameterType, string Description)[] Parameters)> _methodInfos;
private readonly IServiceProvider _serviceProvider;
private Assembly[] _assemblies;
private readonly AssemblyLoadContext loadContext;
public FunctionService(IServiceProvider serviceProvider, Assembly[] assemblies)
{
_methodCache = [];
_methodInfos = [];
_serviceProvider = serviceProvider;
_assemblies = assemblies;
loadContext = new AssemblyLoadContext("AntSKLoadContext", true);
}
public Dictionary<string, MethodInfo> Functions => _methodCache;
public Dictionary<string, (string Description, (Type ParameterType, string Description) ReturnType, (string ParameterName, Type ParameterType, string Description)[] Parameters)> MethodInfos => _methodInfos;
/// <summary>
/// 查询程序集中的方法委托后续利用Source Generators生成
/// </summary>
public void SearchMarkedMethods()
{
var markedMethods = new List<MethodInfo>();
_methodCache.Clear();
_methodInfos.Clear();
foreach (var assembly in _assemblies)
{
// 从缓存中获取标记了ActionAttribute的方法
foreach (var type in assembly.GetTypes())
{
markedMethods.AddRange(type.GetMethods().Where(m =>
{
DescriptionAttribute da = (DescriptionAttribute)m.GetCustomAttributes(typeof(DescriptionAttribute), true).FirstOrDefault();
return da != null && da.Description.Contains( "AntSK");
}));
}
}
//动态加载部分
var loadedAssemblies = loadContext.Assemblies.ToList();
foreach (var assembly in loadedAssemblies)
{
// 从缓存中获取标记了ActionAttribute的方法
foreach (var type in assembly.GetTypes())
{
markedMethods.AddRange(type.GetMethods().Where(m =>
{
DescriptionAttribute da = (DescriptionAttribute)m.GetCustomAttributes(typeof(DescriptionAttribute), true).FirstOrDefault();
return da != null && da.Description.Contains("AntSK");
}));
}
}
// 构建方法调用
foreach (var method in markedMethods)
{
var key = $"{method.DeclaringType.Assembly.GetName().Name}_{method.DeclaringType.Name}_{method.Name}";
string pattern = "[^a-zA-Z0-9_]";
// 使用 '-' 替换非ASCII的正则表达式的字符
key = Regex.Replace(key, pattern, "_");
_methodCache.TryAdd(key, method);
var description= method.GetCustomAttribute<DescriptionAttribute>().Description.ConvertToString().Replace("AntSK:","");
var returnType = method.ReturnParameter.GetCustomAttribute<DescriptionAttribute>().Description.ConvertToString();
var parameters = method.GetParameters().Select(x => (x.Name, x.ParameterType,x.GetCustomAttribute<DescriptionAttribute>()?.Description)).ToArray();
// 假设 _methodInfos 是一个已经定义好的字典,用来保存方法的相关信息
_methodInfos.TryAdd(key, (description, (method.ReflectedType, returnType), parameters));
}
}
public void FuncLoad(string pluginPath)
{
try
{
if (File.Exists(pluginPath))
{
string directory = Path.GetDirectoryName(pluginPath);
string fileName = Path.GetFileName(pluginPath);
var resolver = new AssemblyDependencyResolver(directory);
// Create a custom AssemblyLoadContext
loadContext.Resolving += (context, assemblyName) =>
{
string assemblyPath = resolver.ResolveAssemblyToPath(assemblyName);
if (assemblyPath != null)
{
return context.LoadFromAssemblyPath(assemblyPath);
}
return null;
};
// Load your assembly
Assembly pluginAssembly = loadContext.LoadFromAssemblyPath(pluginPath);
}
}
catch (Exception ex)
{
Console.WriteLine(ex.Message + " ---- " + ex.StackTrace);
}
}
}
}

View File

@@ -1,6 +1,7 @@
using AntSK.Domain.Common.DependencyInjection;
using AntSK.Domain.Domain.Interface;
using AntSK.Domain.Model;
using AntSK.Domain.Domain.Model;
using AntSK.Domain.Domain.Model.Constant;
using AntSK.Domain.Repositories;
using Microsoft.KernelMemory;
@@ -29,8 +30,8 @@ namespace AntSK.Domain.Domain.Service
{
var importResult = _memory.ImportDocumentAsync(new Document(fileid)
.AddFile(req.FilePath)
.AddTag("kmsid", req.KmsId)
, index: "kms").Result;
.AddTag(KmsConstantcs.KmsIdTag, req.KmsId)
, index: KmsConstantcs.KmsIndex).Result;
//查询文档数量
var docTextList = _kMService.GetDocumentByFileID(km.Id, fileid).Result;
string fileGuidName = Path.GetFileName(req.FilePath);
@@ -43,8 +44,8 @@ namespace AntSK.Domain.Domain.Service
case ImportType.Url:
{
//导入url
var importResult = _memory.ImportWebPageAsync(req.Url, fileid, new TagCollection() { { "kmsid", req.KmsId } }
, index: "kms").Result;
var importResult = _memory.ImportWebPageAsync(req.Url, fileid, new TagCollection() { { KmsConstantcs.KmsIdTag, req.KmsId } }
, index: KmsConstantcs.KmsIndex).Result;
//查询文档数量
var docTextList = _kMService.GetDocumentByFileID(km.Id, fileid).Result;
req.KmsDetail.Url = req.Url;
@@ -54,8 +55,8 @@ namespace AntSK.Domain.Domain.Service
case ImportType.Text:
//导入文本
{
var importResult = _memory.ImportTextAsync(req.Text, fileid, new TagCollection() { { "kmsid", req.KmsId } }
, index: "kms").Result;
var importResult = _memory.ImportTextAsync(req.Text, fileid, new TagCollection() { { KmsConstantcs.KmsIdTag, req.KmsId } }
, index: KmsConstantcs.KmsIndex).Result;
//查询文档数量
var docTextList = _kMService.GetDocumentByFileID(km.Id, fileid).Result;
req.KmsDetail.Url = req.Url;

View File

@@ -1,71 +1,126 @@
using AntSK.Domain.Common.DependencyInjection;
using AntSK.Domain.Domain.Dto;
using AntDesign;
using AntSK.Domain.Common.DependencyInjection;
using AntSK.Domain.Common.Embedding;
using AntSK.Domain.Domain.Interface;
using AntSK.Domain.Domain.Model.Constant;
using AntSK.Domain.Domain.Model.Dto;
using AntSK.Domain.Domain.Other;
using AntSK.Domain.Options;
using AntSK.Domain.Repositories;
using AntSK.Domain.Utils;
using DocumentFormat.OpenXml.Drawing.Diagrams;
using LLama;
using LLamaSharp.KernelMemory;
using Markdig;
using Microsoft.AspNetCore.Components;
using Microsoft.Extensions.Configuration;
using Microsoft.KernelMemory;
using Microsoft.KernelMemory.Configuration;
using Microsoft.KernelMemory.ContentStorage.DevTools;
using Microsoft.KernelMemory.FileSystem.DevTools;
using Microsoft.KernelMemory.MemoryStorage;
using Microsoft.KernelMemory.MemoryStorage.DevTools;
using Microsoft.KernelMemory.Postgres;
namespace AntSK.Domain.Domain.Service
{
[ServiceDescription(typeof(IKMService), ServiceLifetime.Scoped)]
public class KMService(
IConfiguration _config,
IKmss_Repositories _kmss_Repositories,
IAIModels_Repositories _aIModels_Repositories
) : IKMService
IKmss_Repositories _kmss_Repositories,
IAIModels_Repositories _aIModels_Repositories,
IMessageService? _message
) : IKMService
{
public MemoryServerless GetMemoryByKMS(string kmsID, SearchClientConfig searchClientConfig = null)
{
//获取KMS配置
var kms = _kmss_Repositories.GetFirst(p => p.Id == kmsID);
var chatModel = _aIModels_Repositories.GetFirst(p => p.Id == kms.ChatModelID);
var embedModel = _aIModels_Repositories.GetFirst(p => p.Id == kms.EmbeddingModelID);
private MemoryServerless _memory;
//http代理
private List<UploadFileItem> _fileList = [];
public List<UploadFileItem> FileList => _fileList;
public MemoryServerless GetMemory(Apps app)
{
var chatModel = _aIModels_Repositories.GetFirst(p => p.Id == app.ChatModelID);
var embedModel = _aIModels_Repositories.GetFirst(p => p.Id == app.EmbeddingModelID);
var chatHttpClient = OpenAIHttpClientHandlerUtil.GetHttpClient(chatModel.EndPoint);
var embeddingHttpClient = OpenAIHttpClientHandlerUtil.GetHttpClient(embedModel.EndPoint);
//搜索配置
if (searchClientConfig.IsNull())
var searchClientConfig = new SearchClientConfig
{
searchClientConfig = new SearchClientConfig
{
MaxAskPromptSize = 2048,
MaxMatchesCount = 3,
AnswerTokens = 1000,
EmptyAnswer = "知识库未搜索到相关内容"
};
}
MaxAskPromptSize = 2048,
MaxMatchesCount = 3,
AnswerTokens = 1000,
EmptyAnswer = KmsConstantcs.KmsSearchNull
};
var memory = new KernelMemoryBuilder()
.WithSearchClientConfig(searchClientConfig)
.WithCustomTextPartitioningOptions(new TextPartitioningOptions
{
MaxTokensPerLine = kms.MaxTokensPerLine,
MaxTokensPerParagraph = kms.MaxTokensPerParagraph,
OverlappingTokens = kms.OverlappingTokens
});
//加载huihu 模型
WithTextGenerationByAIType(memory, chatModel, chatHttpClient);
var memoryBuild = new KernelMemoryBuilder()
.WithSearchClientConfig(searchClientConfig)
//.WithCustomTextPartitioningOptions(new TextPartitioningOptions
//{
// MaxTokensPerLine = app.MaxTokensPerLine,
// MaxTokensPerParagraph = kms.MaxTokensPerParagraph,
// OverlappingTokens = kms.OverlappingTokens
//})
;
//加载会话模型
WithTextGenerationByAIType(memoryBuild, chatModel, chatHttpClient);
//加载向量模型
WithTextEmbeddingGenerationByAIType(memory, embedModel, embeddingHttpClient);
WithTextEmbeddingGenerationByAIType(memoryBuild, embedModel, embeddingHttpClient);
//加载向量库
WithMemoryDbByVectorDB(memory, _config);
var result = memory.Build<MemoryServerless>();
return result;
WithMemoryDbByVectorDB(memoryBuild);
_memory = memoryBuild.Build<MemoryServerless>();
return _memory;
}
private void WithTextEmbeddingGenerationByAIType(IKernelMemoryBuilder memory, AIModels embedModel, HttpClient embeddingHttpClient)
public MemoryServerless GetMemoryByKMS(string kmsID, SearchClientConfig searchClientConfig = null)
{
//if (_memory.IsNull())
{
//获取KMS配置
var kms = _kmss_Repositories.GetFirst(p => p.Id == kmsID);
var chatModel = _aIModels_Repositories.GetFirst(p => p.Id == kms.ChatModelID);
var embedModel = _aIModels_Repositories.GetFirst(p => p.Id == kms.EmbeddingModelID);
//http代理
var chatHttpClient = OpenAIHttpClientHandlerUtil.GetHttpClient(chatModel.EndPoint);
var embeddingHttpClient = OpenAIHttpClientHandlerUtil.GetHttpClient(embedModel.EndPoint);
//搜索配置
if (searchClientConfig.IsNull())
{
searchClientConfig = new SearchClientConfig
{
MaxAskPromptSize = 2048,
MaxMatchesCount = 3,
AnswerTokens = 1000,
EmptyAnswer = KmsConstantcs.KmsSearchNull
};
}
var memoryBuild = new KernelMemoryBuilder()
.WithSearchClientConfig(searchClientConfig)
.WithCustomTextPartitioningOptions(new TextPartitioningOptions
{
MaxTokensPerLine = kms.MaxTokensPerLine,
MaxTokensPerParagraph = kms.MaxTokensPerParagraph,
OverlappingTokens = kms.OverlappingTokens
});
//加载会话模型
WithTextGenerationByAIType(memoryBuild, chatModel, chatHttpClient);
//加载向量模型
WithTextEmbeddingGenerationByAIType(memoryBuild, embedModel, embeddingHttpClient);
//加载向量库
WithMemoryDbByVectorDB(memoryBuild);
_memory = memoryBuild.Build<MemoryServerless>();
return _memory;
}
//else {
// return _memory;
//}
}
private void WithTextEmbeddingGenerationByAIType(IKernelMemoryBuilder memory, AIModels embedModel,
HttpClient embeddingHttpClient)
{
switch (embedModel.AIType)
{
@@ -76,6 +131,7 @@ namespace AntSK.Domain.Domain.Service
EmbeddingModel = embedModel.ModelName
}, null, false, embeddingHttpClient);
break;
case Model.Enum.AIType.AzureOpenAI:
memory.WithAzureOpenAITextEmbeddingGeneration(new AzureOpenAIConfig()
{
@@ -86,15 +142,25 @@ namespace AntSK.Domain.Domain.Service
APIType = AzureOpenAIConfig.APITypes.EmbeddingGeneration,
});
break;
case Model.Enum.AIType.LLamaSharp:
var (weights, parameters) = LLamaConfig.GetLLamaConfig(embedModel.ModelName);
var embedder = new LLamaEmbedder(weights, parameters);
memory.WithLLamaSharpTextEmbeddingGeneration(new LLamaSharpTextEmbeddingGenerator(embedder));
break;
case Model.Enum.AIType.BgeEmbedding:
string pyDll = embedModel.EndPoint;
string bgeEmbeddingModelName = embedModel.ModelName;
memory.WithBgeTextEmbeddingGeneration(new HuggingfaceTextEmbeddingGenerator(pyDll,bgeEmbeddingModelName));
break;
case Model.Enum.AIType.DashScope:
memory.WithDashScopeDefaults(embedModel.ModelKey);
break;
}
}
private void WithTextGenerationByAIType(IKernelMemoryBuilder memory, AIModels chatModel, HttpClient chatHttpClient)
private void WithTextGenerationByAIType(IKernelMemoryBuilder memory, AIModels chatModel,
HttpClient chatHttpClient)
{
switch (chatModel.AIType)
{
@@ -105,6 +171,7 @@ namespace AntSK.Domain.Domain.Service
TextModel = chatModel.ModelName
}, null, chatHttpClient);
break;
case Model.Enum.AIType.AzureOpenAI:
memory.WithAzureOpenAITextGeneration(new AzureOpenAIConfig()
{
@@ -115,20 +182,35 @@ namespace AntSK.Domain.Domain.Service
APIType = AzureOpenAIConfig.APITypes.TextCompletion,
});
break;
case Model.Enum.AIType.LLamaSharp:
var (weights, parameters) = LLamaConfig.GetLLamaConfig(chatModel.ModelName);
var context = weights.CreateContext(parameters);
var executor = new StatelessExecutor(weights, parameters);
memory.WithLLamaSharpTextGeneration(new LlamaSharpTextGenerator(weights, context, executor));
break;
case Model.Enum.AIType.LLamaFactory:
memory.WithOpenAITextGeneration(new OpenAIConfig()
{
APIKey = "123",
TextModel = chatModel.ModelName
}, null, chatHttpClient);
break;
case Model.Enum.AIType.DashScope:
memory.WithDashScopeTextGeneration(new Cnblogs.KernelMemory.AI.DashScope.DashScopeConfig
{
ApiKey = chatModel.ModelKey,
});
break;
}
}
private void WithMemoryDbByVectorDB(IKernelMemoryBuilder memory, IConfiguration _config)
private void WithMemoryDbByVectorDB(IKernelMemoryBuilder memory)
{
string VectorDb = _config["KernelMemory:VectorDb"].ConvertToString();
string ConnectionString = _config["KernelMemory:ConnectionString"].ConvertToString();
string TableNamePrefix = _config["KernelMemory:TableNamePrefix"].ConvertToString();
string VectorDb = KernelMemoryOption.VectorDb.ConvertToString();
string ConnectionString = KernelMemoryOption.ConnectionString.ConvertToString();
string TableNamePrefix = KernelMemoryOption.TableNamePrefix.ConvertToString();
switch (VectorDb)
{
case "Postgres":
@@ -138,51 +220,134 @@ namespace AntSK.Domain.Domain.Service
TableNamePrefix = TableNamePrefix
});
break;
case "Disk":
memory.WithSimpleFileStorage(new SimpleFileStorageConfig()
memory.WithSimpleVectorDb(new SimpleVectorDbConfig()
{
StorageType = FileSystemTypes.Disk
StorageType = FileSystemTypes.Disk,
});
break;
case "Memory":
memory.WithSimpleFileStorage(new SimpleFileStorageConfig()
memory.WithSimpleVectorDb(new SimpleVectorDbConfig()
{
StorageType = FileSystemTypes.Volatile
});
break;
case "Qdrant":
var qdrantConfig = ConnectionString.Split("|");
memory.WithQdrantMemoryDb(qdrantConfig[0],qdrantConfig[1]);
break;
case "Redis":
memory.WithRedisMemoryDb(new RedisConfig()
{
ConnectionString = ConnectionString,
});
break;
case "AzureAISearch":
var aisearchConfig = ConnectionString.Split("|");
memory.WithAzureAISearchMemoryDb(aisearchConfig[0], aisearchConfig[1]);
break;
}
}
public async Task<List<KMFile>> GetDocumentByFileID(string kmsid, string fileid)
public async Task<List<KMFile>> GetDocumentByFileID(string kmsId, string fileId)
{
var _memory = GetMemoryByKMS(kmsid);
var memories = await _memory.ListIndexesAsync();
var memoryDbs = _memory.Orchestrator.GetMemoryDbs();
List<KMFile> docTextList = new List<KMFile>();
var memory = GetMemoryByKMS(kmsId);
var memories = await memory.ListIndexesAsync();
var memoryDbs = memory.Orchestrator.GetMemoryDbs();
var docTextList = new List<KMFile>();
foreach (var memoryIndex in memories)
{
foreach (var memoryDb in memoryDbs)
{
var items = await memoryDb.GetListAsync(memoryIndex.Name, new List<MemoryFilter>() { new MemoryFilter().ByDocument(fileid) }, 100, true).ToListAsync();
foreach (var item in items)
var items = await memoryDb.GetListAsync(memoryIndex.Name, new List<MemoryFilter>() { new MemoryFilter().ByDocument(fileId) }, 100, true).ToListAsync();
docTextList.AddRange(items.Select(item => new KMFile()
{
KMFile file = new KMFile()
{
Text = item.Payload.FirstOrDefault(p => p.Key == "text").Value.ConvertToString(),
Url = item.Payload.FirstOrDefault(p => p.Key == "url").Value.ConvertToString(),
LastUpdate = item.Payload.FirstOrDefault(p => p.Key == "last_update").Value.ConvertToString(),
Schema = item.Payload.FirstOrDefault(p => p.Key == "schema").Value.ConvertToString(),
File = item.Payload.FirstOrDefault(p => p.Key == "file").Value.ConvertToString(),
};
docTextList.Add(file);
}
DocumentId = item.GetDocumentId(),
Text = item.GetPartitionText(),
Url = item.GetWebPageUrl(),
LastUpdate = item.GetLastUpdate().LocalDateTime.ToString("yyyy-MM-dd HH:mm:ss"),
File = item.GetFileName()
}));
}
}
return docTextList;
}
public async Task<List<RelevantSource>> GetRelevantSourceList(string kmsIdListStr, string msg)
{
var result = new List<RelevantSource>();
if (string.IsNullOrWhiteSpace(kmsIdListStr))
return result;
var kmsIdList = kmsIdListStr.Split(",");
if (!kmsIdList.Any()) return result;
var memory = GetMemoryByKMS(kmsIdList.FirstOrDefault()!);
var filters = kmsIdList.Select(kmsId => new MemoryFilter().ByTag(KmsConstantcs.KmsIdTag, kmsId)).ToList();
var searchResult = await memory.SearchAsync(msg, index: KmsConstantcs.KmsIndex, filters: filters);
if (!searchResult.NoResult)
{
foreach (var item in searchResult.Results)
{
result.AddRange(item.Partitions.Select(part => new RelevantSource()
{
SourceName = item.SourceName,
Text = Markdown.ToHtml(part.Text),
Relevance = part.Relevance
}));
}
}
return result;
}
public bool BeforeUpload(UploadFileItem file)
{
List<string> types = new List<string>() {
"text/plain",
"application/msword",
"application/vnd.openxmlformats-officedocument.wordprocessingml.document",
"application/vnd.ms-excel",
"application/vnd.openxmlformats-officedocument.spreadsheetml.sheet",
"application/vnd.ms-powerpoint",
"application/vnd.openxmlformats-officedocument.presentationml.presentation",
"application/pdf",
"application/json",
"text/x-markdown",
"text/markdown"
};
string[] exceptExts = [".md", ".pdf"];
var validTypes = types.Contains(file.Type) || exceptExts.Contains(file.Ext);
if (!validTypes && file.Ext != ".md")
{
_message.Error("文件格式错误,请重新选择!");
}
var IsLt500K = file.Size < 1024 * 1024 * 100;
if (!IsLt500K)
{
_message.Error("文件需不大于100MB!");
}
return validTypes && IsLt500K;
}
public void OnSingleCompleted(UploadInfo fileinfo)
{
if (fileinfo.File.State == UploadState.Success)
{
//文件列表
_fileList.Add(new UploadFileItem()
{
FileName = fileinfo.File.FileName,
Url = fileinfo.File.Url = fileinfo.File.Response
});
}
}
}
}
}

View File

@@ -1,7 +1,7 @@
using AntSK.Domain.Common.DependencyInjection;
using AntSK.LLM.SparkDesk;
using AntSK.Domain.Common.DependencyInjection;
using AntSK.Domain.Domain.Interface;
using AntSK.Domain.Domain.Other;
using AntSK.Domain.Model;
using AntSK.Domain.Repositories;
using AntSK.Domain.Utils;
using LLama;
@@ -11,7 +11,13 @@ using Microsoft.SemanticKernel;
using Microsoft.SemanticKernel.Plugins.Core;
using Microsoft.SemanticKernel.TextGeneration;
using RestSharp;
using System;
using ServiceLifetime = AntSK.Domain.Common.DependencyInjection.ServiceLifetime;
using AntSK.LLM.Mock;
using AntSK.Domain.Domain.Model.Enum;
using AntSK.LLM.LLamaFactory;
using System.Reflection;
using DocumentFormat.OpenXml.Drawing;
namespace AntSK.Domain.Domain.Service
{
@@ -20,15 +26,22 @@ namespace AntSK.Domain.Domain.Service
{
private readonly IApis_Repositories _apis_Repositories;
private readonly IAIModels_Repositories _aIModels_Repositories;
private readonly FunctionService _functionService;
private readonly IServiceProvider _serviceProvider;
private Kernel _kernel;
public KernelService(
IApis_Repositories apis_Repositories,
IAIModels_Repositories aIModels_Repositories
)
IAIModels_Repositories aIModels_Repositories,
FunctionService functionService,
IServiceProvider serviceProvider)
{
_apis_Repositories = apis_Repositories;
_aIModels_Repositories = aIModels_Repositories;
_functionService = functionService;
_serviceProvider = serviceProvider;
}
/// <summary>
/// 获取kernel实例依赖注入不好按每个用户去Import不同的插件所以每次new一个新的kernel
/// </summary>
@@ -37,20 +50,26 @@ namespace AntSK.Domain.Domain.Service
/// <returns></returns>
public Kernel GetKernelByApp(Apps app)
{
var chatModel = _aIModels_Repositories.GetFirst(p => p.Id == app.ChatModelID);
//if (_kernel.IsNull())
{
var chatModel = _aIModels_Repositories.GetFirst(p => p.Id == app.ChatModelID);
var chatHttpClient = OpenAIHttpClientHandlerUtil.GetHttpClient(chatModel.EndPoint);
var chatHttpClient = OpenAIHttpClientHandlerUtil.GetHttpClient(chatModel.EndPoint);
var builder = Kernel.CreateBuilder();
WithTextGenerationByAIType(builder, chatModel, chatHttpClient);
var builder = Kernel.CreateBuilder();
WithTextGenerationByAIType(builder, app, chatModel, chatHttpClient);
var kernel = builder.Build();
RegisterPluginsWithKernel(kernel);
return kernel;
_kernel = builder.Build();
RegisterPluginsWithKernel(_kernel);
return _kernel;
}
//else
//{
// return _kernel;
//}
}
private void WithTextGenerationByAIType(IKernelBuilder builder, AIModels chatModel, HttpClient chatHttpClient)
private void WithTextGenerationByAIType(IKernelBuilder builder, Apps app, AIModels chatModel, HttpClient chatHttpClient)
{
switch (chatModel.AIType)
{
@@ -60,6 +79,7 @@ namespace AntSK.Domain.Domain.Service
apiKey: chatModel.ModelKey,
httpClient: chatHttpClient);
break;
case Model.Enum.AIType.AzureOpenAI:
builder.AddAzureOpenAIChatCompletion(
deploymentName: chatModel.ModelName,
@@ -67,11 +87,32 @@ namespace AntSK.Domain.Domain.Service
endpoint: chatModel.EndPoint
);
break;
case Model.Enum.AIType.LLamaSharp:
var (weights, parameters) = LLamaConfig.GetLLamaConfig(chatModel.ModelName);
var ex = new StatelessExecutor(weights, parameters);
builder.Services.AddKeyedSingleton<ITextGenerationService>("local-llama", new LLamaSharpTextCompletion(ex));
break;
case Model.Enum.AIType.SparkDesk:
var options = new SparkDeskOptions { AppId = chatModel.EndPoint, ApiSecret = chatModel.ModelKey, ApiKey = chatModel.ModelName, ModelVersion = Sdcb.SparkDesk.ModelVersion.V3_5 };
builder.Services.AddKeyedSingleton<ITextGenerationService>("spark-desk", new SparkDeskTextCompletion(options, app.Id));
break;
case Model.Enum.AIType.DashScope:
builder.Services.AddDashScopeChatCompletion(chatModel.ModelKey, chatModel.ModelName);
break;
case Model.Enum.AIType.Mock:
builder.Services.AddKeyedSingleton<ITextGenerationService>("mock", new MockTextCompletion());
break;
case Model.Enum.AIType.LLamaFactory:
builder.AddOpenAIChatCompletion(
modelId: chatModel.ModelName,
apiKey: "123",
httpClient: chatHttpClient
);
break;
}
}
@@ -82,25 +123,57 @@ namespace AntSK.Domain.Domain.Service
/// <param name="_kernel"></param>
public void ImportFunctionsByApp(Apps app, Kernel _kernel)
{
//开启自动插件调用
var apiIdList = app.ApiFunctionList.Split(",");
var apiList = _apis_Repositories.GetList(p => apiIdList.Contains(p.Id));
List<KernelFunction> functions = new List<KernelFunction>();
var plugin = _kernel.Plugins.FirstOrDefault(p => p.Name == "ApiFunctions");
//插件不能重复注册,否则会异常
if (_kernel.Plugins.Any(p => p.Name == "AntSkFunctions"))
{
return;
}
List<KernelFunction> functions = new List<KernelFunction>();
//API插件
ImportApiFunction(app, functions);
//本地函数插件
ImportNativeFunction(app, functions);
_kernel.ImportPluginFromFunctions("AntSkFunctions", functions);
}
/// <summary>
/// 导入API插件
/// </summary>
/// <param name="app"></param>
/// <param name="functions"></param>
private void ImportApiFunction(Apps app, List<KernelFunction> functions)
{
if (!string.IsNullOrWhiteSpace(app.ApiFunctionList))
{
//开启自动插件调用
var apiIdList = app.ApiFunctionList.Split(",");
var apiList = _apis_Repositories.GetList(p => apiIdList.Contains(p.Id));
foreach (var api in apiList)
{
var returnType = new KernelReturnParameterMetadata() { Description = api.OutputPrompt };
switch (api.Method)
{
case HttpMethodType.Get:
functions.Add(_kernel.CreateFunctionFromMethod((string msg) =>
var getParametes = new List<KernelParameterMetadata>() {
new KernelParameterMetadata("jsonbody"){
Name="json参数字符串",
ParameterType=typeof(string),
Description=$"背景文档:{Environment.NewLine}{api.InputPrompt} {Environment.NewLine}提取出对应的json格式字符串参考如下格式:{Environment.NewLine}{api.Query}"
}
};
functions.Add(_kernel.CreateFunctionFromMethod((string jsonbody) =>
{
try
{
Console.WriteLine(msg);
//将json 转换为query参数
var queryString = Newtonsoft.Json.JsonConvert.DeserializeObject<Dictionary<string, string>>(jsonbody);
RestClient client = new RestClient();
RestRequest request = new RestRequest(api.Url, Method.Get);
foreach (var header in api.Header.Split("\n"))
foreach (var header in api.Header.ConvertToString().Split("\n"))
{
var headerArray = header.Split(":");
if (headerArray.Length == 2)
@@ -109,13 +182,9 @@ namespace AntSK.Domain.Domain.Service
}
}
//这里应该还要处理一次参数提取,等后面再迭代
foreach (var query in api.Query.Split("\n"))
foreach (var q in queryString)
{
var queryArray = query.Split("=");
if (queryArray.Length == 2)
{
request.AddQueryParameter(queryArray[0], queryArray[1]);
}
request.AddQueryParameter(q.Key, q.Value);
}
var result = client.Execute(request);
return result.Content;
@@ -124,17 +193,25 @@ namespace AntSK.Domain.Domain.Service
{
return "调用失败:" + ex.Message;
}
}, api.Name, $"{api.Describe}"));
}, api.Name, api.Describe, getParametes, returnType));
break;
case HttpMethodType.Post:
functions.Add(_kernel.CreateFunctionFromMethod((string msg) =>
//处理json body
var postParametes = new List<KernelParameterMetadata>() {
new KernelParameterMetadata("jsonbody"){
Name="json参数字符串",
ParameterType=typeof(string),
Description=$"背景文档:{Environment.NewLine}{api.InputPrompt} {Environment.NewLine}提取出对应的json格式字符串参考如下格式:{Environment.NewLine}{api.JsonBody}"
}
};
functions.Add(_kernel.CreateFunctionFromMethod((string jsonBody) =>
{
try
{
Console.WriteLine(msg);
Console.WriteLine(jsonBody);
RestClient client = new RestClient();
RestRequest request = new RestRequest(api.Url, Method.Post);
foreach (var header in api.Header.Split("\n"))
foreach (var header in api.Header.ConvertToString().Split("\n"))
{
var headerArray = header.Split(":");
if (headerArray.Length == 2)
@@ -143,7 +220,7 @@ namespace AntSK.Domain.Domain.Service
}
}
//这里应该还要处理一次参数提取,等后面再迭代
request.AddJsonBody(api.JsonBody);
request.AddJsonBody(jsonBody.ConvertToString());
var result = client.Execute(request);
return result.Content;
}
@@ -151,24 +228,51 @@ namespace AntSK.Domain.Domain.Service
{
return "调用失败:" + ex.Message;
}
}, api.Name, $"{api.Describe}"));
}, api.Name, api.Describe, postParametes, returnType));
break;
}
}
_kernel.ImportPluginFromFunctions("ApiFunctions", functions);
}
}
/// <summary>
/// 导入原生插件
/// </summary>
/// <param name="app"></param>
/// <param name="functions"></param>
private void ImportNativeFunction(Apps app, List<KernelFunction> functions)
{
if (!string.IsNullOrWhiteSpace(app.NativeFunctionList))//需要添加判断应用是否开启了本地函数插件
{
var nativeIdList = app.NativeFunctionList.Split(",");
_functionService.SearchMarkedMethods();
using var scope = _serviceProvider.CreateScope();
foreach (var func in _functionService.Functions)
{
if (nativeIdList.Contains(func.Key))
{
var methodInfo = _functionService.MethodInfos[func.Key];
var parameters = methodInfo.Parameters.Select(x => new KernelParameterMetadata(x.ParameterName) { ParameterType = x.ParameterType, Description = x.Description });
var returnType = new KernelReturnParameterMetadata() { ParameterType = methodInfo.ReturnType.ParameterType, Description = methodInfo.ReturnType.Description };
var target = ActivatorUtilities.CreateInstance(scope.ServiceProvider, func.Value.DeclaringType);
functions.Add(_kernel.CreateFunctionFromMethod(func.Value, target, func.Key, methodInfo.Description, parameters, returnType));
}
}
}
}
/// <summary>
/// 注册默认插件
/// </summary>
/// <param name="kernel"></param>
void RegisterPluginsWithKernel(Kernel kernel)
private void RegisterPluginsWithKernel(Kernel kernel)
{
kernel.ImportPluginFromObject(new ConversationSummaryPlugin(), "ConversationSummaryPlugin");
kernel.ImportPluginFromObject(new TimePlugin(), "TimePlugin");
kernel.ImportPluginFromPromptDirectory(Path.Combine(RepoFiles.SamplePluginsPath(), "KMSPlugin"));
//kernel.ImportPluginFromObject(new TimePlugin(), "TimePlugin");
kernel.ImportPluginFromPromptDirectory(System.IO.Path.Combine(RepoFiles.SamplePluginsPath(), "KMSPlugin"));
}
/// <summary>
@@ -183,8 +287,8 @@ namespace AntSK.Domain.Domain.Service
KernelFunction sunFun = _kernel.Plugins.GetFunction("ConversationSummaryPlugin", "SummarizeConversation");
var summary = await _kernel.InvokeAsync(sunFun, new() { ["input"] = $"内容是:{history.ToString()} {Environment.NewLine} 请注意用中文总结" });
string his = summary.GetValue<string>();
var msg = $"history{history.ToString()}{Environment.NewLine} user{questions}"; ;
var msg = $"history{Environment.NewLine}{history.ToString()}{Environment.NewLine} user{questions}{Environment.NewLine}"; ;
return msg;
}
}
}
}

View File

@@ -0,0 +1,164 @@
using AntSK.Domain.Common.DependencyInjection;
using AntSK.Domain.Domain.Interface;
using AntSK.Domain.Domain.Model.Dto;
using AntSK.Domain.Options;
using AntSK.LLamaFactory.Model;
using Microsoft.AspNetCore.Mvc.ModelBinding;
using Newtonsoft.Json;
using System;
using System.Collections.Generic;
using System.Diagnostics;
using System.Linq;
using System.Text;
using System.Text.Json;
using System.Threading.Tasks;
namespace AntSK.Domain.Domain.Service
{
[ServiceDescription(typeof(ILLamaFactoryService), ServiceLifetime.Singleton)]
public class LLamaFactoryService : ILLamaFactoryService
{
private Process process;
public static bool isProcessComplete = false;
private readonly object _syncLock = new object();
private List<LLamaModel> modelList = new List<LLamaModel>();
public LLamaFactoryService() { }
public delegate Task LogMessageHandler(string message);
public event LogMessageHandler LogMessageReceived;
protected virtual async Task OnLogMessageReceived(string message)
{
LogMessageReceived?.Invoke(message);
}
public async Task PipInstall()
{
var cmdTask = Task.Factory.StartNew(() =>
{
var isProcessComplete = false;
process = new Process
{
StartInfo = new ProcessStartInfo
{
FileName = "pip",
Arguments = "install -r requirements.txt -i https://pypi.tuna.tsinghua.edu.cn/simple",
UseShellExecute = false,
RedirectStandardOutput = true,
RedirectStandardError = true,
WorkingDirectory = AppDomain.CurrentDomain.BaseDirectory,
}
};
process.OutputDataReceived += (sender, eventArgs) =>
{
Console.WriteLine($"{eventArgs.Data}");
OnLogMessageReceived(eventArgs.Data);
};
process.ErrorDataReceived += (sender, eventArgs) =>
{
Console.WriteLine($"{eventArgs.Data}");
OnLogMessageReceived(eventArgs.Data);
};
process.Start();
process.BeginOutputReadLine();
process.BeginErrorReadLine();
process.WaitForExit();
}, TaskCreationOptions.LongRunning);
}
public async Task StartLLamaFactory(string modelName, string templateName)
{
var cmdTask = Task.Factory.StartNew(() =>
{
var isProcessComplete = false;
process = new Process
{
StartInfo = new ProcessStartInfo
{
FileName = "python",
Arguments = "api_demo.py --model_name_or_path " + modelName + " --template " + templateName + " ",
UseShellExecute = false,
RedirectStandardOutput = true,
RedirectStandardError=true,
WorkingDirectory = Path.Combine(Path.GetDirectoryName(System.Reflection.Assembly.GetEntryAssembly().Location), "llamafactory"),
}
};
process.StartInfo.Environment["CUDA_VISIBLE_DEVICES"] = "0";
process.StartInfo.Environment["API_PORT"] = "8000";
process.StartInfo.EnvironmentVariables["USE_MODELSCOPE_HUB"] = "1";
process.OutputDataReceived += (sender, eventArgs) =>
{
Console.WriteLine($"{eventArgs.Data}");
OnLogMessageReceived(eventArgs.Data);
};
process.ErrorDataReceived += (sender, eventArgs) =>
{
Console.WriteLine($"{eventArgs.Data}");
OnLogMessageReceived(eventArgs.Data);
};
process.Start();
process.BeginOutputReadLine();
process.BeginErrorReadLine();
process.WaitForExit();
}, TaskCreationOptions.LongRunning);
}
private void Process_OutputDataReceived(object sender, DataReceivedEventArgs e)
{
throw new NotImplementedException();
}
public string WaitForProcessExit()
{
process.WaitForExit();
return process.StandardOutput.ReadToEnd();
}
public void KillProcess()
{
try
{
Process[] processes = Process.GetProcesses();
foreach (Process process1 in processes)
{
if (process1.ProcessName.ToLower() == "python")
{
process1.Kill();
System.Console.WriteLine("kill python");
}
}
}
catch (InvalidOperationException ex)
{
// Process already exited.
}
}
public List<LLamaModel> GetLLamaFactoryModels()
{
if (modelList.Count==0)
{
string jsonString = File.ReadAllText(Path.Combine(AppDomain.CurrentDomain.BaseDirectory, "modelList.json"));
// 反序列化 JSON 字符串到相应的 C# 对象
var Models = JsonConvert.DeserializeObject<List<LLamaFactoryModel>>(jsonString);
foreach (var model in Models)
{
foreach (var m in model.Models)
{
modelList.Add(new LLamaModel() { Name=m.Key, ModelScope=m.Value.MODELSCOPE });
}
}
}
return modelList;
}
}
}

View File

@@ -1,13 +0,0 @@
using AutoMapper;
namespace AntSK.Domain.Map
{
public class AutoMapProfile : Profile
{
public AutoMapProfile()
{
}
}
}

View File

@@ -1,48 +0,0 @@
using AutoMapper;
namespace AntSK.Domain.Map
{
public static class MapperExtend
{
/// <summary>
/// Entity集合转DTO集合
/// </summary>
/// <typeparam name="T"></typeparam>
/// <param name="value"></param>
/// <returns></returns>
public static List<T> ToDTOList<T>(this object value)
{
if (value == null)
return new List<T>();
return Mapper.Map<List<T>>(value);
}
/// <summary>
/// Entity转DTO
/// </summary>
/// <typeparam name="T"></typeparam>
/// <param name="value"></param>
/// <returns></returns>
public static T ToDTO<T>(this object value)
{
if (value == null)
return default(T);
return Mapper.Map<T>(value);
}
/// <summary>
/// 给已有对象map,适合update场景如需过滤空值需要在AutoMapProfile 设置
/// </summary>
/// <typeparam name="T"></typeparam>
/// <param name="self"></param>
/// <param name="result"></param>
/// <returns></returns>
public static T MapTo<T>(this object self, T result)
{
if (self == null)
return default(T);
return (T)Mapper.Map(self, result, self.GetType(), typeof(T));
}
}
}

View File

@@ -1,30 +0,0 @@
using AutoMapper;
using Microsoft.Extensions.DependencyInjection;
namespace AntSK.Domain.Map
{
public static class MapperRegister
{
public static void AddMapper(this IServiceCollection services)
{
var config = new MapperConfiguration(cfg =>
{
cfg.CreateMissingTypeMaps = true;
cfg.ValidateInlineMaps = false;
cfg.ShouldMapMethod = m => false;
cfg.AddProfile<AutoMapProfile>();
});
IMapper mapper = config.CreateMapper();
//启动实体映射
Mapper.Initialize(cfg =>
{
cfg.CreateMissingTypeMaps = true;
cfg.ValidateInlineMaps = false;
cfg.ShouldMapMethod = m => false;
cfg.AddProfile<AutoMapProfile>();
});
}
}
}

View File

@@ -1,21 +0,0 @@
namespace AntSK.Domain.Model.Enum
{
/// <summary>
/// AI类型
/// </summary>
public enum AIType
{
OpenAI = 1,
AzureOpenAI = 2,
LLamaSharp = 3
}
/// <summary>
/// 模型类型
/// </summary>
public enum AIModelType
{
Chat = 1,
Embedding = 2,
}
}

View File

@@ -3,8 +3,6 @@
public class LLamaSharpOption
{
public static string RunType { get; set; }
public static string Chat { get; set; }
public static string Embedding { get; set; }
public static string FileDirectory { get; set; } = Directory.GetCurrentDirectory();
}
}

View File

@@ -1,4 +1,4 @@
using AntSK.Domain.Model;
using AntSK.Domain.Domain.Model.Enum;
using SqlSugar;
using System.ComponentModel.DataAnnotations;

View File

@@ -20,6 +20,7 @@ namespace AntSK.Domain.Repositories
/// </summary>
[Required]
public string Describe { get; set; }
/// <summary>
/// 图标
/// </summary>
@@ -38,6 +39,12 @@ namespace AntSK.Domain.Repositories
[Required]
public string? ChatModelID { get; set; }
/// <summary>
/// Embedding 模型Id
/// </summary>
public string? EmbeddingModelID { get; set; }
public string? ImageModelID { get; set; }
/// <summary>
/// 温度
/// </summary>
@@ -52,16 +59,23 @@ namespace AntSK.Domain.Repositories
/// <summary>
/// 插件列表
/// </summary>
[SugarColumn(ColumnDataType = "varchar(1000)")]
public string? ApiFunctionList { get; set; }
/// <summary>
/// 本地函数列表
/// </summary>
[SugarColumn(ColumnDataType = "varchar(1000)")]
public string? NativeFunctionList { get; set; }
/// <summary>
/// 知识库ID列表
/// </summary>
public string? KmsIdList { get; set; }
/// <summary>
/// API调用秘钥
/// </summary>
public string? SecretKey { get; set; }
}
}
}

View File

@@ -0,0 +1,20 @@
using AntSK.Domain.Domain.Model.Enum;
using SqlSugar;
using System.ComponentModel.DataAnnotations;
namespace AntSK.Domain.Repositories
{
[SugarTable("Funs")]
public partial class Funs
{
[SugarColumn(IsPrimaryKey = true)]
public string Id { get; set; }
/// <summary>
/// 接口描述
/// </summary>
[Required]
public string Path { get; set; }
}
}

View File

@@ -0,0 +1,11 @@

using AntSK.Domain.Common.DependencyInjection;
using AntSK.Domain.Repositories.Base;
namespace AntSK.Domain.Repositories
{
[ServiceDescription(typeof(IFuns_Repositories), ServiceLifetime.Scoped)]
public class Funs_Repositories : Repository<Funs>, IFuns_Repositories
{
}
}

View File

@@ -0,0 +1,8 @@
using AntSK.Domain.Repositories.Base;
namespace AntSK.Domain.Repositories
{
public interface IFuns_Repositories : IRepository<Funs>
{
}
}

View File

@@ -1,4 +1,4 @@
using AntSK.Domain.Model.Enum;
using AntSK.Domain.Domain.Model.Enum;
using SqlSugar;
namespace AntSK.Domain.Repositories

View File

@@ -1,4 +1,4 @@
using AntSK.Domain.Model;
using AntSK.Domain.Domain.Model;
using SqlSugar;
using System.Linq.Expressions;

View File

@@ -1,5 +1,6 @@
using AntSK.Domain.Map;
using AntSK.Domain.Model;
using AntSK.Domain.Common.Map;
using AntSK.Domain.Domain.Model;
using SqlSugar;
using System.Linq.Expressions;

View File

@@ -1,4 +1,4 @@
using AntSK.Domain.Model.Enum;
using AntSK.Domain.Domain.Model.Enum;
using SqlSugar;
using System.ComponentModel.DataAnnotations;

View File

@@ -0,0 +1,16 @@
using AntSK.Domain.Domain.Model.Enum;
using SqlSugar;
using System.ComponentModel.DataAnnotations;
namespace AntSK.Domain.Repositories
{
[SugarTable("Dics")]
public partial class Dics
{
[SugarColumn(IsPrimaryKey = true)]
public string Id { get; set; }
public string Type { get; set; }
public string Key { get; set; }
public string Value { get; set; }
}
}

View File

@@ -0,0 +1,11 @@

using AntSK.Domain.Common.DependencyInjection;
using AntSK.Domain.Repositories.Base;
namespace AntSK.Domain.Repositories
{
[ServiceDescription(typeof(IDics_Repositories), ServiceLifetime.Scoped)]
public class Dics_Repositories : Repository<Dics>, IDics_Repositories
{
}
}

View File

@@ -0,0 +1,8 @@
using AntSK.Domain.Repositories.Base;
namespace AntSK.Domain.Repositories
{
public interface IDics_Repositories : IRepository<Dics>
{
}
}

View File

@@ -1,4 +1,6 @@
namespace AntSK.Domain.Utils
using System.Web;
namespace AntSK.Domain.Utils
{
public static class ConvertUtils
{
@@ -231,5 +233,33 @@
{
return $"{Environment.NewLine}```json{Environment.NewLine}{s}{Environment.NewLine}```{Environment.NewLine}";
}
/// <summary>
/// json参数转化querystring参数
/// </summary>
/// <param name="parameters"></param>
/// <returns></returns>
public static string ToQueryString(this Dictionary<string, string> parameters)
{
var nameValueCollection = HttpUtility.ParseQueryString(string.Empty);
foreach (var param in parameters)
{
nameValueCollection[param.Key] = param.Value;
}
return nameValueCollection.ToString();
}
/// <summary>
/// 忽略大小写匹配
/// </summary>
/// <param name="s"></param>
/// <param name="value"></param>
/// <returns></returns>
public static bool ComparisonIgnoreCase(this string s, string value)
{
return s.Equals(value, StringComparison.OrdinalIgnoreCase);
}
}
}

View File

@@ -0,0 +1,39 @@
using System;
using System.Collections.Generic;
using System.Drawing.Imaging;
using System.Drawing;
using System.Linq;
using System.Text;
using System.Threading.Tasks;
namespace AntSK.Domain.Utils
{
public class ImageUtils
{
public static string BitmapToBase64(Bitmap bitmap)
{
using (MemoryStream memoryStream = new MemoryStream())
{
// 保存为JPEG格式也可以选择PngGif等等
bitmap.Save(memoryStream, ImageFormat.Jpeg);
// 获取内存流的字节数组
byte[] imageBytes = memoryStream.ToArray();
// 将字节转换为Base64字符串
string base64String = Convert.ToBase64String(imageBytes);
return base64String;
}
}
public static List<string> BitmapListToBase64(Bitmap[] bitmaps)
{
List<string> base64Strings = new List<string>();
foreach (Bitmap bitmap in bitmaps)
{
base64Strings.Add(BitmapToBase64(bitmap));
}
return base64Strings;
}
}
}

View File

@@ -0,0 +1,375 @@
using System;
using System.Collections.Generic;
using System.Linq;
using System.Reflection;
using System.Text;
using System.Text.RegularExpressions;
using System.Threading.Tasks;
using System.Xml.XPath;
namespace AntSK.Domain.Utils
{
/// <summary>
/// 注释辅助类
/// </summary>
public class XmlCommentHelper
{
private static Regex RefTagPattern = new Regex(@"<(see|paramref) (name|cref)=""([TPF]{1}:)?(?<display>.+?)"" ?/>");
private static Regex CodeTagPattern = new Regex(@"<c>(?<display>.+?)</c>");
private static Regex ParaTagPattern = new Regex(@"<para>(?<display>.+?)</para>", RegexOptions.Singleline);
List<XPathNavigator> navigators = new List<XPathNavigator>();
/// <summary>
/// 从当前dll文件中加载所有的xml文件
/// </summary>
public void LoadAll()
{
var files = Directory.GetFiles(Directory.GetCurrentDirectory());
foreach (var file in files)
{
if (string.Equals(Path.GetExtension(file), ".xml", StringComparison.OrdinalIgnoreCase))
{
Load(file);
}
}
}
/// <summary>
/// 从xml中加载
/// </summary>
/// <param name="xmls"></param>
public void LoadXml(params string[] xmls)
{
foreach (var xml in xmls)
{
Load(new MemoryStream(Encoding.UTF8.GetBytes(xml)));
}
}
/// <summary>
/// 从文件中加载
/// </summary>
/// <param name="xmlFiles"></param>
public void Load(params string[] xmlFiles)
{
foreach (var xmlFile in xmlFiles)
{
var doc = new XPathDocument(xmlFile);
navigators.Add(doc.CreateNavigator());
}
}
/// <summary>
/// 从流中加载
/// </summary>
/// <param name="streams"></param>
public void Load(params Stream[] streams)
{
foreach (var stream in streams)
{
var doc = new XPathDocument(stream);
navigators.Add(doc.CreateNavigator());
}
}
/// <summary>
/// 读取类型中的注释
/// </summary>
/// <param name="type">类型</param>
/// <param name="xPath">注释路径</param>
/// <param name="humanize">可读性优化(比如去掉xml标记)</param>
/// <returns></returns>
public string GetTypeComment(Type type, string xPath = "summary", bool humanize = true)
{
var typeMemberName = GetMemberNameForType(type);
return GetComment(typeMemberName, xPath, humanize);
}
/// <summary>
/// 读取字段或者属性的注释
/// </summary>
/// <param name="fieldOrPropertyInfo">字段或者属性</param>
/// <param name="xPath">注释路径</param>
/// <param name="humanize">可读性优化(比如去掉xml标记)</param>
/// <returns></returns>
public string GetFieldOrPropertyComment(MemberInfo fieldOrPropertyInfo, string xPath = "summary", bool humanize = true)
{
var fieldOrPropertyMemberName = GetMemberNameForFieldOrProperty(fieldOrPropertyInfo);
return GetComment(fieldOrPropertyMemberName, xPath, humanize);
}
/// <summary>
/// 读取方法中的注释
/// </summary>
/// <param name="methodInfo">方法</param>
/// <param name="xPath">注释路径</param>
/// <param name="humanize">可读性优化(比如去掉xml标记)</param>
/// <returns></returns>
public string GetMethodComment(MethodInfo methodInfo, string xPath = "summary", bool humanize = true)
{
var methodMemberName = GetMemberNameForMethod(methodInfo);
return GetComment(methodMemberName, xPath, humanize);
}
/// <summary>
/// 读取方法中的返回值注释
/// </summary>
/// <param name="methodInfo">方法</param>
/// <param name="humanize">可读性优化(比如去掉xml标记)</param>
/// <returns></returns>
public string GetMethodReturnComment(MethodInfo methodInfo, bool humanize = true)
{
return GetMethodComment(methodInfo, "returns", humanize);
}
/// <summary>
/// 读取参数的注释
/// </summary>
/// <param name="parameterInfo">参数</param>
/// <param name="humanize">可读性优化(比如去掉xml标记)</param>
/// <returns></returns>
public string GetParameterComment(ParameterInfo parameterInfo, bool humanize = true)
{
if (!(parameterInfo.Member is MethodInfo methodInfo)) return string.Empty;
var methodMemberName = GetMemberNameForMethod(methodInfo);
return GetComment(methodMemberName, $"param[@name='{parameterInfo.Name}']", humanize);
}
/// <summary>
/// 读取方法的所有参数的注释
/// </summary>
/// <param name="methodInfo">方法</param>
/// <param name="humanize">可读性优化(比如去掉xml标记)</param>
/// <returns></returns>
public Dictionary<string, string> GetParameterComments(MethodInfo methodInfo, bool humanize = true)
{
var parameterInfos = methodInfo.GetParameters();
Dictionary<string, string> dict = new Dictionary<string, string>();
foreach (var parameterInfo in parameterInfos)
{
dict[parameterInfo.Name] = GetParameterComment(parameterInfo, humanize);
}
return dict;
}
/// <summary>
/// 读取指定名称节点的注释
/// </summary>
/// <param name="name">节点名称</param>
/// <param name="xPath">注释路径</param>
/// <param name="humanize">可读性优化(比如去掉xml标记)</param>
/// <returns></returns>
public string GetComment(string name, string xPath, bool humanize = true)
{
foreach (var _xmlNavigator in navigators)
{
var typeSummaryNode = _xmlNavigator.SelectSingleNode($"/doc/members/member[@name='{name}']/{xPath.Trim('/', '\\')}");
if (typeSummaryNode != null)
{
return humanize ? Humanize(typeSummaryNode.InnerXml) : typeSummaryNode.InnerXml;
}
}
return string.Empty;
}
/// <summary>
/// 读取指定节点的summary注释
/// </summary>
/// <param name="name">节点名称</param>
/// <param name="humanize">可读性优化(比如去掉xml标记)</param>
/// <returns></returns>
public string GetSummary(string name, bool humanize = true)
{
return GetComment(name, "summary", humanize);
}
/// <summary>
/// 读取指定节点的example注释
/// </summary>
/// <param name="name">节点名称</param>
/// <param name="humanize">可读性优化(比如去掉xml标记)</param>
/// <returns></returns>
public string GetExample(string name, bool humanize = true)
{
return GetComment(name, "example", humanize);
}
/// <summary>
/// 获取方法的节点名称
/// </summary>
/// <param name="method"></param>
/// <returns></returns>
public string GetMemberNameForMethod(MethodInfo method)
{
var builder = new StringBuilder("M:");
builder.Append(QualifiedNameFor(method.DeclaringType));
builder.Append($".{method.Name}");
var parameters = method.GetParameters();
if (parameters.Any())
{
var parametersNames = parameters.Select(p =>
{
return p.ParameterType.IsGenericParameter
? $"`{p.ParameterType.GenericParameterPosition}"
: QualifiedNameFor(p.ParameterType, expandGenericArgs: true);
});
builder.Append($"({string.Join(",", parametersNames)})");
}
return builder.ToString();
}
/// <summary>
/// 获取类型的节点名称
/// </summary>
/// <param name="type"></param>
/// <returns></returns>
public string GetMemberNameForType(Type type)
{
var builder = new StringBuilder("T:");
builder.Append(QualifiedNameFor(type));
return builder.ToString();
}
/// <summary>
/// 获取字段或者属性的节点名称
/// </summary>
/// <param name="fieldOrPropertyInfo"></param>
/// <returns></returns>
public string GetMemberNameForFieldOrProperty(MemberInfo fieldOrPropertyInfo)
{
var builder = new StringBuilder((fieldOrPropertyInfo.MemberType & MemberTypes.Field) != 0 ? "F:" : "P:");
builder.Append(QualifiedNameFor(fieldOrPropertyInfo.DeclaringType));
builder.Append($".{fieldOrPropertyInfo.Name}");
return builder.ToString();
}
private string QualifiedNameFor(Type type, bool expandGenericArgs = false)
{
if (type.IsArray)
return $"{QualifiedNameFor(type.GetElementType(), expandGenericArgs)}[]";
var builder = new StringBuilder();
if (!string.IsNullOrEmpty(type.Namespace))
builder.Append($"{type.Namespace}.");
if (type.IsNested)
{
builder.Append($"{string.Join(".", GetNestedTypeNames(type))}.");
}
if (type.IsConstructedGenericType && expandGenericArgs)
{
var nameSansGenericArgs = type.Name.Split('`').First();
builder.Append(nameSansGenericArgs);
var genericArgsNames = type.GetGenericArguments().Select(t =>
{
return t.IsGenericParameter
? $"`{t.GenericParameterPosition}"
: QualifiedNameFor(t, true);
});
builder.Append($"{{{string.Join(",", genericArgsNames)}}}");
}
else
{
builder.Append(type.Name);
}
return builder.ToString();
}
private IEnumerable<string> GetNestedTypeNames(Type type)
{
if (!type.IsNested || type.DeclaringType == null) yield break;
foreach (var nestedTypeName in GetNestedTypeNames(type.DeclaringType))
{
yield return nestedTypeName;
}
yield return type.DeclaringType.Name;
}
private string Humanize(string text)
{
if (text == null)
throw new ArgumentNullException("text");
//Call DecodeXml at last to avoid entities like &lt and &gt to break valid xml
text = NormalizeIndentation(text);
text = HumanizeRefTags(text);
text = HumanizeCodeTags(text);
text = HumanizeParaTags(text);
text = DecodeXml(text);
return text;
}
private string NormalizeIndentation(string text)
{
string[] lines = text.Split('\n');
string padding = GetCommonLeadingWhitespace(lines);
int padLen = padding == null ? 0 : padding.Length;
// remove leading padding from each line
for (int i = 0, l = lines.Length; i < l; ++i)
{
string line = lines[i].TrimEnd('\r'); // remove trailing '\r'
if (padLen != 0 && line.Length >= padLen && line.Substring(0, padLen) == padding)
line = line.Substring(padLen);
lines[i] = line;
}
// remove leading empty lines, but not all leading padding
// remove all trailing whitespace, regardless
return string.Join("\r\n", lines.SkipWhile(x => string.IsNullOrWhiteSpace(x))).TrimEnd();
}
private string GetCommonLeadingWhitespace(string[] lines)
{
if (null == lines)
throw new ArgumentException("lines");
if (lines.Length == 0)
return null;
string[] nonEmptyLines = lines
.Where(x => !string.IsNullOrWhiteSpace(x))
.ToArray();
if (nonEmptyLines.Length < 1)
return null;
int padLen = 0;
// use the first line as a seed, and see what is shared over all nonEmptyLines
string seed = nonEmptyLines[0];
for (int i = 0, l = seed.Length; i < l; ++i)
{
if (!char.IsWhiteSpace(seed, i))
break;
if (nonEmptyLines.Any(line => line[i] != seed[i]))
break;
++padLen;
}
if (padLen > 0)
return seed.Substring(0, padLen);
return null;
}
private string HumanizeRefTags(string text)
{
return RefTagPattern.Replace(text, (match) => match.Groups["display"].Value);
}
private string HumanizeCodeTags(string text)
{
return CodeTagPattern.Replace(text, (match) => "{" + match.Groups["display"].Value + "}");
}
private string HumanizeParaTags(string text)
{
return ParaTagPattern.Replace(text, (match) => "<br>" + match.Groups["display"].Value);
}
private string DecodeXml(string text)
{
return System.Net.WebUtility.HtmlDecode(text);
}
}
}

View File

@@ -0,0 +1,21 @@
<Project Sdk="Microsoft.NET.Sdk">
<PropertyGroup>
<TargetFramework>net8.0</TargetFramework>
<ImplicitUsings>enable</ImplicitUsings>
<Nullable>enable</Nullable>
</PropertyGroup>
<ItemGroup>
<Content Include="llamafactory\**">
<CopyToOutputDirectory>PreserveNewest</CopyToOutputDirectory>
</Content>
</ItemGroup>
<ItemGroup>
<None Update="modelList.json">
<CopyToOutputDirectory>PreserveNewest</CopyToOutputDirectory>
</None>
<None Update="requirements.txt">
<CopyToOutputDirectory>PreserveNewest</CopyToOutputDirectory>
</None>
</ItemGroup>
</Project>

View File

@@ -0,0 +1,27 @@
using System;
using System.Collections.Generic;
using System.Linq;
using System.Text;
using System.Threading.Tasks;
namespace AntSK.LLamaFactory.Model
{
public class ModelInfo
{
public string DEFAULT { get; set; }
public string MODELSCOPE { get; set; }
}
public class LLamaFactoryModel
{
public Dictionary<string, ModelInfo> Models { get; set; }
public string Template { get; set; }
}
public class LLamaModel
{
public string Name { get; set; }
public string ModelScope { get; set; }
}
}

View File

@@ -0,0 +1,16 @@
import os
import uvicorn
from llmtuner import ChatModel, create_app
def main():
chat_model = ChatModel()
app = create_app(chat_model)
print("Visit http://localhost:{}/docs for API document.".format(os.environ.get("API_PORT", 8000)))
uvicorn.run(app, host="0.0.0.0", port=int(os.environ.get("API_PORT", 8000)), workers=1)
if __name__ == "__main__":
main()

View File

@@ -0,0 +1,11 @@
# Level: api, webui > chat, eval, train > data, model > extras, hparams
from .api import create_app
from .chat import ChatModel
from .eval import Evaluator
from .train import export_model, run_exp
from .webui import create_ui, create_web_demo
__version__ = "0.5.3"
__all__ = ["create_app", "ChatModel", "Evaluator", "export_model", "run_exp", "create_ui", "create_web_demo"]

View File

@@ -0,0 +1,4 @@
from .app import create_app
__all__ = ["create_app"]

View File

@@ -0,0 +1,224 @@
import json
import os
from contextlib import asynccontextmanager
from typing import Any, Dict, Sequence
from pydantic import BaseModel
from ..chat import ChatModel
from ..data import Role as DataRole
from ..extras.misc import torch_gc
from ..extras.packages import is_fastapi_availble, is_starlette_available, is_uvicorn_available
from .protocol import (
ChatCompletionMessage,
ChatCompletionRequest,
ChatCompletionResponse,
ChatCompletionResponseChoice,
ChatCompletionResponseStreamChoice,
ChatCompletionResponseUsage,
ChatCompletionStreamResponse,
Finish,
Function,
FunctionCall,
ModelCard,
ModelList,
Role,
ScoreEvaluationRequest,
ScoreEvaluationResponse,
)
if is_fastapi_availble():
from fastapi import FastAPI, HTTPException, status
from fastapi.middleware.cors import CORSMiddleware
if is_starlette_available():
from sse_starlette import EventSourceResponse
if is_uvicorn_available():
import uvicorn
@asynccontextmanager
async def lifespan(app: "FastAPI"): # collects GPU memory
yield
torch_gc()
def dictify(data: "BaseModel") -> Dict[str, Any]:
try: # pydantic v2
return data.model_dump(exclude_unset=True)
except AttributeError: # pydantic v1
return data.dict(exclude_unset=True)
def jsonify(data: "BaseModel") -> str:
try: # pydantic v2
return json.dumps(data.model_dump(exclude_unset=True), ensure_ascii=False)
except AttributeError: # pydantic v1
return data.json(exclude_unset=True, ensure_ascii=False)
def create_app(chat_model: "ChatModel") -> "FastAPI":
app = FastAPI(lifespan=lifespan)
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
role_mapping = {
Role.USER: DataRole.USER.value,
Role.ASSISTANT: DataRole.ASSISTANT.value,
Role.SYSTEM: DataRole.SYSTEM.value,
Role.FUNCTION: DataRole.FUNCTION.value,
Role.TOOL: DataRole.OBSERVATION.value,
}
@app.get("/v1/models", response_model=ModelList)
async def list_models():
model_card = ModelCard(id="gpt-3.5-turbo")
return ModelList(data=[model_card])
@app.post("/v1/chat/completions", response_model=ChatCompletionResponse, status_code=status.HTTP_200_OK)
async def create_chat_completion(request: ChatCompletionRequest):
if not chat_model.engine.can_generate:
raise HTTPException(status_code=status.HTTP_405_METHOD_NOT_ALLOWED, detail="Not allowed")
if len(request.messages) == 0:
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid length")
if request.messages[0].role == Role.SYSTEM:
system = request.messages.pop(0).content
else:
system = ""
if len(request.messages) % 2 == 0:
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Only supports u/a/u/a/u...")
input_messages = []
for i, message in enumerate(request.messages):
if i % 2 == 0 and message.role not in [Role.USER, Role.TOOL]:
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid role")
elif i % 2 == 1 and message.role not in [Role.ASSISTANT, Role.FUNCTION]:
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid role")
input_messages.append({"role": role_mapping[message.role], "content": message.content})
tool_list = request.tools
if isinstance(tool_list, list) and len(tool_list):
try:
tools = json.dumps([tool["function"] for tool in tool_list], ensure_ascii=False)
except Exception:
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid tools")
else:
tools = ""
if request.stream:
if tools:
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Cannot stream function calls.")
generate = stream_chat_completion(input_messages, system, tools, request)
return EventSourceResponse(generate, media_type="text/event-stream")
responses = await chat_model.achat(
input_messages,
system,
tools,
do_sample=request.do_sample,
temperature=request.temperature,
top_p=request.top_p,
max_new_tokens=request.max_tokens,
num_return_sequences=request.n,
)
prompt_length, response_length = 0, 0
choices = []
for i, response in enumerate(responses):
if tools:
result = chat_model.engine.template.format_tools.extract(response.response_text)
else:
result = response.response_text
if isinstance(result, tuple):
name, arguments = result
function = Function(name=name, arguments=arguments)
response_message = ChatCompletionMessage(
role=Role.ASSISTANT, tool_calls=[FunctionCall(function=function)]
)
finish_reason = Finish.TOOL
else:
response_message = ChatCompletionMessage(role=Role.ASSISTANT, content=result)
finish_reason = Finish.STOP if response.finish_reason == "stop" else Finish.LENGTH
choices.append(
ChatCompletionResponseChoice(index=i, message=response_message, finish_reason=finish_reason)
)
prompt_length = response.prompt_length
response_length += response.response_length
usage = ChatCompletionResponseUsage(
prompt_tokens=prompt_length,
completion_tokens=response_length,
total_tokens=prompt_length + response_length,
)
return ChatCompletionResponse(model=request.model, choices=choices, usage=usage)
async def stream_chat_completion(
messages: Sequence[Dict[str, str]], system: str, tools: str, request: ChatCompletionRequest
):
choice_data = ChatCompletionResponseStreamChoice(
index=0, delta=ChatCompletionMessage(role=Role.ASSISTANT, content=""), finish_reason=None
)
chunk = ChatCompletionStreamResponse(model=request.model, choices=[choice_data])
yield jsonify(chunk)
async for new_token in chat_model.astream_chat(
messages,
system,
tools,
do_sample=request.do_sample,
temperature=request.temperature,
top_p=request.top_p,
max_new_tokens=request.max_tokens,
):
if len(new_token) == 0:
continue
choice_data = ChatCompletionResponseStreamChoice(
index=0, delta=ChatCompletionMessage(content=new_token), finish_reason=None
)
chunk = ChatCompletionStreamResponse(model=request.model, choices=[choice_data])
yield jsonify(chunk)
choice_data = ChatCompletionResponseStreamChoice(
index=0, delta=ChatCompletionMessage(), finish_reason=Finish.STOP
)
chunk = ChatCompletionStreamResponse(model=request.model, choices=[choice_data])
yield jsonify(chunk)
yield "[DONE]"
@app.post("/v1/score/evaluation", response_model=ScoreEvaluationResponse, status_code=status.HTTP_200_OK)
async def create_score_evaluation(request: ScoreEvaluationRequest):
if chat_model.engine.can_generate:
raise HTTPException(status_code=status.HTTP_405_METHOD_NOT_ALLOWED, detail="Not allowed")
if len(request.messages) == 0:
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid request")
scores = await chat_model.aget_scores(request.messages, max_length=request.max_length)
return ScoreEvaluationResponse(model=request.model, scores=scores)
return app
if __name__ == "__main__":
chat_model = ChatModel()
app = create_app(chat_model)
uvicorn.run(app, host="0.0.0.0", port=int(os.environ.get("API_PORT", 8000)), workers=1)

View File

@@ -0,0 +1,116 @@
import time
from enum import Enum, unique
from typing import List, Optional
from pydantic import BaseModel, Field
from typing_extensions import Literal
@unique
class Role(str, Enum):
USER = "user"
ASSISTANT = "assistant"
SYSTEM = "system"
FUNCTION = "function"
TOOL = "tool"
@unique
class Finish(str, Enum):
STOP = "stop"
LENGTH = "length"
TOOL = "tool_calls"
class ModelCard(BaseModel):
id: str
object: Literal["model"] = "model"
created: int = Field(default_factory=lambda: int(time.time()))
owned_by: Literal["owner"] = "owner"
class ModelList(BaseModel):
object: Literal["list"] = "list"
data: List[ModelCard] = []
class Function(BaseModel):
name: str
arguments: str
class FunctionCall(BaseModel):
id: Literal["call_default"] = "call_default"
type: Literal["function"] = "function"
function: Function
class ChatMessage(BaseModel):
role: Role
content: str
class ChatCompletionMessage(BaseModel):
role: Optional[Role] = None
content: Optional[str] = None
tool_calls: Optional[List[FunctionCall]] = None
class ChatCompletionRequest(BaseModel):
model: str
messages: List[ChatMessage]
tools: list = []
do_sample: bool = True
temperature: Optional[float] = None
top_p: Optional[float] = None
n: int = 1
max_tokens: Optional[int] = None
stream: bool = False
class ChatCompletionResponseChoice(BaseModel):
index: int
message: ChatCompletionMessage
finish_reason: Finish
class ChatCompletionResponseStreamChoice(BaseModel):
index: int
delta: ChatCompletionMessage
finish_reason: Optional[Finish] = None
class ChatCompletionResponseUsage(BaseModel):
prompt_tokens: int
completion_tokens: int
total_tokens: int
class ChatCompletionResponse(BaseModel):
id: Literal["chatcmpl-default"] = "chatcmpl-default"
object: Literal["chat.completion"] = "chat.completion"
created: int = Field(default_factory=lambda: int(time.time()))
model: str
choices: List[ChatCompletionResponseChoice]
usage: ChatCompletionResponseUsage
class ChatCompletionStreamResponse(BaseModel):
id: Literal["chatcmpl-default"] = "chatcmpl-default"
object: Literal["chat.completion.chunk"] = "chat.completion.chunk"
created: int = Field(default_factory=lambda: int(time.time()))
model: str
choices: List[ChatCompletionResponseStreamChoice]
class ScoreEvaluationRequest(BaseModel):
model: str
messages: List[str]
max_length: Optional[int] = None
class ScoreEvaluationResponse(BaseModel):
id: Literal["scoreeval-default"] = "scoreeval-default"
object: Literal["score.evaluation"] = "score.evaluation"
model: str
scores: List[float]

View File

@@ -0,0 +1,5 @@
from .base_engine import BaseEngine
from .chat_model import ChatModel
__all__ = ["BaseEngine", "ChatModel"]

View File

@@ -0,0 +1,69 @@
from abc import ABC, abstractmethod
from dataclasses import dataclass
from typing import TYPE_CHECKING, Any, AsyncGenerator, Dict, List, Literal, Optional, Sequence, Union
if TYPE_CHECKING:
from transformers import PreTrainedModel, PreTrainedTokenizer
from ..data import Template
from ..extras.packages import is_vllm_available
from ..hparams import DataArguments, FinetuningArguments, GeneratingArguments, ModelArguments
if is_vllm_available():
from vllm import AsyncLLMEngine
@dataclass
class Response:
response_text: str
response_length: int
prompt_length: int
finish_reason: Literal["stop", "length"]
class BaseEngine(ABC):
model: Union["PreTrainedModel", "AsyncLLMEngine"]
tokenizer: "PreTrainedTokenizer"
can_generate: bool
template: "Template"
generating_args: Dict[str, Any]
@abstractmethod
def __init__(
self,
model_args: "ModelArguments",
data_args: "DataArguments",
finetuning_args: "FinetuningArguments",
generating_args: "GeneratingArguments",
) -> None: ...
@abstractmethod
async def start(
self,
) -> None: ...
@abstractmethod
async def chat(
self,
messages: Sequence[Dict[str, str]],
system: Optional[str] = None,
tools: Optional[str] = None,
**input_kwargs,
) -> List["Response"]: ...
@abstractmethod
async def stream_chat(
self,
messages: Sequence[Dict[str, str]],
system: Optional[str] = None,
tools: Optional[str] = None,
**input_kwargs,
) -> AsyncGenerator[str, None]: ...
@abstractmethod
async def get_scores(
self,
batch_input: List[str],
**input_kwargs,
) -> List[float]: ...

View File

@@ -0,0 +1,91 @@
import asyncio
from threading import Thread
from typing import TYPE_CHECKING, Any, AsyncGenerator, Dict, Generator, List, Optional, Sequence
from ..hparams import get_infer_args
from .hf_engine import HuggingfaceEngine
from .vllm_engine import VllmEngine
if TYPE_CHECKING:
from .base_engine import BaseEngine, Response
def _start_background_loop(loop: asyncio.AbstractEventLoop) -> None:
asyncio.set_event_loop(loop)
loop.run_forever()
class ChatModel:
def __init__(self, args: Optional[Dict[str, Any]] = None) -> None:
model_args, data_args, finetuning_args, generating_args = get_infer_args(args)
if model_args.infer_backend == "huggingface":
self.engine: "BaseEngine" = HuggingfaceEngine(model_args, data_args, finetuning_args, generating_args)
elif model_args.infer_backend == "vllm":
self.engine: "BaseEngine" = VllmEngine(model_args, data_args, finetuning_args, generating_args)
else:
raise NotImplementedError("Unknown backend: {}".format(model_args.infer_backend))
self._loop = asyncio.new_event_loop()
self._thread = Thread(target=_start_background_loop, args=(self._loop,), daemon=True)
self._thread.start()
asyncio.run_coroutine_threadsafe(self.engine.start(), self._loop)
def chat(
self,
messages: Sequence[Dict[str, str]],
system: Optional[str] = None,
tools: Optional[str] = None,
**input_kwargs,
) -> List["Response"]:
task = asyncio.run_coroutine_threadsafe(self.achat(messages, system, tools, **input_kwargs), self._loop)
return task.result()
async def achat(
self,
messages: Sequence[Dict[str, str]],
system: Optional[str] = None,
tools: Optional[str] = None,
**input_kwargs,
) -> List["Response"]:
return await self.engine.chat(messages, system, tools, **input_kwargs)
def stream_chat(
self,
messages: Sequence[Dict[str, str]],
system: Optional[str] = None,
tools: Optional[str] = None,
**input_kwargs,
) -> Generator[str, None, None]:
generator = self.astream_chat(messages, system, tools, **input_kwargs)
while True:
try:
task = asyncio.run_coroutine_threadsafe(generator.__anext__(), self._loop)
yield task.result()
except StopAsyncIteration:
break
async def astream_chat(
self,
messages: Sequence[Dict[str, str]],
system: Optional[str] = None,
tools: Optional[str] = None,
**input_kwargs,
) -> AsyncGenerator[str, None]:
async for new_token in self.engine.stream_chat(messages, system, tools, **input_kwargs):
yield new_token
def get_scores(
self,
batch_input: List[str],
**input_kwargs,
) -> List[float]:
task = asyncio.run_coroutine_threadsafe(self.aget_scores(batch_input, **input_kwargs), self._loop)
return task.result()
async def aget_scores(
self,
batch_input: List[str],
**input_kwargs,
) -> List[float]:
return await self.engine.get_scores(batch_input, **input_kwargs)

View File

@@ -0,0 +1,263 @@
import asyncio
import concurrent.futures
import os
from threading import Thread
from typing import TYPE_CHECKING, Any, AsyncGenerator, Callable, Dict, List, Optional, Sequence, Tuple
import torch
from transformers import GenerationConfig, TextIteratorStreamer
from ..data import get_template_and_fix_tokenizer
from ..extras.misc import get_logits_processor
from ..model import load_model_and_tokenizer
from .base_engine import BaseEngine, Response
if TYPE_CHECKING:
from transformers import PreTrainedModel, PreTrainedTokenizer
from trl import PreTrainedModelWrapper
from ..data import Template
from ..hparams import DataArguments, FinetuningArguments, GeneratingArguments, ModelArguments
class HuggingfaceEngine(BaseEngine):
def __init__(
self,
model_args: "ModelArguments",
data_args: "DataArguments",
finetuning_args: "FinetuningArguments",
generating_args: "GeneratingArguments",
) -> None:
self.can_generate = finetuning_args.stage == "sft"
self.model, self.tokenizer = load_model_and_tokenizer(
model_args, finetuning_args, is_trainable=False, add_valuehead=(not self.can_generate)
)
self.tokenizer.padding_side = "left" if self.can_generate else "right"
self.template = get_template_and_fix_tokenizer(self.tokenizer, data_args.template)
self.generating_args = generating_args.to_dict()
@staticmethod
def _process_args(
model: "PreTrainedModel",
tokenizer: "PreTrainedTokenizer",
template: "Template",
generating_args: Dict[str, Any],
messages: Sequence[Dict[str, str]],
system: Optional[str] = None,
tools: Optional[str] = None,
input_kwargs: Optional[Dict[str, Any]] = {},
) -> Tuple[Dict[str, Any], int]:
paired_messages = messages + [{"role": "assistant", "content": ""}]
prompt_ids, _ = template.encode_oneturn(
tokenizer=tokenizer, messages=paired_messages, system=system, tools=tools
)
prompt_length = len(prompt_ids)
inputs = torch.tensor([prompt_ids], device=model.device)
do_sample = input_kwargs.pop("do_sample", None)
temperature = input_kwargs.pop("temperature", None)
top_p = input_kwargs.pop("top_p", None)
top_k = input_kwargs.pop("top_k", None)
num_return_sequences = input_kwargs.pop("num_return_sequences", None)
repetition_penalty = input_kwargs.pop("repetition_penalty", None)
max_length = input_kwargs.pop("max_length", None)
max_new_tokens = input_kwargs.pop("max_new_tokens", None)
generating_args.update(
dict(
do_sample=do_sample if do_sample is not None else generating_args["do_sample"],
temperature=temperature or generating_args["temperature"],
top_p=top_p or generating_args["top_p"],
top_k=top_k or generating_args["top_k"],
num_return_sequences=num_return_sequences or 1,
repetition_penalty=repetition_penalty or generating_args["repetition_penalty"],
eos_token_id=[tokenizer.eos_token_id] + tokenizer.additional_special_tokens_ids,
pad_token_id=tokenizer.pad_token_id,
)
)
if isinstance(num_return_sequences, int) and num_return_sequences > 1:
generating_args["do_sample"] = True
if max_length:
generating_args.pop("max_new_tokens", None)
generating_args["max_length"] = max_length
if max_new_tokens:
generating_args.pop("max_length", None)
generating_args["max_new_tokens"] = max_new_tokens
gen_kwargs = dict(
inputs=inputs,
generation_config=GenerationConfig(**generating_args),
logits_processor=get_logits_processor(),
)
return gen_kwargs, prompt_length
@staticmethod
@torch.inference_mode()
def _chat(
model: "PreTrainedModel",
tokenizer: "PreTrainedTokenizer",
template: "Template",
generating_args: Dict[str, Any],
messages: Sequence[Dict[str, str]],
system: Optional[str] = None,
tools: Optional[str] = None,
input_kwargs: Optional[Dict[str, Any]] = {},
) -> List["Response"]:
gen_kwargs, prompt_length = HuggingfaceEngine._process_args(
model, tokenizer, template, generating_args, messages, system, tools, input_kwargs
)
generate_output = model.generate(**gen_kwargs)
response_ids = generate_output[:, prompt_length:]
response = tokenizer.batch_decode(response_ids, skip_special_tokens=True, clean_up_tokenization_spaces=True)
results = []
for i in range(len(response)):
eos_index = (response_ids[i] == tokenizer.eos_token_id).nonzero()
response_length = (eos_index[0].item() + 1) if len(eos_index) else len(response_ids[i])
results.append(
Response(
response_text=response[i],
response_length=response_length,
prompt_length=prompt_length,
finish_reason="stop" if len(eos_index) else "length",
)
)
return results
@staticmethod
@torch.inference_mode()
def _stream_chat(
model: "PreTrainedModel",
tokenizer: "PreTrainedTokenizer",
template: "Template",
generating_args: Dict[str, Any],
messages: Sequence[Dict[str, str]],
system: Optional[str] = None,
tools: Optional[str] = None,
input_kwargs: Optional[Dict[str, Any]] = {},
) -> Callable[[], str]:
gen_kwargs, _ = HuggingfaceEngine._process_args(
model, tokenizer, template, generating_args, messages, system, tools, input_kwargs
)
streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
gen_kwargs["streamer"] = streamer
thread = Thread(target=model.generate, kwargs=gen_kwargs, daemon=True)
thread.start()
def stream():
try:
return streamer.__next__()
except StopIteration:
raise StopAsyncIteration()
return stream
@staticmethod
@torch.inference_mode()
def _get_scores(
model: "PreTrainedModelWrapper",
tokenizer: "PreTrainedTokenizer",
batch_input: List[str],
input_kwargs: Optional[Dict[str, Any]] = {},
) -> List[float]:
max_length = input_kwargs.pop("max_length", None)
device = getattr(model.pretrained_model, "device", "cuda")
inputs = tokenizer(
batch_input,
padding=True,
truncation=True,
max_length=max_length or getattr(model.config, "max_position_embeddings", 1024),
return_tensors="pt",
add_special_tokens=True,
).to(device)
input_ids: torch.Tensor = inputs["input_ids"]
_, _, values = model(**inputs, output_hidden_states=True, return_dict=True)
if getattr(model.config, "model_type", None) == "chatglm":
values = torch.transpose(values, 0, 1)
scores = []
for i in range(input_ids.size(0)):
end_indexes = (input_ids[i] != tokenizer.pad_token_id).nonzero()
end_index = end_indexes[-1].item() if len(end_indexes) else 0
scores.append(values[i, end_index].nan_to_num().item())
return scores
async def start(self) -> None:
self._semaphore = asyncio.Semaphore(int(os.environ.get("MAX_CONCURRENT", 1)))
async def chat(
self,
messages: Sequence[Dict[str, str]],
system: Optional[str] = None,
tools: Optional[str] = None,
**input_kwargs,
) -> List["Response"]:
if not self.can_generate:
raise ValueError("The current model does not support `chat`.")
loop = asyncio.get_running_loop()
input_args = (
self.model,
self.tokenizer,
self.template,
self.generating_args,
messages,
system,
tools,
input_kwargs,
)
async with self._semaphore:
with concurrent.futures.ThreadPoolExecutor() as pool:
return await loop.run_in_executor(pool, self._chat, *input_args)
async def stream_chat(
self,
messages: Sequence[Dict[str, str]],
system: Optional[str] = None,
tools: Optional[str] = None,
**input_kwargs,
) -> AsyncGenerator[str, None]:
if not self.can_generate:
raise ValueError("The current model does not support `stream_chat`.")
loop = asyncio.get_running_loop()
input_args = (
self.model,
self.tokenizer,
self.template,
self.generating_args,
messages,
system,
tools,
input_kwargs,
)
async with self._semaphore:
with concurrent.futures.ThreadPoolExecutor() as pool:
stream = self._stream_chat(*input_args)
while True:
try:
yield await loop.run_in_executor(pool, stream)
except StopAsyncIteration:
break
async def get_scores(
self,
batch_input: List[str],
**input_kwargs,
) -> List[float]:
if self.can_generate:
raise ValueError("Cannot get scores using an auto-regressive model.")
loop = asyncio.get_running_loop()
input_args = (self.model, self.tokenizer, batch_input, input_kwargs)
async with self._semaphore:
with concurrent.futures.ThreadPoolExecutor() as pool:
return await loop.run_in_executor(pool, self._get_scores, *input_args)

View File

@@ -0,0 +1,149 @@
import uuid
from typing import TYPE_CHECKING, AsyncGenerator, AsyncIterator, Dict, List, Optional, Sequence
from transformers.utils.versions import require_version
from ..data import get_template_and_fix_tokenizer
from ..extras.misc import get_device_count
from ..extras.packages import is_vllm_available
from ..model import load_tokenizer
from .base_engine import BaseEngine, Response
if is_vllm_available():
from vllm import AsyncEngineArgs, AsyncLLMEngine, RequestOutput, SamplingParams
if TYPE_CHECKING:
from ..hparams import DataArguments, FinetuningArguments, GeneratingArguments, ModelArguments
class VllmEngine(BaseEngine):
def __init__(
self,
model_args: "ModelArguments",
data_args: "DataArguments",
finetuning_args: "FinetuningArguments",
generating_args: "GeneratingArguments",
) -> None:
require_version("vllm>=0.3.3", "To fix: pip install vllm>=0.3.3")
self.can_generate = finetuning_args.stage == "sft"
engine_args = AsyncEngineArgs(
model=model_args.model_name_or_path,
trust_remote_code=True,
max_model_len=model_args.vllm_maxlen,
tensor_parallel_size=get_device_count() or 1,
gpu_memory_utilization=model_args.vllm_gpu_util,
disable_log_stats=True,
disable_log_requests=True,
enforce_eager=model_args.vllm_enforce_eager,
)
self.model = AsyncLLMEngine.from_engine_args(engine_args)
self.tokenizer = load_tokenizer(model_args)
self.tokenizer.padding_side = "left"
self.template = get_template_and_fix_tokenizer(self.tokenizer, data_args.template)
self.generating_args = generating_args.to_dict()
async def _generate(
self,
messages: Sequence[Dict[str, str]],
system: Optional[str] = None,
tools: Optional[str] = None,
**input_kwargs,
) -> AsyncIterator["RequestOutput"]:
request_id = "chatcmpl-{}".format(uuid.uuid4().hex)
paired_messages = messages + [{"role": "assistant", "content": ""}]
prompt_ids, _ = self.template.encode_oneturn(
tokenizer=self.tokenizer, messages=paired_messages, system=system, tools=tools
)
prompt_length = len(prompt_ids)
temperature = input_kwargs.pop("temperature", None)
top_p = input_kwargs.pop("top_p", None)
top_k = input_kwargs.pop("top_k", None)
num_return_sequences = input_kwargs.pop("num_return_sequences", None)
repetition_penalty = input_kwargs.pop("repetition_penalty", None)
max_length = input_kwargs.pop("max_length", None)
max_new_tokens = input_kwargs.pop("max_new_tokens", None)
generating_args = self.generating_args.copy()
generating_args.update(
dict(
temperature=temperature or generating_args["temperature"],
top_p=top_p or generating_args["top_p"],
top_k=top_k or generating_args["top_k"],
num_return_sequences=num_return_sequences or 1,
repetition_penalty=repetition_penalty or generating_args["repetition_penalty"],
)
)
if max_length:
generating_args["max_new_tokens"] = max_length - prompt_length
if max_new_tokens:
generating_args["max_new_tokens"] = max_new_tokens
sampling_params = SamplingParams(
n=generating_args["num_return_sequences"],
repetition_penalty=generating_args["repetition_penalty"],
temperature=generating_args["temperature"],
top_p=generating_args["top_p"],
top_k=generating_args["top_k"],
use_beam_search=generating_args["num_beams"] > 1,
length_penalty=generating_args["length_penalty"],
stop_token_ids=[self.tokenizer.eos_token_id] + self.tokenizer.additional_special_tokens_ids,
max_tokens=generating_args["max_new_tokens"],
skip_special_tokens=True,
)
result_generator = self.model.generate(
prompt=None, sampling_params=sampling_params, request_id=request_id, prompt_token_ids=prompt_ids
)
return result_generator
async def start(self) -> None:
pass
async def chat(
self,
messages: Sequence[Dict[str, str]],
system: Optional[str] = None,
tools: Optional[str] = None,
**input_kwargs,
) -> List["Response"]:
final_output = None
generator = await self._generate(messages, system, tools, **input_kwargs)
async for request_output in generator:
final_output = request_output
results = []
for output in final_output.outputs:
results.append(
Response(
response_text=output.text,
response_length=len(output.token_ids),
prompt_length=len(final_output.prompt_token_ids),
finish_reason=output.finish_reason,
)
)
return results
async def stream_chat(
self,
messages: Sequence[Dict[str, str]],
system: Optional[str] = None,
tools: Optional[str] = None,
**input_kwargs,
) -> AsyncGenerator[str, None]:
generated_text = ""
generator = await self._generate(messages, system, tools, **input_kwargs)
async for result in generator:
delta_text = result.outputs[0].text[len(generated_text) :]
generated_text = result.outputs[0].text
yield delta_text
async def get_scores(
self,
batch_input: List[str],
**input_kwargs,
) -> List[float]:
raise NotImplementedError("vLLM engine does not support get_scores.")

View File

@@ -0,0 +1,6 @@
from .loader import get_dataset
from .template import Template, get_template_and_fix_tokenizer, templates
from .utils import Role, split_dataset
__all__ = ["get_dataset", "Template", "get_template_and_fix_tokenizer", "templates", "Role", "split_dataset"]

View File

@@ -0,0 +1,133 @@
from functools import partial
from typing import TYPE_CHECKING, Any, Dict, List, Union
from datasets import Features
from .utils import Role
if TYPE_CHECKING:
from datasets import Dataset, IterableDataset
from ..hparams import DataArguments
from .parser import DatasetAttr
def convert_alpaca(examples: Dict[str, List[Any]], dataset_attr: "DatasetAttr") -> Dict[str, List[Any]]:
outputs = {"prompt": [], "response": [], "system": [], "tools": []}
for i in range(len(examples[dataset_attr.prompt])):
prompt = []
if dataset_attr.history and isinstance(examples[dataset_attr.history][i], list):
for old_prompt, old_response in examples[dataset_attr.history][i]:
prompt.append({"role": Role.USER.value, "content": old_prompt})
prompt.append({"role": Role.ASSISTANT.value, "content": old_response})
content = []
if dataset_attr.prompt and examples[dataset_attr.prompt][i]:
content.append(examples[dataset_attr.prompt][i])
if dataset_attr.query and examples[dataset_attr.query][i]:
content.append(examples[dataset_attr.query][i])
prompt.append({"role": Role.USER.value, "content": "\n".join(content)})
if dataset_attr.response and isinstance(examples[dataset_attr.response][i], list):
response = [
{"role": Role.ASSISTANT.value, "content": content} for content in examples[dataset_attr.response][i]
]
elif dataset_attr.response and isinstance(examples[dataset_attr.response][i], str):
response = [{"role": Role.ASSISTANT.value, "content": examples[dataset_attr.response][i]}]
else:
response = []
outputs["prompt"].append(prompt)
outputs["response"].append(response)
outputs["system"].append(examples[dataset_attr.system][i] if dataset_attr.system else "")
outputs["tools"].append("")
return outputs
def convert_sharegpt(examples: Dict[str, List[Any]], dataset_attr: "DatasetAttr") -> Dict[str, List[Any]]:
outputs = {"prompt": [], "response": [], "system": [], "tools": []}
tag_mapping = {
dataset_attr.user_tag: Role.USER.value,
dataset_attr.assistant_tag: Role.ASSISTANT.value,
dataset_attr.observation_tag: Role.OBSERVATION.value,
dataset_attr.function_tag: Role.FUNCTION.value,
dataset_attr.system_tag: Role.SYSTEM.value,
}
odd_tags = (dataset_attr.user_tag, dataset_attr.observation_tag)
even_tags = (dataset_attr.assistant_tag, dataset_attr.function_tag)
accept_tags = (odd_tags, even_tags)
for i, messages in enumerate(examples[dataset_attr.messages]):
if dataset_attr.system_tag and messages[0][dataset_attr.role_tag] == dataset_attr.system_tag:
system = messages[0][dataset_attr.content_tag]
messages = messages[1:]
else:
system = examples[dataset_attr.system][i] if dataset_attr.system else ""
messages = messages[: len(messages) // 2 * 2] # should be multiples of 2
if len(messages) == 0:
continue
aligned_messages = []
for turn_idx, message in enumerate(messages):
if message[dataset_attr.role_tag] not in accept_tags[turn_idx % 2]:
raise ValueError("Invalid role tag in {}.".format(messages))
aligned_messages.append(
{"role": tag_mapping[message[dataset_attr.role_tag]], "content": message[dataset_attr.content_tag]}
)
outputs["prompt"].append(aligned_messages[:-1])
outputs["response"].append(aligned_messages[-1:])
outputs["system"].append(system)
outputs["tools"].append(examples[dataset_attr.tools][i] if dataset_attr.tools else "")
return outputs
def align_dataset(
dataset: Union["Dataset", "IterableDataset"], dataset_attr: "DatasetAttr", data_args: "DataArguments"
) -> Union["Dataset", "IterableDataset"]:
r"""
Aligned dataset:
prompt: [{"role": "user", "content": "..."}] * (2T - 1)
response: [{"role": "assistant", "content": "..."}] * N (N > 1 for ranking dataset)
system: "..."
tools: "..."
"""
if dataset_attr.formatting == "alpaca":
convert_func = partial(convert_alpaca, dataset_attr=dataset_attr)
else:
convert_func = partial(convert_sharegpt, dataset_attr=dataset_attr)
column_names = list(next(iter(dataset)).keys())
features = Features.from_dict(
{
"prompt": [
{"role": {"dtype": "string", "_type": "Value"}, "content": {"dtype": "string", "_type": "Value"}}
],
"response": [
{"role": {"dtype": "string", "_type": "Value"}, "content": {"dtype": "string", "_type": "Value"}}
],
"system": {"dtype": "string", "_type": "Value"},
"tools": {"dtype": "string", "_type": "Value"},
}
)
kwargs = {}
if not data_args.streaming:
kwargs = dict(
num_proc=data_args.preprocessing_num_workers,
load_from_cache_file=(not data_args.overwrite_cache),
desc="Converting format of dataset",
)
return dataset.map(
convert_func,
batched=True,
remove_columns=column_names,
features=features,
**kwargs,
)

View File

@@ -0,0 +1,187 @@
import json
import re
from abc import ABC, abstractmethod
from dataclasses import dataclass, field
from typing import Any, Dict, List, Literal, Optional, Sequence, Set, Tuple, Union
SLOTS = Sequence[Union[str, Set[str], Dict[str, str]]]
JSON_FORMAT_PROMPT = (
""", in a JSON format representing the kwargs (e.g. ```{"input": "hello world", "num_beams": 5}```)"""
)
TOOL_SYSTEM_PROMPT = (
"You have access to the following tools:\n{tool_text}"
"Use the following format if using a tool:\n"
"```\n"
"Action: tool name (one of [{tool_names}]).\n"
"Action Input: the input to the tool{format_prompt}.\n"
"```\n"
)
def default_tool_formatter(tools: List[Dict[str, Any]]) -> str:
tool_text = ""
tool_names = []
for tool in tools:
param_text = ""
for name, param in tool["parameters"]["properties"].items():
required = ", required" if name in tool["parameters"].get("required", []) else ""
enum = ", should be one of [{}]".format(", ".join(param["enum"])) if param.get("enum", None) else ""
items = (
", where each item should be {}".format(param["items"].get("type", "")) if param.get("items") else ""
)
param_text += " - {name} ({type}{required}): {desc}{enum}{items}\n".format(
name=name,
type=param.get("type", ""),
required=required,
desc=param.get("description", ""),
enum=enum,
items=items,
)
tool_text += "> Tool Name: {name}\nTool Description: {desc}\nTool Args:\n{args}\n".format(
name=tool["name"], desc=tool.get("description", ""), args=param_text
)
tool_names.append(tool["name"])
return TOOL_SYSTEM_PROMPT.format(
tool_text=tool_text, tool_names=", ".join(tool_names), format_prompt=JSON_FORMAT_PROMPT
)
def default_tool_extractor(content: str) -> Union[str, Tuple[str, str]]:
regex = re.compile(r"Action:\s*([a-zA-Z0-9_]+).*?Action Input:\s*(.*)", re.DOTALL)
action_match = re.search(regex, content)
if not action_match:
return content
tool_name = action_match.group(1).strip()
tool_input = action_match.group(2).strip().strip('"').strip("```")
try:
arguments = json.loads(tool_input)
except json.JSONDecodeError:
return content
return tool_name, json.dumps(arguments, ensure_ascii=False)
@dataclass
class Formatter(ABC):
slots: SLOTS = field(default_factory=list)
tool_format: Optional[Literal["default"]] = None
@abstractmethod
def apply(self, **kwargs) -> SLOTS: ...
def extract(self, content: str) -> Union[str, Tuple[str, str]]:
raise NotImplementedError
@dataclass
class EmptyFormatter(Formatter):
def __post_init__(self):
has_placeholder = False
for slot in filter(lambda s: isinstance(s, str), self.slots):
if re.search(r"\{\{[a-zA-Z_][a-zA-Z0-9_]*\}\}", slot):
has_placeholder = True
if has_placeholder:
raise ValueError("Empty formatter should not contain any placeholder.")
def apply(self, **kwargs) -> SLOTS:
return self.slots
@dataclass
class StringFormatter(Formatter):
def __post_init__(self):
has_placeholder = False
for slot in filter(lambda s: isinstance(s, str), self.slots):
if re.search(r"\{\{[a-zA-Z_][a-zA-Z0-9_]*\}\}", slot):
has_placeholder = True
if not has_placeholder:
raise ValueError("A placeholder is required in the string formatter.")
def apply(self, **kwargs) -> SLOTS:
elements = []
for slot in self.slots:
if isinstance(slot, str):
for name, value in kwargs.items():
if not isinstance(value, str):
raise RuntimeError("Expected a string, got {}".format(value))
slot = slot.replace("{{" + name + "}}", value, 1)
elements.append(slot)
elif isinstance(slot, (dict, set)):
elements.append(slot)
else:
raise RuntimeError("Input must be string, set[str] or dict[str, str], got {}".format(type(slot)))
return elements
@dataclass
class FunctionFormatter(Formatter):
def __post_init__(self):
has_name, has_args = False, False
for slot in filter(lambda s: isinstance(s, str), self.slots):
if "{{name}}" in slot:
has_name = True
if "{{arguments}}" in slot:
has_args = True
if not has_name or not has_args:
raise ValueError("Name and arguments placeholders are required in the function formatter.")
def apply(self, **kwargs) -> SLOTS:
content = kwargs.pop("content")
try:
function = json.loads(content)
name = function["name"]
arguments = json.dumps(function["arguments"], ensure_ascii=False)
except Exception:
name, arguments = "", ""
elements = []
for slot in self.slots:
if isinstance(slot, str):
slot = slot.replace("{{name}}", name).replace("{{arguments}}", arguments)
elements.append(slot)
elif isinstance(slot, (dict, set)):
elements.append(slot)
else:
raise RuntimeError("Input must be string, set[str] or dict[str, str], got {}".format(type(slot)))
return elements
@dataclass
class ToolFormatter(Formatter):
def __post_init__(self):
if self.tool_format is None:
raise ValueError("Tool format was not found.")
def apply(self, **kwargs) -> SLOTS:
content = kwargs.pop("content")
try:
tools = json.loads(content)
if not len(tools):
return [""]
if self.tool_format == "default":
return [default_tool_formatter(tools)]
else:
raise NotImplementedError
except Exception:
return [""]
def extract(self, content: str) -> Union[str, Tuple[str, str]]:
if self.tool_format == "default":
return default_tool_extractor(content)
else:
raise NotImplementedError

View File

@@ -0,0 +1,170 @@
import inspect
import os
from typing import TYPE_CHECKING, Literal, Union
from datasets import load_dataset, load_from_disk
from ..extras.constants import FILEEXT2TYPE
from ..extras.logging import get_logger
from .aligner import align_dataset
from .parser import get_dataset_list
from .preprocess import get_preprocess_and_print_func
from .template import get_template_and_fix_tokenizer
from .utils import checksum, merge_dataset
if TYPE_CHECKING:
from datasets import Dataset, IterableDataset
from transformers import Seq2SeqTrainingArguments
from transformers.tokenization_utils import PreTrainedTokenizer
from ..hparams import DataArguments, ModelArguments
from .parser import DatasetAttr
logger = get_logger(__name__)
def load_single_dataset(
dataset_attr: "DatasetAttr",
model_args: "ModelArguments",
data_args: "DataArguments",
) -> Union["Dataset", "IterableDataset"]:
logger.info("Loading dataset {}...".format(dataset_attr))
data_path, data_name, data_dir, data_files = None, None, None, None
if dataset_attr.load_from in ["hf_hub", "ms_hub"]:
data_path = dataset_attr.dataset_name
data_name = dataset_attr.subset
data_dir = dataset_attr.folder
elif dataset_attr.load_from == "script":
data_path = os.path.join(data_args.dataset_dir, dataset_attr.dataset_name)
data_name = dataset_attr.subset
data_dir = dataset_attr.folder
elif dataset_attr.load_from == "file":
data_files = []
local_path = os.path.join(data_args.dataset_dir, dataset_attr.dataset_name)
if os.path.isdir(local_path): # is directory
for file_name in os.listdir(local_path):
data_files.append(os.path.join(local_path, file_name))
if data_path is None:
data_path = FILEEXT2TYPE.get(file_name.split(".")[-1], None)
elif data_path != FILEEXT2TYPE.get(file_name.split(".")[-1], None):
raise ValueError("File types should be identical.")
elif os.path.isfile(local_path): # is file
data_files.append(local_path)
data_path = FILEEXT2TYPE.get(local_path.split(".")[-1], None)
else:
raise ValueError("File not found.")
if data_path is None:
raise ValueError("File extension must be txt, csv, json or jsonl.")
checksum(data_files, dataset_attr.file_sha1)
else:
raise NotImplementedError
if dataset_attr.load_from == "ms_hub":
try:
from modelscope import MsDataset
from modelscope.utils.config_ds import MS_DATASETS_CACHE
cache_dir = model_args.cache_dir or MS_DATASETS_CACHE
dataset = MsDataset.load(
dataset_name=data_path,
subset_name=data_name,
data_dir=data_dir,
data_files=data_files,
split=data_args.split,
cache_dir=cache_dir,
token=model_args.ms_hub_token,
use_streaming=(data_args.streaming and (dataset_attr.load_from != "file")),
).to_hf_dataset()
except ImportError:
raise ImportError("Please install modelscope via `pip install modelscope -U`")
else:
if "trust_remote_code" in inspect.signature(load_dataset).parameters: # for datasets==2.16.0
kwargs = {"trust_remote_code": True}
else:
kwargs = {}
dataset = load_dataset(
path=data_path,
name=data_name,
data_dir=data_dir,
data_files=data_files,
split=data_args.split,
cache_dir=model_args.cache_dir,
token=model_args.hf_hub_token,
streaming=(data_args.streaming and (dataset_attr.load_from != "file")),
**kwargs,
)
if data_args.streaming and (dataset_attr.load_from == "file"): # faster than specifying streaming=True
dataset = dataset.to_iterable_dataset() # TODO: add num shards parameter
if data_args.max_samples is not None: # truncate dataset
num_samples = min(data_args.max_samples, len(dataset))
dataset = dataset.select(range(num_samples))
return align_dataset(dataset, dataset_attr, data_args)
def get_dataset(
tokenizer: "PreTrainedTokenizer",
model_args: "ModelArguments",
data_args: "DataArguments",
training_args: "Seq2SeqTrainingArguments",
stage: Literal["pt", "sft", "rm", "ppo"],
# split: Optional[str] = "train", # TODO: add split
) -> Union["Dataset", "IterableDataset"]:
template = get_template_and_fix_tokenizer(tokenizer, data_args.template)
if data_args.train_on_prompt and template.efficient_eos:
raise ValueError("Current template does not support `train_on_prompt`.")
# Load from cache
if data_args.cache_path is not None:
if os.path.exists(data_args.cache_path):
logger.warning("Loading dataset from disk will ignore other data arguments.")
dataset = load_from_disk(data_args.cache_path)
if data_args.streaming:
dataset = dataset.to_iterable_dataset()
return dataset
if data_args.streaming:
raise ValueError("Turn off `streaming` when saving dataset to disk.")
with training_args.main_process_first(desc="load dataset"):
all_datasets = []
for dataset_attr in get_dataset_list(data_args):
all_datasets.append(load_single_dataset(dataset_attr, model_args, data_args))
dataset = merge_dataset(all_datasets, data_args, training_args)
with training_args.main_process_first(desc="pre-process dataset"):
preprocess_func, print_function = get_preprocess_and_print_func(
tokenizer, template, data_args, training_args, stage
)
column_names = list(next(iter(dataset)).keys())
kwargs = {}
if not data_args.streaming:
kwargs = dict(
num_proc=data_args.preprocessing_num_workers,
load_from_cache_file=(not data_args.overwrite_cache),
desc="Running tokenizer on dataset",
)
dataset = dataset.map(preprocess_func, batched=True, remove_columns=column_names, **kwargs)
if data_args.cache_path is not None and not os.path.exists(data_args.cache_path):
if training_args.should_save:
dataset.save_to_disk(data_args.cache_path)
logger.info("Dataset cache saved at {}.".format(data_args.cache_path))
if training_args.should_log:
try:
print_function(next(iter(dataset)))
except StopIteration:
raise RuntimeError("Cannot find valid samples, check `data/README.md` for the data format.")
return dataset

View File

@@ -0,0 +1,119 @@
import json
import os
from dataclasses import dataclass
from typing import TYPE_CHECKING, Any, Dict, List, Literal, Optional
from ..extras.constants import DATA_CONFIG
from ..extras.misc import use_modelscope
if TYPE_CHECKING:
from ..hparams import DataArguments
@dataclass
class DatasetAttr:
r"""
Dataset attributes.
"""
""" basic configs """
load_from: Literal["hf_hub", "ms_hub", "script", "file"]
dataset_name: str
""" extra configs """
file_sha1: Optional[str] = None
subset: Optional[str] = None
folder: Optional[str] = None
ranking: bool = False
formatting: Literal["alpaca", "sharegpt"] = "alpaca"
""" columns """
system: Optional[str] = None
""" columns for the alpaca format """
prompt: Optional[str] = "instruction"
query: Optional[str] = "input"
response: Optional[str] = "output"
history: Optional[str] = None
""" columns for the sharegpt format """
messages: Optional[str] = "conversations"
tools: Optional[str] = None
""" tags for the sharegpt format """
role_tag: Optional[str] = "from"
content_tag: Optional[str] = "value"
user_tag: Optional[str] = "human"
assistant_tag: Optional[str] = "gpt"
observation_tag: Optional[str] = "observation"
function_tag: Optional[str] = "function_call"
system_tag: Optional[str] = "system"
def __repr__(self) -> str:
return self.dataset_name
def set_attr(self, key: str, obj: Dict[str, Any], default: Optional[Any] = None) -> None:
setattr(self, key, obj.get(key, default))
def get_dataset_list(data_args: "DataArguments") -> List["DatasetAttr"]:
dataset_names = [ds.strip() for ds in data_args.dataset.split(",")] if data_args.dataset is not None else []
try:
with open(os.path.join(data_args.dataset_dir, DATA_CONFIG), "r") as f:
dataset_info = json.load(f)
except Exception as err:
if data_args.dataset is not None:
raise ValueError(
"Cannot open {} due to {}.".format(os.path.join(data_args.dataset_dir, DATA_CONFIG), str(err))
)
dataset_info = None
if data_args.interleave_probs is not None:
data_args.interleave_probs = [float(prob.strip()) for prob in data_args.interleave_probs.split(",")]
dataset_list: List[DatasetAttr] = []
for name in dataset_names:
if name not in dataset_info:
raise ValueError("Undefined dataset {} in {}.".format(name, DATA_CONFIG))
has_hf_url = "hf_hub_url" in dataset_info[name]
has_ms_url = "ms_hub_url" in dataset_info[name]
if has_hf_url or has_ms_url:
if (use_modelscope() and has_ms_url) or (not has_hf_url):
dataset_attr = DatasetAttr("ms_hub", dataset_name=dataset_info[name]["ms_hub_url"])
else:
dataset_attr = DatasetAttr("hf_hub", dataset_name=dataset_info[name]["hf_hub_url"])
elif "script_url" in dataset_info[name]:
dataset_attr = DatasetAttr("script", dataset_name=dataset_info[name]["script_url"])
else:
dataset_attr = DatasetAttr("file", dataset_name=dataset_info[name]["file_name"])
dataset_attr.set_attr("file_sha1", dataset_info[name])
dataset_attr.set_attr("subset", dataset_info[name])
dataset_attr.set_attr("folder", dataset_info[name])
dataset_attr.set_attr("ranking", dataset_info[name], default=False)
dataset_attr.set_attr("formatting", dataset_info[name], default="alpaca")
if "columns" in dataset_info[name]:
column_names = ["system"]
if dataset_attr.formatting == "alpaca":
column_names.extend(["prompt", "query", "response", "history"])
else:
column_names.extend(["messages", "tools"])
for column_name in column_names:
dataset_attr.set_attr(column_name, dataset_info[name]["columns"])
if dataset_attr.formatting == "sharegpt" and "tags" in dataset_info[name]:
tag_names = (
"role_tag",
"content_tag",
"user_tag",
"assistant_tag",
"observation_tag",
"function_tag",
"system_tag",
)
for tag in tag_names:
dataset_attr.set_attr(tag, dataset_info[name]["tags"])
dataset_list.append(dataset_attr)
return dataset_list

View File

@@ -0,0 +1,276 @@
from functools import partial
from itertools import chain
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Literal, Tuple
from ..extras.constants import IGNORE_INDEX
from ..extras.logging import get_logger
from .utils import Role
if TYPE_CHECKING:
from transformers import Seq2SeqTrainingArguments
from transformers.tokenization_utils import PreTrainedTokenizer
from ..hparams import DataArguments
from .template import Template
logger = get_logger(__name__)
def preprocess_pretrain_dataset(
examples: Dict[str, List[Any]], tokenizer: "PreTrainedTokenizer", data_args: "DataArguments"
) -> Dict[str, List[List[int]]]:
# build grouped texts with format `X1 X2 X3 ...` if packing is enabled
text_examples = [messages[0]["content"] + tokenizer.eos_token for messages in examples["prompt"]]
if not data_args.packing:
return tokenizer(text_examples, add_special_tokens=False, max_length=data_args.cutoff_len)
tokenized_examples = tokenizer(text_examples, add_special_tokens=False)
concatenated_examples = {k: list(chain(*tokenized_examples[k])) for k in tokenized_examples.keys()}
total_length = len(concatenated_examples[list(concatenated_examples.keys())[0]])
block_size = data_args.cutoff_len
# we drop the small remainder, and if the total_length < block_size, we exclude this batch
total_length = (total_length // block_size) * block_size
# split by chunks of cutoff_len
result = {
k: [t[i : i + block_size] for i in range(0, total_length, block_size)]
for k, t in concatenated_examples.items()
}
if data_args.template == "gemma":
for i in range(len(result["input_ids"])):
result["input_ids"][i][0] = tokenizer.bos_token_id
return result
def preprocess_supervised_dataset(
examples: Dict[str, List[Any]],
tokenizer: "PreTrainedTokenizer",
template: "Template",
data_args: "DataArguments",
) -> Dict[str, List[List[int]]]:
# build inputs with format `<bos> X Y <eos>` and labels with format `<ignore> ... <ignore> Y <eos>`
# for multiturn examples, we only mask the prompt part in each prompt-response pair.
model_inputs = {"input_ids": [], "attention_mask": [], "labels": []}
for i in range(len(examples["prompt"])):
if len(examples["prompt"][i]) % 2 != 1 or len(examples["response"][i]) != 1:
continue
messages = examples["prompt"][i] + examples["response"][i]
input_ids, labels = [], []
for turn_idx, (source_ids, target_ids) in enumerate(
template.encode_multiturn(
tokenizer,
messages,
examples["system"][i],
examples["tools"][i],
data_args.cutoff_len,
data_args.reserved_label_len,
)
):
if data_args.train_on_prompt:
source_mask = source_ids
elif turn_idx != 0 and template.efficient_eos:
source_mask = [tokenizer.eos_token_id] + [IGNORE_INDEX] * (len(source_ids) - 1)
else:
source_mask = [IGNORE_INDEX] * len(source_ids)
input_ids += source_ids + target_ids
labels += source_mask + target_ids
if template.efficient_eos:
input_ids += [tokenizer.eos_token_id]
labels += [tokenizer.eos_token_id]
model_inputs["input_ids"].append(input_ids)
model_inputs["attention_mask"].append([1] * len(input_ids))
model_inputs["labels"].append(labels)
return model_inputs
def preprocess_packed_supervised_dataset(
examples: Dict[str, List[Any]],
tokenizer: "PreTrainedTokenizer",
template: "Template",
data_args: "DataArguments",
) -> Dict[str, List[List[int]]]:
# build inputs with format `<bos> X1 Y1 <eos> <bos> X2 Y2 <eos>`
# and labels with format `<ignore> ... <ignore> Y1 <eos> <ignore> ... <ignore> Y2 <eos>`
model_inputs = {"input_ids": [], "attention_mask": [], "labels": []}
input_ids, labels = [], []
for i in range(len(examples["prompt"])):
if len(examples["prompt"][i]) % 2 != 1 or len(examples["response"][i]) != 1:
continue
messages = examples["prompt"][i] + examples["response"][i]
for source_ids, target_ids in template.encode_multiturn(
tokenizer, messages, examples["system"][i], examples["tools"][i]
):
if data_args.train_on_prompt:
source_mask = source_ids
elif len(input_ids) != 0 and template.efficient_eos:
source_mask = [tokenizer.eos_token_id] + [IGNORE_INDEX] * (len(source_ids) - 1)
else:
source_mask = [IGNORE_INDEX] * len(source_ids)
input_ids += source_ids + target_ids
labels += source_mask + target_ids
if template.efficient_eos:
input_ids += [tokenizer.eos_token_id]
labels += [tokenizer.eos_token_id]
total_length = len(input_ids)
block_size = data_args.cutoff_len
# we drop the small remainder, and if the total_length < block_size, we exclude this batch
total_length = (total_length // block_size) * block_size
# split by chunks of cutoff_len
for i in range(0, total_length, block_size):
if not all(label == IGNORE_INDEX for label in labels[i : i + block_size]):
model_inputs["input_ids"].append(input_ids[i : i + block_size])
model_inputs["attention_mask"].append([1] * block_size)
model_inputs["labels"].append(labels[i : i + block_size])
return model_inputs
def preprocess_unsupervised_dataset(
examples: Dict[str, List[Any]],
tokenizer: "PreTrainedTokenizer",
template: "Template",
data_args: "DataArguments",
) -> Dict[str, List[List[int]]]:
# build inputs with format `<bos> X` and labels with format `Y <eos>`
model_inputs = {"input_ids": [], "attention_mask": [], "labels": []}
for i in range(len(examples["prompt"])):
if len(examples["prompt"][i]) % 2 != 1:
continue
if len(examples["response"][i]) == 1:
messages = examples["prompt"][i] + examples["response"][i]
else:
messages = examples["prompt"][i] + [{"role": Role.ASSISTANT.value, "content": ""}]
input_ids, labels = template.encode_oneturn(
tokenizer,
messages,
examples["system"][i],
examples["tools"][i],
data_args.cutoff_len,
data_args.reserved_label_len,
)
if template.efficient_eos:
labels += [tokenizer.eos_token_id]
model_inputs["input_ids"].append(input_ids)
model_inputs["attention_mask"].append([1] * len(input_ids))
model_inputs["labels"].append(labels)
return model_inputs
def preprocess_pairwise_dataset(
examples: Dict[str, List[Any]],
tokenizer: "PreTrainedTokenizer",
template: "Template",
data_args: "DataArguments",
) -> Dict[str, List[List[int]]]:
# build input pairs with format `<bos> X`, `Y1 <eos>` and `Y2 <eos>`
model_inputs = {"prompt_ids": [], "chosen_ids": [], "rejected_ids": []}
for i in range(len(examples["prompt"])):
if len(examples["prompt"][i]) % 2 != 1 or len(examples["response"][i]) < 2:
continue
chosen_messages = examples["prompt"][i] + [examples["response"][i][0]]
rejected_messages = examples["prompt"][i] + [examples["response"][i][1]]
prompt_ids, chosen_ids = template.encode_oneturn(
tokenizer,
chosen_messages,
examples["system"][i],
examples["tools"][i],
data_args.cutoff_len,
data_args.reserved_label_len,
)
_, rejected_ids = template.encode_oneturn(
tokenizer,
rejected_messages,
examples["system"][i],
examples["tools"][i],
data_args.cutoff_len,
data_args.reserved_label_len,
)
if template.efficient_eos:
chosen_ids += [tokenizer.eos_token_id]
rejected_ids += [tokenizer.eos_token_id]
model_inputs["prompt_ids"].append(prompt_ids)
model_inputs["chosen_ids"].append(chosen_ids)
model_inputs["rejected_ids"].append(rejected_ids)
return model_inputs
def print_supervised_dataset_example(example: Dict[str, List[int]], tokenizer: "PreTrainedTokenizer") -> None:
print("input_ids:\n{}".format(example["input_ids"]))
print("inputs:\n{}".format(tokenizer.decode(example["input_ids"], skip_special_tokens=False)))
print("label_ids:\n{}".format(example["labels"]))
print(
"labels:\n{}".format(
tokenizer.decode(list(filter(lambda x: x != IGNORE_INDEX, example["labels"])), skip_special_tokens=False)
)
)
def print_pairwise_dataset_example(example: Dict[str, List[int]], tokenizer: "PreTrainedTokenizer") -> None:
print("prompt_ids:\n{}".format(example["prompt_ids"]))
print("prompt:\n{}".format(tokenizer.decode(example["prompt_ids"], skip_special_tokens=False)))
print("chosen_ids:\n{}".format(example["chosen_ids"]))
print("chosen:\n{}".format(tokenizer.decode(example["chosen_ids"], skip_special_tokens=False)))
print("rejected_ids:\n{}".format(example["rejected_ids"]))
print("rejected:\n{}".format(tokenizer.decode(example["rejected_ids"], skip_special_tokens=False)))
def print_unsupervised_dataset_example(example: Dict[str, List[int]], tokenizer: "PreTrainedTokenizer") -> None:
print("input_ids:\n{}".format(example["input_ids"]))
print("inputs:\n{}".format(tokenizer.decode(example["input_ids"], skip_special_tokens=False)))
def get_preprocess_and_print_func(
tokenizer: "PreTrainedTokenizer",
template: "Template",
data_args: "DataArguments",
training_args: "Seq2SeqTrainingArguments",
stage: Literal["pt", "sft", "rm", "ppo"],
) -> Tuple[Callable, Callable]:
if stage == "pt":
preprocess_func = partial(preprocess_pretrain_dataset, tokenizer=tokenizer, data_args=data_args)
print_function = partial(print_unsupervised_dataset_example, tokenizer=tokenizer)
elif stage == "sft" and not training_args.predict_with_generate:
if data_args.packing:
preprocess_func = partial(
preprocess_packed_supervised_dataset, tokenizer=tokenizer, template=template, data_args=data_args
)
else:
preprocess_func = partial(
preprocess_supervised_dataset, tokenizer=tokenizer, template=template, data_args=data_args
)
print_function = partial(print_supervised_dataset_example, tokenizer=tokenizer)
elif stage == "rm":
preprocess_func = partial(
preprocess_pairwise_dataset, tokenizer=tokenizer, template=template, data_args=data_args
)
print_function = partial(print_pairwise_dataset_example, tokenizer=tokenizer)
else:
preprocess_func = partial(
preprocess_unsupervised_dataset, tokenizer=tokenizer, template=template, data_args=data_args
)
print_function = partial(print_unsupervised_dataset_example, tokenizer=tokenizer)
return preprocess_func, print_function

View File

@@ -0,0 +1,773 @@
from dataclasses import dataclass
from typing import TYPE_CHECKING, Dict, List, Optional, Sequence, Tuple, Union
from ..extras.logging import get_logger
from .formatter import EmptyFormatter, FunctionFormatter, StringFormatter, ToolFormatter
from .utils import Role, infer_max_len
if TYPE_CHECKING:
from transformers import PreTrainedTokenizer
from .formatter import SLOTS, Formatter
logger = get_logger(__name__)
@dataclass
class Template:
format_user: "Formatter"
format_assistant: "Formatter"
format_system: "Formatter"
format_function: "Formatter"
format_observation: "Formatter"
format_tools: "Formatter"
format_separator: "Formatter"
default_system: str
stop_words: List[str]
efficient_eos: bool
replace_eos: bool
force_system: bool
def encode_oneturn(
self,
tokenizer: "PreTrainedTokenizer",
messages: List[Dict[str, str]],
system: Optional[str] = None,
tools: Optional[str] = None,
cutoff_len: int = 1_000_000,
reserved_label_len: int = 1,
) -> Tuple[List[int], List[int]]:
r"""
Returns a single pair of token ids representing prompt and response respectively.
"""
encoded_pairs = self._encode(tokenizer, messages, system, tools, cutoff_len, reserved_label_len)
prompt_ids = []
for query_ids, resp_ids in encoded_pairs[:-1]:
prompt_ids += query_ids + resp_ids
prompt_ids = prompt_ids + encoded_pairs[-1][0]
answer_ids = encoded_pairs[-1][1]
return prompt_ids, answer_ids
def encode_multiturn(
self,
tokenizer: "PreTrainedTokenizer",
messages: List[Dict[str, str]],
system: Optional[str] = None,
tools: Optional[str] = None,
cutoff_len: int = 1_000_000,
reserved_label_len: int = 1,
) -> Sequence[Tuple[List[int], List[int]]]:
r"""
Returns multiple pairs of token ids representing prompts and responses respectively.
"""
return self._encode(tokenizer, messages, system, tools, cutoff_len, reserved_label_len)
def _encode(
self,
tokenizer: "PreTrainedTokenizer",
messages: List[Dict[str, str]],
system: str,
tools: str,
cutoff_len: int,
reserved_label_len: int,
) -> Sequence[Tuple[List[int], List[int]]]:
r"""
Encodes formatted inputs to pairs of token ids.
Turn 0: system + query resp
Turn t: sep + query resp
"""
system = system or self.default_system
encoded_messages = []
for i, message in enumerate(messages):
elements = []
if i == 0 and (system or tools or self.force_system):
tool_text = self.format_tools.apply(content=tools)[0] if tools else ""
elements += self.format_system.apply(content=(system + tool_text))
elif i > 0 and i % 2 == 0:
elements += self.format_separator.apply()
if message["role"] == Role.USER.value:
elements += self.format_user.apply(content=message["content"], idx=str(i // 2))
elif message["role"] == Role.ASSISTANT.value:
elements += self.format_assistant.apply(content=message["content"])
elif message["role"] == Role.OBSERVATION.value:
elements += self.format_observation.apply(content=message["content"])
elif message["role"] == Role.FUNCTION.value:
elements += self.format_function.apply(content=message["content"])
else:
raise NotImplementedError("Unexpected role: {}".format(message["role"]))
encoded_messages.append(self._convert_elements_to_ids(tokenizer, elements))
return self._make_pairs(encoded_messages, cutoff_len, reserved_label_len)
def _convert_elements_to_ids(
self, tokenizer: "PreTrainedTokenizer", elements: List[Union[str, Dict[str, str]]]
) -> List[int]:
r"""
Converts elements to token ids.
"""
token_ids = []
for elem in elements:
if isinstance(elem, str):
if len(elem) != 0:
token_ids += tokenizer.encode(elem, add_special_tokens=False)
elif isinstance(elem, dict):
token_ids += [tokenizer.convert_tokens_to_ids(elem.get("token"))]
elif isinstance(elem, set):
if "bos_token" in elem and tokenizer.bos_token_id is not None:
token_ids += [tokenizer.bos_token_id]
elif "eos_token" in elem and tokenizer.eos_token_id is not None:
token_ids += [tokenizer.eos_token_id]
else:
raise ValueError("Input must be string, set[str] or dict[str, str], got {}".format(type(elem)))
return token_ids
def _make_pairs(
self,
encoded_messages: Sequence[List[int]],
cutoff_len: int,
reserved_label_len: int,
) -> Sequence[Tuple[List[int], List[int]]]:
encoded_pairs = []
total_length = 0
for i in range(0, len(encoded_messages), 2):
if total_length >= cutoff_len:
break
max_source_len, max_target_len = infer_max_len(
source_len=len(encoded_messages[i]),
target_len=len(encoded_messages[i + 1]),
max_len=(cutoff_len - total_length),
reserved_label_len=reserved_label_len,
)
source_ids = encoded_messages[i][:max_source_len]
target_ids = encoded_messages[i + 1][:max_target_len]
total_length += len(source_ids) + len(target_ids)
encoded_pairs.append((source_ids, target_ids))
return encoded_pairs
@dataclass
class Llama2Template(Template):
def _encode(
self,
tokenizer: "PreTrainedTokenizer",
messages: List[Dict[str, str]],
system: str,
tools: str,
cutoff_len: int,
reserved_label_len: int,
) -> Sequence[Tuple[List[int], List[int]]]:
r"""
Encodes formatted inputs to pairs of token ids.
Turn 0: system + query resp
Turn t: sep + query resp
"""
system = system or self.default_system
encoded_messages = []
for i, message in enumerate(messages):
elements = []
system_text = ""
if i == 0 and (system or tools or self.force_system):
tool_text = self.format_tools.apply(content=tools)[0] if tools else ""
system_text = self.format_system.apply(content=(system + tool_text))[0]
elif i > 0 and i % 2 == 0:
elements += self.format_separator.apply()
if message["role"] == Role.USER.value:
elements += self.format_user.apply(content=system_text + message["content"])
elif message["role"] == Role.ASSISTANT.value:
elements += self.format_assistant.apply(content=message["content"])
elif message["role"] == Role.OBSERVATION.value:
elements += self.format_observation.apply(content=message["content"])
elif message["role"] == Role.FUNCTION.value:
elements += self.format_function.apply(content=message["content"])
else:
raise NotImplementedError("Unexpected role: {}".format(message["role"]))
encoded_messages.append(self._convert_elements_to_ids(tokenizer, elements))
return self._make_pairs(encoded_messages, cutoff_len, reserved_label_len)
templates: Dict[str, Template] = {}
def _register_template(
name: str,
format_user: Optional["Formatter"] = None,
format_assistant: Optional["Formatter"] = None,
format_system: Optional["Formatter"] = None,
format_function: Optional["Formatter"] = None,
format_observation: Optional["Formatter"] = None,
format_tools: Optional["Formatter"] = None,
format_separator: Optional["Formatter"] = None,
default_system: str = "",
stop_words: List[str] = [],
efficient_eos: bool = False,
replace_eos: bool = False,
force_system: bool = False,
) -> None:
r"""
Registers a chat template.
To add the following chat template:
```
[HUMAN]:
user prompt here
[AI]:
model response here
[HUMAN]:
user prompt here
[AI]:
model response here
```
The corresponding code should be:
```
_register_template(
name="custom",
format_user=StringFormatter(slots=["[HUMAN]:\n{{content}}\n[AI]:\n"]),
format_separator=EmptyFormatter(slots=["\n\n"]),
efficient_eos=True,
)
```
"""
eos_slots = [] if efficient_eos else [{"eos_token"}]
template_class = Llama2Template if name.startswith("llama2") else Template
default_user_formatter = StringFormatter(slots=["{{content}}"])
default_assistant_formatter = StringFormatter(slots=["{{content}}"] + eos_slots)
default_function_formatter = FunctionFormatter(slots=["Action: {{name}}\nAction Input: {{arguments}}"] + eos_slots)
default_tool_formatter = ToolFormatter(tool_format="default")
default_separator_formatter = EmptyFormatter()
templates[name] = template_class(
format_user=format_user or default_user_formatter,
format_assistant=format_assistant or default_assistant_formatter,
format_system=format_system or default_user_formatter,
format_function=format_function or default_function_formatter,
format_observation=format_observation or format_user or default_user_formatter,
format_tools=format_tools or default_tool_formatter,
format_separator=format_separator or default_separator_formatter,
default_system=default_system,
stop_words=stop_words,
efficient_eos=efficient_eos,
replace_eos=replace_eos,
force_system=force_system,
)
def _add_or_replace_eos_token(tokenizer: "PreTrainedTokenizer", eos_token: str) -> None:
is_added = tokenizer.eos_token_id is None
num_added_tokens = tokenizer.add_special_tokens({"eos_token": eos_token})
if is_added:
logger.info("Add eos token: {}".format(tokenizer.eos_token))
else:
logger.info("Replace eos token: {}".format(tokenizer.eos_token))
if num_added_tokens > 0:
logger.warning("New tokens have been added, make sure `resize_vocab` is True.")
def _jinja_escape(content: str) -> str:
return content.replace("\n", r"\n").replace("'", r"\'")
def _convert_slots_to_jinja(slots: "SLOTS", tokenizer: "PreTrainedTokenizer", placeholder: str = "content") -> str:
slot_items = []
for slot in slots:
if isinstance(slot, str):
slot_pieces = slot.split("{{content}}")
if slot_pieces[0]:
slot_items.append("'" + _jinja_escape(slot_pieces[0]) + "'")
if len(slot_pieces) > 1:
slot_items.append(placeholder)
if slot_pieces[1]:
slot_items.append("'" + _jinja_escape(slot_pieces[1]) + "'")
elif isinstance(slot, set):
if "bos_token" in slot:
slot_items.append("'" + tokenizer.bos_token + "'")
elif "eos_token" in slot: # do not use {{ eos_token }} since it may be replaced
slot_items.append("'" + tokenizer.eos_token + "'")
elif isinstance(slot, dict):
raise ValueError("Dict is not supported.")
return " + ".join(slot_items)
def _get_jinja_template(template: "Template", tokenizer: "PreTrainedTokenizer") -> str:
jinja_template = ""
if template.default_system:
jinja_template += "{% set system_message = '" + _jinja_escape(template.default_system) + "' %}"
jinja_template += (
"{% if messages[0]['role'] == 'system' %}" "{% set system_message = messages[0]['content'] %}" "{% endif %}"
)
system_message = _convert_slots_to_jinja(template.format_system.apply(), tokenizer, placeholder="system_message")
if isinstance(template, Llama2Template):
pass
elif template.force_system:
jinja_template += "{{ " + system_message + " }}"
else:
jinja_template += "{% if system_message is defined %}{{ " + system_message + " }}{% endif %}"
jinja_template += "{% for message in messages %}"
jinja_template += "{% set content = message['content'] %}"
if isinstance(template, Llama2Template):
jinja_template += "{% if loop.index0 == 0 and system_message is defined %}"
jinja_template += "{% set content = " + system_message + " + message['content'] %}"
jinja_template += "{% endif %}"
jinja_template += "{% if message['role'] == 'user' %}"
user_message = _convert_slots_to_jinja(template.format_user.apply(), tokenizer)
jinja_template += "{{ " + user_message + " }}"
jinja_template += "{% elif message['role'] == 'assistant' %}"
assistant_message = _convert_slots_to_jinja(
template.format_assistant.apply() + template.format_separator.apply(), tokenizer
)
jinja_template += "{{ " + assistant_message + " }}"
jinja_template += "{% endif %}"
jinja_template += "{% endfor %}"
return jinja_template
def get_template_and_fix_tokenizer(
tokenizer: "PreTrainedTokenizer",
name: Optional[str] = None,
) -> Template:
if name is None:
template = templates["vanilla"] # placeholder
else:
template = templates.get(name, None)
if template is None:
raise ValueError("Template {} does not exist.".format(name))
stop_words = template.stop_words
if template.replace_eos:
if not stop_words:
raise ValueError("Stop words are required to replace the EOS token.")
_add_or_replace_eos_token(tokenizer, eos_token=stop_words[0])
stop_words = stop_words[1:]
if tokenizer.eos_token_id is None:
_add_or_replace_eos_token(tokenizer, eos_token="<|endoftext|>")
if tokenizer.pad_token_id is None:
tokenizer.pad_token = tokenizer.eos_token
logger.info("Add pad token: {}".format(tokenizer.pad_token))
if stop_words:
num_added_tokens = tokenizer.add_special_tokens(
dict(additional_special_tokens=stop_words), replace_additional_special_tokens=False
)
logger.info("Add {} to stop words.".format(",".join(stop_words)))
if num_added_tokens > 0:
logger.warning("New tokens have been added, make sure `resize_vocab` is True.")
try:
tokenizer.chat_template = _get_jinja_template(template, tokenizer)
except ValueError:
logger.info("Cannot add this chat template to tokenizer.")
return template
_register_template(
name="alpaca",
format_user=StringFormatter(slots=["### Instruction:\n{{content}}\n\n### Response:\n"]),
format_separator=EmptyFormatter(slots=["\n\n"]),
default_system=(
"Below is an instruction that describes a task. " "Write a response that appropriately completes the request."
),
)
_register_template(
name="aquila",
format_user=StringFormatter(slots=["Human: {{content}}###Assistant:"]),
format_separator=EmptyFormatter(slots=["###"]),
default_system=(
"A chat between a curious human and an artificial intelligence assistant. "
"The assistant gives helpful, detailed, and polite answers to the human's questions."
),
stop_words=["</s>"],
efficient_eos=True,
)
_register_template(
name="atom",
format_user=StringFormatter(
slots=[{"bos_token"}, "Human: {{content}}\n", {"eos_token"}, {"bos_token"}, "Assistant:"]
),
format_assistant=StringFormatter(slots=["{{content}}\n", {"eos_token"}]),
)
_register_template(
name="baichuan",
format_user=StringFormatter(slots=["<reserved_102>{{content}}<reserved_103>"]),
efficient_eos=True,
)
_register_template(
name="baichuan2",
format_user=StringFormatter(slots=["<reserved_106>{{content}}<reserved_107>"]),
efficient_eos=True,
)
_register_template(
name="belle",
format_user=StringFormatter(slots=["Human: {{content}}\n\nBelle: "]),
format_system=StringFormatter(slots=[{"bos_token"}, "{{content}}"]),
format_separator=EmptyFormatter(slots=["\n\n"]),
force_system=True,
)
_register_template(
name="bluelm",
format_user=StringFormatter(slots=[{"token": "[|Human|]:"}, "{{content}}", {"token": "[|AI|]:"}]),
)
_register_template(
name="chatglm2",
format_user=StringFormatter(slots=["[Round {{idx}}]\n\n问:{{content}}\n\n答:"]),
format_system=StringFormatter(slots=[{"token": "[gMASK]"}, {"token": "sop"}, "{{content}}"]),
format_separator=EmptyFormatter(slots=["\n\n"]),
efficient_eos=True,
force_system=True,
)
_register_template(
name="chatglm3",
format_user=StringFormatter(slots=[{"token": "<|user|>"}, "\n", "{{content}}", {"token": "<|assistant|>"}]),
format_assistant=StringFormatter(slots=["\n", "{{content}}"]),
format_system=StringFormatter(slots=[{"token": "[gMASK]"}, {"token": "sop"}, "{{content}}"]),
format_function=FunctionFormatter(slots=["{{name}}\n{{arguments}}"]),
format_observation=StringFormatter(
slots=[{"token": "<|observation|>"}, "\n", "{{content}}", {"token": "<|assistant|>"}]
),
stop_words=["<|user|>", "<|observation|>"],
efficient_eos=True,
force_system=True,
)
_register_template(
name="chatglm3_system",
format_user=StringFormatter(slots=[{"token": "<|user|>"}, "\n", "{{content}}", {"token": "<|assistant|>"}]),
format_assistant=StringFormatter(slots=["\n", "{{content}}"]),
format_system=StringFormatter(
slots=[{"token": "[gMASK]"}, {"token": "sop"}, {"token": "<|system|>"}, "\n", "{{content}}"]
),
format_function=FunctionFormatter(slots=["{{name}}\n{{arguments}}"]),
format_observation=StringFormatter(
slots=[{"token": "<|observation|>"}, "\n", "{{content}}", {"token": "<|assistant|>"}]
),
default_system=(
"You are ChatGLM3, a large language model trained by Zhipu.AI. "
"Follow the user's instructions carefully. Respond using markdown."
),
stop_words=["<|user|>", "<|observation|>"],
efficient_eos=True,
)
_register_template(
name="chatml",
format_user=StringFormatter(slots=["<|im_start|>user\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]),
format_system=StringFormatter(slots=["<|im_start|>system\n{{content}}<|im_end|>\n"]),
format_separator=EmptyFormatter(slots=["\n"]),
stop_words=["<|im_end|>", "<|im_start|>"],
replace_eos=True,
)
_register_template(
name="chatml_de",
format_user=StringFormatter(slots=["<|im_start|>user\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]),
format_system=StringFormatter(slots=["<|im_start|>system\n{{content}}<|im_end|>\n"]),
format_separator=EmptyFormatter(slots=["\n"]),
default_system="Du bist ein freundlicher und hilfsbereiter KI-Assistent.",
stop_words=["<|im_end|>", "<|im_start|>"],
replace_eos=True,
)
_register_template(
name="codegeex2",
format_system=StringFormatter(slots=[{"token": "[gMASK]"}, {"token": "sop"}, "{{content}}"]),
force_system=True,
)
_register_template(
name="cpm",
format_user=StringFormatter(slots=["<用户>{{content}}<AI>"]),
format_system=StringFormatter(slots=[{"bos_token"}, "{{content}}"]),
force_system=True,
)
_register_template(
name="deepseek",
format_user=StringFormatter(slots=["User: {{content}}\n\nAssistant:"]),
format_system=StringFormatter(slots=[{"bos_token"}, "{{content}}"]),
force_system=True,
)
_register_template(
name="deepseekcoder",
format_user=StringFormatter(slots=["### Instruction:\n{{content}}\n### Response:"]),
format_assistant=StringFormatter(slots=["\n", "{{content}}"]),
format_separator=EmptyFormatter(slots=["\n<|EOT|>\n"]),
default_system=(
"You are an AI programming assistant, utilizing the Deepseek Coder model, "
"developed by Deepseek Company, and you only answer questions related to computer science. "
"For politically sensitive questions, security and privacy issues, "
"and other non-computer science questions, you will refuse to answer\n"
),
stop_words=["<|EOT|>"],
efficient_eos=True,
)
_register_template(
name="default",
format_user=StringFormatter(slots=["Human: {{content}}\nAssistant: "]),
format_system=StringFormatter(slots=["{{content}}\n"]),
format_separator=EmptyFormatter(slots=["\n"]),
)
_register_template(
name="falcon",
format_user=StringFormatter(slots=["User: {{content}}\nFalcon:"]),
format_separator=EmptyFormatter(slots=["\n"]),
efficient_eos=True,
)
_register_template(
name="gemma",
format_user=StringFormatter(slots=["<start_of_turn>user\n{{content}}<end_of_turn>\n<start_of_turn>model\n"]),
format_system=StringFormatter(slots=[{"bos_token"}, "{{content}}"]),
format_separator=EmptyFormatter(slots=["<end_of_turn>\n"]),
efficient_eos=True,
force_system=True,
)
_register_template(
name="intern",
format_user=StringFormatter(slots=["<|User|>:{{content}}", {"token": "<eoh>"}, "\n<|Bot|>:"]),
format_separator=EmptyFormatter(slots=[{"token": "<eoa>"}, "\n"]),
stop_words=["<eoa>"],
efficient_eos=True,
)
_register_template(
name="intern2",
format_user=StringFormatter(slots=["<|im_start|>user\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]),
format_system=StringFormatter(slots=[{"bos_token"}, "<|im_start|>system\n{{content}}<|im_end|>\n"]),
format_separator=EmptyFormatter(slots=["\n"]),
default_system=(
"You are an AI assistant whose name is InternLM (书生·浦语).\n"
"- InternLM (书生·浦语) is a conversational language model that is developed "
"by Shanghai AI Laboratory (上海人工智能实验室). It is designed to be helpful, honest, and harmless.\n"
"- InternLM (书生·浦语) can understand and communicate fluently in the language chosen "
"by the user such as English and 中文."
),
stop_words=["<|im_end|>"],
efficient_eos=True, # internlm2 tokenizer cannot set eos_token_id
)
_register_template(
name="llama2",
format_user=StringFormatter(slots=[{"bos_token"}, "[INST] {{content}} [/INST]"]),
format_system=StringFormatter(slots=["<<SYS>>\n{{content}}\n<</SYS>>\n\n"]),
default_system=(
"You are a helpful, respectful and honest assistant. "
"Always answer as helpfully as possible, while being safe. "
"Your answers should not include any harmful, unethical, "
"racist, sexist, toxic, dangerous, or illegal content. "
"Please ensure that your responses are socially unbiased and positive in nature.\n\n"
"If a question does not make any sense, or is not factually coherent, "
"explain why instead of answering something not correct. "
"If you don't know the answer to a question, please don't share false information."
),
)
_register_template(
name="llama2_zh",
format_user=StringFormatter(slots=[{"bos_token"}, "[INST] {{content}} [/INST]"]),
format_system=StringFormatter(slots=["<<SYS>>\n{{content}}\n<</SYS>>\n\n"]),
default_system="You are a helpful assistant. 你是一个乐于助人的助手。",
)
_register_template(
name="mistral",
format_user=StringFormatter(slots=["[INST] {{content}} [/INST]"]),
format_system=StringFormatter(slots=[{"bos_token"}, "{{content}}"]),
force_system=True,
)
_register_template(
name="olmo",
format_user=StringFormatter(slots=["<|user|>\n{{content}}<|assistant|>"]),
format_assistant=StringFormatter(slots=["{{content}}", {"eos_token"}]),
format_system=StringFormatter(slots=[{"eos_token"}, "{{content}}"]),
force_system=True,
)
_register_template(
name="openchat",
format_user=StringFormatter(slots=["GPT4 Correct User: {{content}}", {"eos_token"}, "GPT4 Correct Assistant:"]),
format_assistant=StringFormatter(slots=["{{content}}", {"eos_token"}]),
format_system=StringFormatter(slots=[{"bos_token"}, "{{content}}"]),
force_system=True,
)
_register_template(
name="orion",
format_user=StringFormatter(slots=["Human: {{content}}\n\nAssistant: ", {"eos_token"}]),
format_system=StringFormatter(slots=[{"bos_token"}, "{{content}}"]),
force_system=True,
)
_register_template(
name="qwen",
format_user=StringFormatter(slots=["<|im_start|>user\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]),
format_system=StringFormatter(slots=["<|im_start|>system\n{{content}}<|im_end|>\n"]),
format_separator=EmptyFormatter(slots=["\n"]),
default_system="You are a helpful assistant.",
stop_words=["<|im_end|>"],
replace_eos=True,
)
_register_template(
name="solar",
format_user=StringFormatter(slots=["### User:\n{{content}}\n\n### Assistant:\n"]),
format_system=StringFormatter(slots=["### System:\n{{content}}\n\n"]),
efficient_eos=True,
)
_register_template(
name="starchat",
format_user=StringFormatter(slots=["<|user|>\n{{content}}<|end|>\n<|assistant|>"]),
format_system=StringFormatter(slots=["<|system|>\n{{content}}<|end|>\n"]),
format_separator=EmptyFormatter(slots=["\n"]),
stop_words=["<|end|>"],
replace_eos=True,
force_system=True,
)
_register_template(
name="vanilla",
)
_register_template(
name="vicuna",
format_user=StringFormatter(slots=["USER: {{content}} ASSISTANT:"]),
default_system=(
"A chat between a curious user and an artificial intelligence assistant. "
"The assistant gives helpful, detailed, and polite answers to the user's questions."
),
)
_register_template(
name="xuanyuan",
format_user=StringFormatter(slots=["Human: {{content}} Assistant:"]),
default_system=(
"以下是用户和人工智能助手之间的对话。用户以Human开头人工智能助手以Assistant开头"
"会对人类提出的问题给出有帮助、高质量、详细和礼貌的回答,并且总是拒绝参与与不道德、"
"不安全、有争议、政治敏感等相关的话题、问题和指示。\n"
),
)
_register_template(
name="xverse",
format_user=StringFormatter(slots=["Human: {{content}}\n\nAssistant: "]),
)
_register_template(
name="yayi",
format_user=StringFormatter(slots=[{"token": "<|Human|>"}, ":\n{{content}}\n\n", {"token": "<|YaYi|>"}, ":"]),
format_system=StringFormatter(slots=[{"token": "<|System|>"}, ":\n{{content}}\n\n"]),
format_separator=EmptyFormatter(slots=["\n\n"]),
default_system=(
"You are a helpful, respectful and honest assistant named YaYi "
"developed by Beijing Wenge Technology Co.,Ltd. "
"Always answer as helpfully as possible, while being safe. "
"Your answers should not include any harmful, unethical, "
"racist, sexist, toxic, dangerous, or illegal content. "
"Please ensure that your responses are socially unbiased and positive in nature.\n\n"
"If a question does not make any sense, or is not factually coherent, "
"explain why instead of answering something not correct. "
"If you don't know the answer to a question, please don't share false information."
),
stop_words=["<|End|>"],
)
_register_template(
name="yi",
format_user=StringFormatter(slots=["<|im_start|>user\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]),
format_separator=EmptyFormatter(slots=["\n"]),
stop_words=["<|im_end|>"],
replace_eos=True,
)
_register_template(
name="yuan",
format_user=StringFormatter(slots=["{{content}}", {"token": "<sep>"}]),
format_separator=EmptyFormatter(slots=["\n"]),
stop_words=["<eod>"],
replace_eos=True,
)
_register_template(
name="zephyr",
format_user=StringFormatter(slots=["<|user|>\n{{content}}", {"eos_token"}, "<|assistant|>"]),
format_assistant=StringFormatter(slots=["\n{{content}}", {"eos_token"}]),
format_system=StringFormatter(slots=["<|system|>\n{{content}}", {"eos_token"}]),
default_system="You are a friendly chatbot who always responds in the style of a pirate",
)
_register_template(
name="ziya",
format_user=StringFormatter(slots=["<human>:{{content}}\n<bot>:"]),
format_separator=EmptyFormatter(slots=["\n"]),
)

View File

@@ -0,0 +1,94 @@
import hashlib
from enum import Enum, unique
from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union
from datasets import concatenate_datasets, interleave_datasets
from ..extras.logging import get_logger
if TYPE_CHECKING:
from datasets import Dataset, IterableDataset
from transformers import Seq2SeqTrainingArguments
from llmtuner.hparams import DataArguments
logger = get_logger(__name__)
@unique
class Role(str, Enum):
USER = "user"
ASSISTANT = "assistant"
SYSTEM = "system"
FUNCTION = "function"
OBSERVATION = "observation"
def checksum(data_files: List[str], file_sha1: Optional[str] = None) -> None:
if file_sha1 is None:
logger.warning("Checksum failed: missing SHA-1 hash value in dataset_info.json.")
return
if len(data_files) != 1:
logger.warning("Checksum failed: too many files.")
return
with open(data_files[0], "rb") as f:
sha1 = hashlib.sha1(f.read()).hexdigest()
if sha1 != file_sha1:
logger.warning("Checksum failed: mismatched SHA-1 hash value at {}.".format(data_files[0]))
def infer_max_len(source_len: int, target_len: int, max_len: int, reserved_label_len: int) -> Tuple[int, int]:
max_target_len = int(max_len * (target_len / (source_len + target_len)))
max_target_len = max(max_target_len, reserved_label_len)
max_source_len = max_len - max_target_len
return max_source_len, max_target_len
def merge_dataset(
all_datasets: List[Union["Dataset", "IterableDataset"]],
data_args: "DataArguments",
training_args: "Seq2SeqTrainingArguments",
) -> Union["Dataset", "IterableDataset"]:
if len(all_datasets) == 1:
return all_datasets[0]
elif data_args.mix_strategy == "concat":
if data_args.streaming:
logger.warning("The samples between different datasets will not be mixed in streaming mode.")
return concatenate_datasets(all_datasets)
elif data_args.mix_strategy.startswith("interleave"):
if not data_args.streaming:
logger.warning("We recommend using `mix_strategy=concat` in non-streaming mode.")
return interleave_datasets(
datasets=all_datasets,
probabilities=data_args.interleave_probs,
seed=training_args.seed,
stopping_strategy="first_exhausted" if data_args.mix_strategy.endswith("under") else "all_exhausted",
)
else:
raise ValueError("Unknown mixing strategy.")
def split_dataset(
dataset: Union["Dataset", "IterableDataset"], data_args: "DataArguments", training_args: "Seq2SeqTrainingArguments"
) -> Dict[str, "Dataset"]:
if training_args.do_train:
if data_args.val_size > 1e-6: # Split the dataset
if data_args.streaming:
val_set = dataset.take(int(data_args.val_size))
train_set = dataset.skip(int(data_args.val_size))
dataset = dataset.shuffle(buffer_size=data_args.buffer_size, seed=training_args.seed)
return {"train_dataset": train_set, "eval_dataset": val_set}
else:
val_size = int(data_args.val_size) if data_args.val_size > 1 else data_args.val_size
dataset = dataset.train_test_split(test_size=val_size, seed=training_args.seed)
return {"train_dataset": dataset["train"], "eval_dataset": dataset["test"]}
else:
if data_args.streaming:
dataset = dataset.shuffle(buffer_size=data_args.buffer_size, seed=training_args.seed)
return {"train_dataset": dataset}
else: # do_eval or do_predict
return {"eval_dataset": dataset}

View File

@@ -0,0 +1,4 @@
from .evaluator import Evaluator
__all__ = ["Evaluator"]

Some files were not shown because too many files have changed in this diff Show More