zbt

CLI Bittorrent client, written in Zig
Log | Files | Refs | README

bencode.zig (12142B)


      1 //! Bencoding
      2 //! See specification here https://wiki.theory.org/BitTorrentSpecification#Bencoding
      3 
      4 const std = @import("std");
      5 const AnyWriter = @import("anywriter.zig");
      6 
      7 pub const Error = error.Malformatted || std.io.AnyReader.Error;
      8 
      9 // All content is owned by the BValue and must be freed with deinit.
     10 // hmmm, this gets a bit awkward from the _writing_ side of things.
     11 // What should I do? Optionally owned? Yeah let's do that.
     12 pub const BValue = union(enum) {
     13     int: i64,
     14     string: struct {
     15         string: []const u8,
     16         owned: bool = false,
     17     },
     18     list: struct {
     19         list: std.ArrayList(BValue),
     20         valuesOwned: bool = false,
     21     },
     22     dict: struct {
     23         dict: std.StringArrayHashMap(BValue),
     24         keysAndValuesOwned: bool = false,
     25     },
     26 
     27     pub fn bencode(self: *BValue, base_writer: anytype) !void {
     28         var wrap = AnyWriter.wrapper(base_writer);
     29         var writer = wrap.any();
     30         try self.bencodeInner(writer);
     31     }
     32 
     33     // Note: uses defined types only to avoid trying to recursively evaulate this function
     34     // at compile time, otherwise we run into https://github.com/ziglang/zig/issues/13724
     35     fn bencodeInner(self: *BValue, writer: AnyWriter) !void {
     36         switch (self.*) {
     37             .int => |i| {
     38                 try std.fmt.format(writer, "i{}e", .{i});
     39             },
     40             .string => |s| {
     41                 try std.fmt.format(writer, "{}:{s}", .{ s.string.len, s.string });
     42             },
     43             .list => |*l| {
     44                 try writer.writeByte('l');
     45                 for (l.list.items) |*i| {
     46                     try i.bencodeInner(writer);
     47                 }
     48                 try writer.writeByte('e');
     49             },
     50             .dict => |*d| {
     51                 // Keys must be strings and appear in sorted order (sorted as raw strings, not alphanumerics). The strings should be compared using a binary comparison, not a culture-specific "natural" comparison.
     52                 const Ctx = struct {
     53                     keys: [][]const u8,
     54                     pub fn lessThan(ctx: @This(), a_index: usize, b_index: usize) bool {
     55                         const a_k = ctx.keys[a_index];
     56                         const b_k = ctx.keys[b_index];
     57                         return std.mem.order(u8, a_k, b_k) == .lt;
     58                     }
     59                 };
     60                 var dict: *std.StringArrayHashMap(BValue) = &d.dict;
     61                 dict.sort(Ctx{ .keys = dict.keys() });
     62 
     63                 try writer.writeByte('d');
     64                 var it = dict.iterator();
     65                 while (it.next()) |entry| {
     66                     try std.fmt.format(writer, "{}:{s}", .{ entry.key_ptr.*.len, entry.key_ptr.* });
     67                     try entry.value_ptr.*.bencodeInner(writer);
     68                 }
     69                 try writer.writeByte('e');
     70             },
     71         }
     72     }
     73 
     74     pub fn deinit(self: *BValue, a: std.mem.Allocator) void {
     75         switch (self.*) {
     76             .int => {},
     77             .string => |s| {
     78                 if (s.owned)
     79                     a.free(s.string);
     80             },
     81             .list => |*l| {
     82                 if (l.valuesOwned)
     83                     for (l.list.items) |*i|
     84                         i.deinit(a);
     85                 l.list.deinit();
     86             },
     87             .dict => |*d| {
     88                 if (d.keysAndValuesOwned) {
     89                     var it = d.dict.iterator();
     90                     while (it.next()) |entry| {
     91                         a.free(entry.key_ptr.*);
     92                         entry.value_ptr.*.deinit(a);
     93                     }
     94                 }
     95                 d.dict.deinit();
     96             },
     97         }
     98     }
     99 
    100     pub fn asInt(self: BValue, comptime itype: type) !itype {
    101         switch (self) {
    102             .int => |i| {
    103                 return std.math.cast(itype, i) orelse error.Overflow;
    104             },
    105             else => return error.WrongType,
    106         }
    107     }
    108 
    109     pub fn asString(self: BValue) ![]const u8 {
    110         switch (self) {
    111             .string => |s| return s.string,
    112             else => return error.WrongType,
    113         }
    114     }
    115 
    116     pub fn asList(self: BValue) !std.ArrayList(BValue) {
    117         switch (self) {
    118             .list => |l| return l.list,
    119             else => return error.WrongType,
    120         }
    121     }
    122 
    123     pub fn asDict(self: BValue) !std.StringArrayHashMap(BValue) {
    124         switch (self) {
    125             .dict => |d| return d.dict,
    126             else => return error.WrongType,
    127         }
    128     }
    129 };
    130 
    131 pub fn bdecodeBuf(a: std.mem.Allocator, buf: []const u8) !BValue {
    132     var fbs = std.io.fixedBufferStream(buf);
    133     return try bdecode(a, fbs.reader());
    134 }
    135 
    136 pub fn bdecode(a: std.mem.Allocator, base_reader: anytype) anyerror!BValue {
    137     var reader = PeekStream.init(base_reader.any());
    138     return bdecodeInner(a, &reader, 0);
    139 }
    140 
    141 const PeekStream = std.io.PeekStream(.{ .Static = 1 }, std.io.AnyReader);
    142 
    143 // Note: uses defined types only to avoid trying to recursively evaulate this function
    144 // at compile time, otherwise we run into https://github.com/ziglang/zig/issues/13724
    145 fn bdecodeInner(a: std.mem.Allocator, peekStream: *PeekStream, depth: u32) !BValue {
    146     if (depth > 100) {
    147         // TODO diagnostic...
    148         return error.Malformatted;
    149     }
    150     var reader = peekStream.reader();
    151     var byte = try reader.readByte();
    152     if (std.ascii.isDigit(byte)) {
    153         try peekStream.putBackByte(byte);
    154         return .{ .string = .{
    155             .owned = true,
    156             .string = try readString(a, peekStream),
    157         } };
    158     } else {
    159         switch (byte) {
    160             'i' => {
    161                 const max_len = comptime std.fmt.comptimePrint("{}", .{std.math.minInt(i64)}).len;
    162                 var s = reader.readUntilDelimiterAlloc(a, 'e', max_len) catch return error.Malformatted;
    163                 defer a.free(s);
    164                 const i = std.fmt.parseInt(i64, s, 10) catch return error.Malformatted;
    165                 return .{ .int = i };
    166             },
    167             'l' => {
    168                 var r: BValue = .{ .list = .{
    169                     .valuesOwned = true,
    170                     .list = std.ArrayList(BValue).init(a),
    171                 } };
    172                 errdefer r.deinit(a);
    173                 while (true) {
    174                     const b2 = try reader.readByte();
    175                     if (b2 == 'e') break;
    176                     try peekStream.putBackByte(b2);
    177                     var val = try bdecodeInner(a, peekStream, depth + 1);
    178                     errdefer val.deinit(a);
    179                     try r.list.list.append(val);
    180                 }
    181                 return r;
    182             },
    183             'd' => {
    184                 var r: BValue = .{ .dict = .{ .keysAndValuesOwned = true, .dict = std.StringArrayHashMap(BValue).init(a) } };
    185                 errdefer r.deinit(a);
    186                 while (true) {
    187                     const b2 = try reader.readByte();
    188                     if (b2 == 'e') break;
    189                     try peekStream.putBackByte(b2);
    190                     var key = try readString(a, peekStream);
    191                     errdefer a.free(key);
    192                     var val = try bdecode(a, reader);
    193                     errdefer val.deinit(a);
    194                     try r.dict.dict.put(key, val);
    195                 }
    196                 return r;
    197             },
    198             else => return error.Malformatted, // TODO diagnostics
    199         }
    200     }
    201 }
    202 
    203 // Result is owned by the caller and must be freed
    204 fn readString(a: std.mem.Allocator, peekStream: *PeekStream) ![]const u8 {
    205     var reader = peekStream.reader();
    206     const max_len = comptime std.fmt.comptimePrint("{}", .{std.math.maxInt(usize)}).len;
    207     const str_len_s = reader.readUntilDelimiterAlloc(a, ':', max_len) catch {
    208         return error.Malformatted;
    209     };
    210     defer a.free(str_len_s);
    211     var strlen = std.fmt.parseInt(usize, str_len_s, 10) catch return error.Malformatted;
    212     var string = try a.alloc(u8, strlen);
    213     errdefer a.free(string);
    214     reader.readNoEof(string) catch return error.Malformatted;
    215     return string;
    216 }
    217 
    218 test "bdecode empty" {
    219     var a = std.testing.allocator;
    220     try std.testing.expectError(error.EndOfStream, bdecodeBuf(a, ""));
    221 }
    222 
    223 test "bdecode too short" {
    224     var a = std.testing.allocator;
    225     try std.testing.expectError(error.Malformatted, bdecodeBuf(a, "1"));
    226 }
    227 
    228 test "bdecode plain number" {
    229     var a = std.testing.allocator;
    230     try std.testing.expectError(error.Malformatted, bdecodeBuf(a, "12"));
    231 }
    232 
    233 test "bdecode garbage" {
    234     var a = std.testing.allocator;
    235     try std.testing.expectError(error.Malformatted, bdecodeBuf(a, "xz1234"));
    236 }
    237 
    238 test "bdecode number" {
    239     var a = std.testing.allocator;
    240     var bval = try bdecodeBuf(a, "i123e");
    241     defer bval.deinit(a);
    242     try std.testing.expectEqualDeep(BValue{ .int = 123 }, bval);
    243 }
    244 
    245 test "bdecode number negative" {
    246     var a = std.testing.allocator;
    247     var bval = try bdecodeBuf(a, "i-123e");
    248     defer bval.deinit(a);
    249     try std.testing.expectEqualDeep(BValue{ .int = -123 }, bval);
    250 }
    251 
    252 test "bdecode number empty" {
    253     var a = std.testing.allocator;
    254     try std.testing.expectError(error.Malformatted, bdecodeBuf(a, "ie"));
    255 }
    256 
    257 test "bdecode number just sign" {
    258     var a = std.testing.allocator;
    259     try std.testing.expectError(error.Malformatted, bdecodeBuf(a, "i-e"));
    260 }
    261 
    262 test "bdecode number no end" {
    263     var a = std.testing.allocator;
    264     try std.testing.expectError(error.Malformatted, bdecodeBuf(a, "i123123671283"));
    265 }
    266 
    267 test "bdecode number out of range" {
    268     var a = std.testing.allocator;
    269     try std.testing.expectError(error.Malformatted, bdecodeBuf(a, "i9223372036854775808e"));
    270 }
    271 
    272 test "bdecode string" {
    273     var a = std.testing.allocator;
    274     var bval = try bdecodeBuf(a, "5:hello");
    275     defer bval.deinit(a);
    276     try std.testing.expectEqualDeep(BValue{ .string = .{ .owned = true, .string = "hello" } }, bval);
    277 }
    278 test "bdecode string too short" {
    279     var a = std.testing.allocator;
    280     try std.testing.expectError(error.Malformatted, bdecodeBuf(a, "5:hell"));
    281 }
    282 
    283 test "bdecode list" {
    284     var a = std.testing.allocator;
    285     var bval = try bdecodeBuf(a, "l5:hello5:worlde");
    286     defer bval.deinit(a);
    287     try std.testing.expectEqual(@as(usize, 2), bval.list.list.items.len);
    288     try std.testing.expectEqualStrings("hello", bval.list.list.items[0].string.string);
    289     try std.testing.expectEqualStrings("world", bval.list.list.items[1].string.string);
    290 }
    291 
    292 test "invalid list" {
    293     var a = std.testing.allocator;
    294     try std.testing.expectError(error.EndOfStream, bdecodeBuf(a, "l5:hello5:world")); // missing end
    295 }
    296 
    297 test "dict" {
    298     var a = std.testing.allocator;
    299     var bval = try bdecodeBuf(a, "d5:hello5:worlde");
    300     defer bval.deinit(a);
    301     var v = bval.dict.dict.getPtr("hello") orelse return error.TestExpectedNotNull;
    302     try std.testing.expectEqualStrings("world", v.string.string);
    303 }
    304 
    305 test "invalid dict no value" {
    306     var a = std.testing.allocator;
    307     try std.testing.expectError(error.Malformatted, bdecodeBuf(a, "d5:hello5:world2:hie"));
    308 }
    309 
    310 test "invalid dict wrong key type" {
    311     var a = std.testing.allocator;
    312     try std.testing.expectError(error.Malformatted, bdecodeBuf(a, "di32e5:helloe"));
    313 }
    314 
    315 test "nested structure" {
    316     var a = std.testing.allocator;
    317     var bval = try bdecodeBuf(a, "d5:hello5:world2:hili123ei456el4:nesteee");
    318     defer bval.deinit(a);
    319     var v = bval.dict.dict.getPtr("hello") orelse return error.TestExpectedNotNull;
    320     try std.testing.expectEqualStrings("world", v.string.string);
    321     var v2 = bval.dict.dict.getPtr("hi") orelse return error.TestExpectedNotNull;
    322     try std.testing.expectEqualDeep(v2.*.list.list.items[0], BValue{ .int = 123 });
    323     try std.testing.expectEqualDeep(v2.*.list.list.items[1], BValue{ .int = 456 });
    324     try std.testing.expectEqualStrings("nest", v2.*.list.list.items[2].list.list.items[0].string.string);
    325 }
    326 
    327 test "round trip" {
    328     var a = std.testing.allocator;
    329     const in = "d5:hello5:world2:hili123ei456el4:nesteee";
    330     var bval = try bdecodeBuf(a, in);
    331     defer bval.deinit(a);
    332     var bw = std.ArrayList(u8).init(a);
    333     defer bw.deinit();
    334     var writer = bw.writer();
    335     try bval.bencode(writer);
    336     var out = try bw.toOwnedSlice();
    337     defer a.free(out);
    338     try std.testing.expectEqualStrings(in, out);
    339 }