ziglang / zig

General-purpose programming language and toolchain for maintaining robust, optimal, and reusable software.
https://ziglang.org
MIT License
33.75k stars 2.48k forks source link

Optimise stringToEnum #3863

Open daurnimator opened 4 years ago

daurnimator commented 4 years ago

The current implementation of stringToEnum is:

pub fn stringToEnum(comptime T: type, str: []const u8) ?T {
    inline for (@typeInfo(T).Enum.fields) |enumField| {
        if (std.mem.eql(u8, str, enumField.name)) {
            return @field(T, enumField.name);
        }
    }
    return null;
}

This could be much more efficient if a perfect hash was created and is therefore one motivation for a perfect hashing algorithm to be in the standard library.

FWIW my current usecase for this is converting known HTTP field names to an enum.

data-man commented 4 years ago

perfect hashing algorithm to be in the standard library

Based on https://andrewkelley.me/post/string-matching-comptime-perfect-hashing-zig.html

perfecthash.zig ```zig const std = @import("std"); const assert = std.debug.assert; pub fn perfectHash(comptime strs: []const []const u8) type { const Op = union(enum) { /// add the length of the string Length, /// add the byte at index % len Index: usize, /// right shift then xor with constant XorShiftMultiply: u32, }; const S = struct { fn hash(comptime plan: []Op, s: []const u8) u32 { var h: u32 = 0; inline for (plan) |op| { switch (op) { Op.Length => { h +%= @truncate(u32, s.len); }, Op.Index => |index| { h +%= s[index % s.len]; }, Op.XorShiftMultiply => |x| { h ^= x >> 16; }, } } return h; } fn testPlan(comptime plan: []Op) bool { var hit = [1]bool{false} ** strs.len; for (strs) |s| { const h = hash(plan, s); const i = h % hit.len; if (hit[i]) { // hit this index twice return false; } hit[i] = true; } return true; } }; var ops_buf: [10]Op = undefined; const plan = have_a_plan: { var seed: u32 = 0x45d9f3b; var index_i: usize = 0; const try_seed_count = 50; const try_index_count = 50; while (index_i < try_index_count) : (index_i += 1) { const bool_values = if (index_i == 0) [_]bool{true} else [_]bool{ false, true }; for (bool_values) |try_length| { var seed_i: usize = 0; while (seed_i < try_seed_count) : (seed_i += 1) { comptime var rand_state = std.rand.Xoroshiro128.init(seed + seed_i); const rng = &rand_state.random; var ops_index = 0; if (try_length) { ops_buf[ops_index] = Op.Length; ops_index += 1; if (S.testPlan(ops_buf[0..ops_index])) break :have_a_plan ops_buf[0..ops_index]; ops_buf[ops_index] = Op{ .XorShiftMultiply = rng.scalar(u32) }; ops_index += 1; if (S.testPlan(ops_buf[0..ops_index])) break :have_a_plan ops_buf[0..ops_index]; } ops_buf[ops_index] = Op{ .XorShiftMultiply = rng.scalar(u32) }; ops_index += 1; if (S.testPlan(ops_buf[0..ops_index])) break :have_a_plan ops_buf[0..ops_index]; const before_bytes_it_index = ops_index; var byte_index = 0; while (byte_index < index_i) : (byte_index += 1) { ops_index = before_bytes_it_index; ops_buf[ops_index] = Op{ .Index = rng.scalar(u32) % try_index_count }; ops_index += 1; if (S.testPlan(ops_buf[0..ops_index])) break :have_a_plan ops_buf[0..ops_index]; ops_buf[ops_index] = Op{ .XorShiftMultiply = rng.scalar(u32) }; ops_index += 1; if (S.testPlan(ops_buf[0..ops_index])) break :have_a_plan ops_buf[0..ops_index]; } } } } @compileError("unable to come up with perfect hash"); }; return struct { pub fn case(comptime s: []const u8) usize { inline for (strs) |str| { if (std.mem.eql(u8, str, s)) return hash(s); } @compileError("case value '" ++ s ++ "' not declared"); } pub fn hash(s: []const u8) usize { const ok = for (strs) |str| { if (std.mem.eql(u8, str, s)) break true; } else false; if (ok) { return S.hash(plan, s) % strs.len; } else { return S.hash(plan, s); } } }; } ```
test_pf.zig ```zig const std = @import("std"); const perfectHash = @import("perfecthash.zig").perfectHash; const assert = std.debug.assert; test "perfect hashing" { basedOnLength("ab"); } fn basedOnLength(target: []const u8) void { const ph = perfectHash(&[_][]const u8{ "a", "ab", "abc", }); switch (ph.hash(target)) { ph.case("a") => @panic("wrong one a\n"), ph.case("ab") => {}, // test pass ph.case("abc") => @panic("wrong one abc\n"), else => std.debug.warn("not found\n"), } } test "perfect hashing 2" { @setEvalBranchQuota(100000); const target = "eno"; const ph = perfectHash(&[_][]const u8{ "one", "eno", "two", "three", "four", "five", }); switch (ph.hash(target)) { ph.case("one") => std.debug.warn("handle the one case\n"), ph.case("eno") => std.debug.warn("handle the eno case\n"), ph.case("two") => std.debug.warn("handle the two case\n"), ph.case("three") => std.debug.warn("handle the three case\n"), ph.case("four") => std.debug.warn("handle the four case\n"), ph.case("five") => std.debug.warn("handle the five case\n"), else => std.debug.warn("not found\n"), } } test "perfect hashing 3" { @setEvalBranchQuota(100000); const target = "six"; const ph = perfectHash(&[_][]const u8{ "one", "eno", "two", "three", "four", "five", }); switch (ph.hash(target)) { ph.case("one") => std.debug.warn("handle the one case\n"), ph.case("eno") => std.debug.warn("handle the eno case\n"), ph.case("two") => std.debug.warn("handle the two case\n"), ph.case("three") => std.debug.warn("handle the three case\n"), ph.case("four") => std.debug.warn("handle the four case\n"), ph.case("five") => std.debug.warn("handle the five case\n"), else => std.debug.warn("{} not found\n", target), } } ```
frmdstryr commented 4 years ago

