zbt

CLI Bittorrent client, written in Zig
Log | Files | Refs | README

peer_protocol.zig (5021B)


      1 const std = @import("std");
      2 const builtin = @import("builtin");
      3 
      4 pub fn readMessage(a: std.mem.Allocator, reader: anytype, comptime msgType: type) !msgType {
      5     var nka: u32 = 0;
      6     while (nka < 1000) {
      7         var len = try reader.readInt(u32, .big);
      8         // All messages except Keepalive start with a single byte message type.
      9         // Skip keepalive messages, we don't care (unless you're spamming them)
     10         if (len == 0) {
     11             nka += 1; 
     12             continue;
     13         }
     14         var mt = try reader.readByte();
     15         if (mt != msgType.Tag) return error.ProtocolError;
     16         return try msgType.read(a, len-1, reader);
     17     } else {
     18         return error.ProtocolError;
     19     }
     20 }
     21 
     22 // When you're expecting several possible messages.
     23 // 'Expected' should be a tagged union of message types you are expecting.
     24 pub fn readAnyMessage(a: std.mem.Allocator, reader: anytype, comptime Expected: type) !Expected {
     25     var nka: u32 = 0;
     26     while (nka < 1000) {
     27         var len = try reader.readInt(u32, .big);
     28         // All messages except Keepalive start with a single byte message type.
     29         // Skip keepalive messages, we don't care (unless you're spamming them)
     30         if (len == 0) {
     31             nka += 1; 
     32             continue;
     33         }
     34         var mt = try reader.readByte();
     35         inline for (@typeInfo(Expected).Union.fields) |field| {
     36             const msgType = field.type;
     37             if (msgType.Tag == mt) {
     38                 return @unionInit(Expected, field.name, try msgType.read(a, len-1, reader));
     39             }
     40         } else {
     41             return error.ProtocolError;
     42         }
     43     } else {
     44         return error.ProtocolError;
     45     }
     46 }
     47 
     48 // Handshake message has a different structure to the rest. And it's only read once per connection.
     49 pub const Handshake = struct {
     50     info_hash: [20]u8,
     51     peer_id: [20]u8,
     52 
     53     pub fn read(reader: anytype) !Handshake {
     54         var msg = [_]u8{0} ** 68;
     55         try reader.readNoEof(&msg);
     56         if (msg[0] != 19) return error.ProtocolError;
     57         if (!std.mem.eql(u8, msg[1..20], "BitTorrent protocol")) return error.ProtocolError;
     58         //if (!std.mem.allEqual(u8, msg[20..28], 0)) return error.ProtocolError;
     59         var res: Handshake = undefined;
     60         @memcpy(&res.info_hash, msg[28..48]);
     61         @memcpy(&res.peer_id, msg[48..68]);
     62         return res;
     63     }
     64 
     65     pub fn write(self: Handshake, writer: anytype) !void {
     66         try writer.writeByte(19);
     67         try writer.writeAll("BitTorrent protocol");
     68         try writer.writeByteNTimes(0, 8);
     69         try writer.writeAll(&self.info_hash);
     70         try writer.writeAll(&self.peer_id);
     71     }
     72 };
     73 
     74 
     75 pub const Unchoke = struct {
     76     pub const Tag: u8 = 1;
     77     pub fn read(a: std.mem.Allocator, len: usize, reader: anytype) !Unchoke {
     78         _ = a;
     79         _ = reader;
     80         if (len != 0) return error.ProtocolError;
     81         return .{};
     82     }
     83 };
     84 
     85 pub const Interested = struct {
     86     pub const Tag: u8 = 2;
     87     pub fn read(a: std.mem.Allocator, len: usize, reader: anytype) !Interested {
     88         _ = a;
     89         _ = reader;
     90         if (len != 0) return error.ProtocolError;
     91         return .{};   
     92     }
     93     pub fn write(writer: anytype) !void {
     94         try writer.writeInt(u32, 1, .big);
     95         try writer.writeInt(u8, Tag, .big);
     96     }
     97 };
     98 
     99 pub const Bitfield = struct {
    100     pub const Tag: u8 = 5;
    101     pub fn read(a: std.mem.Allocator, len: usize, reader: anytype) !Bitfield {
    102         // TODO actually read this message and do something useful.
    103         _ = a;
    104         try reader.skipBytes(len, .{}); // 
    105         return .{};
    106     }
    107 };
    108 
    109 pub const Request = struct {
    110     pub const Tag: u8 = 6;
    111     index: u32, 
    112     begin: u32, 
    113     length: u32,
    114     pub fn write(self: Request, writer: anytype) !void {
    115         try writer.writeInt(u32, 13, .big);
    116         try writer.writeInt(u8, Tag, .big);
    117         try writer.writeInt(u32,self.index, .big);
    118         try writer.writeInt(u32,self.begin, .big);
    119         try writer.writeInt(u32,self.length, .big);
    120     }
    121 };
    122 
    123 pub const Piece = struct {
    124     pub const Tag: u8 = 7;
    125     index: u32,
    126     begin: u32,
    127     block: []const u8,
    128     pub fn read(a: std.mem.Allocator, len: usize, reader: anytype) !Piece {
    129         if (len <= 8) {
    130             std.log.err("Piece#read len {}", .{len});
    131             return error.ProtocolError;
    132         }
    133         var ix = try reader.readInt(u32, .big);
    134         var be = try reader.readInt(u32, .big);
    135         var bl = try a.alloc(u8, len-8);
    136         errdefer a.free(bl);
    137         try reader.readNoEof(bl);
    138         return .{
    139             .index = ix, 
    140             .begin = be,
    141             .block = bl,
    142         };
    143     }
    144     pub fn deinit(self: *Piece, a: std.mem.Allocator) void {
    145         a.free(self.block);
    146     }
    147 };
    148 
    149 test "read any" {
    150     const a = std.testing.allocator;
    151     var fbs = std.io.fixedBufferStream(&[_]u8{0, 0, 0, 1, Unchoke.Tag});
    152     var r = fbs.reader();
    153     const T = union(enum) {u: Unchoke, i: Interested, b: Bitfield};
    154     var msg: T = try readAnyMessage(a, r, T);
    155     try std.testing.expect(msg == .u);
    156 }