aboutsummaryrefslogtreecommitdiff
path: root/src/seekable_http_range.zig
diff options
context:
space:
mode:
Diffstat (limited to 'src/seekable_http_range.zig')
-rw-r--r--src/seekable_http_range.zig243
1 files changed, 243 insertions, 0 deletions
diff --git a/src/seekable_http_range.zig b/src/seekable_http_range.zig
new file mode 100644
index 0000000..059d1dd
--- /dev/null
+++ b/src/seekable_http_range.zig
@@ -0,0 +1,243 @@
+//! Reader/SeekableStream implementation of a resource over HTTP using Range requests.
+
+const std = @import("std");
+
+const SeekableHttpRange = @This();
+
+pub const Opts = struct {
+ allocator: std.mem.Allocator,
+ client: *std.http.Client,
+ url: []const u8,
+ buffer_size: usize = 1024,
+};
+
+allocator: std.mem.Allocator,
+client: *std.http.Client,
+url: []const u8,
+pos: u64 = 0,
+endpos: u64,
+buffer_pos: ?u64 = null,
+buffer: []u8,
+
+pub fn init(opts: Opts) !SeekableHttpRange {
+ var a = opts.allocator;
+ var client = opts.client;
+ var url = opts.url;
+ var res = try client.fetch(a, .{
+ .method = .HEAD,
+ .location = .{ .url = url },
+ .headers = std.http.Headers{ .allocator = a, .owned = false },
+ });
+ defer res.deinit();
+ if (res.status != .ok) return error.HttpStatusError;
+ const accept_header_val = res.headers.getFirstValue("accept-ranges") orelse return error.HttpRangeNotSpecified;
+ if (std.mem.eql(u8, accept_header_val, "none")) return error.HttpRangeNotSupported;
+ if (!std.mem.eql(u8, accept_header_val, "bytes")) return error.HttpRangeUnsupportedUnit;
+ const content_length_val = res.headers.getFirstValue("content-length") orelse return error.NoContentLength;
+ const content_length = std.fmt.parseInt(u64, content_length_val, 10) catch return error.ContentLengthFormatError;
+ var buffer = try a.alloc(u8, opts.buffer_size);
+ return .{
+ .allocator = opts.allocator,
+ .client = client,
+ .url = url,
+ .buffer = buffer,
+ .endpos = content_length,
+ };
+}
+
+pub fn deinit(self: *SeekableHttpRange) void {
+ self.allocator.free(self.buffer);
+}
+
+// it's a big list :(
+pub const ReadError = error{
+ UnsupportedHeader,
+ UnexpectedCharacter,
+ InvalidFormat,
+ InvalidPort,
+ OutOfMemory,
+ ConnectionRefused,
+ NetworkUnreachable,
+ ConnectionTimedOut,
+ ConnectionResetByPeer,
+ TemporaryNameServerFailure,
+ NameServerFailure,
+ UnknownHostName,
+ HostLacksNetworkAddresses,
+ UnexpectedConnectFailure,
+ TlsInitializationFailed,
+ UnsupportedUrlScheme,
+ UnexpectedWriteFailure,
+ InvalidContentLength,
+ UnsupportedTransferEncoding,
+ Overflow,
+ InvalidCharacter,
+ UriMissingHost,
+ CertificateBundleLoadFailure,
+ TlsFailure,
+ TlsAlert,
+ UnexpectedReadFailure,
+ EndOfStream,
+ HttpChunkInvalid,
+ SystemResources,
+ FileLocksNotSupported,
+ Unexpected,
+ AccessDenied,
+ NotWriteable,
+ MessageTooLong,
+ Unseekable,
+ InputOutput,
+ IsDir,
+ OperationAborted,
+ BrokenPipe,
+ NotOpenForReading,
+ NetNameDeleted,
+ WouldBlock,
+ MessageNotCompleted,
+ HttpHeadersExceededSizeLimit,
+ HttpHeadersInvalid,
+ HttpHeaderContinuationsUnsupported,
+ HttpTransferEncodingUnsupported,
+ HttpConnectionHeaderUnsupported,
+ CompressionNotSupported,
+ TooManyHttpRedirects,
+ RedirectRequiresResend,
+ HttpRedirectMissingLocation,
+ CompressionInitializationFailed,
+ DecompressionFailure,
+ InvalidTrailers,
+ StreamTooLong,
+ DiskQuota,
+ FileTooBig,
+ NoSpaceLeft,
+ DeviceBusy,
+ InvalidArgument,
+ NotOpenForWriting,
+ LockViolation,
+ HttpStatusError,
+ HttpNoBody,
+};
+
+pub const Reader = std.io.Reader(*SeekableHttpRange, ReadError, read);
+
+pub fn reader(self: *SeekableHttpRange) Reader {
+ return .{
+ .context = self,
+ };
+}
+
+pub fn read(self: *SeekableHttpRange, buffer: []u8) ReadError!usize {
+ var n: usize = 0;
+ for (0..buffer.len) |ix| {
+ const b = try readByte(self);
+ const bb = b orelse break;
+ buffer[ix] = bb;
+ n += 1;
+ }
+ return n;
+}
+
+fn readByte(self: *SeekableHttpRange) !?u8 {
+ if (self.pos >= self.endpos) return null;
+
+ if (self.buffer_pos) |buffer_pos| {
+ const buffer_end: u64 = buffer_pos + self.buffer.len;
+ if (self.pos >= buffer_pos and self.pos < buffer_end) {
+ return self.readFromBuffer();
+ }
+ }
+
+ // refill the buffer from pos
+ // max u64 formatted as decimal is 20 bytes long
+ const range_buf_len = "bytes=-".len + 20 + 20;
+ var range_buf = [_]u8{0} ** range_buf_len;
+ const nbuf_end = @min(self.pos + self.buffer.len, self.endpos);
+ // Range request end is _inclusive_
+ var range_value = std.fmt.bufPrint(&range_buf, "bytes={}-{}", .{ self.pos, nbuf_end - 1 }) catch unreachable;
+ var headers = std.http.Headers{ .allocator = self.allocator };
+ defer headers.deinit();
+ try headers.append("range", range_value);
+ var res = try self.client.fetch(self.allocator, .{
+ .location = .{ .url = self.url },
+ .headers = headers,
+ });
+ defer res.deinit();
+ if (res.status != .partial_content) return error.HttpStatusError;
+ const body = res.body orelse return error.HttpNoBody;
+ std.mem.copyForwards(u8, self.buffer, body);
+ self.buffer_pos = self.pos;
+ return self.readFromBuffer();
+}
+
+fn readFromBuffer(self: *SeekableHttpRange) u8 {
+ const pos_in_buffer = self.pos - self.buffer_pos.?;
+ defer self.pos += 1;
+ return self.buffer[pos_in_buffer];
+}
+
+pub const SeekError = error{};
+pub const GetSeekPosError = error{};
+pub const SeekableStream = std.io.SeekableStream(*SeekableHttpRange, SeekError, GetSeekPosError, seekTo, seekBy, getPos, getEndPos);
+
+pub fn seekableStream(self: *SeekableHttpRange) SeekableStream {
+ return .{
+ .context = self,
+ };
+}
+
+// copying from FixedBufferStream: clamp rather than return an error
+pub fn seekTo(self: *SeekableHttpRange, pos: u64) SeekError!void {
+ self.pos = std.math.clamp(pos, 0, self.endpos);
+}
+
+// copying from FixedBufferStream: clamp rather than return an error
+pub fn seekBy(self: *SeekableHttpRange, delta: i64) SeekError!void {
+ const np: u64 = if (std.math.sign(delta) == -1)
+ std.math.sub(u64, self.pos, std.math.absCast(delta)) catch 0
+ else
+ std.math.add(u64, self.pos, std.math.absCast(delta)) catch std.math.maxInt(u64);
+ self.pos = std.math.clamp(np, 0, self.endpos);
+}
+
+pub fn getPos(self: *SeekableHttpRange) GetSeekPosError!u64 {
+ return self.pos;
+}
+
+pub fn getEndPos(self: *SeekableHttpRange) GetSeekPosError!u64 {
+ return self.endpos;
+}
+
+test "endBytes" {
+ const a = std.testing.allocator;
+ var client = std.http.Client{ .allocator = a };
+ defer client.deinit();
+ var range = try SeekableHttpRange.init(.{ .allocator = a, .client = &client, .url = "https://mfashby.net/posts.zip" });
+ defer range.deinit();
+ var ss = range.seekableStream();
+ var rr = range.reader();
+
+ var buf = try a.alloc(u8, 20);
+ defer a.free(buf);
+
+ try ss.seekTo(try ss.getEndPos() - 20);
+ try rr.readNoEof(buf);
+ try std.testing.expectEqualSlices(u8, &[_]u8{ 0x05, 0x06, 0x00, 0x00, 0x00, 0x00, 0x31, 0x00, 0x31, 0x00, 0xFD, 0x11, 0x00, 0x00, 0xEE, 0xB7, 0x00, 0x00, 0x00, 0x00 }, buf);
+
+ try ss.seekBy(-300);
+ try rr.readNoEof(buf);
+ try std.testing.expectEqualSlices(u8, &[_]u8{ 0x00, 0x00, 0x00, 0x08, 0x00, 0xA1, 0x3A, 0x17, 0x57, 0x85, 0x9F, 0x53, 0xCE, 0x26, 0x05, 0x00, 0x00, 0x37, 0x0A, 0x00 }, buf);
+
+ try ss.seekTo(0);
+ try ss.seekBy(-1);
+ try std.testing.expectEqual(@as(u64, 0), try ss.getPos());
+
+ try ss.seekBy(std.math.minInt(i64));
+ try std.testing.expectEqual(@as(u64, 0), try ss.getPos());
+
+ try ss.seekTo(try ss.getEndPos());
+ try ss.seekBy(1);
+ try std.testing.expectEqual(try ss.getEndPos(), try ss.getPos());
+
+ try ss.seekBy(std.math.maxInt(i64));
+ try std.testing.expectEqual(try ss.getEndPos(), try ss.getPos());
+}