aboutsummaryrefslogtreecommitdiff
path: root/src
diff options
context:
space:
mode:
Diffstat (limited to 'src')
-rw-r--r--src/main.zig172
-rw-r--r--src/metainfo.zig13
-rw-r--r--src/peer_protocol.zig130
-rw-r--r--src/tracker_protocol.zig63
4 files changed, 289 insertions, 89 deletions
diff --git a/src/main.zig b/src/main.zig
index d50cc90..f27735d 100644
--- a/src/main.zig
+++ b/src/main.zig
@@ -3,18 +3,19 @@ const MetaInfo = @import("metainfo.zig");
const bencode = @import("bencode.zig");
const AnyWriter = @import("anywriter.zig");
const peerproto = @import("peer_protocol.zig");
-
-const prr = "00112233445566778899";
+const trackproto = @import("tracker_protocol.zig");
pub fn main() !void {
- var peer_id: [20]u8 = undefined;
- @memcpy(&peer_id, prr);
-
var gpa = std.heap.GeneralPurposeAllocator(.{}){};
defer _ = gpa.deinit();
const a = gpa.allocator();
- // open & parse the torrent file
+ // TODO figure this out. It's not that important, I think, unless
+ // other clients have special handling for different patterns.
+ // Spec looks like a bit of a free-for-all here.
+ var peer_id: [20]u8 = undefined;
+ @memcpy(&peer_id, "00112233445566778899");
+
const f = try std.fs.cwd().openFile("src/sample.torrent", .{});
defer f.close();
var fr = f.reader();
@@ -22,105 +23,106 @@ pub fn main() !void {
defer mib.deinit(a);
var mi = try MetaInfo.parse(a, mib);
defer mi.deinit(a);
+ const info_hash = try mi.info.hash(a);
- // call the tracker...
var c = std.http.Client{
.allocator = a,
};
defer c.deinit();
- var q = std.StringHashMap([]const u8).init(a);
- defer q.deinit();
- var info_hash = try mi.info.hash(a);
- var buf_left = [_]u8{0} ** 1024;
- try q.put("info_hash", &info_hash);
- try q.put("peer_id", &peer_id);
- try q.put("port", "6881");
- try q.put("uploaded", "0");
- try q.put("downloaded", "0");
- try q.put("left", try std.fmt.bufPrint(&buf_left, "{}", .{mi.info.files[0].length}));
- try q.put("compact", "1");
- var qs = try toqs(a, q);
- defer a.free(qs);
- var url = try std.fmt.allocPrint(a, "{s}?{s}", .{ mi.announce, qs });
+ const url = try trackproto.trackerRequestUrl(a, info_hash, peer_id, mi.info.files[0].length, mi.announce);
defer a.free(url);
var res = try c.fetch(a, .{ .location = .{ .url = url } });
defer res.deinit();
- if (res.status == .ok) {
- var trb = try bencode.bdecodeBuf(a, res.body.?);
- defer trb.deinit(a);
- var tr = try TrackerResp.parse(a, trb);
- defer tr.deinit(a);
-
- for (tr.peers) |peer| {
- std.log.info("peer: {}", .{peer});
- }
-
- // Handle peers...
-
+ if (res.status != .ok) {
+ return error.TrackerHttpError;
+ }
- const p = tr.peers[0];
- var ps = try std.net.tcpConnectToAddress(p);
- defer ps.close();
- var pw = ps.writer();
- var pr = ps.reader();
- var hs: peerproto.Handshake = .{
- .info_hash = info_hash,
- .peer_id = peer_id,
- };
+ var trb = try bencode.bdecodeBuf(a, res.body.?);
+ defer trb.deinit(a);
+ var tr = try trackproto.TrackerResp.parse(a, trb);
+ defer tr.deinit(a);
- try hs.write(pw);
- var phs = try peerproto.Handshake.read(pr);
- std.log.info("peer at {} peer_id {s}", .{ p, std.fmt.fmtSliceHexLower(&phs.peer_id) });
+ if (tr.peers.len == 0) {
+ std.log.info("no peers", .{});
+ return;
}
-}
-fn toqs(a: std.mem.Allocator, hm: std.StringHashMap([]const u8)) ![]const u8 {
- var al = std.ArrayList(u8).init(a);
- defer al.deinit();
- var w = al.writer();
- var it = hm.iterator();
- var first = true;
- while (it.next()) |entry| {
- if (!first) try w.writeByte('&');
- try std.Uri.writeEscapedQuery(w, entry.key_ptr.*);
- try w.writeByte('=');
- try std.Uri.writeEscapedQuery(w, entry.value_ptr.*);
- first = false;
+ for (tr.peers) |peer| {
+ std.log.info("peer: {}", .{peer});
}
- return try al.toOwnedSlice();
-}
-const TrackerResp = struct {
- // interval: u64,
- peers: []std.net.Address,
-
- pub fn parse(a: std.mem.Allocator, b: bencode.BValue) !TrackerResp {
- var ipl = std.ArrayList(std.net.Address).init(a);
- defer ipl.deinit();
- var d = b.asDict() catch return error.Malformatted;
- var pb = d.get("peers") orelse return error.Malformatted;
- var ps = pb.asString() catch return error.Malformatted;
- if ((ps.len % 6) != 0) return error.Malformatted;
- for (0..ps.len / 6) |ix| {
- const start = ix * 6;
- const port = std.mem.readInt(u16, ps[start + 4 .. start + 6][0..2], .big);
- var ip = [_]u8{0} ** 4;
- @memcpy(&ip, ps[start .. start + 4]);
- try ipl.append(std.net.Address.initIp4(ip, port));
- }
- return .{
- .peers = try ipl.toOwnedSlice(),
- };
- }
+ // Handle peers, PoC we're just going to handle 1 peer and download everything from them very simplistically.
+ const p = tr.peers[0];
+ const file = mi.info.files[0];
+ var ps = try std.net.tcpConnectToAddress(p);
+ defer ps.close();
+ var pw = ps.writer();
+ var pr = ps.reader();
+ var hs: peerproto.Handshake = .{
+ .info_hash = info_hash,
+ .peer_id = peer_id,
+ };
+
+ try hs.write(pw);
+ var phs = try peerproto.Handshake.read(pr);
+ std.log.info("peer at {} peer_id {s}", .{ p, std.fmt.fmtSliceHexLower(&phs.peer_id) });
+
- pub fn deinit(self: *TrackerResp, a: std.mem.Allocator) void {
- a.free(self.peers);
+ var bf = try peerproto.readMessage(a, pr, peerproto.Bitfield);
+ _ = bf; // ignore it for now.
+ try peerproto.Interested.write(pw);
+ _ = try peerproto.readMessage(a, pr, peerproto.Unchoke);
+
+ var of = try std.fs.cwd().createFile(file.name, .{});
+ defer of.close();
+ errdefer {
+ // try to truncate the now-bad file...
+ of.setEndPos(0) catch {};
}
-};
+ // Read the piece into memory, we'll check the hash before it goes to disk...
+ var piece_buf = try a.alloc(u8, mi.info.piece_length);
+ defer a.free(piece_buf);
+
+ for (0..mi.info.pieceCount()) |pi| {
+ const piece_length = @min(mi.info.piece_length, file.length - (pi * mi.info.piece_length));
+ var s1 = std.crypto.hash.Sha1.init(.{});
+
+ // Send a request message for each 16KiB block of the first piece
+ const blklen: u32 = 16*1024;
+ var blkcount = try std.math.divCeil(u32, piece_length, blklen);
+ for (0..blkcount) |i| {
+ const begin = std.math.cast(u32, i*blklen).?;
+ const len = @min(blklen, piece_length - begin);
+ const req = peerproto.Request{
+ .index = @intCast(pi),
+ .begin = begin,
+ .length = len,
+ };
+ std.log.info("Request {any}", .{req});
+ try req.write(pw);
+ var piece = try peerproto.readMessage(a, pr, peerproto.Piece);
+ defer piece.deinit(a);
+ if (piece.index != req.index) return error.ProtocolError;
+ if (piece.begin != req.begin) return error.ProtocolError;
+ if (piece.block.len != req.length) return error.ProtocolError;
+ s1.update(piece.block);
+ @memcpy(piece_buf[piece.begin..piece.begin+piece.block.len], piece.block);
+ }
+ var ah = s1.finalResult();
+ var ph0 = mi.info.pieceHash(pi).?;
+ if (std.mem.eql(u8, &ah, &ph0)) {
+ try of.writeAll(piece_buf[0..piece_length]);
+ } else {
+ return error.BadHash;
+ }
+ }
+ std.log.info("fin", .{});
+}
test {
_ = bencode;
_ = MetaInfo;
+ _ = peerproto;
}
diff --git a/src/metainfo.zig b/src/metainfo.zig
index cd00438..04e8f64 100644
--- a/src/metainfo.zig
+++ b/src/metainfo.zig
@@ -12,7 +12,7 @@ pub const Info = struct {
md5sum: ?[]const u8 = null,
};
- piece_length: u64,
+ piece_length: u32,
pieces: []const u8,
files: []File,
private: ?bool,
@@ -68,7 +68,7 @@ pub const Info = struct {
if (ps.len % 20 != 0) return error.Malformatted;
return .{
- .piece_length = pl.asInt(u64) catch return error.Malformatted,
+ .piece_length = pl.asInt(u32) catch return error.Malformatted,
.pieces = ps,
.files = try files.toOwnedSlice(),
.private = priv,
@@ -116,10 +116,15 @@ pub const Info = struct {
try b.bencode(w);
return sha1.finalResult();
}
- pub fn pieceHash(self: Info, ix: usize) ?[]const u8 {
+ pub fn pieceHash(self: Info, ix: usize) ?[20]u8 {
const start = 20 * ix;
if (start >= self.pieces.len) return null;
- return self.pieces[start .. start + 20];
+ var res: [20]u8 = undefined;
+ @memcpy(&res, self.pieces[start .. start + 20]);
+ return res;
+ }
+ pub fn pieceCount(self: Info) u32 {
+ return @as(u32, @intCast(self.pieces.len)) / 20;
}
};
diff --git a/src/peer_protocol.zig b/src/peer_protocol.zig
index cd70ad2..91ea349 100644
--- a/src/peer_protocol.zig
+++ b/src/peer_protocol.zig
@@ -1,5 +1,51 @@
const std = @import("std");
+const builtin = @import("builtin");
+pub fn readMessage(a: std.mem.Allocator, reader: anytype, comptime msgType: type) !msgType {
+ var nka: u32 = 0;
+ while (nka < 1000) {
+ var len = try reader.readInt(u32, .big);
+ // All messages except Keepalive start with a single byte message type.
+ // Skip keepalive messages, we don't care (unless you're spamming them)
+ if (len == 0) {
+ nka += 1;
+ continue;
+ }
+ var mt = try reader.readByte();
+ if (mt != msgType.Tag) return error.ProtocolError;
+ return try msgType.read(a, len-1, reader);
+ } else {
+ return error.ProtocolError;
+ }
+}
+
+// When you're expecting several possible messages.
+// 'Expected' should be a tagged union of message types you are expecting.
+pub fn readAnyMessage(a: std.mem.Allocator, reader: anytype, comptime Expected: type) !Expected {
+ var nka: u32 = 0;
+ while (nka < 1000) {
+ var len = try reader.readInt(u32, .big);
+ // All messages except Keepalive start with a single byte message type.
+ // Skip keepalive messages, we don't care (unless you're spamming them)
+ if (len == 0) {
+ nka += 1;
+ continue;
+ }
+ var mt = try reader.readByte();
+ inline for (@typeInfo(Expected).Union.fields) |field| {
+ const msgType = field.type;
+ if (msgType.Tag == mt) {
+ return @unionInit(Expected, field.name, try msgType.read(a, len-1, reader));
+ }
+ } else {
+ return error.ProtocolError;
+ }
+ } else {
+ return error.ProtocolError;
+ }
+}
+
+// Handshake message has a different structure to the rest. And it's only read once per connection.
pub const Handshake = struct {
info_hash: [20]u8,
peer_id: [20]u8,
@@ -24,3 +70,87 @@ pub const Handshake = struct {
try writer.writeAll(&self.peer_id);
}
};
+
+
+pub const Unchoke = struct {
+ pub const Tag: u8 = 1;
+ pub fn read(a: std.mem.Allocator, len: usize, reader: anytype) !Unchoke {
+ _ = a;
+ _ = reader;
+ if (len != 0) return error.ProtocolError;
+ return .{};
+ }
+};
+
+pub const Interested = struct {
+ pub const Tag: u8 = 2;
+ pub fn read(a: std.mem.Allocator, len: usize, reader: anytype) !Interested {
+ _ = a;
+ _ = reader;
+ if (len != 0) return error.ProtocolError;
+ return .{};
+ }
+ pub fn write(writer: anytype) !void {
+ try writer.writeInt(u32, 1, .big);
+ try writer.writeInt(u8, Tag, .big);
+ }
+};
+
+pub const Bitfield = struct {
+ pub const Tag: u8 = 5;
+ pub fn read(a: std.mem.Allocator, len: usize, reader: anytype) !Bitfield {
+ // TODO actually read this message and do something useful.
+ _ = a;
+ try reader.skipBytes(len, .{}); //
+ return .{};
+ }
+};
+
+pub const Request = struct {
+ pub const Tag: u8 = 6;
+ index: u32,
+ begin: u32,
+ length: u32,
+ pub fn write(self: Request, writer: anytype) !void {
+ try writer.writeInt(u32, 13, .big);
+ try writer.writeInt(u8, Tag, .big);
+ try writer.writeInt(u32,self.index, .big);
+ try writer.writeInt(u32,self.begin, .big);
+ try writer.writeInt(u32,self.length, .big);
+ }
+};
+
+pub const Piece = struct {
+ pub const Tag: u8 = 7;
+ index: u32,
+ begin: u32,
+ block: []const u8,
+ pub fn read(a: std.mem.Allocator, len: usize, reader: anytype) !Piece {
+ if (len <= 8) {
+ std.log.err("Piece#read len {}", .{len});
+ return error.ProtocolError;
+ }
+ var ix = try reader.readInt(u32, .big);
+ var be = try reader.readInt(u32, .big);
+ var bl = try a.alloc(u8, len-8);
+ errdefer a.free(bl);
+ try reader.readNoEof(bl);
+ return .{
+ .index = ix,
+ .begin = be,
+ .block = bl,
+ };
+ }
+ pub fn deinit(self: *Piece, a: std.mem.Allocator) void {
+ a.free(self.block);
+ }
+};
+
+test "read any" {
+ const a = std.testing.allocator;
+ var fbs = std.io.fixedBufferStream(&[_]u8{0, 0, 0, 1, Unchoke.Tag});
+ var r = fbs.reader();
+ const T = union(enum) {u: Unchoke, i: Interested, b: Bitfield};
+ var msg: T = try readAnyMessage(a, r, T);
+ try std.testing.expect(msg == .u);
+} \ No newline at end of file
diff --git a/src/tracker_protocol.zig b/src/tracker_protocol.zig
new file mode 100644
index 0000000..0a6a9c2
--- /dev/null
+++ b/src/tracker_protocol.zig
@@ -0,0 +1,63 @@
+const std = @import("std");
+const bencode = @import("bencode.zig");
+
+// https://wiki.theory.org/BitTorrentSpecification#Tracker_Response
+pub const TrackerResp = struct {
+ // interval: u64,
+ peers: []std.net.Address,
+
+ pub fn parse(a: std.mem.Allocator, b: bencode.BValue) !TrackerResp {
+ var ipl = std.ArrayList(std.net.Address).init(a);
+ defer ipl.deinit();
+ var d = b.asDict() catch return error.Malformatted;
+ var pb = d.get("peers") orelse return error.Malformatted;
+ var ps = pb.asString() catch return error.Malformatted;
+ if ((ps.len % 6) != 0) return error.Malformatted;
+ for (0..ps.len / 6) |ix| {
+ const start = ix * 6;
+ const port = std.mem.readInt(u16, ps[start + 4 .. start + 6][0..2], .big);
+ var ip = [_]u8{0} ** 4;
+ @memcpy(&ip, ps[start .. start + 4]);
+ try ipl.append(std.net.Address.initIp4(ip, port));
+ }
+ return .{
+ .peers = try ipl.toOwnedSlice(),
+ };
+ }
+
+ pub fn deinit(self: *TrackerResp, a: std.mem.Allocator) void {
+ a.free(self.peers);
+ }
+};
+
+pub fn trackerRequestUrl(a: std.mem.Allocator, info_hash: [20]u8, peer_id: [20]u8, left: usize, announce: []const u8) ![]const u8 {
+ var q = std.StringHashMap([]const u8).init(a);
+ defer q.deinit();
+ try q.put("info_hash", &info_hash);
+ try q.put("peer_id", &peer_id);
+ try q.put("port", "6881");
+ try q.put("uploaded", "0");
+ try q.put("downloaded", "0");
+ var buf_left = [_]u8{0} ** 1024;
+ try q.put("left", try std.fmt.bufPrint(&buf_left, "{}", .{left}));
+ try q.put("compact", "1");
+ var qs = try toqs(a, q);
+ defer a.free(qs);
+ return try std.fmt.allocPrint(a, "{s}?{s}", .{ announce, qs });
+}
+
+fn toqs(a: std.mem.Allocator, hm: std.StringHashMap([]const u8)) ![]const u8 {
+ var al = std.ArrayList(u8).init(a);
+ defer al.deinit();
+ var w = al.writer();
+ var it = hm.iterator();
+ var first = true;
+ while (it.next()) |entry| {
+ if (!first) try w.writeByte('&');
+ try std.Uri.writeEscapedQuery(w, entry.key_ptr.*);
+ try w.writeByte('=');
+ try std.Uri.writeEscapedQuery(w, entry.value_ptr.*);
+ first = false;
+ }
+ return try al.toOwnedSlice();
+}