Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

math.hypot: fix incorrect over/underflow behavior #19472

Merged
merged 10 commits into from
May 30, 2024
Merged

math.hypot: fix incorrect over/underflow behavior #19472

merged 10 commits into from
May 30, 2024

Conversation

expikr
Copy link
Contributor

@expikr expikr commented Mar 29, 2024

This replaces the incorrectly ported musl code and instead implements the algorithm from this paper..

~15% speedup over the old version. Raw sqrt is faster than the new version by ~3x and faster than the old by ~3.5x

Benchmark Code
const std = @import("std");
const math = std.math;
const floatMin = math.floatMin;
const floatMax = math.floatMax;
const floatEps = math.floatEps;
const isNan = math.isNan;
const isInf = math.isInf;
const nan = math.nan;
const inf = math.inf;
const print = std.debug.print;
const expect = std.testing.expect;

pub fn main() !void {
    inline for (.{ f32, f64 }) |T| {
        const scale = floatEpsAt(T, @sqrt(floatEps(T)/2));
        print("\n{any} rescale: {any}, {any}\n", .{ T, scale, 1 / scale });
        print("| ref | new | old | raw |\n", .{});
        inline for ([7][3]T{
            .{ 0.0, -1.2, 1.2 },
            .{ 0.2, -0.34, 0.3944616584663203993612799816649560759946493601889826495362 },
            .{ 0.8923, 2.636890, 2.7837722899152509525110650481670176852603253522923737962880 },
            .{ 1.5, 5.25, 5.4600824169603887033229768686452745953332522619323580787836 },
            .{ 37.45, 159.835, 164.16372840856167640478217141034363907565754072954443805164 },
            .{ 89.123, 382.028905, 392.28687638576315875933966414927490685367196874260165618371 },
            .{ 123123.234375, 529428.707813, 543556.88524707706887251269205923830745438413088753096759371 },
        }) |v| {
            print("| {any} | {any} | {any} | {any} |\n", .{ v[2], hypot(v[0], v[1]), math.hypot(v[0], v[1]), @sqrt(v[0] * v[0] + v[1] * v[1]) });
        }

        try expect(math.isNan(hypot(nan(T), 0.0)));
        try expect(math.isNan(hypot(0.0, nan(T))));

        try expect(math.isPositiveInf(hypot(inf(T), 0.0)));
        try expect(math.isPositiveInf(hypot(0.0, inf(T))));
        try expect(math.isPositiveInf(hypot(inf(T), nan(T))));
        try expect(math.isPositiveInf(hypot(nan(T), inf(T))));

        try expect(math.isPositiveInf(hypot(-inf(T), 0.0)));
        try expect(math.isPositiveInf(hypot(0.0, -inf(T))));
        try expect(math.isPositiveInf(hypot(-inf(T), nan(T))));
        try expect(math.isPositiveInf(hypot(nan(T), -inf(T))));
    }
    print("comptime_float: {any}\n", .{hypot(3.0, 4.0)});
    print("comptime_int: {any}\n", .{hypot(3, 4)});

    var timer = std.time.Timer.start() catch unreachable;
    const side = 30000;

    var old_tot: u64 = 0;
    var new_tot: u64 = 0;
    for (1..10) |_| {
        print("old: ", .{});

        timer.reset();
        for (0..side) |a| {
            for (0..side) |b| {
                _ = std.mem.doNotOptimizeAway(math.hypot(@as(f64, @floatFromInt(a)), @as(f64, @floatFromInt(b))));
            }
        }
        const old_time = timer.read();
        old_tot += old_time;

        const old_avg = @as(f64, @floatFromInt(old_time)) / @as(f64, @floatFromInt(side * side));
        print("{any} ns, avg: {any} ns/calc\n", .{ old_time, old_avg });


        print("new: ", .{});

        timer.reset();
        for (0..side) |a| {
            for (0..side) |b| {
                _ = std.mem.doNotOptimizeAway(hypot(@as(f64, @floatFromInt(a)), @as(f64, @floatFromInt(b))));
            }
        }
        const new_time = timer.read();
        new_tot += new_time;

        const new_avg = @as(f64, @floatFromInt(new_time)) / @as(f64, @floatFromInt(side * side));
        print("{any} ns, avg: {any} ns/calc\n", .{ new_time, new_avg });

        print("old/new={d}%\n", .{100 * old_avg / new_avg});
    }
    print("Overall old/new={d}%\n", .{100 * @as(f64,@floatFromInt(old_tot)) / @as(f64,@floatFromInt(new_tot))});
}

