z1brc

Unnamed repository; edit this file 'description' to name the repository.
Log | Files | Refs | README

main.zig (6259B)


      1 const std = @import("std");
      2 
      3 pub const std_options = struct {
      4     pub const log_level = .info;
      5 };
      6 
      7 const Accumulator = struct {
      8     min: f32,
      9     max: f32,
     10     sum: f64,
     11     count: u64,
     12 };
     13 
     14 pub fn main() !void {
     15     var t = try std.time.Timer.start();
     16     std.log.info("start!", .{});
     17 
     18     var gpa = std.heap.GeneralPurposeAllocator(.{}){};
     19     defer _ = gpa.deinit();
     20     const a = gpa.allocator();
     21     
     22     var args = std.process.args();
     23     defer args.deinit();
     24     if (!args.skip()) @panic("program name wasn't supplied wtf");
     25     const infile_name = args.next() orelse return error.NoInputFile;
     26     const infile = try open_mmap(std.fs.cwd(), infile_name);
     27     defer std.os.munmap(infile);
     28 
     29     const out = try run(a, infile, &t);
     30     defer a.free(out);
     31     try std.io.getStdOut().writeAll(out);
     32 
     33     std.log.info("finished at {} s", .{t.read() / std.time.ns_per_s});
     34 }
     35 
     36 // Result must be closed with std.os.munmap
     37 fn open_mmap(dir: std.fs.Dir, file_path: []const u8) ![]align(std.mem.page_size) u8 {
     38     var f = try dir.openFile(file_path, .{ .mode = .read_only });
     39     defer f.close();
     40     const stat = try f.stat();
     41     return try std.os.mmap(null, stat.size, std.os.PROT.READ, std.os.MAP.PRIVATE, f.handle, 0);
     42 }
     43 
     44 fn run(a: std.mem.Allocator, infile: []const u8, _: *std.time.Timer) ![]const u8 {
     45     const threadcount = try std.Thread.getCpuCount();
     46     var ress = try a.alloc(std.StringArrayHashMap(Accumulator), threadcount);
     47     defer a.free(ress);
     48     var threads = try a.alloc(std.Thread, threadcount);
     49     defer a.free(threads);
     50     var threadnames = try a.alloc([]const u8, threadcount);
     51     defer a.free(threadnames);
     52 
     53     var start: usize = 0;
     54     for (0..threadcount) |i| {
     55         ress[i] = std.StringArrayHashMap(Accumulator).init(a);
     56         var end = ((infile.len * (i+1)) / threadcount);
     57         while (end < infile.len and infile[end] != '\n') end += 1;
     58         const infile_part = infile[start..end];
     59         const threadname = try std.fmt.allocPrint(a, "threads {}", .{i});
     60         threadnames[i] = threadname;
     61         threads[i] = try std.Thread.spawn(.{}, run_part, .{&ress[i], infile_part, threadname});
     62         start = end+1;
     63     }
     64     defer {
     65         for (0..threadcount) |i| {
     66             defer free_keys_and_deinit(&ress[i]);
     67             a.free(threadnames[i]);
     68         }
     69     }
     70     for (0..threadcount) |i| {
     71         threads[i].join();
     72     }
     73     // Now merge the results
     74     var res = std.StringArrayHashMap(Accumulator).init(a);
     75     defer res.deinit(); // Doesn't own it's own keys
     76     for (0..threadcount) |i| {
     77         try merge_in(&res, &ress[i]);
     78     }
     79 
     80     // Sort and print
     81     const Srt = struct {
     82         keys: [][]const u8,
     83         pub fn lessThan(self: @This(), a_index: usize, b_index: usize) bool {
     84             // character value order!
     85             return std.mem.order(u8, self.keys[a_index], self.keys[b_index]).compare(.lt);
     86         }
     87     };
     88     res.sort(Srt{.keys = res.keys()});
     89 
     90     var rr = std.ArrayList(u8).init(a);
     91     defer rr.deinit();
     92     var ww = rr.writer();
     93     try ww.writeAll("{");
     94     var it = res.iterator();
     95     while (it.next()) |nxt| {
     96         const k = nxt.key_ptr.*;
     97         try ww.writeAll(k);
     98         try ww.writeAll("=");
     99         const v = nxt.value_ptr.*;
    100         try std.fmt.format(ww, "{d:.1}", .{v.min});
    101         try ww.writeAll("/");
    102         try std.fmt.format(ww, "{d:.1}", .{v.max});
    103         try ww.writeAll("/");
    104         const mean = v.sum / @as(f64, @floatFromInt(v.count));
    105         try std.fmt.format(ww, "{d:.1}", .{mean});
    106         try ww.writeAll(", ");
    107     }
    108     try ww.writeAll("}");
    109     return try rr.toOwnedSlice();
    110 }
    111 
    112 fn run_part(res: *std.StringArrayHashMap(Accumulator), infile: []const u8, name: []const u8) !void {
    113     var t = try std.time.Timer.start(); // I know it's supported on my platform
    114     var lines = std.mem.tokenizeScalar(u8, infile, '\n');
    115     var ct: usize = 0;
    116     while (lines.next()) |line| {
    117         ct += 1;
    118         if (ct % 1000000 == 0) {
    119             const sec = t.read() / std.time.ns_per_s;
    120             const rows_sec = ct / sec;
    121             std.log.info("thread {s} processed {} lines at {} seconds, rate {} rows / sec", .{name, ct, sec, rows_sec});
    122         }
    123         var spl = std.mem.splitScalar(u8, line, ';');
    124         const key = spl.first();
    125         const val_s = spl.next() orelse unreachable;
    126         const val = std.fmt.parseFloat(f32, val_s) catch unreachable;
    127         
    128         if (res.contains(key)) {
    129             const e = res.getPtr(key) orelse unreachable;
    130             e.* = .{
    131                 .min = @min(e.min, val),
    132                 .max = @max(e.max, val),
    133                 .sum = e.sum + val,
    134                 .count = e.count + 1,
    135             };
    136         } else {
    137             const kd = try res.allocator.dupe(u8, key);
    138             try res.put(kd,.{
    139                 .min = val,
    140                 .max = val,
    141                 .sum = val,
    142                 .count = 1,
    143             });
    144         }
    145     }    
    146 }
    147 
    148 fn merge_in(res_f: *std.StringArrayHashMap(Accumulator), res_a: *std.StringArrayHashMap(Accumulator)) !void {
    149     var it = res_a.iterator();
    150     while (it.next()) |e| {
    151         const r = e.value_ptr.*;
    152         const gpr = try res_f.getOrPut(e.key_ptr.*);
    153         if (gpr.found_existing) {
    154             const rr = gpr.value_ptr.*;
    155             gpr.value_ptr.* = Accumulator{
    156                 .min = @min(rr.min, r.min),
    157                 .max = @max(rr.max, r.max),
    158                 .sum = rr.sum + r.sum,
    159                 .count = rr.count + r.count,
    160             };
    161         } else {
    162             gpr.value_ptr.* = r;
    163         }
    164     }
    165 }
    166 
    167 fn free_keys_and_deinit(hm: *std.StringArrayHashMap(Accumulator)) void {
    168     for (hm.keys()) |*k| {
    169         hm.allocator.free(k.*);
    170     }
    171     hm.deinit();
    172 }
    173 
    174 
    175 test {
    176     const test_input = @embedFile("measurement_test.txt");
    177     const test_output = 
    178         \\{Bridgetown=26.9/26.9/26.9, Bulawayo=8.9/8.9/8.9, Conakry=31.2/31.2/31.2, Cracow=12.6/12.6/12.6, Hamburg=12.0/12.0/12.0, Istanbul=-15.0/23.0/-1.3, Palembang=38.8/38.8/38.8, Roseau=34.4/34.4/34.4, St. John's=15.2/15.2/15.2, }
    179     ;
    180     var t = try std.time.Timer.start();
    181     const a = std.testing.allocator;
    182     const out = try run(a, test_input, &t);
    183     defer a.free(out);
    184     try std.testing.expectEqualStrings(test_output, out);
    185 }