From 35494bc81b59165ee9264cd1004bb05a120279a3 Mon Sep 17 00:00:00 2001 From: Martin Ashby Date: Thu, 28 Sep 2023 10:54:58 +0100 Subject: Reduce allocations, the message takes ownership of the bytes read from the stream and is responsible for deallocating, rather than copying them to their own storage. --- src/proto/authentication_request.zig | 11 ++++++----- src/proto/backend_key_data.zig | 10 +++++----- src/proto/command_complete.zig | 13 ++++++------- src/proto/copy_x_response.zig | 6 +++--- src/proto/data_row.zig | 23 ++++++++++------------- src/proto/error_response.zig | 5 ++--- src/proto/parameter_status.zig | 18 +++++++++--------- src/proto/proto.zig | 2 +- src/proto/ready_for_query.zig | 9 ++++----- src/proto/row_description.zig | 35 ++++++++++++++++++----------------- 10 files changed, 64 insertions(+), 68 deletions(-) diff --git a/src/proto/authentication_request.zig b/src/proto/authentication_request.zig index 3ea5cd1..c4bddb5 100644 --- a/src/proto/authentication_request.zig +++ b/src/proto/authentication_request.zig @@ -25,12 +25,14 @@ pub const AuthRequestCleartextPassword = struct {}; inner_type: InnerAuthRequestType, inner: InnerAuthRequest, -pub fn read(_: std.mem.Allocator, b: []const u8) !AuthenticationRequest { - if (b.len != 4) { - log.err("invalid message length, expected 4 got {}", .{b.len}); +// takes ownership of b +pub fn read(a: std.mem.Allocator, buf: []const u8) !AuthenticationRequest { + defer a.free(buf); // No need to retain it we copy all the interesting data out + if (buf.len != 4) { + log.err("invalid message length, expected 4 got {}", .{buf.len}); return ProtocolError.InvalidMessageLength; } - const auth_type_int = std.mem.readIntBig(u32, b[0..4]); + const auth_type_int = std.mem.readIntBig(u32, buf[0..4]); const inner_type = enum_from_int(InnerAuthRequestType, auth_type_int) orelse { log.err("Unsupported auth type {}", .{auth_type_int}); return ClientError.UnsupportedAuthType; @@ -71,7 +73,6 @@ test "round trip" { try std.testing.expectEqual(Tag, tag); const len = try reader.readIntBig(u32); const buf = try allocator.alloc(u8, len - 4); - defer allocator.free(buf); try reader.readNoEof(buf); var sm2 = try AuthenticationRequest.read(allocator, buf); defer sm2.deinit(allocator); diff --git a/src/proto/backend_key_data.zig b/src/proto/backend_key_data.zig index 7c32178..4e7f30d 100644 --- a/src/proto/backend_key_data.zig +++ b/src/proto/backend_key_data.zig @@ -10,11 +10,12 @@ pub const Tag: u8 = 'K'; process_id: u32, secret_key: u32, -pub fn read(_: std.mem.Allocator, b: []const u8) !BackendKeyData { - if (b.len != 8) return ProtocolError.InvalidMessageLength; +pub fn read(a: std.mem.Allocator, buf: []const u8) !BackendKeyData { + defer a.free(buf); + if (buf.len != 8) return ProtocolError.InvalidMessageLength; return .{ - .process_id = std.mem.readIntBig(u32, b[0..4]), - .secret_key = std.mem.readIntBig(u32, b[4..8]), + .process_id = std.mem.readIntBig(u32, buf[0..4]), + .secret_key = std.mem.readIntBig(u32, buf[4..8]), }; } pub fn write(self: BackendKeyData, _: std.mem.Allocator, stream_writer: anytype) !void { @@ -43,7 +44,6 @@ test "round trip" { try std.testing.expectEqual(Tag, tag); const len = try reader.readIntBig(u32); const buf = try allocator.alloc(u8, len - 4); - defer allocator.free(buf); try reader.readNoEof(buf); var sm2 = try BackendKeyData.read(allocator, buf); defer sm2.deinit(allocator); diff --git a/src/proto/command_complete.zig b/src/proto/command_complete.zig index 80014e9..ed8e052 100644 --- a/src/proto/command_complete.zig +++ b/src/proto/command_complete.zig @@ -10,24 +10,24 @@ pub const Tag: u8 = 'C'; const CommandComplete = @This(); +buf: ?[]const u8 = null, // owned command_tag: []const u8, -owned: bool = false, -pub fn read(a: std.mem.Allocator, b: []const u8) !CommandComplete { +pub fn read(_: std.mem.Allocator, buf: []const u8) !CommandComplete { return .{ - .command_tag = try a.dupe(u8, b), - .owned = true, + .buf = buf, + .command_tag = buf[0..], }; } pub fn write(self: CommandComplete, _: std.mem.Allocator, stream_writer: anytype) !void { try stream_writer.writeByte(Tag); - try stream_writer.writeIntBig(u32, @as(u32, @intCast(4+self.command_tag.len))); + try stream_writer.writeIntBig(u32, @as(u32, @intCast(4 + self.command_tag.len))); try stream_writer.writeAll(self.command_tag); } pub fn deinit(self: *CommandComplete, a: std.mem.Allocator) void { - if (self.owned) a.free(self.command_tag); + if (self.buf != null) a.free(self.buf.?); } test "round trip" { @@ -47,7 +47,6 @@ test "round trip" { try std.testing.expectEqual(Tag, tag); const len = try reader.readIntBig(u32); const buf = try allocator.alloc(u8, len - 4); - defer allocator.free(buf); try reader.readNoEof(buf); var sm2 = try CommandComplete.read(allocator, buf); defer sm2.deinit(allocator); diff --git a/src/proto/copy_x_response.zig b/src/proto/copy_x_response.zig index 9d1b26c..a4370a7 100644 --- a/src/proto/copy_x_response.zig +++ b/src/proto/copy_x_response.zig @@ -13,8 +13,9 @@ pub fn CopyXResponse(comptime tag: u8) type { overall_format_code: u8, format_codes: []const FormatCode, // owned - pub fn read(a: std.mem.Allocator, b: []const u8) !@This() { - var fbs = std.io.fixedBufferStream(b); + pub fn read(a: std.mem.Allocator, buf: []const u8) !@This() { + defer a.free(buf); + var fbs = std.io.fixedBufferStream(buf); var reader = fbs.reader(); const overall_format_code = try reader.readIntBig(u8); const n_columns = try reader.readIntBig(u16); @@ -76,7 +77,6 @@ test "round trip" { try std.testing.expectEqual(Tag, tag); const len = try reader.readIntBig(u32); const buf = try allocator.alloc(u8, len - 4); - defer allocator.free(buf); try reader.readNoEof(buf); var sm2 = try CopyInResponse.read(allocator, buf); defer sm2.deinit(allocator); diff --git a/src/proto/data_row.zig b/src/proto/data_row.zig index c20b794..43c4526 100644 --- a/src/proto/data_row.zig +++ b/src/proto/data_row.zig @@ -12,25 +12,23 @@ const DataRow = @This(); buf: ?[]const u8 = null, // owned columns: [][]const u8, // also owned -pub fn read(a: std.mem.Allocator, b: []const u8) !DataRow { - if (b.len < 2) return ProtocolError.InvalidMessageLength; - var buf = try a.dupe(u8, b); - var res: DataRow = undefined; - res.buf = buf; - errdefer res.deinit(a); - +pub fn read(a: std.mem.Allocator, buf: []const u8) !DataRow { + if (buf.len < 2) return ProtocolError.InvalidMessageLength; + errdefer a.free(buf); const n_columns = std.mem.readIntBig(u16, buf[0..2]); const columns = try a.alloc([]const u8, n_columns); errdefer a.free(columns); var pos: usize = 2; for (0..n_columns) |col| { - const len = std.mem.readIntBig(u32, buf[pos..(pos+4)][0..4]); // second slice forces the slice size to be known at comptime and satisfy the type check on readIntBig! - const data = if (len > 0) buf[(pos+4)..(pos+4+len)] else &[_]u8{}; + const len = std.mem.readIntBig(u32, buf[pos..(pos + 4)][0..4]); // second slice forces the slice size to be known at comptime and satisfy the type check on readIntBig! + const data = if (len > 0) buf[(pos + 4)..(pos + 4 + len)] else &[_]u8{}; columns[col] = data; - pos += (4+len); + pos += (4 + len); } - res.columns = columns; - return res; + return .{ + .buf = buf, + .columns = columns, + }; } pub fn write(self: DataRow, a: std.mem.Allocator, stream_writer: anytype) !void { @@ -77,7 +75,6 @@ test "round trip" { try std.testing.expectEqual(Tag, tag); const len = try reader.readIntBig(u32); const buf = try allocator.alloc(u8, len - 4); - defer allocator.free(buf); try reader.readNoEof(buf); var sm2 = try DataRow.read(allocator, buf); defer sm2.deinit(allocator); diff --git a/src/proto/error_response.zig b/src/proto/error_response.zig index dc75053..58ca06e 100644 --- a/src/proto/error_response.zig +++ b/src/proto/error_response.zig @@ -27,13 +27,13 @@ line: ?u32 = null, routine: ?[]const u8 = null, unknown_fields: HMByteString, -pub fn read(allocator: std.mem.Allocator, b: []const u8) !ErrorResponse { +pub fn read(allocator: std.mem.Allocator, buf: []const u8) !ErrorResponse { var res = ErrorResponse{ .severity = "", .code = "", .message = "", .unknown_fields = HMByteString.init(allocator), - .buf = try allocator.dupe(u8, b), + .buf = buf, }; errdefer res.deinit(allocator); var it = std.mem.splitScalar(u8, res.buf.?, 0); @@ -178,7 +178,6 @@ test "round trip" { try std.testing.expectEqual(Tag, tag); const len = try reader.readIntBig(u32); const buf = try allocator.alloc(u8, len - 4); - defer allocator.free(buf); try reader.readNoEof(buf); var sm2 = try ErrorResponse.read(allocator, buf); defer sm2.deinit(allocator); diff --git a/src/proto/parameter_status.zig b/src/proto/parameter_status.zig index 5f95695..7bc306b 100644 --- a/src/proto/parameter_status.zig +++ b/src/proto/parameter_status.zig @@ -11,13 +11,14 @@ buf: ?[]const u8 = null, // owned name: []const u8, value: []const u8, -pub fn read(allocator: std.mem.Allocator, b: []const u8) !ParameterStatus { - var res: ParameterStatus = undefined; - res.buf = try allocator.dupe(u8, b); - var it = std.mem.splitScalar(u8, res.buf.?, 0); - res.name = it.first(); - res.value = it.next() orelse return ProtocolError.MissingField; - return res; +pub fn read(a: std.mem.Allocator, buf: []const u8) !ParameterStatus { + errdefer a.free(buf); + var it = std.mem.splitScalar(u8, buf, 0); + return .{ + .buf = buf, + .name = it.first(), + .value = it.next() orelse return ProtocolError.MissingField, + }; } pub fn write(self: ParameterStatus, a: std.mem.Allocator, stream_writer: anytype) !void { try stream_writer.writeByte(Tag); @@ -30,7 +31,7 @@ pub fn write(self: ParameterStatus, a: std.mem.Allocator, stream_writer: anytype try writer.writeByte(0); try writer.writeAll(self.value); try writer.writeByte(0); - std.mem.writeIntBig(u32, al.items[0..4], @as(u32,@intCast(cw.bytes_written))); // Fix length + std.mem.writeIntBig(u32, al.items[0..4], @as(u32, @intCast(cw.bytes_written))); // Fix length try stream_writer.writeAll(al.items); } pub fn deinit(self: *ParameterStatus, allocator: std.mem.Allocator) void { @@ -55,7 +56,6 @@ test "round trip" { try std.testing.expectEqual(Tag, tag); const len = try reader.readIntBig(u32); const buf = try allocator.alloc(u8, len - 4); - defer allocator.free(buf); try reader.readNoEof(buf); var sm2 = try ParameterStatus.read(allocator, buf); defer sm2.deinit(allocator); diff --git a/src/proto/proto.zig b/src/proto/proto.zig index 5e4489d..9025347 100644 --- a/src/proto/proto.zig +++ b/src/proto/proto.zig @@ -51,13 +51,13 @@ pub fn read_message(allocator: std.mem.Allocator, stream_reader: anytype) !Backe const tag = try stream_reader.readByte(); const len = try stream_reader.readIntBig(u32); const buf = try allocator.alloc(u8, @as(u32, @intCast(len - 4))); - defer allocator.free(buf); try stream_reader.readNoEof(buf); inline for (@typeInfo(BackendMessage).Union.fields) |field| { if (field.type.Tag == tag) { return @unionInit(BackendMessage, field.name, try field.type.read(allocator, buf)); } } else { + allocator.free(buf); return ProtocolError.InvalidMessageType; } } diff --git a/src/proto/ready_for_query.zig b/src/proto/ready_for_query.zig index ef99e60..6bf25f9 100644 --- a/src/proto/ready_for_query.zig +++ b/src/proto/ready_for_query.zig @@ -14,10 +14,10 @@ const TransactionStatus = enum(u8) { transaction_status: TransactionStatus, -pub fn read(allocator: std.mem.Allocator, b: []const u8) !ReadyForQuery { - _ = allocator; - if (b.len != 1) return ProtocolError.InvalidMessageLength; - return .{ .transaction_status = enum_from_int(TransactionStatus, b[0]) orelse return ProtocolError.InvalidTransactionStatus }; +pub fn read(a: std.mem.Allocator, buf: []const u8) !ReadyForQuery { + defer a.free(buf); + if (buf.len != 1) return ProtocolError.InvalidMessageLength; + return .{ .transaction_status = enum_from_int(TransactionStatus, buf[0]) orelse return ProtocolError.InvalidTransactionStatus }; } pub fn write(self: ReadyForQuery, allocator: std.mem.Allocator, stream_writer: anytype) !void { _ = allocator; @@ -44,7 +44,6 @@ test "round trip" { try std.testing.expectEqual(Tag, tag); const len = try reader.readIntBig(u32); const buf = try allocator.alloc(u8, len - 4); - defer allocator.free(buf); try reader.readNoEof(buf); var sm2 = try ReadyForQuery.read(allocator, buf); defer sm2.deinit(allocator); diff --git a/src/proto/row_description.zig b/src/proto/row_description.zig index b8105e2..ff17716 100644 --- a/src/proto/row_description.zig +++ b/src/proto/row_description.zig @@ -11,7 +11,7 @@ pub const Tag: u8 = 'T'; const RowDescription = @This(); buf: ?[]const u8 = null, // owned -fields: ?[]Field = null, // owned +fields: []Field, // owned pub const Field = struct { name: []const u8, @@ -23,19 +23,18 @@ pub const Field = struct { format_code: FormatCode, }; -pub fn read(a: std.mem.Allocator, b: []const u8) !RowDescription { - var res: RowDescription = undefined; - res.buf = try a.dupe(u8, b); - errdefer res.deinit(a); - var fbs = std.io.fixedBufferStream(res.buf.?); +pub fn read(a: std.mem.Allocator, buf: []const u8) !RowDescription { + errdefer a.free(buf); + var fbs = std.io.fixedBufferStream(buf); var reader = fbs.reader(); const n_fields = try reader.readIntBig(u16); - res.fields = try a.alloc(Field, n_fields); + var fields = try a.alloc(Field, n_fields); + errdefer a.free(fields); for (0..n_fields) |i| { const name_start = fbs.pos; try reader.skipUntilDelimiterOrEof(0); const name_end = fbs.pos - 1; - const name = res.buf.?[name_start..name_end]; + const name = buf[name_start..name_end]; const field = Field{ .name = name, .table_oid = try reader.readIntBig(u32), @@ -45,9 +44,12 @@ pub fn read(a: std.mem.Allocator, b: []const u8) !RowDescription { .data_type_modifier = try reader.readIntBig(u32), .format_code = enum_from_int(FormatCode, try reader.readIntBig(u16)) orelse return ProtocolError.InvalidFormatCode, }; - res.fields.?[i] = field; + fields[i] = field; } - return res; + return .{ + .buf = buf, + .fields = fields, + }; } pub fn write(self: RowDescription, a: std.mem.Allocator, stream_writer: anytype) !void { @@ -57,8 +59,8 @@ pub fn write(self: RowDescription, a: std.mem.Allocator, stream_writer: anytype) var cw = std.io.countingWriter(al.writer()); var writer = cw.writer(); try writer.writeIntBig(u32, 0); // length placeholder - try writer.writeIntBig(u16, @as(u16, @intCast(self.fields.?.len))); - for (self.fields.?) |field| { + try writer.writeIntBig(u16, @as(u16, @intCast(self.fields.len))); + for (self.fields) |field| { try writer.writeAll(field.name); try writer.writeByte(0); try writer.writeIntBig(u32, field.table_oid); @@ -73,7 +75,7 @@ pub fn write(self: RowDescription, a: std.mem.Allocator, stream_writer: anytype) } pub fn deinit(self: *RowDescription, a: std.mem.Allocator) void { - if (self.fields != null) a.free(self.fields.?); + a.free(self.fields); if (self.buf != null) a.free(self.buf.?); } @@ -125,12 +127,11 @@ test "round trip" { try std.testing.expectEqual(Tag, tag); const len = try reader.readIntBig(u32); const buf = try allocator.alloc(u8, len - 4); - defer allocator.free(buf); try reader.readNoEof(buf); var sm2 = try RowDescription.read(allocator, buf); defer sm2.deinit(allocator); - try std.testing.expectEqualDeep(f0, sm2.fields.?[0]); - try std.testing.expectEqualDeep(f1, sm2.fields.?[1]); - try std.testing.expectEqualDeep(f2, sm2.fields.?[2]); + try std.testing.expectEqualDeep(f0, sm2.fields[0]); + try std.testing.expectEqualDeep(f1, sm2.fields[1]); + try std.testing.expectEqualDeep(f2, sm2.fields[2]); } -- cgit v1.2.3-ZIG