diff options
Diffstat (limited to 'src')
-rw-r--r-- | src/main.zig | 5 | ||||
-rw-r--r-- | src/zigwebserver.zig | 342 |
2 files changed, 347 insertions, 0 deletions
diff --git a/src/main.zig b/src/main.zig index a37f6a0..54e6b4b 100644 --- a/src/main.zig +++ b/src/main.zig @@ -1,4 +1,5 @@ const std = @import("std"); +const zws = @import("zigwebserver.zig"); // extremely basic http file server pub fn main() !void { @@ -61,3 +62,7 @@ fn serve_error(res: *std.http.Server.Response, status: std.http.Status) !void { \\ <!doctype html><html><body>{s}</body></html> , .{phrase}); } + +test { + _ = zws; +} diff --git a/src/zigwebserver.zig b/src/zigwebserver.zig new file mode 100644 index 0000000..6ce4439 --- /dev/null +++ b/src/zigwebserver.zig @@ -0,0 +1,342 @@ +const std = @import("std"); + +/// Wrapper around a std.http.Server to run a multi-threaded HTTP server using thread-per-request and arena-per-request +/// Context is .clone()'d and passed to each request, useful for passing user data to each request handler e.g. a database connection. +/// HandlerType provides a single method handle() which is used to actually handle requests. +pub fn Server(comptime Context: type, comptime Handler: type) type { + return struct { + address: std.net.Address, + context: Context, + handler: Handler, + + max_header_size: usize = 8192, + n_threads: u32 = 50, + + allocator: std.mem.Allocator, + + pub fn serve(self: @This()) !void { + var tp = std.Thread.Pool{ .threads = &[_]std.Thread{}, .allocator = self.allocator }; + try tp.init(.{ .allocator = self.allocator, .n_jobs = self.n_threads }); + defer tp.deinit(); + + var svr_internal = std.http.Server.init(self.allocator, .{ .reuse_address = true }); + defer svr_internal.deinit(); + try svr_internal.listen(self.address); + std.log.info("server listening on {}", .{self.address}); + while (true) { + var aa = std.heap.ArenaAllocator.init(self.allocator); // will be freed by the spawned thread. + var conn = try svr_internal.accept(.{ .allocator = aa.allocator(), .header_strategy = .{ .dynamic = self.max_header_size } }); + const ctx: Context = self.context.clone(); + try tp.spawn(handle, .{ self, &conn, ctx, aa }); + } + } + + fn handle(self: @This(), res: *std.http.Server.Response, ctx: Context, aa: std.heap.ArenaAllocator) void { + defer aa.deinit(); + defer ctx.deinit(); + defer res.deinit(); + if (res.wait()) { + if (self.handler.handle(res, ctx)) { + std.log.info("Success handling request for {}", .{res.address}); + } else |err| { + std.log.err("Error handling request for {} : {}", .{ res.address, err }); + if (handle_simple_response(res, "<html><body>Server error!</body></html>", .internal_server_error)) {} else |err2| { + std.log.err("Error sending error page for {} : {}", .{ res.address, err2 }); + } + } + } else |_| { + // Do nothing + } + + if (res.state != .finished) { + std.log.err("request wasn't finished!", .{}); + } + } + + fn handle_simple_response(res: *std.http.Server.Response, content: []const u8, status: std.http.Status) !void { + res.status = status; + res.transfer_encoding = .{ .content_length = content.len }; + try res.headers.append("content-type", "text/html"); + try res.do(); + try res.writer().writeAll(content); + try res.finish(); + } + }; +} + +const Response = std.http.Server.Response; + +pub const Params = std.StringHashMap([]const u8); + +/// Routing component for an http server with wildcard matching and parameter +/// Handles matching a request to a handler. +/// Handler pattern can either be matched exactly +/// or it can have matcher segments, so +/// "/" -> matches request for "/" only +/// "/foo" matches request for "/foo" only +/// "/foo/{bar}/baz" matches request for "/foo/123/baz" and "/foo/bar/baz", and Params would contain "bar":"123" and "bar":"bar" respectively. +/// or it can have terminating wildcards, so +/// "/foo/*" -> matches "/foo", "/foo/bar","/foo/bar/baz" +/// "/*" -> matches all requests +/// TODO something clever to parse path parameters into the appropriate types, maybe smth like "/foo/{bar:u32}/baz" +/// TODO something to handle query parameters and request body too +pub fn Router(comptime Context: type, comptime ErrorType: type) type { + return struct { + pub const Handler = struct { + method: std.http.Method, + pattern: []const u8, + handle_fn: *const fn (res: *Response, ctx: Context, Params) ErrorType!void, + }; + + handlers: []const Handler, + + notfound: *const fn (res: *Response, ctx: Context) ErrorType!void, + + pub fn handle(self: @This(), res: *Response, ctx: Context) ErrorType!void { + // Routing can only happen after we have the headers + // It is a programmer error to call this without calling .wait first. + if (res.state != .waited) unreachable; + + handler_loop: for (self.handlers) |handler| { + if (handler.method != res.request.method) { + continue :handler_loop; + } + + var path_params: Params = std.StringHashMap([]const u8).init(res.allocator); + defer path_params.deinit(); + + var handle_split = std.mem.splitScalar(u8, handler.pattern, '/'); + var req_split = std.mem.splitScalar(u8, res.request.target, '/'); + + while (true) { + const maybe_handle_seg = handle_split.next(); + const maybe_req_seg = req_split.next(); + if (maybe_handle_seg == null and maybe_req_seg == null) { + // End of both handler and request, they matched this far so + // the handler must handle. + try handler.handle_fn(res, ctx, path_params); + break :handler_loop; + } else if (maybe_handle_seg != null and std.mem.eql(u8, maybe_handle_seg.?, "*")) { + // Wildcard, this matches + try handler.handle_fn(res, ctx, path_params); + break :handler_loop; + } else if (maybe_handle_seg == null or maybe_req_seg == null) { + // path lengths don't match, try the next handler + continue :handler_loop; + } else { + const handle_seg = maybe_handle_seg.?; + const req_seg = maybe_req_seg.?; + if (handle_seg.len > 0 and handle_seg[0] == '{' and handle_seg[handle_seg.len - 1] == '}') { + // Capture and keep going + const key = handle_seg[1 .. handle_seg.len - 1]; + try path_params.put(key, req_seg); + } else if (std.mem.eql(u8, handle_seg, req_seg)) { + // segments match, keep going + } else { + // mismatch, try the next handler + continue :handler_loop; + } + } + } + } else { + try self.notfound(res, ctx); + } + } + }; +} + +const T = struct { + const TestCtx = struct {}; + const TestErr = error{ TestError, OutOfMemory }; + const TestRouter = Router(TestCtx, TestErr); + + var notfoundinvoked = false; + fn notfound(_: *Response, _: TestCtx) TestErr!void { + notfoundinvoked = true; + } + var route1invoked = false; + var route1params: ?Params = null; + fn route1(_: *Response, _: TestCtx, p: Params) TestErr!void { + route1invoked = true; + route1params = try p.clone(); + } + var route2invoked = false; + fn route2(_: *Response, _: TestCtx, _: Params) TestErr!void { + route2invoked = true; + } + fn reset() void { + notfoundinvoked = false; + if (route1params != null) route1params.?.deinit(); + route1params = null; + route1invoked = false; + route2invoked = false; + } + + fn runTestRouter(handlers: []TestRouter.Handler, target: []const u8) !void { + const alloc = std.testing.allocator; + const ctx = TestCtx{}; + var buf: [128]u8 = undefined; + const req = std.http.Server.Request{ + .method = .GET, + .target = target, + .version = .@"HTTP/1.1", + .headers = std.http.Headers.init(alloc), + .parser = std.http.protocol.HeadersParser.initStatic(&buf), + }; + const sock = try std.net.tcpConnectToAddress(std.net.Address{ .in = std.net.Ip4Address.init(.{ 127, 0, 0, 1 }, 22) }); + defer sock.close(); + const conn = std.http.Server.Connection{ + .stream = sock, + .protocol = .plain, + }; + var res = Response{ + .allocator = alloc, + .address = std.net.Address{ .in = std.net.Ip4Address.init(.{ 127, 0, 0, 1 }, 8080) }, + .connection = conn, + .headers = std.http.Headers.init(alloc), + .request = req, + .state = .waited, + }; + const router = TestRouter{ + .handlers = handlers, + .notfound = notfound, + }; + try router.handle(&res, ctx); + } + + // fn hmof(x: []const u8, y: []const u8) std.StringHashMap([]const u8) { + // var hm = std.StringHashMap([]const u8).init(std.testing.allocator); + // hm.put(x, y) catch @panic("failed to create hmof in test"); + // return hm; + // } + + const TestCase = struct { + target: []const u8, + route1: ?[]const u8 = null, + route2: ?[]const u8 = null, + notfoundexpected: bool = false, + route1expected: bool = false, + route2expected: bool = false, + // route1paramsexpected: ?Params = null, + }; + + fn expectEqual(maybe_pexp: ?Params, maybe_pact: ?Params) !void { + if (maybe_pexp == null and maybe_pact == null) { + // fine + } else if (maybe_pexp == null or maybe_pact == null) { + std.debug.print("isnull(pexp) = {} isnull(pact) = {}", .{ maybe_pexp == null, maybe_pact == null }); + return error.TestUnexpectedResult; + } else { + const pexp = maybe_pexp.?; + const pact = maybe_pact.?; + try std.testing.expectEqual(pexp.count(), pact.count()); + var it = pexp.keyIterator(); + var kexp = it.next(); + while (kexp != null) : (kexp = it.next()) { + var vexp = pexp.get(kexp.?.*).?; + var maybe_vact = pact.get(kexp.?.*); + if (maybe_vact) |vact| { + std.debug.print("{s} {s}", .{ vexp, vact }); + try std.testing.expectEqual(vexp, vact); + } else { + std.debug.print("expected key {s} not found in actual", .{kexp.?.*}); + return error.TestUnexpectedResult; + } + } + } + } + + test "router tests" { + // var m0 = std.StringHashMap([]const u8).init(std.testing.allocator); + // defer m0.deinit(); + // var m1 = hmof("var", "bam"); + // defer m1.deinit(); + const cases = [_]TestCase{ + .{ + .target = "/", + .notfoundexpected = true, + }, + .{ + .target = "/", + .route1 = "/", + .route1expected = true, + // .route1paramsexpected = m0, + }, + .{ + .target = "/foo", + .route1 = "/bar", + .notfoundexpected = true, + }, + .{ + .target = "/bar", + .route1 = "/foo", + .route2 = "/bar", + .route2expected = true, + }, + .{ + .target = "/baz", + .route1 = "/", + .notfoundexpected = true, + }, + .{ + .target = "/baz", + .route1 = "/*", + .route2 = "/bar", + .route1expected = true, + // .route1paramsexpected = m0, + }, + .{ + .target = "/baz", + .route1 = "/*", + .route2 = "/baz", + .route1expected = true, // first matching route takes prio + // .route1paramsexpected = m0, + }, + .{ + .target = "/baz", + .route1 = "/baz", + .route2 = "/*", + .route1expected = true, // first matching route takes prio + // .route1paramsexpected = m0, + }, + .{ + .target = "/baz/bam", + .route1 = "/baz/{var}", + .route1expected = true, + // .route1paramsexpected = m1, + }, + .{ + .target = "/baz/bam/boo", + .route1 = "/baz/{var}/boo", + .route1expected = true, + // .route1paramsexpected = m1, + }, + // .{ + // .target = "/baz/bam/bar", + // .route1 = "/baz/{var}/boo", + // .notfoundexpected = true, + // }, + }; + + for (cases) |case| { + defer reset(); + var handlers = std.ArrayList(TestRouter.Handler).init(std.testing.allocator); + defer handlers.deinit(); + if (case.route1) |r1| { + try handlers.append(TestRouter.Handler{ .pattern = r1, .method = .GET, .handle_fn = route1 }); + } + if (case.route2) |r2| { + try handlers.append(TestRouter.Handler{ .pattern = r2, .method = .GET, .handle_fn = route2 }); + } + try runTestRouter(handlers.items, case.target); + try std.testing.expectEqual(case.notfoundexpected, notfoundinvoked); + try std.testing.expectEqual(case.route1expected, route1invoked); + try std.testing.expectEqual(case.route2expected, route2invoked); + // try expectEqual(case.route1paramsexpected, route1params); // TODO assert captures + } + } +}; + +test { + _ = T; +} |