diff options
author | Martin Ashby <martin@ashbysoft.com> | 2023-11-11 20:47:46 +0000 |
---|---|---|
committer | Martin Ashby <martin@ashbysoft.com> | 2023-11-11 20:47:46 +0000 |
commit | e92459156c2f74de648566ca7acde3833de33425 (patch) | |
tree | 59d0bb42d23db0d295a0ef0a97e31e71b8f592cd /src | |
parent | dff4234bf2d957a0328ff4f3dc4f9bba1fbeffd4 (diff) | |
download | zbt-e92459156c2f74de648566ca7acde3833de33425.tar.gz zbt-e92459156c2f74de648566ca7acde3833de33425.tar.bz2 zbt-e92459156c2f74de648566ca7acde3833de33425.tar.xz zbt-e92459156c2f74de648566ca7acde3833de33425.zip |
Include ownership info on BValue.
It's a bit cumbersome, but it caters for the two situations: a. parsing
from a stream where strings must be copied into the structure and owned,
and b. constructing from application code where strings must _not_ be
freed. Either way, the base structures (array list, hash map) are always
owned and must be freed.
Add method and test for info hash
Diffstat (limited to 'src')
-rw-r--r-- | src/bencode.zig | 158 | ||||
-rw-r--r-- | src/metainfo.zig | 79 |
2 files changed, 168 insertions, 69 deletions
diff --git a/src/bencode.zig b/src/bencode.zig index 60a3746..3e83121 100644 --- a/src/bencode.zig +++ b/src/bencode.zig @@ -7,13 +7,24 @@ const AnyWriter = @import("anywriter.zig"); pub const Error = error.Malformatted || std.io.AnyReader.Error; // All content is owned by the BValue and must be freed with deinit. +// hmmm, this gets a bit awkward from the _writing_ side of things. +// What should I do? Optionally owned? Yeah let's do that. pub const BValue = union(enum) { - string: []const u8, int: i64, - list: std.ArrayList(BValue), - dict: std.StringArrayHashMap(BValue), - - pub fn bencode(self: *const BValue, base_writer: anytype) !void { + string: struct { + string: []const u8, + owned: bool = false, + }, + list: struct { + list: std.ArrayList(BValue), + valuesOwned: bool = false, + }, + dict: struct { + dict: std.StringArrayHashMap(BValue), + keysAndValuesOwned: bool = false, + }, + + pub fn bencode(self: *BValue, base_writer: anytype) !void { var wrap = AnyWriter.wrapper(base_writer); var writer = wrap.any(); try self.bencodeInner(writer); @@ -21,60 +32,68 @@ pub const BValue = union(enum) { // Note: uses defined types only to avoid trying to recursively evaulate this function // at compile time, otherwise we run into https://github.com/ziglang/zig/issues/13724 - fn bencodeInner(self: *const BValue, writer: AnyWriter) !void { + fn bencodeInner(self: *BValue, writer: AnyWriter) !void { switch (self.*) { + .int => |i| { + try std.fmt.format(writer, "i{}e", .{i}); + }, .string => |s| { - try std.fmt.format(writer, "{}:{s}", .{ s.len, s }); + try std.fmt.format(writer, "{}:{s}", .{ s.string.len, s.string }); }, - .list => |l| { + .list => |*l| { try writer.writeByte('l'); - for (l.items) |i| { + for (l.list.items) |*i| { try i.bencodeInner(writer); } try writer.writeByte('e'); }, - .dict => |d| { + .dict => |*d| { + // Keys must be strings and appear in sorted order (sorted as raw strings, not alphanumerics). The strings should be compared using a binary comparison, not a culture-specific "natural" comparison. + const Ctx = struct { + keys: [][]const u8, + pub fn lessThan(ctx: @This(), a_index: usize, b_index: usize) bool { + const a_k = ctx.keys[a_index]; + const b_k = ctx.keys[b_index]; + return std.mem.order(u8, a_k, b_k) == .lt; + } + }; + var dict: *std.StringArrayHashMap(BValue) = &d.dict; + dict.sort(Ctx{ .keys = dict.keys() }); + try writer.writeByte('d'); - var it = d.iterator(); + var it = dict.iterator(); while (it.next()) |entry| { try std.fmt.format(writer, "{}:{s}", .{ entry.key_ptr.*.len, entry.key_ptr.* }); try entry.value_ptr.*.bencodeInner(writer); } try writer.writeByte('e'); }, - .int => |i| { - try std.fmt.format(writer, "i{}e", .{i}); - }, } } pub fn deinit(self: *BValue, a: std.mem.Allocator) void { switch (self.*) { + .int => {}, .string => |s| { - a.free(s); + if (s.owned) + a.free(s.string); }, .list => |*l| { - for (l.items) |*i| { - i.deinit(a); - } - l.deinit(); + if (l.valuesOwned) + for (l.list.items) |*i| + i.deinit(a); + l.list.deinit(); }, .dict => |*d| { - var it = d.iterator(); - while (it.next()) |entry| { - a.free(entry.key_ptr.*); - entry.value_ptr.*.deinit(a); + if (d.keysAndValuesOwned) { + var it = d.dict.iterator(); + while (it.next()) |entry| { + a.free(entry.key_ptr.*); + entry.value_ptr.*.deinit(a); + } } - d.deinit(); + d.dict.deinit(); }, - .int => {}, - } - } - - pub fn asDict(self: BValue) !std.StringArrayHashMap(BValue) { - switch (self) { - .dict => |d| return d, - else => return error.WrongType, } } @@ -89,14 +108,21 @@ pub const BValue = union(enum) { pub fn asString(self: BValue) ![]const u8 { switch (self) { - .string => |s| return s, + .string => |s| return s.string, else => return error.WrongType, } } pub fn asList(self: BValue) !std.ArrayList(BValue) { switch (self) { - .list => |l| return l, + .list => |l| return l.list, + else => return error.WrongType, + } + } + + pub fn asDict(self: BValue) !std.StringArrayHashMap(BValue) { + switch (self) { + .dict => |d| return d.dict, else => return error.WrongType, } } @@ -125,7 +151,10 @@ fn bdecodeInner(a: std.mem.Allocator, peekStream: *PeekStream, depth: u32) !BVal var byte = try reader.readByte(); if (std.ascii.isDigit(byte)) { try peekStream.putBackByte(byte); - return .{ .string = try readString(a, peekStream) }; + return .{ .string = .{ + .owned = true, + .string = try readString(a, peekStream), + } }; } else { switch (byte) { 'i' => { @@ -136,44 +165,35 @@ fn bdecodeInner(a: std.mem.Allocator, peekStream: *PeekStream, depth: u32) !BVal return .{ .int = i }; }, 'l' => { - var list = std.ArrayList(BValue).init(a); - errdefer { - for (list.items) |*i| { - i.deinit(a); - } - list.deinit(); - } + var r: BValue = .{ .list = .{ + .valuesOwned = true, + .list = std.ArrayList(BValue).init(a), + } }; + errdefer r.deinit(a); while (true) { const b2 = try reader.readByte(); if (b2 == 'e') break; try peekStream.putBackByte(b2); var val = try bdecodeInner(a, peekStream, depth + 1); errdefer val.deinit(a); - try list.append(val); + try r.list.list.append(val); } - return .{ .list = list }; + return r; }, 'd' => { - var dict = std.StringArrayHashMap(BValue).init(a); - errdefer { - var it = dict.iterator(); - while (it.next()) |entry| { - a.free(entry.key_ptr.*); - entry.value_ptr.*.deinit(a); - } - dict.deinit(); - } - lp: while (true) { + var r: BValue = .{ .dict = .{ .keysAndValuesOwned = true, .dict = std.StringArrayHashMap(BValue).init(a) } }; + errdefer r.deinit(a); + while (true) { const b2 = try reader.readByte(); - if (b2 == 'e') break :lp; + if (b2 == 'e') break; try peekStream.putBackByte(b2); var key = try readString(a, peekStream); errdefer a.free(key); var val = try bdecode(a, reader); errdefer val.deinit(a); - try dict.put(key, val); + try r.dict.dict.put(key, val); } - return .{ .dict = dict }; + return r; }, else => return error.Malformatted, // TODO diagnostics } @@ -253,7 +273,7 @@ test "bdecode string" { var a = std.testing.allocator; var bval = try bdecodeBuf(a, "5:hello"); defer bval.deinit(a); - try std.testing.expectEqualDeep(BValue{ .string = "hello" }, bval); + try std.testing.expectEqualDeep(BValue{ .string = .{ .owned = true, .string = "hello" } }, bval); } test "bdecode string too short" { var a = std.testing.allocator; @@ -264,9 +284,9 @@ test "bdecode list" { var a = std.testing.allocator; var bval = try bdecodeBuf(a, "l5:hello5:worlde"); defer bval.deinit(a); - try std.testing.expectEqual(@as(usize, 2), bval.list.items.len); - try std.testing.expectEqualStrings("hello", bval.list.items[0].string); - try std.testing.expectEqualStrings("world", bval.list.items[1].string); + try std.testing.expectEqual(@as(usize, 2), bval.list.list.items.len); + try std.testing.expectEqualStrings("hello", bval.list.list.items[0].string.string); + try std.testing.expectEqualStrings("world", bval.list.list.items[1].string.string); } test "invalid list" { @@ -278,8 +298,8 @@ test "dict" { var a = std.testing.allocator; var bval = try bdecodeBuf(a, "d5:hello5:worlde"); defer bval.deinit(a); - var v = bval.dict.getPtr("hello") orelse return error.TestExpectedNotNull; - try std.testing.expectEqualStrings("world", v.string); + var v = bval.dict.dict.getPtr("hello") orelse return error.TestExpectedNotNull; + try std.testing.expectEqualStrings("world", v.string.string); } test "invalid dict no value" { @@ -296,12 +316,12 @@ test "nested structure" { var a = std.testing.allocator; var bval = try bdecodeBuf(a, "d5:hello5:world2:hili123ei456el4:nesteee"); defer bval.deinit(a); - var v = bval.dict.getPtr("hello") orelse return error.TestExpectedNotNull; - try std.testing.expectEqualStrings("world", v.string); - var v2 = bval.dict.getPtr("hi") orelse return error.TestExpectedNotNull; - try std.testing.expectEqualDeep(v2.*.list.items[0], BValue{ .int = 123 }); - try std.testing.expectEqualDeep(v2.*.list.items[1], BValue{ .int = 456 }); - try std.testing.expectEqualStrings("nest", v2.*.list.items[2].list.items[0].string); + var v = bval.dict.dict.getPtr("hello") orelse return error.TestExpectedNotNull; + try std.testing.expectEqualStrings("world", v.string.string); + var v2 = bval.dict.dict.getPtr("hi") orelse return error.TestExpectedNotNull; + try std.testing.expectEqualDeep(v2.*.list.list.items[0], BValue{ .int = 123 }); + try std.testing.expectEqualDeep(v2.*.list.list.items[1], BValue{ .int = 456 }); + try std.testing.expectEqualStrings("nest", v2.*.list.list.items[2].list.list.items[0].string.string); } test "round trip" { diff --git a/src/metainfo.zig b/src/metainfo.zig index a202b0e..31b14b7 100644 --- a/src/metainfo.zig +++ b/src/metainfo.zig @@ -9,10 +9,14 @@ pub const Info = struct { name: []const u8, length: u64, path: []const u8, + md5sum: ?[]const u8 = null, }; + piece_length: u64, pieces: []const u8, files: []File, + private: ?bool, + pub fn parse(a: std.mem.Allocator, b: bencode.BValue) Error!Info { var d = b.asDict() catch return error.Malformatted; const pl = d.get("piece length") orelse return error.Malformatted; @@ -27,33 +31,89 @@ pub const Info = struct { const fin = fd.get("name") orelse return error.Malformatted; const fl = fd.get("length") orelse return error.Malformatted; const fp = fd.get("path") orelse return error.Malformatted; + var fm: ?[]const u8 = null; + if (fd.get("md5sum")) |md5| { + fm = md5.asString() catch return error.Malformatted; + } try files.append(.{ .name = fin.asString() catch return error.Malformatted, .length = fl.asInt(u64) catch return error.Malformatted, .path = fp.asString() catch return error.Malformatted, + .md5sum = fm, }); } } else { // single-file mode const fin = d.get("name") orelse return error.Malformatted; const fl = d.get("length") orelse return error.Malformatted; + var fm: ?[]const u8 = null; + if (d.get("md5sum")) |md5| { + fm = md5.asString() catch return error.Malformatted; + } try files.append(.{ .name = fin.asString() catch return error.Malformatted, .length = fl.asInt(u64) catch return error.Malformatted, .path = fin.asString() catch return error.Malformatted, // just use the file name as path + .md5sum = fm, }); } + var priv: ?bool = null; + if (d.get("private")) |pr| { + const pri = pr.asInt(u1) catch return error.Malformatted; + priv = pri == 1; + } return .{ .piece_length = pl.asInt(u64) catch return error.Malformatted, .pieces = pp.asString() catch return error.Malformatted, .files = try files.toOwnedSlice(), + .private = priv, }; } + pub fn encode(self: Info, a: std.mem.Allocator) Error!bencode.BValue { + var r: bencode.BValue = .{ .dict = .{ .dict = std.StringArrayHashMap(bencode.BValue).init(a) } }; + errdefer r.deinit(a); + if (self.files.len > 1) { + @panic("TODO"); + } else if (self.files.len == 1) { + const f = self.files[0]; + const l = std.math.cast(i64, f.length) orelse return error.Malformatted; + try r.dict.dict.put("length", .{ .int = l }); + try r.dict.dict.put("name", .{ .string = .{ .string = f.name } }); + if (f.md5sum) |fm| { + try r.dict.dict.put("md5sum", .{ .string = .{ .string = fm } }); + } + } else { + return error.Malformatted; + } + + try r.dict.dict.put("pieces", .{ .string = .{ .string = self.pieces } }); + const pl = std.math.cast(i64, self.piece_length) orelse return error.Malformatted; + try r.dict.dict.put("piece length", .{ .int = pl }); + if (self.private) |pr| { + const pri: i64 = if (pr) 1 else 0; + try r.dict.dict.put("private", .{ .int = pri }); + } + + return r; + } + pub fn deinit(self: *Info, a: std.mem.Allocator) void { a.free(self.files); } + + const info_hash_len = 40; + pub fn hash(self: Info, a: std.mem.Allocator) ![info_hash_len]u8 { + var b = try self.encode(a); + defer b.deinit(a); + var sha1 = std.crypto.hash.Sha1.init(.{}); + var w = sha1.writer(); + try b.bencode(w); + var buf = [_]u8{0} ** info_hash_len; + _ = std.fmt.bufPrint(&buf, "{}", .{std.fmt.fmtSliceHexLower(&sha1.finalResult())}) catch unreachable; + return buf; + } }; info: Info, @@ -70,6 +130,14 @@ pub fn parse(a: std.mem.Allocator, b: bencode.BValue) Error!MetaInfo { }; } +pub fn encode(self: MetaInfo, a: std.mem.Allocator) !bencode.BValue { + var d: bencode.BValue = .{ .dict = .{ .dict = std.StringArrayHashMap(bencode.BValue).init(a) } }; + errdefer d.deinit(a); + try d.put("announce", .{ .string = .{ .string = self.announce } }); + try d.put("info", try self.info.encode(a)); + return d; +} + pub fn deinit(self: *MetaInfo, a: std.mem.Allocator) void { self.info.deinit(a); } @@ -85,3 +153,14 @@ test "sample" { try std.testing.expectEqual(@as(usize, 1), mi.info.files.len); try std.testing.expectEqualStrings("sample.txt", mi.info.files[0].name); } + +test "info hash" { + const a = std.testing.allocator; + const sample_str = @embedFile("sample.torrent"); + var b = try bencode.bdecodeBuf(a, sample_str); + defer b.deinit(a); + var mi = try MetaInfo.parse(a, b); + defer mi.deinit(a); + const hash = try mi.info.hash(a); + try std.testing.expectEqualStrings("d69f91e6b2ae4c542468d1073a71d4ea13879a7f", &hash); +} |