diff options
author | Martin Ashby <martin@ashbysoft.com> | 2023-11-12 21:17:01 +0000 |
---|---|---|
committer | Martin Ashby <martin@ashbysoft.com> | 2023-11-12 21:17:01 +0000 |
commit | 6c24cd4862cd06f3364810b73e3de1bab411f31a (patch) | |
tree | 1ef0d35963d960e6f885ceca172e22c8de1c5217 | |
parent | 536837b44823aedf3dab0b8ef844d57cbae7af74 (diff) | |
download | zbt-6c24cd4862cd06f3364810b73e3de1bab411f31a.tar.gz zbt-6c24cd4862cd06f3364810b73e3de1bab411f31a.tar.bz2 zbt-6c24cd4862cd06f3364810b73e3de1bab411f31a.tar.xz zbt-6c24cd4862cd06f3364810b73e3de1bab411f31a.zip |
Move tracker protocol into it's own file
Change piece_length -> u32, protocol constrains that anyway
annnnd, successfully download a file (albeit from one torrent only with
no checking of a ton of stuff)
-rw-r--r-- | src/main.zig | 172 | ||||
-rw-r--r-- | src/metainfo.zig | 13 | ||||
-rw-r--r-- | src/peer_protocol.zig | 130 | ||||
-rw-r--r-- | src/tracker_protocol.zig | 63 |
4 files changed, 289 insertions, 89 deletions
diff --git a/src/main.zig b/src/main.zig index d50cc90..f27735d 100644 --- a/src/main.zig +++ b/src/main.zig @@ -3,18 +3,19 @@ const MetaInfo = @import("metainfo.zig"); const bencode = @import("bencode.zig"); const AnyWriter = @import("anywriter.zig"); const peerproto = @import("peer_protocol.zig"); - -const prr = "00112233445566778899"; +const trackproto = @import("tracker_protocol.zig"); pub fn main() !void { - var peer_id: [20]u8 = undefined; - @memcpy(&peer_id, prr); - var gpa = std.heap.GeneralPurposeAllocator(.{}){}; defer _ = gpa.deinit(); const a = gpa.allocator(); - // open & parse the torrent file + // TODO figure this out. It's not that important, I think, unless + // other clients have special handling for different patterns. + // Spec looks like a bit of a free-for-all here. + var peer_id: [20]u8 = undefined; + @memcpy(&peer_id, "00112233445566778899"); + const f = try std.fs.cwd().openFile("src/sample.torrent", .{}); defer f.close(); var fr = f.reader(); @@ -22,105 +23,106 @@ pub fn main() !void { defer mib.deinit(a); var mi = try MetaInfo.parse(a, mib); defer mi.deinit(a); + const info_hash = try mi.info.hash(a); - // call the tracker... var c = std.http.Client{ .allocator = a, }; defer c.deinit(); - var q = std.StringHashMap([]const u8).init(a); - defer q.deinit(); - var info_hash = try mi.info.hash(a); - var buf_left = [_]u8{0} ** 1024; - try q.put("info_hash", &info_hash); - try q.put("peer_id", &peer_id); - try q.put("port", "6881"); - try q.put("uploaded", "0"); - try q.put("downloaded", "0"); - try q.put("left", try std.fmt.bufPrint(&buf_left, "{}", .{mi.info.files[0].length})); - try q.put("compact", "1"); - var qs = try toqs(a, q); - defer a.free(qs); - var url = try std.fmt.allocPrint(a, "{s}?{s}", .{ mi.announce, qs }); + const url = try trackproto.trackerRequestUrl(a, info_hash, peer_id, mi.info.files[0].length, mi.announce); defer a.free(url); var res = try c.fetch(a, .{ .location = .{ .url = url } }); defer res.deinit(); - if (res.status == .ok) { - var trb = try bencode.bdecodeBuf(a, res.body.?); - defer trb.deinit(a); - var tr = try TrackerResp.parse(a, trb); - defer tr.deinit(a); - - for (tr.peers) |peer| { - std.log.info("peer: {}", .{peer}); - } - - // Handle peers... - + if (res.status != .ok) { + return error.TrackerHttpError; + } - const p = tr.peers[0]; - var ps = try std.net.tcpConnectToAddress(p); - defer ps.close(); - var pw = ps.writer(); - var pr = ps.reader(); - var hs: peerproto.Handshake = .{ - .info_hash = info_hash, - .peer_id = peer_id, - }; + var trb = try bencode.bdecodeBuf(a, res.body.?); + defer trb.deinit(a); + var tr = try trackproto.TrackerResp.parse(a, trb); + defer tr.deinit(a); - try hs.write(pw); - var phs = try peerproto.Handshake.read(pr); - std.log.info("peer at {} peer_id {s}", .{ p, std.fmt.fmtSliceHexLower(&phs.peer_id) }); + if (tr.peers.len == 0) { + std.log.info("no peers", .{}); + return; } -} -fn toqs(a: std.mem.Allocator, hm: std.StringHashMap([]const u8)) ![]const u8 { - var al = std.ArrayList(u8).init(a); - defer al.deinit(); - var w = al.writer(); - var it = hm.iterator(); - var first = true; - while (it.next()) |entry| { - if (!first) try w.writeByte('&'); - try std.Uri.writeEscapedQuery(w, entry.key_ptr.*); - try w.writeByte('='); - try std.Uri.writeEscapedQuery(w, entry.value_ptr.*); - first = false; + for (tr.peers) |peer| { + std.log.info("peer: {}", .{peer}); } - return try al.toOwnedSlice(); -} -const TrackerResp = struct { - // interval: u64, - peers: []std.net.Address, - - pub fn parse(a: std.mem.Allocator, b: bencode.BValue) !TrackerResp { - var ipl = std.ArrayList(std.net.Address).init(a); - defer ipl.deinit(); - var d = b.asDict() catch return error.Malformatted; - var pb = d.get("peers") orelse return error.Malformatted; - var ps = pb.asString() catch return error.Malformatted; - if ((ps.len % 6) != 0) return error.Malformatted; - for (0..ps.len / 6) |ix| { - const start = ix * 6; - const port = std.mem.readInt(u16, ps[start + 4 .. start + 6][0..2], .big); - var ip = [_]u8{0} ** 4; - @memcpy(&ip, ps[start .. start + 4]); - try ipl.append(std.net.Address.initIp4(ip, port)); - } - return .{ - .peers = try ipl.toOwnedSlice(), - }; - } + // Handle peers, PoC we're just going to handle 1 peer and download everything from them very simplistically. + const p = tr.peers[0]; + const file = mi.info.files[0]; + var ps = try std.net.tcpConnectToAddress(p); + defer ps.close(); + var pw = ps.writer(); + var pr = ps.reader(); + var hs: peerproto.Handshake = .{ + .info_hash = info_hash, + .peer_id = peer_id, + }; + + try hs.write(pw); + var phs = try peerproto.Handshake.read(pr); + std.log.info("peer at {} peer_id {s}", .{ p, std.fmt.fmtSliceHexLower(&phs.peer_id) }); + - pub fn deinit(self: *TrackerResp, a: std.mem.Allocator) void { - a.free(self.peers); + var bf = try peerproto.readMessage(a, pr, peerproto.Bitfield); + _ = bf; // ignore it for now. + try peerproto.Interested.write(pw); + _ = try peerproto.readMessage(a, pr, peerproto.Unchoke); + + var of = try std.fs.cwd().createFile(file.name, .{}); + defer of.close(); + errdefer { + // try to truncate the now-bad file... + of.setEndPos(0) catch {}; } -}; + // Read the piece into memory, we'll check the hash before it goes to disk... + var piece_buf = try a.alloc(u8, mi.info.piece_length); + defer a.free(piece_buf); + + for (0..mi.info.pieceCount()) |pi| { + const piece_length = @min(mi.info.piece_length, file.length - (pi * mi.info.piece_length)); + var s1 = std.crypto.hash.Sha1.init(.{}); + + // Send a request message for each 16KiB block of the first piece + const blklen: u32 = 16*1024; + var blkcount = try std.math.divCeil(u32, piece_length, blklen); + for (0..blkcount) |i| { + const begin = std.math.cast(u32, i*blklen).?; + const len = @min(blklen, piece_length - begin); + const req = peerproto.Request{ + .index = @intCast(pi), + .begin = begin, + .length = len, + }; + std.log.info("Request {any}", .{req}); + try req.write(pw); + var piece = try peerproto.readMessage(a, pr, peerproto.Piece); + defer piece.deinit(a); + if (piece.index != req.index) return error.ProtocolError; + if (piece.begin != req.begin) return error.ProtocolError; + if (piece.block.len != req.length) return error.ProtocolError; + s1.update(piece.block); + @memcpy(piece_buf[piece.begin..piece.begin+piece.block.len], piece.block); + } + var ah = s1.finalResult(); + var ph0 = mi.info.pieceHash(pi).?; + if (std.mem.eql(u8, &ah, &ph0)) { + try of.writeAll(piece_buf[0..piece_length]); + } else { + return error.BadHash; + } + } + std.log.info("fin", .{}); +} test { _ = bencode; _ = MetaInfo; + _ = peerproto; } diff --git a/src/metainfo.zig b/src/metainfo.zig index cd00438..04e8f64 100644 --- a/src/metainfo.zig +++ b/src/metainfo.zig @@ -12,7 +12,7 @@ pub const Info = struct { md5sum: ?[]const u8 = null, }; - piece_length: u64, + piece_length: u32, pieces: []const u8, files: []File, private: ?bool, @@ -68,7 +68,7 @@ pub const Info = struct { if (ps.len % 20 != 0) return error.Malformatted; return .{ - .piece_length = pl.asInt(u64) catch return error.Malformatted, + .piece_length = pl.asInt(u32) catch return error.Malformatted, .pieces = ps, .files = try files.toOwnedSlice(), .private = priv, @@ -116,10 +116,15 @@ pub const Info = struct { try b.bencode(w); return sha1.finalResult(); } - pub fn pieceHash(self: Info, ix: usize) ?[]const u8 { + pub fn pieceHash(self: Info, ix: usize) ?[20]u8 { const start = 20 * ix; if (start >= self.pieces.len) return null; - return self.pieces[start .. start + 20]; + var res: [20]u8 = undefined; + @memcpy(&res, self.pieces[start .. start + 20]); + return res; + } + pub fn pieceCount(self: Info) u32 { + return @as(u32, @intCast(self.pieces.len)) / 20; } }; diff --git a/src/peer_protocol.zig b/src/peer_protocol.zig index cd70ad2..91ea349 100644 --- a/src/peer_protocol.zig +++ b/src/peer_protocol.zig @@ -1,5 +1,51 @@ const std = @import("std"); +const builtin = @import("builtin"); +pub fn readMessage(a: std.mem.Allocator, reader: anytype, comptime msgType: type) !msgType { + var nka: u32 = 0; + while (nka < 1000) { + var len = try reader.readInt(u32, .big); + // All messages except Keepalive start with a single byte message type. + // Skip keepalive messages, we don't care (unless you're spamming them) + if (len == 0) { + nka += 1; + continue; + } + var mt = try reader.readByte(); + if (mt != msgType.Tag) return error.ProtocolError; + return try msgType.read(a, len-1, reader); + } else { + return error.ProtocolError; + } +} + +// When you're expecting several possible messages. +// 'Expected' should be a tagged union of message types you are expecting. +pub fn readAnyMessage(a: std.mem.Allocator, reader: anytype, comptime Expected: type) !Expected { + var nka: u32 = 0; + while (nka < 1000) { + var len = try reader.readInt(u32, .big); + // All messages except Keepalive start with a single byte message type. + // Skip keepalive messages, we don't care (unless you're spamming them) + if (len == 0) { + nka += 1; + continue; + } + var mt = try reader.readByte(); + inline for (@typeInfo(Expected).Union.fields) |field| { + const msgType = field.type; + if (msgType.Tag == mt) { + return @unionInit(Expected, field.name, try msgType.read(a, len-1, reader)); + } + } else { + return error.ProtocolError; + } + } else { + return error.ProtocolError; + } +} + +// Handshake message has a different structure to the rest. And it's only read once per connection. pub const Handshake = struct { info_hash: [20]u8, peer_id: [20]u8, @@ -24,3 +70,87 @@ pub const Handshake = struct { try writer.writeAll(&self.peer_id); } }; + + +pub const Unchoke = struct { + pub const Tag: u8 = 1; + pub fn read(a: std.mem.Allocator, len: usize, reader: anytype) !Unchoke { + _ = a; + _ = reader; + if (len != 0) return error.ProtocolError; + return .{}; + } +}; + +pub const Interested = struct { + pub const Tag: u8 = 2; + pub fn read(a: std.mem.Allocator, len: usize, reader: anytype) !Interested { + _ = a; + _ = reader; + if (len != 0) return error.ProtocolError; + return .{}; + } + pub fn write(writer: anytype) !void { + try writer.writeInt(u32, 1, .big); + try writer.writeInt(u8, Tag, .big); + } +}; + +pub const Bitfield = struct { + pub const Tag: u8 = 5; + pub fn read(a: std.mem.Allocator, len: usize, reader: anytype) !Bitfield { + // TODO actually read this message and do something useful. + _ = a; + try reader.skipBytes(len, .{}); // + return .{}; + } +}; + +pub const Request = struct { + pub const Tag: u8 = 6; + index: u32, + begin: u32, + length: u32, + pub fn write(self: Request, writer: anytype) !void { + try writer.writeInt(u32, 13, .big); + try writer.writeInt(u8, Tag, .big); + try writer.writeInt(u32,self.index, .big); + try writer.writeInt(u32,self.begin, .big); + try writer.writeInt(u32,self.length, .big); + } +}; + +pub const Piece = struct { + pub const Tag: u8 = 7; + index: u32, + begin: u32, + block: []const u8, + pub fn read(a: std.mem.Allocator, len: usize, reader: anytype) !Piece { + if (len <= 8) { + std.log.err("Piece#read len {}", .{len}); + return error.ProtocolError; + } + var ix = try reader.readInt(u32, .big); + var be = try reader.readInt(u32, .big); + var bl = try a.alloc(u8, len-8); + errdefer a.free(bl); + try reader.readNoEof(bl); + return .{ + .index = ix, + .begin = be, + .block = bl, + }; + } + pub fn deinit(self: *Piece, a: std.mem.Allocator) void { + a.free(self.block); + } +}; + +test "read any" { + const a = std.testing.allocator; + var fbs = std.io.fixedBufferStream(&[_]u8{0, 0, 0, 1, Unchoke.Tag}); + var r = fbs.reader(); + const T = union(enum) {u: Unchoke, i: Interested, b: Bitfield}; + var msg: T = try readAnyMessage(a, r, T); + try std.testing.expect(msg == .u); +}
\ No newline at end of file diff --git a/src/tracker_protocol.zig b/src/tracker_protocol.zig new file mode 100644 index 0000000..0a6a9c2 --- /dev/null +++ b/src/tracker_protocol.zig @@ -0,0 +1,63 @@ +const std = @import("std"); +const bencode = @import("bencode.zig"); + +// https://wiki.theory.org/BitTorrentSpecification#Tracker_Response +pub const TrackerResp = struct { + // interval: u64, + peers: []std.net.Address, + + pub fn parse(a: std.mem.Allocator, b: bencode.BValue) !TrackerResp { + var ipl = std.ArrayList(std.net.Address).init(a); + defer ipl.deinit(); + var d = b.asDict() catch return error.Malformatted; + var pb = d.get("peers") orelse return error.Malformatted; + var ps = pb.asString() catch return error.Malformatted; + if ((ps.len % 6) != 0) return error.Malformatted; + for (0..ps.len / 6) |ix| { + const start = ix * 6; + const port = std.mem.readInt(u16, ps[start + 4 .. start + 6][0..2], .big); + var ip = [_]u8{0} ** 4; + @memcpy(&ip, ps[start .. start + 4]); + try ipl.append(std.net.Address.initIp4(ip, port)); + } + return .{ + .peers = try ipl.toOwnedSlice(), + }; + } + + pub fn deinit(self: *TrackerResp, a: std.mem.Allocator) void { + a.free(self.peers); + } +}; + +pub fn trackerRequestUrl(a: std.mem.Allocator, info_hash: [20]u8, peer_id: [20]u8, left: usize, announce: []const u8) ![]const u8 { + var q = std.StringHashMap([]const u8).init(a); + defer q.deinit(); + try q.put("info_hash", &info_hash); + try q.put("peer_id", &peer_id); + try q.put("port", "6881"); + try q.put("uploaded", "0"); + try q.put("downloaded", "0"); + var buf_left = [_]u8{0} ** 1024; + try q.put("left", try std.fmt.bufPrint(&buf_left, "{}", .{left})); + try q.put("compact", "1"); + var qs = try toqs(a, q); + defer a.free(qs); + return try std.fmt.allocPrint(a, "{s}?{s}", .{ announce, qs }); +} + +fn toqs(a: std.mem.Allocator, hm: std.StringHashMap([]const u8)) ![]const u8 { + var al = std.ArrayList(u8).init(a); + defer al.deinit(); + var w = al.writer(); + var it = hm.iterator(); + var first = true; + while (it.next()) |entry| { + if (!first) try w.writeByte('&'); + try std.Uri.writeEscapedQuery(w, entry.key_ptr.*); + try w.writeByte('='); + try std.Uri.writeEscapedQuery(w, entry.value_ptr.*); + first = false; + } + return try al.toOwnedSlice(); +} |