diff options
Diffstat (limited to 'src/seekable_http_range.zig')
-rw-r--r-- | src/seekable_http_range.zig | 243 |
1 files changed, 243 insertions, 0 deletions
diff --git a/src/seekable_http_range.zig b/src/seekable_http_range.zig new file mode 100644 index 0000000..059d1dd --- /dev/null +++ b/src/seekable_http_range.zig @@ -0,0 +1,243 @@ +//! Reader/SeekableStream implementation of a resource over HTTP using Range requests. + +const std = @import("std"); + +const SeekableHttpRange = @This(); + +pub const Opts = struct { + allocator: std.mem.Allocator, + client: *std.http.Client, + url: []const u8, + buffer_size: usize = 1024, +}; + +allocator: std.mem.Allocator, +client: *std.http.Client, +url: []const u8, +pos: u64 = 0, +endpos: u64, +buffer_pos: ?u64 = null, +buffer: []u8, + +pub fn init(opts: Opts) !SeekableHttpRange { + var a = opts.allocator; + var client = opts.client; + var url = opts.url; + var res = try client.fetch(a, .{ + .method = .HEAD, + .location = .{ .url = url }, + .headers = std.http.Headers{ .allocator = a, .owned = false }, + }); + defer res.deinit(); + if (res.status != .ok) return error.HttpStatusError; + const accept_header_val = res.headers.getFirstValue("accept-ranges") orelse return error.HttpRangeNotSpecified; + if (std.mem.eql(u8, accept_header_val, "none")) return error.HttpRangeNotSupported; + if (!std.mem.eql(u8, accept_header_val, "bytes")) return error.HttpRangeUnsupportedUnit; + const content_length_val = res.headers.getFirstValue("content-length") orelse return error.NoContentLength; + const content_length = std.fmt.parseInt(u64, content_length_val, 10) catch return error.ContentLengthFormatError; + var buffer = try a.alloc(u8, opts.buffer_size); + return .{ + .allocator = opts.allocator, + .client = client, + .url = url, + .buffer = buffer, + .endpos = content_length, + }; +} + +pub fn deinit(self: *SeekableHttpRange) void { + self.allocator.free(self.buffer); +} + +// it's a big list :( +pub const ReadError = error{ + UnsupportedHeader, + UnexpectedCharacter, + InvalidFormat, + InvalidPort, + OutOfMemory, + ConnectionRefused, + NetworkUnreachable, + ConnectionTimedOut, + ConnectionResetByPeer, + TemporaryNameServerFailure, + NameServerFailure, + UnknownHostName, + HostLacksNetworkAddresses, + UnexpectedConnectFailure, + TlsInitializationFailed, + UnsupportedUrlScheme, + UnexpectedWriteFailure, + InvalidContentLength, + UnsupportedTransferEncoding, + Overflow, + InvalidCharacter, + UriMissingHost, + CertificateBundleLoadFailure, + TlsFailure, + TlsAlert, + UnexpectedReadFailure, + EndOfStream, + HttpChunkInvalid, + SystemResources, + FileLocksNotSupported, + Unexpected, + AccessDenied, + NotWriteable, + MessageTooLong, + Unseekable, + InputOutput, + IsDir, + OperationAborted, + BrokenPipe, + NotOpenForReading, + NetNameDeleted, + WouldBlock, + MessageNotCompleted, + HttpHeadersExceededSizeLimit, + HttpHeadersInvalid, + HttpHeaderContinuationsUnsupported, + HttpTransferEncodingUnsupported, + HttpConnectionHeaderUnsupported, + CompressionNotSupported, + TooManyHttpRedirects, + RedirectRequiresResend, + HttpRedirectMissingLocation, + CompressionInitializationFailed, + DecompressionFailure, + InvalidTrailers, + StreamTooLong, + DiskQuota, + FileTooBig, + NoSpaceLeft, + DeviceBusy, + InvalidArgument, + NotOpenForWriting, + LockViolation, + HttpStatusError, + HttpNoBody, +}; + +pub const Reader = std.io.Reader(*SeekableHttpRange, ReadError, read); + +pub fn reader(self: *SeekableHttpRange) Reader { + return .{ + .context = self, + }; +} + +pub fn read(self: *SeekableHttpRange, buffer: []u8) ReadError!usize { + var n: usize = 0; + for (0..buffer.len) |ix| { + const b = try readByte(self); + const bb = b orelse break; + buffer[ix] = bb; + n += 1; + } + return n; +} + +fn readByte(self: *SeekableHttpRange) !?u8 { + if (self.pos >= self.endpos) return null; + + if (self.buffer_pos) |buffer_pos| { + const buffer_end: u64 = buffer_pos + self.buffer.len; + if (self.pos >= buffer_pos and self.pos < buffer_end) { + return self.readFromBuffer(); + } + } + + // refill the buffer from pos + // max u64 formatted as decimal is 20 bytes long + const range_buf_len = "bytes=-".len + 20 + 20; + var range_buf = [_]u8{0} ** range_buf_len; + const nbuf_end = @min(self.pos + self.buffer.len, self.endpos); + // Range request end is _inclusive_ + var range_value = std.fmt.bufPrint(&range_buf, "bytes={}-{}", .{ self.pos, nbuf_end - 1 }) catch unreachable; + var headers = std.http.Headers{ .allocator = self.allocator }; + defer headers.deinit(); + try headers.append("range", range_value); + var res = try self.client.fetch(self.allocator, .{ + .location = .{ .url = self.url }, + .headers = headers, + }); + defer res.deinit(); + if (res.status != .partial_content) return error.HttpStatusError; + const body = res.body orelse return error.HttpNoBody; + std.mem.copyForwards(u8, self.buffer, body); + self.buffer_pos = self.pos; + return self.readFromBuffer(); +} + +fn readFromBuffer(self: *SeekableHttpRange) u8 { + const pos_in_buffer = self.pos - self.buffer_pos.?; + defer self.pos += 1; + return self.buffer[pos_in_buffer]; +} + +pub const SeekError = error{}; +pub const GetSeekPosError = error{}; +pub const SeekableStream = std.io.SeekableStream(*SeekableHttpRange, SeekError, GetSeekPosError, seekTo, seekBy, getPos, getEndPos); + +pub fn seekableStream(self: *SeekableHttpRange) SeekableStream { + return .{ + .context = self, + }; +} + +// copying from FixedBufferStream: clamp rather than return an error +pub fn seekTo(self: *SeekableHttpRange, pos: u64) SeekError!void { + self.pos = std.math.clamp(pos, 0, self.endpos); +} + +// copying from FixedBufferStream: clamp rather than return an error +pub fn seekBy(self: *SeekableHttpRange, delta: i64) SeekError!void { + const np: u64 = if (std.math.sign(delta) == -1) + std.math.sub(u64, self.pos, std.math.absCast(delta)) catch 0 + else + std.math.add(u64, self.pos, std.math.absCast(delta)) catch std.math.maxInt(u64); + self.pos = std.math.clamp(np, 0, self.endpos); +} + +pub fn getPos(self: *SeekableHttpRange) GetSeekPosError!u64 { + return self.pos; +} + +pub fn getEndPos(self: *SeekableHttpRange) GetSeekPosError!u64 { + return self.endpos; +} + +test "endBytes" { + const a = std.testing.allocator; + var client = std.http.Client{ .allocator = a }; + defer client.deinit(); + var range = try SeekableHttpRange.init(.{ .allocator = a, .client = &client, .url = "https://mfashby.net/posts.zip" }); + defer range.deinit(); + var ss = range.seekableStream(); + var rr = range.reader(); + + var buf = try a.alloc(u8, 20); + defer a.free(buf); + + try ss.seekTo(try ss.getEndPos() - 20); + try rr.readNoEof(buf); + try std.testing.expectEqualSlices(u8, &[_]u8{ 0x05, 0x06, 0x00, 0x00, 0x00, 0x00, 0x31, 0x00, 0x31, 0x00, 0xFD, 0x11, 0x00, 0x00, 0xEE, 0xB7, 0x00, 0x00, 0x00, 0x00 }, buf); + + try ss.seekBy(-300); + try rr.readNoEof(buf); + try std.testing.expectEqualSlices(u8, &[_]u8{ 0x00, 0x00, 0x00, 0x08, 0x00, 0xA1, 0x3A, 0x17, 0x57, 0x85, 0x9F, 0x53, 0xCE, 0x26, 0x05, 0x00, 0x00, 0x37, 0x0A, 0x00 }, buf); + + try ss.seekTo(0); + try ss.seekBy(-1); + try std.testing.expectEqual(@as(u64, 0), try ss.getPos()); + + try ss.seekBy(std.math.minInt(i64)); + try std.testing.expectEqual(@as(u64, 0), try ss.getPos()); + + try ss.seekTo(try ss.getEndPos()); + try ss.seekBy(1); + try std.testing.expectEqual(try ss.getEndPos(), try ss.getPos()); + + try ss.seekBy(std.math.maxInt(i64)); + try std.testing.expectEqual(try ss.getEndPos(), try ss.getPos()); +} |