zigwebserver

Tools for writing web stuff in zig
Log | Files | Refs

zigwebserver.zig (23104B)


      1 const std = @import("std");
      2 
      3 /// Wrapper around a std.http.Server to run a multi-threaded HTTP server using thread-per-request and arena-per-request
      4 /// Context is .clone()'d and passed to each request, useful for passing user data to each request handler e.g. a database connection.
      5 /// HandlerType provides a single method handle() which is used to actually handle requests.
      6 pub fn Server(comptime Context: type, comptime Handler: type) type {
      7     return struct {
      8         address: std.net.Address,
      9         context: Context,
     10         handler: Handler,
     11 
     12         max_header_size: usize = 8192,
     13         n_threads: u32 = 50,
     14 
     15         allocator: std.mem.Allocator,
     16 
     17         pub fn serve(self: @This()) !void {
     18             var tp = std.Thread.Pool{ .threads = &[_]std.Thread{}, .allocator = self.allocator };
     19             try tp.init(.{ .allocator = self.allocator, .n_jobs = self.n_threads });
     20             defer tp.deinit();
     21 
     22             var svr_internal = std.http.Server.init(.{ .reuse_address = true });
     23             defer svr_internal.deinit();
     24             try svr_internal.listen(self.address);
     25             std.log.info("server listening on {}", .{self.address});
     26             while (true) {
     27                 var aa = std.heap.ArenaAllocator.init(self.allocator); // will be freed by the spawned thread.
     28                 var conn = try svr_internal.accept(.{ .allocator = aa.allocator(), .header_strategy = .{ .dynamic = self.max_header_size } });
     29                 const ctx: Context = self.context.clone();
     30                 try tp.spawn(handle, .{ self, &conn, ctx, aa });
     31             }
     32         }
     33 
     34         fn handle(self: @This(), res: *std.http.Server.Response, ctx: Context, aa: std.heap.ArenaAllocator) void {
     35             defer aa.deinit();
     36             defer ctx.deinit();
     37             defer res.deinit();
     38             if (res.wait()) {
     39                 if (self.handler.handle(res, ctx)) {
     40                     std.log.info("Success handling request [{s} {s} {s}] status {d} client {}", .{ @tagName(res.request.method), res.request.target, @tagName(res.request.version), @intFromEnum(res.status), res.address });
     41                 } else |err| {
     42                     std.log.info("Error handling request [{s} {s} {s}] client {} error {}", .{ @tagName(res.request.method), res.request.target, @tagName(res.request.version), res.address, err });
     43                     if (handle_simple_response(res, "<html><body>Server error!</body></html>", .internal_server_error)) {} else |err2| {
     44                         std.log.err("Error sending error page for {} : {}", .{ res.address, err2 });
     45                     }
     46                 }
     47             } else |_| {
     48                 // Do nothing
     49             }
     50 
     51             if (res.state != .finished) {
     52                 std.log.err("request wasn't finished!", .{});
     53             }
     54         }
     55 
     56         fn handle_simple_response(res: *std.http.Server.Response, content: []const u8, status: std.http.Status) !void {
     57             res.status = status;
     58             res.transfer_encoding = .{ .content_length = content.len };
     59             try res.headers.append("content-type", "text/html");
     60             try res.send();
     61             try res.writer().writeAll(content);
     62             try res.finish();
     63         }
     64     };
     65 }
     66 
     67 pub const Params = std.StringHashMap([]const u8);
     68 
     69 /// Routing component for an http server with wildcard matching and parameter
     70 /// Handles matching a request to a handler.
     71 /// Handler pattern can either be matched exactly
     72 /// or it can have matcher segments, so
     73 /// "/" -> matches request for "/" only
     74 /// "/foo" matches request for "/foo" only
     75 /// "/foo/{bar}/baz" matches request for "/foo/123/baz" and "/foo/bar/baz", and Params would contain "bar":"123" and "bar":"bar" respectively.
     76 /// or it can have terminating wildcards, so
     77 /// "/foo/*" -> matches "/foo", "/foo/bar","/foo/bar/baz"
     78 /// "/*" -> matches all requests
     79 /// TODO something clever to parse path parameters into the appropriate types, maybe smth like "/foo/{bar:u32}/baz"
     80 /// TODO something to handle query parameters and request body too
     81 pub fn Router(comptime Response: type, comptime Context: type, comptime ErrorType: type) type {
     82     return struct {
     83         pub const Handler = struct {
     84             method: std.http.Method,
     85             pattern: []const u8,
     86             handle_fn: *const fn (Response, Context, Params) ErrorType!void,
     87         };
     88 
     89         allocator: std.mem.Allocator,
     90 
     91         handlers: []const Handler,
     92 
     93         notfound: *const fn (Response, Context) ErrorType!void,
     94 
     95         pub fn handle(self: @This(), res: Response, ctx: Context) ErrorType!void {
     96             var p = try Path.parse(self.allocator, res.request.target);
     97             defer p.deinit();
     98             const path = p.path;
     99 
    100             handler_loop: for (self.handlers) |handler| {
    101                 if (handler.method != res.request.method) {
    102                     continue :handler_loop;
    103                 }
    104 
    105                 var path_params: Params = std.StringHashMap([]const u8).init(self.allocator);
    106                 defer path_params.deinit();
    107 
    108                 var handle_split = std.mem.splitScalar(u8, handler.pattern, '/');
    109                 var req_split = std.mem.splitScalar(u8, path, '/');
    110 
    111                 while (true) {
    112                     const maybe_handle_seg = handle_split.next();
    113                     const maybe_req_seg = req_split.next();
    114                     if (maybe_handle_seg == null and maybe_req_seg == null) {
    115                         // End of both handler and request, they matched this far so
    116                         // the handler must handle.
    117                         try handler.handle_fn(res, ctx, path_params);
    118                         break :handler_loop;
    119                     } else if (maybe_handle_seg != null and std.mem.eql(u8, maybe_handle_seg.?, "*")) {
    120                         // Wildcard, this matches
    121                         try handler.handle_fn(res, ctx, path_params);
    122                         break :handler_loop;
    123                     } else if (maybe_handle_seg == null or maybe_req_seg == null) {
    124                         // path lengths don't match, try the next handler
    125                         continue :handler_loop;
    126                     } else {
    127                         const handle_seg = maybe_handle_seg.?;
    128                         const req_seg = maybe_req_seg.?;
    129                         if (handle_seg.len > 0 and handle_seg[0] == '{' and handle_seg[handle_seg.len - 1] == '}') {
    130                             // Capture and keep going
    131                             const key = handle_seg[1 .. handle_seg.len - 1];
    132                             try path_params.put(key, req_seg);
    133                         } else if (std.mem.eql(u8, handle_seg, req_seg)) {
    134                             // segments match, keep going
    135                         } else {
    136                             // mismatch, try the next handler
    137                             continue :handler_loop;
    138                         }
    139                     }
    140                 }
    141             } else {
    142                 try self.notfound(res, ctx);
    143             }
    144         }
    145     };
    146 }
    147 
    148 const RouterTest = struct {
    149     const TestRequest = struct {
    150         method: std.http.Method,
    151         target: []const u8,
    152     };
    153     const TestResponse = struct {
    154         request: TestRequest,
    155     };
    156     const TestCtx = struct {};
    157     const TestErr = error{ TestError, OutOfMemory } || Path.ParseError;
    158     const TestRouter = Router(TestResponse, TestCtx, TestErr);
    159 
    160     var notfoundinvoked = false;
    161     fn notfound(_: TestResponse, _: TestCtx) TestErr!void {
    162         notfoundinvoked = true;
    163     }
    164     var route1invoked = false;
    165     var route1params: ?Params = null;
    166     fn route1(_: TestResponse, _: TestCtx, p: Params) TestErr!void {
    167         route1invoked = true;
    168         route1params = try p.clone();
    169     }
    170     var route2invoked = false;
    171     fn route2(_: TestResponse, _: TestCtx, _: Params) TestErr!void {
    172         route2invoked = true;
    173     }
    174     fn reset() void {
    175         notfoundinvoked = false;
    176         if (route1params != null) route1params.?.deinit();
    177         route1params = null;
    178         route1invoked = false;
    179         route2invoked = false;
    180     }
    181 
    182     fn runTestRouter(handlers: []TestRouter.Handler, target: []const u8) !void {
    183         const ctx = TestCtx{};
    184         const req = TestRequest{
    185             .method = .GET,
    186             .target = target,
    187         };
    188         const res = TestResponse{
    189             .request = req,
    190         };
    191         const router = TestRouter{
    192             .allocator = std.testing.allocator,
    193             .handlers = handlers,
    194             .notfound = notfound,
    195         };
    196         try router.handle(res, ctx);
    197     }
    198 
    199     // fn hmof(x: []const u8, y: []const u8) std.StringHashMap([]const u8) {
    200     //     var hm = std.StringHashMap([]const u8).init(std.testing.allocator);
    201     //     hm.put(x, y) catch @panic("failed to create hmof in test");
    202     //     return hm;
    203     // }
    204 
    205     const TestCase = struct {
    206         target: []const u8,
    207         route1: ?[]const u8 = null,
    208         route2: ?[]const u8 = null,
    209         notfoundexpected: bool = false,
    210         route1expected: bool = false,
    211         route2expected: bool = false,
    212         // route1paramsexpected: ?Params = null,
    213     };
    214 
    215     fn expectEqual(maybe_pexp: ?Params, maybe_pact: ?Params) !void {
    216         if (maybe_pexp == null and maybe_pact == null) {
    217             // fine
    218         } else if (maybe_pexp == null or maybe_pact == null) {
    219             std.debug.print("isnull(pexp) = {} isnull(pact) = {}", .{ maybe_pexp == null, maybe_pact == null });
    220             return error.TestUnexpectedResult;
    221         } else {
    222             const pexp = maybe_pexp.?;
    223             const pact = maybe_pact.?;
    224             try std.testing.expectEqual(pexp.count(), pact.count());
    225             var it = pexp.keyIterator();
    226             var kexp = it.next();
    227             while (kexp != null) : (kexp = it.next()) {
    228                 const vexp = pexp.get(kexp.?.*).?;
    229                 const maybe_vact = pact.get(kexp.?.*);
    230                 if (maybe_vact) |vact| {
    231                     std.debug.print("{s} {s}", .{ vexp, vact });
    232                     try std.testing.expectEqual(vexp, vact);
    233                 } else {
    234                     std.debug.print("expected key {s} not found in actual", .{kexp.?.*});
    235                     return error.TestUnexpectedResult;
    236                 }
    237             }
    238         }
    239     }
    240 
    241     test "router tests" {
    242         // var m0 = std.StringHashMap([]const u8).init(std.testing.allocator);
    243         // defer m0.deinit();
    244         // var m1 = hmof("var", "bam");
    245         // defer m1.deinit();
    246         const cases = [_]TestCase{
    247             .{
    248                 .target = "/",
    249                 .notfoundexpected = true,
    250             },
    251             .{
    252                 .target = "/",
    253                 .route1 = "/",
    254                 .route1expected = true,
    255                 // .route1paramsexpected = m0,
    256             },
    257             .{
    258                 .target = "/foo",
    259                 .route1 = "/bar",
    260                 .notfoundexpected = true,
    261             },
    262             .{
    263                 .target = "/bar",
    264                 .route1 = "/foo",
    265                 .route2 = "/bar",
    266                 .route2expected = true,
    267             },
    268             .{
    269                 .target = "/baz",
    270                 .route1 = "/",
    271                 .notfoundexpected = true,
    272             },
    273             .{
    274                 .target = "/baz",
    275                 .route1 = "/*",
    276                 .route2 = "/bar",
    277                 .route1expected = true,
    278                 // .route1paramsexpected = m0,
    279             },
    280             .{
    281                 .target = "/baz",
    282                 .route1 = "/*",
    283                 .route2 = "/baz",
    284                 .route1expected = true, // first matching route takes prio
    285                 // .route1paramsexpected = m0,
    286             },
    287             .{
    288                 .target = "/baz",
    289                 .route1 = "/baz",
    290                 .route2 = "/*",
    291                 .route1expected = true, // first matching route takes prio
    292                 // .route1paramsexpected = m0,
    293             },
    294             .{
    295                 .target = "/baz/bam",
    296                 .route1 = "/baz/{var}",
    297                 .route1expected = true,
    298                 // .route1paramsexpected = m1,
    299             },
    300             .{
    301                 .target = "/baz/bam/boo",
    302                 .route1 = "/baz/{var}/boo",
    303                 .route1expected = true,
    304                 // .route1paramsexpected = m1,
    305             },
    306             .{
    307                 .target = "/baz/bam/boo?somequery=foo",
    308                 .route1 = "/baz/{var}/boo",
    309                 .route1expected = true,
    310                 // .route1paramsexpected = m1,
    311             },
    312             // .{
    313             //     .target = "/baz/bam/bar",
    314             //     .route1 = "/baz/{var}/boo",
    315             //     .notfoundexpected = true,
    316             // },
    317         };
    318 
    319         for (cases) |case| {
    320             defer reset();
    321             var handlers = std.ArrayList(TestRouter.Handler).init(std.testing.allocator);
    322             defer handlers.deinit();
    323             if (case.route1) |r1| {
    324                 try handlers.append(TestRouter.Handler{ .pattern = r1, .method = .GET, .handle_fn = route1 });
    325             }
    326             if (case.route2) |r2| {
    327                 try handlers.append(TestRouter.Handler{ .pattern = r2, .method = .GET, .handle_fn = route2 });
    328             }
    329             try runTestRouter(handlers.items, case.target);
    330             try std.testing.expectEqual(case.notfoundexpected, notfoundinvoked);
    331             try std.testing.expectEqual(case.route1expected, route1invoked);
    332             try std.testing.expectEqual(case.route2expected, route2invoked);
    333             // try expectEqual(case.route1paramsexpected, route1params); // TODO assert captures
    334         }
    335     }
    336 };
    337 
    338 /// HTTP path parsing
    339 /// which is a subset of URI parsing :)
    340 /// RFC-3986
    341 pub const Path = struct {
    342     allocator: std.mem.Allocator,
    343     path: []const u8,
    344     query: []const u8,
    345     fragment: []const u8, // technically I think the fragment is never received on the server anyway
    346     query_parsed: ?Form = null,
    347 
    348     pub const ParseError = error{Malformatted} || Form.ParseError;
    349 
    350     pub fn parse(allocator: std.mem.Allocator, str: []const u8) ParseError!Path {
    351         var path: []const u8 = str;
    352         var query: []const u8 = "";
    353         var fragment: []const u8 = "";
    354         const f_ix = std.mem.indexOfScalar(u8, str, '#');
    355         const q_ix = std.mem.indexOfScalar(u8, str, '?');
    356         if (q_ix) |q| {
    357             path = str[0..q];
    358             if (f_ix) |f| {
    359                 if (f < q) {
    360                     return ParseError.Malformatted;
    361                 }
    362                 query = str[(q + 1)..f];
    363                 fragment = str[(f + 1)..];
    364             } else {
    365                 query = str[(q + 1)..];
    366             }
    367         } else if (f_ix) |f| {
    368             path = str[0..f];
    369             fragment = str[(f + 1)..];
    370         }
    371         return Path{
    372             .allocator = allocator,
    373             .path = path,
    374             .query = query,
    375             .fragment = fragment,
    376         };
    377     }
    378 
    379     pub fn get_query_param(self: *Path, key: []const u8) !?[]const u8 {
    380         if (self.query_parsed == null) {
    381             self.query_parsed = try Form.parse(self.allocator, self.query);
    382         }
    383         return self.query_parsed.?.data.get(key);
    384     }
    385 
    386     pub fn query_to_struct(self: *Path, comptime T: type) !T {
    387         if (self.query_parsed == null) {
    388             self.query_parsed = try Form.parse(self.allocator, self.query);
    389         }
    390         return self.query_parsed.?.form_to_struct(T);
    391     }
    392 
    393     pub fn deinit(self: *Path) void {
    394         if (self.query_parsed != null) {
    395             self.query_parsed.?.deinit();
    396         }
    397     }
    398 };
    399 
    400 const PathTest = struct {
    401     test "path" {
    402         var p = try Path.parse(std.testing.allocator, "/");
    403         defer p.deinit();
    404         try assertPath("/", "", "", p);
    405     }
    406 
    407     fn assertPath(path: []const u8, query: []const u8, fragment: []const u8, actual: Path) !void {
    408         try std.testing.expectEqualSlices(u8, path, actual.path);
    409         try std.testing.expectEqualSlices(u8, query, actual.query);
    410         try std.testing.expectEqualSlices(u8, fragment, actual.fragment);
    411     }
    412 
    413     test "query" {
    414         var p = try Path.parse(std.testing.allocator, "/foo?bar=baz");
    415         defer p.deinit();
    416         try assertPath("/foo", "bar=baz", "", p);
    417     }
    418 
    419     test "query and fragment" {
    420         var p = try Path.parse(std.testing.allocator, "/foo?bar=baz#frag");
    421         defer p.deinit();
    422         try assertPath("/foo", "bar=baz", "frag", p);
    423     }
    424 
    425     test "fragment" {
    426         var p = try Path.parse(std.testing.allocator, "/foo#frag");
    427         defer p.deinit();
    428         try assertPath("/foo", "", "frag", p);
    429     }
    430 
    431     test "query param" {
    432         var p = try Path.parse(std.testing.allocator, "/foo?bar=baz#frag");
    433         defer p.deinit();
    434         const v1 = try p.get_query_param("bar");
    435         try std.testing.expect(v1 != null);
    436         try std.testing.expectEqualSlices(u8, "baz", v1.?);
    437         const v2 = try p.get_query_param("bam");
    438         try std.testing.expect(v2 == null);
    439     }
    440 
    441     test "query param mixed" {
    442         var p = try Path.parse(std.testing.allocator, "/foo?bar=baz&ba+m=bo+om&zigzag#frag");
    443         defer p.deinit();
    444         try assertPath("/foo", "bar=baz&ba+m=bo+om&zigzag", "frag", p);
    445         const v1 = try p.get_query_param("bar");
    446         try std.testing.expect(v1 != null);
    447         try std.testing.expectEqualSlices(u8, "baz", v1.?);
    448         const v2 = try p.get_query_param("ba m");
    449         try std.testing.expect(v2 != null);
    450         try std.testing.expectEqualSlices(u8, "bo om", v2.?);
    451     }
    452 
    453     test "query to struct" {
    454         var p = try Path.parse(std.testing.allocator, "/foo?bar=ba+z&bam=55&zigzag#frag");
    455         defer p.deinit();
    456         const T = struct {
    457             bar: []const u8,
    458             bam: u64,
    459             fn deinit(self: *@This()) void {
    460                 std.testing.allocator.free(self.bar);
    461             }
    462         };
    463         var t = try p.query_to_struct(T);
    464         defer t.deinit();
    465         try std.testing.expectEqualDeep(T{ .bar = "ba z", .bam = 55 }, t);
    466     }
    467 };
    468 
    469 pub const Form = struct {
    470     allocator: std.mem.Allocator,
    471     data: std.StringHashMap([]const u8),
    472 
    473     const ParseError = error{ Malformatted, InvalidLength, InvalidCharacter, NoSpaceLeft } || std.mem.Allocator.Error;
    474 
    475     // Tries to parse key=value&key2=value2 pairs from the form.
    476     // Note that a URL query segment doesn't _have_ to be key-value pairs
    477     // so this is quite lenient.
    478     // Form struct owns all the keys and values in the resulting map.
    479     pub fn parse(allocator: std.mem.Allocator, form: []const u8) ParseError!Form {
    480         var res = std.StringHashMap([]const u8).init(allocator);
    481         var iter1 = std.mem.splitScalar(u8, form, '&');
    482         while (iter1.next()) |split| {
    483             var iter2 = std.mem.splitScalar(u8, split, '=');
    484             if (iter2.next()) |key| {
    485                 if (iter2.next()) |value| {
    486                     try res.put(try percent_decode(allocator, key), try percent_decode(allocator, value));
    487                 } else {
    488                     // Do nothing, it's a well-formatted kv pair
    489                 }
    490             } else {
    491                 // Do nothing it's not a well-formatted kv pair
    492             }
    493         }
    494         return Form{ .allocator = allocator, .data = res };
    495     }
    496     pub fn form_to_struct(self: *Form, comptime T: type) !T {
    497         return to_struct(self.allocator, T, self.data);
    498     }
    499     pub fn deinit(self: *Form) void {
    500         var it = self.data.iterator();
    501         var e = it.next();
    502         while (e != null) : (e = it.next()) {
    503             self.allocator.free(e.?.key_ptr.*);
    504             self.allocator.free(e.?.value_ptr.*);
    505         }
    506         self.data.deinit();
    507     }
    508 };
    509 
    510 fn percent_decode(allocator: std.mem.Allocator, str: []const u8) ![]const u8 {
    511     var fbs = std.io.fixedBufferStream(str);
    512     var rdr = fbs.reader();
    513     var out = std.ArrayList(u8).init(allocator);
    514     var wtr = out.writer();
    515     defer out.deinit();
    516     while (true) {
    517         const b = rdr.readByte() catch break;
    518         if (b == '%') {
    519             var hex_code: [2]u8 = undefined;
    520             _ = try rdr.readAll(&hex_code);
    521             var b2: [1]u8 = .{0};
    522             _ = try std.fmt.hexToBytes(&b2, &hex_code);
    523             try wtr.writeByte(b2[0]);
    524         } else if (b == '+') {
    525             try wtr.writeByte(' ');
    526         } else {
    527             try wtr.writeByte(b);
    528         }
    529     }
    530     return out.toOwnedSlice();
    531 }
    532 
    533 const PercentEncodeTest = struct {
    534     test "decode" {
    535         const decoded = try percent_decode(std.testing.allocator, "%C3%A7%C3%AE%C4%85%C3%B5+hithere");
    536         defer std.testing.allocator.free(decoded);
    537         try std.testing.expectEqualStrings("çîąõ hithere", decoded);
    538     }
    539 };
    540 
    541 /// Populate a struct from a hashmap
    542 fn to_struct(allocator: std.mem.Allocator, comptime T: type, hm: std.StringHashMap([]const u8)) !T {
    543     const ti = @typeInfo(T);
    544     if (ti != .Struct) {
    545         @compileError("to_struct T was not a struct type");
    546     }
    547     var t: T = undefined;
    548     inline for (ti.Struct.fields) |field| {
    549         if (field.is_comptime) {
    550             @compileError("can't dynamically set comptime field " ++ field.name);
    551         }
    552         const value: []const u8 = hm.get(field.name) orelse {
    553             return error.FieldNotPresent; // TODO somehow be more useful.
    554         };
    555         switch (@typeInfo(field.type)) { // TODO handle more types, default values etc etc.
    556             .Int => {
    557                 @field(t, field.name) = try std.fmt.parseInt(field.type, value, 10);
    558             },
    559             .Pointer => |ptrinfo| {
    560                 if (ptrinfo.size != .Slice) {
    561                     @compileError("field pointer size " ++ @tagName(ptrinfo.size) ++ " is not supported, only []u8 is supported right now");
    562                 }
    563                 if (ptrinfo.child != u8) {
    564                     @compileError("field pointer type " ++ @tagName(@typeInfo(ptrinfo.child)) ++ " is not supported, only []u8 is supported right now");
    565                 }
    566                 const dvalue = try allocator.dupe(u8, value);
    567                 errdefer allocator.free(dvalue);
    568                 @field(t, field.name) = dvalue;
    569             },
    570             else => @compileError("field type " ++ @tagName(@typeInfo(field.type)) ++ " not supported on field " ++ field.name),
    571         }
    572     }
    573     return t;
    574 }
    575 
    576 const StructTest = struct {
    577     test "to struct" {
    578         const T = struct {
    579             foo: i64,
    580             bar: []const u8,
    581             pub fn deinit(self: *@This()) void {
    582                 std.testing.allocator.free(self.bar);
    583             }
    584         };
    585         var hm = std.StringHashMap([]const u8).init(std.testing.allocator);
    586         defer hm.deinit();
    587         try hm.put("foo", "42");
    588         try hm.put("bar", "oops");
    589         var t = try to_struct(std.testing.allocator, T, hm);
    590         defer t.deinit();
    591         try std.testing.expectEqualDeep(T{ .foo = 42, .bar = "oops" }, t);
    592     }
    593 };
    594 
    595 test {
    596     _ = RouterTest;
    597     _ = PathTest;
    598     _ = PercentEncodeTest;
    599     _ = StructTest;
    600 }