diff options
Diffstat (limited to 'src')
-rw-r--r-- | src/main.zig | 5 | ||||
-rw-r--r-- | src/zigwebserver.zig | 176 |
2 files changed, 140 insertions, 41 deletions
diff --git a/src/main.zig b/src/main.zig index 1f4d8f9..70869b2 100644 --- a/src/main.zig +++ b/src/main.zig @@ -10,7 +10,8 @@ const Context = struct { }; const Handler = struct { pub fn handle(_: Handler, res: *std.http.Server.Response, _: Context) !void { - const p = try zws.Path.parse(res.request.target); + var p = try zws.Path.parse(res.allocator, res.request.target); + defer p.deinit(); const path = try std.fs.path.join(res.allocator, &[_][]const u8{ ".", p.path }); defer res.allocator.free(path); if (std.fs.cwd().openFile(path, .{})) |file| { @@ -43,7 +44,7 @@ const Handler = struct { while (true) { const read = try file.read(&buf); if (read == 0) break; - _ = try res.write(buf[0..read]); + _ = try res.writeAll(buf[0..read]); } } diff --git a/src/zigwebserver.zig b/src/zigwebserver.zig index 37a43a4..42188f8 100644 --- a/src/zigwebserver.zig +++ b/src/zigwebserver.zig @@ -97,7 +97,8 @@ pub fn Router(comptime Context: type, comptime ErrorType: type) type { // It is a programmer error to call this without calling .wait first. if (res.state != .waited) unreachable; - const p = try Path.parse(res.request.target); + var p = try Path.parse(res.allocator, res.request.target); + defer p.deinit(); const path = p.path; handler_loop: for (self.handlers) |handler| { @@ -348,86 +349,183 @@ const RouterTest = struct { /// HTTP path parsing /// which is a subset of URI parsing :) +/// RFC-3986 pub const Path = struct { + allocator: std.mem.Allocator, path: []const u8, query: []const u8, fragment: []const u8, // technically I think the fragment is never received on the server anyway + query_parsed: ?Form = null, - pub const ParseError = error{Malformatted}; + pub const ParseError = error{Malformatted} || Form.ParseError; - pub fn parse(path: []const u8) ParseError!Path { - var p = Path{ - .path = path, - .query = "", - .fragment = "", - }; - const q_ix = std.mem.indexOfScalar(u8, path, '?'); - const f_ix = std.mem.indexOfScalar(u8, path, '#'); + pub fn parse(allocator: std.mem.Allocator, str: []const u8) ParseError!Path { + var path: []const u8 = str; + var query: []const u8 = ""; + var fragment: []const u8 = ""; + const f_ix = std.mem.indexOfScalar(u8, str, '#'); + const q_ix = std.mem.indexOfScalar(u8, str, '?'); if (q_ix) |q| { - p.path = path[0..q]; + path = str[0..q]; if (f_ix) |f| { if (f < q) { return ParseError.Malformatted; } - p.query = path[(q + 1)..f]; - p.fragment = path[(f + 1)..]; + query = str[(q + 1)..f]; + fragment = str[(f + 1)..]; } else { - p.query = path[(q + 1)..]; + query = str[(q + 1)..]; } } else if (f_ix) |f| { - p.path = path[0..f]; - p.fragment = path[(f + 1)..]; + path = str[0..f]; + fragment = str[(f + 1)..]; } - return p; + return Path{ + .allocator = allocator, + .path = path, + .query = query, + .fragment = fragment, + }; } - pub fn get_query_param(self: Path, key: []const u8) ?[]const u8 { - var it1 = std.mem.splitScalar(u8, self.query, '&'); - var t: ?[]const u8 = it1.first(); - while (t != null) : (t = it1.next()) { - var it2 = std.mem.splitScalar(u8, t.?, '='); - const k = it2.first(); - const v = it2.next(); - if (std.mem.eql(u8, key, k)) { - return v; - } + pub fn get_query_param(self: *Path, key: []const u8) !?[]const u8 { + if (self.query_parsed == null) { + self.query_parsed = try Form.parse(self.allocator, self.query); + } + return self.query_parsed.?.data.get(key); + } + pub fn deinit(self: *Path) void { + if (self.query_parsed != null) { + self.query_parsed.?.deinit(); } - return null; } }; const PathTest = struct { test "path" { - const p = try Path.parse("/"); - try std.testing.expectEqualDeep(Path{ .path = "/", .query = "", .fragment = "" }, p); + var p = try Path.parse(std.testing.allocator, "/"); + defer p.deinit(); + try assertPath("/", "", "", p); + } + + fn assertPath(path: []const u8, query: []const u8, fragment: []const u8, actual: Path) !void { + try std.testing.expectEqualSlices(u8, path, actual.path); + try std.testing.expectEqualSlices(u8, query, actual.query); + try std.testing.expectEqualSlices(u8, fragment, actual.fragment); } test "query" { - const p = try Path.parse("/foo?bar=baz"); - try std.testing.expectEqualDeep(Path{ .path = "/foo", .query = "bar=baz", .fragment = "" }, p); + var p = try Path.parse(std.testing.allocator, "/foo?bar=baz"); + defer p.deinit(); + try assertPath("/foo", "bar=baz", "", p); } test "query and fragment" { - const p = try Path.parse("/foo?bar=baz#frag"); - try std.testing.expectEqualDeep(Path{ .path = "/foo", .query = "bar=baz", .fragment = "frag" }, p); + var p = try Path.parse(std.testing.allocator, "/foo?bar=baz#frag"); + defer p.deinit(); + try assertPath("/foo", "bar=baz", "frag", p); } test "fragment" { - const p = try Path.parse("/foo#frag"); - try std.testing.expectEqualDeep(Path{ .path = "/foo", .query = "", .fragment = "frag" }, p); + var p = try Path.parse(std.testing.allocator, "/foo#frag"); + defer p.deinit(); + try assertPath("/foo", "", "frag", p); } test "query param" { - const p = try Path.parse("/foo?bar=baz#frag"); - const v1 = p.get_query_param("bar"); + var p = try Path.parse(std.testing.allocator, "/foo?bar=baz#frag"); + defer p.deinit(); + const v1 = try p.get_query_param("bar"); try std.testing.expect(v1 != null); try std.testing.expectEqualSlices(u8, "baz", v1.?); - const v2 = p.get_query_param("bam"); + const v2 = try p.get_query_param("bam"); try std.testing.expect(v2 == null); } + + test "query param mixed" { + var p = try Path.parse(std.testing.allocator, "/foo?bar=baz&ba+m=bo+om&zigzag#frag"); + defer p.deinit(); + try assertPath("/foo", "bar=baz&ba+m=bo+om&zigzag", "frag", p); + const v1 = try p.get_query_param("bar"); + try std.testing.expect(v1 != null); + try std.testing.expectEqualSlices(u8, "baz", v1.?); + const v2 = try p.get_query_param("ba m"); + try std.testing.expect(v2 != null); + try std.testing.expectEqualSlices(u8, "bo om", v2.?); + } +}; + +pub const Form = struct { + allocator: std.mem.Allocator, + data: std.StringHashMap([]const u8), + + const ParseError = error{ Malformatted, InvalidLength, InvalidCharacter, NoSpaceLeft } || std.mem.Allocator.Error; + + // Tries to parse key=value&key2=value2 pairs from the form. + // Note that a URL query segment doesn't _have_ to be key-value pairs + // so this is quite lenient. + // Form struct owns all the keys and values in the resulting map. + pub fn parse(allocator: std.mem.Allocator, form: []const u8) ParseError!Form { + var res = std.StringHashMap([]const u8).init(allocator); + var iter1 = std.mem.splitScalar(u8, form, '&'); + while (iter1.next()) |split| { + var iter2 = std.mem.splitScalar(u8, split, '='); + if (iter2.next()) |key| { + if (iter2.next()) |value| { + try res.put(try percent_decode(allocator, key), try percent_decode(allocator, value)); + } else { + // Do nothing, it's a well-formatted kv pair + } + } else { + // Do nothing it's not a well-formatted kv pair + } + } + return Form{ .allocator = allocator, .data = res }; + } + pub fn deinit(self: *Form) void { + var it = self.data.iterator(); + var e = it.next(); + while (e != null) : (e = it.next()) { + self.allocator.free(e.?.key_ptr.*); + self.allocator.free(e.?.value_ptr.*); + } + self.data.deinit(); + } +}; + +fn percent_decode(allocator: std.mem.Allocator, str: []const u8) ![]const u8 { + var fbs = std.io.fixedBufferStream(str); + var rdr = fbs.reader(); + var out = std.ArrayList(u8).init(allocator); + var wtr = out.writer(); + defer out.deinit(); + while (true) { + const b = rdr.readByte() catch break; + if (b == '%') { + var hex_code: [2]u8 = undefined; + _ = try rdr.readAll(&hex_code); + var b2: [1]u8 = .{0}; + _ = try std.fmt.hexToBytes(&b2, &hex_code); + try wtr.writeByte(b2[0]); + } else if (b == '+') { + try wtr.writeByte(' '); + } else { + try wtr.writeByte(b); + } + } + return out.toOwnedSlice(); +} + +const PercentEncodeTest = struct { + test "decode" { + const decoded = try percent_decode(std.testing.allocator, "%C3%A7%C3%AE%C4%85%C3%B5+hithere"); + defer std.testing.allocator.free(decoded); + try std.testing.expectEqualStrings("çîąõ hithere", decoded); + } }; test { _ = RouterTest; _ = PathTest; + _ = PercentEncodeTest; } |