const coral = @import("./coral.zig"); const std = @import("std"); pub fn Binary(comptime Key: type, comptime Value: type, comptime traits: Traits(Key)) type { return struct { free_nodes: Pool = .{}, has_root: ?*Node = null, const Node = struct { key: Key, value: Value, has_lesser: ?*Node = null, has_greater: ?*Node = null, has_parent: ?*Node = null, fn deinit(self: *Node, allocator: std.mem.Allocator) void { if (self.has_lesser) |left| { left.deinit(allocator); allocator.destroy(left); } if (self.has_greater) |right| { right.deinit(allocator); allocator.destroy(right); } self.* = undefined; } fn find(self: *Node, key: Key) ?*Node { var node = self; while (true) { const next_node = switch (traits.compare(key, node.key)) { .lesser => node.has_lesser, .greater => node.has_greater, .equal => { return node; }, }; node = next_node orelse { return null; }; } } fn getMax(self: *Node) *Node { var node = self; while (self.has_lesser) |lesser| { node = lesser; } return node; } fn getMin(self: *Node) *Node { var node = self; while (self.has_lesser) |lesser| { node = lesser; } return node; } fn remove(self: *Node, key: Key) ?*Node { const node = self.find(key) orelse { return null; }; if (node.has_lesser == null) { std.debug.assert(self.transplant(node.has_greater)); } else if (node.has_greater == null) { std.debug.assert(self.transplant(node.has_lesser)); } else { var successor = node.has_greater.?.getMin(); if (successor.has_parent != node) { // Move successor up: replace successor with its right child first std.debug.assert(successor.transplant(successor.has_greater)); // Attach node.right to successor successor.has_greater = node.has_greater; if (successor.has_greater) |g| { g.has_parent = successor; } } // Replace node with successor std.debug.assert(node.transplant(successor)); // Attach node.left to successor successor.has_lesser = node.has_lesser; if (successor.has_lesser) |l| { l.has_parent = successor; } } // Detach the removed node completely and return it node.has_parent = null; node.has_lesser = null; node.has_greater = null; return node; } fn transplant(self: *Node, has_node: ?*Node) bool { const parent = self.has_parent orelse { return false; }; if (parent.has_lesser == self) { parent.has_lesser = has_node; } else { parent.has_greater = has_node; } if (has_node) |node| { node.has_parent = self.has_parent; } return true; } }; pub const KeyValues = struct { nodes: ?*Node, pub fn next(self: *KeyValues) ?coral.KeyValuePair(Key, *Value) { var nodes = self.nodes; while (nodes) |node| { const left = node.has_lesser orelse { self.nodes = node.has_greater; return .{ .key = node.key, .value = &node.value, }; }; // Find the rightmost node in left subtree or link back to current var pred = left; while (pred.has_greater != null and pred.has_greater != node) { pred = pred.has_greater.?; } if (pred.has_greater != null) { pred.has_greater = null; self.nodes = node.has_greater; return .{ .key = node.key, .value = &node.value, }; } pred.has_greater = node; self.nodes = node.has_lesser; nodes = self.nodes; } return null; } pub fn nextKey(self: *KeyValues) ?Key { return if (self.next()) |key_value| key_value.key else null; } pub fn nextValue(self: *KeyValues) ?*Value { return if (self.next()) |key_value| key_value.value else null; } }; const Pool = struct { has_node: ?*Node = null, fn create(self: *Pool, allocator: std.mem.Allocator, node: Node) std.mem.Allocator.Error!*Node { if (self.has_node) |free_node| { self.has_node = free_node.has_parent; free_node.* = node; return free_node; } const created_node = try allocator.create(Node); created_node.* = node; return created_node; } fn destroy(self: *Pool, node: *Node) void { node.* = .{ .key = undefined, .value = undefined, .has_parent = self.has_node, .has_lesser = null, .has_greater = null, }; self.has_node = node; } }; const Self = @This(); pub fn clear(self: *Self) void { var free_nodes: ?*Node = null; if (self.has_root) |root| { // Push root onto stack root.has_parent = free_nodes; free_nodes = root; self.has_root = null; } while (free_nodes) |node| { // Pop node from stack free_nodes = node.has_parent; // Push children onto stack if (node.has_lesser) |left| { left.has_parent = free_nodes; free_nodes = left; } if (node.has_greater) |right| { right.has_parent = free_nodes; free_nodes = right; } self.free_nodes.destroy(node); } } pub fn deinit(self: *Self, allocator: std.mem.Allocator) void { if (self.has_root) |root| { root.deinit(allocator); allocator.destroy(root); } self.has_root = undefined; } pub fn insert(self: *Self, allocator: std.mem.Allocator, key: Key, value: Value) std.mem.Allocator.Error!?*Value { var node = self.has_root orelse { self.has_root = try self.free_nodes.create(allocator, .{ .key = key, .value = value, }); return &self.has_root.?.value; }; while (true) { switch (traits.compare(key, node.key)) { .equal => { return null; }, .lesser => { node = node.has_lesser orelse { node.has_lesser = try self.free_nodes.create(allocator, .{ .key = key, .value = value, .has_parent = node, }); return &node.has_lesser.?.value; }; }, .greater => { node = node.has_greater orelse { node.has_greater = try self.free_nodes.create(allocator, .{ .key = key, .value = value, .has_parent = node, }); return &node.has_greater.?.value; }; }, } } } pub const empty = Self{ .has_root = null, }; pub fn get(self: Self, key: Key) ?*Value { if (self.has_root) |root| { if (root.find(key)) |node| { return &node.value; } } return null; } pub fn getKey(self: Self, key: Key) ?Key { if (self.has_root) |root| { if (root.find(key)) |node| { return &node.key; } } return null; } pub fn isEmpty(self: *Self) bool { return self.has_root == null; } pub fn keyValues(self: *const Self) KeyValues { return .{ .nodes = self.has_root }; } pub fn remove(self: *Self, key: Key) ?coral.KeyValuePair(Key, Value) { const root = self.has_root orelse { return null; }; const node = root.remove(key) orelse { return null; }; defer { self.free_nodes.destroy(node); } return .{ .key = node.key, .value = node.value, }; } }; } pub const Comparison = enum(i2) { lesser = -1, equal = 0, greater = 1, }; pub fn Traits(comptime Key: type) type { return struct { compare: fn (Key, Key) Comparison, }; } pub fn scalarTraits(comptime Scalar: type) Traits(Scalar) { const traits = switch (@typeInfo(Scalar)) { .@"enum" => struct { fn compare(a: Scalar, b: Scalar) Comparison { const a_int = @intFromEnum(a); const b_int = @intFromEnum(b); if (a_int < b_int) { return .lesser; } if (a_int > b_int) { return .greater; } return .equal; } }, .pointer => struct { fn compare(a: Scalar, b: Scalar) Comparison { const a_int = @intFromPtr(a); const b_int = @intFromPtr(b); if (a_int < b_int) { return .lesser; } if (a_int > b_int) { return .greater; } return .equal; } }, .int => struct { fn compare(a: Scalar, b: Scalar) Comparison { if (a < b) { return .lesser; } if (a > b) { return .greater; } return .equal; } }, else => { @compileError(std.fmt.comptimePrint("parameter `Scalar` must be a scalar type, not {s}", .{ @typeName(Scalar), })); }, }; return .{ .compare = traits.compare, }; } pub fn sliceTraits(comptime Slice: type) Traits(Slice) { const slice_pointer = switch (@typeInfo(Slice)) { .pointer => |pointer| pointer, else => { @compileError(std.fmt.comptimePrint("parameter `Slice` must be a slice type, not {s}", .{ @typeName(Slice), })); }, }; const traits = struct { fn compare(a: Slice, b: Slice) Comparison { return switch (std.mem.order(slice_pointer.child, a, b)) { .lt => .lesser, .gt => .greater, .eq => .equal, }; } }; return .{ .compare = traits.compare, }; }