const coral = @import("./coral.zig"); const std = @import("std"); pub fn Binary(comptime Key: type, comptime Value: type, comptime traits: Traits(Key)) type { return struct { free_nodes: ?*Node = null, active_nodes: ?*Node = null, const Node = struct { key: Key, value: Value, has_left: ?*Node = null, has_right: ?*Node = null, has_parent: ?*Node = null, fn deinit(self: *Node, allocator: std.mem.Allocator) void { self.has_parent = undefined; if (self.has_left) |left| { left.deinit(allocator); allocator.destroy(left); } self.has_left = undefined; if (self.has_right) |right| { right.deinit(allocator); allocator.destroy(right); } self.has_right = undefined; } }; pub const KeyValues = struct { nodes: ?*Node, pub fn next(self: *KeyValues) ?coral.KeyValuePair(Key, *Value) { var nodes = self.nodes; while (nodes) |node| { const left = node.has_left orelse { self.nodes = node.has_right; return .{ .key = node.key, .value = &node.value, }; }; // Find the rightmost node in left subtree or link back to current var pred = left; while (pred.has_right != null and pred.has_right != node) { pred = pred.has_right.?; } if (pred.has_right != null) { pred.has_right = null; self.nodes = node.has_right; return .{ .key = node.key, .value = &node.value, }; } pred.has_right = node; self.nodes = node.has_left; nodes = self.nodes; } return null; } pub fn nextKey(self: *KeyValues) ?Key { return if (self.next()) |key_value| key_value.key else null; } pub fn nextValue(self: *KeyValues) ?*Value { return if (self.next()) |key_value| key_value.value else null; } }; const Self = @This(); pub fn clear(self: *Self) void { var free_nodes: ?*Node = null; if (self.active_nodes) |root| { // Push root onto stack root.has_parent = free_nodes; free_nodes = root; self.active_nodes = null; } while (free_nodes) |node| { // Pop node from stack free_nodes = node.has_parent; // Push children onto stack if (node.has_left) |left| { left.has_parent = free_nodes; free_nodes = left; } if (node.has_right) |right| { right.has_parent = free_nodes; free_nodes = right; } // Add node to free list node.has_left = null; node.has_right = null; node.has_parent = null; node.key = undefined; node.value = undefined; node.has_right = self.free_nodes; self.free_nodes = node; } } pub fn createNode(self: *Self, allocator: std.mem.Allocator, node: Node) std.mem.Allocator.Error!*Node { if (self.free_nodes) |free_node| { self.free_nodes = free_node.has_parent; free_node.* = node; return free_node; } const created_node = try allocator.create(Node); created_node.* = node; return created_node; } pub fn deinit(self: *Self, allocator: std.mem.Allocator) void { if (self.active_nodes) |node| { node.deinit(allocator); allocator.destroy(node); } self.active_nodes = undefined; } pub fn insert(self: *Self, allocator: std.mem.Allocator, key: Key, value: Value) std.mem.Allocator.Error!?*Value { var node = self.active_nodes orelse { self.active_nodes = try self.createNode(allocator, .{ .key = key, .value = value, }); return &self.active_nodes.?.value; }; while (true) { switch (traits.compare(key, node.key)) { .equal => { return null; }, .lesser => { node = node.has_left orelse { node.has_left = try self.createNode(allocator, .{ .key = key, .value = value, .has_parent = node, }); return &node.has_left.?.value; }; }, .greater => { node = node.has_right orelse { node.has_right = try self.createNode(allocator, .{ .key = key, .value = value, .has_parent = node, }); return &node.has_right.?.value; }; }, } } } pub const empty = Self{ .active_nodes = null, }; pub fn get(self: Self, key: Key) ?*Value { var nodes = self.active_nodes; while (nodes) |node| { nodes = switch (traits.compare(key, node.key)) { .lesser => node.has_left, .greater => node.has_right, .equal => { return &node.value; }, }; } return null; } pub fn getKey(self: Self, key: Key) ?Key { var nodes = self.active_nodes; while (nodes) |node| { nodes = switch (traits.compare(key, node.key)) { .lesser => node.has_left, .greater => node.has_right, .equal => { return node.key; }, }; } return null; } pub fn isEmpty(self: *Self) bool { return self.active_nodes == null; } pub fn keyValues(self: *const Self) KeyValues { return .{ .nodes = self.active_nodes }; } }; } pub const Comparison = enum(i2) { lesser = -1, equal = 0, greater = 1, }; pub fn Traits(comptime Key: type) type { return struct { compare: fn (Key, Key) Comparison, }; } pub fn scalarTraits(comptime Scalar: type) Traits(Scalar) { const traits = switch (@typeInfo(Scalar)) { .@"enum" => struct { fn compare(a: Scalar, b: Scalar) Comparison { const a_int = @intFromEnum(a); const b_int = @intFromEnum(b); if (a_int < b_int) { return .lesser; } if (a_int > b_int) { return .greater; } return .equal; } }, .pointer => struct { fn compare(a: Scalar, b: Scalar) Comparison { const a_int = @intFromPtr(a); const b_int = @intFromPtr(b); if (a_int < b_int) { return .lesser; } if (a_int > b_int) { return .greater; } return .equal; } }, .int => struct { fn compare(a: Scalar, b: Scalar) Comparison { if (a < b) { return .lesser; } if (a > b) { return .greater; } return .equal; } }, else => { @compileError(std.fmt.comptimePrint("parameter `Scalar` must be a scalar type, not {s}", .{ @typeName(Scalar), })); }, }; return .{ .compare = traits.compare, }; } pub fn sliceTraits(comptime Slice: type) Traits(Slice) { const slice_pointer = switch (@typeInfo(Slice)) { .pointer => |pointer| pointer, else => { @compileError(std.fmt.comptimePrint("parameter `Slice` must be a slice type, not {s}", .{ @typeName(Slice), })); }, }; const traits = struct { fn compare(a: Slice, b: Slice) Comparison { return switch (std.mem.order(slice_pointer.child, a, b)) { .lt => .lesser, .gt => .greater, .eq => .equal, }; } }; return .{ .compare = traits.compare, }; }