const std = @import("std"); const builtin = @import("builtin"); pub fn readMessage(a: std.mem.Allocator, reader: anytype, comptime msgType: type) !msgType { var nka: u32 = 0; while (nka < 1000) { var len = try reader.readInt(u32, .big); // All messages except Keepalive start with a single byte message type. // Skip keepalive messages, we don't care (unless you're spamming them) if (len == 0) { nka += 1; continue; } var mt = try reader.readByte(); if (mt != msgType.Tag) return error.ProtocolError; return try msgType.read(a, len-1, reader); } else { return error.ProtocolError; } } // When you're expecting several possible messages. // 'Expected' should be a tagged union of message types you are expecting. pub fn readAnyMessage(a: std.mem.Allocator, reader: anytype, comptime Expected: type) !Expected { var nka: u32 = 0; while (nka < 1000) { var len = try reader.readInt(u32, .big); // All messages except Keepalive start with a single byte message type. // Skip keepalive messages, we don't care (unless you're spamming them) if (len == 0) { nka += 1; continue; } var mt = try reader.readByte(); inline for (@typeInfo(Expected).Union.fields) |field| { const msgType = field.type; if (msgType.Tag == mt) { return @unionInit(Expected, field.name, try msgType.read(a, len-1, reader)); } } else { return error.ProtocolError; } } else { return error.ProtocolError; } } // Handshake message has a different structure to the rest. And it's only read once per connection. pub const Handshake = struct { info_hash: [20]u8, peer_id: [20]u8, pub fn read(reader: anytype) !Handshake { var msg = [_]u8{0} ** 68; try reader.readNoEof(&msg); if (msg[0] != 19) return error.ProtocolError; if (!std.mem.eql(u8, msg[1..20], "BitTorrent protocol")) return error.ProtocolError; //if (!std.mem.allEqual(u8, msg[20..28], 0)) return error.ProtocolError; var res: Handshake = undefined; @memcpy(&res.info_hash, msg[28..48]); @memcpy(&res.peer_id, msg[48..68]); return res; } pub fn write(self: Handshake, writer: anytype) !void { try writer.writeByte(19); try writer.writeAll("BitTorrent protocol"); try writer.writeByteNTimes(0, 8); try writer.writeAll(&self.info_hash); try writer.writeAll(&self.peer_id); } }; pub const Unchoke = struct { pub const Tag: u8 = 1; pub fn read(a: std.mem.Allocator, len: usize, reader: anytype) !Unchoke { _ = a; _ = reader; if (len != 0) return error.ProtocolError; return .{}; } }; pub const Interested = struct { pub const Tag: u8 = 2; pub fn read(a: std.mem.Allocator, len: usize, reader: anytype) !Interested { _ = a; _ = reader; if (len != 0) return error.ProtocolError; return .{}; } pub fn write(writer: anytype) !void { try writer.writeInt(u32, 1, .big); try writer.writeInt(u8, Tag, .big); } }; pub const Bitfield = struct { pub const Tag: u8 = 5; pub fn read(a: std.mem.Allocator, len: usize, reader: anytype) !Bitfield { // TODO actually read this message and do something useful. _ = a; try reader.skipBytes(len, .{}); // return .{}; } }; pub const Request = struct { pub const Tag: u8 = 6; index: u32, begin: u32, length: u32, pub fn write(self: Request, writer: anytype) !void { try writer.writeInt(u32, 13, .big); try writer.writeInt(u8, Tag, .big); try writer.writeInt(u32,self.index, .big); try writer.writeInt(u32,self.begin, .big); try writer.writeInt(u32,self.length, .big); } }; pub const Piece = struct { pub const Tag: u8 = 7; index: u32, begin: u32, block: []const u8, pub fn read(a: std.mem.Allocator, len: usize, reader: anytype) !Piece { if (len <= 8) { std.log.err("Piece#read len {}", .{len}); return error.ProtocolError; } var ix = try reader.readInt(u32, .big); var be = try reader.readInt(u32, .big); var bl = try a.alloc(u8, len-8); errdefer a.free(bl); try reader.readNoEof(bl); return .{ .index = ix, .begin = be, .block = bl, }; } pub fn deinit(self: *Piece, a: std.mem.Allocator) void { a.free(self.block); } }; test "read any" { const a = std.testing.allocator; var fbs = std.io.fixedBufferStream(&[_]u8{0, 0, 0, 1, Unchoke.Tag}); var r = fbs.reader(); const T = union(enum) {u: Unchoke, i: Interested, b: Bitfield}; var msg: T = try readAnyMessage(a, r, T); try std.testing.expect(msg == .u); }