commit 8156e741ef7c8a59317ba801aa0c8940e9fcbd00
parent 95699a6ed77e1480e7b9256035225981cb33bfbb
Author: Martin Ashby <martin@ashbysoft.com>
Date: Sun, 6 Aug 2023 19:55:20 +0100
Percent-decode query parameters like they shoud be according to the spec
https://www.w3.org/TR/html401/interact/forms.html
Diffstat:
2 files changed, 140 insertions(+), 41 deletions(-)
diff --git a/src/main.zig b/src/main.zig
@@ -10,7 +10,8 @@ const Context = struct {
};
const Handler = struct {
pub fn handle(_: Handler, res: *std.http.Server.Response, _: Context) !void {
- const p = try zws.Path.parse(res.request.target);
+ var p = try zws.Path.parse(res.allocator, res.request.target);
+ defer p.deinit();
const path = try std.fs.path.join(res.allocator, &[_][]const u8{ ".", p.path });
defer res.allocator.free(path);
if (std.fs.cwd().openFile(path, .{})) |file| {
@@ -43,7 +44,7 @@ const Handler = struct {
while (true) {
const read = try file.read(&buf);
if (read == 0) break;
- _ = try res.write(buf[0..read]);
+ _ = try res.writeAll(buf[0..read]);
}
}
diff --git a/src/zigwebserver.zig b/src/zigwebserver.zig
@@ -97,7 +97,8 @@ pub fn Router(comptime Context: type, comptime ErrorType: type) type {
// It is a programmer error to call this without calling .wait first.
if (res.state != .waited) unreachable;
- const p = try Path.parse(res.request.target);
+ var p = try Path.parse(res.allocator, res.request.target);
+ defer p.deinit();
const path = p.path;
handler_loop: for (self.handlers) |handler| {
@@ -348,86 +349,183 @@ const RouterTest = struct {
/// HTTP path parsing
/// which is a subset of URI parsing :)
+/// RFC-3986
pub const Path = struct {
+ allocator: std.mem.Allocator,
path: []const u8,
query: []const u8,
fragment: []const u8, // technically I think the fragment is never received on the server anyway
+ query_parsed: ?Form = null,
- pub const ParseError = error{Malformatted};
+ pub const ParseError = error{Malformatted} || Form.ParseError;
- pub fn parse(path: []const u8) ParseError!Path {
- var p = Path{
- .path = path,
- .query = "",
- .fragment = "",
- };
- const q_ix = std.mem.indexOfScalar(u8, path, '?');
- const f_ix = std.mem.indexOfScalar(u8, path, '#');
+ pub fn parse(allocator: std.mem.Allocator, str: []const u8) ParseError!Path {
+ var path: []const u8 = str;
+ var query: []const u8 = "";
+ var fragment: []const u8 = "";
+ const f_ix = std.mem.indexOfScalar(u8, str, '#');
+ const q_ix = std.mem.indexOfScalar(u8, str, '?');
if (q_ix) |q| {
- p.path = path[0..q];
+ path = str[0..q];
if (f_ix) |f| {
if (f < q) {
return ParseError.Malformatted;
}
- p.query = path[(q + 1)..f];
- p.fragment = path[(f + 1)..];
+ query = str[(q + 1)..f];
+ fragment = str[(f + 1)..];
} else {
- p.query = path[(q + 1)..];
+ query = str[(q + 1)..];
}
} else if (f_ix) |f| {
- p.path = path[0..f];
- p.fragment = path[(f + 1)..];
+ path = str[0..f];
+ fragment = str[(f + 1)..];
}
- return p;
+ return Path{
+ .allocator = allocator,
+ .path = path,
+ .query = query,
+ .fragment = fragment,
+ };
}
- pub fn get_query_param(self: Path, key: []const u8) ?[]const u8 {
- var it1 = std.mem.splitScalar(u8, self.query, '&');
- var t: ?[]const u8 = it1.first();
- while (t != null) : (t = it1.next()) {
- var it2 = std.mem.splitScalar(u8, t.?, '=');
- const k = it2.first();
- const v = it2.next();
- if (std.mem.eql(u8, key, k)) {
- return v;
- }
+ pub fn get_query_param(self: *Path, key: []const u8) !?[]const u8 {
+ if (self.query_parsed == null) {
+ self.query_parsed = try Form.parse(self.allocator, self.query);
+ }
+ return self.query_parsed.?.data.get(key);
+ }
+ pub fn deinit(self: *Path) void {
+ if (self.query_parsed != null) {
+ self.query_parsed.?.deinit();
}
- return null;
}
};
const PathTest = struct {
test "path" {
- const p = try Path.parse("/");
- try std.testing.expectEqualDeep(Path{ .path = "/", .query = "", .fragment = "" }, p);
+ var p = try Path.parse(std.testing.allocator, "/");
+ defer p.deinit();
+ try assertPath("/", "", "", p);
+ }
+
+ fn assertPath(path: []const u8, query: []const u8, fragment: []const u8, actual: Path) !void {
+ try std.testing.expectEqualSlices(u8, path, actual.path);
+ try std.testing.expectEqualSlices(u8, query, actual.query);
+ try std.testing.expectEqualSlices(u8, fragment, actual.fragment);
}
test "query" {
- const p = try Path.parse("/foo?bar=baz");
- try std.testing.expectEqualDeep(Path{ .path = "/foo", .query = "bar=baz", .fragment = "" }, p);
+ var p = try Path.parse(std.testing.allocator, "/foo?bar=baz");
+ defer p.deinit();
+ try assertPath("/foo", "bar=baz", "", p);
}
test "query and fragment" {
- const p = try Path.parse("/foo?bar=baz#frag");
- try std.testing.expectEqualDeep(Path{ .path = "/foo", .query = "bar=baz", .fragment = "frag" }, p);
+ var p = try Path.parse(std.testing.allocator, "/foo?bar=baz#frag");
+ defer p.deinit();
+ try assertPath("/foo", "bar=baz", "frag", p);
}
test "fragment" {
- const p = try Path.parse("/foo#frag");
- try std.testing.expectEqualDeep(Path{ .path = "/foo", .query = "", .fragment = "frag" }, p);
+ var p = try Path.parse(std.testing.allocator, "/foo#frag");
+ defer p.deinit();
+ try assertPath("/foo", "", "frag", p);
}
test "query param" {
- const p = try Path.parse("/foo?bar=baz#frag");
- const v1 = p.get_query_param("bar");
+ var p = try Path.parse(std.testing.allocator, "/foo?bar=baz#frag");
+ defer p.deinit();
+ const v1 = try p.get_query_param("bar");
try std.testing.expect(v1 != null);
try std.testing.expectEqualSlices(u8, "baz", v1.?);
- const v2 = p.get_query_param("bam");
+ const v2 = try p.get_query_param("bam");
try std.testing.expect(v2 == null);
}
+
+ test "query param mixed" {
+ var p = try Path.parse(std.testing.allocator, "/foo?bar=baz&ba+m=bo+om&zigzag#frag");
+ defer p.deinit();
+ try assertPath("/foo", "bar=baz&ba+m=bo+om&zigzag", "frag", p);
+ const v1 = try p.get_query_param("bar");
+ try std.testing.expect(v1 != null);
+ try std.testing.expectEqualSlices(u8, "baz", v1.?);
+ const v2 = try p.get_query_param("ba m");
+ try std.testing.expect(v2 != null);
+ try std.testing.expectEqualSlices(u8, "bo om", v2.?);
+ }
+};
+
+pub const Form = struct {
+ allocator: std.mem.Allocator,
+ data: std.StringHashMap([]const u8),
+
+ const ParseError = error{ Malformatted, InvalidLength, InvalidCharacter, NoSpaceLeft } || std.mem.Allocator.Error;
+
+ // Tries to parse key=value&key2=value2 pairs from the form.
+ // Note that a URL query segment doesn't _have_ to be key-value pairs
+ // so this is quite lenient.
+ // Form struct owns all the keys and values in the resulting map.
+ pub fn parse(allocator: std.mem.Allocator, form: []const u8) ParseError!Form {
+ var res = std.StringHashMap([]const u8).init(allocator);
+ var iter1 = std.mem.splitScalar(u8, form, '&');
+ while (iter1.next()) |split| {
+ var iter2 = std.mem.splitScalar(u8, split, '=');
+ if (iter2.next()) |key| {
+ if (iter2.next()) |value| {
+ try res.put(try percent_decode(allocator, key), try percent_decode(allocator, value));
+ } else {
+ // Do nothing, it's a well-formatted kv pair
+ }
+ } else {
+ // Do nothing it's not a well-formatted kv pair
+ }
+ }
+ return Form{ .allocator = allocator, .data = res };
+ }
+ pub fn deinit(self: *Form) void {
+ var it = self.data.iterator();
+ var e = it.next();
+ while (e != null) : (e = it.next()) {
+ self.allocator.free(e.?.key_ptr.*);
+ self.allocator.free(e.?.value_ptr.*);
+ }
+ self.data.deinit();
+ }
+};
+
+fn percent_decode(allocator: std.mem.Allocator, str: []const u8) ![]const u8 {
+ var fbs = std.io.fixedBufferStream(str);
+ var rdr = fbs.reader();
+ var out = std.ArrayList(u8).init(allocator);
+ var wtr = out.writer();
+ defer out.deinit();
+ while (true) {
+ const b = rdr.readByte() catch break;
+ if (b == '%') {
+ var hex_code: [2]u8 = undefined;
+ _ = try rdr.readAll(&hex_code);
+ var b2: [1]u8 = .{0};
+ _ = try std.fmt.hexToBytes(&b2, &hex_code);
+ try wtr.writeByte(b2[0]);
+ } else if (b == '+') {
+ try wtr.writeByte(' ');
+ } else {
+ try wtr.writeByte(b);
+ }
+ }
+ return out.toOwnedSlice();
+}
+
+const PercentEncodeTest = struct {
+ test "decode" {
+ const decoded = try percent_decode(std.testing.allocator, "%C3%A7%C3%AE%C4%85%C3%B5+hithere");
+ defer std.testing.allocator.free(decoded);
+ try std.testing.expectEqualStrings("çîąõ hithere", decoded);
+ }
};
test {
_ = RouterTest;
_ = PathTest;
+ _ = PercentEncodeTest;
}