aboutsummaryrefslogtreecommitdiff
path: root/src
diff options
context:
space:
mode:
authorMartin Ashby <martin@ashbysoft.com>2023-11-09 22:09:08 +0000
committerMartin Ashby <martin@ashbysoft.com>2023-11-09 22:09:08 +0000
commit68c104c8b0580c51c9f16ec33d6a957fd4c08c0c (patch)
tree651032a12bcb2e20bf319e0544042069f0bac3ca /src
downloadzbt-68c104c8b0580c51c9f16ec33d6a957fd4c08c0c.tar.gz
zbt-68c104c8b0580c51c9f16ec33d6a957fd4c08c0c.tar.bz2
zbt-68c104c8b0580c51c9f16ec33d6a957fd4c08c0c.tar.xz
zbt-68c104c8b0580c51c9f16ec33d6a957fd4c08c0c.zip
Add bencoding decoding library
Diffstat (limited to 'src')
-rw-r--r--src/bencode.zig239
-rw-r--r--src/main.zig7
2 files changed, 246 insertions, 0 deletions
diff --git a/src/bencode.zig b/src/bencode.zig
new file mode 100644
index 0000000..b78ee9f
--- /dev/null
+++ b/src/bencode.zig
@@ -0,0 +1,239 @@
+//! Bencoding
+//! See specification here https://wiki.theory.org/BitTorrentSpecification#Bencoding
+
+const std = @import("std");
+
+pub const Error = error.Malformatted || std.io.AnyReader.Error;
+
+// All content is owned by the BValue and must be freed with deinit.
+pub const BValue = union(enum) {
+ string: []const u8,
+ int: i64,
+ list: std.ArrayList(BValue),
+ dict: std.StringArrayHashMap(BValue),
+
+ pub fn deinit(self: *BValue, a: std.mem.Allocator) void {
+ switch (self.*) {
+ .string => |s| {
+ a.free(s);
+ },
+ .list => |*l| {
+ for (l.items) |*i| {
+ i.deinit(a);
+ }
+ l.deinit();
+ },
+ .dict => |*d| {
+ var it = d.iterator();
+ while (it.next()) |entry| {
+ a.free(entry.key_ptr.*);
+ entry.value_ptr.*.deinit(a);
+ }
+ d.deinit();
+ },
+ .int => {},
+ }
+ }
+};
+
+pub fn bdecodeBuf(a: std.mem.Allocator, buf: []const u8) !BValue {
+ var fbs = std.io.fixedBufferStream(buf);
+ return try bdecode(a, fbs.reader());
+}
+
+pub fn bdecode(a: std.mem.Allocator, base_reader: anytype) anyerror!BValue {
+ var reader = PeekStream.init(base_reader.any());
+ return bdecodeInner(a, &reader, 0);
+}
+
+const PeekStream = std.io.PeekStream(.{ .Static = 1 }, std.io.AnyReader);
+
+// Note: uses defined types only to avoid trying to recursively evaulate this function
+// at compile time, otherwise we run into https://github.com/ziglang/zig/issues/13724
+fn bdecodeInner(a: std.mem.Allocator, peekStream: *PeekStream, depth: u32) !BValue {
+ if (depth > 100) {
+ // TODO diagnostic...
+ return error.Malformatted;
+ }
+ var reader = peekStream.reader();
+ var byte = try reader.readByte();
+ if (std.ascii.isDigit(byte)) {
+ try peekStream.putBackByte(byte);
+ return .{ .string = try readString(a, peekStream) };
+ } else {
+ switch (byte) {
+ 'i' => {
+ const max_len = comptime std.fmt.comptimePrint("{}", .{std.math.minInt(i64)}).len;
+ var s = reader.readUntilDelimiterAlloc(a, 'e', max_len) catch return error.Malformatted;
+ defer a.free(s);
+ const i = std.fmt.parseInt(i64, s, 10) catch return error.Malformatted;
+ return .{ .int = i };
+ },
+ 'l' => {
+ var list = std.ArrayList(BValue).init(a);
+ errdefer {
+ for (list.items) |*i| {
+ i.deinit(a);
+ }
+ list.deinit();
+ }
+ while (true) {
+ const b2 = try reader.readByte();
+ if (b2 == 'e') break;
+ try peekStream.putBackByte(b2);
+ var val = try bdecodeInner(a, peekStream, depth + 1);
+ errdefer val.deinit(a);
+ try list.append(val);
+ }
+ return .{ .list = list };
+ },
+ 'd' => {
+ var dict = std.StringArrayHashMap(BValue).init(a);
+ errdefer {
+ var it = dict.iterator();
+ while (it.next()) |entry| {
+ a.free(entry.key_ptr.*);
+ entry.value_ptr.*.deinit(a);
+ }
+ dict.deinit();
+ }
+ lp: while (true) {
+ const b2 = try reader.readByte();
+ if (b2 == 'e') break :lp;
+ try peekStream.putBackByte(b2);
+ var key = try readString(a, peekStream);
+ errdefer a.free(key);
+ var val = try bdecode(a, reader);
+ errdefer val.deinit(a);
+ try dict.put(key, val);
+ }
+ return .{ .dict = dict };
+ },
+ else => return error.Malformatted, // TODO diagnostics
+ }
+ }
+}
+
+// Result is owned by the caller and must be freed
+fn readString(a: std.mem.Allocator, peekStream: *PeekStream) ![]const u8 {
+ var reader = peekStream.reader();
+ const max_len = comptime std.fmt.comptimePrint("{}", .{std.math.maxInt(usize)}).len;
+ const str_len_s = reader.readUntilDelimiterAlloc(a, ':', max_len) catch {
+ return error.Malformatted;
+ };
+ defer a.free(str_len_s);
+ var strlen = std.fmt.parseInt(usize, str_len_s, 10) catch return error.Malformatted;
+ var string = try a.alloc(u8, strlen);
+ errdefer a.free(string);
+ reader.readNoEof(string) catch return error.Malformatted;
+ return string;
+}
+
+test "bdecode empty" {
+ var a = std.testing.allocator;
+ try std.testing.expectError(error.EndOfStream, bdecodeBuf(a, ""));
+}
+
+test "bdecode too short" {
+ var a = std.testing.allocator;
+ try std.testing.expectError(error.Malformatted, bdecodeBuf(a, "1"));
+}
+
+test "bdecode plain number" {
+ var a = std.testing.allocator;
+ try std.testing.expectError(error.Malformatted, bdecodeBuf(a, "12"));
+}
+
+test "bdecode garbage" {
+ var a = std.testing.allocator;
+ try std.testing.expectError(error.Malformatted, bdecodeBuf(a, "xz1234"));
+}
+
+test "bdecode number" {
+ var a = std.testing.allocator;
+ var bval = try bdecodeBuf(a, "i123e");
+ defer bval.deinit(a);
+ try std.testing.expectEqualDeep(BValue{ .int = 123 }, bval);
+}
+
+test "bdecode number negative" {
+ var a = std.testing.allocator;
+ var bval = try bdecodeBuf(a, "i-123e");
+ defer bval.deinit(a);
+ try std.testing.expectEqualDeep(BValue{ .int = -123 }, bval);
+}
+
+test "bdecode number empty" {
+ var a = std.testing.allocator;
+ try std.testing.expectError(error.Malformatted, bdecodeBuf(a, "ie"));
+}
+
+test "bdecode number just sign" {
+ var a = std.testing.allocator;
+ try std.testing.expectError(error.Malformatted, bdecodeBuf(a, "i-e"));
+}
+
+test "bdecode number no end" {
+ var a = std.testing.allocator;
+ try std.testing.expectError(error.Malformatted, bdecodeBuf(a, "i123123671283"));
+}
+
+test "bdecode number out of range" {
+ var a = std.testing.allocator;
+ try std.testing.expectError(error.Malformatted, bdecodeBuf(a, "i9223372036854775808e"));
+}
+
+test "bdecode string" {
+ var a = std.testing.allocator;
+ var bval = try bdecodeBuf(a, "5:hello");
+ defer bval.deinit(a);
+ try std.testing.expectEqualDeep(BValue{ .string = "hello" }, bval);
+}
+test "bdecode string too short" {
+ var a = std.testing.allocator;
+ try std.testing.expectError(error.Malformatted, bdecodeBuf(a, "5:hell"));
+}
+
+test "bdecode list" {
+ var a = std.testing.allocator;
+ var bval = try bdecodeBuf(a, "l5:hello5:worlde");
+ defer bval.deinit(a);
+ try std.testing.expectEqual(@as(usize, 2), bval.list.items.len);
+ try std.testing.expectEqualStrings("hello", bval.list.items[0].string);
+ try std.testing.expectEqualStrings("world", bval.list.items[1].string);
+}
+
+test "invalid list" {
+ var a = std.testing.allocator;
+ try std.testing.expectError(error.EndOfStream, bdecodeBuf(a, "l5:hello5:world")); // missing end
+}
+
+test "dict" {
+ var a = std.testing.allocator;
+ var bval = try bdecodeBuf(a, "d5:hello5:worlde");
+ defer bval.deinit(a);
+ var v = bval.dict.getPtr("hello") orelse return error.TestExpectedNotNull;
+ try std.testing.expectEqualStrings("world", v.string);
+}
+
+test "invalid dict no value" {
+ var a = std.testing.allocator;
+ try std.testing.expectError(error.Malformatted, bdecodeBuf(a, "d5:hello5:world2:hie"));
+}
+
+test "invalid dict wrong key type" {
+ var a = std.testing.allocator;
+ try std.testing.expectError(error.Malformatted, bdecodeBuf(a, "di32e5:helloe"));
+}
+
+test "nested structure" {
+ var a = std.testing.allocator;
+ var bval = try bdecodeBuf(a, "d5:hello5:world2:hili123ei456el4:nesteee");
+ defer bval.deinit(a);
+ var v = bval.dict.getPtr("hello") orelse return error.TestExpectedNotNull;
+ try std.testing.expectEqualStrings("world", v.string);
+ var v2 = bval.dict.getPtr("hi") orelse return error.TestExpectedNotNull;
+ try std.testing.expectEqualDeep(v2.*.list.items[0], BValue{ .int = 123 });
+ try std.testing.expectEqualDeep(v2.*.list.items[1], BValue{ .int = 456 });
+ try std.testing.expectEqualStrings("nest", v2.*.list.items[2].list.items[0].string);
+}
diff --git a/src/main.zig b/src/main.zig
new file mode 100644
index 0000000..fa8a1b6
--- /dev/null
+++ b/src/main.zig
@@ -0,0 +1,7 @@
+const std = @import("std");
+
+pub fn main() !void {}
+
+test {
+ _ = @import("bencode.zig");
+}