001/*
002 * Licensed to the Apache Software Foundation (ASF) under one or more
003 * contributor license agreements.  See the NOTICE file distributed with
004 * this work for additional information regarding copyright ownership.
005 * The ASF licenses this file to You under the Apache License, Version 2.0
006 * (the "License"); you may not use this file except in compliance with
007 * the License.  You may obtain a copy of the License at
008 *
009 *      https://www.apache.org/licenses/LICENSE-2.0
010 *
011 * Unless required by applicable law or agreed to in writing, software
012 * distributed under the License is distributed on an "AS IS" BASIS,
013 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
014 * See the License for the specific language governing permissions and
015 * limitations under the License.
016 */
017
018package org.apache.commons.statistics.distribution;
019
020import java.util.function.DoubleSupplier;
021import org.apache.commons.numbers.gamma.Erf;
022import org.apache.commons.numbers.gamma.ErfDifference;
023import org.apache.commons.numbers.gamma.Erfcx;
024import org.apache.commons.rng.UniformRandomProvider;
025import org.apache.commons.rng.sampling.distribution.ZigguratSampler;
026
027/**
028 * Implementation of the truncated normal distribution.
029 *
030 * <p>The probability density function of \( X \) is:
031 *
032 * <p>\[ f(x;\mu,\sigma,a,b) = \frac{1}{\sigma}\,\frac{\phi(\frac{x - \mu}{\sigma})}{\Phi(\frac{b - \mu}{\sigma}) - \Phi(\frac{a - \mu}{\sigma}) } \]
033 *
034 * <p>for \( \mu \) mean of the parent normal distribution,
035 * \( \sigma \) standard deviation of the parent normal distribution,
036 * \( -\infty \le a \lt b \le \infty \) the truncation interval, and
037 * \( x \in [a, b] \), where \( \phi \) is the probability
038 * density function of the standard normal distribution and \( \Phi \)
039 * is its cumulative distribution function.
040 *
041 * @see <a href="https://en.wikipedia.org/wiki/Truncated_normal_distribution">
042 * Truncated normal distribution (Wikipedia)</a>
043 */
044public final class TruncatedNormalDistribution extends AbstractContinuousDistribution {
045
046    /** The max allowed value for x where (x*x) will not overflow.
047     * This is a limit on computation of the moments of the truncated normal
048     * as some calculations assume x*x is finite. Value is sqrt(MAX_VALUE). */
049    private static final double MAX_X = 0x1.fffffffffffffp511;
050
051    /** The min allowed probability range of the parent normal distribution.
052     * Set to 0.0. This may be too low for accurate usage. It is a signal that
053     * the truncation is invalid. */
054    private static final double MIN_P = 0.0;
055
056    /** sqrt(2). */
057    private static final double ROOT2 = Constants.ROOT_TWO;
058    /** Normalisation constant 2 / sqrt(2 pi) = sqrt(2 / pi). */
059    private static final double ROOT_2_PI = Constants.ROOT_TWO_DIV_PI;
060    /** Normalisation constant sqrt(2 pi) / 2 = sqrt(pi / 2). */
061    private static final double ROOT_PI_2 = Constants.ROOT_PI_DIV_TWO;
062
063    /**
064     * The threshold to switch to a rejection sampler. When the truncated
065     * distribution covers more than this fraction of the CDF then rejection
066     * sampling will be more efficient than inverse CDF sampling. Performance
067     * benchmarks indicate that a normalized Gaussian sampler is up to 10 times
068     * faster than inverse transform sampling using a fast random generator. See
069     * STATISTICS-55.
070     */
071    private static final double REJECTION_THRESHOLD = 0.2;
072
073    /** Parent normal distribution. */
074    private final NormalDistribution parentNormal;
075    /** Lower bound of this distribution. */
076    private final double lower;
077    /** Upper bound of this distribution. */
078    private final double upper;
079
080    /** Stored value of {@code parentNormal.probability(lower, upper)}. This is used to
081     * normalise the probability computations. */
082    private final double cdfDelta;
083    /** log(cdfDelta). */
084    private final double logCdfDelta;
085    /** Stored value of {@code parentNormal.cumulativeProbability(lower)}. Used to map
086     * a probability into the range of the parent normal distribution. */
087    private final double cdfAlpha;
088    /** Stored value of {@code parentNormal.survivalProbability(upper)}. Used to map
089     * a probability into the range of the parent normal distribution. */
090    private final double sfBeta;
091
092    /**
093     * @param parent Parent distribution.
094     * @param z Probability of the parent distribution for {@code [lower, upper]}.
095     * @param lower Lower bound (inclusive) of the distribution, can be {@link Double#NEGATIVE_INFINITY}.
096     * @param upper Upper bound (inclusive) of the distribution, can be {@link Double#POSITIVE_INFINITY}.
097     */
098    private TruncatedNormalDistribution(NormalDistribution parent, double z, double lower, double upper) {
099        this.parentNormal = parent;
100        this.lower = lower;
101        this.upper = upper;
102
103        cdfDelta = z;
104        logCdfDelta = Math.log(cdfDelta);
105        // Used to map the inverse probability.
106        cdfAlpha = parentNormal.cumulativeProbability(lower);
107        sfBeta = parentNormal.survivalProbability(upper);
108    }
109
110    /**
111     * Creates a truncated normal distribution.
112     *
113     * <p>Note that the {@code mean} and {@code sd} is of the parent normal distribution,
114     * and not the true mean and standard deviation of the truncated normal distribution.
115     * The {@code lower} and {@code upper} bounds define the truncation of the parent
116     * normal distribution.
117     *
118     * @param mean Mean for the parent distribution.
119     * @param sd Standard deviation for the parent distribution.
120     * @param lower Lower bound (inclusive) of the distribution, can be {@link Double#NEGATIVE_INFINITY}.
121     * @param upper Upper bound (inclusive) of the distribution, can be {@link Double#POSITIVE_INFINITY}.
122     * @return the distribution
123     * @throws IllegalArgumentException if {@code sd <= 0}; if {@code lower >= upper}; or if
124     * the truncation covers no probability range in the parent distribution.
125     */
126    public static TruncatedNormalDistribution of(double mean, double sd, double lower, double upper) {
127        if (sd <= 0) {
128            throw new DistributionException(DistributionException.NOT_STRICTLY_POSITIVE, sd);
129        }
130        if (lower >= upper) {
131            throw new DistributionException(DistributionException.INVALID_RANGE_LOW_GTE_HIGH, lower, upper);
132        }
133
134        // Use an instance for the parent normal distribution to maximise accuracy
135        // in range computations using the error function
136        final NormalDistribution parent = NormalDistribution.of(mean, sd);
137
138        // If there is no computable range then raise an exception.
139        final double z = parent.probability(lower, upper);
140        if (z <= MIN_P) {
141            // Map the bounds to a standard normal distribution for the message
142            final double a = (lower - mean) / sd;
143            final double b = (upper - mean) / sd;
144            throw new DistributionException(
145                "Excess truncation of standard normal : CDF(%s, %s) = %s", a, b, z);
146        }
147
148        // Here we have a meaningful truncation. Note that excess truncation may not be optimal.
149        // For example truncation close to zero where the PDF is constant can be approximated
150        // using a uniform distribution.
151
152        return new TruncatedNormalDistribution(parent, z, lower, upper);
153    }
154
155    /**
156     * Gets the mean for the parent distribution.
157     *
158     * <p>Note that the mean is of the parent normal distribution,
159     * and not the true mean of the truncated normal distribution.
160     * This is the {@code mean} parameter used to construct the truncated distribution.
161     *
162     * @return the parent mean.
163     * @see #getMean
164     * @since 1.3
165     */
166    public double getParentMean() {
167        return parentNormal.getMean();
168    }
169
170    /**
171     * Gets the standard deviation for the parent distribution.
172     *
173     * <p>Note that the standard deviation (SD) is of the parent normal distribution,
174     * and not the true standard deviation of the truncated normal distribution.
175     * This is the {@code sd} parameter used to construct the truncated distribution.
176     *
177     * @return the parent standard deviation.
178     * @since 1.3
179     */
180    public double getParentStandardDeviation() {
181        return parentNormal.getStandardDeviation();
182    }
183
184    /** {@inheritDoc} */
185    @Override
186    public double density(double x) {
187        if (x < lower || x > upper) {
188            return 0;
189        }
190        return parentNormal.density(x) / cdfDelta;
191    }
192
193    /** {@inheritDoc} */
194    @Override
195    public double probability(double x0, double x1) {
196        if (x0 > x1) {
197            throw new DistributionException(DistributionException.INVALID_RANGE_LOW_GT_HIGH,
198                                            x0, x1);
199        }
200        return parentNormal.probability(clipToRange(x0), clipToRange(x1)) / cdfDelta;
201    }
202
203    /** {@inheritDoc} */
204    @Override
205    public double logDensity(double x) {
206        if (x < lower || x > upper) {
207            return Double.NEGATIVE_INFINITY;
208        }
209        return parentNormal.logDensity(x) - logCdfDelta;
210    }
211
212    /** {@inheritDoc} */
213    @Override
214    public double cumulativeProbability(double x) {
215        if (x <= lower) {
216            return 0;
217        } else if (x >= upper) {
218            return 1;
219        }
220        return parentNormal.probability(lower, x) / cdfDelta;
221    }
222
223    /** {@inheritDoc} */
224    @Override
225    public double survivalProbability(double x) {
226        if (x <= lower) {
227            return 1;
228        } else if (x >= upper) {
229            return 0;
230        }
231        return parentNormal.probability(x, upper) / cdfDelta;
232    }
233
234    /** {@inheritDoc} */
235    @Override
236    public double inverseCumulativeProbability(double p) {
237        ArgumentUtils.checkProbability(p);
238        // Exact bound
239        if (p == 0) {
240            return lower;
241        } else if (p == 1) {
242            return upper;
243        }
244        // Linearly map p to the range [lower, upper]
245        final double x = parentNormal.inverseCumulativeProbability(cdfAlpha + p * cdfDelta);
246        return clipToRange(x);
247    }
248
249    /** {@inheritDoc} */
250    @Override
251    public double inverseSurvivalProbability(double p) {
252        ArgumentUtils.checkProbability(p);
253        // Exact bound
254        if (p == 1) {
255            return lower;
256        } else if (p == 0) {
257            return upper;
258        }
259        // Linearly map p to the range [lower, upper]
260        final double x = parentNormal.inverseSurvivalProbability(sfBeta + p * cdfDelta);
261        return clipToRange(x);
262    }
263
264    /** {@inheritDoc} */
265    @Override
266    public Sampler createSampler(UniformRandomProvider rng) {
267        // Map the bounds to a standard normal distribution
268        final double u = parentNormal.getMean();
269        final double s = parentNormal.getStandardDeviation();
270        final double a = (lower - u) / s;
271        final double b = (upper - u) / s;
272        // If the truncation covers a reasonable amount of the normal distribution
273        // then a rejection sampler can be used.
274        double threshold = REJECTION_THRESHOLD;
275        // If the truncation is entirely in the upper or lower half then adjust the
276        // threshold as twice the samples can be used
277        if (a >= 0 || b <= 0) {
278            threshold *= 0.5;
279        }
280
281        if (cdfDelta > threshold) {
282            // Create the rejection sampler
283            final ZigguratSampler.NormalizedGaussian sampler = ZigguratSampler.NormalizedGaussian.of(rng);
284            final DoubleSupplier gen;
285            // Use mirroring if possible
286            if (a >= 0) {
287                // Return the upper-half of the Gaussian
288                gen = () -> Math.abs(sampler.sample());
289            } else if (b <= 0) {
290                // Return the lower-half of the Gaussian
291                gen = () -> -Math.abs(sampler.sample());
292            } else {
293                // Return the full range of the Gaussian
294                gen = sampler::sample;
295            }
296            // Sample in [a, b] using rejection
297            return () -> {
298                double x = gen.getAsDouble();
299                while (x < a || x > b) {
300                    x = gen.getAsDouble();
301                }
302                // Avoid floating-point error when mapping back
303                return clipToRange(u + x * s);
304            };
305        }
306
307        // Default to an inverse CDF sampler
308        return super.createSampler(rng);
309    }
310
311    /**
312     * {@inheritDoc}
313     *
314     * <p>Represents the true mean of the truncated normal distribution rather
315     * than the parent normal distribution mean.
316     *
317     * <p>For \( \mu \) mean of the parent normal distribution,
318     * \( \sigma \) standard deviation of the parent normal distribution, and
319     * \( a \lt b \) the truncation interval of the parent normal distribution, the mean is:
320     *
321     * <p>\[ \mu + \frac{\phi(a)-\phi(b)}{\Phi(b) - \Phi(a)}\sigma \]
322     *
323     * <p>where \( \phi \) is the probability density function of the standard normal distribution
324     * and \( \Phi \) is its cumulative distribution function.
325     */
326    @Override
327    public double getMean() {
328        final double u = parentNormal.getMean();
329        final double s = parentNormal.getStandardDeviation();
330        final double a = (lower - u) / s;
331        final double b = (upper - u) / s;
332        return u + moment1(a, b) * s;
333    }
334
335    /**
336     * {@inheritDoc}
337     *
338     * <p>Represents the true variance of the truncated normal distribution rather
339     * than the parent normal distribution variance.
340     *
341     * <p>For \( \mu \) mean of the parent normal distribution,
342     * \( \sigma \) standard deviation of the parent normal distribution, and
343     * \( a \lt b \) the truncation interval of the parent normal distribution, the variance is:
344     *
345     * <p>\[ \sigma^2 \left[1 + \frac{a\phi(a)-b\phi(b)}{\Phi(b) - \Phi(a)} -
346     *       \left( \frac{\phi(a)-\phi(b)}{\Phi(b) - \Phi(a)} \right)^2 \right] \]
347     *
348     * <p>where \( \phi \) is the probability density function of the standard normal distribution
349     * and \( \Phi \) is its cumulative distribution function.
350     */
351    @Override
352    public double getVariance() {
353        final double u = parentNormal.getMean();
354        final double s = parentNormal.getStandardDeviation();
355        final double a = (lower - u) / s;
356        final double b = (upper - u) / s;
357        return variance(a, b) * s * s;
358    }
359
360    /**
361     * {@inheritDoc}
362     *
363     * <p>The lower bound of the support is equal to the lower bound parameter
364     * of the distribution.
365     */
366    @Override
367    public double getSupportLowerBound() {
368        return lower;
369    }
370
371    /**
372     * {@inheritDoc}
373     *
374     * <p>The upper bound of the support is equal to the upper bound parameter
375     * of the distribution.
376     */
377    @Override
378    public double getSupportUpperBound() {
379        return upper;
380    }
381
382    /**
383     * Clip the value to the range [lower, upper].
384     * This is used to handle floating-point error at the support bound.
385     *
386     * @param x Value x
387     * @return x clipped to the range
388     */
389    private double clipToRange(double x) {
390        return clip(x, lower, upper);
391    }
392
393    /**
394     * Clip the value to the range [lower, upper].
395     *
396     * @param x Value x
397     * @param lower Lower bound (inclusive)
398     * @param upper Upper bound (inclusive)
399     * @return x clipped to the range
400     */
401    private static double clip(double x, double lower, double upper) {
402        if (x <= lower) {
403            return lower;
404        }
405        return x < upper ? x : upper;
406    }
407
408    // Calculation of variance and mean can suffer from cancellation.
409    //
410    // Use formulas from Jorge Fernandez-de-Cossio-Diaz adapted under the
411    // terms of the MIT "Expat" License (see NOTICE and LICENSE).
412    //
413    // These formulas use the complementary error function
414    //   erfcx(z) = erfc(z) * exp(z^2)
415    // This avoids computation of exp terms for the Gaussian PDF and then
416    // dividing by the error functions erf or erfc:
417    //   exp(-0.5*x*x) / erfc(x / sqrt(2)) == 1 / erfcx(x / sqrt(2))
418    // At large z the erfcx function is computable but exp(-0.5*z*z) and
419    // erfc(z) are zero. Use of these formulas allows computation of the
420    // mean and variance for the usable range of the truncated distribution
421    // (cdf(a, b) != 0). The variance is not accurate when it approaches
422    // machine epsilon (2^-52) at extremely narrow truncations and the
423    // computation -> 0.
424    //
425    // See: https://github.com/cossio/TruncatedNormal.jl
426
427    /**
428     * Compute the first moment (mean) of the truncated standard normal distribution.
429     *
430     * <p>Assumes {@code a <= b}.
431     *
432     * @param a Lower bound
433     * @param b Upper bound
434     * @return the first moment
435     */
436    static double moment1(double a, double b) {
437        // Assume a <= b
438        if (a == b) {
439            return a;
440        }
441        if (Math.abs(a) > Math.abs(b)) {
442            // Subtract from zero to avoid generating -0.0
443            return 0 - moment1(-b, -a);
444        }
445
446        // Here:
447        // |a| <= |b|
448        // a < b
449        // 0 < b
450
451        if (a <= -MAX_X) {
452            // No truncation
453            return 0;
454        }
455        if (b >= MAX_X) {
456            // One-sided truncation
457            return ROOT_2_PI / Erfcx.value(a / ROOT2);
458        }
459
460        // pdf = exp(-0.5*x*x) / sqrt(2*pi)
461        // cdf = erfc(-x/sqrt(2)) / 2
462        // Compute:
463        // -(pdf(b) - pdf(a)) / cdf(b, a)
464        // Note:
465        // exp(-0.5*b*b) - exp(-0.5*a*a)
466        // Use cancellation of powers:
467        // exp(-0.5*(b*b-a*a)) * exp(-0.5*a*a) - exp(-0.5*a*a)
468        // expm1(-0.5*(b*b-a*a)) * exp(-0.5*a*a)
469
470        // dx = -0.5*(b*b-a*a)
471        final double dx = 0.5 * (b + a) * (b - a);
472        final double m;
473        if (a <= 0) {
474            // Opposite signs
475            m = ROOT_2_PI * -Math.expm1(-dx) * Math.exp(-0.5 * a * a) / ErfDifference.value(a / ROOT2, b / ROOT2);
476        } else {
477            final double z = Math.exp(-dx) * Erfcx.value(b / ROOT2) - Erfcx.value(a / ROOT2);
478            if (z == 0) {
479                // Occurs when a and b have large magnitudes and are very close
480                return (a + b) * 0.5;
481            }
482            m = ROOT_2_PI * Math.expm1(-dx) / z;
483        }
484
485        // Clip to the range
486        return clip(m, a, b);
487    }
488
489    /**
490     * Compute the second moment of the truncated standard normal distribution.
491     *
492     * <p>Assumes {@code a <= b}.
493     *
494     * @param a Lower bound
495     * @param b Upper bound
496     * @return the first moment
497     */
498    private static double moment2(double a, double b) {
499        // Assume a < b.
500        // a == b is handled in the variance method
501        if (Math.abs(a) > Math.abs(b)) {
502            return moment2(-b, -a);
503        }
504
505        // Here:
506        // |a| <= |b|
507        // a < b
508        // 0 < b
509
510        if (a <= -MAX_X) {
511            // No truncation
512            return 1;
513        }
514        if (b >= MAX_X) {
515            // One-sided truncation.
516            // For a -> inf : moment2 -> a*a
517            // This occurs when erfcx(z) is approximated by (1/sqrt(pi)) / z and terms
518            // cancel. z > 6.71e7, a > 9.49e7
519            return 1 + ROOT_2_PI * a / Erfcx.value(a / ROOT2);
520        }
521
522        // pdf = exp(-0.5*x*x) / sqrt(2*pi)
523        // cdf = erfc(-x/sqrt(2)) / 2
524        // Compute:
525        // 1 - (b*pdf(b) - a*pdf(a)) / cdf(b, a)
526        // = (cdf(b, a) - b*pdf(b) -a*pdf(a)) / cdf(b, a)
527
528        // Note:
529        // For z -> 0:
530        //   sqrt(pi / 2) * erf(z / sqrt(2)) -> z
531        //   z * Math.exp(-0.5 * z * z) -> z
532        // Both computations below have cancellation as b -> 0 and the
533        // second moment is not computable as the fraction P/Q
534        // since P < ulp(Q). This always occurs when b < MIN_X
535        // if MIN_X is set at the point where
536        //   exp(-0.5 * z * z) / sqrt(2 pi) == 1 / sqrt(2 pi).
537        // This is JDK dependent due to variations in Math.exp.
538        // For b < MIN_X the second moment can be approximated using
539        // a uniform distribution: (b^3 - a^3) / (3b - 3a).
540        // In practice it also occurs when b > MIN_X since any a < MIN_X
541        // is effectively zero for part of the computation. A
542        // threshold to transition to a uniform distribution
543        // approximation is a compromise. Also note it will not
544        // correct computation when (b-a) is small and is far from 0.
545        // Thus the second moment is left to be inaccurate for
546        // small ranges (b-a) and the variance -> 0 when the true
547        // variance is close to or below machine epsilon.
548
549        double m;
550
551        if (a <= 0) {
552            // Opposite signs
553            final double ea = ROOT_PI_2 * Erf.value(a / ROOT2);
554            final double eb = ROOT_PI_2 * Erf.value(b / ROOT2);
555            final double fa = ea - a * Math.exp(-0.5 * a * a);
556            final double fb = eb - b * Math.exp(-0.5 * b * b);
557            // Assume fb >= fa && eb >= ea
558            // If fb <= fa this is a tiny range around 0
559            m = (fb - fa) / (eb - ea);
560            // Clip to the range
561            m = clip(m, 0, 1);
562        } else {
563            final double dx = 0.5 * (b + a) * (b - a);
564            final double ex = Math.exp(-dx);
565            final double ea = ROOT_PI_2 * Erfcx.value(a / ROOT2);
566            final double eb = ROOT_PI_2 * Erfcx.value(b / ROOT2);
567            final double fa = ea + a;
568            final double fb = eb + b;
569            m = (fa - fb * ex) / (ea - eb * ex);
570            // Clip to the range
571            m = clip(m, a * a, b * b);
572        }
573        return m;
574    }
575
576    /**
577     * Compute the variance of the truncated standard normal distribution.
578     *
579     * <p>Assumes {@code a <= b}.
580     *
581     * @param a Lower bound
582     * @param b Upper bound
583     * @return the first moment
584     */
585    static double variance(double a, double b) {
586        if (a == b) {
587            return 0;
588        }
589
590        final double m1 = moment1(a, b);
591        double m2 = moment2(a, b);
592        // variance = m2 - m1*m1
593        // rearrange x^2 - y^2 as (x-y)(x+y)
594        m2 = Math.sqrt(m2);
595        final double variance = (m2 - m1) * (m2 + m1);
596
597        // Detect floating-point error.
598        if (variance >= 1) {
599            // Note:
600            // Extreme truncations in the tails can compute a variance above 1,
601            // for example if m2 is infinite: m2 - m1*m1 > 1
602            // Detect no truncation as the terms a and b lie far either side of zero;
603            // otherwise return 0 to indicate very small unknown variance.
604            return a < -1 && b > 1 ? 1 : 0;
605        } else if (variance <= 0) {
606            // Floating-point error can create negative variance so return 0.
607            return 0;
608        }
609
610        return variance;
611    }
612}