First Vulkan Relise
This commit is contained in:
@@ -80,7 +80,7 @@ void trainOnSequence(NeuralNetwork& nn, Tokenizer& tok, Embedder& emb, const std
|
||||
std::vector<int> context(allTokens.begin(), allTokens.begin() + i);
|
||||
std::vector<double> target(MAX_VOCAB, 0.0);
|
||||
target[allTokens[i]] = 1.0;
|
||||
totalLoss += nn.train(buildNetInput(context, emb), target, lr);
|
||||
totalLoss += nn.trainVulkan(buildNetInput(context, emb), target, lr);
|
||||
|
||||
trainSteps++;
|
||||
auto currentTime = std::chrono::high_resolution_clock::now();
|
||||
@@ -99,7 +99,7 @@ void trainOnSequence(NeuralNetwork& nn, Tokenizer& tok, Embedder& emb, const std
|
||||
std::cout << "SPEED: " << std::setw(6) << std::fixed << std::setprecision(1) << stepsPerSec
|
||||
<< " st/s | MODEL: " << std::setw(7) << modelSizeStr
|
||||
<< " | CURRENT: [" << std::left << std::setw(15) << tok.getWord(allTokens[i]) << "]"
|
||||
<< "\033[K" << std::flush;
|
||||
<< "\033[K" << std::flush;
|
||||
|
||||
}
|
||||
maxLoss = totalLoss;
|
||||
@@ -146,8 +146,7 @@ int main() {
|
||||
std::string aiPart;
|
||||
std::getline(std::cin, aiPart);
|
||||
|
||||
std::string finalData = "[SYS]" + currentSystemPrompt +
|
||||
"[USER]" + userPart +
|
||||
std::string finalData = "[USER]" + userPart +
|
||||
"[AI]" + aiPart + "<EOS>";
|
||||
|
||||
std::cout << "\nTraining logic: Pattern Recognition..." << std::endl;
|
||||
@@ -189,9 +188,6 @@ int main() {
|
||||
std::getline(std::cin, currentSystemPrompt);
|
||||
std::cout << "System Prompt updated!" << std::endl;
|
||||
|
||||
} else if (cmdIn == "/trainVulkan") {
|
||||
std::cout << nn.trainVulkan() << "\n";
|
||||
|
||||
} else if (cmdIn == "/help") {
|
||||
std::cout << "\n--- MENU ---" << std::endl;
|
||||
std::cout << "/train\n/trainFile\n/sysPrompt\n/help\n/exit\n";
|
||||
@@ -201,7 +197,7 @@ int main() {
|
||||
std::cout << "\033[2J\033[1;1H";
|
||||
|
||||
} else {
|
||||
std::string prompt = "[SYS]" + currentSystemPrompt + "[USER]" + cmdIn + "[AI]";
|
||||
std::string prompt = "[USER]" + cmdIn + "[AI]";
|
||||
std::vector<int> currentTokens = tok.textToTokens(prompt);
|
||||
|
||||
std::cout << "AI: ";
|
||||
|
||||
Reference in New Issue
Block a user