65 lines
2.1 KiB
Plaintext
65 lines
2.1 KiB
Plaintext
#version 450
|
|
|
|
layout(local_size_x = 256) in;
|
|
|
|
layout(std430, binding = 0) buffer Weights { float W[]; };
|
|
layout(std430, binding = 1) buffer Biases { float B[]; };
|
|
layout(std430, binding = 2) buffer Outputs { float O[]; };
|
|
layout(std430, binding = 3) buffer Errors { float E[]; };
|
|
layout(std430, binding = 4) buffer Targets { float T[]; };
|
|
|
|
layout(push_constant) uniform Params {
|
|
uint mode; // 0: FF, 1: OutError, 2: BackProp, 3: Update
|
|
uint prevSize;
|
|
uint nextSize;
|
|
uint wOff;
|
|
uint bOff;
|
|
uint oOff;
|
|
uint nextOOff;
|
|
float lr;
|
|
} p;
|
|
|
|
float sigmoid(float x) { return 1.0 / (1.0 + exp(-x)); }
|
|
float dSigmoid(float x) { return x * (1.0 - x); }
|
|
|
|
void main() {
|
|
uint idx = gl_GlobalInvocationID.x;
|
|
|
|
// MODE 0: Прямое распространение (Forward Pass)
|
|
if (p.mode == 0) {
|
|
if (idx < p.nextSize) {
|
|
float sum = B[p.bOff + idx];
|
|
for (uint i = 0; i < p.prevSize; i++) {
|
|
sum += O[p.oOff + i] * W[p.wOff + idx * p.prevSize + i];
|
|
}
|
|
O[p.nextOOff + idx] = sigmoid(sum);
|
|
}
|
|
}
|
|
// MODE 1: Ошибка выходного слоя
|
|
else if (p.mode == 1) {
|
|
if (idx < p.nextSize) {
|
|
float outVal = O[p.nextOOff + idx];
|
|
E[p.nextOOff + idx] = (T[idx] - outVal) * dSigmoid(outVal);
|
|
}
|
|
}
|
|
// MODE 2: Обратное распространение ошибки (Hidden layers)
|
|
else if (p.mode == 2) {
|
|
if (idx < p.prevSize) {
|
|
float errSum = 0.0;
|
|
for (uint i = 0; i < p.nextSize; i++) {
|
|
errSum += E[p.nextOOff + i] * W[p.wOff + i * p.prevSize + idx];
|
|
}
|
|
E[p.oOff + idx] = errSum * dSigmoid(O[p.oOff + idx]);
|
|
}
|
|
}
|
|
// MODE 3: Обновление весов и смещений
|
|
else if (p.mode == 3) {
|
|
if (idx < p.nextSize) {
|
|
float errTerm = E[p.nextOOff + idx] * p.lr;
|
|
for (uint i = 0; i < p.prevSize; i++) {
|
|
W[p.wOff + idx * p.prevSize + i] += errTerm * O[p.oOff + i];
|
|
}
|
|
B[p.bOff + idx] += errTerm;
|
|
}
|
|
}
|
|
} |