I think it would be an awesome thing to have built into the language, eg


switch (target) : (hashFn) { // Or some way to set which hash fn to use at comptime
    "one" => ...,
    "two" => ...,
    // etc...
}
data-man commented 4 years ago

Some useful links:

BBHash, go-bbhash and rust-boomphf based on this paper: Fast and scalable minimal perfect hashing for massive key sets

And rust-phf uses CHD algorithm.

N00byEdge commented 4 years ago

Or constructing a compile time trie...

data-man commented 4 years ago

Attempt №2

perfectHash for any type ```zig const std = @import("std"); const mem = std.mem; const warn = std.debug.warn; pub fn perfectHash(comptime T: type, comptime cases: []const T) type { const Op = union(enum) { /// add the length of the string Length, /// add the byte at index % len Index: usize, /// right shift then xor with constant XorShiftMultiply: u32, }; const S = struct { fn hash(comptime plan: []Op, s: []const u8) u32 { var h: u32 = 0; inline for (plan) |op| { switch (op) { Op.Length => { h +%= @truncate(u32, s.len); }, Op.Index => |index| { h +%= s[index % s.len]; }, Op.XorShiftMultiply => |x| { h ^= x >> 16; }, } } return h; } fn testPlan(comptime plan: []Op) bool { comptime var hit = [1]bool{false} ** cases.len; for (cases) |c| { const b = mem.toBytes(c); const h = hash(plan, b[0..]); const i = h % hit.len; if (hit[i]) { // hit this index twice return false; } hit[i] = true; } return true; } }; var ops_buf: [10]Op = undefined; const plan = have_a_plan: { var seed: u32 = 0x45d9f3b; var index_i: usize = 0; const try_seed_count = 50; const try_index_count = 50; @setEvalBranchQuota(50000); while (index_i < try_index_count) : (index_i += 1) { const bool_values = if (index_i == 0) [_]bool{true} else [_]bool{ false, true }; for (bool_values) |try_length| { var seed_i: usize = 0; while (seed_i < try_seed_count) : (seed_i += 1) { comptime var rand_state = std.rand.Xoroshiro128.init(seed + seed_i); const rng = &rand_state.random; var ops_index = 0; if (try_length) { ops_buf[ops_index] = Op.Length; ops_index += 1; if (S.testPlan(ops_buf[0..ops_index])) break :have_a_plan ops_buf[0..ops_index]; ops_buf[ops_index] = Op{ .XorShiftMultiply = rng.scalar(u32) }; ops_index += 1; if (S.testPlan(ops_buf[0..ops_index])) break :have_a_plan ops_buf[0..ops_index]; } ops_buf[ops_index] = Op{ .XorShiftMultiply = rng.scalar(u32) }; ops_index += 1; if (S.testPlan(ops_buf[0..ops_index])) break :have_a_plan ops_buf[0..ops_index]; const before_bytes_it_index = ops_index; var byte_index = 0; while (byte_index < index_i) : (byte_index += 1) { ops_index = before_bytes_it_index; ops_buf[ops_index] = Op{ .Index = rng.scalar(u32) % try_index_count }; ops_index += 1; if (S.testPlan(ops_buf[0..ops_index])) break :have_a_plan ops_buf[0..ops_index]; ops_buf[ops_index] = Op{ .XorShiftMultiply = rng.scalar(u32) }; ops_index += 1; if (S.testPlan(ops_buf[0..ops_index])) break :have_a_plan ops_buf[0..ops_index]; } } } } @compileError("unable to come up with perfect hash"); }; return struct { pub fn case(comptime c: T) usize { inline for (cases) |c2| { if (std.meta.eql(c, c2)) return hash(c); } @compileLog("case value ", c, " not declared!"); } pub fn hash(c: T) usize { const ok = for (cases) |c2| { if (std.meta.eql(c, c2)) break true; } else false; const b = mem.toBytes(c); if (ok) { return S.hash(plan, b[0..]) % cases.len; } else { return S.hash(plan, b[0..]); } } }; } test "perfect hashing v2" { const target1 = 3; const target2 = 30; const ph = perfectHash(u16, &[_]u16{ 1, 2, 3, 4, 5, 6, }); switch (ph.hash(target1)) { ph.case(1) => warn("handle the {} case\n", .{target1}), ph.case(2) => warn("handle the {} case\n", .{target1}), ph.case(3) => warn("handle the {} case\n", .{target1}), ph.case(4) => warn("handle the {} case\n", .{target1}), ph.case(5) => warn("handle the {} case\n", .{target1}), ph.case(6) => warn("handle the {} case\n", .{target1}), else => warn("case {} not found\n", .{target1}), } switch (ph.hash(target2)) { ph.case(1) => warn("handle the {} case\n", .{target2}), ph.case(2) => warn("handle the {} case\n", .{target2}), ph.case(3) => warn("handle the {} case\n", .{target2}), ph.case(4) => warn("handle the {} case\n", .{target2}), ph.case(5) => warn("handle the {} case\n", .{target2}), ph.case(6) => warn("handle the {} case\n", .{target2}), else => warn("case {} not found\n", .{target2}), } } ```

