1 /**
2 Copyright: Copyright (c) 2020, Joakim Brännström. All rights reserved.
3 License: $(LINK2 http://www.boost.org/LICENSE_1_0.txt, Boost Software License 1.0)
4 Author: Joakim Brännström (joakim.brannstrom@gmx.com)
5 
6 This module contains some simple statistics functionality. It isn't intended to
7 be a full blown stat packaged, that is
8 [mir](http://mir-algorithm.libmir.org/mir_math_stat.html). I wrote this module
9 because I had problem using **mir** and only needed a small subset of the
10 functionality.
11 
12 The functions probably contain rounding errors etc so be aware. But it seems to
13 work well enough for simple needs.
14 */
15 module my.stat;
16 
17 import logger = std.experimental.logger;
18 import std;
19 import std.array : appender;
20 import std.ascii : newline;
21 import std.format : formattedWrite;
22 import std.range : isOutputRange, put;
23 
24 @safe:
25 
26 /// Example:
27 unittest {
28     auto d0 = [3, 14, 18, 24, 29].makeData;
29 
30     writeln(basicStat(d0));
31 
32     writeln(histogram(d0, 3));
33     writeln(histogram(d0, 3).toBar);
34 
35     auto d1 = pdf(NormDistribution(0, 1)).take(10000).makeData;
36     writeln(basicStat(d1));
37     writeln(stdError(d1));
38 
39     auto hist = histogram(d1, 21);
40     writeln(hist.toBar);
41     writeln(hist.mode);
42 
43     writeln(cdf(NormDistribution(0, 1), 1) - cdf(NormDistribution(0, 1), -1));
44 }
45 
46 struct StatData {
47     double[] value;
48 
49     size_t length() {
50         return value.length;
51     }
52 }
53 
54 /// Convert user data to a representation useful for simple, statistics calculations.
55 StatData makeData(T)(T raw) {
56     import std.algorithm;
57 
58     double[] r = raw.map!(a => cast(double) a).array;
59     if (r.length <= 1)
60         throw new Exception("Too few samples");
61     return StatData(r);
62 }
63 
64 struct Mean {
65     double value;
66 }
67 
68 Mean mean(StatData data) {
69     const N = cast(double) data.length;
70     return Mean(data.value.sum / N);
71 }
72 
73 /// According to wikipedia this is the Corrected Sample Standard Deviation
74 struct SampleStdDev {
75     double value;
76 }
77 
78 SampleStdDev sampleStdDev(StatData data, Mean mean) {
79     const N = cast(double) data.length;
80     const s = data.value.map!(a => pow(a - mean.value, 2.0)).sum;
81     return SampleStdDev(sqrt(s / (N - 1.0)));
82 }
83 
84 struct Median {
85     double value;
86 }
87 
88 Median median(StatData data_) {
89     const data = data_.value.sort.map!(a => cast(double) a).array;
90 
91     if (data.length % 2 == 0)
92         return Median((data[$ / 2 - 1] + data[$ / 2]) / 2.0);
93     return Median(data[$ / 2]);
94 }
95 
96 struct Histogram {
97     long[] buckets;
98     double low;
99     double high;
100     double interval;
101 
102     this(double low, double high, long nrBuckets)
103     in (nrBuckets > 1, "failed nrBuckets > 1")
104     in (low < high, "failed low < high") {
105         this.low = low;
106         this.high = high;
107         interval = (high - low) / cast(double) nrBuckets;
108         buckets = iota(0, cast(long) ceil((high - low) / interval)).map!(a => 0L).array;
109     }
110 
111     void put(const double v)
112     in (v >= low && v <= high, "v must be in the range [low, high]") {
113         const idx = cast(long) floor((v - low) / interval);
114         assert(idx >= 0);
115 
116         if (idx < buckets.length)
117             buckets[idx] += 1;
118         else
119             buckets[$ - 1] += 1;
120     }
121 
122     string toString() @safe const {
123         auto buf = appender!string;
124         toString(buf);
125         return buf.data;
126     }
127 
128     void toString(Writer)(ref Writer w) const if (isOutputRange!(Writer, char)) {
129         import std.range : put;
130 
131         formattedWrite(w, "Histogram(low:%s, high:%s, interval:%s, buckets: [",
132                 low, high, interval);
133         foreach (const i; 0 .. buckets.length) {
134             if (i != 0)
135                 put(w, ", ");
136             formattedWrite(w, "[%s, %s]:%s", (low + i * interval),
137                     (low + (i + 1) * interval), buckets[i]);
138         }
139         put(w, "])");
140     }
141 
142     string toBar() @safe const {
143         auto buf = appender!string;
144         toBar(buf);
145         return buf.data;
146     }
147 
148     void toBar(Writer)(ref Writer w) const if (isOutputRange!(Writer, char)) {
149         import std.range : put;
150         import std.range : repeat;
151 
152         immutable maxWidth = 42;
153         const fit = () {
154             const m = maxElement(buckets);
155             if (m > maxWidth)
156                 return cast(double) m / cast(double) maxWidth;
157             return 1.0;
158         }();
159 
160         const indexWidth = cast(int) ceil(log10(buckets.length) + 1);
161 
162         foreach (const i; 0 .. buckets.length) {
163             const row = format("[%.3f, %.3f]", (low + i * interval), (low + (i + 1) * interval));
164             formattedWrite(w, "%*s %30s: %-(%s%) %s", indexWidth, i, row,
165                     repeat("#", cast(size_t)(buckets[i] / fit)), buckets[i]);
166             put(w, newline);
167         }
168     }
169 }
170 
171 Histogram histogram(StatData data, long nrBuckets) {
172     auto hist = () {
173         double low = data.value[0];
174         double high = data.value[0];
175         foreach (const v; data.value) {
176             low = min(low, v);
177             high = max(high, v);
178         }
179         return Histogram(low, high, nrBuckets);
180     }();
181 
182     foreach (const v; data.value)
183         hist.put(v);
184 
185     return hist;
186 }
187 
188 struct Mode {
189     double value;
190 }
191 
192 Mode mode(Histogram hist) {
193     long cnt = hist.buckets[0];
194     double rval = hist.low;
195     foreach (const i; 1 .. hist.buckets.length) {
196         if (hist.buckets[i] > cnt) {
197             rval = hist.low + (i + 0.5) * hist.interval;
198             cnt = hist.buckets[i];
199         }
200     }
201 
202     return Mode(rval);
203 }
204 
205 struct BasicStat {
206     Mean mean;
207     Median median;
208     SampleStdDev sd;
209 
210     string toString() @safe const {
211         auto buf = appender!string;
212         toString(buf);
213         return buf.data;
214     }
215 
216     void toString(Writer)(ref Writer w) const if (isOutputRange!(Writer, char)) {
217         formattedWrite(w, "BasicStat(mean:%s, median:%s, stdev: %s)",
218                 mean.value, median.value, sd.value);
219     }
220 }
221 
222 BasicStat basicStat(StatData data) {
223     auto m = mean(data);
224     return BasicStat(m, median(data), sampleStdDev(data, m));
225 }
226 
227 struct NormDistribution {
228     double mean;
229     double sd;
230 }
231 
232 /// From the C++ standard library implementation.
233 struct NormalDistributionPdf {
234     NormDistribution nd;
235     private double front_;
236     private double V;
237     private bool Vhot;
238 
239     double front() @safe pure nothrow {
240         assert(!empty, "Can't get front of an empty range");
241         return front_;
242     }
243 
244     void popFront() @safe {
245         assert(!empty, "Can't pop front of an empty range");
246 
247         import std.random : uniform;
248 
249         double Up;
250 
251         if (Vhot) {
252             Vhot = false;
253             Up = V;
254         } else {
255             double u;
256             double v;
257             double s;
258 
259             do {
260                 u = uniform(-1.0, 1.0);
261                 v = uniform(-1.0, 1.0);
262                 s = u * u + v * v;
263             }
264             while (s > 1 || s == 0);
265 
266             double Fp = sqrt(-2.0 * log(s) / s);
267             V = v * Fp;
268             Vhot = true;
269             Up = u * Fp;
270         }
271         front_ = Up * nd.sd + nd.mean;
272     }
273 
274     enum bool empty = false;
275 }
276 
277 NormalDistributionPdf pdf(NormDistribution nd) {
278     auto rval = NormalDistributionPdf(nd);
279     rval.popFront;
280     return rval;
281 }
282 
283 double cdf(NormDistribution nd, double x)
284 in (nd.sd > 0, "domain error") {
285     if (isInfinity(x)) {
286         if (x < 0)
287             return 0;
288         return 1;
289     }
290 
291     const diff = (x - nd.mean) / (nd.sd * SQRT2);
292 
293     return cast(double) erfc(-diff) / 2.0;
294 }
295 
296 struct StdMeanError {
297     double value;
298 }
299 
300 StdMeanError stdError(StatData data)
301 in (data.value.length > 1) {
302     const len = data.value.length;
303     double[] means;
304     long samples = max(30, data.value.length);
305     for (; samples > 0; --samples) {
306         means ~= bootstrap(data).sum / cast(double) len;
307     }
308 
309     return StdMeanError(sampleStdDev(StatData(means), StatData(means).mean).value);
310 }
311 
312 auto bootstrap(StatData data, long minSamples = 5)
313 in (minSamples > 0)
314 in (data.value.length > 1) {
315     const len = data.value.length;
316     return iota(min(minSamples, len)).map!(a => uniform(0, len - 1))
317         .map!(a => data.value[a]);
318 }