Compare commits

...

57 Commits
0.2.4 ... 0.2.6

Author SHA1 Message Date
zyxucp
d450efcffe Merge pull request #60 from AIDotNet/feature_kms
fix 修改提示词上限
2024-04-05 15:51:22 +08:00
zeyu xu
2a6c84c200 fix 修改提示词上限 2024-04-05 15:50:46 +08:00
zyxucp
138a952ace Merge pull request #59 from AIDotNet/feature_kms
Feature kms
2024-04-05 15:41:14 +08:00
zeyu xu
eb6528ecd2 add 修改message结构,减少localstore存储 2024-04-05 15:39:56 +08:00
zeyu xu
2c30bbfa09 fix 细节调整 2024-04-05 15:29:39 +08:00
zeyu xu
c5a78c2135 add modeldownchange 2024-04-05 15:12:29 +08:00
zeyu xu
f03362ee41 fix 修改dropdown Trigger.Click 2024-04-05 15:04:44 +08:00
zeyu xu
fad3167d97 add kms settings 2024-04-05 15:00:37 +08:00
zeyu xu
ad949681dd add change 2024-04-05 14:41:58 +08:00
zeyu xu
27999d76b0 fix 修改知识库函数 2024-04-05 14:26:35 +08:00
zeyu xu
83278352d6 add kms 配置 2024-04-05 14:12:25 +08:00
zeyu xu
fcc56f5fef fix bge embedding 无法切片问题 2024-04-04 00:37:08 +08:00
zyxucp
4ebe2ecc32 fix 修改初始化,增加完成标识 2024-04-02 13:53:32 +08:00
zeyu xu
e684cba527 Merge branch 'main' of github.com:AIDotNet/AntSK 2024-04-02 13:34:38 +08:00
zeyu xu
888dc19ee0 fix bgeembedding 2024-04-02 13:34:24 +08:00
zyxucp
731aea702f fix 修改提示词 2024-04-02 11:17:53 +08:00
zyxucp
09e22bc76a Update README.md 2024-04-02 00:07:08 +08:00
zyxucp
74406d88a0 Merge pull request #58 from AIDotNet/feature_StableDiffusion
fix 修改为静态类
2024-04-01 23:57:12 +08:00
zeyu xu
e5f9d97560 fix 修改为静态类 2024-04-01 23:56:44 +08:00
zyxucp
59e768aaea Merge pull request #57 from AIDotNet/feature_StableDiffusion
Feature stable diffusion
2024-04-01 23:39:22 +08:00
zeyu xu
6a7cb24a5b add sd 2024-04-01 23:08:53 +08:00
zeyu xu
1db40d534c add apptype 2024-04-01 22:14:18 +08:00
zeyu xu
11d6e30f7e add sd function 2024-04-01 22:03:00 +08:00
zeyu xu
9d5214aaae add sdmodel 2024-04-01 21:57:18 +08:00
zeyu xu
010b906271 add sd 2024-04-01 21:35:51 +08:00
zeyu xu
16bf944edf add sd 2024-04-01 21:31:15 +08:00
zeyu xu
5bae5a099a margin 2024-04-01 21:01:29 +08:00
zyxucp
f771ea9521 Merge branch 'main' of https://github.com/AIDotNet/AntSK 2024-04-01 13:54:53 +08:00
zyxucp
994efbf37c update nuget 2024-04-01 13:54:20 +08:00
zyxucp
938cd86c88 Update README.md 2024-03-31 13:24:21 +08:00
zeyu xu
1339cbadbc fix 修改menukey 2024-03-31 13:07:30 +08:00
zeyu xu
bd0ad570ad add 增加使用文档 2024-03-31 13:07:08 +08:00
zeyu xu
234e649a7e fix 优化部分内容 2024-03-31 12:38:17 +08:00
zyxucp
c431dbc842 Update README.md 2024-03-31 00:28:16 +08:00
zyxucp
76283060d9 Update docker-compose.simple.yml 2024-03-30 23:28:52 +08:00
zyxucp
75ba506db4 Update docker-compose.yml 2024-03-30 23:28:33 +08:00
zeyu xu
0c8ad5fe8d add loadding 2024-03-30 19:50:29 +08:00
zeyu xu
68ce0db011 fix 样式修改 2024-03-30 17:35:40 +08:00
zeyu xu
c36de1a1e9 add 选项控制 2024-03-30 17:25:58 +08:00
zeyu xu
44ef759abd fix 修改控件 2024-03-30 14:47:29 +08:00
longdream
0c3d9844be Merge pull request #52 from longdream/main
bge embedding模型添加,bge用的CPU。
2024-03-29 21:51:35 +08:00
longdream
854c62a4ca 合并 2024-03-29 21:50:17 +08:00
longdream
5ed4fd5299 Merge branch 'main' of https://github.com/longdream/AntSK 2024-03-29 20:00:53 +08:00
longdream
af5ec43571 修改设置界面 2024-03-29 20:00:49 +08:00
junlong
d7b56d1590 Merge branch 'main' of https://github.com/longdream/AntSK 2024-03-29 15:34:08 +08:00
longdream
b925f8890b 修改token长度 2024-03-28 23:06:21 +08:00
longdream
5d80ee994a 解决线程冲突问题 2024-03-28 19:04:11 +08:00
longdream
f73bd2dfda 增减embedding 2024-03-27 22:53:45 +08:00
longdream
f340ee1088 embedding封装 2024-03-26 23:14:55 +08:00
longdream
edad2644aa 删除没必要的py文件 2024-03-26 20:48:49 +08:00
longdream
8a56a0393a Merge branch 'main' of https://github.com/longdream/AntSK 2024-03-26 20:48:07 +08:00
junlong
bd5ca06d8f test 2024-03-25 16:55:41 +08:00
junlong
e0985ecec3 Merge branch 'main' of https://github.com/longdream/AntSK 2024-03-25 16:48:21 +08:00
junlong
e56b74d4af 删除chat以外的文件 2024-03-25 16:48:11 +08:00
longdream
849b18f677 Merge branch 'AIDotNet:main' into main 2024-03-22 19:36:20 +08:00
junlong
344128e49d Merge branch 'main' of https://github.com/longdream/AntSK 2024-03-21 19:38:03 +08:00
junlong
56fc9dd517 test 2024-03-21 19:37:56 +08:00
53 changed files with 1317 additions and 256 deletions

View File

