205 lines
8.7 KiB
C++
205 lines
8.7 KiB
C++
#include <iostream>
|
|
#include <vector>
|
|
#include <string>
|
|
#include <sstream>
|
|
#include <fstream>
|
|
#include <chrono>
|
|
#include <iomanip>
|
|
#include <filesystem>
|
|
#include <thread>
|
|
#include <atomic>
|
|
|
|
#include "imgui.h"
|
|
#include "imgui_impl_glfw.h"
|
|
#include "imgui_impl_opengl3.h"
|
|
#include <GLFW/glfw3.h>
|
|
|
|
#include "Xenith/core.hpp"
|
|
#include "Xenith/token/token.hpp"
|
|
|
|
struct UIState {
|
|
std::string chatLog;
|
|
TrainStatus lastStatus = {0};
|
|
float lr = 0.01f;
|
|
int epochs = 10;
|
|
char fileBuf[256] = "dataset.txt";
|
|
char editorBuf[1024 * 512] = ""; // 512KB для текста
|
|
char inputBuf[512] = "";
|
|
bool scrollChat = true;
|
|
std::atomic<bool> isTraining{false};
|
|
double genSpeed = 0;
|
|
} ui;
|
|
|
|
std::string formatTime(double seconds) {
|
|
if (seconds < 0) seconds = 0;
|
|
int h = (int)seconds / 3600;
|
|
int m = ((int)seconds % 3600) / 60;
|
|
int s = (int)seconds % 60;
|
|
std::stringstream ss;
|
|
ss << std::setfill('0') << std::setw(2) << h << ":" << std::setfill('0') << std::setw(2) << m << ":" << std::setfill('0') << std::setw(2) << s;
|
|
return ss.str();
|
|
}
|
|
|
|
std::vector<double> buildNetInput(const std::vector<int>& tokens, Embedder& emb) {
|
|
std::vector<double> netInput;
|
|
netInput.reserve(MAX_CONTEXT * EMBED_DIM);
|
|
int start = (int)tokens.size() - MAX_CONTEXT;
|
|
if (start < 0) start = 0;
|
|
int count = 0;
|
|
for (int i = start; i < (int)tokens.size(); i++) {
|
|
std::vector<double> v = emb.get(tokens[i]);
|
|
netInput.insert(netInput.end(), v.begin(), v.end());
|
|
count++;
|
|
}
|
|
while (count < MAX_CONTEXT) {
|
|
for (int d = 0; d < EMBED_DIM; d++) netInput.push_back(0.0);
|
|
count++;
|
|
}
|
|
return netInput;
|
|
}
|
|
|
|
int main() {
|
|
if (!glfwInit()) return 1;
|
|
GLFWwindow* window = glfwCreateWindow(1400, 800, "BiPy Studio", nullptr, nullptr);
|
|
glfwMakeContextCurrent(window);
|
|
glfwSwapInterval(1);
|
|
|
|
IMGUI_CHECKVERSION();
|
|
ImGui::CreateContext();
|
|
ImGuiIO& io = ImGui::GetIO();
|
|
|
|
const char* font_path = "/usr/share/fonts/truetype/dejavu/DejaVuSans.ttf";
|
|
if (std::filesystem::exists(font_path))
|
|
io.Fonts->AddFontFromFileTTF(font_path, 18.0f, nullptr, io.Fonts->GetGlyphRangesCyrillic());
|
|
|
|
ImGui_ImplGlfw_InitForOpenGL(window, true);
|
|
ImGui_ImplOpenGL3_Init("#version 130");
|
|
|
|
//LayerStructure_t layers[] = {{MAX_CONTEXT * EMBED_DIM, SIGMOID}, {MIDDLE_LAYER, SIGMOID}, {MAX_VOCAB, SIGMOID}};
|
|
|
|
LayerStructure_t layers[] = {
|
|
{MAX_CONTEXT * EMBED_DIM, SIGMOID},
|
|
{2048, SIGMOID},
|
|
{2048, SIGMOID},
|
|
{1024, SIGMOID},
|
|
{MAX_VOCAB, SIGMOID}
|
|
};
|
|
|
|
Tokenizer tok;
|
|
Embedder emb(MAX_VOCAB, EMBED_DIM);
|
|
NeuralNetwork nn(layers, 3, true); // GPU ON
|
|
|
|
while (!glfwWindowShouldClose(window)) {
|
|
glfwPollEvents();
|
|
ImGui_ImplOpenGL3_NewFrame(); ImGui_ImplGlfw_NewFrame(); ImGui::NewFrame();
|
|
|
|
ImGui::SetNextWindowPos(ImVec2(0, 0));
|
|
ImGui::SetNextWindowSize(io.DisplaySize);
|
|
ImGui::Begin("Studio", nullptr, ImGuiWindowFlags_NoDecoration);
|
|
|
|
// Левая панель
|
|
ImGui::BeginChild("Left", ImVec2(io.DisplaySize.x * 0.4f, 0), true);
|
|
if (ImGui::BeginTabBar("Tabs")) {
|
|
|
|
// ВКЛАДКА ЧАТ
|
|
if (ImGui::BeginTabItem("Чат")) {
|
|
ImGui::BeginChild("ChatLog", ImVec2(0, -60), true);
|
|
ImGui::TextWrapped("%s", ui.chatLog.c_str());
|
|
if (ui.scrollChat) { ImGui::SetScrollHereY(1.0f); ui.scrollChat = false; }
|
|
ImGui::EndChild();
|
|
|
|
if (ImGui::InputText("##In", ui.inputBuf, 512, ImGuiInputTextFlags_EnterReturnsTrue)) {
|
|
ui.chatLog += "\n[USER]: " + std::string(ui.inputBuf);
|
|
auto startGen = std::chrono::high_resolution_clock::now();
|
|
|
|
std::string prompt = "[USER]" + std::string(ui.inputBuf) + "[AI]";
|
|
std::vector<int> ctx = tok.textToTokens(prompt);
|
|
std::string aiRes = "";
|
|
for (int g = 0; g < 128; g++) {
|
|
std::vector<double> out = nn.feedForward(buildNetInput(ctx, emb));
|
|
int bId = 0; double mV = -1.0;
|
|
for (int i = 0; i < MAX_VOCAB; i++) if (out[i] > mV) { mV = out[i]; bId = i; }
|
|
if (bId <= 0) break;
|
|
std::string w = tok.getWord(bId);
|
|
if (w == "<EOS>") break;
|
|
aiRes += w; ctx.push_back(bId);
|
|
if (ctx.size() > MAX_CONTEXT) ctx.erase(ctx.begin());
|
|
}
|
|
auto endGen = std::chrono::high_resolution_clock::now();
|
|
ui.genSpeed = 1.0 / std::chrono::duration<double>(endGen - startGen).count() * aiRes.size(); // симв/сек
|
|
ui.chatLog += "\n[AI]: " + aiRes + "\n";
|
|
ui.inputBuf[0] = '\0'; ui.scrollChat = true;
|
|
}
|
|
ImGui::Text("Скорость генерации: %.1f симв/сек", ui.genSpeed);
|
|
ImGui::EndTabItem();
|
|
}
|
|
|
|
// ВКЛАДКА ОБУЧЕНИЕ
|
|
if (ImGui::BeginTabItem("Обучение")) {
|
|
ImGui::InputText("Файл", ui.fileBuf, 256);
|
|
if (ImGui::Button("Загрузить")) {
|
|
std::ifstream f(ui.fileBuf);
|
|
if (f) {
|
|
std::stringstream ss; ss << f.rdbuf();
|
|
strncpy(ui.editorBuf, ss.str().c_str(), sizeof(ui.editorBuf)-1);
|
|
}
|
|
}
|
|
ImGui::SliderInt("Эпохи", &ui.epochs, 1, 500);
|
|
ImGui::SliderFloat("LR", &ui.lr, 0.0001f, 0.1f);
|
|
|
|
if (ImGui::Button("ПУСК", ImVec2(-1, 40)) && !ui.isTraining) {
|
|
std::string data = ui.editorBuf;
|
|
ui.isTraining = true;
|
|
std::thread([&nn, &tok, &emb, data]() {
|
|
nn.trainOnSequence(tok, emb, data, ui.epochs, (double)ui.lr, buildNetInput,
|
|
[](const TrainStatus& s) { ui.lastStatus = s; });
|
|
ui.isTraining = false;
|
|
}).detach();
|
|
}
|
|
|
|
std::stringstream ss;
|
|
if (ui.lastStatus.totalParams >= 1e12) ss << std::fixed << std::setprecision(1) << ui.lastStatus.totalParams / 1e12 << "t";
|
|
else if (ui.lastStatus.totalParams >= 1e9) ss << std::fixed << std::setprecision(1) << ui.lastStatus.totalParams / 1e9 << "b";
|
|
else if (ui.lastStatus.totalParams >= 1e6) ss << std::fixed << std::setprecision(1) << ui.lastStatus.totalParams / 1e6 << "m";
|
|
else if (ui.lastStatus.totalParams >= 1e3) ss << std::fixed << std::setprecision(1) << ui.lastStatus.totalParams / 1e3 << "k";
|
|
else ss << ui.lastStatus.totalParams;
|
|
|
|
std::stringstream ss2;
|
|
double bytes = (double)ui.lastStatus.totalParams * 4.0;
|
|
if (bytes >= 1024.0 * 1024.0 * 1024.0)
|
|
ss2 << std::fixed << std::setprecision(2) << bytes / (1024.0 * 1024.0 * 1024.0) << " GB";
|
|
else if (bytes >= 1024.0 * 1024.0)
|
|
ss2 << std::fixed << std::setprecision(2) << bytes / (1024.0 * 1024.0) << " MB";
|
|
else
|
|
ss2 << std::fixed << std::setprecision(2) << bytes / 1024.0 << " KB";
|
|
|
|
|
|
ImGui::ProgressBar(ui.lastStatus.percentage / 100.0f, ImVec2(-1, 20));
|
|
ImGui::Text("%d / %d", ui.lastStatus.currentEpoch, ui.lastStatus.totalEpochs);
|
|
ImGui::Text("ETA: %s", formatTime(ui.lastStatus.eta).c_str());
|
|
ImGui::Text("Токенов: %d / %d", ui.lastStatus.currentToken, ui.lastStatus.totalTokens);
|
|
ImGui::Text("Текущий Loss: %.6f", ui.lastStatus.currentLoss);
|
|
ImGui::Text("Loss эпохи: %.6f", ui.lastStatus.lastEpochLoss);
|
|
ImGui::Text("Скорость обучения: %.1f t/s", ui.lastStatus.speed);
|
|
ImGui::Text("Параметров: %s (%s)", ss.str().c_str(), ss2.str().c_str());
|
|
ImGui::EndTabItem();
|
|
}
|
|
ImGui::EndTabBar();
|
|
}
|
|
ImGui::EndChild();
|
|
ImGui::SameLine();
|
|
ImGui::BeginChild("Right");
|
|
ImGui::Text("Редактор датасета:");
|
|
ImGui::InputTextMultiline("##ed", ui.editorBuf, sizeof(ui.editorBuf), ImVec2(-1, -1));
|
|
ImGui::EndChild();
|
|
ImGui::End();
|
|
|
|
ImGui::Render();
|
|
glViewport(0, 0, (int)io.DisplaySize.x, (int)io.DisplaySize.y);
|
|
glClearColor(0.1f, 0.1f, 0.1f, 1.0f); glClear(GL_COLOR_BUFFER_BIT);
|
|
ImGui_ImplOpenGL3_RenderDrawData(ImGui::GetDrawData());
|
|
glfwSwapBuffers(window);
|
|
}
|
|
return 0;
|
|
}
|