#include #include #include #include #include #include #include #include #include #include #include "imgui.h" #include "imgui_impl_glfw.h" #include "imgui_impl_opengl3.h" #include #include "Xenith/core.hpp" #include "Xenith/token/token.hpp" struct UIState { std::string chatLog; TrainStatus lastStatus = {0}; float lr = 0.01f; int epochs = 10; char fileBuf[256] = "dataset.txt"; char editorBuf[1024 * 512] = ""; // 512KB для текста char inputBuf[512] = ""; bool scrollChat = true; std::atomic isTraining{false}; double genSpeed = 0; } ui; std::string formatTime(double seconds) { if (seconds < 0) seconds = 0; int h = (int)seconds / 3600; int m = ((int)seconds % 3600) / 60; int s = (int)seconds % 60; std::stringstream ss; ss << std::setfill('0') << std::setw(2) << h << ":" << std::setfill('0') << std::setw(2) << m << ":" << std::setfill('0') << std::setw(2) << s; return ss.str(); } std::vector buildNetInput(const std::vector& tokens, Embedder& emb) { std::vector netInput; netInput.reserve(MAX_CONTEXT * EMBED_DIM); int start = (int)tokens.size() - MAX_CONTEXT; if (start < 0) start = 0; int count = 0; for (int i = start; i < (int)tokens.size(); i++) { std::vector v = emb.get(tokens[i]); netInput.insert(netInput.end(), v.begin(), v.end()); count++; } while (count < MAX_CONTEXT) { for (int d = 0; d < EMBED_DIM; d++) netInput.push_back(0.0); count++; } return netInput; } int main() { if (!glfwInit()) return 1; GLFWwindow* window = glfwCreateWindow(1400, 800, "BiPy Studio", nullptr, nullptr); glfwMakeContextCurrent(window); glfwSwapInterval(1); IMGUI_CHECKVERSION(); ImGui::CreateContext(); ImGuiIO& io = ImGui::GetIO(); const char* font_path = "/usr/share/fonts/truetype/dejavu/DejaVuSans.ttf"; if (std::filesystem::exists(font_path)) io.Fonts->AddFontFromFileTTF(font_path, 18.0f, nullptr, io.Fonts->GetGlyphRangesCyrillic()); ImGui_ImplGlfw_InitForOpenGL(window, true); ImGui_ImplOpenGL3_Init("#version 130"); //LayerStructure_t layers[] = {{MAX_CONTEXT * EMBED_DIM, SIGMOID}, {MIDDLE_LAYER, SIGMOID}, {MAX_VOCAB, SIGMOID}}; LayerStructure_t layers[] = { {MAX_CONTEXT * EMBED_DIM, SIGMOID}, {2048, SIGMOID}, {2048, SIGMOID}, {1024, SIGMOID}, {MAX_VOCAB, SIGMOID} }; Tokenizer tok; Embedder emb(MAX_VOCAB, EMBED_DIM); NeuralNetwork nn(layers, 3, true); // GPU ON while (!glfwWindowShouldClose(window)) { glfwPollEvents(); ImGui_ImplOpenGL3_NewFrame(); ImGui_ImplGlfw_NewFrame(); ImGui::NewFrame(); ImGui::SetNextWindowPos(ImVec2(0, 0)); ImGui::SetNextWindowSize(io.DisplaySize); ImGui::Begin("Studio", nullptr, ImGuiWindowFlags_NoDecoration); // Левая панель ImGui::BeginChild("Left", ImVec2(io.DisplaySize.x * 0.4f, 0), true); if (ImGui::BeginTabBar("Tabs")) { // ВКЛАДКА ЧАТ if (ImGui::BeginTabItem("Чат")) { ImGui::BeginChild("ChatLog", ImVec2(0, -60), true); ImGui::TextWrapped("%s", ui.chatLog.c_str()); if (ui.scrollChat) { ImGui::SetScrollHereY(1.0f); ui.scrollChat = false; } ImGui::EndChild(); if (ImGui::InputText("##In", ui.inputBuf, 512, ImGuiInputTextFlags_EnterReturnsTrue)) { ui.chatLog += "\n[USER]: " + std::string(ui.inputBuf); auto startGen = std::chrono::high_resolution_clock::now(); std::string prompt = "[USER]" + std::string(ui.inputBuf) + "[AI]"; std::vector ctx = tok.textToTokens(prompt); std::string aiRes = ""; for (int g = 0; g < 128; g++) { std::vector out = nn.feedForward(buildNetInput(ctx, emb)); int bId = 0; double mV = -1.0; for (int i = 0; i < MAX_VOCAB; i++) if (out[i] > mV) { mV = out[i]; bId = i; } if (bId <= 0) break; std::string w = tok.getWord(bId); if (w == "") break; aiRes += w; ctx.push_back(bId); if (ctx.size() > MAX_CONTEXT) ctx.erase(ctx.begin()); } auto endGen = std::chrono::high_resolution_clock::now(); ui.genSpeed = 1.0 / std::chrono::duration(endGen - startGen).count() * aiRes.size(); // симв/сек ui.chatLog += "\n[AI]: " + aiRes + "\n"; ui.inputBuf[0] = '\0'; ui.scrollChat = true; } ImGui::Text("Скорость генерации: %.1f симв/сек", ui.genSpeed); ImGui::EndTabItem(); } // ВКЛАДКА ОБУЧЕНИЕ if (ImGui::BeginTabItem("Обучение")) { ImGui::InputText("Файл", ui.fileBuf, 256); if (ImGui::Button("Загрузить")) { std::ifstream f(ui.fileBuf); if (f) { std::stringstream ss; ss << f.rdbuf(); strncpy(ui.editorBuf, ss.str().c_str(), sizeof(ui.editorBuf)-1); } } ImGui::SliderInt("Эпохи", &ui.epochs, 1, 500); ImGui::SliderFloat("LR", &ui.lr, 0.0001f, 0.1f); if (ImGui::Button("ПУСК", ImVec2(-1, 40)) && !ui.isTraining) { std::string data = ui.editorBuf; ui.isTraining = true; std::thread([&nn, &tok, &emb, data]() { nn.trainOnSequence(tok, emb, data, ui.epochs, (double)ui.lr, buildNetInput, [](const TrainStatus& s) { ui.lastStatus = s; }); ui.isTraining = false; }).detach(); } std::stringstream ss; if (ui.lastStatus.totalParams >= 1e12) ss << std::fixed << std::setprecision(1) << ui.lastStatus.totalParams / 1e12 << "t"; else if (ui.lastStatus.totalParams >= 1e9) ss << std::fixed << std::setprecision(1) << ui.lastStatus.totalParams / 1e9 << "b"; else if (ui.lastStatus.totalParams >= 1e6) ss << std::fixed << std::setprecision(1) << ui.lastStatus.totalParams / 1e6 << "m"; else if (ui.lastStatus.totalParams >= 1e3) ss << std::fixed << std::setprecision(1) << ui.lastStatus.totalParams / 1e3 << "k"; else ss << ui.lastStatus.totalParams; std::stringstream ss2; double bytes = (double)ui.lastStatus.totalParams * 4.0; if (bytes >= 1024.0 * 1024.0 * 1024.0) ss2 << std::fixed << std::setprecision(2) << bytes / (1024.0 * 1024.0 * 1024.0) << " GB"; else if (bytes >= 1024.0 * 1024.0) ss2 << std::fixed << std::setprecision(2) << bytes / (1024.0 * 1024.0) << " MB"; else ss2 << std::fixed << std::setprecision(2) << bytes / 1024.0 << " KB"; ImGui::ProgressBar(ui.lastStatus.percentage / 100.0f, ImVec2(-1, 20)); ImGui::Text("%d / %d", ui.lastStatus.currentEpoch, ui.lastStatus.totalEpochs); ImGui::Text("ETA: %s", formatTime(ui.lastStatus.eta).c_str()); ImGui::Text("Токенов: %d / %d", ui.lastStatus.currentToken, ui.lastStatus.totalTokens); ImGui::Text("Текущий Loss: %.6f", ui.lastStatus.currentLoss); ImGui::Text("Loss эпохи: %.6f", ui.lastStatus.lastEpochLoss); ImGui::Text("Скорость обучения: %.1f t/s", ui.lastStatus.speed); ImGui::Text("Параметров: %s (%s)", ss.str().c_str(), ss2.str().c_str()); ImGui::EndTabItem(); } ImGui::EndTabBar(); } ImGui::EndChild(); ImGui::SameLine(); ImGui::BeginChild("Right"); ImGui::Text("Редактор датасета:"); ImGui::InputTextMultiline("##ed", ui.editorBuf, sizeof(ui.editorBuf), ImVec2(-1, -1)); ImGui::EndChild(); ImGui::End(); ImGui::Render(); glViewport(0, 0, (int)io.DisplaySize.x, (int)io.DisplaySize.y); glClearColor(0.1f, 0.1f, 0.1f, 1.0f); glClear(GL_COLOR_BUFFER_BIT); ImGui_ImplOpenGL3_RenderDrawData(ImGui::GetDrawData()); glfwSwapBuffers(window); } return 0; }