diff options
Diffstat (limited to 'src/peer_protocol.zig')
-rw-r--r-- | src/peer_protocol.zig | 130 |
1 files changed, 130 insertions, 0 deletions
diff --git a/src/peer_protocol.zig b/src/peer_protocol.zig index cd70ad2..91ea349 100644 --- a/src/peer_protocol.zig +++ b/src/peer_protocol.zig @@ -1,5 +1,51 @@ const std = @import("std"); +const builtin = @import("builtin"); +pub fn readMessage(a: std.mem.Allocator, reader: anytype, comptime msgType: type) !msgType { + var nka: u32 = 0; + while (nka < 1000) { + var len = try reader.readInt(u32, .big); + // All messages except Keepalive start with a single byte message type. + // Skip keepalive messages, we don't care (unless you're spamming them) + if (len == 0) { + nka += 1; + continue; + } + var mt = try reader.readByte(); + if (mt != msgType.Tag) return error.ProtocolError; + return try msgType.read(a, len-1, reader); + } else { + return error.ProtocolError; + } +} + +// When you're expecting several possible messages. +// 'Expected' should be a tagged union of message types you are expecting. +pub fn readAnyMessage(a: std.mem.Allocator, reader: anytype, comptime Expected: type) !Expected { + var nka: u32 = 0; + while (nka < 1000) { + var len = try reader.readInt(u32, .big); + // All messages except Keepalive start with a single byte message type. + // Skip keepalive messages, we don't care (unless you're spamming them) + if (len == 0) { + nka += 1; + continue; + } + var mt = try reader.readByte(); + inline for (@typeInfo(Expected).Union.fields) |field| { + const msgType = field.type; + if (msgType.Tag == mt) { + return @unionInit(Expected, field.name, try msgType.read(a, len-1, reader)); + } + } else { + return error.ProtocolError; + } + } else { + return error.ProtocolError; + } +} + +// Handshake message has a different structure to the rest. And it's only read once per connection. pub const Handshake = struct { info_hash: [20]u8, peer_id: [20]u8, @@ -24,3 +70,87 @@ pub const Handshake = struct { try writer.writeAll(&self.peer_id); } }; + + +pub const Unchoke = struct { + pub const Tag: u8 = 1; + pub fn read(a: std.mem.Allocator, len: usize, reader: anytype) !Unchoke { + _ = a; + _ = reader; + if (len != 0) return error.ProtocolError; + return .{}; + } +}; + +pub const Interested = struct { + pub const Tag: u8 = 2; + pub fn read(a: std.mem.Allocator, len: usize, reader: anytype) !Interested { + _ = a; + _ = reader; + if (len != 0) return error.ProtocolError; + return .{}; + } + pub fn write(writer: anytype) !void { + try writer.writeInt(u32, 1, .big); + try writer.writeInt(u8, Tag, .big); + } +}; + +pub const Bitfield = struct { + pub const Tag: u8 = 5; + pub fn read(a: std.mem.Allocator, len: usize, reader: anytype) !Bitfield { + // TODO actually read this message and do something useful. + _ = a; + try reader.skipBytes(len, .{}); // + return .{}; + } +}; + +pub const Request = struct { + pub const Tag: u8 = 6; + index: u32, + begin: u32, + length: u32, + pub fn write(self: Request, writer: anytype) !void { + try writer.writeInt(u32, 13, .big); + try writer.writeInt(u8, Tag, .big); + try writer.writeInt(u32,self.index, .big); + try writer.writeInt(u32,self.begin, .big); + try writer.writeInt(u32,self.length, .big); + } +}; + +pub const Piece = struct { + pub const Tag: u8 = 7; + index: u32, + begin: u32, + block: []const u8, + pub fn read(a: std.mem.Allocator, len: usize, reader: anytype) !Piece { + if (len <= 8) { + std.log.err("Piece#read len {}", .{len}); + return error.ProtocolError; + } + var ix = try reader.readInt(u32, .big); + var be = try reader.readInt(u32, .big); + var bl = try a.alloc(u8, len-8); + errdefer a.free(bl); + try reader.readNoEof(bl); + return .{ + .index = ix, + .begin = be, + .block = bl, + }; + } + pub fn deinit(self: *Piece, a: std.mem.Allocator) void { + a.free(self.block); + } +}; + +test "read any" { + const a = std.testing.allocator; + var fbs = std.io.fixedBufferStream(&[_]u8{0, 0, 0, 1, Unchoke.Tag}); + var r = fbs.reader(); + const T = union(enum) {u: Unchoke, i: Interested, b: Bitfield}; + var msg: T = try readAnyMessage(a, r, T); + try std.testing.expect(msg == .u); +}
\ No newline at end of file |