#ifndef CORE_H #define CORE_H #include "typedef.hpp" #include #include #include #include struct TrainStatus { int currentEpoch; int totalEpochs; int currentToken; int totalTokens; double currentLoss; double epochLoss; double lastEpochLoss; double speed; double eta; float percentage; long totalParams; }; class Tokenizer; class Embedder; class NeuralNetwork { private: int numLayers; std::vector sizes; std::vector h_weights, h_biases, h_outputs, h_errors; std::vector wOff, bOff, oOff; bool useVulkan; vk::Instance instance; vk::PhysicalDevice physDev; vk::Device device; vk::Queue queue; vk::CommandPool cmdPool; uint32_t computeQueueFamilyIndex; vk::Buffer gpuW, gpuB, gpuO, gpuE, gpuT; vk::DeviceMemory memW, memB, memO, memE, memT; void *pW = nullptr, *pB = nullptr, *pO = nullptr, *pE = nullptr, *pT = nullptr; vk::DescriptorPool descriptorPool; vk::DescriptorSet descriptorSet; vk::DescriptorSetLayout dsLayout; vk::PipelineLayout pipeLayout; vk::Pipeline pipeline; vk::ShaderModule shaderModule; struct TrainParams { uint32_t mode; uint32_t prevSize; uint32_t nextSize; uint32_t wOff; uint32_t bOff; uint32_t oOff; uint32_t nextOOff; float lr; }; void initVulkan(); void initVulkanResources(); uint32_t findMemoryType(uint32_t typeFilter, vk::MemoryPropertyFlags properties); std::vector readFile(const std::string& filename); double runTrainCPU(const std::vector& input, const std::vector& target, double lr); public: int cpu_count = 4; NeuralNetwork(LayerStructure_t layers[], int count, bool useVulkan = false); ~NeuralNetwork(); void syncToCPU(); void syncToGPU(); std::vector feedForward(const std::vector& input); double train(const std::vector& input, const std::vector& target, double lr); void trainOnSequence( Tokenizer& tok, Embedder& emb, const std::string& dataset, int epochs, double lr, std::function(const std::vector&, Embedder&)> buildInput, std::function onProgress = nullptr ); long long getTotalParameters() { long long total = 0; for (int i = 0; i < numLayers - 1; i++) { total += (long long)sizes[i] * sizes[i+1]; total += (long long)sizes[i+1]; } return total; } }; #endif