const std = @import("std"); /// Wrapper around a std.http.Server to run a multi-threaded HTTP server using thread-per-request and arena-per-request /// Context is .clone()'d and passed to each request, useful for passing user data to each request handler e.g. a database connection. /// HandlerType provides a single method handle() which is used to actually handle requests. pub fn Server(comptime Context: type, comptime Handler: type) type { return struct { address: std.net.Address, context: Context, handler: Handler, max_header_size: usize = 8192, n_threads: u32 = 50, allocator: std.mem.Allocator, pub fn serve(self: @This()) !void { var tp = std.Thread.Pool{ .threads = &[_]std.Thread{}, .allocator = self.allocator }; try tp.init(.{ .allocator = self.allocator, .n_jobs = self.n_threads }); defer tp.deinit(); var svr_internal = std.http.Server.init(self.allocator, .{ .reuse_address = true }); defer svr_internal.deinit(); try svr_internal.listen(self.address); std.log.info("server listening on {}", .{self.address}); while (true) { var aa = std.heap.ArenaAllocator.init(self.allocator); // will be freed by the spawned thread. var conn = try svr_internal.accept(.{ .allocator = aa.allocator(), .header_strategy = .{ .dynamic = self.max_header_size } }); const ctx: Context = self.context.clone(); try tp.spawn(handle, .{ self, &conn, ctx, aa }); } } fn handle(self: @This(), res: *std.http.Server.Response, ctx: Context, aa: std.heap.ArenaAllocator) void { defer aa.deinit(); defer ctx.deinit(); defer res.deinit(); if (res.wait()) { if (self.handler.handle(res, ctx)) { std.log.info("Success handling request [{s} {s} {s}] status {d} client {}", .{ @tagName(res.request.method), res.request.target, @tagName(res.request.version), @intFromEnum(res.status), res.address }); } else |err| { std.log.info("Error handling request [{s} {s} {s}] client {} error {}", .{ @tagName(res.request.method), res.request.target, @tagName(res.request.version), res.address, err }); if (handle_simple_response(res, "Server error!", .internal_server_error)) {} else |err2| { std.log.err("Error sending error page for {} : {}", .{ res.address, err2 }); } } } else |_| { // Do nothing } if (res.state != .finished) { std.log.err("request wasn't finished!", .{}); } } fn handle_simple_response(res: *std.http.Server.Response, content: []const u8, status: std.http.Status) !void { res.status = status; res.transfer_encoding = .{ .content_length = content.len }; try res.headers.append("content-type", "text/html"); try res.do(); try res.writer().writeAll(content); try res.finish(); } }; } const Response = std.http.Server.Response; pub const Params = std.StringHashMap([]const u8); /// Routing component for an http server with wildcard matching and parameter /// Handles matching a request to a handler. /// Handler pattern can either be matched exactly /// or it can have matcher segments, so /// "/" -> matches request for "/" only /// "/foo" matches request for "/foo" only /// "/foo/{bar}/baz" matches request for "/foo/123/baz" and "/foo/bar/baz", and Params would contain "bar":"123" and "bar":"bar" respectively. /// or it can have terminating wildcards, so /// "/foo/*" -> matches "/foo", "/foo/bar","/foo/bar/baz" /// "/*" -> matches all requests /// TODO something clever to parse path parameters into the appropriate types, maybe smth like "/foo/{bar:u32}/baz" /// TODO something to handle query parameters and request body too pub fn Router(comptime Context: type, comptime ErrorType: type) type { return struct { pub const Handler = struct { method: std.http.Method, pattern: []const u8, handle_fn: *const fn (res: *Response, ctx: Context, Params) ErrorType!void, }; handlers: []const Handler, notfound: *const fn (res: *Response, ctx: Context) ErrorType!void, pub fn handle(self: @This(), res: *Response, ctx: Context) ErrorType!void { // Routing can only happen after we have the headers // It is a programmer error to call this without calling .wait first. if (res.state != .waited) unreachable; const p = try Path.parse(res.request.target); const path = p.path; handler_loop: for (self.handlers) |handler| { if (handler.method != res.request.method) { continue :handler_loop; } var path_params: Params = std.StringHashMap([]const u8).init(res.allocator); defer path_params.deinit(); var handle_split = std.mem.splitScalar(u8, handler.pattern, '/'); var req_split = std.mem.splitScalar(u8, path, '/'); while (true) { const maybe_handle_seg = handle_split.next(); const maybe_req_seg = req_split.next(); if (maybe_handle_seg == null and maybe_req_seg == null) { // End of both handler and request, they matched this far so // the handler must handle. try handler.handle_fn(res, ctx, path_params); break :handler_loop; } else if (maybe_handle_seg != null and std.mem.eql(u8, maybe_handle_seg.?, "*")) { // Wildcard, this matches try handler.handle_fn(res, ctx, path_params); break :handler_loop; } else if (maybe_handle_seg == null or maybe_req_seg == null) { // path lengths don't match, try the next handler continue :handler_loop; } else { const handle_seg = maybe_handle_seg.?; const req_seg = maybe_req_seg.?; if (handle_seg.len > 0 and handle_seg[0] == '{' and handle_seg[handle_seg.len - 1] == '}') { // Capture and keep going const key = handle_seg[1 .. handle_seg.len - 1]; try path_params.put(key, req_seg); } else if (std.mem.eql(u8, handle_seg, req_seg)) { // segments match, keep going } else { // mismatch, try the next handler continue :handler_loop; } } } } else { try self.notfound(res, ctx); } } }; } const RouterTest = struct { const TestCtx = struct {}; const TestErr = error{ TestError, OutOfMemory } || Path.ParseError; const TestRouter = Router(TestCtx, TestErr); var notfoundinvoked = false; fn notfound(_: *Response, _: TestCtx) TestErr!void { notfoundinvoked = true; } var route1invoked = false; var route1params: ?Params = null; fn route1(_: *Response, _: TestCtx, p: Params) TestErr!void { route1invoked = true; route1params = try p.clone(); } var route2invoked = false; fn route2(_: *Response, _: TestCtx, _: Params) TestErr!void { route2invoked = true; } fn reset() void { notfoundinvoked = false; if (route1params != null) route1params.?.deinit(); route1params = null; route1invoked = false; route2invoked = false; } fn runTestRouter(handlers: []TestRouter.Handler, target: []const u8) !void { const alloc = std.testing.allocator; const ctx = TestCtx{}; var buf: [128]u8 = undefined; const req = std.http.Server.Request{ .method = .GET, .target = target, .version = .@"HTTP/1.1", .headers = std.http.Headers.init(alloc), .parser = std.http.protocol.HeadersParser.initStatic(&buf), }; const sock = try std.net.tcpConnectToAddress(std.net.Address{ .in = std.net.Ip4Address.init(.{ 127, 0, 0, 1 }, 22) }); defer sock.close(); const conn = std.http.Server.Connection{ .stream = sock, .protocol = .plain, }; var res = Response{ .allocator = alloc, .address = std.net.Address{ .in = std.net.Ip4Address.init(.{ 127, 0, 0, 1 }, 8080) }, .connection = conn, .headers = std.http.Headers.init(alloc), .request = req, .state = .waited, }; const router = TestRouter{ .handlers = handlers, .notfound = notfound, }; try router.handle(&res, ctx); } // fn hmof(x: []const u8, y: []const u8) std.StringHashMap([]const u8) { // var hm = std.StringHashMap([]const u8).init(std.testing.allocator); // hm.put(x, y) catch @panic("failed to create hmof in test"); // return hm; // } const TestCase = struct { target: []const u8, route1: ?[]const u8 = null, route2: ?[]const u8 = null, notfoundexpected: bool = false, route1expected: bool = false, route2expected: bool = false, // route1paramsexpected: ?Params = null, }; fn expectEqual(maybe_pexp: ?Params, maybe_pact: ?Params) !void { if (maybe_pexp == null and maybe_pact == null) { // fine } else if (maybe_pexp == null or maybe_pact == null) { std.debug.print("isnull(pexp) = {} isnull(pact) = {}", .{ maybe_pexp == null, maybe_pact == null }); return error.TestUnexpectedResult; } else { const pexp = maybe_pexp.?; const pact = maybe_pact.?; try std.testing.expectEqual(pexp.count(), pact.count()); var it = pexp.keyIterator(); var kexp = it.next(); while (kexp != null) : (kexp = it.next()) { var vexp = pexp.get(kexp.?.*).?; var maybe_vact = pact.get(kexp.?.*); if (maybe_vact) |vact| { std.debug.print("{s} {s}", .{ vexp, vact }); try std.testing.expectEqual(vexp, vact); } else { std.debug.print("expected key {s} not found in actual", .{kexp.?.*}); return error.TestUnexpectedResult; } } } } test "router tests" { // var m0 = std.StringHashMap([]const u8).init(std.testing.allocator); // defer m0.deinit(); // var m1 = hmof("var", "bam"); // defer m1.deinit(); const cases = [_]TestCase{ .{ .target = "/", .notfoundexpected = true, }, .{ .target = "/", .route1 = "/", .route1expected = true, // .route1paramsexpected = m0, }, .{ .target = "/foo", .route1 = "/bar", .notfoundexpected = true, }, .{ .target = "/bar", .route1 = "/foo", .route2 = "/bar", .route2expected = true, }, .{ .target = "/baz", .route1 = "/", .notfoundexpected = true, }, .{ .target = "/baz", .route1 = "/*", .route2 = "/bar", .route1expected = true, // .route1paramsexpected = m0, }, .{ .target = "/baz", .route1 = "/*", .route2 = "/baz", .route1expected = true, // first matching route takes prio // .route1paramsexpected = m0, }, .{ .target = "/baz", .route1 = "/baz", .route2 = "/*", .route1expected = true, // first matching route takes prio // .route1paramsexpected = m0, }, .{ .target = "/baz/bam", .route1 = "/baz/{var}", .route1expected = true, // .route1paramsexpected = m1, }, .{ .target = "/baz/bam/boo", .route1 = "/baz/{var}/boo", .route1expected = true, // .route1paramsexpected = m1, }, .{ .target = "/baz/bam/boo?somequery=foo", .route1 = "/baz/{var}/boo", .route1expected = true, // .route1paramsexpected = m1, }, // .{ // .target = "/baz/bam/bar", // .route1 = "/baz/{var}/boo", // .notfoundexpected = true, // }, }; for (cases) |case| { defer reset(); var handlers = std.ArrayList(TestRouter.Handler).init(std.testing.allocator); defer handlers.deinit(); if (case.route1) |r1| { try handlers.append(TestRouter.Handler{ .pattern = r1, .method = .GET, .handle_fn = route1 }); } if (case.route2) |r2| { try handlers.append(TestRouter.Handler{ .pattern = r2, .method = .GET, .handle_fn = route2 }); } try runTestRouter(handlers.items, case.target); try std.testing.expectEqual(case.notfoundexpected, notfoundinvoked); try std.testing.expectEqual(case.route1expected, route1invoked); try std.testing.expectEqual(case.route2expected, route2invoked); // try expectEqual(case.route1paramsexpected, route1params); // TODO assert captures } } }; /// HTTP path parsing /// which is a subset of URI parsing :) pub const Path = struct { path: []const u8, query: []const u8, fragment: []const u8, // technically I think the fragment is never received on the server anyway pub const ParseError = error{Malformatted}; pub fn parse(path: []const u8) ParseError!Path { var p = Path{ .path = path, .query = "", .fragment = "", }; const q_ix = std.mem.indexOfScalar(u8, path, '?'); const f_ix = std.mem.indexOfScalar(u8, path, '#'); if (q_ix) |q| { p.path = path[0..q]; if (f_ix) |f| { if (f < q) { return ParseError.Malformatted; } p.query = path[(q + 1)..f]; p.fragment = path[(f + 1)..]; } else { p.query = path[(q + 1)..]; } } else if (f_ix) |f| { p.path = path[0..f]; p.fragment = path[(f + 1)..]; } return p; } pub fn get_query_param(self: Path, key: []const u8) ?[]const u8 { var it1 = std.mem.splitScalar(u8, self.query, '&'); var t: ?[]const u8 = it1.first(); while (t != null) : (t = it1.next()) { var it2 = std.mem.splitScalar(u8, t.?, '='); const k = it2.first(); const v = it2.next(); if (std.mem.eql(u8, key, k)) { return v; } } return null; } }; const PathTest = struct { test "path" { const p = try Path.parse("/"); try std.testing.expectEqualDeep(Path{ .path = "/", .query = "", .fragment = "" }, p); } test "query" { const p = try Path.parse("/foo?bar=baz"); try std.testing.expectEqualDeep(Path{ .path = "/foo", .query = "bar=baz", .fragment = "" }, p); } test "query and fragment" { const p = try Path.parse("/foo?bar=baz#frag"); try std.testing.expectEqualDeep(Path{ .path = "/foo", .query = "bar=baz", .fragment = "frag" }, p); } test "fragment" { const p = try Path.parse("/foo#frag"); try std.testing.expectEqualDeep(Path{ .path = "/foo", .query = "", .fragment = "frag" }, p); } test "query param" { const p = try Path.parse("/foo?bar=baz#frag"); const v1 = p.get_query_param("bar"); try std.testing.expect(v1 != null); try std.testing.expectEqualSlices(u8, "baz", v1.?); const v2 = p.get_query_param("bam"); try std.testing.expect(v2 == null); } }; test { _ = RouterTest; _ = PathTest; }