inline fn floatEpsAt(comptime T: type, x: T) T {
    switch (@typeInfo(T)) {
        .Float => |F| {
            const U: type = @Type(.{ .Int = .{ .signedness = .unsigned, .bits = F.bits } });
            const u: U = @bitCast(x);
            const y: T = @bitCast(u ^ 1);
            return @abs(x - y);
        },
        else => @compileError("floatEpsAt only supports floats"),
    }
}

fn Hypot(comptime T: type) type {
    return if (T == comptime_int) comptime_float else T;
}

fn hypot(x: anytype, y: anytype) Hypot(@TypeOf(x, y)) {
    const T = @TypeOf(x, y);
    switch (@typeInfo(T)) {
        .Float => {},
        .ComptimeFloat => return @sqrt(x * x + y * y),
        .ComptimeInt => {
            const a: comptime_float = @floatFromInt(x);
            const b: comptime_float = @floatFromInt(y);
            return @sqrt(a * a + b * b);
        },
        else => @compileError("hypot not implemented for " ++ @typeName(T)),
    }
    const lower = @sqrt(floatMin(T));
    const upper = @sqrt(floatMax(T) / 2);
    const incre = @sqrt(floatEps(T) / 2);
    const scale = floatEpsAt(T, incre);
    var major: T = x;
    var minor: T = y;
    if (isInf(major) or isInf(minor)) return inf(T);
    if (isNan(major) or isNan(minor)) return nan(T);
    major = @abs(major);
    minor = @abs(minor);
    if (minor > major) {
        const tmp = major;
        major = minor;
        minor = tmp;
    }
    if (major * incre >= minor) return major;
    if (major > upper) return hypotFused(T, major * scale, minor * scale) / scale;
    if (minor < lower) return hypotFused(T, major / scale, minor / scale) * scale;
    return hypotFused(T, major, minor);
}

inline fn hypotFused(comptime F: type, x: F, y: F) F {
    var r = @sqrt(@mulAdd(F, x, x, y * y));
    const rr = r * r;
    const xx = x * x;
    const z = @mulAdd(F, -y, y, rr - xx) + @mulAdd(F, r, r, -rr) - @mulAdd(F, x, x, -xx);
    r -= z / (2 * r);
    return r;
}

closes #17915

@expikr expikr changed the title RFC for std.math.hypot: "An Improved Algorithm for hypot(a,b)" RFC for "An Improved Algorithm for hypot(a,b)" Mar 29, 2024
@expikr expikr marked this pull request as draft March 29, 2024 09:01
@expikr expikr changed the title RFC for "An Improved Algorithm for hypot(a,b)" overhaul math.hp Mar 29, 2024
@expikr expikr changed the title overhaul math.hp overhaul math.hypot Mar 29, 2024
@expikr expikr marked this pull request as ready for review March 29, 2024 10:59
@expikr expikr changed the title overhaul math.hypot overhaul math.hypot and add math.floatEpsAt for more rigorous numerical precision testing Mar 30, 2024
@expikr expikr changed the title overhaul math.hypot and add math.floatEpsAt for more rigorous numerical precision testing overhaul math.hypot and add math.floatEpsAt Mar 30, 2024
@expikr expikr changed the title overhaul math.hypot and add math.floatEpsAt math.hypot: fix incorrect over/underflow behavior Apr 24, 2024
@expikr
Copy link
Contributor Author

expikr commented May 23, 2024

@tiehuis mind giving this a quick review?

Copy link
Member

@tiehuis tiehuis left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good PR. Keen to merge.

lib/std/math/hypot.zig Show resolved Hide resolved
lib/std/math/hypot.zig Outdated Show resolved Hide resolved
lib/std/math/hypot.zig Outdated Show resolved Hide resolved
@expikr expikr requested a review from tiehuis May 26, 2024 21:59
lib/std/math/hypot.zig Outdated Show resolved Hide resolved
lib/std/math/hypot.zig Outdated Show resolved Hide resolved
lib/std/math/hypot.zig Outdated Show resolved Hide resolved
@expikr expikr requested a review from tiehuis May 30, 2024 07:09
lib/std/math/hypot.zig Outdated Show resolved Hide resolved
@tiehuis tiehuis enabled auto-merge (squash) May 30, 2024 07:42
@tiehuis tiehuis merged commit 103b885 into ziglang:master May 30, 2024
10 checks passed
@expikr expikr deleted the patch-6 branch May 30, 2024 17:09
Rexicon226 pushed a commit to Rexicon226/zig that referenced this pull request Jun 9, 2024
ryoppippi pushed a commit to ryoppippi/zig that referenced this pull request Jul 5, 2024
SammyJames pushed a commit to SammyJames/zig that referenced this pull request Aug 7, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

RFC for std.math.hypot: "An Improved Algorithm for hypot(a,b)"
2 participants