Uploaded image for project: 'JDK'
  1. JDK
  2. JDK-8214761

Bug in parallel Kahan summation implementation

    Details

    • Subcomponent:
    • Understanding:
      Fix Understood
    • CPU:
      x86_64
    • OS:
      windows_10

      Description

      A DESCRIPTION OF THE PROBLEM :
      DoubleStream.sum and related functions (Collectors.averagingDouble, Collectors.summingDouble and possibly others) all use Kahan summation to reduce numerical error. As I understand it, the implementations of these function use a double array where index 0 holds the high-order bits of the running, and index 1 holds the negation of the low-order bits. The documentation incorrectly states that index 1 holds the lower order bits (no negaton), and when combining two running sums incorrectly adds the negation of the low-order bits.

      This problem appears in OpenJDK 8 and 11. I think that in https://hg.openjdk.java.net/jdk/jdk/file/8613f3fdbdae/src/java.base/share/classes/java/util/stream/DoublePipeline.java, line 432 should be changed to

      Collectors.sumWithCompensation(ll, -rr[1]);

      and in https://hg.openjdk.java.net/jdk/jdk/file/8613f3fdbdae/src/java.base/share/classes/java/util/stream/Collectors.java, line 729 should be changed to

      return sumWithCompensation(a, -b[1]); },

      and line 841 should be changed to

      (a, b) -> { sumWithCompensation(a, b[0]); sumWithCompensation(a, -b[1]); a[2] += b[2]; a[3] += b[3]; return a; },

      Alternatively, sumWithCompensation could be altered instead.

      I've attached test code below, comparing the result of sum() using the current implementation, and the result using the flipped sign. The results of the sequential and niave sum are given for comparison, and the sum of the squared errors (relative to a base case using sequential Kahan summation) are printed. With the random inputs used, the sum of squared errors is consistently lower with the proposed sign flip.


      ---------- BEGIN SOURCE ----------
      package test;

      import java.util.Random;
      import java.util.stream.DoubleStream;

      public class TestSum {
          
          public static void main(String [] args) {
              double naive = 0;
              double sequentialStream = 0;
              double parallelStream = 0;
              double mySequentialStream = 0;
              double myParallelStream = 0;
              
              for (int loop = 0; loop < 100; loop++) {
                  // sequence of random numbers of varying magnitudes, both positive and negative
                  double[] rand = new Random().doubles(1_000_000)
                          .map(Math::log)
                          .map(x -> (Double.doubleToLongBits(x) % 2 == 0) ? x : -x)
                          .toArray();
                  
                  // base case: standard Kahan summation
                  double[] sum = new double[2];
                  for (int i=0; i < rand.length; i++) {
                      sumWithCompensation(sum, rand[i]);
                  }
                  
                  // squared error of naive sum by reduction - should be large
                  naive += Math.pow(DoubleStream.of(rand).reduce((x, y) -> x+y).getAsDouble() - sum[0], 2);
                  
                  // squared error of sequential sum - should be 0
                  sequentialStream += Math.pow(DoubleStream.of(rand).sum() - sum[0], 2);
                  
                  // squared error of parallel sum
                  parallelStream += Math.pow(DoubleStream.of(rand).parallel().sum() - sum[0], 2);
                  
                  // squared error of modified sequential sum - should be 0
                  mySequentialStream += Math.pow(computeFinalSum(DoubleStream.of(rand).collect(
                          () -> new double[3],
                          (ll, d) -> {
                              sumWithCompensation(ll, d);
                              ll[2] += d;
                          },
                          (ll, rr) -> {
                              sumWithCompensation(ll, rr[0]);
                              sumWithCompensation(ll, -rr[1]); // minus is added
                              ll[2] += rr[2];
                          })) - sum[0], 2);
                  
                  // squared error of modified parallel sum - typically ~0.25-0.5 times squared error of parallel sum
                  myParallelStream += Math.pow(computeFinalSum(DoubleStream.of(rand).parallel().collect(
                          () -> new double[3],
                          (ll, d) -> {
                              sumWithCompensation(ll, d);
                              ll[2] += d;
                          },
                          (ll, rr) -> {
                              sumWithCompensation(ll, rr[0]);
                              sumWithCompensation(ll, -rr[1]); // minus is added
                              ll[2] += rr[2];
                          })) - sum[0], 2);
              }
              
              // print sum of squared errors
              System.out.println(naive);
              System.out.println(sequentialStream);
              System.out.println(parallelStream);
              System.out.println(mySequentialStream);
              System.out.println(myParallelStream);
          }
          
          // from OpenJDK8 Collectors, unmodified
          static double[] sumWithCompensation(double[] intermediateSum, double value) {
              double tmp = value - intermediateSum[1];
              double sum = intermediateSum[0];
              double velvel = sum + tmp; // Little wolf of rounding error
              intermediateSum[1] = (velvel - sum) - tmp;
              intermediateSum[0] = velvel;
              return intermediateSum;
          }
          
          // from OpenJDK8 Collectors, unmodified
          static double computeFinalSum(double[] summands) {
              double tmp = summands[0] + summands[1];
              double simpleSum = summands[summands.length - 1];
              if (Double.isNaN(tmp) && Double.isInfinite(simpleSum))
                  return simpleSum;
              else
                  return tmp;
          }

      }

      ---------- END SOURCE ----------

      FREQUENCY : always


        Attachments

          Issue Links

            Activity

              People

              • Assignee:
                igerasim Ivan Gerasimov
                Reporter:
                webbuggrp Webbug Group
              • Votes:
                0 Vote for this issue
                Watchers:
                3 Start watching this issue

                Dates

                • Created:
                  Updated: