aboutsummaryrefslogtreecommitdiff
path: root/src
diff options
context:
space:
mode:
authorMartin Ashby <martin@ashbysoft.com>2023-11-11 20:47:46 +0000
committerMartin Ashby <martin@ashbysoft.com>2023-11-11 20:47:46 +0000
commite92459156c2f74de648566ca7acde3833de33425 (patch)
tree59d0bb42d23db0d295a0ef0a97e31e71b8f592cd /src
parentdff4234bf2d957a0328ff4f3dc4f9bba1fbeffd4 (diff)
downloadzbt-e92459156c2f74de648566ca7acde3833de33425.tar.gz
zbt-e92459156c2f74de648566ca7acde3833de33425.tar.bz2
zbt-e92459156c2f74de648566ca7acde3833de33425.tar.xz
zbt-e92459156c2f74de648566ca7acde3833de33425.zip
Include ownership info on BValue.
It's a bit cumbersome, but it caters for the two situations: a. parsing from a stream where strings must be copied into the structure and owned, and b. constructing from application code where strings must _not_ be freed. Either way, the base structures (array list, hash map) are always owned and must be freed. Add method and test for info hash
Diffstat (limited to 'src')
-rw-r--r--src/bencode.zig158
-rw-r--r--src/metainfo.zig79
2 files changed, 168 insertions, 69 deletions
diff --git a/src/bencode.zig b/src/bencode.zig
index 60a3746..3e83121 100644
--- a/src/bencode.zig
+++ b/src/bencode.zig
@@ -7,13 +7,24 @@ const AnyWriter = @import("anywriter.zig");
pub const Error = error.Malformatted || std.io.AnyReader.Error;
// All content is owned by the BValue and must be freed with deinit.
+// hmmm, this gets a bit awkward from the _writing_ side of things.
+// What should I do? Optionally owned? Yeah let's do that.
pub const BValue = union(enum) {
- string: []const u8,
int: i64,
- list: std.ArrayList(BValue),
- dict: std.StringArrayHashMap(BValue),
-
- pub fn bencode(self: *const BValue, base_writer: anytype) !void {
+ string: struct {
+ string: []const u8,
+ owned: bool = false,
+ },
+ list: struct {
+ list: std.ArrayList(BValue),
+ valuesOwned: bool = false,
+ },
+ dict: struct {
+ dict: std.StringArrayHashMap(BValue),
+ keysAndValuesOwned: bool = false,
+ },
+
+ pub fn bencode(self: *BValue, base_writer: anytype) !void {
var wrap = AnyWriter.wrapper(base_writer);
var writer = wrap.any();
try self.bencodeInner(writer);
@@ -21,60 +32,68 @@ pub const BValue = union(enum) {
// Note: uses defined types only to avoid trying to recursively evaulate this function
// at compile time, otherwise we run into https://github.com/ziglang/zig/issues/13724
- fn bencodeInner(self: *const BValue, writer: AnyWriter) !void {
+ fn bencodeInner(self: *BValue, writer: AnyWriter) !void {
switch (self.*) {
+ .int => |i| {
+ try std.fmt.format(writer, "i{}e", .{i});
+ },
.string => |s| {
- try std.fmt.format(writer, "{}:{s}", .{ s.len, s });
+ try std.fmt.format(writer, "{}:{s}", .{ s.string.len, s.string });
},
- .list => |l| {
+ .list => |*l| {
try writer.writeByte('l');
- for (l.items) |i| {
+ for (l.list.items) |*i| {
try i.bencodeInner(writer);
}
try writer.writeByte('e');
},
- .dict => |d| {
+ .dict => |*d| {
+ // Keys must be strings and appear in sorted order (sorted as raw strings, not alphanumerics). The strings should be compared using a binary comparison, not a culture-specific "natural" comparison.
+ const Ctx = struct {
+ keys: [][]const u8,
+ pub fn lessThan(ctx: @This(), a_index: usize, b_index: usize) bool {
+ const a_k = ctx.keys[a_index];
+ const b_k = ctx.keys[b_index];
+ return std.mem.order(u8, a_k, b_k) == .lt;
+ }
+ };
+ var dict: *std.StringArrayHashMap(BValue) = &d.dict;
+ dict.sort(Ctx{ .keys = dict.keys() });
+
try writer.writeByte('d');
- var it = d.iterator();
+ var it = dict.iterator();
while (it.next()) |entry| {
try std.fmt.format(writer, "{}:{s}", .{ entry.key_ptr.*.len, entry.key_ptr.* });
try entry.value_ptr.*.bencodeInner(writer);
}
try writer.writeByte('e');
},
- .int => |i| {
- try std.fmt.format(writer, "i{}e", .{i});
- },
}
}
pub fn deinit(self: *BValue, a: std.mem.Allocator) void {
switch (self.*) {
+ .int => {},
.string => |s| {
- a.free(s);
+ if (s.owned)
+ a.free(s.string);
},
.list => |*l| {
- for (l.items) |*i| {
- i.deinit(a);
- }
- l.deinit();
+ if (l.valuesOwned)
+ for (l.list.items) |*i|
+ i.deinit(a);
+ l.list.deinit();
},
.dict => |*d| {
- var it = d.iterator();
- while (it.next()) |entry| {
- a.free(entry.key_ptr.*);
- entry.value_ptr.*.deinit(a);
+ if (d.keysAndValuesOwned) {
+ var it = d.dict.iterator();
+ while (it.next()) |entry| {
+ a.free(entry.key_ptr.*);
+ entry.value_ptr.*.deinit(a);
+ }
}
- d.deinit();
+ d.dict.deinit();
},
- .int => {},
- }
- }
-
- pub fn asDict(self: BValue) !std.StringArrayHashMap(BValue) {
- switch (self) {
- .dict => |d| return d,
- else => return error.WrongType,
}
}
@@ -89,14 +108,21 @@ pub const BValue = union(enum) {
pub fn asString(self: BValue) ![]const u8 {
switch (self) {
- .string => |s| return s,
+ .string => |s| return s.string,
else => return error.WrongType,
}
}
pub fn asList(self: BValue) !std.ArrayList(BValue) {
switch (self) {
- .list => |l| return l,
+ .list => |l| return l.list,
+ else => return error.WrongType,
+ }
+ }
+
+ pub fn asDict(self: BValue) !std.StringArrayHashMap(BValue) {
+ switch (self) {
+ .dict => |d| return d.dict,
else => return error.WrongType,
}
}
@@ -125,7 +151,10 @@ fn bdecodeInner(a: std.mem.Allocator, peekStream: *PeekStream, depth: u32) !BVal
var byte = try reader.readByte();
if (std.ascii.isDigit(byte)) {
try peekStream.putBackByte(byte);
- return .{ .string = try readString(a, peekStream) };
+ return .{ .string = .{
+ .owned = true,
+ .string = try readString(a, peekStream),
+ } };
} else {
switch (byte) {
'i' => {
@@ -136,44 +165,35 @@ fn bdecodeInner(a: std.mem.Allocator, peekStream: *PeekStream, depth: u32) !BVal
return .{ .int = i };
},
'l' => {
- var list = std.ArrayList(BValue).init(a);
- errdefer {
- for (list.items) |*i| {
- i.deinit(a);
- }
- list.deinit();
- }
+ var r: BValue = .{ .list = .{
+ .valuesOwned = true,
+ .list = std.ArrayList(BValue).init(a),
+ } };
+ errdefer r.deinit(a);
while (true) {
const b2 = try reader.readByte();
if (b2 == 'e') break;
try peekStream.putBackByte(b2);
var val = try bdecodeInner(a, peekStream, depth + 1);
errdefer val.deinit(a);
- try list.append(val);
+ try r.list.list.append(val);
}
- return .{ .list = list };
+ return r;
},
'd' => {
- var dict = std.StringArrayHashMap(BValue).init(a);
- errdefer {
- var it = dict.iterator();
- while (it.next()) |entry| {
- a.free(entry.key_ptr.*);
- entry.value_ptr.*.deinit(a);
- }
- dict.deinit();
- }
- lp: while (true) {
+ var r: BValue = .{ .dict = .{ .keysAndValuesOwned = true, .dict = std.StringArrayHashMap(BValue).init(a) } };
+ errdefer r.deinit(a);
+ while (true) {
const b2 = try reader.readByte();
- if (b2 == 'e') break :lp;
+ if (b2 == 'e') break;
try peekStream.putBackByte(b2);
var key = try readString(a, peekStream);
errdefer a.free(key);
var val = try bdecode(a, reader);
errdefer val.deinit(a);
- try dict.put(key, val);
+ try r.dict.dict.put(key, val);
}
- return .{ .dict = dict };
+ return r;
},
else => return error.Malformatted, // TODO diagnostics
}
@@ -253,7 +273,7 @@ test "bdecode string" {
var a = std.testing.allocator;
var bval = try bdecodeBuf(a, "5:hello");
defer bval.deinit(a);
- try std.testing.expectEqualDeep(BValue{ .string = "hello" }, bval);
+ try std.testing.expectEqualDeep(BValue{ .string = .{ .owned = true, .string = "hello" } }, bval);
}
test "bdecode string too short" {
var a = std.testing.allocator;
@@ -264,9 +284,9 @@ test "bdecode list" {
var a = std.testing.allocator;
var bval = try bdecodeBuf(a, "l5:hello5:worlde");
defer bval.deinit(a);
- try std.testing.expectEqual(@as(usize, 2), bval.list.items.len);
- try std.testing.expectEqualStrings("hello", bval.list.items[0].string);
- try std.testing.expectEqualStrings("world", bval.list.items[1].string);
+ try std.testing.expectEqual(@as(usize, 2), bval.list.list.items.len);
+ try std.testing.expectEqualStrings("hello", bval.list.list.items[0].string.string);
+ try std.testing.expectEqualStrings("world", bval.list.list.items[1].string.string);
}
test "invalid list" {
@@ -278,8 +298,8 @@ test "dict" {
var a = std.testing.allocator;
var bval = try bdecodeBuf(a, "d5:hello5:worlde");
defer bval.deinit(a);
- var v = bval.dict.getPtr("hello") orelse return error.TestExpectedNotNull;
- try std.testing.expectEqualStrings("world", v.string);
+ var v = bval.dict.dict.getPtr("hello") orelse return error.TestExpectedNotNull;
+ try std.testing.expectEqualStrings("world", v.string.string);
}
test "invalid dict no value" {
@@ -296,12 +316,12 @@ test "nested structure" {
var a = std.testing.allocator;
var bval = try bdecodeBuf(a, "d5:hello5:world2:hili123ei456el4:nesteee");
defer bval.deinit(a);
- var v = bval.dict.getPtr("hello") orelse return error.TestExpectedNotNull;
- try std.testing.expectEqualStrings("world", v.string);
- var v2 = bval.dict.getPtr("hi") orelse return error.TestExpectedNotNull;
- try std.testing.expectEqualDeep(v2.*.list.items[0], BValue{ .int = 123 });
- try std.testing.expectEqualDeep(v2.*.list.items[1], BValue{ .int = 456 });
- try std.testing.expectEqualStrings("nest", v2.*.list.items[2].list.items[0].string);
+ var v = bval.dict.dict.getPtr("hello") orelse return error.TestExpectedNotNull;
+ try std.testing.expectEqualStrings("world", v.string.string);
+ var v2 = bval.dict.dict.getPtr("hi") orelse return error.TestExpectedNotNull;
+ try std.testing.expectEqualDeep(v2.*.list.list.items[0], BValue{ .int = 123 });
+ try std.testing.expectEqualDeep(v2.*.list.list.items[1], BValue{ .int = 456 });
+ try std.testing.expectEqualStrings("nest", v2.*.list.list.items[2].list.list.items[0].string.string);
}
test "round trip" {
diff --git a/src/metainfo.zig b/src/metainfo.zig
index a202b0e..31b14b7 100644
--- a/src/metainfo.zig
+++ b/src/metainfo.zig
@@ -9,10 +9,14 @@ pub const Info = struct {
name: []const u8,
length: u64,
path: []const u8,
+ md5sum: ?[]const u8 = null,
};
+
piece_length: u64,
pieces: []const u8,
files: []File,
+ private: ?bool,
+
pub fn parse(a: std.mem.Allocator, b: bencode.BValue) Error!Info {
var d = b.asDict() catch return error.Malformatted;
const pl = d.get("piece length") orelse return error.Malformatted;
@@ -27,33 +31,89 @@ pub const Info = struct {
const fin = fd.get("name") orelse return error.Malformatted;
const fl = fd.get("length") orelse return error.Malformatted;
const fp = fd.get("path") orelse return error.Malformatted;
+ var fm: ?[]const u8 = null;
+ if (fd.get("md5sum")) |md5| {
+ fm = md5.asString() catch return error.Malformatted;
+ }
try files.append(.{
.name = fin.asString() catch return error.Malformatted,
.length = fl.asInt(u64) catch return error.Malformatted,
.path = fp.asString() catch return error.Malformatted,
+ .md5sum = fm,
});
}
} else {
// single-file mode
const fin = d.get("name") orelse return error.Malformatted;
const fl = d.get("length") orelse return error.Malformatted;
+ var fm: ?[]const u8 = null;
+ if (d.get("md5sum")) |md5| {
+ fm = md5.asString() catch return error.Malformatted;
+ }
try files.append(.{
.name = fin.asString() catch return error.Malformatted,
.length = fl.asInt(u64) catch return error.Malformatted,
.path = fin.asString() catch return error.Malformatted, // just use the file name as path
+ .md5sum = fm,
});
}
+ var priv: ?bool = null;
+ if (d.get("private")) |pr| {
+ const pri = pr.asInt(u1) catch return error.Malformatted;
+ priv = pri == 1;
+ }
return .{
.piece_length = pl.asInt(u64) catch return error.Malformatted,
.pieces = pp.asString() catch return error.Malformatted,
.files = try files.toOwnedSlice(),
+ .private = priv,
};
}
+ pub fn encode(self: Info, a: std.mem.Allocator) Error!bencode.BValue {
+ var r: bencode.BValue = .{ .dict = .{ .dict = std.StringArrayHashMap(bencode.BValue).init(a) } };
+ errdefer r.deinit(a);
+ if (self.files.len > 1) {
+ @panic("TODO");
+ } else if (self.files.len == 1) {
+ const f = self.files[0];
+ const l = std.math.cast(i64, f.length) orelse return error.Malformatted;
+ try r.dict.dict.put("length", .{ .int = l });
+ try r.dict.dict.put("name", .{ .string = .{ .string = f.name } });
+ if (f.md5sum) |fm| {
+ try r.dict.dict.put("md5sum", .{ .string = .{ .string = fm } });
+ }
+ } else {
+ return error.Malformatted;
+ }
+
+ try r.dict.dict.put("pieces", .{ .string = .{ .string = self.pieces } });
+ const pl = std.math.cast(i64, self.piece_length) orelse return error.Malformatted;
+ try r.dict.dict.put("piece length", .{ .int = pl });
+ if (self.private) |pr| {
+ const pri: i64 = if (pr) 1 else 0;
+ try r.dict.dict.put("private", .{ .int = pri });
+ }
+
+ return r;
+ }
+
pub fn deinit(self: *Info, a: std.mem.Allocator) void {
a.free(self.files);
}
+
+ const info_hash_len = 40;
+ pub fn hash(self: Info, a: std.mem.Allocator) ![info_hash_len]u8 {
+ var b = try self.encode(a);
+ defer b.deinit(a);
+ var sha1 = std.crypto.hash.Sha1.init(.{});
+ var w = sha1.writer();
+ try b.bencode(w);
+ var buf = [_]u8{0} ** info_hash_len;
+ _ = std.fmt.bufPrint(&buf, "{}", .{std.fmt.fmtSliceHexLower(&sha1.finalResult())}) catch unreachable;
+ return buf;
+ }
};
info: Info,
@@ -70,6 +130,14 @@ pub fn parse(a: std.mem.Allocator, b: bencode.BValue) Error!MetaInfo {
};
}
+pub fn encode(self: MetaInfo, a: std.mem.Allocator) !bencode.BValue {
+ var d: bencode.BValue = .{ .dict = .{ .dict = std.StringArrayHashMap(bencode.BValue).init(a) } };
+ errdefer d.deinit(a);
+ try d.put("announce", .{ .string = .{ .string = self.announce } });
+ try d.put("info", try self.info.encode(a));
+ return d;
+}
+
pub fn deinit(self: *MetaInfo, a: std.mem.Allocator) void {
self.info.deinit(a);
}
@@ -85,3 +153,14 @@ test "sample" {
try std.testing.expectEqual(@as(usize, 1), mi.info.files.len);
try std.testing.expectEqualStrings("sample.txt", mi.info.files[0].name);
}
+
+test "info hash" {
+ const a = std.testing.allocator;
+ const sample_str = @embedFile("sample.torrent");
+ var b = try bencode.bdecodeBuf(a, sample_str);
+ defer b.deinit(a);
+ var mi = try MetaInfo.parse(a, b);
+ defer mi.deinit(a);
+ const hash = try mi.info.hash(a);
+ try std.testing.expectEqualStrings("d69f91e6b2ae4c542468d1073a71d4ea13879a7f", &hash);
+}