Files
2026-05-03 21:02:34 +07:00

105 lines
5.0 KiB
C++

#include <iostream>
#include <iomanip>
#include <vector>
#include <string>
#include <sstream>
#include <fstream>
#include <chrono>
#include "Xenith/core.hpp"
#include "Xenith/token/token.hpp"
std::string currentSystemPrompt = "";
LayerStructure_t layers[] = {
{MAX_CONTEXT * EMBED_DIM, SIGMOID},
{1024, SIGMOID},
{MAX_VOCAB, SIGMOID}
};
std::string formatTime(double seconds) {
if (seconds < 0) seconds = 0;
int h = (int)seconds / 3600;
int m = ((int)seconds % 3600) / 60;
int s = (int)seconds % 60;
std::stringstream ss;
ss << std::setfill('0') << std::setw(2) << h << ":" << std::setfill('0') << std::setw(2) << m << ":" << std::setfill('0') << std::setw(2) << s;
return ss.str();
}
std::vector<double> buildNetInput(const std::vector<int>& tokens, Embedder& emb) {
std::vector<double> netInput; netInput.reserve(MAX_CONTEXT * EMBED_DIM);
int start = (int)tokens.size() - MAX_CONTEXT; if (start < 0) start = 0;
int count = 0;
for (int i = start; i < (int)tokens.size(); i++) {
std::vector<double> v = emb.get(tokens[i]);
netInput.insert(netInput.end(), v.begin(), v.end()); count++;
}
while (count < MAX_CONTEXT) { for (int d = 0; d < EMBED_DIM; d++) netInput.push_back(0.0); count++; }
return netInput;
}
int main() {
Tokenizer tok; Embedder emb(MAX_VOCAB, EMBED_DIM);
NeuralNetwork nn(layers, sizeof(layers)/sizeof(layers[0]), true);
while (true) {
std::cout << "\033[1;32mxenith\033[0m~$ ";
std::string cmdIn; std::getline(std::cin, cmdIn);
if (cmdIn == "/exit") break;
if (cmdIn == "/train" || cmdIn == "/trainFile") {
std::string content;
if (cmdIn == "/trainFile") {
std::cout << "Filename: "; std::string fn; std::getline(std::cin, fn);
std::ifstream f(fn); std::stringstream ss; ss << f.rdbuf(); content = ss.str();
} else {
std::cout << "User: "; std::string u; std::getline(std::cin, u);
std::cout << "AI: "; std::string a; std::getline(std::cin, a);
content = "[CLR][USER]" + u + "[AI]" + a + "<EOS>";
}
std::cout << "Epochs: "; std::string ep; std::getline(std::cin, ep);
std::cout << "LR: "; std::string lr; std::getline(std::cin, lr);
std::cout << "\n\033[s";
nn.trainOnSequence(tok, emb, content, std::stoi(ep), std::stod(lr), buildNetInput, [](const TrainStatus& s) {
std::stringstream ss;
if (s.totalParams >= 1e12) ss << std::fixed << std::setprecision(1) << s.totalParams / 1e12 << "t";
else if (s.totalParams >= 1e9) ss << std::fixed << std::setprecision(1) << s.totalParams / 1e9 << "b";
else if (s.totalParams >= 1e6) ss << std::fixed << std::setprecision(1) << s.totalParams / 1e6 << "m";
else if (s.totalParams >= 1e3) ss << std::fixed << std::setprecision(1) << s.totalParams / 1e3 << "k";
else ss << s.totalParams;
std::cout << "\033[u";
int width = 100;
int pos = width * (s.percentage / 100.0f);
std::cout << "[\033[1;36m";
for(int i=0; i<width; i++) std::cout << (i < pos ? "" : " ");
std::cout << "\033[0m] " << std::fixed << std::setprecision(1) << s.percentage << "% | ETA: \033[1;33m" << formatTime(s.eta) << "\033[0m | Params: \033[1;32m" << ss.str() << "\033[0m\n";
std::cout << "Epoch: " << s.currentEpoch << "/" << s.totalEpochs
<< " | Token: " << s.currentToken << "/" << s.totalTokens << "\n";
std::cout << "Loss: " << std::fixed << std::setprecision(6) << s.currentLoss
<< " | Ep Loss: " << s.epochLoss << "\n";
std::cout << "Prev Ep Loss: " << s.lastEpochLoss << "\n";
std::cout << "Speed: " << std::fixed << std::setprecision(1) << s.speed << " t/s\033[K" << std::flush;
}
);
std::cout << "\n\nDone.\n";
} else {
std::string prompt = "[USER]" + cmdIn + "[AI]";
std::vector<int> ctx = tok.textToTokens(prompt);
int eosId = -1; auto s = tok.textToTokens("<EOS>"); if(!s.empty()) eosId = s[0];
std::cout << "\033[1;33mAI:\033[0m ";
for (int g = 0; g < 256; g++) {
std::vector<double> out = nn.feedForward(buildNetInput(ctx, emb));
int bId = 0; double mV = -1.0;
for (int i = 0; i < MAX_VOCAB; i++) if (out[i] > mV) { mV = out[i]; bId = i; }
if (bId == eosId || bId == 0) break;
std::string w = tok.getWord(bId);
if (w != "[AI]" && w != "[USER]" && w != "[CLR]") std::cout << w << std::flush;
ctx.push_back(bId); if (ctx.size() > MAX_CONTEXT) ctx.erase(ctx.begin());
}
std::cout << std::endl;
}
}
return 0;
}