aboutsummaryrefslogtreecommitdiff
path: root/src/metainfo.zig
blob: 04e8f64649361d25def9fc06eedfcc804bd1d833 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
//! https://wiki.theory.org/BitTorrentSpecification#Metainfo_File_Structure
const std = @import("std");
const bencode = @import("bencode.zig");
const MetaInfo = @This();
pub const Error = (error{Malformatted} || std.mem.Allocator.Error);

pub const Info = struct {
    pub const File = struct {
        name: []const u8,
        length: u64,
        path: []const u8,
        md5sum: ?[]const u8 = null,
    };

    piece_length: u32,
    pieces: []const u8,
    files: []File,
    private: ?bool,

    pub fn parse(a: std.mem.Allocator, b: bencode.BValue) Error!Info {
        var d = b.asDict() catch return error.Malformatted;
        const pl = d.get("piece length") orelse return error.Malformatted;
        const pp = d.get("pieces") orelse return error.Malformatted;
        var files = std.ArrayList(File).init(a);
        defer files.deinit();
        if (d.get("files")) |f| {
            // multi-file mode
            const l = f.asList() catch return error.Malformatted;
            for (l.items) |fi| {
                const fd = fi.asDict() catch return error.Malformatted;
                const fin = fd.get("name") orelse return error.Malformatted;
                const fl = fd.get("length") orelse return error.Malformatted;
                const fp = fd.get("path") orelse return error.Malformatted;
                var fm: ?[]const u8 = null;
                if (fd.get("md5sum")) |md5| {
                    fm = md5.asString() catch return error.Malformatted;
                }
                try files.append(.{
                    .name = fin.asString() catch return error.Malformatted,
                    .length = fl.asInt(u64) catch return error.Malformatted,
                    .path = fp.asString() catch return error.Malformatted,
                    .md5sum = fm,
                });
            }
        } else {
            // single-file mode
            const fin = d.get("name") orelse return error.Malformatted;
            const fl = d.get("length") orelse return error.Malformatted;
            var fm: ?[]const u8 = null;
            if (d.get("md5sum")) |md5| {
                fm = md5.asString() catch return error.Malformatted;
            }
            try files.append(.{
                .name = fin.asString() catch return error.Malformatted,
                .length = fl.asInt(u64) catch return error.Malformatted,
                .path = fin.asString() catch return error.Malformatted, // just use the file name as path
                .md5sum = fm,
            });
        }

        var priv: ?bool = null;
        if (d.get("private")) |pr| {
            const pri = pr.asInt(u1) catch return error.Malformatted;
            priv = pri == 1;
        }
        // Validation....
        const ps = pp.asString() catch return error.Malformatted;
        if (ps.len % 20 != 0) return error.Malformatted;

        return .{
            .piece_length = pl.asInt(u32) catch return error.Malformatted,
            .pieces = ps,
            .files = try files.toOwnedSlice(),
            .private = priv,
        };
    }

    pub fn encode(self: Info, a: std.mem.Allocator) Error!bencode.BValue {
        var r: bencode.BValue = .{ .dict = .{ .dict = std.StringArrayHashMap(bencode.BValue).init(a) } };
        errdefer r.deinit(a);
        if (self.files.len > 1) {
            @panic("TODO");
        } else if (self.files.len == 1) {
            const f = self.files[0];
            const l = std.math.cast(i64, f.length) orelse return error.Malformatted;
            try r.dict.dict.put("length", .{ .int = l });
            try r.dict.dict.put("name", .{ .string = .{ .string = f.name } });
            if (f.md5sum) |fm| {
                try r.dict.dict.put("md5sum", .{ .string = .{ .string = fm } });
            }
        } else {
            return error.Malformatted;
        }

        try r.dict.dict.put("pieces", .{ .string = .{ .string = self.pieces } });
        const pl = std.math.cast(i64, self.piece_length) orelse return error.Malformatted;
        try r.dict.dict.put("piece length", .{ .int = pl });
        if (self.private) |pr| {
            const pri: i64 = if (pr) 1 else 0;
            try r.dict.dict.put("private", .{ .int = pri });
        }

        return r;
    }

    pub fn deinit(self: *Info, a: std.mem.Allocator) void {
        a.free(self.files);
    }

    const info_hash_len = std.crypto.hash.Sha1.digest_length;
    pub fn hash(self: Info, a: std.mem.Allocator) ![info_hash_len]u8 {
        var b = try self.encode(a);
        defer b.deinit(a);
        var sha1 = std.crypto.hash.Sha1.init(.{});
        var w = sha1.writer();
        try b.bencode(w);
        return sha1.finalResult();
    }
    pub fn pieceHash(self: Info, ix: usize) ?[20]u8 {
        const start = 20 * ix;
        if (start >= self.pieces.len) return null;
        var res: [20]u8 = undefined;
        @memcpy(&res, self.pieces[start .. start + 20]);
        return res;
    }
    pub fn pieceCount(self: Info) u32 {
        return @as(u32, @intCast(self.pieces.len)) / 20;
    }
};

info: Info,
announce: []const u8,

pub fn parse(a: std.mem.Allocator, b: bencode.BValue) Error!MetaInfo {
    // TODO diagnostics
    var d = b.asDict() catch return error.Malformatted;
    const i = d.get("info") orelse return error.Malformatted;
    const an = d.get("announce") orelse return error.Malformatted;
    return .{
        .info = try Info.parse(a, i),
        .announce = an.asString() catch return error.Malformatted,
    };
}

pub fn encode(self: MetaInfo, a: std.mem.Allocator) !bencode.BValue {
    var d: bencode.BValue = .{ .dict = .{ .dict = std.StringArrayHashMap(bencode.BValue).init(a) } };
    errdefer d.deinit(a);
    try d.put("announce", .{ .string = .{ .string = self.announce } });
    try d.put("info", try self.info.encode(a));
    return d;
}

pub fn deinit(self: *MetaInfo, a: std.mem.Allocator) void {
    self.info.deinit(a);
}

test "sample" {
    const a = std.testing.allocator;
    const sample_str = @embedFile("sample.torrent");
    var b = try bencode.bdecodeBuf(a, sample_str);
    defer b.deinit(a);
    var mi = try MetaInfo.parse(a, b);
    defer mi.deinit(a);
    try std.testing.expectEqualStrings("http://bittorrent-test-tracker.codecrafters.io/announce", mi.announce);
    try std.testing.expectEqual(@as(usize, 1), mi.info.files.len);
    try std.testing.expectEqualStrings("sample.txt", mi.info.files[0].name);
}

test "info hash" {
    const a = std.testing.allocator;
    const sample_str = @embedFile("sample.torrent");
    var b = try bencode.bdecodeBuf(a, sample_str);
    defer b.deinit(a);
    var mi = try MetaInfo.parse(a, b);
    defer mi.deinit(a);
    const hash = try mi.info.hash(a);
    var hash_hex = try std.fmt.allocPrint(a, "{s}", .{std.fmt.fmtSliceHexLower(&hash)});
    defer a.free(hash_hex);
    try std.testing.expectEqualStrings("d69f91e6b2ae4c542468d1073a71d4ea13879a7f", hash_hex);
}