aboutsummaryrefslogtreecommitdiff
path: root/src/bencode.zig
diff options
context:
space:
mode:
Diffstat (limited to 'src/bencode.zig')
-rw-r--r--src/bencode.zig158
1 files changed, 89 insertions, 69 deletions
diff --git a/src/bencode.zig b/src/bencode.zig
index 60a3746..3e83121 100644
--- a/src/bencode.zig
+++ b/src/bencode.zig
@@ -7,13 +7,24 @@ const AnyWriter = @import("anywriter.zig");
pub const Error = error.Malformatted || std.io.AnyReader.Error;
// All content is owned by the BValue and must be freed with deinit.
+// hmmm, this gets a bit awkward from the _writing_ side of things.
+// What should I do? Optionally owned? Yeah let's do that.
pub const BValue = union(enum) {
- string: []const u8,
int: i64,
- list: std.ArrayList(BValue),
- dict: std.StringArrayHashMap(BValue),
-
- pub fn bencode(self: *const BValue, base_writer: anytype) !void {
+ string: struct {
+ string: []const u8,
+ owned: bool = false,
+ },
+ list: struct {
+ list: std.ArrayList(BValue),
+ valuesOwned: bool = false,
+ },
+ dict: struct {
+ dict: std.StringArrayHashMap(BValue),
+ keysAndValuesOwned: bool = false,
+ },
+
+ pub fn bencode(self: *BValue, base_writer: anytype) !void {
var wrap = AnyWriter.wrapper(base_writer);
var writer = wrap.any();
try self.bencodeInner(writer);
@@ -21,60 +32,68 @@ pub const BValue = union(enum) {
// Note: uses defined types only to avoid trying to recursively evaulate this function
// at compile time, otherwise we run into https://github.com/ziglang/zig/issues/13724
- fn bencodeInner(self: *const BValue, writer: AnyWriter) !void {
+ fn bencodeInner(self: *BValue, writer: AnyWriter) !void {
switch (self.*) {
+ .int => |i| {
+ try std.fmt.format(writer, "i{}e", .{i});
+ },
.string => |s| {
- try std.fmt.format(writer, "{}:{s}", .{ s.len, s });
+ try std.fmt.format(writer, "{}:{s}", .{ s.string.len, s.string });
},
- .list => |l| {
+ .list => |*l| {
try writer.writeByte('l');
- for (l.items) |i| {
+ for (l.list.items) |*i| {
try i.bencodeInner(writer);
}
try writer.writeByte('e');
},
- .dict => |d| {
+ .dict => |*d| {
+ // Keys must be strings and appear in sorted order (sorted as raw strings, not alphanumerics). The strings should be compared using a binary comparison, not a culture-specific "natural" comparison.
+ const Ctx = struct {
+ keys: [][]const u8,
+ pub fn lessThan(ctx: @This(), a_index: usize, b_index: usize) bool {
+ const a_k = ctx.keys[a_index];
+ const b_k = ctx.keys[b_index];
+ return std.mem.order(u8, a_k, b_k) == .lt;
+ }
+ };
+ var dict: *std.StringArrayHashMap(BValue) = &d.dict;
+ dict.sort(Ctx{ .keys = dict.keys() });
+
try writer.writeByte('d');
- var it = d.iterator();
+ var it = dict.iterator();
while (it.next()) |entry| {
try std.fmt.format(writer, "{}:{s}", .{ entry.key_ptr.*.len, entry.key_ptr.* });
try entry.value_ptr.*.bencodeInner(writer);
}
try writer.writeByte('e');
},
- .int => |i| {
- try std.fmt.format(writer, "i{}e", .{i});
- },
}
}
pub fn deinit(self: *BValue, a: std.mem.Allocator) void {
switch (self.*) {
+ .int => {},
.string => |s| {
- a.free(s);
+ if (s.owned)
+ a.free(s.string);
},
.list => |*l| {
- for (l.items) |*i| {
- i.deinit(a);
- }
- l.deinit();
+ if (l.valuesOwned)
+ for (l.list.items) |*i|
+ i.deinit(a);
+ l.list.deinit();
},
.dict => |*d| {
- var it = d.iterator();
- while (it.next()) |entry| {
- a.free(entry.key_ptr.*);
- entry.value_ptr.*.deinit(a);
+ if (d.keysAndValuesOwned) {
+ var it = d.dict.iterator();
+ while (it.next()) |entry| {
+ a.free(entry.key_ptr.*);
+ entry.value_ptr.*.deinit(a);
+ }
}
- d.deinit();
+ d.dict.deinit();
},
- .int => {},
- }
- }
-
- pub fn asDict(self: BValue) !std.StringArrayHashMap(BValue) {
- switch (self) {
- .dict => |d| return d,
- else => return error.WrongType,
}
}
@@ -89,14 +108,21 @@ pub const BValue = union(enum) {
pub fn asString(self: BValue) ![]const u8 {
switch (self) {
- .string => |s| return s,
+ .string => |s| return s.string,
else => return error.WrongType,
}
}
pub fn asList(self: BValue) !std.ArrayList(BValue) {
switch (self) {
- .list => |l| return l,
+ .list => |l| return l.list,
+ else => return error.WrongType,
+ }
+ }
+
+ pub fn asDict(self: BValue) !std.StringArrayHashMap(BValue) {
+ switch (self) {
+ .dict => |d| return d.dict,
else => return error.WrongType,
}
}
@@ -125,7 +151,10 @@ fn bdecodeInner(a: std.mem.Allocator, peekStream: *PeekStream, depth: u32) !BVal
var byte = try reader.readByte();
if (std.ascii.isDigit(byte)) {
try peekStream.putBackByte(byte);
- return .{ .string = try readString(a, peekStream) };
+ return .{ .string = .{
+ .owned = true,
+ .string = try readString(a, peekStream),
+ } };
} else {
switch (byte) {
'i' => {
@@ -136,44 +165,35 @@ fn bdecodeInner(a: std.mem.Allocator, peekStream: *PeekStream, depth: u32) !BVal
return .{ .int = i };
},
'l' => {
- var list = std.ArrayList(BValue).init(a);
- errdefer {
- for (list.items) |*i| {
- i.deinit(a);
- }
- list.deinit();
- }
+ var r: BValue = .{ .list = .{
+ .valuesOwned = true,
+ .list = std.ArrayList(BValue).init(a),
+ } };
+ errdefer r.deinit(a);
while (true) {
const b2 = try reader.readByte();
if (b2 == 'e') break;
try peekStream.putBackByte(b2);
var val = try bdecodeInner(a, peekStream, depth + 1);
errdefer val.deinit(a);
- try list.append(val);
+ try r.list.list.append(val);
}
- return .{ .list = list };
+ return r;
},
'd' => {
- var dict = std.StringArrayHashMap(BValue).init(a);
- errdefer {
- var it = dict.iterator();
- while (it.next()) |entry| {
- a.free(entry.key_ptr.*);
- entry.value_ptr.*.deinit(a);
- }
- dict.deinit();
- }
- lp: while (true) {
+ var r: BValue = .{ .dict = .{ .keysAndValuesOwned = true, .dict = std.StringArrayHashMap(BValue).init(a) } };
+ errdefer r.deinit(a);
+ while (true) {
const b2 = try reader.readByte();
- if (b2 == 'e') break :lp;
+ if (b2 == 'e') break;
try peekStream.putBackByte(b2);
var key = try readString(a, peekStream);
errdefer a.free(key);
var val = try bdecode(a, reader);
errdefer val.deinit(a);
- try dict.put(key, val);
+ try r.dict.dict.put(key, val);
}
- return .{ .dict = dict };
+ return r;
},
else => return error.Malformatted, // TODO diagnostics
}
@@ -253,7 +273,7 @@ test "bdecode string" {
var a = std.testing.allocator;
var bval = try bdecodeBuf(a, "5:hello");
defer bval.deinit(a);
- try std.testing.expectEqualDeep(BValue{ .string = "hello" }, bval);
+ try std.testing.expectEqualDeep(BValue{ .string = .{ .owned = true, .string = "hello" } }, bval);
}
test "bdecode string too short" {
var a = std.testing.allocator;
@@ -264,9 +284,9 @@ test "bdecode list" {
var a = std.testing.allocator;
var bval = try bdecodeBuf(a, "l5:hello5:worlde");
defer bval.deinit(a);
- try std.testing.expectEqual(@as(usize, 2), bval.list.items.len);
- try std.testing.expectEqualStrings("hello", bval.list.items[0].string);
- try std.testing.expectEqualStrings("world", bval.list.items[1].string);
+ try std.testing.expectEqual(@as(usize, 2), bval.list.list.items.len);
+ try std.testing.expectEqualStrings("hello", bval.list.list.items[0].string.string);
+ try std.testing.expectEqualStrings("world", bval.list.list.items[1].string.string);
}
test "invalid list" {
@@ -278,8 +298,8 @@ test "dict" {
var a = std.testing.allocator;
var bval = try bdecodeBuf(a, "d5:hello5:worlde");
defer bval.deinit(a);
- var v = bval.dict.getPtr("hello") orelse return error.TestExpectedNotNull;
- try std.testing.expectEqualStrings("world", v.string);
+ var v = bval.dict.dict.getPtr("hello") orelse return error.TestExpectedNotNull;
+ try std.testing.expectEqualStrings("world", v.string.string);
}
test "invalid dict no value" {
@@ -296,12 +316,12 @@ test "nested structure" {
var a = std.testing.allocator;
var bval = try bdecodeBuf(a, "d5:hello5:world2:hili123ei456el4:nesteee");
defer bval.deinit(a);
- var v = bval.dict.getPtr("hello") orelse return error.TestExpectedNotNull;
- try std.testing.expectEqualStrings("world", v.string);
- var v2 = bval.dict.getPtr("hi") orelse return error.TestExpectedNotNull;
- try std.testing.expectEqualDeep(v2.*.list.items[0], BValue{ .int = 123 });
- try std.testing.expectEqualDeep(v2.*.list.items[1], BValue{ .int = 456 });
- try std.testing.expectEqualStrings("nest", v2.*.list.items[2].list.items[0].string);
+ var v = bval.dict.dict.getPtr("hello") orelse return error.TestExpectedNotNull;
+ try std.testing.expectEqualStrings("world", v.string.string);
+ var v2 = bval.dict.dict.getPtr("hi") orelse return error.TestExpectedNotNull;
+ try std.testing.expectEqualDeep(v2.*.list.list.items[0], BValue{ .int = 123 });
+ try std.testing.expectEqualDeep(v2.*.list.list.items[1], BValue{ .int = 456 });
+ try std.testing.expectEqualStrings("nest", v2.*.list.list.items[2].list.list.items[0].string.string);
}
test "round trip" {