#include "core.hpp" #include #include #include #define MAX_CORES 16 NeuralNetwork::NeuralNetwork(LayerStructure_t layers[], int count) : numLayers(count) { for (int i = 0; i < count; i++) sizes.push_back(layers[i].size); for (int i = 0; i < count - 1; i++) { std::vector> layerW; double scale = sqrt(2.0 / sizes[i]); for (int j = 0; j < sizes[i+1]; j++) { std::vector nodeW; for (int k = 0; k < sizes[i]; k++) nodeW.push_back(((double)rand()/RAND_MAX * 2 - 1) * scale); layerW.push_back(nodeW); } weights.push_back(layerW); biases.push_back(std::vector(sizes[i+1], 0.0)); } } std::vector NeuralNetwork::feedForward(const std::vector& input) { outputs.clear(); outputs.push_back(input); std::vector curr = input; for (int i = 0; i < numLayers - 1; i++) { std::vector next; for (int j = 0; j < sizes[i+1]; j++) { double sum = biases[i][j]; for (int k = 0; k < (int)curr.size(); k++) sum += curr[k] * weights[i][j][k]; next.push_back(1.0 / (1.0 + exp(-sum))); } curr = next; outputs.push_back(curr); } return curr; } double NeuralNetwork::train(const std::vector& input, const std::vector& target, double lr) { omp_set_num_threads(MAX_CORES); std::vector pred = feedForward(input); std::vector> errors(numLayers); errors[numLayers - 1].resize(sizes[numLayers - 1]); double totalErr = 0; for (int i = 0; i < sizes[numLayers - 1]; i++) { double e = target[i] - pred[i]; errors[numLayers - 1][i] = e * pred[i] * (1.0 - pred[i]); totalErr += e * e; } for (int i = numLayers - 2; i > 0; i--) { errors[i].resize(sizes[i]); #pragma omp parallel for for (int j = 0; j < sizes[i]; j++) { double e = 0; for (int k = 0; k < sizes[i + 1]; k++) { e += errors[i + 1][k] * weights[i][k][j]; } errors[i][j] = e * outputs[i][j] * (1.0 - outputs[i][j]); } } for (int i = 0; i < numLayers - 1; i++) { #pragma omp parallel for for (int j = 0; j < sizes[i + 1]; j++) { double errorTerm = lr * errors[i + 1][j]; // Вложенный цикл обновления весов for (int k = 0; k < sizes[i]; k++) { weights[i][j][k] += errorTerm * outputs[i][k]; } biases[i][j] += errorTerm; } } return totalErr; }