summaryrefslogtreecommitdiff
path: root/src/zigwebserver.zig
diff options
context:
space:
mode:
Diffstat (limited to 'src/zigwebserver.zig')
-rw-r--r--src/zigwebserver.zig342
1 files changed, 342 insertions, 0 deletions
diff --git a/src/zigwebserver.zig b/src/zigwebserver.zig
new file mode 100644
index 0000000..6ce4439
--- /dev/null
+++ b/src/zigwebserver.zig
@@ -0,0 +1,342 @@
+const std = @import("std");
+
+/// Wrapper around a std.http.Server to run a multi-threaded HTTP server using thread-per-request and arena-per-request
+/// Context is .clone()'d and passed to each request, useful for passing user data to each request handler e.g. a database connection.
+/// HandlerType provides a single method handle() which is used to actually handle requests.
+pub fn Server(comptime Context: type, comptime Handler: type) type {
+ return struct {
+ address: std.net.Address,
+ context: Context,
+ handler: Handler,
+
+ max_header_size: usize = 8192,
+ n_threads: u32 = 50,
+
+ allocator: std.mem.Allocator,
+
+ pub fn serve(self: @This()) !void {
+ var tp = std.Thread.Pool{ .threads = &[_]std.Thread{}, .allocator = self.allocator };
+ try tp.init(.{ .allocator = self.allocator, .n_jobs = self.n_threads });
+ defer tp.deinit();
+
+ var svr_internal = std.http.Server.init(self.allocator, .{ .reuse_address = true });
+ defer svr_internal.deinit();
+ try svr_internal.listen(self.address);
+ std.log.info("server listening on {}", .{self.address});
+ while (true) {
+ var aa = std.heap.ArenaAllocator.init(self.allocator); // will be freed by the spawned thread.
+ var conn = try svr_internal.accept(.{ .allocator = aa.allocator(), .header_strategy = .{ .dynamic = self.max_header_size } });
+ const ctx: Context = self.context.clone();
+ try tp.spawn(handle, .{ self, &conn, ctx, aa });
+ }
+ }
+
+ fn handle(self: @This(), res: *std.http.Server.Response, ctx: Context, aa: std.heap.ArenaAllocator) void {
+ defer aa.deinit();
+ defer ctx.deinit();
+ defer res.deinit();
+ if (res.wait()) {
+ if (self.handler.handle(res, ctx)) {
+ std.log.info("Success handling request for {}", .{res.address});
+ } else |err| {
+ std.log.err("Error handling request for {} : {}", .{ res.address, err });
+ if (handle_simple_response(res, "<html><body>Server error!</body></html>", .internal_server_error)) {} else |err2| {
+ std.log.err("Error sending error page for {} : {}", .{ res.address, err2 });
+ }
+ }
+ } else |_| {
+ // Do nothing
+ }
+
+ if (res.state != .finished) {
+ std.log.err("request wasn't finished!", .{});
+ }
+ }
+
+ fn handle_simple_response(res: *std.http.Server.Response, content: []const u8, status: std.http.Status) !void {
+ res.status = status;
+ res.transfer_encoding = .{ .content_length = content.len };
+ try res.headers.append("content-type", "text/html");
+ try res.do();
+ try res.writer().writeAll(content);
+ try res.finish();
+ }
+ };
+}
+
+const Response = std.http.Server.Response;
+
+pub const Params = std.StringHashMap([]const u8);
+
+/// Routing component for an http server with wildcard matching and parameter
+/// Handles matching a request to a handler.
+/// Handler pattern can either be matched exactly
+/// or it can have matcher segments, so
+/// "/" -> matches request for "/" only
+/// "/foo" matches request for "/foo" only
+/// "/foo/{bar}/baz" matches request for "/foo/123/baz" and "/foo/bar/baz", and Params would contain "bar":"123" and "bar":"bar" respectively.
+/// or it can have terminating wildcards, so
+/// "/foo/*" -> matches "/foo", "/foo/bar","/foo/bar/baz"
+/// "/*" -> matches all requests
+/// TODO something clever to parse path parameters into the appropriate types, maybe smth like "/foo/{bar:u32}/baz"
+/// TODO something to handle query parameters and request body too
+pub fn Router(comptime Context: type, comptime ErrorType: type) type {
+ return struct {
+ pub const Handler = struct {
+ method: std.http.Method,
+ pattern: []const u8,
+ handle_fn: *const fn (res: *Response, ctx: Context, Params) ErrorType!void,
+ };
+
+ handlers: []const Handler,
+
+ notfound: *const fn (res: *Response, ctx: Context) ErrorType!void,
+
+ pub fn handle(self: @This(), res: *Response, ctx: Context) ErrorType!void {
+ // Routing can only happen after we have the headers
+ // It is a programmer error to call this without calling .wait first.
+ if (res.state != .waited) unreachable;
+
+ handler_loop: for (self.handlers) |handler| {
+ if (handler.method != res.request.method) {
+ continue :handler_loop;
+ }
+
+ var path_params: Params = std.StringHashMap([]const u8).init(res.allocator);
+ defer path_params.deinit();
+
+ var handle_split = std.mem.splitScalar(u8, handler.pattern, '/');
+ var req_split = std.mem.splitScalar(u8, res.request.target, '/');
+
+ while (true) {
+ const maybe_handle_seg = handle_split.next();
+ const maybe_req_seg = req_split.next();
+ if (maybe_handle_seg == null and maybe_req_seg == null) {
+ // End of both handler and request, they matched this far so
+ // the handler must handle.
+ try handler.handle_fn(res, ctx, path_params);
+ break :handler_loop;
+ } else if (maybe_handle_seg != null and std.mem.eql(u8, maybe_handle_seg.?, "*")) {
+ // Wildcard, this matches
+ try handler.handle_fn(res, ctx, path_params);
+ break :handler_loop;
+ } else if (maybe_handle_seg == null or maybe_req_seg == null) {
+ // path lengths don't match, try the next handler
+ continue :handler_loop;
+ } else {
+ const handle_seg = maybe_handle_seg.?;
+ const req_seg = maybe_req_seg.?;
+ if (handle_seg.len > 0 and handle_seg[0] == '{' and handle_seg[handle_seg.len - 1] == '}') {
+ // Capture and keep going
+ const key = handle_seg[1 .. handle_seg.len - 1];
+ try path_params.put(key, req_seg);
+ } else if (std.mem.eql(u8, handle_seg, req_seg)) {
+ // segments match, keep going
+ } else {
+ // mismatch, try the next handler
+ continue :handler_loop;
+ }
+ }
+ }
+ } else {
+ try self.notfound(res, ctx);
+ }
+ }
+ };
+}
+
+const T = struct {
+ const TestCtx = struct {};
+ const TestErr = error{ TestError, OutOfMemory };
+ const TestRouter = Router(TestCtx, TestErr);
+
+ var notfoundinvoked = false;
+ fn notfound(_: *Response, _: TestCtx) TestErr!void {
+ notfoundinvoked = true;
+ }
+ var route1invoked = false;
+ var route1params: ?Params = null;
+ fn route1(_: *Response, _: TestCtx, p: Params) TestErr!void {
+ route1invoked = true;
+ route1params = try p.clone();
+ }
+ var route2invoked = false;
+ fn route2(_: *Response, _: TestCtx, _: Params) TestErr!void {
+ route2invoked = true;
+ }
+ fn reset() void {
+ notfoundinvoked = false;
+ if (route1params != null) route1params.?.deinit();
+ route1params = null;
+ route1invoked = false;
+ route2invoked = false;
+ }
+
+ fn runTestRouter(handlers: []TestRouter.Handler, target: []const u8) !void {
+ const alloc = std.testing.allocator;
+ const ctx = TestCtx{};
+ var buf: [128]u8 = undefined;
+ const req = std.http.Server.Request{
+ .method = .GET,
+ .target = target,
+ .version = .@"HTTP/1.1",
+ .headers = std.http.Headers.init(alloc),
+ .parser = std.http.protocol.HeadersParser.initStatic(&buf),
+ };
+ const sock = try std.net.tcpConnectToAddress(std.net.Address{ .in = std.net.Ip4Address.init(.{ 127, 0, 0, 1 }, 22) });
+ defer sock.close();
+ const conn = std.http.Server.Connection{
+ .stream = sock,
+ .protocol = .plain,
+ };
+ var res = Response{
+ .allocator = alloc,
+ .address = std.net.Address{ .in = std.net.Ip4Address.init(.{ 127, 0, 0, 1 }, 8080) },
+ .connection = conn,
+ .headers = std.http.Headers.init(alloc),
+ .request = req,
+ .state = .waited,
+ };
+ const router = TestRouter{
+ .handlers = handlers,
+ .notfound = notfound,
+ };
+ try router.handle(&res, ctx);
+ }
+
+ // fn hmof(x: []const u8, y: []const u8) std.StringHashMap([]const u8) {
+ // var hm = std.StringHashMap([]const u8).init(std.testing.allocator);
+ // hm.put(x, y) catch @panic("failed to create hmof in test");
+ // return hm;
+ // }
+
+ const TestCase = struct {
+ target: []const u8,
+ route1: ?[]const u8 = null,
+ route2: ?[]const u8 = null,
+ notfoundexpected: bool = false,
+ route1expected: bool = false,
+ route2expected: bool = false,
+ // route1paramsexpected: ?Params = null,
+ };
+
+ fn expectEqual(maybe_pexp: ?Params, maybe_pact: ?Params) !void {
+ if (maybe_pexp == null and maybe_pact == null) {
+ // fine
+ } else if (maybe_pexp == null or maybe_pact == null) {
+ std.debug.print("isnull(pexp) = {} isnull(pact) = {}", .{ maybe_pexp == null, maybe_pact == null });
+ return error.TestUnexpectedResult;
+ } else {
+ const pexp = maybe_pexp.?;
+ const pact = maybe_pact.?;
+ try std.testing.expectEqual(pexp.count(), pact.count());
+ var it = pexp.keyIterator();
+ var kexp = it.next();
+ while (kexp != null) : (kexp = it.next()) {
+ var vexp = pexp.get(kexp.?.*).?;
+ var maybe_vact = pact.get(kexp.?.*);
+ if (maybe_vact) |vact| {
+ std.debug.print("{s} {s}", .{ vexp, vact });
+ try std.testing.expectEqual(vexp, vact);
+ } else {
+ std.debug.print("expected key {s} not found in actual", .{kexp.?.*});
+ return error.TestUnexpectedResult;
+ }
+ }
+ }
+ }
+
+ test "router tests" {
+ // var m0 = std.StringHashMap([]const u8).init(std.testing.allocator);
+ // defer m0.deinit();
+ // var m1 = hmof("var", "bam");
+ // defer m1.deinit();
+ const cases = [_]TestCase{
+ .{
+ .target = "/",
+ .notfoundexpected = true,
+ },
+ .{
+ .target = "/",
+ .route1 = "/",
+ .route1expected = true,
+ // .route1paramsexpected = m0,
+ },
+ .{
+ .target = "/foo",
+ .route1 = "/bar",
+ .notfoundexpected = true,
+ },
+ .{
+ .target = "/bar",
+ .route1 = "/foo",
+ .route2 = "/bar",
+ .route2expected = true,
+ },
+ .{
+ .target = "/baz",
+ .route1 = "/",
+ .notfoundexpected = true,
+ },
+ .{
+ .target = "/baz",
+ .route1 = "/*",
+ .route2 = "/bar",
+ .route1expected = true,
+ // .route1paramsexpected = m0,
+ },
+ .{
+ .target = "/baz",
+ .route1 = "/*",
+ .route2 = "/baz",
+ .route1expected = true, // first matching route takes prio
+ // .route1paramsexpected = m0,
+ },
+ .{
+ .target = "/baz",
+ .route1 = "/baz",
+ .route2 = "/*",
+ .route1expected = true, // first matching route takes prio
+ // .route1paramsexpected = m0,
+ },
+ .{
+ .target = "/baz/bam",
+ .route1 = "/baz/{var}",
+ .route1expected = true,
+ // .route1paramsexpected = m1,
+ },
+ .{
+ .target = "/baz/bam/boo",
+ .route1 = "/baz/{var}/boo",
+ .route1expected = true,
+ // .route1paramsexpected = m1,
+ },
+ // .{
+ // .target = "/baz/bam/bar",
+ // .route1 = "/baz/{var}/boo",
+ // .notfoundexpected = true,
+ // },
+ };
+
+ for (cases) |case| {
+ defer reset();
+ var handlers = std.ArrayList(TestRouter.Handler).init(std.testing.allocator);
+ defer handlers.deinit();
+ if (case.route1) |r1| {
+ try handlers.append(TestRouter.Handler{ .pattern = r1, .method = .GET, .handle_fn = route1 });
+ }
+ if (case.route2) |r2| {
+ try handlers.append(TestRouter.Handler{ .pattern = r2, .method = .GET, .handle_fn = route2 });
+ }
+ try runTestRouter(handlers.items, case.target);
+ try std.testing.expectEqual(case.notfoundexpected, notfoundinvoked);
+ try std.testing.expectEqual(case.route1expected, route1invoked);
+ try std.testing.expectEqual(case.route2expected, route2invoked);
+ // try expectEqual(case.route1paramsexpected, route1params); // TODO assert captures
+ }
+ }
+};
+
+test {
+ _ = T;
+}