@@ -10,6 +10,8 @@
- **知识库**通过文档Word、PDF、Excel、Txt、Markdown、Json、PPT等形式导入知识库可以进行知识库问答。
- **文生图**:集成**StableDiffusion** 本地模型,可以进行文生图。
- **GPTs 生成**此平台支持创建个性化的GPT模型尝试构建您自己的GPT模型。
- **API接口发布**将内部功能以API的形式对外提供便于开发者将AntSK 集成进其他应用,增强应用智慧。
@@ -54,6 +56,8 @@ https://antsk.ai-dotnet.com/
### 其他功能示例
[视频示例](https://www.bilibili.com/video/BV1zH4y1h7Y9/)
[在线文档http://antsk.cn](http://antsk.cn)
## ❓如何开始?
在这里我使用的是Postgres 作为数据存储和向量存储因为Semantic Kernel和Kernel Memory都支持他当然你也可以换成其他的。
@@ -173,7 +177,7 @@ DB我使用的是CodeFirst模式只要配置好数据库链接表结构是
## ✔使用llamafactory
```
1、首先需要确保你的环境已经安装了python和pip如果使用镜像例如v0.2.3.2版本已经包含了 python全套环境则无需此步骤
1、首先需要确保你的环境已经安装了python和pip如果使用镜像例如p0.2.4版本已经包含了 python全套环境则无需此步骤
2、进入模型添加页面选择llamafactory
3、点击初始化可以检查pip install 环境是否完成
4、选择一个喜欢的模型

View File

@@ -3,9 +3,9 @@ version: '3.8'
services:
antsk:
container_name: antsk
image: registry.cn-hangzhou.aliyuncs.com/xuzeyu91/antsk:v0.2.3
image: registry.cn-hangzhou.aliyuncs.com/xuzeyu91/antsk:v0.2.4
# 如果需要pytorch环境需要使用下面这个镜像镜像比较大
# image: registry.cn-hangzhou.aliyuncs.com/xuzeyu91/antsk:v0.2.3.2
# image: registry.cn-hangzhou.aliyuncs.com/xuzeyu91/antsk:p0.2.4
ports:
- 5000:5000
networks:

View File

@@ -18,9 +18,9 @@ services:
- ./pg/data:/var/lib/postgresql/data
antsk:
container_name: antsk
image: registry.cn-hangzhou.aliyuncs.com/xuzeyu91/antsk:v0.2.3
image: registry.cn-hangzhou.aliyuncs.com/xuzeyu91/antsk:v0.2.4
# 如果需要pytorch环境需要使用下面这个镜像镜像比较大
# image: registry.cn-hangzhou.aliyuncs.com/xuzeyu91/antsk:v0.2.3.2
# image: registry.cn-hangzhou.aliyuncs.com/xuzeyu91/antsk:p0.2.4
ports:
- 5000:5000
networks:

View File

@@ -13,13 +13,15 @@
<PackageReference Include="BlazorComponents.Terminal" Version="0.6.0" />
<PackageReference Include="Blazored.LocalStorage" Version="4.5.0" />
<PackageReference Include="pythonnet" Version="3.0.3" />
<PackageReference Include="Swashbuckle.AspNetCore" Version="6.5.0" />
<PackageReference Include="AutoMapper" Version="8.1.0" />
<PackageReference Include="BCrypt.Net-Next" Version="4.0.3" />
<PackageReference Include="Markdig" Version="0.36.2" />
<PackageReference Include="Newtonsoft.Json" Version="13.0.3" />
<PackageReference Include="SqlSugarCore" Version="5.1.4.148" />
<PackageReference Include="SqlSugarCore" Version="5.1.4.149" />
<PackageReference Include="System.Data.SQLite.Core" Version="1.0.118" />
<PackageReference Include="RestSharp" Version="110.2.0" />
@@ -32,11 +34,11 @@
<PackageReference Include="Microsoft.KernelMemory.MemoryDb.Redis" Version="0.35.240321.1" />
<PackageReference Include="Microsoft.KernelMemory.MemoryDb.AzureAISearch" Version="0.35.240321.1" />
<PackageReference Include="LLamaSharp" Version="0.10.0" />
<PackageReference Include="LLamaSharp.Backend.Cpu" Version="0.10.0" />
<PackageReference Include="LLamaSharp.Backend.Cuda12" Version="0.10.0" />
<PackageReference Include="LLamaSharp.kernel-memory" Version="0.10.0" />
<PackageReference Include="LLamaSharp.semantic-kernel" Version="0.10.0" />
<PackageReference Include="LLamaSharp" Version="0.11.1" />
<PackageReference Include="LLamaSharp.Backend.Cpu" Version="0.11.1" />
<PackageReference Include="LLamaSharp.Backend.Cuda12" Version="0.11.1" />
<PackageReference Include="LLamaSharp.kernel-memory" Version="0.11.1" />
<PackageReference Include="LLamaSharp.semantic-kernel" Version="0.11.1" />
</ItemGroup>

View File

@@ -99,6 +99,11 @@
总数
</summary>
</member>
<member name="M:AntSK.Domain.Domain.Other.EmbeddingConfig.LoadModel(System.String,System.String)">
<summary>
模型写死
</summary>
</member>
<member name="F:AntSK.Domain.Domain.Other.LLamaConfig.dicLLamaWeights">
<summary>
避免模型重复加载,本地缓存
@@ -287,6 +292,26 @@
API调用秘钥
</summary>
</member>
<member name="P:AntSK.Domain.Repositories.Apps.Relevance">
<summary>
相似度
</summary>
</member>
<member name="P:AntSK.Domain.Repositories.Apps.MaxAskPromptSize">
<summary>
提问最大token数
</summary>
</member>
<member name="P:AntSK.Domain.Repositories.Apps.MaxMatchesCount">
<summary>
向量匹配数
</summary>
</member>
<member name="P:AntSK.Domain.Repositories.Apps.AnswerTokens">
<summary>
回答最大token数
</summary>
</member>
<member name="P:AntSK.Domain.Repositories.Funs.Path">
<summary>
接口描述
@@ -771,6 +796,14 @@
<param name="parameters"></param>
<returns></returns>
</member>
<member name="M:AntSK.Domain.Utils.ConvertUtils.ComparisonIgnoreCase(System.String,System.String)">
<summary>
忽略大小写匹配
</summary>
<param name="s"></param>
<param name="value"></param>
<returns></returns>
</member>
<member name="M:AntSK.Domain.Utils.RepoFiles.SamplePluginsPath">
<summary>
Scan the local folders from the repo, looking for "samples/plugins" folder.

View File

@@ -0,0 +1,21 @@
using LLamaSharp.KernelMemory;
using Microsoft.KernelMemory.AI;
using Microsoft.KernelMemory;
using System;
using System.Collections.Generic;
using System.Linq;
using System.Text;
using System.Threading.Tasks;
namespace AntSK.Domain.Common.Embedding
{
public static class BuilderBgeExtensions
{
public static IKernelMemoryBuilder WithBgeTextEmbeddingGeneration(this IKernelMemoryBuilder builder, HuggingfaceTextEmbeddingGenerator textEmbeddingGenerator)
{
builder.AddSingleton((ITextEmbeddingGenerator)textEmbeddingGenerator);
builder.AddIngestionEmbeddingGenerator(textEmbeddingGenerator);
return builder;
}
}
}

View File

@@ -0,0 +1,56 @@
using LLama.Common;
using LLama;
using LLamaSharp.KernelMemory;
using Microsoft.KernelMemory.AI;
using System;
using System.Collections.Generic;
using System.Linq;
using System.Text;
using System.Threading.Tasks;
using AntSK.Domain.Domain.Other;
namespace AntSK.Domain.Common.Embedding
{
public class HuggingfaceTextEmbeddingGenerator : ITextEmbeddingGenerator, ITextTokenizer, IDisposable
{
public int MaxTokens => 1024;
public int MaxTokenTotal => 1024;
private readonly dynamic _embedder;
public HuggingfaceTextEmbeddingGenerator(string pyDllPath,string modelName)
{
_embedder = EmbeddingConfig.LoadModel(pyDllPath, modelName);
}
public void Dispose()
{
EmbeddingConfig.Dispose();
}
//public async Task<IList<ReadOnlyMemory<float>>> GenerateEmbeddingAsync(IList<string> data, CancellationToken cancellationToken = default)
//{
// IList<ReadOnlyMemory<float>> results = new List<ReadOnlyMemory<float>>();
// foreach (var d in data)
// {
// var embeddings = await EmbeddingConfig.GetEmbedding(d);
// results.Add(new ReadOnlyMemory<float>(embeddings));
// }
// return results;
//}
public async Task<Microsoft.KernelMemory.Embedding> GenerateEmbeddingAsync(string text, CancellationToken cancellationToken = default)
{
var embeddings = await EmbeddingConfig.GetEmbedding(text);
return new Microsoft.KernelMemory.Embedding(embeddings);
}
public int CountTokens(string text)
{
return EmbeddingConfig.TokenCount(text);
}
}
}

View File

@@ -5,6 +5,7 @@ using Microsoft.SemanticKernel;
using Microsoft.SemanticKernel.ChatCompletion;
using System;
using System.Collections.Generic;
using System.Drawing;
using System.Linq;
using System.Text;
using System.Threading.Tasks;
@@ -16,7 +17,7 @@ namespace AntSK.Domain.Domain.Interface
IAsyncEnumerable<StreamingKernelContent> SendChatByAppAsync(Apps app, string questions, ChatHistory history);
IAsyncEnumerable<StreamingKernelContent> SendKmsByAppAsync(Apps app, string questions, ChatHistory history, string filePath, List<RelevantSource> relevantSources = null);
Task<string> SendImgByAppAsync(Apps app, string questions);
Task<ChatHistory> GetChatHistory(List<MessageInfo> MessageList);
}
}

View File

@@ -7,13 +7,13 @@ namespace AntSK.Domain.Domain.Interface
{
public interface IKMService
{
MemoryServerless GetMemory(Apps app);
MemoryServerless GetMemoryByApp(Apps app);
MemoryServerless GetMemoryByKMS(string kmsID, SearchClientConfig searchClientConfig = null);
MemoryServerless GetMemoryByKMS(string kmsID);
Task<List<KMFile>> GetDocumentByFileID(string kmsId, string fileId);
Task<List<RelevantSource>> GetRelevantSourceList(string kmsIdListStr, string msg);
Task<List<RelevantSource>> GetRelevantSourceList(Apps app, string msg);
List<UploadFileItem> FileList { get; }

View File

@@ -11,5 +11,22 @@ namespace AntSK.Domain.Domain.Model.Constant
public const string KmsIdTag = "kmsid";
public const string KmsIndex = "kms";
public const string KmsSearchNull="知识库未搜索到相关内容";
public const string KmsPrompt = @"使用<data></data>标记的内容作为你的知识:
<data>
{{$doc}}
</data>
--------------------------
回答要求:
- 如果你不清楚答案,你需要澄清
- 避免提及你是从<data></data>获取的知识
- 保持答案与<data></data>众描述一致
- 使用Markdown语法优化回答格式。
- 如果Markdown有图片则正常显示
--------------------------
历史聊天记录:{{ConversationSummaryPlugin.SummarizeConversation $history}}
--------------------------
用户问题: {{$input}}";
}
}

View File

@@ -21,12 +21,16 @@ namespace AntSK.Domain.Domain.Model.Enum
[Display(Name = "灵积大模型")]
DashScope = 5,
[Display(Name = "LLamaFactory")]
LLamaFactory = 6,
[Display(Name = "Bge Embedding")]
BgeEmbedding = 7,
[Display(Name = "StableDiffusion")]
StableDiffusion = 8,
[Display(Name = "模拟输出")]
Mock = 100,
}
/// <summary>
@@ -36,5 +40,6 @@ namespace AntSK.Domain.Domain.Model.Enum
{
Chat = 1,
Embedding = 2,
Image=3,
}
}

View File

@@ -9,6 +9,7 @@ namespace AntSK.Domain.Domain.Model.Enum
public enum AppType
{
chat = 1,
kms = 2
kms = 2,
img=3
}
}

View File

@@ -4,7 +4,6 @@
{
public string ID { get; set; } = "";
public string Context { get; set; } = "";
public string HtmlAnswers { get; set; } = "";
/// <summary>
/// 发送是true 接收是false
@@ -13,8 +12,6 @@
public DateTime CreateTime { get; set; }
public string? FilePath { get; set; }
public string? FileName { get; set; }
}
}

View File

@@ -0,0 +1,95 @@
using Microsoft.KernelMemory.AI.OpenAI.GPT3;
using Python.Runtime;
using System;
using System.Collections.Generic;
using System.Linq;
using System.Text;
using System.Threading.Tasks;
using static Python.Runtime.Py;
namespace AntSK.Domain.Domain.Other
{
public static class EmbeddingConfig
{
public static dynamic model { get; set; }
static object lockobj = new object();
/// <summary>
/// 模型写死
/// </summary>
public static dynamic LoadModel(string pythondllPath, string modelName)
{
lock (lockobj)
{
if (model == null)
{
//Runtime.PythonDLL = @"D:\Programs\Python\Python311\python311.dll";
Runtime.PythonDLL = pythondllPath;
PythonEngine.Initialize();
PythonEngine.BeginAllowThreads();
try
{
using (Py.GIL())// 初始化Python环境的Global Interpreter Lock)
{
dynamic modelscope = Py.Import("modelscope");
//dynamic model_dir = modelscope.snapshot_download("AI-ModelScope/bge-large-zh-v1.5", revision: "master");
dynamic model_dir = modelscope.snapshot_download(modelName, revision: "master");
dynamic HuggingFaceBgeEmbeddingstemp = Py.Import("langchain_community.embeddings.huggingface");
dynamic HuggingFaceBgeEmbeddings = HuggingFaceBgeEmbeddingstemp.HuggingFaceBgeEmbeddings;
string model_name = model_dir;
dynamic model_kwargs = new PyDict();
model_kwargs["device"] = new PyString("cpu");
dynamic hugginmodel = HuggingFaceBgeEmbeddings(
model_name: model_dir,
model_kwargs: model_kwargs
);
model = hugginmodel;
return hugginmodel;
}
}
catch(Exception ex)
{
throw ex;
}
}
else
return model;
}
}
public static Task<float[]> GetEmbedding(string queryStr)
{
using (Py.GIL())
{
PyObject queryResult = model.embed_query(queryStr);
var floatList = queryResult.As<float[]>();
return Task.FromResult(floatList); ;
}
}
public static int TokenCount(string queryStr)
{
//using (Py.GIL())
//{
// PyObject queryResult = model.client.tokenize(queryStr);
// // 使用Python的内置len()函数获取长度
// PyObject lenFunc = Py.Import("builtins").GetAttr("len");
// PyObject length = lenFunc.Invoke(queryResult["input_ids"]);
// int len = length.As<int>(); // 将PyObject转换为C#中的整数
// return len;
//}
var tokenCount1 = GPT3Tokenizer.Encode(queryStr).Count;
return tokenCount1;
}
public static void Dispose()
{
Console.WriteLine("python dispose");
}
}
}

View File

@@ -16,6 +16,8 @@ using ChatHistory = Microsoft.SemanticKernel.ChatCompletion.ChatHistory;
using Microsoft.SemanticKernel.Plugins.Core;
using Azure.Core;
using AntSK.Domain.Domain.Model;
using AntSK.LLM.StableDiffusion;
using System.Drawing;
namespace AntSK.Domain.Domain.Service
{
@@ -23,7 +25,8 @@ namespace AntSK.Domain.Domain.Service
public class ChatService(
IKernelService _kernelService,
IKMService _kMService,
IKmsDetails_Repositories _kmsDetails_Repositories
IKmsDetails_Repositories _kmsDetails_Repositories,
IAIModels_Repositories _aIModels_Repositories
) : IChatService
{
/// <summary>
@@ -77,11 +80,12 @@ namespace AntSK.Domain.Domain.Service
public async IAsyncEnumerable<StreamingKernelContent> SendKmsByAppAsync(Apps app, string questions, ChatHistory history, string filePath, List<RelevantSource> relevantSources = null)
{
var relevantSourceList = await _kMService.GetRelevantSourceList(app.KmsIdList, questions);
relevantSources?.Clear();
var relevantSourceList = await _kMService.GetRelevantSourceList(app, questions);
var _kernel = _kernelService.GetKernelByApp(app);
if (!string.IsNullOrWhiteSpace(filePath))
{
var memory = _kMService.GetMemory(app);
var memory = _kMService.GetMemoryByApp(app);
var fileId = Guid.NewGuid().ToString();
var result = await memory.ImportDocumentAsync(new Microsoft.KernelMemory.Document(fileId).AddFile(filePath)
.AddTag(KmsConstantcs.KmsIdTag, app.Id)
@@ -101,20 +105,43 @@ namespace AntSK.Domain.Domain.Service
var dataMsg = new StringBuilder();
if (relevantSourceList.Any())
{
bool isSearch=false;
foreach (var item in relevantSourceList)
{
//匹配相似度
if (item.Relevance >= app.Relevance/100)
{
dataMsg.AppendLine(item.ToString());
isSearch=true;
}
}
//处理markdown显示
relevantSources?.AddRange(relevantSourceList);
foreach (var item in relevantSourceList)
{
dataMsg.AppendLine(item.ToString());
item.Text = Markdown.ToHtml(item.Text);
}
KernelFunction jsonFun = _kernel.Plugins.GetFunction("KMSPlugin", "Ask1");
var chatResult = _kernel.InvokeStreamingAsync(function: jsonFun,
arguments: new KernelArguments() { ["doc"] = dataMsg, ["history"] = string.Join("\n", history.Select(x => x.Role + ": " + x.Content)), ["questions"] = questions });
await foreach (var content in chatResult)
if (isSearch)
{
yield return content;
//KernelFunction jsonFun = _kernel.Plugins.GetFunction("KMSPlugin", "Ask1");
var temperature = app.Temperature / 100;//存的是0~100需要缩小
OpenAIPromptExecutionSettings settings = new() { Temperature = temperature };
var func = _kernel.CreateFunctionFromPrompt(app.Prompt, settings);
var chatResult = _kernel.InvokeStreamingAsync(function: func,
arguments: new KernelArguments() { ["doc"] = dataMsg.ToString(), ["history"] = string.Join("\n", history.Select(x => x.Role + ": " + x.Content)), ["input"] = questions });
await foreach (var content in chatResult)
{
yield return content;
}
}
else
{
yield return new StreamingTextContent(KmsConstantcs.KmsSearchNull);
}
}
else
{
@@ -122,6 +149,58 @@ namespace AntSK.Domain.Domain.Service
}
}
public async Task<string> SendImgByAppAsync(Apps app, string questions)
{
var imageModel = _aIModels_Repositories.GetFirst(p => p.Id == app.ImageModelID);
KernelArguments args = new() {
{ "input", questions }
};
var _kernel = _kernelService.GetKernelByApp(app);
var temperature = app.Temperature / 100; //存的是0~100需要缩小
OpenAIPromptExecutionSettings settings = new() { Temperature = temperature };
var func = _kernel.CreateFunctionFromPrompt("Translate this into English:{{$input}}", settings);
var chatResult = await _kernel.InvokeAsync(function: func, arguments: args);
if (chatResult.IsNotNull())
{
string prompt = chatResult.GetValue<string>();
if (!SDHelper.IsInitialized)
{
Structs.ModelParams modelParams = new Structs.ModelParams
{
ModelPath = imageModel.ModelName,
RngType = Structs.RngType.CUDA_RNG,
//VaePath = vaePath,
//KeepVaeOnCpu = keepVaeOnCpu,
//VaeTiling = vaeTiling,
//LoraModelDir = loraModelDir,
};
bool result = SDHelper.Initialize(modelParams);
}
Structs.TextToImageParams textToImageParams = new Structs.TextToImageParams
{
Prompt = prompt,
NegativePrompt = "2d, 3d, cartoon, paintings",
SampleMethod = (Structs.SampleMethod)Enum.Parse(typeof(Structs.SampleMethod), "EULER_A"),
Width = 256,
Height = 256,
NormalizeInput = true,
ClipSkip = -1,
CfgScale = 7,
SampleSteps = 20,
Seed = -1,
};
Bitmap[] outputImages = SDHelper.TextToImage(textToImageParams);
var base64 = ImageUtils.BitmapToBase64(outputImages[0]);
return base64;
}
else
{
return "";
}
}
public async Task<ChatHistory> GetChatHistory(List<MessageInfo> MessageList)
{
ChatHistory history = new ChatHistory();

View File

@@ -1,5 +1,6 @@
using AntDesign;
using AntSK.Domain.Common.DependencyInjection;
using AntSK.Domain.Common.Embedding;
using AntSK.Domain.Domain.Interface;
using AntSK.Domain.Domain.Model.Constant;
using AntSK.Domain.Domain.Model.Dto;
@@ -35,7 +36,7 @@ namespace AntSK.Domain.Domain.Service
public List<UploadFileItem> FileList => _fileList;
public MemoryServerless GetMemory(Apps app)
public MemoryServerless GetMemoryByApp(Apps app)
{
var chatModel = _aIModels_Repositories.GetFirst(p => p.Id == app.ChatModelID);
var embedModel = _aIModels_Repositories.GetFirst(p => p.Id == app.EmbeddingModelID);
@@ -44,9 +45,9 @@ namespace AntSK.Domain.Domain.Service
var searchClientConfig = new SearchClientConfig
{
MaxAskPromptSize = 2048,
MaxMatchesCount = 3,
AnswerTokens = 1000,
MaxAskPromptSize = app.MaxAskPromptSize,
MaxMatchesCount = app.MaxMatchesCount,
AnswerTokens = app.AnswerTokens,
EmptyAnswer = KmsConstantcs.KmsSearchNull
};
@@ -70,7 +71,7 @@ namespace AntSK.Domain.Domain.Service
return _memory;
}
public MemoryServerless GetMemoryByKMS(string kmsID, SearchClientConfig searchClientConfig = null)
public MemoryServerless GetMemoryByKMS(string kmsID)
{
//if (_memory.IsNull())
{
@@ -84,19 +85,19 @@ namespace AntSK.Domain.Domain.Service
var embeddingHttpClient = OpenAIHttpClientHandlerUtil.GetHttpClient(embedModel.EndPoint);
//搜索配置
if (searchClientConfig.IsNull())
{
searchClientConfig = new SearchClientConfig
{
MaxAskPromptSize = 2048,
MaxMatchesCount = 3,
AnswerTokens = 1000,
EmptyAnswer = KmsConstantcs.KmsSearchNull
};
}
//if (searchClientConfig.IsNull())
//{
// searchClientConfig = new SearchClientConfig
// {
// MaxAskPromptSize = 2048,
// MaxMatchesCount = 3,
// AnswerTokens = 1000,
// EmptyAnswer = KmsConstantcs.KmsSearchNull
// };
//}
var memoryBuild = new KernelMemoryBuilder()
.WithSearchClientConfig(searchClientConfig)
//.WithSearchClientConfig(searchClientConfig)
.WithCustomTextPartitioningOptions(new TextPartitioningOptions
{
MaxTokensPerLine = kms.MaxTokensPerLine,
@@ -147,6 +148,11 @@ namespace AntSK.Domain.Domain.Service
var embedder = new LLamaEmbedder(weights, parameters);
memory.WithLLamaSharpTextEmbeddingGeneration(new LLamaSharpTextEmbeddingGenerator(embedder));
break;
case Model.Enum.AIType.BgeEmbedding:
string pyDll = embedModel.EndPoint;
string bgeEmbeddingModelName = embedModel.ModelName;
memory.WithBgeTextEmbeddingGeneration(new HuggingfaceTextEmbeddingGenerator(pyDll,bgeEmbeddingModelName));
break;
case Model.Enum.AIType.DashScope:
memory.WithDashScopeDefaults(embedModel.ModelKey);
break;
@@ -183,6 +189,14 @@ namespace AntSK.Domain.Domain.Service
var executor = new StatelessExecutor(weights, parameters);
memory.WithLLamaSharpTextGeneration(new LlamaSharpTextGenerator(weights, context, executor));
break;
case Model.Enum.AIType.LLamaFactory:
memory.WithOpenAITextGeneration(new OpenAIConfig()
{
APIKey = "123",
TextModel = chatModel.ModelName
}, null, chatHttpClient);
break;
case Model.Enum.AIType.DashScope:
memory.WithDashScopeTextGeneration(new Cnblogs.KernelMemory.AI.DashScope.DashScopeConfig
{
@@ -263,15 +277,15 @@ namespace AntSK.Domain.Domain.Service
return docTextList;
}
public async Task<List<RelevantSource>> GetRelevantSourceList(string kmsIdListStr, string msg)
public async Task<List<RelevantSource>> GetRelevantSourceList(Apps app ,string msg)
{
var result = new List<RelevantSource>();
if (string.IsNullOrWhiteSpace(kmsIdListStr))
if (string.IsNullOrWhiteSpace(app.KmsIdList))
return result;
var kmsIdList = kmsIdListStr.Split(",");
var kmsIdList = app.KmsIdList.Split(",");
if (!kmsIdList.Any()) return result;
var memory = GetMemoryByKMS(kmsIdList.FirstOrDefault()!);
var memory = GetMemoryByApp(app);
var filters = kmsIdList.Select(kmsId => new MemoryFilter().ByTag(KmsConstantcs.KmsIdTag, kmsId)).ToList();
@@ -283,7 +297,7 @@ namespace AntSK.Domain.Domain.Service
result.AddRange(item.Partitions.Select(part => new RelevantSource()
{
SourceName = item.SourceName,
Text = Markdown.ToHtml(part.Text),
Text = part.Text,
Relevance = part.Relevance
}));
}

View File

@@ -8,6 +8,7 @@ using Newtonsoft.Json;
using System;
using System.Collections.Generic;
using System.Diagnostics;
using System.Diagnostics.Tracing;
using System.Linq;
using System.Text;
using System.Text.Json;
@@ -67,7 +68,9 @@ namespace AntSK.Domain.Domain.Service
process.BeginOutputReadLine();
process.BeginErrorReadLine();
process.WaitForExit();
OnLogMessageReceived("--------------------完成--------------------");
}, TaskCreationOptions.LongRunning);
await cmdTask;
}
public async Task StartLLamaFactory(string modelName, string templateName)
@@ -106,7 +109,10 @@ namespace AntSK.Domain.Domain.Service
process.BeginOutputReadLine();
process.BeginErrorReadLine();
process.WaitForExit();
OnLogMessageReceived("--------------------完成--------------------");
}, TaskCreationOptions.LongRunning);
await cmdTask;
}
private void Process_OutputDataReceived(object sender, DataReceivedEventArgs e)

View File

@@ -44,6 +44,7 @@ namespace AntSK.Domain.Repositories
/// </summary>
public string? EmbeddingModelID { get; set; }
public string? ImageModelID { get; set; }
/// <summary>
/// 温度
/// </summary>
@@ -53,6 +54,7 @@ namespace AntSK.Domain.Repositories
/// <summary>
/// 提示词
/// </summary>
[SugarColumn(ColumnDataType = "varchar(2000)")]
public string? Prompt { get; set; }
/// <summary>
@@ -76,5 +78,28 @@ namespace AntSK.Domain.Repositories
/// API调用秘钥
/// </summary>
public string? SecretKey { get; set; }
/// <summary>
/// 相似度
/// </summary>
[SugarColumn(DefaultValue = "70")]
public double Relevance { get; set; } = 70;
/// <summary>
/// 提问最大token数
/// </summary>
[SugarColumn(DefaultValue = "2048")]
public int MaxAskPromptSize { get; set; } = 2048;
/// <summary>
/// 向量匹配数
/// </summary>
[SugarColumn(DefaultValue = "3")]
public int MaxMatchesCount { get; set; } = 3;
/// <summary>
/// 回答最大token数
/// </summary>
[SugarColumn(DefaultValue = "2048")]
public int AnswerTokens { get; set; } = 2048;
}
}

View File

@@ -250,5 +250,16 @@ namespace AntSK.Domain.Utils
return nameValueCollection.ToString();
}
/// <summary>
/// 忽略大小写匹配
/// </summary>
/// <param name="s"></param>
/// <param name="value"></param>
/// <returns></returns>
public static bool ComparisonIgnoreCase(this string s, string value)
{
return s.Equals(value, StringComparison.OrdinalIgnoreCase);
}
}
}

View File

@@ -0,0 +1,39 @@
using System;
using System.Collections.Generic;
using System.Drawing.Imaging;
using System.Drawing;
using System.Linq;
using System.Text;
using System.Threading.Tasks;
namespace AntSK.Domain.Utils
{
public class ImageUtils
{
public static string BitmapToBase64(Bitmap bitmap)
{
using (MemoryStream memoryStream = new MemoryStream())
{
// 保存为JPEG格式也可以选择PngGif等等
bitmap.Save(memoryStream, ImageFormat.Jpeg);
// 获取内存流的字节数组
byte[] imageBytes = memoryStream.ToArray();
// 将字节转换为Base64字符串
string base64String = Convert.ToBase64String(imageBytes);
return base64String;
}
}
public static List<string> BitmapListToBase64(Bitmap[] bitmaps)
{
List<string> base64Strings = new List<string>();
foreach (Bitmap bitmap in bitmaps)
{
base64Strings.Add(BitmapToBase64(bitmap));
}
return base64Strings;
}
}
}

View File

@@ -1,27 +0,0 @@
import subprocess
import shlex
import os
class Start(object):
def __init__(self,model_name_or_path):
self.model_name_or_path=model_name_or_path
def StartCommand(self):
os.environ['CUDA_VISIBLE_DEVICES'] = '0'
os.environ['API_PORT'] = '8000'
# 构建要执行的命令
command = (
'python api_demo.py'
' --model_name_or_path E:/model/Qwen1.5-0.5B-Chat_back'
' --template default '
)
# 使用shlex.split()去安全地分割命令字符串
command = shlex.split(command)
# 执行命令
subprocess.run(command, shell=True)
if __name__ == "__main__":
star= Start('model_name_or_path')
star.StartCommand()

View File

@@ -1,49 +0,0 @@
from llmtuner import ChatModel
from llmtuner.extras.misc import torch_gc
try:
import platform
if platform.system() != "Windows":
import readline # noqa: F401
except ImportError:
print("Install `readline` for a better experience.")
def main():
chat_model = ChatModel()
messages = []
print("Welcome to the CLI application, use `clear` to remove the history, use `exit` to exit the application.")
while True:
try:
query = input("\nUser: ")
except UnicodeDecodeError:
print("Detected decoding error at the inputs, please set the terminal encoding to utf-8.")
continue
except Exception:
raise
if query.strip() == "exit":
break
if query.strip() == "clear":
messages = []
torch_gc()
print("History has been removed.")
continue
messages.append({"role": "user", "content": query})
print("Assistant: ", end="", flush=True)
response = ""
for new_text in chat_model.stream_chat(messages):
print(new_text, end="", flush=True)
response += new_text
print()
messages.append({"role": "assistant", "content": response})
if __name__ == "__main__":
main()

View File

@@ -1,10 +0,0 @@
from llmtuner import Evaluator
def main():
evaluator = Evaluator()
evaluator.eval()
if __name__ == "__main__":
main()

View File

@@ -1,9 +0,0 @@
from llmtuner import export_model
def main():
export_model()
if __name__ == "__main__":
main()

View File

@@ -1,14 +0,0 @@
from llmtuner import run_exp
def main():
run_exp()
def _mp_fn(index):
# For xla_spawn (TPUs)
main()
if __name__ == "__main__":
main()

View File

@@ -1,11 +0,0 @@
from llmtuner import create_ui
def main():
demo = create_ui()
demo.queue()
demo.launch(server_name="0.0.0.0", share=False, inbrowser=True)
if __name__ == "__main__":
main()

View File

@@ -1,11 +0,0 @@
from llmtuner import create_web_demo
def main():
demo = create_web_demo()
demo.queue()
demo.launch(server_name="0.0.0.0", share=False, inbrowser=True)
if __name__ == "__main__":
main()

View File

@@ -1,4 +1,4 @@
torch>=1.13.1
torch>=1.13.1 --index-url https://download.pytorch.org/whl/cu121
transformers>=4.37.2
datasets>=2.14.3
accelerate>=0.27.2
@@ -16,3 +16,5 @@ sse-starlette
matplotlib
fire
modelscope
langchain-community
sentence_transformers

View File

@@ -0,0 +1,9 @@
<Project Sdk="Microsoft.NET.Sdk">
<PropertyGroup>
<TargetFramework>net8.0</TargetFramework>
<ImplicitUsings>enable</ImplicitUsings>
<Nullable>enable</Nullable>
</PropertyGroup>
</Project>

View File

@@ -0,0 +1,7 @@
namespace AntSK.PyNet
{
public class Class1
{
}
}

View File

@@ -7,6 +7,13 @@
<DocumentationFile>AntSK.xml</DocumentationFile>
<NoWarn>CA1050,CA1707,CA2007,VSTHRD111,CS1591,RCS1110,CA5394,SKEXP0001,SKEXP0002,SKEXP0003,SKEXP0004,SKEXP0010,SKEXP0011,,SKEXP0012,SKEXP0020,SKEXP0021,SKEXP0022,SKEXP0023,SKEXP0024,SKEXP0025,SKEXP0026,SKEXP0027,SKEXP0028,SKEXP0029,SKEXP0030,SKEXP0031,SKEXP0032,SKEXP0040,SKEXP0041,SKEXP0042,SKEXP0050,SKEXP0051,SKEXP0052,SKEXP0053,SKEXP0054,SKEXP0055,SKEXP0060,SKEXP0061,SKEXP0101,SKEXP0102</NoWarn>
</PropertyGroup>
<ItemGroup>
<Compile Remove="llamafactory\**" />
<Content Remove="llamafactory\**" />
<EmbeddedResource Remove="llamafactory\**" />
<None Remove="llamafactory\**" />
</ItemGroup>
<ItemGroup>

View File

@@ -35,7 +35,6 @@ namespace AntSK.Controllers
[HttpPost]
public async Task<IActionResult> ImportKMSTask(ImportKMSTaskDTO model)
{
Console.WriteLine("api/kms/ImportKMSTask 开始");
ImportKMSTaskReq req = model.ToDTO<ImportKMSTaskReq>();
KmsDetails detail = new KmsDetails()
{
@@ -49,7 +48,6 @@ namespace AntSK.Controllers
await _kmsDetailsRepositories.InsertAsync(detail);
req.KmsDetail = detail;
_taskBroker.QueueWorkItem(req);
Console.WriteLine("api/kms/ImportKMSTask 结束");
return Ok();
}
}

View File

@@ -24,45 +24,57 @@
<IconPicker @bind-Value="@context.Icon"></IconPicker>
</FormItem>
<FormItem Label="类型" LabelCol="LayoutModel._formItemLayout.LabelCol" WrapperCol="LayoutModel._formItemLayout.WrapperCol">
<RadioGroup @bind-Value="context.Type">
<RadioGroup @bind-Value="context.Type" OnChange="OnAppTypeChange" TValue="string">
<Radio RadioButton Value="@AppType.chat.ToString()">会话应用</Radio>
<Radio RadioButton Value="@AppType.kms.ToString()">知识库</Radio>
</RadioGroup>
</FormItem>
<FormItem Label="描述" LabelCol="LayoutModel._formItemLayout.LabelCol" WrapperCol="LayoutModel._formItemLayout.WrapperCol">
<Input Placeholder="请输入描述" @bind-Value="@context.Describe" />
</FormItem>
<Radio RadioButton Value="@AppType.kms.ToString()">知识库</Radio>
<Radio RadioButton Value="@AppType.img.ToString()">做图应用</Radio>
</RadioGroup>
</FormItem>
<FormItem Label="描述" LabelCol="LayoutModel._formItemLayout.LabelCol" WrapperCol="LayoutModel._formItemLayout.WrapperCol">
<Input Placeholder="请输入描述" @bind-Value="@context.Describe" />
</FormItem>
<FormItem Label="会话模型" LabelCol="LayoutModel._formItemLayout.LabelCol" WrapperCol="LayoutModel._formItemLayout.WrapperCol">
<Select DataSource="@_chatList"
@bind-Value="@context.ChatModelID"
ValueProperty="c=>c.Id"
LabelProperty="c=>'【'+c.AIType.ToString()+'】'+c.ModelDescription">
</Select>
<Button Type="@ButtonType.Link" OnClick="NavigateModelList">去创建</Button>
</FormItem>
<FormItem Label="向量模型" LabelCol="LayoutModel._formItemLayout.LabelCol" WrapperCol="LayoutModel._formItemLayout.WrapperCol">
<Select DataSource="@_embedignList"
@bind-Value="@context.EmbeddingModelID"
<FormItem Label="会话模型" LabelCol="LayoutModel._formItemLayout.LabelCol" WrapperCol="LayoutModel._formItemLayout.WrapperCol">
<Select DataSource="@_chatList"
@bind-Value="@context.ChatModelID"
ValueProperty="c=>c.Id"
LabelProperty="c=>'【'+c.AIType.ToString()+'】'+c.ModelDescription">
</Select>
<Button Type="@ButtonType.Link" OnClick="NavigateModelList">去创建</Button>
</FormItem>
@if (@context.Type == AppType.chat.ToString())
@if (@context.Type != AppType.img.ToString())
{
<FormItem Label="向量模型" LabelCol="LayoutModel._formItemLayout.LabelCol" WrapperCol="LayoutModel._formItemLayout.WrapperCol">
<Select DataSource="@_embedingList"
@bind-Value="@context.EmbeddingModelID"
ValueProperty="c=>c.Id"
LabelProperty="c=>'【'+c.AIType.ToString()+'】'+c.ModelDescription">
</Select>
<Button Type="@ButtonType.Link" OnClick="NavigateModelList">去创建</Button>
</FormItem>
<FormItem Label="提示词" LabelCol="LayoutModel._formItemLayout.LabelCol" WrapperCol="LayoutModel._formItemLayout.WrapperCol">
<TextArea MinRows="4" Placeholder="请输入提示词,用户输入使用{{$input}} 来做占位符" @bind-Value="@context.Prompt" />
</FormItem>
<FormItem Label="温度系数" LabelCol="LayoutModel._formItemLayout.LabelCol" WrapperCol="LayoutModel._formItemLayout.WrapperCol">
<span>更确定</span>
<Slider TValue="double" Style="display: inline-block;width: 300px; " Min="0" Max="100" DefaultValue="70" @bind-Value="@context.Temperature" />
<span>更发散</span>
</FormItem>
<FormItem Label="API插件列表" LabelCol="LayoutModel._formItemLayout.LabelCol" WrapperCol="LayoutModel._formItemLayout.WrapperCol">
}
else
{
<FormItem Label="图片模型" LabelCol="LayoutModel._formItemLayout.LabelCol" WrapperCol="LayoutModel._formItemLayout.WrapperCol">
<Select DataSource="@_imageList"
@bind-Value="@context.ImageModelID"
ValueProperty="c=>c.Id"
LabelProperty="c=>'【'+c.AIType.ToString()+'】'+c.ModelDescription">
</Select>
<Button Type="@ButtonType.Link" OnClick="NavigateModelList">去创建</Button>
</FormItem>
}
@if (@context.Type == AppType.chat.ToString())
{
<FormItem Label="API插件列表" LabelCol="LayoutModel._formItemLayout.LabelCol" WrapperCol="LayoutModel._formItemLayout.WrapperCol">
<Select Mode="multiple"
@bind-Values="apiIds"
Placeholder="选择API插件, 选择后会开启自动调用"
@@ -110,6 +122,19 @@
}
</SelectOptions>
</Select>
<Button Type="@ButtonType.Link" OnClick="NavigateKmsList">去创建</Button>
</FormItem>
<FormItem Label="提问最大token数" LabelCol="LayoutModel._formItemLayout.LabelCol" WrapperCol="LayoutModel._formItemLayout.WrapperCol">
<AntDesign.InputNumber @bind-Value="context.MaxAskPromptSize" PlaceHolder="提问最大token数"></AntDesign.InputNumber>
</FormItem>
<FormItem Label="回答最大token数" LabelCol="LayoutModel._formItemLayout.LabelCol" WrapperCol="LayoutModel._formItemLayout.WrapperCol">
<AntDesign.InputNumber @bind-Value="context.AnswerTokens" PlaceHolder="回答最大token数"></AntDesign.InputNumber>
</FormItem>
<FormItem Label="向量匹配数" LabelCol="LayoutModel._formItemLayout.LabelCol" WrapperCol="LayoutModel._formItemLayout.WrapperCol">
<AntDesign.InputNumber @bind-Value="context.MaxMatchesCount" PlaceHolder="向量匹配数"></AntDesign.InputNumber>
</FormItem>
<FormItem Label="最低相似度(%" LabelCol="LayoutModel._formItemLayout.LabelCol" WrapperCol="LayoutModel._formItemLayout.WrapperCol">
<Slider TValue="double" Style="display: inline-block;width: 300px; " Min="0" Max="100" DefaultValue="70" @bind-Value="@context.Relevance" />
</FormItem>
}
<FormItem Label=" " Style="margin-top:32px" WrapperCol="LayoutModel._submitFormLayout.WrapperCol">

View File

@@ -1,4 +1,5 @@
using AntDesign;
using AntSK.Domain.Domain.Model.Constant;
using AntSK.Domain.Domain.Model.Enum;
using AntSK.Domain.Domain.Service;
using AntSK.Domain.Repositories;
@@ -45,7 +46,8 @@ namespace AntSK.Pages.AppPage
public Dictionary<string, string> _funList = new Dictionary<string, string>();
private List<AIModels> _chatList;
private List<AIModels> _embedignList;
private List<AIModels> _embedingList;
private List<AIModels> _imageList;
protected override async Task OnInitializedAsync()
{
await base.OnInitializedAsync();
@@ -53,7 +55,8 @@ namespace AntSK.Pages.AppPage
_apiList = _apis_Repositories.GetList();
var models=_aimodels_Repositories.GetList();
_chatList = models.Where(p => p.AIModelType == AIModelType.Chat).ToList();
_embedignList = models.Where(p => p.AIModelType == AIModelType.Embedding).ToList();
_embedingList = models.Where(p => p.AIModelType == AIModelType.Embedding).ToList();
_imageList = models.Where(p => p.AIModelType == AIModelType.Image).ToList();
_functionService.SearchMarkedMethods();
foreach (var func in _functionService.Functions)
@@ -87,7 +90,14 @@ namespace AntSK.Pages.AppPage
}
_appModel.KmsIdList = string.Join(",", kmsIds);
}
if (_appModel.Type == AppType.kms.ToString())
{
if (string.IsNullOrEmpty(_appModel.Prompt)|| !_appModel.Prompt.Contains("{{$doc}}") || !_appModel.Prompt.Contains("{{$input}}"))
{
_ = Message.Error("知识库提示词必须包含 {{$doc}} 和 {{$input}}", 2);
return;
}
}
if (apiIds.IsNotNull())
{
_appModel.ApiFunctionList = string.Join(",", apiIds);
@@ -97,7 +107,7 @@ namespace AntSK.Pages.AppPage
_appModel.NativeFunctionList = string.Join(",", funIds);
}
if (string.IsNullOrEmpty(AppId))
{
//新增
@@ -133,6 +143,25 @@ namespace AntSK.Pages.AppPage
{
NavigationManager.NavigateTo("/setting/modellist");
}
private void NavigateKmsList()
{
NavigationManager.NavigateTo("/KmsList");
}
private void OnAppTypeChange(string value)
{
if (value == AppType.kms.ToString() && string.IsNullOrEmpty( _appModel.Prompt))
{
_appModel.Prompt = KmsConstantcs.KmsPrompt;
}
if (value == AppType.chat.ToString())
{
_appModel.Prompt = "";
}
}
}
}

View File

@@ -58,6 +58,10 @@
<Tag Color="@PresetColor.Green.ToString()">知识库</Tag>
}
else if (context.Type == AppType.img.ToString())
{
<Tag Color="@PresetColor.Red.ToString()">做图应用</Tag>
}
</DescriptionTemplate>
</CardMeta>
</Card>

View File

@@ -46,7 +46,7 @@
</GridCol>
<GridCol Span="23">
<div class="chat-bubble received">
@((MarkupString)(item.HtmlAnswers))
@((MarkupString)(item.Context))
</div>
</GridCol>

View File

@@ -2,13 +2,17 @@
using AntSK.Domain.Domain.Interface;
using AntSK.Domain.Domain.Model;
using AntSK.Domain.Domain.Model.Dto;
using AntSK.Domain.Domain.Model.Enum;
using AntSK.Domain.Repositories;
using AntSK.Domain.Utils;
using AntSK.LLM.StableDiffusion;
using Blazored.LocalStorage;
using Markdig;
using Microsoft.AspNetCore.Components;
using Microsoft.JSInterop;
using Microsoft.SemanticKernel;
using Microsoft.SemanticKernel.ChatCompletion;
using Microsoft.SemanticKernel.Connectors.OpenAI;
using Newtonsoft.Json;
namespace AntSK.Pages.ChatPage.Components
@@ -76,7 +80,9 @@ namespace AntSK.Pages.ChatPage.Components
{
MessageList.Clear();
await _localStorage.SetItemAsync<List<MessageInfo>>("msgs", MessageList);
await InvokeAsync(StateHasChanged);
_ = Message.Info("清理成功");
}
}
else
@@ -107,9 +113,7 @@ namespace AntSK.Pages.ChatPage.Components
Sendding = true;
await SendAsync(_messageInput,filePath);
_messageInput = "";
Sendding = false;
await _localStorage.SetItemAsync<List<MessageInfo>>("msgs", MessageList);
Sendding = false;
}
catch (System.Exception ex)
{
@@ -144,23 +148,48 @@ namespace AntSK.Pages.ChatPage.Components
{
history = await _chatService.GetChatHistory(MessageList);
}
switch (app.Type)
{
case "chat" when filePath == null||app.EmbeddingModelID.IsNull():
//普通会话
await SendChat(questions, history, app);
break;
default:
//知识库问答
await SendKms(questions, history, app, filePath);
break;
if (app.Type == AppType.chat.ToString() && (filePath == null || app.EmbeddingModelID.IsNull()))
{
await SendChat(questions, history, app);
}
else if (app.Type == AppType.kms.ToString() || filePath != null || app.EmbeddingModelID.IsNotNull())
{
await SendKms(questions, history, app, filePath);
}
else if (app.Type == AppType.img.ToString())
{
await SendImg(questions,app);
}
//缓存消息记录
if (app.Type != AppType.img.ToString())
{
await _localStorage.SetItemAsync<List<MessageInfo>>("msgs", MessageList);
}
return await Task.FromResult(true);
}
private async Task SendImg(string questions,Apps app)
{
MessageInfo info = new MessageInfo();
info.ID = Guid.NewGuid().ToString();
info.CreateTime = DateTime.Now;
var base64= await _chatService.SendImgByAppAsync(app, questions);
if (string.IsNullOrEmpty(base64))
{
info.Context = "生成失败";
}
else
{
info.Context = $"<img src=\"data:image/jpeg;base64,{base64}\" alt=\"Base64 Image\" />";
}
MessageList.Add(info);
}
/// <summary>
/// 发送知识库问答
/// </summary>
@@ -179,14 +208,13 @@ namespace AntSK.Pages.ChatPage.Components
info = new MessageInfo();
info.ID = Guid.NewGuid().ToString();
info.Context = content.ConvertToString();
info.HtmlAnswers = content.ConvertToString();
info.CreateTime = DateTime.Now;
MessageList.Add(info);
}
else
{
info.HtmlAnswers += content.ConvertToString();
info.Context += content.ConvertToString();
await Task.Delay(50);
}
await InvokeAsync(StateHasChanged);
@@ -214,14 +242,13 @@ namespace AntSK.Pages.ChatPage.Components
info = new MessageInfo();
info.ID = Guid.NewGuid().ToString();
info.Context = content.ConvertToString();
info.HtmlAnswers = content.ConvertToString();
info.CreateTime = DateTime.Now;
MessageList.Add(info);
}
else
{
info.HtmlAnswers += content.ConvertToString();
info.Context += content.ConvertToString();
await Task.Delay(50);
}
await InvokeAsync(StateHasChanged);
@@ -235,7 +262,7 @@ namespace AntSK.Pages.ChatPage.Components
if (info.IsNotNull())
{
// info!.HtmlAnswers = markdown.Transform(info.HtmlAnswers);
info!.HtmlAnswers = Markdown.ToHtml(info.HtmlAnswers);
info!.Context = Markdown.ToHtml(info.Context);
}
await InvokeAsync(StateHasChanged);

View File

@@ -17,7 +17,7 @@
<Extra>
<Button Type="@ButtonType.Primary" Style="position: absolute; right:360px; margin-bottom: 8px;" OnClick="Refresh">刷新 </Button>
<Dropdown Style="position: absolute; right: 20px; margin-bottom: 8px;">
<Dropdown Style="position: absolute; right: 20px; margin-bottom: 8px;" Trigger="@(new Trigger[] { Trigger.Click })">
<Overlay>
<Menu>
@(_fileUpload(() => FileShowModal()))

View File

@@ -24,18 +24,31 @@
</FormItem>
<FormItem Label="AI类型" LabelCol="LayoutModel._formItemLayout.LabelCol" WrapperCol="LayoutModel._formItemLayout.WrapperCol">
<EnumRadioGroup @bind-Value="context.AIType"></EnumRadioGroup>
<EnumRadioGroup @bind-Value="context.AIType" ButtonStyle="RadioButtonStyle.Solid" OnChange="AITypeChange" TEnum="AIType"> </EnumRadioGroup>
</FormItem>
<FormItem Label="模型类型" LabelCol="LayoutModel._formItemLayout.LabelCol" WrapperCol="LayoutModel._formItemLayout.WrapperCol">
<RadioGroup @bind-Value="context.AIModelType">
<Radio RadioButton Value="@(AIModelType.Chat)">会话模型</Radio>
<Radio RadioButton Value="@(AIModelType.Embedding)">向量模型</Radio>
<RadioGroup @bind-Value="context.AIModelType">
@if (context.AIType == AIType.StableDiffusion)
{
<Radio RadioButton Value="@(AIModelType.Image)">图片模型</Radio>
}
else
{
@if (context.AIType != AIType.BgeEmbedding)
{
<Radio RadioButton Value="@(AIModelType.Chat)">会话模型</Radio>
}
@if (context.AIType != AIType.LLamaFactory && context.AIType != AIType.Mock && context.AIType != AIType.SparkDesk)
{
<Radio RadioButton Value="@(AIModelType.Embedding)">向量模型</Radio>
}
}
</RadioGroup>
</FormItem>
@if (context.AIModelType == AIModelType.Embedding)
{
<FormItem Label="注意事项" LabelCol="LayoutModel._formItemLayout.LabelCol" WrapperCol="LayoutModel._formItemLayout.WrapperCol">
<b>请不要使用不同维度的向量模型,否则会导致无法向量存储</b>
<span style="color:red"><b>请不要使用不同维度的向量模型,否则会导致无法向量存储</b></span>
</FormItem>
}
@@ -84,7 +97,7 @@
<Input Placeholder="请输入模型名称" @bind-Value="@context.ModelName" />
</FormItem>
}
@if (context.AIType == AIType.LLamaSharp)
@if (context.AIType == AIType.LLamaSharp || context.AIType == AIType.StableDiffusion)
{
<FormItem Label="模型路径" LabelCol="LayoutModel._formItemLayout.LabelCol" WrapperCol="LayoutModel._formItemLayout.WrapperCol">
<InputGroup>
@@ -113,6 +126,9 @@
<FormItem Label="请求地址" LabelCol="LayoutModel._formItemLayout.LabelCol" WrapperCol="LayoutModel._formItemLayout.WrapperCol">
<Input Placeholder="http://localhost:8080/" @bind-Value="@context.EndPoint" />
</FormItem>
<FormItem Label="环境安装" LabelCol="LayoutModel._formItemLayout.LabelCol" WrapperCol="LayoutModel._formItemLayout.WrapperCol">
<Button Type="primary" OnClick="PipInstall">环境安装</Button>
</FormItem>
<FormItem Label="llama factory服务" LabelCol="LayoutModel._formItemLayout.LabelCol" WrapperCol="LayoutModel._formItemLayout.WrapperCol">
<InputGroup>
@if (!llamaFactoryIsStart)
@@ -124,11 +140,31 @@
<Button OnClick="StopLFService" Disabled="@(!llamaFactoryIsStart)">停止</Button>
}
</InputGroup>
</FormItem>
}
@if (context.AIType == AIType.BgeEmbedding)
{
<FormItem Label="模型名称" LabelCol="LayoutModel._formItemLayout.LabelCol" WrapperCol="LayoutModel._formItemLayout.WrapperCol">
<Select DataSource="@bgeEmbeddingList"
@bind-Value="@context.ModelName"
ValueProperty="c=>c"
LabelProperty="c=>c">
</Select>
</FormItem>
<FormItem Label="环境安装" LabelCol="LayoutModel._formItemLayout.LabelCol" WrapperCol="LayoutModel._formItemLayout.WrapperCol">
<Button Type="primary" OnClick="PipInstall" >初始化</Button>
<FormItem Label="PythonDll路径" LabelCol="LayoutModel._formItemLayout.LabelCol" WrapperCol="LayoutModel._formItemLayout.WrapperCol">
<Input Placeholder="D:\Programs\Python\Python311\python311.dll" @bind-Value="@context.EndPoint" />
</FormItem>
<FormItem Label="环境安装" LabelCol="LayoutModel._formItemLayout.LabelCol" WrapperCol="LayoutModel._formItemLayout.WrapperCol">
<Button Type="primary" OnClick="PipInstall">环境安装</Button>
</FormItem>
<FormItem Label="下载并初始化" LabelCol="LayoutModel._formItemLayout.LabelCol" WrapperCol="LayoutModel._formItemLayout.WrapperCol">
<Spin Tip="请等待..." Spinning="@(BgeIsStart)" >
<Button Type="primary" Disabled="@(BgeIsStart)" OnClick="BgeDownload">@BgeBtnText</Button>
</Spin>
</FormItem>
}
@if (context.AIType == AIType.Mock)
{
}
@@ -168,3 +204,7 @@
</Modal>
@code
{
}

View File

@@ -3,6 +3,7 @@ using AntDesign.ProLayout;
using AntSK.Domain.Domain.Interface;
using AntSK.Domain.Domain.Model.Constant;
using AntSK.Domain.Domain.Model.Enum;
using AntSK.Domain.Domain.Other;
using AntSK.Domain.Domain.Service;
using AntSK.Domain.Options;
using AntSK.Domain.Repositories;
@@ -12,6 +13,7 @@ using BlazorComponents.Terminal;
using DocumentFormat.OpenXml.Office2010.Excel;
using Downloader;
using Microsoft.AspNetCore.Components;
using NRedisStack.Search;
using System.ComponentModel;
namespace AntSK.Pages.Setting.AIModel
@@ -58,6 +60,10 @@ namespace AntSK.Pages.Setting.AIModel
private TerminalParagraph para;
private bool _logModalVisible;
private List<string> bgeEmbeddingList = new List<string>() { "AI-ModelScope/bge-small-zh-v1.5", "AI-ModelScope/bge-base-zh-v1.5", "AI-ModelScope/bge-large-zh-v1.5" };
private bool BgeIsStart = false;
private string BgeBtnText = "初始化";
protected override async Task OnInitializedAsync()
{
try
@@ -75,12 +81,25 @@ namespace AntSK.Pages.Setting.AIModel
llamaFactoryIsStart = llamaFactoryDic.Value == "false" ? false : true;
}
//目前只支持gguf的 所以筛选一下
_modelFiles = Directory.GetFiles(Path.Combine(Directory.GetCurrentDirectory(), LLamaSharpOption.FileDirectory)).Where(p=>p.Contains(".gguf")).ToArray();
_modelFiles = Directory.GetFiles(Path.Combine(Directory.GetCurrentDirectory(), LLamaSharpOption.FileDirectory)).Where(p=> p.Contains(".gguf")||p.Contains(".ckpt")|| p.Contains(".safetensors")).ToArray();
if (!string.IsNullOrEmpty(ModelPath))
{
string extension = Path.GetExtension(ModelPath);
switch (extension)
{
case ".gguf":
_aiModel.AIType = AIType.LLamaSharp;
break;
case ".safetensors":
case ".ckpt":
_aiModel.AIType = AIType.StableDiffusion;
break;
}
//下载页跳入
_aiModel.AIType = AIType.LLamaSharp;
_downloadModalVisible = true;
_downloadUrl = $"https://hf-mirror.com{ModelPath.Replace("---","/")}";
@@ -214,7 +233,7 @@ namespace AntSK.Pages.Setting.AIModel
/// <summary>
/// 启动服务
/// </summary>
private void StartLFService()
private async Task StartLFService()
{
if (string.IsNullOrEmpty(_aiModel.ModelName))
{
@@ -248,6 +267,37 @@ namespace AntSK.Pages.Setting.AIModel
_ILLamaFactoryService.PipInstall();
}
}
private async Task BgeDownload()
{
if (string.IsNullOrEmpty(_aiModel.ModelName))
{
_ = Message.Error("请输入模型名称!", 2);
return;
}
if (string.IsNullOrEmpty(_aiModel.EndPoint))
{
_ = Message.Error("请输入正确的Python dll路径", 2);
return;
}
BgeIsStart = true;
BgeBtnText = "正在初始化...";
await Task.Run(() =>
{
try
{
EmbeddingConfig.LoadModel(_aiModel.EndPoint, _aiModel.ModelName);
BgeBtnText = "初始化完成";
BgeIsStart = false;
}
catch (System.Exception ex)
{
_ = Message.Error(ex.Message, 2);
BgeIsStart = false;
}
});
}
private async Task CmdLogHandler(string message)
{
await InvokeAsync(() =>
@@ -263,5 +313,28 @@ namespace AntSK.Pages.Setting.AIModel
private void OnCancelLog() {
_logModalVisible = false;
}
private void AITypeChange(AIType aiType)
{
switch (aiType)
{
case AIType.LLamaFactory:
_aiModel.EndPoint = "http://localhost:8080/";
_aiModel.AIModelType=AIModelType.Chat;
break;
case AIType.StableDiffusion:
_aiModel.AIModelType = AIModelType.Image;
break;
case AIType.Mock:
_aiModel.AIModelType = AIModelType.Chat;
break ;
case AIType.BgeEmbedding:
_aiModel.AIModelType = AIModelType.Embedding;
break;
default:
_aiModel.AIModelType = AIModelType.Chat;
break;
}
}
}
}

View File

@@ -10,12 +10,19 @@
<PageContainer Title="模型列表">
<Content>
<RadioGroup @bind-Value="@_modelType" OnChange="OnModelTypeChange" TValue="string">
<Radio Value="@("gguf")" DefaultChecked=true>LLama本地模型(gguf)</Radio>
<Radio Value="@("safetensors")">StableDiffusion(safetensors)</Radio>
<Radio Value="@("ckpt")">StableDiffusion2(ckpt)</Radio>
</RadioGroup>
<div style="text-align: center;">
<Search Placeholder="输入回车"
EnterButton="@("搜索")"
Size="large"
Style="max-width: 522px; width: 100%;"
OnSearch="Search" />
</div>
</Content>
<ChildContent>

View File

@@ -16,7 +16,7 @@ namespace AntSK.Pages.Setting.AIModel
private readonly IList<string> _selectCategories = new List<string>();
private List<HfModels> _modelList = new List<HfModels>();
private string _modelType;
protected override async Task OnInitializedAsync()
{
await base.OnInitializedAsync();
@@ -27,7 +27,7 @@ namespace AntSK.Pages.Setting.AIModel
{
var param = searchKey.ConvertToString().Split(" ");
string urlBase = "https://hf-mirror.com/models-json?sort=trending&search=gguf";
string urlBase = $"https://hf-mirror.com/models-json?sort=trending&search={_modelType}";
if (param.Count() > 0)
{
urlBase += "+" + string.Join("+", param);
@@ -48,5 +48,10 @@ namespace AntSK.Pages.Setting.AIModel
{
NavigationManager.NavigateTo($"/setting/modeldown/detail/{modelPath}");
}
private void OnModelTypeChange(string value)
{
InitData("");
}
}
}

View File

@@ -70,6 +70,14 @@
{
<Tag Color="@PresetColor.Cyan.ToString()">LLamaFactory</Tag>
}
else if (context.AIType == AIType.BgeEmbedding)
{
<Tag Color="@PresetColor.Gold.ToString()">BgeEmbedding</Tag>
}
else if (context.AIType == AIType.StableDiffusion)
{
<Tag Color="@PresetColor.Lime.ToString()">StableDiffusion</Tag>
}
</p>
</div>
@@ -85,6 +93,10 @@
{
<Tag Color="@PresetColor.Green.ToString()">向量模型</Tag>
}
else if (context.AIModelType == AIModelType.Image)
{
<Tag Color="@PresetColor.Lime.ToString()">图片模型</Tag>
}
</p>
</div>
<div class="listContentItem" style="width:20%">

View File

@@ -201,7 +201,7 @@ namespace AntSK.Services.OpenApi
string result = "";
var _kernel = _kernelService.GetKernelByApp(app);
var relevantSource = await _kMService.GetRelevantSourceList(app.KmsIdList, questions);
var relevantSource = await _kMService.GetRelevantSourceList(app, questions);
var dataMsg = new StringBuilder();
if (relevantSource.Any())
{
@@ -210,9 +210,12 @@ namespace AntSK.Services.OpenApi
dataMsg.AppendLine(item.ToString());
}
KernelFunction jsonFun = _kernel.Plugins.GetFunction("KMSPlugin", "Ask1");
var chatResult = await _kernel.InvokeAsync(function: jsonFun,
arguments: new KernelArguments() { ["doc"] = dataMsg, ["history"] = string.Join("\n", history.Select(x => x.Role + ": " + x.Content)), ["questions"] = questions });
//KernelFunction jsonFun = _kernel.Plugins.GetFunction("KMSPlugin", "Ask1");
var temperature = app.Temperature / 100;//存的是0~100需要缩小
OpenAIPromptExecutionSettings settings = new() { Temperature = temperature };
var func = _kernel.CreateFunctionFromPrompt(app.Prompt, settings);
var chatResult = await _kernel.InvokeAsync(function: func,
arguments: new KernelArguments() { ["doc"] = dataMsg, ["history"] = string.Join("\n", history.Select(x => x.Role + ": " + x.Content)), ["input"] = questions });
if (chatResult.IsNotNull())
{
string answers = chatResult.GetValue<string>();
@@ -236,15 +239,15 @@ namespace AntSK.Services.OpenApi
for (int i = 0; i < model.messages.Count() - 1; i++)
{
var item = model.messages[i];
if (item.role.ToLower() == "user")
if (item.role.ComparisonIgnoreCase("user"))
{
history.AddUserMessage(item.content);
}
else if (item.role.ToLower() == "assistant")
else if (item.role.ComparisonIgnoreCase("assistant"))
{
history.AddAssistantMessage(item.content);
}
else if (item.role.ToLower() == "system")
else if (item.role.ComparisonIgnoreCase("system"))
{
history.AddSystemMessage(item.content);
}

View File

@@ -3,7 +3,7 @@ Facts:
--------------------------
History:{{ConversationSummaryPlugin.SummarizeConversation $history}}
--------------------------
Question: {{$questions}}
Question: {{$input}}
--------------------------
Given only the facts above, provide a comprehensive/detailed answer.
You don't know where the knowledge comes from, just answer.

View File

@@ -13,5 +13,5 @@
历史聊天记录:{{ConversationSummaryPlugin.SummarizeConversation $history}}
--------------------------
用户问题: {{$questions}}
用户问题: {{$input}}

View File

@@ -57,5 +57,11 @@
"key": "setting.modeldown"
}
]
},
{
"path": "http://antsk.cn/",
"name": "文档",
"key": "antskdoc",
"icon": "question-circle"
}
]

View File

@@ -13,6 +13,13 @@
<PackageReference Include="Cnblogs.SemanticKernel.Connectors.DashScope" Version="0.3.2" />
<PackageReference Include="Microsoft.SemanticKernel" Version="1.6.3" />
<PackageReference Include="Sdcb.SparkDesk" Version="3.0.0" />
<PackageReference Include="System.Drawing.Common" Version="8.0.0" />
</ItemGroup>
<ItemGroup>
<None Update="stable-diffusion.dll">
<CopyToOutputDirectory>Always</CopyToOutputDirectory>
</None>
</ItemGroup>
</Project>

View File

@@ -0,0 +1,108 @@
using System;
using System.Runtime.InteropServices;
namespace AntSK.LLM.StableDiffusion
{
using static AntSK.LLM.StableDiffusion.Structs;
using int32_t = Int32;
using int64_t = Int64;
using SdContext = IntPtr;
using SDImagePtr = IntPtr;
using UpscalerContext = IntPtr;
internal class Native
{
const string DllName = "stable-diffusion";
internal delegate void SdLogCallback(SdLogLevel level, [MarshalAs(UnmanagedType.LPStr)] string text, IntPtr data);
internal delegate void SdProgressCallback(int step, int steps, float time, IntPtr data);
[DllImport(DllName, EntryPoint = "new_sd_ctx", CallingConvention = CallingConvention.Cdecl)]
internal extern static SdContext new_sd_ctx(string model_path,
string vae_path,
string taesd_path,
string control_net_path_c_str,
string lora_model_dir,
string embed_dir_c_str,
string stacked_id_embed_dir_c_str,
bool vae_decode_only,
bool vae_tiling,
bool free_params_immediately,
int n_threads,
WeightType weightType,
RngType rng_type,
ScheduleType s,
bool keep_clip_on_cpu,
bool keep_control_net_cpu,
bool keep_vae_on_cpu);
[DllImport(DllName, EntryPoint = "txt2img", CallingConvention = CallingConvention.Cdecl)]
internal static extern SDImagePtr txt2img(SdContext sd_ctx,
string prompt,
string negative_prompt,
int clip_skip,
float cfg_scale,
int width,
int height,
SampleMethod sample_method,
int sample_steps,
int64_t seed,
int batch_count,
SDImagePtr control_cond,
float control_strength,
float style_strength,
bool normalize_input,
string input_id_images_path);
[DllImport(DllName, EntryPoint = "img2img", CallingConvention = CallingConvention.Cdecl)]
internal static extern SDImagePtr img2img(SdContext sd_ctx,
SDImage init_image,
string prompt_c_str,
string negative_prompt_c_str,
int clip_skip,
float cfg_scale,
int width,
int height,
SampleMethod sample_method,
int sample_steps,
float strength,
int64_t seed,
int batch_count);
[DllImport(DllName, CallingConvention = CallingConvention.Cdecl)]
internal static extern IntPtr preprocess_canny(IntPtr imgData,
int width,
int height,
float high_threshold,
float low_threshold,
float weak,
float strong,
bool inverse);
[DllImport(DllName, CallingConvention = CallingConvention.Cdecl)]
internal static extern UpscalerContext new_upscaler_ctx(string esrgan_path,
int n_threads,
WeightType wtype);
[DllImport(DllName, CallingConvention = CallingConvention.Cdecl)]
internal static extern int32_t get_num_physical_cores();
[DllImport(DllName, CallingConvention = CallingConvention.Cdecl)]
internal static extern void free_sd_ctx(SdContext sd_ctx);
[DllImport(DllName, CallingConvention = CallingConvention.Cdecl)]
internal static extern void free_upscaler_ctx(UpscalerContext upscaler_ctx);
[DllImport(DllName, CallingConvention = CallingConvention.Cdecl)]
internal static extern SDImage upscale(UpscalerContext upscaler_ctx, SDImage input_image, int upscale_factor);
[DllImport(DllName, EntryPoint = "sd_set_log_callback", CallingConvention = CallingConvention.Cdecl)]
internal static extern void sd_set_log_callback(SdLogCallback cb, IntPtr data);
[DllImport(DllName, EntryPoint = "sd_set_progress_callback", CallingConvention = CallingConvention.Cdecl)]
internal static extern void sd_set_progress_callback(SdProgressCallback cb, IntPtr data);
}
}

View File

@@ -0,0 +1,217 @@
using System;
using System.Drawing;
using System.Drawing.Imaging;
using System.Runtime.InteropServices;
namespace AntSK.LLM.StableDiffusion
{
using static AntSK.LLM.StableDiffusion.Structs;
using SdContext = IntPtr;
using SDImagePtr = IntPtr;
using UpscalerContext = IntPtr;
public static class SDHelper
{
public static bool IsInitialized => SdContext.Zero != sd_ctx;
public static bool IsUpscalerInitialized => UpscalerContext.Zero != upscaler_ctx;
private static SdContext sd_ctx = new SdContext();
private static UpscalerContext upscaler_ctx = new UpscalerContext();
public static event EventHandler<StableDiffusionEventArgs.StableDiffusionLogEventArgs> Log;
public static event EventHandler<StableDiffusionEventArgs.StableDiffusionProgressEventArgs> Progress;
static readonly Native.SdLogCallback sd_Log_Cb;
static readonly Native.SdProgressCallback sd_Progress_Cb;
static SDHelper()
{
sd_Log_Cb = new Native.SdLogCallback(OnNativeLog);
Native.sd_set_log_callback(sd_Log_Cb, IntPtr.Zero);
sd_Progress_Cb = new Native.SdProgressCallback(OnProgressRunning);
Native.sd_set_progress_callback(sd_Progress_Cb, IntPtr.Zero);
}
public static bool Initialize(ModelParams modelParams)
{
sd_ctx = Native.new_sd_ctx(modelParams.ModelPath,
modelParams.VaePath,
modelParams.TaesdPath,
modelParams.ControlnetPath,
modelParams.LoraModelDir,
modelParams.EmbeddingsPath,
modelParams.StackedIdEmbeddingsPath,
modelParams.VaeDecodeOnly,
modelParams.VaeTiling,
modelParams.FreeParamsImmediately,
modelParams.Threads,
modelParams.SdType,
modelParams.RngType,
modelParams.Schedule,
modelParams.KeepClipOnCpu,
modelParams.KeepControlNetOnCpu,
modelParams.KeepVaeOnCpu);
return SdContext.Zero != sd_ctx;
}
public static bool InitializeUpscaler(UpscalerParams @params)
{
upscaler_ctx = Native.new_upscaler_ctx(@params.ESRGANPath, @params.Threads, @params.SdType);
return UpscalerContext.Zero != upscaler_ctx;
}
public static void FreeSD()
{
if (SdContext.Zero != sd_ctx)
{
Native.free_sd_ctx(sd_ctx);
sd_ctx = SdContext.Zero;
}
}
public static void FreeUpscaler()
{
if (UpscalerContext.Zero != upscaler_ctx)
{
Native.free_upscaler_ctx(upscaler_ctx);
upscaler_ctx = UpscalerContext.Zero;
}
}
public static Bitmap[] TextToImage(TextToImageParams textToImageParams)
{
if (!IsInitialized) throw new ArgumentNullException("Model not loaded!");
SDImagePtr sd_Image_ptr = Native.txt2img(sd_ctx,
textToImageParams.Prompt,
textToImageParams.NegativePrompt,
textToImageParams.ClipSkip,
textToImageParams.CfgScale,
textToImageParams.Width,
textToImageParams.Height,
textToImageParams.SampleMethod,
textToImageParams.SampleSteps,
textToImageParams.Seed,
textToImageParams.BatchCount,
SDImagePtr.Zero,
textToImageParams.ControlStrength,
textToImageParams.StyleStrength,
textToImageParams.NormalizeInput,
textToImageParams.InputIdImagesPath);
Bitmap[] images = new Bitmap[textToImageParams.BatchCount];
for (int i = 0; i < textToImageParams.BatchCount; i++)
{
SDImage sd_image = Marshal.PtrToStructure<SDImage>(sd_Image_ptr + i * Marshal.SizeOf<SDImage>());
images[i] = GetBitmapFromSdImage(sd_image);
}
return images;
}
public static Bitmap ImageToImage(ImageToImageParams imageToImageParams)
{
if (!IsInitialized) throw new ArgumentNullException("Model not loaded!");
SDImage input_sd_image = GetSDImageFromBitmap(imageToImageParams.InputImage);
SDImagePtr sdImgPtr = Native.img2img(sd_ctx,
input_sd_image,
imageToImageParams.Prompt,
imageToImageParams.NegativePrompt,
imageToImageParams.ClipSkip,
imageToImageParams.CfgScale,
imageToImageParams.Width,
imageToImageParams.Height,
imageToImageParams.SampleMethod,
imageToImageParams.SampleSteps,
imageToImageParams.Strength,
imageToImageParams.Seed,
imageToImageParams.BatchCount);
SDImage sdImg = Marshal.PtrToStructure<SDImage>(sdImgPtr);
return GetBitmapFromSdImage(sdImg);
}
public static Bitmap UpscaleImage(Bitmap image, int upscaleFactor)
{
if (!IsUpscalerInitialized) throw new ArgumentNullException("Upscaler not loaded!");
SDImage inputSDImg = GetSDImageFromBitmap(image);
SDImage result = Native.upscale(upscaler_ctx, inputSDImg, upscaleFactor);
return GetBitmapFromSdImage(result);
}
private static Bitmap GetBitmapFromSdImage(SDImage sd_Image)
{
int width = (int)sd_Image.Width;
int height = (int)sd_Image.Height;
int channel = (int)sd_Image.Channel;
byte[] bytes = new byte[width * height * channel];
Marshal.Copy(sd_Image.Data, bytes, 0, bytes.Length);
Bitmap bmp = new Bitmap(width, height, PixelFormat.Format24bppRgb);
int stride = bmp.Width * channel;
byte[] des = new byte[bytes.Length];
for (int i = 0; i < height; i++)
{
for (int j = 0; j < width; j++)
{
des[stride * i + channel * j + 0] = bytes[stride * i + channel * j + 2];
des[stride * i + channel * j + 1] = bytes[stride * i + channel * j + 1];
des[stride * i + channel * j + 2] = bytes[stride * i + channel * j + 0];
}
}
BitmapData bitmapData = bmp.LockBits(new Rectangle(0, 0, width, height), ImageLockMode.WriteOnly, bmp.PixelFormat);
Marshal.Copy(des, 0, bitmapData.Scan0, bytes.Length);
bmp.UnlockBits(bitmapData);
return bmp;
}
private static SDImage GetSDImageFromBitmap(Bitmap bmp)
{
int width = bmp.Width;
int height = bmp.Height;
int channel = Bitmap.GetPixelFormatSize(bmp.PixelFormat) / 8;
int stride = width * channel;
byte[] bytes = new byte[width * height * channel];
BitmapData bitmapData = bmp.LockBits(new Rectangle(0, 0, width, height), ImageLockMode.ReadOnly, bmp.PixelFormat);
Marshal.Copy(bitmapData.Scan0, bytes, 0, bytes.Length);
bmp.UnlockBits(bitmapData);
byte[] sdImageBytes = new byte[bytes.Length];
for (int i = 0; i < height; i++)
{
for (int j = 0; j < width; j++)
{
sdImageBytes[stride * i + j * 3 + 0] = bytes[stride * i + j * 3 + 2];
sdImageBytes[stride * i + j * 3 + 1] = bytes[stride * i + j * 3 + 1];
sdImageBytes[stride * i + j * 3 + 2] = bytes[stride * i + j * 3 + 0];
}
}
SDImage sd_Image = new SDImage
{
Width = (uint)width,
Height = (uint)height,
Channel = 3,
Data = Marshal.UnsafeAddrOfPinnedArrayElement(sdImageBytes, 0),
};
return sd_Image;
}
private static void OnNativeLog(SdLogLevel level, string text, IntPtr data)
{
Log?.Invoke(null, new StableDiffusionEventArgs.StableDiffusionLogEventArgs { Level = level, Text = text });
}
private static void OnProgressRunning(int step, int steps, float time, IntPtr data)
{
Progress?.Invoke(null, new StableDiffusionEventArgs.StableDiffusionProgressEventArgs { Step = step, Steps = steps, Time = time });
}
}
}

View File

@@ -0,0 +1,33 @@
using System;
using static AntSK.LLM.StableDiffusion.Structs;
namespace AntSK.LLM.StableDiffusion
{
public class StableDiffusionEventArgs
{
public class StableDiffusionProgressEventArgs : EventArgs
{
#region Properties & Fields
public int Step { get; set; }
public int Steps { get; set; }
public float Time { get; set; }
public IntPtr Data { get; set; }
public double Progress => (double)Step / Steps;
public float IterationsPerSecond => 1.0f / Time;
#endregion
}
public class StableDiffusionLogEventArgs : EventArgs
{
#region Properties & Fields
public SdLogLevel Level { get; set; }
public string Text { get; set; }
#endregion
}
}
}

View File

@@ -0,0 +1,13 @@
using System;
using System.Collections.Generic;
using System.Linq;
using System.Text;
using System.Threading.Tasks;
namespace AntSK.LLM.StableDiffusion
{
public static class StableDiffusionService
{
}
}

View File

@@ -0,0 +1,154 @@
using System;
using System.Drawing;
using System.Runtime.InteropServices;
namespace AntSK.LLM.StableDiffusion
{
using int64_t = Int64;
using uint32_t = UInt32;
public class Structs
{
public class ModelParams
{
public string ModelPath = string.Empty;
public string VaePath = string.Empty;
public string TaesdPath = string.Empty;
public string ControlnetPath = string.Empty;
public string LoraModelDir = string.Empty;
public string EmbeddingsPath = string.Empty;
public string StackedIdEmbeddingsPath = string.Empty;
public bool VaeDecodeOnly = false;
public bool VaeTiling = true;
public bool FreeParamsImmediately = false;
public int Threads = Native.get_num_physical_cores();
public WeightType SdType = WeightType.SD_TYPE_COUNT;
public RngType RngType = RngType.CUDA_RNG;
public ScheduleType Schedule = ScheduleType.DEFAULT;
public bool KeepClipOnCpu = false;
public bool KeepControlNetOnCpu = false;
public bool KeepVaeOnCpu = false;
}
public class TextToImageParams
{
public string Prompt = string.Empty;
public string NegativePrompt = string.Empty;
public int ClipSkip = 0;
public float CfgScale = 7;
public int Width = 512;
public int Height = 512;
public SampleMethod SampleMethod = SampleMethod.EULER_A;
public int SampleSteps = 20;
public int64_t Seed = -1;
public int BatchCount = 1;
public Bitmap ControlCond = new Bitmap(1, 1);
public float ControlStrength = 0.9f;
public float StyleStrength = 0.75f;
public bool NormalizeInput = false;
public string InputIdImagesPath = string.Empty;
}
public class ImageToImageParams
{
public Bitmap InputImage;
public string Prompt = string.Empty;
public string NegativePrompt = string.Empty;
public int ClipSkip = -1;
public float CfgScale = 7.0f;
public int Width = 512;
public int Height = 512;
public SampleMethod SampleMethod = SampleMethod.EULER_A;
public int SampleSteps = 20;
public float Strength = 0.75f;
public int64_t Seed = 42;
public int BatchCount = 1;
}
public class UpscalerParams
{
public string ESRGANPath = string.Empty;
public int Threads = Native.get_num_physical_cores();
public WeightType SdType = WeightType.SD_TYPE_COUNT;
}
[StructLayout(LayoutKind.Sequential)]
internal struct SDImage
{
public uint32_t Width;
public uint32_t Height;
public uint32_t Channel;
public IntPtr Data;
}
public enum WeightType
{
SD_TYPE_F32 = 0,
SD_TYPE_F16 = 1,
SD_TYPE_Q4_0 = 2,
SD_TYPE_Q4_1 = 3,
// SD_TYPE_Q4_2 = 4, support has been removed
// SD_TYPE_Q4_3 (5) support has been removed
SD_TYPE_Q5_0 = 6,
SD_TYPE_Q5_1 = 7,
SD_TYPE_Q8_0 = 8,
SD_TYPE_Q8_1 = 9,
// k-quantizations
SD_TYPE_Q2_K = 10,
SD_TYPE_Q3_K = 11,
SD_TYPE_Q4_K = 12,
SD_TYPE_Q5_K = 13,
SD_TYPE_Q6_K = 14,
SD_TYPE_Q8_K = 15,
SD_TYPE_IQ2_XXS = 16,
SD_TYPE_IQ2_XS = 17,
SD_TYPE_IQ3_XXS = 18,
SD_TYPE_IQ1_S = 19,
SD_TYPE_IQ4_NL = 20,
SD_TYPE_IQ3_S = 21,
SD_TYPE_IQ2_S = 22,
SD_TYPE_IQ4_XS = 23,
SD_TYPE_I8,
SD_TYPE_I16,
SD_TYPE_I32,
SD_TYPE_COUNT,
};
public enum RngType
{
STD_DEFAULT_RNG,
CUDA_RNG
};
public enum ScheduleType
{
DEFAULT,
DISCRETE,
KARRAS,
N_SCHEDULES
};
public enum SampleMethod
{
EULER_A,
EULER,
HEUN,
DPM2,
DPMPP2S_A,
DPMPP2M,
DPMPP2Mv2,
LCM,
N_SAMPLE_METHODS
};
public enum SdLogLevel
{
Debug,
Info,
Warn,
Error
}
}
}