aboutsummaryrefslogtreecommitdiff
path: root/src/peer_protocol.zig
blob: 91ea349c7ec056e8a218a09d7db2b01ece3cf36a (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
const std = @import("std");
const builtin = @import("builtin");

pub fn readMessage(a: std.mem.Allocator, reader: anytype, comptime msgType: type) !msgType {
    var nka: u32 = 0;
    while (nka < 1000) {
        var len = try reader.readInt(u32, .big);
        // All messages except Keepalive start with a single byte message type.
        // Skip keepalive messages, we don't care (unless you're spamming them)
        if (len == 0) {
            nka += 1; 
            continue;
        }
        var mt = try reader.readByte();
        if (mt != msgType.Tag) return error.ProtocolError;
        return try msgType.read(a, len-1, reader);
    } else {
        return error.ProtocolError;
    }
}

// When you're expecting several possible messages.
// 'Expected' should be a tagged union of message types you are expecting.
pub fn readAnyMessage(a: std.mem.Allocator, reader: anytype, comptime Expected: type) !Expected {
    var nka: u32 = 0;
    while (nka < 1000) {
        var len = try reader.readInt(u32, .big);
        // All messages except Keepalive start with a single byte message type.
        // Skip keepalive messages, we don't care (unless you're spamming them)
        if (len == 0) {
            nka += 1; 
            continue;
        }
        var mt = try reader.readByte();
        inline for (@typeInfo(Expected).Union.fields) |field| {
            const msgType = field.type;
            if (msgType.Tag == mt) {
                return @unionInit(Expected, field.name, try msgType.read(a, len-1, reader));
            }
        } else {
            return error.ProtocolError;
        }
    } else {
        return error.ProtocolError;
    }
}

// Handshake message has a different structure to the rest. And it's only read once per connection.
pub const Handshake = struct {
    info_hash: [20]u8,
    peer_id: [20]u8,

    pub fn read(reader: anytype) !Handshake {
        var msg = [_]u8{0} ** 68;
        try reader.readNoEof(&msg);
        if (msg[0] != 19) return error.ProtocolError;
        if (!std.mem.eql(u8, msg[1..20], "BitTorrent protocol")) return error.ProtocolError;
        //if (!std.mem.allEqual(u8, msg[20..28], 0)) return error.ProtocolError;
        var res: Handshake = undefined;
        @memcpy(&res.info_hash, msg[28..48]);
        @memcpy(&res.peer_id, msg[48..68]);
        return res;
    }

    pub fn write(self: Handshake, writer: anytype) !void {
        try writer.writeByte(19);
        try writer.writeAll("BitTorrent protocol");
        try writer.writeByteNTimes(0, 8);
        try writer.writeAll(&self.info_hash);
        try writer.writeAll(&self.peer_id);
    }
};


pub const Unchoke = struct {
    pub const Tag: u8 = 1;
    pub fn read(a: std.mem.Allocator, len: usize, reader: anytype) !Unchoke {
        _ = a;
        _ = reader;
        if (len != 0) return error.ProtocolError;
        return .{};
    }
};

pub const Interested = struct {
    pub const Tag: u8 = 2;
    pub fn read(a: std.mem.Allocator, len: usize, reader: anytype) !Interested {
        _ = a;
        _ = reader;
        if (len != 0) return error.ProtocolError;
        return .{};   
    }
    pub fn write(writer: anytype) !void {
        try writer.writeInt(u32, 1, .big);
        try writer.writeInt(u8, Tag, .big);
    }
};

pub const Bitfield = struct {
    pub const Tag: u8 = 5;
    pub fn read(a: std.mem.Allocator, len: usize, reader: anytype) !Bitfield {
        // TODO actually read this message and do something useful.
        _ = a;
        try reader.skipBytes(len, .{}); // 
        return .{};
    }
};

pub const Request = struct {
    pub const Tag: u8 = 6;
    index: u32, 
    begin: u32, 
    length: u32,
    pub fn write(self: Request, writer: anytype) !void {
        try writer.writeInt(u32, 13, .big);
        try writer.writeInt(u8, Tag, .big);
        try writer.writeInt(u32,self.index, .big);
        try writer.writeInt(u32,self.begin, .big);
        try writer.writeInt(u32,self.length, .big);
    }
};

pub const Piece = struct {
    pub const Tag: u8 = 7;
    index: u32,
    begin: u32,
    block: []const u8,
    pub fn read(a: std.mem.Allocator, len: usize, reader: anytype) !Piece {
        if (len <= 8) {
            std.log.err("Piece#read len {}", .{len});
            return error.ProtocolError;
        }
        var ix = try reader.readInt(u32, .big);
        var be = try reader.readInt(u32, .big);
        var bl = try a.alloc(u8, len-8);
        errdefer a.free(bl);
        try reader.readNoEof(bl);
        return .{
            .index = ix, 
            .begin = be,
            .block = bl,
        };
    }
    pub fn deinit(self: *Piece, a: std.mem.Allocator) void {
        a.free(self.block);
    }
};

test "read any" {
    const a = std.testing.allocator;
    var fbs = std.io.fixedBufferStream(&[_]u8{0, 0, 0, 1, Unchoke.Tag});
    var r = fbs.reader();
    const T = union(enum) {u: Unchoke, i: Interested, b: Bitfield};
    var msg: T = try readAnyMessage(a, r, T);
    try std.testing.expect(msg == .u);
}