diff --git a/src/AntSK.Domain/AntSK.Domain.xml b/src/AntSK.Domain/AntSK.Domain.xml
index 3467025..72e858e 100644
--- a/src/AntSK.Domain/AntSK.Domain.xml
+++ b/src/AntSK.Domain/AntSK.Domain.xml
@@ -99,6 +99,11 @@
总数
+
+
+ 模型写死
+
+
避免模型重复加载,本地缓存
diff --git a/src/AntSK.Domain/Common/Embedding/BuilderBgeExtensions.cs b/src/AntSK.Domain/Common/Embedding/BuilderBgeExtensions.cs
new file mode 100644
index 0000000..c190102
--- /dev/null
+++ b/src/AntSK.Domain/Common/Embedding/BuilderBgeExtensions.cs
@@ -0,0 +1,21 @@
+using LLamaSharp.KernelMemory;
+using Microsoft.KernelMemory.AI;
+using Microsoft.KernelMemory;
+using System;
+using System.Collections.Generic;
+using System.Linq;
+using System.Text;
+using System.Threading.Tasks;
+
+namespace AntSK.Domain.Common.Embedding
+{
+ public static class BuilderBgeExtensions
+ {
+ public static IKernelMemoryBuilder WithBgeTextEmbeddingGeneration(this IKernelMemoryBuilder builder, HuggingfaceTextEmbeddingGenerator textEmbeddingGenerator)
+ {
+ builder.AddSingleton((ITextEmbeddingGenerator)textEmbeddingGenerator);
+ builder.AddIngestionEmbeddingGenerator(textEmbeddingGenerator);
+ return builder;
+ }
+ }
+}
diff --git a/src/AntSK.Domain/Common/Embedding/HuggingfaceTextEmbeddingGenerator.cs b/src/AntSK.Domain/Common/Embedding/HuggingfaceTextEmbeddingGenerator.cs
new file mode 100644
index 0000000..60456ad
--- /dev/null
+++ b/src/AntSK.Domain/Common/Embedding/HuggingfaceTextEmbeddingGenerator.cs
@@ -0,0 +1,56 @@
+using LLama.Common;
+using LLama;
+using LLamaSharp.KernelMemory;
+using Microsoft.KernelMemory.AI;
+using System;
+using System.Collections.Generic;
+using System.Linq;
+using System.Text;
+using System.Threading.Tasks;
+using AntSK.Domain.Domain.Other;
+
+namespace AntSK.Domain.Common.Embedding
+{
+ public class HuggingfaceTextEmbeddingGenerator : ITextEmbeddingGenerator, ITextTokenizer, IDisposable
+ {
+ public int MaxTokens => 1024;
+
+ public int MaxTokenTotal => 1024;
+
+
+ private readonly dynamic _embedder;
+
+ public HuggingfaceTextEmbeddingGenerator(string pyDllPath,string modelName)
+ {
+ _embedder = EmbeddingConfig.LoadModel(pyDllPath, modelName);
+ }
+
+ public void Dispose()
+ {
+ EmbeddingConfig.Dispose();
+ }
+
+ //public async Task>> GenerateEmbeddingAsync(IList data, CancellationToken cancellationToken = default)
+ //{
+ // IList> results = new List>();
+
+ // foreach (var d in data)
+ // {
+ // var embeddings = await EmbeddingConfig.GetEmbedding(d);
+ // results.Add(new ReadOnlyMemory(embeddings));
+ // }
+ // return results;
+ //}
+
+ public async Task GenerateEmbeddingAsync(string text, CancellationToken cancellationToken = default)
+ {
+ var embeddings = await EmbeddingConfig.GetEmbedding(text);
+ return new Microsoft.KernelMemory.Embedding(embeddings);
+ }
+
+ public int CountTokens(string text)
+ {
+ return 1024;
+ }
+ }
+}
diff --git a/src/AntSK.Domain/Domain/Interface/IPyNetEmbeddingService.cs b/src/AntSK.Domain/Domain/Interface/IPyNetEmbeddingService.cs
deleted file mode 100644
index 3f0bfe3..0000000
--- a/src/AntSK.Domain/Domain/Interface/IPyNetEmbeddingService.cs
+++ /dev/null
@@ -1,12 +0,0 @@
-using System;
-using System.Collections.Generic;
-using System.Linq;
-using System.Text;
-using System.Threading.Tasks;
-
-namespace AntSK.Domain.Domain.Interface
-{
- internal interface IPyNetEmbeddingService
- {
- }
-}
diff --git a/src/AntSK.Domain/Domain/Model/Enum/AIModelType.cs b/src/AntSK.Domain/Domain/Model/Enum/AIModelType.cs
index 98d5d2c..c879b57 100644
--- a/src/AntSK.Domain/Domain/Model/Enum/AIModelType.cs
+++ b/src/AntSK.Domain/Domain/Model/Enum/AIModelType.cs
@@ -21,12 +21,14 @@ namespace AntSK.Domain.Domain.Model.Enum
[Display(Name = "灵积大模型")]
DashScope = 5,
-
+
[Display(Name = "LLamaFactory")]
LLamaFactory = 6,
-
+ [Display(Name = "Bge Embedding")]
+ BgeEmbedding = 7,
[Display(Name = "模拟输出")]
Mock = 100,
+
}
///
diff --git a/src/AntSK.Domain/Domain/Other/EmbeddingConfig.cs b/src/AntSK.Domain/Domain/Other/EmbeddingConfig.cs
new file mode 100644
index 0000000..64cbc83
--- /dev/null
+++ b/src/AntSK.Domain/Domain/Other/EmbeddingConfig.cs
@@ -0,0 +1,73 @@
+using Python.Runtime;
+using System;
+using System.Collections.Generic;
+using System.Linq;
+using System.Text;
+using System.Threading.Tasks;
+using static Python.Runtime.Py;
+
+namespace AntSK.Domain.Domain.Other
+{
+ public static class EmbeddingConfig
+ {
+ public static dynamic model { get; set; }
+
+ static object lockobj = new object();
+
+ private static GILState GIL { get; set; }
+
+ ///
+ /// 模型写死
+ ///
+ public static dynamic LoadModel(string pythondllPath, string modelName)
+ {
+ lock (lockobj)
+ {
+ if (model == null)
+ {
+ //Runtime.PythonDLL = @"D:\Programs\Python\Python311\python311.dll";
+ Runtime.PythonDLL = pythondllPath;
+ PythonEngine.Initialize();
+ GIL= Py.GIL();// 初始化Python环境的Global Interpreter Lock
+ try
+ {
+ dynamic modelscope = Py.Import("modelscope");
+ //dynamic model_dir = modelscope.snapshot_download("AI-ModelScope/bge-large-zh-v1.5", revision: "master");
+ dynamic model_dir = modelscope.snapshot_download(modelName, revision: "master");
+ dynamic HuggingFaceBgeEmbeddingstemp = Py.Import("langchain.embeddings");
+ dynamic HuggingFaceBgeEmbeddings = HuggingFaceBgeEmbeddingstemp.HuggingFaceBgeEmbeddings;
+ string model_name = model_dir;
+ dynamic model_kwargs = new PyDict();
+ model_kwargs["device"] = new PyString("cpu");
+ dynamic hugginmodel = HuggingFaceBgeEmbeddings(
+ model_name: model_dir,
+ model_kwargs: model_kwargs
+ );
+ model = hugginmodel;
+ return hugginmodel;
+ }
+ catch
+ {
+ return null;
+ }
+
+ }
+ else
+ return model;
+ }
+ }
+
+ public static Task GetEmbedding(string queryStr)
+ {
+ PyObject queryResult = model.embed_query(queryStr);
+ var floatList = queryResult.As();
+ return Task.FromResult(floatList); ;
+ }
+
+ public static void Dispose()
+ {
+ Console.WriteLine("python dispose");
+ GIL.Dispose();
+ }
+ }
+}
diff --git a/src/AntSK.Domain/Domain/Service/KMService.cs b/src/AntSK.Domain/Domain/Service/KMService.cs
index a788a9d..f0c375a 100644
--- a/src/AntSK.Domain/Domain/Service/KMService.cs
+++ b/src/AntSK.Domain/Domain/Service/KMService.cs
@@ -1,5 +1,6 @@
using AntDesign;
using AntSK.Domain.Common.DependencyInjection;
+using AntSK.Domain.Common.Embedding;
using AntSK.Domain.Domain.Interface;
using AntSK.Domain.Domain.Model.Constant;
using AntSK.Domain.Domain.Model.Dto;
@@ -147,6 +148,11 @@ namespace AntSK.Domain.Domain.Service
var embedder = new LLamaEmbedder(weights, parameters);
memory.WithLLamaSharpTextEmbeddingGeneration(new LLamaSharpTextEmbeddingGenerator(embedder));
break;
+ case Model.Enum.AIType.BgeEmbedding:
+ string pyDll = embedModel.EndPoint;
+ string bgeEmbeddingModelName = embedModel.ModelName;
+ memory.WithBgeTextEmbeddingGeneration(new HuggingfaceTextEmbeddingGenerator(pyDll,bgeEmbeddingModelName));
+ break;
case Model.Enum.AIType.DashScope:
memory.WithDashScopeDefaults(embedModel.ModelKey);
break;
@@ -183,6 +189,14 @@ namespace AntSK.Domain.Domain.Service
var executor = new StatelessExecutor(weights, parameters);
memory.WithLLamaSharpTextGeneration(new LlamaSharpTextGenerator(weights, context, executor));
break;
+ case Model.Enum.AIType.LLamaFactory:
+
+ memory.WithOpenAITextGeneration(new OpenAIConfig()
+ {
+ APIKey = "123",
+ TextModel = chatModel.ModelName
+ }, null, chatHttpClient);
+ break;
case Model.Enum.AIType.DashScope:
memory.WithDashScopeTextGeneration(new Cnblogs.KernelMemory.AI.DashScope.DashScopeConfig
{
diff --git a/src/AntSK.Domain/Domain/Service/PyNetEmbeddingService.cs b/src/AntSK.Domain/Domain/Service/PyNetEmbeddingService.cs
deleted file mode 100644
index 9d743ef..0000000
--- a/src/AntSK.Domain/Domain/Service/PyNetEmbeddingService.cs
+++ /dev/null
@@ -1,74 +0,0 @@
-using AntSK.Domain.Common.DependencyInjection;
-using AntSK.Domain.Domain.Interface;
-using AntSK.Domain.Domain.Model.Dto;
-using AntSK.Domain.Options;
-using AntSK.LLamaFactory.Model;
-using Microsoft.AspNetCore.Mvc.ModelBinding;
-using Newtonsoft.Json;
-using System;
-using System.Collections.Generic;
-using System.Diagnostics;
-using System.Linq;
-using System.Text;
-using System.Text.Json;
-using System.Threading.Tasks;
-using Python.Runtime;
-using System.Diagnostics;
-using static System.Net.Mime.MediaTypeNames;
-using AntSK.Domain.Repositories;
-using DocumentFormat.OpenXml.EMMA;
-
-namespace AntSK.Domain.Domain.Service
-{
- [ServiceDescription(typeof(IPyNetEmbeddingService), ServiceLifetime.Singleton)]
- public class PyNetEmbeddingService : IPyNetEmbeddingService
- {
- private readonly IAIModels_Repositories _aIModels_Repositories;
- public PyNetEmbeddingService(IAIModels_Repositories aIModels_Repositories)
- {
- _aIModels_Repositories = aIModels_Repositories;
- }
-
- public static dynamic model { get; set; }
- ///
- /// 模型写死
- ///
- public dynamic LoadModel()
- {
- if (model == null)
- {
- Runtime.PythonDLL = @"D:\Programs\Python\Python311\python311.dll";
- PythonEngine.Initialize();
- Py.GIL();// 初始化Python环境的Global Interpreter Lock
- try
- {
- dynamic modelscope = Py.Import("modelscope");
- dynamic model_dir = modelscope.snapshot_download("AI-ModelScope/bge-large-zh-v1.5", revision: "master");
- dynamic HuggingFaceBgeEmbeddingstemp = Py.Import("langchain.embeddings");
- dynamic HuggingFaceBgeEmbeddings = HuggingFaceBgeEmbeddingstemp.HuggingFaceBgeEmbeddings;
- string model_name = model_dir;
- dynamic model_kwargs = new PyDict();
- model_kwargs["device"] = new PyString("cpu");
- dynamic model = HuggingFaceBgeEmbeddings(
- model_name: model_dir,
- model_kwargs: model_kwargs
- );
- return model;
- }
- catch
- {
- return null;
- }
-
- }
- else
- return model;
- }
-
- public string GetEmbedding(string queryStr)
- {
- var queryResult = model.embed_query(queryStr).ToString();
- return queryResult;
- }
- }
-}
diff --git a/src/AntSK.LLamaFactory/llamafactory/api_demo.py b/src/AntSK.LLamaFactory/llamafactory/api_demo.py
new file mode 100644
index 0000000..a714067
--- /dev/null
+++ b/src/AntSK.LLamaFactory/llamafactory/api_demo.py
@@ -0,0 +1,16 @@
+import os
+
+import uvicorn
+
+from llmtuner import ChatModel, create_app
+
+
+def main():
+ chat_model = ChatModel()
+ app = create_app(chat_model)
+ print("Visit http://localhost:{}/docs for API document.".format(os.environ.get("API_PORT", 8000)))
+ uvicorn.run(app, host="0.0.0.0", port=int(os.environ.get("API_PORT", 8000)), workers=1)
+
+
+if __name__ == "__main__":
+ main()
diff --git a/src/AntSK.LLamaFactory/llamafactory/llmtuner/train/__init__.py b/src/AntSK.LLamaFactory/llamafactory/llmtuner/train/__init__.py
new file mode 100644
index 0000000..6c22bc1
--- /dev/null
+++ b/src/AntSK.LLamaFactory/llamafactory/llmtuner/train/__init__.py
@@ -0,0 +1,4 @@
+from .tuner import export_model, run_exp
+
+
+__all__ = ["export_model", "run_exp"]
diff --git a/src/AntSK.LLamaFactory/llamafactory/llmtuner/train/dpo/__init__.py b/src/AntSK.LLamaFactory/llamafactory/llmtuner/train/dpo/__init__.py
new file mode 100644
index 0000000..43fe942
--- /dev/null
+++ b/src/AntSK.LLamaFactory/llamafactory/llmtuner/train/dpo/__init__.py
@@ -0,0 +1,4 @@
+from .workflow import run_dpo
+
+
+__all__ = ["run_dpo"]
diff --git a/src/AntSK.LLamaFactory/llamafactory/llmtuner/train/dpo/collator.py b/src/AntSK.LLamaFactory/llamafactory/llmtuner/train/dpo/collator.py
new file mode 100644
index 0000000..7e8ba1c
--- /dev/null
+++ b/src/AntSK.LLamaFactory/llamafactory/llmtuner/train/dpo/collator.py
@@ -0,0 +1,54 @@
+from dataclasses import dataclass
+from typing import Any, Dict, List, Sequence, Tuple
+
+import torch
+from transformers import DataCollatorForSeq2Seq
+
+
+@dataclass
+class DPODataCollatorWithPadding(DataCollatorForSeq2Seq):
+ r"""
+ Data collator for pairwise data.
+ """
+
+ def _pad_labels(self, batch: torch.Tensor, positions: List[Tuple[int, int]]) -> torch.Tensor:
+ padded_labels = []
+ for feature, (prompt_len, answer_len) in zip(batch, positions):
+ if self.tokenizer.padding_side == "left":
+ start, end = feature.size(0) - answer_len, feature.size(0)
+ else:
+ start, end = prompt_len, prompt_len + answer_len
+ padded_tensor = self.label_pad_token_id * torch.ones_like(feature)
+ padded_tensor[start:end] = feature[start:end]
+ padded_labels.append(padded_tensor)
+ return torch.stack(padded_labels, dim=0).contiguous() # in contiguous memory
+
+ def __call__(self, features: Sequence[Dict[str, Any]]) -> Dict[str, torch.Tensor]:
+ r"""
+ Pads batched data to the longest sequence in the batch.
+
+ We generate 2 * n examples where the first n examples represent chosen examples and
+ the last n examples represent rejected examples.
+ """
+ concatenated_features = []
+ label_positions = []
+ for key in ("chosen_ids", "rejected_ids"):
+ for feature in features:
+ prompt_len, answer_len = len(feature["prompt_ids"]), len(feature[key])
+ concatenated_features.append(
+ {
+ "input_ids": feature["prompt_ids"] + feature[key],
+ "attention_mask": [1] * (prompt_len + answer_len),
+ }
+ )
+ label_positions.append((prompt_len, answer_len))
+
+ batch = self.tokenizer.pad(
+ concatenated_features,
+ padding=self.padding,
+ max_length=self.max_length,
+ pad_to_multiple_of=self.pad_to_multiple_of,
+ return_tensors=self.return_tensors,
+ )
+ batch["labels"] = self._pad_labels(batch["input_ids"], label_positions)
+ return batch
diff --git a/src/AntSK.LLamaFactory/llamafactory/llmtuner/train/dpo/trainer.py b/src/AntSK.LLamaFactory/llamafactory/llmtuner/train/dpo/trainer.py
new file mode 100644
index 0000000..ed8bf4c
--- /dev/null
+++ b/src/AntSK.LLamaFactory/llamafactory/llmtuner/train/dpo/trainer.py
@@ -0,0 +1,149 @@
+from collections import defaultdict
+from contextlib import nullcontext
+from typing import TYPE_CHECKING, Dict, Literal, Optional, Tuple, Union
+
+import torch
+from transformers import BatchEncoding, Trainer
+from trl import DPOTrainer
+from trl.trainer.utils import disable_dropout_in_model
+
+from ...extras.constants import IGNORE_INDEX
+
+
+if TYPE_CHECKING:
+ from transformers import PreTrainedModel
+
+
+class CustomDPOTrainer(DPOTrainer):
+ def __init__(
+ self,
+ beta: float,
+ loss_type: Literal["sigmoid", "hinge", "ipo", "kto_pair"],
+ ftx_gamma: float,
+ model: Union["PreTrainedModel", torch.nn.Module],
+ ref_model: Optional[Union["PreTrainedModel", torch.nn.Module]] = None,
+ disable_dropout: bool = True,
+ **kwargs,
+ ):
+ if disable_dropout:
+ disable_dropout_in_model(model)
+ if ref_model is not None:
+ disable_dropout_in_model(ref_model)
+
+ self.reference_free = False
+ self.use_dpo_data_collator = True # hack to avoid warning
+ self.generate_during_eval = False # disable at evaluation
+ self.label_pad_token_id = IGNORE_INDEX
+ self.padding_value = 0
+ self.is_encoder_decoder = model.config.is_encoder_decoder
+ self.precompute_ref_log_probs = False
+ self._precomputed_train_ref_log_probs = False
+ self._precomputed_eval_ref_log_probs = False
+ self._peft_has_been_casted_to_bf16 = False
+
+ self.ref_model = ref_model
+ self.beta = beta
+ self.label_smoothing = 0
+ self.loss_type = loss_type
+ self.ftx_gamma = ftx_gamma
+ self._stored_metrics = defaultdict(lambda: defaultdict(list))
+
+ Trainer.__init__(self, model=model, **kwargs)
+ if not hasattr(self, "accelerator"):
+ raise AttributeError("Please update `transformers`.")
+
+ if ref_model is not None:
+ if self.is_deepspeed_enabled:
+ if not (
+ getattr(ref_model, "is_loaded_in_8bit", False) or getattr(ref_model, "is_loaded_in_4bit", False)
+ ): # quantized models are already set on the correct device
+ self.ref_model = self._prepare_deepspeed(self.ref_model)
+ else:
+ self.ref_model = self.accelerator.prepare_model(self.ref_model, evaluation_mode=True)
+
+ def sft_loss(self, chosen_logits: torch.FloatTensor, chosen_labels: torch.LongTensor) -> torch.Tensor:
+ r"""
+ Computes supervised cross-entropy loss of given labels under the given logits.
+
+ Returns:
+ A tensor of shape (batch_size,) containing the cross-entropy loss of each samples.
+ """
+ all_logps = self.get_batch_logps(chosen_logits, chosen_labels, average_log_prob=True)
+ return -all_logps
+
+ def concatenated_forward(
+ self, model: "PreTrainedModel", batch: Dict[str, torch.Tensor]
+ ) -> Tuple[torch.FloatTensor, torch.FloatTensor, torch.FloatTensor, torch.FloatTensor]:
+ batch_copied = BatchEncoding({k: v.detach().clone() for k, v in batch.items()}) # avoid error
+
+ all_logits = model(
+ input_ids=batch_copied["input_ids"], attention_mask=batch_copied["attention_mask"], return_dict=True
+ ).logits.to(torch.float32)
+
+ all_logps = self.get_batch_logps(
+ all_logits,
+ batch["labels"],
+ average_log_prob=False,
+ label_pad_token_id=self.label_pad_token_id,
+ )
+ batch_size = batch["input_ids"].size(0) // 2
+ chosen_logps, rejected_logps = all_logps.split(batch_size, dim=0)
+ chosen_logits, rejected_logits = all_logits.split(batch_size, dim=0)
+ return chosen_logps, rejected_logps, chosen_logits, rejected_logits
+
+ def get_batch_loss_metrics(
+ self,
+ model: "PreTrainedModel",
+ batch: Dict[str, torch.Tensor],
+ train_eval: Literal["train", "eval"] = "train",
+ ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]:
+ r"""
+ Computes the DPO loss and other metrics for the given batch of inputs for train or test.
+ """
+ metrics = {}
+ (
+ policy_chosen_logps,
+ policy_rejected_logps,
+ policy_chosen_logits,
+ policy_rejected_logits,
+ ) = self.concatenated_forward(model, batch)
+ with torch.no_grad():
+ if self.ref_model is None:
+ ref_model = self.model
+ ref_context = self.accelerator.unwrap_model(self.model).disable_adapter()
+ else:
+ ref_model = self.ref_model
+ ref_context = nullcontext()
+
+ with ref_context:
+ (
+ reference_chosen_logps,
+ reference_rejected_logps,
+ _,
+ _,
+ ) = self.concatenated_forward(ref_model, batch)
+
+ losses, chosen_rewards, rejected_rewards = self.dpo_loss(
+ policy_chosen_logps,
+ policy_rejected_logps,
+ reference_chosen_logps,
+ reference_rejected_logps,
+ )
+ if self.ftx_gamma > 1e-6:
+ batch_size = batch["input_ids"].size(0) // 2
+ chosen_labels, _ = batch["labels"].split(batch_size, dim=0)
+ losses += self.ftx_gamma * self.sft_loss(policy_chosen_logits, chosen_labels)
+
+ reward_accuracies = (chosen_rewards > rejected_rewards).float()
+
+ prefix = "eval_" if train_eval == "eval" else ""
+ metrics[f"{prefix}rewards/chosen"] = chosen_rewards.cpu().mean()
+ metrics[f"{prefix}rewards/rejected"] = rejected_rewards.cpu().mean()
+ metrics[f"{prefix}rewards/accuracies"] = reward_accuracies.cpu().mean()
+ metrics[f"{prefix}rewards/margins"] = (chosen_rewards - rejected_rewards).cpu().mean()
+ metrics[f"{prefix}logps/rejected"] = policy_rejected_logps.detach().cpu().mean()
+ metrics[f"{prefix}logps/chosen"] = policy_chosen_logps.detach().cpu().mean()
+ metrics[f"{prefix}logits/rejected"] = policy_rejected_logits.detach().cpu().mean()
+ metrics[f"{prefix}logits/chosen"] = policy_chosen_logits.detach().cpu().mean()
+
+ return losses.mean(), metrics
diff --git a/src/AntSK.LLamaFactory/llamafactory/llmtuner/train/dpo/workflow.py b/src/AntSK.LLamaFactory/llamafactory/llmtuner/train/dpo/workflow.py
new file mode 100644
index 0000000..39ea1a0
--- /dev/null
+++ b/src/AntSK.LLamaFactory/llamafactory/llmtuner/train/dpo/workflow.py
@@ -0,0 +1,83 @@
+# Inspired by: https://github.com/huggingface/trl/blob/main/examples/research_projects/stack_llama_2/scripts/dpo_llama2.py
+
+from typing import TYPE_CHECKING, List, Optional
+
+from ...data import get_dataset, split_dataset
+from ...extras.constants import IGNORE_INDEX
+from ...extras.ploting import plot_loss
+from ...hparams import ModelArguments
+from ...model import load_model, load_tokenizer
+from ..utils import create_custom_optimzer, create_modelcard_and_push, create_ref_model
+from .collator import DPODataCollatorWithPadding
+from .trainer import CustomDPOTrainer
+
+
+if TYPE_CHECKING:
+ from transformers import Seq2SeqTrainingArguments, TrainerCallback
+
+ from ...hparams import DataArguments, FinetuningArguments
+
+
+def run_dpo(
+ model_args: "ModelArguments",
+ data_args: "DataArguments",
+ training_args: "Seq2SeqTrainingArguments",
+ finetuning_args: "FinetuningArguments",
+ callbacks: Optional[List["TrainerCallback"]] = None,
+):
+ tokenizer = load_tokenizer(model_args)
+ dataset = get_dataset(tokenizer, model_args, data_args, training_args, stage="rm")
+ model = load_model(tokenizer, model_args, finetuning_args, training_args.do_train)
+ data_collator = DPODataCollatorWithPadding(
+ tokenizer=tokenizer,
+ pad_to_multiple_of=8,
+ label_pad_token_id=IGNORE_INDEX if data_args.ignore_pad_token_for_loss else tokenizer.pad_token_id,
+ )
+
+ # Create reference model
+ if finetuning_args.ref_model is None and (not training_args.do_train): # use the model itself
+ ref_model = model
+ else:
+ ref_model = create_ref_model(model_args, finetuning_args)
+
+ # Update arguments
+ training_args.remove_unused_columns = False # important for pairwise dataset
+
+ # Initialize our Trainer
+ optimizer = create_custom_optimzer(model, dataset, training_args, finetuning_args)
+ trainer = CustomDPOTrainer(
+ beta=finetuning_args.dpo_beta,
+ loss_type=finetuning_args.dpo_loss,
+ ftx_gamma=finetuning_args.dpo_ftx,
+ model=model,
+ ref_model=ref_model,
+ args=training_args,
+ tokenizer=tokenizer,
+ data_collator=data_collator,
+ callbacks=callbacks,
+ optimizers=(optimizer, None),
+ **split_dataset(dataset, data_args, training_args),
+ )
+
+ # Training
+ if training_args.do_train:
+ train_result = trainer.train(resume_from_checkpoint=training_args.resume_from_checkpoint)
+ trainer.save_model()
+ trainer.log_metrics("train", train_result.metrics)
+ trainer.save_metrics("train", train_result.metrics)
+ trainer.save_state()
+ if trainer.is_world_process_zero() and finetuning_args.plot_loss:
+ plot_loss(training_args.output_dir, keys=["loss", "eval_loss"])
+
+ # Evaluation
+ if training_args.do_eval:
+ metrics = trainer.evaluate(metric_key_prefix="eval")
+ if id(model) == id(ref_model): # unable to compute rewards without a reference model
+ remove_keys = [key for key in metrics.keys() if "rewards" in key]
+ for key in remove_keys:
+ metrics.pop(key)
+ trainer.log_metrics("eval", metrics)
+ trainer.save_metrics("eval", metrics)
+
+ # Create model card
+ create_modelcard_and_push(trainer, model_args, data_args, training_args, finetuning_args)
diff --git a/src/AntSK.LLamaFactory/llamafactory/llmtuner/train/ppo/__init__.py b/src/AntSK.LLamaFactory/llamafactory/llmtuner/train/ppo/__init__.py
new file mode 100644
index 0000000..d17336d
--- /dev/null
+++ b/src/AntSK.LLamaFactory/llamafactory/llmtuner/train/ppo/__init__.py
@@ -0,0 +1,4 @@
+from .workflow import run_ppo
+
+
+__all__ = ["run_ppo"]
diff --git a/src/AntSK.LLamaFactory/llamafactory/llmtuner/train/ppo/trainer.py b/src/AntSK.LLamaFactory/llamafactory/llmtuner/train/ppo/trainer.py
new file mode 100644
index 0000000..a06d7ef
--- /dev/null
+++ b/src/AntSK.LLamaFactory/llamafactory/llmtuner/train/ppo/trainer.py
@@ -0,0 +1,376 @@
+import math
+import os
+import sys
+from typing import TYPE_CHECKING, Dict, List, Optional, Tuple
+
+import torch
+from tqdm import tqdm
+from transformers import GenerationConfig, Trainer, TrainerControl, TrainerState
+from transformers.trainer_pt_utils import remove_dummy_checkpoint
+from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR
+from transformers.utils import SAFE_WEIGHTS_NAME, WEIGHTS_NAME
+from trl import PPOTrainer
+from trl.core import PPODecorators, logprobs_from_logits
+
+from ...extras.callbacks import FixValueHeadModelCallback, LogCallback
+from ...extras.logging import get_logger
+from ...extras.misc import AverageMeter, count_parameters, get_current_device, get_logits_processor
+from .utils import dump_layernorm, get_rewards_from_server, replace_model, restore_layernorm
+
+
+if TYPE_CHECKING:
+ from transformers import Seq2SeqTrainingArguments, TrainerCallback
+ from trl import AutoModelForCausalLMWithValueHead
+
+ from ...hparams import FinetuningArguments, GeneratingArguments, ModelArguments
+
+
+logger = get_logger(__name__)
+
+
+class CustomPPOTrainer(PPOTrainer, Trainer):
+ r"""
+ Inherits PPOTrainer.
+ """
+
+ def __init__(
+ self,
+ model_args: "ModelArguments",
+ training_args: "Seq2SeqTrainingArguments",
+ finetuning_args: "FinetuningArguments",
+ generating_args: "GeneratingArguments",
+ callbacks: List["TrainerCallback"],
+ reward_model: "AutoModelForCausalLMWithValueHead",
+ **kwargs,
+ ):
+ PPOTrainer.__init__(self, **kwargs)
+
+ self.args = training_args
+ self.model_args = model_args
+ self.finetuning_args = finetuning_args
+ self.reward_model = reward_model
+ self.current_device = get_current_device() # patch for deepspeed training
+
+ self.generation_config = GenerationConfig(
+ pad_token_id=self.tokenizer.pad_token_id,
+ eos_token_id=[self.tokenizer.eos_token_id] + self.tokenizer.additional_special_tokens_ids,
+ **generating_args.to_dict(),
+ )
+
+ self.state = TrainerState()
+ self.control = TrainerControl()
+ self.is_deepspeed_enabled = self.accelerator.distributed_type == "DEEPSPEED" and hasattr(
+ self.accelerator.state, "deepspeed_plugin"
+ )
+ self.log_callback, self.save_callback = callbacks[0], callbacks[1]
+ assert isinstance(self.log_callback, LogCallback) and isinstance(self.save_callback, FixValueHeadModelCallback)
+
+ if self.args.max_steps > 0:
+ logger.info("max_steps is given, it will override any value given in num_train_epochs")
+
+ if finetuning_args.reward_model_type == "full":
+ if self.is_deepspeed_enabled:
+ if not (
+ getattr(reward_model.pretrained_model, "is_loaded_in_8bit", False)
+ or getattr(reward_model.pretrained_model, "is_loaded_in_4bit", False)
+ ): # quantized models are already set on the correct device
+ self.reward_model = self._prepare_deepspeed(self.reward_model)
+ else:
+ self.reward_model = self.accelerator.prepare_model(self.reward_model, evaluation_mode=True)
+
+ def ppo_train(self, resume_from_checkpoint: Optional[str] = None) -> None:
+ r"""
+ Implements training loop for the PPO stage, like _inner_training_loop() in Huggingface's Trainer.
+ """
+ if resume_from_checkpoint is not None:
+ raise ValueError("`resume_from_checkpoint` will be supported in the future version.")
+
+ total_train_batch_size = (
+ self.args.per_device_train_batch_size
+ * self.args.gradient_accumulation_steps
+ * self.finetuning_args.ppo_buffer_size
+ * self.args.world_size
+ )
+ if self.args.max_steps > 0:
+ num_examples = total_train_batch_size * self.args.max_steps
+ num_train_epochs = sys.maxsize
+ max_steps = self.args.max_steps
+ steps_in_epoch = self.args.max_steps
+ else:
+ len_dataloader = len(self.dataloader)
+ num_examples = len(self.dataset)
+ num_train_epochs = self.args.num_train_epochs
+ max_steps = math.ceil(num_train_epochs * len_dataloader)
+ steps_in_epoch = len_dataloader
+
+ self.state.max_steps = max_steps
+ self.state.num_train_epochs = num_train_epochs
+ self.state.is_local_process_zero = self.is_local_process_zero()
+ self.state.is_world_process_zero = self.is_world_process_zero()
+
+ if self.is_world_process_zero():
+ logger.info("***** Running training *****")
+ logger.info(" Num examples = {}".format(num_examples))
+ logger.info(" Num Epochs = {}".format(num_train_epochs))
+ logger.info(" Instantaneous batch size per device = {}".format(self.args.per_device_train_batch_size))
+ logger.info(
+ " Total train batch size (w. parallel, buffer, distributed & accumulation) = {}".format(
+ total_train_batch_size
+ )
+ )
+ logger.info(" Gradient Accumulation steps = {}".format(self.args.gradient_accumulation_steps))
+ logger.info(" Num optimization epochs per batch = {}".format(self.finetuning_args.ppo_epochs))
+ logger.info(" Total training steps = {}".format(max_steps))
+ logger.info(" Number of trainable parameters = {}".format(count_parameters(self.model)[0]))
+
+ unwrapped_model: "AutoModelForCausalLMWithValueHead" = self.accelerator.unwrap_model(self.model)
+ dataiter = iter(self.dataloader)
+ loss_meter = AverageMeter()
+ reward_meter = AverageMeter()
+ self.log_callback.on_train_begin(self.args, self.state, self.control)
+
+ for step in tqdm(range(max_steps), disable=not self.is_local_process_zero()):
+ try:
+ batch = next(dataiter)
+ except StopIteration:
+ dataiter = iter(self.dataloader)
+ batch = next(dataiter)
+
+ # Cast to inference mode
+ unwrapped_model.gradient_checkpointing_disable()
+ unwrapped_model.config.use_cache = True
+ self.model.eval()
+
+ # Get inputs
+ self.tokenizer.padding_side = "right" # change padding side
+ queries, responses, rewards = [], [], []
+ for idx in range(0, self.config.batch_size, self.config.mini_batch_size):
+ mini_batch_queries, mini_batch_responses = self.get_inputs(
+ batch[idx : idx + self.config.mini_batch_size]
+ )
+ mini_batch_rewards = self.get_rewards(mini_batch_queries, mini_batch_responses, unwrapped_model)
+ queries.extend(mini_batch_queries)
+ responses.extend(mini_batch_responses)
+ rewards.extend(mini_batch_rewards)
+
+ # Cast to training mode
+ unwrapped_model.gradient_checkpointing_enable()
+ unwrapped_model.config.use_cache = False
+ self.model.train()
+
+ # Run PPO step
+ stats = self.step(queries, responses, rewards)
+ self.tokenizer.padding_side = "left" # restore padding side
+ loss_meter.update(float(stats["ppo/loss/total"]), n=len(rewards))
+ reward_meter.update(torch.stack(rewards).mean().item(), n=len(rewards))
+
+ if self.config.log_with is not None:
+ try:
+ batch["query"] = self.tokenizer.batch_decode(queries, skip_special_tokens=True)
+ batch["response"] = self.tokenizer.batch_decode(responses, skip_special_tokens=True)
+ self.log_stats(stats, batch, rewards)
+ except Exception:
+ logger.warning("Failed to save stats due to unknown errors.")
+
+ self.state.global_step += 1
+ self.log_callback.on_step_end(self.args, self.state, self.control)
+
+ if self.is_local_process_zero() and (step + 1) % self.args.logging_steps == 0:
+ logs = dict(
+ loss=round(loss_meter.avg, 4),
+ reward=round(reward_meter.avg, 4),
+ learning_rate=stats["ppo/learning_rate"],
+ epoch=round(step / steps_in_epoch, 2),
+ )
+ tqdm.write(str(logs))
+ logs["step"] = step
+ self.state.log_history.append(logs)
+ self.log_callback.on_log(self.args, self.state, self.control)
+ loss_meter.reset()
+ reward_meter.reset()
+
+ if (step + 1) % self.args.save_steps == 0: # save checkpoint
+ self.save_model(
+ os.path.join(self.args.output_dir, "{}-{}".format(PREFIX_CHECKPOINT_DIR, self.state.global_step))
+ )
+ self.save_callback.on_save(
+ self.args, self.state, self.control, model=self.accelerator.unwrap_model(self.model)
+ )
+
+ if self.control.should_epoch_stop or self.control.should_training_stop:
+ break
+
+ self.log_callback.on_train_end(self.args, self.state, self.control)
+ self.save_callback.on_train_end(
+ self.args, self.state, self.control, model=self.accelerator.unwrap_model(self.model)
+ )
+
+ @torch.no_grad()
+ def get_inputs(self, batch: Dict[str, torch.Tensor]) -> Tuple[List[torch.Tensor], List[torch.Tensor]]:
+ r"""
+ Generates model's responses given queries.
+ """
+ if self.model_args.upcast_layernorm:
+ layernorm_params = dump_layernorm(self.model)
+
+ if batch["input_ids"].size(0) == 1: # handle llama2 ppo with gradient accumulation > 1
+ start_index = (batch["input_ids"][0] != self.tokenizer.pad_token_id).nonzero()[0].item()
+ for k, v in batch.items():
+ batch[k] = v[:, start_index:]
+
+ unwrapped_model: "AutoModelForCausalLMWithValueHead" = self.accelerator.unwrap_model(self.model)
+ generate_output: torch.Tensor = unwrapped_model.generate(
+ generation_config=self.generation_config, logits_processor=get_logits_processor(), **batch
+ )
+
+ if self.model_args.upcast_layernorm:
+ restore_layernorm(self.model, layernorm_params)
+
+ query = batch["input_ids"].detach().cpu()
+ response = generate_output[:, batch["input_ids"].size(-1) :].detach().cpu()
+ queries, responses = [], []
+ for i in range(len(query)):
+ query_start_index = (query[i] != self.tokenizer.pad_token_id).nonzero()[0].item()
+ response_index = (response[i] != self.tokenizer.pad_token_id).nonzero()
+
+ if len(response_index) == 0:
+ response_length = 1 # allow empty response
+ else:
+ response_length = response_index[-1].item() + 1
+
+ queries.append(query[i, query_start_index:]) # remove padding from left
+ responses.append(response[i, :response_length]) # remove padding from right
+
+ return queries, responses
+
+ @torch.no_grad()
+ def get_rewards(
+ self,
+ queries: List[torch.Tensor],
+ responses: List[torch.Tensor],
+ unwrapped_model: "AutoModelForCausalLMWithValueHead",
+ ) -> List[torch.Tensor]:
+ r"""
+ Computes scores using given reward model.
+
+ Both inputs and outputs are put on CPU.
+ """
+ if self.finetuning_args.reward_model_type == "api":
+ token_ids = [torch.cat((q, r), dim=-1).tolist() for q, r in zip(queries, responses)]
+ messages = self.tokenizer.batch_decode(token_ids, skip_special_tokens=True)
+ return get_rewards_from_server(self.reward_model, messages)
+
+ if self.finetuning_args.reward_model_type == "lora":
+ replace_model(unwrapped_model, target="reward")
+ reward_model = self.model
+ else:
+ reward_model = self.reward_model
+
+ batch = self.prepare_model_inputs(queries, responses)
+
+ with torch.cuda.amp.autocast(dtype=self.model_args.compute_dtype): # support bf16
+ _, _, values = reward_model(**batch, output_hidden_states=True, return_dict=True)
+
+ if getattr(unwrapped_model.config, "model_type", None) == "chatglm": # assume same architecture
+ values = torch.transpose(values, 0, 1)
+
+ rewards = []
+ for i in range(values.size(0)):
+ end_indexes = (batch["input_ids"][i] != self.tokenizer.pad_token_id).nonzero()
+ end_index = end_indexes[-1].item() if len(end_indexes) else 0
+ rewards.append(values[i, end_index].float().detach().cpu()) # use fp32 type
+
+ if self.finetuning_args.reward_model_type == "lora":
+ replace_model(unwrapped_model, target="default")
+
+ return rewards
+
+ @PPODecorators.empty_device_cache()
+ def batched_forward_pass(
+ self,
+ model: "AutoModelForCausalLMWithValueHead",
+ queries: torch.Tensor,
+ responses: torch.Tensor,
+ model_inputs: dict,
+ return_logits: bool = False,
+ response_masks: Optional[torch.Tensor] = None,
+ ):
+ r"""
+ Calculates model outputs in multiple batches.
+
+ Subclass and override to inject custom behavior.
+ """
+ bs = len(queries)
+ fbs = self.config.mini_batch_size
+ all_logprobs = []
+ all_logits = []
+ all_masks = []
+ all_values = []
+
+ for i in range(math.ceil(bs / fbs)):
+ input_kwargs = {key: value[i * fbs : (i + 1) * fbs] for key, value in model_inputs.items()}
+ query_batch = queries[i * fbs : (i + 1) * fbs]
+ response_batch = responses[i * fbs : (i + 1) * fbs]
+ if response_masks is not None:
+ response_masks_batch = response_masks[i * fbs : (i + 1) * fbs]
+ input_ids = input_kwargs["input_ids"]
+ attention_mask = input_kwargs["attention_mask"]
+
+ with torch.cuda.amp.autocast(dtype=self.model_args.compute_dtype): # support bf16
+ logits, _, values = model(**input_kwargs)
+
+ unwrapped_model: "AutoModelForCausalLMWithValueHead" = self.accelerator.unwrap_model(self.model)
+ if getattr(unwrapped_model.config, "model_type", None) == "chatglm":
+ values = torch.transpose(values, 0, 1)
+
+ logprobs = logprobs_from_logits(logits[:, :-1, :], input_ids[:, 1:])
+ masks = torch.zeros_like(attention_mask)
+ masks[:, :-1] = attention_mask[:, 1:]
+
+ for j in range(len(query_batch)):
+ start = len(query_batch[j]) - 1
+ if attention_mask[j, 0] == 0: # offset left padding
+ start += attention_mask[j, :].nonzero()[0].item()
+ end = start + len(response_batch[j])
+
+ if response_masks is not None:
+ response_masks_batch = torch.cat((torch.zeros_like(query_batch[j]), response_masks_batch[j]))[1:]
+
+ masks[j, :start] = 0
+ masks[j, end:] = 0
+ if response_masks is not None:
+ masks[j, start:end] = masks[j, start:end] * response_masks_batch[j][start:end]
+
+ if return_logits:
+ all_logits.append(logits)
+ else:
+ del logits
+
+ all_values.append(values)
+ all_logprobs.append(logprobs)
+ all_masks.append(masks)
+
+ return (
+ torch.cat(all_logprobs),
+ torch.cat(all_logits)[:, :-1] if return_logits else None,
+ torch.cat(all_values)[:, :-1],
+ torch.cat(all_masks)[:, :-1],
+ )
+
+ def save_model(self, output_dir: Optional[str] = None) -> None:
+ r"""
+ Saves model checkpoint.
+
+ Subclass and override to inject custom behavior.
+ """
+ if self.args.should_save:
+ try:
+ self._save(output_dir, state_dict=self.accelerator.get_state_dict(self.model))
+ except ValueError:
+ logger.warning(
+ " stage3_gather_16bit_weights_on_model_save=false. Saving the full checkpoint instead,"
+ " use zero_to_fp32.py to recover weights"
+ )
+ self._save(output_dir, state_dict={})
+ remove_dummy_checkpoint(True, output_dir, [WEIGHTS_NAME, SAFE_WEIGHTS_NAME])
+ self.model.save_checkpoint(output_dir)
diff --git a/src/AntSK.LLamaFactory/llamafactory/llmtuner/train/ppo/utils.py b/src/AntSK.LLamaFactory/llamafactory/llmtuner/train/ppo/utils.py
new file mode 100644
index 0000000..e6bdb89
--- /dev/null
+++ b/src/AntSK.LLamaFactory/llamafactory/llmtuner/train/ppo/utils.py
@@ -0,0 +1,59 @@
+import json
+from contextlib import nullcontext
+from typing import TYPE_CHECKING, Dict, List, Literal, Optional
+
+import torch
+from transformers.integrations import is_deepspeed_zero3_enabled
+
+from ...extras.packages import is_requests_available
+
+
+if TYPE_CHECKING:
+ from transformers import PreTrainedModel
+ from trl import AutoModelForCausalLMWithValueHead
+
+if is_requests_available():
+ import requests
+
+
+def get_rewards_from_server(server_url: str, messages: List[str]) -> List[torch.Tensor]:
+ headers = {"Content-Type": "application/json"}
+ payload = {"model": "model", "messages": messages}
+ response = requests.post(server_url, json=payload, headers=headers)
+ rewards = json.loads(response.text)["scores"]
+ return torch.Tensor(rewards)
+
+
+def replace_model(model: "AutoModelForCausalLMWithValueHead", target: Literal["default", "reward"]) -> None:
+ if is_deepspeed_zero3_enabled():
+ import deepspeed # type: ignore
+
+ params = [model.v_head.summary.weight, model.v_head.summary.bias]
+ context_maybe_zero3 = deepspeed.zero.GatheredParameters(params, modifier_rank=0)
+ else:
+ context_maybe_zero3 = nullcontext()
+
+ with context_maybe_zero3:
+ if target == "reward": # save default head temporarily
+ setattr(model, "default_head_weight", model.v_head.summary.weight.data.detach().clone())
+ setattr(model, "default_head_bias", model.v_head.summary.bias.data.detach().clone())
+
+ model.pretrained_model.set_adapter(target) # set the LoRA adapter to be active
+ model.v_head.summary.weight.data = model.get_buffer("{}_head_weight".format(target)).detach().clone()
+ model.v_head.summary.bias.data = model.get_buffer("{}_head_bias".format(target)).detach().clone()
+
+
+def dump_layernorm(model: "PreTrainedModel") -> Dict[str, torch.Tensor]:
+ layer_norm_params = {}
+ for name, param in model.named_parameters():
+ if param.data.dtype == torch.float32:
+ layer_norm_params[name] = param.data.detach().clone()
+ param.data = param.data.to(model.config.torch_dtype)
+
+ return layer_norm_params
+
+
+def restore_layernorm(model: "PreTrainedModel", layernorm_params: Optional[Dict[str, torch.Tensor]] = None) -> None:
+ for name, param in model.named_parameters():
+ if name in layernorm_params:
+ param.data = layernorm_params[name]
diff --git a/src/AntSK.LLamaFactory/llamafactory/llmtuner/train/ppo/workflow.py b/src/AntSK.LLamaFactory/llamafactory/llmtuner/train/ppo/workflow.py
new file mode 100644
index 0000000..de9f2a2
--- /dev/null
+++ b/src/AntSK.LLamaFactory/llamafactory/llmtuner/train/ppo/workflow.py
@@ -0,0 +1,110 @@
+# Inspired by: https://github.com/lvwerra/trl/blob/main/examples/research_projects/stack_llama/scripts/rl_training.py
+
+import math
+from typing import TYPE_CHECKING, List, Optional
+
+from torch.optim import AdamW
+from transformers import DataCollatorWithPadding
+from transformers.optimization import get_scheduler
+from trl import PPOConfig
+
+from ...data import get_dataset
+from ...extras.callbacks import FixValueHeadModelCallback
+from ...extras.misc import fix_valuehead_checkpoint
+from ...extras.ploting import plot_loss
+from ...model import load_model, load_tokenizer
+from ..utils import create_custom_optimzer, create_ref_model, create_reward_model
+from .trainer import CustomPPOTrainer
+
+
+if TYPE_CHECKING:
+ from transformers import Seq2SeqTrainingArguments, TrainerCallback
+
+ from ...hparams import DataArguments, FinetuningArguments, GeneratingArguments, ModelArguments
+
+
+def run_ppo(
+ model_args: "ModelArguments",
+ data_args: "DataArguments",
+ training_args: "Seq2SeqTrainingArguments",
+ finetuning_args: "FinetuningArguments",
+ generating_args: "GeneratingArguments",
+ callbacks: Optional[List["TrainerCallback"]] = None,
+):
+ tokenizer = load_tokenizer(model_args)
+ dataset = get_dataset(tokenizer, model_args, data_args, training_args, stage="ppo")
+ model = load_model(tokenizer, model_args, finetuning_args, training_args.do_train, add_valuehead=True)
+
+ tokenizer.padding_side = "left" # use left-padding in generation while using right-padding in training
+ data_collator = DataCollatorWithPadding(tokenizer=tokenizer)
+
+ # Create reference model and reward model
+ ref_model = create_ref_model(model_args, finetuning_args, add_valuehead=True)
+ reward_model = create_reward_model(model, model_args, finetuning_args)
+
+ # Create ppo config
+ backward_batch_size = training_args.per_device_train_batch_size * training_args.gradient_accumulation_steps
+ ppo_config = PPOConfig(
+ model_name=model_args.model_name_or_path,
+ learning_rate=training_args.learning_rate,
+ mini_batch_size=training_args.per_device_train_batch_size,
+ batch_size=backward_batch_size * finetuning_args.ppo_buffer_size,
+ gradient_accumulation_steps=training_args.gradient_accumulation_steps,
+ ppo_epochs=finetuning_args.ppo_epochs,
+ max_grad_norm=training_args.max_grad_norm,
+ seed=training_args.seed,
+ optimize_device_cache=True,
+ target=finetuning_args.ppo_target,
+ log_with=finetuning_args.ppo_logger,
+ use_score_scaling=finetuning_args.ppo_score_norm,
+ use_score_norm=finetuning_args.ppo_score_norm,
+ whiten_rewards=finetuning_args.ppo_whiten_rewards,
+ accelerator_kwargs={"step_scheduler_with_optimizer": False},
+ project_kwargs={"logging_dir": training_args.logging_dir},
+ )
+
+ # Create optimizer and scheduler
+ optimizer = create_custom_optimzer(model, dataset, training_args, finetuning_args)
+ if optimizer is None:
+ optimizer = AdamW(filter(lambda p: p.requires_grad, model.parameters()), lr=training_args.learning_rate)
+
+ if training_args.max_steps > 0:
+ num_training_steps = training_args.max_steps
+ else:
+ total_train_batch_size = backward_batch_size * finetuning_args.ppo_buffer_size * training_args.world_size
+ num_training_steps = training_args.num_train_epochs * math.ceil(len(dataset) / total_train_batch_size)
+
+ lr_scheduler = get_scheduler(
+ training_args.lr_scheduler_type,
+ optimizer=optimizer,
+ num_warmup_steps=training_args.get_warmup_steps(num_training_steps),
+ num_training_steps=num_training_steps,
+ )
+
+ # Initialize our Trainer
+ ppo_trainer = CustomPPOTrainer(
+ model_args=model_args,
+ training_args=training_args,
+ finetuning_args=finetuning_args,
+ generating_args=generating_args,
+ callbacks=callbacks + [FixValueHeadModelCallback()],
+ reward_model=reward_model,
+ config=ppo_config,
+ model=model,
+ ref_model=ref_model,
+ tokenizer=tokenizer,
+ dataset=dataset,
+ data_collator=data_collator,
+ optimizer=optimizer,
+ lr_scheduler=lr_scheduler,
+ )
+
+ # Training
+ if training_args.do_train:
+ ppo_trainer.ppo_train(resume_from_checkpoint=training_args.resume_from_checkpoint)
+ ppo_trainer.save_model()
+ if training_args.should_save:
+ fix_valuehead_checkpoint(model, training_args.output_dir, training_args.save_safetensors)
+ ppo_trainer.save_state() # must be called after save_model to have a folder
+ if ppo_trainer.is_world_process_zero() and finetuning_args.plot_loss:
+ plot_loss(training_args.output_dir, keys=["loss", "reward"])
diff --git a/src/AntSK.LLamaFactory/llamafactory/llmtuner/train/pt/__init__.py b/src/AntSK.LLamaFactory/llamafactory/llmtuner/train/pt/__init__.py
new file mode 100644
index 0000000..bdf397f
--- /dev/null
+++ b/src/AntSK.LLamaFactory/llamafactory/llmtuner/train/pt/__init__.py
@@ -0,0 +1,4 @@
+from .workflow import run_pt
+
+
+__all__ = ["run_pt"]
diff --git a/src/AntSK.LLamaFactory/llamafactory/llmtuner/train/pt/workflow.py b/src/AntSK.LLamaFactory/llamafactory/llmtuner/train/pt/workflow.py
new file mode 100644
index 0000000..5a08854
--- /dev/null
+++ b/src/AntSK.LLamaFactory/llamafactory/llmtuner/train/pt/workflow.py
@@ -0,0 +1,67 @@
+# Inspired by: https://github.com/huggingface/transformers/blob/v4.34.1/examples/pytorch/language-modeling/run_clm.py
+
+import math
+from typing import TYPE_CHECKING, List, Optional
+
+from transformers import DataCollatorForLanguageModeling, Trainer
+
+from ...data import get_dataset, split_dataset
+from ...extras.ploting import plot_loss
+from ...model import load_model, load_tokenizer
+from ..utils import create_custom_optimzer, create_modelcard_and_push
+
+
+if TYPE_CHECKING:
+ from transformers import Seq2SeqTrainingArguments, TrainerCallback
+
+ from ...hparams import DataArguments, FinetuningArguments, ModelArguments
+
+
+def run_pt(
+ model_args: "ModelArguments",
+ data_args: "DataArguments",
+ training_args: "Seq2SeqTrainingArguments",
+ finetuning_args: "FinetuningArguments",
+ callbacks: Optional[List["TrainerCallback"]] = None,
+):
+ tokenizer = load_tokenizer(model_args)
+ dataset = get_dataset(tokenizer, model_args, data_args, training_args, stage="pt")
+ model = load_model(tokenizer, model_args, finetuning_args, training_args.do_train)
+ data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False)
+
+ # Initialize our Trainer
+ optimizer = create_custom_optimzer(model, dataset, training_args, finetuning_args)
+ trainer = Trainer(
+ model=model,
+ args=training_args,
+ tokenizer=tokenizer,
+ data_collator=data_collator,
+ callbacks=callbacks,
+ optimizers=(optimizer, None),
+ **split_dataset(dataset, data_args, training_args),
+ )
+
+ # Training
+ if training_args.do_train:
+ train_result = trainer.train(resume_from_checkpoint=training_args.resume_from_checkpoint)
+ trainer.save_model()
+ trainer.log_metrics("train", train_result.metrics)
+ trainer.save_metrics("train", train_result.metrics)
+ trainer.save_state()
+ if trainer.is_world_process_zero() and finetuning_args.plot_loss:
+ plot_loss(training_args.output_dir, keys=["loss", "eval_loss"])
+
+ # Evaluation
+ if training_args.do_eval:
+ metrics = trainer.evaluate(metric_key_prefix="eval")
+ try:
+ perplexity = math.exp(metrics["eval_loss"])
+ except OverflowError:
+ perplexity = float("inf")
+
+ metrics["perplexity"] = perplexity
+ trainer.log_metrics("eval", metrics)
+ trainer.save_metrics("eval", metrics)
+
+ # Create model card
+ create_modelcard_and_push(trainer, model_args, data_args, training_args, finetuning_args)
diff --git a/src/AntSK.LLamaFactory/llamafactory/llmtuner/train/rm/__init__.py b/src/AntSK.LLamaFactory/llamafactory/llmtuner/train/rm/__init__.py
new file mode 100644
index 0000000..dedac35
--- /dev/null
+++ b/src/AntSK.LLamaFactory/llamafactory/llmtuner/train/rm/__init__.py
@@ -0,0 +1,4 @@
+from .workflow import run_rm
+
+
+__all__ = ["run_rm"]
diff --git a/src/AntSK.LLamaFactory/llamafactory/llmtuner/train/rm/collator.py b/src/AntSK.LLamaFactory/llamafactory/llmtuner/train/rm/collator.py
new file mode 100644
index 0000000..8d5d4ad
--- /dev/null
+++ b/src/AntSK.LLamaFactory/llamafactory/llmtuner/train/rm/collator.py
@@ -0,0 +1,29 @@
+from dataclasses import dataclass
+from typing import Any, Dict, Sequence
+
+import torch
+from transformers import DataCollatorWithPadding
+
+
+@dataclass
+class PairwiseDataCollatorWithPadding(DataCollatorWithPadding):
+ r"""
+ Data collator for pairwise data.
+ """
+
+ def __call__(self, features: Sequence[Dict[str, Any]]) -> Dict[str, torch.Tensor]:
+ r"""
+ Pads batched data to the longest sequence in the batch.
+
+ We generate 2 * n examples where the first n examples represent chosen examples and
+ the last n examples represent rejected examples.
+ """
+ features = [
+ {
+ "input_ids": feature["prompt_ids"] + feature[key],
+ "attention_mask": [1] * (len(feature["prompt_ids"]) + len(feature[key])),
+ }
+ for key in ("chosen_ids", "rejected_ids")
+ for feature in features
+ ]
+ return super().__call__(features)
diff --git a/src/AntSK.LLamaFactory/llamafactory/llmtuner/train/rm/metric.py b/src/AntSK.LLamaFactory/llamafactory/llmtuner/train/rm/metric.py
new file mode 100644
index 0000000..99dc6ab
--- /dev/null
+++ b/src/AntSK.LLamaFactory/llamafactory/llmtuner/train/rm/metric.py
@@ -0,0 +1,8 @@
+from typing import Dict, Sequence, Tuple, Union
+
+import numpy as np
+
+
+def compute_accuracy(eval_preds: Sequence[Union[np.ndarray, Tuple[np.ndarray]]]) -> Dict[str, float]:
+ preds, _ = eval_preds
+ return {"accuracy": (preds[0] > preds[1]).sum() / len(preds[0])}
diff --git a/src/AntSK.LLamaFactory/llamafactory/llmtuner/train/rm/trainer.py b/src/AntSK.LLamaFactory/llamafactory/llmtuner/train/rm/trainer.py
new file mode 100644
index 0000000..f7e104c
--- /dev/null
+++ b/src/AntSK.LLamaFactory/llamafactory/llmtuner/train/rm/trainer.py
@@ -0,0 +1,99 @@
+import json
+import os
+from typing import TYPE_CHECKING, Dict, List, Tuple, Union
+
+import torch
+from transformers import Trainer
+
+from ...extras.logging import get_logger
+
+
+if TYPE_CHECKING:
+ from transformers.modeling_utils import PreTrainedModel
+ from transformers.trainer import PredictionOutput
+
+
+logger = get_logger(__name__)
+
+
+class PairwiseTrainer(Trainer):
+ r"""
+ Inherits PeftTrainer to compute pairwise loss.
+ """
+
+ def __init__(self, *args, **kwargs):
+ super().__init__(*args, **kwargs)
+ self.can_return_loss = True # override property to return eval_loss
+
+ def compute_loss(
+ self, model: "PreTrainedModel", inputs: Dict[str, torch.Tensor], return_outputs: bool = False
+ ) -> Union[torch.Tensor, Tuple[torch.Tensor, List[torch.Tensor]]]:
+ r"""
+ Computes pairwise loss. The first n examples are chosen and the last n examples are rejected.
+
+ Subclass and override to inject custom behavior.
+
+ Note that the first element will be removed from the output tuple.
+ See: https://github.com/huggingface/transformers/blob/v4.30.2/src/transformers/trainer.py#L3509
+ """
+ # Compute rewards
+ _, _, values = model(**inputs, output_hidden_states=True, return_dict=True)
+
+ unwrapped_model: "PreTrainedModel" = self.accelerator.unwrap_model(self.model)
+ if getattr(unwrapped_model.config, "model_type", None) == "chatglm":
+ values = torch.transpose(values, 0, 1)
+
+ # Split the inputs and rewards into two parts, chosen and rejected
+ batch_size = inputs["input_ids"].size(0) // 2
+ chosen_input_ids, rejected_input_ids = inputs["input_ids"][:batch_size], inputs["input_ids"][batch_size:]
+ chosen_rewards, rejected_rewards = values[:batch_size], values[batch_size:]
+ chosen_scores, rejected_scores = [], []
+
+ # Compute pairwise loss. Only backprop on the different tokens before padding
+ # Inspired by: https://github.com/CarperAI/trlx/blob/main/examples/summarize_rlhf/reward_model/reward_model.py
+ loss = 0
+ for i in range(batch_size):
+ chosen_length = (chosen_input_ids[i] != self.tokenizer.pad_token_id).nonzero()[-1] + 1
+ rejected_length = (rejected_input_ids[i] != self.tokenizer.pad_token_id).nonzero()[-1] + 1
+ check_divergence = (chosen_input_ids[i] != rejected_input_ids[i]).nonzero()
+
+ if len(check_divergence) == 0:
+ end_index = chosen_length
+ div_index = end_index - 1
+ else:
+ end_index = max(chosen_length, rejected_length)
+ div_index = check_divergence[0]
+
+ assert div_index > 0
+ chosen_trunc_rewards = chosen_rewards[i, div_index:end_index]
+ rejected_trunc_rewards = rejected_rewards[i, div_index:end_index]
+ if return_outputs: # use the score on the last token except pad token for inference
+ chosen_scores.append(chosen_rewards[i, chosen_length - 1])
+ rejected_scores.append(rejected_rewards[i, rejected_length - 1])
+ loss += -torch.nn.functional.logsigmoid(chosen_trunc_rewards - rejected_trunc_rewards).mean()
+
+ loss = loss / batch_size
+ if return_outputs:
+ chosen_scores, rejected_scores = torch.stack(chosen_scores), torch.stack(rejected_scores)
+ return loss, [loss, chosen_scores, rejected_scores]
+
+ return loss
+
+ def save_predictions(self, predict_results: "PredictionOutput") -> None:
+ r"""
+ Saves model predictions to `output_dir`.
+
+ A custom behavior that not contained in Seq2SeqTrainer.
+ """
+ if not self.is_world_process_zero():
+ return
+
+ output_prediction_file = os.path.join(self.args.output_dir, "generated_predictions.jsonl")
+ logger.info(f"Saving prediction results to {output_prediction_file}")
+ chosen_scores, rejected_scores = predict_results.predictions
+
+ with open(output_prediction_file, "w", encoding="utf-8") as writer:
+ res: List[str] = []
+ for c_score, r_score in zip(chosen_scores, rejected_scores):
+ res.append(json.dumps({"chosen": round(float(c_score), 2), "rejected": round(float(r_score), 2)}))
+ writer.write("\n".join(res))
diff --git a/src/AntSK.LLamaFactory/llamafactory/llmtuner/train/rm/workflow.py b/src/AntSK.LLamaFactory/llamafactory/llmtuner/train/rm/workflow.py
new file mode 100644
index 0000000..9dfef30
--- /dev/null
+++ b/src/AntSK.LLamaFactory/llamafactory/llmtuner/train/rm/workflow.py
@@ -0,0 +1,76 @@
+# Inspired by: https://github.com/CarperAI/trlx/blob/main/examples/summarize_rlhf/reward_model/train_reward_model_gptj.py
+
+from typing import TYPE_CHECKING, List, Optional
+
+from ...data import get_dataset, split_dataset
+from ...extras.callbacks import FixValueHeadModelCallback
+from ...extras.misc import fix_valuehead_checkpoint
+from ...extras.ploting import plot_loss
+from ...model import load_model, load_tokenizer
+from ..utils import create_custom_optimzer, create_modelcard_and_push
+from .collator import PairwiseDataCollatorWithPadding
+from .metric import compute_accuracy
+from .trainer import PairwiseTrainer
+
+
+if TYPE_CHECKING:
+ from transformers import Seq2SeqTrainingArguments, TrainerCallback
+
+ from ...hparams import DataArguments, FinetuningArguments, ModelArguments
+
+
+def run_rm(
+ model_args: "ModelArguments",
+ data_args: "DataArguments",
+ training_args: "Seq2SeqTrainingArguments",
+ finetuning_args: "FinetuningArguments",
+ callbacks: Optional[List["TrainerCallback"]] = None,
+):
+ tokenizer = load_tokenizer(model_args)
+ dataset = get_dataset(tokenizer, model_args, data_args, training_args, stage="rm")
+ model = load_model(tokenizer, model_args, finetuning_args, training_args.do_train, add_valuehead=True)
+ data_collator = PairwiseDataCollatorWithPadding(tokenizer, pad_to_multiple_of=8)
+
+ # Update arguments
+ training_args.remove_unused_columns = False # important for pairwise dataset
+
+ # Initialize our Trainer
+ optimizer = create_custom_optimzer(model, dataset, training_args, finetuning_args)
+ trainer = PairwiseTrainer(
+ model=model,
+ args=training_args,
+ tokenizer=tokenizer,
+ data_collator=data_collator,
+ callbacks=callbacks + [FixValueHeadModelCallback()],
+ optimizers=(optimizer, None),
+ compute_metrics=compute_accuracy,
+ **split_dataset(dataset, data_args, training_args),
+ )
+
+ # Training
+ if training_args.do_train:
+ train_result = trainer.train(resume_from_checkpoint=training_args.resume_from_checkpoint)
+ trainer.save_model()
+ if training_args.should_save:
+ fix_valuehead_checkpoint(model, training_args.output_dir, training_args.save_safetensors)
+ trainer.log_metrics("train", train_result.metrics)
+ trainer.save_metrics("train", train_result.metrics)
+ trainer.save_state()
+ if trainer.is_world_process_zero() and finetuning_args.plot_loss:
+ plot_loss(training_args.output_dir, keys=["loss", "eval_loss"])
+
+ # Evaluation
+ if training_args.do_eval:
+ metrics = trainer.evaluate(metric_key_prefix="eval")
+ trainer.log_metrics("eval", metrics)
+ trainer.save_metrics("eval", metrics)
+
+ # Predict
+ if training_args.do_predict:
+ predict_results = trainer.predict(dataset, metric_key_prefix="predict")
+ trainer.log_metrics("predict", predict_results.metrics)
+ trainer.save_metrics("predict", predict_results.metrics)
+ trainer.save_predictions(predict_results)
+
+ # Create model card
+ create_modelcard_and_push(trainer, model_args, data_args, training_args, finetuning_args)
diff --git a/src/AntSK.LLamaFactory/llamafactory/llmtuner/train/sft/__init__.py b/src/AntSK.LLamaFactory/llamafactory/llmtuner/train/sft/__init__.py
new file mode 100644
index 0000000..f2f84e7
--- /dev/null
+++ b/src/AntSK.LLamaFactory/llamafactory/llmtuner/train/sft/__init__.py
@@ -0,0 +1,4 @@
+from .workflow import run_sft
+
+
+__all__ = ["run_sft"]
diff --git a/src/AntSK.LLamaFactory/llamafactory/llmtuner/train/sft/metric.py b/src/AntSK.LLamaFactory/llamafactory/llmtuner/train/sft/metric.py
new file mode 100644
index 0000000..d1af4c1
--- /dev/null
+++ b/src/AntSK.LLamaFactory/llamafactory/llmtuner/train/sft/metric.py
@@ -0,0 +1,61 @@
+from dataclasses import dataclass
+from typing import TYPE_CHECKING, Dict, Sequence, Tuple, Union
+
+import numpy as np
+
+from ...extras.constants import IGNORE_INDEX
+from ...extras.packages import is_jieba_available, is_nltk_available, is_rouge_available
+
+
+if TYPE_CHECKING:
+ from transformers.tokenization_utils import PreTrainedTokenizer
+
+if is_jieba_available():
+ import jieba # type: ignore
+
+if is_nltk_available():
+ from nltk.translate.bleu_score import SmoothingFunction, sentence_bleu
+
+if is_rouge_available():
+ from rouge_chinese import Rouge
+
+
+@dataclass
+class ComputeMetrics:
+ r"""
+ Wraps the tokenizer into metric functions, used in Seq2SeqPeftTrainer.
+ """
+
+ tokenizer: "PreTrainedTokenizer"
+
+ def __call__(self, eval_preds: Sequence[Union[np.ndarray, Tuple[np.ndarray]]]) -> Dict[str, float]:
+ r"""
+ Uses the model predictions to compute metrics.
+ """
+ preds, labels = eval_preds
+ score_dict = {"rouge-1": [], "rouge-2": [], "rouge-l": [], "bleu-4": []}
+
+ preds = np.where(preds != IGNORE_INDEX, preds, self.tokenizer.pad_token_id)
+ labels = np.where(labels != IGNORE_INDEX, labels, self.tokenizer.pad_token_id)
+
+ decoded_preds = self.tokenizer.batch_decode(preds, skip_special_tokens=True)
+ decoded_labels = self.tokenizer.batch_decode(labels, skip_special_tokens=True)
+
+ for pred, label in zip(decoded_preds, decoded_labels):
+ hypothesis = list(jieba.cut(pred))
+ reference = list(jieba.cut(label))
+
+ if len(" ".join(hypothesis).split()) == 0 or len(" ".join(reference).split()) == 0:
+ result = {"rouge-1": {"f": 0.0}, "rouge-2": {"f": 0.0}, "rouge-l": {"f": 0.0}}
+ else:
+ rouge = Rouge()
+ scores = rouge.get_scores(" ".join(hypothesis), " ".join(reference))
+ result = scores[0]
+
+ for k, v in result.items():
+ score_dict[k].append(round(v["f"] * 100, 4))
+
+ bleu_score = sentence_bleu([list(label)], list(pred), smoothing_function=SmoothingFunction().method3)
+ score_dict["bleu-4"].append(round(bleu_score * 100, 4))
+
+ return {k: float(np.mean(v)) for k, v in score_dict.items()}
diff --git a/src/AntSK.LLamaFactory/llamafactory/llmtuner/train/sft/trainer.py b/src/AntSK.LLamaFactory/llamafactory/llmtuner/train/sft/trainer.py
new file mode 100644
index 0000000..36d09f3
--- /dev/null
+++ b/src/AntSK.LLamaFactory/llamafactory/llmtuner/train/sft/trainer.py
@@ -0,0 +1,100 @@
+import json
+import os
+from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union
+
+import numpy as np
+import torch
+import torch.nn as nn
+from transformers import Seq2SeqTrainer
+
+from ...extras.constants import IGNORE_INDEX
+from ...extras.logging import get_logger
+
+
+if TYPE_CHECKING:
+ from transformers.trainer import PredictionOutput
+
+
+logger = get_logger(__name__)
+
+
+class CustomSeq2SeqTrainer(Seq2SeqTrainer):
+ r"""
+ Inherits PeftTrainer to compute generative metrics such as BLEU and ROUGE.
+ """
+
+ def prediction_step(
+ self,
+ model: nn.Module,
+ inputs: Dict[str, Union[torch.Tensor, Any]],
+ prediction_loss_only: bool,
+ ignore_keys: Optional[List[str]] = None,
+ ) -> Tuple[Optional[float], Optional[torch.Tensor], Optional[torch.Tensor]]:
+ r"""
+ Removes the prompt part in the generated tokens.
+
+ Subclass and override to inject custom behavior.
+ """
+ labels = inputs["labels"].detach().clone() if "labels" in inputs else None # backup labels
+ if self.args.predict_with_generate:
+ assert self.tokenizer.padding_side == "left", "This method only accepts left-padded tensor."
+ prompt_len, label_len = inputs["input_ids"].size(-1), inputs["labels"].size(-1)
+ if prompt_len > label_len:
+ inputs["labels"] = self._pad_tensors_to_target_len(inputs["labels"], inputs["input_ids"])
+ if label_len > prompt_len: # truncate the labels instead of padding the inputs (llama2 fp16 compatibility)
+ inputs["labels"] = inputs["labels"][:, :prompt_len]
+
+ loss, generated_tokens, _ = super().prediction_step( # ignore the returned labels (may be truncated)
+ model, inputs, prediction_loss_only=prediction_loss_only, ignore_keys=ignore_keys
+ )
+ if generated_tokens is not None and self.args.predict_with_generate:
+ generated_tokens[:, :prompt_len] = self.tokenizer.pad_token_id
+ generated_tokens = generated_tokens.contiguous()
+
+ return loss, generated_tokens, labels
+
+ def _pad_tensors_to_target_len(self, src_tensor: torch.Tensor, tgt_tensor: torch.Tensor) -> torch.Tensor:
+ r"""
+ Pads the tensor to the same length as the target tensor.
+ """
+ assert self.tokenizer.pad_token_id is not None, "Pad token is required."
+ padded_tensor = self.tokenizer.pad_token_id * torch.ones_like(tgt_tensor)
+ padded_tensor[:, -src_tensor.shape[-1] :] = src_tensor # adopt left-padding
+ return padded_tensor.contiguous() # in contiguous memory
+
+ def save_predictions(self, predict_results: "PredictionOutput") -> None:
+ r"""
+ Saves model predictions to `output_dir`.
+
+ A custom behavior that not contained in Seq2SeqTrainer.
+ """
+ if not self.is_world_process_zero():
+ return
+
+ output_prediction_file = os.path.join(self.args.output_dir, "generated_predictions.jsonl")
+ logger.info(f"Saving prediction results to {output_prediction_file}")
+
+ labels = np.where(
+ predict_results.label_ids != IGNORE_INDEX, predict_results.label_ids, self.tokenizer.pad_token_id
+ )
+ preds = np.where(
+ predict_results.predictions != IGNORE_INDEX, predict_results.predictions, self.tokenizer.pad_token_id
+ )
+
+ for i in range(len(preds)):
+ pad_len = np.nonzero(preds[i] != self.tokenizer.pad_token_id)[0]
+ if len(pad_len):
+ preds[i] = np.concatenate(
+ (preds[i][pad_len[0] :], preds[i][: pad_len[0]]), axis=-1
+ ) # move pad token to last
+
+ decoded_labels = self.tokenizer.batch_decode(
+ labels, skip_special_tokens=True, clean_up_tokenization_spaces=False
+ )
+ decoded_preds = self.tokenizer.batch_decode(preds, skip_special_tokens=True, clean_up_tokenization_spaces=True)
+
+ with open(output_prediction_file, "w", encoding="utf-8") as writer:
+ res: List[str] = []
+ for label, pred in zip(decoded_labels, decoded_preds):
+ res.append(json.dumps({"label": label, "predict": pred}, ensure_ascii=False))
+ writer.write("\n".join(res))
diff --git a/src/AntSK.LLamaFactory/llamafactory/llmtuner/train/sft/workflow.py b/src/AntSK.LLamaFactory/llamafactory/llmtuner/train/sft/workflow.py
new file mode 100644
index 0000000..099edc1
--- /dev/null
+++ b/src/AntSK.LLamaFactory/llamafactory/llmtuner/train/sft/workflow.py
@@ -0,0 +1,99 @@
+# Inspired by: https://github.com/huggingface/transformers/blob/v4.34.1/examples/pytorch/summarization/run_summarization.py
+
+from typing import TYPE_CHECKING, List, Optional
+
+from transformers import DataCollatorForSeq2Seq
+
+from ...data import get_dataset, split_dataset
+from ...extras.constants import IGNORE_INDEX
+from ...extras.misc import get_logits_processor
+from ...extras.ploting import plot_loss
+from ...model import load_model, load_tokenizer
+from ...train.sft.metric import ComputeMetrics
+from ...train.sft.trainer import CustomSeq2SeqTrainer
+from ...train.utils import create_modelcard_and_push
+from ..utils import create_custom_optimzer
+
+
+if TYPE_CHECKING:
+ from transformers import Seq2SeqTrainingArguments, TrainerCallback
+
+ from ...hparams import DataArguments, FinetuningArguments, GeneratingArguments, ModelArguments
+
+
+def run_sft(
+ model_args: "ModelArguments",
+ data_args: "DataArguments",
+ training_args: "Seq2SeqTrainingArguments",
+ finetuning_args: "FinetuningArguments",
+ generating_args: "GeneratingArguments",
+ callbacks: Optional[List["TrainerCallback"]] = None,
+):
+ tokenizer = load_tokenizer(model_args)
+ dataset = get_dataset(tokenizer, model_args, data_args, training_args, stage="sft")
+ model = load_model(tokenizer, model_args, finetuning_args, training_args.do_train)
+
+ if training_args.predict_with_generate:
+ tokenizer.padding_side = "left" # use left-padding in generation
+
+ if getattr(model, "is_quantized", False) and not training_args.do_train:
+ setattr(model, "_hf_peft_config_loaded", True) # hack here: make model compatible with prediction
+
+ data_collator = DataCollatorForSeq2Seq(
+ tokenizer=tokenizer,
+ pad_to_multiple_of=8 if tokenizer.padding_side == "right" else None, # for shift short attention
+ label_pad_token_id=IGNORE_INDEX if data_args.ignore_pad_token_for_loss else tokenizer.pad_token_id,
+ )
+
+ # Override the decoding parameters of Seq2SeqTrainer
+ training_args.generation_max_length = training_args.generation_max_length or data_args.cutoff_len
+ training_args.generation_num_beams = data_args.eval_num_beams or training_args.generation_num_beams
+
+ # Initialize our Trainer
+ optimizer = create_custom_optimzer(model, dataset, training_args, finetuning_args)
+ trainer = CustomSeq2SeqTrainer(
+ model=model,
+ args=training_args,
+ tokenizer=tokenizer,
+ data_collator=data_collator,
+ callbacks=callbacks,
+ optimizers=(optimizer, None),
+ compute_metrics=ComputeMetrics(tokenizer) if training_args.predict_with_generate else None,
+ **split_dataset(dataset, data_args, training_args),
+ )
+
+ # Keyword arguments for `model.generate`
+ gen_kwargs = generating_args.to_dict()
+ gen_kwargs["eos_token_id"] = [tokenizer.eos_token_id] + tokenizer.additional_special_tokens_ids
+ gen_kwargs["pad_token_id"] = tokenizer.pad_token_id
+ gen_kwargs["logits_processor"] = get_logits_processor()
+
+ # Training
+ if training_args.do_train:
+ train_result = trainer.train(resume_from_checkpoint=training_args.resume_from_checkpoint)
+ trainer.save_model()
+ trainer.log_metrics("train", train_result.metrics)
+ trainer.save_metrics("train", train_result.metrics)
+ trainer.save_state()
+ if trainer.is_world_process_zero() and finetuning_args.plot_loss:
+ plot_loss(training_args.output_dir, keys=["loss", "eval_loss"])
+
+ # Evaluation
+ if training_args.do_eval:
+ metrics = trainer.evaluate(metric_key_prefix="eval", **gen_kwargs)
+ if training_args.predict_with_generate: # eval_loss will be wrong if predict_with_generate is enabled
+ metrics.pop("eval_loss", None)
+ trainer.log_metrics("eval", metrics)
+ trainer.save_metrics("eval", metrics)
+
+ # Predict
+ if training_args.do_predict:
+ predict_results = trainer.predict(dataset, metric_key_prefix="predict", **gen_kwargs)
+ if training_args.predict_with_generate: # predict_loss will be wrong if predict_with_generate is enabled
+ predict_results.metrics.pop("predict_loss", None)
+ trainer.log_metrics("predict", predict_results.metrics)
+ trainer.save_metrics("predict", predict_results.metrics)
+ trainer.save_predictions(predict_results)
+
+ # Create model card
+ create_modelcard_and_push(trainer, model_args, data_args, training_args, finetuning_args)
diff --git a/src/AntSK.LLamaFactory/llamafactory/llmtuner/train/tuner.py b/src/AntSK.LLamaFactory/llamafactory/llmtuner/train/tuner.py
new file mode 100644
index 0000000..a1b7bec
--- /dev/null
+++ b/src/AntSK.LLamaFactory/llamafactory/llmtuner/train/tuner.py
@@ -0,0 +1,93 @@
+from typing import TYPE_CHECKING, Any, Dict, List, Optional
+
+import torch
+from transformers import PreTrainedModel
+
+from ..data import get_template_and_fix_tokenizer
+from ..extras.callbacks import LogCallback
+from ..extras.logging import get_logger
+from ..hparams import get_infer_args, get_train_args
+from ..model import load_model_and_tokenizer
+from .dpo import run_dpo
+from .ppo import run_ppo
+from .pt import run_pt
+from .rm import run_rm
+from .sft import run_sft
+
+
+if TYPE_CHECKING:
+ from transformers import TrainerCallback
+
+
+logger = get_logger(__name__)
+
+
+def run_exp(args: Optional[Dict[str, Any]] = None, callbacks: Optional[List["TrainerCallback"]] = None):
+ model_args, data_args, training_args, finetuning_args, generating_args = get_train_args(args)
+ callbacks = [LogCallback()] if callbacks is None else callbacks
+
+ if finetuning_args.stage == "pt":
+ run_pt(model_args, data_args, training_args, finetuning_args, callbacks)
+ elif finetuning_args.stage == "sft":
+ run_sft(model_args, data_args, training_args, finetuning_args, generating_args, callbacks)
+ elif finetuning_args.stage == "rm":
+ run_rm(model_args, data_args, training_args, finetuning_args, callbacks)
+ elif finetuning_args.stage == "ppo":
+ run_ppo(model_args, data_args, training_args, finetuning_args, generating_args, callbacks)
+ elif finetuning_args.stage == "dpo":
+ run_dpo(model_args, data_args, training_args, finetuning_args, callbacks)
+ else:
+ raise ValueError("Unknown task.")
+
+
+def export_model(args: Optional[Dict[str, Any]] = None):
+ model_args, data_args, finetuning_args, _ = get_infer_args(args)
+
+ if model_args.export_dir is None:
+ raise ValueError("Please specify `export_dir`.")
+
+ if model_args.adapter_name_or_path is not None and model_args.export_quantization_bit is not None:
+ raise ValueError("Please merge adapters before quantizing the model.")
+
+ model, tokenizer = load_model_and_tokenizer(model_args, finetuning_args)
+ get_template_and_fix_tokenizer(tokenizer, data_args.template)
+
+ if getattr(model, "quantization_method", None) and model_args.adapter_name_or_path is not None:
+ raise ValueError("Cannot merge adapters to a quantized model.")
+
+ if not isinstance(model, PreTrainedModel):
+ raise ValueError("The model is not a `PreTrainedModel`, export aborted.")
+
+ if getattr(model, "quantization_method", None):
+ model = model.to("cpu")
+ elif hasattr(model.config, "torch_dtype"):
+ model = model.to(getattr(model.config, "torch_dtype")).to("cpu")
+ else:
+ model = model.to(torch.float16).to("cpu")
+ setattr(model.config, "torch_dtype", torch.float16)
+
+ model.save_pretrained(
+ save_directory=model_args.export_dir,
+ max_shard_size="{}GB".format(model_args.export_size),
+ safe_serialization=(not model_args.export_legacy_format),
+ )
+ if model_args.export_hub_model_id is not None:
+ model.push_to_hub(
+ model_args.export_hub_model_id,
+ token=model_args.hf_hub_token,
+ max_shard_size="{}GB".format(model_args.export_size),
+ safe_serialization=(not model_args.export_legacy_format),
+ )
+
+ try:
+ tokenizer.padding_side = "left" # restore padding side
+ tokenizer.init_kwargs["padding_side"] = "left"
+ tokenizer.save_pretrained(model_args.export_dir)
+ if model_args.export_hub_model_id is not None:
+ tokenizer.push_to_hub(model_args.export_hub_model_id, token=model_args.hf_hub_token)
+ except Exception:
+ logger.warning("Cannot save tokenizer, please copy the files manually.")
+
+
+if __name__ == "__main__":
+ run_exp()
diff --git a/src/AntSK.LLamaFactory/llamafactory/llmtuner/train/utils.py b/src/AntSK.LLamaFactory/llamafactory/llmtuner/train/utils.py
new file mode 100644
index 0000000..425ff18
--- /dev/null
+++ b/src/AntSK.LLamaFactory/llamafactory/llmtuner/train/utils.py
@@ -0,0 +1,246 @@
+import math
+from typing import TYPE_CHECKING, Callable, Dict, List, Optional, Union
+
+import torch
+from transformers.optimization import get_scheduler
+from transformers.utils.versions import require_version
+
+from ..extras.logging import get_logger
+from ..extras.packages import is_galore_available
+from ..hparams import FinetuningArguments, ModelArguments
+from ..model import load_model_and_tokenizer, load_valuehead_params
+
+
+if is_galore_available():
+ from galore_torch import GaLoreAdafactor, GaLoreAdamW, GaLoreAdamW8bit
+
+
+if TYPE_CHECKING:
+ from datasets import Dataset, IterableDataset
+ from transformers import Seq2SeqTrainingArguments, Trainer
+ from transformers.modeling_utils import PreTrainedModel
+ from trl import AutoModelForCausalLMWithValueHead
+
+ from ..hparams import DataArguments
+
+
+logger = get_logger(__name__)
+
+
+class DummyOptimizer(torch.optim.Optimizer):
+ def __init__(self, *args, **kwargs):
+ dummy_tensor = torch.randn(1, 1)
+ super().__init__([dummy_tensor], {"lr": 1e-3})
+
+ def zero_grad(self, set_to_none: bool = True) -> None:
+ pass
+
+ def step(self, closure: Optional[Callable[[], float]] = None) -> Optional[float]:
+ pass
+
+
+def create_modelcard_and_push(
+ trainer: "Trainer",
+ model_args: "ModelArguments",
+ data_args: "DataArguments",
+ training_args: "Seq2SeqTrainingArguments",
+ finetuning_args: "FinetuningArguments",
+) -> None:
+ kwargs = {
+ "tasks": "text-generation",
+ "finetuned_from": model_args.model_name_or_path,
+ "dataset": [dataset.strip() for dataset in data_args.dataset.split(",")],
+ "tags": ["llama-factory", finetuning_args.finetuning_type],
+ }
+ if not training_args.do_train:
+ pass
+ elif training_args.push_to_hub:
+ trainer.push_to_hub(**kwargs)
+ else:
+ trainer.create_model_card(license="other", **kwargs) # prevent from connecting to hub
+
+
+def create_ref_model(
+ model_args: "ModelArguments", finetuning_args: "FinetuningArguments", add_valuehead: bool = False
+) -> Union["PreTrainedModel", "AutoModelForCausalLMWithValueHead"]:
+ r"""
+ Creates reference model for PPO/DPO training. Evaluation mode is not supported.
+
+ The valuehead parameter is randomly initialized since it is useless for PPO training.
+ """
+ if finetuning_args.ref_model is not None:
+ ref_model_args_dict = model_args.to_dict()
+ ref_model_args_dict.update(
+ dict(
+ model_name_or_path=finetuning_args.ref_model,
+ adapter_name_or_path=finetuning_args.ref_model_adapters,
+ quantization_bit=finetuning_args.ref_model_quantization_bit,
+ )
+ )
+ ref_model_args = ModelArguments(**ref_model_args_dict)
+ ref_finetuning_args = FinetuningArguments(finetuning_type="lora")
+ ref_model, _ = load_model_and_tokenizer(
+ ref_model_args, ref_finetuning_args, is_trainable=False, add_valuehead=add_valuehead
+ )
+ logger.info("Created reference model from {}".format(finetuning_args.ref_model))
+ else:
+ if finetuning_args.finetuning_type == "lora":
+ ref_model = None
+ else:
+ ref_model, _ = load_model_and_tokenizer(
+ model_args, finetuning_args, is_trainable=False, add_valuehead=add_valuehead
+ )
+ logger.info("Created reference model from the model itself.")
+
+ return ref_model
+
+
+def create_reward_model(
+ model: "AutoModelForCausalLMWithValueHead", model_args: "ModelArguments", finetuning_args: "FinetuningArguments"
+) -> "AutoModelForCausalLMWithValueHead":
+ r"""
+ Creates reward model for PPO training.
+ """
+ if finetuning_args.reward_model_type == "api":
+ assert finetuning_args.reward_model.startswith("http"), "Please provide full url."
+ logger.info("Use reward server {}".format(finetuning_args.reward_model))
+ return finetuning_args.reward_model
+ elif finetuning_args.reward_model_type == "lora":
+ model.pretrained_model.load_adapter(finetuning_args.reward_model, "reward")
+ for name, param in model.named_parameters(): # https://github.com/huggingface/peft/issues/1090
+ if "default" in name:
+ param.data = param.data.to(torch.float32) # trainable params should in fp32
+ vhead_params = load_valuehead_params(finetuning_args.reward_model, model_args)
+ assert vhead_params is not None, "Reward model is not correctly loaded."
+ model.register_buffer("reward_head_weight", vhead_params["v_head.summary.weight"], persistent=False)
+ model.register_buffer("reward_head_bias", vhead_params["v_head.summary.bias"], persistent=False)
+ model.register_buffer(
+ "default_head_weight", torch.zeros_like(vhead_params["v_head.summary.weight"]), persistent=False
+ )
+ model.register_buffer(
+ "default_head_bias", torch.zeros_like(vhead_params["v_head.summary.bias"]), persistent=False
+ )
+ logger.info("Loaded adapter weights of reward model from {}".format(finetuning_args.reward_model))
+ return None
+ else:
+ reward_model_args_dict = model_args.to_dict()
+ reward_model_args_dict.update(
+ dict(
+ model_name_or_path=finetuning_args.reward_model,
+ adapter_name_or_path=finetuning_args.reward_model_adapters,
+ quantization_bit=finetuning_args.reward_model_quantization_bit,
+ )
+ )
+ reward_model_args = ModelArguments(**reward_model_args_dict)
+ reward_finetuning_args = FinetuningArguments(finetuning_type="lora")
+ reward_model, _ = load_model_and_tokenizer(
+ reward_model_args, reward_finetuning_args, is_trainable=False, add_valuehead=True
+ )
+ logger.info("Loaded full weights of reward model from {}".format(finetuning_args.reward_model))
+ logger.warning("Please ensure the ppo model and reward model share SAME tokenizer and vocabulary.")
+ return reward_model
+
+
+def create_custom_optimzer(
+ model: "PreTrainedModel",
+ dataset: Union["Dataset", "IterableDataset"],
+ training_args: "Seq2SeqTrainingArguments",
+ finetuning_args: "FinetuningArguments",
+) -> Optional["torch.optim.Optimizer"]:
+ if not finetuning_args.use_galore:
+ return None
+
+ require_version("galore_torch", "To fix: pip install git+https://github.com/hiyouga/GaLore.git")
+ galore_params: List[torch.nn.Parameter] = []
+ galore_targets = finetuning_args.galore_target.split(",")
+
+ for name, module in model.named_modules():
+ if isinstance(module, torch.nn.Linear) and any(target in name for target in galore_targets):
+ for param in module.parameters():
+ if param.requires_grad and len(param.shape) > 1:
+ galore_params.append(param)
+
+ id_galore_params = {id(param) for param in galore_params}
+ trainable_params = filter(lambda param: param.requires_grad, model.parameters())
+ non_galore_params = [param for param in trainable_params if id(param) not in id_galore_params]
+
+ if training_args.optim == "adamw_torch":
+ optim_class = GaLoreAdamW
+ optim_kwargs = {
+ "lr": training_args.learning_rate,
+ "eps": training_args.adam_epsilon,
+ "betas": (training_args.adam_beta1, training_args.adam_beta2),
+ "weight_decay": training_args.weight_decay,
+ }
+
+ elif training_args.optim in ["adamw_bnb_8bit", "adamw_8bit", "paged_adamw_8bit"]:
+ optim_class = GaLoreAdamW8bit
+ optim_kwargs = {
+ "lr": training_args.learning_rate,
+ "eps": training_args.adam_epsilon,
+ "betas": (training_args.adam_beta1, training_args.adam_beta2),
+ "weight_decay": training_args.weight_decay,
+ "optim_bits": 8,
+ "is_paged": "paged" in training_args.optim,
+ }
+
+ elif training_args.optim == "adafactor":
+ optim_class = GaLoreAdafactor
+ optim_kwargs = {
+ "lr": training_args.learning_rate,
+ "weight_decay": training_args.weight_decay,
+ }
+
+ else:
+ raise NotImplementedError("Unknow optim: {}".format(training_args.optim))
+
+ galore_kwargs = {
+ "rank": finetuning_args.galore_rank,
+ "update_proj_gap": finetuning_args.galore_update_interval,
+ "scale": finetuning_args.galore_scale,
+ "proj_type": finetuning_args.galore_proj_type,
+ }
+
+ if finetuning_args.galore_layerwise:
+ if training_args.gradient_accumulation_steps != 1:
+ raise ValueError("Per-layer GaLore does not support gradient accumulation.")
+
+ if training_args.max_steps > 0:
+ num_training_steps = training_args.max_steps
+ else:
+ total_train_batch_size = training_args.per_device_train_batch_size * training_args.world_size
+ num_training_steps = training_args.num_train_epochs * math.ceil(len(dataset) / total_train_batch_size)
+
+ optimizer_dict: Dict["torch.Tensor", "torch.optim.Optimizer"] = {}
+ for param in non_galore_params:
+ param_groups = [dict(params=[param])]
+ optimizer_dict[param] = optim_class(param_groups, **optim_kwargs)
+ for param in galore_params:
+ param_groups = [dict(params=[param], **galore_kwargs)]
+ optimizer_dict[param] = optim_class(param_groups, **optim_kwargs)
+
+ scheduler_dict: Dict["torch.Tensor", "torch.optim.lr_scheduler.LRScheduler"] = {}
+ for param in non_galore_params + galore_params:
+ scheduler_dict[param] = get_scheduler(
+ training_args.lr_scheduler_type,
+ optimizer=optimizer_dict[param],
+ num_warmup_steps=training_args.get_warmup_steps(num_training_steps) * 2,
+ num_training_steps=num_training_steps * 2,
+ )
+
+ def optimizer_hook(param: "torch.Tensor"):
+ if param.grad is not None:
+ optimizer_dict[param].step()
+ optimizer_dict[param].zero_grad()
+ scheduler_dict[param].step()
+
+ for param in non_galore_params + galore_params:
+ param.register_post_accumulate_grad_hook(optimizer_hook)
+
+ optimizer = DummyOptimizer()
+ else:
+ param_groups = [dict(params=non_galore_params), dict(params=galore_params, **galore_kwargs)]
+ optimizer = optim_class(param_groups, **optim_kwargs)
+
+ logger.info("Using GaLore optimizer, may cause hanging at the start of training, wait patiently.")
+ return optimizer
diff --git a/src/AntSK.LLamaFactory/llamafactory/llmtuner/webui/__init__.py b/src/AntSK.LLamaFactory/llamafactory/llmtuner/webui/__init__.py
new file mode 100644
index 0000000..3e82dd6
--- /dev/null
+++ b/src/AntSK.LLamaFactory/llamafactory/llmtuner/webui/__init__.py
@@ -0,0 +1,4 @@
+from .interface import create_ui, create_web_demo
+
+
+__all__ = ["create_ui", "create_web_demo"]
diff --git a/src/AntSK.LLamaFactory/llamafactory/llmtuner/webui/chatter.py b/src/AntSK.LLamaFactory/llamafactory/llmtuner/webui/chatter.py
new file mode 100644
index 0000000..d149ca2
--- /dev/null
+++ b/src/AntSK.LLamaFactory/llamafactory/llmtuner/webui/chatter.py
@@ -0,0 +1,137 @@
+import json
+import os
+from typing import TYPE_CHECKING, Any, Dict, Generator, List, Optional, Sequence, Tuple
+
+import gradio as gr
+from gradio.components import Component # cannot use TYPE_CHECKING here
+
+from ..chat import ChatModel
+from ..data import Role
+from ..extras.misc import torch_gc
+from .common import get_save_dir
+from .locales import ALERTS
+
+
+if TYPE_CHECKING:
+ from ..chat import BaseEngine
+ from .manager import Manager
+
+
+class WebChatModel(ChatModel):
+ def __init__(self, manager: "Manager", demo_mode: bool = False, lazy_init: bool = True) -> None:
+ self.manager = manager
+ self.demo_mode = demo_mode
+ self.engine: Optional["BaseEngine"] = None
+
+ if not lazy_init: # read arguments from command line
+ super().__init__()
+
+ if demo_mode and os.environ.get("DEMO_MODEL") and os.environ.get("DEMO_TEMPLATE"): # load demo model
+ model_name_or_path = os.environ.get("DEMO_MODEL")
+ template = os.environ.get("DEMO_TEMPLATE")
+ super().__init__(dict(model_name_or_path=model_name_or_path, template=template))
+
+ @property
+ def loaded(self) -> bool:
+ return self.engine is not None
+
+ def load_model(self, data: Dict[Component, Any]) -> Generator[str, None, None]:
+ get = lambda name: data[self.manager.get_elem_by_name(name)]
+ lang = get("top.lang")
+ error = ""
+ if self.loaded:
+ error = ALERTS["err_exists"][lang]
+ elif not get("top.model_name"):
+ error = ALERTS["err_no_model"][lang]
+ elif not get("top.model_path"):
+ error = ALERTS["err_no_path"][lang]
+ elif self.demo_mode:
+ error = ALERTS["err_demo"][lang]
+
+ if error:
+ gr.Warning(error)
+ yield error
+ return
+
+ if get("top.adapter_path"):
+ adapter_name_or_path = ",".join(
+ [
+ get_save_dir(get("top.model_name"), get("top.finetuning_type"), adapter)
+ for adapter in get("top.adapter_path")
+ ]
+ )
+ else:
+ adapter_name_or_path = None
+
+ yield ALERTS["info_loading"][lang]
+ args = dict(
+ model_name_or_path=get("top.model_path"),
+ adapter_name_or_path=adapter_name_or_path,
+ finetuning_type=get("top.finetuning_type"),
+ quantization_bit=int(get("top.quantization_bit")) if get("top.quantization_bit") in ["8", "4"] else None,
+ template=get("top.template"),
+ flash_attn=(get("top.booster") == "flash_attn"),
+ use_unsloth=(get("top.booster") == "unsloth"),
+ rope_scaling=get("top.rope_scaling") if get("top.rope_scaling") in ["linear", "dynamic"] else None,
+ infer_backend=get("infer.infer_backend"),
+ )
+ super().__init__(args)
+
+ yield ALERTS["info_loaded"][lang]
+
+ def unload_model(self, data: Dict[Component, Any]) -> Generator[str, None, None]:
+ lang = data[self.manager.get_elem_by_name("top.lang")]
+
+ if self.demo_mode:
+ gr.Warning(ALERTS["err_demo"][lang])
+ yield ALERTS["err_demo"][lang]
+ return
+
+ yield ALERTS["info_unloading"][lang]
+ self.engine = None
+ torch_gc()
+ yield ALERTS["info_unloaded"][lang]
+
+ def predict(
+ self,
+ chatbot: List[Tuple[str, str]],
+ role: str,
+ query: str,
+ messages: Sequence[Tuple[str, str]],
+ system: str,
+ tools: str,
+ max_new_tokens: int,
+ top_p: float,
+ temperature: float,
+ ) -> Generator[Tuple[Sequence[Tuple[str, str]], Sequence[Tuple[str, str]]], None, None]:
+ chatbot.append([query, ""])
+ query_messages = messages + [{"role": role, "content": query}]
+ response = ""
+ for new_text in self.stream_chat(
+ query_messages, system, tools, max_new_tokens=max_new_tokens, top_p=top_p, temperature=temperature
+ ):
+ response += new_text
+ if tools:
+ result = self.engine.template.format_tools.extract(response)
+ else:
+ result = response
+
+ if isinstance(result, tuple):
+ name, arguments = result
+ arguments = json.loads(arguments)
+ tool_call = json.dumps({"name": name, "arguments": arguments}, ensure_ascii=False)
+ output_messages = query_messages + [{"role": Role.FUNCTION.value, "content": tool_call}]
+ bot_text = "```json\n" + tool_call + "\n```"
+ else:
+ output_messages = query_messages + [{"role": Role.ASSISTANT.value, "content": result}]
+ bot_text = result
+
+ chatbot[-1] = [query, self.postprocess(bot_text)]
+ yield chatbot, output_messages
+
+ def postprocess(self, response: str) -> str:
+ blocks = response.split("```")
+ for i, block in enumerate(blocks):
+ if i % 2 == 0:
+ blocks[i] = block.replace("<", "<").replace(">", ">")
+ return "```".join(blocks)
diff --git a/src/AntSK.LLamaFactory/llamafactory/llmtuner/webui/common.py b/src/AntSK.LLamaFactory/llamafactory/llmtuner/webui/common.py
new file mode 100644
index 0000000..961d6f0
--- /dev/null
+++ b/src/AntSK.LLamaFactory/llamafactory/llmtuner/webui/common.py
@@ -0,0 +1,115 @@
+import json
+import os
+from collections import defaultdict
+from typing import Any, Dict, Optional
+
+import gradio as gr
+from peft.utils import SAFETENSORS_WEIGHTS_NAME, WEIGHTS_NAME
+
+from ..extras.constants import (
+ DATA_CONFIG,
+ DEFAULT_MODULE,
+ DEFAULT_TEMPLATE,
+ PEFT_METHODS,
+ SUPPORTED_MODELS,
+ TRAINING_STAGES,
+ DownloadSource,
+)
+from ..extras.misc import use_modelscope
+
+
+ADAPTER_NAMES = {WEIGHTS_NAME, SAFETENSORS_WEIGHTS_NAME}
+DEFAULT_CACHE_DIR = "cache"
+DEFAULT_DATA_DIR = "data"
+DEFAULT_SAVE_DIR = "saves"
+USER_CONFIG = "user.config"
+
+
+def get_save_dir(*args) -> os.PathLike:
+ return os.path.join(DEFAULT_SAVE_DIR, *args)
+
+
+def get_config_path() -> os.PathLike:
+ return os.path.join(DEFAULT_CACHE_DIR, USER_CONFIG)
+
+
+def load_config() -> Dict[str, Any]:
+ try:
+ with open(get_config_path(), "r", encoding="utf-8") as f:
+ return json.load(f)
+ except Exception:
+ return {"lang": None, "last_model": None, "path_dict": {}, "cache_dir": None}
+
+
+def save_config(lang: str, model_name: Optional[str] = None, model_path: Optional[str] = None) -> None:
+ os.makedirs(DEFAULT_CACHE_DIR, exist_ok=True)
+ user_config = load_config()
+ user_config["lang"] = lang or user_config["lang"]
+ if model_name:
+ user_config["last_model"] = model_name
+ user_config["path_dict"][model_name] = model_path
+ with open(get_config_path(), "w", encoding="utf-8") as f:
+ json.dump(user_config, f, indent=2, ensure_ascii=False)
+
+
+def get_model_path(model_name: str) -> str:
+ user_config = load_config()
+ path_dict: Dict[DownloadSource, str] = SUPPORTED_MODELS.get(model_name, defaultdict(str))
+ model_path = user_config["path_dict"].get(model_name, None) or path_dict.get(DownloadSource.DEFAULT, None)
+ if (
+ use_modelscope()
+ and path_dict.get(DownloadSource.MODELSCOPE)
+ and model_path == path_dict.get(DownloadSource.DEFAULT)
+ ): # replace path
+ model_path = path_dict.get(DownloadSource.MODELSCOPE)
+ return model_path
+
+
+def get_prefix(model_name: str) -> str:
+ return model_name.split("-")[0]
+
+
+def get_module(model_name: str) -> str:
+ return DEFAULT_MODULE.get(get_prefix(model_name), "q_proj,v_proj")
+
+
+def get_template(model_name: str) -> str:
+ if model_name and model_name.endswith("Chat") and get_prefix(model_name) in DEFAULT_TEMPLATE:
+ return DEFAULT_TEMPLATE[get_prefix(model_name)]
+ return "default"
+
+
+def list_adapters(model_name: str, finetuning_type: str) -> Dict[str, Any]:
+ if finetuning_type not in PEFT_METHODS:
+ return gr.update(value=[], choices=[], interactive=False)
+
+ adapters = []
+ if model_name and finetuning_type == "lora":
+ save_dir = get_save_dir(model_name, finetuning_type)
+ if save_dir and os.path.isdir(save_dir):
+ for adapter in os.listdir(save_dir):
+ if os.path.isdir(os.path.join(save_dir, adapter)) and any(
+ os.path.isfile(os.path.join(save_dir, adapter, name)) for name in ADAPTER_NAMES
+ ):
+ adapters.append(adapter)
+ return gr.update(value=[], choices=adapters, interactive=True)
+
+
+def load_dataset_info(dataset_dir: str) -> Dict[str, Dict[str, Any]]:
+ try:
+ with open(os.path.join(dataset_dir, DATA_CONFIG), "r", encoding="utf-8") as f:
+ return json.load(f)
+ except Exception as err:
+ print("Cannot open {} due to {}.".format(os.path.join(dataset_dir, DATA_CONFIG), str(err)))
+ return {}
+
+
+def list_dataset(dataset_dir: str = None, training_stage: str = list(TRAINING_STAGES.keys())[0]) -> Dict[str, Any]:
+ dataset_info = load_dataset_info(dataset_dir if dataset_dir is not None else DEFAULT_DATA_DIR)
+ ranking = TRAINING_STAGES[training_stage] in ["rm", "dpo"]
+ datasets = [k for k, v in dataset_info.items() if v.get("ranking", False) == ranking]
+ return gr.update(value=[], choices=datasets)
+
+
+def autoset_packing(training_stage: str = list(TRAINING_STAGES.keys())[0]) -> Dict[str, Any]:
+ return gr.update(value=(TRAINING_STAGES[training_stage] == "pt"))
diff --git a/src/AntSK.LLamaFactory/llamafactory/llmtuner/webui/components/__init__.py b/src/AntSK.LLamaFactory/llamafactory/llmtuner/webui/components/__init__.py
new file mode 100644
index 0000000..5c1e21b
--- /dev/null
+++ b/src/AntSK.LLamaFactory/llamafactory/llmtuner/webui/components/__init__.py
@@ -0,0 +1,16 @@
+from .chatbot import create_chat_box
+from .eval import create_eval_tab
+from .export import create_export_tab
+from .infer import create_infer_tab
+from .top import create_top
+from .train import create_train_tab
+
+
+__all__ = [
+ "create_chat_box",
+ "create_eval_tab",
+ "create_export_tab",
+ "create_infer_tab",
+ "create_top",
+ "create_train_tab",
+]
diff --git a/src/AntSK.LLamaFactory/llamafactory/llmtuner/webui/components/chatbot.py b/src/AntSK.LLamaFactory/llamafactory/llmtuner/webui/components/chatbot.py
new file mode 100644
index 0000000..bf5bb66
--- /dev/null
+++ b/src/AntSK.LLamaFactory/llamafactory/llmtuner/webui/components/chatbot.py
@@ -0,0 +1,62 @@
+from typing import TYPE_CHECKING, Dict, Tuple
+
+import gradio as gr
+
+from ...data import Role
+from ..utils import check_json_schema
+
+
+if TYPE_CHECKING:
+ from gradio.blocks import Block
+ from gradio.components import Component
+
+ from ..engine import Engine
+
+
+def create_chat_box(
+ engine: "Engine", visible: bool = False
+) -> Tuple["Block", "Component", "Component", Dict[str, "Component"]]:
+ with gr.Box(visible=visible) as chat_box:
+ chatbot = gr.Chatbot()
+ messages = gr.State([])
+ with gr.Row():
+ with gr.Column(scale=4):
+ role = gr.Dropdown(choices=[Role.USER.value, Role.OBSERVATION.value], value=Role.USER.value)
+ system = gr.Textbox(show_label=False)
+ tools = gr.Textbox(show_label=False, lines=2)
+ query = gr.Textbox(show_label=False, lines=8)
+ submit_btn = gr.Button(variant="primary")
+
+ with gr.Column(scale=1):
+ max_new_tokens = gr.Slider(8, 4096, value=512, step=1)
+ top_p = gr.Slider(0.01, 1.0, value=0.7, step=0.01)
+ temperature = gr.Slider(0.01, 1.5, value=0.95, step=0.01)
+ clear_btn = gr.Button()
+
+ tools.input(check_json_schema, [tools, engine.manager.get_elem_by_name("top.lang")])
+
+ submit_btn.click(
+ engine.chatter.predict,
+ [chatbot, role, query, messages, system, tools, max_new_tokens, top_p, temperature],
+ [chatbot, messages],
+ show_progress=True,
+ ).then(lambda: gr.update(value=""), outputs=[query])
+
+ clear_btn.click(lambda: ([], []), outputs=[chatbot, messages], show_progress=True)
+
+ return (
+ chat_box,
+ chatbot,
+ messages,
+ dict(
+ role=role,
+ system=system,
+ tools=tools,
+ query=query,
+ submit_btn=submit_btn,
+ max_new_tokens=max_new_tokens,
+ top_p=top_p,
+ temperature=temperature,
+ clear_btn=clear_btn,
+ ),
+ )
diff --git a/src/AntSK.LLamaFactory/llamafactory/llmtuner/webui/components/data.py b/src/AntSK.LLamaFactory/llamafactory/llmtuner/webui/components/data.py
new file mode 100644
index 0000000..c63b6ea
--- /dev/null
+++ b/src/AntSK.LLamaFactory/llamafactory/llmtuner/webui/components/data.py
@@ -0,0 +1,93 @@
+import json
+import os
+from typing import TYPE_CHECKING, Any, Dict, Tuple
+
+import gradio as gr
+
+from ...extras.constants import DATA_CONFIG
+
+
+if TYPE_CHECKING:
+ from gradio.components import Component
+
+
+PAGE_SIZE = 2
+
+
+def prev_page(page_index: int) -> int:
+ return page_index - 1 if page_index > 0 else page_index
+
+
+def next_page(page_index: int, total_num: int) -> int:
+ return page_index + 1 if (page_index + 1) * PAGE_SIZE < total_num else page_index
+
+
+def can_preview(dataset_dir: str, dataset: list) -> Dict[str, Any]:
+ try:
+ with open(os.path.join(dataset_dir, DATA_CONFIG), "r", encoding="utf-8") as f:
+ dataset_info = json.load(f)
+ except Exception:
+ return gr.update(interactive=False)
+
+ if (
+ len(dataset) > 0
+ and "file_name" in dataset_info[dataset[0]]
+ and os.path.isfile(os.path.join(dataset_dir, dataset_info[dataset[0]]["file_name"]))
+ ):
+ return gr.update(interactive=True)
+ else:
+ return gr.update(interactive=False)
+
+
+def get_preview(dataset_dir: str, dataset: list, page_index: int) -> Tuple[int, list, Dict[str, Any]]:
+ with open(os.path.join(dataset_dir, DATA_CONFIG), "r", encoding="utf-8") as f:
+ dataset_info = json.load(f)
+
+ data_file: str = dataset_info[dataset[0]]["file_name"]
+ with open(os.path.join(dataset_dir, data_file), "r", encoding="utf-8") as f:
+ if data_file.endswith(".json"):
+ data = json.load(f)
+ elif data_file.endswith(".jsonl"):
+ data = [json.loads(line) for line in f]
+ else:
+ data = [line for line in f] # noqa: C416
+ return len(data), data[PAGE_SIZE * page_index : PAGE_SIZE * (page_index + 1)], gr.update(visible=True)
+
+
+def create_preview_box(dataset_dir: "gr.Textbox", dataset: "gr.Dropdown") -> Dict[str, "Component"]:
+ data_preview_btn = gr.Button(interactive=False, scale=1)
+ with gr.Column(visible=False, elem_classes="modal-box") as preview_box:
+ with gr.Row():
+ preview_count = gr.Number(value=0, interactive=False, precision=0)
+ page_index = gr.Number(value=0, interactive=False, precision=0)
+
+ with gr.Row():
+ prev_btn = gr.Button()
+ next_btn = gr.Button()
+ close_btn = gr.Button()
+
+ with gr.Row():
+ preview_samples = gr.JSON(interactive=False)
+
+ dataset.change(can_preview, [dataset_dir, dataset], [data_preview_btn], queue=False).then(
+ lambda: 0, outputs=[page_index], queue=False
+ )
+ data_preview_btn.click(
+ get_preview, [dataset_dir, dataset, page_index], [preview_count, preview_samples, preview_box], queue=False
+ )
+ prev_btn.click(prev_page, [page_index], [page_index], queue=False).then(
+ get_preview, [dataset_dir, dataset, page_index], [preview_count, preview_samples, preview_box], queue=False
+ )
+ next_btn.click(next_page, [page_index, preview_count], [page_index], queue=False).then(
+ get_preview, [dataset_dir, dataset, page_index], [preview_count, preview_samples, preview_box], queue=False
+ )
+ close_btn.click(lambda: gr.update(visible=False), outputs=[preview_box], queue=False)
+ return dict(
+ data_preview_btn=data_preview_btn,
+ preview_count=preview_count,
+ page_index=page_index,
+ prev_btn=prev_btn,
+ next_btn=next_btn,
+ close_btn=close_btn,
+ preview_samples=preview_samples,
+ )
diff --git a/src/AntSK.LLamaFactory/llamafactory/llmtuner/webui/components/eval.py b/src/AntSK.LLamaFactory/llamafactory/llmtuner/webui/components/eval.py
new file mode 100644
index 0000000..4c35ad8
--- /dev/null
+++ b/src/AntSK.LLamaFactory/llamafactory/llmtuner/webui/components/eval.py
@@ -0,0 +1,76 @@
+from typing import TYPE_CHECKING, Dict
+
+import gradio as gr
+
+from ..common import DEFAULT_DATA_DIR, list_dataset
+from .data import create_preview_box
+
+
+if TYPE_CHECKING:
+ from gradio.components import Component
+
+ from ..engine import Engine
+
+
+def create_eval_tab(engine: "Engine") -> Dict[str, "Component"]:
+ input_elems = engine.manager.get_base_elems()
+ elem_dict = dict()
+
+ with gr.Row():
+ dataset_dir = gr.Textbox(value=DEFAULT_DATA_DIR, scale=2)
+ dataset = gr.Dropdown(multiselect=True, scale=4)
+ preview_elems = create_preview_box(dataset_dir, dataset)
+
+ dataset_dir.change(list_dataset, [dataset_dir], [dataset], queue=False)
+
+ input_elems.update({dataset_dir, dataset})
+ elem_dict.update(dict(dataset_dir=dataset_dir, dataset=dataset, **preview_elems))
+
+ with gr.Row():
+ cutoff_len = gr.Slider(value=1024, minimum=4, maximum=8192, step=1)
+ max_samples = gr.Textbox(value="100000")
+ batch_size = gr.Slider(value=8, minimum=1, maximum=512, step=1)
+ predict = gr.Checkbox(value=True)
+
+ input_elems.update({cutoff_len, max_samples, batch_size, predict})
+ elem_dict.update(dict(cutoff_len=cutoff_len, max_samples=max_samples, batch_size=batch_size, predict=predict))
+
+ with gr.Row():
+ max_new_tokens = gr.Slider(10, 2048, value=128, step=1)
+ top_p = gr.Slider(0.01, 1, value=0.7, step=0.01)
+ temperature = gr.Slider(0.01, 1.5, value=0.95, step=0.01)
+ output_dir = gr.Textbox()
+
+ input_elems.update({max_new_tokens, top_p, temperature, output_dir})
+ elem_dict.update(dict(max_new_tokens=max_new_tokens, top_p=top_p, temperature=temperature, output_dir=output_dir))
+
+ with gr.Row():
+ cmd_preview_btn = gr.Button()
+ start_btn = gr.Button()
+ stop_btn = gr.Button()
+
+ with gr.Row():
+ resume_btn = gr.Checkbox(visible=False, interactive=False, value=False)
+ process_bar = gr.Slider(visible=False, interactive=False)
+
+ with gr.Box():
+ output_box = gr.Markdown()
+
+ output_elems = [output_box, process_bar]
+ elem_dict.update(
+ dict(
+ cmd_preview_btn=cmd_preview_btn,
+ start_btn=start_btn,
+ stop_btn=stop_btn,
+ resume_btn=resume_btn,
+ process_bar=process_bar,
+ output_box=output_box,
+ )
+ )
+
+ cmd_preview_btn.click(engine.runner.preview_eval, input_elems, output_elems)
+ start_btn.click(engine.runner.run_eval, input_elems, output_elems)
+ stop_btn.click(engine.runner.set_abort, queue=False)
+ resume_btn.change(engine.runner.monitor, outputs=output_elems)
+
+ return elem_dict
diff --git a/src/AntSK.LLamaFactory/llamafactory/llmtuner/webui/components/export.py b/src/AntSK.LLamaFactory/llamafactory/llmtuner/webui/components/export.py
new file mode 100644
index 0000000..a40590c
--- /dev/null
+++ b/src/AntSK.LLamaFactory/llamafactory/llmtuner/webui/components/export.py
@@ -0,0 +1,117 @@
+from typing import TYPE_CHECKING, Dict, Generator, List
+
+import gradio as gr
+
+from ...train import export_model
+from ..common import get_save_dir
+from ..locales import ALERTS
+
+
+if TYPE_CHECKING:
+ from gradio.components import Component
+
+ from ..engine import Engine
+
+
+GPTQ_BITS = ["8", "4", "3", "2"]
+
+
+def save_model(
+ lang: str,
+ model_name: str,
+ model_path: str,
+ adapter_path: List[str],
+ finetuning_type: str,
+ template: str,
+ max_shard_size: int,
+ export_quantization_bit: int,
+ export_quantization_dataset: str,
+ export_legacy_format: bool,
+ export_dir: str,
+ export_hub_model_id: str,
+) -> Generator[str, None, None]:
+ error = ""
+ if not model_name:
+ error = ALERTS["err_no_model"][lang]
+ elif not model_path:
+ error = ALERTS["err_no_path"][lang]
+ elif not export_dir:
+ error = ALERTS["err_no_export_dir"][lang]
+ elif export_quantization_bit in GPTQ_BITS and not export_quantization_dataset:
+ error = ALERTS["err_no_dataset"][lang]
+ elif export_quantization_bit not in GPTQ_BITS and not adapter_path:
+ error = ALERTS["err_no_adapter"][lang]
+
+ if error:
+ gr.Warning(error)
+ yield error
+ return
+
+ if adapter_path:
+ adapter_name_or_path = ",".join(
+ [get_save_dir(model_name, finetuning_type, adapter) for adapter in adapter_path]
+ )
+ else:
+ adapter_name_or_path = None
+
+ args = dict(
+ model_name_or_path=model_path,
+ adapter_name_or_path=adapter_name_or_path,
+ finetuning_type=finetuning_type,
+ template=template,
+ export_dir=export_dir,
+ export_hub_model_id=export_hub_model_id or None,
+ export_size=max_shard_size,
+ export_quantization_bit=int(export_quantization_bit) if export_quantization_bit in GPTQ_BITS else None,
+ export_quantization_dataset=export_quantization_dataset,
+ export_legacy_format=export_legacy_format,
+ )
+
+ yield ALERTS["info_exporting"][lang]
+ export_model(args)
+ yield ALERTS["info_exported"][lang]
+
+
+def create_export_tab(engine: "Engine") -> Dict[str, "Component"]:
+ with gr.Row():
+ max_shard_size = gr.Slider(value=1, minimum=1, maximum=100)
+ export_quantization_bit = gr.Dropdown(choices=["none", "8", "4", "3", "2"], value="none")
+ export_quantization_dataset = gr.Textbox(value="data/c4_demo.json")
+ export_legacy_format = gr.Checkbox()
+
+ with gr.Row():
+ export_dir = gr.Textbox()
+ export_hub_model_id = gr.Textbox()
+
+ export_btn = gr.Button()
+ info_box = gr.Textbox(show_label=False, interactive=False)
+
+ export_btn.click(
+ save_model,
+ [
+ engine.manager.get_elem_by_name("top.lang"),
+ engine.manager.get_elem_by_name("top.model_name"),
+ engine.manager.get_elem_by_name("top.model_path"),
+ engine.manager.get_elem_by_name("top.adapter_path"),
+ engine.manager.get_elem_by_name("top.finetuning_type"),
+ engine.manager.get_elem_by_name("top.template"),
+ max_shard_size,
+ export_quantization_bit,
+ export_quantization_dataset,
+ export_legacy_format,
+ export_dir,
+ export_hub_model_id,
+ ],
+ [info_box],
+ )
+
+ return dict(
+ max_shard_size=max_shard_size,
+ export_quantization_bit=export_quantization_bit,
+ export_quantization_dataset=export_quantization_dataset,
+ export_legacy_format=export_legacy_format,
+ export_dir=export_dir,
+ export_hub_model_id=export_hub_model_id,
+ export_btn=export_btn,
+ info_box=info_box,
+ )
diff --git a/src/AntSK.LLamaFactory/llamafactory/llmtuner/webui/components/infer.py b/src/AntSK.LLamaFactory/llamafactory/llmtuner/webui/components/infer.py
new file mode 100644
index 0000000..135535a
--- /dev/null
+++ b/src/AntSK.LLamaFactory/llamafactory/llmtuner/webui/components/infer.py
@@ -0,0 +1,39 @@
+from typing import TYPE_CHECKING, Dict
+
+import gradio as gr
+
+from .chatbot import create_chat_box
+
+
+if TYPE_CHECKING:
+ from gradio.components import Component
+
+ from ..engine import Engine
+
+
+def create_infer_tab(engine: "Engine") -> Dict[str, "Component"]:
+ input_elems = engine.manager.get_base_elems()
+ elem_dict = dict()
+
+ infer_backend = gr.Dropdown(choices=["huggingface", "vllm"], value="huggingface")
+ with gr.Row():
+ load_btn = gr.Button()
+ unload_btn = gr.Button()
+
+ info_box = gr.Textbox(show_label=False, interactive=False)
+
+ input_elems.update({infer_backend})
+ elem_dict.update(dict(infer_backend=infer_backend, load_btn=load_btn, unload_btn=unload_btn, info_box=info_box))
+
+ chat_box, chatbot, history, chat_elems = create_chat_box(engine, visible=False)
+ elem_dict.update(dict(chat_box=chat_box, **chat_elems))
+
+ load_btn.click(engine.chatter.load_model, input_elems, [info_box]).then(
+ lambda: gr.update(visible=engine.chatter.loaded), outputs=[chat_box]
+ )
+
+ unload_btn.click(engine.chatter.unload_model, input_elems, [info_box]).then(
+ lambda: ([], []), outputs=[chatbot, history]
+ ).then(lambda: gr.update(visible=engine.chatter.loaded), outputs=[chat_box])
+
+ return elem_dict
diff --git a/src/AntSK.LLamaFactory/llamafactory/llmtuner/webui/components/top.py b/src/AntSK.LLamaFactory/llamafactory/llmtuner/webui/components/top.py
new file mode 100644
index 0000000..d8b4958
--- /dev/null
+++ b/src/AntSK.LLamaFactory/llamafactory/llmtuner/webui/components/top.py
@@ -0,0 +1,59 @@
+from typing import TYPE_CHECKING, Dict, Tuple
+
+import gradio as gr
+
+from ...data import templates
+from ...extras.constants import METHODS, SUPPORTED_MODELS
+from ..common import get_model_path, get_template, list_adapters, save_config
+from ..utils import can_quantize
+
+
+if TYPE_CHECKING:
+ from gradio.components import Component
+
+
+def create_top() -> Tuple["gr.Dropdown", Dict[str, "Component"]]:
+ available_models = list(SUPPORTED_MODELS.keys()) + ["Custom"]
+
+ with gr.Row():
+ lang = gr.Dropdown(choices=["en", "ru", "zh"], scale=1)
+ model_name = gr.Dropdown(choices=available_models, scale=3)
+ model_path = gr.Textbox(scale=3)
+
+ with gr.Row():
+ finetuning_type = gr.Dropdown(choices=METHODS, value="lora", scale=1)
+ adapter_path = gr.Dropdown(multiselect=True, allow_custom_value=True, scale=5)
+ refresh_btn = gr.Button(scale=1)
+
+ with gr.Accordion(label="Advanced config", open=False) as advanced_tab:
+ with gr.Row():
+ quantization_bit = gr.Dropdown(choices=["none", "8", "4"], value="none")
+ template = gr.Dropdown(choices=list(templates.keys()), value="default")
+ rope_scaling = gr.Radio(choices=["none", "linear", "dynamic"], value="none")
+ booster = gr.Radio(choices=["none", "flashattn", "unsloth"], value="none")
+
+ model_name.change(list_adapters, [model_name, finetuning_type], [adapter_path], queue=False).then(
+ get_model_path, [model_name], [model_path], queue=False
+ ).then(get_template, [model_name], [template], queue=False) # do not save config since the below line will save
+
+ model_path.change(save_config, inputs=[lang, model_name, model_path], queue=False)
+
+ finetuning_type.change(list_adapters, [model_name, finetuning_type], [adapter_path], queue=False).then(
+ can_quantize, [finetuning_type], [quantization_bit], queue=False
+ )
+
+ refresh_btn.click(list_adapters, [model_name, finetuning_type], [adapter_path], queue=False)
+
+ return lang, dict(
+ lang=lang,
+ model_name=model_name,
+ model_path=model_path,
+ finetuning_type=finetuning_type,
+ adapter_path=adapter_path,
+ refresh_btn=refresh_btn,
+ advanced_tab=advanced_tab,
+ quantization_bit=quantization_bit,
+ template=template,
+ rope_scaling=rope_scaling,
+ booster=booster,
+ )
diff --git a/src/AntSK.LLamaFactory/llamafactory/llmtuner/webui/components/train.py b/src/AntSK.LLamaFactory/llamafactory/llmtuner/webui/components/train.py
new file mode 100644
index 0000000..0725f5e
--- /dev/null
+++ b/src/AntSK.LLamaFactory/llamafactory/llmtuner/webui/components/train.py
@@ -0,0 +1,246 @@
+from typing import TYPE_CHECKING, Dict
+
+import gradio as gr
+from transformers.trainer_utils import SchedulerType
+
+from ...extras.constants import TRAINING_STAGES
+from ..common import DEFAULT_DATA_DIR, autoset_packing, list_adapters, list_dataset
+from ..components.data import create_preview_box
+from ..utils import gen_plot
+
+
+if TYPE_CHECKING:
+ from gradio.components import Component
+
+ from ..engine import Engine
+
+
+def create_train_tab(engine: "Engine") -> Dict[str, "Component"]:
+ input_elems = engine.manager.get_base_elems()
+ elem_dict = dict()
+
+ with gr.Row():
+ training_stage = gr.Dropdown(
+ choices=list(TRAINING_STAGES.keys()), value=list(TRAINING_STAGES.keys())[0], scale=2
+ )
+ dataset_dir = gr.Textbox(value=DEFAULT_DATA_DIR, scale=2)
+ dataset = gr.Dropdown(multiselect=True, scale=4)
+ preview_elems = create_preview_box(dataset_dir, dataset)
+
+ dataset_dir.change(list_dataset, [dataset_dir, training_stage], [dataset], queue=False)
+
+ input_elems.update({training_stage, dataset_dir, dataset})
+ elem_dict.update(dict(training_stage=training_stage, dataset_dir=dataset_dir, dataset=dataset, **preview_elems))
+
+ with gr.Row():
+ learning_rate = gr.Textbox(value="5e-5")
+ num_train_epochs = gr.Textbox(value="3.0")
+ max_grad_norm = gr.Textbox(value="1.0")
+ max_samples = gr.Textbox(value="100000")
+ compute_type = gr.Dropdown(choices=["fp16", "bf16", "fp32", "pure_bf16"], value="fp16")
+
+ input_elems.update({learning_rate, num_train_epochs, max_grad_norm, max_samples, compute_type})
+ elem_dict.update(
+ dict(
+ learning_rate=learning_rate,
+ num_train_epochs=num_train_epochs,
+ max_grad_norm=max_grad_norm,
+ max_samples=max_samples,
+ compute_type=compute_type,
+ )
+ )
+
+ with gr.Row():
+ cutoff_len = gr.Slider(value=1024, minimum=4, maximum=16384, step=1)
+ batch_size = gr.Slider(value=2, minimum=1, maximum=1024, step=1)
+ gradient_accumulation_steps = gr.Slider(value=8, minimum=1, maximum=1024, step=1)
+ val_size = gr.Slider(value=0, minimum=0, maximum=1, step=0.001)
+ lr_scheduler_type = gr.Dropdown(choices=[scheduler.value for scheduler in SchedulerType], value="cosine")
+
+ input_elems.update({cutoff_len, batch_size, gradient_accumulation_steps, val_size, lr_scheduler_type})
+ elem_dict.update(
+ dict(
+ cutoff_len=cutoff_len,
+ batch_size=batch_size,
+ gradient_accumulation_steps=gradient_accumulation_steps,
+ val_size=val_size,
+ lr_scheduler_type=lr_scheduler_type,
+ )
+ )
+
+ with gr.Accordion(label="Extra config", open=False) as extra_tab:
+ with gr.Row():
+ logging_steps = gr.Slider(value=5, minimum=5, maximum=1000, step=5)
+ save_steps = gr.Slider(value=100, minimum=10, maximum=5000, step=10)
+ warmup_steps = gr.Slider(value=0, minimum=0, maximum=5000, step=1)
+ neftune_alpha = gr.Slider(value=0, minimum=0, maximum=10, step=0.1)
+ optim = gr.Textbox(value="adamw_torch")
+
+ with gr.Row():
+ resize_vocab = gr.Checkbox()
+ packing = gr.Checkbox()
+ upcast_layernorm = gr.Checkbox()
+ use_llama_pro = gr.Checkbox()
+ shift_attn = gr.Checkbox()
+
+ input_elems.update(
+ {
+ logging_steps,
+ save_steps,
+ warmup_steps,
+ neftune_alpha,
+ optim,
+ resize_vocab,
+ packing,
+ upcast_layernorm,
+ use_llama_pro,
+ shift_attn,
+ }
+ )
+ elem_dict.update(
+ dict(
+ extra_tab=extra_tab,
+ logging_steps=logging_steps,
+ save_steps=save_steps,
+ warmup_steps=warmup_steps,
+ neftune_alpha=neftune_alpha,
+ optim=optim,
+ resize_vocab=resize_vocab,
+ packing=packing,
+ upcast_layernorm=upcast_layernorm,
+ use_llama_pro=use_llama_pro,
+ shift_attn=shift_attn,
+ )
+ )
+
+ with gr.Accordion(label="Freeze config", open=False) as freeze_tab:
+ with gr.Row():
+ num_layer_trainable = gr.Slider(value=3, minimum=1, maximum=128, step=1, scale=2)
+ name_module_trainable = gr.Textbox(value="all", scale=3)
+
+ input_elems.update({num_layer_trainable, name_module_trainable})
+ elem_dict.update(
+ dict(
+ freeze_tab=freeze_tab, num_layer_trainable=num_layer_trainable, name_module_trainable=name_module_trainable
+ )
+ )
+
+ with gr.Accordion(label="LoRA config", open=False) as lora_tab:
+ with gr.Row():
+ lora_rank = gr.Slider(value=8, minimum=1, maximum=1024, step=1, scale=1)
+ lora_alpha = gr.Slider(value=16, minimum=1, maximum=2048, step=1, scale=1)
+ lora_dropout = gr.Slider(value=0.1, minimum=0, maximum=1, step=0.01, scale=1)
+ lora_target = gr.Textbox(scale=2)
+
+ with gr.Row():
+ use_rslora = gr.Checkbox(scale=1)
+ use_dora = gr.Checkbox(scale=1)
+ create_new_adapter = gr.Checkbox(scale=1)
+ additional_target = gr.Textbox(scale=2)
+
+ input_elems.update(
+ {lora_rank, lora_alpha, lora_dropout, lora_target, use_rslora, use_dora, create_new_adapter, additional_target}
+ )
+ elem_dict.update(
+ dict(
+ lora_tab=lora_tab,
+ lora_rank=lora_rank,
+ lora_alpha=lora_alpha,
+ lora_dropout=lora_dropout,
+ lora_target=lora_target,
+ use_rslora=use_rslora,
+ use_dora=use_dora,
+ create_new_adapter=create_new_adapter,
+ additional_target=additional_target,
+ )
+ )
+
+ with gr.Accordion(label="RLHF config", open=False) as rlhf_tab:
+ with gr.Row():
+ dpo_beta = gr.Slider(value=0.1, minimum=0, maximum=1, step=0.01, scale=1)
+ dpo_ftx = gr.Slider(value=0, minimum=0, maximum=10, step=0.01, scale=1)
+ reward_model = gr.Dropdown(multiselect=True, allow_custom_value=True, scale=2)
+
+ training_stage.change(list_dataset, [dataset_dir, training_stage], [dataset], queue=False).then(
+ list_adapters,
+ [engine.manager.get_elem_by_name("top.model_name"), engine.manager.get_elem_by_name("top.finetuning_type")],
+ [reward_model],
+ queue=False,
+ ).then(autoset_packing, [training_stage], [packing], queue=False)
+
+ input_elems.update({dpo_beta, dpo_ftx, reward_model})
+ elem_dict.update(dict(rlhf_tab=rlhf_tab, dpo_beta=dpo_beta, dpo_ftx=dpo_ftx, reward_model=reward_model))
+
+ with gr.Accordion(label="GaLore config", open=False) as galore_tab:
+ with gr.Row():
+ use_galore = gr.Checkbox(scale=1)
+ galore_rank = gr.Slider(value=16, minimum=1, maximum=1024, step=1, scale=2)
+ galore_update_interval = gr.Slider(value=200, minimum=1, maximum=1024, step=1, scale=2)
+ galore_scale = gr.Slider(value=0.25, minimum=0, maximum=1, step=0.01, scale=2)
+ galore_target = gr.Textbox(value="mlp,attn", scale=3)
+
+ input_elems.update({use_galore, galore_rank, galore_update_interval, galore_scale, galore_target})
+ elem_dict.update(
+ dict(
+ galore_tab=galore_tab,
+ use_galore=use_galore,
+ galore_rank=galore_rank,
+ galore_update_interval=galore_update_interval,
+ galore_scale=galore_scale,
+ galore_target=galore_target,
+ )
+ )
+
+ with gr.Row():
+ cmd_preview_btn = gr.Button()
+ start_btn = gr.Button()
+ stop_btn = gr.Button()
+
+ with gr.Row():
+ with gr.Column(scale=3):
+ with gr.Row():
+ output_dir = gr.Textbox()
+
+ with gr.Row():
+ resume_btn = gr.Checkbox(visible=False, interactive=False)
+ process_bar = gr.Slider(visible=False, interactive=False)
+
+ with gr.Box():
+ output_box = gr.Markdown()
+
+ with gr.Column(scale=1):
+ loss_viewer = gr.Plot()
+
+ input_elems.add(output_dir)
+ output_elems = [output_box, process_bar]
+
+ cmd_preview_btn.click(engine.runner.preview_train, input_elems, output_elems)
+ start_btn.click(engine.runner.run_train, input_elems, output_elems)
+ stop_btn.click(engine.runner.set_abort, queue=False)
+ resume_btn.change(engine.runner.monitor, outputs=output_elems)
+
+ elem_dict.update(
+ dict(
+ cmd_preview_btn=cmd_preview_btn,
+ start_btn=start_btn,
+ stop_btn=stop_btn,
+ output_dir=output_dir,
+ resume_btn=resume_btn,
+ process_bar=process_bar,
+ output_box=output_box,
+ loss_viewer=loss_viewer,
+ )
+ )
+
+ output_box.change(
+ gen_plot,
+ [
+ engine.manager.get_elem_by_name("top.model_name"),
+ engine.manager.get_elem_by_name("top.finetuning_type"),
+ output_dir,
+ ],
+ loss_viewer,
+ queue=False,
+ )
+
+ return elem_dict
diff --git a/src/AntSK.LLamaFactory/llamafactory/llmtuner/webui/css.py b/src/AntSK.LLamaFactory/llamafactory/llmtuner/webui/css.py
new file mode 100644
index 0000000..36e3d4c
--- /dev/null
+++ b/src/AntSK.LLamaFactory/llamafactory/llmtuner/webui/css.py
@@ -0,0 +1,27 @@
+CSS = r"""
+.duplicate-button {
+ margin: auto !important;
+ color: white !important;
+ background: black !important;
+ border-radius: 100vh !important;
+}
+
+.modal-box {
+ position: fixed !important;
+ top: 50%;
+ left: 50%;
+ transform: translate(-50%, -50%); /* center horizontally */
+ max-width: 1000px;
+ max-height: 750px;
+ overflow-y: auto;
+ background-color: var(--input-background-fill);
+ flex-wrap: nowrap !important;
+ border: 2px solid black !important;
+ z-index: 1000;
+ padding: 10px;
+}
+
+.dark .modal-box {
+ border: 2px solid white !important;
+}
+"""
diff --git a/src/AntSK.LLamaFactory/llamafactory/llmtuner/webui/engine.py b/src/AntSK.LLamaFactory/llamafactory/llmtuner/webui/engine.py
new file mode 100644
index 0000000..fb04ca0
--- /dev/null
+++ b/src/AntSK.LLamaFactory/llamafactory/llmtuner/webui/engine.py
@@ -0,0 +1,62 @@
+from typing import Any, Dict, Generator
+
+import gradio as gr
+from gradio.components import Component # cannot use TYPE_CHECKING here
+
+from .chatter import WebChatModel
+from .common import get_model_path, list_dataset, load_config
+from .locales import LOCALES
+from .manager import Manager
+from .runner import Runner
+from .utils import get_time
+
+
+class Engine:
+ def __init__(self, demo_mode: bool = False, pure_chat: bool = False) -> None:
+ self.demo_mode = demo_mode
+ self.pure_chat = pure_chat
+ self.manager = Manager()
+ self.runner = Runner(self.manager, demo_mode)
+ self.chatter = WebChatModel(self.manager, demo_mode, lazy_init=(not pure_chat))
+
+ def _form_dict(self, resume_dict: Dict[str, Dict[str, Any]]):
+ return {self.manager.get_elem_by_name(k): gr.update(**v) for k, v in resume_dict.items()}
+
+ def resume(self) -> Generator[Dict[Component, Dict[str, Any]], None, None]:
+ user_config = load_config() if not self.demo_mode else {}
+ lang = user_config.get("lang", None) or "en"
+
+ init_dict = {"top.lang": {"value": lang}, "infer.chat_box": {"visible": self.chatter.loaded}}
+
+ if not self.pure_chat:
+ init_dict["train.dataset"] = {"choices": list_dataset()["choices"]}
+ init_dict["eval.dataset"] = {"choices": list_dataset()["choices"]}
+
+ if user_config.get("last_model", None):
+ init_dict["top.model_name"] = {"value": user_config["last_model"]}
+ init_dict["top.model_path"] = {"value": get_model_path(user_config["last_model"])}
+
+ yield self._form_dict(init_dict)
+
+ if not self.pure_chat:
+ if self.runner.alive and not self.demo_mode:
+ yield {elem: gr.update(value=value) for elem, value in self.runner.running_data.items()}
+ if self.runner.do_train:
+ yield self._form_dict({"train.resume_btn": {"value": True}})
+ else:
+ yield self._form_dict({"eval.resume_btn": {"value": True}})
+ else:
+ yield self._form_dict(
+ {
+ "train.output_dir": {"value": "train_" + get_time()},
+ "eval.output_dir": {"value": "eval_" + get_time()},
+ }
+ )
+
+ def change_lang(self, lang: str) -> Dict[Component, Dict[str, Any]]:
+ return {
+ component: gr.update(**LOCALES[name][lang])
+ for elems in self.manager.all_elems.values()
+ for name, component in elems.items()
+ if name in LOCALES
+ }
diff --git a/src/AntSK.LLamaFactory/llamafactory/llmtuner/webui/interface.py b/src/AntSK.LLamaFactory/llamafactory/llmtuner/webui/interface.py
new file mode 100644
index 0000000..a1f4d53
--- /dev/null
+++ b/src/AntSK.LLamaFactory/llamafactory/llmtuner/webui/interface.py
@@ -0,0 +1,74 @@
+import gradio as gr
+from transformers.utils.versions import require_version
+
+from .common import save_config
+from .components import (
+ create_chat_box,
+ create_eval_tab,
+ create_export_tab,
+ create_infer_tab,
+ create_top,
+ create_train_tab,
+)
+from .css import CSS
+from .engine import Engine
+
+
+require_version("gradio>=3.38.0,<4.0.0", 'To fix: pip install "gradio>=3.38.0,<4.0.0"')
+
+
+def create_ui(demo_mode: bool = False) -> gr.Blocks:
+ engine = Engine(demo_mode=demo_mode, pure_chat=False)
+
+ with gr.Blocks(title="LLaMA Board", css=CSS) as demo:
+ if demo_mode:
+ gr.HTML("LLaMA Board: A One-stop Web UI for Getting Started with LLaMA Factory
")
+ gr.HTML(
+ '"
+ )
+ gr.DuplicateButton(value="Duplicate Space for private use", elem_classes="duplicate-button")
+
+ lang, engine.manager.all_elems["top"] = create_top()
+
+ with gr.Tab("Train"):
+ engine.manager.all_elems["train"] = create_train_tab(engine)
+
+ with gr.Tab("Evaluate & Predict"):
+ engine.manager.all_elems["eval"] = create_eval_tab(engine)
+
+ with gr.Tab("Chat"):
+ engine.manager.all_elems["infer"] = create_infer_tab(engine)
+
+ if not demo_mode:
+ with gr.Tab("Export"):
+ engine.manager.all_elems["export"] = create_export_tab(engine)
+
+ demo.load(engine.resume, outputs=engine.manager.list_elems())
+ lang.change(engine.change_lang, [lang], engine.manager.list_elems(), queue=False)
+ lang.input(save_config, inputs=[lang], queue=False)
+
+ return demo
+
+
+def create_web_demo() -> gr.Blocks:
+ engine = Engine(pure_chat=True)
+
+ with gr.Blocks(title="Web Demo", css=CSS) as demo:
+ lang = gr.Dropdown(choices=["en", "zh"])
+ engine.manager.all_elems["top"] = dict(lang=lang)
+
+ chat_box, _, _, chat_elems = create_chat_box(engine, visible=True)
+ engine.manager.all_elems["infer"] = dict(chat_box=chat_box, **chat_elems)
+
+ demo.load(engine.resume, outputs=engine.manager.list_elems())
+ lang.change(engine.change_lang, [lang], engine.manager.list_elems(), queue=False)
+ lang.input(save_config, inputs=[lang], queue=False)
+
+ return demo
+
+
+if __name__ == "__main__":
+ demo = create_ui()
+ demo.queue()
+ demo.launch(server_name="0.0.0.0", share=False, inbrowser=True)
diff --git a/src/AntSK.LLamaFactory/llamafactory/llmtuner/webui/locales.py b/src/AntSK.LLamaFactory/llamafactory/llmtuner/webui/locales.py
new file mode 100644
index 0000000..4f329e8
--- /dev/null
+++ b/src/AntSK.LLamaFactory/llamafactory/llmtuner/webui/locales.py
@@ -0,0 +1,1289 @@
+LOCALES = {
+ "lang": {
+ "en": {
+ "label": "Lang",
+ },
+ "ru": {
+ "label": "Русский",
+ },
+ "zh": {
+ "label": "语言",
+ },
+ },
+ "model_name": {
+ "en": {
+ "label": "Model name",
+ },
+ "ru": {
+ "label": "Название модели",
+ },
+ "zh": {
+ "label": "模型名称",
+ },
+ },
+ "model_path": {
+ "en": {
+ "label": "Model path",
+ "info": "Path to pretrained model or model identifier from Hugging Face.",
+ },
+ "ru": {
+ "label": "Путь к модели",
+ "info": "Путь к предварительно обученной модели или идентификатор модели от Hugging Face.",
+ },
+ "zh": {
+ "label": "模型路径",
+ "info": "本地模型的文件路径或 Hugging Face 的模型标识符。",
+ },
+ },
+ "finetuning_type": {
+ "en": {
+ "label": "Finetuning method",
+ },
+ "ru": {
+ "label": "Метод дообучения",
+ },
+ "zh": {
+ "label": "微调方法",
+ },
+ },
+ "adapter_path": {
+ "en": {
+ "label": "Adapter path",
+ },
+ "ru": {
+ "label": "Путь к адаптеру",
+ },
+ "zh": {
+ "label": "适配器路径",
+ },
+ },
+ "refresh_btn": {
+ "en": {
+ "value": "Refresh adapters",
+ },
+ "ru": {
+ "value": "Обновить адаптеры",
+ },
+ "zh": {
+ "value": "刷新适配器",
+ },
+ },
+ "advanced_tab": {
+ "en": {
+ "label": "Advanced configurations",
+ },
+ "ru": {
+ "label": "Расширенные конфигурации",
+ },
+ "zh": {
+ "label": "高级设置",
+ },
+ },
+ "quantization_bit": {
+ "en": {
+ "label": "Quantization bit",
+ "info": "Enable 4/8-bit model quantization (QLoRA).",
+ },
+ "ru": {
+ "label": "Уровень квантования",
+ "info": "Включить 4/8-битное квантование модели (QLoRA).",
+ },
+ "zh": {
+ "label": "量化等级",
+ "info": "启用 4/8 比特模型量化(QLoRA)。",
+ },
+ },
+ "template": {
+ "en": {
+ "label": "Prompt template",
+ "info": "The template used in constructing prompts.",
+ },
+ "ru": {
+ "label": "Шаблон запроса",
+ "info": "Шаблон, используемый при формировании запросов.",
+ },
+ "zh": {
+ "label": "提示模板",
+ "info": "构建提示词时使用的模板",
+ },
+ },
+ "rope_scaling": {
+ "en": {
+ "label": "RoPE scaling",
+ },
+ "ru": {
+ "label": "Масштабирование RoPE",
+ },
+ "zh": {
+ "label": "RoPE 插值方法",
+ },
+ },
+ "booster": {
+ "en": {
+ "label": "Booster",
+ },
+ "ru": {
+ "label": "Ускоритель",
+ },
+ "zh": {
+ "label": "加速方式",
+ },
+ },
+ "training_stage": {
+ "en": {
+ "label": "Stage",
+ "info": "The stage to perform in training.",
+ },
+ "ru": {
+ "label": "Этап",
+ "info": "Этап выполнения обучения.",
+ },
+ "zh": {
+ "label": "训练阶段",
+ "info": "目前采用的训练方式。",
+ },
+ },
+ "dataset_dir": {
+ "en": {
+ "label": "Data dir",
+ "info": "Path to the data directory.",
+ },
+ "ru": {
+ "label": "Директория данных",
+ "info": "Путь к директории данных.",
+ },
+ "zh": {
+ "label": "数据路径",
+ "info": "数据文件夹的路径。",
+ },
+ },
+ "dataset": {
+ "en": {
+ "label": "Dataset",
+ },
+ "ru": {
+ "label": "Набор данных",
+ },
+ "zh": {
+ "label": "数据集",
+ },
+ },
+ "data_preview_btn": {
+ "en": {
+ "value": "Preview dataset",
+ },
+ "ru": {
+ "value": "Просмотреть набор данных",
+ },
+ "zh": {
+ "value": "预览数据集",
+ },
+ },
+ "preview_count": {
+ "en": {
+ "label": "Count",
+ },
+ "ru": {
+ "label": "Количество",
+ },
+ "zh": {
+ "label": "数量",
+ },
+ },
+ "page_index": {
+ "en": {
+ "label": "Page",
+ },
+ "ru": {
+ "label": "Страница",
+ },
+ "zh": {
+ "label": "页数",
+ },
+ },
+ "prev_btn": {
+ "en": {
+ "value": "Prev",
+ },
+ "ru": {
+ "value": "Предыдущая",
+ },
+ "zh": {
+ "value": "上一页",
+ },
+ },
+ "next_btn": {
+ "en": {
+ "value": "Next",
+ },
+ "ru": {
+ "value": "Следующая",
+ },
+ "zh": {
+ "value": "下一页",
+ },
+ },
+ "close_btn": {
+ "en": {
+ "value": "Close",
+ },
+ "ru": {
+ "value": "Закрыть",
+ },
+ "zh": {
+ "value": "关闭",
+ },
+ },
+ "preview_samples": {
+ "en": {
+ "label": "Samples",
+ },
+ "ru": {
+ "label": "Примеры",
+ },
+ "zh": {
+ "label": "样例",
+ },
+ },
+ "learning_rate": {
+ "en": {
+ "label": "Learning rate",
+ "info": "Initial learning rate for AdamW.",
+ },
+ "ru": {
+ "label": "Скорость обучения",
+ "info": "Начальная скорость обучения для AdamW.",
+ },
+ "zh": {
+ "label": "学习率",
+ "info": "AdamW 优化器的初始学习率。",
+ },
+ },
+ "num_train_epochs": {
+ "en": {
+ "label": "Epochs",
+ "info": "Total number of training epochs to perform.",
+ },
+ "ru": {
+ "label": "Эпохи",
+ "info": "Общее количество эпох обучения.",
+ },
+ "zh": {
+ "label": "训练轮数",
+ "info": "需要执行的训练总轮数。",
+ },
+ },
+ "max_grad_norm": {
+ "en": {
+ "label": "Maximum gradient norm",
+ "info": "Norm for gradient clipping.",
+ },
+ "ru": {
+ "label": "Максимальная норма градиента",
+ "info": "Норма для обрезки градиента.",
+ },
+ "zh": {
+ "label": "最大梯度范数",
+ "info": "用于梯度裁剪的范数。",
+ },
+ },
+ "max_samples": {
+ "en": {
+ "label": "Max samples",
+ "info": "Maximum samples per dataset.",
+ },
+ "ru": {
+ "label": "Максимальное количество образцов",
+ "info": "Максимальное количество образцов на набор данных.",
+ },
+ "zh": {
+ "label": "最大样本数",
+ "info": "每个数据集的最大样本数。",
+ },
+ },
+ "compute_type": {
+ "en": {
+ "label": "Compute type",
+ "info": "Whether to use mixed precision training.",
+ },
+ "ru": {
+ "label": "Тип вычислений",
+ "info": "Использовать ли обучение смешанной точности.",
+ },
+ "zh": {
+ "label": "计算类型",
+ "info": "是否使用混合精度训练。",
+ },
+ },
+ "cutoff_len": {
+ "en": {
+ "label": "Cutoff length",
+ "info": "Max tokens in input sequence.",
+ },
+ "ru": {
+ "label": "Длина обрезки",
+ "info": "Максимальное количество токенов во входной последовательности.",
+ },
+ "zh": {
+ "label": "截断长度",
+ "info": "输入序列分词后的最大长度。",
+ },
+ },
+ "batch_size": {
+ "en": {
+ "label": "Batch size",
+ "info": "Number of samples processed on each GPU.",
+ },
+ "ru": {
+ "label": "Размер пакета",
+ "info": "Количество образцов для обработки на каждом GPU.",
+ },
+ "zh": {
+ "label": "批处理大小",
+ "info": "每个 GPU 处理的样本数量。",
+ },
+ },
+ "gradient_accumulation_steps": {
+ "en": {
+ "label": "Gradient accumulation",
+ "info": "Number of steps for gradient accumulation.",
+ },
+ "ru": {
+ "label": "Накопление градиента",
+ "info": "Количество шагов накопления градиента.",
+ },
+ "zh": {
+ "label": "梯度累积",
+ "info": "梯度累积的步数。",
+ },
+ },
+ "val_size": {
+ "en": {
+ "label": "Val size",
+ "info": "Proportion of data in the dev set.",
+ },
+ "ru": {
+ "label": "Размер валидации",
+ "info": "Пропорция данных в наборе для разработки.",
+ },
+ "zh": {
+ "label": "验证集比例",
+ "info": "验证集占全部样本的百分比。",
+ },
+ },
+ "lr_scheduler_type": {
+ "en": {
+ "label": "LR scheduler",
+ "info": "Name of the learning rate scheduler.",
+ },
+ "ru": {
+ "label": "Планировщик скорости обучения",
+ "info": "Название планировщика скорости обучения.",
+ },
+ "zh": {
+ "label": "学习率调节器",
+ "info": "学习率调度器的名称。",
+ },
+ },
+ "extra_tab": {
+ "en": {
+ "label": "Extra configurations",
+ },
+ "ru": {
+ "label": "Дополнительные конфигурации",
+ },
+ "zh": {
+ "label": "其它参数设置",
+ },
+ },
+ "logging_steps": {
+ "en": {
+ "label": "Logging steps",
+ "info": "Number of steps between two logs.",
+ },
+ "ru": {
+ "label": "Шаги логирования",
+ "info": "Количество шагов между двумя записями в журнале.",
+ },
+ "zh": {
+ "label": "日志间隔",
+ "info": "每两次日志输出间的更新步数。",
+ },
+ },
+ "save_steps": {
+ "en": {
+ "label": "Save steps",
+ "info": "Number of steps between two checkpoints.",
+ },
+ "ru": {
+ "label": "Шаги сохранения",
+ "info": "Количество шагов между двумя контрольными точками.",
+ },
+ "zh": {
+ "label": "保存间隔",
+ "info": "每两次断点保存间的更新步数。",
+ },
+ },
+ "warmup_steps": {
+ "en": {
+ "label": "Warmup steps",
+ "info": "Number of steps used for warmup.",
+ },
+ "ru": {
+ "label": "Шаги прогрева",
+ "info": "Количество шагов, используемых для прогрева.",
+ },
+ "zh": {
+ "label": "预热步数",
+ "info": "学习率预热采用的步数。",
+ },
+ },
+ "neftune_alpha": {
+ "en": {
+ "label": "NEFTune Alpha",
+ "info": "Magnitude of noise adding to embedding vectors.",
+ },
+ "ru": {
+ "label": "NEFTune Alpha",
+ "info": "Величина шума, добавляемого к векторам вложений.",
+ },
+ "zh": {
+ "label": "NEFTune 噪声参数",
+ "info": "嵌入向量所添加的噪声大小。",
+ },
+ },
+ "optim": {
+ "en": {
+ "label": "Optimizer",
+ "info": "The optimizer to use: adamw_torch, adamw_8bit or adafactor.",
+ },
+ "ru": {
+ "label": "Оптимизатор",
+ "info": "Оптимизатор для использования: adamw_torch, adamw_8bit или adafactor.",
+ },
+ "zh": {
+ "label": "优化器",
+ "info": "使用的优化器:adamw_torch、adamw_8bit 或 adafactor。",
+ },
+ },
+ "resize_vocab": {
+ "en": {
+ "label": "Resize token embeddings",
+ "info": "Resize the tokenizer vocab and the embedding layers.",
+ },
+ "ru": {
+ "label": "Изменение размера токенных эмбеддингов",
+ "info": "Изменить размер словаря токенизатора и слоев эмбеддинга.",
+ },
+ "zh": {
+ "label": "更改词表大小",
+ "info": "更改分词器词表和嵌入层的大小。",
+ },
+ },
+ "packing": {
+ "en": {
+ "label": "Pack sequences",
+ "info": "Pack sequences into samples of fixed length.",
+ },
+ "ru": {
+ "label": "Упаковка последовательностей",
+ "info": "Упаковка последовательностей в образцы фиксированной длины.",
+ },
+ "zh": {
+ "label": "序列打包",
+ "info": "将序列打包为等长样本。",
+ },
+ },
+ "upcast_layernorm": {
+ "en": {
+ "label": "Upcast LayerNorm",
+ "info": "Upcast weights of layernorm in float32.",
+ },
+ "ru": {
+ "label": "Приведение весов LayerNorm",
+ "info": "Приведение весов LayerNorm к float32.",
+ },
+ "zh": {
+ "label": "缩放归一化层",
+ "info": "将归一化层权重缩放至 32 位精度。",
+ },
+ },
+ "use_llama_pro": {
+ "en": {
+ "label": "Enable LLaMA Pro",
+ "info": "Make the parameters in the expanded blocks trainable.",
+ },
+ "ru": {
+ "label": "Включить LLaMA Pro",
+ "info": "Сделать параметры в расширенных блоках обучаемыми.",
+ },
+ "zh": {
+ "label": "使用 LLaMA Pro",
+ "info": "仅训练块扩展后的参数。",
+ },
+ },
+ "shift_attn": {
+ "en": {
+ "label": "Enable S^2 Attention",
+ "info": "Use shift short attention proposed by LongLoRA.",
+ },
+ "ru": {
+ "label": "Включить S^2 внимание",
+ "info": "Использовать сдвиг внимания на короткие дистанции предложенный LongLoRA.",
+ },
+ "zh": {
+ "label": "使用 S^2 Attention",
+ "info": "使用 LongLoRA 提出的 shift short attention。",
+ },
+ },
+ "freeze_tab": {
+ "en": {
+ "label": "Freeze tuning configurations",
+ },
+ "ru": {
+ "label": "конфигурации для настройки заморозки",
+ },
+ "zh": {
+ "label": "部分参数微调设置",
+ },
+ },
+ "num_layer_trainable": {
+ "en": {
+ "label": "Trainable layers",
+ "info": "The number of trainable layers.",
+ },
+ "ru": {
+ "label": "Обучаемые слои",
+ "info": "Количество обучаемых слоев.",
+ },
+ "zh": {
+ "label": "可训练层数",
+ "info": "可训练模型层的数量。",
+ },
+ },
+ "name_module_trainable": {
+ "en": {
+ "label": "Trainable modules",
+ "info": "The name of trainable modules. Use commas to separate multiple modules.",
+ },
+ "ru": {
+ "label": "Обучаемые модули",
+ "info": "Название обучаемых модулей. Используйте запятые для разделения нескольких модулей.",
+ },
+ "zh": {
+ "label": "可训练模块",
+ "info": "可训练模块的名称。使用英文逗号分隔多个名称。",
+ },
+ },
+ "lora_tab": {
+ "en": {
+ "label": "LoRA configurations",
+ },
+ "ru": {
+ "label": "Конфигурации LoRA",
+ },
+ "zh": {
+ "label": "LoRA 参数设置",
+ },
+ },
+ "lora_rank": {
+ "en": {
+ "label": "LoRA rank",
+ "info": "The rank of LoRA matrices.",
+ },
+ "ru": {
+ "label": "Ранг матриц LoRA",
+ "info": "Ранг матриц LoRA.",
+ },
+ "zh": {
+ "label": "LoRA 秩",
+ "info": "LoRA 矩阵的秩大小。",
+ },
+ },
+ "lora_alpha": {
+ "en": {
+ "label": "LoRA alpha",
+ "info": "Lora scaling coefficient.",
+ },
+ "ru": {
+ "label": "LoRA alpha",
+ "info": "Коэффициент масштабирования LoRA.",
+ },
+ "zh": {
+ "label": "LoRA 缩放系数",
+ "info": "LoRA 缩放系数大小。",
+ },
+ },
+ "lora_dropout": {
+ "en": {
+ "label": "LoRA dropout",
+ "info": "Dropout ratio of LoRA weights.",
+ },
+ "ru": {
+ "label": "Вероятность отсева LoRA",
+ "info": "Вероятность отсева весов LoRA.",
+ },
+ "zh": {
+ "label": "LoRA 随机丢弃",
+ "info": "LoRA 权重随机丢弃的概率。",
+ },
+ },
+ "lora_target": {
+ "en": {
+ "label": "LoRA modules (optional)",
+ "info": "Name(s) of modules to apply LoRA. Use commas to separate multiple modules.",
+ },
+ "ru": {
+ "label": "Модули LoRA (опционально)",
+ "info": "Имена модулей для применения LoRA. Используйте запятые для разделения нескольких модулей.",
+ },
+ "zh": {
+ "label": "LoRA 作用模块(非必填)",
+ "info": "应用 LoRA 的模块名称。使用英文逗号分隔多个名称。",
+ },
+ },
+ "use_rslora": {
+ "en": {
+ "label": "Use rslora",
+ "info": "Use the rank stabilization scaling factor for LoRA layer.",
+ },
+ "ru": {
+ "label": "Использовать rslora",
+ "info": "Использовать коэффициент масштабирования стабилизации ранга для слоя LoRA.",
+ },
+ "zh": {
+ "label": "使用 rslora",
+ "info": "对 LoRA 层使用秩稳定缩放方法。",
+ },
+ },
+ "use_dora": {
+ "en": {
+ "label": "Use DoRA",
+ "info": "Use weight-decomposed LoRA.",
+ },
+ "ru": {
+ "label": "Используйте DoRA",
+ "info": "Используйте LoRA с декомпозицией весов.",
+ },
+ "zh": {
+ "label": "使用 DoRA",
+ "info": "使用权重分解的 LoRA。",
+ },
+ },
+ "create_new_adapter": {
+ "en": {
+ "label": "Create new adapter",
+ "info": "Create a new adapter with randomly initialized weight upon the existing one.",
+ },
+ "ru": {
+ "label": "Создать новый адаптер",
+ "info": "Создать новый адаптер с случайной инициализацией веса на основе существующего.",
+ },
+ "zh": {
+ "label": "新建适配器",
+ "info": "在现有的适配器上创建一个随机初始化后的新适配器。",
+ },
+ },
+ "additional_target": {
+ "en": {
+ "label": "Additional modules (optional)",
+ "info": (
+ "Name(s) of modules apart from LoRA layers to be set as trainable. "
+ "Use commas to separate multiple modules."
+ ),
+ },
+ "ru": {
+ "label": "Дополнительные модули (опционально)",
+ "info": (
+ "Имена модулей, кроме слоев LoRA, которые следует установить в качестве обучаемых. "
+ "Используйте запятые для разделения нескольких модулей."
+ ),
+ },
+ "zh": {
+ "label": "附加模块(非必填)",
+ "info": "除 LoRA 层以外的可训练模块名称。使用英文逗号分隔多个名称。",
+ },
+ },
+ "rlhf_tab": {
+ "en": {
+ "label": "RLHF configurations",
+ },
+ "ru": {
+ "label": "Конфигурации RLHF",
+ },
+ "zh": {
+ "label": "RLHF 参数设置",
+ },
+ },
+ "dpo_beta": {
+ "en": {
+ "label": "DPO beta",
+ "info": "Value of the beta parameter in the DPO loss.",
+ },
+ "ru": {
+ "label": "DPO бета",
+ "info": "Значение параметра бета в функции потерь DPO.",
+ },
+ "zh": {
+ "label": "DPO beta 参数",
+ "info": "DPO 损失函数中 beta 超参数大小。",
+ },
+ },
+ "dpo_ftx": {
+ "en": {
+ "label": "DPO-ftx weight",
+ "info": "The weight of SFT loss in the DPO-ftx.",
+ },
+ "ru": {
+ "label": "Вес DPO-ftx",
+ "info": "Вес функции потерь SFT в DPO-ftx.",
+ },
+ "zh": {
+ "label": "DPO-ftx 权重",
+ "info": "DPO-ftx 中 SFT 损失的权重大小。",
+ },
+ },
+ "reward_model": {
+ "en": {
+ "label": "Reward model",
+ "info": "Adapter of the reward model for PPO training.",
+ },
+ "ru": {
+ "label": "Модель вознаграждения",
+ "info": "Адаптер модели вознаграждения для обучения PPO.",
+ },
+ "zh": {
+ "label": "奖励模型",
+ "info": "PPO 训练中奖励模型的适配器路径。",
+ },
+ },
+ "galore_tab": {
+ "en": {
+ "label": "GaLore configurations",
+ },
+ "ru": {
+ "label": "Конфигурации GaLore",
+ },
+ "zh": {
+ "label": "GaLore 参数设置",
+ },
+ },
+ "use_galore": {
+ "en": {
+ "label": "Use GaLore",
+ "info": "Enable gradient low-Rank projection.",
+ },
+ "ru": {
+ "label": "Использовать GaLore",
+ "info": "Включить проекцию градиента на низкоранговое пространство.",
+ },
+ "zh": {
+ "label": "使用 GaLore",
+ "info": "使用梯度低秩投影。",
+ },
+ },
+ "galore_rank": {
+ "en": {
+ "label": "GaLore rank",
+ "info": "The rank of GaLore gradients.",
+ },
+ "ru": {
+ "label": "Ранг GaLore",
+ "info": "Ранг градиентов GaLore.",
+ },
+ "zh": {
+ "label": "GaLore 秩",
+ "info": "GaLore 梯度的秩大小。",
+ },
+ },
+ "galore_update_interval": {
+ "en": {
+ "label": "Update interval",
+ "info": "Number of steps to update the GaLore projection.",
+ },
+ "ru": {
+ "label": "Интервал обновления",
+ "info": "Количество шагов для обновления проекции GaLore.",
+ },
+ "zh": {
+ "label": "更新间隔",
+ "info": "相邻两次投影更新的步数。",
+ },
+ },
+ "galore_scale": {
+ "en": {
+ "label": "GaLore scale",
+ "info": "GaLore scaling coefficient.",
+ },
+ "ru": {
+ "label": "LoRA Alpha",
+ "info": "Коэффициент масштабирования GaLore.",
+ },
+ "zh": {
+ "label": "GaLore 缩放系数",
+ "info": "GaLore 缩放系数大小。",
+ },
+ },
+ "galore_target": {
+ "en": {
+ "label": "GaLore modules",
+ "info": "Name(s) of modules to apply GaLore. Use commas to separate multiple modules.",
+ },
+ "ru": {
+ "label": "Модули GaLore",
+ "info": "Имена модулей для применения GaLore. Используйте запятые для разделения нескольких модулей.",
+ },
+ "zh": {
+ "label": "GaLore 作用模块",
+ "info": "应用 GaLore 的模块名称。使用英文逗号分隔多个名称。",
+ },
+ },
+ "cmd_preview_btn": {
+ "en": {
+ "value": "Preview command",
+ },
+ "ru": {
+ "value": "Просмотр команды",
+ },
+ "zh": {
+ "value": "预览命令",
+ },
+ },
+ "start_btn": {
+ "en": {
+ "value": "Start",
+ },
+ "ru": {
+ "value": "Начать",
+ },
+ "zh": {
+ "value": "开始",
+ },
+ },
+ "stop_btn": {
+ "en": {
+ "value": "Abort",
+ },
+ "ru": {
+ "value": "Прервать",
+ },
+ "zh": {
+ "value": "中断",
+ },
+ },
+ "output_dir": {
+ "en": {
+ "label": "Output dir",
+ "info": "Directory for saving results.",
+ },
+ "ru": {
+ "label": "Выходной каталог",
+ "info": "Каталог для сохранения результатов.",
+ },
+ "zh": {
+ "label": "输出目录",
+ "info": "保存结果的路径。",
+ },
+ },
+ "output_box": {
+ "en": {
+ "value": "Ready.",
+ },
+ "ru": {
+ "value": "Готово.",
+ },
+ "zh": {
+ "value": "准备就绪。",
+ },
+ },
+ "loss_viewer": {
+ "en": {
+ "label": "Loss",
+ },
+ "ru": {
+ "label": "Потери",
+ },
+ "zh": {
+ "label": "损失",
+ },
+ },
+ "predict": {
+ "en": {
+ "label": "Save predictions",
+ },
+ "ru": {
+ "label": "Сохранить предсказания",
+ },
+ "zh": {
+ "label": "保存预测结果",
+ },
+ },
+ "infer_backend": {
+ "en": {
+ "label": "Inference engine",
+ },
+ "ru": {
+ "label": "Инференс движок",
+ },
+ "zh": {
+ "label": "推理引擎",
+ },
+ },
+ "load_btn": {
+ "en": {
+ "value": "Load model",
+ },
+ "ru": {
+ "value": "Загрузить модель",
+ },
+ "zh": {
+ "value": "加载模型",
+ },
+ },
+ "unload_btn": {
+ "en": {
+ "value": "Unload model",
+ },
+ "ru": {
+ "value": "Выгрузить модель",
+ },
+ "zh": {
+ "value": "卸载模型",
+ },
+ },
+ "info_box": {
+ "en": {
+ "value": "Model unloaded, please load a model first.",
+ },
+ "ru": {
+ "value": "Модель не загружена, загрузите модель сначала.",
+ },
+ "zh": {
+ "value": "模型未加载,请先加载模型。",
+ },
+ },
+ "role": {
+ "en": {
+ "label": "Role",
+ },
+ "ru": {
+ "label": "Роль",
+ },
+ "zh": {
+ "label": "角色",
+ },
+ },
+ "system": {
+ "en": {
+ "placeholder": "System prompt (optional)",
+ },
+ "ru": {
+ "placeholder": "Системный запрос (по желанию)",
+ },
+ "zh": {
+ "placeholder": "系统提示词(非必填)",
+ },
+ },
+ "tools": {
+ "en": {
+ "placeholder": "Tools (optional)",
+ },
+ "ru": {
+ "placeholder": "Инструменты (по желанию)",
+ },
+ "zh": {
+ "placeholder": "工具列表(非必填)",
+ },
+ },
+ "query": {
+ "en": {
+ "placeholder": "Input...",
+ },
+ "ru": {
+ "placeholder": "Ввод...",
+ },
+ "zh": {
+ "placeholder": "输入...",
+ },
+ },
+ "submit_btn": {
+ "en": {
+ "value": "Submit",
+ },
+ "ru": {
+ "value": "Отправить",
+ },
+ "zh": {
+ "value": "提交",
+ },
+ },
+ "max_length": {
+ "en": {
+ "label": "Maximum length",
+ },
+ "ru": {
+ "label": "Максимальная длина",
+ },
+ "zh": {
+ "label": "最大长度",
+ },
+ },
+ "max_new_tokens": {
+ "en": {
+ "label": "Maximum new tokens",
+ },
+ "ru": {
+ "label": "Максимальное количество новых токенов",
+ },
+ "zh": {
+ "label": "最大生成长度",
+ },
+ },
+ "top_p": {
+ "en": {
+ "label": "Top-p",
+ },
+ "ru": {
+ "label": "Лучшие-p",
+ },
+ "zh": {
+ "label": "Top-p 采样值",
+ },
+ },
+ "temperature": {
+ "en": {
+ "label": "Temperature",
+ },
+ "ru": {
+ "label": "Температура",
+ },
+ "zh": {
+ "label": "温度系数",
+ },
+ },
+ "clear_btn": {
+ "en": {
+ "value": "Clear history",
+ },
+ "ru": {
+ "value": "Очистить историю",
+ },
+ "zh": {
+ "value": "清空历史",
+ },
+ },
+ "max_shard_size": {
+ "en": {
+ "label": "Max shard size (GB)",
+ "info": "The maximum size for a model file.",
+ },
+ "ru": {
+ "label": "Максимальный размер фрагмента (ГБ)",
+ "info": "Максимальный размер файла модели.",
+ },
+ "zh": {
+ "label": "最大分块大小(GB)",
+ "info": "单个模型文件的最大大小。",
+ },
+ },
+ "export_quantization_bit": {
+ "en": {
+ "label": "Export quantization bit.",
+ "info": "Quantizing the exported model.",
+ },
+ "ru": {
+ "label": "Экспорт бита квантования",
+ "info": "Квантование экспортируемой модели.",
+ },
+ "zh": {
+ "label": "导出量化等级",
+ "info": "量化导出模型。",
+ },
+ },
+ "export_quantization_dataset": {
+ "en": {
+ "label": "Export quantization dataset",
+ "info": "The calibration dataset used for quantization.",
+ },
+ "ru": {
+ "label": "Экспорт набора данных для квантования",
+ "info": "Набор данных калибровки, используемый для квантования.",
+ },
+ "zh": {
+ "label": "导出量化数据集",
+ "info": "量化过程中使用的校准数据集。",
+ },
+ },
+ "export_legacy_format": {
+ "en": {
+ "label": "Export legacy format",
+ "info": "Do not use safetensors to save the model.",
+ },
+ "ru": {
+ "label": "Экспорт в устаревший формат",
+ "info": "Не использовать safetensors для сохранения модели.",
+ },
+ "zh": {
+ "label": "导出旧格式",
+ "info": "不使用 safetensors 格式保存模型。",
+ },
+ },
+ "export_dir": {
+ "en": {
+ "label": "Export dir",
+ "info": "Directory to save exported model.",
+ },
+ "ru": {
+ "label": "Каталог экспорта",
+ "info": "Каталог для сохранения экспортированной модели.",
+ },
+ "zh": {
+ "label": "导出目录",
+ "info": "保存导出模型的文件夹路径。",
+ },
+ },
+ "export_hub_model_id": {
+ "en": {
+ "label": "HF Hub ID (optional)",
+ "info": "Repo ID for uploading model to Hugging Face hub.",
+ },
+ "ru": {
+ "label": "HF Hub ID (опционально)",
+ "info": "Идентификатор репозитория для загрузки модели на Hugging Face hub.",
+ },
+ "zh": {
+ "label": "HF Hub ID(非必填)",
+ "info": "用于将模型上传至 Hugging Face Hub 的仓库 ID。",
+ },
+ },
+ "export_btn": {
+ "en": {
+ "value": "Export",
+ },
+ "ru": {
+ "value": "Экспорт",
+ },
+ "zh": {
+ "value": "开始导出",
+ },
+ },
+}
+
+
+ALERTS = {
+ "err_conflict": {
+ "en": "A process is in running, please abort it first.",
+ "ru": "Процесс уже запущен, пожалуйста, сначала прервите его.",
+ "zh": "任务已存在,请先中断训练。",
+ },
+ "err_exists": {
+ "en": "You have loaded a model, please unload it first.",
+ "ru": "Вы загрузили модель, сначала разгрузите ее.",
+ "zh": "模型已存在,请先卸载模型。",
+ },
+ "err_no_model": {
+ "en": "Please select a model.",
+ "ru": "Пожалуйста, выберите модель.",
+ "zh": "请选择模型。",
+ },
+ "err_no_path": {
+ "en": "Model not found.",
+ "ru": "Модель не найдена.",
+ "zh": "模型未找到。",
+ },
+ "err_no_dataset": {
+ "en": "Please choose a dataset.",
+ "ru": "Пожалуйста, выберите набор данных.",
+ "zh": "请选择数据集。",
+ },
+ "err_no_adapter": {
+ "en": "Please select an adapter.",
+ "ru": "Пожалуйста, выберите адаптер.",
+ "zh": "请选择适配器。",
+ },
+ "err_no_reward_model": {
+ "en": "Please select a reward model.",
+ "ru": "Пожалуйста, выберите модель вознаграждения.",
+ "zh": "请选择奖励模型。",
+ },
+ "err_no_export_dir": {
+ "en": "Please provide export dir.",
+ "ru": "Пожалуйста, укажите каталог для экспорта.",
+ "zh": "请填写导出目录",
+ },
+ "err_failed": {
+ "en": "Failed.",
+ "ru": "Ошибка.",
+ "zh": "训练出错。",
+ },
+ "err_demo": {
+ "en": "Training is unavailable in demo mode, duplicate the space to a private one first.",
+ "ru": "Обучение недоступно в демонстрационном режиме, сначала скопируйте пространство в частное.",
+ "zh": "展示模式不支持训练,请先复制到私人空间。",
+ },
+ "err_device_count": {
+ "en": "Multiple GPUs are not supported yet.",
+ "ru": "Пока не поддерживается множественные GPU.",
+ "zh": "尚不支持多 GPU 训练。",
+ },
+ "err_tool_name": {
+ "en": "Tool name not found.",
+ "ru": "Имя инструмента не найдено.",
+ "zh": "工具名称未找到。",
+ },
+ "err_json_schema": {
+ "en": "Invalid JSON schema.",
+ "ru": "Неверная схема JSON.",
+ "zh": "Json 格式错误。",
+ },
+ "warn_no_cuda": {
+ "en": "CUDA environment was not detected.",
+ "ru": "Среда CUDA не обнаружена.",
+ "zh": "未检测到 CUDA 环境。",
+ },
+ "info_aborting": {
+ "en": "Aborted, wait for terminating...",
+ "ru": "Прервано, ожидание завершения...",
+ "zh": "训练中断,正在等待线程结束……",
+ },
+ "info_aborted": {
+ "en": "Ready.",
+ "ru": "Готово.",
+ "zh": "准备就绪。",
+ },
+ "info_finished": {
+ "en": "Finished.",
+ "ru": "Завершено.",
+ "zh": "训练完毕。",
+ },
+ "info_loading": {
+ "en": "Loading model...",
+ "ru": "Загрузка модели...",
+ "zh": "加载中……",
+ },
+ "info_unloading": {
+ "en": "Unloading model...",
+ "ru": "Выгрузка модели...",
+ "zh": "卸载中……",
+ },
+ "info_loaded": {
+ "en": "Model loaded, now you can chat with your model!",
+ "ru": "Модель загружена, теперь вы можете общаться с вашей моделью!",
+ "zh": "模型已加载,可以开始聊天了!",
+ },
+ "info_unloaded": {
+ "en": "Model unloaded.",
+ "ru": "Модель выгружена.",
+ "zh": "模型已卸载。",
+ },
+ "info_exporting": {
+ "en": "Exporting model...",
+ "ru": "Экспорт модели...",
+ "zh": "正在导出模型……",
+ },
+ "info_exported": {
+ "en": "Model exported.",
+ "ru": "Модель экспортирована.",
+ "zh": "模型导出完成。",
+ },
+}
diff --git a/src/AntSK.LLamaFactory/llamafactory/llmtuner/webui/manager.py b/src/AntSK.LLamaFactory/llamafactory/llmtuner/webui/manager.py
new file mode 100644
index 0000000..51ddf49
--- /dev/null
+++ b/src/AntSK.LLamaFactory/llamafactory/llmtuner/webui/manager.py
@@ -0,0 +1,33 @@
+from typing import TYPE_CHECKING, Dict, List, Set
+
+
+if TYPE_CHECKING:
+ from gradio.components import Component
+
+
+class Manager:
+ def __init__(self) -> None:
+ self.all_elems: Dict[str, Dict[str, "Component"]] = {}
+
+ def get_elem_by_name(self, name: str) -> "Component":
+ r"""
+ Example: top.lang, train.dataset
+ """
+ tab_name, elem_name = name.split(".")
+ return self.all_elems[tab_name][elem_name]
+
+ def get_base_elems(self) -> Set["Component"]:
+ return {
+ self.all_elems["top"]["lang"],
+ self.all_elems["top"]["model_name"],
+ self.all_elems["top"]["model_path"],
+ self.all_elems["top"]["adapter_path"],
+ self.all_elems["top"]["finetuning_type"],
+ self.all_elems["top"]["quantization_bit"],
+ self.all_elems["top"]["template"],
+ self.all_elems["top"]["rope_scaling"],
+ self.all_elems["top"]["booster"],
+ }
+
+ def list_elems(self) -> List["Component"]:
+ return [elem for elems in self.all_elems.values() for elem in elems.values()]
diff --git a/src/AntSK.LLamaFactory/llamafactory/llmtuner/webui/runner.py b/src/AntSK.LLamaFactory/llamafactory/llmtuner/webui/runner.py
new file mode 100644
index 0000000..1d5396a
--- /dev/null
+++ b/src/AntSK.LLamaFactory/llamafactory/llmtuner/webui/runner.py
@@ -0,0 +1,306 @@
+import logging
+import os
+import time
+from threading import Thread
+from typing import TYPE_CHECKING, Any, Dict, Generator, Tuple
+
+import gradio as gr
+import transformers
+from gradio.components import Component # cannot use TYPE_CHECKING here
+from transformers.trainer import TRAINING_ARGS_NAME
+from transformers.utils import is_torch_cuda_available
+
+from ..extras.callbacks import LogCallback
+from ..extras.constants import TRAINING_STAGES
+from ..extras.logging import LoggerHandler
+from ..extras.misc import get_device_count, torch_gc
+from ..train import run_exp
+from .common import get_module, get_save_dir, load_config
+from .locales import ALERTS
+from .utils import gen_cmd, get_eval_results, update_process_bar
+
+
+if TYPE_CHECKING:
+ from .manager import Manager
+
+
+class Runner:
+ def __init__(self, manager: "Manager", demo_mode: bool = False) -> None:
+ self.manager = manager
+ self.demo_mode = demo_mode
+ """ Resume """
+ self.thread: "Thread" = None
+ self.do_train = True
+ self.running_data: Dict["Component", Any] = None
+ """ State """
+ self.aborted = False
+ self.running = False
+ """ Handler """
+ self.logger_handler = LoggerHandler()
+ self.logger_handler.setLevel(logging.INFO)
+ logging.root.addHandler(self.logger_handler)
+ transformers.logging.add_handler(self.logger_handler)
+
+ @property
+ def alive(self) -> bool:
+ return self.thread is not None
+
+ def set_abort(self) -> None:
+ self.aborted = True
+
+ def _initialize(self, data: Dict[Component, Any], do_train: bool, from_preview: bool) -> str:
+ get = lambda name: data[self.manager.get_elem_by_name(name)]
+ lang, model_name, model_path = get("top.lang"), get("top.model_name"), get("top.model_path")
+ dataset = get("train.dataset") if do_train else get("eval.dataset")
+ stage = TRAINING_STAGES[get("train.training_stage")]
+ reward_model = get("train.reward_model")
+
+ if self.running:
+ return ALERTS["err_conflict"][lang]
+
+ if not model_name:
+ return ALERTS["err_no_model"][lang]
+
+ if not model_path:
+ return ALERTS["err_no_path"][lang]
+
+ if len(dataset) == 0:
+ return ALERTS["err_no_dataset"][lang]
+
+ if stage == "ppo" and not reward_model:
+ return ALERTS["err_no_reward_model"][lang]
+
+ if not from_preview and self.demo_mode:
+ return ALERTS["err_demo"][lang]
+
+ if not from_preview and get_device_count() > 1:
+ return ALERTS["err_device_count"][lang]
+
+ if not from_preview and not is_torch_cuda_available():
+ gr.Warning(ALERTS["warn_no_cuda"][lang])
+
+ self.aborted = False
+ self.logger_handler.reset()
+ self.trainer_callback = LogCallback(self)
+ return ""
+
+ def _finalize(self, lang: str, finish_info: str) -> str:
+ self.thread = None
+ self.running_data = None
+ self.running = False
+ torch_gc()
+ if self.aborted:
+ return ALERTS["info_aborted"][lang]
+ else:
+ return finish_info
+
+ def _parse_train_args(self, data: Dict[Component, Any]) -> Dict[str, Any]:
+ get = lambda name: data[self.manager.get_elem_by_name(name)]
+ user_config = load_config()
+
+ if get("top.adapter_path"):
+ adapter_name_or_path = ",".join(
+ [
+ get_save_dir(get("top.model_name"), get("top.finetuning_type"), adapter)
+ for adapter in get("top.adapter_path")
+ ]
+ )
+ else:
+ adapter_name_or_path = None
+
+ args = dict(
+ stage=TRAINING_STAGES[get("train.training_stage")],
+ do_train=True,
+ model_name_or_path=get("top.model_path"),
+ adapter_name_or_path=adapter_name_or_path,
+ cache_dir=user_config.get("cache_dir", None),
+ finetuning_type=get("top.finetuning_type"),
+ quantization_bit=int(get("top.quantization_bit")) if get("top.quantization_bit") in ["8", "4"] else None,
+ template=get("top.template"),
+ rope_scaling=get("top.rope_scaling") if get("top.rope_scaling") in ["linear", "dynamic"] else None,
+ flash_attn=(get("top.booster") == "flashattn"),
+ use_unsloth=(get("top.booster") == "unsloth"),
+ dataset_dir=get("train.dataset_dir"),
+ dataset=",".join(get("train.dataset")),
+ cutoff_len=get("train.cutoff_len"),
+ learning_rate=float(get("train.learning_rate")),
+ num_train_epochs=float(get("train.num_train_epochs")),
+ max_samples=int(get("train.max_samples")),
+ per_device_train_batch_size=get("train.batch_size"),
+ gradient_accumulation_steps=get("train.gradient_accumulation_steps"),
+ lr_scheduler_type=get("train.lr_scheduler_type"),
+ max_grad_norm=float(get("train.max_grad_norm")),
+ logging_steps=get("train.logging_steps"),
+ save_steps=get("train.save_steps"),
+ warmup_steps=get("train.warmup_steps"),
+ neftune_noise_alpha=get("train.neftune_alpha") or None,
+ optim=get("train.optim"),
+ resize_vocab=get("train.resize_vocab"),
+ packing=get("train.packing"),
+ upcast_layernorm=get("train.upcast_layernorm"),
+ use_llama_pro=get("train.use_llama_pro"),
+ shift_attn=get("train.shift_attn"),
+ use_galore=get("train.use_galore"),
+ output_dir=get_save_dir(get("top.model_name"), get("top.finetuning_type"), get("train.output_dir")),
+ fp16=(get("train.compute_type") == "fp16"),
+ bf16=(get("train.compute_type") == "bf16"),
+ pure_bf16=(get("train.compute_type") == "pure_bf16"),
+ )
+ args["disable_tqdm"] = True
+
+ if args["finetuning_type"] == "freeze":
+ args["num_layer_trainable"] = int(get("train.num_layer_trainable"))
+ args["name_module_trainable"] = get("train.name_module_trainable")
+ elif args["finetuning_type"] == "lora":
+ args["lora_rank"] = int(get("train.lora_rank"))
+ args["lora_alpha"] = int(get("train.lora_alpha"))
+ args["lora_dropout"] = float(get("train.lora_dropout"))
+ args["lora_target"] = get("train.lora_target") or get_module(get("top.model_name"))
+ args["use_rslora"] = get("train.use_rslora")
+ args["use_dora"] = get("train.use_dora")
+ args["additional_target"] = get("train.additional_target") or None
+ if args["stage"] in ["rm", "ppo", "dpo"]:
+ args["create_new_adapter"] = args["quantization_bit"] is None
+ else:
+ args["create_new_adapter"] = get("train.create_new_adapter")
+
+ if args["use_llama_pro"]:
+ args["num_layer_trainable"] = int(get("train.num_layer_trainable"))
+
+ if args["stage"] == "ppo":
+ args["reward_model"] = ",".join(
+ [
+ get_save_dir(get("top.model_name"), get("top.finetuning_type"), adapter)
+ for adapter in get("train.reward_model")
+ ]
+ )
+ args["reward_model_type"] = "lora" if args["finetuning_type"] == "lora" else "full"
+
+ if args["stage"] == "dpo":
+ args["dpo_beta"] = get("train.dpo_beta")
+ args["dpo_ftx"] = get("train.dpo_ftx")
+
+ if get("train.val_size") > 1e-6 and args["stage"] != "ppo":
+ args["val_size"] = get("train.val_size")
+ args["evaluation_strategy"] = "steps"
+ args["eval_steps"] = args["save_steps"]
+ args["per_device_eval_batch_size"] = args["per_device_train_batch_size"]
+ args["load_best_model_at_end"] = args["stage"] not in ["rm", "ppo"]
+
+ if args["use_galore"]:
+ args["galore_rank"] = get("train.galore_rank")
+ args["galore_update_interval"] = get("train.galore_update_interval")
+ args["galore_scale"] = get("train.galore_scale")
+ args["galore_target"] = get("train.galore_target")
+
+ return args
+
+ def _parse_eval_args(self, data: Dict[Component, Any]) -> Dict[str, Any]:
+ get = lambda name: data[self.manager.get_elem_by_name(name)]
+ user_config = load_config()
+
+ if get("top.adapter_path"):
+ adapter_name_or_path = ",".join(
+ [
+ get_save_dir(get("top.model_name"), get("top.finetuning_type"), adapter)
+ for adapter in get("top.adapter_path")
+ ]
+ )
+ else:
+ adapter_name_or_path = None
+
+ args = dict(
+ stage="sft",
+ model_name_or_path=get("top.model_path"),
+ adapter_name_or_path=adapter_name_or_path,
+ cache_dir=user_config.get("cache_dir", None),
+ finetuning_type=get("top.finetuning_type"),
+ quantization_bit=int(get("top.quantization_bit")) if get("top.quantization_bit") in ["8", "4"] else None,
+ template=get("top.template"),
+ rope_scaling=get("top.rope_scaling") if get("top.rope_scaling") in ["linear", "dynamic"] else None,
+ flash_attn=(get("top.booster") == "flashattn"),
+ use_unsloth=(get("top.booster") == "unsloth"),
+ dataset_dir=get("eval.dataset_dir"),
+ dataset=",".join(get("eval.dataset")),
+ cutoff_len=get("eval.cutoff_len"),
+ max_samples=int(get("eval.max_samples")),
+ per_device_eval_batch_size=get("eval.batch_size"),
+ predict_with_generate=True,
+ max_new_tokens=get("eval.max_new_tokens"),
+ top_p=get("eval.top_p"),
+ temperature=get("eval.temperature"),
+ output_dir=get_save_dir(get("top.model_name"), get("top.finetuning_type"), get("eval.output_dir")),
+ )
+
+ if get("eval.predict"):
+ args["do_predict"] = True
+ else:
+ args["do_eval"] = True
+
+ return args
+
+ def _preview(
+ self, data: Dict[Component, Any], do_train: bool
+ ) -> Generator[Tuple[str, Dict[str, Any]], None, None]:
+ error = self._initialize(data, do_train, from_preview=True)
+ if error:
+ gr.Warning(error)
+ yield error, gr.update(visible=False)
+ else:
+ args = self._parse_train_args(data) if do_train else self._parse_eval_args(data)
+ yield gen_cmd(args), gr.update(visible=False)
+
+ def _launch(self, data: Dict[Component, Any], do_train: bool) -> Generator[Tuple[str, Dict[str, Any]], None, None]:
+ error = self._initialize(data, do_train, from_preview=False)
+ if error:
+ gr.Warning(error)
+ yield error, gr.update(visible=False)
+ else:
+ args = self._parse_train_args(data) if do_train else self._parse_eval_args(data)
+ run_kwargs = dict(args=args, callbacks=[self.trainer_callback])
+ self.do_train, self.running_data = do_train, data
+ self.thread = Thread(target=run_exp, kwargs=run_kwargs)
+ self.thread.start()
+ yield from self.monitor()
+
+ def preview_train(self, data: Dict[Component, Any]) -> Generator[Tuple[str, Dict[str, Any]], None, None]:
+ yield from self._preview(data, do_train=True)
+
+ def preview_eval(self, data: Dict[Component, Any]) -> Generator[Tuple[str, Dict[str, Any]], None, None]:
+ yield from self._preview(data, do_train=False)
+
+ def run_train(self, data: Dict[Component, Any]) -> Generator[Tuple[str, Dict[str, Any]], None, None]:
+ yield from self._launch(data, do_train=True)
+
+ def run_eval(self, data: Dict[Component, Any]) -> Generator[Tuple[str, Dict[str, Any]], None, None]:
+ yield from self._launch(data, do_train=False)
+
+ def monitor(self) -> Generator[Tuple[str, Dict[str, Any]], None, None]:
+ get = lambda name: self.running_data[self.manager.get_elem_by_name(name)]
+ self.running = True
+ lang = get("top.lang")
+ output_dir = get_save_dir(
+ get("top.model_name"),
+ get("top.finetuning_type"),
+ get("{}.output_dir".format("train" if self.do_train else "eval")),
+ )
+
+ while self.thread.is_alive():
+ time.sleep(2)
+ if self.aborted:
+ yield ALERTS["info_aborting"][lang], gr.update(visible=False)
+ else:
+ yield self.logger_handler.log, update_process_bar(self.trainer_callback)
+
+ if self.do_train:
+ if os.path.exists(os.path.join(output_dir, TRAINING_ARGS_NAME)):
+ finish_info = ALERTS["info_finished"][lang]
+ else:
+ finish_info = ALERTS["err_failed"][lang]
+ else:
+ if os.path.exists(os.path.join(output_dir, "all_results.json")):
+ finish_info = get_eval_results(os.path.join(output_dir, "all_results.json"))
+ else:
+ finish_info = ALERTS["err_failed"][lang]
+
+ yield self._finalize(lang, finish_info), gr.update(visible=False)
diff --git a/src/AntSK.LLamaFactory/llamafactory/llmtuner/webui/utils.py b/src/AntSK.LLamaFactory/llamafactory/llmtuner/webui/utils.py
new file mode 100644
index 0000000..05cdd7f
--- /dev/null
+++ b/src/AntSK.LLamaFactory/llamafactory/llmtuner/webui/utils.py
@@ -0,0 +1,104 @@
+import json
+import os
+from datetime import datetime
+from typing import TYPE_CHECKING, Any, Dict
+
+import gradio as gr
+
+from ..extras.packages import is_matplotlib_available
+from ..extras.ploting import smooth
+from .common import get_save_dir
+from .locales import ALERTS
+
+
+if TYPE_CHECKING:
+ from ..extras.callbacks import LogCallback
+
+if is_matplotlib_available():
+ import matplotlib.figure
+ import matplotlib.pyplot as plt
+
+
+def update_process_bar(callback: "LogCallback") -> Dict[str, Any]:
+ if not callback.max_steps:
+ return gr.update(visible=False)
+
+ percentage = round(100 * callback.cur_steps / callback.max_steps, 0) if callback.max_steps != 0 else 100.0
+ label = "Running {:d}/{:d}: {} < {}".format(
+ callback.cur_steps, callback.max_steps, callback.elapsed_time, callback.remaining_time
+ )
+ return gr.update(label=label, value=percentage, visible=True)
+
+
+def get_time() -> str:
+ return datetime.now().strftime("%Y-%m-%d-%H-%M-%S")
+
+
+def can_quantize(finetuning_type: str) -> Dict[str, Any]:
+ if finetuning_type != "lora":
+ return gr.update(value="None", interactive=False)
+ else:
+ return gr.update(interactive=True)
+
+
+def check_json_schema(text: str, lang: str) -> None:
+ try:
+ tools = json.loads(text)
+ if tools:
+ assert isinstance(tools, list)
+ for tool in tools:
+ if "name" not in tool:
+ raise ValueError("Name not found.")
+ except ValueError:
+ gr.Warning(ALERTS["err_tool_name"][lang])
+ except Exception:
+ gr.Warning(ALERTS["err_json_schema"][lang])
+
+
+def gen_cmd(args: Dict[str, Any]) -> str:
+ args.pop("disable_tqdm", None)
+ args["plot_loss"] = args.get("do_train", None)
+ current_devices = os.environ.get("CUDA_VISIBLE_DEVICES", "0")
+ cmd_lines = ["CUDA_VISIBLE_DEVICES={} python src/train_bash.py ".format(current_devices)]
+ for k, v in args.items():
+ if v is not None and v is not False and v != "":
+ cmd_lines.append(" --{} {} ".format(k, str(v)))
+ cmd_text = "\\\n".join(cmd_lines)
+ cmd_text = "```bash\n{}\n```".format(cmd_text)
+ return cmd_text
+
+
+def get_eval_results(path: os.PathLike) -> str:
+ with open(path, "r", encoding="utf-8") as f:
+ result = json.dumps(json.load(f), indent=4)
+ return "```json\n{}\n```\n".format(result)
+
+
+def gen_plot(base_model: str, finetuning_type: str, output_dir: str) -> "matplotlib.figure.Figure":
+ if not base_model:
+ return
+ log_file = get_save_dir(base_model, finetuning_type, output_dir, "trainer_log.jsonl")
+ if not os.path.isfile(log_file):
+ return
+
+ plt.close("all")
+ plt.switch_backend("agg")
+ fig = plt.figure()
+ ax = fig.add_subplot(111)
+ steps, losses = [], []
+ with open(log_file, "r", encoding="utf-8") as f:
+ for line in f:
+ log_info = json.loads(line)
+ if log_info.get("loss", None):
+ steps.append(log_info["current_steps"])
+ losses.append(log_info["loss"])
+
+ if len(losses) == 0:
+ return None
+
+ ax.plot(steps, losses, color="#1f77b4", alpha=0.4, label="original")
+ ax.plot(steps, smooth(losses), color="#1f77b4", label="smoothed")
+ ax.legend()
+ ax.set_xlabel("step")
+ ax.set_ylabel("loss")
+ return fig
diff --git a/src/AntSK.sln b/src/AntSK.sln
index 764be2c..98db8a8 100644
--- a/src/AntSK.sln
+++ b/src/AntSK.sln
@@ -24,8 +24,6 @@ Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "AntSK.Test", "AntSK.Test\An
EndProject
Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "AntSK.LLamaFactory", "AntSK.LLamaFactory\AntSK.LLamaFactory.csproj", "{664DFA1F-68B7-49C7-B889-FA14D1756D3D}"
EndProject
-Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "AntSK.PyNet", "AntSK.PyNet\AntSK.PyNet.csproj", "{1C04AC5E-A37D-41D9-8519-7389DCC6F9AC}"
-EndProject
Global
GlobalSection(SolutionConfigurationPlatforms) = preSolution
Debug|Any CPU = Debug|Any CPU
@@ -56,10 +54,6 @@ Global
{664DFA1F-68B7-49C7-B889-FA14D1756D3D}.Debug|Any CPU.Build.0 = Debug|Any CPU
{664DFA1F-68B7-49C7-B889-FA14D1756D3D}.Release|Any CPU.ActiveCfg = Release|Any CPU
{664DFA1F-68B7-49C7-B889-FA14D1756D3D}.Release|Any CPU.Build.0 = Release|Any CPU
- {1C04AC5E-A37D-41D9-8519-7389DCC6F9AC}.Debug|Any CPU.ActiveCfg = Debug|Any CPU
- {1C04AC5E-A37D-41D9-8519-7389DCC6F9AC}.Debug|Any CPU.Build.0 = Debug|Any CPU
- {1C04AC5E-A37D-41D9-8519-7389DCC6F9AC}.Release|Any CPU.ActiveCfg = Release|Any CPU
- {1C04AC5E-A37D-41D9-8519-7389DCC6F9AC}.Release|Any CPU.Build.0 = Release|Any CPU
EndGlobalSection
GlobalSection(SolutionProperties) = preSolution
HideSolutionNode = FALSE
diff --git a/src/AntSK/AntSK.csproj b/src/AntSK/AntSK.csproj
index b91e5ef..dde5c11 100644
--- a/src/AntSK/AntSK.csproj
+++ b/src/AntSK/AntSK.csproj
@@ -9,13 +9,10 @@
-
-
-
-
-
- Always
-
+
+
+
+
diff --git a/src/AntSK/Pages/Setting/AIModel/AddModel.razor b/src/AntSK/Pages/Setting/AIModel/AddModel.razor
index 7e0bd7e..d5651d0 100644
--- a/src/AntSK/Pages/Setting/AIModel/AddModel.razor
+++ b/src/AntSK/Pages/Setting/AIModel/AddModel.razor
@@ -129,6 +129,22 @@
}
+
+ @if (context.AIType == AIType.BgeEmbedding)
+ {
+
+
+
+
+
+
+
+
+
+
+
+ }
+
@if (context.AIType == AIType.Mock)
{
}
diff --git a/src/AntSK/Pages/Setting/AIModel/AddModel.razor.cs b/src/AntSK/Pages/Setting/AIModel/AddModel.razor.cs
index dd9b8a2..ff54e44 100644
--- a/src/AntSK/Pages/Setting/AIModel/AddModel.razor.cs
+++ b/src/AntSK/Pages/Setting/AIModel/AddModel.razor.cs
@@ -3,6 +3,7 @@ using AntDesign.ProLayout;
using AntSK.Domain.Domain.Interface;
using AntSK.Domain.Domain.Model.Constant;
using AntSK.Domain.Domain.Model.Enum;
+using AntSK.Domain.Domain.Other;
using AntSK.Domain.Domain.Service;
using AntSK.Domain.Options;
using AntSK.Domain.Repositories;
@@ -248,6 +249,21 @@ namespace AntSK.Pages.Setting.AIModel
_ILLamaFactoryService.PipInstall();
}
}
+
+ private async Task BgeDownload()
+ {
+ if (string.IsNullOrEmpty(_aiModel.ModelName))
+ {
+ _ = Message.Error("请输入模型名称!", 2);
+ return;
+ }
+ if (string.IsNullOrEmpty(_aiModel.EndPoint))
+ {
+ _ = Message.Error("请输入正确的Python dll路径!", 2);
+ return;
+ }
+ EmbeddingConfig.LoadModel(_aiModel.EndPoint, _aiModel.ModelName);
+ }
private async Task CmdLogHandler(string message)
{
await InvokeAsync(() =>