My goals: autoPerfectHash and autoPerfectHashMap.

daurnimator commented 4 years ago
 const ok = for (cases) |c2| {
                if (std.meta.eql(c, c2))
                    break true;
            } else false;

This looks like the expensive comparison that perfect hashing is meant to avoid?

data-man commented 4 years ago

All questions to the creator. :smile:

I hope this loop executed in comptime only.

squeek502 commented 4 years ago

I hope this loop executed in comptime only.

In Andrew's blog post, that loop is wrapped in if (std.debug.runtime_safety):

if (std.debug.runtime_safety) {
    const ok = for (strs) |str| {
        if (std.mem.eql(u8, str, s))
            break true;
    } else false;
    if (!ok) {
        std.debug.panic("attempt to perfect hash {} which was not declared", s);
    }
}

i.e. it's not included in ReleaseFast/ReleaseSmall.

daurnimator commented 4 years ago

For stringToEnum you wouldn't want that loop in there either: return null when the element is not in the perfect hash.

squeek502 commented 4 years ago

return null when the element is not in the perfect hash.

Is this feasible? How could the perfect hash know when something is not one of the original set of values?

As an example, if it turns out that in and not_in both hash to the same value with the chosen perfect hashing algorithm, wouldn't it just treat not_in the same as in? How could it know to return null for not_in instead?

daurnimator commented 4 years ago

Is this feasible? How could the perfect hash know when something is not one of the original set of values?

Once you hash to a given element, you then verify that the input matches that member.

squeek502 commented 4 years ago

Ah, I see; sounds similar to what the Zig tokenizer does now. Hashes for each keyword are computed at compile time and the lookup function checks the hash first before checking mem.eql. Perfect hashing could remove the need for the loop in getKeyword, though.

daurnimator commented 4 years ago

Yep! Seems like you found another place that would able to immediately make use of the new machinery :)

andrewrk commented 4 years ago

One thing to remember is to perf test. @hejsil did some experiments with this earlier and determined that, at least in release-fast mode, the optimizer given if-else chains was able to outperform a perfect hash implementation.

pixelherodev commented 4 years ago

Dumb question: what does "perfect" hash mean? Got a quick answer, thanks haze :)

squeek502 commented 11 months ago

As discovered in https://github.com/Vexu/arocc/pull/524 (https://github.com/Vexu/arocc/pull/524#issuecomment-1762854941), stringToEnum could have considerably better codegen for large enums if it were sorted by field length before the inline for:

https://github.com/ziglang/zig/blob/a126afa1c3fe7e39678c201f6470b7899aedd47e/lib/std/meta.zig#L41-L46

Here's a benchmark focusing on just different possible sorting of enum fields (this is with 3948 fields in the enum, shortest field length is 3 and longest is 43):

-------------------- unsorted ---------------------
            always found: 3718ns per lookup
not found (random bytes): 6638ns per lookup
 not found (1 char diff): 3819ns per lookup

----------- sorted by length (desc) ---------------
            always found: 1176ns per lookup
not found (random bytes): 68ns per lookup
 not found (1 char diff): 1173ns per lookup

----------- sorted by length (asc) ----------------
            always found: 1054ns per lookup
not found (random bytes): 67ns per lookup
 not found (1 char diff): 1053ns per lookup

-------- sorted lexicographically (asc) -----------
            always found: 2764ns per lookup
not found (random bytes): 4615ns per lookup
 not found (1 char diff): 2750ns per lookup

This would ultimately be a trade-off between compile time and runtime performance. I haven't tested to see how much of an impact on the compile time the comptime sorting of the fields would incur. We might end up hitting https://github.com/ziglang/zig/issues/4055, in which case this optimization might need to wait a bit.

Note: Sorting would also make it easy to create a fast path that checks that the str.len is within the bounds of the longest/shortest enum field.