diff --git a/scripts/mzteinit/build.zig b/scripts/mzteinit/build.zig index 97662fe..84a92ec 100644 --- a/scripts/mzteinit/build.zig +++ b/scripts/mzteinit/build.zig @@ -6,7 +6,6 @@ pub fn build(b: *std.build.Builder) !void { const optimize = b.standardOptimizeOption(.{}); const ansi_term_mod = b.dependency("ansi_term", .{}).module("ansi-term"); - const s2s_mod = b.dependency("s2s", .{}).module("s2s"); const exe = b.addExecutable(.{ .name = "mzteinit", @@ -30,7 +29,6 @@ pub fn build(b: *std.build.Builder) !void { inline for (.{ mzteinitctl, exe }) |e| { e.addModule("ansi-term", ansi_term_mod); - e.addModule("s2s", s2s_mod); } const cg_opt = try common.confgenGet(struct { diff --git a/scripts/mzteinit/build.zig.zon b/scripts/mzteinit/build.zig.zon index 371374f..3cb0ac5 100644 --- a/scripts/mzteinit/build.zig.zon +++ b/scripts/mzteinit/build.zig.zon @@ -7,10 +7,6 @@ .url = "https://github.com/ziglibs/ansi-term/archive/1614b61486d567b59abe11a097d11aa6ce679819.tar.gz", .hash = "1220647eea49d2c48d5e59354291e975f813be3cc5a9d9920a50bbfaa40a891a06ee", }, - .s2s = .{ - .url = "https://github.com/ziglibs/s2s/archive/f1d0508cc47b2af353658d4e52616a45aafa91ce.tar.gz", - .hash = "122084b614cc4ac1e0694812d8863446840a314d8654afebad9b6966ebe6792931b1", - }, }, .paths = .{""}, } diff --git a/scripts/mzteinit/src/mzteinitctl.zig b/scripts/mzteinit/src/mzteinitctl.zig index 3354126..e2d053c 100644 --- a/scripts/mzteinit/src/mzteinitctl.zig +++ b/scripts/mzteinit/src/mzteinitctl.zig @@ -2,6 +2,10 @@ const std = @import("std"); const Client = @import("sock/Client.zig"); +pub const std_options = struct { + pub const log_level = .debug; +}; + pub fn main() !void { var gpa = std.heap.GeneralPurposeAllocator(.{}){}; defer _ = gpa.deinit(); @@ -12,19 +16,35 @@ pub fn main() !void { const verb = std.mem.span(std.os.argv[1]); - const client = try Client.connect( - std.os.getenv("MZTEINIT_SOCKET") orelse return error.SocketPathUnknown, - ); - defer client.deinit(); - if (std.mem.eql(u8, verb, "ping")) { + const client = try Client.connect( + std.os.getenv("MZTEINIT_SOCKET") orelse return error.SocketPathUnknown, + ); + defer client.deinit(); + try client.ping(alloc); } else if (std.mem.eql(u8, verb, "getenv")) { if (std.os.argv.len < 3) return error.InvalidArgs; - const val = try client.getenv(alloc, std.mem.span(std.os.argv[2])); - defer if (val) |v| alloc.free(v); + const client = if (std.os.getenv("MZTEINIT_SOCKET")) |sockpath| + try Client.connect(sockpath) + else nosock: { + std.log.warn("MZTEINIT_SOCKET not set", .{}); + break :nosock null; + }; + defer if (client) |cl| cl.deinit(); + + const mzteinit_val = if (client) |cl| + try cl.getenv(alloc, std.mem.span(std.os.argv[2])) + else + null; + defer if (mzteinit_val) |v| alloc.free(v); + + const val = mzteinit_val orelse getenv: { + std.log.warn("Variable not known to MZTEINIT, falling back to current environment.", .{}); + break :getenv std.os.getenv(std.mem.span(std.os.argv[2])); + }; if (val) |v| { try std.io.getStdOut().writer().print("{s}\n", .{v}); diff --git a/scripts/mzteinit/src/sock/Client.zig b/scripts/mzteinit/src/sock/Client.zig index 3a5b047..66c84d9 100644 --- a/scripts/mzteinit/src/sock/Client.zig +++ b/scripts/mzteinit/src/sock/Client.zig @@ -1,5 +1,4 @@ const std = @import("std"); -const s2s = @import("s2s"); const message = @import("message.zig"); @@ -17,19 +16,19 @@ pub fn deinit(self: Client) void { } pub fn ping(self: Client, alloc: std.mem.Allocator) !void { - try s2s.serialize(self.stream.writer(), message.Serverbound, .ping); - var msg = try s2s.deserializeAlloc(self.stream.reader(), message.Clientbound, alloc); - defer s2s.free(alloc, message.Clientbound, &msg); - if (msg != .pong) + try (message.Serverbound{ .ping = .{} }).write(self.stream.writer()); + const res = try message.Clientbound.read(self.stream.reader(), alloc); + defer res.deinit(alloc); + if (!std.meta.eql(res, .{ .pong = .{} })) return error.InvalidResponse; } pub fn getenv(self: Client, alloc: std.mem.Allocator, key: []const u8) !?[]u8 { - try s2s.serialize(self.stream.writer(), message.Serverbound, .{ .getenv = key }); - var msg = try s2s.deserializeAlloc(self.stream.reader(), message.Clientbound, alloc); - defer s2s.free(alloc, message.Clientbound, &msg); - return switch (msg) { - .getenv_res => |val| if (val) |v| try alloc.dupe(u8, v) else null, + try (message.Serverbound{ .getenv = .{ .data = key } }).write(self.stream.writer()); + const res = try message.Clientbound.read(self.stream.reader(), alloc); + defer res.deinit(alloc); + return switch (res) { + .getenv_res => |val| if (val.inner) |v| try alloc.dupe(u8, v.data) else null, else => error.InvalidResponse, }; } diff --git a/scripts/mzteinit/src/sock/Server.zig b/scripts/mzteinit/src/sock/Server.zig index ece103a..a010225 100644 --- a/scripts/mzteinit/src/sock/Server.zig +++ b/scripts/mzteinit/src/sock/Server.zig @@ -1,10 +1,11 @@ const std = @import("std"); -const s2s = @import("s2s"); const message = @import("message.zig"); const Mutex = @import("../mutex.zig").Mutex; +const log = std.log.scoped(.server); + alloc: std.mem.Allocator, env: *Mutex(std.process.EnvMap), ss: std.net.StreamServer, @@ -26,29 +27,34 @@ pub fn run(self: *Server) !void { } pub fn handleConnection(self: *Server, con: std.net.StreamServer.Connection) !void { + defer con.stream.close(); while (true) { - var msg = s2s.deserializeAlloc(con.stream.reader(), message.Serverbound, self.alloc) catch |e| { + const msg = message.Serverbound.read(con.stream.reader(), self.alloc) catch |e| { switch (e) { - error.EndOfStream => { - con.stream.close(); - return; - }, + error.EndOfStream => return, else => return e, } }; - defer s2s.free(self.alloc, message.Serverbound, &msg); + defer msg.deinit(self.alloc); switch (msg) { - .ping => try s2s.serialize(con.stream.writer(), message.Clientbound, .pong), + .ping => { + log.info("got ping!", .{}); + try (message.Clientbound{ .pong = .{} }).write(con.stream.writer()); + }, .getenv => |key| { self.env.mtx.lock(); defer self.env.mtx.unlock(); - try s2s.serialize( - con.stream.writer(), - message.Clientbound, - .{ .getenv_res = self.env.data.get(key) }, - ); + log.info("env var '{s}' requested", .{key.data}); + + const payload = message.Clientbound{ .getenv_res = .{ + .inner = if (self.env.data.get(key.data)) |v| + .{ .data = v } + else + null, + } }; + try payload.write(con.stream.writer()); }, } } diff --git a/scripts/mzteinit/src/sock/message.zig b/scripts/mzteinit/src/sock/message.zig index ff4be66..13b4ca6 100644 --- a/scripts/mzteinit/src/sock/message.zig +++ b/scripts/mzteinit/src/sock/message.zig @@ -1,12 +1,117 @@ const std = @import("std"); +const native_endian = @import("builtin").cpu.arch.endian(); + pub const Serverbound = union(enum) { - ping, - getenv: []const u8, + ping: NullPayload, + getenv: BytesPayload, + + pub usingnamespace MessageFunctions(Serverbound); }; pub const Clientbound = union(enum) { - pong, - getenv_res: ?[]const u8, + pong: NullPayload, + getenv_res: OptionalPayload(BytesPayload), + + pub usingnamespace MessageFunctions(Clientbound); }; +pub const NullPayload = struct { + fn read(_: anytype, _: std.mem.Allocator) !NullPayload { + return .{}; + } + + fn write(_: NullPayload, _: anytype) !void {} +}; + +pub const BytesPayload = struct { + data: []const u8, + + fn read(reader: anytype, alloc: std.mem.Allocator) !BytesPayload { + const len = try reader.readInt(usize, native_endian); + const data = try alloc.alloc(u8, len); + errdefer alloc.free(data); + try reader.readNoEof(data); + + return .{ .data = data }; + } + + fn write(self: *const BytesPayload, writer: anytype) !void { + try writer.writeInt(usize, self.data.len, native_endian); + try writer.writeAll(self.data); + } + + fn deinit(self: BytesPayload, alloc: std.mem.Allocator) void { + alloc.free(self.data); + } +}; + +pub fn OptionalPayload(comptime T: type) type { + return struct { + inner: ?T, + + const Self = @This(); + + fn read(reader: anytype, alloc: std.mem.Allocator) !Self { + const present_byte = try reader.readByte(); + return switch (present_byte) { + 0 => .{ .inner = null }, + 1 => .{ .inner = try T.read(reader, alloc) }, + else => error.InvalidPacket, + }; + } + + fn write(self: *const Self, writer: anytype) !void { + if (self.inner) |i| { + try writer.writeByte(1); + try i.write(writer); + } else { + try writer.writeByte(0); + } + } + + fn deinit(self: Self, alloc: std.mem.Allocator) void { + if (@hasDecl(T, "deinit")) { + if (self.inner) |i| { + i.deinit(alloc); + } + } + } + }; +} + +fn EnumIntRoundUp(comptime T: type) type { + const int_info = @typeInfo(@typeInfo(T).Enum.tag_type).Int; + return std.meta.Int(int_info.signedness, std.mem.alignForward(u16, int_info.bits, 8)); +} + +fn MessageFunctions(comptime Self: type) type { + return struct { + const Tag = std.meta.Tag(Self); + + pub fn read(reader: anytype, alloc: std.mem.Allocator) !Self { + switch (try std.meta.intToEnum( + Tag, + try reader.readInt(EnumIntRoundUp(Tag), native_endian), + )) { + inline else => |t| { + const Field = std.meta.FieldType(Self, t); + return @unionInit(Self, @tagName(t), try Field.read(reader, alloc)); + }, + } + } + + pub fn write(self: *const Self, writer: anytype) !void { + try writer.writeInt(EnumIntRoundUp(Tag), @intFromEnum(self.*), native_endian); + switch (self.*) { + inline else => |*t| try t.write(writer), + } + } + + pub fn deinit(self: Self, alloc: std.mem.Allocator) void { + switch (self) { + inline else => |d| if (@hasDecl(@TypeOf(d), "deinit")) d.deinit(alloc), + } + } + }; +}