xref: /linux/fs/bcachefs/mean_and_variance.h (revision bf5802238dc181b1f7375d358af1d01cd72d1c11)
1 /* SPDX-License-Identifier: GPL-2.0 */
2 #ifndef MEAN_AND_VARIANCE_H_
3 #define MEAN_AND_VARIANCE_H_
4 
5 #include <linux/types.h>
6 #include <linux/limits.h>
7 #include <linux/math.h>
8 #include <linux/math64.h>
9 
10 #define SQRT_U64_MAX 4294967295ULL
11 
12 /*
13  * u128_u: u128 user mode, because not all architectures support a real int128
14  * type
15  *
16  * We don't use this version in userspace, because in userspace we link with
17  * Rust and rustc has issues with u128.
18  */
19 
20 #if defined(__SIZEOF_INT128__) && defined(__KERNEL__) && !defined(CONFIG_PARISC)
21 
22 typedef struct {
23 	unsigned __int128 v;
24 } __aligned(16) u128_u;
25 
26 static inline u128_u u64_to_u128(u64 a)
27 {
28 	return (u128_u) { .v = a };
29 }
30 
31 static inline u64 u128_lo(u128_u a)
32 {
33 	return a.v;
34 }
35 
36 static inline u64 u128_hi(u128_u a)
37 {
38 	return a.v >> 64;
39 }
40 
41 static inline u128_u u128_add(u128_u a, u128_u b)
42 {
43 	a.v += b.v;
44 	return a;
45 }
46 
47 static inline u128_u u128_sub(u128_u a, u128_u b)
48 {
49 	a.v -= b.v;
50 	return a;
51 }
52 
53 static inline u128_u u128_shl(u128_u a, s8 shift)
54 {
55 	a.v <<= shift;
56 	return a;
57 }
58 
59 static inline u128_u u128_square(u64 a)
60 {
61 	u128_u b = u64_to_u128(a);
62 
63 	b.v *= b.v;
64 	return b;
65 }
66 
67 #else
68 
69 typedef struct {
70 	u64 hi, lo;
71 } __aligned(16) u128_u;
72 
73 /* conversions */
74 
75 static inline u128_u u64_to_u128(u64 a)
76 {
77 	return (u128_u) { .lo = a };
78 }
79 
80 static inline u64 u128_lo(u128_u a)
81 {
82 	return a.lo;
83 }
84 
85 static inline u64 u128_hi(u128_u a)
86 {
87 	return a.hi;
88 }
89 
90 /* arithmetic */
91 
92 static inline u128_u u128_add(u128_u a, u128_u b)
93 {
94 	u128_u c;
95 
96 	c.lo = a.lo + b.lo;
97 	c.hi = a.hi + b.hi + (c.lo < a.lo);
98 	return c;
99 }
100 
101 static inline u128_u u128_sub(u128_u a, u128_u b)
102 {
103 	u128_u c;
104 
105 	c.lo = a.lo - b.lo;
106 	c.hi = a.hi - b.hi - (c.lo > a.lo);
107 	return c;
108 }
109 
110 static inline u128_u u128_shl(u128_u i, s8 shift)
111 {
112 	u128_u r;
113 
114 	r.lo = i.lo << shift;
115 	if (shift < 64)
116 		r.hi = (i.hi << shift) | (i.lo >> (64 - shift));
117 	else {
118 		r.hi = i.lo << (shift - 64);
119 		r.lo = 0;
120 	}
121 	return r;
122 }
123 
124 static inline u128_u u128_square(u64 i)
125 {
126 	u128_u r;
127 	u64  h = i >> 32, l = i & U32_MAX;
128 
129 	r =             u128_shl(u64_to_u128(h*h), 64);
130 	r = u128_add(r, u128_shl(u64_to_u128(h*l), 32));
131 	r = u128_add(r, u128_shl(u64_to_u128(l*h), 32));
132 	r = u128_add(r,          u64_to_u128(l*l));
133 	return r;
134 }
135 
136 #endif
137 
138 static inline u128_u u64s_to_u128(u64 hi, u64 lo)
139 {
140 	u128_u c = u64_to_u128(hi);
141 
142 	c = u128_shl(c, 64);
143 	c = u128_add(c, u64_to_u128(lo));
144 	return c;
145 }
146 
147 u128_u u128_div(u128_u n, u64 d);
148 
149 struct mean_and_variance {
150 	s64	n;
151 	s64	sum;
152 	u128_u	sum_squares;
153 };
154 
155 /* expontentially weighted variant */
156 struct mean_and_variance_weighted {
157 	bool	init;
158 	u8	weight;	/* base 2 logarithim */
159 	s64	mean;
160 	u64	variance;
161 };
162 
163 /**
164  * fast_divpow2() - fast approximation for n / (1 << d)
165  * @n: numerator
166  * @d: the power of 2 denominator.
167  *
168  * note: this rounds towards 0.
169  */
170 static inline s64 fast_divpow2(s64 n, u8 d)
171 {
172 	return (n + ((n < 0) ? ((1 << d) - 1) : 0)) >> d;
173 }
174 
175 /**
176  * mean_and_variance_update() - update a mean_and_variance struct @s1 with a new sample @v1
177  * and return it.
178  * @s1: the mean_and_variance to update.
179  * @v1: the new sample.
180  *
181  * see linked pdf equation 12.
182  */
183 static inline void
184 mean_and_variance_update(struct mean_and_variance *s, s64 v)
185 {
186 	s->n++;
187 	s->sum += v;
188 	s->sum_squares = u128_add(s->sum_squares, u128_square(abs(v)));
189 }
190 
191 s64 mean_and_variance_get_mean(struct mean_and_variance s);
192 u64 mean_and_variance_get_variance(struct mean_and_variance s1);
193 u32 mean_and_variance_get_stddev(struct mean_and_variance s);
194 
195 void mean_and_variance_weighted_update(struct mean_and_variance_weighted *s, s64 v);
196 
197 s64 mean_and_variance_weighted_get_mean(struct mean_and_variance_weighted s);
198 u64 mean_and_variance_weighted_get_variance(struct mean_and_variance_weighted s);
199 u32 mean_and_variance_weighted_get_stddev(struct mean_and_variance_weighted s);
200 
201 #endif // MEAN_AND_VAIRANCE_H_
202