peer_protocol.zig (5021B)
1 const std = @import("std"); 2 const builtin = @import("builtin"); 3 4 pub fn readMessage(a: std.mem.Allocator, reader: anytype, comptime msgType: type) !msgType { 5 var nka: u32 = 0; 6 while (nka < 1000) { 7 var len = try reader.readInt(u32, .big); 8 // All messages except Keepalive start with a single byte message type. 9 // Skip keepalive messages, we don't care (unless you're spamming them) 10 if (len == 0) { 11 nka += 1; 12 continue; 13 } 14 var mt = try reader.readByte(); 15 if (mt != msgType.Tag) return error.ProtocolError; 16 return try msgType.read(a, len-1, reader); 17 } else { 18 return error.ProtocolError; 19 } 20 } 21 22 // When you're expecting several possible messages. 23 // 'Expected' should be a tagged union of message types you are expecting. 24 pub fn readAnyMessage(a: std.mem.Allocator, reader: anytype, comptime Expected: type) !Expected { 25 var nka: u32 = 0; 26 while (nka < 1000) { 27 var len = try reader.readInt(u32, .big); 28 // All messages except Keepalive start with a single byte message type. 29 // Skip keepalive messages, we don't care (unless you're spamming them) 30 if (len == 0) { 31 nka += 1; 32 continue; 33 } 34 var mt = try reader.readByte(); 35 inline for (@typeInfo(Expected).Union.fields) |field| { 36 const msgType = field.type; 37 if (msgType.Tag == mt) { 38 return @unionInit(Expected, field.name, try msgType.read(a, len-1, reader)); 39 } 40 } else { 41 return error.ProtocolError; 42 } 43 } else { 44 return error.ProtocolError; 45 } 46 } 47 48 // Handshake message has a different structure to the rest. And it's only read once per connection. 49 pub const Handshake = struct { 50 info_hash: [20]u8, 51 peer_id: [20]u8, 52 53 pub fn read(reader: anytype) !Handshake { 54 var msg = [_]u8{0} ** 68; 55 try reader.readNoEof(&msg); 56 if (msg[0] != 19) return error.ProtocolError; 57 if (!std.mem.eql(u8, msg[1..20], "BitTorrent protocol")) return error.ProtocolError; 58 //if (!std.mem.allEqual(u8, msg[20..28], 0)) return error.ProtocolError; 59 var res: Handshake = undefined; 60 @memcpy(&res.info_hash, msg[28..48]); 61 @memcpy(&res.peer_id, msg[48..68]); 62 return res; 63 } 64 65 pub fn write(self: Handshake, writer: anytype) !void { 66 try writer.writeByte(19); 67 try writer.writeAll("BitTorrent protocol"); 68 try writer.writeByteNTimes(0, 8); 69 try writer.writeAll(&self.info_hash); 70 try writer.writeAll(&self.peer_id); 71 } 72 }; 73 74 75 pub const Unchoke = struct { 76 pub const Tag: u8 = 1; 77 pub fn read(a: std.mem.Allocator, len: usize, reader: anytype) !Unchoke { 78 _ = a; 79 _ = reader; 80 if (len != 0) return error.ProtocolError; 81 return .{}; 82 } 83 }; 84 85 pub const Interested = struct { 86 pub const Tag: u8 = 2; 87 pub fn read(a: std.mem.Allocator, len: usize, reader: anytype) !Interested { 88 _ = a; 89 _ = reader; 90 if (len != 0) return error.ProtocolError; 91 return .{}; 92 } 93 pub fn write(writer: anytype) !void { 94 try writer.writeInt(u32, 1, .big); 95 try writer.writeInt(u8, Tag, .big); 96 } 97 }; 98 99 pub const Bitfield = struct { 100 pub const Tag: u8 = 5; 101 pub fn read(a: std.mem.Allocator, len: usize, reader: anytype) !Bitfield { 102 // TODO actually read this message and do something useful. 103 _ = a; 104 try reader.skipBytes(len, .{}); // 105 return .{}; 106 } 107 }; 108 109 pub const Request = struct { 110 pub const Tag: u8 = 6; 111 index: u32, 112 begin: u32, 113 length: u32, 114 pub fn write(self: Request, writer: anytype) !void { 115 try writer.writeInt(u32, 13, .big); 116 try writer.writeInt(u8, Tag, .big); 117 try writer.writeInt(u32,self.index, .big); 118 try writer.writeInt(u32,self.begin, .big); 119 try writer.writeInt(u32,self.length, .big); 120 } 121 }; 122 123 pub const Piece = struct { 124 pub const Tag: u8 = 7; 125 index: u32, 126 begin: u32, 127 block: []const u8, 128 pub fn read(a: std.mem.Allocator, len: usize, reader: anytype) !Piece { 129 if (len <= 8) { 130 std.log.err("Piece#read len {}", .{len}); 131 return error.ProtocolError; 132 } 133 var ix = try reader.readInt(u32, .big); 134 var be = try reader.readInt(u32, .big); 135 var bl = try a.alloc(u8, len-8); 136 errdefer a.free(bl); 137 try reader.readNoEof(bl); 138 return .{ 139 .index = ix, 140 .begin = be, 141 .block = bl, 142 }; 143 } 144 pub fn deinit(self: *Piece, a: std.mem.Allocator) void { 145 a.free(self.block); 146 } 147 }; 148 149 test "read any" { 150 const a = std.testing.allocator; 151 var fbs = std.io.fixedBufferStream(&[_]u8{0, 0, 0, 1, Unchoke.Tag}); 152 var r = fbs.reader(); 153 const T = union(enum) {u: Unchoke, i: Interested, b: Bitfield}; 154 var msg: T = try readAnyMessage(a, r, T); 155 try std.testing.expect(msg == .u); 156 }