Baremetal-NN
Baremetal-NN API documentation
Loading...
Searching...
No Matches
nn_math.h
Go to the documentation of this file.
1#ifndef __NN_MATH_H
2#define __NN_MATH_H
3
4
5#include <assert.h>
6#include <math.h>
7#include <string.h>
8
9#include "float16.h"
10
11//
12// fundamental operations
13//
14
15
16
17// inline static void NN_mad_F32(const int n, float *y, const float *x, const float v) {
18// #if defined(GGML_SIMD)
19// const int np = (n & ~(GGML_F32_STEP - 1));
20
21// GGML_F32_VEC vx = GGML_F32_VEC_SET1(v);
22
23// GGML_F32_VEC ax[GGML_F32_ARR];
24// GGML_F32_VEC ay[GGML_F32_ARR];
25
26// for (int i = 0; i < np; i += GGML_F32_STEP) {
27// for (int j = 0; j < GGML_F32_ARR; j++) {
28// ax[j] = GGML_F32_VEC_LOAD(x + i + j*GGML_F32_EPR);
29// ay[j] = GGML_F32_VEC_LOAD(y + i + j*GGML_F32_EPR);
30// ay[j] = GGML_F32_VEC_FMA(ay[j], ax[j], vx);
31
32// GGML_F32_VEC_STORE(y + i + j*GGML_F32_EPR, ay[j]);
33// }
34// }
35
36// // leftovers
37// for (int i = np; i < n; i += 1) {
38// y[i] += x[i]*v;
39// }
40// #else
41// // scalar
42// for (int i = 0; i < n; i += 1) {
43// y[i] += x[i]*v;
44// }
45// #endif
46// }
47
48// inline static void NN_mad_f16(const int n, float16_t *y, const float16_t *x, const float v) {
49// #if defined(GGML_SIMD)
50// const int np = (n & ~(GGML_F16_STEP - 1));
51
52// GGML_F16_VEC vx = GGML_F16_VEC_SET1(v);
53
54// GGML_F16_VEC ax[GGML_F16_ARR];
55// GGML_F16_VEC ay[GGML_F16_ARR];
56
57// for (int i = 0; i < np; i += GGML_F16_STEP) {
58// for (int j = 0; j < GGML_F16_ARR; j++) {
59// ax[j] = GGML_F16_VEC_LOAD(x + i + j*GGML_F16_EPR, j);
60// ay[j] = GGML_F16_VEC_LOAD(y + i + j*GGML_F16_EPR, j);
61// ay[j] = GGML_F16_VEC_FMA(ay[j], ax[j], vx);
62
63// GGML_F16_VEC_STORE(y + i + j*GGML_F16_EPR, ay, j);
64// }
65// }
66
67// // leftovers
68// for (int i = np; i < n; i += 1) {
69// y[i] = GGML_FP32_TO_FP16(GGML_FP16_TO_FP32(y[i]) + GGML_FP16_TO_FP32(x[i])*v);
70// }
71// #else
72// // scalar
73// for (int i = 0; i < n; i += 1) {
74// y[i] = GGML_FP32_TO_FP16(GGML_FP16_TO_FP32(y[i]) + GGML_FP16_TO_FP32(x[i])*v);
75// }
76// #endif
77// }
78
79// // xs and vs are byte strides of x and v
80// inline static void NN_mad_F32_unroll(const int n, const int xs, const int vs, float *restrict y, const float *restrict xv, const float *restrict vv) {
81
82// const float *restrict x[GGML_VEC_MAD_UNROLL];
83// const float *restrict v[GGML_VEC_MAD_UNROLL];
84
85// for (int i = 0; i < GGML_VEC_MAD_UNROLL; i += 1) {
86// x[i] = (const float *) ((const char *) xv + i*xs);
87// v[i] = (const float *) ((const char *) vv + i*vs);
88// }
89
90// #if defined(GGML_SIMD)
91// const int np = (n & ~(GGML_F32_STEP - 1));
92
93// GGML_F32_VEC vx[GGML_VEC_MAD_UNROLL];
94
95// for (int k = 0; k < GGML_VEC_MAD_UNROLL; ++k) {
96// vx[k] = GGML_F32_VEC_SET1(v[k][0]);
97// }
98
99// GGML_F32_VEC ax[GGML_VEC_MAD_UNROLL][GGML_F32_ARR];
100// GGML_F32_VEC ay[GGML_F32_ARR];
101
102// for (int i = 0; i < np; i += GGML_F32_STEP) {
103// for (int j = 0; j < GGML_F32_ARR; j++) {
104// ay[j] = GGML_F32_VEC_LOAD(y + i + j*GGML_F32_EPR);
105
106// for (int k = 0; k < GGML_VEC_MAD_UNROLL; ++k) {
107// ax[k][j] = GGML_F32_VEC_LOAD(x[k] + i + j*GGML_F32_EPR);
108// ay[j] = GGML_F32_VEC_FMA(ay[j], ax[k][j], vx[k]);
109// }
110
111// GGML_F32_VEC_STORE(y + i + j*GGML_F32_EPR, ay[j]);
112// }
113// }
114
115// // leftovers
116// for (int k = 0; k < GGML_VEC_MAD_UNROLL; ++k) {
117// for (int i = np; i < n; i += 1) {
118// y[i] += x[k][i]*v[k][0];
119// }
120// }
121// #else
122// // scalar
123// for (int k = 0; k < GGML_VEC_MAD_UNROLL; ++k) {
124// for (int i = 0; i < n; i += 1) {
125// y[i] += x[k][i]*v[k][0];
126// }
127// }
128// #endif
129// }
130
131// inline static void NN_step_F32 (const int n, float *y, const float *x) { for (int i = 0; i < n; i += 1) y[i] = (x[i] > 0.f) ? 1.f : 0.f; }
132// inline static void NN_tanh_F32 (const int n, float *y, const float *x) { for (int i = 0; i < n; i += 1) y[i] = tanhf(x[i]); }
133// inline static void NN_elu_F32 (const int n, float *y, const float *x) { for (int i = 0; i < n; i += 1) y[i] = (x[i] > 0.f) ? x[i] : expf(x[i])-1; }
134// inline static void NN_relu_F32 (const int n, float *y, const float *x) { for (int i = 0; i < n; i += 1) y[i] = (x[i] > 0.f) ? x[i] : 0.f; }
135// inline static void NN_leaky_relu_F32 (const int n, float *y, const float *x, const float ns) { for (int i = 0; i < n; i += 1) y[i] = ((x[i] > 0.f) ? x[i] : 0.f) + ns * ((x[i] < 0.0f) ? x[i] : 0.f); }
136// inline static void NN_sigmoid_F32 (const int n, float *y, const float *x) { for (int i = 0; i < n; i += 1) y[i] = 1.f / (1.f + expf(-x[i])); }
137// // TODO: optimize performance
138// inline static void NN_hardswish_F32 (const int n, float *y, const float *x) { for (int i = 0; i < n; i += 1) y[i] = x[i] * fminf(1.0f, fmaxf(0.0f, (x[i] + 3.0f) / 6.0f)); }
139// inline static void NN_hardsigmoid_F32 (const int n, float *y, const float *x) { for (int i = 0; i < n; i += 1) y[i] = fminf(1.0f, fmaxf(0.0f, (x[i] + 3.0f) / 6.0f)); }
140
141// static const float GELU_COEF_A = 0.044715f;
142// static const float GELU_QUICK_COEF = -1.702f;
143// static const float SQRT_2_OVER_PI = 0.79788456080286535587989211986876f;
144
145// inline static float ggml_gelu_F32(float x) {
146// return 0.5f*x*(1.0f + tanhf(SQRT_2_OVER_PI*x*(1.0f + GELU_COEF_A*x*x)));
147// }
148
149// inline static void NN_gelu_f16(const int n, float16_t * y, const float16_t * x) {
150// const uint16_t * i16 = (const uint16_t *) x;
151// for (int i = 0; i < n; i += 1) {
152// y[i] = ggml_table_gelu_f16[i16[i]];
153// }
154// }
155
156// #ifdef GGML_GELU_FP16
157// inline static void NN_gelu_F32(const int n, float *y, const float *x) {
158// uint16_t t;
159// for (int i = 0; i < n; i += 1) {
160// if (x[i] <= -10.0f) {
161// y[i] = 0.0f;
162// } else if (x[i] >= 10.0f) {
163// y[i] = x[i];
164// } else {
165// float16_t fp16 = GGML_FP32_TO_FP16(x[i]);
166// memcpy(&t, &fp16, sizeof(uint16_t));
167// y[i] = GGML_FP16_TO_FP32(ggml_table_gelu_f16[t]);
168// }
169// }
170// }
171// #else
172// inline static void NN_gelu_F32(const int n, float *y, const float *x) {
173// for (int i = 0; i < n; i += 1) {
174// y[i] = ggml_gelu_F32(x[i]);
175// }
176// }
177// #endif
178
179// inline static float ggml_gelu_quick_F32(float x) {
180// return x*(1.0f/(1.0f+expf(GELU_QUICK_COEF*x)));
181// }
182
183// //inline static void NN_gelu_quick_f16(const int n, float16_t * y, const float16_t * x) {
184// // const uint16_t * i16 = (const uint16_t *) x;
185// // for (int i = 0; i < n; i += 1) {
186// // y[i] = ggml_table_gelu_quick_f16[i16[i]];
187// // }
188// //}
189
190// #ifdef GGML_GELU_QUICK_FP16
191// inline static void NN_gelu_quick_F32(const int n, float *y, const float *x) {
192// uint16_t t;
193// for (int i = 0; i < n; i += 1) {
194// float16_t fp16 = GGML_FP32_TO_FP16(x[i]);
195// memcpy(&t, &fp16, sizeof(uint16_t));
196// y[i] = GGML_FP16_TO_FP32(ggml_table_gelu_quick_f16[t]);
197// }
198// }
199// #else
200// inline static void NN_gelu_quick_F32(const int n, float *y, const float *x) {
201// for (int i = 0; i < n; i += 1) {
202// y[i] = ggml_gelu_quick_F32(x[i]);
203// }
204// }
205// #endif
206
207// // Sigmoid Linear Unit (SiLU) function
208// inline static float ggml_silu_F32(float x) {
209// return x/(1.0f + expf(-x));
210// }
211
212// #if defined(__ARM_NEON) && defined(__aarch64__)
213
214// // adapted from arm limited optimized routine
215// // the maximum error is 1.45358 plus 0.5 ulps
216// // numbers above 88.38 will flush to infinity
217// // numbers beneath -103.97 will flush to zero
218// inline static float32x4_t ggml_v_expf(float32x4_t x) {
219// const float32x4_t r = vdupq_n_F32(0x1.8p23f);
220// const float32x4_t z = vfmaq_F32(r, x, vdupq_n_F32(0x1.715476p+0f));
221// const float32x4_t n = vsubq_F32(z, r);
222// const float32x4_t b = vfmsq_F32(vfmsq_F32(x, n, vdupq_n_F32(0x1.62e4p-1f)), n,
223// vdupq_n_F32(0x1.7f7d1cp-20f));
224// const uint32x4_t e = vshlq_n_u32(vreinterpretq_u32_F32(z), 23);
225// const float32x4_t k = vreinterpretq_F32_u32(vaddq_u32(e, vreinterpretq_u32_F32(vdupq_n_F32(1))));
226// const uint32x4_t c = vcagtq_F32(n, vdupq_n_F32(126));
227// const float32x4_t u = vmulq_F32(b, b);
228// const float32x4_t j = vfmaq_F32(
229// vmulq_F32(vdupq_n_F32(0x1.ffffecp-1f), b),
230// vfmaq_F32(vfmaq_F32(vdupq_n_F32(0x1.fffdb6p-2f), vdupq_n_F32(0x1.555e66p-3f), b),
231// vfmaq_F32(vdupq_n_F32(0x1.573e2ep-5f), vdupq_n_F32(0x1.0e4020p-7f), b), u), u);
232// if (!vpaddd_u64(vreinterpretq_u64_u32(c)))
233// return vfmaq_F32(k, j, k);
234// const uint32x4_t d = vandq_u32(vclezq_F32(n), vdupq_n_u32(0x82000000));
235// const float32x4_t s1 = vreinterpretq_F32_u32(vaddq_u32(d, vdupq_n_u32(0x7f000000)));
236// const float32x4_t s2 = vreinterpretq_F32_u32(vsubq_u32(e, d));
237// return vbslq_F32(vcagtq_F32(n, vdupq_n_F32(192)), vmulq_F32(s1, s1),
238// vbslq_F32(c, vmulq_F32(vfmaq_F32(s2, s2, j), s1), vfmaq_F32(k, k, j)));
239// }
240
241// // computes silu x/(1+exp(-x)) in single precision vector
242// inline static float32x4_t ggml_v_silu(float32x4_t x) {
243// const float32x4_t one = vdupq_n_F32(1.0f);
244// const float32x4_t zero = vdupq_n_F32(0.0f);
245// const float32x4_t neg_x = vsubq_F32(zero, x);
246// const float32x4_t exp_neg_x = ggml_v_expf(neg_x);
247// const float32x4_t one_plus_exp_neg_x = vaddq_F32(one, exp_neg_x);
248// return vdivq_F32(x, one_plus_exp_neg_x);
249// }
250
251// #elif defined(__AVX512F__) && defined(__AVX512DQ__)
252
253// // adapted from arm limited optimized routine
254// // the maximum error is 1.45358 plus 0.5 ulps
255// // numbers above 88.38 will flush to infinity
256// // numbers beneath -103.97 will flush to zero
257// inline static __m512 ggml_v_expf(__m512 x) {
258// const __m512 r = _mm512_set1_ps(0x1.8p23f);
259// const __m512 z = _mm512_fmadd_ps(x, _mm512_set1_ps(0x1.715476p+0f), r);
260// const __m512 n = _mm512_sub_ps(z, r);
261// const __m512 b = _mm512_fnmadd_ps(n, _mm512_set1_ps(0x1.7f7d1cp-20f),
262// _mm512_fnmadd_ps(n, _mm512_set1_ps(0x1.62e4p-1f), x));
263// const __m512i e = _mm512_slli_epi32(_mm512_castps_si512(z), 23);
264// const __m512 k = _mm512_castsi512_ps(_mm512_add_epi32(e, _mm512_castps_si512(_mm512_set1_ps(1))));
265// const __mmask16 c = _mm512_cmp_ps_mask(_mm512_abs_ps(n), _mm512_set1_ps(126), _CMP_GT_OQ);
266// const __m512 u = _mm512_mul_ps(b, b);
267// const __m512 j = _mm512_fmadd_ps(_mm512_fmadd_ps(_mm512_fmadd_ps(_mm512_set1_ps(0x1.0e4020p-7f), b,
268// _mm512_set1_ps(0x1.573e2ep-5f)), u,
269// _mm512_fmadd_ps(_mm512_set1_ps(0x1.555e66p-3f), b,
270// _mm512_set1_ps(0x1.fffdb6p-2f))),
271// u, _mm512_mul_ps(_mm512_set1_ps(0x1.ffffecp-1f), b));
272// if (_mm512_kortestz(c, c))
273// return _mm512_fmadd_ps(j, k, k);
274// const __m512i g = _mm512_and_si512(
275// _mm512_movm_epi32(_mm512_cmp_ps_mask(n, _mm512_setzero_ps(), _CMP_LE_OQ)),
276// _mm512_set1_epi32(0x82000000u));
277// const __m512 s1 =
278// _mm512_castsi512_ps(_mm512_add_epi32(g, _mm512_set1_epi32(0x7f000000u)));
279// const __m512 s2 = _mm512_castsi512_ps(_mm512_sub_epi32(e, g));
280// const __mmask16 d =
281// _mm512_cmp_ps_mask(_mm512_abs_ps(n), _mm512_set1_ps(192), _CMP_GT_OQ);
282// return _mm512_mask_blend_ps(
283// d, _mm512_mask_blend_ps(
284// c, _mm512_fmadd_ps(k, j, k),
285// _mm512_mul_ps(_mm512_fmadd_ps(s2, j, s2), s1)),
286// _mm512_mul_ps(s1, s1));
287// }
288
289// // computes silu x/(1+exp(-x)) in single precision vector
290// inline static __m512 ggml_v_silu(__m512 x) {
291// const __m512 one = _mm512_set1_ps(1);
292// const __m512 zero = _mm512_setzero_ps();
293// const __m512 neg_x = _mm512_sub_ps(zero, x);
294// const __m512 exp_neg_x = ggml_v_expf(neg_x);
295// const __m512 one_plus_exp_neg_x = _mm512_add_ps(one, exp_neg_x);
296// return _mm512_div_ps(x, one_plus_exp_neg_x);
297// }
298
299// #elif defined(__AVX2__) && defined(__FMA__)
300
301// // adapted from arm limited optimized routine
302// // the maximum error is 1.45358 plus 0.5 ulps
303// // numbers above 88.38 will flush to infinity
304// // numbers beneath -103.97 will flush to zero
305// inline static __m256 ggml_v_expf(__m256 x) {
306// const __m256 r = _mm256_set1_ps(0x1.8p23f);
307// const __m256 z = _mm256_fmadd_ps(x, _mm256_set1_ps(0x1.715476p+0f), r);
308// const __m256 n = _mm256_sub_ps(z, r);
309// const __m256 b = _mm256_fnmadd_ps(n, _mm256_set1_ps(0x1.7f7d1cp-20f),
310// _mm256_fnmadd_ps(n, _mm256_set1_ps(0x1.62e4p-1f), x));
311// const __m256i e = _mm256_slli_epi32(_mm256_castps_si256(z), 23);
312// const __m256 k = _mm256_castsi256_ps(
313// _mm256_add_epi32(e, _mm256_castps_si256(_mm256_set1_ps(1))));
314// const __m256i c = _mm256_castps_si256(
315// _mm256_cmp_ps(_mm256_andnot_ps(_mm256_set1_ps(-0.f), n),
316// _mm256_set1_ps(126), _CMP_GT_OQ));
317// const __m256 u = _mm256_mul_ps(b, b);
318// const __m256 j = _mm256_fmadd_ps(_mm256_fmadd_ps(_mm256_fmadd_ps(_mm256_set1_ps(0x1.0e4020p-7f), b,
319// _mm256_set1_ps(0x1.573e2ep-5f)), u,
320// _mm256_fmadd_ps(_mm256_set1_ps(0x1.555e66p-3f), b,
321// _mm256_set1_ps(0x1.fffdb6p-2f))),
322// u, _mm256_mul_ps(_mm256_set1_ps(0x1.ffffecp-1f), b));
323// if (!_mm256_movemask_ps(_mm256_castsi256_ps(c)))
324// return _mm256_fmadd_ps(j, k, k);
325// const __m256i g = _mm256_and_si256(
326// _mm256_castps_si256(_mm256_cmp_ps(n, _mm256_setzero_ps(), _CMP_LE_OQ)),
327// _mm256_set1_epi32(0x82000000u));
328// const __m256 s1 =
329// _mm256_castsi256_ps(_mm256_add_epi32(g, _mm256_set1_epi32(0x7f000000u)));
330// const __m256 s2 = _mm256_castsi256_ps(_mm256_sub_epi32(e, g));
331// const __m256i d = _mm256_castps_si256(
332// _mm256_cmp_ps(_mm256_andnot_ps(_mm256_set1_ps(-0.f), n),
333// _mm256_set1_ps(192), _CMP_GT_OQ));
334// return _mm256_or_ps(
335// _mm256_and_ps(_mm256_castsi256_ps(d), _mm256_mul_ps(s1, s1)),
336// _mm256_andnot_ps(
337// _mm256_castsi256_ps(d),
338// _mm256_or_ps(
339// _mm256_and_ps(_mm256_castsi256_ps(c),
340// _mm256_mul_ps(_mm256_fmadd_ps(s2, j, s2), s1)),
341// _mm256_andnot_ps(_mm256_castsi256_ps(c), _mm256_fmadd_ps(k, j, k)))));
342// }
343
344// // computes silu x/(1+exp(-x)) in single precision vector
345// inline static __m256 ggml_v_silu(__m256 x) {
346// const __m256 one = _mm256_set1_ps(1);
347// const __m256 zero = _mm256_setzero_ps();
348// const __m256 neg_x = _mm256_sub_ps(zero, x);
349// const __m256 exp_neg_x = ggml_v_expf(neg_x);
350// const __m256 one_plus_exp_neg_x = _mm256_add_ps(one, exp_neg_x);
351// return _mm256_div_ps(x, one_plus_exp_neg_x);
352// }
353
354// #elif defined(__SSE2__) // __AVX2__ / __ARM_NEON
355
356// #if defined(__FMA__)
357// #define MADD128(x, y, z) _mm_fmadd_ps(x, y, z)
358// #define NMADD128(x, y, z) _mm_fnmadd_ps(x, y, z)
359// #else
360// #define MADD128(x, y, z) _mm_add_ps(_mm_mul_ps(x, y), z)
361// #define NMADD128(x, y, z) _mm_sub_ps(z, _mm_mul_ps(x, y))
362// #endif
363
364// // adapted from arm limited optimized routine
365// // the maximum error is 1.45358 plus 0.5 ulps
366// // numbers above 88.38 will flush to infinity
367// // numbers beneath -103.97 will flush to zero
368// inline static __m128 ggml_v_expf(__m128 x) {
369// const __m128 r = _mm_set1_ps(0x1.8p23f);
370// const __m128 z = MADD128(x, _mm_set1_ps(0x1.715476p+0f), r);
371// const __m128 n = _mm_sub_ps(z, r);
372// const __m128 b =
373// NMADD128(n, _mm_set1_ps(0x1.7f7d1cp-20f), NMADD128(n, _mm_set1_ps(0x1.62e4p-1f), x));
374// const __m128i e = _mm_slli_epi32(_mm_castps_si128(z), 23);
375// const __m128 k = _mm_castsi128_ps(_mm_add_epi32(e, _mm_castps_si128(_mm_set1_ps(1))));
376// const __m128i c =
377// _mm_castps_si128(_mm_cmpgt_ps(_mm_andnot_ps(_mm_set1_ps(-0.f), n), _mm_set1_ps(126)));
378// const __m128 u = _mm_mul_ps(b, b);
379// const __m128 j =
380// MADD128(MADD128(MADD128(_mm_set1_ps(0x1.0e4020p-7f), b, _mm_set1_ps(0x1.573e2ep-5f)), u,
381// MADD128(_mm_set1_ps(0x1.555e66p-3f), b, _mm_set1_ps(0x1.fffdb6p-2f))),
382// u, _mm_mul_ps(_mm_set1_ps(0x1.ffffecp-1f), b));
383// if (!_mm_movemask_epi8(c))
384// return MADD128(j, k, k);
385// const __m128i g = _mm_and_si128(_mm_castps_si128(_mm_cmple_ps(n, _mm_setzero_ps())),
386// _mm_set1_epi32(0x82000000u));
387// const __m128 s1 = _mm_castsi128_ps(_mm_add_epi32(g, _mm_set1_epi32(0x7f000000u)));
388// const __m128 s2 = _mm_castsi128_ps(_mm_sub_epi32(e, g));
389// const __m128i d =
390// _mm_castps_si128(_mm_cmpgt_ps(_mm_andnot_ps(_mm_set1_ps(-0.f), n), _mm_set1_ps(192)));
391// return _mm_or_ps(
392// _mm_and_ps(_mm_castsi128_ps(d), _mm_mul_ps(s1, s1)),
393// _mm_andnot_ps(_mm_castsi128_ps(d),
394// _mm_or_ps(_mm_and_ps(_mm_castsi128_ps(c), _mm_mul_ps(MADD128(s2, j, s2), s1)),
395// _mm_andnot_ps(_mm_castsi128_ps(c), MADD128(k, j, k)))));
396// }
397
398// // computes silu x/(1+exp(-x)) in single precision vector
399// inline static __m128 ggml_v_silu(__m128 x) {
400// const __m128 one = _mm_set1_ps(1);
401// const __m128 zero = _mm_setzero_ps();
402// const __m128 neg_x = _mm_sub_ps(zero, x);
403// const __m128 exp_neg_x = ggml_v_expf(neg_x);
404// const __m128 one_plus_exp_neg_x = _mm_add_ps(one, exp_neg_x);
405// return _mm_div_ps(x, one_plus_exp_neg_x);
406// }
407
408// #endif // __ARM_NEON / __AVX2__ / __SSE2__
409
410// static void NN_silu_F32(const int n, float *y, const float *x) {
411// int i = 0;
412// #if defined(__AVX512F__) && defined(__AVX512DQ__)
413// for (; i + 15 < n; i += 16) {
414// _mm512_storeu_ps(y + i, ggml_v_silu(_mm512_loadu_ps(x + i)));
415// }
416// #elif defined(__AVX2__) && defined(__FMA__)
417// for (; i + 7 < n; i += 8) {
418// _mm256_storeu_ps(y + i, ggml_v_silu(_mm256_loadu_ps(x + i)));
419// }
420// #elif defined(__SSE2__)
421// for (; i + 3 < n; i += 4) {
422// _mm_storeu_ps(y + i, ggml_v_silu(_mm_loadu_ps(x + i)));
423// }
424// #elif defined(__ARM_NEON) && defined(__aarch64__)
425// for (; i + 3 < n; i += 4) {
426// vst1q_F32(y + i, ggml_v_silu(vld1q_F32(x + i)));
427// }
428// #endif
429// for (; i < n; i += 1) {
430// y[i] = ggml_silu_F32(x[i]);
431// }
432// }
433
434
435// inline static float ggml_silu_backward_F32(float x, float dy) {
436// const float s = 1.0f/(1.0f + expf(-x));
437// return dy*s*(1.0f + x*(1.0f - s));
438// }
439
440// inline static void NN_silu_backward_F32(const int n, float *dx, const float *x, const float *dy) {
441// for (int i = 0; i < n; i += 1) {
442// dx[i] = ggml_silu_backward_F32(x[i], dy[i]);
443// }
444// }
445// inline static void NN_argmax_F32(const int n, int * s, const float *x) {
446// float max = -INFINITY;
447// int idx = 0;
448// for (int i = 0; i < n; i += 1) {
449// max = MAX(max, x[i]);
450// if (max == x[i]) { idx = i; }
451// }
452// *s = idx;
453// }
454
455
456#endif // __NN_MATH_H