zbt

CLI Bittorrent client, written in Zig
Log | Files | Refs | README

metainfo.zig (6921B)


      1 //! https://wiki.theory.org/BitTorrentSpecification#Metainfo_File_Structure
      2 const std = @import("std");
      3 const bencode = @import("bencode.zig");
      4 const MetaInfo = @This();
      5 pub const Error = (error{Malformatted} || std.mem.Allocator.Error);
      6 
      7 pub const Info = struct {
      8     pub const File = struct {
      9         name: []const u8,
     10         length: u64,
     11         path: []const u8,
     12         md5sum: ?[]const u8 = null,
     13     };
     14 
     15     piece_length: u32,
     16     pieces: []const u8,
     17     files: []File,
     18     private: ?bool,
     19 
     20     pub fn parse(a: std.mem.Allocator, b: bencode.BValue) Error!Info {
     21         var d = b.asDict() catch return error.Malformatted;
     22         const pl = d.get("piece length") orelse return error.Malformatted;
     23         const pp = d.get("pieces") orelse return error.Malformatted;
     24         var files = std.ArrayList(File).init(a);
     25         defer files.deinit();
     26         if (d.get("files")) |f| {
     27             // multi-file mode
     28             const l = f.asList() catch return error.Malformatted;
     29             for (l.items) |fi| {
     30                 const fd = fi.asDict() catch return error.Malformatted;
     31                 const fin = fd.get("name") orelse return error.Malformatted;
     32                 const fl = fd.get("length") orelse return error.Malformatted;
     33                 const fp = fd.get("path") orelse return error.Malformatted;
     34                 var fm: ?[]const u8 = null;
     35                 if (fd.get("md5sum")) |md5| {
     36                     fm = md5.asString() catch return error.Malformatted;
     37                 }
     38                 try files.append(.{
     39                     .name = fin.asString() catch return error.Malformatted,
     40                     .length = fl.asInt(u64) catch return error.Malformatted,
     41                     .path = fp.asString() catch return error.Malformatted,
     42                     .md5sum = fm,
     43                 });
     44             }
     45         } else {
     46             // single-file mode
     47             const fin = d.get("name") orelse return error.Malformatted;
     48             const fl = d.get("length") orelse return error.Malformatted;
     49             var fm: ?[]const u8 = null;
     50             if (d.get("md5sum")) |md5| {
     51                 fm = md5.asString() catch return error.Malformatted;
     52             }
     53             try files.append(.{
     54                 .name = fin.asString() catch return error.Malformatted,
     55                 .length = fl.asInt(u64) catch return error.Malformatted,
     56                 .path = fin.asString() catch return error.Malformatted, // just use the file name as path
     57                 .md5sum = fm,
     58             });
     59         }
     60 
     61         var priv: ?bool = null;
     62         if (d.get("private")) |pr| {
     63             const pri = pr.asInt(u1) catch return error.Malformatted;
     64             priv = pri == 1;
     65         }
     66         // Validation....
     67         const ps = pp.asString() catch return error.Malformatted;
     68         if (ps.len % 20 != 0) return error.Malformatted;
     69 
     70         return .{
     71             .piece_length = pl.asInt(u32) catch return error.Malformatted,
     72             .pieces = ps,
     73             .files = try files.toOwnedSlice(),
     74             .private = priv,
     75         };
     76     }
     77 
     78     pub fn encode(self: Info, a: std.mem.Allocator) Error!bencode.BValue {
     79         var r: bencode.BValue = .{ .dict = .{ .dict = std.StringArrayHashMap(bencode.BValue).init(a) } };
     80         errdefer r.deinit(a);
     81         if (self.files.len > 1) {
     82             @panic("TODO");
     83         } else if (self.files.len == 1) {
     84             const f = self.files[0];
     85             const l = std.math.cast(i64, f.length) orelse return error.Malformatted;
     86             try r.dict.dict.put("length", .{ .int = l });
     87             try r.dict.dict.put("name", .{ .string = .{ .string = f.name } });
     88             if (f.md5sum) |fm| {
     89                 try r.dict.dict.put("md5sum", .{ .string = .{ .string = fm } });
     90             }
     91         } else {
     92             return error.Malformatted;
     93         }
     94 
     95         try r.dict.dict.put("pieces", .{ .string = .{ .string = self.pieces } });
     96         const pl = std.math.cast(i64, self.piece_length) orelse return error.Malformatted;
     97         try r.dict.dict.put("piece length", .{ .int = pl });
     98         if (self.private) |pr| {
     99             const pri: i64 = if (pr) 1 else 0;
    100             try r.dict.dict.put("private", .{ .int = pri });
    101         }
    102 
    103         return r;
    104     }
    105 
    106     pub fn deinit(self: *Info, a: std.mem.Allocator) void {
    107         a.free(self.files);
    108     }
    109 
    110     const info_hash_len = std.crypto.hash.Sha1.digest_length;
    111     pub fn hash(self: Info, a: std.mem.Allocator) ![info_hash_len]u8 {
    112         var b = try self.encode(a);
    113         defer b.deinit(a);
    114         var sha1 = std.crypto.hash.Sha1.init(.{});
    115         var w = sha1.writer();
    116         try b.bencode(w);
    117         return sha1.finalResult();
    118     }
    119     pub fn pieceHash(self: Info, ix: usize) ?[20]u8 {
    120         const start = 20 * ix;
    121         if (start >= self.pieces.len) return null;
    122         var res: [20]u8 = undefined;
    123         @memcpy(&res, self.pieces[start .. start + 20]);
    124         return res;
    125     }
    126     pub fn pieceCount(self: Info) u32 {
    127         return @as(u32, @intCast(self.pieces.len)) / 20;
    128     }
    129 };
    130 
    131 info: Info,
    132 announce: []const u8,
    133 
    134 pub fn parse(a: std.mem.Allocator, b: bencode.BValue) Error!MetaInfo {
    135     // TODO diagnostics
    136     var d = b.asDict() catch return error.Malformatted;
    137     const i = d.get("info") orelse return error.Malformatted;
    138     const an = d.get("announce") orelse return error.Malformatted;
    139     return .{
    140         .info = try Info.parse(a, i),
    141         .announce = an.asString() catch return error.Malformatted,
    142     };
    143 }
    144 
    145 pub fn encode(self: MetaInfo, a: std.mem.Allocator) !bencode.BValue {
    146     var d: bencode.BValue = .{ .dict = .{ .dict = std.StringArrayHashMap(bencode.BValue).init(a) } };
    147     errdefer d.deinit(a);
    148     try d.put("announce", .{ .string = .{ .string = self.announce } });
    149     try d.put("info", try self.info.encode(a));
    150     return d;
    151 }
    152 
    153 pub fn deinit(self: *MetaInfo, a: std.mem.Allocator) void {
    154     self.info.deinit(a);
    155 }
    156 
    157 test "sample" {
    158     const a = std.testing.allocator;
    159     const sample_str = @embedFile("sample.torrent");
    160     var b = try bencode.bdecodeBuf(a, sample_str);
    161     defer b.deinit(a);
    162     var mi = try MetaInfo.parse(a, b);
    163     defer mi.deinit(a);
    164     try std.testing.expectEqualStrings("http://bittorrent-test-tracker.codecrafters.io/announce", mi.announce);
    165     try std.testing.expectEqual(@as(usize, 1), mi.info.files.len);
    166     try std.testing.expectEqualStrings("sample.txt", mi.info.files[0].name);
    167 }
    168 
    169 test "info hash" {
    170     const a = std.testing.allocator;
    171     const sample_str = @embedFile("sample.torrent");
    172     var b = try bencode.bdecodeBuf(a, sample_str);
    173     defer b.deinit(a);
    174     var mi = try MetaInfo.parse(a, b);
    175     defer mi.deinit(a);
    176     const hash = try mi.info.hash(a);
    177     var hash_hex = try std.fmt.allocPrint(a, "{s}", .{std.fmt.fmtSliceHexLower(&hash)});
    178     defer a.free(hash_hex);
    179     try std.testing.expectEqualStrings("d69f91e6b2ae4c542468d1073a71d4ea13879a7f", hash_hex);
    180 }